pydiverse-common 0.3.9__py3-none-any.whl → 0.3.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -23,6 +23,10 @@ class Dtype:
23
23
  """Return a string representation of this dtype."""
24
24
  return self.__class__.__name__
25
25
 
26
+ def __str__(self):
27
+ """Return a string representation of this dtype."""
28
+ return self.__repr__()
29
+
26
30
  @classmethod
27
31
  def is_int(cls):
28
32
  """Return ``True`` if this dtype is an integer type."""
@@ -49,20 +53,20 @@ class Dtype:
49
53
  @staticmethod
50
54
  def from_sql(sql_type) -> "Dtype":
51
55
  """Convert a SQL type to a Dtype."""
52
- import sqlalchemy as sqa
56
+ import sqlalchemy as sa
53
57
 
54
- if isinstance(sql_type, sqa.SmallInteger):
58
+ if isinstance(sql_type, sa.SmallInteger):
55
59
  return Int16()
56
- if isinstance(sql_type, sqa.BigInteger):
60
+ if isinstance(sql_type, sa.BigInteger):
57
61
  return Int64()
58
- if isinstance(sql_type, sqa.Integer):
62
+ if isinstance(sql_type, sa.Integer):
59
63
  return Int32()
60
- if isinstance(sql_type, sqa.Float):
64
+ if isinstance(sql_type, sa.Float):
61
65
  precision = sql_type.precision or 53
62
66
  if precision <= 24:
63
67
  return Float32()
64
68
  return Float64()
65
- if isinstance(sql_type, sqa.Numeric | sqa.DECIMAL):
69
+ if isinstance(sql_type, sa.Numeric | sa.DECIMAL):
66
70
  # Just to be safe, we always use FLOAT64 for fixpoint numbers.
67
71
  # Databases are obsessed about fixpoint. However, in dataframes, it
68
72
  # is more common to just work with double precision floating point.
@@ -70,21 +74,21 @@ class Dtype:
70
74
  # Decimal to Float64 whenever it cannot guarantee semantic correctness
71
75
  # otherwise.
72
76
  return Float64()
73
- if isinstance(sql_type, sqa.String):
74
- return String()
75
- if isinstance(sql_type, sqa.Boolean):
77
+ if isinstance(sql_type, sa.String):
78
+ return String(sql_type.length)
79
+ if isinstance(sql_type, sa.Boolean):
76
80
  return Bool()
77
- if isinstance(sql_type, sqa.Date):
81
+ if isinstance(sql_type, sa.Date):
78
82
  return Date()
79
- if isinstance(sql_type, sqa.Time):
83
+ if isinstance(sql_type, sa.Time):
80
84
  return Time()
81
- if isinstance(sql_type, sqa.DateTime):
85
+ if isinstance(sql_type, sa.DateTime):
82
86
  return Datetime()
83
- if isinstance(sql_type, sqa.Interval):
87
+ if isinstance(sql_type, sa.Interval):
84
88
  return Duration()
85
- if isinstance(sql_type, sqa.ARRAY):
89
+ if isinstance(sql_type, sa.ARRAY):
86
90
  return List(Dtype.from_sql(sql_type.item_type))
87
- if isinstance(sql_type, sqa.types.NullType):
91
+ if isinstance(sql_type, sa.types.NullType):
88
92
  return NullType()
89
93
 
90
94
  raise TypeError
@@ -184,7 +188,7 @@ class Dtype:
184
188
  raise TypeError
185
189
  if pa.types.is_decimal(arrow_type):
186
190
  # We don't recommend using Decimal in dataframes, but we support it.
187
- return Decimal()
191
+ return Decimal(arrow_type.precision, arrow_type.scale)
188
192
  if pa.types.is_string(arrow_type):
189
193
  return String()
190
194
  if pa.types.is_boolean(arrow_type):
@@ -217,6 +221,8 @@ class Dtype:
217
221
  return List(Dtype.from_polars(polars_type.inner))
218
222
  if isinstance(polars_type, pl.Enum):
219
223
  return Enum(*polars_type.categories)
224
+ if isinstance(polars_type, pl.Decimal):
225
+ return Decimal(polars_type.precision, polars_type.scale)
220
226
 
221
227
  return {
222
228
  pl.Int64: Int64(),
@@ -229,7 +235,6 @@ class Dtype:
229
235
  pl.UInt8: UInt8(),
230
236
  pl.Float64: Float64(),
231
237
  pl.Float32: Float32(),
232
- pl.Decimal: Decimal(),
233
238
  pl.Utf8: String(),
234
239
  pl.Boolean: Bool(),
235
240
  pl.Datetime: Datetime(),
@@ -241,29 +246,28 @@ class Dtype:
241
246
 
242
247
  def to_sql(self):
243
248
  """Convert this Dtype to a SQL type."""
244
- import sqlalchemy as sqa
249
+ import sqlalchemy as sa
245
250
 
246
251
  return {
247
- Int(): sqa.BigInteger(), # we default to 64 bit
248
- Int8(): sqa.SmallInteger(),
249
- Int16(): sqa.SmallInteger(),
250
- Int32(): sqa.Integer(),
251
- Int64(): sqa.BigInteger(),
252
- UInt8(): sqa.SmallInteger(),
253
- UInt16(): sqa.Integer(),
254
- UInt32(): sqa.BigInteger(),
255
- UInt64(): sqa.BigInteger(),
256
- Float(): sqa.Float(53), # we default to 64 bit
257
- Float32(): sqa.Float(24),
258
- Float64(): sqa.Float(53),
259
- Decimal(): sqa.DECIMAL(),
260
- String(): sqa.String(),
261
- Bool(): sqa.Boolean(),
262
- Date(): sqa.Date(),
263
- Time(): sqa.Time(),
264
- Datetime(): sqa.DateTime(),
265
- Duration(): sqa.Interval(),
266
- NullType(): sqa.types.NullType(),
252
+ Int(): sa.BigInteger(), # we default to 64 bit
253
+ Int8(): sa.SmallInteger(),
254
+ Int16(): sa.SmallInteger(),
255
+ Int32(): sa.Integer(),
256
+ Int64(): sa.BigInteger(),
257
+ UInt8(): sa.SmallInteger(),
258
+ UInt16(): sa.Integer(),
259
+ UInt32(): sa.BigInteger(),
260
+ UInt64(): sa.BigInteger(),
261
+ Float(): sa.Float(53), # we default to 64 bit
262
+ Float32(): sa.Float(24),
263
+ Float64(): sa.Float(53),
264
+ String(): sa.String(),
265
+ Bool(): sa.Boolean(),
266
+ Date(): sa.Date(),
267
+ Time(): sa.Time(),
268
+ Datetime(): sa.DateTime(),
269
+ Duration(): sa.Interval(),
270
+ NullType(): sa.types.NullType(),
267
271
  }[self]
268
272
 
269
273
  def to_pandas(self, backend: PandasBackend = PandasBackend.ARROW):
@@ -273,7 +277,7 @@ class Dtype:
273
277
  if backend == PandasBackend.NUMPY:
274
278
  return self.to_pandas_nullable(backend)
275
279
  if backend == PandasBackend.ARROW:
276
- if self == String() or isinstance(self, Enum):
280
+ if isinstance(self, String) or isinstance(self, Enum):
277
281
  return pd.StringDtype(storage="pyarrow")
278
282
  return pd.ArrowDtype(self.to_arrow())
279
283
 
@@ -321,7 +325,7 @@ class Dtype:
321
325
  Float(): pd.Float64Dtype(), # we default to 64 bit
322
326
  Float32(): pd.Float32Dtype(),
323
327
  Float64(): pd.Float64Dtype(),
324
- Decimal(): pd.Float64Dtype(), # NumericDtype is
328
+ Decimal(): pd.Float64Dtype(), # NumericDtype exists but is not used
325
329
  String(): pd.StringDtype(),
326
330
  Bool(): pd.BooleanDtype(),
327
331
  Date(): "datetime64[s]",
@@ -334,9 +338,6 @@ class Dtype:
334
338
  """Convert this Dtype to a PyArrow type."""
335
339
  import pyarrow as pa
336
340
 
337
- if isinstance(self, Enum):
338
- return pa.string()
339
-
340
341
  return {
341
342
  Int(): pa.int64(), # we default to 64 bit
342
343
  Int8(): pa.int8(),
@@ -350,7 +351,6 @@ class Dtype:
350
351
  Float(): pa.float64(), # we default to 64 bit
351
352
  Float32(): pa.float32(),
352
353
  Float64(): pa.float64(),
353
- Decimal(): pa.decimal128(35, 10), # Arbitrary precision
354
354
  String(): pa.string(),
355
355
  Bool(): pa.bool_(),
356
356
  Date(): pa.date32(),
@@ -383,8 +383,6 @@ class Dtype:
383
383
  Float(): pl.Float64, # we default to 64 bit
384
384
  Float64(): pl.Float64,
385
385
  Float32(): pl.Float32,
386
- Decimal(): pl.Decimal(scale=10), # Arbitrary precision
387
- String(): pl.Utf8,
388
386
  Bool(): pl.Boolean,
389
387
  Datetime(): pl.Datetime("us"),
390
388
  Duration(): pl.Duration("us"),
@@ -406,7 +404,57 @@ class Float64(Float): ...
406
404
  class Float32(Float): ...
407
405
 
408
406
 
409
- class Decimal(Float): ...
407
+ class Decimal(Float):
408
+ def __init__(self, precision: int | None = None, scale: int | None = None):
409
+ """
410
+ Initialize a Decimal Dtype.
411
+
412
+ Default is Decimal(31,10) which is the highest precision that works with DB2.
413
+ If you like to save memory, Decimal(15,6) will get you quite far as well.
414
+
415
+ :param precision: total number of digits in the number
416
+ If not specified, it is assumed to be 31.
417
+ :param scale: number of digits after the decimal point
418
+ If not specified, it is assumed to be (precision//3+1).
419
+ """
420
+ self.precision = precision or 31
421
+ self.scale = scale or (self.precision // 3 + 1)
422
+
423
+ def __eq__(self, rhs):
424
+ return (
425
+ isinstance(rhs, self.__class__)
426
+ and self.precision == rhs.precision
427
+ and self.scale == rhs.scale
428
+ )
429
+
430
+ def __hash__(self):
431
+ return hash((self.__class__.__name__, self.precision, self.scale))
432
+
433
+ def __repr__(self):
434
+ """Return a string representation of this dtype."""
435
+ return f"{self.__class__.__name__}({self.precision}, {self.scale})"
436
+
437
+ def to_sql(self):
438
+ import sqlalchemy as sa
439
+
440
+ return sa.Numeric(self.precision, self.scale)
441
+
442
+ def to_polars(self):
443
+ import polars as pl
444
+
445
+ return pl.Decimal(self.precision, self.scale)
446
+
447
+ def to_arrow(self):
448
+ import pyarrow as pa
449
+
450
+ if self.precision > 38:
451
+ return pa.decimal256(self.precision, self.scale)
452
+ elif self.precision > 18:
453
+ return pa.decimal128(self.precision, self.scale)
454
+ elif self.precision > 9:
455
+ return pa.decimal64(self.precision, self.scale)
456
+ else:
457
+ return pa.decimal32(self.precision, self.scale)
410
458
 
411
459
 
412
460
  class Int(Dtype):
@@ -439,7 +487,42 @@ class UInt16(Int): ...
439
487
  class UInt8(Int): ...
440
488
 
441
489
 
442
- class String(Dtype): ...
490
+ class String(Dtype):
491
+ def __init__(self, max_length: int | None = None):
492
+ """
493
+ Initialize a String Dtype.
494
+
495
+ :param max_length: maximum length of string
496
+ This length will only be used for specifying fixed length strings in SQL.
497
+ Thus, the meaning of characters vs. bytes is dependent on the SQL dialect.
498
+ """
499
+ self.max_length = max_length
500
+
501
+ def __eq__(self, rhs):
502
+ return isinstance(rhs, self.__class__) and self.max_length == rhs.max_length
503
+
504
+ def __hash__(self):
505
+ return hash((self.__class__.__name__, self.max_length))
506
+
507
+ def __repr__(self):
508
+ """Return a string representation of this dtype."""
509
+ return f"{self.__class__.__name__}({self.max_length})"
510
+
511
+ def to_sql(self):
512
+ """Convert this Dtype to a SQL type."""
513
+ import sqlalchemy as sa
514
+
515
+ return sa.String(length=self.max_length)
516
+
517
+ def to_polars(self):
518
+ import polars as pl
519
+
520
+ return pl.Utf8
521
+
522
+ def to_arrow(self):
523
+ import pyarrow as pa
524
+
525
+ return pa.string()
443
526
 
444
527
 
445
528
  class Bool(Dtype): ...
@@ -468,15 +551,16 @@ class List(Dtype):
468
551
  return isinstance(rhs, List) and self.inner == rhs.inner
469
552
 
470
553
  def __hash__(self):
471
- return hash((0, hash(self.inner)))
554
+ return hash((self.__class__.__name__, hash(self.inner)))
472
555
 
473
556
  def __repr__(self):
474
- return f"List[{repr(self.inner)}]"
557
+ """Return a string representation of this dtype."""
558
+ return f"{self.__class__.__name__}[{self.inner}]"
475
559
 
476
560
  def to_sql(self):
477
- import sqlalchemy as sqa
561
+ import sqlalchemy as sa
478
562
 
479
- return sqa.ARRAY(self.inner.to_sql())
563
+ return sa.ARRAY(self.inner.to_sql())
480
564
 
481
565
  def to_polars(self):
482
566
  import polars as pl
@@ -494,26 +578,26 @@ class Enum(String):
494
578
  if not all(isinstance(c, str) for c in categories):
495
579
  raise TypeError("arguments for `Enum` must have type `str`")
496
580
  self.categories = list(categories)
581
+ self.max_length = (
582
+ max([len(c) for c in categories]) if len(categories) > 0 else None
583
+ )
497
584
 
498
585
  def __eq__(self, rhs):
499
586
  return isinstance(rhs, Enum) and self.categories == rhs.categories
500
587
 
501
- def __repr__(self) -> str:
502
- return f"Enum[{', '.join(repr(c) for c in self.categories)}]"
503
-
504
588
  def __hash__(self):
505
- return hash(tuple(self.categories))
589
+ return hash((self.__class__.__name__, tuple(self.categories)))
590
+
591
+ def __repr__(self) -> str:
592
+ return (
593
+ f"{self.__class__.__name__}[{', '.join(repr(c) for c in self.categories)}]"
594
+ )
506
595
 
507
596
  def to_polars(self):
508
597
  import polars as pl
509
598
 
510
599
  return pl.Enum(self.categories)
511
600
 
512
- def to_sql(self):
513
- import sqlalchemy as sqa
514
-
515
- return sqa.String()
516
-
517
601
  def to_arrow(self):
518
602
  import pyarrow as pa
519
603
 
@@ -530,6 +614,7 @@ class Enum(String):
530
614
  pa.dictionary(pa.uint32(), pa.large_string()),
531
615
  nullable=nullable,
532
616
  metadata={
533
- "_PL_ENUM_VALUES": "".join([f"{len(c)};{c}" for c in self.categories])
617
+ # the key might change with polars versions
618
+ "_PL_ENUM_VALUES2": "".join([f"{len(c)};{c}" for c in self.categories])
534
619
  },
535
620
  )
@@ -52,11 +52,14 @@ def hash_polars_dataframe(df: pl.DataFrame, use_init_repr=False) -> str:
52
52
  ]:
53
53
  df = df.with_columns(
54
54
  pl.col(struct_col_name).struct.rename_fields(
55
- [stable_hash(struct_col_name, struct_field_name)]
55
+ [stable_hash(struct_col_name, struct_dtype)]
56
56
  )
57
- for struct_col_name, struct_field_name in struct_cols_and_dtypes
57
+ for struct_col_name, struct_dtype in struct_cols_and_dtypes
58
58
  ).unnest(
59
- struct_col_name for struct_col_name, _ in struct_cols_and_dtypes
59
+ [
60
+ struct_col_name
61
+ for struct_col_name, _ in struct_cols_and_dtypes
62
+ ]
60
63
  )
61
64
  return df
62
65
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pydiverse-common
3
- Version: 0.3.9
3
+ Version: 0.3.11
4
4
  Summary: Common functionality shared between pydiverse libraries
5
5
  Author: QuantCo, Inc.
6
6
  Author-email: Martin Trautmann <windiana@users.sf.net>, Finn Rudolph <finn.rudolph@t-online.de>
@@ -1,5 +1,5 @@
1
1
  pydiverse/common/__init__.py,sha256=J7b4iStFyaEMYre_jdlZ4l_8dLyrMWCIpQdsMQcB8aI,806
2
- pydiverse/common/dtypes.py,sha256=TN3UjI-4uBT_CB_3mLowOYq5GPcyvSbI3N6zHpxxV3c,16885
2
+ pydiverse/common/dtypes.py,sha256=QJUd4HwFLsO2uJ_3kqQZ8QwdoAk-qIhixKQ8n-x4uk0,19879
3
3
  pydiverse/common/testing.py,sha256=FcivI5wn0X3gzJhwnysKvCOgjSTTXaN6FtSFJ72jfSg,341
4
4
  pydiverse/common/version.py,sha256=1IU_m4r76_Qq0u-Tyo2_bERZFOkh0ZFueVzDqcCfLO0,336
5
5
  pydiverse/common/errors/__init__.py,sha256=FNeEfVbUa23b9sHkFsmxHYhY6sRgjaZysPQmlovpJrI,262
@@ -8,10 +8,10 @@ pydiverse/common/util/computation_tracing.py,sha256=HeXRHRUI8vxpzQ27Xcpa0StndSTP
8
8
  pydiverse/common/util/deep_map.py,sha256=JtY5ViWMMelOiLzPF7ZjzruCfB-bETISGxCk37qETxg,2540
9
9
  pydiverse/common/util/deep_merge.py,sha256=bV5p5_lsC-9nFah28EiEyG2h6U3Z5AuTqSooxOgCHN0,1929
10
10
  pydiverse/common/util/disposable.py,sha256=4XoGz70YRWA9TAqnUBvRCTAdsOGBviFN0gzxU7veY9o,993
11
- pydiverse/common/util/hashing.py,sha256=8Z1NybJ_zd3ONpn5annHGjowwArWkd2ZkCtlb3dtz_Q,4576
11
+ pydiverse/common/util/hashing.py,sha256=EofnKULVKXv-S9kry0mOqHU5bxPGomCtr6XfYqIhGgc,4650
12
12
  pydiverse/common/util/import_.py,sha256=K7dSgz4YyrqEvqhoOzbwgD7D8HScMoO5XoSWtjbaoUs,4056
13
13
  pydiverse/common/util/structlog.py,sha256=xxhauxMuyxcKXTVg1MiPTkuvPBj8Zcr4o_v8Bq59Nig,3778
14
- pydiverse_common-0.3.9.dist-info/METADATA,sha256=_P0mMant1LZMCa4p5tFBnpr0cvLIYXZXfu5xmYuih9Q,3399
15
- pydiverse_common-0.3.9.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
16
- pydiverse_common-0.3.9.dist-info/licenses/LICENSE,sha256=AcE6SDVuAq6v9ZLE_8eOCe_NvSE0rAPR3NR7lSowYh4,1517
17
- pydiverse_common-0.3.9.dist-info/RECORD,,
14
+ pydiverse_common-0.3.11.dist-info/METADATA,sha256=mg6x4saQhm43MNjOLgwLHNZWkYIzNHsTwzJvey-a8Ig,3400
15
+ pydiverse_common-0.3.11.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
16
+ pydiverse_common-0.3.11.dist-info/licenses/LICENSE,sha256=AcE6SDVuAq6v9ZLE_8eOCe_NvSE0rAPR3NR7lSowYh4,1517
17
+ pydiverse_common-0.3.11.dist-info/RECORD,,