tracdap-runtime 0.6.4__py3-none-any.whl → 0.6.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. tracdap/rt/_exec/context.py +556 -36
  2. tracdap/rt/_exec/dev_mode.py +320 -198
  3. tracdap/rt/_exec/engine.py +331 -62
  4. tracdap/rt/_exec/functions.py +151 -22
  5. tracdap/rt/_exec/graph.py +47 -13
  6. tracdap/rt/_exec/graph_builder.py +383 -175
  7. tracdap/rt/_exec/runtime.py +7 -5
  8. tracdap/rt/_impl/config_parser.py +11 -4
  9. tracdap/rt/_impl/data.py +329 -152
  10. tracdap/rt/_impl/ext/__init__.py +13 -0
  11. tracdap/rt/_impl/ext/sql.py +116 -0
  12. tracdap/rt/_impl/ext/storage.py +57 -0
  13. tracdap/rt/_impl/grpc/tracdap/metadata/job_pb2.py +82 -30
  14. tracdap/rt/_impl/grpc/tracdap/metadata/job_pb2.pyi +155 -2
  15. tracdap/rt/_impl/grpc/tracdap/metadata/model_pb2.py +12 -10
  16. tracdap/rt/_impl/grpc/tracdap/metadata/model_pb2.pyi +14 -2
  17. tracdap/rt/_impl/grpc/tracdap/metadata/resource_pb2.py +29 -0
  18. tracdap/rt/_impl/grpc/tracdap/metadata/resource_pb2.pyi +16 -0
  19. tracdap/rt/_impl/models.py +8 -0
  20. tracdap/rt/_impl/static_api.py +29 -0
  21. tracdap/rt/_impl/storage.py +39 -27
  22. tracdap/rt/_impl/util.py +10 -0
  23. tracdap/rt/_impl/validation.py +140 -18
  24. tracdap/rt/_plugins/repo_git.py +1 -1
  25. tracdap/rt/_plugins/storage_sql.py +417 -0
  26. tracdap/rt/_plugins/storage_sql_dialects.py +117 -0
  27. tracdap/rt/_version.py +1 -1
  28. tracdap/rt/api/experimental.py +267 -0
  29. tracdap/rt/api/hook.py +14 -0
  30. tracdap/rt/api/model_api.py +48 -6
  31. tracdap/rt/config/__init__.py +2 -2
  32. tracdap/rt/config/common.py +6 -0
  33. tracdap/rt/metadata/__init__.py +29 -20
  34. tracdap/rt/metadata/job.py +99 -0
  35. tracdap/rt/metadata/model.py +18 -0
  36. tracdap/rt/metadata/resource.py +24 -0
  37. {tracdap_runtime-0.6.4.dist-info → tracdap_runtime-0.6.6.dist-info}/METADATA +5 -1
  38. {tracdap_runtime-0.6.4.dist-info → tracdap_runtime-0.6.6.dist-info}/RECORD +41 -32
  39. {tracdap_runtime-0.6.4.dist-info → tracdap_runtime-0.6.6.dist-info}/WHEEL +1 -1
  40. {tracdap_runtime-0.6.4.dist-info → tracdap_runtime-0.6.6.dist-info}/LICENSE +0 -0
  41. {tracdap_runtime-0.6.4.dist-info → tracdap_runtime-0.6.6.dist-info}/top_level.txt +0 -0
@@ -154,7 +154,6 @@ class TracRuntime:
154
154
  _plugins.PluginManager.register_plugin_package(plugin_package)
155
155
 
156
156
  _static_api.StaticApiImpl.register_impl()
157
- _guard.PythonGuardRails.protect_dangerous_functions()
158
157
 
159
158
  # Load sys config (or use embedded), config errors are detected before start()
160
159
  # Job config can also be checked before start() by using load_job_config()
@@ -201,6 +200,11 @@ class TracRuntime:
201
200
  self._models = _models.ModelLoader(self._sys_config, self._scratch_dir)
202
201
  self._storage = _storage.StorageManager(self._sys_config)
203
202
 
203
+ # Enable protection after the initial setup of the runtime is complete
204
+ # Storage plugins in particular are likely to tigger protected imports
205
+ # Once the runtime is up, no more plugins should be loaded
206
+ _guard.PythonGuardRails.protect_dangerous_functions()
207
+
204
208
  self._engine = _engine.TracEngine(
205
209
  self._sys_config, self._models, self._storage,
206
210
  notify_callback=self._engine_callback)
@@ -329,10 +333,8 @@ class TracRuntime:
329
333
  config_file_name="job")
330
334
 
331
335
  if self._dev_mode:
332
- job_config = _dev_mode.DevModeTranslator.translate_job_config(
333
- self._sys_config, job_config,
334
- self._scratch_dir, self._config_mgr,
335
- model_class)
336
+ translator = _dev_mode.DevModeTranslator(self._sys_config, self._config_mgr, self._scratch_dir)
337
+ job_config = translator.translate_job_config(job_config, model_class)
336
338
 
