dao-ai 0.1.5__py3-none-any.whl → 0.1.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/apps/__init__.py +24 -0
- dao_ai/apps/handlers.py +105 -0
- dao_ai/apps/model_serving.py +29 -0
- dao_ai/apps/resources.py +1122 -0
- dao_ai/apps/server.py +39 -0
- dao_ai/cli.py +446 -16
- dao_ai/config.py +1034 -103
- dao_ai/evaluation.py +543 -0
- dao_ai/genie/__init__.py +55 -7
- dao_ai/genie/cache/__init__.py +34 -7
- dao_ai/genie/cache/base.py +143 -2
- dao_ai/genie/cache/context_aware/__init__.py +31 -0
- dao_ai/genie/cache/context_aware/base.py +1151 -0
- dao_ai/genie/cache/context_aware/in_memory.py +609 -0
- dao_ai/genie/cache/context_aware/persistent.py +802 -0
- dao_ai/genie/cache/context_aware/postgres.py +1166 -0
- dao_ai/genie/cache/core.py +1 -1
- dao_ai/genie/cache/lru.py +257 -75
- dao_ai/genie/cache/optimization.py +890 -0
- dao_ai/genie/core.py +235 -11
- dao_ai/memory/postgres.py +175 -39
- dao_ai/middleware/__init__.py +5 -0
- dao_ai/middleware/tool_selector.py +129 -0
- dao_ai/models.py +327 -370
- dao_ai/nodes.py +4 -4
- dao_ai/orchestration/core.py +33 -9
- dao_ai/orchestration/supervisor.py +23 -8
- dao_ai/orchestration/swarm.py +6 -1
- dao_ai/{prompts.py → prompts/__init__.py} +12 -61
- dao_ai/prompts/instructed_retriever_decomposition.yaml +58 -0
- dao_ai/prompts/instruction_reranker.yaml +14 -0
- dao_ai/prompts/router.yaml +37 -0
- dao_ai/prompts/verifier.yaml +46 -0
- dao_ai/providers/base.py +28 -2
- dao_ai/providers/databricks.py +352 -33
- dao_ai/state.py +1 -0
- dao_ai/tools/__init__.py +5 -3
- dao_ai/tools/genie.py +103 -26
- dao_ai/tools/instructed_retriever.py +366 -0
- dao_ai/tools/instruction_reranker.py +202 -0
- dao_ai/tools/mcp.py +539 -97
- dao_ai/tools/router.py +89 -0
- dao_ai/tools/slack.py +13 -2
- dao_ai/tools/sql.py +7 -3
- dao_ai/tools/unity_catalog.py +32 -10
- dao_ai/tools/vector_search.py +493 -160
- dao_ai/tools/verifier.py +159 -0
- dao_ai/utils.py +182 -2
- dao_ai/vector_search.py +9 -1
- {dao_ai-0.1.5.dist-info → dao_ai-0.1.20.dist-info}/METADATA +10 -8
- dao_ai-0.1.20.dist-info/RECORD +89 -0
- dao_ai/agent_as_code.py +0 -22
- dao_ai/genie/cache/semantic.py +0 -970
- dao_ai-0.1.5.dist-info/RECORD +0 -70
- {dao_ai-0.1.5.dist-info → dao_ai-0.1.20.dist-info}/WHEEL +0 -0
- {dao_ai-0.1.5.dist-info → dao_ai-0.1.20.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.1.5.dist-info → dao_ai-0.1.20.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,802 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Abstract base class for persistent (database-backed) context-aware Genie cache implementations.
|
|
3
|
+
|
|
4
|
+
This module provides the foundational abstract base class for database-backed
|
|
5
|
+
cache implementations. It adds:
|
|
6
|
+
- Connection pooling management
|
|
7
|
+
- Transaction handling with retry logic
|
|
8
|
+
- Prompt history storage and retrieval
|
|
9
|
+
- Database error handling with exponential backoff
|
|
10
|
+
|
|
11
|
+
Subclasses must implement database-specific methods:
|
|
12
|
+
- _create_table_if_not_exists(): Create database schema
|
|
13
|
+
- _get_pool(): Get database connection pool
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
from abc import abstractmethod
|
|
19
|
+
from typing import Any, Callable, TypeVar
|
|
20
|
+
|
|
21
|
+
from loguru import logger
|
|
22
|
+
|
|
23
|
+
from dao_ai.config import DatabaseModel
|
|
24
|
+
from dao_ai.genie.cache.context_aware.base import (
|
|
25
|
+
ContextAwareGenieService,
|
|
26
|
+
get_conversation_history,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
# Type variable for return types
|
|
30
|
+
T = TypeVar("T")
|
|
31
|
+
|
|
32
|
+
# Type alias for database row (dict due to row_factory=dict_row)
|
|
33
|
+
DbRow = dict[str, Any]
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class PersistentContextAwareGenieCacheService(ContextAwareGenieService):
|
|
37
|
+
"""
|
|
38
|
+
Abstract base class for database-backed context-aware Genie cache implementations.
|
|
39
|
+
|
|
40
|
+
This class extends ContextAwareGenieService with database-specific functionality:
|
|
41
|
+
- Connection pool management
|
|
42
|
+
- Prompt history tracking for conversation context
|
|
43
|
+
- Retry logic for transient database failures
|
|
44
|
+
- Schema creation and management
|
|
45
|
+
|
|
46
|
+
Subclasses must implement:
|
|
47
|
+
- _get_pool(): Return the database connection pool
|
|
48
|
+
- _create_table_if_not_exists(): Create required database tables
|
|
49
|
+
- Database-specific _find_similar() and _store_entry() implementations
|
|
50
|
+
|
|
51
|
+
Thread Safety:
|
|
52
|
+
Uses connection pooling for thread-safe database access.
|
|
53
|
+
All database operations use connection context managers.
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
# Additional attributes for persistent implementations
|
|
57
|
+
_pool: Any # ConnectionPool
|
|
58
|
+
|
|
59
|
+
@property
|
|
60
|
+
@abstractmethod
|
|
61
|
+
def database(self) -> DatabaseModel:
|
|
62
|
+
"""The database used for storing cache entries."""
|
|
63
|
+
pass
|
|
64
|
+
|
|
65
|
+
@property
|
|
66
|
+
@abstractmethod
|
|
67
|
+
def table_name(self) -> str:
|
|
68
|
+
"""Name of the cache table."""
|
|
69
|
+
pass
|
|
70
|
+
|
|
71
|
+
@property
|
|
72
|
+
@abstractmethod
|
|
73
|
+
def prompt_history_table(self) -> str:
|
|
74
|
+
"""Name of the prompt history table."""
|
|
75
|
+
pass
|
|
76
|
+
|
|
77
|
+
@property
|
|
78
|
+
@abstractmethod
|
|
79
|
+
def context_window_size(self) -> int:
|
|
80
|
+
"""Number of previous prompts to include in context."""
|
|
81
|
+
pass
|
|
82
|
+
|
|
83
|
+
@property
|
|
84
|
+
@abstractmethod
|
|
85
|
+
def max_context_tokens(self) -> int:
|
|
86
|
+
"""Maximum tokens for context string."""
|
|
87
|
+
pass
|
|
88
|
+
|
|
89
|
+
@property
|
|
90
|
+
@abstractmethod
|
|
91
|
+
def context_similarity_threshold(self) -> float:
|
|
92
|
+
"""Minimum similarity for context matching."""
|
|
93
|
+
pass
|
|
94
|
+
|
|
95
|
+
@property
|
|
96
|
+
@abstractmethod
|
|
97
|
+
def question_weight(self) -> float:
|
|
98
|
+
"""Weight for question similarity in combined score."""
|
|
99
|
+
pass
|
|
100
|
+
|
|
101
|
+
@property
|
|
102
|
+
@abstractmethod
|
|
103
|
+
def context_weight(self) -> float:
|
|
104
|
+
"""Weight for context similarity in combined score."""
|
|
105
|
+
pass
|
|
106
|
+
|
|
107
|
+
@property
|
|
108
|
+
@abstractmethod
|
|
109
|
+
def max_prompt_history_length(self) -> int:
|
|
110
|
+
"""Maximum number of prompts to keep per conversation."""
|
|
111
|
+
pass
|
|
112
|
+
|
|
113
|
+
@property
|
|
114
|
+
@abstractmethod
|
|
115
|
+
def time_to_live_seconds(self) -> int | None:
|
|
116
|
+
"""TTL in seconds (None or negative = never expires)."""
|
|
117
|
+
pass
|
|
118
|
+
|
|
119
|
+
@abstractmethod
|
|
120
|
+
def _create_table_if_not_exists(self) -> None:
|
|
121
|
+
"""
|
|
122
|
+
Create the cache and prompt history tables if they don't exist.
|
|
123
|
+
|
|
124
|
+
This method should handle:
|
|
125
|
+
- Creating the cache table with vector columns
|
|
126
|
+
- Creating indexes for efficient similarity search
|
|
127
|
+
- Creating the prompt history table
|
|
128
|
+
- Handling schema migrations if needed
|
|
129
|
+
"""
|
|
130
|
+
pass
|
|
131
|
+
|
|
132
|
+
def _execute_with_retry(
|
|
133
|
+
self,
|
|
134
|
+
operation: Callable[[], T],
|
|
135
|
+
max_attempts: int = 3,
|
|
136
|
+
base_delay: float = 1.0,
|
|
137
|
+
max_delay: float = 10.0,
|
|
138
|
+
) -> T:
|
|
139
|
+
"""
|
|
140
|
+
Execute a database operation with exponential backoff retry.
|
|
141
|
+
|
|
142
|
+
Args:
|
|
143
|
+
operation: The database operation to execute
|
|
144
|
+
max_attempts: Maximum number of retry attempts
|
|
145
|
+
base_delay: Initial delay between retries (seconds)
|
|
146
|
+
max_delay: Maximum delay between retries (seconds)
|
|
147
|
+
|
|
148
|
+
Returns:
|
|
149
|
+
The result of the operation
|
|
150
|
+
|
|
151
|
+
Raises:
|
|
152
|
+
The last exception if all retries fail
|
|
153
|
+
"""
|
|
154
|
+
import time
|
|
155
|
+
|
|
156
|
+
last_exception: Exception | None = None
|
|
157
|
+
delay = base_delay
|
|
158
|
+
|
|
159
|
+
for attempt in range(max_attempts):
|
|
160
|
+
try:
|
|
161
|
+
return operation()
|
|
162
|
+
except Exception as e:
|
|
163
|
+
last_exception = e
|
|
164
|
+
error_str = str(e).lower()
|
|
165
|
+
|
|
166
|
+
# Check if this is a retryable error
|
|
167
|
+
retryable_errors = [
|
|
168
|
+
"connection",
|
|
169
|
+
"timeout",
|
|
170
|
+
"temporarily unavailable",
|
|
171
|
+
"too many connections",
|
|
172
|
+
"connection refused",
|
|
173
|
+
"operational error",
|
|
174
|
+
]
|
|
175
|
+
is_retryable = any(err in error_str for err in retryable_errors)
|
|
176
|
+
|
|
177
|
+
if not is_retryable or attempt == max_attempts - 1:
|
|
178
|
+
raise
|
|
179
|
+
|
|
180
|
+
logger.warning(
|
|
181
|
+
f"Database operation failed (attempt {attempt + 1}/{max_attempts}), retrying",
|
|
182
|
+
layer=self.name,
|
|
183
|
+
error=str(e),
|
|
184
|
+
delay=delay,
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
time.sleep(delay)
|
|
188
|
+
delay = min(delay * 2, max_delay)
|
|
189
|
+
|
|
190
|
+
# Should not reach here, but just in case
|
|
191
|
+
if last_exception:
|
|
192
|
+
raise last_exception
|
|
193
|
+
raise RuntimeError("Unexpected state in retry logic")
|
|
194
|
+
|
|
195
|
+
def _index_exists(self, cur: Any, index_name: str) -> bool:
|
|
196
|
+
"""
|
|
197
|
+
Check if an index already exists in the database.
|
|
198
|
+
|
|
199
|
+
Args:
|
|
200
|
+
cur: Database cursor to execute SQL statements
|
|
201
|
+
index_name: Name of the index to check
|
|
202
|
+
|
|
203
|
+
Returns:
|
|
204
|
+
True if the index exists, False otherwise
|
|
205
|
+
"""
|
|
206
|
+
cur.execute(
|
|
207
|
+
"SELECT 1 FROM pg_indexes WHERE indexname = %s",
|
|
208
|
+
(index_name,),
|
|
209
|
+
)
|
|
210
|
+
return cur.fetchone() is not None
|
|
211
|
+
|
|
212
|
+
def _store_user_prompt(
|
|
213
|
+
self,
|
|
214
|
+
prompt: str,
|
|
215
|
+
conversation_id: str,
|
|
216
|
+
cache_hit: bool = False,
|
|
217
|
+
) -> bool:
|
|
218
|
+
"""
|
|
219
|
+
Store user prompt in local conversation history.
|
|
220
|
+
|
|
221
|
+
This is called after embeddings are generated to ensure the current prompt
|
|
222
|
+
is not included in its own context.
|
|
223
|
+
|
|
224
|
+
Prompt history is non-critical; failures are logged but don't crash the request.
|
|
225
|
+
|
|
226
|
+
Args:
|
|
227
|
+
prompt: The user's question/prompt
|
|
228
|
+
conversation_id: The conversation ID
|
|
229
|
+
cache_hit: Whether this prompt resulted in a cache hit
|
|
230
|
+
|
|
231
|
+
Returns:
|
|
232
|
+
True if prompt was stored successfully, False otherwise
|
|
233
|
+
"""
|
|
234
|
+
prompt_table_name = self.prompt_history_table
|
|
235
|
+
insert_sql: str = f"""
|
|
236
|
+
INSERT INTO {prompt_table_name}
|
|
237
|
+
(genie_space_id, conversation_id, prompt, cache_hit)
|
|
238
|
+
VALUES (%s, %s, %s, %s)
|
|
239
|
+
"""
|
|
240
|
+
|
|
241
|
+
logger.debug(
|
|
242
|
+
"Inserting prompt into history",
|
|
243
|
+
layer=self.name,
|
|
244
|
+
table=prompt_table_name,
|
|
245
|
+
space_id=self.space_id,
|
|
246
|
+
conversation_id=conversation_id,
|
|
247
|
+
prompt_preview=prompt[:80] if len(prompt) > 80 else prompt,
|
|
248
|
+
prompt_length=len(prompt),
|
|
249
|
+
cache_hit=cache_hit,
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
try:
|
|
253
|
+
with self._pool.connection() as conn:
|
|
254
|
+
with conn.cursor() as cur:
|
|
255
|
+
cur.execute(
|
|
256
|
+
insert_sql, (self.space_id, conversation_id, prompt, cache_hit)
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
logger.info(
|
|
260
|
+
"Stored user prompt in history",
|
|
261
|
+
layer=self.name,
|
|
262
|
+
table=prompt_table_name,
|
|
263
|
+
conversation_id=conversation_id,
|
|
264
|
+
prompt_preview=prompt[:50],
|
|
265
|
+
cache_hit=cache_hit,
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
# Enforce max_prompt_history_length per conversation
|
|
269
|
+
self._enforce_prompt_history_limit(conversation_id)
|
|
270
|
+
|
|
271
|
+
return True
|
|
272
|
+
except Exception as e:
|
|
273
|
+
logger.warning(
|
|
274
|
+
f"Failed to store prompt in history (non-critical): {e}",
|
|
275
|
+
layer=self.name,
|
|
276
|
+
table=prompt_table_name,
|
|
277
|
+
conversation_id=conversation_id,
|
|
278
|
+
)
|
|
279
|
+
return False
|
|
280
|
+
|
|
281
|
+
def _enforce_prompt_history_limit(self, conversation_id: str) -> int:
|
|
282
|
+
"""
|
|
283
|
+
Delete oldest prompts if conversation exceeds max_prompt_history_length.
|
|
284
|
+
|
|
285
|
+
This is called after inserting a new prompt to keep history bounded.
|
|
286
|
+
Uses a single DELETE with subquery for efficiency.
|
|
287
|
+
|
|
288
|
+
Args:
|
|
289
|
+
conversation_id: The conversation ID to enforce limit for
|
|
290
|
+
|
|
291
|
+
Returns:
|
|
292
|
+
Number of prompts deleted (0 if within limit)
|
|
293
|
+
"""
|
|
294
|
+
max_length = self.max_prompt_history_length
|
|
295
|
+
prompt_table_name = self.prompt_history_table
|
|
296
|
+
|
|
297
|
+
# Delete prompts beyond the limit, keeping the most recent ones
|
|
298
|
+
delete_sql: str = f"""
|
|
299
|
+
DELETE FROM {prompt_table_name}
|
|
300
|
+
WHERE genie_space_id = %s
|
|
301
|
+
AND conversation_id = %s
|
|
302
|
+
AND created_at < (
|
|
303
|
+
SELECT created_at FROM {prompt_table_name}
|
|
304
|
+
WHERE genie_space_id = %s
|
|
305
|
+
AND conversation_id = %s
|
|
306
|
+
ORDER BY created_at DESC
|
|
307
|
+
LIMIT 1 OFFSET %s
|
|
308
|
+
)
|
|
309
|
+
"""
|
|
310
|
+
|
|
311
|
+
try:
|
|
312
|
+
with self._pool.connection() as conn:
|
|
313
|
+
with conn.cursor() as cur:
|
|
314
|
+
cur.execute(
|
|
315
|
+
delete_sql,
|
|
316
|
+
(
|
|
317
|
+
self.space_id,
|
|
318
|
+
conversation_id,
|
|
319
|
+
self.space_id,
|
|
320
|
+
conversation_id,
|
|
321
|
+
max_length - 1,
|
|
322
|
+
),
|
|
323
|
+
)
|
|
324
|
+
deleted = cur.rowcount if isinstance(cur.rowcount, int) else 0
|
|
325
|
+
|
|
326
|
+
if deleted > 0:
|
|
327
|
+
logger.debug(
|
|
328
|
+
"Enforced prompt history limit",
|
|
329
|
+
layer=self.name,
|
|
330
|
+
table=prompt_table_name,
|
|
331
|
+
conversation_id=conversation_id,
|
|
332
|
+
max_length=max_length,
|
|
333
|
+
deleted=deleted,
|
|
334
|
+
)
|
|
335
|
+
return deleted
|
|
336
|
+
except Exception as e:
|
|
337
|
+
logger.debug(
|
|
338
|
+
f"Failed to enforce prompt history limit (non-critical): {e}",
|
|
339
|
+
layer=self.name,
|
|
340
|
+
conversation_id=conversation_id,
|
|
341
|
+
)
|
|
342
|
+
return 0
|
|
343
|
+
|
|
344
|
+
def _get_local_prompt_history(
|
|
345
|
+
self,
|
|
346
|
+
conversation_id: str,
|
|
347
|
+
max_prompts: int | None = None,
|
|
348
|
+
) -> list[str]:
|
|
349
|
+
"""
|
|
350
|
+
Retrieve recent user prompts from local storage.
|
|
351
|
+
|
|
352
|
+
Uses SQL LIMIT for efficiency - only retrieves exactly the number
|
|
353
|
+
of prompts needed for the context window, not all prompts.
|
|
354
|
+
|
|
355
|
+
Args:
|
|
356
|
+
conversation_id: The conversation ID to retrieve prompts for
|
|
357
|
+
max_prompts: Maximum number of prompts to retrieve
|
|
358
|
+
|
|
359
|
+
Returns:
|
|
360
|
+
List of prompt strings in chronological order (oldest to newest)
|
|
361
|
+
"""
|
|
362
|
+
if max_prompts is None:
|
|
363
|
+
max_prompts = self.context_window_size
|
|
364
|
+
|
|
365
|
+
prompt_table_name = self.prompt_history_table
|
|
366
|
+
query_sql: str = f"""
|
|
367
|
+
SELECT prompt
|
|
368
|
+
FROM {prompt_table_name}
|
|
369
|
+
WHERE genie_space_id = %s
|
|
370
|
+
AND conversation_id = %s
|
|
371
|
+
ORDER BY created_at DESC
|
|
372
|
+
LIMIT %s
|
|
373
|
+
"""
|
|
374
|
+
|
|
375
|
+
logger.debug(
|
|
376
|
+
"Querying prompt history",
|
|
377
|
+
layer=self.name,
|
|
378
|
+
table=prompt_table_name,
|
|
379
|
+
space_id=self.space_id,
|
|
380
|
+
conversation_id=conversation_id,
|
|
381
|
+
max_prompts=max_prompts,
|
|
382
|
+
)
|
|
383
|
+
|
|
384
|
+
with self._pool.connection() as conn:
|
|
385
|
+
with conn.cursor() as cur:
|
|
386
|
+
# LIMIT ensures we only fetch exactly what's needed
|
|
387
|
+
cur.execute(query_sql, (self.space_id, conversation_id, max_prompts))
|
|
388
|
+
rows: list[DbRow] = cur.fetchall()
|
|
389
|
+
# Reverse to get chronological order (oldest to newest)
|
|
390
|
+
prompts = [row["prompt"] for row in reversed(rows)]
|
|
391
|
+
|
|
392
|
+
logger.info(
|
|
393
|
+
"Retrieved prompt history from database",
|
|
394
|
+
layer=self.name,
|
|
395
|
+
table=prompt_table_name,
|
|
396
|
+
conversation_id=conversation_id,
|
|
397
|
+
requested=max_prompts,
|
|
398
|
+
returned=len(prompts),
|
|
399
|
+
prompts_preview=[
|
|
400
|
+
p[:40] + "..." if len(p) > 40 else p for p in prompts
|
|
401
|
+
],
|
|
402
|
+
)
|
|
403
|
+
|
|
404
|
+
return prompts
|
|
405
|
+
|
|
406
|
+
def _update_prompt_cache_hit(
|
|
407
|
+
self,
|
|
408
|
+
conversation_id: str,
|
|
409
|
+
prompt: str,
|
|
410
|
+
cache_hit: bool,
|
|
411
|
+
cache_entry_id: int | None = None,
|
|
412
|
+
) -> bool:
|
|
413
|
+
"""
|
|
414
|
+
Update the cache_hit flag and cache_entry_id for a previously stored prompt.
|
|
415
|
+
|
|
416
|
+
This is called after determining whether the prompt resulted in a cache hit.
|
|
417
|
+
Updates the most recent prompt matching the given text.
|
|
418
|
+
|
|
419
|
+
Args:
|
|
420
|
+
conversation_id: The conversation ID
|
|
421
|
+
prompt: The prompt text to update
|
|
422
|
+
cache_hit: The cache hit status to set
|
|
423
|
+
cache_entry_id: The ID of the cache entry that served this hit (for traceability)
|
|
424
|
+
|
|
425
|
+
Returns:
|
|
426
|
+
True if update was successful, False otherwise
|
|
427
|
+
"""
|
|
428
|
+
prompt_table_name = self.prompt_history_table
|
|
429
|
+
update_sql: str = f"""
|
|
430
|
+
UPDATE {prompt_table_name}
|
|
431
|
+
SET cache_hit = %s, cache_entry_id = %s
|
|
432
|
+
WHERE genie_space_id = %s
|
|
433
|
+
AND conversation_id = %s
|
|
434
|
+
AND prompt = %s
|
|
435
|
+
AND created_at = (
|
|
436
|
+
SELECT MAX(created_at)
|
|
437
|
+
FROM {prompt_table_name}
|
|
438
|
+
WHERE genie_space_id = %s
|
|
439
|
+
AND conversation_id = %s
|
|
440
|
+
AND prompt = %s
|
|
441
|
+
)
|
|
442
|
+
"""
|
|
443
|
+
|
|
444
|
+
logger.debug(
|
|
445
|
+
"Updating prompt cache_hit flag and cache_entry_id",
|
|
446
|
+
layer=self.name,
|
|
447
|
+
table=prompt_table_name,
|
|
448
|
+
space_id=self.space_id,
|
|
449
|
+
conversation_id=conversation_id,
|
|
450
|
+
prompt_preview=prompt[:50],
|
|
451
|
+
new_cache_hit=cache_hit,
|
|
452
|
+
cache_entry_id=cache_entry_id,
|
|
453
|
+
)
|
|
454
|
+
|
|
455
|
+
try:
|
|
456
|
+
with self._pool.connection() as conn:
|
|
457
|
+
with conn.cursor() as cur:
|
|
458
|
+
cur.execute(
|
|
459
|
+
update_sql,
|
|
460
|
+
(
|
|
461
|
+
cache_hit,
|
|
462
|
+
cache_entry_id,
|
|
463
|
+
self.space_id,
|
|
464
|
+
conversation_id,
|
|
465
|
+
prompt,
|
|
466
|
+
self.space_id,
|
|
467
|
+
conversation_id,
|
|
468
|
+
prompt,
|
|
469
|
+
),
|
|
470
|
+
)
|
|
471
|
+
# Handle rowcount safely (may be Mock in tests or None)
|
|
472
|
+
updated_rows = getattr(cur, "rowcount", 0)
|
|
473
|
+
if not isinstance(updated_rows, int):
|
|
474
|
+
updated_rows = 0
|
|
475
|
+
|
|
476
|
+
if updated_rows > 0:
|
|
477
|
+
logger.info(
|
|
478
|
+
"Updated prompt cache_hit flag and cache_entry_id in history",
|
|
479
|
+
layer=self.name,
|
|
480
|
+
table=prompt_table_name,
|
|
481
|
+
conversation_id=conversation_id,
|
|
482
|
+
prompt_preview=prompt[:50],
|
|
483
|
+
cache_hit=cache_hit,
|
|
484
|
+
cache_entry_id=cache_entry_id,
|
|
485
|
+
rows_updated=updated_rows,
|
|
486
|
+
)
|
|
487
|
+
return True
|
|
488
|
+
else:
|
|
489
|
+
logger.debug(
|
|
490
|
+
"No prompt found to update cache_hit flag (may be expected)",
|
|
491
|
+
layer=self.name,
|
|
492
|
+
table=prompt_table_name,
|
|
493
|
+
conversation_id=conversation_id,
|
|
494
|
+
prompt_preview=prompt[:50],
|
|
495
|
+
)
|
|
496
|
+
return False
|
|
497
|
+
except Exception as e:
|
|
498
|
+
logger.warning(
|
|
499
|
+
f"Failed to update prompt cache_hit flag and cache_entry_id (non-critical): {e}",
|
|
500
|
+
layer=self.name,
|
|
501
|
+
table=prompt_table_name,
|
|
502
|
+
conversation_id=conversation_id,
|
|
503
|
+
cache_entry_id=cache_entry_id,
|
|
504
|
+
)
|
|
505
|
+
return False
|
|
506
|
+
|
|
507
|
+
def _embed_question(
|
|
508
|
+
self, question: str, conversation_id: str | None = None
|
|
509
|
+
) -> tuple[list[float], list[float], str]:
|
|
510
|
+
"""
|
|
511
|
+
Generate dual embeddings using local prompt history for context.
|
|
512
|
+
|
|
513
|
+
This method retrieves conversation history from local storage first,
|
|
514
|
+
falling back to Genie API if local history is empty.
|
|
515
|
+
|
|
516
|
+
Args:
|
|
517
|
+
question: The question to embed
|
|
518
|
+
conversation_id: Optional conversation ID for retrieving context
|
|
519
|
+
|
|
520
|
+
Returns:
|
|
521
|
+
Tuple of (question_embedding, context_embedding, conversation_context_string)
|
|
522
|
+
"""
|
|
523
|
+
conversation_context = ""
|
|
524
|
+
|
|
525
|
+
# If conversation context is enabled and available
|
|
526
|
+
if conversation_id is not None and self.context_window_size > 0:
|
|
527
|
+
try:
|
|
528
|
+
# Try local prompt history first (FASTER, includes cache hits)
|
|
529
|
+
recent_prompts = self._get_local_prompt_history(
|
|
530
|
+
conversation_id=conversation_id,
|
|
531
|
+
max_prompts=self.context_window_size,
|
|
532
|
+
)
|
|
533
|
+
|
|
534
|
+
logger.trace(
|
|
535
|
+
"Retrieved local prompt history",
|
|
536
|
+
layer=self.name,
|
|
537
|
+
prompts_count=len(recent_prompts),
|
|
538
|
+
conversation_id=conversation_id,
|
|
539
|
+
)
|
|
540
|
+
|
|
541
|
+
# Fallback to Genie API if local history empty and API available
|
|
542
|
+
if not recent_prompts and self.workspace_client is not None:
|
|
543
|
+
logger.debug(
|
|
544
|
+
"Local prompt history empty, falling back to Genie API",
|
|
545
|
+
layer=self.name,
|
|
546
|
+
conversation_id=conversation_id,
|
|
547
|
+
)
|
|
548
|
+
|
|
549
|
+
conversation_messages = get_conversation_history(
|
|
550
|
+
workspace_client=self.workspace_client,
|
|
551
|
+
space_id=self.space_id,
|
|
552
|
+
conversation_id=conversation_id,
|
|
553
|
+
max_messages=self.context_window_size * 2,
|
|
554
|
+
)
|
|
555
|
+
|
|
556
|
+
if conversation_messages:
|
|
557
|
+
recent_messages = (
|
|
558
|
+
conversation_messages[-self.context_window_size :]
|
|
559
|
+
if len(conversation_messages) > self.context_window_size
|
|
560
|
+
else conversation_messages
|
|
561
|
+
)
|
|
562
|
+
recent_prompts = [
|
|
563
|
+
msg.content for msg in recent_messages if msg.content
|
|
564
|
+
]
|
|
565
|
+
|
|
566
|
+
# Build context string from prompts
|
|
567
|
+
if recent_prompts:
|
|
568
|
+
context_parts: list[str] = []
|
|
569
|
+
for prompt in recent_prompts:
|
|
570
|
+
content: str = prompt
|
|
571
|
+
if len(content) > 500:
|
|
572
|
+
content = content[:500] + "..."
|
|
573
|
+
context_parts.append(f"Previous: {content}")
|
|
574
|
+
|
|
575
|
+
conversation_context = "\n".join(context_parts)
|
|
576
|
+
|
|
577
|
+
# Truncate if too long
|
|
578
|
+
estimated_tokens = len(conversation_context) / 4
|
|
579
|
+
if estimated_tokens > self.max_context_tokens:
|
|
580
|
+
target_chars = self.max_context_tokens * 4
|
|
581
|
+
conversation_context = (
|
|
582
|
+
conversation_context[:target_chars] + "..."
|
|
583
|
+
)
|
|
584
|
+
|
|
585
|
+
logger.trace(
|
|
586
|
+
"Using conversation context",
|
|
587
|
+
layer=self.name,
|
|
588
|
+
prompts_count=len(recent_prompts),
|
|
589
|
+
window_size=self.context_window_size,
|
|
590
|
+
source="local_db",
|
|
591
|
+
)
|
|
592
|
+
except Exception as e:
|
|
593
|
+
logger.warning(
|
|
594
|
+
"Failed to build conversation context, using question only",
|
|
595
|
+
layer=self.name,
|
|
596
|
+
error=str(e),
|
|
597
|
+
)
|
|
598
|
+
conversation_context = ""
|
|
599
|
+
|
|
600
|
+
return self._generate_dual_embeddings(question, conversation_context)
|
|
601
|
+
|
|
602
|
+
def get_prompt_history(
|
|
603
|
+
self,
|
|
604
|
+
conversation_id: str,
|
|
605
|
+
max_prompts: int | None = None,
|
|
606
|
+
include_cache_hits: bool = True,
|
|
607
|
+
) -> list[dict[str, Any]]:
|
|
608
|
+
"""
|
|
609
|
+
Retrieve prompt history for a conversation with metadata.
|
|
610
|
+
|
|
611
|
+
Public utility method for inspecting conversation history.
|
|
612
|
+
|
|
613
|
+
Args:
|
|
614
|
+
conversation_id: The conversation ID to retrieve
|
|
615
|
+
max_prompts: Maximum number of prompts (None = all prompts)
|
|
616
|
+
include_cache_hits: Whether to include prompts that hit cache
|
|
617
|
+
|
|
618
|
+
Returns:
|
|
619
|
+
List of prompt records with metadata (prompt, cache_hit, created_at)
|
|
620
|
+
"""
|
|
621
|
+
self._setup()
|
|
622
|
+
|
|
623
|
+
prompt_table_name = self.prompt_history_table
|
|
624
|
+
|
|
625
|
+
cache_filter = "" if include_cache_hits else "AND cache_hit = false"
|
|
626
|
+
limit_clause = f"LIMIT {max_prompts}" if max_prompts else ""
|
|
627
|
+
|
|
628
|
+
query_sql: str = f"""
|
|
629
|
+
SELECT prompt, cache_hit, created_at
|
|
630
|
+
FROM {prompt_table_name}
|
|
631
|
+
WHERE genie_space_id = %s
|
|
632
|
+
AND conversation_id = %s
|
|
633
|
+
{cache_filter}
|
|
634
|
+
ORDER BY created_at ASC
|
|
635
|
+
{limit_clause}
|
|
636
|
+
"""
|
|
637
|
+
|
|
638
|
+
with self._pool.connection() as conn:
|
|
639
|
+
with conn.cursor() as cur:
|
|
640
|
+
cur.execute(query_sql, (self.space_id, conversation_id))
|
|
641
|
+
rows: list[DbRow] = cur.fetchall()
|
|
642
|
+
|
|
643
|
+
return [
|
|
644
|
+
{
|
|
645
|
+
"prompt": row["prompt"],
|
|
646
|
+
"cache_hit": row["cache_hit"],
|
|
647
|
+
"created_at": row["created_at"],
|
|
648
|
+
}
|
|
649
|
+
for row in rows
|
|
650
|
+
]
|
|
651
|
+
|
|
652
|
+
def export_prompt_history(
|
|
653
|
+
self,
|
|
654
|
+
conversation_id: str,
|
|
655
|
+
output_format: str = "text",
|
|
656
|
+
) -> str:
|
|
657
|
+
"""
|
|
658
|
+
Export prompt history for a conversation in various formats.
|
|
659
|
+
|
|
660
|
+
Args:
|
|
661
|
+
conversation_id: The conversation ID to export
|
|
662
|
+
output_format: Format for export ("text", "json", "markdown")
|
|
663
|
+
|
|
664
|
+
Returns:
|
|
665
|
+
Formatted prompt history string
|
|
666
|
+
"""
|
|
667
|
+
self._setup()
|
|
668
|
+
|
|
669
|
+
history = self.get_prompt_history(conversation_id)
|
|
670
|
+
|
|
671
|
+
if not history:
|
|
672
|
+
return "No prompt history found."
|
|
673
|
+
|
|
674
|
+
if output_format == "json":
|
|
675
|
+
import json
|
|
676
|
+
|
|
677
|
+
return json.dumps(history, indent=2, default=str)
|
|
678
|
+
|
|
679
|
+
elif output_format == "markdown":
|
|
680
|
+
lines = ["# Conversation History", ""]
|
|
681
|
+
for i, entry in enumerate(history, 1):
|
|
682
|
+
cache_mark = "HIT" if entry["cache_hit"] else "MISS"
|
|
683
|
+
lines.append(f"## Prompt {i} [{cache_mark}]")
|
|
684
|
+
lines.append(f"**Prompt**: {entry['prompt']}")
|
|
685
|
+
lines.append(f"**Cache Hit**: {entry['cache_hit']}")
|
|
686
|
+
lines.append(f"**Timestamp**: {entry['created_at']}")
|
|
687
|
+
lines.append("")
|
|
688
|
+
return "\n".join(lines)
|
|
689
|
+
|
|
690
|
+
else: # text format
|
|
691
|
+
lines = [f"Conversation: {conversation_id}", ""]
|
|
692
|
+
for i, entry in enumerate(history, 1):
|
|
693
|
+
cache_mark = "[CACHE HIT]" if entry["cache_hit"] else "[GENIE]"
|
|
694
|
+
lines.append(f"{i}. {cache_mark} {entry['prompt']}")
|
|
695
|
+
lines.append(f" Timestamp: {entry['created_at']}")
|
|
696
|
+
return "\n".join(lines)
|
|
697
|
+
|
|
698
|
+
def clear_prompt_history(self, conversation_id: str | None = None) -> int:
|
|
699
|
+
"""
|
|
700
|
+
Clear prompt history for a conversation or entire space.
|
|
701
|
+
|
|
702
|
+
Args:
|
|
703
|
+
conversation_id: Specific conversation to clear (None = clear all for space)
|
|
704
|
+
|
|
705
|
+
Returns:
|
|
706
|
+
Number of prompts deleted
|
|
707
|
+
"""
|
|
708
|
+
self._setup()
|
|
709
|
+
|
|
710
|
+
prompt_table_name = self.prompt_history_table
|
|
711
|
+
|
|
712
|
+
if conversation_id:
|
|
713
|
+
delete_sql: str = f"""
|
|
714
|
+
DELETE FROM {prompt_table_name}
|
|
715
|
+
WHERE genie_space_id = %s AND conversation_id = %s
|
|
716
|
+
"""
|
|
717
|
+
params = (self.space_id, conversation_id)
|
|
718
|
+
else:
|
|
719
|
+
delete_sql = f"""
|
|
720
|
+
DELETE FROM {prompt_table_name}
|
|
721
|
+
WHERE genie_space_id = %s
|
|
722
|
+
"""
|
|
723
|
+
params = (self.space_id,)
|
|
724
|
+
|
|
725
|
+
with self._pool.connection() as conn:
|
|
726
|
+
with conn.cursor() as cur:
|
|
727
|
+
cur.execute(delete_sql, params)
|
|
728
|
+
deleted: int = cur.rowcount
|
|
729
|
+
|
|
730
|
+
logger.info(
|
|
731
|
+
"Cleared prompt history",
|
|
732
|
+
layer=self.name,
|
|
733
|
+
conversation_id=conversation_id or "all",
|
|
734
|
+
deleted_count=deleted,
|
|
735
|
+
)
|
|
736
|
+
|
|
737
|
+
return deleted
|
|
738
|
+
|
|
739
|
+
def drop_tables(self) -> dict[str, bool]:
|
|
740
|
+
"""
|
|
741
|
+
Drop both cache and prompt history tables.
|
|
742
|
+
|
|
743
|
+
This is useful for test cleanup to avoid accumulating test tables.
|
|
744
|
+
|
|
745
|
+
Returns:
|
|
746
|
+
Dict with 'cache' and 'prompt_history' keys indicating success
|
|
747
|
+
"""
|
|
748
|
+
self._setup()
|
|
749
|
+
|
|
750
|
+
results: dict[str, bool] = {"cache": False, "prompt_history": False}
|
|
751
|
+
|
|
752
|
+
with self._pool.connection() as conn:
|
|
753
|
+
with conn.cursor() as cur:
|
|
754
|
+
# Drop cache table
|
|
755
|
+
try:
|
|
756
|
+
cur.execute(f"DROP TABLE IF EXISTS {self.table_name} CASCADE")
|
|
757
|
+
results["cache"] = True
|
|
758
|
+
logger.info(
|
|
759
|
+
"Dropped cache table",
|
|
760
|
+
layer=self.name,
|
|
761
|
+
table_name=self.table_name,
|
|
762
|
+
)
|
|
763
|
+
except Exception as e:
|
|
764
|
+
logger.warning(
|
|
765
|
+
f"Failed to drop cache table: {e}",
|
|
766
|
+
layer=self.name,
|
|
767
|
+
table_name=self.table_name,
|
|
768
|
+
)
|
|
769
|
+
|
|
770
|
+
# Drop prompt history table
|
|
771
|
+
try:
|
|
772
|
+
cur.execute(
|
|
773
|
+
f"DROP TABLE IF EXISTS {self.prompt_history_table} CASCADE"
|
|
774
|
+
)
|
|
775
|
+
results["prompt_history"] = True
|
|
776
|
+
logger.info(
|
|
777
|
+
"Dropped prompt history table",
|
|
778
|
+
layer=self.name,
|
|
779
|
+
table_name=self.prompt_history_table,
|
|
780
|
+
)
|
|
781
|
+
except Exception as e:
|
|
782
|
+
logger.warning(
|
|
783
|
+
f"Failed to drop prompt history table: {e}",
|
|
784
|
+
layer=self.name,
|
|
785
|
+
table_name=self.prompt_history_table,
|
|
786
|
+
)
|
|
787
|
+
|
|
788
|
+
return results
|
|
789
|
+
|
|
790
|
+
@property
|
|
791
|
+
def size(self) -> int:
|
|
792
|
+
"""Current number of entries in the cache for this Genie space."""
|
|
793
|
+
self._setup()
|
|
794
|
+
count_sql: str = (
|
|
795
|
+
f"SELECT COUNT(*) as count FROM {self.table_name} WHERE genie_space_id = %s"
|
|
796
|
+
)
|
|
797
|
+
|
|
798
|
+
with self._pool.connection() as conn:
|
|
799
|
+
with conn.cursor() as cur:
|
|
800
|
+
cur.execute(count_sql, (self.space_id,))
|
|
801
|
+
row: DbRow | None = cur.fetchone()
|
|
802
|
+
return row.get("count", 0) if row else 0
|