patito 0.5.1__py3-none-any.whl → 0.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,139 @@
1
+ import sys
2
+ import types
3
+ import typing
4
+ from typing import (
5
+ Any,
6
+ Callable,
7
+ Generator,
8
+ Iterable,
9
+ Literal,
10
+ Optional,
11
+ Sequence,
12
+ Tuple,
13
+ Type,
14
+ Union,
15
+ get_args,
16
+ get_origin,
17
+ )
18
+
19
+ if typing.TYPE_CHECKING:
20
+ Loc = Tuple[Union[int, str], ...]
21
+ ReprArgs = Sequence[Tuple[Optional[str], Any]]
22
+ RichReprResult = Iterable[
23
+ Union[Any, Tuple[Any], Tuple[str, Any], Tuple[str, Any, Any]]
24
+ ]
25
+
26
+ try:
27
+ from typing import _TypingBase # type: ignore[attr-defined]
28
+ except ImportError:
29
+ from typing import _Final as _TypingBase # type: ignore[attr-defined]
30
+
31
+ typing_base = _TypingBase
32
+
33
+ if sys.version_info < (3, 9):
34
+ # python < 3.9 does not have GenericAlias (list[int], tuple[str, ...] and so on)
35
+ TypingGenericAlias = ()
36
+ else:
37
+ from typing import GenericAlias as TypingGenericAlias # type: ignore
38
+
39
+ if sys.version_info < (3, 10):
40
+
41
+ def origin_is_union(tp: Optional[Type[Any]]) -> bool:
42
+ return tp is typing.Union
43
+
44
+ WithArgsTypes = (TypingGenericAlias,)
45
+
46
+ else:
47
+
48
+ def origin_is_union(tp: type[Any] | None) -> bool:
49
+ return tp is typing.Union or tp is types.UnionType
50
+
51
+ WithArgsTypes = typing._GenericAlias, types.GenericAlias, types.UnionType # type: ignore[attr-defined]
52
+
53
+
54
+ class Representation:
55
+ """Mixin to provide __str__, __repr__, and __pretty__ methods. See #884 for more details.
56
+
57
+ __pretty__ is used by [devtools](https://python-devtools.helpmanual.io/) to provide human readable representations
58
+ of objects.
59
+ """
60
+
61
+ __slots__: Tuple[str, ...] = tuple()
62
+
63
+ def __repr_args__(self) -> "ReprArgs":
64
+ """Returns the attributes to show in __str__, __repr__, and __pretty__ this is generally overridden.
65
+
66
+ Can either return:
67
+ * name - value pairs, e.g.: `[('foo_name', 'foo'), ('bar_name', ['b', 'a', 'r'])]`
68
+ * or, just values, e.g.: `[(None, 'foo'), (None, ['b', 'a', 'r'])]`
69
+ """
70
+ attrs = ((s, getattr(self, s)) for s in self.__slots__)
71
+ return [(a, v) for a, v in attrs if v is not None]
72
+
73
+ def __repr_name__(self) -> str:
74
+ """Name of the instance's class, used in __repr__."""
75
+ return self.__class__.__name__
76
+
77
+ def __repr_str__(self, join_str: str) -> str:
78
+ return join_str.join(
79
+ repr(v) if a is None else f"{a}={v!r}" for a, v in self.__repr_args__()
80
+ )
81
+
82
+ def __pretty__(
83
+ self, fmt: Callable[[Any], Any], **kwargs: Any
84
+ ) -> Generator[Any, None, None]:
85
+ """Used by devtools (https://python-devtools.helpmanual.io/) to provide a human readable representations of objects."""
86
+ yield self.__repr_name__() + "("
87
+ yield 1
88
+ for name, value in self.__repr_args__():
89
+ if name is not None:
90
+ yield name + "="
91
+ yield fmt(value)
92
+ yield ","
93
+ yield 0
94
+ yield -1
95
+ yield ")"
96
+
97
+ def __str__(self) -> str:
98
+ return self.__repr_str__(" ")
99
+
100
+ def __repr__(self) -> str:
101
+ return f'{self.__repr_name__()}({self.__repr_str__(", ")})'
102
+
103
+ def __rich_repr__(self) -> "RichReprResult":
104
+ """Get fields for Rich library."""
105
+ for name, field_repr in self.__repr_args__():
106
+ if name is None:
107
+ yield field_repr
108
+ else:
109
+ yield name, field_repr
110
+
111
+
112
+ def display_as_type(obj: Any) -> str:
113
+ """Pretty representation of a type, should be as close as possible to the original type definition string.
114
+
115
+ Takes some logic from `typing._type_repr`.
116
+ """
117
+ if isinstance(obj, types.FunctionType):
118
+ return obj.__name__
119
+ elif obj is ...:
120
+ return "..."
121
+ elif isinstance(obj, Representation):
122
+ return repr(obj)
123
+
124
+ if not isinstance(obj, (typing_base, WithArgsTypes, type)):
125
+ obj = obj.__class__
126
+
127
+ if origin_is_union(get_origin(obj)):
128
+ args = ", ".join(map(display_as_type, get_args(obj)))
129
+ return f"Union[{args}]"
130
+ elif isinstance(obj, WithArgsTypes):
131
+ if get_origin(obj) == Literal:
132
+ args = ", ".join(map(repr, get_args(obj)))
133
+ else:
134
+ args = ", ".join(map(display_as_type, get_args(obj)))
135
+ return f"{obj.__qualname__}[{args}]"
136
+ elif isinstance(obj, type):
137
+ return obj.__qualname__
138
+ else:
139
+ return repr(obj).replace("typing.", "").replace("typing_extensions.", "")
@@ -0,0 +1,96 @@
1
+ from __future__ import annotations
2
+
3
+ from functools import cache
4
+ from typing import TYPE_CHECKING, Any, Dict, Mapping, Optional, Type, cast, get_args
5
+
6
+ from pydantic.fields import FieldInfo
7
+
8
+ from patito._pydantic.column_info import ColumnInfo
9
+ from patito._pydantic.dtypes import PYTHON_TO_PYDANTIC_TYPES
10
+
11
+ if TYPE_CHECKING:
12
+ from patito.pydantic import ModelType
13
+
14
+
15
+ @cache
16
+ def schema_for_model(cls: Type[ModelType]) -> Dict[str, Dict[str, Any]]:
17
+ """Return schema properties where definition references have been resolved.
18
+
19
+ Returns:
20
+ Field information as a dictionary where the keys are field names and the
21
+ values are dictionaries containing metadata information about the field
22
+ itself.
23
+
24
+ Raises:
25
+ TypeError: if a field is annotated with an enum where the values are of
26
+ different types.
27
+
28
+ """
29
+ schema = cls.model_json_schema(by_alias=False, ref_template="{model}")
30
+ fields = {}
31
+ # first resolve definitions for nested models TODO checks for one-way references, if models are self-referencing this falls apart with recursion depth error
32
+ for f in cls.model_fields.values():
33
+ annotation = f.annotation
34
+ cls._update_dfn(annotation, schema)
35
+ for a in get_args(annotation):
36
+ cls._update_dfn(a, schema)
37
+ for field_name, field_info in schema["properties"].items():
38
+ fields[field_name] = _append_field_info_to_props(
39
+ field_info=field_info,
40
+ field_name=field_name,
41
+ required=field_name in schema.get("required", set()),
42
+ model_schema=schema,
43
+ )
44
+ schema["properties"] = fields
45
+ return schema
46
+
47
+
48
+ @cache
49
+ def column_infos_for_model(cls: Type[ModelType]) -> Mapping[str, ColumnInfo]:
50
+ fields = cls.model_fields
51
+
52
+ def get_column_info(field: FieldInfo) -> ColumnInfo:
53
+ if field.json_schema_extra is None:
54
+ return cast(ColumnInfo, cls.column_info_class())
55
+ elif callable(field.json_schema_extra):
56
+ raise NotImplementedError(
57
+ "Callable json_schema_extra not supported by patito."
58
+ )
59
+ return cast(ColumnInfo, field.json_schema_extra["column_info"])
60
+
61
+ return {k: get_column_info(v) for k, v in fields.items()}
62
+
63
+
64
+ def _append_field_info_to_props(
65
+ field_info: Dict[str, Any],
66
+ field_name: str,
67
+ model_schema: Dict[str, Any],
68
+ required: Optional[bool] = None,
69
+ ) -> Dict[str, Any]:
70
+ if "$ref" in field_info: # TODO onto runtime append
71
+ definition = model_schema["$defs"][field_info["$ref"]]
72
+ if "enum" in definition and "type" not in definition:
73
+ enum_types = set(type(value) for value in definition["enum"])
74
+ if len(enum_types) > 1:
75
+ raise TypeError(
76
+ "All enumerated values of enums used to annotate "
77
+ "Patito model fields must have the same type. "
78
+ "Encountered types: "
79
+ f"{sorted(map(lambda t: t.__name__, enum_types))}."
80
+ )
81
+ enum_type = enum_types.pop()
82
+ definition["type"] = PYTHON_TO_PYDANTIC_TYPES[enum_type]
83
+ field = definition
84
+ else:
85
+ field = field_info
86
+ if "items" in field_info:
87
+ field["items"] = _append_field_info_to_props(
88
+ field_info=field_info["items"],
89
+ field_name=field_name,
90
+ model_schema=model_schema,
91
+ )
92
+ if required is not None:
93
+ field["required"] = required
94
+ if "const" in field_info and "type" not in field_info:
95
+ field["type"] = PYTHON_TO_PYDANTIC_TYPES[type(field_info["const"])]
96
+ return field
patito/exceptions.py CHANGED
@@ -1,14 +1,181 @@
1
- """Module containing all custom exceptions raised by patito."""
1
+ """Exceptions used by patito."""
2
2
 
3
- import pydantic
3
+ from typing import (
4
+ TYPE_CHECKING,
5
+ Any,
6
+ Dict,
7
+ Generator,
8
+ List,
9
+ Optional,
10
+ Sequence,
11
+ Tuple,
12
+ Type,
13
+ TypedDict,
14
+ Union,
15
+ )
4
16
 
17
+ from patito._pydantic.repr import Representation
5
18
 
6
- class ValidationError(pydantic.ValidationError):
7
- """Exception raised when dataframe does not match schema."""
19
+ if TYPE_CHECKING:
20
+ from pydantic import BaseModel
8
21
 
22
+ Loc = Tuple[Union[int, str], ...]
9
23
 
10
- class ErrorWrapper(pydantic.error_wrappers.ErrorWrapper):
11
- """Wrapper for specific column validation error."""
24
+ class _ErrorDictRequired(TypedDict):
25
+ loc: Loc
26
+ msg: str
27
+ type: str
28
+
29
+ class ErrorDict(_ErrorDictRequired, total=False):
30
+ ctx: Dict[str, Any]
31
+
32
+ from patito._pydantic.repr import ReprArgs
33
+
34
+
35
+ __all__ = "ErrorWrapper", "DataFrameValidationError"
36
+
37
+
38
+ class ErrorWrapper(Representation):
39
+ """Error handler for nicely accumulating errors."""
40
+
41
+ __slots__ = "exc", "_loc"
42
+
43
+ def __init__(self, exc: Exception, loc: Union[str, "Loc"]) -> None:
44
+ """Wrap an error in an ErrorWrapper."""
45
+ self.exc = exc
46
+ self._loc = loc
47
+
48
+ def loc_tuple(self) -> "Loc":
49
+ """Represent error as tuple."""
50
+ if isinstance(self._loc, tuple):
51
+ return self._loc
52
+ else:
53
+ return (self._loc,)
54
+
55
+ def __repr_args__(self) -> "ReprArgs":
56
+ """Pydantic repr."""
57
+ return [("exc", self.exc), ("loc", self.loc_tuple())]
58
+
59
+
60
+ # ErrorList is something like Union[List[Union[List[ErrorWrapper], ErrorWrapper]], ErrorWrapper]
61
+ # but recursive, therefore just use:
62
+ ErrorList = Union[Sequence[Any], ErrorWrapper]
63
+
64
+
65
+ class DataFrameValidationError(Representation, ValueError):
66
+ """Parent error for DataFrame validation errors."""
67
+
68
+ __slots__ = "raw_errors", "model", "_error_cache"
69
+
70
+ def __init__(self, errors: Sequence[ErrorList], model: Type["BaseModel"]) -> None:
71
+ """Create a dataframe validation error."""
72
+ self.raw_errors = errors
73
+ self.model = model
74
+ self._error_cache: Optional[List["ErrorDict"]] = None
75
+
76
+ def errors(self) -> List["ErrorDict"]:
77
+ """Get list of errors."""
78
+ if self._error_cache is None:
79
+ self._error_cache = list(flatten_errors(self.raw_errors))
80
+ return self._error_cache
81
+
82
+ def __str__(self) -> str:
83
+ """String reprentation of error."""
84
+ errors = self.errors()
85
+ no_errors = len(errors)
86
+ return (
87
+ f'{no_errors} validation error{"" if no_errors == 1 else "s"} for {self.model.__name__}\n'
88
+ f"{display_errors(errors)}"
89
+ )
90
+
91
+ def __repr_args__(self) -> "ReprArgs":
92
+ """Pydantic repr."""
93
+ return [("model", self.model.__name__), ("errors", self.errors())]
94
+
95
+
96
+ def display_errors(errors: List["ErrorDict"]) -> str:
97
+ return "\n".join(
98
+ f'{_display_error_loc(e)}\n {e["msg"]} ({_display_error_type_and_ctx(e)})'
99
+ for e in errors
100
+ )
101
+
102
+
103
+ def _display_error_loc(error: "ErrorDict") -> str:
104
+ return " -> ".join(str(e) for e in error["loc"])
105
+
106
+
107
+ def _display_error_type_and_ctx(error: "ErrorDict") -> str:
108
+ t = "type=" + error["type"]
109
+ ctx = error.get("ctx")
110
+ if ctx:
111
+ return t + "".join(f"; {k}={v}" for k, v in ctx.items())
112
+ else:
113
+ return t
114
+
115
+
116
+ def flatten_errors(
117
+ errors: Sequence[Any], loc: Optional["Loc"] = None
118
+ ) -> Generator["ErrorDict", None, None]:
119
+ for error in errors:
120
+ if isinstance(error, ErrorWrapper):
121
+ if loc:
122
+ error_loc = loc + error.loc_tuple()
123
+ else:
124
+ error_loc = error.loc_tuple()
125
+
126
+ if isinstance(error.exc, DataFrameValidationError):
127
+ yield from flatten_errors(error.exc.raw_errors, error_loc)
128
+ else:
129
+ yield error_dict(error.exc, error_loc)
130
+ elif isinstance(error, list):
131
+ yield from flatten_errors(error, loc=loc)
132
+ else:
133
+ raise RuntimeError(f"Unknown error object: {error}")
134
+
135
+
136
+ def error_dict(exc: Exception, loc: "Loc") -> "ErrorDict":
137
+ type_ = get_exc_type(exc.__class__)
138
+ msg_template = getattr(exc, "msg_template", None)
139
+ ctx = exc.__dict__
140
+ if msg_template:
141
+ msg = msg_template.format(**ctx)
142
+ else:
143
+ msg = str(exc)
144
+
145
+ d: "ErrorDict" = {"loc": loc, "msg": msg, "type": type_}
146
+
147
+ if ctx:
148
+ d["ctx"] = ctx
149
+
150
+ return d
151
+
152
+
153
+ _EXC_TYPE_CACHE: Dict[Type[Exception], str] = {}
154
+
155
+
156
+ def get_exc_type(cls: Type[Exception]) -> str:
157
+ # slightly more efficient than using lru_cache since we don't need to worry about the cache filling up
158
+ try:
159
+ return _EXC_TYPE_CACHE[cls]
160
+ except KeyError:
161
+ r = _get_exc_type(cls)
162
+ _EXC_TYPE_CACHE[cls] = r
163
+ return r
164
+
165
+
166
+ def _get_exc_type(cls: Type[Exception]) -> str:
167
+ if issubclass(cls, AssertionError):
168
+ return "assertion_error"
169
+
170
+ base_name = "type_error" if issubclass(cls, TypeError) else "value_error"
171
+ if cls in (TypeError, ValueError):
172
+ # just TypeError or ValueError, no extra code
173
+ return base_name
174
+
175
+ # if it's not a TypeError or ValueError, we just take the lowercase of the exception name
176
+ # no chaining or snake case logic, use "code" for more complex error types.
177
+ code = getattr(cls, "code", None) or cls.__name__.replace("Error", "").lower()
178
+ return base_name + "." + code
12
179
 
13
180
 
14
181
  class WrongColumnsError(TypeError):
@@ -19,7 +186,7 @@ class MissingColumnsError(WrongColumnsError):
19
186
  """Exception for when a dataframe is missing one or more columns."""
20
187
 
21
188
 
22
- class SuperflousColumnsError(WrongColumnsError):
189
+ class SuperfluousColumnsError(WrongColumnsError):
23
190
  """Exception for when a dataframe has one ore more non-specified columns."""
24
191
 
25
192