json2vec 0.4.3__tar.gz → 0.4.4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. {json2vec-0.4.3/src/json2vec.egg-info → json2vec-0.4.4}/PKG-INFO +1 -1
  2. {json2vec-0.4.3 → json2vec-0.4.4}/pyproject.toml +19 -1
  3. json2vec-0.4.4/src/json2vec/architecture/checkpoint.py +69 -0
  4. json2vec-0.4.4/src/json2vec/architecture/contracts.py +466 -0
  5. json2vec-0.4.4/src/json2vec/architecture/graph.py +100 -0
  6. {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/architecture/plot.py +1 -1
  7. {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/architecture/root.py +57 -313
  8. json2vec-0.4.4/src/json2vec/architecture/runtime.py +241 -0
  9. json2vec-0.4.4/src/json2vec/architecture/schema_editor.py +126 -0
  10. {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/data/datasets/streaming.py +5 -2
  11. {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/data/iterables.py +2 -2
  12. {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/inference/callback.py +4 -3
  13. {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/inference/deployment.py +2 -2
  14. {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/structs/enums.py +3 -0
  15. {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/structs/experiment.py +27 -232
  16. json2vec-0.4.4/src/json2vec/structs/selectors.py +236 -0
  17. {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/tensorfields/extensions/category.py +4 -1
  18. {json2vec-0.4.3 → json2vec-0.4.4/src/json2vec.egg-info}/PKG-INFO +1 -1
  19. {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec.egg-info/SOURCES.txt +6 -0
  20. {json2vec-0.4.3 → json2vec-0.4.4}/LICENSE +0 -0
  21. {json2vec-0.4.3 → json2vec-0.4.4}/NOTICE +0 -0
  22. {json2vec-0.4.3 → json2vec-0.4.4}/README.md +0 -0
  23. {json2vec-0.4.3 → json2vec-0.4.4}/setup.cfg +0 -0
  24. {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/__init__.py +0 -0
  25. {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/architecture/__init__.py +0 -0
  26. {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/architecture/attention.py +0 -0
  27. {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/architecture/encoder.py +0 -0
  28. {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/architecture/node.py +0 -0
  29. {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/architecture/pool.py +0 -0
  30. {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/architecture/rotary.py +0 -0
  31. {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/data/__init__.py +0 -0
  32. {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/data/datasets/__init__.py +0 -0
  33. {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/data/datasets/base.py +0 -0
  34. {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/data/datasets/polars.py +0 -0
  35. {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/data/processing.py +0 -0
  36. {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/distributed.py +0 -0
  37. {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/inference/__init__.py +0 -0
  38. {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/logging/__init__.py +0 -0
  39. {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/logging/config.py +0 -0
  40. {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/logging/epoch.py +0 -0
  41. {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/logging/throughput.py +0 -0
  42. {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/preprocessors/__init__.py +0 -0
  43. {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/preprocessors/base.py +0 -0
  44. {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/preprocessors/extensions/__init__.py +0 -0
  45. {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/preprocessors/spec.py +0 -0
  46. {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/structs/__init__.py +0 -0
  47. {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/structs/packages.py +0 -0
  48. {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/structs/structure.py +0 -0
  49. {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/structs/tree.py +0 -0
  50. {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/tensorfields/__init__.py +0 -0
  51. {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/tensorfields/base.py +0 -0
  52. {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/tensorfields/extensions/__init__.py +0 -0
  53. {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/tensorfields/extensions/dateparts.py +0 -0
  54. {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/tensorfields/extensions/entity.py +0 -0
  55. {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/tensorfields/extensions/number.py +0 -0
  56. {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/tensorfields/extensions/set.py +0 -0
  57. {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/tensorfields/extensions/text.py +0 -0
  58. {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/tensorfields/extensions/vector.py +0 -0
  59. {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/tensorfields/shared/__init__.py +0 -0
  60. {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/tensorfields/shared/counter.py +0 -0
  61. {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/tensorfields/shared/vocabulary.py +0 -0
  62. {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec/tensorfields/spec.py +0 -0
  63. {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec.egg-info/dependency_links.txt +0 -0
  64. {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec.egg-info/requires.txt +0 -0
  65. {json2vec-0.4.3 → json2vec-0.4.4}/src/json2vec.egg-info/top_level.txt +0 -0
  66. {json2vec-0.4.3 → json2vec-0.4.4}/tests/test_callbacks.py +0 -0
  67. {json2vec-0.4.3 → json2vec-0.4.4}/tests/test_public_api.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: json2vec
3
- Version: 0.4.3
3
+ Version: 0.4.4
4
4
  Summary: {...} -> [*]
5
5
  License-Expression: Apache-2.0
6
6
  Requires-Python: >=3.12
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "json2vec"
3
- version = "0.4.3"
3
+ version = "0.4.4"
4
4
  description = "{...} -> [*]"
5
5
  readme = "README.md"
6
6
  license = "Apache-2.0"
@@ -68,3 +68,21 @@ python_files = ["test_*.py"]
68
68
  [tool.ruff]
69
69
  line-length = 120
70
70
  lint.extend-select = ["I"]
71
+
72
+ [tool.ty.rules]
73
+ # Keep ty as a green baseline while the dynamic plugin/Torch/Lightning surfaces
74
+ # are covered by runtime tests and can be tightened incrementally.
75
+ unresolved-attribute = "ignore"
76
+ invalid-type-form = "ignore"
77
+ invalid-argument-type = "ignore"
78
+ invalid-assignment = "ignore"
79
+ unknown-argument = "ignore"
80
+ invalid-method-override = "ignore"
81
+ call-non-callable = "ignore"
82
+ invalid-return-type = "ignore"
83
+ not-subscriptable = "ignore"
84
+ unsupported-operator = "ignore"
85
+ no-matching-overload = "ignore"
86
+ invalid-attribute-override = "ignore"
87
+ redundant-cast = "ignore"
88
+ unused-ignore-comment = "ignore"
@@ -0,0 +1,69 @@
1
+ """Checkpoint serialization helpers for JSON2Vec models."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+ from typing import TYPE_CHECKING, Any
7
+
8
+ import torch
9
+ from loguru import logger
10
+
11
+ from json2vec.architecture.graph import ModelGraph
12
+ from json2vec.structs.experiment import Hyperparameters
13
+
14
+ if TYPE_CHECKING:
15
+ from json2vec.architecture.root import Model
16
+
17
+
18
+ class CheckpointState:
19
+ """Save, load, and restore model state without owning the public facade."""
20
+
21
+ required_fields = {"state_dict", "hyperparameters", "batch_size"}
22
+
23
+ @staticmethod
24
+ def dump(module: "Model", checkpoint: dict[str, Any]) -> None:
25
+ checkpoint["hyperparameters"] = module.hyperparameters.model_dump(mode="python")
26
+ checkpoint["batch_size"] = module.batch_size
27
+
28
+ @staticmethod
29
+ def save(module: "Model", pathname: str | Path) -> None:
30
+ path = Path(pathname)
31
+ path.parent.mkdir(parents=True, exist_ok=True)
32
+
33
+ checkpoint: dict[str, Any] = {"state_dict": module.state_dict()}
34
+ CheckpointState.dump(module, checkpoint)
35
+ torch.save(checkpoint, path)
36
+
37
+ @staticmethod
38
+ def restore(module: "Model", checkpoint: dict[str, Any]) -> None:
39
+ missing = CheckpointState.required_fields - set(checkpoint)
40
+ if missing:
41
+ fields = ", ".join(sorted(missing))
42
+ raise ValueError(f"missing checkpoint fields: {fields}")
43
+
44
+ device = module.device
45
+ was_training = module.training
46
+ module.hyperparameters = Hyperparameters.model_validate(checkpoint["hyperparameters"])
47
+ module.batch_size = checkpoint["batch_size"]
48
+ ModelGraph.install(module)
49
+ if isinstance(device, torch.device):
50
+ module.to(device=device)
51
+ module.load_state_dict(state_dict=checkpoint["state_dict"])
52
+ module.train(was_training)
53
+
54
+ @staticmethod
55
+ def load(model_cls: type["Model"], checkpoint: str | Path) -> "Model":
56
+ path = Path(checkpoint)
57
+ logger.bind(component="model_factory", checkpoint=str(path)).info("loading Model from checkpoint")
58
+ state = torch.load(path, weights_only=False, map_location="cpu")
59
+ if "hyperparameters" not in state:
60
+ raise ValueError("missing hyperparameters in checkpoint")
61
+
62
+ model = model_cls(
63
+ hyperparameters=Hyperparameters.model_validate(state["hyperparameters"]),
64
+ batch_size=state["batch_size"],
65
+ )
66
+ model.restore_checkpoint_state(state)
67
+ logger.bind(component="model_factory", checkpoint=str(path)).info("restored model state from checkpoint")
68
+
69
+ return model
@@ -0,0 +1,466 @@
1
+ """Generic runtime contracts for model forward inputs."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from collections.abc import Iterator, Mapping
6
+ from dataclasses import dataclass, field
7
+ from typing import TYPE_CHECKING, Any
8
+
9
+ import torch
10
+ from tensordict import TensorDict
11
+
12
+ from json2vec.structs.enums import Strata, TensorKey, Tokens
13
+ from json2vec.structs.tree import Address
14
+ from json2vec.tensorfields.base import TENSORFIELDS, TensorFieldBase
15
+
16
+ if TYPE_CHECKING:
17
+ from json2vec.architecture.root import Model
18
+
19
+
20
+ class ForwardContractError(ValueError):
21
+ """Raised when a forward batch violates a model input contract."""
22
+
23
+
24
+ INTEGER_DTYPES = {
25
+ torch.uint8,
26
+ torch.int8,
27
+ torch.int16,
28
+ torch.int32,
29
+ torch.int64,
30
+ }
31
+
32
+
33
+ ContractSignature = tuple[Any, ...]
34
+ ContractScope = tuple[str, int, int, ContractSignature]
35
+
36
+
37
+ @dataclass
38
+ class ContractScheduler:
39
+ """Deterministic backoff scheduler for expensive forward contract checks."""
40
+
41
+ periodic_interval: int = 1024
42
+ _counts: dict[ContractScope, int] = field(default_factory=dict)
43
+
44
+ def reset(self) -> None:
45
+ self._counts.clear()
46
+
47
+ def should_check(
48
+ self,
49
+ module: "Model",
50
+ inputs: Any,
51
+ *,
52
+ strata: Strata,
53
+ dataloader_idx: int,
54
+ ) -> bool:
55
+ generation = int(getattr(module, "_contract_generation", 0))
56
+ scope = (
57
+ str(strata),
58
+ dataloader_idx,
59
+ generation,
60
+ batch_signature(module, inputs),
61
+ )
62
+ count = self._counts.get(scope, 0)
63
+ self._counts[scope] = count + 1
64
+ return is_backoff_index(count, periodic_interval=self.periodic_interval)
65
+
66
+
67
+ def sanitize(
68
+ module: "Model",
69
+ inputs: TensorDict[Address, TensorFieldBase],
70
+ *,
71
+ strata: Strata | str,
72
+ dataloader_idx: int = 0,
73
+ ) -> None:
74
+ """Validate the generic forward-input contract before model execution."""
75
+ normalized = Strata.normalize(strata)
76
+ scheduler = getattr(module, "_contract_scheduler", None)
77
+ if isinstance(scheduler, ContractScheduler) and not scheduler.should_check(
78
+ module,
79
+ inputs,
80
+ strata=normalized,
81
+ dataloader_idx=dataloader_idx,
82
+ ):
83
+ return
84
+
85
+ if not isinstance(inputs, TensorDict):
86
+ raise TypeError(f"forward inputs must be a TensorDict, got {type(inputs).__name__}")
87
+
88
+ require_forward_addresses(module, inputs, strata=normalized)
89
+
90
+ for address in module.hyperparameters.active_requests:
91
+ tensorfield = inputs[address]
92
+ require_registered_tensorfield(module, address, tensorfield)
93
+ require_core_tensors(module, address, tensorfield)
94
+ require_tensor_devices(module, address, tensorfield)
95
+ require_target_contract(module, address, tensorfield, strata=normalized)
96
+ require_mask_contract(module, address, tensorfield)
97
+
98
+
99
+ def is_backoff_index(index: int, *, periodic_interval: int) -> bool:
100
+ if index == 0:
101
+ return True
102
+
103
+ if (index & (index - 1)) == 0:
104
+ return True
105
+
106
+ return periodic_interval > 0 and index % periodic_interval == 0
107
+
108
+
109
+ def batch_signature(module: "Model", inputs: Any) -> ContractSignature:
110
+ if not isinstance(inputs, TensorDict):
111
+ return ("inputs", qualified_name(type(inputs)))
112
+
113
+ input_keys = tuple(sorted(str(key) for key in inputs.keys()))
114
+ fields: list[tuple[Any, ...]] = []
115
+ for address in sorted(module.hyperparameters.active_requests, key=str):
116
+ if address not in inputs.keys():
117
+ fields.append((str(address), "missing"))
118
+ continue
119
+
120
+ tensorfield = inputs[address]
121
+ fields.append(
122
+ (
123
+ str(address),
124
+ qualified_name(type(tensorfield)),
125
+ tensor_signature(getattr(tensorfield, TensorKey.state, None)),
126
+ tensor_signature(getattr(tensorfield, TensorKey.trainable, None)),
127
+ tensor_tree_signature(getattr(tensorfield, TensorKey.content, None)),
128
+ tensor_tree_signature(getattr(tensorfield, TensorKey.targets, None)),
129
+ )
130
+ )
131
+
132
+ return (input_keys, tuple(fields))
133
+
134
+
135
+ def tensor_signature(value: Any) -> tuple[Any, ...]:
136
+ if not torch.is_tensor(value):
137
+ return ("object", qualified_name(type(value)))
138
+
139
+ return (
140
+ "tensor",
141
+ tuple(value.shape),
142
+ str(value.dtype),
143
+ str(value.device),
144
+ )
145
+
146
+
147
+ def tensor_tree_signature(value: Any) -> tuple[Any, ...]:
148
+ if torch.is_tensor(value):
149
+ return tensor_signature(value)
150
+
151
+ if isinstance(value, TensorDict):
152
+ return (
153
+ "tensordict",
154
+ tuple((str(key), tensor_tree_signature(value[key])) for key in sorted(value.keys(), key=str)),
155
+ )
156
+
157
+ if isinstance(value, Mapping):
158
+ return (
159
+ "mapping",
160
+ tuple(
161
+ (str(key), tensor_tree_signature(item)) for key, item in sorted(value.items(), key=lambda x: str(x[0]))
162
+ ),
163
+ )
164
+
165
+ return ("object", qualified_name(type(value)))
166
+
167
+
168
+ def require_forward_addresses(
169
+ module: "Model",
170
+ inputs: TensorDict[Address, TensorFieldBase],
171
+ *,
172
+ strata: Strata,
173
+ ) -> None:
174
+ keys = set(inputs.keys())
175
+ metadata_keys = {key for key in keys if key == TensorKey.metadata}
176
+ addresses = {Address(str(key)) for key in keys if key != TensorKey.metadata}
177
+ expected = set(module.hyperparameters.active_requests)
178
+
179
+ if metadata_keys and strata != Strata.predict:
180
+ raise ForwardContractError(f"forward input contains {TensorKey.metadata} outside predict strata")
181
+
182
+ missing = expected - addresses
183
+ if missing:
184
+ raise ForwardContractError(f"forward input is missing active request address(es): {format_addresses(missing)}")
185
+
186
+ extra = addresses - expected
187
+ if not extra:
188
+ return
189
+
190
+ arrays = extra & set(module.hyperparameters.arrays)
191
+ if arrays:
192
+ raise ForwardContractError(
193
+ f"forward input contains array address(es); only active leaf request addresses are allowed: "
194
+ f"{format_addresses(arrays)}"
195
+ )
196
+
197
+ inactive = {address for address in extra if address in module.hyperparameters.requests}
198
+ if inactive:
199
+ raise ForwardContractError(
200
+ "forward input contains inactive request address(es): "
201
+ f"{format_addresses(inactive)}. Inactive fields remain in the schema but must not be present in runtime input."
202
+ )
203
+
204
+ raise ForwardContractError(f"forward input contains unknown address(es): {format_addresses(extra)}")
205
+
206
+
207
+ def require_registered_tensorfield(module: "Model", address: Address, value: Any) -> None:
208
+ if not isinstance(value, TensorFieldBase):
209
+ raise TypeError(f"forward input '{address}' must be a TensorFieldBase, got {type(value).__name__}")
210
+
211
+ request = module.hyperparameters.requests[address]
212
+ expected = TENSORFIELDS[request.type].TensorField
213
+ if not isinstance(value, expected):
214
+ raise TypeError(
215
+ f"forward input '{address}' must use tensorfield class {qualified_name(expected)}, "
216
+ f"got {qualified_name(type(value))}"
217
+ )
218
+
219
+
220
+ def require_core_tensors(module: "Model", address: Address, tensorfield: TensorFieldBase) -> None:
221
+ state = require_tensor_attribute(address, tensorfield, TensorKey.state)
222
+ trainable = require_tensor_attribute(address, tensorfield, TensorKey.trainable)
223
+ content = require_tensor_tree(
224
+ address,
225
+ TensorKey.content,
226
+ getattr(tensorfield, TensorKey.content, None),
227
+ )
228
+ targets = require_targets(address, tensorfield)
229
+
230
+ field_shape = module.hyperparameters.shapes[address]
231
+ if state.ndim != len(field_shape) + 1:
232
+ raise ForwardContractError(
233
+ f"forward input '{address}' state must have rank {len(field_shape) + 1}, got {state.ndim}"
234
+ )
235
+
236
+ expected_shape = (state.shape[0], *field_shape)
237
+ if tuple(state.shape) != expected_shape:
238
+ raise ForwardContractError(
239
+ f"forward input '{address}' state must have shape {expected_shape}, got {tuple(state.shape)}"
240
+ )
241
+
242
+ if state.dtype not in INTEGER_DTYPES:
243
+ raise TypeError(f"forward input '{address}' state must use an integer dtype, got {state.dtype}")
244
+
245
+ if tuple(trainable.shape) != tuple(state.shape):
246
+ raise ForwardContractError(
247
+ f"forward input '{address}' trainable must have shape {tuple(state.shape)}, got {tuple(trainable.shape)}"
248
+ )
249
+
250
+ if trainable.dtype != torch.bool:
251
+ raise TypeError(f"forward input '{address}' trainable must use bool dtype, got {trainable.dtype}")
252
+
253
+ require_token_values(address, TensorKey.state, state)
254
+ require_content_prefix_shapes(address, content, state)
255
+
256
+ if TensorKey.state in targets.keys():
257
+ target_state_name = f"{TensorKey.targets}[{TensorKey.state}]"
258
+ target_state = require_tensor_tree(address, target_state_name, targets[TensorKey.state])
259
+ require_matching_tree_shapes(
260
+ address,
261
+ actual_name=target_state_name,
262
+ actual=target_state,
263
+ expected_name=TensorKey.state,
264
+ expected={(): state},
265
+ )
266
+ require_integer_tensors(address, target_state_name, target_state)
267
+ require_token_values(address, target_state_name, targets[TensorKey.state])
268
+
269
+ if TensorKey.content in targets.keys():
270
+ target_content_name = f"{TensorKey.targets}[{TensorKey.content}]"
271
+ target_content = require_tensor_tree(address, target_content_name, targets[TensorKey.content])
272
+ require_matching_tree_shapes(
273
+ address,
274
+ actual_name=target_content_name,
275
+ actual=target_content,
276
+ expected_name=TensorKey.content,
277
+ expected=content,
278
+ )
279
+
280
+
281
+ def require_tensor_devices(module: "Model", address: Address, tensorfield: TensorFieldBase) -> None:
282
+ tensors = list(iter_tensor_leaves(tensorfield))
283
+ devices = {tensor.device for _, tensor in tensors}
284
+ if len(devices) > 1:
285
+ formatted = ", ".join(sorted(str(device) for device in devices))
286
+ raise ForwardContractError(f"forward input '{address}' tensors must share one device, got {formatted}")
287
+
288
+ module_device = getattr(module, "device", None)
289
+ if isinstance(module_device, torch.device) and devices and next(iter(devices)) != module_device:
290
+ raise ForwardContractError(
291
+ f"forward input '{address}' tensors must be on module device {module_device}, got {next(iter(devices))}"
292
+ )
293
+
294
+
295
+ def require_mask_contract(module: "Model", address: Address, tensorfield: TensorFieldBase) -> None:
296
+ state = tensorfield.state
297
+ trainable = tensorfield.trainable
298
+ is_masked = state.eq(Tokens.masked.value)
299
+ is_target = address in module.hyperparameters.target
300
+
301
+ if trainable.any() and not state.masked_select(trainable).eq(Tokens.masked.value).all():
302
+ raise ForwardContractError(f"forward input '{address}' trainable positions must have masked state")
303
+
304
+ if not is_target and (is_masked & ~trainable).any():
305
+ raise ForwardContractError(f"forward input '{address}' has masked state where trainable is false")
306
+
307
+ if not trainable.any():
308
+ return
309
+
310
+ targets = tensorfield.targets
311
+ for key in (TensorKey.state, TensorKey.content):
312
+ if key not in targets.keys():
313
+ raise ForwardContractError(f"forward input '{address}' has trainable positions but lacks targets[{key}]")
314
+
315
+ target_state = targets[TensorKey.state]
316
+ if target_state.masked_select(trainable).eq(Tokens.masked.value).any():
317
+ raise ForwardContractError(f"forward input '{address}' targets[{TensorKey.state}] must not be masked")
318
+
319
+
320
+ def require_target_contract(
321
+ module: "Model",
322
+ address: Address,
323
+ tensorfield: TensorFieldBase,
324
+ *,
325
+ strata: Strata | None,
326
+ ) -> None:
327
+ if address not in module.hyperparameters.target:
328
+ return
329
+
330
+ if not tensorfield.state.eq(Tokens.masked.value).all():
331
+ raise ForwardContractError(f"target field '{address}' must not contain visible input state")
332
+
333
+ if strata in (Strata.train, Strata.validate, Strata.test) and not tensorfield.trainable.any():
334
+ raise ForwardContractError(f"target field '{address}' must have trainable positions in {strata} strata")
335
+
336
+
337
+ def require_tensor_attribute(address: Address, tensorfield: TensorFieldBase, name: str) -> torch.Tensor:
338
+ value = getattr(tensorfield, name, None)
339
+ if not torch.is_tensor(value):
340
+ raise TypeError(f"forward input '{address}' {name} must be a torch.Tensor, got {type(value).__name__}")
341
+
342
+ return value
343
+
344
+
345
+ def require_targets(address: Address, tensorfield: TensorFieldBase) -> TensorDict:
346
+ value = getattr(tensorfield, TensorKey.targets, None)
347
+ if not isinstance(value, TensorDict):
348
+ raise TypeError(
349
+ f"forward input '{address}' {TensorKey.targets} must be a TensorDict, got {type(value).__name__}"
350
+ )
351
+
352
+ require_tensor_tree(address, TensorKey.targets, value, allow_empty=True)
353
+ return value
354
+
355
+
356
+ def require_tensor_tree(
357
+ address: Address,
358
+ name: str,
359
+ value: Any,
360
+ *,
361
+ allow_empty: bool = False,
362
+ ) -> dict[tuple[str, ...], torch.Tensor]:
363
+ tensors = dict(iter_tensor_leaves(value))
364
+ if not tensors and not allow_empty:
365
+ raise TypeError(f"forward input '{address}' {name} must contain at least one tensor")
366
+
367
+ return tensors
368
+
369
+
370
+ def require_matching_tree_shapes(
371
+ address: Address,
372
+ *,
373
+ actual_name: str,
374
+ actual: dict[tuple[str, ...], torch.Tensor],
375
+ expected_name: str,
376
+ expected: dict[tuple[str, ...], torch.Tensor],
377
+ ) -> None:
378
+ if set(actual) != set(expected):
379
+ raise ForwardContractError(
380
+ f"forward input '{address}' {actual_name} keys must match {expected_name} keys: "
381
+ f"expected {format_paths(expected)}, got {format_paths(actual)}"
382
+ )
383
+
384
+ for path, actual_tensor in actual.items():
385
+ expected_tensor = expected[path]
386
+ if tuple(actual_tensor.shape) != tuple(expected_tensor.shape):
387
+ suffix = format_path(path)
388
+ raise ForwardContractError(
389
+ f"forward input '{address}' {actual_name}{suffix} must have shape "
390
+ f"{tuple(expected_tensor.shape)}, got {tuple(actual_tensor.shape)}"
391
+ )
392
+
393
+
394
+ def require_content_prefix_shapes(
395
+ address: Address,
396
+ content: dict[tuple[str, ...], torch.Tensor],
397
+ state: torch.Tensor,
398
+ ) -> None:
399
+ state_shape = tuple(state.shape)
400
+ state_rank = len(state_shape)
401
+ for path, tensor in content.items():
402
+ if len(tensor.shape) < state_rank or tuple(tensor.shape[:state_rank]) != state_shape:
403
+ suffix = format_path(path)
404
+ raise ForwardContractError(
405
+ f"forward input '{address}' {TensorKey.content}{suffix} must start with {TensorKey.state} "
406
+ f"shape {state_shape}, "
407
+ f"got {tuple(tensor.shape)}"
408
+ )
409
+
410
+
411
+ def require_integer_tensors(
412
+ address: Address,
413
+ name: str,
414
+ tensors: dict[tuple[str, ...], torch.Tensor],
415
+ ) -> None:
416
+ for path, tensor in tensors.items():
417
+ if tensor.dtype not in INTEGER_DTYPES:
418
+ suffix = format_path(path)
419
+ raise TypeError(f"forward input '{address}' {name}{suffix} must use an integer dtype, got {tensor.dtype}")
420
+
421
+
422
+ def require_token_values(address: Address, name: str, values: torch.Tensor) -> None:
423
+ valid = torch.tensor([token.value for token in Tokens], device=values.device, dtype=values.dtype)
424
+ invalid = ~torch.isin(values, valid)
425
+ if invalid.any():
426
+ value = values.masked_select(invalid).reshape(-1)[0].item()
427
+ raise ForwardContractError(f"forward input '{address}' {name} contains invalid token id {value}")
428
+
429
+
430
+ def iter_tensor_leaves(value: Any, path: tuple[str, ...] = ()) -> Iterator[tuple[tuple[str, ...], torch.Tensor]]:
431
+ if torch.is_tensor(value):
432
+ yield path, value
433
+ return
434
+
435
+ if isinstance(value, TensorFieldBase):
436
+ for name in (TensorKey.state, TensorKey.trainable, TensorKey.content, TensorKey.targets):
437
+ yield from iter_tensor_leaves(getattr(value, name, None), (*path, name))
438
+ return
439
+
440
+ if isinstance(value, TensorDict):
441
+ for key in value.keys():
442
+ yield from iter_tensor_leaves(value[key], (*path, str(key)))
443
+ return
444
+
445
+ if isinstance(value, Mapping):
446
+ for key, item in value.items():
447
+ yield from iter_tensor_leaves(item, (*path, str(key)))
448
+ return
449
+
450
+ raise TypeError(f"expected tensor tree at {format_path(path) or '<root>'}, got {type(value).__name__}")
451
+
452
+
453
+ def format_addresses(addresses: set[Address]) -> str:
454
+ return ", ".join(sorted(str(address) for address in addresses))
455
+
456
+
457
+ def format_paths(values: Mapping[tuple[str, ...], Any]) -> str:
458
+ return ", ".join(format_path(path) or "<tensor>" for path in sorted(values))
459
+
460
+
461
+ def format_path(path: tuple[str, ...]) -> str:
462
+ return "".join(f"[{part}]" for part in path)
463
+
464
+
465
+ def qualified_name(cls: type[Any]) -> str:
466
+ return f"{cls.__module__}.{cls.__qualname__}"
@@ -0,0 +1,100 @@
1
+ """Runtime graph construction for schema-backed models."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from copy import deepcopy
6
+ from typing import TYPE_CHECKING
7
+
8
+ import torch
9
+
10
+ from json2vec.architecture.node import NodeModule
11
+ from json2vec.structs.experiment import Hyperparameters
12
+ from json2vec.structs.tree import Address, Node
13
+
14
+ if TYPE_CHECKING:
15
+ from json2vec.architecture.root import Model
16
+
17
+
18
+ class ModelGraph:
19
+ """Build and rebuild runtime modules from schema hyperparameters."""
20
+
21
+ @staticmethod
22
+ def build(hyperparameters: Hyperparameters, batch_size: int) -> tuple[torch.nn.ModuleDict, torch.Tensor]:
23
+ from json2vec.data.iterables import mock
24
+
25
+ nodes: torch.nn.ModuleDict[str, NodeModule] = torch.nn.ModuleDict()
26
+
27
+ for address in hyperparameters.requests | hyperparameters.arrays:
28
+ nodes[address] = NodeModule(
29
+ hyperparameters=hyperparameters,
30
+ address=address,
31
+ batch_size=batch_size,
32
+ )
33
+
34
+ return nodes, mock(hyperparameters=hyperparameters, batch_size=batch_size)
35
+
36
+ @staticmethod
37
+ def install(module: "Model") -> None:
38
+ module.nodes, module.example_input_array = ModelGraph.build(
39
+ hyperparameters=module.hyperparameters,
40
+ batch_size=module.batch_size,
41
+ )
42
+
43
+ @staticmethod
44
+ def rebuild(module: "Model") -> None:
45
+ module.hyperparameters._clear_tree_caches()
46
+ was_training = module.training
47
+ device = module.device
48
+ previous = {
49
+ name: value.detach().clone() if isinstance(value, torch.Tensor) else deepcopy(value)
50
+ for name, value in module.state_dict().items()
51
+ }
52
+ ModelGraph.install(module)
53
+ if isinstance(device, torch.device):
54
+ module.to(device=device)
55
+ current = module.state_dict()
56
+ compatible = {}
57
+ for name, value in previous.items():
58
+ if name not in current:
59
+ continue
60
+
61
+ current_value = current[name]
62
+ if isinstance(current_value, torch.Tensor) and isinstance(value, torch.Tensor):
63
+ if current_value.shape != value.shape:
64
+ continue
65
+ elif type(current_value) is not type(value):
66
+ continue
67
+
68
+ compatible[name] = value
69
+
70
+ module.load_state_dict(compatible, strict=False)
71
+ module.train(was_training)
72
+
73
+ @staticmethod
74
+ def reset_selected(module: "Model", selected: list[Node], *, descendants: bool = False) -> None:
75
+ from json2vec.data.iterables import mock
76
+
77
+ selected_by_address: dict[Address, Node] = {}
78
+ for node in selected:
79
+ if node.address in module.nodes:
80
+ selected_by_address[Address(str(node.address))] = node
81
+
82
+ if descendants:
83
+ for descendant in getattr(node, "descendants", ()):
84
+ if descendant.address in module.nodes:
85
+ selected_by_address[Address(str(descendant.address))] = descendant
86
+
87
+ if not selected_by_address:
88
+ raise ValueError("reset matched no runtime nodes")
89
+
90
+ for address in selected_by_address:
91
+ module.nodes[address] = NodeModule(
92
+ hyperparameters=module.hyperparameters,
93
+ address=address,
94
+ batch_size=module.batch_size,
95
+ )
96
+
97
+ module.example_input_array = mock(hyperparameters=module.hyperparameters, batch_size=module.batch_size)
98
+ device = module.device
99
+ if isinstance(device, torch.device):
100
+ module.to(device=device)
@@ -88,7 +88,7 @@ def render_schema_plot(
88
88
  ) -> RenderableType:
89
89
  hyperparameters = module.hyperparameters
90
90
  root = hyperparameters.fields if address is None else resolve_node(hyperparameters=hyperparameters, address=address)
91
- title = "State" if state_focus else "Schema"
91
+ title = "JSON2Vec State" if state_focus else "JSON2Vec Schema"
92
92
 
93
93
  tree = Tree(render_node_label(module=module, node=root, state_focus=state_focus), guide_style="dim")
94
94
  append_schema_children(tree=tree, module=module, node=root, detail=detail or state_focus, state_focus=state_focus)