json2vec 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. json2vec/__init__.py +0 -0
  2. json2vec/__main__.py +32 -0
  3. json2vec/architecture/__init__.py +0 -0
  4. json2vec/architecture/attention.py +64 -0
  5. json2vec/architecture/counter.py +37 -0
  6. json2vec/architecture/encoder.py +88 -0
  7. json2vec/architecture/node.py +34 -0
  8. json2vec/architecture/pool.py +61 -0
  9. json2vec/architecture/root.py +338 -0
  10. json2vec/architecture/rotary.py +39 -0
  11. json2vec/data/__init__.py +0 -0
  12. json2vec/data/datasets.py +539 -0
  13. json2vec/data/processing.py +152 -0
  14. json2vec/entrypoints/__init__.py +3 -0
  15. json2vec/entrypoints/pipeline.py +174 -0
  16. json2vec/inference/__init__.py +0 -0
  17. json2vec/inference/callback.py +98 -0
  18. json2vec/inference/deployment.py +175 -0
  19. json2vec/logging/__init__.py +0 -0
  20. json2vec/logging/config.py +27 -0
  21. json2vec/logging/epoch.py +42 -0
  22. json2vec/logging/throughput.py +39 -0
  23. json2vec/logging/tracking.py +152 -0
  24. json2vec/processors/__init__.py +8 -0
  25. json2vec/processors/base.py +102 -0
  26. json2vec/processors/extensions/__init__.py +0 -0
  27. json2vec/processors/extensions/example.py +6 -0
  28. json2vec/processors/spec.py +8 -0
  29. json2vec/structs/__init__.py +0 -0
  30. json2vec/structs/enums.py +84 -0
  31. json2vec/structs/environment.py +138 -0
  32. json2vec/structs/experiment.py +330 -0
  33. json2vec/structs/packages.py +117 -0
  34. json2vec/structs/structure.py +70 -0
  35. json2vec/structs/tree.py +92 -0
  36. json2vec/tensorfields/__init__.py +8 -0
  37. json2vec/tensorfields/base.py +210 -0
  38. json2vec/tensorfields/extensions/__init__.py +0 -0
  39. json2vec/tensorfields/extensions/category.py +484 -0
  40. json2vec/tensorfields/extensions/dateparts.py +410 -0
  41. json2vec/tensorfields/extensions/entity.py +336 -0
  42. json2vec/tensorfields/extensions/number.py +400 -0
  43. json2vec/tensorfields/extensions/vector.py +279 -0
  44. json2vec/tensorfields/spec.py +8 -0
  45. json2vec-0.1.0.dist-info/METADATA +227 -0
  46. json2vec-0.1.0.dist-info/RECORD +51 -0
  47. json2vec-0.1.0.dist-info/WHEEL +5 -0
  48. json2vec-0.1.0.dist-info/entry_points.txt +2 -0
  49. json2vec-0.1.0.dist-info/licenses/LICENSE +178 -0
  50. json2vec-0.1.0.dist-info/licenses/NOTICE +8 -0
  51. json2vec-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,210 @@
1
+ from __future__ import annotations
2
+
3
+ import functools
4
+ import math
5
+ import re
6
+ from abc import ABC, abstractmethod
7
+ from typing import TYPE_CHECKING, Any, Callable, Type, TypeAlias
8
+
9
+ import pluggy
10
+ import torch
11
+ from tensordict import TensorDict
12
+
13
+ from json2vec.architecture.pool import LearnedQueryCrossAttention
14
+ from json2vec.structs.enums import Component, Metric, Strata, TensorKey
15
+ from json2vec.structs.packages import Parcel, Prediction
16
+ from json2vec.structs.tree import Address, Leaf, Node
17
+ from json2vec.tensorfields.spec import PluginSpec
18
+
19
+ if TYPE_CHECKING:
20
+ from json2vec.architecture.root import JSON2Vec
21
+ from json2vec.structs.experiment import Session
22
+ from json2vec.structs.structure import Structure
23
+
24
+ pm: pluggy.PluginManager = pluggy.PluginManager(project_name="tensorfields")
25
+
26
+ pm.add_hookspecs(module_or_class=PluginSpec)
27
+
28
+ RequestBase: TypeAlias = Leaf
29
+
30
+
31
+ class EmbedderBase(torch.nn.Module):
32
+ def __init__(self, structure: Structure, address: Address):
33
+ super().__init__()
34
+
35
+
36
+ class DecoderBase(torch.nn.Module):
37
+ def __init__(self, structure: Structure, address: Address):
38
+ super().__init__()
39
+
40
+ self.address: Address = address
41
+ self.sigma: torch.Tensor = torch.nn.Parameter(torch.zeros(1))
42
+
43
+ self.pool = LearnedQueryCrossAttention(
44
+ n_context=math.prod(structure.shapes[address]),
45
+ d_model=structure.d_model,
46
+ nhead=structure.requests[address].n_heads,
47
+ dropout=structure.dropout,
48
+ n_linear=structure.requests[address].n_linear,
49
+ )
50
+
51
+ def decode(self, pooled: torch.Tensor) -> TensorDict[TensorKey, torch.Tensor]:
52
+ raise NotImplementedError("decoder must implement decode(pooled)")
53
+
54
+ def forward(self, parcels: list[Parcel]) -> Prediction:
55
+ if len(parcels) == 0:
56
+ raise ValueError("decoder requires at least one parcel")
57
+
58
+ N, *_, C = parcels[0].payload.shape
59
+ stacked = torch.cat([parcel.payload.reshape(N, -1, C) for parcel in parcels], dim=1)
60
+ pooled = self.pool(stacked)
61
+
62
+ payload = self.decode(pooled)
63
+ return Prediction(
64
+ payload=payload,
65
+ address=self.address,
66
+ batch_size=pooled.shape[0],
67
+ )
68
+
69
+
70
+ class TensorFieldBase(ABC):
71
+ content: torch.Tensor
72
+ state: torch.Tensor
73
+ trainable: torch.Tensor
74
+ targets: TensorDict[TensorKey, torch.Tensor]
75
+
76
+ @classmethod
77
+ @abstractmethod
78
+ def new(
79
+ cls,
80
+ values: list,
81
+ address: Address,
82
+ session: Session,
83
+ strata: Strata,
84
+ state: Any,
85
+ ) -> "TensorFieldBase":
86
+ raise NotImplementedError
87
+
88
+ @abstractmethod
89
+ def mask(self, p_mask: float):
90
+ raise NotImplementedError
91
+
92
+ @abstractmethod
93
+ def prune(cls, p_prune: float):
94
+ raise NotImplementedError
95
+
96
+
97
+ TENSORFIELDS: dict[str, "Plugin"] = {}
98
+
99
+
100
+ class Plugin:
101
+ def __init__(self, name: str):
102
+ if not isinstance(name, str):
103
+ raise TypeError("Plugin name must be a string")
104
+
105
+ # should start with a letter and contain only lowercase letters, numbers, and underscores
106
+ if not re.match(r"^[a-z0-9_]+$", name):
107
+ raise ValueError("Plugin name must consist of lowercase letters, numbers, and underscores only")
108
+
109
+ self.name: str = name
110
+ self.components: dict[Component, Callable | Type] = {}
111
+
112
+ if name in TENSORFIELDS:
113
+ raise ValueError(f"Plugin '{name}' already registered")
114
+
115
+ TENSORFIELDS[name] = self
116
+
117
+ def register(self, obj: Type | Callable) -> Type | Callable:
118
+ if not hasattr(obj, "__name__"):
119
+ raise NameError(f"Object {obj} does not have a name")
120
+
121
+ name: str = str(obj.__name__)
122
+
123
+ if name in self.components:
124
+ raise ValueError(f"Component '{name}' already registered in plugin '{self.name}'")
125
+
126
+ match name:
127
+ case Component.Request:
128
+ if not isinstance(obj, type):
129
+ raise TypeError("Request must be a class type")
130
+
131
+ if not issubclass(obj, Node):
132
+ raise TypeError("Request must be a subclass of Node")
133
+
134
+ # for attr in Leaf.__annotations__.keys():
135
+ # if not hasattr(obj, attr):
136
+ # raise AttributeError(f"Request class must have a '{attr}' attribute")
137
+
138
+ # if getattr(obj, "type") != self.name:
139
+ # raise ValueError(
140
+ # f"Request class 'type' attribute must be '{self.name}', got '{getattr(obj, 'type')}'"
141
+ # )
142
+
143
+ case Component.TensorField:
144
+ if not isinstance(obj, type):
145
+ raise TypeError("TensorField must be a class type")
146
+
147
+ if not issubclass(obj, TensorFieldBase):
148
+ raise TypeError("TensorField must be a subclass of TensorFieldBase")
149
+
150
+ case Component.Embedder:
151
+ if not isinstance(obj, type):
152
+ raise TypeError("Embedder must be a class type")
153
+
154
+ if not issubclass(obj, EmbedderBase):
155
+ raise TypeError("Embedder must be a subclass of EmbedderBase")
156
+
157
+ # confirm the init method is expecting structure and address
158
+ init_params = list(obj.__init__.__annotations__.keys())
159
+ if "structure" not in init_params or "address" not in init_params:
160
+ raise TypeError("Embedder __init__ method must accept 'structure' and 'address' parameters")
161
+
162
+ case Component.Decoder:
163
+ if not isinstance(obj, type):
164
+ raise TypeError("Decoder must be a class type")
165
+
166
+ if not issubclass(obj, DecoderBase):
167
+ raise TypeError("Decoder must be a subclass of DecoderBase")
168
+
169
+ init_params = list(obj.__init__.__annotations__.keys())
170
+ if "structure" not in init_params or "address" not in init_params:
171
+ raise TypeError("Decoder __init__ method must accept 'structure' and 'address' parameters")
172
+
173
+ case Component.loss:
174
+ if not callable(obj):
175
+ raise TypeError("Loss must be a callable function")
176
+
177
+ expected_params: list[str] = ["module", "prediction", "batch", "strata"]
178
+ func_params: list[str] = list(obj.__annotations__.keys())
179
+
180
+ if not set(expected_params).issubset(set(func_params)):
181
+ raise TypeError(
182
+ f"Loss function must accept the following parameters: {expected_params}, got {func_params}"
183
+ )
184
+
185
+ case Component.write:
186
+ if not callable(obj):
187
+ raise TypeError("Write must be a callable function")
188
+
189
+ # check the signature of the function
190
+ expected_params: list[str] = ["module", "prediction"]
191
+ func_params: list[str] = list(obj.__annotations__.keys())
192
+
193
+ if func_params != expected_params:
194
+ raise TypeError(
195
+ f"Write function must accept the following parameters: {expected_params}, got {func_params}"
196
+ )
197
+
198
+ self.components[name] = obj
199
+
200
+ return obj
201
+
202
+ @functools.cache
203
+ def __getattr__(self, key: Component) -> Callable | Type:
204
+ if key not in Component:
205
+ raise ValueError(f"Component '{key}' is not a valid Component enum value")
206
+
207
+ if key in self.components:
208
+ return self.components[key]
209
+
210
+ raise AttributeError(f"Plugin '{self.name}' has no component '{key}'")
File without changes
@@ -0,0 +1,484 @@
1
+ from __future__ import annotations
2
+
3
+ from functools import partial
4
+ from multiprocessing import Manager
5
+ from multiprocessing.managers import ListProxy, SyncManager
6
+ from multiprocessing.synchronize import Lock
7
+ from typing import TYPE_CHECKING, Annotated, Literal
8
+
9
+ import numpy as np
10
+ import pydantic
11
+ import torch
12
+ from beartype import beartype
13
+ from ordered_set import OrderedSet
14
+ from tensordict import TensorDict, tensorclass
15
+
16
+ from json2vec.architecture.counter import Counter
17
+ from json2vec.data.processing import apply, pad
18
+ from json2vec.structs.enums import Metric, Strata, TensorKey, Tokens
19
+ from json2vec.structs.packages import Parcel, Prediction
20
+ from json2vec.structs.tree import Address
21
+ from json2vec.tensorfields.base import (
22
+ DecoderBase,
23
+ EmbedderBase,
24
+ Plugin,
25
+ RequestBase,
26
+ TensorFieldBase,
27
+ )
28
+
29
+ if TYPE_CHECKING:
30
+ from json2vec.architecture.root import JSON2Vec
31
+ from json2vec.structs.experiment import Session, Structure
32
+
33
+ category: Plugin = Plugin(name="category")
34
+
35
+ class Vocabulary:
36
+
37
+ def __init__(self, master: ListProxy, lock: Lock):
38
+
39
+ self.master: ListProxy[str] = master
40
+ self.lock: Lock = lock
41
+ self.vocab: OrderedSet[str] = OrderedSet(list(master))
42
+
43
+ def __call__(self, word: str, update: bool) -> int|None:
44
+
45
+ if word is None:
46
+ return None
47
+
48
+ if word in self.vocab:
49
+ return self.vocab.index(word)
50
+
51
+ if not update:
52
+ return None
53
+
54
+ # OK, it is not known locally... We will lock the global state and update the local vocab
55
+ with self.lock:
56
+ self.vocab: OrderedSet[str] = OrderedSet(list(self.master))
57
+
58
+ if word not in self.vocab:
59
+ self.vocab.add(word)
60
+ self.master.append(word)
61
+
62
+ return self.vocab.index(word)
63
+
64
+ def __len__(self) -> int:
65
+
66
+ return len(self.vocab)
67
+
68
+ class OnlineVocabularyModel(torch.nn.Module):
69
+ def __init__(self):
70
+ super().__init__()
71
+
72
+ self.manager: SyncManager = Manager()
73
+ self.master: ListProxy[str] = self.manager.list()
74
+ self.lock: Lock = self.manager.Lock()
75
+ self._snapshot_cache: list[str] | None = None
76
+ self._snapshot_size: int = -1
77
+
78
+ def _save_to_state_dict(self, state_dict, prefix, keep_vars):
79
+ super()._save_to_state_dict(state_dict, prefix, keep_vars)
80
+
81
+ state_dict[prefix + "vocabulary"] = list(self.master)
82
+
83
+ def _load_from_state_dict(
84
+ self, state_dict, prefix, local_metadata,
85
+ strict, missing_keys, unexpected_keys, error_msgs
86
+ ):
87
+
88
+ vocab: list[str] = state_dict.pop(prefix + "vocabulary")
89
+ self.master: ListProxy[str] = self.manager.list(vocab)
90
+ self._snapshot_cache = None
91
+ self._snapshot_size = -1
92
+
93
+ super()._load_from_state_dict(
94
+ state_dict, prefix, local_metadata,
95
+ strict, missing_keys, unexpected_keys, error_msgs
96
+ )
97
+
98
+ @property
99
+ def state(self) -> Vocabulary:
100
+
101
+ return Vocabulary(master=self.master, lock=self.lock)
102
+
103
+ def snapshot(self) -> list[str]:
104
+ size = len(self.master)
105
+ if self._snapshot_cache is None or self._snapshot_size != size:
106
+ self._snapshot_cache = list(self.master)
107
+ self._snapshot_size = size
108
+
109
+ return self._snapshot_cache
110
+
111
+
112
+ @category.register
113
+ class Request(RequestBase):
114
+ type: Literal["category"]
115
+ max_vocab_size: Annotated[int, pydantic.Field(gt=0, default=10_000)]
116
+ n_bands: Annotated[int, pydantic.Field(gt=0, default=8)]
117
+ topk: Annotated[list[int], pydantic.Field(default_factory=list)]
118
+
119
+ @pydantic.model_validator(mode="after")
120
+ def check_topk(self):
121
+
122
+ # enforce uniqueness
123
+ self.topk = sorted(set(self.topk))
124
+
125
+ for topk in self.topk:
126
+ if not isinstance(topk, int):
127
+ raise ValueError("topk values must be integers")
128
+
129
+ if topk <= 0:
130
+ raise ValueError("topk values must be positive")
131
+
132
+ if topk == 1:
133
+ raise ValueError("topk values must not be 1")
134
+
135
+ if topk >= self.max_vocab_size:
136
+ raise ValueError("topk values must be less than max_vocab_size")
137
+
138
+ return self
139
+
140
+
141
+
142
+ @category.register
143
+ @tensorclass
144
+ class TensorField(TensorFieldBase):
145
+ state: torch.Tensor
146
+ content: torch.Tensor
147
+ trainable: torch.Tensor
148
+ targets: TensorDict[TensorKey, torch.Tensor]
149
+
150
+ @classmethod
151
+ def new(
152
+ cls,
153
+ values: list,
154
+ address: Address,
155
+ session: Session,
156
+ strata: Strata,
157
+ state: Vocabulary,
158
+ ) -> TensorFieldBase:
159
+
160
+ context_shape: tuple[int, ...] = session.structure.shapes[address]
161
+
162
+ tokens = apply(values, partial(state, update=(strata == Strata.train)))
163
+
164
+ if len(state) > (max_vocab_size := session.structure.requests[address].max_vocab_size):
165
+ print(f"Vocab in address {address} exceeds max vocab size of {max_vocab_size}")
166
+
167
+ data, states = pad(
168
+ nested=tokens,
169
+ shape=(len(values), *context_shape),
170
+ dtype=np.int64,
171
+ pad_value=0,
172
+ )
173
+
174
+ state_tensor = torch.tensor(states, dtype=torch.int64)
175
+ content = torch.tensor(data=data, dtype=torch.int64)
176
+
177
+ return cls(
178
+ state=state_tensor,
179
+ content=content,
180
+ trainable=torch.zeros_like(input=state_tensor, dtype=torch.bool),
181
+ targets=TensorDict({}),
182
+ batch_size=len(values),
183
+ )
184
+
185
+ def mask(self, p_mask: float):
186
+ mask_token = torch.full_like(input=self.state, fill_value=Tokens.masked.value)
187
+ is_masked = torch.rand_like(input=self.state, dtype=torch.float).lt(other=p_mask)
188
+
189
+ if TensorKey.state not in self.targets.keys():
190
+ self.targets[TensorKey.state] = self.state.clone()
191
+
192
+ if TensorKey.content not in self.targets.keys():
193
+ self.targets[TensorKey.content] = self.content.clone()
194
+
195
+ self.state = self.state.masked_scatter(is_masked, mask_token)
196
+ self.content = self.content.masked_scatter(is_masked, torch.zeros_like(input=self.content))
197
+
198
+ self.trainable |= is_masked
199
+
200
+ def prune(self, p_prune: float = 1.0):
201
+ prune_tokens = torch.full_like(input=self.state, fill_value=Tokens.pruned.value)
202
+
203
+ is_pruned = (
204
+ torch.rand(self.state.size(0), *([1] * (len(self.state.shape) - 1)), device=self.state.device)
205
+ .lt(p_prune)
206
+ .expand_as(self.state)
207
+ )
208
+
209
+ if TensorKey.state not in self.targets.keys():
210
+ self.targets[TensorKey.state] = self.state.clone()
211
+
212
+ if TensorKey.content not in self.targets.keys():
213
+ self.targets[TensorKey.content] = self.content.clone()
214
+
215
+ self.state = self.state.masked_scatter(is_pruned, prune_tokens)
216
+ self.content = self.content.masked_scatter(is_pruned, torch.zeros_like(input=self.content))
217
+
218
+ self.trainable |= is_pruned
219
+
220
+ @classmethod
221
+ def empty(
222
+ cls,
223
+ batch_size: int,
224
+ address: Address,
225
+ structure: Structure,
226
+ ):
227
+ shape: tuple[int, ...] = (batch_size, *structure.shapes[address])
228
+
229
+ state = torch.full(shape, Tokens.pruned)
230
+ content = torch.zeros(shape, dtype=torch.int64)
231
+
232
+ return cls(
233
+ state=state,
234
+ content=content,
235
+ trainable=torch.zeros_like(input=state, dtype=torch.bool),
236
+ targets=TensorDict({}),
237
+ batch_size=batch_size,
238
+ )
239
+
240
+
241
+ @category.register
242
+ class Embedder(EmbedderBase):
243
+ def __init__(self, structure: Structure, address: Address):
244
+ super().__init__(structure=structure, address=address)
245
+
246
+ request: Request = structure.requests[address]
247
+ self.origin: Address = address
248
+ self.destination: Address = request.parent.address
249
+ self.max_vocab_size: int = request.max_vocab_size
250
+
251
+ self.vocab: OnlineVocabularyModel = OnlineVocabularyModel()
252
+
253
+ self.embeddings = torch.nn.ModuleDict(
254
+ {
255
+ TensorKey.state.name: torch.nn.Embedding(
256
+ num_embeddings=len(Tokens),
257
+ embedding_dim=structure.d_model,
258
+ ),
259
+ TensorKey.content.name: torch.nn.Embedding(
260
+ num_embeddings=request.max_vocab_size,
261
+ embedding_dim=structure.d_model,
262
+ ),
263
+ }
264
+ )
265
+
266
+ @beartype
267
+ def forward(self, inputs: TensorFieldBase) -> Parcel:
268
+ N: int
269
+ dims: tuple[int, ...]
270
+
271
+ N, *dims = inputs.state.shape
272
+ state = inputs.state.reshape(-1)
273
+ content = inputs.content.reshape(-1)
274
+ valued = state.eq(Tokens.valued.value)
275
+
276
+ if valued.any() and (content.masked_select(valued) >= self.max_vocab_size).any().item():
277
+ raise ValueError(f"Token in address {self.origin} exceeds max vocab size of {self.max_vocab_size}")
278
+
279
+ safe_content = content.masked_fill(~valued, 0)
280
+
281
+ embeddings: torch.Tensor = (
282
+ self.embeddings[TensorKey.state.name](state) +
283
+ self.embeddings[TensorKey.content.name](safe_content) * valued.unsqueeze(-1)
284
+ ).reshape(N, *dims, -1)
285
+
286
+
287
+ return Parcel(
288
+ payload=embeddings,
289
+ origin=self.origin,
290
+ destination=self.destination,
291
+ batch_size=N,
292
+ )
293
+
294
+ @property
295
+ def state(self) -> Vocabulary:
296
+ return self.vocab.state
297
+
298
+
299
+
300
+ @category.register
301
+ class Decoder(DecoderBase):
302
+ def __init__(self, structure: Structure, address: Address):
303
+ super().__init__(structure=structure, address=address)
304
+
305
+ request: RequestBase = structure.requests[address]
306
+
307
+ self.linears = torch.nn.ModuleDict(
308
+ {
309
+ TensorKey.state.name: torch.nn.Linear(
310
+ in_features=structure.d_model,
311
+ out_features=len(Tokens),
312
+ ),
313
+ TensorKey.content.name: torch.nn.Linear(
314
+ in_features=structure.d_model,
315
+ out_features=request.max_vocab_size,
316
+ ),
317
+ }
318
+ )
319
+
320
+ self.counters = torch.nn.ModuleDict(
321
+ {
322
+ TensorKey.state.name: Counter(address=address, size=len(Tokens)),
323
+ TensorKey.content.name: Counter(address=address, size=request.max_vocab_size),
324
+ }
325
+ )
326
+
327
+ @beartype
328
+ def decode(self, pooled: torch.Tensor) -> TensorDict[TensorKey, torch.Tensor]:
329
+ return TensorDict(
330
+ source={
331
+ TensorKey.state: self.linears[TensorKey.state.name](pooled),
332
+ TensorKey.content: self.linears[TensorKey.content.name](pooled),
333
+ }
334
+ )
335
+
336
+
337
+ @category.register
338
+ def loss(
339
+ module: JSON2Vec,
340
+ prediction: Prediction,
341
+ batch: TensorFieldBase,
342
+ strata: Strata,
343
+ ) -> torch.Tensor:
344
+ decoder: Decoder = module.nodes[prediction.address].decoder
345
+ N: int = batch.targets[TensorKey.state].numel()
346
+ trainable = batch.trainable.reshape(N)
347
+
348
+ state_inputs = prediction.payload[TensorKey.state].reshape(N, -1)
349
+ state_targets = batch.targets[TensorKey.state].reshape(N)
350
+ decoder.counters[TensorKey.state.name](batch.targets[TensorKey.state])
351
+
352
+ loss: torch.Tensor = module.track(
353
+ (prediction.address, strata, Metric.loss, TensorKey.state),
354
+ value=(
355
+ torch.nn.functional.cross_entropy(
356
+ input=state_inputs,
357
+ target=state_targets,
358
+ weight=decoder.counters[TensorKey.state.name].weight,
359
+ reduction="none",
360
+ )
361
+ .masked_select(trainable)
362
+ .mean()
363
+ )
364
+ )
365
+
366
+ module.track(
367
+ (prediction.address, strata, Metric.accuracy, TensorKey.state),
368
+ value=state_inputs.argmax(dim=1).eq(state_targets).masked_select(trainable).float().mean(),
369
+ )
370
+
371
+ valued = trainable & state_targets.eq(Tokens.valued.value)
372
+ if not valued.any():
373
+ return loss
374
+
375
+ content_inputs = prediction.payload[TensorKey.content].reshape(N, -1)
376
+ content_targets = batch.targets[TensorKey.content].reshape(N)
377
+ content_counter_values = content_targets.masked_select(state_targets.eq(Tokens.valued.value))
378
+ if content_counter_values.numel() > 0:
379
+ decoder.counters[TensorKey.content.name](content_counter_values)
380
+
381
+ loss += module.track(
382
+ (prediction.address, strata, Metric.loss, TensorKey.content),
383
+ value=(
384
+ torch.nn.functional.cross_entropy(
385
+ input=content_inputs,
386
+ target=content_targets,
387
+ weight=decoder.counters[TensorKey.content.name].weight,
388
+ reduction="none",
389
+ )
390
+ .masked_select(valued)
391
+ .mean()
392
+ )
393
+ )
394
+
395
+ for topk in module.session.structure.requests[prediction.address].topk:
396
+ module.track(
397
+ (prediction.address, strata, Metric.accuracy, f"top{topk}"),
398
+ value=(
399
+ content_inputs
400
+ .topk(k=topk, dim=1)
401
+ .indices.eq(content_targets.unsqueeze(1))
402
+ .any(dim=1)
403
+ .masked_select(valued).float().mean()
404
+ )
405
+ )
406
+
407
+ module.track(
408
+ (prediction.address, strata, Metric.accuracy, TensorKey.content),
409
+ value=content_inputs.argmax(dim=1).eq(content_targets).masked_select(valued).float().mean(),
410
+ )
411
+
412
+ return loss
413
+
414
+
415
+ @category.register
416
+ def write(module: JSON2Vec, prediction: Prediction):
417
+
418
+ node = module.nodes[prediction.address]
419
+ state_logits: torch.Tensor = prediction.payload[TensorKey.state]
420
+ content_logits: torch.Tensor = prediction.payload[TensorKey.content]
421
+
422
+ tokens = np.fromiter((token.name for token in Tokens), dtype=object, count=len(Tokens))
423
+ state_log_norm = state_logits.logsumexp(dim=-1, keepdim=True)
424
+ state_distribution = (state_logits - state_log_norm).exp().detach().float().cpu().numpy()
425
+ state_payload = {
426
+ token: state_distribution[..., index]
427
+ for index, token in enumerate(tokens.tolist())
428
+ }
429
+
430
+ vocab = np.array(node.embedder.vocab.snapshot(), dtype=object)
431
+ content_shape = tuple(state_distribution.shape[:-1])
432
+ content_labels = np.full(content_shape, None, dtype=object)
433
+ content_probabilities = np.zeros(content_shape, dtype=np.float32)
434
+
435
+ requested_ks: list[int] = module.session.structure.requests[prediction.address].topk
436
+ max_requested_k: int = max(requested_ks, default=0)
437
+
438
+ def _pack_candidates(labels: np.ndarray, probabilities: np.ndarray) -> list[dict[str, float]] | list:
439
+ if labels.ndim == 1:
440
+ return [
441
+ {"label": str(label), "probability": float(probability)}
442
+ for label, probability in zip(labels.tolist(), probabilities.tolist())
443
+ ]
444
+
445
+ return [_pack_candidates(labels[index], probabilities[index]) for index in range(labels.shape[0])]
446
+
447
+ def _empty_candidates(shape: tuple[int, ...]) -> list | None:
448
+ if len(shape) == 0:
449
+ return []
450
+
451
+ return [_empty_candidates(shape[1:]) for _ in range(shape[0])]
452
+
453
+ topk_payload: list | None = _empty_candidates(content_shape)
454
+
455
+ if len(vocab) > 0:
456
+ narrow: torch.Tensor = content_logits.narrow(dim=-1, start=0, length=len(vocab))
457
+ log_norm = narrow.logsumexp(dim=-1, keepdim=True)
458
+ max_logits, max_indices = narrow.max(dim=-1)
459
+ content_probabilities = (max_logits - log_norm.squeeze(-1)).exp().detach().float().cpu().numpy()
460
+
461
+ max_indices_np: np.ndarray = max_indices.detach().cpu().numpy().astype(np.int32)
462
+ content_labels = vocab[max_indices_np]
463
+
464
+ if max_requested_k > 0:
465
+ topk: int = min(max_requested_k, narrow.shape[-1])
466
+ topk_logits, topk_indices = narrow.topk(k=topk, dim=-1)
467
+ topk_probabilities = (topk_logits - log_norm).exp()
468
+
469
+ topk_indices_np: np.ndarray = topk_indices.detach().cpu().numpy().astype(np.int32)
470
+ topk_labels_np: np.ndarray = vocab[topk_indices_np]
471
+ topk_probabilities_np: np.ndarray = topk_probabilities.detach().float().cpu().numpy()
472
+ topk_payload = _pack_candidates(
473
+ labels=topk_labels_np,
474
+ probabilities=topk_probabilities_np,
475
+ )
476
+
477
+ return {
478
+ TensorKey.state.name: state_payload,
479
+ TensorKey.content.name: {
480
+ TensorKey.value.name: content_labels,
481
+ TensorKey.probability.name: content_probabilities,
482
+ TensorKey.topk.name: topk_payload,
483
+ },
484
+ }