json2vec 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- json2vec/__init__.py +0 -0
- json2vec/__main__.py +32 -0
- json2vec/architecture/__init__.py +0 -0
- json2vec/architecture/attention.py +64 -0
- json2vec/architecture/counter.py +37 -0
- json2vec/architecture/encoder.py +88 -0
- json2vec/architecture/node.py +34 -0
- json2vec/architecture/pool.py +61 -0
- json2vec/architecture/root.py +338 -0
- json2vec/architecture/rotary.py +39 -0
- json2vec/data/__init__.py +0 -0
- json2vec/data/datasets.py +539 -0
- json2vec/data/processing.py +152 -0
- json2vec/entrypoints/__init__.py +3 -0
- json2vec/entrypoints/pipeline.py +174 -0
- json2vec/inference/__init__.py +0 -0
- json2vec/inference/callback.py +98 -0
- json2vec/inference/deployment.py +175 -0
- json2vec/logging/__init__.py +0 -0
- json2vec/logging/config.py +27 -0
- json2vec/logging/epoch.py +42 -0
- json2vec/logging/throughput.py +39 -0
- json2vec/logging/tracking.py +152 -0
- json2vec/processors/__init__.py +8 -0
- json2vec/processors/base.py +102 -0
- json2vec/processors/extensions/__init__.py +0 -0
- json2vec/processors/extensions/example.py +6 -0
- json2vec/processors/spec.py +8 -0
- json2vec/structs/__init__.py +0 -0
- json2vec/structs/enums.py +84 -0
- json2vec/structs/environment.py +138 -0
- json2vec/structs/experiment.py +330 -0
- json2vec/structs/packages.py +117 -0
- json2vec/structs/structure.py +70 -0
- json2vec/structs/tree.py +92 -0
- json2vec/tensorfields/__init__.py +8 -0
- json2vec/tensorfields/base.py +210 -0
- json2vec/tensorfields/extensions/__init__.py +0 -0
- json2vec/tensorfields/extensions/category.py +484 -0
- json2vec/tensorfields/extensions/dateparts.py +410 -0
- json2vec/tensorfields/extensions/entity.py +336 -0
- json2vec/tensorfields/extensions/number.py +400 -0
- json2vec/tensorfields/extensions/vector.py +279 -0
- json2vec/tensorfields/spec.py +8 -0
- json2vec-0.1.0.dist-info/METADATA +227 -0
- json2vec-0.1.0.dist-info/RECORD +51 -0
- json2vec-0.1.0.dist-info/WHEEL +5 -0
- json2vec-0.1.0.dist-info/entry_points.txt +2 -0
- json2vec-0.1.0.dist-info/licenses/LICENSE +178 -0
- json2vec-0.1.0.dist-info/licenses/NOTICE +8 -0
- json2vec-0.1.0.dist-info/top_level.txt +1 -0
json2vec/__init__.py
ADDED
|
File without changes
|
json2vec/__main__.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import argparse
|
|
4
|
+
|
|
5
|
+
from loguru import logger
|
|
6
|
+
|
|
7
|
+
from json2vec.entrypoints import execute
|
|
8
|
+
from json2vec.structs.experiment import Experiment
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def train() -> None:
|
|
12
|
+
parser = argparse.ArgumentParser()
|
|
13
|
+
parser.add_argument("--experiment", type=str, default=None)
|
|
14
|
+
parser.add_argument("--name", type=str, default=None)
|
|
15
|
+
parser.add_argument("--notes", type=str, default=None)
|
|
16
|
+
parser.add_argument("--experiments", type=str, default="experiments")
|
|
17
|
+
args = parser.parse_args()
|
|
18
|
+
|
|
19
|
+
experiment: Experiment = Experiment.from_config(
|
|
20
|
+
args.experiments,
|
|
21
|
+
experiment=args.experiment,
|
|
22
|
+
name=args.name,
|
|
23
|
+
notes=args.notes,
|
|
24
|
+
)
|
|
25
|
+
outputs = execute(experiment=experiment)
|
|
26
|
+
|
|
27
|
+
for session_name, output in outputs.items():
|
|
28
|
+
logger.info(f"session={session_name} output={output}")
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
if __name__ == "__main__":
|
|
32
|
+
train()
|
|
File without changes
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
import math
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from json2vec.architecture.rotary import RotaryEmbedding
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class RotaryMultiheadAttention(torch.nn.Module):
|
|
9
|
+
def __init__(self, d_model: int, nhead: int, dropout: float):
|
|
10
|
+
super().__init__()
|
|
11
|
+
|
|
12
|
+
if d_model % nhead != 0:
|
|
13
|
+
raise ValueError("d_model must be divisible by nhead")
|
|
14
|
+
|
|
15
|
+
self.d_model = d_model
|
|
16
|
+
self.nhead = nhead
|
|
17
|
+
self.head_dim = d_model // nhead
|
|
18
|
+
self.scale = 1.0 / math.sqrt(self.head_dim)
|
|
19
|
+
|
|
20
|
+
self.q_proj = torch.nn.Linear(d_model, d_model)
|
|
21
|
+
self.k_proj = torch.nn.Linear(d_model, d_model)
|
|
22
|
+
self.v_proj = torch.nn.Linear(d_model, d_model)
|
|
23
|
+
self.out_proj = torch.nn.Linear(d_model, d_model)
|
|
24
|
+
|
|
25
|
+
self.rotary = RotaryEmbedding(d_model=self.head_dim)
|
|
26
|
+
self.dropout = torch.nn.Dropout(p=dropout)
|
|
27
|
+
|
|
28
|
+
def splitheads(self, inputs: torch.Tensor) -> torch.Tensor:
|
|
29
|
+
batch, seq_len, _ = inputs.shape
|
|
30
|
+
return inputs.reshape(batch, seq_len, self.nhead, self.head_dim).transpose(1, 2)
|
|
31
|
+
|
|
32
|
+
def rotate(self, inputs: torch.Tensor) -> torch.Tensor:
|
|
33
|
+
batch, nhead, seq_len, head_dim = inputs.shape
|
|
34
|
+
rotated = self.rotary(inputs.reshape(batch * nhead, seq_len, head_dim))
|
|
35
|
+
return rotated.reshape(batch, nhead, seq_len, head_dim)
|
|
36
|
+
|
|
37
|
+
def forward(
|
|
38
|
+
self,
|
|
39
|
+
query: torch.Tensor,
|
|
40
|
+
key: torch.Tensor,
|
|
41
|
+
value: torch.Tensor,
|
|
42
|
+
key_padding_mask: torch.Tensor | None = None,
|
|
43
|
+
) -> torch.Tensor:
|
|
44
|
+
q = self.rotate(self.splitheads(self.q_proj(query)))
|
|
45
|
+
k = self.rotate(self.splitheads(self.k_proj(key)))
|
|
46
|
+
v = self.splitheads(self.v_proj(value))
|
|
47
|
+
|
|
48
|
+
scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
|
|
49
|
+
|
|
50
|
+
if key_padding_mask is not None:
|
|
51
|
+
mask = key_padding_mask
|
|
52
|
+
all_masked = mask.all(dim=1)
|
|
53
|
+
if all_masked.any():
|
|
54
|
+
mask = mask.clone()
|
|
55
|
+
mask[all_masked, 0] = False
|
|
56
|
+
scores = scores.masked_fill(mask[:, None, None, :], torch.finfo(scores.dtype).min)
|
|
57
|
+
|
|
58
|
+
probs = torch.softmax(scores, dim=-1)
|
|
59
|
+
probs = self.dropout(probs)
|
|
60
|
+
|
|
61
|
+
context = torch.matmul(probs, v)
|
|
62
|
+
context = context.transpose(1, 2).reshape(query.shape[0], query.shape[1], self.d_model)
|
|
63
|
+
|
|
64
|
+
return self.out_proj(context)
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from json2vec.structs.tree import Address
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class Counter(torch.nn.Module):
|
|
7
|
+
def __init__(self, address: Address, size: int):
|
|
8
|
+
super().__init__()
|
|
9
|
+
|
|
10
|
+
self.size: int = size
|
|
11
|
+
|
|
12
|
+
# init with ones to avoid division by zero
|
|
13
|
+
# it doesn't matter much since we will normalize over time
|
|
14
|
+
self.register_buffer("counts", torch.ones(size, dtype=torch.int64))
|
|
15
|
+
self.is_full: bool = False
|
|
16
|
+
|
|
17
|
+
@torch.no_grad()
|
|
18
|
+
def forward(self, values: torch.Tensor):
|
|
19
|
+
if self.training and not self.is_full:
|
|
20
|
+
next_count_max = int(self.counts.max().item()) + int(values.numel())
|
|
21
|
+
could_overflow = next_count_max > torch.iinfo(self.counts.dtype).max
|
|
22
|
+
|
|
23
|
+
if could_overflow:
|
|
24
|
+
# if we are approaching the max value, we stop counting and assume the counts are full
|
|
25
|
+
self.is_full = True
|
|
26
|
+
return values
|
|
27
|
+
|
|
28
|
+
self.counts += torch.bincount(values.view(-1), minlength=self.counts.shape[0]).to(self.counts.dtype)
|
|
29
|
+
|
|
30
|
+
return values
|
|
31
|
+
|
|
32
|
+
@property
|
|
33
|
+
@torch.no_grad()
|
|
34
|
+
def weight(self) -> torch.Tensor:
|
|
35
|
+
counts = self.counts.to(dtype=torch.float32)
|
|
36
|
+
weights = counts.rsqrt()
|
|
37
|
+
return weights * (counts.sum() / (weights * counts).sum())
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from json2vec.architecture.attention import RotaryMultiheadAttention
|
|
8
|
+
from json2vec.architecture.pool import LearnedQueryCrossAttention
|
|
9
|
+
from json2vec.structs.packages import Parcel
|
|
10
|
+
from json2vec.structs.tree import Address
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from json2vec.structs.structure import Structure
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class RotaryTransformerEncoderLayer(torch.nn.Module):
|
|
17
|
+
def __init__(self, d_model: int, nhead: int, dropout: float, ffn_multiplier: int = 4):
|
|
18
|
+
super().__init__()
|
|
19
|
+
|
|
20
|
+
self.attention_norm = torch.nn.LayerNorm(normalized_shape=d_model)
|
|
21
|
+
self.ffn_norm = torch.nn.LayerNorm(normalized_shape=d_model)
|
|
22
|
+
|
|
23
|
+
self.attention = RotaryMultiheadAttention(d_model=d_model, nhead=nhead, dropout=dropout)
|
|
24
|
+
|
|
25
|
+
hidden = d_model * ffn_multiplier
|
|
26
|
+
self.ffn = torch.nn.Sequential(
|
|
27
|
+
torch.nn.Linear(in_features=d_model, out_features=hidden),
|
|
28
|
+
torch.nn.GELU(),
|
|
29
|
+
torch.nn.Dropout(p=dropout),
|
|
30
|
+
torch.nn.Linear(in_features=hidden, out_features=d_model),
|
|
31
|
+
torch.nn.Dropout(p=dropout),
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
|
|
35
|
+
normed = self.attention_norm(inputs)
|
|
36
|
+
inputs = inputs + self.attention(normed, normed, normed)
|
|
37
|
+
return inputs + self.ffn(self.ffn_norm(inputs))
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class ContextEncoder(torch.nn.Module):
|
|
41
|
+
def __init__(self, structure: Structure, address: Address):
|
|
42
|
+
super().__init__()
|
|
43
|
+
|
|
44
|
+
context = structure.contexts[address]
|
|
45
|
+
|
|
46
|
+
self.origin: Address = address
|
|
47
|
+
self.destination: Address = context.parent.address
|
|
48
|
+
|
|
49
|
+
layers: list[RotaryTransformerEncoderLayer] = []
|
|
50
|
+
for _ in range(context.n_layers):
|
|
51
|
+
layers.append(
|
|
52
|
+
RotaryTransformerEncoderLayer(
|
|
53
|
+
d_model=structure.d_model,
|
|
54
|
+
nhead=context.n_heads,
|
|
55
|
+
dropout=structure.dropout,
|
|
56
|
+
)
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
self.encoder = torch.nn.ModuleList(layers)
|
|
60
|
+
|
|
61
|
+
self.pool = LearnedQueryCrossAttention(
|
|
62
|
+
n_context=context.n_outputs,
|
|
63
|
+
d_model=structure.d_model,
|
|
64
|
+
nhead=context.n_heads,
|
|
65
|
+
dropout=structure.dropout,
|
|
66
|
+
n_linear=context.n_linear,
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
def forward(self, parcels: list[Parcel]) -> Parcel:
|
|
70
|
+
payloads: list[torch.Tensor] = []
|
|
71
|
+
for parcel in parcels:
|
|
72
|
+
payloads.append(parcel.payload)
|
|
73
|
+
|
|
74
|
+
concatenated: torch.Tensor = torch.cat(payloads, dim=-2)
|
|
75
|
+
N, *dims, L, C = concatenated.shape
|
|
76
|
+
encoded: torch.Tensor = concatenated.reshape(-1, L, C)
|
|
77
|
+
|
|
78
|
+
for layer in self.encoder:
|
|
79
|
+
encoded = layer(encoded)
|
|
80
|
+
|
|
81
|
+
pooled: torch.Tensor = self.pool(encoded).reshape(N, *dims[:-1], -1, C)
|
|
82
|
+
|
|
83
|
+
return Parcel(
|
|
84
|
+
payload=pooled,
|
|
85
|
+
origin=self.origin,
|
|
86
|
+
destination=self.destination,
|
|
87
|
+
batch_size=N,
|
|
88
|
+
)
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from json2vec.architecture.encoder import ContextEncoder
|
|
8
|
+
from json2vec.structs.tree import Address, Node
|
|
9
|
+
from json2vec.tensorfields.base import (
|
|
10
|
+
TENSORFIELDS,
|
|
11
|
+
DecoderBase,
|
|
12
|
+
EmbedderBase,
|
|
13
|
+
Plugin,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from json2vec.structs.config import Structure
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class NodeModule(torch.nn.Module):
|
|
21
|
+
def __init__(self, structure: Structure, address: Address):
|
|
22
|
+
super().__init__()
|
|
23
|
+
|
|
24
|
+
if address in structure.requests:
|
|
25
|
+
request: Node = structure.requests[address]
|
|
26
|
+
plugin: Plugin = TENSORFIELDS[request.type]
|
|
27
|
+
self.embedder: EmbedderBase = plugin.Embedder(structure=structure, address=address)
|
|
28
|
+
self.decoder: DecoderBase = plugin.Decoder(structure=structure, address=address)
|
|
29
|
+
|
|
30
|
+
elif address in structure.contexts:
|
|
31
|
+
self.encoder: ContextEncoder = ContextEncoder(structure=structure, address=address)
|
|
32
|
+
|
|
33
|
+
else:
|
|
34
|
+
raise ValueError("how did we get here?")
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from json2vec.architecture.attention import RotaryMultiheadAttention
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class CrossAttentionBlock(torch.nn.Module):
|
|
7
|
+
def __init__(self, d_model: int, nhead: int, dropout: float, ffn_multiplier: int):
|
|
8
|
+
super().__init__()
|
|
9
|
+
|
|
10
|
+
self.attention_norm = torch.nn.LayerNorm(normalized_shape=d_model)
|
|
11
|
+
self.ffn_norm = torch.nn.LayerNorm(normalized_shape=d_model)
|
|
12
|
+
self.attention = RotaryMultiheadAttention(d_model=d_model, nhead=nhead, dropout=dropout)
|
|
13
|
+
|
|
14
|
+
hidden = d_model * ffn_multiplier
|
|
15
|
+
self.ffn = torch.nn.Sequential(
|
|
16
|
+
torch.nn.Linear(in_features=d_model, out_features=hidden),
|
|
17
|
+
torch.nn.GELU(),
|
|
18
|
+
torch.nn.Dropout(p=dropout),
|
|
19
|
+
torch.nn.Linear(in_features=hidden, out_features=d_model),
|
|
20
|
+
torch.nn.Dropout(p=dropout),
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
def forward(self, queries: torch.Tensor, memory: torch.Tensor) -> torch.Tensor:
|
|
24
|
+
attended = self.attention(self.attention_norm(queries), memory, memory)
|
|
25
|
+
queries = queries + attended
|
|
26
|
+
return queries + self.ffn(self.ffn_norm(queries))
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class LearnedQueryCrossAttention(torch.nn.Module):
|
|
30
|
+
def __init__(
|
|
31
|
+
self,
|
|
32
|
+
n_context: int,
|
|
33
|
+
d_model: int,
|
|
34
|
+
nhead: int,
|
|
35
|
+
dropout: float,
|
|
36
|
+
n_linear: int = 1,
|
|
37
|
+
ffn_multiplier: int = 4,
|
|
38
|
+
):
|
|
39
|
+
super().__init__()
|
|
40
|
+
|
|
41
|
+
self.queries = torch.nn.Parameter(torch.normal(mean=0.0, std=1e-2, size=(n_context, d_model)))
|
|
42
|
+
self.blocks = torch.nn.ModuleList()
|
|
43
|
+
for _ in range(n_linear):
|
|
44
|
+
self.blocks.append(
|
|
45
|
+
CrossAttentionBlock(
|
|
46
|
+
d_model=d_model,
|
|
47
|
+
nhead=nhead,
|
|
48
|
+
dropout=dropout,
|
|
49
|
+
ffn_multiplier=ffn_multiplier,
|
|
50
|
+
)
|
|
51
|
+
)
|
|
52
|
+
self.norm = torch.nn.LayerNorm(normalized_shape=d_model)
|
|
53
|
+
|
|
54
|
+
def forward(self, memory: torch.Tensor) -> torch.Tensor:
|
|
55
|
+
N, _, _ = memory.shape
|
|
56
|
+
queries = self.queries.unsqueeze(0).expand(N, -1, -1)
|
|
57
|
+
|
|
58
|
+
for block in self.blocks:
|
|
59
|
+
queries = block(queries=queries, memory=memory)
|
|
60
|
+
|
|
61
|
+
return self.norm(queries)
|
|
@@ -0,0 +1,338 @@
|
|
|
1
|
+
import math
|
|
2
|
+
import traceback
|
|
3
|
+
from collections import defaultdict
|
|
4
|
+
from functools import cache, partialmethod, wraps
|
|
5
|
+
from typing import Any, NotRequired, Self, TypedDict
|
|
6
|
+
|
|
7
|
+
import lightning.pytorch as lit
|
|
8
|
+
import torch
|
|
9
|
+
from beartype import beartype
|
|
10
|
+
from loguru import logger
|
|
11
|
+
from tensordict import TensorDict
|
|
12
|
+
|
|
13
|
+
from json2vec.architecture.encoder import ContextEncoder
|
|
14
|
+
from json2vec.architecture.node import NodeModule
|
|
15
|
+
from json2vec.data.datasets import dataloader, mock
|
|
16
|
+
from json2vec.structs.enums import Metric, Strata, TensorKey
|
|
17
|
+
from json2vec.structs.experiment import Session
|
|
18
|
+
from json2vec.structs.packages import Embedding, Parcel, Prediction
|
|
19
|
+
from json2vec.structs.tree import Address
|
|
20
|
+
from json2vec.tensorfields.base import (
|
|
21
|
+
TENSORFIELDS,
|
|
22
|
+
DecoderBase,
|
|
23
|
+
EmbedderBase,
|
|
24
|
+
Plugin,
|
|
25
|
+
RequestBase,
|
|
26
|
+
TensorFieldBase,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class Output(TypedDict):
|
|
31
|
+
loss: NotRequired[torch.Tensor]
|
|
32
|
+
predictions: list[Prediction]
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@beartype
|
|
36
|
+
def step(
|
|
37
|
+
module: "JSON2Vec",
|
|
38
|
+
batch: TensorDict[Address, TensorFieldBase],
|
|
39
|
+
batch_idx: int,
|
|
40
|
+
strata: Strata,
|
|
41
|
+
) -> Output:
|
|
42
|
+
predictions: list[Prediction] = module.forward(batch)
|
|
43
|
+
|
|
44
|
+
if strata == Strata.predict:
|
|
45
|
+
return Output(predictions=predictions)
|
|
46
|
+
|
|
47
|
+
losses: list[torch.Tensor] = []
|
|
48
|
+
|
|
49
|
+
for prediction in predictions:
|
|
50
|
+
|
|
51
|
+
if isinstance(prediction, Embedding):
|
|
52
|
+
continue
|
|
53
|
+
|
|
54
|
+
address: Address = prediction.address
|
|
55
|
+
request: RequestBase = module.session.structure.requests[address]
|
|
56
|
+
extension: Plugin = TENSORFIELDS[request.type]
|
|
57
|
+
|
|
58
|
+
loss: torch.Tensor = extension.loss(module=module, prediction=prediction, batch=batch[address], strata=strata)
|
|
59
|
+
losses.append(loss * torch.tensor(request.weight))
|
|
60
|
+
|
|
61
|
+
if len(losses) == 0:
|
|
62
|
+
# under idealistic circumstances this would never happen.
|
|
63
|
+
# but with small mask rates, batch sizes, and flat input data it is possible
|
|
64
|
+
logger.warning("no trainable fields in batch, returning zero loss")
|
|
65
|
+
loss: torch.Tensor = torch.tensor(0.0, device=batch.device, requires_grad=True)
|
|
66
|
+
return Output(loss=loss, predictions=[])
|
|
67
|
+
|
|
68
|
+
loss: torch.Tensor = module.track((Metric.loss, strata), value=torch.stack(losses).sum())
|
|
69
|
+
|
|
70
|
+
return Output(loss=loss, predictions=predictions)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
def compile(fn):
|
|
74
|
+
@wraps(fn)
|
|
75
|
+
def wrapper(*args, **kwargs):
|
|
76
|
+
model = fn(*args, **kwargs)
|
|
77
|
+
try:
|
|
78
|
+
import thunder
|
|
79
|
+
|
|
80
|
+
model = thunder.compile(model)
|
|
81
|
+
logger.info("successfully compiled module with thunder")
|
|
82
|
+
except Exception:
|
|
83
|
+
traceback.print_exc()
|
|
84
|
+
logger.info("[thunder] Returning uncompiled model instead.")
|
|
85
|
+
return model
|
|
86
|
+
|
|
87
|
+
return wrapper
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
@cache
|
|
91
|
+
def groupname(names: tuple[str, ...]) -> str:
|
|
92
|
+
assert len(names) > 1
|
|
93
|
+
|
|
94
|
+
group, *keys = tuple(map(lambda x: x.replace("/", ":").lower(), names))
|
|
95
|
+
|
|
96
|
+
key: str = ":".join(list(keys))
|
|
97
|
+
|
|
98
|
+
return f"{group}/{key}"
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
class JSON2Vec(lit.LightningModule):
|
|
102
|
+
@beartype
|
|
103
|
+
def __init__(self, session: Session):
|
|
104
|
+
|
|
105
|
+
super().__init__()
|
|
106
|
+
|
|
107
|
+
self.session: Session = session
|
|
108
|
+
|
|
109
|
+
self.nodes: torch.nn.ModuleDict[str, NodeModule] = torch.nn.ModuleDict()
|
|
110
|
+
|
|
111
|
+
for address in self.session.structure.requests | self.session.structure.contexts:
|
|
112
|
+
self.nodes[address] = NodeModule(structure=self.session.structure, address=address)
|
|
113
|
+
|
|
114
|
+
self.example_input_array = mock(structure=session.structure)
|
|
115
|
+
|
|
116
|
+
logger.bind(
|
|
117
|
+
component="model",
|
|
118
|
+
session=self.session.name,
|
|
119
|
+
structure=self.session.structure.name,
|
|
120
|
+
requests=len(self.session.structure.requests),
|
|
121
|
+
contexts=len(self.session.structure.contexts),
|
|
122
|
+
outputs=len(self.session.output),
|
|
123
|
+
).info("initialized JSON2Vec module")
|
|
124
|
+
|
|
125
|
+
def track(self, names: tuple[str, ...], /, value: torch.Tensor) -> torch.Tensor:
|
|
126
|
+
self.log(
|
|
127
|
+
name=groupname(names),
|
|
128
|
+
value=value,
|
|
129
|
+
on_step=False,
|
|
130
|
+
on_epoch=True,
|
|
131
|
+
sync_dist=True,
|
|
132
|
+
batch_size=self.session.structure.batch_size,
|
|
133
|
+
)
|
|
134
|
+
|
|
135
|
+
return value
|
|
136
|
+
|
|
137
|
+
@property
|
|
138
|
+
def state(self) -> dict[Address, Any]:
|
|
139
|
+
return {
|
|
140
|
+
address: node.embedder.state
|
|
141
|
+
for address, node in self.nodes.items()
|
|
142
|
+
if hasattr(node, "embedder") and hasattr(node.embedder, "state")
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
@beartype
|
|
146
|
+
def forward(self, inputs: TensorDict[Address, TensorFieldBase]) -> list[Prediction]:
|
|
147
|
+
processed: dict[Address, list[Parcel]] = defaultdict(list)
|
|
148
|
+
outgoing: dict[Address, Parcel] = {}
|
|
149
|
+
predictions: list[Prediction] = []
|
|
150
|
+
|
|
151
|
+
for address in self.session.structure.requests.keys():
|
|
152
|
+
if address in self.session.pruned:
|
|
153
|
+
continue
|
|
154
|
+
|
|
155
|
+
tensorfield: TensorFieldBase = inputs[address]
|
|
156
|
+
embedder: EmbedderBase = self.nodes[address].embedder
|
|
157
|
+
embedding: Parcel = embedder(tensorfield)
|
|
158
|
+
processed[embedding.destination].append(embedding)
|
|
159
|
+
outgoing[embedding.origin] = embedding
|
|
160
|
+
|
|
161
|
+
if address in self.session.output:
|
|
162
|
+
predictions.append(Embedding.from_parcel(embedding))
|
|
163
|
+
|
|
164
|
+
# DAG traversal from leaves to root
|
|
165
|
+
for depth in reversed(self.session.structure.depthwise):
|
|
166
|
+
# these are order-independent within the same depth
|
|
167
|
+
for address in depth:
|
|
168
|
+
|
|
169
|
+
if len(processed[address]) == 0:
|
|
170
|
+
continue
|
|
171
|
+
|
|
172
|
+
encoder: ContextEncoder = self.nodes[address].encoder
|
|
173
|
+
encoding: Parcel = encoder(processed[address])
|
|
174
|
+
processed[encoding.destination].append(encoding)
|
|
175
|
+
outgoing[encoding.origin] = encoding
|
|
176
|
+
|
|
177
|
+
if address in self.session.output:
|
|
178
|
+
predictions.append(Embedding.from_parcel(encoding))
|
|
179
|
+
|
|
180
|
+
for address in self.session.structure.requests.keys():
|
|
181
|
+
|
|
182
|
+
if (torch.any(inputs[address].trainable)) or (address in self.session.pruned):
|
|
183
|
+
|
|
184
|
+
heritage: list[Address] = self.session.structure.requests[address].heritage
|
|
185
|
+
parcels: list[Parcel] = [
|
|
186
|
+
outgoing[address] for address in heritage
|
|
187
|
+
if address not in self.session.pruned and address in outgoing.keys()
|
|
188
|
+
]
|
|
189
|
+
|
|
190
|
+
decoder: DecoderBase = self.nodes[address].decoder
|
|
191
|
+
predictions.append(decoder(parcels))
|
|
192
|
+
|
|
193
|
+
return predictions
|
|
194
|
+
|
|
195
|
+
@beartype
|
|
196
|
+
def configure_optimizers(self):
|
|
197
|
+
|
|
198
|
+
if self.session.learning_rate is None:
|
|
199
|
+
raise ValueError("learning_rate must be defined for optimizer configuration")
|
|
200
|
+
|
|
201
|
+
class GroupedParameter(TypedDict):
|
|
202
|
+
params: list[torch.nn.Parameter]
|
|
203
|
+
weight_decay: float
|
|
204
|
+
|
|
205
|
+
params: dict[str, GroupedParameter] = dict(
|
|
206
|
+
with_decay = GroupedParameter(params=[], weight_decay=self.session.weight_decay),
|
|
207
|
+
no_decay = GroupedParameter(params=[], weight_decay=0.0),
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
for name, parameter in self.named_parameters():
|
|
211
|
+
if not parameter.requires_grad:
|
|
212
|
+
continue
|
|
213
|
+
|
|
214
|
+
if name.endswith(".bias") or parameter.ndim <= 1 or "norm" in name.lower():
|
|
215
|
+
params["no_decay"]["params"].append(parameter)
|
|
216
|
+
else:
|
|
217
|
+
params["with_decay"]["params"].append(parameter)
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
optimizer = torch.optim.AdamW(params=list(params.values()), lr=self.session.learning_rate, betas=(0.9, 0.95))
|
|
221
|
+
trainable_parameters: int = sum(parameter.numel() for parameter in self.parameters() if parameter.requires_grad)
|
|
222
|
+
logger.bind(
|
|
223
|
+
component="optimizer",
|
|
224
|
+
session=self.session.name,
|
|
225
|
+
learning_rate=self.session.learning_rate,
|
|
226
|
+
weight_decay=self.session.weight_decay,
|
|
227
|
+
trainable_parameters=trainable_parameters,
|
|
228
|
+
warmup_ratio=self.session.warmup_ratio,
|
|
229
|
+
min_lr_ratio=self.session.min_lr_ratio,
|
|
230
|
+
).info("configured AdamW optimizer")
|
|
231
|
+
|
|
232
|
+
total = int(getattr(self.trainer, "estimated_stepping_batches", 0) or 0)
|
|
233
|
+
|
|
234
|
+
if total <= 0:
|
|
235
|
+
return optimizer
|
|
236
|
+
|
|
237
|
+
warmup = max(1, int(total * self.session.warmup_ratio))
|
|
238
|
+
min_lr_ratio = self.session.min_lr_ratio
|
|
239
|
+
|
|
240
|
+
def schedule(step: int) -> float:
|
|
241
|
+
|
|
242
|
+
if step < warmup:
|
|
243
|
+
return float(step + 1) / float(warmup)
|
|
244
|
+
|
|
245
|
+
ratio = float(step - warmup) / float(max(1, total - warmup))
|
|
246
|
+
|
|
247
|
+
progress = min(1.0, ratio)
|
|
248
|
+
|
|
249
|
+
cosine = 0.5 * (1.0 + math.cos(math.pi * progress))
|
|
250
|
+
return min_lr_ratio + (1.0 - min_lr_ratio) * cosine
|
|
251
|
+
|
|
252
|
+
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, schedule)
|
|
253
|
+
|
|
254
|
+
return dict(
|
|
255
|
+
optimizer= optimizer,
|
|
256
|
+
lr_scheduler=dict(
|
|
257
|
+
scheduler = scheduler,
|
|
258
|
+
interval = "step",
|
|
259
|
+
frequency = 1,
|
|
260
|
+
),
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
def on_save_checkpoint(self, checkpoint):
|
|
264
|
+
logger.bind(component="checkpoint", session=self.session.name).info("serializing session")
|
|
265
|
+
checkpoint["session"] = self.session.model_dump()
|
|
266
|
+
|
|
267
|
+
def on_load_checkpoint(self, checkpoint):
|
|
268
|
+
logger.bind(component="checkpoint").info("loading session from checkpoint payload")
|
|
269
|
+
if "session" in checkpoint and getattr(self, "session", None) is None:
|
|
270
|
+
self.session = Session.model_validate(checkpoint["session"])
|
|
271
|
+
|
|
272
|
+
if getattr(self, "session", None) is None:
|
|
273
|
+
raise ValueError("missing session in checkpoint and constructor")
|
|
274
|
+
|
|
275
|
+
@classmethod
|
|
276
|
+
def get_or_create(
|
|
277
|
+
cls,
|
|
278
|
+
session: Session|None = None,
|
|
279
|
+
checkpoint: str | None = None,
|
|
280
|
+
) -> Self:
|
|
281
|
+
|
|
282
|
+
if checkpoint is None:
|
|
283
|
+
logger.bind(component="model_factory").info("creating new JSON2Vec model")
|
|
284
|
+
if session is None:
|
|
285
|
+
raise ValueError("session is required when checkpoint is not provided")
|
|
286
|
+
|
|
287
|
+
model: "JSON2Vec" = cls(session=session)
|
|
288
|
+
|
|
289
|
+
return model
|
|
290
|
+
|
|
291
|
+
else:
|
|
292
|
+
logger.bind(component="model_factory", checkpoint=checkpoint).info("loading JSON2Vec model from checkpoint")
|
|
293
|
+
state = torch.load(checkpoint, weights_only=False, map_location="cpu")
|
|
294
|
+
|
|
295
|
+
model: "JSON2Vec" = cls(
|
|
296
|
+
session=session or Session.model_validate(state["session"]),
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
model.load_state_dict(state_dict=state["state_dict"], strict=False)
|
|
300
|
+
logger.bind(component="model_factory", checkpoint=checkpoint).info("restored model state from checkpoint")
|
|
301
|
+
|
|
302
|
+
return model
|
|
303
|
+
|
|
304
|
+
def write(self, predictions: list[Prediction]) -> tuple[dict[Address, dict[str, Any]], dict[Address, dict[str, Any]]]:
|
|
305
|
+
|
|
306
|
+
supervised: dict[Address, dict[str, Any]] = {}
|
|
307
|
+
embeddings: dict[Address, dict[str, Any]] = {}
|
|
308
|
+
|
|
309
|
+
for prediction in predictions:
|
|
310
|
+
|
|
311
|
+
if isinstance(prediction, Embedding):
|
|
312
|
+
|
|
313
|
+
embeddings[prediction.address] = Embedding.write(prediction)
|
|
314
|
+
|
|
315
|
+
continue
|
|
316
|
+
|
|
317
|
+
request: RequestBase = self.session.structure.requests[prediction.address]
|
|
318
|
+
|
|
319
|
+
extension: Plugin = TENSORFIELDS[request.type]
|
|
320
|
+
|
|
321
|
+
scribed: dict[TensorKey, Any]|None = extension.write(module=self, prediction=prediction)
|
|
322
|
+
|
|
323
|
+
if scribed is not None:
|
|
324
|
+
supervised[prediction.address] = Prediction.serialize(scribed)
|
|
325
|
+
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
return supervised, embeddings
|
|
329
|
+
|
|
330
|
+
training_step = partialmethod(step, strata=Strata.train)
|
|
331
|
+
validation_step = partialmethod(step, strata=Strata.validate)
|
|
332
|
+
test_step = partialmethod(step, strata=Strata.test)
|
|
333
|
+
predict_step = partialmethod(step, strata=Strata.predict)
|
|
334
|
+
|
|
335
|
+
train_dataloader = partialmethod(dataloader, strata=Strata.train)
|
|
336
|
+
val_dataloader = partialmethod(dataloader, strata=Strata.validate)
|
|
337
|
+
test_dataloader = partialmethod(dataloader, strata=Strata.test)
|
|
338
|
+
predict_dataloader = partialmethod(dataloader, strata=Strata.predict)
|