rxnn 0.2.24__py3-none-any.whl → 0.2.26__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/memory/attention.py CHANGED
@@ -35,7 +35,6 @@ class StmMemoryAttention(nn.Module):
35
35
  encoded_layer_data = x[i]
36
36
  normalized_layer_stm = self.memory_norm_layers[i](layer_stm)
37
37
  new_layer_stm = self.attention_layers[i](normalized_layer_stm, encoded_layer_data, encoded_layer_data, mask=mask)
38
- # self.stm.update_layer(i, new_layer_stm + layer_stm)
39
38
  new_stm[i] = new_layer_stm + layer_stm # residual
40
39
  self.stm.update_all(new_stm)
41
40
  return self.stm.memory
rxnn/memory/norm.py CHANGED
@@ -7,10 +7,11 @@ class AdaptivePositionalMemoryNorm(nn.Module):
7
7
  self,
8
8
  num_slots: int,
9
9
  dim: int,
10
- decay: float = 0.99,
10
+ decay: float = 0.9,
11
11
  use_scale: bool = True,
12
12
  use_gate: bool = True,
13
- init_gate: float = -4.0
13
+ init_gate: float = -2.0,
14
+ per_dim_scale: bool = False,
14
15
  ):
15
16
  super(AdaptivePositionalMemoryNorm, self).__init__()
16
17
  self.use_gate = use_gate
@@ -20,39 +21,38 @@ class AdaptivePositionalMemoryNorm(nn.Module):
20
21
  self.eps = 1e-6
21
22
 
22
23
  # Learnable parameters
23
- self.scale = nn.Parameter(torch.ones(num_slots, 1, dim)) if use_scale else None
24
- self.gate = nn.Parameter(torch.full((num_slots, 1, 1), init_gate)) if use_gate else None
24
+ scale_shape = (num_slots, 1) if not per_dim_scale else (dim,)
25
+ self.scale = nn.Parameter(torch.ones(*scale_shape)) if use_scale else None
26
+ self.gate = nn.Parameter(torch.full((num_slots, 1), init_gate)) if use_gate else None
25
27
 
26
28
  # EMA buffers
27
29
  self.register_buffer("ema_rms", torch.ones(num_slots, 1))
28
30
 
29
31
  # Initialize parameters
30
32
  if self.scale is not None:
31
- nn.init.normal_(self.scale, mean=1.0, std=0.01)
33
+ nn.init.normal_(self.scale, mean=1.0, std=0.1)
32
34
 
33
35
  def forward(self, x: torch.Tensor) -> torch.Tensor:
34
- # x shape: [batch_size, num_slots, dim]
35
- batch_size = x.size(0)
36
-
37
36
  # Calculate current RMS per slot
38
- current_rms = x.pow(2).mean(dim=-1, keepdim=True).sqrt() # [batch, slots, 1]
39
- slot_rms = current_rms.mean(dim=0) # [slots, 1] (average over batch)
37
+ # x: [batch_size, num_slots, dim]
38
+ current_rms = x.pow(2).mean(dim=-1, keepdim=True).sqrt() # [batch, num_slots, 1]
39
+ slot_rms = current_rms.mean(dim=0) # [num_slots, 1] (average over batch)
40
40
 
41
41
  # Update EMA during training
42
42
  if self.training:
43
- self.ema_rms = self.decay * self.ema_rms + (1 - self.decay) * slot_rms.detach()
43
+ self.ema_rms = self.decay * self.ema_rms + (1 - self.decay) * slot_rms.detach() # [num_slots, 1]
44
44
 
45
45
  # Normalize using EMA statistics
46
- x_norm = x * torch.rsqrt(self.ema_rms + self.eps)
46
+ x_norm = x * torch.rsqrt(self.ema_rms + self.eps) # [batch_size, num_slots, dim] * [num_slots, 1]
47
47
 
48
48
  # Apply learned scale per slot
49
49
  if self.scale is not None:
50
- x_norm = x_norm * self.scale
50
+ x_norm = x_norm * self.scale # [batch_size, num_slots, dim] * [num_slots, 1] or [dim]
51
51
 
52
52
  # Apply gating mechanism
53
53
  if self.use_gate:
54
- gate = torch.sigmoid(self.gate) # [slots, 1, 1]
55
- return gate * x_norm + (1 - gate) * x
54
+ gate = torch.sigmoid(self.gate) # [num_slots, 1]
55
+ return gate * x_norm + (1 - gate) * x # [batch_size, num_slots, dim] * [num_slots, 1]
56
56
 
