rxnn 0.2.51__tar.gz → 0.2.53__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. {rxnn-0.2.51 → rxnn-0.2.53}/PKG-INFO +1 -1
  2. {rxnn-0.2.51 → rxnn-0.2.53}/pyproject.toml +1 -1
  3. {rxnn-0.2.51 → rxnn-0.2.53}/src/rxnn/experimental/attention.py +103 -60
  4. {rxnn-0.2.51 → rxnn-0.2.53}/src/rxnn/experimental/models.py +7 -3
  5. {rxnn-0.2.51 → rxnn-0.2.53}/src/rxnn/memory/attention.py +8 -3
  6. {rxnn-0.2.51 → rxnn-0.2.53}/src/rxnn/training/mrl.py +15 -1
  7. {rxnn-0.2.51 → rxnn-0.2.53}/src/rxnn/training/rl.py +11 -2
  8. {rxnn-0.2.51 → rxnn-0.2.53}/src/rxnn/training/utils.py +11 -0
  9. {rxnn-0.2.51 → rxnn-0.2.53}/src/rxnn/transformers/attention.py +3 -1
  10. {rxnn-0.2.51 → rxnn-0.2.53}/src/rxnn/transformers/layers.py +1 -1
  11. {rxnn-0.2.51 → rxnn-0.2.53}/LICENSE +0 -0
  12. {rxnn-0.2.51 → rxnn-0.2.53}/README.md +0 -0
  13. {rxnn-0.2.51 → rxnn-0.2.53}/src/rxnn/.DS_Store +0 -0
  14. {rxnn-0.2.51 → rxnn-0.2.53}/src/rxnn/__init__.py +0 -0
  15. {rxnn-0.2.51 → rxnn-0.2.53}/src/rxnn/experimental/__init__.py +0 -0
  16. {rxnn-0.2.51 → rxnn-0.2.53}/src/rxnn/experimental/moe.py +0 -0
  17. {rxnn-0.2.51 → rxnn-0.2.53}/src/rxnn/memory/__init__.py +0 -0
  18. {rxnn-0.2.51 → rxnn-0.2.53}/src/rxnn/memory/norm.py +0 -0
  19. {rxnn-0.2.51 → rxnn-0.2.53}/src/rxnn/memory/stm.py +0 -0
  20. {rxnn-0.2.51 → rxnn-0.2.53}/src/rxnn/rxt/__init__.py +0 -0
  21. {rxnn-0.2.51 → rxnn-0.2.53}/src/rxnn/rxt/models.py +0 -0
  22. {rxnn-0.2.51 → rxnn-0.2.53}/src/rxnn/training/__init__.py +0 -0
  23. {rxnn-0.2.51 → rxnn-0.2.53}/src/rxnn/training/base.py +0 -0
  24. {rxnn-0.2.51 → rxnn-0.2.53}/src/rxnn/training/bml.py +0 -0
  25. {rxnn-0.2.51 → rxnn-0.2.53}/src/rxnn/training/callbacks.py +0 -0
  26. {rxnn-0.2.51 → rxnn-0.2.53}/src/rxnn/training/dataset.py +0 -0
  27. {rxnn-0.2.51 → rxnn-0.2.53}/src/rxnn/training/ddp.py +0 -0
  28. {rxnn-0.2.51 → rxnn-0.2.53}/src/rxnn/training/models.py +0 -0
  29. {rxnn-0.2.51 → rxnn-0.2.53}/src/rxnn/training/reward.py +0 -0
  30. {rxnn-0.2.51 → rxnn-0.2.53}/src/rxnn/training/scheduler.py +0 -0
  31. {rxnn-0.2.51 → rxnn-0.2.53}/src/rxnn/training/tokenizer.py +0 -0
  32. {rxnn-0.2.51 → rxnn-0.2.53}/src/rxnn/transformers/__init__.py +0 -0
  33. {rxnn-0.2.51 → rxnn-0.2.53}/src/rxnn/transformers/ff.py +0 -0
  34. {rxnn-0.2.51 → rxnn-0.2.53}/src/rxnn/transformers/mask.py +0 -0
  35. {rxnn-0.2.51 → rxnn-0.2.53}/src/rxnn/transformers/models.py +0 -0
  36. {rxnn-0.2.51 → rxnn-0.2.53}/src/rxnn/transformers/moe.py +0 -0
  37. {rxnn-0.2.51 → rxnn-0.2.53}/src/rxnn/transformers/positional.py +0 -0
  38. {rxnn-0.2.51 → rxnn-0.2.53}/src/rxnn/transformers/sampler.py +0 -0
  39. {rxnn-0.2.51 → rxnn-0.2.53}/src/rxnn/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.51
