tracdap-runtime 0.6.5__py3-none-any.whl → 0.6.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -154,7 +154,6 @@ class TracRuntime:
154
154
  _plugins.PluginManager.register_plugin_package(plugin_package)
155
155
 
156
156
  _static_api.StaticApiImpl.register_impl()
157
- _guard.PythonGuardRails.protect_dangerous_functions()
158
157
 
159
158
  # Load sys config (or use embedded), config errors are detected before start()
160
159
  # Job config can also be checked before start() by using load_job_config()
@@ -201,6 +200,11 @@ class TracRuntime:
201
200
  self._models = _models.ModelLoader(self._sys_config, self._scratch_dir)
202
201
  self._storage = _storage.StorageManager(self._sys_config)
203
202
 
203
+ # Enable protection after the initial setup of the runtime is complete
204
+ # Storage plugins in particular are likely to tigger protected imports
205
+ # Once the runtime is up, no more plugins should be loaded
206
+ _guard.PythonGuardRails.protect_dangerous_functions()
207
+
204
208
  self._engine = _engine.TracEngine(
205
209
  self._sys_config, self._models, self._storage,
206
210
  notify_callback=self._engine_callback)
@@ -329,10 +333,8 @@ class TracRuntime:
329
333
  config_file_name="job")
330
334
 
331
335
  if self._dev_mode:
332
- job_config = _dev_mode.DevModeTranslator.translate_job_config(
333
- self._sys_config, job_config,
334
- self._scratch_dir, self._config_mgr,
335
- model_class)
336
+ translator = _dev_mode.DevModeTranslator(self._sys_config, self._config_mgr, self._scratch_dir)
337
+ job_config = translator.translate_job_config(job_config, model_class)
336
338
 
337
339
  return job_config
338
340
 
@@ -341,10 +341,17 @@ class ConfigParser(tp.Generic[_T]):
341
341
 
342
342
  if isinstance(raw_value, tp.Dict):
343
343
  return self._parse_simple_class(location, raw_value, annotation)
344
- elif self._is_dev_mode_location(location) and type(raw_value) in ConfigParser.__primitive_types:
345
- return self._parse_primitive(location, raw_value, type(raw_value))
346
- else:
347
- return self._error(location, f"Expected type {annotation.__name__}, got '{str(raw_value)}'")
344
+
345
+ if self._is_dev_mode_location(location):
346
+ if type(raw_value) in ConfigParser.__primitive_types:
347
+ return self._parse_primitive(location, raw_value, type(raw_value))
348
+ if isinstance(raw_value, list):
349
+ if len(raw_value) == 0:
350
+ return []
351
+ list_type = type(raw_value[0])
352
+ return list(map(lambda x: self._parse_primitive(location, x, list_type), raw_value))
353
+
354
+ return self._error(location, f"Expected type {annotation.__name__}, got '{str(raw_value)}'")
348
355
 
349
356
  if isinstance(annotation, self.__generic_metaclass):
350
357
  return self._parse_generic_class(location, raw_value, annotation) # noqa
tracdap/rt/_impl/data.py CHANGED
@@ -12,6 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ import abc
15
16
  import dataclasses as dc
16
17
  import typing as tp
17
18
  import datetime as dt
@@ -31,6 +32,7 @@ try:
31
32
  except ModuleNotFoundError:
32
33
  polars = None
33
34
 
35
+ import tracdap.rt.api.experimental as _api
34
36
  import tracdap.rt.metadata as _meta
35
37
  import tracdap.rt.exceptions as _ex
36
38
  import tracdap.rt._impl.util as _util
@@ -116,73 +118,19 @@ class DataMapping:
116
118
 
117
119
  # Matches TRAC_ARROW_TYPE_MAPPING in ArrowSchema, tracdap-lib-data
118
120
 
119
- __TRAC_DECIMAL_PRECISION = 38
120
- __TRAC_DECIMAL_SCALE = 12
121
- __TRAC_TIMESTAMP_UNIT = "ms"
122
- __TRAC_TIMESTAMP_ZONE = None
121
+ DEFAULT_DECIMAL_PRECISION = 38
122
+ DEFAULT_DECIMAL_SCALE = 12
123
+ DEFAULT_TIMESTAMP_UNIT = "ms"
124
+ DEFAULT_TIMESTAMP_ZONE = None
123
125
 
