"""Unit tests for src/my_deepagent/middleware/cost.py.""" from __future__ import annotations from typing import Any from unittest.mock import AsyncMock, MagicMock from uuid import UUID import pytest from my_deepagent.middleware.cost import CostMiddleware from my_deepagent.monitoring.pricing import ModelPrice, PricingCache # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _make_pricing_cache( model: str = "anthropic/claude-sonnet", input_per_1k: float = 0.003, output_per_1k: float = 0.015, ) -> PricingCache: cache = PricingCache() cache.set( [ ModelPrice( model=model, input_per_1k_usd=input_per_1k, output_per_1k_usd=output_per_1k, context_length=200000, ) ] ) return cache def _make_response(input_tokens: int = 100, output_tokens: int = 50) -> MagicMock: response = MagicMock() response.usage_metadata = {"input_tokens": input_tokens, "output_tokens": output_tokens} return response # --------------------------------------------------------------------------- # awrap_model_call — success path # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_cost_middleware_records_correct_fields_on_success() -> None: recorder = AsyncMock() cache = _make_pricing_cache() mw = CostMiddleware( pricing=cache, model_name="anthropic/claude-sonnet", run_id=UUID("00000000-0000-0000-0000-000000000001"), phase_id=UUID("00000000-0000-0000-0000-000000000002"), persona_name="test-persona", recorder=recorder, ) response = _make_response(input_tokens=1000, output_tokens=500) handler = AsyncMock(return_value=response) request = MagicMock() result = await mw.awrap_model_call(request, handler) assert result is response recorder.assert_awaited_once() record: dict[str, Any] = recorder.call_args[0][0] assert record["model"] == "anthropic/claude-sonnet" assert record["input_tokens"] == 1000 assert record["output_tokens"] == 500 assert record["status"] == "ok" assert record["error_code"] is None assert record["latency_ms"] >= 0 # cost: (1000/1000 * 0.003) + (500/1000 * 0.015) expected_cost = 0.003 * 1.0 + 0.015 * 0.5 assert record["cost_usd_total"] == pytest.approx(expected_cost) @pytest.mark.asyncio async def test_cost_middleware_no_recorder_is_noop() -> None: cache = _make_pricing_cache() mw = CostMiddleware(pricing=cache, model_name="anthropic/claude-sonnet") response = _make_response() handler = AsyncMock(return_value=response) # Should not raise even with recorder=None result = await mw.awrap_model_call(MagicMock(), handler) assert result is response # --------------------------------------------------------------------------- # awrap_model_call — error path # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_cost_middleware_records_error_on_handler_exception() -> None: recorder = AsyncMock() cache = _make_pricing_cache() mw = CostMiddleware( pricing=cache, model_name="anthropic/claude-sonnet", recorder=recorder, ) handler = AsyncMock(side_effect=RuntimeError("timeout")) with pytest.raises(RuntimeError, match="timeout"): await mw.awrap_model_call(MagicMock(), handler) recorder.assert_awaited_once() record: dict[str, Any] = recorder.call_args[0][0] assert record["status"] == "error" assert record["error_code"] == "RuntimeError" assert record["input_tokens"] == 0 assert record["output_tokens"] == 0 @pytest.mark.asyncio async def test_cost_middleware_reraises_exception() -> None: cache = _make_pricing_cache() mw = CostMiddleware(pricing=cache, model_name="m", recorder=AsyncMock()) handler = AsyncMock(side_effect=ValueError("bad input")) with pytest.raises(ValueError, match="bad input"): await mw.awrap_model_call(MagicMock(), handler) # --------------------------------------------------------------------------- # cost computation via cache # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_cost_zero_when_model_not_in_cache() -> None: recorder = AsyncMock() cache = PricingCache() # empty mw = CostMiddleware(pricing=cache, model_name="unknown/model", recorder=recorder) response = _make_response(input_tokens=1000, output_tokens=1000) handler = AsyncMock(return_value=response) await mw.awrap_model_call(MagicMock(), handler) record = recorder.call_args[0][0] assert record["cost_usd_total"] == 0.0