rxnn 0.2.72__py3-none-any.whl → 0.2.74__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/memory/attention.py CHANGED
@@ -2,17 +2,15 @@ import torch
2
2
  import torch.nn as nn
3
3
  from .stm import ShortTermMemory
4
4
 
5
+
5
6
  class StmMemoryAttention(nn.Module):
6
7
  def __init__(
7
8
  self,
8
9
  stm: ShortTermMemory,
9
10
  attention_layers: nn.ModuleList,
10
11
  memory_norm_layers: nn.ModuleList,
11
- use_gated_residual: bool = False,
12
- per_slot_gate: bool = False,
13
- init_gate: float = 0.0,
14
- use_dynamic_gate: bool = False,
15
- use_tanh_gate: bool = False,
12
+ memory_input_norm_layers: nn.ModuleList,
13
+ residual_gate_layers: nn.ModuleList,
16
14
  debug_mode: bool = False,
17
15
  debug_interval: int = 10,
18
16
  *args,
@@ -22,16 +20,12 @@ class StmMemoryAttention(nn.Module):
22
20
  self.stm = stm
23
21
  self.attention_layers = attention_layers
24
22
  self.memory_norm_layers = memory_norm_layers
25
- assert len(self.attention_layers) == len(self.memory_norm_layers) == self.stm.memory.size(0)
23
+ self.memory_input_norm_layers = memory_input_norm_layers
24
+ self.residual_gate_layers = residual_gate_layers
25
+ assert (len(self.attention_layers) == len(self.memory_norm_layers) ==
26
+ len(self.residual_gate_layers) == len(self.memory_input_norm_layers) ==
27
+ self.stm.memory.size(0))
26
28
  self.num_layers = len(attention_layers)
27
- self.use_gated_residual = use_gated_residual
28
- self.per_slot_gate = per_slot_gate
29
- self.use_dynamic_gate = use_dynamic_gate
30
- self.use_tanh_gate = use_tanh_gate
31
- if self.use_gated_residual:
32
- gate_shape = (self.num_layers, self.stm.stm_size, 1) if self.per_slot_gate else (self.num_layers,)
33
- self.gate = nn.Parameter(torch.full(gate_shape, init_gate))
34
-
35
29
  self.debug_mode = debug_mode
36
30
  self.debug_interval = debug_interval
37
31
  self.debug_step = 0
@@ -41,32 +35,27 @@ class StmMemoryAttention(nn.Module):
41
35
  if self.attention_layers[i].rope is not None:
42
36
  self.attention_layers[i].rope.update_max_len(max_seq_len)
43
37
 
44
- def _residual_gate(self, gate: torch.Tensor, layer_stm: torch.Tensor, new_layer_stm: torch.Tensor) -> torch.Tensor:
45
- if self.use_dynamic_gate:
46
- mean_dim = -1 if self.per_slot_gate else [1, 2]
47
- gate_input = gate * (new_layer_stm + layer_stm).mean(dim=mean_dim, keepdim=True)
48
- layer_gate = torch.tanh(gate_input) if self.use_tanh_gate else torch.sigmoid(gate_input)
49
- else:
50
- layer_gate = torch.tanh(gate) if self.use_tanh_gate else torch.sigmoid(gate)
51
- if self.use_tanh_gate:
52
- return (1 + layer_gate) * new_layer_stm + (1 - layer_gate) * layer_stm
53
- else:
54
- return layer_gate * new_layer_stm + (1 - layer_gate) * layer_stm
55
-
56
38
  def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
39
+ # 1. Process correct attention mask
57
40
  if attention_mask is not None:
58
41
  attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).bool()
42
+ # 2. Init new empty STM
59
43
  new_stm = torch.zeros_like(self.stm.memory)
44
+
45
+ # 3. Run Short-Term Memory update for all layers
60
46
  for i in range(self.num_layers):
47
+ # 4. Get current layer STM value
61
48
  layer_stm = self.stm(i)
62
- # expand layer STM to batch size, if it's not in batch mode
49
+ # 5. Expand layer STM to batch size, if it's not in batch mode
63
50
  if layer_stm.size(0) == 1:
