rxnn 0.2.71__py3-none-any.whl → 0.2.73__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/memory/attention.py CHANGED
@@ -2,17 +2,15 @@ import torch
2
2
  import torch.nn as nn
3
3
  from .stm import ShortTermMemory
4
4
 
5
+
5
6
  class StmMemoryAttention(nn.Module):
6
7
  def __init__(
7
8
  self,
8
9
  stm: ShortTermMemory,
9
10
  attention_layers: nn.ModuleList,
10
11
  memory_norm_layers: nn.ModuleList,
11
- use_gated_residual: bool = False,
12
- per_slot_gate: bool = False,
13
- init_gate: float = 0.0,
14
- use_dynamic_gate: bool = False,
15
- use_tanh_gate: bool = False,
12
+ memory_input_norm_layers: nn.ModuleList,
13
+ residual_gate_layers: nn.ModuleList,
16
14
  debug_mode: bool = False,
17
15
  debug_interval: int = 10,
18
16
  *args,
@@ -22,16 +20,12 @@ class StmMemoryAttention(nn.Module):
22
20
  self.stm = stm
23
21
  self.attention_layers = attention_layers
24
22
  self.memory_norm_layers = memory_norm_layers
25
- assert len(self.attention_layers) == len(self.memory_norm_layers) == self.stm.memory.size(0)
23
+ self.memory_input_norm_layers = memory_input_norm_layers
24
+ self.residual_gate_layers = residual_gate_layers
25
+ assert (len(self.attention_layers) == len(self.memory_norm_layers) ==
26
+ len(self.residual_gate_layers) == len(self.memory_input_norm_layers) ==
27
+ self.stm.memory.size(0))
26
28
  self.num_layers = len(attention_layers)
27
- self.use_gated_residual = use_gated_residual
28
- self.per_slot_gate = per_slot_gate
29
- self.use_dynamic_gate = use_dynamic_gate
30
- self.use_tanh_gate = use_tanh_gate
31
- if self.use_gated_residual:
32
- gate_shape = (self.num_layers, self.stm.stm_size, 1) if self.per_slot_gate else (self.num_layers,)
33
- self.gate = nn.Parameter(torch.full(gate_shape, init_gate))
34
-
35
29
  self.debug_mode = debug_mode
36
30
  self.debug_interval = debug_interval
37
31
  self.debug_step = 0
@@ -41,32 +35,27 @@ class StmMemoryAttention(nn.Module):
41
35
  if self.attention_layers[i].rope is not None:
42
36
  self.attention_layers[i].rope.update_max_len(max_seq_len)
43
37
 
44
- def _residual_gate(self, gate: torch.Tensor, layer_stm: torch.Tensor, new_layer_stm: torch.Tensor) -> torch.Tensor:
45
- if self.use_dynamic_gate:
46
- mean_dim = -1 if self.per_slot_gate else [1, 2]
47
- gate_input = gate * (new_layer_stm + layer_stm).mean(dim=mean_dim, keepdim=True)
48
- layer_gate = torch.tanh(gate_input) if self.use_tanh_gate else torch.sigmoid(gate_input)
49
- else:
50
- layer_gate = torch.tanh(gate) if self.use_tanh_gate else torch.sigmoid(gate)
51
- if self.use_tanh_gate:
52
- return (1 + layer_gate) * new_layer_stm + (1 - layer_gate) * layer_stm
53
- else:
54
- return layer_gate * new_layer_stm + (1 - layer_gate) * layer_stm
55
-
56
38
  def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
39
+ # 1. Process correct attention mask
57
40
  if attention_mask is not None:
58
41
  attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).bool()
42
+ # 2. Init new empty STM
59
43
  new_stm = torch.zeros_like(self.stm.memory)
44
+
45
+ # 3. Run Short-Term Memory update for all layers
60
46
  for i in range(self.num_layers):
47
+ # 4. Get current layer STM value
61
48
  layer_stm = self.stm(i)
62
- # expand layer STM to batch size, if it's not in batch mode
49
+ # 5. Expand layer STM to batch size, if it's not in batch mode
63
50
  if layer_stm.size(0) == 1:
64
51
  layer_stm = layer_stm.expand(x.size(0), -1, -1)
65
- encoded_layer_data = x[i]
52
+
53
+ # 6. Get encoded layer data and normalize it
54
+ encoded_layer_data = self.memory_input_norm_layers[i](x[i])
55
+ # 7. Normalize STM layer
66
56
  normalized_layer_stm = self.memory_norm_layers[i](layer_stm)
67
- if torch.isnan(normalized_layer_stm).any():
68
- print(f"NaN detected in {i} layer memory norm output")
69
57
 
