vectormesh 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vectormesh/__init__.py +28 -0
- vectormesh/components/__init__.py +33 -0
- vectormesh/components/aggregation.py +80 -0
- vectormesh/components/connectors.py +40 -0
- vectormesh/components/gating.py +132 -0
- vectormesh/components/metrics.py +121 -0
- vectormesh/components/neural.py +62 -0
- vectormesh/components/padding.py +34 -0
- vectormesh/components/pipelines.py +41 -0
- vectormesh/data/__init__.py +16 -0
- vectormesh/data/cache.py +208 -0
- vectormesh/data/dataset.py +237 -0
- vectormesh/data/summarize.py +285 -0
- vectormesh/data/vectorizers.py +475 -0
- vectormesh/types.py +51 -0
- vectormesh-0.1.0.dist-info/METADATA +479 -0
- vectormesh-0.1.0.dist-info/RECORD +19 -0
- vectormesh-0.1.0.dist-info/WHEEL +4 -0
- vectormesh-0.1.0.dist-info/entry_points.txt +2 -0
vectormesh/__init__.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
from importlib.metadata import version
|
|
3
|
+
|
|
4
|
+
from loguru import logger
|
|
5
|
+
|
|
6
|
+
from vectormesh.data import (
|
|
7
|
+
BaseVectorizer,
|
|
8
|
+
LabelEncoder,
|
|
9
|
+
RegexVectorizer,
|
|
10
|
+
VectorCache,
|
|
11
|
+
Vectorizer,
|
|
12
|
+
build,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
__version__ = version("vectormesh")
|
|
16
|
+
|
|
17
|
+
__all__ = [
|
|
18
|
+
"VectorCache",
|
|
19
|
+
"LabelEncoder",
|
|
20
|
+
"build",
|
|
21
|
+
"BaseVectorizer",
|
|
22
|
+
"Vectorizer",
|
|
23
|
+
"RegexVectorizer",
|
|
24
|
+
]
|
|
25
|
+
|
|
26
|
+
logger.remove()
|
|
27
|
+
logger.add(sys.stderr, level="INFO")
|
|
28
|
+
logger.add("logs/dataset.log", rotation="10 MB", level="DEBUG")
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
"""VectorMesh components module."""
|
|
2
|
+
|
|
3
|
+
from .aggregation import (
|
|
4
|
+
AttentionAggregator,
|
|
5
|
+
BaseAggregator,
|
|
6
|
+
MeanAggregator,
|
|
7
|
+
RNNAggregator,
|
|
8
|
+
)
|
|
9
|
+
from .connectors import Concatenate2D, Stack2D
|
|
10
|
+
from .gating import Gate, Highway, MoE, Skip
|
|
11
|
+
from .neural import Attention, NeuralNet, Projection
|
|
12
|
+
from .padding import DynamicPadding, FixedPadding
|
|
13
|
+
from .pipelines import Parallel, Serial
|
|
14
|
+
|
|
15
|
+
__all__ = [
|
|
16
|
+
"AttentionAggregator",
|
|
17
|
+
"BaseAggregator",
|
|
18
|
+
"MeanAggregator",
|
|
19
|
+
"RNNAggregator",
|
|
20
|
+
"Concatenate2D",
|
|
21
|
+
"Stack2D",
|
|
22
|
+
"Gate",
|
|
23
|
+
"Highway",
|
|
24
|
+
"Skip",
|
|
25
|
+
"MoE",
|
|
26
|
+
"NeuralNet",
|
|
27
|
+
"Projection",
|
|
28
|
+
"Attention",
|
|
29
|
+
"DynamicPadding",
|
|
30
|
+
"FixedPadding",
|
|
31
|
+
"Parallel",
|
|
32
|
+
"Serial",
|
|
33
|
+
]
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
from abc import abstractmethod
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn as nn
|
|
5
|
+
from beartype import beartype
|
|
6
|
+
from jaxtyping import Float, jaxtyped
|
|
7
|
+
from torch import Tensor
|
|
8
|
+
|
|
9
|
+
from vectormesh.types import BaseComponent
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class BaseAggregator(BaseComponent):
|
|
13
|
+
"""Base class for aggregating 3D -> 2D tensors.
|
|
14
|
+
We use "forward" to be compatible with nn.Module
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
def __init__(self):
|
|
18
|
+
super().__init__()
|
|
19
|
+
|
|
20
|
+
@abstractmethod
|
|
21
|
+
@jaxtyped(typechecker=beartype)
|
|
22
|
+
def forward(
|
|
23
|
+
self, embeddings: Float[Tensor, "batch _ dim"]
|
|
24
|
+
) -> Float[Tensor, "batch dim"]:
|
|
25
|
+
"""Aggregate from (batch, chunks, dim) to (batch, dim)."""
|
|
26
|
+
...
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class MeanAggregator(BaseAggregator):
|
|
30
|
+
"""Aggregate by taking mean over chunks.
|
|
31
|
+
no learnable parameters.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
@jaxtyped(typechecker=beartype)
|
|
35
|
+
def forward(
|
|
36
|
+
self, embeddings: Float[Tensor, "batch _ dim"]
|
|
37
|
+
) -> Float[Tensor, "batch dim"]:
|
|
38
|
+
"""Mean over chunks dimension."""
|
|
39
|
+
return embeddings.mean(dim=1)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class AttentionAggregator(BaseAggregator):
|
|
43
|
+
"""Aggregate using learnable attention over chunks.
|
|
44
|
+
Because attention does not handle variable-length sequences,
|
|
45
|
+
we actually get (batch, chunks, dim) where chunks is fixed.
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
def __init__(self, hidden_size: int):
|
|
49
|
+
"""initialize learnable parameters."""
|
|
50
|
+
super().__init__()
|
|
51
|
+
self.attention = nn.Linear(hidden_size, 1)
|
|
52
|
+
|
|
53
|
+
@jaxtyped(typechecker=beartype)
|
|
54
|
+
def forward(
|
|
55
|
+
self, embeddings: Float[Tensor, "batch _ dim"]
|
|
56
|
+
) -> Float[Tensor, "batch dim"]:
|
|
57
|
+
# attention_weights: (batch, _, 1)
|
|
58
|
+
attention_weights = torch.softmax(self.attention(embeddings), dim=1)
|
|
59
|
+
return (embeddings * attention_weights).sum(dim=1)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class RNNAggregator(BaseAggregator):
|
|
63
|
+
"""Aggregate using RNN over chunks.
|
|
64
|
+
return final hidden state.
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
def __init__(self, hidden_size: int):
|
|
68
|
+
"""initialize learnable parameters."""
|
|
69
|
+
super().__init__()
|
|
70
|
+
self.rnn = torch.nn.GRU(
|
|
71
|
+
input_size=hidden_size, hidden_size=hidden_size, batch_first=True
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
@jaxtyped(typechecker=beartype)
|
|
75
|
+
def forward(
|
|
76
|
+
self, embeddings: Float[Tensor, "batch _ dim"]
|
|
77
|
+
) -> Float[Tensor, "batch dim"]:
|
|
78
|
+
output, _ = self.rnn(embeddings)
|
|
79
|
+
return output[:, -1, :]
|
|
80
|
+
return output[:, -1, :]
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from beartype import beartype
|
|
3
|
+
from jaxtyping import Float, jaxtyped
|
|
4
|
+
from torch import Tensor
|
|
5
|
+
|
|
6
|
+
from vectormesh.types import BaseComponent
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class Concatenate2D(BaseComponent):
|
|
10
|
+
"""Concatenate tuples from parallel branches at last dimension.
|
|
11
|
+
|
|
12
|
+
input: ((batch dim1), (batch dim2), ...)
|
|
13
|
+
output: (batch ndim)
|
|
14
|
+
|
|
15
|
+
where ndim = dim1 + dim2 + ...
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
@jaxtyped(typechecker=beartype)
|
|
19
|
+
def forward(
|
|
20
|
+
self, tensors: tuple[Float[Tensor, "batch dim"], ...]
|
|
21
|
+
) -> Float[Tensor, "batch ndim"]:
|
|
22
|
+
return torch.cat(tensors, dim=-1)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class Stack2D(BaseComponent):
|
|
26
|
+
"""Stack n tuples from parallel branches, default 1st dimension.
|
|
27
|
+
|
|
28
|
+
input : ((batch dim1), (batch dim1), ...)
|
|
29
|
+
output: (batch nstack dim1)
|
|
30
|
+
|
|
31
|
+
where nstack is the number of tensors in the tuple
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
@jaxtyped(typechecker=beartype)
|
|
35
|
+
def forward(
|
|
36
|
+
self, tensors: tuple[Float[Tensor, "batch dim1"], ...]
|
|
37
|
+
) -> Float[Tensor, "batch nstack dim1"]:
|
|
38
|
+
return torch.stack(tensors, dim=1)
|
|
39
|
+
return torch.stack(tensors, dim=1)
|
|
40
|
+
return torch.stack(tensors, dim=1)
|
|
@@ -0,0 +1,132 @@
|
|
|
1
|
+
"""Residual and gating components for skip connections and gated transformations."""
|
|
2
|
+
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torch.nn as nn
|
|
7
|
+
import torch.nn.functional as F
|
|
8
|
+
from beartype import beartype
|
|
9
|
+
from jaxtyping import Float, jaxtyped
|
|
10
|
+
from torch import Tensor
|
|
11
|
+
|
|
12
|
+
from vectormesh.types import BaseComponent
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class Skip(BaseComponent):
|
|
16
|
+
"""Residual skip connection: output = batchnorm(transform(x) + projection(x))
|
|
17
|
+
- transform is the pipeline we want to apply to the input
|
|
18
|
+
- in_size is the dimensionality of the input; we need this for the layernorm
|
|
19
|
+
- projection is an optional pipeline, eg a Linear(in_size, out_size) if the
|
|
20
|
+
transform changes the dimensionality.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
transform: nn.Module
|
|
24
|
+
projection: Optional[nn.Module]
|
|
25
|
+
layernorm: nn.LayerNorm
|
|
26
|
+
|
|
27
|
+
def __init__(
|
|
28
|
+
self,
|
|
29
|
+
transform: nn.Module,
|
|
30
|
+
in_size: int,
|
|
31
|
+
projection: Optional[nn.Module] = None,
|
|
32
|
+
):
|
|
33
|
+
super().__init__()
|
|
34
|
+
self.transform = transform
|
|
35
|
+
self.projection = projection
|
|
36
|
+
self.layernorm = nn.LayerNorm(in_size)
|
|
37
|
+
|
|
38
|
+
@jaxtyped(typechecker=beartype)
|
|
39
|
+
def forward(self, x: Float[Tensor, "..."]) -> Float[Tensor, "..."]:
|
|
40
|
+
# pre-norm (instead of post-norm) improves stability
|
|
41
|
+
x = self.layernorm(x)
|
|
42
|
+
residual = self.projection(x) if self.projection else x
|
|
43
|
+
transformed = self.transform(x)
|
|
44
|
+
return transformed + residual
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class Gate(BaseComponent):
|
|
48
|
+
"""Simple gating: output = sigmoid(W·x) * x"""
|
|
49
|
+
|
|
50
|
+
def __init__(self, hidden_size: int):
|
|
51
|
+
super().__init__()
|
|
52
|
+
self.project = nn.Linear(hidden_size, hidden_size)
|
|
53
|
+
|
|
54
|
+
@jaxtyped(typechecker=beartype)
|
|
55
|
+
def forward(self, x: Float[Tensor, "batch dim"]) -> Float[Tensor, "batch dim"]:
|
|
56
|
+
return F.sigmoid(self.project(x)) * x
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class Highway(BaseComponent):
|
|
60
|
+
"""Highway network: G * T(x) + (1-G) * x"""
|
|
61
|
+
|
|
62
|
+
def __init__(self, transform: nn.Module, hidden_size: int):
|
|
63
|
+
super().__init__()
|
|
64
|
+
self.transform = transform
|
|
65
|
+
self.project = nn.Linear(hidden_size, hidden_size)
|
|
66
|
+
self.norm = nn.LayerNorm(hidden_size)
|
|
67
|
+
|
|
68
|
+
@jaxtyped(typechecker=beartype)
|
|
69
|
+
def forward(self, x: Float[Tensor, "batch dim"]) -> Float[Tensor, "batch dim"]:
|
|
70
|
+
# pre-norm (instead of post-norm) improves stability
|
|
71
|
+
x = self.norm(x)
|
|
72
|
+
gate = F.sigmoid(self.project(x))
|
|
73
|
+
transformed = self.transform(x)
|
|
74
|
+
return gate * transformed + (1 - gate) * x
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class MoE(BaseComponent):
|
|
78
|
+
"""
|
|
79
|
+
See https://arxiv.org/abs/1701.06538 for paper
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
def __init__(self, experts, hidden_size, out_size, top_k, noisy_gating=True):
|
|
83
|
+
super().__init__()
|
|
84
|
+
self.experts = nn.ModuleList(experts)
|
|
85
|
+
self.router = nn.Linear(hidden_size, len(experts))
|
|
86
|
+
|
|
87
|
+
self.w_noise = nn.Linear(hidden_size, len(experts))
|
|
88
|
+
self.noisy_gating = noisy_gating
|
|
89
|
+
self.top_k = top_k
|
|
90
|
+
self.num_experts = len(experts)
|
|
91
|
+
self.out_size = out_size
|
|
92
|
+
|
|
93
|
+
def forward(self, x):
|
|
94
|
+
clean_logits = self.router(x)
|
|
95
|
+
|
|
96
|
+
# self.training is automatically managed by .eval() and .train()
|
|
97
|
+
if self.noisy_gating and self.training:
|
|
98
|
+
raw_noise_stddev = self.w_noise(x)
|
|
99
|
+
noise_stddev = F.softplus(raw_noise_stddev) + 1e-2
|
|
100
|
+
noise = torch.randn_like(clean_logits) * noise_stddev
|
|
101
|
+
noisy_logits = clean_logits + noise
|
|
102
|
+
else:
|
|
103
|
+
noisy_logits = clean_logits
|
|
104
|
+
|
|
105
|
+
# We set non-top-k logits to -inf so Softmax drives them to absolute zero
|
|
106
|
+
top_logits, top_indices = noisy_logits.topk(self.top_k, dim=1)
|
|
107
|
+
full_logits = torch.full_like(noisy_logits, float("-inf"))
|
|
108
|
+
full_logits.scatter_(1, top_indices, top_logits)
|
|
109
|
+
|
|
110
|
+
router_probs = F.softmax(full_logits, dim=1)
|
|
111
|
+
|
|
112
|
+
final_output = torch.zeros(x.size(0), self.out_size, device=x.device)
|
|
113
|
+
|
|
114
|
+
for i in range(self.num_experts):
|
|
115
|
+
mask = (top_indices == i).any(dim=1)
|
|
116
|
+
|
|
117
|
+
if mask.any():
|
|
118
|
+
expert_input = x[mask]
|
|
119
|
+
expert_output = self.experts[i](expert_input)
|
|
120
|
+
|
|
121
|
+
expert_weights = router_probs[mask, i].unsqueeze(-1)
|
|
122
|
+
|
|
123
|
+
final_output[mask] += expert_output * expert_weights
|
|
124
|
+
|
|
125
|
+
# TODO: the paper implements importance loss
|
|
126
|
+
# to encourage balanced expert usage
|
|
127
|
+
#
|
|
128
|
+
# importance = router_probs.sum(dim=0)
|
|
129
|
+
# imp_loss = (importance.std() / (importance.mean() + 1e-10)).pow(2)
|
|
130
|
+
# return final_output, imp_loss
|
|
131
|
+
|
|
132
|
+
return final_output
|
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from typing import Iterator, Union
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
NumericType = Union[torch.Tensor, np.ndarray]
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class Metric(ABC):
|
|
13
|
+
"""Base class for all metrics with proper device and type handling"""
|
|
14
|
+
|
|
15
|
+
@abstractmethod
|
|
16
|
+
def _compute(self, y: torch.Tensor, yhat: torch.Tensor) -> torch.Tensor:
|
|
17
|
+
"""Actual metric computation - implemented by subclasses"""
|
|
18
|
+
pass
|
|
19
|
+
|
|
20
|
+
def __call__(self, y: NumericType, yhat: NumericType) -> float:
|
|
21
|
+
"""Handle device/type conversion and return final metric"""
|
|
22
|
+
# Convert to tensors if needed
|
|
23
|
+
if isinstance(y, np.ndarray):
|
|
24
|
+
y = torch.from_numpy(y)
|
|
25
|
+
if isinstance(yhat, np.ndarray):
|
|
26
|
+
yhat = torch.from_numpy(yhat)
|
|
27
|
+
|
|
28
|
+
# Ensure we're working with the same device
|
|
29
|
+
device = y.device
|
|
30
|
+
y = y.to(device)
|
|
31
|
+
yhat = yhat.to(device)
|
|
32
|
+
|
|
33
|
+
# Compute metric
|
|
34
|
+
result = self._compute(y, yhat)
|
|
35
|
+
|
|
36
|
+
# Return as float for consistency
|
|
37
|
+
return float(result.cpu().detach())
|
|
38
|
+
|
|
39
|
+
@abstractmethod
|
|
40
|
+
def __repr__(self) -> str:
|
|
41
|
+
pass
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
class MAE(Metric):
|
|
45
|
+
def _compute(self, y: torch.Tensor, yhat: torch.Tensor) -> torch.Tensor:
|
|
46
|
+
return torch.mean(torch.abs(y - yhat))
|
|
47
|
+
|
|
48
|
+
def __repr__(self) -> str:
|
|
49
|
+
return "MAE"
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class MASE(Metric):
|
|
53
|
+
def __init__(self, train: Iterator, horizon: int) -> None:
|
|
54
|
+
self.horizon = horizon
|
|
55
|
+
with torch.no_grad():
|
|
56
|
+
self.scale = self._calculate_scale(train)
|
|
57
|
+
|
|
58
|
+
def _calculate_scale(self, train: Iterator) -> torch.Tensor:
|
|
59
|
+
elist = []
|
|
60
|
+
streamer = train.stream() # type: ignore
|
|
61
|
+
for _ in range(len(train)): # type: ignore
|
|
62
|
+
x, y = next(iter(streamer))
|
|
63
|
+
yhat = self._naive_predict(x)
|
|
64
|
+
e = torch.mean(torch.abs(y - yhat))
|
|
65
|
+
elist.append(e)
|
|
66
|
+
return torch.mean(torch.stack(elist))
|
|
67
|
+
|
|
68
|
+
def _naive_predict(self, x: torch.Tensor) -> torch.Tensor:
|
|
69
|
+
return x[..., -self.horizon :, :].squeeze(-1)
|
|
70
|
+
|
|
71
|
+
def _compute(self, y: torch.Tensor, yhat: torch.Tensor) -> torch.Tensor:
|
|
72
|
+
mae = torch.mean(torch.abs(y - yhat))
|
|
73
|
+
return mae / self.scale
|
|
74
|
+
|
|
75
|
+
def __repr__(self) -> str:
|
|
76
|
+
return f"MASE(scale={self.scale:.3f})"
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
class Accuracy(Metric):
|
|
80
|
+
def _compute(self, y: torch.Tensor, yhat: torch.Tensor) -> torch.Tensor:
|
|
81
|
+
predictions = yhat.argmax(dim=1)
|
|
82
|
+
return (predictions == y).float().mean()
|
|
83
|
+
|
|
84
|
+
def __repr__(self) -> str:
|
|
85
|
+
return "Accuracy"
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class F1Score(Metric):
|
|
89
|
+
def __init__(self, average="micro", threshold=0.5, epsilon=1e-7):
|
|
90
|
+
self.average = average
|
|
91
|
+
self.threshold = threshold
|
|
92
|
+
self.epsilon = epsilon
|
|
93
|
+
|
|
94
|
+
def _compute(self, y: torch.Tensor, yhat: torch.Tensor) -> torch.Tensor:
|
|
95
|
+
y_prob = torch.sigmoid(yhat)
|
|
96
|
+
y_pred = (y_prob > self.threshold).float()
|
|
97
|
+
|
|
98
|
+
tp = y_pred * y
|
|
99
|
+
fp = y_pred * (1 - y)
|
|
100
|
+
fn = (1 - y_pred) * y
|
|
101
|
+
|
|
102
|
+
if self.average == "micro":
|
|
103
|
+
tp_sum = tp.sum()
|
|
104
|
+
fp_sum = fp.sum()
|
|
105
|
+
fn_sum = fn.sum()
|
|
106
|
+
f1 = (2 * tp_sum) / (2 * tp_sum + fp_sum + fn_sum + self.epsilon)
|
|
107
|
+
|
|
108
|
+
elif self.average == "macro":
|
|
109
|
+
tp_sum = tp.sum(dim=0)
|
|
110
|
+
fp_sum = fp.sum(dim=0)
|
|
111
|
+
fn_sum = fn.sum(dim=0)
|
|
112
|
+
f1_per_class = (2 * tp_sum) / (2 * tp_sum + fp_sum + fn_sum + self.epsilon)
|
|
113
|
+
f1 = f1_per_class.mean()
|
|
114
|
+
|
|
115
|
+
else:
|
|
116
|
+
raise ValueError("Average must be 'micro' or 'macro'")
|
|
117
|
+
|
|
118
|
+
return f1
|
|
119
|
+
|
|
120
|
+
def __repr__(self) -> str:
|
|
121
|
+
return f"F1({self.average})"
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
import torch.nn as nn
|
|
2
|
+
from beartype import beartype
|
|
3
|
+
from jaxtyping import Float, jaxtyped
|
|
4
|
+
from torch import Tensor
|
|
5
|
+
|
|
6
|
+
from vectormesh.types import BaseComponent
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class NeuralNet(BaseComponent):
|
|
10
|
+
"""Two-layer feedforward network with GELU activation."""
|
|
11
|
+
|
|
12
|
+
def __init__(self, hidden_size: int, out_size: int):
|
|
13
|
+
super().__init__()
|
|
14
|
+
self.hidden_size = hidden_size
|
|
15
|
+
self.out_size = out_size
|
|
16
|
+
self.fc1 = nn.Linear(hidden_size, hidden_size)
|
|
17
|
+
self.fc2 = nn.Linear(hidden_size, out_size)
|
|
18
|
+
self.activation = nn.GELU()
|
|
19
|
+
|
|
20
|
+
@jaxtyped(typechecker=beartype)
|
|
21
|
+
def forward(
|
|
22
|
+
self, x: Float[Tensor, "batch {self.hidden_size}"]
|
|
23
|
+
) -> Float[Tensor, "batch {self.out_size}"]:
|
|
24
|
+
return self.fc2(self.activation(self.fc1(x)))
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class Projection(BaseComponent):
|
|
28
|
+
"""Linear projection layer."""
|
|
29
|
+
|
|
30
|
+
def __init__(self, hidden_size: int, out_size: int):
|
|
31
|
+
super().__init__()
|
|
32
|
+
self.hidden_size = hidden_size
|
|
33
|
+
self.out_size = out_size
|
|
34
|
+
self.proj = nn.Linear(hidden_size, out_size)
|
|
35
|
+
|
|
36
|
+
@jaxtyped(typechecker=beartype)
|
|
37
|
+
def forward(
|
|
38
|
+
self, x: Float[Tensor, "batch {self.hidden_size}"]
|
|
39
|
+
) -> Float[Tensor, "batch {self.out_size}"]:
|
|
40
|
+
return self.proj(x)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class Attention(nn.Module):
|
|
44
|
+
"""Multi-head self-attention using PyTorch's implementation."""
|
|
45
|
+
|
|
46
|
+
def __init__(self, hidden_size: int, num_heads: int = 8, dropout: float = 0.1):
|
|
47
|
+
super().__init__()
|
|
48
|
+
self.hidden_size = hidden_size
|
|
49
|
+
self.attn = nn.MultiheadAttention(
|
|
50
|
+
embed_dim=hidden_size,
|
|
51
|
+
num_heads=num_heads,
|
|
52
|
+
dropout=dropout,
|
|
53
|
+
batch_first=True,
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
@jaxtyped(typechecker=beartype)
|
|
57
|
+
def forward(
|
|
58
|
+
self, x: Float[Tensor, "batch seq {self.hidden_size}"]
|
|
59
|
+
) -> Float[Tensor, "batch seq {self.hidden_size}"]:
|
|
60
|
+
# Self-attention: query, key, value all come from x
|
|
61
|
+
attn_output, _ = self.attn(x, x, x, need_weights=False)
|
|
62
|
+
return attn_output
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
from beartype import beartype
|
|
5
|
+
from jaxtyping import Float, jaxtyped
|
|
6
|
+
from torch import Tensor
|
|
7
|
+
from torch.nn.utils.rnn import pad_sequence
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class DynamicPadding:
|
|
11
|
+
@jaxtyped(typechecker=beartype)
|
|
12
|
+
def __call__(self, embeddings: list[Any]) -> Float[Tensor, "batch _ dim"]:
|
|
13
|
+
"""Pad sequences to the maximum length in the batch."""
|
|
14
|
+
return pad_sequence(
|
|
15
|
+
embeddings,
|
|
16
|
+
batch_first=True,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class FixedPadding:
|
|
21
|
+
def __init__(self, max_chunks: int):
|
|
22
|
+
self.max_chunks = max_chunks
|
|
23
|
+
|
|
24
|
+
@jaxtyped(typechecker=beartype)
|
|
25
|
+
def __call__(
|
|
26
|
+
self, embeddings: list[Any]
|
|
27
|
+
) -> Float[Tensor, "batch {self.max_chunks} dim"]:
|
|
28
|
+
padded = pad_sequence(embeddings, batch_first=True)
|
|
29
|
+
|
|
30
|
+
current = padded.shape[1]
|
|
31
|
+
if current < self.max_chunks:
|
|
32
|
+
return F.pad(padded, (0, 0, 0, self.max_chunks - current))
|
|
33
|
+
else:
|
|
34
|
+
return padded[:, : self.max_chunks, :]
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
import torch.nn as nn
|
|
2
|
+
from beartype import beartype
|
|
3
|
+
from jaxtyping import jaxtyped
|
|
4
|
+
|
|
5
|
+
from vectormesh.types import TensorInput
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class Serial(nn.Module):
|
|
9
|
+
"""Sequential composition - just runs components in order.
|
|
10
|
+
Tensor checking happens via jaxtyping decorators on each component.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
components: nn.ModuleList
|
|
14
|
+
|
|
15
|
+
def __init__(self, components: list[nn.Module]):
|
|
16
|
+
super().__init__()
|
|
17
|
+
self.components = nn.ModuleList(components)
|
|
18
|
+
|
|
19
|
+
@jaxtyped(typechecker=beartype)
|
|
20
|
+
def forward(self, tensors: TensorInput) -> TensorInput:
|
|
21
|
+
"""Execute pipeline. Type checking via component decorators."""
|
|
22
|
+
result = tensors
|
|
23
|
+
for component in self.components:
|
|
24
|
+
result = component(result)
|
|
25
|
+
return result
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class Parallel(nn.Module):
|
|
29
|
+
"""Parallel composition - runs branches independently and returns tuple.
|
|
30
|
+
All branches receive the same input.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
branches: nn.ModuleList
|
|
34
|
+
|
|
35
|
+
def __init__(self, branches: list[nn.Module]):
|
|
36
|
+
super().__init__()
|
|
37
|
+
self.branches = nn.ModuleList(branches)
|
|
38
|
+
|
|
39
|
+
@jaxtyped(typechecker=beartype)
|
|
40
|
+
def forward(self, tensors: TensorInput) -> TensorInput:
|
|
41
|
+
return tuple(branch(t) for branch, t in zip(self.branches, tensors))
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
"""VectorMesh data components."""
|
|
2
|
+
|
|
3
|
+
from .cache import VectorCache
|
|
4
|
+
from .dataset import Collate, LabelEncoder, OneHot, build
|
|
5
|
+
from .vectorizers import BaseVectorizer, RegexVectorizer, Vectorizer
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"VectorCache",
|
|
9
|
+
"LabelEncoder",
|
|
10
|
+
"OneHot",
|
|
11
|
+
"Collate",
|
|
12
|
+
"build",
|
|
13
|
+
"BaseVectorizer",
|
|
14
|
+
"Vectorizer",
|
|
15
|
+
"RegexVectorizer",
|
|
16
|
+
]
|