vellum-ai 0.14.40__py3-none-any.whl → 0.14.42__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. vellum/__init__.py +2 -4
  2. vellum/client/core/client_wrapper.py +1 -1
  3. vellum/client/reference.md +141 -4
  4. vellum/client/resources/ad_hoc/client.py +311 -1
  5. vellum/client/resources/deployments/client.py +2 -2
  6. vellum/client/resources/documents/client.py +0 -6
  7. vellum/client/types/__init__.py +2 -4
  8. vellum/client/types/execute_api_response.py +3 -4
  9. vellum/client/types/execute_api_response_json.py +7 -0
  10. vellum/client/types/prompt_settings.py +1 -0
  11. vellum/client/types/workflow_event_execution_read.py +0 -4
  12. vellum/client/types/workflow_execution_initiated_body.py +0 -9
  13. vellum/client/types/workflow_execution_initiated_event.py +0 -4
  14. vellum/client/types/workflow_execution_span.py +0 -4
  15. vellum/types/{node_event_display_context.py → execute_api_response_json.py} +1 -1
  16. vellum/workflows/inputs/base.py +26 -3
  17. vellum/workflows/inputs/tests/test_inputs.py +15 -0
  18. vellum/workflows/nodes/bases/base_adornment_node.py +9 -0
  19. vellum/workflows/nodes/core/map_node/node.py +3 -2
  20. vellum/workflows/nodes/core/map_node/tests/test_node.py +56 -0
  21. vellum/workflows/nodes/core/retry_node/node.py +2 -1
  22. vellum/workflows/nodes/experimental/tool_calling_node/node.py +6 -28
  23. vellum/workflows/nodes/experimental/tool_calling_node/utils.py +6 -10
  24. vellum/workflows/nodes/utils.py +14 -1
  25. vellum/workflows/references/workflow_input.py +5 -1
  26. vellum/workflows/runner/runner.py +2 -0
  27. vellum/workflows/workflows/base.py +5 -0
  28. {vellum_ai-0.14.40.dist-info → vellum_ai-0.14.42.dist-info}/METADATA +1 -1
  29. {vellum_ai-0.14.40.dist-info → vellum_ai-0.14.42.dist-info}/RECORD +65 -68
  30. vellum_ee/workflows/display/nodes/base_node_display.py +67 -28
  31. vellum_ee/workflows/display/nodes/tests/test_base_node_display.py +18 -0
  32. vellum_ee/workflows/display/nodes/vellum/api_node.py +3 -2
  33. vellum_ee/workflows/display/nodes/vellum/base_adornment_node.py +1 -2
  34. vellum_ee/workflows/display/nodes/vellum/code_execution_node.py +3 -2
  35. vellum_ee/workflows/display/nodes/vellum/conditional_node.py +3 -2
  36. vellum_ee/workflows/display/nodes/vellum/error_node.py +2 -2
  37. vellum_ee/workflows/display/nodes/vellum/final_output_node.py +8 -8
  38. vellum_ee/workflows/display/nodes/vellum/guardrail_node.py +3 -2
  39. vellum_ee/workflows/display/nodes/vellum/inline_prompt_node.py +3 -2
  40. vellum_ee/workflows/display/nodes/vellum/inline_subworkflow_node.py +3 -2
  41. vellum_ee/workflows/display/nodes/vellum/map_node.py +15 -12
  42. vellum_ee/workflows/display/nodes/vellum/merge_node.py +3 -2
  43. vellum_ee/workflows/display/nodes/vellum/note_node.py +2 -2
  44. vellum_ee/workflows/display/nodes/vellum/prompt_deployment_node.py +3 -4
  45. vellum_ee/workflows/display/nodes/vellum/search_node.py +3 -2
  46. vellum_ee/workflows/display/nodes/vellum/subworkflow_deployment_node.py +3 -2
  47. vellum_ee/workflows/display/nodes/vellum/templating_node.py +3 -2
  48. vellum_ee/workflows/display/nodes/vellum/tests/test_utils.py +2 -2
  49. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_api_node_serialization.py +1 -0
  50. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_code_execution_node_serialization.py +3 -0
  51. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_conditional_node_serialization.py +138 -0
  52. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_guardrail_node_serialization.py +1 -0
  53. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_inline_subworkflow_serialization.py +1 -0
  54. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_map_node_serialization.py +3 -2
  55. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_merge_node_serialization.py +1 -0
  56. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_prompt_deployment_serialization.py +1 -0
  57. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_search_node_serialization.py +1 -0
  58. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_subworkflow_deployment_serialization.py +1 -0
  59. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_templating_node_serialization.py +1 -0
  60. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_terminal_node_serialization.py +2 -2
  61. vellum_ee/workflows/display/tests/workflow_serialization/test_complex_terminal_node_serialization.py +2 -2
  62. vellum_ee/workflows/display/workflows/base_workflow_display.py +4 -12
  63. vellum/client/types/node_event_display_context.py +0 -30
  64. vellum/client/types/workflow_event_display_context.py +0 -28
  65. vellum/types/workflow_event_display_context.py +0 -3
  66. vellum_ee/workflows/display/nodes/base_node_vellum_display.py +0 -40
  67. {vellum_ai-0.14.40.dist-info → vellum_ai-0.14.42.dist-info}/LICENSE +0 -0
  68. {vellum_ai-0.14.40.dist-info → vellum_ai-0.14.42.dist-info}/WHEEL +0 -0
  69. {vellum_ai-0.14.40.dist-info → vellum_ai-0.14.42.dist-info}/entry_points.txt +0 -0
