tracdap-runtime 0.6.4__py3-none-any.whl → 0.6.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tracdap/rt/_exec/context.py +556 -36
- tracdap/rt/_exec/dev_mode.py +320 -198
- tracdap/rt/_exec/engine.py +331 -62
- tracdap/rt/_exec/functions.py +151 -22
- tracdap/rt/_exec/graph.py +47 -13
- tracdap/rt/_exec/graph_builder.py +383 -175
- tracdap/rt/_exec/runtime.py +7 -5
- tracdap/rt/_impl/config_parser.py +11 -4
- tracdap/rt/_impl/data.py +329 -152
- tracdap/rt/_impl/ext/__init__.py +13 -0
- tracdap/rt/_impl/ext/sql.py +116 -0
- tracdap/rt/_impl/ext/storage.py +57 -0
- tracdap/rt/_impl/grpc/tracdap/metadata/job_pb2.py +82 -30
- tracdap/rt/_impl/grpc/tracdap/metadata/job_pb2.pyi +155 -2
- tracdap/rt/_impl/grpc/tracdap/metadata/model_pb2.py +12 -10
- tracdap/rt/_impl/grpc/tracdap/metadata/model_pb2.pyi +14 -2
- tracdap/rt/_impl/grpc/tracdap/metadata/resource_pb2.py +29 -0
- tracdap/rt/_impl/grpc/tracdap/metadata/resource_pb2.pyi +16 -0
- tracdap/rt/_impl/models.py +8 -0
- tracdap/rt/_impl/static_api.py +29 -0
- tracdap/rt/_impl/storage.py +39 -27
- tracdap/rt/_impl/util.py +10 -0
- tracdap/rt/_impl/validation.py +140 -18
- tracdap/rt/_plugins/repo_git.py +1 -1
- tracdap/rt/_plugins/storage_sql.py +417 -0
- tracdap/rt/_plugins/storage_sql_dialects.py +117 -0
- tracdap/rt/_version.py +1 -1
- tracdap/rt/api/experimental.py +267 -0
- tracdap/rt/api/hook.py +14 -0
- tracdap/rt/api/model_api.py +48 -6
- tracdap/rt/config/__init__.py +2 -2
- tracdap/rt/config/common.py +6 -0
- tracdap/rt/metadata/__init__.py +29 -20
- tracdap/rt/metadata/job.py +99 -0
- tracdap/rt/metadata/model.py +18 -0
- tracdap/rt/metadata/resource.py +24 -0
- {tracdap_runtime-0.6.4.dist-info → tracdap_runtime-0.6.6.dist-info}/METADATA +5 -1
- {tracdap_runtime-0.6.4.dist-info → tracdap_runtime-0.6.6.dist-info}/RECORD +41 -32
- {tracdap_runtime-0.6.4.dist-info → tracdap_runtime-0.6.6.dist-info}/WHEEL +1 -1
- {tracdap_runtime-0.6.4.dist-info → tracdap_runtime-0.6.6.dist-info}/LICENSE +0 -0
- {tracdap_runtime-0.6.4.dist-info → tracdap_runtime-0.6.6.dist-info}/top_level.txt +0 -0
tracdap/rt/_exec/functions.py
CHANGED
@@ -23,6 +23,7 @@ import tracdap.rt.api as _api
|
|
23
23
|
import tracdap.rt.config as _config
|
24
24
|
import tracdap.rt.exceptions as _ex
|
25
25
|
import tracdap.rt._exec.context as _ctx
|
26
|
+
import tracdap.rt._exec.graph_builder as _graph
|
26
27
|
import tracdap.rt._impl.config_parser as _cfg_p # noqa
|
27
28
|
import tracdap.rt._impl.type_system as _types # noqa
|
28
29
|
import tracdap.rt._impl.data as _data # noqa
|
@@ -59,6 +60,13 @@ class NodeContext:
|
|
59
60
|
pass
|
60
61
|
|
61
62
|
|
63
|
+
class NodeCallback:
|
64
|
+
|
65
|
+
@abc.abstractmethod
|
66
|
+
def send_graph_updates(self, new_nodes: tp.Dict[NodeId, Node], new_deps: tp.Dict[NodeId, tp.List[Dependency]]):
|
67
|
+
pass
|
68
|
+
|
69
|
+
|
62
70
|
# Helper functions to access the node context (in case the NodeContext interface needs to change)
|
63
71
|
|
64
72
|
def _ctx_lookup(node_id: NodeId[_T], ctx: NodeContext) -> _T:
|
@@ -89,8 +97,15 @@ class NodeFunction(tp.Generic[_T]):
|
|
89
97
|
:py:class:`NodeContext <NodeContext>`
|
90
98
|
"""
|
91
99
|
|
92
|
-
def
|
93
|
-
|
100
|
+
def __init__(self):
|
101
|
+
self.node_callback: tp.Optional[NodeCallback] = None
|
102
|
+
|
103
|
+
def __call__(self, ctx: NodeContext, callback: NodeCallback = None) -> _T:
|
104
|
+
try:
|
105
|
+
self.node_callback = callback
|
106
|
+
return self._execute(ctx)
|
107
|
+
finally:
|
108
|
+
self.node_callback = None
|
94
109
|
|
95
110
|
@abc.abstractmethod
|
96
111
|
def _execute(self, ctx: NodeContext) -> _T:
|
@@ -105,6 +120,7 @@ class NodeFunction(tp.Generic[_T]):
|
|
105
120
|
class NoopFunc(NodeFunction[None]):
|
106
121
|
|
107
122
|
def __init__(self, node: NoopNode):
|
123
|
+
super().__init__()
|
108
124
|
self.node = node
|
109
125
|
|
110
126
|
def _execute(self, _: NodeContext) -> None:
|
@@ -114,6 +130,7 @@ class NoopFunc(NodeFunction[None]):
|
|
114
130
|
class StaticValueFunc(NodeFunction[_T]):
|
115
131
|
|
116
132
|
def __init__(self, node: StaticValueNode[_T]):
|
133
|
+
super().__init__()
|
117
134
|
self.node = node
|
118
135
|
|
119
136
|
def _execute(self, ctx: NodeContext) -> _T:
|
@@ -123,6 +140,7 @@ class StaticValueFunc(NodeFunction[_T]):
|
|
123
140
|
class IdentityFunc(NodeFunction[_T]):
|
124
141
|
|
125
142
|
def __init__(self, node: IdentityNode[_T]):
|
143
|
+
super().__init__()
|
126
144
|
self.node = node
|
127
145
|
|
128
146
|
def _execute(self, ctx: NodeContext) -> _T:
|
@@ -138,6 +156,7 @@ class _ContextPushPopFunc(NodeFunction[Bundle[tp.Any]], abc.ABC):
|
|
138
156
|
_POP = False
|
139
157
|
|
140
158
|
def __init__(self, node: tp.Union[ContextPushNode, ContextPopNode], direction: bool):
|
159
|
+
super().__init__()
|
141
160
|
self.node = node
|
142
161
|
self.direction = direction
|
143
162
|
|
@@ -176,6 +195,7 @@ class ContextPopFunc(_ContextPushPopFunc):
|
|
176
195
|
class KeyedItemFunc(NodeFunction[_T]):
|
177
196
|
|
178
197
|
def __init__(self, node: KeyedItemNode[_T]):
|
198
|
+
super().__init__()
|
179
199
|
self.node = node
|
180
200
|
|
181
201
|
def _execute(self, ctx: NodeContext) -> _T:
|
@@ -184,9 +204,20 @@ class KeyedItemFunc(NodeFunction[_T]):
|
|
184
204
|
return src_item
|
185
205
|
|
186
206
|
|
207
|
+
class RuntimeOutputsFunc(NodeFunction[JobOutputs]):
|
208
|
+
|
209
|
+
def __init__(self, node: RuntimeOutputsNode):
|
210
|
+
super().__init__()
|
211
|
+
self.node = node
|
212
|
+
|
213
|
+
def _execute(self, ctx: NodeContext) -> JobOutputs:
|
214
|
+
return self.node.outputs
|
215
|
+
|
216
|
+
|
187
217
|
class BuildJobResultFunc(NodeFunction[_config.JobResult]):
|
188
218
|
|
189
219
|
def __init__(self, node: BuildJobResultNode):
|
220
|
+
super().__init__()
|
190
221
|
self.node = node
|
191
222
|
|
192
223
|
def _execute(self, ctx: NodeContext) -> _config.JobResult:
|
@@ -197,20 +228,33 @@ class BuildJobResultFunc(NodeFunction[_config.JobResult]):
|
|
197
228
|
|
198
229
|
# TODO: Handle individual failed results
|
199
230
|
|
200
|
-
for obj_id, node_id in self.node.objects.items():
|
231
|
+
for obj_id, node_id in self.node.outputs.objects.items():
|
201
232
|
obj_def = _ctx_lookup(node_id, ctx)
|
202
233
|
job_result.results[obj_id] = obj_def
|
203
234
|
|
204
|
-
for bundle_id in self.node.bundles:
|
235
|
+
for bundle_id in self.node.outputs.bundles:
|
205
236
|
bundle = _ctx_lookup(bundle_id, ctx)
|
206
237
|
job_result.results.update(bundle.items())
|
207
238
|
|
239
|
+
if self.node.runtime_outputs is not None:
|
240
|
+
|
241
|
+
runtime_outputs = _ctx_lookup(self.node.runtime_outputs, ctx)
|
242
|
+
|
243
|
+
for obj_id, node_id in runtime_outputs.objects.items():
|
244
|
+
obj_def = _ctx_lookup(node_id, ctx)
|
245
|
+
job_result.results[obj_id] = obj_def
|
246
|
+
|
247
|
+
for bundle_id in runtime_outputs.bundles:
|
248
|
+
bundle = _ctx_lookup(bundle_id, ctx)
|
249
|
+
job_result.results.update(bundle.items())
|
250
|
+
|
208
251
|
return job_result
|
209
252
|
|
210
253
|
|
211
254
|
class SaveJobResultFunc(NodeFunction[None]):
|
212
255
|
|
213
256
|
def __init__(self, node: SaveJobResultNode):
|
257
|
+
super().__init__()
|
214
258
|
self.node = node
|
215
259
|
|
216
260
|
def _execute(self, ctx: NodeContext) -> None:
|
@@ -241,6 +285,7 @@ class SaveJobResultFunc(NodeFunction[None]):
|
|
241
285
|
class DataViewFunc(NodeFunction[_data.DataView]):
|
242
286
|
|
243
287
|
def __init__(self, node: DataViewNode):
|
288
|
+
super().__init__()
|
244
289
|
self.node = node
|
245
290
|
|
246
291
|
def _execute(self, ctx: NodeContext) -> _data.DataView:
|
@@ -267,6 +312,7 @@ class DataViewFunc(NodeFunction[_data.DataView]):
|
|
267
312
|
class DataItemFunc(NodeFunction[_data.DataItem]):
|
268
313
|
|
269
314
|
def __init__(self, node: DataItemNode):
|
315
|
+
super().__init__()
|
270
316
|
self.node = node
|
271
317
|
|
272
318
|
def _execute(self, ctx: NodeContext) -> _data.DataItem:
|
@@ -290,6 +336,7 @@ class DataItemFunc(NodeFunction[_data.DataItem]):
|
|
290
336
|
class DataResultFunc(NodeFunction[ObjectBundle]):
|
291
337
|
|
292
338
|
def __init__(self, node: DataResultNode):
|
339
|
+
super().__init__()
|
293
340
|
self.node = node
|
294
341
|
|
295
342
|
def _execute(self, ctx: NodeContext) -> ObjectBundle:
|
@@ -324,6 +371,7 @@ class DynamicDataSpecFunc(NodeFunction[_data.DataSpec]):
|
|
324
371
|
RANDOM.seed()
|
325
372
|
|
326
373
|
def __init__(self, node: DynamicDataSpecNode, storage: _storage.StorageManager):
|
374
|
+
super().__init__()
|
327
375
|
self.node = node
|
328
376
|
self.storage = storage
|
329
377
|
|
@@ -434,7 +482,7 @@ class _LoadSaveDataFunc(abc.ABC):
|
|
434
482
|
return copy_
|
435
483
|
|
436
484
|
|
437
|
-
class LoadDataFunc(NodeFunction[_data.DataItem],
|
485
|
+
class LoadDataFunc( _LoadSaveDataFunc, NodeFunction[_data.DataItem],):
|
438
486
|
|
439
487
|
def __init__(self, node: LoadDataNode, storage: _storage.StorageManager):
|
440
488
|
super().__init__(storage)
|
@@ -463,7 +511,7 @@ class LoadDataFunc(NodeFunction[_data.DataItem], _LoadSaveDataFunc):
|
|
463
511
|
return _data.DataItem(table.schema, table)
|
464
512
|
|
465
513
|
|
466
|
-
class SaveDataFunc(NodeFunction[None]
|
514
|
+
class SaveDataFunc(_LoadSaveDataFunc, NodeFunction[None]):
|
467
515
|
|
468
516
|
def __init__(self, node: SaveDataNode, storage: _storage.StorageManager):
|
469
517
|
super().__init__(storage)
|
@@ -518,6 +566,7 @@ def _model_def_for_import(import_details: meta.ImportModelJob):
|
|
518
566
|
class ImportModelFunc(NodeFunction[meta.ObjectDefinition]):
|
519
567
|
|
520
568
|
def __init__(self, node: ImportModelNode, models: _models.ModelLoader):
|
569
|
+
super().__init__()
|
521
570
|
self.node = node
|
522
571
|
self._models = models
|
523
572
|
|
@@ -535,11 +584,17 @@ class ImportModelFunc(NodeFunction[meta.ObjectDefinition]):
|
|
535
584
|
|
536
585
|
class RunModelFunc(NodeFunction[Bundle[_data.DataView]]):
|
537
586
|
|
538
|
-
def __init__(
|
587
|
+
def __init__(
|
588
|
+
self, node: RunModelNode,
|
589
|
+
model_class: _api.TracModel.__class__,
|
590
|
+
checkout_directory: pathlib.Path,
|
591
|
+
storage_manager: _storage.StorageManager):
|
592
|
+
|
539
593
|
super().__init__()
|
540
594
|
self.node = node
|
541
595
|
self.model_class = model_class
|
542
596
|
self.checkout_directory = checkout_directory
|
597
|
+
self.storage_manager = storage_manager
|
543
598
|
|
544
599
|
def _execute(self, ctx: NodeContext) -> Bundle[_data.DataView]:
|
545
600
|
|
@@ -550,23 +605,48 @@ class RunModelFunc(NodeFunction[Bundle[_data.DataView]]):
|
|
550
605
|
# Still, if any nodes are missing or have the wrong type TracContextImpl will raise ERuntimeValidation
|
551
606
|
|
552
607
|
local_ctx = {}
|
608
|
+
dynamic_outputs = []
|
553
609
|
|
554
610
|
for node_id, node_result in _ctx_iter_items(ctx):
|
611
|
+
if node_id.namespace == self.node.id.namespace:
|
612
|
+
if node_id.name in model_def.parameters or node_id.name in model_def.inputs:
|
613
|
+
local_ctx[node_id.name] = node_result
|
614
|
+
|
615
|
+
# Set up access to external storage if required
|
616
|
+
|
617
|
+
storage_map = {}
|
618
|
+
|
619
|
+
if self.node.storage_access:
|
620
|
+
write_access = True if self.node.model_def.modelType == meta.ModelType.DATA_EXPORT_MODEL else False
|
621
|
+
for storage_key in self.node.storage_access:
|
622
|
+
if self.storage_manager.has_file_storage(storage_key, external=True):
|
623
|
+
storage_impl = self.storage_manager.get_file_storage(storage_key, external=True)
|
624
|
+
storage = _ctx.TracFileStorageImpl(storage_key, storage_impl, write_access, self.checkout_directory)
|
625
|
+
storage_map[storage_key] = storage
|
626
|
+
elif self.storage_manager.has_data_storage(storage_key, external=True):
|
627
|
+
storage_impl = self.storage_manager.get_data_storage(storage_key, external=True)
|
628
|
+
# This is a work-around until the storage extension API can be updated / unified
|
629
|
+
if not isinstance(storage_impl, _storage.IDataStorageBase):
|
630
|
+
raise _ex.EStorageConfig(f"External storage for [{storage_key}] is using the legacy storage framework]")
|
631
|
+
converter = _data.DataConverter.noop()
|
632
|
+
storage = _ctx.TracDataStorageImpl(storage_key, storage_impl, converter, write_access, self.checkout_directory)
|
633
|
+
storage_map[storage_key] = storage
|
634
|
+
else:
|
635
|
+
raise _ex.EStorageConfig(f"External storage is not available: [{storage_key}]")
|
555
636
|
|
556
|
-
if node_id.namespace != self.node.id.namespace:
|
557
|
-
continue
|
558
|
-
|
559
|
-
if node_id.name in model_def.parameters:
|
560
|
-
param_name = node_id.name
|
561
|
-
local_ctx[param_name] = node_result
|
562
|
-
|
563
|
-
if node_id.name in model_def.inputs:
|
564
|
-
input_name = node_id.name
|
565
|
-
local_ctx[input_name] = node_result
|
566
637
|
|
567
638
|
# Run the model against the mapped local context
|
568
639
|
|
569
|
-
|
640
|
+
if model_def.modelType in [meta.ModelType.DATA_IMPORT_MODEL, meta.ModelType.DATA_EXPORT_MODEL]:
|
641
|
+
trac_ctx = _ctx.TracDataContextImpl(
|
642
|
+
self.node.model_def, self.model_class,
|
643
|
+
local_ctx, dynamic_outputs, storage_map,
|
644
|
+
self.checkout_directory)
|
645
|
+
else:
|
646
|
+
trac_ctx = _ctx.TracContextImpl(
|
647
|
+
self.node.model_def, self.model_class,
|
648
|
+
local_ctx, dynamic_outputs,
|
649
|
+
self.checkout_directory)
|
570
650
|
|
571
651
|
try:
|
572
652
|
model = self.model_class()
|
@@ -580,7 +660,10 @@ class RunModelFunc(NodeFunction[Bundle[_data.DataView]]):
|
|
580
660
|
|
581
661
|
# Check required outputs are present and build the results bundle
|
582
662
|
|
663
|
+
model_name = self.model_class.__name__
|
583
664
|
results: Bundle[_data.DataView] = dict()
|
665
|
+
new_nodes = dict()
|
666
|
+
new_deps = dict()
|
584
667
|
|
585
668
|
for output_name, output_schema in model_def.outputs.items():
|
586
669
|
|
@@ -589,7 +672,6 @@ class RunModelFunc(NodeFunction[Bundle[_data.DataView]]):
|
|
589
672
|
if result is None or result.is_empty():
|
590
673
|
|
591
674
|
if not output_schema.optional:
|
592
|
-
model_name = self.model_class.__name__
|
593
675
|
raise _ex.ERuntimeValidation(f"Missing required output [{output_name}] from model [{model_name}]")
|
594
676
|
|
595
677
|
# Create a placeholder for optional outputs that were not emitted
|
@@ -598,9 +680,45 @@ class RunModelFunc(NodeFunction[Bundle[_data.DataView]]):
|
|
598
680
|
|
599
681
|
results[output_name] = result
|
600
682
|
|
683
|
+
if dynamic_outputs:
|
684
|
+
|
685
|
+
for output_name in dynamic_outputs:
|
686
|
+
|
687
|
+
result: _data.DataView = local_ctx.get(output_name)
|
688
|
+
|
689
|
+
if result is None or result.is_empty():
|
690
|
+
raise _ex.ERuntimeValidation(f"No data provided for [{output_name}] from model [{model_name}]")
|
691
|
+
|
692
|
+
results[output_name] = result
|
693
|
+
|
694
|
+
result_node_id = NodeId.of(output_name, self.node.id.namespace, _data.DataView)
|
695
|
+
result_node = BundleItemNode(result_node_id, self.node.id, output_name)
|
696
|
+
|
697
|
+
new_nodes[result_node_id] = result_node
|
698
|
+
|
699
|
+
output_section = _graph.GraphBuilder.build_runtime_outputs(dynamic_outputs, self.node.id.namespace)
|
700
|
+
new_nodes.update(output_section.nodes)
|
701
|
+
|
702
|
+
ctx_id = NodeId.of("trac_job_result", self.node.id.namespace, result_type=None)
|
703
|
+
new_deps[ctx_id] = list(_graph.Dependency(nid, _graph.DependencyType.HARD) for nid in output_section.outputs)
|
704
|
+
|
705
|
+
self.node_callback.send_graph_updates(new_nodes, new_deps)
|
706
|
+
|
601
707
|
return results
|
602
708
|
|
603
709
|
|
710
|
+
class ChildJobFunction(NodeFunction[None]):
|
711
|
+
|
712
|
+
def __init__(self, node: ChildJobNode):
|
713
|
+
super().__init__()
|
714
|
+
self.node = node
|
715
|
+
|
716
|
+
def _execute(self, ctx: NodeContext):
|
717
|
+
# This node should never execute, the engine intercepts child job nodes and provides special handling
|
718
|
+
raise _ex.ETracInternal("Child job was not processed correctly (this is a bug)")
|
719
|
+
|
720
|
+
|
721
|
+
|
604
722
|
# ----------------------------------------------------------------------------------------------------------------------
|
605
723
|
# FUNCTION RESOLUTION
|
606
724
|
# ----------------------------------------------------------------------------------------------------------------------
|
@@ -621,6 +739,14 @@ class FunctionResolver:
|
|
621
739
|
:py:class:`NodeFunction <NodeFunction>`
|
622
740
|
"""
|
623
741
|
|
742
|
+
# TODO: Validate consistency for resource keys
|
743
|
+
# Storage key should be validated for load data, save data and run model with storage access
|
744
|
+
# Repository key should be validated for import model (and explicitly for run model)
|
745
|
+
|
746
|
+
# Currently jobs with missing resources will fail at runtime, with a suitable error
|
747
|
+
# The resolver is called during graph building
|
748
|
+
# Putting the check here will raise a consistency error before the job starts processing
|
749
|
+
|
624
750
|
__ResolveFunc = tp.Callable[['FunctionResolver', Node[_T]], NodeFunction[_T]]
|
625
751
|
|
626
752
|
def __init__(self, models: _models.ModelLoader, storage: _storage.StorageManager):
|
@@ -655,12 +781,13 @@ class FunctionResolver:
|
|
655
781
|
|
656
782
|
def resolve_run_model_node(self, node: RunModelNode) -> NodeFunction:
|
657
783
|
|
784
|
+
# TODO: Verify model_class against model_def
|
785
|
+
|
658
786
|
model_class = self._models.load_model_class(node.model_scope, node.model_def)
|
659
787
|
checkout_directory = self._models.model_load_checkout_directory(node.model_scope, node.model_def)
|
788
|
+
storage_manager = self._storage if node.storage_access else None
|
660
789
|
|
661
|
-
|
662
|
-
|
663
|
-
return RunModelFunc(node, model_class, checkout_directory)
|
790
|
+
return RunModelFunc(node, model_class, checkout_directory, storage_manager)
|
664
791
|
|
665
792
|
__basic_node_mapping: tp.Dict[Node.__class__, NodeFunction.__class__] = {
|
666
793
|
|
@@ -674,6 +801,8 @@ class FunctionResolver:
|
|
674
801
|
SaveJobResultNode: SaveJobResultFunc,
|
675
802
|
DataResultNode: DataResultFunc,
|
676
803
|
StaticValueNode: StaticValueFunc,
|
804
|
+
RuntimeOutputsNode: RuntimeOutputsFunc,
|
805
|
+
ChildJobNode: ChildJobFunction,
|
677
806
|
BundleItemNode: NoopFunc,
|
678
807
|
NoopNode: NoopFunc,
|
679
808
|
RunModelResultNode: NoopFunc
|
tracdap/rt/_exec/graph.py
CHANGED
@@ -12,8 +12,6 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from __future__ import annotations
|
16
|
-
|
17
15
|
import pathlib
|
18
16
|
import typing as tp
|
19
17
|
import dataclasses as dc
|
@@ -38,7 +36,7 @@ class NodeNamespace:
|
|
38
36
|
return cls.__ROOT
|
39
37
|
|
40
38
|
name: str
|
41
|
-
parent: tp.Optional[NodeNamespace] = dc.field(default_factory=lambda: NodeNamespace.root())
|
39
|
+
parent: "tp.Optional[NodeNamespace]" = dc.field(default_factory=lambda: NodeNamespace.root())
|
42
40
|
|
43
41
|
def __str__(self):
|
44
42
|
if self is self.__ROOT:
|
@@ -62,7 +60,7 @@ class NodeNamespace:
|
|
62
60
|
class NodeId(tp.Generic[_T]):
|
63
61
|
|
64
62
|
@staticmethod
|
65
|
-
def of(name: str, namespace: NodeNamespace, result_type: tp.Type[_T]) -> NodeId[_T]:
|
63
|
+
def of(name: str, namespace: NodeNamespace, result_type: tp.Type[_T]) -> "NodeId[_T]":
|
66
64
|
return NodeId(name, namespace, result_type)
|
67
65
|
|
68
66
|
name: str
|
@@ -83,8 +81,8 @@ class DependencyType:
|
|
83
81
|
immediate: bool = True
|
84
82
|
tolerant: bool = False
|
85
83
|
|
86
|
-
HARD: tp.ClassVar[DependencyType]
|
87
|
-
TOLERANT: tp.ClassVar[DependencyType]
|
84
|
+
HARD: "tp.ClassVar[DependencyType]"
|
85
|
+
TOLERANT: "tp.ClassVar[DependencyType]"
|
88
86
|
|
89
87
|
|
90
88
|
DependencyType.HARD = DependencyType(immediate=True, tolerant=False)
|
@@ -93,6 +91,13 @@ DependencyType.TOLERANT = DependencyType(immediate=True, tolerant=True)
|
|
93
91
|
DependencyType.DELAYED = DependencyType(immediate=False, tolerant=False)
|
94
92
|
|
95
93
|
|
94
|
+
@dc.dataclass(frozen=True)
|
95
|
+
class Dependency:
|
96
|
+
|
97
|
+
node_id: NodeId
|
98
|
+
dependency_type: DependencyType
|
99
|
+
|
100
|
+
|
96
101
|
@dc.dataclass(frozen=True)
|
97
102
|
class Node(tp.Generic[_T]):
|
98
103
|
|
@@ -165,6 +170,17 @@ class GraphSection:
|
|
165
170
|
must_run: tp.List[NodeId] = dc.field(default_factory=list)
|
166
171
|
|
167
172
|
|
173
|
+
Bundle: tp.Generic[_T] = tp.Dict[str, _T]
|
174
|
+
ObjectBundle = Bundle[meta.ObjectDefinition]
|
175
|
+
|
176
|
+
|
177
|
+
@dc.dataclass(frozen=True)
|
178
|
+
class JobOutputs:
|
179
|
+
|
180
|
+
objects: tp.Dict[str, NodeId[meta.ObjectDefinition]] = dc.field(default_factory=dict)
|
181
|
+
bundles: tp.List[NodeId[ObjectBundle]] = dc.field(default_factory=list)
|
182
|
+
|
183
|
+
|
168
184
|
# TODO: Where does this go?
|
169
185
|
@dc.dataclass(frozen=True)
|
170
186
|
class JobResultSpec:
|
@@ -179,10 +195,6 @@ class JobResultSpec:
|
|
179
195
|
# ----------------------------------------------------------------------------------------------------------------------
|
180
196
|
|
181
197
|
|
182
|
-
Bundle: tp.Generic[_T] = tp.Dict[str, _T]
|
183
|
-
ObjectBundle = Bundle[meta.ObjectDefinition]
|
184
|
-
|
185
|
-
|
186
198
|
@_node_type
|
187
199
|
class NoopNode(Node):
|
188
200
|
pass
|
@@ -354,6 +366,7 @@ class RunModelNode(Node[Bundle[_data.DataView]]):
|
|
354
366
|
model_def: meta.ModelDefinition
|
355
367
|
parameter_ids: tp.FrozenSet[NodeId]
|
356
368
|
input_ids: tp.FrozenSet[NodeId]
|
369
|
+
storage_access: tp.Optional[tp.List[str]] = None
|
357
370
|
|
358
371
|
def _node_dependencies(self) -> tp.Dict[NodeId, DependencyType]:
|
359
372
|
return {dep_id: DependencyType.HARD for dep_id in [*self.parameter_ids, *self.input_ids]}
|
@@ -368,16 +381,28 @@ class RunModelResultNode(Node[None]):
|
|
368
381
|
return {self.model_id: DependencyType.HARD}
|
369
382
|
|
370
383
|
|
384
|
+
@_node_type
|
385
|
+
class RuntimeOutputsNode(Node[JobOutputs]):
|
386
|
+
|
387
|
+
outputs: JobOutputs
|
388
|
+
|
389
|
+
def _node_dependencies(self) -> tp.Dict[NodeId, DependencyType]:
|
390
|
+
dep_ids = [*self.outputs.bundles, *self.outputs.objects.values()]
|
391
|
+
return {node_id: DependencyType.HARD for node_id in dep_ids}
|
392
|
+
|
393
|
+
|
371
394
|
@_node_type
|
372
395
|
class BuildJobResultNode(Node[cfg.JobResult]):
|
373
396
|
|
374
397
|
job_id: meta.TagHeader
|
375
398
|
|
376
|
-
|
377
|
-
|
399
|
+
outputs: JobOutputs
|
400
|
+
runtime_outputs: tp.Optional[NodeId[JobOutputs]] = None
|
378
401
|
|
379
402
|
def _node_dependencies(self) -> tp.Dict[NodeId, DependencyType]:
|
380
|
-
dep_ids = [*self.bundles, *self.objects.values()]
|
403
|
+
dep_ids = [*self.outputs.bundles, *self.outputs.objects.values()]
|
404
|
+
if self.runtime_outputs is not None:
|
405
|
+
dep_ids.append(self.runtime_outputs)
|
381
406
|
return {node_id: DependencyType.HARD for node_id in dep_ids}
|
382
407
|
|
383
408
|
|
@@ -389,3 +414,12 @@ class SaveJobResultNode(Node[None]):
|
|
389
414
|
|
390
415
|
def _node_dependencies(self) -> tp.Dict[NodeId, DependencyType]:
|
391
416
|
return {self.job_result_id: DependencyType.HARD}
|
417
|
+
|
418
|
+
|
419
|
+
@_node_type
|
420
|
+
class ChildJobNode(Node[cfg.JobResult]):
|
421
|
+
|
422
|
+
job_id: meta.TagHeader
|
423
|
+
job_def: meta.JobDefinition
|
424
|
+
|
425
|
+
graph: Graph
|