tracdap-runtime 0.6.3__py3-none-any.whl → 0.6.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. tracdap/rt/_exec/context.py +572 -112
  2. tracdap/rt/_exec/dev_mode.py +166 -97
  3. tracdap/rt/_exec/engine.py +120 -9
  4. tracdap/rt/_exec/functions.py +137 -35
  5. tracdap/rt/_exec/graph.py +38 -13
  6. tracdap/rt/_exec/graph_builder.py +120 -9
  7. tracdap/rt/_impl/data.py +183 -52
  8. tracdap/rt/_impl/grpc/tracdap/metadata/data_pb2.py +18 -18
  9. tracdap/rt/_impl/grpc/tracdap/metadata/job_pb2.py +74 -30
  10. tracdap/rt/_impl/grpc/tracdap/metadata/job_pb2.pyi +120 -2
  11. tracdap/rt/_impl/grpc/tracdap/metadata/model_pb2.py +20 -18
  12. tracdap/rt/_impl/grpc/tracdap/metadata/model_pb2.pyi +22 -6
  13. tracdap/rt/_impl/grpc/tracdap/metadata/resource_pb2.py +29 -0
  14. tracdap/rt/_impl/grpc/tracdap/metadata/resource_pb2.pyi +16 -0
  15. tracdap/rt/_impl/models.py +8 -0
  16. tracdap/rt/_impl/static_api.py +42 -10
  17. tracdap/rt/_impl/storage.py +37 -25
  18. tracdap/rt/_impl/validation.py +113 -11
  19. tracdap/rt/_plugins/repo_git.py +1 -1
  20. tracdap/rt/_version.py +1 -1
  21. tracdap/rt/api/experimental.py +220 -0
  22. tracdap/rt/api/hook.py +6 -4
  23. tracdap/rt/api/model_api.py +98 -13
  24. tracdap/rt/api/static_api.py +14 -6
  25. tracdap/rt/config/__init__.py +2 -2
  26. tracdap/rt/config/common.py +23 -17
  27. tracdap/rt/config/job.py +2 -2
  28. tracdap/rt/config/platform.py +25 -25
  29. tracdap/rt/config/result.py +2 -2
  30. tracdap/rt/config/runtime.py +3 -3
  31. tracdap/rt/launch/cli.py +7 -4
  32. tracdap/rt/launch/launch.py +19 -3
  33. tracdap/rt/metadata/__init__.py +25 -20
  34. tracdap/rt/metadata/common.py +2 -2
  35. tracdap/rt/metadata/custom.py +3 -3
  36. tracdap/rt/metadata/data.py +12 -12
  37. tracdap/rt/metadata/file.py +6 -6
  38. tracdap/rt/metadata/flow.py +6 -6
  39. tracdap/rt/metadata/job.py +62 -8
  40. tracdap/rt/metadata/model.py +33 -11
  41. tracdap/rt/metadata/object_id.py +8 -8
  42. tracdap/rt/metadata/resource.py +24 -0
  43. tracdap/rt/metadata/search.py +5 -5
  44. tracdap/rt/metadata/stoarge.py +6 -6
  45. tracdap/rt/metadata/tag.py +1 -1
  46. tracdap/rt/metadata/tag_update.py +1 -1
  47. tracdap/rt/metadata/type.py +4 -4
  48. {tracdap_runtime-0.6.3.dist-info → tracdap_runtime-0.6.5.dist-info}/METADATA +3 -1
  49. {tracdap_runtime-0.6.3.dist-info → tracdap_runtime-0.6.5.dist-info}/RECORD +52 -48
  50. {tracdap_runtime-0.6.3.dist-info → tracdap_runtime-0.6.5.dist-info}/LICENSE +0 -0
  51. {tracdap_runtime-0.6.3.dist-info → tracdap_runtime-0.6.5.dist-info}/WHEEL +0 -0
  52. {tracdap_runtime-0.6.3.dist-info → tracdap_runtime-0.6.5.dist-info}/top_level.txt +0 -0
@@ -12,19 +12,20 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ import copy
15
16
  import logging
16
17
  import pathlib
17
18
  import typing as tp
18
19
  import re
19
20
  import traceback
20
21
 
21
- import pandas as pd
22
-
23
22
  import tracdap.rt.api as _api
23
+ import tracdap.rt.api.experimental as _eapi
24
24
  import tracdap.rt.metadata as _meta
25
25
  import tracdap.rt.exceptions as _ex
26
26
  import tracdap.rt._impl.type_system as _types # noqa
27
27
  import tracdap.rt._impl.data as _data # noqa
28
+ import tracdap.rt._impl.storage as _storage # noqa
28
29
  import tracdap.rt._impl.util as _util # noqa
29
30
  import tracdap.rt._impl.validation as _val # noqa
30
31
 
@@ -32,8 +33,8 @@ import tracdap.rt._impl.validation as _val # noqa
32
33
  class TracContextImpl(_api.TracContext):
33
34
 
