dao-ai 0.0.25__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. dao_ai/__init__.py +29 -0
  2. dao_ai/agent_as_code.py +5 -5
  3. dao_ai/cli.py +245 -40
  4. dao_ai/config.py +1863 -338
  5. dao_ai/genie/__init__.py +38 -0
  6. dao_ai/genie/cache/__init__.py +43 -0
  7. dao_ai/genie/cache/base.py +72 -0
  8. dao_ai/genie/cache/core.py +79 -0
  9. dao_ai/genie/cache/lru.py +347 -0
  10. dao_ai/genie/cache/semantic.py +970 -0
  11. dao_ai/genie/core.py +35 -0
  12. dao_ai/graph.py +27 -228
  13. dao_ai/hooks/__init__.py +9 -6
  14. dao_ai/hooks/core.py +27 -195
  15. dao_ai/logging.py +56 -0
  16. dao_ai/memory/__init__.py +10 -0
  17. dao_ai/memory/core.py +65 -30
  18. dao_ai/memory/databricks.py +402 -0
  19. dao_ai/memory/postgres.py +79 -38
  20. dao_ai/messages.py +6 -4
  21. dao_ai/middleware/__init__.py +125 -0
  22. dao_ai/middleware/assertions.py +806 -0
  23. dao_ai/middleware/base.py +50 -0
  24. dao_ai/middleware/core.py +67 -0
  25. dao_ai/middleware/guardrails.py +420 -0
  26. dao_ai/middleware/human_in_the_loop.py +232 -0
  27. dao_ai/middleware/message_validation.py +586 -0
  28. dao_ai/middleware/summarization.py +197 -0
  29. dao_ai/models.py +1306 -114
  30. dao_ai/nodes.py +261 -166
  31. dao_ai/optimization.py +674 -0
  32. dao_ai/orchestration/__init__.py +52 -0
  33. dao_ai/orchestration/core.py +294 -0
  34. dao_ai/orchestration/supervisor.py +278 -0
  35. dao_ai/orchestration/swarm.py +271 -0
  36. dao_ai/prompts.py +128 -31
  37. dao_ai/providers/databricks.py +645 -172
  38. dao_ai/state.py +157 -21
  39. dao_ai/tools/__init__.py +13 -5
  40. dao_ai/tools/agent.py +1 -3
  41. dao_ai/tools/core.py +64 -11
  42. dao_ai/tools/email.py +232 -0
  43. dao_ai/tools/genie.py +144 -295
  44. dao_ai/tools/mcp.py +220 -133
  45. dao_ai/tools/memory.py +50 -0
  46. dao_ai/tools/python.py +9 -14
  47. dao_ai/tools/search.py +14 -0
  48. dao_ai/tools/slack.py +22 -10
  49. dao_ai/tools/sql.py +202 -0
  50. dao_ai/tools/time.py +30 -7
  51. dao_ai/tools/unity_catalog.py +165 -88
  52. dao_ai/tools/vector_search.py +360 -40
  53. dao_ai/utils.py +218 -16
  54. dao_ai-0.1.2.dist-info/METADATA +455 -0
  55. dao_ai-0.1.2.dist-info/RECORD +64 -0
  56. {dao_ai-0.0.25.dist-info → dao_ai-0.1.2.dist-info}/WHEEL +1 -1
  57. dao_ai/chat_models.py +0 -204
  58. dao_ai/guardrails.py +0 -112
  59. dao_ai/tools/human_in_the_loop.py +0 -100
  60. dao_ai-0.0.25.dist-info/METADATA +0 -1165
  61. dao_ai-0.0.25.dist-info/RECORD +0 -41
  62. {dao_ai-0.0.25.dist-info → dao_ai-0.1.2.dist-info}/entry_points.txt +0 -0
  63. {dao_ai-0.0.25.dist-info → dao_ai-0.1.2.dist-info}/licenses/LICENSE +0 -0
