quantalogic 0.53.0__py3-none-any.whl → 0.56.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quantalogic/__init__.py +7 -0
- quantalogic/flow/flow.py +267 -80
- quantalogic/flow/flow_extractor.py +216 -87
- quantalogic/flow/flow_generator.py +157 -88
- quantalogic/flow/flow_manager.py +252 -125
- quantalogic/flow/flow_manager_schema.py +62 -43
- quantalogic/flow/flow_mermaid.py +151 -68
- quantalogic/flow/flow_validator.py +204 -77
- quantalogic/flow/flow_yaml.md +341 -156
- quantalogic/tools/safe_python_interpreter_tool.py +6 -1
- quantalogic/xml_parser.py +5 -1
- quantalogic/xml_tool_parser.py +4 -1
- {quantalogic-0.53.0.dist-info → quantalogic-0.56.0.dist-info}/METADATA +16 -6
- {quantalogic-0.53.0.dist-info → quantalogic-0.56.0.dist-info}/RECORD +17 -17
- {quantalogic-0.53.0.dist-info → quantalogic-0.56.0.dist-info}/LICENSE +0 -0
- {quantalogic-0.53.0.dist-info → quantalogic-0.56.0.dist-info}/WHEEL +0 -0
- {quantalogic-0.53.0.dist-info → quantalogic-0.56.0.dist-info}/entry_points.txt +0 -0
@@ -4,13 +4,7 @@ from pydantic import BaseModel, Field, model_validator
|
|
4
4
|
|
5
5
|
|
6
6
|
class FunctionDefinition(BaseModel):
|
7
|
-
"""
|
8
|
-
Definition of a function used in the workflow.
|
9
|
-
|
10
|
-
This model supports both embedded functions (inline code) and external functions sourced
|
11
|
-
from Python modules, including PyPI packages, local files, or remote URLs.
|
12
|
-
"""
|
13
|
-
|
7
|
+
"""Definition of a function used in the workflow."""
|
14
8
|
type: str = Field(
|
15
9
|
...,
|
16
10
|
description="Type of function source. Must be 'embedded' for inline code or 'external' for module-based functions.",
|
@@ -55,16 +49,15 @@ class FunctionDefinition(BaseModel):
|
|
55
49
|
|
56
50
|
class LLMConfig(BaseModel):
|
57
51
|
"""Configuration for LLM-based nodes."""
|
58
|
-
|
59
52
|
model: str = Field(
|
60
53
|
default="gpt-3.5-turbo", description="The LLM model to use (e.g., 'gpt-3.5-turbo', 'gemini/gemini-2.0-flash')."
|
61
54
|
)
|
62
55
|
system_prompt: Optional[str] = Field(None, description="System prompt defining the LLM's role or context.")
|
63
56
|
prompt_template: str = Field(
|
64
|
-
default="{{ input }}", description="Jinja2 template for the user prompt
|
57
|
+
default="{{ input }}", description="Jinja2 template for the user prompt. Ignored if prompt_file is set."
|
65
58
|
)
|
66
59
|
prompt_file: Optional[str] = Field(
|
67
|
-
None, description="Path to an external Jinja2 template file
|
60
|
+
None, description="Path to an external Jinja2 template file. Takes precedence over prompt_template."
|
68
61
|
)
|
69
62
|
temperature: float = Field(
|
70
63
|
default=0.7, ge=0.0, le=1.0, description="Controls randomness of LLM output (0.0 to 1.0)."
|
@@ -77,13 +70,10 @@ class LLMConfig(BaseModel):
|
|
77
70
|
frequency_penalty: float = Field(
|
78
71
|
default=0.0, ge=-2.0, le=2.0, description="Penalty for repeating words (-2.0 to 2.0)."
|
79
72
|
)
|
80
|
-
stop: Optional[List[str]] = Field(None, description="List of stop sequences for LLM generation
|
73
|
+
stop: Optional[List[str]] = Field(None, description="List of stop sequences for LLM generation.")
|
81
74
|
response_model: Optional[str] = Field(
|
82
75
|
None,
|
83
|
-
description=(
|
84
|
-
"Path to a Pydantic model for structured output (e.g., 'my_module:OrderDetails'). "
|
85
|
-
"If specified, uses structured_llm_node; otherwise, uses llm_node."
|
86
|
-
),
|
76
|
+
description="Path to a Pydantic model for structured output (e.g., 'my_module:OrderDetails')."
|
87
77
|
)
|
88
78
|
api_key: Optional[str] = Field(None, description="Custom API key for the LLM provider, if required.")
|
89
79
|
|
@@ -97,13 +87,30 @@ class LLMConfig(BaseModel):
|
|
97
87
|
return data
|
98
88
|
|
99
89
|
|
100
|
-
class
|
101
|
-
"""
|
102
|
-
|
90
|
+
class TemplateConfig(BaseModel):
|
91
|
+
"""Configuration for template-based nodes."""
|
92
|
+
template: str = Field(
|
93
|
+
default="", description="Jinja2 template string to render. Ignored if template_file is set."
|
94
|
+
)
|
95
|
+
template_file: Optional[str] = Field(
|
96
|
+
None, description="Path to an external Jinja2 template file. Takes precedence over template."
|
97
|
+
)
|
98
|
+
|
99
|
+
@model_validator(mode="before")
|
100
|
+
@classmethod
|
101
|
+
def check_template_source(cls, data: Any) -> Any:
|
102
|
+
"""Ensure template_file and template are used appropriately."""
|
103
|
+
template_file = data.get("template_file")
|
104
|
+
template = data.get("template")
|
105
|
+
if not template and not template_file:
|
106
|
+
raise ValueError("Either 'template' or 'template_file' must be provided")
|
107
|
+
if template_file and not isinstance(template_file, str):
|
108
|
+
raise ValueError("template_file must be a string path to a Jinja2 template file")
|
109
|
+
return data
|
103
110
|
|
104
|
-
A node must specify exactly one of 'function', 'sub_workflow', or 'llm_config'.
|
105
|
-
"""
|
106
111
|
|
112
|
+
class NodeDefinition(BaseModel):
|
113
|
+
"""Definition of a workflow node with template_node and inputs_mapping support."""
|
107
114
|
function: Optional[str] = Field(
|
108
115
|
None, description="Name of the function to execute (references a FunctionDefinition)."
|
109
116
|
)
|
@@ -111,12 +118,14 @@ class NodeDefinition(BaseModel):
|
|
111
118
|
None, description="Nested workflow definition for sub-workflow nodes."
|
112
119
|
)
|
113
120
|
llm_config: Optional[LLMConfig] = Field(None, description="Configuration for LLM-based nodes.")
|
121
|
+
template_config: Optional[TemplateConfig] = Field(None, description="Configuration for template-based nodes.")
|
122
|
+
inputs_mapping: Optional[Dict[str, str]] = Field(
|
123
|
+
None,
|
124
|
+
description="Mapping of node inputs to context keys or stringified lambda expressions (e.g., 'lambda ctx: value')."
|
125
|
+
)
|
114
126
|
output: Optional[str] = Field(
|
115
127
|
None,
|
116
|
-
description=
|
117
|
-
"Context key to store the node's result. Defaults to '<node_name>_result' "
|
118
|
-
"for function or LLM nodes if not specified."
|
119
|
-
),
|
128
|
+
description="Context key to store the node's result. Defaults to '<node_name>_result' if applicable."
|
120
129
|
)
|
121
130
|
retries: int = Field(default=3, ge=0, description="Number of retry attempts on failure.")
|
122
131
|
delay: float = Field(default=1.0, ge=0.0, description="Delay in seconds between retries.")
|
@@ -127,62 +136,72 @@ class NodeDefinition(BaseModel):
|
|
127
136
|
|
128
137
|
@model_validator(mode="before")
|
129
138
|
@classmethod
|
130
|
-
def
|
131
|
-
"""Ensure a node has exactly one of 'function', 'sub_workflow', or '
|
139
|
+
def check_function_or_sub_workflow_or_llm_or_template(cls, data: Any) -> Any:
|
140
|
+
"""Ensure a node has exactly one of 'function', 'sub_workflow', 'llm_config', or 'template_config'."""
|
132
141
|
func = data.get("function")
|
133
142
|
sub_wf = data.get("sub_workflow")
|
134
143
|
llm = data.get("llm_config")
|
135
|
-
|
136
|
-
|
144
|
+
template = data.get("template_config")
|
145
|
+
if sum(x is not None for x in (func, sub_wf, llm, template)) != 1:
|
146
|
+
raise ValueError("Node must have exactly one of 'function', 'sub_workflow', 'llm_config', or 'template_config'")
|
137
147
|
return data
|
138
148
|
|
139
149
|
|
150
|
+
class BranchCondition(BaseModel):
|
151
|
+
"""Definition of a branch condition for a transition."""
|
152
|
+
to_node: str = Field(
|
153
|
+
..., description="Target node name for this branch."
|
154
|
+
)
|
155
|
+
condition: Optional[str] = Field(
|
156
|
+
None, description="Python expression using 'ctx' for conditional branching."
|
157
|
+
)
|
158
|
+
|
159
|
+
|
140
160
|
class TransitionDefinition(BaseModel):
|
141
161
|
"""Definition of a transition between nodes."""
|
142
|
-
|
143
162
|
from_node: str = Field(
|
144
163
|
...,
|
145
164
|
description="Source node name for the transition.",
|
146
165
|
)
|
147
|
-
to_node: Union[str, List[str]] = Field(
|
148
|
-
...,
|
166
|
+
to_node: Union[str, List[Union[str, BranchCondition]]] = Field(
|
167
|
+
...,
|
168
|
+
description=(
|
169
|
+
"Target node(s). Can be: a string, list of strings (parallel), or list of BranchCondition (branching)."
|
170
|
+
),
|
149
171
|
)
|
150
172
|
condition: Optional[str] = Field(
|
151
|
-
None,
|
173
|
+
None,
|
174
|
+
description="Python expression using 'ctx' for simple transitions."
|
152
175
|
)
|
153
176
|
|
154
177
|
|
155
178
|
class WorkflowStructure(BaseModel):
|
156
179
|
"""Structure defining the workflow's execution flow."""
|
157
|
-
|
158
180
|
start: Optional[str] = Field(None, description="Name of the starting node.")
|
159
181
|
transitions: List[TransitionDefinition] = Field(
|
160
182
|
default_factory=list, description="List of transitions between nodes."
|
161
183
|
)
|
184
|
+
convergence_nodes: List[str] = Field(
|
185
|
+
default_factory=list, description="List of nodes where branches converge."
|
186
|
+
)
|
162
187
|
|
163
188
|
|
164
189
|
class WorkflowDefinition(BaseModel):
|
165
190
|
"""Top-level definition of the workflow."""
|
166
|
-
|
167
191
|
functions: Dict[str, FunctionDefinition] = Field(
|
168
|
-
default_factory=dict, description="Dictionary of function definitions
|
192
|
+
default_factory=dict, description="Dictionary of function definitions."
|
169
193
|
)
|
170
194
|
nodes: Dict[str, NodeDefinition] = Field(default_factory=dict, description="Dictionary of node definitions.")
|
171
195
|
workflow: WorkflowStructure = Field(
|
172
|
-
default_factory=lambda: WorkflowStructure(start=None), description="Main workflow structure
|
196
|
+
default_factory=lambda: WorkflowStructure(start=None), description="Main workflow structure."
|
173
197
|
)
|
174
198
|
observers: List[str] = Field(
|
175
|
-
default_factory=list, description="List of observer function names
|
199
|
+
default_factory=list, description="List of observer function names."
|
176
200
|
)
|
177
201
|
dependencies: List[str] = Field(
|
178
202
|
default_factory=list,
|
179
|
-
description=
|
180
|
-
"List of Python module dependencies required by the workflow. "
|
181
|
-
"Examples: PyPI packages ('requests>=2.28.0'), local paths ('/path/to/module.py'), "
|
182
|
-
"or remote URLs ('https://example.com/module.py'). Processed during workflow instantiation."
|
183
|
-
),
|
203
|
+
description="List of Python module dependencies."
|
184
204
|
)
|
185
205
|
|
186
206
|
|
187
|
-
# Resolve forward reference for sub_workflow in NodeDefinition
|
188
207
|
NodeDefinition.model_rebuild()
|
quantalogic/flow/flow_mermaid.py
CHANGED
@@ -1,8 +1,8 @@
|
|
1
1
|
import re
|
2
|
-
from typing import Dict, List, Optional, Set, Tuple
|
2
|
+
from typing import Dict, List, Optional, Set, Tuple, Union
|
3
3
|
|
4
4
|
from quantalogic.flow.flow_manager import WorkflowManager
|
5
|
-
from quantalogic.flow.flow_manager_schema import NodeDefinition, WorkflowDefinition
|
5
|
+
from quantalogic.flow.flow_manager_schema import BranchCondition, NodeDefinition, WorkflowDefinition
|
6
6
|
|
7
7
|
|
8
8
|
def get_node_label_and_type(node_name: str, node_def: Optional[NodeDefinition], has_conditions: bool) -> Tuple[str, str, str]:
|
@@ -12,29 +12,48 @@ def get_node_label_and_type(node_name: str, node_def: Optional[NodeDefinition],
|
|
12
12
|
Args:
|
13
13
|
node_name: The name of the node.
|
14
14
|
node_def: The NodeDefinition object from the workflow, or None if undefined.
|
15
|
-
has_conditions: True if the node has outgoing transitions with conditions.
|
15
|
+
has_conditions: True if the node has outgoing transitions with conditions (branching).
|
16
16
|
|
17
17
|
Returns:
|
18
18
|
A tuple of (display label, type key for styling, shape identifier).
|
19
19
|
"""
|
20
|
-
#
|
20
|
+
# Escape quotes for Mermaid compatibility
|
21
21
|
escaped_name = node_name.replace('"', '\\"')
|
22
|
-
|
23
|
-
# Use diamond shape for nodes with conditional transitions, rectangle otherwise
|
24
22
|
shape = "diamond" if has_conditions else "rect"
|
25
23
|
|
26
24
|
if not node_def:
|
27
25
|
return f"{escaped_name} (unknown)", "unknown", shape
|
28
|
-
|
26
|
+
|
27
|
+
# Base label starts with node name and type
|
29
28
|
if node_def.function:
|
30
|
-
|
29
|
+
label = f"{escaped_name} (function)"
|
30
|
+
node_type = "function"
|
31
31
|
elif node_def.llm_config:
|
32
32
|
if node_def.llm_config.response_model:
|
33
|
-
|
34
|
-
|
33
|
+
label = f"{escaped_name} (structured LLM)"
|
34
|
+
node_type = "structured_llm"
|
35
|
+
else:
|
36
|
+
label = f"{escaped_name} (LLM)"
|
37
|
+
node_type = "llm"
|
38
|
+
elif node_def.template_config:
|
39
|
+
label = f"{escaped_name} (template)"
|
40
|
+
node_type = "template"
|
35
41
|
elif node_def.sub_workflow:
|
36
|
-
|
37
|
-
|
42
|
+
label = f"{escaped_name} (Sub-Workflow)"
|
43
|
+
node_type = "sub_workflow"
|
44
|
+
else:
|
45
|
+
label = f"{escaped_name} (unknown)"
|
46
|
+
node_type = "unknown"
|
47
|
+
|
48
|
+
# Append inputs_mapping if present
|
49
|
+
if node_def and node_def.inputs_mapping:
|
50
|
+
mapping_str = ", ".join(f"{k}={v}" for k, v in node_def.inputs_mapping.items())
|
51
|
+
# Truncate if too long for readability (e.g., > 30 chars)
|
52
|
+
if len(mapping_str) > 30:
|
53
|
+
mapping_str = mapping_str[:27] + "..."
|
54
|
+
label += f"\\nInputs: {mapping_str}"
|
55
|
+
|
56
|
+
return label, node_type, shape
|
38
57
|
|
39
58
|
|
40
59
|
def generate_mermaid_diagram(
|
@@ -51,7 +70,7 @@ def generate_mermaid_diagram(
|
|
51
70
|
workflow_def: The workflow definition to visualize.
|
52
71
|
include_subgraphs: If True, nests sub-workflows in Mermaid subgraphs (flowchart only).
|
53
72
|
title: Optional title for the diagram.
|
54
|
-
include_legend: If True, adds a comment-based legend explaining node types.
|
73
|
+
include_legend: If True, adds a comment-based legend explaining node types and shapes.
|
55
74
|
diagram_type: Type of diagram to generate: "flowchart" (default) or "stateDiagram".
|
56
75
|
|
57
76
|
Returns:
|
@@ -68,6 +87,7 @@ def generate_mermaid_diagram(
|
|
68
87
|
"function": "fill:#90CAF9,stroke:#42A5F5,stroke-width:2px", # Pastel Blue
|
69
88
|
"structured_llm": "fill:#A5D6A7,stroke:#66BB6A,stroke-width:2px", # Pastel Green
|
70
89
|
"llm": "fill:#CE93D8,stroke:#AB47BC,stroke-width:2px", # Pastel Purple
|
90
|
+
"template": "fill:#FCE4EC,stroke:#F06292,stroke-width:2px", # Pastel Pink (new for template)
|
71
91
|
"sub_workflow": "fill:#FFCCBC,stroke:#FF7043,stroke-width:2px", # Pastel Orange
|
72
92
|
"unknown": "fill:#CFD8DC,stroke:#B0BEC5,stroke-width:2px" # Pastel Grey
|
73
93
|
}
|
@@ -75,7 +95,7 @@ def generate_mermaid_diagram(
|
|
75
95
|
# Shape mappings for flowchart syntax
|
76
96
|
shape_syntax: Dict[str, Tuple[str, str]] = {
|
77
97
|
"rect": ("[", "]"), # Rectangle for standard nodes
|
78
|
-
"diamond": ("{{", "}}") # Diamond for decision points
|
98
|
+
"diamond": ("{{", "}}") # Diamond for decision points (branching)
|
79
99
|
}
|
80
100
|
|
81
101
|
# Validate node names for Mermaid compatibility (alphanumeric, underscore, hyphen)
|
@@ -94,17 +114,26 @@ def generate_mermaid_diagram(
|
|
94
114
|
raise ValueError(f"Invalid node name '{trans.to_node}' for Mermaid")
|
95
115
|
all_nodes.add(trans.to_node)
|
96
116
|
else:
|
97
|
-
for
|
98
|
-
if
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
117
|
+
for tn in trans.to_node:
|
118
|
+
target = tn if isinstance(tn, str) else tn.to_node
|
119
|
+
if re.search(invalid_chars, target):
|
120
|
+
raise ValueError(f"Invalid node name '{target}' for Mermaid")
|
121
|
+
all_nodes.add(target)
|
122
|
+
for conv_node in workflow_def.workflow.convergence_nodes:
|
123
|
+
if re.search(invalid_chars, conv_node):
|
124
|
+
raise ValueError(f"Invalid node name '{conv_node}' for Mermaid")
|
125
|
+
all_nodes.add(conv_node)
|
126
|
+
|
127
|
+
# Determine nodes with conditional transitions (branching)
|
103
128
|
conditional_nodes: Set[str] = set()
|
104
129
|
for trans in workflow_def.workflow.transitions:
|
105
|
-
if trans.condition and isinstance(trans.to_node, str)
|
130
|
+
if (trans.condition and isinstance(trans.to_node, str)) or \
|
131
|
+
(isinstance(trans.to_node, list) and any(isinstance(tn, BranchCondition) and tn.condition for tn in trans.to_node)):
|
106
132
|
conditional_nodes.add(trans.from_node)
|
107
133
|
|
134
|
+
# Identify convergence nodes
|
135
|
+
convergence_nodes: Set[str] = set(workflow_def.workflow.convergence_nodes)
|
136
|
+
|
108
137
|
# Shared node definitions and types
|
109
138
|
node_types: Dict[str, str] = {}
|
110
139
|
node_shapes: Dict[str, str] = {} # Only used for flowchart
|
@@ -119,13 +148,14 @@ def generate_mermaid_diagram(
|
|
119
148
|
if title:
|
120
149
|
mermaid_code += f" %% Diagram: {title}\n"
|
121
150
|
|
122
|
-
# Optional legend for UX
|
151
|
+
# Optional legend for UX, updated to include template nodes
|
123
152
|
if include_legend:
|
124
153
|
mermaid_code += " %% Legend:\n"
|
125
154
|
if diagram_type == "flowchart":
|
126
|
-
mermaid_code += " %% - Rectangle: Process Step\n"
|
127
|
-
mermaid_code += " %% - Diamond: Decision Point\n"
|
128
|
-
mermaid_code += " %% - Colors: Blue (Function), Green (Structured LLM), Purple (LLM), Orange (Sub-Workflow), Grey (Unknown)\n"
|
155
|
+
mermaid_code += " %% - Rectangle: Process Step or Convergence Point\n"
|
156
|
+
mermaid_code += " %% - Diamond: Decision Point (Branching)\n"
|
157
|
+
mermaid_code += " %% - Colors: Blue (Function), Green (Structured LLM), Purple (LLM), Pink (Template), Orange (Sub-Workflow), Grey (Unknown)\n"
|
158
|
+
mermaid_code += " %% - Dashed Border: Convergence Node\n"
|
129
159
|
|
130
160
|
if diagram_type == "flowchart":
|
131
161
|
# Flowchart-specific: Generate node definitions with shapes
|
@@ -140,10 +170,10 @@ def generate_mermaid_diagram(
|
|
140
170
|
node_shapes[node] = shape
|
141
171
|
|
142
172
|
# Add node definitions
|
143
|
-
for node_def_str in node_defs:
|
173
|
+
for node_def_str in sorted(node_defs): # Sort for consistent output
|
144
174
|
mermaid_code += f" {node_def_str}\n"
|
145
175
|
|
146
|
-
# Generate arrows for transitions
|
176
|
+
# Generate arrows for transitions
|
147
177
|
for trans in workflow_def.workflow.transitions:
|
148
178
|
from_node = trans.from_node
|
149
179
|
if isinstance(trans.to_node, str):
|
@@ -155,15 +185,29 @@ def generate_mermaid_diagram(
|
|
155
185
|
else:
|
156
186
|
mermaid_code += f' {from_node} --> {to_node}\n'
|
157
187
|
else:
|
158
|
-
for
|
159
|
-
|
188
|
+
for tn in trans.to_node:
|
189
|
+
if isinstance(tn, str):
|
190
|
+
# Parallel transition (no condition)
|
191
|
+
mermaid_code += f' {from_node} --> {tn}\n'
|
192
|
+
else: # BranchCondition
|
193
|
+
to_node = tn.to_node
|
194
|
+
condition = tn.condition
|
195
|
+
if condition:
|
196
|
+
cond = condition.replace('"', '\\"')[:30] + ("..." if len(condition) > 30 else "")
|
197
|
+
mermaid_code += f' {from_node} -->|"{cond}"| {to_node}\n'
|
198
|
+
else:
|
199
|
+
mermaid_code += f' {from_node} --> {to_node}\n'
|
160
200
|
|
161
201
|
# Add styles for node types
|
162
202
|
for node, node_type in node_types.items():
|
163
203
|
if node_type in node_styles:
|
164
|
-
|
204
|
+
style = node_styles[node_type]
|
205
|
+
# Add dashed stroke for convergence nodes
|
206
|
+
if node in convergence_nodes:
|
207
|
+
style += ",stroke-dasharray:5 5"
|
208
|
+
mermaid_code += f" style {node} {style}\n"
|
165
209
|
|
166
|
-
# Highlight the start node
|
210
|
+
# Highlight the start node with a thicker border
|
167
211
|
if workflow_def.workflow.start and workflow_def.workflow.start in node_types:
|
168
212
|
mermaid_code += f" style {workflow_def.workflow.start} stroke-width:4px\n"
|
169
213
|
|
@@ -178,8 +222,10 @@ def generate_mermaid_diagram(
|
|
178
222
|
if isinstance(trans.to_node, str):
|
179
223
|
sub_nodes.add(trans.to_node)
|
180
224
|
else:
|
181
|
-
|
182
|
-
|
225
|
+
for tn in trans.to_node:
|
226
|
+
target = tn if isinstance(tn, str) else tn.to_node
|
227
|
+
sub_nodes.add(target)
|
228
|
+
for sub_node in sorted(sub_nodes): # Sort for consistency
|
183
229
|
mermaid_code += f" {sub_node}[[{sub_node}]]\n"
|
184
230
|
mermaid_code += " end\n"
|
185
231
|
|
@@ -189,6 +235,9 @@ def generate_mermaid_diagram(
|
|
189
235
|
node_def_state: Optional[NodeDefinition] = workflow_def.nodes.get(node)
|
190
236
|
has_conditions = node in conditional_nodes
|
191
237
|
label, node_type, _ = get_node_label_and_type(node, node_def_state, has_conditions) # Shape unused
|
238
|
+
# Append (Convergence) to label for convergence nodes
|
239
|
+
if node in convergence_nodes:
|
240
|
+
label += " (Convergence)"
|
192
241
|
mermaid_code += f" state \"{label}\" as {node}\n"
|
193
242
|
node_types[node] = node_type
|
194
243
|
|
@@ -208,75 +257,109 @@ def generate_mermaid_diagram(
|
|
208
257
|
else:
|
209
258
|
mermaid_code += f" {from_node} --> {to_node}\n"
|
210
259
|
else:
|
211
|
-
|
212
|
-
|
213
|
-
|
260
|
+
for tn in trans.to_node:
|
261
|
+
if isinstance(tn, str):
|
262
|
+
# Parallel transition approximated with a note
|
263
|
+
mermaid_code += f" {from_node} --> {tn} : parallel\n"
|
264
|
+
else: # BranchCondition
|
265
|
+
to_node = tn.to_node
|
266
|
+
condition = tn.condition
|
267
|
+
if condition:
|
268
|
+
cond = condition.replace('"', '\\"')[:30] + ("..." if len(condition) > 30 else "")
|
269
|
+
mermaid_code += f" {from_node} --> {to_node} : {cond}\n"
|
270
|
+
else:
|
271
|
+
mermaid_code += f" {from_node} --> {to_node}\n"
|
214
272
|
|
215
273
|
# Add styles for node types
|
216
274
|
for node, node_type in node_types.items():
|
217
275
|
if node_type in node_styles:
|
218
|
-
|
276
|
+
style = node_styles[node_type]
|
277
|
+
# Add dashed stroke for convergence nodes
|
278
|
+
if node in convergence_nodes:
|
279
|
+
style += ",stroke-dasharray:5 5"
|
280
|
+
mermaid_code += f" style {node} {style}\n"
|
219
281
|
|
220
282
|
mermaid_code += "```\n"
|
221
283
|
return mermaid_code
|
222
284
|
|
223
285
|
|
224
286
|
def main() -> None:
|
225
|
-
"""Create a complex workflow and print its
|
287
|
+
"""Create a complex workflow with branch, converge, template node, and input mapping, and print its Mermaid diagram."""
|
226
288
|
manager = WorkflowManager()
|
227
289
|
|
228
290
|
# Add functions
|
229
291
|
manager.add_function(
|
230
|
-
name="
|
292
|
+
name="say_hello",
|
231
293
|
type_="embedded",
|
232
|
-
code="
|
294
|
+
code="def say_hello():\n return 'Hello, World!'"
|
233
295
|
)
|
234
296
|
manager.add_function(
|
235
|
-
name="
|
297
|
+
name="check_condition",
|
236
298
|
type_="embedded",
|
237
|
-
code="
|
299
|
+
code="def check_condition(text: str):\n return 'yes' if 'Hello' in text else 'no'"
|
238
300
|
)
|
239
301
|
manager.add_function(
|
240
|
-
name="
|
302
|
+
name="say_goodbye",
|
241
303
|
type_="embedded",
|
242
|
-
code="
|
304
|
+
code="def say_goodbye():\n return 'Goodbye, World!'"
|
243
305
|
)
|
244
306
|
manager.add_function(
|
245
|
-
name="
|
307
|
+
name="finalize",
|
246
308
|
type_="embedded",
|
247
|
-
code="
|
309
|
+
code="def finalize(text: str):\n return 'Done'"
|
248
310
|
)
|
249
311
|
|
312
|
+
# Add nodes
|
313
|
+
manager.add_node(name="start", function="say_hello", output="text")
|
314
|
+
manager.add_node(name="check", function="check_condition", output="result",
|
315
|
+
inputs_mapping={"text": "text"})
|
316
|
+
manager.add_node(name="goodbye", function="say_goodbye", output="farewell")
|
317
|
+
manager.add_node(name="finalize", function="finalize", output="status",
|
318
|
+
inputs_mapping={"text": "lambda ctx: ctx['farewell'] if ctx['result'] == 'no' else ctx['ai_result']"})
|
319
|
+
|
250
320
|
# Add LLM node
|
251
|
-
|
252
|
-
"
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
321
|
+
manager.add_node(
|
322
|
+
name="ai_node",
|
323
|
+
llm_config={
|
324
|
+
"model": "gpt-3.5-turbo",
|
325
|
+
"prompt_template": "{{text}}",
|
326
|
+
"temperature": 0.7
|
327
|
+
},
|
328
|
+
output="ai_result"
|
329
|
+
)
|
259
330
|
|
260
|
-
# Add
|
261
|
-
manager.add_node(
|
262
|
-
|
263
|
-
|
264
|
-
|
331
|
+
# Add template node
|
332
|
+
manager.add_node(
|
333
|
+
name="template_node",
|
334
|
+
template_config={
|
335
|
+
"template": "Response: {{text}} - {{result}}"
|
336
|
+
},
|
337
|
+
output="template_output",
|
338
|
+
inputs_mapping={"text": "text", "result": "result"}
|
339
|
+
)
|
265
340
|
|
266
|
-
# Define workflow structure
|
267
|
-
manager.set_start_node("
|
268
|
-
manager.add_transition(from_node="
|
269
|
-
manager.add_transition(
|
270
|
-
|
271
|
-
|
341
|
+
# Define workflow structure with branch and converge
|
342
|
+
manager.set_start_node("start")
|
343
|
+
manager.add_transition(from_node="start", to_node="check")
|
344
|
+
manager.add_transition(
|
345
|
+
from_node="check",
|
346
|
+
to_node=[
|
347
|
+
BranchCondition(to_node="ai_node", condition="ctx['result'] == 'yes'"),
|
348
|
+
BranchCondition(to_node="goodbye", condition="ctx['result'] == 'no'")
|
349
|
+
]
|
350
|
+
)
|
351
|
+
manager.add_transition(from_node="ai_node", to_node="finalize")
|
352
|
+
manager.add_transition(from_node="goodbye", to_node="finalize")
|
353
|
+
manager.add_transition(from_node="finalize", to_node="template_node")
|
354
|
+
manager.add_convergence_node("finalize")
|
272
355
|
|
273
356
|
# Generate and print both diagrams
|
274
357
|
workflow_def = manager.workflow
|
275
358
|
print("Flowchart (default):")
|
276
|
-
print(generate_mermaid_diagram(workflow_def, include_subgraphs=False, title="
|
359
|
+
print(generate_mermaid_diagram(workflow_def, include_subgraphs=False, title="Sample Workflow with Template and Mapping"))
|
277
360
|
print("\nState Diagram:")
|
278
|
-
print(generate_mermaid_diagram(workflow_def, diagram_type="stateDiagram", title="
|
279
|
-
|
361
|
+
print(generate_mermaid_diagram(workflow_def, diagram_type="stateDiagram", title="Sample Workflow with Template and Mapping"))
|
362
|
+
|
280
363
|
|
281
364
|
if __name__ == "__main__":
|
282
365
|
main()
|