amati 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,438 @@
1
+ """Generic factories to add repetitive validators to Pydantic models."""
2
+
3
+ from collections.abc import Iterable
4
+ from numbers import Number
5
+ from typing import Any, Optional, Sequence
6
+
7
+ from pydantic import model_validator
8
+ from pydantic._internal._decorators import (
9
+ ModelValidatorDecoratorInfo,
10
+ PydanticDescriptorProxy,
11
+ )
12
+
13
+ from amati.logging import Log, LogMixin
14
+ from amati.validators.generic import GenericObject
15
+
16
+
17
+ class UnknownValue:
18
+ """
19
+ Sentinel singleton to represent the existence of a value.
20
+ """
21
+
22
+ _instance = None
23
+
24
+ def __new__(cls):
25
+ if cls._instance is None:
26
+ cls._instance = super().__new__(cls)
27
+ return cls._instance
28
+
29
+ def __repr__(self): # pragma: no cover
30
+ return "UNKNOWN"
31
+
32
+ def __str__(self): # pragma: no cover
33
+ return "UNKNOWN"
34
+
35
+
36
+ UNKNOWN = UnknownValue()
37
+
38
+
39
+ def is_truthy_with_numeric_zero(value: Any) -> bool:
40
+ """Checks if a variable is truthy, treating numeric zero as truthy.
41
+
42
+ This function follows standard Python truthiness rules with one exception:
43
+ any numeric value that equals 0 (e.g., `0`, `0.0`, `0j`) is considered
44
+ truthy, rather than falsy.
45
+
46
+ Args:
47
+ value: The variable to test for truthiness. Can be of any type.
48
+
49
+ Returns:
50
+ True if the variable is truthy according to the custom rules, False otherwise.
51
+
52
+ Example:
53
+ >>> is_truthy_with_numeric_zero(0)
54
+ True
55
+ >>> is_truthy_with_numeric_zero(1)
56
+ True
57
+ >>> is_truthy_with_numeric_zero(0.0)
58
+ True
59
+ >>> is_truthy_with_numeric_zero([])
60
+ False
61
+ >>> is_truthy_with_numeric_zero("Hello")
62
+ True
63
+ >>> is_truthy_with_numeric_zero(None)
64
+ False
65
+ """
66
+ # Check if the value is a number and if it's equal to zero.
67
+ # numbers.Number is used to cover integers, floats, complex numbers, etc.
68
+ if isinstance(value, Number):
69
+ return True
70
+ # For all other cases, revert to standard Python's bool() conversion.
71
+ return bool(value)
72
+
73
+
74
+ def _get_candidates(
75
+ self: GenericObject, fields: Optional[Sequence[str]]
76
+ ) -> dict[str, Any]:
77
+ """
78
+ Helper function to filter down the list of fields of a model to examine.
79
+ """
80
+
81
+ model_fields: dict[str, Any] = self.model_dump()
82
+
83
+ options: Sequence[str] = fields or list(model_fields.keys())
84
+
85
+ return {
86
+ name: value
87
+ for name, value in model_fields.items()
88
+ if not name.startswith("_") and name in options
89
+ }
90
+
91
+
92
+ def at_least_one_of(
93
+ fields: Optional[Sequence[str]] = None,
94
+ ) -> PydanticDescriptorProxy[ModelValidatorDecoratorInfo]:
95
+ """Factory that adds validation to ensure at least one public field is non-empty.
96
+
97
+ This factory adds a Pydantic model validator that checks all public fields
98
+ (fields not starting with underscore) and raises the specified exception if
99
+ none of them contain truthy values.
100
+
101
+ Args:
102
+ fields: Optional sequence of field names to check. If provided, only these
103
+ fields will be validated. If not provided, all public fields will be
104
+ checked.
105
+
106
+ Returns:
107
+ The validator that ensures at least one public field is non-empty.
108
+
109
+ Example:
110
+ >>> from amati import Reference
111
+ >>> LogMixin.logs = []
112
+ >>>
113
+ >>> class User(GenericObject):
114
+ ... name: str = ""
115
+ ... email: str = None
116
+ ... _at_least_one_of = at_least_one_of()
117
+ ... _reference: Reference = Reference(title="test")
118
+ ...
119
+ >>> user = User()
120
+ >>> assert len(LogMixin.logs) == 1
121
+ >>> LogMixin.logs = []
122
+
123
+ >>> class User(GenericObject):
124
+ ... name: str = ""
125
+ ... email: str = None
126
+ ... age: int = None
127
+ ... _at_least_one_of = at_least_one_of(fields=["name", "email"])
128
+ ... _reference: Reference = Reference(title="test")
129
+ ...
130
+ >>>
131
+ >>> user = User(name="John") # Works fine
132
+ >>> assert not LogMixin.logs
133
+ >>> user = User()
134
+ >>> assert len(LogMixin.logs) == 1
135
+ >>> user = User(age=30)
136
+ >>> assert len(LogMixin.logs) == 2
137
+
138
+
139
+ Note:
140
+ Only public fields (not starting with '_') are checked. Private fields
141
+ and computed fields are ignored in the validation.
142
+ """
143
+
144
+ # Create the validator function with proper binding
145
+ @model_validator(mode="after")
146
+ def validate_at_least_one(self: GenericObject) -> Any:
147
+ """Validate that at least one public field is non-empty."""
148
+
149
+ # Early return if no fields exist
150
+ if not (candidates := _get_candidates(self, fields)):
151
+ return self
152
+
153
+ # Check if at least one public field has a truthy value
154
+ for value in candidates.values():
155
+ if is_truthy_with_numeric_zero(value):
156
+ return self
157
+
158
+ public_fields = ", ".join(f"{name}" for name in candidates.keys())
159
+
160
+ msg = f"{public_fields} do not have values, expected at least one."
161
+ LogMixin.log(
162
+ Log(
163
+ message=msg,
164
+ type=ValueError,
165
+ reference=self._reference, # pylint: disable=protected-access # type: ignore
166
+ )
167
+ )
168
+
169
+ return self
170
+
171
+ return validate_at_least_one
172
+
173
+
174
+ def only_one_of(
175
+ fields: Optional[Sequence[str]] = None,
176
+ ) -> PydanticDescriptorProxy[ModelValidatorDecoratorInfo]:
177
+ """Factory that adds validation to ensure one public field is non-empty.
178
+
179
+ This factory adds a Pydantic model validator that checks all public fields
180
+ (fields not starting with underscore) or a specified subset, and raises
181
+ a ValueError if more than one, or none, of them contain truthy values.
182
+
183
+ Args:
184
+ fields: Optional sequence of field names to check. If provided, only these
185
+ fields will be validated. If not provided, all public fields will be
186
+ checked.
187
+
188
+ Returns:
189
+ The validator that ensures at one public field is non-empty.
190
+
191
+ Example:
192
+ >>> from amati import Reference
193
+ >>> LogMixin.logs = []
194
+ >>>
195
+ >>> class User(GenericObject):
196
+ ... email: str = ""
197
+ ... name: str = ""
198
+ ... _only_one_of = only_one_of()
199
+ ... _reference: Reference = Reference(title="test")
200
+ ...
201
+ >>> user = User(email="test@example.com") # Works fine
202
+ >>> user = User(name="123-456-7890") # Works fine
203
+ >>> assert not LogMixin.logs
204
+ >>> user = User(email="a@b.com", name="123")
205
+ >>> assert LogMixin.logs
206
+ >>> LogMixin.logs = []
207
+
208
+ >>> class User(GenericObject):
209
+ ... name: str = ""
210
+ ... email: str = ""
211
+ ... age: int = None
212
+ ... _only_one_of = only_one_of(["name", "email"])
213
+ ... _reference: Reference = Reference(title="test")
214
+ ...
215
+ >>> user = User(name="Bob") # Works fine
216
+ >>> user = User(email="test@example.com") # Works fine
217
+ >>> user = User(name="Bob", age=30) # Works fine
218
+ >>> assert not LogMixin.logs
219
+ >>> user = User(name="Bob", email="a@b.com")
220
+ >>> assert len(LogMixin.logs) == 1
221
+ >>> user = User(age=30)
222
+ >>> assert len(LogMixin.logs) == 2
223
+
224
+ Note:
225
+ Only public fields (not starting with '_') are checked. Private fields
226
+ and computed fields are ignored in the validation.
227
+ """
228
+
229
+ @model_validator(mode="after")
230
+ def validate_only_one(self: GenericObject) -> Any:
231
+ """Validate that at most one public field is non-empty."""
232
+
233
+ # Early return if no fields exist
234
+ if not (candidates := _get_candidates(self, fields)):
235
+ return self
236
+
237
+ truthy: list[str] = []
238
+
239
+ # Store fields with a truthy value
240
+ for name, value in candidates.items():
241
+ if is_truthy_with_numeric_zero(value):
242
+ truthy.append(name)
243
+
244
+ if len(truthy) != 1:
245
+ msg = f"Expected at most one field to have a value, {", ".join(truthy)} did"
246
+
247
+ LogMixin.log(
248
+ Log(
249
+ message=msg,
250
+ type=ValueError,
251
+ reference=self._reference, # pylint: disable=protected-access # type: ignore
252
+ )
253
+ )
254
+
255
+ return self
256
+
257
+ return validate_only_one
258
+
259
+
260
+ def all_of(
261
+ fields: Optional[Sequence[str]] = None,
262
+ ) -> PydanticDescriptorProxy[ModelValidatorDecoratorInfo]:
263
+ """Factory that adds validation to ensure at most one public field is non-empty.
264
+
265
+ This factory adds a Pydantic model validator that checks all public fields
266
+ (fields not starting with underscore) or a specified subset, and raises
267
+ a ValueError if more than one of them contain truthy values.
268
+
269
+ Args:
270
+ fields: Optional sequence of field names to check. If provided, only these
271
+ fields will be validated. If not provided, all public fields will be
272
+ checked.
273
+
274
+ Returns:
275
+ The validator that ensures at most one public field is non-empty.
276
+
277
+ Example:
278
+ >>> from amati import Reference
279
+ >>> LogMixin.logs = []
280
+ >>>
281
+ >>> class User(GenericObject):
282
+ ... email: str = ""
283
+ ... name: str = ""
284
+ ... _all_of = all_of()
285
+ ... _reference: Reference = Reference(title="test")
286
+ ...
287
+ >>> user = User(email="a@b.com", name="123") # Works fine
288
+ >>> assert not LogMixin.logs
289
+ >>> user = User(email="test@example.com")
290
+ >>> assert len(LogMixin.logs) == 1
291
+ >>> user = User(name="123-456-7890")
292
+ >>> assert len(LogMixin.logs) == 2
293
+
294
+ >>> class User(GenericObject):
295
+ ... name: str = ""
296
+ ... email: str = ""
297
+ ... age: int = None
298
+ ... _all_of = all_of(["name", "email"])
299
+ ... _reference: Reference = Reference(title="test")
300
+ ...
301
+ >>> LogMixin.logs = []
302
+ >>> user = User(name="Bob", email="a@b.com") # Works fine
303
+ >>> assert not LogMixin.logs
304
+ >>> user = User(name="Bob")
305
+ >>> assert len(LogMixin.logs) == 1
306
+ >>> user = User(email="test@example.com")
307
+ >>> assert len(LogMixin.logs) == 2
308
+ >>> user = User(age=30)
309
+ >>> assert len(LogMixin.logs) == 3
310
+ >>> user = User(name="Bob", age=30)
311
+ >>> assert len(LogMixin.logs) == 4
312
+
313
+ Note:
314
+ Only public fields (not starting with '_') are checked. Private fields
315
+ and computed fields are ignored in the validation.
316
+ """
317
+
318
+ @model_validator(mode="after")
319
+ def validate_only_one(self: GenericObject) -> Any:
320
+ """Validate that at most one public field is non-empty."""
321
+
322
+ # Early return if no fields exist
323
+ if not (candidates := _get_candidates(self, fields)):
324
+ return self
325
+
326
+ falsy: list[str] = []
327
+
328
+ # Store fields with a falsy value
329
+ for name, value in candidates.items():
330
+ if not is_truthy_with_numeric_zero(value):
331
+ falsy.append(name)
332
+
333
+ if falsy:
334
+ msg = f"Expected at all fields to have a value, {", ".join(falsy)} did not"
335
+
336
+ LogMixin.log(
337
+ Log(
338
+ message=msg,
339
+ type=ValueError,
340
+ reference=self._reference, # pylint: disable=protected-access # type: ignore
341
+ )
342
+ )
343
+
344
+ return self
345
+
346
+ return validate_only_one
347
+
348
+
349
+ def if_then(
350
+ conditions: dict[str, Any] | None = None,
351
+ consequences: dict[str, Any | UnknownValue] | None = None,
352
+ ) -> PydanticDescriptorProxy[ModelValidatorDecoratorInfo]:
353
+ """Factory that adds validation to ensure if-then relationships between fields.
354
+
355
+ This factory adds a Pydantic model validator that checks if certain field conditions
356
+ are met, and if so, validates that other fields have specific values. This creates
357
+ an if-then relationship between model fields.
358
+
359
+ Args:
360
+ conditions: Dictionary mapping field names to their required values that trigger
361
+ the validation. All conditions must be met for the consequences to be
362
+ checked.
363
+ consequences: Dictionary mapping field names to their required values that must
364
+ be true when the conditions are met.
365
+
366
+ Returns:
367
+ A validator that ensures the if-then relationship between fields is maintained.
368
+
369
+ Raises:
370
+ ValueError: If a condition and consequence are not present
371
+
372
+ Example:
373
+ >>> from amati import Reference
374
+ >>> LogMixin.logs = []
375
+ >>>
376
+ >>> class User(GenericObject):
377
+ ... role: str = ""
378
+ ... can_edit: bool = False
379
+ ... _if_admin = if_then(
380
+ ... conditions={"role": "admin"},
381
+ ... consequences={"can_edit": True}
382
+ ... )
383
+ ... _reference: Reference = Reference(title="test")
384
+ ...
385
+ >>> user = User(role="admin", can_edit=True) # Works fine
386
+ >>> assert not LogMixin.logs
387
+ >>> user = User(role="admin", can_edit=False) # Fails validation
388
+ >>> assert len(LogMixin.logs) == 1
389
+ >>> user = User(role="user", can_edit=False) # Works fine
390
+ >>> assert len(LogMixin.logs) == 1
391
+ """
392
+
393
+ @model_validator(mode="after")
394
+ def validate_if_then(self: GenericObject) -> GenericObject:
395
+
396
+ if not conditions or not consequences:
397
+ raise ValueError(
398
+ "A condition and a consequence must be "
399
+ f"present to validate {self.__class__.__name__}"
400
+ )
401
+
402
+ model_fields: dict[str, Any] = self.model_dump()
403
+
404
+ candidates = {k: v for k, v in model_fields.items() if k in conditions}
405
+
406
+ for k, v in candidates.items():
407
+ # Unfulfilled condition
408
+ if not conditions[k] in (v, UNKNOWN):
409
+ return self
410
+
411
+ # None and UNKNOWN are opposites
412
+ if v is None and conditions[k] == UNKNOWN:
413
+ return self
414
+
415
+ for field, value in consequences.items():
416
+ actual = model_fields.get(field)
417
+
418
+ if (iterable := isinstance(value, Iterable)) and actual in value:
419
+ continue
420
+
421
+ if value == UNKNOWN and is_truthy_with_numeric_zero(actual):
422
+ continue
423
+
424
+ if value == actual:
425
+ continue
426
+
427
+ LogMixin.log(
428
+ Log(
429
+ message=f"Expected {field} to be {"in " if iterable else ""}"
430
+ f"{value} found {actual}",
431
+ type=ValueError,
432
+ reference=self._reference, # pylint: disable=protected-access # type: ignore
433
+ )
434
+ )
435
+
436
+ return self
437
+
438
+ return validate_if_then
amati/references.py ADDED
@@ -0,0 +1,33 @@
1
+ """
2
+ Represents a reference, declared here to not put in __init__.
3
+ """
4
+
5
+ from dataclasses import dataclass
6
+ from typing import Optional, Sequence
7
+
8
+
9
+ class AmatiReferenceException(Exception):
10
+ message: str = "Cannot construct empty references"
11
+
12
+
13
+ @dataclass
14
+ class Reference:
15
+ """
16
+ Attributes:
17
+ title : Title of the referenced content
18
+ section : Section of the referenced content
19
+ url : URL where the referenced content can be found
20
+ """
21
+
22
+ title: Optional[str] = None
23
+ section: Optional[str] = None
24
+ url: Optional[str] = None
25
+
26
+ def __post_init__(self):
27
+
28
+ if not self.title and not self.section and not self.url:
29
+ raise AmatiReferenceException
30
+
31
+
32
+ type ReferenceArray = Sequence[Reference]
33
+ type References = Reference | ReferenceArray
File without changes
@@ -0,0 +1,133 @@
1
+ """
2
+ A generic object to add extra functionality to pydantic.BaseModel.
3
+
4
+ Should be used as the base class for all classes in the project.
5
+ """
6
+
7
+ import re
8
+ from typing import (
9
+ Any,
10
+ Callable,
11
+ ClassVar,
12
+ Optional,
13
+ Pattern,
14
+ Type,
15
+ TypeVar,
16
+ Union,
17
+ cast,
18
+ )
19
+
20
+ from pydantic import BaseModel, ConfigDict, PrivateAttr
21
+ from pydantic_core._pydantic_core import PydanticUndefined
22
+
23
+ from amati import Reference
24
+ from amati.logging import Log, LogMixin
25
+
26
+
27
+ class GenericObject(LogMixin, BaseModel):
28
+ """
29
+ A generic model to overwrite provide extra functionality
30
+ to pydantic.BaseModel.
31
+ """
32
+
33
+ _reference: ClassVar[Reference] = PrivateAttr()
34
+ _extra_field_pattern: Optional[Pattern[str]] = PrivateAttr()
35
+
36
+ def __init__(self, **data: Any) -> None:
37
+
38
+ super().__init__(**data)
39
+
40
+ if self.model_config.get("extra") == "allow":
41
+ return
42
+
43
+ # If extra fields aren't allowed log those that aren't going to be added
44
+ # to the model.
45
+ for field in data:
46
+ if (
47
+ field not in self.model_dump().keys()
48
+ and field not in self.get_field_aliases()
49
+ ):
50
+ message = f"{field} is not a valid field for {self.__repr_name__()}."
51
+ self.log(
52
+ Log(
53
+ message=message,
54
+ type=ValueError,
55
+ )
56
+ )
57
+
58
+ def model_post_init(self, __context: Any) -> None:
59
+ if not self.model_extra:
60
+ return
61
+
62
+ if self.__private_attributes__["_extra_field_pattern"] == PrivateAttr(
63
+ PydanticUndefined
64
+ ):
65
+ return
66
+
67
+ # Any extra fields are allowed
68
+ if self._extra_field_pattern is None:
69
+ return
70
+
71
+ excess_fields: set[str] = set()
72
+
73
+ pattern: Pattern[str] = re.compile(self._extra_field_pattern)
74
+ excess_fields.update(
75
+ key for key in self.model_extra.keys() if not pattern.match(key)
76
+ )
77
+
78
+ for field in excess_fields:
79
+ message = f"{field} is not a valid field for {self.__repr_name__()}."
80
+ LogMixin.log(
81
+ Log(
82
+ message=message,
83
+ type=ValueError,
84
+ )
85
+ )
86
+
87
+ def get_field_aliases(self) -> list[str]:
88
+ """
89
+ Gets a list of aliases for confirming whether extra
90
+ fields are allowed.
91
+
92
+ Returns:
93
+ A list of field aliases for the class.
94
+ """
95
+
96
+ aliases: list[str] = []
97
+
98
+ for field_info in self.__class__.model_fields.values():
99
+ if field_info.alias:
100
+ aliases.append(field_info.alias)
101
+
102
+ return aliases
103
+
104
+
105
+ T = TypeVar("T", bound=GenericObject)
106
+
107
+
108
+ def allow_extra_fields(pattern: Optional[str] = None) -> Callable[[Type[T]], Type[T]]:
109
+ """
110
+ A decorator that modifies a Pydantic BaseModel to allow extra fields and optionally
111
+ sets a pattern for those extra fields
112
+
113
+ Args:
114
+ pattern: Optional pattern string for extra fields. If not provided all extra
115
+ fields will be allowed
116
+
117
+ Returns:
118
+ A decorator function that adds a ConfigDict allowing extra fields
119
+ and the pattern those fields should follow to the class.
120
+ """
121
+
122
+ def decorator(cls: Type[T]) -> Type[T]:
123
+ """
124
+ A decorator function that adds a ConfigDict allowing extra fields.
125
+ """
126
+ namespace: dict[str, Union[ConfigDict, Optional[str]]] = {
127
+ "model_config": ConfigDict(extra="allow"),
128
+ "_extra_field_pattern": pattern,
129
+ }
130
+ # Create a new class with the updated configuration
131
+ return cast(Type[T], type(cls.__name__, (cls,), namespace))
132
+
133
+ return decorator