tracdap-runtime 0.6.1.dev3__py3-none-any.whl → 0.6.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (102) hide show
  1. tracdap/rt/_exec/actors.py +87 -10
  2. tracdap/rt/_exec/context.py +25 -1
  3. tracdap/rt/_exec/dev_mode.py +277 -221
  4. tracdap/rt/_exec/engine.py +79 -14
  5. tracdap/rt/_exec/functions.py +37 -8
  6. tracdap/rt/_exec/graph.py +2 -0
  7. tracdap/rt/_exec/graph_builder.py +118 -56
  8. tracdap/rt/_exec/runtime.py +108 -37
  9. tracdap/rt/_exec/server.py +345 -0
  10. tracdap/rt/_impl/config_parser.py +219 -49
  11. tracdap/rt/_impl/data.py +14 -0
  12. tracdap/rt/_impl/grpc/__init__.py +13 -0
  13. tracdap/rt/_impl/grpc/codec.py +99 -0
  14. tracdap/rt/_impl/grpc/tracdap/api/internal/runtime_pb2.py +51 -0
  15. tracdap/rt/_impl/grpc/tracdap/api/internal/runtime_pb2.pyi +61 -0
  16. tracdap/rt/_impl/grpc/tracdap/api/internal/runtime_pb2_grpc.py +183 -0
  17. tracdap/rt/_impl/grpc/tracdap/metadata/common_pb2.py +33 -0
  18. tracdap/rt/_impl/grpc/tracdap/metadata/common_pb2.pyi +34 -0
  19. tracdap/rt/{metadata → _impl/grpc/tracdap/metadata}/custom_pb2.py +5 -5
  20. tracdap/rt/_impl/grpc/tracdap/metadata/custom_pb2.pyi +15 -0
  21. tracdap/rt/_impl/grpc/tracdap/metadata/data_pb2.py +51 -0
  22. tracdap/rt/_impl/grpc/tracdap/metadata/data_pb2.pyi +115 -0
  23. tracdap/rt/_impl/grpc/tracdap/metadata/file_pb2.py +28 -0
  24. tracdap/rt/_impl/grpc/tracdap/metadata/file_pb2.pyi +22 -0
  25. tracdap/rt/_impl/grpc/tracdap/metadata/flow_pb2.py +59 -0
  26. tracdap/rt/_impl/grpc/tracdap/metadata/flow_pb2.pyi +109 -0
  27. tracdap/rt/_impl/grpc/tracdap/metadata/job_pb2.py +76 -0
  28. tracdap/rt/_impl/grpc/tracdap/metadata/job_pb2.pyi +177 -0
  29. tracdap/rt/_impl/grpc/tracdap/metadata/model_pb2.py +63 -0
  30. tracdap/rt/_impl/grpc/tracdap/metadata/model_pb2.pyi +119 -0
  31. tracdap/rt/_impl/grpc/tracdap/metadata/object_id_pb2.py +32 -0
  32. tracdap/rt/_impl/grpc/tracdap/metadata/object_id_pb2.pyi +68 -0
  33. tracdap/rt/_impl/grpc/tracdap/metadata/object_pb2.py +40 -0
  34. tracdap/rt/_impl/grpc/tracdap/metadata/object_pb2.pyi +46 -0
  35. tracdap/rt/_impl/grpc/tracdap/metadata/search_pb2.py +39 -0
  36. tracdap/rt/_impl/grpc/tracdap/metadata/search_pb2.pyi +83 -0
  37. tracdap/rt/_impl/grpc/tracdap/metadata/stoarge_pb2.py +50 -0
  38. tracdap/rt/_impl/grpc/tracdap/metadata/stoarge_pb2.pyi +89 -0
  39. tracdap/rt/_impl/grpc/tracdap/metadata/tag_pb2.py +34 -0
  40. tracdap/rt/_impl/grpc/tracdap/metadata/tag_pb2.pyi +26 -0
  41. tracdap/rt/_impl/grpc/tracdap/metadata/tag_update_pb2.py +30 -0
  42. tracdap/rt/_impl/grpc/tracdap/metadata/tag_update_pb2.pyi +34 -0
  43. tracdap/rt/_impl/grpc/tracdap/metadata/type_pb2.py +47 -0
  44. tracdap/rt/_impl/grpc/tracdap/metadata/type_pb2.pyi +101 -0
  45. tracdap/rt/_impl/guard_rails.py +26 -6
  46. tracdap/rt/_impl/models.py +25 -0
  47. tracdap/rt/_impl/static_api.py +27 -9
  48. tracdap/rt/_impl/type_system.py +17 -0
  49. tracdap/rt/_impl/validation.py +10 -0
  50. tracdap/rt/_plugins/config_local.py +49 -0
  51. tracdap/rt/_version.py +1 -1
  52. tracdap/rt/api/hook.py +10 -3
  53. tracdap/rt/api/model_api.py +22 -0
  54. tracdap/rt/api/static_api.py +79 -19
  55. tracdap/rt/config/__init__.py +3 -3
  56. tracdap/rt/config/common.py +10 -0
  57. tracdap/rt/config/platform.py +9 -19
  58. tracdap/rt/config/runtime.py +2 -0
  59. tracdap/rt/ext/config.py +34 -0
  60. tracdap/rt/ext/embed.py +1 -3
  61. tracdap/rt/ext/plugins.py +47 -6
  62. tracdap/rt/launch/cli.py +7 -5
  63. tracdap/rt/launch/launch.py +49 -12
  64. tracdap/rt/metadata/__init__.py +24 -24
  65. tracdap/rt/metadata/common.py +7 -7
  66. tracdap/rt/metadata/custom.py +2 -0
  67. tracdap/rt/metadata/data.py +28 -5
  68. tracdap/rt/metadata/file.py +2 -0
  69. tracdap/rt/metadata/flow.py +66 -4
  70. tracdap/rt/metadata/job.py +56 -16
  71. tracdap/rt/metadata/model.py +10 -0
  72. tracdap/rt/metadata/object.py +3 -0
  73. tracdap/rt/metadata/object_id.py +9 -9
  74. tracdap/rt/metadata/search.py +35 -13
  75. tracdap/rt/metadata/stoarge.py +64 -6
  76. tracdap/rt/metadata/tag_update.py +21 -7
  77. tracdap/rt/metadata/type.py +28 -13
  78. {tracdap_runtime-0.6.1.dev3.dist-info → tracdap_runtime-0.6.3.dist-info}/METADATA +22 -19
  79. tracdap_runtime-0.6.3.dist-info/RECORD +112 -0
  80. {tracdap_runtime-0.6.1.dev3.dist-info → tracdap_runtime-0.6.3.dist-info}/WHEEL +1 -1
  81. tracdap/rt/config/common_pb2.py +0 -55
  82. tracdap/rt/config/job_pb2.py +0 -42
  83. tracdap/rt/config/platform_pb2.py +0 -71
  84. tracdap/rt/config/result_pb2.py +0 -37
  85. tracdap/rt/config/runtime_pb2.py +0 -42
  86. tracdap/rt/ext/_guard.py +0 -37
  87. tracdap/rt/metadata/common_pb2.py +0 -33
  88. tracdap/rt/metadata/data_pb2.py +0 -51
  89. tracdap/rt/metadata/file_pb2.py +0 -28
  90. tracdap/rt/metadata/flow_pb2.py +0 -55
  91. tracdap/rt/metadata/job_pb2.py +0 -76
  92. tracdap/rt/metadata/model_pb2.py +0 -51
  93. tracdap/rt/metadata/object_id_pb2.py +0 -32
  94. tracdap/rt/metadata/object_pb2.py +0 -35
  95. tracdap/rt/metadata/search_pb2.py +0 -39
  96. tracdap/rt/metadata/stoarge_pb2.py +0 -50
  97. tracdap/rt/metadata/tag_pb2.py +0 -34
  98. tracdap/rt/metadata/tag_update_pb2.py +0 -30
  99. tracdap/rt/metadata/type_pb2.py +0 -48
  100. tracdap_runtime-0.6.1.dev3.dist-info/RECORD +0 -96
  101. {tracdap_runtime-0.6.1.dev3.dist-info → tracdap_runtime-0.6.3.dist-info}/LICENSE +0 -0
  102. {tracdap_runtime-0.6.1.dev3.dist-info → tracdap_runtime-0.6.3.dist-info}/top_level.txt +0 -0