dao_ai/genie/core.py ADDED
@@ -0,0 +1,35 @@
1
+ """
2
+ Core Genie service implementation.
3
+
4
+ This module provides the concrete implementation of GenieServiceBase
5
+ that wraps the Databricks Genie SDK.
6
+ """
7
+
8
+ import mlflow
9
+ from databricks_ai_bridge.genie import Genie, GenieResponse
10
+
11
+ from dao_ai.genie.cache import CacheResult, GenieServiceBase
12
+
13
+
14
+ class GenieService(GenieServiceBase):
15
+ """Concrete implementation of GenieServiceBase using the Genie SDK."""
16
+
17
+ genie: Genie
18
+
19
+ def __init__(self, genie: Genie) -> None:
20
+ self.genie = genie
21
+
22
+ @mlflow.trace(name="genie_ask_question")
23
+ def ask_question(
24
+ self, question: str, conversation_id: str | None = None
25
+ ) -> CacheResult:
26
+ """Ask question to Genie and return CacheResult (no caching at this level)."""
27
+ response: GenieResponse = self.genie.ask_question(
28
+ question, conversation_id=conversation_id
29
+ )
30
+ # No caching at this level - return cache miss
31
+ return CacheResult(response=response, cache_hit=False, served_by=None)
32
+
33
+ @property
34
+ def space_id(self) -> str:
35
+ return self.genie.space_id
dao_ai/graph.py CHANGED
@@ -1,238 +1,37 @@
1
- from typing import Sequence
1
+ """
2
+ Graph creation utilities for DAO AI multi-agent orchestration.
2
3
 
3
- from langchain_core.language_models import LanguageModelLike
4
- from langchain_core.runnables import RunnableConfig
5
- from langchain_core.tools import BaseTool
6
- from langgraph.checkpoint.base import BaseCheckpointSaver
7
- from langgraph.graph import END, StateGraph
8
- from langgraph.graph.state import CompiledStateGraph
9
- from langgraph.store.base import BaseStore
10
- from langgraph_supervisor import create_handoff_tool as supervisor_handoff_tool
11
- from langgraph_supervisor import create_supervisor
12
- from langgraph_swarm import create_handoff_tool as swarm_handoff_tool
13
- from langgraph_swarm import create_swarm
14
- from langmem import create_manage_memory_tool, create_search_memory_tool
15
- from loguru import logger
16
-
17
- from dao_ai.config import (
18
- AgentModel,
19
- AppConfig,
20
- OrchestrationModel,
21
- SupervisorModel,
22
- SwarmModel,
23
- )
24
- from dao_ai.nodes import (
25
- create_agent_node,
26
- message_hook_node,
27
- )
28
- from dao_ai.prompts import make_prompt
29
- from dao_ai.state import Context, IncomingState, OutgoingState, SharedState
30
- from dao_ai.tools import create_tools
31
-
32
-
33
- def route_message(state: SharedState) -> str:
34
- if not state["is_valid"]:
35
- return END
36
- return "orchestration"
37
-
38
-
39
- def _handoffs_for_agent(agent: AgentModel, config: AppConfig) -> Sequence[BaseTool]:
40
- handoff_tools: list[BaseTool] = []
41
-
42
- handoffs: dict[str, Sequence[AgentModel | str]] = (
43
- config.app.orchestration.swarm.handoffs or {}
44
- )
45
- agent_handoffs: Sequence[AgentModel | str] = handoffs.get(agent.name)
46
- if agent_handoffs is None:
47
- agent_handoffs = config.app.agents
48
-
49
- for handoff_to_agent in agent_handoffs:
50
- if isinstance(handoff_to_agent, str):
51
- handoff_to_agent = next(
52
- iter(config.find_agents(lambda a: a.name == handoff_to_agent)), None
53
- )
54
-
55
- if handoff_to_agent is None:
56
- logger.warning(
57
- f"Handoff agent {handoff_to_agent} not found in configuration for agent {agent.name}"
58
- )
59
- continue
60
- if agent.name == handoff_to_agent.name:
61
- continue
62
- logger.debug(
63
- f"Creating handoff tool from agent {agent.name} to {handoff_to_agent.name}"
64
- )
65
- handoff_tools.append(
66
- swarm_handoff_tool(
67
- agent_name=handoff_to_agent.name,
68
- description=f"Ask {handoff_to_agent.name} for help with: "
69
- + handoff_to_agent.handoff_prompt,
70
- )
71
- )
72
- return handoff_tools
73
-
74
-
75
- def _create_supervisor_graph(config: AppConfig) -> CompiledStateGraph:
76
- logger.debug("Creating supervisor graph")
77
- agents: list[CompiledStateGraph] = []
78
- tools: Sequence[BaseTool] = []
79
- for registered_agent in config.app.agents:
80
- agents.append(
81
- create_agent_node(
82
- app=config.app, agent=registered_agent, additional_tools=[]
83
- )
84
- )
85
- tools.append(
86
- supervisor_handoff_tool(
87
- agent_name=registered_agent.name,
88
- description=registered_agent.handoff_prompt,
89
- )
90
- )
91
-
92
- orchestration: OrchestrationModel = config.app.orchestration
93
- supervisor: SupervisorModel = orchestration.supervisor
94
-
95
- tools += create_tools(orchestration.supervisor.tools)
96
-
97
- store: BaseStore = None
98
- if orchestration.memory and orchestration.memory.store:
99
- store = orchestration.memory.store.as_store()
100
- logger.debug(f"Using memory store: {store}")
101
- namespace: tuple[str, ...] = ("memory",)
102
-
103
- if orchestration.memory.store.namespace:
104
- namespace = namespace + (orchestration.memory.store.namespace,)
105
- logger.debug(f"Memory store namespace: {namespace}")
106
- tools += [
107
- create_manage_memory_tool(namespace=namespace),
108
- create_search_memory_tool(namespace=namespace),
109
- ]
110
-
111
- checkpointer: BaseCheckpointSaver = None
112
- if orchestration.memory and orchestration.memory.checkpointer:
113
- checkpointer = orchestration.memory.checkpointer.as_checkpointer()
114
- logger.debug(f"Using checkpointer: {checkpointer}")
115
-
116
- prompt: str = supervisor.prompt
117
-
118
- model: LanguageModelLike = supervisor.model.as_chat_model()
119
- supervisor_workflow: StateGraph = create_supervisor(
120
- supervisor_name="supervisor",
121
- prompt=make_prompt(base_system_prompt=prompt),
122
- agents=agents,
123
- model=model,
124
- tools=tools,
125
- state_schema=SharedState,
126
- config_schema=RunnableConfig,
127
- output_mode="last_message",
128
- add_handoff_messages=False,
129
- add_handoff_back_messages=False,
130
- context_schema=Context,
131
- # output_mode="full",
132
- # add_handoff_messages=True,
133
- # add_handoff_back_messages=True,
134
- )
4
+ This module provides backwards-compatible imports from the refactored
5
+ orchestration package. New code should import directly from:
6
+ - dao_ai.orchestration
7
+ - dao_ai.orchestration.supervisor
8
+ - dao_ai.orchestration.swarm
135
9
 
136
- supervisor_node: CompiledStateGraph = supervisor_workflow.compile(
137
- checkpointer=checkpointer, store=store
138
- )
10
+ See: https://docs.langchain.com/oss/python/langchain/multi-agent
11
+ """
139
12
 