337
339
  return job_config
338
340
 
@@ -341,10 +341,17 @@ class ConfigParser(tp.Generic[_T]):
341
341
 
342
342
  if isinstance(raw_value, tp.Dict):
343
343
  return self._parse_simple_class(location, raw_value, annotation)
344
- elif self._is_dev_mode_location(location) and type(raw_value) in ConfigParser.__primitive_types:
345
- return self._parse_primitive(location, raw_value, type(raw_value))
346
- else:
347
- return self._error(location, f"Expected type {annotation.__name__}, got '{str(raw_value)}'")
344
+
345
+ if self._is_dev_mode_location(location):
346
+ if type(raw_value) in ConfigParser.__primitive_types:
347
+ return self._parse_primitive(location, raw_value, type(raw_value))
348
+ if isinstance(raw_value, list):
349
+ if len(raw_value) == 0:
350
+ return []
351
+ list_type = type(raw_value[0])
352
+ return list(map(lambda x: self._parse_primitive(location, x, list_type), raw_value))
353
+
354
+ return self._error(location, f"Expected type {annotation.__name__}, got '{str(raw_value)}'")
348
355
 
349
356
  if isinstance(annotation, self.__generic_metaclass):
350
357
  return self._parse_generic_class(location, raw_value, annotation) # noqa
tracdap/rt/_impl/data.py CHANGED
@@ -12,8 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from __future__ import annotations
16
-
15
+ import abc
17
16
  import dataclasses as dc
18
17
  import typing as tp
19
18
  import datetime as dt
@@ -22,8 +21,18 @@ import platform
22
21
 
23
22
  import pyarrow as pa
24
23
  import pyarrow.compute as pc
25
- import pandas as pd
26
24
 
25
+ try:
26
+ import pandas # noqa
27
+ except ModuleNotFoundError:
28
+ pandas = None
29
+
30
+ try:
31
+ import polars # noqa
32
+ except ModuleNotFoundError:
33
+ polars = None
34
+
35
+ import tracdap.rt.api.experimental as _api
27
36
  import tracdap.rt.metadata as _meta
28
37
  import tracdap.rt.exceptions as _ex
29
38
  import tracdap.rt._impl.util as _util
@@ -42,7 +51,7 @@ class DataSpec:
42
51
  class DataPartKey:
43
52
 
44
53
  @classmethod
45
- def for_root(cls) -> DataPartKey:
54
+ def for_root(cls) -> "DataPartKey":
46
55
  return DataPartKey(opaque_key='part_root')
47
56
 
48
57
  opaque_key: str
@@ -55,14 +64,14 @@ class DataItem:
55
64
  table: tp.Optional[pa.Table] = None
56
65
  batches: tp.Optional[tp.List[pa.RecordBatch]] = None
57
66
 
58
- pandas: tp.Optional[pd.DataFrame] = None
67
+ pandas: "tp.Optional[pandas.DataFrame]" = None
59
68
  pyspark: tp.Any = None
60
69
 
61
70
  def is_empty(self) -> bool:
62
71
  return self.table is None and (self.batches is None or len(self.batches) == 0)
63
72
 
64
73
  @staticmethod
65
- def create_empty() -> DataItem:
74
+ def create_empty() -> "DataItem":
66
75
  return DataItem(pa.schema([]))
67
76
 
68
77
 
@@ -75,7 +84,7 @@ class DataView:
75
84
  parts: tp.Dict[DataPartKey, tp.List[DataItem]]
76
85
 
77
86
  @staticmethod
78
- def create_empty() -> DataView:
87
+ def create_empty() -> "DataView":
79
88
  return DataView(_meta.SchemaDefinition(), pa.schema([]), dict())
80
89
 
81
90
  @staticmethod
@@ -109,73 +118,19 @@ class DataMapping:
109
118
 
110
119
  # Matches TRAC_ARROW_TYPE_MAPPING in ArrowSchema, tracdap-lib-data
111
120
 
112
- __TRAC_DECIMAL_PRECISION = 38
113
- __TRAC_DECIMAL_SCALE = 12
114
- __TRAC_TIMESTAMP_UNIT = "ms"
115
- __TRAC_TIMESTAMP_ZONE = None
121
+ DEFAULT_DECIMAL_PRECISION = 38
122
+ DEFAULT_DECIMAL_SCALE = 12
123
+ DEFAULT_TIMESTAMP_UNIT = "ms"
124
+ DEFAULT_TIMESTAMP_ZONE = None
116
125
 
