"""Main pipeline execution engine.""" from __future__ import annotations import logging import os import re import subprocess import time from concurrent.futures import ThreadPoolExecutor, as_completed from datetime import datetime from pathlib import Path from cross_eval.agent import invoke_agent from cross_eval.config import try_reload_config from cross_eval.models import ( AgentResult, IterationResult, PipelineConfig, PipelineResult, StepConfig, ) from cross_eval.prompts import render_template, resolve_template, set_language from cross_eval.report import build_report logger = logging.getLogger(__name__) def run_pipeline( config: PipelineConfig, cwd: Path | None = None, dry_run: bool = False, timeout: int | None = None, ) -> PipelineResult: """Execute the full cross-eval pipeline.""" # Create run directory: output/{preset}_{datetime}/ run_dir = _make_run_dir(config) if config.phases: return _run_phased_pipeline(config, run_dir, cwd, dry_run, timeout) return _run_simple_pipeline(config, run_dir, cwd, dry_run, timeout) def _make_run_dir(config: PipelineConfig) -> Path: """Create timestamped run directory under output_dir.""" ts = datetime.now().strftime("%Y%m%d_%H%M%S") run_dir = config.output_dir / f"{config.preset_name}_{ts}" run_dir.mkdir(parents=True, exist_ok=True) return run_dir def _run_simple_pipeline( config: PipelineConfig, run_dir: Path, cwd: Path | None = None, dry_run: bool = False, timeout: int | None = None, ) -> PipelineResult: """Execute a simple (non-phased) pipeline.""" if cwd is None: cwd = Path(os.getcwd()) set_language(config.language) input_contents = _load_inputs(config) feedback = "(no feedback — first iteration)" iterations: list[IterationResult] = [] start_time = time.monotonic() final_verdict = "MAX_ITERATIONS_REACHED" aggregate_history: dict[str, int] = {} aggregate_warnings: list[str] = [] for i in range(1, config.max_iterations + 1): config = try_reload_config(config) set_language(config.language) _refresh_inputs(config, input_contents) logger.info("=" * 50) logger.info(" Iteration %d/%d", i, config.max_iterations) logger.info("=" * 50) step_outputs, step_results, verdict = _run_steps( config.pipeline, config, input_contents, feedback, i, config.max_iterations, cwd, timeout, dry_run, run_dir=run_dir, output_iter=i, ) iter_result = IterationResult( iteration=i, step_results=step_results, step_outputs=step_outputs, verdict=verdict, ) warning = _detect_repeated_aggregate( config.pipeline, step_outputs, aggregate_history, iteration=i, ) if warning: iter_result.repeated_aggregate_warning = warning aggregate_warnings.append(warning) logger.warning(" %s", warning) iter_result.feedback = _collect_feedback(config.pipeline, step_outputs) feedback = iter_result.feedback or feedback iterations.append(iter_result) if verdict == "PASS": final_verdict = "PASS" if i >= config.min_iterations: logger.info(" PASS at iteration %d (min=%d reached)!", i, config.min_iterations) break else: logger.info( " PASS at iteration %d, but min_iterations=%d — continuing", i, config.min_iterations, ) if dry_run: logger.info(" (dry-run: stopping after iteration 1)") break total_duration = time.monotonic() - start_time pipeline_result = PipelineResult( iterations=iterations, final_verdict=final_verdict, total_duration=round(total_duration, 1), run_dir=run_dir, repeated_aggregate_warnings=aggregate_warnings, ) if not dry_run: _save_report(run_dir, config, pipeline_result) return pipeline_result def _run_phased_pipeline( config: PipelineConfig, run_dir: Path, cwd: Path | None = None, dry_run: bool = False, timeout: int | None = None, ) -> PipelineResult: """Execute a multi-phase pipeline (e.g. review-fix).""" if cwd is None: cwd = Path(os.getcwd()) set_language(config.language) input_contents = _load_inputs(config) iterations: list[IterationResult] = [] feedback = "(no feedback — first iteration)" start_time = time.monotonic() final_verdict = "MAX_ITERATIONS_REACHED" global_iter = 0 aggregate_history_by_phase: dict[str, dict[str, int]] = {} aggregate_warnings: list[str] = [] for phase_idx, phase in enumerate(config.phases): logger.info("=" * 60) logger.info( " Phase: %s (max_iter=%d, consecutive_pass=%d)", phase.name, phase.max_iterations, phase.consecutive_pass, ) logger.info("=" * 60) consecutive_passes = 0 phase_converged = False for pi in range(1, phase.max_iterations + 1): global_iter += 1 config = try_reload_config(config) set_language(config.language) _refresh_inputs(config, input_contents) logger.info("-" * 50) logger.info( " [%s] Iteration %d/%d (global: v%d)", phase.name, pi, phase.max_iterations, global_iter, ) logger.info("-" * 50) step_outputs, step_results, verdict = _run_steps( phase.steps, config, input_contents, feedback, pi, phase.max_iterations, cwd, timeout, dry_run, run_dir=run_dir, output_iter=global_iter, phase_name=phase.name, ) iter_result = IterationResult( iteration=global_iter, step_results=step_results, step_outputs=step_outputs, verdict=verdict, phase_name=phase.name, ) phase_history = aggregate_history_by_phase.setdefault(phase.name, {}) warning = _detect_repeated_aggregate( phase.steps, step_outputs, phase_history, iteration=global_iter, phase_name=phase.name, ) if warning: iter_result.repeated_aggregate_warning = warning aggregate_warnings.append(warning) logger.warning(" %s", warning) iter_result.feedback = _collect_feedback(phase.steps, step_outputs) feedback = iter_result.feedback or feedback iterations.append(iter_result) if verdict == "PASS": consecutive_passes += 1 logger.info( " [%s] PASS (%d/%d consecutive)", phase.name, consecutive_passes, phase.consecutive_pass, ) if consecutive_passes >= phase.consecutive_pass: logger.info( " [%s] Converged! %d consecutive PASSes.", phase.name, phase.consecutive_pass, ) phase_converged = True break else: consecutive_passes = 0 if dry_run: break if phase_converged: logger.info(" Phase '%s' completed: CONVERGED", phase.name) else: logger.info( " Phase '%s' completed: max iterations (%d) reached", phase.name, phase.max_iterations, ) if phase_idx == len(config.phases) - 1: final_verdict = "PASS" if phase_converged else "MAX_ITERATIONS_REACHED" total_duration = time.monotonic() - start_time pipeline_result = PipelineResult( iterations=iterations, final_verdict=final_verdict, total_duration=round(total_duration, 1), run_dir=run_dir, repeated_aggregate_warnings=aggregate_warnings, ) if not dry_run: _save_report(run_dir, config, pipeline_result) return pipeline_result # --------------------------------------------------------------------------- # Shared helpers # --------------------------------------------------------------------------- def _load_inputs(config: PipelineConfig) -> dict[str, str]: """Load input file contents from config.""" input_contents: dict[str, str] = {} for key, val in config.inputs.items(): if isinstance(val, str): input_contents[key] = val else: input_contents[key] = val.read_text(encoding="utf-8") return input_contents def _refresh_inputs( config: PipelineConfig, input_contents: dict[str, str], ) -> None: """Re-read input files (they may have changed on disk).""" for key, val in config.inputs.items(): if isinstance(val, str): input_contents[key] = val elif isinstance(val, Path) and val.exists(): input_contents[key] = val.read_text(encoding="utf-8") # --------------------------------------------------------------------------- # Parallel step grouping # --------------------------------------------------------------------------- def _get_step_dependencies(step: StepConfig) -> set[str]: """Extract output_key references from context_override values.""" deps: set[str] = set() for val in step.context_override.values(): for match in re.finditer(r"\{(\w+)\}", val): deps.add(match.group(1)) return deps def _group_parallel_steps(steps: list[StepConfig]) -> list[list[StepConfig]]: """Group consecutive parallel steps into batches. Consecutive steps with parallel=True are grouped together, but a new batch starts when a step depends on an output_key from a step in the current batch (dependency breaking). """ batches: list[list[StepConfig]] = [] current: list[StepConfig] = [] current_output_keys: set[str] = set() for step in steps: if not step.parallel: if current: batches.append(current) current = [] current_output_keys = set() batches.append([step]) continue # Check if this step depends on any output from the current batch deps = _get_step_dependencies(step) if deps & current_output_keys: batches.append(current) current = [] current_output_keys = set() current.append(step) current_output_keys.add(step.output_key) if current: batches.append(current) return batches # --------------------------------------------------------------------------- # Step execution # --------------------------------------------------------------------------- def _run_steps( steps: list[StepConfig], config: PipelineConfig, input_contents: dict[str, str], feedback: str, iteration: int, max_iterations: int, cwd: Path, timeout: int | None, dry_run: bool, *, run_dir: Path, output_iter: int, phase_name: str | None = None, ) -> tuple[dict[str, str], dict[str, AgentResult], str | None]: """Execute all steps in one iteration, parallelizing where possible.""" step_outputs: dict[str, str] = {} step_results: dict[str, AgentResult] = {} verdict: str | None = None batches = _group_parallel_steps(steps) for batch in batches: if len(batch) == 1: # Single step — run directly step = batch[0] _execute_step( step, config, input_contents, feedback, iteration, max_iterations, cwd, timeout, dry_run, step_outputs, step_results, run_dir=run_dir, output_iter=output_iter, phase_name=phase_name, ) else: # Parallel batch — run with ThreadPoolExecutor _execute_parallel_batch( batch, config, input_contents, feedback, iteration, max_iterations, cwd, timeout, dry_run, step_outputs, step_results, run_dir=run_dir, output_iter=output_iter, phase_name=phase_name, ) # Extract verdict from all verdict steps (ALL must PASS) for step in steps: if step.verdict: output = step_outputs.get(step.output_key, "") step_verdict = _extract_verdict(output, step.verdict_pattern) logger.info(" [%s] verdict: %s", step.name, step_verdict) if verdict is None: verdict = step_verdict elif step_verdict == "FAIL": verdict = "FAIL" return step_outputs, step_results, verdict def _execute_step( step: StepConfig, config: PipelineConfig, input_contents: dict[str, str], feedback: str, iteration: int, max_iterations: int, cwd: Path, timeout: int | None, dry_run: bool, step_outputs: dict[str, str], step_results: dict[str, AgentResult], *, run_dir: Path, output_iter: int, phase_name: str | None = None, quiet: bool = False, ) -> None: """Execute a single step, updating step_outputs and step_results in place.""" if not quiet: logger.info(" [%s] agent='%s' role='%s'", step.name, step.agent, step.role) # 1. Resolve template template = resolve_template(step.prompt_template) # 2. Build context context = _build_context( input_contents, step_outputs, feedback, iteration, max_iterations, ) # 3. Apply context overrides if step.context_override: context = _apply_context_override(context, step.context_override) # 4. Render prompt prompt = render_template(template, context) # 5. Dry run: print and skip if dry_run: phase_label = f" phase={phase_name}" if phase_name else "" print(f"\n--- Step: {step.name} (agent={step.agent}{phase_label}) ---") print(prompt) print(f"--- end {step.name} ---\n") step_outputs[step.output_key] = f"(dry-run: no output for {step.output_key})" return # 6. Invoke agent agent_config = config.agents[step.agent] try: result = invoke_agent( agent_config, prompt, step.name, cwd=cwd, timeout=timeout, quiet=quiet, ) except subprocess.TimeoutExpired as e: stdout = (e.stdout or b"") if isinstance(e.stdout, bytes) else (e.stdout or "") stderr = (e.stderr or b"") if isinstance(e.stderr, bytes) else (e.stderr or "") if isinstance(stdout, bytes): stdout = stdout.decode("utf-8", errors="replace") if isinstance(stderr, bytes): stderr = stderr.decode("utf-8", errors="replace") phase_info = f"- **Phase**: {phase_name}\n" if phase_name else "" error_msg = ( f"# Agent Timeout\n\n" f"{phase_info}" f"- **Step**: {step.name}\n" f"- **Agent**: {step.agent}\n" f"- **Timeout**: {timeout}s\n\n" f"Partial stdout ({len(stdout)} chars):\n" f"```\n{stdout[:2000] or '(none)'}\n```\n\n" f"Stderr:\n```\n{stderr[:2000] or '(none)'}\n```\n" ) _save_step_output(run_dir, output_iter, f"{step.name}_error", error_msg) logger.error(" [%s] TIMEOUT after %ss — saved to output", step.name, timeout) raise RuntimeError( f"Agent '{step.agent}' timed out after {timeout}s at step '{step.name}'. " f"Error saved to {run_dir}/v{output_iter}/{step.name}_error.md. " f"Try --timeout 0 (unlimited)" ) except RuntimeError as e: phase_info = f"- **Phase**: {phase_name}\n" if phase_name else "" error_msg = ( f"# Agent Error\n\n{phase_info}" f"- **Step**: {step.name}\n- **Agent**: {step.agent}\n\n```\n{e}\n```\n" ) _save_step_output(run_dir, output_iter, f"{step.name}_error", error_msg) logger.error(" [%s] FAILED — saved to output", step.name) raise # 7. Store output step_outputs[step.output_key] = result.output step_results[step.output_key] = result if not quiet: logger.info( " [%s] completed (%.1fs, %d chars)", step.name, result.duration_seconds, len(result.output), ) # 8. Save to disk _save_step_output(run_dir, output_iter, step.name, result.output) def _execute_parallel_batch( batch: list[StepConfig], config: PipelineConfig, input_contents: dict[str, str], feedback: str, iteration: int, max_iterations: int, cwd: Path, timeout: int | None, dry_run: bool, step_outputs: dict[str, str], step_results: dict[str, AgentResult], *, run_dir: Path, output_iter: int, phase_name: str | None = None, ) -> None: """Execute multiple steps in parallel using threads.""" agent_names = ", ".join(s.agent for s in batch) logger.info(" [parallel] %d agents: %s", len(batch), agent_names) if dry_run: for step in batch: _execute_step( step, config, input_contents, feedback, iteration, max_iterations, cwd, timeout, dry_run, step_outputs, step_results, run_dir=run_dir, output_iter=output_iter, phase_name=phase_name, ) return # Snapshot context before parallel execution (all steps see same state) context_snapshot = dict(input_contents) context_snapshot.update(step_outputs) # Collect results from parallel threads local_outputs: dict[str, str] = {} local_results: dict[str, AgentResult] = {} errors: list[Exception] = [] # Show a single spinner for the batch from cross_eval.agent import _Spinner spinner = _Spinner( f"[parallel] {len(batch)} agents running ({agent_names})..." ) spinner.start() batch_start = time.monotonic() def _run_one(step: StepConfig) -> tuple[str, str, AgentResult]: """Run one step, return (output_key, output, result).""" template = resolve_template(step.prompt_template) context = _build_context( context_snapshot, {}, feedback, iteration, max_iterations, ) if step.context_override: context = _apply_context_override(context, step.context_override) prompt = render_template(template, context) agent_config = config.agents[step.agent] result = invoke_agent( agent_config, prompt, step.name, cwd=cwd, timeout=timeout, quiet=True, ) return step.output_key, result.output, result with ThreadPoolExecutor(max_workers=len(batch)) as executor: futures = {executor.submit(_run_one, step): step for step in batch} for future in as_completed(futures): step = futures[future] try: output_key, output, result = future.result() local_results[output_key] = result local_outputs[output_key] = output except Exception as e: errors.append(e) batch_elapsed = round(time.monotonic() - batch_start, 1) if errors: spinner.stop(f"[parallel] FAILED ({batch_elapsed}s)") raise errors[0] spinner.stop(f"[parallel] {len(batch)} agents done ({batch_elapsed}s)") # Merge results for step in batch: key = step.output_key step_outputs[key] = local_outputs[key] step_results[key] = local_results[key] r = local_results[key] logger.info( " [%s] completed (%.1fs, %d chars)", step.name, r.duration_seconds, len(r.output), ) _save_step_output(run_dir, output_iter, step.name, r.output) # --------------------------------------------------------------------------- # Context and template helpers # --------------------------------------------------------------------------- def _build_context( input_contents: dict[str, str], step_outputs: dict[str, str], feedback: str, iteration: int, max_iterations: int, ) -> dict[str, str]: """Build the template context dict.""" context: dict[str, str] = {} context.update(input_contents) context.update(step_outputs) context["feedback"] = feedback context["iteration"] = str(iteration) context["max_iterations"] = str(max_iterations) return context def _apply_context_override( context: dict[str, str], overrides: dict[str, str], ) -> dict[str, str]: """Apply context_override mappings for cross-review scenarios.""" result = dict(context) for key, value_template in overrides.items(): result[key] = render_template(value_template, context) return result def _collect_feedback( steps: list[StepConfig], step_outputs: dict[str, str], ) -> str: """Collect feedback from all verdict steps. Single verdict step → raw output (backward compatible). Multiple verdict steps → combined with agent headers for cross-referencing. """ verdict_steps = [s for s in steps if s.verdict] if len(verdict_steps) == 1: return step_outputs.get(verdict_steps[0].output_key, "") parts: list[str] = [] for s in verdict_steps: output = step_outputs.get(s.output_key, "") if output: parts.append(f"## Review by {s.agent} ({s.name})\n{output}") return "\n\n---\n\n".join(parts) def _detect_repeated_aggregate( steps: list[StepConfig], step_outputs: dict[str, str], history: dict[str, int], *, iteration: int, phase_name: str | None = None, ) -> str | None: """Detect repeated aggregate-review outputs across iterations.""" for step in steps: if step.prompt_template != "default:aggregate-review": continue output = step_outputs.get(step.output_key, "") normalized = _normalize_aggregate_output(output) if not normalized: return None if normalized in history: prev_iter = history[normalized] phase_prefix = f"[{phase_name}] " if phase_name else "" return ( f"{phase_prefix}Repeated aggregate_review detected at iteration {iteration} " f"(same as iteration {prev_iter})." ) history[normalized] = iteration return None return None def _normalize_aggregate_output(output: str) -> str: """Normalize aggregate output for repeat detection.""" return " ".join(output.lower().split()) def _extract_verdict(output: str, pattern: str) -> str: """Extract PASS or FAIL from output using regex pattern.""" if re.search(pattern, output): return "PASS" return "FAIL" def _save_step_output( run_dir: Path, iteration: int, step_name: str, content: str, ) -> Path: """Save step output to run_dir/v{iteration}/{step_name}.md""" path = run_dir / f"v{iteration}" / f"{step_name}.md" path.parent.mkdir(parents=True, exist_ok=True) path.write_text(content, encoding="utf-8") return path def _save_report(run_dir: Path, config: PipelineConfig, result: PipelineResult) -> None: """Generate and save the final markdown report.""" report = build_report(config, result) report_path = run_dir / "final-report.md" report_path.parent.mkdir(parents=True, exist_ok=True) report_path.write_text(report, encoding="utf-8") logger.info("Report saved: %s", report_path)