dao-ai 0.1.2__py3-none-any.whl → 0.1.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. dao_ai/apps/__init__.py +24 -0
  2. dao_ai/apps/handlers.py +105 -0
  3. dao_ai/apps/model_serving.py +29 -0
  4. dao_ai/apps/resources.py +1122 -0
  5. dao_ai/apps/server.py +39 -0
  6. dao_ai/cli.py +546 -37
  7. dao_ai/config.py +1179 -139
  8. dao_ai/evaluation.py +543 -0
  9. dao_ai/genie/__init__.py +55 -7
  10. dao_ai/genie/cache/__init__.py +34 -7
  11. dao_ai/genie/cache/base.py +143 -2
  12. dao_ai/genie/cache/context_aware/__init__.py +31 -0
  13. dao_ai/genie/cache/context_aware/base.py +1151 -0
  14. dao_ai/genie/cache/context_aware/in_memory.py +609 -0
  15. dao_ai/genie/cache/context_aware/persistent.py +802 -0
  16. dao_ai/genie/cache/context_aware/postgres.py +1166 -0
  17. dao_ai/genie/cache/core.py +1 -1
  18. dao_ai/genie/cache/lru.py +257 -75
  19. dao_ai/genie/cache/optimization.py +890 -0
  20. dao_ai/genie/core.py +235 -11
  21. dao_ai/memory/postgres.py +175 -39
  22. dao_ai/middleware/__init__.py +38 -0
  23. dao_ai/middleware/assertions.py +3 -3
  24. dao_ai/middleware/context_editing.py +230 -0
  25. dao_ai/middleware/core.py +4 -4
  26. dao_ai/middleware/guardrails.py +3 -3
  27. dao_ai/middleware/human_in_the_loop.py +3 -2
  28. dao_ai/middleware/message_validation.py +4 -4
  29. dao_ai/middleware/model_call_limit.py +77 -0
  30. dao_ai/middleware/model_retry.py +121 -0
  31. dao_ai/middleware/pii.py +157 -0
  32. dao_ai/middleware/summarization.py +1 -1
  33. dao_ai/middleware/tool_call_limit.py +210 -0
  34. dao_ai/middleware/tool_retry.py +174 -0
  35. dao_ai/middleware/tool_selector.py +129 -0
  36. dao_ai/models.py +327 -370
  37. dao_ai/nodes.py +9 -16
  38. dao_ai/orchestration/core.py +33 -9
  39. dao_ai/orchestration/supervisor.py +29 -13
  40. dao_ai/orchestration/swarm.py +6 -1
  41. dao_ai/{prompts.py → prompts/__init__.py} +12 -61
  42. dao_ai/prompts/instructed_retriever_decomposition.yaml +58 -0
  43. dao_ai/prompts/instruction_reranker.yaml +14 -0
  44. dao_ai/prompts/router.yaml +37 -0
  45. dao_ai/prompts/verifier.yaml +46 -0
  46. dao_ai/providers/base.py +28 -2
  47. dao_ai/providers/databricks.py +363 -33
  48. dao_ai/state.py +1 -0
  49. dao_ai/tools/__init__.py +5 -3
  50. dao_ai/tools/genie.py +103 -26
  51. dao_ai/tools/instructed_retriever.py +366 -0
  52. dao_ai/tools/instruction_reranker.py +202 -0
  53. dao_ai/tools/mcp.py +539 -97
  54. dao_ai/tools/router.py +89 -0
  55. dao_ai/tools/slack.py +13 -2
  56. dao_ai/tools/sql.py +7 -3
  57. dao_ai/tools/unity_catalog.py +32 -10
  58. dao_ai/tools/vector_search.py +493 -160
  59. dao_ai/tools/verifier.py +159 -0
  60. dao_ai/utils.py +182 -2
  61. dao_ai/vector_search.py +46 -1
  62. {dao_ai-0.1.2.dist-info → dao_ai-0.1.20.dist-info}/METADATA +45 -9
  63. dao_ai-0.1.20.dist-info/RECORD +89 -0
  64. dao_ai/agent_as_code.py +0 -22
  65. dao_ai/genie/cache/semantic.py +0 -970
  66. dao_ai-0.1.2.dist-info/RECORD +0 -64
  67. {dao_ai-0.1.2.dist-info → dao_ai-0.1.20.dist-info}/WHEEL +0 -0
  68. {dao_ai-0.1.2.dist-info → dao_ai-0.1.20.dist-info}/entry_points.txt +0 -0
  69. {dao_ai-0.1.2.dist-info → dao_ai-0.1.20.dist-info}/licenses/LICENSE +0 -0
