rxnn 0.2.71__tar.gz → 0.2.73__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. {rxnn-0.2.71 → rxnn-0.2.73}/PKG-INFO +1 -1
  2. {rxnn-0.2.71 → rxnn-0.2.73}/pyproject.toml +1 -1
  3. rxnn-0.2.73/src/rxnn/memory/attention.py +150 -0
  4. rxnn-0.2.73/src/rxnn/memory/gate.py +60 -0
  5. {rxnn-0.2.71 → rxnn-0.2.73}/src/rxnn/rxt/models.py +153 -9
  6. {rxnn-0.2.71 → rxnn-0.2.73}/src/rxnn/training/mrl.py +3 -3
  7. {rxnn-0.2.71 → rxnn-0.2.73}/src/rxnn/transformers/layers.py +10 -16
  8. {rxnn-0.2.71 → rxnn-0.2.73}/src/rxnn/transformers/models.py +4 -4
  9. rxnn-0.2.71/src/rxnn/memory/attention.py +0 -89
  10. {rxnn-0.2.71 → rxnn-0.2.73}/LICENSE +0 -0
  11. {rxnn-0.2.71 → rxnn-0.2.73}/README.md +0 -0
  12. {rxnn-0.2.71 → rxnn-0.2.73}/src/rxnn/.DS_Store +0 -0
  13. {rxnn-0.2.71 → rxnn-0.2.73}/src/rxnn/__init__.py +0 -0
  14. {rxnn-0.2.71 → rxnn-0.2.73}/src/rxnn/experimental/__init__.py +0 -0
  15. {rxnn-0.2.71 → rxnn-0.2.73}/src/rxnn/experimental/attention.py +0 -0
  16. {rxnn-0.2.71 → rxnn-0.2.73}/src/rxnn/experimental/models.py +0 -0
  17. {rxnn-0.2.71 → rxnn-0.2.73}/src/rxnn/experimental/moe.py +0 -0
  18. {rxnn-0.2.71 → rxnn-0.2.73}/src/rxnn/memory/__init__.py +0 -0
  19. {rxnn-0.2.71 → rxnn-0.2.73}/src/rxnn/memory/norm.py +0 -0
  20. {rxnn-0.2.71 → rxnn-0.2.73}/src/rxnn/memory/stm.py +0 -0
  21. {rxnn-0.2.71 → rxnn-0.2.73}/src/rxnn/rxt/__init__.py +0 -0
  22. {rxnn-0.2.71 → rxnn-0.2.73}/src/rxnn/training/__init__.py +0 -0
  23. {rxnn-0.2.71 → rxnn-0.2.73}/src/rxnn/training/base.py +0 -0
  24. {rxnn-0.2.71 → rxnn-0.2.73}/src/rxnn/training/bml.py +0 -0
  25. {rxnn-0.2.71 → rxnn-0.2.73}/src/rxnn/training/callbacks.py +0 -0
  26. {rxnn-0.2.71 → rxnn-0.2.73}/src/rxnn/training/dataset.py +0 -0
  27. {rxnn-0.2.71 → rxnn-0.2.73}/src/rxnn/training/ddp.py +0 -0
  28. {rxnn-0.2.71 → rxnn-0.2.73}/src/rxnn/training/models.py +0 -0
  29. {rxnn-0.2.71 → rxnn-0.2.73}/src/rxnn/training/reward.py +0 -0
  30. {rxnn-0.2.71 → rxnn-0.2.73}/src/rxnn/training/rl.py +0 -0
  31. {rxnn-0.2.71 → rxnn-0.2.73}/src/rxnn/training/scheduler.py +0 -0
  32. {rxnn-0.2.71 → rxnn-0.2.73}/src/rxnn/training/tokenizer.py +0 -0
  33. {rxnn-0.2.71 → rxnn-0.2.73}/src/rxnn/training/utils.py +0 -0
  34. {rxnn-0.2.71 → rxnn-0.2.73}/src/rxnn/transformers/__init__.py +0 -0
  35. {rxnn-0.2.71 → rxnn-0.2.73}/src/rxnn/transformers/attention.py +0 -0
  36. {rxnn-0.2.71 → rxnn-0.2.73}/src/rxnn/transformers/ff.py +0 -0
  37. {rxnn-0.2.71 → rxnn-0.2.73}/src/rxnn/transformers/mask.py +0 -0
  38. {rxnn-0.2.71 → rxnn-0.2.73}/src/rxnn/transformers/moe.py +0 -0
  39. {rxnn-0.2.71 → rxnn-0.2.73}/src/rxnn/transformers/positional.py +0 -0
  40. {rxnn-0.2.71 → rxnn-0.2.73}/src/rxnn/transformers/sampler.py +0 -0
  41. {rxnn-0.2.71 → rxnn-0.2.73}/src/rxnn/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.71
3
+ Version: 0.2.73
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
4
4
 
5
5
  [tool.poetry]
