"""Unit tests for src/my_deepagent/monitoring/cost_estimator.py.""" from __future__ import annotations from unittest.mock import MagicMock import pytest from my_deepagent.monitoring.cost_estimator import ( _DEFAULT_INPUT_TOKENS, _DEFAULT_OUTPUT_TOKENS, PhaseCostEstimate, WorkflowCostEstimate, estimate_phase, estimate_workflow, ) from my_deepagent.monitoring.pricing import ModelPrice, PricingCache # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _make_pricing(model: str = "anthropic/claude-sonnet-4-6") -> PricingCache: cache = PricingCache() cache.set( [ ModelPrice( model=model, input_per_1k_usd=0.003, output_per_1k_usd=0.015, context_length=200000, ) ] ) return cache def _make_persona( model: str = "anthropic/claude-sonnet-4-6", max_tokens: int | None = None, ) -> object: p = MagicMock() p.name = "test-persona" p.version = 1 p.model = model p.model_params = {"max_tokens": max_tokens} if max_tokens else {} return p def _make_phase(key: str = "spec") -> MagicMock: phase = MagicMock() phase.key = key return phase def _make_binding(persona: object) -> MagicMock: b = MagicMock() b.persona = persona return b # --------------------------------------------------------------------------- # estimate_phase # --------------------------------------------------------------------------- def test_estimate_phase_known_model_correct_cost() -> None: pricing = _make_pricing("anthropic/claude-sonnet-4-6") persona = _make_persona("anthropic/claude-sonnet-4-6") phase = _make_phase("spec") est = estimate_phase(phase, persona, pricing) # type: ignore[arg-type] expected_cost = _DEFAULT_INPUT_TOKENS / 1000.0 * 0.003 + _DEFAULT_OUTPUT_TOKENS / 1000.0 * 0.015 assert isinstance(est, PhaseCostEstimate) assert est.phase_key == "spec" assert est.persona_name == "test-persona@1" assert est.model == "anthropic/claude-sonnet-4-6" assert est.estimated_input_tokens == _DEFAULT_INPUT_TOKENS assert est.estimated_output_tokens == _DEFAULT_OUTPUT_TOKENS assert est.estimated_cost_usd == pytest.approx(expected_cost) def test_estimate_phase_unknown_model_returns_zero_cost() -> None: pricing = PricingCache() # empty persona = _make_persona("unknown/model-xyz") phase = _make_phase("unknown_phase") est = estimate_phase(phase, persona, pricing) # type: ignore[arg-type] assert est.estimated_cost_usd == 0.0 def test_estimate_phase_max_tokens_override() -> None: pricing = _make_pricing() persona = _make_persona(max_tokens=2000) phase = _make_phase() est = estimate_phase(phase, persona, pricing) # type: ignore[arg-type] assert est.estimated_output_tokens == 2000 def test_estimate_phase_default_output_tokens_when_no_max_tokens() -> None: pricing = _make_pricing() persona = _make_persona() # no max_tokens phase = _make_phase() est = estimate_phase(phase, persona, pricing) # type: ignore[arg-type] assert est.estimated_output_tokens == _DEFAULT_OUTPUT_TOKENS # --------------------------------------------------------------------------- # estimate_workflow # --------------------------------------------------------------------------- def test_estimate_workflow_sums_phases() -> None: pricing = _make_pricing() phase1 = _make_phase("phase1") phase1.role = "researcher" phase2 = _make_phase("phase2") phase2.role = "reviewer" template = MagicMock() template.phases = [phase1, phase2] persona1 = _make_persona() persona2 = _make_persona() bindings = { "researcher": _make_binding(persona1), "reviewer": _make_binding(persona2), } est = estimate_workflow(template, bindings, pricing) # type: ignore[arg-type] assert isinstance(est, WorkflowCostEstimate) assert len(est.phases) == 2 assert est.total_usd == pytest.approx(sum(p.estimated_cost_usd for p in est.phases)) assert est.total_usd > 0.0 def test_estimate_workflow_total_greater_than_zero_with_known_models() -> None: pricing = _make_pricing() phase = _make_phase("spec") phase.role = "researcher" template = MagicMock() template.phases = [phase] persona = _make_persona() bindings = {"researcher": _make_binding(persona)} est = estimate_workflow(template, bindings, pricing) # type: ignore[arg-type] assert est.total_usd > 0.0