@@ -2,8 +2,9 @@
2
2
 
3
3
  from ..core.pydantic_utilities import UniversalBaseModel
4
4
  import typing_extensions
5
- import typing
5
+ from .execute_api_response_json import ExecuteApiResponseJson
6
6
  from ..core.serialization import FieldMetadata
7
+ import typing
7
8
  from ..core.pydantic_utilities import IS_PYDANTIC_V2
8
9
  import pydantic
9
10
 
@@ -11,9 +12,7 @@ import pydantic
11
12
  class ExecuteApiResponse(UniversalBaseModel):
12
13
  status_code: int
13
14
  text: str
14
- json_: typing_extensions.Annotated[
15
- typing.Optional[typing.Dict[str, typing.Optional[typing.Any]]], FieldMetadata(alias="json")
16
- ] = None
15
+ json_: typing_extensions.Annotated[ExecuteApiResponseJson, FieldMetadata(alias="json")]
17
16
  headers: typing.Dict[str, str]
18
17
 
19
18
  if IS_PYDANTIC_V2:
@@ -0,0 +1,7 @@
1
+ # This file was auto-generated by Fern from our API Definition.
2
+
3
+ import typing
4
+
5
+ ExecuteApiResponseJson = typing.Union[
6
+ typing.Dict[str, typing.Optional[typing.Any]], typing.List[typing.Optional[typing.Any]]
7
+ ]
@@ -8,6 +8,7 @@ import pydantic
8
8
 
9
9
  class PromptSettings(UniversalBaseModel):
10
10
  timeout: typing.Optional[float] = None
11
+ stream_enabled: typing.Optional[bool] = None
11
12
 
12
13
  if IS_PYDANTIC_V2:
13
14
  model_config: typing.ClassVar[pydantic.ConfigDict] = pydantic.ConfigDict(extra="allow", frozen=True) # type: ignore # Pydantic v2
@@ -10,8 +10,6 @@ from .workflow_deployment_parent_context import WorkflowDeploymentParentContext
10
10
  from .workflow_parent_context import WorkflowParentContext
11
11
  from .workflow_sandbox_parent_context import WorkflowSandboxParentContext
12
12
  from .array_vellum_value import ArrayVellumValue
13
- from .node_event_display_context import NodeEventDisplayContext
14
- from .workflow_event_display_context import WorkflowEventDisplayContext
15
13
  import typing
16
14
  import datetime as dt
17
15
  from .execution_vellum_value import ExecutionVellumValue
