tracdap-runtime 0.6.1.dev3__py3-none-any.whl → 0.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (98) hide show
  1. tracdap/rt/_exec/context.py +25 -1
  2. tracdap/rt/_exec/dev_mode.py +277 -213
  3. tracdap/rt/_exec/functions.py +37 -8
  4. tracdap/rt/_exec/graph.py +2 -0
  5. tracdap/rt/_exec/graph_builder.py +118 -56
  6. tracdap/rt/_exec/runtime.py +28 -0
  7. tracdap/rt/_exec/server.py +68 -0
  8. tracdap/rt/_impl/data.py +14 -0
  9. tracdap/rt/_impl/grpc/__init__.py +13 -0
  10. tracdap/rt/_impl/grpc/codec.py +44 -0
  11. tracdap/rt/_impl/grpc/tracdap/api/internal/runtime_pb2.py +51 -0
  12. tracdap/rt/_impl/grpc/tracdap/api/internal/runtime_pb2.pyi +59 -0
  13. tracdap/rt/_impl/grpc/tracdap/api/internal/runtime_pb2_grpc.py +183 -0
  14. tracdap/rt/_impl/grpc/tracdap/config/common_pb2.py +55 -0
  15. tracdap/rt/_impl/grpc/tracdap/config/common_pb2.pyi +103 -0
  16. tracdap/rt/_impl/grpc/tracdap/config/job_pb2.py +42 -0
  17. tracdap/rt/_impl/grpc/tracdap/config/job_pb2.pyi +44 -0
  18. tracdap/rt/_impl/grpc/tracdap/config/platform_pb2.py +71 -0
  19. tracdap/rt/_impl/grpc/tracdap/config/platform_pb2.pyi +197 -0
  20. tracdap/rt/_impl/grpc/tracdap/config/result_pb2.py +37 -0
  21. tracdap/rt/_impl/grpc/tracdap/config/result_pb2.pyi +35 -0
  22. tracdap/rt/_impl/grpc/tracdap/config/runtime_pb2.py +42 -0
  23. tracdap/rt/_impl/grpc/tracdap/config/runtime_pb2.pyi +46 -0
  24. tracdap/rt/_impl/grpc/tracdap/metadata/common_pb2.py +33 -0
  25. tracdap/rt/_impl/grpc/tracdap/metadata/common_pb2.pyi +34 -0
  26. tracdap/rt/{metadata → _impl/grpc/tracdap/metadata}/custom_pb2.py +5 -5
  27. tracdap/rt/_impl/grpc/tracdap/metadata/custom_pb2.pyi +15 -0
  28. tracdap/rt/_impl/grpc/tracdap/metadata/data_pb2.py +51 -0
  29. tracdap/rt/_impl/grpc/tracdap/metadata/data_pb2.pyi +115 -0
  30. tracdap/rt/_impl/grpc/tracdap/metadata/file_pb2.py +28 -0
  31. tracdap/rt/_impl/grpc/tracdap/metadata/file_pb2.pyi +22 -0
  32. tracdap/rt/_impl/grpc/tracdap/metadata/flow_pb2.py +59 -0
  33. tracdap/rt/_impl/grpc/tracdap/metadata/flow_pb2.pyi +109 -0
  34. tracdap/rt/_impl/grpc/tracdap/metadata/job_pb2.py +76 -0
  35. tracdap/rt/_impl/grpc/tracdap/metadata/job_pb2.pyi +177 -0
  36. tracdap/rt/_impl/grpc/tracdap/metadata/model_pb2.py +51 -0
  37. tracdap/rt/_impl/grpc/tracdap/metadata/model_pb2.pyi +92 -0
  38. tracdap/rt/_impl/grpc/tracdap/metadata/object_id_pb2.py +32 -0
  39. tracdap/rt/_impl/grpc/tracdap/metadata/object_id_pb2.pyi +68 -0
  40. tracdap/rt/_impl/grpc/tracdap/metadata/object_pb2.py +35 -0
  41. tracdap/rt/_impl/grpc/tracdap/metadata/object_pb2.pyi +35 -0
  42. tracdap/rt/_impl/grpc/tracdap/metadata/search_pb2.py +39 -0
  43. tracdap/rt/_impl/grpc/tracdap/metadata/search_pb2.pyi +83 -0
  44. tracdap/rt/_impl/grpc/tracdap/metadata/stoarge_pb2.py +50 -0
  45. tracdap/rt/_impl/grpc/tracdap/metadata/stoarge_pb2.pyi +89 -0
  46. tracdap/rt/_impl/grpc/tracdap/metadata/tag_pb2.py +34 -0
  47. tracdap/rt/_impl/grpc/tracdap/metadata/tag_pb2.pyi +26 -0
  48. tracdap/rt/_impl/grpc/tracdap/metadata/tag_update_pb2.py +30 -0
  49. tracdap/rt/_impl/grpc/tracdap/metadata/tag_update_pb2.pyi +34 -0
  50. tracdap/rt/_impl/grpc/tracdap/metadata/type_pb2.py +47 -0
  51. tracdap/rt/_impl/grpc/tracdap/metadata/type_pb2.pyi +101 -0
  52. tracdap/rt/_impl/guard_rails.py +5 -6
  53. tracdap/rt/_impl/static_api.py +10 -6
  54. tracdap/rt/_version.py +1 -1
  55. tracdap/rt/api/hook.py +6 -2
  56. tracdap/rt/api/model_api.py +22 -0
  57. tracdap/rt/api/static_api.py +14 -4
  58. tracdap/rt/config/__init__.py +3 -3
  59. tracdap/rt/config/platform.py +9 -9
  60. tracdap/rt/launch/cli.py +3 -5
  61. tracdap/rt/launch/launch.py +15 -3
  62. tracdap/rt/metadata/__init__.py +15 -15
  63. tracdap/rt/metadata/common.py +7 -7
  64. tracdap/rt/metadata/custom.py +2 -0
  65. tracdap/rt/metadata/data.py +28 -5
  66. tracdap/rt/metadata/file.py +2 -0
  67. tracdap/rt/metadata/flow.py +66 -4
  68. tracdap/rt/metadata/job.py +56 -16
  69. tracdap/rt/metadata/model.py +4 -0
  70. tracdap/rt/metadata/object_id.py +9 -9
  71. tracdap/rt/metadata/search.py +35 -13
  72. tracdap/rt/metadata/stoarge.py +64 -6
  73. tracdap/rt/metadata/tag_update.py +21 -7
  74. tracdap/rt/metadata/type.py +28 -13
  75. {tracdap_runtime-0.6.1.dev3.dist-info → tracdap_runtime-0.6.2.dist-info}/METADATA +22 -19
  76. tracdap_runtime-0.6.2.dist-info/RECORD +121 -0
  77. {tracdap_runtime-0.6.1.dev3.dist-info → tracdap_runtime-0.6.2.dist-info}/WHEEL +1 -1
  78. tracdap/rt/config/common_pb2.py +0 -55
  79. tracdap/rt/config/job_pb2.py +0 -42
  80. tracdap/rt/config/platform_pb2.py +0 -71
  81. tracdap/rt/config/result_pb2.py +0 -37
  82. tracdap/rt/config/runtime_pb2.py +0 -42
  83. tracdap/rt/metadata/common_pb2.py +0 -33
  84. tracdap/rt/metadata/data_pb2.py +0 -51
  85. tracdap/rt/metadata/file_pb2.py +0 -28
  86. tracdap/rt/metadata/flow_pb2.py +0 -55
  87. tracdap/rt/metadata/job_pb2.py +0 -76
  88. tracdap/rt/metadata/model_pb2.py +0 -51
  89. tracdap/rt/metadata/object_id_pb2.py +0 -32
  90. tracdap/rt/metadata/object_pb2.py +0 -35
  91. tracdap/rt/metadata/search_pb2.py +0 -39
  92. tracdap/rt/metadata/stoarge_pb2.py +0 -50
  93. tracdap/rt/metadata/tag_pb2.py +0 -34
  94. tracdap/rt/metadata/tag_update_pb2.py +0 -30
  95. tracdap/rt/metadata/type_pb2.py +0 -48
  96. tracdap_runtime-0.6.1.dev3.dist-info/RECORD +0 -96
  97. {tracdap_runtime-0.6.1.dev3.dist-info → tracdap_runtime-0.6.2.dist-info}/LICENSE +0 -0
  98. {tracdap_runtime-0.6.1.dev3.dist-info → tracdap_runtime-0.6.2.dist-info}/top_level.txt +0 -0