124
126
  __TRAC_TO_ARROW_BASIC_TYPE_MAPPING = {
125
127
  _meta.BasicType.BOOLEAN: pa.bool_(),
126
128
  _meta.BasicType.INTEGER: pa.int64(),
127
129
  _meta.BasicType.FLOAT: pa.float64(),
128
- _meta.BasicType.DECIMAL: pa.decimal128(__TRAC_DECIMAL_PRECISION, __TRAC_DECIMAL_SCALE),
130
+ _meta.BasicType.DECIMAL: pa.decimal128(DEFAULT_DECIMAL_PRECISION, DEFAULT_DECIMAL_SCALE),
129
131
  _meta.BasicType.STRING: pa.utf8(),
130
132
  _meta.BasicType.DATE: pa.date32(),
131
- _meta.BasicType.DATETIME: pa.timestamp(__TRAC_TIMESTAMP_UNIT, __TRAC_TIMESTAMP_ZONE)
132
- }
133
-
134
- # Check the Pandas dtypes for handling floats are available before setting up the type mapping
135
- __PANDAS_VERSION_ELEMENTS = pandas.__version__.split(".")
136
- __PANDAS_MAJOR_VERSION = int(__PANDAS_VERSION_ELEMENTS[0])
137
- __PANDAS_MINOR_VERSION = int(__PANDAS_VERSION_ELEMENTS[1])
138
-
139
- if __PANDAS_MAJOR_VERSION == 2:
140
-
141
- __PANDAS_DATE_TYPE = pandas.to_datetime([dt.date(2000, 1, 1)]).as_unit(__TRAC_TIMESTAMP_UNIT).dtype
142
- __PANDAS_DATETIME_TYPE = pandas.to_datetime([dt.datetime(2000, 1, 1, 0, 0, 0)]).as_unit(__TRAC_TIMESTAMP_UNIT).dtype
143
-
144
- @classmethod
145
- def __pandas_datetime_type(cls, tz, unit):
146
- if tz is None and unit is None:
147
- return cls.__PANDAS_DATETIME_TYPE
148
- _unit = unit if unit is not None else cls.__TRAC_TIMESTAMP_UNIT
149
- if tz is None:
150
- return pandas.to_datetime([dt.datetime(2000, 1, 1, 0, 0, 0)]).as_unit(_unit).dtype
151
- else:
152
- return pandas.DatetimeTZDtype(tz=tz, unit=_unit)
153
-
154
- # Minimum supported version for Pandas is 1.2, when pandas.Float64Dtype was introduced
155
- elif __PANDAS_MAJOR_VERSION == 1 and __PANDAS_MINOR_VERSION >= 2:
156
-
157
- __PANDAS_DATE_TYPE = pandas.to_datetime([dt.date(2000, 1, 1)]).dtype
158
- __PANDAS_DATETIME_TYPE = pandas.to_datetime([dt.datetime(2000, 1, 1, 0, 0, 0)]).dtype
159
-
160
- @classmethod
161
- def __pandas_datetime_type(cls, tz, unit): # noqa
162
- if tz is None:
163
- return cls.__PANDAS_DATETIME_TYPE
164
- else:
165
- return pandas.DatetimeTZDtype(tz=tz)
166
-
167
- else:
168
- raise _ex.EStartup(f"Pandas version not supported: [{pandas.__version__}]")
169
-
170
- # Only partial mapping is possible, decimal and temporal dtypes cannot be mapped this way
171
- __ARROW_TO_PANDAS_TYPE_MAPPING = {
172
- pa.bool_(): pandas.BooleanDtype(),
173
- pa.int8(): pandas.Int8Dtype(),
174
- pa.int16(): pandas.Int16Dtype(),
175
- pa.int32(): pandas.Int32Dtype(),
176
- pa.int64(): pandas.Int64Dtype(),
177
- pa.uint8(): pandas.UInt8Dtype(),
178
- pa.uint16(): pandas.UInt16Dtype(),
179
- pa.uint32(): pandas.UInt32Dtype(),
180
- pa.uint64(): pandas.UInt64Dtype(),
181
- pa.float16(): pandas.Float32Dtype(),
182
- pa.float32(): pandas.Float32Dtype(),
183
- pa.float64(): pandas.Float64Dtype(),
184
- pa.string(): pandas.StringDtype(),
185
- pa.utf8(): pandas.StringDtype()
133
+ _meta.BasicType.DATETIME: pa.timestamp(DEFAULT_TIMESTAMP_UNIT, DEFAULT_TIMESTAMP_ZONE)
186
134
  }