58
+ # 8. Print normalization stats in debug mode
70
59
  if self.debug_mode and self.training:
71
60
  if self.debug_step != 0 and self.debug_step % self.debug_interval == 0:
72
61
  self.debug_step = 0
@@ -74,16 +63,88 @@ class StmMemoryAttention(nn.Module):
74
63
  else:
75
64
  self.debug_step += 1
76
65
 
77
- if torch.isnan(encoded_layer_data).any():
78
- print(f"NaN detected in {i} layer encoded data input")
79
-
66
+ # 9. Calculate memory attention
80
67
  new_layer_stm = self.attention_layers[i](normalized_layer_stm, encoded_layer_data, encoded_layer_data, mask=attention_mask)
81
- if torch.isnan(new_layer_stm).any():
82
- print(f"NaN detected in {i} layer memory attention output")
68
+ # 10. Combine new updated layer state with current STM state in residual gate
69
+ new_stm[i] = self.residual_gate_layers[i](layer_stm, new_layer_stm) # residual
70
+ # 11. Update all layers/models
71
+ self.stm.update_all(new_stm)
72
+ return self.stm.memory
73
+
74
+
75
+ class InterlayerStmMemoryAttention(StmMemoryAttention):
76
+ def __init__(
77
+ self,
78
+ stm: ShortTermMemory,
79
+ attention_layers: nn.ModuleList,
80
+ memory_norm_layers: nn.ModuleList,
81
+ memory_input_norm_layers: nn.ModuleList,
82
+ residual_gate_layers: nn.ModuleList,
83
+ mean_attention_layers: nn.ModuleList,
84
+ mean_memory_norm_layers: nn.ModuleList,
85
+ mean_residual_gate_layers: nn.ModuleList,
86
+ mean_stm_norm: nn.Module,
87
+ debug_mode: bool = False,
88
+ debug_interval: int = 10,
89
+ **kwargs
90
+ ):
91
+ super(InterlayerStmMemoryAttention, self).__init__(
92
+ stm, attention_layers, memory_norm_layers, memory_input_norm_layers, residual_gate_layers,
93
+ debug_mode=debug_mode, debug_interval=debug_interval, **kwargs
94
+ )
95
+ self.mean_attention_layers = mean_attention_layers
96
+ self.mean_memory_norm_layers = mean_memory_norm_layers
97
+ self.mean_stm_norm = mean_stm_norm
98
+ self.mean_residual_gate_layers = mean_residual_gate_layers
99
+ assert (len(self.mean_attention_layers) == len(self.mean_memory_norm_layers) ==
100
+ len(self.mean_residual_gate_layers) == self.num_layers)
83
101
 
84
- if self.use_gated_residual:
85
- new_stm[i] = self._residual_gate(self.gate[i], layer_stm, new_layer_stm) # gated residual
86
- else:
87
- new_stm[i] = new_layer_stm + layer_stm # residual
102
+ def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
103
+ # 1. Process correct attention mask
104
+ if attention_mask is not None:
105
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).bool()
106
+ # 2. Init new empty STM
107
+ new_stm = torch.zeros_like(self.stm.memory)
108
+
109
+ # 3. Get mean STM value from layers for mean interlayer memory attention
110
+ mean_stm = self.stm.memory.mean(dim=0) # [batch_size, stm_size, embed_dim]
111
+ # 4. Normalize mean STM layer
112
+ normalized_mean_stm = self.mean_stm_norm(mean_stm)
113
+
114
+ # 5. Run Short-Term Memory update for all layers
115
+ for i in range(self.num_layers):
116
+ # 6. Get current layer STM value
117
+ layer_stm = self.stm(i)
118
+ # 7. Expand layer STM to batch size, if it's not in batch mode
119
+ if layer_stm.size(0) == 1:
120
+ layer_stm = layer_stm.expand(x.size(0), -1, -1)
121
+
122
+ # 8. Mean interlayer memory attention
123
+ # a) normalize STM layer value
124
+ pre_normalized_layer_stm = self.mean_memory_norm_layers[i](layer_stm)
125
+ # b) calculate attention between STM layer and mean value of all STM layers (from previous interaction)
126
+ interlayer_stm = self.mean_attention_layers[i](pre_normalized_layer_stm, normalized_mean_stm, normalized_mean_stm, mask=None)
127
+ # c) combine updated interlayer state with current STM state in residual gate
128
+ updated_layer_stm = self.mean_residual_gate_layers[i](layer_stm, interlayer_stm)
129
+
130
+ # 9. Main memory attention
131
+ # a) get encoded data for current layer and normalize it
132
+ encoded_layer_data = self.memory_input_norm_layers[i](x[i])
133
+ # b) normalize STM layer value
134
+ normalized_layer_stm = self.memory_norm_layers[i](updated_layer_stm)
135
+ # c) print normalized STM stats in debug mode
136
+ if self.debug_mode and self.training:
137
+ if self.debug_step != 0 and self.debug_step % self.debug_interval == 0:
138
+ self.debug_step = 0
139
+ print(
140
+ f"Pre-Normalized STM stats - mean: {pre_normalized_layer_stm.mean().item():.4f}, std: {pre_normalized_layer_stm.std().item():.4f}")
141
+ print(f"Normalized STM stats - mean: {normalized_layer_stm.mean().item():.4f}, std: {normalized_layer_stm.std().item():.4f}")
142
+ else:
143
+ self.debug_step += 1
144
+ # d) calculate memory attention between STM layer and encoded data
145
+ new_layer_stm = self.attention_layers[i](normalized_layer_stm, encoded_layer_data, encoded_layer_data, mask=attention_mask)
146
+ # e) combine new updated layer STM with previous state in residual gate
147
+ new_stm[i] = self.residual_gate_layers[i](updated_layer_stm, new_layer_stm) # residual
148
+ # 10. Update all layers/models
88
149
  self.stm.update_all(new_stm)