@@ -34,8 +34,9 @@ DEV_MODE_JOB_CONFIG = [
34
34
  re.compile(r"job\.run(Model|Flow)\.parameters\.\w+"),
35
35
  re.compile(r"job\.run(Model|Flow)\.inputs\.\w+"),
36
36
  re.compile(r"job\.run(Model|Flow)\.outputs\.\w+"),
37
- re.compile(r"job\.run(Model|Flow)\.models\.\w+"),
38
- re.compile(r"job\.run(Model|Flow)\.flow+")]
37
+ re.compile(r"job\.runModel\.model"),
38
+ re.compile(r"job\.runFlow\.flow"),
39
+ re.compile(r"job\.runFlow\.models\.\w+")]
39
40
 
40
41
  DEV_MODE_SYS_CONFIG = []
41
42
 
@@ -71,92 +72,24 @@ class DevModeTranslator:
71
72
 
72
73
  cls._log.info(f"Applying dev mode config translation to job config")
73
74
 
74
- model_loader = _models.ModelLoader(sys_config, scratch_dir)
75
- model_loader.create_scope("DEV_MODE_TRANSLATION")
76
-
77
75
  if not job_config.jobId:
78
76
  job_config = cls._process_job_id(job_config)
79
77
 
80
78
  if job_config.job.jobType is None or job_config.job.jobType == _meta.JobType.JOB_TYPE_NOT_SET:
81
79
  job_config = cls._process_job_type(job_config)
82
80
 
83
- if model_class is not None:
84
-
85
- model_id, model_obj = cls._generate_model_for_class(model_loader, model_class)
86
- job_config = cls._add_job_resource(job_config, model_id, model_obj)
87
- job_config.job.runModel.model = _util.selector_for(model_id)
81
+ # Load and populate any models provided as a Python class or class name
82
+ if job_config.job.jobType in [_meta.JobType.RUN_MODEL, _meta.JobType.RUN_FLOW]:
83
+ job_config = cls._process_models(sys_config, job_config, scratch_dir, model_class)
88
84
 
