json2vec 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. json2vec/__init__.py +0 -0
  2. json2vec/__main__.py +32 -0
  3. json2vec/architecture/__init__.py +0 -0
  4. json2vec/architecture/attention.py +64 -0
  5. json2vec/architecture/counter.py +37 -0
  6. json2vec/architecture/encoder.py +88 -0
  7. json2vec/architecture/node.py +34 -0
  8. json2vec/architecture/pool.py +61 -0
  9. json2vec/architecture/root.py +338 -0
  10. json2vec/architecture/rotary.py +39 -0
  11. json2vec/data/__init__.py +0 -0
  12. json2vec/data/datasets.py +539 -0
  13. json2vec/data/processing.py +152 -0
  14. json2vec/entrypoints/__init__.py +3 -0
  15. json2vec/entrypoints/pipeline.py +174 -0
  16. json2vec/inference/__init__.py +0 -0
  17. json2vec/inference/callback.py +98 -0
  18. json2vec/inference/deployment.py +175 -0
  19. json2vec/logging/__init__.py +0 -0
  20. json2vec/logging/config.py +27 -0
  21. json2vec/logging/epoch.py +42 -0
  22. json2vec/logging/throughput.py +39 -0
  23. json2vec/logging/tracking.py +152 -0
  24. json2vec/processors/__init__.py +8 -0
  25. json2vec/processors/base.py +102 -0
  26. json2vec/processors/extensions/__init__.py +0 -0
  27. json2vec/processors/extensions/example.py +6 -0
  28. json2vec/processors/spec.py +8 -0
  29. json2vec/structs/__init__.py +0 -0
  30. json2vec/structs/enums.py +84 -0
  31. json2vec/structs/environment.py +138 -0
  32. json2vec/structs/experiment.py +330 -0
  33. json2vec/structs/packages.py +117 -0
  34. json2vec/structs/structure.py +70 -0
  35. json2vec/structs/tree.py +92 -0
  36. json2vec/tensorfields/__init__.py +8 -0
  37. json2vec/tensorfields/base.py +210 -0
  38. json2vec/tensorfields/extensions/__init__.py +0 -0
  39. json2vec/tensorfields/extensions/category.py +484 -0
  40. json2vec/tensorfields/extensions/dateparts.py +410 -0
  41. json2vec/tensorfields/extensions/entity.py +336 -0
  42. json2vec/tensorfields/extensions/number.py +400 -0
  43. json2vec/tensorfields/extensions/vector.py +279 -0
  44. json2vec/tensorfields/spec.py +8 -0
  45. json2vec-0.1.0.dist-info/METADATA +227 -0
  46. json2vec-0.1.0.dist-info/RECORD +51 -0
  47. json2vec-0.1.0.dist-info/WHEEL +5 -0
  48. json2vec-0.1.0.dist-info/entry_points.txt +2 -0
  49. json2vec-0.1.0.dist-info/licenses/LICENSE +178 -0
  50. json2vec-0.1.0.dist-info/licenses/NOTICE +8 -0
  51. json2vec-0.1.0.dist-info/top_level.txt +1 -0
json2vec/__init__.py ADDED
File without changes
json2vec/__main__.py ADDED
@@ -0,0 +1,32 @@
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+
5
+ from loguru import logger
6
+
7
+ from json2vec.entrypoints import execute
8
+ from json2vec.structs.experiment import Experiment
9
+
10
+
11
+ def train() -> None:
12
+ parser = argparse.ArgumentParser()
13
+ parser.add_argument("--experiment", type=str, default=None)
14
+ parser.add_argument("--name", type=str, default=None)
15
+ parser.add_argument("--notes", type=str, default=None)
16
+ parser.add_argument("--experiments", type=str, default="experiments")
17
+ args = parser.parse_args()
18
+
19
+ experiment: Experiment = Experiment.from_config(
20
+ args.experiments,
21
+ experiment=args.experiment,
22
+ name=args.name,
23
+ notes=args.notes,
24
+ )
25
+ outputs = execute(experiment=experiment)
26
+
27
+ for session_name, output in outputs.items():
28
+ logger.info(f"session={session_name} output={output}")
29
+
30
+
31
+ if __name__ == "__main__":
32
+ train()
File without changes
@@ -0,0 +1,64 @@
1
+ import math
2
+
3
+ import torch
4
+
5
+ from json2vec.architecture.rotary import RotaryEmbedding
6
+
7
+
8
+ class RotaryMultiheadAttention(torch.nn.Module):
9
+ def __init__(self, d_model: int, nhead: int, dropout: float):
10
+ super().__init__()
11
+
12
+ if d_model % nhead != 0:
13
+ raise ValueError("d_model must be divisible by nhead")
14
+
15
+ self.d_model = d_model
16
+ self.nhead = nhead
17
+ self.head_dim = d_model // nhead
18
+ self.scale = 1.0 / math.sqrt(self.head_dim)
19
+
20
+ self.q_proj = torch.nn.Linear(d_model, d_model)
21
+ self.k_proj = torch.nn.Linear(d_model, d_model)
22
+ self.v_proj = torch.nn.Linear(d_model, d_model)
23
+ self.out_proj = torch.nn.Linear(d_model, d_model)
24
+
25
+ self.rotary = RotaryEmbedding(d_model=self.head_dim)
26
+ self.dropout = torch.nn.Dropout(p=dropout)
27
+
28
+ def splitheads(self, inputs: torch.Tensor) -> torch.Tensor:
29
+ batch, seq_len, _ = inputs.shape
30
+ return inputs.reshape(batch, seq_len, self.nhead, self.head_dim).transpose(1, 2)
31
+
32
+ def rotate(self, inputs: torch.Tensor) -> torch.Tensor:
33
+ batch, nhead, seq_len, head_dim = inputs.shape
34
+ rotated = self.rotary(inputs.reshape(batch * nhead, seq_len, head_dim))
35
+ return rotated.reshape(batch, nhead, seq_len, head_dim)
36
+
37
+ def forward(
38
+ self,
39
+ query: torch.Tensor,
40
+ key: torch.Tensor,
41
+ value: torch.Tensor,
42
+ key_padding_mask: torch.Tensor | None = None,
43
+ ) -> torch.Tensor:
44
+ q = self.rotate(self.splitheads(self.q_proj(query)))
45
+ k = self.rotate(self.splitheads(self.k_proj(key)))
46
+ v = self.splitheads(self.v_proj(value))
47
+
48
+ scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
49
+
50
+ if key_padding_mask is not None:
51
+ mask = key_padding_mask
52
+ all_masked = mask.all(dim=1)
53
+ if all_masked.any():
54
+ mask = mask.clone()
55
+ mask[all_masked, 0] = False
56
+ scores = scores.masked_fill(mask[:, None, None, :], torch.finfo(scores.dtype).min)
57
+
58
+ probs = torch.softmax(scores, dim=-1)
59
+ probs = self.dropout(probs)
60
+
61
+ context = torch.matmul(probs, v)
62
+ context = context.transpose(1, 2).reshape(query.shape[0], query.shape[1], self.d_model)
63
+
64
+ return self.out_proj(context)
@@ -0,0 +1,37 @@
1
+ import torch
2
+
3
+ from json2vec.structs.tree import Address
4
+
5
+
6
+ class Counter(torch.nn.Module):
7
+ def __init__(self, address: Address, size: int):
8
+ super().__init__()
9
+
10
+ self.size: int = size
11
+
12
+ # init with ones to avoid division by zero
13
+ # it doesn't matter much since we will normalize over time
14
+ self.register_buffer("counts", torch.ones(size, dtype=torch.int64))
15
+ self.is_full: bool = False
16
+
17
+ @torch.no_grad()
18
+ def forward(self, values: torch.Tensor):
19
+ if self.training and not self.is_full:
20
+ next_count_max = int(self.counts.max().item()) + int(values.numel())
21
+ could_overflow = next_count_max > torch.iinfo(self.counts.dtype).max
22
+
23
+ if could_overflow:
24
+ # if we are approaching the max value, we stop counting and assume the counts are full
25
+ self.is_full = True
26
+ return values
27
+
28
+ self.counts += torch.bincount(values.view(-1), minlength=self.counts.shape[0]).to(self.counts.dtype)
29
+
30
+ return values
31
+
32
+ @property
33
+ @torch.no_grad()
34
+ def weight(self) -> torch.Tensor:
35
+ counts = self.counts.to(dtype=torch.float32)
36
+ weights = counts.rsqrt()
37
+ return weights * (counts.sum() / (weights * counts).sum())
@@ -0,0 +1,88 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ import torch
6
+
7
+ from json2vec.architecture.attention import RotaryMultiheadAttention
8
+ from json2vec.architecture.pool import LearnedQueryCrossAttention
9
+ from json2vec.structs.packages import Parcel
10
+ from json2vec.structs.tree import Address
11
+
12
+ if TYPE_CHECKING:
13
+ from json2vec.structs.structure import Structure
14
+
15
+
16
+ class RotaryTransformerEncoderLayer(torch.nn.Module):
17
+ def __init__(self, d_model: int, nhead: int, dropout: float, ffn_multiplier: int = 4):
18
+ super().__init__()
19
+
20
+ self.attention_norm = torch.nn.LayerNorm(normalized_shape=d_model)
21
+ self.ffn_norm = torch.nn.LayerNorm(normalized_shape=d_model)
22
+
23
+ self.attention = RotaryMultiheadAttention(d_model=d_model, nhead=nhead, dropout=dropout)
24
+
25
+ hidden = d_model * ffn_multiplier
26
+ self.ffn = torch.nn.Sequential(
27
+ torch.nn.Linear(in_features=d_model, out_features=hidden),
28
+ torch.nn.GELU(),
29
+ torch.nn.Dropout(p=dropout),
30
+ torch.nn.Linear(in_features=hidden, out_features=d_model),
31
+ torch.nn.Dropout(p=dropout),
32
+ )
33
+
34
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
35
+ normed = self.attention_norm(inputs)
36
+ inputs = inputs + self.attention(normed, normed, normed)
37
+ return inputs + self.ffn(self.ffn_norm(inputs))
38
+
39
+
40
+ class ContextEncoder(torch.nn.Module):
41
+ def __init__(self, structure: Structure, address: Address):
42
+ super().__init__()
43
+
44
+ context = structure.contexts[address]
45
+
46
+ self.origin: Address = address
47
+ self.destination: Address = context.parent.address
48
+
49
+ layers: list[RotaryTransformerEncoderLayer] = []
50
+ for _ in range(context.n_layers):
51
+ layers.append(
52
+ RotaryTransformerEncoderLayer(
53
+ d_model=structure.d_model,
54
+ nhead=context.n_heads,
55
+ dropout=structure.dropout,
56
+ )
57
+ )
58
+
59
+ self.encoder = torch.nn.ModuleList(layers)
60
+
61
+ self.pool = LearnedQueryCrossAttention(
62
+ n_context=context.n_outputs,
63
+ d_model=structure.d_model,
64
+ nhead=context.n_heads,
65
+ dropout=structure.dropout,
66
+ n_linear=context.n_linear,
67
+ )
68
+
69
+ def forward(self, parcels: list[Parcel]) -> Parcel:
70
+ payloads: list[torch.Tensor] = []
71
+ for parcel in parcels:
72
+ payloads.append(parcel.payload)
73
+
74
+ concatenated: torch.Tensor = torch.cat(payloads, dim=-2)
75
+ N, *dims, L, C = concatenated.shape
76
+ encoded: torch.Tensor = concatenated.reshape(-1, L, C)
77
+
78
+ for layer in self.encoder:
79
+ encoded = layer(encoded)
80
+
81
+ pooled: torch.Tensor = self.pool(encoded).reshape(N, *dims[:-1], -1, C)
82
+
83
+ return Parcel(
84
+ payload=pooled,
85
+ origin=self.origin,
86
+ destination=self.destination,
87
+ batch_size=N,
88
+ )
@@ -0,0 +1,34 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ import torch
6
+
7
+ from json2vec.architecture.encoder import ContextEncoder
8
+ from json2vec.structs.tree import Address, Node
9
+ from json2vec.tensorfields.base import (
10
+ TENSORFIELDS,
11
+ DecoderBase,
12
+ EmbedderBase,
13
+ Plugin,
14
+ )
15
+
16
+ if TYPE_CHECKING:
17
+ from json2vec.structs.config import Structure
18
+
19
+
20
+ class NodeModule(torch.nn.Module):
21
+ def __init__(self, structure: Structure, address: Address):
22
+ super().__init__()
23
+
24
+ if address in structure.requests:
25
+ request: Node = structure.requests[address]
26
+ plugin: Plugin = TENSORFIELDS[request.type]
27
+ self.embedder: EmbedderBase = plugin.Embedder(structure=structure, address=address)
28
+ self.decoder: DecoderBase = plugin.Decoder(structure=structure, address=address)
29
+
30
+ elif address in structure.contexts:
31
+ self.encoder: ContextEncoder = ContextEncoder(structure=structure, address=address)
32
+
33
+ else:
34
+ raise ValueError("how did we get here?")
@@ -0,0 +1,61 @@
1
+ import torch
2
+
3
+ from json2vec.architecture.attention import RotaryMultiheadAttention
4
+
5
+
6
+ class CrossAttentionBlock(torch.nn.Module):
7
+ def __init__(self, d_model: int, nhead: int, dropout: float, ffn_multiplier: int):
8
+ super().__init__()
9
+
10
+ self.attention_norm = torch.nn.LayerNorm(normalized_shape=d_model)
11
+ self.ffn_norm = torch.nn.LayerNorm(normalized_shape=d_model)
12
+ self.attention = RotaryMultiheadAttention(d_model=d_model, nhead=nhead, dropout=dropout)
13
+
14
+ hidden = d_model * ffn_multiplier
15
+ self.ffn = torch.nn.Sequential(
16
+ torch.nn.Linear(in_features=d_model, out_features=hidden),
17
+ torch.nn.GELU(),
18
+ torch.nn.Dropout(p=dropout),
19
+ torch.nn.Linear(in_features=hidden, out_features=d_model),
20
+ torch.nn.Dropout(p=dropout),
21
+ )
22
+
23
+ def forward(self, queries: torch.Tensor, memory: torch.Tensor) -> torch.Tensor:
24
+ attended = self.attention(self.attention_norm(queries), memory, memory)
25
+ queries = queries + attended
26
+ return queries + self.ffn(self.ffn_norm(queries))
27
+
28
+
29
+ class LearnedQueryCrossAttention(torch.nn.Module):
30
+ def __init__(
31
+ self,
32
+ n_context: int,
33
+ d_model: int,
34
+ nhead: int,
35
+ dropout: float,
36
+ n_linear: int = 1,
37
+ ffn_multiplier: int = 4,
38
+ ):
39
+ super().__init__()
40
+
41
+ self.queries = torch.nn.Parameter(torch.normal(mean=0.0, std=1e-2, size=(n_context, d_model)))
42
+ self.blocks = torch.nn.ModuleList()
43
+ for _ in range(n_linear):
44
+ self.blocks.append(
45
+ CrossAttentionBlock(
46
+ d_model=d_model,
47
+ nhead=nhead,
48
+ dropout=dropout,
49
+ ffn_multiplier=ffn_multiplier,
50
+ )
51
+ )
52
+ self.norm = torch.nn.LayerNorm(normalized_shape=d_model)
53
+
54
+ def forward(self, memory: torch.Tensor) -> torch.Tensor:
55
+ N, _, _ = memory.shape
56
+ queries = self.queries.unsqueeze(0).expand(N, -1, -1)
57
+
58
+ for block in self.blocks:
59
+ queries = block(queries=queries, memory=memory)
60
+
61
+ return self.norm(queries)
@@ -0,0 +1,338 @@
1
+ import math
2
+ import traceback
3
+ from collections import defaultdict
4
+ from functools import cache, partialmethod, wraps
5
+ from typing import Any, NotRequired, Self, TypedDict
6
+
7
+ import lightning.pytorch as lit
8
+ import torch
9
+ from beartype import beartype
10
+ from loguru import logger
11
+ from tensordict import TensorDict
12
+
13
+ from json2vec.architecture.encoder import ContextEncoder
14
+ from json2vec.architecture.node import NodeModule
15
+ from json2vec.data.datasets import dataloader, mock
16
+ from json2vec.structs.enums import Metric, Strata, TensorKey
17
+ from json2vec.structs.experiment import Session
18
+ from json2vec.structs.packages import Embedding, Parcel, Prediction
19
+ from json2vec.structs.tree import Address
20
+ from json2vec.tensorfields.base import (
21
+ TENSORFIELDS,
22
+ DecoderBase,
23
+ EmbedderBase,
24
+ Plugin,
25
+ RequestBase,
26
+ TensorFieldBase,
27
+ )
28
+
29
+
30
+ class Output(TypedDict):
31
+ loss: NotRequired[torch.Tensor]
32
+ predictions: list[Prediction]
33
+
34
+
35
+ @beartype
36
+ def step(
37
+ module: "JSON2Vec",
38
+ batch: TensorDict[Address, TensorFieldBase],
39
+ batch_idx: int,
40
+ strata: Strata,
41
+ ) -> Output:
42
+ predictions: list[Prediction] = module.forward(batch)
43
+
44
+ if strata == Strata.predict:
45
+ return Output(predictions=predictions)
46
+
47
+ losses: list[torch.Tensor] = []
48
+
49
+ for prediction in predictions:
50
+
51
+ if isinstance(prediction, Embedding):
52
+ continue
53
+
54
+ address: Address = prediction.address
55
+ request: RequestBase = module.session.structure.requests[address]
56
+ extension: Plugin = TENSORFIELDS[request.type]
57
+
58
+ loss: torch.Tensor = extension.loss(module=module, prediction=prediction, batch=batch[address], strata=strata)
59
+ losses.append(loss * torch.tensor(request.weight))
60
+
61
+ if len(losses) == 0:
62
+ # under idealistic circumstances this would never happen.
63
+ # but with small mask rates, batch sizes, and flat input data it is possible
64
+ logger.warning("no trainable fields in batch, returning zero loss")
65
+ loss: torch.Tensor = torch.tensor(0.0, device=batch.device, requires_grad=True)
66
+ return Output(loss=loss, predictions=[])
67
+
68
+ loss: torch.Tensor = module.track((Metric.loss, strata), value=torch.stack(losses).sum())
69
+
70
+ return Output(loss=loss, predictions=predictions)
71
+
72
+
73
+ def compile(fn):
74
+ @wraps(fn)
75
+ def wrapper(*args, **kwargs):
76
+ model = fn(*args, **kwargs)
77
+ try:
78
+ import thunder
79
+
80
+ model = thunder.compile(model)
81
+ logger.info("successfully compiled module with thunder")
82
+ except Exception:
83
+ traceback.print_exc()
84
+ logger.info("[thunder] Returning uncompiled model instead.")
85
+ return model
86
+
87
+ return wrapper
88
+
89
+
90
+ @cache
91
+ def groupname(names: tuple[str, ...]) -> str:
92
+ assert len(names) > 1
93
+
94
+ group, *keys = tuple(map(lambda x: x.replace("/", ":").lower(), names))
95
+
96
+ key: str = ":".join(list(keys))
97
+
98
+ return f"{group}/{key}"
99
+
100
+
101
+ class JSON2Vec(lit.LightningModule):
102
+ @beartype
103
+ def __init__(self, session: Session):
104
+
105
+ super().__init__()
106
+
107
+ self.session: Session = session
108
+
109
+ self.nodes: torch.nn.ModuleDict[str, NodeModule] = torch.nn.ModuleDict()
110
+
111
+ for address in self.session.structure.requests | self.session.structure.contexts:
112
+ self.nodes[address] = NodeModule(structure=self.session.structure, address=address)
113
+
114
+ self.example_input_array = mock(structure=session.structure)
115
+
116
+ logger.bind(
117
+ component="model",
118
+ session=self.session.name,
119
+ structure=self.session.structure.name,
120
+ requests=len(self.session.structure.requests),
121
+ contexts=len(self.session.structure.contexts),
122
+ outputs=len(self.session.output),
123
+ ).info("initialized JSON2Vec module")
124
+
125
+ def track(self, names: tuple[str, ...], /, value: torch.Tensor) -> torch.Tensor:
126
+ self.log(
127
+ name=groupname(names),
128
+ value=value,
129
+ on_step=False,
130
+ on_epoch=True,
131
+ sync_dist=True,
132
+ batch_size=self.session.structure.batch_size,
133
+ )
134
+
135
+ return value
136
+
137
+ @property
138
+ def state(self) -> dict[Address, Any]:
139
+ return {
140
+ address: node.embedder.state
141
+ for address, node in self.nodes.items()
142
+ if hasattr(node, "embedder") and hasattr(node.embedder, "state")
143
+ }
144
+
145
+ @beartype
146
+ def forward(self, inputs: TensorDict[Address, TensorFieldBase]) -> list[Prediction]:
147
+ processed: dict[Address, list[Parcel]] = defaultdict(list)
148
+ outgoing: dict[Address, Parcel] = {}
149
+ predictions: list[Prediction] = []
150
+
151
+ for address in self.session.structure.requests.keys():
152
+ if address in self.session.pruned:
153
+ continue
154
+
155
+ tensorfield: TensorFieldBase = inputs[address]
156
+ embedder: EmbedderBase = self.nodes[address].embedder
157
+ embedding: Parcel = embedder(tensorfield)
158
+ processed[embedding.destination].append(embedding)
159
+ outgoing[embedding.origin] = embedding
160
+
161
+ if address in self.session.output:
162
+ predictions.append(Embedding.from_parcel(embedding))
163
+
164
+ # DAG traversal from leaves to root
165
+ for depth in reversed(self.session.structure.depthwise):
166
+ # these are order-independent within the same depth
167
+ for address in depth:
168
+
169
+ if len(processed[address]) == 0:
170
+ continue
171
+
172
+ encoder: ContextEncoder = self.nodes[address].encoder
173
+ encoding: Parcel = encoder(processed[address])
174
+ processed[encoding.destination].append(encoding)
175
+ outgoing[encoding.origin] = encoding
176
+
177
+ if address in self.session.output:
178
+ predictions.append(Embedding.from_parcel(encoding))
179
+
180
+ for address in self.session.structure.requests.keys():
181
+
182
+ if (torch.any(inputs[address].trainable)) or (address in self.session.pruned):
183
+
184
+ heritage: list[Address] = self.session.structure.requests[address].heritage
185
+ parcels: list[Parcel] = [
186
+ outgoing[address] for address in heritage
187
+ if address not in self.session.pruned and address in outgoing.keys()
188
+ ]
189
+
190
+ decoder: DecoderBase = self.nodes[address].decoder
191
+ predictions.append(decoder(parcels))
192
+
193
+ return predictions
194
+
195
+ @beartype
196
+ def configure_optimizers(self):
197
+
198
+ if self.session.learning_rate is None:
199
+ raise ValueError("learning_rate must be defined for optimizer configuration")
200
+
201
+ class GroupedParameter(TypedDict):
202
+ params: list[torch.nn.Parameter]
203
+ weight_decay: float
204
+
205
+ params: dict[str, GroupedParameter] = dict(
206
+ with_decay = GroupedParameter(params=[], weight_decay=self.session.weight_decay),
207
+ no_decay = GroupedParameter(params=[], weight_decay=0.0),
208
+ )
209
+
210
+ for name, parameter in self.named_parameters():
211
+ if not parameter.requires_grad:
212
+ continue
213
+
214
+ if name.endswith(".bias") or parameter.ndim <= 1 or "norm" in name.lower():
215
+ params["no_decay"]["params"].append(parameter)
216
+ else:
217
+ params["with_decay"]["params"].append(parameter)
218
+
219
+
220
+ optimizer = torch.optim.AdamW(params=list(params.values()), lr=self.session.learning_rate, betas=(0.9, 0.95))
221
+ trainable_parameters: int = sum(parameter.numel() for parameter in self.parameters() if parameter.requires_grad)
222
+ logger.bind(
223
+ component="optimizer",
224
+ session=self.session.name,
225
+ learning_rate=self.session.learning_rate,
226
+ weight_decay=self.session.weight_decay,
227
+ trainable_parameters=trainable_parameters,
228
+ warmup_ratio=self.session.warmup_ratio,
229
+ min_lr_ratio=self.session.min_lr_ratio,
230
+ ).info("configured AdamW optimizer")
231
+
232
+ total = int(getattr(self.trainer, "estimated_stepping_batches", 0) or 0)
233
+
234
+ if total <= 0:
235
+ return optimizer
236
+
237
+ warmup = max(1, int(total * self.session.warmup_ratio))
238
+ min_lr_ratio = self.session.min_lr_ratio
239
+
240
+ def schedule(step: int) -> float:
241
+
242
+ if step < warmup:
243
+ return float(step + 1) / float(warmup)
244
+
245
+ ratio = float(step - warmup) / float(max(1, total - warmup))
246
+
247
+ progress = min(1.0, ratio)
248
+
249
+ cosine = 0.5 * (1.0 + math.cos(math.pi * progress))
250
+ return min_lr_ratio + (1.0 - min_lr_ratio) * cosine
251
+
252
+ scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, schedule)
253
+
254
+ return dict(
255
+ optimizer= optimizer,
256
+ lr_scheduler=dict(
257
+ scheduler = scheduler,
258
+ interval = "step",
259
+ frequency = 1,
260
+ ),
261
+ )
262
+
263
+ def on_save_checkpoint(self, checkpoint):
264
+ logger.bind(component="checkpoint", session=self.session.name).info("serializing session")
265
+ checkpoint["session"] = self.session.model_dump()
266
+
267
+ def on_load_checkpoint(self, checkpoint):
268
+ logger.bind(component="checkpoint").info("loading session from checkpoint payload")
269
+ if "session" in checkpoint and getattr(self, "session", None) is None:
270
+ self.session = Session.model_validate(checkpoint["session"])
271
+
272
+ if getattr(self, "session", None) is None:
273
+ raise ValueError("missing session in checkpoint and constructor")
274
+
275
+ @classmethod
276
+ def get_or_create(
277
+ cls,
278
+ session: Session|None = None,
279
+ checkpoint: str | None = None,
280
+ ) -> Self:
281
+
282
+ if checkpoint is None:
283
+ logger.bind(component="model_factory").info("creating new JSON2Vec model")
284
+ if session is None:
285
+ raise ValueError("session is required when checkpoint is not provided")
286
+
287
+ model: "JSON2Vec" = cls(session=session)
288
+
289
+ return model
290
+
291
+ else:
292
+ logger.bind(component="model_factory", checkpoint=checkpoint).info("loading JSON2Vec model from checkpoint")
293
+ state = torch.load(checkpoint, weights_only=False, map_location="cpu")
294
+
295
+ model: "JSON2Vec" = cls(
296
+ session=session or Session.model_validate(state["session"]),
297
+ )
298
+
299
+ model.load_state_dict(state_dict=state["state_dict"], strict=False)
300
+ logger.bind(component="model_factory", checkpoint=checkpoint).info("restored model state from checkpoint")
301
+
302
+ return model
303
+
304
+ def write(self, predictions: list[Prediction]) -> tuple[dict[Address, dict[str, Any]], dict[Address, dict[str, Any]]]:
305
+
306
+ supervised: dict[Address, dict[str, Any]] = {}
307
+ embeddings: dict[Address, dict[str, Any]] = {}
308
+
309
+ for prediction in predictions:
310
+
311
+ if isinstance(prediction, Embedding):
312
+
313
+ embeddings[prediction.address] = Embedding.write(prediction)
314
+
315
+ continue
316
+
317
+ request: RequestBase = self.session.structure.requests[prediction.address]
318
+
319
+ extension: Plugin = TENSORFIELDS[request.type]
320
+
321
+ scribed: dict[TensorKey, Any]|None = extension.write(module=self, prediction=prediction)
322
+
323
+ if scribed is not None:
324
+ supervised[prediction.address] = Prediction.serialize(scribed)
325
+
326
+
327
+
328
+ return supervised, embeddings
329
+
330
+ training_step = partialmethod(step, strata=Strata.train)
331
+ validation_step = partialmethod(step, strata=Strata.validate)
332
+ test_step = partialmethod(step, strata=Strata.test)
333
+ predict_step = partialmethod(step, strata=Strata.predict)
334
+
335
+ train_dataloader = partialmethod(dataloader, strata=Strata.train)
336
+ val_dataloader = partialmethod(dataloader, strata=Strata.validate)
337
+ test_dataloader = partialmethod(dataloader, strata=Strata.test)
338
+ predict_dataloader = partialmethod(dataloader, strata=Strata.predict)