"""Integration tests: CostMiddleware + BudgetTracker wire-up.""" from __future__ import annotations import tempfile from pathlib import Path from typing import Any from unittest.mock import AsyncMock, MagicMock from uuid import uuid4 import pytest import pytest_asyncio from my_deepagent.budget import BudgetOnHit, BudgetTracker from my_deepagent.errors import BudgetExhaustedError from my_deepagent.middleware.cost import CostMiddleware from my_deepagent.monitoring.pricing import ModelPrice, PricingCache from my_deepagent.persistence.db import Database # --------------------------------------------------------------------------- # Fixtures # --------------------------------------------------------------------------- _MODEL = "anthropic/claude-sonnet-4-6" _IN_PRICE = 0.003 _OUT_PRICE = 0.015 @pytest_asyncio.fixture async def db() -> Database: p = Path(tempfile.mkdtemp()) / "test_mw_budget.sqlite3" database = Database(f"sqlite+aiosqlite:///{p}") await database.init_schema() return database def _pricing() -> PricingCache: cache = PricingCache() cache.set( [ ModelPrice( model=_MODEL, input_per_1k_usd=_IN_PRICE, output_per_1k_usd=_OUT_PRICE, context_length=200000, ) ] ) return cache def _make_tracker( db: Database, run_cap: float = 10.0, on_hit: BudgetOnHit = BudgetOnHit.BLOCK, ) -> BudgetTracker: return BudgetTracker( db=db, daily_cap_usd=100.0, run_cap_usd=run_cap, daily_warn_usd=50.0, run_warn_usd=5.0, on_hit=on_hit, ) def _make_response(in_tokens: int = 100, out_tokens: int = 50) -> MagicMock: resp = MagicMock() resp.usage_metadata = {"input_tokens": in_tokens, "output_tokens": out_tokens} return resp # --------------------------------------------------------------------------- # Test: over cap → assert_can_call raises before handler is called # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_over_cap_raises_before_handler(db: Database) -> None: tracker = _make_tracker(db, run_cap=0.000001, on_hit=BudgetOnHit.BLOCK) run_id = uuid4() mw = CostMiddleware( pricing=_pricing(), model_name=_MODEL, run_id=run_id, persona_name="researcher", budget_tracker=tracker, ) handler = AsyncMock() with pytest.raises(BudgetExhaustedError): await mw.awrap_model_call(MagicMock(), handler) handler.assert_not_awaited() # --------------------------------------------------------------------------- # Test: under cap → handler called + ledger accumulated # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_under_cap_handler_called_and_ledger_updated(db: Database) -> None: tracker = _make_tracker(db, run_cap=10.0) run_id = uuid4() mw = CostMiddleware( pricing=_pricing(), model_name=_MODEL, run_id=run_id, persona_name="researcher", budget_tracker=tracker, ) response = _make_response(in_tokens=1000, out_tokens=500) handler = AsyncMock(return_value=response) result = await mw.awrap_model_call(MagicMock(), handler) assert result is response handler.assert_awaited_once() # Check ledger was updated run_spent = await tracker.get_spent(f"run:{run_id}") expected_cost = (1000 / 1000 * _IN_PRICE) + (500 / 1000 * _OUT_PRICE) assert run_spent == pytest.approx(expected_cost) # --------------------------------------------------------------------------- # Test: handler exception → recorder gets status=error, budget NOT accumulated # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_handler_exception_error_status_no_budget(db: Database) -> None: tracker = _make_tracker(db, run_cap=10.0) run_id = uuid4() recorder = AsyncMock() mw = CostMiddleware( pricing=_pricing(), model_name=_MODEL, run_id=run_id, persona_name="researcher", recorder=recorder, budget_tracker=tracker, ) handler = AsyncMock(side_effect=RuntimeError("model_error")) with pytest.raises(RuntimeError, match="model_error"): await mw.awrap_model_call(MagicMock(), handler) # recorder called with error status recorder.assert_awaited_once() record: dict[str, Any] = recorder.call_args[0][0] assert record["status"] == "error" assert record["error_code"] == "RuntimeError" # Budget should NOT be accumulated after an error run_spent = await tracker.get_spent(f"run:{run_id}") assert run_spent == 0.0 # --------------------------------------------------------------------------- # Test: budget=None → existing behaviour preserved (no BudgetExhaustedError) # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_no_budget_tracker_still_works() -> None: recorder = AsyncMock() mw = CostMiddleware( pricing=_pricing(), model_name=_MODEL, recorder=recorder, budget_tracker=None, ) response = _make_response() handler = AsyncMock(return_value=response) result = await mw.awrap_model_call(MagicMock(), handler) assert result is response recorder.assert_awaited_once() record: dict[str, Any] = recorder.call_args[0][0] assert record["status"] == "ok"