tracdap-runtime 0.6.1.dev3__py3-none-any.whl → 0.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (98) hide show
  1. tracdap/rt/_exec/context.py +25 -1
  2. tracdap/rt/_exec/dev_mode.py +277 -213
  3. tracdap/rt/_exec/functions.py +37 -8
  4. tracdap/rt/_exec/graph.py +2 -0
  5. tracdap/rt/_exec/graph_builder.py +118 -56
  6. tracdap/rt/_exec/runtime.py +28 -0
  7. tracdap/rt/_exec/server.py +68 -0
  8. tracdap/rt/_impl/data.py +14 -0
  9. tracdap/rt/_impl/grpc/__init__.py +13 -0
  10. tracdap/rt/_impl/grpc/codec.py +44 -0
  11. tracdap/rt/_impl/grpc/tracdap/api/internal/runtime_pb2.py +51 -0
  12. tracdap/rt/_impl/grpc/tracdap/api/internal/runtime_pb2.pyi +59 -0
  13. tracdap/rt/_impl/grpc/tracdap/api/internal/runtime_pb2_grpc.py +183 -0
  14. tracdap/rt/_impl/grpc/tracdap/config/common_pb2.py +55 -0
  15. tracdap/rt/_impl/grpc/tracdap/config/common_pb2.pyi +103 -0
  16. tracdap/rt/_impl/grpc/tracdap/config/job_pb2.py +42 -0
  17. tracdap/rt/_impl/grpc/tracdap/config/job_pb2.pyi +44 -0
  18. tracdap/rt/_impl/grpc/tracdap/config/platform_pb2.py +71 -0
  19. tracdap/rt/_impl/grpc/tracdap/config/platform_pb2.pyi +197 -0
  20. tracdap/rt/_impl/grpc/tracdap/config/result_pb2.py +37 -0
  21. tracdap/rt/_impl/grpc/tracdap/config/result_pb2.pyi +35 -0
  22. tracdap/rt/_impl/grpc/tracdap/config/runtime_pb2.py +42 -0
  23. tracdap/rt/_impl/grpc/tracdap/config/runtime_pb2.pyi +46 -0
  24. tracdap/rt/_impl/grpc/tracdap/metadata/common_pb2.py +33 -0
  25. tracdap/rt/_impl/grpc/tracdap/metadata/common_pb2.pyi +34 -0
  26. tracdap/rt/{metadata → _impl/grpc/tracdap/metadata}/custom_pb2.py +5 -5
  27. tracdap/rt/_impl/grpc/tracdap/metadata/custom_pb2.pyi +15 -0
  28. tracdap/rt/_impl/grpc/tracdap/metadata/data_pb2.py +51 -0
  29. tracdap/rt/_impl/grpc/tracdap/metadata/data_pb2.pyi +115 -0
  30. tracdap/rt/_impl/grpc/tracdap/metadata/file_pb2.py +28 -0
  31. tracdap/rt/_impl/grpc/tracdap/metadata/file_pb2.pyi +22 -0
  32. tracdap/rt/_impl/grpc/tracdap/metadata/flow_pb2.py +59 -0
  33. tracdap/rt/_impl/grpc/tracdap/metadata/flow_pb2.pyi +109 -0
  34. tracdap/rt/_impl/grpc/tracdap/metadata/job_pb2.py +76 -0
  35. tracdap/rt/_impl/grpc/tracdap/metadata/job_pb2.pyi +177 -0
  36. tracdap/rt/_impl/grpc/tracdap/metadata/model_pb2.py +51 -0
  37. tracdap/rt/_impl/grpc/tracdap/metadata/model_pb2.pyi +92 -0
  38. tracdap/rt/_impl/grpc/tracdap/metadata/object_id_pb2.py +32 -0
  39. tracdap/rt/_impl/grpc/tracdap/metadata/object_id_pb2.pyi +68 -0
  40. tracdap/rt/_impl/grpc/tracdap/metadata/object_pb2.py +35 -0
  41. tracdap/rt/_impl/grpc/tracdap/metadata/object_pb2.pyi +35 -0
  42. tracdap/rt/_impl/grpc/tracdap/metadata/search_pb2.py +39 -0
  43. tracdap/rt/_impl/grpc/tracdap/metadata/search_pb2.pyi +83 -0
  44. tracdap/rt/_impl/grpc/tracdap/metadata/stoarge_pb2.py +50 -0
  45. tracdap/rt/_impl/grpc/tracdap/metadata/stoarge_pb2.pyi +89 -0
  46. tracdap/rt/_impl/grpc/tracdap/metadata/tag_pb2.py +34 -0
  47. tracdap/rt/_impl/grpc/tracdap/metadata/tag_pb2.pyi +26 -0
  48. tracdap/rt/_impl/grpc/tracdap/metadata/tag_update_pb2.py +30 -0
  49. tracdap/rt/_impl/grpc/tracdap/metadata/tag_update_pb2.pyi +34 -0
  50. tracdap/rt/_impl/grpc/tracdap/metadata/type_pb2.py +47 -0
  51. tracdap/rt/_impl/grpc/tracdap/metadata/type_pb2.pyi +101 -0
  52. tracdap/rt/_impl/guard_rails.py +5 -6
  53. tracdap/rt/_impl/static_api.py +10 -6
  54. tracdap/rt/_version.py +1 -1
  55. tracdap/rt/api/hook.py +6 -2
  56. tracdap/rt/api/model_api.py +22 -0
  57. tracdap/rt/api/static_api.py +14 -4
  58. tracdap/rt/config/__init__.py +3 -3
  59. tracdap/rt/config/platform.py +9 -9
  60. tracdap/rt/launch/cli.py +3 -5
  61. tracdap/rt/launch/launch.py +15 -3
  62. tracdap/rt/metadata/__init__.py +15 -15
  63. tracdap/rt/metadata/common.py +7 -7
  64. tracdap/rt/metadata/custom.py +2 -0
  65. tracdap/rt/metadata/data.py +28 -5
  66. tracdap/rt/metadata/file.py +2 -0
  67. tracdap/rt/metadata/flow.py +66 -4
  68. tracdap/rt/metadata/job.py +56 -16
  69. tracdap/rt/metadata/model.py +4 -0
  70. tracdap/rt/metadata/object_id.py +9 -9
  71. tracdap/rt/metadata/search.py +35 -13
  72. tracdap/rt/metadata/stoarge.py +64 -6
  73. tracdap/rt/metadata/tag_update.py +21 -7
  74. tracdap/rt/metadata/type.py +28 -13
  75. {tracdap_runtime-0.6.1.dev3.dist-info → tracdap_runtime-0.6.2.dist-info}/METADATA +22 -19
  76. tracdap_runtime-0.6.2.dist-info/RECORD +121 -0
  77. {tracdap_runtime-0.6.1.dev3.dist-info → tracdap_runtime-0.6.2.dist-info}/WHEEL +1 -1
  78. tracdap/rt/config/common_pb2.py +0 -55
  79. tracdap/rt/config/job_pb2.py +0 -42
  80. tracdap/rt/config/platform_pb2.py +0 -71
  81. tracdap/rt/config/result_pb2.py +0 -37
  82. tracdap/rt/config/runtime_pb2.py +0 -42
  83. tracdap/rt/metadata/common_pb2.py +0 -33
  84. tracdap/rt/metadata/data_pb2.py +0 -51
  85. tracdap/rt/metadata/file_pb2.py +0 -28
  86. tracdap/rt/metadata/flow_pb2.py +0 -55
  87. tracdap/rt/metadata/job_pb2.py +0 -76
  88. tracdap/rt/metadata/model_pb2.py +0 -51
  89. tracdap/rt/metadata/object_id_pb2.py +0 -32
  90. tracdap/rt/metadata/object_pb2.py +0 -35
  91. tracdap/rt/metadata/search_pb2.py +0 -39
  92. tracdap/rt/metadata/stoarge_pb2.py +0 -50
  93. tracdap/rt/metadata/tag_pb2.py +0 -34
  94. tracdap/rt/metadata/tag_update_pb2.py +0 -30
  95. tracdap/rt/metadata/type_pb2.py +0 -48
  96. tracdap_runtime-0.6.1.dev3.dist-info/RECORD +0 -96
  97. {tracdap_runtime-0.6.1.dev3.dist-info → tracdap_runtime-0.6.2.dist-info}/LICENSE +0 -0
  98. {tracdap_runtime-0.6.1.dev3.dist-info → tracdap_runtime-0.6.2.dist-info}/top_level.txt +0 -0