3
+ Version: 0.2.53
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
4
4
 
5
5
  [tool.poetry]
6
6
  name = "rxnn"
7
- version = "0.2.51"
7
+ version = "0.2.53"
8
8
  description = "RxNN: Reactive Neural Networks Platform"
9
9
 
10
10
  license = "Apache-2.0"
@@ -318,76 +318,102 @@ class SparseQueryAttention(MultiHeadAttention):
318
318
 
319
319
 
320
320
  # Others
321
-
322
321
  class FlexAttention(MultiHeadAttention):
323
322
  def __init__(
324
- self,
325
- embed_dim: int,
326
- num_heads: int,
327
- num_global_tokens: int = 16,
328
- window_size: int = 128,
329
- **kwargs
323
+ self,
324
+ embed_dim: int,
325
+ num_heads: int,
326
+ dropout: float = 0.0,
327
+ rope: RotaryPositionalEmbedding = None,
328
+ rope_only_for_query: bool = False,
329
+ rope_only_for_keys: bool = False,
330
+ use_relative_embeddings: bool = False,
331
+ max_seq_len: int = 1024,
332
+ use_flash_attention: bool = True,
333
+ is_causal: bool = False,
334
+ use_bias: bool = False,
335
+ num_global_tokens: int = 16,
336
+ window_size: int = 128,
330
337
  ):
331
- super().__init__(embed_dim, num_heads, **kwargs)
338
+ super(FlexAttention, self).__init__(
339
+ embed_dim,
340
+ num_heads,
341
+ dropout=dropout,
342
+ rope=rope,
343
+ rope_only_for_query=rope_only_for_query,
344
+ rope_only_for_keys=rope_only_for_keys,
345
+ use_relative_embeddings=use_relative_embeddings,
346
+ max_seq_len=max_seq_len,
347
+ use_flash_attention=use_flash_attention,
348
+ is_causal=is_causal,
349
+ use_bias=use_bias,
350
+ )
351
+ self.head_dim = embed_dim // num_heads
332
352
  self.num_global_tokens = num_global_tokens
333
353
  self.window_size = window_size
334
- self.global_tokens = nn.Parameter(torch.zeros(1, num_global_tokens, embed_dim))
335
354
 
336
- def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask=None):
355
+ # Learnable global tokens
356
+ self.global_tokens = nn.Parameter(torch.randn(1, num_global_tokens, embed_dim))
357
+
358
+
359
+ def _split_heads(self, x: torch.Tensor) -> torch.Tensor:
360
+ b, t, d = x.size()
361
+ return x.view(b, t, self.num_heads, self.head_dim).transpose(1, 2)
362
+
363
+ def _merge_heads(self, x: torch.Tensor) -> torch.Tensor:
364
+ b, h, t, d = x.size()
365
+ return self._transpose_output(x, b, t, h * d)
366
+
367
+ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: torch.Tensor = None):
337
368
  b, t, d = query.size()
338
- head_dim = d // self.num_heads
339
369
 
340
- # Split into global and local
341
- x = torch.cat([self.global_tokens.expand(b, -1, -1), query], dim=1)
342
- seq_len = x.size(1)
343
- num_windows = (seq_len - self.num_global_tokens + self.window_size - 1) // self.window_size
370
+ # Prepend global tokens to the input query
371
+ global_tokens = self.global_tokens.expand(b, -1, -1)
372
+ x = torch.cat([global_tokens, query], dim=1)
344
373
 
345
374
  # Project Q, K, V
346
- q, k, v = self._forward_qkv(x, key, value, b, seq_len, d)
347
-
348
- # Process Global-to-Global Attention
349
- global_q = q[:, :, :self.num_global_tokens] # [B, H, G, head_dim]
350
- global_k = k[:, :, :self.num_global_tokens]
351
- global_v = v[:, :, :self.num_global_tokens]
352
- global_attn = self._calculate_attn_weights(global_q, global_k, d) @ global_v
353
-
354
- # Process Global-to-Local Attention
355
- local_k = k[:, :, self.num_global_tokens:] # [B, H, (num_windows * window_size), head_dim]
356
- local_v = v[:, :, self.num_global_tokens:]
357
- # Apply RoPE to local_k if needed
375
+ q = self._split_heads(self.q_proj(x))
376
+ k = self._split_heads(self.k_proj(key))
377
+ v = self._split_heads(self.v_proj(value))
378
+
379
+ # Apply RoPE
358
380
  if self.rope:
