vellum-ai 0.14.56__py3-none-any.whl → 0.14.58__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (24) hide show
  1. vellum/client/core/client_wrapper.py +1 -1
  2. vellum/workflows/nodes/bases/base.py +19 -8
  3. vellum/workflows/nodes/core/retry_node/node.py +6 -0
  4. vellum/workflows/nodes/displayable/api_node/node.py +8 -1
  5. vellum/workflows/nodes/displayable/api_node/tests/test_api_node.py +66 -3
  6. vellum/workflows/nodes/displayable/bases/inline_prompt_node/node.py +14 -10
  7. vellum/workflows/nodes/displayable/guardrail_node/node.py +13 -2
  8. vellum/workflows/nodes/displayable/guardrail_node/test_node.py +29 -0
  9. vellum/workflows/nodes/experimental/tool_calling_node/node.py +3 -1
  10. vellum/workflows/nodes/experimental/tool_calling_node/utils.py +46 -8
  11. vellum/workflows/runner/runner.py +14 -10
  12. vellum/workflows/state/base.py +28 -10
  13. vellum/workflows/state/encoder.py +5 -1
  14. vellum/workflows/utils/functions.py +42 -1
  15. vellum/workflows/utils/tests/test_functions.py +156 -1
  16. vellum/workflows/workflows/tests/test_base_workflow.py +4 -4
  17. {vellum_ai-0.14.56.dist-info → vellum_ai-0.14.58.dist-info}/METADATA +1 -1
  18. {vellum_ai-0.14.56.dist-info → vellum_ai-0.14.58.dist-info}/RECORD +24 -23
  19. vellum_ee/workflows/display/nodes/vellum/tests/test_tool_calling_node.py +118 -0
  20. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_inline_prompt_node_serialization.py +265 -5
  21. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_tool_calling_node_serialization.py +2 -1
  22. {vellum_ai-0.14.56.dist-info → vellum_ai-0.14.58.dist-info}/LICENSE +0 -0
  23. {vellum_ai-0.14.56.dist-info → vellum_ai-0.14.58.dist-info}/WHEEL +0 -0
  24. {vellum_ai-0.14.56.dist-info → vellum_ai-0.14.58.dist-info}/entry_points.txt +0 -0
@@ -18,7 +18,7 @@ class BaseClientWrapper:
18
18
  headers: typing.Dict[str, str] = {
19
19
  "X-Fern-Language": "Python",
20
20
  "X-Fern-SDK-Name": "vellum-ai",
21
- "X-Fern-SDK-Version": "0.14.56",
21
+ "X-Fern-SDK-Version": "0.14.58",
22
22
  }
23
23
  headers["X-API-KEY"] = self.api_key
24
24
  return headers
@@ -8,7 +8,6 @@ from typing import Any, Dict, Generic, Iterator, Optional, Set, Tuple, Type, Typ
8
8
  from vellum.workflows.constants import undefined
9
9
  from vellum.workflows.descriptors.base import BaseDescriptor
10
10
  from vellum.workflows.descriptors.utils import is_unresolved, resolve_value
11
- from vellum.workflows.edges.edge import Edge
12
11
  from vellum.workflows.errors.types import WorkflowErrorCode
13
12
  from vellum.workflows.exceptions import NodeException
14
13
  from vellum.workflows.graph import Graph
@@ -317,8 +316,15 @@ class BaseNode(Generic[StateType], metaclass=BaseNodeMeta):
317
316
  """
318
317
  # Check if all dependencies have invoked this node
319
318
  dependencies_invoked = state.meta.node_execution_cache._dependencies_invoked.get(node_span_id, set())
320
- all_deps_invoked = all(dep in dependencies_invoked for dep in dependencies)
321
-
319
+ node_classes_invoked = {
320
+ state.meta.node_execution_cache.__node_execution_lookup__[dep]
321
+ for dep in dependencies_invoked
322
+ if dep in state.meta.node_execution_cache.__node_execution_lookup__
323
+ }
324
+ if len(node_classes_invoked) != len(dependencies):
325
+ return False
326
+
327
+ all_deps_invoked = all(dep in node_classes_invoked for dep in dependencies)
322
328
  return all_deps_invoked
323
329
 
324
330
  raise NodeException(
@@ -328,7 +334,7 @@ class BaseNode(Generic[StateType], metaclass=BaseNodeMeta):
328
334
 
329
335
  @classmethod
330
336
  def _queue_node_execution(
331
- cls, state: StateType, dependencies: Set["Type[BaseNode]"], invoked_by: Optional[Edge] = None
337
+ cls, state: StateType, dependencies: Set["Type[BaseNode]"], invoked_by: Optional[UUID] = None
332
338
  ) -> UUID:
333
339
  """
