patito 0.6.1__py3-none-any.whl → 0.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
patito/__init__.py CHANGED
@@ -1,4 +1,5 @@
1
1
  """Patito, a data-modelling library built on top of polars and pydantic."""
2
+
2
3
  from polars import Expr, Series, col
3
4
 
4
5
  from patito import exceptions
patito/_docs.py CHANGED
@@ -1,2 +1,3 @@
1
1
  """Ugly workaround for Sphinx + autodoc + ModelMetaclass + classproperty."""
2
+
2
3
  from patito.pydantic import ModelMetaclass as Model # noqa: F401, pragma: no cover
@@ -1 +0,0 @@
1
-
@@ -19,10 +19,9 @@ from patito._pydantic.dtypes import parse_composite_dtype
19
19
 
20
20
 
21
21
  class ColumnInfo(BaseModel, arbitrary_types_allowed=True):
22
- """patito-side model for storing column metadata
22
+ """patito-side model for storing column metadata.
23
23
 
24
24
  Args:
25
- ----
26
25
  constraints (Union[polars.Expression, List[polars.Expression]): A single
27
26
  constraint or list of constraints, expressed as a polars expression objects.
28
27
  All rows must satisfy the given constraint. You can refer to the given column
@@ -40,6 +39,22 @@ class ColumnInfo(BaseModel, arbitrary_types_allowed=True):
40
39
  derived_from: Optional[Union[str, pl.Expr]] = None
41
40
  unique: Optional[bool] = None
42
41
 
42
+ def __repr__(self) -> str:
43
+ """Print only Field attributes whose values are not default (mainly None)."""
44
+ not_default_field = {
45
+ field: getattr(self, field)
46
+ for field in self.model_fields
47
+ if getattr(self, field) is not self.model_fields[field].default
48
+ }
49
+
50
+ string = ""
51
+ for field, value in not_default_field.items():
52
+ string += f"{field}={value}, "
53
+ if string:
54
+ # remove trailing comma and space
55
+ string = string[:-2]
56
+ return f"ColumnInfo({string})"
57
+
43
58
  @field_serializer("constraints", "derived_from")
44
59
  def serialize_exprs(self, exprs: str | pl.Expr | Sequence[pl.Expr] | None) -> Any:
45
60
  if exprs is None:
@@ -56,17 +71,17 @@ class ColumnInfo(BaseModel, arbitrary_types_allowed=True):
56
71
  def _serialize_expr(self, expr: pl.Expr) -> Dict:
57
72
  if isinstance(expr, pl.Expr):
58
73
  return json.loads(
59
- expr.meta.write_json(None)
74
+ expr.meta.serialize(format="json")
60
75
  ) # can we access the dictionary directly?
61
76
  else:
62
77
  raise ValueError(f"Invalid type for expr: {type(expr)}")
63
78
 
64
79
  @field_serializer("dtype")
65
80
  def serialize_dtype(self, dtype: DataTypeClass | DataType | None) -> Any:
66
- """References
67
- ----------
68
- [1] https://stackoverflow.com/questions/76572310/how-to-serialize-deserialize-polars-datatypes
81
+ """Serialize a polars dtype.
69
82
 
83
+ References:
84
+ [1] https://stackoverflow.com/questions/76572310/how-to-serialize-deserialize-polars-datatypes
70
85
  """
71
86
  if dtype is None:
72
87
  return None
@@ -5,7 +5,8 @@ from operator import and_
5
5
  from typing import TYPE_CHECKING, Any, Dict, FrozenSet, Mapping, Optional, Type
6
6
 
7
7
  import polars as pl
8
- from polars.datatypes import DataType, DataTypeClass, DataTypeGroup
8
+ from polars.datatypes import DataType, DataTypeClass
9
+ from polars.datatypes.group import DataTypeGroup
9
10
  from pydantic import TypeAdapter
10
11
 
11
12
  from patito._pydantic.dtypes.utils import (
@@ -39,24 +40,17 @@ def valid_dtypes_for_model(
39
40
  @cache
40
41
  def default_dtypes_for_model(
41
42
  cls: Type[ModelType],
42
- ) -> dict[str, DataTypeClass | DataType]:
43
- default_dtypes = {}
43
+ ) -> dict[str, DataType]:
44
+ default_dtypes: dict[str, DataType] = {}
44
45
  for column in cls.columns:
45
- dtype = cls.column_infos[column].dtype
46
+ dtype = (
47
+ cls.column_infos[column].dtype
48
+ or DtypeResolver(cls.model_fields[column].annotation).default_polars_dtype()
49
+ )
46
50
  if dtype is None:
47
- default_dtype = DtypeResolver(
48
- cls.model_fields[column].annotation
49
- ).default_polars_dtype()
50
- if default_dtype is None:
51
- raise ValueError(
52
- f"Unable to find a default dtype for column `{column}`"
53
- )
54
- else:
55
- default_dtypes[column] = default_dtype
56
- else:
57
- default_dtypes[column] = (
58
- dtype if isinstance(dtype, DataType) else dtype()
59
- ) # if dtype is not instantiated, instantiate it
51
+ raise ValueError(f"Unable to find a default dtype for column `{column}`")
52
+
53
+ default_dtypes[column] = dtype if isinstance(dtype, DataType) else dtype()
60
54
  return default_dtypes
61
55
 
62
56
 
@@ -68,7 +62,6 @@ def validate_polars_dtype(
68
62
  """Check that the polars dtype is valid for the given annotation. Raises ValueError if not.
69
63
 
70
64
  Args:
71
- ----
72
65
  annotation (type[Any] | None): python type annotation
73
66
  dtype (DataType | DataTypeClass | None): polars dtype
74
67
  column (Optional[str], optional): column name. Defaults to None.
@@ -96,7 +89,6 @@ def validate_annotation(
96
89
  """Check that the provided annotation has polars/patito support (we can resolve it to a default dtype). Raises ValueError if not.
97
90
 
98
91
  Args:
99
- ----
100
92
  annotation (type[Any] | None): python type annotation
101
93
  column (Optional[str], optional): column name. Defaults to None.
102
94
 
@@ -130,9 +122,9 @@ class DtypeResolver:
130
122
  return PT_BASE_SUPPORTED_DTYPES
131
123
  return self._valid_polars_dtypes_for_schema(self.schema)
132
124
 
133
- def default_polars_dtype(self) -> DataTypeClass | DataType | None:
125
+ def default_polars_dtype(self) -> DataType | None:
134
126
  if self.annotation == Any:
135
- return pl.String
127
+ return pl.String()
136
128
  return self._default_polars_dtype_for_schema(self.schema)
137
129
 
138
130
  def _valid_polars_dtypes_for_schema(
@@ -197,9 +189,7 @@ class DtypeResolver:
197
189
  PydanticBaseType(pyd_type), props.get("format"), props.get("enum")
198
190
  )
199
191
 
200
- def _default_polars_dtype_for_schema(
201
- self, schema: Dict
202
- ) -> DataTypeClass | DataType | None:
192
+ def _default_polars_dtype_for_schema(self, schema: Dict) -> DataType | None:
203
193
  if "anyOf" in schema:
204
194
  if len(schema["anyOf"]) == 2: # look for optionals first
205
195
  schema = _without_optional(schema)
@@ -216,10 +206,12 @@ class DtypeResolver:
216
206
  def _pydantic_subschema_to_default_dtype(
217
207
  self,
218
208
  props: Dict,
219
- ) -> DataTypeClass | DataType | None:
209
+ ) -> DataType | None:
220
210
  if "column_info" in props: # user has specified in patito model
221
211
  if props["column_info"]["dtype"] is not None:
222
- return dtype_from_string(props["column_info"]["dtype"])
212
+ dtype = dtype_from_string(props["column_info"]["dtype"])
213
+ dtype = dtype() if isinstance(dtype, DataTypeClass) else dtype
214
+ return dtype
223
215
  if "type" not in props:
224
216
  if "enum" in props:
225
217
  raise TypeError("Mixed type enums not supported by patito.")
@@ -231,6 +223,8 @@ class DtypeResolver:
231
223
  )
232
224
  return None
233
225
  pyd_type = props.get("type")
226
+ if pyd_type == "numeric":
227
+ pyd_type = "number"
234
228
  if pyd_type == "array":
235
229
  if "items" not in props:
236
230
  raise NotImplementedError(
@@ -15,12 +15,13 @@ from typing import (
15
15
  )
16
16
 
17
17
  import polars as pl
18
- from polars.datatypes import DataType, DataTypeClass, DataTypeGroup, convert
19
- from polars.datatypes.constants import (
18
+ from polars.datatypes import DataType, DataTypeClass, convert
19
+ from polars.datatypes.group import (
20
20
  DATETIME_DTYPES,
21
21
  DURATION_DTYPES,
22
22
  FLOAT_DTYPES,
23
23
  INTEGER_DTYPES,
24
+ DataTypeGroup,
24
25
  )
25
26
  from polars.polars import (
26
27
  dtype_str_repr, # TODO: this is a rust function, can we implement our own string parser for Time/Duration/Datetime?
@@ -78,11 +79,9 @@ def is_optional(type_annotation: type[Any] | Any | None) -> bool:
78
79
  """Return True if the given type annotation is an Optional annotation.
79
80
 
80
81
  Args:
81
- ----
82
82
  type_annotation: The type annotation to be checked.
83
83
 
84
84
  Returns:
85
- -------
86
85
  True if the outermost type is Optional.
87
86
 
88
87
  """
@@ -92,8 +91,8 @@ def is_optional(type_annotation: type[Any] | Any | None) -> bool:
92
91
 
93
92
 
94
93
  def parse_composite_dtype(dtype: DataTypeClass | DataType) -> str:
95
- """For serialization, converts polars dtype to string representation"""
96
- if dtype in pl.NESTED_DTYPES:
94
+ """For serialization, converts polars dtype to string representation."""
95
+ if dtype.is_nested():
97
96
  if dtype == pl.Struct or isinstance(dtype, pl.Struct):
98
97
  raise NotImplementedError("Structs not yet supported by patito")
99
98
  if not isinstance(dtype, pl.List) or isinstance(dtype, pl.Array):
@@ -103,14 +102,14 @@ def parse_composite_dtype(dtype: DataTypeClass | DataType) -> str:
103
102
  if dtype.inner is None:
104
103
  return convert.DataTypeMappings.DTYPE_TO_FFINAME[dtype.base_type()]
105
104
  return f"{convert.DataTypeMappings.DTYPE_TO_FFINAME[dtype.base_type()]}[{parse_composite_dtype(dtype.inner)}]"
106
- elif dtype in pl.TEMPORAL_DTYPES:
105
+ elif dtype.is_temporal():
107
106
  return cast(str, dtype_str_repr(dtype))
108
107
  else:
109
108
  return convert.DataTypeMappings.DTYPE_TO_FFINAME[dtype]
110
109
 
111
110
 
112
111
  def dtype_from_string(v: str) -> Optional[Union[DataTypeClass, DataType]]:
113
- """For deserialization"""
112
+ """For deserialization."""
114
113
  # TODO test all dtypes
115
114
  return convert.dtype_short_repr_to_dtype(v)
116
115
 
patito/_pydantic/repr.py CHANGED
@@ -82,7 +82,7 @@ class Representation:
82
82
  def __pretty__(
83
83
  self, fmt: Callable[[Any], Any], **kwargs: Any
84
84
  ) -> Generator[Any, None, None]:
85
- """Used by devtools (https://python-devtools.helpmanual.io/) to provide a human readable representations of objects"""
85
+ """Used by devtools (https://python-devtools.helpmanual.io/) to provide a human readable representations of objects."""
86
86
  yield self.__repr_name__() + "("
87
87
  yield 1
88
88
  for name, value in self.__repr_args__():
@@ -101,7 +101,7 @@ class Representation:
101
101
  return f'{self.__repr_name__()}({self.__repr_str__(", ")})'
102
102
 
103
103
  def __rich_repr__(self) -> "RichReprResult":
104
- """Get fields for Rich library"""
104
+ """Get fields for Rich library."""
105
105
  for name, field_repr in self.__repr_args__():
106
106
  if name is None:
107
107
  yield field_repr
@@ -16,14 +16,12 @@ if TYPE_CHECKING:
16
16
  def schema_for_model(cls: Type[ModelType]) -> Dict[str, Dict[str, Any]]:
17
17
  """Return schema properties where definition references have been resolved.
18
18
 
19
- Returns
20
- -------
19
+ Returns:
21
20
  Field information as a dictionary where the keys are field names and the
22
21
  values are dictionaries containing metadata information about the field
23
22
  itself.
24
23
 
25
- Raises
26
- ------
24
+ Raises:
27
25
  TypeError: if a field is annotated with an enum where the values are of
28
26
  different types.
29
27
 
patito/exceptions.py CHANGED
@@ -1,3 +1,5 @@
1
+ """Exceptions used by patito."""
2
+
1
3
  from typing import (
2
4
  TYPE_CHECKING,
3
5
  Any,
@@ -34,19 +36,24 @@ __all__ = "ErrorWrapper", "DataFrameValidationError"
34
36
 
35
37
 
36
38
  class ErrorWrapper(Representation):
39
+ """Error handler for nicely accumulating errors."""
40
+
37
41
  __slots__ = "exc", "_loc"
38
42
 
39
43
  def __init__(self, exc: Exception, loc: Union[str, "Loc"]) -> None:
44
+ """Wrap an error in an ErrorWrapper."""
40
45
  self.exc = exc
41
46
  self._loc = loc
42
47
 
43
48
  def loc_tuple(self) -> "Loc":
49
+ """Represent error as tuple."""
44
50
  if isinstance(self._loc, tuple):
45
51
  return self._loc
46
52
  else:
47
53
  return (self._loc,)
48
54
 
49
55
  def __repr_args__(self) -> "ReprArgs":
56
+ """Pydantic repr."""
50
57
  return [("exc", self.exc), ("loc", self.loc_tuple())]
51
58
 
52
59
 
@@ -56,19 +63,24 @@ ErrorList = Union[Sequence[Any], ErrorWrapper]
56
63
 
57
64
 
58
65
  class DataFrameValidationError(Representation, ValueError):
66
+ """Parent error for DataFrame validation errors."""
67
+
59
68
  __slots__ = "raw_errors", "model", "_error_cache"
60
69
 
61
70
  def __init__(self, errors: Sequence[ErrorList], model: Type["BaseModel"]) -> None:
71
+ """Create a dataframe validation error."""
62
72
  self.raw_errors = errors
63
73
  self.model = model
64
74
  self._error_cache: Optional[List["ErrorDict"]] = None
65
75
 
66
76
  def errors(self) -> List["ErrorDict"]:
77
+ """Get list of errors."""
67
78
  if self._error_cache is None:
68
79
  self._error_cache = list(flatten_errors(self.raw_errors))
69
80
  return self._error_cache
70
81
 
71
82
  def __str__(self) -> str:
83
+ """String reprentation of error."""
72
84
  errors = self.errors()
73
85
  no_errors = len(errors)
74
86
  return (
@@ -77,6 +89,7 @@ class DataFrameValidationError(Representation, ValueError):
77
89
  )
78
90
 
79
91
  def __repr_args__(self) -> "ReprArgs":
92
+ """Pydantic repr."""
80
93
  return [("model", self.model.__name__), ("errors", self.errors())]
81
94
 
82
95