json2vec 0.4.7__tar.gz → 0.4.9__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {json2vec-0.4.7/src/json2vec.egg-info → json2vec-0.4.9}/PKG-INFO +8 -5
- {json2vec-0.4.7 → json2vec-0.4.9}/README.md +1 -1
- {json2vec-0.4.7 → json2vec-0.4.9}/pyproject.toml +12 -5
- {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/__init__.py +13 -14
- {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/architecture/checkpoint.py +42 -0
- {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/architecture/contracts.py +5 -4
- {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/architecture/mutations.py +137 -14
- {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/architecture/root.py +62 -133
- {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/architecture/runtime.py +12 -3
- {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/data/datasets/base.py +8 -0
- {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/data/datasets/custom.py +11 -4
- {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/data/datasets/polars.py +11 -4
- {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/data/datasets/streaming.py +11 -4
- {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/data/iterables.py +114 -8
- {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/data/processing.py +90 -3
- json2vec-0.4.9/src/json2vec/helpers/__init__.py +8 -0
- json2vec-0.4.9/src/json2vec/helpers/inference.py +632 -0
- json2vec-0.4.9/src/json2vec/helpers/optimizers.py +78 -0
- json2vec-0.4.9/src/json2vec/helpers/trainer.py +0 -0
- {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/inference/__init__.py +5 -11
- json2vec-0.4.9/src/json2vec/inference/deployment.py +691 -0
- json2vec-0.4.9/src/json2vec/structs/__init__.py +0 -0
- {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/structs/enums.py +0 -1
- {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/structs/experiment.py +70 -25
- json2vec-0.4.9/src/json2vec/structs/structure.py +228 -0
- {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/structs/tree.py +147 -2
- {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/tensorfields/base.py +254 -45
- {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/tensorfields/extensions/category.py +83 -84
- {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/tensorfields/extensions/dateparts.py +32 -29
- {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/tensorfields/extensions/entity.py +33 -28
- {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/tensorfields/extensions/number.py +77 -74
- {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/tensorfields/extensions/set.py +61 -90
- {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/tensorfields/extensions/text.py +30 -20
- {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/tensorfields/extensions/vector.py +32 -26
- json2vec-0.4.9/src/json2vec/tensorfields/shared/__init__.py +80 -0
- {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/tensorfields/shared/counter.py +3 -1
- json2vec-0.4.9/src/json2vec/tensorfields/shared/vocabulary.py +440 -0
- {json2vec-0.4.7 → json2vec-0.4.9/src/json2vec.egg-info}/PKG-INFO +8 -5
- {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec.egg-info/SOURCES.txt +8 -2
- {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec.egg-info/requires.txt +6 -3
- {json2vec-0.4.7 → json2vec-0.4.9}/tests/test_callbacks.py +6 -30
- json2vec-0.4.9/tests/test_optimizers.py +78 -0
- {json2vec-0.4.7 → json2vec-0.4.9}/tests/test_public_api.py +1 -3
- json2vec-0.4.9/tests/test_schema_inference.py +327 -0
- json2vec-0.4.7/src/json2vec/architecture/plot.py +0 -562
- json2vec-0.4.7/src/json2vec/inference/deployment.py +0 -422
- json2vec-0.4.7/src/json2vec/structs/structure.py +0 -59
- json2vec-0.4.7/src/json2vec/tensorfields/shared/__init__.py +0 -12
- json2vec-0.4.7/src/json2vec/tensorfields/shared/vocabulary.py +0 -283
- {json2vec-0.4.7 → json2vec-0.4.9}/LICENSE +0 -0
- {json2vec-0.4.7 → json2vec-0.4.9}/NOTICE +0 -0
- {json2vec-0.4.7 → json2vec-0.4.9}/setup.cfg +0 -0
- {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/architecture/__init__.py +0 -0
- {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/architecture/attention.py +0 -0
- {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/architecture/encoder.py +0 -0
- {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/architecture/graph.py +0 -0
- {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/architecture/node.py +0 -0
- {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/architecture/pool.py +0 -0
- {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/architecture/rotary.py +0 -0
- {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/data/__init__.py +0 -0
- {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/data/datasets/__init__.py +0 -0
- {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/distributed.py +0 -0
- /json2vec-0.4.7/src/json2vec/structs/__init__.py → /json2vec-0.4.9/src/json2vec/helpers/hyperparameters.py +0 -0
- {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/inference/callback.py +0 -0
- {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/logging/__init__.py +0 -0
- {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/logging/config.py +0 -0
- {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/logging/epoch.py +0 -0
- {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/logging/throughput.py +0 -0
- {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/preprocessors/__init__.py +0 -0
- {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/preprocessors/base.py +0 -0
- {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/preprocessors/extensions/__init__.py +0 -0
- {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/preprocessors/spec.py +0 -0
- {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/structs/packages.py +0 -0
- {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/structs/selectors.py +0 -0
- {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/tensorfields/__init__.py +0 -0
- {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/tensorfields/extensions/__init__.py +0 -0
- {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec/tensorfields/spec.py +0 -0
- {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec.egg-info/dependency_links.txt +0 -0
- {json2vec-0.4.7 → json2vec-0.4.9}/src/json2vec.egg-info/top_level.txt +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: json2vec
|
|
3
|
-
Version: 0.4.
|
|
3
|
+
Version: 0.4.9
|
|
4
4
|
Summary: Schema-first PyTorch models for hierarchical / nested / sequence data structures
|
|
5
5
|
License-Expression: Apache-2.0
|
|
6
6
|
Requires-Python: >=3.12
|
|
@@ -14,7 +14,6 @@ Requires-Dist: pydantic>=2.11.7
|
|
|
14
14
|
Requires-Dist: jmespath>=1.0.1
|
|
15
15
|
Requires-Dist: loguru>=0.7.3
|
|
16
16
|
Requires-Dist: anytree>=2.13.0
|
|
17
|
-
Requires-Dist: ordered-set>=4.1.0
|
|
18
17
|
Requires-Dist: pyarrow>=21.0.0
|
|
19
18
|
Requires-Dist: polars>=1.35.2
|
|
20
19
|
Requires-Dist: numpy>=2.2.6
|
|
@@ -22,16 +21,20 @@ Requires-Dist: lightning>=2.6.4
|
|
|
22
21
|
Requires-Dist: tensordict>=0.10.0
|
|
23
22
|
Requires-Dist: torch>=2.7.1
|
|
24
23
|
Provides-Extra: serving
|
|
25
|
-
Requires-Dist:
|
|
24
|
+
Requires-Dist: fastapi>=0.124.0; extra == "serving"
|
|
25
|
+
Requires-Dist: orjson>=3.10.0; extra == "serving"
|
|
26
26
|
Requires-Dist: pydantic-settings>=2.10.1; extra == "serving"
|
|
27
|
+
Requires-Dist: uvicorn>=0.38.0; extra == "serving"
|
|
27
28
|
Provides-Extra: text
|
|
28
29
|
Requires-Dist: transformers>=4.55.0; extra == "text"
|
|
29
30
|
Provides-Extra: docs
|
|
30
|
-
Requires-Dist:
|
|
31
|
+
Requires-Dist: fastapi>=0.124.0; extra == "docs"
|
|
31
32
|
Requires-Dist: mkdocs-material>=9.6; extra == "docs"
|
|
32
33
|
Requires-Dist: mkdocs-jupyter>=0.26.3; extra == "docs"
|
|
33
34
|
Requires-Dist: mkdocstrings[python]>=0.27; extra == "docs"
|
|
35
|
+
Requires-Dist: orjson>=3.10.0; extra == "docs"
|
|
34
36
|
Requires-Dist: pydantic-settings>=2.10.1; extra == "docs"
|
|
37
|
+
Requires-Dist: uvicorn>=0.38.0; extra == "docs"
|
|
35
38
|
Dynamic: license-file
|
|
36
39
|
|
|
37
40
|
<h1 align="center"><code>json2vec</code></h1>
|
|
@@ -314,7 +317,7 @@ uv sync --extra docs
|
|
|
314
317
|
```
|
|
315
318
|
|
|
316
319
|
The `text` extra installs Hugging Face `transformers`. The `serving` extra
|
|
317
|
-
installs
|
|
320
|
+
installs FastAPI-backed deployment dependencies. The `docs` extra installs the
|
|
318
321
|
MkDocs toolchain.
|
|
319
322
|
|
|
320
323
|
## Documentation Map
|
|
@@ -278,7 +278,7 @@ uv sync --extra docs
|
|
|
278
278
|
```
|
|
279
279
|
|
|
280
280
|
The `text` extra installs Hugging Face `transformers`. The `serving` extra
|
|
281
|
-
installs
|
|
281
|
+
installs FastAPI-backed deployment dependencies. The `docs` extra installs the
|
|
282
282
|
MkDocs toolchain.
|
|
283
283
|
|
|
284
284
|
## Documentation Map
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "json2vec"
|
|
3
|
-
version = "0.4.
|
|
3
|
+
version = "0.4.9"
|
|
4
4
|
description = "Schema-first PyTorch models for hierarchical / nested / sequence data structures"
|
|
5
5
|
readme = "README.md"
|
|
6
6
|
license = "Apache-2.0"
|
|
@@ -13,7 +13,6 @@ dependencies = [
|
|
|
13
13
|
"jmespath>=1.0.1",
|
|
14
14
|
"loguru>=0.7.3",
|
|
15
15
|
"anytree>=2.13.0",
|
|
16
|
-
"ordered-set>=4.1.0",
|
|
17
16
|
"pyarrow>=21.0.0",
|
|
18
17
|
"polars>=1.35.2",
|
|
19
18
|
"numpy>=2.2.6",
|
|
@@ -24,30 +23,37 @@ dependencies = [
|
|
|
24
23
|
|
|
25
24
|
[project.optional-dependencies]
|
|
26
25
|
serving = [
|
|
27
|
-
"
|
|
26
|
+
"fastapi>=0.124.0",
|
|
27
|
+
"orjson>=3.10.0",
|
|
28
28
|
"pydantic-settings>=2.10.1",
|
|
29
|
+
"uvicorn>=0.38.0",
|
|
29
30
|
]
|
|
30
31
|
text = [
|
|
31
32
|
"transformers>=4.55.0",
|
|
32
33
|
]
|
|
33
34
|
docs = [
|
|
34
|
-
"
|
|
35
|
+
"fastapi>=0.124.0",
|
|
35
36
|
"mkdocs-material>=9.6",
|
|
36
37
|
"mkdocs-jupyter>=0.26.3",
|
|
37
38
|
"mkdocstrings[python]>=0.27",
|
|
39
|
+
"orjson>=3.10.0",
|
|
38
40
|
"pydantic-settings>=2.10.1",
|
|
41
|
+
"uvicorn>=0.38.0",
|
|
39
42
|
]
|
|
40
43
|
|
|
41
44
|
[dependency-groups]
|
|
42
45
|
dev = [
|
|
43
46
|
"ruff>=0.12.12",
|
|
44
47
|
"pytest>=8.4.1",
|
|
48
|
+
"pytest-xdist>=3.8.0",
|
|
45
49
|
"ipython>=9.9.0",
|
|
46
50
|
"ipykernel>=6.29.5",
|
|
47
51
|
"nbclient>=0.10.2",
|
|
48
52
|
"nbformat>=5.10.4",
|
|
49
|
-
"
|
|
53
|
+
"fastapi>=0.124.0",
|
|
54
|
+
"orjson>=3.10.0",
|
|
50
55
|
"pydantic-settings>=2.10.1",
|
|
56
|
+
"uvicorn>=0.38.0",
|
|
51
57
|
"ty>=0.0.1a20",
|
|
52
58
|
"pre-commit>=4.3.0",
|
|
53
59
|
]
|
|
@@ -66,6 +72,7 @@ include = ["json2vec*"]
|
|
|
66
72
|
[tool.pytest.ini_options]
|
|
67
73
|
testpaths = ["tests"]
|
|
68
74
|
python_files = ["test_*.py"]
|
|
75
|
+
addopts = ["-n", "auto"]
|
|
69
76
|
|
|
70
77
|
[tool.ruff]
|
|
71
78
|
line-length = 120
|
|
@@ -8,15 +8,16 @@ mutation predicates, and the `@preprocess` decorator.
|
|
|
8
8
|
|
|
9
9
|
from typing import TYPE_CHECKING, Any
|
|
10
10
|
|
|
11
|
+
from json2vec import helpers as helpers
|
|
12
|
+
from json2vec.architecture.checkpoint import RollbackCheckpoint
|
|
13
|
+
from json2vec.architecture.mutations import MutationLockCallback, RuntimePlacementCallback
|
|
11
14
|
from json2vec.architecture.root import (
|
|
12
15
|
Model,
|
|
13
|
-
MutationLockCallback,
|
|
14
16
|
OptimizerConfig,
|
|
15
|
-
RollbackCheckpoint,
|
|
16
|
-
RuntimePlacementCallback,
|
|
17
17
|
SchedulerConfig,
|
|
18
18
|
)
|
|
19
19
|
from json2vec.data.datasets import CustomDataModule, PolarsDataModule, StreamingDataModule
|
|
20
|
+
from json2vec.data.processing import MASK_LITERAL, MaskLiteral
|
|
20
21
|
from json2vec.inference.callback import Postprocessor, Writer
|
|
21
22
|
from json2vec.preprocessors import PREPROCESSORS, Preprocessor, PreprocessorMode, preprocess
|
|
22
23
|
from json2vec.structs.enums import (
|
|
@@ -38,7 +39,7 @@ from json2vec.structs.experiment import (
|
|
|
38
39
|
predicate,
|
|
39
40
|
where,
|
|
40
41
|
)
|
|
41
|
-
from json2vec.structs.structure import Array
|
|
42
|
+
from json2vec.structs.structure import Array, Mask
|
|
42
43
|
from json2vec.structs.tree import Address, Leaf
|
|
43
44
|
from json2vec.tensorfields import TENSORFIELDS, DecoderBase, EmbedderBase, Plugin, RequestBase, TensorFieldBase
|
|
44
45
|
from json2vec.tensorfields.extensions.category import Request as Category
|
|
@@ -52,23 +53,19 @@ from json2vec.tensorfields.shared.vocabulary import VocabularySyncCallback
|
|
|
52
53
|
|
|
53
54
|
if TYPE_CHECKING:
|
|
54
55
|
from json2vec.inference.deployment import (
|
|
55
|
-
API,
|
|
56
56
|
Accelerator,
|
|
57
|
-
BatchItem,
|
|
58
57
|
Deployment,
|
|
59
|
-
ErrorItem,
|
|
60
58
|
Input,
|
|
59
|
+
JSONBackend,
|
|
61
60
|
ModelSource,
|
|
62
61
|
UpdateOperation,
|
|
63
62
|
)
|
|
64
63
|
|
|
65
64
|
_SERVING_EXPORTS = {
|
|
66
|
-
"API",
|
|
67
65
|
"Accelerator",
|
|
68
|
-
"BatchItem",
|
|
69
66
|
"Deployment",
|
|
70
|
-
"ErrorItem",
|
|
71
67
|
"Input",
|
|
68
|
+
"JSONBackend",
|
|
72
69
|
"ModelSource",
|
|
73
70
|
"UpdateOperation",
|
|
74
71
|
}
|
|
@@ -81,7 +78,7 @@ def __getattr__(name: str) -> Any:
|
|
|
81
78
|
try:
|
|
82
79
|
from json2vec.inference import deployment
|
|
83
80
|
except ModuleNotFoundError as error:
|
|
84
|
-
if error.name in {"
|
|
81
|
+
if error.name in {"fastapi", "orjson", "pydantic_settings", "uvicorn"}:
|
|
85
82
|
raise ModuleNotFoundError(
|
|
86
83
|
f"json2vec.{name} requires the serving extra; install with `pip install json2vec[serving]`."
|
|
87
84
|
) from error
|
|
@@ -98,11 +95,9 @@ def __dir__() -> list[str]:
|
|
|
98
95
|
|
|
99
96
|
__all__ = [
|
|
100
97
|
"Address",
|
|
101
|
-
"API",
|
|
102
98
|
"Accelerator",
|
|
103
99
|
"Array",
|
|
104
100
|
"AttentionMode",
|
|
105
|
-
"BatchItem",
|
|
106
101
|
"Category",
|
|
107
102
|
"Component",
|
|
108
103
|
"CustomDataModule",
|
|
@@ -111,11 +106,15 @@ __all__ = [
|
|
|
111
106
|
"Deployment",
|
|
112
107
|
"EmbedderBase",
|
|
113
108
|
"Entity",
|
|
114
|
-
"
|
|
109
|
+
"helpers",
|
|
115
110
|
"Hyperparameters",
|
|
116
111
|
"Input",
|
|
112
|
+
"JSONBackend",
|
|
117
113
|
"Leaf",
|
|
118
114
|
"Metric",
|
|
115
|
+
"MASK_LITERAL",
|
|
116
|
+
"Mask",
|
|
117
|
+
"MaskLiteral",
|
|
119
118
|
"Model",
|
|
120
119
|
"ModelSource",
|
|
121
120
|
"MutationLockCallback",
|
|
@@ -5,7 +5,9 @@ from __future__ import annotations
|
|
|
5
5
|
from pathlib import Path
|
|
6
6
|
from typing import TYPE_CHECKING, Any
|
|
7
7
|
|
|
8
|
+
import lightning.pytorch as lit
|
|
8
9
|
import torch
|
|
10
|
+
from lightning.pytorch.callbacks import ModelCheckpoint
|
|
9
11
|
from loguru import logger
|
|
10
12
|
|
|
11
13
|
from json2vec.architecture.graph import ModelGraph
|
|
@@ -15,6 +17,46 @@ if TYPE_CHECKING:
|
|
|
15
17
|
from json2vec.architecture.root import Model
|
|
16
18
|
|
|
17
19
|
|
|
20
|
+
class RollbackCheckpoint(ModelCheckpoint):
|
|
21
|
+
"""Checkpoint the best model during fit and restore it into the module at fit end."""
|
|
22
|
+
|
|
23
|
+
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
|
24
|
+
super().__init__(*args, **kwargs)
|
|
25
|
+
if self.save_weights_only:
|
|
26
|
+
raise ValueError("RollbackCheckpoint requires full checkpoints; set save_weights_only=False")
|
|
27
|
+
if self.save_top_k == 0:
|
|
28
|
+
raise ValueError("RollbackCheckpoint requires at least one saved checkpoint; set save_top_k != 0")
|
|
29
|
+
|
|
30
|
+
def on_fit_end(self, trainer: lit.Trainer, pl_module: lit.LightningModule) -> None:
|
|
31
|
+
from json2vec.architecture.root import Model
|
|
32
|
+
|
|
33
|
+
super().on_fit_end(trainer=trainer, pl_module=pl_module)
|
|
34
|
+
if not isinstance(pl_module, Model):
|
|
35
|
+
raise TypeError("RollbackCheckpoint can only restore json2vec Model instances")
|
|
36
|
+
|
|
37
|
+
best_model_path = self.best_model_path
|
|
38
|
+
if not best_model_path:
|
|
39
|
+
raise RuntimeError("RollbackCheckpoint did not find a best checkpoint to restore")
|
|
40
|
+
|
|
41
|
+
strategy = getattr(trainer, "strategy", None)
|
|
42
|
+
if strategy is not None:
|
|
43
|
+
strategy.barrier("rollback_checkpoint_load")
|
|
44
|
+
checkpoint = strategy.checkpoint_io.load_checkpoint(
|
|
45
|
+
best_model_path,
|
|
46
|
+
map_location=pl_module.device,
|
|
47
|
+
weights_only=False,
|
|
48
|
+
)
|
|
49
|
+
else:
|
|
50
|
+
checkpoint = torch.load(best_model_path, weights_only=False, map_location=pl_module.device)
|
|
51
|
+
|
|
52
|
+
pl_module.restore_checkpoint_state(checkpoint)
|
|
53
|
+
logger.bind(
|
|
54
|
+
component="checkpoint",
|
|
55
|
+
checkpoint=best_model_path,
|
|
56
|
+
score=self.best_model_score,
|
|
57
|
+
).info("rolled back Model to best checkpoint")
|
|
58
|
+
|
|
59
|
+
|
|
18
60
|
class CheckpointState:
|
|
19
61
|
"""Save, load, and restore model state without owning the public facade."""
|
|
20
62
|
|
|
@@ -3,9 +3,9 @@
|
|
|
3
3
|
from __future__ import annotations
|
|
4
4
|
|
|
5
5
|
from collections.abc import Iterator, Mapping
|
|
6
|
-
from dataclasses import dataclass, field
|
|
7
6
|
from typing import TYPE_CHECKING, Any
|
|
8
7
|
|
|
8
|
+
import pydantic
|
|
9
9
|
import torch
|
|
10
10
|
from tensordict import TensorDict
|
|
11
11
|
|
|
@@ -34,12 +34,13 @@ ContractSignature = tuple[Any, ...]
|
|
|
34
34
|
ContractScope = tuple[str, int, int, ContractSignature]
|
|
35
35
|
|
|
36
36
|
|
|
37
|
-
|
|
38
|
-
class ContractScheduler:
|
|
37
|
+
class ContractScheduler(pydantic.BaseModel):
|
|
39
38
|
"""Deterministic backoff scheduler for expensive forward contract checks."""
|
|
40
39
|
|
|
40
|
+
model_config = pydantic.ConfigDict(arbitrary_types_allowed=True)
|
|
41
|
+
|
|
41
42
|
periodic_interval: int = 1024
|
|
42
|
-
_counts: dict[ContractScope, int] =
|
|
43
|
+
_counts: dict[ContractScope, int] = pydantic.PrivateAttr(default_factory=dict)
|
|
43
44
|
|
|
44
45
|
def reset(self) -> None:
|
|
45
46
|
self._counts.clear()
|
|
@@ -4,12 +4,17 @@ from __future__ import annotations
|
|
|
4
4
|
|
|
5
5
|
from collections.abc import Callable, Iterator
|
|
6
6
|
from contextlib import contextmanager
|
|
7
|
-
from
|
|
7
|
+
from functools import partialmethod, wraps
|
|
8
8
|
from typing import TYPE_CHECKING, Any
|
|
9
9
|
|
|
10
|
+
import lightning.pytorch as lit
|
|
11
|
+
import pydantic
|
|
12
|
+
import torch
|
|
13
|
+
from lightning.pytorch import Callback
|
|
10
14
|
from loguru import logger
|
|
11
15
|
|
|
12
16
|
from json2vec.architecture.graph import ModelGraph
|
|
17
|
+
from json2vec.structs.enums import Strata
|
|
13
18
|
from json2vec.structs.experiment import NodeAttribute, NodePredicate, SchemaField
|
|
14
19
|
from json2vec.structs.structure import Array
|
|
15
20
|
from json2vec.structs.tree import Leaf, Node
|
|
@@ -20,12 +25,85 @@ if TYPE_CHECKING:
|
|
|
20
25
|
_MISSING = object()
|
|
21
26
|
|
|
22
27
|
|
|
23
|
-
|
|
24
|
-
|
|
28
|
+
def immutable(name: str | Strata) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
|
29
|
+
def decorator(method: Callable[..., Any]) -> Callable[..., Any]:
|
|
30
|
+
@wraps(method)
|
|
31
|
+
def wrapped(self: Any, *args: Any, **kwargs: Any) -> Any:
|
|
32
|
+
locks = self.locks
|
|
33
|
+
locks[name] += 1
|
|
34
|
+
try:
|
|
35
|
+
return method(self, *args, **kwargs)
|
|
36
|
+
finally:
|
|
37
|
+
if locks[name] <= 1:
|
|
38
|
+
locks.pop(name, None)
|
|
39
|
+
else:
|
|
40
|
+
locks[name] -= 1
|
|
41
|
+
|
|
42
|
+
return wrapped
|
|
43
|
+
|
|
44
|
+
return decorator
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class MutationLockCallback(Callback):
|
|
48
|
+
"""Prevent runtime schema mutations while Lightning owns an active loop."""
|
|
49
|
+
|
|
50
|
+
locks: tuple[Strata, ...] = (Strata.train, Strata.validate, Strata.test, Strata.predict)
|
|
51
|
+
|
|
52
|
+
def _on_loop_start(self, trainer: lit.Trainer, pl_module: "Model", strata: Strata) -> None:
|
|
53
|
+
pl_module.locks[strata] += 1
|
|
54
|
+
|
|
55
|
+
def _on_loop_end(self, trainer: lit.Trainer, pl_module: "Model", strata: Strata) -> None:
|
|
56
|
+
locks = pl_module.locks
|
|
57
|
+
if locks[strata] <= 1:
|
|
58
|
+
locks.pop(strata, None)
|
|
59
|
+
else:
|
|
60
|
+
locks[strata] -= 1
|
|
61
|
+
|
|
62
|
+
def on_exception(
|
|
63
|
+
self,
|
|
64
|
+
trainer: lit.Trainer,
|
|
65
|
+
pl_module: "Model",
|
|
66
|
+
exception: BaseException,
|
|
67
|
+
) -> None: # ty:ignore[invalid-method-override]
|
|
68
|
+
for lock in self.locks:
|
|
69
|
+
pl_module.locks.pop(lock, None)
|
|
70
|
+
|
|
71
|
+
on_train_start = partialmethod(_on_loop_start, strata=Strata.train)
|
|
72
|
+
on_train_end = partialmethod(_on_loop_end, strata=Strata.train)
|
|
73
|
+
on_validation_start = partialmethod(_on_loop_start, strata=Strata.validate)
|
|
74
|
+
on_validation_end = partialmethod(_on_loop_end, strata=Strata.validate)
|
|
75
|
+
on_test_start = partialmethod(_on_loop_start, strata=Strata.test)
|
|
76
|
+
on_test_end = partialmethod(_on_loop_end, strata=Strata.test)
|
|
77
|
+
on_predict_start = partialmethod(_on_loop_start, strata=Strata.predict)
|
|
78
|
+
on_predict_end = partialmethod(_on_loop_end, strata=Strata.predict)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class RuntimePlacementCallback(Callback):
|
|
82
|
+
"""Move late-created modules onto the Lightning module's active device."""
|
|
83
|
+
|
|
84
|
+
def _on_loop_start(self, trainer: lit.Trainer, pl_module: lit.LightningModule, strata: Strata) -> None:
|
|
85
|
+
device = getattr(pl_module, "device", None)
|
|
86
|
+
if isinstance(device, torch.device):
|
|
87
|
+
pl_module.to(device=device)
|
|
88
|
+
|
|
89
|
+
on_train_start = partialmethod(_on_loop_start, strata=Strata.train)
|
|
90
|
+
on_validation_start = partialmethod(_on_loop_start, strata=Strata.validate)
|
|
91
|
+
on_test_start = partialmethod(_on_loop_start, strata=Strata.test)
|
|
92
|
+
on_predict_start = partialmethod(_on_loop_start, strata=Strata.predict)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class AttributeChange(pydantic.BaseModel):
|
|
96
|
+
model_config = pydantic.ConfigDict(arbitrary_types_allowed=True)
|
|
97
|
+
|
|
25
98
|
node: Node
|
|
26
99
|
name: str
|
|
27
100
|
original: Any
|
|
28
101
|
definition_attribute: bool
|
|
102
|
+
address: str
|
|
103
|
+
node_name: str
|
|
104
|
+
node_type: str
|
|
105
|
+
changed: Any = _MISSING
|
|
106
|
+
changed_address: Any = _MISSING
|
|
29
107
|
|
|
30
108
|
|
|
31
109
|
class SchemaEditor:
|
|
@@ -34,6 +112,12 @@ class SchemaEditor:
|
|
|
34
112
|
def __init__(self, module: "Model") -> None:
|
|
35
113
|
self.module = module
|
|
36
114
|
|
|
115
|
+
def _assert_mutation_allowed(self, action: str) -> None:
|
|
116
|
+
active = tuple(name for name, count in self.module.locks.items() if count > 0)
|
|
117
|
+
if active:
|
|
118
|
+
labels = ", ".join(active)
|
|
119
|
+
raise RuntimeError(f"model.{action}(...) cannot run while the model is in an active loop: {labels}")
|
|
120
|
+
|
|
37
121
|
def select(
|
|
38
122
|
self,
|
|
39
123
|
*predicates: NodePredicate | NodeAttribute | Callable[[Node], bool],
|
|
@@ -56,7 +140,7 @@ class SchemaEditor:
|
|
|
56
140
|
use_cache: bool = False,
|
|
57
141
|
**values: Any,
|
|
58
142
|
) -> None:
|
|
59
|
-
self.
|
|
143
|
+
self._assert_mutation_allowed("update")
|
|
60
144
|
values = self.module.hyperparameters.update_values(values)
|
|
61
145
|
changes = self._attribute_changes(
|
|
62
146
|
values=values,
|
|
@@ -84,7 +168,7 @@ class SchemaEditor:
|
|
|
84
168
|
include_root: bool = True,
|
|
85
169
|
use_cache: bool = True,
|
|
86
170
|
) -> None:
|
|
87
|
-
self.
|
|
171
|
+
self._assert_mutation_allowed("extend")
|
|
88
172
|
parent, field_count = self._extend_target(*args, include_root=include_root, use_cache=use_cache)
|
|
89
173
|
self.module.hyperparameters.extend(*args, include_root=include_root, use_cache=use_cache)
|
|
90
174
|
ModelGraph.rebuild(self.module)
|
|
@@ -103,7 +187,7 @@ class SchemaEditor:
|
|
|
103
187
|
include_root: bool = False,
|
|
104
188
|
use_cache: bool = True,
|
|
105
189
|
) -> None:
|
|
106
|
-
self.
|
|
190
|
+
self._assert_mutation_allowed("delete")
|
|
107
191
|
roots = self._delete_roots(*predicates, include_root=include_root, use_cache=use_cache)
|
|
108
192
|
self.module.hyperparameters.delete(*predicates, include_root=include_root, use_cache=use_cache)
|
|
109
193
|
ModelGraph.rebuild(self.module)
|
|
@@ -123,7 +207,7 @@ class SchemaEditor:
|
|
|
123
207
|
use_cache: bool = True,
|
|
124
208
|
descendants: bool = False,
|
|
125
209
|
) -> None:
|
|
126
|
-
self.
|
|
210
|
+
self._assert_mutation_allowed("reset")
|
|
127
211
|
selected = self.module.hyperparameters.select(
|
|
128
212
|
*predicates,
|
|
129
213
|
include_root=include_root,
|
|
@@ -154,7 +238,7 @@ class SchemaEditor:
|
|
|
154
238
|
use_cache: bool = False,
|
|
155
239
|
**values: Any,
|
|
156
240
|
) -> Iterator[None]:
|
|
157
|
-
self.
|
|
241
|
+
self._assert_mutation_allowed("override")
|
|
158
242
|
values = self.module.hyperparameters.update_values(values)
|
|
159
243
|
changes = self._attribute_changes(
|
|
160
244
|
values=values,
|
|
@@ -208,6 +292,9 @@ class SchemaEditor:
|
|
|
208
292
|
name=name,
|
|
209
293
|
original=getattr(node, name, _MISSING),
|
|
210
294
|
definition_attribute=_is_definition_attribute(node, name),
|
|
295
|
+
address=str(node.address),
|
|
296
|
+
node_name=node.name,
|
|
297
|
+
node_type=node.type,
|
|
211
298
|
)
|
|
212
299
|
)
|
|
213
300
|
|
|
@@ -280,29 +367,55 @@ class SchemaEditor:
|
|
|
280
367
|
|
|
281
368
|
def _log_attribute_changes(self, action: str, changes: list[AttributeChange], *, restored: bool = False) -> None:
|
|
282
369
|
for change in changes:
|
|
370
|
+
current_address = str(change.node.address)
|
|
283
371
|
value = change.original if restored else getattr(change.node, change.name, _MISSING)
|
|
372
|
+
if not restored:
|
|
373
|
+
change.changed = value
|
|
374
|
+
change.changed_address = current_address
|
|
375
|
+
previous_value = change.changed if restored else change.original
|
|
376
|
+
previous_address = change.changed_address if restored else change.address
|
|
377
|
+
if previous_address is _MISSING:
|
|
378
|
+
previous_address = change.address
|
|
379
|
+
address_context = (
|
|
380
|
+
current_address if previous_address == current_address else f"{previous_address} -> {current_address}"
|
|
381
|
+
)
|
|
382
|
+
value_text = _format_log_value(value)
|
|
383
|
+
previous_value_text = _format_log_value(previous_value)
|
|
284
384
|
logger.bind(
|
|
285
385
|
component="schema_mutation",
|
|
286
386
|
action=action,
|
|
287
|
-
address=
|
|
288
|
-
|
|
387
|
+
address=current_address,
|
|
388
|
+
previous_address=previous_address,
|
|
389
|
+
node_name=change.node.name,
|
|
390
|
+
previous_node_name=change.node_name,
|
|
391
|
+
node_type=change.node_type,
|
|
289
392
|
attribute=change.name,
|
|
290
393
|
definition_attribute=change.definition_attribute,
|
|
291
|
-
value=
|
|
292
|
-
previous_value=
|
|
293
|
-
|
|
394
|
+
value=value_text,
|
|
395
|
+
previous_value=previous_value_text,
|
|
396
|
+
change=f"{change.name}: {previous_value_text} -> {value_text}",
|
|
397
|
+
).info(
|
|
398
|
+
"{} {}: {} {} -> {}",
|
|
399
|
+
"restored" if restored else "mutated",
|
|
400
|
+
address_context,
|
|
401
|
+
change.name,
|
|
402
|
+
previous_value_text,
|
|
403
|
+
value_text,
|
|
404
|
+
)
|
|
294
405
|
|
|
295
406
|
def _log_node_mutation(self, *, action: str, message: str, node: Node, **kwargs: Any) -> None:
|
|
296
407
|
extra = {key: str(value.address) if isinstance(value, Node) else value for key, value in kwargs.items()}
|
|
408
|
+
context = _format_node_log_context(node, extra)
|
|
297
409
|
logger.bind(
|
|
298
410
|
component="schema_mutation",
|
|
299
411
|
action=action,
|
|
300
412
|
address=str(node.address),
|
|
301
413
|
node_type=node.type,
|
|
414
|
+
node_name=node.name,
|
|
302
415
|
attribute=None,
|
|
303
416
|
definition_attribute=None,
|
|
304
417
|
**extra,
|
|
305
|
-
).info(message)
|
|
418
|
+
).info("{} {}", message, context)
|
|
306
419
|
|
|
307
420
|
|
|
308
421
|
def _has_node_attribute(node: Node, name: str) -> bool:
|
|
@@ -321,3 +434,13 @@ def _format_log_value(value: Any) -> str:
|
|
|
321
434
|
|
|
322
435
|
text = repr(value)
|
|
323
436
|
return text if len(text) <= 160 else f"{text[:157]}..."
|
|
437
|
+
|
|
438
|
+
|
|
439
|
+
def _format_node_log_context(node: Node, extra: dict[str, Any]) -> str:
|
|
440
|
+
parts = [str(node.address)]
|
|
441
|
+
if parent := extra.get("parent"):
|
|
442
|
+
parts.append(f"under {parent}")
|
|
443
|
+
if "descendants" in extra:
|
|
444
|
+
parts.append(f"descendants={extra['descendants']}")
|
|
445
|
+
|
|
446
|
+
return " ".join(parts)
|