"""Integration tests for crash recovery sweep (sweep_orphan_runs).""" from __future__ import annotations import uuid from collections.abc import AsyncGenerator from pathlib import Path import pytest import pytest_asyncio from sqlalchemy import select from sqlalchemy.exc import IntegrityError from my_deepagent.enums import RunPhaseState, RunState from my_deepagent.persistence.db import Database from my_deepagent.persistence.models import ( RunEventRow, RunPhaseRow, RunRow, WorkflowTemplateRow, ) from my_deepagent.recovery import SweepReport, sweep_orphan_runs from my_deepagent.run_event import RunEventType # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- _NOW = "2026-05-14T00:00:00+00:00" def _make_id() -> str: return str(uuid.uuid4()) def _template_row(template_id: str | None = None) -> WorkflowTemplateRow: tid = template_id or _make_id() return WorkflowTemplateRow( id=tid, name="test-wf", version=1, hash=tid, definition={}, created_at=_NOW, ) def _run_row( *, run_id: str | None = None, template_id: str, state: str = RunState.EXECUTING.value, repo_path: str = "/repo", base_branch: str = "main", ) -> RunRow: rid = run_id or _make_id() return RunRow( id=rid, template_id=template_id, template_hash="a" * 64, state=state, repo_path=repo_path, base_branch=base_branch, worktree_root="/wt", created_at=_NOW, updated_at=_NOW, ) def _phase_row(run_id: str, state: str = RunPhaseState.RUNNING.value) -> RunPhaseRow: return RunPhaseRow( id=_make_id(), run_id=run_id, phase_key="spec", seq=0, state=state, attempts=1, started_at=_NOW, ) # --------------------------------------------------------------------------- # Fixtures # --------------------------------------------------------------------------- @pytest_asyncio.fixture() async def db(tmp_path: Path) -> AsyncGenerator[Database, None]: url = f"sqlite+aiosqlite:///{tmp_path}/test.db" database = Database(url) await database.init_schema() yield database await database.dispose() # --------------------------------------------------------------------------- # Tests # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_sweep_with_no_orphans_returns_empty_report(db: Database) -> None: """Sweep on empty DB returns SweepReport with zero counts.""" report = await sweep_orphan_runs(db) assert isinstance(report, SweepReport) assert report.total == 0 assert report.failed_runs == () assert report.failed_phases == () @pytest.mark.asyncio async def test_sweep_marks_executing_run_as_failed(db: Database) -> None: """A run in EXECUTING state is marked FAILED after sweep.""" tid = _make_id() run = _run_row(template_id=tid, state=RunState.EXECUTING.value) async with db.session() as s: s.add(_template_row(tid)) async with db.session() as s: s.add(run) report = await sweep_orphan_runs(db) assert len(report.failed_runs) == 1 async with db.session() as s: refreshed = await s.get(RunRow, run.id) assert refreshed is not None assert refreshed.state == RunState.FAILED.value assert refreshed.ended_at is not None @pytest.mark.asyncio async def test_sweep_marks_paused_run_as_failed(db: Database) -> None: """A run in PAUSED state is marked FAILED after sweep.""" tid = _make_id() run = _run_row(template_id=tid, state=RunState.PAUSED.value) async with db.session() as s: s.add(_template_row(tid)) async with db.session() as s: s.add(run) report = await sweep_orphan_runs(db) assert len(report.failed_runs) == 1 async with db.session() as s: refreshed = await s.get(RunRow, run.id) assert refreshed is not None assert refreshed.state == RunState.FAILED.value @pytest.mark.asyncio async def test_sweep_leaves_completed_run_alone(db: Database) -> None: """A run in COMPLETED state is NOT touched by the sweep.""" tid = _make_id() run = _run_row(template_id=tid, state=RunState.COMPLETED.value) async with db.session() as s: s.add(_template_row(tid)) async with db.session() as s: s.add(run) report = await sweep_orphan_runs(db) assert report.total == 0 async with db.session() as s: refreshed = await s.get(RunRow, run.id) assert refreshed is not None assert refreshed.state == RunState.COMPLETED.value @pytest.mark.asyncio async def test_sweep_cascades_phase_states(db: Database) -> None: """Orphan phases belonging to a swept run are also marked FAILED.""" tid = _make_id() run = _run_row(template_id=tid, state=RunState.EXECUTING.value) async with db.session() as s: s.add(_template_row(tid)) async with db.session() as s: s.add(run) phase = _phase_row(run.id, state=RunPhaseState.RUNNING.value) async with db.session() as s: s.add(phase) report = await sweep_orphan_runs(db) assert len(report.failed_runs) == 1 assert len(report.failed_phases) == 1 async with db.session() as s: refreshed_phase = await s.get(RunPhaseRow, phase.id) assert refreshed_phase is not None assert refreshed_phase.state == RunPhaseState.FAILED.value assert refreshed_phase.ended_at is not None @pytest.mark.asyncio async def test_sweep_emits_run_failed_event(db: Database) -> None: """Sweep emits exactly one run.failed event per orphan run.""" tid = _make_id() run = _run_row(template_id=tid, state=RunState.EXECUTING.value) async with db.session() as s: s.add(_template_row(tid)) async with db.session() as s: s.add(run) await sweep_orphan_runs(db) async with db.session() as s: events = ( ( await s.execute( select(RunEventRow) .where(RunEventRow.run_id == run.id) .where(RunEventRow.type == RunEventType.RUN_FAILED.value) ) ) .scalars() .all() ) assert len(events) == 1 assert events[0].payload.get("reason") == "process_restart_unrecovered" @pytest.mark.asyncio async def test_sweep_idempotent_no_duplicate_event(db: Database) -> None: """Running sweep twice does not create duplicate events (ON CONFLICT DO NOTHING).""" tid = _make_id() run = _run_row(template_id=tid, state=RunState.EXECUTING.value) async with db.session() as s: s.add(_template_row(tid)) async with db.session() as s: s.add(run) # First sweep marks the run as failed. report1 = await sweep_orphan_runs(db) assert len(report1.failed_runs) == 1 # Second sweep: no more non-terminal runs, no duplicate events. report2 = await sweep_orphan_runs(db) assert report2.total == 0 async with db.session() as s: events = ( ( await s.execute( select(RunEventRow) .where(RunEventRow.run_id == run.id) .where(RunEventRow.type == RunEventType.RUN_FAILED.value) ) ) .scalars() .all() ) assert len(events) == 1 @pytest.mark.asyncio async def test_sweep_frees_active_run_slot(db: Database) -> None: """After sweep, a second run with same (repo_path, base_branch) can be inserted. Without sweep: the partial unique index ux_active_run_repo_base prevents a second active run for the same (repo_path, base_branch). After sweep marks the first run FAILED, the uniqueness slot is freed and the second insert succeeds. """ repo = "/unique-repo" branch = "main" tid1 = _make_id() tid2 = _make_id() run1 = _run_row( template_id=tid1, state=RunState.EXECUTING.value, repo_path=repo, base_branch=branch, ) async with db.session() as s: s.add(_template_row(tid1)) s.add(_template_row(tid2)) async with db.session() as s: s.add(run1) # A second executing run for the same (repo, branch) must raise IntegrityError. run2 = _run_row( template_id=tid2, state=RunState.EXECUTING.value, repo_path=repo, base_branch=branch, ) with pytest.raises(IntegrityError): async with db.session() as s: s.add(run2) # Sweep frees the slot. report = await sweep_orphan_runs(db) assert len(report.failed_runs) == 1 # Now the second insert should succeed. run3 = _run_row( template_id=tid2, state=RunState.EXECUTING.value, repo_path=repo, base_branch=branch, ) async with db.session() as s: s.add(run3) async with db.session() as s: refreshed = await s.get(RunRow, run3.id) assert refreshed is not None assert refreshed.state == RunState.EXECUTING.value