json2vec 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- json2vec/__init__.py +0 -0
- json2vec/__main__.py +32 -0
- json2vec/architecture/__init__.py +0 -0
- json2vec/architecture/attention.py +64 -0
- json2vec/architecture/counter.py +37 -0
- json2vec/architecture/encoder.py +88 -0
- json2vec/architecture/node.py +34 -0
- json2vec/architecture/pool.py +61 -0
- json2vec/architecture/root.py +338 -0
- json2vec/architecture/rotary.py +39 -0
- json2vec/data/__init__.py +0 -0
- json2vec/data/datasets.py +539 -0
- json2vec/data/processing.py +152 -0
- json2vec/entrypoints/__init__.py +3 -0
- json2vec/entrypoints/pipeline.py +174 -0
- json2vec/inference/__init__.py +0 -0
- json2vec/inference/callback.py +98 -0
- json2vec/inference/deployment.py +175 -0
- json2vec/logging/__init__.py +0 -0
- json2vec/logging/config.py +27 -0
- json2vec/logging/epoch.py +42 -0
- json2vec/logging/throughput.py +39 -0
- json2vec/logging/tracking.py +152 -0
- json2vec/processors/__init__.py +8 -0
- json2vec/processors/base.py +102 -0
- json2vec/processors/extensions/__init__.py +0 -0
- json2vec/processors/extensions/example.py +6 -0
- json2vec/processors/spec.py +8 -0
- json2vec/structs/__init__.py +0 -0
- json2vec/structs/enums.py +84 -0
- json2vec/structs/environment.py +138 -0
- json2vec/structs/experiment.py +330 -0
- json2vec/structs/packages.py +117 -0
- json2vec/structs/structure.py +70 -0
- json2vec/structs/tree.py +92 -0
- json2vec/tensorfields/__init__.py +8 -0
- json2vec/tensorfields/base.py +210 -0
- json2vec/tensorfields/extensions/__init__.py +0 -0
- json2vec/tensorfields/extensions/category.py +484 -0
- json2vec/tensorfields/extensions/dateparts.py +410 -0
- json2vec/tensorfields/extensions/entity.py +336 -0
- json2vec/tensorfields/extensions/number.py +400 -0
- json2vec/tensorfields/extensions/vector.py +279 -0
- json2vec/tensorfields/spec.py +8 -0
- json2vec-0.1.0.dist-info/METADATA +227 -0
- json2vec-0.1.0.dist-info/RECORD +51 -0
- json2vec-0.1.0.dist-info/WHEEL +5 -0
- json2vec-0.1.0.dist-info/entry_points.txt +2 -0
- json2vec-0.1.0.dist-info/licenses/LICENSE +178 -0
- json2vec-0.1.0.dist-info/licenses/NOTICE +8 -0
- json2vec-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,210 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
import math
|
|
5
|
+
import re
|
|
6
|
+
from abc import ABC, abstractmethod
|
|
7
|
+
from typing import TYPE_CHECKING, Any, Callable, Type, TypeAlias
|
|
8
|
+
|
|
9
|
+
import pluggy
|
|
10
|
+
import torch
|
|
11
|
+
from tensordict import TensorDict
|
|
12
|
+
|
|
13
|
+
from json2vec.architecture.pool import LearnedQueryCrossAttention
|
|
14
|
+
from json2vec.structs.enums import Component, Metric, Strata, TensorKey
|
|
15
|
+
from json2vec.structs.packages import Parcel, Prediction
|
|
16
|
+
from json2vec.structs.tree import Address, Leaf, Node
|
|
17
|
+
from json2vec.tensorfields.spec import PluginSpec
|
|
18
|
+
|
|
19
|
+
if TYPE_CHECKING:
|
|
20
|
+
from json2vec.architecture.root import JSON2Vec
|
|
21
|
+
from json2vec.structs.experiment import Session
|
|
22
|
+
from json2vec.structs.structure import Structure
|
|
23
|
+
|
|
24
|
+
pm: pluggy.PluginManager = pluggy.PluginManager(project_name="tensorfields")
|
|
25
|
+
|
|
26
|
+
pm.add_hookspecs(module_or_class=PluginSpec)
|
|
27
|
+
|
|
28
|
+
RequestBase: TypeAlias = Leaf
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class EmbedderBase(torch.nn.Module):
|
|
32
|
+
def __init__(self, structure: Structure, address: Address):
|
|
33
|
+
super().__init__()
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class DecoderBase(torch.nn.Module):
|
|
37
|
+
def __init__(self, structure: Structure, address: Address):
|
|
38
|
+
super().__init__()
|
|
39
|
+
|
|
40
|
+
self.address: Address = address
|
|
41
|
+
self.sigma: torch.Tensor = torch.nn.Parameter(torch.zeros(1))
|
|
42
|
+
|
|
43
|
+
self.pool = LearnedQueryCrossAttention(
|
|
44
|
+
n_context=math.prod(structure.shapes[address]),
|
|
45
|
+
d_model=structure.d_model,
|
|
46
|
+
nhead=structure.requests[address].n_heads,
|
|
47
|
+
dropout=structure.dropout,
|
|
48
|
+
n_linear=structure.requests[address].n_linear,
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
def decode(self, pooled: torch.Tensor) -> TensorDict[TensorKey, torch.Tensor]:
|
|
52
|
+
raise NotImplementedError("decoder must implement decode(pooled)")
|
|
53
|
+
|
|
54
|
+
def forward(self, parcels: list[Parcel]) -> Prediction:
|
|
55
|
+
if len(parcels) == 0:
|
|
56
|
+
raise ValueError("decoder requires at least one parcel")
|
|
57
|
+
|
|
58
|
+
N, *_, C = parcels[0].payload.shape
|
|
59
|
+
stacked = torch.cat([parcel.payload.reshape(N, -1, C) for parcel in parcels], dim=1)
|
|
60
|
+
pooled = self.pool(stacked)
|
|
61
|
+
|
|
62
|
+
payload = self.decode(pooled)
|
|
63
|
+
return Prediction(
|
|
64
|
+
payload=payload,
|
|
65
|
+
address=self.address,
|
|
66
|
+
batch_size=pooled.shape[0],
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class TensorFieldBase(ABC):
|
|
71
|
+
content: torch.Tensor
|
|
72
|
+
state: torch.Tensor
|
|
73
|
+
trainable: torch.Tensor
|
|
74
|
+
targets: TensorDict[TensorKey, torch.Tensor]
|
|
75
|
+
|
|
76
|
+
@classmethod
|
|
77
|
+
@abstractmethod
|
|
78
|
+
def new(
|
|
79
|
+
cls,
|
|
80
|
+
values: list,
|
|
81
|
+
address: Address,
|
|
82
|
+
session: Session,
|
|
83
|
+
strata: Strata,
|
|
84
|
+
state: Any,
|
|
85
|
+
) -> "TensorFieldBase":
|
|
86
|
+
raise NotImplementedError
|
|
87
|
+
|
|
88
|
+
@abstractmethod
|
|
89
|
+
def mask(self, p_mask: float):
|
|
90
|
+
raise NotImplementedError
|
|
91
|
+
|
|
92
|
+
@abstractmethod
|
|
93
|
+
def prune(cls, p_prune: float):
|
|
94
|
+
raise NotImplementedError
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
TENSORFIELDS: dict[str, "Plugin"] = {}
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
class Plugin:
|
|
101
|
+
def __init__(self, name: str):
|
|
102
|
+
if not isinstance(name, str):
|
|
103
|
+
raise TypeError("Plugin name must be a string")
|
|
104
|
+
|
|
105
|
+
# should start with a letter and contain only lowercase letters, numbers, and underscores
|
|
106
|
+
if not re.match(r"^[a-z0-9_]+$", name):
|
|
107
|
+
raise ValueError("Plugin name must consist of lowercase letters, numbers, and underscores only")
|
|
108
|
+
|
|
109
|
+
self.name: str = name
|
|
110
|
+
self.components: dict[Component, Callable | Type] = {}
|
|
111
|
+
|
|
112
|
+
if name in TENSORFIELDS:
|
|
113
|
+
raise ValueError(f"Plugin '{name}' already registered")
|
|
114
|
+
|
|
115
|
+
TENSORFIELDS[name] = self
|
|
116
|
+
|
|
117
|
+
def register(self, obj: Type | Callable) -> Type | Callable:
|
|
118
|
+
if not hasattr(obj, "__name__"):
|
|
119
|
+
raise NameError(f"Object {obj} does not have a name")
|
|
120
|
+
|
|
121
|
+
name: str = str(obj.__name__)
|
|
122
|
+
|
|
123
|
+
if name in self.components:
|
|
124
|
+
raise ValueError(f"Component '{name}' already registered in plugin '{self.name}'")
|
|
125
|
+
|
|
126
|
+
match name:
|
|
127
|
+
case Component.Request:
|
|
128
|
+
if not isinstance(obj, type):
|
|
129
|
+
raise TypeError("Request must be a class type")
|
|
130
|
+
|
|
131
|
+
if not issubclass(obj, Node):
|
|
132
|
+
raise TypeError("Request must be a subclass of Node")
|
|
133
|
+
|
|
134
|
+
# for attr in Leaf.__annotations__.keys():
|
|
135
|
+
# if not hasattr(obj, attr):
|
|
136
|
+
# raise AttributeError(f"Request class must have a '{attr}' attribute")
|
|
137
|
+
|
|
138
|
+
# if getattr(obj, "type") != self.name:
|
|
139
|
+
# raise ValueError(
|
|
140
|
+
# f"Request class 'type' attribute must be '{self.name}', got '{getattr(obj, 'type')}'"
|
|
141
|
+
# )
|
|
142
|
+
|
|
143
|
+
case Component.TensorField:
|
|
144
|
+
if not isinstance(obj, type):
|
|
145
|
+
raise TypeError("TensorField must be a class type")
|
|
146
|
+
|
|
147
|
+
if not issubclass(obj, TensorFieldBase):
|
|
148
|
+
raise TypeError("TensorField must be a subclass of TensorFieldBase")
|
|
149
|
+
|
|
150
|
+
case Component.Embedder:
|
|
151
|
+
if not isinstance(obj, type):
|
|
152
|
+
raise TypeError("Embedder must be a class type")
|
|
153
|
+
|
|
154
|
+
if not issubclass(obj, EmbedderBase):
|
|
155
|
+
raise TypeError("Embedder must be a subclass of EmbedderBase")
|
|
156
|
+
|
|
157
|
+
# confirm the init method is expecting structure and address
|
|
158
|
+
init_params = list(obj.__init__.__annotations__.keys())
|
|
159
|
+
if "structure" not in init_params or "address" not in init_params:
|
|
160
|
+
raise TypeError("Embedder __init__ method must accept 'structure' and 'address' parameters")
|
|
161
|
+
|
|
162
|
+
case Component.Decoder:
|
|
163
|
+
if not isinstance(obj, type):
|
|
164
|
+
raise TypeError("Decoder must be a class type")
|
|
165
|
+
|
|
166
|
+
if not issubclass(obj, DecoderBase):
|
|
167
|
+
raise TypeError("Decoder must be a subclass of DecoderBase")
|
|
168
|
+
|
|
169
|
+
init_params = list(obj.__init__.__annotations__.keys())
|
|
170
|
+
if "structure" not in init_params or "address" not in init_params:
|
|
171
|
+
raise TypeError("Decoder __init__ method must accept 'structure' and 'address' parameters")
|
|
172
|
+
|
|
173
|
+
case Component.loss:
|
|
174
|
+
if not callable(obj):
|
|
175
|
+
raise TypeError("Loss must be a callable function")
|
|
176
|
+
|
|
177
|
+
expected_params: list[str] = ["module", "prediction", "batch", "strata"]
|
|
178
|
+
func_params: list[str] = list(obj.__annotations__.keys())
|
|
179
|
+
|
|
180
|
+
if not set(expected_params).issubset(set(func_params)):
|
|
181
|
+
raise TypeError(
|
|
182
|
+
f"Loss function must accept the following parameters: {expected_params}, got {func_params}"
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
case Component.write:
|
|
186
|
+
if not callable(obj):
|
|
187
|
+
raise TypeError("Write must be a callable function")
|
|
188
|
+
|
|
189
|
+
# check the signature of the function
|
|
190
|
+
expected_params: list[str] = ["module", "prediction"]
|
|
191
|
+
func_params: list[str] = list(obj.__annotations__.keys())
|
|
192
|
+
|
|
193
|
+
if func_params != expected_params:
|
|
194
|
+
raise TypeError(
|
|
195
|
+
f"Write function must accept the following parameters: {expected_params}, got {func_params}"
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
self.components[name] = obj
|
|
199
|
+
|
|
200
|
+
return obj
|
|
201
|
+
|
|
202
|
+
@functools.cache
|
|
203
|
+
def __getattr__(self, key: Component) -> Callable | Type:
|
|
204
|
+
if key not in Component:
|
|
205
|
+
raise ValueError(f"Component '{key}' is not a valid Component enum value")
|
|
206
|
+
|
|
207
|
+
if key in self.components:
|
|
208
|
+
return self.components[key]
|
|
209
|
+
|
|
210
|
+
raise AttributeError(f"Plugin '{self.name}' has no component '{key}'")
|
|
File without changes
|
|
@@ -0,0 +1,484 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from functools import partial
|
|
4
|
+
from multiprocessing import Manager
|
|
5
|
+
from multiprocessing.managers import ListProxy, SyncManager
|
|
6
|
+
from multiprocessing.synchronize import Lock
|
|
7
|
+
from typing import TYPE_CHECKING, Annotated, Literal
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
import pydantic
|
|
11
|
+
import torch
|
|
12
|
+
from beartype import beartype
|
|
13
|
+
from ordered_set import OrderedSet
|
|
14
|
+
from tensordict import TensorDict, tensorclass
|
|
15
|
+
|
|
16
|
+
from json2vec.architecture.counter import Counter
|
|
17
|
+
from json2vec.data.processing import apply, pad
|
|
18
|
+
from json2vec.structs.enums import Metric, Strata, TensorKey, Tokens
|
|
19
|
+
from json2vec.structs.packages import Parcel, Prediction
|
|
20
|
+
from json2vec.structs.tree import Address
|
|
21
|
+
from json2vec.tensorfields.base import (
|
|
22
|
+
DecoderBase,
|
|
23
|
+
EmbedderBase,
|
|
24
|
+
Plugin,
|
|
25
|
+
RequestBase,
|
|
26
|
+
TensorFieldBase,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
if TYPE_CHECKING:
|
|
30
|
+
from json2vec.architecture.root import JSON2Vec
|
|
31
|
+
from json2vec.structs.experiment import Session, Structure
|
|
32
|
+
|
|
33
|
+
category: Plugin = Plugin(name="category")
|
|
34
|
+
|
|
35
|
+
class Vocabulary:
|
|
36
|
+
|
|
37
|
+
def __init__(self, master: ListProxy, lock: Lock):
|
|
38
|
+
|
|
39
|
+
self.master: ListProxy[str] = master
|
|
40
|
+
self.lock: Lock = lock
|
|
41
|
+
self.vocab: OrderedSet[str] = OrderedSet(list(master))
|
|
42
|
+
|
|
43
|
+
def __call__(self, word: str, update: bool) -> int|None:
|
|
44
|
+
|
|
45
|
+
if word is None:
|
|
46
|
+
return None
|
|
47
|
+
|
|
48
|
+
if word in self.vocab:
|
|
49
|
+
return self.vocab.index(word)
|
|
50
|
+
|
|
51
|
+
if not update:
|
|
52
|
+
return None
|
|
53
|
+
|
|
54
|
+
# OK, it is not known locally... We will lock the global state and update the local vocab
|
|
55
|
+
with self.lock:
|
|
56
|
+
self.vocab: OrderedSet[str] = OrderedSet(list(self.master))
|
|
57
|
+
|
|
58
|
+
if word not in self.vocab:
|
|
59
|
+
self.vocab.add(word)
|
|
60
|
+
self.master.append(word)
|
|
61
|
+
|
|
62
|
+
return self.vocab.index(word)
|
|
63
|
+
|
|
64
|
+
def __len__(self) -> int:
|
|
65
|
+
|
|
66
|
+
return len(self.vocab)
|
|
67
|
+
|
|
68
|
+
class OnlineVocabularyModel(torch.nn.Module):
|
|
69
|
+
def __init__(self):
|
|
70
|
+
super().__init__()
|
|
71
|
+
|
|
72
|
+
self.manager: SyncManager = Manager()
|
|
73
|
+
self.master: ListProxy[str] = self.manager.list()
|
|
74
|
+
self.lock: Lock = self.manager.Lock()
|
|
75
|
+
self._snapshot_cache: list[str] | None = None
|
|
76
|
+
self._snapshot_size: int = -1
|
|
77
|
+
|
|
78
|
+
def _save_to_state_dict(self, state_dict, prefix, keep_vars):
|
|
79
|
+
super()._save_to_state_dict(state_dict, prefix, keep_vars)
|
|
80
|
+
|
|
81
|
+
state_dict[prefix + "vocabulary"] = list(self.master)
|
|
82
|
+
|
|
83
|
+
def _load_from_state_dict(
|
|
84
|
+
self, state_dict, prefix, local_metadata,
|
|
85
|
+
strict, missing_keys, unexpected_keys, error_msgs
|
|
86
|
+
):
|
|
87
|
+
|
|
88
|
+
vocab: list[str] = state_dict.pop(prefix + "vocabulary")
|
|
89
|
+
self.master: ListProxy[str] = self.manager.list(vocab)
|
|
90
|
+
self._snapshot_cache = None
|
|
91
|
+
self._snapshot_size = -1
|
|
92
|
+
|
|
93
|
+
super()._load_from_state_dict(
|
|
94
|
+
state_dict, prefix, local_metadata,
|
|
95
|
+
strict, missing_keys, unexpected_keys, error_msgs
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
@property
|
|
99
|
+
def state(self) -> Vocabulary:
|
|
100
|
+
|
|
101
|
+
return Vocabulary(master=self.master, lock=self.lock)
|
|
102
|
+
|
|
103
|
+
def snapshot(self) -> list[str]:
|
|
104
|
+
size = len(self.master)
|
|
105
|
+
if self._snapshot_cache is None or self._snapshot_size != size:
|
|
106
|
+
self._snapshot_cache = list(self.master)
|
|
107
|
+
self._snapshot_size = size
|
|
108
|
+
|
|
109
|
+
return self._snapshot_cache
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
@category.register
|
|
113
|
+
class Request(RequestBase):
|
|
114
|
+
type: Literal["category"]
|
|
115
|
+
max_vocab_size: Annotated[int, pydantic.Field(gt=0, default=10_000)]
|
|
116
|
+
n_bands: Annotated[int, pydantic.Field(gt=0, default=8)]
|
|
117
|
+
topk: Annotated[list[int], pydantic.Field(default_factory=list)]
|
|
118
|
+
|
|
119
|
+
@pydantic.model_validator(mode="after")
|
|
120
|
+
def check_topk(self):
|
|
121
|
+
|
|
122
|
+
# enforce uniqueness
|
|
123
|
+
self.topk = sorted(set(self.topk))
|
|
124
|
+
|
|
125
|
+
for topk in self.topk:
|
|
126
|
+
if not isinstance(topk, int):
|
|
127
|
+
raise ValueError("topk values must be integers")
|
|
128
|
+
|
|
129
|
+
if topk <= 0:
|
|
130
|
+
raise ValueError("topk values must be positive")
|
|
131
|
+
|
|
132
|
+
if topk == 1:
|
|
133
|
+
raise ValueError("topk values must not be 1")
|
|
134
|
+
|
|
135
|
+
if topk >= self.max_vocab_size:
|
|
136
|
+
raise ValueError("topk values must be less than max_vocab_size")
|
|
137
|
+
|
|
138
|
+
return self
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
@category.register
|
|
143
|
+
@tensorclass
|
|
144
|
+
class TensorField(TensorFieldBase):
|
|
145
|
+
state: torch.Tensor
|
|
146
|
+
content: torch.Tensor
|
|
147
|
+
trainable: torch.Tensor
|
|
148
|
+
targets: TensorDict[TensorKey, torch.Tensor]
|
|
149
|
+
|
|
150
|
+
@classmethod
|
|
151
|
+
def new(
|
|
152
|
+
cls,
|
|
153
|
+
values: list,
|
|
154
|
+
address: Address,
|
|
155
|
+
session: Session,
|
|
156
|
+
strata: Strata,
|
|
157
|
+
state: Vocabulary,
|
|
158
|
+
) -> TensorFieldBase:
|
|
159
|
+
|
|
160
|
+
context_shape: tuple[int, ...] = session.structure.shapes[address]
|
|
161
|
+
|
|
162
|
+
tokens = apply(values, partial(state, update=(strata == Strata.train)))
|
|
163
|
+
|
|
164
|
+
if len(state) > (max_vocab_size := session.structure.requests[address].max_vocab_size):
|
|
165
|
+
print(f"Vocab in address {address} exceeds max vocab size of {max_vocab_size}")
|
|
166
|
+
|
|
167
|
+
data, states = pad(
|
|
168
|
+
nested=tokens,
|
|
169
|
+
shape=(len(values), *context_shape),
|
|
170
|
+
dtype=np.int64,
|
|
171
|
+
pad_value=0,
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
state_tensor = torch.tensor(states, dtype=torch.int64)
|
|
175
|
+
content = torch.tensor(data=data, dtype=torch.int64)
|
|
176
|
+
|
|
177
|
+
return cls(
|
|
178
|
+
state=state_tensor,
|
|
179
|
+
content=content,
|
|
180
|
+
trainable=torch.zeros_like(input=state_tensor, dtype=torch.bool),
|
|
181
|
+
targets=TensorDict({}),
|
|
182
|
+
batch_size=len(values),
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
def mask(self, p_mask: float):
|
|
186
|
+
mask_token = torch.full_like(input=self.state, fill_value=Tokens.masked.value)
|
|
187
|
+
is_masked = torch.rand_like(input=self.state, dtype=torch.float).lt(other=p_mask)
|
|
188
|
+
|
|
189
|
+
if TensorKey.state not in self.targets.keys():
|
|
190
|
+
self.targets[TensorKey.state] = self.state.clone()
|
|
191
|
+
|
|
192
|
+
if TensorKey.content not in self.targets.keys():
|
|
193
|
+
self.targets[TensorKey.content] = self.content.clone()
|
|
194
|
+
|
|
195
|
+
self.state = self.state.masked_scatter(is_masked, mask_token)
|
|
196
|
+
self.content = self.content.masked_scatter(is_masked, torch.zeros_like(input=self.content))
|
|
197
|
+
|
|
198
|
+
self.trainable |= is_masked
|
|
199
|
+
|
|
200
|
+
def prune(self, p_prune: float = 1.0):
|
|
201
|
+
prune_tokens = torch.full_like(input=self.state, fill_value=Tokens.pruned.value)
|
|
202
|
+
|
|
203
|
+
is_pruned = (
|
|
204
|
+
torch.rand(self.state.size(0), *([1] * (len(self.state.shape) - 1)), device=self.state.device)
|
|
205
|
+
.lt(p_prune)
|
|
206
|
+
.expand_as(self.state)
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
if TensorKey.state not in self.targets.keys():
|
|
210
|
+
self.targets[TensorKey.state] = self.state.clone()
|
|
211
|
+
|
|
212
|
+
if TensorKey.content not in self.targets.keys():
|
|
213
|
+
self.targets[TensorKey.content] = self.content.clone()
|
|
214
|
+
|
|
215
|
+
self.state = self.state.masked_scatter(is_pruned, prune_tokens)
|
|
216
|
+
self.content = self.content.masked_scatter(is_pruned, torch.zeros_like(input=self.content))
|
|
217
|
+
|
|
218
|
+
self.trainable |= is_pruned
|
|
219
|
+
|
|
220
|
+
@classmethod
|
|
221
|
+
def empty(
|
|
222
|
+
cls,
|
|
223
|
+
batch_size: int,
|
|
224
|
+
address: Address,
|
|
225
|
+
structure: Structure,
|
|
226
|
+
):
|
|
227
|
+
shape: tuple[int, ...] = (batch_size, *structure.shapes[address])
|
|
228
|
+
|
|
229
|
+
state = torch.full(shape, Tokens.pruned)
|
|
230
|
+
content = torch.zeros(shape, dtype=torch.int64)
|
|
231
|
+
|
|
232
|
+
return cls(
|
|
233
|
+
state=state,
|
|
234
|
+
content=content,
|
|
235
|
+
trainable=torch.zeros_like(input=state, dtype=torch.bool),
|
|
236
|
+
targets=TensorDict({}),
|
|
237
|
+
batch_size=batch_size,
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
|
|
241
|
+
@category.register
|
|
242
|
+
class Embedder(EmbedderBase):
|
|
243
|
+
def __init__(self, structure: Structure, address: Address):
|
|
244
|
+
super().__init__(structure=structure, address=address)
|
|
245
|
+
|
|
246
|
+
request: Request = structure.requests[address]
|
|
247
|
+
self.origin: Address = address
|
|
248
|
+
self.destination: Address = request.parent.address
|
|
249
|
+
self.max_vocab_size: int = request.max_vocab_size
|
|
250
|
+
|
|
251
|
+
self.vocab: OnlineVocabularyModel = OnlineVocabularyModel()
|
|
252
|
+
|
|
253
|
+
self.embeddings = torch.nn.ModuleDict(
|
|
254
|
+
{
|
|
255
|
+
TensorKey.state.name: torch.nn.Embedding(
|
|
256
|
+
num_embeddings=len(Tokens),
|
|
257
|
+
embedding_dim=structure.d_model,
|
|
258
|
+
),
|
|
259
|
+
TensorKey.content.name: torch.nn.Embedding(
|
|
260
|
+
num_embeddings=request.max_vocab_size,
|
|
261
|
+
embedding_dim=structure.d_model,
|
|
262
|
+
),
|
|
263
|
+
}
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
@beartype
|
|
267
|
+
def forward(self, inputs: TensorFieldBase) -> Parcel:
|
|
268
|
+
N: int
|
|
269
|
+
dims: tuple[int, ...]
|
|
270
|
+
|
|
271
|
+
N, *dims = inputs.state.shape
|
|
272
|
+
state = inputs.state.reshape(-1)
|
|
273
|
+
content = inputs.content.reshape(-1)
|
|
274
|
+
valued = state.eq(Tokens.valued.value)
|
|
275
|
+
|
|
276
|
+
if valued.any() and (content.masked_select(valued) >= self.max_vocab_size).any().item():
|
|
277
|
+
raise ValueError(f"Token in address {self.origin} exceeds max vocab size of {self.max_vocab_size}")
|
|
278
|
+
|
|
279
|
+
safe_content = content.masked_fill(~valued, 0)
|
|
280
|
+
|
|
281
|
+
embeddings: torch.Tensor = (
|
|
282
|
+
self.embeddings[TensorKey.state.name](state) +
|
|
283
|
+
self.embeddings[TensorKey.content.name](safe_content) * valued.unsqueeze(-1)
|
|
284
|
+
).reshape(N, *dims, -1)
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
return Parcel(
|
|
288
|
+
payload=embeddings,
|
|
289
|
+
origin=self.origin,
|
|
290
|
+
destination=self.destination,
|
|
291
|
+
batch_size=N,
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
@property
|
|
295
|
+
def state(self) -> Vocabulary:
|
|
296
|
+
return self.vocab.state
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
@category.register
|
|
301
|
+
class Decoder(DecoderBase):
|
|
302
|
+
def __init__(self, structure: Structure, address: Address):
|
|
303
|
+
super().__init__(structure=structure, address=address)
|
|
304
|
+
|
|
305
|
+
request: RequestBase = structure.requests[address]
|
|
306
|
+
|
|
307
|
+
self.linears = torch.nn.ModuleDict(
|
|
308
|
+
{
|
|
309
|
+
TensorKey.state.name: torch.nn.Linear(
|
|
310
|
+
in_features=structure.d_model,
|
|
311
|
+
out_features=len(Tokens),
|
|
312
|
+
),
|
|
313
|
+
TensorKey.content.name: torch.nn.Linear(
|
|
314
|
+
in_features=structure.d_model,
|
|
315
|
+
out_features=request.max_vocab_size,
|
|
316
|
+
),
|
|
317
|
+
}
|
|
318
|
+
)
|
|
319
|
+
|
|
320
|
+
self.counters = torch.nn.ModuleDict(
|
|
321
|
+
{
|
|
322
|
+
TensorKey.state.name: Counter(address=address, size=len(Tokens)),
|
|
323
|
+
TensorKey.content.name: Counter(address=address, size=request.max_vocab_size),
|
|
324
|
+
}
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
@beartype
|
|
328
|
+
def decode(self, pooled: torch.Tensor) -> TensorDict[TensorKey, torch.Tensor]:
|
|
329
|
+
return TensorDict(
|
|
330
|
+
source={
|
|
331
|
+
TensorKey.state: self.linears[TensorKey.state.name](pooled),
|
|
332
|
+
TensorKey.content: self.linears[TensorKey.content.name](pooled),
|
|
333
|
+
}
|
|
334
|
+
)
|
|
335
|
+
|
|
336
|
+
|
|
337
|
+
@category.register
|
|
338
|
+
def loss(
|
|
339
|
+
module: JSON2Vec,
|
|
340
|
+
prediction: Prediction,
|
|
341
|
+
batch: TensorFieldBase,
|
|
342
|
+
strata: Strata,
|
|
343
|
+
) -> torch.Tensor:
|
|
344
|
+
decoder: Decoder = module.nodes[prediction.address].decoder
|
|
345
|
+
N: int = batch.targets[TensorKey.state].numel()
|
|
346
|
+
trainable = batch.trainable.reshape(N)
|
|
347
|
+
|
|
348
|
+
state_inputs = prediction.payload[TensorKey.state].reshape(N, -1)
|
|
349
|
+
state_targets = batch.targets[TensorKey.state].reshape(N)
|
|
350
|
+
decoder.counters[TensorKey.state.name](batch.targets[TensorKey.state])
|
|
351
|
+
|
|
352
|
+
loss: torch.Tensor = module.track(
|
|
353
|
+
(prediction.address, strata, Metric.loss, TensorKey.state),
|
|
354
|
+
value=(
|
|
355
|
+
torch.nn.functional.cross_entropy(
|
|
356
|
+
input=state_inputs,
|
|
357
|
+
target=state_targets,
|
|
358
|
+
weight=decoder.counters[TensorKey.state.name].weight,
|
|
359
|
+
reduction="none",
|
|
360
|
+
)
|
|
361
|
+
.masked_select(trainable)
|
|
362
|
+
.mean()
|
|
363
|
+
)
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
module.track(
|
|
367
|
+
(prediction.address, strata, Metric.accuracy, TensorKey.state),
|
|
368
|
+
value=state_inputs.argmax(dim=1).eq(state_targets).masked_select(trainable).float().mean(),
|
|
369
|
+
)
|
|
370
|
+
|
|
371
|
+
valued = trainable & state_targets.eq(Tokens.valued.value)
|
|
372
|
+
if not valued.any():
|
|
373
|
+
return loss
|
|
374
|
+
|
|
375
|
+
content_inputs = prediction.payload[TensorKey.content].reshape(N, -1)
|
|
376
|
+
content_targets = batch.targets[TensorKey.content].reshape(N)
|
|
377
|
+
content_counter_values = content_targets.masked_select(state_targets.eq(Tokens.valued.value))
|
|
378
|
+
if content_counter_values.numel() > 0:
|
|
379
|
+
decoder.counters[TensorKey.content.name](content_counter_values)
|
|
380
|
+
|
|
381
|
+
loss += module.track(
|
|
382
|
+
(prediction.address, strata, Metric.loss, TensorKey.content),
|
|
383
|
+
value=(
|
|
384
|
+
torch.nn.functional.cross_entropy(
|
|
385
|
+
input=content_inputs,
|
|
386
|
+
target=content_targets,
|
|
387
|
+
weight=decoder.counters[TensorKey.content.name].weight,
|
|
388
|
+
reduction="none",
|
|
389
|
+
)
|
|
390
|
+
.masked_select(valued)
|
|
391
|
+
.mean()
|
|
392
|
+
)
|
|
393
|
+
)
|
|
394
|
+
|
|
395
|
+
for topk in module.session.structure.requests[prediction.address].topk:
|
|
396
|
+
module.track(
|
|
397
|
+
(prediction.address, strata, Metric.accuracy, f"top{topk}"),
|
|
398
|
+
value=(
|
|
399
|
+
content_inputs
|
|
400
|
+
.topk(k=topk, dim=1)
|
|
401
|
+
.indices.eq(content_targets.unsqueeze(1))
|
|
402
|
+
.any(dim=1)
|
|
403
|
+
.masked_select(valued).float().mean()
|
|
404
|
+
)
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
module.track(
|
|
408
|
+
(prediction.address, strata, Metric.accuracy, TensorKey.content),
|
|
409
|
+
value=content_inputs.argmax(dim=1).eq(content_targets).masked_select(valued).float().mean(),
|
|
410
|
+
)
|
|
411
|
+
|
|
412
|
+
return loss
|
|
413
|
+
|
|
414
|
+
|
|
415
|
+
@category.register
|
|
416
|
+
def write(module: JSON2Vec, prediction: Prediction):
|
|
417
|
+
|
|
418
|
+
node = module.nodes[prediction.address]
|
|
419
|
+
state_logits: torch.Tensor = prediction.payload[TensorKey.state]
|
|
420
|
+
content_logits: torch.Tensor = prediction.payload[TensorKey.content]
|
|
421
|
+
|
|
422
|
+
tokens = np.fromiter((token.name for token in Tokens), dtype=object, count=len(Tokens))
|
|
423
|
+
state_log_norm = state_logits.logsumexp(dim=-1, keepdim=True)
|
|
424
|
+
state_distribution = (state_logits - state_log_norm).exp().detach().float().cpu().numpy()
|
|
425
|
+
state_payload = {
|
|
426
|
+
token: state_distribution[..., index]
|
|
427
|
+
for index, token in enumerate(tokens.tolist())
|
|
428
|
+
}
|
|
429
|
+
|
|
430
|
+
vocab = np.array(node.embedder.vocab.snapshot(), dtype=object)
|
|
431
|
+
content_shape = tuple(state_distribution.shape[:-1])
|
|
432
|
+
content_labels = np.full(content_shape, None, dtype=object)
|
|
433
|
+
content_probabilities = np.zeros(content_shape, dtype=np.float32)
|
|
434
|
+
|
|
435
|
+
requested_ks: list[int] = module.session.structure.requests[prediction.address].topk
|
|
436
|
+
max_requested_k: int = max(requested_ks, default=0)
|
|
437
|
+
|
|
438
|
+
def _pack_candidates(labels: np.ndarray, probabilities: np.ndarray) -> list[dict[str, float]] | list:
|
|
439
|
+
if labels.ndim == 1:
|
|
440
|
+
return [
|
|
441
|
+
{"label": str(label), "probability": float(probability)}
|
|
442
|
+
for label, probability in zip(labels.tolist(), probabilities.tolist())
|
|
443
|
+
]
|
|
444
|
+
|
|
445
|
+
return [_pack_candidates(labels[index], probabilities[index]) for index in range(labels.shape[0])]
|
|
446
|
+
|
|
447
|
+
def _empty_candidates(shape: tuple[int, ...]) -> list | None:
|
|
448
|
+
if len(shape) == 0:
|
|
449
|
+
return []
|
|
450
|
+
|
|
451
|
+
return [_empty_candidates(shape[1:]) for _ in range(shape[0])]
|
|
452
|
+
|
|
453
|
+
topk_payload: list | None = _empty_candidates(content_shape)
|
|
454
|
+
|
|
455
|
+
if len(vocab) > 0:
|
|
456
|
+
narrow: torch.Tensor = content_logits.narrow(dim=-1, start=0, length=len(vocab))
|
|
457
|
+
log_norm = narrow.logsumexp(dim=-1, keepdim=True)
|
|
458
|
+
max_logits, max_indices = narrow.max(dim=-1)
|
|
459
|
+
content_probabilities = (max_logits - log_norm.squeeze(-1)).exp().detach().float().cpu().numpy()
|
|
460
|
+
|
|
461
|
+
max_indices_np: np.ndarray = max_indices.detach().cpu().numpy().astype(np.int32)
|
|
462
|
+
content_labels = vocab[max_indices_np]
|
|
463
|
+
|
|
464
|
+
if max_requested_k > 0:
|
|
465
|
+
topk: int = min(max_requested_k, narrow.shape[-1])
|
|
466
|
+
topk_logits, topk_indices = narrow.topk(k=topk, dim=-1)
|
|
467
|
+
topk_probabilities = (topk_logits - log_norm).exp()
|
|
468
|
+
|
|
469
|
+
topk_indices_np: np.ndarray = topk_indices.detach().cpu().numpy().astype(np.int32)
|
|
470
|
+
topk_labels_np: np.ndarray = vocab[topk_indices_np]
|
|
471
|
+
topk_probabilities_np: np.ndarray = topk_probabilities.detach().float().cpu().numpy()
|
|
472
|
+
topk_payload = _pack_candidates(
|
|
473
|
+
labels=topk_labels_np,
|
|
474
|
+
probabilities=topk_probabilities_np,
|
|
475
|
+
)
|
|
476
|
+
|
|
477
|
+
return {
|
|
478
|
+
TensorKey.state.name: state_payload,
|
|
479
|
+
TensorKey.content.name: {
|
|
480
|
+
TensorKey.value.name: content_labels,
|
|
481
|
+
TensorKey.probability.name: content_probabilities,
|
|
482
|
+
TensorKey.topk.name: topk_payload,
|
|
483
|
+
},
|
|
484
|
+
}
|