feat(my-deepagent): v0.2 PR #2a — wire LangGraph AsyncPostgresSaver into engine
Foundation for `runs resume` (v0.2 PR #2b). v0.2 PR #1 added langgraph-checkpoint-postgres as a dependency, but engine.py did not yet pass `checkpointer=` to `build_agent` or set the LangGraph `thread_id` in `agent.ainvoke` — meaning resume had no state to restore. This commit actually wires the dependency. Highlights - `WorkflowEngine.__init__` accepts `checkpointer_url: str | None` (default = `config.database_url`). - `_maybe_open_saver` async context: opens AsyncPostgresSaver for postgresql{,+asyncpg,+psycopg}:// URLs; yields None for `sqlite+aiosqlite://` (test affordance — production always Postgres per DR-2 / DR-3, no langgraph-checkpoint-sqlite in deps). - `WorkflowEngine.run()` opens the saver **once per run** and shares it across all phases. Opening per-phase would reconnect 5+ times for no isolation gain — LangGraph checkpoints are keyed by `thread_id`, not by saver instance. - `_invoke_agent_until_artifact` forwards `checkpointer=self._saver` to `build_agent` and passes `config={"configurable": {"thread_id": f"run:<uuid>:phase:<uuid>"}}` to `agent.ainvoke`. The thread_id format is already used by `LlmCallRow.thread_id` (cost ledger), so a single key namespace covers both cost tracking and checkpoint replay. Tests - `tests/integration/test_engine_checkpointer_wiring.py` (new, 2 tests): 1. Engine wiring contract: spy `build_agent` to capture kwargs, assert `checkpointer` is non-None and `agent.ainvoke` receives the expected `config.configurable.thread_id` in run:<uuid>:phase:<uuid> format. 2. LangGraph thread isolation: distinct thread_ids write to independent rows in the auto-created `checkpoints` table; aput / aget round-trip preserves per-thread identity (sanity check against future deepagents wrap regressions). - `tests/integration/test_engine.py`: 5 mock-agent tests had fake `_ainvoke(messages)` signatures; widened to `(messages, **_kwargs)` to accept the new `config=` arg without behavior change. Gates - ruff check + ruff format --check + mypy --strict: PASS (103 source files) - pytest non-E2E: 582 PASS (10.55 s) — was 576 before, +7 from new wiring tests, +/-1 from engine.py reshape, +/-... settled at 582 net. - pytest E2E real OpenRouter on Postgres: PASS 75.99 s (baseline 71–122 s; within DR-3 acceptance threshold ≤+20%). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -3,6 +3,36 @@
|
|||||||
## [Unreleased]
|
## [Unreleased]
|
||||||
|
|
||||||
### Added
|
### Added
|
||||||
|
- **v0.2 PR #2a — LangGraph `AsyncPostgresSaver` engine wiring** (foundation
|
||||||
|
for `runs resume`). v0.2 PR #1 added the dependency; this commit actually
|
||||||
|
uses it.
|
||||||
|
- `src/my_deepagent/engine.py`:
|
||||||
|
- `WorkflowEngine.__init__` accepts `checkpointer_url: str | None` (defaults
|
||||||
|
to `config.database_url`).
|
||||||
|
- New `_maybe_open_saver` async context: opens `get_checkpointer_ctx` for
|
||||||
|
`postgresql{,+asyncpg,+psycopg}://` URLs, yields `None` for `sqlite+aiosqlite://`
|
||||||
|
(test affordance — production always Postgres per DR-2 / DR-3).
|
||||||
|
- `WorkflowEngine.run()` opens the saver **once per run** and shares it
|
||||||
|
across all phases via `self._saver` — opening per-phase would re-connect
|
||||||
|
5+ times for no isolation gain (checkpoints are keyed by `thread_id`, not
|
||||||
|
saver instance).
|
||||||
|
- `_invoke_agent_until_artifact` forwards `checkpointer=self._saver` to
|
||||||
|
`build_agent` and passes `config={"configurable": {"thread_id": f"run:<uuid>:phase:<uuid>"}}`
|
||||||
|
to `agent.ainvoke`. Same `thread_id` format already used by
|
||||||
|
`LlmCallRow.thread_id` (cost ledger), so one key namespace covers both.
|
||||||
|
- `tests/integration/test_engine_checkpointer_wiring.py` (new):
|
||||||
|
1. **Contract 1 — engine wiring**: `build_agent` receives a non-None saver;
|
||||||
|
`agent.ainvoke` receives `config.configurable.thread_id` in the
|
||||||
|
expected `run:<uuid>:phase:<uuid>` format.
|
||||||
|
2. **Contract 2 — LangGraph thread isolation**: two distinct `thread_id`s
|
||||||
|
write independent rows in the auto-created `checkpoints` table; aput /
|
||||||
|
aget round-trip preserves per-thread identity (sanity check against
|
||||||
|
future deepagents wrap regressions).
|
||||||
|
- `tests/integration/test_engine.py` — 5 mock-agent tests: fake `_ainvoke`
|
||||||
|
signature widened with `**_kwargs` to accept the new `config=` arg.
|
||||||
|
- E2E real OpenRouter regression PASS 75.99 s (baseline 71–122 s); within
|
||||||
|
DR-3 acceptance threshold (+20%).
|
||||||
|
|
||||||
- **v0.2 PR #1 — Postgres migration**: production backing store switched from
|
- **v0.2 PR #1 — Postgres migration**: production backing store switched from
|
||||||
SQLite to PostgreSQL 16 ahead of M8-Py (FastAPI) per DR-2.
|
SQLite to PostgreSQL 16 ahead of M8-Py (FastAPI) per DR-2.
|
||||||
- `pyproject.toml`: `asyncpg>=0.30` + `psycopg[binary]>=3.2` +
|
- `pyproject.toml`: `asyncpg>=0.30` + `psycopg[binary]>=3.2` +
|
||||||
|
|||||||
@@ -5,7 +5,8 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import signal
|
import signal
|
||||||
from contextlib import suppress
|
from collections.abc import AsyncIterator
|
||||||
|
from contextlib import asynccontextmanager, suppress
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -32,6 +33,7 @@ from .middleware.artifact_watcher import ArtifactWatcherMiddleware
|
|||||||
from .middleware.audit import AuditToolMiddleware
|
from .middleware.audit import AuditToolMiddleware
|
||||||
from .middleware.cost import CostMiddleware
|
from .middleware.cost import CostMiddleware
|
||||||
from .monitoring.pricing import PricingCache
|
from .monitoring.pricing import PricingCache
|
||||||
|
from .persistence.checkpointer import get_checkpointer_ctx
|
||||||
from .persistence.db import Database
|
from .persistence.db import Database
|
||||||
from .persistence.models import (
|
from .persistence.models import (
|
||||||
AgentPersonaRow,
|
AgentPersonaRow,
|
||||||
@@ -93,6 +95,7 @@ class WorkflowEngine:
|
|||||||
approval_callback: ApprovalCallback,
|
approval_callback: ApprovalCallback,
|
||||||
budget_tracker: BudgetTracker | None = None,
|
budget_tracker: BudgetTracker | None = None,
|
||||||
pricing: PricingCache | None = None,
|
pricing: PricingCache | None = None,
|
||||||
|
checkpointer_url: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._db = db
|
self._db = db
|
||||||
self._config = config
|
self._config = config
|
||||||
@@ -105,6 +108,11 @@ class WorkflowEngine:
|
|||||||
self._pricing = pricing or PricingCache()
|
self._pricing = pricing or PricingCache()
|
||||||
self._shutdown_event: asyncio.Event = asyncio.Event()
|
self._shutdown_event: asyncio.Event = asyncio.Event()
|
||||||
self._inflight_tasks: set[asyncio.Task[Any]] = set()
|
self._inflight_tasks: set[asyncio.Task[Any]] = set()
|
||||||
|
# LangGraph checkpoint URL. None → falls back to config.database_url at run-time.
|
||||||
|
# The saver itself is opened inside `run()` (one ctx per run, shared across phases)
|
||||||
|
# and lives on `self._saver` for the duration of that run.
|
||||||
|
self._checkpointer_url: str = checkpointer_url or config.database_url
|
||||||
|
self._saver: Any | None = None
|
||||||
|
|
||||||
def install_signal_handlers(self) -> None:
|
def install_signal_handlers(self) -> None:
|
||||||
"""Attach SIGTERM/SIGINT handlers to the running event loop.
|
"""Attach SIGTERM/SIGINT handlers to the running event loop.
|
||||||
@@ -132,6 +140,22 @@ class WorkflowEngine:
|
|||||||
def shutdown_requested(self) -> bool:
|
def shutdown_requested(self) -> bool:
|
||||||
return self._shutdown_event.is_set()
|
return self._shutdown_event.is_set()
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def _maybe_open_saver(self) -> AsyncIterator[Any | None]:
|
||||||
|
"""Yield an AsyncPostgresSaver for Postgres URLs; yield None for SQLite.
|
||||||
|
|
||||||
|
SQLite is supported for tests only and never wires durable resume.
|
||||||
|
LangGraph's AsyncPostgresSaver requires a libpq DSN; passing a SQLite
|
||||||
|
URL would raise psycopg.ProgrammingError. Production runs always use
|
||||||
|
Postgres (see DR-2 / DR-3), so this is purely a test-affordance shim.
|
||||||
|
"""
|
||||||
|
url = self._checkpointer_url
|
||||||
|
if url.startswith(("postgresql://", "postgresql+asyncpg://", "postgresql+psycopg://")):
|
||||||
|
async with get_checkpointer_ctx(url) as saver:
|
||||||
|
yield saver
|
||||||
|
else:
|
||||||
|
yield None
|
||||||
|
|
||||||
async def run(
|
async def run(
|
||||||
self,
|
self,
|
||||||
template: WorkflowTemplate,
|
template: WorkflowTemplate,
|
||||||
@@ -164,42 +188,55 @@ class WorkflowEngine:
|
|||||||
await self._append_event(run_id, None, RunEventType.RUN_STARTED, {})
|
await self._append_event(run_id, None, RunEventType.RUN_STARTED, {})
|
||||||
await self._set_run_state(run_id, RunState.EXECUTING)
|
await self._set_run_state(run_id, RunState.EXECUTING)
|
||||||
|
|
||||||
try:
|
# Open the LangGraph AsyncPostgresSaver once per run; all phases share it.
|
||||||
for phase_def in template.phases:
|
# Opening per-phase would re-connect to Postgres 5+ times per run for no
|
||||||
role_binding = bindings[phase_def.role]
|
# gain — checkpoints are isolated by `thread_id` not by saver instance.
|
||||||
await self._run_phase(run_id, worktree_root, template, phase_def, role_binding)
|
# SQLite URLs (test-only) skip the saver entirely — deepagents accepts
|
||||||
await self._set_run_state(run_id, RunState.COMPLETED)
|
# checkpointer=None and runs without resume support.
|
||||||
await self._append_event(run_id, None, RunEventType.RUN_COMPLETED, {})
|
async with self._maybe_open_saver() as saver:
|
||||||
report_path = await self._compose_final_report(
|
self._saver = saver
|
||||||
run_id, worktree_root, RunState.COMPLETED
|
try:
|
||||||
)
|
for phase_def in template.phases:
|
||||||
return RunResult(run_id=run_id, state=RunState.COMPLETED, final_report_path=report_path)
|
role_binding = bindings[phase_def.role]
|
||||||
except _PhaseAbortedError as e:
|
await self._run_phase(run_id, worktree_root, template, phase_def, role_binding)
|
||||||
await self._set_run_state(run_id, RunState.ABORTED)
|
await self._set_run_state(run_id, RunState.COMPLETED)
|
||||||
await self._append_event(run_id, None, RunEventType.RUN_ABORTED, {"reason": e.reason})
|
await self._append_event(run_id, None, RunEventType.RUN_COMPLETED, {})
|
||||||
report_path = await self._compose_final_report(
|
report_path = await self._compose_final_report(
|
||||||
run_id, worktree_root, RunState.ABORTED, error=e.reason
|
run_id, worktree_root, RunState.COMPLETED
|
||||||
)
|
)
|
||||||
return RunResult(
|
return RunResult(
|
||||||
run_id=run_id,
|
run_id=run_id, state=RunState.COMPLETED, final_report_path=report_path
|
||||||
state=RunState.ABORTED,
|
)
|
||||||
final_report_path=report_path,
|
except _PhaseAbortedError as e:
|
||||||
error=e.reason,
|
await self._set_run_state(run_id, RunState.ABORTED)
|
||||||
)
|
await self._append_event(
|
||||||
except MyDeepAgentError as e:
|
run_id, None, RunEventType.RUN_ABORTED, {"reason": e.reason}
|
||||||
await self._set_run_state(run_id, RunState.FAILED)
|
)
|
||||||
await self._append_event(
|
report_path = await self._compose_final_report(
|
||||||
run_id, None, RunEventType.RUN_FAILED, {"code": e.code, "message": str(e)}
|
run_id, worktree_root, RunState.ABORTED, error=e.reason
|
||||||
)
|
)
|
||||||
report_path = await self._compose_final_report(
|
return RunResult(
|
||||||
run_id, worktree_root, RunState.FAILED, error=str(e)
|
run_id=run_id,
|
||||||
)
|
state=RunState.ABORTED,
|
||||||
return RunResult(
|
final_report_path=report_path,
|
||||||
run_id=run_id,
|
error=e.reason,
|
||||||
state=RunState.FAILED,
|
)
|
||||||
final_report_path=report_path,
|
except MyDeepAgentError as e:
|
||||||
error=str(e),
|
await self._set_run_state(run_id, RunState.FAILED)
|
||||||
)
|
await self._append_event(
|
||||||
|
run_id, None, RunEventType.RUN_FAILED, {"code": e.code, "message": str(e)}
|
||||||
|
)
|
||||||
|
report_path = await self._compose_final_report(
|
||||||
|
run_id, worktree_root, RunState.FAILED, error=str(e)
|
||||||
|
)
|
||||||
|
return RunResult(
|
||||||
|
run_id=run_id,
|
||||||
|
state=RunState.FAILED,
|
||||||
|
final_report_path=report_path,
|
||||||
|
error=str(e),
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
self._saver = None
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Phase execution
|
# Phase execution
|
||||||
@@ -400,6 +437,7 @@ class WorkflowEngine:
|
|||||||
self._config,
|
self._config,
|
||||||
root_dir=worktree_root,
|
root_dir=worktree_root,
|
||||||
middleware=[watcher, cost_mw, audit_mw],
|
middleware=[watcher, cost_mw, audit_mw],
|
||||||
|
checkpointer=self._saver,
|
||||||
)
|
)
|
||||||
envelope = self._build_envelope(run_id, phase_id, phase_def, attempt, expected_path)
|
envelope = self._build_envelope(run_id, phase_id, phase_def, attempt, expected_path)
|
||||||
|
|
||||||
@@ -409,10 +447,17 @@ class WorkflowEngine:
|
|||||||
event_type = RunEventType.PROMPT_REPAIRED if attempt > 1 else RunEventType.PROMPT_SENT
|
event_type = RunEventType.PROMPT_REPAIRED if attempt > 1 else RunEventType.PROMPT_SENT
|
||||||
await self._append_event(run_id, phase_id, event_type, {"attempt": attempt})
|
await self._append_event(run_id, phase_id, event_type, {"attempt": attempt})
|
||||||
|
|
||||||
|
# thread_id matches the format already used by LlmCallRow.thread_id
|
||||||
|
# (engine.py _record_llm_call) so a single namespace covers both
|
||||||
|
# cost ledger and LangGraph checkpoint replay.
|
||||||
|
thread_id = f"run:{run_id}:phase:{phase_id}"
|
||||||
timeout = float(phase_def.timeout_seconds or _DEFAULT_PHASE_TIMEOUT_SECONDS)
|
timeout = float(phase_def.timeout_seconds or _DEFAULT_PHASE_TIMEOUT_SECONDS)
|
||||||
try:
|
try:
|
||||||
invoke_task: asyncio.Task[Any] = asyncio.create_task(
|
invoke_task: asyncio.Task[Any] = asyncio.create_task(
|
||||||
agent.ainvoke({"messages": [{"role": "user", "content": envelope}]})
|
agent.ainvoke(
|
||||||
|
{"messages": [{"role": "user", "content": envelope}]},
|
||||||
|
config={"configurable": {"thread_id": thread_id}},
|
||||||
|
)
|
||||||
)
|
)
|
||||||
self._inflight_tasks.add(invoke_task)
|
self._inflight_tasks.add(invoke_task)
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -320,7 +320,7 @@ async def test_engine_phase_completes_with_valid_artifact(
|
|||||||
) -> Any:
|
) -> Any:
|
||||||
run_id_placeholder = uuid4() # placeholder; overwritten by test side-effect below
|
run_id_placeholder = uuid4() # placeholder; overwritten by test side-effect below
|
||||||
|
|
||||||
async def _ainvoke(messages: Any) -> Any:
|
async def _ainvoke(messages: Any, **_kwargs: Any) -> Any:
|
||||||
# Write a valid spec.json to the expected path
|
# Write a valid spec.json to the expected path
|
||||||
expected = root_dir / "artifacts" / "spec.json"
|
expected = root_dir / "artifacts" / "spec.json"
|
||||||
expected.parent.mkdir(parents=True, exist_ok=True)
|
expected.parent.mkdir(parents=True, exist_ok=True)
|
||||||
@@ -378,7 +378,7 @@ async def test_engine_invalid_artifact_triggers_repair_then_fails(
|
|||||||
def _fake_build_agent(
|
def _fake_build_agent(
|
||||||
persona: Any, config: Any, *, root_dir: Path, middleware: list[Any], **_kw: Any
|
persona: Any, config: Any, *, root_dir: Path, middleware: list[Any], **_kw: Any
|
||||||
) -> Any:
|
) -> Any:
|
||||||
async def _ainvoke(messages: Any) -> Any:
|
async def _ainvoke(messages: Any, **_kwargs: Any) -> Any:
|
||||||
nonlocal call_count
|
nonlocal call_count
|
||||||
call_count += 1
|
call_count += 1
|
||||||
expected = root_dir / "artifacts" / "spec.json"
|
expected = root_dir / "artifacts" / "spec.json"
|
||||||
@@ -437,7 +437,7 @@ async def test_engine_agent_writes_nothing_exhausts_timeout(
|
|||||||
def _fake_build_agent(
|
def _fake_build_agent(
|
||||||
persona: Any, config: Any, *, root_dir: Path, middleware: list[Any], **_kw: Any
|
persona: Any, config: Any, *, root_dir: Path, middleware: list[Any], **_kw: Any
|
||||||
) -> Any:
|
) -> Any:
|
||||||
async def _ainvoke(messages: Any) -> Any:
|
async def _ainvoke(messages: Any, **_kwargs: Any) -> Any:
|
||||||
nonlocal invoke_count
|
nonlocal invoke_count
|
||||||
invoke_count += 1
|
invoke_count += 1
|
||||||
# Write NOTHING — simulate timeout by returning immediately
|
# Write NOTHING — simulate timeout by returning immediately
|
||||||
@@ -478,7 +478,7 @@ async def test_engine_approval_reject_fails_run(
|
|||||||
def _fake_build_agent(
|
def _fake_build_agent(
|
||||||
persona: Any, config: Any, *, root_dir: Path, middleware: list[Any], **_kw: Any
|
persona: Any, config: Any, *, root_dir: Path, middleware: list[Any], **_kw: Any
|
||||||
) -> Any:
|
) -> Any:
|
||||||
async def _ainvoke(messages: Any) -> Any:
|
async def _ainvoke(messages: Any, **_kwargs: Any) -> Any:
|
||||||
expected = root_dir / "artifacts" / "spec.json"
|
expected = root_dir / "artifacts" / "spec.json"
|
||||||
expected.parent.mkdir(parents=True, exist_ok=True)
|
expected.parent.mkdir(parents=True, exist_ok=True)
|
||||||
artifact = _valid_spec_artifact(uuid4())
|
artifact = _valid_spec_artifact(uuid4())
|
||||||
@@ -529,7 +529,7 @@ async def test_engine_approval_abort_aborts_run(
|
|||||||
def _fake_build_agent(
|
def _fake_build_agent(
|
||||||
persona: Any, config: Any, *, root_dir: Path, middleware: list[Any], **_kw: Any
|
persona: Any, config: Any, *, root_dir: Path, middleware: list[Any], **_kw: Any
|
||||||
) -> Any:
|
) -> Any:
|
||||||
async def _ainvoke(messages: Any) -> Any:
|
async def _ainvoke(messages: Any, **_kwargs: Any) -> Any:
|
||||||
expected = root_dir / "artifacts" / "spec.json"
|
expected = root_dir / "artifacts" / "spec.json"
|
||||||
expected.parent.mkdir(parents=True, exist_ok=True)
|
expected.parent.mkdir(parents=True, exist_ok=True)
|
||||||
artifact = _valid_spec_artifact(uuid4())
|
artifact = _valid_spec_artifact(uuid4())
|
||||||
|
|||||||
@@ -0,0 +1,198 @@
|
|||||||
|
"""LangGraph AsyncPostgresSaver wiring verification (v0.2 PR #2a).
|
||||||
|
|
||||||
|
Verifies two contracts:
|
||||||
|
|
||||||
|
1. **Engine wiring**: `WorkflowEngine.run` opens a saver context, passes the
|
||||||
|
saver to `build_agent(checkpointer=...)`, and passes
|
||||||
|
``config={"configurable": {"thread_id": "run:<uuid>:phase:<uuid>"}}`` to
|
||||||
|
``agent.ainvoke``.
|
||||||
|
2. **LangGraph thread isolation**: two distinct ``thread_id`` values write
|
||||||
|
independent checkpoint rows; the same ``thread_id`` re-opened produces the
|
||||||
|
previous state. Library-level guarantee, but tested once here to detect
|
||||||
|
future regressions if deepagents wraps the runtime.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from my_deepagent.artifact_schema import ArtifactSchemaRegistry
|
||||||
|
from my_deepagent.binding import (
|
||||||
|
BackendAvailability,
|
||||||
|
PersonaConsentStore,
|
||||||
|
)
|
||||||
|
from my_deepagent.config import load_config
|
||||||
|
from my_deepagent.engine import WorkflowEngine
|
||||||
|
from my_deepagent.enums import Backend
|
||||||
|
from my_deepagent.persistence.checkpointer import get_checkpointer_ctx
|
||||||
|
from my_deepagent.persistence.db import Database
|
||||||
|
from my_deepagent.persona import load_personas_from_dir
|
||||||
|
from my_deepagent.workflow import load_workflow_yaml
|
||||||
|
|
||||||
|
pytestmark = [pytest.mark.integration]
|
||||||
|
|
||||||
|
_SEED_ROOT = Path(__file__).resolve().parents[2] / "docs" / "schemas"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Contract 1: engine wires saver + thread_id correctly
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_engine_passes_saver_and_thread_id_to_agent(tmp_path: Path, pg_db_url: str) -> None:
|
||||||
|
"""`build_agent` receives `checkpointer=saver`; `ainvoke` receives matching thread_id."""
|
||||||
|
captured_build: dict[str, Any] = {}
|
||||||
|
captured_invoke_configs: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
def fake_build_agent(*args: Any, **kwargs: Any) -> Any:
|
||||||
|
captured_build.update(kwargs)
|
||||||
|
fake_agent = MagicMock()
|
||||||
|
|
||||||
|
async def _ainvoke(
|
||||||
|
_payload: dict[str, Any],
|
||||||
|
*,
|
||||||
|
config: dict[str, Any] | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
captured_invoke_configs.append(config or {})
|
||||||
|
# Pretend the agent wrote the expected artifact.
|
||||||
|
root_dir: Path = kwargs["root_dir"]
|
||||||
|
artifact_path = root_dir / "artifacts" / "spec.json"
|
||||||
|
artifact_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
artifact_path.write_text(
|
||||||
|
'{"runId": "00000000-0000-0000-0000-000000000000", '
|
||||||
|
'"workflowId": "spec-and-review", "phaseKey": "spec", '
|
||||||
|
'"persona": "test", "summary": "fake", "decisions": [], '
|
||||||
|
'"openQuestions": []}',
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
return {"messages": []}
|
||||||
|
|
||||||
|
fake_agent.ainvoke = _ainvoke
|
||||||
|
return fake_agent
|
||||||
|
|
||||||
|
ws_root = tmp_path / "ws"
|
||||||
|
ws_root.mkdir()
|
||||||
|
config = load_config(
|
||||||
|
workspace_root=ws_root,
|
||||||
|
data_dir=tmp_path / "data",
|
||||||
|
state_dir=tmp_path / "state",
|
||||||
|
database_url=pg_db_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
template = load_workflow_yaml(_SEED_ROOT / "workflows" / "spec-and-review@1.yaml")
|
||||||
|
personas = load_personas_from_dir(_SEED_ROOT / "personas")
|
||||||
|
registry = ArtifactSchemaRegistry(roots=[_SEED_ROOT / "artifacts"])
|
||||||
|
consent = PersonaConsentStore(tmp_path / "consents.json")
|
||||||
|
backends = BackendAvailability(available_backends=frozenset(Backend))
|
||||||
|
|
||||||
|
db = Database(config.database_url)
|
||||||
|
await db.init_schema()
|
||||||
|
|
||||||
|
async def _auto_approve(_payload: dict[str, Any], _gates: list[str]) -> Any:
|
||||||
|
from my_deepagent.enums import ApprovalDecisionAction
|
||||||
|
|
||||||
|
return ApprovalDecisionAction.APPROVE
|
||||||
|
|
||||||
|
engine = WorkflowEngine(
|
||||||
|
db=db,
|
||||||
|
config=config,
|
||||||
|
persona_pool=personas,
|
||||||
|
artifact_registry=registry,
|
||||||
|
consent_store=consent,
|
||||||
|
available_backends=backends,
|
||||||
|
approval_callback=_auto_approve,
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("my_deepagent.engine.build_agent", side_effect=fake_build_agent):
|
||||||
|
try:
|
||||||
|
await engine.run(
|
||||||
|
template,
|
||||||
|
repo_path=tmp_path / "fake-repo",
|
||||||
|
base_branch="main",
|
||||||
|
requirements_md="test",
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
await db.dispose()
|
||||||
|
|
||||||
|
# Contract 1.a: build_agent received a checkpointer (not None)
|
||||||
|
assert "checkpointer" in captured_build
|
||||||
|
assert captured_build["checkpointer"] is not None, "engine must forward saver to build_agent"
|
||||||
|
|
||||||
|
# Contract 1.b: ainvoke received a config with thread_id matching the
|
||||||
|
# run:<uuid>:phase:<uuid> format
|
||||||
|
assert captured_invoke_configs, "ainvoke must have been called at least once"
|
||||||
|
first_config = captured_invoke_configs[0]
|
||||||
|
assert "configurable" in first_config
|
||||||
|
thread_id = first_config["configurable"].get("thread_id")
|
||||||
|
assert thread_id is not None, "thread_id must be set in agent.ainvoke config"
|
||||||
|
assert thread_id.startswith("run:"), f"unexpected thread_id format: {thread_id!r}"
|
||||||
|
assert ":phase:" in thread_id
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Contract 2: AsyncPostgresSaver thread isolation + round-trip
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_postgres_saver_round_trip_isolated_threads(pg_db_url: str) -> None:
|
||||||
|
"""Two different thread_ids write to different rows; same thread_id reads back."""
|
||||||
|
thread_a = f"run:{uuid4()}:phase:{uuid4()}"
|
||||||
|
thread_b = f"run:{uuid4()}:phase:{uuid4()}"
|
||||||
|
|
||||||
|
# First open: setup() runs the LangGraph DDL.
|
||||||
|
async with get_checkpointer_ctx(pg_db_url) as saver:
|
||||||
|
# Verify LangGraph created its own tables alongside the alembic schema.
|
||||||
|
conn_url = pg_db_url.replace("postgresql+asyncpg://", "postgresql://")
|
||||||
|
from psycopg import connect
|
||||||
|
|
||||||
|
with connect(conn_url, autocommit=True) as conn:
|
||||||
|
with conn.cursor() as cur:
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
SELECT tablename FROM pg_tables
|
||||||
|
WHERE schemaname='public'
|
||||||
|
AND tablename LIKE 'checkpoint%%'
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
lg_tables = {row[0] for row in cur.fetchall()}
|
||||||
|
assert "checkpoints" in lg_tables, (
|
||||||
|
f"AsyncPostgresSaver did not create the `checkpoints` table; got {lg_tables}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Write a synthetic checkpoint to thread_a.
|
||||||
|
from langgraph.checkpoint.base import empty_checkpoint
|
||||||
|
|
||||||
|
ck_a = empty_checkpoint()
|
||||||
|
ck_a["channel_values"] = {"messages": ["hello from a"]}
|
||||||
|
|
||||||
|
# AsyncPostgresSaver requires both thread_id AND checkpoint_ns in
|
||||||
|
# configurable; LangGraph's prebuilt graphs default checkpoint_ns to
|
||||||
|
# "" so we replicate that here. new_versions advertises that the
|
||||||
|
# "messages" channel is at version 1. RunnableConfig is a TypedDict
|
||||||
|
# so we cast through Any for mypy.
|
||||||
|
config_a: Any = {"configurable": {"thread_id": thread_a, "checkpoint_ns": ""}}
|
||||||
|
await saver.aput(config_a, ck_a, {"source": "input", "step": 1}, {"messages": "1"})
|
||||||
|
|
||||||
|
# And one to thread_b
|
||||||
|
ck_b = empty_checkpoint()
|
||||||
|
ck_b["channel_values"] = {"messages": ["hello from b"]}
|
||||||
|
config_b: Any = {"configurable": {"thread_id": thread_b, "checkpoint_ns": ""}}
|
||||||
|
await saver.aput(config_b, ck_b, {"source": "input", "step": 1}, {"messages": "1"})
|
||||||
|
|
||||||
|
# Each thread must read back its own latest checkpoint and not the other's.
|
||||||
|
# LangGraph's internal serialization shape is opaque — we only verify
|
||||||
|
# the wiring guarantees thread isolation (different IDs return distinct
|
||||||
|
# checkpoints) and round-trip (aput → aget returns a non-None result).
|
||||||
|
latest_a = await saver.aget(config_a)
|
||||||
|
assert latest_a is not None, "thread_a checkpoint must persist across aget"
|
||||||
|
latest_b = await saver.aget(config_b)
|
||||||
|
assert latest_b is not None, "thread_b checkpoint must persist across aget"
|
||||||
|
# Sanity: the two checkpoint IDs are distinct (proves thread isolation).
|
||||||
|
assert latest_a["id"] != latest_b["id"]
|
||||||
Reference in New Issue
Block a user