fableforge-agent-runtime 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agent_runtime/__init__.py +3 -0
- agent_runtime/cli.py +128 -0
- agent_runtime/daemon.py +172 -0
- agent_runtime/memory_store.py +244 -0
- agent_runtime/models.py +93 -0
- agent_runtime/server.py +119 -0
- agent_runtime/session_manager.py +156 -0
- agent_runtime/state_serializer.py +222 -0
- fableforge_agent_runtime-0.1.0.dist-info/METADATA +274 -0
- fableforge_agent_runtime-0.1.0.dist-info/RECORD +13 -0
- fableforge_agent_runtime-0.1.0.dist-info/WHEEL +4 -0
- fableforge_agent_runtime-0.1.0.dist-info/entry_points.txt +2 -0
- fableforge_agent_runtime-0.1.0.dist-info/licenses/LICENSE +21 -0
agent_runtime/server.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
"""FastAPI server — HTTP API for agent runtime."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from datetime import datetime, timezone
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from fastapi import FastAPI, HTTPException
|
|
9
|
+
from fastapi.middleware.cors import CORSMiddleware
|
|
10
|
+
|
|
11
|
+
from .daemon import AgentDaemon
|
|
12
|
+
from .models import (
|
|
13
|
+
CheckpointInfo,
|
|
14
|
+
HealthResponse,
|
|
15
|
+
MemoryRetrieveResponse,
|
|
16
|
+
MemorySearchResult,
|
|
17
|
+
MemoryStoreRequest,
|
|
18
|
+
SessionCreate,
|
|
19
|
+
SessionResume,
|
|
20
|
+
SessionState,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def create_app(daemon: AgentDaemon) -> FastAPI:
|
|
25
|
+
app = FastAPI(title="AgentRuntime", version="0.1.0")
|
|
26
|
+
app.add_middleware(
|
|
27
|
+
CORSMiddleware,
|
|
28
|
+
allow_origins=["*"],
|
|
29
|
+
allow_methods=["*"],
|
|
30
|
+
allow_headers=["*"],
|
|
31
|
+
)
|
|
32
|
+
sm = daemon.session_manager
|
|
33
|
+
|
|
34
|
+
@app.get("/health", response_model=HealthResponse)
|
|
35
|
+
async def health():
|
|
36
|
+
return daemon.health()
|
|
37
|
+
|
|
38
|
+
@app.post("/sessions", response_model=SessionState, status_code=201)
|
|
39
|
+
async def create_session(body: SessionCreate):
|
|
40
|
+
state = sm.create_session(
|
|
41
|
+
name=body.name,
|
|
42
|
+
model=body.model,
|
|
43
|
+
config={
|
|
44
|
+
"system_prompt": body.system_prompt,
|
|
45
|
+
"tools": body.tools,
|
|
46
|
+
},
|
|
47
|
+
)
|
|
48
|
+
return state
|
|
49
|
+
|
|
50
|
+
@app.get("/sessions", response_model=list[SessionState])
|
|
51
|
+
async def list_sessions():
|
|
52
|
+
return sm.list_sessions()
|
|
53
|
+
|
|
54
|
+
@app.get("/sessions/{session_id}", response_model=SessionState)
|
|
55
|
+
async def get_session(session_id: str):
|
|
56
|
+
try:
|
|
57
|
+
return sm.get_session_status(session_id)
|
|
58
|
+
except Exception:
|
|
59
|
+
raise HTTPException(status_code=404, detail=f"Session {session_id} not found")
|
|
60
|
+
|
|
61
|
+
@app.post("/sessions/{session_id}/start", response_model=SessionState)
|
|
62
|
+
async def start_session(session_id: str):
|
|
63
|
+
try:
|
|
64
|
+
return await sm.start_session(session_id)
|
|
65
|
+
except Exception as e:
|
|
66
|
+
raise HTTPException(status_code=400, detail=str(e))
|
|
67
|
+
|
|
68
|
+
@app.post("/sessions/{session_id}/pause", response_model=SessionState)
|
|
69
|
+
async def pause_session(session_id: str):
|
|
70
|
+
try:
|
|
71
|
+
return await sm.pause_session(session_id)
|
|
72
|
+
except Exception as e:
|
|
73
|
+
raise HTTPException(status_code=400, detail=str(e))
|
|
74
|
+
|
|
75
|
+
@app.post("/sessions/{session_id}/resume", response_model=SessionState)
|
|
76
|
+
async def resume_session(session_id: str, body: SessionResume | None = None):
|
|
77
|
+
checkpoint_id = body.checkpoint_id if body else None
|
|
78
|
+
try:
|
|
79
|
+
return await sm.resume_session(session_id, checkpoint_id=checkpoint_id)
|
|
80
|
+
except Exception as e:
|
|
81
|
+
raise HTTPException(status_code=400, detail=str(e))
|
|
82
|
+
|
|
83
|
+
@app.post("/sessions/{session_id}/stop", response_model=SessionState)
|
|
84
|
+
async def stop_session(session_id: str):
|
|
85
|
+
try:
|
|
86
|
+
return await sm.stop_session(session_id)
|
|
87
|
+
except Exception as e:
|
|
88
|
+
raise HTTPException(status_code=400, detail=str(e))
|
|
89
|
+
|
|
90
|
+
@app.get("/sessions/{session_id}/memory")
|
|
91
|
+
async def get_memory(session_id: str, key: str | None = None):
|
|
92
|
+
mem = sm.get_memory_store()
|
|
93
|
+
if key:
|
|
94
|
+
entry = mem.retrieve(key)
|
|
95
|
+
if entry is None:
|
|
96
|
+
raise HTTPException(status_code=404, detail=f"Memory key {key} not found")
|
|
97
|
+
return MemoryRetrieveResponse(key=entry.key, value=entry.value, timestamp=entry.timestamp)
|
|
98
|
+
return {"keys": mem.list_keys()}
|
|
99
|
+
|
|
100
|
+
@app.post("/sessions/{session_id}/memory", status_code=201)
|
|
101
|
+
async def store_memory(session_id: str, body: MemoryStoreRequest):
|
|
102
|
+
mem = sm.get_memory_store()
|
|
103
|
+
mem.store(key=body.key, value=body.value)
|
|
104
|
+
return {"status": "stored", "key": body.key}
|
|
105
|
+
|
|
106
|
+
@app.get("/sessions/{session_id}/checkpoints", response_model=list[CheckpointInfo])
|
|
107
|
+
async def list_checkpoints(session_id: str):
|
|
108
|
+
serializer = sm.get_state_serializer()
|
|
109
|
+
return serializer.list_checkpoints(session_id)
|
|
110
|
+
|
|
111
|
+
@app.post("/sessions/{session_id}/checkpoints", response_model=CheckpointInfo, status_code=201)
|
|
112
|
+
async def create_checkpoint(session_id: str, label: str | None = None):
|
|
113
|
+
serializer = sm.get_state_serializer()
|
|
114
|
+
try:
|
|
115
|
+
return serializer.create_checkpoint(session_id, label=label)
|
|
116
|
+
except ValueError as e:
|
|
117
|
+
raise HTTPException(status_code=404, detail=str(e))
|
|
118
|
+
|
|
119
|
+
return app
|
|
@@ -0,0 +1,156 @@
|
|
|
1
|
+
"""Session lifecycle management — create, start, pause, resume, stop."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import logging
|
|
7
|
+
import uuid
|
|
8
|
+
from datetime import datetime, timezone
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
from .memory_store import MemoryStore
|
|
12
|
+
from .models import SessionCreate, SessionState, SessionStatus
|
|
13
|
+
from .state_serializer import StateSerializer
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class SessionError(Exception):
|
|
19
|
+
pass
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class SessionManager:
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
state_serializer: StateSerializer | None = None,
|
|
26
|
+
memory_store: MemoryStore | None = None,
|
|
27
|
+
) -> None:
|
|
28
|
+
self._state_serializer = state_serializer or StateSerializer()
|
|
29
|
+
self._memory_store = memory_store or MemoryStore()
|
|
30
|
+
self._running_tasks: dict[str, asyncio.Task] = {}
|
|
31
|
+
self._heartbeat_intervals: dict[str, float] = {}
|
|
32
|
+
self._auto_checkpoint_interval = 60.0
|
|
33
|
+
|
|
34
|
+
def create_session(self, name: str, model: str, config: dict[str, Any] | None = None) -> SessionState:
|
|
35
|
+
cfg = config or {}
|
|
36
|
+
state = SessionState(
|
|
37
|
+
session_id=uuid.uuid4().hex[:16],
|
|
38
|
+
name=name,
|
|
39
|
+
model=model,
|
|
40
|
+
system_prompt=cfg.get("system_prompt", ""),
|
|
41
|
+
tools=cfg.get("tools", []),
|
|
42
|
+
)
|
|
43
|
+
self._state_serializer.save_session(state)
|
|
44
|
+
logger.info("Created session %s (%s)", state.session_id, name)
|
|
45
|
+
return state
|
|
46
|
+
|
|
47
|
+
async def start_session(self, session_id: str) -> SessionState:
|
|
48
|
+
state = self._state_serializer.load_session(session_id)
|
|
49
|
+
if state is None:
|
|
50
|
+
raise SessionError(f"Session {session_id} not found")
|
|
51
|
+
if state.status == SessionStatus.RUNNING:
|
|
52
|
+
raise SessionError(f"Session {session_id} is already running")
|
|
53
|
+
state.status = SessionStatus.RUNNING
|
|
54
|
+
state.updated_at = datetime.now(timezone.utc)
|
|
55
|
+
self._state_serializer.save_session(state)
|
|
56
|
+
task = asyncio.create_task(self._session_loop(session_id))
|
|
57
|
+
self._running_tasks[session_id] = task
|
|
58
|
+
logger.info("Started session %s", session_id)
|
|
59
|
+
return state
|
|
60
|
+
|
|
61
|
+
async def pause_session(self, session_id: str) -> SessionState:
|
|
62
|
+
state = self._state_serializer.load_session(session_id)
|
|
63
|
+
if state is None:
|
|
64
|
+
raise SessionError(f"Session {session_id} not found")
|
|
65
|
+
if state.status != SessionStatus.RUNNING:
|
|
66
|
+
raise SessionError(f"Session {session_id} is not running")
|
|
67
|
+
state.status = SessionStatus.PAUSED
|
|
68
|
+
state.updated_at = datetime.now(timezone.utc)
|
|
69
|
+
self._state_serializer.save_session(state)
|
|
70
|
+
task = self._running_tasks.pop(session_id, None)
|
|
71
|
+
if task and not task.done():
|
|
72
|
+
task.cancel()
|
|
73
|
+
try:
|
|
74
|
+
await task
|
|
75
|
+
except asyncio.CancelledError:
|
|
76
|
+
pass
|
|
77
|
+
self._state_serializer.create_checkpoint(session_id, label="pause-checkpoint")
|
|
78
|
+
logger.info("Paused session %s", session_id)
|
|
79
|
+
return state
|
|
80
|
+
|
|
81
|
+
async def resume_session(self, session_id: str, checkpoint_id: str | None = None) -> SessionState:
|
|
82
|
+
state = self._state_serializer.load_session(session_id)
|
|
83
|
+
if state is None:
|
|
84
|
+
raise SessionError(f"Session {session_id} not found")
|
|
85
|
+
if checkpoint_id:
|
|
86
|
+
state = self._state_serializer.resume_from_checkpoint(session_id, checkpoint_id)
|
|
87
|
+
state.status = SessionStatus.RUNNING
|
|
88
|
+
state.updated_at = datetime.now(timezone.utc)
|
|
89
|
+
self._state_serializer.save_session(state)
|
|
90
|
+
task = asyncio.create_task(self._session_loop(session_id))
|
|
91
|
+
self._running_tasks[session_id] = task
|
|
92
|
+
logger.info("Resumed session %s", session_id)
|
|
93
|
+
return state
|
|
94
|
+
|
|
95
|
+
async def stop_session(self, session_id: str) -> SessionState:
|
|
96
|
+
state = self._state_serializer.load_session(session_id)
|
|
97
|
+
if state is None:
|
|
98
|
+
raise SessionError(f"Session {session_id} not found")
|
|
99
|
+
state.status = SessionStatus.STOPPED
|
|
100
|
+
state.updated_at = datetime.now(timezone.utc)
|
|
101
|
+
self._state_serializer.save_session(state)
|
|
102
|
+
task = self._running_tasks.pop(session_id, None)
|
|
103
|
+
if task and not task.done():
|
|
104
|
+
task.cancel()
|
|
105
|
+
try:
|
|
106
|
+
await task
|
|
107
|
+
except asyncio.CancelledError:
|
|
108
|
+
pass
|
|
109
|
+
self._state_serializer.create_checkpoint(session_id, label="stop-checkpoint")
|
|
110
|
+
logger.info("Stopped session %s", session_id)
|
|
111
|
+
return state
|
|
112
|
+
|
|
113
|
+
def list_sessions(self) -> list[SessionState]:
|
|
114
|
+
return self._state_serializer.list_sessions()
|
|
115
|
+
|
|
116
|
+
def get_session_status(self, session_id: str) -> SessionState:
|
|
117
|
+
state = self._state_serializer.load_session(session_id)
|
|
118
|
+
if state is None:
|
|
119
|
+
raise SessionError(f"Session {session_id} not found")
|
|
120
|
+
return state
|
|
121
|
+
|
|
122
|
+
def get_memory_store(self) -> MemoryStore:
|
|
123
|
+
return self._memory_store
|
|
124
|
+
|
|
125
|
+
def get_state_serializer(self) -> StateSerializer:
|
|
126
|
+
return self._state_serializer
|
|
127
|
+
|
|
128
|
+
async def _session_loop(self, session_id: str) -> None:
|
|
129
|
+
last_checkpoint = asyncio.get_event_loop().time()
|
|
130
|
+
try:
|
|
131
|
+
while True:
|
|
132
|
+
await asyncio.sleep(1.0)
|
|
133
|
+
now = asyncio.get_event_loop().time()
|
|
134
|
+
if now - last_checkpoint >= self._auto_checkpoint_interval:
|
|
135
|
+
state = self._state_serializer.load_session(session_id)
|
|
136
|
+
if state and state.status == SessionStatus.RUNNING:
|
|
137
|
+
self._state_serializer.create_checkpoint(
|
|
138
|
+
session_id, label="auto-checkpoint"
|
|
139
|
+
)
|
|
140
|
+
self._state_serializer.prune_checkpoints(session_id, keep_last=10)
|
|
141
|
+
last_checkpoint = now
|
|
142
|
+
except asyncio.CancelledError:
|
|
143
|
+
logger.info("Session loop cancelled for %s", session_id)
|
|
144
|
+
|
|
145
|
+
async def shutdown(self) -> None:
|
|
146
|
+
logger.info("Shutting down session manager, preserving %d active sessions", len(self._running_tasks))
|
|
147
|
+
for session_id, task in list(self._running_tasks.items()):
|
|
148
|
+
if not task.done():
|
|
149
|
+
task.cancel()
|
|
150
|
+
try:
|
|
151
|
+
await task
|
|
152
|
+
except asyncio.CancelledError:
|
|
153
|
+
pass
|
|
154
|
+
self._state_serializer.create_checkpoint(session_id, label="shutdown-checkpoint")
|
|
155
|
+
self._running_tasks.clear()
|
|
156
|
+
logger.info("All sessions preserved and shut down")
|
|
@@ -0,0 +1,222 @@
|
|
|
1
|
+
"""Serialize and deserialize agent state to/from SQLite."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import logging
|
|
7
|
+
import uuid
|
|
8
|
+
from datetime import datetime, timezone
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Any
|
|
11
|
+
|
|
12
|
+
from sqlalchemy import (
|
|
13
|
+
Column,
|
|
14
|
+
DateTime,
|
|
15
|
+
Index,
|
|
16
|
+
LargeBinary,
|
|
17
|
+
String,
|
|
18
|
+
Text,
|
|
19
|
+
create_engine,
|
|
20
|
+
text,
|
|
21
|
+
)
|
|
22
|
+
from sqlalchemy.orm import DeclarativeBase, Session, sessionmaker
|
|
23
|
+
|
|
24
|
+
from .models import (
|
|
25
|
+
CheckpointInfo,
|
|
26
|
+
Message,
|
|
27
|
+
SessionState,
|
|
28
|
+
SessionStatus,
|
|
29
|
+
ToolCall,
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
logger = logging.getLogger(__name__)
|
|
33
|
+
|
|
34
|
+
DEFAULT_DB_PATH = Path.home() / ".agent_runtime" / "state.db"
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class Base(DeclarativeBase):
|
|
38
|
+
pass
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class SessionRow(Base):
|
|
42
|
+
__tablename__ = "sessions"
|
|
43
|
+
|
|
44
|
+
session_id = Column(String(64), primary_key=True)
|
|
45
|
+
name = Column(String(256), default="")
|
|
46
|
+
model = Column(String(128), default="")
|
|
47
|
+
system_prompt = Column(Text, default="")
|
|
48
|
+
tools = Column(Text, default="[]")
|
|
49
|
+
status = Column(String(32), default=SessionStatus.CREATED.value)
|
|
50
|
+
memory_json = Column(Text, default="{}")
|
|
51
|
+
messages_json = Column(Text, default="[]")
|
|
52
|
+
tool_history_json = Column(Text, default="[]")
|
|
53
|
+
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
|
|
54
|
+
updated_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class CheckpointRow(Base):
|
|
58
|
+
__tablename__ = "checkpoints"
|
|
59
|
+
|
|
60
|
+
checkpoint_id = Column(String(64), primary_key=True)
|
|
61
|
+
session_id = Column(String(64), nullable=False, index=True)
|
|
62
|
+
label = Column(String(256), nullable=True)
|
|
63
|
+
state_json = Column(Text, nullable=False)
|
|
64
|
+
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
|
|
65
|
+
|
|
66
|
+
__table_args__ = (Index("ix_checkpoints_session_created", "session_id", "created_at"),)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class StateSerializer:
|
|
70
|
+
def __init__(self, db_path: Path | str | None = None) -> None:
|
|
71
|
+
if db_path is None:
|
|
72
|
+
db_path = DEFAULT_DB_PATH
|
|
73
|
+
self.db_path = Path(db_path)
|
|
74
|
+
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
|
75
|
+
self.engine = create_engine(f"sqlite:///{self.db_path}", echo=False)
|
|
76
|
+
Base.metadata.create_all(self.engine)
|
|
77
|
+
self._session_factory = sessionmaker(bind=self.engine)
|
|
78
|
+
|
|
79
|
+
def _get_session(self) -> Session:
|
|
80
|
+
return self._session_factory()
|
|
81
|
+
|
|
82
|
+
def save_session(self, state: SessionState) -> None:
|
|
83
|
+
with self._get_session() as db:
|
|
84
|
+
row = db.get(SessionRow, state.session_id)
|
|
85
|
+
if row is None:
|
|
86
|
+
row = SessionRow(session_id=state.session_id)
|
|
87
|
+
db.add(row)
|
|
88
|
+
row.name = state.name
|
|
89
|
+
row.model = state.model
|
|
90
|
+
row.system_prompt = state.system_prompt
|
|
91
|
+
row.tools = json.dumps(state.tools)
|
|
92
|
+
row.status = state.status.value
|
|
93
|
+
row.memory_json = json.dumps(state.memory, default=str)
|
|
94
|
+
row.messages_json = json.dumps(
|
|
95
|
+
[m.model_dump(mode="json") for m in state.messages],
|
|
96
|
+
default=str,
|
|
97
|
+
)
|
|
98
|
+
row.tool_history_json = json.dumps(
|
|
99
|
+
[t.model_dump(mode="json") for t in state.tool_history],
|
|
100
|
+
default=str,
|
|
101
|
+
)
|
|
102
|
+
row.created_at = state.created_at
|
|
103
|
+
row.updated_at = datetime.now(timezone.utc)
|
|
104
|
+
db.commit()
|
|
105
|
+
logger.info("Saved session %s", state.session_id)
|
|
106
|
+
|
|
107
|
+
def load_session(self, session_id: str) -> SessionState | None:
|
|
108
|
+
with self._get_session() as db:
|
|
109
|
+
row = db.get(SessionRow, session_id)
|
|
110
|
+
if row is None:
|
|
111
|
+
return None
|
|
112
|
+
return self._row_to_state(row)
|
|
113
|
+
|
|
114
|
+
def delete_session(self, session_id: str) -> bool:
|
|
115
|
+
with self._get_session() as db:
|
|
116
|
+
row = db.get(SessionRow, session_id)
|
|
117
|
+
if row is None:
|
|
118
|
+
return False
|
|
119
|
+
db.delete(row)
|
|
120
|
+
db.execute(
|
|
121
|
+
text("DELETE FROM checkpoints WHERE session_id = :sid"),
|
|
122
|
+
{"sid": session_id},
|
|
123
|
+
)
|
|
124
|
+
db.commit()
|
|
125
|
+
return True
|
|
126
|
+
|
|
127
|
+
def list_sessions(self) -> list[SessionState]:
|
|
128
|
+
with self._get_session() as db:
|
|
129
|
+
rows = db.query(SessionRow).order_by(SessionRow.created_at).all()
|
|
130
|
+
return [self._row_to_state(r) for r in rows]
|
|
131
|
+
|
|
132
|
+
def create_checkpoint(self, session_id: str, label: str | None = None) -> CheckpointInfo:
|
|
133
|
+
state = self.load_session(session_id)
|
|
134
|
+
if state is None:
|
|
135
|
+
raise ValueError(f"Session {session_id} not found")
|
|
136
|
+
checkpoint_id = uuid.uuid4().hex[:16]
|
|
137
|
+
state_json = json.dumps(state.model_dump(mode="json"), default=str)
|
|
138
|
+
with self._get_session() as db:
|
|
139
|
+
row = CheckpointRow(
|
|
140
|
+
checkpoint_id=checkpoint_id,
|
|
141
|
+
session_id=session_id,
|
|
142
|
+
label=label,
|
|
143
|
+
state_json=state_json,
|
|
144
|
+
)
|
|
145
|
+
db.add(row)
|
|
146
|
+
db.commit()
|
|
147
|
+
created_at = row.created_at
|
|
148
|
+
logger.info("Created checkpoint %s for session %s", checkpoint_id, session_id)
|
|
149
|
+
return CheckpointInfo(
|
|
150
|
+
checkpoint_id=checkpoint_id,
|
|
151
|
+
session_id=session_id,
|
|
152
|
+
created_at=created_at,
|
|
153
|
+
label=label,
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
def list_checkpoints(self, session_id: str) -> list[CheckpointInfo]:
|
|
157
|
+
with self._get_session() as db:
|
|
158
|
+
rows = (
|
|
159
|
+
db.query(CheckpointRow)
|
|
160
|
+
.filter(CheckpointRow.session_id == session_id)
|
|
161
|
+
.order_by(CheckpointRow.created_at)
|
|
162
|
+
.all()
|
|
163
|
+
)
|
|
164
|
+
return [
|
|
165
|
+
CheckpointInfo(
|
|
166
|
+
checkpoint_id=r.checkpoint_id,
|
|
167
|
+
session_id=r.session_id,
|
|
168
|
+
created_at=r.created_at,
|
|
169
|
+
label=r.label,
|
|
170
|
+
)
|
|
171
|
+
for r in rows
|
|
172
|
+
]
|
|
173
|
+
|
|
174
|
+
def resume_from_checkpoint(self, session_id: str, checkpoint_id: str) -> SessionState:
|
|
175
|
+
with self._get_session() as db:
|
|
176
|
+
row = db.get(CheckpointRow, checkpoint_id)
|
|
177
|
+
if row is None or row.session_id != session_id:
|
|
178
|
+
raise ValueError(f"Checkpoint {checkpoint_id} not found for session {session_id}")
|
|
179
|
+
state_data = json.loads(row.state_json)
|
|
180
|
+
state = SessionState.model_validate(state_data)
|
|
181
|
+
state.status = SessionStatus.PAUSED
|
|
182
|
+
self.save_session(state)
|
|
183
|
+
logger.info("Resumed session %s from checkpoint %s", session_id, checkpoint_id)
|
|
184
|
+
return state
|
|
185
|
+
|
|
186
|
+
def prune_checkpoints(
|
|
187
|
+
self, session_id: str, keep_last: int = 10
|
|
188
|
+
) -> int:
|
|
189
|
+
with self._get_session() as db:
|
|
190
|
+
rows = (
|
|
191
|
+
db.query(CheckpointRow)
|
|
192
|
+
.filter(CheckpointRow.session_id == session_id)
|
|
193
|
+
.order_by(CheckpointRow.created_at.desc())
|
|
194
|
+
.all()
|
|
195
|
+
)
|
|
196
|
+
to_delete = rows[keep_last:]
|
|
197
|
+
count = 0
|
|
198
|
+
for row in to_delete:
|
|
199
|
+
db.delete(row)
|
|
200
|
+
count += 1
|
|
201
|
+
db.commit()
|
|
202
|
+
if count:
|
|
203
|
+
logger.info("Pruned %d old checkpoints for session %s", count, session_id)
|
|
204
|
+
return count
|
|
205
|
+
|
|
206
|
+
@staticmethod
|
|
207
|
+
def _row_to_state(row: SessionRow) -> SessionState:
|
|
208
|
+
messages = [Message.model_validate(m) for m in json.loads(row.messages_json)]
|
|
209
|
+
tool_history = [ToolCall.model_validate(t) for t in json.loads(row.tool_history_json)]
|
|
210
|
+
return SessionState(
|
|
211
|
+
session_id=row.session_id,
|
|
212
|
+
name=row.name,
|
|
213
|
+
model=row.model,
|
|
214
|
+
system_prompt=row.system_prompt,
|
|
215
|
+
tools=json.loads(row.tools),
|
|
216
|
+
status=SessionStatus(row.status),
|
|
217
|
+
memory=json.loads(row.memory_json),
|
|
218
|
+
messages=messages,
|
|
219
|
+
tool_history=tool_history,
|
|
220
|
+
created_at=row.created_at,
|
|
221
|
+
updated_at=row.updated_at,
|
|
222
|
+
)
|