6
6
  name = "rxnn"
7
- version = "0.2.71"
7
+ version = "0.2.73"
8
8
  description = "RxNN: Reactive Neural Networks Platform"
9
9
 
10
10
  license = "Apache-2.0"
@@ -0,0 +1,150 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from .stm import ShortTermMemory
4
+
5
+
6
+ class StmMemoryAttention(nn.Module):
7
+ def __init__(
8
+ self,
9
+ stm: ShortTermMemory,
10
+ attention_layers: nn.ModuleList,
11
+ memory_norm_layers: nn.ModuleList,
12
+ memory_input_norm_layers: nn.ModuleList,
13
+ residual_gate_layers: nn.ModuleList,
14
+ debug_mode: bool = False,
15
+ debug_interval: int = 10,
16
+ *args,
17
+ **kwargs
18
+ ):
19
+ super(StmMemoryAttention, self).__init__(*args, **kwargs)
20
+ self.stm = stm
21
+ self.attention_layers = attention_layers
22
+ self.memory_norm_layers = memory_norm_layers
23
+ self.memory_input_norm_layers = memory_input_norm_layers
24
+ self.residual_gate_layers = residual_gate_layers
25
+ assert (len(self.attention_layers) == len(self.memory_norm_layers) ==
26
+ len(self.residual_gate_layers) == len(self.memory_input_norm_layers) ==
27
+ self.stm.memory.size(0))
28
+ self.num_layers = len(attention_layers)
29
+ self.debug_mode = debug_mode
30
+ self.debug_interval = debug_interval
31
+ self.debug_step = 0
32
+
33
+ def update_max_len(self, max_seq_len: int):
34
+ for i in range(self.num_layers):
35
+ if self.attention_layers[i].rope is not None:
36
+ self.attention_layers[i].rope.update_max_len(max_seq_len)
37
+
38
+ def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
39
+ # 1. Process correct attention mask
40
+ if attention_mask is not None:
41
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).bool()
42
+ # 2. Init new empty STM
43
+ new_stm = torch.zeros_like(self.stm.memory)
44
+
45
+ # 3. Run Short-Term Memory update for all layers
46
+ for i in range(self.num_layers):
47
+ # 4. Get current layer STM value
48
+ layer_stm = self.stm(i)
49
+ # 5. Expand layer STM to batch size, if it's not in batch mode
50
+ if layer_stm.size(0) == 1:
51
+ layer_stm = layer_stm.expand(x.size(0), -1, -1)
52
+
53
+ # 6. Get encoded layer data and normalize it
54
+ encoded_layer_data = self.memory_input_norm_layers[i](x[i])
55
+ # 7. Normalize STM layer
56
+ normalized_layer_stm = self.memory_norm_layers[i](layer_stm)
57
+
58
+ # 8. Print normalization stats in debug mode
59
+ if self.debug_mode and self.training:
60
+ if self.debug_step != 0 and self.debug_step % self.debug_interval == 0:
61
+ self.debug_step = 0
62
+ print(f"Normalized STM stats - mean: {normalized_layer_stm.mean().item():.4f}, std: {normalized_layer_stm.std().item():.4f}")
63
+ else:
64
+ self.debug_step += 1
65
+
66
+ # 9. Calculate memory attention
67
+ new_layer_stm = self.attention_layers[i](normalized_layer_stm, encoded_layer_data, encoded_layer_data, mask=attention_mask)
68
+ # 10. Combine new updated layer state with current STM state in residual gate
69
+ new_stm[i] = self.residual_gate_layers[i](layer_stm, new_layer_stm) # residual
70
+ # 11. Update all layers/models
71
+ self.stm.update_all(new_stm)
72
+ return self.stm.memory
73
+
74
+
75
+ class InterlayerStmMemoryAttention(StmMemoryAttention):
76
+ def __init__(
77
+ self,
78
+ stm: ShortTermMemory,
79
+ attention_layers: nn.ModuleList,
80
+ memory_norm_layers: nn.ModuleList,
81
+ memory_input_norm_layers: nn.ModuleList,
82
+ residual_gate_layers: nn.ModuleList,
83
+ mean_attention_layers: nn.ModuleList,
84
+ mean_memory_norm_layers: nn.ModuleList,
85
+ mean_residual_gate_layers: nn.ModuleList,
86
+ mean_stm_norm: nn.Module,
87
+ debug_mode: bool = False,
88
+ debug_interval: int = 10,
89
+ **kwargs
90
+ ):
91
+ super(InterlayerStmMemoryAttention, self).__init__(
92
+ stm, attention_layers, memory_norm_layers, memory_input_norm_layers, residual_gate_layers,
93
+ debug_mode=debug_mode, debug_interval=debug_interval, **kwargs
94
+ )
95
+ self.mean_attention_layers = mean_attention_layers
96
+ self.mean_memory_norm_layers = mean_memory_norm_layers
97
+ self.mean_stm_norm = mean_stm_norm
98
+ self.mean_residual_gate_layers = mean_residual_gate_layers
99
+ assert (len(self.mean_attention_layers) == len(self.mean_memory_norm_layers) ==
100
+ len(self.mean_residual_gate_layers) == self.num_layers)
101
+
102
+ def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
103
+ # 1. Process correct attention mask
104
+ if attention_mask is not None:
105
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).bool()
106
+ # 2. Init new empty STM
107
+ new_stm = torch.zeros_like(self.stm.memory)
108
+
109
+ # 3. Get mean STM value from layers for mean interlayer memory attention
110
+ mean_stm = self.stm.memory.mean(dim=0) # [batch_size, stm_size, embed_dim]
111
+ # 4. Normalize mean STM layer
112
+ normalized_mean_stm = self.mean_stm_norm(mean_stm)
113
+
114
+ # 5. Run Short-Term Memory update for all layers
115
+ for i in range(self.num_layers):
116
+ # 6. Get current layer STM value
117
+ layer_stm = self.stm(i)
118
+ # 7. Expand layer STM to batch size, if it's not in batch mode
119
+ if layer_stm.size(0) == 1:
120
+ layer_stm = layer_stm.expand(x.size(0), -1, -1)
121
+
122
+ # 8. Mean interlayer memory attention
123
+ # a) normalize STM layer value
124
+ pre_normalized_layer_stm = self.mean_memory_norm_layers[i](layer_stm)
125
+ # b) calculate attention between STM layer and mean value of all STM layers (from previous interaction)
126
+ interlayer_stm = self.mean_attention_layers[i](pre_normalized_layer_stm, normalized_mean_stm, normalized_mean_stm, mask=None)
127
+ # c) combine updated interlayer state with current STM state in residual gate
128
+ updated_layer_stm = self.mean_residual_gate_layers[i](layer_stm, interlayer_stm)
129
+
130
+ # 9. Main memory attention
131
+ # a) get encoded data for current layer and normalize it
132
+ encoded_layer_data = self.memory_input_norm_layers[i](x[i])
133
+ # b) normalize STM layer value
134
+ normalized_layer_stm = self.memory_norm_layers[i](updated_layer_stm)
135
+ # c) print normalized STM stats in debug mode
136
+ if self.debug_mode and self.training:
137
+ if self.debug_step != 0 and self.debug_step % self.debug_interval == 0:
138
+ self.debug_step = 0
139
+ print(
140
+ f"Pre-Normalized STM stats - mean: {pre_normalized_layer_stm.mean().item():.4f}, std: {pre_normalized_layer_stm.std().item():.4f}")
141
+ print(f"Normalized STM stats - mean: {normalized_layer_stm.mean().item():.4f}, std: {normalized_layer_stm.std().item():.4f}")
142
+ else:
143
+ self.debug_step += 1
144
+ # d) calculate memory attention between STM layer and encoded data
145
+ new_layer_stm = self.attention_layers[i](normalized_layer_stm, encoded_layer_data, encoded_layer_data, mask=attention_mask)
146
+ # e) combine new updated layer STM with previous state in residual gate
147
+ new_stm[i] = self.residual_gate_layers[i](updated_layer_stm, new_layer_stm) # residual
148
+ # 10. Update all layers/models
149
+ self.stm.update_all(new_stm)
150
+ return self.stm.memory
@@ -0,0 +1,60 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import TypeAlias, Literal
4
+
5
+ ResidualGateType: TypeAlias = Literal['static', 'elementwise', 'linear']
6
+
7
+
8
+ class ResidualGate(nn.Module):
9
+ def __init__(
10
+ self,
11
+ stm_size: int,
12
+ use_gate: bool = False,
13
+ gate_type: ResidualGateType = 'static',
14
+ per_slot_gate: bool = True,
15
+ init_gate: float = 0.0,
16
+ use_tanh_gate: bool = True,
17
+ **kwargs,
18
+ ):
19
+ super(ResidualGate, self).__init__(**kwargs)
20
+ self.use_gate = use_gate
21
+ self.per_slot_gate = per_slot_gate
22
+ self.gate_type = gate_type
23
+ self.use_tanh_gate = use_tanh_gate
24
+
25
+ if self.use_gate:
26
+ if self.gate_type == 'linear':
27
+ self.gate = nn.Linear(stm_size, stm_size if self.per_slot_gate else 1)
28
+ else:
29
+ gate_shape = (stm_size, 1) if self.per_slot_gate else (1,)
30
+ self.gate = nn.Parameter(torch.full(gate_shape, init_gate))
31
+ else:
32
+ self.gate = None
33
+
34
+ self.gate_activation = nn.Tanh() if self.use_tanh_gate else nn.Sigmoid()
35
+
36
+ def _dynamic_gate(self, old_value: torch.Tensor, new_value: torch.Tensor):
37
+ if self.gate_type == 'linear':
38
+ mean_residual = (new_value + old_value).mean(dim=-1)
39
+ gate_input = self.gate(mean_residual).unsqueeze(-1)
40
+ else:
41
+ mean_dim = -1 if self.per_slot_gate else [1, 2]
42
+ gate_input = self.gate * (new_value + old_value).mean(dim=mean_dim, keepdim=True)
43
+ return self.gate_activation(gate_input)
44
+
45
+ def _calculate_output(self, layer_gate: torch.Tensor, old_value: torch.Tensor, new_value: torch.Tensor) -> torch.Tensor:
46
+ if self.use_tanh_gate:
47
+ return (1 + layer_gate) * new_value + (1 - layer_gate) * old_value
48
+ else:
49
+ return layer_gate * new_value + (1 - layer_gate) * old_value
50
+
51
+ def forward(self, old_value: torch.Tensor, new_value: torch.Tensor) -> torch.Tensor:
52
+ if not self.use_gate:
53
+ return new_value + old_value
54
+
55
+ if self.gate_type == 'static':
56
+ layer_gate = self.gate_activation(self.gate)
57
+ else:
58
+ layer_gate = self._dynamic_gate(old_value, new_value)
59
+
60
+ return self._calculate_output(layer_gate, old_value, new_value)
@@ -9,11 +9,13 @@ from ..transformers.models import ReactiveTransformerBase, ReactiveTransformerEn
9
9
  from ..transformers.ff import get_activation_layer
