tracdap-runtime 0.6.1.dev3__py3-none-any.whl → 0.6.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tracdap/rt/_exec/context.py +25 -1
- tracdap/rt/_exec/dev_mode.py +277 -213
- tracdap/rt/_exec/functions.py +37 -8
- tracdap/rt/_exec/graph.py +2 -0
- tracdap/rt/_exec/graph_builder.py +118 -56
- tracdap/rt/_exec/runtime.py +28 -0
- tracdap/rt/_exec/server.py +68 -0
- tracdap/rt/_impl/data.py +14 -0
- tracdap/rt/_impl/grpc/__init__.py +13 -0
- tracdap/rt/_impl/grpc/codec.py +44 -0
- tracdap/rt/_impl/grpc/tracdap/api/internal/runtime_pb2.py +51 -0
- tracdap/rt/_impl/grpc/tracdap/api/internal/runtime_pb2.pyi +59 -0
- tracdap/rt/_impl/grpc/tracdap/api/internal/runtime_pb2_grpc.py +183 -0
- tracdap/rt/_impl/grpc/tracdap/config/common_pb2.py +55 -0
- tracdap/rt/_impl/grpc/tracdap/config/common_pb2.pyi +103 -0
- tracdap/rt/_impl/grpc/tracdap/config/job_pb2.py +42 -0
- tracdap/rt/_impl/grpc/tracdap/config/job_pb2.pyi +44 -0
- tracdap/rt/_impl/grpc/tracdap/config/platform_pb2.py +71 -0
- tracdap/rt/_impl/grpc/tracdap/config/platform_pb2.pyi +197 -0
- tracdap/rt/_impl/grpc/tracdap/config/result_pb2.py +37 -0
- tracdap/rt/_impl/grpc/tracdap/config/result_pb2.pyi +35 -0
- tracdap/rt/_impl/grpc/tracdap/config/runtime_pb2.py +42 -0
- tracdap/rt/_impl/grpc/tracdap/config/runtime_pb2.pyi +46 -0
- tracdap/rt/_impl/grpc/tracdap/metadata/common_pb2.py +33 -0
- tracdap/rt/_impl/grpc/tracdap/metadata/common_pb2.pyi +34 -0
- tracdap/rt/{metadata → _impl/grpc/tracdap/metadata}/custom_pb2.py +5 -5
- tracdap/rt/_impl/grpc/tracdap/metadata/custom_pb2.pyi +15 -0
- tracdap/rt/_impl/grpc/tracdap/metadata/data_pb2.py +51 -0
- tracdap/rt/_impl/grpc/tracdap/metadata/data_pb2.pyi +115 -0
- tracdap/rt/_impl/grpc/tracdap/metadata/file_pb2.py +28 -0
- tracdap/rt/_impl/grpc/tracdap/metadata/file_pb2.pyi +22 -0
- tracdap/rt/_impl/grpc/tracdap/metadata/flow_pb2.py +59 -0
- tracdap/rt/_impl/grpc/tracdap/metadata/flow_pb2.pyi +109 -0
- tracdap/rt/_impl/grpc/tracdap/metadata/job_pb2.py +76 -0
- tracdap/rt/_impl/grpc/tracdap/metadata/job_pb2.pyi +177 -0
- tracdap/rt/_impl/grpc/tracdap/metadata/model_pb2.py +51 -0
- tracdap/rt/_impl/grpc/tracdap/metadata/model_pb2.pyi +92 -0
- tracdap/rt/_impl/grpc/tracdap/metadata/object_id_pb2.py +32 -0
- tracdap/rt/_impl/grpc/tracdap/metadata/object_id_pb2.pyi +68 -0
- tracdap/rt/_impl/grpc/tracdap/metadata/object_pb2.py +35 -0
- tracdap/rt/_impl/grpc/tracdap/metadata/object_pb2.pyi +35 -0
- tracdap/rt/_impl/grpc/tracdap/metadata/search_pb2.py +39 -0
- tracdap/rt/_impl/grpc/tracdap/metadata/search_pb2.pyi +83 -0
- tracdap/rt/_impl/grpc/tracdap/metadata/stoarge_pb2.py +50 -0
- tracdap/rt/_impl/grpc/tracdap/metadata/stoarge_pb2.pyi +89 -0
- tracdap/rt/_impl/grpc/tracdap/metadata/tag_pb2.py +34 -0
- tracdap/rt/_impl/grpc/tracdap/metadata/tag_pb2.pyi +26 -0
- tracdap/rt/_impl/grpc/tracdap/metadata/tag_update_pb2.py +30 -0
- tracdap/rt/_impl/grpc/tracdap/metadata/tag_update_pb2.pyi +34 -0
- tracdap/rt/_impl/grpc/tracdap/metadata/type_pb2.py +47 -0
- tracdap/rt/_impl/grpc/tracdap/metadata/type_pb2.pyi +101 -0
- tracdap/rt/_impl/guard_rails.py +5 -6
- tracdap/rt/_impl/static_api.py +10 -6
- tracdap/rt/_version.py +1 -1
- tracdap/rt/api/hook.py +6 -2
- tracdap/rt/api/model_api.py +22 -0
- tracdap/rt/api/static_api.py +14 -4
- tracdap/rt/config/__init__.py +3 -3
- tracdap/rt/config/platform.py +9 -9
- tracdap/rt/launch/cli.py +3 -5
- tracdap/rt/launch/launch.py +15 -3
- tracdap/rt/metadata/__init__.py +15 -15
- tracdap/rt/metadata/common.py +7 -7
- tracdap/rt/metadata/custom.py +2 -0
- tracdap/rt/metadata/data.py +28 -5
- tracdap/rt/metadata/file.py +2 -0
- tracdap/rt/metadata/flow.py +66 -4
- tracdap/rt/metadata/job.py +56 -16
- tracdap/rt/metadata/model.py +4 -0
- tracdap/rt/metadata/object_id.py +9 -9
- tracdap/rt/metadata/search.py +35 -13
- tracdap/rt/metadata/stoarge.py +64 -6
- tracdap/rt/metadata/tag_update.py +21 -7
- tracdap/rt/metadata/type.py +28 -13
- {tracdap_runtime-0.6.1.dev3.dist-info → tracdap_runtime-0.6.2.dist-info}/METADATA +22 -19
- tracdap_runtime-0.6.2.dist-info/RECORD +121 -0
- {tracdap_runtime-0.6.1.dev3.dist-info → tracdap_runtime-0.6.2.dist-info}/WHEEL +1 -1
- tracdap/rt/config/common_pb2.py +0 -55
- tracdap/rt/config/job_pb2.py +0 -42
- tracdap/rt/config/platform_pb2.py +0 -71
- tracdap/rt/config/result_pb2.py +0 -37
- tracdap/rt/config/runtime_pb2.py +0 -42
- tracdap/rt/metadata/common_pb2.py +0 -33
- tracdap/rt/metadata/data_pb2.py +0 -51
- tracdap/rt/metadata/file_pb2.py +0 -28
- tracdap/rt/metadata/flow_pb2.py +0 -55
- tracdap/rt/metadata/job_pb2.py +0 -76
- tracdap/rt/metadata/model_pb2.py +0 -51
- tracdap/rt/metadata/object_id_pb2.py +0 -32
- tracdap/rt/metadata/object_pb2.py +0 -35
- tracdap/rt/metadata/search_pb2.py +0 -39
- tracdap/rt/metadata/stoarge_pb2.py +0 -50
- tracdap/rt/metadata/tag_pb2.py +0 -34
- tracdap/rt/metadata/tag_update_pb2.py +0 -30
- tracdap/rt/metadata/type_pb2.py +0 -48
- tracdap_runtime-0.6.1.dev3.dist-info/RECORD +0 -96
- {tracdap_runtime-0.6.1.dev3.dist-info → tracdap_runtime-0.6.2.dist-info}/LICENSE +0 -0
- {tracdap_runtime-0.6.1.dev3.dist-info → tracdap_runtime-0.6.2.dist-info}/top_level.txt +0 -0
tracdap/rt/_exec/functions.py
CHANGED
@@ -248,6 +248,10 @@ class DataViewFunc(NodeFunction[_data.DataView]):
|
|
248
248
|
root_item = _ctx_lookup(self.node.root_item, ctx)
|
249
249
|
root_part_key = _data.DataPartKey.for_root()
|
250
250
|
|
251
|
+
# Map empty item -> emtpy view (for optional inputs not supplied)
|
252
|
+
if root_item.is_empty():
|
253
|
+
return _data.DataView.create_empty()
|
254
|
+
|
251
255
|
data_view = _data.DataView.for_trac_schema(self.node.schema)
|
252
256
|
data_view = _data.DataMapping.add_item_to_view(data_view, root_part_key, root_item)
|
253
257
|
|
@@ -263,6 +267,10 @@ class DataItemFunc(NodeFunction[_data.DataItem]):
|
|
263
267
|
|
264
268
|
data_view = _ctx_lookup(self.node.data_view_id, ctx)
|
265
269
|
|
270
|
+
# Map empty view -> emtpy item (for optional outputs not supplied)
|
271
|
+
if data_view.is_empty():
|
272
|
+
return _data.DataItem.create_empty()
|
273
|
+
|
266
274
|
# TODO: Support selecting data item described by self.node
|
267
275
|
|
268
276
|
# Selecting data item for part-root, delta=0
|
@@ -280,6 +288,12 @@ class DataResultFunc(NodeFunction[ObjectBundle]):
|
|
280
288
|
|
281
289
|
def _execute(self, ctx: NodeContext) -> ObjectBundle:
|
282
290
|
|
291
|
+
data_item = _ctx_lookup(self.node.data_item_id, ctx)
|
292
|
+
|
293
|
+
# Do not record output metadata for optional outputs that are empty
|
294
|
+
if data_item.is_empty():
|
295
|
+
return {}
|
296
|
+
|
283
297
|
data_spec = _ctx_lookup(self.node.data_spec_id, ctx)
|
284
298
|
|
285
299
|
# TODO: Check result of save operation
|
@@ -451,6 +465,13 @@ class SaveDataFunc(NodeFunction[None], _LoadSaveDataFunc):
|
|
451
465
|
|
452
466
|
def _execute(self, ctx: NodeContext):
|
453
467
|
|
468
|
+
# Item to be saved should exist in the current context
|
469
|
+
data_item = _ctx_lookup(self.node.data_item_id, ctx)
|
470
|
+
|
471
|
+
# Do not save empty outputs (optional outputs that were not produced)
|
472
|
+
if data_item.is_empty():
|
473
|
+
return
|
474
|
+
|
454
475
|
# This function assumes that metadata has already been generated as the data_spec
|
455
476
|
# i.e. it is already known which incarnation / copy of the data will be created
|
456
477
|
|
@@ -458,9 +479,6 @@ class SaveDataFunc(NodeFunction[None], _LoadSaveDataFunc):
|
|
458
479
|
data_copy = self._choose_copy(data_spec.data_item, data_spec.storage_def)
|
459
480
|
data_storage = self.storage.get_data_storage(data_copy.storageKey)
|
460
481
|
|
461
|
-
# Item to be saved should exist in the current context
|
462
|
-
data_item = _ctx_lookup(self.node.data_item_id, ctx)
|
463
|
-
|
464
482
|
# Current implementation will always put an Arrow table in the data item
|
465
483
|
# Empty tables are allowed, so explicitly check if table is None
|
466
484
|
# Testing "if not data_item.table" will fail for empty tables
|
@@ -567,12 +585,23 @@ class RunModelFunc(NodeFunction[Bundle[_data.DataView]]):
|
|
567
585
|
msg = f"There was an unhandled error in the model: {str(e)}{details}"
|
568
586
|
raise _ex.EModelExec(msg) from e
|
569
587
|
|
570
|
-
#
|
571
|
-
|
572
|
-
|
573
|
-
|
588
|
+
# Check required outputs are present and build the results bundle
|
589
|
+
|
590
|
+
results: Bundle[_data.DataView] = dict()
|
591
|
+
|
592
|
+
for output_name, output_schema in model_def.outputs.items():
|
593
|
+
|
594
|
+
result: _data.DataView = local_ctx.get(output_name)
|
595
|
+
|
596
|
+
if result is None or result.is_empty():
|
597
|
+
if not output_schema.optional:
|
598
|
+
model_name = self.model_class.__name__
|
599
|
+
raise _ex.ERuntimeValidation(f"Missing required output [{output_name}] from model [{model_name}]")
|
600
|
+
|
601
|
+
if result is not None:
|
602
|
+
results[output_name] = result
|
574
603
|
|
575
|
-
return
|
604
|
+
return results
|
576
605
|
|
577
606
|
|
578
607
|
# ----------------------------------------------------------------------------------------------------------------------
|
tracdap/rt/_exec/graph.py
CHANGED
@@ -297,6 +297,7 @@ class DataItemNode(MappingNode[_data.DataItem]):
|
|
297
297
|
class DataResultNode(Node[ObjectBundle]):
|
298
298
|
|
299
299
|
output_name: str
|
300
|
+
data_item_id: NodeId[_data.DataItem]
|
300
301
|
data_spec_id: NodeId[_data.DataSpec]
|
301
302
|
data_save_id: NodeId[type(None)]
|
302
303
|
|
@@ -306,6 +307,7 @@ class DataResultNode(Node[ObjectBundle]):
|
|
306
307
|
def _node_dependencies(self) -> tp.Dict[NodeId, DependencyType]:
|
307
308
|
|
308
309
|
return {
|
310
|
+
self.data_item_id: DependencyType.HARD,
|
309
311
|
self.data_spec_id: DependencyType.HARD,
|
310
312
|
self.data_save_id: DependencyType.HARD}
|
311
313
|
|
@@ -14,9 +14,6 @@
|
|
14
14
|
|
15
15
|
from __future__ import annotations
|
16
16
|
|
17
|
-
import copy
|
18
|
-
import dataclasses # noqa
|
19
|
-
|
20
17
|
import tracdap.rt.config as config
|
21
18
|
import tracdap.rt.exceptions as _ex
|
22
19
|
import tracdap.rt._impl.data as _data # noqa
|
@@ -123,12 +120,14 @@ class GraphBuilder:
|
|
123
120
|
job_namespace: NodeNamespace, job_push_id: NodeId) \
|
124
121
|
-> GraphSection:
|
125
122
|
|
123
|
+
target_selector = job_config.job.runModel.model
|
124
|
+
target_obj = _util.get_job_resource(target_selector, job_config)
|
125
|
+
target_def = target_obj.model
|
126
|
+
job_def = job_config.job.runModel
|
127
|
+
|
126
128
|
return cls.build_calculation_job(
|
127
129
|
job_config, result_spec, job_namespace, job_push_id,
|
128
|
-
|
129
|
-
job_config.job.runModel.parameters,
|
130
|
-
job_config.job.runModel.inputs,
|
131
|
-
job_config.job.runModel.outputs)
|
130
|
+
target_selector, target_def, job_def)
|
132
131
|
|
133
132
|
@classmethod
|
134
133
|
def build_run_flow_job(
|
@@ -136,40 +135,53 @@ class GraphBuilder:
|
|
136
135
|
job_namespace: NodeNamespace, job_push_id: NodeId) \
|
137
136
|
-> GraphSection:
|
138
137
|
|
138
|
+
target_selector = job_config.job.runFlow.flow
|
139
|
+
target_obj = _util.get_job_resource(target_selector, job_config)
|
140
|
+
target_def = target_obj.flow
|
141
|
+
job_def = job_config.job.runFlow
|
142
|
+
|
139
143
|
return cls.build_calculation_job(
|
140
144
|
job_config, result_spec, job_namespace, job_push_id,
|
141
|
-
|
142
|
-
job_config.job.runFlow.parameters,
|
143
|
-
job_config.job.runFlow.inputs,
|
144
|
-
job_config.job.runFlow.outputs)
|
145
|
+
target_selector, target_def, job_def)
|
145
146
|
|
146
147
|
@classmethod
|
147
148
|
def build_calculation_job(
|
148
149
|
cls, job_config: config.JobConfig, result_spec: JobResultSpec,
|
149
150
|
job_namespace: NodeNamespace, job_push_id: NodeId,
|
150
|
-
|
151
|
-
|
151
|
+
target_selector: meta.TagSelector,
|
152
|
+
target_def: tp.Union[meta.ModelDefinition, meta.FlowDefinition],
|
153
|
+
job_def: tp.Union[meta.RunModelJob, meta.RunFlowJob]) \
|
152
154
|
-> GraphSection:
|
153
155
|
|
154
156
|
# The main execution graph can run directly in the job context, no need to do a context push
|
155
157
|
# since inputs and outputs in this context line up with the top level execution task
|
156
158
|
|
159
|
+
# Required / provided items are the same for RUN_MODEL and RUN_FLOW jobs
|
160
|
+
|
161
|
+
required_params = target_def.parameters
|
162
|
+
required_inputs = target_def.inputs
|
163
|
+
required_outputs = target_def.outputs
|
164
|
+
|
165
|
+
provided_params = job_def.parameters
|
166
|
+
provided_inputs = job_def.inputs
|
167
|
+
provided_outputs = job_def.outputs
|
168
|
+
|
157
169
|
params_section = cls.build_job_parameters(
|
158
|
-
job_namespace,
|
170
|
+
job_namespace, required_params, provided_params,
|
159
171
|
explicit_deps=[job_push_id])
|
160
172
|
|
161
173
|
input_section = cls.build_job_inputs(
|
162
|
-
job_config, job_namespace,
|
174
|
+
job_config, job_namespace, required_inputs, provided_inputs,
|
163
175
|
explicit_deps=[job_push_id])
|
164
176
|
|
165
|
-
exec_obj = _util.get_job_resource(
|
177
|
+
exec_obj = _util.get_job_resource(target_selector, job_config)
|
166
178
|
|
167
179
|
exec_section = cls.build_model_or_flow(
|
168
180
|
job_config, job_namespace, exec_obj,
|
169
181
|
explicit_deps=[job_push_id])
|
170
182
|
|
171
183
|
output_section = cls.build_job_outputs(
|
172
|
-
job_config, job_namespace,
|
184
|
+
job_config, job_namespace, required_outputs, provided_outputs,
|
173
185
|
explicit_deps=[job_push_id])
|
174
186
|
|
175
187
|
main_section = cls._join_sections(params_section, input_section, exec_section, output_section)
|
@@ -190,13 +202,22 @@ class GraphBuilder:
|
|
190
202
|
@classmethod
|
191
203
|
def build_job_parameters(
|
192
204
|
cls, job_namespace: NodeNamespace,
|
193
|
-
|
205
|
+
required_params: tp.Dict[str, meta.ModelParameter],
|
206
|
+
supplied_params: tp.Dict[str, meta.Value],
|
194
207
|
explicit_deps: tp.Optional[tp.List[NodeId]] = None) \
|
195
208
|
-> GraphSection:
|
196
209
|
|
197
210
|
nodes = dict()
|
198
211
|
|
199
|
-
for param_name,
|
212
|
+
for param_name, param_schema in required_params.items():
|
213
|
+
|
214
|
+
param_def = supplied_params.get(param_name)
|
215
|
+
|
216
|
+
if param_def is None:
|
217
|
+
if param_schema.defaultValue is not None:
|
218
|
+
param_def = param_schema.defaultValue
|
219
|
+
else:
|
220
|
+
raise _ex.EJobValidation(f"Missing required parameter: [{param_name}]")
|
200
221
|
|
201
222
|
param_id = NodeId(param_name, job_namespace, meta.Value)
|
202
223
|
param_node = StaticValueNode(param_id, param_def, explicit_deps=explicit_deps)
|
@@ -208,7 +229,8 @@ class GraphBuilder:
|
|
208
229
|
@classmethod
|
209
230
|
def build_job_inputs(
|
210
231
|
cls, job_config: config.JobConfig, job_namespace: NodeNamespace,
|
211
|
-
|
232
|
+
required_inputs: tp.Dict[str, meta.ModelInputSchema],
|
233
|
+
supplied_inputs: tp.Dict[str, meta.TagSelector],
|
212
234
|
explicit_deps: tp.Optional[tp.List[NodeId]] = None) \
|
213
235
|
-> GraphSection:
|
214
236
|
|
@@ -216,7 +238,18 @@ class GraphBuilder:
|
|
216
238
|
outputs = set()
|
217
239
|
must_run = list()
|
218
240
|
|
219
|
-
for input_name,
|
241
|
+
for input_name, input_schema in required_inputs.items():
|
242
|
+
|
243
|
+
data_selector = supplied_inputs.get(input_name)
|
244
|
+
|
245
|
+
if data_selector is None:
|
246
|
+
if input_schema.optional:
|
247
|
+
data_view_id = NodeId.of(input_name, job_namespace, _data.DataView)
|
248
|
+
nodes[data_view_id] = StaticValueNode(data_view_id, _data.DataView.create_empty())
|
249
|
+
outputs.add(data_view_id)
|
250
|
+
continue
|
251
|
+
else:
|
252
|
+
raise _ex.EJobValidation(f"Missing required input: [{input_name}]")
|
220
253
|
|
221
254
|
# Build a data spec using metadata from the job config
|
222
255
|
# For now we are always loading the root part, snap 0, delta 0
|
@@ -258,14 +291,24 @@ class GraphBuilder:
|
|
258
291
|
@classmethod
|
259
292
|
def build_job_outputs(
|
260
293
|
cls, job_config: config.JobConfig, job_namespace: NodeNamespace,
|
261
|
-
|
294
|
+
required_outputs: tp.Dict[str, meta.ModelOutputSchema],
|
295
|
+
supplied_outputs: tp.Dict[str, meta.TagSelector],
|
262
296
|
explicit_deps: tp.Optional[tp.List[NodeId]] = None) \
|
263
297
|
-> GraphSection:
|
264
298
|
|
265
299
|
nodes = {}
|
266
300
|
inputs = set()
|
267
301
|
|
268
|
-
for output_name,
|
302
|
+
for output_name, output_schema in required_outputs.items():
|
303
|
+
|
304
|
+
data_selector = supplied_outputs.get(output_name)
|
305
|
+
|
306
|
+
if data_selector is None:
|
307
|
+
if output_schema.optional:
|
308
|
+
optional_info = "(configuration is required for all optional outputs, in case they are produced)"
|
309
|
+
raise _ex.EJobValidation(f"Missing optional output: [{output_name}] {optional_info}")
|
310
|
+
else:
|
311
|
+
raise _ex.EJobValidation(f"Missing required output: [{output_name}]")
|
269
312
|
|
270
313
|
# Output data view must already exist in the namespace
|
271
314
|
data_view_id = NodeId.of(output_name, job_namespace, _data.DataView)
|
@@ -323,7 +366,8 @@ class GraphBuilder:
|
|
323
366
|
|
324
367
|
data_result_id = NodeId.of(f"{output_name}:RESULT", job_namespace, ObjectBundle)
|
325
368
|
data_result_node = DataResultNode(
|
326
|
-
data_result_id, output_name,
|
369
|
+
data_result_id, output_name,
|
370
|
+
data_item_id, data_spec_id, data_save_id,
|
327
371
|
output_data_key, output_storage_key)
|
328
372
|
|
329
373
|
nodes[data_spec_id] = data_spec_node
|
@@ -458,10 +502,10 @@ class GraphBuilder:
|
|
458
502
|
frozenset(parameter_ids), frozenset(input_ids),
|
459
503
|
explicit_deps=explicit_deps, bundle=model_id.namespace)
|
460
504
|
|
461
|
-
|
462
|
-
model_result_node = RunModelResultNode(
|
505
|
+
model_result_id = NodeId(f"{model_name}:RESULT", namespace)
|
506
|
+
model_result_node = RunModelResultNode(model_result_id, model_id)
|
463
507
|
|
464
|
-
nodes = {model_id: model_node,
|
508
|
+
nodes = {model_id: model_node, model_result_id: model_result_node}
|
465
509
|
|
466
510
|
# Create nodes for each model output
|
467
511
|
# The model node itself outputs a bundle (dictionary of named outputs)
|
@@ -474,7 +518,7 @@ class GraphBuilder:
|
|
474
518
|
nodes[output_id] = BundleItemNode(output_id, model_id, output_id.name)
|
475
519
|
|
476
520
|
# Assemble a graph to include the model and its outputs
|
477
|
-
return GraphSection(nodes, inputs={*parameter_ids, *input_ids}, outputs=output_ids, must_run=[
|
521
|
+
return GraphSection(nodes, inputs={*parameter_ids, *input_ids}, outputs=output_ids, must_run=[model_result_id])
|
478
522
|
|
479
523
|
@classmethod
|
480
524
|
def build_flow(
|
@@ -488,7 +532,7 @@ class GraphBuilder:
|
|
488
532
|
|
489
533
|
# https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm
|
490
534
|
|
491
|
-
|
535
|
+
# Group edges by source and target node
|
492
536
|
remaining_edges_by_target = {edge.target.node: [] for edge in flow_def.edges}
|
493
537
|
remaining_edges_by_source = {edge.source.node: [] for edge in flow_def.edges}
|
494
538
|
|
@@ -496,16 +540,14 @@ class GraphBuilder:
|
|
496
540
|
remaining_edges_by_target[edge.target.node].append(edge)
|
497
541
|
remaining_edges_by_source[edge.source.node].append(edge)
|
498
542
|
|
499
|
-
|
500
|
-
|
501
|
-
# Initial set of reachable flow nodes is just the input nodes
|
502
|
-
for node_name, node in list(remaining_nodes.items()):
|
503
|
-
if node.nodeType == meta.FlowNodeType.INPUT_NODE:
|
504
|
-
reachable_nodes[node_name] = node
|
505
|
-
del remaining_nodes[node_name]
|
506
|
-
|
543
|
+
# Group edges by target socket (only one edge per target in a consistent flow)
|
507
544
|
target_edges = {socket_key(edge.target): edge for edge in flow_def.edges}
|
508
545
|
|
546
|
+
# Initially parameters and inputs are reachable, everything else is not
|
547
|
+
def is_input(n): return n[1].nodeType in [meta.FlowNodeType.PARAMETER_NODE, meta.FlowNodeType.INPUT_NODE]
|
548
|
+
reachable_nodes = dict(filter(is_input, flow_def.nodes.items()))
|
549
|
+
remaining_nodes = dict(filter(lambda n: not is_input(n), flow_def.nodes.items()))
|
550
|
+
|
509
551
|
# Initial graph section for the flow is empty
|
510
552
|
graph_section = GraphSection({}, must_run=explicit_deps)
|
511
553
|
|
@@ -559,10 +601,16 @@ class GraphBuilder:
|
|
559
601
|
return NodeId(socket_name, namespace, result_type)
|
560
602
|
|
561
603
|
def edge_mapping(node_: str, socket_: str = None, result_type=None):
|
562
|
-
socket = meta.FlowSocket(node_, socket_)
|
563
|
-
edge = target_edges.get(
|
604
|
+
socket = socket_key(meta.FlowSocket(node_, socket_))
|
605
|
+
edge = target_edges.get(socket)
|
606
|
+
# Report missing edges as a job consistency error (this might happen sometimes in dev mode)
|
607
|
+
if edge is None:
|
608
|
+
raise _ex.EJobValidation(f"Inconsistent flow: Socket [{socket}] is not connected")
|
564
609
|
return socket_id(edge.source.node, edge.source.socket, result_type)
|
565
610
|
|
611
|
+
if node.nodeType == meta.FlowNodeType.PARAMETER_NODE:
|
612
|
+
return GraphSection({}, inputs={NodeId(node_name, namespace, result_type=meta.Value)})
|
613
|
+
|
566
614
|
if node.nodeType == meta.FlowNodeType.INPUT_NODE:
|
567
615
|
return GraphSection({}, inputs={NodeId(node_name, namespace, result_type=_data.DataView)})
|
568
616
|
|
@@ -573,32 +621,46 @@ class GraphBuilder:
|
|
573
621
|
|
574
622
|
if node.nodeType == meta.FlowNodeType.MODEL_NODE:
|
575
623
|
|
576
|
-
|
577
|
-
|
578
|
-
|
579
|
-
# TODO: Whether to use flow node or model_obj to build the push mapping?
|
624
|
+
param_mapping = {socket: edge_mapping(node_name, socket, meta.Value) for socket in node.parameters}
|
625
|
+
input_mapping = {socket: edge_mapping(node_name, socket, _data.DataView) for socket in node.inputs}
|
626
|
+
output_mapping = {socket: socket_id(node_name, socket, _data.DataView) for socket in node.outputs}
|
580
627
|
|
581
|
-
|
582
|
-
|
583
|
-
for input_name in model_obj.model.inputs
|
584
|
-
}
|
628
|
+
push_mapping = {**input_mapping, **param_mapping}
|
629
|
+
pop_mapping = output_mapping
|
585
630
|
|
586
|
-
|
587
|
-
|
588
|
-
for param_name in model_obj.model.parameters
|
589
|
-
}
|
631
|
+
model_selector = flow_job.models.get(node_name)
|
632
|
+
model_obj = _util.get_job_resource(model_selector, job_config)
|
590
633
|
|
591
|
-
|
634
|
+
# Missing models in the job config is a job consistency error
|
635
|
+
if model_obj is None or model_obj.objectType != meta.ObjectType.MODEL:
|
636
|
+
raise _ex.EJobValidation(f"No model was provided for flow node [{node_name}]")
|
592
637
|
|
593
|
-
|
594
|
-
|
595
|
-
for output_ in model_obj.model.outputs}
|
638
|
+
# Explicit check for model compatibility - report an error now, do not try build_model()
|
639
|
+
cls.check_model_compatibility(model_selector, model_obj.model, node_name, node)
|
596
640
|
|
597
641
|
return cls.build_model_or_flow_with_context(
|
598
642
|
job_config, namespace, node_name, model_obj,
|
599
643
|
push_mapping, pop_mapping, explicit_deps)
|
600
644
|
|
601
|
-
|
645
|
+
# Missing / invalid node type - should be caught in static validation
|
646
|
+
raise _ex.ETracInternal(f"Flow node [{node_name}] has invalid node type [{node.nodeType}]")
|
647
|
+
|
648
|
+
@classmethod
|
649
|
+
def check_model_compatibility(
|
650
|
+
cls, model_selector: meta.TagSelector, model_def: meta.ModelDefinition,
|
651
|
+
node_name: str, flow_node: meta.FlowNode):
|
652
|
+
|
653
|
+
model_params = list(sorted(model_def.parameters.keys()))
|
654
|
+
model_inputs = list(sorted(model_def.inputs.keys()))
|
655
|
+
model_outputs = list(sorted(model_def.outputs.keys()))
|
656
|
+
|
657
|
+
node_params = list(sorted(flow_node.parameters))
|
658
|
+
node_inputs = list(sorted(flow_node.inputs))
|
659
|
+
node_outputs = list(sorted(flow_node.outputs))
|
660
|
+
|
661
|
+
if model_params != node_params or model_inputs != node_inputs or model_outputs != node_outputs:
|
662
|
+
model_key = _util.object_key(model_selector)
|
663
|
+
raise _ex.EJobValidation(f"Incompatible model for flow node [{node_name}] (Model: [{model_key}])")
|
602
664
|
|
603
665
|
@staticmethod
|
604
666
|
def build_context_push(
|
tracdap/rt/_exec/runtime.py
CHANGED
@@ -90,6 +90,10 @@ class TracRuntime:
|
|
90
90
|
self._scratch_dir_provided = True if scratch_dir is not None else False
|
91
91
|
self._scratch_dir_persist = scratch_dir_persist
|
92
92
|
self._dev_mode = dev_mode
|
93
|
+
self._server_enabled = False
|
94
|
+
self._server_port = 0
|
95
|
+
|
96
|
+
self._pre_start_complete = False
|
93
97
|
|
94
98
|
# Top level resources
|
95
99
|
self._models: tp.Optional[_models.ModelLoader] = None
|
@@ -100,6 +104,9 @@ class TracRuntime:
|
|
100
104
|
self._engine: tp.Optional[_engine.TracEngine] = None
|
101
105
|
self._engine_event = threading.Condition()
|
102
106
|
|
107
|
+
# Runtime API server
|
108
|
+
self._server = None
|
109
|
+
|
103
110
|
self._jobs: tp.Dict[str, _RuntimeJobInfo] = dict()
|
104
111
|
|
105
112
|
# ------------------------------------------------------------------------------------------------------------------
|
@@ -152,6 +159,8 @@ class TracRuntime:
|
|
152
159
|
config_dir = self._sys_config_path.parent if self._sys_config_path is not None else None
|
153
160
|
self._sys_config = _dev_mode.DevModeTranslator.translate_sys_config(self._sys_config, config_dir)
|
154
161
|
|
162
|
+
self._pre_start_complete = True
|
163
|
+
|
155
164
|
except Exception as e:
|
156
165
|
self._handle_startup_error(e)
|
157
166
|
|
@@ -159,6 +168,10 @@ class TracRuntime:
|
|
159
168
|
|
160
169
|
try:
|
161
170
|
|
171
|
+
# Ensure pre-start has been run
|
172
|
+
if not self._pre_start_complete:
|
173
|
+
self.pre_start()
|
174
|
+
|
162
175
|
self._log.info("Starting the engine...")
|
163
176
|
|
164
177
|
self._models = _models.ModelLoader(self._sys_config, self._scratch_dir)
|
@@ -175,11 +188,26 @@ class TracRuntime:
|
|
175
188
|
|
176
189
|
self._system.start(wait=wait)
|
177
190
|
|
191
|
+
# If the runtime server has been enabled, start it up
|
192
|
+
if self._server_enabled:
|
193
|
+
|
194
|
+
self._log.info("Starting the runtime API server...")
|
195
|
+
|
196
|
+
# The server module pulls in all the gRPC dependencies, don't import it unless we have to
|
197
|
+
import tracdap.rt._exec.server as _server
|
198
|
+
|
199
|
+
self._server = _server.RuntimeApiServer(self._server_port)
|
200
|
+
self._server.start()
|
201
|
+
|
178
202
|
except Exception as e:
|
179
203
|
self._handle_startup_error(e)
|
180
204
|
|
181
205
|
def stop(self, due_to_error=False):
|
182
206
|
|
207
|
+
if self._server is not None:
|
208
|
+
self._log.info("Stopping the runtime API server...")
|
209
|
+
self._server.stop()
|
210
|
+
|
183
211
|
if due_to_error:
|
184
212
|
self._log.info("Shutting down the engine in response to an error")
|
185
213
|
else:
|
@@ -0,0 +1,68 @@
|
|
1
|
+
# Copyright 2024 Accenture Global Solutions Limited
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import typing as tp
|
16
|
+
import concurrent.futures as futures
|
17
|
+
|
18
|
+
# Imports for gRPC generated code, these are managed by build_runtime.py for distribution
|
19
|
+
import tracdap.rt._impl.grpc.tracdap.api.internal.runtime_pb2 as runtime_pb2
|
20
|
+
import tracdap.rt._impl.grpc.tracdap.api.internal.runtime_pb2_grpc as runtime_grpc
|
21
|
+
import grpc
|
22
|
+
|
23
|
+
|
24
|
+
class RuntimeApiServer(runtime_grpc.TracRuntimeApiServicer):
|
25
|
+
|
26
|
+
__THREAD_POOL_DEFAULT_SIZE = 2
|
27
|
+
__THREAD_NAME_PREFIX = "server-"
|
28
|
+
__DEFAULT_SHUTDOWN_TIMEOUT = 10.0 # seconds
|
29
|
+
|
30
|
+
def __init__(self, port: int, n_workers: int = None):
|
31
|
+
self.__port = port
|
32
|
+
self.__n_workers = n_workers or self.__THREAD_POOL_DEFAULT_SIZE
|
33
|
+
self.__server: tp.Optional[grpc.Server] = None
|
34
|
+
self.__thread_pool: tp.Optional[futures.ThreadPoolExecutor] = None
|
35
|
+
|
36
|
+
def listJobs(self, request, context):
|
37
|
+
return super().listJobs(request, context)
|
38
|
+
|
39
|
+
def getJobStatus(self, request: runtime_pb2.BatchJobStatusRequest, context: grpc.ServicerContext):
|
40
|
+
return super().getJobStatus(request, context)
|
41
|
+
|
42
|
+
def getJobDetails(self, request, context):
|
43
|
+
return super().getJobDetails(request, context)
|
44
|
+
|
45
|
+
def start(self):
|
46
|
+
|
47
|
+
self.__thread_pool = futures.ThreadPoolExecutor(
|
48
|
+
max_workers=self.__n_workers,
|
49
|
+
thread_name_prefix=self.__THREAD_NAME_PREFIX)
|
50
|
+
|
51
|
+
self.__server = grpc.server(self.__thread_pool)
|
52
|
+
|
53
|
+
socket = f"[::]:{self.__port}"
|
54
|
+
self.__server.add_insecure_port(socket)
|
55
|
+
|
56
|
+
runtime_grpc.add_TracRuntimeApiServicer_to_server(self, self.__server)
|
57
|
+
|
58
|
+
self.__server.start()
|
59
|
+
|
60
|
+
def stop(self, shutdown_timeout: float = None):
|
61
|
+
|
62
|
+
grace = shutdown_timeout or self.__DEFAULT_SHUTDOWN_TIMEOUT
|
63
|
+
|
64
|
+
if self.__server is not None:
|
65
|
+
self.__server.stop(grace)
|
66
|
+
|
67
|
+
if self.__thread_pool is not None:
|
68
|
+
self.__thread_pool.shutdown()
|
tracdap/rt/_impl/data.py
CHANGED
@@ -58,6 +58,13 @@ class DataItem:
|
|
58
58
|
pandas: tp.Optional[pd.DataFrame] = None
|
59
59
|
pyspark: tp.Any = None
|
60
60
|
|
61
|
+
def is_empty(self) -> bool:
|
62
|
+
return self.table is None and (self.batches is None or len(self.batches) == 0)
|
63
|
+
|
64
|
+
@staticmethod
|
65
|
+
def create_empty() -> DataItem:
|
66
|
+
return DataItem(pa.schema([]))
|
67
|
+
|
61
68
|
|
62
69
|
@dc.dataclass(frozen=True)
|
63
70
|
class DataView:
|
@@ -72,6 +79,13 @@ class DataView:
|
|
72
79
|
arrow_schema = DataMapping.trac_to_arrow_schema(trac_schema)
|
73
80
|
return DataView(trac_schema, arrow_schema, dict())
|
74
81
|
|
82
|
+
def is_empty(self) -> bool:
|
83
|
+
return self.parts is None or len(self.parts) == 0
|
84
|
+
|
85
|
+
@staticmethod
|
86
|
+
def create_empty() -> DataView:
|
87
|
+
return DataView(_meta.SchemaDefinition(), pa.schema([]), dict())
|
88
|
+
|
75
89
|
|
76
90
|
class _DataInternal:
|
77
91
|
pass
|
@@ -0,0 +1,13 @@
|
|
1
|
+
# Copyright 2024 Accenture Global Solutions Limited
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
@@ -0,0 +1,44 @@
|
|
1
|
+
# Copyright 2024 Accenture Global Solutions Limited
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import enum
|
16
|
+
import typing as tp
|
17
|
+
|
18
|
+
|
19
|
+
def encode(obj: tp.Any) -> tp.Any:
|
20
|
+
|
21
|
+
# Translate TRAC domain objects into generic dict / list structures
|
22
|
+
# These can be accepted by gRPC message constructors, do not try to build messages directly
|
23
|
+
# Use shallow copies and builtins to minimize performance impact
|
24
|
+
|
25
|
+
if obj is None:
|
26
|
+
return None
|
27
|
+
|
28
|
+
if isinstance(obj, str) or isinstance(obj, bool) or isinstance(obj, int) or isinstance(obj, float):
|
29
|
+
return obj
|
30
|
+
|
31
|
+
if isinstance(obj, enum.Enum):
|
32
|
+
return obj.value
|
33
|
+
|
34
|
+
if isinstance(obj, list):
|
35
|
+
return list(map(encode, obj))
|
36
|
+
|
37
|
+
if isinstance(obj, dict):
|
38
|
+
return dict(map(lambda kv: (kv[0], encode(kv[1])), obj.items()))
|
39
|
+
|
40
|
+
# Filter classes for TRAC domain objects (sanity check, not a watertight validation)
|
41
|
+
if hasattr(obj, "__module__") and "tracdap" in obj.__module__:
|
42
|
+
return dict(map(lambda kv: (kv[0], encode(kv[1])), obj.__dict__.items()))
|
43
|
+
|
44
|
+
raise RuntimeError(f"Cannot encode object of type [{type(obj).__name__}] for gRPC")
|