kailash 0.7.0__py3-none-any.whl → 0.8.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,12 +4,10 @@ import logging
4
4
  import uuid
5
5
  from typing import TYPE_CHECKING, Any
6
6
 
7
+ from kailash.nodes.base import Node
7
8
  from kailash.sdk_exceptions import ConnectionError, WorkflowValidationError
8
9
  from kailash.workflow.graph import Workflow
9
10
 
10
- if TYPE_CHECKING:
11
- from kailash.nodes.base import Node
12
-
13
11
  logger = logging.getLogger(__name__)
14
12
 
15
13
 
@@ -21,6 +19,9 @@ class WorkflowBuilder:
21
19
  self.nodes: dict[str, dict[str, Any]] = {}
22
20
  self.connections: list[dict[str, str]] = []
23
21
  self._metadata: dict[str, Any] = {}
22
+ # Parameter injection capabilities
23
+ self.workflow_parameters: dict[str, Any] = {}
24
+ self.parameter_mappings: dict[str, dict[str, str]] = {}
24
25
 
25
26
  def add_node(self, *args, **kwargs) -> str:
26
27
  """
@@ -74,8 +75,6 @@ class WorkflowBuilder:
74
75
  # Pattern: add_node(NodeClass)
75
76
  return self._add_node_alternative(args[0], None, **kwargs)
76
77
  else:
77
- from kailash.nodes.base import Node
78
-
79
78
  if isinstance(args[0], Node):
80
79
  # Pattern: add_node(node_instance)
81
80
  return self._add_node_instance(args[0], None)
@@ -111,8 +110,6 @@ class WorkflowBuilder:
111
110
 
112
111
  elif len(args) >= 2:
113
112
  # Check if first arg is a Node instance
114
- from kailash.nodes.base import Node
115
-
116
113
  if isinstance(args[0], Node):
117
114
  # Pattern 4: Instance - add_node(node_instance, "node_id") or add_node(node_instance, "node_id", config)
118
115
  # Config is ignored for instances
@@ -149,8 +146,6 @@ class WorkflowBuilder:
149
146
  """Handle legacy fluent API pattern: add_node('node_id', NodeClass, param=value)"""
150
147
  import warnings
151
148
 
152
- from kailash.nodes.base import Node
153
-
154
149
  # If it's a class, validate it's a Node subclass
155
150
  if isinstance(node_class_or_type, type) and not issubclass(
156
151
  node_class_or_type, Node
@@ -183,8 +178,6 @@ class WorkflowBuilder:
183
178
  """Handle alternative pattern: add_node(NodeClass, 'node_id', param=value)"""
184
179
  import warnings
185
180
 
186
- from kailash.nodes.base import Node
187
-
188
181
  # Validate that node_class is actually a Node subclass
189
182
  if not isinstance(node_class, type) or not issubclass(node_class, Node):
190
183
  raise WorkflowValidationError(
@@ -267,9 +260,6 @@ class WorkflowBuilder:
267
260
  f"Node ID '{node_id}' already exists in workflow"
268
261
  )
269
262
 
270
- # Import Node here to avoid circular imports
271
- from kailash.nodes.base import Node
272
-
273
263
  # Handle different input types
274
264
  if isinstance(node_type, str):
275
265
  # String node type name
@@ -606,12 +596,108 @@ class WorkflowBuilder:
606
596
  f"Failed to connect '{from_node}' to '{to_node}': {e}"
607
597
  ) from e
608
598
 
599
+ # Parameter injection: Find nodes without incoming connections and inject parameters
600
+ if self.workflow_parameters:
601
+ nodes_with_inputs = set()
602
+ for conn in self.connections:
603
+ if not conn.get("is_workflow_input"):
604
+ nodes_with_inputs.add(conn["to_node"])
605
+
606
+ nodes_without_inputs = set(self.nodes.keys()) - nodes_with_inputs
607
+
608
+ # For each node without inputs, check if it needs workflow parameters
609
+ for node_id in nodes_without_inputs:
610
+ node_info = self.nodes[node_id]
611
+ node_instance = workflow.get_node(node_id)
612
+
613
+ if hasattr(node_instance, "get_parameters"):
614
+ params = node_instance.get_parameters()
615
+
616
+ # Check which required parameters are missing from config
617
+ for param_name, param_def in params.items():
618
+ if param_def.required and param_name not in node_info["config"]:
619
+ # Check if this parameter should come from workflow parameters
620
+ if param_name in self.workflow_parameters:
621
+ # Add to node config
622
+ node_info["config"][param_name] = (
623
+ self.workflow_parameters[param_name]
624
+ )
625
+ elif node_id in self.parameter_mappings:
626
+ # Check parameter mappings
627
+ mapping = self.parameter_mappings[node_id]
628
+ if param_name in mapping:
629
+ workflow_param = mapping[param_name]
630
+ if workflow_param in self.workflow_parameters:
631
+ node_info["config"][param_name] = (
632
+ self.workflow_parameters[workflow_param]
633
+ )
634
+
635
+ # Store workflow parameters in metadata for runtime reference
636
+ workflow.metadata["workflow_parameters"] = self.workflow_parameters
637
+ workflow.metadata["parameter_mappings"] = self.parameter_mappings
638
+
609
639
  logger.info(
610
640
  f"Built workflow '{workflow_id}' with "
611
641
  f"{len(self.nodes)} nodes and {len(self.connections)} connections"
612
642
  )
613
643
  return workflow
614
644
 
645
+ def set_workflow_parameters(self, **parameters) -> "WorkflowBuilder":
646
+ """
647
+ Set default parameters that will be passed to all nodes.
648
+
649
+ Args:
650
+ **parameters: Key-value pairs of workflow-level parameters
651
+
652
+ Returns:
653
+ Self for chaining
654
+ """
655
+ self.workflow_parameters.update(parameters)
656
+ return self
657
+
658
+ def add_parameter_mapping(
659
+ self, node_id: str, mappings: dict[str, str]
660
+ ) -> "WorkflowBuilder":
661
+ """
662
+ Add parameter mappings for a specific node.
663
+
664
+ Args:
665
+ node_id: Node to configure
666
+ mappings: Dict mapping workflow param names to node param names
667
+
668
+ Returns:
669
+ Self for chaining
670
+ """
671
+ if node_id not in self.parameter_mappings:
672
+ self.parameter_mappings[node_id] = {}
673
+ self.parameter_mappings[node_id].update(mappings)
674
+ return self
675
+
676
+ def add_input_connection(
677
+ self, to_node: str, to_input: str, from_workflow_param: str
678
+ ) -> "WorkflowBuilder":
679
+ """
680
+ Connect a workflow parameter directly to a node input.
681
+
682
+ Args:
683
+ to_node: Target node ID
684
+ to_input: Input parameter name on the node
685
+ from_workflow_param: Workflow parameter name
686
+
687
+ Returns:
688
+ Self for chaining
689
+ """
690
+ # Add a special connection type for workflow inputs
691
+ connection = {
692
+ "from_node": "__workflow_input__",
693
+ "from_output": from_workflow_param,
694
+ "to_node": to_node,
695
+ "to_input": to_input,
696
+ "is_workflow_input": True,
697
+ }
698
+ self.connections.append(connection)
699
+ return self
700
+
615
701
  def clear(self) -> "WorkflowBuilder":
616
702
  """
