rxnn 0.2.53__py3-none-any.whl → 0.2.54__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/experimental/attention.py
CHANGED
@@ -367,7 +367,7 @@ class FlexAttention(MultiHeadAttention):
|
|
367
367
|
def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: torch.Tensor = None):
|
368
368
|
b, t, d = query.size()
|
369
369
|
|
370
|
-
# Prepend global tokens
|
370
|
+
# Prepend global tokens
|
371
371
|
global_tokens = self.global_tokens.expand(b, -1, -1)
|
372
372
|
x = torch.cat([global_tokens, query], dim=1)
|
373
373
|
|
@@ -380,39 +380,42 @@ class FlexAttention(MultiHeadAttention):
|
|
380
380
|
if self.rope:
|
381
381
|
q, k = self._apply_rope(q, k, separate=True)
|
382
382
|
|
383
|
-
# Split Q into global and local
|
384
|
-
global_q = q[:, :, :self.num_global_tokens] # (B, H, G,
|
385
|
-
local_q = q[:, :, self.num_global_tokens:] # (B, H, L,
|
383
|
+
# Split Q into global and local
|
384
|
+
global_q = q[:, :, :self.num_global_tokens] # (B, H, G, D)
|
385
|
+
local_q = q[:, :, self.num_global_tokens:] # (B, H, L, D)
|
386
|
+
|
387
|
+
# Global attention
|
388
|
+
global_attn = F.scaled_dot_product_attention(global_q, k, v, attn_mask=mask, dropout_p=self.dropout.p if self.training else 0, is_causal=self.is_causal)
|
389
|
+
|
390
|
+
# Local attention with windowed slicing (no large masks)
|
386
391
|
L = local_q.size(2)
|
387
392
|
S = k.size(2)
|
393
|
+
windowed_attn = []
|
388
394
|
|
389
|
-
|
390
|
-
|
391
|
-
|
392
|
-
|
393
|
-
)
|
395
|
+
for i in range(0, L, self.window_size):
|
396
|
+
start = i
|
397
|
+
end = min(i + self.window_size, L)
|
398
|
+
window_q = local_q[:, :, start:end] # (B, H, W, D)
|
394
399
|
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
dropout_p=self.dropout.p if self.training else 0, is_causal=self.is_causal
|
406
|
-
)
|
400
|
+
# Use only relevant keys/values (same window)
|
401
|
+
k_window_start = max(0, start - self.window_size)
|
402
|
+
k_window_end = min(S, end + self.window_size)
|
403
|
+
window_k = k[:, :, k_window_start:k_window_end]
|
404
|
+
window_v = v[:, :, k_window_start:k_window_end]
|
405
|
+
|
406
|
+
|
407
|
+
window_attn = F.scaled_dot_product_attention(window_q, window_k, window_v, attn_mask=None, dropout_p=self.dropout.p if self.training else 0, is_causal=self.is_causal)
|
408
|
+
|
409
|
+
windowed_attn.append(window_attn)
|
407
410
|
|
408
|
-
#
|
409
|
-
|
411
|
+
# Concat local attention
|
412
|
+
local_attn = torch.cat(windowed_attn, dim=2)
|
410
413
|
|
411
|
-
#
|
412
|
-
|
413
|
-
output = self.
|
414
|
+
# Combine global and local
|
415
|
+
attn = torch.cat([global_attn, local_attn], dim=2)
|
416
|
+
output = self._merge_heads(attn)
|
417
|
+
output = self.out_proj(output)
|
414
418
|
|
415
|
-
# Return only the local tokens (original query tokens)
|
416
419
|
return output[:, self.num_global_tokens:, :]
|
417
420
|
|
418
421
|
|
@@ -1,7 +1,7 @@
|
|
1
1
|
rxnn/.DS_Store,sha256=BxZLo9tFs48JMq6jhumiCnCPLTeCwl619CFSg4ClRAY,6148
|
2
2
|
rxnn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
3
|
rxnn/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
4
|
-
rxnn/experimental/attention.py,sha256=
|
4
|
+
rxnn/experimental/attention.py,sha256=zQM6Og62IZVGogsDBReYrHSiRZmDaebl1FcH2e6sHyY,24589
|
5
5
|
rxnn/experimental/models.py,sha256=HPOIRpnX_oiI10wsVC4J6rzo3T6dj10aNWGYpa9S1UU,5115
|
6
6
|
rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
|
7
7
|
rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -33,7 +33,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
|
|
33
33
|
rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
|
34
34
|
rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
|
35
35
|
rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
|
36
|
-
rxnn-0.2.
|
37
|
-
rxnn-0.2.
|
38
|
-
rxnn-0.2.
|
39
|
-
rxnn-0.2.
|
36
|
+
rxnn-0.2.54.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
37
|
+
rxnn-0.2.54.dist-info/METADATA,sha256=t6l1VezLNpdpgaXaqB-YhrfAhEUlWZm9-wwzBZ_Xk34,25997
|
38
|
+
rxnn-0.2.54.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
39
|
+
rxnn-0.2.54.dist-info/RECORD,,
|
File without changes
|
File without changes
|