tracdap-runtime 0.6.2__py3-none-any.whl → 0.6.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. tracdap/rt/_exec/actors.py +87 -10
  2. tracdap/rt/_exec/context.py +207 -100
  3. tracdap/rt/_exec/dev_mode.py +52 -20
  4. tracdap/rt/_exec/engine.py +79 -14
  5. tracdap/rt/_exec/functions.py +14 -17
  6. tracdap/rt/_exec/runtime.py +83 -40
  7. tracdap/rt/_exec/server.py +306 -29
  8. tracdap/rt/_impl/config_parser.py +219 -49
  9. tracdap/rt/_impl/data.py +70 -5
  10. tracdap/rt/_impl/grpc/codec.py +60 -5
  11. tracdap/rt/_impl/grpc/tracdap/api/internal/runtime_pb2.py +19 -19
  12. tracdap/rt/_impl/grpc/tracdap/api/internal/runtime_pb2.pyi +11 -9
  13. tracdap/rt/_impl/grpc/tracdap/api/internal/runtime_pb2_grpc.py +25 -25
  14. tracdap/rt/_impl/grpc/tracdap/metadata/data_pb2.py +18 -18
  15. tracdap/rt/_impl/grpc/tracdap/metadata/model_pb2.py +28 -16
  16. tracdap/rt/_impl/grpc/tracdap/metadata/model_pb2.pyi +37 -6
  17. tracdap/rt/_impl/grpc/tracdap/metadata/object_pb2.py +8 -3
  18. tracdap/rt/_impl/grpc/tracdap/metadata/object_pb2.pyi +13 -2
  19. tracdap/rt/_impl/guard_rails.py +21 -0
  20. tracdap/rt/_impl/models.py +25 -0
  21. tracdap/rt/_impl/static_api.py +43 -13
  22. tracdap/rt/_impl/type_system.py +17 -0
  23. tracdap/rt/_impl/validation.py +47 -4
  24. tracdap/rt/_plugins/config_local.py +49 -0
  25. tracdap/rt/_version.py +1 -1
  26. tracdap/rt/api/hook.py +6 -5
  27. tracdap/rt/api/model_api.py +50 -7
  28. tracdap/rt/api/static_api.py +81 -23
  29. tracdap/rt/config/__init__.py +4 -4
  30. tracdap/rt/config/common.py +25 -15
  31. tracdap/rt/config/job.py +2 -2
  32. tracdap/rt/config/platform.py +25 -35
  33. tracdap/rt/config/result.py +2 -2
  34. tracdap/rt/config/runtime.py +4 -2
  35. tracdap/rt/ext/config.py +34 -0
  36. tracdap/rt/ext/embed.py +1 -3
  37. tracdap/rt/ext/plugins.py +47 -6
  38. tracdap/rt/launch/cli.py +11 -4
  39. tracdap/rt/launch/launch.py +53 -12
  40. tracdap/rt/metadata/__init__.py +17 -17
  41. tracdap/rt/metadata/common.py +2 -2
  42. tracdap/rt/metadata/custom.py +3 -3
  43. tracdap/rt/metadata/data.py +12 -12
  44. tracdap/rt/metadata/file.py +6 -6
  45. tracdap/rt/metadata/flow.py +6 -6
  46. tracdap/rt/metadata/job.py +8 -8
  47. tracdap/rt/metadata/model.py +21 -11
  48. tracdap/rt/metadata/object.py +3 -0
  49. tracdap/rt/metadata/object_id.py +8 -8
  50. tracdap/rt/metadata/search.py +5 -5
  51. tracdap/rt/metadata/stoarge.py +6 -6
  52. tracdap/rt/metadata/tag.py +1 -1
  53. tracdap/rt/metadata/tag_update.py +1 -1
  54. tracdap/rt/metadata/type.py +4 -4
  55. {tracdap_runtime-0.6.2.dist-info → tracdap_runtime-0.6.4.dist-info}/METADATA +4 -4
  56. tracdap_runtime-0.6.4.dist-info/RECORD +112 -0
  57. {tracdap_runtime-0.6.2.dist-info → tracdap_runtime-0.6.4.dist-info}/WHEEL +1 -1
  58. tracdap/rt/_impl/grpc/tracdap/config/common_pb2.py +0 -55
  59. tracdap/rt/_impl/grpc/tracdap/config/common_pb2.pyi +0 -103
  60. tracdap/rt/_impl/grpc/tracdap/config/job_pb2.py +0 -42
  61. tracdap/rt/_impl/grpc/tracdap/config/job_pb2.pyi +0 -44
  62. tracdap/rt/_impl/grpc/tracdap/config/platform_pb2.py +0 -71
  63. tracdap/rt/_impl/grpc/tracdap/config/platform_pb2.pyi +0 -197
  64. tracdap/rt/_impl/grpc/tracdap/config/result_pb2.py +0 -37
  65. tracdap/rt/_impl/grpc/tracdap/config/result_pb2.pyi +0 -35
  66. tracdap/rt/_impl/grpc/tracdap/config/runtime_pb2.py +0 -42
  67. tracdap/rt/_impl/grpc/tracdap/config/runtime_pb2.pyi +0 -46
  68. tracdap/rt/ext/_guard.py +0 -37
  69. tracdap_runtime-0.6.2.dist-info/RECORD +0 -121
  70. {tracdap_runtime-0.6.2.dist-info → tracdap_runtime-0.6.4.dist-info}/LICENSE +0 -0
  71. {tracdap_runtime-0.6.2.dist-info → tracdap_runtime-0.6.4.dist-info}/top_level.txt +0 -0