@@ -248,6 +248,10 @@ class DataViewFunc(NodeFunction[_data.DataView]):
248
248
  root_item = _ctx_lookup(self.node.root_item, ctx)
249
249
  root_part_key = _data.DataPartKey.for_root()
250
250
 
251
+ # Map empty item -> emtpy view (for optional inputs not supplied)
252
+ if root_item.is_empty():
253
+ return _data.DataView.create_empty()
254
+
251
255
  data_view = _data.DataView.for_trac_schema(self.node.schema)
252
256
  data_view = _data.DataMapping.add_item_to_view(data_view, root_part_key, root_item)
253
257
 
@@ -263,6 +267,10 @@ class DataItemFunc(NodeFunction[_data.DataItem]):
263
267
 
264
268
  data_view = _ctx_lookup(self.node.data_view_id, ctx)
265
269
 
270
+ # Map empty view -> emtpy item (for optional outputs not supplied)
271
+ if data_view.is_empty():
272
+ return _data.DataItem.create_empty()
273
+
266
274
  # TODO: Support selecting data item described by self.node
267
275
 
268
276
  # Selecting data item for part-root, delta=0
@@ -280,6 +288,12 @@ class DataResultFunc(NodeFunction[ObjectBundle]):
280
288
 
281
289
  def _execute(self, ctx: NodeContext) -> ObjectBundle:
282
290
 
291
+ data_item = _ctx_lookup(self.node.data_item_id, ctx)
292
+
293
+ # Do not record output metadata for optional outputs that are empty
294
+ if data_item.is_empty():
295
+ return {}
296
+
283
297
  data_spec = _ctx_lookup(self.node.data_spec_id, ctx)
284
298
 
285
299
  # TODO: Check result of save operation
@@ -451,6 +465,13 @@ class SaveDataFunc(NodeFunction[None], _LoadSaveDataFunc):
451
465
 
452
466
  def _execute(self, ctx: NodeContext):
453
467
 
468
+ # Item to be saved should exist in the current context
469
+ data_item = _ctx_lookup(self.node.data_item_id, ctx)
470
+
471
+ # Do not save empty outputs (optional outputs that were not produced)
472
+ if data_item.is_empty():
473
+ return
474
+
454
475
  # This function assumes that metadata has already been generated as the data_spec
455
476
  # i.e. it is already known which incarnation / copy of the data will be created
456
477
 
@@ -458,9 +479,6 @@ class SaveDataFunc(NodeFunction[None], _LoadSaveDataFunc):
458
479
  data_copy = self._choose_copy(data_spec.data_item, data_spec.storage_def)
459
480
  data_storage = self.storage.get_data_storage(data_copy.storageKey)
460
481
 
461
- # Item to be saved should exist in the current context
462
- data_item = _ctx_lookup(self.node.data_item_id, ctx)
463
-
464
482
  # Current implementation will always put an Arrow table in the data item
465
483
  # Empty tables are allowed, so explicitly check if table is None
466
484
  # Testing "if not data_item.table" will fail for empty tables
@@ -567,12 +585,23 @@ class RunModelFunc(NodeFunction[Bundle[_data.DataView]]):
567
585
  msg = f"There was an unhandled error in the model: {str(e)}{details}"
568
586
  raise _ex.EModelExec(msg) from e
569
587
 
570
- # The node result is just the model outputs taken from the local context
571
- model_outputs: Bundle[_data.DataView] = {
572
- name: obj for name, obj in local_ctx.items()
573
- if name in self.node.model_def.outputs}
588
+ # Check required outputs are present and build the results bundle
589
+
590
+ results: Bundle[_data.DataView] = dict()
591
+
592
+ for output_name, output_schema in model_def.outputs.items():
593
+
594
+ result: _data.DataView = local_ctx.get(output_name)
595
+
596
+ if result is None or result.is_empty():
597
+ if not output_schema.optional:
598
+ model_name = self.model_class.__name__
599
+ raise _ex.ERuntimeValidation(f"Missing required output [{output_name}] from model [{model_name}]")
600
+
601
+ if result is not None:
602
+ results[output_name] = result
574
603
 
575
- return model_outputs
604
+ return results
576
605
 
577
606
 
578
607
  # ----------------------------------------------------------------------------------------------------------------------
tracdap/rt/_exec/graph.py CHANGED
@@ -297,6 +297,7 @@ class DataItemNode(MappingNode[_data.DataItem]):
297
297
  class DataResultNode(Node[ObjectBundle]):
298
298
 
299
299
  output_name: str
300
+ data_item_id: NodeId[_data.DataItem]
300
301
  data_spec_id: NodeId[_data.DataSpec]
301
302
  data_save_id: NodeId[type(None)]
302
303
 
@@ -306,6 +307,7 @@ class DataResultNode(Node[ObjectBundle]):
306
307
  def _node_dependencies(self) -> tp.Dict[NodeId, DependencyType]:
307
308
 
308
309
  return {
310
+ self.data_item_id: DependencyType.HARD,
309
311
  self.data_spec_id: DependencyType.HARD,
310
312
  self.data_save_id: DependencyType.HARD}
311
313
 
@@ -14,9 +14,6 @@
14
14
 
15
15
  from __future__ import annotations
16
16
 
17
- import copy
18
- import dataclasses # noqa
19
-
20
17
  import tracdap.rt.config as config
21
18
  import tracdap.rt.exceptions as _ex
22
19
  import tracdap.rt._impl.data as _data # noqa
@@ -123,12 +120,14 @@ class GraphBuilder:
123
120
  job_namespace: NodeNamespace, job_push_id: NodeId) \
124
121
  -> GraphSection:
125
122
 
