quantalogic 0.55.0__py3-none-any.whl → 0.57.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
quantalogic/flow/flow.py CHANGED
@@ -12,7 +12,8 @@
12
12
  # ///
13
13
 
14
14
  import asyncio
15
- import inspect # Added for accurate parameter detection
15
+ import inspect
16
+ import os
16
17
  from dataclasses import dataclass
17
18
  from enum import Enum
18
19
  from pathlib import Path
@@ -47,18 +48,17 @@ class WorkflowEvent:
47
48
  transition_from: Optional[str] = None
48
49
  transition_to: Optional[str] = None
49
50
  sub_workflow_name: Optional[str] = None
50
- usage: Optional[Dict[str, Any]] = None # Added to store token usage and cost
51
+ usage: Optional[Dict[str, Any]] = None
51
52
 
52
53
 
53
54
  WorkflowObserver = Callable[[WorkflowEvent], None]
54
55
 
55
56
 
56
- # Define a class for sub-workflow nodes with updated inputs handling
57
57
  class SubWorkflowNode:
58
58
  def __init__(self, sub_workflow: "Workflow", inputs: Dict[str, Any], output: str):
59
59
  """Initialize a sub-workflow node with flexible inputs mapping."""
60
60
  self.sub_workflow = sub_workflow
61
- self.inputs = inputs # Maps sub_key to main_key, callable, or value
61
+ self.inputs = inputs
62
62
  self.output = output
63
63
 
64
64
  async def __call__(self, engine: "WorkflowEngine"):
@@ -70,7 +70,7 @@ class SubWorkflowNode:
70
70
  elif isinstance(mapping, str):
71
71
  sub_context[sub_key] = engine.context.get(mapping)
72
72
  else:
73
- sub_context[sub_key] = mapping # Direct value
73
+ sub_context[sub_key] = mapping
74
74
  sub_engine = self.sub_workflow.build(parent_engine=engine)
75
75
  result = await sub_engine.run(sub_context)
76
76
  return result.get(self.output)
@@ -82,7 +82,7 @@ class WorkflowEngine:
82
82
  self.workflow = workflow
83
83
  self.context: Dict[str, Any] = {}
84
84
  self.observers: List[WorkflowObserver] = []
85
- self.parent_engine = parent_engine # Link to parent engine for sub-workflow observer propagation
85
+ self.parent_engine = parent_engine
86
86
 
87
87
  def add_observer(self, observer: WorkflowObserver) -> None:
88
88
  """Register an event observer callback."""
@@ -90,7 +90,7 @@ class WorkflowEngine:
90
90
  self.observers.append(observer)
91
91
  logger.debug(f"Added observer: {observer}")
92
92
  if self.parent_engine:
93
- self.parent_engine.add_observer(observer) # Propagate to parent for global visibility
93
+ self.parent_engine.add_observer(observer)
94
94
 
95
95
  def remove_observer(self, observer: WorkflowObserver) -> None:
96
96
  """Remove an event observer callback."""
@@ -140,18 +140,15 @@ class WorkflowEngine:
140
140
  )
141
141
  break
142
142
 
143
- # Prepare inputs with mappings
144
143
  input_mappings = self.workflow.node_input_mappings.get(current_node, {})
145
144
  inputs = {}
146
- # Add all mapped inputs
147
145
  for key, mapping in input_mappings.items():
148
146
  if callable(mapping):
149
147
  inputs[key] = mapping(self.context)
150
148
  elif isinstance(mapping, str):
151
149
  inputs[key] = self.context.get(mapping)
152
150
  else:
153
- inputs[key] = mapping # Direct value
154
- # For parameters in node_inputs that are not mapped, get from context
151
+ inputs[key] = mapping
155
152
  for param in self.workflow.node_inputs[current_node]:
156
153
  if param not in inputs:
157
154
  inputs[param] = self.context.get(param)
@@ -159,7 +156,6 @@ class WorkflowEngine:
159
156
  result = None
160
157
  exception = None
161
158
 
162
- # Handle sub-workflow nodes
163
159
  if isinstance(node_func, SubWorkflowNode):
164
160
  await self._notify_observers(
165
161
  WorkflowEvent(
@@ -172,21 +168,24 @@ class WorkflowEngine:
172
168
 
173
169
  try:
174
170
  if isinstance(node_func, SubWorkflowNode):
175
- result = await node_func(self) # Sub-workflow handles its own inputs
176
- usage = None # Sub-workflow usage is handled by its own nodes
171
+ result = await node_func(self)
172
+ usage = None
177
173
  else:
178
174
  result = await node_func(**inputs)
179
- usage = getattr(node_func, "usage", None) # Extract usage if set by LLM nodes
175
+ usage = getattr(node_func, "usage", None)
180
176
  output_key = self.workflow.node_outputs[current_node]
181
177
  if output_key:
182
178
  self.context[output_key] = result
179
+ elif isinstance(result, dict):
180
+ self.context.update(result)
181
+ logger.debug(f"Updated context with {result} from node {current_node}")
183
182
  await self._notify_observers(
184
183
  WorkflowEvent(
185
184
  event_type=WorkflowEventType.NODE_COMPLETED,
186
185
  node_name=current_node,
187
186
  context=self.context,
188
187
  result=result,
189
- usage=usage, # Include usage data in the event
188
+ usage=usage,
190
189
  )
191
190
  )
192
191
  except Exception as e:
@@ -239,17 +238,21 @@ class WorkflowEngine:
239
238
 
240
239
  class Workflow:
241
240
  def __init__(self, start_node: str):
242
- """Initialize a workflow with a starting node."""
241
+ """Initialize a workflow with a starting node.
242
+
243
+ Args:
244
+ start_node: The name of the initial node in the workflow.
245
+ """
243
246
  self.start_node = start_node
244
247
  self.nodes: Dict[str, Callable] = {}
245
248
  self.node_inputs: Dict[str, List[str]] = {}
246
249
  self.node_outputs: Dict[str, Optional[str]] = {}
247
250
  self.transitions: Dict[str, List[Tuple[str, Optional[Callable]]]] = {}
248
- self.node_input_mappings: Dict[str, Dict[str, Any]] = {} # Store input mappings for nodes
251
+ self.node_input_mappings: Dict[str, Dict[str, Any]] = {}
249
252
  self.current_node = None
250
253
  self._observers: List[WorkflowObserver] = []
251
- self._register_node(start_node) # Register the start node without setting current_node
252
- self.current_node = start_node # Set current_node explicitly after registration
254
+ self._register_node(start_node)
255
+ self.current_node = start_node
253
256
 
254
257
  def _register_node(self, name: str):
255
258
  """Register a node without modifying the current node."""
@@ -261,7 +264,15 @@ class Workflow:
261
264
  self.node_outputs[name] = output
262
265
 
263
266
  def node(self, name: str, inputs_mapping: Optional[Dict[str, Any]] = None):
264
- """Add a node to the workflow chain with an optional inputs mapping."""
267
+ """Add a node to the workflow chain with an optional inputs mapping.
268
+
269
+ Args:
270
+ name: The name of the node to add.
271
+ inputs_mapping: Optional dictionary mapping node inputs to context keys or callables.
272
+
273
+ Returns:
274
+ Self for method chaining.
275
+ """
265
276
  self._register_node(name)
266
277
  if inputs_mapping:
267
278
  self.node_input_mappings[name] = inputs_mapping
@@ -270,7 +281,14 @@ class Workflow:
270
281
  return self
271
282
 
272
283
  def sequence(self, *nodes: str):
273
- """Add a sequence of nodes to execute in order."""
284
+ """Add a sequence of nodes to execute in order.
285
+
286
+ Args:
287
+ *nodes: Variable number of node names to execute sequentially.
288
+
289
+ Returns:
290
+ Self for method chaining.
291
+ """
274
292
  if not nodes:
275
293
  return self
276
294
  for node in nodes:
@@ -286,9 +304,17 @@ class Workflow:
286
304
  return self
287
305
 
288
306
  def then(self, next_node: str, condition: Optional[Callable] = None):
289
- """Add a transition to the next node with an optional condition."""
307
+ """Add a transition to the next node with an optional condition.
308
+
309
+ Args:
310
+ next_node: Name of the node to transition to.
311
+ condition: Optional callable taking context and returning a boolean.
312
+
313
+ Returns:
314
+ Self for method chaining.
315
+ """
290
316
  if next_node not in self.nodes:
291
- self._register_node(next_node) # Register without changing current_node
317
+ self._register_node(next_node)
292
318
  if self.current_node:
293
319
  self.transitions.setdefault(self.current_node, []).append((next_node, condition))
294
320
  logger.debug(f"Added transition from {self.current_node} to {next_node} with condition {condition}")
@@ -297,23 +323,49 @@ class Workflow:
297
323
  self.current_node = next_node
298
324
  return self
299
325
 
300
- def branch(self, branches: List[Tuple[str, Optional[Callable]]]) -> "Workflow":
301
- """Add multiple conditional branches from the current node."""
326
+ def branch(
327
+ self,
328
+ branches: List[Tuple[str, Optional[Callable]]],
329
+ default: Optional[str] = None,
330
+ next_node: Optional[str] = None,
331
+ ) -> "Workflow":
332
+ """Add multiple conditional branches from the current node with an optional default and next node.
333
+
334
+ Args:
335
+ branches: List of tuples (next_node, condition), where condition takes context and returns a boolean.
336
+ default: Optional node to transition to if no branch conditions are met.
337
+ next_node: Optional node to set as current_node after branching (e.g., for convergence).
338
+
339
+ Returns:
340
+ Self for method chaining.
341
+ """
302
342
  if not self.current_node:
303
343
  logger.warning("No current node set for branching")
304
344
  return self
305
- for next_node, condition in branches:
306
- if next_node not in self.nodes:
307
- self._register_node(next_node)
308
- self.transitions.setdefault(self.current_node, []).append((next_node, condition))
309
- logger.debug(f"Added branch from {self.current_node} to {next_node} with condition {condition}")
345
+ for next_node_name, condition in branches:
346
+ if next_node_name not in self.nodes:
347
+ self._register_node(next_node_name)
348
+ self.transitions.setdefault(self.current_node, []).append((next_node_name, condition))
349
+ logger.debug(f"Added branch from {self.current_node} to {next_node_name} with condition {condition}")
350
+ if default:
351
+ if default not in self.nodes:
352
+ self._register_node(default)
353
+ self.transitions.setdefault(self.current_node, []).append((default, None))
354
+ logger.debug(f"Added default transition from {self.current_node} to {default}")
355
+ self.current_node = next_node # Explicitly set next_node if provided
310
356
  return self
311
357
 
312
358
  def converge(self, convergence_node: str) -> "Workflow":
313
- """Set a convergence point for all previous branches."""
359
+ """Set a convergence point for all previous branches.
360
+
361
+ Args:
362
+ convergence_node: Name of the node where branches converge.
363
+
364
+ Returns:
365
+ Self for method chaining.
366
+ """
314
367
  if convergence_node not in self.nodes:
315
368
  self._register_node(convergence_node)
316
- # Find all leaf nodes (nodes with no outgoing transitions) and point them to convergence_node
317
369
  for node in self.nodes:
318
370
  if (node not in self.transitions or not self.transitions[node]) and node != convergence_node:
319
371
  self.transitions.setdefault(node, []).append((convergence_node, None))
@@ -322,32 +374,63 @@ class Workflow:
322
374
  return self
323
375
 
324
376
  def parallel(self, *nodes: str):
325
- """Add parallel nodes to execute concurrently."""
377
+ """Add parallel nodes to execute concurrently.
378
+
379
+ Args:
380
+ *nodes: Variable number of node names to execute in parallel.
381
+
382
+ Returns:
383
+ Self for method chaining.
384
+ """
326
385
  if self.current_node:
327
386
  for node in nodes:
328
387
  self.transitions.setdefault(self.current_node, []).append((node, None))
329
- self.current_node = None # Reset after parallel to force explicit next node
388
+ self.current_node = None
330
389
  return self
331
390
 
332
391
  def add_observer(self, observer: WorkflowObserver) -> "Workflow":
333
- """Add an event observer callback to the workflow."""
392
+ """Add an event observer callback to the workflow.
393
+
394
+ Args:
395
+ observer: Callable to handle workflow events.
396
+
397
+ Returns:
398
+ Self for method chaining.
399
+ """
334
400
  if observer not in self._observers:
335
401
  self._observers.append(observer)
336
402
  logger.debug(f"Added observer to workflow: {observer}")
337
- return self # Support chaining
403
+ return self
338
404
 
339
405
  def add_sub_workflow(self, name: str, sub_workflow: "Workflow", inputs: Dict[str, Any], output: str):
340
- """Add a sub-workflow as a node with flexible inputs mapping."""
406
+ """Add a sub-workflow as a node with flexible inputs mapping.
407
+
408
+ Args:
409
+ name: Name of the sub-workflow node.
410
+ sub_workflow: The Workflow instance to embed.
411
+ inputs: Dictionary mapping sub-workflow inputs to context keys or callables.
412
+ output: Context key for the sub-workflow's result.
413
+
414
+ Returns:
415
+ Self for method chaining.
416
+ """
341
417
  sub_node = SubWorkflowNode(sub_workflow, inputs, output)
342
418
  self.nodes[name] = sub_node
343
- self.node_inputs[name] = [] # Inputs handled internally by SubWorkflowNode
419
+ self.node_inputs[name] = []
344
420
  self.node_outputs[name] = output
345
421
  self.current_node = name
346
422
  logger.debug(f"Added sub-workflow {name} with inputs {inputs} and output {output}")
347
423
  return self
348
424
 
349
425
  def build(self, parent_engine: Optional["WorkflowEngine"] = None) -> WorkflowEngine:
350
- """Build and return a WorkflowEngine instance with registered observers."""
426
+ """Build and return a WorkflowEngine instance with registered observers.
427
+
428
+ Args:
429
+ parent_engine: Optional parent WorkflowEngine for sub-workflows.
430
+
431
+ Returns:
432
+ Configured WorkflowEngine instance.
433
+ """
351
434
  engine = WorkflowEngine(self, parent_engine=parent_engine)
352
435
  for observer in self._observers:
353
436
  engine.add_observer(observer)
@@ -355,11 +438,18 @@ class Workflow:
355
438
 
356
439
 
357
440
  class Nodes:
358
- NODE_REGISTRY: Dict[str, Tuple[Callable, List[str], Optional[str]]] = {} # Registry to hold node functions and metadata
441
+ NODE_REGISTRY: Dict[str, Tuple[Callable, List[str], Optional[str]]] = {}
359
442
 
360
443
  @classmethod
361
444
  def define(cls, output: Optional[str] = None):
362
- """Decorator for defining simple workflow nodes."""
445
+ """Decorator for defining simple workflow nodes.
446
+
447
+ Args:
448
+ output: Optional context key for the node's result.
449
+
450
+ Returns:
451
+ Decorator function wrapping the node logic.
452
+ """
363
453
  def decorator(func: Callable) -> Callable:
364
454
  async def wrapped_func(**kwargs):
365
455
  try:
@@ -372,8 +462,6 @@ class Nodes:
372
462
  except Exception as e:
373
463
  logger.error(f"Error in node {func.__name__}: {e}")
374
464
  raise
375
-
376
- # Get parameter names from function signature
377
465
  sig = inspect.signature(func)
378
466
  inputs = [param.name for param in sig.parameters.values()]
379
467
  logger.debug(f"Registering node {func.__name__} with inputs {inputs} and output {output}")
@@ -383,7 +471,14 @@ class Nodes:
383
471
 
384
472
  @classmethod
385
473
  def validate_node(cls, output: str):
386
- """Decorator for nodes that validate inputs."""
474
+ """Decorator for nodes that validate inputs and return a string.
475
+
476
+ Args:
477
+ output: Context key for the validation result.
478
+
479
+ Returns:
480
+ Decorator function wrapping the validation logic.
481
+ """
387
482
  def decorator(func: Callable) -> Callable:
388
483
  async def wrapped_func(**kwargs):
389
484
  try:
@@ -398,8 +493,6 @@ class Nodes:
398
493
  except Exception as e:
399
494
  logger.error(f"Validation error in {func.__name__}: {e}")
400
495
  raise
401
-
402
- # Get parameter names from function signature
403
496
  sig = inspect.signature(func)
404
497
  inputs = [param.name for param in sig.parameters.values()]
405
498
  logger.debug(f"Registering node {func.__name__} with inputs {inputs} and output {output}")
@@ -409,11 +502,18 @@ class Nodes:
409
502
 
410
503
  @classmethod
411
504
  def transform_node(cls, output: str, transformer: Callable[[Any], Any]):
412
- """Decorator for nodes that transform their inputs."""
505
+ """Decorator for nodes that transform their inputs.
506
+
507
+ Args:
508
+ output: Context key for the transformed result.
509
+ transformer: Callable to transform the input.
510
+
511
+ Returns:
512
+ Decorator function wrapping the transformation logic.
513
+ """
413
514
  def decorator(func: Callable) -> Callable:
414
515
  async def wrapped_func(**kwargs):
415
516
  try:
416
- # Apply transformer to the first input value
417
517
  input_key = list(kwargs.keys())[0] if kwargs else None
418
518
  if input_key:
419
519
  transformed_input = transformer(kwargs[input_key])
@@ -427,8 +527,6 @@ class Nodes:
427
527
  except Exception as e:
428
528
  logger.error(f"Error in transform node {func.__name__}: {e}")
429
529
  raise
430
-
431
- # Get parameter names from function signature
432
530
  sig = inspect.signature(func)
433
531
  inputs = [param.name for param in sig.parameters.values()]
434
532
  logger.debug(f"Registering node {func.__name__} with inputs {inputs} and output {output}")
@@ -467,8 +565,9 @@ class Nodes:
467
565
  @classmethod
468
566
  def llm_node(
469
567
  cls,
470
- system_prompt: str,
471
- output: str,
568
+ system_prompt: str = "",
569
+ system_prompt_file: Optional[str] = None,
570
+ output: str = "",
472
571
  prompt_template: str = "",
473
572
  prompt_file: Optional[str] = None,
474
573
  temperature: float = 0.7,
@@ -476,13 +575,38 @@ class Nodes:
476
575
  top_p: float = 1.0,
477
576
  presence_penalty: float = 0.0,
478
577
  frequency_penalty: float = 0.0,
578
+ model: Callable[[Dict[str, Any]], str] = lambda ctx: "gpt-3.5-turbo",
479
579
  **kwargs,
480
580
  ):
481
- """Decorator for creating LLM nodes with plain text output, supporting dynamic parameters via input mappings."""
581
+ """Decorator for creating LLM nodes with plain text output, supporting dynamic parameters.
582
+
583
+ Args:
584
+ system_prompt: Inline system prompt defining LLM behavior.
585
+ system_prompt_file: Path to a system prompt template file (overrides system_prompt).
586
+ output: Context key for the LLM's result.
587
+ prompt_template: Inline Jinja2 template for the user prompt.
588
+ prompt_file: Path to a user prompt template file (overrides prompt_template).
589
+ temperature: Randomness control (0.0 to 1.0).
590
+ max_tokens: Maximum response length.
591
+ top_p: Nucleus sampling parameter (0.0 to 1.0).
592
+ presence_penalty: Penalty for repetition (-2.0 to 2.0).
593
+ frequency_penalty: Penalty for frequent words (-2.0 to 2.0).
594
+ model: Callable or string to determine the LLM model dynamically from context.
595
+ **kwargs: Additional parameters for the LLM call.
596
+
597
+ Returns:
598
+ Decorator function wrapping the LLM logic.
599
+ """
482
600
  def decorator(func: Callable) -> Callable:
483
- async def wrapped_func(model: str, **func_kwargs):
484
- # Extract parameters from func_kwargs if provided, else use defaults
601
+ async def wrapped_func(model_param: str = None, **func_kwargs):
485
602
  system_prompt_to_use = func_kwargs.pop("system_prompt", system_prompt)
603
+ system_prompt_file_to_use = func_kwargs.pop("system_prompt_file", system_prompt_file)
604
+
605
+ if system_prompt_file_to_use:
606
+ system_content = cls._load_prompt_from_file(system_prompt_file_to_use, func_kwargs)
607
+ else:
608
+ system_content = system_prompt_to_use
609
+
486
610
  prompt_template_to_use = func_kwargs.pop("prompt_template", prompt_template)
487
611
  prompt_file_to_use = func_kwargs.pop("prompt_file", prompt_file)
488
612
  temperature_to_use = func_kwargs.pop("temperature", temperature)
@@ -490,25 +614,27 @@ class Nodes:
490
614
  top_p_to_use = func_kwargs.pop("top_p", top_p)
491
615
  presence_penalty_to_use = func_kwargs.pop("presence_penalty", presence_penalty)
492
616
  frequency_penalty_to_use = func_kwargs.pop("frequency_penalty", frequency_penalty)
617
+
618
+ # Prioritize model from func_kwargs (workflow mapping), then model_param, then default
619
+ model_to_use = func_kwargs.get("model", model_param if model_param is not None else model(func_kwargs))
620
+ logger.debug(f"Selected model for {func.__name__}: {model_to_use}")
493
621
 
494
- # Use only signature parameters for template rendering
495
622
  sig = inspect.signature(func)
496
623
  template_vars = {k: v for k, v in func_kwargs.items() if k in sig.parameters}
497
624
  prompt = cls._render_template(prompt_template_to_use, prompt_file_to_use, template_vars)
498
625
  messages = [
499
- {"role": "system", "content": system_prompt_to_use},
626
+ {"role": "system", "content": system_content},
500
627
  {"role": "user", "content": prompt},
501
628
  ]
502
629
 
503
- # Log the model and a preview of the prompt
504
630
  truncated_prompt = prompt[:200] + "..." if len(prompt) > 200 else prompt
505
- logger.info(f"LLM node {func.__name__} using model: {model}")
506
- logger.debug(f"System prompt: {system_prompt_to_use[:100]}...")
631
+ logger.info(f"LLM node {func.__name__} using model: {model_to_use}")
632
+ logger.debug(f"System prompt: {system_content[:100]}...")
507
633
  logger.debug(f"User prompt preview: {truncated_prompt}")
508
634
 
509
635
  try:
510
636
  response = await acompletion(
511
- model=model,
637
+ model=model_to_use,
512
638
  messages=messages,
513
639
  temperature=temperature_to_use,
514
640
  max_tokens=max_tokens_to_use,
@@ -530,8 +656,6 @@ class Nodes:
530
656
  except Exception as e:
531
657
  logger.error(f"Error in LLM node {func.__name__}: {e}")
532
658
  raise
533
-
534
- # Get parameter names from function signature and add 'model'
535
659
  sig = inspect.signature(func)
536
660
  inputs = ['model'] + [param.name for param in sig.parameters.values()]
537
661
  logger.debug(f"Registering node {func.__name__} with inputs {inputs} and output {output}")
@@ -542,9 +666,10 @@ class Nodes:
542
666
  @classmethod
543
667
  def structured_llm_node(
544
668
  cls,
545
- system_prompt: str,
546
- output: str,
547
- response_model: Type[BaseModel],
669
+ system_prompt: str = "",
670
+ system_prompt_file: Optional[str] = None,
671
+ output: str = "",
672
+ response_model: Type[BaseModel] = None,
548
673
  prompt_template: str = "",
549
674
  prompt_file: Optional[str] = None,
550
675
  temperature: float = 0.7,
@@ -552,9 +677,29 @@ class Nodes:
552
677
  top_p: float = 1.0,
553
678
  presence_penalty: float = 0.0,
554
679
  frequency_penalty: float = 0.0,
680
+ model: Callable[[Dict[str, Any]], str] = lambda ctx: "gpt-3.5-turbo",
555
681
  **kwargs,
556
682
  ):
557
- """Decorator for creating LLM nodes with structured output, supporting dynamic parameters via input mappings."""
683
+ """Decorator for creating LLM nodes with structured output, supporting dynamic parameters.
684
+
685
+ Args:
686
+ system_prompt: Inline system prompt defining LLM behavior.
687
+ system_prompt_file: Path to a system prompt template file (overrides system_prompt).
688
+ output: Context key for the LLM's structured result.
689
+ response_model: Pydantic model class for structured output.
690
+ prompt_template: Inline Jinja2 template for the user prompt.
691
+ prompt_file: Path to a user prompt template file (overrides prompt_template).
692
+ temperature: Randomness control (0.0 to 1.0).
693
+ max_tokens: Maximum response length.
694
+ top_p: Nucleus sampling parameter (0.0 to 1.0).
695
+ presence_penalty: Penalty for repetition (-2.0 to 2.0).
696
+ frequency_penalty: Penalty for frequent words (-2.0 to 2.0).
697
+ model: Callable or string to determine the LLM model dynamically from context.
698
+ **kwargs: Additional parameters for the LLM call.
699
+
700
+ Returns:
701
+ Decorator function wrapping the structured LLM logic.
702
+ """
558
703
  try:
559
704
  client = instructor.from_litellm(acompletion)
560
705
  except ImportError:
@@ -562,9 +707,15 @@ class Nodes:
562
707
  raise ImportError("Instructor is required for structured_llm_node")
563
708
 
564
709
  def decorator(func: Callable) -> Callable:
565
- async def wrapped_func(model: str, **func_kwargs):
566
- # Extract parameters from func_kwargs if provided, else use defaults
710
+ async def wrapped_func(model_param: str = None, **func_kwargs):
567
711
  system_prompt_to_use = func_kwargs.pop("system_prompt", system_prompt)
712
+ system_prompt_file_to_use = func_kwargs.pop("system_prompt_file", system_prompt_file)
713
+
714
+ if system_prompt_file_to_use:
715
+ system_content = cls._load_prompt_from_file(system_prompt_file_to_use, func_kwargs)
716
+ else:
717
+ system_content = system_prompt_to_use
718
+
568
719
  prompt_template_to_use = func_kwargs.pop("prompt_template", prompt_template)
569
720
  prompt_file_to_use = func_kwargs.pop("prompt_file", prompt_file)
570
721
  temperature_to_use = func_kwargs.pop("temperature", temperature)
@@ -572,26 +723,28 @@ class Nodes:
572
723
  top_p_to_use = func_kwargs.pop("top_p", top_p)
573
724
  presence_penalty_to_use = func_kwargs.pop("presence_penalty", presence_penalty)
574
725
  frequency_penalty_to_use = func_kwargs.pop("frequency_penalty", frequency_penalty)
726
+
727
+ # Prioritize model from func_kwargs (workflow mapping), then model_param, then default
728
+ model_to_use = func_kwargs.get("model", model_param if model_param is not None else model(func_kwargs))
729
+ logger.debug(f"Selected model for {func.__name__}: {model_to_use}")
575
730
 
576
- # Use only signature parameters for template rendering
577
731
  sig = inspect.signature(func)
578
732
  template_vars = {k: v for k, v in func_kwargs.items() if k in sig.parameters}
579
733
  prompt = cls._render_template(prompt_template_to_use, prompt_file_to_use, template_vars)
580
734
  messages = [
581
- {"role": "system", "content": system_prompt_to_use},
735
+ {"role": "system", "content": system_content},
582
736
  {"role": "user", "content": prompt},
583
737
  ]
584
738
 
585
- # Log the model and a preview of the prompt
586
739
  truncated_prompt = prompt[:200] + "..." if len(prompt) > 200 else prompt
587
- logger.info(f"Structured LLM node {func.__name__} using model: {model}")
588
- logger.debug(f"System prompt: {system_prompt_to_use[:100]}...")
740
+ logger.info(f"Structured LLM node {func.__name__} using model: {model_to_use}")
741
+ logger.debug(f"System prompt: {system_content[:100]}...")
589
742
  logger.debug(f"User prompt preview: {truncated_prompt}")
590
743
  logger.debug(f"Expected response model: {response_model.__name__}")
591
744
 
592
745
  try:
593
746
  structured_response, raw_response = await client.chat.completions.create_with_completion(
594
- model=model,
747
+ model=model_to_use,
595
748
  messages=messages,
596
749
  response_model=response_model,
597
750
  temperature=temperature_to_use,
@@ -616,8 +769,6 @@ class Nodes:
616
769
  except Exception as e:
617
770
  logger.error(f"Error in structured LLM node {func.__name__}: {e}")
618
771
  raise
619
-
620
- # Get parameter names from function signature and add 'model'
621
772
  sig = inspect.signature(func)
622
773
  inputs = ['model'] + [param.name for param in sig.parameters.values()]
623
774
  logger.debug(f"Registering node {func.__name__} with inputs {inputs} and output {output}")
@@ -632,20 +783,26 @@ class Nodes:
632
783
  template: str = "",
633
784
  template_file: Optional[str] = None,
634
785
  ):
635
- """Decorator for creating nodes that apply a Jinja2 template to inputs, supporting dynamic parameters."""
786
+ """Decorator for creating nodes that apply a Jinja2 template to inputs.
787
+
788
+ Args:
789
+ output: Context key for the rendered result.
790
+ template: Inline Jinja2 template string.
791
+ template_file: Path to a template file (overrides template).
792
+
793
+ Returns:
794
+ Decorator function wrapping the template logic.
795
+ """
636
796
  def decorator(func: Callable) -> Callable:
637
797
  async def wrapped_func(**func_kwargs):
638
- # Extract template parameters from func_kwargs if provided, else use defaults
639
798
  template_to_use = func_kwargs.pop("template", template)
640
799
  template_file_to_use = func_kwargs.pop("template_file", template_file)
641
800
 
642
- # Use only signature parameters (excluding rendered_content) for template rendering
643
801
  sig = inspect.signature(func)
644
802
  expected_params = [p.name for p in sig.parameters.values() if p.name != 'rendered_content']
645
803
  template_vars = {k: v for k, v in func_kwargs.items() if k in expected_params}
646
804
  rendered_content = cls._render_template(template_to_use, template_file_to_use, template_vars)
647
805
 
648
- # Filter func_kwargs for the function call
649
806
  filtered_kwargs = {k: v for k, v in func_kwargs.items() if k in expected_params}
650
807
 
651
808
  try:
@@ -658,8 +815,6 @@ class Nodes:
658
815
  except Exception as e:
659
816
  logger.error(f"Error in template node {func.__name__}: {e}")
660
817
  raise
661
-
662
- # Get parameter names from function signature and add 'rendered_content' if not present
663
818
  sig = inspect.signature(func)
664
819
  inputs = [param.name for param in sig.parameters.values()]
665
820
  if 'rendered_content' not in inputs:
@@ -670,15 +825,20 @@ class Nodes:
670
825
  return decorator
671
826
 
672
827
 
673
- # Example workflow with observer integration, updated nodes, input mappings, and dynamic parameters
828
+ # Add a templates directory path at the module level
829
+ TEMPLATES_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "templates")
830
+
831
+ # Helper function to get template paths
832
+ def get_template_path(template_name):
833
+ return os.path.join(TEMPLATES_DIR, template_name)
834
+
835
+
674
836
  async def example_workflow():
675
- # Define Pydantic model for structured output
676
837
  class OrderDetails(BaseModel):
677
838
  order_id: str
678
839
  items_in_stock: List[str]
679
840
  items_out_of_stock: List[str]
680
841
 
681
- # Define an example observer for progress
682
842
  async def progress_monitor(event: WorkflowEvent):
683
843
  print(f"[{event.event_type.value}] {event.node_name or 'Workflow'}")
684
844
  if event.result is not None:
@@ -686,7 +846,6 @@ async def example_workflow():
686
846
  if event.exception is not None:
687
847
  print(f"Exception: {event.exception}")
688
848
 
689
- # Define an observer for token usage
690
849
  class TokenUsageObserver:
691
850
  def __init__(self):
692
851
  self.total_prompt_tokens = 0
@@ -709,16 +868,15 @@ async def example_workflow():
709
868
  for node, usage in self.node_usages.items():
710
869
  print(f"Node {node}: {usage}")
711
870
 
712
- # Define nodes
713
871
  @Nodes.validate_node(output="validation_result")
714
872
  async def validate_order(order: Dict[str, Any]) -> str:
715
873
  return "Order validated" if order.get("items") else "Invalid order"
716
874
 
717
875
  @Nodes.structured_llm_node(
718
- system_prompt="You are an inventory checker. Respond with a JSON object containing 'order_id', 'items_in_stock', and 'items_out_of_stock'.",
876
+ system_prompt_file=get_template_path("system_check_inventory.j2"),
719
877
  output="inventory_status",
720
878
  response_model=OrderDetails,
721
- prompt_template="Check if the following items are in stock: {{ items }}. Return the result in JSON format with 'order_id' set to '123'.",
879
+ prompt_file=get_template_path("prompt_check_inventory.j2"),
722
880
  )
723
881
  async def check_inventory(items: List[str]) -> OrderDetails:
724
882
  return OrderDetails(order_id="123", items_in_stock=["item1"], items_out_of_stock=[])
@@ -754,13 +912,10 @@ async def example_workflow():
754
912
  async def format_order_message(rendered_content: str, items: List[str]) -> str:
755
913
  return rendered_content
756
914
 
757
- # Sub-workflow for payment and shipping
758
915
  payment_shipping_sub_wf = Workflow("process_payment").sequence("process_payment", "arrange_shipping")
759
916
 
760
- # Instantiate token usage observer
761
917
  token_observer = TokenUsageObserver()
762
918
 
763
- # Main workflow with dynamic parameter overrides
764
919
  workflow = (
765
920
  Workflow("validate_order")
766
921
  .add_observer(progress_monitor)
@@ -769,13 +924,13 @@ async def example_workflow():
769
924
  .node("transform_items")
770
925
  .node("format_order_message", inputs_mapping={
771
926
  "items": "items",
772
- "template": "Custom order: {{ items | join(', ') }}" # Dynamic override
927
+ "template": "Custom order: {{ items | join(', ') }}"
773
928
  })
774
929
  .node("check_inventory", inputs_mapping={
775
930
  "model": lambda ctx: "gemini/gemini-2.0-flash",
776
931
  "items": "transformed_items",
777
- "temperature": 0.5, # Dynamic override
778
- "max_tokens": 1000 # Dynamic override
932
+ "temperature": 0.5,
933
+ "max_tokens": 1000
779
934
  })
780
935
  .add_sub_workflow(
781
936
  "payment_shipping",
@@ -783,15 +938,17 @@ async def example_workflow():
783
938
  inputs={"order": lambda ctx: {"items": ctx["items"]}},
784
939
  output="shipping_confirmation"
785
940
  )
786
- .branch([
787
- ("payment_shipping", lambda ctx: len(ctx.get("inventory_status").items_out_of_stock) == 0 if ctx.get("inventory_status") else False),
788
- ("notify_customer_out_of_stock", lambda ctx: len(ctx.get("inventory_status").items_out_of_stock) > 0 if ctx.get("inventory_status") else True)
789
- ])
941
+ .branch(
942
+ [
943
+ ("payment_shipping", lambda ctx: len(ctx.get("inventory_status").items_out_of_stock) == 0 if ctx.get("inventory_status") else False),
944
+ ("notify_customer_out_of_stock", lambda ctx: len(ctx.get("inventory_status").items_out_of_stock) > 0 if ctx.get("inventory_status") else True)
945
+ ],
946
+ next_node="update_order_status"
947
+ )
790
948
  .converge("update_order_status")
791
949
  .sequence("update_order_status", "send_confirmation_email")
792
950
  )
793
951
 
794
- # Execute workflow
795
952
  initial_context = {"customer_order": {"items": ["item1", "item2"]}, "items": ["item1", "item2"]}
796
953
  engine = workflow.build()
797
954
  result = await engine.run(initial_context)