dao-ai 0.1.2__py3-none-any.whl → 0.1.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/apps/__init__.py +24 -0
- dao_ai/apps/handlers.py +105 -0
- dao_ai/apps/model_serving.py +29 -0
- dao_ai/apps/resources.py +1122 -0
- dao_ai/apps/server.py +39 -0
- dao_ai/cli.py +546 -37
- dao_ai/config.py +1179 -139
- dao_ai/evaluation.py +543 -0
- dao_ai/genie/__init__.py +55 -7
- dao_ai/genie/cache/__init__.py +34 -7
- dao_ai/genie/cache/base.py +143 -2
- dao_ai/genie/cache/context_aware/__init__.py +31 -0
- dao_ai/genie/cache/context_aware/base.py +1151 -0
- dao_ai/genie/cache/context_aware/in_memory.py +609 -0
- dao_ai/genie/cache/context_aware/persistent.py +802 -0
- dao_ai/genie/cache/context_aware/postgres.py +1166 -0
- dao_ai/genie/cache/core.py +1 -1
- dao_ai/genie/cache/lru.py +257 -75
- dao_ai/genie/cache/optimization.py +890 -0
- dao_ai/genie/core.py +235 -11
- dao_ai/memory/postgres.py +175 -39
- dao_ai/middleware/__init__.py +38 -0
- dao_ai/middleware/assertions.py +3 -3
- dao_ai/middleware/context_editing.py +230 -0
- dao_ai/middleware/core.py +4 -4
- dao_ai/middleware/guardrails.py +3 -3
- dao_ai/middleware/human_in_the_loop.py +3 -2
- dao_ai/middleware/message_validation.py +4 -4
- dao_ai/middleware/model_call_limit.py +77 -0
- dao_ai/middleware/model_retry.py +121 -0
- dao_ai/middleware/pii.py +157 -0
- dao_ai/middleware/summarization.py +1 -1
- dao_ai/middleware/tool_call_limit.py +210 -0
- dao_ai/middleware/tool_retry.py +174 -0
- dao_ai/middleware/tool_selector.py +129 -0
- dao_ai/models.py +327 -370
- dao_ai/nodes.py +9 -16
- dao_ai/orchestration/core.py +33 -9
- dao_ai/orchestration/supervisor.py +29 -13
- dao_ai/orchestration/swarm.py +6 -1
- dao_ai/{prompts.py → prompts/__init__.py} +12 -61
- dao_ai/prompts/instructed_retriever_decomposition.yaml +58 -0
- dao_ai/prompts/instruction_reranker.yaml +14 -0
- dao_ai/prompts/router.yaml +37 -0
- dao_ai/prompts/verifier.yaml +46 -0
- dao_ai/providers/base.py +28 -2
- dao_ai/providers/databricks.py +363 -33
- dao_ai/state.py +1 -0
- dao_ai/tools/__init__.py +5 -3
- dao_ai/tools/genie.py +103 -26
- dao_ai/tools/instructed_retriever.py +366 -0
- dao_ai/tools/instruction_reranker.py +202 -0
- dao_ai/tools/mcp.py +539 -97
- dao_ai/tools/router.py +89 -0
- dao_ai/tools/slack.py +13 -2
- dao_ai/tools/sql.py +7 -3
- dao_ai/tools/unity_catalog.py +32 -10
- dao_ai/tools/vector_search.py +493 -160
- dao_ai/tools/verifier.py +159 -0
- dao_ai/utils.py +182 -2
- dao_ai/vector_search.py +46 -1
- {dao_ai-0.1.2.dist-info → dao_ai-0.1.20.dist-info}/METADATA +45 -9
- dao_ai-0.1.20.dist-info/RECORD +89 -0
- dao_ai/agent_as_code.py +0 -22
- dao_ai/genie/cache/semantic.py +0 -970
- dao_ai-0.1.2.dist-info/RECORD +0 -64
- {dao_ai-0.1.2.dist-info → dao_ai-0.1.20.dist-info}/WHEEL +0 -0
- {dao_ai-0.1.2.dist-info → dao_ai-0.1.20.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.1.2.dist-info → dao_ai-0.1.20.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,1166 @@
|
|
|
1
|
+
"""
|
|
2
|
+
PostgreSQL pg_vector-based context-aware Genie cache implementation.
|
|
3
|
+
|
|
4
|
+
This module provides a context-aware cache that uses PostgreSQL with pg_vector
|
|
5
|
+
for semantic similarity search. It supports both standard PostgreSQL and
|
|
6
|
+
Databricks Lakebase connections via the DatabaseModel abstraction.
|
|
7
|
+
|
|
8
|
+
Features:
|
|
9
|
+
- Dual embedding matching (question + conversation context)
|
|
10
|
+
- pg_vector similarity search with L2 distance
|
|
11
|
+
- Prompt history tracking for conversation context
|
|
12
|
+
- TTL-based expiration with refresh-on-hit
|
|
13
|
+
- Space-partitioned cache entries
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
from datetime import datetime, timedelta
|
|
19
|
+
from typing import Any, Self
|
|
20
|
+
|
|
21
|
+
import mlflow
|
|
22
|
+
from databricks.sdk import WorkspaceClient
|
|
23
|
+
from databricks_ai_bridge.genie import GenieResponse
|
|
24
|
+
from loguru import logger
|
|
25
|
+
|
|
26
|
+
from dao_ai.config import (
|
|
27
|
+
DatabaseModel,
|
|
28
|
+
GenieContextAwareCacheParametersModel,
|
|
29
|
+
WarehouseModel,
|
|
30
|
+
)
|
|
31
|
+
from dao_ai.genie.cache.base import (
|
|
32
|
+
CacheResult,
|
|
33
|
+
GenieServiceBase,
|
|
34
|
+
SQLCacheEntry,
|
|
35
|
+
)
|
|
36
|
+
from dao_ai.genie.cache.context_aware.persistent import (
|
|
37
|
+
DbRow,
|
|
38
|
+
PersistentContextAwareGenieCacheService,
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class PostgresContextAwareGenieService(PersistentContextAwareGenieCacheService):
|
|
43
|
+
"""
|
|
44
|
+
PostgreSQL pg_vector-based context-aware caching decorator.
|
|
45
|
+
|
|
46
|
+
This service caches the SQL query generated by Genie along with dual embeddings
|
|
47
|
+
(question + conversation context) for high-precision semantic matching. On
|
|
48
|
+
subsequent queries, it performs similarity search using pg_vector to find
|
|
49
|
+
cached queries that match both the question intent AND conversation context.
|
|
50
|
+
|
|
51
|
+
Supports both standard PostgreSQL and Databricks Lakebase via DatabaseModel.
|
|
52
|
+
|
|
53
|
+
Cache entries are partitioned by genie_space_id to ensure queries from different
|
|
54
|
+
Genie spaces don't return incorrect cache hits.
|
|
55
|
+
|
|
56
|
+
On cache hit, it re-executes the cached SQL using the provided warehouse
|
|
57
|
+
to return fresh data while avoiding the Genie NL-to-SQL translation cost.
|
|
58
|
+
|
|
59
|
+
Example:
|
|
60
|
+
from dao_ai.config import GenieContextAwareCacheParametersModel, DatabaseModel
|
|
61
|
+
from dao_ai.genie.cache.context_aware import PostgresContextAwareGenieService
|
|
62
|
+
|
|
63
|
+
cache_params = GenieContextAwareCacheParametersModel(
|
|
64
|
+
database=database_model,
|
|
65
|
+
warehouse=warehouse_model,
|
|
66
|
+
embedding_model="databricks-gte-large-en",
|
|
67
|
+
time_to_live_seconds=86400, # 24 hours
|
|
68
|
+
similarity_threshold=0.85
|
|
69
|
+
)
|
|
70
|
+
genie = PostgresContextAwareGenieService(
|
|
71
|
+
impl=GenieService(Genie(space_id="my-space")),
|
|
72
|
+
parameters=cache_params
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
Thread-safe: Uses connection pooling from psycopg_pool.
|
|
76
|
+
"""
|
|
77
|
+
|
|
78
|
+
impl: GenieServiceBase
|
|
79
|
+
parameters: GenieContextAwareCacheParametersModel
|
|
80
|
+
_workspace_client: WorkspaceClient | None
|
|
81
|
+
name: str
|
|
82
|
+
_embeddings: Any # DatabricksEmbeddings
|
|
83
|
+
_pool: Any # ConnectionPool
|
|
84
|
+
_embedding_dims: int | None
|
|
85
|
+
_setup_complete: bool
|
|
86
|
+
|
|
87
|
+
def __init__(
|
|
88
|
+
self,
|
|
89
|
+
impl: GenieServiceBase,
|
|
90
|
+
parameters: GenieContextAwareCacheParametersModel,
|
|
91
|
+
workspace_client: WorkspaceClient | None = None,
|
|
92
|
+
name: str | None = None,
|
|
93
|
+
) -> None:
|
|
94
|
+
"""
|
|
95
|
+
Initialize the PostgreSQL context-aware cache service.
|
|
96
|
+
|
|
97
|
+
Args:
|
|
98
|
+
impl: The underlying GenieServiceBase to delegate to on cache miss.
|
|
99
|
+
The space_id will be obtained from impl.space_id.
|
|
100
|
+
parameters: Cache configuration including database, warehouse, embedding model
|
|
101
|
+
workspace_client: Optional WorkspaceClient for retrieving conversation history.
|
|
102
|
+
If None, conversation context will not be used.
|
|
103
|
+
name: Name for this cache layer (for logging). Defaults to class name.
|
|
104
|
+
"""
|
|
105
|
+
self.impl = impl
|
|
106
|
+
self.parameters = parameters
|
|
107
|
+
self._workspace_client = workspace_client
|
|
108
|
+
self.name = name if name is not None else self.__class__.__name__
|
|
109
|
+
self._embeddings = None
|
|
110
|
+
self._pool = None
|
|
111
|
+
self._embedding_dims = None
|
|
112
|
+
self._setup_complete = False
|
|
113
|
+
self._prompt_stored_for_current_request = False
|
|
114
|
+
|
|
115
|
+
def _setup(self) -> None:
|
|
116
|
+
"""Initialize embeddings and database connection pool lazily."""
|
|
117
|
+
if self._setup_complete:
|
|
118
|
+
return
|
|
119
|
+
|
|
120
|
+
from dao_ai.memory.postgres import PostgresPoolManager
|
|
121
|
+
|
|
122
|
+
# Initialize embeddings using base class helper
|
|
123
|
+
self._initialize_embeddings(
|
|
124
|
+
self.parameters.embedding_model,
|
|
125
|
+
self.parameters.embedding_dims,
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
# Get connection pool
|
|
129
|
+
self._pool = PostgresPoolManager.get_pool(self.parameters.database)
|
|
130
|
+
|
|
131
|
+
# Ensure table exists
|
|
132
|
+
self._create_table_if_not_exists()
|
|
133
|
+
|
|
134
|
+
self._setup_complete = True
|
|
135
|
+
logger.debug(
|
|
136
|
+
"PostgreSQL context-aware cache initialized",
|
|
137
|
+
layer=self.name,
|
|
138
|
+
space_id=self.space_id,
|
|
139
|
+
table_name=self.table_name,
|
|
140
|
+
dims=self._embedding_dims,
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
# Property implementations
|
|
144
|
+
@property
|
|
145
|
+
def database(self) -> DatabaseModel:
|
|
146
|
+
"""The database used for storing cache entries."""
|
|
147
|
+
return self.parameters.database
|
|
148
|
+
|
|
149
|
+
@property
|
|
150
|
+
def warehouse(self) -> WarehouseModel:
|
|
151
|
+
"""The warehouse used for executing cached SQL queries."""
|
|
152
|
+
return self.parameters.warehouse
|
|
153
|
+
|
|
154
|
+
@property
|
|
155
|
+
def time_to_live(self) -> timedelta | None:
|
|
156
|
+
"""Time-to-live for cache entries. None means never expires."""
|
|
157
|
+
ttl = self.parameters.time_to_live_seconds
|
|
158
|
+
if ttl is None or ttl < 0:
|
|
159
|
+
return None
|
|
160
|
+
return timedelta(seconds=ttl)
|
|
161
|
+
|
|
162
|
+
@property
|
|
163
|
+
def time_to_live_seconds(self) -> int | None:
|
|
164
|
+
"""TTL in seconds (None or negative = never expires)."""
|
|
165
|
+
return self.parameters.time_to_live_seconds
|
|
166
|
+
|
|
167
|
+
@property
|
|
168
|
+
def similarity_threshold(self) -> float:
|
|
169
|
+
"""Minimum similarity for cache hit (using L2 distance converted to similarity)."""
|
|
170
|
+
return self.parameters.similarity_threshold
|
|
171
|
+
|
|
172
|
+
@property
|
|
173
|
+
def context_similarity_threshold(self) -> float:
|
|
174
|
+
"""Minimum similarity for context matching."""
|
|
175
|
+
return self.parameters.context_similarity_threshold
|
|
176
|
+
|
|
177
|
+
@property
|
|
178
|
+
def question_weight(self) -> float:
|
|
179
|
+
"""Weight for question similarity in combined score."""
|
|
180
|
+
return self.parameters.question_weight
|
|
181
|
+
|
|
182
|
+
@property
|
|
183
|
+
def context_weight(self) -> float:
|
|
184
|
+
"""Weight for context similarity in combined score."""
|
|
185
|
+
return self.parameters.context_weight
|
|
186
|
+
|
|
187
|
+
@property
|
|
188
|
+
def table_name(self) -> str:
|
|
189
|
+
"""Name of the cache table."""
|
|
190
|
+
return self.parameters.table_name
|
|
191
|
+
|
|
192
|
+
@property
|
|
193
|
+
def prompt_history_table(self) -> str:
|
|
194
|
+
"""Name of the prompt history table."""
|
|
195
|
+
return self.parameters.prompt_history_table
|
|
196
|
+
|
|
197
|
+
@property
|
|
198
|
+
def context_window_size(self) -> int:
|
|
199
|
+
"""Number of previous prompts to include in context."""
|
|
200
|
+
return self.parameters.context_window_size
|
|
201
|
+
|
|
202
|
+
@property
|
|
203
|
+
def max_context_tokens(self) -> int:
|
|
204
|
+
"""Maximum tokens for context string."""
|
|
205
|
+
return self.parameters.max_context_tokens
|
|
206
|
+
|
|
207
|
+
@property
|
|
208
|
+
def max_prompt_history_length(self) -> int:
|
|
209
|
+
"""Maximum number of prompts to keep per conversation."""
|
|
210
|
+
return self.parameters.max_prompt_history_length
|
|
211
|
+
|
|
212
|
+
def _create_table_if_not_exists(self) -> None:
|
|
213
|
+
"""Create the cache table and prompt history table with pg_vector extension."""
|
|
214
|
+
create_extension_sql: str = "CREATE EXTENSION IF NOT EXISTS vector"
|
|
215
|
+
|
|
216
|
+
# Check if table exists and get current embedding dimensions
|
|
217
|
+
check_dims_sql: str = """
|
|
218
|
+
SELECT atttypmod
|
|
219
|
+
FROM pg_attribute
|
|
220
|
+
WHERE attrelid = %s::regclass
|
|
221
|
+
AND attname = 'question_embedding'
|
|
222
|
+
"""
|
|
223
|
+
|
|
224
|
+
create_table_sql: str = f"""
|
|
225
|
+
CREATE TABLE IF NOT EXISTS {self.table_name} (
|
|
226
|
+
id SERIAL PRIMARY KEY,
|
|
227
|
+
genie_space_id TEXT NOT NULL,
|
|
228
|
+
question TEXT NOT NULL,
|
|
229
|
+
conversation_context TEXT,
|
|
230
|
+
context_string TEXT,
|
|
231
|
+
question_embedding vector({self.embedding_dims}),
|
|
232
|
+
context_embedding vector({self.embedding_dims}),
|
|
233
|
+
sql_query TEXT NOT NULL,
|
|
234
|
+
description TEXT,
|
|
235
|
+
conversation_id TEXT,
|
|
236
|
+
message_id TEXT,
|
|
237
|
+
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
|
|
238
|
+
)
|
|
239
|
+
"""
|
|
240
|
+
|
|
241
|
+
# Migration: Add message_id column if it doesn't exist
|
|
242
|
+
add_message_id_sql: str = f"""
|
|
243
|
+
ALTER TABLE {self.table_name}
|
|
244
|
+
ADD COLUMN IF NOT EXISTS message_id TEXT
|
|
245
|
+
"""
|
|
246
|
+
# Index for efficient similarity search partitioned by genie_space_id
|
|
247
|
+
create_question_embedding_index_sql: str = f"""
|
|
248
|
+
CREATE INDEX IF NOT EXISTS {self.table_name}_question_embedding_idx
|
|
249
|
+
ON {self.table_name}
|
|
250
|
+
USING ivfflat (question_embedding vector_l2_ops)
|
|
251
|
+
WITH (lists = 100)
|
|
252
|
+
"""
|
|
253
|
+
create_context_embedding_index_sql: str = f"""
|
|
254
|
+
CREATE INDEX IF NOT EXISTS {self.table_name}_context_embedding_idx
|
|
255
|
+
ON {self.table_name}
|
|
256
|
+
USING ivfflat (context_embedding vector_l2_ops)
|
|
257
|
+
WITH (lists = 100)
|
|
258
|
+
"""
|
|
259
|
+
create_space_index_sql: str = f"""
|
|
260
|
+
CREATE INDEX IF NOT EXISTS {self.table_name}_space_idx
|
|
261
|
+
ON {self.table_name} (genie_space_id)
|
|
262
|
+
"""
|
|
263
|
+
create_unique_question_index_sql: str = f"""
|
|
264
|
+
CREATE UNIQUE INDEX IF NOT EXISTS {self.table_name}_unique_question_idx
|
|
265
|
+
ON {self.table_name} (genie_space_id, question)
|
|
266
|
+
"""
|
|
267
|
+
|
|
268
|
+
with self._pool.connection() as conn:
|
|
269
|
+
with conn.cursor() as cur:
|
|
270
|
+
cur.execute(create_extension_sql)
|
|
271
|
+
|
|
272
|
+
# Check if table exists and verify embedding dimensions
|
|
273
|
+
try:
|
|
274
|
+
cur.execute(check_dims_sql, (self.table_name,))
|
|
275
|
+
row: DbRow | None = cur.fetchone()
|
|
276
|
+
if row is not None:
|
|
277
|
+
current_dims = row.get("atttypmod", 0)
|
|
278
|
+
if current_dims != self.embedding_dims:
|
|
279
|
+
logger.warning(
|
|
280
|
+
"Embedding dimension mismatch, dropping and recreating table",
|
|
281
|
+
layer=self.name,
|
|
282
|
+
table_dims=current_dims,
|
|
283
|
+
expected_dims=self.embedding_dims,
|
|
284
|
+
table_name=self.table_name,
|
|
285
|
+
)
|
|
286
|
+
cur.execute(f"DROP TABLE {self.table_name}")
|
|
287
|
+
except Exception:
|
|
288
|
+
pass
|
|
289
|
+
|
|
290
|
+
try:
|
|
291
|
+
cur.execute(create_table_sql)
|
|
292
|
+
except Exception as e:
|
|
293
|
+
logger.debug(
|
|
294
|
+
f"Table creation skipped (may already exist): {e}",
|
|
295
|
+
layer=self.name,
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
# Migration: Add message_id column if it doesn't exist (for existing tables)
|
|
299
|
+
try:
|
|
300
|
+
cur.execute(add_message_id_sql)
|
|
301
|
+
logger.debug(
|
|
302
|
+
"Added message_id column (or already exists)",
|
|
303
|
+
layer=self.name,
|
|
304
|
+
table_name=self.table_name,
|
|
305
|
+
)
|
|
306
|
+
except Exception as e:
|
|
307
|
+
# Column might already exist or other error
|
|
308
|
+
logger.debug(
|
|
309
|
+
f"message_id column migration skipped: {e}",
|
|
310
|
+
layer=self.name,
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
# Create indexes
|
|
314
|
+
for idx_name, idx_sql in [
|
|
315
|
+
(f"{self.table_name}_space_idx", create_space_index_sql),
|
|
316
|
+
(
|
|
317
|
+
f"{self.table_name}_question_embedding_idx",
|
|
318
|
+
create_question_embedding_index_sql,
|
|
319
|
+
),
|
|
320
|
+
(
|
|
321
|
+
f"{self.table_name}_context_embedding_idx",
|
|
322
|
+
create_context_embedding_index_sql,
|
|
323
|
+
),
|
|
324
|
+
(
|
|
325
|
+
f"{self.table_name}_unique_question_idx",
|
|
326
|
+
create_unique_question_index_sql,
|
|
327
|
+
),
|
|
328
|
+
]:
|
|
329
|
+
if self._index_exists(cur, idx_name):
|
|
330
|
+
logger.debug(
|
|
331
|
+
f"Index {idx_name} already exists", layer=self.name
|
|
332
|
+
)
|
|
333
|
+
continue
|
|
334
|
+
try:
|
|
335
|
+
cur.execute(idx_sql)
|
|
336
|
+
except Exception as e:
|
|
337
|
+
logger.warning(
|
|
338
|
+
f"Could not create {idx_name}: {e}", layer=self.name
|
|
339
|
+
)
|
|
340
|
+
|
|
341
|
+
# Create prompt history table
|
|
342
|
+
try:
|
|
343
|
+
self._create_prompt_history_table(cur)
|
|
344
|
+
except Exception as e:
|
|
345
|
+
logger.error(
|
|
346
|
+
f"Failed to create prompt history table: {e}",
|
|
347
|
+
layer=self.name,
|
|
348
|
+
exc_info=True,
|
|
349
|
+
)
|
|
350
|
+
|
|
351
|
+
def _create_prompt_history_table(self, cur: Any) -> None:
|
|
352
|
+
"""Create the prompt history table for tracking user prompts."""
|
|
353
|
+
prompt_table_name = self.prompt_history_table
|
|
354
|
+
|
|
355
|
+
create_prompt_table_sql: str = f"""
|
|
356
|
+
CREATE TABLE IF NOT EXISTS {prompt_table_name} (
|
|
357
|
+
id SERIAL PRIMARY KEY,
|
|
358
|
+
genie_space_id TEXT NOT NULL,
|
|
359
|
+
conversation_id TEXT NOT NULL,
|
|
360
|
+
prompt TEXT NOT NULL,
|
|
361
|
+
cache_hit BOOLEAN DEFAULT FALSE,
|
|
362
|
+
cache_entry_id INTEGER,
|
|
363
|
+
created_at TIMESTAMP WITH TIME ZONE DEFAULT CURRENT_TIMESTAMP
|
|
364
|
+
)
|
|
365
|
+
"""
|
|
366
|
+
|
|
367
|
+
# Migration: Add cache_entry_id column if it doesn't exist
|
|
368
|
+
add_cache_entry_id_sql: str = f"""
|
|
369
|
+
ALTER TABLE {prompt_table_name}
|
|
370
|
+
ADD COLUMN IF NOT EXISTS cache_entry_id INTEGER
|
|
371
|
+
"""
|
|
372
|
+
|
|
373
|
+
create_conversation_index_sql: str = f"""
|
|
374
|
+
CREATE INDEX IF NOT EXISTS {prompt_table_name}_conversation_idx
|
|
375
|
+
ON {prompt_table_name} (genie_space_id, conversation_id, created_at DESC)
|
|
376
|
+
"""
|
|
377
|
+
create_space_index_sql: str = f"""
|
|
378
|
+
CREATE INDEX IF NOT EXISTS {prompt_table_name}_space_idx
|
|
379
|
+
ON {prompt_table_name} (genie_space_id, created_at DESC)
|
|
380
|
+
"""
|
|
381
|
+
create_unique_prompt_index_sql: str = f"""
|
|
382
|
+
CREATE UNIQUE INDEX IF NOT EXISTS {prompt_table_name}_unique_prompt_idx
|
|
383
|
+
ON {prompt_table_name} (genie_space_id, conversation_id, prompt)
|
|
384
|
+
"""
|
|
385
|
+
create_cache_entry_index_sql: str = f"""
|
|
386
|
+
CREATE INDEX IF NOT EXISTS {prompt_table_name}_cache_entry_idx
|
|
387
|
+
ON {prompt_table_name} (cache_entry_id)
|
|
388
|
+
WHERE cache_entry_id IS NOT NULL
|
|
389
|
+
"""
|
|
390
|
+
|
|
391
|
+
try:
|
|
392
|
+
cur.execute(create_prompt_table_sql)
|
|
393
|
+
except Exception as e:
|
|
394
|
+
if "duplicate key" in str(e) or "already exists" in str(e):
|
|
395
|
+
logger.debug("Prompt history table already exists", layer=self.name)
|
|
396
|
+
else:
|
|
397
|
+
raise
|
|
398
|
+
|
|
399
|
+
# Migration: Add cache_entry_id column if it doesn't exist (for existing tables)
|
|
400
|
+
try:
|
|
401
|
+
cur.execute(add_cache_entry_id_sql)
|
|
402
|
+
logger.debug(
|
|
403
|
+
"Added cache_entry_id column (or already exists)",
|
|
404
|
+
layer=self.name,
|
|
405
|
+
table_name=prompt_table_name,
|
|
406
|
+
)
|
|
407
|
+
except Exception as e:
|
|
408
|
+
logger.debug(
|
|
409
|
+
f"cache_entry_id column migration skipped: {e}",
|
|
410
|
+
layer=self.name,
|
|
411
|
+
)
|
|
412
|
+
|
|
413
|
+
for idx_name, idx_sql in [
|
|
414
|
+
(f"{prompt_table_name}_conversation_idx", create_conversation_index_sql),
|
|
415
|
+
(f"{prompt_table_name}_space_idx", create_space_index_sql),
|
|
416
|
+
(f"{prompt_table_name}_unique_prompt_idx", create_unique_prompt_index_sql),
|
|
417
|
+
(f"{prompt_table_name}_cache_entry_idx", create_cache_entry_index_sql),
|
|
418
|
+
]:
|
|
419
|
+
if self._index_exists(cur, idx_name):
|
|
420
|
+
continue
|
|
421
|
+
try:
|
|
422
|
+
cur.execute(idx_sql)
|
|
423
|
+
except Exception as e:
|
|
424
|
+
if "duplicate key" not in str(e) and "already exists" not in str(e):
|
|
425
|
+
logger.warning(
|
|
426
|
+
f"Could not create index {idx_name}: {e}", layer=self.name
|
|
427
|
+
)
|
|
428
|
+
|
|
429
|
+
logger.info(
|
|
430
|
+
"Prompt history table ready", layer=self.name, table=prompt_table_name
|
|
431
|
+
)
|
|
432
|
+
|
|
433
|
+
@mlflow.trace(name="semantic_search_postgres")
|
|
434
|
+
def _find_similar(
|
|
435
|
+
self,
|
|
436
|
+
question: str,
|
|
437
|
+
conversation_context: str,
|
|
438
|
+
question_embedding: list[float],
|
|
439
|
+
context_embedding: list[float],
|
|
440
|
+
conversation_id: str | None = None,
|
|
441
|
+
) -> tuple[SQLCacheEntry, float] | None:
|
|
442
|
+
"""Find a semantically similar cached entry using pg_vector."""
|
|
443
|
+
ttl_seconds = self.time_to_live_seconds
|
|
444
|
+
ttl_disabled = ttl_seconds is None or ttl_seconds < 0
|
|
445
|
+
|
|
446
|
+
if ttl_disabled:
|
|
447
|
+
is_valid_expr = "TRUE"
|
|
448
|
+
else:
|
|
449
|
+
is_valid_expr = f"created_at > NOW() - INTERVAL '{ttl_seconds} seconds'"
|
|
450
|
+
|
|
451
|
+
question_weight = self.question_weight
|
|
452
|
+
context_weight = self.context_weight
|
|
453
|
+
|
|
454
|
+
search_sql: str = f"""
|
|
455
|
+
SELECT
|
|
456
|
+
id,
|
|
457
|
+
question,
|
|
458
|
+
conversation_context,
|
|
459
|
+
sql_query,
|
|
460
|
+
description,
|
|
461
|
+
conversation_id,
|
|
462
|
+
message_id,
|
|
463
|
+
created_at,
|
|
464
|
+
1.0 / (1.0 + (question_embedding <-> %s::vector)) as question_similarity,
|
|
465
|
+
1.0 / (1.0 + (context_embedding <-> %s::vector)) as context_similarity,
|
|
466
|
+
({question_weight} * (1.0 / (1.0 + (question_embedding <-> %s::vector)))) +
|
|
467
|
+
({context_weight} * (1.0 / (1.0 + (context_embedding <-> %s::vector)))) as combined_similarity,
|
|
468
|
+
{is_valid_expr} as is_valid
|
|
469
|
+
FROM {self.table_name}
|
|
470
|
+
WHERE genie_space_id = %s
|
|
471
|
+
ORDER BY combined_similarity DESC
|
|
472
|
+
LIMIT 1
|
|
473
|
+
"""
|
|
474
|
+
|
|
475
|
+
question_emb_str = f"[{','.join(str(x) for x in question_embedding)}]"
|
|
476
|
+
context_emb_str = f"[{','.join(str(x) for x in context_embedding)}]"
|
|
477
|
+
|
|
478
|
+
with self._pool.connection() as conn:
|
|
479
|
+
with conn.cursor() as cur:
|
|
480
|
+
cur.execute(
|
|
481
|
+
search_sql,
|
|
482
|
+
(
|
|
483
|
+
question_emb_str,
|
|
484
|
+
context_emb_str,
|
|
485
|
+
question_emb_str,
|
|
486
|
+
context_emb_str,
|
|
487
|
+
self.space_id,
|
|
488
|
+
),
|
|
489
|
+
)
|
|
490
|
+
row: DbRow | None = cur.fetchone()
|
|
491
|
+
|
|
492
|
+
if row is None:
|
|
493
|
+
logger.info(
|
|
494
|
+
"Cache MISS (no entries)",
|
|
495
|
+
layer=self.name,
|
|
496
|
+
question=question[:50],
|
|
497
|
+
space=self.space_id,
|
|
498
|
+
)
|
|
499
|
+
return None
|
|
500
|
+
|
|
501
|
+
entry_id = row.get("id")
|
|
502
|
+
cached_question = row.get("question", "")
|
|
503
|
+
sql_query = row["sql_query"]
|
|
504
|
+
description = row.get("description", "")
|
|
505
|
+
conversation_id_cached = row.get("conversation_id", "")
|
|
506
|
+
created_at = row["created_at"]
|
|
507
|
+
question_similarity = row["question_similarity"]
|
|
508
|
+
context_similarity = row["context_similarity"]
|
|
509
|
+
combined_similarity = row["combined_similarity"]
|
|
510
|
+
is_valid = row.get("is_valid", False)
|
|
511
|
+
|
|
512
|
+
logger.debug(
|
|
513
|
+
"Best match found",
|
|
514
|
+
layer=self.name,
|
|
515
|
+
question_sim=f"{question_similarity:.4f}",
|
|
516
|
+
context_sim=f"{context_similarity:.4f}",
|
|
517
|
+
combined_sim=f"{combined_similarity:.4f}",
|
|
518
|
+
is_valid=is_valid,
|
|
519
|
+
)
|
|
520
|
+
|
|
521
|
+
if question_similarity < self.similarity_threshold:
|
|
522
|
+
logger.info(
|
|
523
|
+
"Cache MISS (question similarity too low)",
|
|
524
|
+
layer=self.name,
|
|
525
|
+
question_sim=f"{question_similarity:.4f}",
|
|
526
|
+
threshold=self.similarity_threshold,
|
|
527
|
+
)
|
|
528
|
+
return None
|
|
529
|
+
|
|
530
|
+
if context_similarity < self.context_similarity_threshold:
|
|
531
|
+
logger.info(
|
|
532
|
+
"Cache MISS (context similarity too low)",
|
|
533
|
+
layer=self.name,
|
|
534
|
+
context_sim=f"{context_similarity:.4f}",
|
|
535
|
+
threshold=self.context_similarity_threshold,
|
|
536
|
+
)
|
|
537
|
+
return None
|
|
538
|
+
|
|
539
|
+
if not is_valid:
|
|
540
|
+
cur.execute(
|
|
541
|
+
f"DELETE FROM {self.table_name} WHERE id = %s", (entry_id,)
|
|
542
|
+
)
|
|
543
|
+
logger.info("Cache MISS (expired, deleted)", layer=self.name)
|
|
544
|
+
return None
|
|
545
|
+
|
|
546
|
+
cache_age_seconds = None
|
|
547
|
+
if created_at:
|
|
548
|
+
cache_age_seconds = (
|
|
549
|
+
datetime.now(created_at.tzinfo) - created_at
|
|
550
|
+
).total_seconds()
|
|
551
|
+
|
|
552
|
+
logger.info(
|
|
553
|
+
"Cache HIT",
|
|
554
|
+
layer=self.name,
|
|
555
|
+
question=question[:80],
|
|
556
|
+
matched_question=cached_question[:80],
|
|
557
|
+
cache_age_seconds=round(cache_age_seconds, 1)
|
|
558
|
+
if cache_age_seconds
|
|
559
|
+
else None,
|
|
560
|
+
question_similarity=f"{question_similarity:.4f}",
|
|
561
|
+
context_similarity=f"{context_similarity:.4f}",
|
|
562
|
+
combined_similarity=f"{combined_similarity:.4f}",
|
|
563
|
+
)
|
|
564
|
+
|
|
565
|
+
message_id_cached = row.get("message_id")
|
|
566
|
+
|
|
567
|
+
entry = SQLCacheEntry(
|
|
568
|
+
query=sql_query,
|
|
569
|
+
description=description,
|
|
570
|
+
conversation_id=conversation_id_cached,
|
|
571
|
+
created_at=created_at,
|
|
572
|
+
message_id=message_id_cached,
|
|
573
|
+
cache_entry_id=entry_id,
|
|
574
|
+
)
|
|
575
|
+
return entry, combined_similarity
|
|
576
|
+
|
|
577
|
+
def _store_entry(
|
|
578
|
+
self,
|
|
579
|
+
question: str,
|
|
580
|
+
conversation_context: str,
|
|
581
|
+
question_embedding: list[float],
|
|
582
|
+
context_embedding: list[float],
|
|
583
|
+
response: GenieResponse,
|
|
584
|
+
message_id: str | None = None,
|
|
585
|
+
) -> None:
|
|
586
|
+
"""Store a new cache entry with dual embeddings and message_id."""
|
|
587
|
+
insert_sql: str = f"""
|
|
588
|
+
INSERT INTO {self.table_name}
|
|
589
|
+
(genie_space_id, question, conversation_context, context_string,
|
|
590
|
+
question_embedding, context_embedding, sql_query, description,
|
|
591
|
+
conversation_id, message_id)
|
|
592
|
+
VALUES (%s, %s, %s, %s, %s::vector, %s::vector, %s, %s, %s, %s)
|
|
593
|
+
"""
|
|
594
|
+
question_emb_str = f"[{','.join(str(x) for x in question_embedding)}]"
|
|
595
|
+
context_emb_str = f"[{','.join(str(x) for x in context_embedding)}]"
|
|
596
|
+
|
|
597
|
+
if conversation_context:
|
|
598
|
+
full_context_string = f"{conversation_context}\nCurrent: {question}"
|
|
599
|
+
else:
|
|
600
|
+
full_context_string = question
|
|
601
|
+
|
|
602
|
+
with self._pool.connection() as conn:
|
|
603
|
+
with conn.cursor() as cur:
|
|
604
|
+
cur.execute(
|
|
605
|
+
insert_sql,
|
|
606
|
+
(
|
|
607
|
+
self.space_id,
|
|
608
|
+
question,
|
|
609
|
+
conversation_context,
|
|
610
|
+
full_context_string,
|
|
611
|
+
question_emb_str,
|
|
612
|
+
context_emb_str,
|
|
613
|
+
response.query,
|
|
614
|
+
response.description,
|
|
615
|
+
response.conversation_id,
|
|
616
|
+
message_id,
|
|
617
|
+
),
|
|
618
|
+
)
|
|
619
|
+
logger.debug(
|
|
620
|
+
"Stored cache entry",
|
|
621
|
+
layer=self.name,
|
|
622
|
+
question=question[:50],
|
|
623
|
+
space=self.space_id,
|
|
624
|
+
message_id=message_id,
|
|
625
|
+
)
|
|
626
|
+
|
|
627
|
+
def _on_stale_cache_entry(self, question: str) -> None:
|
|
628
|
+
"""Delete stale cache entry from database."""
|
|
629
|
+
delete_sql = (
|
|
630
|
+
f"DELETE FROM {self.table_name} WHERE genie_space_id = %s AND question = %s"
|
|
631
|
+
)
|
|
632
|
+
with self._pool.connection() as conn:
|
|
633
|
+
with conn.cursor() as cur:
|
|
634
|
+
cur.execute(delete_sql, (self.space_id, question))
|
|
635
|
+
deleted_rows = cur.rowcount
|
|
636
|
+
logger.info(
|
|
637
|
+
"Deleted stale cache entry",
|
|
638
|
+
layer=self.name,
|
|
639
|
+
deleted_rows=deleted_rows,
|
|
640
|
+
space_id=self.space_id,
|
|
641
|
+
)
|
|
642
|
+
|
|
643
|
+
def _invalidate_by_question(self, question: str) -> bool:
|
|
644
|
+
"""
|
|
645
|
+
Invalidate cache entries matching a specific question.
|
|
646
|
+
|
|
647
|
+
This method is called when negative feedback is received to remove
|
|
648
|
+
the corresponding cache entry from the PostgreSQL database.
|
|
649
|
+
|
|
650
|
+
Args:
|
|
651
|
+
question: The question text to match and invalidate
|
|
652
|
+
|
|
653
|
+
Returns:
|
|
654
|
+
True if an entry was found and invalidated, False otherwise
|
|
655
|
+
"""
|
|
656
|
+
delete_sql = (
|
|
657
|
+
f"DELETE FROM {self.table_name} WHERE genie_space_id = %s AND question = %s"
|
|
658
|
+
)
|
|
659
|
+
with self._pool.connection() as conn:
|
|
660
|
+
with conn.cursor() as cur:
|
|
661
|
+
cur.execute(delete_sql, (self.space_id, question))
|
|
662
|
+
deleted_rows = cur.rowcount if isinstance(cur.rowcount, int) else 0
|
|
663
|
+
if deleted_rows > 0:
|
|
664
|
+
logger.info(
|
|
665
|
+
"Invalidated cache entry by question",
|
|
666
|
+
layer=self.name,
|
|
667
|
+
question=question[:50],
|
|
668
|
+
deleted_rows=deleted_rows,
|
|
669
|
+
space_id=self.space_id,
|
|
670
|
+
)
|
|
671
|
+
return True
|
|
672
|
+
return False
|
|
673
|
+
|
|
674
|
+
# Template Method hook implementations
|
|
675
|
+
|
|
676
|
+
def _before_cache_lookup(self, question: str, conversation_id: str | None) -> None:
|
|
677
|
+
"""Store prompt before cache lookup."""
|
|
678
|
+
if conversation_id:
|
|
679
|
+
self._store_user_prompt(
|
|
680
|
+
prompt=question,
|
|
681
|
+
conversation_id=conversation_id,
|
|
682
|
+
cache_hit=False,
|
|
683
|
+
)
|
|
684
|
+
# Track that we stored the prompt
|
|
685
|
+
self._prompt_stored_for_current_request = True
|
|
686
|
+
else:
|
|
687
|
+
self._prompt_stored_for_current_request = False
|
|
688
|
+
|
|
689
|
+
def _after_cache_hit(
|
|
690
|
+
self,
|
|
691
|
+
question: str,
|
|
692
|
+
conversation_id: str | None,
|
|
693
|
+
result: CacheResult,
|
|
694
|
+
) -> None:
|
|
695
|
+
"""Update cache_hit flag and cache_entry_id after a cache hit."""
|
|
696
|
+
if result.cache_hit and self._prompt_stored_for_current_request:
|
|
697
|
+
actual_conv_id = result.response.conversation_id or conversation_id
|
|
698
|
+
if actual_conv_id:
|
|
699
|
+
self._update_prompt_cache_hit(
|
|
700
|
+
conversation_id=actual_conv_id,
|
|
701
|
+
prompt=question,
|
|
702
|
+
cache_hit=True,
|
|
703
|
+
cache_entry_id=result.cache_entry_id,
|
|
704
|
+
)
|
|
705
|
+
|
|
706
|
+
def _after_cache_miss(
|
|
707
|
+
self,
|
|
708
|
+
question: str,
|
|
709
|
+
conversation_id: str | None,
|
|
710
|
+
result: CacheResult,
|
|
711
|
+
) -> None:
|
|
712
|
+
"""Store prompt if not done earlier (when conversation_id comes from response)."""
|
|
713
|
+
if (
|
|
714
|
+
not self._prompt_stored_for_current_request
|
|
715
|
+
and result.response.conversation_id
|
|
716
|
+
):
|
|
717
|
+
self._store_user_prompt(
|
|
718
|
+
prompt=question,
|
|
719
|
+
conversation_id=result.response.conversation_id,
|
|
720
|
+
cache_hit=False,
|
|
721
|
+
)
|
|
722
|
+
|
|
723
|
+
# Template Method implementations for invalidate_expired() and clear()
|
|
724
|
+
|
|
725
|
+
def _get_empty_expiration_result(self) -> dict[str, int]:
|
|
726
|
+
"""Return empty dict for PostgresContextAwareGenieService."""
|
|
727
|
+
return {"cache": 0, "prompt_history": 0}
|
|
728
|
+
|
|
729
|
+
def _delete_expired_entries(self, ttl_seconds: int) -> dict[str, int]:
|
|
730
|
+
"""Delete expired entries from cache and prompt history."""
|
|
731
|
+
prompt_ttl_seconds = self.parameters.prompt_history_ttl_seconds
|
|
732
|
+
if prompt_ttl_seconds is None:
|
|
733
|
+
prompt_ttl_seconds = ttl_seconds
|
|
734
|
+
|
|
735
|
+
result: dict[str, int] = {"cache": 0, "prompt_history": 0}
|
|
736
|
+
|
|
737
|
+
# Delete expired cache entries
|
|
738
|
+
delete_cache_sql = f"""
|
|
739
|
+
DELETE FROM {self.table_name}
|
|
740
|
+
WHERE genie_space_id = %s
|
|
741
|
+
AND created_at < NOW() - INTERVAL '%s seconds'
|
|
742
|
+
"""
|
|
743
|
+
|
|
744
|
+
with self._pool.connection() as conn:
|
|
745
|
+
with conn.cursor() as cur:
|
|
746
|
+
cur.execute(delete_cache_sql, (self.space_id, ttl_seconds))
|
|
747
|
+
deleted = cur.rowcount if isinstance(cur.rowcount, int) else 0
|
|
748
|
+
result["cache"] = deleted
|
|
749
|
+
logger.debug(
|
|
750
|
+
"Deleted expired cache entries",
|
|
751
|
+
layer=self.name,
|
|
752
|
+
deleted_count=deleted,
|
|
753
|
+
)
|
|
754
|
+
|
|
755
|
+
# Delete expired prompt history
|
|
756
|
+
if prompt_ttl_seconds is not None and prompt_ttl_seconds >= 0:
|
|
757
|
+
try:
|
|
758
|
+
delete_prompt_sql = f"""
|
|
759
|
+
DELETE FROM {self.prompt_history_table}
|
|
760
|
+
WHERE genie_space_id = %s
|
|
761
|
+
AND created_at < NOW() - INTERVAL '%s seconds'
|
|
762
|
+
"""
|
|
763
|
+
with self._pool.connection() as conn:
|
|
764
|
+
with conn.cursor() as cur:
|
|
765
|
+
cur.execute(
|
|
766
|
+
delete_prompt_sql, (self.space_id, prompt_ttl_seconds)
|
|
767
|
+
)
|
|
768
|
+
deleted = cur.rowcount if isinstance(cur.rowcount, int) else 0
|
|
769
|
+
result["prompt_history"] = deleted
|
|
770
|
+
except Exception as e:
|
|
771
|
+
logger.warning(
|
|
772
|
+
f"Failed to clean up prompt history: {e}", layer=self.name
|
|
773
|
+
)
|
|
774
|
+
|
|
775
|
+
return result
|
|
776
|
+
|
|
777
|
+
def _delete_all_entries(self) -> int:
|
|
778
|
+
"""Delete all cache entries for this Genie space."""
|
|
779
|
+
delete_sql = f"DELETE FROM {self.table_name} WHERE genie_space_id = %s"
|
|
780
|
+
|
|
781
|
+
with self._pool.connection() as conn:
|
|
782
|
+
with conn.cursor() as cur:
|
|
783
|
+
cur.execute(delete_sql, (self.space_id,))
|
|
784
|
+
deleted: int = cur.rowcount
|
|
785
|
+
logger.debug(
|
|
786
|
+
"Cleared cache entries", layer=self.name, deleted_count=deleted
|
|
787
|
+
)
|
|
788
|
+
return deleted
|
|
789
|
+
|
|
790
|
+
# Template Method implementations for stats()
|
|
791
|
+
|
|
792
|
+
def _count_all_entries(self) -> int:
|
|
793
|
+
"""Count all cache entries for this Genie space."""
|
|
794
|
+
count_sql = (
|
|
795
|
+
f"SELECT COUNT(*) as total FROM {self.table_name} WHERE genie_space_id = %s"
|
|
796
|
+
)
|
|
797
|
+
with self._pool.connection() as conn:
|
|
798
|
+
with conn.cursor() as cur:
|
|
799
|
+
cur.execute(count_sql, (self.space_id,))
|
|
800
|
+
row = cur.fetchone()
|
|
801
|
+
return row.get("total", 0) if row else 0
|
|
802
|
+
|
|
803
|
+
def _count_entries_with_ttl(self, ttl_seconds: int) -> tuple[int, int]:
|
|
804
|
+
"""Count total and expired entries for this Genie space."""
|
|
805
|
+
stats_sql = f"""
|
|
806
|
+
SELECT
|
|
807
|
+
COUNT(*) as total,
|
|
808
|
+
COUNT(*) FILTER (WHERE created_at <= NOW() - INTERVAL '%s seconds') as expired
|
|
809
|
+
FROM {self.table_name}
|
|
810
|
+
WHERE genie_space_id = %s
|
|
811
|
+
"""
|
|
812
|
+
with self._pool.connection() as conn:
|
|
813
|
+
with conn.cursor() as cur:
|
|
814
|
+
cur.execute(stats_sql, (ttl_seconds, self.space_id))
|
|
815
|
+
row = cur.fetchone()
|
|
816
|
+
if row:
|
|
817
|
+
return row.get("total", 0), row.get("expired", 0)
|
|
818
|
+
return 0, 0
|
|
819
|
+
|
|
820
|
+
def _get_additional_stats(self) -> dict[str, Any]:
|
|
821
|
+
"""Add prompt history stats."""
|
|
822
|
+
prompt_stats_sql = f"""
|
|
823
|
+
SELECT
|
|
824
|
+
COUNT(*) as total_prompts,
|
|
825
|
+
COUNT(*) FILTER (WHERE cache_hit = true) as cache_hit_prompts,
|
|
826
|
+
COUNT(*) FILTER (WHERE cache_hit = false) as cache_miss_prompts,
|
|
827
|
+
COUNT(DISTINCT conversation_id) as total_conversations
|
|
828
|
+
FROM {self.prompt_history_table}
|
|
829
|
+
WHERE genie_space_id = %s
|
|
830
|
+
"""
|
|
831
|
+
with self._pool.connection() as conn:
|
|
832
|
+
with conn.cursor() as cur:
|
|
833
|
+
cur.execute(prompt_stats_sql, (self.space_id,))
|
|
834
|
+
row = cur.fetchone()
|
|
835
|
+
if row:
|
|
836
|
+
total_prompts = row.get("total_prompts", 0)
|
|
837
|
+
return {
|
|
838
|
+
"prompt_history": {
|
|
839
|
+
"total_prompts": total_prompts,
|
|
840
|
+
"cache_hit_prompts": row.get("cache_hit_prompts", 0),
|
|
841
|
+
"cache_miss_prompts": row.get("cache_miss_prompts", 0),
|
|
842
|
+
"total_conversations": row.get("total_conversations", 0),
|
|
843
|
+
"cache_hit_rate": (
|
|
844
|
+
row.get("cache_hit_prompts", 0) / total_prompts
|
|
845
|
+
if total_prompts > 0
|
|
846
|
+
else 0.0
|
|
847
|
+
),
|
|
848
|
+
}
|
|
849
|
+
}
|
|
850
|
+
return {}
|
|
851
|
+
|
|
852
|
+
def from_space(
|
|
853
|
+
self,
|
|
854
|
+
space_id: str | None = None,
|
|
855
|
+
*,
|
|
856
|
+
include_all_messages: bool = True,
|
|
857
|
+
from_datetime: datetime | None = None,
|
|
858
|
+
to_datetime: datetime | None = None,
|
|
859
|
+
max_messages: int | None = None,
|
|
860
|
+
) -> Self:
|
|
861
|
+
"""Populate cache from existing Genie space conversations.
|
|
862
|
+
|
|
863
|
+
Fetches all conversations from a Genie space and populates:
|
|
864
|
+
1. Prompt history table - all user messages
|
|
865
|
+
2. Cache embeddings table - messages with SQL query attachments
|
|
866
|
+
|
|
867
|
+
Uses ON CONFLICT DO NOTHING to avoid duplicate entries.
|
|
868
|
+
|
|
869
|
+
Args:
|
|
870
|
+
space_id: Genie space ID to import from (defaults to self.space_id)
|
|
871
|
+
include_all_messages: If True, fetch all users' conversations
|
|
872
|
+
from_datetime: Only include messages after this time
|
|
873
|
+
to_datetime: Only include messages before this time
|
|
874
|
+
max_messages: Limit to last N messages (most recent first)
|
|
875
|
+
|
|
876
|
+
Returns:
|
|
877
|
+
self for method chaining
|
|
878
|
+
"""
|
|
879
|
+
if self.workspace_client is None:
|
|
880
|
+
raise ValueError("workspace_client is required for from_space()")
|
|
881
|
+
|
|
882
|
+
self._setup()
|
|
883
|
+
target_space_id = space_id or self.space_id
|
|
884
|
+
|
|
885
|
+
logger.info(
|
|
886
|
+
"Starting from_space import",
|
|
887
|
+
layer=self.name,
|
|
888
|
+
space_id=target_space_id,
|
|
889
|
+
include_all_messages=include_all_messages,
|
|
890
|
+
)
|
|
891
|
+
|
|
892
|
+
stats = {
|
|
893
|
+
"conversations_processed": 0,
|
|
894
|
+
"prompts_imported": 0,
|
|
895
|
+
"prompts_skipped": 0,
|
|
896
|
+
"cache_entries_imported": 0,
|
|
897
|
+
"cache_entries_skipped": 0,
|
|
898
|
+
"errors": 0,
|
|
899
|
+
}
|
|
900
|
+
|
|
901
|
+
from databricks.sdk.service.dashboards import GenieMessage
|
|
902
|
+
|
|
903
|
+
all_messages: list[tuple[str, GenieMessage]] = []
|
|
904
|
+
page_token: str | None = None
|
|
905
|
+
|
|
906
|
+
while True:
|
|
907
|
+
try:
|
|
908
|
+
response = self.workspace_client.genie.list_conversations(
|
|
909
|
+
space_id=target_space_id,
|
|
910
|
+
include_all=include_all_messages,
|
|
911
|
+
page_token=page_token,
|
|
912
|
+
)
|
|
913
|
+
except Exception as e:
|
|
914
|
+
logger.error(f"Failed to list conversations: {e}", layer=self.name)
|
|
915
|
+
stats["errors"] += 1
|
|
916
|
+
break
|
|
917
|
+
|
|
918
|
+
if response.conversations is None:
|
|
919
|
+
break
|
|
920
|
+
|
|
921
|
+
for conversation in response.conversations:
|
|
922
|
+
if conversation.conversation_id is None:
|
|
923
|
+
continue
|
|
924
|
+
|
|
925
|
+
stats["conversations_processed"] += 1
|
|
926
|
+
|
|
927
|
+
try:
|
|
928
|
+
messages_response = (
|
|
929
|
+
self.workspace_client.genie.list_conversation_messages(
|
|
930
|
+
space_id=target_space_id,
|
|
931
|
+
conversation_id=conversation.conversation_id,
|
|
932
|
+
)
|
|
933
|
+
)
|
|
934
|
+
except Exception as e:
|
|
935
|
+
logger.warning(f"Failed to fetch messages: {e}", layer=self.name)
|
|
936
|
+
stats["errors"] += 1
|
|
937
|
+
continue
|
|
938
|
+
|
|
939
|
+
if messages_response.messages is None:
|
|
940
|
+
continue
|
|
941
|
+
|
|
942
|
+
for message in messages_response.messages:
|
|
943
|
+
all_messages.append((conversation.conversation_id, message))
|
|
944
|
+
|
|
945
|
+
if max_messages and len(all_messages) >= max_messages:
|
|
946
|
+
break
|
|
947
|
+
|
|
948
|
+
if max_messages and len(all_messages) >= max_messages:
|
|
949
|
+
break
|
|
950
|
+
|
|
951
|
+
page_token = response.next_page_token
|
|
952
|
+
if page_token is None:
|
|
953
|
+
break
|
|
954
|
+
|
|
955
|
+
# Sort and limit
|
|
956
|
+
all_messages.sort(
|
|
957
|
+
key=lambda x: x[1].created_timestamp if x[1].created_timestamp else 0,
|
|
958
|
+
reverse=True,
|
|
959
|
+
)
|
|
960
|
+
if max_messages:
|
|
961
|
+
all_messages = all_messages[:max_messages]
|
|
962
|
+
|
|
963
|
+
# Group messages by conversation_id for context building
|
|
964
|
+
from collections import defaultdict
|
|
965
|
+
|
|
966
|
+
messages_by_conversation: dict[str, list[tuple[str, GenieMessage]]] = (
|
|
967
|
+
defaultdict(list)
|
|
968
|
+
)
|
|
969
|
+
for conv_id, msg in all_messages:
|
|
970
|
+
messages_by_conversation[conv_id].append((conv_id, msg))
|
|
971
|
+
|
|
972
|
+
# Sort each conversation's messages by timestamp (oldest first for context building)
|
|
973
|
+
for conv_id in messages_by_conversation:
|
|
974
|
+
messages_by_conversation[conv_id].sort(
|
|
975
|
+
key=lambda x: x[1].created_timestamp if x[1].created_timestamp else 0
|
|
976
|
+
)
|
|
977
|
+
|
|
978
|
+
# Process messages
|
|
979
|
+
for conversation_id, message in all_messages:
|
|
980
|
+
if message.content is None:
|
|
981
|
+
continue
|
|
982
|
+
|
|
983
|
+
message_created_at = None
|
|
984
|
+
if message.created_timestamp:
|
|
985
|
+
message_created_at = datetime.fromtimestamp(
|
|
986
|
+
message.created_timestamp / 1000.0,
|
|
987
|
+
tz=from_datetime.tzinfo if from_datetime else None,
|
|
988
|
+
)
|
|
989
|
+
|
|
990
|
+
if message_created_at:
|
|
991
|
+
if from_datetime and message_created_at < from_datetime:
|
|
992
|
+
continue
|
|
993
|
+
if to_datetime and message_created_at > to_datetime:
|
|
994
|
+
continue
|
|
995
|
+
|
|
996
|
+
# Store prompt
|
|
997
|
+
prompt_stored = self._store_prompt_if_not_exists(
|
|
998
|
+
prompt=message.content,
|
|
999
|
+
conversation_id=conversation_id,
|
|
1000
|
+
space_id=target_space_id,
|
|
1001
|
+
cache_hit=False,
|
|
1002
|
+
created_at=message_created_at,
|
|
1003
|
+
)
|
|
1004
|
+
|
|
1005
|
+
if prompt_stored:
|
|
1006
|
+
stats["prompts_imported"] += 1
|
|
1007
|
+
else:
|
|
1008
|
+
stats["prompts_skipped"] += 1
|
|
1009
|
+
|
|
1010
|
+
# Check for SQL attachments
|
|
1011
|
+
if message.attachments:
|
|
1012
|
+
for attachment in message.attachments:
|
|
1013
|
+
if attachment.query and attachment.query.query:
|
|
1014
|
+
try:
|
|
1015
|
+
# Build conversation context from prior messages
|
|
1016
|
+
# Uses same "Previous: {content}" format as normal operations
|
|
1017
|
+
prior_messages: list[str] = []
|
|
1018
|
+
conv_messages = messages_by_conversation.get(
|
|
1019
|
+
conversation_id, []
|
|
1020
|
+
)
|
|
1021
|
+
for _, prior_msg in conv_messages:
|
|
1022
|
+
if (
|
|
1023
|
+
prior_msg.created_timestamp
|
|
1024
|
+
and message.created_timestamp
|
|
1025
|
+
):
|
|
1026
|
+
if (
|
|
1027
|
+
prior_msg.created_timestamp
|
|
1028
|
+
< message.created_timestamp
|
|
1029
|
+
):
|
|
1030
|
+
if prior_msg.content:
|
|
1031
|
+
content = prior_msg.content
|
|
1032
|
+
if len(content) > 500:
|
|
1033
|
+
content = content[:500] + "..."
|
|
1034
|
+
prior_messages.append(
|
|
1035
|
+
f"Previous: {content}"
|
|
1036
|
+
)
|
|
1037
|
+
|
|
1038
|
+
# Limit to context_window_size (most recent N messages)
|
|
1039
|
+
context_window = self.context_window_size
|
|
1040
|
+
if len(prior_messages) > context_window:
|
|
1041
|
+
prior_messages = prior_messages[-context_window:]
|
|
1042
|
+
|
|
1043
|
+
conversation_context = "\n".join(prior_messages)
|
|
1044
|
+
|
|
1045
|
+
# Generate embeddings
|
|
1046
|
+
question_embedding = self._embeddings.embed_query(
|
|
1047
|
+
message.content
|
|
1048
|
+
)
|
|
1049
|
+
if conversation_context:
|
|
1050
|
+
context_embedding = self._embeddings.embed_query(
|
|
1051
|
+
conversation_context
|
|
1052
|
+
)
|
|
1053
|
+
else:
|
|
1054
|
+
# Zero vector when no prior context (first message)
|
|
1055
|
+
context_embedding = [0.0] * len(question_embedding)
|
|
1056
|
+
|
|
1057
|
+
cache_stored = self._store_cache_entry_if_not_exists(
|
|
1058
|
+
question=message.content,
|
|
1059
|
+
conversation_context=conversation_context,
|
|
1060
|
+
question_embedding=question_embedding,
|
|
1061
|
+
context_embedding=context_embedding,
|
|
1062
|
+
sql_query=attachment.query.query,
|
|
1063
|
+
description=attachment.query.description or "",
|
|
1064
|
+
conversation_id=conversation_id,
|
|
1065
|
+
space_id=target_space_id,
|
|
1066
|
+
)
|
|
1067
|
+
|
|
1068
|
+
if cache_stored:
|
|
1069
|
+
stats["cache_entries_imported"] += 1
|
|
1070
|
+
else:
|
|
1071
|
+
stats["cache_entries_skipped"] += 1
|
|
1072
|
+
except Exception as e:
|
|
1073
|
+
logger.warning(
|
|
1074
|
+
f"Failed to generate embeddings: {e}", layer=self.name
|
|
1075
|
+
)
|
|
1076
|
+
stats["errors"] += 1
|
|
1077
|
+
|
|
1078
|
+
logger.info("Completed from_space import", layer=self.name, **stats)
|
|
1079
|
+
return self
|
|
1080
|
+
|
|
1081
|
+
def _store_prompt_if_not_exists(
|
|
1082
|
+
self,
|
|
1083
|
+
prompt: str,
|
|
1084
|
+
conversation_id: str,
|
|
1085
|
+
space_id: str | None = None,
|
|
1086
|
+
cache_hit: bool = False,
|
|
1087
|
+
created_at: datetime | None = None,
|
|
1088
|
+
) -> bool:
|
|
1089
|
+
"""Store prompt with ON CONFLICT DO NOTHING."""
|
|
1090
|
+
target_space_id = space_id or self.space_id
|
|
1091
|
+
prompt_table_name = self.prompt_history_table
|
|
1092
|
+
|
|
1093
|
+
if created_at:
|
|
1094
|
+
insert_sql = f"""
|
|
1095
|
+
INSERT INTO {prompt_table_name}
|
|
1096
|
+
(genie_space_id, conversation_id, prompt, cache_hit, created_at)
|
|
1097
|
+
VALUES (%s, %s, %s, %s, %s)
|
|
1098
|
+
ON CONFLICT (genie_space_id, conversation_id, prompt) DO NOTHING
|
|
1099
|
+
"""
|
|
1100
|
+
params = (target_space_id, conversation_id, prompt, cache_hit, created_at)
|
|
1101
|
+
else:
|
|
1102
|
+
insert_sql = f"""
|
|
1103
|
+
INSERT INTO {prompt_table_name}
|
|
1104
|
+
(genie_space_id, conversation_id, prompt, cache_hit)
|
|
1105
|
+
VALUES (%s, %s, %s, %s)
|
|
1106
|
+
ON CONFLICT (genie_space_id, conversation_id, prompt) DO NOTHING
|
|
1107
|
+
"""
|
|
1108
|
+
params = (target_space_id, conversation_id, prompt, cache_hit)
|
|
1109
|
+
|
|
1110
|
+
try:
|
|
1111
|
+
with self._pool.connection() as conn:
|
|
1112
|
+
with conn.cursor() as cur:
|
|
1113
|
+
cur.execute(insert_sql, params)
|
|
1114
|
+
return cur.rowcount > 0 if isinstance(cur.rowcount, int) else False
|
|
1115
|
+
except Exception:
|
|
1116
|
+
return False
|
|
1117
|
+
|
|
1118
|
+
def _store_cache_entry_if_not_exists(
|
|
1119
|
+
self,
|
|
1120
|
+
question: str,
|
|
1121
|
+
conversation_context: str,
|
|
1122
|
+
question_embedding: list[float],
|
|
1123
|
+
context_embedding: list[float],
|
|
1124
|
+
sql_query: str,
|
|
1125
|
+
description: str | None = None,
|
|
1126
|
+
conversation_id: str | None = None,
|
|
1127
|
+
space_id: str | None = None,
|
|
1128
|
+
) -> bool:
|
|
1129
|
+
"""Store cache entry with ON CONFLICT DO NOTHING."""
|
|
1130
|
+
target_space_id = space_id or self.space_id
|
|
1131
|
+
|
|
1132
|
+
insert_sql = f"""
|
|
1133
|
+
INSERT INTO {self.table_name}
|
|
1134
|
+
(genie_space_id, question, conversation_context, context_string,
|
|
1135
|
+
question_embedding, context_embedding, sql_query, description, conversation_id)
|
|
1136
|
+
VALUES (%s, %s, %s, %s, %s::vector, %s::vector, %s, %s, %s)
|
|
1137
|
+
ON CONFLICT (genie_space_id, question) DO NOTHING
|
|
1138
|
+
"""
|
|
1139
|
+
question_emb_str = f"[{','.join(str(x) for x in question_embedding)}]"
|
|
1140
|
+
context_emb_str = f"[{','.join(str(x) for x in context_embedding)}]"
|
|
1141
|
+
full_context = (
|
|
1142
|
+
f"{conversation_context}\nCurrent: {question}"
|
|
1143
|
+
if conversation_context
|
|
1144
|
+
else question
|
|
1145
|
+
)
|
|
1146
|
+
|
|
1147
|
+
try:
|
|
1148
|
+
with self._pool.connection() as conn:
|
|
1149
|
+
with conn.cursor() as cur:
|
|
1150
|
+
cur.execute(
|
|
1151
|
+
insert_sql,
|
|
1152
|
+
(
|
|
1153
|
+
target_space_id,
|
|
1154
|
+
question,
|
|
1155
|
+
conversation_context,
|
|
1156
|
+
full_context,
|
|
1157
|
+
question_emb_str,
|
|
1158
|
+
context_emb_str,
|
|
1159
|
+
sql_query,
|
|
1160
|
+
description or "",
|
|
1161
|
+
conversation_id or "",
|
|
1162
|
+
),
|
|
1163
|
+
)
|
|
1164
|
+
return cur.rowcount > 0 if isinstance(cur.rowcount, int) else False
|
|
1165
|
+
except Exception:
|
|
1166
|
+
return False
|