patito 0.5.1__py3-none-any.whl → 0.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
patito/pydantic.py CHANGED
@@ -1,29 +1,55 @@
1
1
  """Logic related to wrapping logic around the pydantic library."""
2
+
2
3
  from __future__ import annotations
3
4
 
4
5
  import itertools
5
6
  from collections.abc import Iterable
6
- from datetime import date, datetime
7
+ from datetime import date, datetime, time, timedelta
8
+ from functools import partial
9
+ from inspect import getfullargspec
7
10
  from typing import (
8
11
  TYPE_CHECKING,
9
12
  Any,
10
13
  ClassVar,
11
14
  Dict,
15
+ FrozenSet,
16
+ Generic,
12
17
  List,
18
+ Literal,
19
+ Mapping,
13
20
  Optional,
14
- Set,
21
+ Sequence,
22
+ Tuple,
15
23
  Type,
16
24
  TypeVar,
17
25
  Union,
18
26
  cast,
27
+ get_args,
19
28
  )
20
29
 
21
30
  import polars as pl
22
- from polars.datatypes import PolarsDataType
23
- from pydantic import BaseConfig, BaseModel, Field, create_model # noqa: F401
24
- from pydantic.main import ModelMetaclass as PydanticModelMetaclass
25
- from typing_extensions import Literal, get_args
26
-
31
+ from polars.datatypes import DataType, DataTypeClass
32
+ from pydantic import ( # noqa: F401
33
+ BaseModel,
34
+ create_model,
35
+ field_serializer,
36
+ fields,
37
+ )
38
+ from pydantic._internal._model_construction import (
39
+ ModelMetaclass as PydanticModelMetaclass,
40
+ )
41
+ from zoneinfo import ZoneInfo
42
+
43
+ from patito._pydantic.column_info import CI, ColumnInfo
44
+ from patito._pydantic.dtypes import (
45
+ default_dtypes_for_model,
46
+ dtype_from_string,
47
+ is_optional,
48
+ valid_dtypes_for_model,
49
+ validate_annotation,
50
+ validate_polars_dtype,
51
+ )
52
+ from patito._pydantic.schema import column_infos_for_model, schema_for_model
27
53
  from patito.polars import DataFrame, LazyFrame
28
54
  from patito.validators import validate
29
55
 
@@ -36,46 +62,34 @@ except ImportError:
36
62
 
37
63
  if TYPE_CHECKING:
38
64
  import patito.polars
39
- from patito.duckdb import DuckDBSQLType
40
65
 
41
66
  # The generic type of a single row in given Relation.
42
67
  # Should be a typed subclass of Model.
43
68
  ModelType = TypeVar("ModelType", bound="Model")
44
69
 
45
- # A mapping from pydantic types to the equivalent type used in DuckDB
46
- PYDANTIC_TO_DUCKDB_TYPES = {
47
- "integer": "BIGINT",
48
- "string": "VARCHAR",
49
- "number": "DOUBLE",
50
- "boolean": "BOOLEAN",
51
- }
52
-
53
- # A mapping from pydantic types to equivalent dtypes used in polars
54
- PYDANTIC_TO_POLARS_TYPES = {
55
- "integer": pl.Int64,
56
- "string": pl.Utf8,
57
- "number": pl.Float64,
58
- "boolean": pl.Boolean,
59
- }
60
70
 
61
-
62
- class ModelMetaclass(PydanticModelMetaclass):
63
- """
64
- Metclass used by patito.Model.
71
+ class ModelMetaclass(PydanticModelMetaclass, Generic[CI]):
72
+ """Metaclass used by patito.Model.
65
73
 
66
74
  Responsible for setting any relevant model-dependent class properties.
67
75
  """
68
76
 
69
- def __init__(cls, name: str, bases: tuple, clsdict: dict) -> None:
70
- """
71
- Construct new patito model.
77
+ column_info_class: ClassVar[Type[ColumnInfo]] = ColumnInfo
78
+
79
+ if TYPE_CHECKING:
80
+ model_fields: ClassVar[Dict[str, fields.FieldInfo]]
81
+
82
+ def __init__(cls, name: str, bases: tuple, clsdict: dict, **kwargs) -> None:
83
+ """Construct new patito model.
72
84
 
73
85
  Args:
74
86
  name: Name of model class.
75
87
  bases: Tuple of superclasses.
76
88
  clsdict: Dictionary containing class properties.
89
+ **kwargs: Additional keyword arguments.
90
+
77
91
  """
78
- super().__init__(name, bases, clsdict)
92
+ super().__init__(name, bases, clsdict, **kwargs)
79
93
  # Add a custom subclass of patito.DataFrame to the model class,
80
94
  # where .set_model() has been implicitly set.
81
95
  cls.DataFrame = DataFrame._construct_dataframe_model_class(
@@ -86,14 +100,34 @@ class ModelMetaclass(PydanticModelMetaclass):
86
100
  model=cls, # type: ignore
87
101
  )
88
102
 
89
- # --- Class properties ---
90
- # These properties will only be available on Model *classes*, not instantiated
91
- # objects This is backwards compatible to python versions before python 3.9,
92
- # unlike a combination of @classmethod and @property.
103
+ def __hash__(self) -> int:
104
+ """Return hash of the model class."""
105
+ return super().__hash__()
106
+
93
107
  @property
94
- def columns(cls: Type[ModelType]) -> List[str]: # type: ignore
108
+ def column_infos(cls: Type[ModelType]) -> Mapping[str, ColumnInfo]:
109
+ """Return column information for the model."""
110
+ return column_infos_for_model(cls)
111
+
112
+ @property
113
+ def model_schema(cls: Type[ModelType]) -> Mapping[str, Mapping[str, Any]]:
114
+ """Return schema properties where definition references have been resolved.
115
+
116
+ Returns:
117
+ Field information as a dictionary where the keys are field names and the
118
+ values are dictionaries containing metadata information about the field
119
+ itself.
120
+
121
+ Raises:
122
+ TypeError: if a field is annotated with an enum where the values are of
123
+ different types.
124
+
95
125
  """
96
- Return the name of the dataframe columns specified by the fields of the model.
126
+ return schema_for_model(cls)
127
+
128
+ @property
129
+ def columns(cls: Type[ModelType]) -> List[str]:
130
+ """Return the name of the dataframe columns specified by the fields of the model.
97
131
 
98
132
  Returns:
99
133
  List of column names.
@@ -106,15 +140,13 @@ class ModelMetaclass(PydanticModelMetaclass):
106
140
  ...
107
141
  >>> Product.columns
108
142
  ['name', 'price']
143
+
109
144
  """
110
- return list(cls.schema()["properties"].keys())
145
+ return list(cls.model_fields.keys())
111
146
 
112
147
  @property
113
- def dtypes( # type: ignore
114
- cls: Type[ModelType], # pyright: ignore
115
- ) -> dict[str, Type[pl.DataType]]:
116
- """
117
- Return the polars dtypes of the dataframe.
148
+ def dtypes(cls: Type[ModelType]) -> dict[str, DataTypeClass | DataType]:
149
+ """Return the polars dtypes of the dataframe.
118
150
 
119
151
  Unless Field(dtype=...) is specified, the highest signed column dtype
120
152
  is chosen for integer and float columns.
@@ -130,18 +162,16 @@ class ModelMetaclass(PydanticModelMetaclass):
130
162
  ... price: float
131
163
  ...
132
164
  >>> Product.dtypes
