dao-ai 0.0.28__py3-none-any.whl → 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/__init__.py +29 -0
- dao_ai/agent_as_code.py +2 -5
- dao_ai/cli.py +245 -40
- dao_ai/config.py +1491 -370
- dao_ai/genie/__init__.py +38 -0
- dao_ai/genie/cache/__init__.py +43 -0
- dao_ai/genie/cache/base.py +72 -0
- dao_ai/genie/cache/core.py +79 -0
- dao_ai/genie/cache/lru.py +347 -0
- dao_ai/genie/cache/semantic.py +970 -0
- dao_ai/genie/core.py +35 -0
- dao_ai/graph.py +27 -253
- dao_ai/hooks/__init__.py +9 -6
- dao_ai/hooks/core.py +27 -195
- dao_ai/logging.py +56 -0
- dao_ai/memory/__init__.py +10 -0
- dao_ai/memory/core.py +65 -30
- dao_ai/memory/databricks.py +402 -0
- dao_ai/memory/postgres.py +79 -38
- dao_ai/messages.py +6 -4
- dao_ai/middleware/__init__.py +125 -0
- dao_ai/middleware/assertions.py +806 -0
- dao_ai/middleware/base.py +50 -0
- dao_ai/middleware/core.py +67 -0
- dao_ai/middleware/guardrails.py +420 -0
- dao_ai/middleware/human_in_the_loop.py +232 -0
- dao_ai/middleware/message_validation.py +586 -0
- dao_ai/middleware/summarization.py +197 -0
- dao_ai/models.py +1306 -114
- dao_ai/nodes.py +245 -159
- dao_ai/optimization.py +674 -0
- dao_ai/orchestration/__init__.py +52 -0
- dao_ai/orchestration/core.py +294 -0
- dao_ai/orchestration/supervisor.py +278 -0
- dao_ai/orchestration/swarm.py +271 -0
- dao_ai/prompts.py +128 -31
- dao_ai/providers/databricks.py +573 -601
- dao_ai/state.py +157 -21
- dao_ai/tools/__init__.py +13 -5
- dao_ai/tools/agent.py +1 -3
- dao_ai/tools/core.py +64 -11
- dao_ai/tools/email.py +232 -0
- dao_ai/tools/genie.py +144 -294
- dao_ai/tools/mcp.py +223 -155
- dao_ai/tools/memory.py +50 -0
- dao_ai/tools/python.py +9 -14
- dao_ai/tools/search.py +14 -0
- dao_ai/tools/slack.py +22 -10
- dao_ai/tools/sql.py +202 -0
- dao_ai/tools/time.py +30 -7
- dao_ai/tools/unity_catalog.py +165 -88
- dao_ai/tools/vector_search.py +331 -221
- dao_ai/utils.py +166 -20
- dao_ai-0.1.2.dist-info/METADATA +455 -0
- dao_ai-0.1.2.dist-info/RECORD +64 -0
- dao_ai/chat_models.py +0 -204
- dao_ai/guardrails.py +0 -112
- dao_ai/tools/human_in_the_loop.py +0 -100
- dao_ai-0.0.28.dist-info/METADATA +0 -1168
- dao_ai-0.0.28.dist-info/RECORD +0 -41
- {dao_ai-0.0.28.dist-info → dao_ai-0.1.2.dist-info}/WHEEL +0 -0
- {dao_ai-0.0.28.dist-info → dao_ai-0.1.2.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.0.28.dist-info → dao_ai-0.1.2.dist-info}/licenses/LICENSE +0 -0
dao_ai/nodes.py
CHANGED
|
@@ -1,101 +1,152 @@
|
|
|
1
|
-
|
|
1
|
+
"""
|
|
2
|
+
Node creation utilities for DAO AI agents.
|
|
2
3
|
|
|
3
|
-
|
|
4
|
+
This module provides factory functions for creating LangGraph nodes
|
|
5
|
+
that implement agent logic using LangChain v1's create_agent pattern.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import Any, Optional, Sequence
|
|
9
|
+
|
|
10
|
+
from langchain.agents import create_agent
|
|
11
|
+
from langchain.agents.middleware import AgentMiddleware
|
|
4
12
|
from langchain_core.language_models import LanguageModelLike
|
|
5
|
-
from langchain_core.messages import AIMessage, AnyMessage, BaseMessage
|
|
6
|
-
from langchain_core.messages.utils import count_tokens_approximately
|
|
7
|
-
from langchain_core.runnables import RunnableConfig
|
|
8
13
|
from langchain_core.runnables.base import RunnableLike
|
|
9
14
|
from langchain_core.tools import BaseTool
|
|
10
|
-
from langgraph.graph import StateGraph
|
|
11
15
|
from langgraph.graph.state import CompiledStateGraph
|
|
12
|
-
from
|
|
13
|
-
from langgraph.runtime import Runtime
|
|
14
|
-
from langmem import create_manage_memory_tool, create_search_memory_tool
|
|
15
|
-
from langmem.short_term import SummarizationNode
|
|
16
|
-
from langmem.short_term.summarization import TokenCounter
|
|
16
|
+
from langmem import create_manage_memory_tool
|
|
17
17
|
from loguru import logger
|
|
18
18
|
|
|
19
19
|
from dao_ai.config import (
|
|
20
20
|
AgentModel,
|
|
21
|
-
AppConfig,
|
|
22
21
|
ChatHistoryModel,
|
|
23
|
-
FunctionHook,
|
|
24
22
|
MemoryModel,
|
|
23
|
+
PromptModel,
|
|
25
24
|
ToolModel,
|
|
26
25
|
)
|
|
27
|
-
from dao_ai.
|
|
28
|
-
from dao_ai.
|
|
26
|
+
from dao_ai.middleware.core import create_factory_middleware
|
|
27
|
+
from dao_ai.middleware.guardrails import GuardrailMiddleware
|
|
28
|
+
from dao_ai.middleware.human_in_the_loop import (
|
|
29
|
+
HumanInTheLoopMiddleware,
|
|
30
|
+
create_hitl_middleware_from_tool_models,
|
|
31
|
+
)
|
|
32
|
+
from dao_ai.middleware.summarization import (
|
|
33
|
+
LoggingSummarizationMiddleware,
|
|
34
|
+
create_summarization_middleware,
|
|
35
|
+
)
|
|
29
36
|
from dao_ai.prompts import make_prompt
|
|
30
|
-
from dao_ai.state import
|
|
37
|
+
from dao_ai.state import AgentState, Context
|
|
31
38
|
from dao_ai.tools import create_tools
|
|
39
|
+
from dao_ai.tools.memory import create_search_memory_tool
|
|
32
40
|
|
|
33
41
|
|
|
34
|
-
def
|
|
42
|
+
def _create_middleware_list(
|
|
43
|
+
agent: AgentModel,
|
|
44
|
+
tool_models: Sequence[ToolModel],
|
|
45
|
+
chat_history: Optional[ChatHistoryModel] = None,
|
|
46
|
+
) -> list[Any]:
|
|
35
47
|
"""
|
|
36
|
-
Create a
|
|
48
|
+
Create a list of middleware instances from agent configuration.
|
|
37
49
|
|
|
38
50
|
Args:
|
|
39
|
-
|
|
51
|
+
agent: AgentModel configuration
|
|
52
|
+
tool_models: Tool model configurations (for HITL settings)
|
|
53
|
+
chat_history: Optional chat history configuration for summarization
|
|
40
54
|
|
|
41
55
|
Returns:
|
|
42
|
-
|
|
56
|
+
List of middleware instances (can include both AgentMiddleware and
|
|
57
|
+
LangChain built-in middleware)
|
|
43
58
|
"""
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
59
|
+
logger.debug("Building middleware list for agent", agent=agent.name)
|
|
60
|
+
middleware_list: list[Any] = []
|
|
61
|
+
|
|
62
|
+
# Add configured middleware using factory pattern
|
|
63
|
+
if agent.middleware:
|
|
64
|
+
middleware_names: list[str] = [mw.name for mw in agent.middleware]
|
|
65
|
+
logger.info(
|
|
66
|
+
"Middleware configuration",
|
|
67
|
+
agent=agent.name,
|
|
68
|
+
middleware_count=len(agent.middleware),
|
|
69
|
+
middleware_names=middleware_names,
|
|
70
|
+
)
|
|
71
|
+
for middleware_config in agent.middleware:
|
|
72
|
+
logger.trace(
|
|
73
|
+
"Creating middleware for agent",
|
|
74
|
+
agent=agent.name,
|
|
75
|
+
middleware_name=middleware_config.name,
|
|
76
|
+
)
|
|
77
|
+
middleware: AgentMiddleware[AgentState, Context] = create_factory_middleware(
|
|
78
|
+
function_name=middleware_config.name,
|
|
79
|
+
args=middleware_config.args,
|
|
80
|
+
)
|
|
81
|
+
if middleware is not None:
|
|
82
|
+
middleware_list.append(middleware)
|
|
83
|
+
|
|
84
|
+
# Add guardrails as middleware
|
|
85
|
+
if agent.guardrails:
|
|
86
|
+
guardrail_names: list[str] = [gr.name for gr in agent.guardrails]
|
|
87
|
+
logger.info(
|
|
88
|
+
"Guardrails configuration",
|
|
89
|
+
agent=agent.name,
|
|
90
|
+
guardrails_count=len(agent.guardrails),
|
|
91
|
+
guardrail_names=guardrail_names,
|
|
92
|
+
)
|
|
93
|
+
for guardrail in agent.guardrails:
|
|
94
|
+
# Extract template string from PromptModel if needed
|
|
95
|
+
prompt_str: str
|
|
96
|
+
if isinstance(guardrail.prompt, PromptModel):
|
|
97
|
+
prompt_str = guardrail.prompt.template
|
|
98
|
+
else:
|
|
99
|
+
prompt_str = guardrail.prompt
|
|
100
|
+
|
|
101
|
+
guardrail_middleware: GuardrailMiddleware = GuardrailMiddleware(
|
|
102
|
+
name=guardrail.name,
|
|
103
|
+
model=guardrail.model.as_chat_model(),
|
|
104
|
+
prompt=prompt_str,
|
|
105
|
+
num_retries=guardrail.num_retries or 3,
|
|
106
|
+
)
|
|
107
|
+
logger.trace(
|
|
108
|
+
"Created guardrail middleware", guardrail=guardrail.name, agent=agent.name
|
|
109
|
+
)
|
|
110
|
+
middleware_list.append(guardrail_middleware)
|
|
111
|
+
|
|
112
|
+
# Add summarization middleware if chat_history is configured
|
|
113
|
+
if chat_history is not None:
|
|
114
|
+
logger.info(
|
|
115
|
+
"Chat history configuration",
|
|
116
|
+
agent=agent.name,
|
|
117
|
+
max_tokens=chat_history.max_tokens,
|
|
118
|
+
summary_model=chat_history.model.name,
|
|
119
|
+
)
|
|
120
|
+
summarization_middleware: LoggingSummarizationMiddleware = (
|
|
121
|
+
create_summarization_middleware(chat_history)
|
|
122
|
+
)
|
|
123
|
+
middleware_list.append(summarization_middleware)
|
|
61
124
|
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
model=summarization_model,
|
|
66
|
-
max_tokens=max_tokens,
|
|
67
|
-
max_tokens_before_summary=max_tokens_before_summary
|
|
68
|
-
or max_messages_before_summary,
|
|
69
|
-
max_summary_tokens=max_summary_tokens,
|
|
70
|
-
token_counter=token_counter,
|
|
71
|
-
input_messages_key="messages",
|
|
72
|
-
output_messages_key="summarized_messages",
|
|
125
|
+
# Add human-in-the-loop middleware if any tools require it
|
|
126
|
+
hitl_middleware: HumanInTheLoopMiddleware | None = (
|
|
127
|
+
create_hitl_middleware_from_tool_models(tool_models)
|
|
73
128
|
)
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
input: dict[str, Any] = {
|
|
87
|
-
"messages": messages,
|
|
88
|
-
}
|
|
89
|
-
|
|
90
|
-
response: dict[str, Any] = await agent.ainvoke(
|
|
91
|
-
input=input, context=runtime.context
|
|
129
|
+
if hitl_middleware is not None:
|
|
130
|
+
# Log which tools require HITL
|
|
131
|
+
hitl_tool_names: list[str] = [
|
|
132
|
+
tool.name
|
|
133
|
+
for tool in tool_models
|
|
134
|
+
if hasattr(tool.function, "human_in_the_loop")
|
|
135
|
+
and tool.function.human_in_the_loop is not None
|
|
136
|
+
]
|
|
137
|
+
logger.info(
|
|
138
|
+
"Human-in-the-Loop configuration",
|
|
139
|
+
agent=agent.name,
|
|
140
|
+
hitl_tools=hitl_tool_names,
|
|
92
141
|
)
|
|
93
|
-
|
|
94
|
-
logger.debug(f"Agent returned {len(response_messages)} messages")
|
|
95
|
-
|
|
96
|
-
return {"messages": response_messages}
|
|
142
|
+
middleware_list.append(hitl_middleware)
|
|
97
143
|
|
|
98
|
-
|
|
144
|
+
logger.info(
|
|
145
|
+
"Middleware summary",
|
|
146
|
+
agent=agent.name,
|
|
147
|
+
total_middleware_count=len(middleware_list),
|
|
148
|
+
)
|
|
149
|
+
return middleware_list
|
|
99
150
|
|
|
100
151
|
|
|
101
152
|
def create_agent_node(
|
|
@@ -107,9 +158,9 @@ def create_agent_node(
|
|
|
107
158
|
"""
|
|
108
159
|
Factory function that creates a LangGraph node for a specialized agent.
|
|
109
160
|
|
|
110
|
-
This creates
|
|
111
|
-
The function configures the agent with
|
|
112
|
-
|
|
161
|
+
This creates an agent using LangChain v1's create_agent function with
|
|
162
|
+
middleware for customization. The function configures the agent with
|
|
163
|
+
the appropriate model, prompt, tools, and middleware.
|
|
113
164
|
|
|
114
165
|
Args:
|
|
115
166
|
agent: AgentModel configuration for the agent
|
|
@@ -120,119 +171,154 @@ def create_agent_node(
|
|
|
120
171
|
Returns:
|
|
121
172
|
RunnableLike: An agent node that processes state and returns responses
|
|
122
173
|
"""
|
|
123
|
-
logger.
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
174
|
+
logger.info("Creating agent node", agent=agent.name)
|
|
175
|
+
|
|
176
|
+
# Log agent configuration details
|
|
177
|
+
logger.info(
|
|
178
|
+
"Agent configuration",
|
|
179
|
+
agent=agent.name,
|
|
180
|
+
model=agent.model.name,
|
|
181
|
+
description=agent.description or "No description",
|
|
182
|
+
)
|
|
128
183
|
|
|
129
184
|
llm: LanguageModelLike = agent.model.as_chat_model()
|
|
130
185
|
|
|
131
186
|
tool_models: Sequence[ToolModel] = agent.tools
|
|
132
187
|
if not additional_tools:
|
|
133
188
|
additional_tools = []
|
|
134
|
-
|
|
189
|
+
|
|
190
|
+
# Log tools being created
|
|
191
|
+
tool_names: list[str] = [tool.name for tool in tool_models]
|
|
192
|
+
logger.info(
|
|
193
|
+
"Tools configuration",
|
|
194
|
+
agent=agent.name,
|
|
195
|
+
tools_count=len(tool_models),
|
|
196
|
+
tool_names=tool_names,
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
tools: list[BaseTool] = list(create_tools(tool_models)) + list(additional_tools)
|
|
200
|
+
|
|
201
|
+
if additional_tools:
|
|
202
|
+
logger.debug(
|
|
203
|
+
"Additional tools added",
|
|
204
|
+
agent=agent.name,
|
|
205
|
+
additional_count=len(additional_tools),
|
|
206
|
+
)
|
|
135
207
|
|
|
136
208
|
if memory and memory.store:
|
|
137
209
|
namespace: tuple[str, ...] = ("memory",)
|
|
138
210
|
if memory.store.namespace:
|
|
139
211
|
namespace = namespace + (memory.store.namespace,)
|
|
140
|
-
logger.
|
|
212
|
+
logger.info(
|
|
213
|
+
"Memory configuration",
|
|
214
|
+
agent=agent.name,
|
|
215
|
+
has_store=True,
|
|
216
|
+
has_checkpointer=memory.checkpointer is not None,
|
|
217
|
+
namespace=namespace,
|
|
218
|
+
)
|
|
219
|
+
elif memory:
|
|
220
|
+
logger.info(
|
|
221
|
+
"Memory configuration",
|
|
222
|
+
agent=agent.name,
|
|
223
|
+
has_store=False,
|
|
224
|
+
has_checkpointer=memory.checkpointer is not None,
|
|
225
|
+
)
|
|
141
226
|
|
|
227
|
+
# Add memory tools if store is configured
|
|
228
|
+
if memory and memory.store:
|
|
229
|
+
# Use Databricks-compatible search_memory tool (omits problematic filter field)
|
|
142
230
|
tools += [
|
|
143
231
|
create_manage_memory_tool(namespace=namespace),
|
|
144
232
|
create_search_memory_tool(namespace=namespace),
|
|
145
233
|
]
|
|
234
|
+
logger.debug(
|
|
235
|
+
"Memory tools added",
|
|
236
|
+
agent=agent.name,
|
|
237
|
+
tools=["manage_memory", "search_memory"],
|
|
238
|
+
)
|
|
146
239
|
|
|
147
|
-
|
|
148
|
-
|
|
240
|
+
# Create middleware list from configuration
|
|
241
|
+
middleware_list = _create_middleware_list(
|
|
242
|
+
agent=agent,
|
|
243
|
+
tool_models=tool_models,
|
|
244
|
+
chat_history=chat_history,
|
|
149
245
|
)
|
|
150
|
-
logger.debug(f"pre_agent_hook: {pre_agent_hook}")
|
|
151
246
|
|
|
152
|
-
|
|
153
|
-
|
|
247
|
+
# Log prompt configuration
|
|
248
|
+
if agent.prompt:
|
|
249
|
+
if isinstance(agent.prompt, PromptModel):
|
|
250
|
+
logger.info(
|
|
251
|
+
"Prompt configuration",
|
|
252
|
+
agent=agent.name,
|
|
253
|
+
prompt_type="PromptModel",
|
|
254
|
+
prompt_name=agent.prompt.name,
|
|
255
|
+
)
|
|
256
|
+
else:
|
|
257
|
+
prompt_preview: str = (
|
|
258
|
+
agent.prompt[:100] + "..." if len(agent.prompt) > 100 else agent.prompt
|
|
259
|
+
)
|
|
260
|
+
logger.info(
|
|
261
|
+
"Prompt configuration",
|
|
262
|
+
agent=agent.name,
|
|
263
|
+
prompt_type="string",
|
|
264
|
+
prompt_preview=prompt_preview,
|
|
265
|
+
)
|
|
266
|
+
else:
|
|
267
|
+
logger.debug("No custom prompt configured", agent=agent.name)
|
|
268
|
+
|
|
269
|
+
checkpointer: bool = memory is not None and memory.checkpointer is not None
|
|
270
|
+
|
|
271
|
+
# Get the prompt as middleware (always returns AgentMiddleware or None)
|
|
272
|
+
prompt_middleware: AgentMiddleware | None = make_prompt(agent.prompt)
|
|
273
|
+
|
|
274
|
+
# Add prompt middleware at the beginning for priority
|
|
275
|
+
if prompt_middleware is not None:
|
|
276
|
+
middleware_list.insert(0, prompt_middleware)
|
|
277
|
+
|
|
278
|
+
# Configure structured output if response_format is specified
|
|
279
|
+
response_format: Any = None
|
|
280
|
+
if agent.response_format is not None:
|
|
281
|
+
try:
|
|
282
|
+
response_format = agent.response_format.as_strategy()
|
|
283
|
+
if response_format is not None:
|
|
284
|
+
logger.info(
|
|
285
|
+
"Response format configuration",
|
|
286
|
+
agent=agent.name,
|
|
287
|
+
format_type=type(response_format).__name__,
|
|
288
|
+
structured_output=True,
|
|
289
|
+
)
|
|
290
|
+
except ValueError as e:
|
|
291
|
+
logger.error(
|
|
292
|
+
"Failed to configure structured output for agent",
|
|
293
|
+
agent=agent.name,
|
|
294
|
+
error=str(e),
|
|
295
|
+
)
|
|
296
|
+
raise
|
|
297
|
+
|
|
298
|
+
# Use LangChain v1's create_agent with middleware
|
|
299
|
+
# AgentState extends MessagesState with additional DAO AI fields
|
|
300
|
+
# System prompt is provided via middleware (dynamic_prompt)
|
|
301
|
+
logger.info(
|
|
302
|
+
"Creating LangChain agent",
|
|
303
|
+
agent=agent.name,
|
|
304
|
+
tools_count=len(tools),
|
|
305
|
+
middleware_count=len(middleware_list),
|
|
306
|
+
has_checkpointer=checkpointer,
|
|
154
307
|
)
|
|
155
|
-
logger.debug(f"post_agent_hook: {post_agent_hook}")
|
|
156
308
|
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
compiled_agent: CompiledStateGraph = create_react_agent(
|
|
309
|
+
compiled_agent: CompiledStateGraph = create_agent(
|
|
160
310
|
name=agent.name,
|
|
161
311
|
model=llm,
|
|
162
|
-
prompt=make_prompt(agent.prompt),
|
|
163
312
|
tools=tools,
|
|
164
|
-
|
|
313
|
+
middleware=middleware_list,
|
|
165
314
|
checkpointer=checkpointer,
|
|
166
|
-
state_schema=
|
|
315
|
+
state_schema=AgentState,
|
|
167
316
|
context_schema=Context,
|
|
168
|
-
|
|
169
|
-
post_model_hook=post_agent_hook,
|
|
317
|
+
response_format=response_format, # Add structured output support
|
|
170
318
|
)
|
|
171
319
|
|
|
172
|
-
for guardrail_definition in agent.guardrails:
|
|
173
|
-
guardrail: CompiledStateGraph = reflection_guardrail(guardrail_definition)
|
|
174
|
-
compiled_agent = with_guardrails(compiled_agent, guardrail)
|
|
175
|
-
|
|
176
320
|
compiled_agent.name = agent.name
|
|
177
321
|
|
|
178
|
-
|
|
322
|
+
logger.info("Agent node created successfully", agent=agent.name)
|
|
179
323
|
|
|
180
|
-
|
|
181
|
-
logger.debug("No chat history configured, using compiled agent directly")
|
|
182
|
-
agent_node = compiled_agent
|
|
183
|
-
else:
|
|
184
|
-
logger.debug("Creating agent node with chat history summarization")
|
|
185
|
-
workflow: StateGraph = StateGraph(
|
|
186
|
-
SharedState,
|
|
187
|
-
config_schema=RunnableConfig,
|
|
188
|
-
input=SharedState,
|
|
189
|
-
output=SharedState,
|
|
190
|
-
)
|
|
191
|
-
workflow.add_node("summarization", summarization_node(chat_history))
|
|
192
|
-
workflow.add_node(
|
|
193
|
-
"agent",
|
|
194
|
-
call_agent_with_summarized_messages(agent=compiled_agent),
|
|
195
|
-
)
|
|
196
|
-
workflow.add_edge("summarization", "agent")
|
|
197
|
-
workflow.set_entry_point("summarization")
|
|
198
|
-
agent_node = workflow.compile(name=agent.name)
|
|
199
|
-
|
|
200
|
-
return agent_node
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
def message_hook_node(config: AppConfig) -> RunnableLike:
|
|
204
|
-
message_hooks: Sequence[Callable[..., Any]] = create_hooks(config.app.message_hooks)
|
|
205
|
-
|
|
206
|
-
@mlflow.trace()
|
|
207
|
-
async def message_hook(
|
|
208
|
-
state: IncomingState, runtime: Runtime[Context]
|
|
209
|
-
) -> SharedState:
|
|
210
|
-
logger.debug("Running message validation")
|
|
211
|
-
response: dict[str, Any] = {"is_valid": True, "message_error": None}
|
|
212
|
-
|
|
213
|
-
for message_hook in message_hooks:
|
|
214
|
-
message_hook: FunctionHook
|
|
215
|
-
if message_hook:
|
|
216
|
-
try:
|
|
217
|
-
hook_response: dict[str, Any] = message_hook(
|
|
218
|
-
state=state,
|
|
219
|
-
runtime=runtime,
|
|
220
|
-
)
|
|
221
|
-
response.update(hook_response)
|
|
222
|
-
logger.debug(f"Hook response: {hook_response}")
|
|
223
|
-
if not response.get("is_valid", True):
|
|
224
|
-
break
|
|
225
|
-
except Exception as e:
|
|
226
|
-
logger.error(f"Message validation failed: {e}")
|
|
227
|
-
response_messages: Sequence[BaseMessage] = [
|
|
228
|
-
AIMessage(content=str(e))
|
|
229
|
-
]
|
|
230
|
-
return {
|
|
231
|
-
"is_valid": False,
|
|
232
|
-
"message_error": str(e),
|
|
233
|
-
"messages": response_messages,
|
|
234
|
-
}
|
|
235
|
-
|
|
236
|
-
return response
|
|
237
|
-
|
|
238
|
-
return message_hook
|
|
324
|
+
return compiled_agent
|