85
+ # Fow flows, load external flow definitions then perform auto-wiring and type inference
89
86
  if job_config.job.jobType == _meta.JobType.RUN_FLOW:
87
+ job_config = cls._process_flow_definition(job_config, config_dir)
90
88
 
91
- original_models = job_config.job.runFlow.models.copy()
92
- for model_key, model_detail in original_models.items():
93
- model_id, model_obj = cls._generate_model_for_entry_point(model_loader, model_detail)
94
- job_config = cls._add_job_resource(job_config, model_id, model_obj)
95
- job_config.job.runFlow.models[model_key] = _util.selector_for(model_id)
96
-
97
- flow_id, flow_obj = cls._expand_flow_definition(job_config, config_dir)
98
- job_config = cls._add_job_resource(job_config, flow_id, flow_obj)
99
- job_config.job.runFlow.flow = _util.selector_for(flow_id)
100
-
101
- model_loader.destroy_scope("DEV_MODE_TRANSLATION")
102
-
89
+ # For run (model|flow) jobs, apply processing to the parameters, inputs and outputs
103
90
  if job_config.job.jobType in [_meta.JobType.RUN_MODEL, _meta.JobType.RUN_FLOW]:
104
91
  job_config = cls._process_parameters(job_config)
105
-
106
- if job_config.job.jobType not in [_meta.JobType.RUN_MODEL, _meta.JobType.RUN_FLOW]:
107
- return job_config
108
-
109
- run_info = job_config.job.runModel \
110
- if job_config.job.jobType == _meta.JobType.RUN_MODEL \
111
- else job_config.job.runFlow
112
-
113
- original_inputs = run_info.inputs
114
- original_outputs = run_info.outputs
115
- original_resources = job_config.resources
116
-
117
- translated_inputs = copy.copy(original_inputs)
118
- translated_outputs = copy.copy(original_outputs)
119
- translated_resources = copy.copy(job_config.resources)
120
-
121
- def process_input_or_output(data_key, data_value, is_input: bool):
122
-
123
- data_id = _util.new_object_id(_meta.ObjectType.DATA)
124
- storage_id = _util.new_object_id(_meta.ObjectType.STORAGE)
125
-
126
- if is_input:
127
- if job_config.job.jobType == _meta.JobType.RUN_MODEL:
128
- model_def = job_config.resources[_util.object_key(job_config.job.runModel.model)]
129
- schema = model_def.model.inputs[data_key].schema
130
- else:
131
- flow_def = job_config.resources[_util.object_key(job_config.job.runFlow.flow)]
132
- schema = flow_def.flow.inputs[data_key].schema
133
- else:
134
- schema = None
135
-
136
- data_obj, storage_obj = cls._process_job_io(
137
- sys_config, data_key, data_value, data_id, storage_id,
138
- new_unique_file=not is_input, schema=schema)
139
-
140
- translated_resources[_util.object_key(data_id)] = data_obj
141
- translated_resources[_util.object_key(storage_id)] = storage_obj
142
-
143
- if is_input:
144
- translated_inputs[data_key] = _util.selector_for(data_id)
145
- else:
146
- translated_outputs[data_key] = _util.selector_for(data_id)
147
-
148
- for input_key, input_value in original_inputs.items():
149
- if not (isinstance(input_value, str) and input_value in original_resources):
150
- process_input_or_output(input_key, input_value, is_input=True)
151
-
152
- for output_key, output_value in original_outputs.items():
153
- if not (isinstance(output_value, str) and output_value in original_outputs):
154
- process_input_or_output(output_key, output_value, is_input=False)
155
-
156
- job_config = copy.copy(job_config)
157
- job_config.resources = translated_resources
158
- run_info.inputs = translated_inputs
159
- run_info.outputs = translated_outputs
92
+ job_config = cls._process_inputs_and_outputs(sys_config, job_config)
160
93
 
161
94
  return job_config
162
95
 
@@ -268,6 +201,58 @@ class DevModeTranslator:
268
201
 
269
202
  return job_config
270
203
 
