"""Integration tests for src/my_deepagent/budget.py (BudgetTracker).""" from __future__ import annotations from uuid import UUID, uuid4 import pytest import pytest_asyncio from my_deepagent.budget import BudgetOnHit, BudgetTracker from my_deepagent.errors import BudgetExhaustedError from my_deepagent.persistence.db import Database # --------------------------------------------------------------------------- # Fixtures # --------------------------------------------------------------------------- _RUN_ID = UUID("00000000-0000-0000-0000-000000000001") @pytest_asyncio.fixture async def db(tmp_path: object) -> Database: import tempfile from pathlib import Path p = Path(tempfile.mkdtemp()) / "test_budget.sqlite3" database = Database(f"sqlite+aiosqlite:///{p}") await database.init_schema() return database def _make_tracker( db: Database, daily_cap: float = 5.0, run_cap: float = 1.0, on_hit: BudgetOnHit = BudgetOnHit.BLOCK, prompt_callback: object = None, ) -> BudgetTracker: return BudgetTracker( db=db, daily_cap_usd=daily_cap, run_cap_usd=run_cap, daily_warn_usd=3.0, run_warn_usd=0.5, on_hit=on_hit, prompt_callback=prompt_callback, # type: ignore[arg-type] ) # --------------------------------------------------------------------------- # init() # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_init_creates_day_scope_row(db: Database) -> None: tracker = _make_tracker(db) await tracker.init() spent = await tracker.get_spent(f"day:{_today()}") assert spent == 0.0 @pytest.mark.asyncio async def test_init_is_idempotent(db: Database) -> None: tracker = _make_tracker(db) await tracker.init() await tracker.init() # second call should not error or double-insert spent = await tracker.get_spent(f"day:{_today()}") assert spent == 0.0 # --------------------------------------------------------------------------- # assert_can_call — under cap # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_assert_can_call_under_cap_returns_ok(db: Database) -> None: tracker = _make_tracker(db, daily_cap=5.0, run_cap=1.0) result = await tracker.assert_can_call( run_id=_RUN_ID, persona_name="researcher", estimated_cost_usd=0.5, ) assert result.ok is True assert result.blocked_scope is None # --------------------------------------------------------------------------- # assert_can_call — over run cap (on_hit=block) # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_assert_can_call_over_run_cap_raises(db: Database) -> None: tracker = _make_tracker(db, run_cap=0.01, on_hit=BudgetOnHit.BLOCK) with pytest.raises(BudgetExhaustedError) as exc_info: await tracker.assert_can_call( run_id=_RUN_ID, persona_name=None, estimated_cost_usd=1.0, ) err = exc_info.value assert err.scope.startswith("run:") assert err.projected_usd > 0.01 # --------------------------------------------------------------------------- # assert_can_call — over day cap (on_hit=block) # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_assert_can_call_over_day_cap_raises(db: Database) -> None: tracker = _make_tracker(db, daily_cap=0.001, run_cap=999.0, on_hit=BudgetOnHit.BLOCK) with pytest.raises(BudgetExhaustedError) as exc_info: await tracker.assert_can_call( run_id=_RUN_ID, persona_name=None, estimated_cost_usd=1.0, ) err = exc_info.value assert err.scope.startswith("day:") assert err.cap_usd == pytest.approx(0.001) # --------------------------------------------------------------------------- # record() — accumulates spend # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_record_accumulates_spend(db: Database) -> None: tracker = _make_tracker(db) run_id = uuid4() await tracker.record(run_id=run_id, persona_name=None, actual_cost_usd=0.10) await tracker.record(run_id=run_id, persona_name=None, actual_cost_usd=0.05) day_spent = await tracker.get_spent(f"day:{_today()}") run_spent = await tracker.get_spent(f"run:{run_id}") assert day_spent == pytest.approx(0.15) assert run_spent == pytest.approx(0.15) @pytest.mark.asyncio async def test_record_zero_is_noop(db: Database) -> None: tracker = _make_tracker(db) run_id = uuid4() await tracker.record(run_id=run_id, persona_name=None, actual_cost_usd=0.0) run_spent = await tracker.get_spent(f"run:{run_id}") assert run_spent == 0.0 # --------------------------------------------------------------------------- # on_hit=warn_continue # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_warn_continue_over_cap_returns_ok_no_raise(db: Database) -> None: tracker = _make_tracker(db, run_cap=0.001, on_hit=BudgetOnHit.WARN_CONTINUE) result = await tracker.assert_can_call( run_id=_RUN_ID, persona_name=None, estimated_cost_usd=1.0, ) # WARN_CONTINUE: blocked=False, no raise assert result.ok is True # --------------------------------------------------------------------------- # on_hit=prompt # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_prompt_callback_returns_true_proceeds(db: Database) -> None: async def _allow(scope: str, projected: float, cap: float) -> bool: return True tracker = _make_tracker(db, run_cap=0.001, on_hit=BudgetOnHit.PROMPT, prompt_callback=_allow) result = await tracker.assert_can_call( run_id=_RUN_ID, persona_name=None, estimated_cost_usd=1.0, ) assert result.ok is True @pytest.mark.asyncio async def test_prompt_callback_returns_false_raises(db: Database) -> None: async def _deny(scope: str, projected: float, cap: float) -> bool: return False tracker = _make_tracker(db, run_cap=0.001, on_hit=BudgetOnHit.PROMPT, prompt_callback=_deny) with pytest.raises(BudgetExhaustedError): await tracker.assert_can_call( run_id=_RUN_ID, persona_name=None, estimated_cost_usd=1.0, ) @pytest.mark.asyncio async def test_prompt_callback_none_raises_like_block(db: Database) -> None: tracker = _make_tracker(db, run_cap=0.001, on_hit=BudgetOnHit.PROMPT, prompt_callback=None) with pytest.raises(BudgetExhaustedError): await tracker.assert_can_call( run_id=_RUN_ID, persona_name=None, estimated_cost_usd=1.0, ) # --------------------------------------------------------------------------- # persona scope # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_persona_scope_accumulates_separately(db: Database) -> None: tracker = _make_tracker(db) await tracker.record(run_id=None, persona_name="researcher", actual_cost_usd=0.20) persona_spent = await tracker.get_spent(f"persona:researcher:day:{_today()}") day_spent = await tracker.get_spent(f"day:{_today()}") assert persona_spent == pytest.approx(0.20) assert day_spent == pytest.approx(0.20) # --------------------------------------------------------------------------- # get_remaining() # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_get_remaining_with_no_spend(db: Database) -> None: tracker = _make_tracker(db, daily_cap=5.0) remaining = await tracker.get_remaining(f"day:{_today()}") assert remaining == pytest.approx(5.0) @pytest.mark.asyncio async def test_get_remaining_after_spend(db: Database) -> None: tracker = _make_tracker(db, daily_cap=5.0) await tracker.record(run_id=None, persona_name=None, actual_cost_usd=1.5) remaining = await tracker.get_remaining(f"day:{_today()}") assert remaining == pytest.approx(3.5) @pytest.mark.asyncio async def test_get_remaining_unknown_scope_returns_none(db: Database) -> None: tracker = _make_tracker(db) # "unknown:xyz" has no cap in _cap_for_scope remaining = await tracker.get_remaining("unknown:xyz") assert remaining is None # --------------------------------------------------------------------------- # session: scope (v0.3 PR #6) — sub-agent rollup to root session # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_session_scope_accumulates_cost(db: Database) -> None: import uuid as _uuid tracker = _make_tracker(db, run_cap=2.0) session_id = _uuid.uuid4() await tracker.record( run_id=None, persona_name=None, actual_cost_usd=0.30, session_id=session_id ) await tracker.record( run_id=None, persona_name=None, actual_cost_usd=0.20, session_id=session_id ) spent = await tracker.get_spent(f"session:{session_id}") assert spent == pytest.approx(0.50) remaining = await tracker.get_remaining(f"session:{session_id}") assert remaining == pytest.approx(1.50) @pytest.mark.asyncio async def test_session_scope_omitted_when_no_session_id(db: Database) -> None: """Calls without ``session_id`` must NOT create a session: ledger row.""" import uuid as _uuid tracker = _make_tracker(db) # Drive a record without session_id. await tracker.record(run_id=None, persona_name=None, actual_cost_usd=0.10) # Querying any session scope should yield 0 spent. sid = _uuid.uuid4() assert (await tracker.get_spent(f"session:{sid}")) == pytest.approx(0.0) # --------------------------------------------------------------------------- # helpers # --------------------------------------------------------------------------- def _today() -> str: from datetime import UTC, datetime return datetime.now(UTC).strftime("%Y-%m-%d")