rxnn 0.1.52__py3-none-any.whl → 0.1.54__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/training/base.py CHANGED
@@ -50,6 +50,10 @@ class BaseTrainer(ABC):
50
50
  self.target_field_name = target_field_name
51
51
  self.total_tokens = 0
52
52
  self.total_steps = 0
53
+ self.validation_steps = 0
54
+ self.total_validation_steps = 0
55
+ self.epoch_steps = 0
56
+ self.current_epoch = 0
53
57
  self.gradient_accumulation_steps = gradient_accumulation_steps
54
58
  self.accumulated_loss = 0.0
55
59
  self.optimizer_step_count = 0
@@ -108,8 +112,10 @@ class BaseTrainer(ABC):
108
112
  scaler = torch.amp.GradScaler() if self.use_amp else None
109
113
 
110
114
  self.model.train()
111
- for epoch in range(epochs):
115
+ for epoch in range(self.current_epoch, self.current_epoch + epochs):
112
116
  if self.is_running:
117
+ self.current_epoch = epoch
118
+ self.epoch_steps = 0
113
119
  if train_sampler is not None:
114
120
  train_sampler.set_epoch(epoch)
115
121
  self._run_epoch(dataloader, epoch, optimizer, batch_size, scaler=scaler, scheduler=scheduler)
@@ -142,6 +148,7 @@ class BaseTrainer(ABC):
142
148
  callback.on_batch_start(self.model, batch_idx, batch)
143
149
  if self.get_batch_size(batch) == batch_size:
144
150
  self.total_steps += 1
151
+ self.epoch_steps = batch_idx
145
152
  loss = self.train_step(batch, batch_idx)
146
153
  orig_loss = loss.item()
147
154
  self.accumulated_loss += orig_loss
@@ -174,25 +181,28 @@ class BaseTrainer(ABC):
174
181
  self.writer.add_scalar(
175
182
  'Loss/train',
176
183
  loss_item,
177
- epoch * len(dataloader) + batch_idx
184
+ self.total_steps,
178
185
  )
179
186
  self.writer.add_scalar(
180
- 'Loss per epoch/train',
187
+ 'Loss/train last epoch',
181
188
  loss_item,
182
189
  batch_idx
183
190
  )
184
191
  self.writer.add_scalar(
185
192
  'Perplexity/train',
186
193
  torch.exp(torch.tensor(loss_item)),
187
- epoch * len(dataloader) + batch_idx
194
+ self.total_steps,
188
195
  )
189
196
  self.accumulated_loss = 0.0
190
197
  self.optimizer_step_count = 0
191
198
 
192
199
  if self.writer:
193
200
  self.total_tokens += batch['attention_mask'].sum().item()
194
- self.writer.add_scalar('Processed tokens', self.total_tokens,
195
- epoch * len(dataloader) + batch_idx)
201
+ self.writer.add_scalar(
202
+ 'Processed tokens',
203
+ self.total_tokens,
204
+ self.total_steps
205
+ )
196
206
 
197
207
  for callback in self.callbacks:
198
208
  should_stop = callback.on_batch_end(self.model, batch_idx, orig_loss, batch)
@@ -200,6 +210,7 @@ class BaseTrainer(ABC):
200
210
  self.is_running = False
201
211
 
202
212
  if self.validation_dataset:
213
+ self.validation_steps = 0
203
214
  val_loss, val_metrics = self.validate(batch_size)
204
215
  val_loss_tensor = torch.tensor(val_loss).to(self.device)
205
216
  if self.use_ddp:
@@ -270,6 +281,8 @@ class BaseTrainer(ABC):
270
281
  with torch.no_grad():
271
282
  for batch in val_dataloader:
272
283
  if self.get_batch_size(batch) == batch_size:
284
+ self.validation_steps += 1
285
+ self.total_validation_steps += 1
273
286
  loss, outputs = self.valid_step(batch)
274
287
  val_loss += loss.item()
275
288
 
rxnn/training/bml.py CHANGED
@@ -91,8 +91,8 @@ class MLMTrainer(BaseTrainer):
91
91
  self.writer.add_scalar('Router aux loss/Train', router_loss.item(), self.total_steps)
92
92
  self.writer.add_scalar('Model loss/Train', main_loss.item(), self.total_steps)