dao_ai/genie/core.py CHANGED
@@ -1,35 +1,259 @@
1
1
  """
2
2
  Core Genie service implementation.
3
3
 
4
- This module provides the concrete implementation of GenieServiceBase
5
- that wraps the Databricks Genie SDK.
4
+ This module provides:
5
+ - Extended Genie and GenieResponse classes that capture message_id
6
+ - GenieService: Concrete implementation of GenieServiceBase
7
+
8
+ The extended classes wrap the databricks_ai_bridge versions to add message_id
9
+ support, which is needed for sending feedback to the Genie API.
6
10
  """
7
11
 
12
+ from __future__ import annotations
13
+
14
+ from dataclasses import dataclass
15
+ from typing import TYPE_CHECKING, Union
16
+
8
17
  import mlflow
9
- from databricks_ai_bridge.genie import Genie, GenieResponse
18
+ import pandas as pd
19
+ from databricks.sdk import WorkspaceClient
20
+ from databricks.sdk.service.dashboards import GenieFeedbackRating
21
+ from databricks_ai_bridge.genie import Genie as DatabricksGenie
22
+ from databricks_ai_bridge.genie import GenieResponse as DatabricksGenieResponse
23
+ from loguru import logger
10
24
 
11
25
  from dao_ai.genie.cache import CacheResult, GenieServiceBase
26
+ from dao_ai.genie.cache.base import get_latest_message_id
27
+
28
+ if TYPE_CHECKING:
29
+ from typing import Optional
30
+
31
+
32
+ # =============================================================================
33
+ # Extended Genie Classes with message_id Support
34
+ # =============================================================================
35
+
36
+
37
+ @dataclass
38
+ class GenieResponse(DatabricksGenieResponse):
39
+ """
40
+ Extended GenieResponse that includes message_id.
41
+
42
+ This extends the databricks_ai_bridge GenieResponse to capture the message_id
43
+ from API responses, which is required for sending feedback to the Genie API.
44
+
45
+ Attributes:
46
+ result: The query result as string or DataFrame
47
+ query: The generated SQL query
48
+ description: Description of the query
49
+ conversation_id: The conversation ID
50
+ message_id: The message ID (NEW - enables feedback without extra API call)
51
+ """
52
+
53
+ result: Union[str, pd.DataFrame] = ""
54
+ query: Optional[str] = ""
55
+ description: Optional[str] = ""
56
+ conversation_id: Optional[str] = None
57
+ message_id: Optional[str] = None
58
+
59
+
60
+ class Genie(DatabricksGenie):
61
+ """
62
+ Extended Genie that captures message_id in responses.
63
+
64
+ This extends the databricks_ai_bridge Genie to return GenieResponse objects
65
+ that include the message_id from the API response. This enables sending
66
+ feedback without requiring an additional API call to look up the message ID.
67
+
68
+ Usage:
69
+ genie = Genie(space_id="my-space")
70
+ response = genie.ask_question("What are total sales?")
71
+ print(response.message_id) # Now available!
72
+
73
+ The original databricks_ai_bridge classes are available as:
74
+ - DatabricksGenie
75
+ - DatabricksGenieResponse
76
+ """
77
+
78
+ def ask_question(
79
+ self, question: str, conversation_id: str | None = None
80
+ ) -> GenieResponse:
81
+ """
82
+ Ask a question and return response with message_id.
83
+
84
+ This overrides the parent method to capture the message_id from the
85
+ API response and include it in the returned GenieResponse.
86
+
87
+ Args:
88
+ question: The question to ask
89
+ conversation_id: Optional conversation ID for follow-up questions
90
+
91
+ Returns:
92
+ GenieResponse with message_id populated
93
+ """
94
+ with mlflow.start_span(name="ask_question"):
95
+ # Start or continue conversation
96
+ if not conversation_id:
97
+ resp = self.start_conversation(question)
98
+ else:
99
+ resp = self.create_message(conversation_id, question)
100
+
101
+ # Capture message_id from the API response
102
+ message_id = resp.get("message_id")
103
+
104
+ # Poll for the result using parent's method
105
+ genie_response = self.poll_for_result(resp["conversation_id"], message_id)
106
+
107
+ # Ensure conversation_id is set
108
+ if not genie_response.conversation_id:
109
+ genie_response.conversation_id = resp["conversation_id"]
110
+
111
+ # Return our extended response with message_id
112
+ return GenieResponse(
113
+ result=genie_response.result,
114
+ query=genie_response.query,
115
+ description=genie_response.description,
116
+ conversation_id=genie_response.conversation_id,
117
+ message_id=message_id,
118
+ )
119
+
120
+
121
+ # =============================================================================
122
+ # GenieService Implementation
123
+ # =============================================================================
12
124
 
