"""WorkflowEngine: orchestrates run lifecycle, phase loop, artifact validation, approval gate.""" from __future__ import annotations import asyncio import json import logging import signal from collections.abc import AsyncIterator from contextlib import asynccontextmanager, suppress from dataclasses import dataclass from datetime import UTC, datetime from pathlib import Path from typing import Any from uuid import UUID, uuid4 from sqlalchemy import select from .artifact_schema import ArtifactSchemaRegistry from .audit import make_audit_recorder from .binding import ( BackendAvailability, Binding, BindingOverride, PersonaConsentStore, bind_personas, ) from .budget import BudgetTracker from .config import Config from .enums import ApprovalDecisionAction, ApprovalState, RunPhaseState, RunState from .errors import MyDeepAgentError from .hash import sha256 from .middleware.artifact_watcher import ArtifactWatcherMiddleware from .middleware.audit import AuditToolMiddleware from .middleware.cost import CostMiddleware from .monitoring.pricing import PricingCache from .persistence.checkpointer import get_checkpointer_ctx from .persistence.db import Database from .persistence.models import ( AgentPersonaRow, ApprovalDecisionRow, ApprovalRequestRow, ArtifactRow, LlmCallRow, RunBindingRow, RunEventRow, RunInputRow, RunPhaseRow, RunRow, WorkflowTemplateRow, ) from .persona import Persona from .run_event import RunEventType, run_idempotency_key from .session import build_agent from .workflow import WorkflowPhase, WorkflowTemplate # ApprovalCallback type: async (request_payload: dict, gates: list[str]) -> ApprovalDecisionAction ApprovalCallback = Any # Callable[[dict, list[str]], Awaitable[ApprovalDecisionAction]] _DEFAULT_PHASE_TIMEOUT_SECONDS = 300 # 5 minutes _LOG_CORRUPT_PERSONA = logging.getLogger(__name__ + ".resume") @dataclass(frozen=True) class RunResult: run_id: UUID state: RunState final_report_path: Path | None error: str | None = None class _PhaseAbortedError(Exception): def __init__(self, reason: str) -> None: self.reason = reason super().__init__(reason) class WorkflowEngine: """In-process workflow engine for v0.1.0. For each phase: build_agent -> invoke -> wait for write_file targeting expected_artifact_path -> load + jsonschema validate -> repair 1x if invalid -> approval gate -> next phase. All events appended idempotently to run_events via the (run_id, idempotency_key) UNIQUE constraint — concurrent/retry safe. """ def __init__( self, db: Database, config: Config, persona_pool: list[Persona], artifact_registry: ArtifactSchemaRegistry, consent_store: PersonaConsentStore, available_backends: BackendAvailability, approval_callback: ApprovalCallback, budget_tracker: BudgetTracker | None = None, pricing: PricingCache | None = None, checkpointer_url: str | None = None, ) -> None: self._db = db self._config = config self._personas = persona_pool self._artifacts = artifact_registry self._consent = consent_store self._backends = available_backends self._approval = approval_callback self._budget = budget_tracker self._pricing = pricing or PricingCache() self._shutdown_event: asyncio.Event = asyncio.Event() self._inflight_tasks: set[asyncio.Task[Any]] = set() # LangGraph checkpoint URL. None → falls back to config.database_url at run-time. # The saver itself is opened inside `run()` (one ctx per run, shared across phases) # and lives on `self._saver` for the duration of that run. self._checkpointer_url: str = checkpointer_url or config.database_url self._saver: Any | None = None def install_signal_handlers(self) -> None: """Attach SIGTERM/SIGINT handlers to the running event loop. Idempotent: calling twice replaces the previous handlers. Should be invoked from ``cli/run.py`` once the asyncio loop is up. On shutdown signal: in-flight ainvoke() tasks get a 30s grace, then are cancelled. """ loop = asyncio.get_running_loop() for sig in (signal.SIGTERM, signal.SIGINT): with suppress(NotImplementedError, ValueError): loop.add_signal_handler(sig, self._on_signal, sig) def _on_signal(self, sig: signal.Signals) -> None: self._shutdown_event.set() loop = asyncio.get_running_loop() loop.call_later(30.0, self._force_cancel_inflight) def _force_cancel_inflight(self) -> None: for task in list(self._inflight_tasks): if not task.done(): task.cancel() @property def shutdown_requested(self) -> bool: return self._shutdown_event.is_set() @asynccontextmanager async def _maybe_open_saver(self) -> AsyncIterator[Any | None]: """Yield an AsyncPostgresSaver for Postgres URLs; yield None for SQLite. SQLite is supported for tests only and never wires durable resume. LangGraph's AsyncPostgresSaver requires a libpq DSN; passing a SQLite URL would raise psycopg.ProgrammingError. Production runs always use Postgres (see DR-2 / DR-3), so this is purely a test-affordance shim. """ url = self._checkpointer_url if url.startswith(("postgresql://", "postgresql+asyncpg://", "postgresql+psycopg://")): async with get_checkpointer_ctx(url) as saver: yield saver else: yield None async def run( self, template: WorkflowTemplate, *, repo_path: Path, base_branch: str = "main", requirements_md: str = "", override: BindingOverride | None = None, pre_allocated_run_id: UUID | None = None, ) -> RunResult: """Start a brand-new run. Allocates a new `run_id`, binds personas, persists skeleton metadata, and dispatches to the shared `_execute_run` phase loop. For resuming an existing non-terminal run, use :meth:`resume` instead. `pre_allocated_run_id` lets the FastAPI runner pick the UUID up-front so the route can return it before the phase loop completes. """ run_id = pre_allocated_run_id if pre_allocated_run_id is not None else uuid4() worktree_root = self._config.workspace_root / str(run_id) worktree_root.mkdir(parents=True, exist_ok=True) artifacts_dir = worktree_root / "artifacts" artifacts_dir.mkdir(parents=True, exist_ok=True) bindings = bind_personas(template, self._personas, self._backends, self._consent, override) await self._persist_run_skeleton( None, run_id, template, bindings, repo_path, base_branch, worktree_root, requirements_md, ) await self._append_event(run_id, None, RunEventType.RUN_CREATED, {}) await self._append_event(run_id, None, RunEventType.RUN_STARTED, {}) return await self._execute_run(run_id, template, worktree_root, bindings) async def resume(self, run_id: UUID) -> RunResult: """Resume a non-terminal run from its first non-completed phase. Reloads worktree_root, template, and bindings from the DB — does **not** re-run `bind_personas`, so consent/pool changes since the original run do not silently shift the binding. Phases whose `RunPhaseRow.state` is already ``completed`` are skipped; the rest re-execute and (when a LangGraph saver is wired) replay deepagents from the last checkpoint for that phase's thread_id. Raises: MyDeepAgentError: if the run is missing, terminal, or its bindings / template metadata cannot be reloaded. """ run_row = await self._get_run_or_raise(run_id) if run_row.state in { RunState.COMPLETED.value, RunState.FAILED.value, RunState.ABORTED.value, }: raise MyDeepAgentError.human_required( "run_already_terminal", message=( f"run {run_id} is already {run_row.state}; start a fresh run " f"with `mydeepagent run`" ), ) worktree_root = Path(run_row.worktree_root) template = await self._reload_template(run_row.template_id) bindings = await self._reload_bindings(run_id) if not bindings: raise MyDeepAgentError.human_required( "run_metadata_missing", message=( f"run {run_id} has no binding rows; cannot resume — start a fresh run instead" ), ) await self._append_event(run_id, None, RunEventType.RUN_RESUMED, {}) return await self._execute_run(run_id, template, worktree_root, bindings) async def _execute_run( self, run_id: UUID, template: WorkflowTemplate, worktree_root: Path, bindings: dict[str, Binding], ) -> RunResult: """Shared phase loop used by both `run` (new) and `resume`.""" await self._set_run_state(run_id, RunState.EXECUTING) # Open the LangGraph AsyncPostgresSaver once per run; all phases share it. # Opening per-phase would re-connect to Postgres 5+ times per run for no # gain — checkpoints are isolated by `thread_id` not by saver instance. # SQLite URLs (test-only) skip the saver entirely — deepagents accepts # checkpointer=None and runs without resume support. async with self._maybe_open_saver() as saver: self._saver = saver completed_keys = await self._get_completed_phase_keys(run_id) try: for phase_def in template.phases: if phase_def.key in completed_keys: await self._append_event( run_id, None, RunEventType.PHASE_SKIPPED, {"phase_key": phase_def.key, "reason": "already_completed"}, ) continue role_binding = bindings[phase_def.role] await self._run_phase(run_id, worktree_root, template, phase_def, role_binding) await self._set_run_state(run_id, RunState.COMPLETED) await self._append_event(run_id, None, RunEventType.RUN_COMPLETED, {}) report_path = await self._compose_final_report( run_id, worktree_root, RunState.COMPLETED ) return RunResult( run_id=run_id, state=RunState.COMPLETED, final_report_path=report_path ) except _PhaseAbortedError as e: await self._set_run_state(run_id, RunState.ABORTED) await self._append_event( run_id, None, RunEventType.RUN_ABORTED, {"reason": e.reason} ) report_path = await self._compose_final_report( run_id, worktree_root, RunState.ABORTED, error=e.reason ) return RunResult( run_id=run_id, state=RunState.ABORTED, final_report_path=report_path, error=e.reason, ) except MyDeepAgentError as e: await self._set_run_state(run_id, RunState.FAILED) await self._append_event( run_id, None, RunEventType.RUN_FAILED, {"code": e.code, "message": str(e)} ) report_path = await self._compose_final_report( run_id, worktree_root, RunState.FAILED, error=str(e) ) return RunResult( run_id=run_id, state=RunState.FAILED, final_report_path=report_path, error=str(e), ) finally: self._saver = None # ------------------------------------------------------------------ # Phase execution # ------------------------------------------------------------------ async def _run_phase( self, run_id: UUID, worktree_root: Path, template: WorkflowTemplate, phase_def: WorkflowPhase, binding: Binding, ) -> None: if self.shutdown_requested: await self._append_event(run_id, None, RunEventType.RUN_PAUSED, {"reason": "shutdown"}) await self._set_run_state(run_id, RunState.PAUSED) raise _PhaseAbortedError(reason="shutdown signal received") phase_id = await self._ensure_phase_row(run_id, phase_def) await self._set_phase_state(phase_id, RunPhaseState.RUNNING) await self._append_event( run_id, phase_id, RunEventType.PHASE_STARTED, {"phase_key": phase_def.key} ) # Phases without an expected artifact complete immediately if phase_def.expected_artifact is None: await self._set_phase_state(phase_id, RunPhaseState.COMPLETED) await self._append_event(run_id, phase_id, RunEventType.PHASE_COMPLETED, {}) return expected_path = (worktree_root / phase_def.expected_artifact.path).resolve() expected_path.parent.mkdir(parents=True, exist_ok=True) # Repair loop: max 2 attempts for attempt in range(1, 3): validated = await self._run_agent_and_validate( run_id, phase_id, worktree_root, phase_def, binding, expected_path, attempt ) if validated: break # validated=False means: invalid/timeout + still have budget for retry # on attempt 2, _run_agent_and_validate raises instead of returning False await self._run_approval_gate(run_id, phase_id, phase_def, expected_path) await self._set_phase_state(phase_id, RunPhaseState.COMPLETED) await self._append_event(run_id, phase_id, RunEventType.PHASE_COMPLETED, {}) async def _run_agent_and_validate( self, run_id: UUID, phase_id: UUID, worktree_root: Path, phase_def: WorkflowPhase, binding: Binding, expected_path: Path, attempt: int, ) -> bool: """Invoke agent for one attempt and validate artifact. Returns True on success. Returns False when attempt < 2 and artifact is missing/invalid (caller retries). Raises MyDeepAgentError on final failure (attempt >= 2). """ written = await self._invoke_agent_until_artifact( run_id, phase_id, worktree_root, phase_def, binding, expected_path, attempt=attempt ) if not written: await self._append_event(run_id, phase_id, RunEventType.ARTIFACT_TIMEOUT, {}) if attempt >= 2: await self._set_phase_state(phase_id, RunPhaseState.FAILED) await self._append_event( run_id, phase_id, RunEventType.PHASE_FAILED, {"reason": "artifact_timeout_exhausted"}, ) raise MyDeepAgentError.human_required( "artifact_timeout_exhausted", message=( f"phase '{phase_def.key}' did not produce expected artifact " f"after {attempt} attempts" ), ) return False # Validate the written artifact await self._set_phase_state(phase_id, RunPhaseState.VALIDATING) assert phase_def.expected_artifact is not None schema_id = phase_def.expected_artifact.schema_id try: data = json.loads(expected_path.read_text(encoding="utf-8")) except (OSError, json.JSONDecodeError) as exc: await self._append_event( run_id, phase_id, RunEventType.ARTIFACT_INVALID, {"errors": [{"message": str(exc)}]}, ) if attempt >= 2: raise MyDeepAgentError.human_required( "artifact_invalid_after_repair", message=str(exc), cause=exc, ) from exc await self._append_event(run_id, phase_id, RunEventType.PROMPT_REPAIRED, {}) return False result = self._artifacts.validate(schema_id, data) if result.ok: await self._persist_artifact(run_id, phase_id, expected_path, schema_id, valid=True) await self._append_event(run_id, phase_id, RunEventType.ARTIFACT_VALIDATED, {}) return True error_payload = [{"path": f.path, "message": f.message} for f in result.errors[:5]] await self._persist_artifact( run_id, phase_id, expected_path, schema_id, valid=False, errors=list(result.errors), ) await self._append_event( run_id, phase_id, RunEventType.ARTIFACT_INVALID, {"errors": error_payload} ) if attempt >= 2: await self._set_phase_state(phase_id, RunPhaseState.FAILED) await self._append_event( run_id, phase_id, RunEventType.PHASE_FAILED, {"reason": "artifact_invalid_after_repair"}, ) raise MyDeepAgentError.human_required( "artifact_invalid_after_repair", message=f"phase '{phase_def.key}' artifact failed validation after repair", ) await self._append_event(run_id, phase_id, RunEventType.PROMPT_REPAIRED, {}) return False async def _run_approval_gate( self, run_id: UUID, phase_id: UUID, phase_def: WorkflowPhase, expected_path: Path, ) -> None: """Run the approval gate if gates are configured. Raises on reject/abort.""" if not phase_def.gates: return await self._set_phase_state(phase_id, RunPhaseState.AWAITING_APPROVAL) decision = await self._request_approval(run_id, phase_id, phase_def, expected_path) if decision == ApprovalDecisionAction.ABORT: raise _PhaseAbortedError(reason=f"aborted at phase {phase_def.key}") if decision != ApprovalDecisionAction.APPROVE: await self._set_phase_state(phase_id, RunPhaseState.FAILED) await self._append_event( run_id, phase_id, RunEventType.PHASE_FAILED, {"reason": decision.value} ) raise MyDeepAgentError.human_required( "approval_rejected", message=f"phase '{phase_def.key}' approval was {decision.value}", ) async def _invoke_agent_until_artifact( self, run_id: UUID, phase_id: UUID, worktree_root: Path, phase_def: WorkflowPhase, binding: Binding, expected_path: Path, attempt: int, ) -> bool: """Build agent + invoke + return True if expected_path was written, False on timeout.""" written_paths: list[str] = [] async def _on_written(path: str, _content: str) -> None: written_paths.append(path) watcher = ArtifactWatcherMiddleware(expected_path, _on_written) cost_mw = CostMiddleware( pricing=self._pricing, model_name=binding.persona.model, run_id=run_id, phase_id=phase_id, persona_name=binding.persona.name, budget_tracker=self._budget, recorder=self._record_llm_call, ) audit_mw = AuditToolMiddleware( run_id=run_id, phase_id=phase_id, file_recorder=make_audit_recorder(self._config.state_dir), ) agent = build_agent( binding.persona, self._config, root_dir=worktree_root, middleware=[watcher, cost_mw, audit_mw], checkpointer=self._saver, ) envelope = self._build_envelope(run_id, phase_id, phase_def, attempt, expected_path) await self._append_event( run_id, phase_id, RunEventType.ARTIFACT_EXPECTED, {"path": str(expected_path)} ) event_type = RunEventType.PROMPT_REPAIRED if attempt > 1 else RunEventType.PROMPT_SENT await self._append_event(run_id, phase_id, event_type, {"attempt": attempt}) # thread_id matches the format already used by LlmCallRow.thread_id # (engine.py _record_llm_call) so a single namespace covers both # cost ledger and LangGraph checkpoint replay. thread_id = f"run:{run_id}:phase:{phase_id}" timeout = float(phase_def.timeout_seconds or _DEFAULT_PHASE_TIMEOUT_SECONDS) try: invoke_task: asyncio.Task[Any] = asyncio.create_task( agent.ainvoke( {"messages": [{"role": "user", "content": envelope}]}, config={"configurable": {"thread_id": thread_id}}, ) ) self._inflight_tasks.add(invoke_task) try: await asyncio.wait_for(asyncio.shield(invoke_task), timeout=timeout) except TimeoutError: pass finally: self._inflight_tasks.discard(invoke_task) except asyncio.CancelledError: pass return expected_path.is_file() def _build_envelope( self, run_id: UUID, phase_id: UUID, phase_def: WorkflowPhase, attempt: int, expected_path: Path, ) -> str: artifact = phase_def.expected_artifact assert artifact is not None try: schema_def = self._artifacts.load(artifact.schema_id) schema_inline = json.dumps(schema_def, indent=2, ensure_ascii=False) except (MyDeepAgentError, AttributeError): # AttributeError covers test scaffolding that instantiates the engine # via __new__ without wiring _artifacts; production paths always have it. schema_inline = "(schema not available)" repair_note = ( "\n\n[REPAIR ATTEMPT]\n" "Your previous artifact did not validate against the JSON Schema below. " "Re-read the schema carefully and emit a corrected JSON object that satisfies " "every `required` field and respects all `enum`, `type`, `minLength`, and " "`additionalProperties: false` constraints." if attempt > 1 else "" ) return ( f"MYDEEPAGENT_PROMPT_BEGIN {phase_id}\n" f"Run: {run_id}\n" f"Phase: {phase_def.key}\n" f"Attempt: {attempt}\n" f"Expected artifact path: {expected_path}\n" f"Expected schema id: {artifact.schema_id}\n" f"\n" f"JSON Schema 2020-12 for this artifact (you MUST satisfy it exactly):\n" f"```json\n{schema_inline}\n```\n" f"\n" f"Use the `write_file` tool to write a JSON object that matches the schema " f"to the exact path `{expected_path}`. The file must parse as valid JSON.\n" f"\n" f"Instructions:\n" f"{phase_def.instructions}" f"{repair_note}\n" f"MYDEEPAGENT_PROMPT_END {phase_id}" ) # ------------------------------------------------------------------ # Approval gate # ------------------------------------------------------------------ async def _request_approval( self, run_id: UUID, phase_id: UUID, phase_def: WorkflowPhase, artifact_path: Path, ) -> ApprovalDecisionAction: request_id = uuid4() idem_key = f"{phase_def.key}:{artifact_path.name}" payload: dict[str, Any] = { "phase_key": phase_def.key, "artifact_path": str(artifact_path), "gates": list(phase_def.gates), } async with self._db.session() as s: s.add( ApprovalRequestRow( id=str(request_id), run_id=str(run_id), phase_id=str(phase_id), gate_key=phase_def.gates[0] if phase_def.gates else "default", state=ApprovalState.PENDING.value, idempotency_key=idem_key, payload=payload, created_at=_now_iso(), ) ) await self._append_event( run_id, phase_id, RunEventType.APPROVAL_REQUESTED, {"request_id": str(request_id)}, ) decision: ApprovalDecisionAction = await self._approval(payload, list(phase_def.gates)) async with self._db.session() as s: s.add( ApprovalDecisionRow( id=str(uuid4()), approval_request_id=str(request_id), action=decision.value, decided_at=_now_iso(), idempotency_key=f"{idem_key}:{decision.value}", ) ) await self._append_event( run_id, phase_id, RunEventType.APPROVAL_RESOLVED, {"action": decision.value} ) return decision # ------------------------------------------------------------------ # Final report # ------------------------------------------------------------------ async def _compose_final_report( self, run_id: UUID, worktree_root: Path, status: RunState, error: str | None = None, ) -> Path: worktree_root.mkdir(parents=True, exist_ok=True) async with self._db.session() as s: run = await s.get(RunRow, str(run_id)) phase_rows = list( (await s.execute(select(RunPhaseRow).where(RunPhaseRow.run_id == str(run_id)))) .scalars() .all() ) artifact_rows = list( (await s.execute(select(ArtifactRow).where(ArtifactRow.run_id == str(run_id)))) .scalars() .all() ) event_rows = list( ( await s.execute( select(RunEventRow) .where(RunEventRow.run_id == str(run_id)) .order_by(RunEventRow.seq.desc()) .limit(20) ) ) .scalars() .all() ) report: dict[str, Any] = { "runId": str(run_id), "templateHash": run.template_hash if run else "", "status": status.value, "phases": [ { "key": p.phase_key, "state": p.state, "started_at": p.started_at, "ended_at": p.ended_at, "attempts": p.attempts, } for p in phase_rows ], "artifacts": [ {"path": a.path, "schema": a.schema_id, "hash": a.hash} for a in artifact_rows ], "events": [{"seq": e.seq, "type": e.type, "ts": e.ts} for e in reversed(event_rows)], "unresolved": [], "endedAt": _now_iso(), "error": error, } json_path = worktree_root / f"{run_id}.report.json" md_path = worktree_root / f"{run_id}.report.md" json_path.write_text(json.dumps(report, indent=2, ensure_ascii=False), encoding="utf-8") md_path.write_text(_render_report_md(report), encoding="utf-8") return json_path # ------------------------------------------------------------------ # Persistence helpers # ------------------------------------------------------------------ async def _record_llm_call(self, record: dict[str, Any]) -> None: """CostMiddleware recorder: persist one LlmCallRow per model call. Fills every NOT NULL column of LlmCallRow. Per-input/output cost is computed from the same PricingCache that the middleware already consulted, so the ledger and the row stay consistent. """ in_tokens = int(record.get("input_tokens") or 0) out_tokens = int(record.get("output_tokens") or 0) model = str(record.get("model") or "") # Reproduce per-direction cost from the cached price. price = self._pricing.get(model) if self._pricing is not None else None if price is not None: cost_input = (in_tokens / 1000.0) * price.input_per_1k_usd cost_output = (out_tokens / 1000.0) * price.output_per_1k_usd else: cost_input = 0.0 cost_output = 0.0 cost_total = float(record.get("cost_usd_total") or (cost_input + cost_output)) run_id_val = record.get("run_id") phase_id_val = record.get("phase_id") session_id_val = record.get("interactive_session_id") thread_id = ( f"run:{run_id_val}:phase:{phase_id_val}" if run_id_val is not None else f"session:{session_id_val}" ) persona_name = str(record.get("persona_name") or "") async with self._db.session() as s: s.add( LlmCallRow( run_id=(str(run_id_val) if run_id_val is not None else None), phase_id=(str(phase_id_val) if phase_id_val is not None else None), interactive_session_id=( str(session_id_val) if session_id_val is not None else None ), thread_id=thread_id, persona_name=persona_name, persona_version=1, model=model, role="main", turn_index=0, input_tokens=in_tokens, output_tokens=out_tokens, cached_tokens=0, reasoning_tokens=0, cost_usd_input=cost_input, cost_usd_output=cost_output, cost_usd_total=cost_total, latency_ms=int(record.get("latency_ms") or 0), status=str(record.get("status") or "ok"), error_code=record.get("error_code"), request_id=None, ts=_now_iso(), ) ) try: await s.commit() except Exception: await s.rollback() async def _persist_run_skeleton( self, _unused_session: Any, # kept for caller compatibility — we open own sessions run_id: UUID, template: WorkflowTemplate, bindings: dict[str, Binding], repo_path: Path, base_branch: str, worktree_root: Path, requirements_md: str, ) -> None: template_hash = template.compute_hash() now = _now_iso() # --- Phase 1: upsert FK targets (committed separately to satisfy FK ordering) --- template_id = uuid4() async with self._db.session() as s: existing_tpl = ( await s.execute( select(WorkflowTemplateRow).where(WorkflowTemplateRow.hash == template_hash) ) ).scalar_one_or_none() if existing_tpl is None: s.add( WorkflowTemplateRow( id=str(template_id), name=template.name, version=template.version, hash=template_hash, definition=template.model_dump(by_alias=True), created_at=now, ) ) else: template_id = UUID(existing_tpl.id) persona_ids: dict[str, UUID] = {} for role_id, binding in bindings.items(): persona_hash = binding.persona.compute_hash() async with self._db.session() as s: existing_persona = ( await s.execute( select(AgentPersonaRow).where(AgentPersonaRow.hash == persona_hash) ) ).scalar_one_or_none() if existing_persona is None: persona_id = uuid4() s.add( AgentPersonaRow( id=str(persona_id), name=binding.persona.name, version=binding.persona.version, hash=persona_hash, definition=binding.persona.model_dump(), created_at=now, ) ) else: persona_id = UUID(existing_persona.id) persona_ids[role_id] = persona_id # --- Phase 2: insert RunRow (FK: workflow_templates — already committed above) --- async with self._db.session() as s: s.add( RunRow( id=str(run_id), template_id=str(template_id), template_hash=template_hash, state=RunState.CREATED.value, repo_path=str(repo_path), base_branch=base_branch, worktree_root=str(worktree_root), created_at=now, updated_at=now, ) ) # --- Phase 3: insert RunInputRow + RunBindingRow (FK: runs — now committed) --- async with self._db.session() as s: s.add( RunInputRow( id=str(uuid4()), run_id=str(run_id), requirements_md=requirements_md, objective={}, extra={}, input_hash=sha256( {"requirements": requirements_md, "template_hash": template_hash} ), ) ) for role_id, binding in bindings.items(): persona_hash = binding.persona.compute_hash() s.add( RunBindingRow( id=str(uuid4()), run_id=str(run_id), role_id=role_id, persona_id=str(persona_ids[role_id]), persona_hash=persona_hash, backend=binding.persona.backend.value, binding_hash=binding.binding_hash, ) ) async def _ensure_phase_row(self, run_id: UUID, phase_def: WorkflowPhase) -> UUID: async with self._db.session() as s: existing = ( await s.execute( select(RunPhaseRow).where( RunPhaseRow.run_id == str(run_id), RunPhaseRow.phase_key == phase_def.key, ) ) ).scalar_one_or_none() if existing is not None: return UUID(existing.id) phase_id = uuid4() existing_count = len( ( await s.execute(select(RunPhaseRow).where(RunPhaseRow.run_id == str(run_id))) ).all() ) s.add( RunPhaseRow( id=str(phase_id), run_id=str(run_id), phase_key=phase_def.key, seq=existing_count, state=RunPhaseState.PENDING.value, attempts=0, started_at=_now_iso(), ) ) return phase_id async def _set_phase_state(self, phase_id: UUID, state: RunPhaseState) -> None: async with self._db.session() as s: row = await s.get(RunPhaseRow, str(phase_id)) if row is not None: row.state = state.value if state in ( RunPhaseState.COMPLETED, RunPhaseState.FAILED, RunPhaseState.SKIPPED, ): row.ended_at = _now_iso() async def _set_run_state(self, run_id: UUID, state: RunState) -> None: async with self._db.session() as s: row = await s.get(RunRow, str(run_id)) if row is not None: row.state = state.value row.updated_at = _now_iso() if state in (RunState.COMPLETED, RunState.FAILED, RunState.ABORTED): row.ended_at = _now_iso() async def _append_event( self, run_id: UUID, phase_id: UUID | None, event_type: RunEventType, payload: dict[str, Any], ) -> None: idem_extra = { k: str(v) for k, v in payload.items() if k in ("phase_key", "attempt", "request_id", "action", "code") } idem = run_idempotency_key(event_type, run_id, **idem_extra) async with self._db.session() as s: existing_count = len( ( await s.execute(select(RunEventRow).where(RunEventRow.run_id == str(run_id))) ).all() ) s.add( RunEventRow( run_id=str(run_id), phase_id=str(phase_id) if phase_id is not None else None, seq=existing_count + 1, type=event_type.value, payload=payload, idempotency_key=idem, ts=_now_iso(), ) ) try: await s.flush() except Exception: await s.rollback() async def _persist_artifact( self, run_id: UUID, phase_id: UUID, path: Path, schema_id: str, *, valid: bool, errors: list[Any] | None = None, ) -> None: try: content = path.read_bytes() except OSError: return artifact_hash = sha256({"bytes_len": len(content), "hex_prefix": content[:64].hex()}) async with self._db.session() as s: s.add( ArtifactRow( id=str(uuid4()), run_id=str(run_id), phase_id=str(phase_id), path=str(path), schema_id=schema_id, hash=artifact_hash, valid=valid, validation_error=( [{"path": f.path, "message": f.message} for f in errors] if errors else None ), created_at=_now_iso(), ) ) try: await s.flush() except Exception: await s.rollback() # ------------------------------------------------------------------ # Resume helpers (used by `resume` to rehydrate state from DB) # ------------------------------------------------------------------ async def _get_run_or_raise(self, run_id: UUID) -> RunRow: async with self._db.session() as s: row = await s.get(RunRow, str(run_id)) if row is None: raise MyDeepAgentError.human_required( "run_not_found", message=f"run {run_id} not found in DB", ) return row async def _reload_template(self, template_id: str) -> WorkflowTemplate: async with self._db.session() as s: row = await s.get(WorkflowTemplateRow, template_id) if row is None: raise MyDeepAgentError.fatal( "template_load_failed", message=f"workflow_templates row {template_id} not found", ) try: return WorkflowTemplate.model_validate(row.definition) except Exception as e: raise MyDeepAgentError.fatal( "template_load_failed", message=f"workflow_templates.definition for {template_id} is malformed: {e}", ) from e async def _reload_bindings(self, run_id: UUID) -> dict[str, Binding]: """Rebuild the `{role_id: Binding}` dict from `run_bindings` + `agent_personas`. Empty result means the run was never fully persisted — caller raises `run_metadata_missing`. We do NOT re-run `bind_personas` here on purpose: consent / pool state could have shifted since the original run. """ from .binding import Binding as _Binding # local import to avoid cycle hint async with self._db.session() as s: binding_rows = ( (await s.execute(select(RunBindingRow).where(RunBindingRow.run_id == str(run_id)))) .scalars() .all() ) persona_rows: dict[str, AgentPersonaRow] = {} for br in binding_rows: pr = await s.get(AgentPersonaRow, br.persona_id) if pr is not None: persona_rows[br.persona_id] = pr out: dict[str, Binding] = {} for br in binding_rows: pr = persona_rows.get(br.persona_id) if pr is None: continue try: persona = Persona.model_validate(pr.definition) except Exception as e: # Corrupt persona JSON: skip the binding; an empty bindings dict # surfaces as `run_metadata_missing` in `resume`. _LOG_CORRUPT_PERSONA.warning("corrupt persona row %s during resume: %s", pr.id, e) continue out[br.role_id] = _Binding( role_id=br.role_id, persona=persona, binding_hash=br.binding_hash ) return out async def _get_completed_phase_keys(self, run_id: UUID) -> set[str]: """Return the set of phase_keys that already reached `completed` state.""" async with self._db.session() as s: rows = ( ( await s.execute( select(RunPhaseRow.phase_key) .where(RunPhaseRow.run_id == str(run_id)) .where(RunPhaseRow.state == RunPhaseState.COMPLETED.value) ) ) .scalars() .all() ) return set(rows) # ------------------------------------------------------------------ # Module-level helpers # ------------------------------------------------------------------ def _now_iso() -> str: return datetime.now(UTC).isoformat(timespec="seconds") def _render_report_md(report: dict[str, Any]) -> str: lines: list[str] = [ f"# Run {report['runId']}", f"**Status**: {report['status']}", f"**Template hash**: `{report['templateHash']}`", f"**Ended at**: {report['endedAt']}", "", "## Phases", ] for p in report["phases"]: lines.append(f"- **{p['key']}** — state={p['state']}, attempts={p['attempts']}") lines.append("\n## Artifacts") for a in report["artifacts"]: lines.append(f"- `{a['path']}` (schema={a['schema']}, hash={a['hash'][:16]}...)") if report.get("error"): lines += ["", "## Error", str(report["error"])] return "\n".join(lines) + "\n"