57
57
  return x_norm
58
58
 
@@ -77,7 +77,7 @@ class AdaptiveRMSMemoryNorm(nn.Module):
77
77
  # x shape: [batch_size, num_slots, dim]
78
78
  if self.training and hasattr(self, 'ema_rms'):
79
79
  # Compute current RMS across all slots and batch (scalar)
80
- current_rms = x.pow(2).mean(-1).mean().sqrt()
80
+ current_rms = x.pow(2).mean(dim=-1).mean().sqrt()
81
81
  self.ema_rms = self.ema_rms * self.decay + current_rms * (1 - self.decay)
82
82
  rms = self.ema_rms
83
83
  else:
@@ -150,24 +150,26 @@ class MemoryNormConfig(TypedDict):
150
150
  use_gate: bool
151
151
  init_gate: float
152
152
  init_scale: float
153
+ per_dim_scale: bool
153
154
 
154
155
  def init_memory_norm(
155
156
  norm_type: str,
156
157
  dim: int,
157
158
  num_slots: int = None,
158
- decay: float = 0.99,
159
+ decay: float = 0.9,
159
160
  use_scale: bool = True,
160
161
  use_gate: bool = True,
161
- init_gate: float = -4.0,
162
+ init_gate: float = -2.0,
162
163
  init_scale: float = 1.0,
164
+ per_dim_scale: bool = False,
163
165
  ) -> nn.Module:
164
- assert norm_type in ["layer", "rms", "adaptive", "positional"]
165
- if norm_type == "layer":
166
+ assert norm_type in ['layer', 'rms', 'adaptive', 'positional']
167
+ if norm_type == 'layer':
166
168
  return MemoryLayerNorm(dim, use_gate, init_scale, init_gate)
167
- elif norm_type == "rms":
169
+ elif norm_type == 'rms':
168
170
  return SimpleRMSMemoryNorm(dim, use_gate, init_scale, init_gate)
169
- elif norm_type == "adaptive":
171
+ elif norm_type == 'adaptive':
170
172
  return AdaptiveRMSMemoryNorm(dim, use_gate, decay, init_scale, init_gate)
171
- elif norm_type == "positional":
172
- return AdaptivePositionalMemoryNorm(num_slots, dim, decay, use_scale, use_gate, init_gate)
173
+ elif norm_type == 'positional':
174
+ return AdaptivePositionalMemoryNorm(num_slots, dim, decay, use_scale, use_gate, init_gate, per_dim_scale)
173
175
  return MemoryLayerNorm(dim, use_gate, init_scale, init_gate)
rxnn/rxt/models.py CHANGED
@@ -13,6 +13,7 @@ from ..memory.attention import StmMemoryAttention
13
13
  from ..utils import get_model_size
14
14
  from ..experimental.attention import init_experimental_attention
15
15
 
16
+
16
17
  class RxTAlphaComponentConfig(TypedDict):
17
18
  num_layers: int
18
19
  vocab_size: int
@@ -76,8 +77,10 @@ class RxTAlphaComponentBase(nn.Module, PyTorchModelHubMixin):
76
77
  assert ff_activation in ['relu', 'gelu',
77
78
  'swish', 'silu', 'linear',
78
79
  'sigmoid'], 'Feed-forward activation could be "relu", "gelu", "swish", "silu", "linear", "sigmoid".'
79
- assert self_att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma', 'sqa'], 'Self-attention type could be "mha", "gqa", "mqa", "gma", "dma", "sqa".'
80
- assert cross_att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma', 'sqa'], 'Memory cross-attention type could be "mha", "gqa", "mqa", "gma", "dma", "sqa".'
80
+ assert self_att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma',
81
+ 'sqa'], 'Self-attention type could be "mha", "gqa", "mqa", "gma", "dma", "sqa".'
82
+ assert cross_att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma',
83
+ 'sqa'], 'Memory cross-attention type could be "mha", "gqa", "mqa", "gma", "dma", "sqa".'
81
84
 
82
85
  embedding = nn.Embedding(vocab_size, embed_dim)