133
- {'name': Utf8, 'ideal_temperature': Int64, 'price': Float64}
165
+ {'name': String, 'ideal_temperature': Int64, 'price': Float64}
166
+
134
167
  """
135
- return {
136
- column: valid_dtypes[0] for column, valid_dtypes in cls.valid_dtypes.items()
137
- }
168
+ return default_dtypes_for_model(cls)
138
169
 
139
170
  @property
140
- def valid_dtypes( # type: ignore
141
- cls: Type[ModelType], # pyright: ignore
142
- ) -> dict[str, List[Union[pl.PolarsDataType, pl.List]]]:
143
- """
144
- Return a list of polars dtypes which Patito considers valid for each field.
171
+ def valid_dtypes(
172
+ cls: Type[ModelType],
173
+ ) -> Mapping[str, FrozenSet[DataTypeClass | DataType]]:
174
+ """Return a list of polars dtypes which Patito considers valid for each field.
145
175
 
146
176
  The first item of each list is the default dtype chosen by Patito.
147
177
 
@@ -152,271 +182,12 @@ class ModelMetaclass(PydanticModelMetaclass):
152
182
  NotImplementedError: If one or more model fields are annotated with types
153
183
  not compatible with polars.
154
184
 
155
- Example:
156
- >>> from pprint import pprint
157
- >>> import patito as pt
158
-
159
- >>> class MyModel(pt.Model):
160
- ... bool_column: bool
161
- ... str_column: str
162
- ... int_column: int
163
- ... float_column: float
164
- ...
165
- >>> pprint(MyModel.valid_dtypes)
166
- {'bool_column': [Boolean],
167
- 'float_column': [Float64, Float32],
168
- 'int_column': [Int64, Int32, Int16, Int8, UInt64, UInt32, UInt16, UInt8],
169
- 'str_column': [Utf8]}
170
185
  """
171
- valid_dtypes = {}
172
- for column, props in cls._schema_properties().items():
173
- column_dtypes: List[Union[PolarsDataType, pl.List]]
174
- if props.get("type") == "array":
175
- array_props = props["items"]
176
- item_dtypes = cls._valid_dtypes(props=array_props)
177
- if item_dtypes is None:
178
- raise NotImplementedError(
179
- f"No valid dtype mapping found for column '{column}'."
180
- )
181
- column_dtypes = [pl.List(dtype) for dtype in item_dtypes]
182
- else:
183
- column_dtypes = cls._valid_dtypes(props=props) # pyright: ignore
184
-
185
- if column_dtypes is None:
186
- raise NotImplementedError(
187
- f"No valid dtype mapping found for column '{column}'."
188
- )
189
- valid_dtypes[column] = column_dtypes
190
-
191
- return valid_dtypes
192
-
193
- @staticmethod
194
- def _valid_dtypes( # noqa: C901
195
- props: Dict,
196
- ) -> Optional[List[pl.PolarsDataType]]:
197
- """
198
- Map schema property to list of valid polars data types.
199
-
200
- Args:
201
- props: Dictionary value retrieved from BaseModel._schema_properties().
202
-
203
- Returns:
204
- List of valid dtypes. None if no mapping exists.
205
- """
206
- if "dtype" in props:
207
- return [
208
- props["dtype"],
209
- ]
210
- elif "enum" in props and props["type"] == "string":
211
- return [pl.Categorical, pl.Utf8]
212
- elif "type" not in props:
213
- return None
214
- elif props["type"] == "integer":
215
- return [
216
- pl.Int64,
217
- pl.Int32,
218
- pl.Int16,
219
- pl.Int8,
220
- pl.UInt64,
221
- pl.UInt32,
222
- pl.UInt16,
223
- pl.UInt8,
224
- ]
225
- elif props["type"] == "number":
226
- if props.get("format") == "time-delta":
227
- return [pl.Duration]
228
- else:
229
- return [pl.Float64, pl.Float32]
230
- elif props["type"] == "boolean":
231
- return [pl.Boolean]
232
- elif props["type"] == "string":
233
- string_format = props.get("format")
234
- if string_format is None:
235
- return [pl.Utf8]
236
- elif string_format == "date":
237
- return [pl.Date]
238
- # TODO: Find out why this branch is not being hit
239
- elif string_format == "date-time": # pragma: no cover
240
- return [pl.Datetime]
241
- else:
242
- return None # pragma: no cover
243
- elif props["type"] == "null":
244
- return [pl.Null]
245
- else: # pragma: no cover
246
- return None
247
-
248
- @property
249
- def valid_sql_types( # type: ignore # noqa: C901
250
- cls: Type[ModelType], # pyright: ignore
251
- ) -> dict[str, List["DuckDBSQLType"]]:
252
- """
253
- Return a list of DuckDB SQL types which Patito considers valid for each field.
254
-
255
- The first item of each list is the default dtype chosen by Patito.
256
-
257
- Returns:
258
- A dictionary mapping each column string name to a list of DuckDB SQL types
259
- represented as strings.
260
-
261
- Raises:
262
- NotImplementedError: If one or more model fields are annotated with types
263
- not compatible with DuckDB.
264
-
265
- Example:
266
- >>> import patito as pt
267
- >>> from pprint import pprint
268
-
269
- >>> class MyModel(pt.Model):
270
- ... bool_column: bool
271
- ... str_column: str
272
- ... int_column: int
273
- ... float_column: float
274
- ...
275
- >>> pprint(MyModel.valid_sql_types)
276
- {'bool_column': ['BOOLEAN', 'BOOL', 'LOGICAL'],
277
- 'float_column': ['DOUBLE',
278
- 'FLOAT8',
279
- 'NUMERIC',
280
- 'DECIMAL',
281
- 'REAL',
282
- 'FLOAT4',
283
- 'FLOAT'],
284
- 'int_column': ['INTEGER',
285
- 'INT4',
286
- 'INT',
287
- 'SIGNED',
288
- 'BIGINT',
289
- 'INT8',
290
- 'LONG',
291
- 'HUGEINT',
292
- 'SMALLINT',
293
- 'INT2',
294
- 'SHORT',
295
- 'TINYINT',
296
- 'INT1',
297
- 'UBIGINT',
298
- 'UINTEGER',
299
- 'USMALLINT',
300
- 'UTINYINT'],
301
- 'str_column': ['VARCHAR', 'CHAR', 'BPCHAR', 'TEXT', 'STRING']}
302
- """
303
- valid_dtypes: Dict[str, List["DuckDBSQLType"]] = {}
304
- for column, props in cls._schema_properties().items():
305
- if "sql_type" in props:
306
- valid_dtypes[column] = [
307
- props["sql_type"],
308
- ]
309
- elif "enum" in props and props["type"] == "string":
310
- from patito.duckdb import _enum_type_name
311
-
312
- # fmt: off
313
- valid_dtypes[column] = [ # pyright: ignore
314
- _enum_type_name(field_properties=props), # type: ignore
315
- "VARCHAR", "CHAR", "BPCHAR", "TEXT", "STRING",
316
- ]
317
- # fmt: on
318
- elif "type" not in props:
319
- raise NotImplementedError(
320
- f"No valid sql_type mapping found for column '{column}'."
321
- )
322
- elif props["type"] == "integer":
323
- # fmt: off
324
- valid_dtypes[column] = [
325
- "INTEGER", "INT4", "INT", "SIGNED",
326
- "BIGINT", "INT8", "LONG",
327
- "HUGEINT",
328
- "SMALLINT", "INT2", "SHORT",
329
- "TINYINT", "INT1",
330
- "UBIGINT",
331
- "UINTEGER",
332
- "USMALLINT",
333
- "UTINYINT",
334
- ]
335
- # fmt: on
336
- elif props["type"] == "number":
337
- if props.get("format") == "time-delta":
338
- valid_dtypes[column] = [
339
- "INTERVAL",
340
- ]
341
- else:
342
- # fmt: off
343
- valid_dtypes[column] = [
344
- "DOUBLE", "FLOAT8", "NUMERIC", "DECIMAL",
345
- "REAL", "FLOAT4", "FLOAT",
346
- ]
347
- # fmt: on
348
- elif props["type"] == "boolean":
349
- # fmt: off
350
- valid_dtypes[column] = [
351
- "BOOLEAN", "BOOL", "LOGICAL",
352
- ]
353
- # fmt: on
354
- elif props["type"] == "string":
355
- string_format = props.get("format")
356
- if string_format is None:
357
- # fmt: off
358
- valid_dtypes[column] = [
359
- "VARCHAR", "CHAR", "BPCHAR", "TEXT", "STRING",
360
- ]
361
- # fmt: on
362
- elif string_format == "date":
363
- valid_dtypes[column] = ["DATE"]
364
- # TODO: Find out why this branch is not being hit
365
- elif string_format == "date-time": # pragma: no cover
366
- # fmt: off
367
- valid_dtypes[column] = [
368
- "TIMESTAMP", "DATETIME",
369
- "TIMESTAMP WITH TIMEZONE", "TIMESTAMPTZ",
370
- ]
371
- # fmt: on
372
- elif props["type"] == "null":
373
- valid_dtypes[column] = [
374
- "INTEGER",
375
- ]
376
- else: # pragma: no cover
377
- raise NotImplementedError(
378
- f"No valid sql_type mapping found for column '{column}'"
379
- )
380
-
381
- return valid_dtypes
186
+ return valid_dtypes_for_model(cls)
382
187
 
