flock-core 0.4.0b22__py3-none-any.whl → 0.4.0b23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of flock-core might be problematic. Click here for more details.

@@ -0,0 +1,286 @@
1
+ # src/flock/modules/asserts/assertion_module.py (New File)
2
+
3
+ import json
4
+ from collections.abc import Callable
5
+ from typing import Any, Literal
6
+
7
+ import dspy # For potential LLM-based rule checking
8
+ from pydantic import BaseModel, Field, PrivateAttr, ValidationError
9
+
10
+ from flock.core.context.context import FlockContext
11
+ from flock.core.flock_agent import FlockAgent
12
+ from flock.core.flock_module import FlockModule, FlockModuleConfig
13
+
14
+ # Need registry access if rules are callables defined elsewhere
15
+ from flock.core.flock_registry import flock_component, get_registry
16
+ from flock.core.logging.logging import get_logger
17
+
18
+ logger = get_logger("module.assertion")
19
+
20
+ # --- Rule Definition ---
21
+ # Rules can be defined in several ways:
22
+ # 1. Python lambda/function: (result: Dict, inputs: Dict, context: FlockContext) -> bool | Tuple[bool, str]
23
+ # 2. String referencing a registered callable: "my_validation_function"
24
+ # 3. Natural language rule string: "The summary must contain the keyword 'Flock'." (requires LLM judge)
25
+ # 4. Pydantic Model: The output must conform to this Pydantic model.
26
+
27
+ RuleType = (
28
+ Callable[[dict, dict, FlockContext | None], bool | tuple[bool, str]]
29
+ | str
30
+ | type[BaseModel]
31
+ )
32
+
33
+
34
+ class Rule(BaseModel):
35
+ """Container for a single assertion rule."""
36
+
37
+ condition: RuleType = Field(
38
+ ...,
39
+ description="""
40
+ # --- Rule Definition ---
41
+ # Rules can be defined in several ways:
42
+ # 1. Python lambda/function: (result: Dict, inputs: Dict, context: FlockContext) -> bool | Tuple[bool, str]
43
+ # 2. String referencing a registered callable: "my_validation_function"
44
+ # 3. Natural language rule string: "The summary must contain the keyword 'Flock'." (requires LLM judge)
45
+ # 4. Pydantic Model: The output must conform to this Pydantic model.
46
+ """,
47
+ )
48
+ fail_message: str # Message to provide as feedback on failure
49
+ name: str | None = None # Optional name for clarity
50
+
51
+ def __post_init__(self):
52
+ # Basic validation of fail_message
53
+ if not isinstance(self.fail_message, str) or not self.fail_message:
54
+ raise ValueError("Rule fail_message must be a non-empty string.")
55
+
56
+
57
+ class AssertionModuleConfig(FlockModuleConfig):
58
+ """--- Rule Definition ---
59
+ Rules can be defined in several ways:
60
+ 1. Python lambda/function: (result: Dict, inputs: Dict, context: FlockContext) -> bool | Tuple[bool, str]
61
+ 2. String referencing a registered callable: "my_validation_function"
62
+ 3. Natural language rule string: "The summary must contain the keyword 'Flock'." (requires LLM judge)
63
+ 4. Pydantic Model: The output must conform to this Pydantic model.
64
+ """
65
+
66
+ rules: list[Rule] = Field(
67
+ default_factory=list,
68
+ description="List of rules to check against the agent's output.",
69
+ )
70
+ # Optional LLM for evaluating natural language rules
71
+ judge_lm_model: str | None = Field(
72
+ None, description="LLM model to use for judging natural language rules."
73
+ )
74
+ # How to handle failure
75
+ on_failure: Literal["add_feedback", "raise_error", "log_warning"] = Field(
76
+ default="add_feedback",
77
+ description="Action on rule failure: 'add_feedback' to context, 'raise_error', 'log_warning'.",
78
+ )
79
+ feedback_context_key: str = Field(
80
+ default="flock.assertion_feedback",
81
+ description="Context key to store failure messages for retry loops.",
82
+ )
83
+ clear_feedback_on_success: bool = Field(
84
+ default=True,
85
+ description="Clear the feedback key from context if all assertions pass.",
86
+ )
87
+
88
+
89
+ @flock_component
90
+ class AssertionCheckerModule(FlockModule):
91
+ """Checks the output of an agent against a set of defined rules.
92
+
93
+ Can trigger different actions on failure, including adding feedback
94
+ to the context to enable self-correction loops via routing.
95
+ """
96
+
97
+ name: str = "assertion_checker"
98
+ config: AssertionModuleConfig = Field(default_factory=AssertionModuleConfig)
99
+ _judge_lm: dspy.LM | None = PrivateAttr(None) # Initialize lazily
100
+
101
+ def _get_judge_lm(self) -> dspy.LM | None:
102
+ """Initializes the judge LM if needed."""
103
+ if self.config.judge_lm_model and self._judge_lm is None:
104
+ try:
105
+ self._judge_lm = dspy.LM(self.config.judge_lm_model)
106
+ except Exception as e:
107
+ logger.error(
108
+ f"Failed to initialize judge LM '{self.config.judge_lm_model}': {e}"
109
+ )
110
+ # Proceed without judge LM for other rule types
111
+ return self._judge_lm
112
+
113
+ async def post_evaluate(
114
+ self,
115
+ agent: FlockAgent,
116
+ inputs: dict[str, Any],
117
+ result: dict[str, Any],
118
+ context: FlockContext | None = None,
119
+ ) -> dict[str, Any]:
120
+ """Checks rules after the main evaluator runs."""
121
+ if not self.config.rules:
122
+ return result # No rules to check
123
+
124
+ logger.debug(f"Running assertion checks for agent '{agent.name}'...")
125
+ all_passed = True
126
+ failed_messages = []
127
+ registry = get_registry() # Needed for callable lookup
128
+
129
+ for i, rule in enumerate(self.config.rules):
130
+ rule_name = rule.name or f"Rule_{i + 1}"
131
+ passed = False
132
+ eval_result = None
133
+ feedback_msg = rule.fail_message
134
+
135
+ try:
136
+ condition = rule.condition
137
+ if callable(condition):
138
+ # Rule is a Python function/lambda
139
+ logger.debug(f"Checking callable rule: {rule_name}")
140
+ eval_result = condition(result, inputs, context)
141
+ elif isinstance(condition, str) and registry.contains(
142
+ condition
143
+ ):
144
+ # Rule is a string referencing a registered callable
145
+ logger.debug(
146
+ f"Checking registered callable rule: '{condition}'"
147
+ )
148
+ rule_func = registry.get_callable(condition)
149
+ eval_result = rule_func(result, inputs, context)
150
+ elif isinstance(condition, str):
151
+ # Rule is a natural language string (requires judge LLM)
152
+ logger.debug(
153
+ f"Checking natural language rule: '{condition}'"
154
+ )
155
+ judge_lm = self._get_judge_lm()
156
+ if judge_lm:
157
+ # Define a simple judge signature dynamically or use a predefined one
158
+ class JudgeSignature(dspy.Signature):
159
+ """Evaluate if the output meets the rule based on input and output."""
160
+
161
+ program_input: str = dspy.InputField(
162
+ desc="Input provided to the agent."
163
+ )
164
+ program_output: str = dspy.InputField(
165
+ desc="Output generated by the agent."
166
+ )
167
+ rule_to_check: str = dspy.InputField(
168
+ desc="The rule to verify."
169
+ )
170
+ is_met: bool = dspy.OutputField(
171
+ desc="True if the rule is met, False otherwise."
172
+ )
173
+ reasoning: str = dspy.OutputField(
174
+ desc="Brief reasoning for the decision."
175
+ )
176
+
177
+ judge_predictor = dspy.Predict(
178
+ JudgeSignature, llm=judge_lm
179
+ )
180
+ # Convert complex dicts/lists to strings for the judge prompt
181
+ input_str = json.dumps(inputs, default=str, indent=2)
182
+ result_str = json.dumps(result, default=str, indent=2)
183
+ judge_pred = judge_predictor(
184
+ program_input=input_str,
185
+ program_output=result_str,
186
+ rule_to_check=condition,
187
+ )
188
+ passed = judge_pred.is_met
189
+ feedback_msg = f"{rule.fail_message} (Reason: {judge_pred.reasoning})"
190
+ logger.debug(
191
+ f"LLM Judge result for rule '{condition}': {passed} ({judge_pred.reasoning})"
192
+ )
193
+ else:
194
+ logger.warning(
195
+ f"Cannot evaluate natural language rule '{condition}' - no judge_lm_model configured."
196
+ )
197
+ passed = True # Default to pass if no judge available? Or fail? Let's pass.
198
+
199
+ elif isinstance(condition, type) and issubclass(
200
+ condition, BaseModel
201
+ ):
202
+ # Rule is a Pydantic model for validation
203
+ logger.debug(
204
+ f"Checking Pydantic validation rule: {condition.__name__}"
205
+ )
206
+ try:
207
+ # Assumes the *entire* result dict should match the model
208
+ # More specific logic might be needed (e.g., validate only a specific key)
209
+ condition.model_validate(result)
210
+ passed = True
211
+ except ValidationError as e:
212
+ passed = False
213
+ feedback_msg = (
214
+ f"{rule.fail_message} (Validation Error: {e})"
215
+ )
216
+ else:
217
+ logger.warning(
218
+ f"Unsupported rule type for rule '{rule_name}': {type(condition)}"
219
+ )
220
+ continue # Skip rule
221
+
222
+ # Process result if it was a callable returning bool or (bool, msg)
223
+ if eval_result is not None:
224
+ if (
225
+ isinstance(eval_result, tuple)
226
+ and len(eval_result) == 2
227
+ and isinstance(eval_result[0], bool)
228
+ ):
229
+ passed, custom_msg = eval_result
230
+ if not passed and custom_msg:
231
+ feedback_msg = (
232
+ custom_msg # Use custom message on failure
233
+ )
234
+ elif isinstance(eval_result, bool):
235
+ passed = eval_result
236
+ else:
237
+ logger.warning(
238
+ f"Rule callable '{rule_name}' returned unexpected type: {type(eval_result)}. Rule skipped."
239
+ )
240
+ continue
241
+
242
+ # Handle failure
243
+ if not passed:
244
+ all_passed = False
245
+ failed_messages.append(feedback_msg)
246
+ logger.warning(
247
+ f"Assertion Failed for agent '{agent.name}': {feedback_msg}"
248
+ )
249
+ # Optionally break early? For now, check all rules.
250
+
251
+ except Exception as e:
252
+ logger.error(
253
+ f"Error executing rule '{rule_name}' for agent '{agent.name}': {e}",
254
+ exc_info=True,
255
+ )
256
+ all_passed = False
257
+ failed_messages.append(
258
+ f"Error checking rule '{rule_name}': {e}"
259
+ )
260
+ # Treat error during check as failure
261
+
262
+ # --- Take action based on results ---
263
+ if not all_passed:
264
+ logger.warning(f"Agent '{agent.name}' failed assertion checks.")
265
+ if self.config.on_failure == "add_feedback" and context:
266
+ context.set_variable(
267
+ self.config.feedback_context_key, "\n".join(failed_messages)
268
+ )
269
+ logger.debug(
270
+ f"Added assertion feedback to context key '{self.config.feedback_context_key}'"
271
+ )
272
+ elif self.config.on_failure == "raise_error":
273
+ # Maybe wrap in a specific FlockAssertionError
274
+ raise AssertionError(
275
+ f"Agent '{agent.name}' failed assertions: {'; '.join(failed_messages)}"
276
+ )
277
+ # else "log_warning" is default behavior
278
+ elif context and self.config.clear_feedback_on_success:
279
+ # Clear feedback key if all rules passed and key exists
280
+ if self.config.feedback_context_key in context.state:
281
+ del context.state[self.config.feedback_context_key]
282
+ logger.debug(
283
+ f"Cleared assertion feedback key '{self.config.feedback_context_key}' on success."
284
+ )
285
+
286
+ return result # Return the original result unmodified
@@ -105,7 +105,7 @@ class AgentRouter(FlockRouter):
105
105
  logger.warning("No available agents for agent-based routing")
