kailash 0.1.5__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kailash/__init__.py +1 -1
- kailash/access_control.py +740 -0
- kailash/api/__main__.py +6 -0
- kailash/api/auth.py +668 -0
- kailash/api/custom_nodes.py +285 -0
- kailash/api/custom_nodes_secure.py +377 -0
- kailash/api/database.py +620 -0
- kailash/api/studio.py +915 -0
- kailash/api/studio_secure.py +893 -0
- kailash/mcp/__init__.py +53 -0
- kailash/mcp/__main__.py +13 -0
- kailash/mcp/ai_registry_server.py +712 -0
- kailash/mcp/client.py +447 -0
- kailash/mcp/client_new.py +334 -0
- kailash/mcp/server.py +293 -0
- kailash/mcp/server_new.py +336 -0
- kailash/mcp/servers/__init__.py +12 -0
- kailash/mcp/servers/ai_registry.py +289 -0
- kailash/nodes/__init__.py +4 -2
- kailash/nodes/ai/__init__.py +2 -0
- kailash/nodes/ai/a2a.py +714 -67
- kailash/nodes/ai/intelligent_agent_orchestrator.py +31 -37
- kailash/nodes/ai/iterative_llm_agent.py +1280 -0
- kailash/nodes/ai/llm_agent.py +324 -1
- kailash/nodes/ai/self_organizing.py +5 -6
- kailash/nodes/base.py +15 -2
- kailash/nodes/base_async.py +45 -0
- kailash/nodes/base_cycle_aware.py +374 -0
- kailash/nodes/base_with_acl.py +338 -0
- kailash/nodes/code/python.py +135 -27
- kailash/nodes/data/readers.py +16 -6
- kailash/nodes/data/writers.py +16 -6
- kailash/nodes/logic/__init__.py +8 -0
- kailash/nodes/logic/convergence.py +642 -0
- kailash/nodes/logic/loop.py +153 -0
- kailash/nodes/logic/operations.py +187 -27
- kailash/nodes/mixins/__init__.py +11 -0
- kailash/nodes/mixins/mcp.py +228 -0
- kailash/nodes/mixins.py +387 -0
- kailash/runtime/__init__.py +2 -1
- kailash/runtime/access_controlled.py +458 -0
- kailash/runtime/local.py +106 -33
- kailash/runtime/parallel_cyclic.py +529 -0
- kailash/sdk_exceptions.py +90 -5
- kailash/security.py +845 -0
- kailash/tracking/manager.py +38 -15
- kailash/tracking/models.py +1 -1
- kailash/tracking/storage/filesystem.py +30 -2
- kailash/utils/__init__.py +8 -0
- kailash/workflow/__init__.py +18 -0
- kailash/workflow/convergence.py +270 -0
- kailash/workflow/cycle_analyzer.py +768 -0
- kailash/workflow/cycle_builder.py +573 -0
- kailash/workflow/cycle_config.py +709 -0
- kailash/workflow/cycle_debugger.py +760 -0
- kailash/workflow/cycle_exceptions.py +601 -0
- kailash/workflow/cycle_profiler.py +671 -0
- kailash/workflow/cycle_state.py +338 -0
- kailash/workflow/cyclic_runner.py +985 -0
- kailash/workflow/graph.py +500 -39
- kailash/workflow/migration.py +768 -0
- kailash/workflow/safety.py +365 -0
- kailash/workflow/templates.py +744 -0
- kailash/workflow/validation.py +693 -0
- {kailash-0.1.5.dist-info → kailash-0.2.0.dist-info}/METADATA +256 -12
- kailash-0.2.0.dist-info/RECORD +125 -0
- kailash/nodes/mcp/__init__.py +0 -11
- kailash/nodes/mcp/client.py +0 -554
- kailash/nodes/mcp/resource.py +0 -682
- kailash/nodes/mcp/server.py +0 -577
- kailash-0.1.5.dist-info/RECORD +0 -88
- {kailash-0.1.5.dist-info → kailash-0.2.0.dist-info}/WHEEL +0 -0
- {kailash-0.1.5.dist-info → kailash-0.2.0.dist-info}/entry_points.txt +0 -0
- {kailash-0.1.5.dist-info → kailash-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {kailash-0.1.5.dist-info → kailash-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,642 @@
|
|
1
|
+
"""
|
2
|
+
Convergence checking nodes for cyclic workflows.
|
3
|
+
|
4
|
+
This module provides specialized nodes for detecting convergence in cyclic workflows,
|
5
|
+
eliminating the need for custom convergence logic in every workflow. These nodes
|
6
|
+
implement common convergence patterns and can be easily configured for different
|
7
|
+
scenarios.
|
8
|
+
|
9
|
+
Design Philosophy:
|
10
|
+
Convergence detection is a critical pattern in cyclic workflows. This module
|
11
|
+
provides declarative convergence checking that replaces imperative convergence
|
12
|
+
logic with configurable nodes, making workflows more maintainable and testable.
|
13
|
+
|
14
|
+
Example usage:
|
15
|
+
>>> from kailash.nodes.logic.convergence import ConvergenceCheckerNode
|
16
|
+
>>> from kailash import Workflow
|
17
|
+
>>>
|
18
|
+
>>> workflow = Workflow("convergence-demo")
|
19
|
+
>>> workflow.add_node("convergence", ConvergenceCheckerNode(),
|
20
|
+
... threshold=0.8, mode="threshold")
|
21
|
+
>>>
|
22
|
+
>>> # Connect to SwitchNode for conditional routing
|
23
|
+
>>> workflow.add_node("switch", SwitchNode(
|
24
|
+
... condition_field="converged",
|
25
|
+
... true_route="output",
|
26
|
+
... false_route="processor"
|
27
|
+
... ))
|
28
|
+
"""
|
29
|
+
|
30
|
+
from typing import Any, Dict, List, Optional
|
31
|
+
|
32
|
+
from ..base import NodeParameter, register_node
|
33
|
+
from ..base_cycle_aware import CycleAwareNode
|
34
|
+
|
35
|
+
|
36
|
+
@register_node()
|
37
|
+
class ConvergenceCheckerNode(CycleAwareNode):
|
38
|
+
"""
|
39
|
+
Specialized node for detecting convergence in cyclic workflows.
|
40
|
+
|
41
|
+
This node implements common convergence patterns and eliminates the need
|
42
|
+
for custom convergence logic in every workflow. It supports multiple
|
43
|
+
convergence modes and provides detailed feedback about convergence status.
|
44
|
+
|
45
|
+
Design Philosophy:
|
46
|
+
ConvergenceCheckerNode provides a declarative approach to convergence
|
47
|
+
detection. Instead of writing custom convergence logic in each workflow,
|
48
|
+
users configure convergence criteria and the node handles the detection
|
49
|
+
logic, state tracking, and reporting.
|
50
|
+
|
51
|
+
Upstream Dependencies:
|
52
|
+
- Any node producing numeric values to monitor
|
53
|
+
- Common patterns: optimizers, iterative refiners, quality improvers
|
54
|
+
- Must receive 'value' parameter to check for convergence
|
55
|
+
|
56
|
+
Downstream Consumers:
|
57
|
+
- SwitchNode: Routes based on 'converged' field
|
58
|
+
- Output nodes: Process final converged results
|
59
|
+
- Monitoring nodes: Track convergence progress
|
60
|
+
|
61
|
+
Configuration:
|
62
|
+
mode (str): Convergence detection mode
|
63
|
+
- 'threshold': Value reaches target threshold
|
64
|
+
- 'stability': Value becomes stable (low variance)
|
65
|
+
- 'improvement': Rate of improvement drops below threshold
|
66
|
+
- 'combined': Multiple criteria must be met
|
67
|
+
- 'custom': User-defined convergence expression
|
68
|
+
threshold (float): Target value for threshold mode
|
69
|
+
stability_window (int): Number of values for stability check
|
70
|
+
min_variance (float): Maximum variance for stability
|
71
|
+
min_improvement (float): Minimum improvement rate
|
72
|
+
patience (int): Iterations without improvement before stopping
|
73
|
+
|
74
|
+
Implementation Details:
|
75
|
+
- Inherits from CycleAwareNode for iteration tracking
|
76
|
+
- Maintains value history across iterations
|
77
|
+
- Tracks best value and no-improvement count
|
78
|
+
- Supports multiple convergence detection algorithms
|
79
|
+
- Provides detailed metrics for debugging
|
80
|
+
|
81
|
+
Error Handling:
|
82
|
+
- Invalid modes raise ValueError
|
83
|
+
- Missing value parameter uses default 0.0
|
84
|
+
- Custom expressions are safely evaluated
|
85
|
+
|
86
|
+
Side Effects:
|
87
|
+
- Logs convergence status each iteration
|
88
|
+
- No external state modifications
|
89
|
+
|
90
|
+
Examples:
|
91
|
+
>>> # Simple threshold convergence
|
92
|
+
>>> convergence = ConvergenceCheckerNode()
|
93
|
+
>>> workflow.add_node("convergence", convergence,
|
94
|
+
... threshold=0.95, mode="threshold")
|
95
|
+
>>>
|
96
|
+
>>> # Stability-based convergence
|
97
|
+
>>> workflow.add_node("stability", ConvergenceCheckerNode(),
|
98
|
+
... mode="stability", stability_window=5, min_variance=0.001)
|
99
|
+
>>>
|
100
|
+
>>> # Combined convergence criteria
|
101
|
+
>>> workflow.add_node("combined", ConvergenceCheckerNode(),
|
102
|
+
... mode="combined", threshold=0.9, stability_window=3)
|
103
|
+
"""
|
104
|
+
|
105
|
+
def get_parameters(self) -> Dict[str, NodeParameter]:
|
106
|
+
"""Define input parameters for convergence checking."""
|
107
|
+
return {
|
108
|
+
"value": NodeParameter(
|
109
|
+
name="value",
|
110
|
+
type=float, # Changed from Union[float, int] to just float
|
111
|
+
required=False,
|
112
|
+
default=0.0,
|
113
|
+
description="Value to check for convergence",
|
114
|
+
),
|
115
|
+
"threshold": NodeParameter(
|
116
|
+
name="threshold",
|
117
|
+
type=float,
|
118
|
+
required=False,
|
119
|
+
default=0.8,
|
120
|
+
description="Target threshold for convergence (mode: threshold, combined)",
|
121
|
+
),
|
122
|
+
"mode": NodeParameter(
|
123
|
+
name="mode",
|
124
|
+
type=str,
|
125
|
+
required=False,
|
126
|
+
default="threshold",
|
127
|
+
description="Convergence detection mode: threshold|stability|improvement|combined|custom",
|
128
|
+
),
|
129
|
+
"stability_window": NodeParameter(
|
130
|
+
name="stability_window",
|
131
|
+
type=int,
|
132
|
+
required=False,
|
133
|
+
default=3,
|
134
|
+
description="Number of recent values to analyze for stability",
|
135
|
+
),
|
136
|
+
"min_variance": NodeParameter(
|
137
|
+
name="min_variance",
|
138
|
+
type=float,
|
139
|
+
required=False,
|
140
|
+
default=0.01,
|
141
|
+
description="Maximum variance for stability convergence",
|
142
|
+
),
|
143
|
+
"min_improvement": NodeParameter(
|
144
|
+
name="min_improvement",
|
145
|
+
type=float,
|
146
|
+
required=False,
|
147
|
+
default=0.01,
|
148
|
+
description="Minimum improvement rate to continue (mode: improvement)",
|
149
|
+
),
|
150
|
+
"improvement_window": NodeParameter(
|
151
|
+
name="improvement_window",
|
152
|
+
type=int,
|
153
|
+
required=False,
|
154
|
+
default=3,
|
155
|
+
description="Window for calculating improvement rate",
|
156
|
+
),
|
157
|
+
"custom_expression": NodeParameter(
|
158
|
+
name="custom_expression",
|
159
|
+
type=str,
|
160
|
+
required=False,
|
161
|
+
description="Custom convergence expression (mode: custom)",
|
162
|
+
),
|
163
|
+
"early_stop_iterations": NodeParameter(
|
164
|
+
name="early_stop_iterations",
|
165
|
+
type=int,
|
166
|
+
required=False,
|
167
|
+
description="Force convergence after this many iterations",
|
168
|
+
),
|
169
|
+
"patience": NodeParameter(
|
170
|
+
name="patience",
|
171
|
+
type=int,
|
172
|
+
required=False,
|
173
|
+
default=5,
|
174
|
+
description="Iterations to wait without improvement before stopping",
|
175
|
+
),
|
176
|
+
"data": NodeParameter(
|
177
|
+
name="data",
|
178
|
+
type=Any,
|
179
|
+
required=False,
|
180
|
+
description="Pass-through data to preserve in the output",
|
181
|
+
),
|
182
|
+
}
|
183
|
+
|
184
|
+
def get_output_schema(self) -> Dict[str, NodeParameter]:
|
185
|
+
"""Define output schema for convergence results."""
|
186
|
+
return {
|
187
|
+
"converged": NodeParameter(
|
188
|
+
name="converged",
|
189
|
+
type=bool,
|
190
|
+
required=True,
|
191
|
+
description="Whether convergence has been achieved",
|
192
|
+
),
|
193
|
+
"reason": NodeParameter(
|
194
|
+
name="reason",
|
195
|
+
type=str,
|
196
|
+
required=True,
|
197
|
+
description="Explanation of convergence decision",
|
198
|
+
),
|
199
|
+
"value": NodeParameter(
|
200
|
+
name="value",
|
201
|
+
type=float,
|
202
|
+
required=True,
|
203
|
+
description="Current value being monitored",
|
204
|
+
),
|
205
|
+
"iteration": NodeParameter(
|
206
|
+
name="iteration",
|
207
|
+
type=int,
|
208
|
+
required=True,
|
209
|
+
description="Current iteration number",
|
210
|
+
),
|
211
|
+
"convergence_metrics": NodeParameter(
|
212
|
+
name="convergence_metrics",
|
213
|
+
type=dict,
|
214
|
+
required=True,
|
215
|
+
description="Detailed metrics about convergence progress",
|
216
|
+
),
|
217
|
+
}
|
218
|
+
|
219
|
+
def run(self, context: Dict[str, Any], **kwargs) -> Dict[str, Any]:
|
220
|
+
"""Execute convergence checking logic."""
|
221
|
+
# Get parameters
|
222
|
+
value = kwargs["value"]
|
223
|
+
threshold = kwargs.get("threshold", 0.8)
|
224
|
+
mode = kwargs.get("mode", "threshold")
|
225
|
+
stability_window = kwargs.get("stability_window", 3)
|
226
|
+
min_variance = kwargs.get("min_variance", 0.01)
|
227
|
+
min_improvement = kwargs.get("min_improvement", 0.01)
|
228
|
+
improvement_window = kwargs.get("improvement_window", 3)
|
229
|
+
custom_expression = kwargs.get("custom_expression")
|
230
|
+
early_stop_iterations = kwargs.get("early_stop_iterations")
|
231
|
+
patience = kwargs.get("patience", 5)
|
232
|
+
|
233
|
+
# Get cycle information
|
234
|
+
iteration = self.get_iteration(context)
|
235
|
+
is_first = self.is_first_iteration(context)
|
236
|
+
|
237
|
+
# Update value history
|
238
|
+
value_history = self.accumulate_values(context, "values", value)
|
239
|
+
|
240
|
+
# Get previous state for additional tracking
|
241
|
+
prev_state = self.get_previous_state(context)
|
242
|
+
best_value = prev_state.get("best_value", value)
|
243
|
+
no_improvement_count = prev_state.get("no_improvement_count", 0)
|
244
|
+
convergence_start_iteration = prev_state.get("convergence_start_iteration")
|
245
|
+
|
246
|
+
# Update best value and improvement tracking
|
247
|
+
if value > best_value:
|
248
|
+
best_value = value
|
249
|
+
no_improvement_count = 0
|
250
|
+
else:
|
251
|
+
no_improvement_count += 1
|
252
|
+
|
253
|
+
# Initialize convergence state
|
254
|
+
converged = False
|
255
|
+
reason = ""
|
256
|
+
metrics = {
|
257
|
+
"value": value,
|
258
|
+
"best_value": best_value,
|
259
|
+
"value_history": value_history[-10:], # Keep last 10 for metrics
|
260
|
+
"no_improvement_count": no_improvement_count,
|
261
|
+
"iteration": iteration,
|
262
|
+
}
|
263
|
+
|
264
|
+
# Check early stopping conditions first
|
265
|
+
if early_stop_iterations and iteration >= early_stop_iterations:
|
266
|
+
converged = True
|
267
|
+
reason = f"Early stop: reached {early_stop_iterations} iterations"
|
268
|
+
elif patience and no_improvement_count >= patience:
|
269
|
+
converged = True
|
270
|
+
reason = f"Early stop: no improvement for {patience} iterations"
|
271
|
+
else:
|
272
|
+
# Apply convergence mode logic
|
273
|
+
if mode == "threshold":
|
274
|
+
converged, reason, mode_metrics = self._check_threshold_convergence(
|
275
|
+
value, threshold, iteration
|
276
|
+
)
|
277
|
+
elif mode == "stability":
|
278
|
+
converged, reason, mode_metrics = self._check_stability_convergence(
|
279
|
+
value_history, stability_window, min_variance, iteration
|
280
|
+
)
|
281
|
+
elif mode == "improvement":
|
282
|
+
converged, reason, mode_metrics = self._check_improvement_convergence(
|
283
|
+
value_history, improvement_window, min_improvement, iteration
|
284
|
+
)
|
285
|
+
elif mode == "combined":
|
286
|
+
converged, reason, mode_metrics = self._check_combined_convergence(
|
287
|
+
value,
|
288
|
+
value_history,
|
289
|
+
threshold,
|
290
|
+
stability_window,
|
291
|
+
min_variance,
|
292
|
+
iteration,
|
293
|
+
)
|
294
|
+
elif mode == "custom":
|
295
|
+
converged, reason, mode_metrics = self._check_custom_convergence(
|
296
|
+
value, value_history, custom_expression, iteration, **kwargs
|
297
|
+
)
|
298
|
+
else:
|
299
|
+
raise ValueError(f"Unsupported convergence mode: {mode}")
|
300
|
+
|
301
|
+
metrics.update(mode_metrics)
|
302
|
+
|
303
|
+
# Track convergence start time
|
304
|
+
if converged and convergence_start_iteration is None:
|
305
|
+
convergence_start_iteration = iteration
|
306
|
+
elif not converged:
|
307
|
+
convergence_start_iteration = None
|
308
|
+
|
309
|
+
metrics["convergence_start_iteration"] = convergence_start_iteration
|
310
|
+
|
311
|
+
# Log convergence status
|
312
|
+
if is_first:
|
313
|
+
self.log_cycle_info(
|
314
|
+
context, f"Starting convergence monitoring (mode: {mode})"
|
315
|
+
)
|
316
|
+
elif converged:
|
317
|
+
self.log_cycle_info(context, f"✅ CONVERGED: {reason}")
|
318
|
+
else:
|
319
|
+
self.log_cycle_info(context, f"Monitoring: {reason}")
|
320
|
+
|
321
|
+
# Prepare state for next iteration
|
322
|
+
next_state = {
|
323
|
+
"values": value_history,
|
324
|
+
"best_value": best_value,
|
325
|
+
"no_improvement_count": no_improvement_count,
|
326
|
+
"convergence_start_iteration": convergence_start_iteration,
|
327
|
+
}
|
328
|
+
|
329
|
+
# Include pass-through data if provided
|
330
|
+
result = {
|
331
|
+
"converged": converged,
|
332
|
+
"reason": reason,
|
333
|
+
"value": value,
|
334
|
+
"iteration": iteration,
|
335
|
+
"convergence_metrics": metrics,
|
336
|
+
**self.set_cycle_state(next_state),
|
337
|
+
}
|
338
|
+
|
339
|
+
# Add data to output if provided
|
340
|
+
if "data" in kwargs:
|
341
|
+
result["data"] = kwargs["data"]
|
342
|
+
|
343
|
+
return result
|
344
|
+
|
345
|
+
def _check_threshold_convergence(
|
346
|
+
self, value: float, threshold: float, iteration: int
|
347
|
+
) -> tuple[bool, str, dict]:
|
348
|
+
"""Check if value has reached threshold."""
|
349
|
+
converged = value >= threshold
|
350
|
+
reason = (
|
351
|
+
f"Value {value:.3f} {'≥' if converged else '<'} threshold {threshold:.3f}"
|
352
|
+
)
|
353
|
+
metrics = {"threshold": threshold, "distance_to_threshold": threshold - value}
|
354
|
+
return converged, reason, metrics
|
355
|
+
|
356
|
+
def _check_stability_convergence(
|
357
|
+
self,
|
358
|
+
value_history: List[float],
|
359
|
+
window: int,
|
360
|
+
min_variance: float,
|
361
|
+
iteration: int,
|
362
|
+
) -> tuple[bool, str, dict]:
|
363
|
+
"""Check if values have stabilized."""
|
364
|
+
if len(value_history) < window:
|
365
|
+
reason = f"Need {window} values, have {len(value_history)}"
|
366
|
+
metrics = {"variance": None, "window_size": len(value_history)}
|
367
|
+
return False, reason, metrics
|
368
|
+
|
369
|
+
recent_values = value_history[-window:]
|
370
|
+
variance = max(recent_values) - min(recent_values)
|
371
|
+
converged = variance <= min_variance
|
372
|
+
|
373
|
+
reason = (
|
374
|
+
f"Variance {variance:.4f} {'≤' if converged else '>'} {min_variance:.4f}"
|
375
|
+
)
|
376
|
+
metrics = {
|
377
|
+
"variance": variance,
|
378
|
+
"min_variance": min_variance,
|
379
|
+
"window_values": recent_values,
|
380
|
+
}
|
381
|
+
return converged, reason, metrics
|
382
|
+
|
383
|
+
def _check_improvement_convergence(
|
384
|
+
self,
|
385
|
+
value_history: List[float],
|
386
|
+
window: int,
|
387
|
+
min_improvement: float,
|
388
|
+
iteration: int,
|
389
|
+
) -> tuple[bool, str, dict]:
|
390
|
+
"""Check if improvement rate has dropped below threshold."""
|
391
|
+
if len(value_history) < window:
|
392
|
+
reason = f"Need {window} values for improvement calculation"
|
393
|
+
metrics = {"improvement_rate": None}
|
394
|
+
return False, reason, metrics
|
395
|
+
|
396
|
+
recent_values = value_history[-window:]
|
397
|
+
if len(recent_values) < 2:
|
398
|
+
improvement_rate = 0.0
|
399
|
+
else:
|
400
|
+
improvement_rate = (recent_values[-1] - recent_values[0]) / (
|
401
|
+
len(recent_values) - 1
|
402
|
+
)
|
403
|
+
|
404
|
+
converged = improvement_rate < min_improvement
|
405
|
+
reason = f"Improvement rate {improvement_rate:.4f} {'<' if converged else '≥'} {min_improvement:.4f}"
|
406
|
+
metrics = {
|
407
|
+
"improvement_rate": improvement_rate,
|
408
|
+
"min_improvement": min_improvement,
|
409
|
+
"window_values": recent_values,
|
410
|
+
}
|
411
|
+
return converged, reason, metrics
|
412
|
+
|
413
|
+
def _check_combined_convergence(
|
414
|
+
self,
|
415
|
+
value: float,
|
416
|
+
value_history: List[float],
|
417
|
+
threshold: float,
|
418
|
+
stability_window: int,
|
419
|
+
min_variance: float,
|
420
|
+
iteration: int,
|
421
|
+
) -> tuple[bool, str, dict]:
|
422
|
+
"""Check combined threshold and stability convergence."""
|
423
|
+
# Check threshold first
|
424
|
+
threshold_met, threshold_reason, threshold_metrics = (
|
425
|
+
self._check_threshold_convergence(value, threshold, iteration)
|
426
|
+
)
|
427
|
+
|
428
|
+
# Check stability
|
429
|
+
stability_met, stability_reason, stability_metrics = (
|
430
|
+
self._check_stability_convergence(
|
431
|
+
value_history, stability_window, min_variance, iteration
|
432
|
+
)
|
433
|
+
)
|
434
|
+
|
435
|
+
# Both must be met for convergence
|
436
|
+
converged = threshold_met and stability_met
|
437
|
+
|
438
|
+
if converged:
|
439
|
+
reason = f"Both conditions met: {threshold_reason} AND {stability_reason}"
|
440
|
+
elif threshold_met:
|
441
|
+
reason = f"Threshold met but unstable: {stability_reason}"
|
442
|
+
else:
|
443
|
+
reason = f"Threshold not met: {threshold_reason}"
|
444
|
+
|
445
|
+
metrics = {
|
446
|
+
"threshold_met": threshold_met,
|
447
|
+
"stability_met": stability_met,
|
448
|
+
**threshold_metrics,
|
449
|
+
**stability_metrics,
|
450
|
+
}
|
451
|
+
|
452
|
+
return converged, reason, metrics
|
453
|
+
|
454
|
+
def _check_custom_convergence(
|
455
|
+
self,
|
456
|
+
value: float,
|
457
|
+
value_history: List[float],
|
458
|
+
expression: Optional[str],
|
459
|
+
iteration: int,
|
460
|
+
**kwargs,
|
461
|
+
) -> tuple[bool, str, dict]:
|
462
|
+
"""Check custom convergence expression."""
|
463
|
+
if not expression:
|
464
|
+
return False, "No custom expression provided", {}
|
465
|
+
|
466
|
+
try:
|
467
|
+
# Create evaluation context
|
468
|
+
eval_context = {
|
469
|
+
"value": value,
|
470
|
+
"iteration": iteration,
|
471
|
+
"history": value_history,
|
472
|
+
"len": len,
|
473
|
+
"max": max,
|
474
|
+
"min": min,
|
475
|
+
"sum": sum,
|
476
|
+
"abs": abs,
|
477
|
+
**kwargs, # Include all parameters
|
478
|
+
}
|
479
|
+
|
480
|
+
# Evaluate custom expression
|
481
|
+
converged = bool(eval(expression, {"__builtins__": {}}, eval_context))
|
482
|
+
reason = f"Custom expression '{expression}' = {converged}"
|
483
|
+
metrics = {"custom_expression": expression, "eval_context": eval_context}
|
484
|
+
|
485
|
+
return converged, reason, metrics
|
486
|
+
|
487
|
+
except Exception as e:
|
488
|
+
reason = f"Custom expression error: {e}"
|
489
|
+
metrics = {"custom_expression": expression, "error": str(e)}
|
490
|
+
return False, reason, metrics
|
491
|
+
|
492
|
+
|
493
|
+
@register_node()
|
494
|
+
class MultiCriteriaConvergenceNode(CycleAwareNode):
|
495
|
+
"""
|
496
|
+
Node for checking convergence across multiple metrics simultaneously.
|
497
|
+
|
498
|
+
This node monitors multiple values and applies different convergence
|
499
|
+
criteria to each, allowing for complex multi-dimensional convergence
|
500
|
+
checking.
|
501
|
+
|
502
|
+
Example:
|
503
|
+
>>> convergence = MultiCriteriaConvergenceNode()
|
504
|
+
>>> workflow.add_node("convergence", convergence,
|
505
|
+
... criteria={
|
506
|
+
... "accuracy": {"threshold": 0.95, "mode": "threshold"},
|
507
|
+
... "loss": {"threshold": 0.01, "mode": "threshold", "direction": "minimize"},
|
508
|
+
... "stability": {"mode": "stability", "window": 5}
|
509
|
+
... },
|
510
|
+
... require_all=True
|
511
|
+
... )
|
512
|
+
"""
|
513
|
+
|
514
|
+
def get_parameters(self) -> Dict[str, NodeParameter]:
|
515
|
+
"""Define input parameters for multi-criteria convergence."""
|
516
|
+
return {
|
517
|
+
"metrics": NodeParameter(
|
518
|
+
name="metrics",
|
519
|
+
type=dict,
|
520
|
+
required=False,
|
521
|
+
default={},
|
522
|
+
description="Dictionary of metric_name: value pairs to monitor",
|
523
|
+
),
|
524
|
+
"criteria": NodeParameter(
|
525
|
+
name="criteria",
|
526
|
+
type=dict,
|
527
|
+
required=False,
|
528
|
+
default={},
|
529
|
+
description="Dictionary of convergence criteria for each metric",
|
530
|
+
),
|
531
|
+
"require_all": NodeParameter(
|
532
|
+
name="require_all",
|
533
|
+
type=bool,
|
534
|
+
required=False,
|
535
|
+
default=True,
|
536
|
+
description="Whether all criteria must be met (True) or any (False)",
|
537
|
+
),
|
538
|
+
}
|
539
|
+
|
540
|
+
def run(self, context: Dict[str, Any], **kwargs) -> Dict[str, Any]:
|
541
|
+
"""Execute multi-criteria convergence checking."""
|
542
|
+
metrics = kwargs.get("metrics", {})
|
543
|
+
|
544
|
+
# On first iteration, store criteria in state
|
545
|
+
if self.is_first_iteration(context):
|
546
|
+
criteria = kwargs.get("criteria", {})
|
547
|
+
require_all = kwargs.get("require_all", True)
|
548
|
+
# Store in cycle state for persistence
|
549
|
+
self._stored_criteria = criteria
|
550
|
+
self._stored_require_all = require_all
|
551
|
+
else:
|
552
|
+
# Use stored criteria from previous iterations
|
553
|
+
criteria = getattr(self, "_stored_criteria", kwargs.get("criteria", {}))
|
554
|
+
require_all = getattr(
|
555
|
+
self, "_stored_require_all", kwargs.get("require_all", True)
|
556
|
+
)
|
557
|
+
|
558
|
+
iteration = self.get_iteration(context)
|
559
|
+
prev_state = self.get_previous_state(context)
|
560
|
+
|
561
|
+
# Track metrics history
|
562
|
+
metrics_history = prev_state.get("metrics_history", {})
|
563
|
+
for metric_name, value in metrics.items():
|
564
|
+
if metric_name not in metrics_history:
|
565
|
+
metrics_history[metric_name] = []
|
566
|
+
metrics_history[metric_name].append(value)
|
567
|
+
# Keep last 100 values
|
568
|
+
metrics_history[metric_name] = metrics_history[metric_name][-100:]
|
569
|
+
|
570
|
+
# Check each criterion
|
571
|
+
results = {}
|
572
|
+
met_criteria = []
|
573
|
+
failed_criteria = []
|
574
|
+
|
575
|
+
for metric_name, criterion in criteria.items():
|
576
|
+
if metric_name not in metrics:
|
577
|
+
results[metric_name] = {
|
578
|
+
"converged": False,
|
579
|
+
"reason": f"Metric '{metric_name}' not provided",
|
580
|
+
"value": None,
|
581
|
+
}
|
582
|
+
failed_criteria.append(metric_name)
|
583
|
+
continue
|
584
|
+
|
585
|
+
value = metrics[metric_name]
|
586
|
+
history = metrics_history.get(metric_name, [])
|
587
|
+
|
588
|
+
# Create individual convergence checker
|
589
|
+
checker = ConvergenceCheckerNode()
|
590
|
+
|
591
|
+
# Prepare parameters for the checker
|
592
|
+
checker_params = {"value": value, **criterion}
|
593
|
+
|
594
|
+
# Use a mock context for the individual checker
|
595
|
+
mock_context = {
|
596
|
+
"cycle": {"iteration": iteration, "node_state": {"values": history}}
|
597
|
+
}
|
598
|
+
|
599
|
+
# Run individual convergence check
|
600
|
+
result = checker.run(mock_context, **checker_params)
|
601
|
+
|
602
|
+
results[metric_name] = {
|
603
|
+
"converged": result["converged"],
|
604
|
+
"reason": result["reason"],
|
605
|
+
"value": value,
|
606
|
+
"metrics": result["convergence_metrics"],
|
607
|
+
}
|
608
|
+
|
609
|
+
if result["converged"]:
|
610
|
+
met_criteria.append(metric_name)
|
611
|
+
else:
|
612
|
+
failed_criteria.append(metric_name)
|
613
|
+
|
614
|
+
# Determine overall convergence
|
615
|
+
if require_all:
|
616
|
+
converged = len(failed_criteria) == 0
|
617
|
+
if converged:
|
618
|
+
reason = f"All {len(met_criteria)} criteria met"
|
619
|
+
else:
|
620
|
+
reason = f"{len(failed_criteria)} criteria not met: {failed_criteria}"
|
621
|
+
else:
|
622
|
+
converged = len(met_criteria) > 0
|
623
|
+
if converged:
|
624
|
+
reason = f"{len(met_criteria)} criteria met: {met_criteria}"
|
625
|
+
else:
|
626
|
+
reason = "No criteria met"
|
627
|
+
|
628
|
+
# Log status
|
629
|
+
self.log_cycle_info(
|
630
|
+
context, f"Multi-criteria: {len(met_criteria)}/{len(criteria)} met"
|
631
|
+
)
|
632
|
+
|
633
|
+
return {
|
634
|
+
"converged": converged,
|
635
|
+
"reason": reason,
|
636
|
+
"met_criteria": met_criteria,
|
637
|
+
"failed_criteria": failed_criteria,
|
638
|
+
"detailed_results": results,
|
639
|
+
"iteration": iteration,
|
640
|
+
"metrics": metrics, # Pass through current metrics for cycle
|
641
|
+
**self.set_cycle_state({"metrics_history": metrics_history}),
|
642
|
+
}
|