pydiverse-common 0.3.8__py3-none-any.whl → 0.3.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pydiverse/common/dtypes.py +134 -56
- pydiverse/common/util/hashing.py +6 -3
- {pydiverse_common-0.3.8.dist-info → pydiverse_common-0.3.10.dist-info}/METADATA +1 -1
- {pydiverse_common-0.3.8.dist-info → pydiverse_common-0.3.10.dist-info}/RECORD +6 -6
- {pydiverse_common-0.3.8.dist-info → pydiverse_common-0.3.10.dist-info}/WHEEL +0 -0
- {pydiverse_common-0.3.8.dist-info → pydiverse_common-0.3.10.dist-info}/licenses/LICENSE +0 -0
pydiverse/common/dtypes.py
CHANGED
@@ -49,20 +49,20 @@ class Dtype:
|
|
49
49
|
@staticmethod
|
50
50
|
def from_sql(sql_type) -> "Dtype":
|
51
51
|
"""Convert a SQL type to a Dtype."""
|
52
|
-
import sqlalchemy as
|
52
|
+
import sqlalchemy as sa
|
53
53
|
|
54
|
-
if isinstance(sql_type,
|
54
|
+
if isinstance(sql_type, sa.SmallInteger):
|
55
55
|
return Int16()
|
56
|
-
if isinstance(sql_type,
|
56
|
+
if isinstance(sql_type, sa.BigInteger):
|
57
57
|
return Int64()
|
58
|
-
if isinstance(sql_type,
|
58
|
+
if isinstance(sql_type, sa.Integer):
|
59
59
|
return Int32()
|
60
|
-
if isinstance(sql_type,
|
60
|
+
if isinstance(sql_type, sa.Float):
|
61
61
|
precision = sql_type.precision or 53
|
62
62
|
if precision <= 24:
|
63
63
|
return Float32()
|
64
64
|
return Float64()
|
65
|
-
if isinstance(sql_type,
|
65
|
+
if isinstance(sql_type, sa.Numeric | sa.DECIMAL):
|
66
66
|
# Just to be safe, we always use FLOAT64 for fixpoint numbers.
|
67
67
|
# Databases are obsessed about fixpoint. However, in dataframes, it
|
68
68
|
# is more common to just work with double precision floating point.
|
@@ -70,21 +70,21 @@ class Dtype:
|
|
70
70
|
# Decimal to Float64 whenever it cannot guarantee semantic correctness
|
71
71
|
# otherwise.
|
72
72
|
return Float64()
|
73
|
-
if isinstance(sql_type,
|
73
|
+
if isinstance(sql_type, sa.String):
|
74
74
|
return String()
|
75
|
-
if isinstance(sql_type,
|
75
|
+
if isinstance(sql_type, sa.Boolean):
|
76
76
|
return Bool()
|
77
|
-
if isinstance(sql_type,
|
77
|
+
if isinstance(sql_type, sa.Date):
|
78
78
|
return Date()
|
79
|
-
if isinstance(sql_type,
|
79
|
+
if isinstance(sql_type, sa.Time):
|
80
80
|
return Time()
|
81
|
-
if isinstance(sql_type,
|
81
|
+
if isinstance(sql_type, sa.DateTime):
|
82
82
|
return Datetime()
|
83
|
-
if isinstance(sql_type,
|
83
|
+
if isinstance(sql_type, sa.Interval):
|
84
84
|
return Duration()
|
85
|
-
if isinstance(sql_type,
|
85
|
+
if isinstance(sql_type, sa.ARRAY):
|
86
86
|
return List(Dtype.from_sql(sql_type.item_type))
|
87
|
-
if isinstance(sql_type,
|
87
|
+
if isinstance(sql_type, sa.types.NullType):
|
88
88
|
return NullType()
|
89
89
|
|
90
90
|
raise TypeError
|
@@ -184,7 +184,7 @@ class Dtype:
|
|
184
184
|
raise TypeError
|
185
185
|
if pa.types.is_decimal(arrow_type):
|
186
186
|
# We don't recommend using Decimal in dataframes, but we support it.
|
187
|
-
return Decimal()
|
187
|
+
return Decimal(arrow_type.precision, arrow_type.scale)
|
188
188
|
if pa.types.is_string(arrow_type):
|
189
189
|
return String()
|
190
190
|
if pa.types.is_boolean(arrow_type):
|
@@ -201,6 +201,11 @@ class Dtype:
|
|
201
201
|
return NullType()
|
202
202
|
if pa.types.is_list(arrow_type):
|
203
203
|
return List(Dtype.from_arrow(arrow_type.value_type))
|
204
|
+
if pa.types.is_dictionary(arrow_type):
|
205
|
+
raise RuntimeError(
|
206
|
+
"Most likely this is an Enum type. But metadata about categories is "
|
207
|
+
"only in the pyarrow field and not in the pyarrow dtype"
|
208
|
+
)
|
204
209
|
raise TypeError
|
205
210
|
|
206
211
|
@staticmethod
|
@@ -212,6 +217,8 @@ class Dtype:
|
|
212
217
|
return List(Dtype.from_polars(polars_type.inner))
|
213
218
|
if isinstance(polars_type, pl.Enum):
|
214
219
|
return Enum(*polars_type.categories)
|
220
|
+
if isinstance(polars_type, pl.Decimal):
|
221
|
+
return Decimal(polars_type.precision, polars_type.scale)
|
215
222
|
|
216
223
|
return {
|
217
224
|
pl.Int64: Int64(),
|
@@ -224,7 +231,6 @@ class Dtype:
|
|
224
231
|
pl.UInt8: UInt8(),
|
225
232
|
pl.Float64: Float64(),
|
226
233
|
pl.Float32: Float32(),
|
227
|
-
pl.Decimal: Decimal(),
|
228
234
|
pl.Utf8: String(),
|
229
235
|
pl.Boolean: Bool(),
|
230
236
|
pl.Datetime: Datetime(),
|
@@ -236,29 +242,28 @@ class Dtype:
|
|
236
242
|
|
237
243
|
def to_sql(self):
|
238
244
|
"""Convert this Dtype to a SQL type."""
|
239
|
-
import sqlalchemy as
|
245
|
+
import sqlalchemy as sa
|
240
246
|
|
241
247
|
return {
|
242
|
-
Int():
|
243
|
-
Int8():
|
244
|
-
Int16():
|
245
|
-
Int32():
|
246
|
-
Int64():
|
247
|
-
UInt8():
|
248
|
-
UInt16():
|
249
|
-
UInt32():
|
250
|
-
UInt64():
|
251
|
-
Float():
|
252
|
-
Float32():
|
253
|
-
Float64():
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
NullType(): sqa.types.NullType(),
|
248
|
+
Int(): sa.BigInteger(), # we default to 64 bit
|
249
|
+
Int8(): sa.SmallInteger(),
|
250
|
+
Int16(): sa.SmallInteger(),
|
251
|
+
Int32(): sa.Integer(),
|
252
|
+
Int64(): sa.BigInteger(),
|
253
|
+
UInt8(): sa.SmallInteger(),
|
254
|
+
UInt16(): sa.Integer(),
|
255
|
+
UInt32(): sa.BigInteger(),
|
256
|
+
UInt64(): sa.BigInteger(),
|
257
|
+
Float(): sa.Float(53), # we default to 64 bit
|
258
|
+
Float32(): sa.Float(24),
|
259
|
+
Float64(): sa.Float(53),
|
260
|
+
String(): sa.String(),
|
261
|
+
Bool(): sa.Boolean(),
|
262
|
+
Date(): sa.Date(),
|
263
|
+
Time(): sa.Time(),
|
264
|
+
Datetime(): sa.DateTime(),
|
265
|
+
Duration(): sa.Interval(),
|
266
|
+
NullType(): sa.types.NullType(),
|
262
267
|
}[self]
|
263
268
|
|
264
269
|
def to_pandas(self, backend: PandasBackend = PandasBackend.ARROW):
|
@@ -268,7 +273,7 @@ class Dtype:
|
|
268
273
|
if backend == PandasBackend.NUMPY:
|
269
274
|
return self.to_pandas_nullable(backend)
|
270
275
|
if backend == PandasBackend.ARROW:
|
271
|
-
if self
|
276
|
+
if isinstance(self, String) or isinstance(self, Enum):
|
272
277
|
return pd.StringDtype(storage="pyarrow")
|
273
278
|
return pd.ArrowDtype(self.to_arrow())
|
274
279
|
|
@@ -316,7 +321,7 @@ class Dtype:
|
|
316
321
|
Float(): pd.Float64Dtype(), # we default to 64 bit
|
317
322
|
Float32(): pd.Float32Dtype(),
|
318
323
|
Float64(): pd.Float64Dtype(),
|
319
|
-
Decimal(): pd.Float64Dtype(), # NumericDtype is
|
324
|
+
Decimal(): pd.Float64Dtype(), # NumericDtype exists but is not used
|
320
325
|
String(): pd.StringDtype(),
|
321
326
|
Bool(): pd.BooleanDtype(),
|
322
327
|
Date(): "datetime64[s]",
|
@@ -329,9 +334,6 @@ class Dtype:
|
|
329
334
|
"""Convert this Dtype to a PyArrow type."""
|
330
335
|
import pyarrow as pa
|
331
336
|
|
332
|
-
if isinstance(self, Enum):
|
333
|
-
return pa.string()
|
334
|
-
|
335
337
|
return {
|
336
338
|
Int(): pa.int64(), # we default to 64 bit
|
337
339
|
Int8(): pa.int8(),
|
@@ -345,7 +347,6 @@ class Dtype:
|
|
345
347
|
Float(): pa.float64(), # we default to 64 bit
|
346
348
|
Float32(): pa.float32(),
|
347
349
|
Float64(): pa.float64(),
|
348
|
-
Decimal(): pa.decimal128(35, 10), # Arbitrary precision
|
349
350
|
String(): pa.string(),
|
350
351
|
Bool(): pa.bool_(),
|
351
352
|
Date(): pa.date32(),
|
@@ -355,6 +356,12 @@ class Dtype:
|
|
355
356
|
NullType(): pa.null(),
|
356
357
|
}[self]
|
357
358
|
|
359
|
+
def to_arrow_field(self, name: str, nullable: bool = True):
|
360
|
+
"""Convert this Dtype to a PyArrow Field."""
|
361
|
+
import pyarrow as pa
|
362
|
+
|
363
|
+
return pa.field(name, self.to_arrow(), nullable=nullable)
|
364
|
+
|
358
365
|
def to_polars(self: "Dtype"):
|
359
366
|
"""Convert this Dtype to a Polars type."""
|
360
367
|
import polars as pl
|
@@ -372,8 +379,6 @@ class Dtype:
|
|
372
379
|
Float(): pl.Float64, # we default to 64 bit
|
373
380
|
Float64(): pl.Float64,
|
374
381
|
Float32(): pl.Float32,
|
375
|
-
Decimal(): pl.Decimal(scale=10), # Arbitrary precision
|
376
|
-
String(): pl.Utf8,
|
377
382
|
Bool(): pl.Boolean,
|
378
383
|
Datetime(): pl.Datetime("us"),
|
379
384
|
Duration(): pl.Duration("us"),
|
@@ -395,7 +400,43 @@ class Float64(Float): ...
|
|
395
400
|
class Float32(Float): ...
|
396
401
|
|
397
402
|
|
398
|
-
class Decimal(Float):
|
403
|
+
class Decimal(Float):
|
404
|
+
def __init__(self, precision: int | None = None, scale: int | None = None):
|
405
|
+
"""
|
406
|
+
Initialize a Decimal Dtype.
|
407
|
+
|
408
|
+
Default is Decimal(31,10) which is the highest precision that works with DB2.
|
409
|
+
If you like to save memory, Decimal(15,6) will get you quite far as well.
|
410
|
+
|
411
|
+
:param precision: total number of digits in the number
|
412
|
+
If not specified, it is assumed to be 31.
|
413
|
+
:param scale: number of digits after the decimal point
|
414
|
+
If not specified, it is assumed to be (precision//3+1).
|
415
|
+
"""
|
416
|
+
self.precision = precision or 31
|
417
|
+
self.scale = scale or (self.precision // 3 + 1)
|
418
|
+
|
419
|
+
def to_sql(self):
|
420
|
+
import sqlalchemy as sa
|
421
|
+
|
422
|
+
return sa.Numeric(self.precision, self.scale)
|
423
|
+
|
424
|
+
def to_polars(self):
|
425
|
+
import polars as pl
|
426
|
+
|
427
|
+
return pl.Decimal(self.precision, self.scale)
|
428
|
+
|
429
|
+
def to_arrow(self):
|
430
|
+
import pyarrow as pa
|
431
|
+
|
432
|
+
if self.precision > 38:
|
433
|
+
return pa.decimal256(self.precision, self.scale)
|
434
|
+
elif self.precision > 18:
|
435
|
+
return pa.decimal128(self.precision, self.scale)
|
436
|
+
elif self.precision > 9:
|
437
|
+
return pa.decimal64(self.precision, self.scale)
|
438
|
+
else:
|
439
|
+
return pa.decimal32(self.precision, self.scale)
|
399
440
|
|
400
441
|
|
401
442
|
class Int(Dtype):
|
@@ -428,7 +469,32 @@ class UInt16(Int): ...
|
|
428
469
|
class UInt8(Int): ...
|
429
470
|
|
430
471
|
|
431
|
-
class String(Dtype):
|
472
|
+
class String(Dtype):
|
473
|
+
def __init__(self, max_length: int | None = None):
|
474
|
+
"""
|
475
|
+
Initialize a String Dtype.
|
476
|
+
|
477
|
+
:param max_length: maximum length of string
|
478
|
+
This length will only be used for specifying fixed length strings in SQL.
|
479
|
+
Thus, the meaning of characters vs. bytes is dependent on the SQL dialect.
|
480
|
+
"""
|
481
|
+
self.max_length = max_length
|
482
|
+
|
483
|
+
def to_sql(self):
|
484
|
+
"""Convert this Dtype to a SQL type."""
|
485
|
+
import sqlalchemy as sa
|
486
|
+
|
487
|
+
return sa.String(length=self.max_length)
|
488
|
+
|
489
|
+
def to_polars(self):
|
490
|
+
import polars as pl
|
491
|
+
|
492
|
+
return pl.Utf8
|
493
|
+
|
494
|
+
def to_arrow(self):
|
495
|
+
import pyarrow as pa
|
496
|
+
|
497
|
+
return pa.string()
|
432
498
|
|
433
499
|
|
434
500
|
class Bool(Dtype): ...
|
@@ -463,9 +529,9 @@ class List(Dtype):
|
|
463
529
|
return f"List[{repr(self.inner)}]"
|
464
530
|
|
465
531
|
def to_sql(self):
|
466
|
-
import sqlalchemy as
|
532
|
+
import sqlalchemy as sa
|
467
533
|
|
468
|
-
return
|
534
|
+
return sa.ARRAY(self.inner.to_sql())
|
469
535
|
|
470
536
|
def to_polars(self):
|
471
537
|
import polars as pl
|
@@ -483,6 +549,9 @@ class Enum(String):
|
|
483
549
|
if not all(isinstance(c, str) for c in categories):
|
484
550
|
raise TypeError("arguments for `Enum` must have type `str`")
|
485
551
|
self.categories = list(categories)
|
552
|
+
self.max_length = (
|
553
|
+
max([len(c) for c in categories]) if len(categories) > 0 else None
|
554
|
+
)
|
486
555
|
|
487
556
|
def __eq__(self, rhs):
|
488
557
|
return isinstance(rhs, Enum) and self.categories == rhs.categories
|
@@ -498,14 +567,23 @@ class Enum(String):
|
|
498
567
|
|
499
568
|
return pl.Enum(self.categories)
|
500
569
|
|
501
|
-
def to_sql(self):
|
502
|
-
import sqlalchemy as sqa
|
503
|
-
|
504
|
-
return sqa.String()
|
505
|
-
|
506
570
|
def to_arrow(self):
|
507
571
|
import pyarrow as pa
|
508
572
|
|
509
|
-
#
|
510
|
-
# Maybe it is better to convert to this.
|
573
|
+
# enum categories can only be maintained in pyarrow field (see to_arrow_field)
|
511
574
|
return pa.string()
|
575
|
+
|
576
|
+
def to_arrow_field(self, name: str, nullable: bool = True):
|
577
|
+
"""Convert this Dtype to a PyArrow Field."""
|
578
|
+
import pyarrow as pa
|
579
|
+
|
580
|
+
# try to mimic what polars does
|
581
|
+
return pa.field(
|
582
|
+
name,
|
583
|
+
pa.dictionary(pa.uint32(), pa.large_string()),
|
584
|
+
nullable=nullable,
|
585
|
+
metadata={
|
586
|
+
# the key might change with polars versions
|
587
|
+
"_PL_ENUM_VALUES2": "".join([f"{len(c)};{c}" for c in self.categories])
|
588
|
+
},
|
589
|
+
)
|
pydiverse/common/util/hashing.py
CHANGED
@@ -52,11 +52,14 @@ def hash_polars_dataframe(df: pl.DataFrame, use_init_repr=False) -> str:
|
|
52
52
|
]:
|
53
53
|
df = df.with_columns(
|
54
54
|
pl.col(struct_col_name).struct.rename_fields(
|
55
|
-
[stable_hash(struct_col_name,
|
55
|
+
[stable_hash(struct_col_name, struct_dtype)]
|
56
56
|
)
|
57
|
-
for struct_col_name,
|
57
|
+
for struct_col_name, struct_dtype in struct_cols_and_dtypes
|
58
58
|
).unnest(
|
59
|
-
|
59
|
+
[
|
60
|
+
struct_col_name
|
61
|
+
for struct_col_name, _ in struct_cols_and_dtypes
|
62
|
+
]
|
60
63
|
)
|
61
64
|
return df
|
62
65
|
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: pydiverse-common
|
3
|
-
Version: 0.3.
|
3
|
+
Version: 0.3.10
|
4
4
|
Summary: Common functionality shared between pydiverse libraries
|
5
5
|
Author: QuantCo, Inc.
|
6
6
|
Author-email: Martin Trautmann <windiana@users.sf.net>, Finn Rudolph <finn.rudolph@t-online.de>
|
@@ -1,5 +1,5 @@
|
|
1
1
|
pydiverse/common/__init__.py,sha256=J7b4iStFyaEMYre_jdlZ4l_8dLyrMWCIpQdsMQcB8aI,806
|
2
|
-
pydiverse/common/dtypes.py,sha256=
|
2
|
+
pydiverse/common/dtypes.py,sha256=ekeH6Ygnfy6sTsMVG4CcEdGHH_C6Z550tkdkaMQNIyY,18782
|
3
3
|
pydiverse/common/testing.py,sha256=FcivI5wn0X3gzJhwnysKvCOgjSTTXaN6FtSFJ72jfSg,341
|
4
4
|
pydiverse/common/version.py,sha256=1IU_m4r76_Qq0u-Tyo2_bERZFOkh0ZFueVzDqcCfLO0,336
|
5
5
|
pydiverse/common/errors/__init__.py,sha256=FNeEfVbUa23b9sHkFsmxHYhY6sRgjaZysPQmlovpJrI,262
|
@@ -8,10 +8,10 @@ pydiverse/common/util/computation_tracing.py,sha256=HeXRHRUI8vxpzQ27Xcpa0StndSTP
|
|
8
8
|
pydiverse/common/util/deep_map.py,sha256=JtY5ViWMMelOiLzPF7ZjzruCfB-bETISGxCk37qETxg,2540
|
9
9
|
pydiverse/common/util/deep_merge.py,sha256=bV5p5_lsC-9nFah28EiEyG2h6U3Z5AuTqSooxOgCHN0,1929
|
10
10
|
pydiverse/common/util/disposable.py,sha256=4XoGz70YRWA9TAqnUBvRCTAdsOGBviFN0gzxU7veY9o,993
|
11
|
-
pydiverse/common/util/hashing.py,sha256=
|
11
|
+
pydiverse/common/util/hashing.py,sha256=EofnKULVKXv-S9kry0mOqHU5bxPGomCtr6XfYqIhGgc,4650
|
12
12
|
pydiverse/common/util/import_.py,sha256=K7dSgz4YyrqEvqhoOzbwgD7D8HScMoO5XoSWtjbaoUs,4056
|
13
13
|
pydiverse/common/util/structlog.py,sha256=xxhauxMuyxcKXTVg1MiPTkuvPBj8Zcr4o_v8Bq59Nig,3778
|
14
|
-
pydiverse_common-0.3.
|
15
|
-
pydiverse_common-0.3.
|
16
|
-
pydiverse_common-0.3.
|
17
|
-
pydiverse_common-0.3.
|
14
|
+
pydiverse_common-0.3.10.dist-info/METADATA,sha256=8xmLOqUJZbuETqz49ve_W8NNVkHGyaeOwT2exOhk0uY,3400
|
15
|
+
pydiverse_common-0.3.10.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
16
|
+
pydiverse_common-0.3.10.dist-info/licenses/LICENSE,sha256=AcE6SDVuAq6v9ZLE_8eOCe_NvSE0rAPR3NR7lSowYh4,1517
|
17
|
+
pydiverse_common-0.3.10.dist-info/RECORD,,
|
File without changes
|
File without changes
|