204
+ @classmethod
205
+ def _process_models(
206
+ cls,
207
+ sys_config: _cfg.RuntimeConfig,
208
+ job_config: _cfg.JobConfig,
209
+ scratch_dir: pathlib.Path,
210
+ model_class: tp.Optional[_api.TracModel.__class__]) \
211
+ -> _cfg.JobConfig:
212
+
213
+ model_loader = _models.ModelLoader(sys_config, scratch_dir)
214
+ model_loader.create_scope("DEV_MODE_TRANSLATION")
215
+
216
+ original_config = job_config
217
+
218
+ job_config = copy.copy(job_config)
219
+ job_config.job = copy.copy(job_config.job)
220
+ job_config.resources = copy.copy(job_config.resources)
221
+
222
+ if job_config.job.jobType == _meta.JobType.RUN_MODEL:
223
+
224
+ job_config.job.runModel = copy.copy(job_config.job.runModel)
225
+
226
+ # If a model class is supplied in code, use that to generate the model def
227
+ if model_class is not None:
228
+ model_id, model_obj = cls._generate_model_for_class(model_loader, model_class)
229
+ job_config = cls._add_job_resource(job_config, model_id, model_obj)
230
+ job_config.job.runModel.model = _util.selector_for(model_id)
231
+
232
+ # Otherwise if model specified as a string instead of a selector, apply the translation
233
+ elif isinstance(original_config.job.runModel.model, str):
234
+ model_detail = original_config.job.runModel.model
235
+ model_id, model_obj = cls._generate_model_for_entry_point(model_loader, model_detail) # noqa
236
+ job_config = cls._add_job_resource(job_config, model_id, model_obj)
237
+ job_config.job.runModel.model = _util.selector_for(model_id)
238
+
239
+ if job_config.job.jobType == _meta.JobType.RUN_FLOW:
240
+
241
+ job_config.job.runFlow = copy.copy(job_config.job.runFlow)
242
+ job_config.job.runFlow.models = copy.copy(job_config.job.runFlow.models)
243
+
244
+ for model_key, model_detail in original_config.job.runFlow.models.items():
245
+
246
+ # Only apply translation if the model is specified as a string instead of a selector
247
+ if isinstance(model_detail, str):
248
+ model_id, model_obj = cls._generate_model_for_entry_point(model_loader, model_detail)
249
+ job_config = cls._add_job_resource(job_config, model_id, model_obj)
250
+ job_config.job.runFlow.models[model_key] = _util.selector_for(model_id)
251
+
252
+ model_loader.destroy_scope("DEV_MODE_TRANSLATION")
253
+
254
+ return job_config
255
+
271
256
  @classmethod
272
257
  def _generate_model_for_class(
273
258
  cls, model_loader: _models.ModelLoader, model_class: _api.TracModel.__class__) \
@@ -306,17 +291,13 @@ class DevModeTranslator:
306
291
  return model_id, model_object
307
292
 
308
293
  @classmethod
309
- def _expand_flow_definition(
310
- cls, job_config: _cfg.JobConfig, config_dir: pathlib.Path) \
311
- -> (_meta.TagHeader, _meta.ObjectDefinition):
294
+ def _process_flow_definition(cls, job_config: _cfg.JobConfig, config_dir: pathlib.Path) -> _cfg.JobConfig:
312
295
 
313
296
  flow_details = job_config.job.runFlow.flow
314
297
 
315
- # The full specification for a flow is as a tag selector for a valid job resource
316
- # This is still allowed in dev mode, in which case dev mode translation is not applied
298
+ # Do not apply translation if flow is specified as an object ID / selector (assume full config is supplied)
317
299
  if isinstance(flow_details, _meta.TagHeader) or isinstance(flow_details, _meta.TagSelector):
318
- flow_obj = _util.get_job_resource(flow_details, job_config, optional=False)
319
- return flow_details, flow_obj
300
+ return job_config
320
301
 
321
302
  # Otherwise, flow is specified as the path to dev-mode flow definition
322
303
  if not isinstance(flow_details, str):
@@ -324,39 +305,59 @@ class DevModeTranslator:
324
305
  cls._log.error(err)
325
306
  raise _ex.EConfigParse(err)
326
307
 
308
+ flow_path = config_dir.joinpath(flow_details) if config_dir is not None else pathlib.Path(flow_details)
309
+
310
+ if not flow_path.exists():
311
+ err = f"Flow definition not available for [{flow_details}]: File not found ({flow_path})"
312
+ cls._log.error(err)
313
+ raise _ex.EConfigParse(err)
314
+
327
315
  flow_id = _util.new_object_id(_meta.ObjectType.FLOW)
328
316
  flow_key = _util.object_key(flow_id)
329
317
 
330
318
  cls._log.info(f"Generating flow definition for [{flow_details}] with ID = [{flow_key}]")
331
319
 
332
- flow_path = config_dir.joinpath(flow_details) if config_dir is not None else pathlib.Path(flow_details)
333
320
  flow_parser = _cfg_p.ConfigParser(_meta.FlowDefinition)
334
321
  flow_raw_data = flow_parser.load_raw_config(flow_path, flow_path.name)
335
322
  flow_def = flow_parser.parse(flow_raw_data, flow_path.name)
336
323
 
337
- flow_def = cls._autowire_flow(flow_def)
338
- flow_def = cls._generate_flow_parameters(flow_def, job_config)
339
- flow_def = cls._generate_flow_inputs(flow_def, job_config)
340
- flow_def = cls._generate_flow_outputs(flow_def, job_config)
324
+ # Auto-wiring and inference only applied to externally loaded flows for now
325
+ flow_def = cls._autowire_flow(flow_def, job_config)
326
+ flow_def = cls._apply_type_inference(flow_def, job_config)
341
327
 
