vectormesh 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
vectormesh/__init__.py ADDED
@@ -0,0 +1,28 @@
1
+ import sys
2
+ from importlib.metadata import version
3
+
4
+ from loguru import logger
5
+
6
+ from vectormesh.data import (
7
+ BaseVectorizer,
8
+ LabelEncoder,
9
+ RegexVectorizer,
10
+ VectorCache,
11
+ Vectorizer,
12
+ build,
13
+ )
14
+
15
+ __version__ = version("vectormesh")
16
+
17
+ __all__ = [
18
+ "VectorCache",
19
+ "LabelEncoder",
20
+ "build",
21
+ "BaseVectorizer",
22
+ "Vectorizer",
23
+ "RegexVectorizer",
24
+ ]
25
+
26
+ logger.remove()
27
+ logger.add(sys.stderr, level="INFO")
28
+ logger.add("logs/dataset.log", rotation="10 MB", level="DEBUG")
@@ -0,0 +1,33 @@
1
+ """VectorMesh components module."""
2
+
3
+ from .aggregation import (
4
+ AttentionAggregator,
5
+ BaseAggregator,
6
+ MeanAggregator,
7
+ RNNAggregator,
8
+ )
9
+ from .connectors import Concatenate2D, Stack2D
10
+ from .gating import Gate, Highway, MoE, Skip
11
+ from .neural import Attention, NeuralNet, Projection
12
+ from .padding import DynamicPadding, FixedPadding
13
+ from .pipelines import Parallel, Serial
14
+
15
+ __all__ = [
16
+ "AttentionAggregator",
17
+ "BaseAggregator",
18
+ "MeanAggregator",
19
+ "RNNAggregator",
20
+ "Concatenate2D",
21
+ "Stack2D",
22
+ "Gate",
23
+ "Highway",
24
+ "Skip",
25
+ "MoE",
26
+ "NeuralNet",
27
+ "Projection",
28
+ "Attention",
29
+ "DynamicPadding",
30
+ "FixedPadding",
31
+ "Parallel",
32
+ "Serial",
33
+ ]
@@ -0,0 +1,80 @@
1
+ from abc import abstractmethod
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from beartype import beartype
6
+ from jaxtyping import Float, jaxtyped
7
+ from torch import Tensor
8
+
9
+ from vectormesh.types import BaseComponent
10
+
11
+
12
+ class BaseAggregator(BaseComponent):
13
+ """Base class for aggregating 3D -> 2D tensors.
14
+ We use "forward" to be compatible with nn.Module
15
+ """
16
+
17
+ def __init__(self):
18
+ super().__init__()
19
+
20
+ @abstractmethod
21
+ @jaxtyped(typechecker=beartype)
22
+ def forward(
23
+ self, embeddings: Float[Tensor, "batch _ dim"]
24
+ ) -> Float[Tensor, "batch dim"]:
25
+ """Aggregate from (batch, chunks, dim) to (batch, dim)."""
26
+ ...
27
+
28
+
29
+ class MeanAggregator(BaseAggregator):
30
+ """Aggregate by taking mean over chunks.
31
+ no learnable parameters.
32
+ """
33
+
34
+ @jaxtyped(typechecker=beartype)
35
+ def forward(
36
+ self, embeddings: Float[Tensor, "batch _ dim"]
37
+ ) -> Float[Tensor, "batch dim"]:
38
+ """Mean over chunks dimension."""
39
+ return embeddings.mean(dim=1)
40
+
41
+
42
+ class AttentionAggregator(BaseAggregator):
43
+ """Aggregate using learnable attention over chunks.
44
+ Because attention does not handle variable-length sequences,
45
+ we actually get (batch, chunks, dim) where chunks is fixed.
46
+ """
47
+
48
+ def __init__(self, hidden_size: int):
49
+ """initialize learnable parameters."""
50
+ super().__init__()
51
+ self.attention = nn.Linear(hidden_size, 1)
52
+
53
+ @jaxtyped(typechecker=beartype)
54
+ def forward(
55
+ self, embeddings: Float[Tensor, "batch _ dim"]
56
+ ) -> Float[Tensor, "batch dim"]:
57
+ # attention_weights: (batch, _, 1)
58
+ attention_weights = torch.softmax(self.attention(embeddings), dim=1)
59
+ return (embeddings * attention_weights).sum(dim=1)
60
+
61
+
62
+ class RNNAggregator(BaseAggregator):
63
+ """Aggregate using RNN over chunks.
64
+ return final hidden state.
65
+ """
66
+
67
+ def __init__(self, hidden_size: int):
68
+ """initialize learnable parameters."""
69
+ super().__init__()
70
+ self.rnn = torch.nn.GRU(
71
+ input_size=hidden_size, hidden_size=hidden_size, batch_first=True
72
+ )
73
+
74
+ @jaxtyped(typechecker=beartype)
75
+ def forward(
76
+ self, embeddings: Float[Tensor, "batch _ dim"]
77
+ ) -> Float[Tensor, "batch dim"]:
78
+ output, _ = self.rnn(embeddings)
79
+ return output[:, -1, :]
80
+ return output[:, -1, :]
@@ -0,0 +1,40 @@
1
+ import torch
2
+ from beartype import beartype
3
+ from jaxtyping import Float, jaxtyped
4
+ from torch import Tensor
5
+
6
+ from vectormesh.types import BaseComponent
7
+
8
+
9
+ class Concatenate2D(BaseComponent):
10
+ """Concatenate tuples from parallel branches at last dimension.
11
+
12
+ input: ((batch dim1), (batch dim2), ...)
13
+ output: (batch ndim)
14
+
15
+ where ndim = dim1 + dim2 + ...
16
+ """
17
+
18
+ @jaxtyped(typechecker=beartype)
19
+ def forward(
20
+ self, tensors: tuple[Float[Tensor, "batch dim"], ...]
21
+ ) -> Float[Tensor, "batch ndim"]:
22
+ return torch.cat(tensors, dim=-1)
23
+
24
+
25
+ class Stack2D(BaseComponent):
26
+ """Stack n tuples from parallel branches, default 1st dimension.
27
+
28
+ input : ((batch dim1), (batch dim1), ...)
29
+ output: (batch nstack dim1)
30
+
31
+ where nstack is the number of tensors in the tuple
32
+ """
33
+
34
+ @jaxtyped(typechecker=beartype)
35
+ def forward(
36
+ self, tensors: tuple[Float[Tensor, "batch dim1"], ...]
37
+ ) -> Float[Tensor, "batch nstack dim1"]:
38
+ return torch.stack(tensors, dim=1)
39
+ return torch.stack(tensors, dim=1)
40
+ return torch.stack(tensors, dim=1)
@@ -0,0 +1,132 @@
1
+ """Residual and gating components for skip connections and gated transformations."""
2
+
3
+ from typing import Optional
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ from beartype import beartype
9
+ from jaxtyping import Float, jaxtyped
10
+ from torch import Tensor
11
+
12
+ from vectormesh.types import BaseComponent
13
+
14
+
15
+ class Skip(BaseComponent):
16
+ """Residual skip connection: output = batchnorm(transform(x) + projection(x))
17
+ - transform is the pipeline we want to apply to the input
18
+ - in_size is the dimensionality of the input; we need this for the layernorm
19
+ - projection is an optional pipeline, eg a Linear(in_size, out_size) if the
20
+ transform changes the dimensionality.
21
+ """
22
+
23
+ transform: nn.Module
24
+ projection: Optional[nn.Module]
25
+ layernorm: nn.LayerNorm
26
+
27
+ def __init__(
28
+ self,
29
+ transform: nn.Module,
30
+ in_size: int,
31
+ projection: Optional[nn.Module] = None,
32
+ ):
33
+ super().__init__()
34
+ self.transform = transform
35
+ self.projection = projection
36
+ self.layernorm = nn.LayerNorm(in_size)
37
+
38
+ @jaxtyped(typechecker=beartype)
39
+ def forward(self, x: Float[Tensor, "..."]) -> Float[Tensor, "..."]:
40
+ # pre-norm (instead of post-norm) improves stability
41
+ x = self.layernorm(x)
42
+ residual = self.projection(x) if self.projection else x
43
+ transformed = self.transform(x)
44
+ return transformed + residual
45
+
46
+
47
+ class Gate(BaseComponent):
48
+ """Simple gating: output = sigmoid(W·x) * x"""
49
+
50
+ def __init__(self, hidden_size: int):
51
+ super().__init__()
52
+ self.project = nn.Linear(hidden_size, hidden_size)
53
+
54
+ @jaxtyped(typechecker=beartype)
55
+ def forward(self, x: Float[Tensor, "batch dim"]) -> Float[Tensor, "batch dim"]:
56
+ return F.sigmoid(self.project(x)) * x
57
+
58
+
59
+ class Highway(BaseComponent):
60
+ """Highway network: G * T(x) + (1-G) * x"""
61
+
62
+ def __init__(self, transform: nn.Module, hidden_size: int):
63
+ super().__init__()
64
+ self.transform = transform
65
+ self.project = nn.Linear(hidden_size, hidden_size)
66
+ self.norm = nn.LayerNorm(hidden_size)
67
+
68
+ @jaxtyped(typechecker=beartype)
69
+ def forward(self, x: Float[Tensor, "batch dim"]) -> Float[Tensor, "batch dim"]:
70
+ # pre-norm (instead of post-norm) improves stability
71
+ x = self.norm(x)
72
+ gate = F.sigmoid(self.project(x))
73
+ transformed = self.transform(x)
74
+ return gate * transformed + (1 - gate) * x
75
+
76
+
77
+ class MoE(BaseComponent):
78
+ """
79
+ See https://arxiv.org/abs/1701.06538 for paper
80
+ """
81
+
82
+ def __init__(self, experts, hidden_size, out_size, top_k, noisy_gating=True):
83
+ super().__init__()
84
+ self.experts = nn.ModuleList(experts)
85
+ self.router = nn.Linear(hidden_size, len(experts))
86
+
87
+ self.w_noise = nn.Linear(hidden_size, len(experts))
88
+ self.noisy_gating = noisy_gating
89
+ self.top_k = top_k
90
+ self.num_experts = len(experts)
91
+ self.out_size = out_size
92
+
93
+ def forward(self, x):
94
+ clean_logits = self.router(x)
95
+
96
+ # self.training is automatically managed by .eval() and .train()
97
+ if self.noisy_gating and self.training:
98
+ raw_noise_stddev = self.w_noise(x)
99
+ noise_stddev = F.softplus(raw_noise_stddev) + 1e-2
100
+ noise = torch.randn_like(clean_logits) * noise_stddev
101
+ noisy_logits = clean_logits + noise
102
+ else:
103
+ noisy_logits = clean_logits
104
+
105
+ # We set non-top-k logits to -inf so Softmax drives them to absolute zero
106
+ top_logits, top_indices = noisy_logits.topk(self.top_k, dim=1)
107
+ full_logits = torch.full_like(noisy_logits, float("-inf"))
108
+ full_logits.scatter_(1, top_indices, top_logits)
109
+
110
+ router_probs = F.softmax(full_logits, dim=1)
111
+
112
+ final_output = torch.zeros(x.size(0), self.out_size, device=x.device)
113
+
114
+ for i in range(self.num_experts):
115
+ mask = (top_indices == i).any(dim=1)
116
+
117
+ if mask.any():
118
+ expert_input = x[mask]
119
+ expert_output = self.experts[i](expert_input)
120
+
121
+ expert_weights = router_probs[mask, i].unsqueeze(-1)
122
+
123
+ final_output[mask] += expert_output * expert_weights
124
+
125
+ # TODO: the paper implements importance loss
126
+ # to encourage balanced expert usage
127
+ #
128
+ # importance = router_probs.sum(dim=0)
129
+ # imp_loss = (importance.std() / (importance.mean() + 1e-10)).pow(2)
130
+ # return final_output, imp_loss
131
+
132
+ return final_output
@@ -0,0 +1,121 @@
1
+ from __future__ import annotations
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import Iterator, Union
5
+
6
+ import numpy as np
7
+ import torch
8
+
9
+ NumericType = Union[torch.Tensor, np.ndarray]
10
+
11
+
12
+ class Metric(ABC):
13
+ """Base class for all metrics with proper device and type handling"""
14
+
15
+ @abstractmethod
16
+ def _compute(self, y: torch.Tensor, yhat: torch.Tensor) -> torch.Tensor:
17
+ """Actual metric computation - implemented by subclasses"""
18
+ pass
19
+
20
+ def __call__(self, y: NumericType, yhat: NumericType) -> float:
21
+ """Handle device/type conversion and return final metric"""
22
+ # Convert to tensors if needed
23
+ if isinstance(y, np.ndarray):
24
+ y = torch.from_numpy(y)
25
+ if isinstance(yhat, np.ndarray):
26
+ yhat = torch.from_numpy(yhat)
27
+
28
+ # Ensure we're working with the same device
29
+ device = y.device
30
+ y = y.to(device)
31
+ yhat = yhat.to(device)
32
+
33
+ # Compute metric
34
+ result = self._compute(y, yhat)
35
+
36
+ # Return as float for consistency
37
+ return float(result.cpu().detach())
38
+
39
+ @abstractmethod
40
+ def __repr__(self) -> str:
41
+ pass
42
+
43
+
44
+ class MAE(Metric):
45
+ def _compute(self, y: torch.Tensor, yhat: torch.Tensor) -> torch.Tensor:
46
+ return torch.mean(torch.abs(y - yhat))
47
+
48
+ def __repr__(self) -> str:
49
+ return "MAE"
50
+
51
+
52
+ class MASE(Metric):
53
+ def __init__(self, train: Iterator, horizon: int) -> None:
54
+ self.horizon = horizon
55
+ with torch.no_grad():
56
+ self.scale = self._calculate_scale(train)
57
+
58
+ def _calculate_scale(self, train: Iterator) -> torch.Tensor:
59
+ elist = []
60
+ streamer = train.stream() # type: ignore
61
+ for _ in range(len(train)): # type: ignore
62
+ x, y = next(iter(streamer))
63
+ yhat = self._naive_predict(x)
64
+ e = torch.mean(torch.abs(y - yhat))
65
+ elist.append(e)
66
+ return torch.mean(torch.stack(elist))
67
+
68
+ def _naive_predict(self, x: torch.Tensor) -> torch.Tensor:
69
+ return x[..., -self.horizon :, :].squeeze(-1)
70
+
71
+ def _compute(self, y: torch.Tensor, yhat: torch.Tensor) -> torch.Tensor:
72
+ mae = torch.mean(torch.abs(y - yhat))
73
+ return mae / self.scale
74
+
75
+ def __repr__(self) -> str:
76
+ return f"MASE(scale={self.scale:.3f})"
77
+
78
+
79
+ class Accuracy(Metric):
80
+ def _compute(self, y: torch.Tensor, yhat: torch.Tensor) -> torch.Tensor:
81
+ predictions = yhat.argmax(dim=1)
82
+ return (predictions == y).float().mean()
83
+
84
+ def __repr__(self) -> str:
85
+ return "Accuracy"
86
+
87
+
88
+ class F1Score(Metric):
89
+ def __init__(self, average="micro", threshold=0.5, epsilon=1e-7):
90
+ self.average = average
91
+ self.threshold = threshold
92
+ self.epsilon = epsilon
93
+
94
+ def _compute(self, y: torch.Tensor, yhat: torch.Tensor) -> torch.Tensor:
95
+ y_prob = torch.sigmoid(yhat)
96
+ y_pred = (y_prob > self.threshold).float()
97
+
98
+ tp = y_pred * y
99
+ fp = y_pred * (1 - y)
100
+ fn = (1 - y_pred) * y
101
+
102
+ if self.average == "micro":
103
+ tp_sum = tp.sum()
104
+ fp_sum = fp.sum()
105
+ fn_sum = fn.sum()
106
+ f1 = (2 * tp_sum) / (2 * tp_sum + fp_sum + fn_sum + self.epsilon)
107
+
108
+ elif self.average == "macro":
109
+ tp_sum = tp.sum(dim=0)
110
+ fp_sum = fp.sum(dim=0)
111
+ fn_sum = fn.sum(dim=0)
112
+ f1_per_class = (2 * tp_sum) / (2 * tp_sum + fp_sum + fn_sum + self.epsilon)
113
+ f1 = f1_per_class.mean()
114
+
115
+ else:
116
+ raise ValueError("Average must be 'micro' or 'macro'")
117
+
118
+ return f1
119
+
120
+ def __repr__(self) -> str:
121
+ return f"F1({self.average})"
@@ -0,0 +1,62 @@
1
+ import torch.nn as nn
2
+ from beartype import beartype
3
+ from jaxtyping import Float, jaxtyped
4
+ from torch import Tensor
5
+
6
+ from vectormesh.types import BaseComponent
7
+
8
+
9
+ class NeuralNet(BaseComponent):
10
+ """Two-layer feedforward network with GELU activation."""
11
+
12
+ def __init__(self, hidden_size: int, out_size: int):
13
+ super().__init__()
14
+ self.hidden_size = hidden_size
15
+ self.out_size = out_size
16
+ self.fc1 = nn.Linear(hidden_size, hidden_size)
17
+ self.fc2 = nn.Linear(hidden_size, out_size)
18
+ self.activation = nn.GELU()
19
+
20
+ @jaxtyped(typechecker=beartype)
21
+ def forward(
22
+ self, x: Float[Tensor, "batch {self.hidden_size}"]
23
+ ) -> Float[Tensor, "batch {self.out_size}"]:
24
+ return self.fc2(self.activation(self.fc1(x)))
25
+
26
+
27
+ class Projection(BaseComponent):
28
+ """Linear projection layer."""
29
+
30
+ def __init__(self, hidden_size: int, out_size: int):
31
+ super().__init__()
32
+ self.hidden_size = hidden_size
33
+ self.out_size = out_size
34
+ self.proj = nn.Linear(hidden_size, out_size)
35
+
36
+ @jaxtyped(typechecker=beartype)
37
+ def forward(
38
+ self, x: Float[Tensor, "batch {self.hidden_size}"]
39
+ ) -> Float[Tensor, "batch {self.out_size}"]:
40
+ return self.proj(x)
41
+
42
+
43
+ class Attention(nn.Module):
44
+ """Multi-head self-attention using PyTorch's implementation."""
45
+
46
+ def __init__(self, hidden_size: int, num_heads: int = 8, dropout: float = 0.1):
47
+ super().__init__()
48
+ self.hidden_size = hidden_size
49
+ self.attn = nn.MultiheadAttention(
50
+ embed_dim=hidden_size,
51
+ num_heads=num_heads,
52
+ dropout=dropout,
53
+ batch_first=True,
54
+ )
55
+
56
+ @jaxtyped(typechecker=beartype)
57
+ def forward(
58
+ self, x: Float[Tensor, "batch seq {self.hidden_size}"]
59
+ ) -> Float[Tensor, "batch seq {self.hidden_size}"]:
60
+ # Self-attention: query, key, value all come from x
61
+ attn_output, _ = self.attn(x, x, x, need_weights=False)
62
+ return attn_output
@@ -0,0 +1,34 @@
1
+ from typing import Any
2
+
3
+ import torch.nn.functional as F
4
+ from beartype import beartype
5
+ from jaxtyping import Float, jaxtyped
6
+ from torch import Tensor
7
+ from torch.nn.utils.rnn import pad_sequence
8
+
9
+
10
+ class DynamicPadding:
11
+ @jaxtyped(typechecker=beartype)
12
+ def __call__(self, embeddings: list[Any]) -> Float[Tensor, "batch _ dim"]:
13
+ """Pad sequences to the maximum length in the batch."""
14
+ return pad_sequence(
15
+ embeddings,
16
+ batch_first=True,
17
+ )
18
+
19
+
20
+ class FixedPadding:
21
+ def __init__(self, max_chunks: int):
22
+ self.max_chunks = max_chunks
23
+
24
+ @jaxtyped(typechecker=beartype)
25
+ def __call__(
26
+ self, embeddings: list[Any]
27
+ ) -> Float[Tensor, "batch {self.max_chunks} dim"]:
28
+ padded = pad_sequence(embeddings, batch_first=True)
29
+
30
+ current = padded.shape[1]
31
+ if current < self.max_chunks:
32
+ return F.pad(padded, (0, 0, 0, self.max_chunks - current))
33
+ else:
34
+ return padded[:, : self.max_chunks, :]
@@ -0,0 +1,41 @@
1
+ import torch.nn as nn
2
+ from beartype import beartype
3
+ from jaxtyping import jaxtyped
4
+
5
+ from vectormesh.types import TensorInput
6
+
7
+
8
+ class Serial(nn.Module):
9
+ """Sequential composition - just runs components in order.
10
+ Tensor checking happens via jaxtyping decorators on each component.
11
+ """
12
+
13
+ components: nn.ModuleList
14
+
15
+ def __init__(self, components: list[nn.Module]):
16
+ super().__init__()
17
+ self.components = nn.ModuleList(components)
18
+
19
+ @jaxtyped(typechecker=beartype)
20
+ def forward(self, tensors: TensorInput) -> TensorInput:
21
+ """Execute pipeline. Type checking via component decorators."""
22
+ result = tensors
23
+ for component in self.components:
24
+ result = component(result)
25
+ return result
26
+
27
+
28
+ class Parallel(nn.Module):
29
+ """Parallel composition - runs branches independently and returns tuple.
30
+ All branches receive the same input.
31
+ """
32
+
33
+ branches: nn.ModuleList
34
+
35
+ def __init__(self, branches: list[nn.Module]):
36
+ super().__init__()
37
+ self.branches = nn.ModuleList(branches)
38
+
39
+ @jaxtyped(typechecker=beartype)
40
+ def forward(self, tensors: TensorInput) -> TensorInput:
41
+ return tuple(branch(t) for branch, t in zip(self.branches, tensors))
@@ -0,0 +1,16 @@
1
+ """VectorMesh data components."""
2
+
3
+ from .cache import VectorCache
4
+ from .dataset import Collate, LabelEncoder, OneHot, build
5
+ from .vectorizers import BaseVectorizer, RegexVectorizer, Vectorizer
6
+
7
+ __all__ = [
8
+ "VectorCache",
9
+ "LabelEncoder",
10
+ "OneHot",
11
+ "Collate",
12
+ "build",
13
+ "BaseVectorizer",
14
+ "Vectorizer",
15
+ "RegexVectorizer",
16
+ ]