quantalogic 0.33.4__py3-none-any.whl → 0.40.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- quantalogic/__init__.py +0 -4
- quantalogic/agent.py +603 -362
- quantalogic/agent_config.py +260 -28
- quantalogic/agent_factory.py +43 -17
- quantalogic/coding_agent.py +20 -12
- quantalogic/config.py +7 -4
- quantalogic/console_print_events.py +4 -8
- quantalogic/console_print_token.py +2 -2
- quantalogic/docs_cli.py +15 -10
- quantalogic/event_emitter.py +258 -83
- quantalogic/flow/__init__.py +23 -0
- quantalogic/flow/flow.py +595 -0
- quantalogic/flow/flow_extractor.py +672 -0
- quantalogic/flow/flow_generator.py +89 -0
- quantalogic/flow/flow_manager.py +407 -0
- quantalogic/flow/flow_manager_schema.py +169 -0
- quantalogic/flow/flow_yaml.md +419 -0
- quantalogic/generative_model.py +109 -77
- quantalogic/get_model_info.py +6 -6
- quantalogic/interactive_text_editor.py +100 -73
- quantalogic/main.py +36 -23
- quantalogic/model_info_list.py +12 -0
- quantalogic/model_info_litellm.py +14 -14
- quantalogic/prompts.py +2 -1
- quantalogic/{llm.py → quantlitellm.py} +29 -39
- quantalogic/search_agent.py +4 -4
- quantalogic/server/models.py +4 -1
- quantalogic/task_file_reader.py +5 -5
- quantalogic/task_runner.py +21 -20
- quantalogic/tool_manager.py +10 -21
- quantalogic/tools/__init__.py +98 -68
- quantalogic/tools/composio/composio.py +416 -0
- quantalogic/tools/{generate_database_report_tool.py → database/generate_database_report_tool.py} +4 -9
- quantalogic/tools/database/sql_query_tool_advanced.py +261 -0
- quantalogic/tools/document_tools/markdown_to_docx_tool.py +620 -0
- quantalogic/tools/document_tools/markdown_to_epub_tool.py +438 -0
- quantalogic/tools/document_tools/markdown_to_html_tool.py +362 -0
- quantalogic/tools/document_tools/markdown_to_ipynb_tool.py +319 -0
- quantalogic/tools/document_tools/markdown_to_latex_tool.py +420 -0
- quantalogic/tools/document_tools/markdown_to_pdf_tool.py +623 -0
- quantalogic/tools/document_tools/markdown_to_pptx_tool.py +319 -0
- quantalogic/tools/duckduckgo_search_tool.py +2 -4
- quantalogic/tools/finance/alpha_vantage_tool.py +440 -0
- quantalogic/tools/finance/ccxt_tool.py +373 -0
- quantalogic/tools/finance/finance_llm_tool.py +387 -0
- quantalogic/tools/finance/google_finance.py +192 -0
- quantalogic/tools/finance/market_intelligence_tool.py +520 -0
- quantalogic/tools/finance/technical_analysis_tool.py +491 -0
- quantalogic/tools/finance/tradingview_tool.py +336 -0
- quantalogic/tools/finance/yahoo_finance.py +236 -0
- quantalogic/tools/git/bitbucket_clone_repo_tool.py +181 -0
- quantalogic/tools/git/bitbucket_operations_tool.py +326 -0
- quantalogic/tools/git/clone_repo_tool.py +189 -0
- quantalogic/tools/git/git_operations_tool.py +532 -0
- quantalogic/tools/google_packages/google_news_tool.py +480 -0
- quantalogic/tools/grep_app_tool.py +123 -186
- quantalogic/tools/{dalle_e.py → image_generation/dalle_e.py} +37 -27
- quantalogic/tools/jinja_tool.py +6 -10
- quantalogic/tools/language_handlers/__init__.py +22 -9
- quantalogic/tools/list_directory_tool.py +131 -42
- quantalogic/tools/llm_tool.py +45 -15
- quantalogic/tools/llm_vision_tool.py +59 -7
- quantalogic/tools/markitdown_tool.py +17 -5
- quantalogic/tools/nasa_packages/models.py +47 -0
- quantalogic/tools/nasa_packages/nasa_apod_tool.py +232 -0
- quantalogic/tools/nasa_packages/nasa_neows_tool.py +147 -0
- quantalogic/tools/nasa_packages/services.py +82 -0
- quantalogic/tools/presentation_tools/presentation_llm_tool.py +396 -0
- quantalogic/tools/product_hunt/product_hunt_tool.py +258 -0
- quantalogic/tools/product_hunt/services.py +63 -0
- quantalogic/tools/rag_tool/__init__.py +48 -0
- quantalogic/tools/rag_tool/document_metadata.py +15 -0
- quantalogic/tools/rag_tool/query_response.py +20 -0
- quantalogic/tools/rag_tool/rag_tool.py +566 -0
- quantalogic/tools/rag_tool/rag_tool_beta.py +264 -0
- quantalogic/tools/read_html_tool.py +24 -38
- quantalogic/tools/replace_in_file_tool.py +10 -10
- quantalogic/tools/safe_python_interpreter_tool.py +10 -24
- quantalogic/tools/search_definition_names.py +2 -2
- quantalogic/tools/sequence_tool.py +14 -23
- quantalogic/tools/sql_query_tool.py +17 -19
- quantalogic/tools/tool.py +39 -15
- quantalogic/tools/unified_diff_tool.py +1 -1
- quantalogic/tools/utilities/csv_processor_tool.py +234 -0
- quantalogic/tools/utilities/download_file_tool.py +179 -0
- quantalogic/tools/utilities/mermaid_validator_tool.py +661 -0
- quantalogic/tools/utils/__init__.py +1 -4
- quantalogic/tools/utils/create_sample_database.py +24 -38
- quantalogic/tools/utils/generate_database_report.py +74 -82
- quantalogic/tools/wikipedia_search_tool.py +17 -21
- quantalogic/utils/ask_user_validation.py +1 -1
- quantalogic/utils/async_utils.py +35 -0
- quantalogic/utils/check_version.py +3 -5
- quantalogic/utils/get_all_models.py +2 -1
- quantalogic/utils/git_ls.py +21 -7
- quantalogic/utils/lm_studio_model_info.py +9 -7
- quantalogic/utils/python_interpreter.py +113 -43
- quantalogic/utils/xml_utility.py +178 -0
- quantalogic/version_check.py +1 -1
- quantalogic/welcome_message.py +7 -7
- quantalogic/xml_parser.py +0 -1
- {quantalogic-0.33.4.dist-info → quantalogic-0.40.0.dist-info}/METADATA +44 -1
- quantalogic-0.40.0.dist-info/RECORD +148 -0
- quantalogic-0.33.4.dist-info/RECORD +0 -102
- {quantalogic-0.33.4.dist-info → quantalogic-0.40.0.dist-info}/LICENSE +0 -0
- {quantalogic-0.33.4.dist-info → quantalogic-0.40.0.dist-info}/WHEEL +0 -0
- {quantalogic-0.33.4.dist-info → quantalogic-0.40.0.dist-info}/entry_points.txt +0 -0
quantalogic/flow/flow.py
ADDED
@@ -0,0 +1,595 @@
|
|
1
|
+
#!/usr/bin/env -S uv run
|
2
|
+
# /// script
|
3
|
+
# requires-python = ">=3.12"
|
4
|
+
# dependencies = [
|
5
|
+
# "loguru",
|
6
|
+
# "litellm",
|
7
|
+
# "pydantic>=2.0",
|
8
|
+
# "anyio",
|
9
|
+
# "jinja2",
|
10
|
+
# "instructor[litellm]" # Required for structured_llm_node
|
11
|
+
# ]
|
12
|
+
# ///
|
13
|
+
|
14
|
+
import asyncio
|
15
|
+
from dataclasses import dataclass
|
16
|
+
from enum import Enum
|
17
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Type
|
18
|
+
|
19
|
+
import instructor
|
20
|
+
from jinja2 import Template
|
21
|
+
from litellm import acompletion
|
22
|
+
from loguru import logger
|
23
|
+
from pydantic import BaseModel, ValidationError
|
24
|
+
|
25
|
+
|
26
|
+
# Define event types and structure for observer system
|
27
|
+
class WorkflowEventType(Enum):
|
28
|
+
NODE_STARTED = "node_started"
|
29
|
+
NODE_COMPLETED = "node_completed"
|
30
|
+
NODE_FAILED = "node_failed"
|
31
|
+
TRANSITION_EVALUATED = "transition_evaluated"
|
32
|
+
WORKFLOW_STARTED = "workflow_started"
|
33
|
+
WORKFLOW_COMPLETED = "workflow_completed"
|
34
|
+
SUB_WORKFLOW_ENTERED = "sub_workflow_entered"
|
35
|
+
SUB_WORKFLOW_EXITED = "sub_workflow_exited"
|
36
|
+
|
37
|
+
|
38
|
+
@dataclass
|
39
|
+
class WorkflowEvent:
|
40
|
+
event_type: WorkflowEventType
|
41
|
+
node_name: Optional[str]
|
42
|
+
context: Dict[str, Any]
|
43
|
+
result: Optional[Any] = None
|
44
|
+
exception: Optional[Exception] = None
|
45
|
+
transition_from: Optional[str] = None
|
46
|
+
transition_to: Optional[str] = None
|
47
|
+
sub_workflow_name: Optional[str] = None
|
48
|
+
usage: Optional[Dict[str, Any]] = None # Added to store token usage and cost
|
49
|
+
|
50
|
+
|
51
|
+
WorkflowObserver = Callable[[WorkflowEvent], None]
|
52
|
+
|
53
|
+
|
54
|
+
# Define a class for sub-workflow nodes
|
55
|
+
class SubWorkflowNode:
|
56
|
+
def __init__(self, sub_workflow: "Workflow", inputs: Dict[str, str], output: str):
|
57
|
+
"""Initialize a sub-workflow node."""
|
58
|
+
self.sub_workflow = sub_workflow
|
59
|
+
self.inputs = inputs
|
60
|
+
self.output = output
|
61
|
+
|
62
|
+
async def __call__(self, engine: "WorkflowEngine", **kwargs):
|
63
|
+
"""Execute the sub-workflow with the engine's context."""
|
64
|
+
sub_context = {sub_key: kwargs[main_key] for main_key, sub_key in self.inputs.items()}
|
65
|
+
sub_engine = self.sub_workflow.build(parent_engine=engine)
|
66
|
+
result = await sub_engine.run(sub_context)
|
67
|
+
return result.get(self.output)
|
68
|
+
|
69
|
+
|
70
|
+
class WorkflowEngine:
|
71
|
+
def __init__(self, workflow, parent_engine: Optional["WorkflowEngine"] = None):
|
72
|
+
"""Initialize the WorkflowEngine with a workflow and optional parent for sub-workflows."""
|
73
|
+
self.workflow = workflow
|
74
|
+
self.context = {}
|
75
|
+
self.observers: List[WorkflowObserver] = []
|
76
|
+
self.parent_engine = parent_engine # Link to parent engine for sub-workflow observer propagation
|
77
|
+
|
78
|
+
def add_observer(self, observer: WorkflowObserver) -> None:
|
79
|
+
"""Register an event observer callback."""
|
80
|
+
if observer not in self.observers:
|
81
|
+
self.observers.append(observer)
|
82
|
+
logger.debug(f"Added observer: {observer}")
|
83
|
+
if self.parent_engine:
|
84
|
+
self.parent_engine.add_observer(observer) # Propagate to parent for global visibility
|
85
|
+
|
86
|
+
def remove_observer(self, observer: WorkflowObserver) -> None:
|
87
|
+
"""Remove an event observer callback."""
|
88
|
+
if observer in self.observers:
|
89
|
+
self.observers.remove(observer)
|
90
|
+
logger.debug(f"Removed observer: {observer}")
|
91
|
+
|
92
|
+
async def _notify_observers(self, event: WorkflowEvent) -> None:
|
93
|
+
"""Asynchronously notify all observers of an event."""
|
94
|
+
tasks = []
|
95
|
+
for observer in self.observers:
|
96
|
+
try:
|
97
|
+
if asyncio.iscoroutinefunction(observer):
|
98
|
+
tasks.append(observer(event))
|
99
|
+
else:
|
100
|
+
observer(event)
|
101
|
+
except Exception as e:
|
102
|
+
logger.error(f"Observer {observer} failed for {event.event_type.value}: {e}")
|
103
|
+
if tasks:
|
104
|
+
await asyncio.gather(*tasks)
|
105
|
+
|
106
|
+
async def run(self, initial_context: Dict[str, Any]) -> Dict[str, Any]:
|
107
|
+
"""Execute the workflow starting from the entry node with event notifications."""
|
108
|
+
self.context = initial_context.copy()
|
109
|
+
await self._notify_observers(
|
110
|
+
WorkflowEvent(event_type=WorkflowEventType.WORKFLOW_STARTED, node_name=None, context=self.context)
|
111
|
+
)
|
112
|
+
|
113
|
+
current_node = self.workflow.start_node
|
114
|
+
while current_node:
|
115
|
+
logger.info(f"Executing node: {current_node}")
|
116
|
+
await self._notify_observers(
|
117
|
+
WorkflowEvent(event_type=WorkflowEventType.NODE_STARTED, node_name=current_node, context=self.context)
|
118
|
+
)
|
119
|
+
|
120
|
+
node_func = self.workflow.nodes.get(current_node)
|
121
|
+
if not node_func:
|
122
|
+
logger.error(f"Node {current_node} not found")
|
123
|
+
exc = ValueError(f"Node {current_node} not found")
|
124
|
+
await self._notify_observers(
|
125
|
+
WorkflowEvent(
|
126
|
+
event_type=WorkflowEventType.NODE_FAILED,
|
127
|
+
node_name=current_node,
|
128
|
+
context=self.context,
|
129
|
+
exception=exc,
|
130
|
+
)
|
131
|
+
)
|
132
|
+
break
|
133
|
+
|
134
|
+
inputs = {k: self.context[k] for k in self.workflow.node_inputs[current_node] if k in self.context}
|
135
|
+
result = None
|
136
|
+
exception = None
|
137
|
+
|
138
|
+
# Handle sub-workflow nodes
|
139
|
+
if isinstance(node_func, SubWorkflowNode):
|
140
|
+
await self._notify_observers(
|
141
|
+
WorkflowEvent(
|
142
|
+
event_type=WorkflowEventType.SUB_WORKFLOW_ENTERED,
|
143
|
+
node_name=current_node,
|
144
|
+
context=self.context,
|
145
|
+
sub_workflow_name=current_node,
|
146
|
+
)
|
147
|
+
)
|
148
|
+
|
149
|
+
try:
|
150
|
+
if isinstance(node_func, SubWorkflowNode):
|
151
|
+
result = await node_func(self, **inputs)
|
152
|
+
usage = None # Sub-workflow usage is handled by its own nodes
|
153
|
+
else:
|
154
|
+
result = await node_func(**inputs)
|
155
|
+
usage = getattr(node_func, "usage", None) # Extract usage if set by LLM nodes
|
156
|
+
output_key = self.workflow.node_outputs[current_node]
|
157
|
+
if output_key:
|
158
|
+
self.context[output_key] = result
|
159
|
+
await self._notify_observers(
|
160
|
+
WorkflowEvent(
|
161
|
+
event_type=WorkflowEventType.NODE_COMPLETED,
|
162
|
+
node_name=current_node,
|
163
|
+
context=self.context,
|
164
|
+
result=result,
|
165
|
+
usage=usage, # Include usage data in the event
|
166
|
+
)
|
167
|
+
)
|
168
|
+
except Exception as e:
|
169
|
+
logger.error(f"Error executing node {current_node}: {e}")
|
170
|
+
exception = e
|
171
|
+
await self._notify_observers(
|
172
|
+
WorkflowEvent(
|
173
|
+
event_type=WorkflowEventType.NODE_FAILED,
|
174
|
+
node_name=current_node,
|
175
|
+
context=self.context,
|
176
|
+
exception=e,
|
177
|
+
)
|
178
|
+
)
|
179
|
+
raise
|
180
|
+
finally:
|
181
|
+
if isinstance(node_func, SubWorkflowNode):
|
182
|
+
await self._notify_observers(
|
183
|
+
WorkflowEvent(
|
184
|
+
event_type=WorkflowEventType.SUB_WORKFLOW_EXITED,
|
185
|
+
node_name=current_node,
|
186
|
+
context=self.context,
|
187
|
+
sub_workflow_name=current_node,
|
188
|
+
result=result,
|
189
|
+
exception=exception,
|
190
|
+
)
|
191
|
+
)
|
192
|
+
|
193
|
+
next_nodes = self.workflow.transitions.get(current_node, [])
|
194
|
+
current_node = None
|
195
|
+
for next_node, condition in next_nodes:
|
196
|
+
await self._notify_observers(
|
197
|
+
WorkflowEvent(
|
198
|
+
event_type=WorkflowEventType.TRANSITION_EVALUATED,
|
199
|
+
node_name=None,
|
200
|
+
context=self.context,
|
201
|
+
transition_from=current_node,
|
202
|
+
transition_to=next_node,
|
203
|
+
)
|
204
|
+
)
|
205
|
+
if condition is None or condition(self.context):
|
206
|
+
current_node = next_node
|
207
|
+
break
|
208
|
+
|
209
|
+
logger.info("Workflow execution completed")
|
210
|
+
await self._notify_observers(
|
211
|
+
WorkflowEvent(event_type=WorkflowEventType.WORKFLOW_COMPLETED, node_name=None, context=self.context)
|
212
|
+
)
|
213
|
+
return self.context
|
214
|
+
|
215
|
+
|
216
|
+
class Workflow:
|
217
|
+
def __init__(self, start_node: str):
|
218
|
+
"""Initialize a workflow with a starting node."""
|
219
|
+
self.start_node = start_node
|
220
|
+
self.nodes: Dict[str, Callable] = {}
|
221
|
+
self.node_inputs: Dict[str, List[str]] = {}
|
222
|
+
self.node_outputs: Dict[str, Optional[str]] = {}
|
223
|
+
self.transitions: Dict[str, List[Tuple[str, Optional[Callable]]]] = {}
|
224
|
+
self.current_node = None
|
225
|
+
self._observers: List[WorkflowObserver] = [] # Store observers for later propagation
|
226
|
+
self._register_node(start_node) # Register the start node without setting current_node
|
227
|
+
self.current_node = start_node # Set current_node explicitly after registration
|
228
|
+
|
229
|
+
def _register_node(self, name: str):
|
230
|
+
"""Register a node without modifying the current node."""
|
231
|
+
if name not in Nodes.NODE_REGISTRY:
|
232
|
+
raise ValueError(f"Node {name} not registered")
|
233
|
+
func, inputs, output = Nodes.NODE_REGISTRY[name]
|
234
|
+
self.nodes[name] = func
|
235
|
+
self.node_inputs[name] = inputs
|
236
|
+
self.node_outputs[name] = output
|
237
|
+
|
238
|
+
def node(self, name: str):
|
239
|
+
"""Add a node to the workflow chain and set it as the current node."""
|
240
|
+
self._register_node(name)
|
241
|
+
self.current_node = name
|
242
|
+
return self
|
243
|
+
|
244
|
+
def sequence(self, *nodes: str):
|
245
|
+
"""Add a sequence of nodes to execute in order."""
|
246
|
+
if not nodes:
|
247
|
+
return self
|
248
|
+
for node in nodes:
|
249
|
+
if node not in Nodes.NODE_REGISTRY:
|
250
|
+
raise ValueError(f"Node {node} not registered")
|
251
|
+
func, inputs, output = Nodes.NODE_REGISTRY[node]
|
252
|
+
self.nodes[node] = func
|
253
|
+
self.node_inputs[node] = inputs
|
254
|
+
self.node_outputs[node] = output
|
255
|
+
for i in range(len(nodes) - 1):
|
256
|
+
self.transitions.setdefault(nodes[i], []).append((nodes[i + 1], None))
|
257
|
+
self.current_node = nodes[-1]
|
258
|
+
return self
|
259
|
+
|
260
|
+
def then(self, next_node: str, condition: Optional[Callable] = None):
|
261
|
+
"""Add a transition to the next node with an optional condition."""
|
262
|
+
if next_node not in self.nodes:
|
263
|
+
self._register_node(next_node) # Register without changing current_node
|
264
|
+
if self.current_node:
|
265
|
+
self.transitions.setdefault(self.current_node, []).append((next_node, condition))
|
266
|
+
logger.debug(f"Added transition from {self.current_node} to {next_node} with condition {condition}")
|
267
|
+
else:
|
268
|
+
logger.warning("No current node set for transition")
|
269
|
+
self.current_node = next_node
|
270
|
+
return self
|
271
|
+
|
272
|
+
def parallel(self, *nodes: str):
|
273
|
+
"""Add parallel nodes to execute concurrently."""
|
274
|
+
if self.current_node:
|
275
|
+
for node in nodes:
|
276
|
+
self.transitions.setdefault(self.current_node, []).append((node, None))
|
277
|
+
self.current_node = None # Reset after parallel to force explicit next node
|
278
|
+
return self
|
279
|
+
|
280
|
+
def add_observer(self, observer: WorkflowObserver) -> "Workflow":
|
281
|
+
"""Add an event observer callback to the workflow."""
|
282
|
+
if observer not in self._observers:
|
283
|
+
self._observers.append(observer)
|
284
|
+
logger.debug(f"Added observer to workflow: {observer}")
|
285
|
+
return self # Support chaining
|
286
|
+
|
287
|
+
def add_sub_workflow(self, name: str, sub_workflow: "Workflow", inputs: Dict[str, str], output: str):
|
288
|
+
"""Add a sub-workflow as a node."""
|
289
|
+
sub_node = SubWorkflowNode(sub_workflow, inputs, output)
|
290
|
+
self.nodes[name] = sub_node
|
291
|
+
self.node_inputs[name] = list(inputs.keys())
|
292
|
+
self.node_outputs[name] = output
|
293
|
+
self.current_node = name
|
294
|
+
return self
|
295
|
+
|
296
|
+
def build(self, parent_engine: Optional["WorkflowEngine"] = None) -> WorkflowEngine:
|
297
|
+
"""Build and return a WorkflowEngine instance with registered observers."""
|
298
|
+
engine = WorkflowEngine(self, parent_engine=parent_engine)
|
299
|
+
for observer in self._observers:
|
300
|
+
engine.add_observer(observer)
|
301
|
+
return engine
|
302
|
+
|
303
|
+
|
304
|
+
class Nodes:
|
305
|
+
NODE_REGISTRY = {} # Registry to hold node functions and metadata
|
306
|
+
|
307
|
+
@classmethod
|
308
|
+
def define(cls, output: Optional[str] = None):
|
309
|
+
"""Decorator for defining simple workflow nodes."""
|
310
|
+
|
311
|
+
def decorator(func: Callable) -> Callable:
|
312
|
+
async def wrapped_func(**kwargs):
|
313
|
+
try:
|
314
|
+
result = await func(**kwargs)
|
315
|
+
logger.debug(f"Node {func.__name__} executed with result: {result}")
|
316
|
+
return result
|
317
|
+
except Exception as e:
|
318
|
+
logger.error(f"Error in node {func.__name__}: {e}")
|
319
|
+
raise
|
320
|
+
|
321
|
+
inputs = list(func.__annotations__.keys())
|
322
|
+
logger.debug(f"Registering node {func.__name__} with inputs {inputs} and output {output}")
|
323
|
+
cls.NODE_REGISTRY[func.__name__] = (wrapped_func, inputs, output)
|
324
|
+
return wrapped_func
|
325
|
+
|
326
|
+
return decorator
|
327
|
+
|
328
|
+
@classmethod
|
329
|
+
def validate_node(cls, output: str):
|
330
|
+
"""Decorator for nodes that validate inputs."""
|
331
|
+
|
332
|
+
def decorator(func: Callable) -> Callable:
|
333
|
+
async def wrapped_func(**kwargs):
|
334
|
+
try:
|
335
|
+
result = await func(**kwargs)
|
336
|
+
if not isinstance(result, str):
|
337
|
+
raise ValueError(f"Validation node {func.__name__} must return a string")
|
338
|
+
logger.info(f"Validation result from {func.__name__}: {result}")
|
339
|
+
return result
|
340
|
+
except Exception as e:
|
341
|
+
logger.error(f"Validation error in {func.__name__}: {e}")
|
342
|
+
raise
|
343
|
+
|
344
|
+
inputs = list(func.__annotations__.keys())
|
345
|
+
logger.debug(f"Registering node {func.__name__} with inputs {inputs} and output {output}")
|
346
|
+
cls.NODE_REGISTRY[func.__name__] = (wrapped_func, inputs, output)
|
347
|
+
return wrapped_func
|
348
|
+
|
349
|
+
return decorator
|
350
|
+
|
351
|
+
@classmethod
|
352
|
+
def llm_node(
|
353
|
+
cls,
|
354
|
+
model: str,
|
355
|
+
system_prompt: str,
|
356
|
+
prompt_template: str,
|
357
|
+
output: str,
|
358
|
+
temperature: float = 0.7,
|
359
|
+
max_tokens: int = 2000,
|
360
|
+
top_p: float = 1.0,
|
361
|
+
presence_penalty: float = 0.0,
|
362
|
+
frequency_penalty: float = 0.0,
|
363
|
+
**kwargs,
|
364
|
+
):
|
365
|
+
"""Decorator for creating LLM nodes with plain text output."""
|
366
|
+
|
367
|
+
def decorator(func: Callable) -> Callable:
|
368
|
+
async def wrapped_func(**kwargs):
|
369
|
+
prompt = cls._render_prompt(prompt_template, kwargs)
|
370
|
+
messages = [
|
371
|
+
{"role": "system", "content": system_prompt},
|
372
|
+
{"role": "user", "content": prompt},
|
373
|
+
]
|
374
|
+
try:
|
375
|
+
response = await acompletion(
|
376
|
+
model=model,
|
377
|
+
messages=messages,
|
378
|
+
temperature=temperature,
|
379
|
+
max_tokens=max_tokens,
|
380
|
+
top_p=top_p,
|
381
|
+
presence_penalty=presence_penalty,
|
382
|
+
frequency_penalty=frequency_penalty,
|
383
|
+
drop_params=True,
|
384
|
+
**kwargs,
|
385
|
+
)
|
386
|
+
content = response.choices[0].message.content.strip()
|
387
|
+
# Attach usage metadata to the function
|
388
|
+
wrapped_func.usage = {
|
389
|
+
"prompt_tokens": response.usage.prompt_tokens,
|
390
|
+
"completion_tokens": response.usage.completion_tokens,
|
391
|
+
"total_tokens": response.usage.total_tokens,
|
392
|
+
"cost": getattr(response, "cost", None), # Include cost if available
|
393
|
+
}
|
394
|
+
logger.debug(f"LLM output from {func.__name__}: {content[:50]}...")
|
395
|
+
return content
|
396
|
+
except Exception as e:
|
397
|
+
logger.error(f"Error in LLM node {func.__name__}: {e}")
|
398
|
+
raise
|
399
|
+
|
400
|
+
inputs = list(func.__annotations__.keys())
|
401
|
+
logger.debug(f"Registering node {func.__name__} with inputs {inputs} and output {output}")
|
402
|
+
cls.NODE_REGISTRY[func.__name__] = (wrapped_func, inputs, output)
|
403
|
+
return wrapped_func
|
404
|
+
|
405
|
+
return decorator
|
406
|
+
|
407
|
+
@classmethod
|
408
|
+
def structured_llm_node(
|
409
|
+
cls,
|
410
|
+
model: str,
|
411
|
+
system_prompt: str,
|
412
|
+
prompt_template: str,
|
413
|
+
response_model: Type[BaseModel],
|
414
|
+
output: str,
|
415
|
+
temperature: float = 0.7,
|
416
|
+
max_tokens: int = 2000,
|
417
|
+
top_p: float = 1.0,
|
418
|
+
presence_penalty: float = 0.0,
|
419
|
+
frequency_penalty: float = 0.0,
|
420
|
+
**kwargs,
|
421
|
+
):
|
422
|
+
"""Decorator for creating LLM nodes with structured output using instructor."""
|
423
|
+
try:
|
424
|
+
client = instructor.from_litellm(acompletion)
|
425
|
+
except ImportError:
|
426
|
+
logger.error("Instructor not installed. Install with 'pip install instructor[litellm]'")
|
427
|
+
raise ImportError("Instructor is required for structured_llm_node")
|
428
|
+
|
429
|
+
def decorator(func: Callable) -> Callable:
|
430
|
+
async def wrapped_func(**kwargs):
|
431
|
+
prompt = cls._render_prompt(prompt_template, kwargs)
|
432
|
+
messages = [
|
433
|
+
{"role": "system", "content": system_prompt},
|
434
|
+
{"role": "user", "content": prompt},
|
435
|
+
]
|
436
|
+
try:
|
437
|
+
# Use instructor with completion to get both structured output and raw response
|
438
|
+
structured_response, raw_response = await client.chat.completions.create_with_completion(
|
439
|
+
model=model,
|
440
|
+
messages=messages,
|
441
|
+
response_model=response_model,
|
442
|
+
temperature=temperature,
|
443
|
+
max_tokens=max_tokens,
|
444
|
+
top_p=top_p,
|
445
|
+
presence_penalty=presence_penalty,
|
446
|
+
frequency_penalty=frequency_penalty,
|
447
|
+
drop_params=True,
|
448
|
+
**kwargs,
|
449
|
+
)
|
450
|
+
# Attach usage metadata to the function
|
451
|
+
wrapped_func.usage = {
|
452
|
+
"prompt_tokens": raw_response.usage.prompt_tokens,
|
453
|
+
"completion_tokens": raw_response.usage.completion_tokens,
|
454
|
+
"total_tokens": raw_response.usage.total_tokens,
|
455
|
+
"cost": getattr(raw_response, "cost", None), # Include cost if available
|
456
|
+
}
|
457
|
+
logger.debug(f"Structured output from {func.__name__}: {structured_response}")
|
458
|
+
return structured_response
|
459
|
+
except ValidationError as e:
|
460
|
+
logger.error(f"Validation error in {func.__name__}: {e}")
|
461
|
+
raise
|
462
|
+
except Exception as e:
|
463
|
+
logger.error(f"Error in structured LLM node {func.__name__}: {e}")
|
464
|
+
raise
|
465
|
+
|
466
|
+
inputs = list(func.__annotations__.keys())
|
467
|
+
logger.debug(f"Registering node {func.__name__} with inputs {inputs} and output {output}")
|
468
|
+
cls.NODE_REGISTRY[func.__name__] = (wrapped_func, inputs, output)
|
469
|
+
return wrapped_func
|
470
|
+
|
471
|
+
return decorator
|
472
|
+
|
473
|
+
@staticmethod
|
474
|
+
def _render_prompt(template: str, context: Dict[str, Any]) -> str:
|
475
|
+
"""Render a Jinja2 template with the given context."""
|
476
|
+
try:
|
477
|
+
return Template(template).render(**context)
|
478
|
+
except Exception as e:
|
479
|
+
logger.error(f"Error rendering prompt template: {e}")
|
480
|
+
raise
|
481
|
+
|
482
|
+
|
483
|
+
# Example workflow with observer integration and updated structured node
|
484
|
+
async def example_workflow():
|
485
|
+
# Define Pydantic model for structured output
|
486
|
+
class OrderDetails(BaseModel):
|
487
|
+
order_id: str
|
488
|
+
items: List[str]
|
489
|
+
in_stock: bool
|
490
|
+
|
491
|
+
# Define an example observer for progress
|
492
|
+
async def progress_monitor(event: WorkflowEvent):
|
493
|
+
print(f"[{event.event_type.value}] {event.node_name or 'Workflow'}")
|
494
|
+
if event.result is not None:
|
495
|
+
print(f"Result: {event.result}")
|
496
|
+
if event.exception is not None:
|
497
|
+
print(f"Exception: {event.exception}")
|
498
|
+
|
499
|
+
# Define an observer for token usage
|
500
|
+
class TokenUsageObserver:
|
501
|
+
def __init__(self):
|
502
|
+
self.total_prompt_tokens = 0
|
503
|
+
self.total_completion_tokens = 0
|
504
|
+
self.total_cost = 0.0
|
505
|
+
self.node_usages = {}
|
506
|
+
|
507
|
+
def __call__(self, event: WorkflowEvent):
|
508
|
+
if event.event_type == WorkflowEventType.NODE_COMPLETED and event.usage:
|
509
|
+
usage = event.usage
|
510
|
+
self.total_prompt_tokens += usage.get("prompt_tokens", 0)
|
511
|
+
self.total_completion_tokens += usage.get("completion_tokens", 0)
|
512
|
+
if usage.get("cost") is not None:
|
513
|
+
self.total_cost += usage["cost"]
|
514
|
+
self.node_usages[event.node_name] = usage
|
515
|
+
# Print summary at workflow completion
|
516
|
+
if event.event_type == WorkflowEventType.WORKFLOW_COMPLETED:
|
517
|
+
print(f"Total prompt tokens: {self.total_prompt_tokens}")
|
518
|
+
print(f"Total completion tokens: {self.total_completion_tokens}")
|
519
|
+
print(f"Total cost: {self.total_cost}")
|
520
|
+
for node, usage in self.node_usages.items():
|
521
|
+
print(f"Node {node}: {usage}")
|
522
|
+
|
523
|
+
# Define nodes
|
524
|
+
@Nodes.validate_node(output="validation_result")
|
525
|
+
async def validate_order(order: Dict[str, Any]) -> str:
|
526
|
+
return "Order validated" if order.get("items") else "Invalid order"
|
527
|
+
|
528
|
+
@Nodes.structured_llm_node(
|
529
|
+
model="gemini/gemini-2.0-flash",
|
530
|
+
system_prompt="You are an inventory checker. Respond with a JSON object containing 'order_id', 'items', and 'in_stock' (boolean).",
|
531
|
+
prompt_template="Check if the following items are in stock: {{ items }}. Return the result in JSON format with 'order_id' set to '123'.",
|
532
|
+
response_model=OrderDetails,
|
533
|
+
output="inventory_status",
|
534
|
+
)
|
535
|
+
async def check_inventory(items: List[str]) -> OrderDetails:
|
536
|
+
pass
|
537
|
+
|
538
|
+
@Nodes.define(output="payment_status")
|
539
|
+
async def process_payment(order: Dict[str, Any]) -> str:
|
540
|
+
return "Payment processed"
|
541
|
+
|
542
|
+
@Nodes.define(output="shipping_confirmation")
|
543
|
+
async def arrange_shipping(order: Dict[str, Any]) -> str:
|
544
|
+
return "Shipping arranged"
|
545
|
+
|
546
|
+
@Nodes.define(output="order_status")
|
547
|
+
async def update_order_status(shipping_confirmation: str) -> str:
|
548
|
+
return "Order updated"
|
549
|
+
|
550
|
+
@Nodes.define(output="email_status")
|
551
|
+
async def send_confirmation_email(shipping_confirmation: str) -> str:
|
552
|
+
return "Email sent"
|
553
|
+
|
554
|
+
@Nodes.define(output="notification_status")
|
555
|
+
async def notify_customer_out_of_stock(inventory_status: OrderDetails) -> str:
|
556
|
+
return "Customer notified of out-of-stock"
|
557
|
+
|
558
|
+
# Sub-workflow for payment and shipping
|
559
|
+
payment_shipping_sub_wf = Workflow("process_payment").sequence("process_payment", "arrange_shipping")
|
560
|
+
|
561
|
+
# Instantiate token usage observer
|
562
|
+
token_observer = TokenUsageObserver()
|
563
|
+
|
564
|
+
# Main workflow incorporating the sub-workflow
|
565
|
+
workflow = (
|
566
|
+
Workflow("validate_order")
|
567
|
+
.add_observer(progress_monitor) # Add progress observer
|
568
|
+
.add_observer(token_observer) # Add token usage observer
|
569
|
+
.add_sub_workflow(
|
570
|
+
"payment_shipping", payment_shipping_sub_wf, inputs={"order": "order"}, output="shipping_confirmation"
|
571
|
+
)
|
572
|
+
.sequence("validate_order", "check_inventory")
|
573
|
+
.then(
|
574
|
+
"payment_shipping",
|
575
|
+
condition=lambda ctx: ctx.get("inventory_status").in_stock if ctx.get("inventory_status") else False,
|
576
|
+
)
|
577
|
+
.then(
|
578
|
+
"notify_customer_out_of_stock",
|
579
|
+
condition=lambda ctx: not ctx.get("inventory_status").in_stock if ctx.get("inventory_status") else True,
|
580
|
+
)
|
581
|
+
.parallel("update_order_status", "send_confirmation_email")
|
582
|
+
.node("update_order_status")
|
583
|
+
.node("send_confirmation_email")
|
584
|
+
.node("notify_customer_out_of_stock")
|
585
|
+
)
|
586
|
+
|
587
|
+
# Execute workflow
|
588
|
+
initial_context = {"order": {"items": ["item1", "item2"]}, "items": ["item1", "item2"]}
|
589
|
+
engine = workflow.build()
|
590
|
+
result = await engine.run(initial_context)
|
591
|
+
logger.info(f"Workflow result: {result}")
|
592
|
+
|
593
|
+
|
594
|
+
if __name__ == "__main__":
|
595
|
+
asyncio.run(example_workflow())
|