334
340
  Queues a future execution of a node, returning the span id of the execution.
@@ -341,19 +347,24 @@ class BaseNode(Generic[StateType], metaclass=BaseNodeMeta):
341
347
  if not invoked_by:
342
348
  return execution_id
343
349
 
350
+ if invoked_by not in state.meta.node_execution_cache.__node_execution_lookup__:
351
+ return execution_id
352
+
344
353
  if cls.merge_behavior not in {MergeBehavior.AWAIT_ANY, MergeBehavior.AWAIT_ALL}:
354
+ # Keep track of the dependencies that have invoked this node
355
+ # This would be needed while climbing the history in the loop
356
+ state.meta.node_execution_cache._dependencies_invoked[execution_id].add(invoked_by)
345
357
  return execution_id
346
358
 
347
- source_node = invoked_by.from_port.node_class
348
359
  for queued_node_execution_id in state.meta.node_execution_cache._node_executions_queued[cls.node_class]:
349
- if source_node not in state.meta.node_execution_cache._dependencies_invoked[queued_node_execution_id]:
360
+ if invoked_by not in state.meta.node_execution_cache._dependencies_invoked[queued_node_execution_id]:
350
361
  state.meta.node_execution_cache._invoke_dependency(
351
- queued_node_execution_id, cls.node_class, source_node, dependencies
362
+ queued_node_execution_id, cls.node_class, invoked_by, dependencies
352
363
  )
353
364
  return queued_node_execution_id
354
365
 
355
366
  state.meta.node_execution_cache._node_executions_queued[cls.node_class].append(execution_id)
356
- state.meta.node_execution_cache._invoke_dependency(execution_id, cls.node_class, source_node, dependencies)
367
+ state.meta.node_execution_cache._invoke_dependency(execution_id, cls.node_class, invoked_by, dependencies)
357
368
  return execution_id
358
369
 
359
370
  class Execution(metaclass=_BaseNodeExecutionMeta):
@@ -70,6 +70,12 @@ class RetryNode(BaseAdornmentNode[StateType], Generic[StateType]):
70
70
 
71
71
  for output_descriptor, output_value in event.outputs:
72
72
  setattr(node_outputs, output_descriptor.name, output_value)
73
+
74
+ if self.__wrapped_node__:
75
+ inner_desc = getattr(self.__wrapped_node__.Outputs, output_descriptor.name, None)
76
+ if inner_desc:
77
+ self.state.meta.node_outputs[inner_desc] = output_value
78
+
73
79
  elif event.name == "workflow.execution.paused":
74
80
  exception = NodeException(
75
81
  code=WorkflowErrorCode.INVALID_OUTPUTS,
@@ -47,11 +47,18 @@ class APINode(BaseAPINode):
47
47
  self.bearer_token_value, VellumSecret
48
48
  ):
49
49
  bearer_token = self.bearer_token_value
50
+
51
+ final_headers = {**headers, **header_overrides}
52
+
53
+ vellum_client_wrapper = self._context.vellum_client._client_wrapper
54
+ if self.url.startswith(vellum_client_wrapper._environment.default) and "X-API-Key" not in final_headers:
55
+ final_headers["X-API-Key"] = vellum_client_wrapper.api_key
56
+
50
57
  return self._run(
51
58
  method=self.method,
52
59
  url=self.url,
53
60
  data=self.data,
54
61
  json=self.json,
55
- headers={**headers, **header_overrides},
62
+ headers=final_headers,
56
63
  bearer_token=bearer_token,
57
64
  )
@@ -19,7 +19,7 @@ def test_run_workflow__secrets(vellum_client):
19
19
  class SimpleBaseAPINode(APINode):
20
20
  method = APIRequestMethod.POST
21
21
  authorization_type = AuthorizationType.BEARER_TOKEN
