soprano-sdk 0.1.94__py3-none-any.whl → 0.1.96__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
soprano_sdk/tools.py ADDED
@@ -0,0 +1,219 @@
1
+ """
2
+ Workflow Tools - Wraps workflows as callable tools for agent frameworks
3
+ """
4
+
5
+ import uuid
6
+ from typing import Optional, Dict, Any
7
+
8
+ from langgraph.graph.state import CompiledStateGraph
9
+
10
+ from .utils.logger import logger
11
+
12
+ from langfuse.langchain import CallbackHandler
13
+
14
+ from .core.engine import load_workflow
15
+
16
+
17
+ class WorkflowTool:
18
+ """Wraps a conversational workflow as a tool for agent orchestration
19
+
20
+ This allows workflows to be used as tools in LangGraph, CrewAI, or other
21
+ agent frameworks. The supervisor agent can decide which workflow to invoke
22
+ based on user intent.
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ yaml_path: str,
28
+ name: str,
29
+ description: str,
30
+ checkpointer=None,
31
+ config: Optional[Dict]=None
32
+ ):
33
+ """Initialize workflow tool
34
+
35
+ Args:
36
+ yaml_path: Path to workflow YAML file
37
+ name: Tool name (used by agents to reference this tool)
38
+ description: Tool description (helps agent decide when to use it)
39
+ checkpointer: Optional checkpointer for persistence
40
+ """
41
+ self.yaml_path = yaml_path
42
+ self.name = name
43
+ self.description = description
44
+ self.checkpointer = checkpointer
45
+
46
+ # Load workflow
47
+ self.graph, self.engine = load_workflow(yaml_path, checkpointer=checkpointer, config=config)
48
+
49
+ def execute(
50
+ self,
51
+ thread_id: Optional[str] = None,
52
+ user_message: Optional[str] = None,
53
+ initial_context: Optional[Dict[str, Any]] = None
54
+ ) -> str:
55
+ """Execute the workflow with automatic state detection
56
+
57
+ Checks workflow state and automatically resumes if interrupted,
58
+ or starts fresh if not started/completed.
59
+
60
+ Args:
61
+ thread_id: Thread ID for state tracking
62
+ user_message: User's message (used for resume if workflow is interrupted)
63
+ initial_context: Context to inject for fresh starts (e.g., {"order_id": "123"})
64
+
65
+ Returns:
66
+ Final outcome message or interrupt prompt
67
+ """
68
+ from langgraph.types import Command
69
+ from soprano_sdk.utils.tracing import trace_workflow_execution
70
+
71
+ if thread_id is None:
72
+ thread_id = str(uuid.uuid4())
73
+
74
+ with trace_workflow_execution(
75
+ workflow_name=self.engine.workflow_name,
76
+ thread_id=thread_id,
77
+ has_initial_context=initial_context is not None
78
+ ) as span:
79
+ callback_handler = CallbackHandler()
80
+ config = {"configurable": {"thread_id": thread_id}, "callbacks": [callback_handler]}
81
+
82
+ update_context = {}
83
+ engine_context_data = {}
84
+ for key, value in initial_context.items():
85
+ if key in self.engine.collect_input_fields:
86
+ engine_context_data[key] = value
87
+ continue
88
+ if value:
89
+ update_context[key] = value
90
+
91
+ if engine_context_data:
92
+ self.engine.update_context(engine_context_data)
93
+ span.add_event("context.updated", {"fields": list(engine_context_data.keys())})
94
+
95
+ state = self.graph.get_state(config)
96
+
97
+ if state.next:
98
+ # Workflow is interrupted and waiting for input
99
+ span.set_attribute("workflow.resumed", True)
100
+ logger.info(f"[WorkflowTool] Resuming interrupted workflow {self.name} (thread: {thread_id})")
101
+ result = self.graph.invoke(
102
+ Command(resume=user_message or "", update=update_context),
103
+ config=config
104
+ )
105
+ else:
106
+ # Workflow is fresh or completed, start/restart
107
+ span.set_attribute("workflow.resumed", False)
108
+ logger.info(f"[WorkflowTool] Starting fresh workflow {self.name} (thread: {thread_id})")
109
+ result = self.graph.invoke(update_context, config=config)
110
+
111
+ final_state = self.graph.get_state(config)
112
+ if not final_state.next and self.checkpointer:
113
+ self.checkpointer.delete_thread(thread_id)
114
+
115
+ # If workflow needs user input, return structured interrupt data
116
+ if "__interrupt__" in result and result["__interrupt__"]:
117
+ span.set_attribute("workflow.status", "interrupted")
118
+ prompt = result["__interrupt__"][0].value
119
+ return f"__WORKFLOW_INTERRUPT__|{thread_id}|{self.name}|{prompt}"
120
+
121
+ # Workflow completed without interrupting
122
+ span.set_attribute("workflow.status", "completed")
123
+ return self.engine.get_outcome_message(result)
124
+
125
+ def resume(self, thread_id: str, user_message: str) -> str:
126
+ """Resume an interrupted workflow with user input
127
+
128
+ Args:
129
+ thread_id: Thread ID of the interrupted workflow
130
+ user_message: User's response to the interrupt prompt
131
+
132
+ Returns:
133
+ Either another interrupt prompt or final outcome message
134
+ """
135
+ from langgraph.types import Command
136
+
137
+ config = {"configurable": {"thread_id": thread_id}}
138
+ result = self.graph.invoke(Command(resume=user_message), config=config)
139
+
140
+ # Check if workflow needs more input
141
+ if "__interrupt__" in result and result["__interrupt__"]:
142
+ prompt = result["__interrupt__"][0].value
143
+ return f"__WORKFLOW_INTERRUPT__|{thread_id}|{self.name}|{prompt}"
144
+
145
+ # Workflow completed
146
+ return self.engine.get_outcome_message(result)
147
+
148
+ def to_langchain_tool(self):
149
+ """Convert to LangChain tool format
150
+
151
+ Returns:
152
+ LangChain Tool that can be used by LangGraph agents
153
+ """
154
+ from langchain_core.tools import tool
155
+
156
+ # Create function with proper name and docstring
157
+ def workflow_tool(context: str = "") -> str:
158
+ """Execute workflow with optional context"""
159
+ # Parse context if provided (simple key=value format)
160
+ initial_context = {}
161
+ if context:
162
+ for pair in context.split(","):
163
+ if "=" in pair:
164
+ key, value = pair.split("=", 1)
165
+ initial_context[key.strip()] = value.strip()
166
+
167
+ return self.execute(initial_context=initial_context)
168
+
169
+ # Set function name and docstring from tool definition
170
+ workflow_tool.__name__ = self.name
171
+ workflow_tool.__doc__ = self.description
172
+
173
+ # Decorate and return
174
+ return tool(workflow_tool)
175
+
176
+ def to_crewai_tool(self):
177
+ """Convert to CrewAI tool format
178
+
179
+ Returns:
180
+ CrewAI BaseTool that can be used by CrewAI agents
181
+ """
182
+ from crewai.tools import BaseTool
183
+
184
+ # Capture self in closure
185
+ workflow_tool = self
186
+
187
+ # Create a custom CrewAI tool class
188
+ class WorkflowCrewAITool(BaseTool):
189
+ name: str = workflow_tool.name
190
+ description: str = workflow_tool.description
191
+
192
+ def _run(self, context: str = "") -> str:
193
+ """Execute workflow with optional context"""
194
+ # Parse context if provided (simple key=value format)
195
+ initial_context = {}
196
+ if context:
197
+ for pair in context.split(","):
198
+ if "=" in pair:
199
+ key, value = pair.split("=", 1)
200
+ initial_context[key.strip()] = value.strip()
201
+
202
+ return workflow_tool.execute(initial_context=initial_context)
203
+
204
+ # Return an instance of the tool
205
+ return WorkflowCrewAITool()
206
+
207
+ def get_mermaid_diagram(self) -> str:
208
+ return self.graph.get_graph().draw_mermaid()
209
+
210
+ def __call__(self, **kwargs) -> str:
211
+ """Allow tool to be called directly
212
+
213
+ Args:
214
+ **kwargs: Context to pass to workflow
215
+
216
+ Returns:
217
+ Workflow result
218
+ """
219
+ return self.execute(initial_context=kwargs)
File without changes
@@ -0,0 +1,35 @@
1
+ import importlib
2
+ from typing import Callable, Dict
3
+
4
+ from ..utils.logger import logger
5
+
6
+
7
+ class FunctionRepository:
8
+ def __init__(self):
9
+ self._cache: Dict[str, Callable] = {}
10
+
11
+ def load(self, function_path: str) -> Callable:
12
+ if function_path in self._cache:
13
+ logger.info(f"Loading function from cache: {function_path}")
14
+ return self._cache[function_path]
15
+
16
+ logger.info(f"Loading function: {function_path}")
17
+
18
+ try:
19
+ module_name, function_name = function_path.rsplit('.', 1)
20
+ module = importlib.import_module(module_name)
21
+
22
+ if not hasattr(module, function_name):
23
+ raise RuntimeError(f"Module '{module_name}' has no function '{function_name}'")
24
+
25
+ func = getattr(module, function_name)
26
+
27
+ if not callable(func):
28
+ raise RuntimeError(f"'{function_path}' is not callable (type: {type(func).__name__})")
29
+
30
+ self._cache[function_path] = func
31
+ logger.info(f"Successfully loaded and cached function: {function_path}")
32
+ return func
33
+
34
+ except Exception as e:
35
+ raise RuntimeError(f"Unexpected error loading function '{function_path}': {e}")
@@ -0,0 +1,6 @@
1
+ import logging
2
+ logging.basicConfig(
3
+ level=logging.INFO,
4
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
5
+ )
6
+ logger = logging.getLogger()
@@ -0,0 +1,27 @@
1
+ from typing import Any
2
+ from jinja2 import Template, TemplateError, UndefinedError
3
+ import ast
4
+
5
+
6
+ def get_nested_value(data: Any, path: str) -> Any:
7
+ if not path:
8
+ return data
9
+
10
+ if not path.strip().startswith('{{'):
11
+ template_str = f'{{{{ {path} }}}}'
12
+ else:
13
+ template_str = path
14
+
15
+ try:
16
+ template = Template(template_str)
17
+ result = template.render(result=data)
18
+
19
+ if not result or result == '':
20
+ return None
21
+
22
+ try:
23
+ return ast.literal_eval(result)
24
+ except (ValueError, SyntaxError):
25
+ return result
26
+ except (TemplateError, UndefinedError):
27
+ return None
@@ -0,0 +1,60 @@
1
+ import importlib
2
+ from typing import Dict, Any
3
+ import functools
4
+
5
+ from .logger import logger
6
+
7
+ TYPE_MAP = {
8
+ "string": str,
9
+ "number": float,
10
+ "integer": int,
11
+ "boolean": bool,
12
+ "array": list,
13
+ "object": dict
14
+ }
15
+
16
+ def wrap_state(state: Dict[str, Any]):
17
+ def wrapper(func):
18
+
19
+ @functools.wraps(func)
20
+ def inner(*args, **kwargs):
21
+ return func(*args, **kwargs, **state)
22
+ return inner
23
+ return wrapper
24
+
25
+
26
+ class ToolRepository:
27
+ def __init__(self, tool_config):
28
+ self._cache: Dict[str, Any] = {}
29
+ self.tool_config = tool_config
30
+
31
+ def load(self, tool_name, state):
32
+ if tool_name in self._cache:
33
+ logger.info(f"Tool '{tool_name}' found in cache")
34
+ return self._cache[tool_name]
35
+
36
+ tool_info = next((t for t in self.tool_config.get('tools') if t['name']==tool_name), None)
37
+ if not tool_info:
38
+ raise RuntimeError(f"Tool '{tool_name}' not found in tool config")
39
+
40
+ tool_description = tool_info.get("description", "")
41
+ function_path = tool_info.get("callable")
42
+
43
+ module_name, function_name = function_path.rsplit('.', 1)
44
+
45
+ try:
46
+ module = importlib.import_module(module_name)
47
+ except ImportError as e:
48
+ raise RuntimeError(f"Failed to import module '{module_name}': {e}")
49
+
50
+ if not hasattr(module, function_name):
51
+ raise RuntimeError(f"Module '{module_name}' has no function '{function_name}'")
52
+
53
+ func = getattr(module, function_name)
54
+
55
+ if not callable(func):
56
+ raise RuntimeError(f"'{function_path}' is not callable (type: {type(func).__name__})")
57
+
58
+ self._cache[function_path] = (tool_name, tool_description, func)
59
+ logger.info(f"Successfully loaded and cached function: {function_path}")
60
+ return tool_name, tool_description, wrap_state(state)(func)
@@ -0,0 +1,71 @@
1
+ from opentelemetry import trace
2
+ from typing import Optional, Dict, Any
3
+ from contextlib import contextmanager
4
+ import logging
5
+
6
+ from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
7
+ from opentelemetry.trace import Link
8
+
9
+ from ..utils.logger import logger
10
+
11
+ tracer = trace.get_tracer(__name__)
12
+ propagator = TraceContextTextMapPropagator()
13
+
14
+
15
+ @contextmanager
16
+ def trace_node_execution(node_id: str, node_type: str, **attributes):
17
+ span_attributes = {
18
+ "node.id": node_id,
19
+ "node.type": node_type,
20
+ }
21
+ span_attributes.update(attributes)
22
+
23
+ with tracer.start_as_current_span(
24
+ f"node.{node_id}",
25
+ attributes=span_attributes
26
+ ) as span:
27
+ logger.info(f"Started tracing node: {node_id} ({node_type})")
28
+ yield span
29
+ logger.info(f"Finished tracing node: {node_id}")
30
+
31
+
32
+ @contextmanager
33
+ def trace_agent_invocation(agent_name: str, model: str, **attributes):
34
+ span_attributes = {
35
+ "agent.name": agent_name,
36
+ "agent.model": model,
37
+ }
38
+ span_attributes.update(attributes)
39
+
40
+ with tracer.start_as_current_span(
41
+ "agent.invoke",
42
+ attributes=span_attributes
43
+ ) as span:
44
+ yield span
45
+
46
+
47
+ @contextmanager
48
+ def trace_workflow_execution(workflow_name: str, **attributes):
49
+ span_attributes = {
50
+ "workflow.name": workflow_name,
51
+ }
52
+ span_attributes.update(attributes)
53
+
54
+ with tracer.start_as_current_span(
55
+ "workflow.execute",
56
+ attributes=span_attributes
57
+ ) as span:
58
+ yield span
59
+
60
+
61
+ def add_node_result(span, field: str, value: Any, status: str):
62
+ if span and span.is_recording():
63
+ span.set_attribute(f"field.{field}", str(value) if value is not None else "None")
64
+ span.set_attribute("node.status", status)
65
+
66
+
67
+ def add_agent_result(span, response_length: int, tool_calls: int = 0):
68
+ if span and span.is_recording():
69
+ span.set_attribute("agent.response.length", response_length)
70
+ if tool_calls > 0:
71
+ span.set_attribute("agent.tool_calls.count", tool_calls)
@@ -0,0 +1,13 @@
1
+ """
2
+ Workflow validation module
3
+ """
4
+
5
+ from .validator import WorkflowValidator, ValidationResult, validate_workflow
6
+ from .schema import WORKFLOW_SCHEMA
7
+
8
+ __all__ = [
9
+ 'WorkflowValidator',
10
+ 'ValidationResult',
11
+ 'validate_workflow',
12
+ 'WORKFLOW_SCHEMA',
13
+ ]