Pre-flight assets prepared on the main machine before the new-machine rewrite of my-deepagent in Python. - poc/: BudgetTracker + CostMiddleware + MockChatModel PoC. Validates wrap_model_call pattern, SQLite WAL + ON CONFLICT upsert, per-scope cap accounting. 5/5 pytest PASS in isolated uv venv. - schemas/: 10 personas (Anthropic Sonnet/Opus/Haiku + DeepSeek mix), 3 workflows (spec-and-review, bug-fix-with-reproduction, code-investigation), 4 artifact JSON Schemas (dev/spec@1, dev/phase-plan@1, dev/review-finding-batch@1, common/final-report@1). - schemas/validate.py: pydantic + Draft202012 cross-validation. 18/18 assets verified. - README.md: new-machine bootstrap instructions. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
83 lines
2.6 KiB
Python
83 lines
2.6 KiB
Python
"""Cost tracking middleware for my-deepagent PoC.
|
|
|
|
Uses langchain.agents.middleware.AgentMiddleware (langchain 1.x) to intercept
|
|
model calls and record budget usage via BudgetTracker.
|
|
|
|
Import path confirmed: from langchain.agents.middleware import AgentMiddleware
|
|
"""
|
|
|
|
import logging
|
|
from collections.abc import Awaitable, Callable
|
|
from typing import Any
|
|
|
|
from langchain.agents.middleware import AgentMiddleware, ModelRequest, ModelResponse
|
|
|
|
from poc.budget import BudgetTracker
|
|
from poc.pricing import compute_cost
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_WORST_CASE_INPUT_TOKENS = 4096
|
|
_WORST_CASE_OUTPUT_TOKENS = 2048
|
|
|
|
|
|
class CostMiddleware(AgentMiddleware): # type: ignore[type-arg]
|
|
"""Middleware that checks budget before model call and records cost after.
|
|
|
|
Pre-call: estimates cost with a conservative worst-case token count and calls
|
|
``tracker.assert_can_call``. If the tracker raises ``BudgetExhausted``,
|
|
the exception propagates and the model is never called.
|
|
|
|
Post-call: extracts actual ``usage_metadata`` from the first AIMessage in
|
|
the response and records the real cost via ``tracker.record``.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
tracker: BudgetTracker,
|
|
run_id: str | None,
|
|
persona_name: str,
|
|
model_name: str,
|
|
) -> None:
|
|
self._tracker = tracker
|
|
self._run_id = run_id
|
|
self._persona_name = persona_name
|
|
self._model_name = model_name
|
|
|
|
async def awrap_model_call(
|
|
self,
|
|
request: ModelRequest[Any],
|
|
handler: Callable[[ModelRequest[Any]], Awaitable[ModelResponse[Any]]],
|
|
) -> ModelResponse[Any]:
|
|
"""Check budget, call model, record actual cost."""
|
|
estimated = compute_cost(
|
|
self._model_name,
|
|
_WORST_CASE_INPUT_TOKENS,
|
|
_WORST_CASE_OUTPUT_TOKENS,
|
|
)
|
|
await self._tracker.assert_can_call(self._run_id, self._persona_name, estimated)
|
|
|
|
response = await handler(request)
|
|
|
|
usage: dict[str, Any] = {}
|
|
if response.result:
|
|
first_msg = response.result[0]
|
|
usage = getattr(first_msg, "usage_metadata", None) or {}
|
|
|
|
actual = compute_cost(
|
|
self._model_name,
|
|
int(usage.get("input_tokens", 0)),
|
|
int(usage.get("output_tokens", 0)),
|
|
)
|
|
await self._tracker.record(self._run_id, self._persona_name, actual)
|
|
|
|
logger.debug(
|
|
"CostMiddleware: model=%s persona=%s estimated=$%.6f actual=$%.6f",
|
|
self._model_name,
|
|
self._persona_name,
|
|
estimated,
|
|
actual,
|
|
)
|
|
|
|
return response
|