10
10
  from ..memory.stm import ShortTermMemory
11
11
  from ..memory.norm import init_memory_norm
12
- from ..memory.attention import StmMemoryAttention
12
+ from ..memory.attention import StmMemoryAttention, InterlayerStmMemoryAttention
13
+ from ..memory.gate import ResidualGate, ResidualGateType
13
14
  from ..utils import get_model_size
14
15
  from ..experimental.attention import init_experimental_attention
15
16
 
16
17
 
18
+
17
19
  class RxTAlphaComponentConfig(TypedDict):
18
20
  num_layers: int
19
21
  vocab_size: int
@@ -260,15 +262,15 @@ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2
260
262
  att_experts: int = None,
261
263
  att_query_experts: int = None,
262
264
  att_query_groups: int = None,
263
- norm_type: str = 'rms',
265
+ norm_type: str = 'classic-rms',
264
266
  norm_init_gate: float = -2.0,
265
267
  norm_per_dim_scale: bool = False,
266
268
  norm_decay: float = 0.9,
267
269
  use_gated_residual: bool = False,
268
- residual_per_slot_gate: bool = False,
269
- residual_init_gate: float = 0.0,
270
- use_dynamic_residual_gate: bool = False,
271
- use_tanh_residual_gate: bool = False,
270
+ residual_per_slot_gate: bool = True,
271
+ residual_gate_init: float = 3.0,
272
+ residual_gate_type: ResidualGateType = 'static',
273
+ use_tanh_residual_gate: bool = True,
272
274
  debug_mode: bool = False,
273
275
  debug_interval: int = 10,
274
276
  **kwargs,
@@ -296,12 +298,153 @@ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2
296
298
  memory_norm_layers = nn.ModuleList([init_memory_norm(norm_type, embed_dim, stm_size, decay=norm_decay,
297
299
  init_gate=norm_init_gate, per_dim_scale=norm_per_dim_scale)
298
300
  for _ in range(num_layers)])
301
+ memory_input_norm_layers = nn.ModuleList(nn.RMSNorm(embed_dim) for _ in range(num_layers))
299
302
  attention_layers = nn.ModuleList([att_init() for _ in range(num_layers)])
303
+ residual_gates = nn.ModuleList([
304
+ ResidualGate(
305
+ stm_size, use_gate=use_gated_residual, gate_type=residual_gate_type, per_slot_gate=residual_per_slot_gate,
306
+ init_gate=residual_gate_init, use_tanh_gate=use_tanh_residual_gate
307
+ ) for _ in range(num_layers)
308
+ ])
309
+
300
310
  self.model = StmMemoryAttention(
301
311
  stm, attention_layers, memory_norm_layers,
302
- use_gated_residual=use_gated_residual, per_slot_gate=residual_per_slot_gate,
303
- init_gate=residual_init_gate, use_dynamic_gate=use_dynamic_residual_gate,
304
- use_tanh_gate=use_tanh_residual_gate, debug_mode=debug_mode, debug_interval=debug_interval,
312
+ memory_input_norm_layers, residual_gates,
313
+ debug_mode=debug_mode, debug_interval=debug_interval,
314
+ )
315
+
316
+ def freeze(self):
317
+ for param in self.parameters():
318
+ param.requires_grad = False
319
+
320
+ def unfreeze(self):
321
+ for param in self.parameters():
322
+ param.requires_grad = True
323
+
324
+ def load_shared_memory(self, stm: ShortTermMemory):
325
+ self.model.stm = stm
326
+
327
+ def update_max_len(self, max_seq_len: int):
328
+ self.model.update_max_len(max_seq_len)
329
+
330
+ def reset_memory(self, init_type: str = None):
331
+ self.model.stm.reset(init_type)
332
+
333
+ def clone_reset_memory(self):
334
+ self.model.stm.clone_detach_reset()
335
+
336
+ def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
337
+ return self.model(x, attention_mask=attention_mask)
338
+
339
+
340
+ class RxTAlphaInterlayerMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2.0"):
341
+ """RxT-Alpha (Reactive Transformer) memory attention model with interlayer STM attention"""
342
+
343
+ def __init__(
344
+ self,
345
+ num_layers: int = 12,
346
+ embed_dim: int = 512,
347
+ att_heads: int = 16,
348
+ seq_len: int = 1024,
349
+ stm_size: int = 1024,
350
+ use_flash_attention: bool = False,
351
+ att_dropout: float = 0.0,
352
+ att_groups: int = 1,
353
+ att_type: str = 'sqa',
354
+ att_experts: int = None,
355
+ att_query_experts: int = None,
356
+ att_query_groups: int = None,
357
+ interlayer_att_dropout: float = 0.0,
358
+ interlayer_att_groups: int = 1,
359
+ interlayer_att_type: str = 'sqa',
360
+ interlayer_att_experts: int = None,
361
+ interlayer_att_query_experts: int = None,
362
+ interlayer_att_query_groups: int = None,
363
+ norm_type: str = 'classic-rms',
364
+ norm_init_gate: float = -2.0,
365
+ norm_per_dim_scale: bool = False,
366
+ norm_decay: float = 0.9,
367
+ use_gated_residual: bool = False,
368
+ residual_per_slot_gate: bool = True,
369
+ residual_gate_init: float = 3.0,
370
+ residual_gate_type: ResidualGateType = 'static',
371
+ use_tanh_residual_gate: bool = True,
372
+ debug_mode: bool = False,
373
+ debug_interval: int = 10,
374
+ **kwargs,
375
+ ):
376
+ super(RxTAlphaInterlayerMemoryAttention, self).__init__(**kwargs)
377
+
378
+ assert att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma',
379
+ 'sqa'], 'Memory attention type could be "mha", "gqa", "mqa", "gma", "dma", "sqa".'
380
+
381
+ rope = RotaryPositionalEmbedding(embed_dim // att_heads, seq_len)
382
+ stm = ShortTermMemory(num_layers, embed_dim, stm_size)
383
+
384
+ if att_type in ['mha', 'gqa', 'mqa']:
385
+ att_init = lambda: init_attention(
386
+ embed_dim, att_heads, att_type, att_groups, rope=rope,
387
+ use_flash_attention=use_flash_attention, dropout=att_dropout,
388
+ max_seq_len=seq_len, is_causal=False, rope_only_for_keys=True
389
+ )
390
+ else:
391
+ att_init = lambda: init_experimental_attention(
392
+ embed_dim, att_heads, att_type, att_groups, rope=rope,
393
+ use_flash_attention=use_flash_attention, dropout=att_dropout,
394
+ max_seq_len=seq_len, is_causal=False, num_experts=att_experts,
395
+ num_query_experts=att_query_experts, num_query_groups=att_query_groups,
396
+ rope_only_for_keys=True
397
+ )
398
+
399
+ memory_norm_layers = nn.ModuleList([init_memory_norm(norm_type, embed_dim, stm_size, decay=norm_decay,
400
+ init_gate=norm_init_gate, per_dim_scale=norm_per_dim_scale)
401
+ for _ in range(num_layers)])
402
+ memory_input_norm_layers = nn.ModuleList(nn.RMSNorm(embed_dim) for _ in range(num_layers))
403
+ attention_layers = nn.ModuleList([att_init() for _ in range(num_layers)])
404
+ residual_gates = nn.ModuleList([
405
+ ResidualGate(
406
+ stm_size, use_gate=use_gated_residual, gate_type=residual_gate_type,
407
+ per_slot_gate=residual_per_slot_gate,
408
+ init_gate=residual_gate_init, use_tanh_gate=use_tanh_residual_gate
409
+ ) for _ in range(num_layers)
410
+ ])
411
+
412
+ # Interlayer attention
413
+ if interlayer_att_type in ['mha', 'gqa', 'mqa']:
414
+ interlayer_att_init = lambda: init_attention(
415
+ embed_dim, att_heads, interlayer_att_type, interlayer_att_groups, rope=None,
416
+ use_flash_attention=use_flash_attention, dropout=interlayer_att_dropout, is_causal=False
417
+ )
418
+ else:
419
+ interlayer_att_init = lambda: init_experimental_attention(
420
+ embed_dim, att_heads, interlayer_att_type, interlayer_att_groups, rope=None,
421
+ use_flash_attention=use_flash_attention, dropout=interlayer_att_dropout, is_causal=False,
422
+ num_experts=interlayer_att_experts, num_query_experts=interlayer_att_query_experts, num_query_groups=interlayer_att_query_groups
423
+ )
424
+
425
+ mean_attention_layers = nn.ModuleList([interlayer_att_init() for _ in range(num_layers)])
426
+
427
+ mean_stm_norm = init_memory_norm(
428
+ norm_type, embed_dim, stm_size, decay=norm_decay,
429
+ init_gate=norm_init_gate, per_dim_scale=norm_per_dim_scale
430
+ )
431
+
432
+ mean_memory_norm_layers = nn.ModuleList([init_memory_norm(norm_type, embed_dim, stm_size, decay=norm_decay,
433
+ init_gate=norm_init_gate, per_dim_scale=norm_per_dim_scale)
434
+ for _ in range(num_layers)])
435
+
436
+ mean_residual_gates = nn.ModuleList([
437
+ ResidualGate(
438
+ stm_size, use_gate=use_gated_residual, gate_type=residual_gate_type,
439
+ per_slot_gate=residual_per_slot_gate,
440
+ init_gate=residual_gate_init, use_tanh_gate=use_tanh_residual_gate
441
+ ) for _ in range(num_layers)
442
+ ])
443
+
444
+ self.model = InterlayerStmMemoryAttention(
445
+ stm, attention_layers, memory_norm_layers, memory_input_norm_layers, residual_gates,
446
+ mean_attention_layers, mean_memory_norm_layers, mean_residual_gates, mean_stm_norm,
447
+ debug_mode=debug_mode, debug_interval=debug_interval,
305
448
  )
306
449
 
307
450
  def freeze(self):
@@ -327,6 +470,7 @@ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2
327
470
  def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
328
471
  return self.model(x, attention_mask=attention_mask)
329
472
 
473
+
330
474
  class RxTAlphaCriticEncoder(RxTAlphaComponentBase, pipeline_tag="text-classification", license="apache-2.0"):
331
475
  """RxT-Alpha (Reactive Transformer) encoder model"""
332
476
 
@@ -592,7 +592,7 @@ class MRLTrainer:
592
592
 
593
593
  router_loss = actor.moe_router_loss()
594
594
  if torch.isnan(router_loss).any():
595
- print("NaN detected in router loss")
595
+ print("!!!!!!!!!!!!!!!!!!!!!! NaN detected in router loss")
596
596
  if router_loss is not None:
597
597
  return main_loss + self.moe_aux_loss_scale * router_loss
598
598
  else:
@@ -671,7 +671,7 @@ class MRLTrainer:
671
671
  # 4.4 Unscale and clip gradient norms
672
672
  self.scaler.unscale_(self.optimizer)
673
673
  torch.nn.utils.clip_grad_norm_(self.actor.unique_parameters(), max_norm=1.0,
674
- error_if_nonfinite=self.debug_mode)
674
+ error_if_nonfinite=False)
675
675
  if self.debug_mode and self.epoch_step['train'] % self.debug_interval == 0:
676
676
  self._log_gradients(logits)
677
677
  # 4.5 Run scaled optimization step
@@ -692,7 +692,7 @@ class MRLTrainer:
692
692
  policy_loss.backward(retain_graph=True)
693
693
  # 4.4 Clip gradient norms
694
694
  torch.nn.utils.clip_grad_norm_(self.actor.unique_parameters(), max_norm=1.0,
695
- error_if_nonfinite=self.debug_mode)
695
+ error_if_nonfinite=False)
696
696
  if self.debug_mode and self.epoch_step['train'] % self.debug_interval == 0:
697
697
  self._log_gradients(logits)
698
698
  # 4.5 Run scaled optimization step
@@ -1,5 +1,7 @@
1
1
  import torch
2
2
  import torch.nn as nn
3
+ from poetry.console.commands import self
4
+
3
5
  from .attention import MultiHeadAttention
4
6
  from .ff import FeedForward, GatedFeedForward
5
7
  from .moe import MoeFeedForward, GatedMoeFeedForward
@@ -49,10 +51,12 @@ class ReactiveTransformerLayer(nn.Module):
49
51
  self.norm1 = nn.RMSNorm(embed_dim)
50
52
  self.norm2 = nn.RMSNorm(embed_dim)
51
53
  self.norm3 = nn.RMSNorm(embed_dim)
54
+ self.stm_norm = nn.RMSNorm(embed_dim)
52
55
  else:
53
56
  self.norm1 = nn.LayerNorm(embed_dim)
54
57
  self.norm2 = nn.LayerNorm(embed_dim)
55
58
  self.norm3 = nn.LayerNorm(embed_dim)
59
+ self.stm_norm = nn.LayerNorm(embed_dim)
56
60
  self.use_post_norm = use_post_norm
57
61
  self.use_moe = use_moe
58
62
  self.use_moe_att = use_moe_att
@@ -63,9 +67,11 @@ class ReactiveTransformerLayer(nn.Module):
63
67
  if with_norms:
64
68
  for param in self.norm2.parameters():
65
69
  param.requires_grad_(is_trainable)
70
+ for param in self.stm_norm.parameters():
71
+ param.requires_grad_(is_trainable)
66
72
 
67
73
  def memory_parameters(self) -> list[nn.Parameter]:
68
- return list(self.memory_cross_attention.parameters()) + list(self.norm2.parameters())
74
+ return list(self.memory_cross_attention.parameters()) + list(self.norm2.parameters()) + list(self.stm_norm.parameters())
69
75
 
70
76
  def not_memory_parameters(self) -> list[nn.Parameter]:
71
77
  return (list(self.attention.parameters()) + list(self.norm1.parameters()) +
@@ -102,11 +108,8 @@ class ReactiveTransformerLayer(nn.Module):
102
108
  residual = x
103
109
  if not self.use_post_norm:
104
110
  x = self.norm1(x)
105
- if torch.isnan(x).any():
106
- print("NaN detected in pre-norm (self-attention) output")
107
111
  x = self.attention(x, x, x, mask=mask)
108
- if torch.isnan(x).any():
109
- print("NaN detected in self-attention output")
112
+
110
113
  x = residual + x
111
114
  if self.use_post_norm:
112
115
  x = self.norm1(x)
@@ -114,18 +117,13 @@ class ReactiveTransformerLayer(nn.Module):
114
117
  residual = x
115
118
  if not self.use_post_norm:
116
119
  x = self.norm2(x)
117
- if torch.isnan(x).any():
118
- print("NaN detected in pre-norm (cross-attention) output")
119
120
 
121
+ # normalize STM and prepare STM mask
122
+ stm = self.stm_norm(stm)
120
123
  mem_mask = mask.squeeze(1).unsqueeze(-1).expand(-1, -1, -1, stm.size(1)) \
121
124
  if mask is not None else None
122
125
 
123
- if torch.isnan(stm).any():
124
- print("NaN detected in STM cross-attention input")
125
-
126
126
  x = self.memory_cross_attention(x, stm, stm, mask=mem_mask)
127
- if torch.isnan(x).any():
128
- print("NaN detected in cross-attention output")
129
127
  x = residual + x
130
128
  if self.use_post_norm:
131
129
  x = self.norm2(x)
@@ -134,11 +132,7 @@ class ReactiveTransformerLayer(nn.Module):
134
132
  residual = x
135
133
  if not self.use_post_norm:
136
134
  x = self.norm3(x)
137
- if torch.isnan(x).any():
138
- print("NaN detected in pre-norm (ff) output")
139
135
  x = self.ff(x)
140
- if torch.isnan(x).any():
141
- print("NaN detected in ff output")
142
136
  x = residual + x
143
137
  if self.use_post_norm:
144
138
  x = self.norm3(x)
@@ -94,7 +94,7 @@ class ReactiveTransformerDecoder(ReactiveTransformerBase):
94
94
  def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
95
95
  x = super().forward(x) # apply embeddings
96
96
  if torch.isnan(x).any():
97
- print("NaN detected in decoder embedding output")
97
+ print("!!!!!!!!!!!!!!!!!!!!!! NaN detected in decoder embedding output")
98
98
  seq_len = x.size(1)
99
99
  if not self.use_flash_attention and self.use_relative_embedding:
100
100
  mask = create_causal_mask(seq_len, device=x.device)
@@ -112,7 +112,7 @@ class ReactiveTransformerDecoder(ReactiveTransformerBase):
112
112
  for i in range(self.num_own_layers):
113
113
  x = self._handle_layer(i, x, mask=mask)
114
114
  if torch.isnan(x).any():
115
- print(f"NaN detected in {i}. decoder layer output")
115
+ print(f"!!!!!!!!!!!!!!!!!!!!!! NaN detected in {i}. decoder layer output")
116
116
  return self.head(self.head_norm(x) if self.use_head_norm else x)
117
117
 
118
118
 
@@ -122,7 +122,7 @@ class ReactiveTransformerEncoder(ReactiveTransformerBase):
122
122
  def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> tuple[torch.Tensor, torch.Tensor]:
123
123
  x = super().forward(x) # apply embeddings
124
124
  if torch.isnan(x).any():
125
- print("NaN detected in encoder embedding output")
125
+ print("!!!!!!!!!!!!!!!!!!!!!! NaN detected in encoder embedding output")
126
126
  if attention_mask is not None:
127
127
  attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).bool()
128
128
 
@@ -136,7 +136,7 @@ class ReactiveTransformerEncoder(ReactiveTransformerBase):
136
136
  for i in range(self.num_own_layers):
137
137
  x = self._handle_layer(i, x, mask=attention_mask)
138
138
  if torch.isnan(x).any():
139
- print(f"NaN detected in {i}. encoder layer output")
139
+ print(f"!!!!!!!!!!!!!!!!!!!!!! NaN detected in {i}. encoder layer output")
140
140
  hidden_states.append(x)
141
141
  return x, torch.stack(hidden_states)
142
142
 
@@ -1,89 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- from .stm import ShortTermMemory
4
-
5
- class StmMemoryAttention(nn.Module):
6
- def __init__(
7
- self,
8
- stm: ShortTermMemory,
9
- attention_layers: nn.ModuleList,
10
- memory_norm_layers: nn.ModuleList,
11
- use_gated_residual: bool = False,
12
- per_slot_gate: bool = False,
13
- init_gate: float = 0.0,
14
- use_dynamic_gate: bool = False,
15
- use_tanh_gate: bool = False,
16
- debug_mode: bool = False,
17
- debug_interval: int = 10,
18
- *args,
19
- **kwargs
20
- ):
21
- super(StmMemoryAttention, self).__init__(*args, **kwargs)
22
- self.stm = stm
23
- self.attention_layers = attention_layers
24
- self.memory_norm_layers = memory_norm_layers
25
- assert len(self.attention_layers) == len(self.memory_norm_layers) == self.stm.memory.size(0)
26
- self.num_layers = len(attention_layers)
27
- self.use_gated_residual = use_gated_residual
28
- self.per_slot_gate = per_slot_gate
29
- self.use_dynamic_gate = use_dynamic_gate
30
- self.use_tanh_gate = use_tanh_gate
31
- if self.use_gated_residual:
32
- gate_shape = (self.num_layers, self.stm.stm_size, 1) if self.per_slot_gate else (self.num_layers,)
33
- self.gate = nn.Parameter(torch.full(gate_shape, init_gate))
34
-
35
- self.debug_mode = debug_mode
36
- self.debug_interval = debug_interval
37
- self.debug_step = 0
38
-
39
- def update_max_len(self, max_seq_len: int):
40
- for i in range(self.num_layers):
41
- if self.attention_layers[i].rope is not None:
42
- self.attention_layers[i].rope.update_max_len(max_seq_len)
43
-
44
- def _residual_gate(self, gate: torch.Tensor, layer_stm: torch.Tensor, new_layer_stm: torch.Tensor) -> torch.Tensor:
45
- if self.use_dynamic_gate:
46
- mean_dim = -1 if self.per_slot_gate else [1, 2]
47
- gate_input = gate * (new_layer_stm + layer_stm).mean(dim=mean_dim, keepdim=True)
48
- layer_gate = torch.tanh(gate_input) if self.use_tanh_gate else torch.sigmoid(gate_input)
49
- else:
50
- layer_gate = torch.tanh(gate) if self.use_tanh_gate else torch.sigmoid(gate)
51
- if self.use_tanh_gate:
52
- return (1 + layer_gate) * new_layer_stm + (1 - layer_gate) * layer_stm
53
- else:
54
- return layer_gate * new_layer_stm + (1 - layer_gate) * layer_stm
55
-
56
- def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
57
- if attention_mask is not None:
58
- attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).bool()
59
- new_stm = torch.zeros_like(self.stm.memory)
60
- for i in range(self.num_layers):
61
- layer_stm = self.stm(i)
62
- # expand layer STM to batch size, if it's not in batch mode
63
- if layer_stm.size(0) == 1:
64
- layer_stm = layer_stm.expand(x.size(0), -1, -1)
65
- encoded_layer_data = x[i]
66
- normalized_layer_stm = self.memory_norm_layers[i](layer_stm)
67
- if torch.isnan(normalized_layer_stm).any():
68
- print(f"NaN detected in {i} layer memory norm output")
69
-
70
- if self.debug_mode and self.training:
71
- if self.debug_step != 0 and self.debug_step % self.debug_interval == 0:
72
- self.debug_step = 0
73
- print(f"Normalized STM stats - mean: {normalized_layer_stm.mean().item():.4f}, std: {normalized_layer_stm.std().item():.4f}")
74
- else:
75
- self.debug_step += 1
76
-
77
- if torch.isnan(encoded_layer_data).any():
78
- print(f"NaN detected in {i} layer encoded data input")
79
-
80
- new_layer_stm = self.attention_layers[i](normalized_layer_stm, encoded_layer_data, encoded_layer_data, mask=attention_mask)
81
- if torch.isnan(new_layer_stm).any():
82
- print(f"NaN detected in {i} layer memory attention output")
83
-
84
- if self.use_gated_residual:
85
- new_stm[i] = self._residual_gate(self.gate[i], layer_stm, new_layer_stm) # gated residual
86
- else:
87
- new_stm[i] = new_layer_stm + layer_stm # residual
88
- self.stm.update_all(new_stm)
89
- return self.stm.memory
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes