vellum-ai 0.14.19__py3-none-any.whl → 0.14.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. vellum/client/core/client_wrapper.py +1 -1
  2. vellum/workflows/nodes/displayable/code_execution_node/node.py +1 -1
  3. vellum/workflows/nodes/displayable/code_execution_node/tests/test_code_execution_node.py +41 -0
  4. vellum/workflows/nodes/displayable/code_execution_node/utils.py +6 -1
  5. vellum/workflows/references/lazy.py +9 -1
  6. vellum/workflows/references/tests/test_lazy.py +30 -0
  7. {vellum_ai-0.14.19.dist-info → vellum_ai-0.14.20.dist-info}/METADATA +1 -1
  8. {vellum_ai-0.14.19.dist-info → vellum_ai-0.14.20.dist-info}/RECORD +37 -35
  9. vellum_ee/workflows/display/base.py +6 -2
  10. vellum_ee/workflows/display/nodes/base_node_display.py +27 -2
  11. vellum_ee/workflows/display/nodes/base_node_vellum_display.py +0 -20
  12. vellum_ee/workflows/display/nodes/get_node_display_class.py +3 -3
  13. vellum_ee/workflows/display/nodes/vellum/code_execution_node.py +7 -3
  14. vellum_ee/workflows/display/nodes/vellum/inline_prompt_node.py +2 -6
  15. vellum_ee/workflows/display/nodes/vellum/prompt_deployment_node.py +2 -6
  16. vellum_ee/workflows/display/nodes/vellum/tests/test_code_execution_node.py +113 -0
  17. vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/conftest.py +2 -2
  18. vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_adornments_serialization.py +4 -4
  19. vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_attributes_serialization.py +8 -8
  20. vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_outputs_serialization.py +3 -3
  21. vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_ports_serialization.py +14 -14
  22. vellum_ee/workflows/display/tests/workflow_serialization/generic_nodes/test_trigger_serialization.py +3 -3
  23. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_conditional_node_serialization.py +5 -5
  24. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_default_state_serialization.py +1 -1
  25. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_generic_node_serialization.py +1 -1
  26. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_inline_subworkflow_serialization.py +2 -2
  27. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_map_node_serialization.py +2 -2
  28. vellum_ee/workflows/display/tests/workflow_serialization/test_basic_merge_node_serialization.py +3 -3
  29. vellum_ee/workflows/display/types.py +4 -7
  30. vellum_ee/workflows/display/vellum.py +10 -2
  31. vellum_ee/workflows/display/workflows/base_workflow_display.py +60 -32
  32. vellum_ee/workflows/display/workflows/vellum_workflow_display.py +33 -78
  33. vellum_ee/workflows/server/virtual_file_loader.py +52 -22
  34. vellum_ee/workflows/tests/test_server.py +61 -0
  35. {vellum_ai-0.14.19.dist-info → vellum_ai-0.14.20.dist-info}/LICENSE +0 -0
  36. {vellum_ai-0.14.19.dist-info → vellum_ai-0.14.20.dist-info}/WHEEL +0 -0
  37. {vellum_ai-0.14.19.dist-info → vellum_ai-0.14.20.dist-info}/entry_points.txt +0 -0
@@ -12,15 +12,14 @@ from vellum.workflows.edges import Edge
12
12
  from vellum.workflows.events.workflow import NodeEventDisplayContext, WorkflowEventDisplayContext
13
13
  from vellum.workflows.expressions.coalesce_expression import CoalesceExpression
14
14
  from vellum.workflows.nodes.bases import BaseNode
15
- from vellum.workflows.nodes.utils import get_wrapped_node
15
+ from vellum.workflows.nodes.utils import get_unadorned_node, get_unadorned_port, get_wrapped_node
16
16
  from vellum.workflows.ports import Port
17
17
  from vellum.workflows.references import OutputReference, StateValueReference, WorkflowInputReference
18
18
  from vellum.workflows.types.core import JsonObject
19
19
  from vellum.workflows.types.generics import WorkflowType
20
20
  from vellum.workflows.utils.uuids import uuid4_from_hash
