json2vec 0.4.8__tar.gz → 0.4.9__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. {json2vec-0.4.8/src/json2vec.egg-info → json2vec-0.4.9}/PKG-INFO +1 -1
  2. {json2vec-0.4.8 → json2vec-0.4.9}/pyproject.toml +1 -1
  3. {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/__init__.py +9 -4
  4. {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/architecture/checkpoint.py +42 -0
  5. {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/architecture/mutations.py +83 -5
  6. {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/architecture/root.py +16 -109
  7. {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/architecture/runtime.py +12 -3
  8. {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/data/datasets/custom.py +0 -2
  9. {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/data/datasets/polars.py +0 -2
  10. {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/data/datasets/streaming.py +0 -2
  11. {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/data/iterables.py +57 -4
  12. {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/data/processing.py +90 -3
  13. json2vec-0.4.9/src/json2vec/helpers/__init__.py +8 -0
  14. json2vec-0.4.9/src/json2vec/helpers/inference.py +632 -0
  15. {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/structs/experiment.py +30 -15
  16. json2vec-0.4.9/src/json2vec/structs/structure.py +228 -0
  17. {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/tensorfields/base.py +235 -6
  18. {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/tensorfields/extensions/category.py +33 -27
  19. {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/tensorfields/extensions/dateparts.py +32 -29
  20. {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/tensorfields/extensions/entity.py +33 -28
  21. {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/tensorfields/extensions/number.py +35 -33
  22. {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/tensorfields/extensions/set.py +31 -26
  23. {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/tensorfields/extensions/text.py +30 -20
  24. {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/tensorfields/extensions/vector.py +32 -26
  25. {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/tensorfields/shared/vocabulary.py +4 -0
  26. {json2vec-0.4.8 → json2vec-0.4.9/src/json2vec.egg-info}/PKG-INFO +1 -1
  27. {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec.egg-info/SOURCES.txt +3 -1
  28. json2vec-0.4.9/tests/test_schema_inference.py +327 -0
  29. json2vec-0.4.8/src/json2vec/structs/__init__.py +0 -0
  30. json2vec-0.4.8/src/json2vec/structs/structure.py +0 -110
  31. {json2vec-0.4.8 → json2vec-0.4.9}/LICENSE +0 -0
  32. {json2vec-0.4.8 → json2vec-0.4.9}/NOTICE +0 -0
  33. {json2vec-0.4.8 → json2vec-0.4.9}/README.md +0 -0
  34. {json2vec-0.4.8 → json2vec-0.4.9}/setup.cfg +0 -0
  35. {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/architecture/__init__.py +0 -0
  36. {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/architecture/attention.py +0 -0
  37. {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/architecture/contracts.py +0 -0
  38. {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/architecture/encoder.py +0 -0
  39. {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/architecture/graph.py +0 -0
  40. {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/architecture/node.py +0 -0
  41. {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/architecture/pool.py +0 -0
  42. {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/architecture/rotary.py +0 -0
  43. {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/data/__init__.py +0 -0
  44. {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/data/datasets/__init__.py +0 -0
  45. {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/data/datasets/base.py +0 -0
  46. {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/distributed.py +0 -0
  47. {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/helpers/hyperparameters.py +0 -0
  48. {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/helpers/optimizers.py +0 -0
  49. {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/helpers/trainer.py +0 -0
  50. {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/inference/__init__.py +0 -0
  51. {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/inference/callback.py +0 -0
  52. {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/inference/deployment.py +0 -0
  53. {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/logging/__init__.py +0 -0
  54. {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/logging/config.py +0 -0
  55. {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/logging/epoch.py +0 -0
  56. {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/logging/throughput.py +0 -0
  57. {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/preprocessors/__init__.py +0 -0
  58. {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/preprocessors/base.py +0 -0
  59. {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/preprocessors/extensions/__init__.py +0 -0
  60. {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/preprocessors/spec.py +0 -0
  61. {json2vec-0.4.8/src/json2vec/helpers → json2vec-0.4.9/src/json2vec/structs}/__init__.py +0 -0
  62. {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/structs/enums.py +0 -0
  63. {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/structs/packages.py +0 -0
  64. {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/structs/selectors.py +0 -0
  65. {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/structs/tree.py +0 -0
  66. {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/tensorfields/__init__.py +0 -0
  67. {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/tensorfields/extensions/__init__.py +0 -0
  68. {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/tensorfields/shared/__init__.py +0 -0
  69. {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/tensorfields/shared/counter.py +0 -0
  70. {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec/tensorfields/spec.py +0 -0
  71. {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec.egg-info/dependency_links.txt +0 -0
  72. {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec.egg-info/requires.txt +0 -0
  73. {json2vec-0.4.8 → json2vec-0.4.9}/src/json2vec.egg-info/top_level.txt +0 -0
  74. {json2vec-0.4.8 → json2vec-0.4.9}/tests/test_callbacks.py +0 -0
  75. {json2vec-0.4.8 → json2vec-0.4.9}/tests/test_optimizers.py +0 -0
  76. {json2vec-0.4.8 → json2vec-0.4.9}/tests/test_public_api.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: json2vec
3
- Version: 0.4.8
3
+ Version: 0.4.9
4
4
  Summary: Schema-first PyTorch models for hierarchical / nested / sequence data structures
5
5
  License-Expression: Apache-2.0
6
6
  Requires-Python: >=3.12
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "json2vec"
3
- version = "0.4.8"
3
+ version = "0.4.9"
4
4
  description = "Schema-first PyTorch models for hierarchical / nested / sequence data structures"
5
5
  readme = "README.md"
6
6
  license = "Apache-2.0"
@@ -8,15 +8,16 @@ mutation predicates, and the `@preprocess` decorator.
8
8
 
9
9
  from typing import TYPE_CHECKING, Any
10
10
 
11
+ from json2vec import helpers as helpers
12
+ from json2vec.architecture.checkpoint import RollbackCheckpoint
13
+ from json2vec.architecture.mutations import MutationLockCallback, RuntimePlacementCallback
11
14
  from json2vec.architecture.root import (
12
15
  Model,
13
- MutationLockCallback,
14
16
  OptimizerConfig,
15
- RollbackCheckpoint,
16
- RuntimePlacementCallback,
17
17
  SchedulerConfig,
18
18
  )
19
19
  from json2vec.data.datasets import CustomDataModule, PolarsDataModule, StreamingDataModule
20
+ from json2vec.data.processing import MASK_LITERAL, MaskLiteral
20
21
  from json2vec.inference.callback import Postprocessor, Writer
21
22
  from json2vec.preprocessors import PREPROCESSORS, Preprocessor, PreprocessorMode, preprocess
22
23
  from json2vec.structs.enums import (
@@ -38,7 +39,7 @@ from json2vec.structs.experiment import (
38
39
  predicate,
39
40
  where,
40
41
  )
41
- from json2vec.structs.structure import Array
42
+ from json2vec.structs.structure import Array, Mask
42
43
  from json2vec.structs.tree import Address, Leaf
43
44
  from json2vec.tensorfields import TENSORFIELDS, DecoderBase, EmbedderBase, Plugin, RequestBase, TensorFieldBase
44
45
  from json2vec.tensorfields.extensions.category import Request as Category
@@ -105,11 +106,15 @@ __all__ = [
105
106
  "Deployment",
106
107
  "EmbedderBase",
107
108
  "Entity",
109
+ "helpers",
108
110
  "Hyperparameters",
109
111
  "Input",
110
112
  "JSONBackend",
111
113
  "Leaf",
112
114
  "Metric",
115
+ "MASK_LITERAL",
116
+ "Mask",
117
+ "MaskLiteral",
113
118
  "Model",
114
119
  "ModelSource",
115
120
  "MutationLockCallback",
@@ -5,7 +5,9 @@ from __future__ import annotations
5
5
  from pathlib import Path
6
6
  from typing import TYPE_CHECKING, Any
7
7
 
8
+ import lightning.pytorch as lit
8
9
  import torch
10
+ from lightning.pytorch.callbacks import ModelCheckpoint
9
11
  from loguru import logger
10
12
 
11
13
  from json2vec.architecture.graph import ModelGraph
@@ -15,6 +17,46 @@ if TYPE_CHECKING:
15
17
  from json2vec.architecture.root import Model
16
18
 
17
19
 
20
+ class RollbackCheckpoint(ModelCheckpoint):
21
+ """Checkpoint the best model during fit and restore it into the module at fit end."""
22
+
23
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
24
+ super().__init__(*args, **kwargs)
25
+ if self.save_weights_only:
26
+ raise ValueError("RollbackCheckpoint requires full checkpoints; set save_weights_only=False")
27
+ if self.save_top_k == 0:
28
+ raise ValueError("RollbackCheckpoint requires at least one saved checkpoint; set save_top_k != 0")
29
+
30
+ def on_fit_end(self, trainer: lit.Trainer, pl_module: lit.LightningModule) -> None:
31
+ from json2vec.architecture.root import Model
32
+
33
+ super().on_fit_end(trainer=trainer, pl_module=pl_module)
34
+ if not isinstance(pl_module, Model):
35
+ raise TypeError("RollbackCheckpoint can only restore json2vec Model instances")
36
+
37
+ best_model_path = self.best_model_path
38
+ if not best_model_path:
39
+ raise RuntimeError("RollbackCheckpoint did not find a best checkpoint to restore")
40
+
41
+ strategy = getattr(trainer, "strategy", None)
42
+ if strategy is not None:
43
+ strategy.barrier("rollback_checkpoint_load")
44
+ checkpoint = strategy.checkpoint_io.load_checkpoint(
45
+ best_model_path,
46
+ map_location=pl_module.device,
47
+ weights_only=False,
48
+ )
49
+ else:
50
+ checkpoint = torch.load(best_model_path, weights_only=False, map_location=pl_module.device)
51
+
52
+ pl_module.restore_checkpoint_state(checkpoint)
53
+ logger.bind(
54
+ component="checkpoint",
55
+ checkpoint=best_model_path,
56
+ score=self.best_model_score,
57
+ ).info("rolled back Model to best checkpoint")
58
+
59
+
18
60
  class CheckpointState:
19
61
  """Save, load, and restore model state without owning the public facade."""
20
62
 
@@ -4,12 +4,17 @@ from __future__ import annotations
4
4
 
5
5
  from collections.abc import Callable, Iterator
6
6
  from contextlib import contextmanager
7
+ from functools import partialmethod, wraps
7
8
  from typing import TYPE_CHECKING, Any
8
9
 
10
+ import lightning.pytorch as lit
9
11
  import pydantic
12
+ import torch
13
+ from lightning.pytorch import Callback
10
14
  from loguru import logger
11
15
 
12
16
  from json2vec.architecture.graph import ModelGraph
17
+ from json2vec.structs.enums import Strata
13
18
  from json2vec.structs.experiment import NodeAttribute, NodePredicate, SchemaField
14
19
  from json2vec.structs.structure import Array
15
20
  from json2vec.structs.tree import Leaf, Node
@@ -20,6 +25,73 @@ if TYPE_CHECKING:
20
25
  _MISSING = object()
21
26
 
22
27
 
28
+ def immutable(name: str | Strata) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
29
+ def decorator(method: Callable[..., Any]) -> Callable[..., Any]:
30
+ @wraps(method)
31
+ def wrapped(self: Any, *args: Any, **kwargs: Any) -> Any:
32
+ locks = self.locks
33
+ locks[name] += 1
34
+ try:
35
+ return method(self, *args, **kwargs)
36
+ finally:
37
+ if locks[name] <= 1:
38
+ locks.pop(name, None)
39
+ else:
40
+ locks[name] -= 1
41
+
42
+ return wrapped
43
+
44
+ return decorator
45
+
46
+
47
+ class MutationLockCallback(Callback):
48
+ """Prevent runtime schema mutations while Lightning owns an active loop."""
49
+
50
+ locks: tuple[Strata, ...] = (Strata.train, Strata.validate, Strata.test, Strata.predict)
51
+
52
+ def _on_loop_start(self, trainer: lit.Trainer, pl_module: "Model", strata: Strata) -> None:
53
+ pl_module.locks[strata] += 1
54
+
55
+ def _on_loop_end(self, trainer: lit.Trainer, pl_module: "Model", strata: Strata) -> None:
56
+ locks = pl_module.locks
57
+ if locks[strata] <= 1:
58
+ locks.pop(strata, None)
59
+ else:
60
+ locks[strata] -= 1
61
+
62
+ def on_exception(
63
+ self,
64
+ trainer: lit.Trainer,
65
+ pl_module: "Model",
66
+ exception: BaseException,
67
+ ) -> None: # ty:ignore[invalid-method-override]
68
+ for lock in self.locks:
69
+ pl_module.locks.pop(lock, None)
70
+
71
+ on_train_start = partialmethod(_on_loop_start, strata=Strata.train)
72
+ on_train_end = partialmethod(_on_loop_end, strata=Strata.train)
73
+ on_validation_start = partialmethod(_on_loop_start, strata=Strata.validate)
74
+ on_validation_end = partialmethod(_on_loop_end, strata=Strata.validate)
75
+ on_test_start = partialmethod(_on_loop_start, strata=Strata.test)
76
+ on_test_end = partialmethod(_on_loop_end, strata=Strata.test)
77
+ on_predict_start = partialmethod(_on_loop_start, strata=Strata.predict)
78
+ on_predict_end = partialmethod(_on_loop_end, strata=Strata.predict)
79
+
80
+
81
+ class RuntimePlacementCallback(Callback):
82
+ """Move late-created modules onto the Lightning module's active device."""
83
+
84
+ def _on_loop_start(self, trainer: lit.Trainer, pl_module: lit.LightningModule, strata: Strata) -> None:
85
+ device = getattr(pl_module, "device", None)
86
+ if isinstance(device, torch.device):
87
+ pl_module.to(device=device)
88
+
89
+ on_train_start = partialmethod(_on_loop_start, strata=Strata.train)
90
+ on_validation_start = partialmethod(_on_loop_start, strata=Strata.validate)
91
+ on_test_start = partialmethod(_on_loop_start, strata=Strata.test)
92
+ on_predict_start = partialmethod(_on_loop_start, strata=Strata.predict)
93
+
94
+
23
95
  class AttributeChange(pydantic.BaseModel):
24
96
  model_config = pydantic.ConfigDict(arbitrary_types_allowed=True)
25
97
 
@@ -40,6 +112,12 @@ class SchemaEditor:
40
112
  def __init__(self, module: "Model") -> None:
41
113
  self.module = module
42
114
 
115
+ def _assert_mutation_allowed(self, action: str) -> None:
116
+ active = tuple(name for name, count in self.module.locks.items() if count > 0)
117
+ if active:
118
+ labels = ", ".join(active)
119
+ raise RuntimeError(f"model.{action}(...) cannot run while the model is in an active loop: {labels}")
120
+
43
121
  def select(
44
122
  self,
45
123
  *predicates: NodePredicate | NodeAttribute | Callable[[Node], bool],
@@ -62,7 +140,7 @@ class SchemaEditor:
62
140
  use_cache: bool = False,
63
141
  **values: Any,
64
142
  ) -> None:
65
- self.module._assert_mutation_allowed("update")
143
+ self._assert_mutation_allowed("update")
66
144
  values = self.module.hyperparameters.update_values(values)
67
145
  changes = self._attribute_changes(
68
146
  values=values,
@@ -90,7 +168,7 @@ class SchemaEditor:
90
168
  include_root: bool = True,
91
169
  use_cache: bool = True,
92
170
  ) -> None:
93
- self.module._assert_mutation_allowed("extend")
171
+ self._assert_mutation_allowed("extend")
94
172
  parent, field_count = self._extend_target(*args, include_root=include_root, use_cache=use_cache)
95
173
  self.module.hyperparameters.extend(*args, include_root=include_root, use_cache=use_cache)
96
174
  ModelGraph.rebuild(self.module)
@@ -109,7 +187,7 @@ class SchemaEditor:
109
187
  include_root: bool = False,
110
188
  use_cache: bool = True,
111
189
  ) -> None:
112
- self.module._assert_mutation_allowed("delete")
190
+ self._assert_mutation_allowed("delete")
113
191
  roots = self._delete_roots(*predicates, include_root=include_root, use_cache=use_cache)
114
192
  self.module.hyperparameters.delete(*predicates, include_root=include_root, use_cache=use_cache)
115
193
  ModelGraph.rebuild(self.module)
@@ -129,7 +207,7 @@ class SchemaEditor:
129
207
  use_cache: bool = True,
130
208
  descendants: bool = False,
131
209
  ) -> None:
132
- self.module._assert_mutation_allowed("reset")
210
+ self._assert_mutation_allowed("reset")
133
211
  selected = self.module.hyperparameters.select(
134
212
  *predicates,
135
213
  include_root=include_root,
@@ -160,7 +238,7 @@ class SchemaEditor:
160
238
  use_cache: bool = False,
161
239
  **values: Any,
162
240
  ) -> Iterator[None]:
163
- self.module._assert_mutation_allowed("override")
241
+ self._assert_mutation_allowed("override")
164
242
  values = self.module.hyperparameters.update_values(values)
165
243
  changes = self._attribute_changes(
166
244
  values=values,
@@ -3,7 +3,7 @@
3
3
  from collections import Counter
4
4
  from collections.abc import Callable, Iterator, Sequence
5
5
  from contextlib import contextmanager
6
- from functools import partialmethod, wraps
6
+ from functools import partialmethod
7
7
  from pathlib import Path
8
8
  from typing import Any, Self, cast
9
9
 
@@ -11,15 +11,19 @@ import lightning.pytorch as lit
11
11
  import torch
12
12
  from beartype import beartype
13
13
  from lightning.pytorch import Callback
14
- from lightning.pytorch.callbacks import ModelCheckpoint
15
14
  from loguru import logger
16
15
  from rich.text import Text
17
16
  from tensordict import TensorDict
18
17
 
19
- from json2vec.architecture.checkpoint import CheckpointState
18
+ from json2vec.architecture.checkpoint import CheckpointState, RollbackCheckpoint
20
19
  from json2vec.architecture.contracts import ContractScheduler
21
20
  from json2vec.architecture.graph import ModelGraph
22
- from json2vec.architecture.mutations import SchemaEditor
21
+ from json2vec.architecture.mutations import (
22
+ MutationLockCallback,
23
+ RuntimePlacementCallback,
24
+ SchemaEditor,
25
+ immutable,
26
+ )
23
27
  from json2vec.architecture.runtime import ModelRuntime, Postprocessor, Preprocessor, step
24
28
  from json2vec.data.datasets.base import EncodedBatch, EncodedInput
25
29
  from json2vec.logging.throughput import ThroughputLogger
@@ -37,105 +41,12 @@ from json2vec.tensorfields.base import TENSORFIELDS, Plugin, TensorFieldBase
37
41
  OptimizerConfig = torch.optim.Optimizer | Callable[["Model"], torch.optim.Optimizer]
38
42
  SchedulerConfig = Any | Callable[["Model", torch.optim.Optimizer], Any]
39
43
 
40
-
41
- def immutable(name: str | Strata) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
42
- def decorator(method: Callable[..., Any]) -> Callable[..., Any]:
43
- @wraps(method)
44
- def wrapped(self: Any, *args: Any, **kwargs: Any) -> Any:
45
- locks = self.locks
46
- locks[name] += 1
47
- try:
48
- return method(self, *args, **kwargs)
49
- finally:
50
- if locks[name] <= 1:
51
- locks.pop(name, None)
52
- else:
53
- locks[name] -= 1
54
-
55
- return wrapped
56
-
57
- return decorator
58
-
59
-
60
- class MutationLockCallback(Callback):
61
- """Prevent runtime schema mutations while Lightning owns an active loop."""
62
-
63
- locks: tuple[Strata, ...] = (Strata.train, Strata.validate, Strata.test, Strata.predict)
64
-
65
- def _on_loop_start(self, trainer: lit.Trainer, pl_module: "Model", strata: Strata) -> None:
66
- pl_module.locks[strata] += 1
67
-
68
- def _on_loop_end(self, trainer: lit.Trainer, pl_module: "Model", strata: Strata) -> None:
69
- locks = pl_module.locks
70
- if locks[strata] <= 1:
71
- locks.pop(strata, None)
72
- else:
73
- locks[strata] -= 1
74
-
75
- def on_exception(self, trainer: lit.Trainer, pl_module: "Model", exception: BaseException) -> None: # ty:ignore[invalid-method-override]
76
- for lock in self.locks:
77
- pl_module.locks.pop(lock, None)
78
-
79
- on_train_start = partialmethod(_on_loop_start, strata=Strata.train)
80
- on_train_end = partialmethod(_on_loop_end, strata=Strata.train)
81
- on_validation_start = partialmethod(_on_loop_start, strata=Strata.validate)
82
- on_validation_end = partialmethod(_on_loop_end, strata=Strata.validate)
83
- on_test_start = partialmethod(_on_loop_start, strata=Strata.test)
84
- on_test_end = partialmethod(_on_loop_end, strata=Strata.test)
85
- on_predict_start = partialmethod(_on_loop_start, strata=Strata.predict)
86
- on_predict_end = partialmethod(_on_loop_end, strata=Strata.predict)
87
-
88
-
89
- class RuntimePlacementCallback(Callback):
90
- """Move late-created modules onto the Lightning module's active device."""
91
-
92
- def _on_loop_start(self, trainer: lit.Trainer, pl_module: lit.LightningModule, strata: Strata) -> None:
93
- device = getattr(pl_module, "device", None)
94
- if isinstance(device, torch.device):
95
- pl_module.to(device=device)
96
-
97
- on_train_start = partialmethod(_on_loop_start, strata=Strata.train)
98
- on_validation_start = partialmethod(_on_loop_start, strata=Strata.validate)
99
- on_test_start = partialmethod(_on_loop_start, strata=Strata.test)
100
- on_predict_start = partialmethod(_on_loop_start, strata=Strata.predict)
101
-
102
-
103
- class RollbackCheckpoint(ModelCheckpoint):
104
- """Checkpoint the best model during fit and restore it into the module at fit end."""
105
-
106
- def __init__(self, *args: Any, **kwargs: Any) -> None:
107
- super().__init__(*args, **kwargs)
108
- if self.save_weights_only:
109
- raise ValueError("RollbackCheckpoint requires full checkpoints; set save_weights_only=False")
110
- if self.save_top_k == 0:
111
- raise ValueError("RollbackCheckpoint requires at least one saved checkpoint; set save_top_k != 0")
112
-
113
- def on_fit_end(self, trainer: lit.Trainer, pl_module: lit.LightningModule) -> None:
114
- super().on_fit_end(trainer=trainer, pl_module=pl_module)
115
- if not isinstance(pl_module, Model):
116
- raise TypeError("RollbackCheckpoint can only restore json2vec Model instances")
117
-
118
- best_model_path = self.best_model_path
119
- if not best_model_path:
120
- raise RuntimeError("RollbackCheckpoint did not find a best checkpoint to restore")
121
-
122
- strategy = getattr(trainer, "strategy", None)
123
- if strategy is not None:
124
- strategy.barrier("rollback_checkpoint_load")
125
- checkpoint = strategy.checkpoint_io.load_checkpoint(
126
- best_model_path,
127
- map_location=pl_module.device,
128
- weights_only=False,
129
- )
130
- else:
131
- checkpoint = torch.load(best_model_path, weights_only=False, map_location=pl_module.device)
132
-
133
- pl_module.restore_checkpoint_state(checkpoint)
134
- logger.bind(
135
- component="checkpoint",
136
- checkpoint=best_model_path,
137
- score=self.best_model_score,
138
- ).info("rolled back Model to best checkpoint")
44
+ __all__ = [
45
+ "Model",
46
+ "MutationLockCallback",
47
+ "RollbackCheckpoint",
48
+ "RuntimePlacementCallback",
49
+ ]
139
50
 
140
51
 
141
52
  class Model(lit.LightningModule, Renderable):
@@ -406,12 +317,6 @@ class Model(lit.LightningModule, Renderable):
406
317
  ):
407
318
  yield
408
319
 
409
- def _assert_mutation_allowed(self, action: str) -> None:
410
- active = tuple(name for name, count in self.locks.items() if count > 0)
411
- if active:
412
- labels = ", ".join(active)
413
- raise RuntimeError(f"model.{action}(...) cannot run while the model is in an active loop: {labels}")
414
-
415
320
  def configure_callbacks(self) -> list[Callback]:
416
321
  callbacks: list[Callback] = []
417
322
  factories: set[Any] = set()
@@ -538,6 +443,7 @@ class Model(lit.LightningModule, Renderable):
538
443
  batch: EncodedBatch | list[dict[str, Any]],
539
444
  preprocess: Preprocessor | None = None,
540
445
  strata: Strata | str = Strata.predict,
446
+ mask: bool = True,
541
447
  ) -> EncodedInput:
542
448
  """Return encoded tensorfield inputs for raw or processed observations."""
543
449
  return ModelRuntime.encode(
@@ -545,6 +451,7 @@ class Model(lit.LightningModule, Renderable):
545
451
  batch=batch,
546
452
  preprocess=preprocess,
547
453
  strata=strata,
454
+ mask=mask,
548
455
  )
549
456
 
550
457
  @immutable("inference")
@@ -14,8 +14,9 @@ from json2vec.architecture.contracts import sanitize
14
14
  from json2vec.architecture.encoder import ArrayEncoder
15
15
  from json2vec.architecture.node import NodeModule
16
16
  from json2vec.data.datasets.base import EncodedBatch, EncodedInput
17
- from json2vec.data.iterables import encode
18
- from json2vec.structs.enums import Metric, Strata, TensorKey
17
+ from json2vec.data.iterables import encode as encode_batch
18
+ from json2vec.data.iterables import mask as apply_mask
19
+ from json2vec.structs.enums import Metric, Strata, TensorKey, Tokens
19
20
  from json2vec.structs.packages import Parcel, Prediction
20
21
  from json2vec.structs.tree import Address
21
22
  from json2vec.tensorfields.base import (
@@ -99,8 +100,10 @@ class ModelRuntime:
99
100
  )
100
101
 
101
102
  for address in module.hyperparameters.active_requests.keys():
103
+ has_masked_input = inputs[address].state.eq(Tokens.masked.value).any()
102
104
  if (
103
105
  torch.any(inputs[address].trainable)
106
+ or (strata == Strata.predict and has_masked_input)
104
107
  or (address in module.hyperparameters.target)
105
108
  or (address in module.hyperparameters.embed)
106
109
  ):
@@ -193,6 +196,7 @@ class ModelRuntime:
193
196
  batch: EncodedBatch | list[dict[str, Any]],
194
197
  preprocess: Preprocessor | None = None,
195
198
  strata: Strata | str = Strata.predict,
199
+ mask: bool = True,
196
200
  ) -> EncodedInput:
197
201
  strata = Strata.normalize(strata)
198
202
 
@@ -209,12 +213,17 @@ class ModelRuntime:
209
213
  elif batch and isinstance(batch[0], dict):
210
214
  batch = [[request] for request in cast(list[dict[str, Any]], batch)]
211
215
 
212
- return encode(
216
+ inputs = encode_batch(
213
217
  batch=cast(EncodedBatch, batch),
214
218
  hyperparameters=module.hyperparameters,
215
219
  strata=strata,
216
220
  interprocess_encoding_context=module.interprocess_encoding_context,
221
+ defer_target_masking=True,
217
222
  )
223
+ if mask:
224
+ return next(apply_mask([inputs], module.hyperparameters, strata=strata))
225
+
226
+ return inputs
218
227
 
219
228
  @staticmethod
220
229
  def predict(
@@ -31,7 +31,6 @@ from json2vec.data.iterables import (
31
31
  process,
32
32
  sample,
33
33
  shuffle,
34
- target,
35
34
  transform,
36
35
  )
37
36
  from json2vec.data.processing import Pipeline
@@ -130,7 +129,6 @@ class CustomBatchDataset(IterableDataset):
130
129
  | batch
131
130
  | transform
132
131
  | mask
133
- | target
134
132
  )
135
133
 
136
134
 
@@ -35,7 +35,6 @@ from json2vec.data.iterables import (
35
35
  process,
36
36
  sample,
37
37
  shuffle,
38
- target,
39
38
  transform,
40
39
  )
41
40
  from json2vec.data.processing import Pipeline
@@ -183,7 +182,6 @@ class PolarsBatchDataset(IterableDataset):
183
182
  | batch
184
183
  | transform
185
184
  | mask
186
- | target
187
185
  )
188
186
 
189
187
 
@@ -40,7 +40,6 @@ from json2vec.data.iterables import (
40
40
  process,
41
41
  sample,
42
42
  shuffle,
43
- target,
44
43
  transform,
45
44
  )
46
45
  from json2vec.data.processing import Pipeline
@@ -323,7 +322,6 @@ class BatchDataset(IterableDataset):
323
322
  | batch
324
323
  | transform
325
324
  | mask
326
- | target
327
325
  )
328
326
 
329
327
 
@@ -23,6 +23,7 @@ from json2vec.data.datasets.base import (
23
23
  ProcessedObservation,
24
24
  RawObservation,
25
25
  )
26
+ from json2vec.data.processing import MASK_LITERAL, contains_mask_literal
26
27
  from json2vec.preprocessors.base import PREPROCESSORS, Preprocessor, PreprocessorMode
27
28
  from json2vec.structs.enums import Strata, TensorKey
28
29
  from json2vec.structs.experiment import Hyperparameters
@@ -224,10 +225,14 @@ def encode(
224
225
  strata: Strata,
225
226
  interprocess_encoding_context: InterprocessEncodingContext,
226
227
  jmespath_resolution_monitor: JMESPathResolutionMonitor | None = None,
228
+ defer_target_masking: bool = False,
227
229
  ) -> EncodedInput:
228
230
  out: dict[Address, TensorFieldBase] = {}
229
231
  target_addresses = set(hyperparameters.target)
230
232
 
233
+ if strata != Strata.predict and contains_mask_literal(batch):
234
+ raise ValueError(f"{MASK_LITERAL!r} is only valid during predict strata")
235
+
231
236
  for address, request in hyperparameters.active_requests.items():
232
237
  TensorField = cast(type[TensorFieldBase], getattr(TENSORFIELDS[request.type], "TensorField"))
233
238
 
@@ -262,8 +267,8 @@ def encode(
262
267
 
263
268
  out[address] = TensorField.new(**kwargs)
264
269
 
265
- if address in target_addresses:
266
- out[address].target(p_prune=1.0)
270
+ if not defer_target_masking and strata != Strata.predict and address in target_addresses:
271
+ out[address].mask(p_prune=1.0)
267
272
 
268
273
  inputs = cast(EncodedInput, TensorDict(source=cast(Any, out)))
269
274
 
@@ -288,21 +293,69 @@ def transform(
288
293
  strata=strata,
289
294
  interprocess_encoding_context=interprocess_encoding_context,
290
295
  jmespath_resolution_monitor=jmespath_resolution_monitor,
296
+ defer_target_masking=True,
291
297
  )
292
298
 
293
299
 
300
+ def _apply_mask_policy(
301
+ field: TensorFieldBase,
302
+ *,
303
+ p_mask: float,
304
+ p_prune: float,
305
+ array_masks: tuple[Any, ...],
306
+ address: Address,
307
+ hyperparameters: Hyperparameters,
308
+ ) -> None:
309
+ parameters = inspect.signature(field.mask).parameters
310
+ supports_policy_kwargs = any(parameter.kind == inspect.Parameter.VAR_KEYWORD for parameter in parameters.values())
311
+ supports_policy_kwargs |= any(name in parameters for name in ("p_prune", "array_masks", "hyperparameters"))
312
+
313
+ if supports_policy_kwargs:
314
+ field.mask(
315
+ p_mask=p_mask,
316
+ p_prune=p_prune,
317
+ array_masks=array_masks,
318
+ address=address,
319
+ hyperparameters=hyperparameters,
320
+ )
321
+ return
322
+
323
+ if array_masks:
324
+ raise TypeError(f"tensorfield at '{address}' must accept mask(..., array_masks=...) to use Array masks")
325
+
326
+ if p_mask > 0.0:
327
+ field.mask(p_mask=p_mask)
328
+
329
+ if p_prune > 0.0:
330
+ field.target(p_prune=p_prune)
331
+
332
+
294
333
  @beartype
295
334
  def mask(
296
335
  pipe: Iterable[EncodedInput],
297
336
  hyperparameters: Hyperparameters,
337
+ strata: Strata = Strata.train,
298
338
  ) -> Iterator[EncodedInput]:
299
339
  for item in pipe:
340
+ if strata == Strata.predict:
341
+ yield item
342
+ continue
343
+
300
344
  for address, request in hyperparameters.active_requests.items():
301
345
  p_mask = float(request.p_mask or 0.0)
302
- if p_mask <= 0.0:
346
+ p_prune = float(request.p_prune or 0.0)
347
+ array_masks = hyperparameters.array_masks_for(address)
348
+ if p_mask <= 0.0 and p_prune <= 0.0 and not array_masks:
303
349
  continue
304
350
 
305
- item[address].mask(p_mask=p_mask)
351
+ _apply_mask_policy(
352
+ item[address],
353
+ p_mask=p_mask,
354
+ p_prune=p_prune,
355
+ array_masks=array_masks,
356
+ address=address,
357
+ hyperparameters=hyperparameters,
358
+ )
306
359
 
307
360
  yield item
308
361