rxnn 0.1.83__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rxnn/.DS_Store +0 -0
- rxnn/experimental/attention.py +5 -0
- rxnn/memory/attention.py +42 -0
- rxnn/memory/stm.py +55 -12
- rxnn/rxt/models.py +71 -0
- rxnn/training/bml.py +2 -59
- rxnn/training/callbacks.py +302 -39
- rxnn/training/dataset.py +344 -1
- rxnn/training/models.py +142 -0
- rxnn/training/mrl.py +808 -0
- rxnn/training/reward.py +111 -0
- rxnn/training/rl.py +69 -0
- rxnn/training/utils.py +148 -0
- rxnn/transformers/attention.py +10 -0
- rxnn/transformers/layers.py +6 -0
- rxnn/transformers/models.py +16 -4
- rxnn/transformers/positional.py +7 -0
- rxnn/transformers/sampler.py +283 -9
- {rxnn-0.1.83.dist-info → rxnn-0.2.1.dist-info}/METADATA +11 -9
- rxnn-0.2.1.dist-info/RECORD +38 -0
- rxnn-0.1.83.dist-info/RECORD +0 -31
- {rxnn-0.1.83.dist-info → rxnn-0.2.1.dist-info}/LICENSE +0 -0
- {rxnn-0.1.83.dist-info → rxnn-0.2.1.dist-info}/WHEEL +0 -0
rxnn/training/dataset.py
CHANGED
@@ -4,7 +4,7 @@ from datasets import Dataset as HfDataset, load_dataset, concatenate_datasets
|
|
4
4
|
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
5
5
|
from .tokenizer import load_tokenizer_from_hf_hub
|
6
6
|
|
7
|
-
from typing import Union
|
7
|
+
from typing import Union, TypedDict, Optional, TypeAlias, Any
|
8
8
|
|
9
9
|
|
10
10
|
class BaseDataset(Dataset):
|
@@ -189,6 +189,12 @@ class BaseDataset(Dataset):
|
|
189
189
|
"""
|
190
190
|
assert tokenizer is not None or tokenizer_hub_id is not None, "One of the `tokenizer` or `tokenizer_hub_id` args must be provided."
|
191
191
|
|
192
|
+
if load_kwargs is None:
|
193
|
+
load_kwargs = {}
|
194
|
+
|
195
|
+
if load_tokenizer_kwargs is None:
|
196
|
+
load_tokenizer_kwargs = {}
|
197
|
+
|
192
198
|
if tokenizer is None:
|
193
199
|
tokenizer = load_tokenizer_from_hf_hub(tokenizer_hub_id, **load_tokenizer_kwargs)
|
194
200
|
|
@@ -231,6 +237,12 @@ class BaseDataset(Dataset):
|
|
231
237
|
"""
|
232
238
|
assert tokenizer is not None or tokenizer_hub_id is not None, "One of the `tokenizer` or `tokenizer_hub_id` args must be provided."
|
233
239
|
|
240
|
+
if load_kwargs is None:
|
241
|
+
load_kwargs = {}
|
242
|
+
|
243
|
+
if load_tokenizer_kwargs is None:
|
244
|
+
load_tokenizer_kwargs = {}
|
245
|
+
|
234
246
|
if tokenizer is None:
|
235
247
|
tokenizer = load_tokenizer_from_hf_hub(tokenizer_hub_id, **load_tokenizer_kwargs)
|
236
248
|
|
@@ -280,6 +292,12 @@ class BaseDataset(Dataset):
|
|
280
292
|
"""
|
281
293
|
assert tokenizer is not None or tokenizer_hub_id is not None, "One of the `tokenizer` or `tokenizer_hub_id` args must be provided."
|
282
294
|
|
295
|
+
if load_kwargs is None:
|
296
|
+
load_kwargs = {}
|
297
|
+
|
298
|
+
if load_tokenizer_kwargs is None:
|
299
|
+
load_tokenizer_kwargs = {}
|
300
|
+
|
283
301
|
if tokenizer is None:
|
284
302
|
tokenizer = load_tokenizer_from_hf_hub(tokenizer_hub_id, **load_tokenizer_kwargs)
|
285
303
|
|
@@ -593,6 +611,12 @@ class BaseInteractionDataset(Dataset):
|
|
593
611
|
"""
|
594
612
|
assert tokenizer is not None or tokenizer_hub_id is not None, "One of the `tokenizer` or `tokenizer_hub_id` args must be provided."
|
595
613
|
|
614
|
+
if load_kwargs is None:
|
615
|
+
load_kwargs = {}
|
616
|
+
|
617
|
+
if load_tokenizer_kwargs is None:
|
618
|
+
load_tokenizer_kwargs = {}
|
619
|
+
|
596
620
|
if tokenizer is None:
|
597
621
|
tokenizer = load_tokenizer_from_hf_hub(tokenizer_hub_id, **load_tokenizer_kwargs)
|
598
622
|
|
@@ -637,6 +661,12 @@ class BaseInteractionDataset(Dataset):
|
|
637
661
|
"""
|
638
662
|
assert tokenizer is not None or tokenizer_hub_id is not None, "One of the `tokenizer` or `tokenizer_hub_id` args must be provided."
|
639
663
|
|
664
|
+
if load_kwargs is None:
|
665
|
+
load_kwargs = {}
|
666
|
+
|
667
|
+
if load_tokenizer_kwargs is None:
|
668
|
+
load_tokenizer_kwargs = {}
|
669
|
+
|
640
670
|
if tokenizer is None:
|
641
671
|
tokenizer = load_tokenizer_from_hf_hub(tokenizer_hub_id, **load_tokenizer_kwargs)
|
642
672
|
|
@@ -688,6 +718,12 @@ class BaseInteractionDataset(Dataset):
|
|
688
718
|
"""
|
689
719
|
assert tokenizer is not None or tokenizer_hub_id is not None, "One of the `tokenizer` or `tokenizer_hub_id` args must be provided."
|
690
720
|
|
721
|
+
if load_kwargs is None:
|
722
|
+
load_kwargs = {}
|
723
|
+
|
724
|
+
if load_tokenizer_kwargs is None:
|
725
|
+
load_tokenizer_kwargs = {}
|
726
|
+
|
691
727
|
if tokenizer is None:
|
692
728
|
tokenizer = load_tokenizer_from_hf_hub(tokenizer_hub_id, **load_tokenizer_kwargs)
|
693
729
|
|
@@ -802,3 +838,310 @@ class EncoderSftDataset(BaseInteractionDataset):
|
|
802
838
|
'attention_mask': attention_mask,
|
803
839
|
'labels': labels
|
804
840
|
}
|
841
|
+
|
842
|
+
MrlDataItem: TypeAlias = dict[str, Union[dict[str, torch.Tensor], list[dict[str, dict[str, torch.Tensor]]]]]
|
843
|
+
|
844
|
+
class MrlCurriculumDataset(Dataset):
|
845
|
+
def __init__(
|
846
|
+
self,
|
847
|
+
episodes: Union[list[dict], HfDataset],
|
848
|
+
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
849
|
+
max_seq_len: int = 1024,
|
850
|
+
query_field: str = 'query',
|
851
|
+
answer_field: str = 'answer',
|
852
|
+
interactions_field: str = 'interactions',
|
853
|
+
query_token: str = '[Q]',
|
854
|
+
answer_token: str = '[A]',
|
855
|
+
bos_token: str = '[BOS]',
|
856
|
+
eos_token: str = '[EOS]',
|
857
|
+
**kwargs,
|
858
|
+
):
|
859
|
+
super(MrlCurriculumDataset, self).__init__(**kwargs)
|
860
|
+
self.episodes = episodes
|
861
|
+
self.tokenizer = tokenizer
|
862
|
+
self.max_seq_len = max_seq_len
|
863
|
+
self.query_field = query_field
|
864
|
+
self.answer_field = answer_field
|
865
|
+
self.interactions_field = interactions_field
|
866
|
+
self.query_token = query_token
|
867
|
+
self.answer_token = answer_token
|
868
|
+
self.bos_token = bos_token
|
869
|
+
self.eos_token = eos_token
|
870
|
+
self.is_pre_tokenized = False
|
871
|
+
self.is_list = isinstance(self.episodes, list)
|
872
|
+
self.inputs = []
|
873
|
+
|
874
|
+
def _tokenize_manual_interaction(self, query: str, answer: str) -> dict[str, dict[str, torch.Tensor]]:
|
875
|
+
# Manually construct query: [BOS][Q]query
|
876
|
+
query_text = f"{self.bos_token}{self.query_token}{query}"
|
877
|
+
query_enc = self.tokenizer(
|
878
|
+
query_text,
|
879
|
+
max_length=self.max_seq_len,
|
880
|
+
padding='max_length',
|
881
|
+
truncation=True,
|
882
|
+
return_tensors='pt',
|
883
|
+
add_special_tokens=False # Critical: We control all tokens
|
884
|
+
)
|
885
|
+
|
886
|
+
# Manually construct answer: [A]answer[EOS]
|
887
|
+
answer_text = f"{self.answer_token}{answer}{self.eos_token}"
|
888
|
+
answer_enc = self.tokenizer(
|
889
|
+
answer_text,
|
890
|
+
max_length=self.max_seq_len,
|
891
|
+
padding='max_length',
|
892
|
+
truncation=True,
|
893
|
+
return_tensors='pt',
|
894
|
+
add_special_tokens=False # Critical: We control all tokens
|
895
|
+
)
|
896
|
+
|
897
|
+
return {
|
898
|
+
'query': {
|
899
|
+
'input_ids': query_enc['input_ids'][0],
|
900
|
+
'attention_mask': query_enc['attention_mask'][0],
|
901
|
+
},
|
902
|
+
'answer': {
|
903
|
+
'input_ids': answer_enc['input_ids'][0],
|
904
|
+
'attention_mask': answer_enc['attention_mask'][0],
|
905
|
+
}
|
906
|
+
}
|
907
|
+
|
908
|
+
def get_tokenized_item(self, idx: int, episode: dict = None) -> MrlDataItem:
|
909
|
+
if self.is_pre_tokenized:
|
910
|
+
return self.inputs[idx]
|
911
|
+
else:
|
912
|
+
item = self.episodes[idx] if episode is None else episode
|
913
|
+
query = item[self.query_field]
|
914
|
+
answer = item[self.answer_field]
|
915
|
+
interactions = item[self.interactions_field]
|
916
|
+
|
917
|
+
initial = self._tokenize_manual_interaction(query, answer)
|
918
|
+
follow_ups = [self._tokenize_manual_interaction(inter['query'], inter['answer']) for inter in interactions]
|
919
|
+
|
920
|
+
return {
|
921
|
+
**initial,
|
922
|
+
'interactions': follow_ups,
|
923
|
+
}
|
924
|
+
|
925
|
+
def __getitem__(self, idx: int) -> MrlDataItem:
|
926
|
+
return self.get_tokenized_item(idx)
|
927
|
+
|
928
|
+
def __len__(self) -> int:
|
929
|
+
return len(self.episodes)
|
930
|
+
|
931
|
+
def get_subset(self, size: float, from_start: bool = False, **kwargs) -> "MRlCurriculumDataset":
|
932
|
+
split_point = int(len(self.episodes) * ((1 - size) if not from_start else size))
|
933
|
+
if not isinstance(self.episodes, list):
|
934
|
+
subset = self.episodes.select(range(split_point, len(self.episodes)) if not from_start else range(split_point))
|
935
|
+
self.episodes = self.episodes.select(range(split_point) if not from_start else range(split_point, len(self.episodes)))
|
936
|
+
else:
|
937
|
+
subset = self.episodes[split_point:-1] if not from_start else self.episodes[0:split_point]
|
938
|
+
self.episodes = self.episodes[0:split_point] if not from_start else self.episodes[split_point:-1]
|
939
|
+
return self.__class__(subset, query_field=self.query_field, answer_field=self.answer_field, interactions_field=self.interactions_field, **kwargs)
|
940
|
+
|
941
|
+
def pre_tokenize(self, verbose: bool = False, log_interval: int = 10_000, keep_order: bool = False):
|
942
|
+
"""
|
943
|
+
Pre-tokenizes all the items in the dataset, for faster training. Training with pre-tokenized
|
944
|
+
dataset could be even 2x faster.
|
945
|
+
|
946
|
+
!! This method has extremely high memory usage, when used with HuggingFace datasets,
|
947
|
+
because of converting it to list. Additionally, for the most optimal performance,
|
948
|
+
pre-tokenized items are in reversed order - it shouldn't matter for training, as
|
949
|
+
items are shuffled then by DataLoader, but you should keep that in mind in case
|
950
|
+
of reproducibility.
|
951
|
+
|
952
|
+
Args:
|
953
|
+
verbose (bool): Should display logs (default: False)
|
954
|
+
log_interval (int): Display logs every log_interval iterations (default: 10_000)
|
955
|
+
keep_order (bool): Keep tokenized items in the same order - by default they are reversed for faster processing (default: False)
|
956
|
+
"""
|
957
|
+
if not self.is_pre_tokenized:
|
958
|
+
num_episodes = len(self.episodes)
|
959
|
+
eps = self.episodes if self.is_list else self.episodes.to_list()
|
960
|
+
del self.episodes
|
961
|
+
self.episodes = None
|
962
|
+
for index in range(num_episodes):
|
963
|
+
self.inputs.append(self.get_tokenized_item(index, episode=eps.pop() if not keep_order else eps[index]))
|
964
|
+
if verbose and index % log_interval == 0:
|
965
|
+
print(f'Processed {index + 1}/{num_episodes}')
|
966
|
+
del eps
|
967
|
+
self.is_pre_tokenized = True
|
968
|
+
|
969
|
+
@classmethod
|
970
|
+
def from_hf_hub(
|
971
|
+
cls,
|
972
|
+
dataset_id: str,
|
973
|
+
mrl_subset: str,
|
974
|
+
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
975
|
+
split: str = 'train',
|
976
|
+
query_field: str = 'query',
|
977
|
+
answer_field: str = 'answer',
|
978
|
+
interactions_field: str = 'interactions',
|
979
|
+
load_kwargs: dict = None,
|
980
|
+
**kwargs
|
981
|
+
):
|
982
|
+
"""
|
983
|
+
Load dataset from HuggingFace Hub and convert it to RxNN training dataset.
|
984
|
+
|
985
|
+
One of the `tokenizer` or `tokenizer_hub_id` args must be provided. If both are provided, `tokenizer` will be used.
|
986
|
+
|
987
|
+
Args:
|
988
|
+
dataset_id (str): Hub dataset repository name
|
989
|
+
mrl_subset (str): Dataset subset
|
990
|
+
tokenizer (Union[PreTrainedTokenizer, PreTrainedTokenizerFast]): Tokenizer
|
991
|
+
split (str): Dataset split (default: "train")
|
992
|
+
query_field (str): Query field (default: "query")
|
993
|
+
answer_field (str): Answer field (default: "answer")
|
994
|
+
interactions_field (str): Interactions field (default: "interactions")
|
995
|
+
load_kwargs (dict): Additional args for HuggingFace API load_dataset function
|
996
|
+
**kwargs: Additional args for RxNN Dataset class
|
997
|
+
"""
|
998
|
+
if load_kwargs is None:
|
999
|
+
load_kwargs = {}
|
1000
|
+
|
1001
|
+
hf_dataset = load_dataset(dataset_id, mrl_subset, split=split, **load_kwargs)
|
1002
|
+
|
1003
|
+
return cls(hf_dataset, tokenizer, query_field=query_field, answer_field=answer_field, interactions_field=interactions_field, **kwargs)
|
1004
|
+
|
1005
|
+
@staticmethod
|
1006
|
+
def collate_mrl_batch(batch: list[MrlDataItem]) -> MrlDataItem:
|
1007
|
+
"""Collate function for MRL curriculum dataset with nested interactions"""
|
1008
|
+
def collate_interaction_batch(interaction_batch: Union[list[dict[str, dict[str, torch.Tensor]]], tuple[Any]]) -> dict[str, dict[str, torch.Tensor]]:
|
1009
|
+
"""Helper to collate a batch of interactions"""
|
1010
|
+
return {
|
1011
|
+
'query': {
|
1012
|
+
'input_ids': torch.stack([x['query']['input_ids'] for x in interaction_batch]),
|
1013
|
+
'attention_mask': torch.stack([x['query']['attention_mask'] for x in interaction_batch]),
|
1014
|
+
},
|
1015
|
+
'answer': {
|
1016
|
+
'input_ids': torch.stack([x['answer']['input_ids'] for x in interaction_batch]),
|
1017
|
+
'attention_mask': torch.stack([x['answer']['attention_mask'] for x in interaction_batch]),
|
1018
|
+
}
|
1019
|
+
}
|
1020
|
+
|
1021
|
+
batch_interactions = [x['interactions'] for x in batch]
|
1022
|
+
transposed_interactions = list(zip(*batch_interactions))
|
1023
|
+
|
1024
|
+
return {
|
1025
|
+
**collate_interaction_batch(batch), # Collate initial query and answer
|
1026
|
+
'interactions': [
|
1027
|
+
collate_interaction_batch(step_batch) for step_batch in transposed_interactions
|
1028
|
+
]
|
1029
|
+
}
|
1030
|
+
|
1031
|
+
class MrlDatasetItem(TypedDict):
|
1032
|
+
steps: int
|
1033
|
+
is_long_range: bool
|
1034
|
+
dataset: MrlCurriculumDataset
|
1035
|
+
eval_dataset: Optional[MrlCurriculumDataset]
|
1036
|
+
|
1037
|
+
class MrlDatasetLoadItem(TypedDict):
|
1038
|
+
subset_name: str
|
1039
|
+
steps: int
|
1040
|
+
is_long_range: bool
|
1041
|
+
|
1042
|
+
class MrlDatasets:
|
1043
|
+
def __init__(self, datasets: list[MrlDatasetItem]):
|
1044
|
+
self.datasets = datasets
|
1045
|
+
|
1046
|
+
def __iter__(self):
|
1047
|
+
return iter(self.datasets)
|
1048
|
+
|
1049
|
+
def __getitem__(self, idx: int) -> MrlDatasetItem:
|
1050
|
+
return self.datasets[idx]
|
1051
|
+
|
1052
|
+
def __len__(self):
|
1053
|
+
return len(self.datasets)
|
1054
|
+
|
1055
|
+
def __call__(self, steps: int, is_long_range: bool = False):
|
1056
|
+
for dataset in self.datasets:
|
1057
|
+
if dataset['steps'] == steps and dataset['is_long_range'] == is_long_range:
|
1058
|
+
return dataset
|
1059
|
+
return None
|
1060
|
+
|
1061
|
+
@property
|
1062
|
+
def is_pre_tokenized(self) -> bool:
|
1063
|
+
train_tokenized = all(item['dataset'].is_pre_tokenized for item in self.datasets)
|
1064
|
+
eval_tokenized = all(item['eval_dataset'].is_pre_tokenized for item in self.datasets if item['eval_dataset'] is not None)
|
1065
|
+
return train_tokenized and eval_tokenized
|
1066
|
+
|
1067
|
+
def pre_tokenize(self, verbose: bool = False, log_interval: int = 10_000, keep_order: bool = False):
|
1068
|
+
"""
|
1069
|
+
Pre-tokenizes all the inner datasets
|
1070
|
+
|
1071
|
+
!! This method has extremely high memory usage, when used with HuggingFace datasets,
|
1072
|
+
because of converting it to list. Additionally, for the most optimal performance,
|
1073
|
+
pre-tokenized items are in reversed order - it shouldn't matter for training, as
|
1074
|
+
items are shuffled then by DataLoader, but you should keep that in mind in case
|
1075
|
+
of reproducibility.
|
1076
|
+
|
1077
|
+
Args:
|
1078
|
+
verbose (bool): Should display logs (default: False)
|
1079
|
+
log_interval (int): Display logs every log_interval iterations (default: 10_000)
|
1080
|
+
keep_order (bool): Keep tokenized items in the same order - by default they are reversed for faster processing (default: False)
|
1081
|
+
"""
|
1082
|
+
if not self.is_pre_tokenized:
|
1083
|
+
for item in self.datasets:
|
1084
|
+
item['dataset'].pre_tokenize(verbose, log_interval, keep_order)
|
1085
|
+
if item['eval_dataset'] is not None:
|
1086
|
+
item['eval_dataset'].pre_tokenize(verbose, log_interval, keep_order)
|
1087
|
+
|
1088
|
+
@classmethod
|
1089
|
+
def from_hf_hub(
|
1090
|
+
cls,
|
1091
|
+
dataset_id: str,
|
1092
|
+
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
1093
|
+
mrl_curriculum_steps: Union[list[MrlDatasetLoadItem], tuple[MrlDatasetLoadItem]],
|
1094
|
+
split: str = 'train',
|
1095
|
+
query_field: str = 'query',
|
1096
|
+
answer_field: str = 'answer',
|
1097
|
+
interactions_field: str = 'interactions',
|
1098
|
+
load_kwargs: dict = None,
|
1099
|
+
mrl_ds_kwargs: dict = None,
|
1100
|
+
eval_split: str = None,
|
1101
|
+
):
|
1102
|
+
"""
|
1103
|
+
Load dataset from HuggingFace Hub and convert it to RxNN training dataset.
|
1104
|
+
|
1105
|
+
One of the `tokenizer` or `tokenizer_hub_id` args must be provided. If both are provided, `tokenizer` will be used.
|
1106
|
+
|
1107
|
+
Args:
|
1108
|
+
dataset_id (str): Hub dataset repository name
|
1109
|
+
tokenizer (Union[PreTrainedTokenizer, PreTrainedTokenizerFast]): Tokenizer
|
1110
|
+
mrl_curriculum_steps (list[MrlDatasetLoadItem]): MRL Curriculum steps configuration
|
1111
|
+
split (str): Dataset split (default: "train")
|
1112
|
+
query_field (str): Query field (default: "query")
|
1113
|
+
answer_field (str): Answer field (default: "answer")
|
1114
|
+
interactions_field (str): Interactions field (default: "interactions")
|
1115
|
+
load_kwargs (dict): Additional args for HuggingFace API load_dataset function
|
1116
|
+
mrl_ds_kwargs (dict): Additional args for RxNN MrlCurriculumDataset class
|
1117
|
+
eval_split (str): Load also evaluation/validation split (default: None)
|
1118
|
+
"""
|
1119
|
+
if load_kwargs is None:
|
1120
|
+
load_kwargs = {}
|
1121
|
+
if mrl_ds_kwargs is None:
|
1122
|
+
mrl_ds_kwargs = {}
|
1123
|
+
|
1124
|
+
def load_subset(subset_name: str, load_split: str):
|
1125
|
+
return MrlCurriculumDataset.from_hf_hub(
|
1126
|
+
dataset_id,
|
1127
|
+
subset_name,
|
1128
|
+
tokenizer=tokenizer,
|
1129
|
+
query_field=query_field,
|
1130
|
+
answer_field=answer_field,
|
1131
|
+
interactions_field=interactions_field,
|
1132
|
+
split=load_split,
|
1133
|
+
load_kwargs=load_kwargs,
|
1134
|
+
**mrl_ds_kwargs,
|
1135
|
+
)
|
1136
|
+
|
1137
|
+
def dataset_item(item: MrlDatasetLoadItem) -> MrlDatasetItem:
|
1138
|
+
return {
|
1139
|
+
'steps': item['steps'],
|
1140
|
+
'is_long_range': item['is_long_range'],
|
1141
|
+
'dataset': load_subset(item['subset_name'], split),
|
1142
|
+
'eval_dataset': load_subset(item['subset_name'], eval_split) if eval_split is not None else None,
|
1143
|
+
}
|
1144
|
+
|
1145
|
+
mrl_datasets = [dataset_item(item) for item in mrl_curriculum_steps]
|
1146
|
+
|
1147
|
+
return cls(mrl_datasets)
|
rxnn/training/models.py
ADDED
@@ -0,0 +1,142 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
from enum import Enum
|
4
|
+
from huggingface_hub import PyTorchModelHubMixin
|
5
|
+
from ..transformers.models import ReactiveTransformerEncoder, ReactiveTransformerDecoder
|
6
|
+
|
7
|
+
class MLMHead(nn.Module, PyTorchModelHubMixin, license="apache-2.0"):
|
8
|
+
def __init__(self, embed_dim: int, vocab_size: int, *args, **kwargs):
|
9
|
+
super(MLMHead, self).__init__(*args, **kwargs)
|
10
|
+
self.dense = nn.Linear(embed_dim, embed_dim)
|
11
|
+
self.act = nn.GELU()
|
12
|
+
self.layer_norm = nn.LayerNorm(embed_dim)
|
13
|
+
self.decoder = nn.Linear(embed_dim, vocab_size)
|
14
|
+
|
15
|
+
def forward(self, hidden_states):
|
16
|
+
x = self.dense(hidden_states)
|
17
|
+
x = self.act(x)
|
18
|
+
x = self.layer_norm(x)
|
19
|
+
return self.decoder(x)
|
20
|
+
|
21
|
+
|
22
|
+
class MLMTrainingModel(nn.Module):
|
23
|
+
def __init__(
|
24
|
+
self,
|
25
|
+
encoder: ReactiveTransformerEncoder,
|
26
|
+
mlm_head: MLMHead,
|
27
|
+
*args,
|
28
|
+
**kwargs
|
29
|
+
):
|
30
|
+
super(MLMTrainingModel, self).__init__(*args, **kwargs)
|
31
|
+
self.encoder = encoder
|
32
|
+
self.mlm_head = mlm_head
|
33
|
+
|
34
|
+
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
|
35
|
+
h, _ = self.encoder(x, attention_mask=attention_mask)
|
36
|
+
y = self.mlm_head(h)
|
37
|
+
return y
|
38
|
+
|
39
|
+
class JointTrainingModel(nn.Module):
|
40
|
+
def __init__(
|
41
|
+
self,
|
42
|
+
encoder: ReactiveTransformerEncoder,
|
43
|
+
decoder: ReactiveTransformerDecoder,
|
44
|
+
mlm_head: MLMHead,
|
45
|
+
*args,
|
46
|
+
**kwargs
|
47
|
+
):
|
48
|
+
super(JointTrainingModel, self).__init__(*args, **kwargs)
|
49
|
+
self.encoder = encoder
|
50
|
+
self.mlm_head = mlm_head
|
51
|
+
self.decoder = decoder
|
52
|
+
|
53
|
+
def forward(self, x_e: torch.Tensor, x_d: torch.Tensor, attention_mask: torch.Tensor = None) -> tuple[
|
54
|
+
torch.Tensor, torch.Tensor]:
|
55
|
+
encoder_result, _ = self.encoder(x_e, attention_mask=attention_mask)
|
56
|
+
y_e = self.mlm_head(encoder_result)
|
57
|
+
y_d = self.decoder(x_d, attention_mask=attention_mask)
|
58
|
+
return y_e, y_d
|
59
|
+
|
60
|
+
class MrlActorAction(Enum):
|
61
|
+
DECODE = 1
|
62
|
+
UPDATE = 2
|
63
|
+
|
64
|
+
class MrlActorModel(nn.Module):
|
65
|
+
def __init__(
|
66
|
+
self,
|
67
|
+
encoder: nn.Module,
|
68
|
+
decoder: nn.Module,
|
69
|
+
memory_attention: nn.Module,
|
70
|
+
**kwargs
|
71
|
+
):
|
72
|
+
super(MrlActorModel, self).__init__(**kwargs)
|
73
|
+
self.encoder = encoder
|
74
|
+
self.decoder = decoder
|
75
|
+
self.memory_attention = memory_attention
|
76
|
+
|
77
|
+
def freeze_components(self):
|
78
|
+
"""Freeze encoder/decoder except memory-related layers."""
|
79
|
+
if self.encoder.freeze_without_memory is not None:
|
80
|
+
self.encoder.freeze_without_memory()
|
81
|
+
else:
|
82
|
+
for param in self.encoder.parameters():
|
83
|
+
param.requires_grad = False
|
84
|
+
self.encoder.model.trainable_cross_attention_(True)
|
85
|
+
if self.decoder.freeze_without_memory is not None:
|
86
|
+
self.decoder.freeze_without_memory()
|
87
|
+
else:
|
88
|
+
for param in self.decoder.parameters():
|
89
|
+
param.requires_grad = False
|
90
|
+
self.decoder.model.trainable_cross_attention_(True)
|
91
|
+
# Unfreeze memory attention
|
92
|
+
for param in self.memory_attention.parameters():
|
93
|
+
param.requires_grad = True
|
94
|
+
|
95
|
+
def unfreeze_components(self):
|
96
|
+
"""Unfreeze all components after initial training."""
|
97
|
+
if self.encoder.unfreeze_all is not None:
|
98
|
+
self.encoder.unfreeze_all()
|
99
|
+
else:
|
100
|
+
for param in self.encoder.parameters():
|
101
|
+
param.requires_grad = True
|
102
|
+
if self.decoder.unfreeze_all is not None:
|
103
|
+
self.decoder.unfreeze_all()
|
104
|
+
else:
|
105
|
+
for param in self.decoder.parameters():
|
106
|
+
param.requires_grad = True
|
107
|
+
for param in self.memory_attention.parameters():
|
108
|
+
param.requires_grad = True
|
109
|
+
|
110
|
+
def reset_memory(self):
|
111
|
+
self.memory_attention.reset_memory()
|
112
|
+
|
113
|
+
def unique_parameters(self):
|
114
|
+
return list(set(
|
115
|
+
list(self.encoder.parameters()) +
|
116
|
+
list(self.decoder.parameters()) +
|
117
|
+
list(self.memory_attention.parameters())
|
118
|
+
))
|
119
|
+
|
120
|
+
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None, action: MrlActorAction = MrlActorAction.DECODE) -> torch.Tensor:
|
121
|
+
if action == MrlActorAction.DECODE:
|
122
|
+
return self.decoder(x, attention_mask=attention_mask)
|
123
|
+
else:
|
124
|
+
_, ed = self.encoder(x, attention_mask=attention_mask)
|
125
|
+
return self.memory_attention(ed, attention_mask=attention_mask)
|
126
|
+
|
127
|
+
class MrlCriticModel(nn.Module):
|
128
|
+
def __init__(self, encoder: nn.Module, embed_dim: int, **kwargs):
|
129
|
+
super(MrlCriticModel, self).__init__(**kwargs)
|
130
|
+
self.encoder = encoder
|
131
|
+
self.value_head = nn.Linear(embed_dim, 1)
|
132
|
+
|
133
|
+
def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
|
134
|
+
x, _ = self.encoder(x, attention_mask=attention_mask)
|
135
|
+
|
136
|
+
if attention_mask is not None:
|
137
|
+
x = x * attention_mask.unsqueeze(-1)
|
138
|
+
x = x.sum(dim=1) / attention_mask.sum(dim=1, keepdim=True)
|
139
|
+
else:
|
140
|
+
x = x.mean(dim=1)
|
141
|
+
|
142
|
+
return self.value_head(x)
|