21
21
  from vellum_ee.workflows.display.base import (
22
- EdgeDisplayOverridesType,
23
- EdgeDisplayType,
22
+ EdgeDisplay,
24
23
  EntrypointDisplayOverridesType,
25
24
  EntrypointDisplayType,
26
25
  StateValueDisplayOverridesType,
@@ -31,11 +30,13 @@ from vellum_ee.workflows.display.base import (
31
30
  WorkflowMetaDisplayType,
32
31
  WorkflowOutputDisplay,
33
32
  )
33
+ from vellum_ee.workflows.display.nodes.base_node_display import BaseNodeDisplay
34
34
  from vellum_ee.workflows.display.nodes.base_node_vellum_display import BaseNodeVellumDisplay
35
35
  from vellum_ee.workflows.display.nodes.get_node_display_class import get_node_display_class
36
36
  from vellum_ee.workflows.display.nodes.types import NodeOutputDisplay, PortDisplay, PortDisplayOverrides
37
37
  from vellum_ee.workflows.display.nodes.utils import raise_if_descriptor
38
- from vellum_ee.workflows.display.types import NodeDisplayType, WorkflowDisplayContext
38
+ from vellum_ee.workflows.display.types import WorkflowDisplayContext
39
+ from vellum_ee.workflows.display.vellum import EdgeVellumDisplay
39
40
  from vellum_ee.workflows.display.workflows.get_vellum_workflow_display_class import get_workflow_display
40
41
 
41
42
  logger = logging.getLogger(__name__)
@@ -50,11 +51,8 @@ class BaseWorkflowDisplay(
50
51
  WorkflowInputsDisplayOverridesType,
51
52
  StateValueDisplayType,
52
53
  StateValueDisplayOverridesType,
53
- NodeDisplayType,
54
54
  EntrypointDisplayType,
55
55
  EntrypointDisplayOverridesType,
56
- EdgeDisplayType,
57
- EdgeDisplayOverridesType,
58
56
  ]
59
57
  ):
60
58
  # Used to specify the display data for a workflow.
@@ -73,7 +71,7 @@ class BaseWorkflowDisplay(
73
71
  output_displays: Dict[BaseDescriptor, WorkflowOutputDisplay] = {}
74
72
 
75
73
  # Used to explicitly specify display data for a workflow's edges.
76
- edge_displays: Dict[Tuple[Port, Type[BaseNode]], EdgeDisplayOverridesType] = {}
74
+ edge_displays: Dict[Tuple[Port, Type[BaseNode]], EdgeDisplay] = {}
77
75
 
78
76
  # Used to explicitly specify display data for a workflow's ports.
79
77
  port_displays: Dict[Port, PortDisplayOverrides] = {}
@@ -94,9 +92,7 @@ class BaseWorkflowDisplay(
94
92
  WorkflowMetaDisplayType,
95
93
  WorkflowInputsDisplayType,
96
94
  StateValueDisplayType,
97
- NodeDisplayType,
98
95
  EntrypointDisplayType,
99
- EdgeDisplayType,
100
96
  ]
101
97
  ] = None,
102
98
  dry_run: bool = False,