@@ -56,5 +54,3 @@ update_forward_refs(WorkflowDeploymentParentContext, WorkflowEventExecutionRead=
56
54
  update_forward_refs(WorkflowParentContext, WorkflowEventExecutionRead=WorkflowEventExecutionRead)
57
55
  update_forward_refs(WorkflowSandboxParentContext, WorkflowEventExecutionRead=WorkflowEventExecutionRead)
58
56
  update_forward_refs(ArrayVellumValue, WorkflowEventExecutionRead=WorkflowEventExecutionRead)
59
- update_forward_refs(NodeEventDisplayContext, WorkflowEventExecutionRead=WorkflowEventExecutionRead)
60
- update_forward_refs(WorkflowEventDisplayContext, WorkflowEventExecutionRead=WorkflowEventExecutionRead)
@@ -1,19 +1,14 @@
1
1
  # This file was auto-generated by Fern from our API Definition.
2
2
 
3
- from __future__ import annotations
4
3
  from ..core.pydantic_utilities import UniversalBaseModel
5
- from .node_event_display_context import NodeEventDisplayContext
6
- from .workflow_event_display_context import WorkflowEventDisplayContext
7
4
  from .vellum_code_resource_definition import VellumCodeResourceDefinition
8
5
  import typing
9
6
  from ..core.pydantic_utilities import IS_PYDANTIC_V2
10
7
  import pydantic
11
- from ..core.pydantic_utilities import update_forward_refs
12
8
 
13
9
 
14
10
  class WorkflowExecutionInitiatedBody(UniversalBaseModel):
15
11
  workflow_definition: VellumCodeResourceDefinition
16
- display_context: typing.Optional[WorkflowEventDisplayContext] = None
17
12
  inputs: typing.Dict[str, typing.Optional[typing.Any]]
18
13
 
19
14
  if IS_PYDANTIC_V2:
@@ -24,7 +19,3 @@ class WorkflowExecutionInitiatedBody(UniversalBaseModel):
24
19
  frozen = True
25
20
  smart_union = True
26
21
  extra = pydantic.Extra.allow
27
-
28
-
29
- update_forward_refs(NodeEventDisplayContext, WorkflowExecutionInitiatedBody=WorkflowExecutionInitiatedBody)
30
- update_forward_refs(WorkflowEventDisplayContext, WorkflowExecutionInitiatedBody=WorkflowExecutionInitiatedBody)
@@ -9,8 +9,6 @@ from .span_link import SpanLink
9
9
  from .workflow_deployment_parent_context import WorkflowDeploymentParentContext
10
10
  from .workflow_parent_context import WorkflowParentContext
11
11
  from .workflow_sandbox_parent_context import WorkflowSandboxParentContext
12
- from .node_event_display_context import NodeEventDisplayContext
13
- from .workflow_event_display_context import WorkflowEventDisplayContext
14
12
  import typing
15
13
  from .parent_context import ParentContext
16
14
  from .workflow_execution_initiated_body import WorkflowExecutionInitiatedBody
@@ -49,5 +47,3 @@ update_forward_refs(SpanLink, WorkflowExecutionInitiatedEvent=WorkflowExecutionI
49
47
  update_forward_refs(WorkflowDeploymentParentContext, WorkflowExecutionInitiatedEvent=WorkflowExecutionInitiatedEvent)
50
48
  update_forward_refs(WorkflowParentContext, WorkflowExecutionInitiatedEvent=WorkflowExecutionInitiatedEvent)
51
49
  update_forward_refs(WorkflowSandboxParentContext, WorkflowExecutionInitiatedEvent=WorkflowExecutionInitiatedEvent)
52
- update_forward_refs(NodeEventDisplayContext, WorkflowExecutionInitiatedEvent=WorkflowExecutionInitiatedEvent)
53
- update_forward_refs(WorkflowEventDisplayContext, WorkflowExecutionInitiatedEvent=WorkflowExecutionInitiatedEvent)
@@ -3,12 +3,10 @@
3
3
  from __future__ import annotations
4
4
  from ..core.pydantic_utilities import UniversalBaseModel
5
5
  from .api_request_parent_context import ApiRequestParentContext
6
- from .node_event_display_context import NodeEventDisplayContext
7
6
  from .node_parent_context import NodeParentContext
8
7
  from .prompt_deployment_parent_context import PromptDeploymentParentContext
9
8
  from .span_link import SpanLink
10
9
  from .workflow_deployment_parent_context import WorkflowDeploymentParentContext
11
- from .workflow_event_display_context import WorkflowEventDisplayContext
12
10
  from .workflow_parent_context import WorkflowParentContext
13
11
  from .workflow_sandbox_parent_context import WorkflowSandboxParentContext
14
12
  import typing
@@ -40,11 +38,9 @@ class WorkflowExecutionSpan(UniversalBaseModel):
40
38
 
41
39
 
42
40
  update_forward_refs(ApiRequestParentContext, WorkflowExecutionSpan=WorkflowExecutionSpan)
43
- update_forward_refs(NodeEventDisplayContext, WorkflowExecutionSpan=WorkflowExecutionSpan)
44
41
  update_forward_refs(NodeParentContext, WorkflowExecutionSpan=WorkflowExecutionSpan)
45
42
  update_forward_refs(PromptDeploymentParentContext, WorkflowExecutionSpan=WorkflowExecutionSpan)
46
43
  update_forward_refs(SpanLink, WorkflowExecutionSpan=WorkflowExecutionSpan)
47
44
  update_forward_refs(WorkflowDeploymentParentContext, WorkflowExecutionSpan=WorkflowExecutionSpan)
48
- update_forward_refs(WorkflowEventDisplayContext, WorkflowExecutionSpan=WorkflowExecutionSpan)
49
45
  update_forward_refs(WorkflowParentContext, WorkflowExecutionSpan=WorkflowExecutionSpan)
50
46
  update_forward_refs(WorkflowSandboxParentContext, WorkflowExecutionSpan=WorkflowExecutionSpan)
@@ -1,3 +1,3 @@
1
1
  # WARNING: This file will be removed in a future release. Please import from "vellum.client" instead.
2
2
 
3
- from vellum.client.types.node_event_display_context import *
3
+ from vellum.client.types.execute_api_response_json import *
@@ -1,4 +1,4 @@
1
- from typing import Any, Iterator, Tuple, Type, Union, get_args, get_origin
1
+ from typing import Any, Dict, Iterator, Set, Tuple, Type, Union, get_args, get_origin
2
2
  from typing_extensions import dataclass_transform
3
3
 
4
4
  from pydantic import GetCoreSchemaHandler
@@ -14,11 +14,28 @@ from vellum.workflows.types.utils import get_class_attr_names, infer_types
14
14
 
15
15
  @dataclass_transform(kw_only_default=True)
16
16
  class _BaseInputsMeta(type):
17
+ def __new__(cls, name: str, bases: Tuple[Type, ...], dct: Dict[str, Any]) -> Any:
18
+ dct["__parent_class__"] = type(None)
19
+ return super().__new__(cls, name, bases, dct)
20
+
17
21
  def __getattribute__(cls, name: str) -> Any:
18
- if not name.startswith("_") and name in cls.__annotations__ and issubclass(cls, BaseInputs):
22
+ if name.startswith("_") or not issubclass(cls, BaseInputs):
23
+ return super().__getattribute__(name)
24
+
25
+ attr_names = get_class_attr_names(cls)
26
+ if name in attr_names:
27
+ # We first try to resolve the instance that this class attribute name is mapped to. If it's not found,
28
+ # we iterate through its inheritance hierarchy to find the first base class that has this attribute
29
+ # and use its mapping.
19
30
  instance = vars(cls).get(name, undefined)
20
- types = infer_types(cls, name)
31
+ if instance is undefined:
32
+ for base in cls.__mro__[1:]:
33
+ inherited_input_reference = getattr(base, name, undefined)
34
+ if isinstance(inherited_input_reference, (ExternalInputReference, WorkflowInputReference)):
35
+ instance = inherited_input_reference.instance
36
+ break
21
37
 
38
+ types = infer_types(cls, name)
22
39
  if getattr(cls, "__descriptor_class__", None) is ExternalInputReference:
23
40
  return ExternalInputReference(name=name, types=types, instance=instance, inputs_class=cls)
24
41
  else:
@@ -29,14 +46,20 @@ class _BaseInputsMeta(type):
29
46
  def __iter__(cls) -> Iterator[InputReference]:
30
47
  # We iterate through the inheritance hierarchy to find all the WorkflowInputReference attached to this
31
48
  # Inputs class. __mro__ is the method resolution order, which is the order in which base classes are resolved.
49
+ yielded_attr_names: Set[str] = set()
50
+
32
51
  for resolved_cls in cls.__mro__:
33
52
  attr_names = get_class_attr_names(resolved_cls)
34
53
  for attr_name in attr_names:
54
+ if attr_name in yielded_attr_names:
55
+ continue
56
+
35
57
  attr_value = getattr(resolved_cls, attr_name)
36
58
  if not isinstance(attr_value, (WorkflowInputReference, ExternalInputReference)):
37
59
  continue
38
60
 
39
61
  yield attr_value
62
+ yielded_attr_names.add(attr_name)
40
63
 
41
64
 
42
65
  class BaseInputs(metaclass=_BaseInputsMeta):
@@ -47,3 +47,18 @@ def test_base_inputs_with_default():
47
47
 
48
48
  # THEN it should use the default value
49
49
  assert inputs.string_with_default == "default_value"
50
+
51
+
52
+ def test_base_inputs__supports_inherited_inputs():
53
+ # GIVEN an inputs class
54
+ class TopInputs(BaseInputs):
55
+ first: str
56
+
57
+ # WHEN we inherit from the base inputs class
58
+ class BottomInputs(TopInputs):
59
+ second: int
60
+
61
+ # THEN both references should be available
62
+ assert BottomInputs.first.name == "first"
63
+ assert BottomInputs.second.name == "second"
64
+ assert len([ref for ref in BottomInputs]) == 2
@@ -1,5 +1,6 @@
1
1
  from typing import TYPE_CHECKING, Any, Dict, Generic, Optional, Tuple, Type
2
2
 
3
+ from vellum.workflows.inputs.base import BaseInputs
3
4
  from vellum.workflows.nodes.bases.base import BaseNode, BaseNodeMeta
4
5
  from vellum.workflows.outputs.base import BaseOutputs
5
6
  from vellum.workflows.references.output import OutputReference
@@ -13,6 +14,14 @@ class _BaseAdornmentNodeMeta(BaseNodeMeta):
13
14
  def __new__(cls, name: str, bases: Tuple[Type, ...], dct: Dict[str, Any]) -> Any:
14
15
  node_class = super().__new__(cls, name, bases, dct)
15
16
 
17
+ SubworkflowInputs = dct.get("SubworkflowInputs")
18
+ if (
19
+ isinstance(SubworkflowInputs, type)
20
+ and issubclass(SubworkflowInputs, BaseInputs)
21
+ and SubworkflowInputs.__parent_class__ is type(None)
22
+ ):
23
+ SubworkflowInputs.__parent_class__ = node_class
24
+
16
25
  subworkflow_attribute = dct.get("subworkflow")
17
26
  if not subworkflow_attribute:
18
27
  return node_class
@@ -62,7 +62,7 @@ class MapNode(BaseAdornmentNode[StateType], Generic[StateType, MapNodeItemType])
62
62
 
63
63
  item: MapNodeItemType # type: ignore[valid-type]
64
64
  index: int
65
- all_items: List[MapNodeItemType] # type: ignore[valid-type]
65
+ items: List[MapNodeItemType] # type: ignore[valid-type]
66
66
 
67
67
  def run(self) -> Iterator[BaseOutput]:
68
68
  mapped_items: Dict[str, List] = defaultdict(list)
@@ -176,8 +176,9 @@ class MapNode(BaseAdornmentNode[StateType], Generic[StateType, MapNodeItemType])
176
176
  parent_state=self.state,
177
177
  context=context,
178
178
  )
179
+ SubworkflowInputsClass = self.subworkflow.get_inputs_class()
179
180
  events = subworkflow.stream(
180
- inputs=self.SubworkflowInputs(index=index, item=item, all_items=self.items),
181
+ inputs=SubworkflowInputsClass(index=index, item=item, items=self.items),
181
182
  node_output_mocks=self._context._get_all_node_output_mocks(),
182
183
  event_filter=all_workflow_event_filter,
183
184
  )
@@ -116,3 +116,59 @@ def test_map_node__inner_try():
116
116
  # THEN the workflow should succeed
117
117
  assert outputs[-1].name == "final_output"
118
118
  assert len(outputs[-1].value) == 2
119
+
120
+
121
+ def test_map_node__nested_map_node():
122
+ # GIVEN the inner map node's inputs
123
+ class VegetableMapNodeInputs(MapNode.SubworkflowInputs):
124
+ item: str
125
+
126
+ # AND the outer map node's inputs
127
+ class FruitMapNodeInputs(MapNode.SubworkflowInputs):
128
+ item: str
129
+
130
+ # AND a simple node that concats both attributes
131
+ class SimpleConcatNode(BaseNode):
132
+ fruit = FruitMapNodeInputs.item
133
+ vegetable = VegetableMapNodeInputs.item
134
+
135
+ class Outputs(BaseNode.Outputs):
136
+ medley: str
137
+
138
+ def run(self) -> Outputs:
139
+ return self.Outputs(medley=f"{self.fruit} {self.vegetable}")
140
+
141
+ # AND a workflow using that node
142
+ class VegetableMapNodeWorkflow(BaseWorkflow[VegetableMapNodeInputs, BaseState]):
143
+ graph = SimpleConcatNode
144
+
145
+ class Outputs(BaseWorkflow.Outputs):
146
+ final_output = SimpleConcatNode.Outputs.medley
147
+
148
+ # AND an inner map node referencing that workflow
149
+ class VegetableMapNode(MapNode):
150
+ items = ["carrot", "potato"]
151
+ subworkflow = VegetableMapNodeWorkflow
152
+
153
+ # AND an outer subworkflow referencing the inner map node
154
+ class FruitMapNodeWorkflow(BaseWorkflow[FruitMapNodeInputs, BaseState]):
155
+ graph = VegetableMapNode
156
+
157
+ class Outputs(BaseWorkflow.Outputs):
158
+ final_output = VegetableMapNode.Outputs.final_output
159
+
160
+ # AND an outer map node referencing the outer subworkflow
161
+ class FruitMapNode(MapNode):
162
+ items = ["apple", "banana"]
163
+ subworkflow = FruitMapNodeWorkflow
164
+
165
+ # WHEN we run the workflow
166
+ stream = FruitMapNode().run()
167
+ outputs = list(stream)
168
+
169
+ # THEN the workflow should succeed
170
+ assert outputs[-1].name == "final_output"
171
+ assert outputs[-1].value == [
172
+ ["apple carrot", "apple potato"],
173
+ ["banana carrot", "banana potato"],
174
+ ]
@@ -47,8 +47,9 @@ class RetryNode(BaseAdornmentNode[StateType], Generic[StateType]):
47
47
  parent_state=self.state,
48
48
  context=WorkflowContext.create_from(self._context),
49
49
  )
