tracdap-runtime 0.6.3__py3-none-any.whl → 0.6.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. tracdap/rt/_exec/context.py +572 -112
  2. tracdap/rt/_exec/dev_mode.py +166 -97
  3. tracdap/rt/_exec/engine.py +120 -9
  4. tracdap/rt/_exec/functions.py +137 -35
  5. tracdap/rt/_exec/graph.py +38 -13
  6. tracdap/rt/_exec/graph_builder.py +120 -9
  7. tracdap/rt/_impl/data.py +183 -52
  8. tracdap/rt/_impl/grpc/tracdap/metadata/data_pb2.py +18 -18
  9. tracdap/rt/_impl/grpc/tracdap/metadata/job_pb2.py +74 -30
  10. tracdap/rt/_impl/grpc/tracdap/metadata/job_pb2.pyi +120 -2
  11. tracdap/rt/_impl/grpc/tracdap/metadata/model_pb2.py +20 -18
  12. tracdap/rt/_impl/grpc/tracdap/metadata/model_pb2.pyi +22 -6
  13. tracdap/rt/_impl/grpc/tracdap/metadata/resource_pb2.py +29 -0
  14. tracdap/rt/_impl/grpc/tracdap/metadata/resource_pb2.pyi +16 -0
  15. tracdap/rt/_impl/models.py +8 -0
  16. tracdap/rt/_impl/static_api.py +42 -10
  17. tracdap/rt/_impl/storage.py +37 -25
  18. tracdap/rt/_impl/validation.py +113 -11
  19. tracdap/rt/_plugins/repo_git.py +1 -1
  20. tracdap/rt/_version.py +1 -1
  21. tracdap/rt/api/experimental.py +220 -0
  22. tracdap/rt/api/hook.py +6 -4
  23. tracdap/rt/api/model_api.py +98 -13
  24. tracdap/rt/api/static_api.py +14 -6
  25. tracdap/rt/config/__init__.py +2 -2
  26. tracdap/rt/config/common.py +23 -17
  27. tracdap/rt/config/job.py +2 -2
  28. tracdap/rt/config/platform.py +25 -25
  29. tracdap/rt/config/result.py +2 -2
  30. tracdap/rt/config/runtime.py +3 -3
  31. tracdap/rt/launch/cli.py +7 -4
  32. tracdap/rt/launch/launch.py +19 -3
  33. tracdap/rt/metadata/__init__.py +25 -20
  34. tracdap/rt/metadata/common.py +2 -2
  35. tracdap/rt/metadata/custom.py +3 -3
  36. tracdap/rt/metadata/data.py +12 -12
  37. tracdap/rt/metadata/file.py +6 -6
  38. tracdap/rt/metadata/flow.py +6 -6
  39. tracdap/rt/metadata/job.py +62 -8
  40. tracdap/rt/metadata/model.py +33 -11
  41. tracdap/rt/metadata/object_id.py +8 -8
  42. tracdap/rt/metadata/resource.py +24 -0
  43. tracdap/rt/metadata/search.py +5 -5
  44. tracdap/rt/metadata/stoarge.py +6 -6
  45. tracdap/rt/metadata/tag.py +1 -1
  46. tracdap/rt/metadata/tag_update.py +1 -1
  47. tracdap/rt/metadata/type.py +4 -4
  48. {tracdap_runtime-0.6.3.dist-info → tracdap_runtime-0.6.5.dist-info}/METADATA +3 -1
  49. {tracdap_runtime-0.6.3.dist-info → tracdap_runtime-0.6.5.dist-info}/RECORD +52 -48
  50. {tracdap_runtime-0.6.3.dist-info → tracdap_runtime-0.6.5.dist-info}/LICENSE +0 -0
  51. {tracdap_runtime-0.6.3.dist-info → tracdap_runtime-0.6.5.dist-info}/WHEEL +0 -0
  52. {tracdap_runtime-0.6.3.dist-info → tracdap_runtime-0.6.5.dist-info}/top_level.txt +0 -0
@@ -23,6 +23,7 @@ import tracdap.rt.api as _api
23
23
  import tracdap.rt.config as _config
24
24
  import tracdap.rt.exceptions as _ex
25
25
  import tracdap.rt._exec.context as _ctx
26
+ import tracdap.rt._exec.graph_builder as _graph
26
27
  import tracdap.rt._impl.config_parser as _cfg_p # noqa
27
28
  import tracdap.rt._impl.type_system as _types # noqa
28
29
  import tracdap.rt._impl.data as _data # noqa
@@ -59,6 +60,13 @@ class NodeContext:
59
60
  pass
60
61
 
61
62
 
63
+ class NodeCallback:
64
+
65
+ @abc.abstractmethod
66
+ def send_graph_updates(self, new_nodes: tp.Dict[NodeId, Node], new_deps: tp.Dict[NodeId, tp.List[Dependency]]):
67
+ pass
68
+
69
+
62
70
  # Helper functions to access the node context (in case the NodeContext interface needs to change)
63
71
 
64
72
  def _ctx_lookup(node_id: NodeId[_T], ctx: NodeContext) -> _T:
@@ -89,8 +97,15 @@ class NodeFunction(tp.Generic[_T]):
89
97
  :py:class:`NodeContext <NodeContext>`
90
98
  """
91
99
 
92
- def __call__(self, ctx: NodeContext) -> _T:
93
- return self._execute(ctx)
100
+ def __init__(self):
101
+ self.node_callback: tp.Optional[NodeCallback] = None
102
+
103
+ def __call__(self, ctx: NodeContext, callback: NodeCallback = None) -> _T:
104
+ try:
105
+ self.node_callback = callback
106
+ return self._execute(ctx)
107
+ finally:
108
+ self.node_callback = None
94
109
 
95
110
  @abc.abstractmethod
96
111
  def _execute(self, ctx: NodeContext) -> _T:
@@ -105,6 +120,7 @@ class NodeFunction(tp.Generic[_T]):
105
120
  class NoopFunc(NodeFunction[None]):
106
121
 
107
122
  def __init__(self, node: NoopNode):
123
+ super().__init__()
108
124
  self.node = node
109
125
 
110
126
  def _execute(self, _: NodeContext) -> None:
@@ -114,6 +130,7 @@ class NoopFunc(NodeFunction[None]):
114
130
  class StaticValueFunc(NodeFunction[_T]):
115
131
 
116
132
  def __init__(self, node: StaticValueNode[_T]):
133
+ super().__init__()
117
134
  self.node = node
118
135
 
119
136
  def _execute(self, ctx: NodeContext) -> _T:
@@ -123,6 +140,7 @@ class StaticValueFunc(NodeFunction[_T]):
123
140
  class IdentityFunc(NodeFunction[_T]):
124
141
 
125
142
  def __init__(self, node: IdentityNode[_T]):
143
+ super().__init__()
126
144
  self.node = node
127
145
 
128
146
  def _execute(self, ctx: NodeContext) -> _T:
@@ -138,6 +156,7 @@ class _ContextPushPopFunc(NodeFunction[Bundle[tp.Any]], abc.ABC):
138
156
  _POP = False
139
157
 
140
158
  def __init__(self, node: tp.Union[ContextPushNode, ContextPopNode], direction: bool):
159
+ super().__init__()
141
160
  self.node = node
142
161
  self.direction = direction
143
162
 
@@ -176,6 +195,7 @@ class ContextPopFunc(_ContextPushPopFunc):
176
195
  class KeyedItemFunc(NodeFunction[_T]):
177
196
 
178
197
  def __init__(self, node: KeyedItemNode[_T]):
198
+ super().__init__()
179
199
  self.node = node
180
200
 
181
201
  def _execute(self, ctx: NodeContext) -> _T:
@@ -184,9 +204,20 @@ class KeyedItemFunc(NodeFunction[_T]):
184
204
  return src_item
185
205
 
186
206
 
207
+ class RuntimeOutputsFunc(NodeFunction[JobOutputs]):
208
+
209
+ def __init__(self, node: RuntimeOutputsNode):
210
+ super().__init__()
211
+ self.node = node
212
+
213
+ def _execute(self, ctx: NodeContext) -> JobOutputs:
214
+ return self.node.outputs
215
+
216
+
187
217
  class BuildJobResultFunc(NodeFunction[_config.JobResult]):
188
218
 
189
219
  def __init__(self, node: BuildJobResultNode):
220
+ super().__init__()
190
221
  self.node = node
191
222
 
192
223
  def _execute(self, ctx: NodeContext) -> _config.JobResult:
@@ -197,20 +228,33 @@ class BuildJobResultFunc(NodeFunction[_config.JobResult]):
197
228
 
198
229
  # TODO: Handle individual failed results
199
230
 
200
- for obj_id, node_id in self.node.objects.items():
231
+ for obj_id, node_id in self.node.outputs.objects.items():
201
232
  obj_def = _ctx_lookup(node_id, ctx)
202
233
  job_result.results[obj_id] = obj_def
203
234
 
204
- for bundle_id in self.node.bundles:
235
+ for bundle_id in self.node.outputs.bundles:
205
236
  bundle = _ctx_lookup(bundle_id, ctx)
206
237
  job_result.results.update(bundle.items())
207
238
 
239
+ if self.node.runtime_outputs is not None:
240
+
241
+ runtime_outputs = _ctx_lookup(self.node.runtime_outputs, ctx)
242
+
243
+ for obj_id, node_id in runtime_outputs.objects.items():
244
+ obj_def = _ctx_lookup(node_id, ctx)
245
+ job_result.results[obj_id] = obj_def
246
+
247
+ for bundle_id in runtime_outputs.bundles:
248
+ bundle = _ctx_lookup(bundle_id, ctx)
249
+ job_result.results.update(bundle.items())
250
+
208
251
  return job_result
209
252
 
210
253
 
211
254
  class SaveJobResultFunc(NodeFunction[None]):
212
255
 
213
256
  def __init__(self, node: SaveJobResultNode):
257
+ super().__init__()
214
258
  self.node = node
215
259
 
216
260
  def _execute(self, ctx: NodeContext) -> None:
@@ -241,6 +285,7 @@ class SaveJobResultFunc(NodeFunction[None]):
241
285
  class DataViewFunc(NodeFunction[_data.DataView]):
242
286
 
243
287
  def __init__(self, node: DataViewNode):
288
+ super().__init__()
244
289
  self.node = node
245
290
 
246
291
  def _execute(self, ctx: NodeContext) -> _data.DataView:
@@ -252,7 +297,13 @@ class DataViewFunc(NodeFunction[_data.DataView]):
252
297
  if root_item.is_empty():
253
298
  return _data.DataView.create_empty()
254
299
 
255
- data_view = _data.DataView.for_trac_schema(self.node.schema)
300
+ if self.node.schema is not None and len(self.node.schema.table.fields) > 0:
301
+ trac_schema = self.node.schema
302
+ else:
303
+ arrow_schema = root_item.schema
304
+ trac_schema = _data.DataMapping.arrow_to_trac_schema(arrow_schema)
305
+
306
+ data_view = _data.DataView.for_trac_schema(trac_schema)
256
307
  data_view = _data.DataMapping.add_item_to_view(data_view, root_part_key, root_item)
257
308
 
258
309
  return data_view
@@ -261,6 +312,7 @@ class DataViewFunc(NodeFunction[_data.DataView]):
261
312
  class DataItemFunc(NodeFunction[_data.DataItem]):
262
313
 
263
314
  def __init__(self, node: DataItemNode):
315
+ super().__init__()
264
316
  self.node = node
265
317
 
266
318
  def _execute(self, ctx: NodeContext) -> _data.DataItem:
@@ -284,6 +336,7 @@ class DataItemFunc(NodeFunction[_data.DataItem]):
284
336
  class DataResultFunc(NodeFunction[ObjectBundle]):
285
337
 
286
338
  def __init__(self, node: DataResultNode):
339
+ super().__init__()
287
340
  self.node = node
288
341
 
289
342
  def _execute(self, ctx: NodeContext) -> ObjectBundle:
@@ -318,6 +371,7 @@ class DynamicDataSpecFunc(NodeFunction[_data.DataSpec]):
318
371
  RANDOM.seed()
319
372
 
320
373
  def __init__(self, node: DynamicDataSpecNode, storage: _storage.StorageManager):
374
+ super().__init__()
321
375
  self.node = node
322
376
  self.storage = storage
323
377
 
@@ -428,7 +482,7 @@ class _LoadSaveDataFunc(abc.ABC):
428
482
  return copy_
429
483
 
430
484
 
431
- class LoadDataFunc(NodeFunction[_data.DataItem], _LoadSaveDataFunc):
485
+ class LoadDataFunc( _LoadSaveDataFunc, NodeFunction[_data.DataItem],):
432
486
 
433
487
  def __init__(self, node: LoadDataNode, storage: _storage.StorageManager):
434
488
  super().__init__(storage)
@@ -457,7 +511,7 @@ class LoadDataFunc(NodeFunction[_data.DataItem], _LoadSaveDataFunc):
457
511
  return _data.DataItem(table.schema, table)
458
512
 
459
513
 
460
- class SaveDataFunc(NodeFunction[None], _LoadSaveDataFunc):
514
+ class SaveDataFunc(_LoadSaveDataFunc, NodeFunction[None]):
461
515
 
462
516
  def __init__(self, node: SaveDataNode, storage: _storage.StorageManager):
463
517
  super().__init__(storage)
@@ -512,6 +566,7 @@ def _model_def_for_import(import_details: meta.ImportModelJob):
512
566
  class ImportModelFunc(NodeFunction[meta.ObjectDefinition]):
513
567
 
514
568
  def __init__(self, node: ImportModelNode, models: _models.ModelLoader):
569
+ super().__init__()
515
570
  self.node = node
516
571
  self._models = models
517
572
 
@@ -529,11 +584,17 @@ class ImportModelFunc(NodeFunction[meta.ObjectDefinition]):
529
584
 
530
585
  class RunModelFunc(NodeFunction[Bundle[_data.DataView]]):
531
586
 
532
- def __init__(self, node: RunModelNode, model_class: _api.TracModel.__class__, checkout_directory: pathlib.Path):
587
+ def __init__(
588
+ self, node: RunModelNode,
589
+ model_class: _api.TracModel.__class__,
590
+ checkout_directory: pathlib.Path,
591
+ storage_manager: _storage.StorageManager):
592
+
533
593
  super().__init__()
534
594
  self.node = node
535
595
  self.model_class = model_class
536
596
  self.checkout_directory = checkout_directory
597
+ self.storage_manager = storage_manager
537
598
 
538
599
  def _execute(self, ctx: NodeContext) -> Bundle[_data.DataView]:
539
600
 
@@ -544,36 +605,37 @@ class RunModelFunc(NodeFunction[Bundle[_data.DataView]]):
544
605
  # Still, if any nodes are missing or have the wrong type TracContextImpl will raise ERuntimeValidation
545
606
 
546
607
  local_ctx = {}
547
- static_schemas = {}
608
+ dynamic_outputs = []
548
609
 
549
610
  for node_id, node_result in _ctx_iter_items(ctx):
611
+ if node_id.namespace == self.node.id.namespace:
612
+ if node_id.name in model_def.parameters or node_id.name in model_def.inputs:
613
+ local_ctx[node_id.name] = node_result
550
614
 
551
- if node_id.namespace != self.node.id.namespace:
552
- continue
553
-
554
- if node_id.name in model_def.parameters:
555
- param_name = node_id.name
556
- local_ctx[param_name] = node_result
557
-
558
- if node_id.name in model_def.inputs:
559
- input_name = node_id.name
560
- local_ctx[input_name] = node_result
561
- # At the moment, all model inputs have static schemas
562
- static_schemas[input_name] = model_def.inputs[input_name].schema
615
+ # Set up access to external storage if required
563
616
 
564
- # Add empty data views to the local context to hold model outputs
565
- # Assuming outputs are all defined with static schemas
617
+ storage_map = {}
566
618
 
567
- for output_name in model_def.outputs:
568
- output_schema = self.node.model_def.outputs[output_name].schema
569
- empty_data_view = _data.DataView.for_trac_schema(output_schema)
570
- local_ctx[output_name] = empty_data_view
571
- # At the moment, all model outputs have static schemas
572
- static_schemas[output_name] = output_schema
619
+ if self.node.storage_access:
620
+ write_access = True if self.node.model_def.modelType == meta.ModelType.DATA_EXPORT_MODEL else False
621
+ for storage_key in self.node.storage_access:
622
+ if self.storage_manager.has_file_storage(storage_key, external=True):
623
+ storage_impl = self.storage_manager.get_file_storage(storage_key, external=True)
624
+ storage = _ctx.TracFileStorageImpl(storage_key, storage_impl, write_access, self.checkout_directory)
625
+ storage_map[storage_key] = storage
573
626
 
574
627
  # Run the model against the mapped local context
575
628
 
576
- trac_ctx = _ctx.TracContextImpl(self.node.model_def, self.model_class, local_ctx, static_schemas)
629
+ if model_def.modelType in [meta.ModelType.DATA_IMPORT_MODEL, meta.ModelType.DATA_EXPORT_MODEL]:
630
+ trac_ctx = _ctx.TracDataContextImpl(
631
+ self.node.model_def, self.model_class,
632
+ local_ctx, dynamic_outputs, storage_map,
633
+ self.checkout_directory)
634
+ else:
635
+ trac_ctx = _ctx.TracContextImpl(
636
+ self.node.model_def, self.model_class,
637
+ local_ctx, dynamic_outputs,
638
+ self.checkout_directory)
577
639
 
578
640
  try:
579
641
  model = self.model_class()
@@ -587,20 +649,50 @@ class RunModelFunc(NodeFunction[Bundle[_data.DataView]]):
587
649
 
588
650
  # Check required outputs are present and build the results bundle
589
651
 
652
+ model_name = self.model_class.__name__
590
653
  results: Bundle[_data.DataView] = dict()
654
+ new_nodes = dict()
655
+ new_deps = dict()
591
656
 
592
657
  for output_name, output_schema in model_def.outputs.items():
593
658
 
594
659
  result: _data.DataView = local_ctx.get(output_name)
595
660
 
596
661
  if result is None or result.is_empty():
662
+
597
663
  if not output_schema.optional:
598
- model_name = self.model_class.__name__
599
664
  raise _ex.ERuntimeValidation(f"Missing required output [{output_name}] from model [{model_name}]")
600
665
 
601
- if result is not None:
666
+ # Create a placeholder for optional outputs that were not emitted
667
+ elif result is None:
668
+ result = _data.DataView.create_empty()
669
+
670
+ results[output_name] = result
671
+
672
+ if dynamic_outputs:
673
+
674
+ for output_name in dynamic_outputs:
675
+
676
+ result: _data.DataView = local_ctx.get(output_name)
677
+
678
+ if result is None or result.is_empty():
679
+ raise _ex.ERuntimeValidation(f"No data provided for [{output_name}] from model [{model_name}]")
680
+
602
681
  results[output_name] = result
603
682
 
683
+ result_node_id = NodeId.of(output_name, self.node.id.namespace, _data.DataView)
684
+ result_node = BundleItemNode(result_node_id, self.node.id, output_name)
685
+
686
+ new_nodes[result_node_id] = result_node
687
+
688
+ output_section = _graph.GraphBuilder.build_runtime_outputs(dynamic_outputs, self.node.id.namespace)
689
+ new_nodes.update(output_section.nodes)
690
+
691
+ ctx_id = NodeId.of("trac_build_result", self.node.id.namespace, result_type=None)
692
+ new_deps[ctx_id] = list(_graph.Dependency(nid, _graph.DependencyType.HARD) for nid in output_section.outputs)
693
+
694
+ self.node_callback.send_graph_updates(new_nodes, new_deps)
695
+
604
696
  return results
605
697
 
606
698
 
@@ -624,6 +716,14 @@ class FunctionResolver:
624
716
  :py:class:`NodeFunction <NodeFunction>`
625
717
  """
