feat(my-deepagent): v0.3 PR #1 — interactive session persistence + LangGraph saver wiring
v0.3의 토대. REPL/GUI 둘 다 장기 대화를 영속해서 `mydeepagent --session <id>`
또는 `GET /api/sessions/{id}`로 어디서든 이어 진행 가능. Claude Code의
`claude --resume` 등가 능력.
Data model
- `persistence/models.py`:
- 신규 `MessageRow` 테이블 — (session_id, seq) UNIQUE, role/content/
tool_calls/token_count/is_summary/archived/ts. LangGraph checkpoint =
source of truth, 이 테이블은 GUI/CLI 빠른 조회 mirror. divergence
rebuild 매커니즘 없음 (단순성 우선).
- `InteractiveSessionRow` 컬럼 8개 추가:
total_input_tokens, total_output_tokens (PR #2 tiktoken으로 정밀화 예정),
model, project_key (sha256(realpath(repo_path))[:16]),
title (첫 user msg 50자), plan_mode (PR #5), parent_session_id (PR #6),
depth (PR #6 sub-agent depth ≤ 3).
- `alembic/versions/684e70f4536a_*.py` (신규):
- `op.batch_alter_table` 사용 — SQLite ALTER constraint 미지원 우회. Postgres는
native DDL.
- 자동생성이 제안한 LangGraph 테이블 (`checkpoints` 등) drop 라인은 의도적으로
제거 (langgraph-checkpoint-postgres가 자체 관리).
- server_default 박아서 기존 row 안전.
CLI
- `cli/interactive.py`:
- REPL 진입 시 `get_checkpointer_ctx(config.database_url)` 컨텍스트 열고
REPL 전체 동안 유지. `build_agent(..., checkpointer=saver)`로 deepagents에
LangGraph saver wire. v0.2 PR #10의 CostMiddleware / AuditToolMiddleware
보존.
- `_invoke_and_stream`이 ainvoke 전후 명시적 MessageRow insert
(user → ainvoke → assistant). last_message_at + total_*_tokens 누적 +
첫 user msg로 title 자동 setter.
- `InteractiveSession.thread_suffix` 도입. /model / /agent / /clear 호출
시 suffix bump → LangGraph thread_id = `{session_id}:{suffix}` 로 새
deepagents 컨텍스트 시작 (compaction과 같은 패턴, PR #2 재사용).
- 신규 `--session <id|prefix>` 옵션: 기존 row 로드 (ended이면 거부) 또는
신규 row insert (AgentPersonaRow upsert + project_key 박음).
- `/clear` 슬래시 갱신: messages.archived=True + 새 thread 시작. 세션 자체
는 살아있음 — `sessions show <id> --all`로 조회 가능.
- `cli/sessions.py` (신규): `mydeepagent sessions list/show/resume/end`.
show <id> [--all]이 archived 메시지까지. 6+ char prefix + 중복 시 명시
에러.
- `cli/main.py`: --session 옵션 + sessions 서브명령 + interactive_command
시그니처 확장.
HTTP API
- `api/models.py`: SessionSummary / MessageInfo / SessionDetail /
CreateSessionRequest / PostMessageRequest / SessionAck DTO 신규 (모두
extra="forbid").
- `api/routes/sessions.py` (신규):
GET /api/sessions?limit=&state=
GET /api/sessions/{id}?all=true (마지막 200 메시지)
POST /api/sessions (persona_name, model_override, repo_path)
POST /api/sessions/{id}/messages (사용자 메시지 append, 동기 persist;
PR #7 GUI에서 background ainvoke 추가)
GET /api/sessions/{id}/stream (SSE — 0.5s polling, last-event-id 헤더
+ ?last_seq 둘 다 지원)
POST /api/sessions/{id}/end
- `api/app.py`: sessions 라우터 마운트.
Tests
- `tests/integration/test_session_persist.py` (5 시나리오):
1. create + post → row + 메시지 + title + token 누적 영속
2. list가 신규 3 세션 모두 포함
3. prefix resolution + 404
4. end 후 메시지 거부 (409)
5. ?all=true가 archived 메시지 surfacing
Gates
- ruff check + ruff format + mypy --strict: PASS (124 source files)
- pytest non-E2E: 608 PASS (25.86 s) — v0.2 PR #3 후 603에서 +5 신규
- pytest E2E real OpenRouter on Postgres: PASS 82.07 s (베이스라인 60–122s
범위 내; DR-3 +20% 임계점 통과)
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -1,8 +1,18 @@
|
||||
"""mydeepagent (no subcommand) — interactive REPL.
|
||||
|
||||
prompt_toolkit-based REPL. Slash commands for navigation; everything else
|
||||
goes to the bound agent. File refs ``@path/to/file.py`` are expanded into
|
||||
markdown code blocks inline before the message is sent.
|
||||
v0.3 PR #1 changes:
|
||||
- LangGraph `AsyncPostgresSaver` is now wired per REPL lifetime — checkpoints
|
||||
survive ^C and a later `mydeepagent --session <id>` resumes the thread.
|
||||
- Every user/assistant turn is mirrored into the `messages` table for fast
|
||||
GUI/CLI listing. LangGraph checkpoints remain the source of truth.
|
||||
- `InteractiveSessionRow` is now persisted at REPL start (or loaded when
|
||||
`--session <id>` is given) — sessions are addressable by short id.
|
||||
- `/model <name>` issues a fresh LangGraph thread suffix so the deepagents
|
||||
context restarts on model switch (compaction-style pattern).
|
||||
- `_resolve_session_id` accepts a 6+ char prefix.
|
||||
|
||||
PR #2 will hook compaction triggers + tiktoken-accurate token counts onto
|
||||
the same `MessageRow` + `InteractiveSessionRow` foundation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -14,10 +24,12 @@ from pathlib import Path
|
||||
from typing import Any
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import typer
|
||||
from prompt_toolkit import PromptSession
|
||||
from prompt_toolkit.completion import WordCompleter
|
||||
from prompt_toolkit.history import FileHistory
|
||||
from rich.console import Console
|
||||
from sqlalchemy import desc, select
|
||||
|
||||
from ..audit import make_audit_recorder
|
||||
from ..budget import make_budget_tracker_from_config
|
||||
@@ -26,7 +38,9 @@ from ..governance import require_consent
|
||||
from ..middleware.audit import AuditToolMiddleware
|
||||
from ..middleware.cost import CostMiddleware
|
||||
from ..monitoring.pricing import ModelPrice, PricingCache
|
||||
from ..persistence.checkpointer import get_checkpointer_ctx
|
||||
from ..persistence.db import Database
|
||||
from ..persistence.models import InteractiveSessionRow, MessageRow
|
||||
from ..persona import Persona, load_personas_from_dir
|
||||
from ..session import build_agent
|
||||
from ..slash import SlashParsed, SlashRegistry, parse_slash
|
||||
@@ -91,8 +105,18 @@ def _now_iso() -> str:
|
||||
return datetime.now(UTC).isoformat(timespec="seconds")
|
||||
|
||||
|
||||
def _truncate_title(text: str, max_chars: int = 50) -> str:
|
||||
one_line = re.sub(r"\s+", " ", text).strip()
|
||||
return one_line[: max_chars - 1] + "…" if len(one_line) > max_chars else one_line
|
||||
|
||||
|
||||
class InteractiveSession:
|
||||
"""Holds REPL state: current persona, current model override, history, agent."""
|
||||
"""Holds REPL state: persona, model override, agent, LangGraph saver, DB row.
|
||||
|
||||
v0.3 PR #1: also tracks `thread_suffix` so `/model` and (future PR #2)
|
||||
compaction can issue a fresh LangGraph thread while the session row stays
|
||||
the same.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -102,6 +126,7 @@ class InteractiveSession:
|
||||
pricing: PricingCache,
|
||||
repo_root: Path,
|
||||
session_id: UUID,
|
||||
saver: Any,
|
||||
) -> None:
|
||||
self.config = config
|
||||
self.personas = personas
|
||||
@@ -109,9 +134,17 @@ class InteractiveSession:
|
||||
self.pricing = pricing
|
||||
self.repo_root = repo_root
|
||||
self.session_id = session_id
|
||||
self.saver = saver
|
||||
self._model_override: str | None = None
|
||||
self._persona = self._default_persona()
|
||||
self._agent: Any | None = None
|
||||
# thread_suffix bumps on /model and compaction; LangGraph thread_id =
|
||||
# f"{session_id}:{suffix}" so model switches start fresh deepagents state.
|
||||
self._thread_suffix: int = 0
|
||||
|
||||
@property
|
||||
def thread_id(self) -> str:
|
||||
return f"{self.session_id}:{self._thread_suffix}"
|
||||
|
||||
def _default_persona(self) -> Persona:
|
||||
name = self.config.default_persona
|
||||
@@ -132,21 +165,28 @@ class InteractiveSession:
|
||||
def model_override(self) -> str | None:
|
||||
return self._model_override
|
||||
|
||||
@property
|
||||
def active_model(self) -> str:
|
||||
return self._model_override or self._persona.model
|
||||
|
||||
def set_persona(self, name: str) -> Persona:
|
||||
for p in self.personas:
|
||||
if p.name == name or f"{p.name}@{p.version}" == name:
|
||||
self._persona = p
|
||||
self._agent = None # rebuild on next turn
|
||||
self._thread_suffix += 1 # persona switch → new LangGraph thread
|
||||
return p
|
||||
raise ValueError(f"persona not found: {name!r}")
|
||||
|
||||
def set_model(self, model: str | None) -> None:
|
||||
self._model_override = model
|
||||
self._agent = None
|
||||
self._thread_suffix += 1 # model switch → new LangGraph thread
|
||||
|
||||
def clear_agent_cache(self) -> None:
|
||||
"""Flush the cached agent so the next call rebuilds with a fresh thread."""
|
||||
self._agent = None
|
||||
self._thread_suffix += 1
|
||||
|
||||
def build_agent_if_needed(self) -> Any:
|
||||
if self._agent is not None:
|
||||
@@ -154,7 +194,7 @@ class InteractiveSession:
|
||||
budget = make_budget_tracker_from_config(self.db, self.config)
|
||||
cost_mw = CostMiddleware(
|
||||
pricing=self.pricing,
|
||||
model_name=self._model_override or self._persona.model,
|
||||
model_name=self.active_model,
|
||||
interactive_session_id=self.session_id,
|
||||
persona_name=self._persona.name,
|
||||
budget_tracker=budget,
|
||||
@@ -169,10 +209,159 @@ class InteractiveSession:
|
||||
root_dir=self.repo_root,
|
||||
middleware=[cost_mw, audit_mw],
|
||||
model_override=self._model_override,
|
||||
checkpointer=self.saver,
|
||||
)
|
||||
return self._agent
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DB helpers (session + message persistence)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def _load_or_create_session_row(
|
||||
db: Database,
|
||||
session_id: UUID,
|
||||
persona: Persona,
|
||||
repo_root: Path,
|
||||
*,
|
||||
create: bool,
|
||||
) -> InteractiveSessionRow:
|
||||
"""Return the session row, creating it if ``create=True`` and not found."""
|
||||
from sqlalchemy import select as _select
|
||||
|
||||
from ..persistence.models import AgentPersonaRow
|
||||
|
||||
async with db.session() as s:
|
||||
existing = await s.get(InteractiveSessionRow, str(session_id))
|
||||
if existing is not None:
|
||||
return existing
|
||||
if not create:
|
||||
raise RuntimeError(f"session {session_id} not found")
|
||||
|
||||
# Find or upsert the AgentPersonaRow. We need persona_id for the FK.
|
||||
ph = persona.compute_hash()
|
||||
persona_row = (
|
||||
await s.execute(_select(AgentPersonaRow).where(AgentPersonaRow.hash == ph))
|
||||
).scalar_one_or_none()
|
||||
if persona_row is None:
|
||||
persona_row = AgentPersonaRow(
|
||||
id=str(uuid4()),
|
||||
name=persona.name,
|
||||
version=persona.version,
|
||||
hash=ph,
|
||||
definition=persona.model_dump(by_alias=True),
|
||||
created_at=_now_iso(),
|
||||
)
|
||||
s.add(persona_row)
|
||||
await s.flush()
|
||||
|
||||
# Derive project_key from the repo root (stable hash).
|
||||
from ..hash import sha256
|
||||
|
||||
project_key = sha256(str(repo_root.resolve()))[:16]
|
||||
|
||||
row = InteractiveSessionRow(
|
||||
id=str(session_id),
|
||||
persona_id=persona_row.id,
|
||||
persona_hash=ph,
|
||||
started_at=_now_iso(),
|
||||
last_message_at=None,
|
||||
state="active",
|
||||
total_input_tokens=0,
|
||||
total_output_tokens=0,
|
||||
model=persona.model,
|
||||
project_key=project_key,
|
||||
title=None,
|
||||
plan_mode=False,
|
||||
parent_session_id=None,
|
||||
depth=0,
|
||||
)
|
||||
s.add(row)
|
||||
await s.commit()
|
||||
return row
|
||||
|
||||
|
||||
async def _next_message_seq(db: Database, session_id: UUID) -> int:
|
||||
async with db.session() as s:
|
||||
result = await s.execute(
|
||||
select(MessageRow.seq)
|
||||
.where(MessageRow.session_id == str(session_id))
|
||||
.order_by(desc(MessageRow.seq))
|
||||
.limit(1)
|
||||
)
|
||||
last = result.scalar_one_or_none()
|
||||
return (last or 0) + 1
|
||||
|
||||
|
||||
async def _append_message(
|
||||
db: Database,
|
||||
session_id: UUID,
|
||||
role: str,
|
||||
content: str,
|
||||
*,
|
||||
tool_calls: dict[str, Any] | None = None,
|
||||
token_count: int = 0,
|
||||
) -> None:
|
||||
"""Insert one MessageRow + update last_message_at / title (if first user msg)."""
|
||||
seq = await _next_message_seq(db, session_id)
|
||||
now = _now_iso()
|
||||
async with db.session() as s:
|
||||
s.add(
|
||||
MessageRow(
|
||||
session_id=str(session_id),
|
||||
seq=seq,
|
||||
role=role,
|
||||
content=content,
|
||||
tool_calls=tool_calls,
|
||||
token_count=token_count,
|
||||
is_summary=False,
|
||||
archived=False,
|
||||
ts=now,
|
||||
)
|
||||
)
|
||||
row = await s.get(InteractiveSessionRow, str(session_id))
|
||||
if row is not None:
|
||||
row.last_message_at = now
|
||||
if row.title is None and role == "user":
|
||||
row.title = _truncate_title(content)
|
||||
if role == "user":
|
||||
row.total_input_tokens += token_count
|
||||
elif role == "assistant":
|
||||
row.total_output_tokens += token_count
|
||||
await s.commit()
|
||||
|
||||
|
||||
async def _archive_messages(db: Database, session_id: UUID) -> int:
|
||||
"""Mark all current messages as archived=True. Returns the count touched."""
|
||||
from sqlalchemy import update
|
||||
|
||||
async with db.session() as s:
|
||||
result = await s.execute(
|
||||
update(MessageRow)
|
||||
.where(MessageRow.session_id == str(session_id))
|
||||
.where(MessageRow.archived.is_(False))
|
||||
.values(archived=True)
|
||||
)
|
||||
await s.commit()
|
||||
# update() returns CursorResult which has rowcount; cast for mypy.
|
||||
return int(getattr(result, "rowcount", 0) or 0)
|
||||
|
||||
|
||||
async def _mark_session_ended(db: Database, session_id: UUID) -> None:
|
||||
async with db.session() as s:
|
||||
row = await s.get(InteractiveSessionRow, str(session_id))
|
||||
if row is not None and row.state != "ended":
|
||||
row.state = "ended"
|
||||
row.ended_at = _now_iso()
|
||||
await s.commit()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Slash commands
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _register_navigation_slash(reg: SlashRegistry, sess: InteractiveSession) -> None:
|
||||
"""Register /quit, /exit, /help, /clear slash handlers."""
|
||||
|
||||
@@ -181,19 +370,24 @@ def _register_navigation_slash(reg: SlashRegistry, sess: InteractiveSession) ->
|
||||
|
||||
async def _help(_: SlashParsed) -> bool:
|
||||
_CONSOLE.print("[bold]Slash commands:[/]")
|
||||
for name, desc in reg.all_help():
|
||||
_CONSOLE.print(f" /{name:14s} {desc}")
|
||||
for name, help_text in reg.all_help():
|
||||
_CONSOLE.print(f" /{name:14s} {help_text}")
|
||||
return False
|
||||
|
||||
async def _clear(_: SlashParsed) -> bool:
|
||||
# v0.3 PR #1: /clear archives the current session's messages and bumps
|
||||
# the LangGraph thread suffix so the next turn starts with a fresh
|
||||
# context. The session row stays — only the message history is
|
||||
# archived (still inspectable via `sessions show <id> --all`).
|
||||
count = await _archive_messages(sess.db, sess.session_id)
|
||||
sess.clear_agent_cache()
|
||||
_CONSOLE.print("[dim]context cleared (new session thread)[/]")
|
||||
_CONSOLE.print(f"[dim]context cleared ({count} messages archived, new thread)[/]")
|
||||
return False
|
||||
|
||||
reg.register("quit", _quit, help="exit the REPL")
|
||||
reg.register("exit", _quit, help="alias for /quit")
|
||||
reg.register("help", _help, help="show slash commands")
|
||||
reg.register("clear", _clear, help="clear conversation context")
|
||||
reg.register("clear", _clear, help="archive messages + start a fresh thread")
|
||||
|
||||
|
||||
def _register_persona_slash(reg: SlashRegistry, sess: InteractiveSession) -> None:
|
||||
@@ -214,15 +408,21 @@ def _register_persona_slash(reg: SlashRegistry, sess: InteractiveSession) -> Non
|
||||
|
||||
async def _model_cmd(cmd: SlashParsed) -> bool:
|
||||
if not cmd.args:
|
||||
cur = sess.model_override or sess.persona.model
|
||||
_CONSOLE.print(f"current model: [cyan]{cur}[/]")
|
||||
_CONSOLE.print(f"current model: [cyan]{sess.active_model}[/]")
|
||||
return False
|
||||
if cmd.args[0] in ("-", "reset"):
|
||||
sess.set_model(None)
|
||||
_CONSOLE.print("[green]model override cleared[/]")
|
||||
new_model = sess.active_model
|
||||
_CONSOLE.print(f"[green]model override cleared → {new_model} (new thread)[/]")
|
||||
else:
|
||||
sess.set_model(cmd.args[0])
|
||||
_CONSOLE.print(f"[green]model → {cmd.args[0]}[/]")
|
||||
_CONSOLE.print(f"[green]model → {cmd.args[0]} (new thread)[/]")
|
||||
# Persist the new active model on the session row.
|
||||
async with sess.db.session() as s:
|
||||
row = await s.get(InteractiveSessionRow, str(sess.session_id))
|
||||
if row is not None:
|
||||
row.model = sess.active_model
|
||||
await s.commit()
|
||||
return False
|
||||
|
||||
reg.register("agent", _agent_cmd, help="list or switch persona: /agent [name]")
|
||||
@@ -230,7 +430,7 @@ def _register_persona_slash(reg: SlashRegistry, sess: InteractiveSession) -> Non
|
||||
|
||||
|
||||
def _register_telemetry_slash(reg: SlashRegistry) -> None:
|
||||
"""Register /stats, /budget, /runs slash handlers."""
|
||||
"""Register /stats, /budget, /runs, /sessions slash handlers."""
|
||||
|
||||
async def _stats(_: SlashParsed) -> bool:
|
||||
from .stats import stats_command
|
||||
@@ -250,9 +450,16 @@ def _register_telemetry_slash(reg: SlashRegistry) -> None:
|
||||
runs_list_command(limit=10, state_filter=None)
|
||||
return False
|
||||
|
||||
async def _sessions(_: SlashParsed) -> bool:
|
||||
from .sessions import sessions_list_command
|
||||
|
||||
sessions_list_command(limit=10)
|
||||
return False
|
||||
|
||||
reg.register("stats", _stats, help="LLM-call stats (last 24h)")
|
||||
reg.register("budget", _budget, help="budget ledger")
|
||||
reg.register("runs", _runs, help="list recent workflow runs")
|
||||
reg.register("sessions", _sessions, help="list recent interactive sessions")
|
||||
|
||||
|
||||
def _register_slash(reg: SlashRegistry, sess: InteractiveSession) -> None:
|
||||
@@ -267,16 +474,42 @@ def _completer(personas: list[Persona], slash_names: list[str]) -> WordCompleter
|
||||
return WordCompleter(words, ignore_case=True, sentence=True)
|
||||
|
||||
|
||||
async def _invoke_and_stream(agent: Any, user_text: str, session_id: UUID) -> None:
|
||||
"""Invoke the agent and pretty-print the response.
|
||||
def _approx_token_count(text: str) -> int:
|
||||
"""Conservative char-based token estimate (PR #1 placeholder).
|
||||
|
||||
v0.1 keeps it simple — full ainvoke, then print the final message.
|
||||
Token-level streaming via astream is a Step 16 polish.
|
||||
PR #2 swaps this for tiktoken with model-aware tokenizer selection.
|
||||
1 token ≈ 4 chars is the cl100k_base rule of thumb for English; mixed
|
||||
Korean text trends higher tokens/char, so we round up.
|
||||
"""
|
||||
result = await agent.ainvoke(
|
||||
{"messages": [{"role": "user", "content": user_text}]},
|
||||
config={"configurable": {"thread_id": str(session_id)}},
|
||||
return max(0, (len(text) + 3) // 4)
|
||||
|
||||
|
||||
async def _invoke_and_stream(
|
||||
agent: Any,
|
||||
user_text: str,
|
||||
sess: InteractiveSession,
|
||||
) -> None:
|
||||
"""Invoke the agent, print the assistant response, and persist both messages."""
|
||||
# 1. Persist the user message first so it's durable even if ainvoke fails.
|
||||
await _append_message(
|
||||
sess.db,
|
||||
sess.session_id,
|
||||
"user",
|
||||
user_text,
|
||||
token_count=_approx_token_count(user_text),
|
||||
)
|
||||
|
||||
# 2. Invoke the agent. LangGraph thread_id includes the suffix so /model
|
||||
# or /clear-induced switches start a fresh context.
|
||||
try:
|
||||
result = await agent.ainvoke(
|
||||
{"messages": [{"role": "user", "content": user_text}]},
|
||||
config={"configurable": {"thread_id": sess.thread_id}},
|
||||
)
|
||||
except Exception:
|
||||
# User msg is already persisted; surface the error and bail.
|
||||
raise
|
||||
|
||||
messages = result.get("messages", []) if isinstance(result, dict) else []
|
||||
if not messages:
|
||||
return
|
||||
@@ -286,7 +519,17 @@ async def _invoke_and_stream(agent: Any, user_text: str, session_id: UUID) -> No
|
||||
content = "\n".join(
|
||||
(c.get("text", str(c)) if isinstance(c, dict) else str(c)) for c in content
|
||||
)
|
||||
_CONSOLE.print(str(content))
|
||||
content_str = str(content)
|
||||
_CONSOLE.print(content_str)
|
||||
|
||||
# 3. Persist the assistant response.
|
||||
await _append_message(
|
||||
sess.db,
|
||||
sess.session_id,
|
||||
"assistant",
|
||||
content_str,
|
||||
token_count=_approx_token_count(content_str),
|
||||
)
|
||||
|
||||
|
||||
async def _repl_loop(
|
||||
@@ -319,12 +562,46 @@ async def _repl_loop(
|
||||
expanded = _expand_file_refs(line, sess.repo_root)
|
||||
agent = sess.build_agent_if_needed()
|
||||
try:
|
||||
await _invoke_and_stream(agent, expanded, sess.session_id)
|
||||
await _invoke_and_stream(agent, expanded, sess)
|
||||
except Exception as e:
|
||||
_CONSOLE.print(f"[red]agent error:[/] {type(e).__name__}: {e}")
|
||||
|
||||
|
||||
async def _interactive_loop_async(persona_override: str | None, model_override: str | None) -> int:
|
||||
async def _resolve_session_arg(db: Database, prefix_or_full: str) -> UUID:
|
||||
"""Accept full UUID or 6+ char prefix; return resolved UUID. Exit on miss."""
|
||||
try:
|
||||
return UUID(prefix_or_full)
|
||||
except ValueError:
|
||||
pass
|
||||
if len(prefix_or_full) < 6:
|
||||
_CONSOLE.print("[red]session prefix must be >=6 chars or a full UUID[/]")
|
||||
raise typer.Exit(code=2)
|
||||
async with db.session() as s:
|
||||
rows = (
|
||||
(
|
||||
await s.execute(
|
||||
select(InteractiveSessionRow.id)
|
||||
.where(InteractiveSessionRow.id.like(f"{prefix_or_full}%"))
|
||||
.limit(2)
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
)
|
||||
if not rows:
|
||||
_CONSOLE.print(f"[red]no session matches prefix:[/] {prefix_or_full}")
|
||||
raise typer.Exit(code=1)
|
||||
if len(rows) > 1:
|
||||
_CONSOLE.print(f"[red]ambiguous prefix matches >1 session:[/] {prefix_or_full}")
|
||||
raise typer.Exit(code=1)
|
||||
return UUID(rows[0])
|
||||
|
||||
|
||||
async def _interactive_loop_async(
|
||||
persona_override: str | None,
|
||||
model_override: str | None,
|
||||
session_arg: str | None,
|
||||
) -> int:
|
||||
config = load_config()
|
||||
require_consent(config.data_dir)
|
||||
db = Database(config.database_url)
|
||||
@@ -334,34 +611,78 @@ async def _interactive_loop_async(persona_override: str | None, model_override:
|
||||
_CONSOLE.print("[red]no personas seeded; run `mydeepagent init`[/]")
|
||||
return 1
|
||||
pricing = _static_pricing_seed()
|
||||
session_id = uuid4()
|
||||
|
||||
# Resolve session id: --session given → existing; otherwise new uuid.
|
||||
if session_arg:
|
||||
session_id = await _resolve_session_arg(db, session_arg)
|
||||
async with db.session() as s:
|
||||
row = await s.get(InteractiveSessionRow, str(session_id))
|
||||
if row is None:
|
||||
_CONSOLE.print(f"[red]session not found:[/] {session_arg}")
|
||||
await db.dispose()
|
||||
return 1
|
||||
if row.state == "ended":
|
||||
_CONSOLE.print(
|
||||
f"[yellow]session {row.id} is ended; start a new one with `mydeepagent`.[/]"
|
||||
)
|
||||
await db.dispose()
|
||||
return 1
|
||||
creating = False
|
||||
else:
|
||||
session_id = uuid4()
|
||||
creating = True
|
||||
|
||||
try:
|
||||
sess = InteractiveSession(config, personas, db, pricing, Path.cwd(), session_id)
|
||||
if persona_override:
|
||||
try:
|
||||
sess.set_persona(persona_override)
|
||||
except ValueError as e:
|
||||
_CONSOLE.print(f"[red]{e}[/]")
|
||||
return 1
|
||||
if model_override:
|
||||
sess.set_model(model_override)
|
||||
reg = SlashRegistry()
|
||||
_register_slash(reg, sess)
|
||||
async with get_checkpointer_ctx(config.database_url) as saver:
|
||||
# Resolve initial persona (may be overridden below).
|
||||
sess = InteractiveSession(config, personas, db, pricing, Path.cwd(), session_id, saver)
|
||||
if persona_override:
|
||||
try:
|
||||
sess.set_persona(persona_override)
|
||||
except ValueError as e:
|
||||
_CONSOLE.print(f"[red]{e}[/]")
|
||||
return 1
|
||||
# set_persona bumps thread_suffix; reset to 0 for new sessions so
|
||||
# initial thread_id is just "<session_id>:0" — clean.
|
||||
if creating:
|
||||
sess._thread_suffix = 0
|
||||
if model_override:
|
||||
sess.set_model(model_override)
|
||||
if creating:
|
||||
sess._thread_suffix = 0
|
||||
|
||||
persona_label = f"{sess.persona.name}@{sess.persona.version}"
|
||||
_CONSOLE.print(f"[bold cyan]my-deepagent[/] — persona [cyan]{persona_label}[/]")
|
||||
_CONSOLE.print("[dim]type /help for commands, /quit to exit[/]")
|
||||
# Now persist the session row (or load existing).
|
||||
await _load_or_create_session_row(
|
||||
db, session_id, sess.persona, Path.cwd(), create=creating
|
||||
)
|
||||
|
||||
prompt_session: PromptSession[str] = PromptSession(
|
||||
history=FileHistory(str(_history_path(config))),
|
||||
completer=_completer(personas, reg.names),
|
||||
)
|
||||
return await _repl_loop(sess, reg, prompt_session)
|
||||
reg = SlashRegistry()
|
||||
_register_slash(reg, sess)
|
||||
|
||||
persona_label = f"{sess.persona.name}@{sess.persona.version}"
|
||||
mode_tag = "[bold green]resuming[/]" if not creating else "[bold cyan]new[/]"
|
||||
_CONSOLE.print(
|
||||
f"{mode_tag} session [dim]{str(session_id)[:8]}…[/] · "
|
||||
f"persona [cyan]{persona_label}[/] · model [dim]{sess.active_model}[/]"
|
||||
)
|
||||
_CONSOLE.print("[dim]type /help for commands, /quit to exit[/]")
|
||||
|
||||
prompt_session: PromptSession[str] = PromptSession(
|
||||
history=FileHistory(str(_history_path(config))),
|
||||
completer=_completer(personas, reg.names),
|
||||
)
|
||||
code = await _repl_loop(sess, reg, prompt_session)
|
||||
# Leave the session "active" — user may resume via --session <id>.
|
||||
# Only explicit `/sessions end <id>` (or terminal state) marks it ended.
|
||||
return code
|
||||
finally:
|
||||
await db.dispose()
|
||||
|
||||
|
||||
def interactive_command(persona: str | None = None, model: str | None = None) -> int:
|
||||
def interactive_command(
|
||||
persona: str | None = None,
|
||||
model: str | None = None,
|
||||
session: str | None = None,
|
||||
) -> int:
|
||||
"""Entry point for the interactive REPL. Returns an exit code."""
|
||||
return asyncio.run(_interactive_loop_async(persona, model))
|
||||
return asyncio.run(_interactive_loop_async(persona, model, session))
|
||||
|
||||
Reference in New Issue
Block a user