dao-ai 0.0.36__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/__init__.py +29 -0
- dao_ai/cli.py +195 -30
- dao_ai/config.py +770 -244
- dao_ai/genie/__init__.py +1 -22
- dao_ai/genie/cache/__init__.py +1 -2
- dao_ai/genie/cache/base.py +20 -70
- dao_ai/genie/cache/core.py +75 -0
- dao_ai/genie/cache/lru.py +44 -21
- dao_ai/genie/cache/semantic.py +390 -109
- dao_ai/genie/core.py +35 -0
- dao_ai/graph.py +27 -253
- dao_ai/hooks/__init__.py +9 -6
- dao_ai/hooks/core.py +22 -190
- dao_ai/memory/__init__.py +10 -0
- dao_ai/memory/core.py +23 -5
- dao_ai/memory/databricks.py +389 -0
- dao_ai/memory/postgres.py +2 -2
- dao_ai/messages.py +6 -4
- dao_ai/middleware/__init__.py +125 -0
- dao_ai/middleware/assertions.py +778 -0
- dao_ai/middleware/base.py +50 -0
- dao_ai/middleware/core.py +61 -0
- dao_ai/middleware/guardrails.py +415 -0
- dao_ai/middleware/human_in_the_loop.py +228 -0
- dao_ai/middleware/message_validation.py +554 -0
- dao_ai/middleware/summarization.py +192 -0
- dao_ai/models.py +1177 -108
- dao_ai/nodes.py +118 -161
- dao_ai/optimization.py +664 -0
- dao_ai/orchestration/__init__.py +52 -0
- dao_ai/orchestration/core.py +287 -0
- dao_ai/orchestration/supervisor.py +264 -0
- dao_ai/orchestration/swarm.py +226 -0
- dao_ai/prompts.py +126 -29
- dao_ai/providers/databricks.py +126 -381
- dao_ai/state.py +139 -21
- dao_ai/tools/__init__.py +8 -5
- dao_ai/tools/core.py +57 -4
- dao_ai/tools/email.py +280 -0
- dao_ai/tools/genie.py +47 -24
- dao_ai/tools/mcp.py +4 -3
- dao_ai/tools/memory.py +50 -0
- dao_ai/tools/python.py +4 -12
- dao_ai/tools/search.py +14 -0
- dao_ai/tools/slack.py +1 -1
- dao_ai/tools/unity_catalog.py +8 -6
- dao_ai/tools/vector_search.py +16 -9
- dao_ai/utils.py +72 -8
- dao_ai-0.1.1.dist-info/METADATA +1878 -0
- dao_ai-0.1.1.dist-info/RECORD +62 -0
- dao_ai/chat_models.py +0 -204
- dao_ai/guardrails.py +0 -112
- dao_ai/tools/genie/__init__.py +0 -236
- dao_ai/tools/human_in_the_loop.py +0 -100
- dao_ai-0.0.36.dist-info/METADATA +0 -951
- dao_ai-0.0.36.dist-info/RECORD +0 -47
- {dao_ai-0.0.36.dist-info → dao_ai-0.1.1.dist-info}/WHEEL +0 -0
- {dao_ai-0.0.36.dist-info → dao_ai-0.1.1.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.0.36.dist-info → dao_ai-0.1.1.dist-info}/licenses/LICENSE +0 -0
dao_ai/nodes.py
CHANGED
|
@@ -1,101 +1,103 @@
|
|
|
1
|
-
|
|
1
|
+
"""
|
|
2
|
+
Node creation utilities for DAO AI agents.
|
|
2
3
|
|
|
3
|
-
|
|
4
|
+
This module provides factory functions for creating LangGraph nodes
|
|
5
|
+
that implement agent logic using LangChain v1's create_agent pattern.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import Any, Optional, Sequence
|
|
9
|
+
|
|
10
|
+
from langchain.agents import create_agent
|
|
11
|
+
from langchain.agents.middleware import AgentMiddleware
|
|
4
12
|
from langchain_core.language_models import LanguageModelLike
|
|
5
|
-
from langchain_core.messages import AIMessage, AnyMessage, BaseMessage
|
|
6
|
-
from langchain_core.messages.utils import count_tokens_approximately
|
|
7
|
-
from langchain_core.runnables import RunnableConfig
|
|
8
13
|
from langchain_core.runnables.base import RunnableLike
|
|
9
14
|
from langchain_core.tools import BaseTool
|
|
10
|
-
from langgraph.graph import StateGraph
|
|
11
15
|
from langgraph.graph.state import CompiledStateGraph
|
|
12
|
-
from
|
|
13
|
-
from langgraph.runtime import Runtime
|
|
14
|
-
from langmem import create_manage_memory_tool, create_search_memory_tool
|
|
15
|
-
from langmem.short_term import SummarizationNode
|
|
16
|
-
from langmem.short_term.summarization import TokenCounter
|
|
16
|
+
from langmem import create_manage_memory_tool
|
|
17
17
|
from loguru import logger
|
|
18
18
|
|
|
19
19
|
from dao_ai.config import (
|
|
20
20
|
AgentModel,
|
|
21
|
-
AppConfig,
|
|
22
21
|
ChatHistoryModel,
|
|
23
|
-
FunctionHook,
|
|
24
22
|
MemoryModel,
|
|
23
|
+
PromptModel,
|
|
25
24
|
ToolModel,
|
|
26
25
|
)
|
|
27
|
-
from dao_ai.
|
|
28
|
-
from dao_ai.
|
|
26
|
+
from dao_ai.middleware.core import create_factory_middleware
|
|
27
|
+
from dao_ai.middleware.guardrails import GuardrailMiddleware
|
|
28
|
+
from dao_ai.middleware.human_in_the_loop import create_hitl_middleware_from_tool_models
|
|
29
|
+
from dao_ai.middleware.summarization import create_summarization_middleware
|
|
29
30
|
from dao_ai.prompts import make_prompt
|
|
30
|
-
from dao_ai.state import
|
|
31
|
+
from dao_ai.state import AgentState, Context
|
|
31
32
|
from dao_ai.tools import create_tools
|
|
33
|
+
from dao_ai.tools.memory import create_search_memory_tool
|
|
32
34
|
|
|
33
35
|
|
|
34
|
-
def
|
|
36
|
+
def _create_middleware_list(
|
|
37
|
+
agent: AgentModel,
|
|
38
|
+
tool_models: Sequence[ToolModel],
|
|
39
|
+
chat_history: Optional[ChatHistoryModel] = None,
|
|
40
|
+
) -> list[Any]:
|
|
35
41
|
"""
|
|
36
|
-
Create a
|
|
42
|
+
Create a list of middleware instances from agent configuration.
|
|
37
43
|
|
|
38
44
|
Args:
|
|
39
|
-
|
|
45
|
+
agent: AgentModel configuration
|
|
46
|
+
tool_models: Tool model configurations (for HITL settings)
|
|
47
|
+
chat_history: Optional chat history configuration for summarization
|
|
40
48
|
|
|
41
49
|
Returns:
|
|
42
|
-
|
|
50
|
+
List of middleware instances (can include both AgentMiddleware and
|
|
51
|
+
LangChain built-in middleware)
|
|
43
52
|
"""
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
return node
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
def call_agent_with_summarized_messages(agent: CompiledStateGraph) -> RunnableLike:
|
|
78
|
-
async def call_agent(state: SharedState, runtime: Runtime[Context]) -> SharedState:
|
|
79
|
-
logger.debug(f"Calling agent {agent.name} with summarized messages")
|
|
80
|
-
|
|
81
|
-
# Get the summarized messages from the summarization node
|
|
82
|
-
messages: Sequence[AnyMessage] = state.get("summarized_messages", [])
|
|
83
|
-
logger.debug(f"Found {len(messages)} summarized messages")
|
|
84
|
-
logger.trace(f"Summarized messages: {[m.model_dump() for m in messages]}")
|
|
85
|
-
|
|
86
|
-
input: dict[str, Any] = {
|
|
87
|
-
"messages": messages,
|
|
88
|
-
}
|
|
89
|
-
|
|
90
|
-
response: dict[str, Any] = await agent.ainvoke(
|
|
91
|
-
input=input, context=runtime.context
|
|
53
|
+
logger.debug(f"Building middleware list for agent '{agent.name}'")
|
|
54
|
+
middleware_list: list[Any] = []
|
|
55
|
+
|
|
56
|
+
# Add configured middleware using factory pattern
|
|
57
|
+
if agent.middleware:
|
|
58
|
+
logger.debug(f"Processing {len(agent.middleware)} configured middleware")
|
|
59
|
+
for middleware_config in agent.middleware:
|
|
60
|
+
middleware = create_factory_middleware(
|
|
61
|
+
function_name=middleware_config.name,
|
|
62
|
+
args=middleware_config.args,
|
|
63
|
+
)
|
|
64
|
+
if middleware is not None:
|
|
65
|
+
middleware_list.append(middleware)
|
|
66
|
+
|
|
67
|
+
# Add guardrails as middleware
|
|
68
|
+
if agent.guardrails:
|
|
69
|
+
logger.debug(f"Adding {len(agent.guardrails)} guardrail middleware")
|
|
70
|
+
for guardrail in agent.guardrails:
|
|
71
|
+
# Extract template string from PromptModel if needed
|
|
72
|
+
prompt_str: str
|
|
73
|
+
if isinstance(guardrail.prompt, PromptModel):
|
|
74
|
+
prompt_str = guardrail.prompt.template
|
|
75
|
+
else:
|
|
76
|
+
prompt_str = guardrail.prompt
|
|
77
|
+
|
|
78
|
+
guardrail_middleware = GuardrailMiddleware(
|
|
79
|
+
name=guardrail.name,
|
|
80
|
+
model=guardrail.model.as_chat_model(),
|
|
81
|
+
prompt=prompt_str,
|
|
82
|
+
num_retries=guardrail.num_retries or 3,
|
|
92
83
|
)
|
|
93
|
-
|
|
94
|
-
|
|
84
|
+
logger.debug(f"Created guardrail middleware: {guardrail.name}")
|
|
85
|
+
middleware_list.append(guardrail_middleware)
|
|
95
86
|
|
|
96
|
-
|
|
87
|
+
# Add summarization middleware if chat_history is configured
|
|
88
|
+
if chat_history is not None:
|
|
89
|
+
logger.debug("Adding summarization middleware")
|
|
90
|
+
summarization_middleware = create_summarization_middleware(chat_history)
|
|
91
|
+
middleware_list.append(summarization_middleware)
|
|
97
92
|
|
|
98
|
-
|
|
93
|
+
# Add human-in-the-loop middleware if any tools require it
|
|
94
|
+
hitl_middleware = create_hitl_middleware_from_tool_models(tool_models)
|
|
95
|
+
if hitl_middleware is not None:
|
|
96
|
+
logger.debug("Added human-in-the-loop middleware")
|
|
97
|
+
middleware_list.append(hitl_middleware)
|
|
98
|
+
|
|
99
|
+
logger.debug(f"Total middleware count: {len(middleware_list)}")
|
|
100
|
+
return middleware_list
|
|
99
101
|
|
|
100
102
|
|
|
101
103
|
def create_agent_node(
|
|
@@ -107,9 +109,9 @@ def create_agent_node(
|
|
|
107
109
|
"""
|
|
108
110
|
Factory function that creates a LangGraph node for a specialized agent.
|
|
109
111
|
|
|
110
|
-
This creates
|
|
111
|
-
The function configures the agent with
|
|
112
|
-
|
|
112
|
+
This creates an agent using LangChain v1's create_agent function with
|
|
113
|
+
middleware for customization. The function configures the agent with
|
|
114
|
+
the appropriate model, prompt, tools, and middleware.
|
|
113
115
|
|
|
114
116
|
Args:
|
|
115
117
|
agent: AgentModel configuration for the agent
|
|
@@ -122,16 +124,12 @@ def create_agent_node(
|
|
|
122
124
|
"""
|
|
123
125
|
logger.debug(f"Creating agent node for {agent.name}")
|
|
124
126
|
|
|
125
|
-
if agent.create_agent_hook:
|
|
126
|
-
agent_hook = next(iter(create_hooks(agent.create_agent_hook)), None)
|
|
127
|
-
return agent_hook
|
|
128
|
-
|
|
129
127
|
llm: LanguageModelLike = agent.model.as_chat_model()
|
|
130
128
|
|
|
131
129
|
tool_models: Sequence[ToolModel] = agent.tools
|
|
132
130
|
if not additional_tools:
|
|
133
131
|
additional_tools = []
|
|
134
|
-
tools:
|
|
132
|
+
tools: list[BaseTool] = list(create_tools(tool_models)) + list(additional_tools)
|
|
135
133
|
|
|
136
134
|
if memory and memory.store:
|
|
137
135
|
namespace: tuple[str, ...] = ("memory",)
|
|
@@ -139,100 +137,59 @@ def create_agent_node(
|
|
|
139
137
|
namespace = namespace + (memory.store.namespace,)
|
|
140
138
|
logger.debug(f"Memory store namespace: {namespace}")
|
|
141
139
|
|
|
140
|
+
# Use Databricks-compatible search_memory tool (omits problematic filter field)
|
|
142
141
|
tools += [
|
|
143
142
|
create_manage_memory_tool(namespace=namespace),
|
|
144
143
|
create_search_memory_tool(namespace=namespace),
|
|
145
144
|
]
|
|
146
145
|
|
|
147
|
-
|
|
148
|
-
|
|
146
|
+
# Create middleware list from configuration
|
|
147
|
+
middleware_list = _create_middleware_list(
|
|
148
|
+
agent=agent,
|
|
149
|
+
tool_models=tool_models,
|
|
150
|
+
chat_history=chat_history,
|
|
149
151
|
)
|
|
150
|
-
logger.debug(f"pre_agent_hook: {pre_agent_hook}")
|
|
151
152
|
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
153
|
+
logger.debug(f"Created {len(middleware_list)} middleware for agent {agent.name}")
|
|
154
|
+
|
|
155
|
+
checkpointer: bool = memory is not None and memory.checkpointer is not None
|
|
156
|
+
|
|
157
|
+
# Get the prompt as middleware (always returns AgentMiddleware or None)
|
|
158
|
+
prompt_middleware: AgentMiddleware | None = make_prompt(agent.prompt)
|
|
159
|
+
|
|
160
|
+
# Add prompt middleware at the beginning for priority
|
|
161
|
+
if prompt_middleware is not None:
|
|
162
|
+
middleware_list.insert(0, prompt_middleware)
|
|
163
|
+
|
|
164
|
+
# Configure structured output if response_format is specified
|
|
165
|
+
response_format: Any = None
|
|
166
|
+
if agent.response_format is not None:
|
|
167
|
+
try:
|
|
168
|
+
response_format = agent.response_format.as_strategy()
|
|
169
|
+
if response_format is not None:
|
|
170
|
+
logger.debug(
|
|
171
|
+
f"Agent '{agent.name}' using structured output: {type(response_format).__name__}"
|
|
172
|
+
)
|
|
173
|
+
except ValueError as e:
|
|
174
|
+
logger.error(
|
|
175
|
+
f"Failed to configure structured output for agent {agent.name}: {e}"
|
|
176
|
+
)
|
|
177
|
+
raise
|
|
178
|
+
|
|
179
|
+
# Use LangChain v1's create_agent with middleware
|
|
180
|
+
# AgentState extends MessagesState with additional DAO AI fields
|
|
181
|
+
# System prompt is provided via middleware (dynamic_prompt)
|
|
182
|
+
compiled_agent: CompiledStateGraph = create_agent(
|
|
160
183
|
name=agent.name,
|
|
161
184
|
model=llm,
|
|
162
|
-
prompt=make_prompt(agent.prompt),
|
|
163
185
|
tools=tools,
|
|
164
|
-
|
|
186
|
+
middleware=middleware_list,
|
|
165
187
|
checkpointer=checkpointer,
|
|
166
|
-
state_schema=
|
|
188
|
+
state_schema=AgentState,
|
|
167
189
|
context_schema=Context,
|
|
168
|
-
|
|
169
|
-
post_model_hook=post_agent_hook,
|
|
190
|
+
response_format=response_format, # Add structured output support
|
|
170
191
|
)
|
|
171
192
|
|
|
172
|
-
for guardrail_definition in agent.guardrails:
|
|
173
|
-
guardrail: CompiledStateGraph = reflection_guardrail(guardrail_definition)
|
|
174
|
-
compiled_agent = with_guardrails(compiled_agent, guardrail)
|
|
175
|
-
|
|
176
193
|
compiled_agent.name = agent.name
|
|
177
194
|
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
if chat_history is None:
|
|
181
|
-
logger.debug("No chat history configured, using compiled agent directly")
|
|
182
|
-
agent_node = compiled_agent
|
|
183
|
-
else:
|
|
184
|
-
logger.debug("Creating agent node with chat history summarization")
|
|
185
|
-
workflow: StateGraph = StateGraph(
|
|
186
|
-
SharedState,
|
|
187
|
-
config_schema=RunnableConfig,
|
|
188
|
-
input=SharedState,
|
|
189
|
-
output=SharedState,
|
|
190
|
-
)
|
|
191
|
-
workflow.add_node("summarization", summarization_node(chat_history))
|
|
192
|
-
workflow.add_node(
|
|
193
|
-
"agent",
|
|
194
|
-
call_agent_with_summarized_messages(agent=compiled_agent),
|
|
195
|
-
)
|
|
196
|
-
workflow.add_edge("summarization", "agent")
|
|
197
|
-
workflow.set_entry_point("summarization")
|
|
198
|
-
agent_node = workflow.compile(name=agent.name)
|
|
199
|
-
|
|
200
|
-
return agent_node
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
def message_hook_node(config: AppConfig) -> RunnableLike:
|
|
204
|
-
message_hooks: Sequence[Callable[..., Any]] = create_hooks(config.app.message_hooks)
|
|
205
|
-
|
|
206
|
-
@mlflow.trace()
|
|
207
|
-
async def message_hook(
|
|
208
|
-
state: IncomingState, runtime: Runtime[Context]
|
|
209
|
-
) -> SharedState:
|
|
210
|
-
logger.debug("Running message validation")
|
|
211
|
-
response: dict[str, Any] = {"is_valid": True, "message_error": None}
|
|
212
|
-
|
|
213
|
-
for message_hook in message_hooks:
|
|
214
|
-
message_hook: FunctionHook
|
|
215
|
-
if message_hook:
|
|
216
|
-
try:
|
|
217
|
-
hook_response: dict[str, Any] = message_hook(
|
|
218
|
-
state=state,
|
|
219
|
-
runtime=runtime,
|
|
220
|
-
)
|
|
221
|
-
response.update(hook_response)
|
|
222
|
-
logger.debug(f"Hook response: {hook_response}")
|
|
223
|
-
if not response.get("is_valid", True):
|
|
224
|
-
break
|
|
225
|
-
except Exception as e:
|
|
226
|
-
logger.error(f"Message validation failed: {e}")
|
|
227
|
-
response_messages: Sequence[BaseMessage] = [
|
|
228
|
-
AIMessage(content=str(e))
|
|
229
|
-
]
|
|
230
|
-
return {
|
|
231
|
-
"is_valid": False,
|
|
232
|
-
"message_error": str(e),
|
|
233
|
-
"messages": response_messages,
|
|
234
|
-
}
|
|
235
|
-
|
|
236
|
-
return response
|
|
237
|
-
|
|
238
|
-
return message_hook
|
|
195
|
+
return compiled_agent
|