106
106
  return HandOffRequest(
107
107
  next_agent="",
108
- hand_off_mode="add",
108
+ output_to_input_merge_strategy="add",
109
109
  override_next_agent=None,
110
110
  override_context=None,
111
111
  )
@@ -138,7 +138,7 @@ class AgentRouter(FlockRouter):
138
138
  )
139
139
  return HandOffRequest(
140
140
  next_agent="",
141
- hand_off_mode="add",
141
+ output_to_input_merge_strategy="add",
142
142
  override_next_agent=None,
143
143
  override_context=None,
144
144
  )
@@ -150,7 +150,7 @@ class AgentRouter(FlockRouter):
150
150
  )
151
151
  return HandOffRequest(
152
152
  next_agent="",
153
- hand_off_mode="add",
153
+ output_to_input_merge_strategy="add",
154
154
  override_next_agent=None,
155
155
  override_context=None,
156
156
  )
@@ -160,7 +160,7 @@ class AgentRouter(FlockRouter):
160
160
  )
161
161
  return HandOffRequest(
162
162
  next_agent=next_agent_name,
163
- hand_off_mode="add",
163
+ output_to_input_merge_strategy="add",
164
164
  override_next_agent=None,
165
165
  override_context=None,
166
166
  )
@@ -169,7 +169,7 @@ class AgentRouter(FlockRouter):
169
169
  logger.error(f"Error in agent-based routing: {e}")
170
170
  return HandOffRequest(
171
171
  next_agent="",
172
- hand_off_mode="add",
172
+ output_to_input_merge_strategy="add",
173
173
  override_next_agent=None,
174
174
  override_context=None,
175
175
  )