"""Unit tests for src/my_deepagent/audit.py.""" from __future__ import annotations import json import os from pathlib import Path from typing import Any import pytest from my_deepagent.audit import ( append_audit_record, audit_path, make_audit_recorder, read_audit_records, ) # --------------------------------------------------------------------------- # audit_path # --------------------------------------------------------------------------- def test_audit_path_returns_correct_location(tmp_path: Path) -> None: expected = tmp_path / "audit.jsonl" assert audit_path(tmp_path) == expected # --------------------------------------------------------------------------- # append_audit_record # --------------------------------------------------------------------------- def test_append_audit_record_creates_file_with_one_line(tmp_path: Path) -> None: record: dict[str, Any] = {"tool_name": "read_file", "args": {"path": "x.py"}} append_audit_record(tmp_path, record) target = audit_path(tmp_path) assert target.is_file() lines = [ln for ln in target.read_text(encoding="utf-8").splitlines() if ln.strip()] assert len(lines) == 1 parsed = json.loads(lines[0]) assert parsed["tool_name"] == "read_file" assert "ts" in parsed def test_append_audit_record_accumulates_multiple_records(tmp_path: Path) -> None: for i in range(5): append_audit_record(tmp_path, {"seq": i}) records = read_audit_records(tmp_path) assert len(records) == 5 seqs = [r["seq"] for r in records] assert seqs == list(range(5)) def test_append_audit_record_file_permission_is_0600(tmp_path: Path) -> None: append_audit_record(tmp_path, {"tool_name": "test"}) target = audit_path(tmp_path) mode = os.stat(target).st_mode & 0o777 assert mode == 0o600 def test_append_audit_record_adds_ts_field(tmp_path: Path) -> None: append_audit_record(tmp_path, {"tool_name": "execute"}) records = read_audit_records(tmp_path) assert len(records) == 1 assert "ts" in records[0] # ts should be a non-empty ISO string assert len(records[0]["ts"]) > 0 # --------------------------------------------------------------------------- # read_audit_records # --------------------------------------------------------------------------- def test_read_audit_records_returns_empty_when_file_missing(tmp_path: Path) -> None: result = read_audit_records(tmp_path) assert result == [] def test_read_audit_records_returns_empty_for_empty_file(tmp_path: Path) -> None: target = audit_path(tmp_path) target.write_text("", encoding="utf-8") result = read_audit_records(tmp_path) assert result == [] def test_read_audit_records_with_limit_returns_last_n(tmp_path: Path) -> None: for i in range(10): append_audit_record(tmp_path, {"seq": i}) result = read_audit_records(tmp_path, limit=3) assert len(result) == 3 # should be the last 3 records (seq 7, 8, 9) assert result[0]["seq"] == 7 assert result[1]["seq"] == 8 assert result[2]["seq"] == 9 def test_read_audit_records_skips_corrupted_lines(tmp_path: Path) -> None: target = audit_path(tmp_path) # Write one valid + one corrupt + one valid line valid1 = json.dumps({"tool_name": "first"}) + "\n" corrupt = "NOT_VALID_JSON{\n" valid2 = json.dumps({"tool_name": "third"}) + "\n" target.write_text(valid1 + corrupt + valid2, encoding="utf-8") records = read_audit_records(tmp_path) assert len(records) == 2 assert records[0]["tool_name"] == "first" assert records[1]["tool_name"] == "third" # --------------------------------------------------------------------------- # make_audit_recorder # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_make_audit_recorder_writes_record(tmp_path: Path) -> None: recorder = make_audit_recorder(tmp_path) await recorder({"tool_name": "write_file", "args": {"path": "out.txt"}}) records = read_audit_records(tmp_path) assert len(records) == 1 assert records[0]["tool_name"] == "write_file"