quantalogic 0.80__py3-none-any.whl → 0.93__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. quantalogic/flow/__init__.py +16 -34
  2. quantalogic/main.py +11 -6
  3. quantalogic/tools/tool.py +8 -922
  4. quantalogic-0.93.dist-info/METADATA +475 -0
  5. {quantalogic-0.80.dist-info → quantalogic-0.93.dist-info}/RECORD +8 -54
  6. quantalogic/codeact/TODO.md +0 -14
  7. quantalogic/codeact/__init__.py +0 -0
  8. quantalogic/codeact/agent.py +0 -478
  9. quantalogic/codeact/cli.py +0 -50
  10. quantalogic/codeact/cli_commands/__init__.py +0 -0
  11. quantalogic/codeact/cli_commands/create_toolbox.py +0 -45
  12. quantalogic/codeact/cli_commands/install_toolbox.py +0 -20
  13. quantalogic/codeact/cli_commands/list_executor.py +0 -15
  14. quantalogic/codeact/cli_commands/list_reasoners.py +0 -15
  15. quantalogic/codeact/cli_commands/list_toolboxes.py +0 -47
  16. quantalogic/codeact/cli_commands/task.py +0 -215
  17. quantalogic/codeact/cli_commands/tool_info.py +0 -24
  18. quantalogic/codeact/cli_commands/uninstall_toolbox.py +0 -43
  19. quantalogic/codeact/config.yaml +0 -21
  20. quantalogic/codeact/constants.py +0 -9
  21. quantalogic/codeact/events.py +0 -85
  22. quantalogic/codeact/examples/README.md +0 -342
  23. quantalogic/codeact/examples/agent_sample.yaml +0 -29
  24. quantalogic/codeact/executor.py +0 -186
  25. quantalogic/codeact/history_manager.py +0 -94
  26. quantalogic/codeact/llm_util.py +0 -57
  27. quantalogic/codeact/plugin_manager.py +0 -92
  28. quantalogic/codeact/prompts/error_format.j2 +0 -11
  29. quantalogic/codeact/prompts/generate_action.j2 +0 -77
  30. quantalogic/codeact/prompts/generate_program.j2 +0 -52
  31. quantalogic/codeact/prompts/response_format.j2 +0 -11
  32. quantalogic/codeact/react_agent.py +0 -318
  33. quantalogic/codeact/reasoner.py +0 -185
  34. quantalogic/codeact/templates/toolbox/README.md.j2 +0 -10
  35. quantalogic/codeact/templates/toolbox/pyproject.toml.j2 +0 -16
  36. quantalogic/codeact/templates/toolbox/tools.py.j2 +0 -6
  37. quantalogic/codeact/templates.py +0 -7
  38. quantalogic/codeact/tools_manager.py +0 -258
  39. quantalogic/codeact/utils.py +0 -62
  40. quantalogic/codeact/xml_utils.py +0 -126
  41. quantalogic/flow/flow.py +0 -1070
  42. quantalogic/flow/flow_extractor.py +0 -783
  43. quantalogic/flow/flow_generator.py +0 -322
  44. quantalogic/flow/flow_manager.py +0 -676
  45. quantalogic/flow/flow_manager_schema.py +0 -287
  46. quantalogic/flow/flow_mermaid.py +0 -365
  47. quantalogic/flow/flow_validator.py +0 -479
  48. quantalogic/flow/flow_yaml.linkedin.md +0 -31
  49. quantalogic/flow/flow_yaml.md +0 -767
  50. quantalogic/flow/templates/prompt_check_inventory.j2 +0 -1
  51. quantalogic/flow/templates/system_check_inventory.j2 +0 -1
  52. quantalogic-0.80.dist-info/METADATA +0 -900
  53. {quantalogic-0.80.dist-info → quantalogic-0.93.dist-info}/LICENSE +0 -0
  54. {quantalogic-0.80.dist-info → quantalogic-0.93.dist-info}/WHEEL +0 -0
  55. {quantalogic-0.80.dist-info → quantalogic-0.93.dist-info}/entry_points.txt +0 -0
