"""Unit tests for src/my_deepagent/middleware/safety.py.""" from __future__ import annotations from typing import Any from unittest.mock import AsyncMock, MagicMock import pytest from my_deepagent.errors import MyDeepAgentError from my_deepagent.middleware.safety import SafetyShellMiddleware, _is_denied_path # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _make_shell_request(cmd: str | list[str], tool_name: str = "shell") -> MagicMock: request = MagicMock() if isinstance(cmd, list): request.tool_call = {"name": tool_name, "args": {"argv": cmd}} else: request.tool_call = {"name": tool_name, "args": {"command": cmd}} return request def _make_other_tool_request( name: str = "read_file", args: dict[str, Any] | None = None ) -> MagicMock: request = MagicMock() request.tool_call = {"name": name, "args": args or {}} return request # --------------------------------------------------------------------------- # Destructive commands — should raise # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_rm_rf_slash_is_blocked() -> None: mw = SafetyShellMiddleware() with pytest.raises(MyDeepAgentError) as exc_info: await mw.awrap_tool_call(_make_shell_request("rm -rf /"), AsyncMock()) assert exc_info.value.code == "destructive_command_blocked" @pytest.mark.asyncio async def test_rm_rf_with_path_is_blocked() -> None: mw = SafetyShellMiddleware() with pytest.raises(MyDeepAgentError) as exc_info: await mw.awrap_tool_call(_make_shell_request("rm -rf ./build"), AsyncMock()) assert exc_info.value.code == "destructive_command_blocked" @pytest.mark.asyncio async def test_git_push_force_is_blocked() -> None: mw = SafetyShellMiddleware() with pytest.raises(MyDeepAgentError): await mw.awrap_tool_call(_make_shell_request("git push --force origin main"), AsyncMock()) @pytest.mark.asyncio async def test_git_push_force_with_lease_is_blocked() -> None: mw = SafetyShellMiddleware() with pytest.raises(MyDeepAgentError): await mw.awrap_tool_call( _make_shell_request("git push --force-with-lease origin main"), AsyncMock() ) @pytest.mark.asyncio async def test_git_reset_hard_is_blocked() -> None: mw = SafetyShellMiddleware() with pytest.raises(MyDeepAgentError): await mw.awrap_tool_call(_make_shell_request("git reset --hard HEAD"), AsyncMock()) @pytest.mark.asyncio async def test_git_clean_is_blocked() -> None: mw = SafetyShellMiddleware() with pytest.raises(MyDeepAgentError): await mw.awrap_tool_call(_make_shell_request("git clean -fd"), AsyncMock()) @pytest.mark.asyncio async def test_drop_table_sql_is_blocked() -> None: mw = SafetyShellMiddleware() with pytest.raises(MyDeepAgentError): await mw.awrap_tool_call(_make_shell_request("psql -c 'DROP TABLE users'"), AsyncMock()) @pytest.mark.asyncio async def test_execute_tool_name_also_blocked() -> None: """The 'execute' tool name is also checked for destructive patterns.""" mw = SafetyShellMiddleware() with pytest.raises(MyDeepAgentError): await mw.awrap_tool_call( _make_shell_request("rm -rf /tmp/data", tool_name="execute"), AsyncMock() ) # --------------------------------------------------------------------------- # argv (list) form — should also be blocked # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_rm_rf_as_list_argv_is_blocked() -> None: mw = SafetyShellMiddleware() with pytest.raises(MyDeepAgentError): await mw.awrap_tool_call( _make_shell_request(["rm", "-rf", "/tmp"], tool_name="shell"), AsyncMock() ) # --------------------------------------------------------------------------- # Safe commands — should pass through # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_ls_la_passes_through() -> None: mw = SafetyShellMiddleware() handler = AsyncMock(return_value="total 42") result = await mw.awrap_tool_call(_make_shell_request("ls -la"), handler) assert result == "total 42" handler.assert_awaited_once() @pytest.mark.asyncio async def test_git_status_passes_through() -> None: mw = SafetyShellMiddleware() handler = AsyncMock(return_value="On branch main") result = await mw.awrap_tool_call(_make_shell_request("git status"), handler) assert result == "On branch main" @pytest.mark.asyncio async def test_git_push_without_force_passes_through() -> None: mw = SafetyShellMiddleware() handler = AsyncMock(return_value="ok") result = await mw.awrap_tool_call(_make_shell_request("git push origin main"), handler) assert result == "ok" # --------------------------------------------------------------------------- # Non-shell tools — should NOT be inspected # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_read_file_tool_with_destructive_content_passes() -> None: """read_file is not a shell tool; its content should not be blocked.""" mw = SafetyShellMiddleware() handler = AsyncMock(return_value="file content") request = _make_other_tool_request("read_file", {"path": "/some/file.py"}) result = await mw.awrap_tool_call(request, handler) assert result == "file content" @pytest.mark.asyncio async def test_unknown_tool_not_checked() -> None: mw = SafetyShellMiddleware() handler = AsyncMock(return_value="ok") result = await mw.awrap_tool_call(_make_other_tool_request("arbitrary_tool"), handler) assert result == "ok" # --------------------------------------------------------------------------- # _is_denied_path unit tests # --------------------------------------------------------------------------- def test_is_denied_path_env_file() -> None: assert _is_denied_path(".env") is True def test_is_denied_path_env_local_in_subdir() -> None: assert _is_denied_path("config/.env.local") is True def test_is_denied_path_ssh_key() -> None: assert _is_denied_path(".ssh/id_rsa") is True def test_is_denied_path_safe_source_file() -> None: assert _is_denied_path("src/main.py") is False def test_is_denied_path_token_file() -> None: assert _is_denied_path("api_token.json") is True def test_is_denied_path_aws_credentials() -> None: assert _is_denied_path(".aws/credentials") is True def test_is_denied_path_pem_file() -> None: assert _is_denied_path("key.pem") is True def test_is_denied_path_absolute_env() -> None: # absolute path normalised by lstrip('/') assert _is_denied_path("/.env") is True # --------------------------------------------------------------------------- # Secret-path tool blocking via awrap_tool_call # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_read_file_env_path_is_blocked() -> None: mw = SafetyShellMiddleware() request = _make_other_tool_request("read_file", {"file_path": ".env"}) with pytest.raises(MyDeepAgentError) as exc_info: await mw.awrap_tool_call(request, AsyncMock()) assert exc_info.value.code == "secret_access_blocked" @pytest.mark.asyncio async def test_write_file_pem_path_is_blocked() -> None: mw = SafetyShellMiddleware() request = _make_other_tool_request("write_file", {"file_path": "key.pem"}) with pytest.raises(MyDeepAgentError) as exc_info: await mw.awrap_tool_call(request, AsyncMock()) assert exc_info.value.code == "secret_access_blocked" @pytest.mark.asyncio async def test_ls_ssh_dir_is_blocked() -> None: mw = SafetyShellMiddleware() request = _make_other_tool_request("ls", {"path": ".ssh/"}) with pytest.raises(MyDeepAgentError) as exc_info: await mw.awrap_tool_call(request, AsyncMock()) assert exc_info.value.code == "secret_access_blocked" @pytest.mark.asyncio async def test_read_file_safe_path_passes() -> None: mw = SafetyShellMiddleware() handler = AsyncMock(return_value="content") request = _make_other_tool_request("read_file", {"file_path": "src/foo.py"}) result = await mw.awrap_tool_call(request, handler) assert result == "content" handler.assert_awaited_once() @pytest.mark.asyncio async def test_execute_tool_path_arg_not_path_checked() -> None: """execute tool goes through shell-check only, not path-check.""" mw = SafetyShellMiddleware() handler = AsyncMock(return_value="ok") # safe shell command with a path arg — should not be blocked via path logic request = _make_shell_request("ls /some/safe/dir", tool_name="execute") result = await mw.awrap_tool_call(request, handler) assert result == "ok"