@@ -34,8 +34,9 @@ DEV_MODE_JOB_CONFIG = [
34
34
  re.compile(r"job\.run(Model|Flow)\.parameters\.\w+"),
35
35
  re.compile(r"job\.run(Model|Flow)\.inputs\.\w+"),
36
36
  re.compile(r"job\.run(Model|Flow)\.outputs\.\w+"),
37
- re.compile(r"job\.run(Model|Flow)\.models\.\w+"),
38
- re.compile(r"job\.run(Model|Flow)\.flow+")]
37
+ re.compile(r"job\.runModel\.model"),
38
+ re.compile(r"job\.runFlow\.flow"),
39
+ re.compile(r"job\.runFlow\.models\.\w+")]
39
40
 
40
41
  DEV_MODE_SYS_CONFIG = []
41
42
 
@@ -45,7 +46,7 @@ class DevModeTranslator:
45
46
  _log: tp.Optional[_util.logging.Logger] = None
46
47
 
47
48
  @classmethod
48
- def translate_sys_config(cls, sys_config: _cfg.RuntimeConfig, config_dir: tp.Optional[pathlib.Path]):
49
+ def translate_sys_config(cls, sys_config: _cfg.RuntimeConfig, config_mgr: _cfg_p.ConfigManager):
49
50
 
50
51
  cls._log.info(f"Applying dev mode config translation to system config")
51
52
 
@@ -55,7 +56,7 @@ class DevModeTranslator:
55
56
  sys_config.storage = _cfg.StorageConfig()
56
57
 
57
58
  sys_config = cls._add_integrated_repo(sys_config)
58
- sys_config = cls._resolve_relative_storage_root(sys_config, config_dir)
59
+ sys_config = cls._resolve_relative_storage_root(sys_config, config_mgr)
59
60
 
60
61
  return sys_config
61
62
 
@@ -65,98 +66,30 @@ class DevModeTranslator:
65
66
  sys_config: _cfg.RuntimeConfig,
66
67
  job_config: _cfg.JobConfig,
67
68
  scratch_dir: pathlib.Path,
68
- config_dir: tp.Optional[pathlib.Path],
69
+ config_mgr: _cfg_p.ConfigManager,
69
70
  model_class: tp.Optional[_api.TracModel.__class__]) \
70
71
  -> _cfg.JobConfig:
71
72
 
72
73
  cls._log.info(f"Applying dev mode config translation to job config")
73
74
 
