dao-ai 0.0.25__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/__init__.py +29 -0
- dao_ai/agent_as_code.py +5 -5
- dao_ai/cli.py +245 -40
- dao_ai/config.py +1863 -338
- dao_ai/genie/__init__.py +38 -0
- dao_ai/genie/cache/__init__.py +43 -0
- dao_ai/genie/cache/base.py +72 -0
- dao_ai/genie/cache/core.py +79 -0
- dao_ai/genie/cache/lru.py +347 -0
- dao_ai/genie/cache/semantic.py +970 -0
- dao_ai/genie/core.py +35 -0
- dao_ai/graph.py +27 -228
- dao_ai/hooks/__init__.py +9 -6
- dao_ai/hooks/core.py +27 -195
- dao_ai/logging.py +56 -0
- dao_ai/memory/__init__.py +10 -0
- dao_ai/memory/core.py +65 -30
- dao_ai/memory/databricks.py +402 -0
- dao_ai/memory/postgres.py +79 -38
- dao_ai/messages.py +6 -4
- dao_ai/middleware/__init__.py +125 -0
- dao_ai/middleware/assertions.py +806 -0
- dao_ai/middleware/base.py +50 -0
- dao_ai/middleware/core.py +67 -0
- dao_ai/middleware/guardrails.py +420 -0
- dao_ai/middleware/human_in_the_loop.py +232 -0
- dao_ai/middleware/message_validation.py +586 -0
- dao_ai/middleware/summarization.py +197 -0
- dao_ai/models.py +1306 -114
- dao_ai/nodes.py +261 -166
- dao_ai/optimization.py +674 -0
- dao_ai/orchestration/__init__.py +52 -0
- dao_ai/orchestration/core.py +294 -0
- dao_ai/orchestration/supervisor.py +278 -0
- dao_ai/orchestration/swarm.py +271 -0
- dao_ai/prompts.py +128 -31
- dao_ai/providers/databricks.py +645 -172
- dao_ai/state.py +157 -21
- dao_ai/tools/__init__.py +13 -5
- dao_ai/tools/agent.py +1 -3
- dao_ai/tools/core.py +64 -11
- dao_ai/tools/email.py +232 -0
- dao_ai/tools/genie.py +144 -295
- dao_ai/tools/mcp.py +220 -133
- dao_ai/tools/memory.py +50 -0
- dao_ai/tools/python.py +9 -14
- dao_ai/tools/search.py +14 -0
- dao_ai/tools/slack.py +22 -10
- dao_ai/tools/sql.py +202 -0
- dao_ai/tools/time.py +30 -7
- dao_ai/tools/unity_catalog.py +165 -88
- dao_ai/tools/vector_search.py +360 -40
- dao_ai/utils.py +218 -16
- dao_ai-0.1.2.dist-info/METADATA +455 -0
- dao_ai-0.1.2.dist-info/RECORD +64 -0
- {dao_ai-0.0.25.dist-info → dao_ai-0.1.2.dist-info}/WHEEL +1 -1
- dao_ai/chat_models.py +0 -204
- dao_ai/guardrails.py +0 -112
- dao_ai/tools/human_in_the_loop.py +0 -100
- dao_ai-0.0.25.dist-info/METADATA +0 -1165
- dao_ai-0.0.25.dist-info/RECORD +0 -41
- {dao_ai-0.0.25.dist-info → dao_ai-0.1.2.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.0.25.dist-info → dao_ai-0.1.2.dist-info}/licenses/LICENSE +0 -0
dao_ai/memory/core.py
CHANGED
|
@@ -1,8 +1,6 @@
|
|
|
1
1
|
from typing import Any
|
|
2
2
|
|
|
3
|
-
from databricks_langchain import
|
|
4
|
-
DatabricksEmbeddings,
|
|
5
|
-
)
|
|
3
|
+
from databricks_langchain import DatabricksEmbeddings
|
|
6
4
|
from langchain_core.embeddings.embeddings import Embeddings
|
|
7
5
|
from langgraph.checkpoint.base import BaseCheckpointSaver
|
|
8
6
|
from langgraph.checkpoint.memory import InMemorySaver
|
|
@@ -27,11 +25,13 @@ class InMemoryStoreManager(StoreManagerBase):
|
|
|
27
25
|
self.store_model = store_model
|
|
28
26
|
|
|
29
27
|
def store(self) -> BaseStore:
|
|
30
|
-
|
|
28
|
+
embedding_model: LLMModel = self.store_model.embedding_model
|
|
31
29
|
|
|
32
|
-
|
|
30
|
+
logger.debug(
|
|
31
|
+
"Creating in-memory store", embeddings_enabled=embedding_model is not None
|
|
32
|
+
)
|
|
33
33
|
|
|
34
|
-
|
|
34
|
+
index: dict[str, Any] = None
|
|
35
35
|
|
|
36
36
|
if embedding_model:
|
|
37
37
|
embeddings: Embeddings = DatabricksEmbeddings(endpoint=embedding_model.name)
|
|
@@ -41,6 +41,11 @@ class InMemoryStoreManager(StoreManagerBase):
|
|
|
41
41
|
|
|
42
42
|
dims: int = self.store_model.dims
|
|
43
43
|
index = {"dims": dims, "embed": embed_texts}
|
|
44
|
+
logger.debug(
|
|
45
|
+
"Store embeddings configured",
|
|
46
|
+
endpoint=embedding_model.name,
|
|
47
|
+
dimensions=dims,
|
|
48
|
+
)
|
|
44
49
|
|
|
45
50
|
store: BaseStore = InMemoryStore(index=index)
|
|
46
51
|
|
|
@@ -60,26 +65,39 @@ class StoreManager:
|
|
|
60
65
|
|
|
61
66
|
@classmethod
|
|
62
67
|
def instance(cls, store_model: StoreModel) -> StoreManagerBase:
|
|
63
|
-
store_manager: StoreManagerBase = None
|
|
64
|
-
match store_model.
|
|
68
|
+
store_manager: StoreManagerBase | None = None
|
|
69
|
+
match store_model.storage_type:
|
|
65
70
|
case StorageType.MEMORY:
|
|
66
71
|
store_manager = cls.store_managers.get(store_model.name)
|
|
67
72
|
if store_manager is None:
|
|
68
73
|
store_manager = InMemoryStoreManager(store_model)
|
|
69
74
|
cls.store_managers[store_model.name] = store_manager
|
|
70
75
|
case StorageType.POSTGRES:
|
|
71
|
-
|
|
76
|
+
# Route based on database configuration: instance_name -> Databricks, host -> Postgres
|
|
77
|
+
if store_model.database.is_lakebase:
|
|
78
|
+
# Databricks Lakebase connection
|
|
79
|
+
from dao_ai.memory.databricks import DatabricksStoreManager
|
|
72
80
|
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
)
|
|
76
|
-
if store_manager is None:
|
|
77
|
-
store_manager = PostgresStoreManager(store_model)
|
|
78
|
-
cls.store_managers[store_model.database.instance_name] = (
|
|
79
|
-
store_manager
|
|
81
|
+
store_manager = cls.store_managers.get(
|
|
82
|
+
store_model.database.instance_name
|
|
80
83
|
)
|
|
84
|
+
if store_manager is None:
|
|
85
|
+
store_manager = DatabricksStoreManager(store_model)
|
|
86
|
+
cls.store_managers[store_model.database.instance_name] = (
|
|
87
|
+
store_manager
|
|
88
|
+
)
|
|
89
|
+
else:
|
|
90
|
+
# Standard PostgreSQL connection
|
|
91
|
+
from dao_ai.memory.postgres import PostgresStoreManager
|
|
92
|
+
|
|
93
|
+
# Use database name as key for standard PostgreSQL
|
|
94
|
+
cache_key = f"{store_model.database.name}"
|
|
95
|
+
store_manager = cls.store_managers.get(cache_key)
|
|
96
|
+
if store_manager is None:
|
|
97
|
+
store_manager = PostgresStoreManager(store_model)
|
|
98
|
+
cls.store_managers[cache_key] = store_manager
|
|
81
99
|
case _:
|
|
82
|
-
raise ValueError(f"Unknown
|
|
100
|
+
raise ValueError(f"Unknown storage type: {store_model.storage_type}")
|
|
83
101
|
|
|
84
102
|
return store_manager
|
|
85
103
|
|
|
@@ -89,8 +107,8 @@ class CheckpointManager:
|
|
|
89
107
|
|
|
90
108
|
@classmethod
|
|
91
109
|
def instance(cls, checkpointer_model: CheckpointerModel) -> CheckpointManagerBase:
|
|
92
|
-
checkpointer_manager: CheckpointManagerBase = None
|
|
93
|
-
match checkpointer_model.
|
|
110
|
+
checkpointer_manager: CheckpointManagerBase | None = None
|
|
111
|
+
match checkpointer_model.storage_type:
|
|
94
112
|
case StorageType.MEMORY:
|
|
95
113
|
checkpointer_manager = cls.checkpoint_managers.get(
|
|
96
114
|
checkpointer_model.name
|
|
@@ -103,19 +121,36 @@ class CheckpointManager:
|
|
|
103
121
|
checkpointer_manager
|
|
104
122
|
)
|
|
105
123
|
case StorageType.POSTGRES:
|
|
106
|
-
|
|
124
|
+
# Route based on database configuration: instance_name -> Databricks, host -> Postgres
|
|
125
|
+
if checkpointer_model.database.is_lakebase:
|
|
126
|
+
# Databricks Lakebase connection
|
|
127
|
+
from dao_ai.memory.databricks import DatabricksCheckpointerManager
|
|
107
128
|
|
|
108
|
-
|
|
109
|
-
checkpointer_model.database.instance_name
|
|
110
|
-
)
|
|
111
|
-
if checkpointer_manager is None:
|
|
112
|
-
checkpointer_manager = AsyncPostgresCheckpointerManager(
|
|
113
|
-
checkpointer_model
|
|
114
|
-
)
|
|
115
|
-
cls.checkpoint_managers[
|
|
129
|
+
checkpointer_manager = cls.checkpoint_managers.get(
|
|
116
130
|
checkpointer_model.database.instance_name
|
|
117
|
-
|
|
131
|
+
)
|
|
132
|
+
if checkpointer_manager is None:
|
|
133
|
+
checkpointer_manager = DatabricksCheckpointerManager(
|
|
134
|
+
checkpointer_model
|
|
135
|
+
)
|
|
136
|
+
cls.checkpoint_managers[
|
|
137
|
+
checkpointer_model.database.instance_name
|
|
138
|
+
] = checkpointer_manager
|
|
139
|
+
else:
|
|
140
|
+
# Standard PostgreSQL connection
|
|
141
|
+
from dao_ai.memory.postgres import AsyncPostgresCheckpointerManager
|
|
142
|
+
|
|
143
|
+
# Use database name as key for standard PostgreSQL
|
|
144
|
+
cache_key = f"{checkpointer_model.database.name}"
|
|
145
|
+
checkpointer_manager = cls.checkpoint_managers.get(cache_key)
|
|
146
|
+
if checkpointer_manager is None:
|
|
147
|
+
checkpointer_manager = AsyncPostgresCheckpointerManager(
|
|
148
|
+
checkpointer_model
|
|
149
|
+
)
|
|
150
|
+
cls.checkpoint_managers[cache_key] = checkpointer_manager
|
|
118
151
|
case _:
|
|
119
|
-
raise ValueError(
|
|
152
|
+
raise ValueError(
|
|
153
|
+
f"Unknown storage type: {checkpointer_model.storage_type}"
|
|
154
|
+
)
|
|
120
155
|
|
|
121
156
|
return checkpointer_manager
|
|
@@ -0,0 +1,402 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Databricks-native memory storage implementations.
|
|
3
|
+
|
|
4
|
+
Provides CheckpointSaver and DatabricksStore implementations using
|
|
5
|
+
Databricks Lakebase for persistent storage, with async support.
|
|
6
|
+
|
|
7
|
+
See:
|
|
8
|
+
- https://api-docs.databricks.com/python/databricks-ai-bridge/latest/databricks_langchain.html#databricks_langchain.CheckpointSaver
|
|
9
|
+
- https://api-docs.databricks.com/python/databricks-ai-bridge/latest/databricks_langchain.html#databricks_langchain.DatabricksStore
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import asyncio
|
|
13
|
+
from collections.abc import AsyncIterator, Iterable, Sequence
|
|
14
|
+
from functools import partial
|
|
15
|
+
from typing import Any, Literal
|
|
16
|
+
|
|
17
|
+
from databricks_langchain import (
|
|
18
|
+
CheckpointSaver as DatabricksCheckpointSaver,
|
|
19
|
+
)
|
|
20
|
+
from databricks_langchain import (
|
|
21
|
+
DatabricksEmbeddings,
|
|
22
|
+
DatabricksStore,
|
|
23
|
+
)
|
|
24
|
+
from langchain_core.runnables import RunnableConfig
|
|
25
|
+
from langgraph.checkpoint.base import (
|
|
26
|
+
BaseCheckpointSaver,
|
|
27
|
+
ChannelVersions,
|
|
28
|
+
Checkpoint,
|
|
29
|
+
CheckpointMetadata,
|
|
30
|
+
CheckpointTuple,
|
|
31
|
+
)
|
|
32
|
+
from langgraph.store.base import BaseStore, Item, Op, Result, SearchItem
|
|
33
|
+
from loguru import logger
|
|
34
|
+
|
|
35
|
+
from dao_ai.config import (
|
|
36
|
+
CheckpointerModel,
|
|
37
|
+
StoreModel,
|
|
38
|
+
)
|
|
39
|
+
from dao_ai.memory.base import (
|
|
40
|
+
CheckpointManagerBase,
|
|
41
|
+
StoreManagerBase,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
# Type alias for namespace path
|
|
45
|
+
NamespacePath = tuple[str, ...]
|
|
46
|
+
|
|
47
|
+
# Sentinel for not-provided values
|
|
48
|
+
NOT_PROVIDED = object()
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class AsyncDatabricksCheckpointSaver(DatabricksCheckpointSaver):
|
|
52
|
+
"""
|
|
53
|
+
Async wrapper for DatabricksCheckpointSaver.
|
|
54
|
+
|
|
55
|
+
Provides async implementations of checkpoint methods by delegating
|
|
56
|
+
to the sync methods using asyncio.to_thread().
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
async def aget_tuple(self, config: RunnableConfig) -> CheckpointTuple | None:
|
|
60
|
+
"""Async version of get_tuple."""
|
|
61
|
+
thread_id = config.get("configurable", {}).get("thread_id", "unknown")
|
|
62
|
+
logger.trace("Fetching checkpoint", thread_id=thread_id, method="aget_tuple")
|
|
63
|
+
result = await asyncio.to_thread(self.get_tuple, config)
|
|
64
|
+
if result:
|
|
65
|
+
logger.trace("Checkpoint found", thread_id=thread_id)
|
|
66
|
+
else:
|
|
67
|
+
logger.trace("No checkpoint found", thread_id=thread_id)
|
|
68
|
+
return result
|
|
69
|
+
|
|
70
|
+
async def aget(self, config: RunnableConfig) -> Checkpoint | None:
|
|
71
|
+
"""Async version of get."""
|
|
72
|
+
thread_id = config.get("configurable", {}).get("thread_id", "unknown")
|
|
73
|
+
logger.trace("Fetching checkpoint", thread_id=thread_id, method="aget")
|
|
74
|
+
result = await asyncio.to_thread(self.get, config)
|
|
75
|
+
if result:
|
|
76
|
+
logger.trace("Checkpoint found", thread_id=thread_id)
|
|
77
|
+
else:
|
|
78
|
+
logger.trace("No checkpoint found", thread_id=thread_id)
|
|
79
|
+
return result
|
|
80
|
+
|
|
81
|
+
async def aput(
|
|
82
|
+
self,
|
|
83
|
+
config: RunnableConfig,
|
|
84
|
+
checkpoint: Checkpoint,
|
|
85
|
+
metadata: CheckpointMetadata,
|
|
86
|
+
new_versions: ChannelVersions,
|
|
87
|
+
) -> RunnableConfig:
|
|
88
|
+
"""Async version of put."""
|
|
89
|
+
thread_id = config.get("configurable", {}).get("thread_id", "unknown")
|
|
90
|
+
checkpoint_id = checkpoint.get("id", "unknown")
|
|
91
|
+
logger.trace(
|
|
92
|
+
"Saving checkpoint", checkpoint_id=checkpoint_id, thread_id=thread_id
|
|
93
|
+
)
|
|
94
|
+
result = await asyncio.to_thread(
|
|
95
|
+
self.put, config, checkpoint, metadata, new_versions
|
|
96
|
+
)
|
|
97
|
+
logger.trace(
|
|
98
|
+
"Checkpoint saved", thread_id=thread_id, checkpoint_id=checkpoint_id
|
|
99
|
+
)
|
|
100
|
+
return result
|
|
101
|
+
|
|
102
|
+
async def aput_writes(
|
|
103
|
+
self,
|
|
104
|
+
config: RunnableConfig,
|
|
105
|
+
writes: Sequence[tuple[str, Any]],
|
|
106
|
+
task_id: str,
|
|
107
|
+
task_path: str = "",
|
|
108
|
+
) -> None:
|
|
109
|
+
"""Async version of put_writes."""
|
|
110
|
+
thread_id = config.get("configurable", {}).get("thread_id", "unknown")
|
|
111
|
+
logger.trace(
|
|
112
|
+
"Saving checkpoint writes",
|
|
113
|
+
writes_count=len(writes),
|
|
114
|
+
thread_id=thread_id,
|
|
115
|
+
task_id=task_id,
|
|
116
|
+
)
|
|
117
|
+
await asyncio.to_thread(self.put_writes, config, writes, task_id, task_path)
|
|
118
|
+
logger.trace("Checkpoint writes saved", thread_id=thread_id, task_id=task_id)
|
|
119
|
+
|
|
120
|
+
async def alist(
|
|
121
|
+
self,
|
|
122
|
+
config: RunnableConfig | None,
|
|
123
|
+
*,
|
|
124
|
+
filter: dict[str, Any] | None = None,
|
|
125
|
+
before: RunnableConfig | None = None,
|
|
126
|
+
limit: int | None = None,
|
|
127
|
+
) -> AsyncIterator[CheckpointTuple]:
|
|
128
|
+
"""Async version of list."""
|
|
129
|
+
thread_id = (
|
|
130
|
+
config.get("configurable", {}).get("thread_id", "unknown")
|
|
131
|
+
if config
|
|
132
|
+
else "all"
|
|
133
|
+
)
|
|
134
|
+
logger.trace("Listing checkpoints", thread_id=thread_id, limit=limit)
|
|
135
|
+
# Get all items from sync iterator in a thread
|
|
136
|
+
items = await asyncio.to_thread(
|
|
137
|
+
lambda: list(self.list(config, filter=filter, before=before, limit=limit))
|
|
138
|
+
)
|
|
139
|
+
logger.debug("Checkpoints listed", thread_id=thread_id, count=len(items))
|
|
140
|
+
for item in items:
|
|
141
|
+
yield item
|
|
142
|
+
|
|
143
|
+
async def adelete_thread(self, thread_id: str) -> None:
|
|
144
|
+
"""Async version of delete_thread."""
|
|
145
|
+
logger.trace("Deleting thread", thread_id=thread_id)
|
|
146
|
+
await asyncio.to_thread(self.delete_thread, thread_id)
|
|
147
|
+
logger.debug("Thread deleted", thread_id=thread_id)
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
class AsyncDatabricksStore(DatabricksStore):
|
|
151
|
+
"""
|
|
152
|
+
Async wrapper for DatabricksStore.
|
|
153
|
+
|
|
154
|
+
Provides async implementations of store methods by delegating
|
|
155
|
+
to the sync methods using asyncio.to_thread().
|
|
156
|
+
"""
|
|
157
|
+
|
|
158
|
+
async def abatch(self, ops: Iterable[Op]) -> list[Result]:
|
|
159
|
+
"""Async version of batch."""
|
|
160
|
+
ops_list = list(ops)
|
|
161
|
+
logger.trace("Executing batch operations", operations_count=len(ops_list))
|
|
162
|
+
result = await asyncio.to_thread(self.batch, ops_list)
|
|
163
|
+
logger.debug("Batch operations completed", operations_count=len(result))
|
|
164
|
+
return result
|
|
165
|
+
|
|
166
|
+
async def aget(
|
|
167
|
+
self,
|
|
168
|
+
namespace: tuple[str, ...],
|
|
169
|
+
key: str,
|
|
170
|
+
*,
|
|
171
|
+
refresh_ttl: bool | None = None,
|
|
172
|
+
) -> Item | None:
|
|
173
|
+
"""Async version of get."""
|
|
174
|
+
ns_str = "/".join(namespace)
|
|
175
|
+
logger.trace("Fetching store item", key=key, namespace=ns_str)
|
|
176
|
+
result = await asyncio.to_thread(
|
|
177
|
+
partial(self.get, namespace, key, refresh_ttl=refresh_ttl)
|
|
178
|
+
)
|
|
179
|
+
if result:
|
|
180
|
+
logger.trace("Store item found", key=key, namespace=ns_str)
|
|
181
|
+
else:
|
|
182
|
+
logger.trace("Store item not found", key=key, namespace=ns_str)
|
|
183
|
+
return result
|
|
184
|
+
|
|
185
|
+
async def aput(
|
|
186
|
+
self,
|
|
187
|
+
namespace: tuple[str, ...],
|
|
188
|
+
key: str,
|
|
189
|
+
value: dict[str, Any],
|
|
190
|
+
index: Literal[False] | list[str] | None = None,
|
|
191
|
+
*,
|
|
192
|
+
ttl: float | None = None,
|
|
193
|
+
) -> None:
|
|
194
|
+
"""Async version of put."""
|
|
195
|
+
ns_str = "/".join(namespace)
|
|
196
|
+
logger.trace("Storing item", key=key, namespace=ns_str, has_ttl=ttl is not None)
|
|
197
|
+
# Handle the ttl parameter - only pass if explicitly provided
|
|
198
|
+
if ttl is not None:
|
|
199
|
+
await asyncio.to_thread(
|
|
200
|
+
partial(self.put, namespace, key, value, index, ttl=ttl)
|
|
201
|
+
)
|
|
202
|
+
else:
|
|
203
|
+
await asyncio.to_thread(partial(self.put, namespace, key, value, index))
|
|
204
|
+
logger.trace("Item stored", key=key, namespace=ns_str)
|
|
205
|
+
|
|
206
|
+
async def adelete(self, namespace: tuple[str, ...], key: str) -> None:
|
|
207
|
+
"""Async version of delete."""
|
|
208
|
+
ns_str = "/".join(namespace)
|
|
209
|
+
logger.trace("Deleting item", key=key, namespace=ns_str)
|
|
210
|
+
await asyncio.to_thread(self.delete, namespace, key)
|
|
211
|
+
logger.trace("Item deleted", key=key, namespace=ns_str)
|
|
212
|
+
|
|
213
|
+
async def asearch(
|
|
214
|
+
self,
|
|
215
|
+
namespace_prefix: tuple[str, ...],
|
|
216
|
+
/,
|
|
217
|
+
*,
|
|
218
|
+
query: str | None = None,
|
|
219
|
+
filter: dict[str, Any] | None = None,
|
|
220
|
+
limit: int = 10,
|
|
221
|
+
offset: int = 0,
|
|
222
|
+
refresh_ttl: bool | None = None,
|
|
223
|
+
) -> list[SearchItem]:
|
|
224
|
+
"""Async version of search."""
|
|
225
|
+
ns_str = "/".join(namespace_prefix)
|
|
226
|
+
logger.trace(
|
|
227
|
+
"Searching store", namespace_prefix=ns_str, query=query, limit=limit
|
|
228
|
+
)
|
|
229
|
+
result = await asyncio.to_thread(
|
|
230
|
+
partial(
|
|
231
|
+
self.search,
|
|
232
|
+
namespace_prefix,
|
|
233
|
+
query=query,
|
|
234
|
+
filter=filter,
|
|
235
|
+
limit=limit,
|
|
236
|
+
offset=offset,
|
|
237
|
+
refresh_ttl=refresh_ttl,
|
|
238
|
+
)
|
|
239
|
+
)
|
|
240
|
+
logger.debug(
|
|
241
|
+
"Store search completed", namespace_prefix=ns_str, results_count=len(result)
|
|
242
|
+
)
|
|
243
|
+
return result
|
|
244
|
+
|
|
245
|
+
async def alist_namespaces(
|
|
246
|
+
self,
|
|
247
|
+
*,
|
|
248
|
+
prefix: NamespacePath | None = None,
|
|
249
|
+
suffix: NamespacePath | None = None,
|
|
250
|
+
max_depth: int | None = None,
|
|
251
|
+
limit: int = 100,
|
|
252
|
+
offset: int = 0,
|
|
253
|
+
) -> list[tuple[str, ...]]:
|
|
254
|
+
"""Async version of list_namespaces."""
|
|
255
|
+
prefix_str = "/".join(prefix) if prefix else "all"
|
|
256
|
+
logger.trace("Listing namespaces", prefix=prefix_str, limit=limit)
|
|
257
|
+
result = await asyncio.to_thread(
|
|
258
|
+
partial(
|
|
259
|
+
self.list_namespaces,
|
|
260
|
+
prefix=prefix,
|
|
261
|
+
suffix=suffix,
|
|
262
|
+
max_depth=max_depth,
|
|
263
|
+
limit=limit,
|
|
264
|
+
offset=offset,
|
|
265
|
+
)
|
|
266
|
+
)
|
|
267
|
+
logger.debug("Namespaces listed", count=len(result))
|
|
268
|
+
return result
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
class DatabricksCheckpointerManager(CheckpointManagerBase):
|
|
272
|
+
"""
|
|
273
|
+
Checkpointer manager using Databricks CheckpointSaver with async support.
|
|
274
|
+
|
|
275
|
+
Uses AsyncDatabricksCheckpointSaver which wraps databricks_langchain.CheckpointSaver
|
|
276
|
+
with async method implementations for LangGraph async streaming compatibility.
|
|
277
|
+
|
|
278
|
+
Required configuration via CheckpointerModel.database:
|
|
279
|
+
- instance_name: The Databricks Lakebase instance name
|
|
280
|
+
- workspace_client: WorkspaceClient (supports OBO, service principal, or default auth)
|
|
281
|
+
|
|
282
|
+
See: https://api-docs.databricks.com/python/databricks-ai-bridge/latest/databricks_langchain.html#databricks_langchain.CheckpointSaver
|
|
283
|
+
"""
|
|
284
|
+
|
|
285
|
+
def __init__(self, checkpointer_model: CheckpointerModel):
|
|
286
|
+
self.checkpointer_model = checkpointer_model
|
|
287
|
+
self._checkpointer: BaseCheckpointSaver | None = None
|
|
288
|
+
|
|
289
|
+
def checkpointer(self) -> BaseCheckpointSaver:
|
|
290
|
+
if self._checkpointer is None:
|
|
291
|
+
database = self.checkpointer_model.database
|
|
292
|
+
if database is None:
|
|
293
|
+
raise ValueError(
|
|
294
|
+
"Database configuration is required for Databricks checkpointer. "
|
|
295
|
+
"Please provide a 'database' field in the checkpointer configuration."
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
instance_name = database.instance_name
|
|
299
|
+
workspace_client = database.workspace_client
|
|
300
|
+
|
|
301
|
+
logger.debug(
|
|
302
|
+
"Creating Databricks checkpointer", instance_name=instance_name
|
|
303
|
+
)
|
|
304
|
+
|
|
305
|
+
checkpointer = AsyncDatabricksCheckpointSaver(
|
|
306
|
+
instance_name=instance_name,
|
|
307
|
+
workspace_client=workspace_client,
|
|
308
|
+
)
|
|
309
|
+
|
|
310
|
+
# Setup the checkpointer (creates necessary tables if needed)
|
|
311
|
+
logger.debug("Setting up checkpoint tables", instance_name=instance_name)
|
|
312
|
+
checkpointer.setup()
|
|
313
|
+
logger.success(
|
|
314
|
+
"Databricks checkpointer initialized", instance_name=instance_name
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
self._checkpointer = checkpointer
|
|
318
|
+
|
|
319
|
+
return self._checkpointer
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
class DatabricksStoreManager(StoreManagerBase):
|
|
323
|
+
"""
|
|
324
|
+
Store manager using Databricks DatabricksStore with async support.
|
|
325
|
+
|
|
326
|
+
Uses AsyncDatabricksStore which wraps databricks_langchain.DatabricksStore
|
|
327
|
+
with async method implementations for LangGraph async streaming compatibility.
|
|
328
|
+
|
|
329
|
+
Required configuration via StoreModel.database:
|
|
330
|
+
- instance_name: The Databricks Lakebase instance name
|
|
331
|
+
- workspace_client: WorkspaceClient (supports OBO, service principal, or default auth)
|
|
332
|
+
|
|
333
|
+
Optional configuration via StoreModel:
|
|
334
|
+
- embedding_model: LLMModel for embeddings (will be converted to DatabricksEmbeddings)
|
|
335
|
+
- dims: Embedding dimensions
|
|
336
|
+
|
|
337
|
+
See: https://api-docs.databricks.com/python/databricks-ai-bridge/latest/databricks_langchain.html#databricks_langchain.DatabricksStore
|
|
338
|
+
"""
|
|
339
|
+
|
|
340
|
+
def __init__(self, store_model: StoreModel):
|
|
341
|
+
self.store_model = store_model
|
|
342
|
+
self._store: BaseStore | None = None
|
|
343
|
+
|
|
344
|
+
def store(self) -> BaseStore:
|
|
345
|
+
if self._store is None:
|
|
346
|
+
database = self.store_model.database
|
|
347
|
+
if database is None:
|
|
348
|
+
raise ValueError(
|
|
349
|
+
"Database configuration is required for Databricks store. "
|
|
350
|
+
"Please provide a 'database' field in the store configuration."
|
|
351
|
+
)
|
|
352
|
+
|
|
353
|
+
instance_name = database.instance_name
|
|
354
|
+
workspace_client = database.workspace_client
|
|
355
|
+
|
|
356
|
+
# Build embeddings configuration if embedding_model is provided
|
|
357
|
+
embeddings: DatabricksEmbeddings | None = None
|
|
358
|
+
embedding_dims: int | None = None
|
|
359
|
+
|
|
360
|
+
if self.store_model.embedding_model is not None:
|
|
361
|
+
embedding_endpoint = self.store_model.embedding_model.name
|
|
362
|
+
embedding_dims = self.store_model.dims
|
|
363
|
+
|
|
364
|
+
logger.debug(
|
|
365
|
+
"Configuring store embeddings",
|
|
366
|
+
endpoint=embedding_endpoint,
|
|
367
|
+
dimensions=embedding_dims,
|
|
368
|
+
)
|
|
369
|
+
|
|
370
|
+
embeddings = DatabricksEmbeddings(endpoint=embedding_endpoint)
|
|
371
|
+
|
|
372
|
+
logger.debug(
|
|
373
|
+
"Creating Databricks store",
|
|
374
|
+
instance_name=instance_name,
|
|
375
|
+
embeddings_enabled=embeddings is not None,
|
|
376
|
+
)
|
|
377
|
+
|
|
378
|
+
store = AsyncDatabricksStore(
|
|
379
|
+
instance_name=instance_name,
|
|
380
|
+
workspace_client=workspace_client,
|
|
381
|
+
embeddings=embeddings,
|
|
382
|
+
embedding_dims=embedding_dims,
|
|
383
|
+
)
|
|
384
|
+
|
|
385
|
+
# Setup the store (creates necessary tables if needed)
|
|
386
|
+
store.setup()
|
|
387
|
+
logger.success(
|
|
388
|
+
"Databricks store initialized",
|
|
389
|
+
instance_name=instance_name,
|
|
390
|
+
embeddings_enabled=embeddings is not None,
|
|
391
|
+
)
|
|
392
|
+
self._store = store
|
|
393
|
+
|
|
394
|
+
return self._store
|
|
395
|
+
|
|
396
|
+
|
|
397
|
+
__all__ = [
|
|
398
|
+
"AsyncDatabricksCheckpointSaver",
|
|
399
|
+
"AsyncDatabricksStore",
|
|
400
|
+
"DatabricksCheckpointerManager",
|
|
401
|
+
"DatabricksStoreManager",
|
|
402
|
+
]
|