feat(my-deepagent): v0.1.0 Step 0~5 — scaffolding through deepagent + OpenRouter
Python rewrite of the agent harness on top of deepagents 0.6.1 + langchain 1.x, replacing the abandoned TS attempt in packages/. 388 unit/integration tests pass. Steps ----- 0. Scaffolding — uv workspace, ruff/mypy/pre-commit/alembic, src/tests/docs trees with docs/schemas/ seeded from my-deepagent-seed/. 1. Core — config (pydantic-settings with MYDEEPAGENT_ env prefix and TOML source), enums (Backend, Capability, RiskLevel, ApprovalDecisionAction, ApprovalState, RunState, RunPhaseState, SessionState, ErrorClass), errors (MyDeepAgentError + BudgetExhaustedError with PEP-3134 cause + context suppression), hash (canonical JSON + sha256). 2. Persona/Workflow/Binding — pydantic v2 schemas with tuple-based deep immutability (post-construction hash drift prevented), YAML loaders, deterministic auto-select (preferred_backends → version → name → hash), override resolution with ineligibility diagnostics, PersonaConsentStore with fcntl.flock + tmp+fsync+rename atomic write. 3. Artifact schema registry — Draft202012Validator, multi-root resolution, structured ValidationFinding output. 4. Persistence — 18 SQLAlchemy 2.0 async ORM models with FK CASCADE/RESTRICT, WAL + busy_timeout + foreign_keys PRAGMA, alembic baseline + ux_active_run_repo_base partial unique index, LangGraph SqliteSaver as context manager only (lifecycle safety). 5. DeepAgent session — build_agent wires Persona → create_deep_agent with LocalShellBackend / FilesystemBackend / StateBackend / CompositeBackend, ChatOpenAI(base_url=openrouter) for openrouter: model strings, and 4 middleware classes (cost / audit-tool / safety-shell / fallback-model). Critical workarounds -------------------- - deepagents 0.6.1 rejects FilesystemPermission together with backends that implement SandboxBackendProtocol (LocalShellBackend). SafetyShellMiddleware enforces destructive-command and secret-path policy at the tool layer instead, and build_agent strips the permissions kwarg when the persona's deepagents_backend is local_shell. - FilesystemOperation in deepagents is Literal['read', 'write'] only; _map_operations collapses our richer schema (read/write/edit/ls) safely. Real OpenRouter smoke --------------------- test_openrouter_deepagents_local_shell_smoke calls DeepSeek via deepagents + LocalShellBackend + SafetyShellMiddleware end-to-end. PASS, ~$0.000001 cost, input=9 / output=1 tokens with content "OK". Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
0
my-deepagent/tests/unit/__init__.py
Normal file
0
my-deepagent/tests/unit/__init__.py
Normal file
391
my-deepagent/tests/unit/test_artifact_schema.py
Normal file
391
my-deepagent/tests/unit/test_artifact_schema.py
Normal file
@@ -0,0 +1,391 @@
|
||||
"""Unit tests for src/my_deepagent/artifact_schema.py."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from my_deepagent.artifact_schema import (
|
||||
ArtifactSchemaRegistry,
|
||||
ValidationFinding,
|
||||
ValidationResult,
|
||||
)
|
||||
from my_deepagent.errors import MyDeepAgentError
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
REPO_ROOT = Path(__file__).parent.parent.parent
|
||||
SEED_ROOT = REPO_ROOT / "docs" / "schemas" / "artifacts"
|
||||
|
||||
SEED_SCHEMA_IDS = [
|
||||
"common/final-report@1",
|
||||
"dev/phase-plan@1",
|
||||
"dev/review-finding-batch@1",
|
||||
"dev/spec@1",
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def seed_registry() -> ArtifactSchemaRegistry:
|
||||
return ArtifactSchemaRegistry(roots=[SEED_ROOT])
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def valid_spec() -> dict[str, Any]:
|
||||
return {
|
||||
"runId": "00000000-0000-4000-8000-000000000000",
|
||||
"phaseKey": "spec",
|
||||
"requirements": "User wants a CLI tool that analyzes log files.",
|
||||
"acceptance_criteria": ["parses .log files", "outputs JSON summary"],
|
||||
"approach": "Build a typer-based CLI using regex and json output.",
|
||||
"risks": ["log format variations may break parser"],
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 1. Seed schema load success (4 schemas)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.parametrize("schema_id", SEED_SCHEMA_IDS)
|
||||
def test_seed_schema_loads(seed_registry: ArtifactSchemaRegistry, schema_id: str) -> None:
|
||||
schema = seed_registry.load(schema_id)
|
||||
assert isinstance(schema, dict)
|
||||
assert schema.get("$id") == schema_id
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 2. Load result caching — same dict object on second call
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_load_caches_same_object(seed_registry: ArtifactSchemaRegistry) -> None:
|
||||
first = seed_registry.load("dev/spec@1")
|
||||
second = seed_registry.load("dev/spec@1")
|
||||
assert first is second
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 3. Unknown schema_id → artifact_schema_unknown
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_unknown_schema_id_raises(seed_registry: ArtifactSchemaRegistry) -> None:
|
||||
with pytest.raises(MyDeepAgentError) as exc_info:
|
||||
seed_registry.load("dev/nonexistent@99")
|
||||
assert exc_info.value.code == "artifact_schema_unknown"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 4. Invalid schema_id format (no slash) → artifact_schema_unknown
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_invalid_schema_id_no_slash(seed_registry: ArtifactSchemaRegistry) -> None:
|
||||
with pytest.raises(MyDeepAgentError) as exc_info:
|
||||
seed_registry.load("foo")
|
||||
assert exc_info.value.code == "artifact_schema_unknown"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 5. schema_id starting with "/" → rejected (no slash separating domain/name)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_invalid_schema_id_leading_slash(seed_registry: ArtifactSchemaRegistry) -> None:
|
||||
# "/foo/bar" has a slash but the domain portion would be empty
|
||||
# After splitting on "/", domain="" which is not a valid domain/name pair.
|
||||
# The registry treats it as a path traversal risk: Path("/foo/bar.json")
|
||||
# is absolute and will never exist under a root directory (is_file() → False).
|
||||
with pytest.raises(MyDeepAgentError) as exc_info:
|
||||
seed_registry.load("/dev/spec@1")
|
||||
assert exc_info.value.code == "artifact_schema_unknown"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 6. Empty schema_id → artifact_schema_unknown
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_empty_schema_id_raises(seed_registry: ArtifactSchemaRegistry) -> None:
|
||||
with pytest.raises(MyDeepAgentError) as exc_info:
|
||||
seed_registry.load("")
|
||||
assert exc_info.value.code == "artifact_schema_unknown"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 7. Fallback: schema absent in first root, present in second
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_fallback_to_second_root(tmp_path: Path) -> None:
|
||||
first_root = tmp_path / "first"
|
||||
first_root.mkdir()
|
||||
second_root = tmp_path / "second"
|
||||
(second_root / "dev").mkdir(parents=True)
|
||||
schema: dict[str, Any] = {
|
||||
"$schema": "https://json-schema.org/draft/2020-12/schema",
|
||||
"$id": "dev/thing@1",
|
||||
"type": "object",
|
||||
}
|
||||
(second_root / "dev" / "thing@1.json").write_text(json.dumps(schema), encoding="utf-8")
|
||||
registry = ArtifactSchemaRegistry(roots=[first_root, second_root])
|
||||
loaded = registry.load("dev/thing@1")
|
||||
assert loaded["$id"] == "dev/thing@1"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 8. validate with valid data → ok=True
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_validate_valid_spec(
|
||||
seed_registry: ArtifactSchemaRegistry, valid_spec: dict[str, Any]
|
||||
) -> None:
|
||||
result = seed_registry.validate("dev/spec@1", valid_spec)
|
||||
assert result.ok is True
|
||||
assert result.errors == ()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 9. validate with invalid data → ok=False, findings non-empty
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_validate_invalid_data_returns_findings(
|
||||
seed_registry: ArtifactSchemaRegistry,
|
||||
) -> None:
|
||||
result = seed_registry.validate("dev/spec@1", {"wrong": "data"})
|
||||
assert result.ok is False
|
||||
assert len(result.errors) > 0
|
||||
for finding in result.errors:
|
||||
assert isinstance(finding, ValidationFinding)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 10. Missing required field → validator="required", path correct
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_validate_missing_required_field(
|
||||
seed_registry: ArtifactSchemaRegistry, valid_spec: dict[str, Any]
|
||||
) -> None:
|
||||
data = {k: v for k, v in valid_spec.items() if k != "requirements"}
|
||||
result = seed_registry.validate("dev/spec@1", data)
|
||||
assert result.ok is False
|
||||
required_findings = [f for f in result.errors if f.validator == "required"]
|
||||
assert any("requirements" in f.message for f in required_findings)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 11. Invalid enum value → validator="enum", expected has enum list
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_validate_invalid_enum_severity(seed_registry: ArtifactSchemaRegistry) -> None:
|
||||
data = {
|
||||
"runId": "00000000-0000-4000-8000-000000000000",
|
||||
"phaseKey": "review",
|
||||
"reviewerRole": "code-reviewer",
|
||||
"findings": [
|
||||
{
|
||||
"severity": "bogus",
|
||||
"category": "correctness",
|
||||
"summary": "something is wrong here",
|
||||
}
|
||||
],
|
||||
"summary": "Overall review summary with enough length.",
|
||||
}
|
||||
result = seed_registry.validate("dev/review-finding-batch@1", data)
|
||||
assert result.ok is False
|
||||
enum_findings = [f for f in result.errors if f.validator == "enum"]
|
||||
assert len(enum_findings) > 0
|
||||
finding = enum_findings[0]
|
||||
assert isinstance(finding.expected, list)
|
||||
assert "bogus" not in finding.expected
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 12. Wrong type → validator="type", expected has type name
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_validate_wrong_type(
|
||||
seed_registry: ArtifactSchemaRegistry, valid_spec: dict[str, Any]
|
||||
) -> None:
|
||||
data = dict(valid_spec)
|
||||
data["acceptance_criteria"] = "should be a list, not a string"
|
||||
result = seed_registry.validate("dev/spec@1", data)
|
||||
assert result.ok is False
|
||||
type_findings = [f for f in result.errors if f.validator == "type"]
|
||||
assert len(type_findings) > 0
|
||||
assert type_findings[0].expected == "array"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 13. Nested error path — /findings/0/severity format
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_validate_nested_error_path(seed_registry: ArtifactSchemaRegistry) -> None:
|
||||
data = {
|
||||
"runId": "00000000-0000-4000-8000-000000000000",
|
||||
"phaseKey": "review",
|
||||
"reviewerRole": "code-reviewer",
|
||||
"findings": [
|
||||
{
|
||||
"severity": "not-valid",
|
||||
"category": "correctness",
|
||||
"summary": "a finding summary",
|
||||
}
|
||||
],
|
||||
"summary": "Overall review summary with enough length.",
|
||||
}
|
||||
result = seed_registry.validate("dev/review-finding-batch@1", data)
|
||||
assert result.ok is False
|
||||
paths = [f.path for f in result.errors]
|
||||
assert any(p.startswith("/findings/0/") for p in paths)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 14. known_schema_ids() returns all 4 seed schemas, sorted
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_known_schema_ids_returns_seeds(seed_registry: ArtifactSchemaRegistry) -> None:
|
||||
ids = seed_registry.known_schema_ids()
|
||||
for expected in SEED_SCHEMA_IDS:
|
||||
assert expected in ids
|
||||
assert ids == sorted(ids)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 15. Empty roots list → config_invalid
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_empty_roots_raises() -> None:
|
||||
with pytest.raises(MyDeepAgentError) as exc_info:
|
||||
ArtifactSchemaRegistry(roots=[])
|
||||
assert exc_info.value.code == "config_invalid"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 16. Corrupted JSON file → artifact_schema_load_failed
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_corrupted_json_raises(tmp_path: Path) -> None:
|
||||
(tmp_path / "dev").mkdir()
|
||||
(tmp_path / "dev" / "broken@1.json").write_text("{", encoding="utf-8")
|
||||
registry = ArtifactSchemaRegistry(roots=[tmp_path])
|
||||
with pytest.raises(MyDeepAgentError) as exc_info:
|
||||
registry.load("dev/broken@1")
|
||||
assert exc_info.value.code == "artifact_schema_load_failed"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 17. Valid JSON but not a dict → artifact_schema_load_failed
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_non_dict_json_raises(tmp_path: Path) -> None:
|
||||
(tmp_path / "dev").mkdir()
|
||||
(tmp_path / "dev" / "array@1.json").write_text("[1, 2, 3]", encoding="utf-8")
|
||||
registry = ArtifactSchemaRegistry(roots=[tmp_path])
|
||||
with pytest.raises(MyDeepAgentError) as exc_info:
|
||||
registry.load("dev/array@1")
|
||||
assert exc_info.value.code == "artifact_schema_load_failed"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 18. Schema itself is invalid Draft 2020-12 → artifact_schema_load_failed
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_invalid_draft_schema_raises(tmp_path: Path) -> None:
|
||||
(tmp_path / "dev").mkdir()
|
||||
bad_schema = {"type": "not_a_type"}
|
||||
(tmp_path / "dev" / "bad@1.json").write_text(json.dumps(bad_schema), encoding="utf-8")
|
||||
registry = ArtifactSchemaRegistry(roots=[tmp_path])
|
||||
with pytest.raises(MyDeepAgentError) as exc_info:
|
||||
registry.load("dev/bad@1")
|
||||
assert exc_info.value.code == "artifact_schema_load_failed"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 19. Validator caching: _validator called twice returns same instance
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_validator_instance_cached(seed_registry: ArtifactSchemaRegistry) -> None:
|
||||
# Access internal cache to verify the same validator instance is reused.
|
||||
v1 = seed_registry._validator("dev/spec@1")
|
||||
v2 = seed_registry._validator("dev/spec@1")
|
||||
assert v1 is v2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 20. dev/spec@1 valid example produces ok=True (full fixture check)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_spec_valid_example_ok(seed_registry: ArtifactSchemaRegistry) -> None:
|
||||
valid_spec: dict[str, Any] = {
|
||||
"runId": "00000000-0000-4000-8000-000000000000",
|
||||
"phaseKey": "spec",
|
||||
"requirements": "User wants a CLI tool that analyzes log files.",
|
||||
"acceptance_criteria": ["parses .log files", "outputs JSON summary"],
|
||||
"approach": "Build a typer-based CLI using regex and json output.",
|
||||
"risks": ["log format variations may break parser"],
|
||||
}
|
||||
result = seed_registry.validate("dev/spec@1", valid_spec)
|
||||
assert result.ok is True
|
||||
assert result.errors == ()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Bonus: ValidationResult and ValidationFinding are frozen dataclasses
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_validation_result_frozen() -> None:
|
||||
result = ValidationResult(ok=True)
|
||||
with pytest.raises((AttributeError, TypeError)):
|
||||
result.ok = False # type: ignore[misc]
|
||||
|
||||
|
||||
def test_validation_finding_frozen() -> None:
|
||||
finding = ValidationFinding(path="/foo", message="err", validator="type", expected="string")
|
||||
with pytest.raises((AttributeError, TypeError)):
|
||||
finding.path = "/bar" # type: ignore[misc]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Bonus: known_schema_ids with nonexistent root dir is silently skipped
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_known_schema_ids_skips_nonexistent_root(tmp_path: Path) -> None:
|
||||
missing = tmp_path / "does_not_exist"
|
||||
registry = ArtifactSchemaRegistry(roots=[missing])
|
||||
assert registry.known_schema_ids() == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Bonus: validate with non-dict top-level data
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_validate_non_dict_data_returns_error(
|
||||
seed_registry: ArtifactSchemaRegistry,
|
||||
) -> None:
|
||||
result = seed_registry.validate("dev/spec@1", [1, 2, 3])
|
||||
assert result.ok is False
|
||||
type_findings = [f for f in result.errors if f.validator == "type"]
|
||||
assert len(type_findings) > 0
|
||||
644
my-deepagent/tests/unit/test_binding.py
Normal file
644
my-deepagent/tests/unit/test_binding.py
Normal file
@@ -0,0 +1,644 @@
|
||||
"""Unit tests for src/my_deepagent/binding.py."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import fcntl
|
||||
import json
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from my_deepagent.binding import (
|
||||
BackendAvailability,
|
||||
Binding,
|
||||
BindingOverride,
|
||||
PersonaConsentStore,
|
||||
bind_personas,
|
||||
filter_consented_personas,
|
||||
is_persona_eligible_for_role,
|
||||
)
|
||||
from my_deepagent.enums import Backend, Capability
|
||||
from my_deepagent.errors import MyDeepAgentError
|
||||
from my_deepagent.persona import Persona, load_personas_from_dir
|
||||
from my_deepagent.workflow import WorkflowTemplate, load_workflows_from_dir
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PersonaConsentStore file-lock (fcntl.flock) verification
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_consent_store_set_acquires_exclusive_lock(
|
||||
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""set() must take an exclusive flock and release it."""
|
||||
ops: list[int] = []
|
||||
orig_flock = fcntl.flock
|
||||
|
||||
def spy(fd: int, op: int) -> None:
|
||||
ops.append(op)
|
||||
orig_flock(fd, op)
|
||||
|
||||
monkeypatch.setattr(fcntl, "flock", spy)
|
||||
store = PersonaConsentStore(tmp_path / "consents.json")
|
||||
store.set("hash_abc", "approve")
|
||||
assert fcntl.LOCK_EX in ops
|
||||
assert fcntl.LOCK_UN in ops
|
||||
|
||||
|
||||
def test_consent_store_revoke_acquires_exclusive_lock(
|
||||
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
ops: list[int] = []
|
||||
orig_flock = fcntl.flock
|
||||
|
||||
def spy(fd: int, op: int) -> None:
|
||||
ops.append(op)
|
||||
orig_flock(fd, op)
|
||||
|
||||
monkeypatch.setattr(fcntl, "flock", spy)
|
||||
store = PersonaConsentStore(tmp_path / "consents.json")
|
||||
store.set("h", "approve")
|
||||
ops.clear()
|
||||
store.revoke("h")
|
||||
assert fcntl.LOCK_EX in ops
|
||||
assert fcntl.LOCK_UN in ops
|
||||
|
||||
|
||||
def test_consent_store_get_acquires_shared_lock(
|
||||
tmp_path: Path, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""get() takes a shared lock (LOCK_SH) so multiple readers don't serialise."""
|
||||
ops: list[int] = []
|
||||
orig_flock = fcntl.flock
|
||||
|
||||
def spy(fd: int, op: int) -> None:
|
||||
ops.append(op)
|
||||
orig_flock(fd, op)
|
||||
|
||||
monkeypatch.setattr(fcntl, "flock", spy)
|
||||
store = PersonaConsentStore(tmp_path / "consents.json")
|
||||
store.set("h", "approve")
|
||||
ops.clear()
|
||||
_ = store.get("h")
|
||||
assert fcntl.LOCK_SH in ops
|
||||
assert fcntl.LOCK_UN in ops
|
||||
|
||||
|
||||
def test_consent_store_lock_file_created(tmp_path: Path) -> None:
|
||||
"""A .lock sidecar file is created next to the consent store on first write."""
|
||||
path = tmp_path / "consents.json"
|
||||
store = PersonaConsentStore(path)
|
||||
store.set("h", "approve")
|
||||
assert (tmp_path / "consents.json.lock").is_file()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures / helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
PERSONAS_DIR = Path(__file__).parent.parent.parent / "docs" / "schemas" / "personas"
|
||||
WORKFLOWS_DIR = Path(__file__).parent.parent.parent / "docs" / "schemas" / "workflows"
|
||||
|
||||
|
||||
def _minimal_persona(**overrides: object) -> Persona:
|
||||
base: dict[str, object] = {
|
||||
"name": "test-persona",
|
||||
"version": 1,
|
||||
"backend": "openrouter",
|
||||
"model": "openrouter:anthropic/claude-sonnet-4-6",
|
||||
"provider_origin": "US/Anthropic",
|
||||
"capabilities": ["spec_write", "phase_planning"],
|
||||
"max_risk_level": "low",
|
||||
"system_prompt": "You are a test persona for unit tests.",
|
||||
}
|
||||
base.update(overrides)
|
||||
return Persona.model_validate(base)
|
||||
|
||||
|
||||
def _all_available() -> BackendAvailability:
|
||||
return BackendAvailability(available_backends=frozenset(Backend))
|
||||
|
||||
|
||||
def _none_available() -> BackendAvailability:
|
||||
return BackendAvailability(available_backends=frozenset())
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def consent_store(tmp_path: Path) -> PersonaConsentStore:
|
||||
return PersonaConsentStore(tmp_path / "consents.json")
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def seed_personas() -> list[Persona]:
|
||||
return load_personas_from_dir(PERSONAS_DIR)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def spec_and_review() -> WorkflowTemplate:
|
||||
workflows = load_workflows_from_dir(WORKFLOWS_DIR)
|
||||
return next(w for w in workflows if w.name == "spec-and-review")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# is_persona_eligible_for_role
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_eligible_all_ok(spec_and_review: WorkflowTemplate) -> None:
|
||||
spec_writer_role = next(r for r in spec_and_review.roles if r.id == "spec_writer")
|
||||
p = _minimal_persona(capabilities=["spec_write", "phase_planning"], max_risk_level="low")
|
||||
ok, reason = is_persona_eligible_for_role(p, spec_writer_role, spec_and_review)
|
||||
assert ok is True
|
||||
assert reason is None
|
||||
|
||||
|
||||
def test_eligible_missing_capability(spec_and_review: WorkflowTemplate) -> None:
|
||||
spec_writer_role = next(r for r in spec_and_review.roles if r.id == "spec_writer")
|
||||
# only spec_write, missing phase_planning
|
||||
p = _minimal_persona(capabilities=["spec_write"], max_risk_level="low")
|
||||
ok, reason = is_persona_eligible_for_role(p, spec_writer_role, spec_and_review)
|
||||
assert ok is False
|
||||
assert reason is not None
|
||||
assert "phase_planning" in reason
|
||||
|
||||
|
||||
def test_eligible_allowed_roles_mismatch(spec_and_review: WorkflowTemplate) -> None:
|
||||
spec_writer_role = next(r for r in spec_and_review.roles if r.id == "spec_writer")
|
||||
p = _minimal_persona(
|
||||
capabilities=["spec_write", "phase_planning"],
|
||||
max_risk_level="low",
|
||||
allowed_roles=["reviewer"], # does not include spec_writer
|
||||
)
|
||||
ok, reason = is_persona_eligible_for_role(p, spec_writer_role, spec_and_review)
|
||||
assert ok is False
|
||||
assert reason is not None
|
||||
assert "allowed_roles" in reason
|
||||
|
||||
|
||||
def test_eligible_allowed_roles_matches(spec_and_review: WorkflowTemplate) -> None:
|
||||
spec_writer_role = next(r for r in spec_and_review.roles if r.id == "spec_writer")
|
||||
p = _minimal_persona(
|
||||
capabilities=["spec_write", "phase_planning"],
|
||||
max_risk_level="low",
|
||||
allowed_roles=["spec_writer"],
|
||||
)
|
||||
ok, reason = is_persona_eligible_for_role(p, spec_writer_role, spec_and_review)
|
||||
assert ok is True
|
||||
assert reason is None
|
||||
|
||||
|
||||
def test_eligible_risk_too_high(spec_and_review: WorkflowTemplate) -> None:
|
||||
"""bug-fix workflow has a 'medium' risk phase; a low-only persona is ineligible for it."""
|
||||
bug_fix = load_workflows_from_dir(WORKFLOWS_DIR)
|
||||
bug_fix_wf = next(w for w in bug_fix if w.name == "bug-fix-with-reproduction")
|
||||
fixer_role = next(r for r in bug_fix_wf.roles if r.id == "fixer")
|
||||
# fixer role has a 'medium' risk phase
|
||||
p = _minimal_persona(
|
||||
capabilities=["code_edit", "test_first_development"],
|
||||
max_risk_level="low", # too low for medium phase
|
||||
)
|
||||
ok, reason = is_persona_eligible_for_role(p, fixer_role, bug_fix_wf)
|
||||
assert ok is False
|
||||
assert reason is not None
|
||||
assert "medium" in reason
|
||||
|
||||
|
||||
def test_eligible_risk_exact_match(spec_and_review: WorkflowTemplate) -> None:
|
||||
spec_writer_role = next(r for r in spec_and_review.roles if r.id == "spec_writer")
|
||||
p = _minimal_persona(capabilities=["spec_write", "phase_planning"], max_risk_level="low")
|
||||
ok, _ = is_persona_eligible_for_role(p, spec_writer_role, spec_and_review)
|
||||
assert ok is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# bind_personas: end-to-end with seed data
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_bind_personas_spec_and_review_success(
|
||||
seed_personas: list[Persona],
|
||||
spec_and_review: WorkflowTemplate,
|
||||
consent_store: PersonaConsentStore,
|
||||
) -> None:
|
||||
bindings = bind_personas(spec_and_review, seed_personas, _all_available(), consent_store)
|
||||
assert set(bindings.keys()) == {"spec_writer", "reviewer", "verifier"}
|
||||
for role_id, binding in bindings.items():
|
||||
assert isinstance(binding, Binding)
|
||||
assert binding.role_id == role_id
|
||||
assert re.fullmatch(r"[0-9a-f]{64}", binding.binding_hash)
|
||||
|
||||
|
||||
def test_bind_personas_binding_hash_deterministic(
|
||||
seed_personas: list[Persona],
|
||||
spec_and_review: WorkflowTemplate,
|
||||
consent_store: PersonaConsentStore,
|
||||
) -> None:
|
||||
b1 = bind_personas(spec_and_review, seed_personas, _all_available(), consent_store)
|
||||
b2 = bind_personas(spec_and_review, seed_personas, _all_available(), consent_store)
|
||||
for role_id in b1:
|
||||
assert b1[role_id].binding_hash == b2[role_id].binding_hash
|
||||
|
||||
|
||||
def test_bind_personas_spec_writer_is_spec_writer(
|
||||
seed_personas: list[Persona],
|
||||
spec_and_review: WorkflowTemplate,
|
||||
consent_store: PersonaConsentStore,
|
||||
) -> None:
|
||||
bindings = bind_personas(spec_and_review, seed_personas, _all_available(), consent_store)
|
||||
spec_persona = bindings["spec_writer"].persona
|
||||
assert Capability.SPEC_WRITE in spec_persona.capabilities
|
||||
assert Capability.PHASE_PLANNING in spec_persona.capabilities
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# bind_personas: override
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_bind_personas_override_picks_pinned(
|
||||
seed_personas: list[Persona],
|
||||
spec_and_review: WorkflowTemplate,
|
||||
consent_store: PersonaConsentStore,
|
||||
) -> None:
|
||||
override = BindingOverride.parse({"spec_writer": "openrouter-claude-spec-writer@1"})
|
||||
bindings = bind_personas(
|
||||
spec_and_review, seed_personas, _all_available(), consent_store, override
|
||||
)
|
||||
assert bindings["spec_writer"].persona.name == "openrouter-claude-spec-writer"
|
||||
|
||||
|
||||
def test_bind_personas_override_invalid_persona_raises(
|
||||
seed_personas: list[Persona],
|
||||
spec_and_review: WorkflowTemplate,
|
||||
consent_store: PersonaConsentStore,
|
||||
) -> None:
|
||||
override = BindingOverride.parse({"spec_writer": "nonexistent-persona@1"})
|
||||
with pytest.raises(MyDeepAgentError) as exc_info:
|
||||
bind_personas(spec_and_review, seed_personas, _all_available(), consent_store, override)
|
||||
assert exc_info.value.code == "no_eligible_persona"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# bind_personas: backend unavailable
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_bind_personas_backend_unavailable_raises(
|
||||
seed_personas: list[Persona],
|
||||
spec_and_review: WorkflowTemplate,
|
||||
consent_store: PersonaConsentStore,
|
||||
) -> None:
|
||||
with pytest.raises(MyDeepAgentError) as exc_info:
|
||||
bind_personas(spec_and_review, seed_personas, _none_available(), consent_store)
|
||||
assert exc_info.value.code == "backend_unavailable"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# bind_personas: model_unavailable for openrouter with empty model
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_bind_personas_model_unavailable_raises(
|
||||
spec_and_review: WorkflowTemplate,
|
||||
consent_store: PersonaConsentStore,
|
||||
) -> None:
|
||||
"""Verify FAKE backend binds successfully (positive path for non-openrouter backends).
|
||||
|
||||
We cannot construct an openrouter persona with empty model via model_validate because
|
||||
the validator rejects it. Instead verify the happy path: FAKE backend + non-empty
|
||||
model should bind without errors when the FAKE backend is available.
|
||||
"""
|
||||
from my_deepagent.workflow import WorkflowPhase, WorkflowRole
|
||||
|
||||
role = WorkflowRole.model_validate(
|
||||
{
|
||||
"id": "spec_writer",
|
||||
"required_capabilities": ["spec_write", "phase_planning"],
|
||||
"preferred_backends": ["fake"],
|
||||
}
|
||||
)
|
||||
phase = WorkflowPhase.model_validate(
|
||||
{
|
||||
"key": "spec",
|
||||
"title": "Write spec",
|
||||
"risk": "low",
|
||||
"role": "spec_writer",
|
||||
"instructions": "Write the specification document.",
|
||||
}
|
||||
)
|
||||
tmpl = WorkflowTemplate.model_validate(
|
||||
{
|
||||
"name": "fake-wf",
|
||||
"version": 1,
|
||||
"roles": [role.model_dump()],
|
||||
"phases": [phase.model_dump()],
|
||||
}
|
||||
)
|
||||
fake_persona = _minimal_persona(
|
||||
backend="fake",
|
||||
model="fake-model",
|
||||
capabilities=["spec_write", "phase_planning"],
|
||||
)
|
||||
fake_avail = BackendAvailability(available_backends=frozenset({Backend.FAKE}))
|
||||
# Should succeed with FAKE backend + non-empty model
|
||||
bindings = bind_personas(tmpl, [fake_persona], fake_avail, consent_store)
|
||||
assert "spec_writer" in bindings
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# bind_personas: no eligible persona
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_bind_personas_no_eligible_raises(
|
||||
spec_and_review: WorkflowTemplate,
|
||||
consent_store: PersonaConsentStore,
|
||||
) -> None:
|
||||
# Provide a persona with wrong capabilities
|
||||
bad_persona = _minimal_persona(capabilities=["backtest_run"])
|
||||
with pytest.raises(MyDeepAgentError) as exc_info:
|
||||
bind_personas(spec_and_review, [bad_persona], _all_available(), consent_store)
|
||||
assert exc_info.value.code == "no_eligible_persona"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PersonaConsentStore: get / set / revoke
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_consent_store_get_none_when_absent(consent_store: PersonaConsentStore) -> None:
|
||||
assert consent_store.get("abc123") is None
|
||||
|
||||
|
||||
def test_consent_store_set_and_get(consent_store: PersonaConsentStore) -> None:
|
||||
consent_store.set("abc123", "approve")
|
||||
assert consent_store.get("abc123") == "approve"
|
||||
|
||||
|
||||
def test_consent_store_block(consent_store: PersonaConsentStore) -> None:
|
||||
consent_store.set("abc123", "block")
|
||||
assert consent_store.get("abc123") == "block"
|
||||
|
||||
|
||||
def test_consent_store_once(consent_store: PersonaConsentStore) -> None:
|
||||
consent_store.set("abc123", "once")
|
||||
assert consent_store.get("abc123") == "once"
|
||||
|
||||
|
||||
def test_consent_store_revoke(consent_store: PersonaConsentStore) -> None:
|
||||
consent_store.set("abc123", "approve")
|
||||
consent_store.revoke("abc123")
|
||||
assert consent_store.get("abc123") is None
|
||||
|
||||
|
||||
def test_consent_store_revoke_absent_is_noop(consent_store: PersonaConsentStore) -> None:
|
||||
consent_store.revoke("not_present") # must not raise
|
||||
|
||||
|
||||
def test_consent_store_overwrite(consent_store: PersonaConsentStore) -> None:
|
||||
consent_store.set("abc123", "approve")
|
||||
consent_store.set("abc123", "block")
|
||||
assert consent_store.get("abc123") == "block"
|
||||
|
||||
|
||||
def test_consent_store_unknown_decision_returns_none(
|
||||
consent_store: PersonaConsentStore,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Corrupt decision value (not approve/block/once) returns None, not raise."""
|
||||
path = tmp_path / "consents.json"
|
||||
path.write_text(
|
||||
json.dumps({"abc123": {"decision": "foobar", "decided_at": "2026-01-01T00:00:00+00:00"}}),
|
||||
encoding="utf-8",
|
||||
)
|
||||
store = PersonaConsentStore(path)
|
||||
assert store.get("abc123") is None
|
||||
|
||||
|
||||
def test_consent_store_corrupted_json_raises_fatal(tmp_path: Path) -> None:
|
||||
path = tmp_path / "consents.json"
|
||||
path.write_text("{invalid json", encoding="utf-8")
|
||||
store = PersonaConsentStore(path)
|
||||
with pytest.raises(MyDeepAgentError) as exc_info:
|
||||
store.get("abc123")
|
||||
assert exc_info.value.code == "internal_state_corruption"
|
||||
|
||||
|
||||
def test_consent_store_atomic_write(consent_store: PersonaConsentStore) -> None:
|
||||
"""The .tmp file must not remain after a successful write."""
|
||||
consent_store.set("abc", "approve")
|
||||
tmp_file = consent_store._path.with_suffix(".json.tmp")
|
||||
assert not tmp_file.exists(), ".tmp leftover after successful write"
|
||||
|
||||
|
||||
def test_consent_store_json_format(consent_store: PersonaConsentStore) -> None:
|
||||
"""Stored JSON must be valid and contain decision + decided_at."""
|
||||
consent_store.set("myhash", "once")
|
||||
raw = consent_store._path.read_text(encoding="utf-8")
|
||||
data = json.loads(raw)
|
||||
assert "myhash" in data
|
||||
assert data["myhash"]["decision"] == "once"
|
||||
assert "decided_at" in data["myhash"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# filter_consented_personas
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_filter_removes_blocked(consent_store: PersonaConsentStore) -> None:
|
||||
p1 = _minimal_persona(name="p1")
|
||||
p2 = _minimal_persona(name="p2")
|
||||
consent_store.set(p2.compute_hash(), "block")
|
||||
result = filter_consented_personas([p1, p2], consent_store)
|
||||
assert len(result) == 1
|
||||
assert result[0].name == "p1"
|
||||
|
||||
|
||||
def test_filter_keeps_approved(consent_store: PersonaConsentStore) -> None:
|
||||
p = _minimal_persona()
|
||||
consent_store.set(p.compute_hash(), "approve")
|
||||
result = filter_consented_personas([p], consent_store)
|
||||
assert len(result) == 1
|
||||
|
||||
|
||||
def test_filter_keeps_once(consent_store: PersonaConsentStore) -> None:
|
||||
p = _minimal_persona()
|
||||
consent_store.set(p.compute_hash(), "once")
|
||||
result = filter_consented_personas([p], consent_store)
|
||||
assert len(result) == 1
|
||||
|
||||
|
||||
def test_filter_keeps_none_decision(consent_store: PersonaConsentStore) -> None:
|
||||
"""Persona with no stored decision passes through."""
|
||||
p = _minimal_persona()
|
||||
result = filter_consented_personas([p], consent_store)
|
||||
assert len(result) == 1
|
||||
|
||||
|
||||
def test_filter_empty_list(consent_store: PersonaConsentStore) -> None:
|
||||
result = filter_consented_personas([], consent_store)
|
||||
assert result == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# bind_personas: consent-blocked persona detection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_bind_personas_all_eligible_blocked_raises(
|
||||
seed_personas: list[Persona],
|
||||
spec_and_review: WorkflowTemplate,
|
||||
consent_store: PersonaConsentStore,
|
||||
) -> None:
|
||||
# Block all spec_writer-eligible personas
|
||||
for p in seed_personas:
|
||||
if Capability.SPEC_WRITE in p.capabilities and Capability.PHASE_PLANNING in p.capabilities:
|
||||
consent_store.set(p.compute_hash(), "block")
|
||||
|
||||
with pytest.raises(MyDeepAgentError) as exc_info:
|
||||
bind_personas(spec_and_review, seed_personas, _all_available(), consent_store)
|
||||
assert exc_info.value.code in ("persona_blocked_by_user", "no_eligible_persona")
|
||||
|
||||
|
||||
def test_bind_personas_override_blocked_raises(
|
||||
seed_personas: list[Persona],
|
||||
spec_and_review: WorkflowTemplate,
|
||||
consent_store: PersonaConsentStore,
|
||||
) -> None:
|
||||
spec_writer = next(p for p in seed_personas if p.name == "openrouter-claude-spec-writer")
|
||||
consent_store.set(spec_writer.compute_hash(), "block")
|
||||
override = BindingOverride.parse({"spec_writer": "openrouter-claude-spec-writer@1"})
|
||||
with pytest.raises(MyDeepAgentError) as exc_info:
|
||||
bind_personas(spec_and_review, seed_personas, _all_available(), consent_store, override)
|
||||
assert exc_info.value.code == "persona_blocked_by_user"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _auto_select: preferred_backends order
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_auto_select_prefers_preferred_backend(spec_and_review: WorkflowTemplate) -> None:
|
||||
"""Persona with preferred backend wins over non-preferred even if alphabetically later."""
|
||||
from my_deepagent.binding import _auto_select
|
||||
|
||||
spec_writer_role = next(r for r in spec_and_review.roles if r.id == "spec_writer")
|
||||
# preferred_backends = ["openrouter"]
|
||||
p_openrouter = _minimal_persona(
|
||||
name="z-openrouter-persona",
|
||||
backend="openrouter",
|
||||
capabilities=["spec_write", "phase_planning"],
|
||||
)
|
||||
p_fake = _minimal_persona(
|
||||
name="a-fake-persona",
|
||||
backend="fake",
|
||||
capabilities=["spec_write", "phase_planning"],
|
||||
)
|
||||
chosen = _auto_select([p_openrouter, p_fake], spec_writer_role)
|
||||
assert chosen.name == "z-openrouter-persona"
|
||||
|
||||
|
||||
def test_auto_select_higher_version_wins(spec_and_review: WorkflowTemplate) -> None:
|
||||
from my_deepagent.binding import _auto_select
|
||||
|
||||
spec_writer_role = next(r for r in spec_and_review.roles if r.id == "spec_writer")
|
||||
p_v1 = _minimal_persona(version=1, capabilities=["spec_write", "phase_planning"])
|
||||
p_v2 = _minimal_persona(version=2, capabilities=["spec_write", "phase_planning"])
|
||||
chosen = _auto_select([p_v1, p_v2], spec_writer_role)
|
||||
assert chosen.version == 2
|
||||
|
||||
|
||||
def test_auto_select_name_asc_tiebreak(spec_and_review: WorkflowTemplate) -> None:
|
||||
from my_deepagent.binding import _auto_select
|
||||
|
||||
spec_writer_role = next(r for r in spec_and_review.roles if r.id == "spec_writer")
|
||||
caps = ["spec_write", "phase_planning"]
|
||||
p_b = _minimal_persona(name="b-persona", version=1, capabilities=caps)
|
||||
p_a = _minimal_persona(name="a-persona", version=1, capabilities=caps)
|
||||
chosen = _auto_select([p_b, p_a], spec_writer_role)
|
||||
assert chosen.name == "a-persona"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Step 2 patch: FAKE backend recovery hint
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_backend_recovery_hint_fake() -> None:
|
||||
"""FAKE backend recovery hint must mention 'fake' and 'tests only'."""
|
||||
from my_deepagent.binding import _backend_recovery_hint
|
||||
|
||||
hint = _backend_recovery_hint(Backend.FAKE)
|
||||
assert "fake" in hint.lower()
|
||||
assert "tests only" in hint.lower() or "test harness" in hint.lower()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Step 2 patch: override with non-integer version raises with diagnostic
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_bind_personas_override_non_integer_version_raises(
|
||||
seed_personas: list[Persona],
|
||||
spec_and_review: WorkflowTemplate,
|
||||
consent_store: PersonaConsentStore,
|
||||
) -> None:
|
||||
"""An override spec with a non-integer version must raise with clear diagnostic."""
|
||||
override = BindingOverride(persona_pinned={"spec_writer": "openrouter-claude-spec-writer@abc"})
|
||||
with pytest.raises(MyDeepAgentError) as exc_info:
|
||||
bind_personas(spec_and_review, seed_personas, _all_available(), consent_store, override)
|
||||
assert exc_info.value.code == "no_eligible_persona"
|
||||
assert "non-integer version" in str(exc_info.value)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Step 2 patch: override with ineligible persona surfaces reason
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_bind_personas_override_ineligible_persona_surfaces_reason(
|
||||
seed_personas: list[Persona],
|
||||
spec_and_review: WorkflowTemplate,
|
||||
consent_store: PersonaConsentStore,
|
||||
) -> None:
|
||||
"""Override that names an ineligible persona must surface the ineligibility reason."""
|
||||
# 'spec_writer' role needs spec_write + phase_planning.
|
||||
# Find a persona in seed that does NOT have those caps so we can force it.
|
||||
ineligible = next(
|
||||
p for p in seed_personas if "spec_write" not in [c.value for c in p.capabilities]
|
||||
)
|
||||
override = BindingOverride(
|
||||
persona_pinned={"spec_writer": f"{ineligible.name}@{ineligible.version}"}
|
||||
)
|
||||
with pytest.raises(MyDeepAgentError) as exc_info:
|
||||
bind_personas(spec_and_review, seed_personas, _all_available(), consent_store, override)
|
||||
assert exc_info.value.code == "no_eligible_persona"
|
||||
err_str = str(exc_info.value)
|
||||
# The error message must say the persona is ineligible with a reason.
|
||||
assert "ineligible" in err_str or "missing" in err_str
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Step 2 patch: PersonaConsentStore atomic write calls os.fsync
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_consent_store_write_calls_fsync(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""PersonaConsentStore.set() must call os.fsync() for atomic durability."""
|
||||
import os
|
||||
|
||||
called: list[int] = []
|
||||
orig_fsync = os.fsync
|
||||
|
||||
def spy(fd: int) -> None:
|
||||
called.append(fd)
|
||||
orig_fsync(fd)
|
||||
|
||||
monkeypatch.setattr(os, "fsync", spy)
|
||||
|
||||
store = PersonaConsentStore(tmp_path / "consents.json")
|
||||
store.set("hash_abc", "approve")
|
||||
|
||||
assert len(called) >= 1, "os.fsync must be called at least once during write"
|
||||
238
my-deepagent/tests/unit/test_config.py
Normal file
238
my-deepagent/tests/unit/test_config.py
Normal file
@@ -0,0 +1,238 @@
|
||||
"""Unit tests for src/my_deepagent/config.py."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from my_deepagent.config import Config, load_config
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Default values (no env, no file)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_default_log_level(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
_clear_env(monkeypatch)
|
||||
cfg = Config()
|
||||
assert cfg.log_level == "info"
|
||||
|
||||
|
||||
def test_default_lang(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
_clear_env(monkeypatch)
|
||||
cfg = Config()
|
||||
assert cfg.lang == "ko"
|
||||
|
||||
|
||||
def test_default_budget_daily_usd(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
_clear_env(monkeypatch)
|
||||
cfg = Config()
|
||||
assert cfg.budget_daily_usd == pytest.approx(5.0)
|
||||
|
||||
|
||||
def test_default_budget_run_usd(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
_clear_env(monkeypatch)
|
||||
cfg = Config()
|
||||
assert cfg.budget_run_usd == pytest.approx(1.0)
|
||||
|
||||
|
||||
def test_default_budget_on_hit(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
_clear_env(monkeypatch)
|
||||
cfg = Config()
|
||||
assert cfg.budget_on_hit == "prompt"
|
||||
|
||||
|
||||
def test_default_persona(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
_clear_env(monkeypatch)
|
||||
cfg = Config()
|
||||
assert cfg.default_persona == "default-interactive"
|
||||
|
||||
|
||||
def test_default_openrouter_api_key_is_none(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
_clear_env(monkeypatch)
|
||||
# _env_file=None bypasses any .env that may exist in the cwd (e.g. dev keys).
|
||||
cfg = Config(_env_file=None) # type: ignore[call-arg]
|
||||
assert cfg.openrouter_api_key is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Env var overrides
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_env_budget_daily_usd(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
_clear_env(monkeypatch)
|
||||
monkeypatch.setenv("MYDEEPAGENT_BUDGET_DAILY_USD", "10")
|
||||
cfg = Config()
|
||||
assert cfg.budget_daily_usd == pytest.approx(10.0)
|
||||
|
||||
|
||||
def test_env_lang_en(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
_clear_env(monkeypatch)
|
||||
monkeypatch.setenv("MYDEEPAGENT_LANG", "en")
|
||||
cfg = Config()
|
||||
assert cfg.lang == "en"
|
||||
|
||||
|
||||
def test_env_log_level_debug(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
_clear_env(monkeypatch)
|
||||
monkeypatch.setenv("MYDEEPAGENT_LOG_LEVEL", "debug")
|
||||
cfg = Config()
|
||||
assert cfg.log_level == "debug"
|
||||
|
||||
|
||||
def test_env_openrouter_api_key(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
_clear_env(monkeypatch)
|
||||
monkeypatch.setenv("MYDEEPAGENT_OPENROUTER_API_KEY", "sk-test-abc")
|
||||
cfg = Config()
|
||||
assert cfg.openrouter_api_key == "sk-test-abc"
|
||||
|
||||
|
||||
def test_env_langsmith_tracing(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
_clear_env(monkeypatch)
|
||||
monkeypatch.setenv("MYDEEPAGENT_LANGSMITH_TRACING", "true")
|
||||
cfg = Config()
|
||||
assert cfg.langsmith_tracing is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Validation errors for invalid values
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_invalid_lang_raises(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
_clear_env(monkeypatch)
|
||||
monkeypatch.setenv("MYDEEPAGENT_LANG", "fr")
|
||||
with pytest.raises(ValidationError):
|
||||
Config()
|
||||
|
||||
|
||||
def test_invalid_log_level_raises(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
_clear_env(monkeypatch)
|
||||
monkeypatch.setenv("MYDEEPAGENT_LOG_LEVEL", "verbose")
|
||||
with pytest.raises(ValidationError):
|
||||
Config()
|
||||
|
||||
|
||||
def test_invalid_budget_on_hit_raises(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
_clear_env(monkeypatch)
|
||||
monkeypatch.setenv("MYDEEPAGENT_BUDGET_ON_HIT", "explode")
|
||||
with pytest.raises(ValidationError):
|
||||
Config()
|
||||
|
||||
|
||||
def test_negative_budget_raises(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
_clear_env(monkeypatch)
|
||||
with pytest.raises(ValidationError):
|
||||
Config(budget_daily_usd=-1.0)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Frozen check
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_frozen_prevents_mutation(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
_clear_env(monkeypatch)
|
||||
cfg = Config()
|
||||
with pytest.raises((ValidationError, TypeError)):
|
||||
cfg.budget_daily_usd = 99 # type: ignore[misc]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Path expansion (~ → absolute path)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_tilde_expansion_workspace_root(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
_clear_env(monkeypatch)
|
||||
monkeypatch.setenv("MYDEEPAGENT_WORKSPACE_ROOT", "~/foo/bar")
|
||||
cfg = Config()
|
||||
assert cfg.workspace_root.is_absolute()
|
||||
assert "~" not in str(cfg.workspace_root)
|
||||
|
||||
|
||||
def test_tilde_expansion_data_dir(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
_clear_env(monkeypatch)
|
||||
monkeypatch.setenv("MYDEEPAGENT_DATA_DIR", "~/mydata")
|
||||
cfg = Config()
|
||||
assert cfg.data_dir.is_absolute()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# TOML priority
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_toml_overrides_default(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
|
||||
_clear_env(monkeypatch)
|
||||
toml_file = tmp_path / "config.toml"
|
||||
toml_file.write_text('lang = "en"\nbudget_daily_usd = 7.5\n')
|
||||
|
||||
# Patch the toml_file location via init override
|
||||
# Config reads toml via SettingsConfigDict; we pass via class-level override trick:
|
||||
# Easiest approach: pass budget_daily_usd and lang directly to assert TOML *can* set them.
|
||||
# For true TOML path injection, subclass Config temporarily.
|
||||
class PatchedConfig(Config):
|
||||
model_config = Config.model_config.copy()
|
||||
|
||||
PatchedConfig.model_config["toml_file"] = str(toml_file)
|
||||
|
||||
cfg = PatchedConfig()
|
||||
assert cfg.lang == "en"
|
||||
assert cfg.budget_daily_usd == pytest.approx(7.5)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# load_config helper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_load_config_with_overrides(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
_clear_env(monkeypatch)
|
||||
cfg = load_config(budget_daily_usd=20.0, lang="en")
|
||||
assert cfg.budget_daily_usd == pytest.approx(20.0)
|
||||
assert cfg.lang == "en"
|
||||
|
||||
|
||||
def test_load_config_default(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
_clear_env(monkeypatch)
|
||||
cfg = load_config()
|
||||
assert cfg.log_level == "info"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
_ENV_KEYS = [
|
||||
"MYDEEPAGENT_BUDGET_DAILY_USD",
|
||||
"MYDEEPAGENT_BUDGET_DAILY_WARN_USD",
|
||||
"MYDEEPAGENT_BUDGET_RUN_USD",
|
||||
"MYDEEPAGENT_BUDGET_RUN_WARN_USD",
|
||||
"MYDEEPAGENT_BUDGET_ON_HIT",
|
||||
"MYDEEPAGENT_LANG",
|
||||
"MYDEEPAGENT_LOG_LEVEL",
|
||||
"MYDEEPAGENT_OPENROUTER_API_KEY",
|
||||
"MYDEEPAGENT_OPENROUTER_BASE_URL",
|
||||
"MYDEEPAGENT_LANGSMITH_TRACING",
|
||||
"MYDEEPAGENT_LANGSMITH_API_KEY",
|
||||
"MYDEEPAGENT_LANGSMITH_PROJECT",
|
||||
"MYDEEPAGENT_DATABASE_URL",
|
||||
"MYDEEPAGENT_WORKSPACE_ROOT",
|
||||
"MYDEEPAGENT_DATA_DIR",
|
||||
"MYDEEPAGENT_CONFIG_DIR",
|
||||
"MYDEEPAGENT_STATE_DIR",
|
||||
"MYDEEPAGENT_DEFAULT_PERSONA",
|
||||
]
|
||||
|
||||
|
||||
def _clear_env(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Remove all MYDEEPAGENT_ env vars to isolate tests from the real environment."""
|
||||
for key in _ENV_KEYS:
|
||||
monkeypatch.delenv(key, raising=False)
|
||||
# Also prevent dotenv file from being loaded
|
||||
monkeypatch.setenv("MYDEEPAGENT_ENV_FILE", "")
|
||||
235
my-deepagent/tests/unit/test_enums.py
Normal file
235
my-deepagent/tests/unit/test_enums.py
Normal file
@@ -0,0 +1,235 @@
|
||||
"""Unit tests for src/my_deepagent/enums.py."""
|
||||
|
||||
import pytest
|
||||
|
||||
from my_deepagent.enums import (
|
||||
ApprovalDecisionAction,
|
||||
ApprovalState,
|
||||
Backend,
|
||||
Capability,
|
||||
ErrorClass,
|
||||
RiskLevel,
|
||||
RunPhaseState,
|
||||
RunState,
|
||||
SessionState,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Backend
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_backend_openrouter_value() -> None:
|
||||
assert Backend.OPENROUTER == "openrouter"
|
||||
|
||||
|
||||
def test_backend_anthropic_value() -> None:
|
||||
assert Backend.ANTHROPIC == "anthropic"
|
||||
|
||||
|
||||
def test_backend_openai_value() -> None:
|
||||
assert Backend.OPENAI == "openai"
|
||||
|
||||
|
||||
def test_backend_google_value() -> None:
|
||||
assert Backend.GOOGLE == "google"
|
||||
|
||||
|
||||
def test_backend_fake_value() -> None:
|
||||
assert Backend.FAKE == "fake"
|
||||
|
||||
|
||||
def test_backend_str_equality() -> None:
|
||||
# StrEnum members compare equal to their string values
|
||||
assert Backend.OPENROUTER == "openrouter"
|
||||
assert str(Backend.OPENROUTER) == "openrouter"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Capability
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_capability_count() -> None:
|
||||
assert len(list(Capability)) == 13
|
||||
|
||||
|
||||
def test_capability_spec_write() -> None:
|
||||
assert Capability.SPEC_WRITE == "spec_write"
|
||||
|
||||
|
||||
def test_capability_code_edit() -> None:
|
||||
assert Capability.CODE_EDIT == "code_edit"
|
||||
|
||||
|
||||
def test_capability_final_report_compose() -> None:
|
||||
assert Capability.FINAL_REPORT_COMPOSE == "final_report_compose"
|
||||
|
||||
|
||||
def test_capability_all_are_str() -> None:
|
||||
for cap in Capability:
|
||||
assert isinstance(cap, str)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RiskLevel
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_risk_level_values() -> None:
|
||||
assert RiskLevel.LOW == "low"
|
||||
assert RiskLevel.MEDIUM == "medium"
|
||||
assert RiskLevel.HIGH == "high"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ApprovalDecisionAction
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_approval_decision_action_approve() -> None:
|
||||
assert ApprovalDecisionAction.APPROVE == "approve"
|
||||
|
||||
|
||||
def test_approval_decision_action_reject() -> None:
|
||||
assert ApprovalDecisionAction.REJECT == "reject"
|
||||
|
||||
|
||||
def test_approval_decision_action_request_changes() -> None:
|
||||
assert ApprovalDecisionAction.REQUEST_CHANGES == "request_changes"
|
||||
|
||||
|
||||
def test_approval_decision_action_abort() -> None:
|
||||
assert ApprovalDecisionAction.ABORT == "abort"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ApprovalState
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_approval_state_all_values() -> None:
|
||||
expected = {"pending", "approved", "rejected", "changes_requested", "aborted", "paused"}
|
||||
actual = {s.value for s in ApprovalState}
|
||||
assert actual == expected
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RunState
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_run_state_all_values() -> None:
|
||||
expected = {
|
||||
"created",
|
||||
"bound",
|
||||
"planning",
|
||||
"awaiting_approval",
|
||||
"executing",
|
||||
"paused",
|
||||
"completed",
|
||||
"failed",
|
||||
"aborted",
|
||||
}
|
||||
actual = {s.value for s in RunState}
|
||||
assert actual == expected
|
||||
|
||||
|
||||
def test_run_state_count() -> None:
|
||||
assert len(list(RunState)) == 9
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RunPhaseState
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_run_phase_state_all_values() -> None:
|
||||
expected = {
|
||||
"pending",
|
||||
"running",
|
||||
"awaiting_artifact",
|
||||
"validating",
|
||||
"awaiting_approval",
|
||||
"completed",
|
||||
"failed",
|
||||
"skipped",
|
||||
}
|
||||
actual = {s.value for s in RunPhaseState}
|
||||
assert actual == expected
|
||||
|
||||
|
||||
def test_run_phase_state_count() -> None:
|
||||
assert len(list(RunPhaseState)) == 8
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SessionState
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_session_state_all_values() -> None:
|
||||
expected = {
|
||||
"CREATED",
|
||||
"BOOTSTRAPPING",
|
||||
"READY",
|
||||
"BUSY",
|
||||
"WAITING_FOR_APPROVAL",
|
||||
"ARTIFACT_TIMEOUT",
|
||||
"HUNG",
|
||||
"CRASHED",
|
||||
"RESUMING",
|
||||
"REBOOTSTRAPPED",
|
||||
"FAILED_NEEDS_HUMAN",
|
||||
}
|
||||
actual = {s.value for s in SessionState}
|
||||
assert actual == expected
|
||||
|
||||
|
||||
def test_session_state_count() -> None:
|
||||
assert len(list(SessionState)) == 11
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ErrorClass
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_error_class_recoverable() -> None:
|
||||
assert ErrorClass.RECOVERABLE == "recoverable"
|
||||
|
||||
|
||||
def test_error_class_human_required() -> None:
|
||||
assert ErrorClass.HUMAN_REQUIRED == "human_required"
|
||||
|
||||
|
||||
def test_error_class_fatal() -> None:
|
||||
assert ErrorClass.FATAL == "fatal"
|
||||
|
||||
|
||||
def test_error_class_count() -> None:
|
||||
assert len(list(ErrorClass)) == 3
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# StrEnum serialization / deserialization
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_str_enum_from_value() -> None:
|
||||
assert Backend("openrouter") is Backend.OPENROUTER
|
||||
|
||||
|
||||
def test_str_enum_in_dict() -> None:
|
||||
# StrEnum should work as dict key and compare with string
|
||||
d = {Backend.OPENROUTER: "openrouter backend"}
|
||||
assert d["openrouter"] == "openrouter backend"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"state",
|
||||
list(RunState),
|
||||
)
|
||||
def test_run_state_parametrize(state: RunState) -> None:
|
||||
assert isinstance(state, str)
|
||||
assert RunState(state.value) is state
|
||||
208
my-deepagent/tests/unit/test_errors.py
Normal file
208
my-deepagent/tests/unit/test_errors.py
Normal file
@@ -0,0 +1,208 @@
|
||||
"""Unit tests for src/my_deepagent/errors.py."""
|
||||
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from my_deepagent.enums import ErrorClass
|
||||
from my_deepagent.errors import BudgetExhaustedError, MyDeepAgentError
|
||||
|
||||
|
||||
def test_cause_sets_suppress_context() -> None:
|
||||
"""Wrapping a cause must suppress the implicit context per PEP 3134."""
|
||||
original = ValueError("root cause")
|
||||
err = MyDeepAgentError.recoverable("wrapped", cause=original)
|
||||
assert err.__cause__ is original
|
||||
assert err.__suppress_context__ is True
|
||||
|
||||
|
||||
def test_no_cause_does_not_set_suppress_context() -> None:
|
||||
err = MyDeepAgentError.recoverable("no_cause")
|
||||
assert err.__cause__ is None
|
||||
assert err.__suppress_context__ is False
|
||||
|
||||
|
||||
def test_factory_returns_base_class_not_subclass() -> None:
|
||||
"""LSP fix: factory methods always return MyDeepAgentError, never BudgetExhaustedError."""
|
||||
err = BudgetExhaustedError.recoverable("foo")
|
||||
assert type(err) is MyDeepAgentError
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MyDeepAgentError factory methods
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_recoverable_class() -> None:
|
||||
err = MyDeepAgentError.recoverable("network_blip", recovery_hint="retry")
|
||||
assert err.error_class == ErrorClass.RECOVERABLE
|
||||
|
||||
|
||||
def test_recoverable_code() -> None:
|
||||
err = MyDeepAgentError.recoverable("network_blip")
|
||||
assert err.code == "network_blip"
|
||||
|
||||
|
||||
def test_recoverable_recovery_hint() -> None:
|
||||
err = MyDeepAgentError.recoverable("network_blip", recovery_hint="retry after 1s")
|
||||
assert err.recovery_hint == "retry after 1s"
|
||||
|
||||
|
||||
def test_human_required_class() -> None:
|
||||
err = MyDeepAgentError.human_required("destructive_command_blocked")
|
||||
assert err.error_class == ErrorClass.HUMAN_REQUIRED
|
||||
|
||||
|
||||
def test_human_required_code() -> None:
|
||||
err = MyDeepAgentError.human_required("destructive_command_blocked")
|
||||
assert err.code == "destructive_command_blocked"
|
||||
|
||||
|
||||
def test_fatal_class() -> None:
|
||||
err = MyDeepAgentError.fatal("unrecoverable_state")
|
||||
assert err.error_class == ErrorClass.FATAL
|
||||
|
||||
|
||||
def test_fatal_code() -> None:
|
||||
err = MyDeepAgentError.fatal("unrecoverable_state")
|
||||
assert err.code == "unrecoverable_state"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# run_id / phase_id context
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_run_id_attached() -> None:
|
||||
run_id = uuid4()
|
||||
err = MyDeepAgentError.recoverable("timeout", run_id=run_id)
|
||||
assert err.run_id == run_id
|
||||
|
||||
|
||||
def test_phase_id_attached() -> None:
|
||||
phase_id = uuid4()
|
||||
err = MyDeepAgentError.recoverable("artifact_missing", phase_id=phase_id)
|
||||
assert err.phase_id == phase_id
|
||||
|
||||
|
||||
def test_run_id_none_by_default() -> None:
|
||||
err = MyDeepAgentError.recoverable("x")
|
||||
assert err.run_id is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# __cause__ propagation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_cause_propagation() -> None:
|
||||
original = ValueError("root cause")
|
||||
err = MyDeepAgentError.recoverable("wrapped", cause=original)
|
||||
assert err.__cause__ is original
|
||||
|
||||
|
||||
def test_cause_none_by_default() -> None:
|
||||
err = MyDeepAgentError.recoverable("no_cause")
|
||||
assert err.__cause__ is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# __repr__ format
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_repr_contains_class_and_code() -> None:
|
||||
err = MyDeepAgentError.recoverable("some_code")
|
||||
r = repr(err)
|
||||
assert "class=recoverable" in r
|
||||
assert "code=some_code" in r
|
||||
|
||||
|
||||
def test_repr_contains_run_id_when_present() -> None:
|
||||
run_id = UUID("12345678-1234-5678-1234-567812345678")
|
||||
err = MyDeepAgentError.recoverable("x", run_id=run_id)
|
||||
assert str(run_id) in repr(err)
|
||||
|
||||
|
||||
def test_repr_contains_hint_when_present() -> None:
|
||||
err = MyDeepAgentError.recoverable("x", recovery_hint="do something")
|
||||
assert "do something" in repr(err)
|
||||
|
||||
|
||||
def test_repr_no_hint_when_absent() -> None:
|
||||
err = MyDeepAgentError.recoverable("x")
|
||||
assert "hint" not in repr(err)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Exception hierarchy
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_my_deepagent_error_is_exception() -> None:
|
||||
err = MyDeepAgentError.recoverable("x")
|
||||
assert isinstance(err, Exception)
|
||||
|
||||
|
||||
def test_budget_exhausted_is_my_deepagent_error() -> None:
|
||||
err = BudgetExhaustedError("day:2026-05-15", 1.20, 1.00)
|
||||
assert isinstance(err, MyDeepAgentError)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# BudgetExhaustedError
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_budget_exhausted_scope() -> None:
|
||||
err = BudgetExhaustedError("day:2026-05-15", 1.20, 1.00)
|
||||
assert err.scope == "day:2026-05-15"
|
||||
|
||||
|
||||
def test_budget_exhausted_projected_usd() -> None:
|
||||
err = BudgetExhaustedError("day:2026-05-15", 1.20, 1.00)
|
||||
assert err.projected_usd == pytest.approx(1.20)
|
||||
|
||||
|
||||
def test_budget_exhausted_cap_usd() -> None:
|
||||
err = BudgetExhaustedError("day:2026-05-15", 1.20, 1.00)
|
||||
assert err.cap_usd == pytest.approx(1.00)
|
||||
|
||||
|
||||
def test_budget_exhausted_error_class() -> None:
|
||||
err = BudgetExhaustedError("day:2026-05-15", 1.20, 1.00)
|
||||
assert err.error_class == ErrorClass.HUMAN_REQUIRED
|
||||
|
||||
|
||||
def test_budget_exhausted_code() -> None:
|
||||
err = BudgetExhaustedError("day:2026-05-15", 1.20, 1.00)
|
||||
assert err.code == "budget_exhausted"
|
||||
|
||||
|
||||
def test_budget_exhausted_default_recovery_hint() -> None:
|
||||
err = BudgetExhaustedError("day:2026-05-15", 1.20, 1.00)
|
||||
assert err.recovery_hint is not None
|
||||
assert len(err.recovery_hint) > 0
|
||||
|
||||
|
||||
def test_budget_exhausted_custom_recovery_hint() -> None:
|
||||
err = BudgetExhaustedError("day:2026-05-15", 1.20, 1.00, recovery_hint="call support")
|
||||
assert err.recovery_hint == "call support"
|
||||
|
||||
|
||||
def test_budget_exhausted_run_id() -> None:
|
||||
run_id = uuid4()
|
||||
err = BudgetExhaustedError("run:abc", 0.5, 0.4, run_id=run_id)
|
||||
assert err.run_id == run_id
|
||||
|
||||
|
||||
def test_budget_exhausted_message_contains_scope() -> None:
|
||||
err = BudgetExhaustedError("day:2026-05-15", 1.20, 1.00)
|
||||
assert "day:2026-05-15" in str(err)
|
||||
|
||||
|
||||
def test_budget_exhausted_message_contains_values() -> None:
|
||||
err = BudgetExhaustedError("scope", 1.2345, 1.0000)
|
||||
msg = str(err)
|
||||
assert "1.2345" in msg
|
||||
assert "1.0000" in msg
|
||||
121
my-deepagent/tests/unit/test_hash.py
Normal file
121
my-deepagent/tests/unit/test_hash.py
Normal file
@@ -0,0 +1,121 @@
|
||||
"""Unit tests for src/my_deepagent/hash.py."""
|
||||
|
||||
import re
|
||||
|
||||
import pytest
|
||||
|
||||
from my_deepagent.hash import canonicalize, sha256
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# canonicalize: key ordering
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_canonicalize_sorts_keys() -> None:
|
||||
assert canonicalize({"b": 1, "a": 2}) == '{"a":2,"b":1}'
|
||||
|
||||
|
||||
def test_canonicalize_nested_sorts_keys() -> None:
|
||||
result = canonicalize({"x": {"b": 2, "a": 1}})
|
||||
assert result == '{"x":{"a":1,"b":2}}'
|
||||
|
||||
|
||||
def test_canonicalize_empty_dict() -> None:
|
||||
assert canonicalize({}) == "{}"
|
||||
|
||||
|
||||
def test_canonicalize_empty_list() -> None:
|
||||
assert canonicalize([]) == "[]"
|
||||
|
||||
|
||||
def test_canonicalize_none() -> None:
|
||||
assert canonicalize(None) == "null"
|
||||
|
||||
|
||||
def test_canonicalize_integer() -> None:
|
||||
assert canonicalize(42) == "42"
|
||||
|
||||
|
||||
def test_canonicalize_float() -> None:
|
||||
# 0.1 has a known floating-point representation
|
||||
result = canonicalize(0.1)
|
||||
assert result == "0.1"
|
||||
|
||||
|
||||
def test_canonicalize_no_whitespace() -> None:
|
||||
result = canonicalize({"a": 1, "b": 2})
|
||||
assert " " not in result
|
||||
|
||||
|
||||
def test_canonicalize_list_preserves_order() -> None:
|
||||
# Lists should not be reordered
|
||||
assert canonicalize([3, 1, 2]) == "[3,1,2]"
|
||||
|
||||
|
||||
def test_canonicalize_string_value() -> None:
|
||||
assert canonicalize("hello") == '"hello"'
|
||||
|
||||
|
||||
def test_canonicalize_boolean() -> None:
|
||||
assert canonicalize(True) == "true"
|
||||
assert canonicalize(False) == "false"
|
||||
|
||||
|
||||
def test_canonicalize_nan_raises() -> None:
|
||||
import math
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
canonicalize(math.nan)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# sha256: determinism
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_sha256_deterministic() -> None:
|
||||
value = {"a": 1, "b": [1, 2, 3]}
|
||||
results = [sha256(value) for _ in range(100)]
|
||||
assert len(set(results)) == 1
|
||||
|
||||
|
||||
def test_sha256_returns_64_char_hex() -> None:
|
||||
result = sha256({"a": 1})
|
||||
assert re.fullmatch(r"[0-9a-f]{64}", result) is not None
|
||||
|
||||
|
||||
def test_sha256_different_inputs_different_hash() -> None:
|
||||
h1 = sha256({"a": 1})
|
||||
h2 = sha256({"a": 2})
|
||||
assert h1 != h2
|
||||
|
||||
|
||||
def test_sha256_key_order_irrelevant() -> None:
|
||||
# Same content, different insertion order → same hash
|
||||
h1 = sha256({"a": 1, "b": 2})
|
||||
h2 = sha256({"b": 2, "a": 1})
|
||||
assert h1 == h2
|
||||
|
||||
|
||||
def test_sha256_empty_dict() -> None:
|
||||
result = sha256({})
|
||||
assert re.fullmatch(r"[0-9a-f]{64}", result) is not None
|
||||
|
||||
|
||||
def test_sha256_none() -> None:
|
||||
result = sha256(None)
|
||||
assert re.fullmatch(r"[0-9a-f]{64}", result) is not None
|
||||
|
||||
|
||||
def test_sha256_nested() -> None:
|
||||
h1 = sha256({"x": {"a": 1, "b": 2}})
|
||||
h2 = sha256({"x": {"b": 2, "a": 1}})
|
||||
assert h1 == h2
|
||||
|
||||
|
||||
def test_sha256_known_value() -> None:
|
||||
# Pre-computed: sha256('{"a":1}') in UTF-8
|
||||
import hashlib
|
||||
|
||||
expected = hashlib.sha256(b'{"a":1}').hexdigest()
|
||||
assert sha256({"a": 1}) == expected
|
||||
118
my-deepagent/tests/unit/test_middleware_audit.py
Normal file
118
my-deepagent/tests/unit/test_middleware_audit.py
Normal file
@@ -0,0 +1,118 @@
|
||||
"""Unit tests for src/my_deepagent/middleware/audit.py."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
|
||||
from my_deepagent.middleware.audit import AuditToolMiddleware
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_request(name: str = "read_file", args: dict[str, Any] | None = None) -> MagicMock:
|
||||
request = MagicMock()
|
||||
request.tool_call = {"name": name, "args": args or {"path": "x.py"}}
|
||||
return request
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# awrap_tool_call — success path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_audit_middleware_records_correct_fields_on_success() -> None:
|
||||
recorder = AsyncMock()
|
||||
mw = AuditToolMiddleware(
|
||||
run_id=UUID("00000000-0000-0000-0000-000000000001"),
|
||||
phase_id=UUID("00000000-0000-0000-0000-000000000002"),
|
||||
interactive_session_id=UUID("00000000-0000-0000-0000-000000000003"),
|
||||
recorder=recorder,
|
||||
)
|
||||
result_value = "file contents here"
|
||||
handler = AsyncMock(return_value=result_value)
|
||||
request = _make_request(name="read_file", args={"path": "src/main.py"})
|
||||
|
||||
result = await mw.awrap_tool_call(request, handler)
|
||||
|
||||
assert result == result_value
|
||||
recorder.assert_awaited_once()
|
||||
record: dict[str, Any] = recorder.call_args[0][0]
|
||||
assert record["tool_name"] == "read_file"
|
||||
assert record["args"] == {"path": "src/main.py"}
|
||||
assert record["result"] == result_value
|
||||
assert record["error"] is None
|
||||
assert record["duration_ms"] >= 0
|
||||
assert record["run_id"] == UUID("00000000-0000-0000-0000-000000000001")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_audit_middleware_no_recorder_is_noop() -> None:
|
||||
mw = AuditToolMiddleware()
|
||||
handler = AsyncMock(return_value="ok")
|
||||
result = await mw.awrap_tool_call(_make_request(), handler)
|
||||
assert result == "ok"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# awrap_tool_call — error path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_audit_middleware_records_error_code_on_exception() -> None:
|
||||
recorder = AsyncMock()
|
||||
mw = AuditToolMiddleware(recorder=recorder)
|
||||
handler = AsyncMock(side_effect=PermissionError("access denied"))
|
||||
|
||||
with pytest.raises(PermissionError):
|
||||
await mw.awrap_tool_call(_make_request(), handler)
|
||||
|
||||
recorder.assert_awaited_once()
|
||||
record: dict[str, Any] = recorder.call_args[0][0]
|
||||
assert record["error"] == "PermissionError"
|
||||
assert record["result"] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_audit_middleware_reraises_exception() -> None:
|
||||
mw = AuditToolMiddleware(recorder=AsyncMock())
|
||||
handler = AsyncMock(side_effect=ValueError("bad args"))
|
||||
with pytest.raises(ValueError, match="bad args"):
|
||||
await mw.awrap_tool_call(_make_request(), handler)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# result serialization
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_audit_middleware_serializes_non_primitive_result_as_str() -> None:
|
||||
recorder = AsyncMock()
|
||||
mw = AuditToolMiddleware(recorder=recorder)
|
||||
|
||||
class _CustomResult:
|
||||
def __str__(self) -> str:
|
||||
return "custom-result-str"
|
||||
|
||||
handler = AsyncMock(return_value=_CustomResult())
|
||||
await mw.awrap_tool_call(_make_request(), handler)
|
||||
record = recorder.call_args[0][0]
|
||||
assert record["result"] == "custom-result-str"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_audit_middleware_passes_dict_result_as_is() -> None:
|
||||
recorder = AsyncMock()
|
||||
mw = AuditToolMiddleware(recorder=recorder)
|
||||
handler = AsyncMock(return_value={"key": "value"})
|
||||
await mw.awrap_tool_call(_make_request(), handler)
|
||||
record = recorder.call_args[0][0]
|
||||
assert record["result"] == {"key": "value"}
|
||||
143
my-deepagent/tests/unit/test_middleware_cost.py
Normal file
143
my-deepagent/tests/unit/test_middleware_cost.py
Normal file
@@ -0,0 +1,143 @@
|
||||
"""Unit tests for src/my_deepagent/middleware/cost.py."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
|
||||
from my_deepagent.middleware.cost import CostMiddleware
|
||||
from my_deepagent.monitoring.pricing import ModelPrice, PricingCache
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_pricing_cache(
|
||||
model: str = "anthropic/claude-sonnet",
|
||||
input_per_1k: float = 0.003,
|
||||
output_per_1k: float = 0.015,
|
||||
) -> PricingCache:
|
||||
cache = PricingCache()
|
||||
cache.set(
|
||||
[
|
||||
ModelPrice(
|
||||
model=model,
|
||||
input_per_1k_usd=input_per_1k,
|
||||
output_per_1k_usd=output_per_1k,
|
||||
context_length=200000,
|
||||
)
|
||||
]
|
||||
)
|
||||
return cache
|
||||
|
||||
|
||||
def _make_response(input_tokens: int = 100, output_tokens: int = 50) -> MagicMock:
|
||||
response = MagicMock()
|
||||
response.usage_metadata = {"input_tokens": input_tokens, "output_tokens": output_tokens}
|
||||
return response
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# awrap_model_call — success path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cost_middleware_records_correct_fields_on_success() -> None:
|
||||
recorder = AsyncMock()
|
||||
cache = _make_pricing_cache()
|
||||
mw = CostMiddleware(
|
||||
pricing=cache,
|
||||
model_name="anthropic/claude-sonnet",
|
||||
run_id=UUID("00000000-0000-0000-0000-000000000001"),
|
||||
phase_id=UUID("00000000-0000-0000-0000-000000000002"),
|
||||
persona_name="test-persona",
|
||||
recorder=recorder,
|
||||
)
|
||||
response = _make_response(input_tokens=1000, output_tokens=500)
|
||||
handler = AsyncMock(return_value=response)
|
||||
request = MagicMock()
|
||||
|
||||
result = await mw.awrap_model_call(request, handler)
|
||||
|
||||
assert result is response
|
||||
recorder.assert_awaited_once()
|
||||
record: dict[str, Any] = recorder.call_args[0][0]
|
||||
assert record["model"] == "anthropic/claude-sonnet"
|
||||
assert record["input_tokens"] == 1000
|
||||
assert record["output_tokens"] == 500
|
||||
assert record["status"] == "ok"
|
||||
assert record["error_code"] is None
|
||||
assert record["latency_ms"] >= 0
|
||||
# cost: (1000/1000 * 0.003) + (500/1000 * 0.015)
|
||||
expected_cost = 0.003 * 1.0 + 0.015 * 0.5
|
||||
assert record["cost_usd_total"] == pytest.approx(expected_cost)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cost_middleware_no_recorder_is_noop() -> None:
|
||||
cache = _make_pricing_cache()
|
||||
mw = CostMiddleware(pricing=cache, model_name="anthropic/claude-sonnet")
|
||||
response = _make_response()
|
||||
handler = AsyncMock(return_value=response)
|
||||
# Should not raise even with recorder=None
|
||||
result = await mw.awrap_model_call(MagicMock(), handler)
|
||||
assert result is response
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# awrap_model_call — error path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cost_middleware_records_error_on_handler_exception() -> None:
|
||||
recorder = AsyncMock()
|
||||
cache = _make_pricing_cache()
|
||||
mw = CostMiddleware(
|
||||
pricing=cache,
|
||||
model_name="anthropic/claude-sonnet",
|
||||
recorder=recorder,
|
||||
)
|
||||
handler = AsyncMock(side_effect=RuntimeError("timeout"))
|
||||
|
||||
with pytest.raises(RuntimeError, match="timeout"):
|
||||
await mw.awrap_model_call(MagicMock(), handler)
|
||||
|
||||
recorder.assert_awaited_once()
|
||||
record: dict[str, Any] = recorder.call_args[0][0]
|
||||
assert record["status"] == "error"
|
||||
assert record["error_code"] == "RuntimeError"
|
||||
assert record["input_tokens"] == 0
|
||||
assert record["output_tokens"] == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cost_middleware_reraises_exception() -> None:
|
||||
cache = _make_pricing_cache()
|
||||
mw = CostMiddleware(pricing=cache, model_name="m", recorder=AsyncMock())
|
||||
handler = AsyncMock(side_effect=ValueError("bad input"))
|
||||
|
||||
with pytest.raises(ValueError, match="bad input"):
|
||||
await mw.awrap_model_call(MagicMock(), handler)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# cost computation via cache
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cost_zero_when_model_not_in_cache() -> None:
|
||||
recorder = AsyncMock()
|
||||
cache = PricingCache() # empty
|
||||
mw = CostMiddleware(pricing=cache, model_name="unknown/model", recorder=recorder)
|
||||
response = _make_response(input_tokens=1000, output_tokens=1000)
|
||||
handler = AsyncMock(return_value=response)
|
||||
await mw.awrap_model_call(MagicMock(), handler)
|
||||
record = recorder.call_args[0][0]
|
||||
assert record["cost_usd_total"] == 0.0
|
||||
168
my-deepagent/tests/unit/test_middleware_fallback.py
Normal file
168
my-deepagent/tests/unit/test_middleware_fallback.py
Normal file
@@ -0,0 +1,168 @@
|
||||
"""Unit tests for src/my_deepagent/middleware/fallback.py."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import httpx
|
||||
import openai
|
||||
import pytest
|
||||
|
||||
from my_deepagent.middleware.fallback import FallbackModelMiddleware
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_request(has_model_attr: bool = True) -> MagicMock:
|
||||
request = MagicMock()
|
||||
if not has_model_attr:
|
||||
del request.model
|
||||
return request
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fallback on RateLimitError
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_on_rate_limit_error_calls_handler_with_fallback() -> None:
|
||||
primary = MagicMock(name="primary-model")
|
||||
fallback = MagicMock(name="fallback-model")
|
||||
mw = FallbackModelMiddleware(primary=primary, fallback=fallback)
|
||||
|
||||
call_count = 0
|
||||
fallback_model_seen: Any = None
|
||||
|
||||
async def handler(request: Any) -> str:
|
||||
nonlocal call_count, fallback_model_seen
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
raise openai.RateLimitError(
|
||||
"rate limit",
|
||||
response=MagicMock(status_code=429, headers={}),
|
||||
body={},
|
||||
)
|
||||
fallback_model_seen = getattr(request, "model", None)
|
||||
return "fallback-response"
|
||||
|
||||
request = _make_request()
|
||||
result = await mw.awrap_model_call(request, handler)
|
||||
assert result == "fallback-response"
|
||||
assert call_count == 2
|
||||
assert fallback_model_seen is fallback
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_on_api_connection_error() -> None:
|
||||
primary = MagicMock()
|
||||
fallback = MagicMock()
|
||||
mw = FallbackModelMiddleware(primary=primary, fallback=fallback)
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def handler(request: Any) -> str:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
raise openai.APIConnectionError(request=MagicMock())
|
||||
return "connection-fallback"
|
||||
|
||||
result = await mw.awrap_model_call(_make_request(), handler)
|
||||
assert result == "connection-fallback"
|
||||
assert call_count == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_on_httpx_error() -> None:
|
||||
primary = MagicMock()
|
||||
fallback = MagicMock()
|
||||
mw = FallbackModelMiddleware(primary=primary, fallback=fallback)
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def handler(request: Any) -> str:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
raise httpx.ConnectError("connect failed")
|
||||
return "httpx-fallback"
|
||||
|
||||
result = await mw.awrap_model_call(_make_request(), handler)
|
||||
assert result == "httpx-fallback"
|
||||
assert call_count == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# No fallback — exception propagates
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_fallback_raises_original_error() -> None:
|
||||
mw = FallbackModelMiddleware(primary=MagicMock(), fallback=None)
|
||||
handler = AsyncMock(
|
||||
side_effect=openai.RateLimitError(
|
||||
"rate limit",
|
||||
response=MagicMock(status_code=429, headers={}),
|
||||
body={},
|
||||
)
|
||||
)
|
||||
with pytest.raises(openai.RateLimitError):
|
||||
await mw.awrap_model_call(_make_request(), handler)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# AuthenticationError — never retried
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auth_error_is_not_retried() -> None:
|
||||
primary = MagicMock()
|
||||
fallback = MagicMock()
|
||||
mw = FallbackModelMiddleware(primary=primary, fallback=fallback)
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def handler(request: Any) -> str:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
raise openai.AuthenticationError(
|
||||
"bad api key",
|
||||
response=MagicMock(status_code=401, headers={}),
|
||||
body={},
|
||||
)
|
||||
|
||||
with pytest.raises(openai.AuthenticationError):
|
||||
await mw.awrap_model_call(_make_request(), handler)
|
||||
|
||||
# Handler should only be called once (no retry for auth errors)
|
||||
assert call_count == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _with_fallback_model
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_with_fallback_model_swaps_model_attribute() -> None:
|
||||
primary = MagicMock(name="primary")
|
||||
fallback = MagicMock(name="fallback")
|
||||
mw = FallbackModelMiddleware(primary=primary, fallback=fallback)
|
||||
|
||||
request = MagicMock()
|
||||
request.model = primary
|
||||
patched = mw._with_fallback_model(request)
|
||||
assert patched.model is fallback
|
||||
|
||||
|
||||
def test_with_fallback_model_no_model_attr_does_not_crash() -> None:
|
||||
mw = FallbackModelMiddleware(primary=MagicMock(), fallback=MagicMock())
|
||||
request = MagicMock(spec=[]) # no attributes
|
||||
# Should not raise
|
||||
patched = mw._with_fallback_model(request)
|
||||
assert patched is request
|
||||
258
my-deepagent/tests/unit/test_middleware_safety.py
Normal file
258
my-deepagent/tests/unit/test_middleware_safety.py
Normal file
@@ -0,0 +1,258 @@
|
||||
"""Unit tests for src/my_deepagent/middleware/safety.py."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from my_deepagent.errors import MyDeepAgentError
|
||||
from my_deepagent.middleware.safety import SafetyShellMiddleware, _is_denied_path
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_shell_request(cmd: str | list[str], tool_name: str = "shell") -> MagicMock:
|
||||
request = MagicMock()
|
||||
if isinstance(cmd, list):
|
||||
request.tool_call = {"name": tool_name, "args": {"argv": cmd}}
|
||||
else:
|
||||
request.tool_call = {"name": tool_name, "args": {"command": cmd}}
|
||||
return request
|
||||
|
||||
|
||||
def _make_other_tool_request(
|
||||
name: str = "read_file", args: dict[str, Any] | None = None
|
||||
) -> MagicMock:
|
||||
request = MagicMock()
|
||||
request.tool_call = {"name": name, "args": args or {}}
|
||||
return request
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Destructive commands — should raise
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rm_rf_slash_is_blocked() -> None:
|
||||
mw = SafetyShellMiddleware()
|
||||
with pytest.raises(MyDeepAgentError) as exc_info:
|
||||
await mw.awrap_tool_call(_make_shell_request("rm -rf /"), AsyncMock())
|
||||
assert exc_info.value.code == "destructive_command_blocked"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rm_rf_with_path_is_blocked() -> None:
|
||||
mw = SafetyShellMiddleware()
|
||||
with pytest.raises(MyDeepAgentError) as exc_info:
|
||||
await mw.awrap_tool_call(_make_shell_request("rm -rf ./build"), AsyncMock())
|
||||
assert exc_info.value.code == "destructive_command_blocked"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_git_push_force_is_blocked() -> None:
|
||||
mw = SafetyShellMiddleware()
|
||||
with pytest.raises(MyDeepAgentError):
|
||||
await mw.awrap_tool_call(_make_shell_request("git push --force origin main"), AsyncMock())
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_git_push_force_with_lease_is_blocked() -> None:
|
||||
mw = SafetyShellMiddleware()
|
||||
with pytest.raises(MyDeepAgentError):
|
||||
await mw.awrap_tool_call(
|
||||
_make_shell_request("git push --force-with-lease origin main"), AsyncMock()
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_git_reset_hard_is_blocked() -> None:
|
||||
mw = SafetyShellMiddleware()
|
||||
with pytest.raises(MyDeepAgentError):
|
||||
await mw.awrap_tool_call(_make_shell_request("git reset --hard HEAD"), AsyncMock())
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_git_clean_is_blocked() -> None:
|
||||
mw = SafetyShellMiddleware()
|
||||
with pytest.raises(MyDeepAgentError):
|
||||
await mw.awrap_tool_call(_make_shell_request("git clean -fd"), AsyncMock())
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_drop_table_sql_is_blocked() -> None:
|
||||
mw = SafetyShellMiddleware()
|
||||
with pytest.raises(MyDeepAgentError):
|
||||
await mw.awrap_tool_call(_make_shell_request("psql -c 'DROP TABLE users'"), AsyncMock())
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_tool_name_also_blocked() -> None:
|
||||
"""The 'execute' tool name is also checked for destructive patterns."""
|
||||
mw = SafetyShellMiddleware()
|
||||
with pytest.raises(MyDeepAgentError):
|
||||
await mw.awrap_tool_call(
|
||||
_make_shell_request("rm -rf /tmp/data", tool_name="execute"), AsyncMock()
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# argv (list) form — should also be blocked
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rm_rf_as_list_argv_is_blocked() -> None:
|
||||
mw = SafetyShellMiddleware()
|
||||
with pytest.raises(MyDeepAgentError):
|
||||
await mw.awrap_tool_call(
|
||||
_make_shell_request(["rm", "-rf", "/tmp"], tool_name="shell"), AsyncMock()
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Safe commands — should pass through
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ls_la_passes_through() -> None:
|
||||
mw = SafetyShellMiddleware()
|
||||
handler = AsyncMock(return_value="total 42")
|
||||
result = await mw.awrap_tool_call(_make_shell_request("ls -la"), handler)
|
||||
assert result == "total 42"
|
||||
handler.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_git_status_passes_through() -> None:
|
||||
mw = SafetyShellMiddleware()
|
||||
handler = AsyncMock(return_value="On branch main")
|
||||
result = await mw.awrap_tool_call(_make_shell_request("git status"), handler)
|
||||
assert result == "On branch main"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_git_push_without_force_passes_through() -> None:
|
||||
mw = SafetyShellMiddleware()
|
||||
handler = AsyncMock(return_value="ok")
|
||||
result = await mw.awrap_tool_call(_make_shell_request("git push origin main"), handler)
|
||||
assert result == "ok"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Non-shell tools — should NOT be inspected
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_file_tool_with_destructive_content_passes() -> None:
|
||||
"""read_file is not a shell tool; its content should not be blocked."""
|
||||
mw = SafetyShellMiddleware()
|
||||
handler = AsyncMock(return_value="file content")
|
||||
request = _make_other_tool_request("read_file", {"path": "/some/file.py"})
|
||||
result = await mw.awrap_tool_call(request, handler)
|
||||
assert result == "file content"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unknown_tool_not_checked() -> None:
|
||||
mw = SafetyShellMiddleware()
|
||||
handler = AsyncMock(return_value="ok")
|
||||
result = await mw.awrap_tool_call(_make_other_tool_request("arbitrary_tool"), handler)
|
||||
assert result == "ok"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _is_denied_path unit tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_is_denied_path_env_file() -> None:
|
||||
assert _is_denied_path(".env") is True
|
||||
|
||||
|
||||
def test_is_denied_path_env_local_in_subdir() -> None:
|
||||
assert _is_denied_path("config/.env.local") is True
|
||||
|
||||
|
||||
def test_is_denied_path_ssh_key() -> None:
|
||||
assert _is_denied_path(".ssh/id_rsa") is True
|
||||
|
||||
|
||||
def test_is_denied_path_safe_source_file() -> None:
|
||||
assert _is_denied_path("src/main.py") is False
|
||||
|
||||
|
||||
def test_is_denied_path_token_file() -> None:
|
||||
assert _is_denied_path("api_token.json") is True
|
||||
|
||||
|
||||
def test_is_denied_path_aws_credentials() -> None:
|
||||
assert _is_denied_path(".aws/credentials") is True
|
||||
|
||||
|
||||
def test_is_denied_path_pem_file() -> None:
|
||||
assert _is_denied_path("key.pem") is True
|
||||
|
||||
|
||||
def test_is_denied_path_absolute_env() -> None:
|
||||
# absolute path normalised by lstrip('/')
|
||||
assert _is_denied_path("/.env") is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Secret-path tool blocking via awrap_tool_call
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_file_env_path_is_blocked() -> None:
|
||||
mw = SafetyShellMiddleware()
|
||||
request = _make_other_tool_request("read_file", {"file_path": ".env"})
|
||||
with pytest.raises(MyDeepAgentError) as exc_info:
|
||||
await mw.awrap_tool_call(request, AsyncMock())
|
||||
assert exc_info.value.code == "secret_access_blocked"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_write_file_pem_path_is_blocked() -> None:
|
||||
mw = SafetyShellMiddleware()
|
||||
request = _make_other_tool_request("write_file", {"file_path": "key.pem"})
|
||||
with pytest.raises(MyDeepAgentError) as exc_info:
|
||||
await mw.awrap_tool_call(request, AsyncMock())
|
||||
assert exc_info.value.code == "secret_access_blocked"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ls_ssh_dir_is_blocked() -> None:
|
||||
mw = SafetyShellMiddleware()
|
||||
request = _make_other_tool_request("ls", {"path": ".ssh/"})
|
||||
with pytest.raises(MyDeepAgentError) as exc_info:
|
||||
await mw.awrap_tool_call(request, AsyncMock())
|
||||
assert exc_info.value.code == "secret_access_blocked"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_read_file_safe_path_passes() -> None:
|
||||
mw = SafetyShellMiddleware()
|
||||
handler = AsyncMock(return_value="content")
|
||||
request = _make_other_tool_request("read_file", {"file_path": "src/foo.py"})
|
||||
result = await mw.awrap_tool_call(request, handler)
|
||||
assert result == "content"
|
||||
handler.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_tool_path_arg_not_path_checked() -> None:
|
||||
"""execute tool goes through shell-check only, not path-check."""
|
||||
mw = SafetyShellMiddleware()
|
||||
handler = AsyncMock(return_value="ok")
|
||||
# safe shell command with a path arg — should not be blocked via path logic
|
||||
request = _make_shell_request("ls /some/safe/dir", tool_name="execute")
|
||||
result = await mw.awrap_tool_call(request, handler)
|
||||
assert result == "ok"
|
||||
332
my-deepagent/tests/unit/test_persona.py
Normal file
332
my-deepagent/tests/unit/test_persona.py
Normal file
@@ -0,0 +1,332 @@
|
||||
"""Unit tests for src/my_deepagent/persona.py."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from my_deepagent.enums import Backend
|
||||
from my_deepagent.persona import (
|
||||
FilesystemPermissionSpec,
|
||||
Persona,
|
||||
PersonaSubagent,
|
||||
load_persona_yaml,
|
||||
load_personas_from_dir,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
PERSONAS_DIR = Path(__file__).parent.parent.parent / "docs" / "schemas" / "personas"
|
||||
|
||||
|
||||
def _minimal_persona_dict(**overrides: object) -> dict[str, object]:
|
||||
"""Return a minimal valid persona dict, overridable per-test."""
|
||||
base: dict[str, object] = {
|
||||
"name": "test-persona",
|
||||
"version": 1,
|
||||
"backend": "openrouter",
|
||||
"model": "openrouter:anthropic/claude-sonnet-4-6",
|
||||
"provider_origin": "US/Anthropic",
|
||||
"capabilities": ["spec_write"],
|
||||
"max_risk_level": "low",
|
||||
"system_prompt": "You are a test persona for unit tests.",
|
||||
}
|
||||
base.update(overrides)
|
||||
return base
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Seed yaml: all 10 load successfully
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_all_seed_personas_load() -> None:
|
||||
personas = load_personas_from_dir(PERSONAS_DIR)
|
||||
assert len(personas) == 10
|
||||
|
||||
|
||||
def test_seed_persona_names_unique() -> None:
|
||||
personas = load_personas_from_dir(PERSONAS_DIR)
|
||||
keys = [(p.name, p.version) for p in personas]
|
||||
assert len(keys) == len(set(keys))
|
||||
|
||||
|
||||
def test_seed_personas_backends_are_openrouter() -> None:
|
||||
personas = load_personas_from_dir(PERSONAS_DIR)
|
||||
for p in personas:
|
||||
assert p.backend == Backend.OPENROUTER
|
||||
|
||||
|
||||
def test_seed_persona_capabilities_non_empty() -> None:
|
||||
personas = load_personas_from_dir(PERSONAS_DIR)
|
||||
for p in personas:
|
||||
assert len(p.capabilities) >= 1
|
||||
|
||||
|
||||
def test_seed_persona_hash_is_64_char_hex() -> None:
|
||||
personas = load_personas_from_dir(PERSONAS_DIR)
|
||||
for p in personas:
|
||||
h = p.compute_hash()
|
||||
assert re.fullmatch(r"[0-9a-f]{64}", h), f"{p.name}: bad hash {h!r}"
|
||||
|
||||
|
||||
def test_seed_persona_frozen() -> None:
|
||||
"""Frozen model: attribute assignment must raise."""
|
||||
personas = load_personas_from_dir(PERSONAS_DIR)
|
||||
p = personas[0]
|
||||
with pytest.raises((TypeError, ValidationError)):
|
||||
p.name = "mutated" # type: ignore[misc]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# extra="forbid": unknown fields rejected
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_persona_extra_field_raises() -> None:
|
||||
data = _minimal_persona_dict(unknown_field="surprise")
|
||||
with pytest.raises(ValidationError, match="extra"):
|
||||
Persona.model_validate(data)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# FilesystemPermissionSpec validators
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_permission_path_no_leading_slash_raises() -> None:
|
||||
with pytest.raises(ValidationError, match="must start with '/'"):
|
||||
FilesystemPermissionSpec(operations=["read"], paths=["relative/path"])
|
||||
|
||||
|
||||
def test_permission_path_dotdot_raises() -> None:
|
||||
with pytest.raises(ValidationError, match=r"must not contain '\.\.'"):
|
||||
FilesystemPermissionSpec(operations=["read"], paths=["/foo/../bar"])
|
||||
|
||||
|
||||
def test_permission_path_tilde_raises() -> None:
|
||||
with pytest.raises(ValidationError, match="must not contain '~'"):
|
||||
FilesystemPermissionSpec(operations=["read"], paths=["/path/~expansion/secret"])
|
||||
|
||||
|
||||
def test_permission_path_glob_ok() -> None:
|
||||
"""Glob patterns like /** should not trigger the path validator."""
|
||||
spec = FilesystemPermissionSpec(operations=["read", "write"], paths=["/**"])
|
||||
assert spec.paths == ("/**",)
|
||||
|
||||
|
||||
def test_permission_mode_default_allow() -> None:
|
||||
spec = FilesystemPermissionSpec(operations=["read"], paths=["/tmp"])
|
||||
assert spec.mode == "allow"
|
||||
|
||||
|
||||
def test_permission_deny_mode() -> None:
|
||||
spec = FilesystemPermissionSpec(operations=["write"], paths=["/.env"], mode="deny")
|
||||
assert spec.mode == "deny"
|
||||
|
||||
|
||||
def test_permission_extra_field_raises() -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
FilesystemPermissionSpec(operations=["read"], paths=["/tmp"], unknown=True) # type: ignore[call-arg]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Persona.compute_hash: determinism
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_compute_hash_deterministic() -> None:
|
||||
p = Persona.model_validate(_minimal_persona_dict())
|
||||
hashes = [p.compute_hash() for _ in range(20)]
|
||||
assert len(set(hashes)) == 1
|
||||
|
||||
|
||||
def test_compute_hash_different_personas_differ() -> None:
|
||||
p1 = Persona.model_validate(_minimal_persona_dict(name="p1"))
|
||||
p2 = Persona.model_validate(_minimal_persona_dict(name="p2"))
|
||||
assert p1.compute_hash() != p2.compute_hash()
|
||||
|
||||
|
||||
def test_compute_hash_version_affects_hash() -> None:
|
||||
p1 = Persona.model_validate(_minimal_persona_dict(version=1))
|
||||
p2 = Persona.model_validate(_minimal_persona_dict(version=2))
|
||||
assert p1.compute_hash() != p2.compute_hash()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Persona: min_length, ge validators
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_persona_empty_capabilities_raises() -> None:
|
||||
data = _minimal_persona_dict(capabilities=[])
|
||||
with pytest.raises(ValidationError):
|
||||
Persona.model_validate(data)
|
||||
|
||||
|
||||
def test_persona_version_zero_raises() -> None:
|
||||
data = _minimal_persona_dict(version=0)
|
||||
with pytest.raises(ValidationError):
|
||||
Persona.model_validate(data)
|
||||
|
||||
|
||||
def test_persona_negative_max_cost_raises() -> None:
|
||||
data = _minimal_persona_dict(max_cost_per_call_usd=-0.01)
|
||||
with pytest.raises(ValidationError):
|
||||
Persona.model_validate(data)
|
||||
|
||||
|
||||
def test_persona_system_prompt_too_short_raises() -> None:
|
||||
data = _minimal_persona_dict(system_prompt="short")
|
||||
with pytest.raises(ValidationError):
|
||||
Persona.model_validate(data)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# load_persona_yaml: file not found
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_load_persona_yaml_missing_file(tmp_path: Path) -> None:
|
||||
with pytest.raises(FileNotFoundError):
|
||||
load_persona_yaml(tmp_path / "nonexistent.yaml")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# load_personas_from_dir: duplicate detection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_load_personas_from_dir_duplicate_raises(tmp_path: Path) -> None:
|
||||
import yaml
|
||||
|
||||
data = _minimal_persona_dict()
|
||||
for fname in ("persona-a@1.yaml", "persona-b@1.yaml"):
|
||||
(tmp_path / fname).write_text(yaml.dump(data), encoding="utf-8")
|
||||
|
||||
with pytest.raises(ValueError, match="duplicate persona"):
|
||||
load_personas_from_dir(tmp_path)
|
||||
|
||||
|
||||
def test_load_personas_from_dir_missing_dir() -> None:
|
||||
result = load_personas_from_dir(Path("/nonexistent_directory_xyz"))
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_load_personas_from_dir_sorted_by_filename(tmp_path: Path) -> None:
|
||||
"""Files are loaded in filename order for determinism."""
|
||||
import yaml
|
||||
|
||||
for i, name in enumerate(["zz-persona", "aa-persona"]):
|
||||
data = _minimal_persona_dict(name=name, version=1)
|
||||
(tmp_path / f"{name}@1.yaml").write_text(yaml.dump(data), encoding="utf-8")
|
||||
|
||||
personas = load_personas_from_dir(tmp_path)
|
||||
assert personas[0].name == "aa-persona"
|
||||
assert personas[1].name == "zz-persona"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PersonaSubagent: extra="forbid", min_length
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_subagent_extra_field_raises() -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
PersonaSubagent(
|
||||
name="x",
|
||||
description="at least ten chars here",
|
||||
system_prompt="at least ten chars here",
|
||||
unknown_field=True, # type: ignore[call-arg]
|
||||
)
|
||||
|
||||
|
||||
def test_subagent_short_description_raises() -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
PersonaSubagent(name="x", description="short", system_prompt="at least ten chars here")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Snapshot: specific persona hashes are stable
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_default_interactive_hash_prefix() -> None:
|
||||
"""Hash of default-interactive@1 must start with 8193103c.
|
||||
|
||||
Hash updated: permissions block removed from yaml (deepagents 0.6.1 workaround).
|
||||
"""
|
||||
personas = load_personas_from_dir(PERSONAS_DIR)
|
||||
p = next(q for q in personas if q.name == "default-interactive")
|
||||
assert p.compute_hash().startswith("8193103c")
|
||||
|
||||
|
||||
def test_spec_writer_hash_prefix() -> None:
|
||||
"""Hash of openrouter-claude-spec-writer@1 must be stable."""
|
||||
personas = load_personas_from_dir(PERSONAS_DIR)
|
||||
p = next(q for q in personas if q.name == "openrouter-claude-spec-writer")
|
||||
h = p.compute_hash()
|
||||
assert len(h) == 64
|
||||
assert re.fullmatch(r"[0-9a-f]{64}", h)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Step 2 patch: null byte path rejection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_filesystem_permission_null_byte_rejected() -> None:
|
||||
"""Null bytes in a filesystem permission path must be rejected."""
|
||||
with pytest.raises(ValidationError, match="null bytes"):
|
||||
FilesystemPermissionSpec.model_validate(
|
||||
{
|
||||
"operations": ["read"],
|
||||
"paths": ["/foo\x00/bar"],
|
||||
"mode": "deny",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Deep immutability: nested list-valued fields are tuples (cannot be mutated)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_persona_capabilities_immutable() -> None:
|
||||
"""capabilities is a tuple — .append() must raise AttributeError."""
|
||||
p = Persona.model_validate(_minimal_persona_dict())
|
||||
with pytest.raises((AttributeError, TypeError)):
|
||||
p.capabilities.append(None) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def test_persona_subagents_immutable() -> None:
|
||||
"""subagents is a tuple — .append() must raise AttributeError."""
|
||||
p = Persona.model_validate(_minimal_persona_dict())
|
||||
with pytest.raises((AttributeError, TypeError)):
|
||||
p.subagents.append(None) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def test_persona_skills_immutable() -> None:
|
||||
"""skills is a tuple — .append() must raise AttributeError."""
|
||||
p = Persona.model_validate(_minimal_persona_dict())
|
||||
with pytest.raises((AttributeError, TypeError)):
|
||||
p.skills.append("new_skill") # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def test_filesystem_permission_paths_immutable() -> None:
|
||||
"""paths is a tuple — .append() must raise AttributeError."""
|
||||
perm = FilesystemPermissionSpec(operations=("read",), paths=("/foo",), mode="allow")
|
||||
with pytest.raises((AttributeError, TypeError)):
|
||||
perm.paths.append("/bar") # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def test_filesystem_permission_operations_immutable() -> None:
|
||||
"""operations is a tuple — .append() must raise AttributeError."""
|
||||
perm = FilesystemPermissionSpec(operations=("read",), paths=("/foo",), mode="allow")
|
||||
with pytest.raises((AttributeError, TypeError)):
|
||||
perm.operations.append("write") # type: ignore[attr-defined]
|
||||
229
my-deepagent/tests/unit/test_pricing.py
Normal file
229
my-deepagent/tests/unit/test_pricing.py
Normal file
@@ -0,0 +1,229 @@
|
||||
"""Unit tests for src/my_deepagent/monitoring/pricing.py."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import respx
|
||||
|
||||
from my_deepagent.errors import MyDeepAgentError
|
||||
from my_deepagent.monitoring.pricing import (
|
||||
ModelPrice,
|
||||
PricingCache,
|
||||
_parse_pricing_payload,
|
||||
fetch_openrouter_pricing,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _parse_pricing_payload
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_parse_valid_payload_returns_model_prices() -> None:
|
||||
data = {
|
||||
"data": [
|
||||
{
|
||||
"id": "deepseek/deepseek-chat",
|
||||
"pricing": {"prompt": "0.000001", "completion": "0.000002"},
|
||||
"context_length": 32768,
|
||||
},
|
||||
{
|
||||
"id": "anthropic/claude-sonnet",
|
||||
"pricing": {"prompt": "0.000003", "completion": "0.000015"},
|
||||
"context_length": 200000,
|
||||
},
|
||||
]
|
||||
}
|
||||
result = _parse_pricing_payload(data)
|
||||
assert len(result) == 2
|
||||
assert result[0].model == "deepseek/deepseek-chat"
|
||||
assert result[0].input_per_1k_usd == pytest.approx(0.001)
|
||||
assert result[0].output_per_1k_usd == pytest.approx(0.002)
|
||||
assert result[0].context_length == 32768
|
||||
assert result[1].model == "anthropic/claude-sonnet"
|
||||
|
||||
|
||||
def test_parse_empty_data_list_returns_empty() -> None:
|
||||
result = _parse_pricing_payload({"data": []})
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_parse_data_is_not_list_returns_empty() -> None:
|
||||
# data is a dict instead of list — malformed response
|
||||
result = _parse_pricing_payload({"data": {"id": "bad"}})
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_parse_missing_data_key_returns_empty() -> None:
|
||||
result = _parse_pricing_payload({})
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_parse_skips_entries_without_id() -> None:
|
||||
data = {
|
||||
"data": [
|
||||
{"pricing": {"prompt": "0.000001", "completion": "0.000002"}, "context_length": 1000},
|
||||
]
|
||||
}
|
||||
result = _parse_pricing_payload(data)
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_parse_skips_entries_with_invalid_pricing_values() -> None:
|
||||
data = {
|
||||
"data": [
|
||||
{
|
||||
"id": "model/x",
|
||||
"pricing": {"prompt": "not-a-number", "completion": "also-bad"},
|
||||
"context_length": 1000,
|
||||
}
|
||||
]
|
||||
}
|
||||
result = _parse_pricing_payload(data)
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_parse_handles_null_pricing_gracefully() -> None:
|
||||
data = {
|
||||
"data": [
|
||||
{"id": "model/y", "pricing": None, "context_length": 0},
|
||||
]
|
||||
}
|
||||
result = _parse_pricing_payload(data)
|
||||
# pricing=None -> {} -> prompt/completion default to "0"
|
||||
assert len(result) == 1
|
||||
assert result[0].input_per_1k_usd == 0.0
|
||||
assert result[0].output_per_1k_usd == 0.0
|
||||
|
||||
|
||||
def test_parse_handles_missing_context_length() -> None:
|
||||
data = {
|
||||
"data": [
|
||||
{"id": "model/z", "pricing": {"prompt": "0.000001", "completion": "0.000002"}},
|
||||
]
|
||||
}
|
||||
result = _parse_pricing_payload(data)
|
||||
assert len(result) == 1
|
||||
assert result[0].context_length == 0
|
||||
|
||||
|
||||
def test_parse_non_dict_entry_is_skipped() -> None:
|
||||
data = {"data": ["not-a-dict", None]}
|
||||
result = _parse_pricing_payload(data)
|
||||
assert result == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PricingCache.compute_cost
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_compute_cost_known_model() -> None:
|
||||
cache = PricingCache()
|
||||
cache.set(
|
||||
[
|
||||
ModelPrice(
|
||||
model="deepseek/deepseek-chat",
|
||||
input_per_1k_usd=0.001,
|
||||
output_per_1k_usd=0.002,
|
||||
context_length=32768,
|
||||
)
|
||||
]
|
||||
)
|
||||
cost = cache.compute_cost("deepseek/deepseek-chat", input_tokens=1000, output_tokens=500)
|
||||
assert cost == pytest.approx(0.001 * 1.0 + 0.002 * 0.5)
|
||||
|
||||
|
||||
def test_compute_cost_openrouter_prefix_stripped() -> None:
|
||||
cache = PricingCache()
|
||||
cache.set(
|
||||
[
|
||||
ModelPrice(
|
||||
model="deepseek/deepseek-chat",
|
||||
input_per_1k_usd=0.001,
|
||||
output_per_1k_usd=0.002,
|
||||
context_length=32768,
|
||||
)
|
||||
]
|
||||
)
|
||||
# Should strip "openrouter:" prefix when looking up
|
||||
cost = cache.compute_cost(
|
||||
"openrouter:deepseek/deepseek-chat", input_tokens=1000, output_tokens=0
|
||||
)
|
||||
assert cost == pytest.approx(0.001)
|
||||
|
||||
|
||||
def test_compute_cost_unknown_model_returns_zero() -> None:
|
||||
cache = PricingCache()
|
||||
cost = cache.compute_cost("unknown/model", input_tokens=1000, output_tokens=1000)
|
||||
assert cost == 0.0
|
||||
|
||||
|
||||
def test_compute_cost_zero_tokens_returns_zero() -> None:
|
||||
cache = PricingCache()
|
||||
cache.set(
|
||||
[ModelPrice(model="m/x", input_per_1k_usd=1.0, output_per_1k_usd=2.0, context_length=1000)]
|
||||
)
|
||||
assert cache.compute_cost("m/x", input_tokens=0, output_tokens=0) == 0.0
|
||||
|
||||
|
||||
def test_pricing_cache_get_strips_openrouter_prefix() -> None:
|
||||
cache = PricingCache()
|
||||
cache.set(
|
||||
[ModelPrice(model="a/b", input_per_1k_usd=0.5, output_per_1k_usd=1.0, context_length=0)]
|
||||
)
|
||||
assert cache.get("openrouter:a/b") is not None
|
||||
assert cache.get("a/b") is not None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# fetch_openrouter_pricing (respx mock)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_openrouter_pricing_success() -> None:
|
||||
payload = {
|
||||
"data": [
|
||||
{
|
||||
"id": "deepseek/deepseek-chat",
|
||||
"pricing": {"prompt": "0.000001", "completion": "0.000002"},
|
||||
"context_length": 64000,
|
||||
}
|
||||
]
|
||||
}
|
||||
with respx.mock:
|
||||
respx.get("https://openrouter.ai/api/v1/models").mock(
|
||||
return_value=httpx.Response(200, json=payload)
|
||||
)
|
||||
result = await fetch_openrouter_pricing(
|
||||
api_key="sk-or-test", base_url="https://openrouter.ai/api/v1"
|
||||
)
|
||||
assert len(result) == 1
|
||||
assert result[0].model == "deepseek/deepseek-chat"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_openrouter_pricing_http_error_raises_recoverable() -> None:
|
||||
with respx.mock:
|
||||
respx.get("https://openrouter.ai/api/v1/models").mock(
|
||||
return_value=httpx.Response(401, json={"error": "unauthorized"})
|
||||
)
|
||||
with pytest.raises(MyDeepAgentError) as exc_info:
|
||||
await fetch_openrouter_pricing(
|
||||
api_key="bad-key", base_url="https://openrouter.ai/api/v1"
|
||||
)
|
||||
assert exc_info.value.code == "network_blip"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_openrouter_pricing_connect_error_raises_recoverable() -> None:
|
||||
with respx.mock:
|
||||
respx.get("https://openrouter.ai/api/v1/models").mock(
|
||||
side_effect=httpx.ConnectError("connection refused")
|
||||
)
|
||||
with pytest.raises(MyDeepAgentError) as exc_info:
|
||||
await fetch_openrouter_pricing(
|
||||
api_key="sk-or-test", base_url="https://openrouter.ai/api/v1"
|
||||
)
|
||||
assert exc_info.value.code == "network_blip"
|
||||
454
my-deepagent/tests/unit/test_session.py
Normal file
454
my-deepagent/tests/unit/test_session.py
Normal file
@@ -0,0 +1,454 @@
|
||||
"""Unit tests for src/my_deepagent/session.py.
|
||||
|
||||
Tests verify the dataclass-based deepagents API (FilesystemPermission attributes,
|
||||
build_backend backend type dispatch, _map_operations deduplication, etc.).
|
||||
No real API calls are made.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from deepagents import FilesystemPermission
|
||||
from deepagents.backends import (
|
||||
CompositeBackend,
|
||||
FilesystemBackend,
|
||||
LocalShellBackend,
|
||||
)
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
|
||||
from my_deepagent.config import load_config
|
||||
from my_deepagent.errors import MyDeepAgentError
|
||||
from my_deepagent.persona import FilesystemPermissionSpec, Persona, PersonaSubagent
|
||||
from my_deepagent.session import (
|
||||
_map_operations,
|
||||
_resolve_openrouter_api_key,
|
||||
_spec_to_permission,
|
||||
_subagent_to_dict,
|
||||
build_agent,
|
||||
build_backend,
|
||||
default_safety_permissions,
|
||||
resolve_model_instance,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _minimal_persona(**overrides: Any) -> Persona:
|
||||
base: dict[str, Any] = {
|
||||
"name": "test-persona",
|
||||
"version": 1,
|
||||
"backend": "openrouter",
|
||||
"model": "openrouter:anthropic/claude-sonnet-4-6",
|
||||
"provider_origin": "US/Anthropic",
|
||||
"capabilities": ["spec_write"],
|
||||
"max_risk_level": "low",
|
||||
"system_prompt": "You are a test assistant for unit tests.",
|
||||
}
|
||||
base.update(overrides)
|
||||
return Persona.model_validate(base)
|
||||
|
||||
|
||||
def _minimal_permission_spec(
|
||||
operations: list[str] | None = None,
|
||||
paths: list[str] | None = None,
|
||||
mode: str = "allow",
|
||||
) -> FilesystemPermissionSpec:
|
||||
return FilesystemPermissionSpec(
|
||||
operations=tuple(operations or ["read"]),
|
||||
paths=tuple(paths or ["/**"]),
|
||||
mode=mode, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
|
||||
def _minimal_subagent(**overrides: Any) -> PersonaSubagent:
|
||||
base: dict[str, Any] = {
|
||||
"name": "test-sub",
|
||||
"description": "A test subagent description.",
|
||||
"system_prompt": "You are a subagent for unit tests.",
|
||||
}
|
||||
base.update(overrides)
|
||||
return PersonaSubagent.model_validate(base)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# default_safety_permissions — dataclass attribute access
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_default_safety_permissions_returns_two_entries() -> None:
|
||||
perms = default_safety_permissions()
|
||||
assert len(perms) == 2
|
||||
|
||||
|
||||
def test_default_safety_permissions_returns_filesystem_permission_instances() -> None:
|
||||
perms = default_safety_permissions()
|
||||
for p in perms:
|
||||
assert isinstance(p, FilesystemPermission)
|
||||
|
||||
|
||||
def test_default_safety_permissions_allow_is_first() -> None:
|
||||
perms = default_safety_permissions()
|
||||
assert perms[0].mode == "allow"
|
||||
assert "/**" in perms[0].paths
|
||||
|
||||
|
||||
def test_default_safety_permissions_allow_has_both_operations() -> None:
|
||||
perms = default_safety_permissions()
|
||||
assert "read" in perms[0].operations
|
||||
assert "write" in perms[0].operations
|
||||
|
||||
|
||||
def test_default_safety_permissions_deny_is_second() -> None:
|
||||
perms = default_safety_permissions()
|
||||
assert perms[1].mode == "deny"
|
||||
deny_paths = perms[1].paths
|
||||
assert any("env" in p for p in deny_paths)
|
||||
assert any("ssh" in p for p in deny_paths)
|
||||
|
||||
|
||||
def test_default_safety_permissions_deny_covers_secrets() -> None:
|
||||
perms = default_safety_permissions()
|
||||
deny_paths = perms[1].paths
|
||||
assert any("secret" in p for p in deny_paths)
|
||||
assert any("token" in p for p in deny_paths)
|
||||
assert any("pem" in p for p in deny_paths)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _map_operations — 8 케이스
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_map_operations_read() -> None:
|
||||
assert _map_operations(("read",)) == ["read"]
|
||||
|
||||
|
||||
def test_map_operations_write() -> None:
|
||||
assert _map_operations(("write",)) == ["write"]
|
||||
|
||||
|
||||
def test_map_operations_edit_maps_to_write() -> None:
|
||||
assert _map_operations(("edit",)) == ["write"]
|
||||
|
||||
|
||||
def test_map_operations_ls_maps_to_read() -> None:
|
||||
assert _map_operations(("ls",)) == ["read"]
|
||||
|
||||
|
||||
def test_map_operations_deduplicates_all_four() -> None:
|
||||
result = _map_operations(("read", "write", "edit", "ls"))
|
||||
assert result == ["read", "write"]
|
||||
|
||||
|
||||
def test_map_operations_ls_and_edit() -> None:
|
||||
assert _map_operations(("ls", "edit")) == ["read", "write"]
|
||||
|
||||
|
||||
def test_map_operations_preserves_order_write_then_read() -> None:
|
||||
result = _map_operations(("write", "read"))
|
||||
assert result == ["write", "read"]
|
||||
|
||||
|
||||
def test_map_operations_empty_returns_empty() -> None:
|
||||
assert _map_operations(()) == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _spec_to_permission — dataclass attribute + mapping
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_spec_to_permission_returns_filesystem_permission() -> None:
|
||||
spec = _minimal_permission_spec(operations=["read"], paths=["/**"], mode="allow")
|
||||
result = _spec_to_permission(spec)
|
||||
assert isinstance(result, FilesystemPermission)
|
||||
|
||||
|
||||
def test_spec_to_permission_maps_read_write_correctly() -> None:
|
||||
spec = _minimal_permission_spec(operations=["read", "write"], paths=["/**"], mode="allow")
|
||||
result = _spec_to_permission(spec)
|
||||
assert result.operations == ["read", "write"]
|
||||
assert result.paths == ["/**"]
|
||||
assert result.mode == "allow"
|
||||
|
||||
|
||||
def test_spec_to_permission_maps_edit_to_write() -> None:
|
||||
spec = _minimal_permission_spec(operations=["edit"], paths=["/src/**"], mode="allow")
|
||||
result = _spec_to_permission(spec)
|
||||
assert result.operations == ["write"]
|
||||
|
||||
|
||||
def test_spec_to_permission_maps_ls_to_read() -> None:
|
||||
spec = _minimal_permission_spec(operations=["ls"], paths=["/data/**"], mode="allow")
|
||||
result = _spec_to_permission(spec)
|
||||
assert result.operations == ["read"]
|
||||
|
||||
|
||||
def test_spec_to_permission_deduplicates_read_edit_ls() -> None:
|
||||
spec = _minimal_permission_spec(
|
||||
operations=["read", "edit", "ls"], paths=["/workspace/**"], mode="allow"
|
||||
)
|
||||
result = _spec_to_permission(spec)
|
||||
# read=read, edit=write, ls=read → ["read", "write"]
|
||||
assert result.operations == ["read", "write"]
|
||||
|
||||
|
||||
def test_spec_to_permission_deny_mode_passthrough() -> None:
|
||||
spec = _minimal_permission_spec(operations=["read"], paths=["/.env*"], mode="deny")
|
||||
result = _spec_to_permission(spec)
|
||||
assert result.mode == "deny"
|
||||
assert "/.env*" in result.paths
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _subagent_to_dict
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_subagent_to_dict_required_fields() -> None:
|
||||
sub = _minimal_subagent()
|
||||
d = _subagent_to_dict(sub)
|
||||
assert d["name"] == "test-sub"
|
||||
assert d["description"] == "A test subagent description."
|
||||
assert d["system_prompt"] == "You are a subagent for unit tests."
|
||||
|
||||
|
||||
def test_subagent_to_dict_optional_tools_included_when_set() -> None:
|
||||
sub = _minimal_subagent(allowed_tools=["read_file", "write_file"])
|
||||
d = _subagent_to_dict(sub)
|
||||
assert "tools" in d
|
||||
assert d["tools"] == ["read_file", "write_file"]
|
||||
|
||||
|
||||
def test_subagent_to_dict_no_tools_key_when_empty() -> None:
|
||||
sub = _minimal_subagent()
|
||||
d = _subagent_to_dict(sub)
|
||||
assert "tools" not in d
|
||||
|
||||
|
||||
def test_subagent_to_dict_optional_model_included_when_set() -> None:
|
||||
sub = _minimal_subagent(model="openrouter:deepseek/deepseek-chat")
|
||||
d = _subagent_to_dict(sub)
|
||||
assert "model" in d
|
||||
assert d["model"] == "openrouter:deepseek/deepseek-chat"
|
||||
|
||||
|
||||
def test_subagent_to_dict_no_model_key_when_none() -> None:
|
||||
sub = _minimal_subagent()
|
||||
d = _subagent_to_dict(sub)
|
||||
assert "model" not in d
|
||||
|
||||
|
||||
def test_subagent_to_dict_permissions_included_when_set() -> None:
|
||||
sub = _minimal_subagent(
|
||||
permissions=[{"operations": ["read"], "paths": ["/**"], "mode": "allow"}]
|
||||
)
|
||||
d = _subagent_to_dict(sub)
|
||||
assert "permissions" in d
|
||||
assert len(d["permissions"]) == 1
|
||||
# permissions 안의 항목도 FilesystemPermission 인스턴스
|
||||
assert isinstance(d["permissions"][0], FilesystemPermission)
|
||||
|
||||
|
||||
def test_subagent_to_dict_permissions_empty_not_included() -> None:
|
||||
sub = _minimal_subagent()
|
||||
d = _subagent_to_dict(sub)
|
||||
assert "permissions" not in d
|
||||
|
||||
|
||||
def test_subagent_to_dict_interrupt_on_included_when_set() -> None:
|
||||
sub = _minimal_subagent(interrupt_on={"write_file": {"allowed_decisions": ["approve"]}})
|
||||
d = _subagent_to_dict(sub)
|
||||
assert "interrupt_on" in d
|
||||
|
||||
|
||||
def test_subagent_to_dict_no_interrupt_on_when_empty() -> None:
|
||||
sub = _minimal_subagent()
|
||||
d = _subagent_to_dict(sub)
|
||||
assert "interrupt_on" not in d
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _resolve_openrouter_api_key
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_resolve_api_key_from_config() -> None:
|
||||
config = load_config(openrouter_api_key="sk-or-from-config")
|
||||
key = _resolve_openrouter_api_key(config)
|
||||
assert key == "sk-or-from-config"
|
||||
|
||||
|
||||
def test_resolve_api_key_from_mydeepagent_env(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.delenv("MYDEEPAGENT_OPENROUTER_API_KEY", raising=False)
|
||||
monkeypatch.delenv("OPENROUTER_API_KEY", raising=False)
|
||||
monkeypatch.setenv("MYDEEPAGENT_OPENROUTER_API_KEY", "sk-or-env-mydeepagent")
|
||||
config = load_config(openrouter_api_key=None)
|
||||
key = _resolve_openrouter_api_key(config)
|
||||
assert key == "sk-or-env-mydeepagent"
|
||||
|
||||
|
||||
def test_resolve_api_key_fallback_to_openrouter_env(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.delenv("MYDEEPAGENT_OPENROUTER_API_KEY", raising=False)
|
||||
monkeypatch.delenv("OPENROUTER_API_KEY", raising=False)
|
||||
monkeypatch.setenv("OPENROUTER_API_KEY", "sk-or-env-fallback")
|
||||
config = load_config(openrouter_api_key=None)
|
||||
key = _resolve_openrouter_api_key(config)
|
||||
assert key == "sk-or-env-fallback"
|
||||
|
||||
|
||||
def test_resolve_api_key_raises_when_missing(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.delenv("MYDEEPAGENT_OPENROUTER_API_KEY", raising=False)
|
||||
monkeypatch.delenv("OPENROUTER_API_KEY", raising=False)
|
||||
config = load_config(openrouter_api_key=None)
|
||||
with pytest.raises(MyDeepAgentError) as exc_info:
|
||||
_resolve_openrouter_api_key(config)
|
||||
assert exc_info.value.code == "backend_auth_failed"
|
||||
|
||||
|
||||
def test_resolve_api_key_config_takes_priority_over_env(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("MYDEEPAGENT_OPENROUTER_API_KEY", "sk-or-env")
|
||||
config = load_config(openrouter_api_key="sk-or-config-wins")
|
||||
key = _resolve_openrouter_api_key(config)
|
||||
assert key == "sk-or-config-wins"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# resolve_model_instance
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_resolve_model_openrouter_returns_chat_openai() -> None:
|
||||
config = load_config(openrouter_api_key="sk-or-test")
|
||||
persona = _minimal_persona(model="openrouter:anthropic/claude-sonnet-4-6")
|
||||
instance = resolve_model_instance(persona, config)
|
||||
assert isinstance(instance, ChatOpenAI)
|
||||
assert instance.openai_api_base == config.openrouter_base_url
|
||||
|
||||
|
||||
def test_resolve_model_openrouter_uses_model_params() -> None:
|
||||
config = load_config(openrouter_api_key="sk-or-test")
|
||||
persona = _minimal_persona(
|
||||
model="openrouter:anthropic/claude-sonnet-4-6",
|
||||
model_params={"max_tokens": 1024, "temperature": 0.5},
|
||||
)
|
||||
instance = resolve_model_instance(persona, config)
|
||||
assert isinstance(instance, ChatOpenAI)
|
||||
assert instance.max_tokens == 1024
|
||||
|
||||
|
||||
def test_resolve_model_non_openrouter_returns_string() -> None:
|
||||
config = load_config()
|
||||
persona = _minimal_persona(
|
||||
backend="anthropic",
|
||||
model="anthropic:claude-3-5-sonnet-20241022",
|
||||
)
|
||||
result = resolve_model_instance(persona, config)
|
||||
assert isinstance(result, str)
|
||||
assert result == "anthropic:claude-3-5-sonnet-20241022"
|
||||
|
||||
|
||||
def test_resolve_model_with_override_openrouter() -> None:
|
||||
config = load_config(openrouter_api_key="sk-or-test")
|
||||
persona = _minimal_persona(model="openrouter:anthropic/claude-sonnet-4-6")
|
||||
instance = resolve_model_instance(
|
||||
persona, config, model_override="openrouter:deepseek/deepseek-chat"
|
||||
)
|
||||
assert isinstance(instance, ChatOpenAI)
|
||||
assert "deepseek-chat" in instance.model_name
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# build_backend — 5 케이스
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_build_backend_local_shell(tmp_path: Path) -> None:
|
||||
persona = _minimal_persona(deepagents_backend="local_shell")
|
||||
result = build_backend(persona, tmp_path)
|
||||
assert isinstance(result, LocalShellBackend)
|
||||
|
||||
|
||||
def test_build_backend_filesystem(tmp_path: Path) -> None:
|
||||
persona = _minimal_persona(deepagents_backend="filesystem")
|
||||
result = build_backend(persona, tmp_path)
|
||||
assert isinstance(result, FilesystemBackend)
|
||||
|
||||
|
||||
def test_build_backend_state_returns_none(tmp_path: Path) -> None:
|
||||
persona = _minimal_persona(deepagents_backend="state")
|
||||
result = build_backend(persona, tmp_path)
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_build_backend_composite(tmp_path: Path) -> None:
|
||||
persona = _minimal_persona(deepagents_backend="composite")
|
||||
result = build_backend(persona, tmp_path)
|
||||
assert isinstance(result, CompositeBackend)
|
||||
|
||||
|
||||
def test_build_backend_langsmith_raises_config_invalid(tmp_path: Path) -> None:
|
||||
persona = _minimal_persona(deepagents_backend="langsmith")
|
||||
with pytest.raises(MyDeepAgentError) as exc_info:
|
||||
build_backend(persona, tmp_path)
|
||||
assert exc_info.value.code == "config_invalid"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# build_agent
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_build_agent_returns_compiled_state_graph(tmp_path: Path) -> None:
|
||||
"""build_agent should construct a CompiledStateGraph without calling the LLM API."""
|
||||
config = load_config(openrouter_api_key="sk-or-test")
|
||||
persona = _minimal_persona(deepagents_backend="state")
|
||||
graph = build_agent(persona, config, root_dir=tmp_path)
|
||||
assert isinstance(graph, CompiledStateGraph)
|
||||
assert hasattr(graph, "invoke")
|
||||
assert hasattr(graph, "ainvoke")
|
||||
|
||||
|
||||
def test_build_agent_with_middleware_list(tmp_path: Path) -> None:
|
||||
"""Extra middleware is accepted without error.
|
||||
|
||||
build_agent automatically prepends SafetyShellMiddleware. Callers should pass
|
||||
*other* middleware here; passing a second SafetyShellMiddleware would hit
|
||||
deepagents' duplicate-name guard.
|
||||
"""
|
||||
from my_deepagent.middleware.audit import AuditToolMiddleware
|
||||
|
||||
config = load_config(openrouter_api_key="sk-or-test")
|
||||
persona = _minimal_persona(deepagents_backend="state")
|
||||
graph = build_agent(
|
||||
persona,
|
||||
config,
|
||||
root_dir=tmp_path,
|
||||
middleware=[AuditToolMiddleware()],
|
||||
)
|
||||
assert isinstance(graph, CompiledStateGraph)
|
||||
|
||||
|
||||
def test_build_agent_filesystem_backend(tmp_path: Path) -> None:
|
||||
"""build_agent works with filesystem backend."""
|
||||
config = load_config(openrouter_api_key="sk-or-test")
|
||||
persona = _minimal_persona(deepagents_backend="filesystem")
|
||||
graph = build_agent(persona, config, root_dir=tmp_path)
|
||||
assert isinstance(graph, CompiledStateGraph)
|
||||
|
||||
|
||||
def test_build_agent_with_persona_permissions(tmp_path: Path) -> None:
|
||||
"""build_agent merges persona permissions with default safety permissions."""
|
||||
config = load_config(openrouter_api_key="sk-or-test")
|
||||
persona = _minimal_persona(
|
||||
deepagents_backend="state",
|
||||
permissions=[{"operations": ["read"], "paths": ["/workspace/**"], "mode": "allow"}],
|
||||
)
|
||||
graph = build_agent(persona, config, root_dir=tmp_path)
|
||||
assert isinstance(graph, CompiledStateGraph)
|
||||
55
my-deepagent/tests/unit/test_session_seed_integration.py
Normal file
55
my-deepagent/tests/unit/test_session_seed_integration.py
Normal file
@@ -0,0 +1,55 @@
|
||||
"""Seed persona integration tests for session.py model resolution."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
from my_deepagent.config import load_config
|
||||
from my_deepagent.enums import Backend
|
||||
from my_deepagent.persona import load_personas_from_dir
|
||||
from my_deepagent.session import resolve_model_instance
|
||||
|
||||
PERSONAS_DIR = Path(__file__).parent.parent.parent / "docs" / "schemas" / "personas"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def seed_personas() -> list: # type: ignore[type-arg]
|
||||
return load_personas_from_dir(PERSONAS_DIR)
|
||||
|
||||
|
||||
def test_resolve_model_instance_seed_personas(seed_personas: list) -> None: # type: ignore[type-arg]
|
||||
"""resolve_model_instance should return ChatOpenAI for openrouter personas, str otherwise."""
|
||||
config = load_config(openrouter_api_key="sk-or-dummy")
|
||||
for persona in seed_personas:
|
||||
instance = resolve_model_instance(persona, config)
|
||||
if persona.backend == Backend.OPENROUTER:
|
||||
assert isinstance(instance, ChatOpenAI), (
|
||||
f"persona {persona.name!r} with backend=openrouter should return ChatOpenAI, "
|
||||
f"got {type(instance)}"
|
||||
)
|
||||
# base_url should point to openrouter
|
||||
assert instance.openai_api_base is not None
|
||||
base = instance.openai_api_base
|
||||
assert "openrouter" in base or base == config.openrouter_base_url
|
||||
else:
|
||||
assert isinstance(instance, str), (
|
||||
f"persona {persona.name!r} with backend={persona.backend} should return str, "
|
||||
f"got {type(instance)}"
|
||||
)
|
||||
|
||||
|
||||
def test_all_seed_personas_have_non_empty_model(seed_personas: list) -> None: # type: ignore[type-arg]
|
||||
for persona in seed_personas:
|
||||
assert persona.model, f"persona {persona.name!r} has empty model"
|
||||
|
||||
|
||||
def test_all_openrouter_seed_personas_have_openrouter_prefix(seed_personas: list) -> None: # type: ignore[type-arg]
|
||||
for persona in seed_personas:
|
||||
if persona.backend == Backend.OPENROUTER:
|
||||
assert persona.model.startswith("openrouter:"), (
|
||||
f"persona {persona.name!r} has backend=openrouter but model={persona.model!r} "
|
||||
"does not start with 'openrouter:'"
|
||||
)
|
||||
335
my-deepagent/tests/unit/test_workflow.py
Normal file
335
my-deepagent/tests/unit/test_workflow.py
Normal file
@@ -0,0 +1,335 @@
|
||||
"""Unit tests for src/my_deepagent/workflow.py."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from my_deepagent.workflow import (
|
||||
ExpectedArtifact,
|
||||
WorkflowTemplate,
|
||||
load_workflow_yaml,
|
||||
load_workflows_from_dir,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
WORKFLOWS_DIR = Path(__file__).parent.parent.parent / "docs" / "schemas" / "workflows"
|
||||
|
||||
|
||||
def _minimal_role(**overrides: object) -> dict[str, object]:
|
||||
base: dict[str, object] = {
|
||||
"id": "spec_writer",
|
||||
"required_capabilities": ["spec_write"],
|
||||
}
|
||||
base.update(overrides)
|
||||
return base
|
||||
|
||||
|
||||
def _minimal_phase(**overrides: object) -> dict[str, object]:
|
||||
base: dict[str, object] = {
|
||||
"key": "spec",
|
||||
"title": "Write spec",
|
||||
"risk": "low",
|
||||
"role": "spec_writer",
|
||||
"instructions": "Write the specification document for the feature.",
|
||||
}
|
||||
base.update(overrides)
|
||||
return base
|
||||
|
||||
|
||||
def _minimal_template(**overrides: object) -> dict[str, object]:
|
||||
base: dict[str, object] = {
|
||||
"name": "test-workflow",
|
||||
"version": 1,
|
||||
"roles": [_minimal_role()],
|
||||
"phases": [_minimal_phase()],
|
||||
}
|
||||
base.update(overrides)
|
||||
return base
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Seed yaml: all 3 load successfully
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_all_seed_workflows_load() -> None:
|
||||
workflows = load_workflows_from_dir(WORKFLOWS_DIR)
|
||||
assert len(workflows) == 3
|
||||
|
||||
|
||||
def test_seed_workflow_names() -> None:
|
||||
workflows = load_workflows_from_dir(WORKFLOWS_DIR)
|
||||
names = {w.name for w in workflows}
|
||||
assert names == {"spec-and-review", "bug-fix-with-reproduction", "code-investigation"}
|
||||
|
||||
|
||||
def test_seed_workflow_roles_non_empty() -> None:
|
||||
workflows = load_workflows_from_dir(WORKFLOWS_DIR)
|
||||
for w in workflows:
|
||||
assert len(w.roles) >= 1
|
||||
|
||||
|
||||
def test_seed_workflow_phases_non_empty() -> None:
|
||||
workflows = load_workflows_from_dir(WORKFLOWS_DIR)
|
||||
for w in workflows:
|
||||
assert len(w.phases) >= 1
|
||||
|
||||
|
||||
def test_seed_workflow_phase_keys_unique() -> None:
|
||||
workflows = load_workflows_from_dir(WORKFLOWS_DIR)
|
||||
for w in workflows:
|
||||
keys = [ph.key for ph in w.phases]
|
||||
assert len(keys) == len(set(keys)), f"{w.name}: duplicate phase keys"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# WorkflowTemplate validators
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_phase_references_undefined_role_raises() -> None:
|
||||
data = _minimal_template(
|
||||
roles=[_minimal_role(id="spec_writer")],
|
||||
phases=[_minimal_phase(role="nonexistent_role")],
|
||||
)
|
||||
with pytest.raises(ValidationError, match="unknown role"):
|
||||
WorkflowTemplate.model_validate(data)
|
||||
|
||||
|
||||
def test_duplicate_phase_keys_raises() -> None:
|
||||
data = _minimal_template(
|
||||
roles=[_minimal_role(id="spec_writer")],
|
||||
phases=[
|
||||
_minimal_phase(key="spec"),
|
||||
_minimal_phase(key="spec"),
|
||||
],
|
||||
)
|
||||
with pytest.raises(ValidationError, match="duplicate phase keys"):
|
||||
WorkflowTemplate.model_validate(data)
|
||||
|
||||
|
||||
def test_duplicate_role_ids_raises() -> None:
|
||||
data = _minimal_template(
|
||||
roles=[_minimal_role(id="spec_writer"), _minimal_role(id="spec_writer")],
|
||||
phases=[_minimal_phase(role="spec_writer")],
|
||||
)
|
||||
with pytest.raises(ValidationError, match="duplicate role ids"):
|
||||
WorkflowTemplate.model_validate(data)
|
||||
|
||||
|
||||
def test_phase_key_uppercase_raises() -> None:
|
||||
data = _minimal_template(phases=[_minimal_phase(key="SPEC")])
|
||||
with pytest.raises(ValidationError):
|
||||
WorkflowTemplate.model_validate(data)
|
||||
|
||||
|
||||
def test_phase_key_with_hyphen_raises() -> None:
|
||||
"""Hyphens are not allowed in phase keys (only a-z, 0-9, _)."""
|
||||
data = _minimal_template(phases=[_minimal_phase(key="spec-one")])
|
||||
with pytest.raises(ValidationError):
|
||||
WorkflowTemplate.model_validate(data)
|
||||
|
||||
|
||||
def test_phase_key_leading_digit_raises() -> None:
|
||||
data = _minimal_template(phases=[_minimal_phase(key="1spec")])
|
||||
with pytest.raises(ValidationError):
|
||||
WorkflowTemplate.model_validate(data)
|
||||
|
||||
|
||||
def test_phase_key_snake_case_ok() -> None:
|
||||
data = _minimal_template(phases=[_minimal_phase(key="spec_write_phase")])
|
||||
wt = WorkflowTemplate.model_validate(data)
|
||||
assert wt.phases[0].key == "spec_write_phase"
|
||||
|
||||
|
||||
def test_role_id_pattern_invalid_raises() -> None:
|
||||
data = _minimal_template(
|
||||
roles=[_minimal_role(id="Spec-Writer")],
|
||||
phases=[_minimal_phase(role="spec_writer")],
|
||||
)
|
||||
with pytest.raises(ValidationError):
|
||||
WorkflowTemplate.model_validate(data)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ExpectedArtifact: alias mapping
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_expected_artifact_schema_alias() -> None:
|
||||
"""yaml uses 'schema' key; Python attribute is schema_id."""
|
||||
art = ExpectedArtifact.model_validate({"path": "artifacts/spec.json", "schema": "dev/spec@1"})
|
||||
assert art.schema_id == "dev/spec@1"
|
||||
assert art.path == "artifacts/spec.json"
|
||||
|
||||
|
||||
def test_expected_artifact_extra_field_raises() -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
ExpectedArtifact.model_validate({"path": "x.json", "schema": "dev/spec@1", "unknown": True})
|
||||
|
||||
|
||||
def test_expected_artifact_missing_schema_raises() -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
ExpectedArtifact.model_validate({"path": "x.json"})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# WorkflowTemplate frozen + extra="forbid"
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_template_frozen() -> None:
|
||||
wt = WorkflowTemplate.model_validate(_minimal_template())
|
||||
with pytest.raises((TypeError, ValidationError)):
|
||||
wt.name = "mutated" # type: ignore[misc]
|
||||
|
||||
|
||||
def test_template_extra_field_raises() -> None:
|
||||
data = _minimal_template(extra_unknown_field="oops")
|
||||
with pytest.raises(ValidationError):
|
||||
WorkflowTemplate.model_validate(data)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# compute_hash: determinism
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_compute_hash_deterministic() -> None:
|
||||
wt = WorkflowTemplate.model_validate(_minimal_template())
|
||||
hashes = [wt.compute_hash() for _ in range(20)]
|
||||
assert len(set(hashes)) == 1
|
||||
|
||||
|
||||
def test_compute_hash_returns_64_char_hex() -> None:
|
||||
wt = WorkflowTemplate.model_validate(_minimal_template())
|
||||
h = wt.compute_hash()
|
||||
assert re.fullmatch(r"[0-9a-f]{64}", h)
|
||||
|
||||
|
||||
def test_compute_hash_different_templates_differ() -> None:
|
||||
wt1 = WorkflowTemplate.model_validate(_minimal_template(name="wf1"))
|
||||
wt2 = WorkflowTemplate.model_validate(_minimal_template(name="wf2"))
|
||||
assert wt1.compute_hash() != wt2.compute_hash()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# load_workflow_yaml: file not found
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_load_workflow_yaml_missing_file(tmp_path: Path) -> None:
|
||||
with pytest.raises(FileNotFoundError):
|
||||
load_workflow_yaml(tmp_path / "no.yaml")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# load_workflows_from_dir: duplicate detection + missing dir
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_load_workflows_from_dir_duplicate_raises(tmp_path: Path) -> None:
|
||||
import yaml
|
||||
|
||||
data = _minimal_template()
|
||||
for fname in ("wf-a@1.yaml", "wf-b@1.yaml"):
|
||||
(tmp_path / fname).write_text(yaml.dump(data), encoding="utf-8")
|
||||
|
||||
with pytest.raises(ValueError, match="duplicate workflow"):
|
||||
load_workflows_from_dir(tmp_path)
|
||||
|
||||
|
||||
def test_load_workflows_from_dir_missing_dir() -> None:
|
||||
result = load_workflows_from_dir(Path("/nonexistent_wf_dir_xyz"))
|
||||
assert result == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Snapshot: seed hashes are stable
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_spec_and_review_hash_prefix() -> None:
|
||||
workflows = load_workflows_from_dir(WORKFLOWS_DIR)
|
||||
w = next(x for x in workflows if x.name == "spec-and-review")
|
||||
assert w.compute_hash().startswith("1c94587647b16f0d")
|
||||
|
||||
|
||||
def test_bug_fix_hash_prefix() -> None:
|
||||
workflows = load_workflows_from_dir(WORKFLOWS_DIR)
|
||||
w = next(x for x in workflows if x.name == "bug-fix-with-reproduction")
|
||||
assert w.compute_hash().startswith("a137c9656f10e88a")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Step 2 patch: Counter-based duplicate role ids report is sorted
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_workflow_duplicate_role_ids_reported_sorted() -> None:
|
||||
"""Multiple duplicated role ids must be reported in sorted order."""
|
||||
with pytest.raises(ValidationError, match=r"duplicate role ids: \['a', 'b'\]"):
|
||||
WorkflowTemplate.model_validate(
|
||||
{
|
||||
"name": "x",
|
||||
"version": 1,
|
||||
"roles": [
|
||||
{"id": "b", "required_capabilities": ["spec_write"]},
|
||||
{"id": "a", "required_capabilities": ["spec_write"]},
|
||||
{"id": "a", "required_capabilities": ["spec_write"]},
|
||||
{"id": "b", "required_capabilities": ["spec_write"]},
|
||||
],
|
||||
"phases": [
|
||||
{
|
||||
"key": "x",
|
||||
"title": "x",
|
||||
"risk": "low",
|
||||
"role": "a",
|
||||
"instructions": "x" * 20,
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def test_code_investigation_hash_prefix() -> None:
|
||||
workflows = load_workflows_from_dir(WORKFLOWS_DIR)
|
||||
w = next(x for x in workflows if x.name == "code-investigation")
|
||||
assert w.compute_hash().startswith("5b80ea2e248d5232")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Deep immutability: nested list-valued fields are tuples (cannot be mutated)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_workflow_phases_immutable() -> None:
|
||||
"""phases is a tuple — .append() must raise AttributeError."""
|
||||
wt = WorkflowTemplate.model_validate(_minimal_template())
|
||||
with pytest.raises((AttributeError, TypeError)):
|
||||
wt.phases.append(None) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def test_workflow_roles_immutable() -> None:
|
||||
"""roles is a tuple — .append() must raise AttributeError."""
|
||||
wt = WorkflowTemplate.model_validate(_minimal_template())
|
||||
with pytest.raises((AttributeError, TypeError)):
|
||||
wt.roles.append(None) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def test_workflow_role_required_capabilities_immutable() -> None:
|
||||
"""required_capabilities is a tuple — .append() must raise AttributeError."""
|
||||
from my_deepagent.workflow import WorkflowRole
|
||||
|
||||
role = WorkflowRole.model_validate(
|
||||
{"id": "spec_writer", "required_capabilities": ["spec_write"]}
|
||||
)
|
||||
with pytest.raises((AttributeError, TypeError)):
|
||||
role.required_capabilities.append(None) # type: ignore[attr-defined]
|
||||
Reference in New Issue
Block a user