140
- workflow: StateGraph = StateGraph(
141
- SharedState,
142
- input_schema=IncomingState,
143
- output_schema=OutgoingState,
144
- context_schema=Context,
145
- )
146
-
147
- workflow.add_node("message_hook", message_hook_node(config=config))
148
-
149
- workflow.add_node("orchestration", supervisor_node)
150
- workflow.add_conditional_edges(
151
- "message_hook",
152
- route_message,
153
- {
154
- "orchestration": "orchestration",
155
- END: END,
156
- },
157
- )
158
- workflow.set_entry_point("message_hook")
159
-
160
- return workflow.compile(checkpointer=checkpointer, store=store)
161
-
162
-
163
- def _create_swarm_graph(config: AppConfig) -> CompiledStateGraph:
164
- logger.debug("Creating swarm graph")
165
- agents: list[CompiledStateGraph] = []
166
- for registered_agent in config.app.agents:
167
- handoff_tools: Sequence[BaseTool] = _handoffs_for_agent(
168
- agent=registered_agent, config=config
169
- )
170
- agents.append(
171
- create_agent_node(
172
- app=config.app, agent=registered_agent, additional_tools=handoff_tools
173
- )
174
- )
175
-
176
- orchestration: OrchestrationModel = config.app.orchestration
177
- swarm: SwarmModel = orchestration.swarm
178
-
179
- store: BaseStore = None
180
- if orchestration.memory and orchestration.memory.store:
181
- store = orchestration.memory.store.as_store()
182
- logger.debug(f"Using memory store: {store}")
183
-
184
- checkpointer: BaseCheckpointSaver = None
185
- if orchestration.memory and orchestration.memory.checkpointer:
186
- checkpointer = orchestration.memory.checkpointer.as_checkpointer()
187
- logger.debug(f"Using checkpointer: {checkpointer}")
188
-
189
- default_agent: AgentModel = swarm.default_agent
190
- if isinstance(default_agent, AgentModel):
191
- default_agent = default_agent.name
192
-
193
- swarm_workflow: StateGraph = create_swarm(
194
- agents=agents,
195
- default_active_agent=default_agent,
196
- state_schema=SharedState,
197
- context_schema=Context,
198
- )
199
-
200
- swarm_node: CompiledStateGraph = swarm_workflow.compile(
201
- checkpointer=checkpointer, store=store
202
- )
203
-
204
- workflow: StateGraph = StateGraph(
205
- SharedState,
206
- input_schema=IncomingState,
207
- output_schema=OutgoingState,
208
- context_schema=Context,
209
- )
210
-
211
- workflow.add_node("message_hook", message_hook_node(config=config))
212
- workflow.add_node("orchestration", swarm_node)
213
-
214
- workflow.add_conditional_edges(
215
- "message_hook",
216
- route_message,
217
- {
218
- "orchestration": "orchestration",
219
- END: END,
220
- },
221
- )
13
+ from langgraph.graph.state import CompiledStateGraph
222
14
 
