json2vec 0.4.3__tar.gz → 0.4.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {json2vec-0.4.3/src/json2vec.egg-info → json2vec-0.4.4}/PKG-INFO +1 -1
- {json2vec-0.4.3 → json2vec-0.4.4}/pyproject.toml +19 -1
- json2vec-0.4.4/src/json2vec/architecture/checkpoint.py +69 -0
- json2vec-0.4.4/src/json2vec/architecture/contracts.py +466 -0
- json2vec-0.4.4/src/json2vec/architecture/graph.py +100 -0
- {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/architecture/plot.py +1 -1
- {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/architecture/root.py +57 -313
- json2vec-0.4.4/src/json2vec/architecture/runtime.py +241 -0
- json2vec-0.4.4/src/json2vec/architecture/schema_editor.py +126 -0
- {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/data/datasets/streaming.py +5 -2
- {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/data/iterables.py +2 -2
- {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/inference/callback.py +4 -3
- {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/inference/deployment.py +2 -2
- {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/structs/enums.py +3 -0
- {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/structs/experiment.py +27 -232
- json2vec-0.4.4/src/json2vec/structs/selectors.py +236 -0
- {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/tensorfields/extensions/category.py +4 -1
- {json2vec-0.4.3 → json2vec-0.4.4/src/json2vec.egg-info}/PKG-INFO +1 -1
- {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec.egg-info/SOURCES.txt +6 -0
- {json2vec-0.4.3 → json2vec-0.4.4}/LICENSE +0 -0
- {json2vec-0.4.3 → json2vec-0.4.4}/NOTICE +0 -0
- {json2vec-0.4.3 → json2vec-0.4.4}/README.md +0 -0
- {json2vec-0.4.3 → json2vec-0.4.4}/setup.cfg +0 -0
- {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/__init__.py +0 -0
- {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/architecture/__init__.py +0 -0
- {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/architecture/attention.py +0 -0
- {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/architecture/encoder.py +0 -0
- {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/architecture/node.py +0 -0
- {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/architecture/pool.py +0 -0
- {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/architecture/rotary.py +0 -0
- {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/data/__init__.py +0 -0
- {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/data/datasets/__init__.py +0 -0
- {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/data/datasets/base.py +0 -0
- {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/data/datasets/polars.py +0 -0
- {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/data/processing.py +0 -0
- {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/distributed.py +0 -0
- {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/inference/__init__.py +0 -0
- {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/logging/__init__.py +0 -0
- {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/logging/config.py +0 -0
- {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/logging/epoch.py +0 -0
- {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/logging/throughput.py +0 -0
- {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/preprocessors/__init__.py +0 -0
- {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/preprocessors/base.py +0 -0
- {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/preprocessors/extensions/__init__.py +0 -0
- {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/preprocessors/spec.py +0 -0
- {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/structs/__init__.py +0 -0
- {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/structs/packages.py +0 -0
- {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/structs/structure.py +0 -0
- {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/structs/tree.py +0 -0
- {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/tensorfields/__init__.py +0 -0
- {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/tensorfields/base.py +0 -0
- {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/tensorfields/extensions/__init__.py +0 -0
- {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/tensorfields/extensions/dateparts.py +0 -0
- {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/tensorfields/extensions/entity.py +0 -0
- {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/tensorfields/extensions/number.py +0 -0
- {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/tensorfields/extensions/set.py +0 -0
- {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/tensorfields/extensions/text.py +0 -0
- {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/tensorfields/extensions/vector.py +0 -0
- {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/tensorfields/shared/__init__.py +0 -0
- {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/tensorfields/shared/counter.py +0 -0
- {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/tensorfields/shared/vocabulary.py +0 -0
- {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/tensorfields/spec.py +0 -0
- {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec.egg-info/dependency_links.txt +0 -0
- {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec.egg-info/requires.txt +0 -0
- {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec.egg-info/top_level.txt +0 -0
- {json2vec-0.4.3 → json2vec-0.4.4}/tests/test_callbacks.py +0 -0
- {json2vec-0.4.3 → json2vec-0.4.4}/tests/test_public_api.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "json2vec"
|
|
3
|
-
version = "0.4.
|
|
3
|
+
version = "0.4.4"
|
|
4
4
|
description = "{...} -> [*]"
|
|
5
5
|
readme = "README.md"
|
|
6
6
|
license = "Apache-2.0"
|
|
@@ -68,3 +68,21 @@ python_files = ["test_*.py"]
|
|
|
68
68
|
[tool.ruff]
|
|
69
69
|
line-length = 120
|
|
70
70
|
lint.extend-select = ["I"]
|
|
71
|
+
|
|
72
|
+
[tool.ty.rules]
|
|
73
|
+
# Keep ty as a green baseline while the dynamic plugin/Torch/Lightning surfaces
|
|
74
|
+
# are covered by runtime tests and can be tightened incrementally.
|
|
75
|
+
unresolved-attribute = "ignore"
|
|
76
|
+
invalid-type-form = "ignore"
|
|
77
|
+
invalid-argument-type = "ignore"
|
|
78
|
+
invalid-assignment = "ignore"
|
|
79
|
+
unknown-argument = "ignore"
|
|
80
|
+
invalid-method-override = "ignore"
|
|
81
|
+
call-non-callable = "ignore"
|
|
82
|
+
invalid-return-type = "ignore"
|
|
83
|
+
not-subscriptable = "ignore"
|
|
84
|
+
unsupported-operator = "ignore"
|
|
85
|
+
no-matching-overload = "ignore"
|
|
86
|
+
invalid-attribute-override = "ignore"
|
|
87
|
+
redundant-cast = "ignore"
|
|
88
|
+
unused-ignore-comment = "ignore"
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
"""Checkpoint serialization helpers for JSON2Vec models."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import TYPE_CHECKING, Any
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
from loguru import logger
|
|
10
|
+
|
|
11
|
+
from json2vec.architecture.graph import ModelGraph
|
|
12
|
+
from json2vec.structs.experiment import Hyperparameters
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from json2vec.architecture.root import Model
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class CheckpointState:
|
|
19
|
+
"""Save, load, and restore model state without owning the public facade."""
|
|
20
|
+
|
|
21
|
+
required_fields = {"state_dict", "hyperparameters", "batch_size"}
|
|
22
|
+
|
|
23
|
+
@staticmethod
|
|
24
|
+
def dump(module: "Model", checkpoint: dict[str, Any]) -> None:
|
|
25
|
+
checkpoint["hyperparameters"] = module.hyperparameters.model_dump(mode="python")
|
|
26
|
+
checkpoint["batch_size"] = module.batch_size
|
|
27
|
+
|
|
28
|
+
@staticmethod
|
|
29
|
+
def save(module: "Model", pathname: str | Path) -> None:
|
|
30
|
+
path = Path(pathname)
|
|
31
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
|
32
|
+
|
|
33
|
+
checkpoint: dict[str, Any] = {"state_dict": module.state_dict()}
|
|
34
|
+
CheckpointState.dump(module, checkpoint)
|
|
35
|
+
torch.save(checkpoint, path)
|
|
36
|
+
|
|
37
|
+
@staticmethod
|
|
38
|
+
def restore(module: "Model", checkpoint: dict[str, Any]) -> None:
|
|
39
|
+
missing = CheckpointState.required_fields - set(checkpoint)
|
|
40
|
+
if missing:
|
|
41
|
+
fields = ", ".join(sorted(missing))
|
|
42
|
+
raise ValueError(f"missing checkpoint fields: {fields}")
|
|
43
|
+
|
|
44
|
+
device = module.device
|
|
45
|
+
was_training = module.training
|
|
46
|
+
module.hyperparameters = Hyperparameters.model_validate(checkpoint["hyperparameters"])
|
|
47
|
+
module.batch_size = checkpoint["batch_size"]
|
|
48
|
+
ModelGraph.install(module)
|
|
49
|
+
if isinstance(device, torch.device):
|
|
50
|
+
module.to(device=device)
|
|
51
|
+
module.load_state_dict(state_dict=checkpoint["state_dict"])
|
|
52
|
+
module.train(was_training)
|
|
53
|
+
|
|
54
|
+
@staticmethod
|
|
55
|
+
def load(model_cls: type["Model"], checkpoint: str | Path) -> "Model":
|
|
56
|
+
path = Path(checkpoint)
|
|
57
|
+
logger.bind(component="model_factory", checkpoint=str(path)).info("loading Model from checkpoint")
|
|
58
|
+
state = torch.load(path, weights_only=False, map_location="cpu")
|
|
59
|
+
if "hyperparameters" not in state:
|
|
60
|
+
raise ValueError("missing hyperparameters in checkpoint")
|
|
61
|
+
|
|
62
|
+
model = model_cls(
|
|
63
|
+
hyperparameters=Hyperparameters.model_validate(state["hyperparameters"]),
|
|
64
|
+
batch_size=state["batch_size"],
|
|
65
|
+
)
|
|
66
|
+
model.restore_checkpoint_state(state)
|
|
67
|
+
logger.bind(component="model_factory", checkpoint=str(path)).info("restored model state from checkpoint")
|
|
68
|
+
|
|
69
|
+
return model
|
|
@@ -0,0 +1,466 @@
|
|
|
1
|
+
"""Generic runtime contracts for model forward inputs."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from collections.abc import Iterator, Mapping
|
|
6
|
+
from dataclasses import dataclass, field
|
|
7
|
+
from typing import TYPE_CHECKING, Any
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
from tensordict import TensorDict
|
|
11
|
+
|
|
12
|
+
from json2vec.structs.enums import Strata, TensorKey, Tokens
|
|
13
|
+
from json2vec.structs.tree import Address
|
|
14
|
+
from json2vec.tensorfields.base import TENSORFIELDS, TensorFieldBase
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from json2vec.architecture.root import Model
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class ForwardContractError(ValueError):
|
|
21
|
+
"""Raised when a forward batch violates a model input contract."""
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
INTEGER_DTYPES = {
|
|
25
|
+
torch.uint8,
|
|
26
|
+
torch.int8,
|
|
27
|
+
torch.int16,
|
|
28
|
+
torch.int32,
|
|
29
|
+
torch.int64,
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
ContractSignature = tuple[Any, ...]
|
|
34
|
+
ContractScope = tuple[str, int, int, ContractSignature]
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@dataclass
|
|
38
|
+
class ContractScheduler:
|
|
39
|
+
"""Deterministic backoff scheduler for expensive forward contract checks."""
|
|
40
|
+
|
|
41
|
+
periodic_interval: int = 1024
|
|
42
|
+
_counts: dict[ContractScope, int] = field(default_factory=dict)
|
|
43
|
+
|
|
44
|
+
def reset(self) -> None:
|
|
45
|
+
self._counts.clear()
|
|
46
|
+
|
|
47
|
+
def should_check(
|
|
48
|
+
self,
|
|
49
|
+
module: "Model",
|
|
50
|
+
inputs: Any,
|
|
51
|
+
*,
|
|
52
|
+
strata: Strata,
|
|
53
|
+
dataloader_idx: int,
|
|
54
|
+
) -> bool:
|
|
55
|
+
generation = int(getattr(module, "_contract_generation", 0))
|
|
56
|
+
scope = (
|
|
57
|
+
str(strata),
|
|
58
|
+
dataloader_idx,
|
|
59
|
+
generation,
|
|
60
|
+
batch_signature(module, inputs),
|
|
61
|
+
)
|
|
62
|
+
count = self._counts.get(scope, 0)
|
|
63
|
+
self._counts[scope] = count + 1
|
|
64
|
+
return is_backoff_index(count, periodic_interval=self.periodic_interval)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def sanitize(
|
|
68
|
+
module: "Model",
|
|
69
|
+
inputs: TensorDict[Address, TensorFieldBase],
|
|
70
|
+
*,
|
|
71
|
+
strata: Strata | str,
|
|
72
|
+
dataloader_idx: int = 0,
|
|
73
|
+
) -> None:
|
|
74
|
+
"""Validate the generic forward-input contract before model execution."""
|
|
75
|
+
normalized = Strata.normalize(strata)
|
|
76
|
+
scheduler = getattr(module, "_contract_scheduler", None)
|
|
77
|
+
if isinstance(scheduler, ContractScheduler) and not scheduler.should_check(
|
|
78
|
+
module,
|
|
79
|
+
inputs,
|
|
80
|
+
strata=normalized,
|
|
81
|
+
dataloader_idx=dataloader_idx,
|
|
82
|
+
):
|
|
83
|
+
return
|
|
84
|
+
|
|
85
|
+
if not isinstance(inputs, TensorDict):
|
|
86
|
+
raise TypeError(f"forward inputs must be a TensorDict, got {type(inputs).__name__}")
|
|
87
|
+
|
|
88
|
+
require_forward_addresses(module, inputs, strata=normalized)
|
|
89
|
+
|
|
90
|
+
for address in module.hyperparameters.active_requests:
|
|
91
|
+
tensorfield = inputs[address]
|
|
92
|
+
require_registered_tensorfield(module, address, tensorfield)
|
|
93
|
+
require_core_tensors(module, address, tensorfield)
|
|
94
|
+
require_tensor_devices(module, address, tensorfield)
|
|
95
|
+
require_target_contract(module, address, tensorfield, strata=normalized)
|
|
96
|
+
require_mask_contract(module, address, tensorfield)
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def is_backoff_index(index: int, *, periodic_interval: int) -> bool:
|
|
100
|
+
if index == 0:
|
|
101
|
+
return True
|
|
102
|
+
|
|
103
|
+
if (index & (index - 1)) == 0:
|
|
104
|
+
return True
|
|
105
|
+
|
|
106
|
+
return periodic_interval > 0 and index % periodic_interval == 0
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def batch_signature(module: "Model", inputs: Any) -> ContractSignature:
|
|
110
|
+
if not isinstance(inputs, TensorDict):
|
|
111
|
+
return ("inputs", qualified_name(type(inputs)))
|
|
112
|
+
|
|
113
|
+
input_keys = tuple(sorted(str(key) for key in inputs.keys()))
|
|
114
|
+
fields: list[tuple[Any, ...]] = []
|
|
115
|
+
for address in sorted(module.hyperparameters.active_requests, key=str):
|
|
116
|
+
if address not in inputs.keys():
|
|
117
|
+
fields.append((str(address), "missing"))
|
|
118
|
+
continue
|
|
119
|
+
|
|
120
|
+
tensorfield = inputs[address]
|
|
121
|
+
fields.append(
|
|
122
|
+
(
|
|
123
|
+
str(address),
|
|
124
|
+
qualified_name(type(tensorfield)),
|
|
125
|
+
tensor_signature(getattr(tensorfield, TensorKey.state, None)),
|
|
126
|
+
tensor_signature(getattr(tensorfield, TensorKey.trainable, None)),
|
|
127
|
+
tensor_tree_signature(getattr(tensorfield, TensorKey.content, None)),
|
|
128
|
+
tensor_tree_signature(getattr(tensorfield, TensorKey.targets, None)),
|
|
129
|
+
)
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
return (input_keys, tuple(fields))
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def tensor_signature(value: Any) -> tuple[Any, ...]:
|
|
136
|
+
if not torch.is_tensor(value):
|
|
137
|
+
return ("object", qualified_name(type(value)))
|
|
138
|
+
|
|
139
|
+
return (
|
|
140
|
+
"tensor",
|
|
141
|
+
tuple(value.shape),
|
|
142
|
+
str(value.dtype),
|
|
143
|
+
str(value.device),
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def tensor_tree_signature(value: Any) -> tuple[Any, ...]:
|
|
148
|
+
if torch.is_tensor(value):
|
|
149
|
+
return tensor_signature(value)
|
|
150
|
+
|
|
151
|
+
if isinstance(value, TensorDict):
|
|
152
|
+
return (
|
|
153
|
+
"tensordict",
|
|
154
|
+
tuple((str(key), tensor_tree_signature(value[key])) for key in sorted(value.keys(), key=str)),
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
if isinstance(value, Mapping):
|
|
158
|
+
return (
|
|
159
|
+
"mapping",
|
|
160
|
+
tuple(
|
|
161
|
+
(str(key), tensor_tree_signature(item)) for key, item in sorted(value.items(), key=lambda x: str(x[0]))
|
|
162
|
+
),
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
return ("object", qualified_name(type(value)))
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def require_forward_addresses(
|
|
169
|
+
module: "Model",
|
|
170
|
+
inputs: TensorDict[Address, TensorFieldBase],
|
|
171
|
+
*,
|
|
172
|
+
strata: Strata,
|
|
173
|
+
) -> None:
|
|
174
|
+
keys = set(inputs.keys())
|
|
175
|
+
metadata_keys = {key for key in keys if key == TensorKey.metadata}
|
|
176
|
+
addresses = {Address(str(key)) for key in keys if key != TensorKey.metadata}
|
|
177
|
+
expected = set(module.hyperparameters.active_requests)
|
|
178
|
+
|
|
179
|
+
if metadata_keys and strata != Strata.predict:
|
|
180
|
+
raise ForwardContractError(f"forward input contains {TensorKey.metadata} outside predict strata")
|
|
181
|
+
|
|
182
|
+
missing = expected - addresses
|
|
183
|
+
if missing:
|
|
184
|
+
raise ForwardContractError(f"forward input is missing active request address(es): {format_addresses(missing)}")
|
|
185
|
+
|
|
186
|
+
extra = addresses - expected
|
|
187
|
+
if not extra:
|
|
188
|
+
return
|
|
189
|
+
|
|
190
|
+
arrays = extra & set(module.hyperparameters.arrays)
|
|
191
|
+
if arrays:
|
|
192
|
+
raise ForwardContractError(
|
|
193
|
+
f"forward input contains array address(es); only active leaf request addresses are allowed: "
|
|
194
|
+
f"{format_addresses(arrays)}"
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
inactive = {address for address in extra if address in module.hyperparameters.requests}
|
|
198
|
+
if inactive:
|
|
199
|
+
raise ForwardContractError(
|
|
200
|
+
"forward input contains inactive request address(es): "
|
|
201
|
+
f"{format_addresses(inactive)}. Inactive fields remain in the schema but must not be present in runtime input."
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
raise ForwardContractError(f"forward input contains unknown address(es): {format_addresses(extra)}")
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
def require_registered_tensorfield(module: "Model", address: Address, value: Any) -> None:
|
|
208
|
+
if not isinstance(value, TensorFieldBase):
|
|
209
|
+
raise TypeError(f"forward input '{address}' must be a TensorFieldBase, got {type(value).__name__}")
|
|
210
|
+
|
|
211
|
+
request = module.hyperparameters.requests[address]
|
|
212
|
+
expected = TENSORFIELDS[request.type].TensorField
|
|
213
|
+
if not isinstance(value, expected):
|
|
214
|
+
raise TypeError(
|
|
215
|
+
f"forward input '{address}' must use tensorfield class {qualified_name(expected)}, "
|
|
216
|
+
f"got {qualified_name(type(value))}"
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
def require_core_tensors(module: "Model", address: Address, tensorfield: TensorFieldBase) -> None:
|
|
221
|
+
state = require_tensor_attribute(address, tensorfield, TensorKey.state)
|
|
222
|
+
trainable = require_tensor_attribute(address, tensorfield, TensorKey.trainable)
|
|
223
|
+
content = require_tensor_tree(
|
|
224
|
+
address,
|
|
225
|
+
TensorKey.content,
|
|
226
|
+
getattr(tensorfield, TensorKey.content, None),
|
|
227
|
+
)
|
|
228
|
+
targets = require_targets(address, tensorfield)
|
|
229
|
+
|
|
230
|
+
field_shape = module.hyperparameters.shapes[address]
|
|
231
|
+
if state.ndim != len(field_shape) + 1:
|
|
232
|
+
raise ForwardContractError(
|
|
233
|
+
f"forward input '{address}' state must have rank {len(field_shape) + 1}, got {state.ndim}"
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
expected_shape = (state.shape[0], *field_shape)
|
|
237
|
+
if tuple(state.shape) != expected_shape:
|
|
238
|
+
raise ForwardContractError(
|
|
239
|
+
f"forward input '{address}' state must have shape {expected_shape}, got {tuple(state.shape)}"
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
if state.dtype not in INTEGER_DTYPES:
|
|
243
|
+
raise TypeError(f"forward input '{address}' state must use an integer dtype, got {state.dtype}")
|
|
244
|
+
|
|
245
|
+
if tuple(trainable.shape) != tuple(state.shape):
|
|
246
|
+
raise ForwardContractError(
|
|
247
|
+
f"forward input '{address}' trainable must have shape {tuple(state.shape)}, got {tuple(trainable.shape)}"
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
if trainable.dtype != torch.bool:
|
|
251
|
+
raise TypeError(f"forward input '{address}' trainable must use bool dtype, got {trainable.dtype}")
|
|
252
|
+
|
|
253
|
+
require_token_values(address, TensorKey.state, state)
|
|
254
|
+
require_content_prefix_shapes(address, content, state)
|
|
255
|
+
|
|
256
|
+
if TensorKey.state in targets.keys():
|
|
257
|
+
target_state_name = f"{TensorKey.targets}[{TensorKey.state}]"
|
|
258
|
+
target_state = require_tensor_tree(address, target_state_name, targets[TensorKey.state])
|
|
259
|
+
require_matching_tree_shapes(
|
|
260
|
+
address,
|
|
261
|
+
actual_name=target_state_name,
|
|
262
|
+
actual=target_state,
|
|
263
|
+
expected_name=TensorKey.state,
|
|
264
|
+
expected={(): state},
|
|
265
|
+
)
|
|
266
|
+
require_integer_tensors(address, target_state_name, target_state)
|
|
267
|
+
require_token_values(address, target_state_name, targets[TensorKey.state])
|
|
268
|
+
|
|
269
|
+
if TensorKey.content in targets.keys():
|
|
270
|
+
target_content_name = f"{TensorKey.targets}[{TensorKey.content}]"
|
|
271
|
+
target_content = require_tensor_tree(address, target_content_name, targets[TensorKey.content])
|
|
272
|
+
require_matching_tree_shapes(
|
|
273
|
+
address,
|
|
274
|
+
actual_name=target_content_name,
|
|
275
|
+
actual=target_content,
|
|
276
|
+
expected_name=TensorKey.content,
|
|
277
|
+
expected=content,
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
def require_tensor_devices(module: "Model", address: Address, tensorfield: TensorFieldBase) -> None:
|
|
282
|
+
tensors = list(iter_tensor_leaves(tensorfield))
|
|
283
|
+
devices = {tensor.device for _, tensor in tensors}
|
|
284
|
+
if len(devices) > 1:
|
|
285
|
+
formatted = ", ".join(sorted(str(device) for device in devices))
|
|
286
|
+
raise ForwardContractError(f"forward input '{address}' tensors must share one device, got {formatted}")
|
|
287
|
+
|
|
288
|
+
module_device = getattr(module, "device", None)
|
|
289
|
+
if isinstance(module_device, torch.device) and devices and next(iter(devices)) != module_device:
|
|
290
|
+
raise ForwardContractError(
|
|
291
|
+
f"forward input '{address}' tensors must be on module device {module_device}, got {next(iter(devices))}"
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
def require_mask_contract(module: "Model", address: Address, tensorfield: TensorFieldBase) -> None:
|
|
296
|
+
state = tensorfield.state
|
|
297
|
+
trainable = tensorfield.trainable
|
|
298
|
+
is_masked = state.eq(Tokens.masked.value)
|
|
299
|
+
is_target = address in module.hyperparameters.target
|
|
300
|
+
|
|
301
|
+
if trainable.any() and not state.masked_select(trainable).eq(Tokens.masked.value).all():
|
|
302
|
+
raise ForwardContractError(f"forward input '{address}' trainable positions must have masked state")
|
|
303
|
+
|
|
304
|
+
if not is_target and (is_masked & ~trainable).any():
|
|
305
|
+
raise ForwardContractError(f"forward input '{address}' has masked state where trainable is false")
|
|
306
|
+
|
|
307
|
+
if not trainable.any():
|
|
308
|
+
return
|
|
309
|
+
|
|
310
|
+
targets = tensorfield.targets
|
|
311
|
+
for key in (TensorKey.state, TensorKey.content):
|
|
312
|
+
if key not in targets.keys():
|
|
313
|
+
raise ForwardContractError(f"forward input '{address}' has trainable positions but lacks targets[{key}]")
|
|
314
|
+
|
|
315
|
+
target_state = targets[TensorKey.state]
|
|
316
|
+
if target_state.masked_select(trainable).eq(Tokens.masked.value).any():
|
|
317
|
+
raise ForwardContractError(f"forward input '{address}' targets[{TensorKey.state}] must not be masked")
|
|
318
|
+
|
|
319
|
+
|
|
320
|
+
def require_target_contract(
|
|
321
|
+
module: "Model",
|
|
322
|
+
address: Address,
|
|
323
|
+
tensorfield: TensorFieldBase,
|
|
324
|
+
*,
|
|
325
|
+
strata: Strata | None,
|
|
326
|
+
) -> None:
|
|
327
|
+
if address not in module.hyperparameters.target:
|
|
328
|
+
return
|
|
329
|
+
|
|
330
|
+
if not tensorfield.state.eq(Tokens.masked.value).all():
|
|
331
|
+
raise ForwardContractError(f"target field '{address}' must not contain visible input state")
|
|
332
|
+
|
|
333
|
+
if strata in (Strata.train, Strata.validate, Strata.test) and not tensorfield.trainable.any():
|
|
334
|
+
raise ForwardContractError(f"target field '{address}' must have trainable positions in {strata} strata")
|
|
335
|
+
|
|
336
|
+
|
|
337
|
+
def require_tensor_attribute(address: Address, tensorfield: TensorFieldBase, name: str) -> torch.Tensor:
|
|
338
|
+
value = getattr(tensorfield, name, None)
|
|
339
|
+
if not torch.is_tensor(value):
|
|
340
|
+
raise TypeError(f"forward input '{address}' {name} must be a torch.Tensor, got {type(value).__name__}")
|
|
341
|
+
|
|
342
|
+
return value
|
|
343
|
+
|
|
344
|
+
|
|
345
|
+
def require_targets(address: Address, tensorfield: TensorFieldBase) -> TensorDict:
|
|
346
|
+
value = getattr(tensorfield, TensorKey.targets, None)
|
|
347
|
+
if not isinstance(value, TensorDict):
|
|
348
|
+
raise TypeError(
|
|
349
|
+
f"forward input '{address}' {TensorKey.targets} must be a TensorDict, got {type(value).__name__}"
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
require_tensor_tree(address, TensorKey.targets, value, allow_empty=True)
|
|
353
|
+
return value
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
def require_tensor_tree(
|
|
357
|
+
address: Address,
|
|
358
|
+
name: str,
|
|
359
|
+
value: Any,
|
|
360
|
+
*,
|
|
361
|
+
allow_empty: bool = False,
|
|
362
|
+
) -> dict[tuple[str, ...], torch.Tensor]:
|
|
363
|
+
tensors = dict(iter_tensor_leaves(value))
|
|
364
|
+
if not tensors and not allow_empty:
|
|
365
|
+
raise TypeError(f"forward input '{address}' {name} must contain at least one tensor")
|
|
366
|
+
|
|
367
|
+
return tensors
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
def require_matching_tree_shapes(
|
|
371
|
+
address: Address,
|
|
372
|
+
*,
|
|
373
|
+
actual_name: str,
|
|
374
|
+
actual: dict[tuple[str, ...], torch.Tensor],
|
|
375
|
+
expected_name: str,
|
|
376
|
+
expected: dict[tuple[str, ...], torch.Tensor],
|
|
377
|
+
) -> None:
|
|
378
|
+
if set(actual) != set(expected):
|
|
379
|
+
raise ForwardContractError(
|
|
380
|
+
f"forward input '{address}' {actual_name} keys must match {expected_name} keys: "
|
|
381
|
+
f"expected {format_paths(expected)}, got {format_paths(actual)}"
|
|
382
|
+
)
|
|
383
|
+
|
|
384
|
+
for path, actual_tensor in actual.items():
|
|
385
|
+
expected_tensor = expected[path]
|
|
386
|
+
if tuple(actual_tensor.shape) != tuple(expected_tensor.shape):
|
|
387
|
+
suffix = format_path(path)
|
|
388
|
+
raise ForwardContractError(
|
|
389
|
+
f"forward input '{address}' {actual_name}{suffix} must have shape "
|
|
390
|
+
f"{tuple(expected_tensor.shape)}, got {tuple(actual_tensor.shape)}"
|
|
391
|
+
)
|
|
392
|
+
|
|
393
|
+
|
|
394
|
+
def require_content_prefix_shapes(
|
|
395
|
+
address: Address,
|
|
396
|
+
content: dict[tuple[str, ...], torch.Tensor],
|
|
397
|
+
state: torch.Tensor,
|
|
398
|
+
) -> None:
|
|
399
|
+
state_shape = tuple(state.shape)
|
|
400
|
+
state_rank = len(state_shape)
|
|
401
|
+
for path, tensor in content.items():
|
|
402
|
+
if len(tensor.shape) < state_rank or tuple(tensor.shape[:state_rank]) != state_shape:
|
|
403
|
+
suffix = format_path(path)
|
|
404
|
+
raise ForwardContractError(
|
|
405
|
+
f"forward input '{address}' {TensorKey.content}{suffix} must start with {TensorKey.state} "
|
|
406
|
+
f"shape {state_shape}, "
|
|
407
|
+
f"got {tuple(tensor.shape)}"
|
|
408
|
+
)
|
|
409
|
+
|
|
410
|
+
|
|
411
|
+
def require_integer_tensors(
|
|
412
|
+
address: Address,
|
|
413
|
+
name: str,
|
|
414
|
+
tensors: dict[tuple[str, ...], torch.Tensor],
|
|
415
|
+
) -> None:
|
|
416
|
+
for path, tensor in tensors.items():
|
|
417
|
+
if tensor.dtype not in INTEGER_DTYPES:
|
|
418
|
+
suffix = format_path(path)
|
|
419
|
+
raise TypeError(f"forward input '{address}' {name}{suffix} must use an integer dtype, got {tensor.dtype}")
|
|
420
|
+
|
|
421
|
+
|
|
422
|
+
def require_token_values(address: Address, name: str, values: torch.Tensor) -> None:
|
|
423
|
+
valid = torch.tensor([token.value for token in Tokens], device=values.device, dtype=values.dtype)
|
|
424
|
+
invalid = ~torch.isin(values, valid)
|
|
425
|
+
if invalid.any():
|
|
426
|
+
value = values.masked_select(invalid).reshape(-1)[0].item()
|
|
427
|
+
raise ForwardContractError(f"forward input '{address}' {name} contains invalid token id {value}")
|
|
428
|
+
|
|
429
|
+
|
|
430
|
+
def iter_tensor_leaves(value: Any, path: tuple[str, ...] = ()) -> Iterator[tuple[tuple[str, ...], torch.Tensor]]:
|
|
431
|
+
if torch.is_tensor(value):
|
|
432
|
+
yield path, value
|
|
433
|
+
return
|
|
434
|
+
|
|
435
|
+
if isinstance(value, TensorFieldBase):
|
|
436
|
+
for name in (TensorKey.state, TensorKey.trainable, TensorKey.content, TensorKey.targets):
|
|
437
|
+
yield from iter_tensor_leaves(getattr(value, name, None), (*path, name))
|
|
438
|
+
return
|
|
439
|
+
|
|
440
|
+
if isinstance(value, TensorDict):
|
|
441
|
+
for key in value.keys():
|
|
442
|
+
yield from iter_tensor_leaves(value[key], (*path, str(key)))
|
|
443
|
+
return
|
|
444
|
+
|
|
445
|
+
if isinstance(value, Mapping):
|
|
446
|
+
for key, item in value.items():
|
|
447
|
+
yield from iter_tensor_leaves(item, (*path, str(key)))
|
|
448
|
+
return
|
|
449
|
+
|
|
450
|
+
raise TypeError(f"expected tensor tree at {format_path(path) or '<root>'}, got {type(value).__name__}")
|
|
451
|
+
|
|
452
|
+
|
|
453
|
+
def format_addresses(addresses: set[Address]) -> str:
|
|
454
|
+
return ", ".join(sorted(str(address) for address in addresses))
|
|
455
|
+
|
|
456
|
+
|
|
457
|
+
def format_paths(values: Mapping[tuple[str, ...], Any]) -> str:
|
|
458
|
+
return ", ".join(format_path(path) or "<tensor>" for path in sorted(values))
|
|
459
|
+
|
|
460
|
+
|
|
461
|
+
def format_path(path: tuple[str, ...]) -> str:
|
|
462
|
+
return "".join(f"[{part}]" for part in path)
|
|
463
|
+
|
|
464
|
+
|
|
465
|
+
def qualified_name(cls: type[Any]) -> str:
|
|
466
|
+
return f"{cls.__module__}.{cls.__qualname__}"
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
"""Runtime graph construction for schema-backed models."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from copy import deepcopy
|
|
6
|
+
from typing import TYPE_CHECKING
|
|
7
|
+
|
|
8
|
+
import torch
|
|
9
|
+
|
|
10
|
+
from json2vec.architecture.node import NodeModule
|
|
11
|
+
from json2vec.structs.experiment import Hyperparameters
|
|
12
|
+
from json2vec.structs.tree import Address, Node
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from json2vec.architecture.root import Model
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class ModelGraph:
|
|
19
|
+
"""Build and rebuild runtime modules from schema hyperparameters."""
|
|
20
|
+
|
|
21
|
+
@staticmethod
|
|
22
|
+
def build(hyperparameters: Hyperparameters, batch_size: int) -> tuple[torch.nn.ModuleDict, torch.Tensor]:
|
|
23
|
+
from json2vec.data.iterables import mock
|
|
24
|
+
|
|
25
|
+
nodes: torch.nn.ModuleDict[str, NodeModule] = torch.nn.ModuleDict()
|
|
26
|
+
|
|
27
|
+
for address in hyperparameters.requests | hyperparameters.arrays:
|
|
28
|
+
nodes[address] = NodeModule(
|
|
29
|
+
hyperparameters=hyperparameters,
|
|
30
|
+
address=address,
|
|
31
|
+
batch_size=batch_size,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
return nodes, mock(hyperparameters=hyperparameters, batch_size=batch_size)
|
|
35
|
+
|
|
36
|
+
@staticmethod
|
|
37
|
+
def install(module: "Model") -> None:
|
|
38
|
+
module.nodes, module.example_input_array = ModelGraph.build(
|
|
39
|
+
hyperparameters=module.hyperparameters,
|
|
40
|
+
batch_size=module.batch_size,
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
@staticmethod
|
|
44
|
+
def rebuild(module: "Model") -> None:
|
|
45
|
+
module.hyperparameters._clear_tree_caches()
|
|
46
|
+
was_training = module.training
|
|
47
|
+
device = module.device
|
|
48
|
+
previous = {
|
|
49
|
+
name: value.detach().clone() if isinstance(value, torch.Tensor) else deepcopy(value)
|
|
50
|
+
for name, value in module.state_dict().items()
|
|
51
|
+
}
|
|
52
|
+
ModelGraph.install(module)
|
|
53
|
+
if isinstance(device, torch.device):
|
|
54
|
+
module.to(device=device)
|
|
55
|
+
current = module.state_dict()
|
|
56
|
+
compatible = {}
|
|
57
|
+
for name, value in previous.items():
|
|
58
|
+
if name not in current:
|
|
59
|
+
continue
|
|
60
|
+
|
|
61
|
+
current_value = current[name]
|
|
62
|
+
if isinstance(current_value, torch.Tensor) and isinstance(value, torch.Tensor):
|
|
63
|
+
if current_value.shape != value.shape:
|
|
64
|
+
continue
|
|
65
|
+
elif type(current_value) is not type(value):
|
|
66
|
+
continue
|
|
67
|
+
|
|
68
|
+
compatible[name] = value
|
|
69
|
+
|
|
70
|
+
module.load_state_dict(compatible, strict=False)
|
|
71
|
+
module.train(was_training)
|
|
72
|
+
|
|
73
|
+
@staticmethod
|
|
74
|
+
def reset_selected(module: "Model", selected: list[Node], *, descendants: bool = False) -> None:
|
|
75
|
+
from json2vec.data.iterables import mock
|
|
76
|
+
|
|
77
|
+
selected_by_address: dict[Address, Node] = {}
|
|
78
|
+
for node in selected:
|
|
79
|
+
if node.address in module.nodes:
|
|
80
|
+
selected_by_address[Address(str(node.address))] = node
|
|
81
|
+
|
|
82
|
+
if descendants:
|
|
83
|
+
for descendant in getattr(node, "descendants", ()):
|
|
84
|
+
if descendant.address in module.nodes:
|
|
85
|
+
selected_by_address[Address(str(descendant.address))] = descendant
|
|
86
|
+
|
|
87
|
+
if not selected_by_address:
|
|
88
|
+
raise ValueError("reset matched no runtime nodes")
|
|
89
|
+
|
|
90
|
+
for address in selected_by_address:
|
|
91
|
+
module.nodes[address] = NodeModule(
|
|
92
|
+
hyperparameters=module.hyperparameters,
|
|
93
|
+
address=address,
|
|
94
|
+
batch_size=module.batch_size,
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
module.example_input_array = mock(hyperparameters=module.hyperparameters, batch_size=module.batch_size)
|
|
98
|
+
device = module.device
|
|
99
|
+
if isinstance(device, torch.device):
|
|
100
|
+
module.to(device=device)
|
|
@@ -88,7 +88,7 @@ def render_schema_plot(
|
|
|
88
88
|
) -> RenderableType:
|
|
89
89
|
hyperparameters = module.hyperparameters
|
|
90
90
|
root = hyperparameters.fields if address is None else resolve_node(hyperparameters=hyperparameters, address=address)
|
|
91
|
-
title = "State" if state_focus else "Schema"
|
|
91
|
+
title = "JSON2Vec State" if state_focus else "JSON2Vec Schema"
|
|
92
92
|
|
|
93
93
|
tree = Tree(render_node_label(module=module, node=root, state_focus=state_focus), guide_style="dim")
|
|
94
94
|
append_schema_children(tree=tree, module=module, node=root, detail=detail or state_focus, state_focus=state_focus)
|