89
150
  return self.stm.memory
rxnn/memory/gate.py ADDED
@@ -0,0 +1,60 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import TypeAlias, Literal
4
+
5
+ ResidualGateType: TypeAlias = Literal['static', 'elementwise', 'linear']
6
+
7
+
8
+ class ResidualGate(nn.Module):
9
+ def __init__(
10
+ self,
11
+ stm_size: int,
12
+ use_gate: bool = False,
13
+ gate_type: ResidualGateType = 'static',
14
+ per_slot_gate: bool = True,
15
+ init_gate: float = 0.0,
16
+ use_tanh_gate: bool = True,
17
+ **kwargs,
18
+ ):
19
+ super(ResidualGate, self).__init__(**kwargs)
20
+ self.use_gate = use_gate
21
+ self.per_slot_gate = per_slot_gate
22
+ self.gate_type = gate_type
23
+ self.use_tanh_gate = use_tanh_gate
24
+
25
+ if self.use_gate:
26
+ if self.gate_type == 'linear':
27
+ self.gate = nn.Linear(stm_size, stm_size if self.per_slot_gate else 1)
28
+ else:
29
+ gate_shape = (stm_size, 1) if self.per_slot_gate else (1,)
30
+ self.gate = nn.Parameter(torch.full(gate_shape, init_gate))
31
+ else:
32
+ self.gate = None
33
+
34
+ self.gate_activation = nn.Tanh() if self.use_tanh_gate else nn.Sigmoid()
35
+
36
+ def _dynamic_gate(self, old_value: torch.Tensor, new_value: torch.Tensor):
37
+ if self.gate_type == 'linear':
38
+ mean_residual = (new_value + old_value).mean(dim=-1)
39
+ gate_input = self.gate(mean_residual).unsqueeze(-1)
40
+ else:
41
+ mean_dim = -1 if self.per_slot_gate else [1, 2]
42
+ gate_input = self.gate * (new_value + old_value).mean(dim=mean_dim, keepdim=True)
43
+ return self.gate_activation(gate_input)
44
+
45
+ def _calculate_output(self, layer_gate: torch.Tensor, old_value: torch.Tensor, new_value: torch.Tensor) -> torch.Tensor:
46
+ if self.use_tanh_gate:
47
+ return (1 + layer_gate) * new_value + (1 - layer_gate) * old_value
48
+ else:
49
+ return layer_gate * new_value + (1 - layer_gate) * old_value
50
+
51
+ def forward(self, old_value: torch.Tensor, new_value: torch.Tensor) -> torch.Tensor:
52
+ if not self.use_gate:
53
+ return new_value + old_value
54
+
55
+ if self.gate_type == 'static':
56
+ layer_gate = self.gate_activation(self.gate)
57
+ else:
58
+ layer_gate = self._dynamic_gate(old_value, new_value)
59
+
60
+ return self._calculate_output(layer_gate, old_value, new_value)
rxnn/rxt/models.py CHANGED
@@ -9,11 +9,13 @@ from ..transformers.models import ReactiveTransformerBase, ReactiveTransformerEn
9
9
  from ..transformers.ff import get_activation_layer
10
10
  from ..memory.stm import ShortTermMemory
11
11
  from ..memory.norm import init_memory_norm
12
- from ..memory.attention import StmMemoryAttention
12
+ from ..memory.attention import StmMemoryAttention, InterlayerStmMemoryAttention
13
+ from ..memory.gate import ResidualGate, ResidualGateType
13
14
  from ..utils import get_model_size