617
703
  Clear builder state.
@@ -622,6 +708,8 @@ class WorkflowBuilder:
622
708
  self.nodes = {}
623
709
  self.connections = []
624
710
  self._metadata = {}
711
+ self.workflow_parameters = {}
712
+ self.parameter_mappings = {}
625
713
  return self
626
714
 
627
715
  @classmethod
@@ -184,7 +184,7 @@ class CyclicWorkflowExecutor:
184
184
  if not workflow.has_cycles():
185
185
  # No cycles, use standard DAG execution
186
186
  logger.info("No cycles detected, using standard DAG execution")
187
- return self.dag_runner.run(workflow, parameters), run_id
187
+ return self.dag_runner.execute(workflow, parameters), run_id
188
188
 
189
189
  # Execute with cycle support
190
190
  try:
@@ -370,18 +370,110 @@ class CyclicWorkflowExecutor:
370
370
  )
371
371
  results.update(cycle_results)
372
372
  else:
373
- # Execute DAG nodes
374
- for node_id in stage.nodes:
375
- if node_id not in state.node_outputs:
376
- logger.info(f"Executing DAG node: {node_id}")
377
- node_result = self._execute_node(
378
- workflow, node_id, state, task_manager=task_manager
379
- )
380
- results[node_id] = node_result
381
- state.node_outputs[node_id] = node_result
373
+ # Execute DAG nodes using extracted method
374
+ dag_results = self._execute_dag_portion(
375
+ workflow, stage.nodes, state, task_manager
376
+ )
377
+ results.update(dag_results)
378
+
379
+ return results
380
+
381
+ def _execute_dag_portion(
382
+ self,
383
+ workflow: Workflow,
384
+ dag_nodes: list[str],
385
+ state: WorkflowState,
386
+ task_manager: TaskManager | None = None,
387
+ ) -> dict[str, Any]:
388
+ """Execute DAG (non-cyclic) portion of the workflow.
389
+
390
+ Args:
391
+ workflow: Workflow instance
392
+ dag_nodes: List of DAG node IDs to execute
393
+ state: Workflow state
394
+ task_manager: Optional task manager for tracking
395
+
396
+ Returns:
397
+ Dictionary with node IDs as keys and their results as values
398
+ """
399
+ results = {}
400
+
401
+ for node_id in dag_nodes:
402
+ if node_id not in state.node_outputs:
403
+ logger.info(f"Executing DAG node: {node_id}")
404
+ node_result = self._execute_node(
405
+ workflow, node_id, state, task_manager=task_manager
406
+ )
407
+ results[node_id] = node_result
408
+ state.node_outputs[node_id] = node_result
409
+
410
+ return results
411
+
412
+ def _execute_cycle_groups(
413
+ self,
414
+ workflow: Workflow,
415
+ cycle_groups: list["CycleGroup"],
416
+ state: WorkflowState,
417
+ task_manager: TaskManager | None = None,
418
+ ) -> dict[str, Any]:
419
+ """Execute cycle groups portion of the workflow.
420
+
421
+ Args:
422
+ workflow: Workflow instance
423
+ cycle_groups: List of cycle groups to execute
424
+ state: Workflow state
425
+ task_manager: Optional task manager for tracking
426
+
427
+ Returns:
428
+ Dictionary with node IDs as keys and their results as values
429
+ """
430
+ results = {}
431
+
432
+ for cycle_group in cycle_groups:
433
+ logger.info(f"Executing cycle group: {cycle_group.cycle_id}")
434
+ cycle_results = self._execute_cycle_group(
435
+ workflow, cycle_group, state, task_manager
436
+ )
437
+ results.update(cycle_results)
382
438
 