13
125
 
14
126
  class GenieService(GenieServiceBase):
15
- """Concrete implementation of GenieServiceBase using the Genie SDK."""
127
+ """
128
+ Concrete implementation of GenieServiceBase using the extended Genie.
129
+
130
+ This service wraps the extended Genie class and provides the GenieServiceBase
131
+ interface for use with cache layers.
132
+ """
16
133
 
17
134
  genie: Genie
135
+ _workspace_client: WorkspaceClient | None
136
+
137
+ def __init__(
138
+ self,
139
+ genie: Genie | DatabricksGenie,
140
+ workspace_client: WorkspaceClient | None = None,
141
+ ) -> None:
142
+ """
143
+ Initialize the GenieService.
144
+
145
+ Args:
146
+ genie: The Genie instance for asking questions. Can be either our
147
+ extended Genie or the original DatabricksGenie.
148
+ workspace_client: Optional WorkspaceClient for feedback API.
149
+ If not provided, one will be created lazily when needed.
150
+ """
151
+ self.genie = genie # type: ignore[assignment]
152
+ self._workspace_client = workspace_client
153
+
154
+ @property
155
+ def workspace_client(self) -> WorkspaceClient:
156
+ """
157
+ Get or create a WorkspaceClient for API calls.
18
158
 
19
- def __init__(self, genie: Genie) -> None:
20
- self.genie = genie
159
+ Lazily creates a WorkspaceClient using default credentials if not provided.
160
+ """
161
+ if self._workspace_client is None:
162
+ self._workspace_client = WorkspaceClient()
163
+ return self._workspace_client
21
164
 
22
165
  @mlflow.trace(name="genie_ask_question")
23
166
  def ask_question(
24
167
  self, question: str, conversation_id: str | None = None
25
168
  ) -> CacheResult:
26
- """Ask question to Genie and return CacheResult (no caching at this level)."""
27
- response: GenieResponse = self.genie.ask_question(
28
- question, conversation_id=conversation_id
29
- )
169
+ """
170
+ Ask question to Genie and return CacheResult.
171
+
172
+ No caching at this level - returns cache miss with fresh response.
173
+ If using our extended Genie, the message_id will be captured in the response.
174
+ """
175
+ response = self.genie.ask_question(question, conversation_id=conversation_id)
176
+
177
+ # Extract message_id if available (from our extended GenieResponse)
178
+ message_id = getattr(response, "message_id", None)
179
+
30
180
  # No caching at this level - return cache miss
31
- return CacheResult(response=response, cache_hit=False, served_by=None)
181
+ return CacheResult(
182
+ response=response,
183
+ cache_hit=False,
184
+ served_by=None,
185
+ message_id=message_id,
186
+ )
32
187
 
33
188
  @property
34
189
  def space_id(self) -> str:
35
190
  return self.genie.space_id
