rxnn 0.2.14__py3-none-any.whl → 0.2.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/training/dataset.py
CHANGED
@@ -41,7 +41,6 @@ class BaseDataset(Dataset):
|
|
41
41
|
self.bg_next.append(self.get_tokenized_text(i))
|
42
42
|
self.last_idx = self.batch_size - 1
|
43
43
|
|
44
|
-
|
45
44
|
def __len__(self):
|
46
45
|
return len(self.texts if not self.is_pre_tokenized else self.inputs)
|
47
46
|
|
@@ -105,7 +104,8 @@ class BaseDataset(Dataset):
|
|
105
104
|
return_attention_mask=True
|
106
105
|
)
|
107
106
|
if not (inputs['input_ids'][0] < self.tokenizer.vocab_size).all():
|
108
|
-
inputs['input_ids'][0][
|
107
|
+
inputs['input_ids'][0][
|
108
|
+
(inputs['input_ids'][0] >= self.tokenizer.vocab_size)] = self.tokenizer.unk_token_id
|
109
109
|
if not (inputs['input_ids'][0] >= 0).all():
|
110
110
|
inputs['input_ids'][0][inputs['input_ids'][0] < 0] = self.tokenizer.unk_token_id
|
111
111
|
|
@@ -123,7 +123,8 @@ class BaseDataset(Dataset):
|
|
123
123
|
split_point = int(len(self.texts) * ((1 - size) if not from_start else size))
|
124
124
|
if not isinstance(self.texts, list):
|
125
125
|
subset = self.texts.select(range(split_point, len(self.texts)) if not from_start else range(split_point))
|
126
|
-
self.texts = self.texts.select(
|
126
|
+
self.texts = self.texts.select(
|
127
|
+
range(split_point) if not from_start else range(split_point, len(self.texts)))
|
127
128
|
else:
|
128
129
|
subset = self.texts[split_point:-1] if not from_start else self.texts[0:split_point]
|
129
130
|
self.texts = self.texts[0:split_point] if not from_start else self.texts[split_point:-1]
|
@@ -155,7 +156,6 @@ class BaseDataset(Dataset):
|
|
155
156
|
print(f'Processed {index + 1}/{num_texts}')
|
156
157
|
self.is_pre_tokenized = True
|
157
158
|
|
158
|
-
|
159
159
|
@classmethod
|
160
160
|
def from_hf_hub(
|
161
161
|
cls,
|
@@ -247,7 +247,8 @@ class BaseDataset(Dataset):
|
|
247
247
|
tokenizer = load_tokenizer_from_hf_hub(tokenizer_hub_id, **load_tokenizer_kwargs)
|
248
248
|
|
249
249
|
hf_datasets = [
|
250
|
-
load_dataset(dataset_id, subset, split=split, **load_kwargs) for dataset_id, subset in
|
250
|
+
load_dataset(dataset_id, subset, split=split, **load_kwargs) for dataset_id, subset in
|
251
|
+
zip(dataset_ids, subsets)
|
251
252
|
] if subsets is not None else [
|
252
253
|
load_dataset(dataset_id, split=split, **load_kwargs) for dataset_id in dataset_ids
|
253
254
|
]
|
@@ -302,7 +303,8 @@ class BaseDataset(Dataset):
|
|
302
303
|
tokenizer = load_tokenizer_from_hf_hub(tokenizer_hub_id, **load_tokenizer_kwargs)
|
303
304
|
|
304
305
|
hf_datasets = [
|
305
|
-
load_dataset(dataset_id, subset, split=split, **load_kwargs) for dataset_id, subset in
|
306
|
+
load_dataset(dataset_id, subset, split=split, **load_kwargs) for dataset_id, subset in
|
307
|
+
zip(dataset_ids, subsets)
|
306
308
|
] if subsets is not None else [
|
307
309
|
load_dataset(dataset_id, split=split, **load_kwargs) for dataset_id in dataset_ids
|
308
310
|
]
|
@@ -312,7 +314,8 @@ class BaseDataset(Dataset):
|
|
312
314
|
hf_dataset = concatenate_datasets([ds_dict['train'] for ds_dict in hf_ds_dicts])
|
313
315
|
hf_valid_dataset = concatenate_datasets([ds_dict['test'] for ds_dict in hf_ds_dicts])
|
314
316
|
|
315
|
-
return cls(hf_dataset, tokenizer, max_seq_len=max_seq_len, hf_field=target_field, **kwargs), cls(
|
317
|
+
return cls(hf_dataset, tokenizer, max_seq_len=max_seq_len, hf_field=target_field, **kwargs), cls(
|
318
|
+
hf_valid_dataset, tokenizer, max_seq_len=max_seq_len, hf_field=target_field, **kwargs)
|
316
319
|
|
317
320
|
|
318
321
|
class JointLMDataset(BaseDataset):
|
@@ -427,6 +430,7 @@ class AutoregressiveLMDataset(BaseDataset):
|
|
427
430
|
'targets': targets
|
428
431
|
}
|
429
432
|
|
433
|
+
|
430
434
|
class BaseInteractionDataset(Dataset):
|
431
435
|
def __init__(
|
432
436
|
self,
|
@@ -463,7 +467,6 @@ class BaseInteractionDataset(Dataset):
|
|
463
467
|
self.bg_next.append(self.get_tokenized_text(i))
|
464
468
|
self.last_idx = self.batch_size - 1
|
465
469
|
|
466
|
-
|
467
470
|
def __len__(self):
|
468
471
|
return len(self.interactions if not self.is_pre_tokenized else self.inputs)
|
469
472
|
|
@@ -528,7 +531,8 @@ class BaseInteractionDataset(Dataset):
|
|
528
531
|
return_attention_mask=True
|
529
532
|
)
|
530
533
|
if not (inputs['input_ids'][0] < self.tokenizer.vocab_size).all():
|
531
|
-
inputs['input_ids'][0][
|
534
|
+
inputs['input_ids'][0][
|
535
|
+
(inputs['input_ids'][0] >= self.tokenizer.vocab_size)] = self.tokenizer.unk_token_id
|
532
536
|
if not (inputs['input_ids'][0] >= 0).all():
|
533
537
|
inputs['input_ids'][0][inputs['input_ids'][0] < 0] = self.tokenizer.unk_token_id
|
534
538
|
|
@@ -545,12 +549,16 @@ class BaseInteractionDataset(Dataset):
|
|
545
549
|
def get_subset(self, size: float, from_start: bool = False, **kwargs) -> "BaseInteractionDataset":
|
546
550
|
split_point = int(len(self.interactions) * ((1 - size) if not from_start else size))
|
547
551
|
if not isinstance(self.interactions, list):
|
548
|
-
subset = self.interactions.select(
|
549
|
-
|
552
|
+
subset = self.interactions.select(
|
553
|
+
range(split_point, len(self.interactions)) if not from_start else range(split_point))
|
554
|
+
self.interactions = self.interactions.select(
|
555
|
+
range(split_point) if not from_start else range(split_point, len(self.interactions)))
|
550
556
|
else:
|
551
557
|
subset = self.interactions[split_point:-1] if not from_start else self.interactions[0:split_point]
|
552
|
-
self.interactions = self.interactions[0:split_point] if not from_start else self.interactions[
|
553
|
-
|
558
|
+
self.interactions = self.interactions[0:split_point] if not from_start else self.interactions[
|
559
|
+
split_point:-1]
|
560
|
+
return self.__class__(subset, self.tokenizer, max_seq_len=self.max_seq_len, query_field=self.query_field,
|
561
|
+
answer_field=self.answer_field, **kwargs)
|
554
562
|
|
555
563
|
def pre_tokenize(self, verbose: bool = False, log_interval: int = 10_000):
|
556
564
|
"""
|
@@ -577,7 +585,6 @@ class BaseInteractionDataset(Dataset):
|
|
577
585
|
print(f'Processed {index + 1}/{num_texts}')
|
578
586
|
self.is_pre_tokenized = True
|
579
587
|
|
580
|
-
|
581
588
|
@classmethod
|
582
589
|
def from_hf_hub(
|
583
590
|
cls,
|
@@ -624,7 +631,8 @@ class BaseInteractionDataset(Dataset):
|
|
624
631
|
|
625
632
|
query_field, answer_field = target_fields
|
626
633
|
|
627
|
-
return cls(hf_dataset, tokenizer, max_seq_len=max_seq_len, query_field=query_field, answer_field=answer_field,
|
634
|
+
return cls(hf_dataset, tokenizer, max_seq_len=max_seq_len, query_field=query_field, answer_field=answer_field,
|
635
|
+
**kwargs)
|
628
636
|
|
629
637
|
@classmethod
|
630
638
|
def concat_from_hf_hub(
|
@@ -671,7 +679,8 @@ class BaseInteractionDataset(Dataset):
|
|
671
679
|
tokenizer = load_tokenizer_from_hf_hub(tokenizer_hub_id, **load_tokenizer_kwargs)
|
672
680
|
|
673
681
|
hf_datasets = [
|
674
|
-
load_dataset(dataset_id, subset, split=split, **load_kwargs) for dataset_id, subset in
|
682
|
+
load_dataset(dataset_id, subset, split=split, **load_kwargs) for dataset_id, subset in
|
683
|
+
zip(dataset_ids, subsets)
|
675
684
|
] if subsets is not None else [
|
676
685
|
load_dataset(dataset_id, split=split, **load_kwargs) for dataset_id in dataset_ids
|
677
686
|
]
|
@@ -679,7 +688,8 @@ class BaseInteractionDataset(Dataset):
|
|
679
688
|
|
680
689
|
query_field, answer_field = target_fields
|
681
690
|
|
682
|
-
return cls(hf_dataset, tokenizer, max_seq_len=max_seq_len, query_field=query_field, answer_field=answer_field,
|
691
|
+
return cls(hf_dataset, tokenizer, max_seq_len=max_seq_len, query_field=query_field, answer_field=answer_field,
|
692
|
+
**kwargs)
|
683
693
|
|
684
694
|
@classmethod
|
685
695
|
def concat_from_hf_hub_with_subset(
|
@@ -728,7 +738,8 @@ class BaseInteractionDataset(Dataset):
|
|
728
738
|
tokenizer = load_tokenizer_from_hf_hub(tokenizer_hub_id, **load_tokenizer_kwargs)
|
729
739
|
|
730
740
|
hf_datasets = [
|
731
|
-
load_dataset(dataset_id, subset, split=split, **load_kwargs) for dataset_id, subset in
|
741
|
+
load_dataset(dataset_id, subset, split=split, **load_kwargs) for dataset_id, subset in
|
742
|
+
zip(dataset_ids, subsets)
|
732
743
|
] if subsets is not None else [
|
733
744
|
load_dataset(dataset_id, split=split, **load_kwargs) for dataset_id in dataset_ids
|
734
745
|
]
|
@@ -740,22 +751,25 @@ class BaseInteractionDataset(Dataset):
|
|
740
751
|
|
741
752
|
query_field, answer_field = target_fields
|
742
753
|
|
743
|
-
return cls(hf_dataset, tokenizer, max_seq_len=max_seq_len, query_field=query_field, answer_field=answer_field,
|
754
|
+
return cls(hf_dataset, tokenizer, max_seq_len=max_seq_len, query_field=query_field, answer_field=answer_field,
|
755
|
+
**kwargs), cls(hf_valid_dataset, tokenizer, max_seq_len=max_seq_len, query_field=query_field,
|
756
|
+
answer_field=answer_field, **kwargs)
|
757
|
+
|
744
758
|
|
745
759
|
class DecoderSftDataset(BaseInteractionDataset):
|
746
760
|
def __init__(
|
747
|
-
|
748
|
-
|
749
|
-
|
750
|
-
|
751
|
-
|
752
|
-
|
753
|
-
|
754
|
-
|
755
|
-
|
756
|
-
|
757
|
-
|
758
|
-
|
761
|
+
self,
|
762
|
+
interactions: Union[list[dict], HfDataset],
|
763
|
+
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
764
|
+
max_seq_len: int = 1024,
|
765
|
+
query_field: str = 'query',
|
766
|
+
answer_field: str = 'answer',
|
767
|
+
cache_tokenized: bool = False,
|
768
|
+
cache_remove_text: bool = True,
|
769
|
+
tokenize_in_background: bool = False,
|
770
|
+
batch_size: int = 1,
|
771
|
+
*args,
|
772
|
+
**kwargs
|
759
773
|
):
|
760
774
|
super(DecoderSftDataset, self).__init__(
|
761
775
|
interactions,
|
@@ -784,21 +798,22 @@ class DecoderSftDataset(BaseInteractionDataset):
|
|
784
798
|
'targets': targets
|
785
799
|
}
|
786
800
|
|
801
|
+
|
787
802
|
class EncoderSftDataset(BaseInteractionDataset):
|
788
803
|
def __init__(
|
789
|
-
|
790
|
-
|
791
|
-
|
792
|
-
|
793
|
-
|
794
|
-
|
795
|
-
|
796
|
-
|
797
|
-
|
798
|
-
|
799
|
-
|
800
|
-
|
801
|
-
|
804
|
+
self,
|
805
|
+
interactions: Union[list[dict], HfDataset],
|
806
|
+
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
807
|
+
max_seq_len: int = 1024,
|
808
|
+
query_field: str = 'query',
|
809
|
+
answer_field: str = 'answer',
|
810
|
+
cache_tokenized: bool = False,
|
811
|
+
cache_remove_text: bool = True,
|
812
|
+
tokenize_in_background: bool = False,
|
813
|
+
batch_size: int = 1,
|
814
|
+
mask_prob: float = 0.15,
|
815
|
+
*args,
|
816
|
+
**kwargs
|
802
817
|
):
|
803
818
|
super(EncoderSftDataset, self).__init__(
|
804
819
|
interactions,
|
@@ -839,8 +854,10 @@ class EncoderSftDataset(BaseInteractionDataset):
|
|
839
854
|
'labels': labels
|
840
855
|
}
|
841
856
|
|
857
|
+
|
842
858
|
MrlDataItem: TypeAlias = dict[str, Union[dict[str, torch.Tensor], list[dict[str, dict[str, torch.Tensor]]]]]
|
843
859
|
|
860
|
+
|
844
861
|
class MrlCurriculumDataset(Dataset):
|
845
862
|
def __init__(
|
846
863
|
self,
|
@@ -931,12 +948,16 @@ class MrlCurriculumDataset(Dataset):
|
|
931
948
|
def get_subset(self, size: float, from_start: bool = False, **kwargs) -> "MRlCurriculumDataset":
|
932
949
|
split_point = int(len(self.episodes) * ((1 - size) if not from_start else size))
|
933
950
|
if not isinstance(self.episodes, list):
|
934
|
-
subset = self.episodes.select(
|
935
|
-
|
951
|
+
subset = self.episodes.select(
|
952
|
+
range(split_point, len(self.episodes)) if not from_start else range(split_point))
|
953
|
+
self.episodes = self.episodes.select(
|
954
|
+
range(split_point) if not from_start else range(split_point, len(self.episodes)))
|
936
955
|
else:
|
937
956
|
subset = self.episodes[split_point:-1] if not from_start else self.episodes[0:split_point]
|
938
957
|
self.episodes = self.episodes[0:split_point] if not from_start else self.episodes[split_point:-1]
|
939
|
-
return self.__class__(subset, tokenizer=self.tokenizer, query_field=self.query_field,
|
958
|
+
return self.__class__(subset, tokenizer=self.tokenizer, query_field=self.query_field,
|
959
|
+
answer_field=self.answer_field, interactions_field=self.interactions_field,
|
960
|
+
max_seq_len=self.max_seq_len, **kwargs)
|
940
961
|
|
941
962
|
def pre_tokenize(self, verbose: bool = False, log_interval: int = 10_000, keep_order: bool = False):
|
942
963
|
"""
|
@@ -977,6 +998,7 @@ class MrlCurriculumDataset(Dataset):
|
|
977
998
|
answer_field: str = 'answer',
|
978
999
|
interactions_field: str = 'interactions',
|
979
1000
|
load_kwargs: dict = None,
|
1001
|
+
max_seq_len: int = 1024,
|
980
1002
|
**kwargs
|
981
1003
|
):
|
982
1004
|
"""
|
@@ -993,19 +1015,23 @@ class MrlCurriculumDataset(Dataset):
|
|
993
1015
|
answer_field (str): Answer field (default: "answer")
|
994
1016
|
interactions_field (str): Interactions field (default: "interactions")
|
995
1017
|
load_kwargs (dict): Additional args for HuggingFace API load_dataset function
|
1018
|
+
max_seq_len (int): Maximum sequence length (default: 1024)
|
996
1019
|
**kwargs: Additional args for RxNN Dataset class
|
997
1020
|
"""
|
998
1021
|
if load_kwargs is None:
|
999
|
-
|
1022
|
+
load_kwargs = {}
|
1000
1023
|
|
1001
1024
|
hf_dataset = load_dataset(dataset_id, mrl_subset, split=split, **load_kwargs)
|
1002
1025
|
|
1003
|
-
return cls(hf_dataset, tokenizer, query_field=query_field, answer_field=answer_field,
|
1026
|
+
return cls(hf_dataset, tokenizer, query_field=query_field, answer_field=answer_field,
|
1027
|
+
interactions_field=interactions_field, max_seq_len=max_seq_len, **kwargs)
|
1004
1028
|
|
1005
1029
|
@staticmethod
|
1006
1030
|
def collate_mrl_batch(batch: list[MrlDataItem]) -> MrlDataItem:
|
1007
1031
|
"""Collate function for MRL curriculum dataset with nested interactions"""
|
1008
|
-
|
1032
|
+
|
1033
|
+
def collate_interaction_batch(interaction_batch: Union[list[dict[str, dict[str, torch.Tensor]]], tuple[Any]]) -> \
|
1034
|
+
dict[str, dict[str, torch.Tensor]]:
|
1009
1035
|
"""Helper to collate a batch of interactions"""
|
1010
1036
|
return {
|
1011
1037
|
'query': {
|
@@ -1022,23 +1048,26 @@ class MrlCurriculumDataset(Dataset):
|
|
1022
1048
|
transposed_interactions = list(zip(*batch_interactions))
|
1023
1049
|
|
1024
1050
|
return {
|
1025
|
-
**collate_interaction_batch(batch),
|
1051
|
+
**collate_interaction_batch(batch), # Collate initial query and answer
|
1026
1052
|
'interactions': [
|
1027
1053
|
collate_interaction_batch(step_batch) for step_batch in transposed_interactions
|
1028
1054
|
]
|
1029
1055
|
}
|
1030
1056
|
|
1057
|
+
|
1031
1058
|
class MrlDatasetItem(TypedDict):
|
1032
1059
|
steps: int
|
1033
1060
|
is_long_range: bool
|
1034
1061
|
dataset: MrlCurriculumDataset
|
1035
1062
|
eval_dataset: Optional[MrlCurriculumDataset]
|
1036
1063
|
|
1064
|
+
|
1037
1065
|
class MrlDatasetLoadItem(TypedDict):
|
1038
1066
|
subset_name: str
|
1039
1067
|
steps: int
|
1040
1068
|
is_long_range: bool
|
1041
1069
|
|
1070
|
+
|
1042
1071
|
class MrlDatasets:
|
1043
1072
|
def __init__(self, datasets: list[MrlDatasetItem]):
|
1044
1073
|
self.datasets = datasets
|
@@ -1061,7 +1090,8 @@ class MrlDatasets:
|
|
1061
1090
|
@property
|
1062
1091
|
def is_pre_tokenized(self) -> bool:
|
1063
1092
|
train_tokenized = all(item['dataset'].is_pre_tokenized for item in self.datasets)
|
1064
|
-
eval_tokenized = all(
|
1093
|
+
eval_tokenized = all(
|
1094
|
+
item['eval_dataset'].is_pre_tokenized for item in self.datasets if item['eval_dataset'] is not None)
|
1065
1095
|
return train_tokenized and eval_tokenized
|
1066
1096
|
|
1067
1097
|
def pre_tokenize(self, verbose: bool = False, log_interval: int = 10_000, keep_order: bool = False):
|
@@ -1147,4 +1177,4 @@ class MrlDatasets:
|
|
1147
1177
|
|
1148
1178
|
mrl_datasets = [dataset_item(item) for item in mrl_curriculum_steps]
|
1149
1179
|
|
1150
|
-
return cls(mrl_datasets)
|
1180
|
+
return cls(mrl_datasets)
|
@@ -14,7 +14,7 @@ rxnn/training/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
14
14
|
rxnn/training/base.py,sha256=_xik1GXE4RJ_nxwqLQ1ccXA5pRtBCi-jL-jeRFBdHBU,11851
|
15
15
|
rxnn/training/bml.py,sha256=FJszaQXOLx2ZHBa1CQpyMrG8i4Kj14E-gzDAEK_Ei5k,17272
|
16
16
|
rxnn/training/callbacks.py,sha256=-N0MQPpZQaUWCINdTOsjul4bDGbGr2JgQBqOIXBLS6o,35053
|
17
|
-
rxnn/training/dataset.py,sha256=
|
17
|
+
rxnn/training/dataset.py,sha256=7hTilFWPpqUEc6zNcMqBPjxFKxCfvTKKF3E8tVlwccQ,51250
|
18
18
|
rxnn/training/models.py,sha256=renPa5YH443XNTMFI-YTCwi5vNp3QzwF5UXedNd5hDk,5187
|
19
19
|
rxnn/training/mrl.py,sha256=CezloyaXOKrc_F_eDt99EZ1fmKAMCCCMh5Ry6vF82Ro,39607
|
20
20
|
rxnn/training/reward.py,sha256=C0ToTz-u-L-qyBd2yJ1HlvVPS110OChYj9ZhD6iSSMU,5654
|
@@ -32,7 +32,7 @@ rxnn/transformers/moe.py,sha256=j6jEx6Ip0zttlUZKKn82azxo95lkLZs-H2GLSMD88hY,5859
|
|
32
32
|
rxnn/transformers/positional.py,sha256=1PjcJybUzeQlIKJI4tahAGZcYgCRCL0otxs7mpsNuzM,4410
|
33
33
|
rxnn/transformers/sampler.py,sha256=t6iiQTdLQ0TakUWnnhKkb5DKF2F_9-thXHBydDF3fxg,17389
|
34
34
|
rxnn/utils.py,sha256=ihb6OTyDtPiocB_lOvnq7eOkjjpCkgs8wxvXUBNQ7mM,996
|
35
|
-
rxnn-0.2.
|
36
|
-
rxnn-0.2.
|
37
|
-
rxnn-0.2.
|
38
|
-
rxnn-0.2.
|
35
|
+
rxnn-0.2.16.dist-info/LICENSE,sha256=C8coDFIUYuOcke4JLPwTqahQUCyXyGq6WOaigOkx8tY,11275
|
36
|
+
rxnn-0.2.16.dist-info/METADATA,sha256=rSE80lLQZTqjOO9CJfLtyIQMCmgLep1suzXj2_DmPNI,25960
|
37
|
+
rxnn-0.2.16.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
38
|
+
rxnn-0.2.16.dist-info/RECORD,,
|
File without changes
|
File without changes
|