rxnn 0.1.25__py3-none-any.whl → 0.1.27__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/experimental/attention.py +258 -2
- rxnn/experimental/models.py +1 -1
- rxnn/transformers/layers.py +3 -2
- {rxnn-0.1.25.dist-info → rxnn-0.1.27.dist-info}/METADATA +1 -1
- {rxnn-0.1.25.dist-info → rxnn-0.1.27.dist-info}/RECORD +7 -7
- {rxnn-0.1.25.dist-info → rxnn-0.1.27.dist-info}/LICENSE +0 -0
- {rxnn-0.1.25.dist-info → rxnn-0.1.27.dist-info}/WHEEL +0 -0
rxnn/experimental/attention.py
CHANGED
@@ -238,6 +238,230 @@ class DeepMoeAttention(GroupedMoeAttention):
|
|
238
238
|
# Key/Value processing
|
239
239
|
return super()._forward_qkv(q, key, value, b, t, d, skip_query_processing=True)
|
240
240
|
|
241
|
+
# Vectorized
|
242
|
+
|
243
|
+
class GroupedMoeAttentionVectorized(GroupedQueryAttention):
|
244
|
+
"""
|
245
|
+
Vectorized implementation calculates all expert heads for each token and selecting active tokens later. Linear layers
|
246
|
+
for Attention are rather small, compared to MoE Feed Forward layers, so it's possible that it will be faster than filtering
|
247
|
+
experts - it has to be tested.
|
248
|
+
|
249
|
+
Grouped MoE Attention (GMA) - GQA extended with Mixture-of-Experts (MoE) routing.
|
250
|
+
|
251
|
+
Instead of mapping keys/values to static head groups, it dynamically selects head expert groups. It has the same
|
252
|
+
number of total keys/values heads as query heads, but uses only a selected group for attention calculation.
|
253
|
+
- with num_groups set to 1, it will be MoE MultiQueryAttention
|
254
|
+
|
255
|
+
Compared to traditional GQA/MQA, it should provide better performance, because lot less data could be lost using
|
256
|
+
this approach - we are training the full number of keys/values heads, while using only a group.
|
257
|
+
|
258
|
+
In case of efficiency, it should be close to GQA/MQA linear performance, but with a small MoE routing overhead.
|
259
|
+
|
260
|
+
Optionally, it could use even more expert heads than attention heads - in example:
|
261
|
+
- 512 dim divided into 16 heads with 32 dim, using 4 head groups - may use i.e., 24 total expert heads - still only
|
262
|
+
4 will be used for attention calculation, while 16 is used to split dimensions (in that case it will have 16 query heads)
|
263
|
+
|
264
|
+
© 2025 Adam Filipek
|
265
|
+
"""
|
266
|
+
|
267
|
+
def __init__(
|
268
|
+
self,
|
269
|
+
embed_dim: int,
|
270
|
+
num_heads: int,
|
271
|
+
num_groups: int,
|
272
|
+
dropout: float = 0.0,
|
273
|
+
rope: RotaryPositionalEmbedding = None,
|
274
|
+
rope_only_for_query: bool = False,
|
275
|
+
use_relative_embeddings: bool = False,
|
276
|
+
max_seq_len: int = 1024,
|
277
|
+
use_flash_attention: bool = False,
|
278
|
+
is_causal: bool = False,
|
279
|
+
use_bias: bool = False,
|
280
|
+
num_experts: int = None,
|
281
|
+
*args,
|
282
|
+
**kwargs,
|
283
|
+
):
|
284
|
+
self.num_experts = num_experts if num_experts is not None else num_heads
|
285
|
+
super(GroupedMoeAttentionVectorized, self).__init__(
|
286
|
+
embed_dim,
|
287
|
+
num_heads,
|
288
|
+
num_groups=num_groups,
|
289
|
+
dropout=dropout,
|
290
|
+
rope=rope,
|
291
|
+
rope_only_for_query=rope_only_for_query,
|
292
|
+
use_relative_embeddings=use_relative_embeddings,
|
293
|
+
max_seq_len=max_seq_len,
|
294
|
+
use_flash_attention=use_flash_attention,
|
295
|
+
is_causal=is_causal,
|
296
|
+
use_bias=use_bias,
|
297
|
+
*args,
|
298
|
+
**kwargs,
|
299
|
+
)
|
300
|
+
|
301
|
+
def _init_kv(self, embed_dim: int):
|
302
|
+
self.router = MoeRouter(embed_dim, self.num_experts, top_k=self.num_groups)
|
303
|
+
hidden_dim = embed_dim // self.num_heads
|
304
|
+
self.wk = nn.Parameter(torch.empty(self.num_experts, embed_dim, hidden_dim))
|
305
|
+
self.bk = nn.Parameter(torch.zeros(self.num_experts, hidden_dim)) if self.use_bias else None
|
306
|
+
self.wv = nn.Parameter(torch.empty(self.num_experts, embed_dim, hidden_dim))
|
307
|
+
self.bv = nn.Parameter(torch.zeros(self.num_experts, hidden_dim)) if self.use_bias else None
|
308
|
+
self._init_experts()
|
309
|
+
|
310
|
+
def _init_experts(self):
|
311
|
+
torch.nn.init.xavier_uniform_(self.wk)
|
312
|
+
torch.nn.init.xavier_uniform_(self.wv)
|
313
|
+
if self.use_bias:
|
314
|
+
torch.nn.init.zeros_(self.bk)
|
315
|
+
torch.nn.init.zeros_(self.bv)
|
316
|
+
|
317
|
+
def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int,
|
318
|
+
skip_query_processing: bool = False):
|
319
|
+
head_dim = d // self.num_heads
|
320
|
+
|
321
|
+
# Process Query as in GQA
|
322
|
+
q = self.q_proj(query).view(b, t, self.num_heads, -1).transpose(1, 2)
|
323
|
+
|
324
|
+
# Key/Value MoE routing
|
325
|
+
B, S, D = key.shape
|
326
|
+
key_flat = key.reshape(B * S, D)
|
327
|
+
weights, indices = self.router(key_flat) # (B*S, num_groups), (B*S, num_groups)
|
328
|
+
weights = weights.view(B, S, self.num_groups, 1)
|
329
|
+
indices = indices.view(B, S, self.num_groups)
|
330
|
+
|
331
|
+
# Compute all experts' projections
|
332
|
+
# Shape: (B*S, num_experts, head_dim)
|
333
|
+
k_all = torch.einsum('be,ehd->beh', key_flat, self.wk) # [B*S, num_experts, head_dim]
|
334
|
+
v_all = torch.einsum('be,ehd->beh', value.view(B*S, D), self.wv)
|
335
|
+
|
336
|
+
# Reshape to [B, S, num_experts, head_dim]
|
337
|
+
k_all = k_all.view(B, S, self.num_experts, -1)
|
338
|
+
v_all = v_all.view(B, S, self.num_experts, -1)
|
339
|
+
|
340
|
+
# Gather top-k experts and weights
|
341
|
+
# Expand indices to [B, S, num_groups, head_dim]
|
342
|
+
expanded_indices = indices.unsqueeze(-1).expand(-1, -1, -1, k_all.size(-1))
|
343
|
+
selected_k = torch.gather(k_all, 2, expanded_indices) # [B, S, num_groups, head_dim]
|
344
|
+
selected_v = torch.gather(v_all, 2, expanded_indices)
|
345
|
+
|
346
|
+
# Weighted sum
|
347
|
+
weighted_k = (selected_k * weights).sum(dim=2) # [B, S, head_dim]
|
348
|
+
weighted_v = (selected_v * weights).sum(dim=2)
|
349
|
+
|
350
|
+
# Reshape to GQA format
|
351
|
+
k = weighted_k.view(B, S, self.num_groups, -1).permute(0, 2, 1, 3) # [B, G, S, head_dim]
|
352
|
+
v = weighted_v.view(B, S, self.num_groups, -1).permute(0, 2, 1, 3)
|
353
|
+
|
354
|
+
if not self.use_flash_attention:
|
355
|
+
group_heads = self.num_heads // self.num_groups
|
356
|
+
|
357
|
+
k = k.unsqueeze(2).expand(-1, -1, group_heads, -1, -1) # (B, G, group_heads, S, head_dim)
|
358
|
+
v = v.unsqueeze(2).expand(-1, -1, group_heads, -1, -1) # (B, G, group_heads, S, head_dim)
|
359
|
+
|
360
|
+
k = k.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
|
361
|
+
v = v.flatten(start_dim=1, end_dim=2) # (B, H, S, head_dim)
|
362
|
+
|
363
|
+
return q, k, v
|
364
|
+
|
365
|
+
|
366
|
+
class DeepMoeAttentionVectorized(GroupedMoeAttentionVectorized):
|
367
|
+
"""
|
368
|
+
Vectorized implementation calculates all expert heads for each token and selecting active tokens later. Linear layers
|
369
|
+
for Attention are rather small, compared to MoE Feed Forward layers, so it's possible that it will be faster than filtering
|
370
|
+
experts - it has to be tested.
|
371
|
+
|
372
|
+
Deep MoE Attention (SMA) - Grouped MoE Attention extended even more for sublinear computational efficiency.
|
373
|
+
|
374
|
+
In addition to using Mixture-of-Experts (MoE) for key/value head groups, SMA is also using dynamically selected
|
375
|
+
query heads - with that approach, each token could attend to every other token, but only partially - only some part of
|
376
|
+
information from each token is used to identify related information parts from other tokens. So, DMA is not spatially
|
377
|
+
sparse (has access to all tokens), but rather structurally sparse (has access only to the part of token's information).
|
378
|
+
|
379
|
+
This solution could reduce the computational complexity of attention operation to sublinear level (<O(N)) and provide
|
380
|
+
a viable and efficient alternative to spatial sparse attention mechanisms like Flex Attention.
|
381
|
+
|
382
|
+
© 2025 Adam Filipek
|
383
|
+
"""
|
384
|
+
|
385
|
+
def __init__(
|
386
|
+
self,
|
387
|
+
embed_dim: int,
|
388
|
+
num_heads: int,
|
389
|
+
num_groups: int,
|
390
|
+
dropout: float = 0.0,
|
391
|
+
rope: RotaryPositionalEmbedding = None,
|
392
|
+
rope_only_for_query: bool = False,
|
393
|
+
use_relative_embeddings: bool = False,
|
394
|
+
max_seq_len: int = 1024,
|
395
|
+
use_flash_attention: bool = False,
|
396
|
+
is_causal: bool = False,
|
397
|
+
use_bias: bool = False,
|
398
|
+
num_experts: int = None,
|
399
|
+
num_query_experts: int = None,
|
400
|
+
num_query_groups: int = None,
|
401
|
+
*args,
|
402
|
+
**kwargs,
|
403
|
+
):
|
404
|
+
self.num_query_experts = num_query_experts if num_query_experts is not None else num_heads
|
405
|
+
self.num_query_groups = num_query_groups if num_query_groups is not None else num_groups
|
406
|
+
super(DeepMoeAttentionVectorized, self).__init__(
|
407
|
+
embed_dim,
|
408
|
+
num_heads,
|
409
|
+
num_groups=num_groups,
|
410
|
+
dropout=dropout,
|
411
|
+
rope=rope,
|
412
|
+
rope_only_for_query=rope_only_for_query,
|
413
|
+
use_relative_embeddings=use_relative_embeddings,
|
414
|
+
max_seq_len=max_seq_len,
|
415
|
+
use_flash_attention=use_flash_attention,
|
416
|
+
is_causal=is_causal,
|
417
|
+
use_bias=use_bias,
|
418
|
+
num_experts=num_experts,
|
419
|
+
*args,
|
420
|
+
**kwargs,
|
421
|
+
)
|
422
|
+
|
423
|
+
def _init_q(self, embed_dim: int):
|
424
|
+
self.query_router = MoeRouter(embed_dim, self.num_query_experts, top_k=self.num_query_groups)
|
425
|
+
hidden_dim = embed_dim // self.num_heads
|
426
|
+
self.wq = nn.Parameter(torch.empty(self.num_query_experts, embed_dim, hidden_dim))
|
427
|
+
self.bq = nn.Parameter(torch.zeros(self.num_query_experts, hidden_dim)) if self.use_bias else None
|
428
|
+
self._init_query_experts()
|
429
|
+
|
430
|
+
def _init_query_experts(self):
|
431
|
+
torch.nn.init.xavier_uniform_(self.wq)
|
432
|
+
if self.use_bias:
|
433
|
+
torch.nn.init.zeros_(self.bq)
|
434
|
+
|
435
|
+
def _init_out(self, embed_dim: int):
|
436
|
+
"""Initialize output projection"""
|
437
|
+
self.out_proj = nn.Linear(embed_dim // (self.num_heads // self.num_groups), embed_dim)
|
438
|
+
|
439
|
+
def _transpose_output(self, attn_output: torch.Tensor, b: int, t: int, d: int):
|
440
|
+
"""Transpose attention output back to (B, T, D) shape"""
|
441
|
+
hidden_dim = d // self.num_heads * self.num_query_groups
|
442
|
+
return attn_output.transpose(1, 2).contiguous().view(b, t, hidden_dim)
|
443
|
+
|
444
|
+
def _forward_qkv(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, b: int, t: int, d: int):
|
445
|
+
B, T, D = query.shape
|
446
|
+
query_flat = query.reshape(B * T, D)
|
447
|
+
weights_q, indices_q = self.query_router(query_flat)
|
448
|
+
weights_q = weights_q.view(B, T, self.num_query_groups, 1)
|
449
|
+
indices_q = indices_q.view(B, T, self.num_query_groups)
|
450
|
+
|
451
|
+
# Compute all query experts
|
452
|
+
q_all = torch.einsum('be,ehd->beh', query_flat, self.wq) # [B*T, num_query_experts, head_dim]
|
453
|
+
q_all = q_all.view(B, T, self.num_query_experts, -1)
|
454
|
+
|
455
|
+
# Gather top-k experts
|
456
|
+
expanded_indices = indices_q.unsqueeze(-1).expand(-1, -1, -1, q_all.size(-1))
|
457
|
+
selected_q = torch.gather(q_all, 2, expanded_indices) # [B, T, num_query_groups, head_dim]
|
458
|
+
|
459
|
+
# Weighted sum
|
460
|
+
q = (selected_q * weights_q).sum(dim=2) # [B, T, head_dim]
|
461
|
+
q = q.view(B, T, self.num_query_groups, -1).permute(0, 2, 1, 3) # [B, H_q, T, head_dim]
|
462
|
+
|
463
|
+
return super()._forward_qkv(q, key, value, b, t, d, skip_query_processing=True)
|
464
|
+
|
241
465
|
|
242
466
|
# Others
|
243
467
|
|
@@ -389,7 +613,7 @@ def init_moe_attention(
|
|
389
613
|
num_query_experts: int = None,
|
390
614
|
num_query_groups: int = None,
|
391
615
|
) -> GroupedQueryAttention:
|
392
|
-
assert attention_type in ['gma', 'dma'], "Error, attention type should be one of: 'gma', 'dma', 'gma_v', 'dma_v'"
|
616
|
+
assert attention_type in ['gma', 'dma', 'gma_v', 'dma_v'], "Error, attention type should be one of: 'gma', 'dma', 'gma_v', 'dma_v'"
|
393
617
|
|
394
618
|
if attention_type == "gma":
|
395
619
|
return GroupedMoeAttention(
|
@@ -406,7 +630,7 @@ def init_moe_attention(
|
|
406
630
|
use_bias=use_bias,
|
407
631
|
num_experts=num_experts,
|
408
632
|
)
|
409
|
-
|
633
|
+
elif attention_type == "dma":
|
410
634
|
return DeepMoeAttention(
|
411
635
|
embed_dim,
|
412
636
|
num_heads,
|
@@ -423,3 +647,35 @@ def init_moe_attention(
|
|
423
647
|
num_query_experts=num_query_experts,
|
424
648
|
num_query_groups=num_query_groups,
|
425
649
|
)
|
650
|
+
elif attention_type == "gma_v":
|
651
|
+
return GroupedMoeAttentionVectorized(
|
652
|
+
embed_dim,
|
653
|
+
num_heads,
|
654
|
+
gqa_groups,
|
655
|
+
dropout=dropout,
|
656
|
+
rope=rope,
|
657
|
+
use_relative_embeddings=use_relative_embeddings,
|
658
|
+
max_seq_len=max_seq_len,
|
659
|
+
rope_only_for_query=rope_only_for_query,
|
660
|
+
use_flash_attention=use_flash_attention,
|
661
|
+
is_causal=is_causal,
|
662
|
+
use_bias=use_bias,
|
663
|
+
num_experts=num_experts,
|
664
|
+
)
|
665
|
+
else:
|
666
|
+
return DeepMoeAttentionVectorized(
|
667
|
+
embed_dim,
|
668
|
+
num_heads,
|
669
|
+
gqa_groups,
|
670
|
+
dropout=dropout,
|
671
|
+
rope=rope,
|
672
|
+
use_relative_embeddings=use_relative_embeddings,
|
673
|
+
max_seq_len=max_seq_len,
|
674
|
+
rope_only_for_query=rope_only_for_query,
|
675
|
+
use_flash_attention=use_flash_attention,
|
676
|
+
is_causal=is_causal,
|
677
|
+
use_bias=use_bias,
|
678
|
+
num_experts=num_experts,
|
679
|
+
num_query_experts=num_query_experts,
|
680
|
+
num_query_groups=num_query_groups,
|
681
|
+
)
|
rxnn/experimental/models.py
CHANGED
@@ -65,7 +65,7 @@ class MoeAttentionTransformer(nn.Module, PyTorchModelHubMixin, pipeline_tag="tex
|
|
65
65
|
assert ff_activation in ['relu', 'gelu',
|
66
66
|
'swish', 'silu', 'linear',
|
67
67
|
'sigmoid'], 'Feed-forward activation could be "relu", "gelu", "swish", "silu", "linear", "sigmoid".'
|
68
|
-
assert att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma'], 'Self-attention type could be "mha", "gqa", "mqa", "gma", "dma"'
|
68
|
+
assert att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma', 'gma_v', 'dma_v'], 'Self-attention type could be "mha", "gqa", "mqa", "gma", "dma", "gma_v", "dma_v".'
|
69
69
|
|
70
70
|
embedding = nn.Embedding(vocab_size, embed_dim)
|
71
71
|
rope = RotaryPositionalEmbedding(embed_dim // att_heads, seq_len)
|
rxnn/transformers/layers.py
CHANGED
@@ -132,9 +132,10 @@ class ClassicTransformerLayer(nn.Module):
|
|
132
132
|
|
133
133
|
if use_gated:
|
134
134
|
if use_moe:
|
135
|
-
self.ff = GatedMoeFeedForward(embed_dim, ff_dim, num_experts, top_k=moe_top_k,
|
135
|
+
self.ff = GatedMoeFeedForward(embed_dim, ff_dim, num_experts, ff_activation, top_k=moe_top_k,
|
136
|
+
dropout=ff_dropout)
|
136
137
|
else:
|
137
|
-
self.ff = GatedFeedForward(embed_dim, ff_dim, dropout=ff_dropout)
|
138
|
+
self.ff = GatedFeedForward(embed_dim, ff_dim, ff_activation, dropout=ff_dropout)
|
138
139
|
else:
|
139
140
|
if use_moe:
|
140
141
|
self.ff = MoeFeedForward(embed_dim, ff_dim, num_experts, ff_activation, top_k=moe_top_k,
|
@@ -1,7 +1,7 @@
|
|
1
1
|
rxnn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
2
|
rxnn/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
|
-
rxnn/experimental/attention.py,sha256=
|
4
|
-
rxnn/experimental/models.py,sha256=
|
3
|
+
rxnn/experimental/attention.py,sha256=csasMRxL4nq2dS7pc9WdS4bvCB70ZVgsR7LTHV2jEJ0,29388
|
4
|
+
rxnn/experimental/models.py,sha256=QEuFBB9iEg5AbKQLwGJkAwPjMfaVeTqazhKDWPRkm7o,4598
|
5
5
|
rxnn/experimental/moe.py,sha256=PhiaNr3FwR2Zv2a0tfj6sfZ4iyhLo3Jyp2DwXq19qZQ,7935
|
6
6
|
rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
7
7
|
rxnn/memory/norm.py,sha256=Ofl8Q5NYEF9GQeO0bhM43tkTW91J0y6TSvTAOYMgloM,6278
|
@@ -18,14 +18,14 @@ rxnn/training/tokenizer.py,sha256=4Y41f07uo2KPA_7bp3FCcwGKbXoS2hsckOoXUsXfQxY,80
|
|
18
18
|
rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
19
19
|
rxnn/transformers/attention.py,sha256=FHATZVf_kt3OHnG02zEeG9QdUXLncKDjrhyT28Pk0E4,14185
|
20
20
|
rxnn/transformers/ff.py,sha256=jJnuBDsnnX5uYC_WZH8cXAYrMnz0P-iX7MwcPivjRtI,2533
|
21
|
-
rxnn/transformers/layers.py,sha256=
|
21
|
+
rxnn/transformers/layers.py,sha256=n_jZTqEF_vLkF31AkB5XGErfm2sQFd9CRqJUHKRFkKI,6956
|
22
22
|
rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
|
23
23
|
rxnn/transformers/models.py,sha256=w-zB_8QB9-Fae-GkGgmVDNY-Ts_0gBeWcevpl9qzZVM,7169
|
24
24
|
rxnn/transformers/moe.py,sha256=6Cffyo0QjmEWc4rK1ncOmLRCQbY0OpQJ4D7xH_4nTN4,4738
|
25
25
|
rxnn/transformers/positional.py,sha256=2l38RS0Dini3f6Z3LUHr3XwWzg1UK7fO2C6wazWDAYU,4292
|
26
26
|
rxnn/transformers/sampler.py,sha256=poWBpxg1iuK5gEJtxHkk5VVfS9V48hs2Olqdhy_Gw8c,6548
|
27
27
|
rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
|
28
|
-
rxnn-0.1.
|
29
|
-
rxnn-0.1.
|
30
|
-
rxnn-0.1.
|
31
|
-
rxnn-0.1.
|
28
|
+
rxnn-0.1.27.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
29
|
+
rxnn-0.1.27.dist-info/METADATA,sha256=XjcqSdhjTRsCvvP-o981Ihp4k5PFRCQUSLVsPZ_NVPw,16627
|
30
|
+
rxnn-0.1.27.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
|
31
|
+
rxnn-0.1.27.dist-info/RECORD,,
|
File without changes
|
File without changes
|