191
+
192
+ @mlflow.trace(name="genie_send_feedback")
193
+ def send_feedback(
194
+ self,
195
+ conversation_id: str,
196
+ rating: GenieFeedbackRating,
197
+ message_id: str | None = None,
198
+ was_cache_hit: bool = False,
199
+ ) -> None:
200
+ """
201
+ Send feedback for a Genie message.
202
+
203
+ For the core GenieService, this always sends feedback to the Genie API
204
+ (the was_cache_hit parameter is ignored here - it's used by cache wrappers).
205
+
206
+ Args:
207
+ conversation_id: The conversation containing the message
208
+ rating: The feedback rating (POSITIVE, NEGATIVE, or NONE)
209
+ message_id: Optional message ID. If None, looks up the most recent message.
210
+ was_cache_hit: Ignored by GenieService. Cache wrappers use this to decide
211
+ whether to forward feedback to the underlying service.
212
+ """
213
+ # Look up message_id if not provided
214
+ if message_id is None:
215
+ message_id = get_latest_message_id(
216
+ workspace_client=self.workspace_client,
217
+ space_id=self.space_id,
218
+ conversation_id=conversation_id,
219
+ )
220
+ if message_id is None:
221
+ logger.warning(
222
+ "Could not find message_id for feedback, skipping",
223
+ space_id=self.space_id,
224
+ conversation_id=conversation_id,
225
+ rating=rating.value if rating else None,
226
+ )
227
+ return
228
+
229
+ logger.info(
230
+ "Sending feedback to Genie",
231
+ space_id=self.space_id,
232
+ conversation_id=conversation_id,
233
+ message_id=message_id,
234
+ rating=rating.value if rating else None,
235
+ )
236
+
237
+ try:
238
+ self.workspace_client.genie.send_message_feedback(
239
+ space_id=self.space_id,
240
+ conversation_id=conversation_id,
241
+ message_id=message_id,
242
+ rating=rating,
243
+ )
244
+ logger.debug(
245
+ "Feedback sent successfully",
246
+ space_id=self.space_id,
247
+ conversation_id=conversation_id,
248
+ message_id=message_id,
249
+ )
250
+ except Exception as e:
251
+ logger.error(
252
+ "Failed to send feedback to Genie",
253
+ space_id=self.space_id,
254
+ conversation_id=conversation_id,
255
+ message_id=message_id,
256
+ rating=rating.value if rating else None,
257
+ error=str(e),
258
+ exc_info=True,
259
+ )
dao_ai/memory/postgres.py CHANGED
@@ -3,6 +3,7 @@ import atexit
3
3
  import threading
4
4
  from typing import Any, Optional
5
5
 
6
+ from databricks_ai_bridge.lakebase import AsyncLakebasePool, LakebasePool
6
7
  from langgraph.checkpoint.base import BaseCheckpointSaver
7
8
  from langgraph.checkpoint.postgres import ShallowPostgresSaver
8
9
  from langgraph.checkpoint.postgres.aio import AsyncShallowPostgresSaver
