tracdap-runtime 0.6.1.dev3__py3-none-any.whl → 0.6.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (102) hide show
  1. tracdap/rt/_exec/actors.py +87 -10
  2. tracdap/rt/_exec/context.py +25 -1
  3. tracdap/rt/_exec/dev_mode.py +277 -221
  4. tracdap/rt/_exec/engine.py +79 -14
  5. tracdap/rt/_exec/functions.py +37 -8
  6. tracdap/rt/_exec/graph.py +2 -0
  7. tracdap/rt/_exec/graph_builder.py +118 -56
  8. tracdap/rt/_exec/runtime.py +108 -37
  9. tracdap/rt/_exec/server.py +345 -0
  10. tracdap/rt/_impl/config_parser.py +219 -49
  11. tracdap/rt/_impl/data.py +14 -0
  12. tracdap/rt/_impl/grpc/__init__.py +13 -0
  13. tracdap/rt/_impl/grpc/codec.py +99 -0
  14. tracdap/rt/_impl/grpc/tracdap/api/internal/runtime_pb2.py +51 -0
  15. tracdap/rt/_impl/grpc/tracdap/api/internal/runtime_pb2.pyi +61 -0
  16. tracdap/rt/_impl/grpc/tracdap/api/internal/runtime_pb2_grpc.py +183 -0
  17. tracdap/rt/_impl/grpc/tracdap/metadata/common_pb2.py +33 -0
  18. tracdap/rt/_impl/grpc/tracdap/metadata/common_pb2.pyi +34 -0
  19. tracdap/rt/{metadata → _impl/grpc/tracdap/metadata}/custom_pb2.py +5 -5
  20. tracdap/rt/_impl/grpc/tracdap/metadata/custom_pb2.pyi +15 -0
  21. tracdap/rt/_impl/grpc/tracdap/metadata/data_pb2.py +51 -0
  22. tracdap/rt/_impl/grpc/tracdap/metadata/data_pb2.pyi +115 -0
  23. tracdap/rt/_impl/grpc/tracdap/metadata/file_pb2.py +28 -0
  24. tracdap/rt/_impl/grpc/tracdap/metadata/file_pb2.pyi +22 -0
  25. tracdap/rt/_impl/grpc/tracdap/metadata/flow_pb2.py +59 -0
  26. tracdap/rt/_impl/grpc/tracdap/metadata/flow_pb2.pyi +109 -0
  27. tracdap/rt/_impl/grpc/tracdap/metadata/job_pb2.py +76 -0
  28. tracdap/rt/_impl/grpc/tracdap/metadata/job_pb2.pyi +177 -0
  29. tracdap/rt/_impl/grpc/tracdap/metadata/model_pb2.py +63 -0
  30. tracdap/rt/_impl/grpc/tracdap/metadata/model_pb2.pyi +119 -0
  31. tracdap/rt/_impl/grpc/tracdap/metadata/object_id_pb2.py +32 -0
  32. tracdap/rt/_impl/grpc/tracdap/metadata/object_id_pb2.pyi +68 -0
  33. tracdap/rt/_impl/grpc/tracdap/metadata/object_pb2.py +40 -0
  34. tracdap/rt/_impl/grpc/tracdap/metadata/object_pb2.pyi +46 -0
  35. tracdap/rt/_impl/grpc/tracdap/metadata/search_pb2.py +39 -0
  36. tracdap/rt/_impl/grpc/tracdap/metadata/search_pb2.pyi +83 -0
  37. tracdap/rt/_impl/grpc/tracdap/metadata/stoarge_pb2.py +50 -0
  38. tracdap/rt/_impl/grpc/tracdap/metadata/stoarge_pb2.pyi +89 -0
  39. tracdap/rt/_impl/grpc/tracdap/metadata/tag_pb2.py +34 -0
  40. tracdap/rt/_impl/grpc/tracdap/metadata/tag_pb2.pyi +26 -0
  41. tracdap/rt/_impl/grpc/tracdap/metadata/tag_update_pb2.py +30 -0
  42. tracdap/rt/_impl/grpc/tracdap/metadata/tag_update_pb2.pyi +34 -0
  43. tracdap/rt/_impl/grpc/tracdap/metadata/type_pb2.py +47 -0
  44. tracdap/rt/_impl/grpc/tracdap/metadata/type_pb2.pyi +101 -0
  45. tracdap/rt/_impl/guard_rails.py +26 -6
  46. tracdap/rt/_impl/models.py +25 -0
  47. tracdap/rt/_impl/static_api.py +27 -9
  48. tracdap/rt/_impl/type_system.py +17 -0
  49. tracdap/rt/_impl/validation.py +10 -0
  50. tracdap/rt/_plugins/config_local.py +49 -0
  51. tracdap/rt/_version.py +1 -1
  52. tracdap/rt/api/hook.py +10 -3
  53. tracdap/rt/api/model_api.py +22 -0
  54. tracdap/rt/api/static_api.py +79 -19
  55. tracdap/rt/config/__init__.py +3 -3
  56. tracdap/rt/config/common.py +10 -0
  57. tracdap/rt/config/platform.py +9 -19
  58. tracdap/rt/config/runtime.py +2 -0
  59. tracdap/rt/ext/config.py +34 -0
  60. tracdap/rt/ext/embed.py +1 -3
  61. tracdap/rt/ext/plugins.py +47 -6
  62. tracdap/rt/launch/cli.py +7 -5
  63. tracdap/rt/launch/launch.py +49 -12
  64. tracdap/rt/metadata/__init__.py +24 -24
  65. tracdap/rt/metadata/common.py +7 -7
  66. tracdap/rt/metadata/custom.py +2 -0
  67. tracdap/rt/metadata/data.py +28 -5
  68. tracdap/rt/metadata/file.py +2 -0
  69. tracdap/rt/metadata/flow.py +66 -4
  70. tracdap/rt/metadata/job.py +56 -16
  71. tracdap/rt/metadata/model.py +10 -0
  72. tracdap/rt/metadata/object.py +3 -0
  73. tracdap/rt/metadata/object_id.py +9 -9
  74. tracdap/rt/metadata/search.py +35 -13
  75. tracdap/rt/metadata/stoarge.py +64 -6
  76. tracdap/rt/metadata/tag_update.py +21 -7
  77. tracdap/rt/metadata/type.py +28 -13
  78. {tracdap_runtime-0.6.1.dev3.dist-info → tracdap_runtime-0.6.3.dist-info}/METADATA +22 -19
  79. tracdap_runtime-0.6.3.dist-info/RECORD +112 -0
  80. {tracdap_runtime-0.6.1.dev3.dist-info → tracdap_runtime-0.6.3.dist-info}/WHEEL +1 -1
  81. tracdap/rt/config/common_pb2.py +0 -55
  82. tracdap/rt/config/job_pb2.py +0 -42
  83. tracdap/rt/config/platform_pb2.py +0 -71
  84. tracdap/rt/config/result_pb2.py +0 -37
  85. tracdap/rt/config/runtime_pb2.py +0 -42
  86. tracdap/rt/ext/_guard.py +0 -37
  87. tracdap/rt/metadata/common_pb2.py +0 -33
  88. tracdap/rt/metadata/data_pb2.py +0 -51
  89. tracdap/rt/metadata/file_pb2.py +0 -28
  90. tracdap/rt/metadata/flow_pb2.py +0 -55
  91. tracdap/rt/metadata/job_pb2.py +0 -76
  92. tracdap/rt/metadata/model_pb2.py +0 -51
  93. tracdap/rt/metadata/object_id_pb2.py +0 -32
  94. tracdap/rt/metadata/object_pb2.py +0 -35
  95. tracdap/rt/metadata/search_pb2.py +0 -39
  96. tracdap/rt/metadata/stoarge_pb2.py +0 -50
  97. tracdap/rt/metadata/tag_pb2.py +0 -34
  98. tracdap/rt/metadata/tag_update_pb2.py +0 -30
  99. tracdap/rt/metadata/type_pb2.py +0 -48
  100. tracdap_runtime-0.6.1.dev3.dist-info/RECORD +0 -96
  101. {tracdap_runtime-0.6.1.dev3.dist-info → tracdap_runtime-0.6.3.dist-info}/LICENSE +0 -0
  102. {tracdap_runtime-0.6.1.dev3.dist-info → tracdap_runtime-0.6.3.dist-info}/top_level.txt +0 -0
