patito 0.6.2__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
patito/polars.py CHANGED
@@ -2,25 +2,18 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
+ from collections.abc import Collection, Iterable, Iterator, Sequence
5
6
  from typing import (
6
7
  TYPE_CHECKING,
7
8
  Any,
8
- Collection,
9
- Dict,
10
9
  Generic,
11
- Iterable,
12
10
  Literal,
13
- Optional,
14
- Sequence,
15
- Tuple,
16
- Type,
17
11
  TypeVar,
18
- Union,
19
12
  cast,
20
13
  )
21
14
 
22
15
  import polars as pl
23
- from polars.type_aliases import IntoExpr
16
+ from polars._typing import IntoExpr
24
17
  from pydantic import AliasChoices, AliasPath, create_model
25
18
 
26
19
  from patito._pydantic.column_info import ColumnInfo
@@ -31,22 +24,40 @@ if TYPE_CHECKING:
31
24
 
32
25
  from patito.pydantic import Model
33
26
 
34
-
35
27
  DF = TypeVar("DF", bound="DataFrame")
36
28
  LDF = TypeVar("LDF", bound="LazyFrame")
37
29
  ModelType = TypeVar("ModelType", bound="Model")
38
30
  OtherModelType = TypeVar("OtherModelType", bound="Model")
31
+ T = TypeVar("T")
32
+
33
+
34
+ class ModelGenerator(Iterator[ModelType], Generic[ModelType]):
35
+ """An iterator that can be converted to a list."""
36
+
37
+ def __init__(self, iterator: Iterator[ModelType]) -> None:
38
+ """Construct a ModelGenerator from an iterator."""
39
+ self._iterator = iterator
40
+
41
+ def to_list(self) -> list[ModelType]:
42
+ """Convert iterator to list."""
43
+ return list(self)
44
+
45
+ def __next__(self) -> ModelType: # noqa: D105
46
+ return next(self._iterator)
47
+
48
+ def __iter__(self) -> Iterator[ModelType]: # noqa: D105
49
+ return self
39
50
 
40
51
 
41
52
  class LazyFrame(pl.LazyFrame, Generic[ModelType]):
42
53
  """LazyFrame class associated to DataFrame."""
43
54
 
44
- model: Type[ModelType]
55
+ model: type[ModelType]
45
56
 
46
57
  @classmethod
47
58
  def _construct_lazyframe_model_class(
48
- cls: Type[LDF], model: Optional[Type[ModelType]]
49
- ) -> Type[LazyFrame[ModelType]]:
59
+ cls: type[LDF], model: type[ModelType] | None
60
+ ) -> type[LazyFrame[ModelType]]:
50
61
  """Return custom LazyFrame sub-class where LazyFrame.model is set.
51
62
 
52
63
  Can be used to construct a LazyFrame class where
@@ -75,7 +86,7 @@ class LazyFrame(pl.LazyFrame, Generic[ModelType]):
75
86
  self,
76
87
  *args,
77
88
  **kwargs,
78
- ) -> "DataFrame[ModelType]": # noqa: DAR101, DAR201
89
+ ) -> DataFrame[ModelType]: # noqa: DAR101, DAR201
79
90
  """Collect into a DataFrame.
80
91
 
81
92
  See documentation of polars.DataFrame.collect for full description of
@@ -130,7 +141,7 @@ class LazyFrame(pl.LazyFrame, Generic[ModelType]):
130
141
  """
131
142
  derived_columns = []
132
143
  props = self.model._schema_properties()
133
- original_columns = set(self.columns)
144
+ original_columns = set(self.collect_schema())
134
145
  to_derive = self.model.derived_columns if columns is None else columns
135
146
  for column_name in to_derive:
136
147
  if column_name not in derived_columns:
@@ -148,33 +159,35 @@ class LazyFrame(pl.LazyFrame, Generic[ModelType]):
148
159
 
149
160
  def _derive_column(
150
161
  self,
151
- df: LDF,
162
+ lf: LDF,
152
163
  column_name: str,
153
- column_infos: Dict[str, ColumnInfo],
154
- ) -> Tuple[LDF, Sequence[str]]:
164
+ column_infos: dict[str, ColumnInfo],
165
+ ) -> tuple[LDF, Sequence[str]]:
155
166
  if (
156
167
  column_infos.get(column_name, None) is None
157
168
  or column_infos[column_name].derived_from is None
158
169
  ):
159
- return df, []
170
+ return lf, []
171
+
160
172
  derived_from = column_infos[column_name].derived_from
161
173
  dtype = self.model.dtypes[column_name]
162
174
  derived_columns = []
175
+
163
176
  if isinstance(derived_from, str):
164
- df = df.with_columns(pl.col(derived_from).cast(dtype).alias(column_name))
177
+ lf = lf.with_columns(pl.col(derived_from).cast(dtype).alias(column_name))
165
178
  elif isinstance(derived_from, pl.Expr):
166
179
  root_cols = derived_from.meta.root_names()
167
180
  while root_cols:
168
181
  root_col = root_cols.pop()
169
- df, _derived_columns = self._derive_column(df, root_col, column_infos)
182
+ lf, _derived_columns = self._derive_column(lf, root_col, column_infos)
170
183
  derived_columns.extend(_derived_columns)
171
- df = df.with_columns(derived_from.cast(dtype).alias(column_name))
184
+ lf = lf.with_columns(derived_from.cast(dtype).alias(column_name))
172
185
  else:
173
186
  raise TypeError(
174
187
  "Can not derive dataframe column from type " f"{type(derived_from)}."
175
188
  )
176
189
  derived_columns.append(column_name)
177
- return df, derived_columns
190
+ return lf, derived_columns
178
191
 
179
192
  def unalias(self: LDF) -> LDF:
180
193
  """Un-aliases column names using information from pydantic validation_alias.
@@ -191,21 +204,21 @@ class LazyFrame(pl.LazyFrame, Generic[ModelType]):
191
204
  return self
192
205
  exprs = []
193
206
 
194
- def to_expr(va: str | AliasPath | AliasChoices) -> Optional[pl.Expr]:
207
+ def to_expr(va: str | AliasPath | AliasChoices) -> pl.Expr | None:
195
208
  if isinstance(va, str):
196
- return pl.col(va) if va in self.columns else None
209
+ return pl.col(va) if va in self.collect_schema() else None
197
210
  elif isinstance(va, AliasPath):
198
211
  if len(va.path) != 2 or not isinstance(va.path[1], int):
199
212
  raise NotImplementedError(
200
213
  f"TODO figure out how this AliasPath behaves ({va})"
201
214
  )
202
215
  return (
203
- pl.col(va.path[0]).list.get(va.path[1])
204
- if va.path[0] in self.columns
216
+ pl.col(va.path[0]).list.get(va.path[1], null_on_oob=True)
217
+ if va.path[0] in self.collect_schema()
205
218
  else None
206
219
  )
207
220
  elif isinstance(va, AliasChoices):
208
- local_expr: Optional[pl.Expr] = None
221
+ local_expr: pl.Expr | None = None
209
222
  for choice in va.choices:
210
223
  if (part := to_expr(choice)) is not None:
211
224
  local_expr = (
@@ -224,7 +237,7 @@ class LazyFrame(pl.LazyFrame, Generic[ModelType]):
224
237
  exprs.append(pl.col(name))
225
238
  else:
226
239
  expr = to_expr(field_info.validation_alias)
227
- if name in self.columns:
240
+ if name in self.collect_schema().names():
228
241
  if expr is None:
229
242
  exprs.append(pl.col(name))
230
243
  else:
@@ -235,7 +248,7 @@ class LazyFrame(pl.LazyFrame, Generic[ModelType]):
235
248
  return self.select(exprs)
236
249
 
237
250
  def cast(
238
- self: LDF, strict: bool = False, columns: Optional[Sequence[str]] = None
251
+ self: LDF, strict: bool = False, columns: Sequence[str] | None = None
239
252
  ) -> LDF:
240
253
  """Cast columns to `dtypes` specified by the associated Patito model.
241
254
 
@@ -278,9 +291,9 @@ class LazyFrame(pl.LazyFrame, Generic[ModelType]):
278
291
  properties = self.model._schema_properties()
279
292
  valid_dtypes = self.model.valid_dtypes
280
293
  default_dtypes = self.model.dtypes
281
- columns = columns or self.columns
294
+ columns = columns or self.collect_schema().names()
282
295
  exprs = []
283
- for column, current_dtype in zip(self.columns, self.dtypes):
296
+ for column, current_dtype in self.collect_schema().items():
284
297
  if (column not in columns) or (column not in properties):
285
298
  exprs.append(pl.col(column))
286
299
  elif "dtype" in properties[column]:
@@ -292,7 +305,7 @@ class LazyFrame(pl.LazyFrame, Generic[ModelType]):
292
305
  return self.with_columns(exprs)
293
306
 
294
307
  @classmethod
295
- def from_existing(cls: Type[LDF], lf: pl.LazyFrame) -> LDF:
308
+ def from_existing(cls: type[LDF], lf: pl.LazyFrame) -> LDF:
296
309
  """Construct a patito.DataFrame object from an existing polars.DataFrame object."""
297
310
  return cls.model.LazyFrame._from_pyldf(lf._ldf).cast()
298
311
 
@@ -326,12 +339,12 @@ class DataFrame(pl.DataFrame, Generic[ModelType]):
326
339
  :ref:`Product.validate <DataFrame.validate>`.
327
340
  """
328
341
 
329
- model: Type[ModelType]
342
+ model: type[ModelType]
330
343
 
331
344
  @classmethod
332
345
  def _construct_dataframe_model_class(
333
- cls: Type[DF], model: Type[OtherModelType]
334
- ) -> Type[DataFrame[OtherModelType]]:
346
+ cls: type[DF], model: type[OtherModelType]
347
+ ) -> type[DataFrame[OtherModelType]]:
335
348
  """Return custom DataFrame sub-class where DataFrame.model is set.
336
349
 
337
350
  Can be used to construct a DataFrame class where
@@ -445,7 +458,7 @@ class DataFrame(pl.DataFrame, Generic[ModelType]):
445
458
  return self.lazy().unalias().collect()
446
459
 
447
460
  def cast(
448
- self: DF, strict: bool = False, columns: Optional[Sequence[str]] = None
461
+ self: DF, strict: bool = False, columns: Sequence[str] | None = None
449
462
  ) -> DF:
450
463
  """Cast columns to `dtypes` specified by the associated Patito model.
451
464
 
@@ -489,7 +502,7 @@ class DataFrame(pl.DataFrame, Generic[ModelType]):
489
502
 
490
503
  def drop(
491
504
  self: DF,
492
- columns: Optional[Union[str, Collection[str]]] = None,
505
+ columns: str | Collection[str] | None = None,
493
506
  *more_columns: str,
494
507
  ) -> DF:
495
508
  """Drop one or more columns from the dataframe.
@@ -529,23 +542,23 @@ class DataFrame(pl.DataFrame, Generic[ModelType]):
529
542
  else:
530
543
  return self.drop(list(set(self.columns) - set(self.model.columns)))
531
544
 
532
- def validate(self, columns: Optional[Sequence[str]] = None, **kwargs: Any):
545
+ def validate(self, columns: Sequence[str] | None = None, **kwargs: Any):
533
546
  """Validate the schema and content of the dataframe.
534
547
 
535
548
  You must invoke ``.set_model()`` before invoking ``.validate()`` in order
536
549
  to specify how the dataframe should be validated.
537
550
 
538
551
  Returns:
539
- DataFrame[Model]: The original dataframe, if correctly validated.
552
+ DataFrame[Model]: The original patito dataframe, if correctly validated.
540
553
 
541
554
  Raises:
555
+ patito.exceptions.DataFrameValidationError: If the dataframe does not match the
556
+ specified schema.
557
+
542
558
  TypeError: If ``DataFrame.set_model()`` has not been invoked prior to
543
559
  validation. Note that ``patito.Model.DataFrame`` automatically invokes
544
560
  ``DataFrame.set_model()`` for you.
545
561
 
546
- patito.exceptions.DataFrameValidationError: If the dataframe does not match the
547
- specified schema.
548
-
549
562
  Examples:
550
563
  >>> import patito as pt
551
564
 
@@ -623,13 +636,12 @@ class DataFrame(pl.DataFrame, Generic[ModelType]):
623
636
 
624
637
  def fill_null(
625
638
  self: DF,
626
- value: Optional[Any] = None,
627
- strategy: Optional[
628
- Literal[
629
- "forward", "backward", "min", "max", "mean", "zero", "one", "defaults"
630
- ]
631
- ] = None,
632
- limit: Optional[int] = None,
639
+ value: Any | None = None,
640
+ strategy: Literal[
641
+ "forward", "backward", "min", "max", "mean", "zero", "one", "defaults"
642
+ ]
643
+ | None = None,
644
+ limit: int | None = None,
633
645
  matches_supertype: bool = True,
634
646
  ) -> DF:
635
647
  """Fill null values using a filling strategy, literal, or ``Expr``.
@@ -689,14 +701,13 @@ class DataFrame(pl.DataFrame, Generic[ModelType]):
689
701
  pl.lit(default_value, self.model.dtypes[column])
690
702
  )
691
703
  if column in self.columns
692
- else pl.Series(column, [default_value], self.model.dtypes[column])
693
- ) # NOTE: hack to get around polars bug https://github.com/pola-rs/polars/issues/13602
694
- # else pl.lit(default_value, self.model.dtypes[column]).alias(column)
704
+ else pl.lit(default_value, self.model.dtypes[column]).alias(column)
705
+ )
695
706
  for column, default_value in self.model.defaults.items()
696
707
  ]
697
708
  ).set_model(self.model)
698
709
 
699
- def get(self, predicate: Optional[pl.Expr] = None) -> ModelType:
710
+ def get(self, predicate: pl.Expr | None = None) -> ModelType:
700
711
  """Fetch the single row that matches the given polars predicate.
701
712
 
702
713
  If you expect a data frame to already consist of one single row,
@@ -778,7 +789,57 @@ class DataFrame(pl.DataFrame, Generic[ModelType]):
778
789
  else:
779
790
  return self._pydantic_model().from_row(row) # type: ignore
780
791
 
781
- def _pydantic_model(self) -> Type[Model]:
792
+ def iter_models(
793
+ self, validate_df: bool = True, validate_model: bool = False
794
+ ) -> ModelGenerator[ModelType]:
795
+ """Iterate over all rows in the dataframe as pydantic models.
796
+
797
+ Args:
798
+ validate_df: If set to ``True``, the dataframe will be validated before
799
+ making models out of each row. If set to ``False``, beware that columns
800
+ need to be the exact same as the model fields.
801
+ validate_model: If set to ``True``, each model will be validated when
802
+ constructing. Disabled by default since df validation should cover this case.
803
+
804
+ Yields:
805
+ Model: A pydantic-derived model representing the given row. .to_list() can be
806
+ used to convert the iterator to a list.
807
+
808
+ Raises:
809
+ TypeError: If ``DataFrame.set_model()`` has not been invoked prior to
810
+ iteration.
811
+
812
+ Example:
813
+ >>> import patito as pt
814
+ >>> import polars as pl
815
+ >>> class Product(pt.Model):
816
+ ... product_id: int = pt.Field(unique=True)
817
+ ... price: float
818
+ ...
819
+ >>> df = pt.DataFrame({"product_id": [1, 2], "price": [10., 20.]})
820
+ >>> df = df.set_model(Product)
821
+ >>> for product in df.iter_models():
822
+ ... print(product)
823
+ ...
824
+ Product(product_id=1, price=10.0)
825
+ Product(product_id=2, price=20.0)
826
+
827
+ """
828
+ if not hasattr(self, "model"):
829
+ raise TypeError(
830
+ f"You must invoke {self.__class__.__name__}.set_model() "
831
+ f"before invoking {self.__class__.__name__}.iter_models()."
832
+ )
833
+
834
+ df = self.validate(drop_superfluous_columns=True) if validate_df else self
835
+
836
+ def _iter_models(_df: DF) -> Iterator[ModelType]:
837
+ for idx in range(_df.height):
838
+ yield self.model.from_row(_df[idx], validate=validate_model)
839
+
840
+ return ModelGenerator(_iter_models(df))
841
+
842
+ def _pydantic_model(self) -> type[Model]:
782
843
  """Dynamically construct patito model compliant with dataframe.
783
844
 
784
845
  Returns:
@@ -790,7 +851,7 @@ class DataFrame(pl.DataFrame, Generic[ModelType]):
790
851
 
791
852
  pydantic_annotations = {column: (Any, ...) for column in self.columns}
792
853
  return cast(
793
- Type[Model],
854
+ type[Model],
794
855
  create_model( # type: ignore
795
856
  "UntypedRow",
796
857
  __base__=Model,
@@ -804,7 +865,7 @@ class DataFrame(pl.DataFrame, Generic[ModelType]):
804
865
 
805
866
  @classmethod
806
867
  def read_csv( # type: ignore[no-untyped-def]
807
- cls: Type[DF],
868
+ cls: type[DF],
808
869
  *args, # noqa: ANN002
809
870
  **kwargs, # noqa: ANN003
810
871
  ) -> DF:
@@ -865,7 +926,7 @@ class DataFrame(pl.DataFrame, Generic[ModelType]):
865
926
  # └─────┴─────┘
866
927
 
867
928
  """
868
- kwargs.setdefault("dtypes", cls.model.dtypes)
929
+ kwargs.setdefault("schema_overrides", cls.model.dtypes)
869
930
  has_header = kwargs.get("has_header", True)
870
931
  if not has_header and "columns" not in kwargs:
871
932
  kwargs.setdefault("new_columns", cls.model.columns)
@@ -877,9 +938,9 @@ class DataFrame(pl.DataFrame, Generic[ModelType]):
877
938
  field_name: alias_func(field_name)
878
939
  for field_name in cls.model.model_fields
879
940
  }
880
- kwargs["dtypes"] = {
941
+ kwargs["schema_overrides"] = {
881
942
  fields_to_cols.get(field, field): dtype
882
- for field, dtype in kwargs["dtypes"].items()
943
+ for field, dtype in kwargs["schema_overrides"].items()
883
944
  }
884
945
  # TODO: other forms of alias setting like in Field
885
946
  df = cls.model.DataFrame._from_pydf(pl.read_csv(*args, **kwargs)._df)
@@ -888,15 +949,13 @@ class DataFrame(pl.DataFrame, Generic[ModelType]):
888
949
  # --- Type annotation overrides ---
889
950
  def filter( # noqa: D102
890
951
  self: DF,
891
- predicate: Union[
892
- pl.Expr, str, pl.Series, list[bool], np.ndarray[Any, Any], bool
893
- ],
952
+ predicate: pl.Expr | str | pl.Series | list[bool] | np.ndarray[Any, Any] | bool,
894
953
  ) -> DF:
895
954
  return cast(DF, super().filter(predicate))
896
955
 
897
956
  def select( # noqa: D102
898
957
  self: DF,
899
- *exprs: Union[IntoExpr, Iterable[IntoExpr]],
958
+ *exprs: IntoExpr | Iterable[IntoExpr],
900
959
  **named_exprs: IntoExpr,
901
960
  ) -> DF:
902
961
  return cast( # pyright: ignore[redundant-cast]
@@ -905,7 +964,7 @@ class DataFrame(pl.DataFrame, Generic[ModelType]):
905
964
 
906
965
  def with_columns( # noqa: D102
907
966
  self: DF,
908
- *exprs: Union[IntoExpr, Iterable[IntoExpr]],
967
+ *exprs: IntoExpr | Iterable[IntoExpr],
909
968
  **named_exprs: IntoExpr,
910
969
  ) -> DF:
911
970
  return cast(DF, super().with_columns(*exprs, **named_exprs))