14
15
  from ..experimental.attention import init_experimental_attention
15
16
 
16
17
 
18
+
17
19
  class RxTAlphaComponentConfig(TypedDict):
18
20
  num_layers: int
19
21
  vocab_size: int
@@ -260,15 +262,15 @@ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2
260
262
  att_experts: int = None,
261
263
  att_query_experts: int = None,
262
264
  att_query_groups: int = None,
263
- norm_type: str = 'rms',
265
+ norm_type: str = 'classic-rms',
264
266
  norm_init_gate: float = -2.0,
265
267
  norm_per_dim_scale: bool = False,
266
268
  norm_decay: float = 0.9,
267
269
  use_gated_residual: bool = False,
268
- residual_per_slot_gate: bool = False,
269
- residual_init_gate: float = 0.0,
270
- use_dynamic_residual_gate: bool = False,
271
- use_tanh_residual_gate: bool = False,
270
+ residual_per_slot_gate: bool = True,
271
+ residual_gate_init: float = 3.0,
272
+ residual_gate_type: ResidualGateType = 'static',
273
+ use_tanh_residual_gate: bool = True,
272
274
  debug_mode: bool = False,
273
275
  debug_interval: int = 10,
274
276
  **kwargs,
@@ -296,12 +298,153 @@ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2
296
298
  memory_norm_layers = nn.ModuleList([init_memory_norm(norm_type, embed_dim, stm_size, decay=norm_decay,
297
299
  init_gate=norm_init_gate, per_dim_scale=norm_per_dim_scale)
298
300
  for _ in range(num_layers)])
301
+ memory_input_norm_layers = nn.ModuleList(nn.RMSNorm(embed_dim) for _ in range(num_layers))
299
302
  attention_layers = nn.ModuleList([att_init() for _ in range(num_layers)])
