quantalogic 0.56.0__py3-none-any.whl → 0.58.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
quantalogic/flow/flow.py CHANGED
@@ -12,7 +12,8 @@
12
12
  # ///
13
13
 
14
14
  import asyncio
15
- import inspect # Added for accurate parameter detection
15
+ import inspect
16
+ import os
16
17
  from dataclasses import dataclass
17
18
  from enum import Enum
18
19
  from pathlib import Path
@@ -47,18 +48,17 @@ class WorkflowEvent:
47
48
  transition_from: Optional[str] = None
48
49
  transition_to: Optional[str] = None
49
50
  sub_workflow_name: Optional[str] = None
50
- usage: Optional[Dict[str, Any]] = None # Added to store token usage and cost
51
+ usage: Optional[Dict[str, Any]] = None
51
52
 
52
53
 
53
54
  WorkflowObserver = Callable[[WorkflowEvent], None]
54
55
 
55
56
 
56
- # Define a class for sub-workflow nodes with updated inputs handling
57
57
  class SubWorkflowNode:
58
58
  def __init__(self, sub_workflow: "Workflow", inputs: Dict[str, Any], output: str):
59
59
  """Initialize a sub-workflow node with flexible inputs mapping."""
60
60
  self.sub_workflow = sub_workflow
61
- self.inputs = inputs # Maps sub_key to main_key, callable, or value
61
+ self.inputs = inputs
62
62
  self.output = output
63
63
 
64
64
  async def __call__(self, engine: "WorkflowEngine"):
@@ -70,7 +70,7 @@ class SubWorkflowNode:
70
70
  elif isinstance(mapping, str):
71
71
  sub_context[sub_key] = engine.context.get(mapping)
72
72
  else:
73
- sub_context[sub_key] = mapping # Direct value
73
+ sub_context[sub_key] = mapping
74
74
  sub_engine = self.sub_workflow.build(parent_engine=engine)
75
75
  result = await sub_engine.run(sub_context)
76
76
  return result.get(self.output)
@@ -82,7 +82,7 @@ class WorkflowEngine:
82
82
  self.workflow = workflow
83
83
  self.context: Dict[str, Any] = {}
84
84
  self.observers: List[WorkflowObserver] = []
85
- self.parent_engine = parent_engine # Link to parent engine for sub-workflow observer propagation
85
+ self.parent_engine = parent_engine
86
86
 
87
87
  def add_observer(self, observer: WorkflowObserver) -> None:
88
88
  """Register an event observer callback."""
@@ -90,7 +90,7 @@ class WorkflowEngine:
90
90
  self.observers.append(observer)
91
91
  logger.debug(f"Added observer: {observer}")
92
92
  if self.parent_engine:
93
- self.parent_engine.add_observer(observer) # Propagate to parent for global visibility
93
+ self.parent_engine.add_observer(observer)
94
94
 
95
95
  def remove_observer(self, observer: WorkflowObserver) -> None:
96
96
  """Remove an event observer callback."""
@@ -140,18 +140,15 @@ class WorkflowEngine:
140
140
  )
141
141
  break
142
142
 
143
- # Prepare inputs with mappings
144
143
  input_mappings = self.workflow.node_input_mappings.get(current_node, {})
145
144
  inputs = {}
146
- # Add all mapped inputs
147
145
  for key, mapping in input_mappings.items():
148
146
  if callable(mapping):
149
147
  inputs[key] = mapping(self.context)
150
148
  elif isinstance(mapping, str):
151
149
  inputs[key] = self.context.get(mapping)
152
150
  else:
153
- inputs[key] = mapping # Direct value
154
- # For parameters in node_inputs that are not mapped, get from context
151
+ inputs[key] = mapping
155
152
  for param in self.workflow.node_inputs[current_node]:
156
153
  if param not in inputs:
157
154
  inputs[param] = self.context.get(param)
@@ -159,7 +156,6 @@ class WorkflowEngine:
159
156
  result = None
160
157
  exception = None
161
158
 
162
- # Handle sub-workflow nodes
163
159
  if isinstance(node_func, SubWorkflowNode):
164
160
  await self._notify_observers(
165
161
  WorkflowEvent(
@@ -172,15 +168,15 @@ class WorkflowEngine:
172
168
 
173
169
  try:
174
170
  if isinstance(node_func, SubWorkflowNode):
175
- result = await node_func(self) # Sub-workflow handles its own inputs
176
- usage = None # Sub-workflow usage is handled by its own nodes
171
+ result = await node_func(self)
172
+ usage = None
177
173
  else:
178
174
  result = await node_func(**inputs)
179
- usage = getattr(node_func, "usage", None) # Extract usage if set by LLM nodes
175
+ usage = getattr(node_func, "usage", None)
180
176
  output_key = self.workflow.node_outputs[current_node]
181
177
  if output_key:
182
178
  self.context[output_key] = result
183
- elif isinstance(result, dict): # Update context if node returns a dict and output is None
179
+ elif isinstance(result, dict):
184
180
  self.context.update(result)
185
181
  logger.debug(f"Updated context with {result} from node {current_node}")
186
182
  await self._notify_observers(
@@ -189,7 +185,7 @@ class WorkflowEngine:
189
185
  node_name=current_node,
190
186
  context=self.context,
191
187
  result=result,
192
- usage=usage, # Include usage data in the event
188
+ usage=usage,
193
189
  )
194
190
  )
195
191
  except Exception as e:
@@ -242,17 +238,21 @@ class WorkflowEngine:
242
238
 
243
239
  class Workflow:
244
240
  def __init__(self, start_node: str):
245
- """Initialize a workflow with a starting node."""
241
+ """Initialize a workflow with a starting node.
242
+
243
+ Args:
244
+ start_node: The name of the initial node in the workflow.
245
+ """
246
246
  self.start_node = start_node
247
247
  self.nodes: Dict[str, Callable] = {}
248
248
  self.node_inputs: Dict[str, List[str]] = {}
249
249
  self.node_outputs: Dict[str, Optional[str]] = {}
250
250
  self.transitions: Dict[str, List[Tuple[str, Optional[Callable]]]] = {}
251
- self.node_input_mappings: Dict[str, Dict[str, Any]] = {} # Store input mappings for nodes
251
+ self.node_input_mappings: Dict[str, Dict[str, Any]] = {}
252
252
  self.current_node = None
253
253
  self._observers: List[WorkflowObserver] = []
254
- self._register_node(start_node) # Register the start node without setting current_node
255
- self.current_node = start_node # Set current_node explicitly after registration
254
+ self._register_node(start_node)
255
+ self.current_node = start_node
256
256
 
257
257
  def _register_node(self, name: str):
258
258
  """Register a node without modifying the current node."""
@@ -264,7 +264,15 @@ class Workflow:
264
264
  self.node_outputs[name] = output
265
265
 
266
266
  def node(self, name: str, inputs_mapping: Optional[Dict[str, Any]] = None):
267
- """Add a node to the workflow chain with an optional inputs mapping."""
267
+ """Add a node to the workflow chain with an optional inputs mapping.
268
+
269
+ Args:
270
+ name: The name of the node to add.
271
+ inputs_mapping: Optional dictionary mapping node inputs to context keys or callables.
272
+
273
+ Returns:
274
+ Self for method chaining.
275
+ """
268
276
  self._register_node(name)
269
277
  if inputs_mapping:
270
278
  self.node_input_mappings[name] = inputs_mapping
@@ -273,7 +281,14 @@ class Workflow:
273
281
  return self
274
282
 
275
283
  def sequence(self, *nodes: str):
276
- """Add a sequence of nodes to execute in order."""
284
+ """Add a sequence of nodes to execute in order.
285
+
286
+ Args:
287
+ *nodes: Variable number of node names to execute sequentially.
288
+
289
+ Returns:
290
+ Self for method chaining.
291
+ """
277
292
  if not nodes:
278
293
  return self
279
294
  for node in nodes:
@@ -289,9 +304,17 @@ class Workflow:
289
304
  return self
290
305
 
291
306
  def then(self, next_node: str, condition: Optional[Callable] = None):
292
- """Add a transition to the next node with an optional condition."""
307
+ """Add a transition to the next node with an optional condition.
308
+
309
+ Args:
310
+ next_node: Name of the node to transition to.
311
+ condition: Optional callable taking context and returning a boolean.
312
+
313
+ Returns:
314
+ Self for method chaining.
315
+ """
293
316
  if next_node not in self.nodes:
294
- self._register_node(next_node) # Register without changing current_node
317
+ self._register_node(next_node)
295
318
  if self.current_node:
296
319
  self.transitions.setdefault(self.current_node, []).append((next_node, condition))
297
320
  logger.debug(f"Added transition from {self.current_node} to {next_node} with condition {condition}")
@@ -300,23 +323,49 @@ class Workflow:
300
323
  self.current_node = next_node
301
324
  return self
302
325
 
303
- def branch(self, branches: List[Tuple[str, Optional[Callable]]]) -> "Workflow":
304
- """Add multiple conditional branches from the current node."""
326
+ def branch(
327
+ self,
328
+ branches: List[Tuple[str, Optional[Callable]]],
329
+ default: Optional[str] = None,
330
+ next_node: Optional[str] = None,
331
+ ) -> "Workflow":
332
+ """Add multiple conditional branches from the current node with an optional default and next node.
333
+
334
+ Args:
335
+ branches: List of tuples (next_node, condition), where condition takes context and returns a boolean.
336
+ default: Optional node to transition to if no branch conditions are met.
337
+ next_node: Optional node to set as current_node after branching (e.g., for convergence).
338
+
339
+ Returns:
340
+ Self for method chaining.
341
+ """
305
342
  if not self.current_node:
306
343
  logger.warning("No current node set for branching")
307
344
  return self
308
- for next_node, condition in branches:
309
- if next_node not in self.nodes:
310
- self._register_node(next_node)
311
- self.transitions.setdefault(self.current_node, []).append((next_node, condition))
312
- logger.debug(f"Added branch from {self.current_node} to {next_node} with condition {condition}")
345
+ for next_node_name, condition in branches:
346
+ if next_node_name not in self.nodes:
347
+ self._register_node(next_node_name)
348
+ self.transitions.setdefault(self.current_node, []).append((next_node_name, condition))
349
+ logger.debug(f"Added branch from {self.current_node} to {next_node_name} with condition {condition}")
350
+ if default:
351
+ if default not in self.nodes:
352
+ self._register_node(default)
353
+ self.transitions.setdefault(self.current_node, []).append((default, None))
354
+ logger.debug(f"Added default transition from {self.current_node} to {default}")
355
+ self.current_node = next_node # Explicitly set next_node if provided
313
356
  return self
314
357
 
315
358
  def converge(self, convergence_node: str) -> "Workflow":
316
- """Set a convergence point for all previous branches."""
359
+ """Set a convergence point for all previous branches.
360
+
361
+ Args:
362
+ convergence_node: Name of the node where branches converge.
363
+
364
+ Returns:
365
+ Self for method chaining.
366
+ """
317
367
  if convergence_node not in self.nodes:
318
368
  self._register_node(convergence_node)
319
- # Find all leaf nodes (nodes with no outgoing transitions) and point them to convergence_node
320
369
  for node in self.nodes:
321
370
  if (node not in self.transitions or not self.transitions[node]) and node != convergence_node:
322
371
  self.transitions.setdefault(node, []).append((convergence_node, None))
@@ -325,32 +374,63 @@ class Workflow:
325
374
  return self
326
375
 
327
376
  def parallel(self, *nodes: str):
328
- """Add parallel nodes to execute concurrently."""
377
+ """Add parallel nodes to execute concurrently.
378
+
379
+ Args:
380
+ *nodes: Variable number of node names to execute in parallel.
381
+
382
+ Returns:
383
+ Self for method chaining.
384
+ """
329
385
  if self.current_node:
330
386
  for node in nodes:
331
387
  self.transitions.setdefault(self.current_node, []).append((node, None))
332
- self.current_node = None # Reset after parallel to force explicit next node
388
+ self.current_node = None
333
389
  return self
334
390
 
335
391
  def add_observer(self, observer: WorkflowObserver) -> "Workflow":
336
- """Add an event observer callback to the workflow."""
392
+ """Add an event observer callback to the workflow.
393
+
394
+ Args:
395
+ observer: Callable to handle workflow events.
396
+
397
+ Returns:
398
+ Self for method chaining.
399
+ """
337
400
  if observer not in self._observers:
338
401
  self._observers.append(observer)
339
402
  logger.debug(f"Added observer to workflow: {observer}")
340
- return self # Support chaining
403
+ return self
341
404
 
342
405
  def add_sub_workflow(self, name: str, sub_workflow: "Workflow", inputs: Dict[str, Any], output: str):
343
- """Add a sub-workflow as a node with flexible inputs mapping."""
406
+ """Add a sub-workflow as a node with flexible inputs mapping.
407
+
408
+ Args:
409
+ name: Name of the sub-workflow node.
410
+ sub_workflow: The Workflow instance to embed.
411
+ inputs: Dictionary mapping sub-workflow inputs to context keys or callables.
412
+ output: Context key for the sub-workflow's result.
413
+
414
+ Returns:
415
+ Self for method chaining.
416
+ """
344
417
  sub_node = SubWorkflowNode(sub_workflow, inputs, output)
345
418
  self.nodes[name] = sub_node
346
- self.node_inputs[name] = [] # Inputs handled internally by SubWorkflowNode
419
+ self.node_inputs[name] = []
347
420
  self.node_outputs[name] = output
348
421
  self.current_node = name
349
422
  logger.debug(f"Added sub-workflow {name} with inputs {inputs} and output {output}")
350
423
  return self
351
424
 
352
425
  def build(self, parent_engine: Optional["WorkflowEngine"] = None) -> WorkflowEngine:
353
- """Build and return a WorkflowEngine instance with registered observers."""
426
+ """Build and return a WorkflowEngine instance with registered observers.
427
+
428
+ Args:
429
+ parent_engine: Optional parent WorkflowEngine for sub-workflows.
430
+
431
+ Returns:
432
+ Configured WorkflowEngine instance.
433
+ """
354
434
  engine = WorkflowEngine(self, parent_engine=parent_engine)
355
435
  for observer in self._observers:
356
436
  engine.add_observer(observer)
@@ -358,11 +438,18 @@ class Workflow:
358
438
 
359
439
 
360
440
  class Nodes:
361
- NODE_REGISTRY: Dict[str, Tuple[Callable, List[str], Optional[str]]] = {} # Registry to hold node functions and metadata
441
+ NODE_REGISTRY: Dict[str, Tuple[Callable, List[str], Optional[str]]] = {}
362
442
 
363
443
  @classmethod
364
444
  def define(cls, output: Optional[str] = None):
365
- """Decorator for defining simple workflow nodes."""
445
+ """Decorator for defining simple workflow nodes.
446
+
447
+ Args:
448
+ output: Optional context key for the node's result.
449
+
450
+ Returns:
451
+ Decorator function wrapping the node logic.
452
+ """
366
453
  def decorator(func: Callable) -> Callable:
367
454
  async def wrapped_func(**kwargs):
368
455
  try:
@@ -375,8 +462,6 @@ class Nodes:
375
462
  except Exception as e:
376
463
  logger.error(f"Error in node {func.__name__}: {e}")
377
464
  raise
378
-
379
- # Get parameter names from function signature
380
465
  sig = inspect.signature(func)
381
466
  inputs = [param.name for param in sig.parameters.values()]
382
467
  logger.debug(f"Registering node {func.__name__} with inputs {inputs} and output {output}")
@@ -386,7 +471,14 @@ class Nodes:
386
471
 
387
472
  @classmethod
388
473
  def validate_node(cls, output: str):
389
- """Decorator for nodes that validate inputs."""
474
+ """Decorator for nodes that validate inputs and return a string.
475
+
476
+ Args:
477
+ output: Context key for the validation result.
478
+
479
+ Returns:
480
+ Decorator function wrapping the validation logic.
481
+ """
390
482
  def decorator(func: Callable) -> Callable:
391
483
  async def wrapped_func(**kwargs):
392
484
  try:
@@ -401,8 +493,6 @@ class Nodes:
401
493
  except Exception as e:
402
494
  logger.error(f"Validation error in {func.__name__}: {e}")
403
495
  raise
404
-
405
- # Get parameter names from function signature
406
496
  sig = inspect.signature(func)
407
497
  inputs = [param.name for param in sig.parameters.values()]
408
498
  logger.debug(f"Registering node {func.__name__} with inputs {inputs} and output {output}")
@@ -412,11 +502,18 @@ class Nodes:
412
502
 
413
503
  @classmethod
414
504
  def transform_node(cls, output: str, transformer: Callable[[Any], Any]):
415
- """Decorator for nodes that transform their inputs."""
505
+ """Decorator for nodes that transform their inputs.
506
+
507
+ Args:
508
+ output: Context key for the transformed result.
509
+ transformer: Callable to transform the input.
510
+
511
+ Returns:
512
+ Decorator function wrapping the transformation logic.
513
+ """
416
514
  def decorator(func: Callable) -> Callable:
417
515
  async def wrapped_func(**kwargs):
418
516
  try:
419
- # Apply transformer to the first input value
420
517
  input_key = list(kwargs.keys())[0] if kwargs else None
421
518
  if input_key:
422
519
  transformed_input = transformer(kwargs[input_key])
@@ -430,8 +527,6 @@ class Nodes:
430
527
  except Exception as e:
431
528
  logger.error(f"Error in transform node {func.__name__}: {e}")
432
529
  raise
433
-
434
- # Get parameter names from function signature
435
530
  sig = inspect.signature(func)
436
531
  inputs = [param.name for param in sig.parameters.values()]
437
532
  logger.debug(f"Registering node {func.__name__} with inputs {inputs} and output {output}")
@@ -470,8 +565,9 @@ class Nodes:
470
565
  @classmethod
471
566
  def llm_node(
472
567
  cls,
473
- system_prompt: str,
474
- output: str,
568
+ system_prompt: str = "",
569
+ system_prompt_file: Optional[str] = None,
570
+ output: str = "",
475
571
  prompt_template: str = "",
476
572
  prompt_file: Optional[str] = None,
477
573
  temperature: float = 0.7,
@@ -479,13 +575,38 @@ class Nodes:
479
575
  top_p: float = 1.0,
480
576
  presence_penalty: float = 0.0,
481
577
  frequency_penalty: float = 0.0,
578
+ model: Callable[[Dict[str, Any]], str] = lambda ctx: "gpt-3.5-turbo",
482
579
  **kwargs,
483
580
  ):
484
- """Decorator for creating LLM nodes with plain text output, supporting dynamic parameters via input mappings."""
581
+ """Decorator for creating LLM nodes with plain text output, supporting dynamic parameters.
582
+
583
+ Args:
584
+ system_prompt: Inline system prompt defining LLM behavior.
585
+ system_prompt_file: Path to a system prompt template file (overrides system_prompt).
586
+ output: Context key for the LLM's result.
587
+ prompt_template: Inline Jinja2 template for the user prompt.
588
+ prompt_file: Path to a user prompt template file (overrides prompt_template).
589
+ temperature: Randomness control (0.0 to 1.0).
590
+ max_tokens: Maximum response length.
591
+ top_p: Nucleus sampling parameter (0.0 to 1.0).
592
+ presence_penalty: Penalty for repetition (-2.0 to 2.0).
593
+ frequency_penalty: Penalty for frequent words (-2.0 to 2.0).
594
+ model: Callable or string to determine the LLM model dynamically from context.
595
+ **kwargs: Additional parameters for the LLM call.
596
+
597
+ Returns:
598
+ Decorator function wrapping the LLM logic.
599
+ """
485
600
  def decorator(func: Callable) -> Callable:
486
- async def wrapped_func(model: str, **func_kwargs):
487
- # Extract parameters from func_kwargs if provided, else use defaults
601
+ async def wrapped_func(model_param: str = None, **func_kwargs):
488
602
  system_prompt_to_use = func_kwargs.pop("system_prompt", system_prompt)
603
+ system_prompt_file_to_use = func_kwargs.pop("system_prompt_file", system_prompt_file)
604
+
605
+ if system_prompt_file_to_use:
606
+ system_content = cls._load_prompt_from_file(system_prompt_file_to_use, func_kwargs)
607
+ else:
608
+ system_content = system_prompt_to_use
609
+
489
610
  prompt_template_to_use = func_kwargs.pop("prompt_template", prompt_template)
490
611
  prompt_file_to_use = func_kwargs.pop("prompt_file", prompt_file)
491
612
  temperature_to_use = func_kwargs.pop("temperature", temperature)
@@ -493,25 +614,27 @@ class Nodes:
493
614
  top_p_to_use = func_kwargs.pop("top_p", top_p)
494
615
  presence_penalty_to_use = func_kwargs.pop("presence_penalty", presence_penalty)
495
616
  frequency_penalty_to_use = func_kwargs.pop("frequency_penalty", frequency_penalty)
617
+
618
+ # Prioritize model from func_kwargs (workflow mapping), then model_param, then default
619
+ model_to_use = func_kwargs.get("model", model_param if model_param is not None else model(func_kwargs))
620
+ logger.debug(f"Selected model for {func.__name__}: {model_to_use}")
496
621
 
497
- # Use only signature parameters for template rendering
498
622
  sig = inspect.signature(func)
499
623
  template_vars = {k: v for k, v in func_kwargs.items() if k in sig.parameters}
500
624
  prompt = cls._render_template(prompt_template_to_use, prompt_file_to_use, template_vars)
501
625
  messages = [
502
- {"role": "system", "content": system_prompt_to_use},
626
+ {"role": "system", "content": system_content},
503
627
  {"role": "user", "content": prompt},
504
628
  ]
505
629
 
506
- # Log the model and a preview of the prompt
507
630
  truncated_prompt = prompt[:200] + "..." if len(prompt) > 200 else prompt
508
- logger.info(f"LLM node {func.__name__} using model: {model}")
509
- logger.debug(f"System prompt: {system_prompt_to_use[:100]}...")
631
+ logger.info(f"LLM node {func.__name__} using model: {model_to_use}")
632
+ logger.debug(f"System prompt: {system_content[:100]}...")
510
633
  logger.debug(f"User prompt preview: {truncated_prompt}")
511
634
 
512
635
  try:
513
636
  response = await acompletion(
514
- model=model,
637
+ model=model_to_use,
515
638
  messages=messages,
516
639
  temperature=temperature_to_use,
517
640
  max_tokens=max_tokens_to_use,
@@ -533,8 +656,6 @@ class Nodes:
533
656
  except Exception as e:
534
657
  logger.error(f"Error in LLM node {func.__name__}: {e}")
535
658
  raise
536
-
537
- # Get parameter names from function signature and add 'model'
538
659
  sig = inspect.signature(func)
539
660
  inputs = ['model'] + [param.name for param in sig.parameters.values()]
540
661
  logger.debug(f"Registering node {func.__name__} with inputs {inputs} and output {output}")
@@ -545,9 +666,10 @@ class Nodes:
545
666
  @classmethod
546
667
  def structured_llm_node(
547
668
  cls,
548
- system_prompt: str,
549
- output: str,
550
- response_model: Type[BaseModel],
669
+ system_prompt: str = "",
670
+ system_prompt_file: Optional[str] = None,
671
+ output: str = "",
672
+ response_model: Type[BaseModel] = None,
551
673
  prompt_template: str = "",
552
674
  prompt_file: Optional[str] = None,
553
675
  temperature: float = 0.7,
@@ -555,9 +677,29 @@ class Nodes:
555
677
  top_p: float = 1.0,
556
678
  presence_penalty: float = 0.0,
557
679
  frequency_penalty: float = 0.0,
680
+ model: Callable[[Dict[str, Any]], str] = lambda ctx: "gpt-3.5-turbo",
558
681
  **kwargs,
559
682
  ):
560
- """Decorator for creating LLM nodes with structured output, supporting dynamic parameters via input mappings."""
683
+ """Decorator for creating LLM nodes with structured output, supporting dynamic parameters.
684
+
685
+ Args:
686
+ system_prompt: Inline system prompt defining LLM behavior.
687
+ system_prompt_file: Path to a system prompt template file (overrides system_prompt).
688
+ output: Context key for the LLM's structured result.
689
+ response_model: Pydantic model class for structured output.
690
+ prompt_template: Inline Jinja2 template for the user prompt.
691
+ prompt_file: Path to a user prompt template file (overrides prompt_template).
692
+ temperature: Randomness control (0.0 to 1.0).
693
+ max_tokens: Maximum response length.
694
+ top_p: Nucleus sampling parameter (0.0 to 1.0).
695
+ presence_penalty: Penalty for repetition (-2.0 to 2.0).
696
+ frequency_penalty: Penalty for frequent words (-2.0 to 2.0).
697
+ model: Callable or string to determine the LLM model dynamically from context.
698
+ **kwargs: Additional parameters for the LLM call.
699
+
700
+ Returns:
701
+ Decorator function wrapping the structured LLM logic.
702
+ """
561
703
  try:
562
704
  client = instructor.from_litellm(acompletion)
563
705
  except ImportError:
@@ -565,9 +707,15 @@ class Nodes:
565
707
  raise ImportError("Instructor is required for structured_llm_node")
566
708
 
567
709
  def decorator(func: Callable) -> Callable:
568
- async def wrapped_func(model: str, **func_kwargs):
569
- # Extract parameters from func_kwargs if provided, else use defaults
710
+ async def wrapped_func(model_param: str = None, **func_kwargs):
570
711
  system_prompt_to_use = func_kwargs.pop("system_prompt", system_prompt)
712
+ system_prompt_file_to_use = func_kwargs.pop("system_prompt_file", system_prompt_file)
713
+
714
+ if system_prompt_file_to_use:
715
+ system_content = cls._load_prompt_from_file(system_prompt_file_to_use, func_kwargs)
716
+ else:
717
+ system_content = system_prompt_to_use
718
+
571
719
  prompt_template_to_use = func_kwargs.pop("prompt_template", prompt_template)
572
720
  prompt_file_to_use = func_kwargs.pop("prompt_file", prompt_file)
573
721
  temperature_to_use = func_kwargs.pop("temperature", temperature)
@@ -575,26 +723,28 @@ class Nodes:
575
723
  top_p_to_use = func_kwargs.pop("top_p", top_p)
576
724
  presence_penalty_to_use = func_kwargs.pop("presence_penalty", presence_penalty)
577
725
  frequency_penalty_to_use = func_kwargs.pop("frequency_penalty", frequency_penalty)
726
+
727
+ # Prioritize model from func_kwargs (workflow mapping), then model_param, then default
728
+ model_to_use = func_kwargs.get("model", model_param if model_param is not None else model(func_kwargs))
729
+ logger.debug(f"Selected model for {func.__name__}: {model_to_use}")
578
730
 
579
- # Use only signature parameters for template rendering
580
731
  sig = inspect.signature(func)
581
732
  template_vars = {k: v for k, v in func_kwargs.items() if k in sig.parameters}
582
733
  prompt = cls._render_template(prompt_template_to_use, prompt_file_to_use, template_vars)
583
734
  messages = [
584
- {"role": "system", "content": system_prompt_to_use},
735
+ {"role": "system", "content": system_content},
585
736
  {"role": "user", "content": prompt},
586
737
  ]
587
738
 
588
- # Log the model and a preview of the prompt
589
739
  truncated_prompt = prompt[:200] + "..." if len(prompt) > 200 else prompt
590
- logger.info(f"Structured LLM node {func.__name__} using model: {model}")
591
- logger.debug(f"System prompt: {system_prompt_to_use[:100]}...")
740
+ logger.info(f"Structured LLM node {func.__name__} using model: {model_to_use}")
741
+ logger.debug(f"System prompt: {system_content[:100]}...")
592
742
  logger.debug(f"User prompt preview: {truncated_prompt}")
593
743
  logger.debug(f"Expected response model: {response_model.__name__}")
594
744
 
595
745
  try:
596
746
  structured_response, raw_response = await client.chat.completions.create_with_completion(
597
- model=model,
747
+ model=model_to_use,
598
748
  messages=messages,
599
749
  response_model=response_model,
600
750
  temperature=temperature_to_use,
@@ -619,8 +769,6 @@ class Nodes:
619
769
  except Exception as e:
620
770
  logger.error(f"Error in structured LLM node {func.__name__}: {e}")
621
771
  raise
622
-
623
- # Get parameter names from function signature and add 'model'
624
772
  sig = inspect.signature(func)
625
773
  inputs = ['model'] + [param.name for param in sig.parameters.values()]
626
774
  logger.debug(f"Registering node {func.__name__} with inputs {inputs} and output {output}")
@@ -635,20 +783,26 @@ class Nodes:
635
783
  template: str = "",
636
784
  template_file: Optional[str] = None,
637
785
  ):
638
- """Decorator for creating nodes that apply a Jinja2 template to inputs, supporting dynamic parameters."""
786
+ """Decorator for creating nodes that apply a Jinja2 template to inputs.
787
+
788
+ Args:
789
+ output: Context key for the rendered result.
790
+ template: Inline Jinja2 template string.
791
+ template_file: Path to a template file (overrides template).
792
+
793
+ Returns:
794
+ Decorator function wrapping the template logic.
795
+ """
639
796
  def decorator(func: Callable) -> Callable:
640
797
  async def wrapped_func(**func_kwargs):
641
- # Extract template parameters from func_kwargs if provided, else use defaults
642
798
  template_to_use = func_kwargs.pop("template", template)
643
799
  template_file_to_use = func_kwargs.pop("template_file", template_file)
644
800
 
645
- # Use only signature parameters (excluding rendered_content) for template rendering
646
801
  sig = inspect.signature(func)
647
802
  expected_params = [p.name for p in sig.parameters.values() if p.name != 'rendered_content']
648
803
  template_vars = {k: v for k, v in func_kwargs.items() if k in expected_params}
649
804
  rendered_content = cls._render_template(template_to_use, template_file_to_use, template_vars)
650
805
 
651
- # Filter func_kwargs for the function call
652
806
  filtered_kwargs = {k: v for k, v in func_kwargs.items() if k in expected_params}
653
807
 
654
808
  try:
@@ -661,8 +815,6 @@ class Nodes:
661
815
  except Exception as e:
662
816
  logger.error(f"Error in template node {func.__name__}: {e}")
663
817
  raise
664
-
665
- # Get parameter names from function signature and add 'rendered_content' if not present
666
818
  sig = inspect.signature(func)
667
819
  inputs = [param.name for param in sig.parameters.values()]
668
820
  if 'rendered_content' not in inputs:
@@ -673,15 +825,20 @@ class Nodes:
673
825
  return decorator
674
826
 
675
827
 
676
- # Example workflow with observer integration, updated nodes, input mappings, and dynamic parameters
828
+ # Add a templates directory path at the module level
829
+ TEMPLATES_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "templates")
830
+
831
+ # Helper function to get template paths
832
+ def get_template_path(template_name):
833
+ return os.path.join(TEMPLATES_DIR, template_name)
834
+
835
+
677
836
  async def example_workflow():
678
- # Define Pydantic model for structured output
679
837
  class OrderDetails(BaseModel):
680
838
  order_id: str
681
839
  items_in_stock: List[str]
682
840
  items_out_of_stock: List[str]
683
841
 
684
- # Define an example observer for progress
685
842
  async def progress_monitor(event: WorkflowEvent):
686
843
  print(f"[{event.event_type.value}] {event.node_name or 'Workflow'}")
687
844
  if event.result is not None:
@@ -689,7 +846,6 @@ async def example_workflow():
689
846
  if event.exception is not None:
690
847
  print(f"Exception: {event.exception}")
691
848
 
692
- # Define an observer for token usage
693
849
  class TokenUsageObserver:
694
850
  def __init__(self):
695
851
  self.total_prompt_tokens = 0
@@ -712,16 +868,15 @@ async def example_workflow():
712
868
  for node, usage in self.node_usages.items():
713
869
  print(f"Node {node}: {usage}")
714
870
 
715
- # Define nodes
716
871
  @Nodes.validate_node(output="validation_result")
717
872
  async def validate_order(order: Dict[str, Any]) -> str:
718
873
  return "Order validated" if order.get("items") else "Invalid order"
719
874
 
720
875
  @Nodes.structured_llm_node(
721
- system_prompt="You are an inventory checker. Respond with a JSON object containing 'order_id', 'items_in_stock', and 'items_out_of_stock'.",
876
+ system_prompt_file=get_template_path("system_check_inventory.j2"),
722
877
  output="inventory_status",
723
878
  response_model=OrderDetails,
724
- prompt_template="Check if the following items are in stock: {{ items }}. Return the result in JSON format with 'order_id' set to '123'.",
879
+ prompt_file=get_template_path("prompt_check_inventory.j2"),
725
880
  )
726
881
  async def check_inventory(items: List[str]) -> OrderDetails:
727
882
  return OrderDetails(order_id="123", items_in_stock=["item1"], items_out_of_stock=[])
@@ -757,13 +912,10 @@ async def example_workflow():
757
912
  async def format_order_message(rendered_content: str, items: List[str]) -> str:
758
913
  return rendered_content
759
914
 
760
- # Sub-workflow for payment and shipping
761
915
  payment_shipping_sub_wf = Workflow("process_payment").sequence("process_payment", "arrange_shipping")
762
916
 
763
- # Instantiate token usage observer
764
917
  token_observer = TokenUsageObserver()
765
918
 
766
- # Main workflow with dynamic parameter overrides
767
919
  workflow = (
768
920
  Workflow("validate_order")
769
921
  .add_observer(progress_monitor)
@@ -772,13 +924,13 @@ async def example_workflow():
772
924
  .node("transform_items")
773
925
  .node("format_order_message", inputs_mapping={
774
926
  "items": "items",
775
- "template": "Custom order: {{ items | join(', ') }}" # Dynamic override
927
+ "template": "Custom order: {{ items | join(', ') }}"
776
928
  })
777
929
  .node("check_inventory", inputs_mapping={
778
930
  "model": lambda ctx: "gemini/gemini-2.0-flash",
779
931
  "items": "transformed_items",
780
- "temperature": 0.5, # Dynamic override
781
- "max_tokens": 1000 # Dynamic override
932
+ "temperature": 0.5,
933
+ "max_tokens": 1000
782
934
  })
783
935
  .add_sub_workflow(
784
936
  "payment_shipping",
@@ -786,15 +938,17 @@ async def example_workflow():
786
938
  inputs={"order": lambda ctx: {"items": ctx["items"]}},
787
939
  output="shipping_confirmation"
788
940
  )
789
- .branch([
790
- ("payment_shipping", lambda ctx: len(ctx.get("inventory_status").items_out_of_stock) == 0 if ctx.get("inventory_status") else False),
791
- ("notify_customer_out_of_stock", lambda ctx: len(ctx.get("inventory_status").items_out_of_stock) > 0 if ctx.get("inventory_status") else True)
792
- ])
941
+ .branch(
942
+ [
943
+ ("payment_shipping", lambda ctx: len(ctx.get("inventory_status").items_out_of_stock) == 0 if ctx.get("inventory_status") else False),
944
+ ("notify_customer_out_of_stock", lambda ctx: len(ctx.get("inventory_status").items_out_of_stock) > 0 if ctx.get("inventory_status") else True)
945
+ ],
946
+ next_node="update_order_status"
947
+ )
793
948
  .converge("update_order_status")
794
949
  .sequence("update_order_status", "send_confirmation_email")
795
950
  )
796
951
 
797
- # Execute workflow
798
952
  initial_context = {"customer_order": {"items": ["item1", "item2"]}, "items": ["item1", "item2"]}
799
953
  engine = workflow.build()
800
954
  result = await engine.run(initial_context)