22
- url = "https://api.vellum.ai"
22
+ url = "https://example.vellum.ai"
23
23
  json = {
24
24
  "key": "value",
25
25
  }
@@ -32,7 +32,9 @@ def test_run_workflow__secrets(vellum_client):
32
32
  terminal = node.run()
33
33
 
34
34
  assert vellum_client.execute_api.call_count == 1
35
+ assert vellum_client.execute_api.call_args.kwargs["url"] == "https://example.vellum.ai"
35
36
  assert vellum_client.execute_api.call_args.kwargs["body"] == {"key": "value"}
37
+ assert vellum_client.execute_api.call_args.kwargs["headers"] == {"X-Test-Header": "foo"}
36
38
  bearer_token = vellum_client.execute_api.call_args.kwargs["bearer_token"]
37
39
  assert bearer_token == ClientVellumSecret(name="secret")
38
40
  assert terminal.headers == {"X-Response-Header": "bar"}
@@ -45,7 +47,7 @@ def test_api_node_raises_error_when_api_call_fails(vellum_client):
45
47
  class SimpleAPINode(APINode):
46
48
  method = APIRequestMethod.GET
47
49
  authorization_type = AuthorizationType.BEARER_TOKEN
48
- url = "https://api.vellum.ai"
50
+ url = "https://example.vellum.ai"
49
51
  json = {
50
52
  "key": "value",
51
53
  }
@@ -65,7 +67,9 @@ def test_api_node_raises_error_when_api_call_fails(vellum_client):
65
67
 
66
68
  # AND the API call should have been made
67
69
  assert vellum_client.execute_api.call_count == 1
70
+ assert vellum_client.execute_api.call_args.kwargs["url"] == "https://example.vellum.ai"
68
71
  assert vellum_client.execute_api.call_args.kwargs["body"] == {"key": "value"}
72
+ assert vellum_client.execute_api.call_args.kwargs["headers"] == {"X-Test-Header": "foo"}
69
73
 
70
74
 
71
75
  def test_api_node_defaults_to_get_method(vellum_client):
@@ -80,7 +84,7 @@ def test_api_node_defaults_to_get_method(vellum_client):
80
84
  # AND an API node without a method specified
81
85
  class SimpleAPINodeWithoutMethod(APINode):
82
86
  authorization_type = AuthorizationType.BEARER_TOKEN
83
- url = "https://api.vellum.ai"
87
+ url = "https://example.vellum.ai"
84
88
  headers = {
85
89
  "X-Test-Header": "foo",
86
90
  }
@@ -95,3 +99,62 @@ def test_api_node_defaults_to_get_method(vellum_client):
95
99
  assert vellum_client.execute_api.call_count == 1
96
100
  method = vellum_client.execute_api.call_args.kwargs["method"]
97
101
  assert method == APIRequestMethod.GET.value
102
+
103
+
104
+ def test_api_node__detects_client_environment_urls__adds_token(mock_httpx_transport, mock_requests, monkeypatch):
105
+ # GIVEN an API node with a URL pointing back to Vellum
106
+ class SimpleAPINodeToVellum(APINode):
107
+ url = "https://api.vellum.ai"
108
+
109
+ # AND a mock request sent to the Vellum API would return a 200
110
+ mock_response = mock_requests.get(
111
+ "https://api.vellum.ai",
112
+ status_code=200,
113
+ json={"data": [1, 2, 3]},
114
+ )
115
+
116
+ # AND an api key is set
117
+ monkeypatch.setenv("VELLUM_API_KEY", "vellum-api-key-1234")
118
+
119
+ # WHEN we run the node
120
+ node = SimpleAPINodeToVellum()
121
+ node.run()
122
+
123
+ # THEN the execute_api method should not have been called
124
+ mock_httpx_transport.handle_request.assert_not_called()
125
+
126
+ # AND the vellum API should have been called with the correct headers
127
+ assert mock_response.last_request
128
+ assert mock_response.last_request.headers["X-API-Key"] == "vellum-api-key-1234"
129
+
130
+
131
+ def test_api_node__detects_client_environment_urls__does_not_override_headers(
132
+ mock_httpx_transport, mock_requests, monkeypatch
133
+ ):
134
+ # GIVEN an API node with a URL pointing back to Vellum
135
+ class SimpleAPINodeToVellum(APINode):
136
+ url = "https://api.vellum.ai"
137
+ headers = {
138
+ "X-API-Key": "vellum-api-key-5678",
139
+ }
140
+
141
+ # AND a mock request sent to the Vellum API would return a 200
142
+ mock_response = mock_requests.get(
143
+ "https://api.vellum.ai",
144
+ status_code=200,
145
+ json={"data": [1, 2, 3]},
146
+ )
147
+
148
+ # AND an api key is set
149
+ monkeypatch.setenv("VELLUM_API_KEY", "vellum-api-key-1234")
150
+
151
+ # WHEN we run the node
152
+ node = SimpleAPINodeToVellum()
153
+ node.run()
154
+
155
+ # THEN the execute_api method should not have been called
156
+ mock_httpx_transport.handle_request.assert_not_called()
157
+
158
+ # AND the vellum API should have been called with the correct headers
159
+ assert mock_response.last_request
160
+ assert mock_response.last_request.headers["X-API-Key"] == "vellum-api-key-5678"
@@ -30,8 +30,8 @@ from vellum.workflows.nodes.displayable.bases.base_prompt_node import BasePrompt
30
30
  from vellum.workflows.nodes.displayable.bases.inline_prompt_node.constants import DEFAULT_PROMPT_PARAMETERS
