dao-ai 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/agent_as_code.py +2 -5
- dao_ai/cli.py +65 -15
- dao_ai/config.py +672 -218
- dao_ai/genie/cache/core.py +6 -2
- dao_ai/genie/cache/lru.py +29 -11
- dao_ai/genie/cache/semantic.py +95 -44
- dao_ai/hooks/core.py +5 -5
- dao_ai/logging.py +56 -0
- dao_ai/memory/core.py +61 -44
- dao_ai/memory/databricks.py +54 -41
- dao_ai/memory/postgres.py +77 -36
- dao_ai/middleware/assertions.py +45 -17
- dao_ai/middleware/core.py +13 -7
- dao_ai/middleware/guardrails.py +30 -25
- dao_ai/middleware/human_in_the_loop.py +9 -5
- dao_ai/middleware/message_validation.py +61 -29
- dao_ai/middleware/summarization.py +16 -11
- dao_ai/models.py +172 -69
- dao_ai/nodes.py +148 -19
- dao_ai/optimization.py +26 -16
- dao_ai/orchestration/core.py +15 -8
- dao_ai/orchestration/supervisor.py +22 -8
- dao_ai/orchestration/swarm.py +57 -12
- dao_ai/prompts.py +17 -17
- dao_ai/providers/databricks.py +365 -155
- dao_ai/state.py +24 -6
- dao_ai/tools/__init__.py +2 -0
- dao_ai/tools/agent.py +1 -3
- dao_ai/tools/core.py +7 -7
- dao_ai/tools/email.py +29 -77
- dao_ai/tools/genie.py +18 -13
- dao_ai/tools/mcp.py +223 -156
- dao_ai/tools/python.py +5 -2
- dao_ai/tools/search.py +1 -1
- dao_ai/tools/slack.py +21 -9
- dao_ai/tools/sql.py +202 -0
- dao_ai/tools/time.py +30 -7
- dao_ai/tools/unity_catalog.py +129 -86
- dao_ai/tools/vector_search.py +318 -244
- dao_ai/utils.py +15 -10
- dao_ai-0.1.3.dist-info/METADATA +455 -0
- dao_ai-0.1.3.dist-info/RECORD +64 -0
- dao_ai-0.1.1.dist-info/METADATA +0 -1878
- dao_ai-0.1.1.dist-info/RECORD +0 -62
- {dao_ai-0.1.1.dist-info → dao_ai-0.1.3.dist-info}/WHEEL +0 -0
- {dao_ai-0.1.1.dist-info → dao_ai-0.1.3.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.1.1.dist-info → dao_ai-0.1.3.dist-info}/licenses/LICENSE +0 -0
dao_ai/nodes.py
CHANGED
|
@@ -25,8 +25,14 @@ from dao_ai.config import (
|
|
|
25
25
|
)
|
|
26
26
|
from dao_ai.middleware.core import create_factory_middleware
|
|
27
27
|
from dao_ai.middleware.guardrails import GuardrailMiddleware
|
|
28
|
-
from dao_ai.middleware.human_in_the_loop import
|
|
29
|
-
|
|
28
|
+
from dao_ai.middleware.human_in_the_loop import (
|
|
29
|
+
HumanInTheLoopMiddleware,
|
|
30
|
+
create_hitl_middleware_from_tool_models,
|
|
31
|
+
)
|
|
32
|
+
from dao_ai.middleware.summarization import (
|
|
33
|
+
LoggingSummarizationMiddleware,
|
|
34
|
+
create_summarization_middleware,
|
|
35
|
+
)
|
|
30
36
|
from dao_ai.prompts import make_prompt
|
|
31
37
|
from dao_ai.state import AgentState, Context
|
|
32
38
|
from dao_ai.tools import create_tools
|
|
@@ -50,14 +56,25 @@ def _create_middleware_list(
|
|
|
50
56
|
List of middleware instances (can include both AgentMiddleware and
|
|
51
57
|
LangChain built-in middleware)
|
|
52
58
|
"""
|
|
53
|
-
logger.debug(
|
|
59
|
+
logger.debug("Building middleware list for agent", agent=agent.name)
|
|
54
60
|
middleware_list: list[Any] = []
|
|
55
61
|
|
|
56
62
|
# Add configured middleware using factory pattern
|
|
57
63
|
if agent.middleware:
|
|
58
|
-
|
|
64
|
+
middleware_names: list[str] = [mw.name for mw in agent.middleware]
|
|
65
|
+
logger.info(
|
|
66
|
+
"Middleware configuration",
|
|
67
|
+
agent=agent.name,
|
|
68
|
+
middleware_count=len(agent.middleware),
|
|
69
|
+
middleware_names=middleware_names,
|
|
70
|
+
)
|
|
59
71
|
for middleware_config in agent.middleware:
|
|
60
|
-
|
|
72
|
+
logger.trace(
|
|
73
|
+
"Creating middleware for agent",
|
|
74
|
+
agent=agent.name,
|
|
75
|
+
middleware_name=middleware_config.name,
|
|
76
|
+
)
|
|
77
|
+
middleware: AgentMiddleware[AgentState, Context] = create_factory_middleware(
|
|
61
78
|
function_name=middleware_config.name,
|
|
62
79
|
args=middleware_config.args,
|
|
63
80
|
)
|
|
@@ -66,7 +83,13 @@ def _create_middleware_list(
|
|
|
66
83
|
|
|
67
84
|
# Add guardrails as middleware
|
|
68
85
|
if agent.guardrails:
|
|
69
|
-
|
|
86
|
+
guardrail_names: list[str] = [gr.name for gr in agent.guardrails]
|
|
87
|
+
logger.info(
|
|
88
|
+
"Guardrails configuration",
|
|
89
|
+
agent=agent.name,
|
|
90
|
+
guardrails_count=len(agent.guardrails),
|
|
91
|
+
guardrail_names=guardrail_names,
|
|
92
|
+
)
|
|
70
93
|
for guardrail in agent.guardrails:
|
|
71
94
|
# Extract template string from PromptModel if needed
|
|
72
95
|
prompt_str: str
|
|
@@ -75,28 +98,54 @@ def _create_middleware_list(
|
|
|
75
98
|
else:
|
|
76
99
|
prompt_str = guardrail.prompt
|
|
77
100
|
|
|
78
|
-
guardrail_middleware = GuardrailMiddleware(
|
|
101
|
+
guardrail_middleware: GuardrailMiddleware = GuardrailMiddleware(
|
|
79
102
|
name=guardrail.name,
|
|
80
103
|
model=guardrail.model.as_chat_model(),
|
|
81
104
|
prompt=prompt_str,
|
|
82
105
|
num_retries=guardrail.num_retries or 3,
|
|
83
106
|
)
|
|
84
|
-
logger.
|
|
107
|
+
logger.trace(
|
|
108
|
+
"Created guardrail middleware", guardrail=guardrail.name, agent=agent.name
|
|
109
|
+
)
|
|
85
110
|
middleware_list.append(guardrail_middleware)
|
|
86
111
|
|
|
87
112
|
# Add summarization middleware if chat_history is configured
|
|
88
113
|
if chat_history is not None:
|
|
89
|
-
logger.
|
|
90
|
-
|
|
114
|
+
logger.info(
|
|
115
|
+
"Chat history configuration",
|
|
116
|
+
agent=agent.name,
|
|
117
|
+
max_tokens=chat_history.max_tokens,
|
|
118
|
+
summary_model=chat_history.model.name,
|
|
119
|
+
)
|
|
120
|
+
summarization_middleware: LoggingSummarizationMiddleware = (
|
|
121
|
+
create_summarization_middleware(chat_history)
|
|
122
|
+
)
|
|
91
123
|
middleware_list.append(summarization_middleware)
|
|
92
124
|
|
|
93
125
|
# Add human-in-the-loop middleware if any tools require it
|
|
94
|
-
hitl_middleware =
|
|
126
|
+
hitl_middleware: HumanInTheLoopMiddleware | None = (
|
|
127
|
+
create_hitl_middleware_from_tool_models(tool_models)
|
|
128
|
+
)
|
|
95
129
|
if hitl_middleware is not None:
|
|
96
|
-
|
|
130
|
+
# Log which tools require HITL
|
|
131
|
+
hitl_tool_names: list[str] = [
|
|
132
|
+
tool.name
|
|
133
|
+
for tool in tool_models
|
|
134
|
+
if hasattr(tool.function, "human_in_the_loop")
|
|
135
|
+
and tool.function.human_in_the_loop is not None
|
|
136
|
+
]
|
|
137
|
+
logger.info(
|
|
138
|
+
"Human-in-the-Loop configuration",
|
|
139
|
+
agent=agent.name,
|
|
140
|
+
hitl_tools=hitl_tool_names,
|
|
141
|
+
)
|
|
97
142
|
middleware_list.append(hitl_middleware)
|
|
98
143
|
|
|
99
|
-
logger.
|
|
144
|
+
logger.info(
|
|
145
|
+
"Middleware summary",
|
|
146
|
+
agent=agent.name,
|
|
147
|
+
total_middleware_count=len(middleware_list),
|
|
148
|
+
)
|
|
100
149
|
return middleware_list
|
|
101
150
|
|
|
102
151
|
|
|
@@ -122,26 +171,71 @@ def create_agent_node(
|
|
|
122
171
|
Returns:
|
|
123
172
|
RunnableLike: An agent node that processes state and returns responses
|
|
124
173
|
"""
|
|
125
|
-
logger.
|
|
174
|
+
logger.info("Creating agent node", agent=agent.name)
|
|
175
|
+
|
|
176
|
+
# Log agent configuration details
|
|
177
|
+
logger.info(
|
|
178
|
+
"Agent configuration",
|
|
179
|
+
agent=agent.name,
|
|
180
|
+
model=agent.model.name,
|
|
181
|
+
description=agent.description or "No description",
|
|
182
|
+
)
|
|
126
183
|
|
|
127
184
|
llm: LanguageModelLike = agent.model.as_chat_model()
|
|
128
185
|
|
|
129
186
|
tool_models: Sequence[ToolModel] = agent.tools
|
|
130
187
|
if not additional_tools:
|
|
131
188
|
additional_tools = []
|
|
189
|
+
|
|
190
|
+
# Log tools being created
|
|
191
|
+
tool_names: list[str] = [tool.name for tool in tool_models]
|
|
192
|
+
logger.info(
|
|
193
|
+
"Tools configuration",
|
|
194
|
+
agent=agent.name,
|
|
195
|
+
tools_count=len(tool_models),
|
|
196
|
+
tool_names=tool_names,
|
|
197
|
+
)
|
|
198
|
+
|
|
132
199
|
tools: list[BaseTool] = list(create_tools(tool_models)) + list(additional_tools)
|
|
133
200
|
|
|
201
|
+
if additional_tools:
|
|
202
|
+
logger.debug(
|
|
203
|
+
"Additional tools added",
|
|
204
|
+
agent=agent.name,
|
|
205
|
+
additional_count=len(additional_tools),
|
|
206
|
+
)
|
|
207
|
+
|
|
134
208
|
if memory and memory.store:
|
|
135
209
|
namespace: tuple[str, ...] = ("memory",)
|
|
136
210
|
if memory.store.namespace:
|
|
137
211
|
namespace = namespace + (memory.store.namespace,)
|
|
138
|
-
logger.
|
|
212
|
+
logger.info(
|
|
213
|
+
"Memory configuration",
|
|
214
|
+
agent=agent.name,
|
|
215
|
+
has_store=True,
|
|
216
|
+
has_checkpointer=memory.checkpointer is not None,
|
|
217
|
+
namespace=namespace,
|
|
218
|
+
)
|
|
219
|
+
elif memory:
|
|
220
|
+
logger.info(
|
|
221
|
+
"Memory configuration",
|
|
222
|
+
agent=agent.name,
|
|
223
|
+
has_store=False,
|
|
224
|
+
has_checkpointer=memory.checkpointer is not None,
|
|
225
|
+
)
|
|
139
226
|
|
|
227
|
+
# Add memory tools if store is configured
|
|
228
|
+
if memory and memory.store:
|
|
140
229
|
# Use Databricks-compatible search_memory tool (omits problematic filter field)
|
|
141
230
|
tools += [
|
|
142
231
|
create_manage_memory_tool(namespace=namespace),
|
|
143
232
|
create_search_memory_tool(namespace=namespace),
|
|
144
233
|
]
|
|
234
|
+
logger.debug(
|
|
235
|
+
"Memory tools added",
|
|
236
|
+
agent=agent.name,
|
|
237
|
+
tools=["manage_memory", "search_memory"],
|
|
238
|
+
)
|
|
145
239
|
|
|
146
240
|
# Create middleware list from configuration
|
|
147
241
|
middleware_list = _create_middleware_list(
|
|
@@ -150,7 +244,27 @@ def create_agent_node(
|
|
|
150
244
|
chat_history=chat_history,
|
|
151
245
|
)
|
|
152
246
|
|
|
153
|
-
|
|
247
|
+
# Log prompt configuration
|
|
248
|
+
if agent.prompt:
|
|
249
|
+
if isinstance(agent.prompt, PromptModel):
|
|
250
|
+
logger.info(
|
|
251
|
+
"Prompt configuration",
|
|
252
|
+
agent=agent.name,
|
|
253
|
+
prompt_type="PromptModel",
|
|
254
|
+
prompt_name=agent.prompt.name,
|
|
255
|
+
)
|
|
256
|
+
else:
|
|
257
|
+
prompt_preview: str = (
|
|
258
|
+
agent.prompt[:100] + "..." if len(agent.prompt) > 100 else agent.prompt
|
|
259
|
+
)
|
|
260
|
+
logger.info(
|
|
261
|
+
"Prompt configuration",
|
|
262
|
+
agent=agent.name,
|
|
263
|
+
prompt_type="string",
|
|
264
|
+
prompt_preview=prompt_preview,
|
|
265
|
+
)
|
|
266
|
+
else:
|
|
267
|
+
logger.debug("No custom prompt configured", agent=agent.name)
|
|
154
268
|
|
|
155
269
|
checkpointer: bool = memory is not None and memory.checkpointer is not None
|
|
156
270
|
|
|
@@ -167,18 +281,31 @@ def create_agent_node(
|
|
|
167
281
|
try:
|
|
168
282
|
response_format = agent.response_format.as_strategy()
|
|
169
283
|
if response_format is not None:
|
|
170
|
-
logger.
|
|
171
|
-
|
|
284
|
+
logger.info(
|
|
285
|
+
"Response format configuration",
|
|
286
|
+
agent=agent.name,
|
|
287
|
+
format_type=type(response_format).__name__,
|
|
288
|
+
structured_output=True,
|
|
172
289
|
)
|
|
173
290
|
except ValueError as e:
|
|
174
291
|
logger.error(
|
|
175
|
-
|
|
292
|
+
"Failed to configure structured output for agent",
|
|
293
|
+
agent=agent.name,
|
|
294
|
+
error=str(e),
|
|
176
295
|
)
|
|
177
296
|
raise
|
|
178
297
|
|
|
179
298
|
# Use LangChain v1's create_agent with middleware
|
|
180
299
|
# AgentState extends MessagesState with additional DAO AI fields
|
|
181
300
|
# System prompt is provided via middleware (dynamic_prompt)
|
|
301
|
+
logger.info(
|
|
302
|
+
"Creating LangChain agent",
|
|
303
|
+
agent=agent.name,
|
|
304
|
+
tools_count=len(tools),
|
|
305
|
+
middleware_count=len(middleware_list),
|
|
306
|
+
has_checkpointer=checkpointer,
|
|
307
|
+
)
|
|
308
|
+
|
|
182
309
|
compiled_agent: CompiledStateGraph = create_agent(
|
|
183
310
|
name=agent.name,
|
|
184
311
|
model=llm,
|
|
@@ -192,4 +319,6 @@ def create_agent_node(
|
|
|
192
319
|
|
|
193
320
|
compiled_agent.name = agent.name
|
|
194
321
|
|
|
322
|
+
logger.info("Agent node created successfully", agent=agent.name)
|
|
323
|
+
|
|
195
324
|
return compiled_agent
|
dao_ai/optimization.py
CHANGED
|
@@ -245,7 +245,7 @@ class DAOAgentAdapter(GEPAAdapter[_TrainingExample, _Trajectory, str]):
|
|
|
245
245
|
)
|
|
246
246
|
|
|
247
247
|
except Exception as e:
|
|
248
|
-
logger.warning(
|
|
248
|
+
logger.warning("Error evaluating example", error=str(e))
|
|
249
249
|
outputs.append("")
|
|
250
250
|
scores.append(0.0)
|
|
251
251
|
|
|
@@ -362,7 +362,9 @@ def _convert_dataset(
|
|
|
362
362
|
)
|
|
363
363
|
examples.append(example)
|
|
364
364
|
|
|
365
|
-
logger.debug(
|
|
365
|
+
logger.debug(
|
|
366
|
+
"Converted dataset entries to training examples", examples_count=len(examples)
|
|
367
|
+
)
|
|
366
368
|
return examples
|
|
367
369
|
|
|
368
370
|
|
|
@@ -400,7 +402,7 @@ def _register_optimized_prompt(
|
|
|
400
402
|
prompt_name: str = prompt.full_name
|
|
401
403
|
optimization_timestamp: str = datetime.now(timezone.utc).isoformat()
|
|
402
404
|
|
|
403
|
-
logger.info(
|
|
405
|
+
logger.info("Registering optimized prompt", prompt_name=prompt_name)
|
|
404
406
|
|
|
405
407
|
# Build comprehensive tags for the prompt registry
|
|
406
408
|
tags: dict[str, str] = {
|
|
@@ -442,7 +444,11 @@ def _register_optimized_prompt(
|
|
|
442
444
|
tags=tags,
|
|
443
445
|
)
|
|
444
446
|
|
|
445
|
-
logger.
|
|
447
|
+
logger.success(
|
|
448
|
+
"Registered optimized prompt version",
|
|
449
|
+
prompt_name=prompt_name,
|
|
450
|
+
version=version.version,
|
|
451
|
+
)
|
|
446
452
|
|
|
447
453
|
# Set 'latest' alias for most recently optimized version
|
|
448
454
|
mlflow.genai.set_prompt_alias(
|
|
@@ -450,7 +456,7 @@ def _register_optimized_prompt(
|
|
|
450
456
|
alias="latest",
|
|
451
457
|
version=version.version,
|
|
452
458
|
)
|
|
453
|
-
logger.info(
|
|
459
|
+
logger.info("Set 'latest' alias", prompt_name=prompt_name, version=version.version)
|
|
454
460
|
|
|
455
461
|
# Set 'champion' alias if there was actual improvement
|
|
456
462
|
if improvement > 0:
|
|
@@ -459,7 +465,9 @@ def _register_optimized_prompt(
|
|
|
459
465
|
alias="champion",
|
|
460
466
|
version=version.version,
|
|
461
467
|
)
|
|
462
|
-
logger.
|
|
468
|
+
logger.success(
|
|
469
|
+
"Set 'champion' alias", prompt_name=prompt_name, version=version.version
|
|
470
|
+
)
|
|
463
471
|
|
|
464
472
|
return version
|
|
465
473
|
|
|
@@ -518,7 +526,7 @@ def optimize_prompt(
|
|
|
518
526
|
if result.improved:
|
|
519
527
|
print(f"Improved by {result.improvement:.1%}")
|
|
520
528
|
"""
|
|
521
|
-
logger.info(
|
|
529
|
+
logger.info("Starting GEPA optimization", prompt_name=prompt.name)
|
|
522
530
|
|
|
523
531
|
# Get the original template
|
|
524
532
|
original_template = prompt.template
|
|
@@ -535,11 +543,11 @@ def optimize_prompt(
|
|
|
535
543
|
trainset = examples[:split_idx]
|
|
536
544
|
valset = examples[split_idx:] if split_idx < len(examples) else examples
|
|
537
545
|
|
|
538
|
-
logger.info(
|
|
546
|
+
logger.info("Dataset split", train_size=len(trainset), val_size=len(valset))
|
|
539
547
|
|
|
540
548
|
# Get reflection model
|
|
541
549
|
reflection_model_name = reflection_model or agent.model.uri
|
|
542
|
-
logger.info(
|
|
550
|
+
logger.info("Using reflection model", model=reflection_model_name)
|
|
543
551
|
|
|
544
552
|
# Create adapter
|
|
545
553
|
adapter = DAOAgentAdapter(agent_model=agent, metric_fn=metric)
|
|
@@ -548,7 +556,7 @@ def optimize_prompt(
|
|
|
548
556
|
seed_candidate = {"prompt": original_template}
|
|
549
557
|
|
|
550
558
|
# Run GEPA optimization
|
|
551
|
-
logger.info(
|
|
559
|
+
logger.info("Running GEPA optimization", max_evaluations=num_candidates)
|
|
552
560
|
|
|
553
561
|
try:
|
|
554
562
|
result: GEPAResult = optimize(
|
|
@@ -562,7 +570,7 @@ def optimize_prompt(
|
|
|
562
570
|
skip_perfect_score=True,
|
|
563
571
|
)
|
|
564
572
|
except Exception as e:
|
|
565
|
-
logger.error(
|
|
573
|
+
logger.error("GEPA optimization failed", error=str(e))
|
|
566
574
|
return OptimizationResult(
|
|
567
575
|
optimized_prompt=prompt,
|
|
568
576
|
optimized_template=original_template,
|
|
@@ -596,10 +604,12 @@ def optimize_prompt(
|
|
|
596
604
|
else 0.0
|
|
597
605
|
)
|
|
598
606
|
|
|
599
|
-
logger.
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
607
|
+
logger.success(
|
|
608
|
+
"Optimization complete",
|
|
609
|
+
original_score=f"{original_score:.3f}",
|
|
610
|
+
optimized_score=f"{optimized_score:.3f}",
|
|
611
|
+
improvement=f"{improvement:.1%}",
|
|
612
|
+
)
|
|
603
613
|
|
|
604
614
|
# Register if improved
|
|
605
615
|
registered_version: Optional[PromptVersion] = None
|
|
@@ -623,7 +633,7 @@ def optimize_prompt(
|
|
|
623
633
|
val_size=len(valset),
|
|
624
634
|
)
|
|
625
635
|
except Exception as e:
|
|
626
|
-
logger.error(
|
|
636
|
+
logger.error("Failed to register optimized prompt", error=str(e))
|
|
627
637
|
|
|
628
638
|
# Build optimized prompt model with comprehensive tags
|
|
629
639
|
optimized_tags: dict[str, str] = {
|
dao_ai/orchestration/core.py
CHANGED
|
@@ -46,7 +46,7 @@ def create_store(orchestration: OrchestrationModel) -> BaseStore | None:
|
|
|
46
46
|
"""
|
|
47
47
|
if orchestration.memory and orchestration.memory.store:
|
|
48
48
|
store = orchestration.memory.store.as_store()
|
|
49
|
-
logger.debug(
|
|
49
|
+
logger.debug("Memory store configured", store_type=type(store).__name__)
|
|
50
50
|
return store
|
|
51
51
|
return None
|
|
52
52
|
|
|
@@ -65,7 +65,9 @@ def create_checkpointer(
|
|
|
65
65
|
"""
|
|
66
66
|
if orchestration.memory and orchestration.memory.checkpointer:
|
|
67
67
|
checkpointer = orchestration.memory.checkpointer.as_checkpointer()
|
|
68
|
-
logger.debug(
|
|
68
|
+
logger.debug(
|
|
69
|
+
"Checkpointer configured", checkpointer_type=type(checkpointer).__name__
|
|
70
|
+
)
|
|
69
71
|
return checkpointer
|
|
70
72
|
return None
|
|
71
73
|
|
|
@@ -164,9 +166,11 @@ def create_agent_node_handler(
|
|
|
164
166
|
original_messages = state.get("messages", [])
|
|
165
167
|
filtered_messages = filter_messages_for_agent(original_messages)
|
|
166
168
|
|
|
167
|
-
logger.
|
|
168
|
-
|
|
169
|
-
|
|
169
|
+
logger.trace(
|
|
170
|
+
"Agent receiving filtered messages",
|
|
171
|
+
agent=agent_name,
|
|
172
|
+
filtered_count=len(filtered_messages),
|
|
173
|
+
original_count=len(original_messages),
|
|
170
174
|
)
|
|
171
175
|
|
|
172
176
|
# Create state with filtered messages for the agent
|
|
@@ -183,8 +187,11 @@ def create_agent_node_handler(
|
|
|
183
187
|
response_messages = extract_agent_response(result_messages, output_mode)
|
|
184
188
|
|
|
185
189
|
logger.debug(
|
|
186
|
-
|
|
187
|
-
|
|
190
|
+
"Agent completed",
|
|
191
|
+
agent=agent_name,
|
|
192
|
+
response_count=len(response_messages),
|
|
193
|
+
total_messages=len(result_messages),
|
|
194
|
+
output_mode=output_mode,
|
|
188
195
|
)
|
|
189
196
|
|
|
190
197
|
# Return state update with extracted response
|
|
@@ -218,7 +225,7 @@ def create_handoff_tool(
|
|
|
218
225
|
def handoff_tool(runtime: ToolRuntime[Context, AgentState]) -> Command:
|
|
219
226
|
"""Transfer control to another agent."""
|
|
220
227
|
tool_call_id: str = runtime.tool_call_id
|
|
221
|
-
logger.debug(
|
|
228
|
+
logger.debug("Handoff to agent", target_agent=target_agent_name)
|
|
222
229
|
|
|
223
230
|
return Command(
|
|
224
231
|
update={
|
|
@@ -73,7 +73,7 @@ def _create_handoff_back_to_supervisor_tool() -> BaseTool:
|
|
|
73
73
|
summary: A brief summary of what was accomplished
|
|
74
74
|
"""
|
|
75
75
|
tool_call_id: str = runtime.tool_call_id
|
|
76
|
-
logger.debug(
|
|
76
|
+
logger.debug("Agent handing back to supervisor", summary_preview=summary[:100])
|
|
77
77
|
|
|
78
78
|
return Command(
|
|
79
79
|
update={
|
|
@@ -163,11 +163,15 @@ def create_supervisor_graph(config: AppConfig) -> CompiledStateGraph:
|
|
|
163
163
|
|
|
164
164
|
Based on: https://github.com/langchain-ai/langgraph-supervisor-py
|
|
165
165
|
"""
|
|
166
|
-
logger.debug("Creating supervisor graph (handoff pattern)")
|
|
167
|
-
|
|
168
166
|
orchestration: OrchestrationModel = config.app.orchestration
|
|
169
167
|
supervisor_config: SupervisorModel = orchestration.supervisor
|
|
170
168
|
|
|
169
|
+
logger.info(
|
|
170
|
+
"Creating supervisor graph",
|
|
171
|
+
pattern="handoff",
|
|
172
|
+
agents_count=len(config.app.agents),
|
|
173
|
+
)
|
|
174
|
+
|
|
171
175
|
# Create handoff tools for supervisor to route to agents
|
|
172
176
|
handoff_tools: list[BaseTool] = []
|
|
173
177
|
for registered_agent in config.app.agents:
|
|
@@ -177,21 +181,31 @@ def create_supervisor_graph(config: AppConfig) -> CompiledStateGraph:
|
|
|
177
181
|
description=description,
|
|
178
182
|
)
|
|
179
183
|
handoff_tools.append(handoff_tool)
|
|
180
|
-
logger.debug(
|
|
184
|
+
logger.debug("Created handoff tool for supervisor", agent=registered_agent.name)
|
|
181
185
|
|
|
182
186
|
# Create supervisor's own tools (e.g., memory tools)
|
|
187
|
+
logger.debug(
|
|
188
|
+
"Creating tools for supervisor", tools_count=len(supervisor_config.tools)
|
|
189
|
+
)
|
|
183
190
|
supervisor_tools: list[BaseTool] = list(create_tools(supervisor_config.tools))
|
|
184
191
|
|
|
185
192
|
# Create middleware from configuration
|
|
186
193
|
middlewares: list[AgentMiddleware] = []
|
|
194
|
+
|
|
187
195
|
for middleware_config in supervisor_config.middleware:
|
|
188
|
-
|
|
196
|
+
logger.trace(
|
|
197
|
+
"Creating middleware for supervisor",
|
|
198
|
+
middleware_name=middleware_config.name,
|
|
199
|
+
)
|
|
200
|
+
middleware: LangchainAgentMiddleware = create_factory_middleware(
|
|
189
201
|
function_name=middleware_config.name,
|
|
190
202
|
args=middleware_config.args,
|
|
191
203
|
)
|
|
192
204
|
if middleware is not None:
|
|
193
205
|
middlewares.append(middleware)
|
|
194
|
-
logger.debug(
|
|
206
|
+
logger.debug(
|
|
207
|
+
"Created supervisor middleware", middleware=middleware_config.name
|
|
208
|
+
)
|
|
195
209
|
|
|
196
210
|
# Set up memory store and checkpointer
|
|
197
211
|
store: BaseStore | None = create_store(orchestration)
|
|
@@ -204,7 +218,7 @@ def create_supervisor_graph(config: AppConfig) -> CompiledStateGraph:
|
|
|
204
218
|
and orchestration.memory.store.namespace
|
|
205
219
|
):
|
|
206
220
|
namespace: tuple[str, ...] = ("memory", orchestration.memory.store.namespace)
|
|
207
|
-
logger.debug(
|
|
221
|
+
logger.debug("Memory store namespace configured", namespace=namespace)
|
|
208
222
|
# Use Databricks-compatible search_memory tool (omits problematic filter field)
|
|
209
223
|
supervisor_tools += [
|
|
210
224
|
create_manage_memory_tool(namespace=namespace),
|
|
@@ -235,7 +249,7 @@ def create_supervisor_graph(config: AppConfig) -> CompiledStateGraph:
|
|
|
235
249
|
additional_tools=[supervisor_handoff],
|
|
236
250
|
)
|
|
237
251
|
agent_subgraphs[registered_agent.name] = agent_subgraph
|
|
238
|
-
logger.debug(
|
|
252
|
+
logger.debug("Created worker agent subgraph", agent=registered_agent.name)
|
|
239
253
|
|
|
240
254
|
# Build the workflow graph
|
|
241
255
|
# All agents are nodes, handoffs route between them via Command
|
dao_ai/orchestration/swarm.py
CHANGED
|
@@ -69,14 +69,14 @@ def _handoffs_for_agent(
|
|
|
69
69
|
)
|
|
70
70
|
|
|
71
71
|
if handoff_to_agent is None:
|
|
72
|
-
logger.warning(
|
|
73
|
-
f"Handoff agent not found in configuration for agent {agent.name}"
|
|
74
|
-
)
|
|
72
|
+
logger.warning("Handoff agent not found in configuration", agent=agent.name)
|
|
75
73
|
continue
|
|
76
74
|
if agent.name == handoff_to_agent.name:
|
|
77
75
|
continue
|
|
78
76
|
logger.debug(
|
|
79
|
-
|
|
77
|
+
"Creating handoff tool",
|
|
78
|
+
from_agent=agent.name,
|
|
79
|
+
to_agent=handoff_to_agent.name,
|
|
80
80
|
)
|
|
81
81
|
|
|
82
82
|
handoff_description: str = get_handoff_description(handoff_to_agent)
|
|
@@ -116,19 +116,22 @@ def _create_swarm_router(
|
|
|
116
116
|
|
|
117
117
|
# If no active agent set, use default
|
|
118
118
|
if not active_agent:
|
|
119
|
-
logger.
|
|
120
|
-
|
|
119
|
+
logger.trace(
|
|
120
|
+
"No active agent in state, routing to default",
|
|
121
|
+
default_agent=default_agent,
|
|
121
122
|
)
|
|
122
123
|
return default_agent
|
|
123
124
|
|
|
124
125
|
# Validate active_agent exists
|
|
125
126
|
if active_agent in agent_names:
|
|
126
|
-
logger.
|
|
127
|
+
logger.trace("Routing to active agent", active_agent=active_agent)
|
|
127
128
|
return active_agent
|
|
128
129
|
|
|
129
130
|
# Fallback to default if active_agent is invalid
|
|
130
131
|
logger.warning(
|
|
131
|
-
|
|
132
|
+
"Invalid active agent, routing to default",
|
|
133
|
+
active_agent=active_agent,
|
|
134
|
+
default_agent=default_agent,
|
|
132
135
|
)
|
|
133
136
|
return default_agent
|
|
134
137
|
|
|
@@ -157,8 +160,6 @@ def create_swarm_graph(config: AppConfig) -> CompiledStateGraph:
|
|
|
157
160
|
|
|
158
161
|
See: https://github.com/langchain-ai/langgraph-swarm-py
|
|
159
162
|
"""
|
|
160
|
-
logger.debug("Creating swarm graph (handoff pattern)")
|
|
161
|
-
|
|
162
163
|
orchestration: OrchestrationModel = config.app.orchestration
|
|
163
164
|
swarm: SwarmModel = orchestration.swarm
|
|
164
165
|
|
|
@@ -169,10 +170,27 @@ def create_swarm_graph(config: AppConfig) -> CompiledStateGraph:
|
|
|
169
170
|
else:
|
|
170
171
|
default_agent = swarm.default_agent
|
|
171
172
|
|
|
173
|
+
logger.info(
|
|
174
|
+
"Creating swarm graph",
|
|
175
|
+
pattern="handoff",
|
|
176
|
+
default_agent=default_agent,
|
|
177
|
+
agents_count=len(config.app.agents),
|
|
178
|
+
)
|
|
179
|
+
|
|
172
180
|
# Create agent subgraphs with their specific handoff tools
|
|
173
181
|
# Each agent gets handoff tools only for agents they're allowed to hand off to
|
|
174
182
|
agent_subgraphs: dict[str, CompiledStateGraph] = {}
|
|
175
183
|
memory: MemoryModel | None = orchestration.memory
|
|
184
|
+
|
|
185
|
+
# Get swarm-level middleware to apply to all agents
|
|
186
|
+
swarm_middleware: list = swarm.middleware if swarm.middleware else []
|
|
187
|
+
if swarm_middleware:
|
|
188
|
+
logger.info(
|
|
189
|
+
"Applying swarm-level middleware to all agents",
|
|
190
|
+
middleware_count=len(swarm_middleware),
|
|
191
|
+
middleware_names=[mw.name for mw in swarm_middleware],
|
|
192
|
+
)
|
|
193
|
+
|
|
176
194
|
for registered_agent in config.app.agents:
|
|
177
195
|
# Get handoff tools for this agent
|
|
178
196
|
handoff_tools: Sequence[BaseTool] = _handoffs_for_agent(
|
|
@@ -180,14 +198,41 @@ def create_swarm_graph(config: AppConfig) -> CompiledStateGraph:
|
|
|
180
198
|
config=config,
|
|
181
199
|
)
|
|
182
200
|
|
|
201
|
+
# Merge swarm-level middleware with agent-specific middleware
|
|
202
|
+
# Swarm middleware is applied first, then agent middleware
|
|
203
|
+
if swarm_middleware:
|
|
204
|
+
from copy import deepcopy
|
|
205
|
+
|
|
206
|
+
# Create a copy of the agent to avoid modifying the original
|
|
207
|
+
agent_with_middleware = deepcopy(registered_agent)
|
|
208
|
+
|
|
209
|
+
# Combine swarm middleware (first) with agent middleware
|
|
210
|
+
agent_with_middleware.middleware = (
|
|
211
|
+
swarm_middleware + agent_with_middleware.middleware
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
logger.debug(
|
|
215
|
+
"Merged middleware for agent",
|
|
216
|
+
agent=registered_agent.name,
|
|
217
|
+
swarm_middleware_count=len(swarm_middleware),
|
|
218
|
+
agent_middleware_count=len(registered_agent.middleware),
|
|
219
|
+
total_middleware_count=len(agent_with_middleware.middleware),
|
|
220
|
+
)
|
|
221
|
+
else:
|
|
222
|
+
agent_with_middleware = registered_agent
|
|
223
|
+
|
|
183
224
|
agent_subgraph: CompiledStateGraph = create_agent_node(
|
|
184
|
-
agent=
|
|
225
|
+
agent=agent_with_middleware,
|
|
185
226
|
memory=memory,
|
|
186
227
|
chat_history=config.app.chat_history,
|
|
187
228
|
additional_tools=handoff_tools,
|
|
188
229
|
)
|
|
189
230
|
agent_subgraphs[registered_agent.name] = agent_subgraph
|
|
190
|
-
logger.debug(
|
|
231
|
+
logger.debug(
|
|
232
|
+
"Created swarm agent subgraph",
|
|
233
|
+
agent=registered_agent.name,
|
|
234
|
+
handoffs_count=len(handoff_tools),
|
|
235
|
+
)
|
|
191
236
|
|
|
192
237
|
# Set up memory store and checkpointer
|
|
193
238
|
store: BaseStore | None = create_store(orchestration)
|