patito 0.6.1__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
patito/polars.py CHANGED
@@ -9,6 +9,7 @@ from typing import (
9
9
  Dict,
10
10
  Generic,
11
11
  Iterable,
12
+ Literal,
12
13
  Optional,
13
14
  Sequence,
14
15
  Tuple,
@@ -19,9 +20,8 @@ from typing import (
19
20
  )
20
21
 
21
22
  import polars as pl
22
- from polars.type_aliases import IntoExpr
23
+ from polars._typing import IntoExpr
23
24
  from pydantic import AliasChoices, AliasPath, create_model
24
- from typing_extensions import Literal
25
25
 
26
26
  from patito._pydantic.column_info import ColumnInfo
27
27
  from patito.exceptions import MultipleRowsReturned, RowDoesNotExist
@@ -53,12 +53,10 @@ class LazyFrame(pl.LazyFrame, Generic[ModelType]):
53
53
  DataFrame.set_model(model) is implicitly invoked at collection.
54
54
 
55
55
  Args:
56
- ----
57
56
  model: A patito model which should be used to validate the final dataframe.
58
57
  If None is provided, the regular LazyFrame class will be returned.
59
58
 
60
59
  Returns:
61
- -------
62
60
  A custom LazyFrame model class where LazyFrame.model has been correctly
63
61
  "hard-coded" to the given model.
64
62
 
@@ -101,21 +99,17 @@ class LazyFrame(pl.LazyFrame, Generic[ModelType]):
101
99
  result of which will be used to populate the column values.
102
100
 
103
101
  Args:
104
- ----
105
102
  columns: Optionally, a list of column names to derive. If not provided, all
106
103
  columns are used.
107
104
 
108
105
  Returns:
109
- -------
110
106
  DataFrame[Model]: A new dataframe where all derivable columns are provided.
111
107
 
112
108
  Raises:
113
- ------
114
109
  TypeError: If the ``derived_from`` parameter of ``patito.Field`` is given
115
110
  as something else than a string or polars expression.
116
111
 
117
112
  Examples:
118
- --------
119
113
  >>> import patito as pt
120
114
  >>> import polars as pl
121
115
  >>> class Foo(pt.Model):
@@ -136,7 +130,7 @@ class LazyFrame(pl.LazyFrame, Generic[ModelType]):
136
130
  """
137
131
  derived_columns = []
138
132
  props = self.model._schema_properties()
139
- original_columns = set(self.columns)
133
+ original_columns = set(self.collect_schema())
140
134
  to_derive = self.model.derived_columns if columns is None else columns
141
135
  for column_name in to_derive:
142
136
  if column_name not in derived_columns:
@@ -189,8 +183,7 @@ class LazyFrame(pl.LazyFrame, Generic[ModelType]):
189
183
 
190
184
  limitation - AliasChoice validation type only supports selecting a single element of an array
191
185
 
192
- Returns
193
- -------
186
+ Returns:
194
187
  DataFrame[Model]: A dataframe with columns normalized to model names.
195
188
 
196
189
  """
@@ -200,15 +193,15 @@ class LazyFrame(pl.LazyFrame, Generic[ModelType]):
200
193
 
201
194
  def to_expr(va: str | AliasPath | AliasChoices) -> Optional[pl.Expr]:
202
195
  if isinstance(va, str):
203
- return pl.col(va) if va in self.columns else None
196
+ return pl.col(va) if va in self.collect_schema() else None
204
197
  elif isinstance(va, AliasPath):
205
198
  if len(va.path) != 2 or not isinstance(va.path[1], int):
206
199
  raise NotImplementedError(
207
200
  f"TODO figure out how this AliasPath behaves ({va})"
208
201
  )
209
202
  return (
210
- pl.col(va.path[0]).list.get(va.path[1])
211
- if va.path[0] in self.columns
203
+ pl.col(va.path[0]).list.get(va.path[1], null_on_oob=True)
204
+ if va.path[0] in self.collect_schema()
212
205
  else None
213
206
  )
214
207
  elif isinstance(va, AliasChoices):
@@ -231,7 +224,7 @@ class LazyFrame(pl.LazyFrame, Generic[ModelType]):
231
224
  exprs.append(pl.col(name))
232
225
  else:
233
226
  expr = to_expr(field_info.validation_alias)
234
- if name in self.columns:
227
+ if name in self.collect_schema().names():
235
228
  if expr is None:
236
229
  exprs.append(pl.col(name))
237
230
  else:
@@ -247,7 +240,6 @@ class LazyFrame(pl.LazyFrame, Generic[ModelType]):
247
240
  """Cast columns to `dtypes` specified by the associated Patito model.
248
241
 
249
242
  Args:
250
- ----
251
243
  strict: If set to ``False``, columns which are technically compliant with
252
244
  the specified field type, will not be casted. For example, a column
253
245
  annotated with ``int`` is technically compliant with ``pl.UInt8``, even
@@ -258,11 +250,9 @@ class LazyFrame(pl.LazyFrame, Generic[ModelType]):
258
250
  columns are casted.
259
251
 
260
252
  Returns:
261
- -------
262
253
  LazyFrame[Model]: A dataframe with columns casted to the correct dtypes.
263
254
 
264
255
  Examples:
265
- --------
266
256
  Create a simple model:
267
257
 
268
258
  >>> import patito as pt
@@ -288,9 +278,9 @@ class LazyFrame(pl.LazyFrame, Generic[ModelType]):
288
278
  properties = self.model._schema_properties()
289
279
  valid_dtypes = self.model.valid_dtypes
290
280
  default_dtypes = self.model.dtypes
291
- columns = columns or self.columns
281
+ columns = columns or self.collect_schema().names()
292
282
  exprs = []
293
- for column, current_dtype in zip(self.columns, self.dtypes):
283
+ for column, current_dtype in self.collect_schema().items():
294
284
  if (column not in columns) or (column not in properties):
295
285
  exprs.append(pl.col(column))
296
286
  elif "dtype" in properties[column]:
@@ -348,11 +338,9 @@ class DataFrame(pl.DataFrame, Generic[ModelType]):
348
338
  DataFrame.set_model(model) is implicitly invoked at instantiation.
349
339
 
350
340
  Args:
351
- ----
352
341
  model: A patito model which should be used to validate the dataframe.
353
342
 
354
343
  Returns:
355
- -------
356
344
  A custom DataFrame model class where DataFrame._model has been correctly
357
345
  "hard-coded" to the given model.
358
346
 
@@ -369,15 +357,14 @@ class DataFrame(pl.DataFrame, Generic[ModelType]):
369
357
 
370
358
  See documentation of polars.DataFrame.lazy() for full description.
371
359
 
372
- Returns
373
- -------
360
+ Returns:
374
361
  A new LazyFrame object.
375
362
 
376
363
  """
377
- lazyframe_class: LazyFrame[
378
- ModelType
379
- ] = LazyFrame._construct_lazyframe_model_class(
380
- model=getattr(self, "model", None)
364
+ lazyframe_class: LazyFrame[ModelType] = (
365
+ LazyFrame._construct_lazyframe_model_class(
366
+ model=getattr(self, "model", None)
367
+ )
381
368
  ) # type: ignore
382
369
  ldf = lazyframe_class._from_pyldf(super().lazy()._ldf)
383
370
  return ldf
@@ -392,17 +379,14 @@ class DataFrame(pl.DataFrame, Generic[ModelType]):
392
379
  ``DataFrame(...).set_model(Model)`` is equivalent with ``Model.DataFrame(...)``.
393
380
 
394
381
  Args:
395
- ----
396
382
  model (Model): Sub-class of ``patito.Model`` declaring the schema of the
397
383
  dataframe.
398
384
 
399
385
  Returns:
400
- -------
401
386
  DataFrame[Model]: Returns the same dataframe, but with an attached model
402
387
  that is required for certain model-specific dataframe methods to work.
403
388
 
404
389
  Examples:
405
- --------
406
390
  >>> from typing_extensions import Literal
407
391
  >>> import patito as pt
408
392
  >>> import polars as pl
@@ -454,8 +438,7 @@ class DataFrame(pl.DataFrame, Generic[ModelType]):
454
438
 
455
439
  limitation - AliasChoice validation type only supports selecting a single element of an array
456
440
 
457
- Returns
458
- -------
441
+ Returns:
459
442
  DataFrame[Model]: A dataframe with columns normalized to model names.
460
443
 
461
444
  """
@@ -467,7 +450,6 @@ class DataFrame(pl.DataFrame, Generic[ModelType]):
467
450
  """Cast columns to `dtypes` specified by the associated Patito model.
468
451
 
469
452
  Args:
470
- ----
471
453
  strict: If set to ``False``, columns which are technically compliant with
472
454
  the specified field type, will not be casted. For example, a column
473
455
  annotated with ``int`` is technically compliant with ``pl.UInt8``, even
@@ -478,11 +460,9 @@ class DataFrame(pl.DataFrame, Generic[ModelType]):
478
460
  columns are casted.
479
461
 
480
462
  Returns:
481
- -------
482
463
  DataFrame[Model]: A dataframe with columns casted to the correct dtypes.
483
464
 
484
465
  Examples:
485
- --------
486
466
  Create a simple model:
487
467
 
488
468
  >>> import patito as pt
@@ -519,18 +499,15 @@ class DataFrame(pl.DataFrame, Generic[ModelType]):
519
499
  :ref:`DataFrame.set_model <DataFrame.set_model>`, are dropped.
520
500
 
521
501
  Args:
522
- ----
523
502
  columns: A single column string name, or list of strings, indicating
524
503
  which columns to drop. If not specified, all columns *not*
525
504
  specified by the associated dataframe model will be dropped.
526
505
  more_columns: Additional named columns to drop.
527
506
 
528
507
  Returns:
529
- -------
530
508
  DataFrame[Model]: New dataframe without the specified columns.
531
509
 
532
510
  Examples:
533
- --------
534
511
  >>> import patito as pt
535
512
  >>> class Model(pt.Model):
536
513
  ... column_1: int
@@ -552,20 +529,16 @@ class DataFrame(pl.DataFrame, Generic[ModelType]):
552
529
  else:
553
530
  return self.drop(list(set(self.columns) - set(self.model.columns)))
554
531
 
555
- def validate(
556
- self: DF, columns: Optional[Sequence[str]] = None, **kwargs: Any
557
- ) -> DF:
532
+ def validate(self, columns: Optional[Sequence[str]] = None, **kwargs: Any):
558
533
  """Validate the schema and content of the dataframe.
559
534
 
560
535
  You must invoke ``.set_model()`` before invoking ``.validate()`` in order
561
536
  to specify how the dataframe should be validated.
562
537
 
563
- Returns
564
- -------
538
+ Returns:
565
539
  DataFrame[Model]: The original dataframe, if correctly validated.
566
540
 
567
- Raises
568
- ------
541
+ Raises:
569
542
  TypeError: If ``DataFrame.set_model()`` has not been invoked prior to
570
543
  validation. Note that ``patito.Model.DataFrame`` automatically invokes
571
544
  ``DataFrame.set_model()`` for you.
@@ -573,8 +546,7 @@ class DataFrame(pl.DataFrame, Generic[ModelType]):
573
546
  patito.exceptions.DataFrameValidationError: If the dataframe does not match the
574
547
  specified schema.
575
548
 
576
- Examples
577
- --------
549
+ Examples:
578
550
  >>> import patito as pt
579
551
 
580
552
 
@@ -621,17 +593,14 @@ class DataFrame(pl.DataFrame, Generic[ModelType]):
621
593
  column name. Alternatively, an arbitrary polars expression can be given, the
622
594
  result of which will be used to populate the column values.
623
595
 
624
- Returns
625
- -------
596
+ Returns:
626
597
  DataFrame[Model]: A new dataframe where all derivable columns are provided.
627
598
 
628
- Raises
629
- ------
599
+ Raises:
630
600
  TypeError: If the ``derived_from`` parameter of ``patito.Field`` is given
631
601
  as something else than a string or polars expression.
632
602
 
633
- Examples
634
- --------
603
+ Examples:
635
604
  >>> import patito as pt
636
605
  >>> import polars as pl
637
606
  >>> class Foo(pt.Model):
@@ -665,11 +634,10 @@ class DataFrame(pl.DataFrame, Generic[ModelType]):
665
634
  ) -> DF:
666
635
  """Fill null values using a filling strategy, literal, or ``Expr``.
667
636
 
668
- If ``"default"`` is provided as the strategy, the model fields with default
637
+ If ``"defaults"`` is provided as the strategy, the model fields with default
669
638
  values are used to fill missing values.
670
639
 
671
640
  Args:
672
- ----
673
641
  value: Value used to fill null values.
674
642
  strategy: Accepts the same arguments as ``polars.DataFrame.fill_null`` in
675
643
  addition to ``"defaults"`` which will use the field's default value if
@@ -680,12 +648,10 @@ class DataFrame(pl.DataFrame, Generic[ModelType]):
680
648
 
681
649
 
682
650
  Returns:
683
- -------
684
651
  DataFrame[Model]: A new dataframe with nulls filled in according to the
685
652
  provided ``strategy`` parameter.
686
653
 
687
654
  Example:
688
- -------
689
655
  >>> import patito as pt
690
656
  >>> class Product(pt.Model):
691
657
  ... name: str
@@ -737,7 +703,6 @@ class DataFrame(pl.DataFrame, Generic[ModelType]):
737
703
  you can use ``.get()`` without any arguments to return that row.
738
704
 
739
705
  Raises:
740
- ------
741
706
  RowDoesNotExist: If zero rows evaluate to true for the given predicate.
742
707
  MultipleRowsReturned: If more than one row evaluates to true for the given
743
708
  predicate.
@@ -746,15 +711,12 @@ class DataFrame(pl.DataFrame, Generic[ModelType]):
746
711
  same class.
747
712
 
748
713
  Args:
749
- ----
750
714
  predicate: A polars expression defining the criteria of the filter.
751
715
 
752
716
  Returns:
753
- -------
754
717
  Model: A pydantic-derived base model representing the given row.
755
718
 
756
719
  Example:
757
- -------
758
720
  >>> import patito as pt
759
721
  >>> import polars as pl
760
722
  >>> df = pt.DataFrame({"product_id": [1, 2, 3], "price": [10, 10, 20]})
@@ -819,8 +781,7 @@ class DataFrame(pl.DataFrame, Generic[ModelType]):
819
781
  def _pydantic_model(self) -> Type[Model]:
820
782
  """Dynamically construct patito model compliant with dataframe.
821
783
 
822
- Returns
823
- -------
784
+ Returns:
824
785
  A pydantic model class where all the rows have been specified as
825
786
  `typing.Any` fields.
826
787
 
@@ -853,16 +814,13 @@ class DataFrame(pl.DataFrame, Generic[ModelType]):
853
814
  to populate the given column(s).
854
815
 
855
816
  Args:
856
- ----
857
817
  *args: All positional arguments are forwarded to ``polars.read_csv``.
858
818
  **kwargs: All keyword arguments are forwarded to ``polars.read_csv``.
859
819
 
860
820
  Returns:
861
- -------
862
821
  DataFrame[Model]: A dataframe representing the given CSV file data.
863
822
 
864
823
  Examples:
865
- --------
866
824
  The ``DataFrame.read_csv`` method can be used to automatically set the
867
825
  correct column names when reading CSV files without headers.
868
826
 
@@ -907,9 +865,23 @@ class DataFrame(pl.DataFrame, Generic[ModelType]):
907
865
  # └─────┴─────┘
908
866
 
909
867
  """
910
- kwargs.setdefault("dtypes", cls.model.dtypes)
911
- if not kwargs.get("has_header", True) and "columns" not in kwargs:
868
+ kwargs.setdefault("schema_overrides", cls.model.dtypes)
869
+ has_header = kwargs.get("has_header", True)
870
+ if not has_header and "columns" not in kwargs:
912
871
  kwargs.setdefault("new_columns", cls.model.columns)
872
+ alias_gen = cls.model.model_config.get("alias_generator")
873
+ if alias_gen:
874
+ alias_func = alias_gen.validation_alias or alias_gen.alias
875
+ if has_header and alias_gen and alias_func:
876
+ fields_to_cols = {
877
+ field_name: alias_func(field_name)
878
+ for field_name in cls.model.model_fields
879
+ }
880
+ kwargs["schema_overrides"] = {
881
+ fields_to_cols.get(field, field): dtype
882
+ for field, dtype in kwargs["schema_overrides"].items()
883
+ }
884
+ # TODO: other forms of alias setting like in Field
913
885
  df = cls.model.DataFrame._from_pydf(pl.read_csv(*args, **kwargs)._df)
914
886
  return df.derive()
915
887