83
86
  rope = RotaryPositionalEmbedding(embed_dim // att_heads, seq_len)
@@ -92,20 +95,25 @@ class RxTAlphaComponentBase(nn.Module, PyTorchModelHubMixin):
92
95
  else:
93
96
  att_init = lambda: init_experimental_attention(embed_dim, att_heads, self_att_type, att_groups, rope=rope,
94
97
  use_flash_attention=use_flash_attention, dropout=att_dropout,
95
- max_seq_len=seq_len, is_causal=is_causal, num_experts=att_experts,
98
+ max_seq_len=seq_len, is_causal=is_causal,
99
+ num_experts=att_experts,
96
100
  num_query_experts=att_query_experts,
97
101
  num_query_groups=att_query_groups)
98
102
 
99
103
  if cross_att_type in ['mha', 'gqa', 'mqa']:
100
104
  cross_att_init = lambda: init_attention(embed_dim, att_heads, cross_att_type, att_groups, rope=rope,
101
- use_flash_attention=use_flash_attention, dropout=att_dropout,
102
- max_seq_len=seq_len, is_causal=is_causal, rope_only_for_query=True)
105
+ use_flash_attention=use_flash_attention, dropout=att_dropout,
106
+ max_seq_len=seq_len, is_causal=is_causal, rope_only_for_query=True)
103
107
  else:
104
- cross_att_init = lambda: init_experimental_attention(embed_dim, att_heads, cross_att_type, cross_att_groups or att_groups, rope=rope,
105
- use_flash_attention=use_flash_attention, dropout=att_dropout,
106
- max_seq_len=seq_len, is_causal=is_causal, num_experts=att_experts,
107
- num_query_experts=att_query_experts,
108
- num_query_groups=cross_att_query_groups or att_query_groups, rope_only_for_query=True)
108
+ cross_att_init = lambda: init_experimental_attention(embed_dim, att_heads, cross_att_type,
109
+ cross_att_groups or att_groups, rope=rope,
110
+ use_flash_attention=use_flash_attention,
111
+ dropout=att_dropout,
112
+ max_seq_len=seq_len, is_causal=is_causal,
113
+ num_experts=att_experts,
114
+ num_query_experts=att_query_experts,
115
+ num_query_groups=cross_att_query_groups or att_query_groups,
116
+ rope_only_for_query=True)
109
117
 