93
93
  else:
94
- self.writer.add_scalar('Router aux loss/Valid', router_loss.item(), self.total_steps)
95
- self.writer.add_scalar('Model loss/Valid', main_loss.item(), self.total_steps)
94
+ self.writer.add_scalar('Router aux loss/Valid', router_loss.item(), self.total_validation_steps)
95
+ self.writer.add_scalar('Model loss/Valid', main_loss.item(), self.total_validation_steps)
96
96
 
97
97
  return loss
98
98
 
@@ -106,14 +106,25 @@ class MLMTrainer(BaseTrainer):
106
106
  with torch.no_grad():
107
107
  for batch in val_dataloader:
108
108
  if self.get_batch_size(batch) == batch_size:
109
+ self.total_validation_steps += 1
110
+ self.validation_steps += 1
109
111
  loss, logits = self.valid_step(batch)
110
112
  val_loss += loss
113
+ if self.writer is not None:
114
+ self.writer.add_scalar('Loss/Valid total', loss.item(), self.total_validation_steps)
115
+ self.writer.add_scalar('Perplexity/Valid', torch.exp(loss).item(), self.total_validation_steps)
116
+
111
117
  labels = batch[self.target_field_name].to(self.device)
112
118
  valid_indices = labels != -100
113
119
  if valid_indices.any():
114
120
  preds = logits.argmax(-1)
115
- correct += (preds[valid_indices] == labels[valid_indices]).sum()
116
- total += valid_indices.sum()
121
+ batch_correct = (preds[valid_indices] == labels[valid_indices]).sum()
122
+ batch_total = valid_indices.sum()
123
+ batch_acc = (batch_correct / batch_total * 100) if total > 0 else torch.tensor(0.0).to(self.device)
124
+ if self.writer is not None:
125
+ self.writer.add_scalar('Accuracy/Valid total', batch_acc.item(), self.total_validation_steps)
126
+ correct += batch_correct
127
+ total += batch_total
117
128
 
118
129
  avg_loss = (val_loss / len(val_dataloader)).item()
119
130
  acc = (correct / total * 100) if total > 0 else torch.tensor(0.0).to(self.device)
@@ -197,15 +208,25 @@ class AutoregressiveTrainer(BaseTrainer):
197
208
  with torch.no_grad():
198
209
  for batch in val_dataloader:
199
210
  if self.get_batch_size(batch) == batch_size:
211
+ self.total_validation_steps += 1
212
+ self.validation_steps += 1
200
213
  loss, logits = self.valid_step(batch)
201
214
  val_loss += loss
215
+ if self.writer is not None:
216
+ self.writer.add_scalar('Loss/Valid total', loss.item(), self.total_validation_steps)
217
+ self.writer.add_scalar('Perplexity/Valid', torch.exp(loss).item(), self.total_validation_steps)
202
218
  shifted_logits = logits[:, :-1].contiguous()
203
219
  shifted_targets = batch[self.target_field_name][:, 1:].to(self.device).contiguous()
204
220
  valid_indices = shifted_targets != -100
205
221
  if valid_indices.any():
206
222
  preds = shifted_logits.argmax(-1)
207
- correct += (preds[valid_indices] == shifted_targets[valid_indices]).sum()
208
- total += valid_indices.sum()
223
+ batch_correct = (preds[valid_indices] == shifted_targets[valid_indices]).sum()
224
+ batch_total = valid_indices.sum()
225
+ batch_acc = (batch_correct / batch_total * 100) if total > 0 else torch.tensor(0.0).to(self.device)
226
+ if self.writer is not None:
227
+ self.writer.add_scalar('Accuracy/Valid total', batch_acc.item(), self.total_validation_steps)
228
+ correct += batch_correct
229
+ total += batch_total
209
230
 
210
231
  avg_loss = (val_loss / len(val_dataloader)).item()
211
232
  acc = (correct / total * 100) if total > 0 else torch.tensor(0.0).to(self.device)
rxnn/training/dataset.py CHANGED
@@ -1,7 +1,8 @@
1
1
  import torch
2
2
  from torch.utils.data import Dataset
