dao-ai 0.0.28__py3-none-any.whl → 0.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/__init__.py +29 -0
- dao_ai/agent_as_code.py +2 -5
- dao_ai/cli.py +342 -58
- dao_ai/config.py +1610 -380
- dao_ai/genie/__init__.py +38 -0
- dao_ai/genie/cache/__init__.py +43 -0
- dao_ai/genie/cache/base.py +72 -0
- dao_ai/genie/cache/core.py +79 -0
- dao_ai/genie/cache/lru.py +347 -0
- dao_ai/genie/cache/semantic.py +970 -0
- dao_ai/genie/core.py +35 -0
- dao_ai/graph.py +27 -253
- dao_ai/hooks/__init__.py +9 -6
- dao_ai/hooks/core.py +27 -195
- dao_ai/logging.py +56 -0
- dao_ai/memory/__init__.py +10 -0
- dao_ai/memory/core.py +65 -30
- dao_ai/memory/databricks.py +402 -0
- dao_ai/memory/postgres.py +79 -38
- dao_ai/messages.py +6 -4
- dao_ai/middleware/__init__.py +158 -0
- dao_ai/middleware/assertions.py +806 -0
- dao_ai/middleware/base.py +50 -0
- dao_ai/middleware/context_editing.py +230 -0
- dao_ai/middleware/core.py +67 -0
- dao_ai/middleware/guardrails.py +420 -0
- dao_ai/middleware/human_in_the_loop.py +233 -0
- dao_ai/middleware/message_validation.py +586 -0
- dao_ai/middleware/model_call_limit.py +77 -0
- dao_ai/middleware/model_retry.py +121 -0
- dao_ai/middleware/pii.py +157 -0
- dao_ai/middleware/summarization.py +197 -0
- dao_ai/middleware/tool_call_limit.py +210 -0
- dao_ai/middleware/tool_retry.py +174 -0
- dao_ai/models.py +1306 -114
- dao_ai/nodes.py +240 -161
- dao_ai/optimization.py +674 -0
- dao_ai/orchestration/__init__.py +52 -0
- dao_ai/orchestration/core.py +294 -0
- dao_ai/orchestration/supervisor.py +279 -0
- dao_ai/orchestration/swarm.py +271 -0
- dao_ai/prompts.py +128 -31
- dao_ai/providers/databricks.py +584 -601
- dao_ai/state.py +157 -21
- dao_ai/tools/__init__.py +13 -5
- dao_ai/tools/agent.py +1 -3
- dao_ai/tools/core.py +64 -11
- dao_ai/tools/email.py +232 -0
- dao_ai/tools/genie.py +144 -294
- dao_ai/tools/mcp.py +223 -155
- dao_ai/tools/memory.py +50 -0
- dao_ai/tools/python.py +9 -14
- dao_ai/tools/search.py +14 -0
- dao_ai/tools/slack.py +22 -10
- dao_ai/tools/sql.py +202 -0
- dao_ai/tools/time.py +30 -7
- dao_ai/tools/unity_catalog.py +165 -88
- dao_ai/tools/vector_search.py +331 -221
- dao_ai/utils.py +166 -20
- dao_ai/vector_search.py +37 -0
- dao_ai-0.1.5.dist-info/METADATA +489 -0
- dao_ai-0.1.5.dist-info/RECORD +70 -0
- dao_ai/chat_models.py +0 -204
- dao_ai/guardrails.py +0 -112
- dao_ai/tools/human_in_the_loop.py +0 -100
- dao_ai-0.0.28.dist-info/METADATA +0 -1168
- dao_ai-0.0.28.dist-info/RECORD +0 -41
- {dao_ai-0.0.28.dist-info → dao_ai-0.1.5.dist-info}/WHEEL +0 -0
- {dao_ai-0.0.28.dist-info → dao_ai-0.1.5.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.0.28.dist-info → dao_ai-0.1.5.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,420 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Guardrail middleware for DAO AI agents.
|
|
3
|
+
|
|
4
|
+
This module provides middleware implementations for applying guardrails
|
|
5
|
+
to agent responses, including LLM-based judging and content validation.
|
|
6
|
+
|
|
7
|
+
Factory functions are provided for consistent configuration via the
|
|
8
|
+
DAO AI middleware factory pattern.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from typing import Any, Optional
|
|
12
|
+
|
|
13
|
+
from langchain.agents.middleware import hook_config
|
|
14
|
+
from langchain_core.language_models import LanguageModelLike
|
|
15
|
+
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
|
16
|
+
from langgraph.runtime import Runtime
|
|
17
|
+
from loguru import logger
|
|
18
|
+
from openevals.llm import create_llm_as_judge
|
|
19
|
+
|
|
20
|
+
from dao_ai.messages import last_ai_message, last_human_message
|
|
21
|
+
from dao_ai.middleware.base import AgentMiddleware
|
|
22
|
+
from dao_ai.state import AgentState, Context
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _extract_text_content(message: BaseMessage) -> str:
|
|
26
|
+
"""
|
|
27
|
+
Extract text content from a message, handling both string and list formats.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
message: The message to extract text from
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
The extracted text content as a string
|
|
34
|
+
"""
|
|
35
|
+
content = message.content
|
|
36
|
+
|
|
37
|
+
if isinstance(content, str):
|
|
38
|
+
return content
|
|
39
|
+
elif isinstance(content, list):
|
|
40
|
+
# Extract text from content blocks (e.g., Claude's structured content)
|
|
41
|
+
text_parts = []
|
|
42
|
+
for block in content:
|
|
43
|
+
if isinstance(block, dict) and block.get("type") == "text":
|
|
44
|
+
text_parts.append(block.get("text", ""))
|
|
45
|
+
elif isinstance(block, str):
|
|
46
|
+
text_parts.append(block)
|
|
47
|
+
return " ".join(text_parts)
|
|
48
|
+
else:
|
|
49
|
+
return str(content)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
__all__ = [
|
|
53
|
+
"GuardrailMiddleware",
|
|
54
|
+
"ContentFilterMiddleware",
|
|
55
|
+
"SafetyGuardrailMiddleware",
|
|
56
|
+
"create_guardrail_middleware",
|
|
57
|
+
"create_content_filter_middleware",
|
|
58
|
+
"create_safety_guardrail_middleware",
|
|
59
|
+
]
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class GuardrailMiddleware(AgentMiddleware[AgentState, Context]):
|
|
63
|
+
"""
|
|
64
|
+
Middleware that applies LLM-based guardrails to agent responses.
|
|
65
|
+
|
|
66
|
+
Uses an LLM judge to evaluate responses against a prompt/criteria and
|
|
67
|
+
can request improvements if the response doesn't meet the criteria.
|
|
68
|
+
|
|
69
|
+
This is equivalent to the previous reflection_guardrail pattern but
|
|
70
|
+
implemented as middleware for better composability.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
guardrail_name: Name identifying this guardrail
|
|
74
|
+
model: The LLM to use for evaluation
|
|
75
|
+
prompt: The evaluation prompt/criteria
|
|
76
|
+
num_retries: Maximum number of retry attempts (default: 3)
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
def __init__(
|
|
80
|
+
self,
|
|
81
|
+
name: str,
|
|
82
|
+
model: LanguageModelLike,
|
|
83
|
+
prompt: str,
|
|
84
|
+
num_retries: int = 3,
|
|
85
|
+
):
|
|
86
|
+
super().__init__()
|
|
87
|
+
self.guardrail_name = name
|
|
88
|
+
self.model = model
|
|
89
|
+
self.prompt = prompt
|
|
90
|
+
self.num_retries = num_retries
|
|
91
|
+
self._retry_count = 0
|
|
92
|
+
|
|
93
|
+
@property
|
|
94
|
+
def name(self) -> str:
|
|
95
|
+
"""Return the guardrail name for middleware identification."""
|
|
96
|
+
return self.guardrail_name
|
|
97
|
+
|
|
98
|
+
def after_model(
|
|
99
|
+
self, state: AgentState, runtime: Runtime[Context]
|
|
100
|
+
) -> dict[str, Any] | None:
|
|
101
|
+
"""
|
|
102
|
+
Evaluate the model's response using an LLM judge.
|
|
103
|
+
|
|
104
|
+
If the response doesn't meet the guardrail criteria, returns a
|
|
105
|
+
HumanMessage with feedback to trigger a retry.
|
|
106
|
+
"""
|
|
107
|
+
messages: list[BaseMessage] = state.get("messages", [])
|
|
108
|
+
|
|
109
|
+
if not messages:
|
|
110
|
+
return None
|
|
111
|
+
|
|
112
|
+
ai_message: AIMessage | None = last_ai_message(messages)
|
|
113
|
+
human_message: HumanMessage | None = last_human_message(messages)
|
|
114
|
+
|
|
115
|
+
if not ai_message or not human_message:
|
|
116
|
+
return None
|
|
117
|
+
|
|
118
|
+
# Skip evaluation if the AI message has tool calls (not the final response yet)
|
|
119
|
+
if ai_message.tool_calls:
|
|
120
|
+
logger.trace(
|
|
121
|
+
"Guardrail skipping evaluation - AI message contains tool calls",
|
|
122
|
+
guardrail_name=self.guardrail_name,
|
|
123
|
+
)
|
|
124
|
+
return None
|
|
125
|
+
|
|
126
|
+
# Skip evaluation if the AI message has no content to evaluate
|
|
127
|
+
if not ai_message.content:
|
|
128
|
+
logger.trace(
|
|
129
|
+
"Guardrail skipping evaluation - AI message has no content",
|
|
130
|
+
guardrail_name=self.guardrail_name,
|
|
131
|
+
)
|
|
132
|
+
return None
|
|
133
|
+
|
|
134
|
+
# Extract text content from messages (handles both string and structured content)
|
|
135
|
+
human_content = _extract_text_content(human_message)
|
|
136
|
+
ai_content = _extract_text_content(ai_message)
|
|
137
|
+
|
|
138
|
+
logger.debug(
|
|
139
|
+
"Evaluating response with guardrail",
|
|
140
|
+
guardrail_name=self.guardrail_name,
|
|
141
|
+
input_length=len(human_content),
|
|
142
|
+
output_length=len(ai_content),
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
evaluator = create_llm_as_judge(
|
|
146
|
+
prompt=self.prompt,
|
|
147
|
+
judge=self.model,
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
eval_result = evaluator(inputs=human_content, outputs=ai_content)
|
|
151
|
+
|
|
152
|
+
if eval_result["score"]:
|
|
153
|
+
logger.debug(
|
|
154
|
+
"Response approved by guardrail",
|
|
155
|
+
guardrail_name=self.guardrail_name,
|
|
156
|
+
comment=eval_result["comment"],
|
|
157
|
+
)
|
|
158
|
+
self._retry_count = 0
|
|
159
|
+
return None
|
|
160
|
+
else:
|
|
161
|
+
self._retry_count += 1
|
|
162
|
+
comment: str = eval_result["comment"]
|
|
163
|
+
|
|
164
|
+
if self._retry_count >= self.num_retries:
|
|
165
|
+
logger.warning(
|
|
166
|
+
"Guardrail failed - max retries reached",
|
|
167
|
+
guardrail_name=self.guardrail_name,
|
|
168
|
+
retry_count=self._retry_count,
|
|
169
|
+
max_retries=self.num_retries,
|
|
170
|
+
critique=comment,
|
|
171
|
+
)
|
|
172
|
+
self._retry_count = 0
|
|
173
|
+
|
|
174
|
+
# Add system message to inform user of guardrail failure
|
|
175
|
+
failure_message = (
|
|
176
|
+
f"⚠️ **Quality Check Failed**\n\n"
|
|
177
|
+
f"The response did not meet the '{self.guardrail_name}' quality standards "
|
|
178
|
+
f"after {self.num_retries} attempts.\n\n"
|
|
179
|
+
f"**Issue:** {comment}\n\n"
|
|
180
|
+
f"The best available response has been provided, but please be aware it may not fully meet quality expectations."
|
|
181
|
+
)
|
|
182
|
+
return {"messages": [AIMessage(content=failure_message)]}
|
|
183
|
+
|
|
184
|
+
logger.warning(
|
|
185
|
+
"Guardrail requested improvements",
|
|
186
|
+
guardrail_name=self.guardrail_name,
|
|
187
|
+
retry=self._retry_count,
|
|
188
|
+
max_retries=self.num_retries,
|
|
189
|
+
critique=comment,
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
content: str = "\n".join([str(human_message.content), comment])
|
|
193
|
+
return {"messages": [HumanMessage(content=content)]}
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
class ContentFilterMiddleware(AgentMiddleware[AgentState, Context]):
|
|
197
|
+
"""
|
|
198
|
+
Middleware that filters responses containing banned keywords.
|
|
199
|
+
|
|
200
|
+
This is a deterministic guardrail that blocks responses containing
|
|
201
|
+
specified keywords.
|
|
202
|
+
|
|
203
|
+
Args:
|
|
204
|
+
banned_keywords: List of keywords to block
|
|
205
|
+
block_message: Message to return when content is blocked
|
|
206
|
+
"""
|
|
207
|
+
|
|
208
|
+
def __init__(
|
|
209
|
+
self,
|
|
210
|
+
banned_keywords: list[str],
|
|
211
|
+
block_message: str = "I cannot provide that response. Please rephrase your request.",
|
|
212
|
+
):
|
|
213
|
+
super().__init__()
|
|
214
|
+
self.banned_keywords = [kw.lower() for kw in banned_keywords]
|
|
215
|
+
self.block_message = block_message
|
|
216
|
+
|
|
217
|
+
@hook_config(can_jump_to=["end"])
|
|
218
|
+
def before_agent(
|
|
219
|
+
self, state: AgentState, runtime: Runtime[Context]
|
|
220
|
+
) -> dict[str, Any] | None:
|
|
221
|
+
"""Block requests containing banned keywords."""
|
|
222
|
+
messages: list[BaseMessage] = state.get("messages", [])
|
|
223
|
+
|
|
224
|
+
if not messages:
|
|
225
|
+
return None
|
|
226
|
+
|
|
227
|
+
first_message = messages[0]
|
|
228
|
+
if not isinstance(first_message, HumanMessage):
|
|
229
|
+
return None
|
|
230
|
+
|
|
231
|
+
content = str(first_message.content).lower()
|
|
232
|
+
|
|
233
|
+
for keyword in self.banned_keywords:
|
|
234
|
+
if keyword in content:
|
|
235
|
+
logger.warning(f"Content filter blocked request containing '{keyword}'")
|
|
236
|
+
return {
|
|
237
|
+
"messages": [AIMessage(content=self.block_message)],
|
|
238
|
+
"jump_to": "end",
|
|
239
|
+
}
|
|
240
|
+
|
|
241
|
+
return None
|
|
242
|
+
|
|
243
|
+
def after_model(
|
|
244
|
+
self, state: AgentState, runtime: Runtime[Context]
|
|
245
|
+
) -> dict[str, Any] | None:
|
|
246
|
+
"""Block responses containing banned keywords."""
|
|
247
|
+
messages: list[BaseMessage] = state.get("messages", [])
|
|
248
|
+
|
|
249
|
+
if not messages:
|
|
250
|
+
return None
|
|
251
|
+
|
|
252
|
+
last_message: AIMessage | None = last_ai_message(messages)
|
|
253
|
+
if not last_message:
|
|
254
|
+
return None
|
|
255
|
+
|
|
256
|
+
content = str(last_message.content).lower()
|
|
257
|
+
|
|
258
|
+
for keyword in self.banned_keywords:
|
|
259
|
+
if keyword in content:
|
|
260
|
+
logger.warning("Content filter blocked response", keyword=keyword)
|
|
261
|
+
# Modify the last message content
|
|
262
|
+
last_message.content = self.block_message
|
|
263
|
+
return None
|
|
264
|
+
|
|
265
|
+
return None
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
class SafetyGuardrailMiddleware(AgentMiddleware[AgentState, Context]):
|
|
269
|
+
"""
|
|
270
|
+
Model-based safety guardrail middleware.
|
|
271
|
+
|
|
272
|
+
Uses an LLM to evaluate whether responses are safe and appropriate.
|
|
273
|
+
This is applied after the agent completes to validate the final output.
|
|
274
|
+
|
|
275
|
+
Args:
|
|
276
|
+
safety_model: The LLM to use for safety evaluation
|
|
277
|
+
"""
|
|
278
|
+
|
|
279
|
+
def __init__(self, safety_model: Optional[LanguageModelLike] = None):
|
|
280
|
+
super().__init__()
|
|
281
|
+
self._safety_model = safety_model
|
|
282
|
+
|
|
283
|
+
@property
|
|
284
|
+
def safety_model(self) -> LanguageModelLike:
|
|
285
|
+
"""Lazily initialize the safety model."""
|
|
286
|
+
if self._safety_model is None:
|
|
287
|
+
from langchain.chat_models import init_chat_model
|
|
288
|
+
|
|
289
|
+
self._safety_model = init_chat_model("gpt-4o-mini")
|
|
290
|
+
return self._safety_model
|
|
291
|
+
|
|
292
|
+
def after_agent(
|
|
293
|
+
self, state: AgentState, runtime: Runtime[Context]
|
|
294
|
+
) -> dict[str, Any] | None:
|
|
295
|
+
"""Evaluate response safety before returning to user."""
|
|
296
|
+
messages: list[BaseMessage] = state.get("messages", [])
|
|
297
|
+
|
|
298
|
+
if not messages:
|
|
299
|
+
return None
|
|
300
|
+
|
|
301
|
+
last_message: AIMessage | None = last_ai_message(messages)
|
|
302
|
+
if not last_message:
|
|
303
|
+
return None
|
|
304
|
+
|
|
305
|
+
safety_prompt = f"""Evaluate if this response is safe and appropriate.
|
|
306
|
+
Respond with only 'SAFE' or 'UNSAFE'.
|
|
307
|
+
|
|
308
|
+
Response: {last_message.content}"""
|
|
309
|
+
|
|
310
|
+
result = self.safety_model.invoke([{"role": "user", "content": safety_prompt}])
|
|
311
|
+
|
|
312
|
+
if "UNSAFE" in str(result.content):
|
|
313
|
+
logger.warning("Safety guardrail blocked unsafe response")
|
|
314
|
+
last_message.content = (
|
|
315
|
+
"I cannot provide that response. Please rephrase your request."
|
|
316
|
+
)
|
|
317
|
+
|
|
318
|
+
return None
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
# =============================================================================
|
|
322
|
+
# Factory Functions
|
|
323
|
+
# =============================================================================
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
def create_guardrail_middleware(
|
|
327
|
+
name: str,
|
|
328
|
+
model: LanguageModelLike,
|
|
329
|
+
prompt: str,
|
|
330
|
+
num_retries: int = 3,
|
|
331
|
+
) -> GuardrailMiddleware:
|
|
332
|
+
"""
|
|
333
|
+
Create a GuardrailMiddleware instance.
|
|
334
|
+
|
|
335
|
+
Factory function for creating LLM-based guardrail middleware that evaluates
|
|
336
|
+
agent responses against specified criteria using an LLM judge.
|
|
337
|
+
|
|
338
|
+
Args:
|
|
339
|
+
name: Name identifying this guardrail
|
|
340
|
+
model: The LLM to use for evaluation
|
|
341
|
+
prompt: The evaluation prompt/criteria
|
|
342
|
+
num_retries: Maximum number of retry attempts (default: 3)
|
|
343
|
+
|
|
344
|
+
Returns:
|
|
345
|
+
List containing GuardrailMiddleware configured with the specified parameters
|
|
346
|
+
|
|
347
|
+
Example:
|
|
348
|
+
middleware = create_guardrail_middleware(
|
|
349
|
+
name="tone_check",
|
|
350
|
+
model=ChatDatabricks(endpoint="databricks-meta-llama-3-3-70b-instruct"),
|
|
351
|
+
prompt="Evaluate if the response is professional and helpful.",
|
|
352
|
+
num_retries=2,
|
|
353
|
+
)
|
|
354
|
+
"""
|
|
355
|
+
logger.trace("Creating guardrail middleware", guardrail_name=name)
|
|
356
|
+
return GuardrailMiddleware(
|
|
357
|
+
name=name,
|
|
358
|
+
model=model,
|
|
359
|
+
prompt=prompt,
|
|
360
|
+
num_retries=num_retries,
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
|
|
364
|
+
def create_content_filter_middleware(
|
|
365
|
+
banned_keywords: list[str],
|
|
366
|
+
block_message: str = "I cannot provide that response. Please rephrase your request.",
|
|
367
|
+
) -> ContentFilterMiddleware:
|
|
368
|
+
"""
|
|
369
|
+
Create a ContentFilterMiddleware instance.
|
|
370
|
+
|
|
371
|
+
Factory function for creating deterministic content filter middleware
|
|
372
|
+
that blocks requests/responses containing banned keywords.
|
|
373
|
+
|
|
374
|
+
Args:
|
|
375
|
+
banned_keywords: List of keywords to block
|
|
376
|
+
block_message: Message to return when content is blocked
|
|
377
|
+
|
|
378
|
+
Returns:
|
|
379
|
+
List containing ContentFilterMiddleware configured with the specified parameters
|
|
380
|
+
|
|
381
|
+
Example:
|
|
382
|
+
middleware = create_content_filter_middleware(
|
|
383
|
+
banned_keywords=["password", "secret", "api_key"],
|
|
384
|
+
block_message="I cannot discuss sensitive credentials.",
|
|
385
|
+
)
|
|
386
|
+
"""
|
|
387
|
+
logger.trace(
|
|
388
|
+
"Creating content filter middleware", keywords_count=len(banned_keywords)
|
|
389
|
+
)
|
|
390
|
+
return ContentFilterMiddleware(
|
|
391
|
+
banned_keywords=banned_keywords,
|
|
392
|
+
block_message=block_message,
|
|
393
|
+
)
|
|
394
|
+
|
|
395
|
+
|
|
396
|
+
def create_safety_guardrail_middleware(
|
|
397
|
+
safety_model: Optional[LanguageModelLike] = None,
|
|
398
|
+
) -> SafetyGuardrailMiddleware:
|
|
399
|
+
"""
|
|
400
|
+
Create a SafetyGuardrailMiddleware instance.
|
|
401
|
+
|
|
402
|
+
Factory function for creating model-based safety guardrail middleware
|
|
403
|
+
that evaluates whether responses are safe and appropriate.
|
|
404
|
+
|
|
405
|
+
Args:
|
|
406
|
+
safety_model: The LLM to use for safety evaluation. If not provided,
|
|
407
|
+
defaults to gpt-4o-mini.
|
|
408
|
+
|
|
409
|
+
Returns:
|
|
410
|
+
List containing SafetyGuardrailMiddleware configured with the specified model
|
|
411
|
+
|
|
412
|
+
Example:
|
|
413
|
+
from databricks_langchain import ChatDatabricks
|
|
414
|
+
|
|
415
|
+
middleware = create_safety_guardrail_middleware(
|
|
416
|
+
safety_model=ChatDatabricks(endpoint="databricks-meta-llama-3-3-70b-instruct"),
|
|
417
|
+
)
|
|
418
|
+
"""
|
|
419
|
+
logger.trace("Creating safety guardrail middleware")
|
|
420
|
+
return SafetyGuardrailMiddleware(safety_model=safety_model)
|
|
@@ -0,0 +1,233 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Human-in-the-loop middleware for DAO AI agents.
|
|
3
|
+
|
|
4
|
+
This module provides utilities for creating HITL middleware from DAO AI configuration.
|
|
5
|
+
It re-exports LangChain's built-in HumanInTheLoopMiddleware.
|
|
6
|
+
|
|
7
|
+
LangChain's HumanInTheLoopMiddleware automatically:
|
|
8
|
+
- Pauses agent execution for human approval of tool calls
|
|
9
|
+
- Allows humans to approve, edit, or reject tool calls
|
|
10
|
+
- Uses LangGraph's interrupt mechanism for persistence across pauses
|
|
11
|
+
|
|
12
|
+
Example:
|
|
13
|
+
from dao_ai.middleware import create_human_in_the_loop_middleware
|
|
14
|
+
|
|
15
|
+
middleware = create_human_in_the_loop_middleware(
|
|
16
|
+
interrupt_on={"send_email": True, "delete_record": True},
|
|
17
|
+
)
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
from typing import Any, Sequence
|
|
21
|
+
|
|
22
|
+
from langchain.agents.middleware import HumanInTheLoopMiddleware
|
|
23
|
+
from langchain.agents.middleware.human_in_the_loop import (
|
|
24
|
+
Action,
|
|
25
|
+
ActionRequest,
|
|
26
|
+
ApproveDecision,
|
|
27
|
+
Decision,
|
|
28
|
+
DecisionType,
|
|
29
|
+
EditDecision,
|
|
30
|
+
HITLRequest,
|
|
31
|
+
HITLResponse,
|
|
32
|
+
InterruptOnConfig,
|
|
33
|
+
RejectDecision,
|
|
34
|
+
ReviewConfig,
|
|
35
|
+
)
|
|
36
|
+
from loguru import logger
|
|
37
|
+
|
|
38
|
+
from dao_ai.config import HumanInTheLoopModel, ToolModel
|
|
39
|
+
|
|
40
|
+
__all__ = [
|
|
41
|
+
# LangChain middleware
|
|
42
|
+
"HumanInTheLoopMiddleware",
|
|
43
|
+
# LangChain HITL types
|
|
44
|
+
"Action",
|
|
45
|
+
"ActionRequest",
|
|
46
|
+
"ApproveDecision",
|
|
47
|
+
"Decision",
|
|
48
|
+
"DecisionType",
|
|
49
|
+
"EditDecision",
|
|
50
|
+
"HITLRequest",
|
|
51
|
+
"HITLResponse",
|
|
52
|
+
"InterruptOnConfig",
|
|
53
|
+
"RejectDecision",
|
|
54
|
+
"ReviewConfig",
|
|
55
|
+
# DAO AI helper functions and models
|
|
56
|
+
"create_human_in_the_loop_middleware",
|
|
57
|
+
"create_hitl_middleware_from_tool_models",
|
|
58
|
+
]
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _hitl_config_to_allowed_decisions(
|
|
62
|
+
hitl_config: HumanInTheLoopModel,
|
|
63
|
+
) -> list[DecisionType]:
|
|
64
|
+
"""
|
|
65
|
+
Extract allowed decisions from HumanInTheLoopModel.
|
|
66
|
+
|
|
67
|
+
LangChain's HumanInTheLoopMiddleware supports 3 decision types:
|
|
68
|
+
- "approve": Execute tool with original arguments
|
|
69
|
+
- "edit": Modify arguments before execution
|
|
70
|
+
- "reject": Skip execution with optional feedback message
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
hitl_config: HumanInTheLoopModel with allowed_decisions
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
List of allowed decision types (e.g., ["approve", "edit", "reject"])
|
|
77
|
+
"""
|
|
78
|
+
return hitl_config.allowed_decisions # type: ignore
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def _config_to_interrupt_on_entry(
|
|
82
|
+
config: HumanInTheLoopModel | bool,
|
|
83
|
+
) -> dict[str, Any] | bool:
|
|
84
|
+
"""
|
|
85
|
+
Convert a HITL config value to interrupt_on entry format.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
config: HumanInTheLoopModel, True, or False
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
dict with allowed_decisions and optional description, True, or False
|
|
92
|
+
"""
|
|
93
|
+
if config is False:
|
|
94
|
+
return False
|
|
95
|
+
if config is True:
|
|
96
|
+
return {"allowed_decisions": ["approve", "edit", "reject"]}
|
|
97
|
+
if isinstance(config, HumanInTheLoopModel):
|
|
98
|
+
interrupt_entry: dict[str, Any] = {
|
|
99
|
+
"allowed_decisions": _hitl_config_to_allowed_decisions(config)
|
|
100
|
+
}
|
|
101
|
+
# If review_prompt is provided, use it as the description
|
|
102
|
+
if config.review_prompt is not None:
|
|
103
|
+
interrupt_entry["description"] = config.review_prompt
|
|
104
|
+
return interrupt_entry
|
|
105
|
+
|
|
106
|
+
logger.warning(
|
|
107
|
+
"Unknown HITL config type, defaulting to True",
|
|
108
|
+
config_type=type(config).__name__,
|
|
109
|
+
)
|
|
110
|
+
return True
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def create_human_in_the_loop_middleware(
|
|
114
|
+
interrupt_on: dict[str, HumanInTheLoopModel | bool | dict[str, Any]],
|
|
115
|
+
description_prefix: str = "Tool execution pending approval",
|
|
116
|
+
) -> HumanInTheLoopMiddleware:
|
|
117
|
+
"""
|
|
118
|
+
Create a HumanInTheLoopMiddleware instance.
|
|
119
|
+
|
|
120
|
+
Factory function for creating LangChain's built-in HumanInTheLoopMiddleware.
|
|
121
|
+
Accepts HumanInTheLoopModel, bool, or raw dict configurations per tool.
|
|
122
|
+
|
|
123
|
+
Note: This middleware requires a checkpointer to be configured on the agent.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
interrupt_on: Dictionary mapping tool names to HITL configuration.
|
|
127
|
+
Each tool can be configured with:
|
|
128
|
+
- HumanInTheLoopModel: Full configuration with custom settings
|
|
129
|
+
- True: Enable HITL with default settings (approve, edit, reject)
|
|
130
|
+
- False: Disable HITL for this tool
|
|
131
|
+
- dict: Raw interrupt_on config (e.g., {"allowed_decisions": [...]})
|
|
132
|
+
description_prefix: Message prefix shown when pausing for review
|
|
133
|
+
|
|
134
|
+
Returns:
|
|
135
|
+
List containing HumanInTheLoopMiddleware configured with the specified parameters
|
|
136
|
+
|
|
137
|
+
Example:
|
|
138
|
+
from dao_ai.config import HumanInTheLoopModel
|
|
139
|
+
|
|
140
|
+
middleware = create_human_in_the_loop_middleware(
|
|
141
|
+
interrupt_on={
|
|
142
|
+
"send_email": HumanInTheLoopModel(review_prompt="Review email"),
|
|
143
|
+
"delete_record": True,
|
|
144
|
+
"search": False,
|
|
145
|
+
},
|
|
146
|
+
)
|
|
147
|
+
"""
|
|
148
|
+
# Convert HumanInTheLoopModel entries to dict format
|
|
149
|
+
normalized_interrupt_on: dict[str, Any] = {}
|
|
150
|
+
for tool_name, config in interrupt_on.items():
|
|
151
|
+
if isinstance(config, (HumanInTheLoopModel, bool)):
|
|
152
|
+
normalized_interrupt_on[tool_name] = _config_to_interrupt_on_entry(config)
|
|
153
|
+
else:
|
|
154
|
+
# Already in dict format
|
|
155
|
+
normalized_interrupt_on[tool_name] = config
|
|
156
|
+
|
|
157
|
+
logger.debug(
|
|
158
|
+
"Creating HITL middleware",
|
|
159
|
+
tools_count=len(normalized_interrupt_on),
|
|
160
|
+
tools=list(normalized_interrupt_on.keys()),
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
return HumanInTheLoopMiddleware(
|
|
164
|
+
interrupt_on=normalized_interrupt_on,
|
|
165
|
+
description_prefix=description_prefix,
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def create_hitl_middleware_from_tool_models(
|
|
170
|
+
tool_models: Sequence[ToolModel],
|
|
171
|
+
description_prefix: str = "Tool execution pending approval",
|
|
172
|
+
) -> HumanInTheLoopMiddleware | None:
|
|
173
|
+
"""
|
|
174
|
+
Create HumanInTheLoopMiddleware from ToolModel configurations.
|
|
175
|
+
|
|
176
|
+
Scans tool_models for those with human_in_the_loop configured and
|
|
177
|
+
creates the appropriate middleware. This is the primary entry point
|
|
178
|
+
used by the agent node creation.
|
|
179
|
+
|
|
180
|
+
Args:
|
|
181
|
+
tool_models: List of ToolModel configurations from agent config
|
|
182
|
+
description_prefix: Message prefix shown when pausing for review
|
|
183
|
+
|
|
184
|
+
Returns:
|
|
185
|
+
List containing HumanInTheLoopMiddleware if any tools require approval,
|
|
186
|
+
empty list otherwise
|
|
187
|
+
|
|
188
|
+
Example:
|
|
189
|
+
from dao_ai.config import ToolModel, PythonFunctionModel, HumanInTheLoopModel
|
|
190
|
+
|
|
191
|
+
tool_models = [
|
|
192
|
+
ToolModel(
|
|
193
|
+
name="email_tool",
|
|
194
|
+
function=PythonFunctionModel(
|
|
195
|
+
name="send_email",
|
|
196
|
+
human_in_the_loop=HumanInTheLoopModel(
|
|
197
|
+
review_prompt="Review this email",
|
|
198
|
+
),
|
|
199
|
+
),
|
|
200
|
+
),
|
|
201
|
+
]
|
|
202
|
+
|
|
203
|
+
middleware = create_hitl_middleware_from_tool_models(tool_models)
|
|
204
|
+
"""
|
|
205
|
+
from dao_ai.config import BaseFunctionModel
|
|
206
|
+
|
|
207
|
+
interrupt_on: dict[str, HumanInTheLoopModel] = {}
|
|
208
|
+
|
|
209
|
+
for tool_model in tool_models:
|
|
210
|
+
function = tool_model.function
|
|
211
|
+
|
|
212
|
+
if not isinstance(function, BaseFunctionModel):
|
|
213
|
+
continue
|
|
214
|
+
|
|
215
|
+
hitl_config: HumanInTheLoopModel | None = function.human_in_the_loop
|
|
216
|
+
if not hitl_config:
|
|
217
|
+
continue
|
|
218
|
+
|
|
219
|
+
# Get tool names created by this function
|
|
220
|
+
for func_tool in function.as_tools():
|
|
221
|
+
tool_name: str | None = getattr(func_tool, "name", None)
|
|
222
|
+
if tool_name:
|
|
223
|
+
interrupt_on[tool_name] = hitl_config
|
|
224
|
+
logger.trace("Tool configured for HITL", tool_name=tool_name)
|
|
225
|
+
|
|
226
|
+
if not interrupt_on:
|
|
227
|
+
logger.trace("No tools require HITL - returning None")
|
|
228
|
+
return None
|
|
229
|
+
|
|
230
|
+
return create_human_in_the_loop_middleware(
|
|
231
|
+
interrupt_on=interrupt_on,
|
|
232
|
+
description_prefix=description_prefix,
|
|
233
|
+
)
|