303
+ residual_gates = nn.ModuleList([
304
+ ResidualGate(
305
+ stm_size, use_gate=use_gated_residual, gate_type=residual_gate_type, per_slot_gate=residual_per_slot_gate,
306
+ init_gate=residual_gate_init, use_tanh_gate=use_tanh_residual_gate
307
+ ) for _ in range(num_layers)
308
+ ])
309
+
300
310
  self.model = StmMemoryAttention(
301
311
  stm, attention_layers, memory_norm_layers,
302
- use_gated_residual=use_gated_residual, per_slot_gate=residual_per_slot_gate,
303
- init_gate=residual_init_gate, use_dynamic_gate=use_dynamic_residual_gate,
304
- use_tanh_gate=use_tanh_residual_gate, debug_mode=debug_mode, debug_interval=debug_interval,
312
+ memory_input_norm_layers, residual_gates,
313
+ debug_mode=debug_mode, debug_interval=debug_interval,
314
+ )
315
+
316
+ def freeze(self):
317
+ for param in self.parameters():
318
+ param.requires_grad = False
319
+
320
+ def unfreeze(self):
321
+ for param in self.parameters():
322
+ param.requires_grad = True
323
+
324
+ def load_shared_memory(self, stm: ShortTermMemory):
325
+ self.model.stm = stm
326
+
327
+ def update_max_len(self, max_seq_len: int):
328
+ self.model.update_max_len(max_seq_len)
329
+
330
+ def reset_memory(self, init_type: str = None):
331
+ self.model.stm.reset(init_type)
332
+
333
+ def clone_reset_memory(self):
334
+ self.model.stm.clone_detach_reset()
335
+
336
+ def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
337
+ return self.model(x, attention_mask=attention_mask)
338
+
339
+
340
+ class RxTAlphaInterlayerMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2.0"):
341
+ """RxT-Alpha (Reactive Transformer) memory attention model with interlayer STM attention"""
342
+
343
+ def __init__(
344
+ self,
345
+ num_layers: int = 12,
346
+ embed_dim: int = 512,
347
+ att_heads: int = 16,
348
+ seq_len: int = 1024,
349
+ stm_size: int = 1024,
350
+ use_flash_attention: bool = False,
351
+ att_dropout: float = 0.0,
352
+ att_groups: int = 1,
353
+ att_type: str = 'sqa',
354
+ att_experts: int = None,
355
+ att_query_experts: int = None,
356
+ att_query_groups: int = None,
357
+ interlayer_att_dropout: float = 0.0,
358
+ interlayer_att_groups: int = 1,
359
+ interlayer_att_type: str = 'sqa',
360
+ interlayer_att_experts: int = None,
361
+ interlayer_att_query_experts: int = None,
362
+ interlayer_att_query_groups: int = None,
363
+ norm_type: str = 'classic-rms',
364
+ norm_init_gate: float = -2.0,
365
+ norm_per_dim_scale: bool = False,
366
+ norm_decay: float = 0.9,
367
+ use_gated_residual: bool = False,
368
+ residual_per_slot_gate: bool = True,
369
+ residual_gate_init: float = 3.0,
370
+ residual_gate_type: ResidualGateType = 'static',
371
+ use_tanh_residual_gate: bool = True,
372
+ debug_mode: bool = False,
373
+ debug_interval: int = 10,
374
+ **kwargs,
375
+ ):
376
+ super(RxTAlphaInterlayerMemoryAttention, self).__init__(**kwargs)
377
+
378
+ assert att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma',
379
+ 'sqa'], 'Memory attention type could be "mha", "gqa", "mqa", "gma", "dma", "sqa".'
380
+
381
+ rope = RotaryPositionalEmbedding(embed_dim // att_heads, seq_len)
382
+ stm = ShortTermMemory(num_layers, embed_dim, stm_size)
383
+
384
+ if att_type in ['mha', 'gqa', 'mqa']:
385
+ att_init = lambda: init_attention(
386
+ embed_dim, att_heads, att_type, att_groups, rope=rope,
387
+ use_flash_attention=use_flash_attention, dropout=att_dropout,
388
+ max_seq_len=seq_len, is_causal=False, rope_only_for_keys=True
389
+ )
390
+ else:
391
+ att_init = lambda: init_experimental_attention(
392
+ embed_dim, att_heads, att_type, att_groups, rope=rope,
393
+ use_flash_attention=use_flash_attention, dropout=att_dropout,
394
+ max_seq_len=seq_len, is_causal=False, num_experts=att_experts,
395
+ num_query_experts=att_query_experts, num_query_groups=att_query_groups,
396
+ rope_only_for_keys=True
397
+ )
398
+
399
+ memory_norm_layers = nn.ModuleList([init_memory_norm(norm_type, embed_dim, stm_size, decay=norm_decay,
400
+ init_gate=norm_init_gate, per_dim_scale=norm_per_dim_scale)
401
+ for _ in range(num_layers)])
402
+ memory_input_norm_layers = nn.ModuleList(nn.RMSNorm(embed_dim) for _ in range(num_layers))
403
+ attention_layers = nn.ModuleList([att_init() for _ in range(num_layers)])
404
+ residual_gates = nn.ModuleList([
405
+ ResidualGate(
406
+ stm_size, use_gate=use_gated_residual, gate_type=residual_gate_type,
407
+ per_slot_gate=residual_per_slot_gate,
408
+ init_gate=residual_gate_init, use_tanh_gate=use_tanh_residual_gate
409
+ ) for _ in range(num_layers)
410
+ ])
411
+
412
+ # Interlayer attention
413
+ if interlayer_att_type in ['mha', 'gqa', 'mqa']:
414
+ interlayer_att_init = lambda: init_attention(
415
+ embed_dim, att_heads, interlayer_att_type, interlayer_att_groups, rope=None,
416
+ use_flash_attention=use_flash_attention, dropout=interlayer_att_dropout, is_causal=False
417
+ )
418
+ else:
419
+ interlayer_att_init = lambda: init_experimental_attention(
420
+ embed_dim, att_heads, interlayer_att_type, interlayer_att_groups, rope=None,
421
+ use_flash_attention=use_flash_attention, dropout=interlayer_att_dropout, is_causal=False,
422
+ num_experts=interlayer_att_experts, num_query_experts=interlayer_att_query_experts, num_query_groups=interlayer_att_query_groups
423
+ )
424
+
425
+ mean_attention_layers = nn.ModuleList([interlayer_att_init() for _ in range(num_layers)])
426
+
427
+ mean_stm_norm = init_memory_norm(
428
+ norm_type, embed_dim, stm_size, decay=norm_decay,
429
+ init_gate=norm_init_gate, per_dim_scale=norm_per_dim_scale
430
+ )
431
+
432
+ mean_memory_norm_layers = nn.ModuleList([init_memory_norm(norm_type, embed_dim, stm_size, decay=norm_decay,
433
+ init_gate=norm_init_gate, per_dim_scale=norm_per_dim_scale)
434
+ for _ in range(num_layers)])
435
+
436
+ mean_residual_gates = nn.ModuleList([
437
+ ResidualGate(
438
+ stm_size, use_gate=use_gated_residual, gate_type=residual_gate_type,
439
+ per_slot_gate=residual_per_slot_gate,
440
+ init_gate=residual_gate_init, use_tanh_gate=use_tanh_residual_gate
441
+ ) for _ in range(num_layers)
442
+ ])
443
+
444
+ self.model = InterlayerStmMemoryAttention(
445
+ stm, attention_layers, memory_norm_layers, memory_input_norm_layers, residual_gates,
446
+ mean_attention_layers, mean_memory_norm_layers, mean_residual_gates, mean_stm_norm,
447
+ debug_mode=debug_mode, debug_interval=debug_interval,
305
448
  )
306
449
 
307
450
  def freeze(self):
@@ -327,6 +470,7 @@ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2
327
470
  def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
328
471
  return self.model(x, attention_mask=attention_mask)
329
472
 
473
+
330
474
  class RxTAlphaCriticEncoder(RxTAlphaComponentBase, pipeline_tag="text-classification", license="apache-2.0"):
331
475
  """RxT-Alpha (Reactive Transformer) encoder model"""
332
476
 
rxnn/training/mrl.py CHANGED
@@ -592,7 +592,7 @@ class MRLTrainer:
592
592
 
593
593
  router_loss = actor.moe_router_loss()
594
594
  if torch.isnan(router_loss).any():
595
- print("NaN detected in router loss")
595
+ print("!!!!!!!!!!!!!!!!!!!!!! NaN detected in router loss")
596
596
  if router_loss is not None:
597
597
  return main_loss + self.moe_aux_loss_scale * router_loss
598
598
  else:
@@ -671,7 +671,7 @@ class MRLTrainer:
671
671
  # 4.4 Unscale and clip gradient norms
672
672
  self.scaler.unscale_(self.optimizer)
673
673
  torch.nn.utils.clip_grad_norm_(self.actor.unique_parameters(), max_norm=1.0,
674
- error_if_nonfinite=self.debug_mode)
674
+ error_if_nonfinite=False)
675
675
  if self.debug_mode and self.epoch_step['train'] % self.debug_interval == 0:
