dao-ai 0.0.28__py3-none-any.whl → 0.1.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dao_ai/__init__.py +29 -0
- dao_ai/agent_as_code.py +2 -5
- dao_ai/cli.py +342 -58
- dao_ai/config.py +1610 -380
- dao_ai/genie/__init__.py +38 -0
- dao_ai/genie/cache/__init__.py +43 -0
- dao_ai/genie/cache/base.py +72 -0
- dao_ai/genie/cache/core.py +79 -0
- dao_ai/genie/cache/lru.py +347 -0
- dao_ai/genie/cache/semantic.py +970 -0
- dao_ai/genie/core.py +35 -0
- dao_ai/graph.py +27 -253
- dao_ai/hooks/__init__.py +9 -6
- dao_ai/hooks/core.py +27 -195
- dao_ai/logging.py +56 -0
- dao_ai/memory/__init__.py +10 -0
- dao_ai/memory/core.py +65 -30
- dao_ai/memory/databricks.py +402 -0
- dao_ai/memory/postgres.py +79 -38
- dao_ai/messages.py +6 -4
- dao_ai/middleware/__init__.py +158 -0
- dao_ai/middleware/assertions.py +806 -0
- dao_ai/middleware/base.py +50 -0
- dao_ai/middleware/context_editing.py +230 -0
- dao_ai/middleware/core.py +67 -0
- dao_ai/middleware/guardrails.py +420 -0
- dao_ai/middleware/human_in_the_loop.py +233 -0
- dao_ai/middleware/message_validation.py +586 -0
- dao_ai/middleware/model_call_limit.py +77 -0
- dao_ai/middleware/model_retry.py +121 -0
- dao_ai/middleware/pii.py +157 -0
- dao_ai/middleware/summarization.py +197 -0
- dao_ai/middleware/tool_call_limit.py +210 -0
- dao_ai/middleware/tool_retry.py +174 -0
- dao_ai/models.py +1306 -114
- dao_ai/nodes.py +240 -161
- dao_ai/optimization.py +674 -0
- dao_ai/orchestration/__init__.py +52 -0
- dao_ai/orchestration/core.py +294 -0
- dao_ai/orchestration/supervisor.py +279 -0
- dao_ai/orchestration/swarm.py +271 -0
- dao_ai/prompts.py +128 -31
- dao_ai/providers/databricks.py +584 -601
- dao_ai/state.py +157 -21
- dao_ai/tools/__init__.py +13 -5
- dao_ai/tools/agent.py +1 -3
- dao_ai/tools/core.py +64 -11
- dao_ai/tools/email.py +232 -0
- dao_ai/tools/genie.py +144 -294
- dao_ai/tools/mcp.py +223 -155
- dao_ai/tools/memory.py +50 -0
- dao_ai/tools/python.py +9 -14
- dao_ai/tools/search.py +14 -0
- dao_ai/tools/slack.py +22 -10
- dao_ai/tools/sql.py +202 -0
- dao_ai/tools/time.py +30 -7
- dao_ai/tools/unity_catalog.py +165 -88
- dao_ai/tools/vector_search.py +331 -221
- dao_ai/utils.py +166 -20
- dao_ai/vector_search.py +37 -0
- dao_ai-0.1.5.dist-info/METADATA +489 -0
- dao_ai-0.1.5.dist-info/RECORD +70 -0
- dao_ai/chat_models.py +0 -204
- dao_ai/guardrails.py +0 -112
- dao_ai/tools/human_in_the_loop.py +0 -100
- dao_ai-0.0.28.dist-info/METADATA +0 -1168
- dao_ai-0.0.28.dist-info/RECORD +0 -41
- {dao_ai-0.0.28.dist-info → dao_ai-0.1.5.dist-info}/WHEEL +0 -0
- {dao_ai-0.0.28.dist-info → dao_ai-0.1.5.dist-info}/entry_points.txt +0 -0
- {dao_ai-0.0.28.dist-info → dao_ai-0.1.5.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,806 @@
|
|
|
1
|
+
"""
|
|
2
|
+
DSPy-style assertion middleware for DAO AI agents.
|
|
3
|
+
|
|
4
|
+
This module provides middleware implementations inspired by DSPy's assertion
|
|
5
|
+
mechanisms (dspy.Assert, dspy.Suggest, dspy.Refine) but implemented natively
|
|
6
|
+
in the LangChain middleware pattern for optimal latency and streaming support.
|
|
7
|
+
|
|
8
|
+
Key concepts:
|
|
9
|
+
- Assert: Hard constraint - retry until satisfied or fail after max attempts
|
|
10
|
+
- Suggest: Soft constraint - provide feedback but don't block execution
|
|
11
|
+
- Refine: Iterative improvement - run multiple times, select best result
|
|
12
|
+
|
|
13
|
+
These work with LangChain's middleware hooks (after_model) to validate and
|
|
14
|
+
improve agent outputs without requiring the DSPy library.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from abc import ABC, abstractmethod
|
|
18
|
+
from dataclasses import dataclass, field
|
|
19
|
+
from typing import Any, Callable, Optional, TypeVar
|
|
20
|
+
|
|
21
|
+
from langchain_core.language_models import LanguageModelLike
|
|
22
|
+
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
|
|
23
|
+
from langgraph.runtime import Runtime
|
|
24
|
+
from loguru import logger
|
|
25
|
+
|
|
26
|
+
from dao_ai.messages import last_ai_message, last_human_message
|
|
27
|
+
from dao_ai.middleware.base import AgentMiddleware
|
|
28
|
+
from dao_ai.state import AgentState, Context
|
|
29
|
+
|
|
30
|
+
__all__ = [
|
|
31
|
+
# Types
|
|
32
|
+
"Constraint",
|
|
33
|
+
"ConstraintResult",
|
|
34
|
+
# Middleware classes
|
|
35
|
+
"AssertMiddleware",
|
|
36
|
+
"SuggestMiddleware",
|
|
37
|
+
"RefineMiddleware",
|
|
38
|
+
# Factory functions
|
|
39
|
+
"create_assert_middleware",
|
|
40
|
+
"create_suggest_middleware",
|
|
41
|
+
"create_refine_middleware",
|
|
42
|
+
]
|
|
43
|
+
|
|
44
|
+
T = TypeVar("T")
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@dataclass
|
|
48
|
+
class ConstraintResult:
|
|
49
|
+
"""Result of evaluating a constraint against model output.
|
|
50
|
+
|
|
51
|
+
Attributes:
|
|
52
|
+
passed: Whether the constraint was satisfied
|
|
53
|
+
feedback: Feedback message explaining the result
|
|
54
|
+
score: Optional numeric score (0.0 to 1.0)
|
|
55
|
+
metadata: Additional metadata from the evaluation
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
passed: bool
|
|
59
|
+
feedback: str = ""
|
|
60
|
+
score: Optional[float] = None
|
|
61
|
+
metadata: dict[str, Any] = field(default_factory=dict)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class Constraint(ABC):
|
|
65
|
+
"""Base class for constraints that can be evaluated against model outputs.
|
|
66
|
+
|
|
67
|
+
Constraints can be:
|
|
68
|
+
- Callable functions: (response: str, context: dict) -> ConstraintResult | bool
|
|
69
|
+
- LLM-based evaluators: Use a judge model to evaluate responses
|
|
70
|
+
- Rule-based: Deterministic checks like regex, keywords, length
|
|
71
|
+
|
|
72
|
+
Example:
|
|
73
|
+
class LengthConstraint(Constraint):
|
|
74
|
+
def __init__(self, min_length: int, max_length: int):
|
|
75
|
+
self.min_length = min_length
|
|
76
|
+
self.max_length = max_length
|
|
77
|
+
|
|
78
|
+
def evaluate(self, response: str, context: dict) -> ConstraintResult:
|
|
79
|
+
length = len(response)
|
|
80
|
+
if self.min_length <= length <= self.max_length:
|
|
81
|
+
return ConstraintResult(passed=True, feedback="Length OK")
|
|
82
|
+
return ConstraintResult(
|
|
83
|
+
passed=False,
|
|
84
|
+
feedback=f"Response length {length} not in range [{self.min_length}, {self.max_length}]"
|
|
85
|
+
)
|
|
86
|
+
"""
|
|
87
|
+
|
|
88
|
+
@abstractmethod
|
|
89
|
+
def evaluate(self, response: str, context: dict[str, Any]) -> ConstraintResult:
|
|
90
|
+
"""Evaluate the constraint against a response.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
response: The model's response text
|
|
94
|
+
context: Additional context (user input, state, etc.)
|
|
95
|
+
|
|
96
|
+
Returns:
|
|
97
|
+
ConstraintResult indicating whether constraint was satisfied
|
|
98
|
+
"""
|
|
99
|
+
...
|
|
100
|
+
|
|
101
|
+
@property
|
|
102
|
+
def name(self) -> str:
|
|
103
|
+
"""Name of this constraint for logging."""
|
|
104
|
+
return self.__class__.__name__
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
class FunctionConstraint(Constraint):
|
|
108
|
+
"""Constraint that wraps a callable function.
|
|
109
|
+
|
|
110
|
+
The function can return either:
|
|
111
|
+
- bool: True = passed, False = failed with default feedback
|
|
112
|
+
- ConstraintResult: Full result with feedback and score
|
|
113
|
+
"""
|
|
114
|
+
|
|
115
|
+
def __init__(
|
|
116
|
+
self,
|
|
117
|
+
func: Callable[[str, dict[str, Any]], ConstraintResult | bool],
|
|
118
|
+
name: Optional[str] = None,
|
|
119
|
+
default_feedback: str = "Constraint not satisfied",
|
|
120
|
+
):
|
|
121
|
+
self._func = func
|
|
122
|
+
self._name = name or func.__name__
|
|
123
|
+
self._default_feedback = default_feedback
|
|
124
|
+
|
|
125
|
+
@property
|
|
126
|
+
def name(self) -> str:
|
|
127
|
+
return self._name
|
|
128
|
+
|
|
129
|
+
def evaluate(self, response: str, context: dict[str, Any]) -> ConstraintResult:
|
|
130
|
+
result = self._func(response, context)
|
|
131
|
+
if isinstance(result, bool):
|
|
132
|
+
return ConstraintResult(
|
|
133
|
+
passed=result,
|
|
134
|
+
feedback="" if result else self._default_feedback,
|
|
135
|
+
)
|
|
136
|
+
return result
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
class LLMConstraint(Constraint):
|
|
140
|
+
"""Constraint that uses an LLM judge to evaluate responses.
|
|
141
|
+
|
|
142
|
+
Similar to LLM-as-judge evaluation but returns a ConstraintResult.
|
|
143
|
+
"""
|
|
144
|
+
|
|
145
|
+
def __init__(
|
|
146
|
+
self,
|
|
147
|
+
model: LanguageModelLike,
|
|
148
|
+
prompt: str,
|
|
149
|
+
name: Optional[str] = None,
|
|
150
|
+
threshold: float = 0.5,
|
|
151
|
+
):
|
|
152
|
+
"""Initialize LLM-based constraint.
|
|
153
|
+
|
|
154
|
+
Args:
|
|
155
|
+
model: LLM to use for evaluation
|
|
156
|
+
prompt: Evaluation prompt. Should include {response} and optionally {input} placeholders.
|
|
157
|
+
name: Name for logging
|
|
158
|
+
threshold: Score threshold for passing (0.0-1.0)
|
|
159
|
+
"""
|
|
160
|
+
self._model = model
|
|
161
|
+
self._prompt = prompt
|
|
162
|
+
self._name = name or "LLMConstraint"
|
|
163
|
+
self._threshold = threshold
|
|
164
|
+
|
|
165
|
+
@property
|
|
166
|
+
def name(self) -> str:
|
|
167
|
+
return self._name
|
|
168
|
+
|
|
169
|
+
def evaluate(self, response: str, context: dict[str, Any]) -> ConstraintResult:
|
|
170
|
+
user_input = context.get("input", "")
|
|
171
|
+
|
|
172
|
+
eval_prompt = self._prompt.format(response=response, input=user_input)
|
|
173
|
+
|
|
174
|
+
result = self._model.invoke(
|
|
175
|
+
[
|
|
176
|
+
{
|
|
177
|
+
"role": "system",
|
|
178
|
+
"content": (
|
|
179
|
+
"You are an evaluation assistant. Evaluate the response and reply with:\n"
|
|
180
|
+
"PASS: <feedback> if the constraint is satisfied\n"
|
|
181
|
+
"FAIL: <feedback> if the constraint is not satisfied\n"
|
|
182
|
+
"Be concise."
|
|
183
|
+
),
|
|
184
|
+
},
|
|
185
|
+
{"role": "user", "content": eval_prompt},
|
|
186
|
+
]
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
content = str(result.content).strip()
|
|
190
|
+
|
|
191
|
+
if content.upper().startswith("PASS"):
|
|
192
|
+
feedback = content[5:].strip(": ").strip()
|
|
193
|
+
return ConstraintResult(passed=True, feedback=feedback, score=1.0)
|
|
194
|
+
elif content.upper().startswith("FAIL"):
|
|
195
|
+
feedback = content[5:].strip(": ").strip()
|
|
196
|
+
return ConstraintResult(passed=False, feedback=feedback, score=0.0)
|
|
197
|
+
else:
|
|
198
|
+
# Try to interpret as pass/fail
|
|
199
|
+
is_pass = any(
|
|
200
|
+
word in content.lower()
|
|
201
|
+
for word in ["yes", "pass", "correct", "good", "valid"]
|
|
202
|
+
)
|
|
203
|
+
return ConstraintResult(passed=is_pass, feedback=content)
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
class KeywordConstraint(Constraint):
|
|
207
|
+
"""Simple constraint that checks for required/banned keywords."""
|
|
208
|
+
|
|
209
|
+
def __init__(
|
|
210
|
+
self,
|
|
211
|
+
required_keywords: Optional[list[str]] = None,
|
|
212
|
+
banned_keywords: Optional[list[str]] = None,
|
|
213
|
+
case_sensitive: bool = False,
|
|
214
|
+
name: Optional[str] = None,
|
|
215
|
+
):
|
|
216
|
+
self._required = required_keywords or []
|
|
217
|
+
self._banned = banned_keywords or []
|
|
218
|
+
self._case_sensitive = case_sensitive
|
|
219
|
+
self._name = name or "KeywordConstraint"
|
|
220
|
+
|
|
221
|
+
@property
|
|
222
|
+
def name(self) -> str:
|
|
223
|
+
return self._name
|
|
224
|
+
|
|
225
|
+
def evaluate(self, response: str, context: dict[str, Any]) -> ConstraintResult:
|
|
226
|
+
check_response = response if self._case_sensitive else response.lower()
|
|
227
|
+
|
|
228
|
+
# Check banned keywords
|
|
229
|
+
for keyword in self._banned:
|
|
230
|
+
check_keyword = keyword if self._case_sensitive else keyword.lower()
|
|
231
|
+
if check_keyword in check_response:
|
|
232
|
+
return ConstraintResult(
|
|
233
|
+
passed=False,
|
|
234
|
+
feedback=f"Response contains banned keyword: '{keyword}'",
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
# Check required keywords
|
|
238
|
+
for keyword in self._required:
|
|
239
|
+
check_keyword = keyword if self._case_sensitive else keyword.lower()
|
|
240
|
+
if check_keyword not in check_response:
|
|
241
|
+
return ConstraintResult(
|
|
242
|
+
passed=False,
|
|
243
|
+
feedback=f"Response missing required keyword: '{keyword}'",
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
return ConstraintResult(
|
|
247
|
+
passed=True, feedback="All keyword constraints satisfied"
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
class LengthConstraint(Constraint):
|
|
252
|
+
"""Constraint that checks response length."""
|
|
253
|
+
|
|
254
|
+
def __init__(
|
|
255
|
+
self,
|
|
256
|
+
min_length: Optional[int] = None,
|
|
257
|
+
max_length: Optional[int] = None,
|
|
258
|
+
unit: str = "chars", # "chars", "words", "sentences"
|
|
259
|
+
name: Optional[str] = None,
|
|
260
|
+
):
|
|
261
|
+
self._min_length = min_length
|
|
262
|
+
self._max_length = max_length
|
|
263
|
+
self._unit = unit
|
|
264
|
+
self._name = name or "LengthConstraint"
|
|
265
|
+
|
|
266
|
+
@property
|
|
267
|
+
def name(self) -> str:
|
|
268
|
+
return self._name
|
|
269
|
+
|
|
270
|
+
def evaluate(self, response: str, context: dict[str, Any]) -> ConstraintResult:
|
|
271
|
+
if self._unit == "chars":
|
|
272
|
+
length = len(response)
|
|
273
|
+
elif self._unit == "words":
|
|
274
|
+
length = len(response.split())
|
|
275
|
+
elif self._unit == "sentences":
|
|
276
|
+
length = response.count(".") + response.count("!") + response.count("?")
|
|
277
|
+
else:
|
|
278
|
+
length = len(response)
|
|
279
|
+
|
|
280
|
+
if self._min_length is not None and length < self._min_length:
|
|
281
|
+
return ConstraintResult(
|
|
282
|
+
passed=False,
|
|
283
|
+
feedback=f"Response too short: {length} {self._unit} (min: {self._min_length})",
|
|
284
|
+
score=length / self._min_length if self._min_length > 0 else 0.0,
|
|
285
|
+
)
|
|
286
|
+
|
|
287
|
+
if self._max_length is not None and length > self._max_length:
|
|
288
|
+
return ConstraintResult(
|
|
289
|
+
passed=False,
|
|
290
|
+
feedback=f"Response too long: {length} {self._unit} (max: {self._max_length})",
|
|
291
|
+
score=self._max_length / length if length > 0 else 0.0,
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
return ConstraintResult(
|
|
295
|
+
passed=True,
|
|
296
|
+
feedback=f"Length OK: {length} {self._unit}",
|
|
297
|
+
score=1.0,
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
|
|
301
|
+
# =============================================================================
|
|
302
|
+
# AssertMiddleware - Hard constraint with retry (like dspy.Assert)
|
|
303
|
+
# =============================================================================
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
class AssertMiddleware(AgentMiddleware[AgentState, Context]):
|
|
307
|
+
"""
|
|
308
|
+
Hard constraint middleware that retries until satisfied.
|
|
309
|
+
|
|
310
|
+
Inspired by dspy.Assert - if the constraint fails, the middleware
|
|
311
|
+
adds feedback to the conversation and requests a retry. If max
|
|
312
|
+
retries are exhausted, it raises an error or returns a fallback.
|
|
313
|
+
|
|
314
|
+
Args:
|
|
315
|
+
constraint: The constraint to enforce
|
|
316
|
+
max_retries: Maximum retry attempts before giving up
|
|
317
|
+
on_failure: What to do when max retries exhausted:
|
|
318
|
+
- "error": Raise ValueError (default)
|
|
319
|
+
- "fallback": Return fallback_message
|
|
320
|
+
- "pass": Let the response through anyway
|
|
321
|
+
fallback_message: Message to return if on_failure="fallback"
|
|
322
|
+
|
|
323
|
+
Example:
|
|
324
|
+
middleware = AssertMiddleware(
|
|
325
|
+
constraint=LengthConstraint(min_length=100),
|
|
326
|
+
max_retries=3,
|
|
327
|
+
on_failure="fallback",
|
|
328
|
+
fallback_message="Unable to generate a complete response."
|
|
329
|
+
)
|
|
330
|
+
"""
|
|
331
|
+
|
|
332
|
+
def __init__(
|
|
333
|
+
self,
|
|
334
|
+
constraint: Constraint,
|
|
335
|
+
max_retries: int = 3,
|
|
336
|
+
on_failure: str = "error", # "error", "fallback", "pass"
|
|
337
|
+
fallback_message: str = "Unable to generate a valid response.",
|
|
338
|
+
):
|
|
339
|
+
super().__init__()
|
|
340
|
+
self.constraint = constraint
|
|
341
|
+
self.max_retries = max_retries
|
|
342
|
+
self.on_failure = on_failure
|
|
343
|
+
self.fallback_message = fallback_message
|
|
344
|
+
self._retry_count = 0
|
|
345
|
+
|
|
346
|
+
def after_model(
|
|
347
|
+
self, state: AgentState, runtime: Runtime[Context]
|
|
348
|
+
) -> dict[str, Any] | None:
|
|
349
|
+
"""Evaluate constraint and retry if not satisfied."""
|
|
350
|
+
messages: list[BaseMessage] = state.get("messages", [])
|
|
351
|
+
|
|
352
|
+
if not messages:
|
|
353
|
+
return None
|
|
354
|
+
|
|
355
|
+
ai_message: AIMessage | None = last_ai_message(messages)
|
|
356
|
+
human_message: HumanMessage | None = last_human_message(messages)
|
|
357
|
+
|
|
358
|
+
if not ai_message:
|
|
359
|
+
return None
|
|
360
|
+
|
|
361
|
+
response = str(ai_message.content)
|
|
362
|
+
user_input = str(human_message.content) if human_message else ""
|
|
363
|
+
|
|
364
|
+
context = {
|
|
365
|
+
"input": user_input,
|
|
366
|
+
"messages": messages,
|
|
367
|
+
"runtime": runtime,
|
|
368
|
+
}
|
|
369
|
+
|
|
370
|
+
logger.trace(
|
|
371
|
+
"Evaluating Assert constraint", constraint_name=self.constraint.name
|
|
372
|
+
)
|
|
373
|
+
|
|
374
|
+
result = self.constraint.evaluate(response, context)
|
|
375
|
+
|
|
376
|
+
if result.passed:
|
|
377
|
+
logger.trace(
|
|
378
|
+
"Assert constraint passed", constraint_name=self.constraint.name
|
|
379
|
+
)
|
|
380
|
+
self._retry_count = 0
|
|
381
|
+
return None
|
|
382
|
+
|
|
383
|
+
# Constraint failed
|
|
384
|
+
self._retry_count += 1
|
|
385
|
+
logger.warning(
|
|
386
|
+
"Assert constraint failed",
|
|
387
|
+
constraint_name=self.constraint.name,
|
|
388
|
+
attempt=self._retry_count,
|
|
389
|
+
max_retries=self.max_retries,
|
|
390
|
+
feedback=result.feedback,
|
|
391
|
+
)
|
|
392
|
+
|
|
393
|
+
if self._retry_count >= self.max_retries:
|
|
394
|
+
self._retry_count = 0
|
|
395
|
+
|
|
396
|
+
if self.on_failure == "error":
|
|
397
|
+
raise ValueError(
|
|
398
|
+
f"Assert constraint '{self.constraint.name}' failed after "
|
|
399
|
+
f"{self.max_retries} retries: {result.feedback}"
|
|
400
|
+
)
|
|
401
|
+
elif self.on_failure == "fallback":
|
|
402
|
+
ai_message.content = self.fallback_message
|
|
403
|
+
return None
|
|
404
|
+
else: # "pass"
|
|
405
|
+
logger.warning(
|
|
406
|
+
"Assert constraint failed but passing through",
|
|
407
|
+
constraint_name=self.constraint.name,
|
|
408
|
+
)
|
|
409
|
+
return None
|
|
410
|
+
|
|
411
|
+
# Add feedback and retry
|
|
412
|
+
retry_prompt = (
|
|
413
|
+
f"Your previous response did not meet the requirements:\n"
|
|
414
|
+
f"{result.feedback}\n\n"
|
|
415
|
+
f"Please try again with the original request:\n{user_input}"
|
|
416
|
+
)
|
|
417
|
+
return {"messages": [HumanMessage(content=retry_prompt)]}
|
|
418
|
+
|
|
419
|
+
|
|
420
|
+
# =============================================================================
|
|
421
|
+
# SuggestMiddleware - Soft constraint with feedback (like dspy.Suggest)
|
|
422
|
+
# =============================================================================
|
|
423
|
+
|
|
424
|
+
|
|
425
|
+
class SuggestMiddleware(AgentMiddleware[AgentState, Context]):
|
|
426
|
+
"""
|
|
427
|
+
Soft constraint middleware that provides feedback without blocking.
|
|
428
|
+
|
|
429
|
+
Inspired by dspy.Suggest - evaluates the constraint and logs feedback
|
|
430
|
+
but does not retry or block the response. The feedback is captured
|
|
431
|
+
in metadata for observability but the response passes through.
|
|
432
|
+
|
|
433
|
+
Optionally, can request one improvement attempt if constraint fails.
|
|
434
|
+
|
|
435
|
+
Args:
|
|
436
|
+
constraint: The constraint to evaluate
|
|
437
|
+
allow_one_retry: If True, request one improvement attempt on failure
|
|
438
|
+
log_level: Log level for feedback ("warning", "info", "debug")
|
|
439
|
+
|
|
440
|
+
Example:
|
|
441
|
+
middleware = SuggestMiddleware(
|
|
442
|
+
constraint=LLMConstraint(
|
|
443
|
+
model=ChatDatabricks(...),
|
|
444
|
+
prompt="Check if response is professional: {response}"
|
|
445
|
+
),
|
|
446
|
+
allow_one_retry=True,
|
|
447
|
+
)
|
|
448
|
+
"""
|
|
449
|
+
|
|
450
|
+
def __init__(
|
|
451
|
+
self,
|
|
452
|
+
constraint: Constraint,
|
|
453
|
+
allow_one_retry: bool = False,
|
|
454
|
+
log_level: str = "warning",
|
|
455
|
+
):
|
|
456
|
+
super().__init__()
|
|
457
|
+
self.constraint = constraint
|
|
458
|
+
self.allow_one_retry = allow_one_retry
|
|
459
|
+
self.log_level = log_level
|
|
460
|
+
self._has_retried = False
|
|
461
|
+
|
|
462
|
+
def after_model(
|
|
463
|
+
self, state: AgentState, runtime: Runtime[Context]
|
|
464
|
+
) -> dict[str, Any] | None:
|
|
465
|
+
"""Evaluate constraint and log feedback."""
|
|
466
|
+
messages: list[BaseMessage] = state.get("messages", [])
|
|
467
|
+
|
|
468
|
+
if not messages:
|
|
469
|
+
return None
|
|
470
|
+
|
|
471
|
+
ai_message: AIMessage | None = last_ai_message(messages)
|
|
472
|
+
human_message: HumanMessage | None = last_human_message(messages)
|
|
473
|
+
|
|
474
|
+
if not ai_message:
|
|
475
|
+
return None
|
|
476
|
+
|
|
477
|
+
response = str(ai_message.content)
|
|
478
|
+
user_input = str(human_message.content) if human_message else ""
|
|
479
|
+
|
|
480
|
+
context = {
|
|
481
|
+
"input": user_input,
|
|
482
|
+
"messages": messages,
|
|
483
|
+
"runtime": runtime,
|
|
484
|
+
}
|
|
485
|
+
|
|
486
|
+
logger.trace(
|
|
487
|
+
"Evaluating Suggest constraint", constraint_name=self.constraint.name
|
|
488
|
+
)
|
|
489
|
+
|
|
490
|
+
result = self.constraint.evaluate(response, context)
|
|
491
|
+
|
|
492
|
+
if result.passed:
|
|
493
|
+
logger.trace(
|
|
494
|
+
"Suggest constraint passed", constraint_name=self.constraint.name
|
|
495
|
+
)
|
|
496
|
+
self._has_retried = False
|
|
497
|
+
return None
|
|
498
|
+
|
|
499
|
+
# Log feedback based on configured level
|
|
500
|
+
if self.log_level == "warning":
|
|
501
|
+
logger.warning(
|
|
502
|
+
"Suggest constraint feedback",
|
|
503
|
+
constraint_name=self.constraint.name,
|
|
504
|
+
feedback=result.feedback,
|
|
505
|
+
)
|
|
506
|
+
elif self.log_level == "info":
|
|
507
|
+
logger.info(
|
|
508
|
+
"Suggest constraint feedback",
|
|
509
|
+
constraint_name=self.constraint.name,
|
|
510
|
+
feedback=result.feedback,
|
|
511
|
+
)
|
|
512
|
+
else:
|
|
513
|
+
logger.debug(
|
|
514
|
+
"Suggest constraint feedback",
|
|
515
|
+
constraint_name=self.constraint.name,
|
|
516
|
+
feedback=result.feedback,
|
|
517
|
+
)
|
|
518
|
+
|
|
519
|
+
# Optionally request one improvement
|
|
520
|
+
if self.allow_one_retry and not self._has_retried:
|
|
521
|
+
self._has_retried = True
|
|
522
|
+
retry_prompt = (
|
|
523
|
+
f"Consider this feedback for your response:\n"
|
|
524
|
+
f"{result.feedback}\n\n"
|
|
525
|
+
f"Original request: {user_input}\n"
|
|
526
|
+
f"Please provide an improved response."
|
|
527
|
+
)
|
|
528
|
+
return {"messages": [HumanMessage(content=retry_prompt)]}
|
|
529
|
+
|
|
530
|
+
# Pass through without modification
|
|
531
|
+
self._has_retried = False
|
|
532
|
+
return None
|
|
533
|
+
|
|
534
|
+
|
|
535
|
+
# =============================================================================
|
|
536
|
+
# RefineMiddleware - Iterative improvement (like dspy.Refine)
|
|
537
|
+
# =============================================================================
|
|
538
|
+
|
|
539
|
+
|
|
540
|
+
class RefineMiddleware(AgentMiddleware[AgentState, Context]):
|
|
541
|
+
"""
|
|
542
|
+
Iterative refinement middleware that improves responses.
|
|
543
|
+
|
|
544
|
+
Inspired by dspy.Refine - runs the response through multiple iterations,
|
|
545
|
+
using a reward function to score each attempt. Selects the best response
|
|
546
|
+
or stops early if a threshold is reached.
|
|
547
|
+
|
|
548
|
+
Since middleware runs in the agent loop, this works by:
|
|
549
|
+
1. Scoring the current response
|
|
550
|
+
2. If below threshold and iterations remain, requesting improvement
|
|
551
|
+
3. Tracking the best response across iterations
|
|
552
|
+
4. Returning the best response when done
|
|
553
|
+
|
|
554
|
+
Args:
|
|
555
|
+
reward_fn: Function that scores a response (returns 0.0 to 1.0)
|
|
556
|
+
threshold: Score threshold to stop early (default 0.8)
|
|
557
|
+
max_iterations: Maximum improvement iterations (default 3)
|
|
558
|
+
select_best: If True, track and return best response; else use last
|
|
559
|
+
|
|
560
|
+
Example:
|
|
561
|
+
def score_response(response: str, context: dict) -> float:
|
|
562
|
+
# Score based on helpfulness, completeness, etc.
|
|
563
|
+
return 0.85
|
|
564
|
+
|
|
565
|
+
middleware = RefineMiddleware(
|
|
566
|
+
reward_fn=score_response,
|
|
567
|
+
threshold=0.9,
|
|
568
|
+
max_iterations=3,
|
|
569
|
+
)
|
|
570
|
+
"""
|
|
571
|
+
|
|
572
|
+
def __init__(
|
|
573
|
+
self,
|
|
574
|
+
reward_fn: Callable[[str, dict[str, Any]], float],
|
|
575
|
+
threshold: float = 0.8,
|
|
576
|
+
max_iterations: int = 3,
|
|
577
|
+
select_best: bool = True,
|
|
578
|
+
):
|
|
579
|
+
super().__init__()
|
|
580
|
+
self.reward_fn = reward_fn
|
|
581
|
+
self.threshold = threshold
|
|
582
|
+
self.max_iterations = max_iterations
|
|
583
|
+
self.select_best = select_best
|
|
584
|
+
self._iteration = 0
|
|
585
|
+
self._best_score = 0.0
|
|
586
|
+
self._best_response: Optional[str] = None
|
|
587
|
+
|
|
588
|
+
def after_model(
|
|
589
|
+
self, state: AgentState, runtime: Runtime[Context]
|
|
590
|
+
) -> dict[str, Any] | None:
|
|
591
|
+
"""Score response and request improvement if needed."""
|
|
592
|
+
messages: list[BaseMessage] = state.get("messages", [])
|
|
593
|
+
|
|
594
|
+
if not messages:
|
|
595
|
+
return None
|
|
596
|
+
|
|
597
|
+
ai_message: AIMessage | None = last_ai_message(messages)
|
|
598
|
+
human_message: HumanMessage | None = last_human_message(messages)
|
|
599
|
+
|
|
600
|
+
if not ai_message:
|
|
601
|
+
return None
|
|
602
|
+
|
|
603
|
+
response = str(ai_message.content)
|
|
604
|
+
user_input = str(human_message.content) if human_message else ""
|
|
605
|
+
|
|
606
|
+
context = {
|
|
607
|
+
"input": user_input,
|
|
608
|
+
"messages": messages,
|
|
609
|
+
"runtime": runtime,
|
|
610
|
+
"iteration": self._iteration,
|
|
611
|
+
}
|
|
612
|
+
|
|
613
|
+
score: float = self.reward_fn(response, context)
|
|
614
|
+
self._iteration += 1
|
|
615
|
+
|
|
616
|
+
logger.debug(
|
|
617
|
+
"Refine iteration",
|
|
618
|
+
iteration=self._iteration,
|
|
619
|
+
max_iterations=self.max_iterations,
|
|
620
|
+
score=f"{score:.3f}",
|
|
621
|
+
threshold=self.threshold,
|
|
622
|
+
)
|
|
623
|
+
|
|
624
|
+
# Track best response
|
|
625
|
+
if self.select_best and score > self._best_score:
|
|
626
|
+
self._best_score = score
|
|
627
|
+
self._best_response = response
|
|
628
|
+
|
|
629
|
+
# Check if we should stop
|
|
630
|
+
if score >= self.threshold:
|
|
631
|
+
logger.debug(
|
|
632
|
+
"Refine threshold reached",
|
|
633
|
+
score=f"{score:.3f}",
|
|
634
|
+
threshold=self.threshold,
|
|
635
|
+
)
|
|
636
|
+
self._reset()
|
|
637
|
+
return None
|
|
638
|
+
|
|
639
|
+
if self._iteration >= self.max_iterations:
|
|
640
|
+
logger.debug(
|
|
641
|
+
"Refine max iterations reached", best_score=f"{self._best_score:.3f}"
|
|
642
|
+
)
|
|
643
|
+
# Use best response if tracking
|
|
644
|
+
if self.select_best and self._best_response:
|
|
645
|
+
ai_message.content = self._best_response
|
|
646
|
+
self._reset()
|
|
647
|
+
return None
|
|
648
|
+
|
|
649
|
+
# Request improvement
|
|
650
|
+
feedback = f"Current response scored {score:.2f}/{self.threshold:.2f}."
|
|
651
|
+
if score < 0.5:
|
|
652
|
+
feedback += " The response needs significant improvement."
|
|
653
|
+
elif score < self.threshold:
|
|
654
|
+
feedback += " The response is good but could be better."
|
|
655
|
+
|
|
656
|
+
retry_prompt = f"{feedback}\n\nPlease improve your response to:\n{user_input}"
|
|
657
|
+
return {"messages": [HumanMessage(content=retry_prompt)]}
|
|
658
|
+
|
|
659
|
+
def _reset(self) -> None:
|
|
660
|
+
"""Reset iteration state for next invocation."""
|
|
661
|
+
self._iteration = 0
|
|
662
|
+
self._best_score = 0.0
|
|
663
|
+
self._best_response = None
|
|
664
|
+
|
|
665
|
+
|
|
666
|
+
# =============================================================================
|
|
667
|
+
# Factory Functions
|
|
668
|
+
# =============================================================================
|
|
669
|
+
|
|
670
|
+
|
|
671
|
+
def create_assert_middleware(
|
|
672
|
+
constraint: Constraint | Callable[[str, dict[str, Any]], ConstraintResult | bool],
|
|
673
|
+
max_retries: int = 3,
|
|
674
|
+
on_failure: str = "error",
|
|
675
|
+
fallback_message: str = "Unable to generate a valid response.",
|
|
676
|
+
name: Optional[str] = None,
|
|
677
|
+
) -> AssertMiddleware:
|
|
678
|
+
"""
|
|
679
|
+
Create an AssertMiddleware (hard constraint with retry).
|
|
680
|
+
|
|
681
|
+
Like dspy.Assert - enforces a constraint and retries if not satisfied.
|
|
682
|
+
|
|
683
|
+
Args:
|
|
684
|
+
constraint: Constraint object or callable function
|
|
685
|
+
max_retries: Maximum retry attempts
|
|
686
|
+
on_failure: "error", "fallback", or "pass"
|
|
687
|
+
fallback_message: Message if on_failure="fallback"
|
|
688
|
+
name: Name for function constraints
|
|
689
|
+
|
|
690
|
+
Returns:
|
|
691
|
+
List containing AssertMiddleware configured with the constraint
|
|
692
|
+
|
|
693
|
+
Example:
|
|
694
|
+
# Using a Constraint class
|
|
695
|
+
middleware = create_assert_middleware(
|
|
696
|
+
constraint=LengthConstraint(min_length=100),
|
|
697
|
+
max_retries=3,
|
|
698
|
+
)
|
|
699
|
+
|
|
700
|
+
# Using a function
|
|
701
|
+
def has_sources(response: str, ctx: dict) -> bool:
|
|
702
|
+
return "[source]" in response.lower() or "reference" in response.lower()
|
|
703
|
+
|
|
704
|
+
middleware = create_assert_middleware(
|
|
705
|
+
constraint=has_sources,
|
|
706
|
+
max_retries=2,
|
|
707
|
+
on_failure="fallback",
|
|
708
|
+
fallback_message="I couldn't find relevant sources.",
|
|
709
|
+
)
|
|
710
|
+
"""
|
|
711
|
+
if callable(constraint) and not isinstance(constraint, Constraint):
|
|
712
|
+
constraint = FunctionConstraint(constraint, name=name)
|
|
713
|
+
|
|
714
|
+
return AssertMiddleware(
|
|
715
|
+
constraint=constraint,
|
|
716
|
+
max_retries=max_retries,
|
|
717
|
+
on_failure=on_failure,
|
|
718
|
+
fallback_message=fallback_message,
|
|
719
|
+
)
|
|
720
|
+
|
|
721
|
+
|
|
722
|
+
def create_suggest_middleware(
|
|
723
|
+
constraint: Constraint | Callable[[str, dict[str, Any]], ConstraintResult | bool],
|
|
724
|
+
allow_one_retry: bool = False,
|
|
725
|
+
log_level: str = "warning",
|
|
726
|
+
name: Optional[str] = None,
|
|
727
|
+
) -> SuggestMiddleware:
|
|
728
|
+
"""
|
|
729
|
+
Create a SuggestMiddleware (soft constraint with feedback).
|
|
730
|
+
|
|
731
|
+
Like dspy.Suggest - evaluates constraint and logs feedback without blocking.
|
|
732
|
+
|
|
733
|
+
Args:
|
|
734
|
+
constraint: Constraint object or callable function
|
|
735
|
+
allow_one_retry: Request one improvement attempt on failure
|
|
736
|
+
log_level: "warning", "info", or "debug"
|
|
737
|
+
name: Name for function constraints
|
|
738
|
+
|
|
739
|
+
Returns:
|
|
740
|
+
List containing SuggestMiddleware configured with the constraint
|
|
741
|
+
|
|
742
|
+
Example:
|
|
743
|
+
def is_professional(response: str, ctx: dict) -> ConstraintResult:
|
|
744
|
+
informal = ["lol", "omg", "btw", "gonna"]
|
|
745
|
+
found = [w for w in informal if w in response.lower()]
|
|
746
|
+
if found:
|
|
747
|
+
return ConstraintResult(
|
|
748
|
+
passed=False,
|
|
749
|
+
feedback=f"Response contains informal language: {found}"
|
|
750
|
+
)
|
|
751
|
+
return ConstraintResult(passed=True)
|
|
752
|
+
|
|
753
|
+
middleware = create_suggest_middleware(
|
|
754
|
+
constraint=is_professional,
|
|
755
|
+
allow_one_retry=True,
|
|
756
|
+
)
|
|
757
|
+
"""
|
|
758
|
+
if callable(constraint) and not isinstance(constraint, Constraint):
|
|
759
|
+
constraint = FunctionConstraint(constraint, name=name)
|
|
760
|
+
|
|
761
|
+
return SuggestMiddleware(
|
|
762
|
+
constraint=constraint,
|
|
763
|
+
allow_one_retry=allow_one_retry,
|
|
764
|
+
log_level=log_level,
|
|
765
|
+
)
|
|
766
|
+
|
|
767
|
+
|
|
768
|
+
def create_refine_middleware(
|
|
769
|
+
reward_fn: Callable[[str, dict[str, Any]], float],
|
|
770
|
+
threshold: float = 0.8,
|
|
771
|
+
max_iterations: int = 3,
|
|
772
|
+
select_best: bool = True,
|
|
773
|
+
) -> RefineMiddleware:
|
|
774
|
+
"""
|
|
775
|
+
Create a RefineMiddleware (iterative improvement).
|
|
776
|
+
|
|
777
|
+
Like dspy.Refine - iteratively improves responses using a reward function.
|
|
778
|
+
|
|
779
|
+
Args:
|
|
780
|
+
reward_fn: Function that scores a response (0.0 to 1.0)
|
|
781
|
+
threshold: Score threshold to stop early
|
|
782
|
+
max_iterations: Maximum improvement iterations
|
|
783
|
+
select_best: Track and return best response across iterations
|
|
784
|
+
|
|
785
|
+
Returns:
|
|
786
|
+
List containing RefineMiddleware configured with the reward function
|
|
787
|
+
|
|
788
|
+
Example:
|
|
789
|
+
def evaluate_completeness(response: str, ctx: dict) -> float:
|
|
790
|
+
# Check for expected sections
|
|
791
|
+
sections = ["introduction", "details", "conclusion"]
|
|
792
|
+
found = sum(1 for s in sections if s in response.lower())
|
|
793
|
+
return found / len(sections)
|
|
794
|
+
|
|
795
|
+
middleware = create_refine_middleware(
|
|
796
|
+
reward_fn=evaluate_completeness,
|
|
797
|
+
threshold=1.0,
|
|
798
|
+
max_iterations=3,
|
|
799
|
+
)
|
|
800
|
+
"""
|
|
801
|
+
return RefineMiddleware(
|
|
802
|
+
reward_fn=reward_fn,
|
|
803
|
+
threshold=threshold,
|
|
804
|
+
max_iterations=max_iterations,
|
|
805
|
+
select_best=select_best,
|
|
806
|
+
)
|