quantalogic 0.53.0__py3-none-any.whl → 0.56.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,10 +4,12 @@ import os
4
4
  from loguru import logger
5
5
 
6
6
  from quantalogic.flow.flow_generator import generate_executable_script # Import from flow_generator
7
- from quantalogic.flow.flow_manager import WorkflowManager # Added for YAML saving
7
+ from quantalogic.flow.flow_manager import WorkflowManager # For YAML saving
8
8
  from quantalogic.flow.flow_manager_schema import (
9
+ BranchCondition,
9
10
  FunctionDefinition,
10
11
  NodeDefinition,
12
+ TemplateConfig,
11
13
  TransitionDefinition,
12
14
  WorkflowDefinition,
13
15
  WorkflowStructure,
@@ -19,17 +21,19 @@ class WorkflowExtractor(ast.NodeVisitor):
19
21
  AST visitor to extract workflow nodes and structure from a Python file.
20
22
 
21
23
  This class parses Python source code to identify workflow components defined with Nodes decorators
22
- and Workflow construction, building a WorkflowDefinition compatible with WorkflowManager.
24
+ and Workflow construction, including branch and converge patterns, building a WorkflowDefinition
25
+ compatible with WorkflowManager. Fully supports input mappings and template nodes.
23
26
  """
24
27
 
25
28
  def __init__(self):
26
29
  """Initialize the extractor with empty collections for workflow components."""
27
30
  self.nodes = {} # Maps node names to their definitions
28
31
  self.functions = {} # Maps function names to their code
29
- self.transitions = [] # List of (from_node, to_node, condition) tuples
32
+ self.transitions = [] # List of TransitionDefinition objects
30
33
  self.start_node = None # Starting node of the workflow
31
34
  self.global_vars = {} # Tracks global variable assignments (e.g., DEFAULT_LLM_PARAMS)
32
35
  self.observers = [] # List of observer function names
36
+ self.convergence_nodes = [] # List of convergence nodes
33
37
 
34
38
  def visit_Module(self, node):
35
39
  """Log and explicitly process top-level statements in the module."""
@@ -58,7 +62,6 @@ class WorkflowExtractor(ast.NodeVisitor):
58
62
  if isinstance(v, ast.Constant):
59
63
  self.global_vars[var_name][key] = v.value
60
64
  elif isinstance(v, ast.Name) and v.id in self.global_vars:
61
- # Resolve variable references to previously defined globals
62
65
  self.global_vars[var_name][key] = self.global_vars[v.id]
63
66
  logger.debug(
64
67
  f"Captured global variable '{var_name}' with keys: {list(self.global_vars[var_name].keys())}"
@@ -85,7 +88,6 @@ class WorkflowExtractor(ast.NodeVisitor):
85
88
  kwargs = {}
86
89
  logger.debug(f"Examining decorator for '{node.name}': {ast.dump(decorator)}")
87
90
 
88
- # Handle simple decorators (e.g., @Nodes.define)
89
91
  if (
90
92
  isinstance(decorator, ast.Attribute)
91
93
  and isinstance(decorator.value, ast.Name)
@@ -94,7 +96,6 @@ class WorkflowExtractor(ast.NodeVisitor):
94
96
  decorator_name = decorator.attr
95
97
  logger.debug(f"Found simple decorator 'Nodes.{decorator_name}' for '{node.name}'")
96
98
 
97
- # Handle decorators with arguments (e.g., @Nodes.llm_node(...))
98
99
  elif (
99
100
  isinstance(decorator, ast.Call)
100
101
  and isinstance(decorator.func, ast.Attribute)
@@ -113,8 +114,9 @@ class WorkflowExtractor(ast.NodeVisitor):
113
114
  kwargs[kw.arg] = kw.value.value
114
115
  elif kw.arg == "response_model" and isinstance(kw.value, ast.Name):
115
116
  kwargs[kw.arg] = ast.unparse(kw.value)
117
+ elif kw.arg == "transformer" and isinstance(kw.value, ast.Lambda):
118
+ kwargs[kw.arg] = ast.unparse(kw.value)
116
119
 
117
- # Process recognized decorators
118
120
  if decorator_name:
119
121
  func_name = node.name
120
122
  inputs = [arg.arg for arg in node.args.args]
@@ -127,6 +129,7 @@ class WorkflowExtractor(ast.NodeVisitor):
127
129
  "inputs": inputs,
128
130
  "output": output,
129
131
  }
132
+ logger.debug(f"Registered function node '{func_name}' with output '{output}'")
130
133
  elif decorator_name == "llm_node":
131
134
  llm_config = {
132
135
  key: value
@@ -135,7 +138,7 @@ class WorkflowExtractor(ast.NodeVisitor):
135
138
  "model",
136
139
  "system_prompt",
137
140
  "prompt_template",
138
- "prompt_file", # Added to support external Jinja2 files
141
+ "prompt_file",
139
142
  "temperature",
140
143
  "max_tokens",
141
144
  "top_p",
@@ -150,6 +153,7 @@ class WorkflowExtractor(ast.NodeVisitor):
150
153
  "inputs": inputs,
151
154
  "output": llm_config.get("output"),
152
155
  }
156
+ logger.debug(f"Registered LLM node '{func_name}' with model '{llm_config.get('model')}'")
153
157
  elif decorator_name == "validate_node":
154
158
  output = kwargs.get("output")
155
159
  self.nodes[func_name] = {
@@ -158,6 +162,7 @@ class WorkflowExtractor(ast.NodeVisitor):
158
162
  "inputs": inputs,
159
163
  "output": output,
160
164
  }
165
+ logger.debug(f"Registered validate node '{func_name}' with output '{output}'")
161
166
  elif decorator_name == "structured_llm_node":
162
167
  llm_config = {
163
168
  key: value
@@ -166,7 +171,7 @@ class WorkflowExtractor(ast.NodeVisitor):
166
171
  "model",
167
172
  "system_prompt",
168
173
  "prompt_template",
169
- "prompt_file", # Added to support external Jinja2 files
174
+ "prompt_file",
170
175
  "temperature",
171
176
  "max_tokens",
172
177
  "top_p",
@@ -182,10 +187,33 @@ class WorkflowExtractor(ast.NodeVisitor):
182
187
  "inputs": inputs,
183
188
  "output": llm_config.get("output"),
184
189
  }
190
+ logger.debug(f"Registered structured LLM node '{func_name}' with model '{llm_config.get('model')}'")
191
+ elif decorator_name == "template_node":
192
+ template_config = {
193
+ "template": kwargs.get("template", ""),
194
+ "template_file": kwargs.get("template_file"),
195
+ }
196
+ if "rendered_content" not in inputs:
197
+ inputs.insert(0, "rendered_content")
198
+ self.nodes[func_name] = {
199
+ "type": "template",
200
+ "template_config": template_config,
201
+ "inputs": inputs,
202
+ "output": kwargs.get("output"),
203
+ }
204
+ logger.debug(f"Registered template node '{func_name}' with config: {template_config}")
205
+ elif decorator_name == "transform_node":
206
+ output = kwargs.get("output")
207
+ self.nodes[func_name] = {
208
+ "type": "function",
209
+ "function": func_name,
210
+ "inputs": inputs,
211
+ "output": output,
212
+ }
213
+ logger.debug(f"Registered transform node '{func_name}' with output '{output}'")
185
214
  else:
186
215
  logger.warning(f"Unsupported decorator 'Nodes.{decorator_name}' in function '{func_name}'")
187
216
 
188
- # Store the function code as embedded
189
217
  func_code = ast.unparse(node)
190
218
  self.functions[func_name] = {
191
219
  "type": "embedded",
@@ -204,7 +232,6 @@ class WorkflowExtractor(ast.NodeVisitor):
204
232
  kwargs = {}
205
233
  logger.debug(f"Examining decorator for '{node.name}': {ast.dump(decorator)}")
206
234
 
207
- # Handle simple decorators (e.g., @Nodes.define)
208
235
  if (
209
236
  isinstance(decorator, ast.Attribute)
210
237
  and isinstance(decorator.value, ast.Name)
@@ -213,7 +240,6 @@ class WorkflowExtractor(ast.NodeVisitor):
213
240
  decorator_name = decorator.attr
214
241
  logger.debug(f"Found simple decorator 'Nodes.{decorator_name}' for '{node.name}'")
215
242
 
216
- # Handle decorators with arguments (e.g., @Nodes.llm_node(...))
217
243
  elif (
218
244
  isinstance(decorator, ast.Call)
219
245
  and isinstance(decorator.func, ast.Attribute)
@@ -232,8 +258,9 @@ class WorkflowExtractor(ast.NodeVisitor):
232
258
  kwargs[kw.arg] = kw.value.value
233
259
  elif kw.arg == "response_model" and isinstance(kw.value, ast.Name):
234
260
  kwargs[kw.arg] = ast.unparse(kw.value)
261
+ elif kw.arg == "transformer" and isinstance(kw.value, ast.Lambda):
262
+ kwargs[kw.arg] = ast.unparse(kw.value)
235
263
 
236
- # Process recognized decorators
237
264
  if decorator_name:
238
265
  func_name = node.name
239
266
  inputs = [arg.arg for arg in node.args.args]
@@ -246,6 +273,7 @@ class WorkflowExtractor(ast.NodeVisitor):
246
273
  "inputs": inputs,
247
274
  "output": output,
248
275
  }
276
+ logger.debug(f"Registered function node '{func_name}' with output '{output}'")
249
277
  elif decorator_name == "llm_node":
250
278
  llm_config = {
251
279
  key: value
@@ -254,7 +282,7 @@ class WorkflowExtractor(ast.NodeVisitor):
254
282
  "model",
255
283
  "system_prompt",
256
284
  "prompt_template",
257
- "prompt_file", # Added to support external Jinja2 files
285
+ "prompt_file",
258
286
  "temperature",
259
287
  "max_tokens",
260
288
  "top_p",
@@ -269,6 +297,7 @@ class WorkflowExtractor(ast.NodeVisitor):
269
297
  "inputs": inputs,
270
298
  "output": llm_config.get("output"),
271
299
  }
300
+ logger.debug(f"Registered LLM node '{func_name}' with model '{llm_config.get('model')}'")
272
301
  elif decorator_name == "validate_node":
273
302
  output = kwargs.get("output")
274
303
  self.nodes[func_name] = {
@@ -277,6 +306,7 @@ class WorkflowExtractor(ast.NodeVisitor):
277
306
  "inputs": inputs,
278
307
  "output": output,
279
308
  }
309
+ logger.debug(f"Registered validate node '{func_name}' with output '{output}'")
280
310
  elif decorator_name == "structured_llm_node":
281
311
  llm_config = {
282
312
  key: value
@@ -285,7 +315,7 @@ class WorkflowExtractor(ast.NodeVisitor):
285
315
  "model",
286
316
  "system_prompt",
287
317
  "prompt_template",
288
- "prompt_file", # Added to support external Jinja2 files
318
+ "prompt_file",
289
319
  "temperature",
290
320
  "max_tokens",
291
321
  "top_p",
@@ -301,10 +331,33 @@ class WorkflowExtractor(ast.NodeVisitor):
301
331
  "inputs": inputs,
302
332
  "output": llm_config.get("output"),
303
333
  }
334
+ logger.debug(f"Registered structured LLM node '{func_name}' with model '{llm_config.get('model')}'")
335
+ elif decorator_name == "template_node":
336
+ template_config = {
337
+ "template": kwargs.get("template", ""),
338
+ "template_file": kwargs.get("template_file"),
339
+ }
340
+ if "rendered_content" not in inputs:
341
+ inputs.insert(0, "rendered_content")
342
+ self.nodes[func_name] = {
343
+ "type": "template",
344
+ "template_config": template_config,
345
+ "inputs": inputs,
346
+ "output": kwargs.get("output"),
347
+ }
348
+ logger.debug(f"Registered template node '{func_name}' with config: {template_config}")
349
+ elif decorator_name == "transform_node":
350
+ output = kwargs.get("output")
351
+ self.nodes[func_name] = {
352
+ "type": "function",
353
+ "function": func_name,
354
+ "inputs": inputs,
355
+ "output": output,
356
+ }
357
+ logger.debug(f"Registered transform node '{func_name}' with output '{output}'")
304
358
  else:
305
359
  logger.warning(f"Unsupported decorator 'Nodes.{decorator_name}' in function '{func_name}'")
306
360
 
307
- # Store the function code as embedded
308
361
  func_code = ast.unparse(node)
309
362
  self.functions[func_name] = {
310
363
  "type": "embedded",
@@ -346,68 +399,110 @@ class WorkflowExtractor(ast.NodeVisitor):
346
399
  next_node = expr.args[0].value if expr.args else None
347
400
  condition = None
348
401
  for keyword in expr.keywords:
349
- if keyword.arg == "condition":
350
- if isinstance(keyword.value, ast.Lambda):
351
- condition = ast.unparse(keyword.value)
352
- else:
353
- condition = ast.unparse(keyword.value)
354
- logger.warning(
355
- f"Non-lambda condition in 'then' for '{next_node}' may not be fully supported"
356
- )
402
+ if keyword.arg == "condition" and keyword.value:
403
+ condition = ast.unparse(keyword.value)
357
404
  if previous_node and next_node:
358
- self.transitions.append((previous_node, next_node, condition))
405
+ self.transitions.append(TransitionDefinition(from_node=previous_node, to_node=next_node, condition=condition))
359
406
  logger.debug(f"Added transition: {previous_node} -> {next_node} (condition: {condition})")
360
407
  return next_node
361
408
 
362
409
  elif method_name == "sequence":
363
410
  nodes = [arg.value for arg in expr.args]
364
- if previous_node:
365
- self.transitions.append((previous_node, nodes[0], None))
411
+ if previous_node and nodes:
412
+ self.transitions.append(TransitionDefinition(from_node=previous_node, to_node=nodes[0]))
413
+ logger.debug(f"Added sequence start transition: {previous_node} -> {nodes[0]}")
366
414
  for i in range(len(nodes) - 1):
367
- self.transitions.append((nodes[i], nodes[i + 1], None))
415
+ self.transitions.append(TransitionDefinition(from_node=nodes[i], to_node=nodes[i + 1]))
368
416
  logger.debug(f"Added sequence transition: {nodes[i]} -> {nodes[i + 1]}")
369
417
  return nodes[-1] if nodes else previous_node
370
418
 
371
419
  elif method_name == "parallel":
372
420
  to_nodes = [arg.value for arg in expr.args]
373
421
  if previous_node:
374
- for to_node in to_nodes:
375
- self.transitions.append((previous_node, to_node, None))
376
- logger.debug(f"Added parallel transition: {previous_node} -> {to_node}")
377
- return None # Parallel transitions reset the current node
422
+ self.transitions.append(TransitionDefinition(from_node=previous_node, to_node=to_nodes))
423
+ logger.debug(f"Added parallel transition: {previous_node} -> {to_nodes}")
424
+ return None
425
+
426
+ elif method_name == "branch":
427
+ branches = []
428
+ if expr.args and isinstance(expr.args[0], ast.List):
429
+ for elt in expr.args[0].elts:
430
+ if isinstance(elt, ast.Tuple) and len(elt.elts) == 2:
431
+ to_node = elt.elts[0].value
432
+ cond = ast.unparse(elt.elts[1]) if elt.elts[1] else None
433
+ branches.append(BranchCondition(to_node=to_node, condition=cond))
434
+ logger.debug(f"Added branch: {previous_node} -> {to_node} (condition: {cond})")
435
+ if previous_node and branches:
436
+ self.transitions.append(TransitionDefinition(from_node=previous_node, to_node=branches))
437
+ return None
438
+
439
+ elif method_name == "converge":
440
+ conv_node = expr.args[0].value if expr.args else None
441
+ if conv_node and conv_node not in self.convergence_nodes:
442
+ self.convergence_nodes.append(conv_node)
443
+ logger.debug(f"Added convergence node: {conv_node}")
444
+ return conv_node
378
445
 
379
446
  elif method_name == "node":
380
447
  node_name = expr.args[0].value if expr.args else None
381
- if node_name and previous_node:
382
- self.transitions.append((previous_node, node_name, None))
383
- logger.debug(f"Added node transition: {previous_node} -> {node_name}")
448
+ inputs_mapping = None
449
+ for keyword in expr.keywords:
450
+ if keyword.arg == "inputs_mapping" and isinstance(keyword.value, ast.Dict):
451
+ inputs_mapping = {}
452
+ for k, v in zip(keyword.value.keys, keyword.value.values):
453
+ key = k.value if isinstance(k, ast.Constant) else ast.unparse(k)
454
+ if isinstance(v, ast.Constant):
455
+ inputs_mapping[key] = v.value
456
+ elif isinstance(v, ast.Lambda):
457
+ inputs_mapping[key] = f"lambda ctx: {ast.unparse(v.body)}"
458
+ else:
459
+ inputs_mapping[key] = ast.unparse(v)
460
+ if node_name:
461
+ if node_name in self.nodes and inputs_mapping:
462
+ self.nodes[node_name]["inputs_mapping"] = inputs_mapping
463
+ logger.debug(f"Added inputs_mapping to node '{node_name}': {inputs_mapping}")
464
+ if previous_node:
465
+ self.transitions.append(TransitionDefinition(from_node=previous_node, to_node=node_name))
466
+ logger.debug(f"Added node transition: {previous_node} -> {node_name}")
384
467
  return node_name
385
468
 
386
469
  elif method_name == "add_sub_workflow":
387
- sub_wf_name = expr.args[0].value
388
- sub_wf_obj = expr.args[1]
470
+ sub_wf_name = expr.args[0].value if expr.args else None
471
+ sub_wf_obj = expr.args[1] if len(expr.args) > 1 else None
389
472
  inputs = {}
473
+ inputs_mapping = None
474
+ output = None
390
475
  if len(expr.args) > 2 and isinstance(expr.args[2], ast.Dict):
391
- inputs = {k.value: v.value for k, v in zip(expr.args[2].keys, expr.args[2].values)}
392
- output = expr.args[3].value if len(expr.args) > 3 else None
393
- sub_extractor = WorkflowExtractor()
394
- sub_extractor.process_workflow_expr(sub_wf_obj, f"{var_name}_{sub_wf_name}")
395
- self.nodes[sub_wf_name] = {
396
- "type": "sub_workflow",
397
- "sub_workflow": WorkflowStructure(
398
- start=sub_extractor.start_node,
399
- transitions=[
400
- TransitionDefinition(from_node=t[0], to_node=t[1], condition=t[2]) for t in sub_extractor.transitions
401
- ],
402
- ),
403
- "inputs": list(inputs.keys()),
404
- "output": output,
405
- }
406
- # Propagate observers from sub-workflow
407
- self.observers.extend(sub_extractor.observers)
408
- logger.debug(f"Added sub-workflow node '{sub_wf_name}' with start '{sub_extractor.start_node}'")
409
- if previous_node:
410
- self.transitions.append((previous_node, sub_wf_name, None))
476
+ inputs_mapping = {}
477
+ for k, v in zip(expr.args[2].keys, expr.args[2].values):
478
+ key = k.value if isinstance(k, ast.Constant) else ast.unparse(k)
479
+ if isinstance(v, ast.Constant):
480
+ inputs_mapping[key] = v.value
481
+ elif isinstance(v, ast.Lambda):
482
+ inputs_mapping[key] = f"lambda ctx: {ast.unparse(v.body)}"
483
+ else:
484
+ inputs_mapping[key] = ast.unparse(v)
485
+ inputs = list(inputs_mapping.keys())
486
+ if len(expr.args) > 3:
487
+ output = expr.args[3].value
488
+ if sub_wf_name and sub_wf_obj:
489
+ sub_extractor = WorkflowExtractor()
490
+ sub_extractor.process_workflow_expr(sub_wf_obj, f"{var_name}_{sub_wf_name}")
491
+ self.nodes[sub_wf_name] = {
492
+ "type": "sub_workflow",
493
+ "sub_workflow": WorkflowStructure(
494
+ start=sub_extractor.start_node,
495
+ transitions=sub_extractor.transitions,
496
+ convergence_nodes=sub_extractor.convergence_nodes,
497
+ ),
498
+ "inputs": inputs,
499
+ "inputs_mapping": inputs_mapping,
500
+ "output": output,
501
+ }
502
+ self.observers.extend(sub_extractor.observers)
503
+ logger.debug(f"Added sub-workflow node '{sub_wf_name}' with start '{sub_extractor.start_node}' and inputs_mapping: {inputs_mapping}")
504
+ if previous_node:
505
+ self.transitions.append(TransitionDefinition(from_node=previous_node, to_node=sub_wf_name))
411
506
  return sub_wf_name
412
507
 
413
508
  elif method_name == "add_observer":
@@ -435,37 +530,34 @@ def extract_workflow_from_file(file_path):
435
530
  Returns:
436
531
  tuple: (WorkflowDefinition, Dict[str, Any]) - The workflow definition and captured global variables.
437
532
  """
438
- # Read and parse the file
439
533
  with open(file_path) as f:
440
534
  source = f.read()
441
535
  tree = ast.parse(source)
442
536
 
443
- # Extract workflow components
444
537
  extractor = WorkflowExtractor()
445
538
  extractor.visit(tree)
446
539
 
447
- # Construct FunctionDefinition objects
448
540
  functions = {name: FunctionDefinition(**func) for name, func in extractor.functions.items()}
449
541
 
450
- # Construct NodeDefinition objects
451
542
  nodes = {}
452
- from quantalogic.flow.flow_manager_schema import LLMConfig # Import LLMConfig explicitly
543
+ from quantalogic.flow.flow_manager_schema import LLMConfig
453
544
 
454
545
  for name, node_info in extractor.nodes.items():
455
546
  if node_info["type"] == "function":
456
547
  nodes[name] = NodeDefinition(
457
548
  function=node_info["function"],
549
+ inputs_mapping=node_info.get("inputs_mapping"),
458
550
  output=node_info["output"],
459
- retries=3, # Default values
551
+ retries=3,
460
552
  delay=1.0,
461
553
  timeout=None,
462
554
  parallel=False,
463
555
  )
464
556
  elif node_info["type"] == "llm":
465
- # Convert llm_config dictionary to LLMConfig object to ensure model is preserved
466
557
  llm_config = LLMConfig(**node_info["llm_config"])
467
558
  nodes[name] = NodeDefinition(
468
559
  llm_config=llm_config,
560
+ inputs_mapping=node_info.get("inputs_mapping"),
469
561
  output=node_info["output"],
470
562
  retries=3,
471
563
  delay=1.0,
@@ -473,10 +565,21 @@ def extract_workflow_from_file(file_path):
473
565
  parallel=False,
474
566
  )
475
567
  elif node_info["type"] == "structured_llm":
476
- # Convert llm_config dictionary to LLMConfig object for structured LLM
477
568
  llm_config = LLMConfig(**node_info["llm_config"])
478
569
  nodes[name] = NodeDefinition(
479
570
  llm_config=llm_config,
571
+ inputs_mapping=node_info.get("inputs_mapping"),
572
+ output=node_info["output"],
573
+ retries=3,
574
+ delay=1.0,
575
+ timeout=None,
576
+ parallel=False,
577
+ )
578
+ elif node_info["type"] == "template":
579
+ template_config = TemplateConfig(**node_info["template_config"])
580
+ nodes[name] = NodeDefinition(
581
+ template_config=template_config,
582
+ inputs_mapping=node_info.get("inputs_mapping"),
480
583
  output=node_info["output"],
481
584
  retries=3,
482
585
  delay=1.0,
@@ -486,6 +589,7 @@ def extract_workflow_from_file(file_path):
486
589
  elif node_info["type"] == "sub_workflow":
487
590
  nodes[name] = NodeDefinition(
488
591
  sub_workflow=node_info["sub_workflow"],
592
+ inputs_mapping=node_info.get("inputs_mapping"),
489
593
  output=node_info["output"],
490
594
  retries=3,
491
595
  delay=1.0,
@@ -493,18 +597,30 @@ def extract_workflow_from_file(file_path):
493
597
  parallel=False,
494
598
  )
495
599
 
496
- # Construct TransitionDefinition objects
497
- transitions = [
498
- TransitionDefinition(from_node=from_node, to_node=to_node, condition=cond)
499
- for from_node, to_node, cond in extractor.transitions
500
- ]
501
-
502
- # Build WorkflowStructure
503
- workflow_structure = WorkflowStructure(start=extractor.start_node, transitions=transitions)
600
+ # Optional: Deduplicate transitions (uncomment if desired)
601
+ # seen = set()
602
+ # unique_transitions = []
603
+ # for t in extractor.transitions:
604
+ # key = (t.from_node, str(t.to_node), t.condition)
605
+ # if key not in seen:
606
+ # seen.add(key)
607
+ # unique_transitions.append(t)
608
+ # workflow_structure = WorkflowStructure(
609
+ # start=extractor.start_node,
610
+ # transitions=unique_transitions,
611
+ # convergence_nodes=extractor.convergence_nodes,
612
+ # )
613
+ workflow_structure = WorkflowStructure(
614
+ start=extractor.start_node,
615
+ transitions=extractor.transitions,
616
+ convergence_nodes=extractor.convergence_nodes,
617
+ )
504
618
 
505
- # Assemble WorkflowDefinition with observers
506
619
  workflow_def = WorkflowDefinition(
507
- functions=functions, nodes=nodes, workflow=workflow_structure, observers=extractor.observers
620
+ functions=functions,
621
+ nodes=nodes,
622
+ workflow=workflow_structure,
623
+ observers=extractor.observers,
508
624
  )
509
625
 
510
626
  return workflow_def, extractor.global_vars
@@ -538,23 +654,37 @@ def print_workflow_definition(workflow_def):
538
654
  print(" Type: LLM")
539
655
  print(f" Model: {node.llm_config.model}")
540
656
  print(f" Prompt Template: {node.llm_config.prompt_template}")
541
- if node.llm_config.prompt_file: # Added to display external prompt file if present
657
+ if node.llm_config.prompt_file:
542
658
  print(f" Prompt File: {node.llm_config.prompt_file}")
659
+ elif node.template_config:
660
+ print(" Type: Template")
661
+ print(f" Template: {node.template_config.template}")
662
+ if node.template_config.template_file:
663
+ print(f" Template File: {node.template_config.template_file}")
543
664
  elif node.sub_workflow:
544
665
  print(" Type: Sub-Workflow")
545
666
  print(f" Start Node: {node.sub_workflow.start}")
667
+ if node.inputs_mapping:
668
+ print(f" Inputs Mapping: {node.inputs_mapping}")
546
669
  print(f" Output: {node.output or 'None'}")
547
670
 
548
671
  print("\n#### Workflow Structure:")
549
672
  print(f"Start Node: {workflow_def.workflow.start}")
550
673
  print("Transitions:")
551
674
  for trans in workflow_def.workflow.transitions:
552
- condition_str = f" [Condition: {trans.condition}]" if trans.condition else ""
553
675
  if isinstance(trans.to_node, list):
554
- for to_node in trans.to_node:
555
- print(f"- {trans.from_node} -> {to_node}{condition_str}")
676
+ if all(isinstance(tn, BranchCondition) for tn in trans.to_node):
677
+ for branch in trans.to_node:
678
+ cond_str = f" [Condition: {branch.condition}]" if branch.condition else ""
679
+ print(f"- {trans.from_node} -> {branch.to_node}{cond_str}")
680
+ else:
681
+ print(f"- {trans.from_node} -> {trans.to_node} (parallel)")
556
682
  else:
557
- print(f"- {trans.from_node} -> {trans.to_node}{condition_str}")
683
+ cond_str = f" [Condition: {trans.condition}]" if trans.condition else ""
684
+ print(f"- {trans.from_node} -> {trans.to_node}{cond_str}")
685
+ print("Convergence Nodes:")
686
+ for conv_node in workflow_def.workflow.convergence_nodes:
687
+ print(f"- {conv_node}")
558
688
 
559
689
  print("\n#### Observers:")
560
690
  for observer in workflow_def.observers:
@@ -565,33 +695,32 @@ def main():
565
695
  """Demonstrate extracting a workflow from a Python file and saving it to YAML."""
566
696
  import argparse
567
697
  import sys
568
-
698
+
569
699
  parser = argparse.ArgumentParser(description='Extract workflow from a Python file')
570
- parser.add_argument('file_path', nargs='?', default="examples/qflow/story_generator_agent.py",
700
+ parser.add_argument('file_path', nargs='?', default="examples/flow/simple_story_generator/story_generator_agent.py",
571
701
  help='Path to the Python file containing the workflow')
572
702
  parser.add_argument('--output', '-o', default="./generated_workflow.py",
573
703
  help='Output path for the executable Python script')
574
704
  parser.add_argument('--yaml', '-y', default="workflow_definition.yaml",
575
705
  help='Output path for the YAML workflow definition')
576
-
706
+
577
707
  args = parser.parse_args()
578
708
  file_path = args.file_path
579
709
  output_file_python = args.output
580
710
  yaml_output_path = args.yaml
581
-
711
+
582
712
  if not os.path.exists(file_path):
583
713
  logger.error(f"File '{file_path}' not found. Please provide a valid file path.")
584
714
  logger.info("Example usage: python -m quantalogic.flow.flow_extractor path/to/your/workflow_file.py")
585
715
  sys.exit(1)
586
-
716
+
587
717
  try:
588
718
  workflow_def, global_vars = extract_workflow_from_file(file_path)
589
719
  logger.info(f"Successfully extracted workflow from '{file_path}'")
590
720
  print_workflow_definition(workflow_def)
591
721
  generate_executable_script(workflow_def, global_vars, output_file_python)
592
722
  logger.info(f"Executable script generated at '{output_file_python}'")
593
-
594
- # Save the workflow to a YAML file
723
+
595
724
  manager = WorkflowManager(workflow_def)
596
725
  manager.save_to_yaml(yaml_output_path)
597
726
  logger.info(f"Workflow saved to YAML file '{yaml_output_path}'")