dao-ai 0.0.25__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/__init__.py +29 -0
- dao_ai/agent_as_code.py +5 -5
- dao_ai/cli.py +245 -40
- dao_ai/config.py +1863 -338
- dao_ai/genie/__init__.py +38 -0
- dao_ai/genie/cache/__init__.py +43 -0
- dao_ai/genie/cache/base.py +72 -0
- dao_ai/genie/cache/core.py +79 -0
- dao_ai/genie/cache/lru.py +347 -0
- dao_ai/genie/cache/semantic.py +970 -0
- dao_ai/genie/core.py +35 -0
- dao_ai/graph.py +27 -228
- dao_ai/hooks/__init__.py +9 -6
- dao_ai/hooks/core.py +27 -195
- dao_ai/logging.py +56 -0
- dao_ai/memory/__init__.py +10 -0
- dao_ai/memory/core.py +65 -30
- dao_ai/memory/databricks.py +402 -0
- dao_ai/memory/postgres.py +79 -38
- dao_ai/messages.py +6 -4
- dao_ai/middleware/__init__.py +125 -0
- dao_ai/middleware/assertions.py +806 -0
- dao_ai/middleware/base.py +50 -0
- dao_ai/middleware/core.py +67 -0
- dao_ai/middleware/guardrails.py +420 -0
- dao_ai/middleware/human_in_the_loop.py +232 -0
- dao_ai/middleware/message_validation.py +586 -0
- dao_ai/middleware/summarization.py +197 -0
- dao_ai/models.py +1306 -114
- dao_ai/nodes.py +261 -166
- dao_ai/optimization.py +674 -0
- dao_ai/orchestration/__init__.py +52 -0
- dao_ai/orchestration/core.py +294 -0
- dao_ai/orchestration/supervisor.py +278 -0
- dao_ai/orchestration/swarm.py +271 -0
- dao_ai/prompts.py +128 -31
- dao_ai/providers/databricks.py +645 -172
- dao_ai/state.py +157 -21
- dao_ai/tools/__init__.py +13 -5
- dao_ai/tools/agent.py +1 -3
- dao_ai/tools/core.py +64 -11
- dao_ai/tools/email.py +232 -0
- dao_ai/tools/genie.py +144 -295
- dao_ai/tools/mcp.py +220 -133
- dao_ai/tools/memory.py +50 -0
- dao_ai/tools/python.py +9 -14
- dao_ai/tools/search.py +14 -0
- dao_ai/tools/slack.py +22 -10
- dao_ai/tools/sql.py +202 -0
- dao_ai/tools/time.py +30 -7
- dao_ai/tools/unity_catalog.py +165 -88
- dao_ai/tools/vector_search.py +360 -40
- dao_ai/utils.py +218 -16
- dao_ai-0.1.2.dist-info/METADATA +455 -0
- dao_ai-0.1.2.dist-info/RECORD +64 -0
- {dao_ai-0.0.25.dist-info → dao_ai-0.1.2.dist-info}/WHEEL +1 -1
- dao_ai/chat_models.py +0 -204
- dao_ai/guardrails.py +0 -112
- dao_ai/tools/human_in_the_loop.py +0 -100
- dao_ai-0.0.25.dist-info/METADATA +0 -1165
- dao_ai-0.0.25.dist-info/RECORD +0 -41
- {dao_ai-0.0.25.dist-info → dao_ai-0.1.2.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.0.25.dist-info → dao_ai-0.1.2.dist-info}/licenses/LICENSE +0 -0
dao_ai/memory/postgres.py
CHANGED
|
@@ -28,9 +28,10 @@ def _create_pool(
|
|
|
28
28
|
kwargs: dict,
|
|
29
29
|
) -> ConnectionPool:
|
|
30
30
|
"""Create a connection pool using the provided connection parameters."""
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
31
|
+
safe_params = {
|
|
32
|
+
k: (str(v) if k != "password" else "***") for k, v in connection_params.items()
|
|
33
|
+
}
|
|
34
|
+
logger.debug("Creating connection pool", database=database_name, **safe_params)
|
|
34
35
|
|
|
35
36
|
# Merge connection_params into kwargs for psycopg
|
|
36
37
|
connection_kwargs = kwargs | connection_params
|
|
@@ -43,7 +44,11 @@ def _create_pool(
|
|
|
43
44
|
kwargs=connection_kwargs,
|
|
44
45
|
)
|
|
45
46
|
pool.open(wait=True, timeout=timeout_seconds)
|
|
46
|
-
logger.
|
|
47
|
+
logger.success(
|
|
48
|
+
"PostgreSQL connection pool created",
|
|
49
|
+
database=database_name,
|
|
50
|
+
pool_size=max_pool_size,
|
|
51
|
+
)
|
|
47
52
|
return pool
|
|
48
53
|
|
|
49
54
|
|
|
@@ -55,8 +60,11 @@ async def _create_async_pool(
|
|
|
55
60
|
kwargs: dict,
|
|
56
61
|
) -> AsyncConnectionPool:
|
|
57
62
|
"""Create an async connection pool using the provided connection parameters."""
|
|
63
|
+
safe_params = {
|
|
64
|
+
k: (str(v) if k != "password" else "***") for k, v in connection_params.items()
|
|
65
|
+
}
|
|
58
66
|
logger.debug(
|
|
59
|
-
|
|
67
|
+
"Creating async connection pool", database=database_name, **safe_params
|
|
60
68
|
)
|
|
61
69
|
|
|
62
70
|
# Merge connection_params into kwargs for psycopg
|
|
@@ -69,7 +77,11 @@ async def _create_async_pool(
|
|
|
69
77
|
kwargs=connection_kwargs,
|
|
70
78
|
)
|
|
71
79
|
await pool.open(wait=True, timeout=timeout_seconds)
|
|
72
|
-
logger.
|
|
80
|
+
logger.success(
|
|
81
|
+
"Async PostgreSQL connection pool created",
|
|
82
|
+
database=database_name,
|
|
83
|
+
pool_size=max_pool_size,
|
|
84
|
+
)
|
|
73
85
|
return pool
|
|
74
86
|
|
|
75
87
|
|
|
@@ -84,10 +96,12 @@ class AsyncPostgresPoolManager:
|
|
|
84
96
|
|
|
85
97
|
async with cls._lock:
|
|
86
98
|
if connection_key in cls._pools:
|
|
87
|
-
logger.
|
|
99
|
+
logger.trace(
|
|
100
|
+
"Reusing existing async PostgreSQL pool", database=database.name
|
|
101
|
+
)
|
|
88
102
|
return cls._pools[connection_key]
|
|
89
103
|
|
|
90
|
-
logger.debug(
|
|
104
|
+
logger.debug("Creating new async PostgreSQL pool", database=database.name)
|
|
91
105
|
|
|
92
106
|
kwargs: dict[str, Any] = {
|
|
93
107
|
"row_factory": dict_row,
|
|
@@ -114,7 +128,7 @@ class AsyncPostgresPoolManager:
|
|
|
114
128
|
if connection_key in cls._pools:
|
|
115
129
|
pool = cls._pools.pop(connection_key)
|
|
116
130
|
await pool.close()
|
|
117
|
-
logger.debug(
|
|
131
|
+
logger.debug("Async PostgreSQL pool closed", database=database.name)
|
|
118
132
|
|
|
119
133
|
@classmethod
|
|
120
134
|
async def close_all_pools(cls):
|
|
@@ -123,17 +137,21 @@ class AsyncPostgresPoolManager:
|
|
|
123
137
|
try:
|
|
124
138
|
# Use a short timeout to avoid blocking on pool closure
|
|
125
139
|
await asyncio.wait_for(pool.close(), timeout=2.0)
|
|
126
|
-
logger.debug(
|
|
140
|
+
logger.debug("Async PostgreSQL pool closed", pool=connection_key)
|
|
127
141
|
except asyncio.TimeoutError:
|
|
128
142
|
logger.warning(
|
|
129
|
-
|
|
143
|
+
"Timeout closing async pool, forcing closure",
|
|
144
|
+
pool=connection_key,
|
|
130
145
|
)
|
|
131
146
|
except asyncio.CancelledError:
|
|
132
147
|
logger.warning(
|
|
133
|
-
|
|
148
|
+
"Async pool closure cancelled (shutdown in progress)",
|
|
149
|
+
pool=connection_key,
|
|
134
150
|
)
|
|
135
151
|
except Exception as e:
|
|
136
|
-
logger.error(
|
|
152
|
+
logger.error(
|
|
153
|
+
"Error closing async pool", pool=connection_key, error=str(e)
|
|
154
|
+
)
|
|
137
155
|
cls._pools.clear()
|
|
138
156
|
|
|
139
157
|
|
|
@@ -181,12 +199,16 @@ class AsyncPostgresStoreManager(StoreManagerBase):
|
|
|
181
199
|
await self._store.setup()
|
|
182
200
|
|
|
183
201
|
self._setup_complete = True
|
|
184
|
-
logger.
|
|
185
|
-
|
|
202
|
+
logger.success(
|
|
203
|
+
"Async PostgreSQL store initialized", store=self.store_model.name
|
|
186
204
|
)
|
|
187
205
|
|
|
188
206
|
except Exception as e:
|
|
189
|
-
logger.error(
|
|
207
|
+
logger.error(
|
|
208
|
+
"Error setting up async PostgreSQL store",
|
|
209
|
+
store=self.store_model.name,
|
|
210
|
+
error=str(e),
|
|
211
|
+
)
|
|
190
212
|
raise
|
|
191
213
|
|
|
192
214
|
|
|
@@ -244,12 +266,17 @@ class AsyncPostgresCheckpointerManager(CheckpointManagerBase):
|
|
|
244
266
|
await self._checkpointer.setup()
|
|
245
267
|
|
|
246
268
|
self._setup_complete = True
|
|
247
|
-
logger.
|
|
248
|
-
|
|
269
|
+
logger.success(
|
|
270
|
+
"Async PostgreSQL checkpointer initialized",
|
|
271
|
+
checkpointer=self.checkpointer_model.name,
|
|
249
272
|
)
|
|
250
273
|
|
|
251
274
|
except Exception as e:
|
|
252
|
-
logger.error(
|
|
275
|
+
logger.error(
|
|
276
|
+
"Error setting up async PostgreSQL checkpointer",
|
|
277
|
+
checkpointer=self.checkpointer_model.name,
|
|
278
|
+
error=str(e),
|
|
279
|
+
)
|
|
253
280
|
raise
|
|
254
281
|
|
|
255
282
|
|
|
@@ -269,10 +296,10 @@ class PostgresPoolManager:
|
|
|
269
296
|
|
|
270
297
|
with cls._lock:
|
|
271
298
|
if connection_key in cls._pools:
|
|
272
|
-
logger.
|
|
299
|
+
logger.trace("Reusing existing PostgreSQL pool", database=database.name)
|
|
273
300
|
return cls._pools[connection_key]
|
|
274
301
|
|
|
275
|
-
logger.debug(
|
|
302
|
+
logger.debug("Creating new PostgreSQL pool", database=database.name)
|
|
276
303
|
|
|
277
304
|
kwargs: dict[str, Any] = {
|
|
278
305
|
"row_factory": dict_row,
|
|
@@ -299,7 +326,7 @@ class PostgresPoolManager:
|
|
|
299
326
|
if connection_key in cls._pools:
|
|
300
327
|
pool = cls._pools.pop(connection_key)
|
|
301
328
|
pool.close()
|
|
302
|
-
logger.debug(
|
|
329
|
+
logger.debug("PostgreSQL pool closed", database=database.name)
|
|
303
330
|
|
|
304
331
|
@classmethod
|
|
305
332
|
def close_all_pools(cls):
|
|
@@ -307,9 +334,13 @@ class PostgresPoolManager:
|
|
|
307
334
|
for connection_key, pool in cls._pools.items():
|
|
308
335
|
try:
|
|
309
336
|
pool.close()
|
|
310
|
-
logger.debug(
|
|
337
|
+
logger.debug("PostgreSQL pool closed", pool=connection_key)
|
|
311
338
|
except Exception as e:
|
|
312
|
-
logger.error(
|
|
339
|
+
logger.error(
|
|
340
|
+
"Error closing PostgreSQL pool",
|
|
341
|
+
pool=connection_key,
|
|
342
|
+
error=str(e),
|
|
343
|
+
)
|
|
313
344
|
cls._pools.clear()
|
|
314
345
|
|
|
315
346
|
|
|
@@ -349,12 +380,14 @@ class PostgresStoreManager(StoreManagerBase):
|
|
|
349
380
|
self._store.setup()
|
|
350
381
|
|
|
351
382
|
self._setup_complete = True
|
|
352
|
-
logger.
|
|
353
|
-
f"PostgresStore initialized successfully for {self.store_model.name}"
|
|
354
|
-
)
|
|
383
|
+
logger.success("PostgreSQL store initialized", store=self.store_model.name)
|
|
355
384
|
|
|
356
385
|
except Exception as e:
|
|
357
|
-
logger.error(
|
|
386
|
+
logger.error(
|
|
387
|
+
"Error setting up PostgreSQL store",
|
|
388
|
+
store=self.store_model.name,
|
|
389
|
+
error=str(e),
|
|
390
|
+
)
|
|
358
391
|
raise
|
|
359
392
|
|
|
360
393
|
|
|
@@ -400,24 +433,31 @@ class PostgresCheckpointerManager(CheckpointManagerBase):
|
|
|
400
433
|
self._checkpointer.setup()
|
|
401
434
|
|
|
402
435
|
self._setup_complete = True
|
|
403
|
-
logger.
|
|
404
|
-
|
|
436
|
+
logger.success(
|
|
437
|
+
"PostgreSQL checkpointer initialized",
|
|
438
|
+
checkpointer=self.checkpointer_model.name,
|
|
405
439
|
)
|
|
406
440
|
|
|
407
441
|
except Exception as e:
|
|
408
|
-
logger.error(
|
|
442
|
+
logger.error(
|
|
443
|
+
"Error setting up PostgreSQL checkpointer",
|
|
444
|
+
checkpointer=self.checkpointer_model.name,
|
|
445
|
+
error=str(e),
|
|
446
|
+
)
|
|
409
447
|
raise
|
|
410
448
|
|
|
411
449
|
|
|
412
|
-
def _shutdown_pools():
|
|
450
|
+
def _shutdown_pools() -> None:
|
|
413
451
|
try:
|
|
414
452
|
PostgresPoolManager.close_all_pools()
|
|
415
|
-
logger.debug("
|
|
453
|
+
logger.debug("All synchronous PostgreSQL pools closed during shutdown")
|
|
416
454
|
except Exception as e:
|
|
417
|
-
logger.error(
|
|
455
|
+
logger.error(
|
|
456
|
+
"Error closing synchronous PostgreSQL pools during shutdown", error=str(e)
|
|
457
|
+
)
|
|
418
458
|
|
|
419
459
|
|
|
420
|
-
def _shutdown_async_pools():
|
|
460
|
+
def _shutdown_async_pools() -> None:
|
|
421
461
|
try:
|
|
422
462
|
# Try to get the current event loop first
|
|
423
463
|
try:
|
|
@@ -434,15 +474,16 @@ def _shutdown_async_pools():
|
|
|
434
474
|
loop = asyncio.new_event_loop()
|
|
435
475
|
asyncio.set_event_loop(loop)
|
|
436
476
|
loop.run_until_complete(AsyncPostgresPoolManager.close_all_pools())
|
|
437
|
-
logger.debug("
|
|
477
|
+
logger.debug("All asynchronous PostgreSQL pools closed during shutdown")
|
|
438
478
|
except Exception as inner_e:
|
|
439
479
|
# If all else fails, just log the error
|
|
440
480
|
logger.warning(
|
|
441
|
-
|
|
481
|
+
"Could not close async pools cleanly during shutdown",
|
|
482
|
+
error=str(inner_e),
|
|
442
483
|
)
|
|
443
484
|
except Exception as e:
|
|
444
485
|
logger.error(
|
|
445
|
-
|
|
486
|
+
"Error closing asynchronous PostgreSQL pools during shutdown", error=str(e)
|
|
446
487
|
)
|
|
447
488
|
|
|
448
489
|
|
dao_ai/messages.py
CHANGED
|
@@ -125,8 +125,6 @@ def has_image(messages: BaseMessage | Sequence[BaseMessage]) -> bool:
|
|
|
125
125
|
"image_url",
|
|
126
126
|
]:
|
|
127
127
|
return True
|
|
128
|
-
if hasattr(item, "type") and item.type in ["image", "image_url"]:
|
|
129
|
-
return True
|
|
130
128
|
return False
|
|
131
129
|
|
|
132
130
|
if isinstance(messages, BaseMessage):
|
|
@@ -176,7 +174,9 @@ def last_human_message(messages: Sequence[BaseMessage]) -> Optional[HumanMessage
|
|
|
176
174
|
Returns:
|
|
177
175
|
The last HumanMessage in the sequence, or None if no human messages found
|
|
178
176
|
"""
|
|
179
|
-
return last_message(
|
|
177
|
+
return last_message(
|
|
178
|
+
messages, lambda m: isinstance(m, HumanMessage) and bool(m.content)
|
|
179
|
+
)
|
|
180
180
|
|
|
181
181
|
|
|
182
182
|
def last_ai_message(messages: Sequence[BaseMessage]) -> Optional[AIMessage]:
|
|
@@ -192,7 +192,9 @@ def last_ai_message(messages: Sequence[BaseMessage]) -> Optional[AIMessage]:
|
|
|
192
192
|
Returns:
|
|
193
193
|
The last AIMessage in the sequence, or None if no AI messages found
|
|
194
194
|
"""
|
|
195
|
-
return last_message(
|
|
195
|
+
return last_message(
|
|
196
|
+
messages, lambda m: isinstance(m, AIMessage) and bool(m.content)
|
|
197
|
+
)
|
|
196
198
|
|
|
197
199
|
|
|
198
200
|
def last_tool_message(messages: Sequence[BaseMessage]) -> Optional[ToolMessage]:
|
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
# DAO AI Middleware Module
|
|
2
|
+
# This module provides middleware implementations compatible with LangChain v1's create_agent
|
|
3
|
+
|
|
4
|
+
# Re-export LangChain built-in middleware
|
|
5
|
+
from langchain.agents.middleware import (
|
|
6
|
+
HumanInTheLoopMiddleware,
|
|
7
|
+
SummarizationMiddleware,
|
|
8
|
+
after_agent,
|
|
9
|
+
after_model,
|
|
10
|
+
before_agent,
|
|
11
|
+
before_model,
|
|
12
|
+
dynamic_prompt,
|
|
13
|
+
wrap_model_call,
|
|
14
|
+
wrap_tool_call,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
# DSPy-style assertion middleware
|
|
18
|
+
from dao_ai.middleware.assertions import (
|
|
19
|
+
# Middleware classes
|
|
20
|
+
AssertMiddleware,
|
|
21
|
+
# Types
|
|
22
|
+
Constraint,
|
|
23
|
+
ConstraintResult,
|
|
24
|
+
FunctionConstraint,
|
|
25
|
+
KeywordConstraint,
|
|
26
|
+
LengthConstraint,
|
|
27
|
+
LLMConstraint,
|
|
28
|
+
RefineMiddleware,
|
|
29
|
+
SuggestMiddleware,
|
|
30
|
+
# Factory functions
|
|
31
|
+
create_assert_middleware,
|
|
32
|
+
create_refine_middleware,
|
|
33
|
+
create_suggest_middleware,
|
|
34
|
+
)
|
|
35
|
+
from dao_ai.middleware.base import (
|
|
36
|
+
AgentMiddleware,
|
|
37
|
+
ModelRequest,
|
|
38
|
+
ModelResponse,
|
|
39
|
+
)
|
|
40
|
+
from dao_ai.middleware.core import create_factory_middleware
|
|
41
|
+
from dao_ai.middleware.guardrails import (
|
|
42
|
+
ContentFilterMiddleware,
|
|
43
|
+
GuardrailMiddleware,
|
|
44
|
+
SafetyGuardrailMiddleware,
|
|
45
|
+
create_content_filter_middleware,
|
|
46
|
+
create_guardrail_middleware,
|
|
47
|
+
create_safety_guardrail_middleware,
|
|
48
|
+
)
|
|
49
|
+
from dao_ai.middleware.human_in_the_loop import (
|
|
50
|
+
create_hitl_middleware_from_tool_models,
|
|
51
|
+
create_human_in_the_loop_middleware,
|
|
52
|
+
)
|
|
53
|
+
from dao_ai.middleware.message_validation import (
|
|
54
|
+
CustomFieldValidationMiddleware,
|
|
55
|
+
FilterLastHumanMessageMiddleware,
|
|
56
|
+
MessageValidationMiddleware,
|
|
57
|
+
RequiredField,
|
|
58
|
+
ThreadIdValidationMiddleware,
|
|
59
|
+
UserIdValidationMiddleware,
|
|
60
|
+
create_custom_field_validation_middleware,
|
|
61
|
+
create_filter_last_human_message_middleware,
|
|
62
|
+
create_thread_id_validation_middleware,
|
|
63
|
+
create_user_id_validation_middleware,
|
|
64
|
+
)
|
|
65
|
+
from dao_ai.middleware.summarization import (
|
|
66
|
+
LoggingSummarizationMiddleware,
|
|
67
|
+
create_summarization_middleware,
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
__all__ = [
|
|
71
|
+
# Base class (from LangChain)
|
|
72
|
+
"AgentMiddleware",
|
|
73
|
+
# Types
|
|
74
|
+
"ModelRequest",
|
|
75
|
+
"ModelResponse",
|
|
76
|
+
# LangChain decorators
|
|
77
|
+
"before_agent",
|
|
78
|
+
"before_model",
|
|
79
|
+
"after_agent",
|
|
80
|
+
"after_model",
|
|
81
|
+
"wrap_model_call",
|
|
82
|
+
"wrap_tool_call",
|
|
83
|
+
"dynamic_prompt",
|
|
84
|
+
# LangChain built-in middleware
|
|
85
|
+
"SummarizationMiddleware",
|
|
86
|
+
"LoggingSummarizationMiddleware",
|
|
87
|
+
"HumanInTheLoopMiddleware",
|
|
88
|
+
# Core factory function
|
|
89
|
+
"create_factory_middleware",
|
|
90
|
+
# DAO AI middleware implementations
|
|
91
|
+
"GuardrailMiddleware",
|
|
92
|
+
"ContentFilterMiddleware",
|
|
93
|
+
"SafetyGuardrailMiddleware",
|
|
94
|
+
"MessageValidationMiddleware",
|
|
95
|
+
"UserIdValidationMiddleware",
|
|
96
|
+
"ThreadIdValidationMiddleware",
|
|
97
|
+
"CustomFieldValidationMiddleware",
|
|
98
|
+
"RequiredField",
|
|
99
|
+
"FilterLastHumanMessageMiddleware",
|
|
100
|
+
# DSPy-style assertion middleware
|
|
101
|
+
"Constraint",
|
|
102
|
+
"ConstraintResult",
|
|
103
|
+
"FunctionConstraint",
|
|
104
|
+
"KeywordConstraint",
|
|
105
|
+
"LengthConstraint",
|
|
106
|
+
"LLMConstraint",
|
|
107
|
+
"AssertMiddleware",
|
|
108
|
+
"SuggestMiddleware",
|
|
109
|
+
"RefineMiddleware",
|
|
110
|
+
# DAO AI middleware factory functions
|
|
111
|
+
"create_guardrail_middleware",
|
|
112
|
+
"create_content_filter_middleware",
|
|
113
|
+
"create_safety_guardrail_middleware",
|
|
114
|
+
"create_user_id_validation_middleware",
|
|
115
|
+
"create_thread_id_validation_middleware",
|
|
116
|
+
"create_custom_field_validation_middleware",
|
|
117
|
+
"create_filter_last_human_message_middleware",
|
|
118
|
+
"create_summarization_middleware",
|
|
119
|
+
"create_human_in_the_loop_middleware",
|
|
120
|
+
"create_hitl_middleware_from_tool_models",
|
|
121
|
+
# DSPy-style assertion factory functions
|
|
122
|
+
"create_assert_middleware",
|
|
123
|
+
"create_suggest_middleware",
|
|
124
|
+
"create_refine_middleware",
|
|
125
|
+
]
|