223
- workflow.set_entry_point("message_hook")
15
+ from dao_ai.config import AppConfig
16
+ from dao_ai.orchestration import create_orchestration_graph
224
17
 
225
- return swarm_node
226
18
 
227
- # return workflow.compile(checkpointer=checkpointer, store=store)
19
+ def create_dao_ai_graph(config: AppConfig) -> CompiledStateGraph:
20
+ """
21
+ Create the main DAO AI graph based on the orchestration configuration.
228
22
 
23
+ This factory function creates either a supervisor or swarm graph
24
+ depending on the configuration.
229
25
 
230
- def create_dao_ai_graph(config: AppConfig) -> CompiledStateGraph:
231
- orchestration: OrchestrationModel = config.app.orchestration
232
- if orchestration.supervisor:
233
- return _create_supervisor_graph(config)
26
+ Args:
27
+ config: The application configuration
234
28
 
235
- if orchestration.swarm:
236
- return _create_swarm_graph(config)
29
+ Returns:
30
+ A compiled LangGraph state machine
237
31
 
238
- raise ValueError("No valid orchestration model found in the configuration.")
32
+ Note:
33
+ This function is provided for backwards compatibility.
34
+ New code should use `create_orchestration_graph` from
35
+ `dao_ai.orchestration` instead.
36
+ """
37
+ return create_orchestration_graph(config)
dao_ai/hooks/__init__.py CHANGED
@@ -1,11 +1,17 @@
1
+ """
2
+ Hook utilities for DAO AI.
3
+
4
+ For validation hooks, use middleware instead:
5
+ - dao_ai.middleware.UserIdValidationMiddleware
6
+ - dao_ai.middleware.ThreadIdValidationMiddleware
7
+ - dao_ai.middleware.FilterLastHumanMessageMiddleware
8
+ """
9
+
1
10
  from dao_ai.hooks.core import (
2
11
  create_hooks,
3
- filter_last_human_message_hook,
4
12
  null_hook,
5
13
  null_initialization_hook,
6
14
  null_shutdown_hook,
7
- require_thread_id_hook,
8
- require_user_id_hook,
9
15
  )