50
+ inputs_class = subworkflow.get_inputs_class()
50
51
  subworkflow_stream = subworkflow.stream(
51
- inputs=self.SubworkflowInputs(attempt_number=attempt_number),
52
+ inputs=inputs_class(attempt_number=attempt_number),
52
53
  event_filter=all_workflow_event_filter,
53
54
  node_output_mocks=self._context._get_all_node_output_mocks(),
54
55
  )
@@ -1,8 +1,7 @@
1
1
  from collections.abc import Callable
2
- from typing import Any, ClassVar, Dict, List, Optional, cast
2
+ from typing import Any, ClassVar, List, Optional
3
3
 
4
- from vellum import ChatMessage, FunctionDefinition, PromptBlock
5
- from vellum.client.types.chat_message_request import ChatMessageRequest
4
+ from vellum import ChatMessage, PromptBlock
6
5
  from vellum.workflows.context import execution_context, get_parent_context
7
6
  from vellum.workflows.errors.types import WorkflowErrorCode
8
7
  from vellum.workflows.exceptions import NodeException
@@ -35,8 +34,7 @@ class ToolCallingNode(BaseNode):
35
34
 
36
35
  ml_model: ClassVar[str] = "gpt-4o-mini"
37
36
  blocks: ClassVar[List[PromptBlock]] = []
38
- functions: ClassVar[List[FunctionDefinition]] = []
39
- function_callables: ClassVar[Dict[str, Callable[..., Any]]] = {}
37
+ functions: ClassVar[List[Callable[..., Any]]] = []
40
38
  prompt_inputs: ClassVar[Optional[EntityInputsInterface]] = None
41
39
  # TODO: https://linear.app/vellum/issue/APO-342/support-tool-call-max-retries
42
40
  max_tool_calls: ClassVar[int] = 1
@@ -59,27 +57,13 @@ class ToolCallingNode(BaseNode):
59
57
  This dynamically builds a graph with router and function nodes,
60
58
  then executes the workflow.
61
59
  """
62
- self._validate_functions()
63
-
64
- initial_chat_history = []
65
-
66
- # Extract chat history from prompt inputs if available
67
- if self.prompt_inputs and "chat_history" in self.prompt_inputs:
68
- chat_history_input = self.prompt_inputs["chat_history"]
69
- if isinstance(chat_history_input, list) and all(
70
- isinstance(msg, (ChatMessage, ChatMessageRequest)) for msg in chat_history_input
71
- ):
72
- initial_chat_history = [
73
- msg if isinstance(msg, ChatMessage) else ChatMessage.model_validate(msg.model_dump())
74
- for msg in chat_history_input
75
- ]
76
60
 
77
61
  self._build_graph()
78
62
 
79
63
  with execution_context(parent_context=get_parent_context()):
80
64
 
81
65
  class ToolCallingState(BaseState):
82
- chat_history: List[ChatMessage] = initial_chat_history
66
+ chat_history: List[ChatMessage] = []
83
67
 
84
68
  class ToolCallingWorkflow(BaseWorkflow[BaseInputs, ToolCallingState]):
85
69
  graph = self._graph
@@ -121,9 +105,8 @@ class ToolCallingNode(BaseNode):
121
105
  )
122
106
 
123
107
  self._function_nodes = {
124
- function.name: create_function_node(
108
+ function.__name__: create_function_node(
125
109
  function=function,
126
- function_callable=cast(Callable[..., Any], self.function_callables[function.name]), # type: ignore
127
110
  )
128
111
  for function in self.functions
129
112
  }
@@ -132,7 +115,7 @@ class ToolCallingNode(BaseNode):
132
115
 
133
116
  # Add connections from ports of router to function nodes and back to router
134
117
  for function_name, FunctionNodeClass in self._function_nodes.items():
135
- router_port = getattr(self.tool_router_node.Ports, function_name) # type: ignore # mypy thinks name is still optional # noqa: E501
118
+ router_port = getattr(self.tool_router_node.Ports, function_name)
136
119
  edge_graph = router_port >> FunctionNodeClass >> self.tool_router_node
137
120
  graph_set.add(edge_graph)
138
121
 
@@ -140,8 +123,3 @@ class ToolCallingNode(BaseNode):
140
123
  graph_set.add(default_port)
141
124
 
142
125
  self._graph = Graph.from_set(graph_set)
143
-
144
- def _validate_functions(self) -> None:
145
- for function in self.functions:
146
- if function.name is None:
147
- raise ValueError("Function name is required")
@@ -35,7 +35,7 @@ class ToolRouterNode(InlinePromptNode):
35
35
  if function_call is not None:
36
36
  self.state.chat_history.append(
37
37
  ChatMessage(
38
- role="FUNCTION",
38
+ role="ASSISTANT",
39
39
  content=FunctionCallChatMessageContent(
40
40
  value=FunctionCallChatMessageContentValue(
41
41
  name=function_call.name,
@@ -51,16 +51,12 @@ class ToolRouterNode(InlinePromptNode):
51
51
  def create_tool_router_node(
52
52
  ml_model: str,
53
53
  blocks: List[PromptBlock],
54
- functions: List[FunctionDefinition],
54
+ functions: List[Callable[..., Any]],
55
55
  prompt_inputs: Optional[EntityInputsInterface],
56
56
  ) -> Type[ToolRouterNode]:
57
57
  Ports = type("Ports", (), {})
58
58
  for function in functions:
59
- if function.name is None:
60
- # We should not raise an error here since we filter out functions without names
61
- raise ValueError("Function name is required")
62
-
63
- function_name = function.name
59
+ function_name = function.__name__
64
60
  port_condition = LazyReference(
65
61
  lambda: (
66
62
  ToolRouterNode.Outputs.results[0]["type"].equals("FUNCTION_CALL")
@@ -98,7 +94,7 @@ def create_tool_router_node(
98
94
  return node
99
95
 
100
96
 
101
- def create_function_node(function: FunctionDefinition, function_callable: Callable[..., Any]) -> Type[FunctionNode]:
97
+ def create_function_node(function: Callable[..., Any]) -> Type[FunctionNode]:
102
98
  """
103
99
  Create a FunctionNode class for a given function.
104
100
 
@@ -113,14 +109,14 @@ def create_function_node(function: FunctionDefinition, function_callable: Callab
113
109
  arguments = outputs["arguments"]
114
110
 
115
111
  # Call the original function directly with the arguments
116
- result = function_callable(**arguments)
112
+ result = function(**arguments)
117
113
 
118
114
  self.state.chat_history.append(ChatMessage(role="FUNCTION", text=result))
119
115
 
120
116
  return self.Outputs()
121
117
 
122
118
  node = type(
123
- f"FunctionNode_{function.name}",
119
+ f"FunctionNode_{function.__name__}",
124
120
  (FunctionNode,),
125
121
  {
126
122
  "function": function,
@@ -9,9 +9,11 @@ from pydantic import BaseModel, create_model
9
9
  from vellum.client.types.function_call import FunctionCall
10
10
  from vellum.workflows.errors.types import WorkflowErrorCode
11
11
  from vellum.workflows.exceptions import NodeException
12
+ from vellum.workflows.inputs.base import BaseInputs
12
13
  from vellum.workflows.nodes import BaseNode
13
14
  from vellum.workflows.nodes.bases.base_adornment_node import BaseAdornmentNode
14
15
  from vellum.workflows.ports.port import Port
16
+ from vellum.workflows.state.base import BaseState
15
17
  from vellum.workflows.types.core import Json
16
18
  from vellum.workflows.types.generics import NodeType
17
19
 
@@ -54,7 +56,18 @@ def create_adornment(
54
56
  # https://app.shortcut.com/vellum/story/4116
55
57
  from vellum.workflows import BaseWorkflow
56
58
 
57
- class Subworkflow(BaseWorkflow):
59
+ SubworkflowInputs = getattr(adornable_cls, "SubworkflowInputs", None)
60
+ BaseSubworkflowInputs = (
61
+ SubworkflowInputs
62
+ if isinstance(SubworkflowInputs, type) and issubclass(SubworkflowInputs, BaseInputs)
63
+ else BaseInputs
64
+ )
65
+
66
+ # mypy is too conservative here - you can absolutely inherit from dynamic classes in python
67
+ class Inputs(BaseSubworkflowInputs): # type: ignore[misc, valid-type]
68
+ pass
69
+
70
+ class Subworkflow(BaseWorkflow[Inputs, BaseState]):
58
71
  graph = inner_cls
59
72
 
60
73
  class Outputs(BaseWorkflow.Outputs):
@@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Generic, Optional, Tuple, Type, TypeVar, cast
3
3
  from vellum.workflows.descriptors.base import BaseDescriptor
4
4
  from vellum.workflows.errors.types import WorkflowErrorCode
5
5
  from vellum.workflows.exceptions import NodeException
6
+ from vellum.workflows.types.generics import import_workflow_class
6
7
 
7
8
  if TYPE_CHECKING:
8
9
  from vellum.workflows.inputs.base import BaseInputs
@@ -29,7 +30,10 @@ class WorkflowInputReference(BaseDescriptor[_InputType], Generic[_InputType]):
29
30
  return self._inputs_class
30
31
 
31
32
  def resolve(self, state: "BaseState") -> _InputType:
32
- if hasattr(state.meta.workflow_inputs, self._name):
33
+ if hasattr(state.meta.workflow_inputs, self._name) and (
34
+ state.meta.workflow_definition == self._inputs_class.__parent_class__
35
+ or not issubclass(self._inputs_class.__parent_class__, import_workflow_class())
36
+ ):
33
37
  return cast(_InputType, getattr(state.meta.workflow_inputs, self._name))
34
38
 
35
39
  if state.meta.parent:
@@ -101,6 +101,7 @@ class WorkflowRunner(Generic[StateType]):
101
101
  if state:
102
102
  self._initial_state = deepcopy(state)
103
103
  self._initial_state.meta.span_id = uuid4()
104
+ self._initial_state.meta.workflow_definition = self.workflow.__class__
104
105
  else:
105
106
  self._initial_state = self.workflow.get_state_at_node(node)
106
107
  self._entrypoints = entrypoint_nodes
@@ -126,6 +127,7 @@ class WorkflowRunner(Generic[StateType]):
126
127
  self._initial_state = deepcopy(state)
127
128
  self._initial_state.meta.workflow_inputs = normalized_inputs
128
129
  self._initial_state.meta.span_id = uuid4()
130
+ self._initial_state.meta.workflow_definition = self.workflow.__class__
129
131
  else:
130
132
  self._initial_state = self.workflow.get_default_state(normalized_inputs)
131
133
  # We don't want to emit the initial state on the base case of Workflow Runs, since
@@ -133,6 +133,11 @@ class _BaseWorkflowMeta(type):
133
133
  cls = super().__new__(mcs, name, bases, dct)
134
134
  workflow_class = cast(Type["BaseWorkflow"], cls)
135
135
  workflow_class.__id__ = uuid4_from_hash(workflow_class.__qualname__)
136
+
137
+ inputs_class = workflow_class.get_inputs_class()
138
+ if inputs_class is not BaseInputs and inputs_class.__parent_class__ is type(None):
139
+ inputs_class.__parent_class__ = workflow_class
140
+
136
141
  return workflow_class
137
142
 
138
143
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: vellum-ai
3
- Version: 0.14.40
3
+ Version: 0.14.42
4
4
  Summary:
5
5
  License: MIT
6
6
  Requires-Python: >=3.9,<4.0