@@ -86,13 +87,22 @@ async def _create_async_pool(
86
87
 
87
88
 
88
89
  class AsyncPostgresPoolManager:
90
+ """
91
+ Asynchronous PostgreSQL connection pool manager that shares pools
92
+ based on database configuration.
93
+
94
+ For Lakebase connections (when instance_name is provided), uses AsyncLakebasePool
95
+ from databricks_ai_bridge which handles automatic token rotation and host resolution.
96
+ For standard PostgreSQL connections, uses psycopg_pool.AsyncConnectionPool.
97
+ """
98
+
89
99
  _pools: dict[str, AsyncConnectionPool] = {}
100
+ _lakebase_pools: dict[str, AsyncLakebasePool] = {}
90
101
  _lock: asyncio.Lock = asyncio.Lock()
91
102
 
92
103
  @classmethod
93
104
  async def get_pool(cls, database: DatabaseModel) -> AsyncConnectionPool:
94
105
  connection_key: str = database.name
95
- connection_params: dict[str, Any] = database.connection_params
96
106
 
97
107
  async with cls._lock:
98
108
  if connection_key in cls._pools:
@@ -103,19 +113,43 @@ class AsyncPostgresPoolManager:
103
113
 
104
114
  logger.debug("Creating new async PostgreSQL pool", database=database.name)
105
115
 
106
- kwargs: dict[str, Any] = {
107
- "row_factory": dict_row,
108
- "autocommit": True,
109
- } | database.connection_kwargs or {}
110
-
111
- # Create connection pool
112
- pool: AsyncConnectionPool = await _create_async_pool(
113
- connection_params=connection_params,
114
- database_name=database.name,
115
- max_pool_size=database.max_pool_size,
116
- timeout_seconds=database.timeout_seconds,
117
- kwargs=kwargs,
118
- )
116
+ if database.is_lakebase:
117
+ # Use AsyncLakebasePool for Lakebase connections
118
+ # AsyncLakebasePool handles automatic token rotation and host resolution
119
+ lakebase_pool = AsyncLakebasePool(
120
+ instance_name=database.instance_name,
121
+ workspace_client=database.workspace_client,
122
+ min_size=1,
123
+ max_size=database.max_pool_size,
124
+ timeout=float(database.timeout_seconds),
125
+ )
126
+ # Open the async pool
127
+ await lakebase_pool.open()
128
+ # Store the AsyncLakebasePool for proper cleanup
129
+ cls._lakebase_pools[connection_key] = lakebase_pool
130
+ # Get the underlying AsyncConnectionPool
131
+ pool = lakebase_pool.pool
132
+ logger.success(
133
+ "Async Lakebase connection pool created",
134
+ database=database.name,
135
+ instance_name=database.instance_name,
136
+ pool_size=database.max_pool_size,
137
+ )
138
+ else:
139
+ # Use standard async PostgreSQL pool for non-Lakebase connections
140
+ connection_params: dict[str, Any] = database.connection_params
141
+ kwargs: dict[str, Any] = {
142
+ "row_factory": dict_row,
143
+ "autocommit": True,
144
+ } | database.connection_kwargs or {}
145
+
146
+ pool = await _create_async_pool(
147
+ connection_params=connection_params,
148
+ database_name=database.name,
149
+ max_pool_size=database.max_pool_size,
150
+ timeout_seconds=database.timeout_seconds,
151
+ kwargs=kwargs,
152
+ )
119
153
 
120
154
  cls._pools[connection_key] = pool
121
155
  return pool
@@ -125,7 +159,13 @@ class AsyncPostgresPoolManager:
125
159
  connection_key: str = database.name
126
160
 
127
161
  async with cls._lock:
128
- if connection_key in cls._pools:
162
+ # Close AsyncLakebasePool if it exists (handles underlying pool cleanup)
163
+ if connection_key in cls._lakebase_pools:
164
+ lakebase_pool = cls._lakebase_pools.pop(connection_key)
165
+ await lakebase_pool.close()
166
+ cls._pools.pop(connection_key, None)
167
+ logger.debug("Async Lakebase pool closed", database=database.name)
168
+ elif connection_key in cls._pools:
129
169
  pool = cls._pools.pop(connection_key)
130
170
  await pool.close()
131
171
  logger.debug("Async PostgreSQL pool closed", database=database.name)
@@ -133,9 +173,32 @@ class AsyncPostgresPoolManager:
133
173
  @classmethod
134
174
  async def close_all_pools(cls):
135
175
  async with cls._lock:
176
+ # Close all AsyncLakebasePool instances first
177
+ for connection_key, lakebase_pool in cls._lakebase_pools.items():
178
+ try:
179
+ await asyncio.wait_for(lakebase_pool.close(), timeout=2.0)
180
+ logger.debug("Async Lakebase pool closed", pool=connection_key)
181
+ except asyncio.TimeoutError:
182
+ logger.warning(
183
+ "Timeout closing async Lakebase pool, forcing closure",
184
+ pool=connection_key,
185
+ )
186
+ except asyncio.CancelledError:
187
+ logger.warning(
188
+ "Async Lakebase pool closure cancelled (shutdown in progress)",
189
+ pool=connection_key,
190
+ )
191
+ except Exception as e:
192
+ logger.error(
193
+ "Error closing async Lakebase pool",
194
+ pool=connection_key,
195
+ error=str(e),
196
+ )
197
+ cls._lakebase_pools.clear()
198
+
199
+ # Close any remaining standard async PostgreSQL pools
136
200
  for connection_key, pool in cls._pools.items():
137
201
  try:
138
- # Use a short timeout to avoid blocking on pool closure
139
202
  await asyncio.wait_for(pool.close(), timeout=2.0)
140
203
  logger.debug("Async PostgreSQL pool closed", pool=connection_key)
141
204
  except asyncio.TimeoutError:
@@ -178,7 +241,20 @@ class AsyncPostgresStoreManager(StoreManagerBase):
178
241
  def _setup(self):
179
242
  if self._setup_complete:
180
243
  return
181
- asyncio.run(self._async_setup())
244
+ try:
245
+ # Check if we're already in an async context
246
+ asyncio.get_running_loop()
247
+ # If we get here, we're in an async context - raise to caller
248
+ raise RuntimeError(
249
+ "Cannot call sync _setup() from async context. "
250
+ "Use await _async_setup() instead."
251
+ )
252
+ except RuntimeError as e:
253
+ if "no running event loop" in str(e).lower():
254
+ # No event loop running - safe to use asyncio.run()
255
+ asyncio.run(self._async_setup())
256
+ else:
257
+ raise
182
258
 
183
259
  async def _async_setup(self):
184
260
  if self._setup_complete:
@@ -237,13 +313,25 @@ class AsyncPostgresCheckpointerManager(CheckpointManagerBase):
237
313
 
238
314
  def _setup(self):
239
315
  """
240
- Run the async setup. Works in both sync and async contexts when nest_asyncio is applied.
316
+ Run the async setup. For async contexts, use await _async_setup() directly.
241
317
  """
242
318
  if self._setup_complete:
243
319
  return
244
320
 
245
- # With nest_asyncio applied in notebooks, asyncio.run() works everywhere
246
- asyncio.run(self._async_setup())
321
+ try:
322
+ # Check if we're already in an async context
323
+ asyncio.get_running_loop()
324
+ # If we get here, we're in an async context - raise to caller
325
+ raise RuntimeError(
326
+ "Cannot call sync _setup() from async context. "
327
+ "Use await _async_setup() instead."
328
+ )
329
+ except RuntimeError as e:
330
+ if "no running event loop" in str(e).lower():
331
+ # No event loop running - safe to use asyncio.run()
332
+ asyncio.run(self._async_setup())
333
+ else:
334
+ raise
247
335
 
248
336
  async def _async_setup(self):
249
337
  """
@@ -284,15 +372,19 @@ class PostgresPoolManager:
284
372
  """
285
373
  Synchronous PostgreSQL connection pool manager that shares pools
286
374
  based on database configuration.
375
+
376
+ For Lakebase connections (when instance_name is provided), uses LakebasePool
377
+ from databricks_ai_bridge which handles automatic token rotation and host resolution.
378
+ For standard PostgreSQL connections, uses psycopg_pool.ConnectionPool.
287
379
  """
288
380
 
289
381
  _pools: dict[str, ConnectionPool] = {}
382
+ _lakebase_pools: dict[str, LakebasePool] = {}
290
383
  _lock: threading.Lock = threading.Lock()
291
384
 
292
385
  @classmethod
293
386
  def get_pool(cls, database: DatabaseModel) -> ConnectionPool:
294
387
  connection_key: str = str(database.name)
295
- connection_params: dict[str, Any] = database.connection_params
296
388
 
297
389
  with cls._lock:
298
390
  if connection_key in cls._pools:
@@ -301,19 +393,41 @@ class PostgresPoolManager:
301
393
 
302
394
  logger.debug("Creating new PostgreSQL pool", database=database.name)
303
395
 
304
- kwargs: dict[str, Any] = {
305
- "row_factory": dict_row,
306
- "autocommit": True,
307
- } | database.connection_kwargs or {}
308
-
309
- # Create connection pool
310
- pool: ConnectionPool = _create_pool(
311
- connection_params=connection_params,
312
- database_name=database.name,
313
- max_pool_size=database.max_pool_size,
314
- timeout_seconds=database.timeout_seconds,
315
- kwargs=kwargs,
316
- )
396
+ if database.is_lakebase:
397
+ # Use LakebasePool for Lakebase connections
398
+ # LakebasePool handles automatic token rotation and host resolution
399
+ lakebase_pool = LakebasePool(
400
+ instance_name=database.instance_name,
401
+ workspace_client=database.workspace_client,
402
+ min_size=1,
403
+ max_size=database.max_pool_size,
404
+ timeout=float(database.timeout_seconds),
405
+ )
406
+ # Store the LakebasePool for proper cleanup
407
+ cls._lakebase_pools[connection_key] = lakebase_pool
408
+ # Get the underlying ConnectionPool
409
+ pool = lakebase_pool.pool
410
+ logger.success(
411
+ "Lakebase connection pool created",
412
+ database=database.name,
413
+ instance_name=database.instance_name,
414
+ pool_size=database.max_pool_size,
415
+ )
416
+ else:
417
+ # Use standard PostgreSQL pool for non-Lakebase connections
418
+ connection_params: dict[str, Any] = database.connection_params
419
+ kwargs: dict[str, Any] = {
420
+ "row_factory": dict_row,
421
+ "autocommit": True,
422
+ } | database.connection_kwargs or {}
423
+
424
+ pool = _create_pool(
425
+ connection_params=connection_params,
426
+ database_name=database.name,
427
+ max_pool_size=database.max_pool_size,
428
+ timeout_seconds=database.timeout_seconds,
429
+ kwargs=kwargs,
430
+ )
317
431
 
318
432
  cls._pools[connection_key] = pool
319
433
  return pool
@@ -323,7 +437,13 @@ class PostgresPoolManager:
323
437
  connection_key: str = database.name
324
438
 
325
439
  with cls._lock:
326
- if connection_key in cls._pools:
440
+ # Close LakebasePool if it exists (handles underlying pool cleanup)
441
+ if connection_key in cls._lakebase_pools:
442
+ lakebase_pool = cls._lakebase_pools.pop(connection_key)
443
+ lakebase_pool.close()
444
+ cls._pools.pop(connection_key, None)
445
+ logger.debug("Lakebase pool closed", database=database.name)
446
+ elif connection_key in cls._pools:
327
447
  pool = cls._pools.pop(connection_key)
328
448
  pool.close()
329
449
  logger.debug("PostgreSQL pool closed", database=database.name)
@@ -331,16 +451,32 @@ class PostgresPoolManager:
331
451
  @classmethod
332
452
  def close_all_pools(cls):
333
453
  with cls._lock:
334
- for connection_key, pool in cls._pools.items():
454
+ # Close all LakebasePool instances first
455
+ for connection_key, lakebase_pool in cls._lakebase_pools.items():
335
456
  try:
336
- pool.close()
337
- logger.debug("PostgreSQL pool closed", pool=connection_key)
457
+ lakebase_pool.close()
458
+ logger.debug("Lakebase pool closed", pool=connection_key)
338
459
  except Exception as e:
339
460
  logger.error(
340
- "Error closing PostgreSQL pool",
461
+ "Error closing Lakebase pool",
341
462
  pool=connection_key,
342
463
  error=str(e),
343
464
  )
465
+ cls._lakebase_pools.clear()
466
+
467
+ # Close any remaining standard PostgreSQL pools
468
+ for connection_key, pool in cls._pools.items():
469
+ # Skip if already closed via LakebasePool
470
+ if connection_key not in cls._lakebase_pools:
471
+ try:
472
+ pool.close()
473
+ logger.debug("PostgreSQL pool closed", pool=connection_key)
474
+ except Exception as e:
475
+ logger.error(
476
+ "Error closing PostgreSQL pool",
477
+ pool=connection_key,
478
+ error=str(e),
479
+ )
344
480
  cls._pools.clear()
345
481
 
346
482
 
@@ -3,8 +3,16 @@
3
3
 
4
4
  # Re-export LangChain built-in middleware
5
5
  from langchain.agents.middleware import (
6
+ ClearToolUsesEdit,
7
+ ContextEditingMiddleware,
6
8
  HumanInTheLoopMiddleware,
9
+ LLMToolSelectorMiddleware,
10
+ ModelCallLimitMiddleware,
11
+ ModelRetryMiddleware,
12
+ PIIMiddleware,
7
13
  SummarizationMiddleware,
14
+ ToolCallLimitMiddleware,
15
+ ToolRetryMiddleware,
8
16
  after_agent,
9
17
  after_model,
10
18
  before_agent,
@@ -37,6 +45,10 @@ from dao_ai.middleware.base import (
37
45
  ModelRequest,
38
46
  ModelResponse,
39
47
  )
48
+ from dao_ai.middleware.context_editing import (
49
+ create_clear_tool_uses_edit,
50
+ create_context_editing_middleware,
51
+ )
40
52
  from dao_ai.middleware.core import create_factory_middleware
41
53
  from dao_ai.middleware.guardrails import (
42
54
  ContentFilterMiddleware,
@@ -62,10 +74,16 @@ from dao_ai.middleware.message_validation import (
62
74
  create_thread_id_validation_middleware,
63
75
  create_user_id_validation_middleware,
64
76
  )
77
+ from dao_ai.middleware.model_call_limit import create_model_call_limit_middleware
78
+ from dao_ai.middleware.model_retry import create_model_retry_middleware
79
+ from dao_ai.middleware.pii import create_pii_middleware
65
80
  from dao_ai.middleware.summarization import (
66
81
  LoggingSummarizationMiddleware,
67
82
  create_summarization_middleware,
68
83
  )
84
+ from dao_ai.middleware.tool_call_limit import create_tool_call_limit_middleware
85
+ from dao_ai.middleware.tool_retry import create_tool_retry_middleware
86
+ from dao_ai.middleware.tool_selector import create_llm_tool_selector_middleware
69
87
 
70
88
  __all__ = [
71
89
  # Base class (from LangChain)
@@ -85,6 +103,14 @@ __all__ = [
85
103
  "SummarizationMiddleware",
86
104
  "LoggingSummarizationMiddleware",
87
105
  "HumanInTheLoopMiddleware",
106
+ "ToolCallLimitMiddleware",
107
+ "ModelCallLimitMiddleware",
108
+ "ToolRetryMiddleware",
109
+ "ModelRetryMiddleware",
110
+ "LLMToolSelectorMiddleware",
111
+ "ContextEditingMiddleware",
112
+ "ClearToolUsesEdit",
113
+ "PIIMiddleware",
88
114
  # Core factory function
89
115
  "create_factory_middleware",
90
116
  # DAO AI middleware implementations
@@ -122,4 +148,16 @@ __all__ = [
122
148
  "create_assert_middleware",
123
149
  "create_suggest_middleware",
124
150
  "create_refine_middleware",
151
+ # Limit and retry middleware factory functions
152
+ "create_tool_call_limit_middleware",
153
+ "create_model_call_limit_middleware",
154
+ "create_tool_retry_middleware",
155
+ "create_model_retry_middleware",
156
+ # Tool selection middleware factory functions
157
+ "create_llm_tool_selector_middleware",
158
+ # Context editing middleware factory functions
159
+ "create_context_editing_middleware",
160
+ "create_clear_tool_uses_edit",
161
+ # PII middleware factory functions
162
+ "create_pii_middleware",
125
163
  ]
@@ -688,7 +688,7 @@ def create_assert_middleware(
688
688
  name: Name for function constraints
689
689
 
690
690
  Returns:
691
- AssertMiddleware configured with the constraint
691
+ List containing AssertMiddleware configured with the constraint
692
692
 
693
693
  Example:
694
694
  # Using a Constraint class
@@ -737,7 +737,7 @@ def create_suggest_middleware(
737
737
  name: Name for function constraints
738
738
 
739
739
  Returns:
740
- SuggestMiddleware configured with the constraint
740
+ List containing SuggestMiddleware configured with the constraint
741
741
 
742
742
  Example:
743
743
  def is_professional(response: str, ctx: dict) -> ConstraintResult:
@@ -783,7 +783,7 @@ def create_refine_middleware(
783
783
  select_best: Track and return best response across iterations
784
784
 
785
785
  Returns:
786
- RefineMiddleware configured with the reward function
786
+ List containing RefineMiddleware configured with the reward function
787
787
 
788
788
  Example:
789
789
  def evaluate_completeness(response: str, ctx: dict) -> float: