rxnn 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,150 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from attention import MultiHeadAttention
4
+ from ff import FeedForward, GatedFeedForward
5
+ from moe import MoeFeedForward, GatedMoeFeedForward
6
+
7
+
8
+ class ReactiveTransformerLayer(nn.Module):
9
+ """Reactive Transformer layer - extending the classic Transformer layer with Memory Cross-Attention"""
10
+
11
+ def __init__(
12
+ self,
13
+ embed_dim: int,
14
+ ff_dim: int,
15
+ self_attention: MultiHeadAttention,
16
+ memory_cross_attention: MultiHeadAttention,
17
+ use_rms_norm: bool = False,
18
+ use_post_norm: bool = False,
19
+ ff_activation: nn.Module = nn.GELU(),
20
+ ff_dropout: float = 0.1,
21
+ use_gated: bool = False,
22
+ use_moe: bool = False,
23
+ num_experts: int = 1,
24
+ moe_top_k: int = 1,
25
+ *args,
26
+ **kwargs,
27
+ ):
28
+ super(ReactiveTransformerLayer, self).__init__(*args, **kwargs)
29
+
30
+ self.attention = self_attention
31
+
32
+ self.memory_cross_attention = memory_cross_attention
33
+
34
+ if use_gated:
35
+ if use_moe:
36
+ self.ff = GatedMoeFeedForward(embed_dim, ff_dim, num_experts, ff_activation, top_k=moe_top_k,
37
+ dropout=ff_dropout)
38
+ else:
39
+ self.ff = GatedFeedForward(embed_dim, ff_dim, ff_activation, dropout=ff_dropout)
40
+ else:
41
+ if use_moe:
42
+ self.ff = MoeFeedForward(embed_dim, ff_dim, num_experts, ff_activation, top_k=moe_top_k,
43
+ dropout=ff_dropout)
44
+ else:
45
+ self.ff = FeedForward(embed_dim, ff_dim, ff_activation, dropout=ff_dropout)
46
+
47
+ if use_rms_norm:
48
+ self.norm1 = nn.RMSNorm(embed_dim)
49
+ self.norm2 = nn.RMSNorm(embed_dim)
50
+ self.norm3 = nn.RMSNorm(embed_dim)
51
+ else:
52
+ self.norm1 = nn.LayerNorm(embed_dim)
53
+ self.norm2 = nn.LayerNorm(embed_dim)
54
+ self.norm3 = nn.LayerNorm(embed_dim)
55
+ self.use_post_norm = use_post_norm
56
+
57
+ def trainable_cross_attention_(self, is_trainable: bool):
58
+ for param in self.memory_cross_attention.parameters():
59
+ param.requires_grad_(is_trainable)
60
+
61
+ def forward(self, x: torch.Tensor, stm: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
62
+ # First step, self-attention
63
+ residual = x
64
+ if not self.use_post_norm:
65
+ x = self.norm1(x)
66
+ x = self.attention(x, x, x, mask=mask)
67
+ x = residual + x
68
+ if self.use_post_norm:
69
+ x = self.norm1(x)
70
+ # Second step, Memory cross-attention
71
+ residual = x
72
+ if not self.use_post_norm:
73
+ x = self.norm2(x)
74
+ x = self.memory_cross_attention(x, stm, stm)
75
+ x = residual + x
76
+ if self.use_post_norm:
77
+ x = self.norm2(x)
78
+
79
+ # Third step, Feed Forward network
80
+ residual = x
81
+ if not self.use_post_norm:
82
+ x = self.norm3(x)
83
+ x = self.ff(x)
84
+ x = residual + x
85
+ if self.use_post_norm:
86
+ x = self.norm3(x)
87
+ return x
88
+
89
+
90
+ class ClassicTransformerLayer(nn.Module):
91
+ """Classic Transformer layer - classic decoder-only/encoder-only Transformer layer with self-attention and Feed-Forward network."""
92
+
93
+ def __init__(
94
+ self,
95
+ embed_dim: int,
96
+ ff_dim: int,
97
+ self_attention: MultiHeadAttention,
98
+ use_rms_norm: bool = False,
99
+ use_post_norm: bool = False,
100
+ ff_activation: nn.Module = nn.GELU(),
101
+ ff_dropout: float = 0.1,
102
+ use_gated: bool = False,
103
+ use_moe: bool = False,
104
+ num_experts: int = 1,
105
+ moe_top_k: int = 1,
106
+ *args,
107
+ **kwargs,
108
+ ):
109
+ super(ClassicTransformerLayer, self).__init__(*args, **kwargs)
110
+
111
+ self.attention = self_attention
112
+
113
+ if use_gated:
114
+ if use_moe:
115
+ self.ff = GatedMoeFeedForward(embed_dim, ff_dim, num_experts, top_k=moe_top_k, dropout=ff_dropout)
116
+ else:
117
+ self.ff = GatedFeedForward(embed_dim, ff_dim, dropout=ff_dropout)
118
+ else:
119
+ if use_moe:
120
+ self.ff = MoeFeedForward(embed_dim, ff_dim, num_experts, ff_activation, top_k=moe_top_k,
121
+ dropout=ff_dropout)
122
+ else:
123
+ self.ff = FeedForward(embed_dim, ff_dim, ff_activation, dropout=ff_dropout)
124
+
125
+ if use_rms_norm:
126
+ self.norm1 = nn.RMSNorm(embed_dim)
127
+ self.norm2 = nn.RMSNorm(embed_dim)
128
+ else:
129
+ self.norm1 = nn.LayerNorm(embed_dim)
130
+ self.norm2 = nn.LayerNorm(embed_dim)
131
+ self.use_post_norm = use_post_norm
132
+
133
+ def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
134
+ # First step, self-attention
135
+ residual = x
136
+ if not self.use_post_norm:
137
+ x = self.norm1(x)
138
+ x = self.attention(x, x, x, mask=mask)
139
+ x = residual + x
140
+ if self.use_post_norm:
141
+ x = self.norm1(x)
142
+ # Second step, Feed Forward network
143
+ residual = x
144
+ if not self.use_post_norm:
145
+ x = self.norm2(x)
146
+ x = self.ff(x)
147
+ x = residual + x
148
+ if self.use_post_norm:
149
+ x = self.norm2(x)
150
+ return x
@@ -0,0 +1,10 @@
1
+ import torch
2
+
3
+
4
+ def create_causal_mask(seq_len: int, device: torch.device) -> torch.Tensor:
5
+ """Create a causal (lower triangular) attention mask for a given sequence length."""
6
+ # Create a lower triangular matrix of ones
7
+ mask = torch.tril(torch.ones((seq_len, seq_len), device=device))
8
+ # Expand the mask to have the shape (1, 1, seq_len, seq_len)
9
+ mask = mask.unsqueeze(0).unsqueeze(0)
10
+ return mask
@@ -0,0 +1,168 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from positional import AbsolutePositionalEmbedding
4
+ from mask import create_causal_mask
5
+ from src.memory.stm import ShortTermMemory
6
+
7
+
8
+ class ReactiveTransformerBase(nn.Module):
9
+ """Base class for Reactive Transformer models - common logic for both decoders and encoders."""
10
+
11
+ def __init__(
12
+ self,
13
+ stm: ShortTermMemory,
14
+ embedding: nn.Embedding,
15
+ own_layers: nn.ModuleList,
16
+ shared_layers: nn.ModuleList = None,
17
+ absolute_embedding: AbsolutePositionalEmbedding = None,
18
+ use_flash_attention: bool = False,
19
+ *args,
20
+ **kwargs,
21
+ ):
22
+ super(ReactiveTransformerBase, self).__init__(*args, **kwargs)
23
+
24
+ self.embedding = embedding
25
+ self.stm = stm
26
+ self.pos_embedding = absolute_embedding
27
+ self.use_flash_attention = use_flash_attention
28
+
29
+ self.shared_layers = shared_layers
30
+ self.layers = own_layers
31
+ self.num_shared_layers = len(shared_layers) if shared_layers else 0
32
+ self.num_own_layers = len(own_layers) if own_layers else 0
33
+
34
+ def trainable_cross_attention_(self, is_trainable: bool):
35
+ for i in range(self.num_shared_layers):
36
+ self.shared_layers[i].trainable_cross_attention_(is_trainable)
37
+ for i in range(self.num_own_layers):
38
+ self.layers[i].trainable_cross_attention_(is_trainable)
39
+
40
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
41
+ # Shared logic for encoders and decoders - apply embeddings and positional encoding
42
+ x = self.embedding(x)
43
+ if self.pos_embedding is not None:
44
+ x = self.pos_embedding(x)
45
+ return x
46
+
47
+
48
+ class ReactiveTransformerDecoder(ReactiveTransformerBase):
49
+ """Reactive Transformer decoder - extending the classic Transformer decoder with Memory Cross-Attention"""
50
+
51
+ def __init__(self, embed_dim: int, vocab_size: int, *args, **kwargs):
52
+ super(ReactiveTransformerDecoder, self).__init__(*args, **kwargs)
53
+ self.head = nn.Linear(embed_dim, vocab_size)
54
+
55
+ def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
56
+ x = super().forward(x) # apply embeddings
57
+ seq_len = x.size(1)
58
+ if not self.use_flash_attention:
59
+ mask = create_causal_mask(seq_len, device=x.device)
60
+ if attention_mask is not None:
61
+ mask &= attention_mask.unsqueeze(1).unsqueeze(1).bool()
62
+ elif attention_mask is not None:
63
+ mask = attention_mask.unsqueeze(1).unsqueeze(1).bool()
64
+ else:
65
+ mask = None
66
+ # Process shared layers
67
+ if self.shared_layers is not None:
68
+ for i in range(self.num_shared_layers):
69
+ layer_stm = self.stm(i)
70
+ x = self.shared_layers[i](x, layer_stm, mask=mask)
71
+ # Process own layers
72
+ for i in range(self.num_own_layers):
73
+ layer_stm = self.stm(i)
74
+ x = self.layers[i](x, layer_stm, mask=mask)
75
+ return self.head(x)
76
+
77
+
78
+ class ReactiveTransformerEncoder(ReactiveTransformerBase):
79
+ """Reactive Transformer encoder - extending the classic Transformer encoder with Memory Cross-Attention"""
80
+
81
+ def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> tuple[torch.Tensor, torch.Tensor]:
82
+ x = super().forward(x) # apply embeddings
83
+ if attention_mask is not None:
84
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).bool()
85
+
86
+ hidden_states = []
87
+ # Process shared layers
88
+ if self.shared_layers is not None:
89
+ for i in range(self.num_shared_layers):
90
+ layer_stm = self.stm(i)
91
+ x = self.shared_layers[i](x, layer_stm, mask=attention_mask)
92
+ hidden_states.append(x)
93
+ # Process own layers
94
+ for i in range(self.num_own_layers):
95
+ layer_stm = self.stm(i)
96
+ x = self.layers[i](x, layer_stm, mask=attention_mask)
97
+ hidden_states.append(x)
98
+ return x, torch.stack(hidden_states)
99
+
100
+
101
+ class ClassicTransformerBase(nn.Module):
102
+ """Base class for Classic Transformer models - common logic for both decoders and encoders."""
103
+
104
+ def __init__(
105
+ self,
106
+ embedding: nn.Embedding,
107
+ layers: nn.ModuleList,
108
+ absolute_embedding: AbsolutePositionalEmbedding = None,
109
+ use_flash_attention: bool = False,
110
+ *args,
111
+ **kwargs,
112
+ ):
113
+ super(ClassicTransformerBase, self).__init__(*args, **kwargs)
114
+
115
+ self.embedding = embedding
116
+ self.pos_embedding = absolute_embedding
117
+ self.use_flash_attention = use_flash_attention
118
+
119
+ self.layers = layers
120
+ self.num_layers = len(layers) if layers else 0
121
+
122
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
123
+ # Shared logic for encoders and decoders - apply embeddings and positional encoding
124
+ x = self.embedding(x)
125
+ if self.pos_embedding is not None:
126
+ x = self.pos_embedding(x)
127
+ return x
128
+
129
+
130
+ class ClassicTransformerDecoder(ClassicTransformerBase):
131
+ """Classic Transformer decoder - for decoder-only Transformer models"""
132
+
133
+ def __init__(self, embed_dim: int, vocab_size: int, *args, **kwargs):
134
+ super(ClassicTransformerDecoder, self).__init__(*args, **kwargs)
135
+ self.head = nn.Linear(embed_dim, vocab_size)
136
+
137
+ def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
138
+ x = super().forward(x) # apply embeddings
139
+ seq_len = x.size(1)
140
+ if not self.use_flash_attention:
141
+ mask = create_causal_mask(seq_len, device=x.device)
142
+ if attention_mask is not None:
143
+ mask &= attention_mask.unsqueeze(1).unsqueeze(1).bool()
144
+ elif attention_mask is not None:
145
+ mask = attention_mask.unsqueeze(1).unsqueeze(1).bool()
146
+ else:
147
+ mask = None
148
+
149
+ # Process layers
150
+ for i in range(self.num_layers):
151
+ x = self.layers[i](x, mask=mask)
152
+ return self.head(x)
153
+
154
+
155
+ class ClassicTransformerEncoder(ClassicTransformerBase):
156
+ """Classic Transformer encoder - for encoder-only Transformer models"""
157
+
158
+ def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> tuple[torch.Tensor, torch.Tensor]:
159
+ x = super().forward(x) # apply embeddings
160
+ if attention_mask is not None:
161
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).bool()
162
+
163
+ hidden_states = []
164
+ # Process own layers
165
+ for i in range(self.num_own_layers):
166
+ x = self.layers[i](x, mask=attention_mask)
167
+ hidden_states.append(x)
168
+ return x, torch.stack(hidden_states)
@@ -0,0 +1,139 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class MoeRouter(nn.Module):
7
+ """Mixture-of-Experts Router layer - computes routing weights for each expert."""
8
+
9
+ def __init__(self, embed_dim: int, num_experts: int, top_k: int = 1, *args, **kwargs):
10
+ super(MoeRouter, self).__init__(*args, **kwargs)
11
+ self.top_k = top_k
12
+ self.num_experts = num_experts
13
+ self.gate = nn.Linear(embed_dim, num_experts, bias=False)
14
+ self.aux_loss = 0.0 # For expert load balancing
15
+
16
+ def forward(self, x: torch.Tensor):
17
+ # x shape: [batch_size*seq_len, embed_dim]
18
+ logits = self.gate(x)
19
+ probs = F.softmax(logits, dim=-1)
20
+
21
+ # Expert load balancing loss
22
+ if self.training:
23
+ probs_for_bal = F.softmax(logits, dim=0)
24
+ self.aux_loss = (probs_for_bal.mean(dim=0) *
25
+ torch.log(probs_for_bal.mean(dim=0) + 1e-9)).sum()
26
+
27
+ top_k_weights, top_k_indices = probs.topk(self.top_k, dim=-1)
28
+ top_k_weights = top_k_weights / (top_k_weights.sum(dim=-1, keepdim=True) + 1e-9)
29
+
30
+ return top_k_weights, top_k_indices
31
+
32
+
33
+ class MoeFeedForward(nn.Module):
34
+ """Mixture-of-Experts Feed-Forward layer - combines multiple experts into a single model."""
35
+
36
+ def __init__(
37
+ self,
38
+ embed_dim: int,
39
+ hidden_dim: int,
40
+ num_experts: int,
41
+ activation: nn.Module,
42
+ top_k: int = 1,
43
+ dropout: float = 0.0,
44
+ *args,
45
+ **kwargs
46
+ ):
47
+ super(MoeFeedForward, self).__init__(*args, **kwargs)
48
+ self.embed_dim = embed_dim
49
+ self.num_experts = num_experts
50
+ self.top_k = top_k
51
+
52
+ self.router = MoeRouter(embed_dim, num_experts, top_k)
53
+
54
+ # Batch all expert parameters together
55
+ self.w1 = nn.Parameter(torch.empty(num_experts, embed_dim, self._w1_dim_factor(hidden_dim)))
56
+ self.b1 = nn.Parameter(torch.zeros(num_experts, self._w1_dim_factor(hidden_dim)))
57
+ self.w2 = nn.Parameter(torch.empty(num_experts, hidden_dim, embed_dim))
58
+ self.b2 = nn.Parameter(torch.zeros(num_experts, embed_dim))
59
+ self.activation = activation
60
+ self.dropout = nn.Dropout(dropout)
61
+
62
+ # Initialize parameters
63
+ self._init_linear_parameters()
64
+ nn.init.zeros_(self.b1)
65
+ nn.init.zeros_(self.b2)
66
+
67
+ def _init_linear_parameters(self):
68
+ nn.init.kaiming_normal_(self.w1, nonlinearity='relu')
69
+ nn.init.kaiming_normal_(self.w2, nonlinearity='relu')
70
+
71
+ def _w1_dim_factor(self, hidden_dim: int) -> int:
72
+ return hidden_dim
73
+
74
+ def _activate(self, h: torch.Tensor):
75
+ return self.activation(h)
76
+
77
+ def forward(self, x: torch.Tensor):
78
+ orig_shape = x.shape
79
+ x = x.view(-1, self.embed_dim) # [batch*seq_len, embed_dim]
80
+
81
+ # Get routing weights and indices
82
+ weights, indices = self.router(x) # [batch*seq_len, top_k]
83
+
84
+ # Create expert masks and combine it with masks
85
+ mask = F.one_hot(indices, self.num_experts).float() # [batch*seq_len, top_k, num_experts]
86
+ weights = (weights.unsqueeze(-1) * mask).sum(dim=1) # [batch*seq_len, num_experts]
87
+
88
+ # Expert computation
89
+ x = x.unsqueeze(1).expand(-1, self.num_experts, -1) # [batch*seq_len, num_experts, embed_dim]
90
+
91
+ # First linear layer
92
+ h = torch.einsum('bie,ieh->bih', x, self.w1) + self.b1 # [batch*seq_len, num_experts, hidden_dim]
93
+ h = self._activate(h)
94
+ h = self.dropout(h)
95
+
96
+ # Second linear layer (projection back to embed_dim)
97
+ out = torch.einsum('bih,ihe->bie', h, self.w2) + self.b2 # [batch*seq_len, num_experts, embed_dim]
98
+
99
+ # Weighted sum of expert outputs
100
+ out = (out * weights.unsqueeze(-1)).sum(dim=1) # [batch*seq_len, embed_dim]
101
+
102
+ return out.view(*orig_shape)
103
+
104
+
105
+ class GatedMoeFeedForward(MoeFeedForward):
106
+ """Gated Mixture-of-Experts Feed-Forward layer - enable GLU-based activations for MoE"""
107
+
108
+ def __init__(
109
+ self,
110
+ embed_dim: int,
111
+ hidden_dim: int,
112
+ num_experts: int,
113
+ activation: nn.Module = nn.SiLU(),
114
+ top_k: int = 1,
115
+ dropout: float = 0.1,
116
+ *args,
117
+ **kwargs
118
+ ):
119
+ super(GatedMoeFeedForward, self).__init__(
120
+ embed_dim=embed_dim,
121
+ hidden_dim=hidden_dim,
122
+ num_experts=num_experts,
123
+ activation=activation,
124
+ top_k=top_k,
125
+ dropout=dropout,
126
+ *args,
127
+ **kwargs
128
+ )
129
+
130
+ def _init_linear_parameters(self):
131
+ nn.init.kaiming_normal_(self.w1, nonlinearity='relu')
132
+ nn.init.kaiming_normal_(self.w2, nonlinearity='linear')
133
+
134
+ def _w1_dim_factor(self, hidden_dim: int) -> int:
135
+ return 2 * hidden_dim
136
+
137
+ def _activate(self, h: torch.Tensor):
138
+ a, b = h.chunk(2, dim=-1)
139
+ return a * self.activation(b)
@@ -0,0 +1,105 @@
1
+ import torch
2
+ from torch import nn
3
+ import math
4
+
5
+
6
+ class RotaryPositionalEmbedding(nn.Module):
7
+ """Rotary Positional Embedding (RoPE) layer - recommended for positional encoding"""
8
+
9
+ def __init__(self, dim: int, max_seq_len: int = 4096, base: int = 10000, *args, **kwargs):
10
+ super(RotaryPositionalEmbedding, self).__init__(*args, **kwargs)
11
+ self.dim = dim
12
+ self.max_seq_len = max_seq_len
13
+ self.base = base
14
+ inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
15
+ self.register_buffer('inv_freq', inv_freq)
16
+ self.register_buffer('cache', None, persistent=False)
17
+
18
+ def forward(self, q: torch.Tensor, k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
19
+ device = q.device
20
+ seq_len = q.size(-2)
21
+ # Prepare RoPE Frequencies
22
+ freqs = self._prepare_freqs(seq_len, device)
23
+
24
+ # Apply the rotation to the queries
25
+ q_embed = self._rotate(q, freqs)
26
+ # Apply the rotation to the keys
27
+ k_embed = self._rotate(k, freqs)
28
+
29
+ return q_embed, k_embed
30
+
31
+ def forward_one(self, q: torch.Tensor) -> torch.Tensor:
32
+ device = q.device
33
+ seq_len = q.size(-2)
34
+ # Prepare RoPE Frequencies
35
+ freqs = self._prepare_freqs(seq_len, device)
36
+
37
+ # Apply the rotation to the queries
38
+ q_embed = self._rotate(q, freqs)
39
+
40
+ return q_embed
41
+
42
+ def _prepare_freqs(self, seq_len: int, device: torch.device) -> torch.Tensor:
43
+ if self.cache is None or self.cache.size(1) < seq_len:
44
+ t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
45
+ freqs = torch.einsum('i,j->ij', t, self.inv_freq)
46
+ self.cache = freqs
47
+ return freqs[None, None, :, :]
48
+ else:
49
+ return self.cache[None, None, :, :]
50
+
51
+ def _rotate(self, x: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
52
+ x1 = x[..., 0::2]
53
+ x2 = x[..., 1::2]
54
+ # Apply the rotation
55
+ x_rotated1 = x1 * torch.cos(freqs) - x2 * torch.sin(freqs)
56
+ x_rotated2 = x1 * torch.sin(freqs) + x2 * torch.cos(freqs)
57
+ # Concatenate the rotated parts back together
58
+ x_rotated = torch.cat((x_rotated1, x_rotated2), dim=-1)
59
+ return x_rotated
60
+
61
+
62
+ class AbsolutePositionalEmbedding(nn.Module):
63
+ """Absolute Positional Embedding layer (legacy) - not recommended for memory-augmented Reactive Transformers"""
64
+
65
+ def __init__(self, max_seq_len: int, embed_dim: int, *args, **kwargs):
66
+ super(AbsolutePositionalEmbedding, self).__init__(*args, **kwargs)
67
+ self.max_seq_len = max_seq_len
68
+ self.embed_dim = embed_dim
69
+ self.position_embeddings = nn.Embedding(max_seq_len, embed_dim)
70
+
71
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
72
+ # Create position indices
73
+ seq_len = x.size(1)
74
+ positions = torch.arange(seq_len, device=x.device).unsqueeze(0).expand(x.size(0), -1)
75
+ # Get position embeddings
76
+ pos_embeddings = self.position_embeddings(positions)
77
+ # Add position embeddings to the input embeddings
78
+ return x + pos_embeddings
79
+
80
+
81
+ class RelativePositionalEmbedding(nn.Module):
82
+ """Relative Positional Embedding layer (legacy) - not compatible with Flash Attention and not recommended for positional encoding"""
83
+
84
+ def __init__(self, max_seq_len: int, embed_dim: int, *args, **kwargs):
85
+ super(RelativePositionalEmbedding, self).__init__(*args, **kwargs)
86
+ self.max_seq_len = max_seq_len
87
+ self.embed_dim = embed_dim
88
+ self.position_embeddings = nn.Embedding(2 * max_seq_len - 1, embed_dim)
89
+ self.embed_dim_sqrt = math.sqrt(embed_dim)
90
+
91
+ def forward(self, q: torch.Tensor, k: torch.Tensor) -> torch.Tensor:
92
+ q_len = q.size(2)
93
+ k_len = k.size(2)
94
+
95
+ # Create relative position indices
96
+ indices = torch.arange(q_len, device=q.device)[:, None] - torch.arange(k_len, device=k.device)[None, :]
97
+ indices += self.max_seq_len - 1 # Shift to non-negative
98
+ indices = torch.clamp(indices, 0, 2 * self.max_seq_len - 2)
99
+
100
+ # Get embeddings
101
+ rel_emb = self.position_embeddings(indices)
102
+
103
+ rel_emb = rel_emb.permute(2, 0, 1)
104
+ rel_pos_bias = torch.einsum('bhqd, dqk -> bhqk', q, rel_emb)
105
+ return rel_pos_bias / self.embed_dim_sqrt
@@ -0,0 +1,109 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from typing import Iterator
5
+
6
+ def sample(
7
+ logits: torch.Tensor,
8
+ temperature: float = 1.0,
9
+ top_k: int = None,
10
+ top_p: float = None,
11
+ ) -> torch.Tensor:
12
+ if temperature <= 0:
13
+ raise ValueError("Temperature must be > 0")
14
+
15
+ # Apply temperature
16
+ logits = logits / temperature
17
+
18
+ # Apply top-k filtering
19
+ if top_k is not None and top_k > 0:
20
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
21
+ logits[indices_to_remove] = float('-inf')
22
+
23
+ # Apply top-p (nucleus) sampling
24
+ if top_p is not None and 0 < top_p <= 1.0:
25
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
26
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
27
+
28
+ # Remove tokens with cumulative probability above threshold
29
+ sorted_indices_to_remove = cumulative_probs > top_p
30
+ # Shift right to keep first token above threshold
31
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
32
+ sorted_indices_to_remove[..., 0] = 0
33
+
34
+ # Scatter sorted indices back to original positions
35
+ indices_to_remove = sorted_indices_to_remove.scatter(
36
+ dim=-1,
37
+ index=sorted_indices,
38
+ src=sorted_indices_to_remove
39
+ )
40
+ logits[indices_to_remove] = float('-inf')
41
+
42
+ # Convert to probabilities
43
+ probs = F.softmax(logits, dim=-1)
44
+
45
+ # Sample from distribution
46
+ return torch.multinomial(probs, num_samples=1)
47
+
48
+ class Sampler:
49
+ def __init__(self, model: nn.Module, device: torch.device, end_token_id: int):
50
+ self.model = model.to(device)
51
+ self.device = device
52
+ self.end_token_id = end_token_id
53
+
54
+ def _generate_token(
55
+ self,
56
+ input_ids: torch.Tensor,
57
+ temperature: float,
58
+ top_k: int,
59
+ top_p: float ,
60
+ attention_mask: torch.Tensor,
61
+ ) -> tuple[int, torch.Tensor, torch.Tensor]:
62
+ # Forward pass to get next token logits
63
+ outputs = self.model(input_ids, attention_mask=attention_mask)
64
+ next_token_logits = outputs[:, -1, :] # Get logits for next token
65
+ # Apply sampling
66
+ next_token = sample(
67
+ next_token_logits,
68
+ temperature=temperature,
69
+ top_k=top_k,
70
+ top_p=top_p,
71
+ )
72
+ next_token = next_token.item() # Extract scalar token
73
+ next_token_ten = torch.tensor([[next_token]], device=self.device)
74
+ next_input_ids = torch.cat([input_ids, next_token_ten], dim=1)
75
+ new_one = torch.ones(1, 1, dtype=torch.bool, device=self.device)
76
+ next_mask = torch.cat([attention_mask, new_one], dim=1) if attention_mask is not None else None
77
+ # Yield the generated token
78
+ return (
79
+ next_token,
80
+ next_input_ids,
81
+ next_mask
82
+ )
83
+
84
+ def __call__(
85
+ self,
86
+ initial_tokens: torch.Tensor,
87
+ temperature: float = 1.0,
88
+ top_k: int = None,
89
+ top_p: float = None,
90
+ max_seq_len: int = 50,
91
+ attention_mask: torch.Tensor = None,
92
+ no_grad: bool = True,
93
+ ) -> Iterator[int]:
94
+ # Convert initial tokens to tensor and move to device
95
+ input_ids = initial_tokens
96
+
97
+ if no_grad:
98
+ with torch.no_grad():
99
+ for _ in range(max_seq_len):
100
+ next_token, input_ids, attention_mask = self._generate_token(input_ids, temperature, top_k, top_p, attention_mask)
101
+ yield next_token
102
+ if next_token == self.end_token_id:
103
+ break
104
+ else:
105
+ for _ in range(max_seq_len):
106
+ next_token, input_ids, attention_mask = self._generate_token(input_ids, temperature, top_k, top_p, attention_mask)
107
+ yield next_token
108
+ if next_token == self.end_token_id:
109
+ break
src/utils.py ADDED
@@ -0,0 +1,14 @@
1
+ import torch
2
+
3
+ def human_format(num: int):
4
+ num = float('{:.3g}'.format(num))
5
+ magnitude = 0
6
+ while abs(num) >= 1000:
7
+ magnitude += 1
8
+ num /= 1000.0
9
+ return '{}{}'.format('{:f}'.format(num).rstrip('0').rstrip('.'), ['', 'K', 'M', 'B', 'T'][magnitude])
10
+
11
+
12
+ def get_model_size(model: torch.nn.Module):
13
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
14
+ return f'Model params {human_format(trainable_params)}'