validating 0.0.1__tar.gz → 0.0.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: validating
3
- Version: 0.0.1
3
+ Version: 0.0.2
4
4
  Summary: Provides a lightweight `attr()` descriptor factory that adds runtime validation to dataclass fields.
5
5
  Project-URL: Documentation, https://github.com/Chitaoji/validating/blob/main/README.md
6
6
  Project-URL: Repository, https://github.com/Chitaoji/validating/
@@ -181,7 +181,11 @@ class EvenNumber:
181
181
  This project falls under the BSD 3-Clause License.
182
182
 
183
183
  ## History
184
+ ### v0.0.2
185
+ * Fixed runtime checking for `TypedDict` and several special typing hints to avoid invalid `isinstance` paths.
186
+ * Added runtime validation support for `typing.Unpack`-style annotations.
187
+ * Improved forward-reference handling under `TYPE_CHECKING` imports, including deferred annotation resolution and safer fallback behavior when references are temporarily unresolved.
188
+ * Refined the `@validate` argument-checking path for better robustness and consistency across annotated call patterns.
184
189
 
185
190
  ### v0.0.1
186
-
187
- - Initial release.
191
+ * Initial release.
@@ -165,7 +165,11 @@ class EvenNumber:
165
165
  This project falls under the BSD 3-Clause License.
166
166
 
167
167
  ## History
168
+ ### v0.0.2
169
+ * Fixed runtime checking for `TypedDict` and several special typing hints to avoid invalid `isinstance` paths.
170
+ * Added runtime validation support for `typing.Unpack`-style annotations.
171
+ * Improved forward-reference handling under `TYPE_CHECKING` imports, including deferred annotation resolution and safer fallback behavior when references are temporarily unresolved.
172
+ * Refined the `@validate` argument-checking path for better robustness and consistency across annotated call patterns.
168
173
 
169
174
  ### v0.0.1
170
-
171
- - Initial release.
175
+ * Initial release.
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "validating"
7
- version = "0.0.1"
7
+ version = "0.0.2"
8
8
  dependencies = []
9
9
  requires-python = ">=3.13"
10
10
  authors = [
@@ -8,12 +8,24 @@ NOTE: this module is private. All functions and objects are available in the mai
8
8
  """
9
9
 
10
10
  import sys
11
+ from ast import (
12
+ Attribute,
13
+ If,
14
+ Import,
15
+ ImportFrom,
16
+ Name,
17
+ NodeVisitor,
18
+ parse,
19
+ )
11
20
  from dataclasses import Field, field
12
21
  from functools import partialmethod
22
+ from importlib import import_module
23
+ from pathlib import Path
13
24
  from types import UnionType
14
25
  from typing import (
15
26
  Any,
16
27
  Callable,
28
+ ForwardRef,
17
29
  Literal,
18
30
  Optional,
19
31
  Union,
@@ -22,6 +34,22 @@ from typing import (
22
34
  get_type_hints,
23
35
  )
24
36
 
37
+ try: # pragma: no cover - Python >= 3.11
38
+ from typing import NotRequired, Required, Unpack
39
+ except ImportError: # pragma: no cover - Python < 3.11
40
+ from typing_extensions import NotRequired, Required, Unpack
41
+
42
+ try: # pragma: no cover - available on modern Python versions
43
+ from typing import is_typeddict
44
+ except ImportError: # pragma: no cover - compatibility fallback
45
+ def is_typeddict(type_hint: Any) -> bool:
46
+ return bool(
47
+ isinstance(type_hint, type)
48
+ and isinstance(getattr(type_hint, "__annotations__", None), dict)
49
+ and hasattr(type_hint, "__required_keys__")
50
+ and hasattr(type_hint, "__optional_keys__")
51
+ )
52
+
25
53
  __all__ = ["attr"]
26
54
 
27
55
 
@@ -536,6 +564,7 @@ def _resolve_field_type_hint(
536
564
  name: str,
537
565
  *,
538
566
  localns: dict[str, Any] | None = None,
567
+ include_type_checking_names: bool = False,
539
568
  ) -> type:
540
569
  if name not in cls.__annotations__:
541
570
  return Any
@@ -544,18 +573,28 @@ def _resolve_field_type_hint(
544
573
  if not isinstance(raw_type_hint, str):
545
574
  return raw_type_hint
546
575
 
576
+ merged_localns = localns
577
+ if include_type_checking_names:
578
+ merged_localns = _merge_localns(
579
+ localns,
580
+ _collect_type_checking_names(cls.__module__),
581
+ )
582
+
547
583
  try:
548
584
  resolved_hints = get_type_hints(
549
585
  cls,
550
586
  globalns=vars(sys.modules[cls.__module__]),
551
- localns=localns,
587
+ localns=merged_localns,
552
588
  include_extras=True,
553
589
  )
554
- except NameError as exc:
555
- raise ValidatorError(
556
- f"failed to resolve annotation for {cls.__name__}.{name}: {raw_type_hint!r}"
557
- ) from exc
558
590
  except Exception as exc: # pragma: no cover - exact exception depends on annotation
591
+ if not include_type_checking_names:
592
+ return _resolve_field_type_hint(
593
+ cls,
594
+ name,
595
+ localns=localns,
596
+ include_type_checking_names=True,
597
+ )
559
598
  raise ValidatorError(
560
599
  f"failed to resolve annotation for {cls.__name__}.{name}: {raw_type_hint!r}"
561
600
  ) from exc
@@ -595,6 +634,89 @@ def _collect_runtime_localns(
595
634
  return merged or None
596
635
 
597
636
 
637
+ def _merge_localns(*namespaces: dict[str, Any] | None) -> dict[str, Any] | None:
638
+ merged: dict[str, Any] = {}
639
+ for namespace in namespaces:
640
+ if namespace:
641
+ merged.update(namespace)
642
+ return merged or None
643
+
644
+
645
+ def _collect_type_checking_names(module_name: str) -> dict[str, Any]:
646
+ module = sys.modules.get(module_name)
647
+ if module is None:
648
+ return {}
649
+
650
+ module_file = getattr(module, "__file__", None)
651
+ if module_file is None:
652
+ return {}
653
+
654
+ try:
655
+ source = Path(module_file).read_text(encoding="utf-8")
656
+ except (OSError, UnicodeDecodeError):
657
+ return {}
658
+
659
+ package = getattr(module, "__package__", None)
660
+ return _TypeCheckingImportCollector(package).collect(source)
661
+
662
+
663
+ class _TypeCheckingImportCollector(NodeVisitor):
664
+ def __init__(self, package: str | None) -> None:
665
+ self.package = package
666
+ self.names: dict[str, Any] = {}
667
+
668
+ def collect(self, source: str) -> dict[str, Any]:
669
+ try:
670
+ tree = parse(source)
671
+ except SyntaxError:
672
+ return {}
673
+ self.visit(tree)
674
+ return self.names
675
+
676
+ def visit_If(self, node: If) -> None:
677
+ if self._is_type_checking_guard(node.test):
678
+ for stmt in node.body:
679
+ self._consume_type_checking_stmt(stmt)
680
+
681
+ self.generic_visit(node)
682
+
683
+ @staticmethod
684
+ def _is_type_checking_guard(test: Any) -> bool:
685
+ if isinstance(test, Name):
686
+ return test.id == "TYPE_CHECKING"
687
+ if isinstance(test, Attribute):
688
+ return isinstance(test.value, Name) and test.value.id == "typing" and test.attr == "TYPE_CHECKING"
689
+ return False
690
+
691
+ def _consume_type_checking_stmt(self, stmt: Any) -> None:
692
+ if isinstance(stmt, Import):
693
+ for alias in stmt.names:
694
+ try:
695
+ module = import_module(alias.name)
696
+ except Exception:
697
+ continue
698
+ self.names[alias.asname or alias.name.split(".")[0]] = module
699
+ return
700
+
701
+ if not isinstance(stmt, ImportFrom) or stmt.module is None:
702
+ return
703
+
704
+ target_module = "." * stmt.level + stmt.module
705
+ try:
706
+ imported = import_module(target_module, self.package)
707
+ except Exception:
708
+ return
709
+
710
+ for alias in stmt.names:
711
+ if alias.name == "*":
712
+ continue
713
+ try:
714
+ symbol = getattr(imported, alias.name)
715
+ except AttributeError:
716
+ continue
717
+ self.names[alias.asname or alias.name] = symbol
718
+
719
+
598
720
  def isoftype(
599
721
  value: object, type_hint: type, name: str, path: str = ""
600
722
  ) -> Optional[str]:
@@ -625,8 +747,24 @@ def isoftype(
625
747
  args = get_args(type_hint)
626
748
 
627
749
  if origin is None:
628
- if isinstance(value, type_hint):
750
+ if isinstance(type_hint, ForwardRef):
751
+ # Unresolved forward references may appear in dynamic annotations
752
+ # (e.g. ``Union[ForwardRef("MyTypedDict"), None]``). At runtime,
753
+ # accept these values instead of failing with a low-signal error.
629
754
  return None
755
+
756
+ newtype_super = getattr(type_hint, "__supertype__", None)
757
+ if newtype_super is not None:
758
+ return isoftype(value, newtype_super, name, path)
759
+
760
+ if is_typeddict(type_hint):
761
+ return _validate_typed_dict(value, type_hint, name, path)
762
+
763
+ try:
764
+ if isinstance(value, type_hint):
765
+ return None
766
+ except TypeError:
767
+ pass
630
768
  return _format_isoftype_error(
631
769
  path, f"{type_hint!r}, got {type(value)!r} instead"
632
770
  )
@@ -645,6 +783,10 @@ def isoftype(
645
783
  return None
646
784
  return _format_isoftype_error(path, f"one of {args!r}, got {value!r} instead")
647
785
 
786
+ if origin is Unpack:
787
+ (unpacked_type,) = args
788
+ return isoftype(value, unpacked_type, name, path)
789
+
648
790
  if origin is list:
649
791
  (elem_type,) = args
650
792
  if not isinstance(value, list):
@@ -713,4 +855,46 @@ def _format_isoftype_error(path: str, detail: str) -> str:
713
855
  return f"expected {detail}"
714
856
 
715
857
 
858
+ def _validate_typed_dict(
859
+ value: object,
860
+ type_hint: Any,
861
+ name: str,
862
+ path: str,
863
+ ) -> Optional[str]:
864
+ if not isinstance(value, dict):
865
+ return _format_isoftype_error(path, f"a dict, got {type(value)!r} instead")
866
+
867
+ annotations = getattr(type_hint, "__annotations__", {})
868
+ required_keys = set(getattr(type_hint, "__required_keys__", set()))
869
+ optional_keys = set(getattr(type_hint, "__optional_keys__", set()))
870
+ allowed_keys = required_keys | optional_keys | set(annotations)
871
+
872
+ missing = sorted(required_keys - set(value))
873
+ if missing:
874
+ keys = ", ".join(repr(k) for k in missing)
875
+ return _format_isoftype_error(path, f"missing required keys: {keys}")
876
+
877
+ extra = sorted(set(value) - allowed_keys)
878
+ if extra:
879
+ keys = ", ".join(repr(k) for k in extra)
880
+ return _format_isoftype_error(path, f"unexpected keys: {keys}")
881
+
882
+ for key, annotated in annotations.items():
883
+ if key not in value:
884
+ continue
885
+ key_type = _unwrap_required_marker(annotated)
886
+ key_error = isoftype(value[key], key_type, name, f"{name}[{key!r}]")
887
+ if key_error is not None:
888
+ return key_error
889
+
890
+ return None
891
+
892
+
893
+ def _unwrap_required_marker(type_hint: Any) -> Any:
894
+ origin = get_origin(type_hint)
895
+ if origin is Required or origin is NotRequired:
896
+ return get_args(type_hint)[0]
897
+ return type_hint
898
+
899
+
716
900
  class ValidatorError(RuntimeError): ...
@@ -13,12 +13,23 @@ from functools import wraps
13
13
  from inspect import Parameter, Signature, getsourcelines, signature
14
14
  from textwrap import dedent
15
15
  from traceback import TracebackException
16
- from typing import Any, Callable, get_type_hints
16
+ from typing import Any, Callable, get_args, get_origin, get_type_hints
17
17
 
18
- from .valid_attr import _collect_runtime_localns, _get_owner_localns, isoftype
18
+ from .valid_attr import (
19
+ _collect_runtime_localns,
20
+ _collect_type_checking_names,
21
+ _get_owner_localns,
22
+ _merge_localns,
23
+ isoftype,
24
+ )
19
25
 
20
26
  __all__ = ["validate"]
21
27
 
28
+ try: # pragma: no cover - Python >= 3.11
29
+ from typing import Unpack
30
+ except ImportError: # pragma: no cover - Python < 3.11
31
+ from typing_extensions import Unpack
32
+
22
33
  _VALIDATE_MARKER = "__validating_is_validate_wrapped__"
23
34
 
24
35
 
@@ -53,6 +64,7 @@ def validate[T](func: T) -> T:
53
64
  return func
54
65
 
55
66
  sig = signature(func)
67
+ parameters = sig.parameters
56
68
  owner_localns = _get_owner_localns()
57
69
  resolved_annotations, deferred_annotations = _resolve_signature_annotations(
58
70
  func,
@@ -69,7 +81,7 @@ def validate[T](func: T) -> T:
69
81
  deferred_annotations,
70
82
  initial_localns=owner_localns,
71
83
  )
72
- _validate_bound_arguments(func, sig, bound.arguments, resolved_annotations)
84
+ _validate_bound_arguments(func, parameters, bound.arguments, resolved_annotations)
73
85
  try:
74
86
  return func(*args, **kwargs)
75
87
  except AssertionError as exc:
@@ -219,7 +231,15 @@ def _resolve_signature_annotations(
219
231
  include_extras=True,
220
232
  )
221
233
  except Exception:
222
- type_hints = {}
234
+ try:
235
+ type_hints = get_type_hints(
236
+ func,
237
+ globalns=globalns,
238
+ localns=_merge_localns(localns, _collect_type_checking_names(func.__module__)),
239
+ include_extras=True,
240
+ )
241
+ except Exception:
242
+ type_hints = {}
223
243
 
224
244
  for name, param in sig.parameters.items():
225
245
  annotation = param.annotation
@@ -246,7 +266,10 @@ def _resolve_deferred_signature_annotations(
246
266
  if not deferred_annotations:
247
267
  return
248
268
 
249
- localns = _collect_runtime_localns(initial_localns)
269
+ localns = _merge_localns(
270
+ _collect_runtime_localns(initial_localns),
271
+ _collect_type_checking_names(func.__module__),
272
+ )
250
273
  globalns = vars(__import__(func.__module__, fromlist=["*"]))
251
274
  try:
252
275
  type_hints = get_type_hints(
@@ -276,12 +299,12 @@ def _resolve_deferred_signature_annotations(
276
299
 
277
300
  def _validate_bound_arguments(
278
301
  func: Callable[..., Any],
279
- sig: Signature,
302
+ parameters: dict[str, Parameter],
280
303
  arguments: dict[str, Any],
281
304
  resolved_annotations: dict[str, Any],
282
305
  ) -> None:
283
306
  for name, value in arguments.items():
284
- param = sig.parameters[name]
307
+ param = parameters[name]
285
308
  if name not in resolved_annotations:
286
309
  continue
287
310
  annotation = resolved_annotations[name]
@@ -295,13 +318,18 @@ def _validate_bound_arguments(
295
318
  path=f"{name}[{idx}]",
296
319
  )
297
320
  if mismatch_reason is not None:
298
- raise TypeError(
299
- f"invalid type for argument {name!r} of {func.__name__}: "
300
- + mismatch_reason
301
- )
321
+ _raise_argument_type_error(func, name, mismatch_reason)
302
322
  continue
303
323
 
304
324
  if param.kind is Parameter.VAR_KEYWORD:
325
+ if get_origin(annotation) is Unpack:
326
+ unpacked = get_args(annotation)
327
+ unpacked_annotation = unpacked[0] if unpacked else annotation
328
+ mismatch_reason = isoftype(value, unpacked_annotation, name, path=name)
329
+ if mismatch_reason is not None:
330
+ _raise_argument_type_error(func, name, mismatch_reason)
331
+ continue
332
+
305
333
  for key, item in value.items():
306
334
  mismatch_reason = isoftype(
307
335
  item,
@@ -310,15 +338,17 @@ def _validate_bound_arguments(
310
338
  path=f"{name}[{key!r}]",
311
339
  )
312
340
  if mismatch_reason is not None:
313
- raise TypeError(
314
- f"invalid type for argument {name!r} of {func.__name__}: "
315
- + mismatch_reason
316
- )
341
+ _raise_argument_type_error(func, name, mismatch_reason)
317
342
  continue
318
343
 
319
344
  mismatch_reason = isoftype(value, annotation, name)
320
345
  if mismatch_reason is not None:
321
- raise TypeError(
322
- f"invalid type for argument {name!r} of {func.__name__}: "
323
- + mismatch_reason
324
- )
346
+ _raise_argument_type_error(func, name, mismatch_reason)
347
+
348
+
349
+ def _raise_argument_type_error(
350
+ func: Callable[..., Any], name: str, mismatch_reason: str
351
+ ) -> None:
352
+ raise TypeError(
353
+ f"invalid type for argument {name!r} of {func.__name__}: " + mismatch_reason
354
+ )
@@ -1,6 +1,15 @@
1
1
  import unittest
2
+ import sys
3
+ from importlib.util import module_from_spec, spec_from_file_location
4
+ from pathlib import Path
5
+ from tempfile import TemporaryDirectory
2
6
  from dataclasses import dataclass
3
- from typing import Any, Literal
7
+ from typing import Any, ForwardRef, Literal, TypedDict
8
+
9
+ try:
10
+ from typing import Unpack
11
+ except ImportError:
12
+ from typing_extensions import Unpack
4
13
 
5
14
  from src.validating import (
6
15
  ValidatorError,
@@ -297,6 +306,86 @@ class TestAttrWithDataclasses(unittest.TestCase):
297
306
  cfg = Config(retries=LaterType())
298
307
  self.assertIsInstance(cfg.retries, LaterType)
299
308
 
309
+ def test_type_checking_import_is_lazily_loaded_for_attr(self):
310
+ with TemporaryDirectory() as temp_dir:
311
+ root = Path(temp_dir)
312
+ (root / "lazy_models.py").write_text(
313
+ "class LaterType:\n"
314
+ " pass\n",
315
+ encoding="utf-8",
316
+ )
317
+ (root / "lazy_consumer.py").write_text(
318
+ "from dataclasses import dataclass\n"
319
+ "from typing import TYPE_CHECKING\n"
320
+ "from src.validating import attr\n"
321
+ "\n"
322
+ "if TYPE_CHECKING:\n"
323
+ " from lazy_models import LaterType\n"
324
+ "\n"
325
+ "@dataclass\n"
326
+ "class Config:\n"
327
+ " value: 'LaterType' = attr()\n",
328
+ encoding="utf-8",
329
+ )
330
+
331
+ self.assertNotIn("lazy_models", sys.modules)
332
+
333
+ consumer_spec = spec_from_file_location("lazy_consumer", root / "lazy_consumer.py")
334
+ assert consumer_spec is not None and consumer_spec.loader is not None
335
+ consumer_module = module_from_spec(consumer_spec)
336
+ sys.modules["lazy_consumer"] = consumer_module
337
+ consumer_spec.loader.exec_module(consumer_module)
338
+
339
+ self.assertNotIn("lazy_models", sys.modules)
340
+
341
+ models_spec = spec_from_file_location("lazy_models", root / "lazy_models.py")
342
+ assert models_spec is not None and models_spec.loader is not None
343
+ models_module = module_from_spec(models_spec)
344
+ sys.modules["lazy_models"] = models_module
345
+ models_spec.loader.exec_module(models_module)
346
+
347
+ cfg = consumer_module.Config(value=models_module.LaterType())
348
+ self.assertIsInstance(cfg.value, models_module.LaterType)
349
+
350
+ def test_forward_string_annotation_from_type_checking_import_is_resolved_for_attr(
351
+ self,
352
+ ):
353
+ with TemporaryDirectory() as temp_dir:
354
+ root = Path(temp_dir)
355
+ (root / "models.py").write_text(
356
+ "class LaterType:\n"
357
+ " pass\n",
358
+ encoding="utf-8",
359
+ )
360
+ (root / "consumer.py").write_text(
361
+ "from dataclasses import dataclass\n"
362
+ "from typing import TYPE_CHECKING\n"
363
+ "from src.validating import attr\n"
364
+ "\n"
365
+ "if TYPE_CHECKING:\n"
366
+ " from models import LaterType\n"
367
+ "\n"
368
+ "@dataclass\n"
369
+ "class Config:\n"
370
+ " value: 'LaterType' = attr()\n",
371
+ encoding="utf-8",
372
+ )
373
+
374
+ models_spec = spec_from_file_location("models", root / "models.py")
375
+ assert models_spec is not None and models_spec.loader is not None
376
+ models_module = module_from_spec(models_spec)
377
+ sys.modules["models"] = models_module
378
+ models_spec.loader.exec_module(models_module)
379
+
380
+ consumer_spec = spec_from_file_location("consumer", root / "consumer.py")
381
+ assert consumer_spec is not None and consumer_spec.loader is not None
382
+ consumer_module = module_from_spec(consumer_spec)
383
+ sys.modules["consumer"] = consumer_module
384
+ consumer_spec.loader.exec_module(consumer_module)
385
+
386
+ cfg = consumer_module.Config(value=models_module.LaterType())
387
+ self.assertIsInstance(cfg.value, models_module.LaterType)
388
+
300
389
  def test_unresolvable_string_annotation_raises_for_attr(self):
301
390
  @dataclass
302
391
  class Config:
@@ -497,6 +586,23 @@ class TestValidateFunctionDecorator(unittest.TestCase):
497
586
  with self.assertRaises(TypeError):
498
587
  collect(1, 2, key=3)
499
588
 
589
+ def test_validate_supports_unpack_typed_dict_for_kwargs(self):
590
+ class Query(TypedDict):
591
+ limit: int
592
+ cursor: str
593
+
594
+ @validate
595
+ def fetch(**kwargs: Unpack[Query]):
596
+ return kwargs
597
+
598
+ self.assertEqual(fetch(limit=1, cursor="next"), {"limit": 1, "cursor": "next"})
599
+ with self.assertRaises(TypeError):
600
+ fetch(limit="1", cursor="next")
601
+ with self.assertRaises(TypeError):
602
+ fetch(limit=1)
603
+ with self.assertRaises(TypeError):
604
+ fetch(limit=1, cursor="next", extra="x")
605
+
500
606
  def test_validate_resolves_string_annotations(self):
501
607
  @validate
502
608
  def add(a: "int", b: "int") -> int:
@@ -506,6 +612,17 @@ class TestValidateFunctionDecorator(unittest.TestCase):
506
612
  with self.assertRaises(TypeError):
507
613
  add("1", 2)
508
614
 
615
+ def test_validate_accepts_union_with_unresolved_forward_ref(self):
616
+ @validate
617
+ def set_figure(kwargs: ForwardRef("SubplotDict") | None):
618
+ return kwargs
619
+
620
+ self.assertEqual(
621
+ set_figure({"left": 0.1, "right": 0.9}),
622
+ {"left": 0.1, "right": 0.9},
623
+ )
624
+ self.assertIsNone(set_figure(None))
625
+
509
626
  def test_validate_resolves_forward_string_annotations_in_local_scope(self):
510
627
  @validate
511
628
  def build(value: "LaterType") -> "LaterType":
@@ -518,6 +635,86 @@ class TestValidateFunctionDecorator(unittest.TestCase):
518
635
  with self.assertRaises(TypeError):
519
636
  build(1)
520
637
 
638
+ def test_validate_type_checking_import_is_lazily_loaded(self):
639
+ with TemporaryDirectory() as temp_dir:
640
+ root = Path(temp_dir)
641
+ (root / "lazy_v_models.py").write_text(
642
+ "class LaterType:\n"
643
+ " pass\n",
644
+ encoding="utf-8",
645
+ )
646
+ (root / "lazy_v_consumer.py").write_text(
647
+ "from typing import TYPE_CHECKING\n"
648
+ "from src.validating import validate\n"
649
+ "\n"
650
+ "if TYPE_CHECKING:\n"
651
+ " from lazy_v_models import LaterType\n"
652
+ "\n"
653
+ "@validate\n"
654
+ "def build(value: 'LaterType') -> 'LaterType':\n"
655
+ " return value\n",
656
+ encoding="utf-8",
657
+ )
658
+
659
+ self.assertNotIn("lazy_v_models", sys.modules)
660
+
661
+ consumer_spec = spec_from_file_location("lazy_v_consumer", root / "lazy_v_consumer.py")
662
+ assert consumer_spec is not None and consumer_spec.loader is not None
663
+ consumer_module = module_from_spec(consumer_spec)
664
+ sys.modules["lazy_v_consumer"] = consumer_module
665
+ consumer_spec.loader.exec_module(consumer_module)
666
+
667
+ self.assertNotIn("lazy_v_models", sys.modules)
668
+
669
+ models_spec = spec_from_file_location("lazy_v_models", root / "lazy_v_models.py")
670
+ assert models_spec is not None and models_spec.loader is not None
671
+ models_module = module_from_spec(models_spec)
672
+ sys.modules["lazy_v_models"] = models_module
673
+ models_spec.loader.exec_module(models_module)
674
+
675
+ instance = models_module.LaterType()
676
+ self.assertIs(consumer_module.build(instance), instance)
677
+ with self.assertRaises(TypeError):
678
+ consumer_module.build(1)
679
+
680
+ def test_validate_resolves_forward_annotation_from_type_checking_import(self):
681
+ with TemporaryDirectory() as temp_dir:
682
+ root = Path(temp_dir)
683
+ (root / "models.py").write_text(
684
+ "class LaterType:\n"
685
+ " pass\n",
686
+ encoding="utf-8",
687
+ )
688
+ (root / "consumer.py").write_text(
689
+ "from typing import TYPE_CHECKING\n"
690
+ "from src.validating import validate\n"
691
+ "\n"
692
+ "if TYPE_CHECKING:\n"
693
+ " from models import LaterType\n"
694
+ "\n"
695
+ "@validate\n"
696
+ "def build(value: 'LaterType') -> 'LaterType':\n"
697
+ " return value\n",
698
+ encoding="utf-8",
699
+ )
700
+
701
+ models_spec = spec_from_file_location("models", root / "models.py")
702
+ assert models_spec is not None and models_spec.loader is not None
703
+ models_module = module_from_spec(models_spec)
704
+ sys.modules["models"] = models_module
705
+ models_spec.loader.exec_module(models_module)
706
+
707
+ consumer_spec = spec_from_file_location("consumer", root / "consumer.py")
708
+ assert consumer_spec is not None and consumer_spec.loader is not None
709
+ consumer_module = module_from_spec(consumer_spec)
710
+ sys.modules["consumer"] = consumer_module
711
+ consumer_spec.loader.exec_module(consumer_module)
712
+
713
+ instance = models_module.LaterType()
714
+ self.assertIs(consumer_module.build(instance), instance)
715
+ with self.assertRaises(TypeError):
716
+ consumer_module.build(1)
717
+
521
718
  def test_validate_raises_for_unresolvable_string_annotation(self):
522
719
  @validate
523
720
  def add(a: "MissingType") -> int:
File without changes
File without changes
File without changes