"""Integration tests for src/my_deepagent/persistence/ (DB engine + ORM models).""" from __future__ import annotations import subprocess import sys import uuid from pathlib import Path from typing import Any import pytest import pytest_asyncio from sqlalchemy import text from sqlalchemy.exc import IntegrityError from my_deepagent.persistence.db import Database from my_deepagent.persistence.models import ( AgentPersonaRow, RunEventRow, RunInputRow, RunPhaseRow, RunRow, WorkflowTemplateRow, ) # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- _NOW = "2026-05-15T00:00:00+00:00" def _make_id() -> str: return str(uuid.uuid4()) def _workflow_template_row(template_id: str) -> WorkflowTemplateRow: """Return a WorkflowTemplateRow that satisfies the runs.template_id FK.""" return WorkflowTemplateRow( id=template_id, name="test-wf", version=1, hash=template_id, # unique per invocation definition={}, created_at=_NOW, ) def _run_row(run_id: str | None = None, template_id: str | None = None) -> RunRow: rid = run_id or _make_id() tid = template_id or _make_id() return RunRow( id=rid, template_id=tid, template_hash="a" * 64, state="pending", repo_path="/repo", base_branch="main", worktree_root="/wt", created_at=_NOW, updated_at=_NOW, ) # --------------------------------------------------------------------------- # Fixtures # --------------------------------------------------------------------------- @pytest.fixture() def db_url(tmp_path: Path) -> str: return f"sqlite+aiosqlite:///{tmp_path}/test.db" @pytest_asyncio.fixture() async def db(db_url: str) -> Database: # type: ignore[misc] database = Database(db_url) await database.init_schema() yield database # type: ignore[misc] await database.dispose() # --------------------------------------------------------------------------- # A.1: All 18 tables exist after init_schema # --------------------------------------------------------------------------- EXPECTED_TABLES = { "workflow_templates", "agent_personas", "runs", "run_inputs", "run_bindings", "run_phases", "run_events", "approval_requests", "approval_decisions", "artifacts", "interactive_sessions", "tool_calls", "llm_calls", "model_pricing", "budget_ledger", "persona_consents", "phase_feedback", "run_commands", } @pytest.mark.asyncio async def test_init_schema_creates_all_tables(db: Database) -> None: """All expected tables must exist in sqlite_master after init_schema.""" async with db.session() as session: result = await session.execute( text("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name") ) table_names = {row[0] for row in result.fetchall()} table_names.discard("alembic_version") assert EXPECTED_TABLES <= table_names, f"Missing tables: {EXPECTED_TABLES - table_names}" # --------------------------------------------------------------------------- # A.2: WAL mode active # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_wal_mode_active(db: Database) -> None: """journal_mode PRAGMA must return 'wal' after connection.""" async with db.session() as session: result = await session.execute(text("PRAGMA journal_mode")) mode = result.scalar() assert mode == "wal", f"Expected 'wal', got {mode!r}" # --------------------------------------------------------------------------- # A.3: busy_timeout active # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_busy_timeout_active(db: Database) -> None: """busy_timeout PRAGMA must return 5000.""" async with db.session() as session: result = await session.execute(text("PRAGMA busy_timeout")) timeout = result.scalar() assert timeout == 5000, f"Expected 5000, got {timeout!r}" # --------------------------------------------------------------------------- # A.4: foreign_keys active # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_foreign_keys_active(db: Database) -> None: """foreign_keys PRAGMA must return 1.""" async with db.session() as session: result = await session.execute(text("PRAGMA foreign_keys")) fk = result.scalar() assert fk == 1, f"Expected 1, got {fk!r}" # --------------------------------------------------------------------------- # A.5: basic insert + select round-trip # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_run_row_insert_and_select(db: Database) -> None: """RunRow insert then SELECT must return the same state.""" rid = _make_id() tid = _make_id() template = _workflow_template_row(tid) run = _run_row(rid, template_id=tid) async with db.session() as session: session.add(template) await session.flush() session.add(run) async with db.session() as session: fetched = await session.get(RunRow, rid) assert fetched is not None assert fetched.id == rid assert fetched.state == "pending" @pytest.mark.asyncio async def test_agent_persona_row_insert_and_select(db: Database) -> None: """AgentPersonaRow insert then SELECT must return the same record.""" persona_id = _make_id() persona = AgentPersonaRow( id=persona_id, name="test-persona", version=1, hash="b" * 64, definition={"model": "test"}, created_at=_NOW, ) async with db.session() as session: session.add(persona) async with db.session() as session: fetched = await session.get(AgentPersonaRow, persona_id) assert fetched is not None assert fetched.name == "test-persona" assert fetched.version == 1 # --------------------------------------------------------------------------- # A.6: UNIQUE constraint — workflow_templates.hash duplicate # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_workflow_template_hash_unique_constraint(db: Database) -> None: """Inserting two WorkflowTemplateRows with the same hash must raise IntegrityError.""" def make_template(tid: str) -> WorkflowTemplateRow: return WorkflowTemplateRow( id=tid, name="my-wf", version=1, hash="c" * 64, # same hash for both definition={}, created_at=_NOW, ) t1 = make_template(_make_id()) async with db.session() as session: session.add(t1) t2 = make_template(_make_id()) with pytest.raises(IntegrityError): async with db.session() as session: session.add(t2) # --------------------------------------------------------------------------- # A.7: FK CASCADE — RunRow delete cascades to RunInputRow # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_fk_cascade_run_delete_cascades_run_input(db: Database) -> None: """Deleting a RunRow must cascade-delete its RunInputRow.""" rid = _make_id() tid = _make_id() template = _workflow_template_row(tid) run = _run_row(rid, template_id=tid) inp = RunInputRow( id=_make_id(), run_id=rid, requirements_md="# Requirements", objective={"goal": "test"}, extra={}, input_hash="d" * 64, ) # Insert parent and child in the same transaction so FK is satisfied. async with db.session() as session: session.add(template) await session.flush() # persist template before run references it session.add(run) await session.flush() # persist run before inp references it session.add(inp) async with db.session() as session: fetched_run = await session.get(RunRow, rid) assert fetched_run is not None await session.delete(fetched_run) async with db.session() as session: result = await session.execute( text("SELECT id FROM run_inputs WHERE run_id = :rid"), {"rid": rid}, ) rows = result.fetchall() assert rows == [], f"Expected cascade delete of run_inputs, got {rows}" # --------------------------------------------------------------------------- # A.8: JSON column round-trip # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_json_column_round_trip(db: Database) -> None: """RunEventRow.payload nested dict must survive DB round-trip intact.""" rid = _make_id() tid = _make_id() template = _workflow_template_row(tid) run = _run_row(rid, template_id=tid) payload: dict[str, Any] = { "nested": {"list": [1, 2, 3], "flag": True}, "msg": "hello", } event = RunEventRow( run_id=rid, seq=1, type="phase_started", payload=payload, idempotency_key="idem-1", ts=_NOW, ) async with db.session() as session: session.add(template) await session.flush() # persist template before run references it session.add(run) await session.flush() # persist run before event references it session.add(event) async with db.session() as session: result = await session.execute( text("SELECT payload FROM run_events WHERE run_id = :rid"), {"rid": rid} ) raw = result.scalar() import json as _json restored = _json.loads(raw) if isinstance(raw, str) else raw assert restored == payload # --------------------------------------------------------------------------- # A.9: UUID string column round-trip # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_uuid_column_round_trip(db: Database) -> None: """UUID primary key stored as string must compare equal after retrieval.""" expected_id = str(uuid.uuid4()) tid = _make_id() template = _workflow_template_row(tid) run = RunRow( id=expected_id, template_id=tid, template_hash="e" * 64, state="running", repo_path="/r", base_branch="main", worktree_root="/w", created_at=_NOW, updated_at=_NOW, ) async with db.session() as session: session.add(template) await session.flush() session.add(run) async with db.session() as session: fetched = await session.get(RunRow, expected_id) assert fetched is not None assert fetched.id == expected_id # --------------------------------------------------------------------------- # A.10: UNIQUE(run_id, seq) on run_events # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_run_events_unique_run_seq(db: Database) -> None: """Two RunEventRows with the same (run_id, seq) must raise IntegrityError.""" rid = _make_id() tid = _make_id() template = _workflow_template_row(tid) run = _run_row(rid, template_id=tid) async with db.session() as session: session.add(template) await session.flush() session.add(run) await session.flush() session.add( RunEventRow( run_id=rid, seq=1, type="x", payload={}, idempotency_key="key-a", ts=_NOW, ) ) with pytest.raises(IntegrityError): async with db.session() as session: session.add( RunEventRow( run_id=rid, seq=1, # same seq → collision on (run_id, seq) type="x", payload={}, idempotency_key="key-b", ts=_NOW, ) ) # --------------------------------------------------------------------------- # A.11: UNIQUE(run_id, idempotency_key) on run_events # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_run_events_unique_idempotency_key(db: Database) -> None: """Two RunEventRows with the same (run_id, idempotency_key) must raise IntegrityError.""" rid = _make_id() tid = _make_id() template = _workflow_template_row(tid) run = _run_row(rid, template_id=tid) async with db.session() as session: session.add(template) await session.flush() session.add(run) await session.flush() session.add( RunEventRow( run_id=rid, seq=1, type="x", payload={}, idempotency_key="shared-key", ts=_NOW, ) ) with pytest.raises(IntegrityError): async with db.session() as session: session.add( RunEventRow( run_id=rid, seq=2, # different seq type="x", payload={}, idempotency_key="shared-key", # same idem key → collision ts=_NOW, ) ) # --------------------------------------------------------------------------- # A.12: Index existence on run_events # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_run_events_index_exists(db: Database) -> None: """The run_events_run_id_ts_idx index must exist in sqlite_master.""" async with db.session() as session: result = await session.execute( text( "SELECT name FROM sqlite_master " "WHERE type='index' AND name='run_events_run_id_ts_idx'" ) ) names = [row[0] for row in result.fetchall()] assert "run_events_run_id_ts_idx" in names # --------------------------------------------------------------------------- # A.13: dispose + new session works # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_dispose_and_reconnect(db_url: str) -> None: """After dispose(), creating a new Database and querying must succeed.""" db1 = Database(db_url) await db1.init_schema() await db1.dispose() db2 = Database(db_url) async with db2.session() as session: result = await session.execute( text("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name") ) tables = [row[0] for row in result.fetchall()] await db2.dispose() assert "runs" in tables # --------------------------------------------------------------------------- # A.14: Alembic upgrade head produces valid schema # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_alembic_upgrade_head_produces_valid_schema(tmp_path: Path) -> None: """Running alembic upgrade head on a fresh DB must create the expected tables.""" db_path = tmp_path / "alembic_test.db" db_url = f"sqlite:///{db_path}" # sync URL for alembic env.py project_root = Path(__file__).parent.parent.parent result = subprocess.run( [ sys.executable, "-m", "alembic", "upgrade", "head", ], cwd=str(project_root), env={**__import__("os").environ, "DATABASE_URL": db_url}, capture_output=True, text=True, ) assert result.returncode == 0, ( f"alembic upgrade head failed:\nSTDOUT:\n{result.stdout}\nSTDERR:\n{result.stderr}" ) import sqlite3 with sqlite3.connect(str(db_path)) as conn: cur = conn.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name") tables = {row[0] for row in cur.fetchall()} tables.discard("alembic_version") assert EXPECTED_TABLES <= tables, f"Missing after alembic upgrade: {EXPECTED_TABLES - tables}" # --------------------------------------------------------------------------- # P0-1: partial unique index ux_active_run_repo_base # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_active_run_unique_index_blocks_duplicate(db: Database) -> None: """Two active runs with the same (repo_path, base_branch) must raise IntegrityError.""" tid = _make_id() template = _workflow_template_row(tid) rid1 = _make_id() run1 = _run_row(rid1, template_id=tid) run1.state = "running" rid2 = _make_id() run2 = _run_row(rid2, template_id=tid) run2.state = "pending" # Same repo_path and base_branch — both active → must violate unique index. async with db.session() as session: session.add(template) await session.flush() session.add(run1) with pytest.raises(IntegrityError): async with db.session() as session: session.add(run2) @pytest.mark.asyncio async def test_active_run_unique_index_allows_completed(db: Database) -> None: """A completed run allows a new active run with the same (repo_path, base_branch).""" tid = _make_id() template = _workflow_template_row(tid) rid1 = _make_id() run1 = _run_row(rid1, template_id=tid) run1.state = "completed" rid2 = _make_id() run2 = _run_row(rid2, template_id=tid) run2.state = "running" # Same repo/branch; run1 is completed (excluded) → run2 must succeed. async with db.session() as session: session.add(template) await session.flush() session.add(run1) async with db.session() as session: session.add(run2) async with db.session() as session: fetched = await session.get(RunRow, rid2) assert fetched is not None assert fetched.state == "running" # --------------------------------------------------------------------------- # P0-3: FK CASCADE — RunRow delete cascades to all audit children # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_fk_cascade_run_delete_cascades_phase_feedback(db: Database) -> None: """Deleting a RunRow cascades to phase_feedback and run_phases rows.""" from my_deepagent.persistence.models import PhaseFeedbackRow tid = _make_id() rid = _make_id() phase_id = _make_id() template = _workflow_template_row(tid) run = _run_row(rid, template_id=tid) phase = RunPhaseRow( id=phase_id, run_id=rid, phase_key="plan", seq=1, state="completed", attempts=1, ) feedback = PhaseFeedbackRow( run_id=rid, phase_id=phase_id, reaction="thumbs_up", created_at=_NOW, ) async with db.session() as session: session.add(template) await session.flush() session.add(run) await session.flush() session.add(phase) await session.flush() session.add(feedback) async with db.session() as session: fetched_run = await session.get(RunRow, rid) assert fetched_run is not None await session.delete(fetched_run) async with db.session() as session: fb_result = await session.execute( text("SELECT id FROM phase_feedback WHERE run_id = :rid"), {"rid": rid} ) ph_result = await session.execute( text("SELECT id FROM run_phases WHERE run_id = :rid"), {"rid": rid} ) assert fb_result.fetchall() == [], "phase_feedback must cascade-delete with run" assert ph_result.fetchall() == [], "run_phases must cascade-delete with run" # --------------------------------------------------------------------------- # P0-3: FK RESTRICT — deleting WorkflowTemplateRow with runs is blocked # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_fk_restrict_template_delete_blocked_by_run(db: Database) -> None: """Deleting a WorkflowTemplateRow that has a referencing RunRow must raise IntegrityError.""" tid = _make_id() rid = _make_id() template = _workflow_template_row(tid) run = _run_row(rid, template_id=tid) async with db.session() as session: session.add(template) await session.flush() session.add(run) with pytest.raises(IntegrityError): async with db.session() as session: fetched = await session.get(WorkflowTemplateRow, tid) assert fetched is not None await session.delete(fetched) # --------------------------------------------------------------------------- # P0-1: partial unique index exists in sqlite_master after init_schema # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_active_run_partial_index_exists_in_schema(db: Database) -> None: """ux_active_run_repo_base partial unique index must exist after init_schema.""" async with db.session() as session: result = await session.execute( text( "SELECT sql FROM sqlite_master " "WHERE type='index' AND name='ux_active_run_repo_base'" ) ) row = result.fetchone() assert row is not None, "ux_active_run_repo_base index missing from sqlite_master" assert "WHERE" in (row[0] or ""), f"Expected WHERE clause in index SQL, got: {row[0]}"