117
126
  __TRAC_TO_ARROW_BASIC_TYPE_MAPPING = {
118
127
  _meta.BasicType.BOOLEAN: pa.bool_(),
119
128
  _meta.BasicType.INTEGER: pa.int64(),
120
129
  _meta.BasicType.FLOAT: pa.float64(),
121
- _meta.BasicType.DECIMAL: pa.decimal128(__TRAC_DECIMAL_PRECISION, __TRAC_DECIMAL_SCALE),
130
+ _meta.BasicType.DECIMAL: pa.decimal128(DEFAULT_DECIMAL_PRECISION, DEFAULT_DECIMAL_SCALE),
122
131
  _meta.BasicType.STRING: pa.utf8(),
123
132
  _meta.BasicType.DATE: pa.date32(),
124
- _meta.BasicType.DATETIME: pa.timestamp(__TRAC_TIMESTAMP_UNIT, __TRAC_TIMESTAMP_ZONE)
125
- }
126
-
127
- # Check the Pandas dtypes for handling floats are available before setting up the type mapping
128
- __PANDAS_VERSION_ELEMENTS = pd.__version__.split(".")
129
- __PANDAS_MAJOR_VERSION = int(__PANDAS_VERSION_ELEMENTS[0])
130
- __PANDAS_MINOR_VERSION = int(__PANDAS_VERSION_ELEMENTS[1])
131
-
132
- if __PANDAS_MAJOR_VERSION == 2:
133
-
134
- __PANDAS_DATE_TYPE = pd.to_datetime([dt.date(2000, 1, 1)]).as_unit(__TRAC_TIMESTAMP_UNIT).dtype
135
- __PANDAS_DATETIME_TYPE = pd.to_datetime([dt.datetime(2000, 1, 1, 0, 0, 0)]).as_unit(__TRAC_TIMESTAMP_UNIT).dtype
136
-
137
- @classmethod
138
- def __pandas_datetime_type(cls, tz, unit):
139
- if tz is None and unit is None:
140
- return cls.__PANDAS_DATETIME_TYPE
141
- _unit = unit if unit is not None else cls.__TRAC_TIMESTAMP_UNIT
142
- if tz is None:
143
- return pd.to_datetime([dt.datetime(2000, 1, 1, 0, 0, 0)]).as_unit(_unit).dtype
144
- else:
145
- return pd.DatetimeTZDtype(tz=tz, unit=_unit)
146
-
147
- # Minimum supported version for Pandas is 1.2, when pd.Float64Dtype was introduced
148
- elif __PANDAS_MAJOR_VERSION == 1 and __PANDAS_MINOR_VERSION >= 2:
149
-
150
- __PANDAS_DATE_TYPE = pd.to_datetime([dt.date(2000, 1, 1)]).dtype
151
- __PANDAS_DATETIME_TYPE = pd.to_datetime([dt.datetime(2000, 1, 1, 0, 0, 0)]).dtype
152
-
153
- @classmethod
154
- def __pandas_datetime_type(cls, tz, unit): # noqa
155
- if tz is None:
156
- return cls.__PANDAS_DATETIME_TYPE
157
- else:
158
- return pd.DatetimeTZDtype(tz=tz)
159
-
160
- else:
161
- raise _ex.EStartup(f"Pandas version not supported: [{pd.__version__}]")
162
-
163
- # Only partial mapping is possible, decimal and temporal dtypes cannot be mapped this way
164
- __ARROW_TO_PANDAS_TYPE_MAPPING = {
165
- pa.bool_(): pd.BooleanDtype(),
166
- pa.int8(): pd.Int8Dtype(),
167
- pa.int16(): pd.Int16Dtype(),
168
- pa.int32(): pd.Int32Dtype(),
169
- pa.int64(): pd.Int64Dtype(),
170
- pa.uint8(): pd.UInt8Dtype(),
171
- pa.uint16(): pd.UInt16Dtype(),
172
- pa.uint32(): pd.UInt32Dtype(),
173
- pa.uint64(): pd.UInt64Dtype(),
174
- pa.float16(): pd.Float32Dtype(),
175
- pa.float32(): pd.Float32Dtype(),
176
- pa.float64(): pd.Float64Dtype(),
177
- pa.string(): pd.StringDtype(),
178
- pa.utf8(): pd.StringDtype()
133
+ _meta.BasicType.DATETIME: pa.timestamp(DEFAULT_TIMESTAMP_UNIT, DEFAULT_TIMESTAMP_ZONE)
179
134
  }
180
135
 
181
136
  __ARROW_TO_TRAC_BASIC_TYPE_MAPPING = {
@@ -236,7 +191,7 @@ class DataMapping:
236
191
  return pa.float64()
237
192
 
238
193
  if python_type == decimal.Decimal:
239
- return pa.decimal128(cls.__TRAC_DECIMAL_PRECISION, cls.__TRAC_DECIMAL_SCALE)
194
+ return pa.decimal128(cls.DEFAULT_DECIMAL_PRECISION, cls.DEFAULT_DECIMAL_SCALE)
240
195
 
241
196
  if python_type == str:
242
197
  return pa.utf8()
@@ -245,7 +200,7 @@ class DataMapping:
245
200
  return pa.date32()
246
201
 
247
202
  if python_type == dt.datetime:
248
- return pa.timestamp(cls.__TRAC_TIMESTAMP_UNIT, cls.__TRAC_TIMESTAMP_ZONE)
203
+ return pa.timestamp(cls.DEFAULT_TIMESTAMP_UNIT, cls.DEFAULT_TIMESTAMP_ZONE)
249
204
 
250
205
  raise _ex.ETracInternal(f"No Arrow type mapping available for Python type [{python_type}]")
251
206
 
@@ -286,8 +241,8 @@ class DataMapping:
286
241
  def trac_arrow_decimal_type(cls) -> pa.Decimal128Type:
287
242
 
288
243
  return pa.decimal128(
289
- cls.__TRAC_DECIMAL_PRECISION,
290
- cls.__TRAC_DECIMAL_SCALE)
244
+ cls.DEFAULT_DECIMAL_PRECISION,
245
+ cls.DEFAULT_DECIMAL_SCALE,)
291
246
 