676
676
  self._log_gradients(logits)
677
677
  # 4.5 Run scaled optimization step
@@ -692,7 +692,7 @@ class MRLTrainer:
692
692
  policy_loss.backward(retain_graph=True)
693
693
  # 4.4 Clip gradient norms
694
694
  torch.nn.utils.clip_grad_norm_(self.actor.unique_parameters(), max_norm=1.0,
695
- error_if_nonfinite=self.debug_mode)
695
+ error_if_nonfinite=False)
696
696
  if self.debug_mode and self.epoch_step['train'] % self.debug_interval == 0:
697
697
  self._log_gradients(logits)
698
698
  # 4.5 Run scaled optimization step
@@ -1,5 +1,7 @@
1
1
  import torch
2
2
  import torch.nn as nn
3
+ from poetry.console.commands import self
4
+
3
5
  from .attention import MultiHeadAttention
4
6
  from .ff import FeedForward, GatedFeedForward
5
7
  from .moe import MoeFeedForward, GatedMoeFeedForward
@@ -49,10 +51,12 @@ class ReactiveTransformerLayer(nn.Module):
49
51
  self.norm1 = nn.RMSNorm(embed_dim)
50
52
  self.norm2 = nn.RMSNorm(embed_dim)
51
53
  self.norm3 = nn.RMSNorm(embed_dim)
54
+ self.stm_norm = nn.RMSNorm(embed_dim)
52
55
  else:
53
56
  self.norm1 = nn.LayerNorm(embed_dim)
54
57
  self.norm2 = nn.LayerNorm(embed_dim)
55
58
  self.norm3 = nn.LayerNorm(embed_dim)
59
+ self.stm_norm = nn.LayerNorm(embed_dim)
56
60
  self.use_post_norm = use_post_norm
57
61
  self.use_moe = use_moe
58
62
  self.use_moe_att = use_moe_att
@@ -63,9 +67,11 @@ class ReactiveTransformerLayer(nn.Module):
63
67
  if with_norms:
64
68
  for param in self.norm2.parameters():
65
69
  param.requires_grad_(is_trainable)
70
+ for param in self.stm_norm.parameters():
71
+ param.requires_grad_(is_trainable)
66
72
 
67
73
  def memory_parameters(self) -> list[nn.Parameter]:
68
- return list(self.memory_cross_attention.parameters()) + list(self.norm2.parameters())
74
+ return list(self.memory_cross_attention.parameters()) + list(self.norm2.parameters()) + list(self.stm_norm.parameters())
69
75
 
70
76
  def not_memory_parameters(self) -> list[nn.Parameter]:
