pydiverse-common 0.3.8__py3-none-any.whl → 0.3.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -49,20 +49,20 @@ class Dtype:
49
49
  @staticmethod
50
50
  def from_sql(sql_type) -> "Dtype":
51
51
  """Convert a SQL type to a Dtype."""
52
- import sqlalchemy as sqa
52
+ import sqlalchemy as sa
53
53
 
54
- if isinstance(sql_type, sqa.SmallInteger):
54
+ if isinstance(sql_type, sa.SmallInteger):
55
55
  return Int16()
56
- if isinstance(sql_type, sqa.BigInteger):
56
+ if isinstance(sql_type, sa.BigInteger):
57
57
  return Int64()
58
- if isinstance(sql_type, sqa.Integer):
58
+ if isinstance(sql_type, sa.Integer):
59
59
  return Int32()
60
- if isinstance(sql_type, sqa.Float):
60
+ if isinstance(sql_type, sa.Float):
61
61
  precision = sql_type.precision or 53
62
62
  if precision <= 24:
63
63
  return Float32()
64
64
  return Float64()
65
- if isinstance(sql_type, sqa.Numeric | sqa.DECIMAL):
65
+ if isinstance(sql_type, sa.Numeric | sa.DECIMAL):
66
66
  # Just to be safe, we always use FLOAT64 for fixpoint numbers.
67
67
  # Databases are obsessed about fixpoint. However, in dataframes, it
68
68
  # is more common to just work with double precision floating point.
@@ -70,21 +70,21 @@ class Dtype:
70
70
  # Decimal to Float64 whenever it cannot guarantee semantic correctness
71
71
  # otherwise.
72
72
  return Float64()
73
- if isinstance(sql_type, sqa.String):
73
+ if isinstance(sql_type, sa.String):
74
74
  return String()
75
- if isinstance(sql_type, sqa.Boolean):
75
+ if isinstance(sql_type, sa.Boolean):
76
76
  return Bool()
77
- if isinstance(sql_type, sqa.Date):
77
+ if isinstance(sql_type, sa.Date):
78
78
  return Date()
79
- if isinstance(sql_type, sqa.Time):
79
+ if isinstance(sql_type, sa.Time):
80
80
  return Time()
81
- if isinstance(sql_type, sqa.DateTime):
81
+ if isinstance(sql_type, sa.DateTime):
82
82
  return Datetime()
83
- if isinstance(sql_type, sqa.Interval):
83
+ if isinstance(sql_type, sa.Interval):
84
84
  return Duration()
85
- if isinstance(sql_type, sqa.ARRAY):
85
+ if isinstance(sql_type, sa.ARRAY):
86
86
  return List(Dtype.from_sql(sql_type.item_type))
87
- if isinstance(sql_type, sqa.types.NullType):
87
+ if isinstance(sql_type, sa.types.NullType):
88
88
  return NullType()
89
89
 
90
90
  raise TypeError
@@ -184,7 +184,7 @@ class Dtype:
184
184
  raise TypeError
185
185
  if pa.types.is_decimal(arrow_type):
186
186
  # We don't recommend using Decimal in dataframes, but we support it.
187
- return Decimal()
187
+ return Decimal(arrow_type.precision, arrow_type.scale)
188
188
  if pa.types.is_string(arrow_type):
189
189
  return String()
190
190
  if pa.types.is_boolean(arrow_type):
@@ -201,6 +201,11 @@ class Dtype:
201
201
  return NullType()
202
202
  if pa.types.is_list(arrow_type):
203
203
  return List(Dtype.from_arrow(arrow_type.value_type))
204
+ if pa.types.is_dictionary(arrow_type):
205
+ raise RuntimeError(
206
+ "Most likely this is an Enum type. But metadata about categories is "
207
+ "only in the pyarrow field and not in the pyarrow dtype"
208
+ )
204
209
  raise TypeError
205
210
 
206
211
  @staticmethod
@@ -212,6 +217,8 @@ class Dtype:
212
217
  return List(Dtype.from_polars(polars_type.inner))
213
218
  if isinstance(polars_type, pl.Enum):
214
219
  return Enum(*polars_type.categories)
220
+ if isinstance(polars_type, pl.Decimal):
221
+ return Decimal(polars_type.precision, polars_type.scale)
215
222
 
216
223
  return {
217
224
  pl.Int64: Int64(),
@@ -224,7 +231,6 @@ class Dtype:
224
231
  pl.UInt8: UInt8(),
225
232
  pl.Float64: Float64(),
226
233
  pl.Float32: Float32(),
227
- pl.Decimal: Decimal(),
228
234
  pl.Utf8: String(),
229
235
  pl.Boolean: Bool(),
230
236
  pl.Datetime: Datetime(),
@@ -236,29 +242,28 @@ class Dtype:
236
242
 
237
243
  def to_sql(self):
238
244
  """Convert this Dtype to a SQL type."""
239
- import sqlalchemy as sqa
245
+ import sqlalchemy as sa
240
246
 
241
247
  return {
242
- Int(): sqa.BigInteger(), # we default to 64 bit
243
- Int8(): sqa.SmallInteger(),
244
- Int16(): sqa.SmallInteger(),
245
- Int32(): sqa.Integer(),
246
- Int64(): sqa.BigInteger(),
247
- UInt8(): sqa.SmallInteger(),
248
- UInt16(): sqa.Integer(),
249
- UInt32(): sqa.BigInteger(),
250
- UInt64(): sqa.BigInteger(),
251
- Float(): sqa.Float(53), # we default to 64 bit
252
- Float32(): sqa.Float(24),
253
- Float64(): sqa.Float(53),
254
- Decimal(): sqa.DECIMAL(),
255
- String(): sqa.String(),
256
- Bool(): sqa.Boolean(),
257
- Date(): sqa.Date(),
258
- Time(): sqa.Time(),
259
- Datetime(): sqa.DateTime(),
260
- Duration(): sqa.Interval(),
261
- NullType(): sqa.types.NullType(),
248
+ Int(): sa.BigInteger(), # we default to 64 bit
249
+ Int8(): sa.SmallInteger(),
250
+ Int16(): sa.SmallInteger(),
251
+ Int32(): sa.Integer(),
252
+ Int64(): sa.BigInteger(),
253
+ UInt8(): sa.SmallInteger(),
254
+ UInt16(): sa.Integer(),
255
+ UInt32(): sa.BigInteger(),
256
+ UInt64(): sa.BigInteger(),
257
+ Float(): sa.Float(53), # we default to 64 bit
258
+ Float32(): sa.Float(24),
259
+ Float64(): sa.Float(53),
260
+ String(): sa.String(),
261
+ Bool(): sa.Boolean(),
262
+ Date(): sa.Date(),
263
+ Time(): sa.Time(),
264
+ Datetime(): sa.DateTime(),
265
+ Duration(): sa.Interval(),
266
+ NullType(): sa.types.NullType(),
262
267
  }[self]
263
268
 
264
269
  def to_pandas(self, backend: PandasBackend = PandasBackend.ARROW):
@@ -268,7 +273,7 @@ class Dtype:
268
273
  if backend == PandasBackend.NUMPY:
269
274
  return self.to_pandas_nullable(backend)
270
275
  if backend == PandasBackend.ARROW:
271
- if self == String():
276
+ if isinstance(self, String) or isinstance(self, Enum):
272
277
  return pd.StringDtype(storage="pyarrow")
273
278
  return pd.ArrowDtype(self.to_arrow())
274
279
 
@@ -316,7 +321,7 @@ class Dtype:
316
321
  Float(): pd.Float64Dtype(), # we default to 64 bit
317
322
  Float32(): pd.Float32Dtype(),
318
323
  Float64(): pd.Float64Dtype(),
319
- Decimal(): pd.Float64Dtype(), # NumericDtype is
324
+ Decimal(): pd.Float64Dtype(), # NumericDtype exists but is not used
320
325
  String(): pd.StringDtype(),
321
326
  Bool(): pd.BooleanDtype(),
322
327
  Date(): "datetime64[s]",
@@ -329,9 +334,6 @@ class Dtype:
329
334
  """Convert this Dtype to a PyArrow type."""
330
335
  import pyarrow as pa
331
336
 
332
- if isinstance(self, Enum):
333
- return pa.string()
334
-
335
337
  return {
336
338
  Int(): pa.int64(), # we default to 64 bit
337
339
  Int8(): pa.int8(),
@@ -345,7 +347,6 @@ class Dtype:
345
347
  Float(): pa.float64(), # we default to 64 bit
346
348
  Float32(): pa.float32(),
347
349
  Float64(): pa.float64(),
348
- Decimal(): pa.decimal128(35, 10), # Arbitrary precision
349
350
  String(): pa.string(),
350
351
  Bool(): pa.bool_(),
351
352
  Date(): pa.date32(),
@@ -355,6 +356,12 @@ class Dtype:
355
356
  NullType(): pa.null(),
356
357
  }[self]
357
358
 
359
+ def to_arrow_field(self, name: str, nullable: bool = True):
360
+ """Convert this Dtype to a PyArrow Field."""
361
+ import pyarrow as pa
362
+
363
+ return pa.field(name, self.to_arrow(), nullable=nullable)
364
+
358
365
  def to_polars(self: "Dtype"):
359
366
  """Convert this Dtype to a Polars type."""
360
367
  import polars as pl
@@ -372,8 +379,6 @@ class Dtype:
372
379
  Float(): pl.Float64, # we default to 64 bit
373
380
  Float64(): pl.Float64,
374
381
  Float32(): pl.Float32,
375
- Decimal(): pl.Decimal(scale=10), # Arbitrary precision
376
- String(): pl.Utf8,
377
382
  Bool(): pl.Boolean,
378
383
  Datetime(): pl.Datetime("us"),
379
384
  Duration(): pl.Duration("us"),
@@ -395,7 +400,43 @@ class Float64(Float): ...
395
400
  class Float32(Float): ...
396
401
 
397
402
 
398
- class Decimal(Float): ...
403
+ class Decimal(Float):
404
+ def __init__(self, precision: int | None = None, scale: int | None = None):
405
+ """
406
+ Initialize a Decimal Dtype.
407
+
408
+ Default is Decimal(31,10) which is the highest precision that works with DB2.
409
+ If you like to save memory, Decimal(15,6) will get you quite far as well.
410
+
411
+ :param precision: total number of digits in the number
412
+ If not specified, it is assumed to be 31.
413
+ :param scale: number of digits after the decimal point
414
+ If not specified, it is assumed to be (precision//3+1).
415
+ """
416
+ self.precision = precision or 31
417
+ self.scale = scale or (self.precision // 3 + 1)
418
+
419
+ def to_sql(self):
420
+ import sqlalchemy as sa
421
+
422
+ return sa.Numeric(self.precision, self.scale)
423
+
424
+ def to_polars(self):
425
+ import polars as pl
426
+
427
+ return pl.Decimal(self.precision, self.scale)
428
+
429
+ def to_arrow(self):
430
+ import pyarrow as pa
431
+
432
+ if self.precision > 38:
433
+ return pa.decimal256(self.precision, self.scale)
434
+ elif self.precision > 18:
435
+ return pa.decimal128(self.precision, self.scale)
436
+ elif self.precision > 9:
437
+ return pa.decimal64(self.precision, self.scale)
438
+ else:
439
+ return pa.decimal32(self.precision, self.scale)
399
440
 
400
441
 
401
442
  class Int(Dtype):
@@ -428,7 +469,32 @@ class UInt16(Int): ...
428
469
  class UInt8(Int): ...
429
470
 
430
471
 
431
- class String(Dtype): ...
472
+ class String(Dtype):
473
+ def __init__(self, max_length: int | None = None):
474
+ """
475
+ Initialize a String Dtype.
476
+
477
+ :param max_length: maximum length of string
478
+ This length will only be used for specifying fixed length strings in SQL.
479
+ Thus, the meaning of characters vs. bytes is dependent on the SQL dialect.
480
+ """
481
+ self.max_length = max_length
482
+
483
+ def to_sql(self):
484
+ """Convert this Dtype to a SQL type."""
485
+ import sqlalchemy as sa
486
+
487
+ return sa.String(length=self.max_length)
488
+
489
+ def to_polars(self):
490
+ import polars as pl
491
+
492
+ return pl.Utf8
493
+
494
+ def to_arrow(self):
495
+ import pyarrow as pa
496
+
497
+ return pa.string()
432
498
 
433
499
 
434
500
  class Bool(Dtype): ...
@@ -463,9 +529,9 @@ class List(Dtype):
463
529
  return f"List[{repr(self.inner)}]"
464
530
 
465
531
  def to_sql(self):
466
- import sqlalchemy as sqa
532
+ import sqlalchemy as sa
467
533
 
468
- return sqa.ARRAY(self.inner.to_sql())
534
+ return sa.ARRAY(self.inner.to_sql())
469
535
 
470
536
  def to_polars(self):
471
537
  import polars as pl
@@ -483,6 +549,9 @@ class Enum(String):
483
549
  if not all(isinstance(c, str) for c in categories):
484
550
  raise TypeError("arguments for `Enum` must have type `str`")
485
551
  self.categories = list(categories)
552
+ self.max_length = (
553
+ max([len(c) for c in categories]) if len(categories) > 0 else None
554
+ )
486
555
 
487
556
  def __eq__(self, rhs):
488
557
  return isinstance(rhs, Enum) and self.categories == rhs.categories
@@ -498,14 +567,23 @@ class Enum(String):
498
567
 
499
568
  return pl.Enum(self.categories)
500
569
 
501
- def to_sql(self):
502
- import sqlalchemy as sqa
503
-
504
- return sqa.String()
505
-
506
570
  def to_arrow(self):
507
571
  import pyarrow as pa
508
572
 
509
- # There is also pa.dictionary(), which seems to be kind of similar to an enum.
510
- # Maybe it is better to convert to this.
573
+ # enum categories can only be maintained in pyarrow field (see to_arrow_field)
511
574
  return pa.string()
575
+
576
+ def to_arrow_field(self, name: str, nullable: bool = True):
577
+ """Convert this Dtype to a PyArrow Field."""
578
+ import pyarrow as pa
579
+
580
+ # try to mimic what polars does
581
+ return pa.field(
582
+ name,
583
+ pa.dictionary(pa.uint32(), pa.large_string()),
584
+ nullable=nullable,
585
+ metadata={
586
+ # the key might change with polars versions
587
+ "_PL_ENUM_VALUES2": "".join([f"{len(c)};{c}" for c in self.categories])
588
+ },
589
+ )
@@ -52,11 +52,14 @@ def hash_polars_dataframe(df: pl.DataFrame, use_init_repr=False) -> str:
52
52
  ]:
53
53
  df = df.with_columns(
54
54
  pl.col(struct_col_name).struct.rename_fields(
55
- [stable_hash(struct_col_name, struct_field_name)]
55
+ [stable_hash(struct_col_name, struct_dtype)]
56
56
  )
57
- for struct_col_name, struct_field_name in struct_cols_and_dtypes
57
+ for struct_col_name, struct_dtype in struct_cols_and_dtypes
58
58
  ).unnest(
59
- struct_col_name for struct_col_name, _ in struct_cols_and_dtypes
59
+ [
60
+ struct_col_name
61
+ for struct_col_name, _ in struct_cols_and_dtypes
62
+ ]
60
63
  )
61
64
  return df
62
65
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pydiverse-common
3
- Version: 0.3.8
3
+ Version: 0.3.10
4
4
  Summary: Common functionality shared between pydiverse libraries
5
5
  Author: QuantCo, Inc.
6
6
  Author-email: Martin Trautmann <windiana@users.sf.net>, Finn Rudolph <finn.rudolph@t-online.de>
@@ -1,5 +1,5 @@
1
1
  pydiverse/common/__init__.py,sha256=J7b4iStFyaEMYre_jdlZ4l_8dLyrMWCIpQdsMQcB8aI,806
2
- pydiverse/common/dtypes.py,sha256=LYZKaKYq_4uI4kUhoaCTTo5j1SRurswIOfN11Bkz25A,15986
2
+ pydiverse/common/dtypes.py,sha256=ekeH6Ygnfy6sTsMVG4CcEdGHH_C6Z550tkdkaMQNIyY,18782
3
3
  pydiverse/common/testing.py,sha256=FcivI5wn0X3gzJhwnysKvCOgjSTTXaN6FtSFJ72jfSg,341
4
4
  pydiverse/common/version.py,sha256=1IU_m4r76_Qq0u-Tyo2_bERZFOkh0ZFueVzDqcCfLO0,336
5
5
  pydiverse/common/errors/__init__.py,sha256=FNeEfVbUa23b9sHkFsmxHYhY6sRgjaZysPQmlovpJrI,262
@@ -8,10 +8,10 @@ pydiverse/common/util/computation_tracing.py,sha256=HeXRHRUI8vxpzQ27Xcpa0StndSTP
8
8
  pydiverse/common/util/deep_map.py,sha256=JtY5ViWMMelOiLzPF7ZjzruCfB-bETISGxCk37qETxg,2540
9
9
  pydiverse/common/util/deep_merge.py,sha256=bV5p5_lsC-9nFah28EiEyG2h6U3Z5AuTqSooxOgCHN0,1929
10
10
  pydiverse/common/util/disposable.py,sha256=4XoGz70YRWA9TAqnUBvRCTAdsOGBviFN0gzxU7veY9o,993
11
- pydiverse/common/util/hashing.py,sha256=8Z1NybJ_zd3ONpn5annHGjowwArWkd2ZkCtlb3dtz_Q,4576
11
+ pydiverse/common/util/hashing.py,sha256=EofnKULVKXv-S9kry0mOqHU5bxPGomCtr6XfYqIhGgc,4650
12
12
  pydiverse/common/util/import_.py,sha256=K7dSgz4YyrqEvqhoOzbwgD7D8HScMoO5XoSWtjbaoUs,4056
13
13
  pydiverse/common/util/structlog.py,sha256=xxhauxMuyxcKXTVg1MiPTkuvPBj8Zcr4o_v8Bq59Nig,3778
14
- pydiverse_common-0.3.8.dist-info/METADATA,sha256=ptAGp299BY9NSaM-XEaojLzhL_KVc0SEY-MFqqqAwL0,3399
15
- pydiverse_common-0.3.8.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
16
- pydiverse_common-0.3.8.dist-info/licenses/LICENSE,sha256=AcE6SDVuAq6v9ZLE_8eOCe_NvSE0rAPR3NR7lSowYh4,1517
17
- pydiverse_common-0.3.8.dist-info/RECORD,,
14
+ pydiverse_common-0.3.10.dist-info/METADATA,sha256=8xmLOqUJZbuETqz49ve_W8NNVkHGyaeOwT2exOhk0uY,3400
15
+ pydiverse_common-0.3.10.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
16
+ pydiverse_common-0.3.10.dist-info/licenses/LICENSE,sha256=AcE6SDVuAq6v9ZLE_8eOCe_NvSE0rAPR3NR7lSowYh4,1517
17
+ pydiverse_common-0.3.10.dist-info/RECORD,,