validating 0.0.1__tar.gz → 0.0.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: validating
3
- Version: 0.0.1
3
+ Version: 0.0.3
4
4
  Summary: Provides a lightweight `attr()` descriptor factory that adds runtime validation to dataclass fields.
5
5
  Project-URL: Documentation, https://github.com/Chitaoji/validating/blob/main/README.md
6
6
  Project-URL: Repository, https://github.com/Chitaoji/validating/
@@ -181,7 +181,14 @@ class EvenNumber:
181
181
  This project falls under the BSD 3-Clause License.
182
182
 
183
183
  ## History
184
+ ### v0.0.3
185
+ * Improved support for `PEP 585` generics with quoted builtin annotations (for example `list["int"]`) to ensure consistent runtime type validation.
184
186
 
185
- ### v0.0.1
187
+ ### v0.0.2
188
+ * Fixed runtime checking for `TypedDict` and several special typing hints to avoid invalid `isinstance` paths.
189
+ * Added runtime validation support for `typing.Unpack`-style annotations.
190
+ * Improved forward-reference handling under `TYPE_CHECKING` imports, including deferred annotation resolution and safer fallback behavior when references are temporarily unresolved.
191
+ * Refined the `@validate` argument-checking path for better robustness and consistency across annotated call patterns.
186
192
 
187
- - Initial release.
193
+ ### v0.0.1
194
+ * Initial release.
@@ -165,7 +165,14 @@ class EvenNumber:
165
165
  This project falls under the BSD 3-Clause License.
166
166
 
167
167
  ## History
168
+ ### v0.0.3
169
+ * Improved support for `PEP 585` generics with quoted builtin annotations (for example `list["int"]`) to ensure consistent runtime type validation.
168
170
 
169
- ### v0.0.1
171
+ ### v0.0.2
172
+ * Fixed runtime checking for `TypedDict` and several special typing hints to avoid invalid `isinstance` paths.
173
+ * Added runtime validation support for `typing.Unpack`-style annotations.
174
+ * Improved forward-reference handling under `TYPE_CHECKING` imports, including deferred annotation resolution and safer fallback behavior when references are temporarily unresolved.
175
+ * Refined the `@validate` argument-checking path for better robustness and consistency across annotated call patterns.
170
176
 
171
- - Initial release.
177
+ ### v0.0.1
178
+ * Initial release.
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "validating"
7
- version = "0.0.1"
7
+ version = "0.0.3"
8
8
  dependencies = []
9
9
  requires-python = ">=3.13"