@@ -25,6 +25,7 @@ import queue
25
25
  import time
26
26
 
27
27
  import tracdap.rt._impl.util as util # noqa
28
+ import tracdap.rt._impl.validation as _val # noqa
28
29
  import tracdap.rt.exceptions as _ex
29
30
 
30
31
 
@@ -180,6 +181,49 @@ class ActorContext:
180
181
  return self.__error or self.__node.error
181
182
 
182
183
 
184
+ class ThreadsafeActor(Actor):
185
+
186
+ def __init__(self):
187
+ super().__init__()
188
+ self.__threadsafe: tp.Optional[ThreadsafeContext] = None
189
+
190
+ def threadsafe(self) -> ThreadsafeContext:
191
+ return self.__threadsafe
192
+
193
+
194
+ class ThreadsafeContext:
195
+
196
+ def __init__(self, node: ActorNode):
197
+ self.__node = node
198
+ self.__id = node.actor_id
199
+ self.__parent = node.parent.actor_id if node.parent is not None else None
200
+
201
+ def spawn(self, actor: Actor):
202
+ self.__node.event_loop.post_message(
203
+ None, lambda _:
204
+ self.__node.spawn(actor) and None)
205
+
206
+ def send(self, target_id: ActorId, message: str, *args, **kwargs):
207
+ self.__node.event_loop.post_message(
208
+ None, lambda _:
209
+ self.__node.send_message(self.__id, target_id, message, args, kwargs))
210
+
211
+ def send_parent(self, message: str, *args, **kwargs):
212
+ self.__node.event_loop.post_message(
213
+ None, lambda _:
214
+ self.__node.send_message(self.__id, self.__parent, message, args, kwargs))
215
+
216
+ def stop(self):
217
+ self.__node.event_loop.post_message(
218
+ None, lambda _:
219
+ self.__node.send_signal(self.__id, self.__id, SignalNames.STOP))
220
+
221
+ def fail(self, error: Exception):
222
+ self.__node.event_loop.post_message(
223
+ None, lambda _:
224
+ self.__node.send_signal(self.__id, self.__id, SignalNames.STOP, error))
225
+
226
+
183
227
  class EventLoop:
184
228
 
185
229
  _T_MSG = tp.TypeVar("_T_MSG")
@@ -340,7 +384,7 @@ class ActorNode:
340
384
  self.state: ActorState = ActorState.NOT_STARTED
341
385
  self.error: tp.Optional[Exception] = None
342
386
 
343
- def spawn(self, child_actor: Actor):
387
+ def spawn(self, child_actor: Actor) -> ActorId:
344
388
 
345
389
  if self._log.isEnabledFor(logging.DEBUG):
346
390
  self._log.debug(f"spawn [{self.actor_id}]: [{type(child_actor)}]")
@@ -355,6 +399,11 @@ class ActorNode:
355
399
  child_node = ActorNode(child_id, child_actor, self, self.system, event_loop)
356
400
  self.children[child_id] = child_node
357
401
 
402
+ # If this is a threadsafe actor, set up the threadsafe context
403
+ if isinstance(child_actor, ThreadsafeActor):
404
+ threadsafe = ThreadsafeContext(child_node)
405
+ child_actor._ThreadsafeActor__threadsafe = threadsafe
406
+
358
407
  child_node.send_signal(self.actor_id, child_id, SignalNames.START)
359
408
 
360
409
  return child_id
@@ -542,6 +591,12 @@ class ActorNode:
542
591
  if not self._check_message_target(signal):
543
592
  return
544
593
 
594
+ # Do not process signals after the actor has stopped
595
+ # This is common with e.g. STOP signals that propagate up and down the tree
596
+
597
+ if self.state in [ActorState.STOPPED, ActorState.FAILED]:
598
+ return
599
+
545
600
  # Call the signal receiver function