292
247
  @classmethod
293
248
  def arrow_to_trac_schema(cls, arrow_schema: pa.Schema) -> _meta.SchemaDefinition:
@@ -330,28 +285,6 @@ class DataMapping:
330
285
 
331
286
  raise _ex.ETracInternal(f"No data type mapping available for Arrow type [{arrow_type}]")
332
287
 
333
- @classmethod
334
- def pandas_date_type(cls):
335
- return cls.__PANDAS_DATE_TYPE
336
-
337
- @classmethod
338
- def pandas_datetime_type(cls, tz=None, unit=None):
339
- return cls.__pandas_datetime_type(tz, unit)
340
-
341
- @classmethod
342
- def view_to_pandas(
343
- cls, view: DataView, part: DataPartKey, schema: tp.Optional[pa.Schema],
344
- temporal_objects_flag: bool) -> pd.DataFrame:
345
-
346
- table = cls.view_to_arrow(view, part)
347
- return cls.arrow_to_pandas(table, schema, temporal_objects_flag)
348
-
349
- @classmethod
350
- def pandas_to_item(cls, df: pd.DataFrame, schema: tp.Optional[pa.Schema]) -> DataItem:
351
-
352
- table = cls.pandas_to_arrow(df, schema)
353
- return DataItem(table.schema, table)
354
-
355
288
  @classmethod
356
289
  def add_item_to_view(cls, view: DataView, part: DataPartKey, item: DataItem) -> DataView:
357
290
 
@@ -400,73 +333,306 @@ class DataMapping:
400
333
 
401
334
  @classmethod
402
335
  def arrow_to_pandas(
403
- cls, table: pa.Table, schema: tp.Optional[pa.Schema] = None,
404
- temporal_objects_flag: bool = False) -> pd.DataFrame:
336
+ cls, table: pa.Table,
337
+ schema: tp.Optional[pa.Schema] = None,
338
+ temporal_objects_flag: bool = False) -> "pandas.DataFrame":
405
339
 
406
- if schema is not None:
407
- table = DataConformance.conform_to_schema(table, schema, warn_extra_columns=False)
408
- else:
409
- DataConformance.check_duplicate_fields(table.schema.names, False)
340
+ # This is a legacy internal method and should be removed
341
+ # DataMapping is no longer responsible for individual data APIs
342
+
343
+ # Maintained temporarily for compatibility with existing deployments
344
+
345
+ converter = PandasArrowConverter(_api.PANDAS, use_temporal_objects=temporal_objects_flag)
346
+ return converter.from_internal(table, schema)
347
+
348
+ @classmethod
349
+ def pandas_to_arrow(
350
+ cls, df: "pandas.DataFrame",
351
+ schema: tp.Optional[pa.Schema] = None) -> pa.Table:
352
+
353
+ # This is a legacy internal method and should be removed
354
+ # DataMapping is no longer responsible for individual data APIs
355
+
356
+ # Maintained temporarily for compatibility with existing deployments
410
357
 
411
- # Use Arrow's built-in function to convert to Pandas
412
- return table.to_pandas(
358
+ converter = PandasArrowConverter(_api.PANDAS)
359
+ return converter.to_internal(df, schema)
413
360
 
414
- # Mapping for arrow -> pandas types for core types
415
- types_mapper=cls.__ARROW_TO_PANDAS_TYPE_MAPPING.get,
416
361
 
417
- # Use Python objects for dates and times if temporal_objects_flag is set
418
- date_as_object=temporal_objects_flag, # noqa
419
- timestamp_as_object=temporal_objects_flag, # noqa
420
362
 
421
- # Do not bring any Arrow metadata into Pandas dataframe
422
- ignore_metadata=True, # noqa
363
+ T_DATA_API = tp.TypeVar("T_DATA_API")
364
+ T_INTERNAL_DATA = tp.TypeVar("T_INTERNAL_DATA")
365
+ T_INTERNAL_SCHEMA = tp.TypeVar("T_INTERNAL_SCHEMA")
423
366
 
424
- # Do not consolidate memory across columns when preparing the Pandas vectors
425
- # This is a significant performance win for very wide datasets
426
- split_blocks=True) # noqa
367
+
368
+ class DataConverter(tp.Generic[T_DATA_API, T_INTERNAL_DATA, T_INTERNAL_SCHEMA]):
369
+
370
+ # Available per-framework args, to enable framework-specific type-checking in public APIs
371
+ # These should (for a purist point of view) be in the individual converter classes
372
+ # For now there are only a few converters, they are all defined here so this is OK
373
+ __FRAMEWORK_ARGS = {
374
+ _api.PANDAS: {"use_temporal_objects": tp.Optional[bool]},
375
+ _api.POLARS: {}
376
+ }
427
377
 
