patito 0.5.1__py3-none-any.whl → 0.6.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- patito/__init__.py +4 -23
- patito/_docs.py +1 -0
- patito/_pydantic/__init__.py +0 -0
- patito/_pydantic/column_info.py +94 -0
- patito/_pydantic/dtypes/__init__.py +25 -0
- patito/_pydantic/dtypes/dtypes.py +249 -0
- patito/_pydantic/dtypes/utils.py +227 -0
- patito/_pydantic/repr.py +139 -0
- patito/_pydantic/schema.py +96 -0
- patito/exceptions.py +174 -7
- patito/polars.py +310 -102
- patito/pydantic.py +361 -511
- patito/validators.py +229 -96
- {patito-0.5.1.dist-info → patito-0.6.2.dist-info}/METADATA +12 -26
- patito-0.6.2.dist-info/RECORD +17 -0
- patito/database.py +0 -658
- patito/duckdb.py +0 -2793
- patito/sql.py +0 -88
- patito/xdg.py +0 -22
- patito-0.5.1.dist-info/RECORD +0 -14
- {patito-0.5.1.dist-info → patito-0.6.2.dist-info}/LICENSE +0 -0
- {patito-0.5.1.dist-info → patito-0.6.2.dist-info}/WHEEL +0 -0
patito/__init__.py
CHANGED
|
@@ -1,47 +1,28 @@
|
|
|
1
1
|
"""Patito, a data-modelling library built on top of polars and pydantic."""
|
|
2
|
+
|
|
2
3
|
from polars import Expr, Series, col
|
|
3
4
|
|
|
4
|
-
from patito import exceptions
|
|
5
|
-
from patito.exceptions import
|
|
5
|
+
from patito import exceptions
|
|
6
|
+
from patito.exceptions import DataFrameValidationError
|
|
6
7
|
from patito.polars import DataFrame, LazyFrame
|
|
7
8
|
from patito.pydantic import Field, Model
|
|
8
9
|
|
|
9
10
|
_CACHING_AVAILABLE = False
|
|
10
|
-
_DUCKDB_AVAILABLE = False
|
|
11
11
|
field = col("_")
|
|
12
12
|
__all__ = [
|
|
13
13
|
"DataFrame",
|
|
14
|
+
"DataFrameValidationError",
|
|
14
15
|
"Expr",
|
|
15
16
|
"Field",
|
|
16
17
|
"LazyFrame",
|
|
17
18
|
"Model",
|
|
18
19
|
"Series",
|
|
19
|
-
"ValidationError",
|
|
20
20
|
"_CACHING_AVAILABLE",
|
|
21
|
-
"_DUCKDB_AVAILABLE",
|
|
22
21
|
"col",
|
|
23
22
|
"exceptions",
|
|
24
23
|
"field",
|
|
25
|
-
"sql",
|
|
26
24
|
]
|
|
27
25
|
|
|
28
|
-
try:
|
|
29
|
-
from patito import duckdb
|
|
30
|
-
|
|
31
|
-
_DUCKDB_AVAILABLE = True
|
|
32
|
-
__all__ += ["duckdb"]
|
|
33
|
-
except ImportError: # pragma: no cover
|
|
34
|
-
pass
|
|
35
|
-
|
|
36
|
-
try:
|
|
37
|
-
from patito.database import Database
|
|
38
|
-
|
|
39
|
-
_CACHING_AVAILABLE = True
|
|
40
|
-
__all__ += ["Database"]
|
|
41
|
-
except ImportError:
|
|
42
|
-
pass
|
|
43
|
-
|
|
44
|
-
|
|
45
26
|
try:
|
|
46
27
|
from importlib.metadata import PackageNotFoundError, version
|
|
47
28
|
except ImportError: # pragma: no cover
|
patito/_docs.py
CHANGED
|
File without changes
|
|
@@ -0,0 +1,94 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from typing import (
|
|
5
|
+
Any,
|
|
6
|
+
Dict,
|
|
7
|
+
Optional,
|
|
8
|
+
Sequence,
|
|
9
|
+
Type,
|
|
10
|
+
TypeVar,
|
|
11
|
+
Union,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
import polars as pl
|
|
15
|
+
from polars.datatypes import DataType, DataTypeClass
|
|
16
|
+
from pydantic import BaseModel, field_serializer
|
|
17
|
+
|
|
18
|
+
from patito._pydantic.dtypes import parse_composite_dtype
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class ColumnInfo(BaseModel, arbitrary_types_allowed=True):
|
|
22
|
+
"""patito-side model for storing column metadata.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
constraints (Union[polars.Expression, List[polars.Expression]): A single
|
|
26
|
+
constraint or list of constraints, expressed as a polars expression objects.
|
|
27
|
+
All rows must satisfy the given constraint. You can refer to the given column
|
|
28
|
+
with ``pt.field``, which will automatically be replaced with
|
|
29
|
+
``polars.col(<field_name>)`` before evaluation.
|
|
30
|
+
derived_from (Union[str, polars.Expr]): used to mark fields that are meant to be derived from other fields. Users can specify a polars expression that will be called to derive the column value when `pt.DataFrame.derive` is called.
|
|
31
|
+
dtype (polars.datatype.DataType): The given dataframe column must have the given
|
|
32
|
+
polars dtype, for instance ``polars.UInt64`` or ``pl.Float32``.
|
|
33
|
+
unique (bool): All row values must be unique.
|
|
34
|
+
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
dtype: Optional[Union[DataTypeClass, DataType]] = None
|
|
38
|
+
constraints: Optional[Union[pl.Expr, Sequence[pl.Expr]]] = None
|
|
39
|
+
derived_from: Optional[Union[str, pl.Expr]] = None
|
|
40
|
+
unique: Optional[bool] = None
|
|
41
|
+
|
|
42
|
+
def __repr__(self) -> str:
|
|
43
|
+
"""Print only Field attributes whose values are not default (mainly None)."""
|
|
44
|
+
not_default_field = {
|
|
45
|
+
field: getattr(self, field)
|
|
46
|
+
for field in self.model_fields
|
|
47
|
+
if getattr(self, field) is not self.model_fields[field].default
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
string = ""
|
|
51
|
+
for field, value in not_default_field.items():
|
|
52
|
+
string += f"{field}={value}, "
|
|
53
|
+
if string:
|
|
54
|
+
# remove trailing comma and space
|
|
55
|
+
string = string[:-2]
|
|
56
|
+
return f"ColumnInfo({string})"
|
|
57
|
+
|
|
58
|
+
@field_serializer("constraints", "derived_from")
|
|
59
|
+
def serialize_exprs(self, exprs: str | pl.Expr | Sequence[pl.Expr] | None) -> Any:
|
|
60
|
+
if exprs is None:
|
|
61
|
+
return None
|
|
62
|
+
elif isinstance(exprs, str):
|
|
63
|
+
return exprs
|
|
64
|
+
elif isinstance(exprs, pl.Expr):
|
|
65
|
+
return self._serialize_expr(exprs)
|
|
66
|
+
elif isinstance(exprs, Sequence):
|
|
67
|
+
return [self._serialize_expr(c) for c in exprs]
|
|
68
|
+
else:
|
|
69
|
+
raise ValueError(f"Invalid type for exprs: {type(exprs)}")
|
|
70
|
+
|
|
71
|
+
def _serialize_expr(self, expr: pl.Expr) -> Dict:
|
|
72
|
+
if isinstance(expr, pl.Expr):
|
|
73
|
+
return json.loads(
|
|
74
|
+
expr.meta.serialize(None)
|
|
75
|
+
) # can we access the dictionary directly?
|
|
76
|
+
else:
|
|
77
|
+
raise ValueError(f"Invalid type for expr: {type(expr)}")
|
|
78
|
+
|
|
79
|
+
@field_serializer("dtype")
|
|
80
|
+
def serialize_dtype(self, dtype: DataTypeClass | DataType | None) -> Any:
|
|
81
|
+
"""Serialize a polars dtype.
|
|
82
|
+
|
|
83
|
+
References:
|
|
84
|
+
[1] https://stackoverflow.com/questions/76572310/how-to-serialize-deserialize-polars-datatypes
|
|
85
|
+
"""
|
|
86
|
+
if dtype is None:
|
|
87
|
+
return None
|
|
88
|
+
elif isinstance(dtype, DataTypeClass) or isinstance(dtype, DataType):
|
|
89
|
+
return parse_composite_dtype(dtype)
|
|
90
|
+
else:
|
|
91
|
+
raise ValueError(f"Invalid type for dtype: {type(dtype)}")
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
CI = TypeVar("CI", bound=Type[ColumnInfo])
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
from patito._pydantic.dtypes.dtypes import (
|
|
2
|
+
DtypeResolver,
|
|
3
|
+
default_dtypes_for_model,
|
|
4
|
+
valid_dtypes_for_model,
|
|
5
|
+
validate_annotation,
|
|
6
|
+
validate_polars_dtype,
|
|
7
|
+
)
|
|
8
|
+
from patito._pydantic.dtypes.utils import (
|
|
9
|
+
PYTHON_TO_PYDANTIC_TYPES,
|
|
10
|
+
dtype_from_string,
|
|
11
|
+
is_optional,
|
|
12
|
+
parse_composite_dtype,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
__all__ = [
|
|
16
|
+
"DtypeResolver",
|
|
17
|
+
"validate_annotation",
|
|
18
|
+
"validate_polars_dtype",
|
|
19
|
+
"parse_composite_dtype",
|
|
20
|
+
"dtype_from_string",
|
|
21
|
+
"valid_dtypes_for_model",
|
|
22
|
+
"default_dtypes_for_model",
|
|
23
|
+
"PYTHON_TO_PYDANTIC_TYPES",
|
|
24
|
+
"is_optional",
|
|
25
|
+
]
|
|
@@ -0,0 +1,249 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from functools import cache, reduce
|
|
4
|
+
from operator import and_
|
|
5
|
+
from typing import TYPE_CHECKING, Any, Dict, FrozenSet, Mapping, Optional, Type
|
|
6
|
+
|
|
7
|
+
import polars as pl
|
|
8
|
+
from polars.datatypes import DataType, DataTypeClass, DataTypeGroup
|
|
9
|
+
from pydantic import TypeAdapter
|
|
10
|
+
|
|
11
|
+
from patito._pydantic.dtypes.utils import (
|
|
12
|
+
PT_BASE_SUPPORTED_DTYPES,
|
|
13
|
+
PydanticBaseType,
|
|
14
|
+
_pyd_type_to_default_dtype,
|
|
15
|
+
_pyd_type_to_valid_dtypes,
|
|
16
|
+
_without_optional,
|
|
17
|
+
dtype_from_string,
|
|
18
|
+
)
|
|
19
|
+
from patito._pydantic.repr import display_as_type
|
|
20
|
+
|
|
21
|
+
if TYPE_CHECKING:
|
|
22
|
+
from patito.pydantic import ModelType
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@cache
|
|
26
|
+
def valid_dtypes_for_model(
|
|
27
|
+
cls: Type[ModelType],
|
|
28
|
+
) -> Mapping[str, FrozenSet[DataTypeClass]]:
|
|
29
|
+
return {
|
|
30
|
+
column: (
|
|
31
|
+
DtypeResolver(cls.model_fields[column].annotation).valid_polars_dtypes()
|
|
32
|
+
if cls.column_infos[column].dtype is None
|
|
33
|
+
else DataTypeGroup([cls.dtypes[column]], match_base_type=False)
|
|
34
|
+
)
|
|
35
|
+
for column in cls.columns
|
|
36
|
+
}
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@cache
|
|
40
|
+
def default_dtypes_for_model(
|
|
41
|
+
cls: Type[ModelType],
|
|
42
|
+
) -> dict[str, DataType]:
|
|
43
|
+
default_dtypes: dict[str, DataType] = {}
|
|
44
|
+
for column in cls.columns:
|
|
45
|
+
dtype = (
|
|
46
|
+
cls.column_infos[column].dtype
|
|
47
|
+
or DtypeResolver(cls.model_fields[column].annotation).default_polars_dtype()
|
|
48
|
+
)
|
|
49
|
+
if dtype is None:
|
|
50
|
+
raise ValueError(f"Unable to find a default dtype for column `{column}`")
|
|
51
|
+
|
|
52
|
+
default_dtypes[column] = dtype if isinstance(dtype, DataType) else dtype()
|
|
53
|
+
return default_dtypes
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def validate_polars_dtype(
|
|
57
|
+
annotation: type[Any] | None,
|
|
58
|
+
dtype: DataType | DataTypeClass | None,
|
|
59
|
+
column: Optional[str] = None,
|
|
60
|
+
) -> None:
|
|
61
|
+
"""Check that the polars dtype is valid for the given annotation. Raises ValueError if not.
|
|
62
|
+
|
|
63
|
+
Args:
|
|
64
|
+
annotation (type[Any] | None): python type annotation
|
|
65
|
+
dtype (DataType | DataTypeClass | None): polars dtype
|
|
66
|
+
column (Optional[str], optional): column name. Defaults to None.
|
|
67
|
+
|
|
68
|
+
"""
|
|
69
|
+
if (
|
|
70
|
+
dtype is None or annotation is None
|
|
71
|
+
): # no potential conflict between type annotation and chosen polars type
|
|
72
|
+
return
|
|
73
|
+
valid_dtypes = DtypeResolver(annotation).valid_polars_dtypes()
|
|
74
|
+
if dtype not in valid_dtypes:
|
|
75
|
+
if column:
|
|
76
|
+
column_msg = f" for column `{column}`"
|
|
77
|
+
else:
|
|
78
|
+
column_msg = ""
|
|
79
|
+
raise ValueError(
|
|
80
|
+
f"Invalid dtype {dtype}{column_msg}. Allowable polars dtypes for {display_as_type(annotation)} are: {', '.join([str(x) for x in valid_dtypes])}."
|
|
81
|
+
)
|
|
82
|
+
return
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def validate_annotation(
|
|
86
|
+
annotation: type[Any] | Any | None, column: Optional[str] = None
|
|
87
|
+
) -> None:
|
|
88
|
+
"""Check that the provided annotation has polars/patito support (we can resolve it to a default dtype). Raises ValueError if not.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
annotation (type[Any] | None): python type annotation
|
|
92
|
+
column (Optional[str], optional): column name. Defaults to None.
|
|
93
|
+
|
|
94
|
+
"""
|
|
95
|
+
default_dtype = DtypeResolver(annotation).default_polars_dtype()
|
|
96
|
+
if default_dtype is None:
|
|
97
|
+
valid_polars_dtypes = DtypeResolver(annotation).valid_polars_dtypes()
|
|
98
|
+
if column:
|
|
99
|
+
column_msg = f" for column `{column}`"
|
|
100
|
+
else:
|
|
101
|
+
column_msg = ""
|
|
102
|
+
if len(valid_polars_dtypes) == 0:
|
|
103
|
+
raise ValueError(
|
|
104
|
+
f"Annotation {display_as_type(annotation)}{column_msg} is not compatible with any polars dtypes."
|
|
105
|
+
)
|
|
106
|
+
else:
|
|
107
|
+
raise ValueError(
|
|
108
|
+
f"Unable to determine default dtype for annotation {display_as_type(annotation)}{column_msg}. Please provide a valid default polars dtype via the `dtype` argument to `Field`. Valid dtypes are: {', '.join([str(x) for x in valid_polars_dtypes])}."
|
|
109
|
+
)
|
|
110
|
+
return
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
class DtypeResolver:
|
|
114
|
+
def __init__(self, annotation: Any | None):
|
|
115
|
+
self.annotation = annotation
|
|
116
|
+
self.schema = TypeAdapter(annotation).json_schema()
|
|
117
|
+
self.defs = self.schema.get("$defs", {})
|
|
118
|
+
|
|
119
|
+
def valid_polars_dtypes(self) -> DataTypeGroup:
|
|
120
|
+
if self.annotation == Any:
|
|
121
|
+
return PT_BASE_SUPPORTED_DTYPES
|
|
122
|
+
return self._valid_polars_dtypes_for_schema(self.schema)
|
|
123
|
+
|
|
124
|
+
def default_polars_dtype(self) -> DataType | None:
|
|
125
|
+
if self.annotation == Any:
|
|
126
|
+
return pl.String()
|
|
127
|
+
return self._default_polars_dtype_for_schema(self.schema)
|
|
128
|
+
|
|
129
|
+
def _valid_polars_dtypes_for_schema(
|
|
130
|
+
self,
|
|
131
|
+
schema: Dict,
|
|
132
|
+
) -> DataTypeGroup:
|
|
133
|
+
valid_type_sets = []
|
|
134
|
+
if "anyOf" in schema:
|
|
135
|
+
schema = _without_optional(schema)
|
|
136
|
+
for sub_props in schema["anyOf"]:
|
|
137
|
+
valid_type_sets.append(
|
|
138
|
+
self._pydantic_subschema_to_valid_polars_types(sub_props)
|
|
139
|
+
)
|
|
140
|
+
else:
|
|
141
|
+
valid_type_sets.append(
|
|
142
|
+
self._pydantic_subschema_to_valid_polars_types(schema)
|
|
143
|
+
)
|
|
144
|
+
return reduce(and_, valid_type_sets) if valid_type_sets else DataTypeGroup([])
|
|
145
|
+
|
|
146
|
+
def _pydantic_subschema_to_valid_polars_types(
|
|
147
|
+
self,
|
|
148
|
+
props: Dict,
|
|
149
|
+
) -> DataTypeGroup:
|
|
150
|
+
if "type" not in props:
|
|
151
|
+
if "enum" in props:
|
|
152
|
+
raise TypeError("Mixed type enums not supported by patito.")
|
|
153
|
+
elif "const" in props:
|
|
154
|
+
return DtypeResolver(type(props["const"])).valid_polars_dtypes()
|
|
155
|
+
elif "$ref" in props:
|
|
156
|
+
return self._pydantic_subschema_to_valid_polars_types(
|
|
157
|
+
self.defs[props["$ref"].split("/")[-1]]
|
|
158
|
+
)
|
|
159
|
+
return DataTypeGroup([])
|
|
160
|
+
pyd_type = props.get("type")
|
|
161
|
+
if pyd_type == "array":
|
|
162
|
+
if "items" not in props:
|
|
163
|
+
return DataTypeGroup([])
|
|
164
|
+
array_props = props["items"]
|
|
165
|
+
item_dtypes = self._valid_polars_dtypes_for_schema(array_props)
|
|
166
|
+
# TODO support pl.Array?
|
|
167
|
+
return DataTypeGroup(
|
|
168
|
+
[pl.List(dtype) for dtype in item_dtypes], match_base_type=False
|
|
169
|
+
)
|
|
170
|
+
elif pyd_type == "object":
|
|
171
|
+
if "properties" not in props:
|
|
172
|
+
return DataTypeGroup([])
|
|
173
|
+
object_props = props["properties"]
|
|
174
|
+
return DataTypeGroup(
|
|
175
|
+
[
|
|
176
|
+
pl.Struct(
|
|
177
|
+
[
|
|
178
|
+
pl.Field(
|
|
179
|
+
name, self._default_polars_dtype_for_schema(sub_props)
|
|
180
|
+
)
|
|
181
|
+
for name, sub_props in object_props.items()
|
|
182
|
+
]
|
|
183
|
+
)
|
|
184
|
+
],
|
|
185
|
+
match_base_type=False,
|
|
186
|
+
) # for structs, return only the default dtype set to avoid combinatoric issues
|
|
187
|
+
return _pyd_type_to_valid_dtypes(
|
|
188
|
+
PydanticBaseType(pyd_type), props.get("format"), props.get("enum")
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
def _default_polars_dtype_for_schema(self, schema: Dict) -> DataType | None:
|
|
192
|
+
if "anyOf" in schema:
|
|
193
|
+
if len(schema["anyOf"]) == 2: # look for optionals first
|
|
194
|
+
schema = _without_optional(schema)
|
|
195
|
+
if len(schema["anyOf"]) == 1:
|
|
196
|
+
if "column_info" in schema:
|
|
197
|
+
schema["anyOf"][0]["column_info"] = schema[
|
|
198
|
+
"column_info"
|
|
199
|
+
] # push column info through optional
|
|
200
|
+
schema = schema["anyOf"][0]
|
|
201
|
+
else:
|
|
202
|
+
return None
|
|
203
|
+
return self._pydantic_subschema_to_default_dtype(schema)
|
|
204
|
+
|
|
205
|
+
def _pydantic_subschema_to_default_dtype(
|
|
206
|
+
self,
|
|
207
|
+
props: Dict,
|
|
208
|
+
) -> DataType | None:
|
|
209
|
+
if "column_info" in props: # user has specified in patito model
|
|
210
|
+
if props["column_info"]["dtype"] is not None:
|
|
211
|
+
dtype = dtype_from_string(props["column_info"]["dtype"])
|
|
212
|
+
dtype = dtype() if isinstance(dtype, DataTypeClass) else dtype
|
|
213
|
+
return dtype
|
|
214
|
+
if "type" not in props:
|
|
215
|
+
if "enum" in props:
|
|
216
|
+
raise TypeError("Mixed type enums not supported by patito.")
|
|
217
|
+
elif "const" in props:
|
|
218
|
+
return DtypeResolver(type(props["const"])).default_polars_dtype()
|
|
219
|
+
elif "$ref" in props:
|
|
220
|
+
return self._pydantic_subschema_to_default_dtype(
|
|
221
|
+
self.defs[props["$ref"].split("/")[-1]]
|
|
222
|
+
)
|
|
223
|
+
return None
|
|
224
|
+
pyd_type = props.get("type")
|
|
225
|
+
if pyd_type == "array":
|
|
226
|
+
if "items" not in props:
|
|
227
|
+
raise NotImplementedError(
|
|
228
|
+
"Unexpected error processing pydantic schema. Please file an issue."
|
|
229
|
+
)
|
|
230
|
+
array_props = props["items"]
|
|
231
|
+
inner_default_type = self._default_polars_dtype_for_schema(array_props)
|
|
232
|
+
if inner_default_type is None:
|
|
233
|
+
return None
|
|
234
|
+
return pl.List(inner_default_type)
|
|
235
|
+
elif pyd_type == "object":
|
|
236
|
+
if "properties" not in props:
|
|
237
|
+
raise NotImplementedError(
|
|
238
|
+
"dictionaries not currently supported by patito"
|
|
239
|
+
)
|
|
240
|
+
object_props = props["properties"]
|
|
241
|
+
return pl.Struct(
|
|
242
|
+
[
|
|
243
|
+
pl.Field(name, self._default_polars_dtype_for_schema(sub_props))
|
|
244
|
+
for name, sub_props in object_props.items()
|
|
245
|
+
]
|
|
246
|
+
)
|
|
247
|
+
return _pyd_type_to_default_dtype(
|
|
248
|
+
PydanticBaseType(pyd_type), props.get("format"), props.get("enum")
|
|
249
|
+
)
|
|
@@ -0,0 +1,227 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import sys
|
|
4
|
+
from enum import Enum
|
|
5
|
+
from typing import (
|
|
6
|
+
Any,
|
|
7
|
+
Dict,
|
|
8
|
+
List,
|
|
9
|
+
Optional,
|
|
10
|
+
Sequence,
|
|
11
|
+
Union,
|
|
12
|
+
cast,
|
|
13
|
+
get_args,
|
|
14
|
+
get_origin,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
import polars as pl
|
|
18
|
+
from polars.datatypes import DataType, DataTypeClass, DataTypeGroup, convert
|
|
19
|
+
from polars.datatypes.constants import (
|
|
20
|
+
DATETIME_DTYPES,
|
|
21
|
+
DURATION_DTYPES,
|
|
22
|
+
FLOAT_DTYPES,
|
|
23
|
+
INTEGER_DTYPES,
|
|
24
|
+
)
|
|
25
|
+
from polars.polars import (
|
|
26
|
+
dtype_str_repr, # TODO: this is a rust function, can we implement our own string parser for Time/Duration/Datetime?
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
PYTHON_TO_PYDANTIC_TYPES = {
|
|
30
|
+
str: "string",
|
|
31
|
+
int: "integer",
|
|
32
|
+
float: "number",
|
|
33
|
+
bool: "boolean",
|
|
34
|
+
type(None): "null",
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
BOOLEAN_DTYPES = DataTypeGroup([pl.Boolean])
|
|
38
|
+
STRING_DTYPES = DataTypeGroup([pl.String])
|
|
39
|
+
DATE_DTYPES = DataTypeGroup([pl.Date])
|
|
40
|
+
TIME_DTYPES = DataTypeGroup([pl.Time])
|
|
41
|
+
|
|
42
|
+
PT_BASE_SUPPORTED_DTYPES = DataTypeGroup(
|
|
43
|
+
INTEGER_DTYPES
|
|
44
|
+
| FLOAT_DTYPES
|
|
45
|
+
| BOOLEAN_DTYPES
|
|
46
|
+
| STRING_DTYPES
|
|
47
|
+
| DATE_DTYPES
|
|
48
|
+
| DATETIME_DTYPES
|
|
49
|
+
| DURATION_DTYPES
|
|
50
|
+
| TIME_DTYPES
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
if sys.version_info >= (3, 10): # pragma: no cover
|
|
54
|
+
from types import UnionType # pyright: ignore
|
|
55
|
+
|
|
56
|
+
UNION_TYPES = (Union, UnionType)
|
|
57
|
+
else:
|
|
58
|
+
UNION_TYPES = (Union,) # pragma: no cover
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class PydanticBaseType(Enum):
|
|
62
|
+
STRING = "string"
|
|
63
|
+
INTEGER = "integer"
|
|
64
|
+
NUMBER = "number"
|
|
65
|
+
BOOLEAN = "boolean"
|
|
66
|
+
NULL = "null"
|
|
67
|
+
OBJECT = "object"
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class PydanticStringFormat(Enum):
|
|
71
|
+
DATE = "date"
|
|
72
|
+
DATE_TIME = "date-time"
|
|
73
|
+
DURATION = "duration"
|
|
74
|
+
TIME = "time"
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def is_optional(type_annotation: type[Any] | Any | None) -> bool:
|
|
78
|
+
"""Return True if the given type annotation is an Optional annotation.
|
|
79
|
+
|
|
80
|
+
Args:
|
|
81
|
+
type_annotation: The type annotation to be checked.
|
|
82
|
+
|
|
83
|
+
Returns:
|
|
84
|
+
True if the outermost type is Optional.
|
|
85
|
+
|
|
86
|
+
"""
|
|
87
|
+
return (get_origin(type_annotation) in UNION_TYPES) and (
|
|
88
|
+
type(None) in get_args(type_annotation)
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def parse_composite_dtype(dtype: DataTypeClass | DataType) -> str:
|
|
93
|
+
"""For serialization, converts polars dtype to string representation."""
|
|
94
|
+
if dtype in pl.NESTED_DTYPES:
|
|
95
|
+
if dtype == pl.Struct or isinstance(dtype, pl.Struct):
|
|
96
|
+
raise NotImplementedError("Structs not yet supported by patito")
|
|
97
|
+
if not isinstance(dtype, pl.List) or isinstance(dtype, pl.Array):
|
|
98
|
+
raise NotImplementedError(
|
|
99
|
+
f"Unsupported nested dtype: {dtype} of type {type(dtype)}"
|
|
100
|
+
)
|
|
101
|
+
if dtype.inner is None:
|
|
102
|
+
return convert.DataTypeMappings.DTYPE_TO_FFINAME[dtype.base_type()]
|
|
103
|
+
return f"{convert.DataTypeMappings.DTYPE_TO_FFINAME[dtype.base_type()]}[{parse_composite_dtype(dtype.inner)}]"
|
|
104
|
+
elif dtype in pl.TEMPORAL_DTYPES:
|
|
105
|
+
return cast(str, dtype_str_repr(dtype))
|
|
106
|
+
else:
|
|
107
|
+
return convert.DataTypeMappings.DTYPE_TO_FFINAME[dtype]
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def dtype_from_string(v: str) -> Optional[Union[DataTypeClass, DataType]]:
|
|
111
|
+
"""For deserialization."""
|
|
112
|
+
# TODO test all dtypes
|
|
113
|
+
return convert.dtype_short_repr_to_dtype(v)
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def _pyd_type_to_valid_dtypes(
|
|
117
|
+
pyd_type: PydanticBaseType, string_format: Optional[str], enum: List[str] | None
|
|
118
|
+
) -> DataTypeGroup:
|
|
119
|
+
if enum is not None:
|
|
120
|
+
_validate_enum_values(pyd_type, enum)
|
|
121
|
+
return DataTypeGroup([pl.Enum(enum), pl.String], match_base_type=False)
|
|
122
|
+
if pyd_type.value == "integer":
|
|
123
|
+
return DataTypeGroup(INTEGER_DTYPES | FLOAT_DTYPES)
|
|
124
|
+
elif pyd_type.value == "number":
|
|
125
|
+
return (
|
|
126
|
+
FLOAT_DTYPES
|
|
127
|
+
if isinstance(FLOAT_DTYPES, DataTypeGroup)
|
|
128
|
+
else DataTypeGroup(FLOAT_DTYPES)
|
|
129
|
+
)
|
|
130
|
+
elif pyd_type.value == "boolean":
|
|
131
|
+
return BOOLEAN_DTYPES
|
|
132
|
+
elif pyd_type.value == "string":
|
|
133
|
+
_string_format = (
|
|
134
|
+
PydanticStringFormat(string_format) if string_format is not None else None
|
|
135
|
+
)
|
|
136
|
+
return _pyd_string_format_to_valid_dtypes(_string_format)
|
|
137
|
+
elif pyd_type.value == "null":
|
|
138
|
+
return DataTypeGroup([pl.Null])
|
|
139
|
+
else:
|
|
140
|
+
return DataTypeGroup([])
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def _pyd_type_to_default_dtype(
|
|
144
|
+
pyd_type: PydanticBaseType, string_format: Optional[str], enum: List[str] | None
|
|
145
|
+
) -> DataTypeClass | DataType:
|
|
146
|
+
if enum is not None:
|
|
147
|
+
_validate_enum_values(pyd_type, enum)
|
|
148
|
+
return pl.Enum(enum)
|
|
149
|
+
elif pyd_type.value == "integer":
|
|
150
|
+
return pl.Int64()
|
|
151
|
+
elif pyd_type.value == "number":
|
|
152
|
+
return pl.Float64()
|
|
153
|
+
elif pyd_type.value == "boolean":
|
|
154
|
+
return pl.Boolean()
|
|
155
|
+
elif pyd_type.value == "string":
|
|
156
|
+
_string_format = (
|
|
157
|
+
PydanticStringFormat(string_format) if string_format is not None else None
|
|
158
|
+
)
|
|
159
|
+
return _pyd_string_format_to_default_dtype(_string_format)
|
|
160
|
+
elif pyd_type.value == "null":
|
|
161
|
+
return pl.Null()
|
|
162
|
+
elif pyd_type.value == "object":
|
|
163
|
+
raise ValueError("pydantic object types not currently supported by patito")
|
|
164
|
+
else:
|
|
165
|
+
raise NotImplementedError
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def _pyd_string_format_to_valid_dtypes(
|
|
169
|
+
string_format: PydanticStringFormat | None,
|
|
170
|
+
) -> DataTypeGroup:
|
|
171
|
+
if string_format is None:
|
|
172
|
+
return STRING_DTYPES
|
|
173
|
+
elif string_format.value == "date":
|
|
174
|
+
return DATE_DTYPES
|
|
175
|
+
elif string_format.value == "date-time":
|
|
176
|
+
return (
|
|
177
|
+
DATETIME_DTYPES
|
|
178
|
+
if isinstance(DATE_DTYPES, DataTypeGroup)
|
|
179
|
+
else DataTypeGroup(DATE_DTYPES)
|
|
180
|
+
)
|
|
181
|
+
elif string_format.value == "duration":
|
|
182
|
+
return (
|
|
183
|
+
DURATION_DTYPES
|
|
184
|
+
if isinstance(DURATION_DTYPES, DataTypeGroup)
|
|
185
|
+
else DataTypeGroup(DURATION_DTYPES)
|
|
186
|
+
)
|
|
187
|
+
elif string_format.value == "time":
|
|
188
|
+
return TIME_DTYPES
|
|
189
|
+
else:
|
|
190
|
+
raise NotImplementedError
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def _pyd_string_format_to_default_dtype(
|
|
194
|
+
string_format: PydanticStringFormat | None,
|
|
195
|
+
) -> DataTypeClass | DataType:
|
|
196
|
+
if string_format is None:
|
|
197
|
+
return pl.String()
|
|
198
|
+
elif string_format.value == "date":
|
|
199
|
+
return pl.Date()
|
|
200
|
+
elif string_format.value == "date-time":
|
|
201
|
+
return pl.Datetime()
|
|
202
|
+
elif string_format.value == "duration":
|
|
203
|
+
return pl.Duration()
|
|
204
|
+
elif string_format.value == "time":
|
|
205
|
+
return pl.Time()
|
|
206
|
+
else:
|
|
207
|
+
raise NotImplementedError
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def _without_optional(schema: Dict) -> Dict:
|
|
211
|
+
if "anyOf" in schema:
|
|
212
|
+
for sub_props in schema["anyOf"]:
|
|
213
|
+
if "type" in sub_props and sub_props["type"] == "null":
|
|
214
|
+
schema["anyOf"].remove(sub_props)
|
|
215
|
+
return schema
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
def _validate_enum_values(pyd_type: PydanticBaseType, enum: Sequence[Any]) -> None:
|
|
219
|
+
enum_types = set(type(value) for value in enum)
|
|
220
|
+
if len(enum_types) > 1:
|
|
221
|
+
raise TypeError(
|
|
222
|
+
f"All enumerated values of enums used to annotate Patito model fields must have the same type. Encountered types: {sorted(map(lambda t: t.__name__, enum_types))}."
|
|
223
|
+
)
|
|
224
|
+
if pyd_type.value != "string":
|
|
225
|
+
raise TypeError(
|
|
226
|
+
f"Enums used to annotate Patito model fields must be strings. Encountered type: {enum_types.pop().__name__}."
|
|
227
|
+
)
|