rxnn 0.2.72__tar.gz → 0.2.74__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. {rxnn-0.2.72 → rxnn-0.2.74}/PKG-INFO +1 -1
  2. {rxnn-0.2.72 → rxnn-0.2.74}/pyproject.toml +1 -1
  3. rxnn-0.2.74/src/rxnn/memory/attention.py +150 -0
  4. rxnn-0.2.74/src/rxnn/memory/gate.py +60 -0
  5. {rxnn-0.2.72 → rxnn-0.2.74}/src/rxnn/rxt/models.py +152 -9
  6. {rxnn-0.2.72 → rxnn-0.2.74}/src/rxnn/transformers/layers.py +8 -16
  7. {rxnn-0.2.72 → rxnn-0.2.74}/src/rxnn/utils.py +0 -5
  8. rxnn-0.2.72/src/rxnn/memory/attention.py +0 -89
  9. {rxnn-0.2.72 → rxnn-0.2.74}/LICENSE +0 -0
  10. {rxnn-0.2.72 → rxnn-0.2.74}/README.md +0 -0
  11. {rxnn-0.2.72 → rxnn-0.2.74}/src/rxnn/.DS_Store +0 -0
  12. {rxnn-0.2.72 → rxnn-0.2.74}/src/rxnn/__init__.py +0 -0
  13. {rxnn-0.2.72 → rxnn-0.2.74}/src/rxnn/experimental/__init__.py +0 -0
  14. {rxnn-0.2.72 → rxnn-0.2.74}/src/rxnn/experimental/attention.py +0 -0
  15. {rxnn-0.2.72 → rxnn-0.2.74}/src/rxnn/experimental/models.py +0 -0
  16. {rxnn-0.2.72 → rxnn-0.2.74}/src/rxnn/experimental/moe.py +0 -0
  17. {rxnn-0.2.72 → rxnn-0.2.74}/src/rxnn/memory/__init__.py +0 -0
  18. {rxnn-0.2.72 → rxnn-0.2.74}/src/rxnn/memory/norm.py +0 -0
  19. {rxnn-0.2.72 → rxnn-0.2.74}/src/rxnn/memory/stm.py +0 -0
  20. {rxnn-0.2.72 → rxnn-0.2.74}/src/rxnn/rxt/__init__.py +0 -0
  21. {rxnn-0.2.72 → rxnn-0.2.74}/src/rxnn/training/__init__.py +0 -0
  22. {rxnn-0.2.72 → rxnn-0.2.74}/src/rxnn/training/base.py +0 -0
  23. {rxnn-0.2.72 → rxnn-0.2.74}/src/rxnn/training/bml.py +0 -0
  24. {rxnn-0.2.72 → rxnn-0.2.74}/src/rxnn/training/callbacks.py +0 -0
  25. {rxnn-0.2.72 → rxnn-0.2.74}/src/rxnn/training/dataset.py +0 -0
  26. {rxnn-0.2.72 → rxnn-0.2.74}/src/rxnn/training/ddp.py +0 -0
  27. {rxnn-0.2.72 → rxnn-0.2.74}/src/rxnn/training/models.py +0 -0
  28. {rxnn-0.2.72 → rxnn-0.2.74}/src/rxnn/training/mrl.py +0 -0
  29. {rxnn-0.2.72 → rxnn-0.2.74}/src/rxnn/training/reward.py +0 -0
  30. {rxnn-0.2.72 → rxnn-0.2.74}/src/rxnn/training/rl.py +0 -0
  31. {rxnn-0.2.72 → rxnn-0.2.74}/src/rxnn/training/scheduler.py +0 -0
  32. {rxnn-0.2.72 → rxnn-0.2.74}/src/rxnn/training/tokenizer.py +0 -0
  33. {rxnn-0.2.72 → rxnn-0.2.74}/src/rxnn/training/utils.py +0 -0
  34. {rxnn-0.2.72 → rxnn-0.2.74}/src/rxnn/transformers/__init__.py +0 -0
  35. {rxnn-0.2.72 → rxnn-0.2.74}/src/rxnn/transformers/attention.py +0 -0
  36. {rxnn-0.2.72 → rxnn-0.2.74}/src/rxnn/transformers/ff.py +0 -0
  37. {rxnn-0.2.72 → rxnn-0.2.74}/src/rxnn/transformers/mask.py +0 -0
  38. {rxnn-0.2.72 → rxnn-0.2.74}/src/rxnn/transformers/models.py +0 -0
  39. {rxnn-0.2.72 → rxnn-0.2.74}/src/rxnn/transformers/moe.py +0 -0
  40. {rxnn-0.2.72 → rxnn-0.2.74}/src/rxnn/transformers/positional.py +0 -0
  41. {rxnn-0.2.72 → rxnn-0.2.74}/src/rxnn/transformers/sampler.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.72
3
+ Version: 0.2.74
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
4
4
 
5
5
  [tool.poetry]
6
6
  name = "rxnn"
7
- version = "0.2.72"
7
+ version = "0.2.74"
8
8
  description = "RxNN: Reactive Neural Networks Platform"
9
9
 
10
10
  license = "Apache-2.0"
@@ -0,0 +1,150 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from .stm import ShortTermMemory
4
+
5
+
6
+ class StmMemoryAttention(nn.Module):
7
+ def __init__(
8
+ self,
9
+ stm: ShortTermMemory,
10
+ attention_layers: nn.ModuleList,
11
+ memory_norm_layers: nn.ModuleList,
12
+ memory_input_norm_layers: nn.ModuleList,
13
+ residual_gate_layers: nn.ModuleList,
14
+ debug_mode: bool = False,
15
+ debug_interval: int = 10,
16
+ *args,
17
+ **kwargs
18
+ ):
19
+ super(StmMemoryAttention, self).__init__(*args, **kwargs)
20
+ self.stm = stm
21
+ self.attention_layers = attention_layers
22
+ self.memory_norm_layers = memory_norm_layers
23
+ self.memory_input_norm_layers = memory_input_norm_layers
24
+ self.residual_gate_layers = residual_gate_layers
25
+ assert (len(self.attention_layers) == len(self.memory_norm_layers) ==
26
+ len(self.residual_gate_layers) == len(self.memory_input_norm_layers) ==
27
+ self.stm.memory.size(0))
28
+ self.num_layers = len(attention_layers)
29
+ self.debug_mode = debug_mode
30
+ self.debug_interval = debug_interval
31
+ self.debug_step = 0
32
+
33
+ def update_max_len(self, max_seq_len: int):
34
+ for i in range(self.num_layers):
35
+ if self.attention_layers[i].rope is not None:
36
+ self.attention_layers[i].rope.update_max_len(max_seq_len)
37
+
38
+ def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
39
+ # 1. Process correct attention mask
40
+ if attention_mask is not None:
41
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).bool()
42
+ # 2. Init new empty STM
43
+ new_stm = torch.zeros_like(self.stm.memory)
44
+
45
+ # 3. Run Short-Term Memory update for all layers
46
+ for i in range(self.num_layers):
47
+ # 4. Get current layer STM value
48
+ layer_stm = self.stm(i)
49
+ # 5. Expand layer STM to batch size, if it's not in batch mode
50
+ if layer_stm.size(0) == 1:
51
+ layer_stm = layer_stm.expand(x.size(0), -1, -1)
52
+
53
+ # 6. Get encoded layer data and normalize it
54
+ encoded_layer_data = self.memory_input_norm_layers[i](x[i])
55
+ # 7. Normalize STM layer
56
+ normalized_layer_stm = self.memory_norm_layers[i](layer_stm)
57
+
58
+ # 8. Print normalization stats in debug mode
59
+ if self.debug_mode and self.training:
60
+ if self.debug_step != 0 and self.debug_step % self.debug_interval == 0:
61
+ self.debug_step = 0
62
+ print(f"Normalized STM stats - mean: {normalized_layer_stm.mean().item():.4f}, std: {normalized_layer_stm.std().item():.4f}")
63
+ else:
64
+ self.debug_step += 1
65
+
66
+ # 9. Calculate memory attention
67
+ new_layer_stm = self.attention_layers[i](normalized_layer_stm, encoded_layer_data, encoded_layer_data, mask=attention_mask)
68
+ # 10. Combine new updated layer state with current STM state in residual gate
69
+ new_stm[i] = self.residual_gate_layers[i](layer_stm, new_layer_stm) # residual
70
+ # 11. Update all layers/models
71
+ self.stm.update_all(new_stm)
72
+ return self.stm.memory
73
+
74
+
75
+ class InterlayerStmMemoryAttention(StmMemoryAttention):
76
+ def __init__(
77
+ self,
78
+ stm: ShortTermMemory,
79
+ attention_layers: nn.ModuleList,
80
+ memory_norm_layers: nn.ModuleList,
81
+ memory_input_norm_layers: nn.ModuleList,
82
+ residual_gate_layers: nn.ModuleList,
83
+ mean_attention_layers: nn.ModuleList,
84
+ mean_memory_norm_layers: nn.ModuleList,
85
+ mean_residual_gate_layers: nn.ModuleList,
86
+ mean_stm_norm: nn.Module,
87
+ debug_mode: bool = False,
88
+ debug_interval: int = 10,
89
+ **kwargs
90
+ ):
91
+ super(InterlayerStmMemoryAttention, self).__init__(
92
+ stm, attention_layers, memory_norm_layers, memory_input_norm_layers, residual_gate_layers,
93
+ debug_mode=debug_mode, debug_interval=debug_interval, **kwargs
94
+ )
95
+ self.mean_attention_layers = mean_attention_layers
96
+ self.mean_memory_norm_layers = mean_memory_norm_layers
97
+ self.mean_stm_norm = mean_stm_norm
98
+ self.mean_residual_gate_layers = mean_residual_gate_layers
99
+ assert (len(self.mean_attention_layers) == len(self.mean_memory_norm_layers) ==
100
+ len(self.mean_residual_gate_layers) == self.num_layers)
101
+
102
+ def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
103
+ # 1. Process correct attention mask
104
+ if attention_mask is not None:
105
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).bool()
106
+ # 2. Init new empty STM
107
+ new_stm = torch.zeros_like(self.stm.memory)
108
+
109
+ # 3. Get mean STM value from layers for mean interlayer memory attention
110
+ mean_stm = self.stm.memory.mean(dim=0) # [batch_size, stm_size, embed_dim]
111
+ # 4. Normalize mean STM layer
112
+ normalized_mean_stm = self.mean_stm_norm(mean_stm)
113
+
114
+ # 5. Run Short-Term Memory update for all layers
115
+ for i in range(self.num_layers):
116
+ # 6. Get current layer STM value
117
+ layer_stm = self.stm(i)
118
+ # 7. Expand layer STM to batch size, if it's not in batch mode
119
+ if layer_stm.size(0) == 1:
120
+ layer_stm = layer_stm.expand(x.size(0), -1, -1)
121
+
122
+ # 8. Mean interlayer memory attention
123
+ # a) normalize STM layer value
124
+ pre_normalized_layer_stm = self.mean_memory_norm_layers[i](layer_stm)
125
+ # b) calculate attention between STM layer and mean value of all STM layers (from previous interaction)
126
+ interlayer_stm = self.mean_attention_layers[i](pre_normalized_layer_stm, normalized_mean_stm, normalized_mean_stm, mask=None)
127
+ # c) combine updated interlayer state with current STM state in residual gate
128
+ updated_layer_stm = self.mean_residual_gate_layers[i](layer_stm, interlayer_stm)
129
+
130
+ # 9. Main memory attention
131
+ # a) get encoded data for current layer and normalize it
132
+ encoded_layer_data = self.memory_input_norm_layers[i](x[i])
133
+ # b) normalize STM layer value
134
+ normalized_layer_stm = self.memory_norm_layers[i](updated_layer_stm)
135
+ # c) print normalized STM stats in debug mode
136
+ if self.debug_mode and self.training:
137
+ if self.debug_step != 0 and self.debug_step % self.debug_interval == 0:
138
+ self.debug_step = 0
139
+ print(
140
+ f"Pre-Normalized STM stats - mean: {pre_normalized_layer_stm.mean().item():.4f}, std: {pre_normalized_layer_stm.std().item():.4f}")
141
+ print(f"Normalized STM stats - mean: {normalized_layer_stm.mean().item():.4f}, std: {normalized_layer_stm.std().item():.4f}")
142
+ else:
143
+ self.debug_step += 1
144
+ # d) calculate memory attention between STM layer and encoded data
145
+ new_layer_stm = self.attention_layers[i](normalized_layer_stm, encoded_layer_data, encoded_layer_data, mask=attention_mask)
146
+ # e) combine new updated layer STM with previous state in residual gate
147
+ new_stm[i] = self.residual_gate_layers[i](updated_layer_stm, new_layer_stm) # residual
148
+ # 10. Update all layers/models
149
+ self.stm.update_all(new_stm)
150
+ return self.stm.memory
@@ -0,0 +1,60 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import TypeAlias, Literal
4
+
5
+ ResidualGateType: TypeAlias = Literal['static', 'elementwise', 'linear']
6
+
7
+
8
+ class ResidualGate(nn.Module):
9
+ def __init__(
10
+ self,
11
+ stm_size: int,
12
+ use_gate: bool = False,
13
+ gate_type: ResidualGateType = 'static',
14
+ per_slot_gate: bool = True,
15
+ init_gate: float = 0.0,
16
+ use_tanh_gate: bool = True,
17
+ **kwargs,
18
+ ):
19
+ super(ResidualGate, self).__init__(**kwargs)
20
+ self.use_gate = use_gate
21
+ self.per_slot_gate = per_slot_gate
22
+ self.gate_type = gate_type
23
+ self.use_tanh_gate = use_tanh_gate
24
+
25
+ if self.use_gate:
26
+ if self.gate_type == 'linear':
27
+ self.gate = nn.Linear(stm_size, stm_size if self.per_slot_gate else 1)
28
+ else:
29
+ gate_shape = (stm_size, 1) if self.per_slot_gate else (1,)
30
+ self.gate = nn.Parameter(torch.full(gate_shape, init_gate))
31
+ else:
32
+ self.gate = None
33
+
34
+ self.gate_activation = nn.Tanh() if self.use_tanh_gate else nn.Sigmoid()
35
+
36
+ def _dynamic_gate(self, old_value: torch.Tensor, new_value: torch.Tensor):
37
+ if self.gate_type == 'linear':
38
+ mean_residual = (new_value + old_value).mean(dim=-1)
39
+ gate_input = self.gate(mean_residual).unsqueeze(-1)
40
+ else:
41
+ mean_dim = -1 if self.per_slot_gate else [1, 2]
42
+ gate_input = self.gate * (new_value + old_value).mean(dim=mean_dim, keepdim=True)
43
+ return self.gate_activation(gate_input)
44
+
45
+ def _calculate_output(self, layer_gate: torch.Tensor, old_value: torch.Tensor, new_value: torch.Tensor) -> torch.Tensor:
46
+ if self.use_tanh_gate:
47
+ return (1 + layer_gate) * new_value + (1 - layer_gate) * old_value
48
+ else:
49
+ return layer_gate * new_value + (1 - layer_gate) * old_value
50
+
51
+ def forward(self, old_value: torch.Tensor, new_value: torch.Tensor) -> torch.Tensor:
52
+ if not self.use_gate:
53
+ return new_value + old_value
54
+
55
+ if self.gate_type == 'static':
56
+ layer_gate = self.gate_activation(self.gate)
57
+ else:
58
+ layer_gate = self._dynamic_gate(old_value, new_value)
59
+
60
+ return self._calculate_output(layer_gate, old_value, new_value)
@@ -9,7 +9,8 @@ from ..transformers.models import ReactiveTransformerBase, ReactiveTransformerEn
9
9
  from ..transformers.ff import get_activation_layer
