tracdap-runtime 0.6.5__py3-none-any.whl → 0.6.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tracdap/rt/_exec/context.py +272 -105
- tracdap/rt/_exec/dev_mode.py +231 -138
- tracdap/rt/_exec/engine.py +217 -59
- tracdap/rt/_exec/functions.py +25 -1
- tracdap/rt/_exec/graph.py +9 -0
- tracdap/rt/_exec/graph_builder.py +295 -198
- tracdap/rt/_exec/runtime.py +7 -5
- tracdap/rt/_impl/config_parser.py +11 -4
- tracdap/rt/_impl/data.py +278 -167
- tracdap/rt/_impl/ext/__init__.py +13 -0
- tracdap/rt/_impl/ext/sql.py +116 -0
- tracdap/rt/_impl/ext/storage.py +57 -0
- tracdap/rt/_impl/grpc/tracdap/metadata/job_pb2.py +62 -54
- tracdap/rt/_impl/grpc/tracdap/metadata/job_pb2.pyi +37 -2
- tracdap/rt/_impl/static_api.py +24 -11
- tracdap/rt/_impl/storage.py +2 -2
- tracdap/rt/_impl/util.py +10 -0
- tracdap/rt/_impl/validation.py +66 -13
- tracdap/rt/_plugins/storage_sql.py +417 -0
- tracdap/rt/_plugins/storage_sql_dialects.py +117 -0
- tracdap/rt/_version.py +1 -1
- tracdap/rt/api/experimental.py +79 -32
- tracdap/rt/api/hook.py +10 -0
- tracdap/rt/metadata/__init__.py +4 -0
- tracdap/rt/metadata/job.py +45 -0
- {tracdap_runtime-0.6.5.dist-info → tracdap_runtime-0.6.6.dist-info}/METADATA +3 -1
- {tracdap_runtime-0.6.5.dist-info → tracdap_runtime-0.6.6.dist-info}/RECORD +30 -25
- {tracdap_runtime-0.6.5.dist-info → tracdap_runtime-0.6.6.dist-info}/WHEEL +1 -1
- {tracdap_runtime-0.6.5.dist-info → tracdap_runtime-0.6.6.dist-info}/LICENSE +0 -0
- {tracdap_runtime-0.6.5.dist-info → tracdap_runtime-0.6.6.dist-info}/top_level.txt +0 -0
tracdap/rt/_exec/runtime.py
CHANGED
@@ -154,7 +154,6 @@ class TracRuntime:
|
|
154
154
|
_plugins.PluginManager.register_plugin_package(plugin_package)
|
155
155
|
|
156
156
|
_static_api.StaticApiImpl.register_impl()
|
157
|
-
_guard.PythonGuardRails.protect_dangerous_functions()
|
158
157
|
|
159
158
|
# Load sys config (or use embedded), config errors are detected before start()
|
160
159
|
# Job config can also be checked before start() by using load_job_config()
|
@@ -201,6 +200,11 @@ class TracRuntime:
|
|
201
200
|
self._models = _models.ModelLoader(self._sys_config, self._scratch_dir)
|
202
201
|
self._storage = _storage.StorageManager(self._sys_config)
|
203
202
|
|
203
|
+
# Enable protection after the initial setup of the runtime is complete
|
204
|
+
# Storage plugins in particular are likely to tigger protected imports
|
205
|
+
# Once the runtime is up, no more plugins should be loaded
|
206
|
+
_guard.PythonGuardRails.protect_dangerous_functions()
|
207
|
+
|
204
208
|
self._engine = _engine.TracEngine(
|
205
209
|
self._sys_config, self._models, self._storage,
|
206
210
|
notify_callback=self._engine_callback)
|
@@ -329,10 +333,8 @@ class TracRuntime:
|
|
329
333
|
config_file_name="job")
|
330
334
|
|
331
335
|
if self._dev_mode:
|
332
|
-
|
333
|
-
|
334
|
-
self._scratch_dir, self._config_mgr,
|
335
|
-
model_class)
|
336
|
+
translator = _dev_mode.DevModeTranslator(self._sys_config, self._config_mgr, self._scratch_dir)
|
337
|
+
job_config = translator.translate_job_config(job_config, model_class)
|
336
338
|
|
337
339
|
return job_config
|
338
340
|
|
@@ -341,10 +341,17 @@ class ConfigParser(tp.Generic[_T]):
|
|
341
341
|
|
342
342
|
if isinstance(raw_value, tp.Dict):
|
343
343
|
return self._parse_simple_class(location, raw_value, annotation)
|
344
|
-
|
345
|
-
|
346
|
-
|
347
|
-
|
344
|
+
|
345
|
+
if self._is_dev_mode_location(location):
|
346
|
+
if type(raw_value) in ConfigParser.__primitive_types:
|
347
|
+
return self._parse_primitive(location, raw_value, type(raw_value))
|
348
|
+
if isinstance(raw_value, list):
|
349
|
+
if len(raw_value) == 0:
|
350
|
+
return []
|
351
|
+
list_type = type(raw_value[0])
|
352
|
+
return list(map(lambda x: self._parse_primitive(location, x, list_type), raw_value))
|
353
|
+
|
354
|
+
return self._error(location, f"Expected type {annotation.__name__}, got '{str(raw_value)}'")
|
348
355
|
|
349
356
|
if isinstance(annotation, self.__generic_metaclass):
|
350
357
|
return self._parse_generic_class(location, raw_value, annotation) # noqa
|
tracdap/rt/_impl/data.py
CHANGED
@@ -12,6 +12,7 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
+
import abc
|
15
16
|
import dataclasses as dc
|
16
17
|
import typing as tp
|
17
18
|
import datetime as dt
|
@@ -31,6 +32,7 @@ try:
|
|
31
32
|
except ModuleNotFoundError:
|
32
33
|
polars = None
|
33
34
|
|
35
|
+
import tracdap.rt.api.experimental as _api
|
34
36
|
import tracdap.rt.metadata as _meta
|
35
37
|
import tracdap.rt.exceptions as _ex
|
36
38
|
import tracdap.rt._impl.util as _util
|
@@ -116,73 +118,19 @@ class DataMapping:
|
|
116
118
|
|
117
119
|
# Matches TRAC_ARROW_TYPE_MAPPING in ArrowSchema, tracdap-lib-data
|
118
120
|
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
|
121
|
+
DEFAULT_DECIMAL_PRECISION = 38
|
122
|
+
DEFAULT_DECIMAL_SCALE = 12
|
123
|
+
DEFAULT_TIMESTAMP_UNIT = "ms"
|
124
|
+
DEFAULT_TIMESTAMP_ZONE = None
|
123
125
|
|
124
126
|
__TRAC_TO_ARROW_BASIC_TYPE_MAPPING = {
|
125
127
|
_meta.BasicType.BOOLEAN: pa.bool_(),
|
126
128
|
_meta.BasicType.INTEGER: pa.int64(),
|
127
129
|
_meta.BasicType.FLOAT: pa.float64(),
|
128
|
-
_meta.BasicType.DECIMAL: pa.decimal128(
|
130
|
+
_meta.BasicType.DECIMAL: pa.decimal128(DEFAULT_DECIMAL_PRECISION, DEFAULT_DECIMAL_SCALE),
|
129
131
|
_meta.BasicType.STRING: pa.utf8(),
|
130
132
|
_meta.BasicType.DATE: pa.date32(),
|
131
|
-
_meta.BasicType.DATETIME: pa.timestamp(
|
132
|
-
}
|
133
|
-
|
134
|
-
# Check the Pandas dtypes for handling floats are available before setting up the type mapping
|
135
|
-
__PANDAS_VERSION_ELEMENTS = pandas.__version__.split(".")
|
136
|
-
__PANDAS_MAJOR_VERSION = int(__PANDAS_VERSION_ELEMENTS[0])
|
137
|
-
__PANDAS_MINOR_VERSION = int(__PANDAS_VERSION_ELEMENTS[1])
|
138
|
-
|
139
|
-
if __PANDAS_MAJOR_VERSION == 2:
|
140
|
-
|
141
|
-
__PANDAS_DATE_TYPE = pandas.to_datetime([dt.date(2000, 1, 1)]).as_unit(__TRAC_TIMESTAMP_UNIT).dtype
|
142
|
-
__PANDAS_DATETIME_TYPE = pandas.to_datetime([dt.datetime(2000, 1, 1, 0, 0, 0)]).as_unit(__TRAC_TIMESTAMP_UNIT).dtype
|
143
|
-
|
144
|
-
@classmethod
|
145
|
-
def __pandas_datetime_type(cls, tz, unit):
|
146
|
-
if tz is None and unit is None:
|
147
|
-
return cls.__PANDAS_DATETIME_TYPE
|
148
|
-
_unit = unit if unit is not None else cls.__TRAC_TIMESTAMP_UNIT
|
149
|
-
if tz is None:
|
150
|
-
return pandas.to_datetime([dt.datetime(2000, 1, 1, 0, 0, 0)]).as_unit(_unit).dtype
|
151
|
-
else:
|
152
|
-
return pandas.DatetimeTZDtype(tz=tz, unit=_unit)
|
153
|
-
|
154
|
-
# Minimum supported version for Pandas is 1.2, when pandas.Float64Dtype was introduced
|
155
|
-
elif __PANDAS_MAJOR_VERSION == 1 and __PANDAS_MINOR_VERSION >= 2:
|
156
|
-
|
157
|
-
__PANDAS_DATE_TYPE = pandas.to_datetime([dt.date(2000, 1, 1)]).dtype
|
158
|
-
__PANDAS_DATETIME_TYPE = pandas.to_datetime([dt.datetime(2000, 1, 1, 0, 0, 0)]).dtype
|
159
|
-
|
160
|
-
@classmethod
|
161
|
-
def __pandas_datetime_type(cls, tz, unit): # noqa
|
162
|
-
if tz is None:
|
163
|
-
return cls.__PANDAS_DATETIME_TYPE
|
164
|
-
else:
|
165
|
-
return pandas.DatetimeTZDtype(tz=tz)
|
166
|
-
|
167
|
-
else:
|
168
|
-
raise _ex.EStartup(f"Pandas version not supported: [{pandas.__version__}]")
|
169
|
-
|
170
|
-
# Only partial mapping is possible, decimal and temporal dtypes cannot be mapped this way
|
171
|
-
__ARROW_TO_PANDAS_TYPE_MAPPING = {
|
172
|
-
pa.bool_(): pandas.BooleanDtype(),
|
173
|
-
pa.int8(): pandas.Int8Dtype(),
|
174
|
-
pa.int16(): pandas.Int16Dtype(),
|
175
|
-
pa.int32(): pandas.Int32Dtype(),
|
176
|
-
pa.int64(): pandas.Int64Dtype(),
|
177
|
-
pa.uint8(): pandas.UInt8Dtype(),
|
178
|
-
pa.uint16(): pandas.UInt16Dtype(),
|
179
|
-
pa.uint32(): pandas.UInt32Dtype(),
|
180
|
-
pa.uint64(): pandas.UInt64Dtype(),
|
181
|
-
pa.float16(): pandas.Float32Dtype(),
|
182
|
-
pa.float32(): pandas.Float32Dtype(),
|
183
|
-
pa.float64(): pandas.Float64Dtype(),
|
184
|
-
pa.string(): pandas.StringDtype(),
|
185
|
-
pa.utf8(): pandas.StringDtype()
|
133
|
+
_meta.BasicType.DATETIME: pa.timestamp(DEFAULT_TIMESTAMP_UNIT, DEFAULT_TIMESTAMP_ZONE)
|
186
134
|
}
|
187
135
|
|
188
136
|
__ARROW_TO_TRAC_BASIC_TYPE_MAPPING = {
|
@@ -243,7 +191,7 @@ class DataMapping:
|
|
243
191
|
return pa.float64()
|
244
192
|
|
245
193
|
if python_type == decimal.Decimal:
|
246
|
-
return pa.decimal128(cls.
|
194
|
+
return pa.decimal128(cls.DEFAULT_DECIMAL_PRECISION, cls.DEFAULT_DECIMAL_SCALE)
|
247
195
|
|
248
196
|
if python_type == str:
|
249
197
|
return pa.utf8()
|
@@ -252,7 +200,7 @@ class DataMapping:
|
|
252
200
|
return pa.date32()
|
253
201
|
|
254
202
|
if python_type == dt.datetime:
|
255
|
-
return pa.timestamp(cls.
|
203
|
+
return pa.timestamp(cls.DEFAULT_TIMESTAMP_UNIT, cls.DEFAULT_TIMESTAMP_ZONE)
|
256
204
|
|
257
205
|
raise _ex.ETracInternal(f"No Arrow type mapping available for Python type [{python_type}]")
|
258
206
|
|
@@ -293,8 +241,8 @@ class DataMapping:
|
|
293
241
|
def trac_arrow_decimal_type(cls) -> pa.Decimal128Type:
|
294
242
|
|
295
243
|
return pa.decimal128(
|
296
|
-
cls.
|
297
|
-
cls.
|
244
|
+
cls.DEFAULT_DECIMAL_PRECISION,
|
245
|
+
cls.DEFAULT_DECIMAL_SCALE,)
|
298
246
|
|
299
247
|
@classmethod
|
300
248
|
def arrow_to_trac_schema(cls, arrow_schema: pa.Schema) -> _meta.SchemaDefinition:
|
@@ -337,41 +285,6 @@ class DataMapping:
|
|
337
285
|
|
338
286
|
raise _ex.ETracInternal(f"No data type mapping available for Arrow type [{arrow_type}]")
|
339
287
|
|
340
|
-
@classmethod
|
341
|
-
def pandas_date_type(cls):
|
342
|
-
return cls.__PANDAS_DATE_TYPE
|
343
|
-
|
344
|
-
@classmethod
|
345
|
-
def pandas_datetime_type(cls, tz=None, unit=None):
|
346
|
-
return cls.__pandas_datetime_type(tz, unit)
|
347
|
-
|
348
|
-
@classmethod
|
349
|
-
def view_to_pandas(
|
350
|
-
cls, view: DataView, part: DataPartKey, schema: tp.Optional[pa.Schema],
|
351
|
-
temporal_objects_flag: bool) -> "pandas.DataFrame":
|
352
|
-
|
353
|
-
table = cls.view_to_arrow(view, part)
|
354
|
-
return cls.arrow_to_pandas(table, schema, temporal_objects_flag)
|
355
|
-
|
356
|
-
@classmethod
|
357
|
-
def view_to_polars(
|
358
|
-
cls, view: DataView, part: DataPartKey, schema: tp.Optional[pa.Schema]):
|
359
|
-
|
360
|
-
table = cls.view_to_arrow(view, part)
|
361
|
-
return cls.arrow_to_polars(table, schema)
|
362
|
-
|
363
|
-
@classmethod
|
364
|
-
def pandas_to_item(cls, df: "pandas.DataFrame", schema: tp.Optional[pa.Schema]) -> DataItem:
|
365
|
-
|
366
|
-
table = cls.pandas_to_arrow(df, schema)
|
367
|
-
return DataItem(table.schema, table)
|
368
|
-
|
369
|
-
@classmethod
|
370
|
-
def polars_to_item(cls, df: "polars.DataFrame", schema: tp.Optional[pa.Schema]) -> DataItem:
|
371
|
-
|
372
|
-
table = cls.polars_to_arrow(df, schema)
|
373
|
-
return DataItem(table.schema, table)
|
374
|
-
|
375
288
|
@classmethod
|
376
289
|
def add_item_to_view(cls, view: DataView, part: DataPartKey, item: DataItem) -> DataView:
|
377
290
|
|
@@ -420,108 +333,306 @@ class DataMapping:
|
|
420
333
|
|
421
334
|
@classmethod
|
422
335
|
def arrow_to_pandas(
|
423
|
-
cls, table: pa.Table,
|
336
|
+
cls, table: pa.Table,
|
337
|
+
schema: tp.Optional[pa.Schema] = None,
|
424
338
|
temporal_objects_flag: bool = False) -> "pandas.DataFrame":
|
425
339
|
|
426
|
-
|
427
|
-
|
428
|
-
else:
|
429
|
-
DataConformance.check_duplicate_fields(table.schema.names, False)
|
340
|
+
# This is a legacy internal method and should be removed
|
341
|
+
# DataMapping is no longer responsible for individual data APIs
|
430
342
|
|
431
|
-
#
|
432
|
-
return table.to_pandas(
|
343
|
+
# Maintained temporarily for compatibility with existing deployments
|
433
344
|
|
434
|
-
|
435
|
-
|
345
|
+
converter = PandasArrowConverter(_api.PANDAS, use_temporal_objects=temporal_objects_flag)
|
346
|
+
return converter.from_internal(table, schema)
|
436
347
|
|
437
|
-
|
438
|
-
|
439
|
-
|
348
|
+
@classmethod
|
349
|
+
def pandas_to_arrow(
|
350
|
+
cls, df: "pandas.DataFrame",
|
351
|
+
schema: tp.Optional[pa.Schema] = None) -> pa.Table:
|
440
352
|
|
441
|
-
|
442
|
-
|
353
|
+
# This is a legacy internal method and should be removed
|
354
|
+
# DataMapping is no longer responsible for individual data APIs
|
443
355
|
|
444
|
-
|
445
|
-
# This is a significant performance win for very wide datasets
|
446
|
-
split_blocks=True) # noqa
|
356
|
+
# Maintained temporarily for compatibility with existing deployments
|
447
357
|
|
448
|
-
|
449
|
-
|
450
|
-
cls, table: pa.Table, schema: tp.Optional[pa.Schema] = None) -> "polars.DataFrame":
|
358
|
+
converter = PandasArrowConverter(_api.PANDAS)
|
359
|
+
return converter.to_internal(df, schema)
|
451
360
|
|
452
|
-
if schema is not None:
|
453
|
-
table = DataConformance.conform_to_schema(table, schema, warn_extra_columns=False)
|
454
|
-
else:
|
455
|
-
DataConformance.check_duplicate_fields(table.schema.names, False)
|
456
361
|
|
457
|
-
return polars.from_arrow(table)
|
458
362
|
|
459
|
-
|
460
|
-
|
363
|
+
T_DATA_API = tp.TypeVar("T_DATA_API")
|
364
|
+
T_INTERNAL_DATA = tp.TypeVar("T_INTERNAL_DATA")
|
365
|
+
T_INTERNAL_SCHEMA = tp.TypeVar("T_INTERNAL_SCHEMA")
|
461
366
|
|
462
|
-
# Converting pandas -> arrow needs care to ensure type coercion is applied correctly
|
463
|
-
# Calling Table.from_pandas with the supplied schema will very often reject data
|
464
|
-
# Instead, we convert the dataframe as-is and then apply type conversion in a second step
|
465
|
-
# This allows us to apply specific coercion rules for each data type
|
466
367
|
|
467
|
-
|
468
|
-
# E.g. if a model outputs lots of undeclared columns, there is no need to convert them
|
368
|
+
class DataConverter(tp.Generic[T_DATA_API, T_INTERNAL_DATA, T_INTERNAL_SCHEMA]):
|
469
369
|
|
470
|
-
|
370
|
+
# Available per-framework args, to enable framework-specific type-checking in public APIs
|
371
|
+
# These should (for a purist point of view) be in the individual converter classes
|
372
|
+
# For now there are only a few converters, they are all defined here so this is OK
|
373
|
+
__FRAMEWORK_ARGS = {
|
374
|
+
_api.PANDAS: {"use_temporal_objects": tp.Optional[bool]},
|
375
|
+
_api.POLARS: {}
|
376
|
+
}
|
471
377
|
|
472
|
-
|
378
|
+
@classmethod
|
379
|
+
def get_framework(cls, dataset: _api.DATA_API) -> _api.DataFramework[_api.DATA_API]:
|
473
380
|
|
474
|
-
|
381
|
+
if pandas is not None and isinstance(dataset, pandas.DataFrame):
|
382
|
+
return _api.PANDAS
|
475
383
|
|
476
|
-
|
477
|
-
|
478
|
-
# Type coercion and column filtering happen in conform_to_schema, if a schema has been supplied
|
384
|
+
if polars is not None and isinstance(dataset, polars.DataFrame):
|
385
|
+
return _api.POLARS
|
479
386
|
|
480
|
-
|
387
|
+
data_api_type = f"{type(dataset).__module__}.{type(dataset).__name__}"
|
388
|
+
raise _ex.EPluginNotAvailable(f"No data framework available for type [{data_api_type}]")
|
481
389
|
|
482
|
-
|
483
|
-
|
390
|
+
@classmethod
|
391
|
+
def get_framework_args(cls, framework: _api.DataFramework[_api.DATA_API]) -> tp.Dict[str, type]:
|
484
392
|
|
485
|
-
|
393
|
+
return cls.__FRAMEWORK_ARGS.get(framework) or {}
|
486
394
|
|
487
|
-
|
488
|
-
|
489
|
-
# E.g. unsigned int 32 -> signed int 64, TRAC standard integer type
|
395
|
+
@classmethod
|
396
|
+
def for_framework(cls, framework: _api.DataFramework[_api.DATA_API], **framework_args) -> "DataConverter[_api.DATA_API, pa.Table, pa.Schema]":
|
490
397
|
|
491
|
-
if
|
492
|
-
|
493
|
-
|
398
|
+
if framework == _api.PANDAS:
|
399
|
+
if pandas is not None:
|
400
|
+
return PandasArrowConverter(framework, **framework_args)
|
401
|
+
else:
|
402
|
+
raise _ex.EPluginNotAvailable(f"Optional package [{framework}] is not installed")
|
494
403
|
|
495
|
-
|
496
|
-
|
404
|
+
if framework == _api.POLARS:
|
405
|
+
if polars is not None:
|
406
|
+
return PolarsArrowConverter(framework)
|
407
|
+
else:
|
408
|
+
raise _ex.EPluginNotAvailable(f"Optional package [{framework}] is not installed")
|
497
409
|
|
498
|
-
|
499
|
-
df_types = df.dtypes.filter(column_filter) if column_filter else df.dtypes
|
500
|
-
return DataConformance.conform_to_schema(table, schema, df_types)
|
410
|
+
raise _ex.EPluginNotAvailable(f"Data framework [{framework}] is not recognized")
|
501
411
|
|
502
412
|
@classmethod
|
503
|
-
def
|
413
|
+
def for_dataset(cls, dataset: _api.DATA_API) -> "DataConverter[_api.DATA_API, pa.Table, pa.Schema]":
|
504
414
|
|
505
|
-
return
|
415
|
+
return cls.for_framework(cls.get_framework(dataset))
|
506
416
|
|
507
417
|
@classmethod
|
508
|
-
def
|
418
|
+
def noop(cls) -> "DataConverter[T_INTERNAL_DATA, T_INTERNAL_DATA, T_INTERNAL_SCHEMA]":
|
419
|
+
return NoopConverter()
|
420
|
+
|
421
|
+
def __init__(self, framework: _api.DataFramework[T_DATA_API]):
|
422
|
+
self.framework = framework
|
423
|
+
|
424
|
+
@abc.abstractmethod
|
425
|
+
def from_internal(self, dataset: T_INTERNAL_DATA, schema: tp.Optional[T_INTERNAL_SCHEMA] = None) -> T_DATA_API:
|
426
|
+
pass
|
427
|
+
|
428
|
+
@abc.abstractmethod
|
429
|
+
def to_internal(self, dataset: T_DATA_API, schema: tp.Optional[T_INTERNAL_SCHEMA] = None) -> T_INTERNAL_DATA:
|
430
|
+
pass
|
431
|
+
|
432
|
+
@abc.abstractmethod
|
433
|
+
def infer_schema(self, dataset: T_DATA_API) -> _meta.SchemaDefinition:
|
434
|
+
pass
|
435
|
+
|
436
|
+
|
437
|
+
class NoopConverter(DataConverter[T_INTERNAL_DATA, T_INTERNAL_DATA, T_INTERNAL_SCHEMA]):
|
438
|
+
|
439
|
+
def __init__(self):
|
440
|
+
super().__init__(_api.DataFramework("internal", None)) # noqa
|
441
|
+
|
442
|
+
def from_internal(self, dataset: T_INTERNAL_DATA, schema: tp.Optional[T_INTERNAL_SCHEMA] = None) -> T_DATA_API:
|
443
|
+
return dataset
|
444
|
+
|
445
|
+
def to_internal(self, dataset: T_DATA_API, schema: tp.Optional[T_INTERNAL_SCHEMA] = None) -> T_INTERNAL_DATA:
|
446
|
+
return dataset
|
447
|
+
|
448
|
+
def infer_schema(self, dataset: T_DATA_API) -> _meta.SchemaDefinition:
|
449
|
+
raise _ex.EUnexpected() # A real converter should be selected before use
|
450
|
+
|
451
|
+
|
452
|
+
# Data frameworks are optional, do not blow up the module just because one framework is unavailable!
|
453
|
+
if pandas is not None:
|
509
454
|
|
510
|
-
|
455
|
+
class PandasArrowConverter(DataConverter[pandas.DataFrame, pa.Table, pa.Schema]):
|
511
456
|
|
512
|
-
|
513
|
-
|
457
|
+
# Check the Pandas dtypes for handling floats are available before setting up the type mapping
|
458
|
+
__PANDAS_VERSION_ELEMENTS = pandas.__version__.split(".")
|
459
|
+
__PANDAS_MAJOR_VERSION = int(__PANDAS_VERSION_ELEMENTS[0])
|
460
|
+
__PANDAS_MINOR_VERSION = int(__PANDAS_VERSION_ELEMENTS[1])
|
461
|
+
|
462
|
+
if __PANDAS_MAJOR_VERSION == 2:
|
463
|
+
|
464
|
+
__PANDAS_DATE_TYPE = pandas.to_datetime([dt.date(2000, 1, 1)]).as_unit(DataMapping.DEFAULT_TIMESTAMP_UNIT).dtype
|
465
|
+
__PANDAS_DATETIME_TYPE = pandas.to_datetime([dt.datetime(2000, 1, 1, 0, 0, 0)]).as_unit(DataMapping.DEFAULT_TIMESTAMP_UNIT).dtype
|
466
|
+
|
467
|
+
@classmethod
|
468
|
+
def __pandas_datetime_type(cls, tz, unit):
|
469
|
+
if tz is None and unit is None:
|
470
|
+
return cls.__PANDAS_DATETIME_TYPE
|
471
|
+
_unit = unit if unit is not None else DataMapping.DEFAULT_TIMESTAMP_UNIT
|
472
|
+
if tz is None:
|
473
|
+
return pandas.to_datetime([dt.datetime(2000, 1, 1, 0, 0, 0)]).as_unit(_unit).dtype
|
474
|
+
else:
|
475
|
+
return pandas.DatetimeTZDtype(tz=tz, unit=_unit)
|
476
|
+
|
477
|
+
# Minimum supported version for Pandas is 1.2, when pandas.Float64Dtype was introduced
|
478
|
+
elif __PANDAS_MAJOR_VERSION == 1 and __PANDAS_MINOR_VERSION >= 2:
|
479
|
+
|
480
|
+
__PANDAS_DATE_TYPE = pandas.to_datetime([dt.date(2000, 1, 1)]).dtype
|
481
|
+
__PANDAS_DATETIME_TYPE = pandas.to_datetime([dt.datetime(2000, 1, 1, 0, 0, 0)]).dtype
|
482
|
+
|
483
|
+
@classmethod
|
484
|
+
def __pandas_datetime_type(cls, tz, unit): # noqa
|
485
|
+
if tz is None:
|
486
|
+
return cls.__PANDAS_DATETIME_TYPE
|
487
|
+
else:
|
488
|
+
return pandas.DatetimeTZDtype(tz=tz)
|
514
489
|
|
515
|
-
if schema is None:
|
516
|
-
DataConformance.check_duplicate_fields(table.schema.names, False)
|
517
|
-
return table
|
518
490
|
else:
|
519
|
-
|
491
|
+
raise _ex.EStartup(f"Pandas version not supported: [{pandas.__version__}]")
|
492
|
+
|
493
|
+
# Only partial mapping is possible, decimal and temporal dtypes cannot be mapped this way
|
494
|
+
__ARROW_TO_PANDAS_TYPE_MAPPING = {
|
495
|
+
pa.bool_(): pandas.BooleanDtype(),
|
496
|
+
pa.int8(): pandas.Int8Dtype(),
|
497
|
+
pa.int16(): pandas.Int16Dtype(),
|
498
|
+
pa.int32(): pandas.Int32Dtype(),
|
499
|
+
pa.int64(): pandas.Int64Dtype(),
|
500
|
+
pa.uint8(): pandas.UInt8Dtype(),
|
501
|
+
pa.uint16(): pandas.UInt16Dtype(),
|
502
|
+
pa.uint32(): pandas.UInt32Dtype(),
|
503
|
+
pa.uint64(): pandas.UInt64Dtype(),
|
504
|
+
pa.float16(): pandas.Float32Dtype(),
|
505
|
+
pa.float32(): pandas.Float32Dtype(),
|
506
|
+
pa.float64(): pandas.Float64Dtype(),
|
507
|
+
pa.string(): pandas.StringDtype(),
|
508
|
+
pa.utf8(): pandas.StringDtype()
|
509
|
+
}
|
510
|
+
|
511
|
+
__DEFAULT_TEMPORAL_OBJECTS = False
|
512
|
+
|
513
|
+
# Expose date type for testing
|
514
|
+
@classmethod
|
515
|
+
def pandas_date_type(cls):
|
516
|
+
return cls.__PANDAS_DATE_TYPE
|
520
517
|
|
521
|
-
|
522
|
-
|
518
|
+
# Expose datetime type for testing
|
519
|
+
@classmethod
|
520
|
+
def pandas_datetime_type(cls, tz=None, unit=None):
|
521
|
+
return cls.__pandas_datetime_type(tz, unit)
|
522
|
+
|
523
|
+
def __init__(self, framework: _api.DataFramework[T_DATA_API], use_temporal_objects: tp.Optional[bool] = None):
|
524
|
+
super().__init__(framework)
|
525
|
+
if use_temporal_objects is None:
|
526
|
+
self.__temporal_objects_flag = self.__DEFAULT_TEMPORAL_OBJECTS
|
527
|
+
else:
|
528
|
+
self.__temporal_objects_flag = use_temporal_objects
|
529
|
+
|
530
|
+
def from_internal(self, table: pa.Table, schema: tp.Optional[pa.Schema] = None) -> pandas.DataFrame:
|
531
|
+
|
532
|
+
if schema is not None:
|
533
|
+
table = DataConformance.conform_to_schema(table, schema, warn_extra_columns=False)
|
534
|
+
else:
|
535
|
+
DataConformance.check_duplicate_fields(table.schema.names, False)
|
536
|
+
|
537
|
+
# Use Arrow's built-in function to convert to Pandas
|
538
|
+
return table.to_pandas(
|
539
|
+
|
540
|
+
# Mapping for arrow -> pandas types for core types
|
541
|
+
types_mapper=self.__ARROW_TO_PANDAS_TYPE_MAPPING.get,
|
542
|
+
|
543
|
+
# Use Python objects for dates and times if temporal_objects_flag is set
|
544
|
+
date_as_object=self.__temporal_objects_flag, # noqa
|
545
|
+
timestamp_as_object=self.__temporal_objects_flag, # noqa
|
546
|
+
|
547
|
+
# Do not bring any Arrow metadata into Pandas dataframe
|
548
|
+
ignore_metadata=True, # noqa
|
549
|
+
|
550
|
+
# Do not consolidate memory across columns when preparing the Pandas vectors
|
551
|
+
# This is a significant performance win for very wide datasets
|
552
|
+
split_blocks=True) # noqa
|
553
|
+
|
554
|
+
def to_internal(self, df: pandas.DataFrame, schema: tp.Optional[pa.Schema] = None) -> pa.Table:
|
555
|
+
|
556
|
+
# Converting pandas -> arrow needs care to ensure type coercion is applied correctly
|
557
|
+
# Calling Table.from_pandas with the supplied schema will very often reject data
|
558
|
+
# Instead, we convert the dataframe as-is and then apply type conversion in a second step
|
559
|
+
# This allows us to apply specific coercion rules for each data type
|
560
|
+
|
561
|
+
# As an optimisation, the column filter means columns will not be converted if they are not needed
|
562
|
+
# E.g. if a model outputs lots of undeclared columns, there is no need to convert them
|
563
|
+
|
564
|
+
column_filter = DataConformance.column_filter(df.columns, schema) # noqa
|
565
|
+
|
566
|
+
if len(df) > 0:
|
567
|
+
|
568
|
+
table = pa.Table.from_pandas(df, columns=column_filter, preserve_index=False) # noqa
|
569
|
+
|
570
|
+
# Special case handling for converting an empty dataframe
|
571
|
+
# These must flow through the pipe with valid schemas, like any other dataset
|
572
|
+
# Type coercion and column filtering happen in conform_to_schema, if a schema has been supplied
|
573
|
+
|
574
|
+
else:
|
575
|
+
|
576
|
+
empty_df = df.filter(column_filter) if column_filter else df
|
577
|
+
empty_schema = pa.Schema.from_pandas(empty_df, preserve_index=False) # noqa
|
578
|
+
|
579
|
+
table = pa.Table.from_batches(list(), empty_schema) # noqa
|
580
|
+
|
581
|
+
# If there is no explict schema, give back the table exactly as it was received from Pandas
|
582
|
+
# There could be an option here to infer and coerce for TRAC standard types
|
583
|
+
# E.g. unsigned int 32 -> signed int 64, TRAC standard integer type
|
584
|
+
|
585
|
+
if schema is None:
|
586
|
+
DataConformance.check_duplicate_fields(table.schema.names, False)
|
587
|
+
return table
|
588
|
+
|
589
|
+
# If a schema has been supplied, apply data conformance
|
590
|
+
# If column filtering has been applied, we also need to filter the pandas dtypes used for hinting
|
591
|
+
|
592
|
+
else:
|
593
|
+
df_types = df.dtypes.filter(column_filter) if column_filter else df.dtypes
|
594
|
+
return DataConformance.conform_to_schema(table, schema, df_types)
|
595
|
+
|
596
|
+
def infer_schema(self, dataset: pandas.DataFrame) -> _meta.SchemaDefinition:
|
597
|
+
|
598
|
+
arrow_schema = pa.Schema.from_pandas(dataset, preserve_index=False) # noqa
|
599
|
+
return DataMapping.arrow_to_trac_schema(arrow_schema)
|
600
|
+
|
601
|
+
|
602
|
+
# Data frameworks are optional, do not blow up the module just because one framework is unavailable!
|
603
|
+
if polars is not None:
|
604
|
+
|
605
|
+
class PolarsArrowConverter(DataConverter[polars.DataFrame, pa.Table, pa.Schema]):
|
606
|
+
|
607
|
+
def __init__(self, framework: _api.DataFramework[T_DATA_API]):
|
608
|
+
super().__init__(framework)
|
609
|
+
|
610
|
+
def from_internal(self, table: pa.Table, schema: tp.Optional[pa.Schema] = None) -> polars.DataFrame:
|
611
|
+
|
612
|
+
if schema is not None:
|
613
|
+
table = DataConformance.conform_to_schema(table, schema, warn_extra_columns=False)
|
614
|
+
else:
|
615
|
+
DataConformance.check_duplicate_fields(table.schema.names, False)
|
616
|
+
|
617
|
+
return polars.from_arrow(table)
|
618
|
+
|
619
|
+
def to_internal(self, df: polars.DataFrame, schema: tp.Optional[pa.Schema] = None,) -> pa.Table:
|
620
|
+
|
621
|
+
column_filter = DataConformance.column_filter(df.columns, schema)
|
622
|
+
|
623
|
+
filtered_df = df.select(polars.col(*column_filter)) if column_filter else df
|
624
|
+
table = filtered_df.to_arrow()
|
625
|
+
|
626
|
+
if schema is None:
|
627
|
+
DataConformance.check_duplicate_fields(table.schema.names, False)
|
628
|
+
return table
|
629
|
+
else:
|
630
|
+
return DataConformance.conform_to_schema(table, schema, None)
|
631
|
+
|
632
|
+
def infer_schema(self, dataset: T_DATA_API) -> _meta.SchemaDefinition:
|
523
633
|
|
524
|
-
|
634
|
+
arrow_schema = dataset.top_k(1).to_arrow().schema
|
635
|
+
return DataMapping.arrow_to_trac_schema(arrow_schema)
|
525
636
|
|
526
637
|
|
527
638
|
class DataConformance:
|
@@ -652,7 +763,7 @@ class DataConformance:
|
|
652
763
|
# Columns not defined in the schema will not be included in the conformed output
|
653
764
|
if warn_extra_columns and table.num_columns > len(schema.types):
|
654
765
|
|
655
|
-
schema_columns = set(map(
|
766
|
+
schema_columns = set(map(lambda c: c.lower(), schema.names))
|
656
767
|
extra_columns = [
|
657
768
|
f"[{col}]"
|
658
769
|
for col in table.schema.names
|
@@ -0,0 +1,13 @@
|
|
1
|
+
# Copyright 2024 Accenture Global Solutions Limited
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|