feat: isolate agentic worktrees and surface execution evidence

This commit is contained in:
chungyeong
2026-03-13 22:50:46 +09:00
parent 3fb19e90c0
commit b19d174c98
7 changed files with 758 additions and 14 deletions

View File

@@ -6,6 +6,7 @@ import os
import re
import subprocess
import time
from hashlib import sha256
from concurrent.futures import ThreadPoolExecutor, as_completed
from datetime import datetime
from pathlib import Path
@@ -92,15 +93,110 @@ def _setup_worktree(cwd: Path, run_dir: Path, preset_name: str) -> tuple[Path, s
Returns (worktree_path, branch_name).
"""
from cross_eval.worktree import create_worktree, make_branch_name
from cross_eval.worktree import create_worktree, make_branch_name, make_worktree_dir
branch_name = make_branch_name(preset_name)
worktree_dir = run_dir / "work"
worktree_dir = make_worktree_dir(cwd, branch_name)
worktree_path = create_worktree(
base_cwd=cwd, work_dir=worktree_dir, branch_name=branch_name,
)
(run_dir / "worktree_path.txt").write_text(f"{worktree_path}\n", encoding="utf-8")
(run_dir / "worktree_branch.txt").write_text(f"{branch_name}\n", encoding="utf-8")
return worktree_path, branch_name
def _snapshot_repo_state(cwd: Path) -> str:
"""Capture the base repository working-tree state.
This is used to detect agentic runs that accidentally modify the original
checkout instead of the isolated worktree.
"""
status = subprocess.run(
["git", "status", "--short", "--untracked-files=all"],
cwd=cwd,
capture_output=True,
text=True,
)
if status.returncode != 0:
return ""
diff = subprocess.run(
["git", "diff", "--no-ext-diff", "--binary", "HEAD"],
cwd=cwd,
capture_output=True,
text=True,
)
cached_diff = subprocess.run(
["git", "diff", "--no-ext-diff", "--binary", "--cached"],
cwd=cwd,
capture_output=True,
text=True,
)
untracked = subprocess.run(
["git", "ls-files", "--others", "--exclude-standard", "-z"],
cwd=cwd,
capture_output=True,
)
parts = [
status.stdout,
diff.stdout,
cached_diff.stdout,
]
if untracked.returncode == 0 and untracked.stdout:
for rel_path in untracked.stdout.decode("utf-8", errors="replace").split("\0"):
if not rel_path:
continue
file_path = cwd / rel_path
if file_path.is_file():
digest = sha256(file_path.read_bytes()).hexdigest()
parts.append(f"UNTRACKED {rel_path} {digest}")
else:
parts.append(f"UNTRACKED {rel_path} (non-file)")
return "\n".join(parts)
def _snapshot_repo_status(cwd: Path) -> str:
"""Capture a human-readable status summary for error reporting."""
result = subprocess.run(
["git", "status", "--short", "--untracked-files=all"],
cwd=cwd,
capture_output=True,
text=True,
)
if result.returncode != 0:
return ""
return result.stdout.strip()
def _assert_base_repo_isolation(
cwd: Path,
baseline_state: str,
*,
step_name: str,
agent_name: str,
worktree_path: Path,
baseline_status: str,
) -> None:
"""Fail fast if an agentic run leaked changes into the base repo."""
current_state = _snapshot_repo_state(cwd)
if current_state == baseline_state:
return
current_status = _snapshot_repo_status(cwd)
before = baseline_status or "(clean)"
after = current_status or "(clean)"
raise WorktreeError(
"Agent modified the base repository instead of the isolated worktree.\n\n"
f"Step: {step_name}\n"
f"Agent: {agent_name}\n"
f"Worktree: {worktree_path}\n\n"
f"Baseline status:\n{before}\n\n"
f"Current status:\n{after}"
)
def _finalize_worktree(
cwd: Path,
worktree_path: Path,
@@ -172,10 +268,14 @@ def _run_simple_pipeline(
# Setup shared worktree for agentic mode
worktree_path: Path | None = None
agentic_branch_name: str | None = None
base_repo_state: str | None = None
base_repo_status: str | None = None
if not dry_run and _has_agentic_steps(config, config.pipeline):
worktree_path, agentic_branch_name = _setup_worktree(
cwd, run_dir, config.preset_name,
)
base_repo_state = _snapshot_repo_state(cwd)
base_repo_status = _snapshot_repo_status(cwd)
feedback = "(no feedback — first iteration)"
iterations: list[IterationResult] = []
@@ -203,6 +303,8 @@ def _run_simple_pipeline(
run_dir=run_dir, output_iter=i,
worktree_path=worktree_path,
runtime_env=runtime_env,
base_repo_state=base_repo_state,
base_repo_status=base_repo_status,
)
# Intermediate commit so next iteration's diff only shows new changes
@@ -332,10 +434,14 @@ def _run_phased_pipeline(
all_phase_steps = [s for p in config.phases for s in p.steps]
worktree_path: Path | None = None
agentic_branch_name: str | None = None
base_repo_state: str | None = None
base_repo_status: str | None = None
if not dry_run and _has_agentic_steps(config, all_phase_steps):
worktree_path, agentic_branch_name = _setup_worktree(
cwd, run_dir, config.preset_name,
)
base_repo_state = _snapshot_repo_state(cwd)
base_repo_status = _snapshot_repo_status(cwd)
iterations: list[IterationResult] = []
feedback = "(no feedback — first iteration)"
@@ -384,6 +490,8 @@ def _run_phased_pipeline(
run_dir=run_dir, output_iter=global_iter, phase_name=phase.name,
worktree_path=worktree_path,
runtime_env=runtime_env,
base_repo_state=base_repo_state,
base_repo_status=base_repo_status,
)
# Intermediate commit so next iteration's diff only shows new changes
@@ -626,6 +734,8 @@ def _run_steps(
phase_name: str | None = None,
worktree_path: Path | None = None,
runtime_env: dict[str, str] | None = None,
base_repo_state: str | None = None,
base_repo_status: str | None = None,
) -> tuple[dict[str, str], dict[str, AgentResult], str | None]:
"""Execute all steps in one iteration, parallelizing where possible."""
step_outputs: dict[str, str] = {}
@@ -644,6 +754,8 @@ def _run_steps(
run_dir=run_dir, output_iter=output_iter,
phase_name=phase_name, worktree_path=worktree_path,
runtime_env=runtime_env,
base_repo_state=base_repo_state,
base_repo_status=base_repo_status,
)
else:
_execute_parallel_batch(
@@ -653,6 +765,8 @@ def _run_steps(
run_dir=run_dir, output_iter=output_iter,
phase_name=phase_name, worktree_path=worktree_path,
runtime_env=runtime_env,
base_repo_state=base_repo_state,
base_repo_status=base_repo_status,
)
# Extract verdict from all verdict steps (ALL must PASS; ESCALATE wins over all)
@@ -709,6 +823,8 @@ def _execute_step(
quiet: bool = False,
worktree_path: Path | None = None,
runtime_env: dict[str, str] | None = None,
base_repo_state: str | None = None,
base_repo_status: str | None = None,
) -> None:
"""Execute a single step, updating step_outputs and step_results in place."""
if not quiet:
@@ -717,9 +833,10 @@ def _execute_step(
# 1. Resolve template
template = resolve_template(step.prompt_template)
# 2. Build context
# 2. Build context (include prior step results for evidence)
context = _build_context(
input_contents, step_outputs, feedback, iteration, max_iterations,
step_results=step_results,
)
# 3. Apply context overrides
@@ -794,6 +911,16 @@ def _execute_step(
raise
# 7. Store output
if worktree_path is not None and base_repo_state is not None:
_assert_base_repo_isolation(
cwd,
base_repo_state,
step_name=step.name,
agent_name=step.agent,
worktree_path=worktree_path,
baseline_status=base_repo_status or "",
)
step_outputs[step.output_key] = result.output
step_results[step.output_key] = result
@@ -826,6 +953,8 @@ def _execute_parallel_batch(
phase_name: str | None = None,
worktree_path: Path | None = None,
runtime_env: dict[str, str] | None = None,
base_repo_state: str | None = None,
base_repo_status: str | None = None,
) -> None:
"""Execute multiple steps in parallel using threads."""
agent_names = ", ".join(s.agent for s in batch)
@@ -838,6 +967,8 @@ def _execute_parallel_batch(
iteration, max_iterations, cwd, timeout, dry_run,
step_outputs, step_results,
run_dir=run_dir, output_iter=output_iter, phase_name=phase_name,
base_repo_state=base_repo_state,
base_repo_status=base_repo_status,
)
return
@@ -858,12 +989,15 @@ def _execute_parallel_batch(
step_outputs, step_results,
run_dir=run_dir, output_iter=output_iter,
phase_name=phase_name, worktree_path=worktree_path,
base_repo_state=base_repo_state,
base_repo_status=base_repo_status,
)
return
# Snapshot context before parallel execution (all steps see same state)
context_snapshot = dict(input_contents)
context_snapshot.update(step_outputs)
results_snapshot = dict(step_results)
# Collect results from parallel threads
local_outputs: dict[str, str] = {}
@@ -883,6 +1017,7 @@ def _execute_parallel_batch(
template = resolve_template(step.prompt_template)
context = _build_context(
context_snapshot, {}, feedback, iteration, max_iterations,
step_results=results_snapshot,
)
if step.context_override:
context = _apply_context_override(context, step.context_override)
@@ -919,6 +1054,16 @@ def _execute_parallel_batch(
batch_elapsed = round(time.monotonic() - batch_start, 1)
# Persist successful outputs even if a sibling step failed.
if worktree_path is not None and base_repo_state is not None:
_assert_base_repo_isolation(
cwd,
base_repo_state,
step_name=phase_name or "parallel-batch",
agent_name=agent_names,
worktree_path=worktree_path,
baseline_status=base_repo_status or "",
)
for step in batch:
key = step.output_key
if key not in local_outputs:
@@ -986,6 +1131,7 @@ def _build_context(
feedback: str,
iteration: int,
max_iterations: int,
step_results: dict[str, AgentResult] | None = None,
) -> dict[str, str]:
"""Build the template context dict."""
context: dict[str, str] = {}
@@ -994,9 +1140,42 @@ def _build_context(
context["feedback"] = feedback
context["iteration"] = str(iteration)
context["max_iterations"] = str(max_iterations)
# Surface execution evidence from prior steps so reviewers can inspect it
if step_results:
context["execution_evidence"] = _format_execution_evidence(step_results)
return context
def _format_execution_evidence(
step_results: dict[str, AgentResult],
) -> str:
"""Format execution evidence from prior steps for reviewer consumption.
Produces a compact summary of command, exit code, duration, and a truncated
transcript excerpt for each completed step so that reviewers and seniors
can verify claims against real execution data.
"""
if not step_results:
return "(no prior execution evidence)"
parts: list[str] = []
for key, result in step_results.items():
section = [
f"### Step: {result.step_name} ({result.agent_name})",
f"- Command: `{result.command_preview}`" if result.command_preview else "",
f"- Exit code: {result.exit_code}",
f"- Duration: {result.duration_seconds}s",
]
section = [line for line in section if line]
if result.transcript:
# Include a truncated transcript excerpt for debugging
excerpt = result.transcript[:2000]
if len(result.transcript) > 2000:
excerpt += "\n... (truncated)"
section.append(f"\n<details>\n<summary>Transcript excerpt</summary>\n\n{excerpt}\n</details>")
parts.append("\n".join(section))
return "\n\n---\n\n".join(parts)
def _build_runtime_inputs(
config: PipelineConfig,
input_contents: dict[str, str],