110
118
  layers = nn.ModuleList([
111
119
  ReactiveTransformerLayer(
@@ -137,6 +145,12 @@ class RxTAlphaComponentBase(nn.Module, PyTorchModelHubMixin):
137
145
  def load_shared_memory(self, stm: ShortTermMemory):
138
146
  self.model.stm = stm
139
147
 
148
+ def memory_parameters(self) -> list[nn.Parameter]:
149
+ return self.model.memory_parameters()
150
+
151
+ def not_memory_parameters(self) -> list[nn.Parameter]:
152
+ return self.model.not_memory_parameters()
153
+
140
154
  def freeze_without_memory(self, unfreeze_norms: bool = True):
141
155
  for param in self.model.parameters():
142
156
  param.requires_grad_(False)
@@ -211,20 +225,9 @@ class RxTAlphaDecoder(RxTAlphaComponentBase, pipeline_tag="text-generation", lic
211
225
  return self.model(x, attention_mask=attention_mask)
212
226
 
213
227
 
214
- def build_rxt_alpha_for_pretraining(
215
- encoder_config: RxTAlphaComponentConfig,
216
- decoder_config: RxTAlphaComponentConfig,
217
- ) -> tuple[RxTAlphaEncoder, RxTAlphaDecoder]:
218
- encoder = RxTAlphaEncoder(**encoder_config)
219
- decoder = RxTAlphaDecoder(**decoder_config)
220
-
221
- encoder.load_shared_memory(decoder.model.stm)
222
- encoder.load_shared_embedding(decoder.model.embedding)
223
-
224
- return encoder, decoder
225
-
226
228
  class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2.0"):
227
229
  """RxT-Alpha (Reactive Transformer) memory attention model"""
230
+
228
231
  def __init__(
229
232
  self,
230
233
  num_layers: int = 12,
@@ -234,17 +237,21 @@ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2
234
237
  stm_size: int = 1024,
235
238
  use_flash_attention: bool = False,
236
239
  att_dropout: float = 0.0,
237
- norm_type: str = 'rms',
238
240
  att_groups: int = 1,
239
241
  att_type: str = 'sqa',
240
242
  att_experts: int = None,
241
243
  att_query_experts: int = None,
242
244
  att_query_groups: int = None,
245
+ norm_type: str = 'rms',
246
+ norm_init_gate: float = -2.0,
247
+ norm_per_dim_scale: bool = False,
248
+ norm_decay: float = 0.9,
243
249
  **kwargs,
244
250
  ):
245
251
  super(RxTAlphaMemoryAttention, self).__init__(**kwargs)
246
252
 
247
- assert att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma', 'sqa'], 'Memory attention type could be "mha", "gqa", "mqa", "gma", "dma", "sqa".'
253
+ assert att_type in ['mha', 'gqa', 'mqa', 'gma', 'dma',
254
+ 'sqa'], 'Memory attention type could be "mha", "gqa", "mqa", "gma", "dma", "sqa".'
248
255
 
249
256
  rope = RotaryPositionalEmbedding(embed_dim // att_heads, seq_len)
250
257
  stm = ShortTermMemory(num_layers, embed_dim, stm_size)
@@ -256,11 +263,14 @@ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2
256
263
  else:
257
264
  att_init = lambda: init_experimental_attention(embed_dim, att_heads, att_type, att_groups, rope=rope,
258
265
  use_flash_attention=use_flash_attention, dropout=att_dropout,
259
- max_seq_len=seq_len, is_causal=False, num_experts=att_experts,
266
+ max_seq_len=seq_len, is_causal=False,
267
+ num_experts=att_experts,
260
268
  num_query_experts=att_query_experts,
261
269
  num_query_groups=att_query_groups, rope_only_for_keys=True)
262
270
 
263
- memory_norm_layers = nn.ModuleList([init_memory_norm(norm_type, embed_dim, stm_size) for _ in range(num_layers)])
271
+ memory_norm_layers = nn.ModuleList([init_memory_norm(norm_type, embed_dim, stm_size, decay=norm_decay,
272
+ init_gate=norm_init_gate, per_dim_scale=norm_per_dim_scale)
273
+ for _ in range(num_layers)])
264
274
  attention_layers = nn.ModuleList([att_init() for _ in range(num_layers)])
265
275
  self.model = StmMemoryAttention(stm, attention_layers, memory_norm_layers)
266
276
 
@@ -283,4 +293,3 @@ class RxTAlphaMemoryAttention(nn.Module, PyTorchModelHubMixin, license="apache-2
283
293
 
284
294
  def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
285
295
  return self.model(x, attention_mask=attention_mask)
286
-
rxnn/training/models.py CHANGED
@@ -124,6 +124,19 @@ class MrlActorModel(nn.Module):
124
124
  def reset_memory(self):
125
125
  self.memory_attention.reset_memory()
126
126
 
127
+ def memory_parameters(self) -> list[nn.Parameter]:
128
+ return list(set(
129
+ self.encoder.memory_parameters() +
130
+ self.decoder.memory_parameters() +
131
+ self.memory_attention.parameters()
132
+ ))
133
+
134
+ def not_memory_parameters(self) -> list[nn.Parameter]:
135
+ return list(set(
136
+ self.encoder.not_memory_parameters() +
137
+ self.decoder.not_memory_parameters()
138
+ ))
139
+
127
140
  def unique_parameters(self):
128
141
  return list(set(
129
142
  list(self.encoder.parameters()) +
rxnn/training/mrl.py CHANGED
@@ -17,6 +17,8 @@ from .models import MrlActorAction, MrlActorModel, MrlCriticModel
17
17
 
18
18
  class MrlConfig(TypedDict):
19
19
  lr: float
20
+ separate_memory_lr: Optional[bool]
21
+ memory_lr: Optional[float]
20
22
  critic_lr: float
21
23
  max_seq_len: int
22
24
  critic_max_len: int
@@ -42,7 +44,9 @@ class CurriculumConfig(TypedDict):
42
44
  random_resets_from: Optional[int]
43
45
  random_resets_ratio: Optional[float]
44
46
  reward_model: Optional[MrlRewardModel]
47
+ separate_memory_lr: Optional[bool]
45
48
  lr: Optional[float]
49
+ memory_lr: Optional[float]
46
50
  critic_lr: Optional[float]
47
51
  weight_decay: Optional[float]
48
52
  critic_weight_decay: Optional[float]
@@ -84,6 +88,7 @@ class MRLTrainer:
84
88
  use_amp: bool = False,
85
89
  dtype: torch.dtype = torch.float32,
86
90
  callbacks: list[MrlTrainerCallback] = None,
91
+
87
92
  ):
88
93
  """
89
94
  Trainer for Memory Reinforcement Learning (MRL) in Reactive Transformer.
@@ -123,15 +128,25 @@ class MRLTrainer:
123
128
  self.use_amp = use_amp
124
129
  self.dtype = dtype
125
130
 
126
- self.base_optim_config = {
127
- 'lr': config.get('lr', 3e-4),
128
- 'critic_lr': config.get('critic_lr', 1e-4),
129
- 'weight_decay': config.get('weight_decay', 0.01),
130
- 'critic_weight_decay': config.get('critic_weight_decay', 0.01),
131
- }
131
+ self.separate_memory_lr = config.get('separate_memory_lr', False)
132
+
133
+ if self.separate_memory_lr:
134
+ self.base_optim_config = {
135
+ 'lr': (config.get('lr', 3e-4), config.get('memory_lr', 5e-4)),
136
+ 'critic_lr': config.get('critic_lr', 1e-4),
137
+ 'weight_decay': config.get('weight_decay', 0.01),
138
+ 'critic_weight_decay': config.get('critic_weight_decay', 0.01),
139
+ }
140
+ else:
141
+ self.base_optim_config = {
142
+ 'lr': config.get('lr', 3e-4),
143
+ 'critic_lr': config.get('critic_lr', 1e-4),
144
+ 'weight_decay': config.get('weight_decay', 0.01),
145
+ 'critic_weight_decay': config.get('critic_weight_decay', 0.01),
146
+ }
132
147
 
133
148
  # Optimizers
134
- self.optimizer, self.critic_optimizer = self._init_optimizers(**self.base_optim_config)
149
+ self.optimizer, self.critic_optimizer = self._init_optimizers(**self.base_optim_config, separate_memory_lr=self.separate_memory_lr)
135
150
 
136
151
  self.scaler = torch.amp.GradScaler() if self.use_amp else None
137
152
  self.critic_scaler = torch.amp.GradScaler() if self.use_amp else None
@@ -158,18 +173,28 @@ class MRLTrainer:
158
173
  self.global_epoch = 0
159
174
  self.global_epochs_count = 0
160
175
 
161
- def _init_optimizers(self, lr: float, critic_lr: float, weight_decay: float, critic_weight_decay: float):
162
- optimizer = torch.optim.AdamW(
163
- self.actor.unique_parameters(),
164
- lr=lr,
165
- weight_decay=weight_decay,
166
- )
176
+ def _init_optimizers(self, lr: Union[float, tuple[float, float]], critic_lr: float, weight_decay: float, critic_weight_decay: float, separate_memory_lr: bool = False) -> tuple[torch.optim.Optimizer, torch.optim.Optimizer]:
177
+ if separate_memory_lr:
178
+ rest_lr, memory_lr = lr
179
+ optimizer = torch.optim.AdamW([
180
+ { 'params': self.actor.not_memory_parameters(), 'lr': rest_lr },
181
+ { 'params': self.actor.memory_parameters(), 'lr': memory_lr },
182
+ ],
183
+ weight_decay=weight_decay,
184
+ )
185
+ else:
186
+ optimizer = torch.optim.AdamW(
187
+ self.actor.unique_parameters(),
188
+ lr=lr,
189
+ weight_decay=weight_decay,
190
+ )
167
191
 
168
192
  critic_optimizer = torch.optim.AdamW(
169
193
  self.critic.parameters(),
170
194
  lr=critic_lr,
171
195
  weight_decay=critic_weight_decay,
172
196
  )
197
+
173
198
  return optimizer, critic_optimizer
174
199
 
175
200
 
@@ -722,12 +747,13 @@ class MRLTrainer:
722
747
  self.strategy = config.get('strategy',
723
748
  MrlStrategy.MULTI_STEP_STRATEGY) # MRL strategy for given curriculum stage
724
749
  self.reward = config.get('reward_model', self.shared_reward_model) # MRL Reward Model for curriculum stage
725
- if config['lr'] is not None or config['critic_lr'] is not None or config['weight_decay'] is not None or config['critic_weight_decay'] is not None:
750
+ if config['lr'] is not None or config['critic_lr'] is not None or config['weight_decay'] is not None or config['critic_weight_decay'] is not None or (config['separate_memory_lr'] and config['memory_lr'] is not None):
726
751
  self.optimizer, self.critic_optimizer = self._init_optimizers(
727
- lr=config.get('lr', self.base_optim_config['lr']),
752
+ lr=(config.get('lr', self.base_optim_config['lr'][0]), config.get('memory_lr', self.base_optim_config['lr'][1])) if config.get('separate_memory_lr', False) else config.get('lr', self.base_optim_config['lr']),
728
753
  critic_lr=config.get('critic_lr', self.base_optim_config['critic_lr']),
729
754
  weight_decay=config.get('weight_decay', self.base_optim_config['weight_decay']),
730
- critic_weight_decay=config.get('critic_weight_decay', self.base_optim_config['critic_weight_decay'])
755
+ critic_weight_decay=config.get('critic_weight_decay', self.base_optim_config['critic_weight_decay']),
756
+ separate_memory_lr=config.get('separate_memory_lr', False),
731
757
  )
732
758
 
733
759
  # 2. Get epochs and random resets configs
@@ -64,6 +64,13 @@ class ReactiveTransformerLayer(nn.Module):
64
64
  for param in self.norm2.parameters():
65
65
  param.requires_grad_(is_trainable)
66
66
 
67
+ def memory_parameters(self) -> list[nn.Parameter]:
68
+ return list(self.memory_cross_attention.parameters()) + list(self.norm2.parameters())
69
+
70
+ def not_memory_parameters(self) -> list[nn.Parameter]:
71
+ memory_params = self.memory_parameters()
72
+ return [param for param in self.parameters() if param not in memory_params]
73
+
67
74
  def update_max_len(self, max_seq_len: int):
68
75
  if self.attention.rope is not None:
69
76
  self.attention.rope.update_max_len(max_seq_len)
@@ -39,6 +39,16 @@ class ReactiveTransformerBase(nn.Module):
39
39
  for i in range(self.num_own_layers):
40
40
  self.layers[i].trainable_cross_attention_(is_trainable, with_norms)
41
41
 
42
+ def memory_parameters(self) -> list[nn.Parameter]:
43
+ own = [param for layer in self.layers for param in layer.memory_parameters()]
44
+ shared = [param for layer in self.shared_layers for param in layer.memory_parameters()] if self.shared_layers else []
45
+ return own + shared
46
+
47
+ def not_memory_parameters(self) -> list[nn.Parameter]:
48
+ own = [param for layer in self.layers for param in layer.not_memory_parameters()]
49
+ shared = [param for layer in self.shared_layers for param in layer.not_memory_parameters()] if self.shared_layers else []
50
+ return own + shared
51
+
42
52
  def moe_router_loss(self):
43
53
  return torch.stack([self.layers[i].moe_router_loss() for i in range(self.num_own_layers) if self.layers[i].use_moe or self.layers[i].use_moe_att] + [
44
54
  self.shared_layers[i].moe_router_loss() for i in range(self.num_shared_layers) if self.shared_layers[i].use_moe or self.shared_layers[i].use_moe_att]).mean()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.2.24
3
+ Version: 0.2.26
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -5,18 +5,18 @@ rxnn/experimental/attention.py,sha256=46qwZLJuZMpIBrZ-r9DaQEPPmmZkO464C3Tkm_Mq-c
5
5
  rxnn/experimental/models.py,sha256=foBo0n0ufvBnfIdJomiEg3CuSOiWSt-q5ako7vzYxx4,4888
6
6
  rxnn/experimental/moe.py,sha256=jHZ1QhpWiVQOswVpFmuH7b2IUOPf0Uuf-I2Ddwsd7Us,6140
7
7
  rxnn/memory/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
- rxnn/memory/attention.py,sha256=p-r8DK3iVhNn-JAESVzIXDCG8gk1R_-x5xHclZ5jgb0,1813
9
- rxnn/memory/norm.py,sha256=Ofl8Q5NYEF9GQeO0bhM43tkTW91J0y6TSvTAOYMgloM,6278
8
+ rxnn/memory/attention.py,sha256=PXVBZQYNsRraZh7QDBgUOdPy3lTI8B0d8CzduojBjG0,1747
9
+ rxnn/memory/norm.py,sha256=E98jOQEuIOFFhlkvS8s4fFN-D4tLO6vaOqnObv1oVmA,6592
10
10
  rxnn/memory/stm.py,sha256=IH_3INw7FdI013t56ui3Zq9GPUq-k3HeZGjx6BerS4g,3888
11
11
  rxnn/rxt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
- rxnn/rxt/models.py,sha256=3gCYD_OXvQc8GaXQvRCSj1OcYOSHayWlpP5lsg9wMMk,12389
12
+ rxnn/rxt/models.py,sha256=r8wZeeNTC2VAhiiNe4y7LrbnB4wjFu_cupKiGkpdgjI,13002
13
13
  rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
14
14
  rxnn/training/base.py,sha256=_xik1GXE4RJ_nxwqLQ1ccXA5pRtBCi-jL-jeRFBdHBU,11851
15
15
  rxnn/training/bml.py,sha256=FJszaQXOLx2ZHBa1CQpyMrG8i4Kj14E-gzDAEK_Ei5k,17272
16
16
  rxnn/training/callbacks.py,sha256=-N0MQPpZQaUWCINdTOsjul4bDGbGr2JgQBqOIXBLS6o,35053
17
17
  rxnn/training/dataset.py,sha256=7hTilFWPpqUEc6zNcMqBPjxFKxCfvTKKF3E8tVlwccQ,51250
18
- rxnn/training/models.py,sha256=5fl1hESVj2Hakqz5to8ZJzw5Q4_RKZAUq2bn6nRiPV8,6045
19
- rxnn/training/mrl.py,sha256=14wx3pVha15B7eRWPRgoxRtV5dPtBI0yadIHOYZjX6k,43275
18
+ rxnn/training/models.py,sha256=_TrFwrQ_m6NDPalrafd8faPRyCnDFFFtN_gfzavaCFs,6474
19
+ rxnn/training/mrl.py,sha256=hDsKQTaQcEVmnJruD3TxHZJJzDWu5I6Rq2HVDLj8ADU,44747
20
20
  rxnn/training/reward.py,sha256=7MTVdNm5HnWmt6zFDi3TAYmnVSL_-24riOoY2F7z4x8,11290
21
21
  rxnn/training/rl.py,sha256=j-KNLoZjhaEKasYNOc8DxHtwvknAgAJFwvXKot6otFA,3272
22
22
  rxnn/training/scheduler.py,sha256=LcjU35mEwz2U5x3U6tLfeeYlBqMxbFSxYzJYuXkWbSY,1408
@@ -25,14 +25,14 @@ rxnn/training/utils.py,sha256=Bw8nZLKIt7NQpUVCYkb_79kWKChVFOYgYXwODo4SvNc,5718
25
25
  rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
26
26
  rxnn/transformers/attention.py,sha256=d0Igo1Nrn76BphbHrzekiKJfT3RCy4iSXSB6FLAOTgc,16247
27
27
  rxnn/transformers/ff.py,sha256=jJnuBDsnnX5uYC_WZH8cXAYrMnz0P-iX7MwcPivjRtI,2533
28
- rxnn/transformers/layers.py,sha256=UQZbrAg1UAttPASeqS7BP1a4JalktThmRMzX99Qghss,7618
28
+ rxnn/transformers/layers.py,sha256=LXSY829fIHSCmFmClhQ6B7I5aKbiOqy9mZmwlJG_r7U,7961
29
29
  rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
30
- rxnn/transformers/models.py,sha256=_2qO1SASHtKvTW3dW-Dy9HEmAvoNVC1_addm2tM9Zbs,8325
30
+ rxnn/transformers/models.py,sha256=QwVxYN9DrKllEpOiFoAx4CiThOWafeTa-OAY7L6gN0Y,8929
31
31
  rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
32
32
  rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
33
33
  rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
34
34
  rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
35
- rxnn-0.2.24.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
36
- rxnn-0.2.24.dist-info/METADATA,sha256=PrVfcCd8NBFtFnD8lAJqU7UW3lLEc-Tr7MQhK6obvuo,25960
37
- rxnn-0.2.24.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
38
- rxnn-0.2.24.dist-info/RECORD,,
35
+ rxnn-0.2.26.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
36
+ rxnn-0.2.26.dist-info/METADATA,sha256=XDqI42X3zLRAAKZlVLmstm24KFPP_MfvDtObG9GBc0Y,25960
37
+ rxnn-0.2.26.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
38
+ rxnn-0.2.26.dist-info/RECORD,,
File without changes
File without changes