rxnn 0.1.12__py3-none-any.whl → 0.1.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/experimental/attention.py +248 -1
- rxnn/transformers/layers.py +2 -2
- rxnn/transformers/models.py +2 -2
- rxnn/transformers/moe.py +46 -12
- rxnn/transformers/sampler.py +82 -22
- {rxnn-0.1.12.dist-info → rxnn-0.1.14.dist-info}/METADATA +1 -1
- {rxnn-0.1.12.dist-info → rxnn-0.1.14.dist-info}/RECORD +9 -9
- {rxnn-0.1.12.dist-info → rxnn-0.1.14.dist-info}/LICENSE +0 -0
- {rxnn-0.1.12.dist-info → rxnn-0.1.14.dist-info}/WHEEL +0 -0
rxnn/experimental/attention.py
CHANGED
@@ -1,6 +1,253 @@
|
|
1
1
|
import torch
|
2
2
|
from torch import nn
|
3
|
-
from
|
3
|
+
from ..transformers.attention import MultiHeadAttention, GroupedQueryAttention
|
4
|
+
from ..transformers.positional import RotaryPositionalEmbedding
|
5
|
+
from ..transformers.moe import MoeRouter
|
6
|
+
|
7
|
+
# Created by Reactive AI
|
8
|
+
|
9
|
+
class GroupedMoeAttention(GroupedQueryAttention):
|
10
|
+
"""
|
11
|
+
Grouped MoE Attention (GMA) - GQA extended with Mixture-of-Experts (MoE) routing.
|
12
|
+
Instead of mapping keys/values to static head groups, it dynamically selects head expert groups. It has the same
|
13
|
+
number of total keys/values heads as query heads, but uses only a selected group for attention calculation.
|
14
|
+
- with num_groups set to 1, it will be MoE MultiQueryAttention
|
15
|
+
|
16
|
+
Compared to traditional GQA/MQA, it should provide better performance, because lot less data could be lost using
|
17
|
+
this approach - we are training the full number of keys/values heads, while using only a group.
|
18
|
+
|
19
|
+
In case of efficiency, it should be close to GQA/MQA linear performance, but with a small MoE routing overhead.
|
20
|
+
|
21
|
+
Optionally, it could use even more expert heads than attention heads - in example:
|
22
|
+
- 512 dim divided into 16 heads with 32 dim, using 4 head groups - may use i.e., 24 total expert heads - still only
|
23
|
+
4 will be used for attention calculation, while 16 is used to split dimensions (in that case it will have 16 query heads)
|
24
|
+
"""
|
25
|
+
def __init__(
|
26
|
+
self,
|
27
|
+
embed_dim: int,
|
28
|
+
num_heads: int,
|
29
|
+
num_groups: int,
|
30
|
+
dropout: float = 0.0,
|
31
|
+
rope: RotaryPositionalEmbedding = None,
|
32
|
+
rope_only_for_query: bool = False,
|
33
|
+
use_relative_embeddings: bool = False,
|
34
|
+
max_seq_len: int = 1024,
|
35
|
+
use_flash_attention: bool = False,
|
36
|
+
is_causal: bool = False,
|
37
|
+
use_bias: bool = False,
|
38
|
+
num_experts: int = None,
|
39
|
+
*args,
|
40
|
+
**kwargs,
|
41
|
+
):
|
42
|
+
self.num_experts = num_experts if num_experts is not None else num_heads
|
43
|
+
super(GroupedMoeAttention, self).__init__(
|
44
|
+
embed_dim,
|
45
|
+
num_heads,
|
46
|
+
num_groups=num_groups,
|
47
|
+
dropout=dropout,
|
48
|
+
rope=rope,
|
49
|
+
rope_only_for_query=rope_only_for_query,
|
50
|
+
use_relative_embeddings=use_relative_embeddings,
|
51
|
+
max_seq_len=max_seq_len,
|
52
|
+
use_flash_attention=use_flash_attention,
|
53
|
+
is_causal=is_causal,
|
54
|
+
use_bias=use_bias,
|
55
|
+
*args,
|
56
|
+
**kwargs,
|
57
|
+
)
|
58
|
+
|
59
|
+
def _init_kv(self, embed_dim: int):
|
60
|
+
self.router = MoeRouter(embed_dim, self.num_experts, top_k=self.num_groups)
|
61
|
+
hidden_dim = embed_dim // (self.num_heads // self.num_groups)
|
62
|
+
self.wk = nn.Parameter(torch.empty(self.num_experts, embed_dim, hidden_dim))
|
63
|
+
self.bk = nn.Parameter(torch.zeros(self.num_experts, hidden_dim)) if self.use_bias else None
|
64
|
+
self.wv = nn.Parameter(torch.empty(self.num_experts, embed_dim, hidden_dim))
|
65
|
+
self.bv = nn.Parameter(torch.zeros(self.num_experts, hidden_dim)) if self.use_bias else None
|
66
|
+
self._init_experts()
|
67
|
+
|
68
|
+
def _init_experts(self):
|
69
|
+
torch.nn.init.xavier_uniform_(self.wk)
|
70
|
+
torch.nn.init.xavier_uniform_(self.wv)
|
71
|
+
if self.use_bias:
|
72
|
+
torch.nn.init.zeros_(self.bk)
|
73
|
+
torch.nn.init.zeros_(self.bv)
|
74
|
+
|
75
|
+
def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int, skip_query_processing: bool = False):
|
76
|
+
# head_dim = d // self.num_heads
|
77
|
+
# group_heads = self.num_heads // self.num_groups
|
78
|
+
#
|
79
|
+
# # Process Query as in GQA
|
80
|
+
# q = self.q_proj(query).view(b, t, self.num_heads, head_dim).transpose(1, 2)
|
81
|
+
#
|
82
|
+
# # Process Key and Value with MoE routing
|
83
|
+
# key_flat = key.view(-1, d)
|
84
|
+
# weights, indices = self.router(key_flat)
|
85
|
+
# weights = weights.view(b, key.size(1), self.num_groups, 1)
|
86
|
+
# indices = indices.view(b, key.size(1), self.num_groups)
|
87
|
+
#
|
88
|
+
# # Compute all experts' K and V projections
|
89
|
+
# # Shape: (batch_size, seq_len, num_experts, head_dim * num_groups)
|
90
|
+
# k_all = torch.einsum(
|
91
|
+
# 'be, ehd -> bedh',
|
92
|
+
# key_flat,
|
93
|
+
# self.wk.view(self.num_experts, d, -1)
|
94
|
+
# ).view(b, key.size(1), self.num_experts, -1)
|
95
|
+
#
|
96
|
+
# v_all = torch.einsum(
|
97
|
+
# 'be, ehd -> bedh',
|
98
|
+
# value.view(-1, d),
|
99
|
+
# self.wv.view(self.num_experts, d, -1)
|
100
|
+
# ).view(b, value.size(1), self.num_experts, -1)
|
101
|
+
#
|
102
|
+
# # Select top_k experts and compute weighted sum
|
103
|
+
# selected_k = torch.gather(
|
104
|
+
# k_all,
|
105
|
+
# 2,
|
106
|
+
# indices.unsqueeze(-1).expand(-1, -1, -1, k_all.size(-1))
|
107
|
+
# )
|
108
|
+
# selected_v = torch.gather(
|
109
|
+
# v_all,
|
110
|
+
# 2,
|
111
|
+
# indices.unsqueeze(-1).expand(-1, -1, -1, v_all.size(-1))
|
112
|
+
# )
|
113
|
+
#
|
114
|
+
# selected_k = (selected_k * weights).sum(dim=2)
|
115
|
+
# selected_v = (selected_v * weights).sum(dim=2)
|
116
|
+
# # Reshape to GQA format: (B, G, S, head_dim)
|
117
|
+
# k = selected_k.view(b, key.size(1), self.num_groups, head_dim).transpose(1, 2)
|
118
|
+
# v = selected_v.view(b, value.size(1), self.num_groups, head_dim).transpose(1, 2)
|
119
|
+
#
|
120
|
+
# if not self.use_flash_attention:
|
121
|
+
# group_heads = self.num_heads // self.num_groups
|
122
|
+
#
|
123
|
+
# k = k.unsqueeze(2).expand(-1, -1, group_heads, -1, -1) # (B, G, group_heads, S, head_dim)
|
124
|
+
# v = v.unsqueeze(2).expand(-1, -1, group_heads, -1, -1) # (B, G, group_heads, S, head_dim)
|
125
|
+
#
|
126
|
+
# k = k.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
|
127
|
+
# v = v.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
|
128
|
+
#
|
129
|
+
# return q, k, v
|
130
|
+
head_dim = d // self.num_heads
|
131
|
+
|
132
|
+
# Process Query as in GQA
|
133
|
+
q = self.q_proj(query).view(b, t, self.num_heads, head_dim).transpose(1, 2) if not skip_query_processing else query
|
134
|
+
|
135
|
+
# Process Key and Value with MoE routing
|
136
|
+
key_flat = key.view(-1, d) # (B*S, d)
|
137
|
+
value_flat = value.view(-1, d) # (B*S, d)
|
138
|
+
|
139
|
+
# Get routing indices and weights for K
|
140
|
+
weights_k, indices_k = self.router(key_flat)
|
141
|
+
indices_k = indices_k.view(-1, self.top_k) # (B*S, top_k)
|
142
|
+
weights_k = weights_k.view(-1, self.top_k, 1) # (B*S, top_k, 1)
|
143
|
+
|
144
|
+
# Select and compute K projections for only the top_k experts
|
145
|
+
selected_k_weights = self.k_experts[indices_k] # (B*S, top_k, d, k_out_dim)
|
146
|
+
k_proj = torch.einsum('bd, behd -> beh', key_flat.unsqueeze(1), selected_k_weights)
|
147
|
+
selected_k = (k_proj * weights_k).sum(dim=1) # (B*S, k_out_dim)
|
148
|
+
selected_k = selected_k.view(b, key.size(1), -1) # (B, S, k_out_dim)
|
149
|
+
|
150
|
+
# Compute V using the same indices as K (since they share the same router)
|
151
|
+
selected_v_weights = self.v_experts[indices_k]
|
152
|
+
v_proj = torch.einsum('bd, behd -> beh', value_flat.unsqueeze(1), selected_v_weights)
|
153
|
+
selected_v = (v_proj * weights_k).sum(dim=1)
|
154
|
+
selected_v = selected_v.view(b, value.size(1), -1) # (B, S, k_out_dim)
|
155
|
+
|
156
|
+
# Reshape to GQA format: (B, G, S, head_dim)
|
157
|
+
k = selected_k.view(b, key.size(1), self.num_groups, head_dim).transpose(1, 2)
|
158
|
+
v = selected_v.view(b, value.size(1), self.num_groups, head_dim).transpose(1, 2)
|
159
|
+
|
160
|
+
if not self.use_flash_attention:
|
161
|
+
group_heads = self.num_heads // self.num_groups
|
162
|
+
|
163
|
+
k = k.unsqueeze(2).expand(-1, -1, group_heads, -1, -1) # (B, G, group_heads, S, head_dim)
|
164
|
+
v = v.unsqueeze(2).expand(-1, -1, group_heads, -1, -1) # (B, G, group_heads, S, head_dim)
|
165
|
+
|
166
|
+
k = k.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
|
167
|
+
v = v.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
|
168
|
+
|
169
|
+
return q, k, v
|
170
|
+
|
171
|
+
class SparseMoeAttention(GroupedMoeAttention):
|
172
|
+
"""
|
173
|
+
Sparse MoE Attention (SMA) - Grouped MoE Attention extended even more for sublinear computational efficiency.
|
174
|
+
In addition to using Mixture-of-Experts (MoE) for key/value head groups, SMA is also using dynamically selected
|
175
|
+
query heads - with that approach, each token could attend to every other token, but only partially - only some part of
|
176
|
+
information from each token is used to identify related information parts from other tokens.
|
177
|
+
|
178
|
+
This solution could reduce the computational complexity of attention operation to sublinear level (<O(N))
|
179
|
+
"""
|
180
|
+
def __init__(
|
181
|
+
self,
|
182
|
+
embed_dim: int,
|
183
|
+
num_heads: int,
|
184
|
+
num_groups: int,
|
185
|
+
dropout: float = 0.0,
|
186
|
+
rope: RotaryPositionalEmbedding = None,
|
187
|
+
rope_only_for_query: bool = False,
|
188
|
+
use_relative_embeddings: bool = False,
|
189
|
+
max_seq_len: int = 1024,
|
190
|
+
use_flash_attention: bool = False,
|
191
|
+
is_causal: bool = False,
|
192
|
+
use_bias: bool = False,
|
193
|
+
num_experts: int = None,
|
194
|
+
num_query_experts: int = None,
|
195
|
+
num_active_query_heads: int = None,
|
196
|
+
*args,
|
197
|
+
**kwargs,
|
198
|
+
):
|
199
|
+
self.num_query_experts = num_query_experts if num_query_experts is not None else num_heads
|
200
|
+
self.num_active_query_heads = num_active_query_heads if num_active_query_heads is not None else num_groups
|
201
|
+
super(SparseMoeAttention, self).__init__(
|
202
|
+
embed_dim,
|
203
|
+
num_heads,
|
204
|
+
num_groups=num_groups,
|
205
|
+
dropout=dropout,
|
206
|
+
rope=rope,
|
207
|
+
rope_only_for_query=rope_only_for_query,
|
208
|
+
use_relative_embeddings=use_relative_embeddings,
|
209
|
+
max_seq_len=max_seq_len,
|
210
|
+
use_flash_attention=use_flash_attention,
|
211
|
+
is_causal=is_causal,
|
212
|
+
use_bias=use_bias,
|
213
|
+
num_experts=num_experts,
|
214
|
+
*args,
|
215
|
+
**kwargs,
|
216
|
+
)
|
217
|
+
|
218
|
+
def _init_q(self, embed_dim: int):
|
219
|
+
self.query_router = MoeRouter(embed_dim, self.num_query_experts, top_k=self.num_active_query_heads)
|
220
|
+
hidden_dim = embed_dim // (self.num_heads // self.num_groups)
|
221
|
+
self.wq = nn.Parameter(torch.empty(self.num_query_experts, embed_dim, hidden_dim))
|
222
|
+
self.bq = nn.Parameter(torch.zeros(self.num_query_experts, hidden_dim)) if self.use_bias else None
|
223
|
+
self._init_query_experts()
|
224
|
+
|
225
|
+
def _init_query_experts(self):
|
226
|
+
torch.nn.init.xavier_uniform_(self.wq)
|
227
|
+
if self.use_bias:
|
228
|
+
torch.nn.init.zeros_(self.bq)
|
229
|
+
|
230
|
+
def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int):
|
231
|
+
head_dim = d // self.num_heads
|
232
|
+
|
233
|
+
# Process Query with MoE routing
|
234
|
+
query_flat = query.view(-1, d) # (B*T, d)
|
235
|
+
weights_q, indices_q = self.router_q(query_flat)
|
236
|
+
indices_q = indices_q.view(-1, self.top_k_q) # (B*T, top_k_q)
|
237
|
+
weights_q = weights_q.view(-1, self.top_k_q, 1) # (B*T, top_k_q, 1)
|
238
|
+
|
239
|
+
# Select and compute Q projections for top_k experts
|
240
|
+
selected_q_weights = self.q_experts[indices_q] # (B*T, top_k_q, d, head_dim*num_heads)
|
241
|
+
q_proj = torch.einsum('bd, behd -> beh', query_flat.unsqueeze(1), selected_q_weights)
|
242
|
+
selected_q = (q_proj * weights_q).sum(dim=1) # (B*T, head_dim*num_heads)
|
243
|
+
selected_q = selected_q.view(b, t, -1) # (B, T, head_dim*num_heads)
|
244
|
+
|
245
|
+
q = selected_q.view(b, t, self.num_heads, head_dim).transpose(1, 2) # (B, H, T, head_dim)
|
246
|
+
|
247
|
+
return super()._forward_qkv(q, key, value, b, t, d, skip_query_processing=True)
|
248
|
+
|
249
|
+
|
250
|
+
# Others
|
4
251
|
|
5
252
|
class FlexAttention(MultiHeadAttention):
|
6
253
|
def __init__(
|
rxnn/transformers/layers.py
CHANGED
@@ -59,7 +59,7 @@ class ReactiveTransformerLayer(nn.Module):
|
|
59
59
|
for param in self.memory_cross_attention.parameters():
|
60
60
|
param.requires_grad_(is_trainable)
|
61
61
|
|
62
|
-
def
|
62
|
+
def moe_router_loss(self):
|
63
63
|
return self.ff.router_loss() if self.use_moe else None
|
64
64
|
|
65
65
|
def forward(self, x: torch.Tensor, stm: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
|
@@ -135,7 +135,7 @@ class ClassicTransformerLayer(nn.Module):
|
|
135
135
|
self.use_post_norm = use_post_norm
|
136
136
|
self.use_moe = use_moe
|
137
137
|
|
138
|
-
def
|
138
|
+
def moe_router_loss(self):
|
139
139
|
return self.ff.router_loss() if self.use_moe else torch.tensor(0.0)
|
140
140
|
|
141
141
|
def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
|
rxnn/transformers/models.py
CHANGED
@@ -37,7 +37,7 @@ class ReactiveTransformerBase(nn.Module):
|
|
37
37
|
for i in range(self.num_own_layers):
|
38
38
|
self.layers[i].trainable_cross_attention_(is_trainable)
|
39
39
|
|
40
|
-
def
|
40
|
+
def moe_router_loss(self):
|
41
41
|
return torch.stack([self.layers[i].moe_router_loss() for i in range(self.num_own_layers) if self.layers[i].use_moe] + [
|
42
42
|
self.shared_layers[i].moe_router_loss() for i in range(self.num_shared_layers) if self.shared_layers[i].use_moe]).mean()
|
43
43
|
|
@@ -123,7 +123,7 @@ class ClassicTransformerBase(nn.Module):
|
|
123
123
|
self.layers = layers
|
124
124
|
self.num_layers = len(layers) if layers else 0
|
125
125
|
|
126
|
-
def
|
126
|
+
def moe_router_loss(self):
|
127
127
|
return torch.stack([self.layers[i].moe_router_loss() for i in range(self.num_layers) if self.layers[i].use_moe]).mean()
|
128
128
|
|
129
129
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
rxnn/transformers/moe.py
CHANGED
@@ -77,29 +77,63 @@ class MoeFeedForward(nn.Module):
|
|
77
77
|
return self.router.aux_loss
|
78
78
|
|
79
79
|
def forward(self, x: torch.Tensor):
|
80
|
+
# orig_shape = x.shape
|
81
|
+
# x = x.view(-1, self.embed_dim) # [batch*seq_len, embed_dim]
|
82
|
+
#
|
83
|
+
# # Get routing weights and indices
|
84
|
+
# weights, indices = self.router(x) # [batch*seq_len, top_k]
|
85
|
+
#
|
86
|
+
# # Create expert masks and combine it with masks
|
87
|
+
# mask = F.one_hot(indices, self.num_experts).float() # [batch*seq_len, top_k, num_experts]
|
88
|
+
# weights = (weights.unsqueeze(-1) * mask).sum(dim=1) # [batch*seq_len, num_experts]
|
89
|
+
#
|
90
|
+
# # Expert computation
|
91
|
+
# x = x.unsqueeze(1).expand(-1, self.num_experts, -1) # [batch*seq_len, num_experts, embed_dim]
|
92
|
+
#
|
93
|
+
# # First linear layer
|
94
|
+
# h = torch.einsum('bie,ieh->bih', x, self.w1) + self.b1 # [batch*seq_len, num_experts, hidden_dim]
|
95
|
+
# h = self._activate(h)
|
96
|
+
# h = self.dropout(h)
|
97
|
+
#
|
98
|
+
# # Second linear layer (projection back to embed_dim)
|
99
|
+
# out = torch.einsum('bih,ihe->bie', h, self.w2) + self.b2 # [batch*seq_len, num_experts, embed_dim]
|
100
|
+
#
|
101
|
+
# # Weighted sum of expert outputs
|
102
|
+
# out = (out * weights.unsqueeze(-1)).sum(dim=1) # [batch*seq_len, embed_dim]
|
103
|
+
#
|
104
|
+
# return out.view(*orig_shape)
|
80
105
|
orig_shape = x.shape
|
81
106
|
x = x.view(-1, self.embed_dim) # [batch*seq_len, embed_dim]
|
82
107
|
|
83
108
|
# Get routing weights and indices
|
84
|
-
weights, indices = self.router(x) # [batch*seq_len, top_k]
|
109
|
+
weights, indices = self.router(x) # [batch*seq_len, top_k], [batch*seq_len, top_k]
|
85
110
|
|
86
|
-
#
|
87
|
-
|
88
|
-
|
111
|
+
# Flatten indices and weights
|
112
|
+
batch_size = x.size(0)
|
113
|
+
top_k = indices.size(1)
|
114
|
+
indices = indices.view(-1) # [batch*seq_len * top_k]
|
115
|
+
weights = weights.view(-1, 1) # [batch*seq_len * top_k, 1]
|
89
116
|
|
90
|
-
#
|
91
|
-
|
117
|
+
# Select only the relevant experts for each token
|
118
|
+
selected_w1 = self.w1[indices] # [batch*seq_len * top_k, embed_dim, hidden_dim]
|
119
|
+
selected_b1 = self.b1[indices] # [batch*seq_len * top_k, hidden_dim]
|
120
|
+
selected_w2 = self.w2[indices] # [batch*seq_len * top_k, hidden_dim, embed_dim]
|
121
|
+
selected_b2 = self.b2[indices] # [batch*seq_len * top_k, embed_dim]
|
92
122
|
|
93
|
-
#
|
94
|
-
|
123
|
+
# Reshape x for batched computation
|
124
|
+
x_expanded = x.unsqueeze(1).repeat(1, top_k, 1).view(-1, self.embed_dim) # [batch*seq_len * top_k, embed_dim]
|
125
|
+
|
126
|
+
# Compute only the selected experts
|
127
|
+
h = torch.einsum('be, beh -> bh', x_expanded, selected_w1) + selected_b1
|
95
128
|
h = self._activate(h)
|
96
129
|
h = self.dropout(h)
|
97
130
|
|
98
|
-
|
99
|
-
out = torch.einsum('bih,ihe->bie', h, self.w2) + self.b2 # [batch*seq_len, num_experts, embed_dim]
|
131
|
+
out = torch.einsum('bh, bhe -> be', h, selected_w2) + selected_b2
|
100
132
|
|
101
|
-
#
|
102
|
-
out = (
|
133
|
+
# Reshape back and apply weights
|
134
|
+
out = out.view(batch_size, top_k, -1) # [batch*seq_len, top_k, embed_dim]
|
135
|
+
weights = weights.view(batch_size, top_k, 1) # [batch*seq_len, top_k, 1]
|
136
|
+
out = (out * weights).sum(dim=1) # Weighted sum over top_k experts
|
103
137
|
|
104
138
|
return out.view(*orig_shape)
|
105
139
|
|
rxnn/transformers/sampler.py
CHANGED
@@ -1,13 +1,15 @@
|
|
1
1
|
import torch
|
2
2
|
import torch.nn as nn
|
3
3
|
import torch.nn.functional as F
|
4
|
-
from typing import Iterator
|
4
|
+
from typing import Iterator, Union
|
5
|
+
from transformers import PreTrainedTokenizerFast, PreTrainedTokenizer
|
6
|
+
|
5
7
|
|
6
8
|
def sample(
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
9
|
+
logits: torch.Tensor,
|
10
|
+
temperature: float = 1.0,
|
11
|
+
top_k: int = None,
|
12
|
+
top_p: float = None,
|
11
13
|
) -> torch.Tensor:
|
12
14
|
if temperature <= 0:
|
13
15
|
raise ValueError("Temperature must be > 0")
|
@@ -45,6 +47,7 @@ def sample(
|
|
45
47
|
# Sample from distribution
|
46
48
|
return torch.multinomial(probs, num_samples=1)
|
47
49
|
|
50
|
+
|
48
51
|
class Sampler:
|
49
52
|
def __init__(self, model: nn.Module, device: torch.device, end_token_id: int):
|
50
53
|
self.model = model.to(device)
|
@@ -52,12 +55,12 @@ class Sampler:
|
|
52
55
|
self.end_token_id = end_token_id
|
53
56
|
|
54
57
|
def _generate_token(
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
58
|
+
self,
|
59
|
+
input_ids: torch.Tensor,
|
60
|
+
temperature: float,
|
61
|
+
top_k: int,
|
62
|
+
top_p: float,
|
63
|
+
attention_mask: torch.Tensor,
|
61
64
|
) -> tuple[int, torch.Tensor, torch.Tensor]:
|
62
65
|
# Forward pass to get next token logits
|
63
66
|
outputs = self.model(input_ids, attention_mask=attention_mask)
|
@@ -82,14 +85,14 @@ class Sampler:
|
|
82
85
|
)
|
83
86
|
|
84
87
|
def __call__(
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
88
|
+
self,
|
89
|
+
initial_tokens: torch.Tensor,
|
90
|
+
temperature: float = 1.0,
|
91
|
+
top_k: int = None,
|
92
|
+
top_p: float = None,
|
93
|
+
max_seq_len: int = 50,
|
94
|
+
attention_mask: torch.Tensor = None,
|
95
|
+
no_grad: bool = True,
|
93
96
|
) -> Iterator[int]:
|
94
97
|
# Convert initial tokens to tensor and move to device
|
95
98
|
input_ids = initial_tokens
|
@@ -97,13 +100,70 @@ class Sampler:
|
|
97
100
|
if no_grad:
|
98
101
|
with torch.no_grad():
|
99
102
|
for _ in range(max_seq_len):
|
100
|
-
next_token, input_ids, attention_mask = self._generate_token(input_ids, temperature, top_k, top_p,
|
103
|
+
next_token, input_ids, attention_mask = self._generate_token(input_ids, temperature, top_k, top_p,
|
104
|
+
attention_mask)
|
101
105
|
yield next_token
|
102
106
|
if next_token == self.end_token_id:
|
103
107
|
break
|
104
108
|
else:
|
105
109
|
for _ in range(max_seq_len):
|
106
|
-
next_token, input_ids, attention_mask = self._generate_token(input_ids, temperature, top_k, top_p,
|
110
|
+
next_token, input_ids, attention_mask = self._generate_token(input_ids, temperature, top_k, top_p,
|
111
|
+
attention_mask)
|
107
112
|
yield next_token
|
108
113
|
if next_token == self.end_token_id:
|
109
|
-
break
|
114
|
+
break
|
115
|
+
|
116
|
+
|
117
|
+
class SampleDecoder:
|
118
|
+
def __init__(self, sampler: Sampler, tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast]):
|
119
|
+
self.sampler = sampler
|
120
|
+
self.tokenizer = tokenizer
|
121
|
+
self.device = self.sampler.device
|
122
|
+
|
123
|
+
def tokenize_input(self, text: str):
|
124
|
+
tokenized = self.tokenizer(
|
125
|
+
text,
|
126
|
+
max_length=256,
|
127
|
+
truncation=True,
|
128
|
+
padding=False,
|
129
|
+
return_tensors='pt',
|
130
|
+
return_attention_mask=True
|
131
|
+
)
|
132
|
+
tokenized['input_ids'] = tokenized['input_ids'][:, :-1].to(self.device)
|
133
|
+
tokenized['attention_mask'] = tokenized['attention_mask'][:, :-1].to(self.device)
|
134
|
+
del tokenized['token_type_ids']
|
135
|
+
return tokenized
|
136
|
+
|
137
|
+
def ids_iter(self, text: str, temperature: float = 0.1, top_p: float = 0.9, max_seq_len=256):
|
138
|
+
tokenized = self.tokenize_input(text)
|
139
|
+
return self.sampler(
|
140
|
+
tokenized['input_ids'],
|
141
|
+
temperature=temperature,
|
142
|
+
top_p=top_p,
|
143
|
+
max_seq_len=max_seq_len,
|
144
|
+
attention_mask=tokenized['attention_mask']
|
145
|
+
)
|
146
|
+
|
147
|
+
def txt_iter(self, text: str, temperature: float = 0.1, top_p: float = 0.9, max_seq_len=256):
|
148
|
+
return map(
|
149
|
+
lambda x: self.tokenizer.decode([x]).replace('Ċ', '\n').replace('Ġ', ' '),
|
150
|
+
self.ids_iter(text, temperature, top_p, max_seq_len)
|
151
|
+
)
|
152
|
+
|
153
|
+
def txt(self, text: str, temperature: float = 0.1, top_p: float = 0.9, max_seq_len=256):
|
154
|
+
return text + ''.join(self.txt_iter(text, temperature, top_p, max_seq_len))
|
155
|
+
|
156
|
+
def print_stream(self, text: str, temperature: float = 0.1, top_p: float = 0.9, max_seq_len=256):
|
157
|
+
print(text, end='')
|
158
|
+
resp = text
|
159
|
+
for txt_token in self.txt_iter(text, temperature=temperature, top_p=top_p, max_seq_len=max_seq_len):
|
160
|
+
print(txt_token, end='')
|
161
|
+
resp += txt_token
|
162
|
+
return resp
|
163
|
+
|
164
|
+
def __call__(self, text: str, print_stream: bool = False, temperature: float = 0.1, top_p: float = 0.9,
|
165
|
+
max_seq_len=256):
|
166
|
+
if print_stream:
|
167
|
+
return self.print_stream(text, temperature=temperature, top_p=top_p, max_seq_len=max_seq_len)
|
168
|
+
else:
|
169
|
+
return self.txt(text, temperature=temperature, top_p=top_p, max_seq_len=max_seq_len)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
rxnn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
rxnn/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
|
-
rxnn/experimental/attention.py,sha256=
|
3
|
+
rxnn/experimental/attention.py,sha256=HahcWU37FTfW8kwSTW8z_l7EtAVkJgvDDxLU8k3miHo,17101
|
4
4
|
rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
5
5
|
rxnn/memory/norm.py,sha256=Ofl8Q5NYEF9GQeO0bhM43tkTW91J0y6TSvTAOYMgloM,6278
|
6
6
|
rxnn/memory/stm.py,sha256=EsD8slSP4_9dLuq6aFPDmuFe8PWilxh90so5Z3nm-ig,2057
|
@@ -16,14 +16,14 @@ rxnn/training/tokenizer.py,sha256=4Y41f07uo2KPA_7bp3FCcwGKbXoS2hsckOoXUsXfQxY,80
|
|
16
16
|
rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
17
17
|
rxnn/transformers/attention.py,sha256=FfEYE0THO73p_1eRupr2mcwfW4UbI_riIxkHfr8X_1c,14022
|
18
18
|
rxnn/transformers/ff.py,sha256=jJnuBDsnnX5uYC_WZH8cXAYrMnz0P-iX7MwcPivjRtI,2533
|
19
|
-
rxnn/transformers/layers.py,sha256=
|
19
|
+
rxnn/transformers/layers.py,sha256=HhIiykmrBgdsV4AbMQXr9t0cSo4gSIeN0dPtc8mDyOo,5629
|
20
20
|
rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
|
21
|
-
rxnn/transformers/models.py,sha256=
|
22
|
-
rxnn/transformers/moe.py,sha256=
|
21
|
+
rxnn/transformers/models.py,sha256=w-zB_8QB9-Fae-GkGgmVDNY-Ts_0gBeWcevpl9qzZVM,7169
|
22
|
+
rxnn/transformers/moe.py,sha256=fFPTRcctCSc9OwHd0PhNb0nwHgNJY7dXfUtGreXtaho,6720
|
23
23
|
rxnn/transformers/positional.py,sha256=2l38RS0Dini3f6Z3LUHr3XwWzg1UK7fO2C6wazWDAYU,4292
|
24
|
-
rxnn/transformers/sampler.py,sha256=
|
24
|
+
rxnn/transformers/sampler.py,sha256=poWBpxg1iuK5gEJtxHkk5VVfS9V48hs2Olqdhy_Gw8c,6548
|
25
25
|
rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
|
26
|
-
rxnn-0.1.
|
27
|
-
rxnn-0.1.
|
28
|
-
rxnn-0.1.
|
29
|
-
rxnn-0.1.
|
26
|
+
rxnn-0.1.14.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
27
|
+
rxnn-0.1.14.dist-info/METADATA,sha256=YQDNMaHDrfVdOk44qEUczgLaNcrXApoqVmNX50yQDdM,14629
|
28
|
+
rxnn-0.1.14.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
|
29
|
+
rxnn-0.1.14.dist-info/RECORD,,
|
File without changes
|
File without changes
|