json2vec 0.4.4__tar.gz → 0.4.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {json2vec-0.4.4/src/json2vec.egg-info → json2vec-0.4.5}/PKG-INFO +1 -1
- {json2vec-0.4.4 → json2vec-0.4.5}/pyproject.toml +1 -1
- {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/architecture/contracts.py +3 -3
- {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/architecture/graph.py +18 -5
- json2vec-0.4.5/src/json2vec/architecture/mutations.py +323 -0
- {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/architecture/root.py +25 -9
- {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/architecture/runtime.py +20 -14
- {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/structs/experiment.py +11 -6
- {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/structs/selectors.py +1 -1
- {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/structs/tree.py +15 -0
- {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/tensorfields/extensions/category.py +10 -2
- {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/tensorfields/extensions/dateparts.py +46 -0
- {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/tensorfields/extensions/number.py +1 -1
- {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/tensorfields/extensions/set.py +18 -1
- {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/tensorfields/extensions/vector.py +50 -8
- {json2vec-0.4.4 → json2vec-0.4.5/src/json2vec.egg-info}/PKG-INFO +1 -1
- {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec.egg-info/SOURCES.txt +1 -1
- json2vec-0.4.4/src/json2vec/architecture/schema_editor.py +0 -126
- {json2vec-0.4.4 → json2vec-0.4.5}/LICENSE +0 -0
- {json2vec-0.4.4 → json2vec-0.4.5}/NOTICE +0 -0
- {json2vec-0.4.4 → json2vec-0.4.5}/README.md +0 -0
- {json2vec-0.4.4 → json2vec-0.4.5}/setup.cfg +0 -0
- {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/__init__.py +0 -0
- {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/architecture/__init__.py +0 -0
- {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/architecture/attention.py +0 -0
- {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/architecture/checkpoint.py +0 -0
- {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/architecture/encoder.py +0 -0
- {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/architecture/node.py +0 -0
- {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/architecture/plot.py +0 -0
- {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/architecture/pool.py +0 -0
- {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/architecture/rotary.py +0 -0
- {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/data/__init__.py +0 -0
- {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/data/datasets/__init__.py +0 -0
- {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/data/datasets/base.py +0 -0
- {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/data/datasets/polars.py +0 -0
- {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/data/datasets/streaming.py +0 -0
- {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/data/iterables.py +0 -0
- {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/data/processing.py +0 -0
- {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/distributed.py +0 -0
- {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/inference/__init__.py +0 -0
- {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/inference/callback.py +0 -0
- {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/inference/deployment.py +0 -0
- {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/logging/__init__.py +0 -0
- {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/logging/config.py +0 -0
- {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/logging/epoch.py +0 -0
- {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/logging/throughput.py +0 -0
- {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/preprocessors/__init__.py +0 -0
- {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/preprocessors/base.py +0 -0
- {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/preprocessors/extensions/__init__.py +0 -0
- {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/preprocessors/spec.py +0 -0
- {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/structs/__init__.py +0 -0
- {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/structs/enums.py +0 -0
- {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/structs/packages.py +0 -0
- {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/structs/structure.py +0 -0
- {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/tensorfields/__init__.py +0 -0
- {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/tensorfields/base.py +0 -0
- {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/tensorfields/extensions/__init__.py +0 -0
- {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/tensorfields/extensions/entity.py +0 -0
- {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/tensorfields/extensions/text.py +0 -0
- {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/tensorfields/shared/__init__.py +0 -0
- {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/tensorfields/shared/counter.py +0 -0
- {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/tensorfields/shared/vocabulary.py +0 -0
- {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec/tensorfields/spec.py +0 -0
- {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec.egg-info/dependency_links.txt +0 -0
- {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec.egg-info/requires.txt +0 -0
- {json2vec-0.4.4 → json2vec-0.4.5}/src/json2vec.egg-info/top_level.txt +0 -0
- {json2vec-0.4.4 → json2vec-0.4.5}/tests/test_callbacks.py +0 -0
- {json2vec-0.4.4 → json2vec-0.4.5}/tests/test_public_api.py +0 -0
|
@@ -93,7 +93,7 @@ def sanitize(
|
|
|
93
93
|
require_core_tensors(module, address, tensorfield)
|
|
94
94
|
require_tensor_devices(module, address, tensorfield)
|
|
95
95
|
require_target_contract(module, address, tensorfield, strata=normalized)
|
|
96
|
-
require_mask_contract(module, address, tensorfield)
|
|
96
|
+
require_mask_contract(module, address, tensorfield, strata=normalized)
|
|
97
97
|
|
|
98
98
|
|
|
99
99
|
def is_backoff_index(index: int, *, periodic_interval: int) -> bool:
|
|
@@ -292,7 +292,7 @@ def require_tensor_devices(module: "Model", address: Address, tensorfield: Tenso
|
|
|
292
292
|
)
|
|
293
293
|
|
|
294
294
|
|
|
295
|
-
def require_mask_contract(module: "Model", address: Address, tensorfield: TensorFieldBase) -> None:
|
|
295
|
+
def require_mask_contract(module: "Model", address: Address, tensorfield: TensorFieldBase, *, strata: Strata) -> None:
|
|
296
296
|
state = tensorfield.state
|
|
297
297
|
trainable = tensorfield.trainable
|
|
298
298
|
is_masked = state.eq(Tokens.masked.value)
|
|
@@ -301,7 +301,7 @@ def require_mask_contract(module: "Model", address: Address, tensorfield: Tensor
|
|
|
301
301
|
if trainable.any() and not state.masked_select(trainable).eq(Tokens.masked.value).all():
|
|
302
302
|
raise ForwardContractError(f"forward input '{address}' trainable positions must have masked state")
|
|
303
303
|
|
|
304
|
-
if not is_target and (is_masked & ~trainable).any():
|
|
304
|
+
if strata != Strata.predict and not is_target and (is_masked & ~trainable).any():
|
|
305
305
|
raise ForwardContractError(f"forward input '{address}' has masked state where trainable is false")
|
|
306
306
|
|
|
307
307
|
if not trainable.any():
|
|
@@ -8,6 +8,8 @@ from typing import TYPE_CHECKING
|
|
|
8
8
|
import torch
|
|
9
9
|
|
|
10
10
|
from json2vec.architecture.node import NodeModule
|
|
11
|
+
from json2vec.data.datasets.base import EncodedInput
|
|
12
|
+
from json2vec.structs.enums import Strata
|
|
11
13
|
from json2vec.structs.experiment import Hyperparameters
|
|
12
14
|
from json2vec.structs.tree import Address, Node
|
|
13
15
|
|
|
@@ -19,9 +21,19 @@ class ModelGraph:
|
|
|
19
21
|
"""Build and rebuild runtime modules from schema hyperparameters."""
|
|
20
22
|
|
|
21
23
|
@staticmethod
|
|
22
|
-
def
|
|
24
|
+
def example_forward_kwargs(hyperparameters: Hyperparameters, batch_size: int) -> dict[str, EncodedInput | Strata]:
|
|
23
25
|
from json2vec.data.iterables import mock
|
|
24
26
|
|
|
27
|
+
return {
|
|
28
|
+
"inputs": mock(hyperparameters=hyperparameters, batch_size=batch_size),
|
|
29
|
+
"strata": Strata.predict,
|
|
30
|
+
}
|
|
31
|
+
|
|
32
|
+
@staticmethod
|
|
33
|
+
def build(
|
|
34
|
+
hyperparameters: Hyperparameters,
|
|
35
|
+
batch_size: int,
|
|
36
|
+
) -> tuple[torch.nn.ModuleDict, dict[str, EncodedInput | Strata]]:
|
|
25
37
|
nodes: torch.nn.ModuleDict[str, NodeModule] = torch.nn.ModuleDict()
|
|
26
38
|
|
|
27
39
|
for address in hyperparameters.requests | hyperparameters.arrays:
|
|
@@ -31,7 +43,7 @@ class ModelGraph:
|
|
|
31
43
|
batch_size=batch_size,
|
|
32
44
|
)
|
|
33
45
|
|
|
34
|
-
return nodes,
|
|
46
|
+
return nodes, ModelGraph.example_forward_kwargs(hyperparameters=hyperparameters, batch_size=batch_size)
|
|
35
47
|
|
|
36
48
|
@staticmethod
|
|
37
49
|
def install(module: "Model") -> None:
|
|
@@ -72,8 +84,6 @@ class ModelGraph:
|
|
|
72
84
|
|
|
73
85
|
@staticmethod
|
|
74
86
|
def reset_selected(module: "Model", selected: list[Node], *, descendants: bool = False) -> None:
|
|
75
|
-
from json2vec.data.iterables import mock
|
|
76
|
-
|
|
77
87
|
selected_by_address: dict[Address, Node] = {}
|
|
78
88
|
for node in selected:
|
|
79
89
|
if node.address in module.nodes:
|
|
@@ -94,7 +104,10 @@ class ModelGraph:
|
|
|
94
104
|
batch_size=module.batch_size,
|
|
95
105
|
)
|
|
96
106
|
|
|
97
|
-
module.example_input_array =
|
|
107
|
+
module.example_input_array = ModelGraph.example_forward_kwargs(
|
|
108
|
+
hyperparameters=module.hyperparameters,
|
|
109
|
+
batch_size=module.batch_size,
|
|
110
|
+
)
|
|
98
111
|
device = module.device
|
|
99
112
|
if isinstance(device, torch.device):
|
|
100
113
|
module.to(device=device)
|
|
@@ -0,0 +1,323 @@
|
|
|
1
|
+
"""Model-facing schema mutation orchestration."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from collections.abc import Callable, Iterator
|
|
6
|
+
from contextlib import contextmanager
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
from typing import TYPE_CHECKING, Any
|
|
9
|
+
|
|
10
|
+
from loguru import logger
|
|
11
|
+
|
|
12
|
+
from json2vec.architecture.graph import ModelGraph
|
|
13
|
+
from json2vec.structs.experiment import NodeAttribute, NodePredicate, SchemaField
|
|
14
|
+
from json2vec.structs.structure import Array
|
|
15
|
+
from json2vec.structs.tree import Leaf, Node
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
from json2vec.architecture.root import Model
|
|
19
|
+
|
|
20
|
+
_MISSING = object()
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass(frozen=True)
|
|
24
|
+
class AttributeChange:
|
|
25
|
+
node: Node
|
|
26
|
+
name: str
|
|
27
|
+
original: Any
|
|
28
|
+
definition_attribute: bool
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class SchemaEditor:
|
|
32
|
+
"""Coordinate schema mutations with runtime graph rebuilds."""
|
|
33
|
+
|
|
34
|
+
def __init__(self, module: "Model") -> None:
|
|
35
|
+
self.module = module
|
|
36
|
+
|
|
37
|
+
def select(
|
|
38
|
+
self,
|
|
39
|
+
*predicates: NodePredicate | NodeAttribute | Callable[[Node], bool],
|
|
40
|
+
include_root: bool = True,
|
|
41
|
+
use_cache: bool = True,
|
|
42
|
+
) -> list[Node]:
|
|
43
|
+
return self.module.hyperparameters.select(
|
|
44
|
+
*predicates,
|
|
45
|
+
include_root=include_root,
|
|
46
|
+
use_cache=use_cache,
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
def update(
|
|
50
|
+
self,
|
|
51
|
+
*predicates: NodePredicate | NodeAttribute | Callable[[Node], bool],
|
|
52
|
+
strict: bool = True,
|
|
53
|
+
allow_extra: bool = False,
|
|
54
|
+
include_root: bool = True,
|
|
55
|
+
validate: bool = True,
|
|
56
|
+
use_cache: bool = False,
|
|
57
|
+
**values: Any,
|
|
58
|
+
) -> None:
|
|
59
|
+
self.module._assert_mutation_allowed("update")
|
|
60
|
+
values = self.module.hyperparameters.update_values(values)
|
|
61
|
+
changes = self._attribute_changes(
|
|
62
|
+
values=values,
|
|
63
|
+
predicates=predicates,
|
|
64
|
+
allow_extra=allow_extra,
|
|
65
|
+
include_root=include_root,
|
|
66
|
+
use_cache=use_cache,
|
|
67
|
+
)
|
|
68
|
+
self.module.hyperparameters.update(
|
|
69
|
+
*predicates,
|
|
70
|
+
strict=strict,
|
|
71
|
+
allow_extra=allow_extra,
|
|
72
|
+
include_root=include_root,
|
|
73
|
+
validate=validate,
|
|
74
|
+
use_cache=use_cache,
|
|
75
|
+
**values,
|
|
76
|
+
)
|
|
77
|
+
ModelGraph.rebuild(self.module)
|
|
78
|
+
self.module._reset_contracts()
|
|
79
|
+
self._log_attribute_changes("update", changes)
|
|
80
|
+
|
|
81
|
+
def extend(
|
|
82
|
+
self,
|
|
83
|
+
*args: NodePredicate | NodeAttribute | Callable[[Node], bool] | SchemaField,
|
|
84
|
+
include_root: bool = True,
|
|
85
|
+
use_cache: bool = True,
|
|
86
|
+
) -> None:
|
|
87
|
+
self.module._assert_mutation_allowed("extend")
|
|
88
|
+
parent, field_count = self._extend_target(*args, include_root=include_root, use_cache=use_cache)
|
|
89
|
+
self.module.hyperparameters.extend(*args, include_root=include_root, use_cache=use_cache)
|
|
90
|
+
ModelGraph.rebuild(self.module)
|
|
91
|
+
self.module._reset_contracts()
|
|
92
|
+
for field in parent.fields[-field_count:]:
|
|
93
|
+
self._log_node_mutation(
|
|
94
|
+
action="extend",
|
|
95
|
+
message="extended schema node",
|
|
96
|
+
node=field,
|
|
97
|
+
parent=parent,
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
def delete(
|
|
101
|
+
self,
|
|
102
|
+
*predicates: NodePredicate | NodeAttribute | Callable[[Node], bool],
|
|
103
|
+
include_root: bool = False,
|
|
104
|
+
use_cache: bool = True,
|
|
105
|
+
) -> None:
|
|
106
|
+
self.module._assert_mutation_allowed("delete")
|
|
107
|
+
roots = self._delete_roots(*predicates, include_root=include_root, use_cache=use_cache)
|
|
108
|
+
self.module.hyperparameters.delete(*predicates, include_root=include_root, use_cache=use_cache)
|
|
109
|
+
ModelGraph.rebuild(self.module)
|
|
110
|
+
self.module._reset_contracts()
|
|
111
|
+
for node in roots:
|
|
112
|
+
self._log_node_mutation(
|
|
113
|
+
action="delete",
|
|
114
|
+
message="deleted schema node",
|
|
115
|
+
node=node,
|
|
116
|
+
descendants=len(getattr(node, "descendants", ())),
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
def reset(
|
|
120
|
+
self,
|
|
121
|
+
*predicates: NodePredicate | NodeAttribute | Callable[[Node], bool],
|
|
122
|
+
include_root: bool = True,
|
|
123
|
+
use_cache: bool = True,
|
|
124
|
+
descendants: bool = False,
|
|
125
|
+
) -> None:
|
|
126
|
+
self.module._assert_mutation_allowed("reset")
|
|
127
|
+
selected = self.module.hyperparameters.select(
|
|
128
|
+
*predicates,
|
|
129
|
+
include_root=include_root,
|
|
130
|
+
use_cache=use_cache,
|
|
131
|
+
)
|
|
132
|
+
if not selected:
|
|
133
|
+
raise ValueError("reset matched no nodes")
|
|
134
|
+
|
|
135
|
+
nodes = self._runtime_reset_nodes(selected, descendants=descendants)
|
|
136
|
+
ModelGraph.reset_selected(self.module, selected, descendants=descendants)
|
|
137
|
+
self.module._reset_contracts()
|
|
138
|
+
for node in nodes:
|
|
139
|
+
self._log_node_mutation(
|
|
140
|
+
action="reset",
|
|
141
|
+
message="reset runtime node",
|
|
142
|
+
node=node,
|
|
143
|
+
descendants=descendants,
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
@contextmanager
|
|
147
|
+
def override(
|
|
148
|
+
self,
|
|
149
|
+
*predicates: NodePredicate | NodeAttribute | Callable[[Node], bool],
|
|
150
|
+
strict: bool = True,
|
|
151
|
+
allow_extra: bool = False,
|
|
152
|
+
include_root: bool = True,
|
|
153
|
+
validate: bool = True,
|
|
154
|
+
use_cache: bool = False,
|
|
155
|
+
**values: Any,
|
|
156
|
+
) -> Iterator[None]:
|
|
157
|
+
self.module._assert_mutation_allowed("override")
|
|
158
|
+
values = self.module.hyperparameters.update_values(values)
|
|
159
|
+
changes = self._attribute_changes(
|
|
160
|
+
values=values,
|
|
161
|
+
predicates=predicates,
|
|
162
|
+
allow_extra=allow_extra,
|
|
163
|
+
include_root=include_root,
|
|
164
|
+
use_cache=use_cache,
|
|
165
|
+
)
|
|
166
|
+
entered = False
|
|
167
|
+
try:
|
|
168
|
+
with self.module.hyperparameters.override(
|
|
169
|
+
*predicates,
|
|
170
|
+
strict=strict,
|
|
171
|
+
allow_extra=allow_extra,
|
|
172
|
+
include_root=include_root,
|
|
173
|
+
validate=validate,
|
|
174
|
+
use_cache=use_cache,
|
|
175
|
+
**values,
|
|
176
|
+
):
|
|
177
|
+
entered = True
|
|
178
|
+
ModelGraph.rebuild(self.module)
|
|
179
|
+
self.module._reset_contracts()
|
|
180
|
+
self._log_attribute_changes("override", changes)
|
|
181
|
+
yield
|
|
182
|
+
finally:
|
|
183
|
+
ModelGraph.rebuild(self.module)
|
|
184
|
+
self.module._reset_contracts()
|
|
185
|
+
if entered:
|
|
186
|
+
self._log_attribute_changes("override_restore", changes, restored=True)
|
|
187
|
+
|
|
188
|
+
def _attribute_changes(
|
|
189
|
+
self,
|
|
190
|
+
*,
|
|
191
|
+
values: dict[str, Any],
|
|
192
|
+
predicates: tuple[NodePredicate | NodeAttribute | Callable[[Node], bool], ...],
|
|
193
|
+
allow_extra: bool,
|
|
194
|
+
include_root: bool,
|
|
195
|
+
use_cache: bool,
|
|
196
|
+
) -> list[AttributeChange]:
|
|
197
|
+
nodes = self.module.hyperparameters.select(*predicates, include_root=include_root, use_cache=use_cache)
|
|
198
|
+
changes: list[AttributeChange] = []
|
|
199
|
+
for node in nodes:
|
|
200
|
+
can_apply_extra = allow_extra and getattr(type(node), "model_config", {}).get("extra") == "allow"
|
|
201
|
+
for name in values:
|
|
202
|
+
if not (_has_node_attribute(node, name) or can_apply_extra):
|
|
203
|
+
continue
|
|
204
|
+
|
|
205
|
+
changes.append(
|
|
206
|
+
AttributeChange(
|
|
207
|
+
node=node,
|
|
208
|
+
name=name,
|
|
209
|
+
original=getattr(node, name, _MISSING),
|
|
210
|
+
definition_attribute=_is_definition_attribute(node, name),
|
|
211
|
+
)
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
return changes
|
|
215
|
+
|
|
216
|
+
def _extend_target(
|
|
217
|
+
self,
|
|
218
|
+
*args: NodePredicate | NodeAttribute | Callable[[Node], bool] | SchemaField,
|
|
219
|
+
include_root: bool,
|
|
220
|
+
use_cache: bool,
|
|
221
|
+
) -> tuple[Array, int]:
|
|
222
|
+
predicates: list[NodePredicate | NodeAttribute | Callable[[Node], bool]] = []
|
|
223
|
+
field_count = 0
|
|
224
|
+
reading_fields = False
|
|
225
|
+
|
|
226
|
+
for item in args:
|
|
227
|
+
if isinstance(item, (Array, Leaf)):
|
|
228
|
+
reading_fields = True
|
|
229
|
+
field_count += 1
|
|
230
|
+
continue
|
|
231
|
+
|
|
232
|
+
if reading_fields:
|
|
233
|
+
raise TypeError("extend predicates must come before new schema fields")
|
|
234
|
+
|
|
235
|
+
predicates.append(item)
|
|
236
|
+
|
|
237
|
+
if field_count == 0:
|
|
238
|
+
raise ValueError("extend requires at least one schema field")
|
|
239
|
+
|
|
240
|
+
candidates = [
|
|
241
|
+
node
|
|
242
|
+
for node in self.module.hyperparameters.select(*predicates, include_root=include_root, use_cache=use_cache)
|
|
243
|
+
if isinstance(node, Array)
|
|
244
|
+
]
|
|
245
|
+
if len(candidates) != 1:
|
|
246
|
+
raise ValueError(f"extend requires exactly one matching array node, found {len(candidates)}")
|
|
247
|
+
|
|
248
|
+
return candidates[0], field_count
|
|
249
|
+
|
|
250
|
+
def _delete_roots(
|
|
251
|
+
self,
|
|
252
|
+
*predicates: NodePredicate | NodeAttribute | Callable[[Node], bool],
|
|
253
|
+
include_root: bool,
|
|
254
|
+
use_cache: bool,
|
|
255
|
+
) -> list[Node]:
|
|
256
|
+
selected = self.module.hyperparameters.select(*predicates, include_root=include_root, use_cache=use_cache)
|
|
257
|
+
selected_ids = {id(node) for node in selected}
|
|
258
|
+
return [
|
|
259
|
+
node
|
|
260
|
+
for node in selected
|
|
261
|
+
if not any(
|
|
262
|
+
id(ancestor) in selected_ids
|
|
263
|
+
for ancestor in getattr(node, "ancestors", ())
|
|
264
|
+
if ancestor is not self.module.hyperparameters
|
|
265
|
+
)
|
|
266
|
+
]
|
|
267
|
+
|
|
268
|
+
def _runtime_reset_nodes(self, selected: list[Node], *, descendants: bool) -> list[Node]:
|
|
269
|
+
nodes: dict[str, Node] = {}
|
|
270
|
+
for node in selected:
|
|
271
|
+
if node.address in self.module.nodes:
|
|
272
|
+
nodes[str(node.address)] = node
|
|
273
|
+
|
|
274
|
+
if descendants:
|
|
275
|
+
for descendant in getattr(node, "descendants", ()):
|
|
276
|
+
if descendant.address in self.module.nodes:
|
|
277
|
+
nodes[str(descendant.address)] = descendant
|
|
278
|
+
|
|
279
|
+
return list(nodes.values())
|
|
280
|
+
|
|
281
|
+
def _log_attribute_changes(self, action: str, changes: list[AttributeChange], *, restored: bool = False) -> None:
|
|
282
|
+
for change in changes:
|
|
283
|
+
value = change.original if restored else getattr(change.node, change.name, _MISSING)
|
|
284
|
+
logger.bind(
|
|
285
|
+
component="schema_mutation",
|
|
286
|
+
action=action,
|
|
287
|
+
address=str(change.node.address),
|
|
288
|
+
node_type=change.node.type,
|
|
289
|
+
attribute=change.name,
|
|
290
|
+
definition_attribute=change.definition_attribute,
|
|
291
|
+
value=_format_log_value(value),
|
|
292
|
+
previous_value=_format_log_value(change.original),
|
|
293
|
+
).info("restored schema node attribute" if restored else "mutated schema node attribute")
|
|
294
|
+
|
|
295
|
+
def _log_node_mutation(self, *, action: str, message: str, node: Node, **kwargs: Any) -> None:
|
|
296
|
+
extra = {key: str(value.address) if isinstance(value, Node) else value for key, value in kwargs.items()}
|
|
297
|
+
logger.bind(
|
|
298
|
+
component="schema_mutation",
|
|
299
|
+
action=action,
|
|
300
|
+
address=str(node.address),
|
|
301
|
+
node_type=node.type,
|
|
302
|
+
attribute=None,
|
|
303
|
+
definition_attribute=None,
|
|
304
|
+
**extra,
|
|
305
|
+
).info(message)
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
def _has_node_attribute(node: Node, name: str) -> bool:
|
|
309
|
+
fields = getattr(type(node), "model_fields", {})
|
|
310
|
+
extra = getattr(node, "model_extra", None) or {}
|
|
311
|
+
return name in fields or name in extra or hasattr(node, name)
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
def _is_definition_attribute(node: Node, name: str) -> bool:
|
|
315
|
+
return name in getattr(type(node), "model_fields", {})
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
def _format_log_value(value: Any) -> str:
|
|
319
|
+
if value is _MISSING:
|
|
320
|
+
return "<missing>"
|
|
321
|
+
|
|
322
|
+
text = repr(value)
|
|
323
|
+
return text if len(text) <= 160 else f"{text[:157]}..."
|
|
@@ -18,10 +18,10 @@ from tensordict import TensorDict
|
|
|
18
18
|
from json2vec.architecture.checkpoint import CheckpointState
|
|
19
19
|
from json2vec.architecture.contracts import ContractScheduler
|
|
20
20
|
from json2vec.architecture.graph import ModelGraph
|
|
21
|
+
from json2vec.architecture.mutations import SchemaEditor
|
|
21
22
|
from json2vec.architecture.plot import PlotMode
|
|
22
23
|
from json2vec.architecture.runtime import EvaluationResult, ModelRuntime, Postprocessor, PreprocessFn, step
|
|
23
|
-
from json2vec.
|
|
24
|
-
from json2vec.data.datasets.base import EncodedBatch
|
|
24
|
+
from json2vec.data.datasets.base import EncodedBatch, EncodedInput
|
|
25
25
|
from json2vec.structs.enums import AttentionMode, Strata
|
|
26
26
|
from json2vec.structs.experiment import (
|
|
27
27
|
Hyperparameters,
|
|
@@ -520,24 +520,38 @@ class Model(lit.LightningModule):
|
|
|
520
520
|
) -> tuple[dict[Address, dict[str, Any]], dict[Address, dict[str, Any]]]:
|
|
521
521
|
return ModelRuntime.write(self, predictions)
|
|
522
522
|
|
|
523
|
+
@immutable("inference")
|
|
524
|
+
def encode(
|
|
525
|
+
self,
|
|
526
|
+
batch: EncodedBatch | list[dict[str, Any]],
|
|
527
|
+
preprocess: PreprocessFn | None = None,
|
|
528
|
+
strata: Strata | str = Strata.predict,
|
|
529
|
+
) -> EncodedInput:
|
|
530
|
+
"""Return encoded tensorfield inputs for raw or processed observations."""
|
|
531
|
+
return ModelRuntime.encode(
|
|
532
|
+
self,
|
|
533
|
+
batch=batch,
|
|
534
|
+
preprocess=preprocess,
|
|
535
|
+
strata=strata,
|
|
536
|
+
)
|
|
537
|
+
|
|
523
538
|
@immutable("inference")
|
|
524
539
|
def evaluate(
|
|
525
540
|
self,
|
|
526
541
|
batch: EncodedBatch | list[dict[str, Any]],
|
|
527
542
|
preprocess: PreprocessFn | None = None,
|
|
528
543
|
postprocess: Postprocessor | None = None,
|
|
529
|
-
) ->
|
|
544
|
+
) -> EvaluationResult:
|
|
530
545
|
"""Run prediction and embedding for encoded or raw observations.
|
|
531
546
|
|
|
532
547
|
If `preprocess` is omitted, raw records are encoded unchanged.
|
|
533
548
|
"""
|
|
534
|
-
|
|
549
|
+
return ModelRuntime.evaluate(
|
|
535
550
|
self,
|
|
536
551
|
batch=batch,
|
|
537
552
|
preprocess=preprocess,
|
|
538
553
|
postprocess=postprocess,
|
|
539
554
|
)
|
|
540
|
-
return result.as_tuple()
|
|
541
555
|
|
|
542
556
|
def predict(
|
|
543
557
|
self,
|
|
@@ -546,12 +560,14 @@ class Model(lit.LightningModule):
|
|
|
546
560
|
postprocess: Postprocessor | None = None,
|
|
547
561
|
) -> dict[Address, dict[str, Any]]:
|
|
548
562
|
"""Return typed predictions for a raw or encoded batch."""
|
|
549
|
-
|
|
563
|
+
|
|
564
|
+
result = self.evaluate(
|
|
550
565
|
batch=batch,
|
|
551
566
|
preprocess=preprocess,
|
|
552
567
|
postprocess=postprocess,
|
|
553
568
|
)
|
|
554
|
-
|
|
569
|
+
|
|
570
|
+
return result.predictions
|
|
555
571
|
|
|
556
572
|
def embed(
|
|
557
573
|
self,
|
|
@@ -560,12 +576,12 @@ class Model(lit.LightningModule):
|
|
|
560
576
|
postprocess: Postprocessor | None = None,
|
|
561
577
|
) -> dict[Address, dict[str, Any]]:
|
|
562
578
|
"""Return configured embeddings for a raw or encoded batch."""
|
|
563
|
-
|
|
579
|
+
result = self.evaluate(
|
|
564
580
|
batch=batch,
|
|
565
581
|
preprocess=preprocess,
|
|
566
582
|
postprocess=postprocess,
|
|
567
583
|
)
|
|
568
|
-
return embeddings
|
|
584
|
+
return result.embeddings
|
|
569
585
|
|
|
570
586
|
training_step = partialmethod(step, strata=Strata.train)
|
|
571
587
|
validation_step = partialmethod(step, strata=Strata.validate)
|
|
@@ -14,7 +14,8 @@ from tensordict import TensorDict
|
|
|
14
14
|
from json2vec.architecture.contracts import sanitize
|
|
15
15
|
from json2vec.architecture.encoder import ArrayEncoder
|
|
16
16
|
from json2vec.architecture.node import NodeModule
|
|
17
|
-
from json2vec.data.datasets.base import EncodedBatch
|
|
17
|
+
from json2vec.data.datasets.base import EncodedBatch, EncodedInput
|
|
18
|
+
from json2vec.data.iterables import encode
|
|
18
19
|
from json2vec.structs.enums import Metric, Strata, TensorKey
|
|
19
20
|
from json2vec.structs.packages import Embedding, Parcel, Prediction
|
|
20
21
|
from json2vec.structs.tree import Address
|
|
@@ -50,9 +51,6 @@ class EvaluationResult:
|
|
|
50
51
|
predictions: dict[Address, dict[str, Any]]
|
|
51
52
|
embeddings: dict[Address, dict[str, Any]]
|
|
52
53
|
|
|
53
|
-
def as_tuple(self) -> tuple[dict[Address, dict[str, Any]], dict[Address, dict[str, Any]]]:
|
|
54
|
-
return self.predictions, self.embeddings
|
|
55
|
-
|
|
56
54
|
|
|
57
55
|
class ModelRuntime:
|
|
58
56
|
"""Own runtime behavior that depends on an already-built model graph."""
|
|
@@ -182,16 +180,13 @@ class ModelRuntime:
|
|
|
182
180
|
return supervised, embeddings
|
|
183
181
|
|
|
184
182
|
@staticmethod
|
|
185
|
-
def
|
|
183
|
+
def encode(
|
|
186
184
|
module: "Model",
|
|
187
185
|
batch: EncodedBatch | list[dict[str, Any]],
|
|
188
186
|
preprocess: PreprocessFn | None = None,
|
|
189
|
-
|
|
190
|
-
) ->
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
was_training = module.training
|
|
194
|
-
raw_batch = batch
|
|
187
|
+
strata: Strata | str = Strata.predict,
|
|
188
|
+
) -> EncodedInput:
|
|
189
|
+
strata = Strata.normalize(strata)
|
|
195
190
|
|
|
196
191
|
if preprocess is not None:
|
|
197
192
|
observations: EncodedBatch = []
|
|
@@ -206,13 +201,24 @@ class ModelRuntime:
|
|
|
206
201
|
elif batch and isinstance(batch[0], dict):
|
|
207
202
|
batch = [[request] for request in cast(list[dict[str, Any]], batch)]
|
|
208
203
|
|
|
209
|
-
|
|
204
|
+
return encode(
|
|
210
205
|
batch=cast(EncodedBatch, batch),
|
|
211
206
|
hyperparameters=module.hyperparameters,
|
|
212
|
-
strata=
|
|
207
|
+
strata=strata,
|
|
213
208
|
interprocess_encoding_context=module.interprocess_encoding_context,
|
|
214
209
|
)
|
|
215
210
|
|
|
211
|
+
@staticmethod
|
|
212
|
+
def evaluate(
|
|
213
|
+
module: "Model",
|
|
214
|
+
batch: EncodedBatch | list[dict[str, Any]],
|
|
215
|
+
preprocess: PreprocessFn | None = None,
|
|
216
|
+
postprocess: Postprocessor | None = None,
|
|
217
|
+
) -> EvaluationResult:
|
|
218
|
+
was_training = module.training
|
|
219
|
+
raw_batch = batch
|
|
220
|
+
inputs = ModelRuntime.encode(module=module, batch=batch, preprocess=preprocess, strata=Strata.predict)
|
|
221
|
+
|
|
216
222
|
module.eval()
|
|
217
223
|
try:
|
|
218
224
|
with torch.inference_mode():
|
|
@@ -226,7 +232,7 @@ class ModelRuntime:
|
|
|
226
232
|
if postprocess is not None:
|
|
227
233
|
context = {
|
|
228
234
|
"batch": raw_batch,
|
|
229
|
-
"observations":
|
|
235
|
+
"observations": inputs[TensorKey.metadata],
|
|
230
236
|
"input": inputs,
|
|
231
237
|
TensorKey.metadata: inputs[TensorKey.metadata],
|
|
232
238
|
}
|
|
@@ -65,7 +65,7 @@ class Hyperparameters(Node):
|
|
|
65
65
|
@classmethod
|
|
66
66
|
def update_values(cls, values: Mapping[str, Any]) -> dict[str, Any]:
|
|
67
67
|
normalized = dict(values)
|
|
68
|
-
target = normalized.
|
|
68
|
+
target = normalized.get("target", None)
|
|
69
69
|
|
|
70
70
|
if target is None:
|
|
71
71
|
return normalized
|
|
@@ -76,11 +76,9 @@ class Hyperparameters(Node):
|
|
|
76
76
|
if target:
|
|
77
77
|
if normalized.get("p_prune") not in (None, 1.0):
|
|
78
78
|
raise ValueError("target=True is shorthand for p_prune=1.0")
|
|
79
|
-
normalized["p_prune"] = 1.0
|
|
80
79
|
else:
|
|
81
|
-
if normalized
|
|
80
|
+
if "p_prune" in normalized and normalized["p_prune"] is not None:
|
|
82
81
|
raise ValueError("target=False is shorthand for p_prune=None")
|
|
83
|
-
normalized["p_prune"] = None
|
|
84
82
|
|
|
85
83
|
return normalized
|
|
86
84
|
|
|
@@ -215,7 +213,7 @@ class Hyperparameters(Node):
|
|
|
215
213
|
@property
|
|
216
214
|
def target(self) -> list[Address]:
|
|
217
215
|
role = NodePredicate(
|
|
218
|
-
func=lambda node: isinstance(node, Leaf) and node.active and
|
|
216
|
+
func=lambda node: isinstance(node, Leaf) and node.active and node.target,
|
|
219
217
|
key=("role", "target"),
|
|
220
218
|
)
|
|
221
219
|
return [Address(str(node.address)) for node in self.select(role)]
|
|
@@ -348,6 +346,8 @@ class Hyperparameters(Node):
|
|
|
348
346
|
|
|
349
347
|
if validate and applicable_values:
|
|
350
348
|
payload = node.model_dump(mode="python", round_trip=True)
|
|
349
|
+
if "target" in applicable_values and "p_prune" not in applicable_values:
|
|
350
|
+
payload.pop("p_prune", None)
|
|
351
351
|
payload.update(applicable_values)
|
|
352
352
|
validated = type(node).model_validate(payload)
|
|
353
353
|
applicable_values = {name: getattr(validated, name) for name in applicable_values}
|
|
@@ -501,7 +501,12 @@ class Hyperparameters(Node):
|
|
|
501
501
|
nodes = self.select(*predicates, include_root=include_root, use_cache=use_cache)
|
|
502
502
|
normalized_values = self.update_values(values)
|
|
503
503
|
snapshot = [
|
|
504
|
-
(
|
|
504
|
+
(
|
|
505
|
+
node,
|
|
506
|
+
"p_prune" if name == "target" else name,
|
|
507
|
+
getattr(node, "p_prune" if name == "target" else name, _MISSING),
|
|
508
|
+
("p_prune" if name == "target" else name) in getattr(node, "model_fields_set", set()),
|
|
509
|
+
)
|
|
505
510
|
for node in nodes
|
|
506
511
|
for name in normalized_values
|
|
507
512
|
if _has_model_attribute(node, name)
|
|
@@ -140,7 +140,7 @@ class NodeAttribute(pydantic.BaseModel):
|
|
|
140
140
|
if self.name == "descendants":
|
|
141
141
|
return tuple(str(child.address) for child in getattr(node, "descendants", ()))
|
|
142
142
|
if self.name == "target":
|
|
143
|
-
return isinstance(node, Leaf) and node.active and
|
|
143
|
+
return isinstance(node, Leaf) and node.active and node.target
|
|
144
144
|
|
|
145
145
|
extra = getattr(node, "model_extra", None) or {}
|
|
146
146
|
if self.name in extra:
|
|
@@ -51,6 +51,17 @@ class Node(NodeMixin, pydantic.BaseModel):
|
|
|
51
51
|
p_mask: Rate | None = None
|
|
52
52
|
p_prune: PruneRate | None = None
|
|
53
53
|
|
|
54
|
+
@property
|
|
55
|
+
def target(self) -> bool:
|
|
56
|
+
return self.p_prune == 1.0
|
|
57
|
+
|
|
58
|
+
@target.setter
|
|
59
|
+
def target(self, value: bool) -> None:
|
|
60
|
+
if not isinstance(value, bool):
|
|
61
|
+
raise ValueError("target must be a boolean")
|
|
62
|
+
|
|
63
|
+
self.p_prune = 1.0 if value else None
|
|
64
|
+
|
|
54
65
|
@classmethod
|
|
55
66
|
def sanitize_name(cls, value: str) -> str:
|
|
56
67
|
sanitized = re.sub(r"[^0-9A-Za-z_-]+", "_", value).strip("_")
|
|
@@ -78,6 +89,10 @@ class Node(NodeMixin, pydantic.BaseModel):
|
|
|
78
89
|
if values.get("p_prune") not in (None, 1.0):
|
|
79
90
|
raise ValueError("target=True is shorthand for p_prune=1.0")
|
|
80
91
|
values["p_prune"] = 1.0
|
|
92
|
+
else:
|
|
93
|
+
if values.get("p_prune") is not None:
|
|
94
|
+
raise ValueError("target=False is shorthand for p_prune=None")
|
|
95
|
+
values["p_prune"] = None
|
|
81
96
|
|
|
82
97
|
return values
|
|
83
98
|
|
|
@@ -1,8 +1,9 @@
|
|
|
1
1
|
# ty: ignore[invalid-method-override,unknown-argument]
|
|
2
2
|
from __future__ import annotations
|
|
3
3
|
|
|
4
|
+
from collections.abc import Mapping
|
|
4
5
|
from functools import partial
|
|
5
|
-
from typing import TYPE_CHECKING, Annotated, Literal, cast
|
|
6
|
+
from typing import TYPE_CHECKING, Annotated, Any, Literal, cast
|
|
6
7
|
|
|
7
8
|
import numpy as np
|
|
8
9
|
import pydantic
|
|
@@ -46,10 +47,17 @@ class Request(RequestBase):
|
|
|
46
47
|
|
|
47
48
|
type: Literal["category"] = "category"
|
|
48
49
|
max_vocab_size: Annotated[int, pydantic.Field(gt=0, default=10_000)] = 10000
|
|
49
|
-
n_bands: Annotated[int, pydantic.Field(gt=0, default=8)] = 8
|
|
50
50
|
p_unavailable: Annotated[float, pydantic.Field(ge=0.0, le=1.0, default=0.01)] = 0.01
|
|
51
51
|
topk: list[int] | None = None
|
|
52
52
|
|
|
53
|
+
@pydantic.model_validator(mode="before")
|
|
54
|
+
@classmethod
|
|
55
|
+
def reject_removed_options(cls, data: Any) -> Any:
|
|
56
|
+
if isinstance(data, Mapping) and "n_bands" in data:
|
|
57
|
+
raise ValueError("Category does not support n_bands")
|
|
58
|
+
|
|
59
|
+
return data
|
|
60
|
+
|
|
53
61
|
@pydantic.model_validator(mode="after")
|
|
54
62
|
def check_topk(self):
|
|
55
63
|
if self.topk is None:
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
# ty: ignore[invalid-argument-type,invalid-assignment,unknown-argument,unresolved-attribute]
|
|
2
2
|
from __future__ import annotations
|
|
3
3
|
|
|
4
|
+
import difflib
|
|
4
5
|
import enum
|
|
5
6
|
import math
|
|
6
7
|
import re
|
|
@@ -74,6 +75,22 @@ class DatePart(enum.StrEnum):
|
|
|
74
75
|
return cls.DEPTH[datepart]
|
|
75
76
|
|
|
76
77
|
|
|
78
|
+
def _normalize_datepart_key(value: str) -> str:
|
|
79
|
+
value = re.sub(r"(?<=[a-z0-9])(?=[A-Z])", "_", value.strip())
|
|
80
|
+
value = re.sub(r"[^0-9A-Za-z]+", "_", value)
|
|
81
|
+
return value.strip("_").casefold()
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def _datepart_lookup() -> dict[str, DatePart]:
|
|
85
|
+
lookup: dict[str, DatePart] = {}
|
|
86
|
+
for datepart in DatePart:
|
|
87
|
+
normalized = _normalize_datepart_key(datepart.value)
|
|
88
|
+
lookup[normalized] = datepart
|
|
89
|
+
lookup[normalized.replace("_", "")] = datepart
|
|
90
|
+
|
|
91
|
+
return lookup
|
|
92
|
+
|
|
93
|
+
|
|
77
94
|
@DatePart.day_of_month.register(depth=31)
|
|
78
95
|
def _(arr: np.ndarray) -> np.ndarray:
|
|
79
96
|
month_start = arr.astype("datetime64[M]")
|
|
@@ -133,6 +150,35 @@ class Request(RequestBase):
|
|
|
133
150
|
dateparts: list[DatePart]
|
|
134
151
|
pattern: Annotated[str | None, pydantic.Field(default=None)] = None
|
|
135
152
|
|
|
153
|
+
@pydantic.field_validator("dateparts", mode="before", check_fields=False)
|
|
154
|
+
@classmethod
|
|
155
|
+
def _coerce_dateparts(cls, value: Any) -> Any:
|
|
156
|
+
if not isinstance(value, (list, tuple)):
|
|
157
|
+
return value
|
|
158
|
+
|
|
159
|
+
lookup = _datepart_lookup()
|
|
160
|
+
canonical = [datepart.value for datepart in DatePart]
|
|
161
|
+
dateparts: list[DatePart] = []
|
|
162
|
+
for item in value:
|
|
163
|
+
if isinstance(item, DatePart):
|
|
164
|
+
dateparts.append(item)
|
|
165
|
+
continue
|
|
166
|
+
|
|
167
|
+
if not isinstance(item, str):
|
|
168
|
+
raise ValueError(f"datepart values must be strings, got {type(item).__name__}")
|
|
169
|
+
|
|
170
|
+
key = _normalize_datepart_key(item)
|
|
171
|
+
match = lookup.get(key) or lookup.get(key.replace("_", ""))
|
|
172
|
+
if match is not None:
|
|
173
|
+
dateparts.append(match)
|
|
174
|
+
continue
|
|
175
|
+
|
|
176
|
+
suggestions = difflib.get_close_matches(key, canonical, n=1)
|
|
177
|
+
suggestion = f"; did you mean '{suggestions[0]}'?" if suggestions else ""
|
|
178
|
+
raise ValueError(f"unknown datepart '{item}'{suggestion}")
|
|
179
|
+
|
|
180
|
+
return dateparts
|
|
181
|
+
|
|
136
182
|
@pydantic.field_validator("dateparts", check_fields=False)
|
|
137
183
|
@classmethod
|
|
138
184
|
def check_dateparts(cls, v):
|
|
@@ -59,7 +59,7 @@ class Request(RequestBase):
|
|
|
59
59
|
"""Numeric scalar tensorfield request."""
|
|
60
60
|
|
|
61
61
|
type: Literal["number"] = "number"
|
|
62
|
-
jitter: Annotated[float, pydantic.Field(ge=0.0,
|
|
62
|
+
jitter: Annotated[float, pydantic.Field(ge=0.0, default=0.0)] = 0.0
|
|
63
63
|
n_bands: Annotated[int, pydantic.Field(gt=0, default=8)] = 8
|
|
64
64
|
offset: Annotated[int, pydantic.Field(gt=0, default=4)] = 4
|
|
65
65
|
alpha: Annotated[float | None, pydantic.Field(gt=0.0, lt=1.0, default=None)] = None
|
|
@@ -43,6 +43,7 @@ class Request(RequestBase):
|
|
|
43
43
|
type: Literal["set"] = "set"
|
|
44
44
|
max_vocab_size: Annotated[int, pydantic.Field(gt=0, default=10_000)] = 10_000
|
|
45
45
|
p_unavailable: Annotated[float, pydantic.Field(ge=0.0, le=1.0, default=0.01)] = 0.01
|
|
46
|
+
threshold: Annotated[float | None, pydantic.Field(ge=0.0, le=1.0, default=None)] = None
|
|
46
47
|
|
|
47
48
|
|
|
48
49
|
def _items(value: Any) -> Iterable[Any]:
|
|
@@ -373,6 +374,7 @@ def loss(
|
|
|
373
374
|
@sets.register
|
|
374
375
|
def write(module: Model, prediction: Prediction):
|
|
375
376
|
node = module.nodes[prediction.address]
|
|
377
|
+
request: Request = module.hyperparameters.requests[prediction.address]
|
|
376
378
|
state_logits: torch.Tensor = prediction.payload[TensorKey.state]
|
|
377
379
|
content_logits: torch.Tensor = prediction.payload[TensorKey.content]
|
|
378
380
|
|
|
@@ -383,7 +385,22 @@ def write(module: Model, prediction: Prediction):
|
|
|
383
385
|
|
|
384
386
|
vocab = node.embedder.vocab.snapshot()
|
|
385
387
|
probabilities = content_logits[..., : len(vocab)].sigmoid().detach().float().cpu().numpy()
|
|
386
|
-
|
|
388
|
+
if request.threshold is None:
|
|
389
|
+
content_payload = {str(label): probabilities[..., index] for index, label in enumerate(vocab)}
|
|
390
|
+
else:
|
|
391
|
+
labels = np.asarray(vocab, dtype=object)
|
|
392
|
+
|
|
393
|
+
def pack_thresholded(values: np.ndarray) -> dict[str, float] | list:
|
|
394
|
+
if values.ndim == 1:
|
|
395
|
+
keep = values >= request.threshold
|
|
396
|
+
return {
|
|
397
|
+
str(label): float(probability)
|
|
398
|
+
for label, probability in zip(labels[keep].tolist(), values[keep].tolist())
|
|
399
|
+
}
|
|
400
|
+
|
|
401
|
+
return [pack_thresholded(values[index]) for index in range(values.shape[0])]
|
|
402
|
+
|
|
403
|
+
content_payload = pack_thresholded(probabilities)
|
|
387
404
|
|
|
388
405
|
return {
|
|
389
406
|
TensorKey.state.name: state_payload,
|
|
@@ -226,7 +226,11 @@ class Decoder(DecoderBase):
|
|
|
226
226
|
|
|
227
227
|
request: Request = hyperparameters.requests[address]
|
|
228
228
|
|
|
229
|
-
self.
|
|
229
|
+
self.classification = torch.nn.Linear(
|
|
230
|
+
in_features=hyperparameters.d_model,
|
|
231
|
+
out_features=len(Tokens),
|
|
232
|
+
)
|
|
233
|
+
self.regression = torch.nn.Linear(
|
|
230
234
|
in_features=hyperparameters.d_model,
|
|
231
235
|
out_features=request.n_dim,
|
|
232
236
|
)
|
|
@@ -235,7 +239,8 @@ class Decoder(DecoderBase):
|
|
|
235
239
|
def decode(self, pooled: torch.Tensor) -> TensorDict[TensorKey, torch.Tensor]:
|
|
236
240
|
return TensorDict(
|
|
237
241
|
source={
|
|
238
|
-
TensorKey.
|
|
242
|
+
TensorKey.state: self.classification(pooled),
|
|
243
|
+
TensorKey.content: self.regression(pooled),
|
|
239
244
|
}
|
|
240
245
|
)
|
|
241
246
|
|
|
@@ -251,30 +256,67 @@ def loss(
|
|
|
251
256
|
request: Request = module.hyperparameters.requests[address]
|
|
252
257
|
|
|
253
258
|
trainable = batch.trainable.reshape(-1)
|
|
259
|
+
state_targets = batch.targets[TensorKey.state].reshape(-1)
|
|
260
|
+
state_inputs = prediction.payload[TensorKey.state].reshape(-1, len(Tokens))
|
|
261
|
+
|
|
262
|
+
output: torch.Tensor = module.track(
|
|
263
|
+
(address, strata, Metric.loss, TensorKey.state),
|
|
264
|
+
value=(
|
|
265
|
+
torch.nn.functional.cross_entropy(
|
|
266
|
+
input=state_inputs,
|
|
267
|
+
target=state_targets,
|
|
268
|
+
reduction="none",
|
|
269
|
+
)
|
|
270
|
+
.masked_select(trainable)
|
|
271
|
+
.mean()
|
|
272
|
+
),
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
module.track(
|
|
276
|
+
(address, strata, Metric.accuracy, TensorKey.state),
|
|
277
|
+
value=state_inputs.argmax(dim=1).eq(state_targets).masked_select(trainable).float().mean(),
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
valued = trainable & state_targets.eq(Tokens.valued.value)
|
|
281
|
+
if not valued.any():
|
|
282
|
+
return output
|
|
283
|
+
|
|
254
284
|
inputs = prediction.payload[TensorKey.content].reshape(-1, request.n_dim)
|
|
255
285
|
targets = batch.targets[TensorKey.content].reshape(-1, request.n_dim)
|
|
256
286
|
diff = inputs.subtract(targets)
|
|
257
287
|
|
|
258
|
-
|
|
288
|
+
output += module.track(
|
|
259
289
|
(address, strata, Metric.loss, TensorKey.content),
|
|
260
|
-
value=request.objective.loss(inputs=inputs, targets=targets).masked_select(
|
|
290
|
+
value=request.objective.loss(inputs=inputs, targets=targets).masked_select(valued).mean(),
|
|
261
291
|
)
|
|
262
292
|
|
|
263
293
|
module.track(
|
|
264
294
|
(address, strata, Metric.mae, TensorKey.content),
|
|
265
|
-
value=diff.absolute().mean(dim=1).masked_select(
|
|
295
|
+
value=diff.absolute().mean(dim=1).masked_select(valued).mean(),
|
|
266
296
|
)
|
|
267
297
|
|
|
268
298
|
module.track(
|
|
269
299
|
(address, strata, Metric.rmse, TensorKey.content),
|
|
270
|
-
value=diff.square().mean(dim=1).sqrt().masked_select(
|
|
300
|
+
value=diff.square().mean(dim=1).sqrt().masked_select(valued).mean(),
|
|
271
301
|
)
|
|
272
302
|
|
|
273
|
-
return
|
|
303
|
+
return output
|
|
274
304
|
|
|
275
305
|
|
|
276
306
|
@vector.register
|
|
277
307
|
def write(module: Model, prediction: Prediction):
|
|
308
|
+
content: np.ndarray = prediction.payload[TensorKey.content].detach().float().cpu().numpy()
|
|
309
|
+
state_logits: torch.Tensor = prediction.payload[TensorKey.state]
|
|
310
|
+
tokens: np.ndarray = np.fromiter((token.name for token in Tokens), dtype=object, count=len(Tokens))
|
|
311
|
+
state_log_norm = state_logits.logsumexp(dim=-1, keepdim=True)
|
|
312
|
+
state_distribution = (state_logits - state_log_norm).exp().detach().float().cpu().numpy()
|
|
313
|
+
state_payload = {token: state_distribution[..., index] for index, token in enumerate(tokens.tolist())}
|
|
314
|
+
|
|
315
|
+
non_valued = state_logits.argmax(dim=-1).ne(Tokens.valued.value).detach().cpu().numpy()
|
|
316
|
+
content = content.copy()
|
|
317
|
+
content[non_valued] = 0.0
|
|
318
|
+
|
|
278
319
|
return {
|
|
279
|
-
TensorKey.
|
|
320
|
+
TensorKey.state.name: state_payload,
|
|
321
|
+
TensorKey.content.name: content,
|
|
280
322
|
}
|
|
@@ -15,13 +15,13 @@ src/json2vec/architecture/checkpoint.py
|
|
|
15
15
|
src/json2vec/architecture/contracts.py
|
|
16
16
|
src/json2vec/architecture/encoder.py
|
|
17
17
|
src/json2vec/architecture/graph.py
|
|
18
|
+
src/json2vec/architecture/mutations.py
|
|
18
19
|
src/json2vec/architecture/node.py
|
|
19
20
|
src/json2vec/architecture/plot.py
|
|
20
21
|
src/json2vec/architecture/pool.py
|
|
21
22
|
src/json2vec/architecture/root.py
|
|
22
23
|
src/json2vec/architecture/rotary.py
|
|
23
24
|
src/json2vec/architecture/runtime.py
|
|
24
|
-
src/json2vec/architecture/schema_editor.py
|
|
25
25
|
src/json2vec/data/__init__.py
|
|
26
26
|
src/json2vec/data/iterables.py
|
|
27
27
|
src/json2vec/data/processing.py
|
|
@@ -1,126 +0,0 @@
|
|
|
1
|
-
"""Model-facing schema mutation orchestration."""
|
|
2
|
-
|
|
3
|
-
from __future__ import annotations
|
|
4
|
-
|
|
5
|
-
from collections.abc import Callable, Iterator
|
|
6
|
-
from contextlib import contextmanager
|
|
7
|
-
from typing import TYPE_CHECKING, Any
|
|
8
|
-
|
|
9
|
-
from json2vec.architecture.graph import ModelGraph
|
|
10
|
-
from json2vec.structs.experiment import NodeAttribute, NodePredicate, SchemaField
|
|
11
|
-
from json2vec.structs.tree import Node
|
|
12
|
-
|
|
13
|
-
if TYPE_CHECKING:
|
|
14
|
-
from json2vec.architecture.root import Model
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
class SchemaEditor:
|
|
18
|
-
"""Coordinate schema mutations with runtime graph rebuilds."""
|
|
19
|
-
|
|
20
|
-
def __init__(self, module: "Model") -> None:
|
|
21
|
-
self.module = module
|
|
22
|
-
|
|
23
|
-
def select(
|
|
24
|
-
self,
|
|
25
|
-
*predicates: NodePredicate | NodeAttribute | Callable[[Node], bool],
|
|
26
|
-
include_root: bool = True,
|
|
27
|
-
use_cache: bool = True,
|
|
28
|
-
) -> list[Node]:
|
|
29
|
-
return self.module.hyperparameters.select(
|
|
30
|
-
*predicates,
|
|
31
|
-
include_root=include_root,
|
|
32
|
-
use_cache=use_cache,
|
|
33
|
-
)
|
|
34
|
-
|
|
35
|
-
def update(
|
|
36
|
-
self,
|
|
37
|
-
*predicates: NodePredicate | NodeAttribute | Callable[[Node], bool],
|
|
38
|
-
strict: bool = True,
|
|
39
|
-
allow_extra: bool = False,
|
|
40
|
-
include_root: bool = True,
|
|
41
|
-
validate: bool = True,
|
|
42
|
-
use_cache: bool = False,
|
|
43
|
-
**values: Any,
|
|
44
|
-
) -> None:
|
|
45
|
-
self.module._assert_mutation_allowed("update")
|
|
46
|
-
self.module.hyperparameters.update(
|
|
47
|
-
*predicates,
|
|
48
|
-
strict=strict,
|
|
49
|
-
allow_extra=allow_extra,
|
|
50
|
-
include_root=include_root,
|
|
51
|
-
validate=validate,
|
|
52
|
-
use_cache=use_cache,
|
|
53
|
-
**values,
|
|
54
|
-
)
|
|
55
|
-
ModelGraph.rebuild(self.module)
|
|
56
|
-
self.module._reset_contracts()
|
|
57
|
-
|
|
58
|
-
def extend(
|
|
59
|
-
self,
|
|
60
|
-
*args: NodePredicate | NodeAttribute | Callable[[Node], bool] | SchemaField,
|
|
61
|
-
include_root: bool = True,
|
|
62
|
-
use_cache: bool = True,
|
|
63
|
-
) -> None:
|
|
64
|
-
self.module._assert_mutation_allowed("extend")
|
|
65
|
-
self.module.hyperparameters.extend(*args, include_root=include_root, use_cache=use_cache)
|
|
66
|
-
ModelGraph.rebuild(self.module)
|
|
67
|
-
self.module._reset_contracts()
|
|
68
|
-
|
|
69
|
-
def delete(
|
|
70
|
-
self,
|
|
71
|
-
*predicates: NodePredicate | NodeAttribute | Callable[[Node], bool],
|
|
72
|
-
include_root: bool = False,
|
|
73
|
-
use_cache: bool = True,
|
|
74
|
-
) -> None:
|
|
75
|
-
self.module._assert_mutation_allowed("delete")
|
|
76
|
-
self.module.hyperparameters.delete(*predicates, include_root=include_root, use_cache=use_cache)
|
|
77
|
-
ModelGraph.rebuild(self.module)
|
|
78
|
-
self.module._reset_contracts()
|
|
79
|
-
|
|
80
|
-
def reset(
|
|
81
|
-
self,
|
|
82
|
-
*predicates: NodePredicate | NodeAttribute | Callable[[Node], bool],
|
|
83
|
-
include_root: bool = True,
|
|
84
|
-
use_cache: bool = True,
|
|
85
|
-
descendants: bool = False,
|
|
86
|
-
) -> None:
|
|
87
|
-
self.module._assert_mutation_allowed("reset")
|
|
88
|
-
selected = self.module.hyperparameters.select(
|
|
89
|
-
*predicates,
|
|
90
|
-
include_root=include_root,
|
|
91
|
-
use_cache=use_cache,
|
|
92
|
-
)
|
|
93
|
-
if not selected:
|
|
94
|
-
raise ValueError("reset matched no nodes")
|
|
95
|
-
|
|
96
|
-
ModelGraph.reset_selected(self.module, selected, descendants=descendants)
|
|
97
|
-
self.module._reset_contracts()
|
|
98
|
-
|
|
99
|
-
@contextmanager
|
|
100
|
-
def override(
|
|
101
|
-
self,
|
|
102
|
-
*predicates: NodePredicate | NodeAttribute | Callable[[Node], bool],
|
|
103
|
-
strict: bool = True,
|
|
104
|
-
allow_extra: bool = False,
|
|
105
|
-
include_root: bool = True,
|
|
106
|
-
validate: bool = True,
|
|
107
|
-
use_cache: bool = False,
|
|
108
|
-
**values: Any,
|
|
109
|
-
) -> Iterator[None]:
|
|
110
|
-
self.module._assert_mutation_allowed("override")
|
|
111
|
-
try:
|
|
112
|
-
with self.module.hyperparameters.override(
|
|
113
|
-
*predicates,
|
|
114
|
-
strict=strict,
|
|
115
|
-
allow_extra=allow_extra,
|
|
116
|
-
include_root=include_root,
|
|
117
|
-
validate=validate,
|
|
118
|
-
use_cache=use_cache,
|
|
119
|
-
**values,
|
|
120
|
-
):
|
|
121
|
-
ModelGraph.rebuild(self.module)
|
|
122
|
-
self.module._reset_contracts()
|
|
123
|
-
yield
|
|
124
|
-
finally:
|
|
125
|
-
ModelGraph.rebuild(self.module)
|
|
126
|
-
self.module._reset_contracts()
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|