10
10
  authors = [
@@ -25,4 +25,4 @@ classifiers = [
25
25
 
26
26
  [project.urls]
27
27
  Documentation = "https://github.com/Chitaoji/validating/blob/main/README.md"
28
- Repository = "https://github.com/Chitaoji/validating/"
28
+ Repository = "https://github.com/Chitaoji/validating/"
@@ -8,12 +8,25 @@ NOTE: this module is private. All functions and objects are available in the mai
8
8
  """
9
9
 
10
10
  import sys
11
+ import builtins
12
+ from ast import (
13
+ Attribute,
14
+ If,
15
+ Import,
16
+ ImportFrom,
17
+ Name,
18
+ NodeVisitor,
19
+ parse,
20
+ )
11
21
  from dataclasses import Field, field
12
22
  from functools import partialmethod
23
+ from importlib import import_module
24
+ from pathlib import Path
13
25
  from types import UnionType
14
26
  from typing import (
15
27
  Any,
16
28
  Callable,
29
+ ForwardRef,
17
30
  Literal,
18
31
  Optional,
19
32
  Union,
@@ -22,6 +35,22 @@ from typing import (
22
35
  get_type_hints,
23
36
  )
24
37
 
38
+ try: # pragma: no cover - Python >= 3.11
39
+ from typing import NotRequired, Required, Unpack
40
+ except ImportError: # pragma: no cover - Python < 3.11
41
+ from typing_extensions import NotRequired, Required, Unpack
42
+
43
+ try: # pragma: no cover - available on modern Python versions
44
+ from typing import is_typeddict
45
+ except ImportError: # pragma: no cover - compatibility fallback
46
+ def is_typeddict(type_hint: Any) -> bool:
47
+ return bool(
48
+ isinstance(type_hint, type)
49
+ and isinstance(getattr(type_hint, "__annotations__", None), dict)
50
+ and hasattr(type_hint, "__required_keys__")
51
+ and hasattr(type_hint, "__optional_keys__")
52
+ )
53
+
25
54
  __all__ = ["attr"]
26
55
 
27
56
 
@@ -536,6 +565,7 @@ def _resolve_field_type_hint(
536
565
  name: str,
537
566
  *,
538
567
  localns: dict[str, Any] | None = None,
568
+ include_type_checking_names: bool = False,
539
569
  ) -> type:
540
570
  if name not in cls.__annotations__:
541
571
  return Any
@@ -544,18 +574,28 @@ def _resolve_field_type_hint(
544
574
  if not isinstance(raw_type_hint, str):
545
575
  return raw_type_hint
546
576
 
577
+ merged_localns = localns
578
+ if include_type_checking_names:
579
+ merged_localns = _merge_localns(
580
+ localns,
581
+ _collect_type_checking_names(cls.__module__),
582
+ )
583
+
547
584
  try:
548
585
  resolved_hints = get_type_hints(
549
586
  cls,
550
587
  globalns=vars(sys.modules[cls.__module__]),
551
- localns=localns,
588
+ localns=merged_localns,
552
589
  include_extras=True,
553
590
  )
554
- except NameError as exc:
555
- raise ValidatorError(
556
- f"failed to resolve annotation for {cls.__name__}.{name}: {raw_type_hint!r}"
557
- ) from exc
558
591
  except Exception as exc: # pragma: no cover - exact exception depends on annotation
592
+ if not include_type_checking_names:
593
+ return _resolve_field_type_hint(
594
+ cls,
595
+ name,
596
+ localns=localns,
597
+ include_type_checking_names=True,
598
+ )
559
599
  raise ValidatorError(
560
600
  f"failed to resolve annotation for {cls.__name__}.{name}: {raw_type_hint!r}"
561
601
  ) from exc
@@ -595,6 +635,89 @@ def _collect_runtime_localns(
595
635
  return merged or None
596
636
 
597
637
 
638
+ def _merge_localns(*namespaces: dict[str, Any] | None) -> dict[str, Any] | None:
639
+ merged: dict[str, Any] = {}
640
+ for namespace in namespaces:
641
+ if namespace:
642
+ merged.update(namespace)
643
+ return merged or None
644
+
645
+
646
+ def _collect_type_checking_names(module_name: str) -> dict[str, Any]:
647
+ module = sys.modules.get(module_name)
648
+ if module is None:
649
+ return {}
650
+
651
+ module_file = getattr(module, "__file__", None)
652
+ if module_file is None:
653
+ return {}
654
+
655
+ try:
656
+ source = Path(module_file).read_text(encoding="utf-8")
657
+ except (OSError, UnicodeDecodeError):
658
+ return {}
659
+
660
+ package = getattr(module, "__package__", None)
661
+ return _TypeCheckingImportCollector(package).collect(source)
662
+
663
+
664
+ class _TypeCheckingImportCollector(NodeVisitor):
665
+ def __init__(self, package: str | None) -> None:
666
+ self.package = package
667
+ self.names: dict[str, Any] = {}
668
+
669
+ def collect(self, source: str) -> dict[str, Any]:
670
+ try:
671
+ tree = parse(source)
672
+ except SyntaxError:
673
+ return {}
674
+ self.visit(tree)
675
+ return self.names
676
+
677
+ def visit_If(self, node: If) -> None:
678
+ if self._is_type_checking_guard(node.test):
679
+ for stmt in node.body:
680
+ self._consume_type_checking_stmt(stmt)
681
+
682
+ self.generic_visit(node)
683
+
684
+ @staticmethod
685
+ def _is_type_checking_guard(test: Any) -> bool:
686
+ if isinstance(test, Name):
687
+ return test.id == "TYPE_CHECKING"
688
+ if isinstance(test, Attribute):
689
+ return isinstance(test.value, Name) and test.value.id == "typing" and test.attr == "TYPE_CHECKING"
690
+ return False
691
+
692
+ def _consume_type_checking_stmt(self, stmt: Any) -> None:
693
+ if isinstance(stmt, Import):
694
+ for alias in stmt.names:
695
+ try:
696
+ module = import_module(alias.name)
697
+ except Exception:
698
+ continue
699
+ self.names[alias.asname or alias.name.split(".")[0]] = module
700
+ return
701
+
702
+ if not isinstance(stmt, ImportFrom) or stmt.module is None:
703
+ return
704
+
705
+ target_module = "." * stmt.level + stmt.module
706
+ try:
707
+ imported = import_module(target_module, self.package)
708
+ except Exception:
709
+ return
710
+
711
+ for alias in stmt.names:
712
+ if alias.name == "*":
713
+ continue
714
+ try:
715
+ symbol = getattr(imported, alias.name)
716
+ except AttributeError:
717
+ continue
718
+ self.names[alias.asname or alias.name] = symbol
719
+
720
+
598
721
  def isoftype(
599
722
  value: object, type_hint: type, name: str, path: str = ""
600
723
  ) -> Optional[str]:
@@ -621,12 +744,36 @@ def isoftype(
621
744
  if type_hint is Any:
622
745
  return None
623
746
 
747
+ if isinstance(type_hint, str):
748
+ resolved_builtin = getattr(builtins, type_hint, None)
749
+ if isinstance(resolved_builtin, type):
750
+ return isoftype(value, resolved_builtin, name, path)
751
+ # Unresolved string annotations can appear in runtime generic aliases
752
+ # such as ``list["MyType"]``. Match ForwardRef behavior and accept.
753
+ return None
754
+
624
755
  origin = get_origin(type_hint)
625
756
  args = get_args(type_hint)
626
757
 
627
758
  if origin is None:
628
- if isinstance(value, type_hint):
759
+ if isinstance(type_hint, ForwardRef):
760
+ # Unresolved forward references may appear in dynamic annotations
761
+ # (e.g. ``Union[ForwardRef("MyTypedDict"), None]``). At runtime,
762
+ # accept these values instead of failing with a low-signal error.
629
763
  return None
764
+
765
+ newtype_super = getattr(type_hint, "__supertype__", None)
766
+ if newtype_super is not None:
767
+ return isoftype(value, newtype_super, name, path)
768
+
769
+ if is_typeddict(type_hint):
770
+ return _validate_typed_dict(value, type_hint, name, path)
771
+
772
+ try:
773
+ if isinstance(value, type_hint):
774
+ return None
775
+ except TypeError:
776
+ pass
630
777
  return _format_isoftype_error(
631
778
  path, f"{type_hint!r}, got {type(value)!r} instead"
632
779
  )
@@ -645,6 +792,10 @@ def isoftype(
645
792
  return None
646
793
  return _format_isoftype_error(path, f"one of {args!r}, got {value!r} instead")
647
794
 
795
+ if origin is Unpack:
796
+ (unpacked_type,) = args
797
+ return isoftype(value, unpacked_type, name, path)
798
+
648
799
  if origin is list:
649
800
  (elem_type,) = args
650
801
  if not isinstance(value, list):
@@ -713,4 +864,46 @@ def _format_isoftype_error(path: str, detail: str) -> str:
713
864
  return f"expected {detail}"
714
865
 
715
866
 
867
+ def _validate_typed_dict(
868
+ value: object,
869
+ type_hint: Any,
870
+ name: str,
871
+ path: str,
872
+ ) -> Optional[str]:
873
+ if not isinstance(value, dict):
874
+ return _format_isoftype_error(path, f"a dict, got {type(value)!r} instead")
875
+
876
+ annotations = getattr(type_hint, "__annotations__", {})
877
+ required_keys = set(getattr(type_hint, "__required_keys__", set()))
878
+ optional_keys = set(getattr(type_hint, "__optional_keys__", set()))
879
+ allowed_keys = required_keys | optional_keys | set(annotations)
880
+
881
+ missing = sorted(required_keys - set(value))
882
+ if missing:
883
+ keys = ", ".join(repr(k) for k in missing)
884
+ return _format_isoftype_error(path, f"missing required keys: {keys}")
885
+
886
+ extra = sorted(set(value) - allowed_keys)
887
+ if extra:
888
+ keys = ", ".join(repr(k) for k in extra)
889
+ return _format_isoftype_error(path, f"unexpected keys: {keys}")
890
+
891
+ for key, annotated in annotations.items():
892
+ if key not in value:
893
+ continue
894
+ key_type = _unwrap_required_marker(annotated)
895
+ key_error = isoftype(value[key], key_type, name, f"{name}[{key!r}]")
896
+ if key_error is not None:
897
+ return key_error
898
+
899
+ return None
900
+
901
+
902
+ def _unwrap_required_marker(type_hint: Any) -> Any:
903
+ origin = get_origin(type_hint)
904
+ if origin is Required or origin is NotRequired:
905
+ return get_args(type_hint)[0]
906
+ return type_hint
907
+
908
+
716
909
  class ValidatorError(RuntimeError): ...
@@ -13,12 +13,23 @@ from functools import wraps
13
13
  from inspect import Parameter, Signature, getsourcelines, signature
14
14
  from textwrap import dedent
15
15
  from traceback import TracebackException
16
- from typing import Any, Callable, get_type_hints
16
+ from typing import Any, Callable, get_args, get_origin, get_type_hints
17
17
 
18
- from .valid_attr import _collect_runtime_localns, _get_owner_localns, isoftype
18
+ from .valid_attr import (
19
+ _collect_runtime_localns,
20
+ _collect_type_checking_names,
21
+ _get_owner_localns,
22
+ _merge_localns,
23
+ isoftype,
24
+ )
19
25
 
20
26
  __all__ = ["validate"]
21
27
 
28
+ try: # pragma: no cover - Python >= 3.11
29
+ from typing import Unpack
30
+ except ImportError: # pragma: no cover - Python < 3.11
31
+ from typing_extensions import Unpack
32
+
22
33
  _VALIDATE_MARKER = "__validating_is_validate_wrapped__"
23
34
 
24
35
 
@@ -53,6 +64,7 @@ def validate[T](func: T) -> T:
53
64
  return func
54
65
 
55
66
  sig = signature(func)
67
+ parameters = sig.parameters
56
68
  owner_localns = _get_owner_localns()
57
69
  resolved_annotations, deferred_annotations = _resolve_signature_annotations(
58
70
  func,
@@ -69,7 +81,7 @@ def validate[T](func: T) -> T:
69
81
  deferred_annotations,
70
82
  initial_localns=owner_localns,
71
83
  )
72
- _validate_bound_arguments(func, sig, bound.arguments, resolved_annotations)
84
+ _validate_bound_arguments(func, parameters, bound.arguments, resolved_annotations)
73
85
  try:
74
86
  return func(*args, **kwargs)
75
87
  except AssertionError as exc:
@@ -219,7 +231,15 @@ def _resolve_signature_annotations(
219
231
  include_extras=True,
220
232
  )
221
233
  except Exception:
222
- type_hints = {}
234
+ try:
235
+ type_hints = get_type_hints(
236
+ func,
237
+ globalns=globalns,
238
+ localns=_merge_localns(localns, _collect_type_checking_names(func.__module__)),
239
+ include_extras=True,
240
+ )
241
+ except Exception:
242
+ type_hints = {}
223
243
 
224
244
  for name, param in sig.parameters.items():
225
245
  annotation = param.annotation
@@ -246,7 +266,10 @@ def _resolve_deferred_signature_annotations(
246
266
  if not deferred_annotations:
247
267
  return
248
268
 
249
- localns = _collect_runtime_localns(initial_localns)
269
+ localns = _merge_localns(
270
+ _collect_runtime_localns(initial_localns),
271
+ _collect_type_checking_names(func.__module__),
272
+ )
250
273
  globalns = vars(__import__(func.__module__, fromlist=["*"]))
251
274
  try:
252
275
  type_hints = get_type_hints(
@@ -276,12 +299,12 @@ def _resolve_deferred_signature_annotations(
276
299
 
277
300
  def _validate_bound_arguments(
278
301
  func: Callable[..., Any],
279
- sig: Signature,
302
+ parameters: dict[str, Parameter],
280
303
  arguments: dict[str, Any],
281
304
  resolved_annotations: dict[str, Any],
282
305
  ) -> None:
283
306
  for name, value in arguments.items():
284
- param = sig.parameters[name]
307
+ param = parameters[name]
285
308
  if name not in resolved_annotations:
286
309
  continue
287
310
  annotation = resolved_annotations[name]
@@ -295,13 +318,18 @@ def _validate_bound_arguments(
295
318
  path=f"{name}[{idx}]",
296
319
  )
297
320
  if mismatch_reason is not None:
298
- raise TypeError(
299
- f"invalid type for argument {name!r} of {func.__name__}: "
300
- + mismatch_reason
301
- )
321
+ _raise_argument_type_error(func, name, mismatch_reason)
302
322
  continue
303
323
 
304
324
  if param.kind is Parameter.VAR_KEYWORD:
325
+ if get_origin(annotation) is Unpack:
326
+ unpacked = get_args(annotation)
327
+ unpacked_annotation = unpacked[0] if unpacked else annotation
328
+ mismatch_reason = isoftype(value, unpacked_annotation, name, path=name)
329
+ if mismatch_reason is not None:
330
+ _raise_argument_type_error(func, name, mismatch_reason)
331
+ continue
332
+
305
333
  for key, item in value.items():
306
334
  mismatch_reason = isoftype(
307
335
  item,
@@ -310,15 +338,17 @@ def _validate_bound_arguments(
310
338
  path=f"{name}[{key!r}]",
311
339
  )
312
340
  if mismatch_reason is not None:
313
- raise TypeError(
314
- f"invalid type for argument {name!r} of {func.__name__}: "
315
- + mismatch_reason
316
- )
341
+ _raise_argument_type_error(func, name, mismatch_reason)
317
342
  continue
318
343
 
319
344
  mismatch_reason = isoftype(value, annotation, name)
320
345
  if mismatch_reason is not None:
321
- raise TypeError(
322
- f"invalid type for argument {name!r} of {func.__name__}: "
323
- + mismatch_reason
324
- )
346
+ _raise_argument_type_error(func, name, mismatch_reason)
347
+
348
+
349
+ def _raise_argument_type_error(
350
+ func: Callable[..., Any], name: str, mismatch_reason: str
351
+ ) -> None:
352
+ raise TypeError(
353
+ f"invalid type for argument {name!r} of {func.__name__}: " + mismatch_reason
354
+ )
@@ -1,6 +1,15 @@
1
1
  import unittest
2
+ import sys
3
+ from importlib.util import module_from_spec, spec_from_file_location
4
+ from pathlib import Path
5
+ from tempfile import TemporaryDirectory
2
6
  from dataclasses import dataclass
3
- from typing import Any, Literal
7
+ from typing import Any, ForwardRef, Literal, TypedDict
8
+
9
+ try:
10
+ from typing import Unpack
11
+ except ImportError:
12
+ from typing_extensions import Unpack
4
13
 
5
14
  from src.validating import (
6
15
  ValidatorError,
@@ -131,6 +140,17 @@ class TestAttrWithDataclasses(unittest.TestCase):
131
140
  with self.assertRaises(TypeError):
132
141
  u.age = "bad"
133
142
 
143
+ def test_type_validation_for_quoted_builtin_in_pep585_generic(self):
144
+ @validating_dataclass
145
+ class Payload:
146
+ values: list["int"] = []
147
+
148
+ item = Payload([1])
149
+ self.assertEqual(item.values, [1])
150
+
151
+ with self.assertRaises(TypeError):
152
+ Payload(["1"])
153
+
134
154
  def test_allowlist_and_denylist(self):
135
155
  @dataclass
136
156
  class Role:
@@ -297,6 +317,86 @@ class TestAttrWithDataclasses(unittest.TestCase):
297
317
  cfg = Config(retries=LaterType())
298
318
  self.assertIsInstance(cfg.retries, LaterType)
299
319
 
320
+ def test_type_checking_import_is_lazily_loaded_for_attr(self):
321
+ with TemporaryDirectory() as temp_dir:
322
+ root = Path(temp_dir)
323
+ (root / "lazy_models.py").write_text(
324
+ "class LaterType:\n"
325
+ " pass\n",
326
+ encoding="utf-8",
327
+ )
328
+ (root / "lazy_consumer.py").write_text(
329
+ "from dataclasses import dataclass\n"
330
+ "from typing import TYPE_CHECKING\n"
331
+ "from src.validating import attr\n"
332
+ "\n"
333
+ "if TYPE_CHECKING:\n"
334
+ " from lazy_models import LaterType\n"
335
+ "\n"
336
+ "@dataclass\n"
337
+ "class Config:\n"
338
+ " value: 'LaterType' = attr()\n",
339
+ encoding="utf-8",
340
+ )
341
+
342
+ self.assertNotIn("lazy_models", sys.modules)
343
+
344
+ consumer_spec = spec_from_file_location("lazy_consumer", root / "lazy_consumer.py")
345
+ assert consumer_spec is not None and consumer_spec.loader is not None
346
+ consumer_module = module_from_spec(consumer_spec)
347
+ sys.modules["lazy_consumer"] = consumer_module
348
+ consumer_spec.loader.exec_module(consumer_module)
349
+
350
+ self.assertNotIn("lazy_models", sys.modules)
351
+
352
+ models_spec = spec_from_file_location("lazy_models", root / "lazy_models.py")
353
+ assert models_spec is not None and models_spec.loader is not None
354
+ models_module = module_from_spec(models_spec)
355
+ sys.modules["lazy_models"] = models_module
356
+ models_spec.loader.exec_module(models_module)
357
+
358
+ cfg = consumer_module.Config(value=models_module.LaterType())
359
+ self.assertIsInstance(cfg.value, models_module.LaterType)
360
+
361
+ def test_forward_string_annotation_from_type_checking_import_is_resolved_for_attr(
362
+ self,
363
+ ):
364
+ with TemporaryDirectory() as temp_dir:
365
+ root = Path(temp_dir)
366
+ (root / "models.py").write_text(
367
+ "class LaterType:\n"
368
+ " pass\n",
369
+ encoding="utf-8",
370
+ )
371
+ (root / "consumer.py").write_text(
372
+ "from dataclasses import dataclass\n"
373
+ "from typing import TYPE_CHECKING\n"
374
+ "from src.validating import attr\n"
375
+ "\n"
376
+ "if TYPE_CHECKING:\n"
377
+ " from models import LaterType\n"
378
+ "\n"
379
+ "@dataclass\n"
380
+ "class Config:\n"
381
+ " value: 'LaterType' = attr()\n",
382
+ encoding="utf-8",
383
+ )
384
+
385
+ models_spec = spec_from_file_location("models", root / "models.py")
386
+ assert models_spec is not None and models_spec.loader is not None
387
+ models_module = module_from_spec(models_spec)
388
+ sys.modules["models"] = models_module
389
+ models_spec.loader.exec_module(models_module)
390
+
391
+ consumer_spec = spec_from_file_location("consumer", root / "consumer.py")
392
+ assert consumer_spec is not None and consumer_spec.loader is not None
393
+ consumer_module = module_from_spec(consumer_spec)
394
+ sys.modules["consumer"] = consumer_module
395
+ consumer_spec.loader.exec_module(consumer_module)
396
+
397
+ cfg = consumer_module.Config(value=models_module.LaterType())
398
+ self.assertIsInstance(cfg.value, models_module.LaterType)
399
+
300
400
  def test_unresolvable_string_annotation_raises_for_attr(self):
301
401
  @dataclass
302
402
  class Config:
@@ -497,6 +597,23 @@ class TestValidateFunctionDecorator(unittest.TestCase):
497
597
  with self.assertRaises(TypeError):
498
598
  collect(1, 2, key=3)
499
599
 
600
+ def test_validate_supports_unpack_typed_dict_for_kwargs(self):
601
+ class Query(TypedDict):
602
+ limit: int
603
+ cursor: str
604
+
605
+ @validate
606
+ def fetch(**kwargs: Unpack[Query]):
607
+ return kwargs
608
+
609
+ self.assertEqual(fetch(limit=1, cursor="next"), {"limit": 1, "cursor": "next"})
610
+ with self.assertRaises(TypeError):
611
+ fetch(limit="1", cursor="next")
612
+ with self.assertRaises(TypeError):
613
+ fetch(limit=1)
614
+ with self.assertRaises(TypeError):
615
+ fetch(limit=1, cursor="next", extra="x")
616
+
500
617
  def test_validate_resolves_string_annotations(self):
501
618
  @validate
502
619
  def add(a: "int", b: "int") -> int:
@@ -506,6 +623,17 @@ class TestValidateFunctionDecorator(unittest.TestCase):
506
623
  with self.assertRaises(TypeError):
507
624
  add("1", 2)
508
625
 
626
+ def test_validate_accepts_union_with_unresolved_forward_ref(self):
627
+ @validate
628
+ def set_figure(kwargs: ForwardRef("SubplotDict") | None):
629
+ return kwargs
630
+
631
+ self.assertEqual(
632
+ set_figure({"left": 0.1, "right": 0.9}),
633
+ {"left": 0.1, "right": 0.9},
634
+ )
635
+ self.assertIsNone(set_figure(None))
636
+
509
637
  def test_validate_resolves_forward_string_annotations_in_local_scope(self):
510
638
  @validate
511
639
  def build(value: "LaterType") -> "LaterType":
@@ -518,6 +646,86 @@ class TestValidateFunctionDecorator(unittest.TestCase):
518
646
  with self.assertRaises(TypeError):
519
647
  build(1)
520
648
 
649
+ def test_validate_type_checking_import_is_lazily_loaded(self):
650
+ with TemporaryDirectory() as temp_dir:
651
+ root = Path(temp_dir)
652
+ (root / "lazy_v_models.py").write_text(
653
+ "class LaterType:\n"
654
+ " pass\n",
655
+ encoding="utf-8",
656
+ )
657
+ (root / "lazy_v_consumer.py").write_text(
658
+ "from typing import TYPE_CHECKING\n"
659
+ "from src.validating import validate\n"
660
+ "\n"
661
+ "if TYPE_CHECKING:\n"
662
+ " from lazy_v_models import LaterType\n"
663
+ "\n"
664
+ "@validate\n"
665
+ "def build(value: 'LaterType') -> 'LaterType':\n"
666
+ " return value\n",
667
+ encoding="utf-8",
668
+ )
669
+
670
+ self.assertNotIn("lazy_v_models", sys.modules)
671
+
672
+ consumer_spec = spec_from_file_location("lazy_v_consumer", root / "lazy_v_consumer.py")
673
+ assert consumer_spec is not None and consumer_spec.loader is not None
674
+ consumer_module = module_from_spec(consumer_spec)
675
+ sys.modules["lazy_v_consumer"] = consumer_module
676
+ consumer_spec.loader.exec_module(consumer_module)
677
+
678
+ self.assertNotIn("lazy_v_models", sys.modules)
679
+
680
+ models_spec = spec_from_file_location("lazy_v_models", root / "lazy_v_models.py")
681
+ assert models_spec is not None and models_spec.loader is not None
682
+ models_module = module_from_spec(models_spec)
683
+ sys.modules["lazy_v_models"] = models_module
684
+ models_spec.loader.exec_module(models_module)
685
+
686
+ instance = models_module.LaterType()
687
+ self.assertIs(consumer_module.build(instance), instance)
688
+ with self.assertRaises(TypeError):
689
+ consumer_module.build(1)
690
+
691
+ def test_validate_resolves_forward_annotation_from_type_checking_import(self):
692
+ with TemporaryDirectory() as temp_dir:
693
+ root = Path(temp_dir)
694
+ (root / "models.py").write_text(
695
+ "class LaterType:\n"
696
+ " pass\n",
697
+ encoding="utf-8",
698
+ )
699
+ (root / "consumer.py").write_text(
700
+ "from typing import TYPE_CHECKING\n"
701
+ "from src.validating import validate\n"
702
+ "\n"
703
+ "if TYPE_CHECKING:\n"
704
+ " from models import LaterType\n"
705
+ "\n"
706
+ "@validate\n"
707
+ "def build(value: 'LaterType') -> 'LaterType':\n"
708
+ " return value\n",
709
+ encoding="utf-8",
710
+ )
711
+
712
+ models_spec = spec_from_file_location("models", root / "models.py")
713
+ assert models_spec is not None and models_spec.loader is not None
714
+ models_module = module_from_spec(models_spec)
715
+ sys.modules["models"] = models_module
716
+ models_spec.loader.exec_module(models_module)
717
+
718
+ consumer_spec = spec_from_file_location("consumer", root / "consumer.py")
719
+ assert consumer_spec is not None and consumer_spec.loader is not None
720
+ consumer_module = module_from_spec(consumer_spec)
721
+ sys.modules["consumer"] = consumer_module
722
+ consumer_spec.loader.exec_module(consumer_module)
723
+
724
+ instance = models_module.LaterType()
725
+ self.assertIs(consumer_module.build(instance), instance)
726
+ with self.assertRaises(TypeError):
727
+ consumer_module.build(1)
728
+
521
729
  def test_validate_raises_for_unresolvable_string_annotation(self):
522
730
  @validate
523
731
  def add(a: "MissingType") -> int:
File without changes
File without changes
File without changes