quantalogic 0.53.0__py3-none-any.whl → 0.56.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,13 +4,7 @@ from pydantic import BaseModel, Field, model_validator
4
4
 
5
5
 
6
6
  class FunctionDefinition(BaseModel):
7
- """
8
- Definition of a function used in the workflow.
9
-
10
- This model supports both embedded functions (inline code) and external functions sourced
11
- from Python modules, including PyPI packages, local files, or remote URLs.
12
- """
13
-
7
+ """Definition of a function used in the workflow."""
14
8
  type: str = Field(
15
9
  ...,
16
10
  description="Type of function source. Must be 'embedded' for inline code or 'external' for module-based functions.",
@@ -55,16 +49,15 @@ class FunctionDefinition(BaseModel):
55
49
 
56
50
  class LLMConfig(BaseModel):
57
51
  """Configuration for LLM-based nodes."""
58
-
59
52
  model: str = Field(
60
53
  default="gpt-3.5-turbo", description="The LLM model to use (e.g., 'gpt-3.5-turbo', 'gemini/gemini-2.0-flash')."
61
54
  )
62
55
  system_prompt: Optional[str] = Field(None, description="System prompt defining the LLM's role or context.")
63
56
  prompt_template: str = Field(
64
- default="{{ input }}", description="Jinja2 template for the user prompt (e.g., 'Summarize {{ text }}'). Ignored if prompt_file is set."
57
+ default="{{ input }}", description="Jinja2 template for the user prompt. Ignored if prompt_file is set."
65
58
  )
66
59
  prompt_file: Optional[str] = Field(
67
- None, description="Path to an external Jinja2 template file (e.g., 'prompts/summary.j2'). Takes precedence over prompt_template if provided."
60
+ None, description="Path to an external Jinja2 template file. Takes precedence over prompt_template."
68
61
  )
69
62
  temperature: float = Field(
70
63
  default=0.7, ge=0.0, le=1.0, description="Controls randomness of LLM output (0.0 to 1.0)."
@@ -77,13 +70,10 @@ class LLMConfig(BaseModel):
77
70
  frequency_penalty: float = Field(
78
71
  default=0.0, ge=-2.0, le=2.0, description="Penalty for repeating words (-2.0 to 2.0)."
79
72
  )
80
- stop: Optional[List[str]] = Field(None, description="List of stop sequences for LLM generation (e.g., ['\\n']).")
73
+ stop: Optional[List[str]] = Field(None, description="List of stop sequences for LLM generation.")
81
74
  response_model: Optional[str] = Field(
82
75
  None,
83
- description=(
84
- "Path to a Pydantic model for structured output (e.g., 'my_module:OrderDetails'). "
85
- "If specified, uses structured_llm_node; otherwise, uses llm_node."
86
- ),
76
+ description="Path to a Pydantic model for structured output (e.g., 'my_module:OrderDetails')."
87
77
  )
88
78
  api_key: Optional[str] = Field(None, description="Custom API key for the LLM provider, if required.")
89
79
 
@@ -97,13 +87,30 @@ class LLMConfig(BaseModel):
97
87
  return data
98
88
 
99
89
 
100
- class NodeDefinition(BaseModel):
101
- """
102
- Definition of a workflow node.
90
+ class TemplateConfig(BaseModel):
91
+ """Configuration for template-based nodes."""
92
+ template: str = Field(
93
+ default="", description="Jinja2 template string to render. Ignored if template_file is set."
94
+ )
95
+ template_file: Optional[str] = Field(
96
+ None, description="Path to an external Jinja2 template file. Takes precedence over template."
97
+ )
98
+
99
+ @model_validator(mode="before")
100
+ @classmethod
101
+ def check_template_source(cls, data: Any) -> Any:
102
+ """Ensure template_file and template are used appropriately."""
103
+ template_file = data.get("template_file")
104
+ template = data.get("template")
105
+ if not template and not template_file:
106
+ raise ValueError("Either 'template' or 'template_file' must be provided")
107
+ if template_file and not isinstance(template_file, str):
108
+ raise ValueError("template_file must be a string path to a Jinja2 template file")
109
+ return data
103
110
 
104
- A node must specify exactly one of 'function', 'sub_workflow', or 'llm_config'.
105
- """
106
111
 
112
+ class NodeDefinition(BaseModel):
113
+ """Definition of a workflow node with template_node and inputs_mapping support."""
107
114
  function: Optional[str] = Field(
108
115
  None, description="Name of the function to execute (references a FunctionDefinition)."
109
116
  )
@@ -111,12 +118,14 @@ class NodeDefinition(BaseModel):
111
118
  None, description="Nested workflow definition for sub-workflow nodes."
112
119
  )
113
120
  llm_config: Optional[LLMConfig] = Field(None, description="Configuration for LLM-based nodes.")
121
+ template_config: Optional[TemplateConfig] = Field(None, description="Configuration for template-based nodes.")
122
+ inputs_mapping: Optional[Dict[str, str]] = Field(
123
+ None,
124
+ description="Mapping of node inputs to context keys or stringified lambda expressions (e.g., 'lambda ctx: value')."
125
+ )
114
126
  output: Optional[str] = Field(
115
127
  None,
116
- description=(
117
- "Context key to store the node's result. Defaults to '<node_name>_result' "
118
- "for function or LLM nodes if not specified."
119
- ),
128
+ description="Context key to store the node's result. Defaults to '<node_name>_result' if applicable."
120
129
  )
121
130
  retries: int = Field(default=3, ge=0, description="Number of retry attempts on failure.")
122
131
  delay: float = Field(default=1.0, ge=0.0, description="Delay in seconds between retries.")
@@ -127,62 +136,72 @@ class NodeDefinition(BaseModel):
127
136
 
128
137
  @model_validator(mode="before")
129
138
  @classmethod
130
- def check_function_or_sub_workflow_or_llm(cls, data: Any) -> Any:
131
- """Ensure a node has exactly one of 'function', 'sub_workflow', or 'llm_config'."""
139
+ def check_function_or_sub_workflow_or_llm_or_template(cls, data: Any) -> Any:
140
+ """Ensure a node has exactly one of 'function', 'sub_workflow', 'llm_config', or 'template_config'."""
132
141
  func = data.get("function")
133
142
  sub_wf = data.get("sub_workflow")
134
143
  llm = data.get("llm_config")
135
- if sum(x is not None for x in (func, sub_wf, llm)) != 1:
136
- raise ValueError("Node must have exactly one of 'function', 'sub_workflow', or 'llm_config'")
144
+ template = data.get("template_config")
145
+ if sum(x is not None for x in (func, sub_wf, llm, template)) != 1:
146
+ raise ValueError("Node must have exactly one of 'function', 'sub_workflow', 'llm_config', or 'template_config'")
137
147
  return data
138
148
 
139
149
 
150
+ class BranchCondition(BaseModel):
151
+ """Definition of a branch condition for a transition."""
152
+ to_node: str = Field(
153
+ ..., description="Target node name for this branch."
154
+ )
155
+ condition: Optional[str] = Field(
156
+ None, description="Python expression using 'ctx' for conditional branching."
157
+ )
158
+
159
+
140
160
  class TransitionDefinition(BaseModel):
141
161
  """Definition of a transition between nodes."""
142
-
143
162
  from_node: str = Field(
144
163
  ...,
145
164
  description="Source node name for the transition.",
146
165
  )
147
- to_node: Union[str, List[str]] = Field(
148
- ..., description="Target node(s). A string for sequential, a list for parallel execution."
166
+ to_node: Union[str, List[Union[str, BranchCondition]]] = Field(
167
+ ...,
168
+ description=(
169
+ "Target node(s). Can be: a string, list of strings (parallel), or list of BranchCondition (branching)."
170
+ ),
149
171
  )
150
172
  condition: Optional[str] = Field(
151
- None, description="Python expression using 'ctx' for conditional transitions (e.g., 'ctx.get(\"in_stock\")')."
173
+ None,
174
+ description="Python expression using 'ctx' for simple transitions."
152
175
  )
153
176
 
154
177
 
155
178
  class WorkflowStructure(BaseModel):
156
179
  """Structure defining the workflow's execution flow."""
157
-
158
180
  start: Optional[str] = Field(None, description="Name of the starting node.")
159
181
  transitions: List[TransitionDefinition] = Field(
160
182
  default_factory=list, description="List of transitions between nodes."
161
183
  )
184
+ convergence_nodes: List[str] = Field(
185
+ default_factory=list, description="List of nodes where branches converge."
186
+ )
162
187
 
163
188
 
164
189
  class WorkflowDefinition(BaseModel):
165
190
  """Top-level definition of the workflow."""
166
-
167
191
  functions: Dict[str, FunctionDefinition] = Field(
168
- default_factory=dict, description="Dictionary of function definitions used in the workflow."
192
+ default_factory=dict, description="Dictionary of function definitions."
169
193
  )
170
194
  nodes: Dict[str, NodeDefinition] = Field(default_factory=dict, description="Dictionary of node definitions.")
171
195
  workflow: WorkflowStructure = Field(
172
- default_factory=lambda: WorkflowStructure(start=None), description="Main workflow structure with start node and transitions."
196
+ default_factory=lambda: WorkflowStructure(start=None), description="Main workflow structure."
173
197
  )
174
198
  observers: List[str] = Field(
175
- default_factory=list, description="List of observer function names to monitor workflow execution."
199
+ default_factory=list, description="List of observer function names."
176
200
  )
177
201
  dependencies: List[str] = Field(
178
202
  default_factory=list,
179
- description=(
180
- "List of Python module dependencies required by the workflow. "
181
- "Examples: PyPI packages ('requests>=2.28.0'), local paths ('/path/to/module.py'), "
182
- "or remote URLs ('https://example.com/module.py'). Processed during workflow instantiation."
183
- ),
203
+ description="List of Python module dependencies."
184
204
  )
185
205
 
186
206
 
187
- # Resolve forward reference for sub_workflow in NodeDefinition
188
207
  NodeDefinition.model_rebuild()
@@ -1,8 +1,8 @@
1
1
  import re
2
- from typing import Dict, List, Optional, Set, Tuple
2
+ from typing import Dict, List, Optional, Set, Tuple, Union
3
3
 
4
4
  from quantalogic.flow.flow_manager import WorkflowManager
5
- from quantalogic.flow.flow_manager_schema import NodeDefinition, WorkflowDefinition
5
+ from quantalogic.flow.flow_manager_schema import BranchCondition, NodeDefinition, WorkflowDefinition
6
6
 
7
7
 
8
8
  def get_node_label_and_type(node_name: str, node_def: Optional[NodeDefinition], has_conditions: bool) -> Tuple[str, str, str]:
@@ -12,29 +12,48 @@ def get_node_label_and_type(node_name: str, node_def: Optional[NodeDefinition],
12
12
  Args:
13
13
  node_name: The name of the node.
14
14
  node_def: The NodeDefinition object from the workflow, or None if undefined.
15
- has_conditions: True if the node has outgoing transitions with conditions.
15
+ has_conditions: True if the node has outgoing transitions with conditions (branching).
16
16
 
17
17
  Returns:
18
18
  A tuple of (display label, type key for styling, shape identifier).
19
19
  """
20
- # No truncation unless necessary, escape quotes for safety
20
+ # Escape quotes for Mermaid compatibility
21
21
  escaped_name = node_name.replace('"', '\\"')
22
-
23
- # Use diamond shape for nodes with conditional transitions, rectangle otherwise
24
22
  shape = "diamond" if has_conditions else "rect"
25
23
 
26
24
  if not node_def:
27
25
  return f"{escaped_name} (unknown)", "unknown", shape
28
-
26
+
27
+ # Base label starts with node name and type
29
28
  if node_def.function:
30
- return f"{escaped_name} (function)", "function", shape
29
+ label = f"{escaped_name} (function)"
30
+ node_type = "function"
31
31
  elif node_def.llm_config:
32
32
  if node_def.llm_config.response_model:
33
- return f"{escaped_name} (structured LLM)", "structured_llm", shape
34
- return f"{escaped_name} (LLM)", "llm", shape
33
+ label = f"{escaped_name} (structured LLM)"
34
+ node_type = "structured_llm"
35
+ else:
36
+ label = f"{escaped_name} (LLM)"
37
+ node_type = "llm"
38
+ elif node_def.template_config:
39
+ label = f"{escaped_name} (template)"
40
+ node_type = "template"
35
41
  elif node_def.sub_workflow:
36
- return f"{escaped_name} (Sub-Workflow)", "sub_workflow", shape
37
- return f"{escaped_name} (unknown)", "unknown", shape
42
+ label = f"{escaped_name} (Sub-Workflow)"
43
+ node_type = "sub_workflow"
44
+ else:
45
+ label = f"{escaped_name} (unknown)"
46
+ node_type = "unknown"
47
+
48
+ # Append inputs_mapping if present
49
+ if node_def and node_def.inputs_mapping:
50
+ mapping_str = ", ".join(f"{k}={v}" for k, v in node_def.inputs_mapping.items())
51
+ # Truncate if too long for readability (e.g., > 30 chars)
52
+ if len(mapping_str) > 30:
53
+ mapping_str = mapping_str[:27] + "..."
54
+ label += f"\\nInputs: {mapping_str}"
55
+
56
+ return label, node_type, shape
38
57
 
39
58
 
40
59
  def generate_mermaid_diagram(
@@ -51,7 +70,7 @@ def generate_mermaid_diagram(
51
70
  workflow_def: The workflow definition to visualize.
52
71
  include_subgraphs: If True, nests sub-workflows in Mermaid subgraphs (flowchart only).
53
72
  title: Optional title for the diagram.
54
- include_legend: If True, adds a comment-based legend explaining node types.
73
+ include_legend: If True, adds a comment-based legend explaining node types and shapes.
55
74
  diagram_type: Type of diagram to generate: "flowchart" (default) or "stateDiagram".
56
75
 
57
76
  Returns:
@@ -68,6 +87,7 @@ def generate_mermaid_diagram(
68
87
  "function": "fill:#90CAF9,stroke:#42A5F5,stroke-width:2px", # Pastel Blue
69
88
  "structured_llm": "fill:#A5D6A7,stroke:#66BB6A,stroke-width:2px", # Pastel Green
70
89
  "llm": "fill:#CE93D8,stroke:#AB47BC,stroke-width:2px", # Pastel Purple
90
+ "template": "fill:#FCE4EC,stroke:#F06292,stroke-width:2px", # Pastel Pink (new for template)
71
91
  "sub_workflow": "fill:#FFCCBC,stroke:#FF7043,stroke-width:2px", # Pastel Orange
72
92
  "unknown": "fill:#CFD8DC,stroke:#B0BEC5,stroke-width:2px" # Pastel Grey
73
93
  }
@@ -75,7 +95,7 @@ def generate_mermaid_diagram(
75
95
  # Shape mappings for flowchart syntax
76
96
  shape_syntax: Dict[str, Tuple[str, str]] = {
77
97
  "rect": ("[", "]"), # Rectangle for standard nodes
78
- "diamond": ("{{", "}}") # Diamond for decision points
98
+ "diamond": ("{{", "}}") # Diamond for decision points (branching)
79
99
  }
80
100
 
81
101
  # Validate node names for Mermaid compatibility (alphanumeric, underscore, hyphen)
@@ -94,17 +114,26 @@ def generate_mermaid_diagram(
94
114
  raise ValueError(f"Invalid node name '{trans.to_node}' for Mermaid")
95
115
  all_nodes.add(trans.to_node)
96
116
  else:
97
- for to_node in trans.to_node:
98
- if re.search(invalid_chars, to_node):
99
- raise ValueError(f"Invalid node name '{to_node}' for Mermaid")
100
- all_nodes.add(to_node)
101
-
102
- # Determine which nodes have conditional transitions
117
+ for tn in trans.to_node:
118
+ target = tn if isinstance(tn, str) else tn.to_node
119
+ if re.search(invalid_chars, target):
120
+ raise ValueError(f"Invalid node name '{target}' for Mermaid")
121
+ all_nodes.add(target)
122
+ for conv_node in workflow_def.workflow.convergence_nodes:
123
+ if re.search(invalid_chars, conv_node):
124
+ raise ValueError(f"Invalid node name '{conv_node}' for Mermaid")
125
+ all_nodes.add(conv_node)
126
+
127
+ # Determine nodes with conditional transitions (branching)
103
128
  conditional_nodes: Set[str] = set()
104
129
  for trans in workflow_def.workflow.transitions:
105
- if trans.condition and isinstance(trans.to_node, str):
130
+ if (trans.condition and isinstance(trans.to_node, str)) or \
131
+ (isinstance(trans.to_node, list) and any(isinstance(tn, BranchCondition) and tn.condition for tn in trans.to_node)):
106
132
  conditional_nodes.add(trans.from_node)
107
133
 
134
+ # Identify convergence nodes
135
+ convergence_nodes: Set[str] = set(workflow_def.workflow.convergence_nodes)
136
+
108
137
  # Shared node definitions and types
109
138
  node_types: Dict[str, str] = {}
110
139
  node_shapes: Dict[str, str] = {} # Only used for flowchart
@@ -119,13 +148,14 @@ def generate_mermaid_diagram(
119
148
  if title:
120
149
  mermaid_code += f" %% Diagram: {title}\n"
121
150
 
122
- # Optional legend for UX
151
+ # Optional legend for UX, updated to include template nodes
123
152
  if include_legend:
124
153
  mermaid_code += " %% Legend:\n"
125
154
  if diagram_type == "flowchart":
126
- mermaid_code += " %% - Rectangle: Process Step\n"
127
- mermaid_code += " %% - Diamond: Decision Point\n"
128
- mermaid_code += " %% - Colors: Blue (Function), Green (Structured LLM), Purple (LLM), Orange (Sub-Workflow), Grey (Unknown)\n"
155
+ mermaid_code += " %% - Rectangle: Process Step or Convergence Point\n"
156
+ mermaid_code += " %% - Diamond: Decision Point (Branching)\n"
157
+ mermaid_code += " %% - Colors: Blue (Function), Green (Structured LLM), Purple (LLM), Pink (Template), Orange (Sub-Workflow), Grey (Unknown)\n"
158
+ mermaid_code += " %% - Dashed Border: Convergence Node\n"
129
159
 
130
160
  if diagram_type == "flowchart":
131
161
  # Flowchart-specific: Generate node definitions with shapes
@@ -140,10 +170,10 @@ def generate_mermaid_diagram(
140
170
  node_shapes[node] = shape
141
171
 
142
172
  # Add node definitions
143
- for node_def_str in node_defs:
173
+ for node_def_str in sorted(node_defs): # Sort for consistent output
144
174
  mermaid_code += f" {node_def_str}\n"
145
175
 
146
- # Generate arrows for transitions (all solid lines)
176
+ # Generate arrows for transitions
147
177
  for trans in workflow_def.workflow.transitions:
148
178
  from_node = trans.from_node
149
179
  if isinstance(trans.to_node, str):
@@ -155,15 +185,29 @@ def generate_mermaid_diagram(
155
185
  else:
156
186
  mermaid_code += f' {from_node} --> {to_node}\n'
157
187
  else:
158
- for to_node in trans.to_node:
159
- mermaid_code += f' {from_node} --> {to_node}\n'
188
+ for tn in trans.to_node:
189
+ if isinstance(tn, str):
190
+ # Parallel transition (no condition)
191
+ mermaid_code += f' {from_node} --> {tn}\n'
192
+ else: # BranchCondition
193
+ to_node = tn.to_node
194
+ condition = tn.condition
195
+ if condition:
196
+ cond = condition.replace('"', '\\"')[:30] + ("..." if len(condition) > 30 else "")
197
+ mermaid_code += f' {from_node} -->|"{cond}"| {to_node}\n'
198
+ else:
199
+ mermaid_code += f' {from_node} --> {to_node}\n'
160
200
 
161
201
  # Add styles for node types
162
202
  for node, node_type in node_types.items():
163
203
  if node_type in node_styles:
164
- mermaid_code += f" style {node} {node_styles[node_type]}\n"
204
+ style = node_styles[node_type]
205
+ # Add dashed stroke for convergence nodes
206
+ if node in convergence_nodes:
207
+ style += ",stroke-dasharray:5 5"
208
+ mermaid_code += f" style {node} {style}\n"
165
209
 
166
- # Highlight the start node
210
+ # Highlight the start node with a thicker border
167
211
  if workflow_def.workflow.start and workflow_def.workflow.start in node_types:
168
212
  mermaid_code += f" style {workflow_def.workflow.start} stroke-width:4px\n"
169
213
 
@@ -178,8 +222,10 @@ def generate_mermaid_diagram(
178
222
  if isinstance(trans.to_node, str):
179
223
  sub_nodes.add(trans.to_node)
180
224
  else:
181
- sub_nodes.update(trans.to_node)
182
- for sub_node in sub_nodes:
225
+ for tn in trans.to_node:
226
+ target = tn if isinstance(tn, str) else tn.to_node
227
+ sub_nodes.add(target)
228
+ for sub_node in sorted(sub_nodes): # Sort for consistency
183
229
  mermaid_code += f" {sub_node}[[{sub_node}]]\n"
184
230
  mermaid_code += " end\n"
185
231
 
@@ -189,6 +235,9 @@ def generate_mermaid_diagram(
189
235
  node_def_state: Optional[NodeDefinition] = workflow_def.nodes.get(node)
190
236
  has_conditions = node in conditional_nodes
191
237
  label, node_type, _ = get_node_label_and_type(node, node_def_state, has_conditions) # Shape unused
238
+ # Append (Convergence) to label for convergence nodes
239
+ if node in convergence_nodes:
240
+ label += " (Convergence)"
192
241
  mermaid_code += f" state \"{label}\" as {node}\n"
193
242
  node_types[node] = node_type
194
243
 
@@ -208,75 +257,109 @@ def generate_mermaid_diagram(
208
257
  else:
209
258
  mermaid_code += f" {from_node} --> {to_node}\n"
210
259
  else:
211
- # Parallel transitions approximated with a note
212
- for to_node in trans.to_node:
213
- mermaid_code += f" {from_node} --> {to_node} : parallel\n"
260
+ for tn in trans.to_node:
261
+ if isinstance(tn, str):
262
+ # Parallel transition approximated with a note
263
+ mermaid_code += f" {from_node} --> {tn} : parallel\n"
264
+ else: # BranchCondition
265
+ to_node = tn.to_node
266
+ condition = tn.condition
267
+ if condition:
268
+ cond = condition.replace('"', '\\"')[:30] + ("..." if len(condition) > 30 else "")
269
+ mermaid_code += f" {from_node} --> {to_node} : {cond}\n"
270
+ else:
271
+ mermaid_code += f" {from_node} --> {to_node}\n"
214
272
 
215
273
  # Add styles for node types
216
274
  for node, node_type in node_types.items():
217
275
  if node_type in node_styles:
218
- mermaid_code += f" style {node} {node_styles[node_type]}\n"
276
+ style = node_styles[node_type]
277
+ # Add dashed stroke for convergence nodes
278
+ if node in convergence_nodes:
279
+ style += ",stroke-dasharray:5 5"
280
+ mermaid_code += f" style {node} {style}\n"
219
281
 
220
282
  mermaid_code += "```\n"
221
283
  return mermaid_code
222
284
 
223
285
 
224
286
  def main() -> None:
225
- """Create a complex workflow and print its improved Mermaid diagram representation."""
287
+ """Create a complex workflow with branch, converge, template node, and input mapping, and print its Mermaid diagram."""
226
288
  manager = WorkflowManager()
227
289
 
228
290
  # Add functions
229
291
  manager.add_function(
230
- name="analyze_sentiment",
292
+ name="say_hello",
231
293
  type_="embedded",
232
- code="async def analyze_sentiment(summary: str) -> str:\n return 'positive' if 'good' in summary.lower() else 'negative'",
294
+ code="def say_hello():\n return 'Hello, World!'"
233
295
  )
234
296
  manager.add_function(
235
- name="extract_keywords",
297
+ name="check_condition",
236
298
  type_="embedded",
237
- code="async def extract_keywords(summary: str) -> str:\n return 'key1, key2'",
299
+ code="def check_condition(text: str):\n return 'yes' if 'Hello' in text else 'no'"
238
300
  )
239
301
  manager.add_function(
240
- name="publish_content",
302
+ name="say_goodbye",
241
303
  type_="embedded",
242
- code="async def publish_content(summary: str, sentiment: str, keywords: str) -> str:\n return 'Published'",
304
+ code="def say_goodbye():\n return 'Goodbye, World!'"
243
305
  )
244
306
  manager.add_function(
245
- name="revise_content",
307
+ name="finalize",
246
308
  type_="embedded",
247
- code="async def revise_content(summary: str) -> str:\n return 'Revised summary'",
309
+ code="def finalize(text: str):\n return 'Done'"
248
310
  )
249
311
 
312
+ # Add nodes
313
+ manager.add_node(name="start", function="say_hello", output="text")
314
+ manager.add_node(name="check", function="check_condition", output="result",
315
+ inputs_mapping={"text": "text"})
316
+ manager.add_node(name="goodbye", function="say_goodbye", output="farewell")
317
+ manager.add_node(name="finalize", function="finalize", output="status",
318
+ inputs_mapping={"text": "lambda ctx: ctx['farewell'] if ctx['result'] == 'no' else ctx['ai_result']"})
319
+
250
320
  # Add LLM node
251
- llm_config = {
252
- "model": "grok/xai",
253
- "system_prompt": "You are a concise summarizer.",
254
- "prompt_template": "Summarize the following text: {{ input_text }}",
255
- "temperature": "0.5",
256
- "max_tokens": "150",
257
- }
258
- manager.add_node(name="summarize_text", llm_config=llm_config, output="summary")
321
+ manager.add_node(
322
+ name="ai_node",
323
+ llm_config={
324
+ "model": "gpt-3.5-turbo",
325
+ "prompt_template": "{{text}}",
326
+ "temperature": 0.7
327
+ },
328
+ output="ai_result"
329
+ )
259
330
 
260
- # Add function nodes
261
- manager.add_node(name="sentiment_analysis", function="analyze_sentiment", output="sentiment")
262
- manager.add_node(name="keyword_extraction", function="extract_keywords", output="keywords")
263
- manager.add_node(name="publish", function="publish_content", output="status")
264
- manager.add_node(name="revise", function="revise_content", output="revised_summary")
331
+ # Add template node
332
+ manager.add_node(
333
+ name="template_node",
334
+ template_config={
335
+ "template": "Response: {{text}} - {{result}}"
336
+ },
337
+ output="template_output",
338
+ inputs_mapping={"text": "text", "result": "result"}
339
+ )
265
340
 
266
- # Define workflow structure
267
- manager.set_start_node("summarize_text")
268
- manager.add_transition(from_node="summarize_text", to_node=["sentiment_analysis", "keyword_extraction"])
269
- manager.add_transition(from_node="sentiment_analysis", to_node="publish", condition="ctx['sentiment'] == 'positive'")
270
- manager.add_transition(from_node="sentiment_analysis", to_node="revise", condition="ctx['sentiment'] == 'negative'")
271
- manager.add_transition(from_node="keyword_extraction", to_node="publish")
341
+ # Define workflow structure with branch and converge
342
+ manager.set_start_node("start")
343
+ manager.add_transition(from_node="start", to_node="check")
344
+ manager.add_transition(
345
+ from_node="check",
346
+ to_node=[
347
+ BranchCondition(to_node="ai_node", condition="ctx['result'] == 'yes'"),
348
+ BranchCondition(to_node="goodbye", condition="ctx['result'] == 'no'")
349
+ ]
350
+ )
351
+ manager.add_transition(from_node="ai_node", to_node="finalize")
352
+ manager.add_transition(from_node="goodbye", to_node="finalize")
353
+ manager.add_transition(from_node="finalize", to_node="template_node")
354
+ manager.add_convergence_node("finalize")
272
355
 
273
356
  # Generate and print both diagrams
274
357
  workflow_def = manager.workflow
275
358
  print("Flowchart (default):")
276
- print(generate_mermaid_diagram(workflow_def, include_subgraphs=False, title="Content Processing Workflow"))
359
+ print(generate_mermaid_diagram(workflow_def, include_subgraphs=False, title="Sample Workflow with Template and Mapping"))
277
360
  print("\nState Diagram:")
278
- print(generate_mermaid_diagram(workflow_def, diagram_type="stateDiagram", title="Content Processing Workflow"))
279
-
361
+ print(generate_mermaid_diagram(workflow_def, diagram_type="stateDiagram", title="Sample Workflow with Template and Mapping"))
362
+
280
363
 
281
364
  if __name__ == "__main__":
282
365
  main()