123
+ target_selector = job_config.job.runModel.model
124
+ target_obj = _util.get_job_resource(target_selector, job_config)
125
+ target_def = target_obj.model
126
+ job_def = job_config.job.runModel
127
+
126
128
  return cls.build_calculation_job(
127
129
  job_config, result_spec, job_namespace, job_push_id,
128
- job_config.job.runModel.model,
129
- job_config.job.runModel.parameters,
130
- job_config.job.runModel.inputs,
131
- job_config.job.runModel.outputs)
130
+ target_selector, target_def, job_def)
132
131
 
133
132
  @classmethod
134
133
  def build_run_flow_job(
@@ -136,40 +135,53 @@ class GraphBuilder:
136
135
  job_namespace: NodeNamespace, job_push_id: NodeId) \
137
136
  -> GraphSection:
138
137
 
138
+ target_selector = job_config.job.runFlow.flow
139
+ target_obj = _util.get_job_resource(target_selector, job_config)
140
+ target_def = target_obj.flow
141
+ job_def = job_config.job.runFlow
142
+
139
143
  return cls.build_calculation_job(
140
144
  job_config, result_spec, job_namespace, job_push_id,
141
- job_config.job.runFlow.flow,
142
- job_config.job.runFlow.parameters,
143
- job_config.job.runFlow.inputs,
144
- job_config.job.runFlow.outputs)
145
+ target_selector, target_def, job_def)
145
146
 
146
147
  @classmethod
147
148
  def build_calculation_job(
148
149
  cls, job_config: config.JobConfig, result_spec: JobResultSpec,
149
150
  job_namespace: NodeNamespace, job_push_id: NodeId,
150
- target: meta.TagSelector, parameters: tp.Dict[str, meta.Value],
151
- inputs: tp.Dict[str, meta.TagSelector], outputs: tp.Dict[str, meta.TagSelector]) \
151
+ target_selector: meta.TagSelector,
152
+ target_def: tp.Union[meta.ModelDefinition, meta.FlowDefinition],
153
+ job_def: tp.Union[meta.RunModelJob, meta.RunFlowJob]) \
152
154
  -> GraphSection:
153
155
 
154
156
  # The main execution graph can run directly in the job context, no need to do a context push
155
157
  # since inputs and outputs in this context line up with the top level execution task
156
158
 
159
+ # Required / provided items are the same for RUN_MODEL and RUN_FLOW jobs
160
+
161
+ required_params = target_def.parameters
162
+ required_inputs = target_def.inputs
163
+ required_outputs = target_def.outputs
164
+
165
+ provided_params = job_def.parameters
166
+ provided_inputs = job_def.inputs
167
+ provided_outputs = job_def.outputs
168
+
157
169
  params_section = cls.build_job_parameters(
158
- job_namespace, parameters,
170
+ job_namespace, required_params, provided_params,
159
171
  explicit_deps=[job_push_id])
160
172
 
161
173
  input_section = cls.build_job_inputs(
162
- job_config, job_namespace, inputs,
174
+ job_config, job_namespace, required_inputs, provided_inputs,
163
175
  explicit_deps=[job_push_id])
164
176
 
165
- exec_obj = _util.get_job_resource(target, job_config)
177
+ exec_obj = _util.get_job_resource(target_selector, job_config)
166
178
 
167
179
  exec_section = cls.build_model_or_flow(
168
180
  job_config, job_namespace, exec_obj,
169
181
  explicit_deps=[job_push_id])
170
182
 
171
183
  output_section = cls.build_job_outputs(
172
- job_config, job_namespace, outputs,
184
+ job_config, job_namespace, required_outputs, provided_outputs,
173
185
  explicit_deps=[job_push_id])
174
186
 
175
187
  main_section = cls._join_sections(params_section, input_section, exec_section, output_section)
@@ -190,13 +202,22 @@ class GraphBuilder:
190
202
  @classmethod
191
203
  def build_job_parameters(
192
204
  cls, job_namespace: NodeNamespace,
193
- parameters: tp.Dict[str, meta.Value],
205
+ required_params: tp.Dict[str, meta.ModelParameter],
206
+ supplied_params: tp.Dict[str, meta.Value],
194
207
  explicit_deps: tp.Optional[tp.List[NodeId]] = None) \
195
208
  -> GraphSection:
196
209
 
197
210
  nodes = dict()
198
211
 
199
- for param_name, param_def in parameters.items():
212
+ for param_name, param_schema in required_params.items():
213
+
214
+ param_def = supplied_params.get(param_name)
215
+
216
+ if param_def is None:
217
+ if param_schema.defaultValue is not None:
218
+ param_def = param_schema.defaultValue
219
+ else:
220
+ raise _ex.EJobValidation(f"Missing required parameter: [{param_name}]")
200
221
 
201
222
  param_id = NodeId(param_name, job_namespace, meta.Value)
202
223
  param_node = StaticValueNode(param_id, param_def, explicit_deps=explicit_deps)