626
718
 
719
+ # TODO: Validate consistency for resource keys
720
+ # Storage key should be validated for load data, save data and run model with storage access
721
+ # Repository key should be validated for import model (and explicitly for run model)
722
+
723
+ # Currently jobs with missing resources will fail at runtime, with a suitable error
724
+ # The resolver is called during graph building
725
+ # Putting the check here will raise a consistency error before the job starts processing
726
+
627
727
  __ResolveFunc = tp.Callable[['FunctionResolver', Node[_T]], NodeFunction[_T]]
628
728
 
629
729
  def __init__(self, models: _models.ModelLoader, storage: _storage.StorageManager):
@@ -658,12 +758,13 @@ class FunctionResolver:
658
758
 
659
759
  def resolve_run_model_node(self, node: RunModelNode) -> NodeFunction:
660
760
 
761
+ # TODO: Verify model_class against model_def
762
+
661
763
  model_class = self._models.load_model_class(node.model_scope, node.model_def)
662
764
  checkout_directory = self._models.model_load_checkout_directory(node.model_scope, node.model_def)
765
+ storage_manager = self._storage if node.storage_access else None
663
766
 
664
- # TODO: Verify model_class against model_def
665
-
666
- return RunModelFunc(node, model_class, checkout_directory)
767
+ return RunModelFunc(node, model_class, checkout_directory, storage_manager)
667
768
 