428
378
  @classmethod
429
- def pandas_to_arrow(cls, df: pd.DataFrame, schema: tp.Optional[pa.Schema] = None) -> pa.Table:
379
+ def get_framework(cls, dataset: _api.DATA_API) -> _api.DataFramework[_api.DATA_API]:
380
+
381
+ if pandas is not None and isinstance(dataset, pandas.DataFrame):
382
+ return _api.PANDAS
430
383
 
431
- # Converting pandas -> arrow needs care to ensure type coercion is applied correctly
432
- # Calling Table.from_pandas with the supplied schema will very often reject data
433
- # Instead, we convert the dataframe as-is and then apply type conversion in a second step
434
- # This allows us to apply specific coercion rules for each data type
384
+ if polars is not None and isinstance(dataset, polars.DataFrame):
385
+ return _api.POLARS
435
386
 
436
- # As an optimisation, the column filter means columns will not be converted if they are not needed
437
- # E.g. if a model outputs lots of undeclared columns, there is no need to convert them
387
+ data_api_type = f"{type(dataset).__module__}.{type(dataset).__name__}"
388
+ raise _ex.EPluginNotAvailable(f"No data framework available for type [{data_api_type}]")
438
389
 
439
- column_filter = DataConformance.column_filter(df.columns, schema) # noqa
390
+ @classmethod
391
+ def get_framework_args(cls, framework: _api.DataFramework[_api.DATA_API]) -> tp.Dict[str, type]:
440
392
 
441
- if len(df) > 0:
393
+ return cls.__FRAMEWORK_ARGS.get(framework) or {}
442
394
 
443
- table = pa.Table.from_pandas(df, columns=column_filter, preserve_index=False) # noqa
395
+ @classmethod
396
+ def for_framework(cls, framework: _api.DataFramework[_api.DATA_API], **framework_args) -> "DataConverter[_api.DATA_API, pa.Table, pa.Schema]":
444
397
 
445
- # Special case handling for converting an empty dataframe
446
- # These must flow through the pipe with valid schemas, like any other dataset
447
- # Type coercion and column filtering happen in conform_to_schema, if a schema has been supplied
398
+ if framework == _api.PANDAS:
399
+ if pandas is not None:
400
+ return PandasArrowConverter(framework, **framework_args)
401
+ else:
402
+ raise _ex.EPluginNotAvailable(f"Optional package [{framework}] is not installed")
448
403
 
449
- else:
404
+ if framework == _api.POLARS:
405
+ if polars is not None:
406
+ return PolarsArrowConverter(framework)
407
+ else:
408
+ raise _ex.EPluginNotAvailable(f"Optional package [{framework}] is not installed")
409
+
410
+ raise _ex.EPluginNotAvailable(f"Data framework [{framework}] is not recognized")
411
+
412
+ @classmethod
413
+ def for_dataset(cls, dataset: _api.DATA_API) -> "DataConverter[_api.DATA_API, pa.Table, pa.Schema]":
414
+
415
+ return cls.for_framework(cls.get_framework(dataset))
416
+
417
+ @classmethod
418
+ def noop(cls) -> "DataConverter[T_INTERNAL_DATA, T_INTERNAL_DATA, T_INTERNAL_SCHEMA]":
419
+ return NoopConverter()
420
+
421
+ def __init__(self, framework: _api.DataFramework[T_DATA_API]):
422
+ self.framework = framework
450
423
 
451
- empty_df = df.filter(column_filter) if column_filter else df
452
- empty_schema = pa.Schema.from_pandas(empty_df, preserve_index=False) # noqa
424
+ @abc.abstractmethod
425
+ def from_internal(self, dataset: T_INTERNAL_DATA, schema: tp.Optional[T_INTERNAL_SCHEMA] = None) -> T_DATA_API:
426
+ pass
453
427
 
454
- table = pa.Table.from_batches(list(), empty_schema) # noqa
428
+ @abc.abstractmethod
429
+ def to_internal(self, dataset: T_DATA_API, schema: tp.Optional[T_INTERNAL_SCHEMA] = None) -> T_INTERNAL_DATA:
430
+ pass
455
431
 