342
- flow_object = _meta.ObjectDefinition(
328
+ flow_obj = _meta.ObjectDefinition(
343
329
  objectType=_meta.ObjectType.FLOW,
344
330
  flow=flow_def)
345
331
 
346
- return flow_id, flow_object
332
+ job_config = copy.copy(job_config)
333
+ job_config.job = copy.copy(job_config.job)
334
+ job_config.job.runFlow = copy.copy(job_config.job.runFlow)
335
+ job_config.resources = copy.copy(job_config.resources)
336
+
337
+ job_config = cls._add_job_resource(job_config, flow_id, flow_obj)
338
+ job_config.job.runFlow.flow = _util.selector_for(flow_id)
339
+
340
+ return job_config
347
341
 
348
342
  @classmethod
349
- def _autowire_flow(cls, flow: _meta.FlowDefinition):
343
+ def _autowire_flow(cls, flow: _meta.FlowDefinition, job_config: _cfg.JobConfig):
344
+
345
+ job = job_config.job.runFlow
346
+ nodes = copy.copy(flow.nodes)
347
+ edges: tp.Dict[str, _meta.FlowEdge] = dict()
350
348
 
351
349
  sources: tp.Dict[str, _meta.FlowSocket] = dict()
352
350
  duplicates: tp.Dict[str, tp.List[_meta.FlowSocket]] = dict()
353
-
354
- edges: tp.Dict[str, _meta.FlowEdge] = dict()
355
351
  errors: tp.Dict[str, str] = dict()
356
352
 
357
353
  def socket_key(socket: _meta.FlowSocket):
358
354
  return f"{socket.node}.{socket.socket}" if socket.socket else socket.node
359
355
 
356
+ # Before starting, add any edges defined explicitly in the flow
357
+ # These take precedence over auto-wired edges
358
+ for edge in flow.edges:
359
+ edges[socket_key(edge.target)] = edge
360
+
360
361
  def add_source(name: str, socket: _meta.FlowSocket):
361
362
  if name in duplicates:
362
363
  duplicates[name].append(socket)
@@ -366,6 +367,14 @@ class DevModeTranslator:
366
367
  else:
367
368
  sources[name] = socket
368
369
 
370
+ def add_param_to_flow(nodel_node: str, param: str):
371
+ target = f"{nodel_node}.{param}"
372
+ if target not in edges and param not in nodes:
373
+ param_node = _meta.FlowNode(_meta.FlowNodeType.PARAMETER_NODE)
374
+ nodes[param] = param_node
375
+ socket = _meta.FlowSocket(param)
376
+ add_source(param, socket)
377
+
369
378
  def add_edge(target: _meta.FlowSocket):
370
379
  target_key = socket_key(target)
371
380
  if target_key in edges:
@@ -380,23 +389,29 @@ class DevModeTranslator:
380
389
  errors[target_key] = f"Flow target {target_name} is not provided by any node"
381
390
 
382
391
  for node_name, node in flow.nodes.items():
383
- if node.nodeType == _meta.FlowNodeType.INPUT_NODE:
392
+ if node.nodeType == _meta.FlowNodeType.INPUT_NODE or node.nodeType == _meta.FlowNodeType.PARAMETER_NODE:
384
393
  add_source(node_name, _meta.FlowSocket(node_name))
385
394
  if node.nodeType == _meta.FlowNodeType.MODEL_NODE:
386
395
  for model_output in node.outputs:
387
396
  add_source(model_output, _meta.FlowSocket(node_name, model_output))
388
-
389
- # Include any edges defined explicitly in the flow
390
- # These take precedence over auto-wired edges
391
- for edge in flow.edges:
392
- edges[socket_key(edge.target)] = edge
393
-
394
- for node_name, node in flow.nodes.items():
397
+ # Generate node param sockets needed by the model
398
+ if node_name in job.models:
399
+ model_selector = job.models[node_name]
400
+ model_obj = _util.get_job_resource(model_selector, job_config)
401
+ for param_name in model_obj.model.parameters:
402
+ add_param_to_flow(node_name, param_name)
403
+ if param_name not in node.parameters:
404
+ node.parameters.append(param_name)
405
+
406
+ # Look at the new set of nodes, which includes any added by auto-wiring
407
+ for node_name, node in nodes.items():
395
408
  if node.nodeType == _meta.FlowNodeType.OUTPUT_NODE:
396
409
  add_edge(_meta.FlowSocket(node_name))
397
410
  if node.nodeType == _meta.FlowNodeType.MODEL_NODE:
398
411
  for model_input in node.inputs:
399
412
  add_edge(_meta.FlowSocket(node_name, model_input))
413
+ for model_param in node.parameters:
414
+ add_edge(_meta.FlowSocket(node_name, model_param))
400
415
 
401
416
  if any(errors):
402
417
 
@@ -408,140 +423,149 @@ class DevModeTranslator:
408
423
  raise _ex.EConfigParse(err)
409
424
 
410
425
  autowired_flow = copy.copy(flow)
426
+ autowired_flow.nodes = nodes
411
427
  autowired_flow.edges = list(edges.values())
412
428
 
413
429
  return autowired_flow
414
430
 
415
431
  @classmethod
416
- def _generate_flow_parameters(cls, flow: _meta.FlowDefinition, job_config: _cfg.JobConfig) -> _meta.FlowDefinition:
417
-
418
- params: tp.Dict[str, _meta.ModelParameter] = dict()
419
-
420
- for node_name, node in flow.nodes.items():
432
+ def _apply_type_inference(cls, flow: _meta.FlowDefinition, job_config: _cfg.JobConfig) -> _meta.FlowDefinition:
421
433
 
422
- if node.nodeType != _meta.FlowNodeType.MODEL_NODE:
423
- continue
424
-
425
- if node_name not in job_config.job.runFlow.models:
426
- err = f"No model supplied for flow model node [{node_name}]"
427
- cls._log.error(err)
428
- raise _ex.EConfigParse(err)
434
+ updated_flow = copy.copy(flow)
435
+ updated_flow.parameters = copy.copy(flow.parameters)
436
+ updated_flow.inputs = copy.copy(flow.inputs)
437
+ updated_flow.outputs = copy.copy(flow.outputs)
429
438
 
430
- model_selector = job_config.job.runFlow.models[node_name]
431
- model_obj = _util.get_job_resource(model_selector, job_config)
432
-
433
- for param_name, param in model_obj.model.parameters.items():
439
+ def socket_key(socket):
440
+ return f"{socket.node}.{socket.socket}" if socket.socket else socket.node
434
441
 
435
- if param_name not in params:
436
- params[param_name] = param
442
+ # Build a map of edges by source socket, mapping to all edges flowing from that source
443
+ edges_by_source = {socket_key(edge.source): [] for edge in flow.edges}
444
+ edges_by_target = {socket_key(edge.target): [] for edge in flow.edges}
445
+ for edge in flow.edges:
446
+ edges_by_source[socket_key(edge.source)].append(edge.target)
447
+ edges_by_target[socket_key(edge.target)].append(edge.source)
437
448
 
438
- else:
439
- existing_param = params[param_name]
449
+ for node_name, node in flow.nodes.items():
440
450
 
441
- if param.paramType != existing_param.paramType:
442
- err = f"Model parameter [{param_name}] has different types in different models"
443
- cls._log.error(err)
444
- raise _ex.EConfigParse(err)
451
+ if node.nodeType == _meta.FlowNodeType.PARAMETER_NODE and node_name not in flow.parameters:
452
+ targets = edges_by_source.get(node_name) or []
453
+ model_parameter = cls._infer_parameter(node_name, targets, job_config)
454
+ updated_flow.parameters[node_name] = model_parameter
445
455
 
446
- if param.defaultValue != existing_param.defaultValue:
447
- if existing_param.defaultValue is None:
448
- params[param_name] = param
449
- elif param.defaultValue is not None:
450
- warn = f"Model parameter [{param_name}] has different default values in different models" \
451
- + f" (using [{_types.MetadataCodec.decode_value(existing_param.defaultValue)}])"
452
- cls._log.warning(warn)
456
+ if node.nodeType == _meta.FlowNodeType.INPUT_NODE and node_name not in flow.inputs:
457
+ targets = edges_by_source.get(node_name) or []
458
+ model_input = cls._infer_input_schema(node_name, targets, job_config)
459
+ updated_flow.inputs[node_name] = model_input
453
460
 
454
- flow.parameters = params
461
+ if node.nodeType == _meta.FlowNodeType.OUTPUT_NODE and node_name not in flow.outputs:
462
+ sources = edges_by_target.get(node_name) or []
463
+ model_output = cls._infer_output_schema(node_name, sources, job_config)
464
+ updated_flow.outputs[node_name] = model_output
455
465
 
456
- return flow
466
+ return updated_flow
457
467
 
458
468
  @classmethod
459
- def _generate_flow_inputs(cls, flow: _meta.FlowDefinition, job_config: _cfg.JobConfig) -> _meta.FlowDefinition:
469
+ def _infer_parameter(
470
+ cls, param_name: str, targets: tp.List[_meta.FlowSocket],
471
+ job_config: _cfg.JobConfig) -> _meta.ModelParameter:
460
472
 
461
- inputs: tp.Dict[str, _meta.ModelInputSchema] = dict()
473
+ model_params = []
462
474
 
463
- def socket_key(socket):
464
- return f"{socket.node}.{socket.socket}" if socket.socket else socket.node
465
-
466
- # Build a map of edges by source socket, mapping to all edges flowing from that source
467
- edges = {socket_key(edge.source): [] for edge in flow.edges}
468
- for edge in flow.edges:
469
- edges[socket_key(edge.source)].append(edge)
470
-
471
- for node_name, node in flow.nodes.items():
472
-
473
- if node.nodeType != _meta.FlowNodeType.INPUT_NODE:
474
- continue
475
+ for target in targets:
475
476
 
476
- input_edges = edges.get(node_name)
477
+ model_selector = job_config.job.runFlow.models.get(target.node)
478
+ model_obj = _util.get_job_resource(model_selector, job_config)
479
+ model_param = model_obj.model.parameters.get(target.socket)
480
+ model_params.append(model_param)
477
481
 
478
- if not input_edges:
479
- err = f"Flow input [{node_name}] is not connected, so the input schema cannot be inferred" \
480
- + f" (either remove the input or connect it to a model)"
481
- cls._log.error(err)
482
- raise _ex.EConfigParse(err)
482
+ if len(model_params) == 0:
483
+ err = f"Flow parameter [{param_name}] is not connected to any models, type information cannot be inferred" \
484
+ + f" (either remove the parameter or connect it to a model)"
485
+ cls._log.error(err)
486
+ raise _ex.EJobValidation(err)
483
487
 
484
- input_schemas = []
488
+ if len(model_params) == 1:
489
+ return model_params[0]
485
490
 
486
- for edge in input_edges:
491
+ model_param = model_params[0]
487
492
 
488
- target_node = flow.nodes.get(edge.target.node) # or cls._report_error(cls._MISSING_FLOW_NODE, node_name)
489
- # cls._require(target_node.nodeType == _meta.FlowNodeType.MODEL_NODE)
493
+ for i in range(1, len(targets)):
494
+ next_param = model_params[i]
495
+ if next_param.paramType != model_param.paramType:
496
+ err = f"Parameter is ambiguous for [{param_name}]: " + \
497
+ f"Types are different for [{cls._socket_key(targets[0])}] and [{cls._socket_key(targets[i])}]"
498
+ raise _ex.EJobValidation(err)
499
+ if next_param.defaultValue is None or next_param.defaultValue != model_param.defaultValue:
500
+ model_param.defaultValue = None
490
501
 
491
- model_selector = job_config.job.runFlow.models.get(edge.target.node)
492
- model_obj = _util.get_job_resource(model_selector, job_config)
493
- model_input = model_obj.model.inputs[edge.target.socket]
494
- input_schemas.append(model_input)
502
+ return model_param
495
503
 
496
- if len(input_schemas) == 1:
497
- inputs[node_name] = input_schemas[0]
498
- else:
499
- first_schema = input_schemas[0]
500
- if all(map(lambda s: s == first_schema, input_schemas[1:])):
501
- inputs[node_name] = first_schema
502
- else:
503
- raise _ex.EJobValidation(f"Multiple models use input [{node_name}] but expect different schemas")
504
+ @classmethod
505
+ def _infer_input_schema(
506
+ cls, input_name: str, targets: tp.List[_meta.FlowSocket],
507
+ job_config: _cfg.JobConfig) -> _meta.ModelInputSchema:
504
508
 
505
- flow.inputs = inputs
509
+ model_inputs = []
506
510
 
507
- return flow
511
+ for target in targets:
508
512
 
509
- @classmethod
510
- def _generate_flow_outputs(cls, flow: _meta.FlowDefinition, job_config: _cfg.JobConfig) -> _meta.FlowDefinition:
513
+ model_selector = job_config.job.runFlow.models.get(target.node)
514
+ model_obj = _util.get_job_resource(model_selector, job_config)
515
+ model_input = model_obj.model.inputs.get(target.socket)
516
+ model_inputs.append(model_input)
511
517
 
512
- outputs: tp.Dict[str, _meta.ModelOutputSchema] = dict()
518
+ if len(model_inputs) == 0:
519
+ err = f"Flow input [{input_name}] is not connected to any models, schema cannot be inferred" \
520
+ + f" (either remove the input or connect it to a model)"
521
+ cls._log.error(err)
522
+ raise _ex.EJobValidation(err)
513
523
 
514
- def socket_key(socket):
515
- return f"{socket.node}.{socket.socket}" if socket.socket else socket.node
524
+ if len(model_inputs) == 1:
525
+ return model_inputs[0]
516
526
 
517
- # Build a map of edges by target socket, there can only be one edge per target in a valid flow
518
- edges = {socket_key(edge.target): edge for edge in flow.edges}
527
+ model_input = model_inputs[0]
519
528
 
520
- for node_name, node in flow.nodes.items():
529
+ for i in range(1, len(targets)):
530
+ next_input = model_inputs[i]
531
+ # Very strict rules on inputs, they must have the exact same schema
532
+ # The Java code includes a combineSchema() method which could be used here as well
533
+ if next_input != model_input:
534
+ raise _ex.EJobValidation(f"Multiple models use input [{input_name}] but expect different schemas")
521
535
 
522
- if node.nodeType != _meta.FlowNodeType.OUTPUT_NODE:
523
- continue
536
+ return model_input
524
537
 
525
- edge = edges.get(node_name)
538
+ @classmethod
539
+ def _infer_output_schema(
540
+ cls, output_name: str, sources: tp.List[_meta.FlowSocket],
541
+ job_config: _cfg.JobConfig) -> _meta.ModelOutputSchema:
526
542
 
527
- if not edge:
528
- err = f"Flow output [{node_name}] is not connected, so the output schema cannot be inferred" \
529
- + f" (either remove the output or connect it to a model)"
530
- cls._log.error(err)
531
- raise _ex.EConfigParse(err)
543
+ model_outputs = []
532
544
 
533
- source_node = flow.nodes.get(edge.source.node) # or cls._report_error(cls._MISSING_FLOW_NODE, node_name)
534
- # cls._require(target_node.nodeType == _meta.FlowNodeType.MODEL_NODE)
545
+ for source in sources:
535
546
 
536
- model_selector = job_config.job.runFlow.models.get(edge.source.node)
547
+ model_selector = job_config.job.runFlow.models.get(source.node)
537
548
  model_obj = _util.get_job_resource(model_selector, job_config)
538
- model_output = model_obj.model.outputs[edge.source.socket]
549
+ model_input = model_obj.model.inputs.get(source.socket)
550
+ model_outputs.append(model_input)
539
551
 
540
- outputs[node_name] = model_output
552
+ if len(model_outputs) == 0:
553
+ err = f"Flow output [{output_name}] is not connected to any models, schema cannot be inferred" \
554
+ + f" (either remove the output or connect it to a model)"
555
+ cls._log.error(err)
556
+ raise _ex.EJobValidation(err)
541
557
 
542
- flow.outputs = outputs
558
+ if len(model_outputs) > 1:
559
+ err = f"Flow output [{output_name}] is not to multiple models" \
560
+ + f" (only one model can supply one output)"
561
+ cls._log.error(err)
562
+ raise _ex.EJobValidation(err)
543
563
 
544
- return flow
564
+ return model_outputs[0]
565
+
566
+ @classmethod
567
+ def _socket_key(cls, socket):
568
+ return f"{socket.node}.{socket.socket}" if socket.socket else socket.node
545
569
 
546
570
  @classmethod
547
571
  def _process_parameters(cls, job_config: _cfg.JobConfig) -> _cfg.JobConfig:
@@ -564,19 +588,8 @@ class DevModeTranslator:
564
588
  param_specs = model_or_flow.parameters
565
589
  param_values = job_details.parameters
566
590
 
567
- encoded_params = cls._process_parameters_dict(param_specs, param_values)
568
-
569
- job_details = copy.copy(job_details)
570
- job_def = copy.copy(job_config.job)
571
- job_config = copy.copy(job_config)
572
-
573
- if job_config.job.jobType == _meta.JobType.RUN_MODEL:
574
- job_def.runModel = job_details
575
- else:
576
- job_def.runFlow = job_details
577
-
578
- job_details.parameters = encoded_params
579
- job_config.job = job_def
591
+ # Set encoded params on runModel or runFlow depending on the job type
592
+ job_details.parameters = cls._process_parameters_dict(param_specs, param_values)
580
593
 
581
594
  return job_config
582
595
 
@@ -610,9 +623,57 @@ class DevModeTranslator:
610
623
  return encoded_values
611
624
 
612
625
  @classmethod
613
- def _process_job_io(
614
- cls, sys_config, data_key, data_value, data_id, storage_id,
615
- new_unique_file=False, schema: tp.Optional[_meta.SchemaDefinition] = None):
626
+ def _process_inputs_and_outputs(cls, sys_config: _cfg.RuntimeConfig, job_config: _cfg.JobConfig) -> _cfg.JobConfig:
627
+
628
+ if job_config.job.jobType == _meta.JobType.RUN_MODEL:
629
+ job_details = job_config.job.runModel
630
+ model_obj = _util.get_job_resource(job_details.model, job_config)
631
+ required_inputs = model_obj.model.inputs
632
+
633
+ elif job_config.job.jobType == _meta.JobType.RUN_FLOW:
634
+ job_details = job_config.job.runFlow
635
+ flow_obj = _util.get_job_resource(job_details.flow, job_config)
636
+ required_inputs = flow_obj.flow.inputs
637
+
638
+ else:
639
+ return job_config
640
+
641
+ job_inputs = job_details.inputs
642
+ job_outputs = job_details.outputs
643
+ job_resources = job_config.resources
644
+
645
+ for input_key, input_value in job_inputs.items():
646
+ if not (isinstance(input_value, str) and input_value in job_resources):
647
+
648
+ input_schema = required_inputs[input_key].schema
649
+
650
+ input_id = cls._process_input_or_output(
651
+ sys_config, input_key, input_value, job_resources,
652
+ new_unique_file=False, schema=input_schema)
653
+
654
+ job_inputs[input_key] = _util.selector_for(input_id)
655
+
656
+ for output_key, output_value in job_outputs.items():
657
+ if not (isinstance(output_value, str) and output_value in job_resources):
658
+
659
+ output_id = cls._process_input_or_output(
660
+ sys_config, output_key, output_value, job_resources,
661
+ new_unique_file=True, schema=None)
662
+
663
+ job_outputs[output_key] = _util.selector_for(output_id)
664
+
665
+ return job_config
666
+
667
+ @classmethod
668
+ def _process_input_or_output(
669
+ cls, sys_config, data_key, data_value,
670
+ resources: tp.Dict[str, _meta.ObjectDefinition],
671
+ new_unique_file=False,
672
+ schema: tp.Optional[_meta.SchemaDefinition] = None) \
673
+ -> _meta.TagHeader:
674
+
675
+ data_id = _util.new_object_id(_meta.ObjectType.DATA)
676
+ storage_id = _util.new_object_id(_meta.ObjectType.STORAGE)
616
677
 
617
678
  if isinstance(data_value, str):
618
679
  storage_path = data_value
@@ -666,7 +727,10 @@ class DevModeTranslator:
666
727
  snap_index=snap_version, delta_index=1, incarnation_index=1,
667
728
  schema=schema)
668
729
 
669
- return data_obj, storage_obj
730
+ resources[_util.object_key(data_id)] = data_obj
731
+ resources[_util.object_key(storage_id)] = storage_obj
732
+
733
+ return data_id
670
734
 
671
735
  @staticmethod
672
736
  def infer_format(storage_path: str, storage_config: _cfg.StorageConfig):