383
188
  @property
384
- def sql_types( # type: ignore
385
- cls: Type[ModelType], # pyright: ignore
386
- ) -> dict[str, str]:
387
- """
388
- Return compatible DuckDB SQL types for all model fields.
389
-
390
- Returns:
391
- Dictionary with column name keys and SQL type identifier strings.
392
-
393
- Example:
394
- >>> from typing import Literal
395
- >>> import patito as pt
396
-
397
- >>> class MyModel(pt.Model):
398
- ... int_column: int
399
- ... str_column: str
400
- ... float_column: float
401
- ... literal_column: Literal["a", "b", "c"]
402
- ...
403
- >>> MyModel.sql_types
404
- {'int_column': 'INTEGER',
405
- 'str_column': 'VARCHAR',
406
- 'float_column': 'DOUBLE',
407
- 'literal_column': 'enum__4a496993dde04060df4e15a340651b45'}
408
- """
409
- return {
410
- column: valid_types[0]
411
- for column, valid_types in cls.valid_sql_types.items()
412
- }
413
-
414
- @property
415
- def defaults( # type: ignore
416
- cls: Type[ModelType], # pyright: ignore
417
- ) -> dict[str, Any]:
418
- """
419
- Return default field values specified on the model.
189
+ def defaults(cls: Type[ModelType]) -> dict[str, Any]:
190
+ """Return default field values specified on the model.
420
191
 
421
192
  Returns:
422
193
  Dictionary containing fields with their respective default values.
@@ -431,6 +202,7 @@ class ModelMetaclass(PydanticModelMetaclass):
431
202
  ...
432
203
  >>> Product.defaults
433
204
  {'price': 0, 'temperature_zone': 'dry'}
205
+
434
206
  """
435
207
  return {
436
208
  field_name: props["default"]
@@ -439,11 +211,8 @@ class ModelMetaclass(PydanticModelMetaclass):
439
211
  }
440
212
 
441
213
  @property
442
- def non_nullable_columns( # type: ignore
443
- cls: Type[ModelType], # pyright: ignore
444
- ) -> set[str]:
445
- """
446
- Return names of those columns that are non-nullable in the schema.
214
+ def non_nullable_columns(cls: Type[ModelType]) -> set[str]:
215
+ """Return names of those columns that are non-nullable in the schema.
447
216
 
448
217
  Returns:
449
218
  Set of column name strings.
@@ -453,21 +222,26 @@ class ModelMetaclass(PydanticModelMetaclass):
453
222
  >>> import patito as pt
454
223
  >>> class MyModel(pt.Model):
455
224
  ... nullable_field: Optional[int]
456
- ... inferred_nullable_field: int = None
225
+ ... another_nullable_field: Optional[int] = None
457
226
  ... non_nullable_field: int
458
227
  ... another_non_nullable_field: str
459
228
  ...
460
229
  >>> sorted(MyModel.non_nullable_columns)
461
230
  ['another_non_nullable_field', 'non_nullable_field']
231
+
462
232
  """
463
- return set(cls.schema().get("required", {}))
233
+ return set(
234
+ k
235
+ for k in cls.columns
236
+ if not (
237
+ is_optional(cls.model_fields[k].annotation)
238
+ or cls.model_fields[k].annotation == type(None)
239
+ )
240
+ )
464
241
 
465
242
  @property
466
- def nullable_columns( # type: ignore
467
- cls: Type[ModelType], # pyright: ignore
468
- ) -> set[str]:
469
- """
470
- Return names of those columns that are nullable in the schema.
243
+ def nullable_columns(cls: Type[ModelType]) -> set[str]:
244
+ """Return names of those columns that are nullable in the schema.
471
245
 
472
246
  Returns:
473
247
  Set of column name strings.
@@ -477,21 +251,19 @@ class ModelMetaclass(PydanticModelMetaclass):
477
251
  >>> import patito as pt
478
252
  >>> class MyModel(pt.Model):
479
253
  ... nullable_field: Optional[int]
480
- ... inferred_nullable_field: int = None
254
+ ... another_nullable_field: Optional[int] = None
481
255
  ... non_nullable_field: int
482
256
  ... another_non_nullable_field: str
483
257
  ...
484
258
  >>> sorted(MyModel.nullable_columns)
485
- ['inferred_nullable_field', 'nullable_field']
259
+ ['another_nullable_field', 'nullable_field']
260
+
486
261
  """
487
262
  return set(cls.columns) - cls.non_nullable_columns
488
263
 
489
264
  @property
490
- def unique_columns( # type: ignore
491
- cls: Type[ModelType], # pyright: ignore
492
- ) -> set[str]:
493
- """
494
- Return columns with uniqueness constraint.
265
+ def unique_columns(cls: Type[ModelType]) -> set[str]:
266
+ """Return columns with uniqueness constraint.
495
267
 
496
268
  Returns:
497
269
  Set of column name strings.
@@ -507,49 +279,35 @@ class ModelMetaclass(PydanticModelMetaclass):
507
279
  ...
508
280
  >>> sorted(Product.unique_columns)
509
281
  ['barcode', 'product_id']
282
+
510
283
  """
511
- props = cls._schema_properties()
512
- return {column for column in cls.columns if props[column].get("unique", False)}
284
+ infos = cls.column_infos
285
+ return {column for column in cls.columns if infos[column].unique}
286
+
287
+ @property
288
+ def derived_columns(cls: Type[ModelType]) -> set[str]:
289
+ """Return set of columns which are derived from other columns."""
290
+ infos = cls.column_infos
291
+ return {
292
+ column for column in cls.columns if infos[column].derived_from is not None
293
+ }
513
294
 
514
295
 
515
296
  class Model(BaseModel, metaclass=ModelMetaclass):
516
297
  """Custom pydantic class for representing table schema and constructing rows."""
517
298
 
