dao-ai 0.0.25__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/__init__.py +29 -0
- dao_ai/agent_as_code.py +5 -5
- dao_ai/cli.py +245 -40
- dao_ai/config.py +1863 -338
- dao_ai/genie/__init__.py +38 -0
- dao_ai/genie/cache/__init__.py +43 -0
- dao_ai/genie/cache/base.py +72 -0
- dao_ai/genie/cache/core.py +79 -0
- dao_ai/genie/cache/lru.py +347 -0
- dao_ai/genie/cache/semantic.py +970 -0
- dao_ai/genie/core.py +35 -0
- dao_ai/graph.py +27 -228
- dao_ai/hooks/__init__.py +9 -6
- dao_ai/hooks/core.py +27 -195
- dao_ai/logging.py +56 -0
- dao_ai/memory/__init__.py +10 -0
- dao_ai/memory/core.py +65 -30
- dao_ai/memory/databricks.py +402 -0
- dao_ai/memory/postgres.py +79 -38
- dao_ai/messages.py +6 -4
- dao_ai/middleware/__init__.py +125 -0
- dao_ai/middleware/assertions.py +806 -0
- dao_ai/middleware/base.py +50 -0
- dao_ai/middleware/core.py +67 -0
- dao_ai/middleware/guardrails.py +420 -0
- dao_ai/middleware/human_in_the_loop.py +232 -0
- dao_ai/middleware/message_validation.py +586 -0
- dao_ai/middleware/summarization.py +197 -0
- dao_ai/models.py +1306 -114
- dao_ai/nodes.py +261 -166
- dao_ai/optimization.py +674 -0
- dao_ai/orchestration/__init__.py +52 -0
- dao_ai/orchestration/core.py +294 -0
- dao_ai/orchestration/supervisor.py +278 -0
- dao_ai/orchestration/swarm.py +271 -0
- dao_ai/prompts.py +128 -31
- dao_ai/providers/databricks.py +645 -172
- dao_ai/state.py +157 -21
- dao_ai/tools/__init__.py +13 -5
- dao_ai/tools/agent.py +1 -3
- dao_ai/tools/core.py +64 -11
- dao_ai/tools/email.py +232 -0
- dao_ai/tools/genie.py +144 -295
- dao_ai/tools/mcp.py +220 -133
- dao_ai/tools/memory.py +50 -0
- dao_ai/tools/python.py +9 -14
- dao_ai/tools/search.py +14 -0
- dao_ai/tools/slack.py +22 -10
- dao_ai/tools/sql.py +202 -0
- dao_ai/tools/time.py +30 -7
- dao_ai/tools/unity_catalog.py +165 -88
- dao_ai/tools/vector_search.py +360 -40
- dao_ai/utils.py +218 -16
- dao_ai-0.1.2.dist-info/METADATA +455 -0
- dao_ai-0.1.2.dist-info/RECORD +64 -0
- {dao_ai-0.0.25.dist-info → dao_ai-0.1.2.dist-info}/WHEEL +1 -1
- dao_ai/chat_models.py +0 -204
- dao_ai/guardrails.py +0 -112
- dao_ai/tools/human_in_the_loop.py +0 -100
- dao_ai-0.0.25.dist-info/METADATA +0 -1165
- dao_ai-0.0.25.dist-info/RECORD +0 -41
- {dao_ai-0.0.25.dist-info → dao_ai-0.1.2.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.0.25.dist-info → dao_ai-0.1.2.dist-info}/licenses/LICENSE +0 -0
dao_ai/genie/core.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Core Genie service implementation.
|
|
3
|
+
|
|
4
|
+
This module provides the concrete implementation of GenieServiceBase
|
|
5
|
+
that wraps the Databricks Genie SDK.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import mlflow
|
|
9
|
+
from databricks_ai_bridge.genie import Genie, GenieResponse
|
|
10
|
+
|
|
11
|
+
from dao_ai.genie.cache import CacheResult, GenieServiceBase
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class GenieService(GenieServiceBase):
|
|
15
|
+
"""Concrete implementation of GenieServiceBase using the Genie SDK."""
|
|
16
|
+
|
|
17
|
+
genie: Genie
|
|
18
|
+
|
|
19
|
+
def __init__(self, genie: Genie) -> None:
|
|
20
|
+
self.genie = genie
|
|
21
|
+
|
|
22
|
+
@mlflow.trace(name="genie_ask_question")
|
|
23
|
+
def ask_question(
|
|
24
|
+
self, question: str, conversation_id: str | None = None
|
|
25
|
+
) -> CacheResult:
|
|
26
|
+
"""Ask question to Genie and return CacheResult (no caching at this level)."""
|
|
27
|
+
response: GenieResponse = self.genie.ask_question(
|
|
28
|
+
question, conversation_id=conversation_id
|
|
29
|
+
)
|
|
30
|
+
# No caching at this level - return cache miss
|
|
31
|
+
return CacheResult(response=response, cache_hit=False, served_by=None)
|
|
32
|
+
|
|
33
|
+
@property
|
|
34
|
+
def space_id(self) -> str:
|
|
35
|
+
return self.genie.space_id
|
dao_ai/graph.py
CHANGED
|
@@ -1,238 +1,37 @@
|
|
|
1
|
-
|
|
1
|
+
"""
|
|
2
|
+
Graph creation utilities for DAO AI multi-agent orchestration.
|
|
2
3
|
|
|
3
|
-
from
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
from langgraph.graph.state import CompiledStateGraph
|
|
9
|
-
from langgraph.store.base import BaseStore
|
|
10
|
-
from langgraph_supervisor import create_handoff_tool as supervisor_handoff_tool
|
|
11
|
-
from langgraph_supervisor import create_supervisor
|
|
12
|
-
from langgraph_swarm import create_handoff_tool as swarm_handoff_tool
|
|
13
|
-
from langgraph_swarm import create_swarm
|
|
14
|
-
from langmem import create_manage_memory_tool, create_search_memory_tool
|
|
15
|
-
from loguru import logger
|
|
16
|
-
|
|
17
|
-
from dao_ai.config import (
|
|
18
|
-
AgentModel,
|
|
19
|
-
AppConfig,
|
|
20
|
-
OrchestrationModel,
|
|
21
|
-
SupervisorModel,
|
|
22
|
-
SwarmModel,
|
|
23
|
-
)
|
|
24
|
-
from dao_ai.nodes import (
|
|
25
|
-
create_agent_node,
|
|
26
|
-
message_hook_node,
|
|
27
|
-
)
|
|
28
|
-
from dao_ai.prompts import make_prompt
|
|
29
|
-
from dao_ai.state import Context, IncomingState, OutgoingState, SharedState
|
|
30
|
-
from dao_ai.tools import create_tools
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
def route_message(state: SharedState) -> str:
|
|
34
|
-
if not state["is_valid"]:
|
|
35
|
-
return END
|
|
36
|
-
return "orchestration"
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
def _handoffs_for_agent(agent: AgentModel, config: AppConfig) -> Sequence[BaseTool]:
|
|
40
|
-
handoff_tools: list[BaseTool] = []
|
|
41
|
-
|
|
42
|
-
handoffs: dict[str, Sequence[AgentModel | str]] = (
|
|
43
|
-
config.app.orchestration.swarm.handoffs or {}
|
|
44
|
-
)
|
|
45
|
-
agent_handoffs: Sequence[AgentModel | str] = handoffs.get(agent.name)
|
|
46
|
-
if agent_handoffs is None:
|
|
47
|
-
agent_handoffs = config.app.agents
|
|
48
|
-
|
|
49
|
-
for handoff_to_agent in agent_handoffs:
|
|
50
|
-
if isinstance(handoff_to_agent, str):
|
|
51
|
-
handoff_to_agent = next(
|
|
52
|
-
iter(config.find_agents(lambda a: a.name == handoff_to_agent)), None
|
|
53
|
-
)
|
|
54
|
-
|
|
55
|
-
if handoff_to_agent is None:
|
|
56
|
-
logger.warning(
|
|
57
|
-
f"Handoff agent {handoff_to_agent} not found in configuration for agent {agent.name}"
|
|
58
|
-
)
|
|
59
|
-
continue
|
|
60
|
-
if agent.name == handoff_to_agent.name:
|
|
61
|
-
continue
|
|
62
|
-
logger.debug(
|
|
63
|
-
f"Creating handoff tool from agent {agent.name} to {handoff_to_agent.name}"
|
|
64
|
-
)
|
|
65
|
-
handoff_tools.append(
|
|
66
|
-
swarm_handoff_tool(
|
|
67
|
-
agent_name=handoff_to_agent.name,
|
|
68
|
-
description=f"Ask {handoff_to_agent.name} for help with: "
|
|
69
|
-
+ handoff_to_agent.handoff_prompt,
|
|
70
|
-
)
|
|
71
|
-
)
|
|
72
|
-
return handoff_tools
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
def _create_supervisor_graph(config: AppConfig) -> CompiledStateGraph:
|
|
76
|
-
logger.debug("Creating supervisor graph")
|
|
77
|
-
agents: list[CompiledStateGraph] = []
|
|
78
|
-
tools: Sequence[BaseTool] = []
|
|
79
|
-
for registered_agent in config.app.agents:
|
|
80
|
-
agents.append(
|
|
81
|
-
create_agent_node(
|
|
82
|
-
app=config.app, agent=registered_agent, additional_tools=[]
|
|
83
|
-
)
|
|
84
|
-
)
|
|
85
|
-
tools.append(
|
|
86
|
-
supervisor_handoff_tool(
|
|
87
|
-
agent_name=registered_agent.name,
|
|
88
|
-
description=registered_agent.handoff_prompt,
|
|
89
|
-
)
|
|
90
|
-
)
|
|
91
|
-
|
|
92
|
-
orchestration: OrchestrationModel = config.app.orchestration
|
|
93
|
-
supervisor: SupervisorModel = orchestration.supervisor
|
|
94
|
-
|
|
95
|
-
tools += create_tools(orchestration.supervisor.tools)
|
|
96
|
-
|
|
97
|
-
store: BaseStore = None
|
|
98
|
-
if orchestration.memory and orchestration.memory.store:
|
|
99
|
-
store = orchestration.memory.store.as_store()
|
|
100
|
-
logger.debug(f"Using memory store: {store}")
|
|
101
|
-
namespace: tuple[str, ...] = ("memory",)
|
|
102
|
-
|
|
103
|
-
if orchestration.memory.store.namespace:
|
|
104
|
-
namespace = namespace + (orchestration.memory.store.namespace,)
|
|
105
|
-
logger.debug(f"Memory store namespace: {namespace}")
|
|
106
|
-
tools += [
|
|
107
|
-
create_manage_memory_tool(namespace=namespace),
|
|
108
|
-
create_search_memory_tool(namespace=namespace),
|
|
109
|
-
]
|
|
110
|
-
|
|
111
|
-
checkpointer: BaseCheckpointSaver = None
|
|
112
|
-
if orchestration.memory and orchestration.memory.checkpointer:
|
|
113
|
-
checkpointer = orchestration.memory.checkpointer.as_checkpointer()
|
|
114
|
-
logger.debug(f"Using checkpointer: {checkpointer}")
|
|
115
|
-
|
|
116
|
-
prompt: str = supervisor.prompt
|
|
117
|
-
|
|
118
|
-
model: LanguageModelLike = supervisor.model.as_chat_model()
|
|
119
|
-
supervisor_workflow: StateGraph = create_supervisor(
|
|
120
|
-
supervisor_name="supervisor",
|
|
121
|
-
prompt=make_prompt(base_system_prompt=prompt),
|
|
122
|
-
agents=agents,
|
|
123
|
-
model=model,
|
|
124
|
-
tools=tools,
|
|
125
|
-
state_schema=SharedState,
|
|
126
|
-
config_schema=RunnableConfig,
|
|
127
|
-
output_mode="last_message",
|
|
128
|
-
add_handoff_messages=False,
|
|
129
|
-
add_handoff_back_messages=False,
|
|
130
|
-
context_schema=Context,
|
|
131
|
-
# output_mode="full",
|
|
132
|
-
# add_handoff_messages=True,
|
|
133
|
-
# add_handoff_back_messages=True,
|
|
134
|
-
)
|
|
4
|
+
This module provides backwards-compatible imports from the refactored
|
|
5
|
+
orchestration package. New code should import directly from:
|
|
6
|
+
- dao_ai.orchestration
|
|
7
|
+
- dao_ai.orchestration.supervisor
|
|
8
|
+
- dao_ai.orchestration.swarm
|
|
135
9
|
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
)
|
|
10
|
+
See: https://docs.langchain.com/oss/python/langchain/multi-agent
|
|
11
|
+
"""
|
|
139
12
|
|
|
140
|
-
|
|
141
|
-
SharedState,
|
|
142
|
-
input_schema=IncomingState,
|
|
143
|
-
output_schema=OutgoingState,
|
|
144
|
-
context_schema=Context,
|
|
145
|
-
)
|
|
146
|
-
|
|
147
|
-
workflow.add_node("message_hook", message_hook_node(config=config))
|
|
148
|
-
|
|
149
|
-
workflow.add_node("orchestration", supervisor_node)
|
|
150
|
-
workflow.add_conditional_edges(
|
|
151
|
-
"message_hook",
|
|
152
|
-
route_message,
|
|
153
|
-
{
|
|
154
|
-
"orchestration": "orchestration",
|
|
155
|
-
END: END,
|
|
156
|
-
},
|
|
157
|
-
)
|
|
158
|
-
workflow.set_entry_point("message_hook")
|
|
159
|
-
|
|
160
|
-
return workflow.compile(checkpointer=checkpointer, store=store)
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
def _create_swarm_graph(config: AppConfig) -> CompiledStateGraph:
|
|
164
|
-
logger.debug("Creating swarm graph")
|
|
165
|
-
agents: list[CompiledStateGraph] = []
|
|
166
|
-
for registered_agent in config.app.agents:
|
|
167
|
-
handoff_tools: Sequence[BaseTool] = _handoffs_for_agent(
|
|
168
|
-
agent=registered_agent, config=config
|
|
169
|
-
)
|
|
170
|
-
agents.append(
|
|
171
|
-
create_agent_node(
|
|
172
|
-
app=config.app, agent=registered_agent, additional_tools=handoff_tools
|
|
173
|
-
)
|
|
174
|
-
)
|
|
175
|
-
|
|
176
|
-
orchestration: OrchestrationModel = config.app.orchestration
|
|
177
|
-
swarm: SwarmModel = orchestration.swarm
|
|
178
|
-
|
|
179
|
-
store: BaseStore = None
|
|
180
|
-
if orchestration.memory and orchestration.memory.store:
|
|
181
|
-
store = orchestration.memory.store.as_store()
|
|
182
|
-
logger.debug(f"Using memory store: {store}")
|
|
183
|
-
|
|
184
|
-
checkpointer: BaseCheckpointSaver = None
|
|
185
|
-
if orchestration.memory and orchestration.memory.checkpointer:
|
|
186
|
-
checkpointer = orchestration.memory.checkpointer.as_checkpointer()
|
|
187
|
-
logger.debug(f"Using checkpointer: {checkpointer}")
|
|
188
|
-
|
|
189
|
-
default_agent: AgentModel = swarm.default_agent
|
|
190
|
-
if isinstance(default_agent, AgentModel):
|
|
191
|
-
default_agent = default_agent.name
|
|
192
|
-
|
|
193
|
-
swarm_workflow: StateGraph = create_swarm(
|
|
194
|
-
agents=agents,
|
|
195
|
-
default_active_agent=default_agent,
|
|
196
|
-
state_schema=SharedState,
|
|
197
|
-
context_schema=Context,
|
|
198
|
-
)
|
|
199
|
-
|
|
200
|
-
swarm_node: CompiledStateGraph = swarm_workflow.compile(
|
|
201
|
-
checkpointer=checkpointer, store=store
|
|
202
|
-
)
|
|
203
|
-
|
|
204
|
-
workflow: StateGraph = StateGraph(
|
|
205
|
-
SharedState,
|
|
206
|
-
input_schema=IncomingState,
|
|
207
|
-
output_schema=OutgoingState,
|
|
208
|
-
context_schema=Context,
|
|
209
|
-
)
|
|
210
|
-
|
|
211
|
-
workflow.add_node("message_hook", message_hook_node(config=config))
|
|
212
|
-
workflow.add_node("orchestration", swarm_node)
|
|
213
|
-
|
|
214
|
-
workflow.add_conditional_edges(
|
|
215
|
-
"message_hook",
|
|
216
|
-
route_message,
|
|
217
|
-
{
|
|
218
|
-
"orchestration": "orchestration",
|
|
219
|
-
END: END,
|
|
220
|
-
},
|
|
221
|
-
)
|
|
13
|
+
from langgraph.graph.state import CompiledStateGraph
|
|
222
14
|
|
|
223
|
-
|
|
15
|
+
from dao_ai.config import AppConfig
|
|
16
|
+
from dao_ai.orchestration import create_orchestration_graph
|
|
224
17
|
|
|
225
|
-
return swarm_node
|
|
226
18
|
|
|
227
|
-
|
|
19
|
+
def create_dao_ai_graph(config: AppConfig) -> CompiledStateGraph:
|
|
20
|
+
"""
|
|
21
|
+
Create the main DAO AI graph based on the orchestration configuration.
|
|
228
22
|
|
|
23
|
+
This factory function creates either a supervisor or swarm graph
|
|
24
|
+
depending on the configuration.
|
|
229
25
|
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
if orchestration.supervisor:
|
|
233
|
-
return _create_supervisor_graph(config)
|
|
26
|
+
Args:
|
|
27
|
+
config: The application configuration
|
|
234
28
|
|
|
235
|
-
|
|
236
|
-
|
|
29
|
+
Returns:
|
|
30
|
+
A compiled LangGraph state machine
|
|
237
31
|
|
|
238
|
-
|
|
32
|
+
Note:
|
|
33
|
+
This function is provided for backwards compatibility.
|
|
34
|
+
New code should use `create_orchestration_graph` from
|
|
35
|
+
`dao_ai.orchestration` instead.
|
|
36
|
+
"""
|
|
37
|
+
return create_orchestration_graph(config)
|
dao_ai/hooks/__init__.py
CHANGED
|
@@ -1,11 +1,17 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Hook utilities for DAO AI.
|
|
3
|
+
|
|
4
|
+
For validation hooks, use middleware instead:
|
|
5
|
+
- dao_ai.middleware.UserIdValidationMiddleware
|
|
6
|
+
- dao_ai.middleware.ThreadIdValidationMiddleware
|
|
7
|
+
- dao_ai.middleware.FilterLastHumanMessageMiddleware
|
|
8
|
+
"""
|
|
9
|
+
|
|
1
10
|
from dao_ai.hooks.core import (
|
|
2
11
|
create_hooks,
|
|
3
|
-
filter_last_human_message_hook,
|
|
4
12
|
null_hook,
|
|
5
13
|
null_initialization_hook,
|
|
6
14
|
null_shutdown_hook,
|
|
7
|
-
require_thread_id_hook,
|
|
8
|
-
require_user_id_hook,
|
|
9
15
|
)
|
|
10
16
|
|
|
11
17
|
__all__ = [
|
|
@@ -13,7 +19,4 @@ __all__ = [
|
|
|
13
19
|
"null_hook",
|
|
14
20
|
"null_initialization_hook",
|
|
15
21
|
"null_shutdown_hook",
|
|
16
|
-
"require_thread_id_hook",
|
|
17
|
-
"require_user_id_hook",
|
|
18
|
-
"filter_last_human_message_hook",
|
|
19
22
|
]
|
dao_ai/hooks/core.py
CHANGED
|
@@ -1,20 +1,32 @@
|
|
|
1
|
-
|
|
1
|
+
"""
|
|
2
|
+
Hook utilities for DAO AI.
|
|
3
|
+
|
|
4
|
+
This module provides the create_hooks function for resolving FunctionHook
|
|
5
|
+
references to callable functions. Individual validation hooks have been
|
|
6
|
+
migrated to middleware - see dao_ai.middleware.message_validation.
|
|
7
|
+
"""
|
|
8
|
+
|
|
2
9
|
from typing import Any, Callable, Sequence
|
|
3
10
|
|
|
4
|
-
from langchain_core.messages import BaseMessage, HumanMessage, RemoveMessage
|
|
5
|
-
from langgraph.runtime import Runtime
|
|
6
11
|
from loguru import logger
|
|
7
12
|
|
|
8
13
|
from dao_ai.config import AppConfig, FunctionHook, PythonFunctionModel
|
|
9
|
-
from dao_ai.messages import last_human_message
|
|
10
|
-
from dao_ai.state import Context
|
|
11
14
|
|
|
12
15
|
|
|
13
16
|
def create_hooks(
|
|
14
17
|
function_hooks: FunctionHook | list[FunctionHook] | None,
|
|
15
18
|
) -> Sequence[Callable[..., Any]]:
|
|
16
|
-
|
|
17
|
-
|
|
19
|
+
"""
|
|
20
|
+
Resolve FunctionHook references to callable functions.
|
|
21
|
+
|
|
22
|
+
Args:
|
|
23
|
+
function_hooks: A single FunctionHook or list of FunctionHooks to resolve
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
Sequence of callable functions
|
|
27
|
+
"""
|
|
28
|
+
logger.trace("Creating hooks", function_hooks=function_hooks)
|
|
29
|
+
hooks: list[Callable[..., Any]] = []
|
|
18
30
|
if not function_hooks:
|
|
19
31
|
return []
|
|
20
32
|
if not isinstance(function_hooks, (list, tuple, set)):
|
|
@@ -23,201 +35,21 @@ def create_hooks(
|
|
|
23
35
|
if isinstance(function_hook, str):
|
|
24
36
|
function_hook = PythonFunctionModel(name=function_hook)
|
|
25
37
|
hooks.extend(function_hook.as_tools())
|
|
26
|
-
logger.
|
|
38
|
+
logger.trace("Created hooks", hooks_count=len(hooks))
|
|
27
39
|
return hooks
|
|
28
40
|
|
|
29
41
|
|
|
30
|
-
def null_hook(state: dict[str, Any],
|
|
31
|
-
|
|
42
|
+
def null_hook(state: dict[str, Any], config: Any) -> dict[str, Any]:
|
|
43
|
+
"""A no-op hook that returns an empty dict."""
|
|
44
|
+
logger.trace("Executing null hook")
|
|
32
45
|
return {}
|
|
33
46
|
|
|
34
47
|
|
|
35
48
|
def null_initialization_hook(config: AppConfig) -> None:
|
|
36
|
-
|
|
49
|
+
"""A no-op initialization hook."""
|
|
50
|
+
logger.trace("Executing null initialization hook")
|
|
37
51
|
|
|
38
52
|
|
|
39
53
|
def null_shutdown_hook(config: AppConfig) -> None:
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
def filter_last_human_message_hook(
|
|
44
|
-
state: dict[str, Any], runtime: Runtime[Context]
|
|
45
|
-
) -> dict[str, Any]:
|
|
46
|
-
"""
|
|
47
|
-
Filter messages to keep only the last human message.
|
|
48
|
-
|
|
49
|
-
This hook removes all messages except for the most recent human message,
|
|
50
|
-
which can be useful for scenarios where you want to process only the
|
|
51
|
-
latest user input without conversation history.
|
|
52
|
-
|
|
53
|
-
Args:
|
|
54
|
-
state: The current state containing messages
|
|
55
|
-
runtime: The runtime context (unused in this hook)
|
|
56
|
-
|
|
57
|
-
Returns:
|
|
58
|
-
Updated state with filtered messages
|
|
59
|
-
"""
|
|
60
|
-
logger.debug("Executing filter_last_human_message hook")
|
|
61
|
-
|
|
62
|
-
messages: list[BaseMessage] = state.get("messages", [])
|
|
63
|
-
|
|
64
|
-
if not messages:
|
|
65
|
-
logger.debug("No messages found in state")
|
|
66
|
-
return state
|
|
67
|
-
|
|
68
|
-
# Use the helper function to find the last human message
|
|
69
|
-
last_message: HumanMessage = last_human_message(messages)
|
|
70
|
-
|
|
71
|
-
if last_message is None:
|
|
72
|
-
logger.debug("No human messages found in state")
|
|
73
|
-
# Return empty messages if no human message found
|
|
74
|
-
updated_state = state.copy()
|
|
75
|
-
updated_state["messages"] = []
|
|
76
|
-
return updated_state
|
|
77
|
-
|
|
78
|
-
logger.debug(f"Filtered {len(messages)} messages down to 1 (last human message)")
|
|
79
|
-
|
|
80
|
-
removed_messages: Sequence[BaseMessage] = [
|
|
81
|
-
RemoveMessage(id=message.id)
|
|
82
|
-
for message in messages
|
|
83
|
-
if message.id != last_message.id
|
|
84
|
-
]
|
|
85
|
-
|
|
86
|
-
updated_state: dict[str, Sequence[BaseMessage]] = {"messages": removed_messages}
|
|
87
|
-
|
|
88
|
-
return updated_state
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
def require_user_id_hook(
|
|
92
|
-
state: dict[str, Any], runtime: Runtime[Context]
|
|
93
|
-
) -> dict[str, Any]:
|
|
94
|
-
logger.debug("Executing user_id validation hook")
|
|
95
|
-
|
|
96
|
-
context: Context = runtime.context or Context()
|
|
97
|
-
|
|
98
|
-
user_id: str | None = context.user_id
|
|
99
|
-
|
|
100
|
-
if not user_id:
|
|
101
|
-
logger.error("User ID is required but not provided in the configuration.")
|
|
102
|
-
|
|
103
|
-
# Create corrected configuration using any provided context parameters
|
|
104
|
-
corrected_config = {
|
|
105
|
-
"configurable": {
|
|
106
|
-
"thread_id": context.thread_id or "1",
|
|
107
|
-
"user_id": "my_user_id",
|
|
108
|
-
"store_num": context.store_num or 87887,
|
|
109
|
-
}
|
|
110
|
-
}
|
|
111
|
-
|
|
112
|
-
# Format as JSON for copy-paste
|
|
113
|
-
corrected_config_json = json.dumps(corrected_config, indent=2)
|
|
114
|
-
|
|
115
|
-
error_message = f"""
|
|
116
|
-
## Authentication Required
|
|
117
|
-
|
|
118
|
-
A **user_id** is required to process your request. Please provide your user ID in the configuration.
|
|
119
|
-
|
|
120
|
-
### Required Configuration Format
|
|
121
|
-
|
|
122
|
-
Please include the following JSON in your request configuration:
|
|
123
|
-
|
|
124
|
-
```json
|
|
125
|
-
{corrected_config_json}
|
|
126
|
-
```
|
|
127
|
-
|
|
128
|
-
### Field Descriptions
|
|
129
|
-
- **user_id**: Your unique user identifier (required)
|
|
130
|
-
- **thread_id**: Conversation thread identifier (required)
|
|
131
|
-
- **store_num**: Your store number (required)
|
|
132
|
-
|
|
133
|
-
Please update your configuration and try again.
|
|
134
|
-
""".strip()
|
|
135
|
-
|
|
136
|
-
raise ValueError(error_message)
|
|
137
|
-
|
|
138
|
-
if "." in user_id:
|
|
139
|
-
logger.error(f"User ID '{user_id}' contains invalid character '.'")
|
|
140
|
-
|
|
141
|
-
# Create a corrected version of the user_id
|
|
142
|
-
corrected_user_id = user_id.replace(".", "_")
|
|
143
|
-
|
|
144
|
-
# Create corrected configuration for the error message
|
|
145
|
-
|
|
146
|
-
# Corrected config with fixed user_id
|
|
147
|
-
corrected_config = {
|
|
148
|
-
"configurable": {
|
|
149
|
-
"thread_id": context.thread_id or "1",
|
|
150
|
-
"user_id": corrected_user_id,
|
|
151
|
-
"store_num": context.store_num or 87887,
|
|
152
|
-
}
|
|
153
|
-
}
|
|
154
|
-
|
|
155
|
-
# Format as JSON for copy-paste
|
|
156
|
-
corrected_config_json = json.dumps(corrected_config, indent=2)
|
|
157
|
-
|
|
158
|
-
error_message = f"""
|
|
159
|
-
## Invalid User ID Format
|
|
160
|
-
|
|
161
|
-
The **user_id** cannot contain a dot character ('.'). Please provide a valid user ID without dots.
|
|
162
|
-
|
|
163
|
-
### Corrected Configuration (Copy & Paste This)
|
|
164
|
-
```json
|
|
165
|
-
{corrected_config_json}
|
|
166
|
-
```
|
|
167
|
-
|
|
168
|
-
Please update your user_id and try again.
|
|
169
|
-
""".strip()
|
|
170
|
-
|
|
171
|
-
raise ValueError(error_message)
|
|
172
|
-
|
|
173
|
-
return {}
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
def require_thread_id_hook(
|
|
177
|
-
state: dict[str, Any], runtime: Runtime[Context]
|
|
178
|
-
) -> dict[str, Any]:
|
|
179
|
-
logger.debug("Executing thread_id validation hook")
|
|
180
|
-
|
|
181
|
-
context: Context = runtime.context or Context()
|
|
182
|
-
|
|
183
|
-
thread_id: str | None = context.thread_id
|
|
184
|
-
|
|
185
|
-
if not thread_id:
|
|
186
|
-
logger.error("Thread ID is required but not provided in the configuration.")
|
|
187
|
-
|
|
188
|
-
# Create corrected configuration using any provided context parameters
|
|
189
|
-
corrected_config = {
|
|
190
|
-
"configurable": {
|
|
191
|
-
"thread_id": "1",
|
|
192
|
-
"user_id": context.user_id or "my_user_id",
|
|
193
|
-
"store_num": context.store_num or 87887,
|
|
194
|
-
}
|
|
195
|
-
}
|
|
196
|
-
|
|
197
|
-
# Format as JSON for copy-paste
|
|
198
|
-
corrected_config_json = json.dumps(corrected_config, indent=2)
|
|
199
|
-
|
|
200
|
-
error_message = f"""
|
|
201
|
-
## Authentication Required
|
|
202
|
-
|
|
203
|
-
A **thread_id** is required to process your request. Please provide your thread ID in the configuration.
|
|
204
|
-
|
|
205
|
-
### Required Configuration Format
|
|
206
|
-
|
|
207
|
-
Please include the following JSON in your request configuration:
|
|
208
|
-
|
|
209
|
-
```json
|
|
210
|
-
{corrected_config_json}
|
|
211
|
-
```
|
|
212
|
-
|
|
213
|
-
### Field Descriptions
|
|
214
|
-
- **thread_id**: Conversation thread identifier (required)
|
|
215
|
-
- **user_id**: Your unique user identifier (required)
|
|
216
|
-
- **store_num**: Your store number (required)
|
|
217
|
-
|
|
218
|
-
Please update your configuration and try again.
|
|
219
|
-
""".strip()
|
|
220
|
-
|
|
221
|
-
raise ValueError(error_message)
|
|
222
|
-
|
|
223
|
-
return {}
|
|
54
|
+
"""A no-op shutdown hook."""
|
|
55
|
+
logger.trace("Executing null shutdown hook")
|
dao_ai/logging.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
"""Logging configuration for DAO AI."""
|
|
2
|
+
|
|
3
|
+
import sys
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from loguru import logger
|
|
7
|
+
|
|
8
|
+
# Re-export logger for convenience
|
|
9
|
+
__all__ = ["logger", "configure_logging"]
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def format_extra(record: dict[str, Any]) -> str:
|
|
13
|
+
"""Format extra fields as key=value pairs."""
|
|
14
|
+
extra: dict[str, Any] = record["extra"]
|
|
15
|
+
if not extra:
|
|
16
|
+
return ""
|
|
17
|
+
|
|
18
|
+
formatted_pairs: list[str] = []
|
|
19
|
+
for key, value in extra.items():
|
|
20
|
+
# Handle different value types
|
|
21
|
+
if isinstance(value, str):
|
|
22
|
+
formatted_pairs.append(f"{key}={value}")
|
|
23
|
+
elif isinstance(value, (list, tuple)):
|
|
24
|
+
formatted_pairs.append(f"{key}={','.join(str(v) for v in value)}")
|
|
25
|
+
else:
|
|
26
|
+
formatted_pairs.append(f"{key}={value}")
|
|
27
|
+
|
|
28
|
+
return " | ".join(formatted_pairs)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def configure_logging(level: str = "INFO") -> None:
|
|
32
|
+
"""
|
|
33
|
+
Configure loguru logging with structured output.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
level: The log level (e.g., "INFO", "DEBUG", "WARNING")
|
|
37
|
+
"""
|
|
38
|
+
logger.remove()
|
|
39
|
+
logger.add(
|
|
40
|
+
sys.stderr,
|
|
41
|
+
level=level,
|
|
42
|
+
format=(
|
|
43
|
+
"<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | "
|
|
44
|
+
"<level>{level: <8}</level> | "
|
|
45
|
+
"<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - "
|
|
46
|
+
"<level>{message}</level>"
|
|
47
|
+
"{extra}"
|
|
48
|
+
),
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
# Add custom formatter for extra fields
|
|
52
|
+
logger.configure(
|
|
53
|
+
patcher=lambda record: record.update(
|
|
54
|
+
extra=" | " + format_extra(record) if record["extra"] else ""
|
|
55
|
+
)
|
|
56
|
+
)
|
dao_ai/memory/__init__.py
CHANGED
|
@@ -3,10 +3,20 @@ from dao_ai.memory.base import (
|
|
|
3
3
|
StoreManagerBase,
|
|
4
4
|
)
|
|
5
5
|
from dao_ai.memory.core import CheckpointManager, StoreManager
|
|
6
|
+
from dao_ai.memory.databricks import (
|
|
7
|
+
AsyncDatabricksCheckpointSaver,
|
|
8
|
+
AsyncDatabricksStore,
|
|
9
|
+
DatabricksCheckpointerManager,
|
|
10
|
+
DatabricksStoreManager,
|
|
11
|
+
)
|
|
6
12
|
|
|
7
13
|
__all__ = [
|
|
8
14
|
"CheckpointManagerBase",
|
|
9
15
|
"StoreManagerBase",
|
|
10
16
|
"CheckpointManager",
|
|
11
17
|
"StoreManager",
|
|
18
|
+
"AsyncDatabricksCheckpointSaver",
|
|
19
|
+
"AsyncDatabricksStore",
|
|
20
|
+
"DatabricksCheckpointerManager",
|
|
21
|
+
"DatabricksStoreManager",
|
|
12
22
|
]
|