456
- # If there is no explict schema, give back the table exactly as it was received from Pandas
457
- # There could be an option here to infer and coerce for TRAC standard types
458
- # E.g. unsigned int 32 -> signed int 64, TRAC standard integer type
432
+ @abc.abstractmethod
433
+ def infer_schema(self, dataset: T_DATA_API) -> _meta.SchemaDefinition:
434
+ pass
459
435
 
460
- if schema is None:
461
- DataConformance.check_duplicate_fields(table.schema.names, False)
462
- return table
463
436
 
464
- # If a schema has been supplied, apply data conformance
465
- # If column filtering has been applied, we also need to filter the pandas dtypes used for hinting
437
+ class NoopConverter(DataConverter[T_INTERNAL_DATA, T_INTERNAL_DATA, T_INTERNAL_SCHEMA]):
438
+
439
+ def __init__(self):
440
+ super().__init__(_api.DataFramework("internal", None)) # noqa
441
+
442
+ def from_internal(self, dataset: T_INTERNAL_DATA, schema: tp.Optional[T_INTERNAL_SCHEMA] = None) -> T_DATA_API:
443
+ return dataset
444
+
445
+ def to_internal(self, dataset: T_DATA_API, schema: tp.Optional[T_INTERNAL_SCHEMA] = None) -> T_INTERNAL_DATA:
446
+ return dataset
447
+
448
+ def infer_schema(self, dataset: T_DATA_API) -> _meta.SchemaDefinition:
449
+ raise _ex.EUnexpected() # A real converter should be selected before use
450
+
451
+
452
+ # Data frameworks are optional, do not blow up the module just because one framework is unavailable!
453
+ if pandas is not None:
454
+
455
+ class PandasArrowConverter(DataConverter[pandas.DataFrame, pa.Table, pa.Schema]):
456
+
457
+ # Check the Pandas dtypes for handling floats are available before setting up the type mapping
458
+ __PANDAS_VERSION_ELEMENTS = pandas.__version__.split(".")
459
+ __PANDAS_MAJOR_VERSION = int(__PANDAS_VERSION_ELEMENTS[0])
460
+ __PANDAS_MINOR_VERSION = int(__PANDAS_VERSION_ELEMENTS[1])
461
+
462
+ if __PANDAS_MAJOR_VERSION == 2:
463
+
464
+ __PANDAS_DATE_TYPE = pandas.to_datetime([dt.date(2000, 1, 1)]).as_unit(DataMapping.DEFAULT_TIMESTAMP_UNIT).dtype
465
+ __PANDAS_DATETIME_TYPE = pandas.to_datetime([dt.datetime(2000, 1, 1, 0, 0, 0)]).as_unit(DataMapping.DEFAULT_TIMESTAMP_UNIT).dtype
466
+
467
+ @classmethod
468
+ def __pandas_datetime_type(cls, tz, unit):
469
+ if tz is None and unit is None:
470
+ return cls.__PANDAS_DATETIME_TYPE
471
+ _unit = unit if unit is not None else DataMapping.DEFAULT_TIMESTAMP_UNIT
472
+ if tz is None:
473
+ return pandas.to_datetime([dt.datetime(2000, 1, 1, 0, 0, 0)]).as_unit(_unit).dtype
474
+ else:
475
+ return pandas.DatetimeTZDtype(tz=tz, unit=_unit)
476
+
477
+ # Minimum supported version for Pandas is 1.2, when pandas.Float64Dtype was introduced
478
+ elif __PANDAS_MAJOR_VERSION == 1 and __PANDAS_MINOR_VERSION >= 2:
479
+
480
+ __PANDAS_DATE_TYPE = pandas.to_datetime([dt.date(2000, 1, 1)]).dtype
481
+ __PANDAS_DATETIME_TYPE = pandas.to_datetime([dt.datetime(2000, 1, 1, 0, 0, 0)]).dtype
482
+
483
+ @classmethod
484
+ def __pandas_datetime_type(cls, tz, unit): # noqa
485
+ if tz is None:
486
+ return cls.__PANDAS_DATETIME_TYPE
487
+ else:
488
+ return pandas.DatetimeTZDtype(tz=tz)
466
489
 
467
490
  else:
468
- df_types = df.dtypes.filter(column_filter) if column_filter else df.dtypes
469
- return DataConformance.conform_to_schema(table, schema, df_types)
491
+ raise _ex.EStartup(f"Pandas version not supported: [{pandas.__version__}]")
492
+
493
+ # Only partial mapping is possible, decimal and temporal dtypes cannot be mapped this way
494
+ __ARROW_TO_PANDAS_TYPE_MAPPING = {
495
+ pa.bool_(): pandas.BooleanDtype(),
496
+ pa.int8(): pandas.Int8Dtype(),
497
+ pa.int16(): pandas.Int16Dtype(),
498
+ pa.int32(): pandas.Int32Dtype(),
499
+ pa.int64(): pandas.Int64Dtype(),
500
+ pa.uint8(): pandas.UInt8Dtype(),
501
+ pa.uint16(): pandas.UInt16Dtype(),
502
+ pa.uint32(): pandas.UInt32Dtype(),
503
+ pa.uint64(): pandas.UInt64Dtype(),
504
+ pa.float16(): pandas.Float32Dtype(),
505
+ pa.float32(): pandas.Float32Dtype(),
506
+ pa.float64(): pandas.Float64Dtype(),
507
+ pa.string(): pandas.StringDtype(),
508
+ pa.utf8(): pandas.StringDtype()
509
+ }
510
+
511
+ __DEFAULT_TEMPORAL_OBJECTS = False
512
+
513
+ # Expose date type for testing
514
+ @classmethod
515
+ def pandas_date_type(cls):
516
+ return cls.__PANDAS_DATE_TYPE
517
+
518
+ # Expose datetime type for testing
519
+ @classmethod
520
+ def pandas_datetime_type(cls, tz=None, unit=None):
521
+ return cls.__pandas_datetime_type(tz, unit)
522
+
523
+ def __init__(self, framework: _api.DataFramework[T_DATA_API], use_temporal_objects: tp.Optional[bool] = None):
524
+ super().__init__(framework)
525
+ if use_temporal_objects is None:
526
+ self.__temporal_objects_flag = self.__DEFAULT_TEMPORAL_OBJECTS
527
+ else:
528
+ self.__temporal_objects_flag = use_temporal_objects
529
+
530
+ def from_internal(self, table: pa.Table, schema: tp.Optional[pa.Schema] = None) -> pandas.DataFrame:
531
+
532
+ if schema is not None:
533
+ table = DataConformance.conform_to_schema(table, schema, warn_extra_columns=False)
534
+ else:
535
+ DataConformance.check_duplicate_fields(table.schema.names, False)
536
+
537
+ # Use Arrow's built-in function to convert to Pandas
538
+ return table.to_pandas(
539
+
540
+ # Mapping for arrow -> pandas types for core types
541
+ types_mapper=self.__ARROW_TO_PANDAS_TYPE_MAPPING.get,
542
+
543
+ # Use Python objects for dates and times if temporal_objects_flag is set
544
+ date_as_object=self.__temporal_objects_flag, # noqa
545
+ timestamp_as_object=self.__temporal_objects_flag, # noqa
546
+
547
+ # Do not bring any Arrow metadata into Pandas dataframe
548
+ ignore_metadata=True, # noqa
549
+
550
+ # Do not consolidate memory across columns when preparing the Pandas vectors
551
+ # This is a significant performance win for very wide datasets
552
+ split_blocks=True) # noqa
553
+
554
+ def to_internal(self, df: pandas.DataFrame, schema: tp.Optional[pa.Schema] = None) -> pa.Table:
555
+
556
+ # Converting pandas -> arrow needs care to ensure type coercion is applied correctly
557
+ # Calling Table.from_pandas with the supplied schema will very often reject data
558
+ # Instead, we convert the dataframe as-is and then apply type conversion in a second step
559
+ # This allows us to apply specific coercion rules for each data type
560
+
561
+ # As an optimisation, the column filter means columns will not be converted if they are not needed
562
+ # E.g. if a model outputs lots of undeclared columns, there is no need to convert them
563
+
564
+ column_filter = DataConformance.column_filter(df.columns, schema) # noqa
565
+
566
+ if len(df) > 0:
567
+
568
+ table = pa.Table.from_pandas(df, columns=column_filter, preserve_index=False) # noqa
569
+
570
+ # Special case handling for converting an empty dataframe
571
+ # These must flow through the pipe with valid schemas, like any other dataset
572
+ # Type coercion and column filtering happen in conform_to_schema, if a schema has been supplied
573
+
574
+ else:
575
+
576
+ empty_df = df.filter(column_filter) if column_filter else df
577
+ empty_schema = pa.Schema.from_pandas(empty_df, preserve_index=False) # noqa
578
+
579
+ table = pa.Table.from_batches(list(), empty_schema) # noqa
580
+
581
+ # If there is no explict schema, give back the table exactly as it was received from Pandas
582
+ # There could be an option here to infer and coerce for TRAC standard types
583
+ # E.g. unsigned int 32 -> signed int 64, TRAC standard integer type
584
+
585
+ if schema is None:
586
+ DataConformance.check_duplicate_fields(table.schema.names, False)
587
+ return table
588
+
589
+ # If a schema has been supplied, apply data conformance
590
+ # If column filtering has been applied, we also need to filter the pandas dtypes used for hinting
591
+
592
+ else:
593
+ df_types = df.dtypes.filter(column_filter) if column_filter else df.dtypes
594
+ return DataConformance.conform_to_schema(table, schema, df_types)
595
+
596
+ def infer_schema(self, dataset: pandas.DataFrame) -> _meta.SchemaDefinition:
597
+
598
+ arrow_schema = pa.Schema.from_pandas(dataset, preserve_index=False) # noqa
599
+ return DataMapping.arrow_to_trac_schema(arrow_schema)
600
+
601
+
602
+ # Data frameworks are optional, do not blow up the module just because one framework is unavailable!
603
+ if polars is not None:
604
+
605
+ class PolarsArrowConverter(DataConverter[polars.DataFrame, pa.Table, pa.Schema]):
606
+
607
+ def __init__(self, framework: _api.DataFramework[T_DATA_API]):
608
+ super().__init__(framework)
609
+
610
+ def from_internal(self, table: pa.Table, schema: tp.Optional[pa.Schema] = None) -> polars.DataFrame:
611
+
612
+ if schema is not None:
613
+ table = DataConformance.conform_to_schema(table, schema, warn_extra_columns=False)
614
+ else:
615
+ DataConformance.check_duplicate_fields(table.schema.names, False)
616
+
617
+ return polars.from_arrow(table)
618
+
619
+ def to_internal(self, df: polars.DataFrame, schema: tp.Optional[pa.Schema] = None,) -> pa.Table:
620
+
621
+ column_filter = DataConformance.column_filter(df.columns, schema)
622
+
623
+ filtered_df = df.select(polars.col(*column_filter)) if column_filter else df
624
+ table = filtered_df.to_arrow()
625
+
626
+ if schema is None:
627
+ DataConformance.check_duplicate_fields(table.schema.names, False)
628
+ return table
629
+ else:
630
+ return DataConformance.conform_to_schema(table, schema, None)
631
+
632
+ def infer_schema(self, dataset: T_DATA_API) -> _meta.SchemaDefinition:
633
+
634
+ arrow_schema = dataset.top_k(1).to_arrow().schema
635
+ return DataMapping.arrow_to_trac_schema(arrow_schema)
470
636
 
471
637
 
472
638
  class DataConformance:
@@ -597,7 +763,7 @@ class DataConformance:
597
763
  # Columns not defined in the schema will not be included in the conformed output
598
764
  if warn_extra_columns and table.num_columns > len(schema.types):
599
765
 
600
- schema_columns = set(map(str.lower, schema.names))
766
+ schema_columns = set(map(lambda c: c.lower(), schema.names))
601
767
  extra_columns = [
602
768
  f"[{col}]"
603
769
  for col in table.schema.names
@@ -784,21 +950,32 @@ class DataConformance:
784
950
  @classmethod
785
951
  def _coerce_string(cls, vector: pa.Array, field: pa.Field) -> pa.Array:
786
952
 
787
- if pa.types.is_string(field.type):
788
- if pa.types.is_string(vector.type):
789
- return vector
953
+ try:
790
954
 
791
- if pa.types.is_large_string(field.type):
792
- if pa.types.is_large_string(vector.type):
793
- return vector
794
- # Allow up-casting string -> large_string
795
- if pa.types.is_string(vector.type):
796
- return pc.cast(vector, field.type)
955
+ if pa.types.is_string(field.type):
956
+ if pa.types.is_string(vector.type):
957
+ return vector
958
+ # Try to down-cast large string -> string, will raise ArrowInvalid if data does not fit
959
+ if pa.types.is_large_string(vector.type):
960
+ return pc.cast(vector, field.type, safe=True)
961
+
962
+ if pa.types.is_large_string(field.type):
963
+ if pa.types.is_large_string(vector.type):
964
+ return vector
965
+ # Allow up-casting string -> large_string
966
+ if pa.types.is_string(vector.type):
967
+ return pc.cast(vector, field.type)
797
968
 
798
- error_message = cls._format_error(cls.__E_WRONG_DATA_TYPE, vector, field)
799
- cls.__log.error(error_message)
969
+ error_message = cls._format_error(cls.__E_WRONG_DATA_TYPE, vector, field)
970
+ cls.__log.error(error_message)
971
+ raise _ex.EDataConformance(error_message)
972
+
973
+ except pa.ArrowInvalid as e:
974
+
975
+ error_message = cls._format_error(cls.__E_DATA_LOSS_DID_OCCUR, vector, field, e)
976
+ cls.__log.error(error_message)
977
+ raise _ex.EDataConformance(error_message) from e
800
978
 
801
- raise _ex.EDataConformance(error_message)
802
979
 
803
980
  @classmethod
804
981
  def _coerce_date(cls, vector: pa.Array, field: pa.Field, pandas_type=None) -> pa.Array:
@@ -816,7 +993,7 @@ class DataConformance:
816
993
  # For Pandas 2.x dates are still np.datetime64 but can be in s, ms, us or ns
817
994
  # This conversion will not apply to dates held in Pandas using the Python date object types
818
995
  if pandas_type is not None:
819
- if pa.types.is_timestamp(vector.type) and pd.api.types.is_datetime64_any_dtype(pandas_type):
996
+ if pa.types.is_timestamp(vector.type) and pandas.api.types.is_datetime64_any_dtype(pandas_type):
820
997
  return pc.cast(vector, field.type)
821
998
 
822
999
  error_message = cls._format_error(cls.__E_WRONG_DATA_TYPE, vector, field)