518
- # -- Class properties set by model metaclass --
519
- # This weird combination of a MetaClass + type annotation
520
- # in order to make the following work simultaneously:
521
- # 1. Make these dynamically constructed properties of the class.
522
- # 2. Have the correct type information for type checkers.
523
- # 3. Allow sphinx-autodoc to construct correct documentation.
524
- # 4. Be compatible with python 3.7.
525
- # Once we drop support for python 3.7, we can replace all of this with just a simple
526
- # combination of @property and @classmethod.
527
- columns: ClassVar[List[str]]
528
-
529
- unique_columns: ClassVar[Set[str]]
530
- non_nullable_columns: ClassVar[Set[str]]
531
- nullable_columns: ClassVar[Set[str]]
532
-
533
- dtypes: ClassVar[Dict[str, Type[pl.DataType]]]
534
- sql_types: ClassVar[Dict[str, str]]
535
- valid_dtypes: ClassVar[Dict[str, List[Type[pl.DataType]]]]
536
- valid_sql_types: ClassVar[Dict[str, List["DuckDBSQLType"]]]
537
-
538
- defaults: ClassVar[Dict[str, Any]]
539
-
540
- @classmethod # type: ignore[misc]
541
- @property
542
- def DataFrame(
543
- cls: Type[ModelType],
544
- ) -> Type[DataFrame[ModelType]]: # pyright: ignore # noqa
545
- """Return DataFrame class where DataFrame.set_model() is set to self."""
546
-
547
- @classmethod # type: ignore[misc]
548
- @property
549
- def LazyFrame(
550
- cls: Type[ModelType],
551
- ) -> Type[LazyFrame[ModelType]]: # pyright: ignore
552
- """Return DataFrame class where DataFrame.set_model() is set to self."""
299
+ @classmethod
300
+ def validate_schema(cls: Type[ModelType]):
301
+ """Users should run this after defining or edit a model. We withhold the checks at model definition time to avoid expensive queries of the model schema."""
302
+ for column in cls.columns:
303
+ col_info = cls.column_infos[column]
304
+ field_info = cls.model_fields[column]
305
+ if col_info.dtype:
306
+ validate_polars_dtype(
307
+ annotation=field_info.annotation, dtype=col_info.dtype
308
+ )
309
+ else:
310
+ validate_annotation(field_info.annotation)
553
311
 
554
312
  @classmethod
555
313
  def from_row(
@@ -557,8 +315,7 @@ class Model(BaseModel, metaclass=ModelMetaclass):
557
315
  row: Union["pd.DataFrame", pl.DataFrame],
558
316
  validate: bool = True,
559
317
  ) -> ModelType:
560
- """
561
- Represent a single data frame row as a Patito model.
318
+ """Represent a single data frame row as a Patito model.
562
319
 
563
320
  Args:
564
321
  row: A dataframe, either polars and pandas, consisting of a single row.
@@ -588,12 +345,13 @@ class Model(BaseModel, metaclass=ModelMetaclass):
588
345
  Product(product_id=1, name='product name', price=1.22)
589
346
  >>> Product.from_row(df, validate=False)
590
347
  Product(product_id='1', name='product name', price='1.22')
348
+
591
349
  """
592
350
  if isinstance(row, pl.DataFrame):
593
351
  dataframe = row
594
352
  elif _PANDAS_AVAILABLE and isinstance(row, pd.DataFrame):
595
353
  dataframe = pl.DataFrame._from_pandas(row)
596
- elif _PANDAS_AVAILABLE and isinstance(row, pd.Series): # type: ignore[unreachable]
354
+ elif _PANDAS_AVAILABLE and isinstance(row, pd.Series):
597
355
  return cls(**dict(row.items())) # type: ignore[unreachable]
598
356
  else:
599
357
  raise TypeError(f"{cls.__name__}.from_row not implemented for {type(row)}.")
@@ -605,8 +363,7 @@ class Model(BaseModel, metaclass=ModelMetaclass):
605
363
  dataframe: pl.DataFrame,
606
364
  validate: bool = True,
607
365
  ) -> ModelType:
608
- """
609
- Construct model from a single polars row.
366
+ """Construct model from a single polars row.
610
367
 
611
368
  Args:
612
369
  dataframe: A polars dataframe consisting of one single row.
@@ -640,6 +397,7 @@ class Model(BaseModel, metaclass=ModelMetaclass):
640
397
  Product(product_id=1, name='product name', price=1.22)
641
398
  >>> Product._from_polars(df, validate=False)
642
399
  Product(product_id='1', name='product name', price='1.22')
400
+
643
401
  """
644
402
  if not isinstance(dataframe, pl.DataFrame):
645
403
  raise TypeError(
@@ -657,28 +415,38 @@ class Model(BaseModel, metaclass=ModelMetaclass):
657
415
  if validate:
658
416
  return cls(**dataframe.to_dicts()[0])
659
417
  else:
660
- return cls.construct(**dataframe.to_dicts()[0])
418
+ return cls.model_construct(**dataframe.to_dicts()[0])
661
419
 
662
420
  @classmethod
663
421
  def validate(
664
422
  cls,
665
423
  dataframe: Union["pd.DataFrame", pl.DataFrame],
424
+ columns: Optional[Sequence[str]] = None,
425
+ allow_missing_columns: bool = False,
426
+ allow_superfluous_columns: bool = False,
427
+ **kwargs,
666
428
  ) -> None:
667
- """
668
- Validate the schema and content of the given dataframe.
429
+ """Validate the schema and content of the given dataframe.
669
430
 
670
431
  Args:
671
432
  dataframe: Polars DataFrame to be validated.
433
+ columns: Optional list of columns to validate. If not provided, all columns
434
+ of the dataframe will be validated.
435
+ allow_missing_columns: If True, missing columns will not be considered an error.
436
+ allow_superfluous_columns: If True, additional columns will not be considered an error.
437
+ **kwargs: Additional keyword arguments to be passed to the validation
438
+
439
+ Returns:
440
+ ``None``:
672
441
 
673
442
  Raises:
674
- patito.exceptions.ValidationError: If the given dataframe does not match
443
+ patito.exceptions.DataFrameValidationError: If the given dataframe does not match
675
444
  the given schema.
676
445
 
677
446
  Examples:
678
447
  >>> import patito as pt
679
448
  >>> import polars as pl
680
449
 
681
-
682
450
  >>> class Product(pt.Model):
683
451
  ... product_id: int = pt.Field(unique=True)
684
452
  ... temperature_zone: Literal["dry", "cold", "frozen"]
@@ -693,7 +461,7 @@ class Model(BaseModel, metaclass=ModelMetaclass):
693
461
  ... )
694
462
  >>> try:
695
463
  ... Product.validate(df)
696
- ... except pt.ValidationError as exc:
464
+ ... except pt.DataFrameValidationError as exc:
697
465
  ... print(exc)
698
466
  ...
699
467
  3 validation errors for Product
@@ -703,19 +471,28 @@ class Model(BaseModel, metaclass=ModelMetaclass):
703
471
  2 rows with duplicated values. (type=value_error.rowvalue)
704
472
  temperature_zone
705
473
  Rows with invalid values: {'oven'}. (type=value_error.rowvalue)