@@ -208,7 +229,8 @@ class GraphBuilder:
208
229
  @classmethod
209
230
  def build_job_inputs(
210
231
  cls, job_config: config.JobConfig, job_namespace: NodeNamespace,
211
- inputs: tp.Dict[str, meta.TagSelector],
232
+ required_inputs: tp.Dict[str, meta.ModelInputSchema],
233
+ supplied_inputs: tp.Dict[str, meta.TagSelector],
212
234
  explicit_deps: tp.Optional[tp.List[NodeId]] = None) \
213
235
  -> GraphSection:
214
236
 
@@ -216,7 +238,18 @@ class GraphBuilder:
216
238
  outputs = set()
217
239
  must_run = list()
218
240
 
219
- for input_name, data_selector in inputs.items():
241
+ for input_name, input_schema in required_inputs.items():
242
+
243
+ data_selector = supplied_inputs.get(input_name)
244
+
245
+ if data_selector is None:
246
+ if input_schema.optional:
247
+ data_view_id = NodeId.of(input_name, job_namespace, _data.DataView)
248
+ nodes[data_view_id] = StaticValueNode(data_view_id, _data.DataView.create_empty())
249
+ outputs.add(data_view_id)
250
+ continue
251
+ else:
252
+ raise _ex.EJobValidation(f"Missing required input: [{input_name}]")
220
253
 
221
254
  # Build a data spec using metadata from the job config
222
255
  # For now we are always loading the root part, snap 0, delta 0
@@ -258,14 +291,24 @@ class GraphBuilder:
258
291
  @classmethod
259
292
  def build_job_outputs(
260
293
  cls, job_config: config.JobConfig, job_namespace: NodeNamespace,
261
- outputs: tp.Dict[str, meta.TagSelector],
294
+ required_outputs: tp.Dict[str, meta.ModelOutputSchema],
295
+ supplied_outputs: tp.Dict[str, meta.TagSelector],
262
296
  explicit_deps: tp.Optional[tp.List[NodeId]] = None) \
263
297
  -> GraphSection:
264
298
 
265
299
  nodes = {}
266
300
  inputs = set()
267
301
 
268
- for output_name, data_selector in outputs.items():
302
+ for output_name, output_schema in required_outputs.items():
303
+
304
+ data_selector = supplied_outputs.get(output_name)
305
+
306
+ if data_selector is None:
307
+ if output_schema.optional:
308
+ optional_info = "(configuration is required for all optional outputs, in case they are produced)"
309
+ raise _ex.EJobValidation(f"Missing optional output: [{output_name}] {optional_info}")
310
+ else:
311
+ raise _ex.EJobValidation(f"Missing required output: [{output_name}]")
269
312
 
270
313
  # Output data view must already exist in the namespace
271
314
  data_view_id = NodeId.of(output_name, job_namespace, _data.DataView)
@@ -323,7 +366,8 @@ class GraphBuilder:
323
366
 
324
367
  data_result_id = NodeId.of(f"{output_name}:RESULT", job_namespace, ObjectBundle)
325
368
  data_result_node = DataResultNode(
326
- data_result_id, output_name, data_spec_id, data_save_id,
369
+ data_result_id, output_name,
370
+ data_item_id, data_spec_id, data_save_id,
327
371
  output_data_key, output_storage_key)
328
372
 
329
373
  nodes[data_spec_id] = data_spec_node
@@ -458,10 +502,10 @@ class GraphBuilder:
458
502
  frozenset(parameter_ids), frozenset(input_ids),
459
503
  explicit_deps=explicit_deps, bundle=model_id.namespace)
460
504
 
461
- module_result_id = NodeId(f"{model_name}:RESULT", namespace)
462
- model_result_node = RunModelResultNode(module_result_id, model_id)
505
+ model_result_id = NodeId(f"{model_name}:RESULT", namespace)
506
+ model_result_node = RunModelResultNode(model_result_id, model_id)
463
507
 
464
- nodes = {model_id: model_node, module_result_id: model_result_node}
508
+ nodes = {model_id: model_node, model_result_id: model_result_node}
465
509
 
466
510
  # Create nodes for each model output
467
511
  # The model node itself outputs a bundle (dictionary of named outputs)
@@ -474,7 +518,7 @@ class GraphBuilder:
474
518
  nodes[output_id] = BundleItemNode(output_id, model_id, output_id.name)
475
519
 
476
520
  # Assemble a graph to include the model and its outputs
477
- return GraphSection(nodes, inputs={*parameter_ids, *input_ids}, outputs=output_ids, must_run=[module_result_id])
521
+ return GraphSection(nodes, inputs={*parameter_ids, *input_ids}, outputs=output_ids, must_run=[model_result_id])
478
522
 
479
523
  @classmethod
