dao-ai 0.0.25__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. dao_ai/__init__.py +29 -0
  2. dao_ai/agent_as_code.py +5 -5
  3. dao_ai/cli.py +245 -40
  4. dao_ai/config.py +1863 -338
  5. dao_ai/genie/__init__.py +38 -0
  6. dao_ai/genie/cache/__init__.py +43 -0
  7. dao_ai/genie/cache/base.py +72 -0
  8. dao_ai/genie/cache/core.py +79 -0
  9. dao_ai/genie/cache/lru.py +347 -0
  10. dao_ai/genie/cache/semantic.py +970 -0
  11. dao_ai/genie/core.py +35 -0
  12. dao_ai/graph.py +27 -228
  13. dao_ai/hooks/__init__.py +9 -6
  14. dao_ai/hooks/core.py +27 -195
  15. dao_ai/logging.py +56 -0
  16. dao_ai/memory/__init__.py +10 -0
  17. dao_ai/memory/core.py +65 -30
  18. dao_ai/memory/databricks.py +402 -0
  19. dao_ai/memory/postgres.py +79 -38
  20. dao_ai/messages.py +6 -4
  21. dao_ai/middleware/__init__.py +125 -0
  22. dao_ai/middleware/assertions.py +806 -0
  23. dao_ai/middleware/base.py +50 -0
  24. dao_ai/middleware/core.py +67 -0
  25. dao_ai/middleware/guardrails.py +420 -0
  26. dao_ai/middleware/human_in_the_loop.py +232 -0
  27. dao_ai/middleware/message_validation.py +586 -0
  28. dao_ai/middleware/summarization.py +197 -0
  29. dao_ai/models.py +1306 -114
  30. dao_ai/nodes.py +261 -166
  31. dao_ai/optimization.py +674 -0
  32. dao_ai/orchestration/__init__.py +52 -0
  33. dao_ai/orchestration/core.py +294 -0
  34. dao_ai/orchestration/supervisor.py +278 -0
  35. dao_ai/orchestration/swarm.py +271 -0
  36. dao_ai/prompts.py +128 -31
  37. dao_ai/providers/databricks.py +645 -172
  38. dao_ai/state.py +157 -21
  39. dao_ai/tools/__init__.py +13 -5
  40. dao_ai/tools/agent.py +1 -3
  41. dao_ai/tools/core.py +64 -11
  42. dao_ai/tools/email.py +232 -0
  43. dao_ai/tools/genie.py +144 -295
  44. dao_ai/tools/mcp.py +220 -133
  45. dao_ai/tools/memory.py +50 -0
  46. dao_ai/tools/python.py +9 -14
  47. dao_ai/tools/search.py +14 -0
  48. dao_ai/tools/slack.py +22 -10
  49. dao_ai/tools/sql.py +202 -0
  50. dao_ai/tools/time.py +30 -7
  51. dao_ai/tools/unity_catalog.py +165 -88
  52. dao_ai/tools/vector_search.py +360 -40
  53. dao_ai/utils.py +218 -16
  54. dao_ai-0.1.2.dist-info/METADATA +455 -0
  55. dao_ai-0.1.2.dist-info/RECORD +64 -0
  56. {dao_ai-0.0.25.dist-info → dao_ai-0.1.2.dist-info}/WHEEL +1 -1
  57. dao_ai/chat_models.py +0 -204
  58. dao_ai/guardrails.py +0 -112
  59. dao_ai/tools/human_in_the_loop.py +0 -100
  60. dao_ai-0.0.25.dist-info/METADATA +0 -1165
  61. dao_ai-0.0.25.dist-info/RECORD +0 -41
  62. {dao_ai-0.0.25.dist-info → dao_ai-0.1.2.dist-info}/entry_points.txt +0 -0
  63. {dao_ai-0.0.25.dist-info → dao_ai-0.1.2.dist-info}/licenses/LICENSE +0 -0
dao_ai/memory/core.py CHANGED
@@ -1,8 +1,6 @@
1
1
  from typing import Any
2
2
 
3
- from databricks_langchain import (
4
- DatabricksEmbeddings,
5
- )
3
+ from databricks_langchain import DatabricksEmbeddings
6
4
  from langchain_core.embeddings.embeddings import Embeddings
7
5
  from langgraph.checkpoint.base import BaseCheckpointSaver
8
6
  from langgraph.checkpoint.memory import InMemorySaver
@@ -27,11 +25,13 @@ class InMemoryStoreManager(StoreManagerBase):
27
25
  self.store_model = store_model