31
31
  from vellum.workflows.outputs import BaseOutput
32
32
  from vellum.workflows.types import MergeBehavior
33
- from vellum.workflows.types.generics import StateType
34
- from vellum.workflows.utils.functions import compile_function_definition
33
+ from vellum.workflows.types.generics import StateType, is_workflow_class
34
+ from vellum.workflows.utils.functions import compile_function_definition, compile_workflow_function_definition
35
35
 
36
36
 
37
37
  class BaseInlinePromptNode(BasePromptNode[StateType], Generic[StateType]):
@@ -97,14 +97,18 @@ class BaseInlinePromptNode(BasePromptNode[StateType], Generic[StateType]):
97
97
  "execution_context": execution_context.model_dump(mode="json"),
98
98
  **request_options.get("additional_body_parameters", {}),
99
99
  }
100
- normalized_functions = (
101
- [
102
- function if isinstance(function, FunctionDefinition) else compile_function_definition(function)
103
- for function in self.functions
104
- ]
105
- if self.functions
106
- else None
107
- )
100
+
101
+ normalized_functions: Optional[List[FunctionDefinition]] = None
102
+
103
+ if self.functions:
104
+ normalized_functions = []
105
+ for function in self.functions:
106
+ if isinstance(function, FunctionDefinition):
107
+ normalized_functions.append(function)
108
+ elif is_workflow_class(function):
109
+ normalized_functions.append(compile_workflow_function_definition(function))
110
+ else:
111
+ normalized_functions.append(compile_function_definition(function))
108
112
 
109
113
  if self.settings and not self.settings.stream_enabled:
110
114
  # This endpoint is returning a single event, so we need to wrap it in a generator
@@ -37,6 +37,7 @@ class GuardrailNode(BaseNode[StateType], Generic[StateType]):
37
37
  score: float
38
38
  normalized_score: Optional[float]
39
39
  log: Optional[str]
40
+ reason: Optional[str]
40
41
 
41
42
  def run(self) -> Outputs:
42
43
  try:
@@ -54,10 +55,10 @@ class GuardrailNode(BaseNode[StateType], Generic[StateType]):
54
55
  )
55
56
 
56
57
  metric_outputs = {output.name: output.value for output in metric_execution.outputs}
57
-
58
58
  SCORE_KEY = "score"
59
59
  NORMALIZED_SCORE_KEY = "normalized_score"
60
60
  LOG_KEY = "log"
61
+ REASON_KEY = "reason"
61
62
 
62
63
  score = metric_outputs.get(SCORE_KEY)
63
64
  if not isinstance(score, float):
@@ -87,7 +88,17 @@ class GuardrailNode(BaseNode[StateType], Generic[StateType]):
87
88
  else:
88
89
  log = None
89
90
 