3
- from datasets import Dataset as HfDataset
4
- from transformers import PreTrainedTokenizer
3
+ from datasets import Dataset as HfDataset, load_dataset, concatenate_datasets
4
+ from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
5
+ from .tokenizer import load_tokenizer_from_hf_hub
5
6
 
6
7
  from typing import Union
7
8
 
@@ -10,10 +11,9 @@ class BaseDataset(Dataset):
10
11
  def __init__(
11
12
  self,
12
13
  texts: Union[list[str], HfDataset],
13
- tokenizer: PreTrainedTokenizer,
14
+ tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
14
15
  max_seq_len: int = 1024,
15
16
  hf_field: str = 'text',
16
- merge_short_from: int = None,
17
17
  *args,
18
18
  **kwargs
19
19
  ):
@@ -22,7 +22,6 @@ class BaseDataset(Dataset):
22
22
  self.max_seq_len = max_seq_len
23
23
  self.texts = texts
24
24
  self.hf_field = hf_field
25
- self.merge_short_from = merge_short_from
26
25
 
27
26
  def get_tokenized_text(self, idx: int):
28
27
  if isinstance(self.texts, list):
@@ -45,6 +44,93 @@ class BaseDataset(Dataset):
45
44
 
46
45
  return inputs
47
46
 
47
+ @classmethod
48
+ def from_hf_hub(
49
+ cls,
50
+ dataset_id: str,
51
+ subset: str = None,
52
+ split: str = 'train',
53
+ target_field: str = 'text',
54
+ tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] = None,
55
+ tokenizer_hub_id: str = None,
56
+ max_seq_len: int = 1024,
57
+ load_kwargs: dict = None,
58
+ load_tokenizer_kwargs: dict = None,
59
+ **kwargs
60
+ ):
61
+ """
62
+ Load dataset from HuggingFace Hub and convert it to RxNN training dataset.
63
+
64
+ One of the `tokenizer` or `tokenizer_hub_id` args must be provided. If both are provided, `tokenizer` will be used.
65
+
66
+ Args:
67
+ dataset_id (str): Hub dataset repository name
68
+ subset (str): Dataset subset
69
+ split (str): Dataset split (default: "train")
70
+ target_field (str): Name of dataset field used for training (default: "text")
71
+ tokenizer (PreTrainedTokenizer): HuggingFace Tokenizer used for training (default: None)
72
+ tokenizer_hub_id (str): HuggingFace Hub ID of tokenizer to load (default: None)
73
+ max_seq_len (int): Maximum sequence length for training (default: 1024)
74
+ load_kwargs (dict): Additional args for HuggingFace API load_dataset function
75
+ load_tokenizer_kwargs (dict): Additional args for loading tokenizer from HuggingFace API with `huggingface_hub.hf_hub_download`
76
+ **kwargs: Additional args for RxNN Dataset class
77
+ """
78
+ assert tokenizer is not None or tokenizer_hub_id is not None, "One of the `tokenizer` or `tokenizer_hub_id` args must be provided."
79
+
80
+ if tokenizer is None:
81
+ tokenizer = load_tokenizer_from_hf_hub(tokenizer_hub_id, **load_tokenizer_kwargs)
82
+
83
+ hf_dataset = load_dataset(dataset_id, subset, split=split, **load_kwargs)
84
+
85
+ return cls(hf_dataset, tokenizer, max_seq_len=max_seq_len, hf_field=target_field, **kwargs)
86
+
87
+ @classmethod
88
+ def concat_from_hf_hub(
89
+ cls,
90
+ dataset_ids: tuple[str],
91
+ subsets: tuple[str] = None,
92
+ split: str = 'train',
93
+ target_field: str = 'text',
94
+ tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] = None,
95
+ tokenizer_hub_id: str = None,
96
+ max_seq_len: int = 1024,
97
+ load_kwargs: dict = None,
98
+ load_tokenizer_kwargs: dict = None,
99
+ **kwargs
100
+ ):
101
+ """
102
+ Load and concatenate multiple datasets from HuggingFace Hub and convert them to RxNN training dataset.
103
+ All datasets should use the same split and target field. If it's not the case, just use `load_dataset` and pass the
104
+ result to RxNN dataset constructor directly.
105
+
106
+ One of the `tokenizer` or `tokenizer_hub_id` args must be provided. If both are provided, `tokenizer` will be used.
107
+
108
+ Args:
109
+ dataset_ids (tuple[str]): Hub dataset repository names
110
+ subsets (tuple[str]): Dataset subsets (default: None)
111
+ split (str): Dataset split (default: "train")
112
+ target_field (str): Name of dataset field used for training (default: "text")
113
+ tokenizer (PreTrainedTokenizer): HuggingFace Tokenizer used for training (default: None)
114
+ tokenizer_hub_id (str): HuggingFace Hub ID of tokenizer to load (default: None)
115
+ max_seq_len (int): Maximum sequence length for training (default: 1024)
116
+ load_kwargs (dict): Additional args for HuggingFace API load_dataset function
117
+ load_tokenizer_kwargs (dict): Additional args for loading tokenizer from HuggingFace API with `huggingface_hub.hf_hub_download`
118
+ **kwargs: Additional args for RxNN Dataset class
119
+ """
120
+ assert tokenizer is not None or tokenizer_hub_id is not None, "One of the `tokenizer` or `tokenizer_hub_id` args must be provided."
121
+
122
+ if tokenizer is None:
123
+ tokenizer = load_tokenizer_from_hf_hub(tokenizer_hub_id, **load_tokenizer_kwargs)
124
+
125
+ hf_datasets = [
126
+ load_dataset(dataset_id, subset, split=split, **load_kwargs) for dataset_id, subset in zip(dataset_ids, subsets)
127
+ ] if subsets is not None else [
128
+ load_dataset(dataset_id, split=split, **load_kwargs) for dataset_id in dataset_ids
129
+ ]
130
+ hf_dataset = concatenate_datasets(hf_datasets)
131
+
132
+ return cls(hf_dataset, tokenizer, max_seq_len=max_seq_len, hf_field=target_field, **kwargs)
133
+
48
134
 
