"""Unit tests for src/my_deepagent/middleware/audit.py.""" from __future__ import annotations from typing import Any from unittest.mock import AsyncMock, MagicMock from uuid import UUID import pytest from my_deepagent.middleware.audit import AuditToolMiddleware # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _make_request(name: str = "read_file", args: dict[str, Any] | None = None) -> MagicMock: request = MagicMock() request.tool_call = {"name": name, "args": args or {"path": "x.py"}} return request # --------------------------------------------------------------------------- # awrap_tool_call — success path # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_audit_middleware_records_correct_fields_on_success() -> None: recorder = AsyncMock() mw = AuditToolMiddleware( run_id=UUID("00000000-0000-0000-0000-000000000001"), phase_id=UUID("00000000-0000-0000-0000-000000000002"), interactive_session_id=UUID("00000000-0000-0000-0000-000000000003"), recorder=recorder, ) result_value = "file contents here" handler = AsyncMock(return_value=result_value) request = _make_request(name="read_file", args={"path": "src/main.py"}) result = await mw.awrap_tool_call(request, handler) assert result == result_value recorder.assert_awaited_once() record: dict[str, Any] = recorder.call_args[0][0] assert record["tool_name"] == "read_file" assert record["args"] == {"path": "src/main.py"} assert record["result"] == result_value assert record["error"] is None assert record["duration_ms"] >= 0 assert record["run_id"] == UUID("00000000-0000-0000-0000-000000000001") @pytest.mark.asyncio async def test_audit_middleware_no_recorder_is_noop() -> None: mw = AuditToolMiddleware() handler = AsyncMock(return_value="ok") result = await mw.awrap_tool_call(_make_request(), handler) assert result == "ok" # --------------------------------------------------------------------------- # awrap_tool_call — error path # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_audit_middleware_records_error_code_on_exception() -> None: recorder = AsyncMock() mw = AuditToolMiddleware(recorder=recorder) handler = AsyncMock(side_effect=PermissionError("access denied")) with pytest.raises(PermissionError): await mw.awrap_tool_call(_make_request(), handler) recorder.assert_awaited_once() record: dict[str, Any] = recorder.call_args[0][0] assert record["error"] == "PermissionError" assert record["result"] is None @pytest.mark.asyncio async def test_audit_middleware_reraises_exception() -> None: mw = AuditToolMiddleware(recorder=AsyncMock()) handler = AsyncMock(side_effect=ValueError("bad args")) with pytest.raises(ValueError, match="bad args"): await mw.awrap_tool_call(_make_request(), handler) # --------------------------------------------------------------------------- # result serialization # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_audit_middleware_serializes_non_primitive_result_as_str() -> None: recorder = AsyncMock() mw = AuditToolMiddleware(recorder=recorder) class _CustomResult: def __str__(self) -> str: return "custom-result-str" handler = AsyncMock(return_value=_CustomResult()) await mw.awrap_tool_call(_make_request(), handler) record = recorder.call_args[0][0] assert record["result"] == "custom-result-str" @pytest.mark.asyncio async def test_audit_middleware_passes_dict_result_as_is() -> None: recorder = AsyncMock() mw = AuditToolMiddleware(recorder=recorder) handler = AsyncMock(return_value={"key": "value"}) await mw.awrap_tool_call(_make_request(), handler) record = recorder.call_args[0][0] assert record["result"] == {"key": "value"}