90
- return self.Outputs(score=score, normalized_score=normalized_score, log=log, **metric_outputs)
91
+ if REASON_KEY in metric_outputs:
92
+ reason = metric_outputs.pop(REASON_KEY) or ""
93
+ if not isinstance(reason, str):
94
+ raise NodeException(
95
+ message="Metric execution reason output must be of type 'str'",
96
+ code=WorkflowErrorCode.INVALID_OUTPUTS,
97
+ )
98
+ else:
99
+ reason = None
100
+
101
+ return self.Outputs(score=score, normalized_score=normalized_score, log=log, reason=reason, **metric_outputs)
91
102
 
92
103
  def _compile_metric_inputs(self) -> List[MetricDefinitionInput]:
93
104
  # TODO: We may want to consolidate with prompt deployment input compilation
@@ -102,6 +102,35 @@ def test_run_guardrail_node__normalized_score_null(vellum_client):
102
102
  assert exc_info.value.code == WorkflowErrorCode.INVALID_OUTPUTS
103
103
 
104
104
 
105
+ def test_run_guardrail_node__reason(vellum_client):
106
+ # GIVEN a Guardrail Node
107
+ class MyGuard(GuardrailNode):
108
+ metric_definition = "example_metric_definition"
109
+ metric_inputs = {}
110
+
111
+ # AND we know that the guardrail node will return a reason
112
+ mock_metric_execution = MetricDefinitionExecution(
113
+ outputs=[
114
+ TestSuiteRunMetricNumberOutput(
115
+ name="score",
116
+ value=0.6,
117
+ ),
118
+ TestSuiteRunMetricStringOutput(
119
+ name="reason",
120
+ value="foo",
121
+ ),
122
+ ],
123
+ )
124
+ vellum_client.metric_definitions.execute_metric_definition.return_value = mock_metric_execution
125
+
126
+ # WHEN we run the Guardrail Node
127
+ outputs = MyGuard().run()
128
+
129
+ # THEN the workflow should have completed successfully
130
+ assert outputs.score == 0.6
131
+ assert outputs.reason == "foo"
132
+
133
+
105
134
  def test_run_guardrail_node__api_error(vellum_client):
106
135
  # GIVEN a Guardrail Node
107
136
  class MyGuard(GuardrailNode):
@@ -1,6 +1,8 @@
1
1
  from collections.abc import Callable
2
2
  from typing import Any, ClassVar, List, Optional
3
3
 
4
+ from pydash import snake_case
5
+
4
6
  from vellum import ChatMessage, PromptBlock
5
7
  from vellum.workflows.context import execution_context, get_parent_context
6
8
  from vellum.workflows.errors.types import WorkflowErrorCode
@@ -101,7 +103,7 @@ class ToolCallingNode(BaseNode):
101
103
  )
102
104
 