546
601
  # This gives the actor a chance to respond to the signal
547
602
 
@@ -768,10 +823,12 @@ class ActorNode:
768
823
  # Positional arg types
769
824
  for pos_param, pos_arg in zip(pos_params, args):
770
825
 
826
+ # If no type hint is available, allow anything through
827
+ # Otherwise, reuse the validator logic to type check individual args
771
828
  type_hint = type_hints.get(pos_param.name)
829
+ type_check = type_hint is None or _val.check_type(type_hint, pos_arg)
772
830
 
773
- # If no type hint is available, allow anything through
774
- if type_hint is not None and not isinstance(pos_arg, type_hint):
831
+ if not type_check:
775
832
  error = f"Invalid message: [{message}] -> {target_id} (wrong parameter type for '{pos_param.name}')"
776
833
  self._log.error(error)
777
834
  raise EBadActor(error)
@@ -780,20 +837,20 @@ class ActorNode:
780
837
  for kw_param in kw_params:
781
838
 
782
839
  kw_arg = kwargs.get(kw_param.name)
783
- type_hint = type_hints.get(kw_param.name)
784
840
 
785
841
  # If param has taken a default value, no type check is needed
786
842
  if kw_arg is None:
787
843
  continue
788
844
 
789
- # If no type hint is available, allow anything through
790
- if type_hint is not None and not isinstance(kw_arg, type_hint):
845
+ # Otherwise use the same type-validation logic as positional args
846
+ type_hint = type_hints.get(kw_param.name)
847
+ type_check = type_hint is None or _val.check_type(type_hint, kw_arg)
848
+
849
+ if not type_check:
791
850
  error = f"Invalid message: [{message}] -> {target_id} (wrong parameter type for '{kw_param.name}')"
792
851
  self._log.error(error)
793
852
  raise EBadActor(error)
794
853
 
795
- # TODO: Verify generics for both args and kwargs
796
-
797
854
 
798
855
  class RootActor(Actor):
799
856
 
@@ -864,11 +921,17 @@ class ActorSystem:
864
921
 
865
922
  self.__root_started = threading.Event()
866
923
  self.__root_stopped = threading.Event()
924
+
867
925
  self.__root_actor = RootActor(main_actor, self.__root_started, self.__root_stopped)
868
926
  self.__root_node = ActorNode(self.ROOT_ID, self.__root_actor, None, self, self.__system_event_loop)
869
927
 
870
928
  # Public API
871
929
 
930
+ def main_id(self) -> ActorId:
931
+ if not self.__root_started.is_set():
932
+ raise EBadActor("System has not started yet")
933
+ return self.__root_actor.main_id
934
+
872
935
  def start(self, wait=True):
873
936
 
874
937
  self.__system_thread.start()
@@ -913,12 +976,26 @@ class ActorSystem:
913
976
 
914
977
  return self.__root_node.error
915
978
 
916
- def send(self, message: str, *args, **kwargs):
979
+ def spawn_agent(self, agent: Actor) -> ActorId:
980
+
981
+ if not self.__root_started.is_set():
982
+ raise EBadActor("System has not started yet")
983
+
984
+ return self.__root_node.spawn(agent)
985
+
986
+ def send_main(self, message: str, *args, **kwargs):
917
987
 
918
988
  if self.__root_actor.main_id is None:
919
989
  raise EBadActor("System has not started yet")
920
990
 
921
- self.__root_node.send_message("/external", self.__root_actor.main_id, message, args, kwargs)
991
+ self.__root_node.send_message("/external", self.__root_actor.main_id, message, args, kwargs) # TODO
992
+
993
+ def send(self, actor_id: ActorId, message: str, *args, **kwargs):
994
+
995
+ if not self.__root_started.is_set():
996
+ raise EBadActor("System has not started yet")
997
+
998
+ self.__root_node.send_message("/external", actor_id, message, args, kwargs)
922
999
 
923
1000
  def _setup_event_loops(self, thread_pools: tp.Dict[str, int]):
924
1001
 
@@ -12,6 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ import copy
15
16
  import logging
16
17
  import pathlib
17
18
  import typing as tp
@@ -32,8 +33,8 @@ import tracdap.rt._impl.validation as _val # noqa
32
33
  class TracContextImpl(_api.TracContext):
33
34
 