10
10
  from ..memory.stm import ShortTermMemory
11
11
  from ..memory.norm import init_memory_norm
12
- from ..memory.attention import StmMemoryAttention
12
+ from ..memory.attention import StmMemoryAttention, InterlayerStmMemoryAttention
13
+ from ..memory.gate import ResidualGate, ResidualGateType
13
14
  from ..utils import get_model_size
14
15
  from ..experimental.attention import init_experimental_attention
15
16
 
@@ -260,15 +261,15 @@ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2
260
261
  att_experts: int = None,
261
262
  att_query_experts: int = None,
262
263
  att_query_groups: int = None,
263
- norm_type: str = 'rms',
264
+ norm_type: str = 'classic-rms',
264
265
  norm_init_gate: float = -2.0,
265
266
  norm_per_dim_scale: bool = False,
266
267
  norm_decay: float = 0.9,
267
268
  use_gated_residual: bool = False,
268
- residual_per_slot_gate: bool = False,
269
- residual_init_gate: float = 0.0,
270
- use_dynamic_residual_gate: bool = False,
271
- use_tanh_residual_gate: bool = False,
269
+ residual_per_slot_gate: bool = True,
270
+ residual_gate_init: float = 3.0,
271
+ residual_gate_type: ResidualGateType = 'static',
272
+ use_tanh_residual_gate: bool = True,
272
273
  debug_mode: bool = False,
