dao-ai 0.0.25__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. dao_ai/__init__.py +29 -0
  2. dao_ai/agent_as_code.py +5 -5
  3. dao_ai/cli.py +245 -40
  4. dao_ai/config.py +1863 -338
  5. dao_ai/genie/__init__.py +38 -0
  6. dao_ai/genie/cache/__init__.py +43 -0
  7. dao_ai/genie/cache/base.py +72 -0
  8. dao_ai/genie/cache/core.py +79 -0
  9. dao_ai/genie/cache/lru.py +347 -0
  10. dao_ai/genie/cache/semantic.py +970 -0
  11. dao_ai/genie/core.py +35 -0
  12. dao_ai/graph.py +27 -228
  13. dao_ai/hooks/__init__.py +9 -6
  14. dao_ai/hooks/core.py +27 -195
  15. dao_ai/logging.py +56 -0
  16. dao_ai/memory/__init__.py +10 -0
  17. dao_ai/memory/core.py +65 -30
  18. dao_ai/memory/databricks.py +402 -0
  19. dao_ai/memory/postgres.py +79 -38
  20. dao_ai/messages.py +6 -4
  21. dao_ai/middleware/__init__.py +125 -0
  22. dao_ai/middleware/assertions.py +806 -0
  23. dao_ai/middleware/base.py +50 -0
  24. dao_ai/middleware/core.py +67 -0
  25. dao_ai/middleware/guardrails.py +420 -0
  26. dao_ai/middleware/human_in_the_loop.py +232 -0
  27. dao_ai/middleware/message_validation.py +586 -0
  28. dao_ai/middleware/summarization.py +197 -0
  29. dao_ai/models.py +1306 -114
  30. dao_ai/nodes.py +261 -166
  31. dao_ai/optimization.py +674 -0
  32. dao_ai/orchestration/__init__.py +52 -0
  33. dao_ai/orchestration/core.py +294 -0
  34. dao_ai/orchestration/supervisor.py +278 -0
  35. dao_ai/orchestration/swarm.py +271 -0
  36. dao_ai/prompts.py +128 -31
  37. dao_ai/providers/databricks.py +645 -172
  38. dao_ai/state.py +157 -21
  39. dao_ai/tools/__init__.py +13 -5
  40. dao_ai/tools/agent.py +1 -3
  41. dao_ai/tools/core.py +64 -11
  42. dao_ai/tools/email.py +232 -0
  43. dao_ai/tools/genie.py +144 -295
  44. dao_ai/tools/mcp.py +220 -133
  45. dao_ai/tools/memory.py +50 -0
  46. dao_ai/tools/python.py +9 -14
  47. dao_ai/tools/search.py +14 -0
  48. dao_ai/tools/slack.py +22 -10
  49. dao_ai/tools/sql.py +202 -0
  50. dao_ai/tools/time.py +30 -7
  51. dao_ai/tools/unity_catalog.py +165 -88
  52. dao_ai/tools/vector_search.py +360 -40
  53. dao_ai/utils.py +218 -16
  54. dao_ai-0.1.2.dist-info/METADATA +455 -0
  55. dao_ai-0.1.2.dist-info/RECORD +64 -0
  56. {dao_ai-0.0.25.dist-info → dao_ai-0.1.2.dist-info}/WHEEL +1 -1
  57. dao_ai/chat_models.py +0 -204
  58. dao_ai/guardrails.py +0 -112
  59. dao_ai/tools/human_in_the_loop.py +0 -100
  60. dao_ai-0.0.25.dist-info/METADATA +0 -1165
  61. dao_ai-0.0.25.dist-info/RECORD +0 -41
  62. {dao_ai-0.0.25.dist-info → dao_ai-0.1.2.dist-info}/entry_points.txt +0 -0
  63. {dao_ai-0.0.25.dist-info → dao_ai-0.1.2.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,806 @@
1
+ """
2
+ DSPy-style assertion middleware for DAO AI agents.
3
+
4
+ This module provides middleware implementations inspired by DSPy's assertion
5
+ mechanisms (dspy.Assert, dspy.Suggest, dspy.Refine) but implemented natively
6
+ in the LangChain middleware pattern for optimal latency and streaming support.
7
+
8
+ Key concepts:
9
+ - Assert: Hard constraint - retry until satisfied or fail after max attempts
10
+ - Suggest: Soft constraint - provide feedback but don't block execution
11
+ - Refine: Iterative improvement - run multiple times, select best result
12
+
13
+ These work with LangChain's middleware hooks (after_model) to validate and
14
+ improve agent outputs without requiring the DSPy library.
15
+ """
16
+
17
+ from abc import ABC, abstractmethod
18
+ from dataclasses import dataclass, field
19
+ from typing import Any, Callable, Optional, TypeVar
20
+
21
+ from langchain_core.language_models import LanguageModelLike
22
+ from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
23
+ from langgraph.runtime import Runtime
24
+ from loguru import logger
25
+
26
+ from dao_ai.messages import last_ai_message, last_human_message
27
+ from dao_ai.middleware.base import AgentMiddleware
28
+ from dao_ai.state import AgentState, Context
29
+
30
+ __all__ = [
31
+ # Types
32
+ "Constraint",
33
+ "ConstraintResult",
34
+ # Middleware classes
35
+ "AssertMiddleware",
36
+ "SuggestMiddleware",
37
+ "RefineMiddleware",
38
+ # Factory functions
39
+ "create_assert_middleware",
40
+ "create_suggest_middleware",
41
+ "create_refine_middleware",
42
+ ]
43
+
44
+ T = TypeVar("T")
45
+
46
+
47
+ @dataclass
48
+ class ConstraintResult:
49
+ """Result of evaluating a constraint against model output.
50
+
51
+ Attributes:
52
+ passed: Whether the constraint was satisfied
53
+ feedback: Feedback message explaining the result
54
+ score: Optional numeric score (0.0 to 1.0)
55
+ metadata: Additional metadata from the evaluation
56
+ """
57
+
58
+ passed: bool
59
+ feedback: str = ""
60
+ score: Optional[float] = None
61
+ metadata: dict[str, Any] = field(default_factory=dict)
62
+
63
+
64
+ class Constraint(ABC):
65
+ """Base class for constraints that can be evaluated against model outputs.
66
+
67
+ Constraints can be:
68
+ - Callable functions: (response: str, context: dict) -> ConstraintResult | bool
69
+ - LLM-based evaluators: Use a judge model to evaluate responses
70
+ - Rule-based: Deterministic checks like regex, keywords, length
71
+
72
+ Example:
73
+ class LengthConstraint(Constraint):
74
+ def __init__(self, min_length: int, max_length: int):
75
+ self.min_length = min_length
76
+ self.max_length = max_length
77
+
78
+ def evaluate(self, response: str, context: dict) -> ConstraintResult:
79
+ length = len(response)
80
+ if self.min_length <= length <= self.max_length:
81
+ return ConstraintResult(passed=True, feedback="Length OK")
82
+ return ConstraintResult(
83
+ passed=False,
84
+ feedback=f"Response length {length} not in range [{self.min_length}, {self.max_length}]"
85
+ )
86
+ """
87
+
88
+ @abstractmethod
89
+ def evaluate(self, response: str, context: dict[str, Any]) -> ConstraintResult:
90
+ """Evaluate the constraint against a response.
91
+
92
+ Args:
93
+ response: The model's response text
94
+ context: Additional context (user input, state, etc.)
95
+
96
+ Returns:
97
+ ConstraintResult indicating whether constraint was satisfied
98
+ """
99
+ ...
100
+
101
+ @property
102
+ def name(self) -> str:
103
+ """Name of this constraint for logging."""
104
+ return self.__class__.__name__
105
+
106
+
107
+ class FunctionConstraint(Constraint):
108
+ """Constraint that wraps a callable function.
109
+
110
+ The function can return either:
111
+ - bool: True = passed, False = failed with default feedback
112
+ - ConstraintResult: Full result with feedback and score
113
+ """
114
+
115
+ def __init__(
116
+ self,
117
+ func: Callable[[str, dict[str, Any]], ConstraintResult | bool],
118
+ name: Optional[str] = None,
119
+ default_feedback: str = "Constraint not satisfied",
120
+ ):
121
+ self._func = func
122
+ self._name = name or func.__name__
123
+ self._default_feedback = default_feedback
124
+
125
+ @property
126
+ def name(self) -> str:
127
+ return self._name
128
+
129
+ def evaluate(self, response: str, context: dict[str, Any]) -> ConstraintResult:
130
+ result = self._func(response, context)
131
+ if isinstance(result, bool):
132
+ return ConstraintResult(
133
+ passed=result,
134
+ feedback="" if result else self._default_feedback,
135
+ )
136
+ return result
137
+
138
+
139
+ class LLMConstraint(Constraint):
140
+ """Constraint that uses an LLM judge to evaluate responses.
141
+
142
+ Similar to LLM-as-judge evaluation but returns a ConstraintResult.
143
+ """
144
+
145
+ def __init__(
146
+ self,
147
+ model: LanguageModelLike,
148
+ prompt: str,
149
+ name: Optional[str] = None,
150
+ threshold: float = 0.5,
151
+ ):
152
+ """Initialize LLM-based constraint.
153
+
154
+ Args:
155
+ model: LLM to use for evaluation
156
+ prompt: Evaluation prompt. Should include {response} and optionally {input} placeholders.
157
+ name: Name for logging
158
+ threshold: Score threshold for passing (0.0-1.0)
159
+ """
160
+ self._model = model
161
+ self._prompt = prompt
162
+ self._name = name or "LLMConstraint"
163
+ self._threshold = threshold
164
+
165
+ @property
166
+ def name(self) -> str:
167
+ return self._name
168
+
169
+ def evaluate(self, response: str, context: dict[str, Any]) -> ConstraintResult:
170
+ user_input = context.get("input", "")
171
+
172
+ eval_prompt = self._prompt.format(response=response, input=user_input)
173
+
174
+ result = self._model.invoke(
175
+ [
176
+ {
177
+ "role": "system",
178
+ "content": (
179
+ "You are an evaluation assistant. Evaluate the response and reply with:\n"
180
+ "PASS: <feedback> if the constraint is satisfied\n"
181
+ "FAIL: <feedback> if the constraint is not satisfied\n"
182
+ "Be concise."
183
+ ),
184
+ },
185
+ {"role": "user", "content": eval_prompt},
186
+ ]
187
+ )
188
+
189
+ content = str(result.content).strip()
190
+
191
+ if content.upper().startswith("PASS"):
192
+ feedback = content[5:].strip(": ").strip()
193
+ return ConstraintResult(passed=True, feedback=feedback, score=1.0)
194
+ elif content.upper().startswith("FAIL"):
195
+ feedback = content[5:].strip(": ").strip()
196
+ return ConstraintResult(passed=False, feedback=feedback, score=0.0)
197
+ else:
198
+ # Try to interpret as pass/fail
199
+ is_pass = any(
200
+ word in content.lower()
201
+ for word in ["yes", "pass", "correct", "good", "valid"]
202
+ )
203
+ return ConstraintResult(passed=is_pass, feedback=content)
204
+
205
+
206
+ class KeywordConstraint(Constraint):
207
+ """Simple constraint that checks for required/banned keywords."""
208
+
209
+ def __init__(
210
+ self,
211
+ required_keywords: Optional[list[str]] = None,
212
+ banned_keywords: Optional[list[str]] = None,
213
+ case_sensitive: bool = False,
214
+ name: Optional[str] = None,
215
+ ):
216
+ self._required = required_keywords or []
217
+ self._banned = banned_keywords or []
218
+ self._case_sensitive = case_sensitive
219
+ self._name = name or "KeywordConstraint"
220
+
221
+ @property
222
+ def name(self) -> str:
223
+ return self._name
224
+
225
+ def evaluate(self, response: str, context: dict[str, Any]) -> ConstraintResult:
226
+ check_response = response if self._case_sensitive else response.lower()
227
+
228
+ # Check banned keywords
229
+ for keyword in self._banned:
230
+ check_keyword = keyword if self._case_sensitive else keyword.lower()
231
+ if check_keyword in check_response:
232
+ return ConstraintResult(
233
+ passed=False,
234
+ feedback=f"Response contains banned keyword: '{keyword}'",
235
+ )
236
+
237
+ # Check required keywords
238
+ for keyword in self._required:
239
+ check_keyword = keyword if self._case_sensitive else keyword.lower()
240
+ if check_keyword not in check_response:
241
+ return ConstraintResult(
242
+ passed=False,
243
+ feedback=f"Response missing required keyword: '{keyword}'",
244
+ )
245
+
246
+ return ConstraintResult(
247
+ passed=True, feedback="All keyword constraints satisfied"
248
+ )
249
+
250
+
251
+ class LengthConstraint(Constraint):
252
+ """Constraint that checks response length."""
253
+
254
+ def __init__(
255
+ self,
256
+ min_length: Optional[int] = None,
257
+ max_length: Optional[int] = None,
258
+ unit: str = "chars", # "chars", "words", "sentences"
259
+ name: Optional[str] = None,
260
+ ):
261
+ self._min_length = min_length
262
+ self._max_length = max_length
263
+ self._unit = unit
264
+ self._name = name or "LengthConstraint"
265
+
266
+ @property
267
+ def name(self) -> str:
268
+ return self._name
269
+
270
+ def evaluate(self, response: str, context: dict[str, Any]) -> ConstraintResult:
271
+ if self._unit == "chars":
272
+ length = len(response)
273
+ elif self._unit == "words":
274
+ length = len(response.split())
275
+ elif self._unit == "sentences":
276
+ length = response.count(".") + response.count("!") + response.count("?")
277
+ else:
278
+ length = len(response)
279
+
280
+ if self._min_length is not None and length < self._min_length:
281
+ return ConstraintResult(
282
+ passed=False,
283
+ feedback=f"Response too short: {length} {self._unit} (min: {self._min_length})",
284
+ score=length / self._min_length if self._min_length > 0 else 0.0,
285
+ )
286
+
287
+ if self._max_length is not None and length > self._max_length:
288
+ return ConstraintResult(
289
+ passed=False,
290
+ feedback=f"Response too long: {length} {self._unit} (max: {self._max_length})",
291
+ score=self._max_length / length if length > 0 else 0.0,
292
+ )
293
+
294
+ return ConstraintResult(
295
+ passed=True,
296
+ feedback=f"Length OK: {length} {self._unit}",
297
+ score=1.0,
298
+ )
299
+
300
+
301
+ # =============================================================================
302
+ # AssertMiddleware - Hard constraint with retry (like dspy.Assert)
303
+ # =============================================================================
304
+
305
+
306
+ class AssertMiddleware(AgentMiddleware[AgentState, Context]):
307
+ """
308
+ Hard constraint middleware that retries until satisfied.
309
+
310
+ Inspired by dspy.Assert - if the constraint fails, the middleware
311
+ adds feedback to the conversation and requests a retry. If max
312
+ retries are exhausted, it raises an error or returns a fallback.
313
+
314
+ Args:
315
+ constraint: The constraint to enforce
316
+ max_retries: Maximum retry attempts before giving up
317
+ on_failure: What to do when max retries exhausted:
318
+ - "error": Raise ValueError (default)
319
+ - "fallback": Return fallback_message
320
+ - "pass": Let the response through anyway
321
+ fallback_message: Message to return if on_failure="fallback"
322
+
323
+ Example:
324
+ middleware = AssertMiddleware(
325
+ constraint=LengthConstraint(min_length=100),
326
+ max_retries=3,
327
+ on_failure="fallback",
328
+ fallback_message="Unable to generate a complete response."
329
+ )
330
+ """
331
+
332
+ def __init__(
333
+ self,
334
+ constraint: Constraint,
335
+ max_retries: int = 3,
336
+ on_failure: str = "error", # "error", "fallback", "pass"
337
+ fallback_message: str = "Unable to generate a valid response.",
338
+ ):
339
+ super().__init__()
340
+ self.constraint = constraint
341
+ self.max_retries = max_retries
342
+ self.on_failure = on_failure
343
+ self.fallback_message = fallback_message
344
+ self._retry_count = 0
345
+
346
+ def after_model(
347
+ self, state: AgentState, runtime: Runtime[Context]
348
+ ) -> dict[str, Any] | None:
349
+ """Evaluate constraint and retry if not satisfied."""
350
+ messages: list[BaseMessage] = state.get("messages", [])
351
+
352
+ if not messages:
353
+ return None
354
+
355
+ ai_message: AIMessage | None = last_ai_message(messages)
356
+ human_message: HumanMessage | None = last_human_message(messages)
357
+
358
+ if not ai_message:
359
+ return None
360
+
361
+ response = str(ai_message.content)
362
+ user_input = str(human_message.content) if human_message else ""
363
+
364
+ context = {
365
+ "input": user_input,
366
+ "messages": messages,
367
+ "runtime": runtime,
368
+ }
369
+
370
+ logger.trace(
371
+ "Evaluating Assert constraint", constraint_name=self.constraint.name
372
+ )
373
+
374
+ result = self.constraint.evaluate(response, context)
375
+
376
+ if result.passed:
377
+ logger.trace(
378
+ "Assert constraint passed", constraint_name=self.constraint.name
379
+ )
380
+ self._retry_count = 0
381
+ return None
382
+
383
+ # Constraint failed
384
+ self._retry_count += 1
385
+ logger.warning(
386
+ "Assert constraint failed",
387
+ constraint_name=self.constraint.name,
388
+ attempt=self._retry_count,
389
+ max_retries=self.max_retries,
390
+ feedback=result.feedback,
391
+ )
392
+
393
+ if self._retry_count >= self.max_retries:
394
+ self._retry_count = 0
395
+
396
+ if self.on_failure == "error":
397
+ raise ValueError(
398
+ f"Assert constraint '{self.constraint.name}' failed after "
399
+ f"{self.max_retries} retries: {result.feedback}"
400
+ )
401
+ elif self.on_failure == "fallback":
402
+ ai_message.content = self.fallback_message
403
+ return None
404
+ else: # "pass"
405
+ logger.warning(
406
+ "Assert constraint failed but passing through",
407
+ constraint_name=self.constraint.name,
408
+ )
409
+ return None
410
+
411
+ # Add feedback and retry
412
+ retry_prompt = (
413
+ f"Your previous response did not meet the requirements:\n"
414
+ f"{result.feedback}\n\n"
415
+ f"Please try again with the original request:\n{user_input}"
416
+ )
417
+ return {"messages": [HumanMessage(content=retry_prompt)]}
418
+
419
+
420
+ # =============================================================================
421
+ # SuggestMiddleware - Soft constraint with feedback (like dspy.Suggest)
422
+ # =============================================================================
423
+
424
+
425
+ class SuggestMiddleware(AgentMiddleware[AgentState, Context]):
426
+ """
427
+ Soft constraint middleware that provides feedback without blocking.
428
+
429
+ Inspired by dspy.Suggest - evaluates the constraint and logs feedback
430
+ but does not retry or block the response. The feedback is captured
431
+ in metadata for observability but the response passes through.
432
+
433
+ Optionally, can request one improvement attempt if constraint fails.
434
+
435
+ Args:
436
+ constraint: The constraint to evaluate
437
+ allow_one_retry: If True, request one improvement attempt on failure
438
+ log_level: Log level for feedback ("warning", "info", "debug")
439
+
440
+ Example:
441
+ middleware = SuggestMiddleware(
442
+ constraint=LLMConstraint(
443
+ model=ChatDatabricks(...),
444
+ prompt="Check if response is professional: {response}"
445
+ ),
446
+ allow_one_retry=True,
447
+ )
448
+ """
449
+
450
+ def __init__(
451
+ self,
452
+ constraint: Constraint,
453
+ allow_one_retry: bool = False,
454
+ log_level: str = "warning",
455
+ ):
456
+ super().__init__()
457
+ self.constraint = constraint
458
+ self.allow_one_retry = allow_one_retry
459
+ self.log_level = log_level
460
+ self._has_retried = False
461
+
462
+ def after_model(
463
+ self, state: AgentState, runtime: Runtime[Context]
464
+ ) -> dict[str, Any] | None:
465
+ """Evaluate constraint and log feedback."""
466
+ messages: list[BaseMessage] = state.get("messages", [])
467
+
468
+ if not messages:
469
+ return None
470
+
471
+ ai_message: AIMessage | None = last_ai_message(messages)
472
+ human_message: HumanMessage | None = last_human_message(messages)
473
+
474
+ if not ai_message:
475
+ return None
476
+
477
+ response = str(ai_message.content)
478
+ user_input = str(human_message.content) if human_message else ""
479
+
480
+ context = {
481
+ "input": user_input,
482
+ "messages": messages,
483
+ "runtime": runtime,
484
+ }
485
+
486
+ logger.trace(
487
+ "Evaluating Suggest constraint", constraint_name=self.constraint.name
488
+ )
489
+
490
+ result = self.constraint.evaluate(response, context)
491
+
492
+ if result.passed:
493
+ logger.trace(
494
+ "Suggest constraint passed", constraint_name=self.constraint.name
495
+ )
496
+ self._has_retried = False
497
+ return None
498
+
499
+ # Log feedback based on configured level
500
+ if self.log_level == "warning":
501
+ logger.warning(
502
+ "Suggest constraint feedback",
503
+ constraint_name=self.constraint.name,
504
+ feedback=result.feedback,
505
+ )
506
+ elif self.log_level == "info":
507
+ logger.info(
508
+ "Suggest constraint feedback",
509
+ constraint_name=self.constraint.name,
510
+ feedback=result.feedback,
511
+ )
512
+ else:
513
+ logger.debug(
514
+ "Suggest constraint feedback",
515
+ constraint_name=self.constraint.name,
516
+ feedback=result.feedback,
517
+ )
518
+
519
+ # Optionally request one improvement
520
+ if self.allow_one_retry and not self._has_retried:
521
+ self._has_retried = True
522
+ retry_prompt = (
523
+ f"Consider this feedback for your response:\n"
524
+ f"{result.feedback}\n\n"
525
+ f"Original request: {user_input}\n"
526
+ f"Please provide an improved response."
527
+ )
528
+ return {"messages": [HumanMessage(content=retry_prompt)]}
529
+
530
+ # Pass through without modification
531
+ self._has_retried = False
532
+ return None
533
+
534
+
535
+ # =============================================================================
536
+ # RefineMiddleware - Iterative improvement (like dspy.Refine)
537
+ # =============================================================================
538
+
539
+
540
+ class RefineMiddleware(AgentMiddleware[AgentState, Context]):
541
+ """
542
+ Iterative refinement middleware that improves responses.
543
+
544
+ Inspired by dspy.Refine - runs the response through multiple iterations,
545
+ using a reward function to score each attempt. Selects the best response
546
+ or stops early if a threshold is reached.
547
+
548
+ Since middleware runs in the agent loop, this works by:
549
+ 1. Scoring the current response
550
+ 2. If below threshold and iterations remain, requesting improvement
551
+ 3. Tracking the best response across iterations
552
+ 4. Returning the best response when done
553
+
554
+ Args:
555
+ reward_fn: Function that scores a response (returns 0.0 to 1.0)
556
+ threshold: Score threshold to stop early (default 0.8)
557
+ max_iterations: Maximum improvement iterations (default 3)
558
+ select_best: If True, track and return best response; else use last
559
+
560
+ Example:
561
+ def score_response(response: str, context: dict) -> float:
562
+ # Score based on helpfulness, completeness, etc.
563
+ return 0.85
564
+
565
+ middleware = RefineMiddleware(
566
+ reward_fn=score_response,
567
+ threshold=0.9,
568
+ max_iterations=3,
569
+ )
570
+ """
571
+
572
+ def __init__(
573
+ self,
574
+ reward_fn: Callable[[str, dict[str, Any]], float],
575
+ threshold: float = 0.8,
576
+ max_iterations: int = 3,
577
+ select_best: bool = True,
578
+ ):
579
+ super().__init__()
580
+ self.reward_fn = reward_fn
581
+ self.threshold = threshold
582
+ self.max_iterations = max_iterations
583
+ self.select_best = select_best
584
+ self._iteration = 0
585
+ self._best_score = 0.0
586
+ self._best_response: Optional[str] = None
587
+
588
+ def after_model(
589
+ self, state: AgentState, runtime: Runtime[Context]
590
+ ) -> dict[str, Any] | None:
591
+ """Score response and request improvement if needed."""
592
+ messages: list[BaseMessage] = state.get("messages", [])
593
+
594
+ if not messages:
595
+ return None
596
+
597
+ ai_message: AIMessage | None = last_ai_message(messages)
598
+ human_message: HumanMessage | None = last_human_message(messages)
599
+
600
+ if not ai_message:
601
+ return None
602
+
603
+ response = str(ai_message.content)
604
+ user_input = str(human_message.content) if human_message else ""
605
+
606
+ context = {
607
+ "input": user_input,
608
+ "messages": messages,
609
+ "runtime": runtime,
610
+ "iteration": self._iteration,
611
+ }
612
+
613
+ score: float = self.reward_fn(response, context)
614
+ self._iteration += 1
615
+
616
+ logger.debug(
617
+ "Refine iteration",
618
+ iteration=self._iteration,
619
+ max_iterations=self.max_iterations,
620
+ score=f"{score:.3f}",
621
+ threshold=self.threshold,
622
+ )
623
+
624
+ # Track best response
625
+ if self.select_best and score > self._best_score:
626
+ self._best_score = score
627
+ self._best_response = response
628
+
629
+ # Check if we should stop
630
+ if score >= self.threshold:
631
+ logger.debug(
632
+ "Refine threshold reached",
633
+ score=f"{score:.3f}",
634
+ threshold=self.threshold,
635
+ )
636
+ self._reset()
637
+ return None
638
+
639
+ if self._iteration >= self.max_iterations:
640
+ logger.debug(
641
+ "Refine max iterations reached", best_score=f"{self._best_score:.3f}"
642
+ )
643
+ # Use best response if tracking
644
+ if self.select_best and self._best_response:
645
+ ai_message.content = self._best_response
646
+ self._reset()
647
+ return None
648
+
649
+ # Request improvement
650
+ feedback = f"Current response scored {score:.2f}/{self.threshold:.2f}."
651
+ if score < 0.5:
652
+ feedback += " The response needs significant improvement."
653
+ elif score < self.threshold:
654
+ feedback += " The response is good but could be better."
655
+
656
+ retry_prompt = f"{feedback}\n\nPlease improve your response to:\n{user_input}"
657
+ return {"messages": [HumanMessage(content=retry_prompt)]}
658
+
659
+ def _reset(self) -> None:
660
+ """Reset iteration state for next invocation."""
661
+ self._iteration = 0
662
+ self._best_score = 0.0
663
+ self._best_response = None
664
+
665
+
666
+ # =============================================================================
667
+ # Factory Functions
668
+ # =============================================================================
669
+
670
+
671
+ def create_assert_middleware(
672
+ constraint: Constraint | Callable[[str, dict[str, Any]], ConstraintResult | bool],
673
+ max_retries: int = 3,
674
+ on_failure: str = "error",
675
+ fallback_message: str = "Unable to generate a valid response.",
676
+ name: Optional[str] = None,
677
+ ) -> AssertMiddleware:
678
+ """
679
+ Create an AssertMiddleware (hard constraint with retry).
680
+
681
+ Like dspy.Assert - enforces a constraint and retries if not satisfied.
682
+
683
+ Args:
684
+ constraint: Constraint object or callable function
685
+ max_retries: Maximum retry attempts
686
+ on_failure: "error", "fallback", or "pass"
687
+ fallback_message: Message if on_failure="fallback"
688
+ name: Name for function constraints
689
+
690
+ Returns:
691
+ AssertMiddleware configured with the constraint
692
+
693
+ Example:
694
+ # Using a Constraint class
695
+ middleware = create_assert_middleware(
696
+ constraint=LengthConstraint(min_length=100),
697
+ max_retries=3,
698
+ )
699
+
700
+ # Using a function
701
+ def has_sources(response: str, ctx: dict) -> bool:
702
+ return "[source]" in response.lower() or "reference" in response.lower()
703
+
704
+ middleware = create_assert_middleware(
705
+ constraint=has_sources,
706
+ max_retries=2,
707
+ on_failure="fallback",
708
+ fallback_message="I couldn't find relevant sources.",
709
+ )
710
+ """
711
+ if callable(constraint) and not isinstance(constraint, Constraint):
712
+ constraint = FunctionConstraint(constraint, name=name)
713
+
714
+ return AssertMiddleware(
715
+ constraint=constraint,
716
+ max_retries=max_retries,
717
+ on_failure=on_failure,
718
+ fallback_message=fallback_message,
719
+ )
720
+
721
+
722
+ def create_suggest_middleware(
723
+ constraint: Constraint | Callable[[str, dict[str, Any]], ConstraintResult | bool],
724
+ allow_one_retry: bool = False,
725
+ log_level: str = "warning",
726
+ name: Optional[str] = None,
727
+ ) -> SuggestMiddleware:
728
+ """
729
+ Create a SuggestMiddleware (soft constraint with feedback).
730
+
731
+ Like dspy.Suggest - evaluates constraint and logs feedback without blocking.
732
+
733
+ Args:
734
+ constraint: Constraint object or callable function
735
+ allow_one_retry: Request one improvement attempt on failure
736
+ log_level: "warning", "info", or "debug"
737
+ name: Name for function constraints
738
+
739
+ Returns:
740
+ SuggestMiddleware configured with the constraint
741
+
742
+ Example:
743
+ def is_professional(response: str, ctx: dict) -> ConstraintResult:
744
+ informal = ["lol", "omg", "btw", "gonna"]
745
+ found = [w for w in informal if w in response.lower()]
746
+ if found:
747
+ return ConstraintResult(
748
+ passed=False,
749
+ feedback=f"Response contains informal language: {found}"
750
+ )
751
+ return ConstraintResult(passed=True)
752
+
753
+ middleware = create_suggest_middleware(
754
+ constraint=is_professional,
755
+ allow_one_retry=True,
756
+ )
757
+ """
758
+ if callable(constraint) and not isinstance(constraint, Constraint):
759
+ constraint = FunctionConstraint(constraint, name=name)
760
+
761
+ return SuggestMiddleware(
762
+ constraint=constraint,
763
+ allow_one_retry=allow_one_retry,
764
+ log_level=log_level,
765
+ )
766
+
767
+
768
+ def create_refine_middleware(
769
+ reward_fn: Callable[[str, dict[str, Any]], float],
770
+ threshold: float = 0.8,
771
+ max_iterations: int = 3,
772
+ select_best: bool = True,
773
+ ) -> RefineMiddleware:
774
+ """
775
+ Create a RefineMiddleware (iterative improvement).
776
+
777
+ Like dspy.Refine - iteratively improves responses using a reward function.
778
+
779
+ Args:
780
+ reward_fn: Function that scores a response (0.0 to 1.0)
781
+ threshold: Score threshold to stop early
782
+ max_iterations: Maximum improvement iterations
783
+ select_best: Track and return best response across iterations
784
+
785
+ Returns:
786
+ RefineMiddleware configured with the reward function
787
+
788
+ Example:
789
+ def evaluate_completeness(response: str, ctx: dict) -> float:
790
+ # Check for expected sections
791
+ sections = ["introduction", "details", "conclusion"]
792
+ found = sum(1 for s in sections if s in response.lower())
793
+ return found / len(sections)
794
+
795
+ middleware = create_refine_middleware(
796
+ reward_fn=evaluate_completeness,
797
+ threshold=1.0,
798
+ max_iterations=3,
799
+ )
800
+ """
801
+ return RefineMiddleware(
802
+ reward_fn=reward_fn,
803
+ threshold=threshold,
804
+ max_iterations=max_iterations,
805
+ select_best=select_best,
806
+ )