json2vec 0.4.7__tar.gz → 0.4.9__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (79) hide show
  1. {json2vec-0.4.7/src/json2vec.egg-info → json2vec-0.4.9}/PKG-INFO +8 -5
  2. {json2vec-0.4.7 → json2vec-0.4.9}/README.md +1 -1
  3. {json2vec-0.4.7 → json2vec-0.4.9}/pyproject.toml +12 -5
  4. {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/__init__.py +13 -14
  5. {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/architecture/checkpoint.py +42 -0
  6. {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/architecture/contracts.py +5 -4
  7. {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/architecture/mutations.py +137 -14
  8. {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/architecture/root.py +62 -133
  9. {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/architecture/runtime.py +12 -3
  10. {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/data/datasets/base.py +8 -0
  11. {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/data/datasets/custom.py +11 -4
  12. {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/data/datasets/polars.py +11 -4
  13. {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/data/datasets/streaming.py +11 -4
  14. {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/data/iterables.py +114 -8
  15. {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/data/processing.py +90 -3
  16. json2vec-0.4.9/src/json2vec/helpers/__init__.py +8 -0
  17. json2vec-0.4.9/src/json2vec/helpers/inference.py +632 -0
  18. json2vec-0.4.9/src/json2vec/helpers/optimizers.py +78 -0
  19. json2vec-0.4.9/src/json2vec/helpers/trainer.py +0 -0
  20. {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/inference/__init__.py +5 -11
  21. json2vec-0.4.9/src/json2vec/inference/deployment.py +691 -0
  22. json2vec-0.4.9/src/json2vec/structs/__init__.py +0 -0
  23. {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/structs/enums.py +0 -1
  24. {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/structs/experiment.py +70 -25
  25. json2vec-0.4.9/src/json2vec/structs/structure.py +228 -0
  26. {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/structs/tree.py +147 -2
  27. {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/tensorfields/base.py +254 -45
  28. {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/tensorfields/extensions/category.py +83 -84
  29. {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/tensorfields/extensions/dateparts.py +32 -29
  30. {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/tensorfields/extensions/entity.py +33 -28
  31. {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/tensorfields/extensions/number.py +77 -74
  32. {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/tensorfields/extensions/set.py +61 -90
  33. {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/tensorfields/extensions/text.py +30 -20
  34. {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/tensorfields/extensions/vector.py +32 -26
  35. json2vec-0.4.9/src/json2vec/tensorfields/shared/__init__.py +80 -0
  36. {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/tensorfields/shared/counter.py +3 -1
  37. json2vec-0.4.9/src/json2vec/tensorfields/shared/vocabulary.py +440 -0
  38. {json2vec-0.4.7 → json2vec-0.4.9/src/json2vec.egg-info}/PKG-INFO +8 -5
  39. {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec.egg-info/SOURCES.txt +8 -2
  40. {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec.egg-info/requires.txt +6 -3
  41. {json2vec-0.4.7 → json2vec-0.4.9}/tests/test_callbacks.py +6 -30
  42. json2vec-0.4.9/tests/test_optimizers.py +78 -0
  43. {json2vec-0.4.7 → json2vec-0.4.9}/tests/test_public_api.py +1 -3
  44. json2vec-0.4.9/tests/test_schema_inference.py +327 -0
  45. json2vec-0.4.7/src/json2vec/architecture/plot.py +0 -562
  46. json2vec-0.4.7/src/json2vec/inference/deployment.py +0 -422
  47. json2vec-0.4.7/src/json2vec/structs/structure.py +0 -59
  48. json2vec-0.4.7/src/json2vec/tensorfields/shared/__init__.py +0 -12
  49. json2vec-0.4.7/src/json2vec/tensorfields/shared/vocabulary.py +0 -283
  50. {json2vec-0.4.7 → json2vec-0.4.9}/LICENSE +0 -0
  51. {json2vec-0.4.7 → json2vec-0.4.9}/NOTICE +0 -0
  52. {json2vec-0.4.7 → json2vec-0.4.9}/setup.cfg +0 -0
  53. {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/architecture/__init__.py +0 -0
  54. {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/architecture/attention.py +0 -0
  55. {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/architecture/encoder.py +0 -0
  56. {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/architecture/graph.py +0 -0
  57. {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/architecture/node.py +0 -0
  58. {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/architecture/pool.py +0 -0
  59. {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/architecture/rotary.py +0 -0
  60. {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/data/__init__.py +0 -0
  61. {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/data/datasets/__init__.py +0 -0
  62. {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/distributed.py +0 -0
  63. /json2vec-0.4.7/src/json2vec/structs/__init__.py → /json2vec-0.4.9/src/json2vec/helpers/hyperparameters.py +0 -0
  64. {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/inference/callback.py +0 -0
  65. {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/logging/__init__.py +0 -0
  66. {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/logging/config.py +0 -0
  67. {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/logging/epoch.py +0 -0
  68. {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/logging/throughput.py +0 -0
  69. {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/preprocessors/__init__.py +0 -0
  70. {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/preprocessors/base.py +0 -0
  71. {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/preprocessors/extensions/__init__.py +0 -0
  72. {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/preprocessors/spec.py +0 -0
  73. {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/structs/packages.py +0 -0
  74. {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/structs/selectors.py +0 -0
  75. {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/tensorfields/__init__.py +0 -0
  76. {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/tensorfields/extensions/__init__.py +0 -0
  77. {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/tensorfields/spec.py +0 -0
  78. {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec.egg-info/dependency_links.txt +0 -0
  79. {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: json2vec
3
- Version: 0.4.7
3
+ Version: 0.4.9
4
4
  Summary: Schema-first PyTorch models for hierarchical / nested / sequence data structures
5
5
  License-Expression: Apache-2.0
6
6
  Requires-Python: >=3.12
@@ -14,7 +14,6 @@ Requires-Dist: pydantic>=2.11.7
14
14
  Requires-Dist: jmespath>=1.0.1
15
15
  Requires-Dist: loguru>=0.7.3
16
16
  Requires-Dist: anytree>=2.13.0
17
- Requires-Dist: ordered-set>=4.1.0
18
17
  Requires-Dist: pyarrow>=21.0.0
19
18
  Requires-Dist: polars>=1.35.2
20
19
  Requires-Dist: numpy>=2.2.6
@@ -22,16 +21,20 @@ Requires-Dist: lightning>=2.6.4
22
21
  Requires-Dist: tensordict>=0.10.0
23
22
  Requires-Dist: torch>=2.7.1
24
23
  Provides-Extra: serving
25
- Requires-Dist: litserve>=0.2.13; extra == "serving"
24
+ Requires-Dist: fastapi>=0.124.0; extra == "serving"
25
+ Requires-Dist: orjson>=3.10.0; extra == "serving"
26
26
  Requires-Dist: pydantic-settings>=2.10.1; extra == "serving"
27
+ Requires-Dist: uvicorn>=0.38.0; extra == "serving"
27
28
  Provides-Extra: text
28
29
  Requires-Dist: transformers>=4.55.0; extra == "text"
29
30
  Provides-Extra: docs
30
- Requires-Dist: litserve>=0.2.13; extra == "docs"
31
+ Requires-Dist: fastapi>=0.124.0; extra == "docs"
31
32
  Requires-Dist: mkdocs-material>=9.6; extra == "docs"
32
33
  Requires-Dist: mkdocs-jupyter>=0.26.3; extra == "docs"
33
34
  Requires-Dist: mkdocstrings[python]>=0.27; extra == "docs"
35
+ Requires-Dist: orjson>=3.10.0; extra == "docs"
34
36
  Requires-Dist: pydantic-settings>=2.10.1; extra == "docs"
37
+ Requires-Dist: uvicorn>=0.38.0; extra == "docs"
35
38
  Dynamic: license-file
36
39
 
37
40
  <h1 align="center"><code>json2vec</code></h1>
@@ -314,7 +317,7 @@ uv sync --extra docs
314
317
  ```
315
318
 
316
319
  The `text` extra installs Hugging Face `transformers`. The `serving` extra
317
- installs LitServe-backed deployment dependencies. The `docs` extra installs the
320
+ installs FastAPI-backed deployment dependencies. The `docs` extra installs the
318
321
  MkDocs toolchain.
319
322
 
320
323
  ## Documentation Map
@@ -278,7 +278,7 @@ uv sync --extra docs
278
278
  ```
279
279
 
280
280
  The `text` extra installs Hugging Face `transformers`. The `serving` extra
281
- installs LitServe-backed deployment dependencies. The `docs` extra installs the
281
+ installs FastAPI-backed deployment dependencies. The `docs` extra installs the
282
282
  MkDocs toolchain.
283
283
 
284
284
  ## Documentation Map
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "json2vec"
3
- version = "0.4.7"
3
+ version = "0.4.9"
4
4
  description = "Schema-first PyTorch models for hierarchical / nested / sequence data structures"
5
5
  readme = "README.md"
6
6
  license = "Apache-2.0"
@@ -13,7 +13,6 @@ dependencies = [
13
13
  "jmespath>=1.0.1",
14
14
  "loguru>=0.7.3",
15
15
  "anytree>=2.13.0",
16
- "ordered-set>=4.1.0",
17
16
  "pyarrow>=21.0.0",
18
17
  "polars>=1.35.2",
19
18
  "numpy>=2.2.6",
@@ -24,30 +23,37 @@ dependencies = [
24
23
 
25
24
  [project.optional-dependencies]
26
25
  serving = [
27
- "litserve>=0.2.13",
26
+ "fastapi>=0.124.0",
27
+ "orjson>=3.10.0",
28
28
  "pydantic-settings>=2.10.1",
29
+ "uvicorn>=0.38.0",
29
30
  ]
30
31
  text = [
31
32
  "transformers>=4.55.0",
32
33
  ]
33
34
  docs = [
34
- "litserve>=0.2.13",
35
+ "fastapi>=0.124.0",
35
36
  "mkdocs-material>=9.6",
36
37
  "mkdocs-jupyter>=0.26.3",
37
38
  "mkdocstrings[python]>=0.27",
39
+ "orjson>=3.10.0",
38
40
  "pydantic-settings>=2.10.1",
41
+ "uvicorn>=0.38.0",
39
42
  ]
40
43
 
41
44
  [dependency-groups]
42
45
  dev = [
43
46
  "ruff>=0.12.12",
44
47
  "pytest>=8.4.1",
48
+ "pytest-xdist>=3.8.0",
45
49
  "ipython>=9.9.0",
46
50
  "ipykernel>=6.29.5",
47
51
  "nbclient>=0.10.2",
48
52
  "nbformat>=5.10.4",
49
- "litserve>=0.2.13",
53
+ "fastapi>=0.124.0",
54
+ "orjson>=3.10.0",
50
55
  "pydantic-settings>=2.10.1",
56
+ "uvicorn>=0.38.0",
51
57
  "ty>=0.0.1a20",
52
58
  "pre-commit>=4.3.0",
53
59
  ]
@@ -66,6 +72,7 @@ include = ["json2vec*"]
66
72
  [tool.pytest.ini_options]
67
73
  testpaths = ["tests"]
68
74
  python_files = ["test_*.py"]
75
+ addopts = ["-n", "auto"]
69
76
 
70
77
  [tool.ruff]
71
78
  line-length = 120
@@ -8,15 +8,16 @@ mutation predicates, and the `@preprocess` decorator.
8
8
 
9
9
  from typing import TYPE_CHECKING, Any
10
10
 
11
+ from json2vec import helpers as helpers
12
+ from json2vec.architecture.checkpoint import RollbackCheckpoint
13
+ from json2vec.architecture.mutations import MutationLockCallback, RuntimePlacementCallback
11
14
  from json2vec.architecture.root import (
12
15
  Model,
13
- MutationLockCallback,
14
16
  OptimizerConfig,
15
- RollbackCheckpoint,
16
- RuntimePlacementCallback,
17
17
  SchedulerConfig,
18
18
  )
19
19
  from json2vec.data.datasets import CustomDataModule, PolarsDataModule, StreamingDataModule
20
+ from json2vec.data.processing import MASK_LITERAL, MaskLiteral
20
21
  from json2vec.inference.callback import Postprocessor, Writer
21
22
  from json2vec.preprocessors import PREPROCESSORS, Preprocessor, PreprocessorMode, preprocess
22
23
  from json2vec.structs.enums import (
@@ -38,7 +39,7 @@ from json2vec.structs.experiment import (
38
39
  predicate,
39
40
  where,
40
41
  )
41
- from json2vec.structs.structure import Array
42
+ from json2vec.structs.structure import Array, Mask
42
43
  from json2vec.structs.tree import Address, Leaf
43
44
  from json2vec.tensorfields import TENSORFIELDS, DecoderBase, EmbedderBase, Plugin, RequestBase, TensorFieldBase
44
45
  from json2vec.tensorfields.extensions.category import Request as Category
@@ -52,23 +53,19 @@ from json2vec.tensorfields.shared.vocabulary import VocabularySyncCallback
52
53
 
53
54
  if TYPE_CHECKING:
54
55
  from json2vec.inference.deployment import (
55
- API,
56
56
  Accelerator,
57
- BatchItem,
58
57
  Deployment,
59
- ErrorItem,
60
58
  Input,
59
+ JSONBackend,
61
60
  ModelSource,
62
61
  UpdateOperation,
63
62
  )
64
63
 
65
64
  _SERVING_EXPORTS = {
66
- "API",
67
65
  "Accelerator",
68
- "BatchItem",
69
66
  "Deployment",
70
- "ErrorItem",
71
67
  "Input",
68
+ "JSONBackend",
72
69
  "ModelSource",
73
70
  "UpdateOperation",
74
71
  }
@@ -81,7 +78,7 @@ def __getattr__(name: str) -> Any:
81
78
  try:
82
79
  from json2vec.inference import deployment
83
80
  except ModuleNotFoundError as error:
84
- if error.name in {"litserve", "pydantic_settings"}:
81
+ if error.name in {"fastapi", "orjson", "pydantic_settings", "uvicorn"}:
85
82
  raise ModuleNotFoundError(
86
83
  f"json2vec.{name} requires the serving extra; install with `pip install json2vec[serving]`."
87
84
  ) from error
@@ -98,11 +95,9 @@ def __dir__() -> list[str]:
98
95
 
99
96
  __all__ = [
100
97
  "Address",
101
- "API",
102
98
  "Accelerator",
103
99
  "Array",
104
100
  "AttentionMode",
105
- "BatchItem",
106
101
  "Category",
107
102
  "Component",
108
103
  "CustomDataModule",
@@ -111,11 +106,15 @@ __all__ = [
111
106
  "Deployment",
112
107
  "EmbedderBase",
113
108
  "Entity",
114
- "ErrorItem",
109
+ "helpers",
115
110
  "Hyperparameters",
116
111
  "Input",
112
+ "JSONBackend",
117
113
  "Leaf",
118
114
  "Metric",
115
+ "MASK_LITERAL",
116
+ "Mask",
117
+ "MaskLiteral",
119
118
  "Model",
120
119
  "ModelSource",
121
120
  "MutationLockCallback",
@@ -5,7 +5,9 @@ from __future__ import annotations
5
5
  from pathlib import Path
6
6
  from typing import TYPE_CHECKING, Any
7
7
 
8
+ import lightning.pytorch as lit
8
9
  import torch
10
+ from lightning.pytorch.callbacks import ModelCheckpoint
9
11
  from loguru import logger
10
12
 
11
13
  from json2vec.architecture.graph import ModelGraph
@@ -15,6 +17,46 @@ if TYPE_CHECKING:
15
17
  from json2vec.architecture.root import Model
16
18
 
17
19
 
20
+ class RollbackCheckpoint(ModelCheckpoint):
21
+ """Checkpoint the best model during fit and restore it into the module at fit end."""
22
+
23
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
24
+ super().__init__(*args, **kwargs)
25
+ if self.save_weights_only:
26
+ raise ValueError("RollbackCheckpoint requires full checkpoints; set save_weights_only=False")
27
+ if self.save_top_k == 0:
28
+ raise ValueError("RollbackCheckpoint requires at least one saved checkpoint; set save_top_k != 0")
29
+
30
+ def on_fit_end(self, trainer: lit.Trainer, pl_module: lit.LightningModule) -> None:
31
+ from json2vec.architecture.root import Model
32
+
33
+ super().on_fit_end(trainer=trainer, pl_module=pl_module)
34
+ if not isinstance(pl_module, Model):
35
+ raise TypeError("RollbackCheckpoint can only restore json2vec Model instances")
36
+
37
+ best_model_path = self.best_model_path
38
+ if not best_model_path:
39
+ raise RuntimeError("RollbackCheckpoint did not find a best checkpoint to restore")
40
+
41
+ strategy = getattr(trainer, "strategy", None)
42
+ if strategy is not None:
43
+ strategy.barrier("rollback_checkpoint_load")
44
+ checkpoint = strategy.checkpoint_io.load_checkpoint(
45
+ best_model_path,
46
+ map_location=pl_module.device,
47
+ weights_only=False,
48
+ )
49
+ else:
50
+ checkpoint = torch.load(best_model_path, weights_only=False, map_location=pl_module.device)
51
+
52
+ pl_module.restore_checkpoint_state(checkpoint)
53
+ logger.bind(
54
+ component="checkpoint",
55
+ checkpoint=best_model_path,
56
+ score=self.best_model_score,
57
+ ).info("rolled back Model to best checkpoint")
58
+
59
+
18
60
  class CheckpointState:
19
61
  """Save, load, and restore model state without owning the public facade."""
20
62
 
@@ -3,9 +3,9 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  from collections.abc import Iterator, Mapping
6
- from dataclasses import dataclass, field
7
6
  from typing import TYPE_CHECKING, Any
8
7
 
8
+ import pydantic
9
9
  import torch
10
10
  from tensordict import TensorDict
11
11
 
@@ -34,12 +34,13 @@ ContractSignature = tuple[Any, ...]
34
34
  ContractScope = tuple[str, int, int, ContractSignature]
35
35
 
36
36
 
37
- @dataclass
38
- class ContractScheduler:
37
+ class ContractScheduler(pydantic.BaseModel):
39
38
  """Deterministic backoff scheduler for expensive forward contract checks."""
40
39
 
40
+ model_config = pydantic.ConfigDict(arbitrary_types_allowed=True)
41
+
41
42
  periodic_interval: int = 1024
42
- _counts: dict[ContractScope, int] = field(default_factory=dict)
43
+ _counts: dict[ContractScope, int] = pydantic.PrivateAttr(default_factory=dict)
43
44
 
44
45
  def reset(self) -> None:
45
46
  self._counts.clear()
@@ -4,12 +4,17 @@ from __future__ import annotations
4
4
 
5
5
  from collections.abc import Callable, Iterator
6
6
  from contextlib import contextmanager
7
- from dataclasses import dataclass
7
+ from functools import partialmethod, wraps
8
8
  from typing import TYPE_CHECKING, Any
9
9
 
10
+ import lightning.pytorch as lit
11
+ import pydantic
12
+ import torch
13
+ from lightning.pytorch import Callback
10
14
  from loguru import logger
11
15
 
12
16
  from json2vec.architecture.graph import ModelGraph
17
+ from json2vec.structs.enums import Strata
13
18
  from json2vec.structs.experiment import NodeAttribute, NodePredicate, SchemaField
14
19
  from json2vec.structs.structure import Array
15
20
  from json2vec.structs.tree import Leaf, Node
@@ -20,12 +25,85 @@ if TYPE_CHECKING:
20
25
  _MISSING = object()
21
26
 
22
27
 
23
- @dataclass(frozen=True)
24
- class AttributeChange:
28
+ def immutable(name: str | Strata) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
29
+ def decorator(method: Callable[..., Any]) -> Callable[..., Any]:
30
+ @wraps(method)
31
+ def wrapped(self: Any, *args: Any, **kwargs: Any) -> Any:
32
+ locks = self.locks
33
+ locks[name] += 1
34
+ try:
35
+ return method(self, *args, **kwargs)
36
+ finally:
37
+ if locks[name] <= 1:
38
+ locks.pop(name, None)
39
+ else:
40
+ locks[name] -= 1
41
+
42
+ return wrapped
43
+
44
+ return decorator
45
+
46
+
47
+ class MutationLockCallback(Callback):
48
+ """Prevent runtime schema mutations while Lightning owns an active loop."""
49
+
50
+ locks: tuple[Strata, ...] = (Strata.train, Strata.validate, Strata.test, Strata.predict)
51
+
52
+ def _on_loop_start(self, trainer: lit.Trainer, pl_module: "Model", strata: Strata) -> None:
53
+ pl_module.locks[strata] += 1
54
+
55
+ def _on_loop_end(self, trainer: lit.Trainer, pl_module: "Model", strata: Strata) -> None:
56
+ locks = pl_module.locks
57
+ if locks[strata] <= 1:
58
+ locks.pop(strata, None)
59
+ else:
60
+ locks[strata] -= 1
61
+
62
+ def on_exception(
63
+ self,
64
+ trainer: lit.Trainer,
65
+ pl_module: "Model",
66
+ exception: BaseException,
67
+ ) -> None: # ty:ignore[invalid-method-override]
68
+ for lock in self.locks:
69
+ pl_module.locks.pop(lock, None)
70
+
71
+ on_train_start = partialmethod(_on_loop_start, strata=Strata.train)
72
+ on_train_end = partialmethod(_on_loop_end, strata=Strata.train)
73
+ on_validation_start = partialmethod(_on_loop_start, strata=Strata.validate)
74
+ on_validation_end = partialmethod(_on_loop_end, strata=Strata.validate)
75
+ on_test_start = partialmethod(_on_loop_start, strata=Strata.test)
76
+ on_test_end = partialmethod(_on_loop_end, strata=Strata.test)
77
+ on_predict_start = partialmethod(_on_loop_start, strata=Strata.predict)
78
+ on_predict_end = partialmethod(_on_loop_end, strata=Strata.predict)
79
+
80
+
81
+ class RuntimePlacementCallback(Callback):
82
+ """Move late-created modules onto the Lightning module's active device."""
83
+
84
+ def _on_loop_start(self, trainer: lit.Trainer, pl_module: lit.LightningModule, strata: Strata) -> None:
85
+ device = getattr(pl_module, "device", None)
86
+ if isinstance(device, torch.device):
87
+ pl_module.to(device=device)
88
+
89
+ on_train_start = partialmethod(_on_loop_start, strata=Strata.train)
90
+ on_validation_start = partialmethod(_on_loop_start, strata=Strata.validate)
91
+ on_test_start = partialmethod(_on_loop_start, strata=Strata.test)
92
+ on_predict_start = partialmethod(_on_loop_start, strata=Strata.predict)
93
+
94
+
95
+ class AttributeChange(pydantic.BaseModel):
96
+ model_config = pydantic.ConfigDict(arbitrary_types_allowed=True)
97
+
25
98
  node: Node
26
99
  name: str
27
100
  original: Any
28
101
  definition_attribute: bool
102
+ address: str
103
+ node_name: str
104
+ node_type: str
105
+ changed: Any = _MISSING
106
+ changed_address: Any = _MISSING
29
107
 
30
108
 
31
109
  class SchemaEditor:
@@ -34,6 +112,12 @@ class SchemaEditor:
34
112
  def __init__(self, module: "Model") -> None:
35
113
  self.module = module
36
114
 
115
+ def _assert_mutation_allowed(self, action: str) -> None:
116
+ active = tuple(name for name, count in self.module.locks.items() if count > 0)
117
+ if active:
118
+ labels = ", ".join(active)
119
+ raise RuntimeError(f"model.{action}(...) cannot run while the model is in an active loop: {labels}")
120
+
37
121
  def select(
38
122
  self,
39
123
  *predicates: NodePredicate | NodeAttribute | Callable[[Node], bool],
@@ -56,7 +140,7 @@ class SchemaEditor:
56
140
  use_cache: bool = False,
57
141
  **values: Any,
58
142
  ) -> None:
59
- self.module._assert_mutation_allowed("update")
143
+ self._assert_mutation_allowed("update")
60
144
  values = self.module.hyperparameters.update_values(values)
61
145
  changes = self._attribute_changes(
62
146
  values=values,
@@ -84,7 +168,7 @@ class SchemaEditor:
84
168
  include_root: bool = True,
85
169
  use_cache: bool = True,
86
170
  ) -> None:
87
- self.module._assert_mutation_allowed("extend")
171
+ self._assert_mutation_allowed("extend")
88
172
  parent, field_count = self._extend_target(*args, include_root=include_root, use_cache=use_cache)
89
173
  self.module.hyperparameters.extend(*args, include_root=include_root, use_cache=use_cache)
90
174
  ModelGraph.rebuild(self.module)
@@ -103,7 +187,7 @@ class SchemaEditor:
103
187
  include_root: bool = False,
104
188
  use_cache: bool = True,
105
189
  ) -> None:
106
- self.module._assert_mutation_allowed("delete")
190
+ self._assert_mutation_allowed("delete")
107
191
  roots = self._delete_roots(*predicates, include_root=include_root, use_cache=use_cache)
108
192
  self.module.hyperparameters.delete(*predicates, include_root=include_root, use_cache=use_cache)
109
193
  ModelGraph.rebuild(self.module)
@@ -123,7 +207,7 @@ class SchemaEditor:
123
207
  use_cache: bool = True,
124
208
  descendants: bool = False,
125
209
  ) -> None:
126
- self.module._assert_mutation_allowed("reset")
210
+ self._assert_mutation_allowed("reset")
127
211
  selected = self.module.hyperparameters.select(
128
212
  *predicates,
129
213
  include_root=include_root,
@@ -154,7 +238,7 @@ class SchemaEditor:
154
238
  use_cache: bool = False,
155
239
  **values: Any,
156
240
  ) -> Iterator[None]:
157
- self.module._assert_mutation_allowed("override")
241
+ self._assert_mutation_allowed("override")
158
242
  values = self.module.hyperparameters.update_values(values)
159
243
  changes = self._attribute_changes(
160
244
  values=values,
@@ -208,6 +292,9 @@ class SchemaEditor:
208
292
  name=name,
209
293
  original=getattr(node, name, _MISSING),
210
294
  definition_attribute=_is_definition_attribute(node, name),
295
+ address=str(node.address),
296
+ node_name=node.name,
297
+ node_type=node.type,
211
298
  )
212
299
  )
213
300
 
@@ -280,29 +367,55 @@ class SchemaEditor:
280
367
 
281
368
  def _log_attribute_changes(self, action: str, changes: list[AttributeChange], *, restored: bool = False) -> None:
282
369
  for change in changes:
370
+ current_address = str(change.node.address)
283
371
  value = change.original if restored else getattr(change.node, change.name, _MISSING)
372
+ if not restored:
373
+ change.changed = value
374
+ change.changed_address = current_address
375
+ previous_value = change.changed if restored else change.original
376
+ previous_address = change.changed_address if restored else change.address
377
+ if previous_address is _MISSING:
378
+ previous_address = change.address
379
+ address_context = (
380
+ current_address if previous_address == current_address else f"{previous_address} -> {current_address}"
381
+ )
382
+ value_text = _format_log_value(value)
383
+ previous_value_text = _format_log_value(previous_value)
284
384
  logger.bind(
285
385
  component="schema_mutation",
286
386
  action=action,
287
- address=str(change.node.address),
288
- node_type=change.node.type,
387
+ address=current_address,
388
+ previous_address=previous_address,
389
+ node_name=change.node.name,
390
+ previous_node_name=change.node_name,
391
+ node_type=change.node_type,
289
392
  attribute=change.name,
290
393
  definition_attribute=change.definition_attribute,
291
- value=_format_log_value(value),
292
- previous_value=_format_log_value(change.original),
293
- ).info("restored schema node attribute" if restored else "mutated schema node attribute")
394
+ value=value_text,
395
+ previous_value=previous_value_text,
396
+ change=f"{change.name}: {previous_value_text} -> {value_text}",
397
+ ).info(
398
+ "{} {}: {} {} -> {}",
399
+ "restored" if restored else "mutated",
400
+ address_context,
401
+ change.name,
402
+ previous_value_text,
403
+ value_text,
404
+ )
294
405
 
295
406
  def _log_node_mutation(self, *, action: str, message: str, node: Node, **kwargs: Any) -> None:
296
407
  extra = {key: str(value.address) if isinstance(value, Node) else value for key, value in kwargs.items()}
408
+ context = _format_node_log_context(node, extra)
297
409
  logger.bind(
298
410
  component="schema_mutation",
299
411
  action=action,
300
412
  address=str(node.address),
301
413
  node_type=node.type,
414
+ node_name=node.name,
302
415
  attribute=None,
303
416
  definition_attribute=None,
304
417
  **extra,
305
- ).info(message)
418
+ ).info("{} {}", message, context)
306
419
 
307
420
 
308
421
  def _has_node_attribute(node: Node, name: str) -> bool:
@@ -321,3 +434,13 @@ def _format_log_value(value: Any) -> str:
321
434
 
322
435
  text = repr(value)
323
436
  return text if len(text) <= 160 else f"{text[:157]}..."
437
+
438
+
439
+ def _format_node_log_context(node: Node, extra: dict[str, Any]) -> str:
440
+ parts = [str(node.address)]
441
+ if parent := extra.get("parent"):
442
+ parts.append(f"under {parent}")
443
+ if "descendants" in extra:
444
+ parts.append(f"descendants={extra['descendants']}")
445
+
446
+ return " ".join(parts)