roampal 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- roampal/__init__.py +29 -0
- roampal/__main__.py +6 -0
- roampal/backend/__init__.py +1 -0
- roampal/backend/modules/__init__.py +1 -0
- roampal/backend/modules/memory/__init__.py +43 -0
- roampal/backend/modules/memory/chromadb_adapter.py +623 -0
- roampal/backend/modules/memory/config.py +102 -0
- roampal/backend/modules/memory/content_graph.py +543 -0
- roampal/backend/modules/memory/context_service.py +455 -0
- roampal/backend/modules/memory/embedding_service.py +96 -0
- roampal/backend/modules/memory/knowledge_graph_service.py +1052 -0
- roampal/backend/modules/memory/memory_bank_service.py +433 -0
- roampal/backend/modules/memory/memory_types.py +296 -0
- roampal/backend/modules/memory/outcome_service.py +400 -0
- roampal/backend/modules/memory/promotion_service.py +473 -0
- roampal/backend/modules/memory/routing_service.py +444 -0
- roampal/backend/modules/memory/scoring_service.py +324 -0
- roampal/backend/modules/memory/search_service.py +646 -0
- roampal/backend/modules/memory/tests/__init__.py +1 -0
- roampal/backend/modules/memory/tests/conftest.py +12 -0
- roampal/backend/modules/memory/tests/unit/__init__.py +1 -0
- roampal/backend/modules/memory/tests/unit/conftest.py +7 -0
- roampal/backend/modules/memory/tests/unit/test_knowledge_graph_service.py +517 -0
- roampal/backend/modules/memory/tests/unit/test_memory_bank_service.py +504 -0
- roampal/backend/modules/memory/tests/unit/test_outcome_service.py +485 -0
- roampal/backend/modules/memory/tests/unit/test_scoring_service.py +255 -0
- roampal/backend/modules/memory/tests/unit/test_search_service.py +413 -0
- roampal/backend/modules/memory/tests/unit/test_unified_memory_system.py +418 -0
- roampal/backend/modules/memory/unified_memory_system.py +1277 -0
- roampal/cli.py +638 -0
- roampal/hooks/__init__.py +16 -0
- roampal/hooks/session_manager.py +587 -0
- roampal/hooks/stop_hook.py +176 -0
- roampal/hooks/user_prompt_submit_hook.py +103 -0
- roampal/mcp/__init__.py +7 -0
- roampal/mcp/server.py +611 -0
- roampal/server/__init__.py +7 -0
- roampal/server/main.py +744 -0
- roampal-0.1.4.dist-info/METADATA +179 -0
- roampal-0.1.4.dist-info/RECORD +44 -0
- roampal-0.1.4.dist-info/WHEEL +5 -0
- roampal-0.1.4.dist-info/entry_points.txt +2 -0
- roampal-0.1.4.dist-info/licenses/LICENSE +190 -0
- roampal-0.1.4.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,587 @@
|
|
|
1
|
+
"""
|
|
2
|
+
SessionManager - Tracks exchanges for enforced outcome scoring
|
|
3
|
+
|
|
4
|
+
Stores exchanges in JSONL files so the Stop hook can:
|
|
5
|
+
1. Load the previous exchange
|
|
6
|
+
2. Inject scoring prompt into UserPromptSubmit
|
|
7
|
+
3. Verify record_response() was called before allowing stop
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import json
|
|
11
|
+
import logging
|
|
12
|
+
import os
|
|
13
|
+
from pathlib import Path
|
|
14
|
+
from datetime import datetime
|
|
15
|
+
from typing import Dict, List, Optional, Any
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class SessionManager:
|
|
21
|
+
"""
|
|
22
|
+
Manages conversation sessions for roampal-core.
|
|
23
|
+
|
|
24
|
+
Session files store exchanges in JSONL format:
|
|
25
|
+
{"role":"user","content":"...","timestamp":"..."}
|
|
26
|
+
{"role":"assistant","content":"...","doc_id":"...","timestamp":"..."}
|
|
27
|
+
|
|
28
|
+
The doc_id links to ChromaDB for outcome scoring.
|
|
29
|
+
|
|
30
|
+
Completion Tracking:
|
|
31
|
+
- Stop hook sets assistant_completed=True when assistant finishes
|
|
32
|
+
- UserPromptSubmit checks this flag to decide if scoring is needed
|
|
33
|
+
- This prevents scoring prompts during mid-work interruptions
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def __init__(self, data_path: Path):
|
|
37
|
+
"""
|
|
38
|
+
Initialize session manager.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
data_path: Root data directory (e.g., %APPDATA%/Roampal/data)
|
|
42
|
+
"""
|
|
43
|
+
self.data_path = Path(data_path)
|
|
44
|
+
self.sessions_dir = self.data_path / "mcp_sessions"
|
|
45
|
+
self.sessions_dir.mkdir(parents=True, exist_ok=True)
|
|
46
|
+
|
|
47
|
+
# In-memory cache of last exchange per session (for fast lookup)
|
|
48
|
+
self._last_exchange_cache: Dict[str, Dict[str, Any]] = {}
|
|
49
|
+
|
|
50
|
+
# Completion state file (persists across hook invocations)
|
|
51
|
+
self._state_file = self.sessions_dir / "_completion_state.json"
|
|
52
|
+
|
|
53
|
+
logger.info(f"SessionManager initialized: {self.sessions_dir}")
|
|
54
|
+
|
|
55
|
+
def _get_session_file(self, conversation_id: str) -> Path:
|
|
56
|
+
"""Get path to session file."""
|
|
57
|
+
# Sanitize conversation_id for filename
|
|
58
|
+
safe_id = "".join(c if c.isalnum() or c in "-_" else "_" for c in conversation_id)
|
|
59
|
+
return self.sessions_dir / f"{safe_id}.jsonl"
|
|
60
|
+
|
|
61
|
+
async def store_exchange(
|
|
62
|
+
self,
|
|
63
|
+
conversation_id: str,
|
|
64
|
+
user_message: str,
|
|
65
|
+
assistant_response: str,
|
|
66
|
+
doc_id: str
|
|
67
|
+
) -> Dict[str, Any]:
|
|
68
|
+
"""
|
|
69
|
+
Store a completed exchange.
|
|
70
|
+
|
|
71
|
+
Called by Stop hook after Claude responds.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
conversation_id: Session identifier
|
|
75
|
+
user_message: What the user said
|
|
76
|
+
assistant_response: What Claude said
|
|
77
|
+
doc_id: ChromaDB document ID for outcome scoring
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
Exchange record that was stored
|
|
81
|
+
"""
|
|
82
|
+
session_file = self._get_session_file(conversation_id)
|
|
83
|
+
timestamp = datetime.now().isoformat()
|
|
84
|
+
|
|
85
|
+
# Create exchange records
|
|
86
|
+
user_record = {
|
|
87
|
+
"role": "user",
|
|
88
|
+
"content": user_message,
|
|
89
|
+
"timestamp": timestamp
|
|
90
|
+
}
|
|
91
|
+
|
|
92
|
+
assistant_record = {
|
|
93
|
+
"role": "assistant",
|
|
94
|
+
"content": assistant_response,
|
|
95
|
+
"doc_id": doc_id,
|
|
96
|
+
"timestamp": timestamp,
|
|
97
|
+
"scored": False # Will be set True when record_response() is called
|
|
98
|
+
}
|
|
99
|
+
|
|
100
|
+
# Append to session file
|
|
101
|
+
with open(session_file, "a", encoding="utf-8") as f:
|
|
102
|
+
f.write(json.dumps(user_record) + "\n")
|
|
103
|
+
f.write(json.dumps(assistant_record) + "\n")
|
|
104
|
+
|
|
105
|
+
# Update cache
|
|
106
|
+
self._last_exchange_cache[conversation_id] = {
|
|
107
|
+
"user": user_message,
|
|
108
|
+
"assistant": assistant_response,
|
|
109
|
+
"doc_id": doc_id,
|
|
110
|
+
"timestamp": timestamp,
|
|
111
|
+
"scored": False
|
|
112
|
+
}
|
|
113
|
+
|
|
114
|
+
logger.info(f"Stored exchange for {conversation_id}, doc_id={doc_id}")
|
|
115
|
+
|
|
116
|
+
return assistant_record
|
|
117
|
+
|
|
118
|
+
async def get_previous_exchange(
|
|
119
|
+
self,
|
|
120
|
+
conversation_id: str
|
|
121
|
+
) -> Optional[Dict[str, Any]]:
|
|
122
|
+
"""
|
|
123
|
+
Get the previous (unscored) exchange for scoring.
|
|
124
|
+
|
|
125
|
+
Called by UserPromptSubmit hook to inject scoring prompt.
|
|
126
|
+
|
|
127
|
+
Args:
|
|
128
|
+
conversation_id: Session identifier
|
|
129
|
+
|
|
130
|
+
Returns:
|
|
131
|
+
Previous exchange or None if no unscored exchange exists
|
|
132
|
+
"""
|
|
133
|
+
# Check cache first
|
|
134
|
+
if conversation_id in self._last_exchange_cache:
|
|
135
|
+
cached = self._last_exchange_cache[conversation_id]
|
|
136
|
+
if not cached.get("scored", False):
|
|
137
|
+
return cached
|
|
138
|
+
|
|
139
|
+
# Fall back to reading file
|
|
140
|
+
session_file = self._get_session_file(conversation_id)
|
|
141
|
+
if not session_file.exists():
|
|
142
|
+
return None
|
|
143
|
+
|
|
144
|
+
# Read backwards to find last unscored assistant message
|
|
145
|
+
try:
|
|
146
|
+
with open(session_file, "r", encoding="utf-8") as f:
|
|
147
|
+
lines = f.readlines()
|
|
148
|
+
|
|
149
|
+
# Find last assistant message that hasn't been scored
|
|
150
|
+
last_user = None
|
|
151
|
+
for line in reversed(lines):
|
|
152
|
+
try:
|
|
153
|
+
record = json.loads(line.strip())
|
|
154
|
+
if record.get("role") == "user" and last_user is None:
|
|
155
|
+
# Keep track of the user message that precedes assistant
|
|
156
|
+
pass
|
|
157
|
+
if record.get("role") == "assistant":
|
|
158
|
+
if not record.get("scored", False):
|
|
159
|
+
# Find the user message before this
|
|
160
|
+
idx = lines.index(line)
|
|
161
|
+
if idx > 0:
|
|
162
|
+
prev_line = lines[idx - 1]
|
|
163
|
+
user_record = json.loads(prev_line.strip())
|
|
164
|
+
if user_record.get("role") == "user":
|
|
165
|
+
return {
|
|
166
|
+
"user": user_record.get("content", ""),
|
|
167
|
+
"assistant": record.get("content", ""),
|
|
168
|
+
"doc_id": record.get("doc_id"),
|
|
169
|
+
"timestamp": record.get("timestamp"),
|
|
170
|
+
"scored": False
|
|
171
|
+
}
|
|
172
|
+
break # Only check the last assistant message
|
|
173
|
+
except json.JSONDecodeError:
|
|
174
|
+
continue
|
|
175
|
+
|
|
176
|
+
except Exception as e:
|
|
177
|
+
logger.error(f"Error reading session file: {e}")
|
|
178
|
+
|
|
179
|
+
return None
|
|
180
|
+
|
|
181
|
+
async def get_most_recent_unscored_exchange(self) -> Optional[Dict[str, Any]]:
|
|
182
|
+
"""
|
|
183
|
+
Get the most recent unscored exchange across ALL sessions.
|
|
184
|
+
|
|
185
|
+
This handles the MCP/hook session ID mismatch by finding
|
|
186
|
+
the most recent exchange regardless of which session it's in.
|
|
187
|
+
|
|
188
|
+
Returns:
|
|
189
|
+
Most recent unscored exchange with conversation_id, or None
|
|
190
|
+
"""
|
|
191
|
+
most_recent = None
|
|
192
|
+
most_recent_time = None
|
|
193
|
+
|
|
194
|
+
# Scan all session files
|
|
195
|
+
for session_file in self.sessions_dir.glob("*.jsonl"):
|
|
196
|
+
conversation_id = session_file.stem
|
|
197
|
+
|
|
198
|
+
try:
|
|
199
|
+
with open(session_file, "r", encoding="utf-8") as f:
|
|
200
|
+
lines = f.readlines()
|
|
201
|
+
|
|
202
|
+
# Find last unscored assistant message
|
|
203
|
+
for i in range(len(lines) - 1, -1, -1):
|
|
204
|
+
try:
|
|
205
|
+
record = json.loads(lines[i].strip())
|
|
206
|
+
if record.get("role") == "assistant" and not record.get("scored", False):
|
|
207
|
+
timestamp = record.get("timestamp", "")
|
|
208
|
+
|
|
209
|
+
# Check if this is more recent
|
|
210
|
+
if most_recent_time is None or timestamp > most_recent_time:
|
|
211
|
+
# Find the user message before this
|
|
212
|
+
if i > 0:
|
|
213
|
+
prev_line = lines[i - 1]
|
|
214
|
+
user_record = json.loads(prev_line.strip())
|
|
215
|
+
if user_record.get("role") == "user":
|
|
216
|
+
most_recent = {
|
|
217
|
+
"user": user_record.get("content", ""),
|
|
218
|
+
"assistant": record.get("content", ""),
|
|
219
|
+
"doc_id": record.get("doc_id"),
|
|
220
|
+
"timestamp": timestamp,
|
|
221
|
+
"scored": False,
|
|
222
|
+
"conversation_id": conversation_id
|
|
223
|
+
}
|
|
224
|
+
most_recent_time = timestamp
|
|
225
|
+
break # Only check the last assistant message per file
|
|
226
|
+
except json.JSONDecodeError:
|
|
227
|
+
continue
|
|
228
|
+
|
|
229
|
+
except Exception as e:
|
|
230
|
+
logger.error(f"Error reading session file {session_file}: {e}")
|
|
231
|
+
continue
|
|
232
|
+
|
|
233
|
+
if most_recent:
|
|
234
|
+
logger.info(f"Found most recent unscored exchange in session {most_recent.get('conversation_id')}")
|
|
235
|
+
|
|
236
|
+
return most_recent
|
|
237
|
+
|
|
238
|
+
async def mark_scored(
|
|
239
|
+
self,
|
|
240
|
+
conversation_id: str,
|
|
241
|
+
doc_id: str,
|
|
242
|
+
outcome: str
|
|
243
|
+
) -> bool:
|
|
244
|
+
"""
|
|
245
|
+
Mark an exchange as scored.
|
|
246
|
+
|
|
247
|
+
Called when record_response() MCP tool is invoked.
|
|
248
|
+
|
|
249
|
+
Args:
|
|
250
|
+
conversation_id: Session identifier
|
|
251
|
+
doc_id: Document ID that was scored
|
|
252
|
+
outcome: The outcome that was recorded
|
|
253
|
+
|
|
254
|
+
Returns:
|
|
255
|
+
True if successfully marked
|
|
256
|
+
"""
|
|
257
|
+
# Update cache
|
|
258
|
+
if conversation_id in self._last_exchange_cache:
|
|
259
|
+
if self._last_exchange_cache[conversation_id].get("doc_id") == doc_id:
|
|
260
|
+
self._last_exchange_cache[conversation_id]["scored"] = True
|
|
261
|
+
self._last_exchange_cache[conversation_id]["outcome"] = outcome
|
|
262
|
+
|
|
263
|
+
# Update file (rewrite the last assistant record with scored=True)
|
|
264
|
+
session_file = self._get_session_file(conversation_id)
|
|
265
|
+
if not session_file.exists():
|
|
266
|
+
return False
|
|
267
|
+
|
|
268
|
+
try:
|
|
269
|
+
with open(session_file, "r", encoding="utf-8") as f:
|
|
270
|
+
lines = f.readlines()
|
|
271
|
+
|
|
272
|
+
# Find and update the assistant record with this doc_id
|
|
273
|
+
updated = False
|
|
274
|
+
for i in range(len(lines) - 1, -1, -1):
|
|
275
|
+
try:
|
|
276
|
+
record = json.loads(lines[i].strip())
|
|
277
|
+
if record.get("role") == "assistant" and record.get("doc_id") == doc_id:
|
|
278
|
+
record["scored"] = True
|
|
279
|
+
record["outcome"] = outcome
|
|
280
|
+
lines[i] = json.dumps(record) + "\n"
|
|
281
|
+
updated = True
|
|
282
|
+
break
|
|
283
|
+
except json.JSONDecodeError:
|
|
284
|
+
continue
|
|
285
|
+
|
|
286
|
+
if updated:
|
|
287
|
+
with open(session_file, "w", encoding="utf-8") as f:
|
|
288
|
+
f.writelines(lines)
|
|
289
|
+
logger.info(f"Marked {doc_id} as scored with outcome={outcome}")
|
|
290
|
+
return True
|
|
291
|
+
|
|
292
|
+
except Exception as e:
|
|
293
|
+
logger.error(f"Error marking scored: {e}")
|
|
294
|
+
|
|
295
|
+
return False
|
|
296
|
+
|
|
297
|
+
async def get_session_history(
|
|
298
|
+
self,
|
|
299
|
+
conversation_id: str,
|
|
300
|
+
limit: int = 20
|
|
301
|
+
) -> List[Dict[str, Any]]:
|
|
302
|
+
"""
|
|
303
|
+
Get recent conversation history.
|
|
304
|
+
|
|
305
|
+
Args:
|
|
306
|
+
conversation_id: Session identifier
|
|
307
|
+
limit: Max messages to return
|
|
308
|
+
|
|
309
|
+
Returns:
|
|
310
|
+
List of message records
|
|
311
|
+
"""
|
|
312
|
+
session_file = self._get_session_file(conversation_id)
|
|
313
|
+
if not session_file.exists():
|
|
314
|
+
return []
|
|
315
|
+
|
|
316
|
+
messages = []
|
|
317
|
+
try:
|
|
318
|
+
with open(session_file, "r", encoding="utf-8") as f:
|
|
319
|
+
lines = f.readlines()
|
|
320
|
+
|
|
321
|
+
for line in lines[-limit:]:
|
|
322
|
+
try:
|
|
323
|
+
record = json.loads(line.strip())
|
|
324
|
+
messages.append(record)
|
|
325
|
+
except json.JSONDecodeError:
|
|
326
|
+
continue
|
|
327
|
+
|
|
328
|
+
except Exception as e:
|
|
329
|
+
logger.error(f"Error reading history: {e}")
|
|
330
|
+
|
|
331
|
+
return messages
|
|
332
|
+
|
|
333
|
+
def build_scoring_prompt(
|
|
334
|
+
self,
|
|
335
|
+
previous_exchange: Dict[str, Any],
|
|
336
|
+
current_user_message: str,
|
|
337
|
+
surfaced_memories: Optional[List[Dict[str, Any]]] = None
|
|
338
|
+
) -> str:
|
|
339
|
+
"""
|
|
340
|
+
Build the scoring prompt to inject.
|
|
341
|
+
|
|
342
|
+
This is what Claude sees at the start of each turn,
|
|
343
|
+
prompting them to score the previous exchange.
|
|
344
|
+
|
|
345
|
+
Args:
|
|
346
|
+
previous_exchange: The exchange to score
|
|
347
|
+
current_user_message: What the user just said (for context)
|
|
348
|
+
surfaced_memories: Optional list of memories that were injected (for selective scoring)
|
|
349
|
+
|
|
350
|
+
Returns:
|
|
351
|
+
Formatted prompt to prepend
|
|
352
|
+
"""
|
|
353
|
+
user_asked = previous_exchange.get("user", "")
|
|
354
|
+
assistant_said = previous_exchange.get("assistant", "")
|
|
355
|
+
|
|
356
|
+
# Build memories section if we have surfaced memories
|
|
357
|
+
memories_section = ""
|
|
358
|
+
if surfaced_memories:
|
|
359
|
+
memory_lines = []
|
|
360
|
+
for mem in surfaced_memories:
|
|
361
|
+
doc_id = mem.get("id", mem.get("doc_id", "unknown"))
|
|
362
|
+
content = mem.get("content") or mem.get("text") or mem.get("metadata", {}).get("content", "")
|
|
363
|
+
# Truncate long content for display
|
|
364
|
+
if len(content) > 100:
|
|
365
|
+
content = content[:100] + "..."
|
|
366
|
+
memory_lines.append(f"• [{doc_id}] \"{content}\"")
|
|
367
|
+
|
|
368
|
+
if memory_lines:
|
|
369
|
+
memories_section = "\nMemories surfaced:\n" + "\n".join(memory_lines) + "\n"
|
|
370
|
+
|
|
371
|
+
# Build selective scoring instruction
|
|
372
|
+
scoring_instruction = "Call score_response(outcome=\"worked|failed|partial|unknown\") FIRST, then respond."
|
|
373
|
+
if surfaced_memories:
|
|
374
|
+
scoring_instruction = """Call score_response(outcome="...", related=["doc_ids that were relevant"]) FIRST.
|
|
375
|
+
- related is optional: omit to score all, or list only the memories you actually used"""
|
|
376
|
+
|
|
377
|
+
return f"""<roampal-score-required>
|
|
378
|
+
Score the previous exchange before responding.
|
|
379
|
+
|
|
380
|
+
Previous:
|
|
381
|
+
- User asked: "{user_asked}"
|
|
382
|
+
- You answered: "{assistant_said}"
|
|
383
|
+
{memories_section}
|
|
384
|
+
Current user message: "{current_user_message}"
|
|
385
|
+
|
|
386
|
+
Based on the user's current message, evaluate if your previous answer helped:
|
|
387
|
+
- "worked" = user satisfied, says thanks, moves on to new topic
|
|
388
|
+
- "failed" = user corrects you, says no/wrong, repeats question
|
|
389
|
+
- "partial" = lukewarm response, "kind of", "I guess"
|
|
390
|
+
- "unknown" = no clear signal
|
|
391
|
+
|
|
392
|
+
{scoring_instruction}
|
|
393
|
+
|
|
394
|
+
Separately, record_response(key_takeaway="...") is OPTIONAL - only for significant learnings.
|
|
395
|
+
</roampal-score-required>
|
|
396
|
+
|
|
397
|
+
"""
|
|
398
|
+
|
|
399
|
+
# ========== Completion State Tracking ==========
|
|
400
|
+
# These methods track whether the assistant has completed a response,
|
|
401
|
+
# so we only inject scoring prompts when the user responds to completed work.
|
|
402
|
+
|
|
403
|
+
def _load_completion_state(self) -> Dict[str, Any]:
|
|
404
|
+
"""Load completion state from file."""
|
|
405
|
+
if not self._state_file.exists():
|
|
406
|
+
return {}
|
|
407
|
+
try:
|
|
408
|
+
with open(self._state_file, "r", encoding="utf-8") as f:
|
|
409
|
+
return json.load(f)
|
|
410
|
+
except (json.JSONDecodeError, Exception) as e:
|
|
411
|
+
logger.warning(f"Error loading completion state: {e}")
|
|
412
|
+
return {}
|
|
413
|
+
|
|
414
|
+
def _save_completion_state(self, state: Dict[str, Any]) -> None:
|
|
415
|
+
"""Save completion state to file."""
|
|
416
|
+
try:
|
|
417
|
+
with open(self._state_file, "w", encoding="utf-8") as f:
|
|
418
|
+
json.dump(state, f)
|
|
419
|
+
except Exception as e:
|
|
420
|
+
logger.error(f"Error saving completion state: {e}")
|
|
421
|
+
|
|
422
|
+
def set_completed(self, conversation_id: str) -> None:
|
|
423
|
+
"""
|
|
424
|
+
Mark that the assistant has completed a response.
|
|
425
|
+
|
|
426
|
+
Called by Stop hook when assistant finishes responding.
|
|
427
|
+
This signals that the next user message should trigger scoring.
|
|
428
|
+
|
|
429
|
+
Args:
|
|
430
|
+
conversation_id: Session identifier
|
|
431
|
+
"""
|
|
432
|
+
state = self._load_completion_state()
|
|
433
|
+
# MERGE with existing state to preserve first_message_seen and other flags
|
|
434
|
+
if conversation_id not in state:
|
|
435
|
+
state[conversation_id] = {}
|
|
436
|
+
state[conversation_id]["completed"] = True
|
|
437
|
+
state[conversation_id]["timestamp"] = datetime.now().isoformat()
|
|
438
|
+
state[conversation_id]["scoring_required"] = False # Will be set by check_and_clear_completed
|
|
439
|
+
self._save_completion_state(state)
|
|
440
|
+
logger.info(f"Marked conversation {conversation_id} as completed")
|
|
441
|
+
|
|
442
|
+
def set_scoring_required(self, conversation_id: str, required: bool) -> None:
|
|
443
|
+
"""
|
|
444
|
+
Track that scoring was required this turn.
|
|
445
|
+
|
|
446
|
+
Called by get-context endpoint when it injects a scoring prompt.
|
|
447
|
+
The Stop hook uses this to decide whether to block.
|
|
448
|
+
|
|
449
|
+
Args:
|
|
450
|
+
conversation_id: Session identifier
|
|
451
|
+
required: Whether scoring prompt was injected
|
|
452
|
+
"""
|
|
453
|
+
state = self._load_completion_state()
|
|
454
|
+
if conversation_id not in state:
|
|
455
|
+
state[conversation_id] = {}
|
|
456
|
+
state[conversation_id]["scoring_required"] = required
|
|
457
|
+
# Reset scored_this_turn when starting a new turn that requires scoring
|
|
458
|
+
if required:
|
|
459
|
+
state[conversation_id]["scored_this_turn"] = False
|
|
460
|
+
self._save_completion_state(state)
|
|
461
|
+
logger.info(f"Set scoring_required={required} for {conversation_id}")
|
|
462
|
+
|
|
463
|
+
def was_scoring_required(self, conversation_id: str) -> bool:
|
|
464
|
+
"""
|
|
465
|
+
Check if scoring was required this turn.
|
|
466
|
+
|
|
467
|
+
Called by Stop hook to decide whether to block.
|
|
468
|
+
|
|
469
|
+
Args:
|
|
470
|
+
conversation_id: Session identifier
|
|
471
|
+
|
|
472
|
+
Returns:
|
|
473
|
+
True if scoring prompt was injected this turn
|
|
474
|
+
"""
|
|
475
|
+
state = self._load_completion_state()
|
|
476
|
+
return state.get(conversation_id, {}).get("scoring_required", False)
|
|
477
|
+
|
|
478
|
+
def set_scored_this_turn(self, conversation_id: str, scored: bool = True) -> None:
|
|
479
|
+
"""
|
|
480
|
+
Track that scoring was completed this turn.
|
|
481
|
+
|
|
482
|
+
Called by record_response endpoint when LLM calls record_response MCP tool.
|
|
483
|
+
The Stop hook uses this to decide whether to block.
|
|
484
|
+
|
|
485
|
+
Args:
|
|
486
|
+
conversation_id: Session identifier
|
|
487
|
+
scored: Whether scoring was completed (default True)
|
|
488
|
+
"""
|
|
489
|
+
state = self._load_completion_state()
|
|
490
|
+
if conversation_id not in state:
|
|
491
|
+
state[conversation_id] = {}
|
|
492
|
+
state[conversation_id]["scored_this_turn"] = scored
|
|
493
|
+
self._save_completion_state(state)
|
|
494
|
+
logger.info(f"Set scored_this_turn={scored} for {conversation_id}")
|
|
495
|
+
|
|
496
|
+
def was_scored_this_turn(self, conversation_id: str) -> bool:
|
|
497
|
+
"""
|
|
498
|
+
Check if scoring was completed this turn.
|
|
499
|
+
|
|
500
|
+
Called by Stop hook to decide whether to block.
|
|
501
|
+
|
|
502
|
+
Args:
|
|
503
|
+
conversation_id: Session identifier
|
|
504
|
+
|
|
505
|
+
Returns:
|
|
506
|
+
True if record_response was called this turn
|
|
507
|
+
"""
|
|
508
|
+
state = self._load_completion_state()
|
|
509
|
+
return state.get(conversation_id, {}).get("scored_this_turn", False)
|
|
510
|
+
|
|
511
|
+
def check_and_clear_completed(self, conversation_id: str) -> bool:
|
|
512
|
+
"""
|
|
513
|
+
Check if assistant completed and clear the flag.
|
|
514
|
+
|
|
515
|
+
Called by UserPromptSubmit hook to decide if scoring is needed.
|
|
516
|
+
Returns True and clears the flag if assistant had completed.
|
|
517
|
+
Returns False if assistant was mid-work (no scoring needed).
|
|
518
|
+
|
|
519
|
+
Args:
|
|
520
|
+
conversation_id: Session identifier
|
|
521
|
+
|
|
522
|
+
Returns:
|
|
523
|
+
True if assistant had completed (scoring should happen)
|
|
524
|
+
False if assistant was mid-work (no scoring needed)
|
|
525
|
+
"""
|
|
526
|
+
state = self._load_completion_state()
|
|
527
|
+
|
|
528
|
+
if conversation_id in state and state[conversation_id].get("completed"):
|
|
529
|
+
# Clear only the completed flag, preserve other state (first_message_seen, etc.)
|
|
530
|
+
state[conversation_id]["completed"] = False
|
|
531
|
+
self._save_completion_state(state)
|
|
532
|
+
logger.info(f"Conversation {conversation_id} was completed - scoring needed")
|
|
533
|
+
return True
|
|
534
|
+
|
|
535
|
+
logger.info(f"Conversation {conversation_id} not completed - skip scoring")
|
|
536
|
+
return False
|
|
537
|
+
|
|
538
|
+
def is_completed(self, conversation_id: str) -> bool:
|
|
539
|
+
"""
|
|
540
|
+
Check if assistant completed without clearing flag.
|
|
541
|
+
|
|
542
|
+
Useful for checking state without side effects.
|
|
543
|
+
|
|
544
|
+
Args:
|
|
545
|
+
conversation_id: Session identifier
|
|
546
|
+
|
|
547
|
+
Returns:
|
|
548
|
+
True if assistant had completed
|
|
549
|
+
"""
|
|
550
|
+
state = self._load_completion_state()
|
|
551
|
+
return conversation_id in state and state[conversation_id].get("completed", False)
|
|
552
|
+
|
|
553
|
+
# ========== Cold Start / First Message Tracking ==========
|
|
554
|
+
# Track which sessions have had their first message, so we can dump
|
|
555
|
+
# the full user profile on cold start.
|
|
556
|
+
|
|
557
|
+
def is_first_message(self, conversation_id: str) -> bool:
|
|
558
|
+
"""
|
|
559
|
+
Check if this is the first message in a session.
|
|
560
|
+
|
|
561
|
+
Called by get-context to decide if cold start user profile dump is needed.
|
|
562
|
+
|
|
563
|
+
Args:
|
|
564
|
+
conversation_id: Session identifier
|
|
565
|
+
|
|
566
|
+
Returns:
|
|
567
|
+
True if this is the first message (cold start)
|
|
568
|
+
"""
|
|
569
|
+
state = self._load_completion_state()
|
|
570
|
+
return not state.get(conversation_id, {}).get("first_message_seen", False)
|
|
571
|
+
|
|
572
|
+
def mark_first_message_seen(self, conversation_id: str) -> None:
|
|
573
|
+
"""
|
|
574
|
+
Mark that the first message has been seen for this session.
|
|
575
|
+
|
|
576
|
+
Called by get-context after injecting cold start profile.
|
|
577
|
+
|
|
578
|
+
Args:
|
|
579
|
+
conversation_id: Session identifier
|
|
580
|
+
"""
|
|
581
|
+
state = self._load_completion_state()
|
|
582
|
+
if conversation_id not in state:
|
|
583
|
+
state[conversation_id] = {}
|
|
584
|
+
state[conversation_id]["first_message_seen"] = True
|
|
585
|
+
state[conversation_id]["first_message_timestamp"] = datetime.now().isoformat()
|
|
586
|
+
self._save_completion_state(state)
|
|
587
|
+
logger.info(f"Marked first message seen for {conversation_id}")
|