fableforge-agent-runtime 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agent_runtime/__init__.py +3 -0
- agent_runtime/cli.py +128 -0
- agent_runtime/daemon.py +172 -0
- agent_runtime/memory_store.py +244 -0
- agent_runtime/models.py +93 -0
- agent_runtime/server.py +119 -0
- agent_runtime/session_manager.py +156 -0
- agent_runtime/state_serializer.py +222 -0
- fableforge_agent_runtime-0.1.0.dist-info/METADATA +274 -0
- fableforge_agent_runtime-0.1.0.dist-info/RECORD +13 -0
- fableforge_agent_runtime-0.1.0.dist-info/WHEEL +4 -0
- fableforge_agent_runtime-0.1.0.dist-info/entry_points.txt +2 -0
- fableforge_agent_runtime-0.1.0.dist-info/licenses/LICENSE +21 -0
agent_runtime/cli.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
"""CLI interface for agentd — start, stop, create, list, resume."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import sys
|
|
7
|
+
|
|
8
|
+
import click
|
|
9
|
+
|
|
10
|
+
from .daemon import AgentDaemon
|
|
11
|
+
from .models import SessionCreate
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@click.group()
|
|
15
|
+
@click.option("--verbose", "-v", is_flag=True, help="Enable verbose logging")
|
|
16
|
+
def app(verbose: bool) -> None:
|
|
17
|
+
import logging
|
|
18
|
+
level = logging.DEBUG if verbose else logging.INFO
|
|
19
|
+
logging.basicConfig(
|
|
20
|
+
level=level,
|
|
21
|
+
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@app.command()
|
|
26
|
+
@click.option("--host", default="0.0.0.0", help="Bind host")
|
|
27
|
+
@click.option("--port", default=8721, type=int, help="Bind port")
|
|
28
|
+
def start(host: str, port: int) -> None:
|
|
29
|
+
if AgentDaemon.is_running():
|
|
30
|
+
click.echo("agentd is already running")
|
|
31
|
+
sys.exit(1)
|
|
32
|
+
click.echo(f"Starting agentd on {host}:{port}")
|
|
33
|
+
daemon = AgentDaemon(host=host, port=port)
|
|
34
|
+
asyncio.run(daemon.start())
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@app.command()
|
|
38
|
+
def stop() -> None:
|
|
39
|
+
if not AgentDaemon.is_running():
|
|
40
|
+
click.echo("agentd is not running")
|
|
41
|
+
sys.exit(1)
|
|
42
|
+
if AgentDaemon.stop_running():
|
|
43
|
+
click.echo("Stopped agentd")
|
|
44
|
+
else:
|
|
45
|
+
click.echo("Failed to stop agentd", err=True)
|
|
46
|
+
sys.exit(1)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@app.command()
|
|
50
|
+
@click.option("--name", required=True, help="Session name")
|
|
51
|
+
@click.option("--model", required=True, help="Model identifier")
|
|
52
|
+
@click.option("--system-prompt", default="", help="System prompt")
|
|
53
|
+
@click.option("--tools", multiple=True, help="Tool names")
|
|
54
|
+
def create(name: str, model: str, system_prompt: str, tools: tuple[str, ...]) -> None:
|
|
55
|
+
from .session_manager import SessionManager
|
|
56
|
+
sm = SessionManager()
|
|
57
|
+
state = sm.create_session(
|
|
58
|
+
name=name,
|
|
59
|
+
model=model,
|
|
60
|
+
config={"system_prompt": system_prompt, "tools": list(tools)},
|
|
61
|
+
)
|
|
62
|
+
click.echo(f"Created session: {state.session_id}")
|
|
63
|
+
click.echo(f" Name: {state.name}")
|
|
64
|
+
click.echo(f" Model: {state.model}")
|
|
65
|
+
click.echo(f" Status: {state.status.value}")
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
@app.command("list")
|
|
69
|
+
def list_sessions() -> None:
|
|
70
|
+
from .session_manager import SessionManager
|
|
71
|
+
sm = SessionManager()
|
|
72
|
+
sessions = sm.list_sessions()
|
|
73
|
+
if not sessions:
|
|
74
|
+
click.echo("No sessions found")
|
|
75
|
+
return
|
|
76
|
+
click.echo(f"{'ID':<20} {'Name':<20} {'Model':<20} {'Status':<10} {'Created'}")
|
|
77
|
+
click.echo("-" * 90)
|
|
78
|
+
for s in sessions:
|
|
79
|
+
click.echo(
|
|
80
|
+
f"{s.session_id:<20} {s.name:<20} {s.model:<20} {s.status.value:<10} {s.created_at.isoformat()}"
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
@app.command()
|
|
85
|
+
@click.argument("session_id")
|
|
86
|
+
@click.option("--checkpoint-id", default=None, help="Resume from specific checkpoint")
|
|
87
|
+
def resume(session_id: str, checkpoint_id: str | None) -> None:
|
|
88
|
+
from .session_manager import SessionManager
|
|
89
|
+
sm = SessionManager()
|
|
90
|
+
try:
|
|
91
|
+
state = asyncio.run(sm.resume_session(session_id, checkpoint_id=checkpoint_id))
|
|
92
|
+
click.echo(f"Resumed session: {state.session_id}")
|
|
93
|
+
click.echo(f" Status: {state.status.value}")
|
|
94
|
+
if checkpoint_id:
|
|
95
|
+
click.echo(f" From checkpoint: {checkpoint_id}")
|
|
96
|
+
except Exception as e:
|
|
97
|
+
click.echo(f"Error: {e}", err=True)
|
|
98
|
+
sys.exit(1)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
@app.command()
|
|
102
|
+
@click.argument("session_id")
|
|
103
|
+
def pause(session_id: str) -> None:
|
|
104
|
+
from .session_manager import SessionManager
|
|
105
|
+
sm = SessionManager()
|
|
106
|
+
try:
|
|
107
|
+
state = asyncio.run(sm.pause_session(session_id))
|
|
108
|
+
click.echo(f"Paused session: {state.session_id}")
|
|
109
|
+
except Exception as e:
|
|
110
|
+
click.echo(f"Error: {e}", err=True)
|
|
111
|
+
sys.exit(1)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
@app.command()
|
|
115
|
+
@click.argument("session_id")
|
|
116
|
+
def stop_session(session_id: str) -> None:
|
|
117
|
+
from .session_manager import SessionManager
|
|
118
|
+
sm = SessionManager()
|
|
119
|
+
try:
|
|
120
|
+
state = asyncio.run(sm.stop_session(session_id))
|
|
121
|
+
click.echo(f"Stopped session: {state.session_id}")
|
|
122
|
+
except Exception as e:
|
|
123
|
+
click.echo(f"Error: {e}", err=True)
|
|
124
|
+
sys.exit(1)
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
if __name__ == "__main__":
|
|
128
|
+
app()
|
agent_runtime/daemon.py
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
1
|
+
"""Main daemon process — background agent manager with heartbeat and auto-checkpoint."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import logging
|
|
7
|
+
import os
|
|
8
|
+
import signal
|
|
9
|
+
import sys
|
|
10
|
+
import time
|
|
11
|
+
from pathlib import Path
|
|
12
|
+
from typing import Any
|
|
13
|
+
|
|
14
|
+
from .memory_store import MemoryStore
|
|
15
|
+
from .models import HealthResponse
|
|
16
|
+
from .session_manager import SessionManager
|
|
17
|
+
from .state_serializer import StateSerializer
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
PID_DIR = Path.home() / ".agent_runtime"
|
|
22
|
+
PID_FILE = PID_DIR / "agentd.pid"
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class AgentDaemon:
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
state_serializer: StateSerializer | None = None,
|
|
29
|
+
memory_store: MemoryStore | None = None,
|
|
30
|
+
host: str = "0.0.0.0",
|
|
31
|
+
port: int = 8721,
|
|
32
|
+
auto_checkpoint_interval: float = 60.0,
|
|
33
|
+
session_timeout: float = 3600.0,
|
|
34
|
+
) -> None:
|
|
35
|
+
self._state_serializer = state_serializer or StateSerializer()
|
|
36
|
+
self._memory_store = memory_store or MemoryStore()
|
|
37
|
+
self._session_manager = SessionManager(
|
|
38
|
+
state_serializer=self._state_serializer,
|
|
39
|
+
memory_store=self._memory_store,
|
|
40
|
+
)
|
|
41
|
+
self._host = host
|
|
42
|
+
self._port = port
|
|
43
|
+
self._auto_checkpoint_interval = auto_checkpoint_interval
|
|
44
|
+
self._session_timeout = session_timeout
|
|
45
|
+
self._start_time = time.monotonic()
|
|
46
|
+
self._shutdown_event = asyncio.Event()
|
|
47
|
+
self._heartbeat_task: asyncio.Task | None = None
|
|
48
|
+
self._cleanup_task: asyncio.Task | None = None
|
|
49
|
+
self._server_task: asyncio.Task | None = None
|
|
50
|
+
|
|
51
|
+
@property
|
|
52
|
+
def session_manager(self) -> SessionManager:
|
|
53
|
+
return self._session_manager
|
|
54
|
+
|
|
55
|
+
@property
|
|
56
|
+
def uptime_seconds(self) -> float:
|
|
57
|
+
return time.monotonic() - self._start_time
|
|
58
|
+
|
|
59
|
+
def health(self) -> HealthResponse:
|
|
60
|
+
active = sum(
|
|
61
|
+
1
|
|
62
|
+
for s in self._session_manager.list_sessions()
|
|
63
|
+
if s.status.value == "running"
|
|
64
|
+
)
|
|
65
|
+
return HealthResponse(
|
|
66
|
+
status="healthy",
|
|
67
|
+
uptime_seconds=self.uptime_seconds,
|
|
68
|
+
active_sessions=active,
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
async def start(self) -> None:
|
|
72
|
+
logger.info("AgentDaemon starting on %s:%d", self._host, self._port)
|
|
73
|
+
PID_DIR.mkdir(parents=True, exist_ok=True)
|
|
74
|
+
PID_FILE.write_text(str(os.getpid()))
|
|
75
|
+
|
|
76
|
+
loop = asyncio.get_event_loop()
|
|
77
|
+
for sig in (signal.SIGINT, signal.SIGTERM):
|
|
78
|
+
loop.add_signal_handler(sig, self._handle_signal, sig)
|
|
79
|
+
|
|
80
|
+
self._heartbeat_task = asyncio.create_task(self._heartbeat_loop())
|
|
81
|
+
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
|
|
82
|
+
|
|
83
|
+
try:
|
|
84
|
+
from .server import create_app
|
|
85
|
+
import uvicorn
|
|
86
|
+
|
|
87
|
+
app = create_app(self)
|
|
88
|
+
config = uvicorn.Config(app, host=self._host, port=self._port, log_level="info")
|
|
89
|
+
server = uvicorn.Server(config)
|
|
90
|
+
self._server_task = asyncio.create_task(server.serve())
|
|
91
|
+
logger.info("AgentDaemon HTTP server started on %s:%d", self._host, self._port)
|
|
92
|
+
await self._shutdown_event.wait()
|
|
93
|
+
await self._graceful_shutdown()
|
|
94
|
+
finally:
|
|
95
|
+
self._remove_pid()
|
|
96
|
+
|
|
97
|
+
async def _graceful_shutdown(self) -> None:
|
|
98
|
+
logger.info("Starting graceful shutdown...")
|
|
99
|
+
if self._heartbeat_task and not self._heartbeat_task.done():
|
|
100
|
+
self._heartbeat_task.cancel()
|
|
101
|
+
if self._cleanup_task and not self._cleanup_task.done():
|
|
102
|
+
self._cleanup_task.cancel()
|
|
103
|
+
if self._server_task and not self._server_task.done():
|
|
104
|
+
self._server_task.cancel()
|
|
105
|
+
await self._session_manager.shutdown()
|
|
106
|
+
logger.info("Graceful shutdown complete")
|
|
107
|
+
|
|
108
|
+
def _handle_signal(self, sig: signal.Signals) -> None:
|
|
109
|
+
logger.info("Received signal %s, initiating shutdown", sig.name)
|
|
110
|
+
self._shutdown_event.set()
|
|
111
|
+
|
|
112
|
+
async def _heartbeat_loop(self) -> None:
|
|
113
|
+
try:
|
|
114
|
+
while True:
|
|
115
|
+
sessions = self._session_manager.list_sessions()
|
|
116
|
+
for s in sessions:
|
|
117
|
+
if s.status.value == "running":
|
|
118
|
+
logger.debug("Heartbeat: session %s is running", s.session_id)
|
|
119
|
+
await asyncio.sleep(10.0)
|
|
120
|
+
except asyncio.CancelledError:
|
|
121
|
+
logger.info("Heartbeat loop cancelled")
|
|
122
|
+
|
|
123
|
+
async def _cleanup_loop(self) -> None:
|
|
124
|
+
try:
|
|
125
|
+
while True:
|
|
126
|
+
sessions = self._session_manager.list_sessions()
|
|
127
|
+
for s in sessions:
|
|
128
|
+
if s.status.value == "stopped":
|
|
129
|
+
checkpoints = self._state_serializer.list_checkpoints(s.session_id)
|
|
130
|
+
if len(checkpoints) > 10:
|
|
131
|
+
pruned = self._state_serializer.prune_checkpoints(s.session_id, keep_last=5)
|
|
132
|
+
logger.info("Cleanup: pruned %d checkpoints for session %s", pruned, s.session_id)
|
|
133
|
+
await asyncio.sleep(300.0)
|
|
134
|
+
except asyncio.CancelledError:
|
|
135
|
+
logger.info("Cleanup loop cancelled")
|
|
136
|
+
|
|
137
|
+
def _remove_pid(self) -> None:
|
|
138
|
+
try:
|
|
139
|
+
PID_FILE.unlink(missing_ok=True)
|
|
140
|
+
except OSError:
|
|
141
|
+
pass
|
|
142
|
+
|
|
143
|
+
@staticmethod
|
|
144
|
+
def is_running() -> bool:
|
|
145
|
+
if not PID_FILE.exists():
|
|
146
|
+
return False
|
|
147
|
+
try:
|
|
148
|
+
pid = int(PID_FILE.read_text().strip())
|
|
149
|
+
os.kill(pid, 0)
|
|
150
|
+
return True
|
|
151
|
+
except (ProcessLookupError, ValueError, OSError):
|
|
152
|
+
try:
|
|
153
|
+
PID_FILE.unlink(missing_ok=True)
|
|
154
|
+
except OSError:
|
|
155
|
+
pass
|
|
156
|
+
return False
|
|
157
|
+
|
|
158
|
+
@staticmethod
|
|
159
|
+
def stop_running() -> bool:
|
|
160
|
+
if not PID_FILE.exists():
|
|
161
|
+
return False
|
|
162
|
+
try:
|
|
163
|
+
pid = int(PID_FILE.read_text().strip())
|
|
164
|
+
os.kill(pid, signal.SIGTERM)
|
|
165
|
+
logger.info("Sent SIGTERM to agentd (pid=%d)", pid)
|
|
166
|
+
return True
|
|
167
|
+
except ProcessLookupError:
|
|
168
|
+
PID_FILE.unlink(missing_ok=True)
|
|
169
|
+
return False
|
|
170
|
+
except (ValueError, OSError) as e:
|
|
171
|
+
logger.error("Failed to stop agentd: %s", e)
|
|
172
|
+
return False
|
|
@@ -0,0 +1,244 @@
|
|
|
1
|
+
"""Persistent memory store with short-term and long-term memory."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import logging
|
|
7
|
+
import math
|
|
8
|
+
import uuid
|
|
9
|
+
from datetime import datetime, timezone
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
from typing import Any
|
|
12
|
+
|
|
13
|
+
from sqlalchemy import (
|
|
14
|
+
Column,
|
|
15
|
+
DateTime,
|
|
16
|
+
Float,
|
|
17
|
+
Integer,
|
|
18
|
+
LargeBinary,
|
|
19
|
+
String,
|
|
20
|
+
Text,
|
|
21
|
+
create_engine,
|
|
22
|
+
text,
|
|
23
|
+
)
|
|
24
|
+
from sqlalchemy.orm import DeclarativeBase, Session, sessionmaker
|
|
25
|
+
|
|
26
|
+
from .models import MemoryEntry, MemorySearchResult
|
|
27
|
+
|
|
28
|
+
logger = logging.getLogger(__name__)
|
|
29
|
+
|
|
30
|
+
DEFAULT_DB_PATH = Path.home() / ".agent_runtime" / "memory.db"
|
|
31
|
+
|
|
32
|
+
SHORT_TERM_WINDOW = 50
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class MemBase(DeclarativeBase):
|
|
36
|
+
pass
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class LongTermMemoryRow(MemBase):
|
|
40
|
+
__tablename__ = "long_term_memory"
|
|
41
|
+
|
|
42
|
+
key = Column(String(512), primary_key=True)
|
|
43
|
+
value_json = Column(Text, nullable=False)
|
|
44
|
+
embedding_blob = Column(LargeBinary, nullable=True)
|
|
45
|
+
timestamp = Column(DateTime, default=lambda: datetime.now(timezone.utc))
|
|
46
|
+
access_count = Column(Integer, default=1)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class ShortTermMemoryRow(MemBase):
|
|
50
|
+
__tablename__ = "short_term_memory"
|
|
51
|
+
|
|
52
|
+
id = Column(String(64), primary_key=True)
|
|
53
|
+
session_id = Column(String(64), nullable=False, index=True)
|
|
54
|
+
role = Column(String(32), nullable=False)
|
|
55
|
+
content = Column(Text, default="")
|
|
56
|
+
timestamp = Column(DateTime, default=lambda: datetime.now(timezone.utc))
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class MemoryConsolidationRow(MemBase):
|
|
60
|
+
__tablename__ = "memory_consolidations"
|
|
61
|
+
|
|
62
|
+
id = Column(String(64), primary_key=True)
|
|
63
|
+
session_id = Column(String(64), nullable=False, index=True)
|
|
64
|
+
summary = Column(Text, default="")
|
|
65
|
+
from_count = Column(Integer, default=0)
|
|
66
|
+
created_at = Column(DateTime, default=lambda: datetime.now(timezone.utc))
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class MemoryStore:
|
|
70
|
+
def __init__(self, db_path: Path | str | None = None) -> None:
|
|
71
|
+
if db_path is None:
|
|
72
|
+
db_path = DEFAULT_DB_PATH
|
|
73
|
+
self.db_path = Path(db_path)
|
|
74
|
+
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
|
75
|
+
self.engine = create_engine(f"sqlite:///{self.db_path}", echo=False)
|
|
76
|
+
MemBase.metadata.create_all(self.engine)
|
|
77
|
+
self._session_factory = sessionmaker(bind=self.engine)
|
|
78
|
+
|
|
79
|
+
def _get_session(self) -> Session:
|
|
80
|
+
return self._session_factory()
|
|
81
|
+
|
|
82
|
+
def store(self, key: str, value: Any, embedding: list[float] | None = None) -> None:
|
|
83
|
+
with self._get_session() as db:
|
|
84
|
+
row = db.get(LongTermMemoryRow, key)
|
|
85
|
+
if row is None:
|
|
86
|
+
row = LongTermMemoryRow(key=key)
|
|
87
|
+
db.add(row)
|
|
88
|
+
row.value_json = json.dumps(value, default=str)
|
|
89
|
+
row.embedding_blob = self._encode_embedding(embedding)
|
|
90
|
+
row.timestamp = datetime.now(timezone.utc)
|
|
91
|
+
row.access_count = (row.access_count or 0) + 1
|
|
92
|
+
db.commit()
|
|
93
|
+
logger.info("Stored memory key=%s", key)
|
|
94
|
+
|
|
95
|
+
def retrieve(self, key: str) -> MemoryEntry | None:
|
|
96
|
+
with self._get_session() as db:
|
|
97
|
+
row = db.get(LongTermMemoryRow, key)
|
|
98
|
+
if row is None:
|
|
99
|
+
return None
|
|
100
|
+
row.access_count = (row.access_count or 0) + 1
|
|
101
|
+
db.commit()
|
|
102
|
+
return MemoryEntry(
|
|
103
|
+
key=row.key,
|
|
104
|
+
value=json.loads(row.value_json),
|
|
105
|
+
embedding=self._decode_embedding(row.embedding_blob),
|
|
106
|
+
timestamp=row.timestamp,
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
def search(self, query_embedding: list[float], limit: int = 10) -> list[MemorySearchResult]:
|
|
110
|
+
results: list[MemorySearchResult] = []
|
|
111
|
+
with self._get_session() as db:
|
|
112
|
+
rows = db.query(LongTermMemoryRow).all()
|
|
113
|
+
for row in rows:
|
|
114
|
+
emb = self._decode_embedding(row.embedding_blob)
|
|
115
|
+
if emb is None:
|
|
116
|
+
continue
|
|
117
|
+
score = self._cosine_similarity(query_embedding, emb)
|
|
118
|
+
results.append(
|
|
119
|
+
MemorySearchResult(
|
|
120
|
+
key=row.key,
|
|
121
|
+
value=json.loads(row.value_json),
|
|
122
|
+
score=score,
|
|
123
|
+
timestamp=row.timestamp,
|
|
124
|
+
)
|
|
125
|
+
)
|
|
126
|
+
results.sort(key=lambda r: r.score, reverse=True)
|
|
127
|
+
return results[:limit]
|
|
128
|
+
|
|
129
|
+
def store_short_term(self, session_id: str, role: str, content: str) -> None:
|
|
130
|
+
with self._get_session() as db:
|
|
131
|
+
row = ShortTermMemoryRow(
|
|
132
|
+
id=uuid.uuid4().hex[:16],
|
|
133
|
+
session_id=session_id,
|
|
134
|
+
role=role,
|
|
135
|
+
content=content,
|
|
136
|
+
)
|
|
137
|
+
db.add(row)
|
|
138
|
+
db.commit()
|
|
139
|
+
self._prune_short_term(session_id)
|
|
140
|
+
|
|
141
|
+
def get_short_term(self, session_id: str, last_n: int = SHORT_TERM_WINDOW) -> list[dict[str, Any]]:
|
|
142
|
+
with self._get_session() as db:
|
|
143
|
+
rows = (
|
|
144
|
+
db.query(ShortTermMemoryRow)
|
|
145
|
+
.filter(ShortTermMemoryRow.session_id == session_id)
|
|
146
|
+
.order_by(ShortTermMemoryRow.timestamp.desc())
|
|
147
|
+
.limit(last_n)
|
|
148
|
+
.all()
|
|
149
|
+
)
|
|
150
|
+
rows.reverse()
|
|
151
|
+
return [{"role": r.role, "content": r.content, "timestamp": r.timestamp.isoformat()} for r in rows]
|
|
152
|
+
|
|
153
|
+
def consolidate(self, session_id: str) -> str | None:
|
|
154
|
+
messages = self.get_short_term(session_id, last_n=SHORT_TERM_WINDOW)
|
|
155
|
+
if len(messages) < 5:
|
|
156
|
+
logger.info("Not enough short-term memory to consolidate for %s", session_id)
|
|
157
|
+
return None
|
|
158
|
+
summary_parts = []
|
|
159
|
+
for msg in messages:
|
|
160
|
+
summary_parts.append(f"[{msg['role']}] {msg['content'][:200]}")
|
|
161
|
+
summary = "\n".join(summary_parts)
|
|
162
|
+
consolidation_id = uuid.uuid4().hex[:16]
|
|
163
|
+
with self._get_session() as db:
|
|
164
|
+
row = MemoryConsolidationRow(
|
|
165
|
+
id=consolidation_id,
|
|
166
|
+
session_id=session_id,
|
|
167
|
+
summary=summary,
|
|
168
|
+
from_count=len(messages),
|
|
169
|
+
)
|
|
170
|
+
db.add(row)
|
|
171
|
+
db.commit()
|
|
172
|
+
self.store(
|
|
173
|
+
key=f"consolidation:{session_id}:{consolidation_id}",
|
|
174
|
+
value={"summary": summary, "message_count": len(messages)},
|
|
175
|
+
)
|
|
176
|
+
logger.info("Consolidated %d messages for session %s", len(messages), session_id)
|
|
177
|
+
return consolidation_id
|
|
178
|
+
|
|
179
|
+
def list_keys(self, prefix: str = "") -> list[str]:
|
|
180
|
+
with self._get_session() as db:
|
|
181
|
+
if prefix:
|
|
182
|
+
rows = db.query(LongTermMemoryRow.key).filter(
|
|
183
|
+
LongTermMemoryRow.key.like(f"{prefix}%")
|
|
184
|
+
).all()
|
|
185
|
+
else:
|
|
186
|
+
rows = db.query(LongTermMemoryRow.key).all()
|
|
187
|
+
return [r[0] for r in rows]
|
|
188
|
+
|
|
189
|
+
def delete(self, key: str) -> bool:
|
|
190
|
+
with self._get_session() as db:
|
|
191
|
+
row = db.get(LongTermMemoryRow, key)
|
|
192
|
+
if row is None:
|
|
193
|
+
return False
|
|
194
|
+
db.delete(row)
|
|
195
|
+
db.commit()
|
|
196
|
+
return True
|
|
197
|
+
|
|
198
|
+
def _prune_short_term(self, session_id: str) -> None:
|
|
199
|
+
with self._get_session() as db:
|
|
200
|
+
count = db.query(ShortTermMemoryRow).filter(
|
|
201
|
+
ShortTermMemoryRow.session_id == session_id
|
|
202
|
+
).count()
|
|
203
|
+
if count > SHORT_TERM_WINDOW * 2:
|
|
204
|
+
subq = (
|
|
205
|
+
db.query(ShortTermMemoryRow.id)
|
|
206
|
+
.filter(ShortTermMemoryRow.session_id == session_id)
|
|
207
|
+
.order_by(ShortTermMemoryRow.timestamp.desc())
|
|
208
|
+
.limit(SHORT_TERM_WINDOW)
|
|
209
|
+
.subquery()
|
|
210
|
+
)
|
|
211
|
+
db.execute(
|
|
212
|
+
text(
|
|
213
|
+
"DELETE FROM short_term_memory "
|
|
214
|
+
"WHERE session_id = :sid "
|
|
215
|
+
"AND id NOT IN (SELECT id FROM subq)"
|
|
216
|
+
).bindparams(sid=session_id)
|
|
217
|
+
)
|
|
218
|
+
db.commit()
|
|
219
|
+
|
|
220
|
+
@staticmethod
|
|
221
|
+
def _encode_embedding(embedding: list[float] | None) -> bytes | None:
|
|
222
|
+
if embedding is None:
|
|
223
|
+
return None
|
|
224
|
+
import struct
|
|
225
|
+
return struct.pack(f"{len(embedding)}d", *embedding)
|
|
226
|
+
|
|
227
|
+
@staticmethod
|
|
228
|
+
def _decode_embedding(blob: bytes | None) -> list[float] | None:
|
|
229
|
+
if blob is None:
|
|
230
|
+
return None
|
|
231
|
+
import struct
|
|
232
|
+
count = len(blob) // 8
|
|
233
|
+
return list(struct.unpack(f"{count}d", blob))
|
|
234
|
+
|
|
235
|
+
@staticmethod
|
|
236
|
+
def _cosine_similarity(a: list[float], b: list[float]) -> float:
|
|
237
|
+
if len(a) != len(b):
|
|
238
|
+
return 0.0
|
|
239
|
+
dot = sum(x * y for x, y in zip(a, b))
|
|
240
|
+
mag_a = math.sqrt(sum(x * x for x in a))
|
|
241
|
+
mag_b = math.sqrt(sum(x * x for x in b))
|
|
242
|
+
if mag_a == 0 or mag_b == 0:
|
|
243
|
+
return 0.0
|
|
244
|
+
return dot / (mag_a * mag_b)
|
agent_runtime/models.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import uuid
|
|
4
|
+
from datetime import datetime, timezone
|
|
5
|
+
from enum import Enum
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from pydantic import BaseModel, Field
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class SessionStatus(str, Enum):
|
|
12
|
+
CREATED = "created"
|
|
13
|
+
RUNNING = "running"
|
|
14
|
+
PAUSED = "paused"
|
|
15
|
+
STOPPED = "stopped"
|
|
16
|
+
ERROR = "error"
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class ToolCall(BaseModel):
|
|
20
|
+
name: str
|
|
21
|
+
arguments: dict[str, Any] = Field(default_factory=dict)
|
|
22
|
+
result: Any = None
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class Message(BaseModel):
|
|
26
|
+
role: str
|
|
27
|
+
content: str = ""
|
|
28
|
+
tool_calls: list[ToolCall] = Field(default_factory=list)
|
|
29
|
+
timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class SessionCreate(BaseModel):
|
|
33
|
+
name: str
|
|
34
|
+
model: str
|
|
35
|
+
system_prompt: str = ""
|
|
36
|
+
tools: list[str] = Field(default_factory=list)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class SessionState(BaseModel):
|
|
40
|
+
session_id: str = Field(default_factory=lambda: uuid.uuid4().hex[:16])
|
|
41
|
+
name: str = ""
|
|
42
|
+
model: str = ""
|
|
43
|
+
system_prompt: str = ""
|
|
44
|
+
tools: list[str] = Field(default_factory=list)
|
|
45
|
+
status: SessionStatus = SessionStatus.CREATED
|
|
46
|
+
memory: dict[str, Any] = Field(default_factory=dict)
|
|
47
|
+
messages: list[Message] = Field(default_factory=list)
|
|
48
|
+
tool_history: list[ToolCall] = Field(default_factory=list)
|
|
49
|
+
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
|
50
|
+
updated_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class SessionResume(BaseModel):
|
|
54
|
+
session_id: str
|
|
55
|
+
checkpoint_id: str | None = None
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class MemoryEntry(BaseModel):
|
|
59
|
+
key: str
|
|
60
|
+
value: Any
|
|
61
|
+
embedding: list[float] | None = None
|
|
62
|
+
timestamp: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class MemorySearchResult(BaseModel):
|
|
66
|
+
key: str
|
|
67
|
+
value: Any
|
|
68
|
+
score: float
|
|
69
|
+
timestamp: datetime
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class CheckpointInfo(BaseModel):
|
|
73
|
+
checkpoint_id: str
|
|
74
|
+
session_id: str
|
|
75
|
+
created_at: datetime
|
|
76
|
+
label: str | None = None
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class MemoryStoreRequest(BaseModel):
|
|
80
|
+
key: str
|
|
81
|
+
value: Any
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class MemoryRetrieveResponse(BaseModel):
|
|
85
|
+
key: str
|
|
86
|
+
value: Any
|
|
87
|
+
timestamp: datetime
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class HealthResponse(BaseModel):
|
|
91
|
+
status: str
|
|
92
|
+
uptime_seconds: float
|
|
93
|
+
active_sessions: int
|