Files
cross-eval/cross_eval/agent.py
2026-03-13 21:47:54 +09:00

436 lines
14 KiB
Python

"""Agent invocation via subprocess with live spinner."""
from __future__ import annotations
import itertools
import logging
import os
import subprocess
import sys
import tempfile
import threading
import time
from pathlib import Path
from typing import Optional
from cross_eval.models import AgentConfig, AgentResult
logger = logging.getLogger(__name__)
# CLI tools that support --system-prompt flag natively
_SYSTEM_PROMPT_AGENTS = ("claude",)
_REASONING_EFFORT_AGENTS = ("codex",)
class AgentInvocationError(RuntimeError):
"""Structured error for agent CLI failures."""
def __init__(
self,
*,
agent_name: str,
step_name: str,
cmd_preview: str,
raw_error: str,
failure_type: str,
suggested_action: str,
) -> None:
self.agent_name = agent_name
self.step_name = step_name
self.cmd_preview = cmd_preview
self.raw_error = raw_error
self.failure_type = failure_type
self.suggested_action = suggested_action
super().__init__(
f"Agent '{agent_name}' failed (exit code != 0) at step '{step_name}':\n"
f" type: {failure_type}\n"
f" cmd: {cmd_preview}\n"
f" error: {raw_error or '(no output)'}\n"
f" action: {suggested_action}"
)
def _supports_system_prompt_flag(command: str) -> bool:
"""Check if the agent CLI supports --system-prompt flag."""
return any(name in command for name in _SYSTEM_PROMPT_AGENTS)
def _supports_reasoning_effort(command: str) -> bool:
"""Check if the agent CLI supports reasoning effort overrides."""
return any(name in command for name in _REASONING_EFFORT_AGENTS)
def _classify_agent_failure(detail: str) -> tuple[str, str]:
"""Classify a failed agent invocation into a user-actionable bucket."""
normalized = detail.lower()
auth_markers = (
"not logged in",
"please run /login",
"auth",
"authentication",
"invalid api key",
"api key",
"unauthorized",
"forbidden",
)
usage_limit_markers = (
"quota",
"rate limit",
"credits",
"credit balance",
"budget",
"insufficient funds",
"usage limit",
"token limit",
"billing",
)
if any(marker in normalized for marker in auth_markers):
return (
"AUTH",
"Agent CLI authentication is missing or expired. Re-authenticate the CLI, then rerun.",
)
if any(marker in normalized for marker in usage_limit_markers):
return (
"USAGE_LIMIT",
"Agent CLI hit a quota, billing, or token budget limit. Refill or raise the limit, then rerun.",
)
if "api error" in normalized:
return (
"API_ERROR",
"Agent CLI returned an API error. Inspect the saved error file for the raw response.",
)
return (
"UNKNOWN",
"Agent CLI failed for an unknown reason. Inspect the saved error file for details.",
)
class _Spinner:
"""Animated spinner for long-running agent calls."""
FRAMES = "⠋⠙⠹⠸⠼⠴⠦⠧⠇⠏"
_CLEAR_LINE = "\r" + (" " * 160) + "\r"
def __init__(self, message: str) -> None:
self.message = message
self._running = False
self._thread: Optional[threading.Thread] = None
self._start_time = 0.0
def start(self) -> None:
self._running = True
self._start_time = time.monotonic()
self._thread = threading.Thread(target=self._spin, daemon=True)
self._thread.start()
def _spin(self) -> None:
for frame in itertools.cycle(self.FRAMES):
if not self._running:
break
elapsed = int(time.monotonic() - self._start_time)
line = f"\r {frame} {self.message} ({elapsed}s)"
sys.stderr.write(line)
sys.stderr.flush()
time.sleep(0.1)
def stop(self, final: str) -> None:
self._running = False
if self._thread:
self._thread.join(timeout=1)
elapsed = round(time.monotonic() - self._start_time, 1)
sys.stderr.write(self._CLEAR_LINE)
sys.stderr.write(f" \u2713 {final} ({elapsed}s)\n")
sys.stderr.flush()
def _is_print_mode(args: list[str]) -> bool:
"""Check if the agent args include -p / --print flag."""
return "-p" in args or "--print" in args
def invoke_agent(
agent: AgentConfig,
prompt: str,
step_name: str,
cwd: Optional[Path] = None,
env: Optional[dict[str, str]] = None,
timeout: int | None = None,
quiet: bool = False,
) -> AgentResult:
"""Invoke an agent CLI with the given prompt.
Args:
quiet: If True, suppress spinner (for parallel execution).
"""
is_claude = "claude" in agent.command
is_interactive = is_claude and not _is_print_mode(agent.args)
cmd = [agent.command]
if agent.reasoning_effort and _supports_reasoning_effort(agent.command):
cmd.extend(["-c", f'model_reasoning_effort="{agent.reasoning_effort}"'])
cmd.extend(agent.args)
# --- Temp files for interactive (non -p) claude ---
task_file: Optional[Path] = None
output_file: Optional[Path] = None
if is_interactive:
# Write prompt + output instruction to temp task file
task_fd, task_path = tempfile.mkstemp(suffix=".md", prefix="cross_eval_task_")
task_file = Path(task_path)
os.close(task_fd)
out_fd, out_path = tempfile.mkstemp(suffix=".md", prefix="cross_eval_out_")
output_file = Path(out_path)
os.close(out_fd)
# Clear the output file so we can detect if agent wrote to it
output_file.write_text("", encoding="utf-8")
wrapped_prompt = (
f"{prompt}\n\n"
f"---\n"
f"IMPORTANT: Write your COMPLETE response to this file: {output_file}\n"
f"Do NOT modify any other files in the project."
)
task_file.write_text(wrapped_prompt, encoding="utf-8")
# System prompt via flag
if agent.system_prompt and _supports_system_prompt_flag(agent.command):
cmd.extend(["--system-prompt", agent.system_prompt])
# Positional arg: point claude to the task file
cmd.append(
f"Read the task file at {task_file} and follow all instructions in it. "
f"Write your complete output to {output_file}."
)
input_data: str | None = None
else:
# Print mode (-p) or non-claude: deliver prompt via stdin
if agent.system_prompt and _supports_system_prompt_flag(agent.command):
cmd.extend(["--system-prompt", agent.system_prompt])
input_data = prompt
elif agent.system_prompt:
input_data = (
f"<system>\n{agent.system_prompt}\n</system>\n\n"
f"{prompt}"
)
else:
input_data = prompt
logger.debug("Invoking agent '%s': %s", agent.name, " ".join(cmd[:5]) + " ...")
spinner: Optional[_Spinner] = None
if not quiet:
mode_label = "interactive" if is_interactive else ""
logger.info(" cmd: %s %s", " ".join(cmd[:6]), f"({mode_label})" if mode_label else "")
spinner = _Spinner(f"[{step_name}] {agent.name} running...")
spinner.start()
try:
start = time.monotonic()
result = subprocess.run(
cmd,
input=input_data,
capture_output=True,
text=True,
timeout=timeout,
cwd=cwd,
env=env,
)
duration = time.monotonic() - start
except subprocess.TimeoutExpired:
if spinner:
spinner.stop(f"[{step_name}] TIMEOUT after {timeout}s")
raise
except Exception:
if spinner:
spinner.stop(f"[{step_name}] ERROR")
raise
finally:
if task_file:
task_file.unlink(missing_ok=True)
if result.returncode != 0:
if spinner:
spinner.stop(f"[{step_name}] FAILED (exit {result.returncode})")
if output_file:
output_file.unlink(missing_ok=True)
err_detail = result.stderr.strip() or result.stdout.strip()
if err_detail and len(err_detail) > 500:
err_detail = err_detail[:500] + "..."
cmd_preview = " ".join(cmd[:6])
failure_type, suggested_action = _classify_agent_failure(err_detail or "")
raise AgentInvocationError(
agent_name=agent.name,
step_name=step_name,
cmd_preview=cmd_preview,
raw_error=err_detail or "(no output)",
failure_type=failure_type,
suggested_action=suggested_action,
)
# --- Capture output ---
if output_file:
output = output_file.read_text(encoding="utf-8").strip()
output_file.unlink(missing_ok=True)
if not output:
# Fallback to stdout if agent didn't write to the file
output = result.stdout.strip()
else:
output = result.stdout.strip()
chars = len(output)
if spinner:
spinner.stop(f"[{step_name}] done — {chars} chars")
if not output:
stderr_info = result.stderr.strip()
if stderr_info:
logger.warning(
"Agent '%s' produced empty output at step '%s'. stderr: %s",
agent.name, step_name, stderr_info[:500],
)
else:
logger.warning(
"Agent '%s' produced empty output at step '%s' (no stderr either)",
agent.name, step_name,
)
return AgentResult(
output=output,
exit_code=result.returncode,
agent_name=agent.name,
step_name=step_name,
duration_seconds=round(duration, 1),
)
def invoke_agent_agentic(
agent: AgentConfig,
prompt: str,
step_name: str,
worktree_path: Path,
env: Optional[dict[str, str]] = None,
timeout: int | None = None,
quiet: bool = False,
) -> AgentResult:
"""Invoke an agent in agentic mode (no -p, runs in worktree, captures git diff).
The agent runs without print mode so it can modify files directly.
After the agent exits, git diff (since last commit) is captured as the output.
"""
from cross_eval.worktree import capture_diff
# Write prompt to a temp file (outside worktree, won't appear in diffs)
import tempfile
task_fd, task_path = tempfile.mkstemp(suffix=".md", prefix="cross_eval_task_")
task_file = Path(task_path)
task_file.write_text(prompt, encoding="utf-8")
os.close(task_fd)
cmd = [agent.command]
if agent.reasoning_effort and _supports_reasoning_effort(agent.command):
cmd.extend(["-c", f'model_reasoning_effort="{agent.reasoning_effort}"'])
# Strip stdin sentinel ("-") from args for agentic mode
args = [a for a in agent.args if a != "-"]
cmd.extend(args)
# System prompt via flag if supported
if agent.system_prompt and _supports_system_prompt_flag(agent.command):
cmd.extend(["--system-prompt", agent.system_prompt])
# Deliver the prompt differently per agent type
is_codex = "codex" in agent.command
input_data: str | None = None
if is_codex:
# codex: stdin mode
cmd.append("-")
if agent.system_prompt and not _supports_system_prompt_flag(agent.command):
input_data = f"<system>\n{agent.system_prompt}\n</system>\n\n{prompt}"
else:
input_data = prompt
else:
# claude: use positional arg with a pointer to the task file
# (avoids OS arg length limits for large prompts)
cmd.append(
f"Read the task file at {task_file} and execute all instructions in it. "
f"Work in the current directory."
)
logger.debug(
"Invoking agent '%s' (agentic) in worktree: %s",
agent.name, worktree_path,
)
spinner: Optional[_Spinner] = None
if not quiet:
logger.info(" cmd: %s (agentic)", " ".join(cmd[:6]))
spinner = _Spinner(f"[{step_name}] {agent.name} (agentic) running...")
spinner.start()
try:
start = time.monotonic()
result = subprocess.run(
cmd,
input=input_data,
capture_output=True,
text=True,
timeout=timeout,
cwd=worktree_path,
env=env,
)
duration = time.monotonic() - start
except subprocess.TimeoutExpired:
if spinner:
spinner.stop(f"[{step_name}] TIMEOUT after {timeout}s")
raise
except Exception:
if spinner:
spinner.stop(f"[{step_name}] ERROR")
raise
finally:
# Clean up temp task file (it's in /tmp, not in worktree)
task_file.unlink(missing_ok=True)
if result.returncode != 0:
if spinner:
spinner.stop(f"[{step_name}] FAILED (exit {result.returncode})")
err_detail = result.stderr.strip() or result.stdout.strip()
if err_detail and len(err_detail) > 500:
err_detail = err_detail[:500] + "..."
cmd_preview = " ".join(cmd[:6])
failure_type, suggested_action = _classify_agent_failure(err_detail or "")
raise AgentInvocationError(
agent_name=agent.name,
step_name=step_name,
cmd_preview=cmd_preview,
raw_error=err_detail or "(no output)",
failure_type=failure_type,
suggested_action=suggested_action,
)
# Capture git diff as the output (changes since last commit on the branch)
diff_output = capture_diff(worktree_path)
if not diff_output:
diff_output = "(no changes)"
logger.warning(
"Agent '%s' made no file changes at step '%s'",
agent.name, step_name,
)
chars = len(diff_output)
if spinner:
spinner.stop(f"[{step_name}] done — {chars} chars (agentic)")
return AgentResult(
output=diff_output,
exit_code=result.returncode,
agent_name=agent.name,
step_name=step_name,
duration_seconds=round(duration, 1),
)