359
- # Compute frequencies for entire local sequence
360
- local_k = self.rope.forward_one(local_k)
361
-
362
- global_local_attn = self._calculate_attn_weights(global_q, local_k, d) @ local_v
363
-
364
- # Process Local-to-Local Attention (per window)
365
- local_q = q[:, :, self.num_global_tokens:] # [B, H, (num_windows * window_size), head_dim]
366
- local_q = local_q.view(b, self.num_heads, num_windows, self.window_size, head_dim)
367
- local_k = local_k.view(b, self.num_heads, num_windows, self.window_size, head_dim)
368
- local_v = local_v.view(b, self.num_heads, num_windows, self.window_size, head_dim)
369
-
370
- local_attn = []
371
- for i in range(num_windows):
372
- window_q = local_q[:, :, i] # [B, H, window_size, head_dim]
373
- window_k = local_k[:, :, i]
374
- window_v = local_v[:, :, i]
375
-
376
- # Apply RoPE to window_q and window_k
377
- if self.rope:
378
- # Compute frequencies for this window
379
- window_q, window_k = self.rope(window_q, window_k)
380
-
381
- # Calculate attention for this window
382
- attn = self._calculate_attn_weights(window_q, window_k, d)
383
- attn_i = torch.einsum('bhij, bhjd -> bhid', attn, window_v)
384
- local_attn.append(attn_i)
385
- local_attn = torch.cat(local_attn, dim=2).view(b, self.num_heads, -1, head_dim)
386
-
387
- # Combine all attention outputs
388
- combined_attn = torch.cat([global_attn, global_local_attn, local_attn], dim=2)
389
- output = self._calculate_output(combined_attn, v, b, t, d)
390
- return self.out_proj(output)
381
+ q, k = self._apply_rope(q, k, separate=True)
382
+
383
+ # Split Q into global and local parts
384
+ global_q = q[:, :, :self.num_global_tokens] # (B, H, G, head_dim)
385
+ local_q = q[:, :, self.num_global_tokens:] # (B, H, L, head_dim)
386
+ L = local_q.size(2)
387
+ S = k.size(2)
388
+
389
+ # Global attention: global_q attends to all K/V
390
+ global_attn = F.scaled_dot_product_attention(
391
+ global_q, k, v, attn_mask=mask if not self.is_causal else None,
392
+ dropout_p=self.dropout.p if self.training else 0, is_causal=self.is_causal
393
+ )
394
+
395
+ # Local attention: local_q attends to windowed K/V
396
+ # Vectorized window mask
397
+ indices = torch.arange(S, device=local_q.device)
398
+ local_pos = torch.arange(L, device=local_q.device)
399
+ local_window = (local_pos // self.window_size).unsqueeze(-1) # (L, 1)
400
+ key_window = (indices // self.window_size).expand(L, -1) # (L, S)
401
+ window_mask = (local_window == key_window).to(device=local_q.device)
402
+
403
+ local_attn = F.scaled_dot_product_attention(
404
+ local_q, k, v, attn_mask=window_mask if not self.is_causal else None,
405
+ dropout_p=self.dropout.p if self.training else 0, is_causal=self.is_causal
406
+ )
407
+
408
+ # Combine global and local attention outputs
409
+ attn = torch.cat([global_attn, local_attn], dim=2) # (B, H, G+L, head_dim)
410
+
411
+ # Merge heads and project back
412
+ output = self._merge_heads(attn) # (B, G+L, D)
413
+ output = self.out_proj(output) # (B, G+L, D)
414
+
415
+ # Return only the local tokens (original query tokens)
416
+ return output[:, self.num_global_tokens:, :]
391
417
 
392
418
 
393
419
  class InfiniteAttention(MultiHeadAttention):
@@ -467,8 +493,10 @@ def init_experimental_attention(
467
493
  num_experts: int = None,
468
494
  num_query_experts: int = None,
469
495
  num_query_groups: int = None,
496
+ num_global_tokens: int = 16,
497
+ window_size: int = 128,
470
498
  ) -> MultiHeadAttention:
471
- assert attention_type in ['gma', 'dma', 'sqa'], "Error, attention type should be one of: 'gma', 'dma', 'sqa'"
499
+ assert attention_type in ['gma', 'dma', 'sqa', 'flex'], "Error, attention type should be one of: 'gma', 'dma', 'sqa', 'flex'"
472
500
 
473
501
  if attention_type == "gma":
474
502
  return GroupedMoeAttention(
@@ -504,6 +532,21 @@ def init_experimental_attention(
504
532
  num_query_experts=num_query_experts,
505
533
  num_query_groups=num_query_groups,
506
534
  )
535
+ elif attention_type == "flex":
536
+ return FlexAttention(
537
+ embed_dim,
538
+ num_heads,
539
+ dropout=dropout,
540
+ rope=rope,
541
+ max_seq_len=max_seq_len,
542
+ rope_only_for_query=rope_only_for_query,
543
+ rope_only_for_keys=rope_only_for_keys,
544
+ use_flash_attention=use_flash_attention,
545
+ is_causal=is_causal,
546
+ use_bias=use_bias,
547
+ num_global_tokens=num_global_tokens,
548
+ window_size=window_size,
549
+ )
507
550
  else:
508
551
  return SparseQueryAttention(
509
552
  embed_dim,
@@ -32,6 +32,8 @@ class ExperimentalAttentionTransformerConfig(TypedDict):
32
32
  att_num_experts: int
33
33
  att_num_query_experts: int
34
34
  att_num_query_groups: int
35
+ att_num_global_tokens: int
36
+ att_window_size: int
35
37
 
36
38
 
37
39
  class ExperimentalAttentionTransformer(nn.Module, PyTorchModelHubMixin, pipeline_tag="text-generation", license="apache-2.0"):
@@ -63,13 +65,15 @@ class ExperimentalAttentionTransformer(nn.Module, PyTorchModelHubMixin, pipeline
63
65
  att_num_experts: int = None,
64
66
  att_num_query_experts: int = None,
65
67
  att_num_query_groups: int = None,
68
+ att_num_global_tokens: int = 16,
69
+ att_window_size: int = 128,
66
70
  **kwargs
67
71
  ):
68
72
  super(ExperimentalAttentionTransformer, self).__init__(**kwargs)
69
73
  assert ff_activation in ['relu', 'gelu',
70
74
  'swish', 'silu', 'linear',
71
75
  'sigmoid'], 'Feed-forward activation could be "relu", "gelu", "swish", "silu", "linear", "sigmoid".'
72
- assert att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma', 'sqa'], 'Self-attention type could be "mha", "gqa", "mqa", "gma", "dma", "sqa".'
76
+ assert att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma', 'sqa', 'flex'], 'Self-attention type could be "mha", "gqa", "mqa", "gma", "dma", "sqa", "flex".'
73
77
 
74
78
  embedding = nn.Embedding(vocab_size, embed_dim)
75
79
  rope = RotaryPositionalEmbedding(embed_dim // att_heads, seq_len)
@@ -84,8 +88,8 @@ class ExperimentalAttentionTransformer(nn.Module, PyTorchModelHubMixin, pipeline
84
88
  att_init = lambda: init_experimental_attention(embed_dim, att_heads, att_type, att_groups, rope=rope,
85
89
  use_flash_attention=use_flash_attention, dropout=att_dropout,
86
90
  max_seq_len=seq_len, is_causal=True, num_experts=att_num_experts,
87
- num_query_experts=att_num_query_experts,
88
- num_query_groups=att_num_query_groups)
91
+ num_query_experts=att_num_query_experts, num_query_groups=att_num_query_groups,
92
+ num_global_tokens=att_num_global_tokens, window_size=att_window_size)
89
93
 
90
94
  use_moe_att = att_type in ['gma', 'dma']
91
95
 
@@ -12,6 +12,7 @@ class StmMemoryAttention(nn.Module):
12
12
  per_slot_gate: bool = False,
13
13
  init_gate: float = 0.0,
14
14
  use_dynamic_gate: bool = False,
15
+ use_tanh_gate: bool = False,
15
16
  *args,
16
17
  **kwargs
17
18
  ):
@@ -24,6 +25,7 @@ class StmMemoryAttention(nn.Module):
24
25
  self.use_gated_residual = use_gated_residual
25
26
  self.per_slot_gate = per_slot_gate
26
27
  self.use_dynamic_gate = use_dynamic_gate
28
+ self.use_tanh_gate = use_tanh_gate
27
29
  if self.use_gated_residual:
28
30
  gate_shape = (self.num_layers, self.stm.stm_size, 1) if self.per_slot_gate else (self.num_layers,)
29
31
  self.gate = nn.Parameter(torch.full(gate_shape, init_gate))
@@ -37,10 +39,13 @@ class StmMemoryAttention(nn.Module):
37
39
  if self.use_dynamic_gate:
38
40
  mean_dim = -1 if self.per_slot_gate else [1, 2]
39
41
  gate_input = gate * (new_layer_stm + layer_stm).mean(dim=mean_dim, keepdim=True)
40
- layer_gate = torch.sigmoid(gate_input)
42
+ layer_gate = torch.tanh(gate_input) if self.use_tanh_gate else torch.sigmoid(gate_input)
41
43
  else:
42
- layer_gate = torch.sigmoid(gate)
43
- return layer_gate * new_layer_stm + (1 - layer_gate) * layer_stm
44
+ layer_gate = torch.tanh(gate) if self.use_tanh_gate else torch.sigmoid(gate)
45
+ if self.use_tanh_gate:
46
+ return (1 + layer_gate) * new_layer_stm + (1 - layer_gate) * layer_stm
47
+ else:
48
+ return layer_gate * new_layer_stm + (1 - layer_gate) * layer_stm
44
49
 
45
50
  def forward(self, x: torch.Tensor) -> torch.Tensor:
46
51
  new_stm = torch.zeros_like(self.stm.memory)
@@ -9,7 +9,7 @@ import random, os
9
9
  from ..transformers.sampler import BatchSampler
10
10
  from .callbacks import MrlTrainerCallback
11
11
  from .dataset import MrlCurriculumDataset
12
- from .utils import smart_concat, smart_concat_critic_states, TokenizedDict
12
+ from .utils import smart_concat, smart_concat_critic_states, TokenizedDict, get_gradient_norms
13
13
  from .rl import RlAlgorithm
14
14
  from .reward import MrlRewardMode, MrlRewardModel
15
15
  from .models import MrlActorAction, MrlActorModel, MrlCriticModel
@@ -109,6 +109,7 @@ class MRLTrainer:
109
109
  use_ddp: bool = False,
110
110
  use_amp: bool = False,
111
111
  dtype: torch.dtype = torch.float32,
112
+ debug_mode: bool = False,
112
113
  ):
113
114
  """
114
115
  Trainer for Memory Reinforcement Learning (MRL) algorithm for reactive models and Attention-Based Memory System.
@@ -139,6 +140,7 @@ class MRLTrainer:
139
140
  self.shared_freeze_embeddings = config.get('freeze_embeddings', False)
140
141
  self.freeze_embeddings = self.shared_freeze_embeddings
141
142
  self.use_memory_warmup = config.get('use_memory_warmup', False)
143
+ self.debug_mode = debug_mode
142
144
  # Internal update epochs config
143
145
  self.shared_update_epochs = config.get('update_epochs', 10)
144
146
  self.update_epochs = self.shared_update_epochs
@@ -566,6 +568,14 @@ class MRLTrainer:
566
568
  else:
567
569
  return main_loss
568
570
 
571
+ def _log_gradients(self):
572
+ encoder_total, encoder_mean = get_gradient_norms(self.actor.encoder)
573
+ decoder_total, decoder_mean = get_gradient_norms(self.actor.decoder)
574
+ mem_att_total, mem_att_mean = get_gradient_norms(self.actor.memory_attention)
575
+ print(f"Encoder grad norm - total: {encoder_total:.4f}, mean: {encoder_mean:.4f}")
576
+ print(f"Decoder grad norm - total: {decoder_total:.4f}, mean: {decoder_mean:.4f}")
577
+ print(f"Memory attention grad norm - total: {mem_att_total:.4f}, mean: {mem_att_mean:.4f}")
578
+
569
579
  def update_actor(self, state: tuple[TokenizedDict, TokenizedDict, TokenizedDict], action: TokenizedDict,
570
580
  advantages: torch.Tensor, old_log_probs: torch.Tensor, epoch: int) -> float:
571
581
  # 1. Reset actor gradients
@@ -596,6 +606,8 @@ class MRLTrainer:
596
606
  self.scaler.unscale_(self.optimizer)
597
607
  torch.nn.utils.clip_grad_norm_(self.actor.unique_parameters(), max_norm=1.0,
598
608
  error_if_nonfinite=False)
609
+ if self.debug_mode:
610
+ self._log_gradients()
599
611
  # 4.5 Run scaled optimization step
600
612
  self.scaler.step(self.optimizer)
601
613
  self.scaler.update()
@@ -613,6 +625,8 @@ class MRLTrainer:
613
625
  # 4.4 Clip gradient norms
614
626
  torch.nn.utils.clip_grad_norm_(self.actor.unique_parameters(), max_norm=1.0,
615
627
  error_if_nonfinite=False)
628
+ if self.debug_mode:
629
+ self._log_gradients()
616
630
  # 4.5 Run scaled optimization step
617
631
  self.optimizer.step()
618
632
  # 5. Get float loss value for callbacks/writer
@@ -36,7 +36,7 @@ class PPOConfig(TypedDict):
36
36
 
37
37
 
38
38
  class PPOAlgorithm(RlAlgorithm):
39
- def __init__(self, config: Optional[PPOConfig] = None):
39
+ def __init__(self, config: Optional[PPOConfig] = None, debug_mode: bool = False):
40
40
  super(PPOAlgorithm, self).__init__()
41
41
 
42
42
  if config is None:
@@ -49,7 +49,8 @@ class PPOAlgorithm(RlAlgorithm):
49
49
  self.entropy_coef = config.get('entropy_coef', 0.01)
50
50
  self.use_distributed_advantage_norm = config.get('use_distributed_advantage_norm', False)
51
51
  self.clip_critic_values = config.get('clip_critic_values', True)
52
- self.critic_value_clip = config.get('critic_value_clip', 10.0)
52
+ self.critic_value_clip = config.get('critic_value_clip', 20.0)
53
+ self.debug_mode = debug_mode
53
54
 
54
55
  def critic_loss(self, values: torch.Tensor, ref_values: torch.Tensor) -> torch.Tensor:
55
56
  # Critic loss with clipped values
@@ -96,6 +97,14 @@ class PPOAlgorithm(RlAlgorithm):
96
97
 
97
98
  advantages = advantages.unsqueeze(-1)
98
99
 
100
+ if self.debug_mode:
101
+ print(
102
+ f"Logits stats: min={new_logits.min().item():.4f}, max={new_logits.max().item():.4f}, mean={new_logits.mean().item():.4f}")
103
+ print(
104
+ f"Ratio stats: min={ratio.min().item():.4f}, max={ratio.max().item():.4f}, mean={ratio.mean().item():.4f}")
105
+ print(
106
+ f"Advantage stats: min={advantages.min().item():.4f}, max={advantages.max().item():.4f}, mean={advantages.mean().item():.4f}")
107
+
99
108
  # c) Clipped surrogate loss
100
109
  surr1 = ratio * advantages
101
110
  surr2 = torch.clamp(ratio, 1.0 - self.clip_eps, 1.0 + self.clip_eps) * advantages
@@ -1,4 +1,5 @@
1
1
  import torch
2
+ import torch.nn as nn
2
3
  from typing import TypedDict
3
4
 
4
5
 
@@ -142,3 +143,13 @@ def smart_concat(query: TokenizedDict, answer: TokenizedDict, max_length: int, p
142
143
  'input_ids': combined_ids,
143
144
  'attention_mask': combined_mask
144
145
  }
146
+
147
+ def get_gradient_norms(model: nn.Module):
148
+ total_norm = 0
149
+ for p in model.parameters():
150
+ if p.grad is not None:
151
+ param_norm = p.grad.data.norm(2)
152
+ total_norm += param_norm.item() ** 2
153
+ total_norm = total_norm ** 0.5
154
+ mean_norm = total_norm / len(list(model.parameters()))
155
+ return total_norm, mean_norm
@@ -69,12 +69,14 @@ class MultiHeadAttention(nn.Module):
69
69
  v = self.v_proj(value).view(b, -1, self.num_heads, d // self.num_heads).transpose(1, 2)
70
70
  return q, k, v
71
71
 
72
- def _apply_rope(self, q: torch.Tensor, k: torch.Tensor):
72
+ def _apply_rope(self, q: torch.Tensor, k: torch.Tensor, separate: bool = False):
73
73
  if self.rope is not None:
74
74
  if self.rope_only_for_query:
75
75
  q = self.rope.forward_one(q)
76
76
  elif self.rope_only_for_keys:
77
77
  k = self.rope.forward_one(k)
78
+ elif separate:
79
+ q, k = self.rope.forward_one(q), self.rope.forward_one(k)
78
80
  else:
79
81
  q, k = self.rope(q, k)
80
82
  return q, k
@@ -110,7 +110,7 @@ class ReactiveTransformerLayer(nn.Module):
110
110
  residual = x
111
111
  if not self.use_post_norm:
112
112
  x = self.norm2(x)
113
- x = self.memory_cross_attention(x, stm, stm)
113
+ x = self.memory_cross_attention(x, stm, stm, mask=mask)
114
114
  x = residual + x
115
115
  if self.use_post_norm:
116
116
  x = self.norm2(x)
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes