rxnn 0.1.52__py3-none-any.whl → 0.1.54__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/training/base.py +19 -6
- rxnn/training/bml.py +27 -6
- rxnn/training/dataset.py +91 -5
- rxnn/training/tokenizer.py +6 -0
- rxnn/transformers/moe.py +6 -5
- {rxnn-0.1.52.dist-info → rxnn-0.1.54.dist-info}/METADATA +1 -1
- {rxnn-0.1.52.dist-info → rxnn-0.1.54.dist-info}/RECORD +9 -9
- {rxnn-0.1.52.dist-info → rxnn-0.1.54.dist-info}/LICENSE +0 -0
- {rxnn-0.1.52.dist-info → rxnn-0.1.54.dist-info}/WHEEL +0 -0
rxnn/training/base.py
CHANGED
@@ -50,6 +50,10 @@ class BaseTrainer(ABC):
|
|
50
50
|
self.target_field_name = target_field_name
|
51
51
|
self.total_tokens = 0
|
52
52
|
self.total_steps = 0
|
53
|
+
self.validation_steps = 0
|
54
|
+
self.total_validation_steps = 0
|
55
|
+
self.epoch_steps = 0
|
56
|
+
self.current_epoch = 0
|
53
57
|
self.gradient_accumulation_steps = gradient_accumulation_steps
|
54
58
|
self.accumulated_loss = 0.0
|
55
59
|
self.optimizer_step_count = 0
|
@@ -108,8 +112,10 @@ class BaseTrainer(ABC):
|
|
108
112
|
scaler = torch.amp.GradScaler() if self.use_amp else None
|
109
113
|
|
110
114
|
self.model.train()
|
111
|
-
for epoch in range(epochs):
|
115
|
+
for epoch in range(self.current_epoch, self.current_epoch + epochs):
|
112
116
|
if self.is_running:
|
117
|
+
self.current_epoch = epoch
|
118
|
+
self.epoch_steps = 0
|
113
119
|
if train_sampler is not None:
|
114
120
|
train_sampler.set_epoch(epoch)
|
115
121
|
self._run_epoch(dataloader, epoch, optimizer, batch_size, scaler=scaler, scheduler=scheduler)
|
@@ -142,6 +148,7 @@ class BaseTrainer(ABC):
|
|
142
148
|
callback.on_batch_start(self.model, batch_idx, batch)
|
143
149
|
if self.get_batch_size(batch) == batch_size:
|
144
150
|
self.total_steps += 1
|
151
|
+
self.epoch_steps = batch_idx
|
145
152
|
loss = self.train_step(batch, batch_idx)
|
146
153
|
orig_loss = loss.item()
|
147
154
|
self.accumulated_loss += orig_loss
|
@@ -174,25 +181,28 @@ class BaseTrainer(ABC):
|
|
174
181
|
self.writer.add_scalar(
|
175
182
|
'Loss/train',
|
176
183
|
loss_item,
|
177
|
-
|
184
|
+
self.total_steps,
|
178
185
|
)
|
179
186
|
self.writer.add_scalar(
|
180
|
-
'Loss
|
187
|
+
'Loss/train last epoch',
|
181
188
|
loss_item,
|
182
189
|
batch_idx
|
183
190
|
)
|
184
191
|
self.writer.add_scalar(
|
185
192
|
'Perplexity/train',
|
186
193
|
torch.exp(torch.tensor(loss_item)),
|
187
|
-
|
194
|
+
self.total_steps,
|
188
195
|
)
|
189
196
|
self.accumulated_loss = 0.0
|
190
197
|
self.optimizer_step_count = 0
|
191
198
|
|
192
199
|
if self.writer:
|
193
200
|
self.total_tokens += batch['attention_mask'].sum().item()
|
194
|
-
self.writer.add_scalar(
|
195
|
-
|
201
|
+
self.writer.add_scalar(
|
202
|
+
'Processed tokens',
|
203
|
+
self.total_tokens,
|
204
|
+
self.total_steps
|
205
|
+
)
|
196
206
|
|
197
207
|
for callback in self.callbacks:
|
198
208
|
should_stop = callback.on_batch_end(self.model, batch_idx, orig_loss, batch)
|
@@ -200,6 +210,7 @@ class BaseTrainer(ABC):
|
|
200
210
|
self.is_running = False
|
201
211
|
|
202
212
|
if self.validation_dataset:
|
213
|
+
self.validation_steps = 0
|
203
214
|
val_loss, val_metrics = self.validate(batch_size)
|
204
215
|
val_loss_tensor = torch.tensor(val_loss).to(self.device)
|
205
216
|
if self.use_ddp:
|
@@ -270,6 +281,8 @@ class BaseTrainer(ABC):
|
|
270
281
|
with torch.no_grad():
|
271
282
|
for batch in val_dataloader:
|
272
283
|
if self.get_batch_size(batch) == batch_size:
|
284
|
+
self.validation_steps += 1
|
285
|
+
self.total_validation_steps += 1
|
273
286
|
loss, outputs = self.valid_step(batch)
|
274
287
|
val_loss += loss.item()
|
275
288
|
|
rxnn/training/bml.py
CHANGED
@@ -91,8 +91,8 @@ class MLMTrainer(BaseTrainer):
|
|
91
91
|
self.writer.add_scalar('Router aux loss/Train', router_loss.item(), self.total_steps)
|
92
92
|
self.writer.add_scalar('Model loss/Train', main_loss.item(), self.total_steps)
|
93
93
|
else:
|
94
|
-
self.writer.add_scalar('Router aux loss/Valid', router_loss.item(), self.
|
95
|
-
self.writer.add_scalar('Model loss/Valid', main_loss.item(), self.
|
94
|
+
self.writer.add_scalar('Router aux loss/Valid', router_loss.item(), self.total_validation_steps)
|
95
|
+
self.writer.add_scalar('Model loss/Valid', main_loss.item(), self.total_validation_steps)
|
96
96
|
|
97
97
|
return loss
|
98
98
|
|
@@ -106,14 +106,25 @@ class MLMTrainer(BaseTrainer):
|
|
106
106
|
with torch.no_grad():
|
107
107
|
for batch in val_dataloader:
|
108
108
|
if self.get_batch_size(batch) == batch_size:
|
109
|
+
self.total_validation_steps += 1
|
110
|
+
self.validation_steps += 1
|
109
111
|
loss, logits = self.valid_step(batch)
|
110
112
|
val_loss += loss
|
113
|
+
if self.writer is not None:
|
114
|
+
self.writer.add_scalar('Loss/Valid total', loss.item(), self.total_validation_steps)
|
115
|
+
self.writer.add_scalar('Perplexity/Valid', torch.exp(loss).item(), self.total_validation_steps)
|
116
|
+
|
111
117
|
labels = batch[self.target_field_name].to(self.device)
|
112
118
|
valid_indices = labels != -100
|
113
119
|
if valid_indices.any():
|
114
120
|
preds = logits.argmax(-1)
|
115
|
-
|
116
|
-
|
121
|
+
batch_correct = (preds[valid_indices] == labels[valid_indices]).sum()
|
122
|
+
batch_total = valid_indices.sum()
|
123
|
+
batch_acc = (batch_correct / batch_total * 100) if total > 0 else torch.tensor(0.0).to(self.device)
|
124
|
+
if self.writer is not None:
|
125
|
+
self.writer.add_scalar('Accuracy/Valid total', batch_acc.item(), self.total_validation_steps)
|
126
|
+
correct += batch_correct
|
127
|
+
total += batch_total
|
117
128
|
|
118
129
|
avg_loss = (val_loss / len(val_dataloader)).item()
|
119
130
|
acc = (correct / total * 100) if total > 0 else torch.tensor(0.0).to(self.device)
|
@@ -197,15 +208,25 @@ class AutoregressiveTrainer(BaseTrainer):
|
|
197
208
|
with torch.no_grad():
|
198
209
|
for batch in val_dataloader:
|
199
210
|
if self.get_batch_size(batch) == batch_size:
|
211
|
+
self.total_validation_steps += 1
|
212
|
+
self.validation_steps += 1
|
200
213
|
loss, logits = self.valid_step(batch)
|
201
214
|
val_loss += loss
|
215
|
+
if self.writer is not None:
|
216
|
+
self.writer.add_scalar('Loss/Valid total', loss.item(), self.total_validation_steps)
|
217
|
+
self.writer.add_scalar('Perplexity/Valid', torch.exp(loss).item(), self.total_validation_steps)
|
202
218
|
shifted_logits = logits[:, :-1].contiguous()
|
203
219
|
shifted_targets = batch[self.target_field_name][:, 1:].to(self.device).contiguous()
|
204
220
|
valid_indices = shifted_targets != -100
|
205
221
|
if valid_indices.any():
|
206
222
|
preds = shifted_logits.argmax(-1)
|
207
|
-
|
208
|
-
|
223
|
+
batch_correct = (preds[valid_indices] == shifted_targets[valid_indices]).sum()
|
224
|
+
batch_total = valid_indices.sum()
|
225
|
+
batch_acc = (batch_correct / batch_total * 100) if total > 0 else torch.tensor(0.0).to(self.device)
|
226
|
+
if self.writer is not None:
|
227
|
+
self.writer.add_scalar('Accuracy/Valid total', batch_acc.item(), self.total_validation_steps)
|
228
|
+
correct += batch_correct
|
229
|
+
total += batch_total
|
209
230
|
|
210
231
|
avg_loss = (val_loss / len(val_dataloader)).item()
|
211
232
|
acc = (correct / total * 100) if total > 0 else torch.tensor(0.0).to(self.device)
|
rxnn/training/dataset.py
CHANGED
@@ -1,7 +1,8 @@
|
|
1
1
|
import torch
|
2
2
|
from torch.utils.data import Dataset
|
3
|
-
from datasets import Dataset as HfDataset
|
4
|
-
from transformers import PreTrainedTokenizer
|
3
|
+
from datasets import Dataset as HfDataset, load_dataset, concatenate_datasets
|
4
|
+
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
5
|
+
from .tokenizer import load_tokenizer_from_hf_hub
|
5
6
|
|
6
7
|
from typing import Union
|
7
8
|
|
@@ -10,10 +11,9 @@ class BaseDataset(Dataset):
|
|
10
11
|
def __init__(
|
11
12
|
self,
|
12
13
|
texts: Union[list[str], HfDataset],
|
13
|
-
tokenizer: PreTrainedTokenizer,
|
14
|
+
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
14
15
|
max_seq_len: int = 1024,
|
15
16
|
hf_field: str = 'text',
|
16
|
-
merge_short_from: int = None,
|
17
17
|
*args,
|
18
18
|
**kwargs
|
19
19
|
):
|
@@ -22,7 +22,6 @@ class BaseDataset(Dataset):
|
|
22
22
|
self.max_seq_len = max_seq_len
|
23
23
|
self.texts = texts
|
24
24
|
self.hf_field = hf_field
|
25
|
-
self.merge_short_from = merge_short_from
|
26
25
|
|
27
26
|
def get_tokenized_text(self, idx: int):
|
28
27
|
if isinstance(self.texts, list):
|
@@ -45,6 +44,93 @@ class BaseDataset(Dataset):
|
|
45
44
|
|
46
45
|
return inputs
|
47
46
|
|
47
|
+
@classmethod
|
48
|
+
def from_hf_hub(
|
49
|
+
cls,
|
50
|
+
dataset_id: str,
|
51
|
+
subset: str = None,
|
52
|
+
split: str = 'train',
|
53
|
+
target_field: str = 'text',
|
54
|
+
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] = None,
|
55
|
+
tokenizer_hub_id: str = None,
|
56
|
+
max_seq_len: int = 1024,
|
57
|
+
load_kwargs: dict = None,
|
58
|
+
load_tokenizer_kwargs: dict = None,
|
59
|
+
**kwargs
|
60
|
+
):
|
61
|
+
"""
|
62
|
+
Load dataset from HuggingFace Hub and convert it to RxNN training dataset.
|
63
|
+
|
64
|
+
One of the `tokenizer` or `tokenizer_hub_id` args must be provided. If both are provided, `tokenizer` will be used.
|
65
|
+
|
66
|
+
Args:
|
67
|
+
dataset_id (str): Hub dataset repository name
|
68
|
+
subset (str): Dataset subset
|
69
|
+
split (str): Dataset split (default: "train")
|
70
|
+
target_field (str): Name of dataset field used for training (default: "text")
|
71
|
+
tokenizer (PreTrainedTokenizer): HuggingFace Tokenizer used for training (default: None)
|
72
|
+
tokenizer_hub_id (str): HuggingFace Hub ID of tokenizer to load (default: None)
|
73
|
+
max_seq_len (int): Maximum sequence length for training (default: 1024)
|
74
|
+
load_kwargs (dict): Additional args for HuggingFace API load_dataset function
|
75
|
+
load_tokenizer_kwargs (dict): Additional args for loading tokenizer from HuggingFace API with `huggingface_hub.hf_hub_download`
|
76
|
+
**kwargs: Additional args for RxNN Dataset class
|
77
|
+
"""
|
78
|
+
assert tokenizer is not None or tokenizer_hub_id is not None, "One of the `tokenizer` or `tokenizer_hub_id` args must be provided."
|
79
|
+
|
80
|
+
if tokenizer is None:
|
81
|
+
tokenizer = load_tokenizer_from_hf_hub(tokenizer_hub_id, **load_tokenizer_kwargs)
|
82
|
+
|
83
|
+
hf_dataset = load_dataset(dataset_id, subset, split=split, **load_kwargs)
|
84
|
+
|
85
|
+
return cls(hf_dataset, tokenizer, max_seq_len=max_seq_len, hf_field=target_field, **kwargs)
|
86
|
+
|
87
|
+
@classmethod
|
88
|
+
def concat_from_hf_hub(
|
89
|
+
cls,
|
90
|
+
dataset_ids: tuple[str],
|
91
|
+
subsets: tuple[str] = None,
|
92
|
+
split: str = 'train',
|
93
|
+
target_field: str = 'text',
|
94
|
+
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] = None,
|
95
|
+
tokenizer_hub_id: str = None,
|
96
|
+
max_seq_len: int = 1024,
|
97
|
+
load_kwargs: dict = None,
|
98
|
+
load_tokenizer_kwargs: dict = None,
|
99
|
+
**kwargs
|
100
|
+
):
|
101
|
+
"""
|
102
|
+
Load and concatenate multiple datasets from HuggingFace Hub and convert them to RxNN training dataset.
|
103
|
+
All datasets should use the same split and target field. If it's not the case, just use `load_dataset` and pass the
|
104
|
+
result to RxNN dataset constructor directly.
|
105
|
+
|
106
|
+
One of the `tokenizer` or `tokenizer_hub_id` args must be provided. If both are provided, `tokenizer` will be used.
|
107
|
+
|
108
|
+
Args:
|
109
|
+
dataset_ids (tuple[str]): Hub dataset repository names
|
110
|
+
subsets (tuple[str]): Dataset subsets (default: None)
|
111
|
+
split (str): Dataset split (default: "train")
|
112
|
+
target_field (str): Name of dataset field used for training (default: "text")
|
113
|
+
tokenizer (PreTrainedTokenizer): HuggingFace Tokenizer used for training (default: None)
|
114
|
+
tokenizer_hub_id (str): HuggingFace Hub ID of tokenizer to load (default: None)
|
115
|
+
max_seq_len (int): Maximum sequence length for training (default: 1024)
|
116
|
+
load_kwargs (dict): Additional args for HuggingFace API load_dataset function
|
117
|
+
load_tokenizer_kwargs (dict): Additional args for loading tokenizer from HuggingFace API with `huggingface_hub.hf_hub_download`
|
118
|
+
**kwargs: Additional args for RxNN Dataset class
|
119
|
+
"""
|
120
|
+
assert tokenizer is not None or tokenizer_hub_id is not None, "One of the `tokenizer` or `tokenizer_hub_id` args must be provided."
|
121
|
+
|
122
|
+
if tokenizer is None:
|
123
|
+
tokenizer = load_tokenizer_from_hf_hub(tokenizer_hub_id, **load_tokenizer_kwargs)
|
124
|
+
|
125
|
+
hf_datasets = [
|
126
|
+
load_dataset(dataset_id, subset, split=split, **load_kwargs) for dataset_id, subset in zip(dataset_ids, subsets)
|
127
|
+
] if subsets is not None else [
|
128
|
+
load_dataset(dataset_id, split=split, **load_kwargs) for dataset_id in dataset_ids
|
129
|
+
]
|
130
|
+
hf_dataset = concatenate_datasets(hf_datasets)
|
131
|
+
|
132
|
+
return cls(hf_dataset, tokenizer, max_seq_len=max_seq_len, hf_field=target_field, **kwargs)
|
133
|
+
|
48
134
|
|
49
135
|
class JointLMDataset(BaseDataset):
|
50
136
|
def __init__(
|
rxnn/training/tokenizer.py
CHANGED
@@ -206,3 +206,9 @@ class TokenizerTrainer:
|
|
206
206
|
trainer = cls()
|
207
207
|
trainer.load(tokenizer_file)
|
208
208
|
return trainer
|
209
|
+
|
210
|
+
def load_tokenizer_from_hf_hub(repo_id: str, **kwargs) -> PreTrainedTokenizerFast:
|
211
|
+
return TokenizerTrainer.from_pretrained(repo_id, **kwargs).get_hf_tokenizer()
|
212
|
+
|
213
|
+
def load_tokenizer_from_file(path: str) -> PreTrainedTokenizerFast:
|
214
|
+
return TokenizerTrainer.hf_tokenizer_from_file(path)
|
rxnn/transformers/moe.py
CHANGED
@@ -23,20 +23,21 @@ class MoeRouter(nn.Module):
|
|
23
23
|
|
24
24
|
def calculate_aux_loss(self, top_k_indices: torch.Tensor, probs: torch.Tensor) -> torch.Tensor:
|
25
25
|
# Get shapes
|
26
|
-
|
26
|
+
T, K = top_k_indices.shape # Batch, Sequence length, Top-K
|
27
27
|
|
28
28
|
# 1. Compute expert selection mask (one-hot encoded)
|
29
29
|
expert_mask = F.one_hot(top_k_indices, self.num_experts).float() # (B, S, K, E)
|
30
30
|
|
31
31
|
# 2. Total number of times each expert is selected
|
32
|
-
expert_usage = expert_mask.sum(dim=(0, 1
|
32
|
+
expert_usage = expert_mask.sum(dim=(0, 1)) # (E,)
|
33
33
|
|
34
34
|
# 3. Fraction of tokens assigned to each expert
|
35
|
-
|
36
|
-
fraction_expert = expert_usage /
|
35
|
+
total_selections = T * K
|
36
|
+
fraction_expert = expert_usage / (total_selections + 1e-6) # (E,)
|
37
37
|
|
38
38
|
# 4. Sum of probabilities for each expert's selected tokens
|
39
|
-
|
39
|
+
probs_expanded = probs.unsqueeze(1).expand(-1, K, -1) # (B_K, K, E)
|
40
|
+
sum_probs = (probs_expanded * expert_mask).sum(dim=(0, 1))
|
40
41
|
|
41
42
|
# 5. Average probability per expert (avoid division by zero)
|
42
43
|
avg_probs = sum_probs / expert_usage.clamp(min=1e-6) # (E,)
|
@@ -9,23 +9,23 @@ rxnn/memory/stm.py,sha256=EsD8slSP4_9dLuq6aFPDmuFe8PWilxh90so5Z3nm-ig,2057
|
|
9
9
|
rxnn/rxt/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
10
10
|
rxnn/rxt/models.py,sha256=INTFeNcqzAsjyWhNtbBHL4Tx7tYDsaQHgm72tf6u20M,6918
|
11
11
|
rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
12
|
-
rxnn/training/base.py,sha256=
|
13
|
-
rxnn/training/bml.py,sha256=
|
12
|
+
rxnn/training/base.py,sha256=gEWASLSuWR8UF8b2e-DYqkBZ1lBx0VsIm4kGf9eWSHM,11678
|
13
|
+
rxnn/training/bml.py,sha256=S1ZaXTybzeJH7uVFamCr4TPl2bLyZ5xmn_lSsjThTiM,19162
|
14
14
|
rxnn/training/callbacks.py,sha256=_YfMKY_eFdc-tubhO9nYH2PXDZDQwlSI74FVOoCXpQg,22108
|
15
|
-
rxnn/training/dataset.py,sha256=
|
15
|
+
rxnn/training/dataset.py,sha256=JQuWSUdT5AnsrG6M_EsewoU6uroVHhg4K715nbtDx8A,9643
|
16
16
|
rxnn/training/scheduler.py,sha256=ow6oALzWjWQmHSpcJEjv6tg4g4CDMvr73TypxfcefMc,712
|
17
|
-
rxnn/training/tokenizer.py,sha256=
|
17
|
+
rxnn/training/tokenizer.py,sha256=umaLByMBx_NMrQElA45HLm9gkuzyKWDTFaKVd-CjXl0,8344
|
18
18
|
rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
19
19
|
rxnn/transformers/attention.py,sha256=dC0UmC-_kjX8US6Sf0Fi5zw5kJ-P6orH3JDHeBB5gI8,15695
|
20
20
|
rxnn/transformers/ff.py,sha256=jJnuBDsnnX5uYC_WZH8cXAYrMnz0P-iX7MwcPivjRtI,2533
|
21
21
|
rxnn/transformers/layers.py,sha256=OX8CsFY9A7uqH1SLwyexR_5BNlwheYrJHCGXjF8Q7HU,7186
|
22
22
|
rxnn/transformers/mask.py,sha256=J0cfLVLt3SzS2ra3KcY4khrkhI975Dw4CjpUi3Sn25s,419
|
23
23
|
rxnn/transformers/models.py,sha256=QFzBrOR7tDp9d_T0HoIukBMfEbLxsCictV5p3e2ilxg,7552
|
24
|
-
rxnn/transformers/moe.py,sha256=
|
24
|
+
rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
|
25
25
|
rxnn/transformers/positional.py,sha256=2l38RS0Dini3f6Z3LUHr3XwWzg1UK7fO2C6wazWDAYU,4292
|
26
26
|
rxnn/transformers/sampler.py,sha256=poWBpxg1iuK5gEJtxHkk5VVfS9V48hs2Olqdhy_Gw8c,6548
|
27
27
|
rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
|
28
|
-
rxnn-0.1.
|
29
|
-
rxnn-0.1.
|
30
|
-
rxnn-0.1.
|
31
|
-
rxnn-0.1.
|
28
|
+
rxnn-0.1.54.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
29
|
+
rxnn-0.1.54.dist-info/METADATA,sha256=FF9XlvOeROGLpVR5pHuuceoeXTzbMNJhEusmQdfPTD0,16627
|
30
|
+
rxnn-0.1.54.dist-info/WHEEL,sha256=fGIA9gx4Qxk2KDKeNJCbOEwSrmLtjWCwzBz351GyrPQ,88
|
31
|
+
rxnn-0.1.54.dist-info/RECORD,,
|
File without changes
|
File without changes
|