383
439
  return results
384
440
 
441
+ def _propagate_parameters(
442
+ self,
443
+ current_params: dict[str, Any],
444
+ current_results: dict[str, Any],
445
+ cycle_config: dict[str, Any] | None = None,
446
+ ) -> dict[str, Any]:
447
+ """Handle parameter propagation between cycle iterations.
448
+
449
+ Args:
450
+ current_params: Current iteration parameters
451
+ current_results: Results from current iteration
452
+ cycle_config: Cycle configuration (optional)
453
+
454
+ Returns:
455
+ Updated parameters for the next iteration
456
+ """
457
+ # Base propagation: copy current results for next iteration
458
+ next_params = current_results.copy() if current_results else {}
459
+
460
+ # Apply any cycle-specific parameter mappings if provided
461
+ if cycle_config and "parameter_mappings" in cycle_config:
462
+ mappings = cycle_config["parameter_mappings"]
463
+ for src_key, dst_key in mappings.items():
464
+ if src_key in current_results:
465
+ next_params[dst_key] = current_results[src_key]
466
+
467
+ # Preserve any initial parameters that aren't overridden
468
+ for key, value in current_params.items():
469
+ if key not in next_params:
470
+ next_params[key] = value
471
+
472
+ # Filter out None values to avoid validation errors
473
+ next_params = self._filter_none_values(next_params)
474
+
475
+ return next_params
476
+
385
477
  def _execute_cycle_group(
386
478
  self,
387
479
  workflow: Workflow,
@@ -18,7 +18,7 @@ class WorkflowVisualizer:
18
18
 
19
19
  def __init__(
20
20
  self,
21
- workflow: Workflow,
21
+ workflow: Workflow | None = None,
22
22
  node_colors: dict[str, str] | None = None,
23
23
  edge_colors: dict[str, str] | None = None,
24
24
  layout: str = "hierarchical",
@@ -26,7 +26,7 @@ class WorkflowVisualizer:
26
26
  """Initialize visualizer.
27
27
 
28
28
  Args:
29
- workflow: Workflow to visualize
29
+ workflow: Workflow to visualize (can be set later)
30
30
  node_colors: Custom node color map
31
31
  edge_colors: Custom edge color map
32
32
  layout: Layout algorithm to use
@@ -70,9 +70,12 @@ class WorkflowVisualizer:
70
70
  """Get colors for all nodes in workflow."""
71
71
  colors = []
72
72
  for node_id in self.workflow.graph.nodes():
73
- node_instance = self.workflow.nodes[node_id]
74
- node_type = node_instance.node_type
75
- colors.append(self._get_node_color(node_type))
73
+ node_instance = self.workflow.nodes.get(node_id)
74
+ if node_instance:
75
+ node_type = node_instance.node_type
76
+ colors.append(self._get_node_color(node_type))
77
+ else:
78
+ colors.append(self.node_colors["default"])
76
79
  return colors
77
80
 
78
81
  def _get_node_labels(self) -> dict[str, str]:
@@ -119,11 +122,17 @@ class WorkflowVisualizer:
119
122
 
120
123
  return edge_labels
121
124
 
122
- def _calculate_layout(self) -> dict[str, tuple[float, float]]:
125
+ def _calculate_layout(
126
+ self, workflow: "Workflow" = None
127
+ ) -> dict[str, tuple[float, float]]:
123
128
  """Calculate node positions for visualization."""
129
+ target_workflow = workflow or self.workflow
130
+ if not target_workflow:
131
+ return {}
132
+
124
133
  # Try to use stored positions first
125
134
  pos = {}
126
- for node_id, node_instance in self.workflow.nodes.items():
135
+ for node_id, node_instance in target_workflow.nodes.items():
127
136
  if node_instance.position != (0, 0):
128
137
  pos[node_id] = node_instance.position
129
138
 
@@ -133,32 +142,73 @@ class WorkflowVisualizer:
133
142
  # Use hierarchical layout for DAGs
134
143
  try:
135
144
  # Create layers based on topological order
136
- layers = self._create_layers()
145
+ layers = self._create_layers(target_workflow)
137
146
  pos = self._hierarchical_layout(layers)
138
147
  except Exception:
139
148
  # Fallback to spring layout
140
- pos = nx.spring_layout(self.workflow.graph, k=3, iterations=50)
149
+ pos = nx.spring_layout(target_workflow.graph, k=3, iterations=50)
141
150
  elif self.layout == "circular":
142
- pos = nx.circular_layout(self.workflow.graph)
151
+ pos = nx.circular_layout(target_workflow.graph)
143
152
  elif self.layout == "spring":
144
- pos = nx.spring_layout(self.workflow.graph, k=2, iterations=100)
153
+ pos = nx.spring_layout(target_workflow.graph, k=2, iterations=100)
145
154
  else:
146
155
  # Default to spring layout
147
- pos = nx.spring_layout(self.workflow.graph)
156
+ pos = nx.spring_layout(target_workflow.graph)
148
157
 
149
158
  return pos
150
159
 
151
- def _create_layers(self) -> dict[int, list]:
160
+ def _get_layout_positions(
161
+ self, workflow: Workflow
162
+ ) -> dict[str, tuple[float, float]]:
163
+ """Get layout positions for workflow nodes."""
164
+ # Temporarily store workflow and calculate layout
165
+ original_workflow = self.workflow
166
+ self.workflow = workflow
167
+ try:
168
+ return self._calculate_layout()
169
+ finally:
170
+ self.workflow = original_workflow
171
+
172
+ def _get_node_colors(self, workflow: Workflow) -> list[str]:
173
+ """Get node colors for workflow."""
174
+ colors = []
175
+ for node_id in workflow.graph.nodes():
176
+ node_instance = workflow.get_node(node_id)
177
+ if node_instance:
178
+ node_type = node_instance.__class__.__name__.lower()
179
+ # Map node types to color categories
180
+ if "data" in node_type or "csv" in node_type or "json" in node_type:
181
+ color_key = "data"
182
+ elif "transform" in node_type or "python" in node_type:
183
+ color_key = "transform"
184
+ elif "switch" in node_type or "merge" in node_type:
185
+ color_key = "logic"
186
+ elif "llm" in node_type or "ai" in node_type:
187
+ color_key = "ai"
188
+ else:
189
+ color_key = "default"
190
+ colors.append(
191
+ self.node_colors.get(color_key, self.node_colors["default"])
192
+ )
193
+ else:
194
+ colors.append(self.node_colors["default"])
195
+ return colors
196
+
197
+ def _create_layers(self, workflow: "Workflow" = None) -> dict[int, list]:
152
198
  """Create layers of nodes for hierarchical layout."""
199
+ target_workflow = workflow or self.workflow
200
+ if not target_workflow:
201
+ return {}
202
+
153
203
  layers = {}
154
- remaining = set(self.workflow.graph.nodes())
204
+ remaining = set(target_workflow.graph.nodes())
155
205
  layer = 0
156
206
 
157
207
  while remaining:
158
208
  # Find nodes with no dependencies in remaining set
159
209
  current_layer = []
160
210
  for node in remaining:
161
- predecessors = set(self.workflow.graph.predecessors(node))
211
+ predecessors = set(target_workflow.graph.predecessors(node))
162
212
  if not predecessors.intersection(remaining):
163
213
  current_layer.append(node)
164
214
 
@@ -196,20 +246,38 @@ class WorkflowVisualizer:
196
246
 
197
247
  def _draw_graph(
198
248
  self,
199
- pos: dict[str, tuple[float, float]],
200
- node_colors: list[str],
201
- show_labels: bool,
202
- show_connections: bool,
249
+ workflow: Workflow | None = None,
250
+ pos: dict[str, tuple[float, float]] | None = None,
251
+ node_colors: list[str] | None = None,
252
+ show_labels: bool = True,
253
+ show_connections: bool = True,
203
254
  ) -> None:
204
255
  """Draw the graph with given positions and options."""
256
+ # Use provided workflow or fall back to instance workflow
257
+ target_workflow = workflow or self.workflow
258
+ if not target_workflow:
259
+ raise ValueError("No workflow provided to draw")
260
+
261
+ # Use default position if not provided
262
+ if pos is None:
263
+ pos = self._get_layout_positions(target_workflow)
264
+
265
+ # Use default colors if not provided
266
+ if node_colors is None:
267
+ node_colors = self._get_node_colors(target_workflow)
268
+
205
269
  # Draw nodes
206
270
  nx.draw_networkx_nodes(
207
- self.workflow.graph, pos, node_color=node_colors, node_size=3000, alpha=0.9
271
+ target_workflow.graph,
272
+ pos,
273
+ node_color=node_colors,
274
+ node_size=3000,
275
+ alpha=0.9,
208
276
  )
209
277
 
210
278
  # Draw edges
211
279
  nx.draw_networkx_edges(
212
- self.workflow.graph,
280
+ target_workflow.graph,
213
281
  pos,
214
282
  edge_color=self.edge_colors["default"],
215
283
  width=2,
@@ -221,16 +289,26 @@ class WorkflowVisualizer:
221
289
 
222
290
  # Draw labels
223
291
  if show_labels:
292
+ # Temporarily set workflow for label generation
293
+ old_workflow = self.workflow
294
+ self.workflow = target_workflow
224
295
  labels = self._get_node_labels()
296
+ self.workflow = old_workflow
297
+
225
298
  nx.draw_networkx_labels(
226
- self.workflow.graph, pos, labels, font_size=10, font_weight="bold"
299
+ target_workflow.graph, pos, labels, font_size=10, font_weight="bold"
227
300
  )
228
301
 
229
302
  # Draw connection labels
230
303
  if show_connections:
304
+ # Temporarily set workflow for edge label generation
305
+ old_workflow = self.workflow
306
+ self.workflow = target_workflow
231
307
  edge_labels = self._get_edge_labels()
308
+ self.workflow = old_workflow
309
+
232
310
  nx.draw_networkx_edge_labels(
233
- self.workflow.graph, pos, edge_labels, font_size=8
311
+ target_workflow.graph, pos, edge_labels, font_size=8
234
312
  )
235
313
 
236
314
  def visualize(
@@ -255,10 +333,16 @@ class WorkflowVisualizer:
255
333
  **kwargs: Additional options passed to plt.savefig
256
334
  """
257
335
  try:
336
+ # Check if workflow is available
337
+ if not self.workflow:
338
+ raise ValueError(
339
+ "No workflow to visualize. Set workflow property or create visualizer with workflow."
340
+ )
341
+
258
342
  plt.figure(figsize=figsize)
259
343
 
260
344
  # Calculate node positions
261
- pos = self._calculate_layout()
345
+ pos = self._calculate_layout(self.workflow)
262
346
 
263
347
  # Handle empty workflow case
264
348
  if not self.workflow.graph.nodes():
@@ -266,11 +350,17 @@ class WorkflowVisualizer:
266
350
  node_colors = []
267
351
  else:
268
352
  # Draw the graph with colors
269
- node_colors = self._get_node_colors()
353
+ node_colors = self._get_node_colors(self.workflow)
270
354
 
271
355
  # Draw the graph components
272
- if pos and node_colors:
273
- self._draw_graph(pos, node_colors, show_labels, show_connections)
356
+ if pos:
357
+ self._draw_graph(
358
+ workflow=self.workflow,
359
+ pos=pos,
360
+ node_colors=node_colors,
361
+ show_labels=show_labels,
362
+ show_connections=show_connections,
363
+ )
274
364
 
275
365
  # Set title
276
366
  if title is None:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: kailash
3
- Version: 0.7.0
3
+ Version: 0.8.1
4
4
  Summary: Python SDK for the Kailash container-node architecture
5
5
  Home-page: https://github.com/integrum/kailash-python-sdk
6
6
  Author: Integrum
@@ -21,7 +21,7 @@ Requires-Dist: matplotlib>=3.5
21
21
  Requires-Dist: pyyaml>=6.0
22
22
  Requires-Dist: click>=8.0
23
23
  Requires-Dist: pytest>=8.3.5
24
- Requires-Dist: mcp[cli]>=1.9.2
24
+ Requires-Dist: mcp[cli]==1.11.0
25
25
  Requires-Dist: pandas>=2.2.3
26
26
  Requires-Dist: numpy>=2.2.5
27
27
  Requires-Dist: scipy>=1.15.3
@@ -58,6 +58,7 @@ Requires-Dist: python-jose>=3.5.0
58
58
  Requires-Dist: pytest-xdist>=3.6.0
59
59
  Requires-Dist: pytest-timeout>=2.3.0
60
60
  Requires-Dist: pytest-split>=0.9.0
61
+ Requires-Dist: pytest-forked>=1.6.0
61
62
  Requires-Dist: asyncpg>=0.30.0
62
63
  Requires-Dist: aiomysql>=0.2.0
63
64
  Requires-Dist: twilio>=9.6.3
@@ -78,6 +79,7 @@ Requires-Dist: passlib>=1.7.4
78
79
  Requires-Dist: pyotp>=2.9.0
79
80
  Requires-Dist: opentelemetry-instrumentation-fastapi>=0.55b1
80
81
  Requires-Dist: seaborn>=0.13.2
82
+ Requires-Dist: sqlparse>=0.5.3
81
83
  Provides-Extra: dev
82
84
  Requires-Dist: pytest>=7.0; extra == "dev"
83
85
  Requires-Dist: pytest-cov>=3.0; extra == "dev"