74
- model_loader = _models.ModelLoader(sys_config, scratch_dir)
75
- model_loader.create_scope("DEV_MODE_TRANSLATION")
76
-
77
75
  if not job_config.jobId:
78
76
  job_config = cls._process_job_id(job_config)
79
77
 
80
78
  if job_config.job.jobType is None or job_config.job.jobType == _meta.JobType.JOB_TYPE_NOT_SET:
81
79
  job_config = cls._process_job_type(job_config)
82
80
 
83
- if model_class is not None:
84
-
85
- model_id, model_obj = cls._generate_model_for_class(model_loader, model_class)
86
- job_config = cls._add_job_resource(job_config, model_id, model_obj)
87
- job_config.job.runModel.model = _util.selector_for(model_id)
81
+ # Load and populate any models provided as a Python class or class name
82
+ if job_config.job.jobType in [_meta.JobType.RUN_MODEL, _meta.JobType.RUN_FLOW]:
83
+ job_config = cls._process_models(sys_config, job_config, scratch_dir, model_class)
88
84
 
85
+ # Fow flows, load external flow definitions then perform auto-wiring and type inference
89
86
  if job_config.job.jobType == _meta.JobType.RUN_FLOW:
87
+ job_config = cls._process_flow_definition(job_config, config_mgr)
90
88
 
91
- original_models = job_config.job.runFlow.models.copy()
92
- for model_key, model_detail in original_models.items():
93
- model_id, model_obj = cls._generate_model_for_entry_point(model_loader, model_detail)
94
- job_config = cls._add_job_resource(job_config, model_id, model_obj)
95
- job_config.job.runFlow.models[model_key] = _util.selector_for(model_id)
96
-
97
- flow_id, flow_obj = cls._expand_flow_definition(job_config, config_dir)
98
- job_config = cls._add_job_resource(job_config, flow_id, flow_obj)
99
- job_config.job.runFlow.flow = _util.selector_for(flow_id)
100
-
101
- model_loader.destroy_scope("DEV_MODE_TRANSLATION")
102
-
89
+ # For run (model|flow) jobs, apply processing to the parameters, inputs and outputs
103
90
  if job_config.job.jobType in [_meta.JobType.RUN_MODEL, _meta.JobType.RUN_FLOW]:
104
91
  job_config = cls._process_parameters(job_config)
105
-
106
- if job_config.job.jobType not in [_meta.JobType.RUN_MODEL, _meta.JobType.RUN_FLOW]:
107
- return job_config
108
-
109
- run_info = job_config.job.runModel \
110
- if job_config.job.jobType == _meta.JobType.RUN_MODEL \
111
- else job_config.job.runFlow
112
-
113
- original_inputs = run_info.inputs
114
- original_outputs = run_info.outputs
115
- original_resources = job_config.resources
116
-
117
- translated_inputs = copy.copy(original_inputs)
118
- translated_outputs = copy.copy(original_outputs)
119
- translated_resources = copy.copy(job_config.resources)
120
-
121
- def process_input_or_output(data_key, data_value, is_input: bool):
122
-
123
- data_id = _util.new_object_id(_meta.ObjectType.DATA)
124
- storage_id = _util.new_object_id(_meta.ObjectType.STORAGE)
125
-
126
- if is_input:
127
- if job_config.job.jobType == _meta.JobType.RUN_MODEL:
128
- model_def = job_config.resources[_util.object_key(job_config.job.runModel.model)]
129
- schema = model_def.model.inputs[data_key].schema
130
- else:
131
- flow_def = job_config.resources[_util.object_key(job_config.job.runFlow.flow)]
132
- schema = flow_def.flow.inputs[data_key].schema
133
- else:
134
- schema = None
135
-
136
- data_obj, storage_obj = cls._process_job_io(
137
- sys_config, data_key, data_value, data_id, storage_id,
138
- new_unique_file=not is_input, schema=schema)
139
-
140
- translated_resources[_util.object_key(data_id)] = data_obj
141
- translated_resources[_util.object_key(storage_id)] = storage_obj
142
-
143
- if is_input:
144
- translated_inputs[data_key] = _util.selector_for(data_id)
145
- else:
146
- translated_outputs[data_key] = _util.selector_for(data_id)
147
-
148
- for input_key, input_value in original_inputs.items():
149
- if not (isinstance(input_value, str) and input_value in original_resources):
150
- process_input_or_output(input_key, input_value, is_input=True)
151
-
152
- for output_key, output_value in original_outputs.items():
153
- if not (isinstance(output_value, str) and output_value in original_outputs):
154
- process_input_or_output(output_key, output_value, is_input=False)
155
-
156
- job_config = copy.copy(job_config)
157
- job_config.resources = translated_resources
158
- run_info.inputs = translated_inputs
159
- run_info.outputs = translated_outputs
92
+ job_config = cls._process_inputs_and_outputs(sys_config, job_config)
160
93
 
161
94
  return job_config
162
95
 
@@ -176,7 +109,7 @@ class DevModeTranslator:
176
109
  @classmethod
177
110
  def _resolve_relative_storage_root(
178
111
  cls, sys_config: _cfg.RuntimeConfig,
179
- sys_config_path: tp.Optional[pathlib.Path]):
112
+ config_mgr: _cfg_p.ConfigManager):
180
113
 
181
114
  storage_config = copy.deepcopy(sys_config.storage)
182
115
 
@@ -195,6 +128,7 @@ class DevModeTranslator:
195
128
 
