fix: renforcer la robustesse du module et étendre les tests/CI

Co-authored-by: aider (openrouter/openai/gpt-5.3-codex) <aider@aider.chat>
This commit is contained in:
Jacquin Antoine
2026-02-28 20:28:40 +01:00
parent a935ed1641
commit 46291898e1
10 changed files with 735 additions and 447 deletions

View File

@ -3,24 +3,21 @@
test_integration.py - Integration tests for mod_reqin_log
This script runs integration tests for the mod_reqin_log Apache module.
It tests the 4 required scenarios from architecture.yml:
It tests the required scenarios:
1. basic_logging - Verify JSON logs with expected fields
2. header_limits - Verify header count and value length limits
3. socket_unavailable_on_start - Verify reconnect behavior when socket is unavailable
4. runtime_socket_loss - Verify behavior when socket disappears during traffic
3. sensitive_headers_blacklist - Verify sensitive headers are never logged
4. socket_unavailable_on_start - Verify reconnect behavior when socket is unavailable
5. runtime_socket_loss - Verify behavior when socket disappears during traffic
"""
import socket
import os
import sys
import json
import signal
import time
import subprocess
import threading
import argparse
from datetime import datetime
from http.server import HTTPServer, BaseHTTPRequestHandler
# Default paths
# Use /var/run for production (more secure than /tmp)
@ -32,12 +29,6 @@ tests_run = 0
tests_passed = 0
tests_failed = 0
# Global flags
shutdown_requested = False
log_entries = []
socket_server = None
socket_thread = None
def log_info(msg):
print(f"[INFO] {msg}", file=sys.stderr)
@ -63,7 +54,7 @@ def log_test_start(name):
class SocketServer:
"""Unix socket server that collects JSON log entries."""
def __init__(self, socket_path):
self.socket_path = socket_path
self.server = None
@ -72,31 +63,28 @@ class SocketServer:
self.lock = threading.Lock()
self.connection = None
self.buffer = b""
def start(self):
"""Start the socket server."""
# Remove existing socket
if os.path.exists(self.socket_path):
os.remove(self.socket_path)
# Create socket
self.server = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
self.server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
self.server.bind(self.socket_path)
self.server.listen(5)
self.server.settimeout(1.0)
os.chmod(self.socket_path, 0o666)
os.chmod(self.socket_path, 0o660)
self.running = True
# Start accept thread
self.thread = threading.Thread(target=self._accept_loop, daemon=True)
self.thread.start()
def _accept_loop(self):
"""Accept connections and read data."""
while self.running:
try:
conn, addr = self.server.accept()
conn, _ = self.server.accept()
conn.settimeout(0.5)
self.connection = conn
while self.running:
@ -105,8 +93,7 @@ class SocketServer:
if not chunk:
break
self.buffer += chunk
# Process complete lines
while b'\n' in self.buffer:
newline_pos = self.buffer.index(b'\n')
line = self.buffer[:newline_pos].decode('utf-8', errors='replace')
@ -115,7 +102,7 @@ class SocketServer:
self._process_entry(line)
except socket.timeout:
continue
except Exception as e:
except Exception:
break
conn.close()
self.connection = None
@ -124,7 +111,7 @@ class SocketServer:
except Exception as e:
if self.running:
log_info(f"Socket server error: {e}")
def _process_entry(self, line):
"""Process a log entry."""
try:
@ -133,36 +120,36 @@ class SocketServer:
self.entries.append(entry)
except json.JSONDecodeError:
log_info(f"Invalid JSON entry: {line[:100]}")
def stop(self):
"""Stop the socket server."""
self.running = False
if self.connection:
try:
self.connection.close()
except:
except Exception:
pass
if self.server:
try:
self.server.close()
except:
except Exception:
pass
if os.path.exists(self.socket_path):
try:
os.remove(self.socket_path)
except:
except Exception:
pass
def get_entries(self):
"""Get collected log entries."""
with self.lock:
return list(self.entries)
def clear_entries(self):
"""Clear collected entries."""
with self.lock:
self.entries.clear()
def wait_for_entries(self, count, timeout=5.0):
"""Wait for at least 'count' entries to arrive."""
start = time.time()
@ -175,14 +162,14 @@ class SocketServer:
def make_request(url, headers=None, method='GET'):
"""Make an HTTP request using curl."""
"""Make an HTTP request using urllib."""
import urllib.request
req = urllib.request.Request(url, method=method)
if headers:
for key, value in headers.items():
req.add_header(key, value)
try:
with urllib.request.urlopen(req, timeout=5) as response:
return response.status, response.read().decode('utf-8', errors='replace')
@ -194,52 +181,42 @@ def make_request(url, headers=None, method='GET'):
# Test 1: Basic Logging
# ============================================================================
def test_basic_logging(socket_server, apache_url):
"""
Test: basic_logging
Description: With JsonSockLogEnabled On and valid socket, verify that each request
produces a valid JSON line with expected fields.
"""
log_test_start("basic_logging")
socket_server.clear_entries()
# Make a simple request
status, _ = make_request(f"{apache_url}/")
# Wait for log entry
if status is None:
log_fail("basic_logging - HTTP request failed")
return False
if not socket_server.wait_for_entries(1, timeout=3.0):
log_fail("basic_logging - No log entries received")
return False
entries = socket_server.get_entries()
entry = entries[-1]
# Verify required fields
required_fields = ['time', 'timestamp', 'src_ip', 'src_port', 'dst_ip',
required_fields = ['time', 'timestamp', 'src_ip', 'src_port', 'dst_ip',
'dst_port', 'method', 'path', 'host', 'http_version']
missing_fields = []
for field in required_fields:
if field not in entry:
missing_fields.append(field)
missing_fields = [field for field in required_fields if field not in entry]
if missing_fields:
log_fail(f"basic_logging - Missing fields: {missing_fields}")
return False
# Verify field types and values
if entry.get('method') != 'GET':
log_fail(f"basic_logging - Expected method 'GET', got '{entry.get('method')}'")
return False
if not isinstance(entry.get('timestamp'), int):
log_fail(f"basic_logging - timestamp should be integer, got {type(entry.get('timestamp'))}")
return False
if not entry.get('time', '').startswith('20'):
log_fail(f"basic_logging - Invalid time format: {entry.get('time')}")
return False
log_pass("basic_logging - All required fields present and valid")
return True
@ -248,188 +225,189 @@ def test_basic_logging(socket_server, apache_url):
# Test 2: Header Limits
# ============================================================================
def test_header_limits(socket_server, apache_url):
"""
Test: header_limits
Description: Configure more headers than JsonSockLogMaxHeaders and verify only
the first N are logged and values are truncated.
"""
log_test_start("header_limits")
socket_server.clear_entries()
# Make request with multiple headers including a long one
headers = {
'X-Request-Id': 'test-123',
'X-Trace-Id': 'trace-456',
'User-Agent': 'TestAgent/1.0',
'X-Long-Header': 'A' * 500, # Very long value
'X-Long-Header': 'A' * 500,
}
status, _ = make_request(f"{apache_url}/test-headers", headers=headers)
# Wait for log entry
if status is None:
log_fail("header_limits - HTTP request failed")
return False
if not socket_server.wait_for_entries(1, timeout=3.0):
log_fail("header_limits - No log entries received")
return False
entries = socket_server.get_entries()
entry = entries[-1]
# Check that header fields are present (implementation logs configured headers)
header_fields = [k for k in entry.keys() if k.startswith('header_')]
# Verify header value truncation (max 256 chars by default)
for key, value in entry.items():
if key.startswith('header_') and isinstance(value, str):
if len(value) > 256:
log_fail(f"header_limits - Header value not truncated: {key} has {len(value)} chars")
return False
if key.startswith('header_') and isinstance(value, str) and len(value) > 256:
log_fail(f"header_limits - Header value not truncated: {key} has {len(value)} chars")
return False
log_pass(f"header_limits - Headers logged correctly ({len(header_fields)} header fields)")
return True
# ============================================================================
# Test 3: Socket Unavailable on Start
# Test 3: Sensitive Headers Blacklist
# ============================================================================
def test_sensitive_headers_blacklist(socket_server, apache_url):
log_test_start("sensitive_headers_blacklist")
socket_server.clear_entries()
headers = {
'X-Request-Id': 'blacklist-test',
'Authorization': 'Bearer secret-token',
'Cookie': 'sessionid=super-secret',
'X-Api-Key': 'api-key-secret',
}
status, _ = make_request(f"{apache_url}/blacklist-check", headers=headers)
if status is None:
log_fail("sensitive_headers_blacklist - HTTP request failed")
return False
if not socket_server.wait_for_entries(1, timeout=3.0):
log_fail("sensitive_headers_blacklist - No log entries received")
return False
entry = socket_server.get_entries()[-1]
forbidden_keys = [
'header_Authorization',
'header_Cookie',
'header_X-Api-Key',
]
for key in forbidden_keys:
if key in entry:
log_fail(f"sensitive_headers_blacklist - Sensitive header leaked: {key}")
return False
log_pass("sensitive_headers_blacklist - Sensitive headers are not logged")
return True
# ============================================================================
# Test 4: Socket Unavailable on Start
# ============================================================================
def test_socket_unavailable_on_start(socket_server, apache_url):
"""
Test: socket_unavailable_on_start
Description: Start with socket not yet created; verify periodic reconnect attempts
and throttled error logging.
"""
log_test_start("socket_unavailable_on_start")
# Stop the socket server to simulate unavailable socket
socket_server.stop()
time.sleep(0.5)
# Make requests while socket is unavailable
for i in range(3):
make_request(f"{apache_url}/unavailable-{i}")
time.sleep(0.2)
# Requests should still succeed (logging failures don't affect client)
status, _ = make_request(f"{apache_url}/final-check")
if status != 200:
log_fail("socket_unavailable_on_start - Request failed when socket unavailable")
# Restart socket server for subsequent tests
socket_server.start()
return False
# Restart socket server
socket_server.start()
time.sleep(0.5)
# Verify module can reconnect
time.sleep(10.5)
socket_server.clear_entries()
status, _ = make_request(f"{apache_url}/after-reconnect")
if socket_server.wait_for_entries(1, timeout=3.0):
make_request(f"{apache_url}/after-reconnect")
if socket_server.wait_for_entries(1, timeout=12.0):
log_pass("socket_unavailable_on_start - Module reconnected after socket became available")
return True
else:
log_fail("socket_unavailable_on_start - Module did not reconnect")
return False
log_fail("socket_unavailable_on_start - Module did not reconnect")
return False
# ============================================================================
# Test 4: Runtime Socket Loss
# Test 5: Runtime Socket Loss
# ============================================================================
def test_runtime_socket_loss(socket_server, apache_url):
"""
Test: runtime_socket_loss
Description: Drop the Unix socket while traffic is ongoing; verify that log lines
are dropped, worker threads are not blocked, and reconnect attempts
resume once the socket reappears.
"""
log_test_start("runtime_socket_loss")
socket_server.clear_entries()
# Make some initial requests
for i in range(3):
make_request(f"{apache_url}/before-loss-{i}")
if not socket_server.wait_for_entries(3, timeout=3.0):
log_fail("runtime_socket_loss - Initial requests not logged")
return False
initial_count = len(socket_server.get_entries())
# Simulate socket loss by stopping server
socket_server.stop()
time.sleep(0.3)
# Make requests while socket is gone
start_time = time.time()
for i in range(3):
req_start = time.time()
status, _ = make_request(f"{apache_url}/during-loss-{i}")
make_request(f"{apache_url}/during-loss-{i}")
req_duration = time.time() - req_start
# Requests should NOT block (should complete quickly)
if req_duration > 2.0:
log_fail(f"runtime_socket_loss - Request blocked for {req_duration:.2f}s")
socket_server.start()
return False
# Give time for any pending logs
time.sleep(0.5)
# Verify no new entries were logged (socket was down)
current_count = len(socket_server.get_entries())
if current_count != initial_count:
log_info(f"runtime_socket_loss - Some entries logged during socket loss (expected: {initial_count}, got: {current_count})")
# Restart socket server
socket_server.start()
time.sleep(0.5)
# Verify module can reconnect and log again
time.sleep(10.5)
socket_server.clear_entries()
status, _ = make_request(f"{apache_url}/after-loss")
if socket_server.wait_for_entries(1, timeout=3.0):
make_request(f"{apache_url}/after-loss")
if socket_server.wait_for_entries(1, timeout=12.0):
log_pass("runtime_socket_loss - Module recovered after socket restored")
return True
else:
log_fail("runtime_socket_loss - Module did not recover after socket restored")
return False
log_fail("runtime_socket_loss - Module did not recover after socket restored")
return False
# ============================================================================
# Main Test Runner
# ============================================================================
def run_all_tests(apache_url, socket_path):
"""Run all integration tests."""
global tests_run, tests_passed, tests_failed
print("=" * 60, file=sys.stderr)
print("mod_reqin_log Integration Tests", file=sys.stderr)
print("=" * 60, file=sys.stderr)
# Create socket server
server = SocketServer(socket_path)
server.start()
log_info(f"Socket server started on {socket_path}")
# Give Apache time to connect
time.sleep(1.0)
try:
# Run all tests
test_basic_logging(server, apache_url)
test_header_limits(server, apache_url)
test_sensitive_headers_blacklist(server, apache_url)
test_socket_unavailable_on_start(server, apache_url)
test_runtime_socket_loss(server, apache_url)
finally:
# Cleanup
server.stop()
log_info("Socket server stopped")
# Print summary
print("\n" + "=" * 60, file=sys.stderr)
print("Test Summary", file=sys.stderr)
print("=" * 60, file=sys.stderr)
@ -437,7 +415,7 @@ def run_all_tests(apache_url, socket_path):
print(f"Tests passed: {tests_passed}", file=sys.stderr)
print(f"Tests failed: {tests_failed}", file=sys.stderr)
print("=" * 60, file=sys.stderr)
return tests_failed == 0
@ -448,7 +426,7 @@ def main():
parser.add_argument('--url', default=DEFAULT_APACHE_URL,
help=f'Apache URL (default: {DEFAULT_APACHE_URL})')
args = parser.parse_args()
success = run_all_tests(args.url, args.socket)
sys.exit(0 if success else 1)

View File

@ -9,12 +9,15 @@
#include <string.h>
#include <stdio.h>
#include <stdlib.h>
#include <errno.h>
#include <limits.h>
/* Default configuration values */
#define DEFAULT_MAX_HEADERS 10
#define DEFAULT_MAX_HEADER_VALUE_LEN 256
#define DEFAULT_RECONNECT_INTERVAL 10
#define DEFAULT_ERROR_REPORT_INTERVAL 10
#define MAX_SOCKET_PATH_LEN 108
/* Mock configuration structure */
typedef struct {
@ -40,34 +43,59 @@ static const char *parse_socket_path(const char *value)
if (value == NULL || strlen(value) == 0) {
return NULL;
}
if (strlen(value) >= MAX_SOCKET_PATH_LEN) {
return NULL;
}
return value;
}
static int parse_int_strict(const char *value, int *result)
{
char *endptr = NULL;
long val;
if (value == NULL || *value == '\0' || result == NULL) {
return -1;
}
errno = 0;
val = strtol(value, &endptr, 10);
if (errno != 0 || endptr == value || *endptr != '\0' || val < INT_MIN || val > INT_MAX) {
return -1;
}
*result = (int)val;
return 0;
}
static int parse_max_headers(const char *value, int *result)
{
char *endptr;
long val = strtol(value, &endptr, 10);
if (*endptr != '\0' || val < 0) {
if (parse_int_strict(value, result) != 0 || *result < 0) {
return -1;
}
*result = (int)val;
return 0;
}
static int parse_interval(const char *value, int *result)
{
char *endptr;
long val = strtol(value, &endptr, 10);
if (*endptr != '\0' || val < 0) {
if (parse_int_strict(value, result) != 0 || *result < 0) {
return -1;
}
return 0;
}
static int parse_max_header_value_len(const char *value, int *result)
{
if (parse_int_strict(value, result) != 0 || *result < 1) {
return -1;
}
*result = (int)val;
return 0;
}
/* Test: Parse enabled On */
static void test_parse_enabled_on(void **state)
{
(void)state;
assert_int_equal(parse_enabled("On"), 1);
assert_int_equal(parse_enabled("on"), 1);
assert_int_equal(parse_enabled("ON"), 1);
@ -77,6 +105,7 @@ static void test_parse_enabled_on(void **state)
/* Test: Parse enabled Off */
static void test_parse_enabled_off(void **state)
{
(void)state;
assert_int_equal(parse_enabled("Off"), 0);
assert_int_equal(parse_enabled("off"), 0);
assert_int_equal(parse_enabled("OFF"), 0);
@ -86,6 +115,7 @@ static void test_parse_enabled_off(void **state)
/* Test: Parse socket path valid */
static void test_parse_socket_path_valid(void **state)
{
(void)state;
const char *result = parse_socket_path("/var/run/mod_reqin_log.sock");
assert_string_equal(result, "/var/run/mod_reqin_log.sock");
}
@ -93,6 +123,7 @@ static void test_parse_socket_path_valid(void **state)
/* Test: Parse socket path empty */
static void test_parse_socket_path_empty(void **state)
{
(void)state;
const char *result = parse_socket_path("");
assert_null(result);
}
@ -100,20 +131,45 @@ static void test_parse_socket_path_empty(void **state)
/* Test: Parse socket path NULL */
static void test_parse_socket_path_null(void **state)
{
(void)state;
const char *result = parse_socket_path(NULL);
assert_null(result);
}
/* Test: Parse socket path max length valid */
static void test_parse_socket_path_max_len_valid(void **state)
{
(void)state;
char path[MAX_SOCKET_PATH_LEN];
memset(path, 'a', MAX_SOCKET_PATH_LEN - 1);
path[MAX_SOCKET_PATH_LEN - 1] = '\0';
assert_non_null(parse_socket_path(path));
}
/* Test: Parse socket path max length invalid */
static void test_parse_socket_path_max_len_invalid(void **state)
{
(void)state;
char path[MAX_SOCKET_PATH_LEN + 1];
memset(path, 'b', MAX_SOCKET_PATH_LEN);
path[MAX_SOCKET_PATH_LEN] = '\0';
assert_null(parse_socket_path(path));
}
/* Test: Parse max headers valid */
static void test_parse_max_headers_valid(void **state)
{
int result;
(void)state;
assert_int_equal(parse_max_headers("10", &result), 0);
assert_int_equal(result, 10);
assert_int_equal(parse_max_headers("0", &result), 0);
assert_int_equal(result, 0);
assert_int_equal(parse_max_headers("100", &result), 0);
assert_int_equal(result, 100);
}
@ -122,6 +178,8 @@ static void test_parse_max_headers_valid(void **state)
static void test_parse_max_headers_invalid(void **state)
{
int result;
(void)state;
assert_int_equal(parse_max_headers("-1", &result), -1);
assert_int_equal(parse_max_headers("abc", &result), -1);
assert_int_equal(parse_max_headers("10abc", &result), -1);
@ -131,12 +189,14 @@ static void test_parse_max_headers_invalid(void **state)
static void test_parse_reconnect_interval_valid(void **state)
{
int result;
(void)state;
assert_int_equal(parse_interval("10", &result), 0);
assert_int_equal(result, 10);
assert_int_equal(parse_interval("0", &result), 0);
assert_int_equal(result, 0);
assert_int_equal(parse_interval("60", &result), 0);
assert_int_equal(result, 60);
}
@ -145,13 +205,52 @@ static void test_parse_reconnect_interval_valid(void **state)
static void test_parse_reconnect_interval_invalid(void **state)
{
int result;
(void)state;
assert_int_equal(parse_interval("-5", &result), -1);
assert_int_equal(parse_interval("abc", &result), -1);
assert_int_equal(parse_interval("10abc", &result), -1);
}
/* Test: Parse max header value length valid */
static void test_parse_max_header_value_len_valid(void **state)
{
int result;
(void)state;
assert_int_equal(parse_max_header_value_len("1", &result), 0);
assert_int_equal(result, 1);
assert_int_equal(parse_max_header_value_len("256", &result), 0);
assert_int_equal(result, 256);
}
/* Test: Parse max header value length invalid */
static void test_parse_max_header_value_len_invalid(void **state)
{
int result;
(void)state;
assert_int_equal(parse_max_header_value_len("0", &result), -1);
assert_int_equal(parse_max_header_value_len("-1", &result), -1);
assert_int_equal(parse_max_header_value_len("10abc", &result), -1);
}
/* Test: strict numeric parsing invalid suffix for all int directives */
static void test_strict_numeric_invalid_suffix_all(void **state)
{
int result;
(void)state;
assert_int_equal(parse_max_headers("10abc", &result), -1);
assert_int_equal(parse_interval("10abc", &result), -1);
assert_int_equal(parse_max_header_value_len("10abc", &result), -1);
}
/* Test: Default configuration values */
static void test_default_config_values(void **state)
{
(void)state;
assert_int_equal(DEFAULT_MAX_HEADERS, 10);
assert_int_equal(DEFAULT_MAX_HEADER_VALUE_LEN, 256);
assert_int_equal(DEFAULT_RECONNECT_INTERVAL, 10);
@ -161,34 +260,47 @@ static void test_default_config_values(void **state)
/* Test: Configuration validation - enabled requires socket */
static void test_config_validation_enabled_requires_socket(void **state)
{
/* Valid: enabled with socket */
int enabled = 1;
const char *socket = "/var/run/socket";
(void)state;
assert_true(enabled == 0 || socket != NULL);
/* Invalid: enabled without socket */
socket = NULL;
assert_false(enabled == 0 || socket != NULL);
}
/* Test: Configuration validation - enabled with empty socket is invalid */
static void test_config_validation_enabled_with_empty_socket(void **state)
{
int enabled = 1;
const char *socket = parse_socket_path("");
(void)state;
assert_false(enabled == 0 || socket != NULL);
}
/* Test: Header value length validation */
static void test_header_value_len_validation(void **state)
{
int result;
assert_int_equal(parse_interval("1", &result), 0);
(void)state;
assert_int_equal(parse_max_header_value_len("1", &result), 0);
assert_true(result >= 1);
assert_int_equal(parse_interval("0", &result), 0);
assert_false(result >= 1);
assert_int_equal(parse_max_header_value_len("0", &result), -1);
}
/* Test: Large but valid values */
static void test_large_valid_values(void **state)
{
int result;
(void)state;
assert_int_equal(parse_max_headers("1000000", &result), 0);
assert_int_equal(result, 1000000);
assert_int_equal(parse_interval("86400", &result), 0);
assert_int_equal(result, 86400);
}
@ -201,15 +313,21 @@ int main(void)
cmocka_unit_test(test_parse_socket_path_valid),
cmocka_unit_test(test_parse_socket_path_empty),
cmocka_unit_test(test_parse_socket_path_null),
cmocka_unit_test(test_parse_socket_path_max_len_valid),
cmocka_unit_test(test_parse_socket_path_max_len_invalid),
cmocka_unit_test(test_parse_max_headers_valid),
cmocka_unit_test(test_parse_max_headers_invalid),
cmocka_unit_test(test_parse_reconnect_interval_valid),
cmocka_unit_test(test_parse_reconnect_interval_invalid),
cmocka_unit_test(test_parse_max_header_value_len_valid),
cmocka_unit_test(test_parse_max_header_value_len_invalid),
cmocka_unit_test(test_strict_numeric_invalid_suffix_all),
cmocka_unit_test(test_default_config_values),
cmocka_unit_test(test_config_validation_enabled_requires_socket),
cmocka_unit_test(test_config_validation_enabled_with_empty_socket),
cmocka_unit_test(test_header_value_len_validation),
cmocka_unit_test(test_large_valid_values),
};
return cmocka_run_group_tests(tests, NULL, NULL);
}

View File

@ -8,34 +8,88 @@
#include <cmocka.h>
#include <string.h>
#include <stdio.h>
#include <apr_pools.h>
#include <apr_strings.h>
#include <apr_time.h>
#include <apr_lib.h>
/* Mock JSON string escaping function for testing */
static void append_json_string(apr_pool_t *pool, apr_strbuf_t *buf, const char *str)
typedef struct {
char *data;
size_t len;
size_t cap;
apr_pool_t *pool;
} testbuf_t;
static void testbuf_init(testbuf_t *buf, apr_pool_t *pool, size_t initial_capacity)
{
buf->pool = pool;
buf->cap = initial_capacity;
buf->len = 0;
buf->data = apr_palloc(pool, initial_capacity);
buf->data[0] = '\0';
}
static void testbuf_append(testbuf_t *buf, const char *str, size_t len)
{
if (str == NULL) {
return;
}
if (len == (size_t)-1) {
len = strlen(str);
}
if (buf->len + len + 1 > buf->cap) {
size_t new_cap = (buf->len + len + 1) * 2;
char *new_data = apr_palloc(buf->pool, new_cap);
memcpy(new_data, buf->data, buf->len + 1);
buf->data = new_data;
buf->cap = new_cap;
}
memcpy(buf->data + buf->len, str, len);
buf->len += len;
buf->data[buf->len] = '\0';
}
static void testbuf_append_char(testbuf_t *buf, char c)
{
if (buf->len + 2 > buf->cap) {
size_t new_cap = (buf->cap * 2);
char *new_data = apr_palloc(buf->pool, new_cap);
memcpy(new_data, buf->data, buf->len + 1);
buf->data = new_data;
buf->cap = new_cap;
}
buf->data[buf->len++] = c;
buf->data[buf->len] = '\0';
}
/* Mock JSON string escaping function for testing */
static void append_json_string(testbuf_t *buf, const char *str)
{
if (str == NULL) {
return;
}
for (const char *p = str; *p; p++) {
char c = *p;
switch (c) {
case '"': apr_strbuf_append(buf, "\\\"", 2); break;
case '\\': apr_strbuf_append(buf, "\\\\", 2); break;
case '\b': apr_strbuf_append(buf, "\\b", 2); break;
case '\f': apr_strbuf_append(buf, "\\f", 2); break;
case '\n': apr_strbuf_append(buf, "\\n", 2); break;
case '\r': apr_strbuf_append(buf, "\\r", 2); break;
case '\t': apr_strbuf_append(buf, "\\t", 2); break;
case '"': testbuf_append(buf, "\\\"", 2); break;
case '\\': testbuf_append(buf, "\\\\", 2); break;
case '\b': testbuf_append(buf, "\\b", 2); break;
case '\f': testbuf_append(buf, "\\f", 2); break;
case '\n': testbuf_append(buf, "\\n", 2); break;
case '\r': testbuf_append(buf, "\\r", 2); break;
case '\t': testbuf_append(buf, "\\t", 2); break;
default:
if ((unsigned char)c < 0x20) {
char unicode[8];
apr_snprintf(unicode, sizeof(unicode), "\\u%04x", (unsigned char)c);
apr_strbuf_append(buf, unicode, -1);
testbuf_append(buf, unicode, (size_t)-1);
} else {
apr_strbuf_append_char(buf, c);
testbuf_append_char(buf, c);
}
break;
}
@ -46,16 +100,16 @@ static void append_json_string(apr_pool_t *pool, apr_strbuf_t *buf, const char *
static void test_json_escape_empty_string(void **state)
{
apr_pool_t *pool;
testbuf_t buf;
(void)state;
apr_pool_create(&pool, NULL);
apr_strbuf_t buf;
char *initial = apr_palloc(pool, 256);
apr_strbuf_init(pool, &buf, initial, 256);
append_json_string(pool, &buf, "");
assert_string_equal(buf.buf, "");
testbuf_init(&buf, pool, 256);
append_json_string(&buf, "");
assert_string_equal(buf.data, "");
apr_pool_destroy(pool);
}
@ -63,16 +117,16 @@ static void test_json_escape_empty_string(void **state)
static void test_json_escape_simple_string(void **state)
{
apr_pool_t *pool;
testbuf_t buf;
(void)state;
apr_pool_create(&pool, NULL);
apr_strbuf_t buf;
char *initial = apr_palloc(pool, 256);
apr_strbuf_init(pool, &buf, initial, 256);
append_json_string(pool, &buf, "hello world");
assert_string_equal(buf.buf, "hello world");
testbuf_init(&buf, pool, 256);
append_json_string(&buf, "hello world");
assert_string_equal(buf.data, "hello world");
apr_pool_destroy(pool);
}
@ -80,16 +134,16 @@ static void test_json_escape_simple_string(void **state)
static void test_json_escape_quotes(void **state)
{
apr_pool_t *pool;
testbuf_t buf;
(void)state;
apr_pool_create(&pool, NULL);
apr_strbuf_t buf;
char *initial = apr_palloc(pool, 256);
apr_strbuf_init(pool, &buf, initial, 256);
append_json_string(pool, &buf, "hello \"world\"");
assert_string_equal(buf.buf, "hello \\\"world\\\"");
testbuf_init(&buf, pool, 256);
append_json_string(&buf, "hello \"world\"");
assert_string_equal(buf.data, "hello \\\"world\\\"");
apr_pool_destroy(pool);
}
@ -97,16 +151,16 @@ static void test_json_escape_quotes(void **state)
static void test_json_escape_backslashes(void **state)
{
apr_pool_t *pool;
testbuf_t buf;
(void)state;
apr_pool_create(&pool, NULL);
apr_strbuf_t buf;
char *initial = apr_palloc(pool, 256);
apr_strbuf_init(pool, &buf, initial, 256);
append_json_string(pool, &buf, "path\\to\\file");
assert_string_equal(buf.buf, "path\\\\to\\\\file");
testbuf_init(&buf, pool, 256);
append_json_string(&buf, "path\\to\\file");
assert_string_equal(buf.data, "path\\\\to\\\\file");
apr_pool_destroy(pool);
}
@ -114,16 +168,16 @@ static void test_json_escape_backslashes(void **state)
static void test_json_escape_newlines_tabs(void **state)
{
apr_pool_t *pool;
testbuf_t buf;
(void)state;
apr_pool_create(&pool, NULL);
apr_strbuf_t buf;
char *initial = apr_palloc(pool, 256);
apr_strbuf_init(pool, &buf, initial, 256);
append_json_string(pool, &buf, "line1\nline2\ttab");
assert_string_equal(buf.buf, "line1\\nline2\\ttab");
testbuf_init(&buf, pool, 256);
append_json_string(&buf, "line1\nline2\ttab");
assert_string_equal(buf.data, "line1\\nline2\\ttab");
apr_pool_destroy(pool);
}
@ -131,18 +185,18 @@ static void test_json_escape_newlines_tabs(void **state)
static void test_json_escape_control_chars(void **state)
{
apr_pool_t *pool;
testbuf_t buf;
(void)state;
apr_pool_create(&pool, NULL);
apr_strbuf_t buf;
char *initial = apr_palloc(pool, 256);
apr_strbuf_init(pool, &buf, initial, 256);
testbuf_init(&buf, pool, 256);
/* Test with bell character (0x07) */
append_json_string(pool, &buf, "test\bell");
append_json_string(&buf, "test\bell");
/* Should contain unicode escape */
assert_true(strstr(buf.buf, "\\u0007") != NULL);
assert_true(strstr(buf.data, "\\u0007") != NULL);
apr_pool_destroy(pool);
}
@ -150,16 +204,16 @@ static void test_json_escape_control_chars(void **state)
static void test_json_escape_null_string(void **state)
{
apr_pool_t *pool;
testbuf_t buf;
(void)state;
apr_pool_create(&pool, NULL);
apr_strbuf_t buf;
char *initial = apr_palloc(pool, 256);
apr_strbuf_init(pool, &buf, initial, 256);
append_json_string(pool, &buf, NULL);
assert_string_equal(buf.buf, "");
testbuf_init(&buf, pool, 256);
append_json_string(&buf, NULL);
assert_string_equal(buf.data, "");
apr_pool_destroy(pool);
}
@ -167,17 +221,17 @@ static void test_json_escape_null_string(void **state)
static void test_json_escape_user_agent(void **state)
{
apr_pool_t *pool;
apr_pool_create(&pool, NULL);
apr_strbuf_t buf;
char *initial = apr_palloc(pool, 512);
apr_strbuf_init(pool, &buf, initial, 512);
testbuf_t buf;
const char *ua = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) \"Test\"";
append_json_string(pool, &buf, ua);
assert_true(strstr(buf.buf, "\\\"Test\\\"") != NULL);
(void)state;
apr_pool_create(&pool, NULL);
testbuf_init(&buf, pool, 512);
append_json_string(&buf, ua);
assert_true(strstr(buf.data, "\\\"Test\\\"") != NULL);
apr_pool_destroy(pool);
}
@ -193,6 +247,6 @@ int main(void)
cmocka_unit_test(test_json_escape_null_string),
cmocka_unit_test(test_json_escape_user_agent),
};
return cmocka_run_group_tests(tests, NULL, NULL);
}