dao-ai 0.0.36__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/__init__.py +29 -0
- dao_ai/cli.py +195 -30
- dao_ai/config.py +770 -244
- dao_ai/genie/__init__.py +1 -22
- dao_ai/genie/cache/__init__.py +1 -2
- dao_ai/genie/cache/base.py +20 -70
- dao_ai/genie/cache/core.py +75 -0
- dao_ai/genie/cache/lru.py +44 -21
- dao_ai/genie/cache/semantic.py +390 -109
- dao_ai/genie/core.py +35 -0
- dao_ai/graph.py +27 -253
- dao_ai/hooks/__init__.py +9 -6
- dao_ai/hooks/core.py +22 -190
- dao_ai/memory/__init__.py +10 -0
- dao_ai/memory/core.py +23 -5
- dao_ai/memory/databricks.py +389 -0
- dao_ai/memory/postgres.py +2 -2
- dao_ai/messages.py +6 -4
- dao_ai/middleware/__init__.py +125 -0
- dao_ai/middleware/assertions.py +778 -0
- dao_ai/middleware/base.py +50 -0
- dao_ai/middleware/core.py +61 -0
- dao_ai/middleware/guardrails.py +415 -0
- dao_ai/middleware/human_in_the_loop.py +228 -0
- dao_ai/middleware/message_validation.py +554 -0
- dao_ai/middleware/summarization.py +192 -0
- dao_ai/models.py +1177 -108
- dao_ai/nodes.py +118 -161
- dao_ai/optimization.py +664 -0
- dao_ai/orchestration/__init__.py +52 -0
- dao_ai/orchestration/core.py +287 -0
- dao_ai/orchestration/supervisor.py +264 -0
- dao_ai/orchestration/swarm.py +226 -0
- dao_ai/prompts.py +126 -29
- dao_ai/providers/databricks.py +126 -381
- dao_ai/state.py +139 -21
- dao_ai/tools/__init__.py +8 -5
- dao_ai/tools/core.py +57 -4
- dao_ai/tools/email.py +280 -0
- dao_ai/tools/genie.py +47 -24
- dao_ai/tools/mcp.py +4 -3
- dao_ai/tools/memory.py +50 -0
- dao_ai/tools/python.py +4 -12
- dao_ai/tools/search.py +14 -0
- dao_ai/tools/slack.py +1 -1
- dao_ai/tools/unity_catalog.py +8 -6
- dao_ai/tools/vector_search.py +16 -9
- dao_ai/utils.py +72 -8
- dao_ai-0.1.1.dist-info/METADATA +1878 -0
- dao_ai-0.1.1.dist-info/RECORD +62 -0
- dao_ai/chat_models.py +0 -204
- dao_ai/guardrails.py +0 -112
- dao_ai/tools/genie/__init__.py +0 -236
- dao_ai/tools/human_in_the_loop.py +0 -100
- dao_ai-0.0.36.dist-info/METADATA +0 -951
- dao_ai-0.0.36.dist-info/RECORD +0 -47
- {dao_ai-0.0.36.dist-info → dao_ai-0.1.1.dist-info}/WHEEL +0 -0
- {dao_ai-0.0.36.dist-info → dao_ai-0.1.1.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.0.36.dist-info → dao_ai-0.1.1.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,389 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Databricks-native memory storage implementations.
|
|
3
|
+
|
|
4
|
+
Provides CheckpointSaver and DatabricksStore implementations using
|
|
5
|
+
Databricks Lakebase for persistent storage, with async support.
|
|
6
|
+
|
|
7
|
+
See:
|
|
8
|
+
- https://api-docs.databricks.com/python/databricks-ai-bridge/latest/databricks_langchain.html#databricks_langchain.CheckpointSaver
|
|
9
|
+
- https://api-docs.databricks.com/python/databricks-ai-bridge/latest/databricks_langchain.html#databricks_langchain.DatabricksStore
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import asyncio
|
|
13
|
+
from collections.abc import AsyncIterator, Iterable, Sequence
|
|
14
|
+
from functools import partial
|
|
15
|
+
from typing import Any, Literal
|
|
16
|
+
|
|
17
|
+
from databricks_langchain import (
|
|
18
|
+
CheckpointSaver as DatabricksCheckpointSaver,
|
|
19
|
+
)
|
|
20
|
+
from databricks_langchain import (
|
|
21
|
+
DatabricksEmbeddings,
|
|
22
|
+
DatabricksStore,
|
|
23
|
+
)
|
|
24
|
+
from langchain_core.runnables import RunnableConfig
|
|
25
|
+
from langgraph.checkpoint.base import (
|
|
26
|
+
BaseCheckpointSaver,
|
|
27
|
+
ChannelVersions,
|
|
28
|
+
Checkpoint,
|
|
29
|
+
CheckpointMetadata,
|
|
30
|
+
CheckpointTuple,
|
|
31
|
+
)
|
|
32
|
+
from langgraph.store.base import BaseStore, Item, Op, Result, SearchItem
|
|
33
|
+
from loguru import logger
|
|
34
|
+
|
|
35
|
+
from dao_ai.config import (
|
|
36
|
+
CheckpointerModel,
|
|
37
|
+
StoreModel,
|
|
38
|
+
)
|
|
39
|
+
from dao_ai.memory.base import (
|
|
40
|
+
CheckpointManagerBase,
|
|
41
|
+
StoreManagerBase,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
# Type alias for namespace path
|
|
45
|
+
NamespacePath = tuple[str, ...]
|
|
46
|
+
|
|
47
|
+
# Sentinel for not-provided values
|
|
48
|
+
NOT_PROVIDED = object()
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class AsyncDatabricksCheckpointSaver(DatabricksCheckpointSaver):
|
|
52
|
+
"""
|
|
53
|
+
Async wrapper for DatabricksCheckpointSaver.
|
|
54
|
+
|
|
55
|
+
Provides async implementations of checkpoint methods by delegating
|
|
56
|
+
to the sync methods using asyncio.to_thread().
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
async def aget_tuple(self, config: RunnableConfig) -> CheckpointTuple | None:
|
|
60
|
+
"""Async version of get_tuple."""
|
|
61
|
+
thread_id = config.get("configurable", {}).get("thread_id", "unknown")
|
|
62
|
+
logger.debug(f"aget_tuple: Fetching checkpoint for thread_id={thread_id}")
|
|
63
|
+
result = await asyncio.to_thread(self.get_tuple, config)
|
|
64
|
+
if result:
|
|
65
|
+
logger.debug(f"aget_tuple: Found checkpoint for thread_id={thread_id}")
|
|
66
|
+
else:
|
|
67
|
+
logger.debug(f"aget_tuple: No checkpoint found for thread_id={thread_id}")
|
|
68
|
+
return result
|
|
69
|
+
|
|
70
|
+
async def aget(self, config: RunnableConfig) -> Checkpoint | None:
|
|
71
|
+
"""Async version of get."""
|
|
72
|
+
thread_id = config.get("configurable", {}).get("thread_id", "unknown")
|
|
73
|
+
logger.debug(f"aget: Fetching checkpoint for thread_id={thread_id}")
|
|
74
|
+
result = await asyncio.to_thread(self.get, config)
|
|
75
|
+
if result:
|
|
76
|
+
logger.debug(f"aget: Found checkpoint for thread_id={thread_id}")
|
|
77
|
+
else:
|
|
78
|
+
logger.debug(f"aget: No checkpoint found for thread_id={thread_id}")
|
|
79
|
+
return result
|
|
80
|
+
|
|
81
|
+
async def aput(
|
|
82
|
+
self,
|
|
83
|
+
config: RunnableConfig,
|
|
84
|
+
checkpoint: Checkpoint,
|
|
85
|
+
metadata: CheckpointMetadata,
|
|
86
|
+
new_versions: ChannelVersions,
|
|
87
|
+
) -> RunnableConfig:
|
|
88
|
+
"""Async version of put."""
|
|
89
|
+
thread_id = config.get("configurable", {}).get("thread_id", "unknown")
|
|
90
|
+
checkpoint_id = checkpoint.get("id", "unknown")
|
|
91
|
+
logger.debug(
|
|
92
|
+
f"aput: Saving checkpoint id={checkpoint_id} for thread_id={thread_id}"
|
|
93
|
+
)
|
|
94
|
+
result = await asyncio.to_thread(
|
|
95
|
+
self.put, config, checkpoint, metadata, new_versions
|
|
96
|
+
)
|
|
97
|
+
logger.debug(f"aput: Checkpoint saved for thread_id={thread_id}")
|
|
98
|
+
return result
|
|
99
|
+
|
|
100
|
+
async def aput_writes(
|
|
101
|
+
self,
|
|
102
|
+
config: RunnableConfig,
|
|
103
|
+
writes: Sequence[tuple[str, Any]],
|
|
104
|
+
task_id: str,
|
|
105
|
+
task_path: str = "",
|
|
106
|
+
) -> None:
|
|
107
|
+
"""Async version of put_writes."""
|
|
108
|
+
thread_id = config.get("configurable", {}).get("thread_id", "unknown")
|
|
109
|
+
logger.debug(
|
|
110
|
+
f"aput_writes: Saving {len(writes)} writes for thread_id={thread_id}, "
|
|
111
|
+
f"task_id={task_id}"
|
|
112
|
+
)
|
|
113
|
+
await asyncio.to_thread(self.put_writes, config, writes, task_id, task_path)
|
|
114
|
+
logger.debug(f"aput_writes: Writes saved for thread_id={thread_id}")
|
|
115
|
+
|
|
116
|
+
async def alist(
|
|
117
|
+
self,
|
|
118
|
+
config: RunnableConfig | None,
|
|
119
|
+
*,
|
|
120
|
+
filter: dict[str, Any] | None = None,
|
|
121
|
+
before: RunnableConfig | None = None,
|
|
122
|
+
limit: int | None = None,
|
|
123
|
+
) -> AsyncIterator[CheckpointTuple]:
|
|
124
|
+
"""Async version of list."""
|
|
125
|
+
thread_id = (
|
|
126
|
+
config.get("configurable", {}).get("thread_id", "unknown")
|
|
127
|
+
if config
|
|
128
|
+
else "all"
|
|
129
|
+
)
|
|
130
|
+
logger.debug(
|
|
131
|
+
f"alist: Listing checkpoints for thread_id={thread_id}, limit={limit}"
|
|
132
|
+
)
|
|
133
|
+
# Get all items from sync iterator in a thread
|
|
134
|
+
items = await asyncio.to_thread(
|
|
135
|
+
lambda: list(self.list(config, filter=filter, before=before, limit=limit))
|
|
136
|
+
)
|
|
137
|
+
logger.debug(f"alist: Found {len(items)} checkpoints for thread_id={thread_id}")
|
|
138
|
+
for item in items:
|
|
139
|
+
yield item
|
|
140
|
+
|
|
141
|
+
async def adelete_thread(self, thread_id: str) -> None:
|
|
142
|
+
"""Async version of delete_thread."""
|
|
143
|
+
logger.debug(f"adelete_thread: Deleting thread_id={thread_id}")
|
|
144
|
+
await asyncio.to_thread(self.delete_thread, thread_id)
|
|
145
|
+
logger.debug(f"adelete_thread: Thread deleted thread_id={thread_id}")
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
class AsyncDatabricksStore(DatabricksStore):
|
|
149
|
+
"""
|
|
150
|
+
Async wrapper for DatabricksStore.
|
|
151
|
+
|
|
152
|
+
Provides async implementations of store methods by delegating
|
|
153
|
+
to the sync methods using asyncio.to_thread().
|
|
154
|
+
"""
|
|
155
|
+
|
|
156
|
+
async def abatch(self, ops: Iterable[Op]) -> list[Result]:
|
|
157
|
+
"""Async version of batch."""
|
|
158
|
+
ops_list = list(ops)
|
|
159
|
+
logger.debug(f"abatch: Executing {len(ops_list)} operations")
|
|
160
|
+
result = await asyncio.to_thread(self.batch, ops_list)
|
|
161
|
+
logger.debug(f"abatch: Completed {len(result)} operations")
|
|
162
|
+
return result
|
|
163
|
+
|
|
164
|
+
async def aget(
|
|
165
|
+
self,
|
|
166
|
+
namespace: tuple[str, ...],
|
|
167
|
+
key: str,
|
|
168
|
+
*,
|
|
169
|
+
refresh_ttl: bool | None = None,
|
|
170
|
+
) -> Item | None:
|
|
171
|
+
"""Async version of get."""
|
|
172
|
+
ns_str = "/".join(namespace)
|
|
173
|
+
logger.debug(f"aget: Fetching key={key} from namespace={ns_str}")
|
|
174
|
+
result = await asyncio.to_thread(
|
|
175
|
+
partial(self.get, namespace, key, refresh_ttl=refresh_ttl)
|
|
176
|
+
)
|
|
177
|
+
if result:
|
|
178
|
+
logger.debug(f"aget: Found item key={key} in namespace={ns_str}")
|
|
179
|
+
else:
|
|
180
|
+
logger.debug(f"aget: No item found key={key} in namespace={ns_str}")
|
|
181
|
+
return result
|
|
182
|
+
|
|
183
|
+
async def aput(
|
|
184
|
+
self,
|
|
185
|
+
namespace: tuple[str, ...],
|
|
186
|
+
key: str,
|
|
187
|
+
value: dict[str, Any],
|
|
188
|
+
index: Literal[False] | list[str] | None = None,
|
|
189
|
+
*,
|
|
190
|
+
ttl: float | None = None,
|
|
191
|
+
) -> None:
|
|
192
|
+
"""Async version of put."""
|
|
193
|
+
ns_str = "/".join(namespace)
|
|
194
|
+
logger.debug(f"aput: Storing key={key} in namespace={ns_str}")
|
|
195
|
+
# Handle the ttl parameter - only pass if explicitly provided
|
|
196
|
+
if ttl is not None:
|
|
197
|
+
await asyncio.to_thread(
|
|
198
|
+
partial(self.put, namespace, key, value, index, ttl=ttl)
|
|
199
|
+
)
|
|
200
|
+
else:
|
|
201
|
+
await asyncio.to_thread(partial(self.put, namespace, key, value, index))
|
|
202
|
+
logger.debug(f"aput: Stored key={key} in namespace={ns_str}")
|
|
203
|
+
|
|
204
|
+
async def adelete(self, namespace: tuple[str, ...], key: str) -> None:
|
|
205
|
+
"""Async version of delete."""
|
|
206
|
+
ns_str = "/".join(namespace)
|
|
207
|
+
logger.debug(f"adelete: Deleting key={key} from namespace={ns_str}")
|
|
208
|
+
await asyncio.to_thread(self.delete, namespace, key)
|
|
209
|
+
logger.debug(f"adelete: Deleted key={key} from namespace={ns_str}")
|
|
210
|
+
|
|
211
|
+
async def asearch(
|
|
212
|
+
self,
|
|
213
|
+
namespace_prefix: tuple[str, ...],
|
|
214
|
+
/,
|
|
215
|
+
*,
|
|
216
|
+
query: str | None = None,
|
|
217
|
+
filter: dict[str, Any] | None = None,
|
|
218
|
+
limit: int = 10,
|
|
219
|
+
offset: int = 0,
|
|
220
|
+
refresh_ttl: bool | None = None,
|
|
221
|
+
) -> list[SearchItem]:
|
|
222
|
+
"""Async version of search."""
|
|
223
|
+
ns_str = "/".join(namespace_prefix)
|
|
224
|
+
logger.debug(
|
|
225
|
+
f"asearch: Searching namespace_prefix={ns_str}, query={query}, limit={limit}"
|
|
226
|
+
)
|
|
227
|
+
result = await asyncio.to_thread(
|
|
228
|
+
partial(
|
|
229
|
+
self.search,
|
|
230
|
+
namespace_prefix,
|
|
231
|
+
query=query,
|
|
232
|
+
filter=filter,
|
|
233
|
+
limit=limit,
|
|
234
|
+
offset=offset,
|
|
235
|
+
refresh_ttl=refresh_ttl,
|
|
236
|
+
)
|
|
237
|
+
)
|
|
238
|
+
logger.debug(f"asearch: Found {len(result)} items in namespace_prefix={ns_str}")
|
|
239
|
+
return result
|
|
240
|
+
|
|
241
|
+
async def alist_namespaces(
|
|
242
|
+
self,
|
|
243
|
+
*,
|
|
244
|
+
prefix: NamespacePath | None = None,
|
|
245
|
+
suffix: NamespacePath | None = None,
|
|
246
|
+
max_depth: int | None = None,
|
|
247
|
+
limit: int = 100,
|
|
248
|
+
offset: int = 0,
|
|
249
|
+
) -> list[tuple[str, ...]]:
|
|
250
|
+
"""Async version of list_namespaces."""
|
|
251
|
+
prefix_str = "/".join(prefix) if prefix else "all"
|
|
252
|
+
logger.debug(
|
|
253
|
+
f"alist_namespaces: Listing namespaces prefix={prefix_str}, limit={limit}"
|
|
254
|
+
)
|
|
255
|
+
result = await asyncio.to_thread(
|
|
256
|
+
partial(
|
|
257
|
+
self.list_namespaces,
|
|
258
|
+
prefix=prefix,
|
|
259
|
+
suffix=suffix,
|
|
260
|
+
max_depth=max_depth,
|
|
261
|
+
limit=limit,
|
|
262
|
+
offset=offset,
|
|
263
|
+
)
|
|
264
|
+
)
|
|
265
|
+
logger.debug(f"alist_namespaces: Found {len(result)} namespaces")
|
|
266
|
+
return result
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
class DatabricksCheckpointerManager(CheckpointManagerBase):
|
|
270
|
+
"""
|
|
271
|
+
Checkpointer manager using Databricks CheckpointSaver with async support.
|
|
272
|
+
|
|
273
|
+
Uses AsyncDatabricksCheckpointSaver which wraps databricks_langchain.CheckpointSaver
|
|
274
|
+
with async method implementations for LangGraph async streaming compatibility.
|
|
275
|
+
|
|
276
|
+
Required configuration via CheckpointerModel.database:
|
|
277
|
+
- instance_name: The Databricks Lakebase instance name
|
|
278
|
+
- workspace_client: WorkspaceClient (supports OBO, service principal, or default auth)
|
|
279
|
+
|
|
280
|
+
See: https://api-docs.databricks.com/python/databricks-ai-bridge/latest/databricks_langchain.html#databricks_langchain.CheckpointSaver
|
|
281
|
+
"""
|
|
282
|
+
|
|
283
|
+
def __init__(self, checkpointer_model: CheckpointerModel):
|
|
284
|
+
self.checkpointer_model = checkpointer_model
|
|
285
|
+
self._checkpointer: BaseCheckpointSaver | None = None
|
|
286
|
+
|
|
287
|
+
def checkpointer(self) -> BaseCheckpointSaver:
|
|
288
|
+
if self._checkpointer is None:
|
|
289
|
+
database = self.checkpointer_model.database
|
|
290
|
+
if database is None:
|
|
291
|
+
raise ValueError(
|
|
292
|
+
"Database configuration is required for Databricks checkpointer. "
|
|
293
|
+
"Please provide a 'database' field in the checkpointer configuration."
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
instance_name = database.instance_name
|
|
297
|
+
workspace_client = database.workspace_client
|
|
298
|
+
|
|
299
|
+
logger.debug(
|
|
300
|
+
f"Creating AsyncDatabricksCheckpointSaver for instance: {instance_name}"
|
|
301
|
+
)
|
|
302
|
+
|
|
303
|
+
checkpointer = AsyncDatabricksCheckpointSaver(
|
|
304
|
+
instance_name=instance_name,
|
|
305
|
+
workspace_client=workspace_client,
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
# Setup the checkpointer (creates necessary tables if needed)
|
|
309
|
+
logger.debug(f"Setting up checkpoint tables for instance: {instance_name}")
|
|
310
|
+
checkpointer.setup()
|
|
311
|
+
logger.debug(
|
|
312
|
+
f"Checkpoint tables setup complete for instance: {instance_name}"
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
self._checkpointer = checkpointer
|
|
316
|
+
|
|
317
|
+
return self._checkpointer
|
|
318
|
+
|
|
319
|
+
|
|
320
|
+
class DatabricksStoreManager(StoreManagerBase):
|
|
321
|
+
"""
|
|
322
|
+
Store manager using Databricks DatabricksStore with async support.
|
|
323
|
+
|
|
324
|
+
Uses AsyncDatabricksStore which wraps databricks_langchain.DatabricksStore
|
|
325
|
+
with async method implementations for LangGraph async streaming compatibility.
|
|
326
|
+
|
|
327
|
+
Required configuration via StoreModel.database:
|
|
328
|
+
- instance_name: The Databricks Lakebase instance name
|
|
329
|
+
- workspace_client: WorkspaceClient (supports OBO, service principal, or default auth)
|
|
330
|
+
|
|
331
|
+
Optional configuration via StoreModel:
|
|
332
|
+
- embedding_model: LLMModel for embeddings (will be converted to DatabricksEmbeddings)
|
|
333
|
+
- dims: Embedding dimensions
|
|
334
|
+
|
|
335
|
+
See: https://api-docs.databricks.com/python/databricks-ai-bridge/latest/databricks_langchain.html#databricks_langchain.DatabricksStore
|
|
336
|
+
"""
|
|
337
|
+
|
|
338
|
+
def __init__(self, store_model: StoreModel):
|
|
339
|
+
self.store_model = store_model
|
|
340
|
+
self._store: BaseStore | None = None
|
|
341
|
+
|
|
342
|
+
def store(self) -> BaseStore:
|
|
343
|
+
if self._store is None:
|
|
344
|
+
database = self.store_model.database
|
|
345
|
+
if database is None:
|
|
346
|
+
raise ValueError(
|
|
347
|
+
"Database configuration is required for Databricks store. "
|
|
348
|
+
"Please provide a 'database' field in the store configuration."
|
|
349
|
+
)
|
|
350
|
+
|
|
351
|
+
instance_name = database.instance_name
|
|
352
|
+
workspace_client = database.workspace_client
|
|
353
|
+
|
|
354
|
+
# Build embeddings configuration if embedding_model is provided
|
|
355
|
+
embeddings: DatabricksEmbeddings | None = None
|
|
356
|
+
embedding_dims: int | None = None
|
|
357
|
+
|
|
358
|
+
if self.store_model.embedding_model is not None:
|
|
359
|
+
embedding_endpoint = self.store_model.embedding_model.name
|
|
360
|
+
embedding_dims = self.store_model.dims
|
|
361
|
+
|
|
362
|
+
logger.debug(
|
|
363
|
+
f"Configuring embeddings: endpoint={embedding_endpoint}, dims={embedding_dims}"
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
embeddings = DatabricksEmbeddings(endpoint=embedding_endpoint)
|
|
367
|
+
|
|
368
|
+
logger.debug(f"Creating AsyncDatabricksStore for instance: {instance_name}")
|
|
369
|
+
|
|
370
|
+
store = AsyncDatabricksStore(
|
|
371
|
+
instance_name=instance_name,
|
|
372
|
+
workspace_client=workspace_client,
|
|
373
|
+
embeddings=embeddings,
|
|
374
|
+
embedding_dims=embedding_dims,
|
|
375
|
+
)
|
|
376
|
+
|
|
377
|
+
# Setup the store (creates necessary tables if needed)
|
|
378
|
+
store.setup()
|
|
379
|
+
self._store = store
|
|
380
|
+
|
|
381
|
+
return self._store
|
|
382
|
+
|
|
383
|
+
|
|
384
|
+
__all__ = [
|
|
385
|
+
"AsyncDatabricksCheckpointSaver",
|
|
386
|
+
"AsyncDatabricksStore",
|
|
387
|
+
"DatabricksCheckpointerManager",
|
|
388
|
+
"DatabricksStoreManager",
|
|
389
|
+
]
|
dao_ai/memory/postgres.py
CHANGED
|
@@ -409,7 +409,7 @@ class PostgresCheckpointerManager(CheckpointManagerBase):
|
|
|
409
409
|
raise
|
|
410
410
|
|
|
411
411
|
|
|
412
|
-
def _shutdown_pools():
|
|
412
|
+
def _shutdown_pools() -> None:
|
|
413
413
|
try:
|
|
414
414
|
PostgresPoolManager.close_all_pools()
|
|
415
415
|
logger.debug("Successfully closed all synchronous PostgreSQL pools")
|
|
@@ -417,7 +417,7 @@ def _shutdown_pools():
|
|
|
417
417
|
logger.error(f"Error closing synchronous PostgreSQL pools during shutdown: {e}")
|
|
418
418
|
|
|
419
419
|
|
|
420
|
-
def _shutdown_async_pools():
|
|
420
|
+
def _shutdown_async_pools() -> None:
|
|
421
421
|
try:
|
|
422
422
|
# Try to get the current event loop first
|
|
423
423
|
try:
|
dao_ai/messages.py
CHANGED
|
@@ -125,8 +125,6 @@ def has_image(messages: BaseMessage | Sequence[BaseMessage]) -> bool:
|
|
|
125
125
|
"image_url",
|
|
126
126
|
]:
|
|
127
127
|
return True
|
|
128
|
-
if hasattr(item, "type") and item.type in ["image", "image_url"]:
|
|
129
|
-
return True
|
|
130
128
|
return False
|
|
131
129
|
|
|
132
130
|
if isinstance(messages, BaseMessage):
|
|
@@ -176,7 +174,9 @@ def last_human_message(messages: Sequence[BaseMessage]) -> Optional[HumanMessage
|
|
|
176
174
|
Returns:
|
|
177
175
|
The last HumanMessage in the sequence, or None if no human messages found
|
|
178
176
|
"""
|
|
179
|
-
return last_message(
|
|
177
|
+
return last_message(
|
|
178
|
+
messages, lambda m: isinstance(m, HumanMessage) and bool(m.content)
|
|
179
|
+
)
|
|
180
180
|
|
|
181
181
|
|
|
182
182
|
def last_ai_message(messages: Sequence[BaseMessage]) -> Optional[AIMessage]:
|
|
@@ -192,7 +192,9 @@ def last_ai_message(messages: Sequence[BaseMessage]) -> Optional[AIMessage]:
|
|
|
192
192
|
Returns:
|
|
193
193
|
The last AIMessage in the sequence, or None if no AI messages found
|
|
194
194
|
"""
|
|
195
|
-
return last_message(
|
|
195
|
+
return last_message(
|
|
196
|
+
messages, lambda m: isinstance(m, AIMessage) and bool(m.content)
|
|
197
|
+
)
|
|
196
198
|
|
|
197
199
|
|
|
198
200
|
def last_tool_message(messages: Sequence[BaseMessage]) -> Optional[ToolMessage]:
|
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
# DAO AI Middleware Module
|
|
2
|
+
# This module provides middleware implementations compatible with LangChain v1's create_agent
|
|
3
|
+
|
|
4
|
+
# Re-export LangChain built-in middleware
|
|
5
|
+
from langchain.agents.middleware import (
|
|
6
|
+
HumanInTheLoopMiddleware,
|
|
7
|
+
SummarizationMiddleware,
|
|
8
|
+
after_agent,
|
|
9
|
+
after_model,
|
|
10
|
+
before_agent,
|
|
11
|
+
before_model,
|
|
12
|
+
dynamic_prompt,
|
|
13
|
+
wrap_model_call,
|
|
14
|
+
wrap_tool_call,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
# DSPy-style assertion middleware
|
|
18
|
+
from dao_ai.middleware.assertions import (
|
|
19
|
+
# Middleware classes
|
|
20
|
+
AssertMiddleware,
|
|
21
|
+
# Types
|
|
22
|
+
Constraint,
|
|
23
|
+
ConstraintResult,
|
|
24
|
+
FunctionConstraint,
|
|
25
|
+
KeywordConstraint,
|
|
26
|
+
LengthConstraint,
|
|
27
|
+
LLMConstraint,
|
|
28
|
+
RefineMiddleware,
|
|
29
|
+
SuggestMiddleware,
|
|
30
|
+
# Factory functions
|
|
31
|
+
create_assert_middleware,
|
|
32
|
+
create_refine_middleware,
|
|
33
|
+
create_suggest_middleware,
|
|
34
|
+
)
|
|
35
|
+
from dao_ai.middleware.base import (
|
|
36
|
+
AgentMiddleware,
|
|
37
|
+
ModelRequest,
|
|
38
|
+
ModelResponse,
|
|
39
|
+
)
|
|
40
|
+
from dao_ai.middleware.core import create_factory_middleware
|
|
41
|
+
from dao_ai.middleware.guardrails import (
|
|
42
|
+
ContentFilterMiddleware,
|
|
43
|
+
GuardrailMiddleware,
|
|
44
|
+
SafetyGuardrailMiddleware,
|
|
45
|
+
create_content_filter_middleware,
|
|
46
|
+
create_guardrail_middleware,
|
|
47
|
+
create_safety_guardrail_middleware,
|
|
48
|
+
)
|
|
49
|
+
from dao_ai.middleware.human_in_the_loop import (
|
|
50
|
+
create_hitl_middleware_from_tool_models,
|
|
51
|
+
create_human_in_the_loop_middleware,
|
|
52
|
+
)
|
|
53
|
+
from dao_ai.middleware.message_validation import (
|
|
54
|
+
CustomFieldValidationMiddleware,
|
|
55
|
+
FilterLastHumanMessageMiddleware,
|
|
56
|
+
MessageValidationMiddleware,
|
|
57
|
+
RequiredField,
|
|
58
|
+
ThreadIdValidationMiddleware,
|
|
59
|
+
UserIdValidationMiddleware,
|
|
60
|
+
create_custom_field_validation_middleware,
|
|
61
|
+
create_filter_last_human_message_middleware,
|
|
62
|
+
create_thread_id_validation_middleware,
|
|
63
|
+
create_user_id_validation_middleware,
|
|
64
|
+
)
|
|
65
|
+
from dao_ai.middleware.summarization import (
|
|
66
|
+
LoggingSummarizationMiddleware,
|
|
67
|
+
create_summarization_middleware,
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
__all__ = [
|
|
71
|
+
# Base class (from LangChain)
|
|
72
|
+
"AgentMiddleware",
|
|
73
|
+
# Types
|
|
74
|
+
"ModelRequest",
|
|
75
|
+
"ModelResponse",
|
|
76
|
+
# LangChain decorators
|
|
77
|
+
"before_agent",
|
|
78
|
+
"before_model",
|
|
79
|
+
"after_agent",
|
|
80
|
+
"after_model",
|
|
81
|
+
"wrap_model_call",
|
|
82
|
+
"wrap_tool_call",
|
|
83
|
+
"dynamic_prompt",
|
|
84
|
+
# LangChain built-in middleware
|
|
85
|
+
"SummarizationMiddleware",
|
|
86
|
+
"LoggingSummarizationMiddleware",
|
|
87
|
+
"HumanInTheLoopMiddleware",
|
|
88
|
+
# Core factory function
|
|
89
|
+
"create_factory_middleware",
|
|
90
|
+
# DAO AI middleware implementations
|
|
91
|
+
"GuardrailMiddleware",
|
|
92
|
+
"ContentFilterMiddleware",
|
|
93
|
+
"SafetyGuardrailMiddleware",
|
|
94
|
+
"MessageValidationMiddleware",
|
|
95
|
+
"UserIdValidationMiddleware",
|
|
96
|
+
"ThreadIdValidationMiddleware",
|
|
97
|
+
"CustomFieldValidationMiddleware",
|
|
98
|
+
"RequiredField",
|
|
99
|
+
"FilterLastHumanMessageMiddleware",
|
|
100
|
+
# DSPy-style assertion middleware
|
|
101
|
+
"Constraint",
|
|
102
|
+
"ConstraintResult",
|
|
103
|
+
"FunctionConstraint",
|
|
104
|
+
"KeywordConstraint",
|
|
105
|
+
"LengthConstraint",
|
|
106
|
+
"LLMConstraint",
|
|
107
|
+
"AssertMiddleware",
|
|
108
|
+
"SuggestMiddleware",
|
|
109
|
+
"RefineMiddleware",
|
|
110
|
+
# DAO AI middleware factory functions
|
|
111
|
+
"create_guardrail_middleware",
|
|
112
|
+
"create_content_filter_middleware",
|
|
113
|
+
"create_safety_guardrail_middleware",
|
|
114
|
+
"create_user_id_validation_middleware",
|
|
115
|
+
"create_thread_id_validation_middleware",
|
|
116
|
+
"create_custom_field_validation_middleware",
|
|
117
|
+
"create_filter_last_human_message_middleware",
|
|
118
|
+
"create_summarization_middleware",
|
|
119
|
+
"create_human_in_the_loop_middleware",
|
|
120
|
+
"create_hitl_middleware_from_tool_models",
|
|
121
|
+
# DSPy-style assertion factory functions
|
|
122
|
+
"create_assert_middleware",
|
|
123
|
+
"create_suggest_middleware",
|
|
124
|
+
"create_refine_middleware",
|
|
125
|
+
]
|