668
769
  __basic_node_mapping: tp.Dict[Node.__class__, NodeFunction.__class__] = {
669
770
 
@@ -677,6 +778,7 @@ class FunctionResolver:
677
778
  SaveJobResultNode: SaveJobResultFunc,
678
779
  DataResultNode: DataResultFunc,
679
780
  StaticValueNode: StaticValueFunc,
781
+ RuntimeOutputsNode: RuntimeOutputsFunc,
680
782
  BundleItemNode: NoopFunc,
681
783
  NoopNode: NoopFunc,
682
784
  RunModelResultNode: NoopFunc
tracdap/rt/_exec/graph.py CHANGED
@@ -12,8 +12,6 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from __future__ import annotations
16
-
17
15
  import pathlib
18
16
  import typing as tp
19
17
  import dataclasses as dc
@@ -38,7 +36,7 @@ class NodeNamespace:
38
36
  return cls.__ROOT
39
37
 
40
38
  name: str
41
- parent: tp.Optional[NodeNamespace] = dc.field(default_factory=lambda: NodeNamespace.root())
39
+ parent: "tp.Optional[NodeNamespace]" = dc.field(default_factory=lambda: NodeNamespace.root())
42
40
 
43
41
  def __str__(self):
44
42
  if self is self.__ROOT:
@@ -62,7 +60,7 @@ class NodeNamespace:
62
60
  class NodeId(tp.Generic[_T]):
63
61
 
64
62
  @staticmethod
65
- def of(name: str, namespace: NodeNamespace, result_type: tp.Type[_T]) -> NodeId[_T]:
63
+ def of(name: str, namespace: NodeNamespace, result_type: tp.Type[_T]) -> "NodeId[_T]":
66
64
  return NodeId(name, namespace, result_type)
67
65
 
68
66
  name: str
@@ -83,8 +81,8 @@ class DependencyType:
83
81
  immediate: bool = True
84
82
  tolerant: bool = False
85
83
 
86
- HARD: tp.ClassVar[DependencyType]
87
- TOLERANT: tp.ClassVar[DependencyType]
84
+ HARD: "tp.ClassVar[DependencyType]"
85
+ TOLERANT: "tp.ClassVar[DependencyType]"
88
86
 
89
87
 
90
88
  DependencyType.HARD = DependencyType(immediate=True, tolerant=False)
@@ -93,6 +91,13 @@ DependencyType.TOLERANT = DependencyType(immediate=True, tolerant=True)
93
91
  DependencyType.DELAYED = DependencyType(immediate=False, tolerant=False)
94
92
 
95
93
 
94
+ @dc.dataclass(frozen=True)
95
+ class Dependency:
96
+
97
+ node_id: NodeId
98
+ dependency_type: DependencyType
99
+
100
+
96
101
  @dc.dataclass(frozen=True)
97
102
  class Node(tp.Generic[_T]):
98
103
 
@@ -165,6 +170,17 @@ class GraphSection:
165
170
  must_run: tp.List[NodeId] = dc.field(default_factory=list)
166
171
 
167
172
 
173
+ Bundle: tp.Generic[_T] = tp.Dict[str, _T]
174
+ ObjectBundle = Bundle[meta.ObjectDefinition]
175
+
176
+
177
+ @dc.dataclass(frozen=True)
178
+ class JobOutputs:
179
+
180
+ objects: tp.Dict[str, NodeId[meta.ObjectDefinition]] = dc.field(default_factory=dict)
181
+ bundles: tp.List[NodeId[ObjectBundle]] = dc.field(default_factory=list)
182
+
183
+
168
184
  # TODO: Where does this go?
169
185
  @dc.dataclass(frozen=True)
170
186
  class JobResultSpec:
@@ -179,10 +195,6 @@ class JobResultSpec:
179
195
  # ----------------------------------------------------------------------------------------------------------------------
180
196
 
181
197
 
182
- Bundle: tp.Generic[_T] = tp.Dict[str, _T]
183
- ObjectBundle = Bundle[meta.ObjectDefinition]
184
-
185
-
186
198
  @_node_type
187
199
  class NoopNode(Node):
188
200
  pass
@@ -354,6 +366,7 @@ class RunModelNode(Node[Bundle[_data.DataView]]):
354
366
  model_def: meta.ModelDefinition
355
367
  parameter_ids: tp.FrozenSet[NodeId]
356
368
  input_ids: tp.FrozenSet[NodeId]
369
+ storage_access: tp.Optional[tp.List[str]] = None
357
370
 
358
371
  def _node_dependencies(self) -> tp.Dict[NodeId, DependencyType]:
359
372
  return {dep_id: DependencyType.HARD for dep_id in [*self.parameter_ids, *self.input_ids]}
@@ -368,16 +381,28 @@ class RunModelResultNode(Node[None]):
368
381
  return {self.model_id: DependencyType.HARD}
369
382
 
370
383
 
384
+ @_node_type
385
+ class RuntimeOutputsNode(Node[JobOutputs]):
386
+
387
+ outputs: JobOutputs
388
+
389
+ def _node_dependencies(self) -> tp.Dict[NodeId, DependencyType]:
390
+ dep_ids = [*self.outputs.bundles, *self.outputs.objects.values()]
391
+ return {node_id: DependencyType.HARD for node_id in dep_ids}
392
+
393
+
371
394
  @_node_type
372
395
  class BuildJobResultNode(Node[cfg.JobResult]):
373
396
 
374
397
  job_id: meta.TagHeader
375
398
 
376
- objects: tp.Dict[str, NodeId[meta.ObjectDefinition]] = dc.field(default_factory=dict)
377
- bundles: tp.List[NodeId[ObjectBundle]] = dc.field(default_factory=list)
399
+ outputs: JobOutputs
400
+ runtime_outputs: tp.Optional[NodeId[JobOutputs]] = None
378
401
 
379
402
  def _node_dependencies(self) -> tp.Dict[NodeId, DependencyType]:
380
- dep_ids = [*self.bundles, *self.objects.values()]
403
+ dep_ids = [*self.outputs.bundles, *self.outputs.objects.values()]
404
+ if self.runtime_outputs is not None:
405
+ dep_ids.append(self.runtime_outputs)
381
406
  return {node_id: DependencyType.HARD for node_id in dep_ids}
382
407
 
383
408