pydiverse-common 0.3.9__py3-none-any.whl → 0.3.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -49,20 +49,20 @@ class Dtype:
49
49
  @staticmethod
50
50
  def from_sql(sql_type) -> "Dtype":
51
51
  """Convert a SQL type to a Dtype."""
52
- import sqlalchemy as sqa
52
+ import sqlalchemy as sa
53
53
 
54
- if isinstance(sql_type, sqa.SmallInteger):
54
+ if isinstance(sql_type, sa.SmallInteger):
55
55
  return Int16()
56
- if isinstance(sql_type, sqa.BigInteger):
56
+ if isinstance(sql_type, sa.BigInteger):
57
57
  return Int64()
58
- if isinstance(sql_type, sqa.Integer):
58
+ if isinstance(sql_type, sa.Integer):
59
59
  return Int32()
60
- if isinstance(sql_type, sqa.Float):
60
+ if isinstance(sql_type, sa.Float):
61
61
  precision = sql_type.precision or 53
62
62
  if precision <= 24:
63
63
  return Float32()
64
64
  return Float64()
65
- if isinstance(sql_type, sqa.Numeric | sqa.DECIMAL):
65
+ if isinstance(sql_type, sa.Numeric | sa.DECIMAL):
66
66
  # Just to be safe, we always use FLOAT64 for fixpoint numbers.
67
67
  # Databases are obsessed about fixpoint. However, in dataframes, it
68
68
  # is more common to just work with double precision floating point.
@@ -70,21 +70,21 @@ class Dtype:
70
70
  # Decimal to Float64 whenever it cannot guarantee semantic correctness
71
71
  # otherwise.
72
72
  return Float64()
73
- if isinstance(sql_type, sqa.String):
73
+ if isinstance(sql_type, sa.String):
74
74
  return String()
75
- if isinstance(sql_type, sqa.Boolean):
75
+ if isinstance(sql_type, sa.Boolean):
76
76
  return Bool()
77
- if isinstance(sql_type, sqa.Date):
77
+ if isinstance(sql_type, sa.Date):
78
78
  return Date()
79
- if isinstance(sql_type, sqa.Time):
79
+ if isinstance(sql_type, sa.Time):
80
80
  return Time()
81
- if isinstance(sql_type, sqa.DateTime):
81
+ if isinstance(sql_type, sa.DateTime):
82
82
  return Datetime()
83
- if isinstance(sql_type, sqa.Interval):
83
+ if isinstance(sql_type, sa.Interval):
84
84
  return Duration()
85
- if isinstance(sql_type, sqa.ARRAY):
85
+ if isinstance(sql_type, sa.ARRAY):
86
86
  return List(Dtype.from_sql(sql_type.item_type))
87
- if isinstance(sql_type, sqa.types.NullType):
87
+ if isinstance(sql_type, sa.types.NullType):
88
88
  return NullType()
89
89
 
90
90
  raise TypeError
@@ -184,7 +184,7 @@ class Dtype:
184
184
  raise TypeError
185
185
  if pa.types.is_decimal(arrow_type):
186
186
  # We don't recommend using Decimal in dataframes, but we support it.
187
- return Decimal()
187
+ return Decimal(arrow_type.precision, arrow_type.scale)
188
188
  if pa.types.is_string(arrow_type):
189
189
  return String()
190
190
  if pa.types.is_boolean(arrow_type):
@@ -217,6 +217,8 @@ class Dtype:
217
217
  return List(Dtype.from_polars(polars_type.inner))
218
218
  if isinstance(polars_type, pl.Enum):
219
219
  return Enum(*polars_type.categories)
220
+ if isinstance(polars_type, pl.Decimal):
221
+ return Decimal(polars_type.precision, polars_type.scale)
220
222
 
221
223
  return {
222
224
  pl.Int64: Int64(),
@@ -229,7 +231,6 @@ class Dtype:
229
231
  pl.UInt8: UInt8(),
230
232
  pl.Float64: Float64(),
231
233
  pl.Float32: Float32(),
232
- pl.Decimal: Decimal(),
233
234
  pl.Utf8: String(),
234
235
  pl.Boolean: Bool(),
235
236
  pl.Datetime: Datetime(),
@@ -241,29 +242,28 @@ class Dtype:
241
242
 
242
243
  def to_sql(self):
243
244
  """Convert this Dtype to a SQL type."""
244
- import sqlalchemy as sqa
245
+ import sqlalchemy as sa
245
246
 
246
247
  return {
247
- Int(): sqa.BigInteger(), # we default to 64 bit
248
- Int8(): sqa.SmallInteger(),
249
- Int16(): sqa.SmallInteger(),
250
- Int32(): sqa.Integer(),
251
- Int64(): sqa.BigInteger(),
252
- UInt8(): sqa.SmallInteger(),
253
- UInt16(): sqa.Integer(),
254
- UInt32(): sqa.BigInteger(),
255
- UInt64(): sqa.BigInteger(),
256
- Float(): sqa.Float(53), # we default to 64 bit
257
- Float32(): sqa.Float(24),
258
- Float64(): sqa.Float(53),
259
- Decimal(): sqa.DECIMAL(),
260
- String(): sqa.String(),
261
- Bool(): sqa.Boolean(),
262
- Date(): sqa.Date(),
263
- Time(): sqa.Time(),
264
- Datetime(): sqa.DateTime(),
265
- Duration(): sqa.Interval(),
266
- NullType(): sqa.types.NullType(),
248
+ Int(): sa.BigInteger(), # we default to 64 bit
249
+ Int8(): sa.SmallInteger(),
250
+ Int16(): sa.SmallInteger(),
251
+ Int32(): sa.Integer(),
252
+ Int64(): sa.BigInteger(),
253
+ UInt8(): sa.SmallInteger(),
254
+ UInt16(): sa.Integer(),
255
+ UInt32(): sa.BigInteger(),
256
+ UInt64(): sa.BigInteger(),
257
+ Float(): sa.Float(53), # we default to 64 bit
258
+ Float32(): sa.Float(24),
259
+ Float64(): sa.Float(53),
260
+ String(): sa.String(),
261
+ Bool(): sa.Boolean(),
262
+ Date(): sa.Date(),
263
+ Time(): sa.Time(),
264
+ Datetime(): sa.DateTime(),
265
+ Duration(): sa.Interval(),
266
+ NullType(): sa.types.NullType(),
267
267
  }[self]
268
268
 
269
269
  def to_pandas(self, backend: PandasBackend = PandasBackend.ARROW):
@@ -273,7 +273,7 @@ class Dtype:
273
273
  if backend == PandasBackend.NUMPY:
274
274
  return self.to_pandas_nullable(backend)
275
275
  if backend == PandasBackend.ARROW:
276
- if self == String() or isinstance(self, Enum):
276
+ if isinstance(self, String) or isinstance(self, Enum):
277
277
  return pd.StringDtype(storage="pyarrow")
278
278
  return pd.ArrowDtype(self.to_arrow())
279
279
 
@@ -321,7 +321,7 @@ class Dtype:
321
321
  Float(): pd.Float64Dtype(), # we default to 64 bit
322
322
  Float32(): pd.Float32Dtype(),
323
323
  Float64(): pd.Float64Dtype(),
324
- Decimal(): pd.Float64Dtype(), # NumericDtype is
324
+ Decimal(): pd.Float64Dtype(), # NumericDtype exists but is not used
325
325
  String(): pd.StringDtype(),
326
326
  Bool(): pd.BooleanDtype(),
327
327
  Date(): "datetime64[s]",
@@ -334,9 +334,6 @@ class Dtype:
334
334
  """Convert this Dtype to a PyArrow type."""
335
335
  import pyarrow as pa
336
336
 
337
- if isinstance(self, Enum):
338
- return pa.string()
339
-
340
337
  return {
341
338
  Int(): pa.int64(), # we default to 64 bit
342
339
  Int8(): pa.int8(),
@@ -350,7 +347,6 @@ class Dtype:
350
347
  Float(): pa.float64(), # we default to 64 bit
351
348
  Float32(): pa.float32(),
352
349
  Float64(): pa.float64(),
353
- Decimal(): pa.decimal128(35, 10), # Arbitrary precision
354
350
  String(): pa.string(),
355
351
  Bool(): pa.bool_(),
356
352
  Date(): pa.date32(),
@@ -383,8 +379,6 @@ class Dtype:
383
379
  Float(): pl.Float64, # we default to 64 bit
384
380
  Float64(): pl.Float64,
385
381
  Float32(): pl.Float32,
386
- Decimal(): pl.Decimal(scale=10), # Arbitrary precision
387
- String(): pl.Utf8,
388
382
  Bool(): pl.Boolean,
389
383
  Datetime(): pl.Datetime("us"),
390
384
  Duration(): pl.Duration("us"),
@@ -406,7 +400,43 @@ class Float64(Float): ...
406
400
  class Float32(Float): ...
407
401
 
408
402
 
409
- class Decimal(Float): ...
403
+ class Decimal(Float):
404
+ def __init__(self, precision: int | None = None, scale: int | None = None):
405
+ """
406
+ Initialize a Decimal Dtype.
407
+
408
+ Default is Decimal(31,10) which is the highest precision that works with DB2.
409
+ If you like to save memory, Decimal(15,6) will get you quite far as well.
410
+
411
+ :param precision: total number of digits in the number
412
+ If not specified, it is assumed to be 31.
413
+ :param scale: number of digits after the decimal point
414
+ If not specified, it is assumed to be (precision//3+1).
415
+ """
416
+ self.precision = precision or 31
417
+ self.scale = scale or (self.precision // 3 + 1)
418
+
419
+ def to_sql(self):
420
+ import sqlalchemy as sa
421
+
422
+ return sa.Numeric(self.precision, self.scale)
423
+
424
+ def to_polars(self):
425
+ import polars as pl
426
+
427
+ return pl.Decimal(self.precision, self.scale)
428
+
429
+ def to_arrow(self):
430
+ import pyarrow as pa
431
+
432
+ if self.precision > 38:
433
+ return pa.decimal256(self.precision, self.scale)
434
+ elif self.precision > 18:
435
+ return pa.decimal128(self.precision, self.scale)
436
+ elif self.precision > 9:
437
+ return pa.decimal64(self.precision, self.scale)
438
+ else:
439
+ return pa.decimal32(self.precision, self.scale)
410
440
 
411
441
 
412
442
  class Int(Dtype):
@@ -439,7 +469,32 @@ class UInt16(Int): ...
439
469
  class UInt8(Int): ...
440
470
 
441
471
 
442
- class String(Dtype): ...
472
+ class String(Dtype):
473
+ def __init__(self, max_length: int | None = None):
474
+ """
475
+ Initialize a String Dtype.
476
+
477
+ :param max_length: maximum length of string
478
+ This length will only be used for specifying fixed length strings in SQL.
479
+ Thus, the meaning of characters vs. bytes is dependent on the SQL dialect.
480
+ """
481
+ self.max_length = max_length
482
+
483
+ def to_sql(self):
484
+ """Convert this Dtype to a SQL type."""
485
+ import sqlalchemy as sa
486
+
487
+ return sa.String(length=self.max_length)
488
+
489
+ def to_polars(self):
490
+ import polars as pl
491
+
492
+ return pl.Utf8
493
+
494
+ def to_arrow(self):
495
+ import pyarrow as pa
496
+
497
+ return pa.string()
443
498
 
444
499
 
445
500
  class Bool(Dtype): ...
@@ -474,9 +529,9 @@ class List(Dtype):
474
529
  return f"List[{repr(self.inner)}]"
475
530
 
476
531
  def to_sql(self):
477
- import sqlalchemy as sqa
532
+ import sqlalchemy as sa
478
533
 
479
- return sqa.ARRAY(self.inner.to_sql())
534
+ return sa.ARRAY(self.inner.to_sql())
480
535
 
481
536
  def to_polars(self):
482
537
  import polars as pl
@@ -494,6 +549,9 @@ class Enum(String):
494
549
  if not all(isinstance(c, str) for c in categories):
495
550
  raise TypeError("arguments for `Enum` must have type `str`")
496
551
  self.categories = list(categories)
552
+ self.max_length = (
553
+ max([len(c) for c in categories]) if len(categories) > 0 else None
554
+ )
497
555
 
498
556
  def __eq__(self, rhs):
499
557
  return isinstance(rhs, Enum) and self.categories == rhs.categories
@@ -509,11 +567,6 @@ class Enum(String):
509
567
 
510
568
  return pl.Enum(self.categories)
511
569
 
512
- def to_sql(self):
513
- import sqlalchemy as sqa
514
-
515
- return sqa.String()
516
-
517
570
  def to_arrow(self):
518
571
  import pyarrow as pa
519
572
 
@@ -530,6 +583,7 @@ class Enum(String):
530
583
  pa.dictionary(pa.uint32(), pa.large_string()),
531
584
  nullable=nullable,
532
585
  metadata={
533
- "_PL_ENUM_VALUES": "".join([f"{len(c)};{c}" for c in self.categories])
586
+ # the key might change with polars versions
587
+ "_PL_ENUM_VALUES2": "".join([f"{len(c)};{c}" for c in self.categories])
534
588
  },
535
589
  )
@@ -52,11 +52,14 @@ def hash_polars_dataframe(df: pl.DataFrame, use_init_repr=False) -> str:
52
52
  ]:
53
53
  df = df.with_columns(
54
54
  pl.col(struct_col_name).struct.rename_fields(
55
- [stable_hash(struct_col_name, struct_field_name)]
55
+ [stable_hash(struct_col_name, struct_dtype)]
56
56
  )
57
- for struct_col_name, struct_field_name in struct_cols_and_dtypes
57
+ for struct_col_name, struct_dtype in struct_cols_and_dtypes
58
58
  ).unnest(
59
- struct_col_name for struct_col_name, _ in struct_cols_and_dtypes
59
+ [
60
+ struct_col_name
61
+ for struct_col_name, _ in struct_cols_and_dtypes
62
+ ]
60
63
  )
61
64
  return df
62
65
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pydiverse-common
3
- Version: 0.3.9
3
+ Version: 0.3.10
4
4
  Summary: Common functionality shared between pydiverse libraries
5
5
  Author: QuantCo, Inc.
6
6
  Author-email: Martin Trautmann <windiana@users.sf.net>, Finn Rudolph <finn.rudolph@t-online.de>
@@ -1,5 +1,5 @@
1
1
  pydiverse/common/__init__.py,sha256=J7b4iStFyaEMYre_jdlZ4l_8dLyrMWCIpQdsMQcB8aI,806
2
- pydiverse/common/dtypes.py,sha256=TN3UjI-4uBT_CB_3mLowOYq5GPcyvSbI3N6zHpxxV3c,16885
2
+ pydiverse/common/dtypes.py,sha256=ekeH6Ygnfy6sTsMVG4CcEdGHH_C6Z550tkdkaMQNIyY,18782
3
3
  pydiverse/common/testing.py,sha256=FcivI5wn0X3gzJhwnysKvCOgjSTTXaN6FtSFJ72jfSg,341
4
4
  pydiverse/common/version.py,sha256=1IU_m4r76_Qq0u-Tyo2_bERZFOkh0ZFueVzDqcCfLO0,336
5
5
  pydiverse/common/errors/__init__.py,sha256=FNeEfVbUa23b9sHkFsmxHYhY6sRgjaZysPQmlovpJrI,262
@@ -8,10 +8,10 @@ pydiverse/common/util/computation_tracing.py,sha256=HeXRHRUI8vxpzQ27Xcpa0StndSTP
8
8
  pydiverse/common/util/deep_map.py,sha256=JtY5ViWMMelOiLzPF7ZjzruCfB-bETISGxCk37qETxg,2540
9
9
  pydiverse/common/util/deep_merge.py,sha256=bV5p5_lsC-9nFah28EiEyG2h6U3Z5AuTqSooxOgCHN0,1929
10
10
  pydiverse/common/util/disposable.py,sha256=4XoGz70YRWA9TAqnUBvRCTAdsOGBviFN0gzxU7veY9o,993
11
- pydiverse/common/util/hashing.py,sha256=8Z1NybJ_zd3ONpn5annHGjowwArWkd2ZkCtlb3dtz_Q,4576
11
+ pydiverse/common/util/hashing.py,sha256=EofnKULVKXv-S9kry0mOqHU5bxPGomCtr6XfYqIhGgc,4650
12
12
  pydiverse/common/util/import_.py,sha256=K7dSgz4YyrqEvqhoOzbwgD7D8HScMoO5XoSWtjbaoUs,4056
13
13
  pydiverse/common/util/structlog.py,sha256=xxhauxMuyxcKXTVg1MiPTkuvPBj8Zcr4o_v8Bq59Nig,3778
14
- pydiverse_common-0.3.9.dist-info/METADATA,sha256=_P0mMant1LZMCa4p5tFBnpr0cvLIYXZXfu5xmYuih9Q,3399
15
- pydiverse_common-0.3.9.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
16
- pydiverse_common-0.3.9.dist-info/licenses/LICENSE,sha256=AcE6SDVuAq6v9ZLE_8eOCe_NvSE0rAPR3NR7lSowYh4,1517
17
- pydiverse_common-0.3.9.dist-info/RECORD,,
14
+ pydiverse_common-0.3.10.dist-info/METADATA,sha256=8xmLOqUJZbuETqz49ve_W8NNVkHGyaeOwT2exOhk0uY,3400
15
+ pydiverse_common-0.3.10.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
16
+ pydiverse_common-0.3.10.dist-info/licenses/LICENSE,sha256=AcE6SDVuAq6v9ZLE_8eOCe_NvSE0rAPR3NR7lSowYh4,1517
17
+ pydiverse_common-0.3.10.dist-info/RECORD,,