34
35
  """
35
- TracContextImpl is the main implementation of the API class TracContext (from trac.rt.api).
36
- It provides get/put operations the inputs, outputs and parameters of a model according to the model definition,
36
+ TracContextImpl is the main implementation of the API class TracContext (from tracdap.rt.api).
37
+ It provides get/put operations on the inputs, outputs and parameters of a model according to the model definition,
37
38
  as well as exposing other information needed by the model at runtime and offering a few utility functions.
38
39
 
39
40
  An instance of TracContextImpl is constructed by the runtime engine for each model node in the execution graph.
@@ -44,8 +45,8 @@ class TracContextImpl(_api.TracContext):
44
45
 
45
46
  Optimizations for lazy loading and eager saving require the context to call back into the runtime engine. For lazy
46
47
  load, the graph node to prepare an input is injected when the data is requested and the model thread blocks until
47
- it is available; for eager save child nodes of individual outputs are triggered when those outputs are produced.
48
- In both cases this complexity is hidden from the model, which only sees one thread with synchronous get/put calls.
48
+ it is available; for eager save outputs are sent to child actors as soon as they are produced. In both cases this
49
+ complexity is hidden from the model, which only sees one thread with synchronous get/put calls.
49
50
 
50
51
  :param model_def: Definition object for the model that will run in this context
51
52
  :param model_class: Type for the model that will run in this context
@@ -59,8 +60,7 @@ class TracContextImpl(_api.TracContext):
59
60
  def __init__(self,
60
61
  model_def: _meta.ModelDefinition,
61
62
  model_class: _api.TracModel.__class__,
62
- local_ctx: tp.Dict[str, _data.DataView],
63
- schemas: tp.Dict[str, _meta.SchemaDefinition],
63
+ local_ctx: tp.Dict[str, tp.Any],
64
64
  checkout_directory: pathlib.Path = None):
65
65
 
66
66
  self.__ctx_log = _util.logger_for_object(self)
@@ -68,26 +68,25 @@ class TracContextImpl(_api.TracContext):
68
68
 
69
69
  self.__model_def = model_def
70
70
  self.__model_class = model_class
71
-
72
- self.__parameters = local_ctx or {}
73
- self.__data = local_ctx or {}
74
- self.__schemas = schemas
71
+ self.__local_ctx = local_ctx or {}
75
72
 
76
73
  self.__val = TracContextValidator(
77
74
  self.__ctx_log,
78
- self.__parameters,
79
- self.__data,
75
+ self.__model_def,
76
+ self.__local_ctx,
80
77
  checkout_directory)
81
78
 
82
79
  def get_parameter(self, parameter_name: str) -> tp.Any:
83
80
 
84
81
  _val.validate_signature(self.get_parameter, parameter_name)
85
82
 
86
- self.__val.check_param_not_null(parameter_name)
87
83
  self.__val.check_param_valid_identifier(parameter_name)
88
- self.__val.check_param_exists(parameter_name)
84
+ self.__val.check_param_defined_in_model(parameter_name)
85
+ self.__val.check_param_available_in_context(parameter_name)
86
+
87
+ value: _meta.Value = self.__local_ctx.get(parameter_name)
89
88
 
90
- value: _meta.Value = self.__parameters[parameter_name] # noqa
89
+ self.__val.check_context_object_type(parameter_name, value, _meta.Value)
91
90
 
92
91
  return _types.MetadataCodec.decode_value(value)
93
92
 
@@ -95,65 +94,64 @@ class TracContextImpl(_api.TracContext):
95
94
 
96
95
  _val.validate_signature(self.has_dataset, dataset_name)
97
96
 
98
- part_key = _data.DataPartKey.for_root()
99
-
100
- self.__val.check_dataset_name_not_null(dataset_name)
101
97
  self.__val.check_dataset_valid_identifier(dataset_name)
98
+ self.__val.check_dataset_defined_in_model(dataset_name)
102
99
 
103
- data_view = self.__data.get(dataset_name)
100
+ data_view: _data.DataView = self.__local_ctx.get(dataset_name)
104
101
 
105
102
  if data_view is None:
106
103
  return False
107
104
 
108
- # If the item exists but is not a dataset, that is still a runtime error
109
- # E.g. if this method is called for FILE inputs
110
- self.__val.check_context_item_is_dataset(dataset_name)
111
-
112
- part = data_view.parts.get(part_key)
105
+ self.__val.check_context_object_type(dataset_name, data_view, _data.DataView)
113
106
 
114
- if part is None or len(part) == 0:
115
- return False
116
-
117
- return True
107
+ return not data_view.is_empty()
118
108
 
119
109
  def get_schema(self, dataset_name: str) -> _meta.SchemaDefinition:
120
110
 
121
111
  _val.validate_signature(self.get_schema, dataset_name)
122
112
 
123
- self.__val.check_dataset_name_not_null(dataset_name)
124
113
  self.__val.check_dataset_valid_identifier(dataset_name)
114
+ self.__val.check_dataset_defined_in_model(dataset_name)
115
+ self.__val.check_dataset_available_in_context(dataset_name)
125
116
 
126
- # There is no need to look in the data map if the model has defined a static schema
127
- if dataset_name in self.__schemas:
128
- return self.__schemas[dataset_name]
117
+ static_schema = self.__get_static_schema(self.__model_def, dataset_name)
118
+ data_view: _data.DataView = self.__local_ctx.get(dataset_name)
129
119
 
130
- self.__val.check_context_item_exists(dataset_name)
131
- self.__val.check_context_item_is_dataset(dataset_name)
132
- self.__val.check_dataset_schema_defined(dataset_name)
120
+ # Check the data view has a well-defined schema even if a static schema exists in the model
121
+ # This ensures errors are always reported and is consistent with get_pandas_table()
133
122
 
134
- data_view = self.__data[dataset_name]
123
+ self.__val.check_context_object_type(dataset_name, data_view, _data.DataView)
124
+ self.__val.check_dataset_schema_defined(dataset_name, data_view)
135
125
 
136
- return data_view.trac_schema
126
+ # If a static schema exists, that takes priority
127
+ # Return deep copies, do not allow model code to change schemas provided by the engine
128
+
129
+ if static_schema is not None:
130
+ return copy.deepcopy(static_schema)
131
+ else:
132
+ return copy.deepcopy(data_view.trac_schema)
137
133
 
138
134
  def get_pandas_table(self, dataset_name: str, use_temporal_objects: tp.Optional[bool] = None) -> pd.DataFrame:
139
135
 
140
136
  _val.validate_signature(self.get_pandas_table, dataset_name, use_temporal_objects)
141
137
 
142
- part_key = _data.DataPartKey.for_root()
143
-
144
- self.__val.check_dataset_name_not_null(dataset_name)
145
138
  self.__val.check_dataset_valid_identifier(dataset_name)
146
- self.__val.check_context_item_exists(dataset_name)
147
- self.__val.check_context_item_is_dataset(dataset_name)
148
- self.__val.check_dataset_schema_defined(dataset_name)
149
- self.__val.check_dataset_part_present(dataset_name, part_key)
139
+ self.__val.check_dataset_defined_in_model(dataset_name)
140
+ self.__val.check_dataset_available_in_context(dataset_name)
150
141
 
151
- data_view = self.__data[dataset_name]
142
+ static_schema = self.__get_static_schema(self.__model_def, dataset_name)
143
+ data_view = self.__local_ctx.get(dataset_name)
144
+ part_key = _data.DataPartKey.for_root()
145
+
146
+ self.__val.check_context_object_type(dataset_name, data_view, _data.DataView)
147
+ self.__val.check_dataset_schema_defined(dataset_name, data_view)
148
+ self.__val.check_dataset_part_present(dataset_name, data_view, part_key)
152
149
 
153
150
  # If the model defines a static input schema, use that for schema conformance
154
151
  # Otherwise, take what is in the incoming dataset (schema is dynamic)
155
- if dataset_name in self.__schemas:
156
- schema = _data.DataMapping.trac_to_arrow_schema(self.__schemas[dataset_name])
152
+
153
+ if static_schema is not None:
154
+ schema = _data.DataMapping.trac_to_arrow_schema(static_schema)
157
155
  else:
158
156
  schema = data_view.arrow_schema
159
157
 
@@ -162,34 +160,71 @@ class TracContextImpl(_api.TracContext):
162
160
 
163
161
  return _data.DataMapping.view_to_pandas(data_view, part_key, schema, use_temporal_objects)
164
162
 
163
+ def put_schema(self, dataset_name: str, schema: _meta.SchemaDefinition):
164
+
165
+ _val.validate_signature(self.get_schema, dataset_name, schema)
166
+
167
+ # Copy the schema - schema cannot be changed in model code after put_schema
168
+ # If field ordering is not assigned by the model, assign it here (model code will not see the numbers)
169
+ schema_copy = self.__assign_field_order(copy.deepcopy(schema))
170
+
171
+ self.__val.check_dataset_valid_identifier(dataset_name)
172
+ self.__val.check_dataset_is_dynamic_output(dataset_name)
173
+ self.__val.check_provided_schema_is_valid(dataset_name, schema_copy)
174
+
175
+ static_schema = self.__get_static_schema(self.__model_def, dataset_name)
176
+ data_view = self.__local_ctx.get(dataset_name)
177
+
178
+ if data_view is None:
179
+ if static_schema is not None:
180
+ data_view = _data.DataView.for_trac_schema(static_schema)
181
+ else:
182
+ data_view = _data.DataView.create_empty()
183
+
184
+ # If there is a prior view it must contain nothing and will be replaced
185
+ self.__val.check_context_object_type(dataset_name, data_view, _data.DataView)
186
+ self.__val.check_dataset_schema_not_defined(dataset_name, data_view)
187
+ self.__val.check_dataset_is_empty(dataset_name, data_view)
188
+
189
+ updated_view = data_view.with_trac_schema(schema_copy)
190
+
191
+ self.__local_ctx[dataset_name] = updated_view
192
+
165
193
  def put_pandas_table(self, dataset_name: str, dataset: pd.DataFrame):
166
194
 
167
195
  _val.validate_signature(self.put_pandas_table, dataset_name, dataset)
168
196
 
169
- part_key = _data.DataPartKey.for_root()
170
-
171
- self.__val.check_dataset_name_not_null(dataset_name)
172
197
  self.__val.check_dataset_valid_identifier(dataset_name)
173
- self.__val.check_context_item_exists(dataset_name)
174
- self.__val.check_context_item_is_dataset(dataset_name)
175
- self.__val.check_dataset_schema_defined(dataset_name)
176
- self.__val.check_dataset_part_not_present(dataset_name, part_key)
177
- self.__val.check_provided_dataset_not_null(dataset)
198
+ self.__val.check_dataset_is_model_output(dataset_name)
178
199
  self.__val.check_provided_dataset_type(dataset, pd.DataFrame)
179
200
 
180
- prior_view = self.__data[dataset_name]
201
+ static_schema = self.__get_static_schema(self.__model_def, dataset_name)
202
+ data_view = self.__local_ctx.get(dataset_name)
203
+ part_key = _data.DataPartKey.for_root()
181
204
 
182
- # If the model defines a static output schema, use that for schema conformance
183
- # Otherwise, use the schema in the data view for this output (this could be a dynamic schema)
184
- if dataset_name in self.__schemas:
185
- schema = _data.DataMapping.trac_to_arrow_schema(self.__schemas[dataset_name])
205
+ if data_view is None:
206
+ if static_schema is not None:
207
+ data_view = _data.DataView.for_trac_schema(static_schema)
208
+ else:
209
+ data_view = _data.DataView.create_empty()
210
+
211
+ self.__val.check_context_object_type(dataset_name, data_view, _data.DataView)
212
+ self.__val.check_dataset_schema_defined(dataset_name, data_view)
213
+ self.__val.check_dataset_part_not_present(dataset_name, data_view, part_key)
214
+
215
+ # Prefer static schemas for data conformance
216
+
217
+ if static_schema is not None:
218
+ schema = _data.DataMapping.trac_to_arrow_schema(static_schema)
186
219
  else:
187
- schema = prior_view.arrow_schema
220
+ schema = data_view.arrow_schema
221
+
222
+ # Data conformance is applied inside these conversion functions
188
223
 
189
- data_item = _data.DataMapping.pandas_to_item(dataset, schema)
190
- data_view = _data.DataMapping.add_item_to_view(prior_view, part_key, data_item)
224
+ updated_item = _data.DataMapping.pandas_to_item(dataset, schema)
225
+ updated_view = _data.DataMapping.add_item_to_view(data_view, part_key, updated_item)
191
226
 
192
- self.__data[dataset_name] = data_view
227
+ self.__local_ctx[dataset_name] = updated_view
193
228
 
194
229
  def log(self) -> logging.Logger:
195
230
 
@@ -197,6 +232,33 @@ class TracContextImpl(_api.TracContext):
197
232
 
198
233
  return self.__model_log
199
234
 
235
+ @staticmethod
236
+ def __get_static_schema(model_def: _meta.ModelDefinition, dataset_name: str):
237
+
238
+ input_schema = model_def.inputs.get(dataset_name)
239
+
240
+ if input_schema is not None and not input_schema.dynamic:
241
+ return input_schema.schema
242
+
243
+ output_schema = model_def.outputs.get(dataset_name)
244
+
245
+ if output_schema is not None and not output_schema.dynamic:
246
+ return output_schema.schema
247
+
248
+ return None
249
+
250
+ @staticmethod
251
+ def __assign_field_order(schema_def: _meta.SchemaDefinition):
252
+
253
+ if schema_def is None or schema_def.table is None or schema_def.table.fields is None:
254
+ return schema_def
255
+
256
+ if all(map(lambda f: f.fieldOrder is None, schema_def.table.fields)):
257
+ for index, field in enumerate(schema_def.table.fields):
258
+ field.fieldOrder = index
259
+
260
+ return schema_def
261
+
200
262
 
201
263
  class TracContextValidator:
202
264
 
@@ -205,16 +267,16 @@ class TracContextValidator:
205
267
 
206
268
  def __init__(
207
269
  self, log: logging.Logger,
208
- parameters: tp.Dict[str, tp.Any],
209
- data_ctx: tp.Dict[str, _data.DataView],
270
+ model_def: _meta.ModelDefinition,
271
+ local_ctx: tp.Dict[str, tp.Any],
210
272
  checkout_directory: pathlib.Path):
211
273
 
212
274
  self.__log = log
213
- self.__parameters = parameters
214
- self.__data_ctx = data_ctx
275
+ self.__model_def = model_def
276
+ self.__local_ctx = local_ctx
215
277
  self.__checkout_directory = checkout_directory
216
278
 
217
- def _report_error(self, message):
279
+ def _report_error(self, message, cause: Exception = None):
218
280
 
219
281
  full_stack = traceback.extract_stack()
220
282
  model_stack = _util.filter_model_stack_trace(full_stack, self.__checkout_directory)
@@ -225,80 +287,114 @@ class TracContextValidator:
225
287
  self.__log.error(message)
226
288
  self.__log.error(f"Model stack trace:\n{model_stack_str}")
227
289
 
228
- raise _ex.ERuntimeValidation(message)
290
+ if cause:
291
+ raise _ex.ERuntimeValidation(message) from cause
292
+ else:
293
+ raise _ex.ERuntimeValidation(message)
229
294
 
230
- def check_param_not_null(self, param_name):
295
+ def check_param_valid_identifier(self, param_name: str):
231
296
 
232
297
  if param_name is None:
233
298
  self._report_error(f"Parameter name is null")
234
299
 
235
- def check_param_valid_identifier(self, param_name: str):
236
-
237
300
  if not self.__VALID_IDENTIFIER.match(param_name):
238
301
  self._report_error(f"Parameter name {param_name} is not a valid identifier")
239
302
 
240
- def check_param_exists(self, param_name: str):
303
+ def check_param_defined_in_model(self, param_name: str):
241
304
 
242
- if param_name not in self.__parameters:
243
- self._report_error(f"Parameter {param_name} is not defined in the current context")
305
+ if param_name not in self.__model_def.parameters:
306
+ self._report_error(f"Parameter {param_name} is not defined in the model")
244
307
 
245
- def check_dataset_name_not_null(self, dataset_name):
308
+ def check_param_available_in_context(self, param_name: str):
246
309
 
247
- if dataset_name is None:
248
- self._report_error(f"Dataset name is null")
310
+ if param_name not in self.__local_ctx:
311
+ self._report_error(f"Parameter {param_name} is not available in the current context")
249
312
 
250
313
  def check_dataset_valid_identifier(self, dataset_name: str):
251
314
 
315
+ if dataset_name is None:
316
+ self._report_error(f"Dataset name is null")
317
+
252
318
  if not self.__VALID_IDENTIFIER.match(dataset_name):
253
319
  self._report_error(f"Dataset name {dataset_name} is not a valid identifier")
254
320
 
255
- def check_context_item_exists(self, item_name: str):
321
+ def check_dataset_defined_in_model(self, dataset_name: str):
322
+
323
+ if dataset_name not in self.__model_def.inputs and dataset_name not in self.__model_def.outputs:
324
+ self._report_error(f"Dataset {dataset_name} is not defined in the model")
325
+
326
+ def check_dataset_is_model_output(self, dataset_name: str):
256
327
 
257
- if item_name not in self.__data_ctx:
258
- self._report_error(f"The identifier {item_name} is not defined in the current context")
328
+ if dataset_name not in self.__model_def.outputs:
329
+ self._report_error(f"Dataset {dataset_name} is not defined as a model output")
259
330
 
260
- def check_context_item_is_dataset(self, item_name: str):
331
+ def check_dataset_is_dynamic_output(self, dataset_name: str):
261
332
 
262
- ctx_item = self.__data_ctx[item_name]
333
+ model_output: _meta.ModelOutputSchema = self.__model_def.outputs.get(dataset_name)
263
334
 
264
- if not isinstance(ctx_item, _data.DataView):
265
- self._report_error(f"The object referenced by {item_name} is not a dataset in the current context")
335
+ if model_output is None:
336
+ self._report_error(f"Dataset {dataset_name} is not defined as a model output")
266
337
 
267
- def check_dataset_schema_defined(self, dataset_name: str):
338
+ if not model_output.dynamic:
339
+ self._report_error(f"Model output {dataset_name} is not a dynamic output")
268
340
 
269
- schema = self.__data_ctx[dataset_name].trac_schema
341
+ def check_dataset_available_in_context(self, item_name: str):
270
342
 
271
- if schema is None or not schema.table or not schema.table.fields:
343
+ if item_name not in self.__local_ctx:
344
+ self._report_error(f"Dataset {item_name} is not available in the current context")
345
+
346
+ def check_dataset_schema_defined(self, dataset_name: str, data_view: _data.DataView):
347
+
348
+ schema = data_view.trac_schema if data_view is not None else None
349
+
350
+ if schema is None or schema.table is None or not schema.table.fields:
272
351
  self._report_error(f"Schema not defined for dataset {dataset_name} in the current context")
273
352
 
274
- def check_dataset_schema_not_defined(self, dataset_name: str):
353
+ def check_dataset_schema_not_defined(self, dataset_name: str, data_view: _data.DataView):
275
354
 
276
- schema = self.__data_ctx[dataset_name].trac_schema
355
+ schema = data_view.trac_schema if data_view is not None else None
277
356
 
278
357
  if schema is not None and (schema.table or schema.schemaType != _meta.SchemaType.SCHEMA_TYPE_NOT_SET):
279
358
  self._report_error(f"Schema already defined for dataset {dataset_name} in the current context")
280
359
 
281
- def check_dataset_part_present(self, dataset_name: str, part_key: _data.DataPartKey):
360
+ def check_dataset_part_present(self, dataset_name: str, data_view: _data.DataView, part_key: _data.DataPartKey):
282
361
 
283
- part = self.__data_ctx[dataset_name].parts.get(part_key)
362
+ part = data_view.parts.get(part_key) if data_view.parts is not None else None
284
363
 
285
364
  if part is None or len(part) == 0:
286
- self._report_error(f"No data present for dataset {dataset_name} ({part_key}) in the current context")
365
+ self._report_error(f"No data present for {dataset_name} ({part_key}) in the current context")
287
366
 
288
- def check_dataset_part_not_present(self, dataset_name: str, part_key: _data.DataPartKey):
367
+ def check_dataset_part_not_present(self, dataset_name: str, data_view: _data.DataView, part_key: _data.DataPartKey):
289
368
 
290
- part = self.__data_ctx[dataset_name].parts.get(part_key)
369
+ part = data_view.parts.get(part_key) if data_view.parts is not None else None
291
370
 
292
371
  if part is not None and len(part) > 0:
293
- self._report_error(f"Data already present for dataset {dataset_name} ({part_key}) in the current context")
372
+ self._report_error(f"Data already present for {dataset_name} ({part_key}) in the current context")
294
373
 
295
- def check_provided_dataset_not_null(self, dataset):
374
+ def check_dataset_is_empty(self, dataset_name: str, data_view: _data.DataView):
296
375
 
297
- if dataset is None:
298
- self._report_error(f"Provided dataset is null")
376
+ if not data_view.is_empty():
377
+ self._report_error(f"Dataset {dataset_name} is not empty")
378
+
379
+ def check_provided_schema_is_valid(self, dataset_name: str, schema: _meta.SchemaDefinition):
380
+
381
+ if schema is None:
382
+ self._report_error(f"The schema provided for [{dataset_name}] is null")
383
+
384
+ if not isinstance(schema, _meta.SchemaDefinition):
385
+ schema_type_name = self._type_name(type(schema))
386
+ self._report_error(f"The object provided for [{dataset_name}] is not a schema (got {schema_type_name})")
387
+
388
+ try:
389
+ _val.StaticValidator.quick_validate_schema(schema)
390
+ except _ex.EModelValidation as e:
391
+ self._report_error(f"The schema provided for [{dataset_name}] failed validation: {str(e)}", e)
299
392
 
300
393
  def check_provided_dataset_type(self, dataset: tp.Any, expected_type: type):
301
394
 
395
+ if dataset is None:
396
+ self._report_error(f"Provided dataset is null")
397
+
302
398
  if not isinstance(dataset, expected_type):
303
399
 
304
400
  expected_type_name = self._type_name(expected_type)
@@ -308,6 +404,17 @@ class TracContextValidator:
308
404
  f"Provided dataset is the wrong type" +
309
405
  f" (expected {expected_type_name}, got {actual_type_name})")
310
406
 
407
+ def check_context_object_type(self, item_name: str, item: tp.Any, expected_type: type):
408
+
409
+ if not isinstance(item, expected_type):
410
+
411
+ expected_type_name = self._type_name(expected_type)
412
+ actual_type_name = self._type_name(type(item))
413
+
414
+ self._report_error(
415
+ f"The object referenced by [{item_name}] in the current context has the wrong type" +
416
+ f" (expected {expected_type_name}, got {actual_type_name})")
417
+
311
418
  @staticmethod
312
419
  def _type_name(type_: type):
313
420