#!/usr/bin/env python3 """ test_integration.py - Integration tests for mod_reqin_log This script runs integration tests for the mod_reqin_log Apache module. It tests the required scenarios: 1. basic_logging - Verify JSON logs with expected fields 2. header_limits - Verify header count and value length limits 3. sensitive_headers_blacklist - Verify sensitive headers are never logged 4. socket_unavailable_on_start - Verify reconnect behavior when socket is unavailable 5. runtime_socket_loss - Verify behavior when socket disappears during traffic """ import socket import os import sys import json import time import threading import argparse # Default paths # Use /var/run for production (more secure than /tmp) DEFAULT_SOCKET_PATH = os.environ.get("MOD_REQIN_LOG_SOCKET", "/var/run/mod_reqin_log_test.sock") DEFAULT_APACHE_URL = "http://localhost:8080" # Test results tests_run = 0 tests_passed = 0 tests_failed = 0 def log_info(msg): print(f"[INFO] {msg}", file=sys.stderr) def log_pass(msg): global tests_passed tests_passed += 1 print(f"[PASS] {msg}", file=sys.stderr) def log_fail(msg): global tests_failed tests_failed += 1 print(f"[FAIL] {msg}", file=sys.stderr) def log_test_start(name): global tests_run tests_run += 1 print(f"\n[TEST] Starting: {name}", file=sys.stderr) class SocketServer: """Unix socket server that collects JSON log entries.""" def __init__(self, socket_path): self.socket_path = socket_path self.server = None self.running = False self.entries = [] self.lock = threading.Lock() self.connection = None self.buffer = b"" def start(self): """Start the socket server.""" if os.path.exists(self.socket_path): os.remove(self.socket_path) self.server = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) self.server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) self.server.bind(self.socket_path) self.server.listen(5) self.server.settimeout(1.0) os.chmod(self.socket_path, 0o660) self.running = True self.thread = threading.Thread(target=self._accept_loop, daemon=True) self.thread.start() def _accept_loop(self): """Accept connections and read data.""" while self.running: try: conn, _ = self.server.accept() conn.settimeout(0.5) self.connection = conn while self.running: try: chunk = conn.recv(4096) if not chunk: break self.buffer += chunk while b'\n' in self.buffer: newline_pos = self.buffer.index(b'\n') line = self.buffer[:newline_pos].decode('utf-8', errors='replace') self.buffer = self.buffer[newline_pos + 1:] if line.strip(): self._process_entry(line) except socket.timeout: continue except Exception: break conn.close() self.connection = None except socket.timeout: continue except Exception as e: if self.running: log_info(f"Socket server error: {e}") def _process_entry(self, line): """Process a log entry.""" try: entry = json.loads(line) with self.lock: self.entries.append(entry) except json.JSONDecodeError: log_info(f"Invalid JSON entry: {line[:100]}") def stop(self): """Stop the socket server.""" self.running = False if self.connection: try: self.connection.close() except Exception: pass if self.server: try: self.server.close() except Exception: pass if os.path.exists(self.socket_path): try: os.remove(self.socket_path) except Exception: pass def get_entries(self): """Get collected log entries.""" with self.lock: return list(self.entries) def clear_entries(self): """Clear collected entries.""" with self.lock: self.entries.clear() def wait_for_entries(self, count, timeout=5.0): """Wait for at least 'count' entries to arrive.""" start = time.time() while time.time() - start < timeout: with self.lock: if len(self.entries) >= count: return True time.sleep(0.1) return False def make_request(url, headers=None, method='GET'): """Make an HTTP request using urllib.""" import urllib.request req = urllib.request.Request(url, method=method) if headers: for key, value in headers.items(): req.add_header(key, value) try: with urllib.request.urlopen(req, timeout=5) as response: return response.status, response.read().decode('utf-8', errors='replace') except Exception as e: return None, str(e) # ============================================================================ # Test 1: Basic Logging # ============================================================================ def test_basic_logging(socket_server, apache_url): log_test_start("basic_logging") socket_server.clear_entries() status, _ = make_request(f"{apache_url}/") if status is None: log_fail("basic_logging - HTTP request failed") return False if not socket_server.wait_for_entries(1, timeout=3.0): log_fail("basic_logging - No log entries received") return False entries = socket_server.get_entries() entry = entries[-1] required_fields = ['time', 'timestamp', 'src_ip', 'src_port', 'dst_ip', 'dst_port', 'method', 'path', 'host', 'http_version'] missing_fields = [field for field in required_fields if field not in entry] if missing_fields: log_fail(f"basic_logging - Missing fields: {missing_fields}") return False if entry.get('method') != 'GET': log_fail(f"basic_logging - Expected method 'GET', got '{entry.get('method')}'") return False if not isinstance(entry.get('timestamp'), int): log_fail(f"basic_logging - timestamp should be integer, got {type(entry.get('timestamp'))}") return False if not entry.get('time', '').startswith('20'): log_fail(f"basic_logging - Invalid time format: {entry.get('time')}") return False log_pass("basic_logging - All required fields present and valid") return True # ============================================================================ # Test 2: Header Limits # ============================================================================ def test_header_limits(socket_server, apache_url): log_test_start("header_limits") socket_server.clear_entries() headers = { 'X-Request-Id': 'test-123', 'X-Trace-Id': 'trace-456', 'User-Agent': 'TestAgent/1.0', 'X-Long-Header': 'A' * 500, } status, _ = make_request(f"{apache_url}/test-headers", headers=headers) if status is None: log_fail("header_limits - HTTP request failed") return False if not socket_server.wait_for_entries(1, timeout=3.0): log_fail("header_limits - No log entries received") return False entries = socket_server.get_entries() entry = entries[-1] header_fields = [k for k in entry.keys() if k.startswith('header_')] for key, value in entry.items(): if key.startswith('header_') and isinstance(value, str) and len(value) > 256: log_fail(f"header_limits - Header value not truncated: {key} has {len(value)} chars") return False log_pass(f"header_limits - Headers logged correctly ({len(header_fields)} header fields)") return True # ============================================================================ # Test 3: Sensitive Headers Blacklist # ============================================================================ def test_sensitive_headers_blacklist(socket_server, apache_url): log_test_start("sensitive_headers_blacklist") socket_server.clear_entries() headers = { 'X-Request-Id': 'blacklist-test', 'Authorization': 'Bearer secret-token', 'Cookie': 'sessionid=super-secret', 'X-Api-Key': 'api-key-secret', } status, _ = make_request(f"{apache_url}/blacklist-check", headers=headers) if status is None: log_fail("sensitive_headers_blacklist - HTTP request failed") return False if not socket_server.wait_for_entries(1, timeout=3.0): log_fail("sensitive_headers_blacklist - No log entries received") return False entry = socket_server.get_entries()[-1] forbidden_keys = [ 'header_Authorization', 'header_Cookie', 'header_X-Api-Key', ] for key in forbidden_keys: if key in entry: log_fail(f"sensitive_headers_blacklist - Sensitive header leaked: {key}") return False log_pass("sensitive_headers_blacklist - Sensitive headers are not logged") return True # ============================================================================ # Test 4: Socket Unavailable on Start # ============================================================================ def test_socket_unavailable_on_start(socket_server, apache_url): log_test_start("socket_unavailable_on_start") socket_server.stop() time.sleep(0.5) for i in range(3): make_request(f"{apache_url}/unavailable-{i}") time.sleep(0.2) status, _ = make_request(f"{apache_url}/final-check") if status != 200: log_fail("socket_unavailable_on_start - Request failed when socket unavailable") socket_server.start() return False socket_server.start() time.sleep(10.5) socket_server.clear_entries() make_request(f"{apache_url}/after-reconnect") if socket_server.wait_for_entries(1, timeout=12.0): log_pass("socket_unavailable_on_start - Module reconnected after socket became available") return True log_fail("socket_unavailable_on_start - Module did not reconnect") return False # ============================================================================ # Test 5: Runtime Socket Loss # ============================================================================ def test_runtime_socket_loss(socket_server, apache_url): log_test_start("runtime_socket_loss") socket_server.clear_entries() for i in range(3): make_request(f"{apache_url}/before-loss-{i}") if not socket_server.wait_for_entries(3, timeout=3.0): log_fail("runtime_socket_loss - Initial requests not logged") return False initial_count = len(socket_server.get_entries()) socket_server.stop() time.sleep(0.3) for i in range(3): req_start = time.time() make_request(f"{apache_url}/during-loss-{i}") req_duration = time.time() - req_start if req_duration > 2.0: log_fail(f"runtime_socket_loss - Request blocked for {req_duration:.2f}s") socket_server.start() return False time.sleep(0.5) current_count = len(socket_server.get_entries()) if current_count != initial_count: log_info(f"runtime_socket_loss - Some entries logged during socket loss (expected: {initial_count}, got: {current_count})") socket_server.start() time.sleep(10.5) socket_server.clear_entries() make_request(f"{apache_url}/after-loss") if socket_server.wait_for_entries(1, timeout=12.0): log_pass("runtime_socket_loss - Module recovered after socket restored") return True log_fail("runtime_socket_loss - Module did not recover after socket restored") return False # ============================================================================ # Main Test Runner # ============================================================================ def run_all_tests(apache_url, socket_path): global tests_run, tests_passed, tests_failed print("=" * 60, file=sys.stderr) print("mod_reqin_log Integration Tests", file=sys.stderr) print("=" * 60, file=sys.stderr) server = SocketServer(socket_path) server.start() log_info(f"Socket server started on {socket_path}") time.sleep(1.0) try: test_basic_logging(server, apache_url) test_header_limits(server, apache_url) test_sensitive_headers_blacklist(server, apache_url) test_socket_unavailable_on_start(server, apache_url) test_runtime_socket_loss(server, apache_url) finally: server.stop() log_info("Socket server stopped") print("\n" + "=" * 60, file=sys.stderr) print("Test Summary", file=sys.stderr) print("=" * 60, file=sys.stderr) print(f"Tests run: {tests_run}", file=sys.stderr) print(f"Tests passed: {tests_passed}", file=sys.stderr) print(f"Tests failed: {tests_failed}", file=sys.stderr) print("=" * 60, file=sys.stderr) return tests_failed == 0 def main(): parser = argparse.ArgumentParser(description='Integration tests for mod_reqin_log') parser.add_argument('--socket', default=DEFAULT_SOCKET_PATH, help=f'Unix socket path (default: {DEFAULT_SOCKET_PATH})') parser.add_argument('--url', default=DEFAULT_APACHE_URL, help=f'Apache URL (default: {DEFAULT_APACHE_URL})') args = parser.parse_args() success = run_all_tests(args.url, args.socket) sys.exit(0 if success else 1) if __name__ == '__main__': main()