patito 0.5.1__py3-none-any.whl → 0.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
patito/polars.py CHANGED
@@ -1,13 +1,18 @@
1
1
  """Logic related to the wrapping of the polars data frame library."""
2
+
2
3
  from __future__ import annotations
3
4
 
4
5
  from typing import (
5
6
  TYPE_CHECKING,
6
7
  Any,
7
8
  Collection,
9
+ Dict,
8
10
  Generic,
9
11
  Iterable,
12
+ Literal,
10
13
  Optional,
14
+ Sequence,
15
+ Tuple,
11
16
  Type,
12
17
  TypeVar,
13
18
  Union,
@@ -16,9 +21,9 @@ from typing import (
16
21
 
17
22
  import polars as pl
18
23
  from polars.type_aliases import IntoExpr
19
- from pydantic import create_model
20
- from typing_extensions import Literal
24
+ from pydantic import AliasChoices, AliasPath, create_model
21
25
 
26
+ from patito._pydantic.column_info import ColumnInfo
22
27
  from patito.exceptions import MultipleRowsReturned, RowDoesNotExist
23
28
 
24
29
  if TYPE_CHECKING:
@@ -42,8 +47,7 @@ class LazyFrame(pl.LazyFrame, Generic[ModelType]):
42
47
  def _construct_lazyframe_model_class(
43
48
  cls: Type[LDF], model: Optional[Type[ModelType]]
44
49
  ) -> Type[LazyFrame[ModelType]]:
45
- """
46
- Return custom LazyFrame sub-class where LazyFrame.model is set.
50
+ """Return custom LazyFrame sub-class where LazyFrame.model is set.
47
51
 
48
52
  Can be used to construct a LazyFrame class where
49
53
  DataFrame.set_model(model) is implicitly invoked at collection.
@@ -55,12 +59,13 @@ class LazyFrame(pl.LazyFrame, Generic[ModelType]):
55
59
  Returns:
56
60
  A custom LazyFrame model class where LazyFrame.model has been correctly
57
61
  "hard-coded" to the given model.
62
+
58
63
  """
59
64
  if model is None:
60
65
  return cls
61
66
 
62
67
  new_class = type(
63
- f"{model.schema()['title']}LazyFrame",
68
+ f"{model.__name__}LazyFrame",
64
69
  (cls,),
65
70
  {"model": model},
66
71
  )
@@ -68,41 +73,232 @@ class LazyFrame(pl.LazyFrame, Generic[ModelType]):
68
73
 
69
74
  def collect(
70
75
  self,
71
- type_coercion: bool = True,
72
- predicate_pushdown: bool = True,
73
- projection_pushdown: bool = True,
74
- simplify_expression: bool = True,
75
- no_optimization: bool = False,
76
- slice_pushdown: bool = True,
77
- common_subplan_elimination: bool = True,
78
- streaming: bool = False,
76
+ *args,
77
+ **kwargs,
79
78
  ) -> "DataFrame[ModelType]": # noqa: DAR101, DAR201
80
- """
81
- Collect into a DataFrame.
79
+ """Collect into a DataFrame.
82
80
 
83
81
  See documentation of polars.DataFrame.collect for full description of
84
82
  parameters.
85
83
  """
86
- df = super().collect(
87
- type_coercion=type_coercion,
88
- predicate_pushdown=predicate_pushdown,
89
- projection_pushdown=projection_pushdown,
90
- simplify_expression=simplify_expression,
91
- no_optimization=no_optimization,
92
- slice_pushdown=slice_pushdown,
93
- common_subplan_elimination=common_subplan_elimination,
94
- streaming=streaming,
95
- )
84
+ background = kwargs.pop("background", False)
85
+ df = super().collect(*args, background=background, **kwargs)
96
86
  if getattr(self, "model", False):
97
87
  cls = DataFrame._construct_dataframe_model_class(model=self.model)
98
88
  else:
99
89
  cls = DataFrame
100
90
  return cls._from_pydf(df._df)
101
91
 
92
+ def derive(self: LDF, columns: list[str] | None = None) -> LDF:
93
+ """Populate columns which have ``pt.Field(derived_from=...)`` definitions.
94
+
95
+ If a column field on the data frame model has ``patito.Field(derived_from=...)``
96
+ specified, the given value will be used to define the column. If
97
+ ``derived_from`` is set to a string, the column will be derived from the given
98
+ column name. Alternatively, an arbitrary polars expression can be given, the
99
+ result of which will be used to populate the column values.
100
+
101
+ Args:
102
+ columns: Optionally, a list of column names to derive. If not provided, all
103
+ columns are used.
104
+
105
+ Returns:
106
+ DataFrame[Model]: A new dataframe where all derivable columns are provided.
107
+
108
+ Raises:
109
+ TypeError: If the ``derived_from`` parameter of ``patito.Field`` is given
110
+ as something else than a string or polars expression.
111
+
112
+ Examples:
113
+ >>> import patito as pt
114
+ >>> import polars as pl
115
+ >>> class Foo(pt.Model):
116
+ ... bar: int = pt.Field(derived_from="foo")
117
+ ... double_bar: int = pt.Field(derived_from=2 * pl.col("bar"))
118
+ ...
119
+ >>> Foo.DataFrame({"foo": [1, 2]}).derive()
120
+ shape: (2, 3)
121
+ ┌─────┬────────────┬─────┐
122
+ │ bar ┆ double_bar ┆ foo │
123
+ │ --- ┆ --- ┆ --- │
124
+ │ i64 ┆ i64 ┆ i64 │
125
+ ╞═════╪════════════╪═════╡
126
+ │ 1 ┆ 2 ┆ 1 │
127
+ │ 2 ┆ 4 ┆ 2 │
128
+ └─────┴────────────┴─────┘
129
+
130
+ """
131
+ derived_columns = []
132
+ props = self.model._schema_properties()
133
+ original_columns = set(self.columns)
134
+ to_derive = self.model.derived_columns if columns is None else columns
135
+ for column_name in to_derive:
136
+ if column_name not in derived_columns:
137
+ self, _derived_columns = self._derive_column(
138
+ self, column_name, self.model.column_infos
139
+ )
140
+ derived_columns.extend(_derived_columns)
141
+ out_cols = [
142
+ x for x in props if x in original_columns.union(to_derive)
143
+ ] # ensure that model columns are first and in the correct order
144
+ out_cols += [
145
+ x for x in original_columns.union(to_derive) if x not in out_cols
146
+ ] # collect columns originally in data frame that are not in the model and append to end of df
147
+ return self.select(out_cols)
148
+
149
+ def _derive_column(
150
+ self,
151
+ df: LDF,
152
+ column_name: str,
153
+ column_infos: Dict[str, ColumnInfo],
154
+ ) -> Tuple[LDF, Sequence[str]]:
155
+ if (
156
+ column_infos.get(column_name, None) is None
157
+ or column_infos[column_name].derived_from is None
158
+ ):
159
+ return df, []
160
+ derived_from = column_infos[column_name].derived_from
161
+ dtype = self.model.dtypes[column_name]
162
+ derived_columns = []
163
+ if isinstance(derived_from, str):
164
+ df = df.with_columns(pl.col(derived_from).cast(dtype).alias(column_name))
165
+ elif isinstance(derived_from, pl.Expr):
166
+ root_cols = derived_from.meta.root_names()
167
+ while root_cols:
168
+ root_col = root_cols.pop()
169
+ df, _derived_columns = self._derive_column(df, root_col, column_infos)
170
+ derived_columns.extend(_derived_columns)
171
+ df = df.with_columns(derived_from.cast(dtype).alias(column_name))
172
+ else:
173
+ raise TypeError(
174
+ "Can not derive dataframe column from type " f"{type(derived_from)}."
175
+ )
176
+ derived_columns.append(column_name)
177
+ return df, derived_columns
178
+
179
+ def unalias(self: LDF) -> LDF:
180
+ """Un-aliases column names using information from pydantic validation_alias.
181
+
182
+ In order of preference - model field name then validation_aliases in order of occurrence
183
+
184
+ limitation - AliasChoice validation type only supports selecting a single element of an array
185
+
186
+ Returns:
187
+ DataFrame[Model]: A dataframe with columns normalized to model names.
188
+
189
+ """
190
+ if not any(fi.validation_alias for fi in self.model.model_fields.values()):
191
+ return self
192
+ exprs = []
193
+
194
+ def to_expr(va: str | AliasPath | AliasChoices) -> Optional[pl.Expr]:
195
+ if isinstance(va, str):
196
+ return pl.col(va) if va in self.columns else None
197
+ elif isinstance(va, AliasPath):
198
+ if len(va.path) != 2 or not isinstance(va.path[1], int):
199
+ raise NotImplementedError(
200
+ f"TODO figure out how this AliasPath behaves ({va})"
201
+ )
202
+ return (
203
+ pl.col(va.path[0]).list.get(va.path[1])
204
+ if va.path[0] in self.columns
205
+ else None
206
+ )
207
+ elif isinstance(va, AliasChoices):
208
+ local_expr: Optional[pl.Expr] = None
209
+ for choice in va.choices:
210
+ if (part := to_expr(choice)) is not None:
211
+ local_expr = (
212
+ local_expr.fill_null(value=part)
213
+ if local_expr is not None
214
+ else part
215
+ )
216
+ return local_expr
217
+ else:
218
+ raise NotImplementedError(
219
+ f"unknown validation_alias type {field_info.validation_alias}"
220
+ )
221
+
222
+ for name, field_info in self.model.model_fields.items():
223
+ if field_info.validation_alias is None:
224
+ exprs.append(pl.col(name))
225
+ else:
226
+ expr = to_expr(field_info.validation_alias)
227
+ if name in self.columns:
228
+ if expr is None:
229
+ exprs.append(pl.col(name))
230
+ else:
231
+ exprs.append(pl.col(name).fill_null(value=expr))
232
+ elif expr is not None:
233
+ exprs.append(expr.alias(name))
234
+
235
+ return self.select(exprs)
236
+
237
+ def cast(
238
+ self: LDF, strict: bool = False, columns: Optional[Sequence[str]] = None
239
+ ) -> LDF:
240
+ """Cast columns to `dtypes` specified by the associated Patito model.
241
+
242
+ Args:
243
+ strict: If set to ``False``, columns which are technically compliant with
244
+ the specified field type, will not be casted. For example, a column
245
+ annotated with ``int`` is technically compliant with ``pl.UInt8``, even
246
+ if ``pl.Int64`` is the default dtype associated with ``int``-annotated
247
+ fields. If ``strict`` is set to ``True``, the resulting dtypes will
248
+ be forced to the default dtype associated with each python type.
249
+ columns: Optionally, a list of column names to cast. If not provided, all
250
+ columns are casted.
251
+
252
+ Returns:
253
+ LazyFrame[Model]: A dataframe with columns casted to the correct dtypes.
254
+
255
+ Examples:
256
+ Create a simple model:
257
+
258
+ >>> import patito as pt
259
+ >>> import polars as pl
260
+ >>> class Product(pt.Model):
261
+ ... name: str
262
+ ... cent_price: int = pt.Field(dtype=pl.UInt16)
263
+ ...
264
+
265
+ Now we can use this model to cast some simple data:
266
+
267
+ >>> Product.LazyFrame({"name": ["apple"], "cent_price": ["8"]}).cast().collect()
268
+ shape: (1, 2)
269
+ ┌───────┬────────────┐
270
+ │ name ┆ cent_price │
271
+ │ --- ┆ --- │
272
+ │ str ┆ u16 │
273
+ ╞═══════╪════════════╡
274
+ │ apple ┆ 8 │
275
+ └───────┴────────────┘
276
+
277
+ """
278
+ properties = self.model._schema_properties()
279
+ valid_dtypes = self.model.valid_dtypes
280
+ default_dtypes = self.model.dtypes
281
+ columns = columns or self.columns
282
+ exprs = []
283
+ for column, current_dtype in zip(self.columns, self.dtypes):
284
+ if (column not in columns) or (column not in properties):
285
+ exprs.append(pl.col(column))
286
+ elif "dtype" in properties[column]:
287
+ exprs.append(pl.col(column).cast(properties[column]["dtype"]))
288
+ elif not strict and current_dtype in valid_dtypes[column]:
289
+ exprs.append(pl.col(column))
290
+ else:
291
+ exprs.append(pl.col(column).cast(default_dtypes[column]))
292
+ return self.with_columns(exprs)
293
+
294
+ @classmethod
295
+ def from_existing(cls: Type[LDF], lf: pl.LazyFrame) -> LDF:
296
+ """Construct a patito.DataFrame object from an existing polars.DataFrame object."""
297
+ return cls.model.LazyFrame._from_pyldf(lf._ldf).cast()
298
+
102
299
 
103
300
  class DataFrame(pl.DataFrame, Generic[ModelType]):
104
- """
105
- A sub-class of polars.DataFrame with additional functionality related to Model.
301
+ """A sub-class of polars.DataFrame with additional functionality related to Model.
106
302
 
107
303
  Two different methods are available for constructing model-aware data frames.
108
304
  Assume a simple model with two fields:
@@ -136,8 +332,7 @@ class DataFrame(pl.DataFrame, Generic[ModelType]):
136
332
  def _construct_dataframe_model_class(
137
333
  cls: Type[DF], model: Type[OtherModelType]
138
334
  ) -> Type[DataFrame[OtherModelType]]:
139
- """
140
- Return custom DataFrame sub-class where DataFrame.model is set.
335
+ """Return custom DataFrame sub-class where DataFrame.model is set.
141
336
 
142
337
  Can be used to construct a DataFrame class where
143
338
  DataFrame.set_model(model) is implicitly invoked at instantiation.
@@ -148,34 +343,34 @@ class DataFrame(pl.DataFrame, Generic[ModelType]):
148
343
  Returns:
149
344
  A custom DataFrame model class where DataFrame._model has been correctly
150
345
  "hard-coded" to the given model.
346
+
151
347
  """
152
348
  new_class = type(
153
- f"{model.schema()['title']}DataFrame",
349
+ f"{model.model_json_schema()['title']}DataFrame",
154
350
  (cls,),
155
351
  {"model": model},
156
352
  )
157
353
  return new_class
158
354
 
159
355
  def lazy(self: DataFrame[ModelType]) -> LazyFrame[ModelType]:
160
- """
161
- Convert DataFrame into LazyFrame.
356
+ """Convert DataFrame into LazyFrame.
162
357
 
163
358
  See documentation of polars.DataFrame.lazy() for full description.
164
359
 
165
360
  Returns:
166
361
  A new LazyFrame object.
362
+
167
363
  """
168
- lazyframe_class: LazyFrame[
169
- ModelType
170
- ] = LazyFrame._construct_lazyframe_model_class(
171
- model=getattr(self, "model", None)
364
+ lazyframe_class: LazyFrame[ModelType] = (
365
+ LazyFrame._construct_lazyframe_model_class(
366
+ model=getattr(self, "model", None)
367
+ )
172
368
  ) # type: ignore
173
369
  ldf = lazyframe_class._from_pyldf(super().lazy()._ldf)
174
370
  return ldf
175
371
 
176
372
  def set_model(self, model): # type: ignore[no-untyped-def] # noqa: ANN001, ANN201
177
- """
178
- Associate a given patito ``Model`` with the dataframe.
373
+ """Associate a given patito ``Model`` with the dataframe.
179
374
 
180
375
  The model schema is used by methods that depend on a model being associated with
181
376
  the given dataframe such as :ref:`DataFrame.validate() <DataFrame.validate>`
@@ -228,6 +423,7 @@ class DataFrame(pl.DataFrame, Generic[ModelType]):
228
423
  │ 2 ┆ B │
229
424
  └──────┴────────┘
230
425
  >>> casted_classes.validate()
426
+
231
427
  """
232
428
  cls = self._construct_dataframe_model_class(model=model)
233
429
  return cast(
@@ -235,9 +431,23 @@ class DataFrame(pl.DataFrame, Generic[ModelType]):
235
431
  cls._from_pydf(self._df),
236
432
  )
237
433
 
238
- def cast(self: DF, strict: bool = False) -> DF:
434
+ def unalias(self: DF) -> DF:
435
+ """Un-aliases column names using information from pydantic validation_alias.
436
+
437
+ In order of preference - model field name then validation_aliases in order of occurrence
438
+
439
+ limitation - AliasChoice validation type only supports selecting a single element of an array
440
+
441
+ Returns:
442
+ DataFrame[Model]: A dataframe with columns normalized to model names.
443
+
239
444
  """
240
- Cast columns to `dtypes` specified by the associated Patito model.
445
+ return self.lazy().unalias().collect()
446
+
447
+ def cast(
448
+ self: DF, strict: bool = False, columns: Optional[Sequence[str]] = None
449
+ ) -> DF:
450
+ """Cast columns to `dtypes` specified by the associated Patito model.
241
451
 
242
452
  Args:
243
453
  strict: If set to ``False``, columns which are technically compliant with
@@ -246,6 +456,8 @@ class DataFrame(pl.DataFrame, Generic[ModelType]):
246
456
  if ``pl.Int64`` is the default dtype associated with ``int``-annotated
247
457
  fields. If ``strict`` is set to ``True``, the resulting dtypes will
248
458
  be forced to the default dtype associated with each python type.
459
+ columns: Optionally, a list of column names to cast. If not provided, all
460
+ columns are casted.
249
461
 
250
462
  Returns:
251
463
  DataFrame[Model]: A dataframe with columns casted to the correct dtypes.
@@ -271,29 +483,16 @@ class DataFrame(pl.DataFrame, Generic[ModelType]):
271
483
  ╞═══════╪════════════╡
272
484
  │ apple ┆ 8 │
273
485
  └───────┴────────────┘
486
+
274
487
  """
275
- properties = self.model._schema_properties()
276
- valid_dtypes = self.model.valid_dtypes
277
- default_dtypes = self.model.dtypes
278
- columns = []
279
- for column, current_dtype in zip(self.columns, self.dtypes):
280
- if column not in properties:
281
- columns.append(pl.col(column))
282
- elif "dtype" in properties[column]:
283
- columns.append(pl.col(column).cast(properties[column]["dtype"]))
284
- elif not strict and current_dtype in valid_dtypes[column]:
285
- columns.append(pl.col(column))
286
- else:
287
- columns.append(pl.col(column).cast(default_dtypes[column]))
288
- return self.with_columns(columns)
488
+ return self.lazy().cast(strict=strict, columns=columns).collect()
289
489
 
290
490
  def drop(
291
491
  self: DF,
292
492
  columns: Optional[Union[str, Collection[str]]] = None,
293
493
  *more_columns: str,
294
494
  ) -> DF:
295
- """
296
- Drop one or more columns from the dataframe.
495
+ """Drop one or more columns from the dataframe.
297
496
 
298
497
  If ``name`` is not provided then all columns `not` specified by the associated
299
498
  patito model, for instance set with
@@ -330,9 +529,8 @@ class DataFrame(pl.DataFrame, Generic[ModelType]):
330
529
  else:
331
530
  return self.drop(list(set(self.columns) - set(self.model.columns)))
332
531
 
333
- def validate(self: DF) -> DF:
334
- """
335
- Validate the schema and content of the dataframe.
532
+ def validate(self, columns: Optional[Sequence[str]] = None, **kwargs: Any):
533
+ """Validate the schema and content of the dataframe.
336
534
 
337
535
  You must invoke ``.set_model()`` before invoking ``.validate()`` in order
338
536
  to specify how the dataframe should be validated.
@@ -345,7 +543,7 @@ class DataFrame(pl.DataFrame, Generic[ModelType]):
345
543
  validation. Note that ``patito.Model.DataFrame`` automatically invokes
346
544
  ``DataFrame.set_model()`` for you.
347
545
 
348
- patito.exceptions.ValidationError: If the dataframe does not match the
546
+ patito.exceptions.DataFrameValidationError: If the dataframe does not match the
349
547
  specified schema.
350
548
 
351
549
  Examples:
@@ -366,7 +564,7 @@ class DataFrame(pl.DataFrame, Generic[ModelType]):
366
564
  ... ).set_model(Product)
367
565
  >>> try:
368
566
  ... df.validate()
369
- ... except pt.ValidationError as exc:
567
+ ... except pt.DataFrameValidationError as exc:
370
568
  ... print(exc)
371
569
  ...
372
570
  3 validation errors for Product
@@ -376,18 +574,18 @@ class DataFrame(pl.DataFrame, Generic[ModelType]):
376
574
  2 rows with duplicated values. (type=value_error.rowvalue)
377
575
  temperature_zone
378
576
  Rows with invalid values: {'oven'}. (type=value_error.rowvalue)
577
+
379
578
  """
380
579
  if not hasattr(self, "model"):
381
580
  raise TypeError(
382
581
  f"You must invoke {self.__class__.__name__}.set_model() "
383
582
  f"before invoking {self.__class__.__name__}.validate()."
384
583
  )
385
- self.model.validate(dataframe=self)
584
+ self.model.validate(dataframe=self, columns=columns, **kwargs)
386
585
  return self
387
586
 
388
- def derive(self: DF) -> DF:
389
- """
390
- Populate columns which have ``pt.Field(derived_from=...)`` definitions.
587
+ def derive(self: DF, columns: list[str] | None = None) -> DF:
588
+ """Populate columns which have ``pt.Field(derived_from=...)`` definitions.
391
589
 
392
590
  If a column field on the data frame model has ``patito.Field(derived_from=...)``
393
591
  specified, the given value will be used to define the column. If
@@ -411,32 +609,17 @@ class DataFrame(pl.DataFrame, Generic[ModelType]):
411
609
  ...
412
610
  >>> Foo.DataFrame({"foo": [1, 2]}).derive()
413
611
  shape: (2, 3)
414
- ┌─────┬─────┬────────────┐
415
- foobardouble_bar
416
- │ --- ┆ --- ┆ ---
417
- │ i64 ┆ i64 ┆ i64
418
- ╞═════╪═════╪════════════╡
419
- │ 1 ┆ 1 ┆ 2
420
- │ 2 ┆ 2 ┆ 4
421
- └─────┴─────┴────────────┘
612
+ ┌─────┬────────────┬─────┐
613
+ bardouble_barfoo
614
+ │ --- ┆ --- ┆ ---
615
+ │ i64 ┆ i64 ┆ i64
616
+ ╞═════╪════════════╪═════╡
617
+ │ 1 ┆ 2 ┆ 1 │
618
+ │ 2 ┆ 4 ┆ 2 │
619
+ └─────┴────────────┴─────┘
620
+
422
621
  """
423
- df = self.lazy()
424
- for column_name, props in self.model._schema_properties().items():
425
- if "derived_from" in props:
426
- derived_from = props["derived_from"]
427
- dtype = self.model.dtypes[column_name]
428
- if isinstance(derived_from, str):
429
- df = df.with_columns(
430
- pl.col(derived_from).cast(dtype).alias(column_name)
431
- )
432
- elif isinstance(derived_from, pl.Expr):
433
- df = df.with_columns(derived_from.cast(dtype).alias(column_name))
434
- else:
435
- raise TypeError(
436
- "Can not derive dataframe column from type "
437
- f"{type(derived_from)}."
438
- )
439
- return cast(DF, df.collect())
622
+ return cast(DF, self.lazy().derive(columns=columns).collect())
440
623
 
441
624
  def fill_null(
442
625
  self: DF,
@@ -449,10 +632,9 @@ class DataFrame(pl.DataFrame, Generic[ModelType]):
449
632
  limit: Optional[int] = None,
450
633
  matches_supertype: bool = True,
451
634
  ) -> DF:
452
- """
453
- Fill null values using a filling strategy, literal, or ``Expr``.
635
+ """Fill null values using a filling strategy, literal, or ``Expr``.
454
636
 
455
- If ``"default"`` is provided as the strategy, the model fields with default
637
+ If ``"defaults"`` is provided as the strategy, the model fields with default
456
638
  values are used to fill missing values.
457
639
 
458
640
  Args:
@@ -488,6 +670,7 @@ class DataFrame(pl.DataFrame, Generic[ModelType]):
488
670
  │ apple ┆ 10 │
489
671
  │ banana ┆ 19 │
490
672
  └────────┴───────┘
673
+
491
674
  """
492
675
  if strategy != "defaults": # pragma: no cover
493
676
  return cast( # pyright: ignore[redundant-cast]
@@ -501,14 +684,20 @@ class DataFrame(pl.DataFrame, Generic[ModelType]):
501
684
  )
502
685
  return self.with_columns(
503
686
  [
504
- pl.col(column).fill_null(pl.lit(default_value))
687
+ (
688
+ pl.col(column).fill_null(
689
+ pl.lit(default_value, self.model.dtypes[column])
690
+ )
691
+ if column in self.columns
692
+ else pl.Series(column, [default_value], self.model.dtypes[column])
693
+ ) # NOTE: hack to get around polars bug https://github.com/pola-rs/polars/issues/13602
694
+ # else pl.lit(default_value, self.model.dtypes[column]).alias(column)
505
695
  for column, default_value in self.model.defaults.items()
506
696
  ]
507
697
  ).set_model(self.model)
508
698
 
509
699
  def get(self, predicate: Optional[pl.Expr] = None) -> ModelType:
510
- """
511
- Fetch the single row that matches the given polars predicate.
700
+ """Fetch the single row that matches the given polars predicate.
512
701
 
513
702
  If you expect a data frame to already consist of one single row,
514
703
  you can use ``.get()`` without any arguments to return that row.
@@ -574,6 +763,7 @@ class DataFrame(pl.DataFrame, Generic[ModelType]):
574
763
  ... print(e)
575
764
  ...
576
765
  DataFrame.get() yielded 0 rows.
766
+
577
767
  """
578
768
  row = self if predicate is None else self.filter(predicate)
579
769
  if row.height == 0:
@@ -589,12 +779,12 @@ class DataFrame(pl.DataFrame, Generic[ModelType]):
589
779
  return self._pydantic_model().from_row(row) # type: ignore
590
780
 
591
781
  def _pydantic_model(self) -> Type[Model]:
592
- """
593
- Dynamically construct patito model compliant with dataframe.
782
+ """Dynamically construct patito model compliant with dataframe.
594
783
 
595
784
  Returns:
596
785
  A pydantic model class where all the rows have been specified as
597
786
  `typing.Any` fields.
787
+
598
788
  """
599
789
  from patito.pydantic import Model
600
790
 
@@ -608,14 +798,17 @@ class DataFrame(pl.DataFrame, Generic[ModelType]):
608
798
  ),
609
799
  )
610
800
 
801
+ def as_polars(self) -> pl.DataFrame:
802
+ """Convert patito dataframe to polars dataframe."""
803
+ return pl.DataFrame._from_pydf(self._df)
804
+
611
805
  @classmethod
612
806
  def read_csv( # type: ignore[no-untyped-def]
613
807
  cls: Type[DF],
614
808
  *args, # noqa: ANN002
615
809
  **kwargs, # noqa: ANN003
616
810
  ) -> DF:
617
- r"""
618
- Read CSV and apply correct column name and types from model.
811
+ r"""Read CSV and apply correct column name and types from model.
619
812
 
620
813
  If any fields have ``derived_from`` specified, the given expression will be used
621
814
  to populate the given column(s).
@@ -670,10 +863,25 @@ class DataFrame(pl.DataFrame, Generic[ModelType]):
670
863
  # ╞═════╪═════╡
671
864
  # │ 1.0 ┆ 1 │
672
865
  # └─────┴─────┘
866
+
673
867
  """
674
868
  kwargs.setdefault("dtypes", cls.model.dtypes)
675
- if not kwargs.get("has_header", True) and "columns" not in kwargs:
869
+ has_header = kwargs.get("has_header", True)
870
+ if not has_header and "columns" not in kwargs:
676
871
  kwargs.setdefault("new_columns", cls.model.columns)
872
+ alias_gen = cls.model.model_config.get("alias_generator")
873
+ if alias_gen:
874
+ alias_func = alias_gen.validation_alias or alias_gen.alias
875
+ if has_header and alias_gen and alias_func:
876
+ fields_to_cols = {
877
+ field_name: alias_func(field_name)
878
+ for field_name in cls.model.model_fields
879
+ }
880
+ kwargs["dtypes"] = {
881
+ fields_to_cols.get(field, field): dtype
882
+ for field, dtype in kwargs["dtypes"].items()
883
+ }
884
+ # TODO: other forms of alias setting like in Field
677
885
  df = cls.model.DataFrame._from_pydf(pl.read_csv(*args, **kwargs)._df)
678
886
  return df.derive()
679
887
 
@@ -684,7 +892,7 @@ class DataFrame(pl.DataFrame, Generic[ModelType]):
684
892
  pl.Expr, str, pl.Series, list[bool], np.ndarray[Any, Any], bool
685
893
  ],
686
894
  ) -> DF:
687
- return cast(DF, super().filter(predicate=predicate))
895
+ return cast(DF, super().filter(predicate))
688
896
 
689
897
  def select( # noqa: D102
690
898
  self: DF,