71
77
  return (list(self.attention.parameters()) + list(self.norm1.parameters()) +
@@ -102,11 +108,8 @@ class ReactiveTransformerLayer(nn.Module):
102
108
  residual = x
103
109
  if not self.use_post_norm:
104
110
  x = self.norm1(x)
105
- if torch.isnan(x).any():
106
- print("NaN detected in pre-norm (self-attention) output")
107
111
  x = self.attention(x, x, x, mask=mask)
108
- if torch.isnan(x).any():
109
- print("NaN detected in self-attention output")
112
+
110
113
  x = residual + x
111
114
  if self.use_post_norm:
112
115
  x = self.norm1(x)
@@ -114,18 +117,13 @@ class ReactiveTransformerLayer(nn.Module):
114
117
  residual = x
115
118
  if not self.use_post_norm:
116
119
  x = self.norm2(x)
117
- if torch.isnan(x).any():
118
- print("NaN detected in pre-norm (cross-attention) output")
119
120
 
121
+ # normalize STM and prepare STM mask
122
+ stm = self.stm_norm(stm)
120
123
  mem_mask = mask.squeeze(1).unsqueeze(-1).expand(-1, -1, -1, stm.size(1)) \
121
124
  if mask is not None else None
122
125
 
123
- if torch.isnan(stm).any():
124
- print("NaN detected in STM cross-attention input")
125
-
126
126
  x = self.memory_cross_attention(x, stm, stm, mask=mem_mask)
127
- if torch.isnan(x).any():
128
- print("NaN detected in cross-attention output")
129
127
  x = residual + x
130
128
  if self.use_post_norm:
131
129
  x = self.norm2(x)
@@ -134,11 +132,7 @@ class ReactiveTransformerLayer(nn.Module):
134
132
  residual = x
135
133
  if not self.use_post_norm:
136
134
  x = self.norm3(x)
137
- if torch.isnan(x).any():
138
- print("NaN detected in pre-norm (ff) output")
139
135
  x = self.ff(x)
140
- if torch.isnan(x).any():
141
- print("NaN detected in ff output")
142
136
  x = residual + x
143
137
  if self.use_post_norm:
144
138
  x = self.norm3(x)
@@ -94,7 +94,7 @@ class ReactiveTransformerDecoder(ReactiveTransformerBase):
94
94
  def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
95
95
  x = super().forward(x) # apply embeddings
96
96
  if torch.isnan(x).any():
97
- print("NaN detected in decoder embedding output")
97
+ print("!!!!!!!!!!!!!!!!!!!!!! NaN detected in decoder embedding output")
98
98
  seq_len = x.size(1)
99
99
  if not self.use_flash_attention and self.use_relative_embedding:
100
100
  mask = create_causal_mask(seq_len, device=x.device)
@@ -112,7 +112,7 @@ class ReactiveTransformerDecoder(ReactiveTransformerBase):
112
112
  for i in range(self.num_own_layers):
113
113
  x = self._handle_layer(i, x, mask=mask)
114
114
  if torch.isnan(x).any():
115
- print(f"NaN detected in {i}. decoder layer output")
115
+ print(f"!!!!!!!!!!!!!!!!!!!!!! NaN detected in {i}. decoder layer output")
116
116
  return self.head(self.head_norm(x) if self.use_head_norm else x)
117
117
 
118
118
 
@@ -122,7 +122,7 @@ class ReactiveTransformerEncoder(ReactiveTransformerBase):
122
122
  def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> tuple[torch.Tensor, torch.Tensor]:
123
123
  x = super().forward(x) # apply embeddings
124
124
  if torch.isnan(x).any():
125
- print("NaN detected in encoder embedding output")
125
+ print("!!!!!!!!!!!!!!!!!!!!!! NaN detected in encoder embedding output")
126
126
  if attention_mask is not None:
127
127
  attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).bool()
128
128
 
@@ -136,7 +136,7 @@ class ReactiveTransformerEncoder(ReactiveTransformerBase):
136
136
  for i in range(self.num_own_layers):
137
137
  x = self._handle_layer(i, x, mask=attention_mask)
138
138
  if torch.isnan(x).any():
139
- print(f"NaN detected in {i}. encoder layer output")
139
+ print(f"!!!!!!!!!!!!!!!!!!!!!! NaN detected in {i}. encoder layer output")
140
140
  hidden_states.append(x)
141
141
  return x, torch.stack(hidden_states)
142
142
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.71
3
+ Version: 0.2.73
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -5,11 +5,12 @@ rxnn/experimental/attention.py,sha256=jlNS82INjycNEfmk3HtkIacUvT_ELhaCO2g-kZTvhX
5
5
  rxnn/experimental/models.py,sha256=KheR1zSNJIaeVvpVAkEJwcuM5nOqQP0ZF08XhrtGJ8E,5387
6
6
  rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
7
7
  rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