64
51
  layer_stm = layer_stm.expand(x.size(0), -1, -1)
65
- encoded_layer_data = x[i]
52
+
53
+ # 6. Get encoded layer data and normalize it
54
+ encoded_layer_data = self.memory_input_norm_layers[i](x[i])
55
+ # 7. Normalize STM layer
66
56
  normalized_layer_stm = self.memory_norm_layers[i](layer_stm)
67
- if torch.isnan(normalized_layer_stm).any():
68
- print(f"!!!!!!!!!!!!!!!!!!!!!! NaN detected in {i} layer memory norm output")
69
57
 
58
+ # 8. Print normalization stats in debug mode
70
59
  if self.debug_mode and self.training:
71
60
  if self.debug_step != 0 and self.debug_step % self.debug_interval == 0:
72
61
  self.debug_step = 0
@@ -74,16 +63,88 @@ class StmMemoryAttention(nn.Module):
74
63
  else:
75
64
  self.debug_step += 1
76
65
 
77
- if torch.isnan(encoded_layer_data).any():
78
- print(f"!!!!!!!!!!!!!!!!!!!!!! NaN detected in {i} layer encoded data input")
79
-
66
+ # 9. Calculate memory attention
80
67
  new_layer_stm = self.attention_layers[i](normalized_layer_stm, encoded_layer_data, encoded_layer_data, mask=attention_mask)
81
- if torch.isnan(new_layer_stm).any():
82
- print(f"!!!!!!!!!!!!!!!!!!!!!! NaN detected in {i} layer memory attention output")
68
+ # 10. Combine new updated layer state with current STM state in residual gate
69
+ new_stm[i] = self.residual_gate_layers[i](layer_stm, new_layer_stm) # residual
70
+ # 11. Update all layers/models
71
+ self.stm.update_all(new_stm)
72
+ return self.stm.memory
73
+
74
+
75
+ class InterlayerStmMemoryAttention(StmMemoryAttention):
76
+ def __init__(
77
+ self,
78
+ stm: ShortTermMemory,
79
+ attention_layers: nn.ModuleList,
80
+ memory_norm_layers: nn.ModuleList,
81
+ memory_input_norm_layers: nn.ModuleList,
82
+ residual_gate_layers: nn.ModuleList,
83
+ mean_attention_layers: nn.ModuleList,
84
+ mean_memory_norm_layers: nn.ModuleList,
85
+ mean_residual_gate_layers: nn.ModuleList,
86
+ mean_stm_norm: nn.Module,
87
+ debug_mode: bool = False,
88
+ debug_interval: int = 10,
89
+ **kwargs
90
+ ):
91
+ super(InterlayerStmMemoryAttention, self).__init__(
92
+ stm, attention_layers, memory_norm_layers, memory_input_norm_layers, residual_gate_layers,
93
+ debug_mode=debug_mode, debug_interval=debug_interval, **kwargs
94
+ )
95
+ self.mean_attention_layers = mean_attention_layers
96
+ self.mean_memory_norm_layers = mean_memory_norm_layers
97
+ self.mean_stm_norm = mean_stm_norm
98
+ self.mean_residual_gate_layers = mean_residual_gate_layers
99
+ assert (len(self.mean_attention_layers) == len(self.mean_memory_norm_layers) ==
100
+ len(self.mean_residual_gate_layers) == self.num_layers)
83
101
 
84
- if self.use_gated_residual:
85
- new_stm[i] = self._residual_gate(self.gate[i], layer_stm, new_layer_stm) # gated residual
86
- else:
87
- new_stm[i] = new_layer_stm + layer_stm # residual
102
+ def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
103
+ # 1. Process correct attention mask
104
+ if attention_mask is not None:
105
+ attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).bool()
106
+ # 2. Init new empty STM
107
+ new_stm = torch.zeros_like(self.stm.memory)
108
+
109
+ # 3. Get mean STM value from layers for mean interlayer memory attention
110
+ mean_stm = self.stm.memory.mean(dim=0) # [batch_size, stm_size, embed_dim]
111
+ # 4. Normalize mean STM layer
112
+ normalized_mean_stm = self.mean_stm_norm(mean_stm)
113
+
114
+ # 5. Run Short-Term Memory update for all layers
115
+ for i in range(self.num_layers):
116
+ # 6. Get current layer STM value
117
+ layer_stm = self.stm(i)
118
+ # 7. Expand layer STM to batch size, if it's not in batch mode
119
+ if layer_stm.size(0) == 1:
120
+ layer_stm = layer_stm.expand(x.size(0), -1, -1)
121
+
122
+ # 8. Mean interlayer memory attention
123
+ # a) normalize STM layer value
124
+ pre_normalized_layer_stm = self.mean_memory_norm_layers[i](layer_stm)
125
+ # b) calculate attention between STM layer and mean value of all STM layers (from previous interaction)
126
+ interlayer_stm = self.mean_attention_layers[i](pre_normalized_layer_stm, normalized_mean_stm, normalized_mean_stm, mask=None)
127
+ # c) combine updated interlayer state with current STM state in residual gate
128
+ updated_layer_stm = self.mean_residual_gate_layers[i](layer_stm, interlayer_stm)
129
+
130
+ # 9. Main memory attention
131
+ # a) get encoded data for current layer and normalize it
132
+ encoded_layer_data = self.memory_input_norm_layers[i](x[i])
133
+ # b) normalize STM layer value
134
+ normalized_layer_stm = self.memory_norm_layers[i](updated_layer_stm)
135
+ # c) print normalized STM stats in debug mode
136
+ if self.debug_mode and self.training:
137
+ if self.debug_step != 0 and self.debug_step % self.debug_interval == 0:
138
+ self.debug_step = 0
139
+ print(
140
+ f"Pre-Normalized STM stats - mean: {pre_normalized_layer_stm.mean().item():.4f}, std: {pre_normalized_layer_stm.std().item():.4f}")
141
+ print(f"Normalized STM stats - mean: {normalized_layer_stm.mean().item():.4f}, std: {normalized_layer_stm.std().item():.4f}")
142
+ else:
143
+ self.debug_step += 1
144
+ # d) calculate memory attention between STM layer and encoded data
145
+ new_layer_stm = self.attention_layers[i](normalized_layer_stm, encoded_layer_data, encoded_layer_data, mask=attention_mask)
146
+ # e) combine new updated layer STM with previous state in residual gate
147
+ new_stm[i] = self.residual_gate_layers[i](updated_layer_stm, new_layer_stm) # residual
148
+ # 10. Update all layers/models
88
149
  self.stm.update_all(new_stm)
89
150
  return self.stm.memory
rxnn/memory/gate.py ADDED
@@ -0,0 +1,60 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import TypeAlias, Literal
4
+
5
+ ResidualGateType: TypeAlias = Literal['static', 'elementwise', 'linear']
6
+
7
+
8
+ class ResidualGate(nn.Module):
9
+ def __init__(
10
+ self,
11
+ stm_size: int,
12
+ use_gate: bool = False,
13
+ gate_type: ResidualGateType = 'static',
14
+ per_slot_gate: bool = True,
15
+ init_gate: float = 0.0,
16
+ use_tanh_gate: bool = True,
17
+ **kwargs,
18
+ ):
19
+ super(ResidualGate, self).__init__(**kwargs)
20
+ self.use_gate = use_gate
21
+ self.per_slot_gate = per_slot_gate
22
+ self.gate_type = gate_type
23
+ self.use_tanh_gate = use_tanh_gate
24
+
25
+ if self.use_gate:
26
+ if self.gate_type == 'linear':
27
+ self.gate = nn.Linear(stm_size, stm_size if self.per_slot_gate else 1)
28
+ else:
29
+ gate_shape = (stm_size, 1) if self.per_slot_gate else (1,)
30
+ self.gate = nn.Parameter(torch.full(gate_shape, init_gate))
31
+ else:
32
+ self.gate = None
33
+
34
+ self.gate_activation = nn.Tanh() if self.use_tanh_gate else nn.Sigmoid()
35
+
36
+ def _dynamic_gate(self, old_value: torch.Tensor, new_value: torch.Tensor):
37
+ if self.gate_type == 'linear':
38
+ mean_residual = (new_value + old_value).mean(dim=-1)
39
+ gate_input = self.gate(mean_residual).unsqueeze(-1)
40
+ else:
41
+ mean_dim = -1 if self.per_slot_gate else [1, 2]
42
+ gate_input = self.gate * (new_value + old_value).mean(dim=mean_dim, keepdim=True)
43
+ return self.gate_activation(gate_input)
44
+
45
+ def _calculate_output(self, layer_gate: torch.Tensor, old_value: torch.Tensor, new_value: torch.Tensor) -> torch.Tensor:
46
+ if self.use_tanh_gate:
47
+ return (1 + layer_gate) * new_value + (1 - layer_gate) * old_value
48
+ else:
49
+ return layer_gate * new_value + (1 - layer_gate) * old_value
50
+
51
+ def forward(self, old_value: torch.Tensor, new_value: torch.Tensor) -> torch.Tensor:
52
+ if not self.use_gate:
53
+ return new_value + old_value
54
+
55
+ if self.gate_type == 'static':
56
+ layer_gate = self.gate_activation(self.gate)
57
+ else:
58
+ layer_gate = self._dynamic_gate(old_value, new_value)
59
+
60
+ return self._calculate_output(layer_gate, old_value, new_value)
rxnn/rxt/models.py CHANGED
@@ -9,7 +9,8 @@ from ..transformers.models import ReactiveTransformerBase, ReactiveTransformerEn
9
9
  from ..transformers.ff import get_activation_layer
10
10
  from ..memory.stm import ShortTermMemory
11
11
  from ..memory.norm import init_memory_norm
12
- from ..memory.attention import StmMemoryAttention
12
+ from ..memory.attention import StmMemoryAttention, InterlayerStmMemoryAttention
13
+ from ..memory.gate import ResidualGate, ResidualGateType
13
14
  from ..utils import get_model_size
14
15
  from ..experimental.attention import init_experimental_attention
15
16
 
@@ -260,15 +261,15 @@ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2
260
261
  att_experts: int = None,
261
262
  att_query_experts: int = None,
262
263
  att_query_groups: int = None,
263
- norm_type: str = 'rms',
264
+ norm_type: str = 'classic-rms',
264
265
  norm_init_gate: float = -2.0,
265
266
  norm_per_dim_scale: bool = False,
266
267
  norm_decay: float = 0.9,
267
268
  use_gated_residual: bool = False,
268
- residual_per_slot_gate: bool = False,
269
- residual_init_gate: float = 0.0,
270
- use_dynamic_residual_gate: bool = False,
271
- use_tanh_residual_gate: bool = False,
269
+ residual_per_slot_gate: bool = True,
270
+ residual_gate_init: float = 3.0,
271
+ residual_gate_type: ResidualGateType = 'static',
272
+ use_tanh_residual_gate: bool = True,
272
273
  debug_mode: bool = False,
273
274
  debug_interval: int = 10,
274
275
  **kwargs,
@@ -296,12 +297,19 @@ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2
296
297
  memory_norm_layers = nn.ModuleList([init_memory_norm(norm_type, embed_dim, stm_size, decay=norm_decay,
297
298
  init_gate=norm_init_gate, per_dim_scale=norm_per_dim_scale)
298
299
  for _ in range(num_layers)])
300
+ memory_input_norm_layers = nn.ModuleList(nn.RMSNorm(embed_dim) for _ in range(num_layers))
299
301
  attention_layers = nn.ModuleList([att_init() for _ in range(num_layers)])
302
+ residual_gates = nn.ModuleList([
303
+ ResidualGate(
304
+ stm_size, use_gate=use_gated_residual, gate_type=residual_gate_type, per_slot_gate=residual_per_slot_gate,
305
+ init_gate=residual_gate_init, use_tanh_gate=use_tanh_residual_gate
306
+ ) for _ in range(num_layers)
307
+ ])
308
+
300
309
  self.model = StmMemoryAttention(
301
310
  stm, attention_layers, memory_norm_layers,
302
- use_gated_residual=use_gated_residual, per_slot_gate=residual_per_slot_gate,
303
- init_gate=residual_init_gate, use_dynamic_gate=use_dynamic_residual_gate,
304
- use_tanh_gate=use_tanh_residual_gate, debug_mode=debug_mode, debug_interval=debug_interval,
311
+ memory_input_norm_layers, residual_gates,
312
+ debug_mode=debug_mode, debug_interval=debug_interval,
305
313
  )
306
314
 
307
315
  def freeze(self):
@@ -327,6 +335,141 @@ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2
327
335
  def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
328
336
  return self.model(x, attention_mask=attention_mask)
329
337
 
338
+
339
+ class RxTAlphaInterlayerMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2.0"):
340
+ """RxT-Alpha (Reactive Transformer) memory attention model with interlayer STM attention"""
341
+
342
+ def __init__(
343
+ self,
344
+ num_layers: int = 12,
345
+ embed_dim: int = 512,
346
+ att_heads: int = 16,
347
+ seq_len: int = 1024,
348
+ stm_size: int = 1024,
349
+ use_flash_attention: bool = False,
350
+ att_dropout: float = 0.0,
351
+ att_groups: int = 1,
352
+ att_type: str = 'sqa',
353
+ att_experts: int = None,
354
+ att_query_experts: int = None,
355
+ att_query_groups: int = None,
356
+ interlayer_att_dropout: float = 0.0,
357
+ interlayer_att_groups: int = 1,
358
+ interlayer_att_type: str = 'sqa',
359
+ interlayer_att_experts: int = None,
360
+ interlayer_att_query_experts: int = None,
361
+ interlayer_att_query_groups: int = None,
362
+ norm_type: str = 'classic-rms',
363
+ norm_init_gate: float = -2.0,
364
+ norm_per_dim_scale: bool = False,
365
+ norm_decay: float = 0.9,
366
+ use_gated_residual: bool = False,
367
+ residual_per_slot_gate: bool = True,
368
+ residual_gate_init: float = 3.0,
369
+ residual_gate_type: ResidualGateType = 'static',
370
+ use_tanh_residual_gate: bool = True,
371
+ debug_mode: bool = False,
372
+ debug_interval: int = 10,
373
+ **kwargs,
374
+ ):
375
+ super(RxTAlphaInterlayerMemoryAttention, self).__init__(**kwargs)
376
+
377
+ assert att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma',
378
+ 'sqa'], 'Memory attention type could be "mha", "gqa", "mqa", "gma", "dma", "sqa".'
379
+
380
+ rope = RotaryPositionalEmbedding(embed_dim // att_heads, seq_len)
381
+ stm = ShortTermMemory(num_layers, embed_dim, stm_size)
382
+
383
+ if att_type in ['mha', 'gqa', 'mqa']:
384
+ att_init = lambda: init_attention(
385
+ embed_dim, att_heads, att_type, att_groups, rope=rope,
386
+ use_flash_attention=use_flash_attention, dropout=att_dropout,
387
+ max_seq_len=seq_len, is_causal=False, rope_only_for_keys=True
388
+ )
389
+ else:
390
+ att_init = lambda: init_experimental_attention(
391
+ embed_dim, att_heads, att_type, att_groups, rope=rope,
392
+ use_flash_attention=use_flash_attention, dropout=att_dropout,
393
+ max_seq_len=seq_len, is_causal=False, num_experts=att_experts,
394
+ num_query_experts=att_query_experts, num_query_groups=att_query_groups,
395
+ rope_only_for_keys=True
396
+ )
397
+
398
+ memory_norm_layers = nn.ModuleList([init_memory_norm(norm_type, embed_dim, stm_size, decay=norm_decay,
399
+ init_gate=norm_init_gate, per_dim_scale=norm_per_dim_scale)
400
+ for _ in range(num_layers)])
401
+ memory_input_norm_layers = nn.ModuleList(nn.RMSNorm(embed_dim) for _ in range(num_layers))
402
+ attention_layers = nn.ModuleList([att_init() for _ in range(num_layers)])
403
+ residual_gates = nn.ModuleList([
404
+ ResidualGate(
405
+ stm_size, use_gate=use_gated_residual, gate_type=residual_gate_type,
406
+ per_slot_gate=residual_per_slot_gate,
407
+ init_gate=residual_gate_init, use_tanh_gate=use_tanh_residual_gate
408
+ ) for _ in range(num_layers)
409
+ ])
410
+
411
+ # Interlayer attention
412
+ if interlayer_att_type in ['mha', 'gqa', 'mqa']:
413
+ interlayer_att_init = lambda: init_attention(
414
+ embed_dim, att_heads, interlayer_att_type, interlayer_att_groups, rope=None,
415
+ use_flash_attention=use_flash_attention, dropout=interlayer_att_dropout, is_causal=False
416
+ )
417
+ else:
418
+ interlayer_att_init = lambda: init_experimental_attention(
419
+ embed_dim, att_heads, interlayer_att_type, interlayer_att_groups, rope=None,
420
+ use_flash_attention=use_flash_attention, dropout=interlayer_att_dropout, is_causal=False,
421
+ num_experts=interlayer_att_experts, num_query_experts=interlayer_att_query_experts, num_query_groups=interlayer_att_query_groups
422
+ )
423
+
424
+ mean_attention_layers = nn.ModuleList([interlayer_att_init() for _ in range(num_layers)])
425
+
426
+ mean_stm_norm = init_memory_norm(
427
+ norm_type, embed_dim, stm_size, decay=norm_decay,
428
+ init_gate=norm_init_gate, per_dim_scale=norm_per_dim_scale
429
+ )
430
+
431
+ mean_memory_norm_layers = nn.ModuleList([init_memory_norm(norm_type, embed_dim, stm_size, decay=norm_decay,
432
+ init_gate=norm_init_gate, per_dim_scale=norm_per_dim_scale)
433
+ for _ in range(num_layers)])
434
+
435
+ mean_residual_gates = nn.ModuleList([
436
+ ResidualGate(
437
+ stm_size, use_gate=use_gated_residual, gate_type=residual_gate_type,
438
+ per_slot_gate=residual_per_slot_gate,
439
+ init_gate=residual_gate_init, use_tanh_gate=use_tanh_residual_gate
440
+ ) for _ in range(num_layers)
441
+ ])
442
+
443
+ self.model = InterlayerStmMemoryAttention(
444
+ stm, attention_layers, memory_norm_layers, memory_input_norm_layers, residual_gates,
445
+ mean_attention_layers, mean_memory_norm_layers, mean_residual_gates, mean_stm_norm,
446
+ debug_mode=debug_mode, debug_interval=debug_interval,
447
+ )
448
+
449
+ def freeze(self):
450
+ for param in self.parameters():
451
+ param.requires_grad = False
452
+
453
+ def unfreeze(self):
454
+ for param in self.parameters():
455
+ param.requires_grad = True
456
+
457
+ def load_shared_memory(self, stm: ShortTermMemory):
458
+ self.model.stm = stm
459
+
460
+ def update_max_len(self, max_seq_len: int):
461
+ self.model.update_max_len(max_seq_len)
462
+
463
+ def reset_memory(self, init_type: str = None):
464
+ self.model.stm.reset(init_type)
465
+
466
+ def clone_reset_memory(self):
467
+ self.model.stm.clone_detach_reset()
468
+
469
+ def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
470
+ return self.model(x, attention_mask=attention_mask)
471
+
472
+
330
473
  class RxTAlphaCriticEncoder(RxTAlphaComponentBase, pipeline_tag="text-classification", license="apache-2.0"):
331
474
  """RxT-Alpha (Reactive Transformer) encoder model"""
332
475
 
@@ -49,10 +49,12 @@ class ReactiveTransformerLayer(nn.Module):
49
49
  self.norm1 = nn.RMSNorm(embed_dim)
50
50
  self.norm2 = nn.RMSNorm(embed_dim)
51
51
  self.norm3 = nn.RMSNorm(embed_dim)
52
+ self.stm_norm = nn.RMSNorm(embed_dim)
52
53
  else:
53
54
  self.norm1 = nn.LayerNorm(embed_dim)
54
55
  self.norm2 = nn.LayerNorm(embed_dim)
55
56
  self.norm3 = nn.LayerNorm(embed_dim)
57
+ self.stm_norm = nn.LayerNorm(embed_dim)
56
58
  self.use_post_norm = use_post_norm
57
59
  self.use_moe = use_moe
58
60
  self.use_moe_att = use_moe_att
@@ -63,9 +65,11 @@ class ReactiveTransformerLayer(nn.Module):
63
65
  if with_norms:
64
66
  for param in self.norm2.parameters():
65
67
  param.requires_grad_(is_trainable)
68
+ for param in self.stm_norm.parameters():
69
+ param.requires_grad_(is_trainable)
66
70
 
67
71
  def memory_parameters(self) -> list[nn.Parameter]:
68
- return list(self.memory_cross_attention.parameters()) + list(self.norm2.parameters())
72
+ return list(self.memory_cross_attention.parameters()) + list(self.norm2.parameters()) + list(self.stm_norm.parameters())
69
73
 
70
74
  def not_memory_parameters(self) -> list[nn.Parameter]:
71
75
  return (list(self.attention.parameters()) + list(self.norm1.parameters()) +
@@ -102,11 +106,8 @@ class ReactiveTransformerLayer(nn.Module):
102
106
  residual = x
103
107
  if not self.use_post_norm:
104
108
  x = self.norm1(x)
105
- if torch.isnan(x).any():
106
- print("!!!!!!!!!!!!!!!!!!!!!! !!!!!!!!!!!!!!!!!!!!!! NaN detected in pre-norm (self-attention) output")
107
109
  x = self.attention(x, x, x, mask=mask)
108
- if torch.isnan(x).any():
109
- print("!!!!!!!!!!!!!!!!!!!!!! !!!!!!!!!!!!!!!!!!!!!! NaN detected in self-attention output")
110
+
110
111
  x = residual + x
111
112
  if self.use_post_norm:
112
113
  x = self.norm1(x)
@@ -114,18 +115,13 @@ class ReactiveTransformerLayer(nn.Module):
114
115
  residual = x
115
116
  if not self.use_post_norm:
116
117
  x = self.norm2(x)
117
- if torch.isnan(x).any():
118
- print("!!!!!!!!!!!!!!!!!!!!!! NaN detected in pre-norm (cross-attention) output")
119
118
 
119
+ # normalize STM and prepare STM mask
120
+ stm = self.stm_norm(stm)
120
121
  mem_mask = mask.squeeze(1).unsqueeze(-1).expand(-1, -1, -1, stm.size(1)) \
121
122
  if mask is not None else None
122
123
 
123
- if torch.isnan(stm).any():
124
- print("!!!!!!!!!!!!!!!!!!!!!! NaN detected in STM cross-attention input")
125
-
126
124
  x = self.memory_cross_attention(x, stm, stm, mask=mem_mask)
127
- if torch.isnan(x).any():
128
- print("!!!!!!!!!!!!!!!!!!!!!! NaN detected in cross-attention output")
129
125
  x = residual + x
130
126
  if self.use_post_norm:
131
127
  x = self.norm2(x)
@@ -134,11 +130,7 @@ class ReactiveTransformerLayer(nn.Module):
134
130
  residual = x
135
131
  if not self.use_post_norm:
136
132
  x = self.norm3(x)
137
- if torch.isnan(x).any():
138
- print("!!!!!!!!!!!!!!!!!!!!!! NaN detected in pre-norm (ff) output")
139
133
  x = self.ff(x)
140
- if torch.isnan(x).any():
141
- print("!!!!!!!!!!!!!!!!!!!!!! NaN detected in ff output")
142
134
  x = residual + x
143
135
  if self.use_post_norm:
144
136
  x = self.norm3(x)
rxnn/utils.py CHANGED
@@ -1,11 +1,6 @@
1
1
  import random, gc
2
- from typing import Optional, Union, List, Dict, Any
3
-
4
2
  import torch
5
3
  import numpy as np
6
- from huggingface_hub import PyTorchModelHubMixin
7
- from huggingface_hub.hub_mixin import DataclassInstance
8
-
9
4
 
10
5
  def human_format(num: int):
11
6
  """Format numbers to human-readable format."""
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.72
3
+ Version: 0.2.74
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -5,11 +5,12 @@ rxnn/experimental/attention.py,sha256=jlNS82INjycNEfmk3HtkIacUvT_ELhaCO2g-kZTvhX
5
5
  rxnn/experimental/models.py,sha256=KheR1zSNJIaeVvpVAkEJwcuM5nOqQP0ZF08XhrtGJ8E,5387
6
6
  rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
7
7
  rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
- rxnn/memory/attention.py,sha256=el-vlkA7OYFNisdYbaQMxSphSG7Px6oDx1aO_3lFIs4,4316
8
+ rxnn/memory/attention.py,sha256=CReYJZNA5JRED_QWqX-yKqEKZTRX6DNCAB8uFLZtKxI,7513
9
+ rxnn/memory/gate.py,sha256=pR_H2y9C7S02QskoFAEC9Tmluut0k4GGlHgvZGiw6m4,2332
9
10
  rxnn/memory/norm.py,sha256=cVjjhCLqR5K6-321SP_ObG17y-ddlcTJeCTXvW4vpk0,6675
10
11
  rxnn/memory/stm.py,sha256=jv57gsH9XW19sLbxpRDqsp1yfsii_4Ef4Ncr_ztk-i4,3937
11
12
  rxnn/rxt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
- rxnn/rxt/models.py,sha256=M_0nEfSgr5Wyv-Ku4TCLpIs5VndUccjtIR0wU0DSVRo,15574
13
+ rxnn/rxt/models.py,sha256=zgRgNUVYYuniiB1xt7HdQYgmhep6e5ybxv3PU0lcfoU,22208
13
14
  rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
15
  rxnn/training/base.py,sha256=CqaArEZYOdH64nmKfx28U3GI46TzO4oNkjf_hrF23Cw,11835
15
16
  rxnn/training/bml.py,sha256=hw6gLpLkGvqLzxIvBg4MvCc5r8cHpEm2RDyh7nH6CtE,16914
@@ -26,14 +27,14 @@ rxnn/training/utils.py,sha256=ngDCm654NL3UsPy190Er4XPc9HI-OyEV6tDLMgEEvQc,6219
26
27
  rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
27
28
  rxnn/transformers/attention.py,sha256=KRnKT6XUqAXElxV9y72mSpdTeiMgCKCCLqqxCFNTHmA,16372
28
29
  rxnn/transformers/ff.py,sha256=WDjO-H9XWInoWnUnxiseIH6Kx5GlHP0zGJygwhcb1gc,2589
29
- rxnn/transformers/layers.py,sha256=VOKCURq9HQu0Uf0uk1cmyvj3u-rnoyJoZZ9Y-kSSihQ,9095
30
+ rxnn/transformers/layers.py,sha256=7iwLZ4De4kw3-YA5p2-adCvTgeqeLC-lXcFAlhN_-AA,8450
30
31
  rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
31
32
  rxnn/transformers/models.py,sha256=TP0H9do53Z0vd8kpHMISBzMpHE5X9QIHcy0B-iJHuNQ,11711
32
33
  rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
33
34
  rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
34
35
  rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
35
- rxnn/utils.py,sha256=jnPmhehnRojRolgDxgRA_XPdcx_nUNT5tuDmrV0b-w0,1155
36
- rxnn-0.2.72.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
37
- rxnn-0.2.72.dist-info/METADATA,sha256=FyIccoN8UysI4TKozaOrjckG0rxSelStVz7Yi3y8wXM,60420
38
- rxnn-0.2.72.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
39
- rxnn-0.2.72.dist-info/RECORD,,
36
+ rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
37
+ rxnn-0.2.74.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
38
+ rxnn-0.2.74.dist-info/METADATA,sha256=8Qdcz_IBj0olbOBdNfRmDOrVN5rD8pghakpfXzTgl3E,60420
39
+ rxnn-0.2.74.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
40
+ rxnn-0.2.74.dist-info/RECORD,,
File without changes
File without changes