dao-ai 0.1.17__py3-none-any.whl → 0.1.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/cli.py +8 -3
- dao_ai/config.py +513 -32
- dao_ai/evaluation.py +543 -0
- dao_ai/genie/cache/__init__.py +2 -0
- dao_ai/genie/cache/core.py +1 -1
- dao_ai/genie/cache/in_memory_semantic.py +871 -0
- dao_ai/genie/cache/lru.py +15 -11
- dao_ai/genie/cache/semantic.py +52 -18
- dao_ai/memory/postgres.py +146 -35
- dao_ai/orchestration/core.py +33 -9
- dao_ai/orchestration/supervisor.py +23 -8
- dao_ai/{prompts.py → prompts/__init__.py} +10 -1
- dao_ai/prompts/instructed_retriever_decomposition.yaml +58 -0
- dao_ai/prompts/instruction_reranker.yaml +14 -0
- dao_ai/prompts/router.yaml +37 -0
- dao_ai/prompts/verifier.yaml +46 -0
- dao_ai/providers/databricks.py +33 -12
- dao_ai/tools/genie.py +28 -3
- dao_ai/tools/instructed_retriever.py +366 -0
- dao_ai/tools/instruction_reranker.py +202 -0
- dao_ai/tools/router.py +89 -0
- dao_ai/tools/vector_search.py +441 -134
- dao_ai/tools/verifier.py +159 -0
- dao_ai/utils.py +182 -2
- dao_ai/vector_search.py +9 -1
- {dao_ai-0.1.17.dist-info → dao_ai-0.1.19.dist-info}/METADATA +4 -3
- {dao_ai-0.1.17.dist-info → dao_ai-0.1.19.dist-info}/RECORD +30 -20
- {dao_ai-0.1.17.dist-info → dao_ai-0.1.19.dist-info}/WHEEL +0 -0
- {dao_ai-0.1.17.dist-info → dao_ai-0.1.19.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.1.17.dist-info → dao_ai-0.1.19.dist-info}/licenses/LICENSE +0 -0
dao_ai/genie/cache/lru.py
CHANGED
|
@@ -124,9 +124,7 @@ class LRUCacheService(GenieServiceBase):
|
|
|
124
124
|
if self._cache:
|
|
125
125
|
oldest_key: str = next(iter(self._cache))
|
|
126
126
|
del self._cache[oldest_key]
|
|
127
|
-
logger.trace(
|
|
128
|
-
"Evicted cache entry", layer=self.name, key_prefix=oldest_key[:50]
|
|
129
|
-
)
|
|
127
|
+
logger.trace("Evicted cache entry", layer=self.name, key=oldest_key[:50])
|
|
130
128
|
|
|
131
129
|
def _get(self, key: str) -> SQLCacheEntry | None:
|
|
132
130
|
"""Get from cache, returning None if not found or expired."""
|
|
@@ -137,7 +135,7 @@ class LRUCacheService(GenieServiceBase):
|
|
|
137
135
|
|
|
138
136
|
if self._is_expired(entry):
|
|
139
137
|
del self._cache[key]
|
|
140
|
-
logger.trace("Expired cache entry", layer=self.name,
|
|
138
|
+
logger.trace("Expired cache entry", layer=self.name, key=key[:50])
|
|
141
139
|
return None
|
|
142
140
|
|
|
143
141
|
self._cache.move_to_end(key)
|
|
@@ -157,11 +155,11 @@ class LRUCacheService(GenieServiceBase):
|
|
|
157
155
|
conversation_id=response.conversation_id,
|
|
158
156
|
created_at=datetime.now(),
|
|
159
157
|
)
|
|
160
|
-
logger.
|
|
158
|
+
logger.debug(
|
|
161
159
|
"Stored cache entry",
|
|
162
160
|
layer=self.name,
|
|
163
|
-
|
|
164
|
-
|
|
161
|
+
key=key[:50],
|
|
162
|
+
sql=response.query[:50] if response.query else None,
|
|
165
163
|
cache_size=len(self._cache),
|
|
166
164
|
capacity=self.capacity,
|
|
167
165
|
)
|
|
@@ -180,7 +178,7 @@ class LRUCacheService(GenieServiceBase):
|
|
|
180
178
|
w: WorkspaceClient = self.warehouse.workspace_client
|
|
181
179
|
warehouse_id: str = str(self.warehouse.warehouse_id)
|
|
182
180
|
|
|
183
|
-
logger.trace("Executing cached SQL", layer=self.name,
|
|
181
|
+
logger.trace("Executing cached SQL", layer=self.name, sql=sql[:100])
|
|
184
182
|
|
|
185
183
|
statement_response: StatementResponse = w.statement_execution.execute_statement(
|
|
186
184
|
statement=sql,
|
|
@@ -258,13 +256,17 @@ class LRUCacheService(GenieServiceBase):
|
|
|
258
256
|
cached: SQLCacheEntry | None = self._get(key)
|
|
259
257
|
|
|
260
258
|
if cached is not None:
|
|
259
|
+
cache_age_seconds = (datetime.now() - cached.created_at).total_seconds()
|
|
261
260
|
logger.info(
|
|
262
261
|
"Cache HIT",
|
|
263
262
|
layer=self.name,
|
|
264
|
-
|
|
263
|
+
question=question[:80],
|
|
265
264
|
conversation_id=conversation_id,
|
|
265
|
+
cached_sql=cached.query[:80] if cached.query else None,
|
|
266
|
+
cache_age_seconds=round(cache_age_seconds, 1),
|
|
266
267
|
cache_size=self.size,
|
|
267
268
|
capacity=self.capacity,
|
|
269
|
+
ttl_seconds=self.parameters.time_to_live_seconds,
|
|
268
270
|
)
|
|
269
271
|
|
|
270
272
|
# Re-execute the cached SQL to get fresh data
|
|
@@ -286,17 +288,19 @@ class LRUCacheService(GenieServiceBase):
|
|
|
286
288
|
logger.info(
|
|
287
289
|
"Cache MISS",
|
|
288
290
|
layer=self.name,
|
|
289
|
-
|
|
291
|
+
question=question[:80],
|
|
290
292
|
conversation_id=conversation_id,
|
|
291
293
|
cache_size=self.size,
|
|
292
294
|
capacity=self.capacity,
|
|
295
|
+
ttl_seconds=self.parameters.time_to_live_seconds,
|
|
293
296
|
delegating_to=type(self.impl).__name__,
|
|
294
297
|
)
|
|
295
298
|
|
|
296
299
|
result: CacheResult = self.impl.ask_question(question, conversation_id)
|
|
297
300
|
with self._lock:
|
|
298
301
|
self._put(key, result.response)
|
|
299
|
-
|
|
302
|
+
# Propagate the inner cache's result - if it was a hit there, preserve that info
|
|
303
|
+
return result
|
|
300
304
|
|
|
301
305
|
@property
|
|
302
306
|
def space_id(self) -> str:
|
dao_ai/genie/cache/semantic.py
CHANGED
|
@@ -497,6 +497,7 @@ class SemanticCacheService(GenieServiceBase):
|
|
|
497
497
|
conversation_context: str,
|
|
498
498
|
question_embedding: list[float],
|
|
499
499
|
context_embedding: list[float],
|
|
500
|
+
conversation_id: str | None = None,
|
|
500
501
|
) -> tuple[SQLCacheEntry, float] | None:
|
|
501
502
|
"""
|
|
502
503
|
Find a semantically similar cached entry using dual embedding matching.
|
|
@@ -509,6 +510,7 @@ class SemanticCacheService(GenieServiceBase):
|
|
|
509
510
|
conversation_context: The conversation context string
|
|
510
511
|
question_embedding: The embedding vector of just the question
|
|
511
512
|
context_embedding: The embedding vector of the conversation context
|
|
513
|
+
conversation_id: Optional conversation ID (for logging)
|
|
512
514
|
|
|
513
515
|
Returns:
|
|
514
516
|
Tuple of (SQLCacheEntry, combined_similarity_score) if found, None otherwise
|
|
@@ -576,8 +578,9 @@ class SemanticCacheService(GenieServiceBase):
|
|
|
576
578
|
logger.info(
|
|
577
579
|
"Cache MISS (no entries)",
|
|
578
580
|
layer=self.name,
|
|
579
|
-
|
|
581
|
+
question=question[:50],
|
|
580
582
|
space=self.space_id,
|
|
583
|
+
delegating_to=type(self.impl).__name__,
|
|
581
584
|
)
|
|
582
585
|
return None
|
|
583
586
|
|
|
@@ -602,8 +605,8 @@ class SemanticCacheService(GenieServiceBase):
|
|
|
602
605
|
context_sim=f"{context_similarity:.4f}",
|
|
603
606
|
combined_sim=f"{combined_similarity:.4f}",
|
|
604
607
|
is_valid=is_valid,
|
|
605
|
-
|
|
606
|
-
|
|
608
|
+
cached_question=cached_question[:50],
|
|
609
|
+
cached_context=cached_context[:80],
|
|
607
610
|
)
|
|
608
611
|
|
|
609
612
|
# Check BOTH similarity thresholds (dual embedding precision check)
|
|
@@ -613,6 +616,7 @@ class SemanticCacheService(GenieServiceBase):
|
|
|
613
616
|
layer=self.name,
|
|
614
617
|
question_sim=f"{question_similarity:.4f}",
|
|
615
618
|
threshold=self.parameters.similarity_threshold,
|
|
619
|
+
delegating_to=type(self.impl).__name__,
|
|
616
620
|
)
|
|
617
621
|
return None
|
|
618
622
|
|
|
@@ -622,6 +626,7 @@ class SemanticCacheService(GenieServiceBase):
|
|
|
622
626
|
layer=self.name,
|
|
623
627
|
context_sim=f"{context_similarity:.4f}",
|
|
624
628
|
threshold=self.parameters.context_similarity_threshold,
|
|
629
|
+
delegating_to=type(self.impl).__name__,
|
|
625
630
|
)
|
|
626
631
|
return None
|
|
627
632
|
|
|
@@ -635,17 +640,32 @@ class SemanticCacheService(GenieServiceBase):
|
|
|
635
640
|
layer=self.name,
|
|
636
641
|
combined_sim=f"{combined_similarity:.4f}",
|
|
637
642
|
ttl_seconds=ttl_seconds,
|
|
638
|
-
|
|
643
|
+
cached_question=cached_question[:50],
|
|
644
|
+
delegating_to=type(self.impl).__name__,
|
|
639
645
|
)
|
|
640
646
|
return None
|
|
641
647
|
|
|
648
|
+
from datetime import datetime as dt
|
|
649
|
+
|
|
650
|
+
cache_age_seconds = (
|
|
651
|
+
(dt.now(created_at.tzinfo) - created_at).total_seconds()
|
|
652
|
+
if created_at
|
|
653
|
+
else None
|
|
654
|
+
)
|
|
642
655
|
logger.info(
|
|
643
656
|
"Cache HIT",
|
|
644
657
|
layer=self.name,
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
|
|
658
|
+
question=question[:80],
|
|
659
|
+
conversation_id=conversation_id,
|
|
660
|
+
matched_question=cached_question[:80],
|
|
661
|
+
cache_age_seconds=round(cache_age_seconds, 1)
|
|
662
|
+
if cache_age_seconds
|
|
663
|
+
else None,
|
|
664
|
+
question_similarity=f"{question_similarity:.4f}",
|
|
665
|
+
context_similarity=f"{context_similarity:.4f}",
|
|
666
|
+
combined_similarity=f"{combined_similarity:.4f}",
|
|
667
|
+
cached_sql=sql_query[:80] if sql_query else None,
|
|
668
|
+
ttl_seconds=self.parameters.time_to_live_seconds,
|
|
649
669
|
)
|
|
650
670
|
|
|
651
671
|
entry = SQLCacheEntry(
|
|
@@ -696,12 +716,12 @@ class SemanticCacheService(GenieServiceBase):
|
|
|
696
716
|
response.conversation_id,
|
|
697
717
|
),
|
|
698
718
|
)
|
|
699
|
-
logger.
|
|
719
|
+
logger.debug(
|
|
700
720
|
"Stored cache entry",
|
|
701
721
|
layer=self.name,
|
|
702
|
-
|
|
703
|
-
|
|
704
|
-
|
|
722
|
+
question=question[:50],
|
|
723
|
+
context=conversation_context[:80],
|
|
724
|
+
sql=response.query[:50] if response.query else None,
|
|
705
725
|
space=self.space_id,
|
|
706
726
|
table=self.table_name,
|
|
707
727
|
)
|
|
@@ -796,7 +816,11 @@ class SemanticCacheService(GenieServiceBase):
|
|
|
796
816
|
|
|
797
817
|
# Check cache using dual embedding similarity
|
|
798
818
|
cache_result: tuple[SQLCacheEntry, float] | None = self._find_similar(
|
|
799
|
-
question,
|
|
819
|
+
question,
|
|
820
|
+
conversation_context,
|
|
821
|
+
question_embedding,
|
|
822
|
+
context_embedding,
|
|
823
|
+
conversation_id,
|
|
800
824
|
)
|
|
801
825
|
|
|
802
826
|
if cache_result is not None:
|
|
@@ -805,7 +829,8 @@ class SemanticCacheService(GenieServiceBase):
|
|
|
805
829
|
"Semantic cache hit",
|
|
806
830
|
layer=self.name,
|
|
807
831
|
combined_similarity=f"{combined_similarity:.3f}",
|
|
808
|
-
|
|
832
|
+
question=question[:50],
|
|
833
|
+
conversation_id=conversation_id,
|
|
809
834
|
)
|
|
810
835
|
|
|
811
836
|
# Re-execute the cached SQL to get fresh data
|
|
@@ -825,16 +850,25 @@ class SemanticCacheService(GenieServiceBase):
|
|
|
825
850
|
return CacheResult(response=response, cache_hit=True, served_by=self.name)
|
|
826
851
|
|
|
827
852
|
# Cache miss - delegate to wrapped service
|
|
828
|
-
logger.
|
|
853
|
+
logger.info(
|
|
854
|
+
"Cache MISS",
|
|
855
|
+
layer=self.name,
|
|
856
|
+
question=question[:80],
|
|
857
|
+
conversation_id=conversation_id,
|
|
858
|
+
space_id=self.space_id,
|
|
859
|
+
similarity_threshold=self.similarity_threshold,
|
|
860
|
+
delegating_to=type(self.impl).__name__,
|
|
861
|
+
)
|
|
829
862
|
|
|
830
863
|
result: CacheResult = self.impl.ask_question(question, conversation_id)
|
|
831
864
|
|
|
832
865
|
# Store in cache if we got a SQL query
|
|
833
866
|
if result.response.query:
|
|
834
|
-
logger.
|
|
867
|
+
logger.debug(
|
|
835
868
|
"Storing new cache entry",
|
|
836
869
|
layer=self.name,
|
|
837
|
-
|
|
870
|
+
question=question[:50],
|
|
871
|
+
conversation_id=conversation_id,
|
|
838
872
|
space=self.space_id,
|
|
839
873
|
)
|
|
840
874
|
self._store_entry(
|
|
@@ -848,7 +882,7 @@ class SemanticCacheService(GenieServiceBase):
|
|
|
848
882
|
logger.warning(
|
|
849
883
|
"Not caching: response has no SQL query",
|
|
850
884
|
layer=self.name,
|
|
851
|
-
|
|
885
|
+
question=question[:50],
|
|
852
886
|
)
|
|
853
887
|
|
|
854
888
|
return CacheResult(response=result.response, cache_hit=False, served_by=None)
|
dao_ai/memory/postgres.py
CHANGED
|
@@ -3,6 +3,7 @@ import atexit
|
|
|
3
3
|
import threading
|
|
4
4
|
from typing import Any, Optional
|
|
5
5
|
|
|
6
|
+
from databricks_ai_bridge.lakebase import AsyncLakebasePool, LakebasePool
|
|
6
7
|
from langgraph.checkpoint.base import BaseCheckpointSaver
|
|
7
8
|
from langgraph.checkpoint.postgres import ShallowPostgresSaver
|
|
8
9
|
from langgraph.checkpoint.postgres.aio import AsyncShallowPostgresSaver
|
|
@@ -86,13 +87,22 @@ async def _create_async_pool(
|
|
|
86
87
|
|
|
87
88
|
|
|
88
89
|
class AsyncPostgresPoolManager:
|
|
90
|
+
"""
|
|
91
|
+
Asynchronous PostgreSQL connection pool manager that shares pools
|
|
92
|
+
based on database configuration.
|
|
93
|
+
|
|
94
|
+
For Lakebase connections (when instance_name is provided), uses AsyncLakebasePool
|
|
95
|
+
from databricks_ai_bridge which handles automatic token rotation and host resolution.
|
|
96
|
+
For standard PostgreSQL connections, uses psycopg_pool.AsyncConnectionPool.
|
|
97
|
+
"""
|
|
98
|
+
|
|
89
99
|
_pools: dict[str, AsyncConnectionPool] = {}
|
|
100
|
+
_lakebase_pools: dict[str, AsyncLakebasePool] = {}
|
|
90
101
|
_lock: asyncio.Lock = asyncio.Lock()
|
|
91
102
|
|
|
92
103
|
@classmethod
|
|
93
104
|
async def get_pool(cls, database: DatabaseModel) -> AsyncConnectionPool:
|
|
94
105
|
connection_key: str = database.name
|
|
95
|
-
connection_params: dict[str, Any] = database.connection_params
|
|
96
106
|
|
|
97
107
|
async with cls._lock:
|
|
98
108
|
if connection_key in cls._pools:
|
|
@@ -103,19 +113,43 @@ class AsyncPostgresPoolManager:
|
|
|
103
113
|
|
|
104
114
|
logger.debug("Creating new async PostgreSQL pool", database=database.name)
|
|
105
115
|
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
116
|
+
if database.is_lakebase:
|
|
117
|
+
# Use AsyncLakebasePool for Lakebase connections
|
|
118
|
+
# AsyncLakebasePool handles automatic token rotation and host resolution
|
|
119
|
+
lakebase_pool = AsyncLakebasePool(
|
|
120
|
+
instance_name=database.instance_name,
|
|
121
|
+
workspace_client=database.workspace_client,
|
|
122
|
+
min_size=1,
|
|
123
|
+
max_size=database.max_pool_size,
|
|
124
|
+
timeout=float(database.timeout_seconds),
|
|
125
|
+
)
|
|
126
|
+
# Open the async pool
|
|
127
|
+
await lakebase_pool.open()
|
|
128
|
+
# Store the AsyncLakebasePool for proper cleanup
|
|
129
|
+
cls._lakebase_pools[connection_key] = lakebase_pool
|
|
130
|
+
# Get the underlying AsyncConnectionPool
|
|
131
|
+
pool = lakebase_pool.pool
|
|
132
|
+
logger.success(
|
|
133
|
+
"Async Lakebase connection pool created",
|
|
134
|
+
database=database.name,
|
|
135
|
+
instance_name=database.instance_name,
|
|
136
|
+
pool_size=database.max_pool_size,
|
|
137
|
+
)
|
|
138
|
+
else:
|
|
139
|
+
# Use standard async PostgreSQL pool for non-Lakebase connections
|
|
140
|
+
connection_params: dict[str, Any] = database.connection_params
|
|
141
|
+
kwargs: dict[str, Any] = {
|
|
142
|
+
"row_factory": dict_row,
|
|
143
|
+
"autocommit": True,
|
|
144
|
+
} | database.connection_kwargs or {}
|
|
145
|
+
|
|
146
|
+
pool = await _create_async_pool(
|
|
147
|
+
connection_params=connection_params,
|
|
148
|
+
database_name=database.name,
|
|
149
|
+
max_pool_size=database.max_pool_size,
|
|
150
|
+
timeout_seconds=database.timeout_seconds,
|
|
151
|
+
kwargs=kwargs,
|
|
152
|
+
)
|
|
119
153
|
|
|
120
154
|
cls._pools[connection_key] = pool
|
|
121
155
|
return pool
|
|
@@ -125,7 +159,13 @@ class AsyncPostgresPoolManager:
|
|
|
125
159
|
connection_key: str = database.name
|
|
126
160
|
|
|
127
161
|
async with cls._lock:
|
|
128
|
-
if
|
|
162
|
+
# Close AsyncLakebasePool if it exists (handles underlying pool cleanup)
|
|
163
|
+
if connection_key in cls._lakebase_pools:
|
|
164
|
+
lakebase_pool = cls._lakebase_pools.pop(connection_key)
|
|
165
|
+
await lakebase_pool.close()
|
|
166
|
+
cls._pools.pop(connection_key, None)
|
|
167
|
+
logger.debug("Async Lakebase pool closed", database=database.name)
|
|
168
|
+
elif connection_key in cls._pools:
|
|
129
169
|
pool = cls._pools.pop(connection_key)
|
|
130
170
|
await pool.close()
|
|
131
171
|
logger.debug("Async PostgreSQL pool closed", database=database.name)
|
|
@@ -133,9 +173,32 @@ class AsyncPostgresPoolManager:
|
|
|
133
173
|
@classmethod
|
|
134
174
|
async def close_all_pools(cls):
|
|
135
175
|
async with cls._lock:
|
|
176
|
+
# Close all AsyncLakebasePool instances first
|
|
177
|
+
for connection_key, lakebase_pool in cls._lakebase_pools.items():
|
|
178
|
+
try:
|
|
179
|
+
await asyncio.wait_for(lakebase_pool.close(), timeout=2.0)
|
|
180
|
+
logger.debug("Async Lakebase pool closed", pool=connection_key)
|
|
181
|
+
except asyncio.TimeoutError:
|
|
182
|
+
logger.warning(
|
|
183
|
+
"Timeout closing async Lakebase pool, forcing closure",
|
|
184
|
+
pool=connection_key,
|
|
185
|
+
)
|
|
186
|
+
except asyncio.CancelledError:
|
|
187
|
+
logger.warning(
|
|
188
|
+
"Async Lakebase pool closure cancelled (shutdown in progress)",
|
|
189
|
+
pool=connection_key,
|
|
190
|
+
)
|
|
191
|
+
except Exception as e:
|
|
192
|
+
logger.error(
|
|
193
|
+
"Error closing async Lakebase pool",
|
|
194
|
+
pool=connection_key,
|
|
195
|
+
error=str(e),
|
|
196
|
+
)
|
|
197
|
+
cls._lakebase_pools.clear()
|
|
198
|
+
|
|
199
|
+
# Close any remaining standard async PostgreSQL pools
|
|
136
200
|
for connection_key, pool in cls._pools.items():
|
|
137
201
|
try:
|
|
138
|
-
# Use a short timeout to avoid blocking on pool closure
|
|
139
202
|
await asyncio.wait_for(pool.close(), timeout=2.0)
|
|
140
203
|
logger.debug("Async PostgreSQL pool closed", pool=connection_key)
|
|
141
204
|
except asyncio.TimeoutError:
|
|
@@ -309,15 +372,19 @@ class PostgresPoolManager:
|
|
|
309
372
|
"""
|
|
310
373
|
Synchronous PostgreSQL connection pool manager that shares pools
|
|
311
374
|
based on database configuration.
|
|
375
|
+
|
|
376
|
+
For Lakebase connections (when instance_name is provided), uses LakebasePool
|
|
377
|
+
from databricks_ai_bridge which handles automatic token rotation and host resolution.
|
|
378
|
+
For standard PostgreSQL connections, uses psycopg_pool.ConnectionPool.
|
|
312
379
|
"""
|
|
313
380
|
|
|
314
381
|
_pools: dict[str, ConnectionPool] = {}
|
|
382
|
+
_lakebase_pools: dict[str, LakebasePool] = {}
|
|
315
383
|
_lock: threading.Lock = threading.Lock()
|
|
316
384
|
|
|
317
385
|
@classmethod
|
|
318
386
|
def get_pool(cls, database: DatabaseModel) -> ConnectionPool:
|
|
319
387
|
connection_key: str = str(database.name)
|
|
320
|
-
connection_params: dict[str, Any] = database.connection_params
|
|
321
388
|
|
|
322
389
|
with cls._lock:
|
|
323
390
|
if connection_key in cls._pools:
|
|
@@ -326,19 +393,41 @@ class PostgresPoolManager:
|
|
|
326
393
|
|
|
327
394
|
logger.debug("Creating new PostgreSQL pool", database=database.name)
|
|
328
395
|
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
396
|
+
if database.is_lakebase:
|
|
397
|
+
# Use LakebasePool for Lakebase connections
|
|
398
|
+
# LakebasePool handles automatic token rotation and host resolution
|
|
399
|
+
lakebase_pool = LakebasePool(
|
|
400
|
+
instance_name=database.instance_name,
|
|
401
|
+
workspace_client=database.workspace_client,
|
|
402
|
+
min_size=1,
|
|
403
|
+
max_size=database.max_pool_size,
|
|
404
|
+
timeout=float(database.timeout_seconds),
|
|
405
|
+
)
|
|
406
|
+
# Store the LakebasePool for proper cleanup
|
|
407
|
+
cls._lakebase_pools[connection_key] = lakebase_pool
|
|
408
|
+
# Get the underlying ConnectionPool
|
|
409
|
+
pool = lakebase_pool.pool
|
|
410
|
+
logger.success(
|
|
411
|
+
"Lakebase connection pool created",
|
|
412
|
+
database=database.name,
|
|
413
|
+
instance_name=database.instance_name,
|
|
414
|
+
pool_size=database.max_pool_size,
|
|
415
|
+
)
|
|
416
|
+
else:
|
|
417
|
+
# Use standard PostgreSQL pool for non-Lakebase connections
|
|
418
|
+
connection_params: dict[str, Any] = database.connection_params
|
|
419
|
+
kwargs: dict[str, Any] = {
|
|
420
|
+
"row_factory": dict_row,
|
|
421
|
+
"autocommit": True,
|
|
422
|
+
} | database.connection_kwargs or {}
|
|
423
|
+
|
|
424
|
+
pool = _create_pool(
|
|
425
|
+
connection_params=connection_params,
|
|
426
|
+
database_name=database.name,
|
|
427
|
+
max_pool_size=database.max_pool_size,
|
|
428
|
+
timeout_seconds=database.timeout_seconds,
|
|
429
|
+
kwargs=kwargs,
|
|
430
|
+
)
|
|
342
431
|
|
|
343
432
|
cls._pools[connection_key] = pool
|
|
344
433
|
return pool
|
|
@@ -348,7 +437,13 @@ class PostgresPoolManager:
|
|
|
348
437
|
connection_key: str = database.name
|
|
349
438
|
|
|
350
439
|
with cls._lock:
|
|
351
|
-
if
|
|
440
|
+
# Close LakebasePool if it exists (handles underlying pool cleanup)
|
|
441
|
+
if connection_key in cls._lakebase_pools:
|
|
442
|
+
lakebase_pool = cls._lakebase_pools.pop(connection_key)
|
|
443
|
+
lakebase_pool.close()
|
|
444
|
+
cls._pools.pop(connection_key, None)
|
|
445
|
+
logger.debug("Lakebase pool closed", database=database.name)
|
|
446
|
+
elif connection_key in cls._pools:
|
|
352
447
|
pool = cls._pools.pop(connection_key)
|
|
353
448
|
pool.close()
|
|
354
449
|
logger.debug("PostgreSQL pool closed", database=database.name)
|
|
@@ -356,16 +451,32 @@ class PostgresPoolManager:
|
|
|
356
451
|
@classmethod
|
|
357
452
|
def close_all_pools(cls):
|
|
358
453
|
with cls._lock:
|
|
359
|
-
|
|
454
|
+
# Close all LakebasePool instances first
|
|
455
|
+
for connection_key, lakebase_pool in cls._lakebase_pools.items():
|
|
360
456
|
try:
|
|
361
|
-
|
|
362
|
-
logger.debug("
|
|
457
|
+
lakebase_pool.close()
|
|
458
|
+
logger.debug("Lakebase pool closed", pool=connection_key)
|
|
363
459
|
except Exception as e:
|
|
364
460
|
logger.error(
|
|
365
|
-
"Error closing
|
|
461
|
+
"Error closing Lakebase pool",
|
|
366
462
|
pool=connection_key,
|
|
367
463
|
error=str(e),
|
|
368
464
|
)
|
|
465
|
+
cls._lakebase_pools.clear()
|
|
466
|
+
|
|
467
|
+
# Close any remaining standard PostgreSQL pools
|
|
468
|
+
for connection_key, pool in cls._pools.items():
|
|
469
|
+
# Skip if already closed via LakebasePool
|
|
470
|
+
if connection_key not in cls._lakebase_pools:
|
|
471
|
+
try:
|
|
472
|
+
pool.close()
|
|
473
|
+
logger.debug("PostgreSQL pool closed", pool=connection_key)
|
|
474
|
+
except Exception as e:
|
|
475
|
+
logger.error(
|
|
476
|
+
"Error closing PostgreSQL pool",
|
|
477
|
+
pool=connection_key,
|
|
478
|
+
error=str(e),
|
|
479
|
+
)
|
|
369
480
|
cls._pools.clear()
|
|
370
481
|
|
|
371
482
|
|
dao_ai/orchestration/core.py
CHANGED
|
@@ -9,7 +9,7 @@ This module provides the foundational utilities for multi-agent orchestration:
|
|
|
9
9
|
- Main orchestration graph factory
|
|
10
10
|
"""
|
|
11
11
|
|
|
12
|
-
from typing import Awaitable, Callable, Literal
|
|
12
|
+
from typing import Any, Awaitable, Callable, Literal
|
|
13
13
|
|
|
14
14
|
from langchain.tools import ToolRuntime, tool
|
|
15
15
|
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, ToolMessage
|
|
@@ -179,8 +179,16 @@ def create_agent_node_handler(
|
|
|
179
179
|
"messages": filtered_messages,
|
|
180
180
|
}
|
|
181
181
|
|
|
182
|
-
#
|
|
183
|
-
|
|
182
|
+
# Build config with configurable from context for langmem compatibility
|
|
183
|
+
# langmem tools expect user_id to be in config.configurable
|
|
184
|
+
config: dict[str, Any] = {}
|
|
185
|
+
if runtime.context:
|
|
186
|
+
config = {"configurable": runtime.context.model_dump()}
|
|
187
|
+
|
|
188
|
+
# Invoke the agent with both context and config
|
|
189
|
+
result: AgentState = await agent.ainvoke(
|
|
190
|
+
agent_state, context=runtime.context, config=config
|
|
191
|
+
)
|
|
184
192
|
|
|
185
193
|
# Extract agent response based on output mode
|
|
186
194
|
result_messages = result.get("messages", [])
|
|
@@ -227,15 +235,31 @@ def create_handoff_tool(
|
|
|
227
235
|
tool_call_id: str = runtime.tool_call_id
|
|
228
236
|
logger.debug("Handoff to agent", target_agent=target_agent_name)
|
|
229
237
|
|
|
238
|
+
# Get the AIMessage that triggered this handoff (required for tool_use/tool_result pairing)
|
|
239
|
+
# LLMs expect tool calls to be paired with their responses, so we must include both
|
|
240
|
+
# the AIMessage containing the tool call and the ToolMessage acknowledging it.
|
|
241
|
+
messages: list[BaseMessage] = runtime.state.get("messages", [])
|
|
242
|
+
last_ai_message: AIMessage | None = None
|
|
243
|
+
for msg in reversed(messages):
|
|
244
|
+
if isinstance(msg, AIMessage) and msg.tool_calls:
|
|
245
|
+
last_ai_message = msg
|
|
246
|
+
break
|
|
247
|
+
|
|
248
|
+
# Build message list with proper pairing
|
|
249
|
+
update_messages: list[BaseMessage] = []
|
|
250
|
+
if last_ai_message:
|
|
251
|
+
update_messages.append(last_ai_message)
|
|
252
|
+
update_messages.append(
|
|
253
|
+
ToolMessage(
|
|
254
|
+
content=f"Transferred to {target_agent_name}",
|
|
255
|
+
tool_call_id=tool_call_id,
|
|
256
|
+
)
|
|
257
|
+
)
|
|
258
|
+
|
|
230
259
|
return Command(
|
|
231
260
|
update={
|
|
232
261
|
"active_agent": target_agent_name,
|
|
233
|
-
"messages":
|
|
234
|
-
ToolMessage(
|
|
235
|
-
content=f"Transferred to {target_agent_name}",
|
|
236
|
-
tool_call_id=tool_call_id,
|
|
237
|
-
)
|
|
238
|
-
],
|
|
262
|
+
"messages": update_messages,
|
|
239
263
|
},
|
|
240
264
|
goto=target_agent_name,
|
|
241
265
|
graph=Command.PARENT,
|
|
@@ -13,7 +13,7 @@ from langchain.agents import create_agent
|
|
|
13
13
|
from langchain.agents.middleware import AgentMiddleware as LangchainAgentMiddleware
|
|
14
14
|
from langchain.tools import ToolRuntime, tool
|
|
15
15
|
from langchain_core.language_models import LanguageModelLike
|
|
16
|
-
from langchain_core.messages import ToolMessage
|
|
16
|
+
from langchain_core.messages import AIMessage, BaseMessage, ToolMessage
|
|
17
17
|
from langchain_core.tools import BaseTool
|
|
18
18
|
from langgraph.checkpoint.base import BaseCheckpointSaver
|
|
19
19
|
from langgraph.graph import StateGraph
|
|
@@ -75,15 +75,30 @@ def _create_handoff_back_to_supervisor_tool() -> BaseTool:
|
|
|
75
75
|
tool_call_id: str = runtime.tool_call_id
|
|
76
76
|
logger.debug("Agent handing back to supervisor", summary_preview=summary[:100])
|
|
77
77
|
|
|
78
|
+
# Get the AIMessage that triggered this handoff (required for tool_use/tool_result pairing)
|
|
79
|
+
# LLMs expect tool calls to be paired with their responses, so we must include both
|
|
80
|
+
# the AIMessage containing the tool call and the ToolMessage acknowledging it.
|
|
81
|
+
messages: list[BaseMessage] = runtime.state.get("messages", [])
|
|
82
|
+
last_ai_message: AIMessage | None = None
|
|
83
|
+
for msg in reversed(messages):
|
|
84
|
+
if isinstance(msg, AIMessage) and msg.tool_calls:
|
|
85
|
+
last_ai_message = msg
|
|
86
|
+
break
|
|
87
|
+
|
|
88
|
+
# Build message list with proper pairing
|
|
89
|
+
update_messages: list[BaseMessage] = []
|
|
90
|
+
if last_ai_message:
|
|
91
|
+
update_messages.append(last_ai_message)
|
|
92
|
+
update_messages.append(
|
|
93
|
+
ToolMessage(
|
|
94
|
+
content=f"Task completed: {summary}",
|
|
95
|
+
tool_call_id=tool_call_id,
|
|
96
|
+
)
|
|
97
|
+
)
|
|
98
|
+
|
|
78
99
|
return Command(
|
|
79
100
|
update={
|
|
80
|
-
"
|
|
81
|
-
"messages": [
|
|
82
|
-
ToolMessage(
|
|
83
|
-
content=f"Task completed: {summary}",
|
|
84
|
-
tool_call_id=tool_call_id,
|
|
85
|
-
)
|
|
86
|
-
],
|
|
101
|
+
"messages": update_messages,
|
|
87
102
|
},
|
|
88
103
|
goto=SUPERVISOR_NODE,
|
|
89
104
|
graph=Command.PARENT,
|
|
@@ -2,9 +2,11 @@
|
|
|
2
2
|
Prompt utilities for DAO AI agents.
|
|
3
3
|
|
|
4
4
|
This module provides utilities for creating dynamic prompts using
|
|
5
|
-
LangChain v1's @dynamic_prompt middleware decorator pattern
|
|
5
|
+
LangChain v1's @dynamic_prompt middleware decorator pattern, as well as
|
|
6
|
+
paths to prompt template files.
|
|
6
7
|
"""
|
|
7
8
|
|
|
9
|
+
from pathlib import Path
|
|
8
10
|
from typing import Any, Optional
|
|
9
11
|
|
|
10
12
|
from langchain.agents.middleware import (
|
|
@@ -18,6 +20,13 @@ from loguru import logger
|
|
|
18
20
|
from dao_ai.config import PromptModel
|
|
19
21
|
from dao_ai.state import Context
|
|
20
22
|
|
|
23
|
+
PROMPTS_DIR = Path(__file__).parent
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def get_prompt_path(name: str) -> Path:
|
|
27
|
+
"""Get the path to a prompt template file."""
|
|
28
|
+
return PROMPTS_DIR / name
|
|
29
|
+
|
|
21
30
|
|
|
22
31
|
def make_prompt(
|
|
23
32
|
base_system_prompt: Optional[str | PromptModel],
|