474
+
706
475
  """
707
- validate(dataframe=dataframe, schema=cls)
476
+ validate(
477
+ dataframe=dataframe,
478
+ schema=cls,
479
+ columns=columns,
480
+ allow_missing_columns=allow_missing_columns,
481
+ allow_superfluous_columns=allow_superfluous_columns,
482
+ **kwargs,
483
+ )
708
484
 
709
485
  @classmethod
710
486
  def example_value( # noqa: C901
711
487
  cls,
712
- field: str,
713
- ) -> Union[date, datetime, float, int, str, None]:
714
- """
715
- Return a valid example value for the given model field.
488
+ field: Optional[str] = None,
489
+ properties: Optional[Dict[str, Any]] = None,
490
+ ) -> Union[date, datetime, time, timedelta, float, int, str, None, Mapping, List]:
491
+ """Return a valid example value for the given model field.
716
492
 
717
493
  Args:
718
494
  field: Field name identifier.
495
+ properties: Pydantic v2-style properties dict
719
496
 
720
497
  Returns:
721
498
  A single value which is consistent with the given field definition.
@@ -738,10 +515,36 @@ class Model(BaseModel, metaclass=ModelMetaclass):
738
515
  'dummy_string'
739
516
  >>> Product.example_value("temperature_zone")
740
517
  'dry'
518
+
741
519
  """
742
- field_data = cls._schema_properties()
743
- properties = field_data[field]
744
- field_type = properties["type"]
520
+ if field is None and properties is None:
521
+ raise ValueError(
522
+ "Either 'field' or 'properties' must be provided as argument."
523
+ )
524
+ if field is not None and properties is not None:
525
+ raise ValueError(
526
+ "Only one of 'field' or 'properties' can be provided as argument."
527
+ )
528
+ if field:
529
+ properties = cls._schema_properties()[field]
530
+ info = cls.column_infos[field]
531
+ else:
532
+ info = cls.column_info_class()
533
+ properties = properties or {}
534
+
535
+ if "type" in properties:
536
+ field_type = properties["type"]
537
+ elif "anyOf" in properties:
538
+ allowable = [x["type"] for x in properties["anyOf"] if "type" in x]
539
+ if "null" in allowable:
540
+ field_type = "null"
541
+ else:
542
+ field_type = allowable[0]
543
+ else:
544
+ raise NotImplementedError(
545
+ f"Field type for {properties['title']} not found."
546
+ )
547
+
745
548
  if "const" in properties:
746
549
  # The default value is the only valid value, provided as const
747
550
  return properties["const"]
@@ -750,7 +553,10 @@ class Model(BaseModel, metaclass=ModelMetaclass):
750
553
  # A default value has been specified in the model field definition
751
554
  return properties["default"]
752
555
 
753
- elif not properties["required"]:
556
+ elif not properties.get("required", True):
557
+ return None
558
+
559
+ elif field_type == "null":
754
560
  return None
755
561
 
756
562
  elif "enum" in properties:
@@ -758,12 +564,18 @@ class Model(BaseModel, metaclass=ModelMetaclass):
758
564
 
759
565
  elif field_type in {"integer", "number"}:
760
566
  # For integer and float types we must check if there are imposed bounds
761
- lower = properties.get("minimum") or properties.get("exclusiveMinimum")
762
- upper = properties.get("maximum") or properties.get("exclusiveMaximum")
567
+
568
+ minimum = properties.get("minimum")
569
+ exclusive_minimum = properties.get("exclusiveMinimum")
570
+ maximum = properties.get("maximum")
571
+ exclusive_maximum = properties.get("exclusiveMaximum")
572
+
573
+ lower = minimum if minimum is not None else exclusive_minimum
574
+ upper = maximum if maximum is not None else exclusive_maximum
763
575
 
764
576
  # If the dtype is an unsigned integer type, we must return a positive value
765
- if "dtype" in properties:
766
- dtype = properties["dtype"]
577
+ if info.dtype:
578
+ dtype = info.dtype
767
579
  if dtype in (pl.UInt8, pl.UInt16, pl.UInt32, pl.UInt64):
768
580
  lower = 0 if lower is None else max(lower, 0)
769
581
 
@@ -798,7 +610,19 @@ class Model(BaseModel, metaclass=ModelMetaclass):
798
610
  elif "format" in properties and properties["format"] == "date":
799
611
  return date(year=1970, month=1, day=1)
800
612
  elif "format" in properties and properties["format"] == "date-time":
613
+ if "column_info" in properties:
614
+ dtype_str = properties["column_info"]["dtype"]
615
+ dtype = dtype_from_string(dtype_str)
616
+ if getattr(dtype, "time_zone", None) is not None:
617
+ tzinfo = ZoneInfo(dtype.time_zone)
618
+ else:
619
+ tzinfo = None
620
+ return datetime(year=1970, month=1, day=1, tzinfo=tzinfo)
801
621
  return datetime(year=1970, month=1, day=1)
622
+ elif "format" in properties and properties["format"] == "time":
623
+ return time(12, 30)
624
+ elif "format" in properties and properties["format"] == "duration":
625
+ return timedelta(1)
802
626
  elif "minLength" in properties:
803
627
  return "a" * properties["minLength"]
804
628
  elif "maxLength" in properties:
@@ -809,6 +633,18 @@ class Model(BaseModel, metaclass=ModelMetaclass):
809
633
  elif field_type == "boolean":
810
634
  return False
811
635
 
636
+ elif field_type == "object":
637
+ try:
638
+ props_o = cls.model_schema["$defs"][properties["title"]]["properties"]
639
+ return {f: cls.example_value(properties=props_o[f]) for f in props_o}
640
+ except AttributeError as err:
641
+ raise NotImplementedError(
642
+ "Nested example generation only supported for nested pt.Model classes."
643
+ ) from err
644
+
645
+ elif field_type == "array":
646
+ return [cls.example_value(properties=properties["items"])]
647
+
812
648
  else: # pragma: no cover
813
649
  raise NotImplementedError
814
650
 
@@ -817,8 +653,7 @@ class Model(BaseModel, metaclass=ModelMetaclass):
817
653
  cls: Type[ModelType],
818
654
  **kwargs: Any, # noqa: ANN401
819
655
  ) -> ModelType:
820
- """
821
- Produce model instance with filled dummy data for all unspecified fields.
656
+ """Produce model instance with filled dummy data for all unspecified fields.
822
657
 
823
658
  The type annotation of unspecified field is used to fill in type-correct
824
659
  dummy data, e.g. ``-1`` for ``int``, ``"dummy_string"`` for ``str``, and so
@@ -849,6 +684,7 @@ class Model(BaseModel, metaclass=ModelMetaclass):
849
684
  ...
850
685
  >>> Product.example(product_id=1)
851
686
  Product(product_id=1, name='dummy_string', temperature_zone='dry')
687
+
852
688
  """
853
689
  # Non-iterable values besides strings must be repeated
854
690
  wrong_columns = set(kwargs.keys()) - set(cls.columns)
@@ -870,8 +706,7 @@ class Model(BaseModel, metaclass=ModelMetaclass):
870
706
  data: Union[dict, Iterable],
871
707
  columns: Optional[Iterable[str]] = None,
872
708
  ) -> "pd.DataFrame":
873
- """
874
- Generate dataframe with dummy data for all unspecified columns.
709
+ """Generate dataframe with dummy data for all unspecified columns.
875
710
 
876
711
  Offers the same API as the pandas.DataFrame constructor.
877
712
  Non-iterable values, besides strings, are repeated until they become as long as
@@ -903,9 +738,10 @@ class Model(BaseModel, metaclass=ModelMetaclass):
903
738
  ...
904
739
 
905
740
  >>> Product.pandas_examples({"name": ["product A", "product B"]})
906
- product_id name temperature_zone
741
+ product_id name temperature_zone
907
742
  0 -1 product A dry
908
743
  1 -1 product B dry
744
+
909
745
  """
910
746
  if not _PANDAS_AVAILABLE:
911
747
  # Re-trigger the import error, but this time don't catch it
@@ -932,7 +768,7 @@ class Model(BaseModel, metaclass=ModelMetaclass):
932
768
  dummies = []