273
274
  debug_interval: int = 10,
274
275
  **kwargs,
@@ -296,12 +297,19 @@ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2
296
297
  memory_norm_layers = nn.ModuleList([init_memory_norm(norm_type, embed_dim, stm_size, decay=norm_decay,
297
298
  init_gate=norm_init_gate, per_dim_scale=norm_per_dim_scale)
298
299
  for _ in range(num_layers)])
300
+ memory_input_norm_layers = nn.ModuleList(nn.RMSNorm(embed_dim) for _ in range(num_layers))
299
301
  attention_layers = nn.ModuleList([att_init() for _ in range(num_layers)])
302
+ residual_gates = nn.ModuleList([
303
+ ResidualGate(
304
+ stm_size, use_gate=use_gated_residual, gate_type=residual_gate_type, per_slot_gate=residual_per_slot_gate,
305
+ init_gate=residual_gate_init, use_tanh_gate=use_tanh_residual_gate
306
+ ) for _ in range(num_layers)
307
+ ])
308
+
300
309
  self.model = StmMemoryAttention(
301
310
  stm, attention_layers, memory_norm_layers,
302
- use_gated_residual=use_gated_residual, per_slot_gate=residual_per_slot_gate,
303
- init_gate=residual_init_gate, use_dynamic_gate=use_dynamic_residual_gate,
304
- use_tanh_gate=use_tanh_residual_gate, debug_mode=debug_mode, debug_interval=debug_interval,
311
+ memory_input_norm_layers, residual_gates,
312
+ debug_mode=debug_mode, debug_interval=debug_interval,
305
313
  )
