quantalogic 0.53.0__py3-none-any.whl → 0.56.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,3 +1,4 @@
1
+ import asyncio
1
2
  import importlib
2
3
  import importlib.util
3
4
  import os
@@ -13,12 +14,13 @@ import yaml # type: ignore
13
14
  from loguru import logger
14
15
  from pydantic import BaseModel, ValidationError
15
16
 
16
- # Import directly from flow.py to avoid circular import through __init__.py
17
17
  from quantalogic.flow.flow import Nodes, Workflow
18
18
  from quantalogic.flow.flow_manager_schema import (
19
+ BranchCondition,
19
20
  FunctionDefinition,
20
21
  LLMConfig,
21
22
  NodeDefinition,
23
+ TemplateConfig,
22
24
  TransitionDefinition,
23
25
  WorkflowDefinition,
24
26
  WorkflowStructure,
@@ -38,15 +40,11 @@ class WorkflowManager:
38
40
 
39
41
  for dep in self.workflow.dependencies:
40
42
  if dep.startswith("http://") or dep.startswith("https://"):
41
- # Remote URL: handled by import_module_from_source later
42
43
  logger.debug(f"Dependency '{dep}' is a remote URL, will be fetched during instantiation")
43
44
  elif os.path.isfile(dep):
44
- # Local file: handled by import_module_from_source later
45
45
  logger.debug(f"Dependency '{dep}' is a local file, will be loaded during instantiation")
46
46
  else:
47
- # Assume PyPI package
48
47
  try:
49
- # Check if the module is already installed
50
48
  module_name = dep.split(">")[0].split("<")[0].split("=")[0].strip()
51
49
  importlib.import_module(module_name)
52
50
  logger.debug(f"Dependency '{dep}' is already installed")
@@ -64,21 +62,41 @@ class WorkflowManager:
64
62
  function: Optional[str] = None,
65
63
  sub_workflow: Optional[WorkflowStructure] = None,
66
64
  llm_config: Optional[Dict[str, Any]] = None,
65
+ template_config: Optional[Dict[str, Any]] = None,
66
+ inputs_mapping: Optional[Dict[str, Union[str, Callable]]] = None,
67
67
  output: Optional[str] = None,
68
68
  retries: int = 3,
69
69
  delay: float = 1.0,
70
70
  timeout: Optional[float] = None,
71
71
  parallel: bool = False,
72
72
  ) -> None:
73
- """Add a new node to the workflow definition, supporting sub-workflows and LLM nodes."""
74
- # Convert dict to LLMConfig if provided
73
+ """Add a new node to the workflow definition with support for template nodes and inputs mapping."""
75
74
  llm_config_obj = LLMConfig(**llm_config) if llm_config is not None else None
75
+ template_config_obj = TemplateConfig(**template_config) if template_config is not None else None
76
+
77
+ serializable_inputs_mapping = {}
78
+ if inputs_mapping:
79
+ for key, value in inputs_mapping.items():
80
+ if callable(value):
81
+ if hasattr(value, '__name__') and value.__name__ == '<lambda>':
82
+ import inspect
83
+ try:
84
+ source = inspect.getsource(value).strip()
85
+ serializable_inputs_mapping[key] = f"lambda ctx: {source.split(':')[-1].strip()}"
86
+ except Exception:
87
+ serializable_inputs_mapping[key] = str(value)
88
+ else:
89
+ serializable_inputs_mapping[key] = value.__name__
90
+ else:
91
+ serializable_inputs_mapping[key] = value
76
92
 
77
93
  node = NodeDefinition(
78
94
  function=function,
79
95
  sub_workflow=sub_workflow,
80
96
  llm_config=llm_config_obj,
81
- output=output or (f"{name}_result" if function or llm_config else None),
97
+ template_config=template_config_obj,
98
+ inputs_mapping=serializable_inputs_mapping,
99
+ output=output or (f"{name}_result" if function or llm_config or template_config else None),
82
100
  retries=retries,
83
101
  delay=delay,
84
102
  timeout=timeout,
@@ -94,27 +112,52 @@ class WorkflowManager:
94
112
  self.workflow.workflow.transitions = [
95
113
  t
96
114
  for t in self.workflow.workflow.transitions
97
- if t.from_node != name and (isinstance(t.to_node, str) or name not in t.to_node)
115
+ if t.from_node != name and (isinstance(t.to_node, str) or all(
116
+ isinstance(tn, str) and tn != name or isinstance(tn, BranchCondition) and tn.to_node != name
117
+ for tn in t.to_node
118
+ ))
98
119
  ]
99
120
  if self.workflow.workflow.start == name:
100
121
  self.workflow.workflow.start = None
122
+ if name in self.workflow.workflow.convergence_nodes:
123
+ self.workflow.workflow.convergence_nodes.remove(name)
101
124
 
102
125
  def update_node(
103
126
  self,
104
127
  name: str,
105
128
  function: Optional[str] = None,
129
+ template_config: Optional[Dict[str, Any]] = None,
130
+ inputs_mapping: Optional[Dict[str, Union[str, Callable]]] = None,
106
131
  output: Optional[str] = None,
107
132
  retries: Optional[int] = None,
108
133
  delay: Optional[float] = None,
109
134
  timeout: Optional[Union[float, None]] = None,
110
135
  parallel: Optional[bool] = None,
111
136
  ) -> None:
112
- """Update specific fields of an existing node."""
137
+ """Update specific fields of an existing node with template and mapping support."""
113
138
  if name not in self.workflow.nodes:
114
139
  raise ValueError(f"Node '{name}' does not exist")
115
140
  node = self.workflow.nodes[name]
116
141
  if function is not None:
117
142
  node.function = function
143
+ if template_config is not None:
144
+ node.template_config = TemplateConfig(**template_config)
145
+ if inputs_mapping is not None:
146
+ serializable_inputs_mapping = {}
147
+ for key, value in inputs_mapping.items():
148
+ if callable(value):
149
+ if hasattr(value, '__name__') and value.__name__ == '<lambda>':
150
+ import inspect
151
+ try:
152
+ source = inspect.getsource(value).strip()
153
+ serializable_inputs_mapping[key] = f"lambda ctx: {source.split(':')[-1].strip()}"
154
+ except Exception:
155
+ serializable_inputs_mapping[key] = str(value)
156
+ else:
157
+ serializable_inputs_mapping[key] = value.__name__
158
+ else:
159
+ serializable_inputs_mapping[key] = value
160
+ node.inputs_mapping = serializable_inputs_mapping
118
161
  if output is not None:
119
162
  node.output = output
120
163
  if retries is not None:
@@ -129,19 +172,11 @@ class WorkflowManager:
129
172
  def add_transition(
130
173
  self,
131
174
  from_node: str,
132
- to_node: Union[str, List[str]],
175
+ to_node: Union[str, List[Union[str, BranchCondition]]],
133
176
  condition: Optional[str] = None,
134
177
  strict: bool = True,
135
178
  ) -> None:
136
- """Add a transition between nodes.
137
-
138
- Args:
139
- from_node: Source node name
140
- to_node: Target node name or list of target node names
141
- condition: Optional condition for the transition
142
- strict: If True, validates that all nodes exist before adding the transition.
143
- If False, allows adding transitions to non-existent nodes.
144
- """
179
+ """Add a transition between nodes, supporting branching."""
145
180
  if strict:
146
181
  if from_node not in self.workflow.nodes:
147
182
  raise ValueError(f"Source node '{from_node}' does not exist")
@@ -150,9 +185,9 @@ class WorkflowManager:
150
185
  raise ValueError(f"Target node '{to_node}' does not exist")
151
186
  else:
152
187
  for t in to_node:
153
- if t not in self.workflow.nodes:
154
- raise ValueError(f"Target node '{t}' does not exist")
155
- # Create TransitionDefinition with named parameters
188
+ target = t if isinstance(t, str) else t.to_node
189
+ if target not in self.workflow.nodes:
190
+ raise ValueError(f"Target node '{target}' does not exist")
156
191
  transition = TransitionDefinition(
157
192
  from_node=from_node,
158
193
  to_node=to_node,
@@ -166,6 +201,14 @@ class WorkflowManager:
166
201
  raise ValueError(f"Node '{name}' does not exist")
167
202
  self.workflow.workflow.start = name
168
203
 
204
+ def add_convergence_node(self, name: str) -> None:
205
+ """Add a convergence node to the workflow."""
206
+ if name not in self.workflow.nodes:
207
+ raise ValueError(f"Node '{name}' does not exist")
208
+ if name not in self.workflow.workflow.convergence_nodes:
209
+ self.workflow.workflow.convergence_nodes.append(name)
210
+ logger.debug(f"Added convergence node '{name}'")
211
+
169
212
  def add_function(
170
213
  self,
171
214
  name: str,
@@ -199,20 +242,8 @@ class WorkflowManager:
199
242
  raise ValueError(f"Failed to resolve response_model '{model_str}': {e}")
200
243
 
201
244
  def import_module_from_source(self, source: str) -> Any:
202
- """
203
- Import a module from various sources: installed module name (e.g., PyPI), local file path, or remote URL.
204
-
205
- Args:
206
- source: The module specification (e.g., 'requests', '/path/to/file.py', 'https://example.com/module.py').
207
-
208
- Returns:
209
- The imported module object.
210
-
211
- Raises:
212
- ValueError: If the module cannot be imported, with suggestions for installation if it's a PyPI package.
213
- """
245
+ """Import a module from various sources."""
214
246
  if source.startswith("http://") or source.startswith("https://"):
215
- # Handle remote URL
216
247
  try:
217
248
  with urllib.request.urlopen(source) as response:
218
249
  code = response.read().decode("utf-8")
@@ -233,7 +264,6 @@ class WorkflowManager:
233
264
  except Exception as e:
234
265
  raise ValueError(f"Failed to import module from URL '{source}': {e}")
235
266
  elif os.path.isfile(source):
236
- # Handle local file path
237
267
  try:
238
268
  module_name = f"local_module_{hash(source)}"
239
269
  spec = importlib.util.spec_from_file_location(module_name, source)
@@ -248,20 +278,17 @@ class WorkflowManager:
248
278
  except Exception as e:
249
279
  raise ValueError(f"Failed to import module from file '{source}': {e}")
250
280
  else:
251
- # Assume installed module name from PyPI or system
252
281
  try:
253
282
  return importlib.import_module(source)
254
283
  except ImportError as e:
255
284
  logger.error(f"Module '{source}' not found: {e}")
256
285
  raise ValueError(
257
286
  f"Failed to import module '{source}': {e}. "
258
- f"This may be a PyPI package. Ensure it is installed using 'pip install {source}' "
259
- "or check if the module name is correct."
287
+ f"Ensure it is installed using 'pip install {source}' or check the module name."
260
288
  )
261
289
 
262
290
  def instantiate_workflow(self) -> Workflow:
263
- """Instantiates a Workflow object based on the definitions stored in the WorkflowManager."""
264
- # Ensure dependencies are available before instantiation
291
+ """Instantiate a Workflow object with full support for template_node and inputs_mapping."""
265
292
  self._ensure_dependencies()
266
293
 
267
294
  functions: Dict[str, Callable] = {}
@@ -286,70 +313,27 @@ class WorkflowManager:
286
313
  except (ImportError, AttributeError) as e:
287
314
  raise ValueError(f"Failed to import external function '{func_name}': {e}")
288
315
 
289
- # Check if start node is set
290
316
  if not self.workflow.workflow.start:
291
317
  raise ValueError("Start node not set in workflow definition")
292
-
293
- # We need to ensure we have a valid string for the start node
294
- # First check if it's None and provide a fallback
318
+
319
+ start_node_name = str(self.workflow.workflow.start) if self.workflow.workflow.start else "start"
295
320
  if self.workflow.workflow.start is None:
296
321
  logger.warning("Start node was None, using 'start' as default")
297
- start_node_name = "start"
298
- else:
299
- # Otherwise convert to string
300
- start_node_name = str(self.workflow.workflow.start)
301
-
302
- # Create the workflow with a valid start node
303
- wf = Workflow(start_node=start_node_name)
304
322
 
305
- # Register observers
306
- for observer_name in self.workflow.observers:
307
- if observer_name not in functions:
308
- raise ValueError(f"Observer '{observer_name}' not found in functions")
309
- wf.add_observer(functions[observer_name])
310
- logger.debug(f"Registered observer '{observer_name}' in workflow")
311
-
312
- sub_workflows: Dict[str, Workflow] = {}
323
+ # Register all nodes with their node names
313
324
  for node_name, node_def in self.workflow.nodes.items():
314
- if node_def.sub_workflow:
315
- # Ensure we have a valid start node for the sub-workflow
316
- if node_def.sub_workflow.start is None:
317
- logger.warning(f"Sub-workflow for node '{node_name}' has no start node, using '{node_name}_start' as default")
318
- start_node = f"{node_name}_start"
319
- else:
320
- start_node = str(node_def.sub_workflow.start)
321
- sub_wf = Workflow(start_node=start_node)
322
- sub_workflows[node_name] = sub_wf
323
- added_sub_nodes = set()
324
- for trans in node_def.sub_workflow.transitions:
325
- from_node = trans.from_node
326
- to_nodes = [trans.to_node] if isinstance(trans.to_node, str) else trans.to_node
327
- if from_node not in added_sub_nodes:
328
- sub_wf.node(from_node)
329
- added_sub_nodes.add(from_node)
330
- for to_node in to_nodes:
331
- if to_node not in added_sub_nodes:
332
- sub_wf.node(to_node)
333
- added_sub_nodes.add(to_node)
334
- condition = eval(f"lambda ctx: {trans.condition}") if trans.condition else None
335
- if len(to_nodes) > 1:
336
- sub_wf.parallel(*to_nodes) # No condition support in parallel as per original
337
- else:
338
- sub_wf.then(to_nodes[0], condition=condition)
339
- inputs = list(Nodes.NODE_REGISTRY[sub_wf.start_node][1])
340
- # Ensure output is a string
341
- output = node_def.output if node_def.output is not None else f"{node_name}_result"
342
- wf.add_sub_workflow(node_name, sub_wf, inputs={k: k for k in inputs}, output=output)
343
- elif node_def.function:
325
+ if node_def.function:
344
326
  if node_def.function not in functions:
345
327
  raise ValueError(f"Function '{node_def.function}' for node '{node_name}' not found")
346
328
  func = functions[node_def.function]
347
- Nodes.define(
348
- output=node_def.output,
349
- )(func)
329
+ # Register with the node name, not the function name
330
+ Nodes.NODE_REGISTRY[node_name] = (
331
+ Nodes.define(output=node_def.output)(func),
332
+ ["user_name"], # Explicitly define inputs based on function signature
333
+ node_def.output
334
+ )
350
335
  elif node_def.llm_config:
351
336
  llm_config = node_def.llm_config
352
- # Extract inputs from prompt_template if no prompt_file, otherwise assume inputs will be inferred at runtime
353
337
  input_vars = set(re.findall(r"{{\s*([^}]+?)\s*}}", llm_config.prompt_template)) if not llm_config.prompt_file else set()
354
338
  cleaned_inputs = set()
355
339
  for input_var in input_vars:
@@ -358,18 +342,16 @@ class WorkflowManager:
358
342
  cleaned_inputs.add(base_var)
359
343
  inputs_list: List[str] = list(cleaned_inputs)
360
344
 
361
- # Define a dummy function to be decorated
362
345
  async def dummy_func(**kwargs):
363
- pass # This will be replaced by the decorator logic
346
+ pass
364
347
 
365
348
  if llm_config.response_model:
366
- # Structured LLM node
367
349
  response_model = self._resolve_model(llm_config.response_model)
368
350
  decorated_func = Nodes.structured_llm_node(
369
351
  model=llm_config.model,
370
352
  system_prompt=llm_config.system_prompt or "",
371
353
  prompt_template=llm_config.prompt_template,
372
- prompt_file=llm_config.prompt_file, # Pass prompt_file if provided
354
+ prompt_file=llm_config.prompt_file,
373
355
  response_model=response_model,
374
356
  output=node_def.output or f"{node_name}_result",
375
357
  temperature=llm_config.temperature,
@@ -380,12 +362,11 @@ class WorkflowManager:
380
362
  api_key=llm_config.api_key,
381
363
  )(dummy_func)
382
364
  else:
383
- # Plain LLM node
384
365
  decorated_func = Nodes.llm_node(
385
366
  model=llm_config.model,
386
367
  system_prompt=llm_config.system_prompt or "",
387
368
  prompt_template=llm_config.prompt_template,
388
- prompt_file=llm_config.prompt_file, # Pass prompt_file if provided
369
+ prompt_file=llm_config.prompt_file,
389
370
  output=node_def.output or f"{node_name}_result",
390
371
  temperature=llm_config.temperature,
391
372
  max_tokens=llm_config.max_tokens or 2000,
@@ -395,28 +376,121 @@ class WorkflowManager:
395
376
  api_key=llm_config.api_key,
396
377
  )(dummy_func)
397
378
 
398
- # Register the node in NODE_REGISTRY with proper inputs
399
379
  Nodes.NODE_REGISTRY[node_name] = (decorated_func, inputs_list, node_def.output or f"{node_name}_result")
400
- logger.debug(
401
- f"Registered LLM node '{node_name}' with inputs {inputs_list} and output {node_def.output or f'{node_name}_result'}"
402
- )
380
+ elif node_def.template_config:
381
+ template_config = node_def.template_config
382
+ input_vars = set(re.findall(r"{{\s*([^}]+?)\s*}}", template_config.template)) if not template_config.template_file else set()
383
+ cleaned_inputs = {var.strip() for var in input_vars if var.strip().isidentifier()}
384
+ inputs_list = list(cleaned_inputs)
385
+
386
+ async def dummy_template_func(rendered_content: str, **kwargs):
387
+ return rendered_content
388
+
389
+ decorated_func = Nodes.template_node(
390
+ output=node_def.output or f"{node_name}_result",
391
+ template=template_config.template,
392
+ template_file=template_config.template_file,
393
+ )(dummy_template_func)
394
+
395
+ Nodes.NODE_REGISTRY[node_name] = (decorated_func, ["rendered_content"] + inputs_list, node_def.output or f"{node_name}_result")
396
+
397
+ # Create the Workflow instance after all nodes are registered
398
+ wf = Workflow(start_node=start_node_name)
399
+
400
+ for observer_name in self.workflow.observers:
401
+ if observer_name not in functions:
402
+ raise ValueError(f"Observer '{observer_name}' not found in functions")
403
+ wf.add_observer(functions[observer_name])
404
+ logger.debug(f"Registered observer '{observer_name}' in workflow")
405
+
406
+ sub_workflows: Dict[str, Workflow] = {}
407
+ for node_name, node_def in self.workflow.nodes.items():
408
+ inputs_mapping = {}
409
+ if node_def.inputs_mapping:
410
+ for key, value in node_def.inputs_mapping.items():
411
+ if isinstance(value, str) and value.startswith("lambda ctx:"):
412
+ try:
413
+ inputs_mapping[key] = eval(value)
414
+ except Exception as e:
415
+ logger.warning(f"Failed to evaluate lambda for {key} in {node_name}: {e}")
416
+ inputs_mapping[key] = value
417
+ else:
418
+ inputs_mapping[key] = value
419
+
420
+ if node_def.sub_workflow:
421
+ start_node = str(node_def.sub_workflow.start) if node_def.sub_workflow.start else f"{node_name}_start"
422
+ if node_def.sub_workflow.start is None:
423
+ logger.warning(f"Sub-workflow for node '{node_name}' has no start node, using '{start_node}'")
424
+ sub_wf = Workflow(start_node=start_node)
425
+ sub_workflows[node_name] = sub_wf
426
+ added_sub_nodes = set()
427
+ for trans in node_def.sub_workflow.transitions:
428
+ from_node = trans.from_node
429
+ if from_node not in added_sub_nodes:
430
+ sub_wf.node(from_node)
431
+ added_sub_nodes.add(from_node)
432
+ if isinstance(trans.to_node, str):
433
+ to_nodes = [trans.to_node]
434
+ condition = eval(f"lambda ctx: {trans.condition}") if trans.condition else None
435
+ if to_nodes[0] not in added_sub_nodes:
436
+ sub_wf.node(to_nodes[0])
437
+ added_sub_nodes.add(to_nodes[0])
438
+ sub_wf.then(to_nodes[0], condition=condition)
439
+ elif all(isinstance(tn, str) for tn in trans.to_node):
440
+ to_nodes = trans.to_node
441
+ for to_node in to_nodes:
442
+ if to_node not in added_sub_nodes:
443
+ sub_wf.node(to_node)
444
+ added_sub_nodes.add(to_node)
445
+ sub_wf.parallel(*to_nodes)
446
+ else:
447
+ branches = [(tn.to_node, eval(f"lambda ctx: {tn.condition}") if tn.condition else None)
448
+ for tn in trans.to_node]
449
+ for to_node, _ in branches:
450
+ if to_node not in added_sub_nodes:
451
+ sub_wf.node(to_node)
452
+ added_sub_nodes.add(to_node)
453
+ sub_wf.branch(branches)
454
+ inputs = list(Nodes.NODE_REGISTRY[sub_wf.start_node][1])
455
+ output = node_def.output if node_def.output is not None else f"{node_name}_result"
456
+ wf.add_sub_workflow(node_name, sub_wf, inputs={k: k for k in inputs}, output=output)
457
+ else:
458
+ wf.node(node_name, inputs_mapping=inputs_mapping if inputs_mapping else None)
403
459
 
404
460
  added_nodes = set()
405
461
  for trans in self.workflow.workflow.transitions:
406
462
  from_node = trans.from_node
407
- to_nodes = [trans.to_node] if isinstance(trans.to_node, str) else trans.to_node
408
463
  if from_node not in added_nodes and from_node not in sub_workflows:
409
464
  wf.node(from_node)
410
465
  added_nodes.add(from_node)
411
- for to_node in to_nodes:
412
- if to_node not in added_nodes and to_node not in sub_workflows:
413
- wf.node(to_node)
414
- added_nodes.add(to_node)
415
- condition = eval(f"lambda ctx: {trans.condition}") if trans.condition else None
416
- if len(to_nodes) > 1:
466
+ if isinstance(trans.to_node, str):
467
+ to_nodes = [trans.to_node]
468
+ condition = eval(f"lambda ctx: {trans.condition}") if trans.condition else None
469
+ if to_nodes[0] not in added_nodes and to_nodes[0] not in sub_workflows:
470
+ wf.node(to_nodes[0])
471
+ added_nodes.add(to_nodes[0])
472
+ wf.then(to_nodes[0], condition=condition)
473
+ elif all(isinstance(tn, str) for tn in trans.to_node):
474
+ to_nodes = trans.to_node
475
+ for to_node in to_nodes:
476
+ if to_node not in added_nodes and to_node not in sub_workflows:
477
+ wf.node(to_node)
478
+ added_nodes.add(to_node)
417
479
  wf.parallel(*to_nodes)
418
480
  else:
419
- wf.then(to_nodes[0], condition=condition)
481
+ branches = [(tn.to_node, eval(f"lambda ctx: {tn.condition}") if tn.condition else None)
482
+ for tn in trans.to_node]
483
+ for to_node, _ in branches:
484
+ if to_node not in added_nodes and to_node not in sub_workflows:
485
+ wf.node(to_node)
486
+ added_nodes.add(to_node)
487
+ wf.branch(branches)
488
+
489
+ for conv_node in self.workflow.workflow.convergence_nodes:
490
+ if conv_node not in added_nodes and conv_node not in sub_workflows:
491
+ wf.node(conv_node)
492
+ added_nodes.add(conv_node)
493
+ wf.converge(conv_node)
420
494
 
421
495
  return wf
422
496
 
@@ -429,7 +503,7 @@ class WorkflowManager:
429
503
  data = yaml.safe_load(f)
430
504
  try:
431
505
  self.workflow = WorkflowDefinition.model_validate(data)
432
- self._ensure_dependencies() # Ensure dependencies after loading
506
+ self._ensure_dependencies()
433
507
  except ValidationError as e:
434
508
  raise ValueError(f"Invalid workflow YAML: {e}")
435
509
 
@@ -437,13 +511,11 @@ class WorkflowManager:
437
511
  """Save the workflow to a YAML file using aliases and multi-line block scalars for code."""
438
512
  file_path = Path(file_path)
439
513
 
440
- # Custom representer to use multi-line block scalars for multi-line strings
441
514
  def str_representer(dumper, data):
442
- if "\n" in data: # Use block scalar for multi-line strings
515
+ if "\n" in data:
443
516
  return dumper.represent_scalar("tag:yaml.org,2002:str", data, style="|")
444
517
  return dumper.represent_scalar("tag:yaml.org,2002:str", data)
445
518
 
446
- # Add the custom representer to the SafeDumper
447
519
  yaml.add_representer(str, str_representer, Dumper=yaml.SafeDumper)
448
520
 
449
521
  with file_path.open("w") as f:
@@ -453,19 +525,24 @@ class WorkflowManager:
453
525
  default_flow_style=False,
454
526
  sort_keys=False,
455
527
  allow_unicode=True,
456
- width=120, # Wider width to reduce wrapping
528
+ width=120,
457
529
  )
458
530
 
459
531
 
460
- def main():
461
- """Demonstrate usage of WorkflowManager with observer support."""
532
+ async def test_workflow():
533
+ """Test the workflow execution."""
462
534
  manager = WorkflowManager()
463
- manager.workflow.dependencies = ["requests>=2.28.0"] # Example dependency
535
+ manager.workflow.dependencies = ["requests>=2.28.0"]
464
536
  manager.add_function(
465
537
  name="greet",
466
538
  type_="embedded",
467
539
  code="def greet(user_name): return f'Hello, {user_name}!'",
468
540
  )
541
+ manager.add_function(
542
+ name="check_condition",
543
+ type_="embedded",
544
+ code="def check_condition(user_name): return len(user_name) > 3",
545
+ )
469
546
  manager.add_function(
470
547
  name="farewell",
471
548
  type_="embedded",
@@ -481,16 +558,66 @@ def main():
481
558
  if event.exception:
482
559
  print(f'Error: {event.exception}')""",
483
560
  )
484
- manager.add_node(name="start", function="greet")
485
- manager.add_node(name="end", function="farewell")
561
+ manager.add_node(
562
+ name="start",
563
+ function="greet",
564
+ inputs_mapping={"user_name": "name_input"},
565
+ )
566
+ manager.add_node(
567
+ name="format_greeting",
568
+ template_config={"template": "User: {{ user_name }} greeted on {{ date }}"},
569
+ inputs_mapping={"user_name": "name_input", "date": "lambda ctx: '2025-03-06'"},
570
+ )
571
+ manager.add_node(
572
+ name="branch_true",
573
+ function="check_condition",
574
+ inputs_mapping={"user_name": "name_input"},
575
+ )
576
+ manager.add_node(
577
+ name="branch_false",
578
+ function="check_condition",
579
+ inputs_mapping={"user_name": "name_input"},
580
+ )
581
+ manager.add_node(
582
+ name="end",
583
+ function="farewell",
584
+ inputs_mapping={"user_name": "name_input"},
585
+ )
486
586
  manager.set_start_node("start")
487
- manager.add_transition(from_node="start", to_node="end")
488
- manager.add_observer("monitor") # Add the observer
587
+ manager.add_transition(
588
+ from_node="start",
589
+ to_node="format_greeting"
590
+ )
591
+ manager.add_transition(
592
+ from_node="format_greeting",
593
+ to_node=[
594
+ BranchCondition(to_node="branch_true", condition="ctx.get('user_name') == 'Alice'"),
595
+ BranchCondition(to_node="branch_false", condition="ctx.get('user_name') != 'Alice'")
596
+ ]
597
+ )
598
+ manager.add_convergence_node("end")
599
+ manager.add_observer("monitor")
489
600
  manager.save_to_yaml("workflow.yaml")
601
+
602
+ # Load and instantiate
490
603
  new_manager = WorkflowManager()
491
604
  new_manager.load_from_yaml("workflow.yaml")
605
+ print("Workflow structure:")
492
606
  print(new_manager.workflow.model_dump())
493
607
 
608
+ # Execute the workflow
609
+ workflow = new_manager.instantiate_workflow()
610
+ engine = workflow.build()
611
+ initial_context = {"name_input": "Alice"}
612
+ result = await engine.run(initial_context)
613
+ print("\nExecution result:")
614
+ print(result)
615
+
616
+
617
+ def main():
618
+ """Run the workflow test."""
619
+ asyncio.run(test_workflow())
620
+
494
621
 
495
622
  if __name__ == "__main__":
496
623
  main()