vellum-ai 0.13.0__py3-none-any.whl → 0.13.1__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- vellum/client/core/client_wrapper.py +1 -1
- vellum/client/core/pydantic_utilities.py +5 -0
- vellum/client/resources/workflows/client.py +8 -0
- vellum/client/types/logical_operator.py +2 -0
- vellum/workflows/descriptors/base.py +1 -1
- vellum/workflows/descriptors/tests/test_utils.py +3 -0
- vellum/workflows/expressions/accessor.py +8 -2
- vellum/workflows/nodes/core/map_node/node.py +49 -24
- vellum/workflows/nodes/core/map_node/tests/test_node.py +4 -4
- vellum/workflows/nodes/displayable/bases/base_prompt_node/node.py +1 -1
- vellum/workflows/nodes/displayable/bases/inline_prompt_node/node.py +5 -3
- vellum/workflows/nodes/displayable/bases/prompt_deployment_node.py +3 -0
- vellum/workflows/nodes/displayable/bases/search_node.py +37 -2
- vellum/workflows/nodes/displayable/bases/tests/__init__.py +0 -0
- vellum/workflows/nodes/displayable/bases/tests/test_utils.py +61 -0
- vellum/workflows/nodes/displayable/bases/types.py +42 -0
- vellum/workflows/nodes/displayable/bases/utils.py +112 -0
- vellum/workflows/nodes/displayable/inline_prompt_node/tests/test_node.py +0 -1
- vellum/workflows/nodes/displayable/search_node/tests/__init__.py +0 -0
- vellum/workflows/nodes/displayable/search_node/tests/test_node.py +164 -0
- vellum/workflows/nodes/displayable/tests/test_inline_text_prompt_node.py +2 -3
- vellum/workflows/nodes/displayable/tests/test_text_prompt_deployment_node.py +0 -1
- vellum/workflows/runner/runner.py +37 -4
- vellum/workflows/types/tests/test_utils.py +5 -2
- vellum/workflows/types/utils.py +4 -0
- vellum/workflows/workflows/base.py +14 -0
- {vellum_ai-0.13.0.dist-info → vellum_ai-0.13.1.dist-info}/METADATA +1 -1
- {vellum_ai-0.13.0.dist-info → vellum_ai-0.13.1.dist-info}/RECORD +46 -36
- vellum_cli/__init__.py +10 -0
- vellum_cli/ping.py +28 -0
- vellum_cli/tests/test_ping.py +47 -0
- vellum_ee/workflows/display/nodes/vellum/base_node.py +22 -9
- vellum_ee/workflows/display/nodes/vellum/inline_prompt_node.py +3 -0
- vellum_ee/workflows/display/nodes/vellum/map_node.py +1 -1
- vellum_ee/workflows/display/nodes/vellum/prompt_deployment_node.py +14 -10
- vellum_ee/workflows/display/nodes/vellum/tests/test_utils.py +2 -2
- vellum_ee/workflows/display/nodes/vellum/utils.py +8 -1
- vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_adornments_serialization.py +67 -0
- vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_attributes_serialization.py +66 -0
- vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_ports_serialization.py +660 -0
- vellum_ee/workflows/display/utils/vellum.py +4 -42
- vellum_ee/workflows/display/vellum.py +7 -36
- vellum_ee/workflows/display/workflows/vellum_workflow_display.py +2 -1
- {vellum_ai-0.13.0.dist-info → vellum_ai-0.13.1.dist-info}/LICENSE +0 -0
- {vellum_ai-0.13.0.dist-info → vellum_ai-0.13.1.dist-info}/WHEEL +0 -0
- {vellum_ai-0.13.0.dist-info → vellum_ai-0.13.1.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,112 @@
|
|
1
|
+
import enum
|
2
|
+
import json
|
3
|
+
from typing import Any, List, Union, cast
|
4
|
+
|
5
|
+
from vellum.client.types.array_vellum_value import ArrayVellumValue
|
6
|
+
from vellum.client.types.array_vellum_value_request import ArrayVellumValueRequest
|
7
|
+
from vellum.client.types.audio_vellum_value import AudioVellumValue
|
8
|
+
from vellum.client.types.audio_vellum_value_request import AudioVellumValueRequest
|
9
|
+
from vellum.client.types.chat_history_vellum_value import ChatHistoryVellumValue
|
10
|
+
from vellum.client.types.chat_history_vellum_value_request import ChatHistoryVellumValueRequest
|
11
|
+
from vellum.client.types.chat_message import ChatMessage
|
12
|
+
from vellum.client.types.error_vellum_value import ErrorVellumValue
|
13
|
+
from vellum.client.types.error_vellum_value_request import ErrorVellumValueRequest
|
14
|
+
from vellum.client.types.function_call_vellum_value import FunctionCallVellumValue
|
15
|
+
from vellum.client.types.function_call_vellum_value_request import FunctionCallVellumValueRequest
|
16
|
+
from vellum.client.types.image_vellum_value import ImageVellumValue
|
17
|
+
from vellum.client.types.image_vellum_value_request import ImageVellumValueRequest
|
18
|
+
from vellum.client.types.json_vellum_value import JsonVellumValue
|
19
|
+
from vellum.client.types.json_vellum_value_request import JsonVellumValueRequest
|
20
|
+
from vellum.client.types.number_vellum_value import NumberVellumValue
|
21
|
+
from vellum.client.types.number_vellum_value_request import NumberVellumValueRequest
|
22
|
+
from vellum.client.types.search_result import SearchResult
|
23
|
+
from vellum.client.types.search_result_request import SearchResultRequest
|
24
|
+
from vellum.client.types.search_results_vellum_value import SearchResultsVellumValue
|
25
|
+
from vellum.client.types.search_results_vellum_value_request import SearchResultsVellumValueRequest
|
26
|
+
from vellum.client.types.string_vellum_value import StringVellumValue
|
27
|
+
from vellum.client.types.string_vellum_value_request import StringVellumValueRequest
|
28
|
+
from vellum.client.types.vellum_value import VellumValue
|
29
|
+
from vellum.client.types.vellum_value_request import VellumValueRequest
|
30
|
+
|
31
|
+
VELLUM_VALUE_REQUEST_TUPLE = (
|
32
|
+
StringVellumValueRequest,
|
33
|
+
NumberVellumValueRequest,
|
34
|
+
JsonVellumValueRequest,
|
35
|
+
ImageVellumValueRequest,
|
36
|
+
AudioVellumValueRequest,
|
37
|
+
FunctionCallVellumValueRequest,
|
38
|
+
ErrorVellumValueRequest,
|
39
|
+
ArrayVellumValueRequest,
|
40
|
+
ChatHistoryVellumValueRequest,
|
41
|
+
SearchResultsVellumValueRequest,
|
42
|
+
)
|
43
|
+
|
44
|
+
|
45
|
+
def primitive_to_vellum_value(value: Any) -> VellumValue:
|
46
|
+
"""Converts a python primitive to a VellumValue"""
|
47
|
+
|
48
|
+
if isinstance(value, str):
|
49
|
+
return StringVellumValue(value=value)
|
50
|
+
elif isinstance(value, enum.Enum):
|
51
|
+
return StringVellumValue(value=value.value)
|
52
|
+
elif isinstance(value, (int, float)):
|
53
|
+
return NumberVellumValue(value=value)
|
54
|
+
elif isinstance(value, list) and (
|
55
|
+
all(isinstance(message, ChatMessage) for message in value)
|
56
|
+
or all(isinstance(message, ChatMessage) for message in value)
|
57
|
+
):
|
58
|
+
chat_messages = cast(Union[List[ChatMessage], List[ChatMessage]], value)
|
59
|
+
return ChatHistoryVellumValue(value=chat_messages)
|
60
|
+
elif isinstance(value, list) and (
|
61
|
+
all(isinstance(search_result, SearchResultRequest) for search_result in value)
|
62
|
+
or all(isinstance(search_result, SearchResult) for search_result in value)
|
63
|
+
):
|
64
|
+
search_results = cast(Union[List[SearchResultRequest], List[SearchResult]], value)
|
65
|
+
return SearchResultsVellumValue(value=search_results)
|
66
|
+
elif isinstance(
|
67
|
+
value,
|
68
|
+
(
|
69
|
+
StringVellumValue,
|
70
|
+
NumberVellumValue,
|
71
|
+
JsonVellumValue,
|
72
|
+
ImageVellumValue,
|
73
|
+
AudioVellumValue,
|
74
|
+
FunctionCallVellumValue,
|
75
|
+
ErrorVellumValue,
|
76
|
+
ArrayVellumValue,
|
77
|
+
ChatHistoryVellumValue,
|
78
|
+
SearchResultsVellumValue,
|
79
|
+
),
|
80
|
+
):
|
81
|
+
return value
|
82
|
+
elif isinstance(
|
83
|
+
value,
|
84
|
+
VELLUM_VALUE_REQUEST_TUPLE,
|
85
|
+
):
|
86
|
+
# This type ignore is safe because consumers of this function won't care the difference between
|
87
|
+
# XVellumValue and XVellumValueRequest. Hopefully in the near future, neither will we
|
88
|
+
return value # type: ignore
|
89
|
+
|
90
|
+
try:
|
91
|
+
json_value = json.dumps(value)
|
92
|
+
except json.JSONDecodeError:
|
93
|
+
raise ValueError(f"Unsupported variable type: {value.__class__.__name__}")
|
94
|
+
|
95
|
+
return JsonVellumValue(value=json.loads(json_value))
|
96
|
+
|
97
|
+
|
98
|
+
def primitive_to_vellum_value_request(value: Any) -> VellumValueRequest:
|
99
|
+
vellum_value = primitive_to_vellum_value(value)
|
100
|
+
vellum_value_request_class = next(
|
101
|
+
(
|
102
|
+
vellum_value_request_class
|
103
|
+
for vellum_value_request_class in VELLUM_VALUE_REQUEST_TUPLE
|
104
|
+
if vellum_value_request_class.__name__.startswith(vellum_value.__class__.__name__)
|
105
|
+
),
|
106
|
+
None,
|
107
|
+
)
|
108
|
+
|
109
|
+
if vellum_value_request_class is None:
|
110
|
+
raise ValueError(f"Unsupported variable type: {vellum_value.__class__.__name__}")
|
111
|
+
|
112
|
+
return vellum_value_request_class.model_validate(vellum_value.model_dump())
|
@@ -76,7 +76,6 @@ def test_inline_prompt_node__function_definitions(vellum_adhoc_prompt_client):
|
|
76
76
|
class MyNode(InlinePromptNode):
|
77
77
|
ml_model = "gpt-4o"
|
78
78
|
functions = [my_function]
|
79
|
-
prompt_inputs = {}
|
80
79
|
blocks = []
|
81
80
|
|
82
81
|
# AND a known response from invoking an inline prompt
|
File without changes
|
@@ -0,0 +1,164 @@
|
|
1
|
+
from vellum import SearchResponse, SearchResult, SearchResultDocument
|
2
|
+
from vellum.client.types.search_filters_request import SearchFiltersRequest
|
3
|
+
from vellum.client.types.search_request_options_request import SearchRequestOptionsRequest
|
4
|
+
from vellum.client.types.search_result_merging_request import SearchResultMergingRequest
|
5
|
+
from vellum.client.types.search_weights_request import SearchWeightsRequest
|
6
|
+
from vellum.client.types.string_vellum_value_request import StringVellumValueRequest
|
7
|
+
from vellum.client.types.vellum_value_logical_condition_group_request import VellumValueLogicalConditionGroupRequest
|
8
|
+
from vellum.client.types.vellum_value_logical_condition_request import VellumValueLogicalConditionRequest
|
9
|
+
from vellum.workflows.nodes.displayable.bases.types import (
|
10
|
+
MetadataLogicalCondition,
|
11
|
+
MetadataLogicalConditionGroup,
|
12
|
+
SearchFilters,
|
13
|
+
)
|
14
|
+
from vellum.workflows.nodes.displayable.search_node.node import SearchNode
|
15
|
+
|
16
|
+
|
17
|
+
def test_run_workflow__happy_path(vellum_client):
|
18
|
+
"""Confirm that we can successfully invoke a Workflow with the new option attributes"""
|
19
|
+
|
20
|
+
# GIVEN a workflow that's set up run a Search Node
|
21
|
+
class MySearchNode(SearchNode):
|
22
|
+
query = "Search query"
|
23
|
+
document_index = "document_index"
|
24
|
+
limit = 1
|
25
|
+
weights = SearchWeightsRequest(
|
26
|
+
semantic_similarity=0.8,
|
27
|
+
keywords=0.2,
|
28
|
+
)
|
29
|
+
result_merging = SearchResultMergingRequest(
|
30
|
+
enabled=True,
|
31
|
+
)
|
32
|
+
filters = SearchFilters(
|
33
|
+
external_ids=["external_id"],
|
34
|
+
metadata=MetadataLogicalConditionGroup(
|
35
|
+
combinator="AND",
|
36
|
+
negated=False,
|
37
|
+
conditions=[
|
38
|
+
MetadataLogicalCondition(
|
39
|
+
lhs_variable="TYPE",
|
40
|
+
operator="=",
|
41
|
+
rhs_variable="COMPANY",
|
42
|
+
)
|
43
|
+
],
|
44
|
+
),
|
45
|
+
)
|
46
|
+
|
47
|
+
# AND a Search request that will return a 200 ok resposne
|
48
|
+
search_response = SearchResponse(
|
49
|
+
results=[
|
50
|
+
SearchResult(
|
51
|
+
text="Search query", score="0.0", keywords=["keywords"], document=SearchResultDocument(label="label")
|
52
|
+
)
|
53
|
+
]
|
54
|
+
)
|
55
|
+
|
56
|
+
vellum_client.search.return_value = search_response
|
57
|
+
|
58
|
+
# WHEN we run the workflow
|
59
|
+
outputs = MySearchNode().run()
|
60
|
+
|
61
|
+
# THEN the workflow should have completed successfully
|
62
|
+
assert outputs.text == "Search query"
|
63
|
+
|
64
|
+
# AND the options should be as expected
|
65
|
+
assert vellum_client.search.call_args.kwargs["options"] == SearchRequestOptionsRequest(
|
66
|
+
limit=1,
|
67
|
+
weights=SearchWeightsRequest(
|
68
|
+
semantic_similarity=0.8,
|
69
|
+
keywords=0.2,
|
70
|
+
),
|
71
|
+
result_merging=SearchResultMergingRequest(
|
72
|
+
enabled=True,
|
73
|
+
),
|
74
|
+
filters=SearchFiltersRequest(
|
75
|
+
external_ids=["external_id"],
|
76
|
+
metadata=VellumValueLogicalConditionGroupRequest(
|
77
|
+
combinator="AND",
|
78
|
+
negated=False,
|
79
|
+
conditions=[
|
80
|
+
VellumValueLogicalConditionRequest(
|
81
|
+
lhs_variable=StringVellumValueRequest(value="TYPE"),
|
82
|
+
operator="=",
|
83
|
+
rhs_variable=StringVellumValueRequest(value="COMPANY"),
|
84
|
+
)
|
85
|
+
],
|
86
|
+
),
|
87
|
+
),
|
88
|
+
)
|
89
|
+
|
90
|
+
|
91
|
+
def test_run_workflow__happy_path__options_attribute(vellum_client):
|
92
|
+
"""Confirm that we can successfully invoke a single Search node with the legacy options attribute"""
|
93
|
+
|
94
|
+
# GIVEN a workflow that's set up run a Search Node
|
95
|
+
class MySearchNode(SearchNode):
|
96
|
+
query = "Search query"
|
97
|
+
document_index = "document_index"
|
98
|
+
options = SearchRequestOptionsRequest(
|
99
|
+
limit=1,
|
100
|
+
weights=SearchWeightsRequest(
|
101
|
+
semantic_similarity=0.8,
|
102
|
+
keywords=0.2,
|
103
|
+
),
|
104
|
+
result_merging=SearchResultMergingRequest(
|
105
|
+
enabled=True,
|
106
|
+
),
|
107
|
+
filters=SearchFiltersRequest(
|
108
|
+
external_ids=["external_id"],
|
109
|
+
metadata=VellumValueLogicalConditionGroupRequest(
|
110
|
+
combinator="AND",
|
111
|
+
negated=False,
|
112
|
+
conditions=[
|
113
|
+
VellumValueLogicalConditionRequest(
|
114
|
+
lhs_variable=StringVellumValueRequest(value="TYPE"),
|
115
|
+
operator="=",
|
116
|
+
rhs_variable=StringVellumValueRequest(value="COMPANY"),
|
117
|
+
)
|
118
|
+
],
|
119
|
+
),
|
120
|
+
),
|
121
|
+
)
|
122
|
+
|
123
|
+
# AND a Search request that will return a 200 ok resposne
|
124
|
+
search_response = SearchResponse(
|
125
|
+
results=[
|
126
|
+
SearchResult(
|
127
|
+
text="Search query", score="0.0", keywords=["keywords"], document=SearchResultDocument(label="label")
|
128
|
+
)
|
129
|
+
]
|
130
|
+
)
|
131
|
+
|
132
|
+
vellum_client.search.return_value = search_response
|
133
|
+
|
134
|
+
# WHEN we run the workflow
|
135
|
+
outputs = MySearchNode().run()
|
136
|
+
|
137
|
+
# THEN the workflow should have completed successfully
|
138
|
+
assert outputs.text == "Search query"
|
139
|
+
|
140
|
+
# AND the options should be as expected
|
141
|
+
assert vellum_client.search.call_args.kwargs["options"] == SearchRequestOptionsRequest(
|
142
|
+
limit=1,
|
143
|
+
weights=SearchWeightsRequest(
|
144
|
+
semantic_similarity=0.8,
|
145
|
+
keywords=0.2,
|
146
|
+
),
|
147
|
+
result_merging=SearchResultMergingRequest(
|
148
|
+
enabled=True,
|
149
|
+
),
|
150
|
+
filters=SearchFiltersRequest(
|
151
|
+
external_ids=["external_id"],
|
152
|
+
metadata=VellumValueLogicalConditionGroupRequest(
|
153
|
+
combinator="AND",
|
154
|
+
negated=False,
|
155
|
+
conditions=[
|
156
|
+
VellumValueLogicalConditionRequest(
|
157
|
+
lhs_variable=StringVellumValueRequest(value="TYPE"),
|
158
|
+
operator="=",
|
159
|
+
rhs_variable=StringVellumValueRequest(value="COMPANY"),
|
160
|
+
)
|
161
|
+
],
|
162
|
+
),
|
163
|
+
),
|
164
|
+
)
|
@@ -1,3 +1,4 @@
|
|
1
|
+
from unittest import mock
|
1
2
|
from uuid import uuid4
|
2
3
|
from typing import Any, Iterator, List
|
3
4
|
|
@@ -33,7 +34,6 @@ def test_inline_text_prompt_node__basic(vellum_adhoc_prompt_client):
|
|
33
34
|
|
34
35
|
class MyInlinePromptNode(InlinePromptNode):
|
35
36
|
ml_model = "gpt-4o"
|
36
|
-
prompt_inputs = {}
|
37
37
|
blocks = []
|
38
38
|
|
39
39
|
# AND a known response from invoking an inline prompt
|
@@ -90,7 +90,7 @@ def test_inline_text_prompt_node__basic(vellum_adhoc_prompt_client):
|
|
90
90
|
logit_bias=None,
|
91
91
|
custom_parameters=None,
|
92
92
|
),
|
93
|
-
request_options=
|
93
|
+
request_options=mock.ANY,
|
94
94
|
)
|
95
95
|
|
96
96
|
|
@@ -107,7 +107,6 @@ def test_inline_text_prompt_node__catch_provider_error(vellum_adhoc_prompt_clien
|
|
107
107
|
@TryNode.wrap(on_error_code=WorkflowErrorCode.PROVIDER_ERROR)
|
108
108
|
class MyInlinePromptNode(InlinePromptNode):
|
109
109
|
ml_model = "gpt-4o"
|
110
|
-
prompt_inputs = {}
|
111
110
|
blocks = []
|
112
111
|
|
113
112
|
# AND a known response from invoking an inline prompt that fails
|
@@ -27,7 +27,6 @@ def test_text_prompt_deployment_node__basic(vellum_client):
|
|
27
27
|
|
28
28
|
class MyPromptDeploymentNode(PromptDeploymentNode):
|
29
29
|
deployment = "my-deployment"
|
30
|
-
prompt_inputs = {}
|
31
30
|
|
32
31
|
# AND a known response from invoking a deployed prompt
|
33
32
|
expected_outputs: List[PromptOutput] = [
|
@@ -4,7 +4,21 @@ import logging
|
|
4
4
|
from queue import Empty, Queue
|
5
5
|
from threading import Event as ThreadingEvent, Thread
|
6
6
|
from uuid import UUID
|
7
|
-
from typing import
|
7
|
+
from typing import (
|
8
|
+
TYPE_CHECKING,
|
9
|
+
Any,
|
10
|
+
Dict,
|
11
|
+
Generic,
|
12
|
+
Iterable,
|
13
|
+
Iterator,
|
14
|
+
List,
|
15
|
+
Optional,
|
16
|
+
Sequence,
|
17
|
+
Set,
|
18
|
+
Tuple,
|
19
|
+
Type,
|
20
|
+
Union,
|
21
|
+
)
|
8
22
|
|
9
23
|
from vellum.workflows.constants import UNDEF
|
10
24
|
from vellum.workflows.context import execution_context, get_parent_context
|
@@ -76,6 +90,7 @@ class WorkflowRunner(Generic[StateType]):
|
|
76
90
|
cancel_signal: Optional[ThreadingEvent] = None,
|
77
91
|
node_output_mocks: Optional[List[BaseOutputs]] = None,
|
78
92
|
parent_context: Optional[ParentContext] = None,
|
93
|
+
max_concurrency: Optional[int] = None,
|
79
94
|
):
|
80
95
|
if state and external_inputs:
|
81
96
|
raise ValueError("Can only run a Workflow providing one of state or external inputs, not both")
|
@@ -120,6 +135,9 @@ class WorkflowRunner(Generic[StateType]):
|
|
120
135
|
# This queue is responsible for sending events from the inner worker threads to WorkflowRunner
|
121
136
|
self._workflow_event_inner_queue: Queue[WorkflowEvent] = Queue()
|
122
137
|
|
138
|
+
self._max_concurrency = max_concurrency
|
139
|
+
self._concurrency_queue: Queue[Tuple[StateType, Type[BaseNode], Optional[Edge]]] = Queue()
|
140
|
+
|
123
141
|
# This queue is responsible for sending events from WorkflowRunner to the background thread
|
124
142
|
# for user defined emitters
|
125
143
|
self._background_thread_queue: Queue[BackgroundThreadItem] = Queue()
|
@@ -350,7 +368,19 @@ class WorkflowRunner(Generic[StateType]):
|
|
350
368
|
else:
|
351
369
|
next_state = state
|
352
370
|
|
353
|
-
self.
|
371
|
+
if self._max_concurrency:
|
372
|
+
self._concurrency_queue.put((next_state, edge.to_node, edge))
|
373
|
+
else:
|
374
|
+
self._run_node_if_ready(next_state, edge.to_node, edge)
|
375
|
+
|
376
|
+
if self._max_concurrency:
|
377
|
+
num_nodes_to_run = self._max_concurrency - len(self._active_nodes_by_execution_id)
|
378
|
+
for _ in range(num_nodes_to_run):
|
379
|
+
if self._concurrency_queue.empty():
|
380
|
+
break
|
381
|
+
|
382
|
+
next_state, node_class, invoked_edge = self._concurrency_queue.get()
|
383
|
+
self._run_node_if_ready(next_state, node_class, invoked_edge)
|
354
384
|
|
355
385
|
def _run_node_if_ready(
|
356
386
|
self,
|
@@ -513,8 +543,11 @@ class WorkflowRunner(Generic[StateType]):
|
|
513
543
|
)
|
514
544
|
for node_cls in self._entrypoints:
|
515
545
|
try:
|
516
|
-
|
517
|
-
|
546
|
+
if not self._max_concurrency or len(self._active_nodes_by_execution_id) < self._max_concurrency:
|
547
|
+
with execution_context(parent_context=current_parent):
|
548
|
+
self._run_node_if_ready(self._initial_state, node_cls)
|
549
|
+
else:
|
550
|
+
self._concurrency_queue.put((self._initial_state, node_cls, None))
|
518
551
|
except NodeException as e:
|
519
552
|
self._workflow_event_outer_queue.put(self._reject_workflow_event(e.error))
|
520
553
|
return
|
@@ -21,6 +21,7 @@ class ExampleClass:
|
|
21
21
|
zeta: ClassVar[str]
|
22
22
|
eta: List[str]
|
23
23
|
kappa: Any
|
24
|
+
mu: list[str]
|
24
25
|
|
25
26
|
|
26
27
|
T = TypeVar("T")
|
@@ -56,6 +57,7 @@ class ExampleNode(BaseNode):
|
|
56
57
|
(ExampleInheritedClass, "beta", (int,)),
|
57
58
|
(ExampleNode.Outputs, "iota", (str,)),
|
58
59
|
(ExampleClass, "kappa", (Any,)),
|
60
|
+
(ExampleClass, "mu", (list[str],)),
|
59
61
|
],
|
60
62
|
ids=[
|
61
63
|
"str",
|
@@ -71,6 +73,7 @@ class ExampleNode(BaseNode):
|
|
71
73
|
"inherited_parent_class_var",
|
72
74
|
"try_node_output",
|
73
75
|
"any",
|
76
|
+
"list_str_generic",
|
74
77
|
],
|
75
78
|
)
|
76
79
|
def test_infer_types(cls, attr_name, expected_type):
|
@@ -80,9 +83,9 @@ def test_infer_types(cls, attr_name, expected_type):
|
|
80
83
|
@pytest.mark.parametrize(
|
81
84
|
"cls, expected_attr_names",
|
82
85
|
[
|
83
|
-
(ExampleClass, {"alpha", "beta", "gamma", "epsilon", "zeta", "eta", "kappa"}),
|
86
|
+
(ExampleClass, {"alpha", "beta", "gamma", "epsilon", "zeta", "eta", "kappa", "mu"}),
|
84
87
|
(ExampleGenericClass, {"delta"}),
|
85
|
-
(ExampleInheritedClass, {"alpha", "beta", "gamma", "epsilon", "zeta", "eta", "theta", "kappa"}),
|
88
|
+
(ExampleInheritedClass, {"alpha", "beta", "gamma", "epsilon", "zeta", "eta", "theta", "kappa", "mu"}),
|
86
89
|
],
|
87
90
|
)
|
88
91
|
def test_class_attr_names(cls, expected_attr_names):
|
vellum/workflows/types/utils.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
from copy import deepcopy
|
2
2
|
from datetime import datetime
|
3
3
|
import importlib
|
4
|
+
from types import GenericAlias
|
4
5
|
from typing import (
|
5
6
|
Any,
|
6
7
|
ClassVar,
|
@@ -77,6 +78,9 @@ def infer_types(object_: Type, attr_name: str, localns: Optional[Dict[str, Any]]
|
|
77
78
|
return (type_hint,)
|
78
79
|
if isinstance(type_hint, SpecialGenericAlias):
|
79
80
|
return (type_hint,)
|
81
|
+
if isinstance(type_hint, GenericAlias):
|
82
|
+
# In future versions of python, list[str] will be a `GenericAlias`
|
83
|
+
return (cast(Type, type_hint),)
|
80
84
|
if isinstance(type_hint, TypeVar):
|
81
85
|
if type_hint in type_var_mapping:
|
82
86
|
return (type_var_mapping[type_hint],)
|
@@ -188,6 +188,7 @@ class BaseWorkflow(Generic[WorkflowInputsType, StateType], metaclass=_BaseWorkfl
|
|
188
188
|
external_inputs: Optional[ExternalInputsArg] = None,
|
189
189
|
cancel_signal: Optional[ThreadingEvent] = None,
|
190
190
|
node_output_mocks: Optional[List[BaseOutputs]] = None,
|
191
|
+
max_concurrency: Optional[int] = None,
|
191
192
|
) -> TerminalWorkflowEvent:
|
192
193
|
"""
|
193
194
|
Invoke a Workflow, returning the last event emitted, which should be one of:
|
@@ -215,6 +216,11 @@ class BaseWorkflow(Generic[WorkflowInputsType, StateType], metaclass=_BaseWorkfl
|
|
215
216
|
|
216
217
|
node_output_mocks: Optional[List[Outputs]] = None
|
217
218
|
A list of Outputs to mock for Nodes during Workflow Execution.
|
219
|
+
|
220
|
+
max_concurrency: Optional[int] = None
|
221
|
+
The max number of concurrent threads to run the Workflow with. If not provided, the Workflow will run
|
222
|
+
without limiting concurrency. This configuration only applies to the current Workflow and not to any
|
223
|
+
subworkflows or nodes that utilizes threads.
|
218
224
|
"""
|
219
225
|
|
220
226
|
events = WorkflowRunner(
|
@@ -226,6 +232,7 @@ class BaseWorkflow(Generic[WorkflowInputsType, StateType], metaclass=_BaseWorkfl
|
|
226
232
|
cancel_signal=cancel_signal,
|
227
233
|
node_output_mocks=node_output_mocks,
|
228
234
|
parent_context=self._context.parent_context,
|
235
|
+
max_concurrency=max_concurrency,
|
229
236
|
).stream()
|
230
237
|
first_event: Optional[Union[WorkflowExecutionInitiatedEvent, WorkflowExecutionResumedEvent]] = None
|
231
238
|
last_event = None
|
@@ -289,6 +296,7 @@ class BaseWorkflow(Generic[WorkflowInputsType, StateType], metaclass=_BaseWorkfl
|
|
289
296
|
external_inputs: Optional[ExternalInputsArg] = None,
|
290
297
|
cancel_signal: Optional[ThreadingEvent] = None,
|
291
298
|
node_output_mocks: Optional[List[BaseOutputs]] = None,
|
299
|
+
max_concurrency: Optional[int] = None,
|
292
300
|
) -> WorkflowEventStream:
|
293
301
|
"""
|
294
302
|
Invoke a Workflow, yielding events as they are emitted.
|
@@ -317,6 +325,11 @@ class BaseWorkflow(Generic[WorkflowInputsType, StateType], metaclass=_BaseWorkfl
|
|
317
325
|
|
318
326
|
node_output_mocks: Optional[List[Outputs]] = None
|
319
327
|
A list of Outputs to mock for Nodes during Workflow Execution.
|
328
|
+
|
329
|
+
max_concurrency: Optional[int] = None
|
330
|
+
The max number of concurrent threads to run the Workflow with. If not provided, the Workflow will run
|
331
|
+
without limiting concurrency. This configuration only applies to the current Workflow and not to any
|
332
|
+
subworkflows or nodes that utilizes threads.
|
320
333
|
"""
|
321
334
|
|
322
335
|
should_yield = event_filter or workflow_event_filter
|
@@ -329,6 +342,7 @@ class BaseWorkflow(Generic[WorkflowInputsType, StateType], metaclass=_BaseWorkfl
|
|
329
342
|
cancel_signal=cancel_signal,
|
330
343
|
node_output_mocks=node_output_mocks,
|
331
344
|
parent_context=self.context.parent_context,
|
345
|
+
max_concurrency=max_concurrency,
|
332
346
|
).stream():
|
333
347
|
if should_yield(self.__class__, event):
|
334
348
|
yield event
|