@@ -19,6 +19,7 @@ import dataclasses as dc
19
19
  import enum
20
20
  import typing as tp
21
21
 
22
+ import tracdap.rt.metadata as _meta
22
23
  import tracdap.rt.config as _cfg
23
24
  import tracdap.rt.exceptions as _ex
24
25
  import tracdap.rt._exec.actors as _actors
@@ -28,7 +29,6 @@ import tracdap.rt._impl.models as _models # noqa
28
29
  import tracdap.rt._impl.data as _data # noqa
29
30
  import tracdap.rt._impl.storage as _storage # noqa
30
31
  import tracdap.rt._impl.util as _util # noqa
31
- from .actors import Signal
32
32
 
33
33
  from .graph import NodeId
34
34
 
@@ -66,6 +66,18 @@ class _EngineContext:
66
66
  failed_nodes: tp.Set[NodeId] = dc.field(default_factory=set)
67
67
 
68
68
 
69
+ @dc.dataclass
70
+ class _JobState:
71
+
72
+ job_id: _meta.TagHeader
73
+ job_config: _cfg.JobConfig
74
+
75
+ actor_id: _actors.ActorId = None
76
+
77
+ job_result: _cfg.JobResult = None
78
+ job_error: Exception = None
79
+
80
+
69
81
  class TracEngine(_actors.Actor):