@@ -1,287 +0,0 @@
1
- from typing import Any, Dict, List, Optional, Union
2
-
3
- from pydantic import BaseModel, Field, model_validator
4
-
5
-
6
- class FunctionDefinition(BaseModel):
7
- """Definition of a function used in the workflow."""
8
- type: str = Field(
9
- ...,
10
- description="Type of function source. Must be 'embedded' for inline code or 'external' for module-based functions.",
11
- )
12
- code: Optional[str] = Field(
13
- None, description="Multi-line Python code for embedded functions. Required if type is 'embedded'."
14
- )
15
- module: Optional[str] = Field(
16
- None,
17
- description=(
18
- "Source of the external module for 'external' functions. Can be:"
19
- " - A PyPI package name (e.g., 'requests', 'numpy') installed in the environment."
20
- " - A local file path (e.g., '/path/to/module.py')."
21
- " - A remote URL (e.g., 'https://example.com/module.py')."
22
- " Required if type is 'external'."
23
- ),
24
- )
25
- function: Optional[str] = Field(
26
- None,
27
- description="Name of the function within the module for 'external' functions. Required if type is 'external'.",
28
- )
29
-
30
- @model_validator(mode="before")
31
- @classmethod
32
- def check_function_source(cls, data: Any) -> Any:
33
- """Ensure the function definition is valid based on its type.
34
-
35
- Args:
36
- data: Raw data to validate.
37
-
38
- Returns:
39
- Validated data.
40
-
41
- Raises:
42
- ValueError: If the function source configuration is invalid.
43
- """
44
- type_ = data.get("type")
45
- if type_ == "embedded":
46
- if not data.get("code"):
47
- raise ValueError("Embedded functions require 'code' to be specified")
48
- if data.get("module") or data.get("function"):
49
- raise ValueError("Embedded functions should not specify 'module' or 'function'")
50
- elif type_ == "external":
51
- if not data.get("module") or not data.get("function"):
52
- raise ValueError("External functions require both 'module' and 'function'")
53
- if data.get("code"):
54
- raise ValueError("External functions should not specify 'code'")
55
- else:
56
- raise ValueError("Function type must be 'embedded' or 'external'")
57
- return data
58
-
59
-
60
- class LLMConfig(BaseModel):
61
- """Configuration for LLM-based nodes."""
62
- model: str = Field(
63
- default="gpt-3.5-turbo",
64
- description=(
65
- "The LLM model to use. Can be a static model name (e.g., 'gpt-3.5-turbo', 'gemini/gemini-2.0-flash') "
66
- "or a lambda expression (e.g., 'lambda ctx: ctx.get(\"model_name\")') for dynamic selection."
67
- ),
68
- )
69
- system_prompt: Optional[str] = Field(None, description="System prompt defining the LLM's role or context.")
70
- system_prompt_file: Optional[str] = Field(
71
- None,
72
- description="Path to an external Jinja2 template file for the system prompt. Takes precedence over system_prompt."
73
- )
74
- prompt_template: str = Field(
75
- default="{{ input }}", description="Jinja2 template for the user prompt. Ignored if prompt_file is set."
76
- )
77
- prompt_file: Optional[str] = Field(
78
- None, description="Path to an external Jinja2 template file. Takes precedence over prompt_template."
79
- )
80
- temperature: float = Field(
81
- default=0.7, ge=0.0, le=1.0, description="Controls randomness of LLM output (0.0 to 1.0)."
82
- )
83
- max_tokens: Optional[int] = Field(None, ge=1, description="Maximum number of tokens in the response.")
84
- top_p: float = Field(default=1.0, ge=0.0, le=1.0, description="Nucleus sampling parameter (0.0 to 1.0).")
85
- presence_penalty: float = Field(
86
- default=0.0, ge=-2.0, le=2.0, description="Penalty for repeating topics (-2.0 to 2.0)."
87
- )
88
- frequency_penalty: float = Field(
89
- default=0.0, ge=-2.0, le=2.0, description="Penalty for repeating words (-2.0 to 2.0)."
90
- )
91
- stop: Optional[List[str]] = Field(None, description="List of stop sequences for LLM generation.")
92
- response_model: Optional[str] = Field(
93
- None,
94
- description="Path to a Pydantic model for structured output (e.g., 'my_module:OrderDetails')."
95
- )
96
- api_key: Optional[str] = Field(None, description="Custom API key for the LLM provider, if required.")
97
-
98
- @model_validator(mode="before")
99
- @classmethod
100
- def check_prompt_source(cls, data: Any) -> Any:
101
- """Ensure prompt_file and prompt_template are used appropriately.
102
-
103
- Args:
104
- data: Raw data to validate.
105
-
106
- Returns:
107
- Validated data.
108
-
109
- Raises:
110
- ValueError: If prompt configuration is invalid.
111
- """
112
- prompt_file = data.get("prompt_file")
113
- if prompt_file and not isinstance(prompt_file, str):
114
- raise ValueError("prompt_file must be a string path to a Jinja2 template file")
115
- return data
116
-
117
-
118
- class TemplateConfig(BaseModel):
119
- """Configuration for template-based nodes."""
120
- template: str = Field(
121
- default="", description="Jinja2 template string to render. Ignored if template_file is set."
122
- )
123
- template_file: Optional[str] = Field(
124
- None, description="Path to an external Jinja2 template file. Takes precedence over template."
125
- )
126
-
127
- @model_validator(mode="before")
128
- @classmethod
129
- def check_template_source(cls, data: Any) -> Any:
130
- """Ensure template_file and template are used appropriately.
131
-
132
- Args:
133
- data: Raw data to validate.
134
-
135
- Returns:
136
- Validated data.
137
-
138
- Raises:
139
- ValueError: If template configuration is invalid.
140
- """
141
- template_file = data.get("template_file")
142
- template = data.get("template")
143
- if not template and not template_file:
144
- raise ValueError("Either 'template' or 'template_file' must be provided")
145
- if template_file and not isinstance(template_file, str):
146
- raise ValueError("template_file must be a string path to a Jinja2 template file")
147
- return data
148
-
149
-
150
- class NodeDefinition(BaseModel):
151
- """Definition of a workflow node with template_node and inputs_mapping support."""
152
- function: Optional[str] = Field(
153
- None, description="Name of the function to execute (references a FunctionDefinition)."
154
- )
155
- sub_workflow: Optional["WorkflowStructure"] = Field(
156
- None, description="Nested workflow definition for sub-workflow nodes."
157
- )
158
- llm_config: Optional[LLMConfig] = Field(None, description="Configuration for LLM-based nodes.")
159
- template_config: Optional[TemplateConfig] = Field(None, description="Configuration for template-based nodes.")
160
- inputs_mapping: Optional[Dict[str, str]] = Field(
161
- None,
162
- description="Mapping of node inputs to context keys or stringified lambda expressions (e.g., 'lambda ctx: value')."
163
- )
164
- output: Optional[str] = Field(
165
- None,
166
- description="Context key to store the node's result. Defaults to '<node_name>_result' if applicable."
167
- )
168
- retries: int = Field(default=3, ge=0, description="Number of retry attempts on failure.")
169
- delay: float = Field(default=1.0, ge=0.0, description="Delay in seconds between retries.")
170
- timeout: Optional[float] = Field(
171
- None, ge=0.0, description="Maximum execution time in seconds (null for no timeout)."
172
- )
173
- parallel: bool = Field(default=False, description="Whether the node can execute in parallel with others.")
174
-
175
- @model_validator(mode="before")
176
- @classmethod
177
- def check_function_or_sub_workflow_or_llm_or_template(cls, data: Any) -> Any:
178
- """Ensure a node has exactly one of 'function', 'sub_workflow', 'llm_config', or 'template_config'.
179
-
180
- Args:
181
- data: Raw data to validate.
182
-
183
- Returns:
184
- Validated data.
185
-
186
- Raises:
187
- ValueError: If node type configuration is invalid.
188
- """
189
- func = data.get("function")
190
- sub_wf = data.get("sub_workflow")
191
- llm = data.get("llm_config")
192
- template = data.get("template_config")
193
- if sum(x is not None for x in (func, sub_wf, llm, template)) != 1:
194
- raise ValueError("Node must have exactly one of 'function', 'sub_workflow', 'llm_config', or 'template_config'")
195
- return data
196
-
197
-
198
- class BranchCondition(BaseModel):
199
- """Definition of a branch condition for a transition."""
200
- to_node: str = Field(
201
- ..., description="Target node name for this branch."
202
- )
203
- condition: Optional[str] = Field(
204
- None, description="Python expression using 'ctx' for conditional branching."
205
- )
206
-
207
-
208
- class TransitionDefinition(BaseModel):
209
- """Definition of a transition between nodes."""
210
- from_node: str = Field(
211
- ...,
212
- description="Source node name for the transition.",
213
- )
214
- to_node: Union[str, List[Union[str, BranchCondition]]] = Field(
215
- ...,
216
- description=(
217
- "Target node(s). Can be: a string, list of strings (parallel), or list of BranchCondition (branching)."
218
- ),
219
- )
220
- condition: Optional[str] = Field(
221
- None,
222
- description="Python expression using 'ctx' for simple transitions."
223
- )
224
-
225
-
226
- class LoopDefinition(BaseModel):
227
- """Definition of a loop within the workflow."""
228
- nodes: List[str] = Field(..., description="List of node names in the loop.")
229
- condition: str = Field(..., description="Python expression using 'ctx' for the loop condition.")
230
- exit_node: str = Field(..., description="Node to transition to when the loop ends.")
231
-
232
-
233
- class WorkflowStructure(BaseModel):
234
- """Structure defining the workflow's execution flow."""
235
- start: Optional[str] = Field(None, description="Name of the starting node.")
236
- transitions: List[TransitionDefinition] = Field(
237
- default_factory=list, description="List of transitions between nodes."
238
- )
239
- loops: List[LoopDefinition] = Field(
240
- default_factory=list, description="List of loop definitions (optional, for explicit loop support)."
241
- )
242
- convergence_nodes: List[str] = Field(
243
- default_factory=list, description="List of nodes where branches converge."
244
- )
245
-
246
- @model_validator(mode="before")
247
- @classmethod
248
- def check_loop_nodes(cls, data: Any) -> Any:
249
- """Ensure all nodes in loops exist in the workflow.
250
-
251
- Args:
252
- data: Raw data to validate.
253
-
254
- Returns:
255
- Validated data.
256
-
257
- Raises:
258
- ValueError: If loop nodes are not defined.
259
- """
260
- loops = data.get("loops", [])
261
- nodes = set(data.get("nodes", {}).keys())
262
- for loop in loops:
263
- for node in loop["nodes"] + [loop["exit_node"]]:
264
- if node not in nodes:
265
- raise ValueError(f"Loop node '{node}' not defined in nodes")
266
- return data
267
-
268
-
269
- class WorkflowDefinition(BaseModel):
270
- """Top-level definition of the workflow."""
271
- functions: Dict[str, FunctionDefinition] = Field(
272
- default_factory=dict, description="Dictionary of function definitions."
273
- )
274
- nodes: Dict[str, NodeDefinition] = Field(default_factory=dict, description="Dictionary of node definitions.")
275
- workflow: WorkflowStructure = Field(
276
- default_factory=lambda: WorkflowStructure(start=None), description="Main workflow structure."
277
- )
278
- observers: List[str] = Field(
279
- default_factory=list, description="List of observer function names."
280
- )
281
- dependencies: List[str] = Field(
282
- default_factory=list,
283
- description="List of Python module dependencies."
284
- )
285
-
286
-
287
- NodeDefinition.model_rebuild()
@@ -1,365 +0,0 @@
1
- import re
2
- from typing import Dict, List, Optional, Set, Tuple, Union
3
-
4
- from quantalogic.flow.flow_manager import WorkflowManager
5
- from quantalogic.flow.flow_manager_schema import BranchCondition, NodeDefinition, WorkflowDefinition
6
-
7
-
8
- def get_node_label_and_type(node_name: str, node_def: Optional[NodeDefinition], has_conditions: bool) -> Tuple[str, str, str]:
9
- """
10
- Generate a label, type identifier, and shape for a node based on its definition and transition context.
11
-
12
- Args:
13
- node_name: The name of the node.
14
- node_def: The NodeDefinition object from the workflow, or None if undefined.
15
- has_conditions: True if the node has outgoing transitions with conditions (branching).
16
-
17
- Returns:
18
- A tuple of (display label, type key for styling, shape identifier).
19
- """
20
- # Escape quotes for Mermaid compatibility
21
- escaped_name = node_name.replace('"', '\\"')
22
- shape = "diamond" if has_conditions else "rect"
23
-
24
- if not node_def:
25
- return f"{escaped_name} (unknown)", "unknown", shape
26
-
27
- # Base label starts with node name and type
28
- if node_def.function:
29
- label = f"{escaped_name} (function)"
30
- node_type = "function"
31
- elif node_def.llm_config:
32
- if node_def.llm_config.response_model:
33
- label = f"{escaped_name} (structured LLM)"
34
- node_type = "structured_llm"
35
- else:
36
- label = f"{escaped_name} (LLM)"
37
- node_type = "llm"
38
- elif node_def.template_config:
39
- label = f"{escaped_name} (template)"
40
- node_type = "template"
41
- elif node_def.sub_workflow:
42
- label = f"{escaped_name} (Sub-Workflow)"
43
- node_type = "sub_workflow"
44
- else:
45
- label = f"{escaped_name} (unknown)"
46
- node_type = "unknown"
47
-
48
- # Append inputs_mapping if present
49
- if node_def and node_def.inputs_mapping:
50
- mapping_str = ", ".join(f"{k}={v}" for k, v in node_def.inputs_mapping.items())
51
- # Truncate if too long for readability (e.g., > 30 chars)
52
- if len(mapping_str) > 30:
53
- mapping_str = mapping_str[:27] + "..."
54
- label += f"\\nInputs: {mapping_str}"
55
-
56
- return label, node_type, shape
57
-
58
-
59
- def generate_mermaid_diagram(
60
- workflow_def: WorkflowDefinition,
61
- include_subgraphs: bool = False,
62
- title: Optional[str] = None,
63
- include_legend: bool = True,
64
- diagram_type: str = "flowchart"
65
- ) -> str:
66
- """
67
- Generate a Mermaid diagram (flowchart or stateDiagram) from a WorkflowDefinition with pastel colors and optimal UX.
68
-
69
- Args:
70
- workflow_def: The workflow definition to visualize.
71
- include_subgraphs: If True, nests sub-workflows in Mermaid subgraphs (flowchart only).
72
- title: Optional title for the diagram.
73
- include_legend: If True, adds a comment-based legend explaining node types and shapes.
74
- diagram_type: Type of diagram to generate: "flowchart" (default) or "stateDiagram".
75
-
76
- Returns:
77
- A string containing the Mermaid syntax for the diagram.
78
-
79
- Raises:
80
- ValueError: If node names contain invalid Mermaid characters or diagram_type is invalid.
81
- """
82
- if diagram_type not in ("flowchart", "stateDiagram"):
83
- raise ValueError(f"Invalid diagram_type '{diagram_type}'; must be 'flowchart' or 'stateDiagram'")
84
-
85
- # Pastel color scheme for a soft, user-friendly look
86
- node_styles: Dict[str, str] = {
87
- "function": "fill:#90CAF9,stroke:#42A5F5,stroke-width:2px", # Pastel Blue
88
- "structured_llm": "fill:#A5D6A7,stroke:#66BB6A,stroke-width:2px", # Pastel Green
89
- "llm": "fill:#CE93D8,stroke:#AB47BC,stroke-width:2px", # Pastel Purple
90
- "template": "fill:#FCE4EC,stroke:#F06292,stroke-width:2px", # Pastel Pink
91
- "sub_workflow": "fill:#FFCCBC,stroke:#FF7043,stroke-width:2px", # Pastel Orange
92
- "unknown": "fill:#CFD8DC,stroke:#B0BEC5,stroke-width:2px" # Pastel Grey
93
- }
94
-
95
- # Shape mappings for flowchart syntax
96
- shape_syntax: Dict[str, Tuple[str, str]] = {
97
- "rect": ("[", "]"), # Rectangle for standard nodes
98
- "diamond": ("{{", "}}") # Diamond for decision points (branching)
99
- }
100
-
101
- # Validate node names for Mermaid compatibility (alphanumeric, underscore, hyphen)
102
- invalid_chars = r'[^a-zA-Z0-9_-]'
103
- all_nodes: Set[str] = set()
104
- if workflow_def.workflow.start:
105
- if re.search(invalid_chars, workflow_def.workflow.start):
106
- raise ValueError(f"Invalid node name '{workflow_def.workflow.start}' for Mermaid")
107
- all_nodes.add(workflow_def.workflow.start)
108
- for trans in workflow_def.workflow.transitions:
109
- if re.search(invalid_chars, trans.from_node):
110
- raise ValueError(f"Invalid node name '{trans.from_node}' for Mermaid")
111
- all_nodes.add(trans.from_node)
112
- if isinstance(trans.to_node, str):
113
- if re.search(invalid_chars, trans.to_node):
114
- raise ValueError(f"Invalid node name '{trans.to_node}' for Mermaid")
115
- all_nodes.add(trans.to_node)
116
- else:
117
- for tn in trans.to_node:
118
- target = tn if isinstance(tn, str) else tn.to_node
119
- if re.search(invalid_chars, target):
120
- raise ValueError(f"Invalid node name '{target}' for Mermaid")
121
- all_nodes.add(target)
122
- for conv_node in workflow_def.workflow.convergence_nodes:
123
- if re.search(invalid_chars, conv_node):
124
- raise ValueError(f"Invalid node name '{conv_node}' for Mermaid")
125
- all_nodes.add(conv_node)
126
-
127
- # Determine nodes with conditional transitions (branching)
128
- conditional_nodes: Set[str] = set()
129
- for trans in workflow_def.workflow.transitions:
130
- if (trans.condition and isinstance(trans.to_node, str)) or \
131
- (isinstance(trans.to_node, list) and any(isinstance(tn, BranchCondition) and tn.condition for tn in trans.to_node)):
132
- conditional_nodes.add(trans.from_node)
133
-
134
- # Identify convergence nodes
135
- convergence_nodes: Set[str] = set(workflow_def.workflow.convergence_nodes)
136
-
137
- # Shared node definitions and types
138
- node_types: Dict[str, str] = {}
139
- node_shapes: Dict[str, str] = {} # Only used for flowchart
140
-
141
- # Assemble the Mermaid syntax
142
- mermaid_code = "```mermaid\n"
143
- if diagram_type == "flowchart":
144
- mermaid_code += "graph TD\n" # Top-down layout
145
- else: # stateDiagram
146
- mermaid_code += "stateDiagram-v2\n"
147
-
148
- if title:
149
- mermaid_code += f" %% Diagram: {title}\n"
150
-
151
- # Optional legend for UX
152
- if include_legend:
153
- mermaid_code += " %% Legend:\n"
154
- if diagram_type == "flowchart":
155
- mermaid_code += " %% - Rectangle: Process Step or Convergence Point\n"
156
- mermaid_code += " %% - Diamond: Decision Point (Branching)\n"
157
- mermaid_code += " %% - Colors: Blue (Function), Green (Structured LLM), Purple (LLM), Pink (Template), Orange (Sub-Workflow), Grey (Unknown)\n"
158
- mermaid_code += " %% - Dashed Border: Convergence Node\n"
159
-
160
- if diagram_type == "flowchart":
161
- # Flowchart-specific: Generate node definitions with shapes
162
- node_defs: List[str] = []
163
- for node in all_nodes:
164
- node_def_flow: Optional[NodeDefinition] = workflow_def.nodes.get(node)
165
- has_conditions = node in conditional_nodes
166
- label, node_type, shape = get_node_label_and_type(node, node_def_flow, has_conditions)
167
- start_shape, end_shape = shape_syntax[shape]
168
- node_defs.append(f'{node}{start_shape}"{label}"{end_shape}')
169
- node_types[node] = node_type
170
- node_shapes[node] = shape
171
-
172
- # Add node definitions
173
- for node_def_str in sorted(node_defs): # Sort for consistent output
174
- mermaid_code += f" {node_def_str}\n"
175
-
176
- # Generate arrows for transitions
177
- for trans in workflow_def.workflow.transitions:
178
- from_node = trans.from_node
179
- if isinstance(trans.to_node, str):
180
- to_node = trans.to_node
181
- condition = trans.condition
182
- if condition:
183
- cond = condition.replace('"', '\\"')[:30] + ("..." if len(condition) > 30 else "")
184
- mermaid_code += f' {from_node} -->|"{cond}"| {to_node}\n'
185
- else:
186
- mermaid_code += f' {from_node} --> {to_node}\n'
187
- else:
188
- for tn in trans.to_node:
189
- if isinstance(tn, str):
190
- # Parallel transition (no condition)
191
- mermaid_code += f' {from_node} --> {tn}\n'
192
- else: # BranchCondition
193
- to_node = tn.to_node
194
- condition = tn.condition
195
- if condition:
196
- cond = condition.replace('"', '\\"')[:30] + ("..." if len(condition) > 30 else "")
197
- mermaid_code += f' {from_node} -->|"{cond}"| {to_node}\n'
198
- else:
199
- mermaid_code += f' {from_node} --> {to_node}\n'
200
-
201
- # Add styles for node types
202
- for node, node_type in node_types.items():
203
- if node_type in node_styles:
204
- style = node_styles[node_type]
205
- # Add dashed stroke for convergence nodes
206
- if node in convergence_nodes:
207
- style += ",stroke-dasharray:5 5"
208
- mermaid_code += f" style {node} {style}\n"
209
-
210
- # Highlight the start node with a thicker border
211
- if workflow_def.workflow.start and workflow_def.workflow.start in node_types:
212
- mermaid_code += f" style {workflow_def.workflow.start} stroke-width:4px\n"
213
-
214
- # Optional: Subgraphs for sub-workflows
215
- if include_subgraphs:
216
- for node, node_def_entry in workflow_def.nodes.items():
217
- if node_def_entry and node_def_entry.sub_workflow:
218
- mermaid_code += f" subgraph {node}_sub[Sub-Workflow: {node}]\n"
219
- sub_nodes: Set[str] = {node_def_entry.sub_workflow.start} if node_def_entry.sub_workflow.start else set()
220
- for trans in node_def_entry.sub_workflow.transitions:
221
- sub_nodes.add(trans.from_node)
222
- if isinstance(trans.to_node, str):
223
- sub_nodes.add(trans.to_node)
224
- else:
225
- for tn in trans.to_node:
226
- target = tn if isinstance(tn, str) else tn.to_node
227
- sub_nodes.add(target)
228
- for sub_node in sorted(sub_nodes): # Sort for consistency
229
- mermaid_code += f" {sub_node}[[{sub_node}]]\n"
230
- mermaid_code += " end\n"
231
-
232
- else: # stateDiagram
233
- # StateDiagram-specific: Define states
234
- for node in all_nodes:
235
- node_def_state: Optional[NodeDefinition] = workflow_def.nodes.get(node)
236
- has_conditions = node in conditional_nodes
237
- label, node_type, _ = get_node_label_and_type(node, node_def_state, has_conditions) # Shape unused
238
- # Append (Convergence) to label for convergence nodes
239
- if node in convergence_nodes:
240
- label += " (Convergence)"
241
- mermaid_code += f" state \"{label}\" as {node}\n"
242
- node_types[node] = node_type
243
-
244
- # Start state
245
- if workflow_def.workflow.start:
246
- mermaid_code += f" [*] --> {workflow_def.workflow.start}\n"
247
-
248
- # Transitions
249
- for trans in workflow_def.workflow.transitions:
250
- from_node = trans.from_node
251
- if isinstance(trans.to_node, str):
252
- to_node = trans.to_node
253
- condition = trans.condition
254
- if condition:
255
- cond = condition.replace('"', '\\"')[:30] + ("..." if len(condition) > 30 else "")
256
- mermaid_code += f" {from_node} --> {to_node} : {cond}\n"
257
- else:
258
- mermaid_code += f" {from_node} --> {to_node}\n"
259
- else:
260
- for tn in trans.to_node:
261
- if isinstance(tn, str):
262
- # Parallel transition approximated with a note
263
- mermaid_code += f" {from_node} --> {tn} : parallel\n"
264
- else: # BranchCondition
265
- to_node = tn.to_node
266
- condition = tn.condition
267
- if condition:
268
- cond = condition.replace('"', '\\"')[:30] + ("..." if len(condition) > 30 else "")
269
- mermaid_code += f" {from_node} --> {to_node} : {cond}\n"
270
- else:
271
- mermaid_code += f" {from_node} --> {to_node}\n"
272
-
273
- # Add styles for node types
274
- for node, node_type in node_types.items():
275
- if node_type in node_styles:
276
- style = node_styles[node_type]
277
- # Add dashed stroke for convergence nodes
278
- if node in convergence_nodes:
279
- style += ",stroke-dasharray:5 5"
280
- mermaid_code += f" style {node} {style}\n"
281
-
282
- mermaid_code += "```\n"
283
- return mermaid_code
284
-
285
-
286
- def main() -> None:
287
- """Create a complex workflow with branch, converge, template node, and input mapping, and print its Mermaid diagram."""
288
- manager = WorkflowManager()
289
-
290
- # Add functions
291
- manager.add_function(
292
- name="say_hello",
293
- type_="embedded",
294
- code="def say_hello():\n return 'Hello, World!'"
295
- )
296
- manager.add_function(
297
- name="check_condition",
298
- type_="embedded",
299
- code="def check_condition(text: str):\n return 'yes' if 'Hello' in text else 'no'"
300
- )
301
- manager.add_function(
302
- name="say_goodbye",
303
- type_="embedded",
304
- code="def say_goodbye():\n return 'Goodbye, World!'"
305
- )
306
- manager.add_function(
307
- name="finalize",
308
- type_="embedded",
309
- code="def finalize(text: str):\n return 'Done'"
310
- )
311
-
312
- # Add nodes
313
- manager.add_node(name="start", function="say_hello", output="text")
314
- manager.add_node(name="check", function="check_condition", output="result",
315
- inputs_mapping={"text": "text"})
316
- manager.add_node(name="goodbye", function="say_goodbye", output="farewell")
317
- manager.add_node(name="finalize", function="finalize", output="status",
318
- inputs_mapping={"text": "lambda ctx: ctx['farewell'] if ctx['result'] == 'no' else ctx['ai_result']"})
319
-
320
- # Add LLM node
321
- manager.add_node(
322
- name="ai_node",
323
- llm_config={
324
- "model": "gpt-3.5-turbo",
325
- "prompt_template": "{{text}}",
326
- "temperature": 0.7
327
- },
328
- output="ai_result"
329
- )
330
-
331
- # Add template node
332
- manager.add_node(
333
- name="template_node",
334
- template_config={
335
- "template": "Response: {{text}} - {{result}}"
336
- },
337
- output="template_output",
338
- inputs_mapping={"text": "text", "result": "result"}
339
- )
340
-
341
- # Define workflow structure with branch and converge
342
- manager.set_start_node("start")
343
- manager.add_transition(from_node="start", to_node="check")
344
- manager.add_transition(
345
- from_node="check",
346
- to_node=[
347
- BranchCondition(to_node="ai_node", condition="ctx['result'] == 'yes'"),
348
- BranchCondition(to_node="goodbye", condition="ctx['result'] == 'no'")
349
- ]
350
- )
351
- manager.add_transition(from_node="ai_node", to_node="finalize")
352
- manager.add_transition(from_node="goodbye", to_node="finalize")
353
- manager.add_transition(from_node="finalize", to_node="template_node")
354
- manager.add_convergence_node("finalize")
355
-
356
- # Generate and print both diagrams
357
- workflow_def = manager.workflow
358
- print("Flowchart (default):")
359
- print(generate_mermaid_diagram(workflow_def, include_subgraphs=False, title="Sample Workflow with Template and Mapping"))
360
- print("\nState Diagram:")
361
- print(generate_mermaid_diagram(workflow_def, diagram_type="stateDiagram", title="Sample Workflow with Template and Mapping"))
362
-
363
-
364
- if __name__ == "__main__":
365
- main()