tracdap-runtime 0.6.1.dev3__py3-none-any.whl → 0.6.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tracdap/rt/_exec/actors.py +87 -10
- tracdap/rt/_exec/context.py +25 -1
- tracdap/rt/_exec/dev_mode.py +277 -221
- tracdap/rt/_exec/engine.py +79 -14
- tracdap/rt/_exec/functions.py +37 -8
- tracdap/rt/_exec/graph.py +2 -0
- tracdap/rt/_exec/graph_builder.py +118 -56
- tracdap/rt/_exec/runtime.py +108 -37
- tracdap/rt/_exec/server.py +345 -0
- tracdap/rt/_impl/config_parser.py +219 -49
- tracdap/rt/_impl/data.py +14 -0
- tracdap/rt/_impl/grpc/__init__.py +13 -0
- tracdap/rt/_impl/grpc/codec.py +99 -0
- tracdap/rt/_impl/grpc/tracdap/api/internal/runtime_pb2.py +51 -0
- tracdap/rt/_impl/grpc/tracdap/api/internal/runtime_pb2.pyi +61 -0
- tracdap/rt/_impl/grpc/tracdap/api/internal/runtime_pb2_grpc.py +183 -0
- tracdap/rt/_impl/grpc/tracdap/metadata/common_pb2.py +33 -0
- tracdap/rt/_impl/grpc/tracdap/metadata/common_pb2.pyi +34 -0
- tracdap/rt/{metadata → _impl/grpc/tracdap/metadata}/custom_pb2.py +5 -5
- tracdap/rt/_impl/grpc/tracdap/metadata/custom_pb2.pyi +15 -0
- tracdap/rt/_impl/grpc/tracdap/metadata/data_pb2.py +51 -0
- tracdap/rt/_impl/grpc/tracdap/metadata/data_pb2.pyi +115 -0
- tracdap/rt/_impl/grpc/tracdap/metadata/file_pb2.py +28 -0
- tracdap/rt/_impl/grpc/tracdap/metadata/file_pb2.pyi +22 -0
- tracdap/rt/_impl/grpc/tracdap/metadata/flow_pb2.py +59 -0
- tracdap/rt/_impl/grpc/tracdap/metadata/flow_pb2.pyi +109 -0
- tracdap/rt/_impl/grpc/tracdap/metadata/job_pb2.py +76 -0
- tracdap/rt/_impl/grpc/tracdap/metadata/job_pb2.pyi +177 -0
- tracdap/rt/_impl/grpc/tracdap/metadata/model_pb2.py +63 -0
- tracdap/rt/_impl/grpc/tracdap/metadata/model_pb2.pyi +119 -0
- tracdap/rt/_impl/grpc/tracdap/metadata/object_id_pb2.py +32 -0
- tracdap/rt/_impl/grpc/tracdap/metadata/object_id_pb2.pyi +68 -0
- tracdap/rt/_impl/grpc/tracdap/metadata/object_pb2.py +40 -0
- tracdap/rt/_impl/grpc/tracdap/metadata/object_pb2.pyi +46 -0
- tracdap/rt/_impl/grpc/tracdap/metadata/search_pb2.py +39 -0
- tracdap/rt/_impl/grpc/tracdap/metadata/search_pb2.pyi +83 -0
- tracdap/rt/_impl/grpc/tracdap/metadata/stoarge_pb2.py +50 -0
- tracdap/rt/_impl/grpc/tracdap/metadata/stoarge_pb2.pyi +89 -0
- tracdap/rt/_impl/grpc/tracdap/metadata/tag_pb2.py +34 -0
- tracdap/rt/_impl/grpc/tracdap/metadata/tag_pb2.pyi +26 -0
- tracdap/rt/_impl/grpc/tracdap/metadata/tag_update_pb2.py +30 -0
- tracdap/rt/_impl/grpc/tracdap/metadata/tag_update_pb2.pyi +34 -0
- tracdap/rt/_impl/grpc/tracdap/metadata/type_pb2.py +47 -0
- tracdap/rt/_impl/grpc/tracdap/metadata/type_pb2.pyi +101 -0
- tracdap/rt/_impl/guard_rails.py +26 -6
- tracdap/rt/_impl/models.py +25 -0
- tracdap/rt/_impl/static_api.py +27 -9
- tracdap/rt/_impl/type_system.py +17 -0
- tracdap/rt/_impl/validation.py +10 -0
- tracdap/rt/_plugins/config_local.py +49 -0
- tracdap/rt/_version.py +1 -1
- tracdap/rt/api/hook.py +10 -3
- tracdap/rt/api/model_api.py +22 -0
- tracdap/rt/api/static_api.py +79 -19
- tracdap/rt/config/__init__.py +3 -3
- tracdap/rt/config/common.py +10 -0
- tracdap/rt/config/platform.py +9 -19
- tracdap/rt/config/runtime.py +2 -0
- tracdap/rt/ext/config.py +34 -0
- tracdap/rt/ext/embed.py +1 -3
- tracdap/rt/ext/plugins.py +47 -6
- tracdap/rt/launch/cli.py +7 -5
- tracdap/rt/launch/launch.py +49 -12
- tracdap/rt/metadata/__init__.py +24 -24
- tracdap/rt/metadata/common.py +7 -7
- tracdap/rt/metadata/custom.py +2 -0
- tracdap/rt/metadata/data.py +28 -5
- tracdap/rt/metadata/file.py +2 -0
- tracdap/rt/metadata/flow.py +66 -4
- tracdap/rt/metadata/job.py +56 -16
- tracdap/rt/metadata/model.py +10 -0
- tracdap/rt/metadata/object.py +3 -0
- tracdap/rt/metadata/object_id.py +9 -9
- tracdap/rt/metadata/search.py +35 -13
- tracdap/rt/metadata/stoarge.py +64 -6
- tracdap/rt/metadata/tag_update.py +21 -7
- tracdap/rt/metadata/type.py +28 -13
- {tracdap_runtime-0.6.1.dev3.dist-info → tracdap_runtime-0.6.3.dist-info}/METADATA +22 -19
- tracdap_runtime-0.6.3.dist-info/RECORD +112 -0
- {tracdap_runtime-0.6.1.dev3.dist-info → tracdap_runtime-0.6.3.dist-info}/WHEEL +1 -1
- tracdap/rt/config/common_pb2.py +0 -55
- tracdap/rt/config/job_pb2.py +0 -42
- tracdap/rt/config/platform_pb2.py +0 -71
- tracdap/rt/config/result_pb2.py +0 -37
- tracdap/rt/config/runtime_pb2.py +0 -42
- tracdap/rt/ext/_guard.py +0 -37
- tracdap/rt/metadata/common_pb2.py +0 -33
- tracdap/rt/metadata/data_pb2.py +0 -51
- tracdap/rt/metadata/file_pb2.py +0 -28
- tracdap/rt/metadata/flow_pb2.py +0 -55
- tracdap/rt/metadata/job_pb2.py +0 -76
- tracdap/rt/metadata/model_pb2.py +0 -51
- tracdap/rt/metadata/object_id_pb2.py +0 -32
- tracdap/rt/metadata/object_pb2.py +0 -35
- tracdap/rt/metadata/search_pb2.py +0 -39
- tracdap/rt/metadata/stoarge_pb2.py +0 -50
- tracdap/rt/metadata/tag_pb2.py +0 -34
- tracdap/rt/metadata/tag_update_pb2.py +0 -30
- tracdap/rt/metadata/type_pb2.py +0 -48
- tracdap_runtime-0.6.1.dev3.dist-info/RECORD +0 -96
- {tracdap_runtime-0.6.1.dev3.dist-info → tracdap_runtime-0.6.3.dist-info}/LICENSE +0 -0
- {tracdap_runtime-0.6.1.dev3.dist-info → tracdap_runtime-0.6.3.dist-info}/top_level.txt +0 -0
tracdap/rt/_exec/engine.py
CHANGED
@@ -19,6 +19,7 @@ import dataclasses as dc
|
|
19
19
|
import enum
|
20
20
|
import typing as tp
|
21
21
|
|
22
|
+
import tracdap.rt.metadata as _meta
|
22
23
|
import tracdap.rt.config as _cfg
|
23
24
|
import tracdap.rt.exceptions as _ex
|
24
25
|
import tracdap.rt._exec.actors as _actors
|
@@ -28,7 +29,6 @@ import tracdap.rt._impl.models as _models # noqa
|
|
28
29
|
import tracdap.rt._impl.data as _data # noqa
|
29
30
|
import tracdap.rt._impl.storage as _storage # noqa
|
30
31
|
import tracdap.rt._impl.util as _util # noqa
|
31
|
-
from .actors import Signal
|
32
32
|
|
33
33
|
from .graph import NodeId
|
34
34
|
|
@@ -66,6 +66,18 @@ class _EngineContext:
|
|
66
66
|
failed_nodes: tp.Set[NodeId] = dc.field(default_factory=set)
|
67
67
|
|
68
68
|
|
69
|
+
@dc.dataclass
|
70
|
+
class _JobState:
|
71
|
+
|
72
|
+
job_id: _meta.TagHeader
|
73
|
+
job_config: _cfg.JobConfig
|
74
|
+
|
75
|
+
actor_id: _actors.ActorId = None
|
76
|
+
|
77
|
+
job_result: _cfg.JobResult = None
|
78
|
+
job_error: Exception = None
|
79
|
+
|
80
|
+
|
69
81
|
class TracEngine(_actors.Actor):
|
70
82
|
|
71
83
|
"""
|
@@ -88,7 +100,7 @@ class TracEngine(_actors.Actor):
|
|
88
100
|
self._storage = storage
|
89
101
|
self._notify_callback = notify_callback
|
90
102
|
|
91
|
-
self.
|
103
|
+
self._jobs: tp.Dict[str, _JobState] = dict()
|
92
104
|
|
93
105
|
def on_start(self):
|
94
106
|
|
@@ -98,7 +110,7 @@ class TracEngine(_actors.Actor):
|
|
98
110
|
|
99
111
|
self._log.info("Engine shutdown complete")
|
100
112
|
|
101
|
-
def on_signal(self, signal: Signal) -> tp.Optional[bool]:
|
113
|
+
def on_signal(self, signal: _actors.Signal) -> tp.Optional[bool]:
|
102
114
|
|
103
115
|
# Failed signals can propagate from leaf nodes up the actor tree for a job
|
104
116
|
# If the failure goes all the way up the tree without being handled, it will reach the engine node
|
@@ -110,8 +122,8 @@ class TracEngine(_actors.Actor):
|
|
110
122
|
failed_job_key = None
|
111
123
|
|
112
124
|
# Look for the job key corresponding to the failed actor
|
113
|
-
for job_key,
|
114
|
-
if
|
125
|
+
for job_key, job_state in self._jobs.items():
|
126
|
+
if job_state.actor_id == signal.sender:
|
115
127
|
failed_job_key = job_key
|
116
128
|
|
117
129
|
# If the job is still live, call job_failed explicitly
|
@@ -147,19 +159,34 @@ class TracEngine(_actors.Actor):
|
|
147
159
|
job_processor = JobProcessor(job_key, job_config, result_spec,self._models, self._storage)
|
148
160
|
job_actor_id = self.actors().spawn(job_processor)
|
149
161
|
|
150
|
-
|
151
|
-
|
162
|
+
job_state = _JobState(job_config.jobId, job_config)
|
163
|
+
job_state.actor_id = job_actor_id
|
164
|
+
|
165
|
+
self._jobs[job_key] = job_state
|
166
|
+
|
167
|
+
@_actors.Message
|
168
|
+
def get_job_list(self):
|
169
|
+
|
170
|
+
job_list = list(map(self._get_job_info, self._jobs.keys()))
|
171
|
+
self.actors().reply("job_list", job_list)
|
172
|
+
|
173
|
+
@_actors.Message
|
174
|
+
def get_job_details(self, job_key: str, details: bool):
|
175
|
+
|
176
|
+
details = self._get_job_info(job_key, details)
|
177
|
+
self.actors().reply("job_details", details)
|
152
178
|
|
153
179
|
@_actors.Message
|
154
180
|
def job_succeeded(self, job_key: str, job_result: _cfg.JobResult):
|
155
181
|
|
156
182
|
# Ignore duplicate messages from the job processor (can happen in unusual error cases)
|
157
|
-
if job_key not in self.
|
183
|
+
if job_key not in self._jobs:
|
158
184
|
self._log.warning(f"Ignoring [job_succeeded] message, job [{job_key}] has already completed")
|
159
185
|
return
|
160
186
|
|
161
187
|
self._log.info(f"Recording job as successful: {job_key}")
|
162
188
|
|
189
|
+
self._jobs[job_key].job_result = job_result
|
163
190
|
self._finalize_job(job_key)
|
164
191
|
|
165
192
|
if self._notify_callback is not None:
|
@@ -169,12 +196,13 @@ class TracEngine(_actors.Actor):
|
|
169
196
|
def job_failed(self, job_key: str, error: Exception):
|
170
197
|
|
171
198
|
# Ignore duplicate messages from the job processor (can happen in unusual error cases)
|
172
|
-
if job_key not in self.
|
199
|
+
if job_key not in self._jobs:
|
173
200
|
self._log.warning(f"Ignoring [job_failed] message, job [{job_key}] has already completed")
|
174
201
|
return
|
175
202
|
|
176
203
|
self._log.error(f"Recording job as failed: {job_key}")
|
177
204
|
|
205
|
+
self._jobs[job_key].job_error = error
|
178
206
|
self._finalize_job(job_key)
|
179
207
|
|
180
208
|
if self._notify_callback is not None:
|
@@ -182,10 +210,47 @@ class TracEngine(_actors.Actor):
|
|
182
210
|
|
183
211
|
def _finalize_job(self, job_key: str):
|
184
212
|
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
213
|
+
# Stop the actor but keep the job state available for status / results queries
|
214
|
+
|
215
|
+
# In the future, job state will need to be expunged after some period of time
|
216
|
+
# For now each instance of the runtime only processes one job so no need to worry
|
217
|
+
|
218
|
+
job_state = self._jobs.get(job_key)
|
219
|
+
job_actor_id = job_state.actor_id if job_state is not None else None
|
220
|
+
|
221
|
+
if job_actor_id is not None:
|
222
|
+
self.actors().stop(job_actor_id)
|
223
|
+
job_state.actor_id = None
|
224
|
+
|
225
|
+
def _get_job_info(self, job_key: str, details: bool = False) -> tp.Optional[_cfg.JobResult]:
|
226
|
+
|
227
|
+
job_state = self._jobs.get(job_key)
|
228
|
+
|
229
|
+
if job_state is None:
|
230
|
+
return None
|
231
|
+
|
232
|
+
job_result = _cfg.JobResult()
|
233
|
+
job_result.jobId = job_state.job_id
|
234
|
+
|
235
|
+
if job_state.actor_id is not None:
|
236
|
+
job_result.statusCode = _meta.JobStatusCode.RUNNING
|
237
|
+
|
238
|
+
elif job_state.job_result is not None:
|
239
|
+
job_result.statusCode = job_state.job_result.statusCode
|
240
|
+
job_result.statusMessage = job_state.job_result.statusMessage
|
241
|
+
if details:
|
242
|
+
job_result.results = job_state.job_result.results or dict()
|
243
|
+
|
244
|
+
elif job_state.job_error is not None:
|
245
|
+
job_result.statusCode = _meta.JobStatusCode.FAILED
|
246
|
+
job_result.statusMessage = str(job_state.job_error.args[0])
|
247
|
+
|
248
|
+
else:
|
249
|
+
# Alternatively return UNKNOWN status or throw an error here
|
250
|
+
job_result.statusCode = _meta.JobStatusCode.FAILED
|
251
|
+
job_result.statusMessage = "No details available"
|
252
|
+
|
253
|
+
return job_result
|
189
254
|
|
190
255
|
|
191
256
|
class JobProcessor(_actors.Actor):
|
@@ -218,7 +283,7 @@ class JobProcessor(_actors.Actor):
|
|
218
283
|
self._log.info(f"Cleaning up job [{self.job_key}]")
|
219
284
|
self._models.destroy_scope(self.job_key)
|
220
285
|
|
221
|
-
def on_signal(self, signal: Signal) -> tp.Optional[bool]:
|
286
|
+
def on_signal(self, signal: _actors.Signal) -> tp.Optional[bool]:
|
222
287
|
|
223
288
|
if signal.message == _actors.SignalNames.FAILED and isinstance(signal, _actors.ErrorSignal):
|
224
289
|
|
tracdap/rt/_exec/functions.py
CHANGED
@@ -248,6 +248,10 @@ class DataViewFunc(NodeFunction[_data.DataView]):
|
|
248
248
|
root_item = _ctx_lookup(self.node.root_item, ctx)
|
249
249
|
root_part_key = _data.DataPartKey.for_root()
|
250
250
|
|
251
|
+
# Map empty item -> emtpy view (for optional inputs not supplied)
|
252
|
+
if root_item.is_empty():
|
253
|
+
return _data.DataView.create_empty()
|
254
|
+
|
251
255
|
data_view = _data.DataView.for_trac_schema(self.node.schema)
|
252
256
|
data_view = _data.DataMapping.add_item_to_view(data_view, root_part_key, root_item)
|
253
257
|
|
@@ -263,6 +267,10 @@ class DataItemFunc(NodeFunction[_data.DataItem]):
|
|
263
267
|
|
264
268
|
data_view = _ctx_lookup(self.node.data_view_id, ctx)
|
265
269
|
|
270
|
+
# Map empty view -> emtpy item (for optional outputs not supplied)
|
271
|
+
if data_view.is_empty():
|
272
|
+
return _data.DataItem.create_empty()
|
273
|
+
|
266
274
|
# TODO: Support selecting data item described by self.node
|
267
275
|
|
268
276
|
# Selecting data item for part-root, delta=0
|
@@ -280,6 +288,12 @@ class DataResultFunc(NodeFunction[ObjectBundle]):
|
|
280
288
|
|
281
289
|
def _execute(self, ctx: NodeContext) -> ObjectBundle:
|
282
290
|
|
291
|
+
data_item = _ctx_lookup(self.node.data_item_id, ctx)
|
292
|
+
|
293
|
+
# Do not record output metadata for optional outputs that are empty
|
294
|
+
if data_item.is_empty():
|
295
|
+
return {}
|
296
|
+
|
283
297
|
data_spec = _ctx_lookup(self.node.data_spec_id, ctx)
|
284
298
|
|
285
299
|
# TODO: Check result of save operation
|
@@ -451,6 +465,13 @@ class SaveDataFunc(NodeFunction[None], _LoadSaveDataFunc):
|
|
451
465
|
|
452
466
|
def _execute(self, ctx: NodeContext):
|
453
467
|
|
468
|
+
# Item to be saved should exist in the current context
|
469
|
+
data_item = _ctx_lookup(self.node.data_item_id, ctx)
|
470
|
+
|
471
|
+
# Do not save empty outputs (optional outputs that were not produced)
|
472
|
+
if data_item.is_empty():
|
473
|
+
return
|
474
|
+
|
454
475
|
# This function assumes that metadata has already been generated as the data_spec
|
455
476
|
# i.e. it is already known which incarnation / copy of the data will be created
|
456
477
|
|
@@ -458,9 +479,6 @@ class SaveDataFunc(NodeFunction[None], _LoadSaveDataFunc):
|
|
458
479
|
data_copy = self._choose_copy(data_spec.data_item, data_spec.storage_def)
|
459
480
|
data_storage = self.storage.get_data_storage(data_copy.storageKey)
|
460
481
|
|
461
|
-
# Item to be saved should exist in the current context
|
462
|
-
data_item = _ctx_lookup(self.node.data_item_id, ctx)
|
463
|
-
|
464
482
|
# Current implementation will always put an Arrow table in the data item
|
465
483
|
# Empty tables are allowed, so explicitly check if table is None
|
466
484
|
# Testing "if not data_item.table" will fail for empty tables
|
@@ -567,12 +585,23 @@ class RunModelFunc(NodeFunction[Bundle[_data.DataView]]):
|
|
567
585
|
msg = f"There was an unhandled error in the model: {str(e)}{details}"
|
568
586
|
raise _ex.EModelExec(msg) from e
|
569
587
|
|
570
|
-
#
|
571
|
-
|
572
|
-
|
573
|
-
|
588
|
+
# Check required outputs are present and build the results bundle
|
589
|
+
|
590
|
+
results: Bundle[_data.DataView] = dict()
|
591
|
+
|
592
|
+
for output_name, output_schema in model_def.outputs.items():
|
593
|
+
|
594
|
+
result: _data.DataView = local_ctx.get(output_name)
|
595
|
+
|
596
|
+
if result is None or result.is_empty():
|
597
|
+
if not output_schema.optional:
|
598
|
+
model_name = self.model_class.__name__
|
599
|
+
raise _ex.ERuntimeValidation(f"Missing required output [{output_name}] from model [{model_name}]")
|
600
|
+
|
601
|
+
if result is not None:
|
602
|
+
results[output_name] = result
|
574
603
|
|
575
|
-
return
|
604
|
+
return results
|
576
605
|
|
577
606
|
|
578
607
|
# ----------------------------------------------------------------------------------------------------------------------
|
tracdap/rt/_exec/graph.py
CHANGED
@@ -297,6 +297,7 @@ class DataItemNode(MappingNode[_data.DataItem]):
|
|
297
297
|
class DataResultNode(Node[ObjectBundle]):
|
298
298
|
|
299
299
|
output_name: str
|
300
|
+
data_item_id: NodeId[_data.DataItem]
|
300
301
|
data_spec_id: NodeId[_data.DataSpec]
|
301
302
|
data_save_id: NodeId[type(None)]
|
302
303
|
|
@@ -306,6 +307,7 @@ class DataResultNode(Node[ObjectBundle]):
|
|
306
307
|
def _node_dependencies(self) -> tp.Dict[NodeId, DependencyType]:
|
307
308
|
|
308
309
|
return {
|
310
|
+
self.data_item_id: DependencyType.HARD,
|
309
311
|
self.data_spec_id: DependencyType.HARD,
|
310
312
|
self.data_save_id: DependencyType.HARD}
|
311
313
|
|
@@ -14,9 +14,6 @@
|
|
14
14
|
|
15
15
|
from __future__ import annotations
|
16
16
|
|
17
|
-
import copy
|
18
|
-
import dataclasses # noqa
|
19
|
-
|
20
17
|
import tracdap.rt.config as config
|
21
18
|
import tracdap.rt.exceptions as _ex
|
22
19
|
import tracdap.rt._impl.data as _data # noqa
|
@@ -123,12 +120,14 @@ class GraphBuilder:
|
|
123
120
|
job_namespace: NodeNamespace, job_push_id: NodeId) \
|
124
121
|
-> GraphSection:
|
125
122
|
|
123
|
+
target_selector = job_config.job.runModel.model
|
124
|
+
target_obj = _util.get_job_resource(target_selector, job_config)
|
125
|
+
target_def = target_obj.model
|
126
|
+
job_def = job_config.job.runModel
|
127
|
+
|
126
128
|
return cls.build_calculation_job(
|
127
129
|
job_config, result_spec, job_namespace, job_push_id,
|
128
|
-
|
129
|
-
job_config.job.runModel.parameters,
|
130
|
-
job_config.job.runModel.inputs,
|
131
|
-
job_config.job.runModel.outputs)
|
130
|
+
target_selector, target_def, job_def)
|
132
131
|
|
133
132
|
@classmethod
|
134
133
|
def build_run_flow_job(
|
@@ -136,40 +135,53 @@ class GraphBuilder:
|
|
136
135
|
job_namespace: NodeNamespace, job_push_id: NodeId) \
|
137
136
|
-> GraphSection:
|
138
137
|
|
138
|
+
target_selector = job_config.job.runFlow.flow
|
139
|
+
target_obj = _util.get_job_resource(target_selector, job_config)
|
140
|
+
target_def = target_obj.flow
|
141
|
+
job_def = job_config.job.runFlow
|
142
|
+
|
139
143
|
return cls.build_calculation_job(
|
140
144
|
job_config, result_spec, job_namespace, job_push_id,
|
141
|
-
|
142
|
-
job_config.job.runFlow.parameters,
|
143
|
-
job_config.job.runFlow.inputs,
|
144
|
-
job_config.job.runFlow.outputs)
|
145
|
+
target_selector, target_def, job_def)
|
145
146
|
|
146
147
|
@classmethod
|
147
148
|
def build_calculation_job(
|
148
149
|
cls, job_config: config.JobConfig, result_spec: JobResultSpec,
|
149
150
|
job_namespace: NodeNamespace, job_push_id: NodeId,
|
150
|
-
|
151
|
-
|
151
|
+
target_selector: meta.TagSelector,
|
152
|
+
target_def: tp.Union[meta.ModelDefinition, meta.FlowDefinition],
|
153
|
+
job_def: tp.Union[meta.RunModelJob, meta.RunFlowJob]) \
|
152
154
|
-> GraphSection:
|
153
155
|
|
154
156
|
# The main execution graph can run directly in the job context, no need to do a context push
|
155
157
|
# since inputs and outputs in this context line up with the top level execution task
|
156
158
|
|
159
|
+
# Required / provided items are the same for RUN_MODEL and RUN_FLOW jobs
|
160
|
+
|
161
|
+
required_params = target_def.parameters
|
162
|
+
required_inputs = target_def.inputs
|
163
|
+
required_outputs = target_def.outputs
|
164
|
+
|
165
|
+
provided_params = job_def.parameters
|
166
|
+
provided_inputs = job_def.inputs
|
167
|
+
provided_outputs = job_def.outputs
|
168
|
+
|
157
169
|
params_section = cls.build_job_parameters(
|
158
|
-
job_namespace,
|
170
|
+
job_namespace, required_params, provided_params,
|
159
171
|
explicit_deps=[job_push_id])
|
160
172
|
|
161
173
|
input_section = cls.build_job_inputs(
|
162
|
-
job_config, job_namespace,
|
174
|
+
job_config, job_namespace, required_inputs, provided_inputs,
|
163
175
|
explicit_deps=[job_push_id])
|
164
176
|
|
165
|
-
exec_obj = _util.get_job_resource(
|
177
|
+
exec_obj = _util.get_job_resource(target_selector, job_config)
|
166
178
|
|
167
179
|
exec_section = cls.build_model_or_flow(
|
168
180
|
job_config, job_namespace, exec_obj,
|
169
181
|
explicit_deps=[job_push_id])
|
170
182
|
|
171
183
|
output_section = cls.build_job_outputs(
|
172
|
-
job_config, job_namespace,
|
184
|
+
job_config, job_namespace, required_outputs, provided_outputs,
|
173
185
|
explicit_deps=[job_push_id])
|
174
186
|
|
175
187
|
main_section = cls._join_sections(params_section, input_section, exec_section, output_section)
|
@@ -190,13 +202,22 @@ class GraphBuilder:
|
|
190
202
|
@classmethod
|
191
203
|
def build_job_parameters(
|
192
204
|
cls, job_namespace: NodeNamespace,
|
193
|
-
|
205
|
+
required_params: tp.Dict[str, meta.ModelParameter],
|
206
|
+
supplied_params: tp.Dict[str, meta.Value],
|
194
207
|
explicit_deps: tp.Optional[tp.List[NodeId]] = None) \
|
195
208
|
-> GraphSection:
|
196
209
|
|
197
210
|
nodes = dict()
|
198
211
|
|
199
|
-
for param_name,
|
212
|
+
for param_name, param_schema in required_params.items():
|
213
|
+
|
214
|
+
param_def = supplied_params.get(param_name)
|
215
|
+
|
216
|
+
if param_def is None:
|
217
|
+
if param_schema.defaultValue is not None:
|
218
|
+
param_def = param_schema.defaultValue
|
219
|
+
else:
|
220
|
+
raise _ex.EJobValidation(f"Missing required parameter: [{param_name}]")
|
200
221
|
|
201
222
|
param_id = NodeId(param_name, job_namespace, meta.Value)
|
202
223
|
param_node = StaticValueNode(param_id, param_def, explicit_deps=explicit_deps)
|
@@ -208,7 +229,8 @@ class GraphBuilder:
|
|
208
229
|
@classmethod
|
209
230
|
def build_job_inputs(
|
210
231
|
cls, job_config: config.JobConfig, job_namespace: NodeNamespace,
|
211
|
-
|
232
|
+
required_inputs: tp.Dict[str, meta.ModelInputSchema],
|
233
|
+
supplied_inputs: tp.Dict[str, meta.TagSelector],
|
212
234
|
explicit_deps: tp.Optional[tp.List[NodeId]] = None) \
|
213
235
|
-> GraphSection:
|
214
236
|
|
@@ -216,7 +238,18 @@ class GraphBuilder:
|
|
216
238
|
outputs = set()
|
217
239
|
must_run = list()
|
218
240
|
|
219
|
-
for input_name,
|
241
|
+
for input_name, input_schema in required_inputs.items():
|
242
|
+
|
243
|
+
data_selector = supplied_inputs.get(input_name)
|
244
|
+
|
245
|
+
if data_selector is None:
|
246
|
+
if input_schema.optional:
|
247
|
+
data_view_id = NodeId.of(input_name, job_namespace, _data.DataView)
|
248
|
+
nodes[data_view_id] = StaticValueNode(data_view_id, _data.DataView.create_empty())
|
249
|
+
outputs.add(data_view_id)
|
250
|
+
continue
|
251
|
+
else:
|
252
|
+
raise _ex.EJobValidation(f"Missing required input: [{input_name}]")
|
220
253
|
|
221
254
|
# Build a data spec using metadata from the job config
|
222
255
|
# For now we are always loading the root part, snap 0, delta 0
|
@@ -258,14 +291,24 @@ class GraphBuilder:
|
|
258
291
|
@classmethod
|
259
292
|
def build_job_outputs(
|
260
293
|
cls, job_config: config.JobConfig, job_namespace: NodeNamespace,
|
261
|
-
|
294
|
+
required_outputs: tp.Dict[str, meta.ModelOutputSchema],
|
295
|
+
supplied_outputs: tp.Dict[str, meta.TagSelector],
|
262
296
|
explicit_deps: tp.Optional[tp.List[NodeId]] = None) \
|
263
297
|
-> GraphSection:
|
264
298
|
|
265
299
|
nodes = {}
|
266
300
|
inputs = set()
|
267
301
|
|
268
|
-
for output_name,
|
302
|
+
for output_name, output_schema in required_outputs.items():
|
303
|
+
|
304
|
+
data_selector = supplied_outputs.get(output_name)
|
305
|
+
|
306
|
+
if data_selector is None:
|
307
|
+
if output_schema.optional:
|
308
|
+
optional_info = "(configuration is required for all optional outputs, in case they are produced)"
|
309
|
+
raise _ex.EJobValidation(f"Missing optional output: [{output_name}] {optional_info}")
|
310
|
+
else:
|
311
|
+
raise _ex.EJobValidation(f"Missing required output: [{output_name}]")
|
269
312
|
|
270
313
|
# Output data view must already exist in the namespace
|
271
314
|
data_view_id = NodeId.of(output_name, job_namespace, _data.DataView)
|
@@ -323,7 +366,8 @@ class GraphBuilder:
|
|
323
366
|
|
324
367
|
data_result_id = NodeId.of(f"{output_name}:RESULT", job_namespace, ObjectBundle)
|
325
368
|
data_result_node = DataResultNode(
|
326
|
-
data_result_id, output_name,
|
369
|
+
data_result_id, output_name,
|
370
|
+
data_item_id, data_spec_id, data_save_id,
|
327
371
|
output_data_key, output_storage_key)
|
328
372
|
|
329
373
|
nodes[data_spec_id] = data_spec_node
|
@@ -458,10 +502,10 @@ class GraphBuilder:
|
|
458
502
|
frozenset(parameter_ids), frozenset(input_ids),
|
459
503
|
explicit_deps=explicit_deps, bundle=model_id.namespace)
|
460
504
|
|
461
|
-
|
462
|
-
model_result_node = RunModelResultNode(
|
505
|
+
model_result_id = NodeId(f"{model_name}:RESULT", namespace)
|
506
|
+
model_result_node = RunModelResultNode(model_result_id, model_id)
|
463
507
|
|
464
|
-
nodes = {model_id: model_node,
|
508
|
+
nodes = {model_id: model_node, model_result_id: model_result_node}
|
465
509
|
|
466
510
|
# Create nodes for each model output
|
467
511
|
# The model node itself outputs a bundle (dictionary of named outputs)
|
@@ -474,7 +518,7 @@ class GraphBuilder:
|
|
474
518
|
nodes[output_id] = BundleItemNode(output_id, model_id, output_id.name)
|
475
519
|
|
476
520
|
# Assemble a graph to include the model and its outputs
|
477
|
-
return GraphSection(nodes, inputs={*parameter_ids, *input_ids}, outputs=output_ids, must_run=[
|
521
|
+
return GraphSection(nodes, inputs={*parameter_ids, *input_ids}, outputs=output_ids, must_run=[model_result_id])
|
478
522
|
|
479
523
|
@classmethod
|
480
524
|
def build_flow(
|
@@ -488,7 +532,7 @@ class GraphBuilder:
|
|
488
532
|
|
489
533
|
# https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm
|
490
534
|
|
491
|
-
|
535
|
+
# Group edges by source and target node
|
492
536
|
remaining_edges_by_target = {edge.target.node: [] for edge in flow_def.edges}
|
493
537
|
remaining_edges_by_source = {edge.source.node: [] for edge in flow_def.edges}
|
494
538
|
|
@@ -496,16 +540,14 @@ class GraphBuilder:
|
|
496
540
|
remaining_edges_by_target[edge.target.node].append(edge)
|
497
541
|
remaining_edges_by_source[edge.source.node].append(edge)
|
498
542
|
|
499
|
-
|
500
|
-
|
501
|
-
# Initial set of reachable flow nodes is just the input nodes
|
502
|
-
for node_name, node in list(remaining_nodes.items()):
|
503
|
-
if node.nodeType == meta.FlowNodeType.INPUT_NODE:
|
504
|
-
reachable_nodes[node_name] = node
|
505
|
-
del remaining_nodes[node_name]
|
506
|
-
|
543
|
+
# Group edges by target socket (only one edge per target in a consistent flow)
|
507
544
|
target_edges = {socket_key(edge.target): edge for edge in flow_def.edges}
|
508
545
|
|
546
|
+
# Initially parameters and inputs are reachable, everything else is not
|
547
|
+
def is_input(n): return n[1].nodeType in [meta.FlowNodeType.PARAMETER_NODE, meta.FlowNodeType.INPUT_NODE]
|
548
|
+
reachable_nodes = dict(filter(is_input, flow_def.nodes.items()))
|
549
|
+
remaining_nodes = dict(filter(lambda n: not is_input(n), flow_def.nodes.items()))
|
550
|
+
|
509
551
|
# Initial graph section for the flow is empty
|
510
552
|
graph_section = GraphSection({}, must_run=explicit_deps)
|
511
553
|
|
@@ -559,10 +601,16 @@ class GraphBuilder:
|
|
559
601
|
return NodeId(socket_name, namespace, result_type)
|
560
602
|
|
561
603
|
def edge_mapping(node_: str, socket_: str = None, result_type=None):
|
562
|
-
socket = meta.FlowSocket(node_, socket_)
|
563
|
-
edge = target_edges.get(
|
604
|
+
socket = socket_key(meta.FlowSocket(node_, socket_))
|
605
|
+
edge = target_edges.get(socket)
|
606
|
+
# Report missing edges as a job consistency error (this might happen sometimes in dev mode)
|
607
|
+
if edge is None:
|
608
|
+
raise _ex.EJobValidation(f"Inconsistent flow: Socket [{socket}] is not connected")
|
564
609
|
return socket_id(edge.source.node, edge.source.socket, result_type)
|
565
610
|
|
611
|
+
if node.nodeType == meta.FlowNodeType.PARAMETER_NODE:
|
612
|
+
return GraphSection({}, inputs={NodeId(node_name, namespace, result_type=meta.Value)})
|
613
|
+
|
566
614
|
if node.nodeType == meta.FlowNodeType.INPUT_NODE:
|
567
615
|
return GraphSection({}, inputs={NodeId(node_name, namespace, result_type=_data.DataView)})
|
568
616
|
|
@@ -573,32 +621,46 @@ class GraphBuilder:
|
|
573
621
|
|
574
622
|
if node.nodeType == meta.FlowNodeType.MODEL_NODE:
|
575
623
|
|
576
|
-
|
577
|
-
|
578
|
-
|
579
|
-
# TODO: Whether to use flow node or model_obj to build the push mapping?
|
624
|
+
param_mapping = {socket: edge_mapping(node_name, socket, meta.Value) for socket in node.parameters}
|
625
|
+
input_mapping = {socket: edge_mapping(node_name, socket, _data.DataView) for socket in node.inputs}
|
626
|
+
output_mapping = {socket: socket_id(node_name, socket, _data.DataView) for socket in node.outputs}
|
580
627
|
|
581
|
-
|
582
|
-
|
583
|
-
for input_name in model_obj.model.inputs
|
584
|
-
}
|
628
|
+
push_mapping = {**input_mapping, **param_mapping}
|
629
|
+
pop_mapping = output_mapping
|
585
630
|
|
586
|
-
|
587
|
-
|
588
|
-
for param_name in model_obj.model.parameters
|
589
|
-
}
|
631
|
+
model_selector = flow_job.models.get(node_name)
|
632
|
+
model_obj = _util.get_job_resource(model_selector, job_config)
|
590
633
|
|
591
|
-
|
634
|
+
# Missing models in the job config is a job consistency error
|
635
|
+
if model_obj is None or model_obj.objectType != meta.ObjectType.MODEL:
|
636
|
+
raise _ex.EJobValidation(f"No model was provided for flow node [{node_name}]")
|
592
637
|
|
593
|
-
|
594
|
-
|
595
|
-
for output_ in model_obj.model.outputs}
|
638
|
+
# Explicit check for model compatibility - report an error now, do not try build_model()
|
639
|
+
cls.check_model_compatibility(model_selector, model_obj.model, node_name, node)
|
596
640
|
|
597
641
|
return cls.build_model_or_flow_with_context(
|
598
642
|
job_config, namespace, node_name, model_obj,
|
599
643
|
push_mapping, pop_mapping, explicit_deps)
|
600
644
|
|
601
|
-
|
645
|
+
# Missing / invalid node type - should be caught in static validation
|
646
|
+
raise _ex.ETracInternal(f"Flow node [{node_name}] has invalid node type [{node.nodeType}]")
|
647
|
+
|
648
|
+
@classmethod
|
649
|
+
def check_model_compatibility(
|
650
|
+
cls, model_selector: meta.TagSelector, model_def: meta.ModelDefinition,
|
651
|
+
node_name: str, flow_node: meta.FlowNode):
|
652
|
+
|
653
|
+
model_params = list(sorted(model_def.parameters.keys()))
|
654
|
+
model_inputs = list(sorted(model_def.inputs.keys()))
|
655
|
+
model_outputs = list(sorted(model_def.outputs.keys()))
|
656
|
+
|
657
|
+
node_params = list(sorted(flow_node.parameters))
|
658
|
+
node_inputs = list(sorted(flow_node.inputs))
|
659
|
+
node_outputs = list(sorted(flow_node.outputs))
|
660
|
+
|
661
|
+
if model_params != node_params or model_inputs != node_inputs or model_outputs != node_outputs:
|
662
|
+
model_key = _util.object_key(model_selector)
|
663
|
+
raise _ex.EJobValidation(f"Incompatible model for flow node [{node_name}] (Model: [{model_key}])")
|
602
664
|
|
603
665
|
@staticmethod
|
604
666
|
def build_context_push(
|