306
314
 
307
315
  def freeze(self):
@@ -327,6 +335,141 @@ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2
327
335
  def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
328
336
  return self.model(x, attention_mask=attention_mask)
329
337
 
338
+
339
+ class RxTAlphaInterlayerMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2.0"):
340
+ """RxT-Alpha (Reactive Transformer) memory attention model with interlayer STM attention"""
341
+
342
+ def __init__(
343
+ self,
344
+ num_layers: int = 12,
345
+ embed_dim: int = 512,
346
+ att_heads: int = 16,
347
+ seq_len: int = 1024,
348
+ stm_size: int = 1024,
349
+ use_flash_attention: bool = False,
350
+ att_dropout: float = 0.0,
351
+ att_groups: int = 1,
352
+ att_type: str = 'sqa',
353
+ att_experts: int = None,
354
+ att_query_experts: int = None,
355
+ att_query_groups: int = None,
356
+ interlayer_att_dropout: float = 0.0,
357
+ interlayer_att_groups: int = 1,
358
+ interlayer_att_type: str = 'sqa',
359
+ interlayer_att_experts: int = None,
360
+ interlayer_att_query_experts: int = None,
361
+ interlayer_att_query_groups: int = None,
362
+ norm_type: str = 'classic-rms',
363
+ norm_init_gate: float = -2.0,
364
+ norm_per_dim_scale: bool = False,
365
+ norm_decay: float = 0.9,
366
+ use_gated_residual: bool = False,
367
+ residual_per_slot_gate: bool = True,
368
+ residual_gate_init: float = 3.0,
369
+ residual_gate_type: ResidualGateType = 'static',
370
+ use_tanh_residual_gate: bool = True,
371
+ debug_mode: bool = False,
372
+ debug_interval: int = 10,
373
+ **kwargs,
374
+ ):
375
+ super(RxTAlphaInterlayerMemoryAttention, self).__init__(**kwargs)
376
+
377
+ assert att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma',
378
+ 'sqa'], 'Memory attention type could be "mha", "gqa", "mqa", "gma", "dma", "sqa".'
379
+
380
+ rope = RotaryPositionalEmbedding(embed_dim // att_heads, seq_len)
381
+ stm = ShortTermMemory(num_layers, embed_dim, stm_size)
382
+
383
+ if att_type in ['mha', 'gqa', 'mqa']:
384
+ att_init = lambda: init_attention(
385
+ embed_dim, att_heads, att_type, att_groups, rope=rope,
386
+ use_flash_attention=use_flash_attention, dropout=att_dropout,
387
+ max_seq_len=seq_len, is_causal=False, rope_only_for_keys=True
388
+ )
389
+ else:
390
+ att_init = lambda: init_experimental_attention(
391
+ embed_dim, att_heads, att_type, att_groups, rope=rope,
392
+ use_flash_attention=use_flash_attention, dropout=att_dropout,
393
+ max_seq_len=seq_len, is_causal=False, num_experts=att_experts,
394
+ num_query_experts=att_query_experts, num_query_groups=att_query_groups,
395
+ rope_only_for_keys=True
396
+ )
397
+
398
+ memory_norm_layers = nn.ModuleList([init_memory_norm(norm_type, embed_dim, stm_size, decay=norm_decay,
399
+ init_gate=norm_init_gate, per_dim_scale=norm_per_dim_scale)
400
+ for _ in range(num_layers)])
401
+ memory_input_norm_layers = nn.ModuleList(nn.RMSNorm(embed_dim) for _ in range(num_layers))
402
+ attention_layers = nn.ModuleList([att_init() for _ in range(num_layers)])
403
+ residual_gates = nn.ModuleList([
404
+ ResidualGate(
405
+ stm_size, use_gate=use_gated_residual, gate_type=residual_gate_type,
406
+ per_slot_gate=residual_per_slot_gate,
407
+ init_gate=residual_gate_init, use_tanh_gate=use_tanh_residual_gate
408
+ ) for _ in range(num_layers)
409
+ ])
410
+
411
+ # Interlayer attention
412
+ if interlayer_att_type in ['mha', 'gqa', 'mqa']:
413
+ interlayer_att_init = lambda: init_attention(
414
+ embed_dim, att_heads, interlayer_att_type, interlayer_att_groups, rope=None,
415
+ use_flash_attention=use_flash_attention, dropout=interlayer_att_dropout, is_causal=False
416
+ )
417
+ else:
418
+ interlayer_att_init = lambda: init_experimental_attention(
419
+ embed_dim, att_heads, interlayer_att_type, interlayer_att_groups, rope=None,
420
+ use_flash_attention=use_flash_attention, dropout=interlayer_att_dropout, is_causal=False,
421
+ num_experts=interlayer_att_experts, num_query_experts=interlayer_att_query_experts, num_query_groups=interlayer_att_query_groups
422
+ )
423
+
424
+ mean_attention_layers = nn.ModuleList([interlayer_att_init() for _ in range(num_layers)])
425
+
426
+ mean_stm_norm = init_memory_norm(
427
+ norm_type, embed_dim, stm_size, decay=norm_decay,
428
+ init_gate=norm_init_gate, per_dim_scale=norm_per_dim_scale
429
+ )
430
+
431
+ mean_memory_norm_layers = nn.ModuleList([init_memory_norm(norm_type, embed_dim, stm_size, decay=norm_decay,
432
+ init_gate=norm_init_gate, per_dim_scale=norm_per_dim_scale)
433
+ for _ in range(num_layers)])
434
+
435
+ mean_residual_gates = nn.ModuleList([
436
+ ResidualGate(
437
+ stm_size, use_gate=use_gated_residual, gate_type=residual_gate_type,
438
+ per_slot_gate=residual_per_slot_gate,
439
+ init_gate=residual_gate_init, use_tanh_gate=use_tanh_residual_gate
440
+ ) for _ in range(num_layers)
441
+ ])
442
+
443
+ self.model = InterlayerStmMemoryAttention(
444
+ stm, attention_layers, memory_norm_layers, memory_input_norm_layers, residual_gates,
445
+ mean_attention_layers, mean_memory_norm_layers, mean_residual_gates, mean_stm_norm,
446
+ debug_mode=debug_mode, debug_interval=debug_interval,
447
+ )
448
+
449
+ def freeze(self):
450
+ for param in self.parameters():
451
+ param.requires_grad = False
452
+
453
+ def unfreeze(self):
454
+ for param in self.parameters():
455
+ param.requires_grad = True
456
+
457
+ def load_shared_memory(self, stm: ShortTermMemory):
458
+ self.model.stm = stm
459
+
460
+ def update_max_len(self, max_seq_len: int):
461
+ self.model.update_max_len(max_seq_len)
462
+
463
+ def reset_memory(self, init_type: str = None):
464
+ self.model.stm.reset(init_type)
465
+
466
+ def clone_reset_memory(self):
467
+ self.model.stm.clone_detach_reset()
468
+
469
+ def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
470
+ return self.model(x, attention_mask=attention_mask)
471
+
472
+
330
473
  class RxTAlphaCriticEncoder(RxTAlphaComponentBase, pipeline_tag="text-classification", license="apache-2.0"):
331
474
  """RxT-Alpha (Reactive Transformer) encoder model"""
332
475
 
@@ -49,10 +49,12 @@ class ReactiveTransformerLayer(nn.Module):
49
49
  self.norm1 = nn.RMSNorm(embed_dim)
50
50
  self.norm2 = nn.RMSNorm(embed_dim)
51
51
  self.norm3 = nn.RMSNorm(embed_dim)
52
+ self.stm_norm = nn.RMSNorm(embed_dim)
52
53
  else:
53
54
  self.norm1 = nn.LayerNorm(embed_dim)
54
55
  self.norm2 = nn.LayerNorm(embed_dim)
55
56
  self.norm3 = nn.LayerNorm(embed_dim)
57
+ self.stm_norm = nn.LayerNorm(embed_dim)
56
58
  self.use_post_norm = use_post_norm
57
59
  self.use_moe = use_moe
58
60
  self.use_moe_att = use_moe_att
@@ -63,9 +65,11 @@ class ReactiveTransformerLayer(nn.Module):
63
65
  if with_norms:
64
66
  for param in self.norm2.parameters():
65
67
  param.requires_grad_(is_trainable)
68
+ for param in self.stm_norm.parameters():
69
+ param.requires_grad_(is_trainable)
66
70
 
67
71
  def memory_parameters(self) -> list[nn.Parameter]:
68
- return list(self.memory_cross_attention.parameters()) + list(self.norm2.parameters())
72
+ return list(self.memory_cross_attention.parameters()) + list(self.norm2.parameters()) + list(self.stm_norm.parameters())
69
73
 
70
74
  def not_memory_parameters(self) -> list[nn.Parameter]:
71
75
  return (list(self.attention.parameters()) + list(self.norm1.parameters()) +
@@ -102,11 +106,8 @@ class ReactiveTransformerLayer(nn.Module):
102
106
  residual = x
103
107
  if not self.use_post_norm:
104
108
  x = self.norm1(x)
105
- if torch.isnan(x).any():
106
- print("!!!!!!!!!!!!!!!!!!!!!! !!!!!!!!!!!!!!!!!!!!!! NaN detected in pre-norm (self-attention) output")
107
109
  x = self.attention(x, x, x, mask=mask)
108
- if torch.isnan(x).any():
109
- print("!!!!!!!!!!!!!!!!!!!!!! !!!!!!!!!!!!!!!!!!!!!! NaN detected in self-attention output")
110
+
110
111
  x = residual + x
111
112
  if self.use_post_norm:
112
113
  x = self.norm1(x)
@@ -114,18 +115,13 @@ class ReactiveTransformerLayer(nn.Module):
114
115
  residual = x
115
116
  if not self.use_post_norm:
116
117
  x = self.norm2(x)
117
- if torch.isnan(x).any():
118
- print("!!!!!!!!!!!!!!!!!!!!!! NaN detected in pre-norm (cross-attention) output")
119
118
 
119
+ # normalize STM and prepare STM mask
120
+ stm = self.stm_norm(stm)
120
121
  mem_mask = mask.squeeze(1).unsqueeze(-1).expand(-1, -1, -1, stm.size(1)) \
121
122
  if mask is not None else None
122
123
 
123
- if torch.isnan(stm).any():
124
- print("!!!!!!!!!!!!!!!!!!!!!! NaN detected in STM cross-attention input")
125
-
126
124
  x = self.memory_cross_attention(x, stm, stm, mask=mem_mask)
127
- if torch.isnan(x).any():
128
- print("!!!!!!!!!!!!!!!!!!!!!! NaN detected in cross-attention output")
129
125
  x = residual + x
130
126
  if self.use_post_norm:
131
127
  x = self.norm2(x)
@@ -134,11 +130,7 @@ class ReactiveTransformerLayer(nn.Module):
134
130
  residual = x
135
131
  if not self.use_post_norm:
136
132
  x = self.norm3(x)
137
- if torch.isnan(x).any():
138
- print("!!!!!!!!!!!!!!!!!!!!!! NaN detected in pre-norm (ff) output")
139
133
  x = self.ff(x)
140
- if torch.isnan(x).any():
141
- print("!!!!!!!!!!!!!!!!!!!!!! NaN detected in ff output")
142
134
  x = residual + x
143
135
  if self.use_post_norm:
144
136
  x = self.norm3(x)
@@ -1,11 +1,6 @@
1
1
  import random, gc
2
- from typing import Optional, Union, List, Dict, Any
3
-
4
2
  import torch
5
3
  import numpy as np
6
- from huggingface_hub import PyTorchModelHubMixin
7
- from huggingface_hub.hub_mixin import DataclassInstance
8
-
9
4
 
10
5
  def human_format(num: int):
11
6
  """Format numbers to human-readable format."""
@@ -1,89 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- from .stm import ShortTermMemory
4
-
5
- class StmMemoryAttention(nn.Module):
6
- def __init__(
7
- self,
8
- stm: ShortTermMemory,
9
- attention_layers: nn.ModuleList,
10
- memory_norm_layers: nn.ModuleList,
11
- use_gated_residual: bool = False,
12
- per_slot_gate: bool = False,
13
- init_gate: float = 0.0,
14
- use_dynamic_gate: bool = False,
15
- use_tanh_gate: bool = False,
16
- debug_mode: bool = False,
17
- debug_interval: int = 10,
18
- *args,
19
- **kwargs
20
- ):
21
- super(StmMemoryAttention, self).__init__(*args, **kwargs)
22
- self.stm = stm
23
- self.attention_layers = attention_layers
24
- self.memory_norm_layers = memory_norm_layers
25
- assert len(self.attention_layers) == len(self.memory_norm_layers) == self.stm.memory.size(0)
26
- self.num_layers = len(attention_layers)
27
- self.use_gated_residual = use_gated_residual
28
- self.per_slot_gate = per_slot_gate
29
- self.use_dynamic_gate = use_dynamic_gate
30
- self.use_tanh_gate = use_tanh_gate
31
- if self.use_gated_residual:
32
- gate_shape = (self.num_layers, self.stm.stm_size, 1) if self.per_slot_gate else (self.num_layers,)
33
- self.gate = nn.Parameter(torch.full(gate_shape, init_gate))
34
-
35
- self.debug_mode = debug_mode
36
- self.debug_interval = debug_interval
37
- self.debug_step = 0
38
-
39
- def update_max_len(self, max_seq_len: int):
40
- for i in range(self.num_layers):
41
- if self.attention_layers[i].rope is not None:
42
- self.attention_layers[i].rope.update_max_len(max_seq_len)
43
-
44
- def _residual_gate(self, gate: torch.Tensor, layer_stm: torch.Tensor, new_layer_stm: torch.Tensor) -> torch.Tensor:
45
- if self.use_dynamic_gate:
46
- mean_dim = -1 if self.per_slot_gate else [1, 2]
47
- gate_input = gate * (new_layer_stm + layer_stm).mean(dim=mean_dim, keepdim=True)
48
- layer_gate = torch.tanh(gate_input) if self.use_tanh_gate else torch.sigmoid(gate_input)
49
- else:
50
- layer_gate = torch.tanh(gate) if self.use_tanh_gate else torch.sigmoid(gate)
51
- if self.use_tanh_gate:
52
- return (1 + layer_gate) * new_layer_stm + (1 - layer_gate) * layer_stm
53
- else:
54
- return layer_gate * new_layer_stm + (1 - layer_gate) * layer_stm
55
-
56
- def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
57
- if attention_mask is not None:
58
- attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).bool()
59
- new_stm = torch.zeros_like(self.stm.memory)
60
- for i in range(self.num_layers):
61
- layer_stm = self.stm(i)
62
- # expand layer STM to batch size, if it's not in batch mode
63
- if layer_stm.size(0) == 1:
64
- layer_stm = layer_stm.expand(x.size(0), -1, -1)
65
- encoded_layer_data = x[i]
66
- normalized_layer_stm = self.memory_norm_layers[i](layer_stm)
67
- if torch.isnan(normalized_layer_stm).any():
68
- print(f"!!!!!!!!!!!!!!!!!!!!!! NaN detected in {i} layer memory norm output")
69
-
70
- if self.debug_mode and self.training:
71
- if self.debug_step != 0 and self.debug_step % self.debug_interval == 0:
72
- self.debug_step = 0
73
- print(f"Normalized STM stats - mean: {normalized_layer_stm.mean().item():.4f}, std: {normalized_layer_stm.std().item():.4f}")
74
- else:
75
- self.debug_step += 1
76
-
77
- if torch.isnan(encoded_layer_data).any():
78
- print(f"!!!!!!!!!!!!!!!!!!!!!! NaN detected in {i} layer encoded data input")
79
-
80
- new_layer_stm = self.attention_layers[i](normalized_layer_stm, encoded_layer_data, encoded_layer_data, mask=attention_mask)
81
- if torch.isnan(new_layer_stm).any():
82
- print(f"!!!!!!!!!!!!!!!!!!!!!! NaN detected in {i} layer memory attention output")
83
-
84
- if self.use_gated_residual:
85
- new_stm[i] = self._residual_gate(self.gate[i], layer_stm, new_layer_stm) # gated residual
86
- else:
87
- new_stm[i] = new_layer_stm + layer_stm # residual
88
- self.stm.update_all(new_stm)
89
- return self.stm.memory
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes