json2vec 0.4.4__tar.gz → 0.4.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. {json2vec-0.4.4/src/json2vec.egg-info → json2vec-0.4.5}/PKG-INFO +1 -1
  2. {json2vec-0.4.4 → json2vec-0.4.5}/pyproject.toml +1 -1
  3. {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/architecture/contracts.py +3 -3
  4. {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/architecture/graph.py +18 -5
  5. json2vec-0.4.5/src/json2vec/architecture/mutations.py +323 -0
  6. {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/architecture/root.py +25 -9
  7. {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/architecture/runtime.py +20 -14
  8. {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/structs/experiment.py +11 -6
  9. {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/structs/selectors.py +1 -1
  10. {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/structs/tree.py +15 -0
  11. {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/tensorfields/extensions/category.py +10 -2
  12. {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/tensorfields/extensions/dateparts.py +46 -0
  13. {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/tensorfields/extensions/number.py +1 -1
  14. {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/tensorfields/extensions/set.py +18 -1
  15. {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/tensorfields/extensions/vector.py +50 -8
  16. {json2vec-0.4.4 → json2vec-0.4.5/src/json2vec.egg-info}/PKG-INFO +1 -1
  17. {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec.egg-info/SOURCES.txt +1 -1
  18. json2vec-0.4.4/src/json2vec/architecture/schema_editor.py +0 -126
  19. {json2vec-0.4.4 → json2vec-0.4.5}/LICENSE +0 -0
  20. {json2vec-0.4.4 → json2vec-0.4.5}/NOTICE +0 -0
  21. {json2vec-0.4.4 → json2vec-0.4.5}/README.md +0 -0
  22. {json2vec-0.4.4 → json2vec-0.4.5}/setup.cfg +0 -0
  23. {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/__init__.py +0 -0
  24. {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/architecture/__init__.py +0 -0
  25. {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/architecture/attention.py +0 -0
  26. {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/architecture/checkpoint.py +0 -0
  27. {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/architecture/encoder.py +0 -0
  28. {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/architecture/node.py +0 -0
  29. {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/architecture/plot.py +0 -0
  30. {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/architecture/pool.py +0 -0
  31. {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/architecture/rotary.py +0 -0
  32. {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/data/__init__.py +0 -0
  33. {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/data/datasets/__init__.py +0 -0
  34. {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/data/datasets/base.py +0 -0
  35. {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/data/datasets/polars.py +0 -0
  36. {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/data/datasets/streaming.py +0 -0
  37. {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/data/iterables.py +0 -0
  38. {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/data/processing.py +0 -0
  39. {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/distributed.py +0 -0
  40. {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/inference/__init__.py +0 -0
  41. {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/inference/callback.py +0 -0
  42. {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/inference/deployment.py +0 -0
  43. {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/logging/__init__.py +0 -0
  44. {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/logging/config.py +0 -0
  45. {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/logging/epoch.py +0 -0
  46. {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/logging/throughput.py +0 -0
  47. {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/preprocessors/__init__.py +0 -0
  48. {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/preprocessors/base.py +0 -0
  49. {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/preprocessors/extensions/__init__.py +0 -0
  50. {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/preprocessors/spec.py +0 -0
  51. {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/structs/__init__.py +0 -0
  52. {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/structs/enums.py +0 -0
  53. {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/structs/packages.py +0 -0
  54. {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/structs/structure.py +0 -0
  55. {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/tensorfields/__init__.py +0 -0
  56. {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/tensorfields/base.py +0 -0
  57. {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/tensorfields/extensions/__init__.py +0 -0
  58. {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/tensorfields/extensions/entity.py +0 -0
  59. {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/tensorfields/extensions/text.py +0 -0
  60. {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/tensorfields/shared/__init__.py +0 -0
  61. {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/tensorfields/shared/counter.py +0 -0
  62. {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/tensorfields/shared/vocabulary.py +0 -0
  63. {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/tensorfields/spec.py +0 -0
  64. {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec.egg-info/dependency_links.txt +0 -0
  65. {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec.egg-info/requires.txt +0 -0
  66. {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec.egg-info/top_level.txt +0 -0
  67. {json2vec-0.4.4 → json2vec-0.4.5}/tests/test_callbacks.py +0 -0
  68. {json2vec-0.4.4 → json2vec-0.4.5}/tests/test_public_api.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: json2vec
3
- Version: 0.4.4
3
+ Version: 0.4.5
4
4
  Summary: {...} -> [*]
5
5
  License-Expression: Apache-2.0
6
6
  Requires-Python: >=3.12
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "json2vec"
3
- version = "0.4.4"
3
+ version = "0.4.5"
4
4
  description = "{...} -> [*]"
5
5
  readme = "README.md"
6
6
  license = "Apache-2.0"
@@ -93,7 +93,7 @@ def sanitize(
93
93
  require_core_tensors(module, address, tensorfield)
94
94
  require_tensor_devices(module, address, tensorfield)
95
95
  require_target_contract(module, address, tensorfield, strata=normalized)
96
- require_mask_contract(module, address, tensorfield)
96
+ require_mask_contract(module, address, tensorfield, strata=normalized)
97
97
 
98
98
 
99
99
  def is_backoff_index(index: int, *, periodic_interval: int) -> bool:
@@ -292,7 +292,7 @@ def require_tensor_devices(module: "Model", address: Address, tensorfield: Tenso
292
292
  )
293
293
 
294
294
 
295
- def require_mask_contract(module: "Model", address: Address, tensorfield: TensorFieldBase) -> None:
295
+ def require_mask_contract(module: "Model", address: Address, tensorfield: TensorFieldBase, *, strata: Strata) -> None:
296
296
  state = tensorfield.state
297
297
  trainable = tensorfield.trainable
298
298
  is_masked = state.eq(Tokens.masked.value)
@@ -301,7 +301,7 @@ def require_mask_contract(module: "Model", address: Address, tensorfield: Tensor
301
301
  if trainable.any() and not state.masked_select(trainable).eq(Tokens.masked.value).all():
302
302
  raise ForwardContractError(f"forward input '{address}' trainable positions must have masked state")
303
303
 
304
- if not is_target and (is_masked & ~trainable).any():
304
+ if strata != Strata.predict and not is_target and (is_masked & ~trainable).any():
305
305
  raise ForwardContractError(f"forward input '{address}' has masked state where trainable is false")
306
306
 
307
307
  if not trainable.any():
@@ -8,6 +8,8 @@ from typing import TYPE_CHECKING
8
8
  import torch
9
9
 
10
10
  from json2vec.architecture.node import NodeModule
11
+ from json2vec.data.datasets.base import EncodedInput
12
+ from json2vec.structs.enums import Strata
11
13
  from json2vec.structs.experiment import Hyperparameters
12
14
  from json2vec.structs.tree import Address, Node
13
15
 
@@ -19,9 +21,19 @@ class ModelGraph:
19
21
  """Build and rebuild runtime modules from schema hyperparameters."""
20
22
 
21
23
  @staticmethod
22
- def build(hyperparameters: Hyperparameters, batch_size: int) -> tuple[torch.nn.ModuleDict, torch.Tensor]:
24
+ def example_forward_kwargs(hyperparameters: Hyperparameters, batch_size: int) -> dict[str, EncodedInput | Strata]:
23
25
  from json2vec.data.iterables import mock
24
26
 
27
+ return {
28
+ "inputs": mock(hyperparameters=hyperparameters, batch_size=batch_size),
29
+ "strata": Strata.predict,
30
+ }
31
+
32
+ @staticmethod
33
+ def build(
34
+ hyperparameters: Hyperparameters,
35
+ batch_size: int,
36
+ ) -> tuple[torch.nn.ModuleDict, dict[str, EncodedInput | Strata]]:
25
37
  nodes: torch.nn.ModuleDict[str, NodeModule] = torch.nn.ModuleDict()
26
38
 
27
39
  for address in hyperparameters.requests | hyperparameters.arrays:
@@ -31,7 +43,7 @@ class ModelGraph:
31
43
  batch_size=batch_size,
32
44
  )
33
45
 
34
- return nodes, mock(hyperparameters=hyperparameters, batch_size=batch_size)
46
+ return nodes, ModelGraph.example_forward_kwargs(hyperparameters=hyperparameters, batch_size=batch_size)
35
47
 
36
48
  @staticmethod
37
49
  def install(module: "Model") -> None:
@@ -72,8 +84,6 @@ class ModelGraph:
72
84
 
73
85
  @staticmethod
74
86
  def reset_selected(module: "Model", selected: list[Node], *, descendants: bool = False) -> None:
75
- from json2vec.data.iterables import mock
76
-
77
87
  selected_by_address: dict[Address, Node] = {}
78
88
  for node in selected:
79
89
  if node.address in module.nodes:
@@ -94,7 +104,10 @@ class ModelGraph:
94
104
  batch_size=module.batch_size,
95
105
  )
96
106
 
97
- module.example_input_array = mock(hyperparameters=module.hyperparameters, batch_size=module.batch_size)
107
+ module.example_input_array = ModelGraph.example_forward_kwargs(
108
+ hyperparameters=module.hyperparameters,
109
+ batch_size=module.batch_size,
110
+ )
98
111
  device = module.device
99
112
  if isinstance(device, torch.device):
100
113
  module.to(device=device)
@@ -0,0 +1,323 @@
1
+ """Model-facing schema mutation orchestration."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from collections.abc import Callable, Iterator
6
+ from contextlib import contextmanager
7
+ from dataclasses import dataclass
8
+ from typing import TYPE_CHECKING, Any
9
+
10
+ from loguru import logger
11
+
12
+ from json2vec.architecture.graph import ModelGraph
13
+ from json2vec.structs.experiment import NodeAttribute, NodePredicate, SchemaField
14
+ from json2vec.structs.structure import Array
15
+ from json2vec.structs.tree import Leaf, Node
16
+
17
+ if TYPE_CHECKING:
18
+ from json2vec.architecture.root import Model
19
+
20
+ _MISSING = object()
21
+
22
+
23
+ @dataclass(frozen=True)
24
+ class AttributeChange:
25
+ node: Node
26
+ name: str
27
+ original: Any
28
+ definition_attribute: bool
29
+
30
+
31
+ class SchemaEditor:
32
+ """Coordinate schema mutations with runtime graph rebuilds."""
33
+
34
+ def __init__(self, module: "Model") -> None:
35
+ self.module = module
36
+
37
+ def select(
38
+ self,
39
+ *predicates: NodePredicate | NodeAttribute | Callable[[Node], bool],
40
+ include_root: bool = True,
41
+ use_cache: bool = True,
42
+ ) -> list[Node]:
43
+ return self.module.hyperparameters.select(
44
+ *predicates,
45
+ include_root=include_root,
46
+ use_cache=use_cache,
47
+ )
48
+
49
+ def update(
50
+ self,
51
+ *predicates: NodePredicate | NodeAttribute | Callable[[Node], bool],
52
+ strict: bool = True,
53
+ allow_extra: bool = False,
54
+ include_root: bool = True,
55
+ validate: bool = True,
56
+ use_cache: bool = False,
57
+ **values: Any,
58
+ ) -> None:
59
+ self.module._assert_mutation_allowed("update")
60
+ values = self.module.hyperparameters.update_values(values)
61
+ changes = self._attribute_changes(
62
+ values=values,
63
+ predicates=predicates,
64
+ allow_extra=allow_extra,
65
+ include_root=include_root,
66
+ use_cache=use_cache,
67
+ )
68
+ self.module.hyperparameters.update(
69
+ *predicates,
70
+ strict=strict,
71
+ allow_extra=allow_extra,
72
+ include_root=include_root,
73
+ validate=validate,
74
+ use_cache=use_cache,
75
+ **values,
76
+ )
77
+ ModelGraph.rebuild(self.module)
78
+ self.module._reset_contracts()
79
+ self._log_attribute_changes("update", changes)
80
+
81
+ def extend(
82
+ self,
83
+ *args: NodePredicate | NodeAttribute | Callable[[Node], bool] | SchemaField,
84
+ include_root: bool = True,
85
+ use_cache: bool = True,
86
+ ) -> None:
87
+ self.module._assert_mutation_allowed("extend")
88
+ parent, field_count = self._extend_target(*args, include_root=include_root, use_cache=use_cache)
89
+ self.module.hyperparameters.extend(*args, include_root=include_root, use_cache=use_cache)
90
+ ModelGraph.rebuild(self.module)
91
+ self.module._reset_contracts()
92
+ for field in parent.fields[-field_count:]:
93
+ self._log_node_mutation(
94
+ action="extend",
95
+ message="extended schema node",
96
+ node=field,
97
+ parent=parent,
98
+ )
99
+
100
+ def delete(
101
+ self,
102
+ *predicates: NodePredicate | NodeAttribute | Callable[[Node], bool],
103
+ include_root: bool = False,
104
+ use_cache: bool = True,
105
+ ) -> None:
106
+ self.module._assert_mutation_allowed("delete")
107
+ roots = self._delete_roots(*predicates, include_root=include_root, use_cache=use_cache)
108
+ self.module.hyperparameters.delete(*predicates, include_root=include_root, use_cache=use_cache)
109
+ ModelGraph.rebuild(self.module)
110
+ self.module._reset_contracts()
111
+ for node in roots:
112
+ self._log_node_mutation(
113
+ action="delete",
114
+ message="deleted schema node",
115
+ node=node,
116
+ descendants=len(getattr(node, "descendants", ())),
117
+ )
118
+
119
+ def reset(
120
+ self,
121
+ *predicates: NodePredicate | NodeAttribute | Callable[[Node], bool],
122
+ include_root: bool = True,
123
+ use_cache: bool = True,
124
+ descendants: bool = False,
125
+ ) -> None:
126
+ self.module._assert_mutation_allowed("reset")
127
+ selected = self.module.hyperparameters.select(
128
+ *predicates,
129
+ include_root=include_root,
130
+ use_cache=use_cache,
131
+ )
132
+ if not selected:
133
+ raise ValueError("reset matched no nodes")
134
+
135
+ nodes = self._runtime_reset_nodes(selected, descendants=descendants)
136
+ ModelGraph.reset_selected(self.module, selected, descendants=descendants)
137
+ self.module._reset_contracts()
138
+ for node in nodes:
139
+ self._log_node_mutation(
140
+ action="reset",
141
+ message="reset runtime node",
142
+ node=node,
143
+ descendants=descendants,
144
+ )
145
+
146
+ @contextmanager
147
+ def override(
148
+ self,
149
+ *predicates: NodePredicate | NodeAttribute | Callable[[Node], bool],
150
+ strict: bool = True,
151
+ allow_extra: bool = False,
152
+ include_root: bool = True,
153
+ validate: bool = True,
154
+ use_cache: bool = False,
155
+ **values: Any,
156
+ ) -> Iterator[None]:
157
+ self.module._assert_mutation_allowed("override")
158
+ values = self.module.hyperparameters.update_values(values)
159
+ changes = self._attribute_changes(
160
+ values=values,
161
+ predicates=predicates,
162
+ allow_extra=allow_extra,
163
+ include_root=include_root,
164
+ use_cache=use_cache,
165
+ )
166
+ entered = False
167
+ try:
168
+ with self.module.hyperparameters.override(
169
+ *predicates,
170
+ strict=strict,
171
+ allow_extra=allow_extra,
172
+ include_root=include_root,
173
+ validate=validate,
174
+ use_cache=use_cache,
175
+ **values,
176
+ ):
177
+ entered = True
178
+ ModelGraph.rebuild(self.module)
179
+ self.module._reset_contracts()
180
+ self._log_attribute_changes("override", changes)
181
+ yield
182
+ finally:
183
+ ModelGraph.rebuild(self.module)
184
+ self.module._reset_contracts()
185
+ if entered:
186
+ self._log_attribute_changes("override_restore", changes, restored=True)
187
+
188
+ def _attribute_changes(
189
+ self,
190
+ *,
191
+ values: dict[str, Any],
192
+ predicates: tuple[NodePredicate | NodeAttribute | Callable[[Node], bool], ...],
193
+ allow_extra: bool,
194
+ include_root: bool,
195
+ use_cache: bool,
196
+ ) -> list[AttributeChange]:
197
+ nodes = self.module.hyperparameters.select(*predicates, include_root=include_root, use_cache=use_cache)
198
+ changes: list[AttributeChange] = []
199
+ for node in nodes:
200
+ can_apply_extra = allow_extra and getattr(type(node), "model_config", {}).get("extra") == "allow"
201
+ for name in values:
202
+ if not (_has_node_attribute(node, name) or can_apply_extra):
203
+ continue
204
+
205
+ changes.append(
206
+ AttributeChange(
207
+ node=node,
208
+ name=name,
209
+ original=getattr(node, name, _MISSING),
210
+ definition_attribute=_is_definition_attribute(node, name),
211
+ )
212
+ )
213
+
214
+ return changes
215
+
216
+ def _extend_target(
217
+ self,
218
+ *args: NodePredicate | NodeAttribute | Callable[[Node], bool] | SchemaField,
219
+ include_root: bool,
220
+ use_cache: bool,
221
+ ) -> tuple[Array, int]:
222
+ predicates: list[NodePredicate | NodeAttribute | Callable[[Node], bool]] = []
223
+ field_count = 0
224
+ reading_fields = False
225
+
226
+ for item in args:
227
+ if isinstance(item, (Array, Leaf)):
228
+ reading_fields = True
229
+ field_count += 1
230
+ continue
231
+
232
+ if reading_fields:
233
+ raise TypeError("extend predicates must come before new schema fields")
234
+
235
+ predicates.append(item)
236
+
237
+ if field_count == 0:
238
+ raise ValueError("extend requires at least one schema field")
239
+
240
+ candidates = [
241
+ node
242
+ for node in self.module.hyperparameters.select(*predicates, include_root=include_root, use_cache=use_cache)
243
+ if isinstance(node, Array)
244
+ ]
245
+ if len(candidates) != 1:
246
+ raise ValueError(f"extend requires exactly one matching array node, found {len(candidates)}")
247
+
248
+ return candidates[0], field_count
249
+
250
+ def _delete_roots(
251
+ self,
252
+ *predicates: NodePredicate | NodeAttribute | Callable[[Node], bool],
253
+ include_root: bool,
254
+ use_cache: bool,
255
+ ) -> list[Node]:
256
+ selected = self.module.hyperparameters.select(*predicates, include_root=include_root, use_cache=use_cache)
257
+ selected_ids = {id(node) for node in selected}
258
+ return [
259
+ node
260
+ for node in selected
261
+ if not any(
262
+ id(ancestor) in selected_ids
263
+ for ancestor in getattr(node, "ancestors", ())
264
+ if ancestor is not self.module.hyperparameters
265
+ )
266
+ ]
267
+
268
+ def _runtime_reset_nodes(self, selected: list[Node], *, descendants: bool) -> list[Node]:
269
+ nodes: dict[str, Node] = {}
270
+ for node in selected:
271
+ if node.address in self.module.nodes:
272
+ nodes[str(node.address)] = node
273
+
274
+ if descendants:
275
+ for descendant in getattr(node, "descendants", ()):
276
+ if descendant.address in self.module.nodes:
277
+ nodes[str(descendant.address)] = descendant
278
+
279
+ return list(nodes.values())
280
+
281
+ def _log_attribute_changes(self, action: str, changes: list[AttributeChange], *, restored: bool = False) -> None:
282
+ for change in changes:
283
+ value = change.original if restored else getattr(change.node, change.name, _MISSING)
284
+ logger.bind(
285
+ component="schema_mutation",
286
+ action=action,
287
+ address=str(change.node.address),
288
+ node_type=change.node.type,
289
+ attribute=change.name,
290
+ definition_attribute=change.definition_attribute,
291
+ value=_format_log_value(value),
292
+ previous_value=_format_log_value(change.original),
293
+ ).info("restored schema node attribute" if restored else "mutated schema node attribute")
294
+
295
+ def _log_node_mutation(self, *, action: str, message: str, node: Node, **kwargs: Any) -> None:
296
+ extra = {key: str(value.address) if isinstance(value, Node) else value for key, value in kwargs.items()}
297
+ logger.bind(
298
+ component="schema_mutation",
299
+ action=action,
300
+ address=str(node.address),
301
+ node_type=node.type,
302
+ attribute=None,
303
+ definition_attribute=None,
304
+ **extra,
305
+ ).info(message)
306
+
307
+
308
+ def _has_node_attribute(node: Node, name: str) -> bool:
309
+ fields = getattr(type(node), "model_fields", {})
310
+ extra = getattr(node, "model_extra", None) or {}
311
+ return name in fields or name in extra or hasattr(node, name)
312
+
313
+
314
+ def _is_definition_attribute(node: Node, name: str) -> bool:
315
+ return name in getattr(type(node), "model_fields", {})
316
+
317
+
318
+ def _format_log_value(value: Any) -> str:
319
+ if value is _MISSING:
320
+ return "<missing>"
321
+
322
+ text = repr(value)
323
+ return text if len(text) <= 160 else f"{text[:157]}..."
@@ -18,10 +18,10 @@ from tensordict import TensorDict
18
18
  from json2vec.architecture.checkpoint import CheckpointState
19
19
  from json2vec.architecture.contracts import ContractScheduler
20
20
  from json2vec.architecture.graph import ModelGraph
21
+ from json2vec.architecture.mutations import SchemaEditor
21
22
  from json2vec.architecture.plot import PlotMode
22
23
  from json2vec.architecture.runtime import EvaluationResult, ModelRuntime, Postprocessor, PreprocessFn, step
23
- from json2vec.architecture.schema_editor import SchemaEditor
24
- from json2vec.data.datasets.base import EncodedBatch
24
+ from json2vec.data.datasets.base import EncodedBatch, EncodedInput
25
25
  from json2vec.structs.enums import AttentionMode, Strata
26
26
  from json2vec.structs.experiment import (
27
27
  Hyperparameters,
@@ -520,24 +520,38 @@ class Model(lit.LightningModule):
520
520
  ) -> tuple[dict[Address, dict[str, Any]], dict[Address, dict[str, Any]]]:
521
521
  return ModelRuntime.write(self, predictions)
522
522
 
523
+ @immutable("inference")
524
+ def encode(
525
+ self,
526
+ batch: EncodedBatch | list[dict[str, Any]],
527
+ preprocess: PreprocessFn | None = None,
528
+ strata: Strata | str = Strata.predict,
529
+ ) -> EncodedInput:
530
+ """Return encoded tensorfield inputs for raw or processed observations."""
531
+ return ModelRuntime.encode(
532
+ self,
533
+ batch=batch,
534
+ preprocess=preprocess,
535
+ strata=strata,
536
+ )
537
+
523
538
  @immutable("inference")
524
539
  def evaluate(
525
540
  self,
526
541
  batch: EncodedBatch | list[dict[str, Any]],
527
542
  preprocess: PreprocessFn | None = None,
528
543
  postprocess: Postprocessor | None = None,
529
- ) -> tuple[dict[Address, dict[str, Any]], dict[Address, dict[str, Any]]]:
544
+ ) -> EvaluationResult:
530
545
  """Run prediction and embedding for encoded or raw observations.
531
546
 
532
547
  If `preprocess` is omitted, raw records are encoded unchanged.
533
548
  """
534
- result: EvaluationResult = ModelRuntime.evaluate(
549
+ return ModelRuntime.evaluate(
535
550
  self,
536
551
  batch=batch,
537
552
  preprocess=preprocess,
538
553
  postprocess=postprocess,
539
554
  )
540
- return result.as_tuple()
541
555
 
542
556
  def predict(
543
557
  self,
@@ -546,12 +560,14 @@ class Model(lit.LightningModule):
546
560
  postprocess: Postprocessor | None = None,
547
561
  ) -> dict[Address, dict[str, Any]]:
548
562
  """Return typed predictions for a raw or encoded batch."""
549
- supervised, _ = self.evaluate(
563
+
564
+ result = self.evaluate(
550
565
  batch=batch,
551
566
  preprocess=preprocess,
552
567
  postprocess=postprocess,
553
568
  )
554
- return supervised
569
+
570
+ return result.predictions
555
571
 
556
572
  def embed(
557
573
  self,
@@ -560,12 +576,12 @@ class Model(lit.LightningModule):
560
576
  postprocess: Postprocessor | None = None,
561
577
  ) -> dict[Address, dict[str, Any]]:
562
578
  """Return configured embeddings for a raw or encoded batch."""
563
- _, embeddings = self.evaluate(
579
+ result = self.evaluate(
564
580
  batch=batch,
565
581
  preprocess=preprocess,
566
582
  postprocess=postprocess,
567
583
  )
568
- return embeddings
584
+ return result.embeddings
569
585
 
570
586
  training_step = partialmethod(step, strata=Strata.train)
571
587
  validation_step = partialmethod(step, strata=Strata.validate)
@@ -14,7 +14,8 @@ from tensordict import TensorDict
14
14
  from json2vec.architecture.contracts import sanitize
15
15
  from json2vec.architecture.encoder import ArrayEncoder
16
16
  from json2vec.architecture.node import NodeModule
17
- from json2vec.data.datasets.base import EncodedBatch
17
+ from json2vec.data.datasets.base import EncodedBatch, EncodedInput
18
+ from json2vec.data.iterables import encode
18
19
  from json2vec.structs.enums import Metric, Strata, TensorKey
19
20
  from json2vec.structs.packages import Embedding, Parcel, Prediction
20
21
  from json2vec.structs.tree import Address
@@ -50,9 +51,6 @@ class EvaluationResult:
50
51
  predictions: dict[Address, dict[str, Any]]
51
52
  embeddings: dict[Address, dict[str, Any]]
52
53
 
53
- def as_tuple(self) -> tuple[dict[Address, dict[str, Any]], dict[Address, dict[str, Any]]]:
54
- return self.predictions, self.embeddings
55
-
56
54
 
57
55
  class ModelRuntime:
58
56
  """Own runtime behavior that depends on an already-built model graph."""
@@ -182,16 +180,13 @@ class ModelRuntime:
182
180
  return supervised, embeddings
183
181
 
184
182
  @staticmethod
185
- def evaluate(
183
+ def encode(
186
184
  module: "Model",
187
185
  batch: EncodedBatch | list[dict[str, Any]],
188
186
  preprocess: PreprocessFn | None = None,
189
- postprocess: Postprocessor | None = None,
190
- ) -> EvaluationResult:
191
- from json2vec.data.iterables import encode
192
-
193
- was_training = module.training
194
- raw_batch = batch
187
+ strata: Strata | str = Strata.predict,
188
+ ) -> EncodedInput:
189
+ strata = Strata.normalize(strata)
195
190
 
196
191
  if preprocess is not None:
197
192
  observations: EncodedBatch = []
@@ -206,13 +201,24 @@ class ModelRuntime:
206
201
  elif batch and isinstance(batch[0], dict):
207
202
  batch = [[request] for request in cast(list[dict[str, Any]], batch)]
208
203
 
209
- inputs = encode(
204
+ return encode(
210
205
  batch=cast(EncodedBatch, batch),
211
206
  hyperparameters=module.hyperparameters,
212
- strata=Strata.predict,
207
+ strata=strata,
213
208
  interprocess_encoding_context=module.interprocess_encoding_context,
214
209
  )
215
210
 
211
+ @staticmethod
212
+ def evaluate(
213
+ module: "Model",
214
+ batch: EncodedBatch | list[dict[str, Any]],
215
+ preprocess: PreprocessFn | None = None,
216
+ postprocess: Postprocessor | None = None,
217
+ ) -> EvaluationResult:
218
+ was_training = module.training
219
+ raw_batch = batch
220
+ inputs = ModelRuntime.encode(module=module, batch=batch, preprocess=preprocess, strata=Strata.predict)
221
+
216
222
  module.eval()
217
223
  try:
218
224
  with torch.inference_mode():
@@ -226,7 +232,7 @@ class ModelRuntime:
226
232
  if postprocess is not None:
227
233
  context = {
228
234
  "batch": raw_batch,
229
- "observations": batch,
235
+ "observations": inputs[TensorKey.metadata],
230
236
  "input": inputs,
231
237
  TensorKey.metadata: inputs[TensorKey.metadata],
232
238
  }
@@ -65,7 +65,7 @@ class Hyperparameters(Node):
65
65
  @classmethod
66
66
  def update_values(cls, values: Mapping[str, Any]) -> dict[str, Any]:
67
67
  normalized = dict(values)
68
- target = normalized.pop("target", None)
68
+ target = normalized.get("target", None)
69
69
 
70
70
  if target is None:
71
71
  return normalized
@@ -76,11 +76,9 @@ class Hyperparameters(Node):
76
76
  if target:
77
77
  if normalized.get("p_prune") not in (None, 1.0):
78
78
  raise ValueError("target=True is shorthand for p_prune=1.0")
79
- normalized["p_prune"] = 1.0
80
79
  else:
81
- if normalized.get("p_prune") is not None:
80
+ if "p_prune" in normalized and normalized["p_prune"] is not None:
82
81
  raise ValueError("target=False is shorthand for p_prune=None")
83
- normalized["p_prune"] = None
84
82
 
85
83
  return normalized
86
84
 
@@ -215,7 +213,7 @@ class Hyperparameters(Node):
215
213
  @property
216
214
  def target(self) -> list[Address]:
217
215
  role = NodePredicate(
218
- func=lambda node: isinstance(node, Leaf) and node.active and getattr(node, "p_prune", 0.0) == 1.0,
216
+ func=lambda node: isinstance(node, Leaf) and node.active and node.target,
219
217
  key=("role", "target"),
220
218
  )
221
219
  return [Address(str(node.address)) for node in self.select(role)]
@@ -348,6 +346,8 @@ class Hyperparameters(Node):
348
346
 
349
347
  if validate and applicable_values:
350
348
  payload = node.model_dump(mode="python", round_trip=True)
349
+ if "target" in applicable_values and "p_prune" not in applicable_values:
350
+ payload.pop("p_prune", None)
351
351
  payload.update(applicable_values)
352
352
  validated = type(node).model_validate(payload)
353
353
  applicable_values = {name: getattr(validated, name) for name in applicable_values}
@@ -501,7 +501,12 @@ class Hyperparameters(Node):
501
501
  nodes = self.select(*predicates, include_root=include_root, use_cache=use_cache)
502
502
  normalized_values = self.update_values(values)
503
503
  snapshot = [
504
- (node, name, getattr(node, name, _MISSING), name in getattr(node, "model_fields_set", set()))
504
+ (
505
+ node,
506
+ "p_prune" if name == "target" else name,
507
+ getattr(node, "p_prune" if name == "target" else name, _MISSING),
508
+ ("p_prune" if name == "target" else name) in getattr(node, "model_fields_set", set()),
509
+ )
505
510
  for node in nodes
506
511
  for name in normalized_values
507
512
  if _has_model_attribute(node, name)
@@ -140,7 +140,7 @@ class NodeAttribute(pydantic.BaseModel):
140
140
  if self.name == "descendants":
141
141
  return tuple(str(child.address) for child in getattr(node, "descendants", ()))
142
142
  if self.name == "target":
143
- return isinstance(node, Leaf) and node.active and getattr(node, "p_prune", None) == 1.0
143
+ return isinstance(node, Leaf) and node.active and node.target
144
144
 
145
145
  extra = getattr(node, "model_extra", None) or {}
146
146
  if self.name in extra:
@@ -51,6 +51,17 @@ class Node(NodeMixin, pydantic.BaseModel):
51
51
  p_mask: Rate | None = None
52
52
  p_prune: PruneRate | None = None
53
53
 
54
+ @property
55
+ def target(self) -> bool:
56
+ return self.p_prune == 1.0
57
+
58
+ @target.setter
59
+ def target(self, value: bool) -> None:
60
+ if not isinstance(value, bool):
61
+ raise ValueError("target must be a boolean")
62
+
63
+ self.p_prune = 1.0 if value else None
64
+
54
65
  @classmethod
55
66
  def sanitize_name(cls, value: str) -> str:
56
67
  sanitized = re.sub(r"[^0-9A-Za-z_-]+", "_", value).strip("_")
@@ -78,6 +89,10 @@ class Node(NodeMixin, pydantic.BaseModel):
78
89
  if values.get("p_prune") not in (None, 1.0):
79
90
  raise ValueError("target=True is shorthand for p_prune=1.0")
80
91
  values["p_prune"] = 1.0
92
+ else:
93
+ if values.get("p_prune") is not None:
94
+ raise ValueError("target=False is shorthand for p_prune=None")
95
+ values["p_prune"] = None
81
96
 
82
97
  return values
83
98
 
@@ -1,8 +1,9 @@
1
1
  # ty: ignore[invalid-method-override,unknown-argument]
2
2
  from __future__ import annotations
3
3
 
4
+ from collections.abc import Mapping
4
5
  from functools import partial
5
- from typing import TYPE_CHECKING, Annotated, Literal, cast
6
+ from typing import TYPE_CHECKING, Annotated, Any, Literal, cast
6
7
 
7
8
  import numpy as np
8
9
  import pydantic
@@ -46,10 +47,17 @@ class Request(RequestBase):
46
47
 
47
48
  type: Literal["category"] = "category"
48
49
  max_vocab_size: Annotated[int, pydantic.Field(gt=0, default=10_000)] = 10000
49
- n_bands: Annotated[int, pydantic.Field(gt=0, default=8)] = 8
50
50
  p_unavailable: Annotated[float, pydantic.Field(ge=0.0, le=1.0, default=0.01)] = 0.01
51
51
  topk: list[int] | None = None
52
52
 
53
+ @pydantic.model_validator(mode="before")
54
+ @classmethod
55
+ def reject_removed_options(cls, data: Any) -> Any:
56
+ if isinstance(data, Mapping) and "n_bands" in data:
57
+ raise ValueError("Category does not support n_bands")
58
+
59
+ return data
60
+
53
61
  @pydantic.model_validator(mode="after")
54
62
  def check_topk(self):
55
63
  if self.topk is None:
@@ -1,6 +1,7 @@
1
1
  # ty: ignore[invalid-argument-type,invalid-assignment,unknown-argument,unresolved-attribute]
2
2
  from __future__ import annotations
3
3
 
4
+ import difflib
4
5
  import enum
5
6
  import math
6
7
  import re
@@ -74,6 +75,22 @@ class DatePart(enum.StrEnum):
74
75
  return cls.DEPTH[datepart]
75
76
 
76
77
 
78
+ def _normalize_datepart_key(value: str) -> str:
79
+ value = re.sub(r"(?<=[a-z0-9])(?=[A-Z])", "_", value.strip())
80
+ value = re.sub(r"[^0-9A-Za-z]+", "_", value)
81
+ return value.strip("_").casefold()
82
+
83
+
84
+ def _datepart_lookup() -> dict[str, DatePart]:
85
+ lookup: dict[str, DatePart] = {}
86
+ for datepart in DatePart:
87
+ normalized = _normalize_datepart_key(datepart.value)
88
+ lookup[normalized] = datepart
89
+ lookup[normalized.replace("_", "")] = datepart
90
+
91
+ return lookup
92
+
93
+
77
94
  @DatePart.day_of_month.register(depth=31)
78
95
  def _(arr: np.ndarray) -> np.ndarray:
79
96
  month_start = arr.astype("datetime64[M]")
@@ -133,6 +150,35 @@ class Request(RequestBase):
133
150
  dateparts: list[DatePart]
134
151
  pattern: Annotated[str | None, pydantic.Field(default=None)] = None
135
152
 
153
+ @pydantic.field_validator("dateparts", mode="before", check_fields=False)
154
+ @classmethod
155
+ def _coerce_dateparts(cls, value: Any) -> Any:
156
+ if not isinstance(value, (list, tuple)):
157
+ return value
158
+
159
+ lookup = _datepart_lookup()
160
+ canonical = [datepart.value for datepart in DatePart]
161
+ dateparts: list[DatePart] = []
162
+ for item in value:
163
+ if isinstance(item, DatePart):
164
+ dateparts.append(item)
165
+ continue
166
+
167
+ if not isinstance(item, str):
168
+ raise ValueError(f"datepart values must be strings, got {type(item).__name__}")
169
+
170
+ key = _normalize_datepart_key(item)
171
+ match = lookup.get(key) or lookup.get(key.replace("_", ""))
172
+ if match is not None:
173
+ dateparts.append(match)
174
+ continue
175
+
176
+ suggestions = difflib.get_close_matches(key, canonical, n=1)
177
+ suggestion = f"; did you mean '{suggestions[0]}'?" if suggestions else ""
178
+ raise ValueError(f"unknown datepart '{item}'{suggestion}")
179
+
180
+ return dateparts
181
+
136
182
  @pydantic.field_validator("dateparts", check_fields=False)
137
183
  @classmethod
138
184
  def check_dateparts(cls, v):
@@ -59,7 +59,7 @@ class Request(RequestBase):
59
59
  """Numeric scalar tensorfield request."""
60
60
 
61
61
  type: Literal["number"] = "number"
62
- jitter: Annotated[float, pydantic.Field(ge=0.0, lt=1.0, default=0.0)] = 0.0
62
+ jitter: Annotated[float, pydantic.Field(ge=0.0, default=0.0)] = 0.0
63
63
  n_bands: Annotated[int, pydantic.Field(gt=0, default=8)] = 8
64
64
  offset: Annotated[int, pydantic.Field(gt=0, default=4)] = 4
65
65
  alpha: Annotated[float | None, pydantic.Field(gt=0.0, lt=1.0, default=None)] = None
@@ -43,6 +43,7 @@ class Request(RequestBase):
43
43
  type: Literal["set"] = "set"
44
44
  max_vocab_size: Annotated[int, pydantic.Field(gt=0, default=10_000)] = 10_000
45
45
  p_unavailable: Annotated[float, pydantic.Field(ge=0.0, le=1.0, default=0.01)] = 0.01
46
+ threshold: Annotated[float | None, pydantic.Field(ge=0.0, le=1.0, default=None)] = None
46
47
 
47
48
 
48
49
  def _items(value: Any) -> Iterable[Any]:
@@ -373,6 +374,7 @@ def loss(
373
374
  @sets.register
374
375
  def write(module: Model, prediction: Prediction):
375
376
  node = module.nodes[prediction.address]
377
+ request: Request = module.hyperparameters.requests[prediction.address]
376
378
  state_logits: torch.Tensor = prediction.payload[TensorKey.state]
377
379
  content_logits: torch.Tensor = prediction.payload[TensorKey.content]
378
380
 
@@ -383,7 +385,22 @@ def write(module: Model, prediction: Prediction):
383
385
 
384
386
  vocab = node.embedder.vocab.snapshot()
385
387
  probabilities = content_logits[..., : len(vocab)].sigmoid().detach().float().cpu().numpy()
386
- content_payload = {str(label): probabilities[..., index] for index, label in enumerate(vocab)}
388
+ if request.threshold is None:
389
+ content_payload = {str(label): probabilities[..., index] for index, label in enumerate(vocab)}
390
+ else:
391
+ labels = np.asarray(vocab, dtype=object)
392
+
393
+ def pack_thresholded(values: np.ndarray) -> dict[str, float] | list:
394
+ if values.ndim == 1:
395
+ keep = values >= request.threshold
396
+ return {
397
+ str(label): float(probability)
398
+ for label, probability in zip(labels[keep].tolist(), values[keep].tolist())
399
+ }
400
+
401
+ return [pack_thresholded(values[index]) for index in range(values.shape[0])]
402
+
403
+ content_payload = pack_thresholded(probabilities)
387
404
 
388
405
  return {
389
406
  TensorKey.state.name: state_payload,
@@ -226,7 +226,11 @@ class Decoder(DecoderBase):
226
226
 
227
227
  request: Request = hyperparameters.requests[address]
228
228
 
229
- self.linear = torch.nn.Linear(
229
+ self.classification = torch.nn.Linear(
230
+ in_features=hyperparameters.d_model,
231
+ out_features=len(Tokens),
232
+ )
233
+ self.regression = torch.nn.Linear(
230
234
  in_features=hyperparameters.d_model,
231
235
  out_features=request.n_dim,
232
236
  )
@@ -235,7 +239,8 @@ class Decoder(DecoderBase):
235
239
  def decode(self, pooled: torch.Tensor) -> TensorDict[TensorKey, torch.Tensor]:
236
240
  return TensorDict(
237
241
  source={
238
- TensorKey.content: self.linear(pooled),
242
+ TensorKey.state: self.classification(pooled),
243
+ TensorKey.content: self.regression(pooled),
239
244
  }
240
245
  )
241
246
 
@@ -251,30 +256,67 @@ def loss(
251
256
  request: Request = module.hyperparameters.requests[address]
252
257
 
253
258
  trainable = batch.trainable.reshape(-1)
259
+ state_targets = batch.targets[TensorKey.state].reshape(-1)
260
+ state_inputs = prediction.payload[TensorKey.state].reshape(-1, len(Tokens))
261
+
262
+ output: torch.Tensor = module.track(
263
+ (address, strata, Metric.loss, TensorKey.state),
264
+ value=(
265
+ torch.nn.functional.cross_entropy(
266
+ input=state_inputs,
267
+ target=state_targets,
268
+ reduction="none",
269
+ )
270
+ .masked_select(trainable)
271
+ .mean()
272
+ ),
273
+ )
274
+
275
+ module.track(
276
+ (address, strata, Metric.accuracy, TensorKey.state),
277
+ value=state_inputs.argmax(dim=1).eq(state_targets).masked_select(trainable).float().mean(),
278
+ )
279
+
280
+ valued = trainable & state_targets.eq(Tokens.valued.value)
281
+ if not valued.any():
282
+ return output
283
+
254
284
  inputs = prediction.payload[TensorKey.content].reshape(-1, request.n_dim)
255
285
  targets = batch.targets[TensorKey.content].reshape(-1, request.n_dim)
256
286
  diff = inputs.subtract(targets)
257
287
 
258
- loss: torch.Tensor = module.track(
288
+ output += module.track(
259
289
  (address, strata, Metric.loss, TensorKey.content),
260
- value=request.objective.loss(inputs=inputs, targets=targets).masked_select(trainable).mean(),
290
+ value=request.objective.loss(inputs=inputs, targets=targets).masked_select(valued).mean(),
261
291
  )
262
292
 
263
293
  module.track(
264
294
  (address, strata, Metric.mae, TensorKey.content),
265
- value=diff.absolute().mean(dim=1).masked_select(trainable).mean(),
295
+ value=diff.absolute().mean(dim=1).masked_select(valued).mean(),
266
296
  )
267
297
 
268
298
  module.track(
269
299
  (address, strata, Metric.rmse, TensorKey.content),
270
- value=diff.square().mean(dim=1).sqrt().masked_select(trainable).mean(),
300
+ value=diff.square().mean(dim=1).sqrt().masked_select(valued).mean(),
271
301
  )
272
302
 
273
- return loss
303
+ return output
274
304
 
275
305
 
276
306
  @vector.register
277
307
  def write(module: Model, prediction: Prediction):
308
+ content: np.ndarray = prediction.payload[TensorKey.content].detach().float().cpu().numpy()
309
+ state_logits: torch.Tensor = prediction.payload[TensorKey.state]
310
+ tokens: np.ndarray = np.fromiter((token.name for token in Tokens), dtype=object, count=len(Tokens))
311
+ state_log_norm = state_logits.logsumexp(dim=-1, keepdim=True)
312
+ state_distribution = (state_logits - state_log_norm).exp().detach().float().cpu().numpy()
313
+ state_payload = {token: state_distribution[..., index] for index, token in enumerate(tokens.tolist())}
314
+
315
+ non_valued = state_logits.argmax(dim=-1).ne(Tokens.valued.value).detach().cpu().numpy()
316
+ content = content.copy()
317
+ content[non_valued] = 0.0
318
+
278
319
  return {
279
- TensorKey.content.name: prediction.payload[TensorKey.content].detach().float().cpu().numpy(),
320
+ TensorKey.state.name: state_payload,
321
+ TensorKey.content.name: content,
280
322
  }
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: json2vec
3
- Version: 0.4.4
3
+ Version: 0.4.5
4
4
  Summary: {...} -> [*]
5
5
  License-Expression: Apache-2.0
6
6
  Requires-Python: >=3.12
@@ -15,13 +15,13 @@ src/json2vec/architecture/checkpoint.py
15
15
  src/json2vec/architecture/contracts.py
16
16
  src/json2vec/architecture/encoder.py
17
17
  src/json2vec/architecture/graph.py
18
+ src/json2vec/architecture/mutations.py
18
19
  src/json2vec/architecture/node.py
19
20
  src/json2vec/architecture/plot.py
20
21
  src/json2vec/architecture/pool.py
21
22
  src/json2vec/architecture/root.py
22
23
  src/json2vec/architecture/rotary.py
23
24
  src/json2vec/architecture/runtime.py
24
- src/json2vec/architecture/schema_editor.py
25
25
  src/json2vec/data/__init__.py
26
26
  src/json2vec/data/iterables.py
27
27
  src/json2vec/data/processing.py
@@ -1,126 +0,0 @@
1
- """Model-facing schema mutation orchestration."""
2
-
3
- from __future__ import annotations
4
-
5
- from collections.abc import Callable, Iterator
6
- from contextlib import contextmanager
7
- from typing import TYPE_CHECKING, Any
8
-
9
- from json2vec.architecture.graph import ModelGraph
10
- from json2vec.structs.experiment import NodeAttribute, NodePredicate, SchemaField
11
- from json2vec.structs.tree import Node
12
-
13
- if TYPE_CHECKING:
14
- from json2vec.architecture.root import Model
15
-
16
-
17
- class SchemaEditor:
18
- """Coordinate schema mutations with runtime graph rebuilds."""
19
-
20
- def __init__(self, module: "Model") -> None:
21
- self.module = module
22
-
23
- def select(
24
- self,
25
- *predicates: NodePredicate | NodeAttribute | Callable[[Node], bool],
26
- include_root: bool = True,
27
- use_cache: bool = True,
28
- ) -> list[Node]:
29
- return self.module.hyperparameters.select(
30
- *predicates,
31
- include_root=include_root,
32
- use_cache=use_cache,
33
- )
34
-
35
- def update(
36
- self,
37
- *predicates: NodePredicate | NodeAttribute | Callable[[Node], bool],
38
- strict: bool = True,
39
- allow_extra: bool = False,
40
- include_root: bool = True,
41
- validate: bool = True,
42
- use_cache: bool = False,
43
- **values: Any,
44
- ) -> None:
45
- self.module._assert_mutation_allowed("update")
46
- self.module.hyperparameters.update(
47
- *predicates,
48
- strict=strict,
49
- allow_extra=allow_extra,
50
- include_root=include_root,
51
- validate=validate,
52
- use_cache=use_cache,
53
- **values,
54
- )
55
- ModelGraph.rebuild(self.module)
56
- self.module._reset_contracts()
57
-
58
- def extend(
59
- self,
60
- *args: NodePredicate | NodeAttribute | Callable[[Node], bool] | SchemaField,
61
- include_root: bool = True,
62
- use_cache: bool = True,
63
- ) -> None:
64
- self.module._assert_mutation_allowed("extend")
65
- self.module.hyperparameters.extend(*args, include_root=include_root, use_cache=use_cache)
66
- ModelGraph.rebuild(self.module)
67
- self.module._reset_contracts()
68
-
69
- def delete(
70
- self,
71
- *predicates: NodePredicate | NodeAttribute | Callable[[Node], bool],
72
- include_root: bool = False,
73
- use_cache: bool = True,
74
- ) -> None:
75
- self.module._assert_mutation_allowed("delete")
76
- self.module.hyperparameters.delete(*predicates, include_root=include_root, use_cache=use_cache)
77
- ModelGraph.rebuild(self.module)
78
- self.module._reset_contracts()
79
-
80
- def reset(
81
- self,
82
- *predicates: NodePredicate | NodeAttribute | Callable[[Node], bool],
83
- include_root: bool = True,
84
- use_cache: bool = True,
85
- descendants: bool = False,
86
- ) -> None:
87
- self.module._assert_mutation_allowed("reset")
88
- selected = self.module.hyperparameters.select(
89
- *predicates,
90
- include_root=include_root,
91
- use_cache=use_cache,
92
- )
93
- if not selected:
94
- raise ValueError("reset matched no nodes")
95
-
96
- ModelGraph.reset_selected(self.module, selected, descendants=descendants)
97
- self.module._reset_contracts()
98
-
99
- @contextmanager
100
- def override(
101
- self,
102
- *predicates: NodePredicate | NodeAttribute | Callable[[Node], bool],
103
- strict: bool = True,
104
- allow_extra: bool = False,
105
- include_root: bool = True,
106
- validate: bool = True,
107
- use_cache: bool = False,
108
- **values: Any,
109
- ) -> Iterator[None]:
110
- self.module._assert_mutation_allowed("override")
111
- try:
112
- with self.module.hyperparameters.override(
113
- *predicates,
114
- strict=strict,
115
- allow_extra=allow_extra,
116
- include_root=include_root,
117
- validate=validate,
118
- use_cache=use_cache,
119
- **values,
120
- ):
121
- ModelGraph.rebuild(self.module)
122
- self.module._reset_contracts()
123
- yield
124
- finally:
125
- ModelGraph.rebuild(self.module)
126
- self.module._reset_contracts()
File without changes
File without changes
File without changes
File without changes