rxnn 0.1.80__py3-none-any.whl → 0.1.82__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/training/dataset.py
CHANGED
@@ -243,6 +243,58 @@ class BaseDataset(Dataset):
|
|
243
243
|
|
244
244
|
return cls(hf_dataset, tokenizer, max_seq_len=max_seq_len, hf_field=target_field, **kwargs)
|
245
245
|
|
246
|
+
@classmethod
|
247
|
+
def concat_from_hf_hub_with_subset(
|
248
|
+
cls,
|
249
|
+
dataset_ids: tuple[str],
|
250
|
+
subsets: tuple[str] = None,
|
251
|
+
split: str = 'train',
|
252
|
+
target_field: str = 'text',
|
253
|
+
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] = None,
|
254
|
+
tokenizer_hub_id: str = None,
|
255
|
+
max_seq_len: int = 1024,
|
256
|
+
load_kwargs: dict = None,
|
257
|
+
load_tokenizer_kwargs: dict = None,
|
258
|
+
valid_size: float = 0.1,
|
259
|
+
**kwargs
|
260
|
+
):
|
261
|
+
"""
|
262
|
+
Load and concatenate multiple datasets from HuggingFace Hub, create validation split and convert them to RxNN training dataset.
|
263
|
+
All datasets should use the same split and target field. If it's not the case, just use `load_dataset` and pass the
|
264
|
+
result to RxNN dataset constructor directly.
|
265
|
+
|
266
|
+
One of the `tokenizer` or `tokenizer_hub_id` args must be provided. If both are provided, `tokenizer` will be used.
|
267
|
+
|
268
|
+
Args:
|
269
|
+
dataset_ids (tuple[str]): Hub dataset repository names
|
270
|
+
subsets (tuple[str]): Dataset subsets (default: None)
|
271
|
+
split (str): Dataset split (default: "train")
|
272
|
+
target_field (str): Name of dataset field used for training (default: "text")
|
273
|
+
tokenizer (PreTrainedTokenizer): HuggingFace Tokenizer used for training (default: None)
|
274
|
+
tokenizer_hub_id (str): HuggingFace Hub ID of tokenizer to load (default: None)
|
275
|
+
max_seq_len (int): Maximum sequence length for training (default: 1024)
|
276
|
+
load_kwargs (dict): Additional args for HuggingFace API load_dataset function
|
277
|
+
load_tokenizer_kwargs (dict): Additional args for loading tokenizer from HuggingFace API with `huggingface_hub.hf_hub_download`
|
278
|
+
valid_size (float): Size of validation dataset (default: 0.1)
|
279
|
+
**kwargs: Additional args for RxNN Dataset class
|
280
|
+
"""
|
281
|
+
assert tokenizer is not None or tokenizer_hub_id is not None, "One of the `tokenizer` or `tokenizer_hub_id` args must be provided."
|
282
|
+
|
283
|
+
if tokenizer is None:
|
284
|
+
tokenizer = load_tokenizer_from_hf_hub(tokenizer_hub_id, **load_tokenizer_kwargs)
|
285
|
+
|
286
|
+
hf_datasets = [
|
287
|
+
load_dataset(dataset_id, subset, split=split, **load_kwargs) for dataset_id, subset in zip(dataset_ids, subsets)
|
288
|
+
] if subsets is not None else [
|
289
|
+
load_dataset(dataset_id, split=split, **load_kwargs) for dataset_id in dataset_ids
|
290
|
+
]
|
291
|
+
|
292
|
+
hf_ds_dicts = [dataset.train_test_split(test_size=valid_size) for dataset in hf_datasets]
|
293
|
+
|
294
|
+
hf_dataset = concatenate_datasets([ds_dict['train'] for ds_dict in hf_ds_dicts])
|
295
|
+
hf_valid_dataset = concatenate_datasets([ds_dict['test'] for ds_dict in hf_ds_dicts])
|
296
|
+
|
297
|
+
return cls(hf_dataset, tokenizer, max_seq_len=max_seq_len, hf_field=target_field, **kwargs), cls(hf_valid_dataset, tokenizer, max_seq_len=max_seq_len, hf_field=target_field, **kwargs)
|
246
298
|
|
247
299
|
|
248
300
|
class JointLMDataset(BaseDataset):
|
@@ -352,3 +404,397 @@ class AutoregressiveLMDataset(BaseDataset):
|
|
352
404
|
'attention_mask': attention_mask,
|
353
405
|
'targets': targets
|
354
406
|
}
|
407
|
+
|
408
|
+
class BaseInteractionDataset(Dataset):
|
409
|
+
def __init__(
|
410
|
+
self,
|
411
|
+
interactions: Union[list[dict], HfDataset],
|
412
|
+
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
413
|
+
max_seq_len: int = 1024,
|
414
|
+
query_field: str = 'query',
|
415
|
+
answer_field: str = 'answer',
|
416
|
+
cache_tokenized: bool = False,
|
417
|
+
cache_remove_text: bool = True,
|
418
|
+
tokenize_in_background: bool = False,
|
419
|
+
batch_size: int = 1,
|
420
|
+
*args,
|
421
|
+
**kwargs
|
422
|
+
):
|
423
|
+
super(BaseInteractionDataset, self).__init__(*args, **kwargs)
|
424
|
+
self.tokenizer = tokenizer
|
425
|
+
self.max_seq_len = max_seq_len
|
426
|
+
self.interactions = interactions
|
427
|
+
self.query_field = query_field
|
428
|
+
self.answer_field = answer_field
|
429
|
+
self.is_pre_tokenized = False
|
430
|
+
self.cache_tokenized = cache_tokenized
|
431
|
+
self.cache_remove_text = cache_remove_text
|
432
|
+
self.inputs = []
|
433
|
+
self.is_list = isinstance(self.interactions, list)
|
434
|
+
self.tokenize_in_background = tokenize_in_background
|
435
|
+
self.bg_next = []
|
436
|
+
self.bg_queue = None
|
437
|
+
self.batch_size = batch_size
|
438
|
+
self.last_idx = 0
|
439
|
+
if tokenize_in_background:
|
440
|
+
for i in range(self.batch_size):
|
441
|
+
self.bg_next.append(self.get_tokenized_text(i))
|
442
|
+
self.last_idx = self.batch_size - 1
|
443
|
+
|
444
|
+
|
445
|
+
def __len__(self):
|
446
|
+
return len(self.interactions if not self.is_pre_tokenized else self.inputs)
|
447
|
+
|
448
|
+
def get_tokenized_text(self, idx: int, inter: dict = None):
|
449
|
+
if self.is_pre_tokenized:
|
450
|
+
return self.inputs[idx]
|
451
|
+
elif self.tokenize_in_background:
|
452
|
+
if idx == self.last_idx - self.batch_size:
|
453
|
+
if self.bg_queue is not None:
|
454
|
+
self.bg_next = self.bg_queue
|
455
|
+
self.bg_queue = None
|
456
|
+
# TODO: schedule tokenizing a batch in background
|
457
|
+
elif idx == self.last_idx:
|
458
|
+
item = self.bg_next[idx]
|
459
|
+
self.bg_next = []
|
460
|
+
return item
|
461
|
+
|
462
|
+
if idx <= self.last_idx:
|
463
|
+
if self.bg_queue is not None:
|
464
|
+
self.bg_next = self.bg_queue
|
465
|
+
self.bg_queue = None
|
466
|
+
|
467
|
+
new_idx = idx - (self.last_idx - self.batch_size)
|
468
|
+
if new_idx in self.bg_next:
|
469
|
+
return self.bg_next[new_idx]
|
470
|
+
else:
|
471
|
+
interaction = self.interactions[idx]
|
472
|
+
query = interaction[self.query_field]
|
473
|
+
answer = interaction[self.answer_field]
|
474
|
+
|
475
|
+
inputs = self.tokenizer(
|
476
|
+
query,
|
477
|
+
answer,
|
478
|
+
max_length=self.max_seq_len,
|
479
|
+
truncation=True,
|
480
|
+
padding='max_length',
|
481
|
+
return_tensors='pt',
|
482
|
+
return_attention_mask=True
|
483
|
+
)
|
484
|
+
if not (inputs['input_ids'][0] < self.tokenizer.vocab_size).all():
|
485
|
+
inputs['input_ids'][0][
|
486
|
+
(inputs['input_ids'][0] >= self.tokenizer.vocab_size)] = self.tokenizer.unk_token_id
|
487
|
+
if not (inputs['input_ids'][0] >= 0).all():
|
488
|
+
inputs['input_ids'][0][inputs['input_ids'][0] < 0] = self.tokenizer.unk_token_id
|
489
|
+
|
490
|
+
return inputs
|
491
|
+
else:
|
492
|
+
if inter is not None:
|
493
|
+
interaction = inter
|
494
|
+
else:
|
495
|
+
interaction = self.interactions[idx]
|
496
|
+
query = interaction[self.query_field]
|
497
|
+
answer = interaction[self.answer_field]
|
498
|
+
|
499
|
+
inputs = self.tokenizer(
|
500
|
+
query,
|
501
|
+
answer,
|
502
|
+
max_length=self.max_seq_len,
|
503
|
+
truncation=True,
|
504
|
+
padding='max_length',
|
505
|
+
return_tensors='pt',
|
506
|
+
return_attention_mask=True
|
507
|
+
)
|
508
|
+
if not (inputs['input_ids'][0] < self.tokenizer.vocab_size).all():
|
509
|
+
inputs['input_ids'][0][(inputs['input_ids'][0] >= self.tokenizer.vocab_size)] = self.tokenizer.unk_token_id
|
510
|
+
if not (inputs['input_ids'][0] >= 0).all():
|
511
|
+
inputs['input_ids'][0][inputs['input_ids'][0] < 0] = self.tokenizer.unk_token_id
|
512
|
+
|
513
|
+
if self.cache_tokenized:
|
514
|
+
self.inputs.append(inputs)
|
515
|
+
if len(self.inputs) == len(self.interactions):
|
516
|
+
self.is_pre_tokenized = True
|
517
|
+
if self.cache_remove_text:
|
518
|
+
del self.interactions
|
519
|
+
self.interactions = None
|
520
|
+
|
521
|
+
return inputs
|
522
|
+
|
523
|
+
def get_subset(self, size: float, from_start: bool = False, **kwargs) -> "BaseInteractionDataset":
|
524
|
+
split_point = int(len(self.interactions) * ((1 - size) if not from_start else size))
|
525
|
+
if not isinstance(self.interactions, list):
|
526
|
+
subset = self.interactions.select(range(split_point, len(self.interactions)) if not from_start else range(split_point))
|
527
|
+
self.interactions = self.interactions.select(range(split_point) if not from_start else range(split_point, len(self.interactions)))
|
528
|
+
else:
|
529
|
+
subset = self.interactions[split_point:-1] if not from_start else self.interactions[0:split_point]
|
530
|
+
self.interactions = self.interactions[0:split_point] if not from_start else self.interactions[split_point:-1]
|
531
|
+
return self.__class__(subset, self.tokenizer, max_seq_len=self.max_seq_len, query_field=self.query_field, answer_field=self.answer_field, **kwargs)
|
532
|
+
|
533
|
+
def pre_tokenize(self, verbose: bool = False, log_interval: int = 10_000):
|
534
|
+
"""
|
535
|
+
Pre-tokenizes all the items in the dataset, for faster training. Training with pre-tokenized
|
536
|
+
dataset could be even 2x faster.
|
537
|
+
|
538
|
+
!! This method has extremely high memory usage, when used with HuggingFace datasets,
|
539
|
+
because of converting it to list. Additionally, for the most optimal performance,
|
540
|
+
pre-tokenized items are in reversed order - it shouldn't matter for training, as
|
541
|
+
items are shuffled then by DataLoader, but you should keep that in mind in case
|
542
|
+
of reproducibility.
|
543
|
+
|
544
|
+
:param(bool) verbose:
|
545
|
+
:param(int) log_interval: Interval of verbose logs
|
546
|
+
"""
|
547
|
+
if not self.is_pre_tokenized:
|
548
|
+
num_texts = len(self.interactions)
|
549
|
+
inters = self.interactions if self.is_list else self.interactions.to_list()
|
550
|
+
del self.interactions
|
551
|
+
self.interactions = None
|
552
|
+
for index in range(num_texts):
|
553
|
+
self.inputs.append(self.get_tokenized_text(index, inter=inters.pop()))
|
554
|
+
if verbose and index % log_interval == 0:
|
555
|
+
print(f'Processed {index + 1}/{num_texts}')
|
556
|
+
self.is_pre_tokenized = True
|
557
|
+
|
558
|
+
|
559
|
+
@classmethod
|
560
|
+
def from_hf_hub(
|
561
|
+
cls,
|
562
|
+
dataset_id: str,
|
563
|
+
subset: str = None,
|
564
|
+
split: str = 'train',
|
565
|
+
target_fields: tuple[str, str] = ('query', 'answer'),
|
566
|
+
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] = None,
|
567
|
+
tokenizer_hub_id: str = None,
|
568
|
+
max_seq_len: int = 1024,
|
569
|
+
load_kwargs: dict = None,
|
570
|
+
load_tokenizer_kwargs: dict = None,
|
571
|
+
**kwargs
|
572
|
+
):
|
573
|
+
"""
|
574
|
+
Load dataset from HuggingFace Hub and convert it to RxNN training dataset.
|
575
|
+
|
576
|
+
One of the `tokenizer` or `tokenizer_hub_id` args must be provided. If both are provided, `tokenizer` will be used.
|
577
|
+
|
578
|
+
Args:
|
579
|
+
dataset_id (str): Hub dataset repository name
|
580
|
+
subset (str): Dataset subset
|
581
|
+
split (str): Dataset split (default: "train")
|
582
|
+
target_fields (tuple): Name of dataset fields used for training (default: ("query", "answer"))
|
583
|
+
tokenizer (PreTrainedTokenizer): HuggingFace Tokenizer used for training (default: None)
|
584
|
+
tokenizer_hub_id (str): HuggingFace Hub ID of tokenizer to load (default: None)
|
585
|
+
max_seq_len (int): Maximum sequence length for training (default: 1024)
|
586
|
+
load_kwargs (dict): Additional args for HuggingFace API load_dataset function
|
587
|
+
load_tokenizer_kwargs (dict): Additional args for loading tokenizer from HuggingFace API with `huggingface_hub.hf_hub_download`
|
588
|
+
**kwargs: Additional args for RxNN Dataset class
|
589
|
+
"""
|
590
|
+
assert tokenizer is not None or tokenizer_hub_id is not None, "One of the `tokenizer` or `tokenizer_hub_id` args must be provided."
|
591
|
+
|
592
|
+
if tokenizer is None:
|
593
|
+
tokenizer = load_tokenizer_from_hf_hub(tokenizer_hub_id, **load_tokenizer_kwargs)
|
594
|
+
|
595
|
+
hf_dataset = load_dataset(dataset_id, subset, split=split, **load_kwargs)
|
596
|
+
|
597
|
+
query_field, answer_field = target_fields
|
598
|
+
|
599
|
+
return cls(hf_dataset, tokenizer, max_seq_len=max_seq_len, query_field=query_field, answer_field=answer_field, **kwargs)
|
600
|
+
|
601
|
+
@classmethod
|
602
|
+
def concat_from_hf_hub(
|
603
|
+
cls,
|
604
|
+
dataset_ids: tuple[str],
|
605
|
+
subsets: tuple[str] = None,
|
606
|
+
split: str = 'train',
|
607
|
+
target_fields: tuple[str, str] = ('query', 'answer'),
|
608
|
+
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] = None,
|
609
|
+
tokenizer_hub_id: str = None,
|
610
|
+
max_seq_len: int = 1024,
|
611
|
+
load_kwargs: dict = None,
|
612
|
+
load_tokenizer_kwargs: dict = None,
|
613
|
+
**kwargs
|
614
|
+
):
|
615
|
+
"""
|
616
|
+
Load and concatenate multiple datasets from HuggingFace Hub and convert them to RxNN training dataset.
|
617
|
+
All datasets should use the same split and target field. If it's not the case, just use `load_dataset` and pass the
|
618
|
+
result to RxNN dataset constructor directly.
|
619
|
+
|
620
|
+
One of the `tokenizer` or `tokenizer_hub_id` args must be provided. If both are provided, `tokenizer` will be used.
|
621
|
+
|
622
|
+
Args:
|
623
|
+
dataset_ids (tuple[str]): Hub dataset repository names
|
624
|
+
subsets (tuple[str]): Dataset subsets (default: None)
|
625
|
+
split (str): Dataset split (default: "train")
|
626
|
+
target_fields (tuple): Name of dataset field used for training (default: ("query", "answer"))
|
627
|
+
tokenizer (PreTrainedTokenizer): HuggingFace Tokenizer used for training (default: None)
|
628
|
+
tokenizer_hub_id (str): HuggingFace Hub ID of tokenizer to load (default: None)
|
629
|
+
max_seq_len (int): Maximum sequence length for training (default: 1024)
|
630
|
+
load_kwargs (dict): Additional args for HuggingFace API load_dataset function
|
631
|
+
load_tokenizer_kwargs (dict): Additional args for loading tokenizer from HuggingFace API with `huggingface_hub.hf_hub_download`
|
632
|
+
**kwargs: Additional args for RxNN Dataset class
|
633
|
+
"""
|
634
|
+
assert tokenizer is not None or tokenizer_hub_id is not None, "One of the `tokenizer` or `tokenizer_hub_id` args must be provided."
|
635
|
+
|
636
|
+
if tokenizer is None:
|
637
|
+
tokenizer = load_tokenizer_from_hf_hub(tokenizer_hub_id, **load_tokenizer_kwargs)
|
638
|
+
|
639
|
+
hf_datasets = [
|
640
|
+
load_dataset(dataset_id, subset, split=split, **load_kwargs) for dataset_id, subset in zip(dataset_ids, subsets)
|
641
|
+
] if subsets is not None else [
|
642
|
+
load_dataset(dataset_id, split=split, **load_kwargs) for dataset_id in dataset_ids
|
643
|
+
]
|
644
|
+
hf_dataset = concatenate_datasets(hf_datasets)
|
645
|
+
|
646
|
+
query_field, answer_field = target_fields
|
647
|
+
|
648
|
+
return cls(hf_dataset, tokenizer, max_seq_len=max_seq_len, query_field=query_field, answer_field=answer_field, **kwargs)
|
649
|
+
|
650
|
+
@classmethod
|
651
|
+
def concat_from_hf_hub_with_subset(
|
652
|
+
cls,
|
653
|
+
dataset_ids: tuple[str],
|
654
|
+
subsets: tuple[str] = None,
|
655
|
+
split: str = 'train',
|
656
|
+
target_fields: tuple[str, str] = ('query', 'answer'),
|
657
|
+
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] = None,
|
658
|
+
tokenizer_hub_id: str = None,
|
659
|
+
max_seq_len: int = 1024,
|
660
|
+
load_kwargs: dict = None,
|
661
|
+
load_tokenizer_kwargs: dict = None,
|
662
|
+
valid_size: float = 0.1,
|
663
|
+
**kwargs
|
664
|
+
):
|
665
|
+
"""
|
666
|
+
Load and concatenate multiple datasets from HuggingFace Hub, create validation split and convert them to RxNN training dataset.
|
667
|
+
All datasets should use the same split and target field. If it's not the case, just use `load_dataset` and pass the
|
668
|
+
result to RxNN dataset constructor directly.
|
669
|
+
|
670
|
+
One of the `tokenizer` or `tokenizer_hub_id` args must be provided. If both are provided, `tokenizer` will be used.
|
671
|
+
|
672
|
+
Args:
|
673
|
+
dataset_ids (tuple[str]): Hub dataset repository names
|
674
|
+
subsets (tuple[str]): Dataset subsets (default: None)
|
675
|
+
split (str): Dataset split (default: "train")
|
676
|
+
target_fields (tuple[str, str]): Name of dataset field used for training (default: "text")
|
677
|
+
tokenizer (PreTrainedTokenizer): HuggingFace Tokenizer used for training (default: None)
|
678
|
+
tokenizer_hub_id (str): HuggingFace Hub ID of tokenizer to load (default: None)
|
679
|
+
max_seq_len (int): Maximum sequence length for training (default: 1024)
|
680
|
+
load_kwargs (dict): Additional args for HuggingFace API load_dataset function
|
681
|
+
load_tokenizer_kwargs (dict): Additional args for loading tokenizer from HuggingFace API with `huggingface_hub.hf_hub_download`
|
682
|
+
valid_size (float): Size of validation dataset (default: 0.1)
|
683
|
+
**kwargs: Additional args for RxNN Dataset class
|
684
|
+
"""
|
685
|
+
assert tokenizer is not None or tokenizer_hub_id is not None, "One of the `tokenizer` or `tokenizer_hub_id` args must be provided."
|
686
|
+
|
687
|
+
if tokenizer is None:
|
688
|
+
tokenizer = load_tokenizer_from_hf_hub(tokenizer_hub_id, **load_tokenizer_kwargs)
|
689
|
+
|
690
|
+
hf_datasets = [
|
691
|
+
load_dataset(dataset_id, subset, split=split, **load_kwargs) for dataset_id, subset in zip(dataset_ids, subsets)
|
692
|
+
] if subsets is not None else [
|
693
|
+
load_dataset(dataset_id, split=split, **load_kwargs) for dataset_id in dataset_ids
|
694
|
+
]
|
695
|
+
|
696
|
+
hf_ds_dicts = [dataset.train_test_split(test_size=valid_size) for dataset in hf_datasets]
|
697
|
+
|
698
|
+
hf_dataset = concatenate_datasets([ds_dict['train'] for ds_dict in hf_ds_dicts])
|
699
|
+
hf_valid_dataset = concatenate_datasets([ds_dict['test'] for ds_dict in hf_ds_dicts])
|
700
|
+
|
701
|
+
query_field, answer_field = target_fields
|
702
|
+
|
703
|
+
return cls(hf_dataset, tokenizer, max_seq_len=max_seq_len, query_field=query_field, answer_field=answer_field, **kwargs), cls(hf_valid_dataset, tokenizer, max_seq_len=max_seq_len, query_field=query_field, answer_field=answer_field, **kwargs)
|
704
|
+
|
705
|
+
class DecoderSftDataset(BaseInteractionDataset):
|
706
|
+
def __init__(
|
707
|
+
self,
|
708
|
+
interactions: Union[list[dict], HfDataset],
|
709
|
+
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
710
|
+
max_seq_len: int = 1024,
|
711
|
+
query_field: str = 'query',
|
712
|
+
answer_field: str = 'answer',
|
713
|
+
cache_tokenized: bool = False,
|
714
|
+
cache_remove_text: bool = True,
|
715
|
+
tokenize_in_background: bool = False,
|
716
|
+
batch_size: int = 1,
|
717
|
+
*args,
|
718
|
+
**kwargs
|
719
|
+
):
|
720
|
+
super(DecoderSftDataset, self).__init__(
|
721
|
+
interactions,
|
722
|
+
tokenizer=tokenizer,
|
723
|
+
max_seq_len=max_seq_len,
|
724
|
+
query_field=query_field,
|
725
|
+
answer_field=answer_field,
|
726
|
+
cache_tokenized=cache_tokenized,
|
727
|
+
cache_remove_text=cache_remove_text,
|
728
|
+
tokenize_in_background=tokenize_in_background,
|
729
|
+
batch_size=batch_size,
|
730
|
+
*args,
|
731
|
+
**kwargs
|
732
|
+
)
|
733
|
+
|
734
|
+
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
735
|
+
inputs = self.get_tokenized_text(idx)
|
736
|
+
|
737
|
+
input_ids = inputs['input_ids'][0]
|
738
|
+
attention_mask = inputs['attention_mask'][0]
|
739
|
+
targets = input_ids.clone()
|
740
|
+
|
741
|
+
return {
|
742
|
+
'input_ids': input_ids,
|
743
|
+
'attention_mask': attention_mask,
|
744
|
+
'targets': targets
|
745
|
+
}
|
746
|
+
|
747
|
+
class EncoderSftDataset(BaseInteractionDataset):
|
748
|
+
def __init__(
|
749
|
+
self,
|
750
|
+
interactions: Union[list[dict], HfDataset],
|
751
|
+
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
752
|
+
max_seq_len: int = 1024,
|
753
|
+
query_field: str = 'query',
|
754
|
+
answer_field: str = 'answer',
|
755
|
+
cache_tokenized: bool = False,
|
756
|
+
cache_remove_text: bool = True,
|
757
|
+
tokenize_in_background: bool = False,
|
758
|
+
batch_size: int = 1,
|
759
|
+
mask_prob: float = 0.15,
|
760
|
+
*args,
|
761
|
+
**kwargs
|
762
|
+
):
|
763
|
+
super(EncoderSftDataset, self).__init__(
|
764
|
+
interactions,
|
765
|
+
tokenizer=tokenizer,
|
766
|
+
max_seq_len=max_seq_len,
|
767
|
+
query_field=query_field,
|
768
|
+
answer_field=answer_field,
|
769
|
+
cache_tokenized=cache_tokenized,
|
770
|
+
cache_remove_text=cache_remove_text,
|
771
|
+
tokenize_in_background=tokenize_in_background,
|
772
|
+
batch_size=batch_size,
|
773
|
+
*args,
|
774
|
+
**kwargs
|
775
|
+
)
|
776
|
+
self.mask_prob = mask_prob
|
777
|
+
|
778
|
+
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
779
|
+
inputs = self.get_tokenized_text(idx)
|
780
|
+
|
781
|
+
input_ids = inputs['input_ids'][0]
|
782
|
+
if self.is_pre_tokenized:
|
783
|
+
input_ids = input_ids.clone()
|
784
|
+
attention_mask = inputs['attention_mask'][0]
|
785
|
+
labels = input_ids.clone()
|
786
|
+
|
787
|
+
# Create masked indices
|
788
|
+
masked_indices = torch.bernoulli(
|
789
|
+
torch.full(labels.shape, self.mask_prob)
|
790
|
+
).bool() & attention_mask.bool()
|
791
|
+
|
792
|
+
# Apply mask
|
793
|
+
labels[~masked_indices] = -100
|
794
|
+
input_ids[masked_indices] = self.tokenizer.mask_token_id
|
795
|
+
|
796
|
+
return {
|
797
|
+
'input_ids': input_ids,
|
798
|
+
'attention_mask': attention_mask,
|
799
|
+
'labels': labels
|
800
|
+
}
|
@@ -12,7 +12,7 @@ rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
12
12
|
rxnn/training/base.py,sha256=xPMA2Bg9-oUZvSZg67ls2p7Gk9pZ9IHUiIJwUzSe2K8,11766
|
13
13
|
rxnn/training/bml.py,sha256=S1ZaXTybzeJH7uVFamCr4TPl2bLyZ5xmn_lSsjThTiM,19162
|
14
14
|
rxnn/training/callbacks.py,sha256=_YfMKY_eFdc-tubhO9nYH2PXDZDQwlSI74FVOoCXpQg,22108
|
15
|
-
rxnn/training/dataset.py,sha256=
|
15
|
+
rxnn/training/dataset.py,sha256=xI7bbARRWifunVX6HakCroSFqkM401BQmxfsf9pDeY4,35621
|
16
16
|
rxnn/training/scheduler.py,sha256=ow6oALzWjWQmHSpcJEjv6tg4g4CDMvr73TypxfcefMc,712
|
17
17
|
rxnn/training/tokenizer.py,sha256=umaLByMBx_NMrQElA45HLm9gkuzyKWDTFaKVd-CjXl0,8344
|
18
18
|
rxnn/transformers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -25,7 +25,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
|
|
25
25
|
rxnn/transformers/positional.py,sha256=ge-kaS6WnWnPGnWVp25ZK5bVkmhBUNCaELaN2rN_fSY,4097
|
26
26
|
rxnn/transformers/sampler.py,sha256=poWBpxg1iuK5gEJtxHkk5VVfS9V48hs2Olqdhy_Gw8c,6548
|
27
27
|
rxnn/utils.py,sha256=d5U8i5ukovgDyqiycc2AoxObTz_eF_bgo2MKvdtJ98s,467
|
28
|
-
rxnn-0.1.
|
29
|
-
rxnn-0.1.
|
30
|
-
rxnn-0.1.
|
31
|
-
rxnn-0.1.
|
28
|
+
rxnn-0.1.82.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
29
|
+
rxnn-0.1.82.dist-info/METADATA,sha256=xhip3_H9uGKIHKfyTnR0vk_a9zr0TzTIr8buNIiDUQY,16589
|
30
|
+
rxnn-0.1.82.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
31
|
+
rxnn-0.1.82.dist-info/RECORD,,
|
File without changes
|