patito 0.4.4__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
patito/polars.py CHANGED
@@ -4,9 +4,10 @@ from __future__ import annotations
4
4
  from typing import (
5
5
  TYPE_CHECKING,
6
6
  Any,
7
+ Collection,
7
8
  Generic,
9
+ Iterable,
8
10
  Optional,
9
- Sequence,
10
11
  Type,
11
12
  TypeVar,
12
13
  Union,
@@ -14,6 +15,7 @@ from typing import (
14
15
  )
15
16
 
16
17
  import polars as pl
18
+ from polars.type_aliases import IntoExpr
17
19
  from pydantic import create_model
18
20
  from typing_extensions import Literal
19
21
 
@@ -21,7 +23,6 @@ from patito.exceptions import MultipleRowsReturned, RowDoesNotExist
21
23
 
22
24
  if TYPE_CHECKING:
23
25
  import numpy as np
24
- from polars.internals import WhenThen, WhenThenThen
25
26
 
26
27
  from patito.pydantic import Model
27
28
 
@@ -209,11 +210,8 @@ class DataFrame(pl.DataFrame, Generic[ModelType]):
209
210
  │ i64 ┆ str │
210
211
  ╞══════╪════════╡
211
212
  │ 1 ┆ A │
212
- ├╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
213
213
  │ 1 ┆ B │
214
- ├╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
215
214
  │ 2 ┆ A │
216
- ├╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
217
215
  │ 2 ┆ B │
218
216
  └──────┴────────┘
219
217
  >>> casted_classes = classes.cast()
@@ -225,11 +223,8 @@ class DataFrame(pl.DataFrame, Generic[ModelType]):
225
223
  │ u16 ┆ cat │
226
224
  ╞══════╪════════╡
227
225
  │ 1 ┆ A │
228
- ├╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
229
226
  │ 1 ┆ B │
230
- ├╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
231
227
  │ 2 ┆ A │
232
- ├╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌┤
233
228
  │ 2 ┆ B │
234
229
  └──────┴────────┘
235
230
  >>> casted_classes.validate()
@@ -292,7 +287,11 @@ class DataFrame(pl.DataFrame, Generic[ModelType]):
292
287
  columns.append(pl.col(column).cast(default_dtypes[column]))
293
288
  return self.with_columns(columns)
294
289
 
295
- def drop(self: DF, columns: Optional[Union[str, Sequence[str]]] = None) -> DF:
290
+ def drop(
291
+ self: DF,
292
+ columns: Optional[Union[str, Collection[str]]] = None,
293
+ *more_columns: str,
294
+ ) -> DF:
296
295
  """
297
296
  Drop one or more columns from the dataframe.
298
297
 
@@ -304,6 +303,7 @@ class DataFrame(pl.DataFrame, Generic[ModelType]):
304
303
  columns: A single column string name, or list of strings, indicating
305
304
  which columns to drop. If not specified, all columns *not*
306
305
  specified by the associated dataframe model will be dropped.
306
+ more_columns: Additional named columns to drop.
307
307
 
308
308
  Returns:
309
309
  DataFrame[Model]: New dataframe without the specified columns.
@@ -321,13 +321,12 @@ class DataFrame(pl.DataFrame, Generic[ModelType]):
321
321
  │ i64 │
322
322
  ╞══════════╡
323
323
  │ 1 │
324
- ├╌╌╌╌╌╌╌╌╌╌┤
325
324
  │ 2 │
326
325
  └──────────┘
327
326
 
328
327
  """
329
328
  if columns is not None:
330
- return super().drop(columns)
329
+ return self._from_pydf(super().drop(columns)._df)
331
330
  else:
332
331
  return self.drop(list(set(self.columns) - set(self.model.columns)))
333
332
 
@@ -418,7 +417,6 @@ class DataFrame(pl.DataFrame, Generic[ModelType]):
418
417
  │ i64 ┆ i64 ┆ i64 │
419
418
  ╞═════╪═════╪════════════╡
420
419
  │ 1 ┆ 1 ┆ 2 │
421
- ├╌╌╌╌╌┼╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌┤
422
420
  │ 2 ┆ 2 ┆ 4 │
423
421
  └─────┴─────┴────────────┘
424
422
  """
@@ -428,11 +426,11 @@ class DataFrame(pl.DataFrame, Generic[ModelType]):
428
426
  derived_from = props["derived_from"]
429
427
  dtype = self.model.dtypes[column_name]
430
428
  if isinstance(derived_from, str):
431
- df = df.with_column(
429
+ df = df.with_columns(
432
430
  pl.col(derived_from).cast(dtype).alias(column_name)
433
431
  )
434
432
  elif isinstance(derived_from, pl.Expr):
435
- df = df.with_column(derived_from.cast(dtype).alias(column_name))
433
+ df = df.with_columns(derived_from.cast(dtype).alias(column_name))
436
434
  else:
437
435
  raise TypeError(
438
436
  "Can not derive dataframe column from type "
@@ -488,12 +486,11 @@ class DataFrame(pl.DataFrame, Generic[ModelType]):
488
486
  │ str ┆ i64 │
489
487
  ╞════════╪═══════╡
490
488
  │ apple ┆ 10 │
491
- ├╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌┤
492
489
  │ banana ┆ 19 │
493
490
  └────────┴───────┘
494
491
  """
495
492
  if strategy != "defaults": # pragma: no cover
496
- return cast( # type: ignore[redundant-cast]
493
+ return cast( # pyright: ignore[redundant-cast]
497
494
  DF,
498
495
  super().fill_null(
499
496
  value=value,
@@ -607,7 +604,7 @@ class DataFrame(pl.DataFrame, Generic[ModelType]):
607
604
  create_model( # type: ignore
608
605
  "UntypedRow",
609
606
  __base__=Model,
610
- **pydantic_annotations,
607
+ **pydantic_annotations, # pyright: ignore
611
608
  ),
612
609
  )
613
610
 
@@ -662,15 +659,17 @@ class DataFrame(pl.DataFrame, Generic[ModelType]):
662
659
  ... b: str = pt.Field(derived_from="source_of_b")
663
660
  ...
664
661
  >>> csv_file = io.StringIO("a,source_of_b\n1,1")
665
- >>> CSVModel.DataFrame.read_csv(csv_file).drop()
666
- shape: (1, 2)
667
- ┌─────┬─────┐
668
- a ┆ b │
669
- --- ┆ --- │
670
- f64 str
671
- ╞═════╪═════╡
672
- 1.01
673
- └─────┴─────┘
662
+
663
+
664
+ # >>> CSVModel.DataFrame.read_csv(csv_file).drop()
665
+ # shape: (1, 2)
666
+ # ┌─────┬─────┐
667
+ # a b
668
+ # │ --- ┆ --- │
669
+ # f64str
670
+ # ╞═════╪═════╡
671
+ # │ 1.0 ┆ 1 │
672
+ # └─────┴─────┘
674
673
  """
675
674
  kwargs.setdefault("dtypes", cls.model.dtypes)
676
675
  if not kwargs.get("has_header", True) and "columns" not in kwargs:
@@ -681,31 +680,24 @@ class DataFrame(pl.DataFrame, Generic[ModelType]):
681
680
  # --- Type annotation overrides ---
682
681
  def filter( # noqa: D102
683
682
  self: DF,
684
- predicate: Union[pl.Expr, str, pl.Series, list[bool], np.ndarray[Any, Any]],
683
+ predicate: Union[
684
+ pl.Expr, str, pl.Series, list[bool], np.ndarray[Any, Any], bool
685
+ ],
685
686
  ) -> DF:
686
687
  return cast(DF, super().filter(predicate=predicate))
687
688
 
688
689
  def select( # noqa: D102
689
690
  self: DF,
690
- exprs: Union[
691
- pl.Expr,
692
- pl.Series,
693
- Sequence[Union[str, pl.Expr, pl.Series, "WhenThen", "WhenThenThen"]],
694
- ],
691
+ *exprs: Union[IntoExpr, Iterable[IntoExpr]],
692
+ **named_exprs: IntoExpr,
695
693
  ) -> DF:
696
- return cast(DF, super().select(exprs=exprs)) # type: ignore[redundant-cast]
697
-
698
- def with_column(self: DF, column: Union[pl.Series, pl.Expr]) -> DF: # noqa: D102
699
- return cast(DF, super().with_column(column=column))
694
+ return cast( # pyright: ignore[redundant-cast]
695
+ DF, super().select(*exprs, **named_exprs)
696
+ )
700
697
 
701
698
  def with_columns( # noqa: D102
702
699
  self: DF,
703
- exprs: Union[
704
- pl.Expr,
705
- pl.Series,
706
- Sequence[Union[pl.Expr, pl.Series]],
707
- None,
708
- ] = None,
709
- **named_exprs: Union[pl.Expr, pl.Series],
700
+ *exprs: Union[IntoExpr, Iterable[IntoExpr]],
701
+ **named_exprs: IntoExpr,
710
702
  ) -> DF:
711
- return cast(DF, super().with_columns(exprs=exprs, **named_exprs))
703
+ return cast(DF, super().with_columns(*exprs, **named_exprs))
patito/pydantic.py CHANGED
@@ -19,6 +19,7 @@ from typing import (
19
19
  )
20
20
 
21
21
  import polars as pl
22
+ from polars.datatypes import PolarsDataType
22
23
  from pydantic import BaseConfig, BaseModel, Field, create_model # noqa: F401
23
24
  from pydantic.main import ModelMetaclass as PydanticModelMetaclass
24
25
  from typing_extensions import Literal, get_args
@@ -110,7 +111,7 @@ class ModelMetaclass(PydanticModelMetaclass):
110
111
 
111
112
  @property
112
113
  def dtypes( # type: ignore
113
- cls: Type[ModelType],
114
+ cls: Type[ModelType], # pyright: ignore
114
115
  ) -> dict[str, Type[pl.DataType]]:
115
116
  """
116
117
  Return the polars dtypes of the dataframe.
@@ -129,18 +130,16 @@ class ModelMetaclass(PydanticModelMetaclass):
129
130
  ... price: float
130
131
  ...
131
132
  >>> Product.dtypes
132
- {'name': <class 'polars.datatypes.Utf8'>, \
133
- 'ideal_temperature': <class 'polars.datatypes.Int64'>, \
134
- 'price': <class 'polars.datatypes.Float64'>}
133
+ {'name': Utf8, 'ideal_temperature': Int64, 'price': Float64}
135
134
  """
136
135
  return {
137
136
  column: valid_dtypes[0] for column, valid_dtypes in cls.valid_dtypes.items()
138
137
  }
139
138
 
140
139
  @property
141
- def valid_dtypes( # type: ignore # noqa: C901
142
- cls: Type[ModelType],
143
- ) -> dict[str, List[Type[pl.DataType]]]:
140
+ def valid_dtypes( # type: ignore
141
+ cls: Type[ModelType], # pyright: ignore
142
+ ) -> dict[str, List[Union[pl.PolarsDataType, pl.List]]]:
144
143
  """
145
144
  Return a list of polars dtypes which Patito considers valid for each field.
146
145
 
@@ -164,82 +163,91 @@ class ModelMetaclass(PydanticModelMetaclass):
164
163
  ... float_column: float
165
164
  ...
166
165
  >>> pprint(MyModel.valid_dtypes)
167
- {'bool_column': [<class 'polars.datatypes.Boolean'>],
168
- 'float_column': [<class 'polars.datatypes.Float64'>,
169
- <class 'polars.datatypes.Float32'>],
170
- 'int_column': [<class 'polars.datatypes.Int64'>,
171
- <class 'polars.datatypes.Int32'>,
172
- <class 'polars.datatypes.Int16'>,
173
- <class 'polars.datatypes.Int8'>,
174
- <class 'polars.datatypes.UInt64'>,
175
- <class 'polars.datatypes.UInt32'>,
176
- <class 'polars.datatypes.UInt16'>,
177
- <class 'polars.datatypes.UInt8'>],
178
- 'str_column': [<class 'polars.datatypes.Utf8'>]}
166
+ {'bool_column': [Boolean],
167
+ 'float_column': [Float64, Float32],
168
+ 'int_column': [Int64, Int32, Int16, Int8, UInt64, UInt32, UInt16, UInt8],
169
+ 'str_column': [Utf8]}
179
170
  """
180
171
  valid_dtypes = {}
181
172
  for column, props in cls._schema_properties().items():
182
- if "dtype" in props:
183
- valid_dtypes[column] = [
184
- props["dtype"],
185
- ]
186
- elif "enum" in props and props["type"] == "string":
187
- valid_dtypes[column] = [pl.Categorical, pl.Utf8]
188
- elif "type" not in props:
173
+ column_dtypes: List[Union[PolarsDataType, pl.List]]
174
+ if props.get("type") == "array":
175
+ array_props = props["items"]
176
+ item_dtypes = cls._valid_dtypes(props=array_props)
177
+ if item_dtypes is None:
178
+ raise NotImplementedError(
179
+ f"No valid dtype mapping found for column '{column}'."
180
+ )
181
+ column_dtypes = [pl.List(dtype) for dtype in item_dtypes]
182
+ else:
183
+ column_dtypes = cls._valid_dtypes(props=props) # pyright: ignore
184
+
185
+ if column_dtypes is None:
189
186
  raise NotImplementedError(
190
187
  f"No valid dtype mapping found for column '{column}'."
191
188
  )
192
- elif props["type"] == "integer":
193
- valid_dtypes[column] = [
194
- pl.Int64,
195
- pl.Int32,
196
- pl.Int16,
197
- pl.Int8,
198
- pl.UInt64,
199
- pl.UInt32,
200
- pl.UInt16,
201
- pl.UInt8,
202
- ]
203
- elif props["type"] == "number":
204
- if props.get("format") == "time-delta":
205
- valid_dtypes[column] = [
206
- pl.Duration,
207
- ] # pyright: reportPrivateImportUsage=false
208
- else:
209
- valid_dtypes[column] = [pl.Float64, pl.Float32]
210
- elif props["type"] == "boolean":
211
- valid_dtypes[column] = [
212
- pl.Boolean,
213
- ]
214
- elif props["type"] == "string":
215
- string_format = props.get("format")
216
- if string_format is None:
217
- valid_dtypes[column] = [
218
- pl.Utf8,
219
- ]
220
- elif string_format == "date":
221
- valid_dtypes[column] = [
222
- pl.Date,
223
- ]
224
- # TODO: Find out why this branch is not being hit
225
- elif string_format == "date-time": # pragma: no cover
226
- valid_dtypes[column] = [
227
- pl.Datetime,
228
- ]
229
- elif props["type"] == "null":
230
- valid_dtypes[column] = [
231
- pl.Null,
232
- ]
233
- else: # pragma: no cover
234
- raise NotImplementedError(
235
- f"No valid dtype mapping found for column '{column}'"
236
- )
189
+ valid_dtypes[column] = column_dtypes
237
190
 
238
191
  return valid_dtypes
239
192
 
193
+ @staticmethod
194
+ def _valid_dtypes( # noqa: C901
195
+ props: Dict,
196
+ ) -> Optional[List[pl.PolarsDataType]]:
197
+ """
198
+ Map schema property to list of valid polars data types.
199
+
200
+ Args:
201
+ props: Dictionary value retrieved from BaseModel._schema_properties().
202
+
203
+ Returns:
204
+ List of valid dtypes. None if no mapping exists.
205
+ """
206
+ if "dtype" in props:
207
+ return [
208
+ props["dtype"],
209
+ ]
210
+ elif "enum" in props and props["type"] == "string":
211
+ return [pl.Categorical, pl.Utf8]
212
+ elif "type" not in props:
213
+ return None
214
+ elif props["type"] == "integer":
215
+ return [
216
+ pl.Int64,
217
+ pl.Int32,
218
+ pl.Int16,
219
+ pl.Int8,
220
+ pl.UInt64,
221
+ pl.UInt32,
222
+ pl.UInt16,
223
+ pl.UInt8,
224
+ ]
225
+ elif props["type"] == "number":
226
+ if props.get("format") == "time-delta":
227
+ return [pl.Duration]
228
+ else:
229
+ return [pl.Float64, pl.Float32]
230
+ elif props["type"] == "boolean":
231
+ return [pl.Boolean]
232
+ elif props["type"] == "string":
233
+ string_format = props.get("format")
234
+ if string_format is None:
235
+ return [pl.Utf8]
236
+ elif string_format == "date":
237
+ return [pl.Date]
238
+ # TODO: Find out why this branch is not being hit
239
+ elif string_format == "date-time": # pragma: no cover
240
+ return [pl.Datetime]
241
+ else:
242
+ return None # pragma: no cover
243
+ elif props["type"] == "null":
244
+ return [pl.Null]
245
+ else: # pragma: no cover
246
+ return None
247
+
240
248
  @property
241
249
  def valid_sql_types( # type: ignore # noqa: C901
242
- cls: Type[ModelType],
250
+ cls: Type[ModelType], # pyright: ignore
243
251
  ) -> dict[str, List["DuckDBSQLType"]]:
244
252
  """
245
253
  Return a list of DuckDB SQL types which Patito considers valid for each field.
@@ -302,7 +310,7 @@ class ModelMetaclass(PydanticModelMetaclass):
302
310
  from patito.duckdb import _enum_type_name
303
311
 
304
312
  # fmt: off
305
- valid_dtypes[column] = [
313
+ valid_dtypes[column] = [ # pyright: ignore
306
314
  _enum_type_name(field_properties=props), # type: ignore
307
315
  "VARCHAR", "CHAR", "BPCHAR", "TEXT", "STRING",
308
316
  ]
@@ -374,7 +382,7 @@ class ModelMetaclass(PydanticModelMetaclass):
374
382
 
375
383
  @property
376
384
  def sql_types( # type: ignore
377
- cls: Type[ModelType],
385
+ cls: Type[ModelType], # pyright: ignore
378
386
  ) -> dict[str, str]:
379
387
  """
380
388
  Return compatible DuckDB SQL types for all model fields.
@@ -405,7 +413,7 @@ class ModelMetaclass(PydanticModelMetaclass):
405
413
 
406
414
  @property
407
415
  def defaults( # type: ignore
408
- cls: Type[ModelType],
416
+ cls: Type[ModelType], # pyright: ignore
409
417
  ) -> dict[str, Any]:
410
418
  """
411
419
  Return default field values specified on the model.
@@ -432,7 +440,7 @@ class ModelMetaclass(PydanticModelMetaclass):
432
440
 
433
441
  @property
434
442
  def non_nullable_columns( # type: ignore
435
- cls: Type[ModelType], # pyright: reportGeneralTypeIssues=false
443
+ cls: Type[ModelType], # pyright: ignore
436
444
  ) -> set[str]:
437
445
  """
438
446
  Return names of those columns that are non-nullable in the schema.
@@ -456,7 +464,7 @@ class ModelMetaclass(PydanticModelMetaclass):
456
464
 
457
465
  @property
458
466
  def nullable_columns( # type: ignore
459
- cls: Type[ModelType], # pyright: reportGeneralTypeIssues=false
467
+ cls: Type[ModelType], # pyright: ignore
460
468
  ) -> set[str]:
461
469
  """
462
470
  Return names of those columns that are nullable in the schema.
@@ -480,7 +488,7 @@ class ModelMetaclass(PydanticModelMetaclass):
480
488
 
481
489
  @property
482
490
  def unique_columns( # type: ignore
483
- cls: Type[ModelType],
491
+ cls: Type[ModelType], # pyright: ignore
484
492
  ) -> set[str]:
485
493
  """
486
494
  Return columns with uniqueness constraint.
@@ -531,12 +539,16 @@ class Model(BaseModel, metaclass=ModelMetaclass):
531
539
 
532
540
  @classmethod # type: ignore[misc]
533
541
  @property
534
- def DataFrame(cls: Type[ModelType]) -> Type[DataFrame[ModelType]]:
542
+ def DataFrame(
543
+ cls: Type[ModelType],
544
+ ) -> Type[DataFrame[ModelType]]: # pyright: ignore # noqa
535
545
  """Return DataFrame class where DataFrame.set_model() is set to self."""
536
546
 
537
547
  @classmethod # type: ignore[misc]
538
548
  @property
539
- def LazyFrame(cls: Type[ModelType]) -> Type[LazyFrame[ModelType]]:
549
+ def LazyFrame(
550
+ cls: Type[ModelType],
551
+ ) -> Type[LazyFrame[ModelType]]: # pyright: ignore
540
552
  """Return DataFrame class where DataFrame.set_model() is set to self."""
541
553
 
542
554
  @classmethod
@@ -570,7 +582,7 @@ class Model(BaseModel, metaclass=ModelMetaclass):
570
582
 
571
583
  >>> df = pl.DataFrame(
572
584
  ... [["1", "product name", "1.22"]],
573
- ... columns=["product_id", "name", "price"],
585
+ ... schema=["product_id", "name", "price"],
574
586
  ... )
575
587
  >>> Product.from_row(df)
576
588
  Product(product_id=1, name='product name', price=1.22)
@@ -622,7 +634,7 @@ class Model(BaseModel, metaclass=ModelMetaclass):
622
634
 
623
635
  >>> df = pl.DataFrame(
624
636
  ... [["1", "product name", "1.22"]],
625
- ... columns=["product_id", "name", "price"],
637
+ ... schema=["product_id", "name", "price"],
626
638
  ... )
627
639
  >>> Product._from_polars(df)
628
640
  Product(product_id=1, name='product name', price=1.22)
@@ -977,7 +989,6 @@ class Model(BaseModel, metaclass=ModelMetaclass):
977
989
  │ str ┆ cat ┆ i64 │
978
990
  ╞═══════════╪══════════════════╪════════════╡
979
991
  │ product A ┆ dry ┆ 0 │
980
- ├╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌╌┼╌╌╌╌╌╌╌╌╌╌╌╌┤
981
992
  │ product B ┆ dry ┆ 1 │
982
993
  └───────────┴──────────────────┴────────────┘
983
994
  """
@@ -1461,8 +1472,8 @@ class FieldDoc:
1461
1472
  product_id
1462
1473
  2 rows with duplicated values. (type=value_error.rowvalue)
1463
1474
  price
1464
- Polars dtype <class 'polars.datatypes.Int64'> \
1465
- does not match model field type. (type=type_error.columndtype)
1475
+ Polars dtype Int64 does not match model field type. \
1476
+ (type=type_error.columndtype)
1466
1477
  brand_color
1467
1478
  2 rows with out of bound values. (type=value_error.rowvalue)
1468
1479
  """
patito/sql.py CHANGED
@@ -45,7 +45,7 @@ class Case:
45
45
 
46
46
  Examples:
47
47
  >>> import patito as pt
48
- >>> db = pt.Database()
48
+ >>> db = pt.duckdb.Database()
49
49
  >>> relation = db.to_relation("select 1 as a union select 2 as a")
50
50
  >>> case_statement = pt.sql.Case(
51
51
  ... on_column="a",
@@ -61,7 +61,6 @@ class Case:
61
61
  │ i64 ┆ str │
62
62
  ╞═════╪═════╡
63
63
  │ 1 ┆ one │
64
- ├╌╌╌╌╌┼╌╌╌╌╌┤
65
64
  │ 2 ┆ two │
66
65
  └─────┴─────┘
67
66
  """
patito/validators.py CHANGED
@@ -1,9 +1,11 @@
1
1
  """Module for validating datastructures with respect to model specifications."""
2
2
  from __future__ import annotations
3
3
 
4
+ import sys
4
5
  from typing import TYPE_CHECKING, Type, Union, cast
5
6
 
6
7
  import polars as pl
8
+ from typing_extensions import get_args, get_origin
7
9
 
8
10
  from patito.exceptions import (
9
11
  ColumnDTypeError,
@@ -15,6 +17,13 @@ from patito.exceptions import (
15
17
  ValidationError,
16
18
  )
17
19
 
20
+ if sys.version_info >= (3, 10): # pragma: no cover
21
+ from types import UnionType # pyright: ignore
22
+
23
+ UNION_TYPES = (Union, UnionType)
24
+ else:
25
+ UNION_TYPES = (Union,) # pragma: no cover
26
+
18
27
  try:
19
28
  import pandas as pd
20
29
 
@@ -44,6 +53,44 @@ VALID_POLARS_TYPES = {
44
53
  }
45
54
 
46
55
 
56
+ def _is_optional(type_annotation: Type) -> bool:
57
+ """
58
+ Return True if the given type annotation is an Optional annotation.
59
+
60
+ Args:
61
+ type_annotation: The type annotation to be checked.
62
+
63
+ Returns:
64
+ True if the outermost type is Optional.
65
+ """
66
+ return (get_origin(type_annotation) in UNION_TYPES) and (
67
+ type(None) in get_args(type_annotation)
68
+ )
69
+
70
+
71
+ def _dewrap_optional(type_annotation: Type) -> Type:
72
+ """
73
+ Return the inner, wrapped type of an Optional.
74
+
75
+ Is a no-op for non-Optional types.
76
+
77
+ Args:
78
+ type_annotation: The type annotation to be dewrapped.
79
+
80
+ Returns:
81
+ The input type, but with the outermost Optional removed.
82
+ """
83
+ return (
84
+ next( # pragma: no cover
85
+ valid_type
86
+ for valid_type in get_args(type_annotation)
87
+ if valid_type is not type(None) # noqa: E721
88
+ )
89
+ if _is_optional(type_annotation)
90
+ else type_annotation
91
+ )
92
+
93
+
47
94
  def _find_errors( # noqa: C901
48
95
  dataframe: pl.DataFrame,
49
96
  schema: Type[Model],
@@ -99,6 +146,45 @@ def _find_errors( # noqa: C901
99
146
  )
100
147
  )
101
148
 
149
+ for column, dtype in schema.dtypes.items():
150
+ if not isinstance(dtype, pl.List):
151
+ continue
152
+
153
+ annotation = schema.__annotations__[column] # type: ignore[unreachable]
154
+
155
+ # Retrieve the annotation of the list itself,
156
+ # dewrapping any potential Optional[...]
157
+ list_type = _dewrap_optional(annotation)
158
+
159
+ # Check if the list items themselves should be considered nullable
160
+ item_type = get_args(list_type)[0]
161
+ if _is_optional(item_type):
162
+ continue
163
+
164
+ num_missing_values = (
165
+ dataframe.lazy()
166
+ .select(column)
167
+ # Remove those rows that do not contain lists at all
168
+ .filter(pl.col(column).is_not_null())
169
+ # Convert lists of N items to N individual rows
170
+ .explode(column)
171
+ # Calculate how many nulls are present in lists
172
+ .filter(pl.col(column).is_null())
173
+ .collect()
174
+ .height
175
+ )
176
+ if num_missing_values != 0:
177
+ errors.append(
178
+ ErrorWrapper(
179
+ MissingValuesError(
180
+ f"{num_missing_values} missing "
181
+ f"{'value' if num_missing_values == 1 else 'values'} "
182
+ f"in lists"
183
+ ),
184
+ loc=column,
185
+ )
186
+ )
187
+
102
188
  # Check if any column has a wrong dtype
103
189
  valid_dtypes = schema.valid_dtypes
104
190
  dataframe_datatypes = dict(zip(dataframe.columns, dataframe.dtypes))
@@ -189,7 +275,7 @@ def _find_errors( # noqa: C901
189
275
  )
190
276
  if "_" in constraints.meta.root_names():
191
277
  # An underscore is an alias for the current field
192
- illegal_rows = dataframe.with_column(
278
+ illegal_rows = dataframe.with_columns(
193
279
  pl.col(column_name).alias("_")
194
280
  ).filter(constraints)
195
281
  else:
patito/xdg.py ADDED
@@ -0,0 +1,22 @@
1
+ """Module implementing the XDG directory standard."""
2
+ import os
3
+ from pathlib import Path
4
+ from typing import Optional
5
+
6
+
7
+ def cache_home(application: Optional[str] = None) -> Path:
8
+ """
9
+ Return path to directory containing user-specific non-essential data files.
10
+
11
+ Args:
12
+ application: An optional name of an application for which to return an
13
+ application-specific cache directory for.
14
+
15
+ Returns:
16
+ A path object pointing to a directory to store cache files.
17
+ """
18
+ path = Path(os.environ.get("XDG_CACHE_HOME", "~/.cache")).resolve()
19
+ if application:
20
+ path = path / application
21
+ path.mkdir(exist_ok=True, parents=True)
22
+ return path
@@ -1,6 +1,7 @@
1
1
  MIT License
2
2
 
3
3
  Copyright (c) 2022 Oda Group Holding AS
4
+ Copyright (c) 2023 Jakob Gerhard Martinussen and contributors
4
5
 
5
6
  Permission is hereby granted, free of charge, to any person obtaining a copy
6
7
  of this software and associated documentation files (the "Software"), to deal