196
129
  cls._log.info(f"Resolving relative path for [{bucket_key}] local storage...")
197
130
 
131
+ sys_config_path = config_mgr.config_dir_path()
198
132
  if sys_config_path is not None:
199
133
  absolute_path = sys_config_path.joinpath(root_path).resolve()
200
134
  if absolute_path.exists():
@@ -268,6 +202,58 @@ class DevModeTranslator:
268
202
 
269
203
  return job_config
270
204
 
205
+ @classmethod
206
+ def _process_models(
207
+ cls,
208
+ sys_config: _cfg.RuntimeConfig,
209
+ job_config: _cfg.JobConfig,
210
+ scratch_dir: pathlib.Path,
211
+ model_class: tp.Optional[_api.TracModel.__class__]) \
212
+ -> _cfg.JobConfig:
213
+
214
+ model_loader = _models.ModelLoader(sys_config, scratch_dir)
215
+ model_loader.create_scope("DEV_MODE_TRANSLATION")
216
+
217
+ original_config = job_config
218
+
219
+ job_config = copy.copy(job_config)
220
+ job_config.job = copy.copy(job_config.job)
221
+ job_config.resources = copy.copy(job_config.resources)
222
+
223
+ if job_config.job.jobType == _meta.JobType.RUN_MODEL:
224
+
225
+ job_config.job.runModel = copy.copy(job_config.job.runModel)
226
+
227
+ # If a model class is supplied in code, use that to generate the model def
228
+ if model_class is not None:
229
+ model_id, model_obj = cls._generate_model_for_class(model_loader, model_class)
230
+ job_config = cls._add_job_resource(job_config, model_id, model_obj)
231
+ job_config.job.runModel.model = _util.selector_for(model_id)
232
+
233
+ # Otherwise if model specified as a string instead of a selector, apply the translation
234
+ elif isinstance(original_config.job.runModel.model, str):
235
+ model_detail = original_config.job.runModel.model
236
+ model_id, model_obj = cls._generate_model_for_entry_point(model_loader, model_detail) # noqa
237
+ job_config = cls._add_job_resource(job_config, model_id, model_obj)
238
+ job_config.job.runModel.model = _util.selector_for(model_id)
239
+
240
+ if job_config.job.jobType == _meta.JobType.RUN_FLOW:
241
+
242
+ job_config.job.runFlow = copy.copy(job_config.job.runFlow)
243
+ job_config.job.runFlow.models = copy.copy(job_config.job.runFlow.models)
244
+
245
+ for model_key, model_detail in original_config.job.runFlow.models.items():
246
+
247
+ # Only apply translation if the model is specified as a string instead of a selector
248
+ if isinstance(model_detail, str):
249
+ model_id, model_obj = cls._generate_model_for_entry_point(model_loader, model_detail)
250
+ job_config = cls._add_job_resource(job_config, model_id, model_obj)
251
+ job_config.job.runFlow.models[model_key] = _util.selector_for(model_id)
252
+
253
+ model_loader.destroy_scope("DEV_MODE_TRANSLATION")
254
+
255
+ return job_config
256
+
271
257
  @classmethod
272
258
  def _generate_model_for_class(
273
259
  cls, model_loader: _models.ModelLoader, model_class: _api.TracModel.__class__) \
@@ -306,17 +292,13 @@ class DevModeTranslator:
306
292
  return model_id, model_object
307
293
 
308
294
  @classmethod
309
- def _expand_flow_definition(
310
- cls, job_config: _cfg.JobConfig, config_dir: pathlib.Path) \
311
- -> (_meta.TagHeader, _meta.ObjectDefinition):
295
+ def _process_flow_definition(cls, job_config: _cfg.JobConfig, config_mgr: _cfg_p.ConfigManager) -> _cfg.JobConfig:
312
296
 
313
297
  flow_details = job_config.job.runFlow.flow
314
298
 
315
- # The full specification for a flow is as a tag selector for a valid job resource
316
- # This is still allowed in dev mode, in which case dev mode translation is not applied
299
+ # Do not apply translation if flow is specified as an object ID / selector (assume full config is supplied)
317
300
  if isinstance(flow_details, _meta.TagHeader) or isinstance(flow_details, _meta.TagSelector):
318
- flow_obj = _util.get_job_resource(flow_details, job_config, optional=False)
319
- return flow_details, flow_obj
301
+ return job_config
320
302
 
321
303
  # Otherwise, flow is specified as the path to dev-mode flow definition
322
304
  if not isinstance(flow_details, str):
@@ -327,36 +309,47 @@ class DevModeTranslator:
327
309
  flow_id = _util.new_object_id(_meta.ObjectType.FLOW)
328
310
  flow_key = _util.object_key(flow_id)
329
311
 
330
- cls._log.info(f"Generating flow definition for [{flow_details}] with ID = [{flow_key}]")
312
+ cls._log.info(f"Generating flow definition from [{flow_details}] with ID = [{flow_key}]")
331
313
 
332
- flow_path = config_dir.joinpath(flow_details) if config_dir is not None else pathlib.Path(flow_details)
333
- flow_parser = _cfg_p.ConfigParser(_meta.FlowDefinition)
334
- flow_raw_data = flow_parser.load_raw_config(flow_path, flow_path.name)
335
- flow_def = flow_parser.parse(flow_raw_data, flow_path.name)
314
+ flow_def = config_mgr.load_config_object(flow_details, _meta.FlowDefinition)
336
315
 