34
35
  """
35
- TracContextImpl is the main implementation of the API class TracContext (from trac.rt.api).
36
- It provides get/put operations the inputs, outputs and parameters of a model according to the model definition,
36
+ TracContextImpl is the main implementation of the API class TracContext (from tracdap.rt.api).
37
+ It provides get/put operations on the inputs, outputs and parameters of a model according to the model definition,
37
38
  as well as exposing other information needed by the model at runtime and offering a few utility functions.
38
39
 
39
40
  An instance of TracContextImpl is constructed by the runtime engine for each model node in the execution graph.
@@ -44,8 +45,8 @@ class TracContextImpl(_api.TracContext):
44
45
 
45
46
  Optimizations for lazy loading and eager saving require the context to call back into the runtime engine. For lazy
46
47
  load, the graph node to prepare an input is injected when the data is requested and the model thread blocks until
47
- it is available; for eager save child nodes of individual outputs are triggered when those outputs are produced.
48
- In both cases this complexity is hidden from the model, which only sees one thread with synchronous get/put calls.
48
+ it is available; for eager save outputs are sent to child actors as soon as they are produced. In both cases this
49
+ complexity is hidden from the model, which only sees one thread with synchronous get/put calls.
49
50
 
50
51
  :param model_def: Definition object for the model that will run in this context
51
52
  :param model_class: Type for the model that will run in this context
@@ -59,8 +60,8 @@ class TracContextImpl(_api.TracContext):
59
60
  def __init__(self,
60
61
  model_def: _meta.ModelDefinition,
61
62
  model_class: _api.TracModel.__class__,
62
- local_ctx: tp.Dict[str, _data.DataView],
63
- schemas: tp.Dict[str, _meta.SchemaDefinition],
63
+ local_ctx: tp.Dict[str, tp.Any],
64
+ dynamic_outputs: tp.List[str] = None,
64
65
  checkout_directory: pathlib.Path = None):
65
66
 
66
67
  self.__ctx_log = _util.logger_for_object(self)
@@ -68,26 +69,27 @@ class TracContextImpl(_api.TracContext):
68
69
 
69
70
  self.__model_def = model_def
70
71
  self.__model_class = model_class
71
-
72
- self.__parameters = local_ctx or {}
73
- self.__data = local_ctx or {}
74
- self.__schemas = schemas
72
+ self.__local_ctx = local_ctx if local_ctx is not None else {}
73
+ self.__dynamic_outputs = dynamic_outputs if dynamic_outputs is not None else []
75
74
 
76
75
  self.__val = TracContextValidator(
77
76
  self.__ctx_log,
78
- self.__parameters,
79
- self.__data,
77
+ self.__model_def,
78
+ self.__local_ctx,
79
+ self.__dynamic_outputs,
80
80
  checkout_directory)
81
81
 
82
82
  def get_parameter(self, parameter_name: str) -> tp.Any:
83
83
 
84
84
  _val.validate_signature(self.get_parameter, parameter_name)
85
85
 
86
- self.__val.check_param_not_null(parameter_name)
87
86
  self.__val.check_param_valid_identifier(parameter_name)
88
- self.__val.check_param_exists(parameter_name)
87
+ self.__val.check_param_defined_in_model(parameter_name)
88
+ self.__val.check_param_available_in_context(parameter_name)
89
89
 
90
- value: _meta.Value = self.__parameters[parameter_name] # noqa
90
+ value: _meta.Value = self.__local_ctx.get(parameter_name)
91
+
92
+ self.__val.check_context_object_type(parameter_name, value, _meta.Value)
91
93
 
92
94
  return _types.MetadataCodec.decode_value(value)
93
95
 
@@ -95,101 +97,204 @@ class TracContextImpl(_api.TracContext):
95
97
 
96
98
  _val.validate_signature(self.has_dataset, dataset_name)
97
99
 
98
- part_key = _data.DataPartKey.for_root()
99
-
100
- self.__val.check_dataset_name_not_null(dataset_name)
101
100
  self.__val.check_dataset_valid_identifier(dataset_name)
101
+ self.__val.check_dataset_defined_in_model(dataset_name)
102
102
 
103
- data_view = self.__data.get(dataset_name)
103
+ data_view: _data.DataView = self.__local_ctx.get(dataset_name)
104
104
 
105
105
  if data_view is None:
106
106
  return False
107
107
 
108
- # If the item exists but is not a dataset, that is still a runtime error
109
- # E.g. if this method is called for FILE inputs
110
- self.__val.check_context_item_is_dataset(dataset_name)
111
-
112
- part = data_view.parts.get(part_key)
108
+ self.__val.check_context_object_type(dataset_name, data_view, _data.DataView)
113
109
 
114
- if part is None or len(part) == 0:
115
- return False
116
-
117
- return True
110
+ return not data_view.is_empty()
118
111
 
119
112
  def get_schema(self, dataset_name: str) -> _meta.SchemaDefinition:
120
113
 
121
114
  _val.validate_signature(self.get_schema, dataset_name)
122
115
 
123
- self.__val.check_dataset_name_not_null(dataset_name)
124
116
  self.__val.check_dataset_valid_identifier(dataset_name)
117
+ self.__val.check_dataset_defined_in_model(dataset_name)
118
+ self.__val.check_dataset_available_in_context(dataset_name)
119
+
120
+ static_schema = self.__get_static_schema(self.__model_def, dataset_name)
121
+ data_view: _data.DataView = self.__local_ctx.get(dataset_name)
125
122
 
126
- # There is no need to look in the data map if the model has defined a static schema
127
- if dataset_name in self.__schemas:
128
- return self.__schemas[dataset_name]
123
+ # Check the data view has a well-defined schema even if a static schema exists in the model
124
+ # This ensures errors are always reported and is consistent with get_pandas_table()
129
125
 
130
- self.__val.check_context_item_exists(dataset_name)
131
- self.__val.check_context_item_is_dataset(dataset_name)
132
- self.__val.check_dataset_schema_defined(dataset_name)
126
+ self.__val.check_context_object_type(dataset_name, data_view, _data.DataView)
127
+ self.__val.check_dataset_schema_defined(dataset_name, data_view)
133
128
 
134
- data_view = self.__data[dataset_name]
129
+ # If a static schema exists, that takes priority
130
+ # Return deep copies, do not allow model code to change schemas provided by the engine
131
+
132
+ if static_schema is not None:
133
+ return copy.deepcopy(static_schema)
134
+ else:
135
+ return copy.deepcopy(data_view.trac_schema)
135
136
 
136
- return data_view.trac_schema
137
+ def get_table(self, dataset_name: str, framework, **kwargs) -> _eapi._DATA_FRAMEWORK: # noqa
138
+
139
+ # Support the experimental API data framework syntax
140
+
141
+ if framework == _eapi.PANDAS:
142
+ return self.get_pandas_table(dataset_name, **kwargs)
143
+ elif framework == _eapi.POLARS:
144
+ return self.get_polars_table(dataset_name)
145
+ else:
146
+ raise _ex.ERuntimeValidation(f"Unsupported data framework [{framework}]")
137
147
 
138
- def get_pandas_table(self, dataset_name: str, use_temporal_objects: tp.Optional[bool] = None) -> pd.DataFrame:
148
+ def get_pandas_table(self, dataset_name: str, use_temporal_objects: tp.Optional[bool] = None) \
149
+ -> "_data.pandas.DataFrame":
139
150
 
151
+ _val.require_package("pandas", _data.pandas)
140
152
  _val.validate_signature(self.get_pandas_table, dataset_name, use_temporal_objects)
141
153
 
154
+ data_view, schema = self.__get_data_view(dataset_name)
142
155
  part_key = _data.DataPartKey.for_root()
143
156
 
144
- self.__val.check_dataset_name_not_null(dataset_name)
157
+ if use_temporal_objects is None:
158
+ use_temporal_objects = self.__DEFAULT_TEMPORAL_OBJECTS
159
+
160
+ return _data.DataMapping.view_to_pandas(data_view, part_key, schema, use_temporal_objects)
161
+
162
+ def get_polars_table(self, dataset_name: str) -> "_data.polars.DataFrame":
163
+
164
+ _val.require_package("polars", _data.polars)
165
+ _val.validate_signature(self.get_polars_table, dataset_name)
166
+
167
+ data_view, schema = self.__get_data_view(dataset_name)
168
+ part_key = _data.DataPartKey.for_root()
169
+
170
+ return _data.DataMapping.view_to_polars(data_view, part_key, schema)
171
+
172
+ def __get_data_view(self, dataset_name: str):
173
+
174
+ _val.validate_signature(self.__get_data_view, dataset_name)
175
+
145
176
  self.__val.check_dataset_valid_identifier(dataset_name)
146
- self.__val.check_context_item_exists(dataset_name)
147
- self.__val.check_context_item_is_dataset(dataset_name)
148
- self.__val.check_dataset_schema_defined(dataset_name)
149
- self.__val.check_dataset_part_present(dataset_name, part_key)
177
+ self.__val.check_dataset_defined_in_model(dataset_name)
178
+ self.__val.check_dataset_available_in_context(dataset_name)
179
+
180
+ static_schema = self.__get_static_schema(self.__model_def, dataset_name)
181
+ data_view = self.__local_ctx.get(dataset_name)
182
+ part_key = _data.DataPartKey.for_root()
150
183
 
151
- data_view = self.__data[dataset_name]
184
+ self.__val.check_context_object_type(dataset_name, data_view, _data.DataView)
185
+ self.__val.check_dataset_schema_defined(dataset_name, data_view)
186
+ self.__val.check_dataset_part_present(dataset_name, data_view, part_key)
152
187
 
153
188
  # If the model defines a static input schema, use that for schema conformance
154
189
  # Otherwise, take what is in the incoming dataset (schema is dynamic)
155
- if dataset_name in self.__schemas:
156
- schema = _data.DataMapping.trac_to_arrow_schema(self.__schemas[dataset_name])
190
+
191
+ if static_schema is not None:
192
+ schema = _data.DataMapping.trac_to_arrow_schema(static_schema)
157
193
  else:
158
194
  schema = data_view.arrow_schema
159
195
 
160
- if use_temporal_objects is None:
161
- use_temporal_objects = self.__DEFAULT_TEMPORAL_OBJECTS
196
+ return data_view, schema
162
197
 
163
- return _data.DataMapping.view_to_pandas(data_view, part_key, schema, use_temporal_objects)
198
+ def put_schema(self, dataset_name: str, schema: _meta.SchemaDefinition):
199
+
200
+ _val.validate_signature(self.get_schema, dataset_name, schema)
201
+
202
+ # Copy the schema - schema cannot be changed in model code after put_schema
203
+ # If field ordering is not assigned by the model, assign it here (model code will not see the numbers)
204
+ schema_copy = self.__assign_field_order(copy.deepcopy(schema))
205
+
206
+ self.__val.check_dataset_valid_identifier(dataset_name)
207
+ self.__val.check_dataset_is_dynamic_output(dataset_name)
208
+ self.__val.check_provided_schema_is_valid(dataset_name, schema_copy)
209
+
210
+ static_schema = self.__get_static_schema(self.__model_def, dataset_name)
211
+ data_view = self.__local_ctx.get(dataset_name)
212
+
213
+ if data_view is None:
214
+ if static_schema is not None:
215
+ data_view = _data.DataView.for_trac_schema(static_schema)
216
+ else:
217
+ data_view = _data.DataView.create_empty()
218
+
219
+ # If there is a prior view it must contain nothing and will be replaced
220
+ self.__val.check_context_object_type(dataset_name, data_view, _data.DataView)
221
+ self.__val.check_dataset_schema_not_defined(dataset_name, data_view)
222
+ self.__val.check_dataset_is_empty(dataset_name, data_view)
223
+
224
+ updated_view = data_view.with_trac_schema(schema_copy)
225
+
226
+ self.__local_ctx[dataset_name] = updated_view
227
+
228
+ def put_table(self, dataset_name: str, dataset: _eapi._DATA_FRAMEWORK, **kwargs): # noqa
164
229
 
165
- def put_pandas_table(self, dataset_name: str, dataset: pd.DataFrame):
230
+ # Support the experimental API data framework syntax
166
231
 
232
+ if _data.pandas and isinstance(dataset, _data.pandas.DataFrame):
233
+ self.put_pandas_table(dataset_name, dataset)
234
+ elif _data.polars and isinstance(dataset, _data.polars.DataFrame):
235
+ self.put_polars_table(dataset_name, dataset)
236
+ else:
237
+ raise _ex.ERuntimeValidation(f"Unsupported data framework[{type(dataset)}]")
238
+
239
+ def put_pandas_table(self, dataset_name: str, dataset: "_data.pandas.DataFrame"):
240
+
241
+ _val.require_package("pandas", _data.pandas)
167
242
  _val.validate_signature(self.put_pandas_table, dataset_name, dataset)
168
243
 
169
244
  part_key = _data.DataPartKey.for_root()
245
+ data_view, schema = self.__put_data_view(dataset_name, part_key, dataset, _data.pandas.DataFrame)
246
+
247
+ # Data conformance is applied inside these conversion functions
248
+
249
+ updated_item = _data.DataMapping.pandas_to_item(dataset, schema)
250
+ updated_view = _data.DataMapping.add_item_to_view(data_view, part_key, updated_item)
251
+
252
+ self.__local_ctx[dataset_name] = updated_view
253
+
254
+ def put_polars_table(self, dataset_name: str, dataset: "_data.polars.DataFrame"):
255
+
256
+ _val.require_package("polars", _data.polars)
257
+ _val.validate_signature(self.put_polars_table, dataset_name, dataset)
258
+
259
+ part_key = _data.DataPartKey.for_root()
260
+ data_view, schema = self.__put_data_view(dataset_name, part_key, dataset, _data.polars.DataFrame)
261
+
262
+ # Data conformance is applied inside these conversion functions
263
+
264
+ updated_item = _data.DataMapping.polars_to_item(dataset, schema)
265
+ updated_view = _data.DataMapping.add_item_to_view(data_view, part_key, updated_item)
266
+
267
+ self.__local_ctx[dataset_name] = updated_view
268
+
269
+ def __put_data_view(self, dataset_name: str, part_key: _data.DataPartKey, dataset: tp.Any, framework: type):
270
+
271
+ _val.validate_signature(self.__put_data_view, dataset_name, part_key, dataset, framework)
170
272
 
171
- self.__val.check_dataset_name_not_null(dataset_name)
172
273
  self.__val.check_dataset_valid_identifier(dataset_name)
173
- self.__val.check_context_item_exists(dataset_name)
174
- self.__val.check_context_item_is_dataset(dataset_name)
175
- self.__val.check_dataset_schema_defined(dataset_name)
176
- self.__val.check_dataset_part_not_present(dataset_name, part_key)
177
- self.__val.check_provided_dataset_not_null(dataset)
178
- self.__val.check_provided_dataset_type(dataset, pd.DataFrame)
179
-
180
- prior_view = self.__data[dataset_name]
181
-
182
- # If the model defines a static output schema, use that for schema conformance
183
- # Otherwise, use the schema in the data view for this output (this could be a dynamic schema)
184
- if dataset_name in self.__schemas:
185
- schema = _data.DataMapping.trac_to_arrow_schema(self.__schemas[dataset_name])
186
- else:
187
- schema = prior_view.arrow_schema
274
+ self.__val.check_dataset_is_model_output(dataset_name)
275
+ self.__val.check_provided_dataset_type(dataset, framework)
276
+
277
+ static_schema = self.__get_static_schema(self.__model_def, dataset_name)
278
+ data_view = self.__local_ctx.get(dataset_name)
279
+
280
+ if data_view is None:
281
+ if static_schema is not None:
282
+ data_view = _data.DataView.for_trac_schema(static_schema)
283
+ else:
284
+ data_view = _data.DataView.create_empty()
285
+
286
+ self.__val.check_context_object_type(dataset_name, data_view, _data.DataView)
287
+ self.__val.check_dataset_schema_defined(dataset_name, data_view)
288
+ self.__val.check_dataset_part_not_present(dataset_name, data_view, part_key)
188
289
 
189
- data_item = _data.DataMapping.pandas_to_item(dataset, schema)
190
- data_view = _data.DataMapping.add_item_to_view(prior_view, part_key, data_item)
290
+ # Prefer static schemas for data conformance
191
291
 
192
- self.__data[dataset_name] = data_view
292
+ if static_schema is not None:
293
+ schema = _data.DataMapping.trac_to_arrow_schema(static_schema)
294
+ else:
295
+ schema = data_view.arrow_schema
296
+
297
+ return data_view, schema
193
298
 
194
299
  def log(self) -> logging.Logger:
195
300
 
@@ -197,24 +302,243 @@ class TracContextImpl(_api.TracContext):
197
302
 
198
303
  return self.__model_log
199
304
 
305
+ @staticmethod
306
+ def __get_static_schema(model_def: _meta.ModelDefinition, dataset_name: str):
200
307
 
201
- class TracContextValidator:
308
+ input_schema = model_def.inputs.get(dataset_name)
202
309
 
203
- __VALID_IDENTIFIER = re.compile("^[a-zA-Z_]\\w*$",)
204
- __RESERVED_IDENTIFIER = re.compile("^(trac_|_)\\w*")
310
+ if input_schema is not None and not input_schema.dynamic:
311
+ return input_schema.schema
312
+
313
+ output_schema = model_def.outputs.get(dataset_name)
314
+
315
+ if output_schema is not None and not output_schema.dynamic:
316
+ return output_schema.schema
317
+
318
+ return None
319
+
320
+ @staticmethod
321
+ def __assign_field_order(schema_def: _meta.SchemaDefinition):
322
+
323
+ if schema_def is None or schema_def.table is None or schema_def.table.fields is None:
324
+ return schema_def
325
+
326
+ if all(map(lambda f: f.fieldOrder is None, schema_def.table.fields)):
327
+ for index, field in enumerate(schema_def.table.fields):
328
+ field.fieldOrder = index
329
+
330
+ return schema_def
331
+
332
+
333
+ class TracDataContextImpl(TracContextImpl, _eapi.TracDataContext):
205
334
 
206
335
  def __init__(
207
- self, log: logging.Logger,
208
- parameters: tp.Dict[str, tp.Any],
209
- data_ctx: tp.Dict[str, _data.DataView],
210
- checkout_directory: pathlib.Path):
336
+ self, model_def: _meta.ModelDefinition, model_class: _api.TracModel.__class__,
337
+ local_ctx: tp.Dict[str, tp.Any], dynamic_outputs: tp.List[str],
338
+ storage_map: tp.Dict[str, tp.Union[_eapi.TracFileStorage]],
339
+ checkout_directory: pathlib.Path = None):
340
+
341
+ super().__init__(model_def, model_class, local_ctx, dynamic_outputs, checkout_directory)
342
+
343
+ self.__model_def = model_def
344
+ self.__local_ctx = local_ctx
345
+ self.__dynamic_outputs = dynamic_outputs
346
+ self.__storage_map = storage_map
347
+ self.__checkout_directory = checkout_directory
348
+
349
+ self.__val = self._TracContextImpl__val # noqa
350
+
351
+ def get_file_storage(self, storage_key: str) -> _eapi.TracFileStorage:
352
+
353
+ _val.validate_signature(self.get_file_storage, storage_key)
354
+
355
+ self.__val.check_storage_valid_identifier(storage_key)
356
+ self.__val.check_storage_available(self.__storage_map, storage_key)
357
+ self.__val.check_storage_type(self.__storage_map, storage_key, _eapi.TracFileStorage)
358
+
359
+ return self.__storage_map[storage_key]
360
+
361
+ def get_data_storage(self, storage_key: str) -> None:
362
+ raise _ex.ERuntimeValidation("Data storage API not available yet")
363
+
364
+ def add_data_import(self, dataset_name: str):
365
+
366
+ _val.validate_signature(self.add_data_import, dataset_name)
367
+
368
+ self.__val.check_dataset_valid_identifier(dataset_name)
369
+ self.__val.check_dataset_not_defined_in_model(dataset_name)
370
+ self.__val.check_dataset_not_available_in_context(dataset_name)
371
+
372
+ self.__local_ctx[dataset_name] = _data.DataView.create_empty()
373
+ self.__dynamic_outputs.append(dataset_name)
374
+
375
+ def set_source_metadata(self, dataset_name: str, storage_key: str, source_info: _eapi.FileStat):
376
+
377
+ _val.validate_signature(self.add_data_import, dataset_name, storage_key, source_info)
378
+
379
+ pass # Not implemented yet, only required when imports are sent back to the platform
380
+
381
+ def set_attribute(self, dataset_name: str, attribute_name: str, value: tp.Any):
382
+
383
+ _val.validate_signature(self.add_data_import, dataset_name, attribute_name, value)
384
+
385
+ pass # Not implemented yet, only required when imports are sent back to the platform
386
+
387
+ def set_schema(self, dataset_name: str, schema: _meta.SchemaDefinition):
388
+
389
+ _val.validate_signature(self.set_schema, dataset_name, schema)
390
+
391
+ # Forward to existing method (these should be swapped round)
392
+ self.put_schema(dataset_name, schema)
393
+
394
+
395
+ class TracFileStorageImpl(_eapi.TracFileStorage):
396
+
397
+ def __init__(self, storage_key: str, storage_impl: _storage.IFileStorage, write_access: bool, checkout_directory):
398
+
399
+ self.__storage_key = storage_key
400
+
401
+ self.__exists = lambda sp: storage_impl.exists(sp)
402
+ self.__size = lambda sp: storage_impl.size(sp)
403
+ self.__stat = lambda sp: storage_impl.stat(sp)
404
+ self.__ls = lambda sp, rec: storage_impl.ls(sp, rec)
405
+ self.__read_byte_stream = lambda sp: storage_impl.read_byte_stream(sp)
406
+
407
+ if write_access:
408
+ self.__mkdir = lambda sp, rec: storage_impl.mkdir(sp, rec)
409
+ self.__rm = lambda sp: storage_impl.rm(sp)
410
+ self.__rmdir = lambda sp: storage_impl.rmdir(sp)
411
+ self.__write_byte_stream = lambda sp: storage_impl.write_byte_stream(sp)
412
+ else:
413
+ self.__mkdir = None
414
+ self.__rm = None
415
+ self.__rmdir = None
416
+ self.__write_byte_stream = None
417
+
418
+ self.__log = _util.logger_for_object(self)
419
+ self.__val = TracStorageValidator(self.__log, checkout_directory, self.__storage_key)
420
+
421
+ def get_storage_key(self) -> str:
422
+
423
+ _val.validate_signature(self.get_storage_key)
424
+
425
+ return self.__storage_key
426
+
427
+ def exists(self, storage_path: str) -> bool:
428
+
429
+ _val.validate_signature(self.exists, storage_path)
430
+
431
+ self.__val.check_operation_available(self.exists, self.__exists)
432
+ self.__val.check_storage_path_is_valid(storage_path)
433
+
434
+ return self.__exists(storage_path)
435
+
436
+ def size(self, storage_path: str) -> int:
437
+
438
+ _val.validate_signature(self.size, storage_path)
439
+
440
+ self.__val.check_operation_available(self.size, self.__size)
441
+ self.__val.check_storage_path_is_valid(storage_path)
442
+
443
+ return self.__size(storage_path)
444
+
445
+ def stat(self, storage_path: str) -> _eapi.FileStat:
446
+
447
+ _val.validate_signature(self.stat, storage_path)
448
+
449
+ self.__val.check_operation_available(self.stat, self.__stat)
450
+ self.__val.check_storage_path_is_valid(storage_path)
451
+
452
+ stat = self.__stat(storage_path)
453
+ return _eapi.FileStat(**stat.__dict__)
454
+
455
+ def ls(self, storage_path: str, recursive: bool = False) -> tp.List[_eapi.FileStat]:
456
+
457
+ _val.validate_signature(self.ls, storage_path, recursive)
458
+
459
+ self.__val.check_operation_available(self.ls, self.__ls)
460
+ self.__val.check_storage_path_is_valid(storage_path)
461
+
462
+ listing = self.__ls(storage_path, recursive)
463
+ return list(_eapi.FileStat(**stat.__dict__) for stat in listing)
464
+
465
+ def mkdir(self, storage_path: str, recursive: bool = False):
466
+
467
+ _val.validate_signature(self.mkdir, storage_path, recursive)
468
+
469
+ self.__val.check_operation_available(self.mkdir, self.__mkdir)
470
+ self.__val.check_storage_path_is_valid(storage_path)
471
+ self.__val.check_storage_path_is_not_root(storage_path)
472
+
473
+ self.__mkdir(storage_path, recursive)
474
+
475
+ def rm(self, storage_path: str):
476
+
477
+ _val.validate_signature(self.rm, storage_path)
478
+
479
+ self.__val.check_operation_available(self.rm, self.__rm)
480
+ self.__val.check_storage_path_is_valid(storage_path)
481
+ self.__val.check_storage_path_is_not_root(storage_path)
482
+
483
+ self.__rm(storage_path)
484
+
485
+ def rmdir(self, storage_path: str):
486
+
487
+ _val.validate_signature(self.rmdir, storage_path)
488
+
489
+ self.__val.check_operation_available(self.rmdir, self.__rmdir)
490
+ self.__val.check_storage_path_is_valid(storage_path)
491
+ self.__val.check_storage_path_is_not_root(storage_path)
492
+
493
+ self.__rmdir(storage_path)
494
+
495
+ def read_byte_stream(self, storage_path: str) -> tp.ContextManager[tp.BinaryIO]:
496
+
497
+ _val.validate_signature(self.read_byte_stream, storage_path)
498
+
499
+ self.__val.check_operation_available(self.read_byte_stream, self.__read_byte_stream)
500
+ self.__val.check_storage_path_is_valid(storage_path)
501
+
502
+ return self.__read_byte_stream(storage_path)
503
+
504
+ def read_bytes(self, storage_path: str) -> bytes:
505
+
506
+ _val.validate_signature(self.read_bytes, storage_path)
507
+
508
+ self.__val.check_operation_available(self.read_bytes, self.__read_byte_stream)
509
+ self.__val.check_storage_path_is_valid(storage_path)
510
+
511
+ return super().read_bytes(storage_path)
512
+
513
+ def write_byte_stream(self, storage_path: str) -> tp.ContextManager[tp.BinaryIO]:
514
+
515
+ _val.validate_signature(self.write_byte_stream, storage_path)
516
+
517
+ self.__val.check_operation_available(self.write_byte_stream, self.__write_byte_stream)
518
+ self.__val.check_storage_path_is_valid(storage_path)
519
+ self.__val.check_storage_path_is_not_root(storage_path)
520
+
521
+ return self.__write_byte_stream(storage_path)
522
+
523
+ def write_bytes(self, storage_path: str, data: bytes):
524
+
525
+ _val.validate_signature(self.write_bytes, storage_path)
526
+
527
+ self.__val.check_operation_available(self.write_bytes, self.__write_byte_stream)
528
+ self.__val.check_storage_path_is_valid(storage_path)
529
+ self.__val.check_storage_path_is_not_root(storage_path)
530
+
531
+ super().write_bytes(storage_path, data)
532
+
533
+
534
+ class TracContextErrorReporter:
535
+
536
+ def __init__(self, log: logging.Logger, checkout_directory: pathlib.Path):
211
537
 
212
538
  self.__log = log
213
- self.__parameters = parameters
214
- self.__data_ctx = data_ctx
215
539
  self.__checkout_directory = checkout_directory
216
540
 
217
- def _report_error(self, message):
541
+ def _report_error(self, message, cause: Exception = None):
218
542
 
219
543
  full_stack = traceback.extract_stack()
220
544
  model_stack = _util.filter_model_stack_trace(full_stack, self.__checkout_directory)
@@ -225,80 +549,147 @@ class TracContextValidator:
225
549
  self.__log.error(message)
226
550
  self.__log.error(f"Model stack trace:\n{model_stack_str}")
227
551
 
228
- raise _ex.ERuntimeValidation(message)
552
+ if cause:
553
+ raise _ex.ERuntimeValidation(message) from cause
554
+ else:
555
+ raise _ex.ERuntimeValidation(message)
229
556
 
230
- def check_param_not_null(self, param_name):
231
557
 
232
- if param_name is None:
233
- self._report_error(f"Parameter name is null")
558
+ class TracContextValidator(TracContextErrorReporter):
559
+
560
+ __VALID_IDENTIFIER = re.compile("^[a-zA-Z_]\\w*$",)
561
+ __RESERVED_IDENTIFIER = re.compile("^(trac_|_)\\w*")
562
+
563
+ def __init__(
564
+ self, log: logging.Logger,
565
+ model_def: _meta.ModelDefinition,
566
+ local_ctx: tp.Dict[str, tp.Any],
567
+ dynamic_outputs: tp.List[str],
568
+ checkout_directory: pathlib.Path):
569
+
570
+ super().__init__(log, checkout_directory)
571
+
572
+ self.__model_def = model_def
573
+ self.__local_ctx = local_ctx
574
+ self.__dynamic_outputs = dynamic_outputs
234
575
 
235
576
  def check_param_valid_identifier(self, param_name: str):
236
577
 
578
+ if param_name is None:
579
+ self._report_error(f"Parameter name is null")
580
+
237
581
  if not self.__VALID_IDENTIFIER.match(param_name):
238
582
  self._report_error(f"Parameter name {param_name} is not a valid identifier")
239
583
 
240
- def check_param_exists(self, param_name: str):
584
+ def check_param_defined_in_model(self, param_name: str):
241
585
 
242
- if param_name not in self.__parameters:
243
- self._report_error(f"Parameter {param_name} is not defined in the current context")
586
+ if param_name not in self.__model_def.parameters:
587
+ self._report_error(f"Parameter {param_name} is not defined in the model")
244
588
 
245
- def check_dataset_name_not_null(self, dataset_name):
589
+ def check_param_available_in_context(self, param_name: str):
246
590
 
247
- if dataset_name is None:
248
- self._report_error(f"Dataset name is null")
591
+ if param_name not in self.__local_ctx:
592
+ self._report_error(f"Parameter {param_name} is not available in the current context")
249
593
 
250
594
  def check_dataset_valid_identifier(self, dataset_name: str):
251
595
 
596
+ if dataset_name is None:
597
+ self._report_error(f"Dataset name is null")
598
+
252
599
  if not self.__VALID_IDENTIFIER.match(dataset_name):
253
600
  self._report_error(f"Dataset name {dataset_name} is not a valid identifier")
254
601
 
255
- def check_context_item_exists(self, item_name: str):
602
+ def check_dataset_not_defined_in_model(self, dataset_name: str):
603
+
604
+ if dataset_name in self.__model_def.inputs or dataset_name in self.__model_def.outputs:
605
+ self._report_error(f"Dataset {dataset_name} is already defined in the model")
606
+
607
+ if dataset_name in self.__model_def.parameters:
608
+ self._report_error(f"Dataset name {dataset_name} is already in use as a model parameter")
609
+
610
+ def check_dataset_defined_in_model(self, dataset_name: str):
611
+
612
+ if dataset_name not in self.__model_def.inputs and dataset_name not in self.__model_def.outputs:
613
+ self._report_error(f"Dataset {dataset_name} is not defined in the model")
256
614
 
257
- if item_name not in self.__data_ctx:
258
- self._report_error(f"The identifier {item_name} is not defined in the current context")
615
+ def check_dataset_is_model_output(self, dataset_name: str):
259
616
 
260
- def check_context_item_is_dataset(self, item_name: str):
617
+ if dataset_name not in self.__model_def.outputs and dataset_name not in self.__dynamic_outputs:
618
+ self._report_error(f"Dataset {dataset_name} is not defined as a model output")
261
619
 
262
- ctx_item = self.__data_ctx[item_name]
620
+ def check_dataset_is_dynamic_output(self, dataset_name: str):
263
621
 
264
- if not isinstance(ctx_item, _data.DataView):
265
- self._report_error(f"The object referenced by {item_name} is not a dataset in the current context")
622
+ model_output: _meta.ModelOutputSchema = self.__model_def.outputs.get(dataset_name)
623
+ dynamic_output = dataset_name in self.__dynamic_outputs
266
624
 
267
- def check_dataset_schema_defined(self, dataset_name: str):
625
+ if model_output is None and not dynamic_output:
626
+ self._report_error(f"Dataset {dataset_name} is not defined as a model output")
268
627
 
269
- schema = self.__data_ctx[dataset_name].trac_schema
628
+ if model_output and not model_output.dynamic:
629
+ self._report_error(f"Model output {dataset_name} is not a dynamic output")
270
630
 
271
- if schema is None or not schema.table or not schema.table.fields:
631
+ def check_dataset_available_in_context(self, item_name: str):
632
+
633
+ if item_name not in self.__local_ctx:
634
+ self._report_error(f"Dataset {item_name} is not available in the current context")
635
+
636
+ def check_dataset_not_available_in_context(self, item_name: str):
637
+
638
+ if item_name in self.__local_ctx:
639
+ self._report_error(f"Dataset {item_name} already exists in the current context")
640
+
641
+ def check_dataset_schema_defined(self, dataset_name: str, data_view: _data.DataView):
642
+
643
+ schema = data_view.trac_schema if data_view is not None else None
644
+
645
+ if schema is None or schema.table is None or not schema.table.fields:
272
646
  self._report_error(f"Schema not defined for dataset {dataset_name} in the current context")
273
647
 
274
- def check_dataset_schema_not_defined(self, dataset_name: str):
648
+ def check_dataset_schema_not_defined(self, dataset_name: str, data_view: _data.DataView):
275
649
 
276
- schema = self.__data_ctx[dataset_name].trac_schema
650
+ schema = data_view.trac_schema if data_view is not None else None
277
651
 
278
652
  if schema is not None and (schema.table or schema.schemaType != _meta.SchemaType.SCHEMA_TYPE_NOT_SET):
279
653
  self._report_error(f"Schema already defined for dataset {dataset_name} in the current context")
280
654
 
281
- def check_dataset_part_present(self, dataset_name: str, part_key: _data.DataPartKey):
655
+ def check_dataset_part_present(self, dataset_name: str, data_view: _data.DataView, part_key: _data.DataPartKey):
282
656
 
283
- part = self.__data_ctx[dataset_name].parts.get(part_key)
657
+ part = data_view.parts.get(part_key) if data_view.parts is not None else None
284
658
 
285
659
  if part is None or len(part) == 0:
286
- self._report_error(f"No data present for dataset {dataset_name} ({part_key}) in the current context")
660
+ self._report_error(f"No data present for {dataset_name} ({part_key}) in the current context")
287
661
 
288
- def check_dataset_part_not_present(self, dataset_name: str, part_key: _data.DataPartKey):
662
+ def check_dataset_part_not_present(self, dataset_name: str, data_view: _data.DataView, part_key: _data.DataPartKey):
289
663
 
290
- part = self.__data_ctx[dataset_name].parts.get(part_key)
664
+ part = data_view.parts.get(part_key) if data_view.parts is not None else None
291
665
 
292
666
  if part is not None and len(part) > 0:
293
- self._report_error(f"Data already present for dataset {dataset_name} ({part_key}) in the current context")
667
+ self._report_error(f"Data already present for {dataset_name} ({part_key}) in the current context")
294
668
 
295
- def check_provided_dataset_not_null(self, dataset):
669
+ def check_dataset_is_empty(self, dataset_name: str, data_view: _data.DataView):
296
670
 
297
- if dataset is None:
298
- self._report_error(f"Provided dataset is null")
671
+ if not data_view.is_empty():
672
+ self._report_error(f"Dataset {dataset_name} is not empty")
673
+
674
+ def check_provided_schema_is_valid(self, dataset_name: str, schema: _meta.SchemaDefinition):
675
+
676
+ if schema is None:
677
+ self._report_error(f"The schema provided for [{dataset_name}] is null")
678
+
679
+ if not isinstance(schema, _meta.SchemaDefinition):
680
+ schema_type_name = self._type_name(type(schema))
681
+ self._report_error(f"The object provided for [{dataset_name}] is not a schema (got {schema_type_name})")
682
+
683
+ try:
684
+ _val.StaticValidator.quick_validate_schema(schema)
685
+ except _ex.EModelValidation as e:
686
+ self._report_error(f"The schema provided for [{dataset_name}] failed validation: {str(e)}", e)
299
687
 
300
688
  def check_provided_dataset_type(self, dataset: tp.Any, expected_type: type):
301
689
 
690
+ if dataset is None:
691
+ self._report_error(f"Provided dataset is null")
692
+
302
693
  if not isinstance(dataset, expected_type):
303
694
 
304
695
  expected_type_name = self._type_name(expected_type)
@@ -308,6 +699,44 @@ class TracContextValidator:
308
699
  f"Provided dataset is the wrong type" +
309
700
  f" (expected {expected_type_name}, got {actual_type_name})")
310
701
 
702
+ def check_context_object_type(self, item_name: str, item: tp.Any, expected_type: type):
703
+
704
+ if not isinstance(item, expected_type):
705
+
706
+ expected_type_name = self._type_name(expected_type)
707
+ actual_type_name = self._type_name(type(item))
708
+
709
+ self._report_error(
710
+ f"The object referenced by [{item_name}] in the current context has the wrong type" +
711
+ f" (expected {expected_type_name}, got {actual_type_name})")
712
+
713
+ def check_storage_valid_identifier(self, storage_key):
714
+
715
+ if storage_key is None:
716
+ self._report_error(f"Storage key is null")
717
+
718
+ if not self.__VALID_IDENTIFIER.match(storage_key):
719
+ self._report_error(f"Storage key {storage_key} is not a valid identifier")
720
+
721
+ def check_storage_available(self, storage_map: tp.Dict, storage_key: str):
722
+
723
+ storage_instance = storage_map.get(storage_key)
724
+
725
+ if storage_instance is None:
726
+ self._report_error(f"Storage not available for storage key [{storage_key}]")
727
+
728
+ def check_storage_type(
729
+ self, storage_map: tp.Dict, storage_key: str,
730
+ storage_type: tp.Union[_eapi.TracFileStorage.__class__]):
731
+
732
+ storage_instance = storage_map.get(storage_key)
733
+
734
+ if not isinstance(storage_instance, storage_type):
735
+ if storage_type == _eapi.TracFileStorage:
736
+ self._report_error(f"Storage key [{storage_key}] refers to data storage, not file storage")
737
+ else:
738
+ self._report_error(f"Storage key [{storage_key}] refers to file storage, not data storage")
739
+
311
740
  @staticmethod
312
741
  def _type_name(type_: type):
313
742
 
@@ -317,3 +746,34 @@ class TracContextValidator:
317
746
  return type_.__qualname__
318
747
 
319
748
  return module + '.' + type_.__name__
749
+
750
+
751
+ class TracStorageValidator(TracContextErrorReporter):
752
+
753
+ def __init__(self, log, checkout_directory, storage_key):
754
+ super().__init__(log, checkout_directory)
755
+ self.__storage_key = storage_key
756
+
757
+ def check_operation_available(self, public_func: tp.Callable, impl_func: tp.Callable):
758
+
759
+ if impl_func is None:
760
+ self._report_error(f"Operation [{public_func.__name__}] is not available for storage [{self.__storage_key}]")
761
+
762
+ def check_storage_path_is_valid(self, storage_path: str):
763
+
764
+ if _val.StorageValidator.storage_path_is_empty(storage_path):
765
+ self._report_error(f"Storage path is None or empty")
766
+
767
+ if _val.StorageValidator.storage_path_invalid(storage_path):
768
+ self._report_error(f"Storage path [{storage_path}] contains invalid characters")
769
+
770
+ if _val.StorageValidator.storage_path_not_relative(storage_path):
771
+ self._report_error(f"Storage path [{storage_path}] is not a relative path")
772
+
773
+ if _val.StorageValidator.storage_path_outside_root(storage_path):
774
+ self._report_error(f"Storage path [{storage_path}] is outside the storage root")
775
+
776
+ def check_storage_path_is_not_root(self, storage_path: str):
777
+
778
+ if _val.StorageValidator.storage_path_is_empty(storage_path):
779
+ self._report_error(f"Storage path [{storage_path}] is not allowed")