tracdap-runtime 0.6.3__py3-none-any.whl → 0.6.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tracdap/rt/_exec/context.py +572 -112
- tracdap/rt/_exec/dev_mode.py +166 -97
- tracdap/rt/_exec/engine.py +120 -9
- tracdap/rt/_exec/functions.py +137 -35
- tracdap/rt/_exec/graph.py +38 -13
- tracdap/rt/_exec/graph_builder.py +120 -9
- tracdap/rt/_impl/data.py +183 -52
- tracdap/rt/_impl/grpc/tracdap/metadata/data_pb2.py +18 -18
- tracdap/rt/_impl/grpc/tracdap/metadata/job_pb2.py +74 -30
- tracdap/rt/_impl/grpc/tracdap/metadata/job_pb2.pyi +120 -2
- tracdap/rt/_impl/grpc/tracdap/metadata/model_pb2.py +20 -18
- tracdap/rt/_impl/grpc/tracdap/metadata/model_pb2.pyi +22 -6
- tracdap/rt/_impl/grpc/tracdap/metadata/resource_pb2.py +29 -0
- tracdap/rt/_impl/grpc/tracdap/metadata/resource_pb2.pyi +16 -0
- tracdap/rt/_impl/models.py +8 -0
- tracdap/rt/_impl/static_api.py +42 -10
- tracdap/rt/_impl/storage.py +37 -25
- tracdap/rt/_impl/validation.py +113 -11
- tracdap/rt/_plugins/repo_git.py +1 -1
- tracdap/rt/_version.py +1 -1
- tracdap/rt/api/experimental.py +220 -0
- tracdap/rt/api/hook.py +6 -4
- tracdap/rt/api/model_api.py +98 -13
- tracdap/rt/api/static_api.py +14 -6
- tracdap/rt/config/__init__.py +2 -2
- tracdap/rt/config/common.py +23 -17
- tracdap/rt/config/job.py +2 -2
- tracdap/rt/config/platform.py +25 -25
- tracdap/rt/config/result.py +2 -2
- tracdap/rt/config/runtime.py +3 -3
- tracdap/rt/launch/cli.py +7 -4
- tracdap/rt/launch/launch.py +19 -3
- tracdap/rt/metadata/__init__.py +25 -20
- tracdap/rt/metadata/common.py +2 -2
- tracdap/rt/metadata/custom.py +3 -3
- tracdap/rt/metadata/data.py +12 -12
- tracdap/rt/metadata/file.py +6 -6
- tracdap/rt/metadata/flow.py +6 -6
- tracdap/rt/metadata/job.py +62 -8
- tracdap/rt/metadata/model.py +33 -11
- tracdap/rt/metadata/object_id.py +8 -8
- tracdap/rt/metadata/resource.py +24 -0
- tracdap/rt/metadata/search.py +5 -5
- tracdap/rt/metadata/stoarge.py +6 -6
- tracdap/rt/metadata/tag.py +1 -1
- tracdap/rt/metadata/tag_update.py +1 -1
- tracdap/rt/metadata/type.py +4 -4
- {tracdap_runtime-0.6.3.dist-info → tracdap_runtime-0.6.5.dist-info}/METADATA +3 -1
- {tracdap_runtime-0.6.3.dist-info → tracdap_runtime-0.6.5.dist-info}/RECORD +52 -48
- {tracdap_runtime-0.6.3.dist-info → tracdap_runtime-0.6.5.dist-info}/LICENSE +0 -0
- {tracdap_runtime-0.6.3.dist-info → tracdap_runtime-0.6.5.dist-info}/WHEEL +0 -0
- {tracdap_runtime-0.6.3.dist-info → tracdap_runtime-0.6.5.dist-info}/top_level.txt +0 -0
tracdap/rt/_exec/functions.py
CHANGED
@@ -23,6 +23,7 @@ import tracdap.rt.api as _api
|
|
23
23
|
import tracdap.rt.config as _config
|
24
24
|
import tracdap.rt.exceptions as _ex
|
25
25
|
import tracdap.rt._exec.context as _ctx
|
26
|
+
import tracdap.rt._exec.graph_builder as _graph
|
26
27
|
import tracdap.rt._impl.config_parser as _cfg_p # noqa
|
27
28
|
import tracdap.rt._impl.type_system as _types # noqa
|
28
29
|
import tracdap.rt._impl.data as _data # noqa
|
@@ -59,6 +60,13 @@ class NodeContext:
|
|
59
60
|
pass
|
60
61
|
|
61
62
|
|
63
|
+
class NodeCallback:
|
64
|
+
|
65
|
+
@abc.abstractmethod
|
66
|
+
def send_graph_updates(self, new_nodes: tp.Dict[NodeId, Node], new_deps: tp.Dict[NodeId, tp.List[Dependency]]):
|
67
|
+
pass
|
68
|
+
|
69
|
+
|
62
70
|
# Helper functions to access the node context (in case the NodeContext interface needs to change)
|
63
71
|
|
64
72
|
def _ctx_lookup(node_id: NodeId[_T], ctx: NodeContext) -> _T:
|
@@ -89,8 +97,15 @@ class NodeFunction(tp.Generic[_T]):
|
|
89
97
|
:py:class:`NodeContext <NodeContext>`
|
90
98
|
"""
|
91
99
|
|
92
|
-
def
|
93
|
-
|
100
|
+
def __init__(self):
|
101
|
+
self.node_callback: tp.Optional[NodeCallback] = None
|
102
|
+
|
103
|
+
def __call__(self, ctx: NodeContext, callback: NodeCallback = None) -> _T:
|
104
|
+
try:
|
105
|
+
self.node_callback = callback
|
106
|
+
return self._execute(ctx)
|
107
|
+
finally:
|
108
|
+
self.node_callback = None
|
94
109
|
|
95
110
|
@abc.abstractmethod
|
96
111
|
def _execute(self, ctx: NodeContext) -> _T:
|
@@ -105,6 +120,7 @@ class NodeFunction(tp.Generic[_T]):
|
|
105
120
|
class NoopFunc(NodeFunction[None]):
|
106
121
|
|
107
122
|
def __init__(self, node: NoopNode):
|
123
|
+
super().__init__()
|
108
124
|
self.node = node
|
109
125
|
|
110
126
|
def _execute(self, _: NodeContext) -> None:
|
@@ -114,6 +130,7 @@ class NoopFunc(NodeFunction[None]):
|
|
114
130
|
class StaticValueFunc(NodeFunction[_T]):
|
115
131
|
|
116
132
|
def __init__(self, node: StaticValueNode[_T]):
|
133
|
+
super().__init__()
|
117
134
|
self.node = node
|
118
135
|
|
119
136
|
def _execute(self, ctx: NodeContext) -> _T:
|
@@ -123,6 +140,7 @@ class StaticValueFunc(NodeFunction[_T]):
|
|
123
140
|
class IdentityFunc(NodeFunction[_T]):
|
124
141
|
|
125
142
|
def __init__(self, node: IdentityNode[_T]):
|
143
|
+
super().__init__()
|
126
144
|
self.node = node
|
127
145
|
|
128
146
|
def _execute(self, ctx: NodeContext) -> _T:
|
@@ -138,6 +156,7 @@ class _ContextPushPopFunc(NodeFunction[Bundle[tp.Any]], abc.ABC):
|
|
138
156
|
_POP = False
|
139
157
|
|
140
158
|
def __init__(self, node: tp.Union[ContextPushNode, ContextPopNode], direction: bool):
|
159
|
+
super().__init__()
|
141
160
|
self.node = node
|
142
161
|
self.direction = direction
|
143
162
|
|
@@ -176,6 +195,7 @@ class ContextPopFunc(_ContextPushPopFunc):
|
|
176
195
|
class KeyedItemFunc(NodeFunction[_T]):
|
177
196
|
|
178
197
|
def __init__(self, node: KeyedItemNode[_T]):
|
198
|
+
super().__init__()
|
179
199
|
self.node = node
|
180
200
|
|
181
201
|
def _execute(self, ctx: NodeContext) -> _T:
|
@@ -184,9 +204,20 @@ class KeyedItemFunc(NodeFunction[_T]):
|
|
184
204
|
return src_item
|
185
205
|
|
186
206
|
|
207
|
+
class RuntimeOutputsFunc(NodeFunction[JobOutputs]):
|
208
|
+
|
209
|
+
def __init__(self, node: RuntimeOutputsNode):
|
210
|
+
super().__init__()
|
211
|
+
self.node = node
|
212
|
+
|
213
|
+
def _execute(self, ctx: NodeContext) -> JobOutputs:
|
214
|
+
return self.node.outputs
|
215
|
+
|
216
|
+
|
187
217
|
class BuildJobResultFunc(NodeFunction[_config.JobResult]):
|
188
218
|
|
189
219
|
def __init__(self, node: BuildJobResultNode):
|
220
|
+
super().__init__()
|
190
221
|
self.node = node
|
191
222
|
|
192
223
|
def _execute(self, ctx: NodeContext) -> _config.JobResult:
|
@@ -197,20 +228,33 @@ class BuildJobResultFunc(NodeFunction[_config.JobResult]):
|
|
197
228
|
|
198
229
|
# TODO: Handle individual failed results
|
199
230
|
|
200
|
-
for obj_id, node_id in self.node.objects.items():
|
231
|
+
for obj_id, node_id in self.node.outputs.objects.items():
|
201
232
|
obj_def = _ctx_lookup(node_id, ctx)
|
202
233
|
job_result.results[obj_id] = obj_def
|
203
234
|
|
204
|
-
for bundle_id in self.node.bundles:
|
235
|
+
for bundle_id in self.node.outputs.bundles:
|
205
236
|
bundle = _ctx_lookup(bundle_id, ctx)
|
206
237
|
job_result.results.update(bundle.items())
|
207
238
|
|
239
|
+
if self.node.runtime_outputs is not None:
|
240
|
+
|
241
|
+
runtime_outputs = _ctx_lookup(self.node.runtime_outputs, ctx)
|
242
|
+
|
243
|
+
for obj_id, node_id in runtime_outputs.objects.items():
|
244
|
+
obj_def = _ctx_lookup(node_id, ctx)
|
245
|
+
job_result.results[obj_id] = obj_def
|
246
|
+
|
247
|
+
for bundle_id in runtime_outputs.bundles:
|
248
|
+
bundle = _ctx_lookup(bundle_id, ctx)
|
249
|
+
job_result.results.update(bundle.items())
|
250
|
+
|
208
251
|
return job_result
|
209
252
|
|
210
253
|
|
211
254
|
class SaveJobResultFunc(NodeFunction[None]):
|
212
255
|
|
213
256
|
def __init__(self, node: SaveJobResultNode):
|
257
|
+
super().__init__()
|
214
258
|
self.node = node
|
215
259
|
|
216
260
|
def _execute(self, ctx: NodeContext) -> None:
|
@@ -241,6 +285,7 @@ class SaveJobResultFunc(NodeFunction[None]):
|
|
241
285
|
class DataViewFunc(NodeFunction[_data.DataView]):
|
242
286
|
|
243
287
|
def __init__(self, node: DataViewNode):
|
288
|
+
super().__init__()
|
244
289
|
self.node = node
|
245
290
|
|
246
291
|
def _execute(self, ctx: NodeContext) -> _data.DataView:
|
@@ -252,7 +297,13 @@ class DataViewFunc(NodeFunction[_data.DataView]):
|
|
252
297
|
if root_item.is_empty():
|
253
298
|
return _data.DataView.create_empty()
|
254
299
|
|
255
|
-
|
300
|
+
if self.node.schema is not None and len(self.node.schema.table.fields) > 0:
|
301
|
+
trac_schema = self.node.schema
|
302
|
+
else:
|
303
|
+
arrow_schema = root_item.schema
|
304
|
+
trac_schema = _data.DataMapping.arrow_to_trac_schema(arrow_schema)
|
305
|
+
|
306
|
+
data_view = _data.DataView.for_trac_schema(trac_schema)
|
256
307
|
data_view = _data.DataMapping.add_item_to_view(data_view, root_part_key, root_item)
|
257
308
|
|
258
309
|
return data_view
|
@@ -261,6 +312,7 @@ class DataViewFunc(NodeFunction[_data.DataView]):
|
|
261
312
|
class DataItemFunc(NodeFunction[_data.DataItem]):
|
262
313
|
|
263
314
|
def __init__(self, node: DataItemNode):
|
315
|
+
super().__init__()
|
264
316
|
self.node = node
|
265
317
|
|
266
318
|
def _execute(self, ctx: NodeContext) -> _data.DataItem:
|
@@ -284,6 +336,7 @@ class DataItemFunc(NodeFunction[_data.DataItem]):
|
|
284
336
|
class DataResultFunc(NodeFunction[ObjectBundle]):
|
285
337
|
|
286
338
|
def __init__(self, node: DataResultNode):
|
339
|
+
super().__init__()
|
287
340
|
self.node = node
|
288
341
|
|
289
342
|
def _execute(self, ctx: NodeContext) -> ObjectBundle:
|
@@ -318,6 +371,7 @@ class DynamicDataSpecFunc(NodeFunction[_data.DataSpec]):
|
|
318
371
|
RANDOM.seed()
|
319
372
|
|
320
373
|
def __init__(self, node: DynamicDataSpecNode, storage: _storage.StorageManager):
|
374
|
+
super().__init__()
|
321
375
|
self.node = node
|
322
376
|
self.storage = storage
|
323
377
|
|
@@ -428,7 +482,7 @@ class _LoadSaveDataFunc(abc.ABC):
|
|
428
482
|
return copy_
|
429
483
|
|
430
484
|
|
431
|
-
class LoadDataFunc(NodeFunction[_data.DataItem],
|
485
|
+
class LoadDataFunc( _LoadSaveDataFunc, NodeFunction[_data.DataItem],):
|
432
486
|
|
433
487
|
def __init__(self, node: LoadDataNode, storage: _storage.StorageManager):
|
434
488
|
super().__init__(storage)
|
@@ -457,7 +511,7 @@ class LoadDataFunc(NodeFunction[_data.DataItem], _LoadSaveDataFunc):
|
|
457
511
|
return _data.DataItem(table.schema, table)
|
458
512
|
|
459
513
|
|
460
|
-
class SaveDataFunc(NodeFunction[None]
|
514
|
+
class SaveDataFunc(_LoadSaveDataFunc, NodeFunction[None]):
|
461
515
|
|
462
516
|
def __init__(self, node: SaveDataNode, storage: _storage.StorageManager):
|
463
517
|
super().__init__(storage)
|
@@ -512,6 +566,7 @@ def _model_def_for_import(import_details: meta.ImportModelJob):
|
|
512
566
|
class ImportModelFunc(NodeFunction[meta.ObjectDefinition]):
|
513
567
|
|
514
568
|
def __init__(self, node: ImportModelNode, models: _models.ModelLoader):
|
569
|
+
super().__init__()
|
515
570
|
self.node = node
|
516
571
|
self._models = models
|
517
572
|
|
@@ -529,11 +584,17 @@ class ImportModelFunc(NodeFunction[meta.ObjectDefinition]):
|
|
529
584
|
|
530
585
|
class RunModelFunc(NodeFunction[Bundle[_data.DataView]]):
|
531
586
|
|
532
|
-
def __init__(
|
587
|
+
def __init__(
|
588
|
+
self, node: RunModelNode,
|
589
|
+
model_class: _api.TracModel.__class__,
|
590
|
+
checkout_directory: pathlib.Path,
|
591
|
+
storage_manager: _storage.StorageManager):
|
592
|
+
|
533
593
|
super().__init__()
|
534
594
|
self.node = node
|
535
595
|
self.model_class = model_class
|
536
596
|
self.checkout_directory = checkout_directory
|
597
|
+
self.storage_manager = storage_manager
|
537
598
|
|
538
599
|
def _execute(self, ctx: NodeContext) -> Bundle[_data.DataView]:
|
539
600
|
|
@@ -544,36 +605,37 @@ class RunModelFunc(NodeFunction[Bundle[_data.DataView]]):
|
|
544
605
|
# Still, if any nodes are missing or have the wrong type TracContextImpl will raise ERuntimeValidation
|
545
606
|
|
546
607
|
local_ctx = {}
|
547
|
-
|
608
|
+
dynamic_outputs = []
|
548
609
|
|
549
610
|
for node_id, node_result in _ctx_iter_items(ctx):
|
611
|
+
if node_id.namespace == self.node.id.namespace:
|
612
|
+
if node_id.name in model_def.parameters or node_id.name in model_def.inputs:
|
613
|
+
local_ctx[node_id.name] = node_result
|
550
614
|
|
551
|
-
|
552
|
-
continue
|
553
|
-
|
554
|
-
if node_id.name in model_def.parameters:
|
555
|
-
param_name = node_id.name
|
556
|
-
local_ctx[param_name] = node_result
|
557
|
-
|
558
|
-
if node_id.name in model_def.inputs:
|
559
|
-
input_name = node_id.name
|
560
|
-
local_ctx[input_name] = node_result
|
561
|
-
# At the moment, all model inputs have static schemas
|
562
|
-
static_schemas[input_name] = model_def.inputs[input_name].schema
|
615
|
+
# Set up access to external storage if required
|
563
616
|
|
564
|
-
|
565
|
-
# Assuming outputs are all defined with static schemas
|
617
|
+
storage_map = {}
|
566
618
|
|
567
|
-
|
568
|
-
|
569
|
-
|
570
|
-
|
571
|
-
|
572
|
-
|
619
|
+
if self.node.storage_access:
|
620
|
+
write_access = True if self.node.model_def.modelType == meta.ModelType.DATA_EXPORT_MODEL else False
|
621
|
+
for storage_key in self.node.storage_access:
|
622
|
+
if self.storage_manager.has_file_storage(storage_key, external=True):
|
623
|
+
storage_impl = self.storage_manager.get_file_storage(storage_key, external=True)
|
624
|
+
storage = _ctx.TracFileStorageImpl(storage_key, storage_impl, write_access, self.checkout_directory)
|
625
|
+
storage_map[storage_key] = storage
|
573
626
|
|
574
627
|
# Run the model against the mapped local context
|
575
628
|
|
576
|
-
|
629
|
+
if model_def.modelType in [meta.ModelType.DATA_IMPORT_MODEL, meta.ModelType.DATA_EXPORT_MODEL]:
|
630
|
+
trac_ctx = _ctx.TracDataContextImpl(
|
631
|
+
self.node.model_def, self.model_class,
|
632
|
+
local_ctx, dynamic_outputs, storage_map,
|
633
|
+
self.checkout_directory)
|
634
|
+
else:
|
635
|
+
trac_ctx = _ctx.TracContextImpl(
|
636
|
+
self.node.model_def, self.model_class,
|
637
|
+
local_ctx, dynamic_outputs,
|
638
|
+
self.checkout_directory)
|
577
639
|
|
578
640
|
try:
|
579
641
|
model = self.model_class()
|
@@ -587,20 +649,50 @@ class RunModelFunc(NodeFunction[Bundle[_data.DataView]]):
|
|
587
649
|
|
588
650
|
# Check required outputs are present and build the results bundle
|
589
651
|
|
652
|
+
model_name = self.model_class.__name__
|
590
653
|
results: Bundle[_data.DataView] = dict()
|
654
|
+
new_nodes = dict()
|
655
|
+
new_deps = dict()
|
591
656
|
|
592
657
|
for output_name, output_schema in model_def.outputs.items():
|
593
658
|
|
594
659
|
result: _data.DataView = local_ctx.get(output_name)
|
595
660
|
|
596
661
|
if result is None or result.is_empty():
|
662
|
+
|
597
663
|
if not output_schema.optional:
|
598
|
-
model_name = self.model_class.__name__
|
599
664
|
raise _ex.ERuntimeValidation(f"Missing required output [{output_name}] from model [{model_name}]")
|
600
665
|
|
601
|
-
|
666
|
+
# Create a placeholder for optional outputs that were not emitted
|
667
|
+
elif result is None:
|
668
|
+
result = _data.DataView.create_empty()
|
669
|
+
|
670
|
+
results[output_name] = result
|
671
|
+
|
672
|
+
if dynamic_outputs:
|
673
|
+
|
674
|
+
for output_name in dynamic_outputs:
|
675
|
+
|
676
|
+
result: _data.DataView = local_ctx.get(output_name)
|
677
|
+
|
678
|
+
if result is None or result.is_empty():
|
679
|
+
raise _ex.ERuntimeValidation(f"No data provided for [{output_name}] from model [{model_name}]")
|
680
|
+
|
602
681
|
results[output_name] = result
|
603
682
|
|
683
|
+
result_node_id = NodeId.of(output_name, self.node.id.namespace, _data.DataView)
|
684
|
+
result_node = BundleItemNode(result_node_id, self.node.id, output_name)
|
685
|
+
|
686
|
+
new_nodes[result_node_id] = result_node
|
687
|
+
|
688
|
+
output_section = _graph.GraphBuilder.build_runtime_outputs(dynamic_outputs, self.node.id.namespace)
|
689
|
+
new_nodes.update(output_section.nodes)
|
690
|
+
|
691
|
+
ctx_id = NodeId.of("trac_build_result", self.node.id.namespace, result_type=None)
|
692
|
+
new_deps[ctx_id] = list(_graph.Dependency(nid, _graph.DependencyType.HARD) for nid in output_section.outputs)
|
693
|
+
|
694
|
+
self.node_callback.send_graph_updates(new_nodes, new_deps)
|
695
|
+
|
604
696
|
return results
|
605
697
|
|
606
698
|
|
@@ -624,6 +716,14 @@ class FunctionResolver:
|
|
624
716
|
:py:class:`NodeFunction <NodeFunction>`
|
625
717
|
"""
|
626
718
|
|
719
|
+
# TODO: Validate consistency for resource keys
|
720
|
+
# Storage key should be validated for load data, save data and run model with storage access
|
721
|
+
# Repository key should be validated for import model (and explicitly for run model)
|
722
|
+
|
723
|
+
# Currently jobs with missing resources will fail at runtime, with a suitable error
|
724
|
+
# The resolver is called during graph building
|
725
|
+
# Putting the check here will raise a consistency error before the job starts processing
|
726
|
+
|
627
727
|
__ResolveFunc = tp.Callable[['FunctionResolver', Node[_T]], NodeFunction[_T]]
|
628
728
|
|
629
729
|
def __init__(self, models: _models.ModelLoader, storage: _storage.StorageManager):
|
@@ -658,12 +758,13 @@ class FunctionResolver:
|
|
658
758
|
|
659
759
|
def resolve_run_model_node(self, node: RunModelNode) -> NodeFunction:
|
660
760
|
|
761
|
+
# TODO: Verify model_class against model_def
|
762
|
+
|
661
763
|
model_class = self._models.load_model_class(node.model_scope, node.model_def)
|
662
764
|
checkout_directory = self._models.model_load_checkout_directory(node.model_scope, node.model_def)
|
765
|
+
storage_manager = self._storage if node.storage_access else None
|
663
766
|
|
664
|
-
|
665
|
-
|
666
|
-
return RunModelFunc(node, model_class, checkout_directory)
|
767
|
+
return RunModelFunc(node, model_class, checkout_directory, storage_manager)
|
667
768
|
|
668
769
|
__basic_node_mapping: tp.Dict[Node.__class__, NodeFunction.__class__] = {
|
669
770
|
|
@@ -677,6 +778,7 @@ class FunctionResolver:
|
|
677
778
|
SaveJobResultNode: SaveJobResultFunc,
|
678
779
|
DataResultNode: DataResultFunc,
|
679
780
|
StaticValueNode: StaticValueFunc,
|
781
|
+
RuntimeOutputsNode: RuntimeOutputsFunc,
|
680
782
|
BundleItemNode: NoopFunc,
|
681
783
|
NoopNode: NoopFunc,
|
682
784
|
RunModelResultNode: NoopFunc
|
tracdap/rt/_exec/graph.py
CHANGED
@@ -12,8 +12,6 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from __future__ import annotations
|
16
|
-
|
17
15
|
import pathlib
|
18
16
|
import typing as tp
|
19
17
|
import dataclasses as dc
|
@@ -38,7 +36,7 @@ class NodeNamespace:
|
|
38
36
|
return cls.__ROOT
|
39
37
|
|
40
38
|
name: str
|
41
|
-
parent: tp.Optional[NodeNamespace] = dc.field(default_factory=lambda: NodeNamespace.root())
|
39
|
+
parent: "tp.Optional[NodeNamespace]" = dc.field(default_factory=lambda: NodeNamespace.root())
|
42
40
|
|
43
41
|
def __str__(self):
|
44
42
|
if self is self.__ROOT:
|
@@ -62,7 +60,7 @@ class NodeNamespace:
|
|
62
60
|
class NodeId(tp.Generic[_T]):
|
63
61
|
|
64
62
|
@staticmethod
|
65
|
-
def of(name: str, namespace: NodeNamespace, result_type: tp.Type[_T]) -> NodeId[_T]:
|
63
|
+
def of(name: str, namespace: NodeNamespace, result_type: tp.Type[_T]) -> "NodeId[_T]":
|
66
64
|
return NodeId(name, namespace, result_type)
|
67
65
|
|
68
66
|
name: str
|
@@ -83,8 +81,8 @@ class DependencyType:
|
|
83
81
|
immediate: bool = True
|
84
82
|
tolerant: bool = False
|
85
83
|
|
86
|
-
HARD: tp.ClassVar[DependencyType]
|
87
|
-
TOLERANT: tp.ClassVar[DependencyType]
|
84
|
+
HARD: "tp.ClassVar[DependencyType]"
|
85
|
+
TOLERANT: "tp.ClassVar[DependencyType]"
|
88
86
|
|
89
87
|
|
90
88
|
DependencyType.HARD = DependencyType(immediate=True, tolerant=False)
|
@@ -93,6 +91,13 @@ DependencyType.TOLERANT = DependencyType(immediate=True, tolerant=True)
|
|
93
91
|
DependencyType.DELAYED = DependencyType(immediate=False, tolerant=False)
|
94
92
|
|
95
93
|
|
94
|
+
@dc.dataclass(frozen=True)
|
95
|
+
class Dependency:
|
96
|
+
|
97
|
+
node_id: NodeId
|
98
|
+
dependency_type: DependencyType
|
99
|
+
|
100
|
+
|
96
101
|
@dc.dataclass(frozen=True)
|
97
102
|
class Node(tp.Generic[_T]):
|
98
103
|
|
@@ -165,6 +170,17 @@ class GraphSection:
|
|
165
170
|
must_run: tp.List[NodeId] = dc.field(default_factory=list)
|
166
171
|
|
167
172
|
|
173
|
+
Bundle: tp.Generic[_T] = tp.Dict[str, _T]
|
174
|
+
ObjectBundle = Bundle[meta.ObjectDefinition]
|
175
|
+
|
176
|
+
|
177
|
+
@dc.dataclass(frozen=True)
|
178
|
+
class JobOutputs:
|
179
|
+
|
180
|
+
objects: tp.Dict[str, NodeId[meta.ObjectDefinition]] = dc.field(default_factory=dict)
|
181
|
+
bundles: tp.List[NodeId[ObjectBundle]] = dc.field(default_factory=list)
|
182
|
+
|
183
|
+
|
168
184
|
# TODO: Where does this go?
|
169
185
|
@dc.dataclass(frozen=True)
|
170
186
|
class JobResultSpec:
|
@@ -179,10 +195,6 @@ class JobResultSpec:
|
|
179
195
|
# ----------------------------------------------------------------------------------------------------------------------
|
180
196
|
|
181
197
|
|
182
|
-
Bundle: tp.Generic[_T] = tp.Dict[str, _T]
|
183
|
-
ObjectBundle = Bundle[meta.ObjectDefinition]
|
184
|
-
|
185
|
-
|
186
198
|
@_node_type
|
187
199
|
class NoopNode(Node):
|
188
200
|
pass
|
@@ -354,6 +366,7 @@ class RunModelNode(Node[Bundle[_data.DataView]]):
|
|
354
366
|
model_def: meta.ModelDefinition
|
355
367
|
parameter_ids: tp.FrozenSet[NodeId]
|
356
368
|
input_ids: tp.FrozenSet[NodeId]
|
369
|
+
storage_access: tp.Optional[tp.List[str]] = None
|
357
370
|
|
358
371
|
def _node_dependencies(self) -> tp.Dict[NodeId, DependencyType]:
|
359
372
|
return {dep_id: DependencyType.HARD for dep_id in [*self.parameter_ids, *self.input_ids]}
|
@@ -368,16 +381,28 @@ class RunModelResultNode(Node[None]):
|
|
368
381
|
return {self.model_id: DependencyType.HARD}
|
369
382
|
|
370
383
|
|
384
|
+
@_node_type
|
385
|
+
class RuntimeOutputsNode(Node[JobOutputs]):
|
386
|
+
|
387
|
+
outputs: JobOutputs
|
388
|
+
|
389
|
+
def _node_dependencies(self) -> tp.Dict[NodeId, DependencyType]:
|
390
|
+
dep_ids = [*self.outputs.bundles, *self.outputs.objects.values()]
|
391
|
+
return {node_id: DependencyType.HARD for node_id in dep_ids}
|
392
|
+
|
393
|
+
|
371
394
|
@_node_type
|
372
395
|
class BuildJobResultNode(Node[cfg.JobResult]):
|
373
396
|
|
374
397
|
job_id: meta.TagHeader
|
375
398
|
|
376
|
-
|
377
|
-
|
399
|
+
outputs: JobOutputs
|
400
|
+
runtime_outputs: tp.Optional[NodeId[JobOutputs]] = None
|
378
401
|
|
379
402
|
def _node_dependencies(self) -> tp.Dict[NodeId, DependencyType]:
|
380
|
-
dep_ids = [*self.bundles, *self.objects.values()]
|
403
|
+
dep_ids = [*self.outputs.bundles, *self.outputs.objects.values()]
|
404
|
+
if self.runtime_outputs is not None:
|
405
|
+
dep_ids.append(self.runtime_outputs)
|
381
406
|
return {node_id: DependencyType.HARD for node_id in dep_ids}
|
382
407
|
|
383
408
|
|