tracdap-runtime 0.6.3__py3-none-any.whl → 0.6.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tracdap/rt/_exec/context.py +207 -100
- tracdap/rt/_exec/dev_mode.py +43 -3
- tracdap/rt/_exec/functions.py +14 -17
- tracdap/rt/_impl/data.py +70 -5
- tracdap/rt/_impl/grpc/tracdap/metadata/data_pb2.py +18 -18
- tracdap/rt/_impl/grpc/tracdap/metadata/model_pb2.py +18 -18
- tracdap/rt/_impl/grpc/tracdap/metadata/model_pb2.pyi +8 -4
- tracdap/rt/_impl/static_api.py +26 -10
- tracdap/rt/_impl/validation.py +37 -4
- tracdap/rt/_version.py +1 -1
- tracdap/rt/api/hook.py +2 -4
- tracdap/rt/api/model_api.py +50 -7
- tracdap/rt/api/static_api.py +14 -6
- tracdap/rt/config/common.py +17 -17
- tracdap/rt/config/job.py +2 -2
- tracdap/rt/config/platform.py +25 -25
- tracdap/rt/config/result.py +2 -2
- tracdap/rt/config/runtime.py +3 -3
- tracdap/rt/launch/cli.py +7 -4
- tracdap/rt/launch/launch.py +19 -3
- tracdap/rt/metadata/common.py +2 -2
- tracdap/rt/metadata/custom.py +3 -3
- tracdap/rt/metadata/data.py +12 -12
- tracdap/rt/metadata/file.py +6 -6
- tracdap/rt/metadata/flow.py +6 -6
- tracdap/rt/metadata/job.py +8 -8
- tracdap/rt/metadata/model.py +15 -11
- tracdap/rt/metadata/object_id.py +8 -8
- tracdap/rt/metadata/search.py +5 -5
- tracdap/rt/metadata/stoarge.py +6 -6
- tracdap/rt/metadata/tag.py +1 -1
- tracdap/rt/metadata/tag_update.py +1 -1
- tracdap/rt/metadata/type.py +4 -4
- {tracdap_runtime-0.6.3.dist-info → tracdap_runtime-0.6.4.dist-info}/METADATA +1 -1
- {tracdap_runtime-0.6.3.dist-info → tracdap_runtime-0.6.4.dist-info}/RECORD +38 -38
- {tracdap_runtime-0.6.3.dist-info → tracdap_runtime-0.6.4.dist-info}/LICENSE +0 -0
- {tracdap_runtime-0.6.3.dist-info → tracdap_runtime-0.6.4.dist-info}/WHEEL +0 -0
- {tracdap_runtime-0.6.3.dist-info → tracdap_runtime-0.6.4.dist-info}/top_level.txt +0 -0
tracdap/rt/_exec/context.py
CHANGED
@@ -12,6 +12,7 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
+
import copy
|
15
16
|
import logging
|
16
17
|
import pathlib
|
17
18
|
import typing as tp
|
@@ -32,8 +33,8 @@ import tracdap.rt._impl.validation as _val # noqa
|
|
32
33
|
class TracContextImpl(_api.TracContext):
|
33
34
|
|
34
35
|
"""
|
35
|
-
TracContextImpl is the main implementation of the API class TracContext (from
|
36
|
-
It provides get/put operations the inputs, outputs and parameters of a model according to the model definition,
|
36
|
+
TracContextImpl is the main implementation of the API class TracContext (from tracdap.rt.api).
|
37
|
+
It provides get/put operations on the inputs, outputs and parameters of a model according to the model definition,
|
37
38
|
as well as exposing other information needed by the model at runtime and offering a few utility functions.
|
38
39
|
|
39
40
|
An instance of TracContextImpl is constructed by the runtime engine for each model node in the execution graph.
|
@@ -44,8 +45,8 @@ class TracContextImpl(_api.TracContext):
|
|
44
45
|
|
45
46
|
Optimizations for lazy loading and eager saving require the context to call back into the runtime engine. For lazy
|
46
47
|
load, the graph node to prepare an input is injected when the data is requested and the model thread blocks until
|
47
|
-
it is available; for eager save
|
48
|
-
|
48
|
+
it is available; for eager save outputs are sent to child actors as soon as they are produced. In both cases this
|
49
|
+
complexity is hidden from the model, which only sees one thread with synchronous get/put calls.
|
49
50
|
|
50
51
|
:param model_def: Definition object for the model that will run in this context
|
51
52
|
:param model_class: Type for the model that will run in this context
|
@@ -59,8 +60,7 @@ class TracContextImpl(_api.TracContext):
|
|
59
60
|
def __init__(self,
|
60
61
|
model_def: _meta.ModelDefinition,
|
61
62
|
model_class: _api.TracModel.__class__,
|
62
|
-
local_ctx: tp.Dict[str,
|
63
|
-
schemas: tp.Dict[str, _meta.SchemaDefinition],
|
63
|
+
local_ctx: tp.Dict[str, tp.Any],
|
64
64
|
checkout_directory: pathlib.Path = None):
|
65
65
|
|
66
66
|
self.__ctx_log = _util.logger_for_object(self)
|
@@ -68,26 +68,25 @@ class TracContextImpl(_api.TracContext):
|
|
68
68
|
|
69
69
|
self.__model_def = model_def
|
70
70
|
self.__model_class = model_class
|
71
|
-
|
72
|
-
self.__parameters = local_ctx or {}
|
73
|
-
self.__data = local_ctx or {}
|
74
|
-
self.__schemas = schemas
|
71
|
+
self.__local_ctx = local_ctx or {}
|
75
72
|
|
76
73
|
self.__val = TracContextValidator(
|
77
74
|
self.__ctx_log,
|
78
|
-
self.
|
79
|
-
self.
|
75
|
+
self.__model_def,
|
76
|
+
self.__local_ctx,
|
80
77
|
checkout_directory)
|
81
78
|
|
82
79
|
def get_parameter(self, parameter_name: str) -> tp.Any:
|
83
80
|
|
84
81
|
_val.validate_signature(self.get_parameter, parameter_name)
|
85
82
|
|
86
|
-
self.__val.check_param_not_null(parameter_name)
|
87
83
|
self.__val.check_param_valid_identifier(parameter_name)
|
88
|
-
self.__val.
|
84
|
+
self.__val.check_param_defined_in_model(parameter_name)
|
85
|
+
self.__val.check_param_available_in_context(parameter_name)
|
86
|
+
|
87
|
+
value: _meta.Value = self.__local_ctx.get(parameter_name)
|
89
88
|
|
90
|
-
value
|
89
|
+
self.__val.check_context_object_type(parameter_name, value, _meta.Value)
|
91
90
|
|
92
91
|
return _types.MetadataCodec.decode_value(value)
|
93
92
|
|
@@ -95,65 +94,64 @@ class TracContextImpl(_api.TracContext):
|
|
95
94
|
|
96
95
|
_val.validate_signature(self.has_dataset, dataset_name)
|
97
96
|
|
98
|
-
part_key = _data.DataPartKey.for_root()
|
99
|
-
|
100
|
-
self.__val.check_dataset_name_not_null(dataset_name)
|
101
97
|
self.__val.check_dataset_valid_identifier(dataset_name)
|
98
|
+
self.__val.check_dataset_defined_in_model(dataset_name)
|
102
99
|
|
103
|
-
data_view = self.
|
100
|
+
data_view: _data.DataView = self.__local_ctx.get(dataset_name)
|
104
101
|
|
105
102
|
if data_view is None:
|
106
103
|
return False
|
107
104
|
|
108
|
-
|
109
|
-
# E.g. if this method is called for FILE inputs
|
110
|
-
self.__val.check_context_item_is_dataset(dataset_name)
|
111
|
-
|
112
|
-
part = data_view.parts.get(part_key)
|
105
|
+
self.__val.check_context_object_type(dataset_name, data_view, _data.DataView)
|
113
106
|
|
114
|
-
|
115
|
-
return False
|
116
|
-
|
117
|
-
return True
|
107
|
+
return not data_view.is_empty()
|
118
108
|
|
119
109
|
def get_schema(self, dataset_name: str) -> _meta.SchemaDefinition:
|
120
110
|
|
121
111
|
_val.validate_signature(self.get_schema, dataset_name)
|
122
112
|
|
123
|
-
self.__val.check_dataset_name_not_null(dataset_name)
|
124
113
|
self.__val.check_dataset_valid_identifier(dataset_name)
|
114
|
+
self.__val.check_dataset_defined_in_model(dataset_name)
|
115
|
+
self.__val.check_dataset_available_in_context(dataset_name)
|
125
116
|
|
126
|
-
|
127
|
-
|
128
|
-
return self.__schemas[dataset_name]
|
117
|
+
static_schema = self.__get_static_schema(self.__model_def, dataset_name)
|
118
|
+
data_view: _data.DataView = self.__local_ctx.get(dataset_name)
|
129
119
|
|
130
|
-
|
131
|
-
|
132
|
-
self.__val.check_dataset_schema_defined(dataset_name)
|
120
|
+
# Check the data view has a well-defined schema even if a static schema exists in the model
|
121
|
+
# This ensures errors are always reported and is consistent with get_pandas_table()
|
133
122
|
|
134
|
-
data_view
|
123
|
+
self.__val.check_context_object_type(dataset_name, data_view, _data.DataView)
|
124
|
+
self.__val.check_dataset_schema_defined(dataset_name, data_view)
|
135
125
|
|
136
|
-
|
126
|
+
# If a static schema exists, that takes priority
|
127
|
+
# Return deep copies, do not allow model code to change schemas provided by the engine
|
128
|
+
|
129
|
+
if static_schema is not None:
|
130
|
+
return copy.deepcopy(static_schema)
|
131
|
+
else:
|
132
|
+
return copy.deepcopy(data_view.trac_schema)
|
137
133
|
|
138
134
|
def get_pandas_table(self, dataset_name: str, use_temporal_objects: tp.Optional[bool] = None) -> pd.DataFrame:
|
139
135
|
|
140
136
|
_val.validate_signature(self.get_pandas_table, dataset_name, use_temporal_objects)
|
141
137
|
|
142
|
-
part_key = _data.DataPartKey.for_root()
|
143
|
-
|
144
|
-
self.__val.check_dataset_name_not_null(dataset_name)
|
145
138
|
self.__val.check_dataset_valid_identifier(dataset_name)
|
146
|
-
self.__val.
|
147
|
-
self.__val.
|
148
|
-
self.__val.check_dataset_schema_defined(dataset_name)
|
149
|
-
self.__val.check_dataset_part_present(dataset_name, part_key)
|
139
|
+
self.__val.check_dataset_defined_in_model(dataset_name)
|
140
|
+
self.__val.check_dataset_available_in_context(dataset_name)
|
150
141
|
|
151
|
-
|
142
|
+
static_schema = self.__get_static_schema(self.__model_def, dataset_name)
|
143
|
+
data_view = self.__local_ctx.get(dataset_name)
|
144
|
+
part_key = _data.DataPartKey.for_root()
|
145
|
+
|
146
|
+
self.__val.check_context_object_type(dataset_name, data_view, _data.DataView)
|
147
|
+
self.__val.check_dataset_schema_defined(dataset_name, data_view)
|
148
|
+
self.__val.check_dataset_part_present(dataset_name, data_view, part_key)
|
152
149
|
|
153
150
|
# If the model defines a static input schema, use that for schema conformance
|
154
151
|
# Otherwise, take what is in the incoming dataset (schema is dynamic)
|
155
|
-
|
156
|
-
|
152
|
+
|
153
|
+
if static_schema is not None:
|
154
|
+
schema = _data.DataMapping.trac_to_arrow_schema(static_schema)
|
157
155
|
else:
|
158
156
|
schema = data_view.arrow_schema
|
159
157
|
|
@@ -162,34 +160,71 @@ class TracContextImpl(_api.TracContext):
|
|
162
160
|
|
163
161
|
return _data.DataMapping.view_to_pandas(data_view, part_key, schema, use_temporal_objects)
|
164
162
|
|
163
|
+
def put_schema(self, dataset_name: str, schema: _meta.SchemaDefinition):
|
164
|
+
|
165
|
+
_val.validate_signature(self.get_schema, dataset_name, schema)
|
166
|
+
|
167
|
+
# Copy the schema - schema cannot be changed in model code after put_schema
|
168
|
+
# If field ordering is not assigned by the model, assign it here (model code will not see the numbers)
|
169
|
+
schema_copy = self.__assign_field_order(copy.deepcopy(schema))
|
170
|
+
|
171
|
+
self.__val.check_dataset_valid_identifier(dataset_name)
|
172
|
+
self.__val.check_dataset_is_dynamic_output(dataset_name)
|
173
|
+
self.__val.check_provided_schema_is_valid(dataset_name, schema_copy)
|
174
|
+
|
175
|
+
static_schema = self.__get_static_schema(self.__model_def, dataset_name)
|
176
|
+
data_view = self.__local_ctx.get(dataset_name)
|
177
|
+
|
178
|
+
if data_view is None:
|
179
|
+
if static_schema is not None:
|
180
|
+
data_view = _data.DataView.for_trac_schema(static_schema)
|
181
|
+
else:
|
182
|
+
data_view = _data.DataView.create_empty()
|
183
|
+
|
184
|
+
# If there is a prior view it must contain nothing and will be replaced
|
185
|
+
self.__val.check_context_object_type(dataset_name, data_view, _data.DataView)
|
186
|
+
self.__val.check_dataset_schema_not_defined(dataset_name, data_view)
|
187
|
+
self.__val.check_dataset_is_empty(dataset_name, data_view)
|
188
|
+
|
189
|
+
updated_view = data_view.with_trac_schema(schema_copy)
|
190
|
+
|
191
|
+
self.__local_ctx[dataset_name] = updated_view
|
192
|
+
|
165
193
|
def put_pandas_table(self, dataset_name: str, dataset: pd.DataFrame):
|
166
194
|
|
167
195
|
_val.validate_signature(self.put_pandas_table, dataset_name, dataset)
|
168
196
|
|
169
|
-
part_key = _data.DataPartKey.for_root()
|
170
|
-
|
171
|
-
self.__val.check_dataset_name_not_null(dataset_name)
|
172
197
|
self.__val.check_dataset_valid_identifier(dataset_name)
|
173
|
-
self.__val.
|
174
|
-
self.__val.check_context_item_is_dataset(dataset_name)
|
175
|
-
self.__val.check_dataset_schema_defined(dataset_name)
|
176
|
-
self.__val.check_dataset_part_not_present(dataset_name, part_key)
|
177
|
-
self.__val.check_provided_dataset_not_null(dataset)
|
198
|
+
self.__val.check_dataset_is_model_output(dataset_name)
|
178
199
|
self.__val.check_provided_dataset_type(dataset, pd.DataFrame)
|
179
200
|
|
180
|
-
|
201
|
+
static_schema = self.__get_static_schema(self.__model_def, dataset_name)
|
202
|
+
data_view = self.__local_ctx.get(dataset_name)
|
203
|
+
part_key = _data.DataPartKey.for_root()
|
181
204
|
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
205
|
+
if data_view is None:
|
206
|
+
if static_schema is not None:
|
207
|
+
data_view = _data.DataView.for_trac_schema(static_schema)
|
208
|
+
else:
|
209
|
+
data_view = _data.DataView.create_empty()
|
210
|
+
|
211
|
+
self.__val.check_context_object_type(dataset_name, data_view, _data.DataView)
|
212
|
+
self.__val.check_dataset_schema_defined(dataset_name, data_view)
|
213
|
+
self.__val.check_dataset_part_not_present(dataset_name, data_view, part_key)
|
214
|
+
|
215
|
+
# Prefer static schemas for data conformance
|
216
|
+
|
217
|
+
if static_schema is not None:
|
218
|
+
schema = _data.DataMapping.trac_to_arrow_schema(static_schema)
|
186
219
|
else:
|
187
|
-
schema =
|
220
|
+
schema = data_view.arrow_schema
|
221
|
+
|
222
|
+
# Data conformance is applied inside these conversion functions
|
188
223
|
|
189
|
-
|
190
|
-
|
224
|
+
updated_item = _data.DataMapping.pandas_to_item(dataset, schema)
|
225
|
+
updated_view = _data.DataMapping.add_item_to_view(data_view, part_key, updated_item)
|
191
226
|
|
192
|
-
self.
|
227
|
+
self.__local_ctx[dataset_name] = updated_view
|
193
228
|
|
194
229
|
def log(self) -> logging.Logger:
|
195
230
|
|
@@ -197,6 +232,33 @@ class TracContextImpl(_api.TracContext):
|
|
197
232
|
|
198
233
|
return self.__model_log
|
199
234
|
|
235
|
+
@staticmethod
|
236
|
+
def __get_static_schema(model_def: _meta.ModelDefinition, dataset_name: str):
|
237
|
+
|
238
|
+
input_schema = model_def.inputs.get(dataset_name)
|
239
|
+
|
240
|
+
if input_schema is not None and not input_schema.dynamic:
|
241
|
+
return input_schema.schema
|
242
|
+
|
243
|
+
output_schema = model_def.outputs.get(dataset_name)
|
244
|
+
|
245
|
+
if output_schema is not None and not output_schema.dynamic:
|
246
|
+
return output_schema.schema
|
247
|
+
|
248
|
+
return None
|
249
|
+
|
250
|
+
@staticmethod
|
251
|
+
def __assign_field_order(schema_def: _meta.SchemaDefinition):
|
252
|
+
|
253
|
+
if schema_def is None or schema_def.table is None or schema_def.table.fields is None:
|
254
|
+
return schema_def
|
255
|
+
|
256
|
+
if all(map(lambda f: f.fieldOrder is None, schema_def.table.fields)):
|
257
|
+
for index, field in enumerate(schema_def.table.fields):
|
258
|
+
field.fieldOrder = index
|
259
|
+
|
260
|
+
return schema_def
|
261
|
+
|
200
262
|
|
201
263
|
class TracContextValidator:
|
202
264
|
|
@@ -205,16 +267,16 @@ class TracContextValidator:
|
|
205
267
|
|
206
268
|
def __init__(
|
207
269
|
self, log: logging.Logger,
|
208
|
-
|
209
|
-
|
270
|
+
model_def: _meta.ModelDefinition,
|
271
|
+
local_ctx: tp.Dict[str, tp.Any],
|
210
272
|
checkout_directory: pathlib.Path):
|
211
273
|
|
212
274
|
self.__log = log
|
213
|
-
self.
|
214
|
-
self.
|
275
|
+
self.__model_def = model_def
|
276
|
+
self.__local_ctx = local_ctx
|
215
277
|
self.__checkout_directory = checkout_directory
|
216
278
|
|
217
|
-
def _report_error(self, message):
|
279
|
+
def _report_error(self, message, cause: Exception = None):
|
218
280
|
|
219
281
|
full_stack = traceback.extract_stack()
|
220
282
|
model_stack = _util.filter_model_stack_trace(full_stack, self.__checkout_directory)
|
@@ -225,80 +287,114 @@ class TracContextValidator:
|
|
225
287
|
self.__log.error(message)
|
226
288
|
self.__log.error(f"Model stack trace:\n{model_stack_str}")
|
227
289
|
|
228
|
-
|
290
|
+
if cause:
|
291
|
+
raise _ex.ERuntimeValidation(message) from cause
|
292
|
+
else:
|
293
|
+
raise _ex.ERuntimeValidation(message)
|
229
294
|
|
230
|
-
def
|
295
|
+
def check_param_valid_identifier(self, param_name: str):
|
231
296
|
|
232
297
|
if param_name is None:
|
233
298
|
self._report_error(f"Parameter name is null")
|
234
299
|
|
235
|
-
def check_param_valid_identifier(self, param_name: str):
|
236
|
-
|
237
300
|
if not self.__VALID_IDENTIFIER.match(param_name):
|
238
301
|
self._report_error(f"Parameter name {param_name} is not a valid identifier")
|
239
302
|
|
240
|
-
def
|
303
|
+
def check_param_defined_in_model(self, param_name: str):
|
241
304
|
|
242
|
-
if param_name not in self.
|
243
|
-
self._report_error(f"Parameter {param_name} is not defined in the
|
305
|
+
if param_name not in self.__model_def.parameters:
|
306
|
+
self._report_error(f"Parameter {param_name} is not defined in the model")
|
244
307
|
|
245
|
-
def
|
308
|
+
def check_param_available_in_context(self, param_name: str):
|
246
309
|
|
247
|
-
if
|
248
|
-
self._report_error(f"
|
310
|
+
if param_name not in self.__local_ctx:
|
311
|
+
self._report_error(f"Parameter {param_name} is not available in the current context")
|
249
312
|
|
250
313
|
def check_dataset_valid_identifier(self, dataset_name: str):
|
251
314
|
|
315
|
+
if dataset_name is None:
|
316
|
+
self._report_error(f"Dataset name is null")
|
317
|
+
|
252
318
|
if not self.__VALID_IDENTIFIER.match(dataset_name):
|
253
319
|
self._report_error(f"Dataset name {dataset_name} is not a valid identifier")
|
254
320
|
|
255
|
-
def
|
321
|
+
def check_dataset_defined_in_model(self, dataset_name: str):
|
322
|
+
|
323
|
+
if dataset_name not in self.__model_def.inputs and dataset_name not in self.__model_def.outputs:
|
324
|
+
self._report_error(f"Dataset {dataset_name} is not defined in the model")
|
325
|
+
|
326
|
+
def check_dataset_is_model_output(self, dataset_name: str):
|
256
327
|
|
257
|
-
if
|
258
|
-
self._report_error(f"
|
328
|
+
if dataset_name not in self.__model_def.outputs:
|
329
|
+
self._report_error(f"Dataset {dataset_name} is not defined as a model output")
|
259
330
|
|
260
|
-
def
|
331
|
+
def check_dataset_is_dynamic_output(self, dataset_name: str):
|
261
332
|
|
262
|
-
|
333
|
+
model_output: _meta.ModelOutputSchema = self.__model_def.outputs.get(dataset_name)
|
263
334
|
|
264
|
-
if
|
265
|
-
self._report_error(f"
|
335
|
+
if model_output is None:
|
336
|
+
self._report_error(f"Dataset {dataset_name} is not defined as a model output")
|
266
337
|
|
267
|
-
|
338
|
+
if not model_output.dynamic:
|
339
|
+
self._report_error(f"Model output {dataset_name} is not a dynamic output")
|
268
340
|
|
269
|
-
|
341
|
+
def check_dataset_available_in_context(self, item_name: str):
|
270
342
|
|
271
|
-
if
|
343
|
+
if item_name not in self.__local_ctx:
|
344
|
+
self._report_error(f"Dataset {item_name} is not available in the current context")
|
345
|
+
|
346
|
+
def check_dataset_schema_defined(self, dataset_name: str, data_view: _data.DataView):
|
347
|
+
|
348
|
+
schema = data_view.trac_schema if data_view is not None else None
|
349
|
+
|
350
|
+
if schema is None or schema.table is None or not schema.table.fields:
|
272
351
|
self._report_error(f"Schema not defined for dataset {dataset_name} in the current context")
|
273
352
|
|
274
|
-
def check_dataset_schema_not_defined(self, dataset_name: str):
|
353
|
+
def check_dataset_schema_not_defined(self, dataset_name: str, data_view: _data.DataView):
|
275
354
|
|
276
|
-
schema =
|
355
|
+
schema = data_view.trac_schema if data_view is not None else None
|
277
356
|
|
278
357
|
if schema is not None and (schema.table or schema.schemaType != _meta.SchemaType.SCHEMA_TYPE_NOT_SET):
|
279
358
|
self._report_error(f"Schema already defined for dataset {dataset_name} in the current context")
|
280
359
|
|
281
|
-
def check_dataset_part_present(self, dataset_name: str, part_key: _data.DataPartKey):
|
360
|
+
def check_dataset_part_present(self, dataset_name: str, data_view: _data.DataView, part_key: _data.DataPartKey):
|
282
361
|
|
283
|
-
part =
|
362
|
+
part = data_view.parts.get(part_key) if data_view.parts is not None else None
|
284
363
|
|
285
364
|
if part is None or len(part) == 0:
|
286
|
-
self._report_error(f"No data present for
|
365
|
+
self._report_error(f"No data present for {dataset_name} ({part_key}) in the current context")
|
287
366
|
|
288
|
-
def check_dataset_part_not_present(self, dataset_name: str, part_key: _data.DataPartKey):
|
367
|
+
def check_dataset_part_not_present(self, dataset_name: str, data_view: _data.DataView, part_key: _data.DataPartKey):
|
289
368
|
|
290
|
-
part =
|
369
|
+
part = data_view.parts.get(part_key) if data_view.parts is not None else None
|
291
370
|
|
292
371
|
if part is not None and len(part) > 0:
|
293
|
-
self._report_error(f"Data already present for
|
372
|
+
self._report_error(f"Data already present for {dataset_name} ({part_key}) in the current context")
|
294
373
|
|
295
|
-
def
|
374
|
+
def check_dataset_is_empty(self, dataset_name: str, data_view: _data.DataView):
|
296
375
|
|
297
|
-
if
|
298
|
-
self._report_error(f"
|
376
|
+
if not data_view.is_empty():
|
377
|
+
self._report_error(f"Dataset {dataset_name} is not empty")
|
378
|
+
|
379
|
+
def check_provided_schema_is_valid(self, dataset_name: str, schema: _meta.SchemaDefinition):
|
380
|
+
|
381
|
+
if schema is None:
|
382
|
+
self._report_error(f"The schema provided for [{dataset_name}] is null")
|
383
|
+
|
384
|
+
if not isinstance(schema, _meta.SchemaDefinition):
|
385
|
+
schema_type_name = self._type_name(type(schema))
|
386
|
+
self._report_error(f"The object provided for [{dataset_name}] is not a schema (got {schema_type_name})")
|
387
|
+
|
388
|
+
try:
|
389
|
+
_val.StaticValidator.quick_validate_schema(schema)
|
390
|
+
except _ex.EModelValidation as e:
|
391
|
+
self._report_error(f"The schema provided for [{dataset_name}] failed validation: {str(e)}", e)
|
299
392
|
|
300
393
|
def check_provided_dataset_type(self, dataset: tp.Any, expected_type: type):
|
301
394
|
|
395
|
+
if dataset is None:
|
396
|
+
self._report_error(f"Provided dataset is null")
|
397
|
+
|
302
398
|
if not isinstance(dataset, expected_type):
|
303
399
|
|
304
400
|
expected_type_name = self._type_name(expected_type)
|
@@ -308,6 +404,17 @@ class TracContextValidator:
|
|
308
404
|
f"Provided dataset is the wrong type" +
|
309
405
|
f" (expected {expected_type_name}, got {actual_type_name})")
|
310
406
|
|
407
|
+
def check_context_object_type(self, item_name: str, item: tp.Any, expected_type: type):
|
408
|
+
|
409
|
+
if not isinstance(item, expected_type):
|
410
|
+
|
411
|
+
expected_type_name = self._type_name(expected_type)
|
412
|
+
actual_type_name = self._type_name(type(item))
|
413
|
+
|
414
|
+
self._report_error(
|
415
|
+
f"The object referenced by [{item_name}] in the current context has the wrong type" +
|
416
|
+
f" (expected {expected_type_name}, got {actual_type_name})")
|
417
|
+
|
311
418
|
@staticmethod
|
312
419
|
def _type_name(type_: type):
|
313
420
|
|
tracdap/rt/_exec/dev_mode.py
CHANGED
@@ -313,6 +313,9 @@ class DevModeTranslator:
|
|
313
313
|
|
314
314
|
flow_def = config_mgr.load_config_object(flow_details, _meta.FlowDefinition)
|
315
315
|
|
316
|
+
# Validate models against the flow (this could move to _impl.validation and check prod jobs as well)
|
317
|
+
cls._check_models_for_flow(flow_def, job_config)
|
318
|
+
|
316
319
|
# Auto-wiring and inference only applied to externally loaded flows for now
|
317
320
|
flow_def = cls._autowire_flow(flow_def, job_config)
|
318
321
|
flow_def = cls._apply_type_inference(flow_def, job_config)
|
@@ -331,6 +334,37 @@ class DevModeTranslator:
|
|
331
334
|
|
332
335
|
return job_config
|
333
336
|
|
337
|
+
@classmethod
|
338
|
+
def _check_models_for_flow(cls, flow: _meta.FlowDefinition, job_config: _cfg.JobConfig):
|
339
|
+
|
340
|
+
model_nodes = dict(filter(lambda n: n[1].nodeType == _meta.FlowNodeType.MODEL_NODE, flow.nodes.items()))
|
341
|
+
|
342
|
+
missing_models = list(filter(lambda m: m not in job_config.job.runFlow.models, model_nodes.keys()))
|
343
|
+
extra_models = list(filter(lambda m: m not in model_nodes, job_config.job.runFlow.models.keys()))
|
344
|
+
|
345
|
+
if any(missing_models):
|
346
|
+
error = f"Missing models in job definition: {', '.join(missing_models)}"
|
347
|
+
cls._log.error(error)
|
348
|
+
raise _ex.EJobValidation(error)
|
349
|
+
|
350
|
+
if any (extra_models):
|
351
|
+
error = f"Extra models in job definition: {', '.join(extra_models)}"
|
352
|
+
cls._log.error(error)
|
353
|
+
raise _ex.EJobValidation(error)
|
354
|
+
|
355
|
+
for model_name, model_node in model_nodes.items():
|
356
|
+
|
357
|
+
model_selector = job_config.job.runFlow.models[model_name]
|
358
|
+
model_obj = _util.get_job_resource(model_selector, job_config)
|
359
|
+
|
360
|
+
model_inputs = set(model_obj.model.inputs.keys())
|
361
|
+
model_outputs = set(model_obj.model.outputs.keys())
|
362
|
+
|
363
|
+
if model_inputs != set(model_node.inputs) or model_outputs != set(model_node.outputs):
|
364
|
+
error = f"The model supplied for [{model_name}] does not match the flow definition"
|
365
|
+
cls._log.error(error)
|
366
|
+
raise _ex.EJobValidation(error)
|
367
|
+
|
334
368
|
@classmethod
|
335
369
|
def _autowire_flow(cls, flow: _meta.FlowDefinition, job_config: _cfg.JobConfig):
|
336
370
|
|
@@ -621,11 +655,13 @@ class DevModeTranslator:
|
|
621
655
|
job_details = job_config.job.runModel
|
622
656
|
model_obj = _util.get_job_resource(job_details.model, job_config)
|
623
657
|
required_inputs = model_obj.model.inputs
|
658
|
+
required_outputs = model_obj.model.outputs
|
624
659
|
|
625
660
|
elif job_config.job.jobType == _meta.JobType.RUN_FLOW:
|
626
661
|
job_details = job_config.job.runFlow
|
627
662
|
flow_obj = _util.get_job_resource(job_details.flow, job_config)
|
628
663
|
required_inputs = flow_obj.flow.inputs
|
664
|
+
required_outputs = flow_obj.flow.outputs
|
629
665
|
|
630
666
|
else:
|
631
667
|
return job_config
|
@@ -637,7 +673,8 @@ class DevModeTranslator:
|
|
637
673
|
for input_key, input_value in job_inputs.items():
|
638
674
|
if not (isinstance(input_value, str) and input_value in job_resources):
|
639
675
|
|
640
|
-
|
676
|
+
model_input = required_inputs[input_key]
|
677
|
+
input_schema = model_input.schema if model_input and not model_input.dynamic else None
|
641
678
|
|
642
679
|
input_id = cls._process_input_or_output(
|
643
680
|
sys_config, input_key, input_value, job_resources,
|
@@ -648,9 +685,12 @@ class DevModeTranslator:
|
|
648
685
|
for output_key, output_value in job_outputs.items():
|
649
686
|
if not (isinstance(output_value, str) and output_value in job_resources):
|
650
687
|
|
688
|
+
model_output= required_outputs[output_key]
|
689
|
+
output_schema = model_output.schema if model_output and not model_output.dynamic else None
|
690
|
+
|
651
691
|
output_id = cls._process_input_or_output(
|
652
692
|
sys_config, output_key, output_value, job_resources,
|
653
|
-
new_unique_file=True, schema=
|
693
|
+
new_unique_file=True, schema=output_schema)
|
654
694
|
|
655
695
|
job_outputs[output_key] = _util.selector_for(output_id)
|
656
696
|
|
@@ -768,7 +808,7 @@ class DevModeTranslator:
|
|
768
808
|
if schema is not None:
|
769
809
|
data_def.schema = schema
|
770
810
|
else:
|
771
|
-
data_def.schema =
|
811
|
+
data_def.schema = None
|
772
812
|
|
773
813
|
data_def.storageId = _meta.TagSelector(
|
774
814
|
_meta.ObjectType.STORAGE, storage_id.objectId,
|
tracdap/rt/_exec/functions.py
CHANGED
@@ -252,7 +252,13 @@ class DataViewFunc(NodeFunction[_data.DataView]):
|
|
252
252
|
if root_item.is_empty():
|
253
253
|
return _data.DataView.create_empty()
|
254
254
|
|
255
|
-
|
255
|
+
if self.node.schema is not None and len(self.node.schema.table.fields) > 0:
|
256
|
+
trac_schema = self.node.schema
|
257
|
+
else:
|
258
|
+
arrow_schema = root_item.schema
|
259
|
+
trac_schema = _data.DataMapping.arrow_to_trac_schema(arrow_schema)
|
260
|
+
|
261
|
+
data_view = _data.DataView.for_trac_schema(trac_schema)
|
256
262
|
data_view = _data.DataMapping.add_item_to_view(data_view, root_part_key, root_item)
|
257
263
|
|
258
264
|
return data_view
|
@@ -544,7 +550,6 @@ class RunModelFunc(NodeFunction[Bundle[_data.DataView]]):
|
|
544
550
|
# Still, if any nodes are missing or have the wrong type TracContextImpl will raise ERuntimeValidation
|
545
551
|
|
546
552
|
local_ctx = {}
|
547
|
-
static_schemas = {}
|
548
553
|
|
549
554
|
for node_id, node_result in _ctx_iter_items(ctx):
|
550
555
|
|
@@ -558,22 +563,10 @@ class RunModelFunc(NodeFunction[Bundle[_data.DataView]]):
|
|
558
563
|
if node_id.name in model_def.inputs:
|
559
564
|
input_name = node_id.name
|
560
565
|
local_ctx[input_name] = node_result
|
561
|
-
# At the moment, all model inputs have static schemas
|
562
|
-
static_schemas[input_name] = model_def.inputs[input_name].schema
|
563
|
-
|
564
|
-
# Add empty data views to the local context to hold model outputs
|
565
|
-
# Assuming outputs are all defined with static schemas
|
566
|
-
|
567
|
-
for output_name in model_def.outputs:
|
568
|
-
output_schema = self.node.model_def.outputs[output_name].schema
|
569
|
-
empty_data_view = _data.DataView.for_trac_schema(output_schema)
|
570
|
-
local_ctx[output_name] = empty_data_view
|
571
|
-
# At the moment, all model outputs have static schemas
|
572
|
-
static_schemas[output_name] = output_schema
|
573
566
|
|
574
567
|
# Run the model against the mapped local context
|
575
568
|
|
576
|
-
trac_ctx = _ctx.TracContextImpl(self.node.model_def, self.model_class, local_ctx,
|
569
|
+
trac_ctx = _ctx.TracContextImpl(self.node.model_def, self.model_class, local_ctx, self.checkout_directory)
|
577
570
|
|
578
571
|
try:
|
579
572
|
model = self.model_class()
|
@@ -594,12 +587,16 @@ class RunModelFunc(NodeFunction[Bundle[_data.DataView]]):
|
|
594
587
|
result: _data.DataView = local_ctx.get(output_name)
|
595
588
|
|
596
589
|
if result is None or result.is_empty():
|
590
|
+
|
597
591
|
if not output_schema.optional:
|
598
592
|
model_name = self.model_class.__name__
|
599
593
|
raise _ex.ERuntimeValidation(f"Missing required output [{output_name}] from model [{model_name}]")
|
600
594
|
|
601
|
-
|
602
|
-
|
595
|
+
# Create a placeholder for optional outputs that were not emitted
|
596
|
+
elif result is None:
|
597
|
+
result = _data.DataView.create_empty()
|
598
|
+
|
599
|
+
results[output_name] = result
|
603
600
|
|
604
601
|
return results
|
605
602
|
|