187
135
 
188
136
  __ARROW_TO_TRAC_BASIC_TYPE_MAPPING = {
@@ -243,7 +191,7 @@ class DataMapping:
243
191
  return pa.float64()
244
192
 
245
193
  if python_type == decimal.Decimal:
246
- return pa.decimal128(cls.__TRAC_DECIMAL_PRECISION, cls.__TRAC_DECIMAL_SCALE)
194
+ return pa.decimal128(cls.DEFAULT_DECIMAL_PRECISION, cls.DEFAULT_DECIMAL_SCALE)
247
195
 
248
196
  if python_type == str:
249
197
  return pa.utf8()
@@ -252,7 +200,7 @@ class DataMapping:
252
200
  return pa.date32()
253
201
 
254
202
  if python_type == dt.datetime:
255
- return pa.timestamp(cls.__TRAC_TIMESTAMP_UNIT, cls.__TRAC_TIMESTAMP_ZONE)
203
+ return pa.timestamp(cls.DEFAULT_TIMESTAMP_UNIT, cls.DEFAULT_TIMESTAMP_ZONE)
256
204
 
257
205
  raise _ex.ETracInternal(f"No Arrow type mapping available for Python type [{python_type}]")
258
206
 
@@ -293,8 +241,8 @@ class DataMapping:
293
241
  def trac_arrow_decimal_type(cls) -> pa.Decimal128Type:
294
242
 
295
243
  return pa.decimal128(
296
- cls.__TRAC_DECIMAL_PRECISION,
297
- cls.__TRAC_DECIMAL_SCALE)
244
+ cls.DEFAULT_DECIMAL_PRECISION,
245
+ cls.DEFAULT_DECIMAL_SCALE,)
298
246
 
299
247
  @classmethod
300
248
  def arrow_to_trac_schema(cls, arrow_schema: pa.Schema) -> _meta.SchemaDefinition:
@@ -337,41 +285,6 @@ class DataMapping:
337
285
 
338
286
  raise _ex.ETracInternal(f"No data type mapping available for Arrow type [{arrow_type}]")
339
287
 
340
- @classmethod
341
- def pandas_date_type(cls):
342
- return cls.__PANDAS_DATE_TYPE
343
-
344
- @classmethod
345
- def pandas_datetime_type(cls, tz=None, unit=None):
346
- return cls.__pandas_datetime_type(tz, unit)
347
-
348
- @classmethod
349
- def view_to_pandas(
350
- cls, view: DataView, part: DataPartKey, schema: tp.Optional[pa.Schema],
351
- temporal_objects_flag: bool) -> "pandas.DataFrame":
352
-
353
- table = cls.view_to_arrow(view, part)
354
- return cls.arrow_to_pandas(table, schema, temporal_objects_flag)
355
-
356
- @classmethod
357
- def view_to_polars(
358
- cls, view: DataView, part: DataPartKey, schema: tp.Optional[pa.Schema]):
359
-
360
- table = cls.view_to_arrow(view, part)
361
- return cls.arrow_to_polars(table, schema)
362
-
363
- @classmethod
364
- def pandas_to_item(cls, df: "pandas.DataFrame", schema: tp.Optional[pa.Schema]) -> DataItem:
365
-
366
- table = cls.pandas_to_arrow(df, schema)
367
- return DataItem(table.schema, table)
368
-
369
- @classmethod
370
- def polars_to_item(cls, df: "polars.DataFrame", schema: tp.Optional[pa.Schema]) -> DataItem:
371
-
372
- table = cls.polars_to_arrow(df, schema)
373
- return DataItem(table.schema, table)
374
-
375
288
  @classmethod
376
289
  def add_item_to_view(cls, view: DataView, part: DataPartKey, item: DataItem) -> DataView:
377
290
 
@@ -420,108 +333,306 @@ class DataMapping:
420
333
 
421
334
  @classmethod
422
335
  def arrow_to_pandas(
423
- cls, table: pa.Table, schema: tp.Optional[pa.Schema] = None,
336
+ cls, table: pa.Table,
337
+ schema: tp.Optional[pa.Schema] = None,
424
338
  temporal_objects_flag: bool = False) -> "pandas.DataFrame":
425
339
 
426
- if schema is not None:
427
- table = DataConformance.conform_to_schema(table, schema, warn_extra_columns=False)
428
- else:
429
- DataConformance.check_duplicate_fields(table.schema.names, False)
340
+ # This is a legacy internal method and should be removed
341
+ # DataMapping is no longer responsible for individual data APIs
430
342
 
431
- # Use Arrow's built-in function to convert to Pandas
432
- return table.to_pandas(
343
+ # Maintained temporarily for compatibility with existing deployments
433
344
 
434
- # Mapping for arrow -> pandas types for core types
435
- types_mapper=cls.__ARROW_TO_PANDAS_TYPE_MAPPING.get,
345
+ converter = PandasArrowConverter(_api.PANDAS, use_temporal_objects=temporal_objects_flag)
346
+ return converter.from_internal(table, schema)
436
347
 
437
- # Use Python objects for dates and times if temporal_objects_flag is set
438
- date_as_object=temporal_objects_flag, # noqa
439
- timestamp_as_object=temporal_objects_flag, # noqa
348
+ @classmethod
349
+ def pandas_to_arrow(
350
+ cls, df: "pandas.DataFrame",
351
+ schema: tp.Optional[pa.Schema] = None) -> pa.Table:
440
352
 
441
- # Do not bring any Arrow metadata into Pandas dataframe
442
- ignore_metadata=True, # noqa
353
+ # This is a legacy internal method and should be removed
354
+ # DataMapping is no longer responsible for individual data APIs
443
355
 
444
- # Do not consolidate memory across columns when preparing the Pandas vectors
445
- # This is a significant performance win for very wide datasets
446
- split_blocks=True) # noqa
356
+ # Maintained temporarily for compatibility with existing deployments
447
357
 
448
- @classmethod
449
- def arrow_to_polars(
450
- cls, table: pa.Table, schema: tp.Optional[pa.Schema] = None) -> "polars.DataFrame":
358
+ converter = PandasArrowConverter(_api.PANDAS)
359
+ return converter.to_internal(df, schema)
451
360
 
452
- if schema is not None:
453
- table = DataConformance.conform_to_schema(table, schema, warn_extra_columns=False)
454
- else:
455
- DataConformance.check_duplicate_fields(table.schema.names, False)
456
361
 
457
- return polars.from_arrow(table)
458
362
 
459
- @classmethod
460
- def pandas_to_arrow(cls, df: "pandas.DataFrame", schema: tp.Optional[pa.Schema] = None) -> pa.Table:
363
+ T_DATA_API = tp.TypeVar("T_DATA_API")
364
+ T_INTERNAL_DATA = tp.TypeVar("T_INTERNAL_DATA")
365
+ T_INTERNAL_SCHEMA = tp.TypeVar("T_INTERNAL_SCHEMA")
461
366
 
462
- # Converting pandas -> arrow needs care to ensure type coercion is applied correctly
463
- # Calling Table.from_pandas with the supplied schema will very often reject data
464
- # Instead, we convert the dataframe as-is and then apply type conversion in a second step
465
- # This allows us to apply specific coercion rules for each data type
466
367
 
467
- # As an optimisation, the column filter means columns will not be converted if they are not needed
468
- # E.g. if a model outputs lots of undeclared columns, there is no need to convert them
368
+ class DataConverter(tp.Generic[T_DATA_API, T_INTERNAL_DATA, T_INTERNAL_SCHEMA]):
469
369
 
470
- column_filter = DataConformance.column_filter(df.columns, schema) # noqa
370
+ # Available per-framework args, to enable framework-specific type-checking in public APIs
371
+ # These should (for a purist point of view) be in the individual converter classes
372
+ # For now there are only a few converters, they are all defined here so this is OK
373
+ __FRAMEWORK_ARGS = {
374
+ _api.PANDAS: {"use_temporal_objects": tp.Optional[bool]},
375
+ _api.POLARS: {}
376
+ }
471
377
 
472
- if len(df) > 0:
378
+ @classmethod
379
+ def get_framework(cls, dataset: _api.DATA_API) -> _api.DataFramework[_api.DATA_API]:
473
380
 
474
- table = pa.Table.from_pandas(df, columns=column_filter, preserve_index=False) # noqa
381
+ if pandas is not None and isinstance(dataset, pandas.DataFrame):
382
+ return _api.PANDAS
475
383
 
476
- # Special case handling for converting an empty dataframe
477
- # These must flow through the pipe with valid schemas, like any other dataset
478
- # Type coercion and column filtering happen in conform_to_schema, if a schema has been supplied
384
+ if polars is not None and isinstance(dataset, polars.DataFrame):
385
+ return _api.POLARS
479
386
 
480
- else:
387
+ data_api_type = f"{type(dataset).__module__}.{type(dataset).__name__}"
388
+ raise _ex.EPluginNotAvailable(f"No data framework available for type [{data_api_type}]")
481
389
 
482
- empty_df = df.filter(column_filter) if column_filter else df
483
- empty_schema = pa.Schema.from_pandas(empty_df, preserve_index=False) # noqa
390
+ @classmethod
391
+ def get_framework_args(cls, framework: _api.DataFramework[_api.DATA_API]) -> tp.Dict[str, type]:
484
392
 
485
- table = pa.Table.from_batches(list(), empty_schema) # noqa
393
+ return cls.__FRAMEWORK_ARGS.get(framework) or {}
486
394
 
487
- # If there is no explict schema, give back the table exactly as it was received from Pandas
488
- # There could be an option here to infer and coerce for TRAC standard types
489
- # E.g. unsigned int 32 -> signed int 64, TRAC standard integer type
395
+ @classmethod
396
+ def for_framework(cls, framework: _api.DataFramework[_api.DATA_API], **framework_args) -> "DataConverter[_api.DATA_API, pa.Table, pa.Schema]":
490
397
 
491
- if schema is None:
492
- DataConformance.check_duplicate_fields(table.schema.names, False)
493
- return table
398
+ if framework == _api.PANDAS:
399
+ if pandas is not None:
400
+ return PandasArrowConverter(framework, **framework_args)
401
+ else:
402
+ raise _ex.EPluginNotAvailable(f"Optional package [{framework}] is not installed")
494
403
 
495
- # If a schema has been supplied, apply data conformance
496
- # If column filtering has been applied, we also need to filter the pandas dtypes used for hinting
404
+ if framework == _api.POLARS:
405
+ if polars is not None:
406
+ return PolarsArrowConverter(framework)
407
+ else:
408
+ raise _ex.EPluginNotAvailable(f"Optional package [{framework}] is not installed")
497
409
 
498
- else:
499
- df_types = df.dtypes.filter(column_filter) if column_filter else df.dtypes
500
- return DataConformance.conform_to_schema(table, schema, df_types)
410
+ raise _ex.EPluginNotAvailable(f"Data framework [{framework}] is not recognized")
501
411
 
502
412
  @classmethod
503
- def pandas_to_arrow_schema(cls, df: "pandas.DataFrame") -> pa.Schema:
413
+ def for_dataset(cls, dataset: _api.DATA_API) -> "DataConverter[_api.DATA_API, pa.Table, pa.Schema]":
504
414
 
505
- return pa.Schema.from_pandas(df, preserve_index=False) # noqa
415
+ return cls.for_framework(cls.get_framework(dataset))
506
416
 
507
417
  @classmethod
508
- def polars_to_arrow(cls, df: "polars.DataFrame", schema: tp.Optional[pa.Schema] = None) -> pa.Table:
418
+ def noop(cls) -> "DataConverter[T_INTERNAL_DATA, T_INTERNAL_DATA, T_INTERNAL_SCHEMA]":
419
+ return NoopConverter()
420
+
421
+ def __init__(self, framework: _api.DataFramework[T_DATA_API]):
422
+ self.framework = framework
423
+
424
+ @abc.abstractmethod
425
+ def from_internal(self, dataset: T_INTERNAL_DATA, schema: tp.Optional[T_INTERNAL_SCHEMA] = None) -> T_DATA_API:
426
+ pass
427
+
428
+ @abc.abstractmethod
429
+ def to_internal(self, dataset: T_DATA_API, schema: tp.Optional[T_INTERNAL_SCHEMA] = None) -> T_INTERNAL_DATA:
430
+ pass
431
+
432
+ @abc.abstractmethod
433
+ def infer_schema(self, dataset: T_DATA_API) -> _meta.SchemaDefinition:
434
+ pass
435
+
436
+
437
+ class NoopConverter(DataConverter[T_INTERNAL_DATA, T_INTERNAL_DATA, T_INTERNAL_SCHEMA]):
438
+
439
+ def __init__(self):
440
+ super().__init__(_api.DataFramework("internal", None)) # noqa
441
+
442
+ def from_internal(self, dataset: T_INTERNAL_DATA, schema: tp.Optional[T_INTERNAL_SCHEMA] = None) -> T_DATA_API:
443
+ return dataset
444
+
445
+ def to_internal(self, dataset: T_DATA_API, schema: tp.Optional[T_INTERNAL_SCHEMA] = None) -> T_INTERNAL_DATA:
446
+ return dataset
447
+
448
+ def infer_schema(self, dataset: T_DATA_API) -> _meta.SchemaDefinition:
449
+ raise _ex.EUnexpected() # A real converter should be selected before use
450
+
451
+
452
+ # Data frameworks are optional, do not blow up the module just because one framework is unavailable!
453
+ if pandas is not None:
509
454
 
510
- column_filter = DataConformance.column_filter(df.columns, schema)
455
+ class PandasArrowConverter(DataConverter[pandas.DataFrame, pa.Table, pa.Schema]):
511
456
 
512
- filtered_df = df.select(polars.col(*column_filter)) if column_filter else df
513
- table = filtered_df.to_arrow()
457
+ # Check the Pandas dtypes for handling floats are available before setting up the type mapping
458
+ __PANDAS_VERSION_ELEMENTS = pandas.__version__.split(".")
459
+ __PANDAS_MAJOR_VERSION = int(__PANDAS_VERSION_ELEMENTS[0])
460
+ __PANDAS_MINOR_VERSION = int(__PANDAS_VERSION_ELEMENTS[1])
461
+
462
+ if __PANDAS_MAJOR_VERSION == 2:
463
+
464
+ __PANDAS_DATE_TYPE = pandas.to_datetime([dt.date(2000, 1, 1)]).as_unit(DataMapping.DEFAULT_TIMESTAMP_UNIT).dtype
465
+ __PANDAS_DATETIME_TYPE = pandas.to_datetime([dt.datetime(2000, 1, 1, 0, 0, 0)]).as_unit(DataMapping.DEFAULT_TIMESTAMP_UNIT).dtype
466
+
467
+ @classmethod
468
+ def __pandas_datetime_type(cls, tz, unit):
469
+ if tz is None and unit is None:
470
+ return cls.__PANDAS_DATETIME_TYPE
471
+ _unit = unit if unit is not None else DataMapping.DEFAULT_TIMESTAMP_UNIT
472
+ if tz is None:
473
+ return pandas.to_datetime([dt.datetime(2000, 1, 1, 0, 0, 0)]).as_unit(_unit).dtype
474
+ else:
475
+ return pandas.DatetimeTZDtype(tz=tz, unit=_unit)
476
+
477
+ # Minimum supported version for Pandas is 1.2, when pandas.Float64Dtype was introduced
478
+ elif __PANDAS_MAJOR_VERSION == 1 and __PANDAS_MINOR_VERSION >= 2:
479
+
480
+ __PANDAS_DATE_TYPE = pandas.to_datetime([dt.date(2000, 1, 1)]).dtype
481
+ __PANDAS_DATETIME_TYPE = pandas.to_datetime([dt.datetime(2000, 1, 1, 0, 0, 0)]).dtype
482
+
483
+ @classmethod
484
+ def __pandas_datetime_type(cls, tz, unit): # noqa
485
+ if tz is None:
486
+ return cls.__PANDAS_DATETIME_TYPE
487
+ else:
488
+ return pandas.DatetimeTZDtype(tz=tz)
514
489
 
515
- if schema is None:
516
- DataConformance.check_duplicate_fields(table.schema.names, False)
517
- return table
518
490
  else:
519
- return DataConformance.conform_to_schema(table, schema, None)
491
+ raise _ex.EStartup(f"Pandas version not supported: [{pandas.__version__}]")
492
+
493
+ # Only partial mapping is possible, decimal and temporal dtypes cannot be mapped this way
494
+ __ARROW_TO_PANDAS_TYPE_MAPPING = {
495
+ pa.bool_(): pandas.BooleanDtype(),
496
+ pa.int8(): pandas.Int8Dtype(),
497
+ pa.int16(): pandas.Int16Dtype(),
498
+ pa.int32(): pandas.Int32Dtype(),
499
+ pa.int64(): pandas.Int64Dtype(),
500
+ pa.uint8(): pandas.UInt8Dtype(),
501
+ pa.uint16(): pandas.UInt16Dtype(),
502
+ pa.uint32(): pandas.UInt32Dtype(),
503
+ pa.uint64(): pandas.UInt64Dtype(),
504
+ pa.float16(): pandas.Float32Dtype(),
505
+ pa.float32(): pandas.Float32Dtype(),
506
+ pa.float64(): pandas.Float64Dtype(),
507
+ pa.string(): pandas.StringDtype(),
508
+ pa.utf8(): pandas.StringDtype()
509
+ }
510
+
511
+ __DEFAULT_TEMPORAL_OBJECTS = False
512
+
513
+ # Expose date type for testing
514
+ @classmethod
515
+ def pandas_date_type(cls):
516
+ return cls.__PANDAS_DATE_TYPE
520
517
 
521
- @classmethod
522
- def polars_to_arrow_schema(cls, df: "polars.DataFrame") -> pa.Schema:
518
+ # Expose datetime type for testing
519
+ @classmethod
520
+ def pandas_datetime_type(cls, tz=None, unit=None):
521
+ return cls.__pandas_datetime_type(tz, unit)
522
+
523
+ def __init__(self, framework: _api.DataFramework[T_DATA_API], use_temporal_objects: tp.Optional[bool] = None):
524
+ super().__init__(framework)
525
+ if use_temporal_objects is None:
526
+ self.__temporal_objects_flag = self.__DEFAULT_TEMPORAL_OBJECTS
527
+ else:
528
+ self.__temporal_objects_flag = use_temporal_objects
529
+
530
+ def from_internal(self, table: pa.Table, schema: tp.Optional[pa.Schema] = None) -> pandas.DataFrame:
531
+
532
+ if schema is not None:
533
+ table = DataConformance.conform_to_schema(table, schema, warn_extra_columns=False)
534
+ else:
535
+ DataConformance.check_duplicate_fields(table.schema.names, False)
536
+
537
+ # Use Arrow's built-in function to convert to Pandas
538
+ return table.to_pandas(
539
+
540
+ # Mapping for arrow -> pandas types for core types
541
+ types_mapper=self.__ARROW_TO_PANDAS_TYPE_MAPPING.get,
542
+
543
+ # Use Python objects for dates and times if temporal_objects_flag is set
544
+ date_as_object=self.__temporal_objects_flag, # noqa
545
+ timestamp_as_object=self.__temporal_objects_flag, # noqa
546
+
547
+ # Do not bring any Arrow metadata into Pandas dataframe
548
+ ignore_metadata=True, # noqa
549
+
550
+ # Do not consolidate memory across columns when preparing the Pandas vectors
551
+ # This is a significant performance win for very wide datasets
552
+ split_blocks=True) # noqa
553
+
554
+ def to_internal(self, df: pandas.DataFrame, schema: tp.Optional[pa.Schema] = None) -> pa.Table:
555
+
556
+ # Converting pandas -> arrow needs care to ensure type coercion is applied correctly
557
+ # Calling Table.from_pandas with the supplied schema will very often reject data
558
+ # Instead, we convert the dataframe as-is and then apply type conversion in a second step
559
+ # This allows us to apply specific coercion rules for each data type
560
+
561
+ # As an optimisation, the column filter means columns will not be converted if they are not needed
562
+ # E.g. if a model outputs lots of undeclared columns, there is no need to convert them
563
+
564
+ column_filter = DataConformance.column_filter(df.columns, schema) # noqa
565
+
566
+ if len(df) > 0:
567
+
568
+ table = pa.Table.from_pandas(df, columns=column_filter, preserve_index=False) # noqa
569
+
570
+ # Special case handling for converting an empty dataframe
571
+ # These must flow through the pipe with valid schemas, like any other dataset
572
+ # Type coercion and column filtering happen in conform_to_schema, if a schema has been supplied
573
+
574
+ else:
575
+
576
+ empty_df = df.filter(column_filter) if column_filter else df
577
+ empty_schema = pa.Schema.from_pandas(empty_df, preserve_index=False) # noqa
578
+
579
+ table = pa.Table.from_batches(list(), empty_schema) # noqa
580
+
581
+ # If there is no explict schema, give back the table exactly as it was received from Pandas
582
+ # There could be an option here to infer and coerce for TRAC standard types
583
+ # E.g. unsigned int 32 -> signed int 64, TRAC standard integer type
584
+
585
+ if schema is None:
586
+ DataConformance.check_duplicate_fields(table.schema.names, False)
587
+ return table
588
+
589
+ # If a schema has been supplied, apply data conformance
590
+ # If column filtering has been applied, we also need to filter the pandas dtypes used for hinting
591
+
592
+ else:
593
+ df_types = df.dtypes.filter(column_filter) if column_filter else df.dtypes
594
+ return DataConformance.conform_to_schema(table, schema, df_types)
595
+
596
+ def infer_schema(self, dataset: pandas.DataFrame) -> _meta.SchemaDefinition:
597
+
598
+ arrow_schema = pa.Schema.from_pandas(dataset, preserve_index=False) # noqa
599
+ return DataMapping.arrow_to_trac_schema(arrow_schema)
600
+
601
+
602
+ # Data frameworks are optional, do not blow up the module just because one framework is unavailable!
603
+ if polars is not None:
604
+
605
+ class PolarsArrowConverter(DataConverter[polars.DataFrame, pa.Table, pa.Schema]):
606
+
607
+ def __init__(self, framework: _api.DataFramework[T_DATA_API]):
608
+ super().__init__(framework)
609
+
610
+ def from_internal(self, table: pa.Table, schema: tp.Optional[pa.Schema] = None) -> polars.DataFrame:
611
+
612
+ if schema is not None:
613
+ table = DataConformance.conform_to_schema(table, schema, warn_extra_columns=False)
614
+ else:
615
+ DataConformance.check_duplicate_fields(table.schema.names, False)
616
+
617
+ return polars.from_arrow(table)
618
+
619
+ def to_internal(self, df: polars.DataFrame, schema: tp.Optional[pa.Schema] = None,) -> pa.Table:
620
+
621
+ column_filter = DataConformance.column_filter(df.columns, schema)
622
+
623
+ filtered_df = df.select(polars.col(*column_filter)) if column_filter else df
624
+ table = filtered_df.to_arrow()
625
+
626
+ if schema is None:
627
+ DataConformance.check_duplicate_fields(table.schema.names, False)
628
+ return table
629
+ else:
630
+ return DataConformance.conform_to_schema(table, schema, None)
631
+
632
+ def infer_schema(self, dataset: T_DATA_API) -> _meta.SchemaDefinition:
523
633
 
524
- return df.top_k(1).to_arrow().schema
634
+ arrow_schema = dataset.top_k(1).to_arrow().schema
635
+ return DataMapping.arrow_to_trac_schema(arrow_schema)
525
636
 
526
637
 
527
638
  class DataConformance:
@@ -652,7 +763,7 @@ class DataConformance:
652
763
  # Columns not defined in the schema will not be included in the conformed output
653
764
  if warn_extra_columns and table.num_columns > len(schema.types):
654
765
 
655
- schema_columns = set(map(str.lower, schema.names))
766
+ schema_columns = set(map(lambda c: c.lower(), schema.names))
656
767
  extra_columns = [
657
768
  f"[{col}]"
658
769
  for col in table.schema.names
@@ -0,0 +1,13 @@
1
+ # Copyright 2024 Accenture Global Solutions Limited
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.