@@ -124,7 +120,7 @@ class BaseWorkflowDisplay(
124
120
 
125
121
  @property
126
122
  @abstractmethod
127
- def node_display_base_class(self) -> Type[NodeDisplayType]:
123
+ def node_display_base_class(self) -> Type[BaseNodeDisplay]:
128
124
  pass
129
125
 
130
126
  def add_error(self, error: Exception) -> None:
@@ -141,7 +137,7 @@ class BaseWorkflowDisplay(
141
137
  def _enrich_global_node_output_displays(
142
138
  self,
143
139
  node: Type[BaseNode],
144
- node_display: NodeDisplayType,
140
+ node_display: BaseNodeDisplay,
145
141
  node_output_displays: Dict[OutputReference, Tuple[Type[BaseNode], NodeOutputDisplay]],
146
142
  ):
147
143
  """This method recursively adds nodes wrapped in decorators to the node_output_displays dictionary."""
@@ -162,7 +158,7 @@ class BaseWorkflowDisplay(
162
158
  def _enrich_node_port_displays(
163
159
  self,
164
160
  node: Type[BaseNode],
165
- node_display: NodeDisplayType,
161
+ node_display: BaseNodeDisplay,
166
162
  port_displays: Dict[Port, PortDisplay],
167
163
  ):
168
164
  """This method recursively adds nodes wrapped in decorators to the port_displays dictionary."""
@@ -178,7 +174,7 @@ class BaseWorkflowDisplay(
178
174
 
179
175
  port_displays[port] = node_display.get_node_port_display(port)
180
176
 
181
- def _get_node_display(self, node: Type[BaseNode]) -> NodeDisplayType:
177
+ def _get_node_display(self, node: Type[BaseNode]) -> BaseNodeDisplay:
182
178
  node_display_class = get_node_display_class(self.node_display_base_class, node)
183
179
  node_display = node_display_class()
184
180
 
@@ -194,9 +190,7 @@ class BaseWorkflowDisplay(
194
190
  WorkflowMetaDisplayType,
195
191
  WorkflowInputsDisplayType,
196
192
  StateValueDisplayType,
197
- NodeDisplayType,
198
193
  EntrypointDisplayType,
199
- EdgeDisplayType,
200
194
  ]:
201
195
  workflow_display = self._generate_workflow_meta_display()
202
196
 
@@ -204,9 +198,9 @@ class BaseWorkflowDisplay(
204
198
  copy(self._parent_display_context.global_node_output_displays) if self._parent_display_context else {}
205
199
  )
206
200
 
207
- node_displays: Dict[Type[BaseNode], NodeDisplayType] = {}
201
+ node_displays: Dict[Type[BaseNode], BaseNodeDisplay] = {}
208
202
 
209
- global_node_displays: Dict[Type[BaseNode], NodeDisplayType] = (
203
+ global_node_displays: Dict[Type[BaseNode], BaseNodeDisplay] = (
210
204
  copy(self._parent_display_context.global_node_displays) if self._parent_display_context else {}
211
205
  )
212
206
 
@@ -273,7 +267,7 @@ class BaseWorkflowDisplay(
273
267
  entrypoint, workflow_display, node_displays, overrides=entrypoint_display_overrides
274
268
  )
275
269
 
276
- edge_displays: Dict[Tuple[Port, Type[BaseNode]], EdgeDisplayType] = {}
270
+ edge_displays: Dict[Tuple[Port, Type[BaseNode]], EdgeVellumDisplay] = {}
277
271
  for edge in self._workflow.get_edges():
278
272
  if edge in edge_displays:
279
273
  continue
@@ -347,7 +341,7 @@ class BaseWorkflowDisplay(
347
341
  self,
348
342
  entrypoint: Type[BaseNode],
349
343
  workflow_display: WorkflowMetaDisplayType,
350
- node_displays: Dict[Type[BaseNode], NodeDisplayType],
344
+ node_displays: Dict[Type[BaseNode], BaseNodeDisplay],
351
345
  overrides: Optional[EntrypointDisplayOverridesType] = None,
352
346
  ) -> EntrypointDisplayType:
353
347
  pass
@@ -357,16 +351,6 @@ class BaseWorkflowDisplay(
357
351
 
358
352
  return WorkflowOutputDisplay(id=output_id, name=output.name)
359
353
 
360
- @abstractmethod
361
- def _generate_edge_display(
362
- self,
363
- edge: Edge,
364
- node_displays: Dict[Type[BaseNode], NodeDisplayType],
365
- port_displays: Dict[Port, PortDisplay],
366
- overrides: Optional[EdgeDisplayOverridesType] = None,
367
- ) -> EdgeDisplayType:
368
- pass
369
-
370
354
  def __init_subclass__(cls, **kwargs: Any) -> None:
371
355
  super().__init_subclass__(**kwargs)
372
356
 
@@ -443,9 +427,9 @@ class BaseWorkflowDisplay(
443
427
  )
444
428
  return display_meta
445
429
 
446
- def _extract_node_displays(self, node: Type[BaseNode]) -> Dict[Type[BaseNode], NodeDisplayType]:
430
+ def _extract_node_displays(self, node: Type[BaseNode]) -> Dict[Type[BaseNode], BaseNodeDisplay]:
447
431
  node_display = self._get_node_display(node)
448
- additional_node_displays: Dict[Type[BaseNode], NodeDisplayType] = {
432
+ additional_node_displays: Dict[Type[BaseNode], BaseNodeDisplay] = {
449
433
  node: node_display,
450
434
  }
451
435
 
@@ -459,3 +443,47 @@ class BaseWorkflowDisplay(
459
443
  additional_node_displays[node] = display
460
444
 
461
445
  return additional_node_displays
446
+
447
+ def _generate_edge_display(
448
+ self,
449
+ edge: Edge,
450
+ node_displays: Dict[Type[BaseNode], BaseNodeDisplay],
451
+ port_displays: Dict[Port, PortDisplay],
452
+ overrides: Optional[EdgeDisplay] = None,
453
+ ) -> EdgeVellumDisplay:
454
+ source_node = get_unadorned_node(edge.from_port.node_class)
455
+ target_node = get_unadorned_node(edge.to_node)
456
+
457
+ source_node_id = node_displays[source_node].node_id
458
+ from_port = get_unadorned_port(edge.from_port)
459
+ source_handle_id = port_displays[from_port].id
460
+
461
+ target_node_display = node_displays[target_node]
462
+ target_node_id = target_node_display.node_id
463
+ target_handle_id = target_node_display.get_target_handle_id_by_source_node_id(source_node_id)
464
+
465
+ return self._generate_edge_display_from_source(
466
+ source_node_id, source_handle_id, target_node_id, target_handle_id, overrides
467
+ )
468
+
469
+ def _generate_edge_display_from_source(
470
+ self,
471
+ source_node_id: UUID,
472
+ source_handle_id: UUID,
473
+ target_node_id: UUID,
474
+ target_handle_id: UUID,
475
+ overrides: Optional[EdgeDisplay] = None,
476
+ ) -> EdgeVellumDisplay:
477
+ edge_id: UUID
478
+ if overrides:
479
+ edge_id = overrides.id
480
+ else:
481
+ edge_id = uuid4_from_hash(f"{self.workflow_id}|id|{source_node_id}|{target_node_id}")
482
+
483
+ return EdgeVellumDisplay(
484
+ id=edge_id,
485
+ source_node_id=source_node_id,
486
+ target_node_id=target_node_id,
487
+ source_handle_id=source_handle_id,
488
+ target_handle_id=target_handle_id,
489
+ )
@@ -1,14 +1,12 @@
1
1
  import logging
2
2
  from uuid import UUID
3
- from typing import Dict, List, Optional, Type, cast
3
+ from typing import Dict, Optional, Type, cast
4
4
 
5
5
  from vellum.workflows.descriptors.base import BaseDescriptor
6
- from vellum.workflows.edges import Edge
7
6
  from vellum.workflows.nodes.bases import BaseNode
8
7
  from vellum.workflows.nodes.displayable.bases.utils import primitive_to_vellum_value
9
8
  from vellum.workflows.nodes.displayable.final_output_node import FinalOutputNode
10
9
  from vellum.workflows.nodes.utils import get_unadorned_node, get_unadorned_port
11
- from vellum.workflows.ports import Port
12
10
  from vellum.workflows.references import WorkflowInputReference
13
11
  from vellum.workflows.references.output import OutputReference
14
12
  from vellum.workflows.types.core import JsonArray, JsonObject
@@ -16,12 +14,9 @@ from vellum.workflows.types.generics import WorkflowType
16
14
  from vellum.workflows.utils.uuids import uuid4_from_hash
17
15
  from vellum_ee.workflows.display.nodes.base_node_display import BaseNodeDisplay
18
16
  from vellum_ee.workflows.display.nodes.base_node_vellum_display import BaseNodeVellumDisplay
19
- from vellum_ee.workflows.display.nodes.types import PortDisplay
20
17
  from vellum_ee.workflows.display.nodes.vellum.utils import create_node_input
21
18
  from vellum_ee.workflows.display.utils.vellum import infer_vellum_variable_type
22
19
  from vellum_ee.workflows.display.vellum import (
23
- EdgeVellumDisplay,
24
- EdgeVellumDisplayOverrides,
25
20
  EntrypointVellumDisplay,
26
21
  EntrypointVellumDisplayOverrides,
27
22
  NodeDisplayData,
@@ -46,11 +41,8 @@ class VellumWorkflowDisplay(
46
41
  WorkflowInputsVellumDisplayOverrides,
47
42
  StateValueVellumDisplay,
48
43
  StateValueVellumDisplayOverrides,
49
- BaseNodeDisplay,
50
44
  EntrypointVellumDisplay,
51
45
  EntrypointVellumDisplayOverrides,
52
- EdgeVellumDisplay,
53
- EdgeVellumDisplayOverrides,
54
46
  ]
55
47
  ):
56
48
  node_display_base_class = BaseNodeDisplay
@@ -99,14 +91,16 @@ class VellumWorkflowDisplay(
99
91
  edges: JsonArray = []
100
92
 
101
93
  # Add a single synthetic node for the workflow entrypoint
94
+ entrypoint_node_id = self.display_context.workflow_display.entrypoint_node_id
95
+ entrypoint_node_source_handle_id = self.display_context.workflow_display.entrypoint_node_source_handle_id
102
96
  nodes.append(
103
97
  {
104
- "id": str(self.display_context.workflow_display.entrypoint_node_id),
98
+ "id": str(entrypoint_node_id),
105
99
  "type": "ENTRYPOINT",
106
100
  "inputs": [],
107
101
  "data": {
108
102
  "label": "Entrypoint Node",
109
- "source_handle_id": str(self.display_context.workflow_display.entrypoint_node_source_handle_id),
103
+ "source_handle_id": str(entrypoint_node_source_handle_id),
110
104
  },
111
105
  "display_data": self.display_context.workflow_display.entrypoint_node_display.dict(),
112
106
  "base": None,
@@ -245,25 +239,37 @@ class VellumWorkflowDisplay(
245
239
  )
246
240
 
247
241
  # Add an edge for each edge in the workflow
248
- all_edge_displays: List[EdgeVellumDisplay] = [
249
- # Create a synthetic edge from the synthetic entrypoint node to each actual entrypoint
250
- *[
251
- entrypoint_display.edge_display
252
- for entrypoint_display in self.display_context.entrypoint_displays.values()
253
- ],
254
- # Include the concrete edges in the workflow
255
- *self.display_context.edge_displays.values(),
256
- ]
242
+ for target_node, entrypoint_display in self.display_context.entrypoint_displays.items():
243
+ unadorned_target_node = get_unadorned_node(target_node)
244
+ target_node_display = self.display_context.node_displays[unadorned_target_node]
245
+ edges.append(
246
+ {
247
+ "id": str(entrypoint_display.edge_display.id),
248
+ "source_node_id": str(entrypoint_node_id),
249
+ "source_handle_id": str(entrypoint_node_source_handle_id),
250
+ "target_node_id": str(target_node_display.node_id),
251
+ "target_handle_id": str(target_node_display.get_trigger_id()),
252
+ "type": "DEFAULT",
253
+ }
254
+ )
255
+
256
+ for (source_node_port, target_node), edge_display in self.display_context.edge_displays.items():
257
+ unadorned_source_node_port = get_unadorned_port(source_node_port)
258
+ unadorned_target_node = get_unadorned_node(target_node)
259
+
260
+ source_node_port_display = self.display_context.port_displays[unadorned_source_node_port]
261
+ target_node_display = self.display_context.node_displays[unadorned_target_node]
257
262
 
258
- for edge_display in all_edge_displays:
259
263
  edges.append(
260
264
  {
261
265
  "id": str(edge_display.id),
262
- "source_node_id": str(edge_display.source_node_id),
263
- "source_handle_id": str(edge_display.source_handle_id),
264
- "target_node_id": str(edge_display.target_node_id),
265
- "target_handle_id": str(edge_display.target_handle_id),
266
- "type": edge_display.type,
266
+ "source_node_id": str(source_node_port_display.node_id),
267
+ "source_handle_id": str(source_node_port_display.id),
268
+ "target_node_id": str(target_node_display.node_id),
269
+ "target_handle_id": str(
270
+ target_node_display.get_target_handle_id_by_source_node_id(source_node_port_display.node_id)
271
+ ),
272
+ "type": "DEFAULT",
267
273
  }
268
274
  )
269
275
 
@@ -357,61 +363,10 @@ class VellumWorkflowDisplay(
357
363
  entrypoint_target = get_unadorned_node(entrypoint)
358
364
  target_node_display = node_displays[entrypoint_target]
359
365
  target_node_id = target_node_display.node_id
360
- if isinstance(target_node_display, BaseNodeVellumDisplay):
361
- target_handle_id = target_node_display.get_target_handle_id_by_source_node_id(entrypoint_node_id)
362
- else:
363
- target_handle_id = target_node_display.get_trigger_id()
366
+ target_handle_id = target_node_display.get_target_handle_id_by_source_node_id(entrypoint_node_id)
364
367
 
365
368
  edge_display = self._generate_edge_display_from_source(
366
369
  entrypoint_node_id, source_handle_id, target_node_id, target_handle_id, overrides=edge_display_overrides
367
370
  )
368
371
 
369
372
  return EntrypointVellumDisplay(id=entrypoint_id, edge_display=edge_display)
370
-
371
- def _generate_edge_display(
372
- self,
373
- edge: Edge,
374
- node_displays: Dict[Type[BaseNode], BaseNodeDisplay],
375
- port_displays: Dict[Port, PortDisplay],
376
- overrides: Optional[EdgeVellumDisplayOverrides] = None,
377
- ) -> EdgeVellumDisplay:
378
- source_node = get_unadorned_node(edge.from_port.node_class)
379
- target_node = get_unadorned_node(edge.to_node)
380
-
381
- source_node_id = node_displays[source_node].node_id
382
- from_port = get_unadorned_port(edge.from_port)
383
- source_handle_id = port_displays[from_port].id
384
-
385
- target_node_display = node_displays[target_node]
386
- target_node_id = target_node_display.node_id
387
-
388
- if isinstance(target_node_display, BaseNodeVellumDisplay):
389
- target_handle_id = target_node_display.get_target_handle_id_by_source_node_id(source_node_id)
390
- else:
391
- target_handle_id = target_node_display.get_trigger_id()
392
-
393
- return self._generate_edge_display_from_source(
394
- source_node_id, source_handle_id, target_node_id, target_handle_id, overrides
395
- )
396
-
397
- def _generate_edge_display_from_source(
398
- self,
399
- source_node_id: UUID,
400
- source_handle_id: UUID,
401
- target_node_id: UUID,
402
- target_handle_id: UUID,
403
- overrides: Optional[EdgeVellumDisplayOverrides] = None,
404
- ) -> EdgeVellumDisplay:
405
- edge_id: UUID
406
- if overrides:
407
- edge_id = overrides.id
408
- else:
409
- edge_id = uuid4_from_hash(f"{self.workflow_id}|id|{source_node_id}|{target_node_id}")
410
-
411
- return EdgeVellumDisplay(
412
- id=edge_id,
413
- source_node_id=source_node_id,
414
- target_node_id=target_node_id,
415
- source_handle_id=source_handle_id,
416
- target_handle_id=target_handle_id,
417
- )
@@ -1,42 +1,72 @@
1
1
  import importlib
2
+ import re
3
+ from typing import Optional
2
4
 
3
5
 
4
6
  class VirtualFileLoader(importlib.abc.Loader):
5
- def __init__(self, code: str, is_package: bool):
6
- self.code = code
7
- self.is_package = is_package
7
+ def __init__(self, files: dict[str, str], namespace: str):
8
+ self.files = files
9
+ self.namespace = namespace
8
10
 
9
11
  def create_module(self, spec):
10
12
  return None # use default module creation
11
13
 
12
14
  def exec_module(self, module):
13
- if not self.is_package or self.code:
14
- exec(self.code, module.__dict__)
15
+ module_info = self._resolve_module(module.__spec__.origin)
15
16
 
17
+ if module_info:
18
+ file_path, code = module_info
19
+ compiled = compile(code, file_path, "exec")
20
+ exec(compiled, module.__dict__)
16
21
 
17
- class VirtualFileFinder(importlib.abc.MetaPathFinder, importlib.abc.Loader):
18
- def __init__(self, files: dict[str, str], namespace: str):
19
- self.files = files
20
- self.namespace = namespace
22
+ def get_source(self, fullname):
23
+ """
24
+ `inspect` module uses this method to get the source code of a module.
25
+ """
26
+
27
+ module_info = self._resolve_module(fullname)
28
+ if module_info:
29
+ return module_info[1]
30
+
31
+ return None
32
+
33
+ def _resolve_module(self, fullname: str) -> Optional[tuple[str, str]]:
34
+ file_path = self._get_file_path(fullname)
35
+ code = self._get_code(file_path)
36
+
37
+ if code is not None:
38
+ return file_path, code
21
39
 
22
- def find_spec(self, fullname, path, target=None):
23
- # Do the namespacing on the fly to avoid having to copy the file dict
24
- prefixed_name = fullname if fullname.startswith(self.namespace) else f"{self.namespace}.{fullname}"
40
+ if not file_path.endswith("__init__.py"):
41
+ file_path = re.sub(r"\.py$", "/__init__.py", file_path)
42
+ code = self._get_code(file_path)
25
43
 
26
- key_name = "__init__" if fullname == self.namespace else fullname.replace(f"{self.namespace}.", "")
44
+ if code is not None:
45
+ return file_path, code
46
+
47
+ return None
27
48
 
28
- files_key = f"{key_name.replace('.', '/')}.py"
29
- if self.files.get(files_key) is None:
30
- files_key = f"{key_name.replace('.', '/')}/__init__.py"
49
+ def _get_file_path(self, fullname):
50
+ return f"{fullname.replace('.', '/')}.py"
31
51
 
32
- file = self.files.get(files_key)
33
- is_package = "__init__" in files_key
52
+ def _get_code(self, file_path):
53
+ file_key_name = re.sub(r"^" + re.escape(self.namespace) + r"/", "", file_path)
54
+ return self.files.get(file_key_name)
55
+
56
+
57
+ class VirtualFileFinder(importlib.abc.MetaPathFinder, importlib.abc.Loader):
58
+ def __init__(self, files: dict[str, str], namespace: str):
59
+ self.loader = VirtualFileLoader(files, namespace)
34
60
 
35
- if file is not None:
61
+ def find_spec(self, fullname: str, path, target=None):
62
+ module_info = self.loader._resolve_module(fullname)
63
+ if module_info:
64
+ file_path, _ = module_info
65
+ is_package = file_path.endswith("__init__.py")
36
66
  return importlib.machinery.ModuleSpec(
37
- prefixed_name,
38
- VirtualFileLoader(file, is_package),
39
- origin=prefixed_name,
67
+ fullname,
68
+ self.loader,
69
+ origin=fullname,
40
70
  is_package=is_package,
41
71
  )
42
72
  return None
@@ -1,4 +1,11 @@
1
+ import sys
2
+ from uuid import uuid4
3
+ from typing import Type, cast
4
+
1
5
  from vellum.client.core.pydantic_utilities import UniversalBaseModel
6
+ from vellum.workflows import BaseWorkflow
7
+ from vellum.workflows.nodes import BaseNode
8
+ from vellum_ee.workflows.server.virtual_file_loader import VirtualFileFinder
2
9
 
3
10
 
4
11
  def test_load_workflow_event_display_context():
@@ -8,3 +15,57 @@ def test_load_workflow_event_display_context():
8
15
  # We are actually just ensuring there are no circular dependencies when
9
16
  # our Workflow Server imports this class.
10
17
  assert issubclass(WorkflowEventDisplayContext, UniversalBaseModel)
18
+
19
+
20
+ def test_load_from_module__lazy_reference_in_file_loader():
21
+ # GIVEN a workflow module with a node containing a lazy reference
22
+ files = {
23
+ "__init__.py": "",
24
+ "workflow.py": """\
25
+ from vellum.workflows import BaseWorkflow
26
+ from .nodes.start_node import StartNode
27
+
28
+ class Workflow(BaseWorkflow):
29
+ graph = StartNode
30
+ """,
31
+ "nodes/__init__.py": """\
32
+ from .start_node import StartNode
33
+
34
+ __all__ = [
35
+ "StartNode",
36
+ ]
37
+ """,
38
+ "nodes/start_node.py": """\
39
+ from vellum.workflows.nodes import BaseNode
40
+ from vellum.workflows.references import LazyReference
41
+
42
+ class StartNode(BaseNode):
43
+ foo = LazyReference(lambda: StartNode.Outputs.bar)
44
+
45
+ class Outputs(BaseNode.Outputs):
46
+ bar = str
47
+ """,
48
+ }
49
+
50
+ namespace = str(uuid4())
51
+
52
+ # AND the virtual file loader is registered
53
+ sys.meta_path.append(VirtualFileFinder(files, namespace))
54
+
55
+ # WHEN the workflow is loaded
56
+ Workflow = BaseWorkflow.load_from_module(namespace)
57
+ workflow = Workflow()
58
+
59
+ # THEN the workflow is successfully initialized
60
+ assert workflow
61
+
62
+ # AND the graph is just a BaseNode
63
+ # ideally this would be true, but the loader uses a different BaseNode class definition than
64
+ # the one in this test module.
65
+ # assert isinstance(workflow.graph, BaseNode)
66
+ start_node = cast(Type[BaseNode], workflow.graph)
67
+ assert start_node.__bases__ == (BaseNode,)
68
+
69
+ # AND the lazy reference has the correct name
70
+ assert start_node.foo.instance
71
+ assert start_node.foo.instance.name == "StartNode.Outputs.bar"