28
26
 
29
27
  def store(self) -> BaseStore:
30
- logger.debug("Creating InMemory store")
28
+ embedding_model: LLMModel = self.store_model.embedding_model
31
29
 
32
- index: dict[str, Any] = None
30
+ logger.debug(
31
+ "Creating in-memory store", embeddings_enabled=embedding_model is not None
32
+ )
33
33
 
34
- embedding_model: LLMModel = self.store_model.embedding_model
34
+ index: dict[str, Any] = None
35
35
 
36
36
  if embedding_model:
37
37
  embeddings: Embeddings = DatabricksEmbeddings(endpoint=embedding_model.name)
@@ -41,6 +41,11 @@ class InMemoryStoreManager(StoreManagerBase):
41
41
 
42
42
  dims: int = self.store_model.dims
43
43
  index = {"dims": dims, "embed": embed_texts}
44
+ logger.debug(
45
+ "Store embeddings configured",
46
+ endpoint=embedding_model.name,
47
+ dimensions=dims,
48
+ )
44
49
 
45
50
  store: BaseStore = InMemoryStore(index=index)
46
51
 
@@ -60,26 +65,39 @@ class StoreManager:
60
65
 
61
66
  @classmethod
62
67
  def instance(cls, store_model: StoreModel) -> StoreManagerBase:
63
- store_manager: StoreManagerBase = None
64
- match store_model.type:
68
+ store_manager: StoreManagerBase | None = None
69
+ match store_model.storage_type:
65
70
  case StorageType.MEMORY:
66
71
  store_manager = cls.store_managers.get(store_model.name)
67
72
  if store_manager is None:
68
73
  store_manager = InMemoryStoreManager(store_model)
69
74
  cls.store_managers[store_model.name] = store_manager
70
75
  case StorageType.POSTGRES:
71
- from dao_ai.memory.postgres import PostgresStoreManager
76
+ # Route based on database configuration: instance_name -> Databricks, host -> Postgres
77
+ if store_model.database.is_lakebase:
78
+ # Databricks Lakebase connection
79
+ from dao_ai.memory.databricks import DatabricksStoreManager
72
80
 
73
- store_manager = cls.store_managers.get(
74
- store_model.database.instance_name
75
- )
76
- if store_manager is None:
77
- store_manager = PostgresStoreManager(store_model)
78
- cls.store_managers[store_model.database.instance_name] = (
79
- store_manager
81
+ store_manager = cls.store_managers.get(
82
+ store_model.database.instance_name
80
83
  )
84
+ if store_manager is None:
85
+ store_manager = DatabricksStoreManager(store_model)
86
+ cls.store_managers[store_model.database.instance_name] = (
87
+ store_manager
88
+ )
89
+ else:
90
+ # Standard PostgreSQL connection
91
+ from dao_ai.memory.postgres import PostgresStoreManager
92
+
93
+ # Use database name as key for standard PostgreSQL
94
+ cache_key = f"{store_model.database.name}"
95
+ store_manager = cls.store_managers.get(cache_key)
96
+ if store_manager is None:
97
+ store_manager = PostgresStoreManager(store_model)
98
+ cls.store_managers[cache_key] = store_manager
81
99
  case _:
82
- raise ValueError(f"Unknown store type: {store_model.type}")
100
+ raise ValueError(f"Unknown storage type: {store_model.storage_type}")
83
101
 
84
102
  return store_manager
85
103
 
@@ -89,8 +107,8 @@ class CheckpointManager:
89
107
 
90
108
  @classmethod
91
109
  def instance(cls, checkpointer_model: CheckpointerModel) -> CheckpointManagerBase:
92
- checkpointer_manager: CheckpointManagerBase = None
93
- match checkpointer_model.type:
110
+ checkpointer_manager: CheckpointManagerBase | None = None
111
+ match checkpointer_model.storage_type:
94
112
  case StorageType.MEMORY:
95
113
  checkpointer_manager = cls.checkpoint_managers.get(
96
114
  checkpointer_model.name
@@ -103,19 +121,36 @@ class CheckpointManager:
103
121
  checkpointer_manager
104
122
  )
105
123
  case StorageType.POSTGRES:
106
- from dao_ai.memory.postgres import AsyncPostgresCheckpointerManager
124
+ # Route based on database configuration: instance_name -> Databricks, host -> Postgres
125
+ if checkpointer_model.database.is_lakebase:
126
+ # Databricks Lakebase connection
127
+ from dao_ai.memory.databricks import DatabricksCheckpointerManager
107
128
 
108
- checkpointer_manager = cls.checkpoint_managers.get(
109
- checkpointer_model.database.instance_name
110
- )
111
- if checkpointer_manager is None:
112
- checkpointer_manager = AsyncPostgresCheckpointerManager(
113
- checkpointer_model
114
- )
115
- cls.checkpoint_managers[
129
+ checkpointer_manager = cls.checkpoint_managers.get(
116
130
  checkpointer_model.database.instance_name
117
- ] = checkpointer_manager
131
+ )
132
+ if checkpointer_manager is None:
133
+ checkpointer_manager = DatabricksCheckpointerManager(
134
+ checkpointer_model
135
+ )
136
+ cls.checkpoint_managers[
137
+ checkpointer_model.database.instance_name
138
+ ] = checkpointer_manager
139
+ else:
140
+ # Standard PostgreSQL connection
141
+ from dao_ai.memory.postgres import AsyncPostgresCheckpointerManager
142
+
143
+ # Use database name as key for standard PostgreSQL
144
+ cache_key = f"{checkpointer_model.database.name}"
145
+ checkpointer_manager = cls.checkpoint_managers.get(cache_key)
146
+ if checkpointer_manager is None:
147
+ checkpointer_manager = AsyncPostgresCheckpointerManager(
148
+ checkpointer_model
149
+ )
150
+ cls.checkpoint_managers[cache_key] = checkpointer_manager
118
151
  case _:
119
- raise ValueError(f"Unknown store type: {checkpointer_model.type}")
152
+ raise ValueError(
153
+ f"Unknown storage type: {checkpointer_model.storage_type}"
154
+ )
120
155
 
121
156
  return checkpointer_manager
@@ -0,0 +1,402 @@
1
+ """
2
+ Databricks-native memory storage implementations.
3
+
4
+ Provides CheckpointSaver and DatabricksStore implementations using
5
+ Databricks Lakebase for persistent storage, with async support.
6
+
7
+ See:
8
+ - https://api-docs.databricks.com/python/databricks-ai-bridge/latest/databricks_langchain.html#databricks_langchain.CheckpointSaver
9
+ - https://api-docs.databricks.com/python/databricks-ai-bridge/latest/databricks_langchain.html#databricks_langchain.DatabricksStore
10
+ """
11
+
12
+ import asyncio
13
+ from collections.abc import AsyncIterator, Iterable, Sequence
14
+ from functools import partial
15
+ from typing import Any, Literal
16
+
17
+ from databricks_langchain import (
18
+ CheckpointSaver as DatabricksCheckpointSaver,
19
+ )
20
+ from databricks_langchain import (
21
+ DatabricksEmbeddings,
22
+ DatabricksStore,
23
+ )
24
+ from langchain_core.runnables import RunnableConfig
25
+ from langgraph.checkpoint.base import (
26
+ BaseCheckpointSaver,
27
+ ChannelVersions,
28
+ Checkpoint,
29
+ CheckpointMetadata,
30
+ CheckpointTuple,
31
+ )
32
+ from langgraph.store.base import BaseStore, Item, Op, Result, SearchItem
33
+ from loguru import logger
34
+
35
+ from dao_ai.config import (
36
+ CheckpointerModel,
37
+ StoreModel,
38
+ )
39
+ from dao_ai.memory.base import (
40
+ CheckpointManagerBase,
41
+ StoreManagerBase,
42
+ )
43
+
44
+ # Type alias for namespace path
45
+ NamespacePath = tuple[str, ...]
46
+
47
+ # Sentinel for not-provided values
48
+ NOT_PROVIDED = object()
49
+
50
+
51
+ class AsyncDatabricksCheckpointSaver(DatabricksCheckpointSaver):
52
+ """
53
+ Async wrapper for DatabricksCheckpointSaver.
54
+
55
+ Provides async implementations of checkpoint methods by delegating
56
+ to the sync methods using asyncio.to_thread().
57
+ """
58
+
59
+ async def aget_tuple(self, config: RunnableConfig) -> CheckpointTuple | None:
60
+ """Async version of get_tuple."""
61
+ thread_id = config.get("configurable", {}).get("thread_id", "unknown")
62
+ logger.trace("Fetching checkpoint", thread_id=thread_id, method="aget_tuple")
63
+ result = await asyncio.to_thread(self.get_tuple, config)
64
+ if result:
65
+ logger.trace("Checkpoint found", thread_id=thread_id)
66
+ else:
67
+ logger.trace("No checkpoint found", thread_id=thread_id)
68
+ return result
69
+
70
+ async def aget(self, config: RunnableConfig) -> Checkpoint | None:
71
+ """Async version of get."""
72
+ thread_id = config.get("configurable", {}).get("thread_id", "unknown")
73
+ logger.trace("Fetching checkpoint", thread_id=thread_id, method="aget")
74
+ result = await asyncio.to_thread(self.get, config)
75
+ if result:
76
+ logger.trace("Checkpoint found", thread_id=thread_id)
77
+ else:
78
+ logger.trace("No checkpoint found", thread_id=thread_id)
79
+ return result
80
+
81
+ async def aput(
82
+ self,
83
+ config: RunnableConfig,
84
+ checkpoint: Checkpoint,
85
+ metadata: CheckpointMetadata,
86
+ new_versions: ChannelVersions,
87
+ ) -> RunnableConfig:
88
+ """Async version of put."""
89
+ thread_id = config.get("configurable", {}).get("thread_id", "unknown")
90
+ checkpoint_id = checkpoint.get("id", "unknown")
91
+ logger.trace(
92
+ "Saving checkpoint", checkpoint_id=checkpoint_id, thread_id=thread_id
93
+ )
94
+ result = await asyncio.to_thread(
95
+ self.put, config, checkpoint, metadata, new_versions
96
+ )
97
+ logger.trace(
98
+ "Checkpoint saved", thread_id=thread_id, checkpoint_id=checkpoint_id
99
+ )
100
+ return result
101
+
102
+ async def aput_writes(
103
+ self,
104
+ config: RunnableConfig,
105
+ writes: Sequence[tuple[str, Any]],
106
+ task_id: str,
107
+ task_path: str = "",
108
+ ) -> None:
109
+ """Async version of put_writes."""
110
+ thread_id = config.get("configurable", {}).get("thread_id", "unknown")
111
+ logger.trace(
112
+ "Saving checkpoint writes",
113
+ writes_count=len(writes),
114
+ thread_id=thread_id,
115
+ task_id=task_id,
116
+ )
117
+ await asyncio.to_thread(self.put_writes, config, writes, task_id, task_path)
118
+ logger.trace("Checkpoint writes saved", thread_id=thread_id, task_id=task_id)
119
+
120
+ async def alist(
121
+ self,
122
+ config: RunnableConfig | None,
123
+ *,
124
+ filter: dict[str, Any] | None = None,
125
+ before: RunnableConfig | None = None,
126
+ limit: int | None = None,
127
+ ) -> AsyncIterator[CheckpointTuple]:
128
+ """Async version of list."""
129
+ thread_id = (
130
+ config.get("configurable", {}).get("thread_id", "unknown")
131
+ if config
132
+ else "all"
133
+ )
134
+ logger.trace("Listing checkpoints", thread_id=thread_id, limit=limit)
135
+ # Get all items from sync iterator in a thread
136
+ items = await asyncio.to_thread(
137
+ lambda: list(self.list(config, filter=filter, before=before, limit=limit))
138
+ )
139
+ logger.debug("Checkpoints listed", thread_id=thread_id, count=len(items))
140
+ for item in items:
141
+ yield item
142
+
143
+ async def adelete_thread(self, thread_id: str) -> None:
144
+ """Async version of delete_thread."""
145
+ logger.trace("Deleting thread", thread_id=thread_id)
146
+ await asyncio.to_thread(self.delete_thread, thread_id)
147
+ logger.debug("Thread deleted", thread_id=thread_id)
148
+
149
+
150
+ class AsyncDatabricksStore(DatabricksStore):
151
+ """
152
+ Async wrapper for DatabricksStore.
153
+
154
+ Provides async implementations of store methods by delegating
155
+ to the sync methods using asyncio.to_thread().
156
+ """
157
+
158
+ async def abatch(self, ops: Iterable[Op]) -> list[Result]:
159
+ """Async version of batch."""
160
+ ops_list = list(ops)
161
+ logger.trace("Executing batch operations", operations_count=len(ops_list))
162
+ result = await asyncio.to_thread(self.batch, ops_list)
163
+ logger.debug("Batch operations completed", operations_count=len(result))
164
+ return result
165
+
166
+ async def aget(
167
+ self,
168
+ namespace: tuple[str, ...],
169
+ key: str,
170
+ *,
171
+ refresh_ttl: bool | None = None,
172
+ ) -> Item | None:
173
+ """Async version of get."""
174
+ ns_str = "/".join(namespace)
175
+ logger.trace("Fetching store item", key=key, namespace=ns_str)
176
+ result = await asyncio.to_thread(
177
+ partial(self.get, namespace, key, refresh_ttl=refresh_ttl)
178
+ )
179
+ if result:
180
+ logger.trace("Store item found", key=key, namespace=ns_str)
181
+ else:
182
+ logger.trace("Store item not found", key=key, namespace=ns_str)
183
+ return result
184
+
185
+ async def aput(
186
+ self,
187
+ namespace: tuple[str, ...],
188
+ key: str,
189
+ value: dict[str, Any],
190
+ index: Literal[False] | list[str] | None = None,
191
+ *,
192
+ ttl: float | None = None,
193
+ ) -> None:
194
+ """Async version of put."""
195
+ ns_str = "/".join(namespace)
196
+ logger.trace("Storing item", key=key, namespace=ns_str, has_ttl=ttl is not None)
197
+ # Handle the ttl parameter - only pass if explicitly provided
198
+ if ttl is not None:
199
+ await asyncio.to_thread(
200
+ partial(self.put, namespace, key, value, index, ttl=ttl)
201
+ )
202
+ else:
203
+ await asyncio.to_thread(partial(self.put, namespace, key, value, index))
204
+ logger.trace("Item stored", key=key, namespace=ns_str)
205
+
206
+ async def adelete(self, namespace: tuple[str, ...], key: str) -> None:
207
+ """Async version of delete."""
208
+ ns_str = "/".join(namespace)
209
+ logger.trace("Deleting item", key=key, namespace=ns_str)
210
+ await asyncio.to_thread(self.delete, namespace, key)
211
+ logger.trace("Item deleted", key=key, namespace=ns_str)
212
+
213
+ async def asearch(
214
+ self,
215
+ namespace_prefix: tuple[str, ...],
216
+ /,
217
+ *,
218
+ query: str | None = None,
219
+ filter: dict[str, Any] | None = None,
220
+ limit: int = 10,
221
+ offset: int = 0,
222
+ refresh_ttl: bool | None = None,
223
+ ) -> list[SearchItem]:
224
+ """Async version of search."""
225
+ ns_str = "/".join(namespace_prefix)
226
+ logger.trace(
227
+ "Searching store", namespace_prefix=ns_str, query=query, limit=limit
228
+ )
229
+ result = await asyncio.to_thread(
230
+ partial(
231
+ self.search,
232
+ namespace_prefix,
233
+ query=query,
234
+ filter=filter,
235
+ limit=limit,
236
+ offset=offset,
237
+ refresh_ttl=refresh_ttl,
238
+ )
239
+ )
240
+ logger.debug(
241
+ "Store search completed", namespace_prefix=ns_str, results_count=len(result)
242
+ )
243
+ return result
244
+
245
+ async def alist_namespaces(
246
+ self,
247
+ *,
248
+ prefix: NamespacePath | None = None,
249
+ suffix: NamespacePath | None = None,
250
+ max_depth: int | None = None,
251
+ limit: int = 100,
252
+ offset: int = 0,
253
+ ) -> list[tuple[str, ...]]:
254
+ """Async version of list_namespaces."""
255
+ prefix_str = "/".join(prefix) if prefix else "all"
256
+ logger.trace("Listing namespaces", prefix=prefix_str, limit=limit)
257
+ result = await asyncio.to_thread(
258
+ partial(
259
+ self.list_namespaces,
260
+ prefix=prefix,
261
+ suffix=suffix,
262
+ max_depth=max_depth,
263
+ limit=limit,
264
+ offset=offset,
265
+ )
266
+ )
267
+ logger.debug("Namespaces listed", count=len(result))
268
+ return result
269
+
270
+
271
+ class DatabricksCheckpointerManager(CheckpointManagerBase):
272
+ """
273
+ Checkpointer manager using Databricks CheckpointSaver with async support.
274
+
275
+ Uses AsyncDatabricksCheckpointSaver which wraps databricks_langchain.CheckpointSaver
276
+ with async method implementations for LangGraph async streaming compatibility.
277
+
278
+ Required configuration via CheckpointerModel.database:
279
+ - instance_name: The Databricks Lakebase instance name
280
+ - workspace_client: WorkspaceClient (supports OBO, service principal, or default auth)
281
+
282
+ See: https://api-docs.databricks.com/python/databricks-ai-bridge/latest/databricks_langchain.html#databricks_langchain.CheckpointSaver
283
+ """
284
+
285
+ def __init__(self, checkpointer_model: CheckpointerModel):
286
+ self.checkpointer_model = checkpointer_model
287
+ self._checkpointer: BaseCheckpointSaver | None = None
288
+
289
+ def checkpointer(self) -> BaseCheckpointSaver:
290
+ if self._checkpointer is None:
291
+ database = self.checkpointer_model.database
292
+ if database is None:
293
+ raise ValueError(
294
+ "Database configuration is required for Databricks checkpointer. "
295
+ "Please provide a 'database' field in the checkpointer configuration."
296
+ )
297
+
298
+ instance_name = database.instance_name
299
+ workspace_client = database.workspace_client
300
+
301
+ logger.debug(
302
+ "Creating Databricks checkpointer", instance_name=instance_name
303
+ )
304
+
305
+ checkpointer = AsyncDatabricksCheckpointSaver(
306
+ instance_name=instance_name,
307
+ workspace_client=workspace_client,
308
+ )
309
+
310
+ # Setup the checkpointer (creates necessary tables if needed)
311
+ logger.debug("Setting up checkpoint tables", instance_name=instance_name)
312
+ checkpointer.setup()
313
+ logger.success(
314
+ "Databricks checkpointer initialized", instance_name=instance_name
315
+ )
316
+
317
+ self._checkpointer = checkpointer
318
+
319
+ return self._checkpointer
320
+
321
+
322
+ class DatabricksStoreManager(StoreManagerBase):
323
+ """
324
+ Store manager using Databricks DatabricksStore with async support.
325
+
326
+ Uses AsyncDatabricksStore which wraps databricks_langchain.DatabricksStore
327
+ with async method implementations for LangGraph async streaming compatibility.
328
+
329
+ Required configuration via StoreModel.database:
330
+ - instance_name: The Databricks Lakebase instance name
331
+ - workspace_client: WorkspaceClient (supports OBO, service principal, or default auth)
332
+
333
+ Optional configuration via StoreModel:
334
+ - embedding_model: LLMModel for embeddings (will be converted to DatabricksEmbeddings)
335
+ - dims: Embedding dimensions
336
+
337
+ See: https://api-docs.databricks.com/python/databricks-ai-bridge/latest/databricks_langchain.html#databricks_langchain.DatabricksStore
338
+ """
339
+
340
+ def __init__(self, store_model: StoreModel):
341
+ self.store_model = store_model
342
+ self._store: BaseStore | None = None
343
+
344
+ def store(self) -> BaseStore:
345
+ if self._store is None:
346
+ database = self.store_model.database
347
+ if database is None:
348
+ raise ValueError(
349
+ "Database configuration is required for Databricks store. "
350
+ "Please provide a 'database' field in the store configuration."
351
+ )
352
+
353
+ instance_name = database.instance_name
354
+ workspace_client = database.workspace_client
355
+
356
+ # Build embeddings configuration if embedding_model is provided
357
+ embeddings: DatabricksEmbeddings | None = None
358
+ embedding_dims: int | None = None
359
+
360
+ if self.store_model.embedding_model is not None:
361
+ embedding_endpoint = self.store_model.embedding_model.name
362
+ embedding_dims = self.store_model.dims
363
+
364
+ logger.debug(
365
+ "Configuring store embeddings",
366
+ endpoint=embedding_endpoint,
367
+ dimensions=embedding_dims,
368
+ )
369
+
370
+ embeddings = DatabricksEmbeddings(endpoint=embedding_endpoint)
371
+
372
+ logger.debug(
373
+ "Creating Databricks store",
374
+ instance_name=instance_name,
375
+ embeddings_enabled=embeddings is not None,
376
+ )
377
+
378
+ store = AsyncDatabricksStore(
379
+ instance_name=instance_name,
380
+ workspace_client=workspace_client,
381
+ embeddings=embeddings,
382
+ embedding_dims=embedding_dims,
383
+ )
384
+
385
+ # Setup the store (creates necessary tables if needed)
386
+ store.setup()
387
+ logger.success(
388
+ "Databricks store initialized",
389
+ instance_name=instance_name,
390
+ embeddings_enabled=embeddings is not None,
391
+ )
392
+ self._store = store
393
+
394
+ return self._store
395
+
396
+
397
+ __all__ = [
398
+ "AsyncDatabricksCheckpointSaver",
399
+ "AsyncDatabricksStore",
400
+ "DatabricksCheckpointerManager",
401
+ "DatabricksStoreManager",
402
+ ]