dao-ai 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/agent_as_code.py +2 -5
- dao_ai/cli.py +65 -15
- dao_ai/config.py +672 -218
- dao_ai/genie/cache/core.py +6 -2
- dao_ai/genie/cache/lru.py +29 -11
- dao_ai/genie/cache/semantic.py +95 -44
- dao_ai/hooks/core.py +5 -5
- dao_ai/logging.py +56 -0
- dao_ai/memory/core.py +61 -44
- dao_ai/memory/databricks.py +54 -41
- dao_ai/memory/postgres.py +77 -36
- dao_ai/middleware/assertions.py +45 -17
- dao_ai/middleware/core.py +13 -7
- dao_ai/middleware/guardrails.py +30 -25
- dao_ai/middleware/human_in_the_loop.py +9 -5
- dao_ai/middleware/message_validation.py +61 -29
- dao_ai/middleware/summarization.py +16 -11
- dao_ai/models.py +172 -69
- dao_ai/nodes.py +148 -19
- dao_ai/optimization.py +26 -16
- dao_ai/orchestration/core.py +15 -8
- dao_ai/orchestration/supervisor.py +22 -8
- dao_ai/orchestration/swarm.py +57 -12
- dao_ai/prompts.py +17 -17
- dao_ai/providers/databricks.py +365 -155
- dao_ai/state.py +24 -6
- dao_ai/tools/__init__.py +2 -0
- dao_ai/tools/agent.py +1 -3
- dao_ai/tools/core.py +7 -7
- dao_ai/tools/email.py +29 -77
- dao_ai/tools/genie.py +18 -13
- dao_ai/tools/mcp.py +223 -156
- dao_ai/tools/python.py +5 -2
- dao_ai/tools/search.py +1 -1
- dao_ai/tools/slack.py +21 -9
- dao_ai/tools/sql.py +202 -0
- dao_ai/tools/time.py +30 -7
- dao_ai/tools/unity_catalog.py +129 -86
- dao_ai/tools/vector_search.py +318 -244
- dao_ai/utils.py +15 -10
- dao_ai-0.1.3.dist-info/METADATA +455 -0
- dao_ai-0.1.3.dist-info/RECORD +64 -0
- dao_ai-0.1.1.dist-info/METADATA +0 -1878
- dao_ai-0.1.1.dist-info/RECORD +0 -62
- {dao_ai-0.1.1.dist-info → dao_ai-0.1.3.dist-info}/WHEEL +0 -0
- {dao_ai-0.1.1.dist-info → dao_ai-0.1.3.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.1.1.dist-info → dao_ai-0.1.3.dist-info}/licenses/LICENSE +0 -0
dao_ai/memory/databricks.py
CHANGED
|
@@ -59,23 +59,23 @@ class AsyncDatabricksCheckpointSaver(DatabricksCheckpointSaver):
|
|
|
59
59
|
async def aget_tuple(self, config: RunnableConfig) -> CheckpointTuple | None:
|
|
60
60
|
"""Async version of get_tuple."""
|
|
61
61
|
thread_id = config.get("configurable", {}).get("thread_id", "unknown")
|
|
62
|
-
logger.
|
|
62
|
+
logger.trace("Fetching checkpoint", thread_id=thread_id, method="aget_tuple")
|
|
63
63
|
result = await asyncio.to_thread(self.get_tuple, config)
|
|
64
64
|
if result:
|
|
65
|
-
logger.
|
|
65
|
+
logger.trace("Checkpoint found", thread_id=thread_id)
|
|
66
66
|
else:
|
|
67
|
-
logger.
|
|
67
|
+
logger.trace("No checkpoint found", thread_id=thread_id)
|
|
68
68
|
return result
|
|
69
69
|
|
|
70
70
|
async def aget(self, config: RunnableConfig) -> Checkpoint | None:
|
|
71
71
|
"""Async version of get."""
|
|
72
72
|
thread_id = config.get("configurable", {}).get("thread_id", "unknown")
|
|
73
|
-
logger.
|
|
73
|
+
logger.trace("Fetching checkpoint", thread_id=thread_id, method="aget")
|
|
74
74
|
result = await asyncio.to_thread(self.get, config)
|
|
75
75
|
if result:
|
|
76
|
-
logger.
|
|
76
|
+
logger.trace("Checkpoint found", thread_id=thread_id)
|
|
77
77
|
else:
|
|
78
|
-
logger.
|
|
78
|
+
logger.trace("No checkpoint found", thread_id=thread_id)
|
|
79
79
|
return result
|
|
80
80
|
|
|
81
81
|
async def aput(
|
|
@@ -88,13 +88,15 @@ class AsyncDatabricksCheckpointSaver(DatabricksCheckpointSaver):
|
|
|
88
88
|
"""Async version of put."""
|
|
89
89
|
thread_id = config.get("configurable", {}).get("thread_id", "unknown")
|
|
90
90
|
checkpoint_id = checkpoint.get("id", "unknown")
|
|
91
|
-
logger.
|
|
92
|
-
|
|
91
|
+
logger.trace(
|
|
92
|
+
"Saving checkpoint", checkpoint_id=checkpoint_id, thread_id=thread_id
|
|
93
93
|
)
|
|
94
94
|
result = await asyncio.to_thread(
|
|
95
95
|
self.put, config, checkpoint, metadata, new_versions
|
|
96
96
|
)
|
|
97
|
-
logger.
|
|
97
|
+
logger.trace(
|
|
98
|
+
"Checkpoint saved", thread_id=thread_id, checkpoint_id=checkpoint_id
|
|
99
|
+
)
|
|
98
100
|
return result
|
|
99
101
|
|
|
100
102
|
async def aput_writes(
|
|
@@ -106,12 +108,14 @@ class AsyncDatabricksCheckpointSaver(DatabricksCheckpointSaver):
|
|
|
106
108
|
) -> None:
|
|
107
109
|
"""Async version of put_writes."""
|
|
108
110
|
thread_id = config.get("configurable", {}).get("thread_id", "unknown")
|
|
109
|
-
logger.
|
|
110
|
-
|
|
111
|
-
|
|
111
|
+
logger.trace(
|
|
112
|
+
"Saving checkpoint writes",
|
|
113
|
+
writes_count=len(writes),
|
|
114
|
+
thread_id=thread_id,
|
|
115
|
+
task_id=task_id,
|
|
112
116
|
)
|
|
113
117
|
await asyncio.to_thread(self.put_writes, config, writes, task_id, task_path)
|
|
114
|
-
logger.
|
|
118
|
+
logger.trace("Checkpoint writes saved", thread_id=thread_id, task_id=task_id)
|
|
115
119
|
|
|
116
120
|
async def alist(
|
|
117
121
|
self,
|
|
@@ -127,22 +131,20 @@ class AsyncDatabricksCheckpointSaver(DatabricksCheckpointSaver):
|
|
|
127
131
|
if config
|
|
128
132
|
else "all"
|
|
129
133
|
)
|
|
130
|
-
logger.
|
|
131
|
-
f"alist: Listing checkpoints for thread_id={thread_id}, limit={limit}"
|
|
132
|
-
)
|
|
134
|
+
logger.trace("Listing checkpoints", thread_id=thread_id, limit=limit)
|
|
133
135
|
# Get all items from sync iterator in a thread
|
|
134
136
|
items = await asyncio.to_thread(
|
|
135
137
|
lambda: list(self.list(config, filter=filter, before=before, limit=limit))
|
|
136
138
|
)
|
|
137
|
-
logger.debug(
|
|
139
|
+
logger.debug("Checkpoints listed", thread_id=thread_id, count=len(items))
|
|
138
140
|
for item in items:
|
|
139
141
|
yield item
|
|
140
142
|
|
|
141
143
|
async def adelete_thread(self, thread_id: str) -> None:
|
|
142
144
|
"""Async version of delete_thread."""
|
|
143
|
-
logger.
|
|
145
|
+
logger.trace("Deleting thread", thread_id=thread_id)
|
|
144
146
|
await asyncio.to_thread(self.delete_thread, thread_id)
|
|
145
|
-
logger.debug(
|
|
147
|
+
logger.debug("Thread deleted", thread_id=thread_id)
|
|
146
148
|
|
|
147
149
|
|
|
148
150
|
class AsyncDatabricksStore(DatabricksStore):
|
|
@@ -156,9 +158,9 @@ class AsyncDatabricksStore(DatabricksStore):
|
|
|
156
158
|
async def abatch(self, ops: Iterable[Op]) -> list[Result]:
|
|
157
159
|
"""Async version of batch."""
|
|
158
160
|
ops_list = list(ops)
|
|
159
|
-
logger.
|
|
161
|
+
logger.trace("Executing batch operations", operations_count=len(ops_list))
|
|
160
162
|
result = await asyncio.to_thread(self.batch, ops_list)
|
|
161
|
-
logger.debug(
|
|
163
|
+
logger.debug("Batch operations completed", operations_count=len(result))
|
|
162
164
|
return result
|
|
163
165
|
|
|
164
166
|
async def aget(
|
|
@@ -170,14 +172,14 @@ class AsyncDatabricksStore(DatabricksStore):
|
|
|
170
172
|
) -> Item | None:
|
|
171
173
|
"""Async version of get."""
|
|
172
174
|
ns_str = "/".join(namespace)
|
|
173
|
-
logger.
|
|
175
|
+
logger.trace("Fetching store item", key=key, namespace=ns_str)
|
|
174
176
|
result = await asyncio.to_thread(
|
|
175
177
|
partial(self.get, namespace, key, refresh_ttl=refresh_ttl)
|
|
176
178
|
)
|
|
177
179
|
if result:
|
|
178
|
-
logger.
|
|
180
|
+
logger.trace("Store item found", key=key, namespace=ns_str)
|
|
179
181
|
else:
|
|
180
|
-
logger.
|
|
182
|
+
logger.trace("Store item not found", key=key, namespace=ns_str)
|
|
181
183
|
return result
|
|
182
184
|
|
|
183
185
|
async def aput(
|
|
@@ -191,7 +193,7 @@ class AsyncDatabricksStore(DatabricksStore):
|
|
|
191
193
|
) -> None:
|
|
192
194
|
"""Async version of put."""
|
|
193
195
|
ns_str = "/".join(namespace)
|
|
194
|
-
logger.
|
|
196
|
+
logger.trace("Storing item", key=key, namespace=ns_str, has_ttl=ttl is not None)
|
|
195
197
|
# Handle the ttl parameter - only pass if explicitly provided
|
|
196
198
|
if ttl is not None:
|
|
197
199
|
await asyncio.to_thread(
|
|
@@ -199,14 +201,14 @@ class AsyncDatabricksStore(DatabricksStore):
|
|
|
199
201
|
)
|
|
200
202
|
else:
|
|
201
203
|
await asyncio.to_thread(partial(self.put, namespace, key, value, index))
|
|
202
|
-
logger.
|
|
204
|
+
logger.trace("Item stored", key=key, namespace=ns_str)
|
|
203
205
|
|
|
204
206
|
async def adelete(self, namespace: tuple[str, ...], key: str) -> None:
|
|
205
207
|
"""Async version of delete."""
|
|
206
208
|
ns_str = "/".join(namespace)
|
|
207
|
-
logger.
|
|
209
|
+
logger.trace("Deleting item", key=key, namespace=ns_str)
|
|
208
210
|
await asyncio.to_thread(self.delete, namespace, key)
|
|
209
|
-
logger.
|
|
211
|
+
logger.trace("Item deleted", key=key, namespace=ns_str)
|
|
210
212
|
|
|
211
213
|
async def asearch(
|
|
212
214
|
self,
|
|
@@ -221,8 +223,8 @@ class AsyncDatabricksStore(DatabricksStore):
|
|
|
221
223
|
) -> list[SearchItem]:
|
|
222
224
|
"""Async version of search."""
|
|
223
225
|
ns_str = "/".join(namespace_prefix)
|
|
224
|
-
logger.
|
|
225
|
-
|
|
226
|
+
logger.trace(
|
|
227
|
+
"Searching store", namespace_prefix=ns_str, query=query, limit=limit
|
|
226
228
|
)
|
|
227
229
|
result = await asyncio.to_thread(
|
|
228
230
|
partial(
|
|
@@ -235,7 +237,9 @@ class AsyncDatabricksStore(DatabricksStore):
|
|
|
235
237
|
refresh_ttl=refresh_ttl,
|
|
236
238
|
)
|
|
237
239
|
)
|
|
238
|
-
logger.debug(
|
|
240
|
+
logger.debug(
|
|
241
|
+
"Store search completed", namespace_prefix=ns_str, results_count=len(result)
|
|
242
|
+
)
|
|
239
243
|
return result
|
|
240
244
|
|
|
241
245
|
async def alist_namespaces(
|
|
@@ -249,9 +253,7 @@ class AsyncDatabricksStore(DatabricksStore):
|
|
|
249
253
|
) -> list[tuple[str, ...]]:
|
|
250
254
|
"""Async version of list_namespaces."""
|
|
251
255
|
prefix_str = "/".join(prefix) if prefix else "all"
|
|
252
|
-
logger.
|
|
253
|
-
f"alist_namespaces: Listing namespaces prefix={prefix_str}, limit={limit}"
|
|
254
|
-
)
|
|
256
|
+
logger.trace("Listing namespaces", prefix=prefix_str, limit=limit)
|
|
255
257
|
result = await asyncio.to_thread(
|
|
256
258
|
partial(
|
|
257
259
|
self.list_namespaces,
|
|
@@ -262,7 +264,7 @@ class AsyncDatabricksStore(DatabricksStore):
|
|
|
262
264
|
offset=offset,
|
|
263
265
|
)
|
|
264
266
|
)
|
|
265
|
-
logger.debug(
|
|
267
|
+
logger.debug("Namespaces listed", count=len(result))
|
|
266
268
|
return result
|
|
267
269
|
|
|
268
270
|
|
|
@@ -297,7 +299,7 @@ class DatabricksCheckpointerManager(CheckpointManagerBase):
|
|
|
297
299
|
workspace_client = database.workspace_client
|
|
298
300
|
|
|
299
301
|
logger.debug(
|
|
300
|
-
|
|
302
|
+
"Creating Databricks checkpointer", instance_name=instance_name
|
|
301
303
|
)
|
|
302
304
|
|
|
303
305
|
checkpointer = AsyncDatabricksCheckpointSaver(
|
|
@@ -306,10 +308,10 @@ class DatabricksCheckpointerManager(CheckpointManagerBase):
|
|
|
306
308
|
)
|
|
307
309
|
|
|
308
310
|
# Setup the checkpointer (creates necessary tables if needed)
|
|
309
|
-
logger.debug(
|
|
311
|
+
logger.debug("Setting up checkpoint tables", instance_name=instance_name)
|
|
310
312
|
checkpointer.setup()
|
|
311
|
-
logger.
|
|
312
|
-
|
|
313
|
+
logger.success(
|
|
314
|
+
"Databricks checkpointer initialized", instance_name=instance_name
|
|
313
315
|
)
|
|
314
316
|
|
|
315
317
|
self._checkpointer = checkpointer
|
|
@@ -360,12 +362,18 @@ class DatabricksStoreManager(StoreManagerBase):
|
|
|
360
362
|
embedding_dims = self.store_model.dims
|
|
361
363
|
|
|
362
364
|
logger.debug(
|
|
363
|
-
|
|
365
|
+
"Configuring store embeddings",
|
|
366
|
+
endpoint=embedding_endpoint,
|
|
367
|
+
dimensions=embedding_dims,
|
|
364
368
|
)
|
|
365
369
|
|
|
366
370
|
embeddings = DatabricksEmbeddings(endpoint=embedding_endpoint)
|
|
367
371
|
|
|
368
|
-
logger.debug(
|
|
372
|
+
logger.debug(
|
|
373
|
+
"Creating Databricks store",
|
|
374
|
+
instance_name=instance_name,
|
|
375
|
+
embeddings_enabled=embeddings is not None,
|
|
376
|
+
)
|
|
369
377
|
|
|
370
378
|
store = AsyncDatabricksStore(
|
|
371
379
|
instance_name=instance_name,
|
|
@@ -376,6 +384,11 @@ class DatabricksStoreManager(StoreManagerBase):
|
|
|
376
384
|
|
|
377
385
|
# Setup the store (creates necessary tables if needed)
|
|
378
386
|
store.setup()
|
|
387
|
+
logger.success(
|
|
388
|
+
"Databricks store initialized",
|
|
389
|
+
instance_name=instance_name,
|
|
390
|
+
embeddings_enabled=embeddings is not None,
|
|
391
|
+
)
|
|
379
392
|
self._store = store
|
|
380
393
|
|
|
381
394
|
return self._store
|
dao_ai/memory/postgres.py
CHANGED
|
@@ -28,9 +28,10 @@ def _create_pool(
|
|
|
28
28
|
kwargs: dict,
|
|
29
29
|
) -> ConnectionPool:
|
|
30
30
|
"""Create a connection pool using the provided connection parameters."""
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
31
|
+
safe_params = {
|
|
32
|
+
k: (str(v) if k != "password" else "***") for k, v in connection_params.items()
|
|
33
|
+
}
|
|
34
|
+
logger.debug("Creating connection pool", database=database_name, **safe_params)
|
|
34
35
|
|
|
35
36
|
# Merge connection_params into kwargs for psycopg
|
|
36
37
|
connection_kwargs = kwargs | connection_params
|
|
@@ -43,7 +44,11 @@ def _create_pool(
|
|
|
43
44
|
kwargs=connection_kwargs,
|
|
44
45
|
)
|
|
45
46
|
pool.open(wait=True, timeout=timeout_seconds)
|
|
46
|
-
logger.
|
|
47
|
+
logger.success(
|
|
48
|
+
"PostgreSQL connection pool created",
|
|
49
|
+
database=database_name,
|
|
50
|
+
pool_size=max_pool_size,
|
|
51
|
+
)
|
|
47
52
|
return pool
|
|
48
53
|
|
|
49
54
|
|
|
@@ -55,8 +60,11 @@ async def _create_async_pool(
|
|
|
55
60
|
kwargs: dict,
|
|
56
61
|
) -> AsyncConnectionPool:
|
|
57
62
|
"""Create an async connection pool using the provided connection parameters."""
|
|
63
|
+
safe_params = {
|
|
64
|
+
k: (str(v) if k != "password" else "***") for k, v in connection_params.items()
|
|
65
|
+
}
|
|
58
66
|
logger.debug(
|
|
59
|
-
|
|
67
|
+
"Creating async connection pool", database=database_name, **safe_params
|
|
60
68
|
)
|
|
61
69
|
|
|
62
70
|
# Merge connection_params into kwargs for psycopg
|
|
@@ -69,7 +77,11 @@ async def _create_async_pool(
|
|
|
69
77
|
kwargs=connection_kwargs,
|
|
70
78
|
)
|
|
71
79
|
await pool.open(wait=True, timeout=timeout_seconds)
|
|
72
|
-
logger.
|
|
80
|
+
logger.success(
|
|
81
|
+
"Async PostgreSQL connection pool created",
|
|
82
|
+
database=database_name,
|
|
83
|
+
pool_size=max_pool_size,
|
|
84
|
+
)
|
|
73
85
|
return pool
|
|
74
86
|
|
|
75
87
|
|
|
@@ -84,10 +96,12 @@ class AsyncPostgresPoolManager:
|
|
|
84
96
|
|
|
85
97
|
async with cls._lock:
|
|
86
98
|
if connection_key in cls._pools:
|
|
87
|
-
logger.
|
|
99
|
+
logger.trace(
|
|
100
|
+
"Reusing existing async PostgreSQL pool", database=database.name
|
|
101
|
+
)
|
|
88
102
|
return cls._pools[connection_key]
|
|
89
103
|
|
|
90
|
-
logger.debug(
|
|
104
|
+
logger.debug("Creating new async PostgreSQL pool", database=database.name)
|
|
91
105
|
|
|
92
106
|
kwargs: dict[str, Any] = {
|
|
93
107
|
"row_factory": dict_row,
|
|
@@ -114,7 +128,7 @@ class AsyncPostgresPoolManager:
|
|
|
114
128
|
if connection_key in cls._pools:
|
|
115
129
|
pool = cls._pools.pop(connection_key)
|
|
116
130
|
await pool.close()
|
|
117
|
-
logger.debug(
|
|
131
|
+
logger.debug("Async PostgreSQL pool closed", database=database.name)
|
|
118
132
|
|
|
119
133
|
@classmethod
|
|
120
134
|
async def close_all_pools(cls):
|
|
@@ -123,17 +137,21 @@ class AsyncPostgresPoolManager:
|
|
|
123
137
|
try:
|
|
124
138
|
# Use a short timeout to avoid blocking on pool closure
|
|
125
139
|
await asyncio.wait_for(pool.close(), timeout=2.0)
|
|
126
|
-
logger.debug(
|
|
140
|
+
logger.debug("Async PostgreSQL pool closed", pool=connection_key)
|
|
127
141
|
except asyncio.TimeoutError:
|
|
128
142
|
logger.warning(
|
|
129
|
-
|
|
143
|
+
"Timeout closing async pool, forcing closure",
|
|
144
|
+
pool=connection_key,
|
|
130
145
|
)
|
|
131
146
|
except asyncio.CancelledError:
|
|
132
147
|
logger.warning(
|
|
133
|
-
|
|
148
|
+
"Async pool closure cancelled (shutdown in progress)",
|
|
149
|
+
pool=connection_key,
|
|
134
150
|
)
|
|
135
151
|
except Exception as e:
|
|
136
|
-
logger.error(
|
|
152
|
+
logger.error(
|
|
153
|
+
"Error closing async pool", pool=connection_key, error=str(e)
|
|
154
|
+
)
|
|
137
155
|
cls._pools.clear()
|
|
138
156
|
|
|
139
157
|
|
|
@@ -181,12 +199,16 @@ class AsyncPostgresStoreManager(StoreManagerBase):
|
|
|
181
199
|
await self._store.setup()
|
|
182
200
|
|
|
183
201
|
self._setup_complete = True
|
|
184
|
-
logger.
|
|
185
|
-
|
|
202
|
+
logger.success(
|
|
203
|
+
"Async PostgreSQL store initialized", store=self.store_model.name
|
|
186
204
|
)
|
|
187
205
|
|
|
188
206
|
except Exception as e:
|
|
189
|
-
logger.error(
|
|
207
|
+
logger.error(
|
|
208
|
+
"Error setting up async PostgreSQL store",
|
|
209
|
+
store=self.store_model.name,
|
|
210
|
+
error=str(e),
|
|
211
|
+
)
|
|
190
212
|
raise
|
|
191
213
|
|
|
192
214
|
|
|
@@ -244,12 +266,17 @@ class AsyncPostgresCheckpointerManager(CheckpointManagerBase):
|
|
|
244
266
|
await self._checkpointer.setup()
|
|
245
267
|
|
|
246
268
|
self._setup_complete = True
|
|
247
|
-
logger.
|
|
248
|
-
|
|
269
|
+
logger.success(
|
|
270
|
+
"Async PostgreSQL checkpointer initialized",
|
|
271
|
+
checkpointer=self.checkpointer_model.name,
|
|
249
272
|
)
|
|
250
273
|
|
|
251
274
|
except Exception as e:
|
|
252
|
-
logger.error(
|
|
275
|
+
logger.error(
|
|
276
|
+
"Error setting up async PostgreSQL checkpointer",
|
|
277
|
+
checkpointer=self.checkpointer_model.name,
|
|
278
|
+
error=str(e),
|
|
279
|
+
)
|
|
253
280
|
raise
|
|
254
281
|
|
|
255
282
|
|
|
@@ -269,10 +296,10 @@ class PostgresPoolManager:
|
|
|
269
296
|
|
|
270
297
|
with cls._lock:
|
|
271
298
|
if connection_key in cls._pools:
|
|
272
|
-
logger.
|
|
299
|
+
logger.trace("Reusing existing PostgreSQL pool", database=database.name)
|
|
273
300
|
return cls._pools[connection_key]
|
|
274
301
|
|
|
275
|
-
logger.debug(
|
|
302
|
+
logger.debug("Creating new PostgreSQL pool", database=database.name)
|
|
276
303
|
|
|
277
304
|
kwargs: dict[str, Any] = {
|
|
278
305
|
"row_factory": dict_row,
|
|
@@ -299,7 +326,7 @@ class PostgresPoolManager:
|
|
|
299
326
|
if connection_key in cls._pools:
|
|
300
327
|
pool = cls._pools.pop(connection_key)
|
|
301
328
|
pool.close()
|
|
302
|
-
logger.debug(
|
|
329
|
+
logger.debug("PostgreSQL pool closed", database=database.name)
|
|
303
330
|
|
|
304
331
|
@classmethod
|
|
305
332
|
def close_all_pools(cls):
|
|
@@ -307,9 +334,13 @@ class PostgresPoolManager:
|
|
|
307
334
|
for connection_key, pool in cls._pools.items():
|
|
308
335
|
try:
|
|
309
336
|
pool.close()
|
|
310
|
-
logger.debug(
|
|
337
|
+
logger.debug("PostgreSQL pool closed", pool=connection_key)
|
|
311
338
|
except Exception as e:
|
|
312
|
-
logger.error(
|
|
339
|
+
logger.error(
|
|
340
|
+
"Error closing PostgreSQL pool",
|
|
341
|
+
pool=connection_key,
|
|
342
|
+
error=str(e),
|
|
343
|
+
)
|
|
313
344
|
cls._pools.clear()
|
|
314
345
|
|
|
315
346
|
|
|
@@ -349,12 +380,14 @@ class PostgresStoreManager(StoreManagerBase):
|
|
|
349
380
|
self._store.setup()
|
|
350
381
|
|
|
351
382
|
self._setup_complete = True
|
|
352
|
-
logger.
|
|
353
|
-
f"PostgresStore initialized successfully for {self.store_model.name}"
|
|
354
|
-
)
|
|
383
|
+
logger.success("PostgreSQL store initialized", store=self.store_model.name)
|
|
355
384
|
|
|
356
385
|
except Exception as e:
|
|
357
|
-
logger.error(
|
|
386
|
+
logger.error(
|
|
387
|
+
"Error setting up PostgreSQL store",
|
|
388
|
+
store=self.store_model.name,
|
|
389
|
+
error=str(e),
|
|
390
|
+
)
|
|
358
391
|
raise
|
|
359
392
|
|
|
360
393
|
|
|
@@ -400,21 +433,28 @@ class PostgresCheckpointerManager(CheckpointManagerBase):
|
|
|
400
433
|
self._checkpointer.setup()
|
|
401
434
|
|
|
402
435
|
self._setup_complete = True
|
|
403
|
-
logger.
|
|
404
|
-
|
|
436
|
+
logger.success(
|
|
437
|
+
"PostgreSQL checkpointer initialized",
|
|
438
|
+
checkpointer=self.checkpointer_model.name,
|
|
405
439
|
)
|
|
406
440
|
|
|
407
441
|
except Exception as e:
|
|
408
|
-
logger.error(
|
|
442
|
+
logger.error(
|
|
443
|
+
"Error setting up PostgreSQL checkpointer",
|
|
444
|
+
checkpointer=self.checkpointer_model.name,
|
|
445
|
+
error=str(e),
|
|
446
|
+
)
|
|
409
447
|
raise
|
|
410
448
|
|
|
411
449
|
|
|
412
450
|
def _shutdown_pools() -> None:
|
|
413
451
|
try:
|
|
414
452
|
PostgresPoolManager.close_all_pools()
|
|
415
|
-
logger.debug("
|
|
453
|
+
logger.debug("All synchronous PostgreSQL pools closed during shutdown")
|
|
416
454
|
except Exception as e:
|
|
417
|
-
logger.error(
|
|
455
|
+
logger.error(
|
|
456
|
+
"Error closing synchronous PostgreSQL pools during shutdown", error=str(e)
|
|
457
|
+
)
|
|
418
458
|
|
|
419
459
|
|
|
420
460
|
def _shutdown_async_pools() -> None:
|
|
@@ -434,15 +474,16 @@ def _shutdown_async_pools() -> None:
|
|
|
434
474
|
loop = asyncio.new_event_loop()
|
|
435
475
|
asyncio.set_event_loop(loop)
|
|
436
476
|
loop.run_until_complete(AsyncPostgresPoolManager.close_all_pools())
|
|
437
|
-
logger.debug("
|
|
477
|
+
logger.debug("All asynchronous PostgreSQL pools closed during shutdown")
|
|
438
478
|
except Exception as inner_e:
|
|
439
479
|
# If all else fails, just log the error
|
|
440
480
|
logger.warning(
|
|
441
|
-
|
|
481
|
+
"Could not close async pools cleanly during shutdown",
|
|
482
|
+
error=str(inner_e),
|
|
442
483
|
)
|
|
443
484
|
except Exception as e:
|
|
444
485
|
logger.error(
|
|
445
|
-
|
|
486
|
+
"Error closing asynchronous PostgreSQL pools during shutdown", error=str(e)
|
|
446
487
|
)
|
|
447
488
|
|
|
448
489
|
|
dao_ai/middleware/assertions.py
CHANGED
|
@@ -367,20 +367,27 @@ class AssertMiddleware(AgentMiddleware[AgentState, Context]):
|
|
|
367
367
|
"runtime": runtime,
|
|
368
368
|
}
|
|
369
369
|
|
|
370
|
-
logger.
|
|
370
|
+
logger.trace(
|
|
371
|
+
"Evaluating Assert constraint", constraint_name=self.constraint.name
|
|
372
|
+
)
|
|
371
373
|
|
|
372
374
|
result = self.constraint.evaluate(response, context)
|
|
373
375
|
|
|
374
376
|
if result.passed:
|
|
375
|
-
logger.
|
|
377
|
+
logger.trace(
|
|
378
|
+
"Assert constraint passed", constraint_name=self.constraint.name
|
|
379
|
+
)
|
|
376
380
|
self._retry_count = 0
|
|
377
381
|
return None
|
|
378
382
|
|
|
379
383
|
# Constraint failed
|
|
380
384
|
self._retry_count += 1
|
|
381
385
|
logger.warning(
|
|
382
|
-
|
|
383
|
-
|
|
386
|
+
"Assert constraint failed",
|
|
387
|
+
constraint_name=self.constraint.name,
|
|
388
|
+
attempt=self._retry_count,
|
|
389
|
+
max_retries=self.max_retries,
|
|
390
|
+
feedback=result.feedback,
|
|
384
391
|
)
|
|
385
392
|
|
|
386
393
|
if self._retry_count >= self.max_retries:
|
|
@@ -396,7 +403,8 @@ class AssertMiddleware(AgentMiddleware[AgentState, Context]):
|
|
|
396
403
|
return None
|
|
397
404
|
else: # "pass"
|
|
398
405
|
logger.warning(
|
|
399
|
-
|
|
406
|
+
"Assert constraint failed but passing through",
|
|
407
|
+
constraint_name=self.constraint.name,
|
|
400
408
|
)
|
|
401
409
|
return None
|
|
402
410
|
|
|
@@ -475,25 +483,38 @@ class SuggestMiddleware(AgentMiddleware[AgentState, Context]):
|
|
|
475
483
|
"runtime": runtime,
|
|
476
484
|
}
|
|
477
485
|
|
|
478
|
-
logger.
|
|
486
|
+
logger.trace(
|
|
487
|
+
"Evaluating Suggest constraint", constraint_name=self.constraint.name
|
|
488
|
+
)
|
|
479
489
|
|
|
480
490
|
result = self.constraint.evaluate(response, context)
|
|
481
491
|
|
|
482
492
|
if result.passed:
|
|
483
|
-
logger.
|
|
493
|
+
logger.trace(
|
|
494
|
+
"Suggest constraint passed", constraint_name=self.constraint.name
|
|
495
|
+
)
|
|
484
496
|
self._has_retried = False
|
|
485
497
|
return None
|
|
486
498
|
|
|
487
499
|
# Log feedback based on configured level
|
|
488
|
-
log_msg = (
|
|
489
|
-
f"Suggest constraint '{self.constraint.name}' feedback: {result.feedback}"
|
|
490
|
-
)
|
|
491
500
|
if self.log_level == "warning":
|
|
492
|
-
logger.warning(
|
|
501
|
+
logger.warning(
|
|
502
|
+
"Suggest constraint feedback",
|
|
503
|
+
constraint_name=self.constraint.name,
|
|
504
|
+
feedback=result.feedback,
|
|
505
|
+
)
|
|
493
506
|
elif self.log_level == "info":
|
|
494
|
-
logger.info(
|
|
507
|
+
logger.info(
|
|
508
|
+
"Suggest constraint feedback",
|
|
509
|
+
constraint_name=self.constraint.name,
|
|
510
|
+
feedback=result.feedback,
|
|
511
|
+
)
|
|
495
512
|
else:
|
|
496
|
-
logger.debug(
|
|
513
|
+
logger.debug(
|
|
514
|
+
"Suggest constraint feedback",
|
|
515
|
+
constraint_name=self.constraint.name,
|
|
516
|
+
feedback=result.feedback,
|
|
517
|
+
)
|
|
497
518
|
|
|
498
519
|
# Optionally request one improvement
|
|
499
520
|
if self.allow_one_retry and not self._has_retried:
|
|
@@ -593,8 +614,11 @@ class RefineMiddleware(AgentMiddleware[AgentState, Context]):
|
|
|
593
614
|
self._iteration += 1
|
|
594
615
|
|
|
595
616
|
logger.debug(
|
|
596
|
-
|
|
597
|
-
|
|
617
|
+
"Refine iteration",
|
|
618
|
+
iteration=self._iteration,
|
|
619
|
+
max_iterations=self.max_iterations,
|
|
620
|
+
score=f"{score:.3f}",
|
|
621
|
+
threshold=self.threshold,
|
|
598
622
|
)
|
|
599
623
|
|
|
600
624
|
# Track best response
|
|
@@ -604,13 +628,17 @@ class RefineMiddleware(AgentMiddleware[AgentState, Context]):
|
|
|
604
628
|
|
|
605
629
|
# Check if we should stop
|
|
606
630
|
if score >= self.threshold:
|
|
607
|
-
logger.debug(
|
|
631
|
+
logger.debug(
|
|
632
|
+
"Refine threshold reached",
|
|
633
|
+
score=f"{score:.3f}",
|
|
634
|
+
threshold=self.threshold,
|
|
635
|
+
)
|
|
608
636
|
self._reset()
|
|
609
637
|
return None
|
|
610
638
|
|
|
611
639
|
if self._iteration >= self.max_iterations:
|
|
612
640
|
logger.debug(
|
|
613
|
-
|
|
641
|
+
"Refine max iterations reached", best_score=f"{self._best_score:.3f}"
|
|
614
642
|
)
|
|
615
643
|
# Use best response if tracking
|
|
616
644
|
if self.select_best and self._best_response:
|