933
769
  for values in zip(*kwargs.values()):
934
770
  dummies.append(cls.example(**dict(zip(kwargs.keys(), values))))
935
- return pd.DataFrame([dummy.dict() for dummy in dummies])
771
+ return pd.DataFrame([dummy.model_dump() for dummy in dummies])
936
772
 
937
773
  @classmethod
938
774
  def examples(
@@ -940,8 +776,7 @@ class Model(BaseModel, metaclass=ModelMetaclass):
940
776
  data: Optional[Union[dict, Iterable]] = None,
941
777
  columns: Optional[Iterable[str]] = None,
942
778
  ) -> "patito.polars.DataFrame":
943
- """
944
- Generate polars dataframe with dummy data for all unspecified columns.
779
+ """Generate polars dataframe with dummy data for all unspecified columns.
945
780
 
946
781
  This constructor accepts the same data format as polars.DataFrame.
947
782
 
@@ -976,9 +811,9 @@ class Model(BaseModel, metaclass=ModelMetaclass):
976
811
  ┌──────────────┬──────────────────┬────────────┐
977
812
  │ name ┆ temperature_zone ┆ product_id │
978
813
  │ --- ┆ --- ┆ --- │
979
- │ str ┆ cat ┆ i64 │
814
+ │ str ┆ enum ┆ i64 │
980
815
  ╞══════════════╪══════════════════╪════════════╡
981
- │ dummy_string ┆ dry ┆ 0
816
+ │ dummy_string ┆ dry ┆ 1
982
817
  └──────────────┴──────────────────┴────────────┘
983
818
 
984
819
  >>> Product.examples({"name": ["product A", "product B"]})
@@ -986,11 +821,12 @@ class Model(BaseModel, metaclass=ModelMetaclass):
986
821
  ┌───────────┬──────────────────┬────────────┐
987
822
  │ name ┆ temperature_zone ┆ product_id │
988
823
  │ --- ┆ --- ┆ --- │
989
- │ str ┆ cat ┆ i64 │
824
+ │ str ┆ enum ┆ i64 │
990
825
  ╞═══════════╪══════════════════╪════════════╡
991
- │ product A ┆ dry ┆ 0
992
- │ product B ┆ dry ┆ 1
826
+ │ product A ┆ dry ┆ 1
827
+ │ product B ┆ dry ┆ 2
993
828
  └───────────┴──────────────────┴────────────┘
829
+
994
830
  """
995
831
  if data is None:
996
832
  # We should create an empty dataframe, but with the correct dtypes
@@ -1014,11 +850,13 @@ class Model(BaseModel, metaclass=ModelMetaclass):
1014
850
  if column_name not in kwargs:
1015
851
  if column_name in cls.unique_columns:
1016
852
  unique_series.append(
1017
- pl.first().cumcount().cast(dtype).alias(column_name)
853
+ pl.first().cum_count().cast(dtype).alias(column_name)
1018
854
  )
1019
855
  else:
1020
856
  example_value = cls.example_value(field=column_name)
1021
- series.append(pl.lit(example_value, dtype=dtype).alias(column_name))
857
+ series.append(
858
+ pl.Series(column_name, values=[example_value], dtype=dtype)
859
+ )
1022
860
  continue
1023
861
 
1024
862
  value = kwargs.get(column_name)
@@ -1030,7 +868,7 @@ class Model(BaseModel, metaclass=ModelMetaclass):
1030
868
  else:
1031
869
  series.append(pl.lit(value, dtype=dtype).alias(column_name))
1032
870
 
1033
- return DataFrame().with_columns(series).with_columns(unique_series)
871
+ return cls.DataFrame().with_columns(series).with_columns(unique_series)
1034
872
 
1035
873
  @classmethod
1036
874
  def join(
@@ -1038,8 +876,7 @@ class Model(BaseModel, metaclass=ModelMetaclass):
1038
876
  other: Type["Model"],
1039
877
  how: Literal["inner", "left", "outer", "asof", "cross", "semi", "anti"],
1040
878
  ) -> Type["Model"]:
1041
- """
1042
- Dynamically create a new model compatible with an SQL Join operation.
879
+ """Dynamically create a new model compatible with an SQL Join operation.
1043
880
 
1044
881
  For instance, ``ModelA.join(ModelB, how="left")`` will create a model containing
1045
882
  all the fields of ``ModelA`` and ``ModelB``, but where all fields of ``ModelB``
@@ -1078,6 +915,7 @@ class Model(BaseModel, metaclass=ModelMetaclass):
1078
915
 
1079
916
  >>> A.join(B, how="anti") is A
1080
917
  True
918
+
1081
919
  """
1082
920
  if how in {"semi", "anti"}:
1083
921
  return cls
@@ -1087,18 +925,13 @@ class Model(BaseModel, metaclass=ModelMetaclass):
1087
925
  (cls, {"outer"}),
1088
926
  (other, {"left", "outer", "asof"}),
1089
927
  ):
1090
- for field_name, field in model.__fields__.items():
1091
- field_type = field.type_
1092
- field_default = field.default
1093
- if how in nullable_methods and type(None) not in get_args(field.type_):
1094
- # This originally non-nullable field has become nullable
1095
- field_type = Optional[field_type]
1096
- elif field.required and field_default is None:
1097
- # We need to replace Pydantic's None default value with ... in order
1098
- # to make it clear that the field is still non-nullable and
1099
- # required.
1100
- field_default = ...
1101
- kwargs[field_name] = (field_type, field_default)
928
+ for field_name, field in model.model_fields.items():
929
+ make_nullable = how in nullable_methods and type(None) not in get_args(
930
+ field.annotation
931
+ )
932
+ kwargs[field_name] = cls._derive_field(
933
+ field, make_nullable=make_nullable
934
+ )
1102
935
 
1103
936
  return create_model(
1104
937
  f"{cls.__name__}{how.capitalize()}Join{other.__name__}",
@@ -1110,8 +943,7 @@ class Model(BaseModel, metaclass=ModelMetaclass):
1110
943
  def select(
1111
944
  cls: Type[ModelType], fields: Union[str, Iterable[str]]
1112
945
  ) -> Type["Model"]:
1113
- """
1114
- Create a new model consisting of only a subset of the model fields.
946
+ """Create a new model consisting of only a subset of the model fields.
1115
947
 
1116
948
  Args:
1117
949
  fields: A single field name as a string or a collection of strings.
@@ -1134,6 +966,7 @@ class Model(BaseModel, metaclass=ModelMetaclass):
1134
966
 
1135
967
  >>> sorted(MyModel.select(["b", "c"]).columns)
1136
968
  ['b', 'c']
969
+
1137
970
  """
1138
971
  if isinstance(fields, str):
1139
972
  fields = [fields]
@@ -1152,8 +985,7 @@ class Model(BaseModel, metaclass=ModelMetaclass):
1152
985
 
1153
986
  @classmethod
1154
987
  def drop(cls: Type[ModelType], name: Union[str, Iterable[str]]) -> Type["Model"]:
1155
- """
1156
- Return a new model where one or more fields are excluded.
988
+ """Return a new model where one or more fields are excluded.
1157
989
 
1158
990
  Args:
1159
991
  name: A single string field name, or a list of such field names,
@@ -1177,6 +1009,7 @@ class Model(BaseModel, metaclass=ModelMetaclass):
1177
1009
 
1178
1010
  >>> MyModel.drop(["b", "c"]).columns
1179
1011
  ['a']
1012
+
1180
1013
  """
1181
1014
  dropped_columns = {name} if isinstance(name, str) else set(name)
1182
1015
  mapping = {
@@ -1191,8 +1024,7 @@ class Model(BaseModel, metaclass=ModelMetaclass):
1191
1024
 
1192
1025
  @classmethod
1193
1026
  def prefix(cls: Type[ModelType], prefix: str) -> Type["Model"]:
1194
- """
1195
- Return a new model where all field names have been prefixed.
1027
+ """Return a new model where all field names have been prefixed.
1196
1028
 
1197
1029
  Args:
1198
1030
  prefix: String prefix to add to all field names.
@@ -1208,6 +1040,7 @@ class Model(BaseModel, metaclass=ModelMetaclass):
1208
1040
 
1209
1041
  >>> MyModel.prefix("x_").columns
1210
1042
  ['x_a', 'x_b']
1043
+
1211
1044
  """
1212
1045
  mapping = {f"{prefix}{field_name}": field_name for field_name in cls.columns}
1213
1046
  return cls._derive_model(
@@ -1217,8 +1050,7 @@ class Model(BaseModel, metaclass=ModelMetaclass):
1217
1050
 
1218
1051
  @classmethod
1219
1052
  def suffix(cls: Type[ModelType], suffix: str) -> Type["Model"]:
1220
- """
1221
- Return a new model where all field names have been suffixed.
1053
+ """Return a new model where all field names have been suffixed.
1222
1054
 
1223
1055
  Args:
1224
1056
  suffix: String suffix to add to all field names.
@@ -1235,6 +1067,7 @@ class Model(BaseModel, metaclass=ModelMetaclass):
1235
1067
 
1236
1068
  >>> MyModel.suffix("_x").columns
1237
1069
  ['a_x', 'b_x']
1070
+
1238
1071
  """
1239
1072
  mapping = {f"{field_name}{suffix}": field_name for field_name in cls.columns}
1240
1073
  return cls._derive_model(
@@ -1244,8 +1077,7 @@ class Model(BaseModel, metaclass=ModelMetaclass):
1244
1077
 
1245
1078
  @classmethod
1246
1079
  def rename(cls: Type[ModelType], mapping: Dict[str, str]) -> Type["Model"]:
1247
- """
1248
- Return a new model class where the specified fields have been renamed.
1080
+ """Return a new model class where the specified fields have been renamed.
1249
1081
 
1250
1082
  Args:
1251
1083
  mapping: A dictionary where the keys are the old field names
@@ -1265,6 +1097,7 @@ class Model(BaseModel, metaclass=ModelMetaclass):
1265
1097
 
1266
1098
  >>> MyModel.rename({"a": "A"}).columns
1267
1099
  ['b', 'A']
1100
+
1268
1101
  """
1269
1102
  non_existent_fields = set(mapping.keys()) - set(cls.columns)
1270
1103
  if non_existent_fields:
@@ -1287,8 +1120,7 @@ class Model(BaseModel, metaclass=ModelMetaclass):
1287
1120
  cls: Type[ModelType],
1288
1121
  **field_definitions: Any, # noqa: ANN401
1289
1122
  ) -> Type["Model"]:
1290
- """
1291
- Return a new model class where the given fields have been added.
1123
+ """Return a new model class where the given fields have been added.
1292
1124
 
1293
1125
  Args:
1294
1126
  **field_definitions: the keywords are of the form:
@@ -1310,6 +1142,7 @@ class Model(BaseModel, metaclass=ModelMetaclass):
1310
1142
  ...
1311
1143
  >>> MyModel.with_fields(b=(int, ...)).columns == ExpandedModel.columns
1312
1144
  True
1145
+
1313
1146
  """
1314
1147
  fields = {field_name: field_name for field_name in cls.columns}
1315
1148
  fields.update(field_definitions)
@@ -1319,49 +1152,16 @@ class Model(BaseModel, metaclass=ModelMetaclass):
1319
1152
  )
1320
1153
 
1321
1154
  @classmethod
1322
- def _schema_properties(cls) -> Dict[str, Dict[str, Any]]:
1323
- """
1324
- Return schema properties where definition references have been resolved.
1325
-
1326
- Returns:
1327
- Field information as a dictionary where the keys are field names and the
1328
- values are dictionaries containing metadata information about the field
1329
- itself.
1155
+ def _schema_properties(cls: Type[ModelType]) -> Mapping[str, Any]:
1156
+ return cls.model_schema["properties"]
1330
1157
 
1331
- Raises:
1332
- TypeError: if a field is annotated with an enum where the values are of
1333
- different types.
1334
- """
1335
- schema = cls.schema(ref_template="{model}")
1336
- required = schema.get("required", set())
1337
- fields = {}
1338
- for field_name, field_info in schema["properties"].items():
1339
- if "$ref" in field_info:
1340
- definition = schema["definitions"][field_info["$ref"]]
1341
- if "enum" in definition and "type" not in definition:
1342
- enum_types = set(type(value) for value in definition["enum"])
1343
- if len(enum_types) > 1:
1344
- raise TypeError(
1345
- "All enumerated values of enums used to annotate "
1346
- "Patito model fields must have the same type. "
1347
- "Encountered types: "
1348
- f"{sorted(map(lambda t: t.__name__, enum_types))}."
1349
- )
1350
- enum_type = enum_types.pop()
1351
- # TODO: Support time-delta, date, and date-time.
1352
- definition["type"] = {
1353
- str: "string",
1354
- int: "integer",
1355
- float: "number",
1356
- bool: "boolean",
1357
- type(None): "null",
1358
- }[enum_type]
1359
- fields[field_name] = definition
1360
- else:
1361
- fields[field_name] = field_info
1362
- fields[field_name]["required"] = field_name in required
1363
-
1364
- return fields
1158
+ @classmethod
1159
+ def _update_dfn(cls, annotation: Any, schema: Dict[str, Any]) -> None:
1160
+ try:
1161
+ if issubclass(annotation, Model) and annotation.__name__ != cls.__name__:
1162
+ schema["$defs"][annotation.__name__] = annotation.model_schema
1163
+ except TypeError:
1164
+ pass
1365
1165
 
1366
1166
  @classmethod
1367
1167
  def _derive_model(
@@ -1369,8 +1169,7 @@ class Model(BaseModel, metaclass=ModelMetaclass):
1369
1169
  model_name: str,
1370
1170
  field_mapping: Dict[str, Any],
1371
1171
  ) -> Type["Model"]:
1372
- """
1373
- Derive a new model with new field definitions.
1172
+ """Derive a new model with new field definitions.
1374
1173
 
1375
1174
  Args:
1376
1175
  model_name: Name of new model class.
@@ -1382,50 +1181,91 @@ class Model(BaseModel, metaclass=ModelMetaclass):
1382
1181
 
1383
1182
  Returns:
1384
1183
  A new model class derived from the model type of self.
1184
+
1385
1185
  """
1386
1186
  new_fields = {}
1387
1187
  for new_field_name, field_definition in field_mapping.items():
1388
1188
  if isinstance(field_definition, str):
1389
1189
  # A single string, interpreted as the name of a field on the existing
1390
1190
  # model.
1391
- old_field = cls.__fields__[field_definition]
1392
- field_type = old_field.type_
1393
- field_default = old_field.default
1394
- if old_field.required and field_default is None:
1395
- # The default None value needs to be replaced with ... in order to
1396
- # make the field required in the new model.
1397
- field_default = ...
1398
- new_fields[new_field_name] = (field_type, field_default)
1191
+ old_field = cls.model_fields[field_definition]
1192
+ new_fields[new_field_name] = cls._derive_field(old_field)
1399
1193
  else:
1400
1194
  # We have been given a (field_type, field_default) tuple defining the
1401
1195
  # new field directly.
1402
- new_fields[new_field_name] = field_definition
1196
+ field_type = field_definition[0]
1197
+ if field_definition[1] is None and type(None) not in get_args(
1198
+ field_type
1199
+ ):
1200
+ field_type = Optional[field_type]
1201
+ new_fields[new_field_name] = (field_type, field_definition[1])
1403
1202
  return create_model( # type: ignore
1404
1203
  __model_name=model_name,
1405
- __validators__={"__validators__": cls.__validators__},
1406
1204
  __base__=Model,
1407
1205
  **new_fields,
1408
1206
  )
1409
1207
 
1208
+ @staticmethod
1209
+ def _derive_field(
1210
+ field: fields.FieldInfo,
1211
+ make_nullable: bool = False,
1212
+ ) -> Tuple[Type | None, fields.FieldInfo]:
1213
+ field_type = field.annotation
1214
+ default = field.default
1215
+ extra_attrs = {
1216
+ x: getattr(field, x)
1217
+ for x in field._attributes_set
1218
+ if x in field.__slots__ and x not in ["annotation", "default"]
1219
+ }
1220
+ if make_nullable:
1221
+ if field_type is None:
1222
+ raise TypeError(
1223
+ "Cannot make field nullable if no type annotation is provided!"
1224
+ )
1225
+ else:
1226
+ # This originally non-nullable field has become nullable
1227
+ field_type = Optional[field_type]
1228
+ elif field.is_required() and default is None:
1229
+ # We need to replace Pydantic's None default value with ... in order
1230
+ # to make it clear that the field is still non-nullable and
1231
+ # required.
1232
+ default = ...
1233
+ field_new = fields.Field(default=default, **extra_attrs)
1234
+ field_new.metadata = field.metadata
1235
+ return field_type, field_new
1410
1236
 
1411
- class FieldDoc:
1412
- """
1413
- Annotate model field with additional type and validation information.
1414
1237
 
1415
- This class is built on ``pydantic.Field`` and you can find its full documentation
1416
- `here <https://pydantic-docs.helpmanual.io/usage/schema/#field-customization>`_.
1238
+ FIELD_KWARGS = getfullargspec(fields.Field)
1239
+
1240
+
1241
+ # Helper function for patito Field.
1242
+
1243
+
1244
+ def FieldCI(
1245
+ column_info: Type[ColumnInfo], *args: Any, **kwargs: Any
1246
+ ) -> Any: # annotate with Any to make the downstream type annotations happy
1247
+ """Annotate model field with additional type and validation information.
1248
+
1249
+ This class is built on ``pydantic.Field`` and you can find the list of parameters
1250
+ in the `API reference <https://docs.pydantic.dev/latest/api/fields/>`_.
1417
1251
  Patito adds additional parameters which are used when validating dataframes,
1418
- these are documented here.
1252
+ these are documented here along with the main parameters which can be used for
1253
+ validation. Pydantic's `usage documentation <https://docs.pydantic.dev/latest/concepts/fields/>`_
1254
+ can be read with the below examples.
1419
1255
 
1420
1256
  Args:
1257
+ column_info: (Type[ColumnInfo]): ColumnInfo object to pass args to.
1421
1258
  constraints (Union[polars.Expression, List[polars.Expression]): A single
1422
1259
  constraint or list of constraints, expressed as a polars expression objects.
1423
1260
  All rows must satisfy the given constraint. You can refer to the given column
1424
1261
  with ``pt.field``, which will automatically be replaced with
1425
1262
  ``polars.col(<field_name>)`` before evaluation.
1426
- unique (bool): All row values must be unique.
1263
+ derived_from (Union[str, polars.Expr]): used to mark fields that are meant to be
1264
+ derived from other fields. Users can specify a polars expression that will
1265
+ be called to derive the column value when `pt.DataFrame.derive` is called.
1427
1266
  dtype (polars.datatype.DataType): The given dataframe column must have the given
1428
1267
  polars dtype, for instance ``polars.UInt64`` or ``pl.Float32``.
1268
+ unique (bool): All row values must be unique.
1429
1269
  gt: All values must be greater than ``gt``.
1430
1270
  ge: All values must be greater than or equal to ``ge``.
1431
1271
  lt: All values must be less than ``lt``.
@@ -1436,10 +1276,12 @@ class FieldDoc:
1436
1276
  regex (str): UTF-8 string column must match regex pattern for all row values.
1437
1277
  min_length (int): Minimum length of all string values in a UTF-8 column.
1438
1278
  max_length (int): Maximum length of all string values in a UTF-8 column.
1279
+ args (Any): additional arguments to pass to pydantic's field.
1280
+ kwargs (Any): additional keyword arguments to pass to pydantic's field.
1439
1281
 
1440
1282
  Return:
1441
- FieldInfo: Object used to represent additional constraints put upon the given
1442
- field.
1283
+ `FieldInfo <https://docs.pydantic.dev/latest/api/fields/#pydantic.fields.FieldInfo>`_:
1284
+ Object used to represent additional constraints put upon the given field.
1443
1285
 
1444
1286
  Examples:
1445
1287
  >>> import patito as pt
@@ -1454,29 +1296,37 @@ class FieldDoc:
1454
1296
  ... # The product name should be from 3 to 128 characters long
1455
1297
  ... name: str = pt.Field(min_length=3, max_length=128)
1456
1298
  ...
1457
- ... # Represent colors in the form of upper cased hex colors
1458
- ... brand_color: str = pt.Field(regex=r"^\\#[0-9A-F]{6}$")
1459
1299
  ...
1460
1300
  >>> Product.DataFrame(
1461
1301
  ... {
1462
1302
  ... "product_id": [1, 1],
1463
1303
  ... "price": [400, 600],
1464
- ... "brand_color": ["#ab00ff", "AB00FF"],
1465
1304
  ... }
1466
1305
  ... ).validate()
1467
1306
  Traceback (most recent call last):
1468
- ...
1469
- patito.exceptions.ValidationError: 4 validation errors for Product
1307
+ patito.exceptions.DataFrameValidationError: 3 validation errors for Product
1470
1308
  name
1471
- Missing column (type=type_error.missingcolumns)
1309
+ Missing column (type=type_error.missingcolumns)
1472
1310
  product_id
1473
- 2 rows with duplicated values. (type=value_error.rowvalue)
1311
+ 2 rows with duplicated values. (type=value_error.rowvalue)
1474
1312
  price
1475
- Polars dtype Int64 does not match model field type. \
1476
- (type=type_error.columndtype)
1477
- brand_color
1478
- 2 rows with out of bound values. (type=value_error.rowvalue)
1313
+ Polars dtype Int64 does not match model field type. (type=type_error.columndtype)
1314
+
1479
1315
  """
1316
+ ci = column_info(**kwargs)
1317
+ for field in ci.model_fields_set:
1318
+ kwargs.pop(field)
1319
+ if kwargs.pop("modern_kwargs_only", True):
1320
+ for kwarg in kwargs:
1321
+ if kwarg not in FIELD_KWARGS.kwonlyargs and kwarg not in FIELD_KWARGS.args:
1322
+ raise ValueError(
1323
+ f"unexpected kwarg {kwarg}={kwargs[kwarg]}. Add modern_kwargs_only=False to ignore"
1324
+ )
1325
+ return fields.Field(
1326
+ *args,
1327
+ json_schema_extra={"column_info": ci},
1328
+ **kwargs,
1329
+ )
1480
1330
 
1481
1331
 
1482
- Field.__doc__ = FieldDoc.__doc__
1332
+ Field = partial(FieldCI, column_info=ColumnInfo)