337
- flow_def = cls._autowire_flow(flow_def)
338
- flow_def = cls._generate_flow_parameters(flow_def, job_config)
339
- flow_def = cls._generate_flow_inputs(flow_def, job_config)
340
- flow_def = cls._generate_flow_outputs(flow_def, job_config)
316
+ # Auto-wiring and inference only applied to externally loaded flows for now
317
+ flow_def = cls._autowire_flow(flow_def, job_config)
318
+ flow_def = cls._apply_type_inference(flow_def, job_config)
341
319
 
342
- flow_object = _meta.ObjectDefinition(
320
+ flow_obj = _meta.ObjectDefinition(
343
321
  objectType=_meta.ObjectType.FLOW,
344
322
  flow=flow_def)
345
323
 
346
- return flow_id, flow_object
324
+ job_config = copy.copy(job_config)
325
+ job_config.job = copy.copy(job_config.job)
326
+ job_config.job.runFlow = copy.copy(job_config.job.runFlow)
327
+ job_config.resources = copy.copy(job_config.resources)
328
+
329
+ job_config = cls._add_job_resource(job_config, flow_id, flow_obj)
330
+ job_config.job.runFlow.flow = _util.selector_for(flow_id)
331
+
332
+ return job_config
347
333
 
348
334
  @classmethod
349
- def _autowire_flow(cls, flow: _meta.FlowDefinition):
335
+ def _autowire_flow(cls, flow: _meta.FlowDefinition, job_config: _cfg.JobConfig):
336
+
337
+ job = job_config.job.runFlow
338
+ nodes = copy.copy(flow.nodes)
339
+ edges: tp.Dict[str, _meta.FlowEdge] = dict()
350
340
 
351
341
  sources: tp.Dict[str, _meta.FlowSocket] = dict()
352
342
  duplicates: tp.Dict[str, tp.List[_meta.FlowSocket]] = dict()
353
-
354
- edges: tp.Dict[str, _meta.FlowEdge] = dict()
355
343
  errors: tp.Dict[str, str] = dict()
356
344
 
357
345
  def socket_key(socket: _meta.FlowSocket):
358
346
  return f"{socket.node}.{socket.socket}" if socket.socket else socket.node
359
347
 
348
+ # Before starting, add any edges defined explicitly in the flow
349
+ # These take precedence over auto-wired edges
350
+ for edge in flow.edges:
351
+ edges[socket_key(edge.target)] = edge
352
+
360
353
  def add_source(name: str, socket: _meta.FlowSocket):
361
354
  if name in duplicates:
362
355
  duplicates[name].append(socket)
@@ -366,6 +359,14 @@ class DevModeTranslator:
366
359
  else:
367
360
  sources[name] = socket
368
361
 
362
+ def add_param_to_flow(nodel_node: str, param: str):
363
+ target = f"{nodel_node}.{param}"
364
+ if target not in edges and param not in nodes:
365
+ param_node = _meta.FlowNode(_meta.FlowNodeType.PARAMETER_NODE)
366
+ nodes[param] = param_node
367
+ socket = _meta.FlowSocket(param)
368
+ add_source(param, socket)
369
+
369
370
  def add_edge(target: _meta.FlowSocket):
370
371
  target_key = socket_key(target)
371
372
  if target_key in edges:
@@ -380,23 +381,29 @@ class DevModeTranslator:
380
381
  errors[target_key] = f"Flow target {target_name} is not provided by any node"
381
382
 
382
383
  for node_name, node in flow.nodes.items():
383
- if node.nodeType == _meta.FlowNodeType.INPUT_NODE:
384
+ if node.nodeType == _meta.FlowNodeType.INPUT_NODE or node.nodeType == _meta.FlowNodeType.PARAMETER_NODE:
384
385
  add_source(node_name, _meta.FlowSocket(node_name))
385
386
  if node.nodeType == _meta.FlowNodeType.MODEL_NODE:
386
387
  for model_output in node.outputs:
387
388
  add_source(model_output, _meta.FlowSocket(node_name, model_output))
388
-
389
- # Include any edges defined explicitly in the flow
390
- # These take precedence over auto-wired edges
391
- for edge in flow.edges:
392
- edges[socket_key(edge.target)] = edge
393
-
394
- for node_name, node in flow.nodes.items():
389
+ # Generate node param sockets needed by the model
390
+ if node_name in job.models:
391
+ model_selector = job.models[node_name]
392
+ model_obj = _util.get_job_resource(model_selector, job_config)
393
+ for param_name in model_obj.model.parameters:
394
+ add_param_to_flow(node_name, param_name)
395
+ if param_name not in node.parameters:
396
+ node.parameters.append(param_name)
397
+
398
+ # Look at the new set of nodes, which includes any added by auto-wiring
399
+ for node_name, node in nodes.items():
395
400
  if node.nodeType == _meta.FlowNodeType.OUTPUT_NODE:
396
401
  add_edge(_meta.FlowSocket(node_name))
397
402
  if node.nodeType == _meta.FlowNodeType.MODEL_NODE:
398
403
  for model_input in node.inputs:
399
404
  add_edge(_meta.FlowSocket(node_name, model_input))
405
+ for model_param in node.parameters:
406
+ add_edge(_meta.FlowSocket(node_name, model_param))
400
407
 
401
408
  if any(errors):
402
409
 
@@ -408,140 +415,149 @@ class DevModeTranslator:
408
415
  raise _ex.EConfigParse(err)
409
416
 
410
417
  autowired_flow = copy.copy(flow)
418
+ autowired_flow.nodes = nodes
411
419
  autowired_flow.edges = list(edges.values())
412
420
 
413
421
  return autowired_flow
414
422
 
415
423
  @classmethod
416
- def _generate_flow_parameters(cls, flow: _meta.FlowDefinition, job_config: _cfg.JobConfig) -> _meta.FlowDefinition:
417
-
418
- params: tp.Dict[str, _meta.ModelParameter] = dict()
424
+ def _apply_type_inference(cls, flow: _meta.FlowDefinition, job_config: _cfg.JobConfig) -> _meta.FlowDefinition:
419
425
 
420
- for node_name, node in flow.nodes.items():
421
-
422
- if node.nodeType != _meta.FlowNodeType.MODEL_NODE:
423
- continue
426
+ updated_flow = copy.copy(flow)
427
+ updated_flow.parameters = copy.copy(flow.parameters)
428
+ updated_flow.inputs = copy.copy(flow.inputs)
429
+ updated_flow.outputs = copy.copy(flow.outputs)
424
430
 
425
- if node_name not in job_config.job.runFlow.models:
426
- err = f"No model supplied for flow model node [{node_name}]"
427
- cls._log.error(err)
428
- raise _ex.EConfigParse(err)
429
-
430
- model_selector = job_config.job.runFlow.models[node_name]
431
- model_obj = _util.get_job_resource(model_selector, job_config)
432
-
433
- for param_name, param in model_obj.model.parameters.items():
431
+ def socket_key(socket):
432
+ return f"{socket.node}.{socket.socket}" if socket.socket else socket.node
434
433
 
435
- if param_name not in params:
436
- params[param_name] = param
434
+ # Build a map of edges by source socket, mapping to all edges flowing from that source
435
+ edges_by_source = {socket_key(edge.source): [] for edge in flow.edges}
436
+ edges_by_target = {socket_key(edge.target): [] for edge in flow.edges}
437
+ for edge in flow.edges:
438
+ edges_by_source[socket_key(edge.source)].append(edge.target)
439
+ edges_by_target[socket_key(edge.target)].append(edge.source)
437
440
 
438
- else:
439
- existing_param = params[param_name]
441
+ for node_name, node in flow.nodes.items():
440
442
 
441
- if param.paramType != existing_param.paramType:
442
- err = f"Model parameter [{param_name}] has different types in different models"
443
- cls._log.error(err)
444
- raise _ex.EConfigParse(err)
443
+ if node.nodeType == _meta.FlowNodeType.PARAMETER_NODE and node_name not in flow.parameters:
444
+ targets = edges_by_source.get(node_name) or []
445
+ model_parameter = cls._infer_parameter(node_name, targets, job_config)
446
+ updated_flow.parameters[node_name] = model_parameter
445
447
 
446
- if param.defaultValue != existing_param.defaultValue:
447
- if existing_param.defaultValue is None:
448
- params[param_name] = param
449
- elif param.defaultValue is not None:
450
- warn = f"Model parameter [{param_name}] has different default values in different models" \
451
- + f" (using [{_types.MetadataCodec.decode_value(existing_param.defaultValue)}])"
452
- cls._log.warning(warn)
448
+ if node.nodeType == _meta.FlowNodeType.INPUT_NODE and node_name not in flow.inputs:
449
+ targets = edges_by_source.get(node_name) or []
450
+ model_input = cls._infer_input_schema(node_name, targets, job_config)
451
+ updated_flow.inputs[node_name] = model_input
453
452
 
454
- flow.parameters = params
453
+ if node.nodeType == _meta.FlowNodeType.OUTPUT_NODE and node_name not in flow.outputs:
454
+ sources = edges_by_target.get(node_name) or []
455
+ model_output = cls._infer_output_schema(node_name, sources, job_config)
456
+ updated_flow.outputs[node_name] = model_output
455
457
 
456
- return flow
458
+ return updated_flow
457
459
 
458
460
  @classmethod
459
- def _generate_flow_inputs(cls, flow: _meta.FlowDefinition, job_config: _cfg.JobConfig) -> _meta.FlowDefinition:
461
+ def _infer_parameter(
462
+ cls, param_name: str, targets: tp.List[_meta.FlowSocket],
463
+ job_config: _cfg.JobConfig) -> _meta.ModelParameter:
460
464
 
461
- inputs: tp.Dict[str, _meta.ModelInputSchema] = dict()
465
+ model_params = []
462
466
 
463
- def socket_key(socket):
464
- return f"{socket.node}.{socket.socket}" if socket.socket else socket.node
467
+ for target in targets:
465
468
 
466
- # Build a map of edges by source socket, mapping to all edges flowing from that source
467
- edges = {socket_key(edge.source): [] for edge in flow.edges}
468
- for edge in flow.edges:
469
- edges[socket_key(edge.source)].append(edge)
470
-
471
- for node_name, node in flow.nodes.items():
472
-
473
- if node.nodeType != _meta.FlowNodeType.INPUT_NODE:
474
- continue
475
-
476
- input_edges = edges.get(node_name)
469
+ model_selector = job_config.job.runFlow.models.get(target.node)
470
+ model_obj = _util.get_job_resource(model_selector, job_config)
471
+ model_param = model_obj.model.parameters.get(target.socket)
472
+ model_params.append(model_param)
477
473
 
478
- if not input_edges:
479
- err = f"Flow input [{node_name}] is not connected, so the input schema cannot be inferred" \
480
- + f" (either remove the input or connect it to a model)"
481
- cls._log.error(err)
482
- raise _ex.EConfigParse(err)
474
+ if len(model_params) == 0:
475
+ err = f"Flow parameter [{param_name}] is not connected to any models, type information cannot be inferred" \
476
+ + f" (either remove the parameter or connect it to a model)"
477
+ cls._log.error(err)
478
+ raise _ex.EJobValidation(err)
483
479
 
484
- input_schemas = []
480
+ if len(model_params) == 1:
481
+ return model_params[0]
485
482
 
486
- for edge in input_edges:
483
+ model_param = model_params[0]
487
484
 
488
- target_node = flow.nodes.get(edge.target.node) # or cls._report_error(cls._MISSING_FLOW_NODE, node_name)
489
- # cls._require(target_node.nodeType == _meta.FlowNodeType.MODEL_NODE)
485
+ for i in range(1, len(targets)):
486
+ next_param = model_params[i]
487
+ if next_param.paramType != model_param.paramType:
488
+ err = f"Parameter is ambiguous for [{param_name}]: " + \
489
+ f"Types are different for [{cls._socket_key(targets[0])}] and [{cls._socket_key(targets[i])}]"
490
+ raise _ex.EJobValidation(err)
491
+ if next_param.defaultValue is None or next_param.defaultValue != model_param.defaultValue:
492
+ model_param.defaultValue = None
490
493
 
491
- model_selector = job_config.job.runFlow.models.get(edge.target.node)
492
- model_obj = _util.get_job_resource(model_selector, job_config)
493
- model_input = model_obj.model.inputs[edge.target.socket]
494
- input_schemas.append(model_input)
494
+ return model_param
495
495
 
496
- if len(input_schemas) == 1:
497
- inputs[node_name] = input_schemas[0]
498
- else:
499
- first_schema = input_schemas[0]
500
- if all(map(lambda s: s == first_schema, input_schemas[1:])):
501
- inputs[node_name] = first_schema
502
- else:
503
- raise _ex.EJobValidation(f"Multiple models use input [{node_name}] but expect different schemas")
496
+ @classmethod
497
+ def _infer_input_schema(
498
+ cls, input_name: str, targets: tp.List[_meta.FlowSocket],
499
+ job_config: _cfg.JobConfig) -> _meta.ModelInputSchema:
504
500
 
505
- flow.inputs = inputs
501
+ model_inputs = []
506
502
 
507
- return flow
503
+ for target in targets:
508
504
 
509
- @classmethod
510
- def _generate_flow_outputs(cls, flow: _meta.FlowDefinition, job_config: _cfg.JobConfig) -> _meta.FlowDefinition:
505
+ model_selector = job_config.job.runFlow.models.get(target.node)
506
+ model_obj = _util.get_job_resource(model_selector, job_config)
507
+ model_input = model_obj.model.inputs.get(target.socket)
508
+ model_inputs.append(model_input)
511
509
 
512
- outputs: tp.Dict[str, _meta.ModelOutputSchema] = dict()
510
+ if len(model_inputs) == 0:
511
+ err = f"Flow input [{input_name}] is not connected to any models, schema cannot be inferred" \
512
+ + f" (either remove the input or connect it to a model)"
513
+ cls._log.error(err)
514
+ raise _ex.EJobValidation(err)
513
515
 
514
- def socket_key(socket):
515
- return f"{socket.node}.{socket.socket}" if socket.socket else socket.node
516
+ if len(model_inputs) == 1:
517
+ return model_inputs[0]
516
518
 
517
- # Build a map of edges by target socket, there can only be one edge per target in a valid flow
518
- edges = {socket_key(edge.target): edge for edge in flow.edges}
519
+ model_input = model_inputs[0]
519
520
 
520
- for node_name, node in flow.nodes.items():
521
+ for i in range(1, len(targets)):
522
+ next_input = model_inputs[i]
523
+ # Very strict rules on inputs, they must have the exact same schema
524
+ # The Java code includes a combineSchema() method which could be used here as well
525
+ if next_input != model_input:
526
+ raise _ex.EJobValidation(f"Multiple models use input [{input_name}] but expect different schemas")
521
527
 
522
- if node.nodeType != _meta.FlowNodeType.OUTPUT_NODE:
523
- continue
528
+ return model_input
524
529
 
525
- edge = edges.get(node_name)
530
+ @classmethod
531
+ def _infer_output_schema(
532
+ cls, output_name: str, sources: tp.List[_meta.FlowSocket],
533
+ job_config: _cfg.JobConfig) -> _meta.ModelOutputSchema:
526
534
 
527
- if not edge:
528
- err = f"Flow output [{node_name}] is not connected, so the output schema cannot be inferred" \
529
- + f" (either remove the output or connect it to a model)"
530
- cls._log.error(err)
531
- raise _ex.EConfigParse(err)
535
+ model_outputs = []
532
536
 
533
- source_node = flow.nodes.get(edge.source.node) # or cls._report_error(cls._MISSING_FLOW_NODE, node_name)
534
- # cls._require(target_node.nodeType == _meta.FlowNodeType.MODEL_NODE)
537
+ for source in sources:
535
538
 
536
- model_selector = job_config.job.runFlow.models.get(edge.source.node)
539
+ model_selector = job_config.job.runFlow.models.get(source.node)
537
540
  model_obj = _util.get_job_resource(model_selector, job_config)
538
- model_output = model_obj.model.outputs[edge.source.socket]
541
+ model_input = model_obj.model.inputs.get(source.socket)
542
+ model_outputs.append(model_input)
539
543
 
540
- outputs[node_name] = model_output
544
+ if len(model_outputs) == 0:
545
+ err = f"Flow output [{output_name}] is not connected to any models, schema cannot be inferred" \
546
+ + f" (either remove the output or connect it to a model)"
547
+ cls._log.error(err)
548
+ raise _ex.EJobValidation(err)
541
549
 
542
- flow.outputs = outputs
550
+ if len(model_outputs) > 1:
551
+ err = f"Flow output [{output_name}] is not to multiple models" \
552
+ + f" (only one model can supply one output)"
553
+ cls._log.error(err)
554
+ raise _ex.EJobValidation(err)
543
555
 
544
- return flow
556
+ return model_outputs[0]
557
+
558
+ @classmethod
559
+ def _socket_key(cls, socket):
560
+ return f"{socket.node}.{socket.socket}" if socket.socket else socket.node
545
561
 
546
562
  @classmethod
547
563
  def _process_parameters(cls, job_config: _cfg.JobConfig) -> _cfg.JobConfig:
@@ -564,19 +580,8 @@ class DevModeTranslator:
564
580
  param_specs = model_or_flow.parameters
565
581
  param_values = job_details.parameters
566
582
 
567
- encoded_params = cls._process_parameters_dict(param_specs, param_values)
568
-
569
- job_details = copy.copy(job_details)
570
- job_def = copy.copy(job_config.job)
571
- job_config = copy.copy(job_config)
572
-
573
- if job_config.job.jobType == _meta.JobType.RUN_MODEL:
574
- job_def.runModel = job_details
575
- else:
576
- job_def.runFlow = job_details
577
-
578
- job_details.parameters = encoded_params
579
- job_config.job = job_def
583
+ # Set encoded params on runModel or runFlow depending on the job type
584
+ job_details.parameters = cls._process_parameters_dict(param_specs, param_values)
580
585
 
581
586
  return job_config
582
587
 
@@ -610,9 +615,57 @@ class DevModeTranslator:
610
615
  return encoded_values
611
616
 
612
617
  @classmethod
613
- def _process_job_io(
614
- cls, sys_config, data_key, data_value, data_id, storage_id,
615
- new_unique_file=False, schema: tp.Optional[_meta.SchemaDefinition] = None):
618
+ def _process_inputs_and_outputs(cls, sys_config: _cfg.RuntimeConfig, job_config: _cfg.JobConfig) -> _cfg.JobConfig:
619
+
620
+ if job_config.job.jobType == _meta.JobType.RUN_MODEL:
621
+ job_details = job_config.job.runModel
622
+ model_obj = _util.get_job_resource(job_details.model, job_config)
623
+ required_inputs = model_obj.model.inputs
624
+
625
+ elif job_config.job.jobType == _meta.JobType.RUN_FLOW:
626
+ job_details = job_config.job.runFlow
627
+ flow_obj = _util.get_job_resource(job_details.flow, job_config)
628
+ required_inputs = flow_obj.flow.inputs
629
+
630
+ else:
631
+ return job_config
632
+
633
+ job_inputs = job_details.inputs
634
+ job_outputs = job_details.outputs
635
+ job_resources = job_config.resources
636
+
637
+ for input_key, input_value in job_inputs.items():
638
+ if not (isinstance(input_value, str) and input_value in job_resources):
639
+
640
+ input_schema = required_inputs[input_key].schema
641
+
642
+ input_id = cls._process_input_or_output(
643
+ sys_config, input_key, input_value, job_resources,
644
+ new_unique_file=False, schema=input_schema)
645
+
646
+ job_inputs[input_key] = _util.selector_for(input_id)
647
+
648
+ for output_key, output_value in job_outputs.items():
649
+ if not (isinstance(output_value, str) and output_value in job_resources):
650
+
651
+ output_id = cls._process_input_or_output(
652
+ sys_config, output_key, output_value, job_resources,
653
+ new_unique_file=True, schema=None)
654
+
655
+ job_outputs[output_key] = _util.selector_for(output_id)
656
+
657
+ return job_config
658
+
659
+ @classmethod
660
+ def _process_input_or_output(
661
+ cls, sys_config, data_key, data_value,
662
+ resources: tp.Dict[str, _meta.ObjectDefinition],
663
+ new_unique_file=False,
664
+ schema: tp.Optional[_meta.SchemaDefinition] = None) \
665
+ -> _meta.TagHeader:
666
+
667
+ data_id = _util.new_object_id(_meta.ObjectType.DATA)
668
+ storage_id = _util.new_object_id(_meta.ObjectType.STORAGE)
616
669
 
617
670
  if isinstance(data_value, str):
618
671
  storage_path = data_value
@@ -666,7 +719,10 @@ class DevModeTranslator:
666
719
  snap_index=snap_version, delta_index=1, incarnation_index=1,
667
720
  schema=schema)
668
721
 
669
- return data_obj, storage_obj
722
+ resources[_util.object_key(data_id)] = data_obj
723
+ resources[_util.object_key(storage_id)] = storage_obj
724
+
725
+ return data_id
670
726
 
671
727
  @staticmethod
672
728
  def infer_format(storage_path: str, storage_config: _cfg.StorageConfig):