rxnn 0.2.51__py3-none-any.whl → 0.2.53__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/experimental/attention.py +103 -60
- rxnn/experimental/models.py +7 -3
- rxnn/memory/attention.py +8 -3
- rxnn/training/mrl.py +15 -1
- rxnn/training/rl.py +11 -2
- rxnn/training/utils.py +11 -0
- rxnn/transformers/attention.py +3 -1
- rxnn/transformers/layers.py +1 -1
- {rxnn-0.2.51.dist-info → rxnn-0.2.53.dist-info}/METADATA +1 -1
- {rxnn-0.2.51.dist-info → rxnn-0.2.53.dist-info}/RECORD +12 -12
- {rxnn-0.2.51.dist-info → rxnn-0.2.53.dist-info}/LICENSE +0 -0
- {rxnn-0.2.51.dist-info → rxnn-0.2.53.dist-info}/WHEEL +0 -0
rxnn/experimental/attention.py
CHANGED
@@ -318,76 +318,102 @@ class SparseQueryAttention(MultiHeadAttention):
|
|
318
318
|
|
319
319
|
|
320
320
|
# Others
|
321
|
-
|
322
321
|
class FlexAttention(MultiHeadAttention):
|
323
322
|
def __init__(
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
323
|
+
self,
|
324
|
+
embed_dim: int,
|
325
|
+
num_heads: int,
|
326
|
+
dropout: float = 0.0,
|
327
|
+
rope: RotaryPositionalEmbedding = None,
|
328
|
+
rope_only_for_query: bool = False,
|
329
|
+
rope_only_for_keys: bool = False,
|
330
|
+
use_relative_embeddings: bool = False,
|
331
|
+
max_seq_len: int = 1024,
|
332
|
+
use_flash_attention: bool = True,
|
333
|
+
is_causal: bool = False,
|
334
|
+
use_bias: bool = False,
|
335
|
+
num_global_tokens: int = 16,
|
336
|
+
window_size: int = 128,
|
330
337
|
):
|
331
|
-
super().__init__(
|
338
|
+
super(FlexAttention, self).__init__(
|
339
|
+
embed_dim,
|
340
|
+
num_heads,
|
341
|
+
dropout=dropout,
|
342
|
+
rope=rope,
|
343
|
+
rope_only_for_query=rope_only_for_query,
|
344
|
+
rope_only_for_keys=rope_only_for_keys,
|
345
|
+
use_relative_embeddings=use_relative_embeddings,
|
346
|
+
max_seq_len=max_seq_len,
|
347
|
+
use_flash_attention=use_flash_attention,
|
348
|
+
is_causal=is_causal,
|
349
|
+
use_bias=use_bias,
|
350
|
+
)
|
351
|
+
self.head_dim = embed_dim // num_heads
|
332
352
|
self.num_global_tokens = num_global_tokens
|
333
353
|
self.window_size = window_size
|
334
|
-
self.global_tokens = nn.Parameter(torch.zeros(1, num_global_tokens, embed_dim))
|
335
354
|
|
336
|
-
|
355
|
+
# Learnable global tokens
|
356
|
+
self.global_tokens = nn.Parameter(torch.randn(1, num_global_tokens, embed_dim))
|
357
|
+
|
358
|
+
|
359
|
+
def _split_heads(self, x: torch.Tensor) -> torch.Tensor:
|
360
|
+
b, t, d = x.size()
|
361
|
+
return x.view(b, t, self.num_heads, self.head_dim).transpose(1, 2)
|
362
|
+
|
363
|
+
def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
|
364
|
+
b, h, t, d = x.size()
|
365
|
+
return self._transpose_output(x, b, t, h * d)
|
366
|
+
|
367
|
+
def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: torch.Tensor = None):
|
337
368
|
b, t, d = query.size()
|
338
|
-
head_dim = d // self.num_heads
|
339
369
|
|
340
|
-
#
|
341
|
-
|
342
|
-
|
343
|
-
num_windows = (seq_len - self.num_global_tokens + self.window_size - 1) // self.window_size
|
370
|
+
# Prepend global tokens to the input query
|
371
|
+
global_tokens = self.global_tokens.expand(b, -1, -1)
|
372
|
+
x = torch.cat([global_tokens, query], dim=1)
|
344
373
|
|
345
374
|
# Project Q, K, V
|
346
|
-
q
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
global_v = v[:, :, :self.num_global_tokens]
|
352
|
-
global_attn = self._calculate_attn_weights(global_q, global_k, d) @ global_v
|
353
|
-
|
354
|
-
# Process Global-to-Local Attention
|
355
|
-
local_k = k[:, :, self.num_global_tokens:] # [B, H, (num_windows * window_size), head_dim]
|
356
|
-
local_v = v[:, :, self.num_global_tokens:]
|
357
|
-
# Apply RoPE to local_k if needed
|
375
|
+
q = self._split_heads(self.q_proj(x))
|
376
|
+
k = self._split_heads(self.k_proj(key))
|
377
|
+
v = self._split_heads(self.v_proj(value))
|
378
|
+
|
379
|
+
# Apply RoPE
|
358
380
|
if self.rope:
|
359
|
-
|
360
|
-
|
361
|
-
|
362
|
-
|
363
|
-
|
364
|
-
|
365
|
-
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
|
385
|
-
|
386
|
-
|
387
|
-
#
|
388
|
-
|
389
|
-
|
390
|
-
|
381
|
+
q, k = self._apply_rope(q, k, separate=True)
|
382
|
+
|
383
|
+
# Split Q into global and local parts
|
384
|
+
global_q = q[:, :, :self.num_global_tokens] # (B, H, G, head_dim)
|
385
|
+
local_q = q[:, :, self.num_global_tokens:] # (B, H, L, head_dim)
|
386
|
+
L = local_q.size(2)
|
387
|
+
S = k.size(2)
|
388
|
+
|
389
|
+
# Global attention: global_q attends to all K/V
|
390
|
+
global_attn = F.scaled_dot_product_attention(
|
391
|
+
global_q, k, v, attn_mask=mask if not self.is_causal else None,
|
392
|
+
dropout_p=self.dropout.p if self.training else 0, is_causal=self.is_causal
|
393
|
+
)
|
394
|
+
|
395
|
+
# Local attention: local_q attends to windowed K/V
|
396
|
+
# Vectorized window mask
|
397
|
+
indices = torch.arange(S, device=local_q.device)
|
398
|
+
local_pos = torch.arange(L, device=local_q.device)
|
399
|
+
local_window = (local_pos // self.window_size).unsqueeze(-1) # (L, 1)
|
400
|
+
key_window = (indices // self.window_size).expand(L, -1) # (L, S)
|
401
|
+
window_mask = (local_window == key_window).to(device=local_q.device)
|
402
|
+
|
403
|
+
local_attn = F.scaled_dot_product_attention(
|
404
|
+
local_q, k, v, attn_mask=window_mask if not self.is_causal else None,
|
405
|
+
dropout_p=self.dropout.p if self.training else 0, is_causal=self.is_causal
|
406
|
+
)
|
407
|
+
|
408
|
+
# Combine global and local attention outputs
|
409
|
+
attn = torch.cat([global_attn, local_attn], dim=2) # (B, H, G+L, head_dim)
|
410
|
+
|
411
|
+
# Merge heads and project back
|
412
|
+
output = self._merge_heads(attn) # (B, G+L, D)
|
413
|
+
output = self.out_proj(output) # (B, G+L, D)
|
414
|
+
|
415
|
+
# Return only the local tokens (original query tokens)
|
416
|
+
return output[:, self.num_global_tokens:, :]
|
391
417
|
|
392
418
|
|
393
419
|
class InfiniteAttention(MultiHeadAttention):
|
@@ -467,8 +493,10 @@ def init_experimental_attention(
|
|
467
493
|
num_experts: int = None,
|
468
494
|
num_query_experts: int = None,
|
469
495
|
num_query_groups: int = None,
|
496
|
+
num_global_tokens: int = 16,
|
497
|
+
window_size: int = 128,
|
470
498
|
) -> MultiHeadAttention:
|
471
|
-
assert attention_type in ['gma', 'dma', 'sqa'], "Error, attention type should be one of: 'gma', 'dma', 'sqa'"
|
499
|
+
assert attention_type in ['gma', 'dma', 'sqa', 'flex'], "Error, attention type should be one of: 'gma', 'dma', 'sqa', 'flex'"
|
472
500
|
|
473
501
|
if attention_type == "gma":
|
474
502
|
return GroupedMoeAttention(
|
@@ -504,6 +532,21 @@ def init_experimental_attention(
|
|
504
532
|
num_query_experts=num_query_experts,
|
505
533
|
num_query_groups=num_query_groups,
|
506
534
|
)
|
535
|
+
elif attention_type == "flex":
|
536
|
+
return FlexAttention(
|
537
|
+
embed_dim,
|
538
|
+
num_heads,
|
539
|
+
dropout=dropout,
|
540
|
+
rope=rope,
|
541
|
+
max_seq_len=max_seq_len,
|
542
|
+
rope_only_for_query=rope_only_for_query,
|
543
|
+
rope_only_for_keys=rope_only_for_keys,
|
544
|
+
use_flash_attention=use_flash_attention,
|
545
|
+
is_causal=is_causal,
|
546
|
+
use_bias=use_bias,
|
547
|
+
num_global_tokens=num_global_tokens,
|
548
|
+
window_size=window_size,
|
549
|
+
)
|
507
550
|
else:
|
508
551
|
return SparseQueryAttention(
|
509
552
|
embed_dim,
|
rxnn/experimental/models.py
CHANGED
@@ -32,6 +32,8 @@ class ExperimentalAttentionTransformerConfig(TypedDict):
|
|
32
32
|
att_num_experts: int
|
33
33
|
att_num_query_experts: int
|
34
34
|
att_num_query_groups: int
|
35
|
+
att_num_global_tokens: int
|
36
|
+
att_window_size: int
|
35
37
|
|
36
38
|
|
37
39
|
class ExperimentalAttentionTransformer(nn.Module, PyTorchModelHubMixin, pipeline_tag="text-generation", license="apache-2.0"):
|
@@ -63,13 +65,15 @@ class ExperimentalAttentionTransformer(nn.Module, PyTorchModelHubMixin, pipeline
|
|
63
65
|
att_num_experts: int = None,
|
64
66
|
att_num_query_experts: int = None,
|
65
67
|
att_num_query_groups: int = None,
|
68
|
+
att_num_global_tokens: int = 16,
|
69
|
+
att_window_size: int = 128,
|
66
70
|
**kwargs
|
67
71
|
):
|
68
72
|
super(ExperimentalAttentionTransformer, self).__init__(**kwargs)
|
69
73
|
assert ff_activation in ['relu', 'gelu',
|
70
74
|
'swish', 'silu', 'linear',
|
71
75
|
'sigmoid'], 'Feed-forward activation could be "relu", "gelu", "swish", "silu", "linear", "sigmoid".'
|
72
|
-
assert att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma', 'sqa'], 'Self-attention type could be "mha", "gqa", "mqa", "gma", "dma", "sqa".'
|
76
|
+
assert att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma', 'sqa', 'flex'], 'Self-attention type could be "mha", "gqa", "mqa", "gma", "dma", "sqa", "flex".'
|
73
77
|
|
74
78
|
embedding = nn.Embedding(vocab_size, embed_dim)
|
75
79
|
rope = RotaryPositionalEmbedding(embed_dim // att_heads, seq_len)
|
@@ -84,8 +88,8 @@ class ExperimentalAttentionTransformer(nn.Module, PyTorchModelHubMixin, pipeline
|
|
84
88
|
att_init = lambda: init_experimental_attention(embed_dim, att_heads, att_type, att_groups, rope=rope,
|
85
89
|
use_flash_attention=use_flash_attention, dropout=att_dropout,
|
86
90
|
max_seq_len=seq_len, is_causal=True, num_experts=att_num_experts,
|
87
|
-
num_query_experts=att_num_query_experts,
|
88
|
-
|
91
|
+
num_query_experts=att_num_query_experts, num_query_groups=att_num_query_groups,
|
92
|
+
num_global_tokens=att_num_global_tokens, window_size=att_window_size)
|
89
93
|
|
90
94
|
use_moe_att = att_type in ['gma', 'dma']
|
91
95
|
|
rxnn/memory/attention.py
CHANGED
@@ -12,6 +12,7 @@ class StmMemoryAttention(nn.Module):
|
|
12
12
|
per_slot_gate: bool = False,
|
13
13
|
init_gate: float = 0.0,
|
14
14
|
use_dynamic_gate: bool = False,
|
15
|
+
use_tanh_gate: bool = False,
|
15
16
|
*args,
|
16
17
|
**kwargs
|
17
18
|
):
|
@@ -24,6 +25,7 @@ class StmMemoryAttention(nn.Module):
|
|
24
25
|
self.use_gated_residual = use_gated_residual
|
25
26
|
self.per_slot_gate = per_slot_gate
|
26
27
|
self.use_dynamic_gate = use_dynamic_gate
|
28
|
+
self.use_tanh_gate = use_tanh_gate
|
27
29
|
if self.use_gated_residual:
|
28
30
|
gate_shape = (self.num_layers, self.stm.stm_size, 1) if self.per_slot_gate else (self.num_layers,)
|
29
31
|
self.gate = nn.Parameter(torch.full(gate_shape, init_gate))
|
@@ -37,10 +39,13 @@ class StmMemoryAttention(nn.Module):
|
|
37
39
|
if self.use_dynamic_gate:
|
38
40
|
mean_dim = -1 if self.per_slot_gate else [1, 2]
|
39
41
|
gate_input = gate * (new_layer_stm + layer_stm).mean(dim=mean_dim, keepdim=True)
|
40
|
-
layer_gate = torch.sigmoid(gate_input)
|
42
|
+
layer_gate = torch.tanh(gate_input) if self.use_tanh_gate else torch.sigmoid(gate_input)
|
41
43
|
else:
|
42
|
-
layer_gate = torch.sigmoid(gate)
|
43
|
-
|
44
|
+
layer_gate = torch.tanh(gate) if self.use_tanh_gate else torch.sigmoid(gate)
|
45
|
+
if self.use_tanh_gate:
|
46
|
+
return (1 + layer_gate) * new_layer_stm + (1 - layer_gate) * layer_stm
|
47
|
+
else:
|
48
|
+
return layer_gate * new_layer_stm + (1 - layer_gate) * layer_stm
|
44
49
|
|
45
50
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
46
51
|
new_stm = torch.zeros_like(self.stm.memory)
|
rxnn/training/mrl.py
CHANGED
@@ -9,7 +9,7 @@ import random, os
|
|
9
9
|
from ..transformers.sampler import BatchSampler
|
10
10
|
from .callbacks import MrlTrainerCallback
|
11
11
|
from .dataset import MrlCurriculumDataset
|
12
|
-
from .utils import smart_concat, smart_concat_critic_states, TokenizedDict
|
12
|
+
from .utils import smart_concat, smart_concat_critic_states, TokenizedDict, get_gradient_norms
|
13
13
|
from .rl import RlAlgorithm
|
14
14
|
from .reward import MrlRewardMode, MrlRewardModel
|
15
15
|
from .models import MrlActorAction, MrlActorModel, MrlCriticModel
|
@@ -109,6 +109,7 @@ class MRLTrainer:
|
|
109
109
|
use_ddp: bool = False,
|
110
110
|
use_amp: bool = False,
|
111
111
|
dtype: torch.dtype = torch.float32,
|
112
|
+
debug_mode: bool = False,
|
112
113
|
):
|
113
114
|
"""
|
114
115
|
Trainer for Memory Reinforcement Learning (MRL) algorithm for reactive models and Attention-Based Memory System.
|
@@ -139,6 +140,7 @@ class MRLTrainer:
|
|
139
140
|
self.shared_freeze_embeddings = config.get('freeze_embeddings', False)
|
140
141
|
self.freeze_embeddings = self.shared_freeze_embeddings
|
141
142
|
self.use_memory_warmup = config.get('use_memory_warmup', False)
|
143
|
+
self.debug_mode = debug_mode
|
142
144
|
# Internal update epochs config
|
143
145
|
self.shared_update_epochs = config.get('update_epochs', 10)
|
144
146
|
self.update_epochs = self.shared_update_epochs
|
@@ -566,6 +568,14 @@ class MRLTrainer:
|
|
566
568
|
else:
|
567
569
|
return main_loss
|
568
570
|
|
571
|
+
def _log_gradients(self):
|
572
|
+
encoder_total, encoder_mean = get_gradient_norms(self.actor.encoder)
|
573
|
+
decoder_total, decoder_mean = get_gradient_norms(self.actor.decoder)
|
574
|
+
mem_att_total, mem_att_mean = get_gradient_norms(self.actor.memory_attention)
|
575
|
+
print(f"Encoder grad norm - total: {encoder_total:.4f}, mean: {encoder_mean:.4f}")
|
576
|
+
print(f"Decoder grad norm - total: {decoder_total:.4f}, mean: {decoder_mean:.4f}")
|
577
|
+
print(f"Memory attention grad norm - total: {mem_att_total:.4f}, mean: {mem_att_mean:.4f}")
|
578
|
+
|
569
579
|
def update_actor(self, state: tuple[TokenizedDict, TokenizedDict, TokenizedDict], action: TokenizedDict,
|
570
580
|
advantages: torch.Tensor, old_log_probs: torch.Tensor, epoch: int) -> float:
|
571
581
|
# 1. Reset actor gradients
|
@@ -596,6 +606,8 @@ class MRLTrainer:
|
|
596
606
|
self.scaler.unscale_(self.optimizer)
|
597
607
|
torch.nn.utils.clip_grad_norm_(self.actor.unique_parameters(), max_norm=1.0,
|
598
608
|
error_if_nonfinite=False)
|
609
|
+
if self.debug_mode:
|
610
|
+
self._log_gradients()
|
599
611
|
# 4.5 Run scaled optimization step
|
600
612
|
self.scaler.step(self.optimizer)
|
601
613
|
self.scaler.update()
|
@@ -613,6 +625,8 @@ class MRLTrainer:
|
|
613
625
|
# 4.4 Clip gradient norms
|
614
626
|
torch.nn.utils.clip_grad_norm_(self.actor.unique_parameters(), max_norm=1.0,
|
615
627
|
error_if_nonfinite=False)
|
628
|
+
if self.debug_mode:
|
629
|
+
self._log_gradients()
|
616
630
|
# 4.5 Run scaled optimization step
|
617
631
|
self.optimizer.step()
|
618
632
|
# 5. Get float loss value for callbacks/writer
|
rxnn/training/rl.py
CHANGED
@@ -36,7 +36,7 @@ class PPOConfig(TypedDict):
|
|
36
36
|
|
37
37
|
|
38
38
|
class PPOAlgorithm(RlAlgorithm):
|
39
|
-
def __init__(self, config: Optional[PPOConfig] = None):
|
39
|
+
def __init__(self, config: Optional[PPOConfig] = None, debug_mode: bool = False):
|
40
40
|
super(PPOAlgorithm, self).__init__()
|
41
41
|
|
42
42
|
if config is None:
|
@@ -49,7 +49,8 @@ class PPOAlgorithm(RlAlgorithm):
|
|
49
49
|
self.entropy_coef = config.get('entropy_coef', 0.01)
|
50
50
|
self.use_distributed_advantage_norm = config.get('use_distributed_advantage_norm', False)
|
51
51
|
self.clip_critic_values = config.get('clip_critic_values', True)
|
52
|
-
self.critic_value_clip = config.get('critic_value_clip',
|
52
|
+
self.critic_value_clip = config.get('critic_value_clip', 20.0)
|
53
|
+
self.debug_mode = debug_mode
|
53
54
|
|
54
55
|
def critic_loss(self, values: torch.Tensor, ref_values: torch.Tensor) -> torch.Tensor:
|
55
56
|
# Critic loss with clipped values
|
@@ -96,6 +97,14 @@ class PPOAlgorithm(RlAlgorithm):
|
|
96
97
|
|
97
98
|
advantages = advantages.unsqueeze(-1)
|
98
99
|
|
100
|
+
if self.debug_mode:
|
101
|
+
print(
|
102
|
+
f"Logits stats: min={new_logits.min().item():.4f}, max={new_logits.max().item():.4f}, mean={new_logits.mean().item():.4f}")
|
103
|
+
print(
|
104
|
+
f"Ratio stats: min={ratio.min().item():.4f}, max={ratio.max().item():.4f}, mean={ratio.mean().item():.4f}")
|
105
|
+
print(
|
106
|
+
f"Advantage stats: min={advantages.min().item():.4f}, max={advantages.max().item():.4f}, mean={advantages.mean().item():.4f}")
|
107
|
+
|
99
108
|
# c) Clipped surrogate loss
|
100
109
|
surr1 = ratio * advantages
|
101
110
|
surr2 = torch.clamp(ratio, 1.0 - self.clip_eps, 1.0 + self.clip_eps) * advantages
|
rxnn/training/utils.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1
1
|
import torch
|
2
|
+
import torch.nn as nn
|
2
3
|
from typing import TypedDict
|
3
4
|
|
4
5
|
|
@@ -142,3 +143,13 @@ def smart_concat(query: TokenizedDict, answer: TokenizedDict, max_length: int, p
|
|
142
143
|
'input_ids': combined_ids,
|
143
144
|
'attention_mask': combined_mask
|
144
145
|
}
|
146
|
+
|
147
|
+
def get_gradient_norms(model: nn.Module):
|
148
|
+
total_norm = 0
|
149
|
+
for p in model.parameters():
|
150
|
+
if p.grad is not None:
|
151
|
+
param_norm = p.grad.data.norm(2)
|
152
|
+
total_norm += param_norm.item() ** 2
|
153
|
+
total_norm = total_norm ** 0.5
|
154
|
+
mean_norm = total_norm / len(list(model.parameters()))
|
155
|
+
return total_norm, mean_norm
|
rxnn/transformers/attention.py
CHANGED
@@ -69,12 +69,14 @@ class MultiHeadAttention(nn.Module):
|
|
69
69
|
v = self.v_proj(value).view(b, -1, self.num_heads, d // self.num_heads).transpose(1, 2)
|
70
70
|
return q, k, v
|
71
71
|
|
72
|
-
def _apply_rope(self, q: torch.Tensor, k: torch.Tensor):
|
72
|
+
def _apply_rope(self, q: torch.Tensor, k: torch.Tensor, separate: bool = False):
|
73
73
|
if self.rope is not None:
|
74
74
|
if self.rope_only_for_query:
|
75
75
|
q = self.rope.forward_one(q)
|
76
76
|
elif self.rope_only_for_keys:
|
77
77
|
k = self.rope.forward_one(k)
|
78
|
+
elif separate:
|
79
|
+
q, k = self.rope.forward_one(q), self.rope.forward_one(k)
|
78
80
|
else:
|
79
81
|
q, k = self.rope(q, k)
|
80
82
|
return q, k
|
rxnn/transformers/layers.py
CHANGED
@@ -110,7 +110,7 @@ class ReactiveTransformerLayer(nn.Module):
|
|
110
110
|
residual = x
|
111
111
|
if not self.use_post_norm:
|
112
112
|
x = self.norm2(x)
|
113
|
-
x = self.memory_cross_attention(x, stm, stm)
|
113
|
+
x = self.memory_cross_attention(x, stm, stm, mask=mask)
|
114
114
|
x = residual + x
|
115
115
|
if self.use_post_norm:
|
116
116
|
x = self.norm2(x)
|
@@ -1,11 +1,11 @@
|
|
1
1
|
rxnn/.DS_Store,sha256=BxZLo9tFs48JMq6jhumiCnCPLTeCwl619CFSg4ClRAY,6148
|
2
2
|
rxnn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
3
|
rxnn/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
4
|
-
rxnn/experimental/attention.py,sha256=
|
5
|
-
rxnn/experimental/models.py,sha256=
|
4
|
+
rxnn/experimental/attention.py,sha256=z9zahN1KTDcVabUB0RwWsT6oK8MTuz65haVwYRHAzy4,24689
|
5
|
+
rxnn/experimental/models.py,sha256=HPOIRpnX_oiI10wsVC4J6rzo3T6dj10aNWGYpa9S1UU,5115
|
6
6
|
rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
|
7
7
|
rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
8
|
-
rxnn/memory/attention.py,sha256=
|
8
|
+
rxnn/memory/attention.py,sha256=wnYjd3UnmzAA79-7QxpMoEk3O1qRy8LmW1JcE_Fotck,3094
|
9
9
|
rxnn/memory/norm.py,sha256=cVjjhCLqR5K6-321SP_ObG17y-ddlcTJeCTXvW4vpk0,6675
|
10
10
|
rxnn/memory/stm.py,sha256=jv57gsH9XW19sLbxpRDqsp1yfsii_4Ef4Ncr_ztk-i4,3937
|
11
11
|
rxnn/rxt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -17,23 +17,23 @@ rxnn/training/callbacks.py,sha256=rS8leuVFPVVfE5Zc8DMkUZhRIPN-vpPbUjowXE5TSBw,36
|
|
17
17
|
rxnn/training/dataset.py,sha256=7hTilFWPpqUEc6zNcMqBPjxFKxCfvTKKF3E8tVlwccQ,51250
|
18
18
|
rxnn/training/ddp.py,sha256=VsNBjn3cY-uUj8hbsW7oKvb0_ZKnXnJ2KgObm-Mr9i4,836
|
19
19
|
rxnn/training/models.py,sha256=CS6mjD338knXmCbMZ3bCpOlA-DR3kmQUOSj5u5F6jII,9002
|
20
|
-
rxnn/training/mrl.py,sha256=
|
20
|
+
rxnn/training/mrl.py,sha256=185rZsaFVaAt4mYksuflzKPiDuEaHpjsc3vPzxt9ax0,61862
|
21
21
|
rxnn/training/reward.py,sha256=uiSsBXmjMw2yv-1Bssy3RTlpU6zP8ape3490Sl-aT0M,16144
|
22
|
-
rxnn/training/rl.py,sha256=
|
22
|
+
rxnn/training/rl.py,sha256=iBEPC_gfydXtWkVORO3REMWvOtx60-0xB7MFzfghUK8,6825
|
23
23
|
rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
|
24
24
|
rxnn/training/tokenizer.py,sha256=umaLByMBx_NMrQElA45HLm9gkuzyKWDTFaKVd-CjXl0,8344
|
25
|
-
rxnn/training/utils.py,sha256=
|
25
|
+
rxnn/training/utils.py,sha256=C0OS2RAGQ3L7D_G3CWupu_BpAFhkovMByBKm355Ibfc,6087
|
26
26
|
rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
27
|
-
rxnn/transformers/attention.py,sha256=
|
27
|
+
rxnn/transformers/attention.py,sha256=KRnKT6XUqAXElxV9y72mSpdTeiMgCKCCLqqxCFNTHmA,16372
|
28
28
|
rxnn/transformers/ff.py,sha256=WDjO-H9XWInoWnUnxiseIH6Kx5GlHP0zGJygwhcb1gc,2589
|
29
|
-
rxnn/transformers/layers.py,sha256=
|
29
|
+
rxnn/transformers/layers.py,sha256=wG8C9doafpLUsGtUTg-xdrHt7EQMEdB10vcSD-O1nVg,7999
|
30
30
|
rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
|
31
31
|
rxnn/transformers/models.py,sha256=7ypPNFFnacdZjvaLVue1KR2PmMSdVYsbCMQSunXDL70,10720
|
32
32
|
rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
|
33
33
|
rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
|
34
34
|
rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
|
35
35
|
rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
|
36
|
-
rxnn-0.2.
|
37
|
-
rxnn-0.2.
|
38
|
-
rxnn-0.2.
|
39
|
-
rxnn-0.2.
|
36
|
+
rxnn-0.2.53.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
37
|
+
rxnn-0.2.53.dist-info/METADATA,sha256=CqOV8zz8qcGYyZXbeRv7jLMd40AeF8iU7vyfRsMuslE,25997
|
38
|
+
rxnn-0.2.53.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
39
|
+
rxnn-0.2.53.dist-info/RECORD,,
|
File without changes
|
File without changes
|