49
135
  class JointLMDataset(BaseDataset):
50
136
  def __init__(
@@ -206,3 +206,9 @@ class TokenizerTrainer:
206
206
  trainer = cls()
207
207
  trainer.load(tokenizer_file)
208
208
  return trainer
209
+
210
+ def load_tokenizer_from_hf_hub(repo_id: str, **kwargs) -> PreTrainedTokenizerFast:
211
+ return TokenizerTrainer.from_pretrained(repo_id, **kwargs).get_hf_tokenizer()
212
+
213
+ def load_tokenizer_from_file(path: str) -> PreTrainedTokenizerFast:
214
+ return TokenizerTrainer.hf_tokenizer_from_file(path)
rxnn/transformers/moe.py CHANGED
@@ -23,20 +23,21 @@ class MoeRouter(nn.Module):
23
23
 
24
24
  def calculate_aux_loss(self, top_k_indices: torch.Tensor, probs: torch.Tensor) -> torch.Tensor:
25
25
  # Get shapes
26
- B, S, K = top_k_indices.shape # Batch, Sequence length, Top-K
26
+ T, K = top_k_indices.shape # Batch, Sequence length, Top-K
27
27
 
28
28
  # 1. Compute expert selection mask (one-hot encoded)
29
29
  expert_mask = F.one_hot(top_k_indices, self.num_experts).float() # (B, S, K, E)
30
30
 
31
31
  # 2. Total number of times each expert is selected
32
- expert_usage = expert_mask.sum(dim=(0, 1, 2)) # (E,)
32
+ expert_usage = expert_mask.sum(dim=(0, 1)) # (E,)
33
33
 
34
34
  # 3. Fraction of tokens assigned to each expert
35
- total_tokens = B * S * K
36
- fraction_expert = expert_usage / total_tokens # (E,)
35
+ total_selections = T * K
36
+ fraction_expert = expert_usage / (total_selections + 1e-6) # (E,)
37
37
 
38
38
  # 4. Sum of probabilities for each expert's selected tokens
39
- sum_probs = (probs.unsqueeze(-1) * expert_mask).sum(dim=(0, 1, 2)) # (E,)
39
+ probs_expanded = probs.unsqueeze(1).expand(-1, K, -1) # (B_K, K, E)
40
+ sum_probs = (probs_expanded * expert_mask).sum(dim=(0, 1))
40
41
 
41
42
  # 5. Average probability per expert (avoid division by zero)
42
43
  avg_probs = sum_probs / expert_usage.clamp(min=1e-6) # (E,)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.1.52
3
+ Version: 0.1.54
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -9,23 +9,23 @@ rxnn/memory/stm.py,sha256=EsD8slSP4_9dLuq6aFPDmuFe8PWilxh90so5Z3nm-ig,2057
9
9
  rxnn/rxt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
10
  rxnn/rxt/models.py,sha256=INTFeNcqzAsjyWhNtbBHL4Tx7tYDsaQHgm72tf6u20M,6918
11
11
  rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
- rxnn/training/base.py,sha256=QD8uS14jSyR5Y_8BgCaBQTKpsarerU3lyufsWsCq_6o,11227
13
- rxnn/training/bml.py,sha256=HtxSzI-WcpRclAuIccF_WoTZ24KzH5ZWKe8SnWgjjm4,17581
12
+ rxnn/training/base.py,sha256=gEWASLSuWR8UF8b2e-DYqkBZ1lBx0VsIm4kGf9eWSHM,11678
13
+ rxnn/training/bml.py,sha256=S1ZaXTybzeJH7uVFamCr4TPl2bLyZ5xmn_lSsjThTiM,19162
14
14
  rxnn/training/callbacks.py,sha256=_YfMKY_eFdc-tubhO9nYH2PXDZDQwlSI74FVOoCXpQg,22108
15
- rxnn/training/dataset.py,sha256=vQ5mDF3bA0HXya474n4D4iL8Mn3AEpJukgzFNVkxjGU,5106
15
+ rxnn/training/dataset.py,sha256=JQuWSUdT5AnsrG6M_EsewoU6uroVHhg4K715nbtDx8A,9643
16
16
  rxnn/training/scheduler.py,sha256=ow6oALzWjWQmHSpcJEjv6tg4g4CDMvr73TypxfcefMc,712
17
- rxnn/training/tokenizer.py,sha256=4Y41f07uo2KPA_7bp3FCcwGKbXoS2hsckOoXUsXfQxY,8052
17
+ rxnn/training/tokenizer.py,sha256=umaLByMBx_NMrQElA45HLm9gkuzyKWDTFaKVd-CjXl0,8344
18
18
  rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
19
19
  rxnn/transformers/attention.py,sha256=dC0UmC-_kjX8US6Sf0Fi5zw5kJ-P6orH3JDHeBB5gI8,15695
20
20
  rxnn/transformers/ff.py,sha256=jJnuBDsnnX5uYC_WZH8cXAYrMnz0P-iX7MwcPivjRtI,2533
21
21
  rxnn/transformers/layers.py,sha256=OX8CsFY9A7uqH1SLwyexR_5BNlwheYrJHCGXjF8Q7HU,7186
22
22
  rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
23
23
  rxnn/transformers/models.py,sha256=QFzBrOR7tDp9d_T0HoIukBMfEbLxsCictV5p3e2ilxg,7552
24
- rxnn/transformers/moe.py,sha256=88-w4cQhYNcebdq4zBsdkaoFa4VxJi1LFXDKAAkfVLk,5791
24
+ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
25
25
  rxnn/transformers/positional.py,sha256=2l38RS0Dini3f6Z3LUHr3XwWzg1UK7fO2C6wazWDAYU,4292
26
26
  rxnn/transformers/sampler.py,sha256=poWBpxg1iuK5gEJtxHkk5VVfS9V48hs2Olqdhy_Gw8c,6548
27
27
  rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
28
- rxnn-0.1.52.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
29
- rxnn-0.1.52.dist-info/METADATA,sha256=aae9Bt0SpsDgugeHY-7Bi6SN3wWhXneD3Kbz1NMtxJo,16627
30
- rxnn-0.1.52.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
31
- rxnn-0.1.52.dist-info/RECORD,,
28
+ rxnn-0.1.54.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
29
+ rxnn-0.1.54.dist-info/METADATA,sha256=FF9XlvOeROGLpVR5pHuuceoeXTzbMNJhEusmQdfPTD0,16627
30
+ rxnn-0.1.54.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
31
+ rxnn-0.1.54.dist-info/RECORD,,
File without changes
File without changes