dao-ai 0.1.2__py3-none-any.whl → 0.1.20__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/apps/__init__.py +24 -0
- dao_ai/apps/handlers.py +105 -0
- dao_ai/apps/model_serving.py +29 -0
- dao_ai/apps/resources.py +1122 -0
- dao_ai/apps/server.py +39 -0
- dao_ai/cli.py +546 -37
- dao_ai/config.py +1179 -139
- dao_ai/evaluation.py +543 -0
- dao_ai/genie/__init__.py +55 -7
- dao_ai/genie/cache/__init__.py +34 -7
- dao_ai/genie/cache/base.py +143 -2
- dao_ai/genie/cache/context_aware/__init__.py +31 -0
- dao_ai/genie/cache/context_aware/base.py +1151 -0
- dao_ai/genie/cache/context_aware/in_memory.py +609 -0
- dao_ai/genie/cache/context_aware/persistent.py +802 -0
- dao_ai/genie/cache/context_aware/postgres.py +1166 -0
- dao_ai/genie/cache/core.py +1 -1
- dao_ai/genie/cache/lru.py +257 -75
- dao_ai/genie/cache/optimization.py +890 -0
- dao_ai/genie/core.py +235 -11
- dao_ai/memory/postgres.py +175 -39
- dao_ai/middleware/__init__.py +38 -0
- dao_ai/middleware/assertions.py +3 -3
- dao_ai/middleware/context_editing.py +230 -0
- dao_ai/middleware/core.py +4 -4
- dao_ai/middleware/guardrails.py +3 -3
- dao_ai/middleware/human_in_the_loop.py +3 -2
- dao_ai/middleware/message_validation.py +4 -4
- dao_ai/middleware/model_call_limit.py +77 -0
- dao_ai/middleware/model_retry.py +121 -0
- dao_ai/middleware/pii.py +157 -0
- dao_ai/middleware/summarization.py +1 -1
- dao_ai/middleware/tool_call_limit.py +210 -0
- dao_ai/middleware/tool_retry.py +174 -0
- dao_ai/middleware/tool_selector.py +129 -0
- dao_ai/models.py +327 -370
- dao_ai/nodes.py +9 -16
- dao_ai/orchestration/core.py +33 -9
- dao_ai/orchestration/supervisor.py +29 -13
- dao_ai/orchestration/swarm.py +6 -1
- dao_ai/{prompts.py → prompts/__init__.py} +12 -61
- dao_ai/prompts/instructed_retriever_decomposition.yaml +58 -0
- dao_ai/prompts/instruction_reranker.yaml +14 -0
- dao_ai/prompts/router.yaml +37 -0
- dao_ai/prompts/verifier.yaml +46 -0
- dao_ai/providers/base.py +28 -2
- dao_ai/providers/databricks.py +363 -33
- dao_ai/state.py +1 -0
- dao_ai/tools/__init__.py +5 -3
- dao_ai/tools/genie.py +103 -26
- dao_ai/tools/instructed_retriever.py +366 -0
- dao_ai/tools/instruction_reranker.py +202 -0
- dao_ai/tools/mcp.py +539 -97
- dao_ai/tools/router.py +89 -0
- dao_ai/tools/slack.py +13 -2
- dao_ai/tools/sql.py +7 -3
- dao_ai/tools/unity_catalog.py +32 -10
- dao_ai/tools/vector_search.py +493 -160
- dao_ai/tools/verifier.py +159 -0
- dao_ai/utils.py +182 -2
- dao_ai/vector_search.py +46 -1
- {dao_ai-0.1.2.dist-info → dao_ai-0.1.20.dist-info}/METADATA +45 -9
- dao_ai-0.1.20.dist-info/RECORD +89 -0
- dao_ai/agent_as_code.py +0 -22
- dao_ai/genie/cache/semantic.py +0 -970
- dao_ai-0.1.2.dist-info/RECORD +0 -64
- {dao_ai-0.1.2.dist-info → dao_ai-0.1.20.dist-info}/WHEEL +0 -0
- {dao_ai-0.1.2.dist-info → dao_ai-0.1.20.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.1.2.dist-info → dao_ai-0.1.20.dist-info}/licenses/LICENSE +0 -0
dao_ai/genie/core.py
CHANGED
|
@@ -1,35 +1,259 @@
|
|
|
1
1
|
"""
|
|
2
2
|
Core Genie service implementation.
|
|
3
3
|
|
|
4
|
-
This module provides
|
|
5
|
-
|
|
4
|
+
This module provides:
|
|
5
|
+
- Extended Genie and GenieResponse classes that capture message_id
|
|
6
|
+
- GenieService: Concrete implementation of GenieServiceBase
|
|
7
|
+
|
|
8
|
+
The extended classes wrap the databricks_ai_bridge versions to add message_id
|
|
9
|
+
support, which is needed for sending feedback to the Genie API.
|
|
6
10
|
"""
|
|
7
11
|
|
|
12
|
+
from __future__ import annotations
|
|
13
|
+
|
|
14
|
+
from dataclasses import dataclass
|
|
15
|
+
from typing import TYPE_CHECKING, Union
|
|
16
|
+
|
|
8
17
|
import mlflow
|
|
9
|
-
|
|
18
|
+
import pandas as pd
|
|
19
|
+
from databricks.sdk import WorkspaceClient
|
|
20
|
+
from databricks.sdk.service.dashboards import GenieFeedbackRating
|
|
21
|
+
from databricks_ai_bridge.genie import Genie as DatabricksGenie
|
|
22
|
+
from databricks_ai_bridge.genie import GenieResponse as DatabricksGenieResponse
|
|
23
|
+
from loguru import logger
|
|
10
24
|
|
|
11
25
|
from dao_ai.genie.cache import CacheResult, GenieServiceBase
|
|
26
|
+
from dao_ai.genie.cache.base import get_latest_message_id
|
|
27
|
+
|
|
28
|
+
if TYPE_CHECKING:
|
|
29
|
+
from typing import Optional
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
# =============================================================================
|
|
33
|
+
# Extended Genie Classes with message_id Support
|
|
34
|
+
# =============================================================================
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@dataclass
|
|
38
|
+
class GenieResponse(DatabricksGenieResponse):
|
|
39
|
+
"""
|
|
40
|
+
Extended GenieResponse that includes message_id.
|
|
41
|
+
|
|
42
|
+
This extends the databricks_ai_bridge GenieResponse to capture the message_id
|
|
43
|
+
from API responses, which is required for sending feedback to the Genie API.
|
|
44
|
+
|
|
45
|
+
Attributes:
|
|
46
|
+
result: The query result as string or DataFrame
|
|
47
|
+
query: The generated SQL query
|
|
48
|
+
description: Description of the query
|
|
49
|
+
conversation_id: The conversation ID
|
|
50
|
+
message_id: The message ID (NEW - enables feedback without extra API call)
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
result: Union[str, pd.DataFrame] = ""
|
|
54
|
+
query: Optional[str] = ""
|
|
55
|
+
description: Optional[str] = ""
|
|
56
|
+
conversation_id: Optional[str] = None
|
|
57
|
+
message_id: Optional[str] = None
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class Genie(DatabricksGenie):
|
|
61
|
+
"""
|
|
62
|
+
Extended Genie that captures message_id in responses.
|
|
63
|
+
|
|
64
|
+
This extends the databricks_ai_bridge Genie to return GenieResponse objects
|
|
65
|
+
that include the message_id from the API response. This enables sending
|
|
66
|
+
feedback without requiring an additional API call to look up the message ID.
|
|
67
|
+
|
|
68
|
+
Usage:
|
|
69
|
+
genie = Genie(space_id="my-space")
|
|
70
|
+
response = genie.ask_question("What are total sales?")
|
|
71
|
+
print(response.message_id) # Now available!
|
|
72
|
+
|
|
73
|
+
The original databricks_ai_bridge classes are available as:
|
|
74
|
+
- DatabricksGenie
|
|
75
|
+
- DatabricksGenieResponse
|
|
76
|
+
"""
|
|
77
|
+
|
|
78
|
+
def ask_question(
|
|
79
|
+
self, question: str, conversation_id: str | None = None
|
|
80
|
+
) -> GenieResponse:
|
|
81
|
+
"""
|
|
82
|
+
Ask a question and return response with message_id.
|
|
83
|
+
|
|
84
|
+
This overrides the parent method to capture the message_id from the
|
|
85
|
+
API response and include it in the returned GenieResponse.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
question: The question to ask
|
|
89
|
+
conversation_id: Optional conversation ID for follow-up questions
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
GenieResponse with message_id populated
|
|
93
|
+
"""
|
|
94
|
+
with mlflow.start_span(name="ask_question"):
|
|
95
|
+
# Start or continue conversation
|
|
96
|
+
if not conversation_id:
|
|
97
|
+
resp = self.start_conversation(question)
|
|
98
|
+
else:
|
|
99
|
+
resp = self.create_message(conversation_id, question)
|
|
100
|
+
|
|
101
|
+
# Capture message_id from the API response
|
|
102
|
+
message_id = resp.get("message_id")
|
|
103
|
+
|
|
104
|
+
# Poll for the result using parent's method
|
|
105
|
+
genie_response = self.poll_for_result(resp["conversation_id"], message_id)
|
|
106
|
+
|
|
107
|
+
# Ensure conversation_id is set
|
|
108
|
+
if not genie_response.conversation_id:
|
|
109
|
+
genie_response.conversation_id = resp["conversation_id"]
|
|
110
|
+
|
|
111
|
+
# Return our extended response with message_id
|
|
112
|
+
return GenieResponse(
|
|
113
|
+
result=genie_response.result,
|
|
114
|
+
query=genie_response.query,
|
|
115
|
+
description=genie_response.description,
|
|
116
|
+
conversation_id=genie_response.conversation_id,
|
|
117
|
+
message_id=message_id,
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
# =============================================================================
|
|
122
|
+
# GenieService Implementation
|
|
123
|
+
# =============================================================================
|
|
12
124
|
|
|
13
125
|
|
|
14
126
|
class GenieService(GenieServiceBase):
|
|
15
|
-
"""
|
|
127
|
+
"""
|
|
128
|
+
Concrete implementation of GenieServiceBase using the extended Genie.
|
|
129
|
+
|
|
130
|
+
This service wraps the extended Genie class and provides the GenieServiceBase
|
|
131
|
+
interface for use with cache layers.
|
|
132
|
+
"""
|
|
16
133
|
|
|
17
134
|
genie: Genie
|
|
135
|
+
_workspace_client: WorkspaceClient | None
|
|
136
|
+
|
|
137
|
+
def __init__(
|
|
138
|
+
self,
|
|
139
|
+
genie: Genie | DatabricksGenie,
|
|
140
|
+
workspace_client: WorkspaceClient | None = None,
|
|
141
|
+
) -> None:
|
|
142
|
+
"""
|
|
143
|
+
Initialize the GenieService.
|
|
144
|
+
|
|
145
|
+
Args:
|
|
146
|
+
genie: The Genie instance for asking questions. Can be either our
|
|
147
|
+
extended Genie or the original DatabricksGenie.
|
|
148
|
+
workspace_client: Optional WorkspaceClient for feedback API.
|
|
149
|
+
If not provided, one will be created lazily when needed.
|
|
150
|
+
"""
|
|
151
|
+
self.genie = genie # type: ignore[assignment]
|
|
152
|
+
self._workspace_client = workspace_client
|
|
153
|
+
|
|
154
|
+
@property
|
|
155
|
+
def workspace_client(self) -> WorkspaceClient:
|
|
156
|
+
"""
|
|
157
|
+
Get or create a WorkspaceClient for API calls.
|
|
18
158
|
|
|
19
|
-
|
|
20
|
-
|
|
159
|
+
Lazily creates a WorkspaceClient using default credentials if not provided.
|
|
160
|
+
"""
|
|
161
|
+
if self._workspace_client is None:
|
|
162
|
+
self._workspace_client = WorkspaceClient()
|
|
163
|
+
return self._workspace_client
|
|
21
164
|
|
|
22
165
|
@mlflow.trace(name="genie_ask_question")
|
|
23
166
|
def ask_question(
|
|
24
167
|
self, question: str, conversation_id: str | None = None
|
|
25
168
|
) -> CacheResult:
|
|
26
|
-
"""
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
169
|
+
"""
|
|
170
|
+
Ask question to Genie and return CacheResult.
|
|
171
|
+
|
|
172
|
+
No caching at this level - returns cache miss with fresh response.
|
|
173
|
+
If using our extended Genie, the message_id will be captured in the response.
|
|
174
|
+
"""
|
|
175
|
+
response = self.genie.ask_question(question, conversation_id=conversation_id)
|
|
176
|
+
|
|
177
|
+
# Extract message_id if available (from our extended GenieResponse)
|
|
178
|
+
message_id = getattr(response, "message_id", None)
|
|
179
|
+
|
|
30
180
|
# No caching at this level - return cache miss
|
|
31
|
-
return CacheResult(
|
|
181
|
+
return CacheResult(
|
|
182
|
+
response=response,
|
|
183
|
+
cache_hit=False,
|
|
184
|
+
served_by=None,
|
|
185
|
+
message_id=message_id,
|
|
186
|
+
)
|
|
32
187
|
|
|
33
188
|
@property
|
|
34
189
|
def space_id(self) -> str:
|
|
35
190
|
return self.genie.space_id
|
|
191
|
+
|
|
192
|
+
@mlflow.trace(name="genie_send_feedback")
|
|
193
|
+
def send_feedback(
|
|
194
|
+
self,
|
|
195
|
+
conversation_id: str,
|
|
196
|
+
rating: GenieFeedbackRating,
|
|
197
|
+
message_id: str | None = None,
|
|
198
|
+
was_cache_hit: bool = False,
|
|
199
|
+
) -> None:
|
|
200
|
+
"""
|
|
201
|
+
Send feedback for a Genie message.
|
|
202
|
+
|
|
203
|
+
For the core GenieService, this always sends feedback to the Genie API
|
|
204
|
+
(the was_cache_hit parameter is ignored here - it's used by cache wrappers).
|
|
205
|
+
|
|
206
|
+
Args:
|
|
207
|
+
conversation_id: The conversation containing the message
|
|
208
|
+
rating: The feedback rating (POSITIVE, NEGATIVE, or NONE)
|
|
209
|
+
message_id: Optional message ID. If None, looks up the most recent message.
|
|
210
|
+
was_cache_hit: Ignored by GenieService. Cache wrappers use this to decide
|
|
211
|
+
whether to forward feedback to the underlying service.
|
|
212
|
+
"""
|
|
213
|
+
# Look up message_id if not provided
|
|
214
|
+
if message_id is None:
|
|
215
|
+
message_id = get_latest_message_id(
|
|
216
|
+
workspace_client=self.workspace_client,
|
|
217
|
+
space_id=self.space_id,
|
|
218
|
+
conversation_id=conversation_id,
|
|
219
|
+
)
|
|
220
|
+
if message_id is None:
|
|
221
|
+
logger.warning(
|
|
222
|
+
"Could not find message_id for feedback, skipping",
|
|
223
|
+
space_id=self.space_id,
|
|
224
|
+
conversation_id=conversation_id,
|
|
225
|
+
rating=rating.value if rating else None,
|
|
226
|
+
)
|
|
227
|
+
return
|
|
228
|
+
|
|
229
|
+
logger.info(
|
|
230
|
+
"Sending feedback to Genie",
|
|
231
|
+
space_id=self.space_id,
|
|
232
|
+
conversation_id=conversation_id,
|
|
233
|
+
message_id=message_id,
|
|
234
|
+
rating=rating.value if rating else None,
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
try:
|
|
238
|
+
self.workspace_client.genie.send_message_feedback(
|
|
239
|
+
space_id=self.space_id,
|
|
240
|
+
conversation_id=conversation_id,
|
|
241
|
+
message_id=message_id,
|
|
242
|
+
rating=rating,
|
|
243
|
+
)
|
|
244
|
+
logger.debug(
|
|
245
|
+
"Feedback sent successfully",
|
|
246
|
+
space_id=self.space_id,
|
|
247
|
+
conversation_id=conversation_id,
|
|
248
|
+
message_id=message_id,
|
|
249
|
+
)
|
|
250
|
+
except Exception as e:
|
|
251
|
+
logger.error(
|
|
252
|
+
"Failed to send feedback to Genie",
|
|
253
|
+
space_id=self.space_id,
|
|
254
|
+
conversation_id=conversation_id,
|
|
255
|
+
message_id=message_id,
|
|
256
|
+
rating=rating.value if rating else None,
|
|
257
|
+
error=str(e),
|
|
258
|
+
exc_info=True,
|
|
259
|
+
)
|
dao_ai/memory/postgres.py
CHANGED
|
@@ -3,6 +3,7 @@ import atexit
|
|
|
3
3
|
import threading
|
|
4
4
|
from typing import Any, Optional
|
|
5
5
|
|
|
6
|
+
from databricks_ai_bridge.lakebase import AsyncLakebasePool, LakebasePool
|
|
6
7
|
from langgraph.checkpoint.base import BaseCheckpointSaver
|
|
7
8
|
from langgraph.checkpoint.postgres import ShallowPostgresSaver
|
|
8
9
|
from langgraph.checkpoint.postgres.aio import AsyncShallowPostgresSaver
|
|
@@ -86,13 +87,22 @@ async def _create_async_pool(
|
|
|
86
87
|
|
|
87
88
|
|
|
88
89
|
class AsyncPostgresPoolManager:
|
|
90
|
+
"""
|
|
91
|
+
Asynchronous PostgreSQL connection pool manager that shares pools
|
|
92
|
+
based on database configuration.
|
|
93
|
+
|
|
94
|
+
For Lakebase connections (when instance_name is provided), uses AsyncLakebasePool
|
|
95
|
+
from databricks_ai_bridge which handles automatic token rotation and host resolution.
|
|
96
|
+
For standard PostgreSQL connections, uses psycopg_pool.AsyncConnectionPool.
|
|
97
|
+
"""
|
|
98
|
+
|
|
89
99
|
_pools: dict[str, AsyncConnectionPool] = {}
|
|
100
|
+
_lakebase_pools: dict[str, AsyncLakebasePool] = {}
|
|
90
101
|
_lock: asyncio.Lock = asyncio.Lock()
|
|
91
102
|
|
|
92
103
|
@classmethod
|
|
93
104
|
async def get_pool(cls, database: DatabaseModel) -> AsyncConnectionPool:
|
|
94
105
|
connection_key: str = database.name
|
|
95
|
-
connection_params: dict[str, Any] = database.connection_params
|
|
96
106
|
|
|
97
107
|
async with cls._lock:
|
|
98
108
|
if connection_key in cls._pools:
|
|
@@ -103,19 +113,43 @@ class AsyncPostgresPoolManager:
|
|
|
103
113
|
|
|
104
114
|
logger.debug("Creating new async PostgreSQL pool", database=database.name)
|
|
105
115
|
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
116
|
+
if database.is_lakebase:
|
|
117
|
+
# Use AsyncLakebasePool for Lakebase connections
|
|
118
|
+
# AsyncLakebasePool handles automatic token rotation and host resolution
|
|
119
|
+
lakebase_pool = AsyncLakebasePool(
|
|
120
|
+
instance_name=database.instance_name,
|
|
121
|
+
workspace_client=database.workspace_client,
|
|
122
|
+
min_size=1,
|
|
123
|
+
max_size=database.max_pool_size,
|
|
124
|
+
timeout=float(database.timeout_seconds),
|
|
125
|
+
)
|
|
126
|
+
# Open the async pool
|
|
127
|
+
await lakebase_pool.open()
|
|
128
|
+
# Store the AsyncLakebasePool for proper cleanup
|
|
129
|
+
cls._lakebase_pools[connection_key] = lakebase_pool
|
|
130
|
+
# Get the underlying AsyncConnectionPool
|
|
131
|
+
pool = lakebase_pool.pool
|
|
132
|
+
logger.success(
|
|
133
|
+
"Async Lakebase connection pool created",
|
|
134
|
+
database=database.name,
|
|
135
|
+
instance_name=database.instance_name,
|
|
136
|
+
pool_size=database.max_pool_size,
|
|
137
|
+
)
|
|
138
|
+
else:
|
|
139
|
+
# Use standard async PostgreSQL pool for non-Lakebase connections
|
|
140
|
+
connection_params: dict[str, Any] = database.connection_params
|
|
141
|
+
kwargs: dict[str, Any] = {
|
|
142
|
+
"row_factory": dict_row,
|
|
143
|
+
"autocommit": True,
|
|
144
|
+
} | database.connection_kwargs or {}
|
|
145
|
+
|
|
146
|
+
pool = await _create_async_pool(
|
|
147
|
+
connection_params=connection_params,
|
|
148
|
+
database_name=database.name,
|
|
149
|
+
max_pool_size=database.max_pool_size,
|
|
150
|
+
timeout_seconds=database.timeout_seconds,
|
|
151
|
+
kwargs=kwargs,
|
|
152
|
+
)
|
|
119
153
|
|
|
120
154
|
cls._pools[connection_key] = pool
|
|
121
155
|
return pool
|
|
@@ -125,7 +159,13 @@ class AsyncPostgresPoolManager:
|
|
|
125
159
|
connection_key: str = database.name
|
|
126
160
|
|
|
127
161
|
async with cls._lock:
|
|
128
|
-
if
|
|
162
|
+
# Close AsyncLakebasePool if it exists (handles underlying pool cleanup)
|
|
163
|
+
if connection_key in cls._lakebase_pools:
|
|
164
|
+
lakebase_pool = cls._lakebase_pools.pop(connection_key)
|
|
165
|
+
await lakebase_pool.close()
|
|
166
|
+
cls._pools.pop(connection_key, None)
|
|
167
|
+
logger.debug("Async Lakebase pool closed", database=database.name)
|
|
168
|
+
elif connection_key in cls._pools:
|
|
129
169
|
pool = cls._pools.pop(connection_key)
|
|
130
170
|
await pool.close()
|
|
131
171
|
logger.debug("Async PostgreSQL pool closed", database=database.name)
|
|
@@ -133,9 +173,32 @@ class AsyncPostgresPoolManager:
|
|
|
133
173
|
@classmethod
|
|
134
174
|
async def close_all_pools(cls):
|
|
135
175
|
async with cls._lock:
|
|
176
|
+
# Close all AsyncLakebasePool instances first
|
|
177
|
+
for connection_key, lakebase_pool in cls._lakebase_pools.items():
|
|
178
|
+
try:
|
|
179
|
+
await asyncio.wait_for(lakebase_pool.close(), timeout=2.0)
|
|
180
|
+
logger.debug("Async Lakebase pool closed", pool=connection_key)
|
|
181
|
+
except asyncio.TimeoutError:
|
|
182
|
+
logger.warning(
|
|
183
|
+
"Timeout closing async Lakebase pool, forcing closure",
|
|
184
|
+
pool=connection_key,
|
|
185
|
+
)
|
|
186
|
+
except asyncio.CancelledError:
|
|
187
|
+
logger.warning(
|
|
188
|
+
"Async Lakebase pool closure cancelled (shutdown in progress)",
|
|
189
|
+
pool=connection_key,
|
|
190
|
+
)
|
|
191
|
+
except Exception as e:
|
|
192
|
+
logger.error(
|
|
193
|
+
"Error closing async Lakebase pool",
|
|
194
|
+
pool=connection_key,
|
|
195
|
+
error=str(e),
|
|
196
|
+
)
|
|
197
|
+
cls._lakebase_pools.clear()
|
|
198
|
+
|
|
199
|
+
# Close any remaining standard async PostgreSQL pools
|
|
136
200
|
for connection_key, pool in cls._pools.items():
|
|
137
201
|
try:
|
|
138
|
-
# Use a short timeout to avoid blocking on pool closure
|
|
139
202
|
await asyncio.wait_for(pool.close(), timeout=2.0)
|
|
140
203
|
logger.debug("Async PostgreSQL pool closed", pool=connection_key)
|
|
141
204
|
except asyncio.TimeoutError:
|
|
@@ -178,7 +241,20 @@ class AsyncPostgresStoreManager(StoreManagerBase):
|
|
|
178
241
|
def _setup(self):
|
|
179
242
|
if self._setup_complete:
|
|
180
243
|
return
|
|
181
|
-
|
|
244
|
+
try:
|
|
245
|
+
# Check if we're already in an async context
|
|
246
|
+
asyncio.get_running_loop()
|
|
247
|
+
# If we get here, we're in an async context - raise to caller
|
|
248
|
+
raise RuntimeError(
|
|
249
|
+
"Cannot call sync _setup() from async context. "
|
|
250
|
+
"Use await _async_setup() instead."
|
|
251
|
+
)
|
|
252
|
+
except RuntimeError as e:
|
|
253
|
+
if "no running event loop" in str(e).lower():
|
|
254
|
+
# No event loop running - safe to use asyncio.run()
|
|
255
|
+
asyncio.run(self._async_setup())
|
|
256
|
+
else:
|
|
257
|
+
raise
|
|
182
258
|
|
|
183
259
|
async def _async_setup(self):
|
|
184
260
|
if self._setup_complete:
|
|
@@ -237,13 +313,25 @@ class AsyncPostgresCheckpointerManager(CheckpointManagerBase):
|
|
|
237
313
|
|
|
238
314
|
def _setup(self):
|
|
239
315
|
"""
|
|
240
|
-
Run the async setup.
|
|
316
|
+
Run the async setup. For async contexts, use await _async_setup() directly.
|
|
241
317
|
"""
|
|
242
318
|
if self._setup_complete:
|
|
243
319
|
return
|
|
244
320
|
|
|
245
|
-
|
|
246
|
-
|
|
321
|
+
try:
|
|
322
|
+
# Check if we're already in an async context
|
|
323
|
+
asyncio.get_running_loop()
|
|
324
|
+
# If we get here, we're in an async context - raise to caller
|
|
325
|
+
raise RuntimeError(
|
|
326
|
+
"Cannot call sync _setup() from async context. "
|
|
327
|
+
"Use await _async_setup() instead."
|
|
328
|
+
)
|
|
329
|
+
except RuntimeError as e:
|
|
330
|
+
if "no running event loop" in str(e).lower():
|
|
331
|
+
# No event loop running - safe to use asyncio.run()
|
|
332
|
+
asyncio.run(self._async_setup())
|
|
333
|
+
else:
|
|
334
|
+
raise
|
|
247
335
|
|
|
248
336
|
async def _async_setup(self):
|
|
249
337
|
"""
|
|
@@ -284,15 +372,19 @@ class PostgresPoolManager:
|
|
|
284
372
|
"""
|
|
285
373
|
Synchronous PostgreSQL connection pool manager that shares pools
|
|
286
374
|
based on database configuration.
|
|
375
|
+
|
|
376
|
+
For Lakebase connections (when instance_name is provided), uses LakebasePool
|
|
377
|
+
from databricks_ai_bridge which handles automatic token rotation and host resolution.
|
|
378
|
+
For standard PostgreSQL connections, uses psycopg_pool.ConnectionPool.
|
|
287
379
|
"""
|
|
288
380
|
|
|
289
381
|
_pools: dict[str, ConnectionPool] = {}
|
|
382
|
+
_lakebase_pools: dict[str, LakebasePool] = {}
|
|
290
383
|
_lock: threading.Lock = threading.Lock()
|
|
291
384
|
|
|
292
385
|
@classmethod
|
|
293
386
|
def get_pool(cls, database: DatabaseModel) -> ConnectionPool:
|
|
294
387
|
connection_key: str = str(database.name)
|
|
295
|
-
connection_params: dict[str, Any] = database.connection_params
|
|
296
388
|
|
|
297
389
|
with cls._lock:
|
|
298
390
|
if connection_key in cls._pools:
|
|
@@ -301,19 +393,41 @@ class PostgresPoolManager:
|
|
|
301
393
|
|
|
302
394
|
logger.debug("Creating new PostgreSQL pool", database=database.name)
|
|
303
395
|
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
396
|
+
if database.is_lakebase:
|
|
397
|
+
# Use LakebasePool for Lakebase connections
|
|
398
|
+
# LakebasePool handles automatic token rotation and host resolution
|
|
399
|
+
lakebase_pool = LakebasePool(
|
|
400
|
+
instance_name=database.instance_name,
|
|
401
|
+
workspace_client=database.workspace_client,
|
|
402
|
+
min_size=1,
|
|
403
|
+
max_size=database.max_pool_size,
|
|
404
|
+
timeout=float(database.timeout_seconds),
|
|
405
|
+
)
|
|
406
|
+
# Store the LakebasePool for proper cleanup
|
|
407
|
+
cls._lakebase_pools[connection_key] = lakebase_pool
|
|
408
|
+
# Get the underlying ConnectionPool
|
|
409
|
+
pool = lakebase_pool.pool
|
|
410
|
+
logger.success(
|
|
411
|
+
"Lakebase connection pool created",
|
|
412
|
+
database=database.name,
|
|
413
|
+
instance_name=database.instance_name,
|
|
414
|
+
pool_size=database.max_pool_size,
|
|
415
|
+
)
|
|
416
|
+
else:
|
|
417
|
+
# Use standard PostgreSQL pool for non-Lakebase connections
|
|
418
|
+
connection_params: dict[str, Any] = database.connection_params
|
|
419
|
+
kwargs: dict[str, Any] = {
|
|
420
|
+
"row_factory": dict_row,
|
|
421
|
+
"autocommit": True,
|
|
422
|
+
} | database.connection_kwargs or {}
|
|
423
|
+
|
|
424
|
+
pool = _create_pool(
|
|
425
|
+
connection_params=connection_params,
|
|
426
|
+
database_name=database.name,
|
|
427
|
+
max_pool_size=database.max_pool_size,
|
|
428
|
+
timeout_seconds=database.timeout_seconds,
|
|
429
|
+
kwargs=kwargs,
|
|
430
|
+
)
|
|
317
431
|
|
|
318
432
|
cls._pools[connection_key] = pool
|
|
319
433
|
return pool
|
|
@@ -323,7 +437,13 @@ class PostgresPoolManager:
|
|
|
323
437
|
connection_key: str = database.name
|
|
324
438
|
|
|
325
439
|
with cls._lock:
|
|
326
|
-
if
|
|
440
|
+
# Close LakebasePool if it exists (handles underlying pool cleanup)
|
|
441
|
+
if connection_key in cls._lakebase_pools:
|
|
442
|
+
lakebase_pool = cls._lakebase_pools.pop(connection_key)
|
|
443
|
+
lakebase_pool.close()
|
|
444
|
+
cls._pools.pop(connection_key, None)
|
|
445
|
+
logger.debug("Lakebase pool closed", database=database.name)
|
|
446
|
+
elif connection_key in cls._pools:
|
|
327
447
|
pool = cls._pools.pop(connection_key)
|
|
328
448
|
pool.close()
|
|
329
449
|
logger.debug("PostgreSQL pool closed", database=database.name)
|
|
@@ -331,16 +451,32 @@ class PostgresPoolManager:
|
|
|
331
451
|
@classmethod
|
|
332
452
|
def close_all_pools(cls):
|
|
333
453
|
with cls._lock:
|
|
334
|
-
|
|
454
|
+
# Close all LakebasePool instances first
|
|
455
|
+
for connection_key, lakebase_pool in cls._lakebase_pools.items():
|
|
335
456
|
try:
|
|
336
|
-
|
|
337
|
-
logger.debug("
|
|
457
|
+
lakebase_pool.close()
|
|
458
|
+
logger.debug("Lakebase pool closed", pool=connection_key)
|
|
338
459
|
except Exception as e:
|
|
339
460
|
logger.error(
|
|
340
|
-
"Error closing
|
|
461
|
+
"Error closing Lakebase pool",
|
|
341
462
|
pool=connection_key,
|
|
342
463
|
error=str(e),
|
|
343
464
|
)
|
|
465
|
+
cls._lakebase_pools.clear()
|
|
466
|
+
|
|
467
|
+
# Close any remaining standard PostgreSQL pools
|
|
468
|
+
for connection_key, pool in cls._pools.items():
|
|
469
|
+
# Skip if already closed via LakebasePool
|
|
470
|
+
if connection_key not in cls._lakebase_pools:
|
|
471
|
+
try:
|
|
472
|
+
pool.close()
|
|
473
|
+
logger.debug("PostgreSQL pool closed", pool=connection_key)
|
|
474
|
+
except Exception as e:
|
|
475
|
+
logger.error(
|
|
476
|
+
"Error closing PostgreSQL pool",
|
|
477
|
+
pool=connection_key,
|
|
478
|
+
error=str(e),
|
|
479
|
+
)
|
|
344
480
|
cls._pools.clear()
|
|
345
481
|
|
|
346
482
|
|
dao_ai/middleware/__init__.py
CHANGED
|
@@ -3,8 +3,16 @@
|
|
|
3
3
|
|
|
4
4
|
# Re-export LangChain built-in middleware
|
|
5
5
|
from langchain.agents.middleware import (
|
|
6
|
+
ClearToolUsesEdit,
|
|
7
|
+
ContextEditingMiddleware,
|
|
6
8
|
HumanInTheLoopMiddleware,
|
|
9
|
+
LLMToolSelectorMiddleware,
|
|
10
|
+
ModelCallLimitMiddleware,
|
|
11
|
+
ModelRetryMiddleware,
|
|
12
|
+
PIIMiddleware,
|
|
7
13
|
SummarizationMiddleware,
|
|
14
|
+
ToolCallLimitMiddleware,
|
|
15
|
+
ToolRetryMiddleware,
|
|
8
16
|
after_agent,
|
|
9
17
|
after_model,
|
|
10
18
|
before_agent,
|
|
@@ -37,6 +45,10 @@ from dao_ai.middleware.base import (
|
|
|
37
45
|
ModelRequest,
|
|
38
46
|
ModelResponse,
|
|
39
47
|
)
|
|
48
|
+
from dao_ai.middleware.context_editing import (
|
|
49
|
+
create_clear_tool_uses_edit,
|
|
50
|
+
create_context_editing_middleware,
|
|
51
|
+
)
|
|
40
52
|
from dao_ai.middleware.core import create_factory_middleware
|
|
41
53
|
from dao_ai.middleware.guardrails import (
|
|
42
54
|
ContentFilterMiddleware,
|
|
@@ -62,10 +74,16 @@ from dao_ai.middleware.message_validation import (
|
|
|
62
74
|
create_thread_id_validation_middleware,
|
|
63
75
|
create_user_id_validation_middleware,
|
|
64
76
|
)
|
|
77
|
+
from dao_ai.middleware.model_call_limit import create_model_call_limit_middleware
|
|
78
|
+
from dao_ai.middleware.model_retry import create_model_retry_middleware
|
|
79
|
+
from dao_ai.middleware.pii import create_pii_middleware
|
|
65
80
|
from dao_ai.middleware.summarization import (
|
|
66
81
|
LoggingSummarizationMiddleware,
|
|
67
82
|
create_summarization_middleware,
|
|
68
83
|
)
|
|
84
|
+
from dao_ai.middleware.tool_call_limit import create_tool_call_limit_middleware
|
|
85
|
+
from dao_ai.middleware.tool_retry import create_tool_retry_middleware
|
|
86
|
+
from dao_ai.middleware.tool_selector import create_llm_tool_selector_middleware
|
|
69
87
|
|
|
70
88
|
__all__ = [
|
|
71
89
|
# Base class (from LangChain)
|
|
@@ -85,6 +103,14 @@ __all__ = [
|
|
|
85
103
|
"SummarizationMiddleware",
|
|
86
104
|
"LoggingSummarizationMiddleware",
|
|
87
105
|
"HumanInTheLoopMiddleware",
|
|
106
|
+
"ToolCallLimitMiddleware",
|
|
107
|
+
"ModelCallLimitMiddleware",
|
|
108
|
+
"ToolRetryMiddleware",
|
|
109
|
+
"ModelRetryMiddleware",
|
|
110
|
+
"LLMToolSelectorMiddleware",
|
|
111
|
+
"ContextEditingMiddleware",
|
|
112
|
+
"ClearToolUsesEdit",
|
|
113
|
+
"PIIMiddleware",
|
|
88
114
|
# Core factory function
|
|
89
115
|
"create_factory_middleware",
|
|
90
116
|
# DAO AI middleware implementations
|
|
@@ -122,4 +148,16 @@ __all__ = [
|
|
|
122
148
|
"create_assert_middleware",
|
|
123
149
|
"create_suggest_middleware",
|
|
124
150
|
"create_refine_middleware",
|
|
151
|
+
# Limit and retry middleware factory functions
|
|
152
|
+
"create_tool_call_limit_middleware",
|
|
153
|
+
"create_model_call_limit_middleware",
|
|
154
|
+
"create_tool_retry_middleware",
|
|
155
|
+
"create_model_retry_middleware",
|
|
156
|
+
# Tool selection middleware factory functions
|
|
157
|
+
"create_llm_tool_selector_middleware",
|
|
158
|
+
# Context editing middleware factory functions
|
|
159
|
+
"create_context_editing_middleware",
|
|
160
|
+
"create_clear_tool_uses_edit",
|
|
161
|
+
# PII middleware factory functions
|
|
162
|
+
"create_pii_middleware",
|
|
125
163
|
]
|
dao_ai/middleware/assertions.py
CHANGED
|
@@ -688,7 +688,7 @@ def create_assert_middleware(
|
|
|
688
688
|
name: Name for function constraints
|
|
689
689
|
|
|
690
690
|
Returns:
|
|
691
|
-
AssertMiddleware configured with the constraint
|
|
691
|
+
List containing AssertMiddleware configured with the constraint
|
|
692
692
|
|
|
693
693
|
Example:
|
|
694
694
|
# Using a Constraint class
|
|
@@ -737,7 +737,7 @@ def create_suggest_middleware(
|
|
|
737
737
|
name: Name for function constraints
|
|
738
738
|
|
|
739
739
|
Returns:
|
|
740
|
-
SuggestMiddleware configured with the constraint
|
|
740
|
+
List containing SuggestMiddleware configured with the constraint
|
|
741
741
|
|
|
742
742
|
Example:
|
|
743
743
|
def is_professional(response: str, ctx: dict) -> ConstraintResult:
|
|
@@ -783,7 +783,7 @@ def create_refine_middleware(
|
|
|
783
783
|
select_best: Track and return best response across iterations
|
|
784
784
|
|
|
785
785
|
Returns:
|
|
786
|
-
RefineMiddleware configured with the reward function
|
|
786
|
+
List containing RefineMiddleware configured with the reward function
|
|
787
787
|
|
|
788
788
|
Example:
|
|
789
789
|
def evaluate_completeness(response: str, ctx: dict) -> float:
|