480
524
  def build_flow(
@@ -488,7 +532,7 @@ class GraphBuilder:
488
532
 
489
533
  # https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm
490
534
 
491
- remaining_nodes = copy.copy(flow_def.nodes)
535
+ # Group edges by source and target node
492
536
  remaining_edges_by_target = {edge.target.node: [] for edge in flow_def.edges}
493
537
  remaining_edges_by_source = {edge.source.node: [] for edge in flow_def.edges}
494
538
 
@@ -496,16 +540,14 @@ class GraphBuilder:
496
540
  remaining_edges_by_target[edge.target.node].append(edge)
497
541
  remaining_edges_by_source[edge.source.node].append(edge)
498
542
 
499
- reachable_nodes = dict()
500
-
501
- # Initial set of reachable flow nodes is just the input nodes
502
- for node_name, node in list(remaining_nodes.items()):
503
- if node.nodeType == meta.FlowNodeType.INPUT_NODE:
504
- reachable_nodes[node_name] = node
505
- del remaining_nodes[node_name]
506
-
543
+ # Group edges by target socket (only one edge per target in a consistent flow)
507
544
  target_edges = {socket_key(edge.target): edge for edge in flow_def.edges}
508
545
 
546
+ # Initially parameters and inputs are reachable, everything else is not
547
+ def is_input(n): return n[1].nodeType in [meta.FlowNodeType.PARAMETER_NODE, meta.FlowNodeType.INPUT_NODE]
548
+ reachable_nodes = dict(filter(is_input, flow_def.nodes.items()))
549
+ remaining_nodes = dict(filter(lambda n: not is_input(n), flow_def.nodes.items()))
550
+
509
551
  # Initial graph section for the flow is empty
510
552
  graph_section = GraphSection({}, must_run=explicit_deps)
511
553
 
@@ -559,10 +601,16 @@ class GraphBuilder:
559
601
  return NodeId(socket_name, namespace, result_type)
560
602
 
561
603
  def edge_mapping(node_: str, socket_: str = None, result_type=None):
562
- socket = meta.FlowSocket(node_, socket_)
563
- edge = target_edges.get(socket_key(socket)) # todo: inconsistent if missing
604
+ socket = socket_key(meta.FlowSocket(node_, socket_))
605
+ edge = target_edges.get(socket)
606
+ # Report missing edges as a job consistency error (this might happen sometimes in dev mode)
607
+ if edge is None:
608
+ raise _ex.EJobValidation(f"Inconsistent flow: Socket [{socket}] is not connected")
564
609
  return socket_id(edge.source.node, edge.source.socket, result_type)
565
610
 
611
+ if node.nodeType == meta.FlowNodeType.PARAMETER_NODE:
612
+ return GraphSection({}, inputs={NodeId(node_name, namespace, result_type=meta.Value)})
613
+
566
614
  if node.nodeType == meta.FlowNodeType.INPUT_NODE:
567
615
  return GraphSection({}, inputs={NodeId(node_name, namespace, result_type=_data.DataView)})
568
616
 
@@ -573,32 +621,46 @@ class GraphBuilder:
573
621
 
574
622
  if node.nodeType == meta.FlowNodeType.MODEL_NODE:
575
623
 
576
- model_selector = flow_job.models.get(node_name)
577
- model_obj = _util.get_job_resource(model_selector, job_config)
578
-
579
- # TODO: Whether to use flow node or model_obj to build the push mapping?
624
+ param_mapping = {socket: edge_mapping(node_name, socket, meta.Value) for socket in node.parameters}
625
+ input_mapping = {socket: edge_mapping(node_name, socket, _data.DataView) for socket in node.inputs}
626
+ output_mapping = {socket: socket_id(node_name, socket, _data.DataView) for socket in node.outputs}
580
627
 
581
- input_mapping_ = {
582
- input_name: edge_mapping(node_name, input_name, _data.DataView)
583
- for input_name in model_obj.model.inputs
584
- }
628
+ push_mapping = {**input_mapping, **param_mapping}
629
+ pop_mapping = output_mapping
585
630
 
586
- param_mapping_ = {
587
- param_name: NodeId(param_name, namespace, meta.Value)
588
- for param_name in model_obj.model.parameters
589
- }
631
+ model_selector = flow_job.models.get(node_name)
632
+ model_obj = _util.get_job_resource(model_selector, job_config)
590
633
 
591
- push_mapping = {**input_mapping_, **param_mapping_}
634
+ # Missing models in the job config is a job consistency error
635
+ if model_obj is None or model_obj.objectType != meta.ObjectType.MODEL:
636
+ raise _ex.EJobValidation(f"No model was provided for flow node [{node_name}]")
592
637
 
593
- pop_mapping = {
594
- output_: NodeId(f"{node_name}.{output_}", namespace, _data.DataView)
595
- for output_ in model_obj.model.outputs}
638
+ # Explicit check for model compatibility - report an error now, do not try build_model()
639
+ cls.check_model_compatibility(model_selector, model_obj.model, node_name, node)
596
640
 
597
641
  return cls.build_model_or_flow_with_context(
598
642
  job_config, namespace, node_name, model_obj,
599
643
  push_mapping, pop_mapping, explicit_deps)
600
644
 
601
- raise _ex.ETracInternal() # TODO: Invalid node type
645
+ # Missing / invalid node type - should be caught in static validation
646
+ raise _ex.ETracInternal(f"Flow node [{node_name}] has invalid node type [{node.nodeType}]")
647
+
648
+ @classmethod
649
+ def check_model_compatibility(
650
+ cls, model_selector: meta.TagSelector, model_def: meta.ModelDefinition,
651
+ node_name: str, flow_node: meta.FlowNode):
652
+
653
+ model_params = list(sorted(model_def.parameters.keys()))
654
+ model_inputs = list(sorted(model_def.inputs.keys()))
655
+ model_outputs = list(sorted(model_def.outputs.keys()))
656
+
657
+ node_params = list(sorted(flow_node.parameters))
658
+ node_inputs = list(sorted(flow_node.inputs))
659
+ node_outputs = list(sorted(flow_node.outputs))
660
+
661
+ if model_params != node_params or model_inputs != node_inputs or model_outputs != node_outputs:
662
+ model_key = _util.object_key(model_selector)
663
+ raise _ex.EJobValidation(f"Incompatible model for flow node [{node_name}] (Model: [{model_key}])")
602
664
 
603
665
  @staticmethod
604
666
  def build_context_push(
@@ -90,6 +90,10 @@ class TracRuntime:
90
90
  self._scratch_dir_provided = True if scratch_dir is not None else False
91
91
  self._scratch_dir_persist = scratch_dir_persist
92
92
  self._dev_mode = dev_mode
93
+ self._server_enabled = False
94
+ self._server_port = 0
95
+
96
+ self._pre_start_complete = False
93
97
 
94
98
  # Top level resources
95
99
  self._models: tp.Optional[_models.ModelLoader] = None
@@ -100,6 +104,9 @@ class TracRuntime:
100
104
  self._engine: tp.Optional[_engine.TracEngine] = None
101
105
  self._engine_event = threading.Condition()
102
106
 
107
+ # Runtime API server
108
+ self._server = None
109
+
103
110
  self._jobs: tp.Dict[str, _RuntimeJobInfo] = dict()
104
111
 
105
112
  # ------------------------------------------------------------------------------------------------------------------
@@ -152,6 +159,8 @@ class TracRuntime:
152
159
  config_dir = self._sys_config_path.parent if self._sys_config_path is not None else None
153
160
  self._sys_config = _dev_mode.DevModeTranslator.translate_sys_config(self._sys_config, config_dir)
154
161
 
162
+ self._pre_start_complete = True
163
+
155
164
  except Exception as e:
156
165
  self._handle_startup_error(e)
157
166
 
@@ -159,6 +168,10 @@ class TracRuntime:
159
168
 
160
169
  try:
161
170
 
171
+ # Ensure pre-start has been run
172
+ if not self._pre_start_complete:
173
+ self.pre_start()
174
+
162
175
  self._log.info("Starting the engine...")
163
176
 
164
177
  self._models = _models.ModelLoader(self._sys_config, self._scratch_dir)
@@ -175,11 +188,26 @@ class TracRuntime:
175
188
 
176
189
  self._system.start(wait=wait)
177
190
 
191
+ # If the runtime server has been enabled, start it up
192
+ if self._server_enabled:
193
+
194
+ self._log.info("Starting the runtime API server...")
195
+
196
+ # The server module pulls in all the gRPC dependencies, don't import it unless we have to
197
+ import tracdap.rt._exec.server as _server
198
+
199
+ self._server = _server.RuntimeApiServer(self._server_port)
200
+ self._server.start()
201
+
178
202
  except Exception as e:
179
203
  self._handle_startup_error(e)
180
204
 
181
205
  def stop(self, due_to_error=False):
182
206
 
207
+ if self._server is not None:
208
+ self._log.info("Stopping the runtime API server...")
209
+ self._server.stop()
210
+
183
211
  if due_to_error:
184
212
  self._log.info("Shutting down the engine in response to an error")
185
213
  else:
@@ -0,0 +1,68 @@
1
+ # Copyright 2024 Accenture Global Solutions Limited
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import typing as tp
16
+ import concurrent.futures as futures
17
+
18
+ # Imports for gRPC generated code, these are managed by build_runtime.py for distribution
19
+ import tracdap.rt._impl.grpc.tracdap.api.internal.runtime_pb2 as runtime_pb2
20
+ import tracdap.rt._impl.grpc.tracdap.api.internal.runtime_pb2_grpc as runtime_grpc
21
+ import grpc
22
+
23
+
24
+ class RuntimeApiServer(runtime_grpc.TracRuntimeApiServicer):
25
+
26
+ __THREAD_POOL_DEFAULT_SIZE = 2
27
+ __THREAD_NAME_PREFIX = "server-"
28
+ __DEFAULT_SHUTDOWN_TIMEOUT = 10.0 # seconds
29
+
30
+ def __init__(self, port: int, n_workers: int = None):
31
+ self.__port = port
32
+ self.__n_workers = n_workers or self.__THREAD_POOL_DEFAULT_SIZE
33
+ self.__server: tp.Optional[grpc.Server] = None
34
+ self.__thread_pool: tp.Optional[futures.ThreadPoolExecutor] = None
35
+
36
+ def listJobs(self, request, context):
37
+ return super().listJobs(request, context)
38
+
39
+ def getJobStatus(self, request: runtime_pb2.BatchJobStatusRequest, context: grpc.ServicerContext):
40
+ return super().getJobStatus(request, context)
41
+
42
+ def getJobDetails(self, request, context):
43
+ return super().getJobDetails(request, context)
44
+
45
+ def start(self):
46
+
47
+ self.__thread_pool = futures.ThreadPoolExecutor(
48
+ max_workers=self.__n_workers,
49
+ thread_name_prefix=self.__THREAD_NAME_PREFIX)
50
+
51
+ self.__server = grpc.server(self.__thread_pool)
52
+
53
+ socket = f"[::]:{self.__port}"
54
+ self.__server.add_insecure_port(socket)
55
+
56
+ runtime_grpc.add_TracRuntimeApiServicer_to_server(self, self.__server)
57
+
58
+ self.__server.start()
59
+
60
+ def stop(self, shutdown_timeout: float = None):
61
+
62
+ grace = shutdown_timeout or self.__DEFAULT_SHUTDOWN_TIMEOUT
63
+
64
+ if self.__server is not None:
65
+ self.__server.stop(grace)
66
+
67
+ if self.__thread_pool is not None:
68
+ self.__thread_pool.shutdown()
tracdap/rt/_impl/data.py CHANGED
@@ -58,6 +58,13 @@ class DataItem:
58
58
  pandas: tp.Optional[pd.DataFrame] = None
59
59
  pyspark: tp.Any = None
60
60
 
61
+ def is_empty(self) -> bool:
62
+ return self.table is None and (self.batches is None or len(self.batches) == 0)
63
+
64
+ @staticmethod
65
+ def create_empty() -> DataItem:
66
+ return DataItem(pa.schema([]))
67
+
61
68
 
62
69
  @dc.dataclass(frozen=True)
63
70
  class DataView:
@@ -72,6 +79,13 @@ class DataView:
72
79
  arrow_schema = DataMapping.trac_to_arrow_schema(trac_schema)
73
80
  return DataView(trac_schema, arrow_schema, dict())
74
81
 
82
+ def is_empty(self) -> bool:
83
+ return self.parts is None or len(self.parts) == 0
84
+
85
+ @staticmethod
86
+ def create_empty() -> DataView:
87
+ return DataView(_meta.SchemaDefinition(), pa.schema([]), dict())
88
+
75
89
 
76
90
  class _DataInternal:
77
91
  pass
@@ -0,0 +1,13 @@
1
+ # Copyright 2024 Accenture Global Solutions Limited
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
@@ -0,0 +1,44 @@
1
+ # Copyright 2024 Accenture Global Solutions Limited
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import enum
16
+ import typing as tp
17
+
18
+
19
+ def encode(obj: tp.Any) -> tp.Any:
20
+
21
+ # Translate TRAC domain objects into generic dict / list structures
22
+ # These can be accepted by gRPC message constructors, do not try to build messages directly
23
+ # Use shallow copies and builtins to minimize performance impact
24
+
25
+ if obj is None:
26
+ return None
27
+
28
+ if isinstance(obj, str) or isinstance(obj, bool) or isinstance(obj, int) or isinstance(obj, float):
29
+ return obj
30
+
31
+ if isinstance(obj, enum.Enum):
32
+ return obj.value
33
+
34
+ if isinstance(obj, list):
35
+ return list(map(encode, obj))
36
+
37
+ if isinstance(obj, dict):
38
+ return dict(map(lambda kv: (kv[0], encode(kv[1])), obj.items()))
39
+
40
+ # Filter classes for TRAC domain objects (sanity check, not a watertight validation)
41
+ if hasattr(obj, "__module__") and "tracdap" in obj.__module__:
42
+ return dict(map(lambda kv: (kv[0], encode(kv[1])), obj.__dict__.items()))
43
+
44
+ raise RuntimeError(f"Cannot encode object of type [{type(obj).__name__}] for gRPC")