10
16
 
11
17
  __all__ = [
@@ -13,7 +19,4 @@ __all__ = [
13
19
  "null_hook",
14
20
  "null_initialization_hook",
15
21
  "null_shutdown_hook",
16
- "require_thread_id_hook",
17
- "require_user_id_hook",
18
- "filter_last_human_message_hook",
19
22
  ]
dao_ai/hooks/core.py CHANGED
@@ -1,20 +1,32 @@
1
- import json
1
+ """
2
+ Hook utilities for DAO AI.
3
+
4
+ This module provides the create_hooks function for resolving FunctionHook
5
+ references to callable functions. Individual validation hooks have been
6
+ migrated to middleware - see dao_ai.middleware.message_validation.
7
+ """
8
+
2
9
  from typing import Any, Callable, Sequence
3
10
 
4
- from langchain_core.messages import BaseMessage, HumanMessage, RemoveMessage
5
- from langgraph.runtime import Runtime
6
11
  from loguru import logger
7
12
 
8
13
  from dao_ai.config import AppConfig, FunctionHook, PythonFunctionModel
9
- from dao_ai.messages import last_human_message
10
- from dao_ai.state import Context
11
14
 
12
15
 
13
16
  def create_hooks(
14
17
  function_hooks: FunctionHook | list[FunctionHook] | None,
15
18
  ) -> Sequence[Callable[..., Any]]:
16
- logger.debug(f"Creating hooks from: {function_hooks}")
17
- hooks: Sequence[Callable[..., Any]] = []
19
+ """
20
+ Resolve FunctionHook references to callable functions.
21
+
22
+ Args:
23
+ function_hooks: A single FunctionHook or list of FunctionHooks to resolve
24
+
25
+ Returns:
26
+ Sequence of callable functions
27
+ """
28
+ logger.trace("Creating hooks", function_hooks=function_hooks)
29
+ hooks: list[Callable[..., Any]] = []
18
30
  if not function_hooks:
19
31
  return []
20
32
  if not isinstance(function_hooks, (list, tuple, set)):
@@ -23,201 +35,21 @@ def create_hooks(
23
35
  if isinstance(function_hook, str):
24
36
  function_hook = PythonFunctionModel(name=function_hook)
25
37
  hooks.extend(function_hook.as_tools())
26
- logger.debug(f"Created hooks: {hooks}")
38
+ logger.trace("Created hooks", hooks_count=len(hooks))
27
39
  return hooks
28
40
 
29
41
 
30
- def null_hook(state: dict[str, Any], runtime: Runtime[Context]) -> dict[str, Any]:
31
- logger.debug("Executing null hook")
42
+ def null_hook(state: dict[str, Any], config: Any) -> dict[str, Any]:
43
+ """A no-op hook that returns an empty dict."""
44
+ logger.trace("Executing null hook")
32
45
  return {}
33
46
 
34
47
 
35
48
  def null_initialization_hook(config: AppConfig) -> None:
36
- logger.debug("Executing null initialization hook")
49
+ """A no-op initialization hook."""
50
+ logger.trace("Executing null initialization hook")
37
51
 
38
52
 
39
53
  def null_shutdown_hook(config: AppConfig) -> None:
40
- logger.debug("Executing null shutdown hook")
41
-
42
-
43
- def filter_last_human_message_hook(
44
- state: dict[str, Any], runtime: Runtime[Context]
45
- ) -> dict[str, Any]:
46
- """
47
- Filter messages to keep only the last human message.
48
-
49
- This hook removes all messages except for the most recent human message,
50
- which can be useful for scenarios where you want to process only the
51
- latest user input without conversation history.
52
-
53
- Args:
54
- state: The current state containing messages
55
- runtime: The runtime context (unused in this hook)
56
-
57
- Returns:
58
- Updated state with filtered messages
59
- """
60
- logger.debug("Executing filter_last_human_message hook")
61
-
62
- messages: list[BaseMessage] = state.get("messages", [])
63
-
64
- if not messages:
65
- logger.debug("No messages found in state")
66
- return state
67
-
68
- # Use the helper function to find the last human message
69
- last_message: HumanMessage = last_human_message(messages)
70
-
71
- if last_message is None:
72
- logger.debug("No human messages found in state")
73
- # Return empty messages if no human message found
74
- updated_state = state.copy()
75
- updated_state["messages"] = []
76
- return updated_state
77
-
78
- logger.debug(f"Filtered {len(messages)} messages down to 1 (last human message)")
79
-
80
- removed_messages: Sequence[BaseMessage] = [
81
- RemoveMessage(id=message.id)
82
- for message in messages
83
- if message.id != last_message.id
84
- ]
85
-
86
- updated_state: dict[str, Sequence[BaseMessage]] = {"messages": removed_messages}
87
-
88
- return updated_state
89
-
90
-
91
- def require_user_id_hook(
92
- state: dict[str, Any], runtime: Runtime[Context]
93
- ) -> dict[str, Any]:
94
- logger.debug("Executing user_id validation hook")
95
-
96
- context: Context = runtime.context or Context()
97
-
98
- user_id: str | None = context.user_id
99
-
100
- if not user_id:
101
- logger.error("User ID is required but not provided in the configuration.")
102
-
103
- # Create corrected configuration using any provided context parameters
104
- corrected_config = {
105
- "configurable": {
106
- "thread_id": context.thread_id or "1",
107
- "user_id": "my_user_id",
108
- "store_num": context.store_num or 87887,
109
- }
110
- }
111
-
112
- # Format as JSON for copy-paste
113
- corrected_config_json = json.dumps(corrected_config, indent=2)
114
-
115
- error_message = f"""
116
- ## Authentication Required
117
-
118
- A **user_id** is required to process your request. Please provide your user ID in the configuration.
119
-
120
- ### Required Configuration Format
121
-
122
- Please include the following JSON in your request configuration:
123
-
124
- ```json
125
- {corrected_config_json}
126
- ```
127
-
128
- ### Field Descriptions
129
- - **user_id**: Your unique user identifier (required)
130
- - **thread_id**: Conversation thread identifier (required)
131
- - **store_num**: Your store number (required)
132
-
133
- Please update your configuration and try again.
134
- """.strip()
135
-
136
- raise ValueError(error_message)
137
-
138
- if "." in user_id:
139
- logger.error(f"User ID '{user_id}' contains invalid character '.'")
140
-
141
- # Create a corrected version of the user_id
142
- corrected_user_id = user_id.replace(".", "_")
143
-
144
- # Create corrected configuration for the error message
145
-
146
- # Corrected config with fixed user_id
147
- corrected_config = {
148
- "configurable": {
149
- "thread_id": context.thread_id or "1",
150
- "user_id": corrected_user_id,
151
- "store_num": context.store_num or 87887,
152
- }
153
- }
154
-
155
- # Format as JSON for copy-paste
156
- corrected_config_json = json.dumps(corrected_config, indent=2)
157
-
158
- error_message = f"""
159
- ## Invalid User ID Format
160
-
161
- The **user_id** cannot contain a dot character ('.'). Please provide a valid user ID without dots.
162
-
163
- ### Corrected Configuration (Copy & Paste This)
164
- ```json
165
- {corrected_config_json}
166
- ```
167
-
168
- Please update your user_id and try again.
169
- """.strip()
170
-
171
- raise ValueError(error_message)
172
-
173
- return {}
174
-
175
-
176
- def require_thread_id_hook(
177
- state: dict[str, Any], runtime: Runtime[Context]
178
- ) -> dict[str, Any]:
179
- logger.debug("Executing thread_id validation hook")
180
-
181
- context: Context = runtime.context or Context()
182
-
183
- thread_id: str | None = context.thread_id
184
-
185
- if not thread_id:
186
- logger.error("Thread ID is required but not provided in the configuration.")
187
-
188
- # Create corrected configuration using any provided context parameters
189
- corrected_config = {
190
- "configurable": {
191
- "thread_id": "1",
192
- "user_id": context.user_id or "my_user_id",
193
- "store_num": context.store_num or 87887,
194
- }
195
- }
196
-
197
- # Format as JSON for copy-paste
198
- corrected_config_json = json.dumps(corrected_config, indent=2)
199
-
200
- error_message = f"""
201
- ## Authentication Required
202
-
203
- A **thread_id** is required to process your request. Please provide your thread ID in the configuration.
204
-
205
- ### Required Configuration Format
206
-
207
- Please include the following JSON in your request configuration:
208
-
209
- ```json
210
- {corrected_config_json}
211
- ```
212
-
213
- ### Field Descriptions
214
- - **thread_id**: Conversation thread identifier (required)
215
- - **user_id**: Your unique user identifier (required)
216
- - **store_num**: Your store number (required)
217
-
218
- Please update your configuration and try again.
219
- """.strip()
220
-
221
- raise ValueError(error_message)
222
-
223
- return {}
54
+ """A no-op shutdown hook."""
55
+ logger.trace("Executing null shutdown hook")
dao_ai/logging.py ADDED
@@ -0,0 +1,56 @@
1
+ """Logging configuration for DAO AI."""
2
+
3
+ import sys
4
+ from typing import Any
5
+
6
+ from loguru import logger
7
+
8
+ # Re-export logger for convenience
9
+ __all__ = ["logger", "configure_logging"]
10
+
11
+
12
+ def format_extra(record: dict[str, Any]) -> str:
13
+ """Format extra fields as key=value pairs."""
14
+ extra: dict[str, Any] = record["extra"]
15
+ if not extra:
16
+ return ""
17
+
18
+ formatted_pairs: list[str] = []
19
+ for key, value in extra.items():
20
+ # Handle different value types
21
+ if isinstance(value, str):
22
+ formatted_pairs.append(f"{key}={value}")
23
+ elif isinstance(value, (list, tuple)):
24
+ formatted_pairs.append(f"{key}={','.join(str(v) for v in value)}")
25
+ else:
26
+ formatted_pairs.append(f"{key}={value}")
27
+
28
+ return " | ".join(formatted_pairs)
29
+
30
+
31
+ def configure_logging(level: str = "INFO") -> None:
32
+ """
33
+ Configure loguru logging with structured output.
34
+
35
+ Args:
36
+ level: The log level (e.g., "INFO", "DEBUG", "WARNING")
37
+ """
38
+ logger.remove()
39
+ logger.add(
40
+ sys.stderr,
41
+ level=level,
42
+ format=(
43
+ "<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | "
44
+ "<level>{level: <8}</level> | "
45
+ "<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - "
46
+ "<level>{message}</level>"
47
+ "{extra}"
48
+ ),
49
+ )
50
+
51
+ # Add custom formatter for extra fields
52
+ logger.configure(
53
+ patcher=lambda record: record.update(
54
+ extra=" | " + format_extra(record) if record["extra"] else ""
55
+ )
56
+ )
dao_ai/memory/__init__.py CHANGED
@@ -3,10 +3,20 @@ from dao_ai.memory.base import (
3
3
  StoreManagerBase,
4
4
  )
5
5
  from dao_ai.memory.core import CheckpointManager, StoreManager
6
+ from dao_ai.memory.databricks import (
7
+ AsyncDatabricksCheckpointSaver,
8
+ AsyncDatabricksStore,
9
+ DatabricksCheckpointerManager,
10
+ DatabricksStoreManager,
11
+ )
6
12
 
7
13
  __all__ = [
8
14
  "CheckpointManagerBase",
9
15
  "StoreManagerBase",
10
16
  "CheckpointManager",
11
17
  "StoreManager",
18
+ "AsyncDatabricksCheckpointSaver",
19
+ "AsyncDatabricksStore",
20
+ "DatabricksCheckpointerManager",
21
+ "DatabricksStoreManager",
12
22
  ]