json2vec 0.4.8__tar.gz → 0.4.9__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {json2vec-0.4.8/src/json2vec.egg-info → json2vec-0.4.9}/PKG-INFO +1 -1
- {json2vec-0.4.8 → json2vec-0.4.9}/pyproject.toml +1 -1
- {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/__init__.py +9 -4
- {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/architecture/checkpoint.py +42 -0
- {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/architecture/mutations.py +83 -5
- {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/architecture/root.py +16 -109
- {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/architecture/runtime.py +12 -3
- {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/data/datasets/custom.py +0 -2
- {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/data/datasets/polars.py +0 -2
- {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/data/datasets/streaming.py +0 -2
- {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/data/iterables.py +57 -4
- {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/data/processing.py +90 -3
- json2vec-0.4.9/src/json2vec/helpers/__init__.py +8 -0
- json2vec-0.4.9/src/json2vec/helpers/inference.py +632 -0
- {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/structs/experiment.py +30 -15
- json2vec-0.4.9/src/json2vec/structs/structure.py +228 -0
- {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/tensorfields/base.py +235 -6
- {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/tensorfields/extensions/category.py +33 -27
- {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/tensorfields/extensions/dateparts.py +32 -29
- {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/tensorfields/extensions/entity.py +33 -28
- {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/tensorfields/extensions/number.py +35 -33
- {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/tensorfields/extensions/set.py +31 -26
- {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/tensorfields/extensions/text.py +30 -20
- {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/tensorfields/extensions/vector.py +32 -26
- {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/tensorfields/shared/vocabulary.py +4 -0
- {json2vec-0.4.8 → json2vec-0.4.9/src/json2vec.egg-info}/PKG-INFO +1 -1
- {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec.egg-info/SOURCES.txt +3 -1
- json2vec-0.4.9/tests/test_schema_inference.py +327 -0
- json2vec-0.4.8/src/json2vec/structs/__init__.py +0 -0
- json2vec-0.4.8/src/json2vec/structs/structure.py +0 -110
- {json2vec-0.4.8 → json2vec-0.4.9}/LICENSE +0 -0
- {json2vec-0.4.8 → json2vec-0.4.9}/NOTICE +0 -0
- {json2vec-0.4.8 → json2vec-0.4.9}/README.md +0 -0
- {json2vec-0.4.8 → json2vec-0.4.9}/setup.cfg +0 -0
- {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/architecture/__init__.py +0 -0
- {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/architecture/attention.py +0 -0
- {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/architecture/contracts.py +0 -0
- {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/architecture/encoder.py +0 -0
- {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/architecture/graph.py +0 -0
- {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/architecture/node.py +0 -0
- {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/architecture/pool.py +0 -0
- {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/architecture/rotary.py +0 -0
- {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/data/__init__.py +0 -0
- {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/data/datasets/__init__.py +0 -0
- {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/data/datasets/base.py +0 -0
- {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/distributed.py +0 -0
- {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/helpers/hyperparameters.py +0 -0
- {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/helpers/optimizers.py +0 -0
- {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/helpers/trainer.py +0 -0
- {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/inference/__init__.py +0 -0
- {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/inference/callback.py +0 -0
- {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/inference/deployment.py +0 -0
- {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/logging/__init__.py +0 -0
- {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/logging/config.py +0 -0
- {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/logging/epoch.py +0 -0
- {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/logging/throughput.py +0 -0
- {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/preprocessors/__init__.py +0 -0
- {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/preprocessors/base.py +0 -0
- {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/preprocessors/extensions/__init__.py +0 -0
- {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/preprocessors/spec.py +0 -0
- {json2vec-0.4.8/src/json2vec/helpers → json2vec-0.4.9/src/json2vec/structs}/__init__.py +0 -0
- {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/structs/enums.py +0 -0
- {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/structs/packages.py +0 -0
- {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/structs/selectors.py +0 -0
- {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/structs/tree.py +0 -0
- {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/tensorfields/__init__.py +0 -0
- {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/tensorfields/extensions/__init__.py +0 -0
- {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/tensorfields/shared/__init__.py +0 -0
- {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/tensorfields/shared/counter.py +0 -0
- {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/tensorfields/spec.py +0 -0
- {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec.egg-info/dependency_links.txt +0 -0
- {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec.egg-info/requires.txt +0 -0
- {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec.egg-info/top_level.txt +0 -0
- {json2vec-0.4.8 → json2vec-0.4.9}/tests/test_callbacks.py +0 -0
- {json2vec-0.4.8 → json2vec-0.4.9}/tests/test_optimizers.py +0 -0
- {json2vec-0.4.8 → json2vec-0.4.9}/tests/test_public_api.py +0 -0
|
@@ -8,15 +8,16 @@ mutation predicates, and the `@preprocess` decorator.
|
|
|
8
8
|
|
|
9
9
|
from typing import TYPE_CHECKING, Any
|
|
10
10
|
|
|
11
|
+
from json2vec import helpers as helpers
|
|
12
|
+
from json2vec.architecture.checkpoint import RollbackCheckpoint
|
|
13
|
+
from json2vec.architecture.mutations import MutationLockCallback, RuntimePlacementCallback
|
|
11
14
|
from json2vec.architecture.root import (
|
|
12
15
|
Model,
|
|
13
|
-
MutationLockCallback,
|
|
14
16
|
OptimizerConfig,
|
|
15
|
-
RollbackCheckpoint,
|
|
16
|
-
RuntimePlacementCallback,
|
|
17
17
|
SchedulerConfig,
|
|
18
18
|
)
|
|
19
19
|
from json2vec.data.datasets import CustomDataModule, PolarsDataModule, StreamingDataModule
|
|
20
|
+
from json2vec.data.processing import MASK_LITERAL, MaskLiteral
|
|
20
21
|
from json2vec.inference.callback import Postprocessor, Writer
|
|
21
22
|
from json2vec.preprocessors import PREPROCESSORS, Preprocessor, PreprocessorMode, preprocess
|
|
22
23
|
from json2vec.structs.enums import (
|
|
@@ -38,7 +39,7 @@ from json2vec.structs.experiment import (
|
|
|
38
39
|
predicate,
|
|
39
40
|
where,
|
|
40
41
|
)
|
|
41
|
-
from json2vec.structs.structure import Array
|
|
42
|
+
from json2vec.structs.structure import Array, Mask
|
|
42
43
|
from json2vec.structs.tree import Address, Leaf
|
|
43
44
|
from json2vec.tensorfields import TENSORFIELDS, DecoderBase, EmbedderBase, Plugin, RequestBase, TensorFieldBase
|
|
44
45
|
from json2vec.tensorfields.extensions.category import Request as Category
|
|
@@ -105,11 +106,15 @@ __all__ = [
|
|
|
105
106
|
"Deployment",
|
|
106
107
|
"EmbedderBase",
|
|
107
108
|
"Entity",
|
|
109
|
+
"helpers",
|
|
108
110
|
"Hyperparameters",
|
|
109
111
|
"Input",
|
|
110
112
|
"JSONBackend",
|
|
111
113
|
"Leaf",
|
|
112
114
|
"Metric",
|
|
115
|
+
"MASK_LITERAL",
|
|
116
|
+
"Mask",
|
|
117
|
+
"MaskLiteral",
|
|
113
118
|
"Model",
|
|
114
119
|
"ModelSource",
|
|
115
120
|
"MutationLockCallback",
|
|
@@ -5,7 +5,9 @@ from __future__ import annotations
|
|
|
5
5
|
from pathlib import Path
|
|
6
6
|
from typing import TYPE_CHECKING, Any
|
|
7
7
|
|
|
8
|
+
import lightning.pytorch as lit
|
|
8
9
|
import torch
|
|
10
|
+
from lightning.pytorch.callbacks import ModelCheckpoint
|
|
9
11
|
from loguru import logger
|
|
10
12
|
|
|
11
13
|
from json2vec.architecture.graph import ModelGraph
|
|
@@ -15,6 +17,46 @@ if TYPE_CHECKING:
|
|
|
15
17
|
from json2vec.architecture.root import Model
|
|
16
18
|
|
|
17
19
|
|
|
20
|
+
class RollbackCheckpoint(ModelCheckpoint):
|
|
21
|
+
"""Checkpoint the best model during fit and restore it into the module at fit end."""
|
|
22
|
+
|
|
23
|
+
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
|
24
|
+
super().__init__(*args, **kwargs)
|
|
25
|
+
if self.save_weights_only:
|
|
26
|
+
raise ValueError("RollbackCheckpoint requires full checkpoints; set save_weights_only=False")
|
|
27
|
+
if self.save_top_k == 0:
|
|
28
|
+
raise ValueError("RollbackCheckpoint requires at least one saved checkpoint; set save_top_k != 0")
|
|
29
|
+
|
|
30
|
+
def on_fit_end(self, trainer: lit.Trainer, pl_module: lit.LightningModule) -> None:
|
|
31
|
+
from json2vec.architecture.root import Model
|
|
32
|
+
|
|
33
|
+
super().on_fit_end(trainer=trainer, pl_module=pl_module)
|
|
34
|
+
if not isinstance(pl_module, Model):
|
|
35
|
+
raise TypeError("RollbackCheckpoint can only restore json2vec Model instances")
|
|
36
|
+
|
|
37
|
+
best_model_path = self.best_model_path
|
|
38
|
+
if not best_model_path:
|
|
39
|
+
raise RuntimeError("RollbackCheckpoint did not find a best checkpoint to restore")
|
|
40
|
+
|
|
41
|
+
strategy = getattr(trainer, "strategy", None)
|
|
42
|
+
if strategy is not None:
|
|
43
|
+
strategy.barrier("rollback_checkpoint_load")
|
|
44
|
+
checkpoint = strategy.checkpoint_io.load_checkpoint(
|
|
45
|
+
best_model_path,
|
|
46
|
+
map_location=pl_module.device,
|
|
47
|
+
weights_only=False,
|
|
48
|
+
)
|
|
49
|
+
else:
|
|
50
|
+
checkpoint = torch.load(best_model_path, weights_only=False, map_location=pl_module.device)
|
|
51
|
+
|
|
52
|
+
pl_module.restore_checkpoint_state(checkpoint)
|
|
53
|
+
logger.bind(
|
|
54
|
+
component="checkpoint",
|
|
55
|
+
checkpoint=best_model_path,
|
|
56
|
+
score=self.best_model_score,
|
|
57
|
+
).info("rolled back Model to best checkpoint")
|
|
58
|
+
|
|
59
|
+
|
|
18
60
|
class CheckpointState:
|
|
19
61
|
"""Save, load, and restore model state without owning the public facade."""
|
|
20
62
|
|
|
@@ -4,12 +4,17 @@ from __future__ import annotations
|
|
|
4
4
|
|
|
5
5
|
from collections.abc import Callable, Iterator
|
|
6
6
|
from contextlib import contextmanager
|
|
7
|
+
from functools import partialmethod, wraps
|
|
7
8
|
from typing import TYPE_CHECKING, Any
|
|
8
9
|
|
|
10
|
+
import lightning.pytorch as lit
|
|
9
11
|
import pydantic
|
|
12
|
+
import torch
|
|
13
|
+
from lightning.pytorch import Callback
|
|
10
14
|
from loguru import logger
|
|
11
15
|
|
|
12
16
|
from json2vec.architecture.graph import ModelGraph
|
|
17
|
+
from json2vec.structs.enums import Strata
|
|
13
18
|
from json2vec.structs.experiment import NodeAttribute, NodePredicate, SchemaField
|
|
14
19
|
from json2vec.structs.structure import Array
|
|
15
20
|
from json2vec.structs.tree import Leaf, Node
|
|
@@ -20,6 +25,73 @@ if TYPE_CHECKING:
|
|
|
20
25
|
_MISSING = object()
|
|
21
26
|
|
|
22
27
|
|
|
28
|
+
def immutable(name: str | Strata) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
|
|
29
|
+
def decorator(method: Callable[..., Any]) -> Callable[..., Any]:
|
|
30
|
+
@wraps(method)
|
|
31
|
+
def wrapped(self: Any, *args: Any, **kwargs: Any) -> Any:
|
|
32
|
+
locks = self.locks
|
|
33
|
+
locks[name] += 1
|
|
34
|
+
try:
|
|
35
|
+
return method(self, *args, **kwargs)
|
|
36
|
+
finally:
|
|
37
|
+
if locks[name] <= 1:
|
|
38
|
+
locks.pop(name, None)
|
|
39
|
+
else:
|
|
40
|
+
locks[name] -= 1
|
|
41
|
+
|
|
42
|
+
return wrapped
|
|
43
|
+
|
|
44
|
+
return decorator
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class MutationLockCallback(Callback):
|
|
48
|
+
"""Prevent runtime schema mutations while Lightning owns an active loop."""
|
|
49
|
+
|
|
50
|
+
locks: tuple[Strata, ...] = (Strata.train, Strata.validate, Strata.test, Strata.predict)
|
|
51
|
+
|
|
52
|
+
def _on_loop_start(self, trainer: lit.Trainer, pl_module: "Model", strata: Strata) -> None:
|
|
53
|
+
pl_module.locks[strata] += 1
|
|
54
|
+
|
|
55
|
+
def _on_loop_end(self, trainer: lit.Trainer, pl_module: "Model", strata: Strata) -> None:
|
|
56
|
+
locks = pl_module.locks
|
|
57
|
+
if locks[strata] <= 1:
|
|
58
|
+
locks.pop(strata, None)
|
|
59
|
+
else:
|
|
60
|
+
locks[strata] -= 1
|
|
61
|
+
|
|
62
|
+
def on_exception(
|
|
63
|
+
self,
|
|
64
|
+
trainer: lit.Trainer,
|
|
65
|
+
pl_module: "Model",
|
|
66
|
+
exception: BaseException,
|
|
67
|
+
) -> None: # ty:ignore[invalid-method-override]
|
|
68
|
+
for lock in self.locks:
|
|
69
|
+
pl_module.locks.pop(lock, None)
|
|
70
|
+
|
|
71
|
+
on_train_start = partialmethod(_on_loop_start, strata=Strata.train)
|
|
72
|
+
on_train_end = partialmethod(_on_loop_end, strata=Strata.train)
|
|
73
|
+
on_validation_start = partialmethod(_on_loop_start, strata=Strata.validate)
|
|
74
|
+
on_validation_end = partialmethod(_on_loop_end, strata=Strata.validate)
|
|
75
|
+
on_test_start = partialmethod(_on_loop_start, strata=Strata.test)
|
|
76
|
+
on_test_end = partialmethod(_on_loop_end, strata=Strata.test)
|
|
77
|
+
on_predict_start = partialmethod(_on_loop_start, strata=Strata.predict)
|
|
78
|
+
on_predict_end = partialmethod(_on_loop_end, strata=Strata.predict)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class RuntimePlacementCallback(Callback):
|
|
82
|
+
"""Move late-created modules onto the Lightning module's active device."""
|
|
83
|
+
|
|
84
|
+
def _on_loop_start(self, trainer: lit.Trainer, pl_module: lit.LightningModule, strata: Strata) -> None:
|
|
85
|
+
device = getattr(pl_module, "device", None)
|
|
86
|
+
if isinstance(device, torch.device):
|
|
87
|
+
pl_module.to(device=device)
|
|
88
|
+
|
|
89
|
+
on_train_start = partialmethod(_on_loop_start, strata=Strata.train)
|
|
90
|
+
on_validation_start = partialmethod(_on_loop_start, strata=Strata.validate)
|
|
91
|
+
on_test_start = partialmethod(_on_loop_start, strata=Strata.test)
|
|
92
|
+
on_predict_start = partialmethod(_on_loop_start, strata=Strata.predict)
|
|
93
|
+
|
|
94
|
+
|
|
23
95
|
class AttributeChange(pydantic.BaseModel):
|
|
24
96
|
model_config = pydantic.ConfigDict(arbitrary_types_allowed=True)
|
|
25
97
|
|
|
@@ -40,6 +112,12 @@ class SchemaEditor:
|
|
|
40
112
|
def __init__(self, module: "Model") -> None:
|
|
41
113
|
self.module = module
|
|
42
114
|
|
|
115
|
+
def _assert_mutation_allowed(self, action: str) -> None:
|
|
116
|
+
active = tuple(name for name, count in self.module.locks.items() if count > 0)
|
|
117
|
+
if active:
|
|
118
|
+
labels = ", ".join(active)
|
|
119
|
+
raise RuntimeError(f"model.{action}(...) cannot run while the model is in an active loop: {labels}")
|
|
120
|
+
|
|
43
121
|
def select(
|
|
44
122
|
self,
|
|
45
123
|
*predicates: NodePredicate | NodeAttribute | Callable[[Node], bool],
|
|
@@ -62,7 +140,7 @@ class SchemaEditor:
|
|
|
62
140
|
use_cache: bool = False,
|
|
63
141
|
**values: Any,
|
|
64
142
|
) -> None:
|
|
65
|
-
self.
|
|
143
|
+
self._assert_mutation_allowed("update")
|
|
66
144
|
values = self.module.hyperparameters.update_values(values)
|
|
67
145
|
changes = self._attribute_changes(
|
|
68
146
|
values=values,
|
|
@@ -90,7 +168,7 @@ class SchemaEditor:
|
|
|
90
168
|
include_root: bool = True,
|
|
91
169
|
use_cache: bool = True,
|
|
92
170
|
) -> None:
|
|
93
|
-
self.
|
|
171
|
+
self._assert_mutation_allowed("extend")
|
|
94
172
|
parent, field_count = self._extend_target(*args, include_root=include_root, use_cache=use_cache)
|
|
95
173
|
self.module.hyperparameters.extend(*args, include_root=include_root, use_cache=use_cache)
|
|
96
174
|
ModelGraph.rebuild(self.module)
|
|
@@ -109,7 +187,7 @@ class SchemaEditor:
|
|
|
109
187
|
include_root: bool = False,
|
|
110
188
|
use_cache: bool = True,
|
|
111
189
|
) -> None:
|
|
112
|
-
self.
|
|
190
|
+
self._assert_mutation_allowed("delete")
|
|
113
191
|
roots = self._delete_roots(*predicates, include_root=include_root, use_cache=use_cache)
|
|
114
192
|
self.module.hyperparameters.delete(*predicates, include_root=include_root, use_cache=use_cache)
|
|
115
193
|
ModelGraph.rebuild(self.module)
|
|
@@ -129,7 +207,7 @@ class SchemaEditor:
|
|
|
129
207
|
use_cache: bool = True,
|
|
130
208
|
descendants: bool = False,
|
|
131
209
|
) -> None:
|
|
132
|
-
self.
|
|
210
|
+
self._assert_mutation_allowed("reset")
|
|
133
211
|
selected = self.module.hyperparameters.select(
|
|
134
212
|
*predicates,
|
|
135
213
|
include_root=include_root,
|
|
@@ -160,7 +238,7 @@ class SchemaEditor:
|
|
|
160
238
|
use_cache: bool = False,
|
|
161
239
|
**values: Any,
|
|
162
240
|
) -> Iterator[None]:
|
|
163
|
-
self.
|
|
241
|
+
self._assert_mutation_allowed("override")
|
|
164
242
|
values = self.module.hyperparameters.update_values(values)
|
|
165
243
|
changes = self._attribute_changes(
|
|
166
244
|
values=values,
|
|
@@ -3,7 +3,7 @@
|
|
|
3
3
|
from collections import Counter
|
|
4
4
|
from collections.abc import Callable, Iterator, Sequence
|
|
5
5
|
from contextlib import contextmanager
|
|
6
|
-
from functools import partialmethod
|
|
6
|
+
from functools import partialmethod
|
|
7
7
|
from pathlib import Path
|
|
8
8
|
from typing import Any, Self, cast
|
|
9
9
|
|
|
@@ -11,15 +11,19 @@ import lightning.pytorch as lit
|
|
|
11
11
|
import torch
|
|
12
12
|
from beartype import beartype
|
|
13
13
|
from lightning.pytorch import Callback
|
|
14
|
-
from lightning.pytorch.callbacks import ModelCheckpoint
|
|
15
14
|
from loguru import logger
|
|
16
15
|
from rich.text import Text
|
|
17
16
|
from tensordict import TensorDict
|
|
18
17
|
|
|
19
|
-
from json2vec.architecture.checkpoint import CheckpointState
|
|
18
|
+
from json2vec.architecture.checkpoint import CheckpointState, RollbackCheckpoint
|
|
20
19
|
from json2vec.architecture.contracts import ContractScheduler
|
|
21
20
|
from json2vec.architecture.graph import ModelGraph
|
|
22
|
-
from json2vec.architecture.mutations import
|
|
21
|
+
from json2vec.architecture.mutations import (
|
|
22
|
+
MutationLockCallback,
|
|
23
|
+
RuntimePlacementCallback,
|
|
24
|
+
SchemaEditor,
|
|
25
|
+
immutable,
|
|
26
|
+
)
|
|
23
27
|
from json2vec.architecture.runtime import ModelRuntime, Postprocessor, Preprocessor, step
|
|
24
28
|
from json2vec.data.datasets.base import EncodedBatch, EncodedInput
|
|
25
29
|
from json2vec.logging.throughput import ThroughputLogger
|
|
@@ -37,105 +41,12 @@ from json2vec.tensorfields.base import TENSORFIELDS, Plugin, TensorFieldBase
|
|
|
37
41
|
OptimizerConfig = torch.optim.Optimizer | Callable[["Model"], torch.optim.Optimizer]
|
|
38
42
|
SchedulerConfig = Any | Callable[["Model", torch.optim.Optimizer], Any]
|
|
39
43
|
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
locks[name] += 1
|
|
47
|
-
try:
|
|
48
|
-
return method(self, *args, **kwargs)
|
|
49
|
-
finally:
|
|
50
|
-
if locks[name] <= 1:
|
|
51
|
-
locks.pop(name, None)
|
|
52
|
-
else:
|
|
53
|
-
locks[name] -= 1
|
|
54
|
-
|
|
55
|
-
return wrapped
|
|
56
|
-
|
|
57
|
-
return decorator
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
class MutationLockCallback(Callback):
|
|
61
|
-
"""Prevent runtime schema mutations while Lightning owns an active loop."""
|
|
62
|
-
|
|
63
|
-
locks: tuple[Strata, ...] = (Strata.train, Strata.validate, Strata.test, Strata.predict)
|
|
64
|
-
|
|
65
|
-
def _on_loop_start(self, trainer: lit.Trainer, pl_module: "Model", strata: Strata) -> None:
|
|
66
|
-
pl_module.locks[strata] += 1
|
|
67
|
-
|
|
68
|
-
def _on_loop_end(self, trainer: lit.Trainer, pl_module: "Model", strata: Strata) -> None:
|
|
69
|
-
locks = pl_module.locks
|
|
70
|
-
if locks[strata] <= 1:
|
|
71
|
-
locks.pop(strata, None)
|
|
72
|
-
else:
|
|
73
|
-
locks[strata] -= 1
|
|
74
|
-
|
|
75
|
-
def on_exception(self, trainer: lit.Trainer, pl_module: "Model", exception: BaseException) -> None: # ty:ignore[invalid-method-override]
|
|
76
|
-
for lock in self.locks:
|
|
77
|
-
pl_module.locks.pop(lock, None)
|
|
78
|
-
|
|
79
|
-
on_train_start = partialmethod(_on_loop_start, strata=Strata.train)
|
|
80
|
-
on_train_end = partialmethod(_on_loop_end, strata=Strata.train)
|
|
81
|
-
on_validation_start = partialmethod(_on_loop_start, strata=Strata.validate)
|
|
82
|
-
on_validation_end = partialmethod(_on_loop_end, strata=Strata.validate)
|
|
83
|
-
on_test_start = partialmethod(_on_loop_start, strata=Strata.test)
|
|
84
|
-
on_test_end = partialmethod(_on_loop_end, strata=Strata.test)
|
|
85
|
-
on_predict_start = partialmethod(_on_loop_start, strata=Strata.predict)
|
|
86
|
-
on_predict_end = partialmethod(_on_loop_end, strata=Strata.predict)
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
class RuntimePlacementCallback(Callback):
|
|
90
|
-
"""Move late-created modules onto the Lightning module's active device."""
|
|
91
|
-
|
|
92
|
-
def _on_loop_start(self, trainer: lit.Trainer, pl_module: lit.LightningModule, strata: Strata) -> None:
|
|
93
|
-
device = getattr(pl_module, "device", None)
|
|
94
|
-
if isinstance(device, torch.device):
|
|
95
|
-
pl_module.to(device=device)
|
|
96
|
-
|
|
97
|
-
on_train_start = partialmethod(_on_loop_start, strata=Strata.train)
|
|
98
|
-
on_validation_start = partialmethod(_on_loop_start, strata=Strata.validate)
|
|
99
|
-
on_test_start = partialmethod(_on_loop_start, strata=Strata.test)
|
|
100
|
-
on_predict_start = partialmethod(_on_loop_start, strata=Strata.predict)
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
class RollbackCheckpoint(ModelCheckpoint):
|
|
104
|
-
"""Checkpoint the best model during fit and restore it into the module at fit end."""
|
|
105
|
-
|
|
106
|
-
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
|
107
|
-
super().__init__(*args, **kwargs)
|
|
108
|
-
if self.save_weights_only:
|
|
109
|
-
raise ValueError("RollbackCheckpoint requires full checkpoints; set save_weights_only=False")
|
|
110
|
-
if self.save_top_k == 0:
|
|
111
|
-
raise ValueError("RollbackCheckpoint requires at least one saved checkpoint; set save_top_k != 0")
|
|
112
|
-
|
|
113
|
-
def on_fit_end(self, trainer: lit.Trainer, pl_module: lit.LightningModule) -> None:
|
|
114
|
-
super().on_fit_end(trainer=trainer, pl_module=pl_module)
|
|
115
|
-
if not isinstance(pl_module, Model):
|
|
116
|
-
raise TypeError("RollbackCheckpoint can only restore json2vec Model instances")
|
|
117
|
-
|
|
118
|
-
best_model_path = self.best_model_path
|
|
119
|
-
if not best_model_path:
|
|
120
|
-
raise RuntimeError("RollbackCheckpoint did not find a best checkpoint to restore")
|
|
121
|
-
|
|
122
|
-
strategy = getattr(trainer, "strategy", None)
|
|
123
|
-
if strategy is not None:
|
|
124
|
-
strategy.barrier("rollback_checkpoint_load")
|
|
125
|
-
checkpoint = strategy.checkpoint_io.load_checkpoint(
|
|
126
|
-
best_model_path,
|
|
127
|
-
map_location=pl_module.device,
|
|
128
|
-
weights_only=False,
|
|
129
|
-
)
|
|
130
|
-
else:
|
|
131
|
-
checkpoint = torch.load(best_model_path, weights_only=False, map_location=pl_module.device)
|
|
132
|
-
|
|
133
|
-
pl_module.restore_checkpoint_state(checkpoint)
|
|
134
|
-
logger.bind(
|
|
135
|
-
component="checkpoint",
|
|
136
|
-
checkpoint=best_model_path,
|
|
137
|
-
score=self.best_model_score,
|
|
138
|
-
).info("rolled back Model to best checkpoint")
|
|
44
|
+
__all__ = [
|
|
45
|
+
"Model",
|
|
46
|
+
"MutationLockCallback",
|
|
47
|
+
"RollbackCheckpoint",
|
|
48
|
+
"RuntimePlacementCallback",
|
|
49
|
+
]
|
|
139
50
|
|
|
140
51
|
|
|
141
52
|
class Model(lit.LightningModule, Renderable):
|
|
@@ -406,12 +317,6 @@ class Model(lit.LightningModule, Renderable):
|
|
|
406
317
|
):
|
|
407
318
|
yield
|
|
408
319
|
|
|
409
|
-
def _assert_mutation_allowed(self, action: str) -> None:
|
|
410
|
-
active = tuple(name for name, count in self.locks.items() if count > 0)
|
|
411
|
-
if active:
|
|
412
|
-
labels = ", ".join(active)
|
|
413
|
-
raise RuntimeError(f"model.{action}(...) cannot run while the model is in an active loop: {labels}")
|
|
414
|
-
|
|
415
320
|
def configure_callbacks(self) -> list[Callback]:
|
|
416
321
|
callbacks: list[Callback] = []
|
|
417
322
|
factories: set[Any] = set()
|
|
@@ -538,6 +443,7 @@ class Model(lit.LightningModule, Renderable):
|
|
|
538
443
|
batch: EncodedBatch | list[dict[str, Any]],
|
|
539
444
|
preprocess: Preprocessor | None = None,
|
|
540
445
|
strata: Strata | str = Strata.predict,
|
|
446
|
+
mask: bool = True,
|
|
541
447
|
) -> EncodedInput:
|
|
542
448
|
"""Return encoded tensorfield inputs for raw or processed observations."""
|
|
543
449
|
return ModelRuntime.encode(
|
|
@@ -545,6 +451,7 @@ class Model(lit.LightningModule, Renderable):
|
|
|
545
451
|
batch=batch,
|
|
546
452
|
preprocess=preprocess,
|
|
547
453
|
strata=strata,
|
|
454
|
+
mask=mask,
|
|
548
455
|
)
|
|
549
456
|
|
|
550
457
|
@immutable("inference")
|
|
@@ -14,8 +14,9 @@ from json2vec.architecture.contracts import sanitize
|
|
|
14
14
|
from json2vec.architecture.encoder import ArrayEncoder
|
|
15
15
|
from json2vec.architecture.node import NodeModule
|
|
16
16
|
from json2vec.data.datasets.base import EncodedBatch, EncodedInput
|
|
17
|
-
from json2vec.data.iterables import encode
|
|
18
|
-
from json2vec.
|
|
17
|
+
from json2vec.data.iterables import encode as encode_batch
|
|
18
|
+
from json2vec.data.iterables import mask as apply_mask
|
|
19
|
+
from json2vec.structs.enums import Metric, Strata, TensorKey, Tokens
|
|
19
20
|
from json2vec.structs.packages import Parcel, Prediction
|
|
20
21
|
from json2vec.structs.tree import Address
|
|
21
22
|
from json2vec.tensorfields.base import (
|
|
@@ -99,8 +100,10 @@ class ModelRuntime:
|
|
|
99
100
|
)
|
|
100
101
|
|
|
101
102
|
for address in module.hyperparameters.active_requests.keys():
|
|
103
|
+
has_masked_input = inputs[address].state.eq(Tokens.masked.value).any()
|
|
102
104
|
if (
|
|
103
105
|
torch.any(inputs[address].trainable)
|
|
106
|
+
or (strata == Strata.predict and has_masked_input)
|
|
104
107
|
or (address in module.hyperparameters.target)
|
|
105
108
|
or (address in module.hyperparameters.embed)
|
|
106
109
|
):
|
|
@@ -193,6 +196,7 @@ class ModelRuntime:
|
|
|
193
196
|
batch: EncodedBatch | list[dict[str, Any]],
|
|
194
197
|
preprocess: Preprocessor | None = None,
|
|
195
198
|
strata: Strata | str = Strata.predict,
|
|
199
|
+
mask: bool = True,
|
|
196
200
|
) -> EncodedInput:
|
|
197
201
|
strata = Strata.normalize(strata)
|
|
198
202
|
|
|
@@ -209,12 +213,17 @@ class ModelRuntime:
|
|
|
209
213
|
elif batch and isinstance(batch[0], dict):
|
|
210
214
|
batch = [[request] for request in cast(list[dict[str, Any]], batch)]
|
|
211
215
|
|
|
212
|
-
|
|
216
|
+
inputs = encode_batch(
|
|
213
217
|
batch=cast(EncodedBatch, batch),
|
|
214
218
|
hyperparameters=module.hyperparameters,
|
|
215
219
|
strata=strata,
|
|
216
220
|
interprocess_encoding_context=module.interprocess_encoding_context,
|
|
221
|
+
defer_target_masking=True,
|
|
217
222
|
)
|
|
223
|
+
if mask:
|
|
224
|
+
return next(apply_mask([inputs], module.hyperparameters, strata=strata))
|
|
225
|
+
|
|
226
|
+
return inputs
|
|
218
227
|
|
|
219
228
|
@staticmethod
|
|
220
229
|
def predict(
|
|
@@ -31,7 +31,6 @@ from json2vec.data.iterables import (
|
|
|
31
31
|
process,
|
|
32
32
|
sample,
|
|
33
33
|
shuffle,
|
|
34
|
-
target,
|
|
35
34
|
transform,
|
|
36
35
|
)
|
|
37
36
|
from json2vec.data.processing import Pipeline
|
|
@@ -130,7 +129,6 @@ class CustomBatchDataset(IterableDataset):
|
|
|
130
129
|
| batch
|
|
131
130
|
| transform
|
|
132
131
|
| mask
|
|
133
|
-
| target
|
|
134
132
|
)
|
|
135
133
|
|
|
136
134
|
|
|
@@ -35,7 +35,6 @@ from json2vec.data.iterables import (
|
|
|
35
35
|
process,
|
|
36
36
|
sample,
|
|
37
37
|
shuffle,
|
|
38
|
-
target,
|
|
39
38
|
transform,
|
|
40
39
|
)
|
|
41
40
|
from json2vec.data.processing import Pipeline
|
|
@@ -183,7 +182,6 @@ class PolarsBatchDataset(IterableDataset):
|
|
|
183
182
|
| batch
|
|
184
183
|
| transform
|
|
185
184
|
| mask
|
|
186
|
-
| target
|
|
187
185
|
)
|
|
188
186
|
|
|
189
187
|
|
|
@@ -40,7 +40,6 @@ from json2vec.data.iterables import (
|
|
|
40
40
|
process,
|
|
41
41
|
sample,
|
|
42
42
|
shuffle,
|
|
43
|
-
target,
|
|
44
43
|
transform,
|
|
45
44
|
)
|
|
46
45
|
from json2vec.data.processing import Pipeline
|
|
@@ -323,7 +322,6 @@ class BatchDataset(IterableDataset):
|
|
|
323
322
|
| batch
|
|
324
323
|
| transform
|
|
325
324
|
| mask
|
|
326
|
-
| target
|
|
327
325
|
)
|
|
328
326
|
|
|
329
327
|
|
|
@@ -23,6 +23,7 @@ from json2vec.data.datasets.base import (
|
|
|
23
23
|
ProcessedObservation,
|
|
24
24
|
RawObservation,
|
|
25
25
|
)
|
|
26
|
+
from json2vec.data.processing import MASK_LITERAL, contains_mask_literal
|
|
26
27
|
from json2vec.preprocessors.base import PREPROCESSORS, Preprocessor, PreprocessorMode
|
|
27
28
|
from json2vec.structs.enums import Strata, TensorKey
|
|
28
29
|
from json2vec.structs.experiment import Hyperparameters
|
|
@@ -224,10 +225,14 @@ def encode(
|
|
|
224
225
|
strata: Strata,
|
|
225
226
|
interprocess_encoding_context: InterprocessEncodingContext,
|
|
226
227
|
jmespath_resolution_monitor: JMESPathResolutionMonitor | None = None,
|
|
228
|
+
defer_target_masking: bool = False,
|
|
227
229
|
) -> EncodedInput:
|
|
228
230
|
out: dict[Address, TensorFieldBase] = {}
|
|
229
231
|
target_addresses = set(hyperparameters.target)
|
|
230
232
|
|
|
233
|
+
if strata != Strata.predict and contains_mask_literal(batch):
|
|
234
|
+
raise ValueError(f"{MASK_LITERAL!r} is only valid during predict strata")
|
|
235
|
+
|
|
231
236
|
for address, request in hyperparameters.active_requests.items():
|
|
232
237
|
TensorField = cast(type[TensorFieldBase], getattr(TENSORFIELDS[request.type], "TensorField"))
|
|
233
238
|
|
|
@@ -262,8 +267,8 @@ def encode(
|
|
|
262
267
|
|
|
263
268
|
out[address] = TensorField.new(**kwargs)
|
|
264
269
|
|
|
265
|
-
if address in target_addresses:
|
|
266
|
-
out[address].
|
|
270
|
+
if not defer_target_masking and strata != Strata.predict and address in target_addresses:
|
|
271
|
+
out[address].mask(p_prune=1.0)
|
|
267
272
|
|
|
268
273
|
inputs = cast(EncodedInput, TensorDict(source=cast(Any, out)))
|
|
269
274
|
|
|
@@ -288,21 +293,69 @@ def transform(
|
|
|
288
293
|
strata=strata,
|
|
289
294
|
interprocess_encoding_context=interprocess_encoding_context,
|
|
290
295
|
jmespath_resolution_monitor=jmespath_resolution_monitor,
|
|
296
|
+
defer_target_masking=True,
|
|
291
297
|
)
|
|
292
298
|
|
|
293
299
|
|
|
300
|
+
def _apply_mask_policy(
|
|
301
|
+
field: TensorFieldBase,
|
|
302
|
+
*,
|
|
303
|
+
p_mask: float,
|
|
304
|
+
p_prune: float,
|
|
305
|
+
array_masks: tuple[Any, ...],
|
|
306
|
+
address: Address,
|
|
307
|
+
hyperparameters: Hyperparameters,
|
|
308
|
+
) -> None:
|
|
309
|
+
parameters = inspect.signature(field.mask).parameters
|
|
310
|
+
supports_policy_kwargs = any(parameter.kind == inspect.Parameter.VAR_KEYWORD for parameter in parameters.values())
|
|
311
|
+
supports_policy_kwargs |= any(name in parameters for name in ("p_prune", "array_masks", "hyperparameters"))
|
|
312
|
+
|
|
313
|
+
if supports_policy_kwargs:
|
|
314
|
+
field.mask(
|
|
315
|
+
p_mask=p_mask,
|
|
316
|
+
p_prune=p_prune,
|
|
317
|
+
array_masks=array_masks,
|
|
318
|
+
address=address,
|
|
319
|
+
hyperparameters=hyperparameters,
|
|
320
|
+
)
|
|
321
|
+
return
|
|
322
|
+
|
|
323
|
+
if array_masks:
|
|
324
|
+
raise TypeError(f"tensorfield at '{address}' must accept mask(..., array_masks=...) to use Array masks")
|
|
325
|
+
|
|
326
|
+
if p_mask > 0.0:
|
|
327
|
+
field.mask(p_mask=p_mask)
|
|
328
|
+
|
|
329
|
+
if p_prune > 0.0:
|
|
330
|
+
field.target(p_prune=p_prune)
|
|
331
|
+
|
|
332
|
+
|
|
294
333
|
@beartype
|
|
295
334
|
def mask(
|
|
296
335
|
pipe: Iterable[EncodedInput],
|
|
297
336
|
hyperparameters: Hyperparameters,
|
|
337
|
+
strata: Strata = Strata.train,
|
|
298
338
|
) -> Iterator[EncodedInput]:
|
|
299
339
|
for item in pipe:
|
|
340
|
+
if strata == Strata.predict:
|
|
341
|
+
yield item
|
|
342
|
+
continue
|
|
343
|
+
|
|
300
344
|
for address, request in hyperparameters.active_requests.items():
|
|
301
345
|
p_mask = float(request.p_mask or 0.0)
|
|
302
|
-
|
|
346
|
+
p_prune = float(request.p_prune or 0.0)
|
|
347
|
+
array_masks = hyperparameters.array_masks_for(address)
|
|
348
|
+
if p_mask <= 0.0 and p_prune <= 0.0 and not array_masks:
|
|
303
349
|
continue
|
|
304
350
|
|
|
305
|
-
|
|
351
|
+
_apply_mask_policy(
|
|
352
|
+
item[address],
|
|
353
|
+
p_mask=p_mask,
|
|
354
|
+
p_prune=p_prune,
|
|
355
|
+
array_masks=array_masks,
|
|
356
|
+
address=address,
|
|
357
|
+
hyperparameters=hyperparameters,
|
|
358
|
+
)
|
|
306
359
|
|
|
307
360
|
yield item
|
|
308
361
|
|