70
82
 
71
83
  """
@@ -88,7 +100,7 @@ class TracEngine(_actors.Actor):
88
100
  self._storage = storage
89
101
  self._notify_callback = notify_callback
90
102
 
91
- self._job_actors = dict()
103
+ self._jobs: tp.Dict[str, _JobState] = dict()
92
104
 
93
105
  def on_start(self):
94
106
 
@@ -98,7 +110,7 @@ class TracEngine(_actors.Actor):
98
110
 
99
111
  self._log.info("Engine shutdown complete")
100
112
 
101
- def on_signal(self, signal: Signal) -> tp.Optional[bool]:
113
+ def on_signal(self, signal: _actors.Signal) -> tp.Optional[bool]:
102
114
 
103
115
  # Failed signals can propagate from leaf nodes up the actor tree for a job
104
116
  # If the failure goes all the way up the tree without being handled, it will reach the engine node
@@ -110,8 +122,8 @@ class TracEngine(_actors.Actor):
110
122
  failed_job_key = None
111
123
 
112
124
  # Look for the job key corresponding to the failed actor
113
- for job_key, job_actor in self._job_actors.items():
114
- if job_actor == signal.sender:
125
+ for job_key, job_state in self._jobs.items():
126
+ if job_state.actor_id == signal.sender:
115
127
  failed_job_key = job_key
116
128
 
117
129
  # If the job is still live, call job_failed explicitly
@@ -147,19 +159,34 @@ class TracEngine(_actors.Actor):
147
159
  job_processor = JobProcessor(job_key, job_config, result_spec,self._models, self._storage)
148
160
  job_actor_id = self.actors().spawn(job_processor)
149
161
 
150
- job_actors = {**self._job_actors, job_key: job_actor_id}
151
- self._job_actors = job_actors
162
+ job_state = _JobState(job_config.jobId, job_config)
163
+ job_state.actor_id = job_actor_id
164
+
165
+ self._jobs[job_key] = job_state
166
+
167
+ @_actors.Message
168
+ def get_job_list(self):
169
+
170
+ job_list = list(map(self._get_job_info, self._jobs.keys()))
171
+ self.actors().reply("job_list", job_list)
172
+
173
+ @_actors.Message
174
+ def get_job_details(self, job_key: str, details: bool):
175
+
176
+ details = self._get_job_info(job_key, details)
177
+ self.actors().reply("job_details", details)
152
178
 
153
179
  @_actors.Message
154
180
  def job_succeeded(self, job_key: str, job_result: _cfg.JobResult):
155
181
 
156
182
  # Ignore duplicate messages from the job processor (can happen in unusual error cases)
157
- if job_key not in self._job_actors:
183
+ if job_key not in self._jobs:
158
184
  self._log.warning(f"Ignoring [job_succeeded] message, job [{job_key}] has already completed")
159
185
  return
160
186
 
161
187
  self._log.info(f"Recording job as successful: {job_key}")
162
188
 
189
+ self._jobs[job_key].job_result = job_result
163
190
  self._finalize_job(job_key)
164
191
 
165
192
  if self._notify_callback is not None:
@@ -169,12 +196,13 @@ class TracEngine(_actors.Actor):
169
196
  def job_failed(self, job_key: str, error: Exception):
170
197
 
171
198
  # Ignore duplicate messages from the job processor (can happen in unusual error cases)
172
- if job_key not in self._job_actors:
199
+ if job_key not in self._jobs:
173
200
  self._log.warning(f"Ignoring [job_failed] message, job [{job_key}] has already completed")
174
201
  return
175
202
 
176
203
  self._log.error(f"Recording job as failed: {job_key}")
177
204
 
205
+ self._jobs[job_key].job_error = error
178
206
  self._finalize_job(job_key)
179
207
 
180
208
  if self._notify_callback is not None:
@@ -182,10 +210,47 @@ class TracEngine(_actors.Actor):
182
210
 
183
211
  def _finalize_job(self, job_key: str):
184
212
 
185
- job_actors = self._job_actors
186
- job_actor_id = job_actors.pop(job_key)
187
- self.actors().stop(job_actor_id)
188
- self._job_actors = job_actors
213
+ # Stop the actor but keep the job state available for status / results queries
214
+
215
+ # In the future, job state will need to be expunged after some period of time
216
+ # For now each instance of the runtime only processes one job so no need to worry
217
+
218
+ job_state = self._jobs.get(job_key)
219
+ job_actor_id = job_state.actor_id if job_state is not None else None
220
+
221
+ if job_actor_id is not None:
222
+ self.actors().stop(job_actor_id)
223
+ job_state.actor_id = None
224
+
225
+ def _get_job_info(self, job_key: str, details: bool = False) -> tp.Optional[_cfg.JobResult]:
226
+
227
+ job_state = self._jobs.get(job_key)
228
+
229
+ if job_state is None:
230
+ return None
231
+
232
+ job_result = _cfg.JobResult()
233
+ job_result.jobId = job_state.job_id
234
+
235
+ if job_state.actor_id is not None:
236
+ job_result.statusCode = _meta.JobStatusCode.RUNNING
237
+
238
+ elif job_state.job_result is not None:
239
+ job_result.statusCode = job_state.job_result.statusCode
240
+ job_result.statusMessage = job_state.job_result.statusMessage
241
+ if details:
242
+ job_result.results = job_state.job_result.results or dict()
243
+
244
+ elif job_state.job_error is not None:
245
+ job_result.statusCode = _meta.JobStatusCode.FAILED
246
+ job_result.statusMessage = str(job_state.job_error.args[0])
247
+
248
+ else:
249
+ # Alternatively return UNKNOWN status or throw an error here
250
+ job_result.statusCode = _meta.JobStatusCode.FAILED
251
+ job_result.statusMessage = "No details available"
252
+
253
+ return job_result
189
254
 
190
255
 
191
256
  class JobProcessor(_actors.Actor):
@@ -218,7 +283,7 @@ class JobProcessor(_actors.Actor):
218
283
  self._log.info(f"Cleaning up job [{self.job_key}]")
219
284
  self._models.destroy_scope(self.job_key)
220
285
 
221
- def on_signal(self, signal: Signal) -> tp.Optional[bool]:
286
+ def on_signal(self, signal: _actors.Signal) -> tp.Optional[bool]:
222
287
 
223
288
  if signal.message == _actors.SignalNames.FAILED and isinstance(signal, _actors.ErrorSignal):
224
289
 
@@ -248,6 +248,10 @@ class DataViewFunc(NodeFunction[_data.DataView]):
248
248
  root_item = _ctx_lookup(self.node.root_item, ctx)
249
249
  root_part_key = _data.DataPartKey.for_root()
250
250
 
251
+ # Map empty item -> emtpy view (for optional inputs not supplied)
252
+ if root_item.is_empty():
253
+ return _data.DataView.create_empty()
254
+
251
255
  data_view = _data.DataView.for_trac_schema(self.node.schema)
252
256
  data_view = _data.DataMapping.add_item_to_view(data_view, root_part_key, root_item)
253
257
 
@@ -263,6 +267,10 @@ class DataItemFunc(NodeFunction[_data.DataItem]):
263
267
 
264
268
  data_view = _ctx_lookup(self.node.data_view_id, ctx)
265
269
 
270
+ # Map empty view -> emtpy item (for optional outputs not supplied)
271
+ if data_view.is_empty():
272
+ return _data.DataItem.create_empty()
273
+
266
274
  # TODO: Support selecting data item described by self.node
267
275
 
268
276
  # Selecting data item for part-root, delta=0
@@ -280,6 +288,12 @@ class DataResultFunc(NodeFunction[ObjectBundle]):
280
288
 
281
289
  def _execute(self, ctx: NodeContext) -> ObjectBundle:
282
290
 
291
+ data_item = _ctx_lookup(self.node.data_item_id, ctx)
292
+
293
+ # Do not record output metadata for optional outputs that are empty
294
+ if data_item.is_empty():
295
+ return {}
296
+
283
297
  data_spec = _ctx_lookup(self.node.data_spec_id, ctx)
284
298
 
285
299
  # TODO: Check result of save operation
@@ -451,6 +465,13 @@ class SaveDataFunc(NodeFunction[None], _LoadSaveDataFunc):
451
465
 
452
466
  def _execute(self, ctx: NodeContext):
453
467
 
468
+ # Item to be saved should exist in the current context
469
+ data_item = _ctx_lookup(self.node.data_item_id, ctx)
470
+
471
+ # Do not save empty outputs (optional outputs that were not produced)
472
+ if data_item.is_empty():
473
+ return
474
+
454
475
  # This function assumes that metadata has already been generated as the data_spec
455
476
  # i.e. it is already known which incarnation / copy of the data will be created
456
477
 
@@ -458,9 +479,6 @@ class SaveDataFunc(NodeFunction[None], _LoadSaveDataFunc):
458
479
  data_copy = self._choose_copy(data_spec.data_item, data_spec.storage_def)
459
480
  data_storage = self.storage.get_data_storage(data_copy.storageKey)
460
481
 
461
- # Item to be saved should exist in the current context
462
- data_item = _ctx_lookup(self.node.data_item_id, ctx)
463
-
464
482
  # Current implementation will always put an Arrow table in the data item
465
483
  # Empty tables are allowed, so explicitly check if table is None
466
484
  # Testing "if not data_item.table" will fail for empty tables
@@ -567,12 +585,23 @@ class RunModelFunc(NodeFunction[Bundle[_data.DataView]]):
567
585
  msg = f"There was an unhandled error in the model: {str(e)}{details}"
568
586
  raise _ex.EModelExec(msg) from e
569
587
 
570
- # The node result is just the model outputs taken from the local context
571
- model_outputs: Bundle[_data.DataView] = {
572
- name: obj for name, obj in local_ctx.items()
573
- if name in self.node.model_def.outputs}
588
+ # Check required outputs are present and build the results bundle
589
+
590
+ results: Bundle[_data.DataView] = dict()
591
+
592
+ for output_name, output_schema in model_def.outputs.items():
593
+
594
+ result: _data.DataView = local_ctx.get(output_name)
595
+
596
+ if result is None or result.is_empty():
597
+ if not output_schema.optional:
598
+ model_name = self.model_class.__name__
599
+ raise _ex.ERuntimeValidation(f"Missing required output [{output_name}] from model [{model_name}]")
600
+
601
+ if result is not None:
602
+ results[output_name] = result
574
603
 
575
- return model_outputs
604
+ return results
576
605
 
577
606
 
578
607
  # ----------------------------------------------------------------------------------------------------------------------
tracdap/rt/_exec/graph.py CHANGED
@@ -297,6 +297,7 @@ class DataItemNode(MappingNode[_data.DataItem]):
297
297
  class DataResultNode(Node[ObjectBundle]):
298
298
 
299
299
  output_name: str
300
+ data_item_id: NodeId[_data.DataItem]
300
301
  data_spec_id: NodeId[_data.DataSpec]
301
302
  data_save_id: NodeId[type(None)]
302
303
 
@@ -306,6 +307,7 @@ class DataResultNode(Node[ObjectBundle]):
306
307
  def _node_dependencies(self) -> tp.Dict[NodeId, DependencyType]:
307
308
 
308
309
  return {
310
+ self.data_item_id: DependencyType.HARD,
309
311
  self.data_spec_id: DependencyType.HARD,
310
312
  self.data_save_id: DependencyType.HARD}
311
313
 
@@ -14,9 +14,6 @@
14
14
 
15
15
  from __future__ import annotations
16
16
 
17
- import copy
18
- import dataclasses # noqa
19
-
20
17
  import tracdap.rt.config as config
21
18
  import tracdap.rt.exceptions as _ex
22
19
  import tracdap.rt._impl.data as _data # noqa
@@ -123,12 +120,14 @@ class GraphBuilder:
123
120
  job_namespace: NodeNamespace, job_push_id: NodeId) \
124
121
  -> GraphSection:
125
122
 
123
+ target_selector = job_config.job.runModel.model
124
+ target_obj = _util.get_job_resource(target_selector, job_config)
125
+ target_def = target_obj.model
126
+ job_def = job_config.job.runModel
127
+
126
128
  return cls.build_calculation_job(
127
129
  job_config, result_spec, job_namespace, job_push_id,
128
- job_config.job.runModel.model,
129
- job_config.job.runModel.parameters,
130
- job_config.job.runModel.inputs,
131
- job_config.job.runModel.outputs)
130
+ target_selector, target_def, job_def)
132
131
 
133
132
  @classmethod
134
133
  def build_run_flow_job(
@@ -136,40 +135,53 @@ class GraphBuilder:
136
135
  job_namespace: NodeNamespace, job_push_id: NodeId) \
137
136
  -> GraphSection:
138
137
 
138
+ target_selector = job_config.job.runFlow.flow
139
+ target_obj = _util.get_job_resource(target_selector, job_config)
140
+ target_def = target_obj.flow
141
+ job_def = job_config.job.runFlow
142
+
139
143
  return cls.build_calculation_job(
140
144
  job_config, result_spec, job_namespace, job_push_id,
141
- job_config.job.runFlow.flow,
142
- job_config.job.runFlow.parameters,
143
- job_config.job.runFlow.inputs,
144
- job_config.job.runFlow.outputs)
145
+ target_selector, target_def, job_def)
145
146
 
146
147
  @classmethod
147
148
  def build_calculation_job(
148
149
  cls, job_config: config.JobConfig, result_spec: JobResultSpec,
149
150
  job_namespace: NodeNamespace, job_push_id: NodeId,
150
- target: meta.TagSelector, parameters: tp.Dict[str, meta.Value],
151
- inputs: tp.Dict[str, meta.TagSelector], outputs: tp.Dict[str, meta.TagSelector]) \
151
+ target_selector: meta.TagSelector,
152
+ target_def: tp.Union[meta.ModelDefinition, meta.FlowDefinition],
153
+ job_def: tp.Union[meta.RunModelJob, meta.RunFlowJob]) \
152
154
  -> GraphSection:
153
155
 
154
156
  # The main execution graph can run directly in the job context, no need to do a context push
155
157
  # since inputs and outputs in this context line up with the top level execution task
156
158
 
159
+ # Required / provided items are the same for RUN_MODEL and RUN_FLOW jobs
160
+
161
+ required_params = target_def.parameters
162
+ required_inputs = target_def.inputs
163
+ required_outputs = target_def.outputs
164
+
165
+ provided_params = job_def.parameters
166
+ provided_inputs = job_def.inputs
167
+ provided_outputs = job_def.outputs
168
+
157
169
  params_section = cls.build_job_parameters(
158
- job_namespace, parameters,
170
+ job_namespace, required_params, provided_params,
159
171
  explicit_deps=[job_push_id])
160
172
 
161
173
  input_section = cls.build_job_inputs(
162
- job_config, job_namespace, inputs,
174
+ job_config, job_namespace, required_inputs, provided_inputs,
163
175
  explicit_deps=[job_push_id])
164
176
 
165
- exec_obj = _util.get_job_resource(target, job_config)
177
+ exec_obj = _util.get_job_resource(target_selector, job_config)
166
178
 
167
179
  exec_section = cls.build_model_or_flow(
168
180
  job_config, job_namespace, exec_obj,
169
181
  explicit_deps=[job_push_id])
170
182
 
171
183
  output_section = cls.build_job_outputs(
172
- job_config, job_namespace, outputs,
184
+ job_config, job_namespace, required_outputs, provided_outputs,
173
185
  explicit_deps=[job_push_id])
174
186
 
175
187
  main_section = cls._join_sections(params_section, input_section, exec_section, output_section)
@@ -190,13 +202,22 @@ class GraphBuilder:
190
202
  @classmethod
191
203
  def build_job_parameters(
192
204
  cls, job_namespace: NodeNamespace,
193
- parameters: tp.Dict[str, meta.Value],
205
+ required_params: tp.Dict[str, meta.ModelParameter],
206
+ supplied_params: tp.Dict[str, meta.Value],
194
207
  explicit_deps: tp.Optional[tp.List[NodeId]] = None) \
195
208
  -> GraphSection:
196
209
 
197
210
  nodes = dict()
198
211
 
199
- for param_name, param_def in parameters.items():
212
+ for param_name, param_schema in required_params.items():
213
+
214
+ param_def = supplied_params.get(param_name)
215
+
216
+ if param_def is None:
217
+ if param_schema.defaultValue is not None:
218
+ param_def = param_schema.defaultValue
219
+ else:
220
+ raise _ex.EJobValidation(f"Missing required parameter: [{param_name}]")
200
221
 
201
222
  param_id = NodeId(param_name, job_namespace, meta.Value)
202
223
  param_node = StaticValueNode(param_id, param_def, explicit_deps=explicit_deps)
@@ -208,7 +229,8 @@ class GraphBuilder:
208
229
  @classmethod
209
230
  def build_job_inputs(
210
231
  cls, job_config: config.JobConfig, job_namespace: NodeNamespace,
211
- inputs: tp.Dict[str, meta.TagSelector],
232
+ required_inputs: tp.Dict[str, meta.ModelInputSchema],
233
+ supplied_inputs: tp.Dict[str, meta.TagSelector],
212
234
  explicit_deps: tp.Optional[tp.List[NodeId]] = None) \
213
235
  -> GraphSection:
214
236
 
@@ -216,7 +238,18 @@ class GraphBuilder:
216
238
  outputs = set()
217
239
  must_run = list()
218
240
 
219
- for input_name, data_selector in inputs.items():
241
+ for input_name, input_schema in required_inputs.items():
242
+
243
+ data_selector = supplied_inputs.get(input_name)
244
+
245
+ if data_selector is None:
246
+ if input_schema.optional:
247
+ data_view_id = NodeId.of(input_name, job_namespace, _data.DataView)
248
+ nodes[data_view_id] = StaticValueNode(data_view_id, _data.DataView.create_empty())
249
+ outputs.add(data_view_id)
250
+ continue
251
+ else:
252
+ raise _ex.EJobValidation(f"Missing required input: [{input_name}]")
220
253
 
221
254
  # Build a data spec using metadata from the job config
222
255
  # For now we are always loading the root part, snap 0, delta 0
@@ -258,14 +291,24 @@ class GraphBuilder:
258
291
  @classmethod
259
292
  def build_job_outputs(
260
293
  cls, job_config: config.JobConfig, job_namespace: NodeNamespace,
261
- outputs: tp.Dict[str, meta.TagSelector],
294
+ required_outputs: tp.Dict[str, meta.ModelOutputSchema],
295
+ supplied_outputs: tp.Dict[str, meta.TagSelector],
262
296
  explicit_deps: tp.Optional[tp.List[NodeId]] = None) \
263
297
  -> GraphSection:
264
298
 
265
299
  nodes = {}
266
300
  inputs = set()
267
301
 
268
- for output_name, data_selector in outputs.items():
302
+ for output_name, output_schema in required_outputs.items():
303
+
304
+ data_selector = supplied_outputs.get(output_name)
305
+
306
+ if data_selector is None:
307
+ if output_schema.optional:
308
+ optional_info = "(configuration is required for all optional outputs, in case they are produced)"
309
+ raise _ex.EJobValidation(f"Missing optional output: [{output_name}] {optional_info}")
310
+ else:
311
+ raise _ex.EJobValidation(f"Missing required output: [{output_name}]")
269
312
 
270
313
  # Output data view must already exist in the namespace
271
314
  data_view_id = NodeId.of(output_name, job_namespace, _data.DataView)
@@ -323,7 +366,8 @@ class GraphBuilder:
323
366
 
324
367
  data_result_id = NodeId.of(f"{output_name}:RESULT", job_namespace, ObjectBundle)
325
368
  data_result_node = DataResultNode(
326
- data_result_id, output_name, data_spec_id, data_save_id,
369
+ data_result_id, output_name,
370
+ data_item_id, data_spec_id, data_save_id,
327
371
  output_data_key, output_storage_key)
328
372
 
329
373
  nodes[data_spec_id] = data_spec_node
@@ -458,10 +502,10 @@ class GraphBuilder:
458
502
  frozenset(parameter_ids), frozenset(input_ids),
459
503
  explicit_deps=explicit_deps, bundle=model_id.namespace)
460
504
 
461
- module_result_id = NodeId(f"{model_name}:RESULT", namespace)
462
- model_result_node = RunModelResultNode(module_result_id, model_id)
505
+ model_result_id = NodeId(f"{model_name}:RESULT", namespace)
506
+ model_result_node = RunModelResultNode(model_result_id, model_id)
463
507
 
464
- nodes = {model_id: model_node, module_result_id: model_result_node}
508
+ nodes = {model_id: model_node, model_result_id: model_result_node}
465
509
 
466
510
  # Create nodes for each model output
467
511
  # The model node itself outputs a bundle (dictionary of named outputs)
@@ -474,7 +518,7 @@ class GraphBuilder:
474
518
  nodes[output_id] = BundleItemNode(output_id, model_id, output_id.name)
475
519
 
476
520
  # Assemble a graph to include the model and its outputs
477
- return GraphSection(nodes, inputs={*parameter_ids, *input_ids}, outputs=output_ids, must_run=[module_result_id])
521
+ return GraphSection(nodes, inputs={*parameter_ids, *input_ids}, outputs=output_ids, must_run=[model_result_id])
478
522
 
479
523
  @classmethod
480
524
  def build_flow(
@@ -488,7 +532,7 @@ class GraphBuilder:
488
532
 
489
533
  # https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm
490
534
 
491
- remaining_nodes = copy.copy(flow_def.nodes)
535
+ # Group edges by source and target node
492
536
  remaining_edges_by_target = {edge.target.node: [] for edge in flow_def.edges}
493
537
  remaining_edges_by_source = {edge.source.node: [] for edge in flow_def.edges}
494
538
 
@@ -496,16 +540,14 @@ class GraphBuilder:
496
540
  remaining_edges_by_target[edge.target.node].append(edge)
497
541
  remaining_edges_by_source[edge.source.node].append(edge)
498
542
 
499
- reachable_nodes = dict()
500
-
501
- # Initial set of reachable flow nodes is just the input nodes
502
- for node_name, node in list(remaining_nodes.items()):
503
- if node.nodeType == meta.FlowNodeType.INPUT_NODE:
504
- reachable_nodes[node_name] = node
505
- del remaining_nodes[node_name]
506
-
543
+ # Group edges by target socket (only one edge per target in a consistent flow)
507
544
  target_edges = {socket_key(edge.target): edge for edge in flow_def.edges}
508
545
 
546
+ # Initially parameters and inputs are reachable, everything else is not
547
+ def is_input(n): return n[1].nodeType in [meta.FlowNodeType.PARAMETER_NODE, meta.FlowNodeType.INPUT_NODE]
548
+ reachable_nodes = dict(filter(is_input, flow_def.nodes.items()))
549
+ remaining_nodes = dict(filter(lambda n: not is_input(n), flow_def.nodes.items()))
550
+
509
551
  # Initial graph section for the flow is empty
510
552
  graph_section = GraphSection({}, must_run=explicit_deps)
511
553
 
@@ -559,10 +601,16 @@ class GraphBuilder:
559
601
  return NodeId(socket_name, namespace, result_type)
560
602
 
561
603
  def edge_mapping(node_: str, socket_: str = None, result_type=None):
562
- socket = meta.FlowSocket(node_, socket_)
563
- edge = target_edges.get(socket_key(socket)) # todo: inconsistent if missing
604
+ socket = socket_key(meta.FlowSocket(node_, socket_))
605
+ edge = target_edges.get(socket)
606
+ # Report missing edges as a job consistency error (this might happen sometimes in dev mode)
607
+ if edge is None:
608
+ raise _ex.EJobValidation(f"Inconsistent flow: Socket [{socket}] is not connected")
564
609
  return socket_id(edge.source.node, edge.source.socket, result_type)
565
610
 
611
+ if node.nodeType == meta.FlowNodeType.PARAMETER_NODE:
612
+ return GraphSection({}, inputs={NodeId(node_name, namespace, result_type=meta.Value)})
613
+
566
614
  if node.nodeType == meta.FlowNodeType.INPUT_NODE:
567
615
  return GraphSection({}, inputs={NodeId(node_name, namespace, result_type=_data.DataView)})
568
616
 
@@ -573,32 +621,46 @@ class GraphBuilder:
573
621
 
574
622
  if node.nodeType == meta.FlowNodeType.MODEL_NODE:
575
623
 
576
- model_selector = flow_job.models.get(node_name)
577
- model_obj = _util.get_job_resource(model_selector, job_config)
578
-
579
- # TODO: Whether to use flow node or model_obj to build the push mapping?
624
+ param_mapping = {socket: edge_mapping(node_name, socket, meta.Value) for socket in node.parameters}
625
+ input_mapping = {socket: edge_mapping(node_name, socket, _data.DataView) for socket in node.inputs}
626
+ output_mapping = {socket: socket_id(node_name, socket, _data.DataView) for socket in node.outputs}
580
627
 
581
- input_mapping_ = {
582
- input_name: edge_mapping(node_name, input_name, _data.DataView)
583
- for input_name in model_obj.model.inputs
584
- }
628
+ push_mapping = {**input_mapping, **param_mapping}
629
+ pop_mapping = output_mapping
585
630
 
586
- param_mapping_ = {
587
- param_name: NodeId(param_name, namespace, meta.Value)
588
- for param_name in model_obj.model.parameters
589
- }
631
+ model_selector = flow_job.models.get(node_name)
632
+ model_obj = _util.get_job_resource(model_selector, job_config)
590
633
 
591
- push_mapping = {**input_mapping_, **param_mapping_}
634
+ # Missing models in the job config is a job consistency error
635
+ if model_obj is None or model_obj.objectType != meta.ObjectType.MODEL:
636
+ raise _ex.EJobValidation(f"No model was provided for flow node [{node_name}]")
592
637
 
593
- pop_mapping = {
594
- output_: NodeId(f"{node_name}.{output_}", namespace, _data.DataView)
595
- for output_ in model_obj.model.outputs}
638
+ # Explicit check for model compatibility - report an error now, do not try build_model()
639
+ cls.check_model_compatibility(model_selector, model_obj.model, node_name, node)
596
640
 
597
641
  return cls.build_model_or_flow_with_context(
598
642
  job_config, namespace, node_name, model_obj,
599
643
  push_mapping, pop_mapping, explicit_deps)
600
644
 
601
- raise _ex.ETracInternal() # TODO: Invalid node type
645
+ # Missing / invalid node type - should be caught in static validation
646
+ raise _ex.ETracInternal(f"Flow node [{node_name}] has invalid node type [{node.nodeType}]")
647
+
648
+ @classmethod
649
+ def check_model_compatibility(
650
+ cls, model_selector: meta.TagSelector, model_def: meta.ModelDefinition,
651
+ node_name: str, flow_node: meta.FlowNode):
652
+
653
+ model_params = list(sorted(model_def.parameters.keys()))
654
+ model_inputs = list(sorted(model_def.inputs.keys()))
655
+ model_outputs = list(sorted(model_def.outputs.keys()))
656
+
657
+ node_params = list(sorted(flow_node.parameters))
658
+ node_inputs = list(sorted(flow_node.inputs))
659
+ node_outputs = list(sorted(flow_node.outputs))
660
+
661
+ if model_params != node_params or model_inputs != node_inputs or model_outputs != node_outputs:
662
+ model_key = _util.object_key(model_selector)
663
+ raise _ex.EJobValidation(f"Incompatible model for flow node [{node_name}] (Model: [{model_key}])")
602
664
 
603
665
  @staticmethod
604
666
  def build_context_push(