103
105
  self._function_nodes = {
104
- function.__name__: create_function_node(
106
+ snake_case(function.__name__): create_function_node(
105
107
  function=function,
106
108
  tool_router_node=self.tool_router_node,
107
109
  )
@@ -2,22 +2,30 @@ from collections.abc import Callable
2
2
  import json
3
3
  from typing import Any, Iterator, List, Optional, Type, cast
4
4
 
5
- from vellum import ChatMessage, FunctionDefinition, PromptBlock
5
+ from pydash import snake_case
6
+
7
+ from vellum import ChatMessage, PromptBlock
6
8
  from vellum.client.types.function_call_chat_message_content import FunctionCallChatMessageContent
7
9
  from vellum.client.types.function_call_chat_message_content_value import FunctionCallChatMessageContentValue
10
+ from vellum.client.types.string_chat_message_content import StringChatMessageContent
8
11
  from vellum.client.types.variable_prompt_block import VariablePromptBlock
12
+ from vellum.workflows.errors.types import WorkflowErrorCode
13
+ from vellum.workflows.exceptions import NodeException
14
+ from vellum.workflows.inputs.base import BaseInputs
9
15
  from vellum.workflows.nodes.bases import BaseNode
10
16
  from vellum.workflows.nodes.displayable.inline_prompt_node.node import InlinePromptNode
11
17
  from vellum.workflows.outputs.base import BaseOutput
12
18
  from vellum.workflows.ports.port import Port
13
19
  from vellum.workflows.references.lazy import LazyReference
20
+ from vellum.workflows.state.encoder import DefaultStateEncoder
14
21
  from vellum.workflows.types.core import EntityInputsInterface, MergeBehavior
22
+ from vellum.workflows.types.generics import is_workflow_class
15
23
 
16
24
 
17
25
  class FunctionNode(BaseNode):
18
26
  """Node that executes a specific function."""
19
27
 
20
- function: FunctionDefinition
28
+ pass
21
29
 
22
30
 
23
31
  class ToolRouterNode(InlinePromptNode):
@@ -61,7 +69,7 @@ def create_tool_router_node(
61
69
  # If we have functions, create dynamic ports for each function
62
70
  Ports = type("Ports", (), {})
63
71
  for function in functions:
64
- function_name = function.__name__
72
+ function_name = snake_case(function.__name__)
65
73
 
66
74
  # Avoid using lambda to capture function_name
67
75
  # lambda will capture the function_name by reference,
@@ -127,10 +135,41 @@ def create_function_node(function: Callable[..., Any], tool_router_node: Type[To
127
135
  outputs = json.loads(outputs)
128
136
  arguments = outputs["arguments"]
129
137
 
130
- # Call the original function directly with the arguments
131
- result = function(**arguments)
132
-
133
- self.state.chat_history.append(ChatMessage(role="FUNCTION", text=result))
138
+ # Call the function based on its type
139
+ if is_workflow_class(function):
140
+ # Dynamically define an Inputs subclass of BaseInputs
141
+ Inputs = type(
142
+ "Inputs",
143
+ (BaseInputs,),
144
+ {"__annotations__": {k: type(v) for k, v in arguments.items()}},
145
+ )
146
+
147
+ # Create an instance with arguments
148
+ inputs_instance = Inputs(**arguments)
149
+
150
+ workflow = function()
151
+ terminal_event = workflow.run(
152
+ inputs=inputs_instance,
153
+ )
154
+ if terminal_event.name == "workflow.execution.paused":
155
+ raise NodeException(
156
+ code=WorkflowErrorCode.INVALID_OUTPUTS,
157
+ message="Subworkflow unexpectedly paused",
158
+ )
159
+ elif terminal_event.name == "workflow.execution.fulfilled":
160
+ result = terminal_event.outputs
161
+ elif terminal_event.name == "workflow.execution.rejected":
162
+ raise Exception(f"Workflow execution rejected: {terminal_event.error}")
163
+ else:
164
+ # If it's a regular callable, call it directly
165
+ result = function(**arguments)
166
+
167
+ self.state.chat_history.append(
168
+ ChatMessage(
169
+ role="FUNCTION",
170
+ content=StringChatMessageContent(value=json.dumps(result, cls=DefaultStateEncoder)),
171
+ )
172
+ )
134
173
 
135
174
  return self.Outputs()
136
175
 
@@ -138,7 +177,6 @@ def create_function_node(function: Callable[..., Any], tool_router_node: Type[To
138
177
  f"FunctionNode_{function.__name__}",
139
178
  (FunctionNode,),
140
179
  {
141
- "function": function,
142
180
  "run": execute_function,
143
181
  "__module__": __name__,
144
182
  },
@@ -10,7 +10,6 @@ from typing import TYPE_CHECKING, Any, Dict, Generic, Iterable, Iterator, Option
10
10
  from vellum.workflows.constants import undefined
11
11
  from vellum.workflows.context import ExecutionContext, execution_context, get_execution_context
12
12
  from vellum.workflows.descriptors.base import BaseDescriptor
13
- from vellum.workflows.edges.edge import Edge
14
13
  from vellum.workflows.errors import WorkflowError, WorkflowErrorCode
15
14
  from vellum.workflows.events import (
16
15
  NodeExecutionFulfilledEvent,
@@ -143,7 +142,7 @@ class WorkflowRunner(Generic[StateType]):
143
142
  self._workflow_event_inner_queue: Queue[WorkflowEvent] = Queue()
144
143
 
145
144
  self._max_concurrency = max_concurrency
146
- self._concurrency_queue: Queue[Tuple[StateType, Type[BaseNode], Optional[Edge]]] = Queue()
145
+ self._concurrency_queue: Queue[Tuple[StateType, Type[BaseNode], Optional[UUID]]] = Queue()
147
146
 
148
147
  # This queue is responsible for sending events from WorkflowRunner to the background thread
149
148
  # for user defined emitters
@@ -389,7 +388,12 @@ class WorkflowRunner(Generic[StateType]):
389
388
  ):
390
389
  self._run_work_item(node, span_id)
391
390
 
392
- def _handle_invoked_ports(self, state: StateType, ports: Optional[Iterable[Port]]) -> None:
391
+ def _handle_invoked_ports(
392
+ self,
393
+ state: StateType,
394
+ ports: Optional[Iterable[Port]],
395
+ invoked_by: Optional[UUID],
396
+ ) -> None:
393
397
  if not ports:
394
398
  return
395
399
 
@@ -402,9 +406,9 @@ class WorkflowRunner(Generic[StateType]):
402
406
  next_state = state
403
407
 
404
408
  if self._max_concurrency:
405
- self._concurrency_queue.put((next_state, edge.to_node, edge))
409
+ self._concurrency_queue.put((next_state, edge.to_node, invoked_by))
406
410
  else:
407
- self._run_node_if_ready(next_state, edge.to_node, edge)
411
+ self._run_node_if_ready(next_state, edge.to_node, invoked_by)
408
412
 
409
413
  if self._max_concurrency:
410
414
  num_nodes_to_run = self._max_concurrency - len(self._active_nodes_by_execution_id)
@@ -412,14 +416,14 @@ class WorkflowRunner(Generic[StateType]):
412
416
  if self._concurrency_queue.empty():
413
417
  break
414
418
 
415
- next_state, node_class, invoked_edge = self._concurrency_queue.get()
416
- self._run_node_if_ready(next_state, node_class, invoked_edge)
419
+ next_state, node_class, invoked_by = self._concurrency_queue.get()
420
+ self._run_node_if_ready(next_state, node_class, invoked_by)
417
421
 
418
422
  def _run_node_if_ready(
419
423
  self,
420
424
  state: StateType,
421
425
  node_class: Type[BaseNode],
422
- invoked_by: Optional[Edge] = None,
426
+ invoked_by: Optional[UUID] = None,
423
427
  ) -> None:
424
428
  with state.__lock__:
425
429
  for descriptor in node_class.ExternalInputs:
@@ -482,7 +486,7 @@ class WorkflowRunner(Generic[StateType]):
482
486
  )
483
487
  )
484
488
 
485
- self._handle_invoked_ports(node.state, event.invoked_ports)
489
+ self._handle_invoked_ports(node.state, event.invoked_ports, event.span_id)
486
490
 
487
491
  return None
488
492
 
@@ -508,7 +512,7 @@ class WorkflowRunner(Generic[StateType]):
508
512
  )
509
513
  )
510
514
 
511
- self._handle_invoked_ports(node.state, event.invoked_ports)
515
+ self._handle_invoked_ports(node.state, event.invoked_ports, event.span_id)
512
516
 
513
517
  return None
514
518
 
@@ -93,7 +93,8 @@ def _make_snapshottable(value: Any, snapshot_callback: Callable[[], None]) -> An
93
93
  NodeExecutionsFulfilled = Dict[Type["BaseNode"], Stack[UUID]]
94
94
  NodeExecutionsInitiated = Dict[Type["BaseNode"], Set[UUID]]
95
95
  NodeExecutionsQueued = Dict[Type["BaseNode"], List[UUID]]
96
- DependenciesInvoked = Dict[UUID, Set[Type["BaseNode"]]]
96
+ NodeExecutionLookup = Dict[UUID, Type["BaseNode"]]
97
+ DependenciesInvoked = Dict[UUID, Set[UUID]]
97
98
 
98
99
 
99
100
  class NodeExecutionCache:
@@ -102,11 +103,15 @@ class NodeExecutionCache:
102
103
  _node_executions_queued: NodeExecutionsQueued
103
104
  _dependencies_invoked: DependenciesInvoked
104
105
 
106
+ # Derived fields - no need to serialize
107
+ __node_execution_lookup__: NodeExecutionLookup
108
+
105
109
  def __init__(self) -> None:
106
110
  self._dependencies_invoked = defaultdict(set)
107
111
  self._node_executions_fulfilled = defaultdict(Stack[UUID])
108
112
  self._node_executions_initiated = defaultdict(set)
109
113
  self._node_executions_queued = defaultdict(list)
114
+ self.__node_execution_lookup__ = {}
110
115
 
111
116
  @classmethod
112
117
  def deserialize(cls, raw_data: dict, nodes: Dict[Union[str, UUID], Type["BaseNode"]]):
@@ -124,10 +129,8 @@ class NodeExecutionCache:
124
129
  dependencies_invoked = raw_data.get("dependencies_invoked")
125
130
  if isinstance(dependencies_invoked, dict):
126
131
  for execution_id, dependencies in dependencies_invoked.items():
127
- dependency_classes = {get_node_class(dep) for dep in dependencies}
128
- cache._dependencies_invoked[UUID(execution_id)] = {
129
- dep_class for dep_class in dependency_classes if dep_class is not None
130
- }
132
+ dependency_execution_ids = {UUID(dep) for dep in dependencies if is_valid_uuid(dep)}
133
+ cache._dependencies_invoked[UUID(execution_id)] = dependency_execution_ids
131
134
 
132
135
  node_executions_fulfilled = raw_data.get("node_executions_fulfilled")
133
136
  if isinstance(node_executions_fulfilled, dict):
@@ -151,6 +154,10 @@ class NodeExecutionCache:
151
154
  {UUID(execution_id) for execution_id in execution_ids}
152
155
  )
153
156
 
157
+ for node_class, execution_ids in cache._node_executions_initiated.items():
158
+ for execution_id in execution_ids:
159
+ cache.__node_execution_lookup__[execution_id] = node_class
160
+
154
161
  node_executions_queued = raw_data.get("node_executions_queued")
155
162
  if isinstance(node_executions_queued, dict):
156
163
  for node, execution_ids in node_executions_queued.items():
@@ -166,18 +173,29 @@ class NodeExecutionCache:
166
173
  self,
167
174
  execution_id: UUID,
168
175
  node: Type["BaseNode"],
169
- dependency: Type["BaseNode"],
176
+ invoked_by: UUID,
170
177
  dependencies: Set["Type[BaseNode]"],
171
178
  ) -> None:
172
- self._dependencies_invoked[execution_id].add(dependency)
173
- if all(dep in self._dependencies_invoked[execution_id] for dep in dependencies):
174
- self._node_executions_queued[node].remove(execution_id)
179
+ self._dependencies_invoked[execution_id].add(invoked_by)
180
+ invoked_node_classes = {
181
+ self.__node_execution_lookup__[dep]
182
+ for dep in self._dependencies_invoked[execution_id]
183
+ if dep in self.__node_execution_lookup__
184
+ }
185
+ if len(invoked_node_classes) != len(dependencies):
186
+ return
187
+
188
+ if any(dep not in invoked_node_classes for dep in dependencies):
189
+ return
190
+
191
+ self._node_executions_queued[node].remove(execution_id)
175
192
 
176
193
  def is_node_execution_initiated(self, node: Type["BaseNode"], execution_id: UUID) -> bool:
177
194
  return execution_id in self._node_executions_initiated[node]
178
195
 
179
196
  def initiate_node_execution(self, node: Type["BaseNode"], execution_id: UUID) -> None:
180
197
  self._node_executions_initiated[node].add(execution_id)
198
+ self.__node_execution_lookup__[execution_id] = node
181
199
 
182
200
  def fulfill_node_execution(self, node: Type["BaseNode"], execution_id: UUID) -> None:
183
201
  self._node_executions_fulfilled[node].push(execution_id)
@@ -188,7 +206,7 @@ class NodeExecutionCache:
188
206
  def dump(self) -> Dict[str, Any]:
189
207
  return {
190
208
  "dependencies_invoked": {
191
- str(execution_id): [str(dep.__id__) for dep in dependencies]
209
+ str(execution_id): [str(dep) for dep in dependencies]
192
210
  for execution_id, dependencies in self._dependencies_invoked.items()
193
211
  },
194
212
  "node_executions_initiated": {
@@ -66,7 +66,11 @@ class DefaultStateEncoder(JSONEncoder):
66
66
  except Exception:
67
67
  source_code = f"<source code not available for {obj.__name__}>"
68
68
 
69
- return {"definition": function_definition, "src": source_code}
69
+ return {
70
+ "type": "CODE_EXECUTION",
71
+ "definition": function_definition,
72
+ "src": source_code,
73
+ }
70
74
 
71
75
  if obj.__class__ in self.encoders:
72
76
  return self.encoders[obj.__class__](obj)