- rxnn/memory/attention.py,sha256=O4ycW3KKP5hFYadgVh47LvGWJn9zNHz8vh9E9okC0h8,4223
8
+ rxnn/memory/attention.py,sha256=CReYJZNA5JRED_QWqX-yKqEKZTRX6DNCAB8uFLZtKxI,7513
9
+ rxnn/memory/gate.py,sha256=pR_H2y9C7S02QskoFAEC9Tmluut0k4GGlHgvZGiw6m4,2332
9
10
  rxnn/memory/norm.py,sha256=cVjjhCLqR5K6-321SP_ObG17y-ddlcTJeCTXvW4vpk0,6675
10
11
  rxnn/memory/stm.py,sha256=jv57gsH9XW19sLbxpRDqsp1yfsii_4Ef4Ncr_ztk-i4,3937
11
12
  rxnn/rxt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
- rxnn/rxt/models.py,sha256=M_0nEfSgr5Wyv-Ku4TCLpIs5VndUccjtIR0wU0DSVRo,15574
13
+ rxnn/rxt/models.py,sha256=Pb48Frl6HV4Wb9CZgYtmzch3k_4Jess3rhs7dY1I96k,22209
13
14
  rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
15
  rxnn/training/base.py,sha256=CqaArEZYOdH64nmKfx28U3GI46TzO4oNkjf_hrF23Cw,11835
15
16
  rxnn/training/bml.py,sha256=hw6gLpLkGvqLzxIvBg4MvCc5r8cHpEm2RDyh7nH6CtE,16914
@@ -17,7 +18,7 @@ rxnn/training/callbacks.py,sha256=rS8leuVFPVVfE5Zc8DMkUZhRIPN-vpPbUjowXE5TSBw,36
17
18
  rxnn/training/dataset.py,sha256=ruU6k33pQmpTqhxpjLFNdDJnCjcrBcGeFOzJqFahJDM,51880
18
19
  rxnn/training/ddp.py,sha256=VsNBjn3cY-uUj8hbsW7oKvb0_ZKnXnJ2KgObm-Mr9i4,836
19
20
  rxnn/training/models.py,sha256=ILkcqBV1MImnULnq-YDSSEf8cUdEbUgQaH0FRTsa4LA,9069
20
- rxnn/training/mrl.py,sha256=Ntkti6DDKipKa-AwTvo1WDOdIXOL3uXOhT-Xx29wR-w,67369
21
+ rxnn/training/mrl.py,sha256=KUJAdUznquhf5UlcpV-QF5oKHDBEsDecMEVmMLQZw7w,67380
21
22
  rxnn/training/reward.py,sha256=uiSsBXmjMw2yv-1Bssy3RTlpU6zP8ape3490Sl-aT0M,16144
22
23
  rxnn/training/rl.py,sha256=hWtExxY-_pAmTOGYxyCNounUbaGWvLDVltC4sRC7MN4,7175
23
24
  rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
@@ -26,14 +27,14 @@ rxnn/training/utils.py,sha256=ngDCm654NL3UsPy190Er4XPc9HI-OyEV6tDLMgEEvQc,6219
26
27
  rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
27
28
  rxnn/transformers/attention.py,sha256=KRnKT6XUqAXElxV9y72mSpdTeiMgCKCCLqqxCFNTHmA,16372
28
29
  rxnn/transformers/ff.py,sha256=WDjO-H9XWInoWnUnxiseIH6Kx5GlHP0zGJygwhcb1gc,2589
29
- rxnn/transformers/layers.py,sha256=bcDP8vZ5dpTWWqMCkzrPG8yQA0D0G5VjnV2Nq9IO8Dc,8816
30
+ rxnn/transformers/layers.py,sha256=fxjlbQG6cwxq-b2ei4DnohSQGH5gwy4GkfP9duTUvjw,8492
30
31
  rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
31
- rxnn/transformers/models.py,sha256=r4vNldYqCIpwMpXkFZvYbw0UBK3NE75qH7bc6OZ8YjE,11587
32
+ rxnn/transformers/models.py,sha256=TP0H9do53Z0vd8kpHMISBzMpHE5X9QIHcy0B-iJHuNQ,11711
32
33
  rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
33
34
  rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
34
35
  rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
35
36
  rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
36
- rxnn-0.2.71.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
37
- rxnn-0.2.71.dist-info/METADATA,sha256=7BHHcFtImjPB57X2eRLgO4IFOSBNb7GOR5ytMaCttkI,60420
38
- rxnn-0.2.71.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
39
- rxnn-0.2.71.dist-info/RECORD,,
37
+ rxnn-0.2.73.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
38
+ rxnn-0.2.73.dist-info/METADATA,sha256=gtoRMeFgBuOZs4lRKl9JGUxZ2X4C9K78Ee-NHLMqW4E,60420
39
+ rxnn-0.2.73.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
40
+ rxnn-0.2.73.dist-info/RECORD,,
File without changes
File without changes