rxnn 0.1.80__tar.gz → 0.1.82__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. {rxnn-0.1.80 → rxnn-0.1.82}/PKG-INFO +1 -1
  2. {rxnn-0.1.80 → rxnn-0.1.82}/pyproject.toml +1 -1
  3. rxnn-0.1.82/src/rxnn/training/dataset.py +800 -0
  4. rxnn-0.1.80/src/rxnn/training/dataset.py +0 -354
  5. {rxnn-0.1.80 → rxnn-0.1.82}/LICENSE +0 -0
  6. {rxnn-0.1.80 → rxnn-0.1.82}/README.md +0 -0
  7. {rxnn-0.1.80 → rxnn-0.1.82}/src/rxnn/__init__.py +0 -0
  8. {rxnn-0.1.80 → rxnn-0.1.82}/src/rxnn/experimental/__init__.py +0 -0
  9. {rxnn-0.1.80 → rxnn-0.1.82}/src/rxnn/experimental/attention.py +0 -0
  10. {rxnn-0.1.80 → rxnn-0.1.82}/src/rxnn/experimental/models.py +0 -0
  11. {rxnn-0.1.80 → rxnn-0.1.82}/src/rxnn/experimental/moe.py +0 -0
  12. {rxnn-0.1.80 → rxnn-0.1.82}/src/rxnn/memory/__init__.py +0 -0
  13. {rxnn-0.1.80 → rxnn-0.1.82}/src/rxnn/memory/norm.py +0 -0
  14. {rxnn-0.1.80 → rxnn-0.1.82}/src/rxnn/memory/stm.py +0 -0
  15. {rxnn-0.1.80 → rxnn-0.1.82}/src/rxnn/rxt/__init__.py +0 -0
  16. {rxnn-0.1.80 → rxnn-0.1.82}/src/rxnn/rxt/models.py +0 -0
  17. {rxnn-0.1.80 → rxnn-0.1.82}/src/rxnn/training/__init__.py +0 -0
  18. {rxnn-0.1.80 → rxnn-0.1.82}/src/rxnn/training/base.py +0 -0
  19. {rxnn-0.1.80 → rxnn-0.1.82}/src/rxnn/training/bml.py +0 -0
  20. {rxnn-0.1.80 → rxnn-0.1.82}/src/rxnn/training/callbacks.py +0 -0
  21. {rxnn-0.1.80 → rxnn-0.1.82}/src/rxnn/training/scheduler.py +0 -0
  22. {rxnn-0.1.80 → rxnn-0.1.82}/src/rxnn/training/tokenizer.py +0 -0
  23. {rxnn-0.1.80 → rxnn-0.1.82}/src/rxnn/transformers/__init__.py +0 -0
  24. {rxnn-0.1.80 → rxnn-0.1.82}/src/rxnn/transformers/attention.py +0 -0
  25. {rxnn-0.1.80 → rxnn-0.1.82}/src/rxnn/transformers/ff.py +0 -0
  26. {rxnn-0.1.80 → rxnn-0.1.82}/src/rxnn/transformers/layers.py +0 -0
  27. {rxnn-0.1.80 → rxnn-0.1.82}/src/rxnn/transformers/mask.py +0 -0
  28. {rxnn-0.1.80 → rxnn-0.1.82}/src/rxnn/transformers/models.py +0 -0
  29. {rxnn-0.1.80 → rxnn-0.1.82}/src/rxnn/transformers/moe.py +0 -0
  30. {rxnn-0.1.80 → rxnn-0.1.82}/src/rxnn/transformers/positional.py +0 -0
  31. {rxnn-0.1.80 → rxnn-0.1.82}/src/rxnn/transformers/sampler.py +0 -0
  32. {rxnn-0.1.80 → rxnn-0.1.82}/src/rxnn/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: rxnn
3
- Version: 0.1.80
3
+ Version: 0.1.82
4
4
  Summary: RxNN: Reactive Neural Networks Platform
5
5
  License: Apache-2.0
6
6
  Keywords: deep-learning,ai,machine-learning
@@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"
4
4
 
5
5
  [tool.poetry]
6
6
  name = "rxnn"
7
- version = "0.1.80"
7
+ version = "0.1.82"
8
8
  description = "RxNN: Reactive Neural Networks Platform"
9
9
 
10
10
  license = "Apache-2.0"
@@ -0,0 +1,800 @@
1
+ import torch
2
+ from torch.utils.data import Dataset
3
+ from datasets import Dataset as HfDataset, load_dataset, concatenate_datasets
4
+ from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
5
+ from .tokenizer import load_tokenizer_from_hf_hub
6
+
7
+ from typing import Union
8
+
9
+
10
+ class BaseDataset(Dataset):
11
+ def __init__(
12
+ self,
13
+ texts: Union[list[str], HfDataset],
14
+ tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
15
+ max_seq_len: int = 1024,
16
+ hf_field: str = 'text',
17
+ cache_tokenized: bool = False,
18
+ cache_remove_text: bool = True,
19
+ tokenize_in_background: bool = False,
20
+ batch_size: int = 1,
21
+ *args,
22
+ **kwargs
23
+ ):
24
+ super(BaseDataset, self).__init__(*args, **kwargs)
25
+ self.tokenizer = tokenizer
26
+ self.max_seq_len = max_seq_len
27
+ self.texts = texts
28
+ self.hf_field = hf_field
29
+ self.is_pre_tokenized = False
30
+ self.cache_tokenized = cache_tokenized
31
+ self.cache_remove_text = cache_remove_text
32
+ self.inputs = []
33
+ self.is_txt_list = isinstance(self.texts, list)
34
+ self.tokenize_in_background = tokenize_in_background
35
+ self.bg_next = []
36
+ self.bg_queue = None
37
+ self.batch_size = batch_size
38
+ self.last_idx = 0
39
+ if tokenize_in_background:
40
+ for i in range(self.batch_size):
41
+ self.bg_next.append(self.get_tokenized_text(i))
42
+ self.last_idx = self.batch_size - 1
43
+
44
+
45
+ def __len__(self):
46
+ return len(self.texts if not self.is_pre_tokenized else self.inputs)
47
+
48
+ def get_tokenized_text(self, idx: int, txt: str = None):
49
+ if self.is_pre_tokenized:
50
+ return self.inputs[idx]
51
+ elif self.tokenize_in_background:
52
+ if idx == self.last_idx - self.batch_size:
53
+ if self.bg_queue is not None:
54
+ self.bg_next = self.bg_queue
55
+ self.bg_queue = None
56
+ # TODO: schedule tokenizing a batch in background
57
+ elif idx == self.last_idx:
58
+ item = self.bg_next[idx]
59
+ self.bg_next = []
60
+ return item
61
+
62
+ if idx <= self.last_idx:
63
+ if self.bg_queue is not None:
64
+ self.bg_next = self.bg_queue
65
+ self.bg_queue = None
66
+
67
+ new_idx = idx - (self.last_idx - self.batch_size)
68
+ if new_idx in self.bg_next:
69
+ return self.bg_next[new_idx]
70
+ else:
71
+ if self.is_txt_list:
72
+ text = self.texts[idx]
73
+ else:
74
+ text = self.texts[idx][self.hf_field]
75
+
76
+ inputs = self.tokenizer(
77
+ text,
78
+ max_length=self.max_seq_len,
79
+ truncation=True,
80
+ padding='max_length',
81
+ return_tensors='pt',
82
+ return_attention_mask=True
83
+ )
84
+ if not (inputs['input_ids'][0] < self.tokenizer.vocab_size).all():
85
+ inputs['input_ids'][0][
86
+ (inputs['input_ids'][0] >= self.tokenizer.vocab_size)] = self.tokenizer.unk_token_id
87
+ if not (inputs['input_ids'][0] >= 0).all():
88
+ inputs['input_ids'][0][inputs['input_ids'][0] < 0] = self.tokenizer.unk_token_id
89
+
90
+ return inputs
91
+ else:
92
+ if txt is not None:
93
+ text = txt
94
+ elif self.is_txt_list:
95
+ text = self.texts[idx]
96
+ else:
97
+ text = self.texts[idx][self.hf_field]
98
+
99
+ inputs = self.tokenizer(
100
+ text,
101
+ max_length=self.max_seq_len,
102
+ truncation=True,
103
+ padding='max_length',
104
+ return_tensors='pt',
105
+ return_attention_mask=True
106
+ )
107
+ if not (inputs['input_ids'][0] < self.tokenizer.vocab_size).all():
108
+ inputs['input_ids'][0][(inputs['input_ids'][0] >= self.tokenizer.vocab_size)] = self.tokenizer.unk_token_id
109
+ if not (inputs['input_ids'][0] >= 0).all():
110
+ inputs['input_ids'][0][inputs['input_ids'][0] < 0] = self.tokenizer.unk_token_id
111
+
112
+ if self.cache_tokenized:
113
+ self.inputs.append(inputs)
114
+ if len(self.inputs) == len(self.texts):
115
+ self.is_pre_tokenized = True
116
+ if self.cache_remove_text:
117
+ del self.texts
118
+ self.texts = None
119
+
120
+ return inputs
121
+
122
+ def get_subset(self, size: float, from_start: bool = False, **kwargs) -> "BaseDataset":
123
+ split_point = int(len(self.texts) * ((1 - size) if not from_start else size))
124
+ if not isinstance(self.texts, list):
125
+ subset = self.texts.select(range(split_point, len(self.texts)) if not from_start else range(split_point))
126
+ self.texts = self.texts.select(range(split_point) if not from_start else range(split_point, len(self.texts)))
127
+ else:
128
+ subset = self.texts[split_point:-1] if not from_start else self.texts[0:split_point]
129
+ self.texts = self.texts[0:split_point] if not from_start else self.texts[split_point:-1]
130
+ return self.__class__(subset, self.tokenizer, max_seq_len=self.max_seq_len, hf_field=self.hf_field, **kwargs)
131
+
132
+ def pre_tokenize(self, verbose: bool = False, log_interval: int = 10_000, map_hf_ds_to_list: bool = True):
133
+ """
134
+ Pre-tokenizes all the items in the dataset, for faster training. Training with pre-tokenized
135
+ dataset could be even 2x faster.
136
+
137
+ !! This method has extremely high memory usage, when used with HuggingFace datasets,
138
+ because of converting it to list. Additionally, for the most optimal performance,
139
+ pre-tokenized items are in reversed order - it shouldn't matter for training, as
140
+ items are shuffled then by DataLoader, but you should keep that in mind in case
141
+ of reproducibility.
142
+
143
+ :param(bool) verbose:
144
+ :param(int) log_interval: Interval of verbose logs
145
+ """
146
+ if not self.is_pre_tokenized:
147
+ num_texts = len(self.texts)
148
+ txts = self.texts if self.is_txt_list else self.texts.to_list()
149
+ del self.texts
150
+ self.texts = None
151
+ for index in range(num_texts):
152
+ item = txts.pop() if self.is_txt_list else txts.pop()[self.hf_field]
153
+ self.inputs.append(self.get_tokenized_text(index, txt=item))
154
+ if verbose and index % log_interval == 0:
155
+ print(f'Processed {index + 1}/{num_texts}')
156
+ self.is_pre_tokenized = True
157
+
158
+
159
+ @classmethod
160
+ def from_hf_hub(
161
+ cls,
162
+ dataset_id: str,
163
+ subset: str = None,
164
+ split: str = 'train',
165
+ target_field: str = 'text',
166
+ tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] = None,
167
+ tokenizer_hub_id: str = None,
168
+ max_seq_len: int = 1024,
169
+ load_kwargs: dict = None,
170
+ load_tokenizer_kwargs: dict = None,
171
+ **kwargs
172
+ ):
173
+ """
174
+ Load dataset from HuggingFace Hub and convert it to RxNN training dataset.
175
+
176
+ One of the `tokenizer` or `tokenizer_hub_id` args must be provided. If both are provided, `tokenizer` will be used.
177
+
178
+ Args:
179
+ dataset_id (str): Hub dataset repository name
180
+ subset (str): Dataset subset
181
+ split (str): Dataset split (default: "train")
182
+ target_field (str): Name of dataset field used for training (default: "text")
183
+ tokenizer (PreTrainedTokenizer): HuggingFace Tokenizer used for training (default: None)
184
+ tokenizer_hub_id (str): HuggingFace Hub ID of tokenizer to load (default: None)
185
+ max_seq_len (int): Maximum sequence length for training (default: 1024)
186
+ load_kwargs (dict): Additional args for HuggingFace API load_dataset function
187
+ load_tokenizer_kwargs (dict): Additional args for loading tokenizer from HuggingFace API with `huggingface_hub.hf_hub_download`
188
+ **kwargs: Additional args for RxNN Dataset class
189
+ """
190
+ assert tokenizer is not None or tokenizer_hub_id is not None, "One of the `tokenizer` or `tokenizer_hub_id` args must be provided."
191
+
192
+ if tokenizer is None:
193
+ tokenizer = load_tokenizer_from_hf_hub(tokenizer_hub_id, **load_tokenizer_kwargs)
194
+
195
+ hf_dataset = load_dataset(dataset_id, subset, split=split, **load_kwargs)
196
+
197
+ return cls(hf_dataset, tokenizer, max_seq_len=max_seq_len, hf_field=target_field, **kwargs)
198
+
199
+ @classmethod
200
+ def concat_from_hf_hub(
201
+ cls,
202
+ dataset_ids: tuple[str],
203
+ subsets: tuple[str] = None,
204
+ split: str = 'train',
205
+ target_field: str = 'text',
206
+ tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] = None,
207
+ tokenizer_hub_id: str = None,
208
+ max_seq_len: int = 1024,
209
+ load_kwargs: dict = None,
210
+ load_tokenizer_kwargs: dict = None,
211
+ **kwargs
212
+ ):
213
+ """
214
+ Load and concatenate multiple datasets from HuggingFace Hub and convert them to RxNN training dataset.
215
+ All datasets should use the same split and target field. If it's not the case, just use `load_dataset` and pass the
216
+ result to RxNN dataset constructor directly.
217
+
218
+ One of the `tokenizer` or `tokenizer_hub_id` args must be provided. If both are provided, `tokenizer` will be used.
219
+
220
+ Args:
221
+ dataset_ids (tuple[str]): Hub dataset repository names
222
+ subsets (tuple[str]): Dataset subsets (default: None)
223
+ split (str): Dataset split (default: "train")
224
+ target_field (str): Name of dataset field used for training (default: "text")
225
+ tokenizer (PreTrainedTokenizer): HuggingFace Tokenizer used for training (default: None)
226
+ tokenizer_hub_id (str): HuggingFace Hub ID of tokenizer to load (default: None)
227
+ max_seq_len (int): Maximum sequence length for training (default: 1024)
228
+ load_kwargs (dict): Additional args for HuggingFace API load_dataset function
229
+ load_tokenizer_kwargs (dict): Additional args for loading tokenizer from HuggingFace API with `huggingface_hub.hf_hub_download`
230
+ **kwargs: Additional args for RxNN Dataset class
231
+ """
232
+ assert tokenizer is not None or tokenizer_hub_id is not None, "One of the `tokenizer` or `tokenizer_hub_id` args must be provided."
233
+
234
+ if tokenizer is None:
235
+ tokenizer = load_tokenizer_from_hf_hub(tokenizer_hub_id, **load_tokenizer_kwargs)
236
+
237
+ hf_datasets = [
238
+ load_dataset(dataset_id, subset, split=split, **load_kwargs) for dataset_id, subset in zip(dataset_ids, subsets)
239
+ ] if subsets is not None else [
240
+ load_dataset(dataset_id, split=split, **load_kwargs) for dataset_id in dataset_ids
241
+ ]
242
+ hf_dataset = concatenate_datasets(hf_datasets)
243
+
244
+ return cls(hf_dataset, tokenizer, max_seq_len=max_seq_len, hf_field=target_field, **kwargs)
245
+
246
+ @classmethod
247
+ def concat_from_hf_hub_with_subset(
248
+ cls,
249
+ dataset_ids: tuple[str],
250
+ subsets: tuple[str] = None,
251
+ split: str = 'train',
252
+ target_field: str = 'text',
253
+ tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] = None,
254
+ tokenizer_hub_id: str = None,
255
+ max_seq_len: int = 1024,
256
+ load_kwargs: dict = None,
257
+ load_tokenizer_kwargs: dict = None,
258
+ valid_size: float = 0.1,
259
+ **kwargs
260
+ ):
261
+ """
262
+ Load and concatenate multiple datasets from HuggingFace Hub, create validation split and convert them to RxNN training dataset.
263
+ All datasets should use the same split and target field. If it's not the case, just use `load_dataset` and pass the
264
+ result to RxNN dataset constructor directly.
265
+
266
+ One of the `tokenizer` or `tokenizer_hub_id` args must be provided. If both are provided, `tokenizer` will be used.
267
+
268
+ Args:
269
+ dataset_ids (tuple[str]): Hub dataset repository names
270
+ subsets (tuple[str]): Dataset subsets (default: None)
271
+ split (str): Dataset split (default: "train")
272
+ target_field (str): Name of dataset field used for training (default: "text")
273
+ tokenizer (PreTrainedTokenizer): HuggingFace Tokenizer used for training (default: None)
274
+ tokenizer_hub_id (str): HuggingFace Hub ID of tokenizer to load (default: None)
275
+ max_seq_len (int): Maximum sequence length for training (default: 1024)
276
+ load_kwargs (dict): Additional args for HuggingFace API load_dataset function
277
+ load_tokenizer_kwargs (dict): Additional args for loading tokenizer from HuggingFace API with `huggingface_hub.hf_hub_download`
278
+ valid_size (float): Size of validation dataset (default: 0.1)
279
+ **kwargs: Additional args for RxNN Dataset class
280
+ """
281
+ assert tokenizer is not None or tokenizer_hub_id is not None, "One of the `tokenizer` or `tokenizer_hub_id` args must be provided."
282
+
283
+ if tokenizer is None:
284
+ tokenizer = load_tokenizer_from_hf_hub(tokenizer_hub_id, **load_tokenizer_kwargs)
285
+
286
+ hf_datasets = [
287
+ load_dataset(dataset_id, subset, split=split, **load_kwargs) for dataset_id, subset in zip(dataset_ids, subsets)
288
+ ] if subsets is not None else [
289
+ load_dataset(dataset_id, split=split, **load_kwargs) for dataset_id in dataset_ids
290
+ ]
291
+
292
+ hf_ds_dicts = [dataset.train_test_split(test_size=valid_size) for dataset in hf_datasets]
293
+
294
+ hf_dataset = concatenate_datasets([ds_dict['train'] for ds_dict in hf_ds_dicts])
295
+ hf_valid_dataset = concatenate_datasets([ds_dict['test'] for ds_dict in hf_ds_dicts])
296
+
297
+ return cls(hf_dataset, tokenizer, max_seq_len=max_seq_len, hf_field=target_field, **kwargs), cls(hf_valid_dataset, tokenizer, max_seq_len=max_seq_len, hf_field=target_field, **kwargs)
298
+
299
+
300
+ class JointLMDataset(BaseDataset):
301
+ def __init__(
302
+ self,
303
+ texts: Union[list[str], HfDataset],
304
+ tokenizer: PreTrainedTokenizer,
305
+ max_seq_len: int = 1024,
306
+ mask_prob: float = 0.15,
307
+ hf_field: str = 'text',
308
+ *args,
309
+ **kwargs
310
+ ):
311
+ super(JointLMDataset, self).__init__(texts, tokenizer, max_seq_len, hf_field, *args, **kwargs)
312
+ self.mask_prob = mask_prob
313
+
314
+ def __getitem__(self, idx: int) -> dict[str, dict[str, torch.Tensor]]:
315
+ inputs = self.get_tokenized_text(idx)
316
+ encoder_input_ids = inputs['input_ids'][0]
317
+ attention_mask = inputs['attention_mask'][0]
318
+
319
+ decoder_input_ids = encoder_input_ids.clone()
320
+
321
+ encoder_labels = encoder_input_ids.clone()
322
+ decoder_targets = encoder_input_ids.clone()
323
+
324
+ # Create masked indices
325
+ masked_indices = torch.bernoulli(
326
+ torch.full(encoder_labels.shape, self.mask_prob)
327
+ ).bool() & attention_mask.bool()
328
+
329
+ # Apply mask
330
+ encoder_labels[~masked_indices] = -100
331
+ encoder_input_ids[masked_indices] = self.tokenizer.mask_token_id
332
+
333
+ return {
334
+ 'decoder': {
335
+ 'input_ids': decoder_input_ids,
336
+ 'targets': decoder_targets,
337
+ },
338
+ 'encoder': {
339
+ 'input_ids': encoder_input_ids,
340
+ 'labels': encoder_labels,
341
+ },
342
+ 'attention_mask': attention_mask,
343
+ }
344
+
345
+
346
+ class MaskedLMDataset(BaseDataset):
347
+ def __init__(
348
+ self,
349
+ texts: Union[list[str], HfDataset],
350
+ tokenizer: PreTrainedTokenizer,
351
+ max_seq_len: int = 1024,
352
+ mask_prob: float = 0.15,
353
+ hf_field: str = 'text',
354
+ *args,
355
+ **kwargs
356
+ ):
357
+ super(MaskedLMDataset, self).__init__(texts, tokenizer, max_seq_len, hf_field, *args, **kwargs)
358
+ self.mask_prob = mask_prob
359
+
360
+ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
361
+ inputs = self.get_tokenized_text(idx)
362
+
363
+ input_ids = inputs['input_ids'][0]
364
+ attention_mask = inputs['attention_mask'][0]
365
+ labels = input_ids.clone()
366
+
367
+ # Create masked indices
368
+ masked_indices = torch.bernoulli(
369
+ torch.full(labels.shape, self.mask_prob)
370
+ ).bool() & attention_mask.bool()
371
+
372
+ # Apply mask
373
+ labels[~masked_indices] = -100
374
+ input_ids[masked_indices] = self.tokenizer.mask_token_id
375
+
376
+ return {
377
+ 'input_ids': input_ids,
378
+ 'attention_mask': attention_mask,
379
+ 'labels': labels
380
+ }
381
+
382
+
383
+ class AutoregressiveLMDataset(BaseDataset):
384
+ def __init__(
385
+ self,
386
+ texts: Union[list[str], HfDataset],
387
+ tokenizer: PreTrainedTokenizer,
388
+ max_seq_len: int = 1024,
389
+ hf_field: str = 'text',
390
+ *args,
391
+ **kwargs
392
+ ):
393
+ super(AutoregressiveLMDataset, self).__init__(texts, tokenizer, max_seq_len, hf_field, *args, **kwargs)
394
+
395
+ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
396
+ inputs = self.get_tokenized_text(idx)
397
+
398
+ input_ids = inputs['input_ids'][0]
399
+ attention_mask = inputs['attention_mask'][0]
400
+ targets = input_ids.clone()
401
+
402
+ return {
403
+ 'input_ids': input_ids,
404
+ 'attention_mask': attention_mask,
405
+ 'targets': targets
406
+ }
407
+
408
+ class BaseInteractionDataset(Dataset):
409
+ def __init__(
410
+ self,
411
+ interactions: Union[list[dict], HfDataset],
412
+ tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
413
+ max_seq_len: int = 1024,
414
+ query_field: str = 'query',
415
+ answer_field: str = 'answer',
416
+ cache_tokenized: bool = False,
417
+ cache_remove_text: bool = True,
418
+ tokenize_in_background: bool = False,
419
+ batch_size: int = 1,
420
+ *args,
421
+ **kwargs
422
+ ):
423
+ super(BaseInteractionDataset, self).__init__(*args, **kwargs)
424
+ self.tokenizer = tokenizer
425
+ self.max_seq_len = max_seq_len
426
+ self.interactions = interactions
427
+ self.query_field = query_field
428
+ self.answer_field = answer_field
429
+ self.is_pre_tokenized = False
430
+ self.cache_tokenized = cache_tokenized
431
+ self.cache_remove_text = cache_remove_text
432
+ self.inputs = []
433
+ self.is_list = isinstance(self.interactions, list)
434
+ self.tokenize_in_background = tokenize_in_background
435
+ self.bg_next = []
436
+ self.bg_queue = None
437
+ self.batch_size = batch_size
438
+ self.last_idx = 0
439
+ if tokenize_in_background:
440
+ for i in range(self.batch_size):
441
+ self.bg_next.append(self.get_tokenized_text(i))
442
+ self.last_idx = self.batch_size - 1
443
+
444
+
445
+ def __len__(self):
446
+ return len(self.interactions if not self.is_pre_tokenized else self.inputs)
447
+
448
+ def get_tokenized_text(self, idx: int, inter: dict = None):
449
+ if self.is_pre_tokenized:
450
+ return self.inputs[idx]
451
+ elif self.tokenize_in_background:
452
+ if idx == self.last_idx - self.batch_size:
453
+ if self.bg_queue is not None:
454
+ self.bg_next = self.bg_queue
455
+ self.bg_queue = None
456
+ # TODO: schedule tokenizing a batch in background
457
+ elif idx == self.last_idx:
458
+ item = self.bg_next[idx]
459
+ self.bg_next = []
460
+ return item
461
+
462
+ if idx <= self.last_idx:
463
+ if self.bg_queue is not None:
464
+ self.bg_next = self.bg_queue
465
+ self.bg_queue = None
466
+
467
+ new_idx = idx - (self.last_idx - self.batch_size)
468
+ if new_idx in self.bg_next:
469
+ return self.bg_next[new_idx]
470
+ else:
471
+ interaction = self.interactions[idx]
472
+ query = interaction[self.query_field]
473
+ answer = interaction[self.answer_field]
474
+
475
+ inputs = self.tokenizer(
476
+ query,
477
+ answer,
478
+ max_length=self.max_seq_len,
479
+ truncation=True,
480
+ padding='max_length',
481
+ return_tensors='pt',
482
+ return_attention_mask=True
483
+ )
484
+ if not (inputs['input_ids'][0] < self.tokenizer.vocab_size).all():
485
+ inputs['input_ids'][0][
486
+ (inputs['input_ids'][0] >= self.tokenizer.vocab_size)] = self.tokenizer.unk_token_id
487
+ if not (inputs['input_ids'][0] >= 0).all():
488
+ inputs['input_ids'][0][inputs['input_ids'][0] < 0] = self.tokenizer.unk_token_id
489
+
490
+ return inputs
491
+ else:
492
+ if inter is not None:
493
+ interaction = inter
494
+ else:
495
+ interaction = self.interactions[idx]
496
+ query = interaction[self.query_field]
497
+ answer = interaction[self.answer_field]
498
+
499
+ inputs = self.tokenizer(
500
+ query,
501
+ answer,
502
+ max_length=self.max_seq_len,
503
+ truncation=True,
504
+ padding='max_length',
505
+ return_tensors='pt',
506
+ return_attention_mask=True
507
+ )
508
+ if not (inputs['input_ids'][0] < self.tokenizer.vocab_size).all():
509
+ inputs['input_ids'][0][(inputs['input_ids'][0] >= self.tokenizer.vocab_size)] = self.tokenizer.unk_token_id
510
+ if not (inputs['input_ids'][0] >= 0).all():
511
+ inputs['input_ids'][0][inputs['input_ids'][0] < 0] = self.tokenizer.unk_token_id
512
+
513
+ if self.cache_tokenized:
514
+ self.inputs.append(inputs)
515
+ if len(self.inputs) == len(self.interactions):
516
+ self.is_pre_tokenized = True
517
+ if self.cache_remove_text:
518
+ del self.interactions
519
+ self.interactions = None
520
+
521
+ return inputs
522
+
523
+ def get_subset(self, size: float, from_start: bool = False, **kwargs) -> "BaseInteractionDataset":
524
+ split_point = int(len(self.interactions) * ((1 - size) if not from_start else size))
525
+ if not isinstance(self.interactions, list):
526
+ subset = self.interactions.select(range(split_point, len(self.interactions)) if not from_start else range(split_point))
527
+ self.interactions = self.interactions.select(range(split_point) if not from_start else range(split_point, len(self.interactions)))
528
+ else:
529
+ subset = self.interactions[split_point:-1] if not from_start else self.interactions[0:split_point]
530
+ self.interactions = self.interactions[0:split_point] if not from_start else self.interactions[split_point:-1]
531
+ return self.__class__(subset, self.tokenizer, max_seq_len=self.max_seq_len, query_field=self.query_field, answer_field=self.answer_field, **kwargs)
532
+
533
+ def pre_tokenize(self, verbose: bool = False, log_interval: int = 10_000):
534
+ """
535
+ Pre-tokenizes all the items in the dataset, for faster training. Training with pre-tokenized
536
+ dataset could be even 2x faster.
537
+
538
+ !! This method has extremely high memory usage, when used with HuggingFace datasets,
539
+ because of converting it to list. Additionally, for the most optimal performance,
540
+ pre-tokenized items are in reversed order - it shouldn't matter for training, as
541
+ items are shuffled then by DataLoader, but you should keep that in mind in case
542
+ of reproducibility.
543
+
544
+ :param(bool) verbose:
545
+ :param(int) log_interval: Interval of verbose logs
546
+ """
547
+ if not self.is_pre_tokenized:
548
+ num_texts = len(self.interactions)
549
+ inters = self.interactions if self.is_list else self.interactions.to_list()
550
+ del self.interactions
551
+ self.interactions = None
552
+ for index in range(num_texts):
553
+ self.inputs.append(self.get_tokenized_text(index, inter=inters.pop()))
554
+ if verbose and index % log_interval == 0:
555
+ print(f'Processed {index + 1}/{num_texts}')
556
+ self.is_pre_tokenized = True
557
+
558
+
559
+ @classmethod
560
+ def from_hf_hub(
561
+ cls,
562
+ dataset_id: str,
563
+ subset: str = None,
564
+ split: str = 'train',
565
+ target_fields: tuple[str, str] = ('query', 'answer'),
566
+ tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] = None,
567
+ tokenizer_hub_id: str = None,
568
+ max_seq_len: int = 1024,
569
+ load_kwargs: dict = None,
570
+ load_tokenizer_kwargs: dict = None,
571
+ **kwargs
572
+ ):
573
+ """
574
+ Load dataset from HuggingFace Hub and convert it to RxNN training dataset.
575
+
576
+ One of the `tokenizer` or `tokenizer_hub_id` args must be provided. If both are provided, `tokenizer` will be used.
577
+
578
+ Args:
579
+ dataset_id (str): Hub dataset repository name
580
+ subset (str): Dataset subset
581
+ split (str): Dataset split (default: "train")
582
+ target_fields (tuple): Name of dataset fields used for training (default: ("query", "answer"))
583
+ tokenizer (PreTrainedTokenizer): HuggingFace Tokenizer used for training (default: None)
584
+ tokenizer_hub_id (str): HuggingFace Hub ID of tokenizer to load (default: None)
585
+ max_seq_len (int): Maximum sequence length for training (default: 1024)
586
+ load_kwargs (dict): Additional args for HuggingFace API load_dataset function
587
+ load_tokenizer_kwargs (dict): Additional args for loading tokenizer from HuggingFace API with `huggingface_hub.hf_hub_download`
588
+ **kwargs: Additional args for RxNN Dataset class
589
+ """
590
+ assert tokenizer is not None or tokenizer_hub_id is not None, "One of the `tokenizer` or `tokenizer_hub_id` args must be provided."
591
+
592
+ if tokenizer is None:
593
+ tokenizer = load_tokenizer_from_hf_hub(tokenizer_hub_id, **load_tokenizer_kwargs)
594
+
595
+ hf_dataset = load_dataset(dataset_id, subset, split=split, **load_kwargs)
596
+
597
+ query_field, answer_field = target_fields
598
+
599
+ return cls(hf_dataset, tokenizer, max_seq_len=max_seq_len, query_field=query_field, answer_field=answer_field, **kwargs)
600
+
601
+ @classmethod
602
+ def concat_from_hf_hub(
603
+ cls,
604
+ dataset_ids: tuple[str],
605
+ subsets: tuple[str] = None,
606
+ split: str = 'train',
607
+ target_fields: tuple[str, str] = ('query', 'answer'),
608
+ tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] = None,
609
+ tokenizer_hub_id: str = None,
610
+ max_seq_len: int = 1024,
611
+ load_kwargs: dict = None,
612
+ load_tokenizer_kwargs: dict = None,
613
+ **kwargs
614
+ ):
615
+ """
616
+ Load and concatenate multiple datasets from HuggingFace Hub and convert them to RxNN training dataset.
617
+ All datasets should use the same split and target field. If it's not the case, just use `load_dataset` and pass the
618
+ result to RxNN dataset constructor directly.
619
+
620
+ One of the `tokenizer` or `tokenizer_hub_id` args must be provided. If both are provided, `tokenizer` will be used.
621
+
622
+ Args:
623
+ dataset_ids (tuple[str]): Hub dataset repository names
624
+ subsets (tuple[str]): Dataset subsets (default: None)
625
+ split (str): Dataset split (default: "train")
626
+ target_fields (tuple): Name of dataset field used for training (default: ("query", "answer"))
627
+ tokenizer (PreTrainedTokenizer): HuggingFace Tokenizer used for training (default: None)
628
+ tokenizer_hub_id (str): HuggingFace Hub ID of tokenizer to load (default: None)
629
+ max_seq_len (int): Maximum sequence length for training (default: 1024)
630
+ load_kwargs (dict): Additional args for HuggingFace API load_dataset function
631
+ load_tokenizer_kwargs (dict): Additional args for loading tokenizer from HuggingFace API with `huggingface_hub.hf_hub_download`
632
+ **kwargs: Additional args for RxNN Dataset class
633
+ """
634
+ assert tokenizer is not None or tokenizer_hub_id is not None, "One of the `tokenizer` or `tokenizer_hub_id` args must be provided."
635
+
636
+ if tokenizer is None:
637
+ tokenizer = load_tokenizer_from_hf_hub(tokenizer_hub_id, **load_tokenizer_kwargs)
638
+
639
+ hf_datasets = [
640
+ load_dataset(dataset_id, subset, split=split, **load_kwargs) for dataset_id, subset in zip(dataset_ids, subsets)
641
+ ] if subsets is not None else [
642
+ load_dataset(dataset_id, split=split, **load_kwargs) for dataset_id in dataset_ids
643
+ ]
644
+ hf_dataset = concatenate_datasets(hf_datasets)
645
+
646
+ query_field, answer_field = target_fields
647
+
648
+ return cls(hf_dataset, tokenizer, max_seq_len=max_seq_len, query_field=query_field, answer_field=answer_field, **kwargs)
649
+
650
+ @classmethod
651
+ def concat_from_hf_hub_with_subset(
652
+ cls,
653
+ dataset_ids: tuple[str],
654
+ subsets: tuple[str] = None,
655
+ split: str = 'train',
656
+ target_fields: tuple[str, str] = ('query', 'answer'),
657
+ tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] = None,
658
+ tokenizer_hub_id: str = None,
659
+ max_seq_len: int = 1024,
660
+ load_kwargs: dict = None,
661
+ load_tokenizer_kwargs: dict = None,
662
+ valid_size: float = 0.1,
663
+ **kwargs
664
+ ):
665
+ """
666
+ Load and concatenate multiple datasets from HuggingFace Hub, create validation split and convert them to RxNN training dataset.
667
+ All datasets should use the same split and target field. If it's not the case, just use `load_dataset` and pass the
668
+ result to RxNN dataset constructor directly.
669
+
670
+ One of the `tokenizer` or `tokenizer_hub_id` args must be provided. If both are provided, `tokenizer` will be used.
671
+
672
+ Args:
673
+ dataset_ids (tuple[str]): Hub dataset repository names
674
+ subsets (tuple[str]): Dataset subsets (default: None)
675
+ split (str): Dataset split (default: "train")
676
+ target_fields (tuple[str, str]): Name of dataset field used for training (default: "text")
677
+ tokenizer (PreTrainedTokenizer): HuggingFace Tokenizer used for training (default: None)
678
+ tokenizer_hub_id (str): HuggingFace Hub ID of tokenizer to load (default: None)
679
+ max_seq_len (int): Maximum sequence length for training (default: 1024)
680
+ load_kwargs (dict): Additional args for HuggingFace API load_dataset function
681
+ load_tokenizer_kwargs (dict): Additional args for loading tokenizer from HuggingFace API with `huggingface_hub.hf_hub_download`
682
+ valid_size (float): Size of validation dataset (default: 0.1)
683
+ **kwargs: Additional args for RxNN Dataset class
684
+ """
685
+ assert tokenizer is not None or tokenizer_hub_id is not None, "One of the `tokenizer` or `tokenizer_hub_id` args must be provided."
686
+
687
+ if tokenizer is None:
688
+ tokenizer = load_tokenizer_from_hf_hub(tokenizer_hub_id, **load_tokenizer_kwargs)
689
+
690
+ hf_datasets = [
691
+ load_dataset(dataset_id, subset, split=split, **load_kwargs) for dataset_id, subset in zip(dataset_ids, subsets)
692
+ ] if subsets is not None else [
693
+ load_dataset(dataset_id, split=split, **load_kwargs) for dataset_id in dataset_ids
694
+ ]
695
+
696
+ hf_ds_dicts = [dataset.train_test_split(test_size=valid_size) for dataset in hf_datasets]
697
+
698
+ hf_dataset = concatenate_datasets([ds_dict['train'] for ds_dict in hf_ds_dicts])
699
+ hf_valid_dataset = concatenate_datasets([ds_dict['test'] for ds_dict in hf_ds_dicts])
700
+
701
+ query_field, answer_field = target_fields
702
+
703
+ return cls(hf_dataset, tokenizer, max_seq_len=max_seq_len, query_field=query_field, answer_field=answer_field, **kwargs), cls(hf_valid_dataset, tokenizer, max_seq_len=max_seq_len, query_field=query_field, answer_field=answer_field, **kwargs)
704
+
705
+ class DecoderSftDataset(BaseInteractionDataset):
706
+ def __init__(
707
+ self,
708
+ interactions: Union[list[dict], HfDataset],
709
+ tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
710
+ max_seq_len: int = 1024,
711
+ query_field: str = 'query',
712
+ answer_field: str = 'answer',
713
+ cache_tokenized: bool = False,
714
+ cache_remove_text: bool = True,
715
+ tokenize_in_background: bool = False,
716
+ batch_size: int = 1,
717
+ *args,
718
+ **kwargs
719
+ ):
720
+ super(DecoderSftDataset, self).__init__(
721
+ interactions,
722
+ tokenizer=tokenizer,
723
+ max_seq_len=max_seq_len,
724
+ query_field=query_field,
725
+ answer_field=answer_field,
726
+ cache_tokenized=cache_tokenized,
727
+ cache_remove_text=cache_remove_text,
728
+ tokenize_in_background=tokenize_in_background,
729
+ batch_size=batch_size,
730
+ *args,
731
+ **kwargs
732
+ )
733
+
734
+ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
735
+ inputs = self.get_tokenized_text(idx)
736
+
737
+ input_ids = inputs['input_ids'][0]
738
+ attention_mask = inputs['attention_mask'][0]
739
+ targets = input_ids.clone()
740
+
741
+ return {
742
+ 'input_ids': input_ids,
743
+ 'attention_mask': attention_mask,
744
+ 'targets': targets
745
+ }
746
+
747
+ class EncoderSftDataset(BaseInteractionDataset):
748
+ def __init__(
749
+ self,
750
+ interactions: Union[list[dict], HfDataset],
751
+ tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
752
+ max_seq_len: int = 1024,
753
+ query_field: str = 'query',
754
+ answer_field: str = 'answer',
755
+ cache_tokenized: bool = False,
756
+ cache_remove_text: bool = True,
757
+ tokenize_in_background: bool = False,
758
+ batch_size: int = 1,
759
+ mask_prob: float = 0.15,
760
+ *args,
761
+ **kwargs
762
+ ):
763
+ super(EncoderSftDataset, self).__init__(
764
+ interactions,
765
+ tokenizer=tokenizer,
766
+ max_seq_len=max_seq_len,
767
+ query_field=query_field,
768
+ answer_field=answer_field,
769
+ cache_tokenized=cache_tokenized,
770
+ cache_remove_text=cache_remove_text,
771
+ tokenize_in_background=tokenize_in_background,
772
+ batch_size=batch_size,
773
+ *args,
774
+ **kwargs
775
+ )
776
+ self.mask_prob = mask_prob
777
+
778
+ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
779
+ inputs = self.get_tokenized_text(idx)
780
+
781
+ input_ids = inputs['input_ids'][0]
782
+ if self.is_pre_tokenized:
783
+ input_ids = input_ids.clone()
784
+ attention_mask = inputs['attention_mask'][0]
785
+ labels = input_ids.clone()
786
+
787
+ # Create masked indices
788
+ masked_indices = torch.bernoulli(
789
+ torch.full(labels.shape, self.mask_prob)
790
+ ).bool() & attention_mask.bool()
791
+
792
+ # Apply mask
793
+ labels[~masked_indices] = -100
794
+ input_ids[masked_indices] = self.tokenizer.mask_token_id
795
+
796
+ return {
797
+ 'input_ids': input_ids,
798
+ 'attention_mask': attention_mask,
799
+ 'labels': labels
800
+ }
@@ -1,354 +0,0 @@
1
- import torch
2
- from torch.utils.data import Dataset
3
- from datasets import Dataset as HfDataset, load_dataset, concatenate_datasets
4
- from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
5
- from .tokenizer import load_tokenizer_from_hf_hub
6
-
7
- from typing import Union
8
-
9
-
10
- class BaseDataset(Dataset):
11
- def __init__(
12
- self,
13
- texts: Union[list[str], HfDataset],
14
- tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
15
- max_seq_len: int = 1024,
16
- hf_field: str = 'text',
17
- cache_tokenized: bool = False,
18
- cache_remove_text: bool = True,
19
- tokenize_in_background: bool = False,
20
- batch_size: int = 1,
21
- *args,
22
- **kwargs
23
- ):
24
- super(BaseDataset, self).__init__(*args, **kwargs)
25
- self.tokenizer = tokenizer
26
- self.max_seq_len = max_seq_len
27
- self.texts = texts
28
- self.hf_field = hf_field
29
- self.is_pre_tokenized = False
30
- self.cache_tokenized = cache_tokenized
31
- self.cache_remove_text = cache_remove_text
32
- self.inputs = []
33
- self.is_txt_list = isinstance(self.texts, list)
34
- self.tokenize_in_background = tokenize_in_background
35
- self.bg_next = []
36
- self.bg_queue = None
37
- self.batch_size = batch_size
38
- self.last_idx = 0
39
- if tokenize_in_background:
40
- for i in range(self.batch_size):
41
- self.bg_next.append(self.get_tokenized_text(i))
42
- self.last_idx = self.batch_size - 1
43
-
44
-
45
- def __len__(self):
46
- return len(self.texts if not self.is_pre_tokenized else self.inputs)
47
-
48
- def get_tokenized_text(self, idx: int, txt: str = None):
49
- if self.is_pre_tokenized:
50
- return self.inputs[idx]
51
- elif self.tokenize_in_background:
52
- if idx == self.last_idx - self.batch_size:
53
- if self.bg_queue is not None:
54
- self.bg_next = self.bg_queue
55
- self.bg_queue = None
56
- # TODO: schedule tokenizing a batch in background
57
- elif idx == self.last_idx:
58
- item = self.bg_next[idx]
59
- self.bg_next = []
60
- return item
61
-
62
- if idx <= self.last_idx:
63
- if self.bg_queue is not None:
64
- self.bg_next = self.bg_queue
65
- self.bg_queue = None
66
-
67
- new_idx = idx - (self.last_idx - self.batch_size)
68
- if new_idx in self.bg_next:
69
- return self.bg_next[new_idx]
70
- else:
71
- if self.is_txt_list:
72
- text = self.texts[idx]
73
- else:
74
- text = self.texts[idx][self.hf_field]
75
-
76
- inputs = self.tokenizer(
77
- text,
78
- max_length=self.max_seq_len,
79
- truncation=True,
80
- padding='max_length',
81
- return_tensors='pt',
82
- return_attention_mask=True
83
- )
84
- if not (inputs['input_ids'][0] < self.tokenizer.vocab_size).all():
85
- inputs['input_ids'][0][
86
- (inputs['input_ids'][0] >= self.tokenizer.vocab_size)] = self.tokenizer.unk_token_id
87
- if not (inputs['input_ids'][0] >= 0).all():
88
- inputs['input_ids'][0][inputs['input_ids'][0] < 0] = self.tokenizer.unk_token_id
89
-
90
- return inputs
91
- else:
92
- if txt is not None:
93
- text = txt
94
- elif self.is_txt_list:
95
- text = self.texts[idx]
96
- else:
97
- text = self.texts[idx][self.hf_field]
98
-
99
- inputs = self.tokenizer(
100
- text,
101
- max_length=self.max_seq_len,
102
- truncation=True,
103
- padding='max_length',
104
- return_tensors='pt',
105
- return_attention_mask=True
106
- )
107
- if not (inputs['input_ids'][0] < self.tokenizer.vocab_size).all():
108
- inputs['input_ids'][0][(inputs['input_ids'][0] >= self.tokenizer.vocab_size)] = self.tokenizer.unk_token_id
109
- if not (inputs['input_ids'][0] >= 0).all():
110
- inputs['input_ids'][0][inputs['input_ids'][0] < 0] = self.tokenizer.unk_token_id
111
-
112
- if self.cache_tokenized:
113
- self.inputs.append(inputs)
114
- if len(self.inputs) == len(self.texts):
115
- self.is_pre_tokenized = True
116
- if self.cache_remove_text:
117
- del self.texts
118
- self.texts = None
119
-
120
- return inputs
121
-
122
- def get_subset(self, size: float, from_start: bool = False, **kwargs) -> "BaseDataset":
123
- split_point = int(len(self.texts) * ((1 - size) if not from_start else size))
124
- if not isinstance(self.texts, list):
125
- subset = self.texts.select(range(split_point, len(self.texts)) if not from_start else range(split_point))
126
- self.texts = self.texts.select(range(split_point) if not from_start else range(split_point, len(self.texts)))
127
- else:
128
- subset = self.texts[split_point:-1] if not from_start else self.texts[0:split_point]
129
- self.texts = self.texts[0:split_point] if not from_start else self.texts[split_point:-1]
130
- return self.__class__(subset, self.tokenizer, max_seq_len=self.max_seq_len, hf_field=self.hf_field, **kwargs)
131
-
132
- def pre_tokenize(self, verbose: bool = False, log_interval: int = 10_000, map_hf_ds_to_list: bool = True):
133
- """
134
- Pre-tokenizes all the items in the dataset, for faster training. Training with pre-tokenized
135
- dataset could be even 2x faster.
136
-
137
- !! This method has extremely high memory usage, when used with HuggingFace datasets,
138
- because of converting it to list. Additionally, for the most optimal performance,
139
- pre-tokenized items are in reversed order - it shouldn't matter for training, as
140
- items are shuffled then by DataLoader, but you should keep that in mind in case
141
- of reproducibility.
142
-
143
- :param(bool) verbose:
144
- :param(int) log_interval: Interval of verbose logs
145
- """
146
- if not self.is_pre_tokenized:
147
- num_texts = len(self.texts)
148
- txts = self.texts if self.is_txt_list else self.texts.to_list()
149
- del self.texts
150
- self.texts = None
151
- for index in range(num_texts):
152
- item = txts.pop() if self.is_txt_list else txts.pop()[self.hf_field]
153
- self.inputs.append(self.get_tokenized_text(index, txt=item))
154
- if verbose and index % log_interval == 0:
155
- print(f'Processed {index + 1}/{num_texts}')
156
- self.is_pre_tokenized = True
157
-
158
-
159
- @classmethod
160
- def from_hf_hub(
161
- cls,
162
- dataset_id: str,
163
- subset: str = None,
164
- split: str = 'train',
165
- target_field: str = 'text',
166
- tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] = None,
167
- tokenizer_hub_id: str = None,
168
- max_seq_len: int = 1024,
169
- load_kwargs: dict = None,
170
- load_tokenizer_kwargs: dict = None,
171
- **kwargs
172
- ):
173
- """
174
- Load dataset from HuggingFace Hub and convert it to RxNN training dataset.
175
-
176
- One of the `tokenizer` or `tokenizer_hub_id` args must be provided. If both are provided, `tokenizer` will be used.
177
-
178
- Args:
179
- dataset_id (str): Hub dataset repository name
180
- subset (str): Dataset subset
181
- split (str): Dataset split (default: "train")
182
- target_field (str): Name of dataset field used for training (default: "text")
183
- tokenizer (PreTrainedTokenizer): HuggingFace Tokenizer used for training (default: None)
184
- tokenizer_hub_id (str): HuggingFace Hub ID of tokenizer to load (default: None)
185
- max_seq_len (int): Maximum sequence length for training (default: 1024)
186
- load_kwargs (dict): Additional args for HuggingFace API load_dataset function
187
- load_tokenizer_kwargs (dict): Additional args for loading tokenizer from HuggingFace API with `huggingface_hub.hf_hub_download`
188
- **kwargs: Additional args for RxNN Dataset class
189
- """
190
- assert tokenizer is not None or tokenizer_hub_id is not None, "One of the `tokenizer` or `tokenizer_hub_id` args must be provided."
191
-
192
- if tokenizer is None:
193
- tokenizer = load_tokenizer_from_hf_hub(tokenizer_hub_id, **load_tokenizer_kwargs)
194
-
195
- hf_dataset = load_dataset(dataset_id, subset, split=split, **load_kwargs)
196
-
197
- return cls(hf_dataset, tokenizer, max_seq_len=max_seq_len, hf_field=target_field, **kwargs)
198
-
199
- @classmethod
200
- def concat_from_hf_hub(
201
- cls,
202
- dataset_ids: tuple[str],
203
- subsets: tuple[str] = None,
204
- split: str = 'train',
205
- target_field: str = 'text',
206
- tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] = None,
207
- tokenizer_hub_id: str = None,
208
- max_seq_len: int = 1024,
209
- load_kwargs: dict = None,
210
- load_tokenizer_kwargs: dict = None,
211
- **kwargs
212
- ):
213
- """
214
- Load and concatenate multiple datasets from HuggingFace Hub and convert them to RxNN training dataset.
215
- All datasets should use the same split and target field. If it's not the case, just use `load_dataset` and pass the
216
- result to RxNN dataset constructor directly.
217
-
218
- One of the `tokenizer` or `tokenizer_hub_id` args must be provided. If both are provided, `tokenizer` will be used.
219
-
220
- Args:
221
- dataset_ids (tuple[str]): Hub dataset repository names
222
- subsets (tuple[str]): Dataset subsets (default: None)
223
- split (str): Dataset split (default: "train")
224
- target_field (str): Name of dataset field used for training (default: "text")
225
- tokenizer (PreTrainedTokenizer): HuggingFace Tokenizer used for training (default: None)
226
- tokenizer_hub_id (str): HuggingFace Hub ID of tokenizer to load (default: None)
227
- max_seq_len (int): Maximum sequence length for training (default: 1024)
228
- load_kwargs (dict): Additional args for HuggingFace API load_dataset function
229
- load_tokenizer_kwargs (dict): Additional args for loading tokenizer from HuggingFace API with `huggingface_hub.hf_hub_download`
230
- **kwargs: Additional args for RxNN Dataset class
231
- """
232
- assert tokenizer is not None or tokenizer_hub_id is not None, "One of the `tokenizer` or `tokenizer_hub_id` args must be provided."
233
-
234
- if tokenizer is None:
235
- tokenizer = load_tokenizer_from_hf_hub(tokenizer_hub_id, **load_tokenizer_kwargs)
236
-
237
- hf_datasets = [
238
- load_dataset(dataset_id, subset, split=split, **load_kwargs) for dataset_id, subset in zip(dataset_ids, subsets)
239
- ] if subsets is not None else [
240
- load_dataset(dataset_id, split=split, **load_kwargs) for dataset_id in dataset_ids
241
- ]
242
- hf_dataset = concatenate_datasets(hf_datasets)
243
-
244
- return cls(hf_dataset, tokenizer, max_seq_len=max_seq_len, hf_field=target_field, **kwargs)
245
-
246
-
247
-
248
- class JointLMDataset(BaseDataset):
249
- def __init__(
250
- self,
251
- texts: Union[list[str], HfDataset],
252
- tokenizer: PreTrainedTokenizer,
253
- max_seq_len: int = 1024,
254
- mask_prob: float = 0.15,
255
- hf_field: str = 'text',
256
- *args,
257
- **kwargs
258
- ):
259
- super(JointLMDataset, self).__init__(texts, tokenizer, max_seq_len, hf_field, *args, **kwargs)
260
- self.mask_prob = mask_prob
261
-
262
- def __getitem__(self, idx: int) -> dict[str, dict[str, torch.Tensor]]:
263
- inputs = self.get_tokenized_text(idx)
264
- encoder_input_ids = inputs['input_ids'][0]
265
- attention_mask = inputs['attention_mask'][0]
266
-
267
- decoder_input_ids = encoder_input_ids.clone()
268
-
269
- encoder_labels = encoder_input_ids.clone()
270
- decoder_targets = encoder_input_ids.clone()
271
-
272
- # Create masked indices
273
- masked_indices = torch.bernoulli(
274
- torch.full(encoder_labels.shape, self.mask_prob)
275
- ).bool() & attention_mask.bool()
276
-
277
- # Apply mask
278
- encoder_labels[~masked_indices] = -100
279
- encoder_input_ids[masked_indices] = self.tokenizer.mask_token_id
280
-
281
- return {
282
- 'decoder': {
283
- 'input_ids': decoder_input_ids,
284
- 'targets': decoder_targets,
285
- },
286
- 'encoder': {
287
- 'input_ids': encoder_input_ids,
288
- 'labels': encoder_labels,
289
- },
290
- 'attention_mask': attention_mask,
291
- }
292
-
293
-
294
- class MaskedLMDataset(BaseDataset):
295
- def __init__(
296
- self,
297
- texts: Union[list[str], HfDataset],
298
- tokenizer: PreTrainedTokenizer,
299
- max_seq_len: int = 1024,
300
- mask_prob: float = 0.15,
301
- hf_field: str = 'text',
302
- *args,
303
- **kwargs
304
- ):
305
- super(MaskedLMDataset, self).__init__(texts, tokenizer, max_seq_len, hf_field, *args, **kwargs)
306
- self.mask_prob = mask_prob
307
-
308
- def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
309
- inputs = self.get_tokenized_text(idx)
310
-
311
- input_ids = inputs['input_ids'][0]
312
- attention_mask = inputs['attention_mask'][0]
313
- labels = input_ids.clone()
314
-
315
- # Create masked indices
316
- masked_indices = torch.bernoulli(
317
- torch.full(labels.shape, self.mask_prob)
318
- ).bool() & attention_mask.bool()
319
-
320
- # Apply mask
321
- labels[~masked_indices] = -100
322
- input_ids[masked_indices] = self.tokenizer.mask_token_id
323
-
324
- return {
325
- 'input_ids': input_ids,
326
- 'attention_mask': attention_mask,
327
- 'labels': labels
328
- }
329
-
330
-
331
- class AutoregressiveLMDataset(BaseDataset):
332
- def __init__(
333
- self,
334
- texts: Union[list[str], HfDataset],
335
- tokenizer: PreTrainedTokenizer,
336
- max_seq_len: int = 1024,
337
- hf_field: str = 'text',
338
- *args,
339
- **kwargs
340
- ):
341
- super(AutoregressiveLMDataset, self).__init__(texts, tokenizer, max_seq_len, hf_field, *args, **kwargs)
342
-
343
- def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
344
- inputs = self.get_tokenized_text(idx)
345
-
346
- input_ids = inputs['input_ids'][0]
347
- attention_mask = inputs['attention_mask'][0]
348
- targets = input_ids.clone()
349
-
350
- return {
351
- 'input_ids': input_ids,
352
- 'attention_mask': attention_mask,
353
- 'targets': targets
354
- }
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes