dao-ai 0.0.35__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (58) hide show
  1. dao_ai/__init__.py +29 -0
  2. dao_ai/cli.py +195 -30
  3. dao_ai/config.py +797 -242
  4. dao_ai/genie/__init__.py +38 -0
  5. dao_ai/genie/cache/__init__.py +43 -0
  6. dao_ai/genie/cache/base.py +72 -0
  7. dao_ai/genie/cache/core.py +75 -0
  8. dao_ai/genie/cache/lru.py +329 -0
  9. dao_ai/genie/cache/semantic.py +919 -0
  10. dao_ai/genie/core.py +35 -0
  11. dao_ai/graph.py +27 -253
  12. dao_ai/hooks/__init__.py +9 -6
  13. dao_ai/hooks/core.py +22 -190
  14. dao_ai/memory/__init__.py +10 -0
  15. dao_ai/memory/core.py +23 -5
  16. dao_ai/memory/databricks.py +389 -0
  17. dao_ai/memory/postgres.py +2 -2
  18. dao_ai/messages.py +6 -4
  19. dao_ai/middleware/__init__.py +125 -0
  20. dao_ai/middleware/assertions.py +778 -0
  21. dao_ai/middleware/base.py +50 -0
  22. dao_ai/middleware/core.py +61 -0
  23. dao_ai/middleware/guardrails.py +415 -0
  24. dao_ai/middleware/human_in_the_loop.py +228 -0
  25. dao_ai/middleware/message_validation.py +554 -0
  26. dao_ai/middleware/summarization.py +192 -0
  27. dao_ai/models.py +1177 -108
  28. dao_ai/nodes.py +118 -161
  29. dao_ai/optimization.py +664 -0
  30. dao_ai/orchestration/__init__.py +52 -0
  31. dao_ai/orchestration/core.py +287 -0
  32. dao_ai/orchestration/supervisor.py +264 -0
  33. dao_ai/orchestration/swarm.py +226 -0
  34. dao_ai/prompts.py +126 -29
  35. dao_ai/providers/databricks.py +126 -381
  36. dao_ai/state.py +139 -21
  37. dao_ai/tools/__init__.py +11 -5
  38. dao_ai/tools/core.py +57 -4
  39. dao_ai/tools/email.py +280 -0
  40. dao_ai/tools/genie.py +108 -35
  41. dao_ai/tools/mcp.py +4 -3
  42. dao_ai/tools/memory.py +50 -0
  43. dao_ai/tools/python.py +4 -12
  44. dao_ai/tools/search.py +14 -0
  45. dao_ai/tools/slack.py +1 -1
  46. dao_ai/tools/unity_catalog.py +8 -6
  47. dao_ai/tools/vector_search.py +16 -9
  48. dao_ai/utils.py +72 -8
  49. dao_ai-0.1.0.dist-info/METADATA +1878 -0
  50. dao_ai-0.1.0.dist-info/RECORD +62 -0
  51. dao_ai/chat_models.py +0 -204
  52. dao_ai/guardrails.py +0 -112
  53. dao_ai/tools/human_in_the_loop.py +0 -100
  54. dao_ai-0.0.35.dist-info/METADATA +0 -1169
  55. dao_ai-0.0.35.dist-info/RECORD +0 -41
  56. {dao_ai-0.0.35.dist-info → dao_ai-0.1.0.dist-info}/WHEEL +0 -0
  57. {dao_ai-0.0.35.dist-info → dao_ai-0.1.0.dist-info}/entry_points.txt +0 -0
  58. {dao_ai-0.0.35.dist-info → dao_ai-0.1.0.dist-info}/licenses/LICENSE +0 -0
dao_ai/genie/core.py ADDED
@@ -0,0 +1,35 @@
1
+ """
2
+ Core Genie service implementation.
3
+
4
+ This module provides the concrete implementation of GenieServiceBase
5
+ that wraps the Databricks Genie SDK.
6
+ """
7
+
8
+ import mlflow
9
+ from databricks_ai_bridge.genie import Genie, GenieResponse
10
+
11
+ from dao_ai.genie.cache import CacheResult, GenieServiceBase
12
+
13
+
14
+ class GenieService(GenieServiceBase):
15
+ """Concrete implementation of GenieServiceBase using the Genie SDK."""
16
+
17
+ genie: Genie
18
+
19
+ def __init__(self, genie: Genie) -> None:
20
+ self.genie = genie
21
+
22
+ @mlflow.trace(name="genie_ask_question")
23
+ def ask_question(
24
+ self, question: str, conversation_id: str | None = None
25
+ ) -> CacheResult:
26
+ """Ask question to Genie and return CacheResult (no caching at this level)."""
27
+ response: GenieResponse = self.genie.ask_question(
28
+ question, conversation_id=conversation_id
29
+ )
30
+ # No caching at this level - return cache miss
31
+ return CacheResult(response=response, cache_hit=False, served_by=None)
32
+
33
+ @property
34
+ def space_id(self) -> str:
35
+ return self.genie.space_id
dao_ai/graph.py CHANGED
@@ -1,263 +1,37 @@
1
- from typing import Sequence
1
+ """
2
+ Graph creation utilities for DAO AI multi-agent orchestration.
2
3
 
3
- from langchain_core.language_models import LanguageModelLike
4
- from langchain_core.runnables import RunnableConfig
5
- from langchain_core.tools import BaseTool
6
- from langgraph.checkpoint.base import BaseCheckpointSaver
7
- from langgraph.graph import END, StateGraph
8
- from langgraph.graph.state import CompiledStateGraph
9
- from langgraph.store.base import BaseStore
10
- from langgraph_supervisor import create_handoff_tool as supervisor_handoff_tool
11
- from langgraph_supervisor import create_supervisor
12
- from langgraph_swarm import create_handoff_tool as swarm_handoff_tool
13
- from langgraph_swarm import create_swarm
14
- from langmem import create_manage_memory_tool, create_search_memory_tool
15
- from loguru import logger
16
-
17
- from dao_ai.config import (
18
- AgentModel,
19
- AppConfig,
20
- OrchestrationModel,
21
- SupervisorModel,
22
- SwarmModel,
23
- )
24
- from dao_ai.nodes import (
25
- create_agent_node,
26
- message_hook_node,
27
- )
28
- from dao_ai.prompts import make_prompt
29
- from dao_ai.state import Context, IncomingState, OutgoingState, SharedState
30
- from dao_ai.tools import create_tools
31
-
32
-
33
- def route_message(state: SharedState) -> str:
34
- if not state["is_valid"]:
35
- return END
36
- return "orchestration"
37
-
38
-
39
- def _handoffs_for_agent(agent: AgentModel, config: AppConfig) -> Sequence[BaseTool]:
40
- handoff_tools: list[BaseTool] = []
41
-
42
- handoffs: dict[str, Sequence[AgentModel | str]] = (
43
- config.app.orchestration.swarm.handoffs or {}
44
- )
45
- agent_handoffs: Sequence[AgentModel | str] = handoffs.get(agent.name)
46
- if agent_handoffs is None:
47
- agent_handoffs = config.app.agents
48
-
49
- for handoff_to_agent in agent_handoffs:
50
- if isinstance(handoff_to_agent, str):
51
- handoff_to_agent = next(
52
- iter(config.find_agents(lambda a: a.name == handoff_to_agent)), None
53
- )
54
-
55
- if handoff_to_agent is None:
56
- logger.warning(
57
- f"Handoff agent {handoff_to_agent} not found in configuration for agent {agent.name}"
58
- )
59
- continue
60
- if agent.name == handoff_to_agent.name:
61
- continue
62
- logger.debug(
63
- f"Creating handoff tool from agent {agent.name} to {handoff_to_agent.name}"
64
- )
65
-
66
- # Use handoff_prompt if provided, otherwise create default description
67
- handoff_description = handoff_to_agent.handoff_prompt or (
68
- handoff_to_agent.description
69
- if handoff_to_agent.description
70
- else "general assistance and questions"
71
- )
72
-
73
- handoff_tools.append(
74
- swarm_handoff_tool(
75
- agent_name=handoff_to_agent.name,
76
- description=f"Ask {handoff_to_agent.name} for help with: "
77
- + handoff_description,
78
- )
79
- )
80
- return handoff_tools
81
-
82
-
83
- def _create_supervisor_graph(config: AppConfig) -> CompiledStateGraph:
84
- logger.debug("Creating supervisor graph")
85
- agents: list[CompiledStateGraph] = []
86
- tools: Sequence[BaseTool] = []
87
- for registered_agent in config.app.agents:
88
- agents.append(
89
- create_agent_node(
90
- agent=registered_agent,
91
- memory=config.app.orchestration.memory
92
- if config.app.orchestration
93
- else None,
94
- chat_history=config.app.chat_history,
95
- additional_tools=[],
96
- )
97
- )
98
- # Use handoff_prompt if provided, otherwise create default description
99
- handoff_description = registered_agent.handoff_prompt or (
100
- registered_agent.description
101
- if registered_agent.description
102
- else f"General assistance with {registered_agent.name} related tasks"
103
- )
104
-
105
- tools.append(
106
- supervisor_handoff_tool(
107
- agent_name=registered_agent.name,
108
- description=handoff_description,
109
- )
110
- )
111
-
112
- orchestration: OrchestrationModel = config.app.orchestration
113
- supervisor: SupervisorModel = orchestration.supervisor
114
-
115
- tools += create_tools(orchestration.supervisor.tools)
116
-
117
- store: BaseStore = None
118
- if orchestration.memory and orchestration.memory.store:
119
- store = orchestration.memory.store.as_store()
120
- logger.debug(f"Using memory store: {store}")
121
- namespace: tuple[str, ...] = ("memory",)
122
-
123
- if orchestration.memory.store.namespace:
124
- namespace = namespace + (orchestration.memory.store.namespace,)
125
- logger.debug(f"Memory store namespace: {namespace}")
126
- tools += [
127
- create_manage_memory_tool(namespace=namespace),
128
- create_search_memory_tool(namespace=namespace),
129
- ]
130
-
131
- checkpointer: BaseCheckpointSaver = None
132
- if orchestration.memory and orchestration.memory.checkpointer:
133
- checkpointer = orchestration.memory.checkpointer.as_checkpointer()
134
- logger.debug(f"Using checkpointer: {checkpointer}")
4
+ This module provides backwards-compatible imports from the refactored
5
+ orchestration package. New code should import directly from:
6
+ - dao_ai.orchestration
7
+ - dao_ai.orchestration.supervisor
8
+ - dao_ai.orchestration.swarm
135
9
 
136
- prompt: str = supervisor.prompt
10
+ See: https://docs.langchain.com/oss/python/langchain/multi-agent
11
+ """
137
12
 
138
- model: LanguageModelLike = supervisor.model.as_chat_model()
139
- supervisor_workflow: StateGraph = create_supervisor(
140
- supervisor_name="supervisor",
141
- prompt=make_prompt(base_system_prompt=prompt),
142
- agents=agents,
143
- model=model,
144
- tools=tools,
145
- state_schema=SharedState,
146
- config_schema=RunnableConfig,
147
- output_mode="last_message",
148
- add_handoff_messages=False,
149
- add_handoff_back_messages=False,
150
- context_schema=Context,
151
- # output_mode="full",
152
- # add_handoff_messages=True,
153
- # add_handoff_back_messages=True,
154
- )
155
-
156
- supervisor_node: CompiledStateGraph = supervisor_workflow.compile(
157
- checkpointer=checkpointer, store=store
158
- )
159
-
160
- workflow: StateGraph = StateGraph(
161
- SharedState,
162
- input_schema=IncomingState,
163
- output_schema=OutgoingState,
164
- context_schema=Context,
165
- )
166
-
167
- workflow.add_node("message_hook", message_hook_node(config=config))
168
-
169
- workflow.add_node("orchestration", supervisor_node)
170
- workflow.add_conditional_edges(
171
- "message_hook",
172
- route_message,
173
- {
174
- "orchestration": "orchestration",
175
- END: END,
176
- },
177
- )
178
- workflow.set_entry_point("message_hook")
179
-
180
- return workflow.compile(checkpointer=checkpointer, store=store)
181
-
182
-
183
- def _create_swarm_graph(config: AppConfig) -> CompiledStateGraph:
184
- logger.debug("Creating swarm graph")
185
- agents: list[CompiledStateGraph] = []
186
- for registered_agent in config.app.agents:
187
- handoff_tools: Sequence[BaseTool] = _handoffs_for_agent(
188
- agent=registered_agent, config=config
189
- )
190
- agents.append(
191
- create_agent_node(
192
- agent=registered_agent,
193
- memory=config.app.orchestration.memory
194
- if config.app.orchestration
195
- else None,
196
- chat_history=config.app.chat_history,
197
- additional_tools=handoff_tools,
198
- )
199
- )
200
-
201
- orchestration: OrchestrationModel = config.app.orchestration
202
- swarm: SwarmModel = orchestration.swarm
203
-
204
- store: BaseStore = None
205
- if orchestration.memory and orchestration.memory.store:
206
- store = orchestration.memory.store.as_store()
207
- logger.debug(f"Using memory store: {store}")
208
-
209
- checkpointer: BaseCheckpointSaver = None
210
- if orchestration.memory and orchestration.memory.checkpointer:
211
- checkpointer = orchestration.memory.checkpointer.as_checkpointer()
212
- logger.debug(f"Using checkpointer: {checkpointer}")
213
-
214
- default_agent: AgentModel = swarm.default_agent
215
- if isinstance(default_agent, AgentModel):
216
- default_agent = default_agent.name
217
-
218
- swarm_workflow: StateGraph = create_swarm(
219
- agents=agents,
220
- default_active_agent=default_agent,
221
- state_schema=SharedState,
222
- context_schema=Context,
223
- )
224
-
225
- swarm_node: CompiledStateGraph = swarm_workflow.compile(
226
- checkpointer=checkpointer, store=store
227
- )
228
-
229
- workflow: StateGraph = StateGraph(
230
- SharedState,
231
- input_schema=IncomingState,
232
- output_schema=OutgoingState,
233
- context_schema=Context,
234
- )
235
-
236
- workflow.add_node("message_hook", message_hook_node(config=config))
237
- workflow.add_node("orchestration", swarm_node)
238
-
239
- workflow.add_conditional_edges(
240
- "message_hook",
241
- route_message,
242
- {
243
- "orchestration": "orchestration",
244
- END: END,
245
- },
246
- )
13
+ from langgraph.graph.state import CompiledStateGraph
247
14
 
248
- workflow.set_entry_point("message_hook")
15
+ from dao_ai.config import AppConfig
16
+ from dao_ai.orchestration import create_orchestration_graph
249
17
 
250
- return swarm_node
251
18
 
252
- # return workflow.compile(checkpointer=checkpointer, store=store)
19
+ def create_dao_ai_graph(config: AppConfig) -> CompiledStateGraph:
20
+ """
21
+ Create the main DAO AI graph based on the orchestration configuration.
253
22
 
23
+ This factory function creates either a supervisor or swarm graph
24
+ depending on the configuration.
254
25
 
255
- def create_dao_ai_graph(config: AppConfig) -> CompiledStateGraph:
256
- orchestration: OrchestrationModel = config.app.orchestration
257
- if orchestration.supervisor:
258
- return _create_supervisor_graph(config)
26
+ Args:
27
+ config: The application configuration
259
28
 
260
- if orchestration.swarm:
261
- return _create_swarm_graph(config)
29
+ Returns:
30
+ A compiled LangGraph state machine
262
31
 
263
- raise ValueError("No valid orchestration model found in the configuration.")
32
+ Note:
33
+ This function is provided for backwards compatibility.
34
+ New code should use `create_orchestration_graph` from
35
+ `dao_ai.orchestration` instead.
36
+ """
37
+ return create_orchestration_graph(config)
dao_ai/hooks/__init__.py CHANGED
@@ -1,11 +1,17 @@
1
+ """
2
+ Hook utilities for DAO AI.
3
+
4
+ For validation hooks, use middleware instead:
5
+ - dao_ai.middleware.UserIdValidationMiddleware
6
+ - dao_ai.middleware.ThreadIdValidationMiddleware
7
+ - dao_ai.middleware.FilterLastHumanMessageMiddleware
8
+ """
9
+
1
10
  from dao_ai.hooks.core import (
2
11
  create_hooks,
3
- filter_last_human_message_hook,
4
12
  null_hook,
5
13
  null_initialization_hook,
6
14
  null_shutdown_hook,
7
- require_thread_id_hook,
8
- require_user_id_hook,
9
15
  )
10
16
 
11
17
  __all__ = [
@@ -13,7 +19,4 @@ __all__ = [
13
19
  "null_hook",
14
20
  "null_initialization_hook",
15
21
  "null_shutdown_hook",
16
- "require_thread_id_hook",
17
- "require_user_id_hook",
18
- "filter_last_human_message_hook",
19
22
  ]
dao_ai/hooks/core.py CHANGED
@@ -1,20 +1,32 @@
1
- import json
1
+ """
2
+ Hook utilities for DAO AI.
3
+
4
+ This module provides the create_hooks function for resolving FunctionHook
5
+ references to callable functions. Individual validation hooks have been
6
+ migrated to middleware - see dao_ai.middleware.message_validation.
7
+ """
8
+
2
9
  from typing import Any, Callable, Sequence
3
10
 
4
- from langchain_core.messages import BaseMessage, HumanMessage, RemoveMessage
5
- from langgraph.runtime import Runtime
6
11
  from loguru import logger
7
12
 
8
13
  from dao_ai.config import AppConfig, FunctionHook, PythonFunctionModel
9
- from dao_ai.messages import last_human_message
10
- from dao_ai.state import Context
11
14
 
12
15
 
13
16
  def create_hooks(
14
17
  function_hooks: FunctionHook | list[FunctionHook] | None,
15
18
  ) -> Sequence[Callable[..., Any]]:
19
+ """
20
+ Resolve FunctionHook references to callable functions.
21
+
22
+ Args:
23
+ function_hooks: A single FunctionHook or list of FunctionHooks to resolve
24
+
25
+ Returns:
26
+ Sequence of callable functions
27
+ """
16
28
  logger.debug(f"Creating hooks from: {function_hooks}")
17
- hooks: Sequence[Callable[..., Any]] = []
29
+ hooks: list[Callable[..., Any]] = []
18
30
  if not function_hooks:
19
31
  return []
20
32
  if not isinstance(function_hooks, (list, tuple, set)):
@@ -27,197 +39,17 @@ def create_hooks(
27
39
  return hooks
28
40
 
29
41
 
30
- def null_hook(state: dict[str, Any], runtime: Runtime[Context]) -> dict[str, Any]:
42
+ def null_hook(state: dict[str, Any], config: Any) -> dict[str, Any]:
43
+ """A no-op hook that returns an empty dict."""
31
44
  logger.debug("Executing null hook")
32
45
  return {}
33
46
 
34
47
 
35
48
  def null_initialization_hook(config: AppConfig) -> None:
49
+ """A no-op initialization hook."""
36
50
  logger.debug("Executing null initialization hook")
37
51
 
38
52
 
39
53
  def null_shutdown_hook(config: AppConfig) -> None:
54
+ """A no-op shutdown hook."""
40
55
  logger.debug("Executing null shutdown hook")
41
-
42
-
43
- def filter_last_human_message_hook(
44
- state: dict[str, Any], runtime: Runtime[Context]
45
- ) -> dict[str, Any]:
46
- """
47
- Filter messages to keep only the last human message.
48
-
49
- This hook removes all messages except for the most recent human message,
50
- which can be useful for scenarios where you want to process only the
51
- latest user input without conversation history.
52
-
53
- Args:
54
- state: The current state containing messages
55
- runtime: The runtime context (unused in this hook)
56
-
57
- Returns:
58
- Updated state with filtered messages
59
- """
60
- logger.debug("Executing filter_last_human_message hook")
61
-
62
- messages: list[BaseMessage] = state.get("messages", [])
63
-
64
- if not messages:
65
- logger.debug("No messages found in state")
66
- return state
67
-
68
- # Use the helper function to find the last human message
69
- last_message: HumanMessage = last_human_message(messages)
70
-
71
- if last_message is None:
72
- logger.debug("No human messages found in state")
73
- # Return empty messages if no human message found
74
- updated_state = state.copy()
75
- updated_state["messages"] = []
76
- return updated_state
77
-
78
- logger.debug(f"Filtered {len(messages)} messages down to 1 (last human message)")
79
-
80
- removed_messages: Sequence[BaseMessage] = [
81
- RemoveMessage(id=message.id)
82
- for message in messages
83
- if message.id != last_message.id
84
- ]
85
-
86
- updated_state: dict[str, Sequence[BaseMessage]] = {"messages": removed_messages}
87
-
88
- return updated_state
89
-
90
-
91
- def require_user_id_hook(
92
- state: dict[str, Any], runtime: Runtime[Context]
93
- ) -> dict[str, Any]:
94
- logger.debug("Executing user_id validation hook")
95
-
96
- context: Context = runtime.context or Context()
97
-
98
- user_id: str | None = context.user_id
99
-
100
- if not user_id:
101
- logger.error("User ID is required but not provided in the configuration.")
102
-
103
- # Create corrected configuration using any provided context parameters
104
- corrected_config = {
105
- "configurable": {
106
- "thread_id": context.thread_id or "1",
107
- "user_id": "my_user_id",
108
- "store_num": context.store_num or 87887,
109
- }
110
- }
111
-
112
- # Format as JSON for copy-paste
113
- corrected_config_json = json.dumps(corrected_config, indent=2)
114
-
115
- error_message = f"""
116
- ## Authentication Required
117
-
118
- A **user_id** is required to process your request. Please provide your user ID in the configuration.
119
-
120
- ### Required Configuration Format
121
-
122
- Please include the following JSON in your request configuration:
123
-
124
- ```json
125
- {corrected_config_json}
126
- ```
127
-
128
- ### Field Descriptions
129
- - **user_id**: Your unique user identifier (required)
130
- - **thread_id**: Conversation thread identifier (required)
131
- - **store_num**: Your store number (required)
132
-
133
- Please update your configuration and try again.
134
- """.strip()
135
-
136
- raise ValueError(error_message)
137
-
138
- if "." in user_id:
139
- logger.error(f"User ID '{user_id}' contains invalid character '.'")
140
-
141
- # Create a corrected version of the user_id
142
- corrected_user_id = user_id.replace(".", "_")
143
-
144
- # Create corrected configuration for the error message
145
-
146
- # Corrected config with fixed user_id
147
- corrected_config = {
148
- "configurable": {
149
- "thread_id": context.thread_id or "1",
150
- "user_id": corrected_user_id,
151
- "store_num": context.store_num or 87887,
152
- }
153
- }
154
-
155
- # Format as JSON for copy-paste
156
- corrected_config_json = json.dumps(corrected_config, indent=2)
157
-
158
- error_message = f"""
159
- ## Invalid User ID Format
160
-
161
- The **user_id** cannot contain a dot character ('.'). Please provide a valid user ID without dots.
162
-
163
- ### Corrected Configuration (Copy & Paste This)
164
- ```json
165
- {corrected_config_json}
166
- ```
167
-
168
- Please update your user_id and try again.
169
- """.strip()
170
-
171
- raise ValueError(error_message)
172
-
173
- return {}
174
-
175
-
176
- def require_thread_id_hook(
177
- state: dict[str, Any], runtime: Runtime[Context]
178
- ) -> dict[str, Any]:
179
- logger.debug("Executing thread_id validation hook")
180
-
181
- context: Context = runtime.context or Context()
182
-
183
- thread_id: str | None = context.thread_id
184
-
185
- if not thread_id:
186
- logger.error("Thread ID is required but not provided in the configuration.")
187
-
188
- # Create corrected configuration using any provided context parameters
189
- corrected_config = {
190
- "configurable": {
191
- "thread_id": "1",
192
- "user_id": context.user_id or "my_user_id",
193
- "store_num": context.store_num or 87887,
194
- }
195
- }
196
-
197
- # Format as JSON for copy-paste
198
- corrected_config_json = json.dumps(corrected_config, indent=2)
199
-
200
- error_message = f"""
201
- ## Authentication Required
202
-
203
- A **thread_id** is required to process your request. Please provide your thread ID in the configuration.
204
-
205
- ### Required Configuration Format
206
-
207
- Please include the following JSON in your request configuration:
208
-
209
- ```json
210
- {corrected_config_json}
211
- ```
212
-
213
- ### Field Descriptions
214
- - **thread_id**: Conversation thread identifier (required)
215
- - **user_id**: Your unique user identifier (required)
216
- - **store_num**: Your store number (required)
217
-
218
- Please update your configuration and try again.
219
- """.strip()
220
-
221
- raise ValueError(error_message)
222
-
223
- return {}
dao_ai/memory/__init__.py CHANGED
@@ -3,10 +3,20 @@ from dao_ai.memory.base import (
3
3
  StoreManagerBase,
4
4
  )
5
5
  from dao_ai.memory.core import CheckpointManager, StoreManager
6
+ from dao_ai.memory.databricks import (
7
+ AsyncDatabricksCheckpointSaver,
8
+ AsyncDatabricksStore,
9
+ DatabricksCheckpointerManager,
10
+ DatabricksStoreManager,
11
+ )
6
12
 
7
13
  __all__ = [
8
14
  "CheckpointManagerBase",
9
15
  "StoreManagerBase",
10
16
  "CheckpointManager",
11
17
  "StoreManager",
18
+ "AsyncDatabricksCheckpointSaver",
19
+ "AsyncDatabricksStore",
20
+ "DatabricksCheckpointerManager",
21
+ "DatabricksStoreManager",
12
22
  ]
dao_ai/memory/core.py CHANGED
@@ -1,8 +1,6 @@
1
1
  from typing import Any
2
2
 
3
- from databricks_langchain import (
4
- DatabricksEmbeddings,
5
- )
3
+ from databricks_langchain import DatabricksEmbeddings
6
4
  from langchain_core.embeddings.embeddings import Embeddings
7
5
  from langgraph.checkpoint.base import BaseCheckpointSaver
8
6
  from langgraph.checkpoint.memory import InMemorySaver
@@ -60,7 +58,7 @@ class StoreManager:
60
58
 
61
59
  @classmethod
62
60
  def instance(cls, store_model: StoreModel) -> StoreManagerBase:
63
- store_manager: StoreManagerBase = None
61
+ store_manager: StoreManagerBase | None = None
64
62
  match store_model.type:
65
63
  case StorageType.MEMORY:
66
64
  store_manager = cls.store_managers.get(store_model.name)
@@ -78,6 +76,13 @@ class StoreManager:
78
76
  cls.store_managers[store_model.database.instance_name] = (
79
77
  store_manager
80
78
  )
79
+ case StorageType.LAKEBASE:
80
+ from dao_ai.memory.databricks import DatabricksStoreManager
81
+
82
+ store_manager = cls.store_managers.get(store_model.name)
83
+ if store_manager is None:
84
+ store_manager = DatabricksStoreManager(store_model)
85
+ cls.store_managers[store_model.name] = store_manager
81
86
  case _:
82
87
  raise ValueError(f"Unknown store type: {store_model.type}")
83
88
 
@@ -89,7 +94,7 @@ class CheckpointManager:
89
94
 
90
95
  @classmethod
91
96
  def instance(cls, checkpointer_model: CheckpointerModel) -> CheckpointManagerBase:
92
- checkpointer_manager: CheckpointManagerBase = None
97
+ checkpointer_manager: CheckpointManagerBase | None = None
93
98
  match checkpointer_model.type:
94
99
  case StorageType.MEMORY:
95
100
  checkpointer_manager = cls.checkpoint_managers.get(
@@ -115,6 +120,19 @@ class CheckpointManager:
115
120
  cls.checkpoint_managers[
116
121
  checkpointer_model.database.instance_name
117
122
  ] = checkpointer_manager
123
+ case StorageType.LAKEBASE:
124
+ from dao_ai.memory.databricks import DatabricksCheckpointerManager
125
+
126
+ checkpointer_manager = cls.checkpoint_managers.get(
127
+ checkpointer_model.name
128
+ )
129
+ if checkpointer_manager is None:
130
+ checkpointer_manager = DatabricksCheckpointerManager(
131
+ checkpointer_model
132
+ )
133
+ cls.checkpoint_managers[checkpointer_model.name] = (
134
+ checkpointer_manager
135
+ )
118
136
  case _:
119
137
  raise ValueError(f"Unknown store type: {checkpointer_model.type}")
120
138