pydiverse-common 0.3.9__py3-none-any.whl → 0.3.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pydiverse/common/dtypes.py +109 -55
- pydiverse/common/util/hashing.py +6 -3
- {pydiverse_common-0.3.9.dist-info → pydiverse_common-0.3.10.dist-info}/METADATA +1 -1
- {pydiverse_common-0.3.9.dist-info → pydiverse_common-0.3.10.dist-info}/RECORD +6 -6
- {pydiverse_common-0.3.9.dist-info → pydiverse_common-0.3.10.dist-info}/WHEEL +0 -0
- {pydiverse_common-0.3.9.dist-info → pydiverse_common-0.3.10.dist-info}/licenses/LICENSE +0 -0
pydiverse/common/dtypes.py
CHANGED
@@ -49,20 +49,20 @@ class Dtype:
|
|
49
49
|
@staticmethod
|
50
50
|
def from_sql(sql_type) -> "Dtype":
|
51
51
|
"""Convert a SQL type to a Dtype."""
|
52
|
-
import sqlalchemy as
|
52
|
+
import sqlalchemy as sa
|
53
53
|
|
54
|
-
if isinstance(sql_type,
|
54
|
+
if isinstance(sql_type, sa.SmallInteger):
|
55
55
|
return Int16()
|
56
|
-
if isinstance(sql_type,
|
56
|
+
if isinstance(sql_type, sa.BigInteger):
|
57
57
|
return Int64()
|
58
|
-
if isinstance(sql_type,
|
58
|
+
if isinstance(sql_type, sa.Integer):
|
59
59
|
return Int32()
|
60
|
-
if isinstance(sql_type,
|
60
|
+
if isinstance(sql_type, sa.Float):
|
61
61
|
precision = sql_type.precision or 53
|
62
62
|
if precision <= 24:
|
63
63
|
return Float32()
|
64
64
|
return Float64()
|
65
|
-
if isinstance(sql_type,
|
65
|
+
if isinstance(sql_type, sa.Numeric | sa.DECIMAL):
|
66
66
|
# Just to be safe, we always use FLOAT64 for fixpoint numbers.
|
67
67
|
# Databases are obsessed about fixpoint. However, in dataframes, it
|
68
68
|
# is more common to just work with double precision floating point.
|
@@ -70,21 +70,21 @@ class Dtype:
|
|
70
70
|
# Decimal to Float64 whenever it cannot guarantee semantic correctness
|
71
71
|
# otherwise.
|
72
72
|
return Float64()
|
73
|
-
if isinstance(sql_type,
|
73
|
+
if isinstance(sql_type, sa.String):
|
74
74
|
return String()
|
75
|
-
if isinstance(sql_type,
|
75
|
+
if isinstance(sql_type, sa.Boolean):
|
76
76
|
return Bool()
|
77
|
-
if isinstance(sql_type,
|
77
|
+
if isinstance(sql_type, sa.Date):
|
78
78
|
return Date()
|
79
|
-
if isinstance(sql_type,
|
79
|
+
if isinstance(sql_type, sa.Time):
|
80
80
|
return Time()
|
81
|
-
if isinstance(sql_type,
|
81
|
+
if isinstance(sql_type, sa.DateTime):
|
82
82
|
return Datetime()
|
83
|
-
if isinstance(sql_type,
|
83
|
+
if isinstance(sql_type, sa.Interval):
|
84
84
|
return Duration()
|
85
|
-
if isinstance(sql_type,
|
85
|
+
if isinstance(sql_type, sa.ARRAY):
|
86
86
|
return List(Dtype.from_sql(sql_type.item_type))
|
87
|
-
if isinstance(sql_type,
|
87
|
+
if isinstance(sql_type, sa.types.NullType):
|
88
88
|
return NullType()
|
89
89
|
|
90
90
|
raise TypeError
|
@@ -184,7 +184,7 @@ class Dtype:
|
|
184
184
|
raise TypeError
|
185
185
|
if pa.types.is_decimal(arrow_type):
|
186
186
|
# We don't recommend using Decimal in dataframes, but we support it.
|
187
|
-
return Decimal()
|
187
|
+
return Decimal(arrow_type.precision, arrow_type.scale)
|
188
188
|
if pa.types.is_string(arrow_type):
|
189
189
|
return String()
|
190
190
|
if pa.types.is_boolean(arrow_type):
|
@@ -217,6 +217,8 @@ class Dtype:
|
|
217
217
|
return List(Dtype.from_polars(polars_type.inner))
|
218
218
|
if isinstance(polars_type, pl.Enum):
|
219
219
|
return Enum(*polars_type.categories)
|
220
|
+
if isinstance(polars_type, pl.Decimal):
|
221
|
+
return Decimal(polars_type.precision, polars_type.scale)
|
220
222
|
|
221
223
|
return {
|
222
224
|
pl.Int64: Int64(),
|
@@ -229,7 +231,6 @@ class Dtype:
|
|
229
231
|
pl.UInt8: UInt8(),
|
230
232
|
pl.Float64: Float64(),
|
231
233
|
pl.Float32: Float32(),
|
232
|
-
pl.Decimal: Decimal(),
|
233
234
|
pl.Utf8: String(),
|
234
235
|
pl.Boolean: Bool(),
|
235
236
|
pl.Datetime: Datetime(),
|
@@ -241,29 +242,28 @@ class Dtype:
|
|
241
242
|
|
242
243
|
def to_sql(self):
|
243
244
|
"""Convert this Dtype to a SQL type."""
|
244
|
-
import sqlalchemy as
|
245
|
+
import sqlalchemy as sa
|
245
246
|
|
246
247
|
return {
|
247
|
-
Int():
|
248
|
-
Int8():
|
249
|
-
Int16():
|
250
|
-
Int32():
|
251
|
-
Int64():
|
252
|
-
UInt8():
|
253
|
-
UInt16():
|
254
|
-
UInt32():
|
255
|
-
UInt64():
|
256
|
-
Float():
|
257
|
-
Float32():
|
258
|
-
Float64():
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
NullType(): sqa.types.NullType(),
|
248
|
+
Int(): sa.BigInteger(), # we default to 64 bit
|
249
|
+
Int8(): sa.SmallInteger(),
|
250
|
+
Int16(): sa.SmallInteger(),
|
251
|
+
Int32(): sa.Integer(),
|
252
|
+
Int64(): sa.BigInteger(),
|
253
|
+
UInt8(): sa.SmallInteger(),
|
254
|
+
UInt16(): sa.Integer(),
|
255
|
+
UInt32(): sa.BigInteger(),
|
256
|
+
UInt64(): sa.BigInteger(),
|
257
|
+
Float(): sa.Float(53), # we default to 64 bit
|
258
|
+
Float32(): sa.Float(24),
|
259
|
+
Float64(): sa.Float(53),
|
260
|
+
String(): sa.String(),
|
261
|
+
Bool(): sa.Boolean(),
|
262
|
+
Date(): sa.Date(),
|
263
|
+
Time(): sa.Time(),
|
264
|
+
Datetime(): sa.DateTime(),
|
265
|
+
Duration(): sa.Interval(),
|
266
|
+
NullType(): sa.types.NullType(),
|
267
267
|
}[self]
|
268
268
|
|
269
269
|
def to_pandas(self, backend: PandasBackend = PandasBackend.ARROW):
|
@@ -273,7 +273,7 @@ class Dtype:
|
|
273
273
|
if backend == PandasBackend.NUMPY:
|
274
274
|
return self.to_pandas_nullable(backend)
|
275
275
|
if backend == PandasBackend.ARROW:
|
276
|
-
if self
|
276
|
+
if isinstance(self, String) or isinstance(self, Enum):
|
277
277
|
return pd.StringDtype(storage="pyarrow")
|
278
278
|
return pd.ArrowDtype(self.to_arrow())
|
279
279
|
|
@@ -321,7 +321,7 @@ class Dtype:
|
|
321
321
|
Float(): pd.Float64Dtype(), # we default to 64 bit
|
322
322
|
Float32(): pd.Float32Dtype(),
|
323
323
|
Float64(): pd.Float64Dtype(),
|
324
|
-
Decimal(): pd.Float64Dtype(), # NumericDtype is
|
324
|
+
Decimal(): pd.Float64Dtype(), # NumericDtype exists but is not used
|
325
325
|
String(): pd.StringDtype(),
|
326
326
|
Bool(): pd.BooleanDtype(),
|
327
327
|
Date(): "datetime64[s]",
|
@@ -334,9 +334,6 @@ class Dtype:
|
|
334
334
|
"""Convert this Dtype to a PyArrow type."""
|
335
335
|
import pyarrow as pa
|
336
336
|
|
337
|
-
if isinstance(self, Enum):
|
338
|
-
return pa.string()
|
339
|
-
|
340
337
|
return {
|
341
338
|
Int(): pa.int64(), # we default to 64 bit
|
342
339
|
Int8(): pa.int8(),
|
@@ -350,7 +347,6 @@ class Dtype:
|
|
350
347
|
Float(): pa.float64(), # we default to 64 bit
|
351
348
|
Float32(): pa.float32(),
|
352
349
|
Float64(): pa.float64(),
|
353
|
-
Decimal(): pa.decimal128(35, 10), # Arbitrary precision
|
354
350
|
String(): pa.string(),
|
355
351
|
Bool(): pa.bool_(),
|
356
352
|
Date(): pa.date32(),
|
@@ -383,8 +379,6 @@ class Dtype:
|
|
383
379
|
Float(): pl.Float64, # we default to 64 bit
|
384
380
|
Float64(): pl.Float64,
|
385
381
|
Float32(): pl.Float32,
|
386
|
-
Decimal(): pl.Decimal(scale=10), # Arbitrary precision
|
387
|
-
String(): pl.Utf8,
|
388
382
|
Bool(): pl.Boolean,
|
389
383
|
Datetime(): pl.Datetime("us"),
|
390
384
|
Duration(): pl.Duration("us"),
|
@@ -406,7 +400,43 @@ class Float64(Float): ...
|
|
406
400
|
class Float32(Float): ...
|
407
401
|
|
408
402
|
|
409
|
-
class Decimal(Float):
|
403
|
+
class Decimal(Float):
|
404
|
+
def __init__(self, precision: int | None = None, scale: int | None = None):
|
405
|
+
"""
|
406
|
+
Initialize a Decimal Dtype.
|
407
|
+
|
408
|
+
Default is Decimal(31,10) which is the highest precision that works with DB2.
|
409
|
+
If you like to save memory, Decimal(15,6) will get you quite far as well.
|
410
|
+
|
411
|
+
:param precision: total number of digits in the number
|
412
|
+
If not specified, it is assumed to be 31.
|
413
|
+
:param scale: number of digits after the decimal point
|
414
|
+
If not specified, it is assumed to be (precision//3+1).
|
415
|
+
"""
|
416
|
+
self.precision = precision or 31
|
417
|
+
self.scale = scale or (self.precision // 3 + 1)
|
418
|
+
|
419
|
+
def to_sql(self):
|
420
|
+
import sqlalchemy as sa
|
421
|
+
|
422
|
+
return sa.Numeric(self.precision, self.scale)
|
423
|
+
|
424
|
+
def to_polars(self):
|
425
|
+
import polars as pl
|
426
|
+
|
427
|
+
return pl.Decimal(self.precision, self.scale)
|
428
|
+
|
429
|
+
def to_arrow(self):
|
430
|
+
import pyarrow as pa
|
431
|
+
|
432
|
+
if self.precision > 38:
|
433
|
+
return pa.decimal256(self.precision, self.scale)
|
434
|
+
elif self.precision > 18:
|
435
|
+
return pa.decimal128(self.precision, self.scale)
|
436
|
+
elif self.precision > 9:
|
437
|
+
return pa.decimal64(self.precision, self.scale)
|
438
|
+
else:
|
439
|
+
return pa.decimal32(self.precision, self.scale)
|
410
440
|
|
411
441
|
|
412
442
|
class Int(Dtype):
|
@@ -439,7 +469,32 @@ class UInt16(Int): ...
|
|
439
469
|
class UInt8(Int): ...
|
440
470
|
|
441
471
|
|
442
|
-
class String(Dtype):
|
472
|
+
class String(Dtype):
|
473
|
+
def __init__(self, max_length: int | None = None):
|
474
|
+
"""
|
475
|
+
Initialize a String Dtype.
|
476
|
+
|
477
|
+
:param max_length: maximum length of string
|
478
|
+
This length will only be used for specifying fixed length strings in SQL.
|
479
|
+
Thus, the meaning of characters vs. bytes is dependent on the SQL dialect.
|
480
|
+
"""
|
481
|
+
self.max_length = max_length
|
482
|
+
|
483
|
+
def to_sql(self):
|
484
|
+
"""Convert this Dtype to a SQL type."""
|
485
|
+
import sqlalchemy as sa
|
486
|
+
|
487
|
+
return sa.String(length=self.max_length)
|
488
|
+
|
489
|
+
def to_polars(self):
|
490
|
+
import polars as pl
|
491
|
+
|
492
|
+
return pl.Utf8
|
493
|
+
|
494
|
+
def to_arrow(self):
|
495
|
+
import pyarrow as pa
|
496
|
+
|
497
|
+
return pa.string()
|
443
498
|
|
444
499
|
|
445
500
|
class Bool(Dtype): ...
|
@@ -474,9 +529,9 @@ class List(Dtype):
|
|
474
529
|
return f"List[{repr(self.inner)}]"
|
475
530
|
|
476
531
|
def to_sql(self):
|
477
|
-
import sqlalchemy as
|
532
|
+
import sqlalchemy as sa
|
478
533
|
|
479
|
-
return
|
534
|
+
return sa.ARRAY(self.inner.to_sql())
|
480
535
|
|
481
536
|
def to_polars(self):
|
482
537
|
import polars as pl
|
@@ -494,6 +549,9 @@ class Enum(String):
|
|
494
549
|
if not all(isinstance(c, str) for c in categories):
|
495
550
|
raise TypeError("arguments for `Enum` must have type `str`")
|
496
551
|
self.categories = list(categories)
|
552
|
+
self.max_length = (
|
553
|
+
max([len(c) for c in categories]) if len(categories) > 0 else None
|
554
|
+
)
|
497
555
|
|
498
556
|
def __eq__(self, rhs):
|
499
557
|
return isinstance(rhs, Enum) and self.categories == rhs.categories
|
@@ -509,11 +567,6 @@ class Enum(String):
|
|
509
567
|
|
510
568
|
return pl.Enum(self.categories)
|
511
569
|
|
512
|
-
def to_sql(self):
|
513
|
-
import sqlalchemy as sqa
|
514
|
-
|
515
|
-
return sqa.String()
|
516
|
-
|
517
570
|
def to_arrow(self):
|
518
571
|
import pyarrow as pa
|
519
572
|
|
@@ -530,6 +583,7 @@ class Enum(String):
|
|
530
583
|
pa.dictionary(pa.uint32(), pa.large_string()),
|
531
584
|
nullable=nullable,
|
532
585
|
metadata={
|
533
|
-
|
586
|
+
# the key might change with polars versions
|
587
|
+
"_PL_ENUM_VALUES2": "".join([f"{len(c)};{c}" for c in self.categories])
|
534
588
|
},
|
535
589
|
)
|
pydiverse/common/util/hashing.py
CHANGED
@@ -52,11 +52,14 @@ def hash_polars_dataframe(df: pl.DataFrame, use_init_repr=False) -> str:
|
|
52
52
|
]:
|
53
53
|
df = df.with_columns(
|
54
54
|
pl.col(struct_col_name).struct.rename_fields(
|
55
|
-
[stable_hash(struct_col_name,
|
55
|
+
[stable_hash(struct_col_name, struct_dtype)]
|
56
56
|
)
|
57
|
-
for struct_col_name,
|
57
|
+
for struct_col_name, struct_dtype in struct_cols_and_dtypes
|
58
58
|
).unnest(
|
59
|
-
|
59
|
+
[
|
60
|
+
struct_col_name
|
61
|
+
for struct_col_name, _ in struct_cols_and_dtypes
|
62
|
+
]
|
60
63
|
)
|
61
64
|
return df
|
62
65
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: pydiverse-common
|
3
|
-
Version: 0.3.
|
3
|
+
Version: 0.3.10
|
4
4
|
Summary: Common functionality shared between pydiverse libraries
|
5
5
|
Author: QuantCo, Inc.
|
6
6
|
Author-email: Martin Trautmann <windiana@users.sf.net>, Finn Rudolph <finn.rudolph@t-online.de>
|
@@ -1,5 +1,5 @@
|
|
1
1
|
pydiverse/common/__init__.py,sha256=J7b4iStFyaEMYre_jdlZ4l_8dLyrMWCIpQdsMQcB8aI,806
|
2
|
-
pydiverse/common/dtypes.py,sha256=
|
2
|
+
pydiverse/common/dtypes.py,sha256=ekeH6Ygnfy6sTsMVG4CcEdGHH_C6Z550tkdkaMQNIyY,18782
|
3
3
|
pydiverse/common/testing.py,sha256=FcivI5wn0X3gzJhwnysKvCOgjSTTXaN6FtSFJ72jfSg,341
|
4
4
|
pydiverse/common/version.py,sha256=1IU_m4r76_Qq0u-Tyo2_bERZFOkh0ZFueVzDqcCfLO0,336
|
5
5
|
pydiverse/common/errors/__init__.py,sha256=FNeEfVbUa23b9sHkFsmxHYhY6sRgjaZysPQmlovpJrI,262
|
@@ -8,10 +8,10 @@ pydiverse/common/util/computation_tracing.py,sha256=HeXRHRUI8vxpzQ27Xcpa0StndSTP
|
|
8
8
|
pydiverse/common/util/deep_map.py,sha256=JtY5ViWMMelOiLzPF7ZjzruCfB-bETISGxCk37qETxg,2540
|
9
9
|
pydiverse/common/util/deep_merge.py,sha256=bV5p5_lsC-9nFah28EiEyG2h6U3Z5AuTqSooxOgCHN0,1929
|
10
10
|
pydiverse/common/util/disposable.py,sha256=4XoGz70YRWA9TAqnUBvRCTAdsOGBviFN0gzxU7veY9o,993
|
11
|
-
pydiverse/common/util/hashing.py,sha256=
|
11
|
+
pydiverse/common/util/hashing.py,sha256=EofnKULVKXv-S9kry0mOqHU5bxPGomCtr6XfYqIhGgc,4650
|
12
12
|
pydiverse/common/util/import_.py,sha256=K7dSgz4YyrqEvqhoOzbwgD7D8HScMoO5XoSWtjbaoUs,4056
|
13
13
|
pydiverse/common/util/structlog.py,sha256=xxhauxMuyxcKXTVg1MiPTkuvPBj8Zcr4o_v8Bq59Nig,3778
|
14
|
-
pydiverse_common-0.3.
|
15
|
-
pydiverse_common-0.3.
|
16
|
-
pydiverse_common-0.3.
|
17
|
-
pydiverse_common-0.3.
|
14
|
+
pydiverse_common-0.3.10.dist-info/METADATA,sha256=8xmLOqUJZbuETqz49ve_W8NNVkHGyaeOwT2exOhk0uY,3400
|
15
|
+
pydiverse_common-0.3.10.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
16
|
+
pydiverse_common-0.3.10.dist-info/licenses/LICENSE,sha256=AcE6SDVuAq6v9ZLE_8eOCe_NvSE0rAPR3NR7lSowYh4,1517
|
17
|
+
pydiverse_common-0.3.10.dist-info/RECORD,,
|
File without changes
|
File without changes
|