rxnn 0.1.83__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rxnn/training/dataset.py CHANGED
@@ -4,7 +4,7 @@ from datasets import Dataset as HfDataset, load_dataset, concatenate_datasets
4
4
  from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
5
5
  from .tokenizer import load_tokenizer_from_hf_hub
6
6
 
7
- from typing import Union
7
+ from typing import Union, TypedDict, Optional, TypeAlias, Any
8
8
 
9
9
 
10
10
  class BaseDataset(Dataset):
@@ -189,6 +189,12 @@ class BaseDataset(Dataset):
189
189
  """
190
190
  assert tokenizer is not None or tokenizer_hub_id is not None, "One of the `tokenizer` or `tokenizer_hub_id` args must be provided."
191
191
 
192
+ if load_kwargs is None:
193
+ load_kwargs = {}
194
+
195
+ if load_tokenizer_kwargs is None:
196
+ load_tokenizer_kwargs = {}
197
+
192
198
  if tokenizer is None:
193
199
  tokenizer = load_tokenizer_from_hf_hub(tokenizer_hub_id, **load_tokenizer_kwargs)
194
200
 
@@ -231,6 +237,12 @@ class BaseDataset(Dataset):
231
237
  """
232
238
  assert tokenizer is not None or tokenizer_hub_id is not None, "One of the `tokenizer` or `tokenizer_hub_id` args must be provided."
233
239
 
240
+ if load_kwargs is None:
241
+ load_kwargs = {}
242
+
243
+ if load_tokenizer_kwargs is None:
244
+ load_tokenizer_kwargs = {}
245
+
234
246
  if tokenizer is None:
235
247
  tokenizer = load_tokenizer_from_hf_hub(tokenizer_hub_id, **load_tokenizer_kwargs)
236
248
 
@@ -280,6 +292,12 @@ class BaseDataset(Dataset):
280
292
  """
281
293
  assert tokenizer is not None or tokenizer_hub_id is not None, "One of the `tokenizer` or `tokenizer_hub_id` args must be provided."
282
294
 
295
+ if load_kwargs is None:
296
+ load_kwargs = {}
297
+
298
+ if load_tokenizer_kwargs is None:
299
+ load_tokenizer_kwargs = {}
300
+
283
301
  if tokenizer is None:
284
302
  tokenizer = load_tokenizer_from_hf_hub(tokenizer_hub_id, **load_tokenizer_kwargs)
285
303
 
@@ -593,6 +611,12 @@ class BaseInteractionDataset(Dataset):
593
611
  """
594
612
  assert tokenizer is not None or tokenizer_hub_id is not None, "One of the `tokenizer` or `tokenizer_hub_id` args must be provided."
595
613
 
614
+ if load_kwargs is None:
615
+ load_kwargs = {}
616
+
617
+ if load_tokenizer_kwargs is None:
618
+ load_tokenizer_kwargs = {}
619
+
596
620
  if tokenizer is None:
597
621
  tokenizer = load_tokenizer_from_hf_hub(tokenizer_hub_id, **load_tokenizer_kwargs)
598
622
 
@@ -637,6 +661,12 @@ class BaseInteractionDataset(Dataset):
637
661
  """
638
662
  assert tokenizer is not None or tokenizer_hub_id is not None, "One of the `tokenizer` or `tokenizer_hub_id` args must be provided."
639
663
 
664
+ if load_kwargs is None:
665
+ load_kwargs = {}
666
+
667
+ if load_tokenizer_kwargs is None:
668
+ load_tokenizer_kwargs = {}
669
+
640
670
  if tokenizer is None:
641
671
  tokenizer = load_tokenizer_from_hf_hub(tokenizer_hub_id, **load_tokenizer_kwargs)
642
672
 
@@ -688,6 +718,12 @@ class BaseInteractionDataset(Dataset):
688
718
  """
689
719
  assert tokenizer is not None or tokenizer_hub_id is not None, "One of the `tokenizer` or `tokenizer_hub_id` args must be provided."
690
720
 
721
+ if load_kwargs is None:
722
+ load_kwargs = {}
723
+
724
+ if load_tokenizer_kwargs is None:
725
+ load_tokenizer_kwargs = {}
726
+
691
727
  if tokenizer is None:
692
728
  tokenizer = load_tokenizer_from_hf_hub(tokenizer_hub_id, **load_tokenizer_kwargs)
693
729
 
@@ -802,3 +838,310 @@ class EncoderSftDataset(BaseInteractionDataset):
802
838
  'attention_mask': attention_mask,
803
839
  'labels': labels
804
840
  }
841
+
842
+ MrlDataItem: TypeAlias = dict[str, Union[dict[str, torch.Tensor], list[dict[str, dict[str, torch.Tensor]]]]]
843
+
844
+ class MrlCurriculumDataset(Dataset):
845
+ def __init__(
846
+ self,
847
+ episodes: Union[list[dict], HfDataset],
848
+ tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
849
+ max_seq_len: int = 1024,
850
+ query_field: str = 'query',
851
+ answer_field: str = 'answer',
852
+ interactions_field: str = 'interactions',
853
+ query_token: str = '[Q]',
854
+ answer_token: str = '[A]',
855
+ bos_token: str = '[BOS]',
856
+ eos_token: str = '[EOS]',
857
+ **kwargs,
858
+ ):
859
+ super(MrlCurriculumDataset, self).__init__(**kwargs)
860
+ self.episodes = episodes
861
+ self.tokenizer = tokenizer
862
+ self.max_seq_len = max_seq_len
863
+ self.query_field = query_field
864
+ self.answer_field = answer_field
865
+ self.interactions_field = interactions_field
866
+ self.query_token = query_token
867
+ self.answer_token = answer_token
868
+ self.bos_token = bos_token
869
+ self.eos_token = eos_token
870
+ self.is_pre_tokenized = False
871
+ self.is_list = isinstance(self.episodes, list)
872
+ self.inputs = []
873
+
874
+ def _tokenize_manual_interaction(self, query: str, answer: str) -> dict[str, dict[str, torch.Tensor]]:
875
+ # Manually construct query: [BOS][Q]query
876
+ query_text = f"{self.bos_token}{self.query_token}{query}"
877
+ query_enc = self.tokenizer(
878
+ query_text,
879
+ max_length=self.max_seq_len,
880
+ padding='max_length',
881
+ truncation=True,
882
+ return_tensors='pt',
883
+ add_special_tokens=False # Critical: We control all tokens
884
+ )
885
+
886
+ # Manually construct answer: [A]answer[EOS]
887
+ answer_text = f"{self.answer_token}{answer}{self.eos_token}"
888
+ answer_enc = self.tokenizer(
889
+ answer_text,
890
+ max_length=self.max_seq_len,
891
+ padding='max_length',
892
+ truncation=True,
893
+ return_tensors='pt',
894
+ add_special_tokens=False # Critical: We control all tokens
895
+ )
896
+
897
+ return {
898
+ 'query': {
899
+ 'input_ids': query_enc['input_ids'][0],
900
+ 'attention_mask': query_enc['attention_mask'][0],
901
+ },
902
+ 'answer': {
903
+ 'input_ids': answer_enc['input_ids'][0],
904
+ 'attention_mask': answer_enc['attention_mask'][0],
905
+ }
906
+ }
907
+
908
+ def get_tokenized_item(self, idx: int, episode: dict = None) -> MrlDataItem:
909
+ if self.is_pre_tokenized:
910
+ return self.inputs[idx]
911
+ else:
912
+ item = self.episodes[idx] if episode is None else episode
913
+ query = item[self.query_field]
914
+ answer = item[self.answer_field]
915
+ interactions = item[self.interactions_field]
916
+
917
+ initial = self._tokenize_manual_interaction(query, answer)
918
+ follow_ups = [self._tokenize_manual_interaction(inter['query'], inter['answer']) for inter in interactions]
919
+
920
+ return {
921
+ **initial,
922
+ 'interactions': follow_ups,
923
+ }
924
+
925
+ def __getitem__(self, idx: int) -> MrlDataItem:
926
+ return self.get_tokenized_item(idx)
927
+
928
+ def __len__(self) -> int:
929
+ return len(self.episodes)
930
+
931
+ def get_subset(self, size: float, from_start: bool = False, **kwargs) -> "MRlCurriculumDataset":
932
+ split_point = int(len(self.episodes) * ((1 - size) if not from_start else size))
933
+ if not isinstance(self.episodes, list):
934
+ subset = self.episodes.select(range(split_point, len(self.episodes)) if not from_start else range(split_point))
935
+ self.episodes = self.episodes.select(range(split_point) if not from_start else range(split_point, len(self.episodes)))
936
+ else:
937
+ subset = self.episodes[split_point:-1] if not from_start else self.episodes[0:split_point]
938
+ self.episodes = self.episodes[0:split_point] if not from_start else self.episodes[split_point:-1]
939
+ return self.__class__(subset, query_field=self.query_field, answer_field=self.answer_field, interactions_field=self.interactions_field, **kwargs)
940
+
941
+ def pre_tokenize(self, verbose: bool = False, log_interval: int = 10_000, keep_order: bool = False):
942
+ """
943
+ Pre-tokenizes all the items in the dataset, for faster training. Training with pre-tokenized
944
+ dataset could be even 2x faster.
945
+
946
+ !! This method has extremely high memory usage, when used with HuggingFace datasets,
947
+ because of converting it to list. Additionally, for the most optimal performance,
948
+ pre-tokenized items are in reversed order - it shouldn't matter for training, as
949
+ items are shuffled then by DataLoader, but you should keep that in mind in case
950
+ of reproducibility.
951
+
952
+ Args:
953
+ verbose (bool): Should display logs (default: False)
954
+ log_interval (int): Display logs every log_interval iterations (default: 10_000)
955
+ keep_order (bool): Keep tokenized items in the same order - by default they are reversed for faster processing (default: False)
956
+ """
957
+ if not self.is_pre_tokenized:
958
+ num_episodes = len(self.episodes)
959
+ eps = self.episodes if self.is_list else self.episodes.to_list()
960
+ del self.episodes
961
+ self.episodes = None
962
+ for index in range(num_episodes):
963
+ self.inputs.append(self.get_tokenized_item(index, episode=eps.pop() if not keep_order else eps[index]))
964
+ if verbose and index % log_interval == 0:
965
+ print(f'Processed {index + 1}/{num_episodes}')
966
+ del eps
967
+ self.is_pre_tokenized = True
968
+
969
+ @classmethod
970
+ def from_hf_hub(
971
+ cls,
972
+ dataset_id: str,
973
+ mrl_subset: str,
974
+ tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
975
+ split: str = 'train',
976
+ query_field: str = 'query',
977
+ answer_field: str = 'answer',
978
+ interactions_field: str = 'interactions',
979
+ load_kwargs: dict = None,
980
+ **kwargs
981
+ ):
982
+ """
983
+ Load dataset from HuggingFace Hub and convert it to RxNN training dataset.
984
+
985
+ One of the `tokenizer` or `tokenizer_hub_id` args must be provided. If both are provided, `tokenizer` will be used.
986
+
987
+ Args:
988
+ dataset_id (str): Hub dataset repository name
989
+ mrl_subset (str): Dataset subset
990
+ tokenizer (Union[PreTrainedTokenizer, PreTrainedTokenizerFast]): Tokenizer
991
+ split (str): Dataset split (default: "train")
992
+ query_field (str): Query field (default: "query")
993
+ answer_field (str): Answer field (default: "answer")
994
+ interactions_field (str): Interactions field (default: "interactions")
995
+ load_kwargs (dict): Additional args for HuggingFace API load_dataset function
996
+ **kwargs: Additional args for RxNN Dataset class
997
+ """
998
+ if load_kwargs is None:
999
+ load_kwargs = {}
1000
+
1001
+ hf_dataset = load_dataset(dataset_id, mrl_subset, split=split, **load_kwargs)
1002
+
1003
+ return cls(hf_dataset, tokenizer, query_field=query_field, answer_field=answer_field, interactions_field=interactions_field, **kwargs)
1004
+
1005
+ @staticmethod
1006
+ def collate_mrl_batch(batch: list[MrlDataItem]) -> MrlDataItem:
1007
+ """Collate function for MRL curriculum dataset with nested interactions"""
1008
+ def collate_interaction_batch(interaction_batch: Union[list[dict[str, dict[str, torch.Tensor]]], tuple[Any]]) -> dict[str, dict[str, torch.Tensor]]:
1009
+ """Helper to collate a batch of interactions"""
1010
+ return {
1011
+ 'query': {
1012
+ 'input_ids': torch.stack([x['query']['input_ids'] for x in interaction_batch]),
1013
+ 'attention_mask': torch.stack([x['query']['attention_mask'] for x in interaction_batch]),
1014
+ },
1015
+ 'answer': {
1016
+ 'input_ids': torch.stack([x['answer']['input_ids'] for x in interaction_batch]),
1017
+ 'attention_mask': torch.stack([x['answer']['attention_mask'] for x in interaction_batch]),
1018
+ }
1019
+ }
1020
+
1021
+ batch_interactions = [x['interactions'] for x in batch]
1022
+ transposed_interactions = list(zip(*batch_interactions))
1023
+
1024
+ return {
1025
+ **collate_interaction_batch(batch), # Collate initial query and answer
1026
+ 'interactions': [
1027
+ collate_interaction_batch(step_batch) for step_batch in transposed_interactions
1028
+ ]
1029
+ }
1030
+
1031
+ class MrlDatasetItem(TypedDict):
1032
+ steps: int
1033
+ is_long_range: bool
1034
+ dataset: MrlCurriculumDataset
1035
+ eval_dataset: Optional[MrlCurriculumDataset]
1036
+
1037
+ class MrlDatasetLoadItem(TypedDict):
1038
+ subset_name: str
1039
+ steps: int
1040
+ is_long_range: bool
1041
+
1042
+ class MrlDatasets:
1043
+ def __init__(self, datasets: list[MrlDatasetItem]):
1044
+ self.datasets = datasets
1045
+
1046
+ def __iter__(self):
1047
+ return iter(self.datasets)
1048
+
1049
+ def __getitem__(self, idx: int) -> MrlDatasetItem:
1050
+ return self.datasets[idx]
1051
+
1052
+ def __len__(self):
1053
+ return len(self.datasets)
1054
+
1055
+ def __call__(self, steps: int, is_long_range: bool = False):
1056
+ for dataset in self.datasets:
1057
+ if dataset['steps'] == steps and dataset['is_long_range'] == is_long_range:
1058
+ return dataset
1059
+ return None
1060
+
1061
+ @property
1062
+ def is_pre_tokenized(self) -> bool:
1063
+ train_tokenized = all(item['dataset'].is_pre_tokenized for item in self.datasets)
1064
+ eval_tokenized = all(item['eval_dataset'].is_pre_tokenized for item in self.datasets if item['eval_dataset'] is not None)
1065
+ return train_tokenized and eval_tokenized
1066
+
1067
+ def pre_tokenize(self, verbose: bool = False, log_interval: int = 10_000, keep_order: bool = False):
1068
+ """
1069
+ Pre-tokenizes all the inner datasets
1070
+
1071
+ !! This method has extremely high memory usage, when used with HuggingFace datasets,
1072
+ because of converting it to list. Additionally, for the most optimal performance,
1073
+ pre-tokenized items are in reversed order - it shouldn't matter for training, as
1074
+ items are shuffled then by DataLoader, but you should keep that in mind in case
1075
+ of reproducibility.
1076
+
1077
+ Args:
1078
+ verbose (bool): Should display logs (default: False)
1079
+ log_interval (int): Display logs every log_interval iterations (default: 10_000)
1080
+ keep_order (bool): Keep tokenized items in the same order - by default they are reversed for faster processing (default: False)
1081
+ """
1082
+ if not self.is_pre_tokenized:
1083
+ for item in self.datasets:
1084
+ item['dataset'].pre_tokenize(verbose, log_interval, keep_order)
1085
+ if item['eval_dataset'] is not None:
1086
+ item['eval_dataset'].pre_tokenize(verbose, log_interval, keep_order)
1087
+
1088
+ @classmethod
1089
+ def from_hf_hub(
1090
+ cls,
1091
+ dataset_id: str,
1092
+ tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
1093
+ mrl_curriculum_steps: Union[list[MrlDatasetLoadItem], tuple[MrlDatasetLoadItem]],
1094
+ split: str = 'train',
1095
+ query_field: str = 'query',
1096
+ answer_field: str = 'answer',
1097
+ interactions_field: str = 'interactions',
1098
+ load_kwargs: dict = None,
1099
+ mrl_ds_kwargs: dict = None,
1100
+ eval_split: str = None,
1101
+ ):
1102
+ """
1103
+ Load dataset from HuggingFace Hub and convert it to RxNN training dataset.
1104
+
1105
+ One of the `tokenizer` or `tokenizer_hub_id` args must be provided. If both are provided, `tokenizer` will be used.
1106
+
1107
+ Args:
1108
+ dataset_id (str): Hub dataset repository name
1109
+ tokenizer (Union[PreTrainedTokenizer, PreTrainedTokenizerFast]): Tokenizer
1110
+ mrl_curriculum_steps (list[MrlDatasetLoadItem]): MRL Curriculum steps configuration
1111
+ split (str): Dataset split (default: "train")
1112
+ query_field (str): Query field (default: "query")
1113
+ answer_field (str): Answer field (default: "answer")
1114
+ interactions_field (str): Interactions field (default: "interactions")
1115
+ load_kwargs (dict): Additional args for HuggingFace API load_dataset function
1116
+ mrl_ds_kwargs (dict): Additional args for RxNN MrlCurriculumDataset class
1117
+ eval_split (str): Load also evaluation/validation split (default: None)
1118
+ """
1119
+ if load_kwargs is None:
1120
+ load_kwargs = {}
1121
+ if mrl_ds_kwargs is None:
1122
+ mrl_ds_kwargs = {}
1123
+
1124
+ def load_subset(subset_name: str, load_split: str):
1125
+ return MrlCurriculumDataset.from_hf_hub(
1126
+ dataset_id,
1127
+ subset_name,
1128
+ tokenizer=tokenizer,
1129
+ query_field=query_field,
1130
+ answer_field=answer_field,
1131
+ interactions_field=interactions_field,
1132
+ split=load_split,
1133
+ load_kwargs=load_kwargs,
1134
+ **mrl_ds_kwargs,
1135
+ )
1136
+
1137
+ def dataset_item(item: MrlDatasetLoadItem) -> MrlDatasetItem:
1138
+ return {
1139
+ 'steps': item['steps'],
1140
+ 'is_long_range': item['is_long_range'],
1141
+ 'dataset': load_subset(item['subset_name'], split),
1142
+ 'eval_dataset': load_subset(item['subset_name'], eval_split) if eval_split is not None else None,
1143
+ }
1144
+
1145
+ mrl_datasets = [dataset_item(item) for item in mrl_curriculum_steps]
1146
+
1147
+ return cls(mrl_datasets)
@@ -0,0 +1,142 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from enum import Enum
4
+ from huggingface_hub import PyTorchModelHubMixin
5
+ from ..transformers.models import ReactiveTransformerEncoder, ReactiveTransformerDecoder
6
+
7
+ class MLMHead(nn.Module, PyTorchModelHubMixin, license="apache-2.0"):
8
+ def __init__(self, embed_dim: int, vocab_size: int, *args, **kwargs):
9
+ super(MLMHead, self).__init__(*args, **kwargs)
10
+ self.dense = nn.Linear(embed_dim, embed_dim)
11
+ self.act = nn.GELU()
12
+ self.layer_norm = nn.LayerNorm(embed_dim)
13
+ self.decoder = nn.Linear(embed_dim, vocab_size)
14
+
15
+ def forward(self, hidden_states):
16
+ x = self.dense(hidden_states)
17
+ x = self.act(x)
18
+ x = self.layer_norm(x)
19
+ return self.decoder(x)
20
+
21
+
22
+ class MLMTrainingModel(nn.Module):
23
+ def __init__(
24
+ self,
25
+ encoder: ReactiveTransformerEncoder,
26
+ mlm_head: MLMHead,
27
+ *args,
28
+ **kwargs
29
+ ):
30
+ super(MLMTrainingModel, self).__init__(*args, **kwargs)
31
+ self.encoder = encoder
32
+ self.mlm_head = mlm_head
33
+
34
+ def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
35
+ h, _ = self.encoder(x, attention_mask=attention_mask)
36
+ y = self.mlm_head(h)
37
+ return y
38
+
39
+ class JointTrainingModel(nn.Module):
40
+ def __init__(
41
+ self,
42
+ encoder: ReactiveTransformerEncoder,
43
+ decoder: ReactiveTransformerDecoder,
44
+ mlm_head: MLMHead,
45
+ *args,
46
+ **kwargs
47
+ ):
48
+ super(JointTrainingModel, self).__init__(*args, **kwargs)
49
+ self.encoder = encoder
50
+ self.mlm_head = mlm_head
51
+ self.decoder = decoder
52
+
53
+ def forward(self, x_e: torch.Tensor, x_d: torch.Tensor, attention_mask: torch.Tensor = None) -> tuple[
54
+ torch.Tensor, torch.Tensor]:
55
+ encoder_result, _ = self.encoder(x_e, attention_mask=attention_mask)
56
+ y_e = self.mlm_head(encoder_result)
57
+ y_d = self.decoder(x_d, attention_mask=attention_mask)
58
+ return y_e, y_d
59
+
60
+ class MrlActorAction(Enum):
61
+ DECODE = 1
62
+ UPDATE = 2
63
+
64
+ class MrlActorModel(nn.Module):
65
+ def __init__(
66
+ self,
67
+ encoder: nn.Module,
68
+ decoder: nn.Module,
69
+ memory_attention: nn.Module,
70
+ **kwargs
71
+ ):
72
+ super(MrlActorModel, self).__init__(**kwargs)
73
+ self.encoder = encoder
74
+ self.decoder = decoder
75
+ self.memory_attention = memory_attention
76
+
77
+ def freeze_components(self):
78
+ """Freeze encoder/decoder except memory-related layers."""
79
+ if self.encoder.freeze_without_memory is not None:
80
+ self.encoder.freeze_without_memory()
81
+ else:
82
+ for param in self.encoder.parameters():
83
+ param.requires_grad = False
84
+ self.encoder.model.trainable_cross_attention_(True)
85
+ if self.decoder.freeze_without_memory is not None:
86
+ self.decoder.freeze_without_memory()
87
+ else:
88
+ for param in self.decoder.parameters():
89
+ param.requires_grad = False
90
+ self.decoder.model.trainable_cross_attention_(True)
91
+ # Unfreeze memory attention
92
+ for param in self.memory_attention.parameters():
93
+ param.requires_grad = True
94
+
95
+ def unfreeze_components(self):
96
+ """Unfreeze all components after initial training."""
97
+ if self.encoder.unfreeze_all is not None:
98
+ self.encoder.unfreeze_all()
99
+ else:
100
+ for param in self.encoder.parameters():
101
+ param.requires_grad = True
102
+ if self.decoder.unfreeze_all is not None:
103
+ self.decoder.unfreeze_all()
104
+ else:
105
+ for param in self.decoder.parameters():
106
+ param.requires_grad = True
107
+ for param in self.memory_attention.parameters():
108
+ param.requires_grad = True
109
+
110
+ def reset_memory(self):
111
+ self.memory_attention.reset_memory()
112
+
113
+ def unique_parameters(self):
114
+ return list(set(
115
+ list(self.encoder.parameters()) +
116
+ list(self.decoder.parameters()) +
117
+ list(self.memory_attention.parameters())
118
+ ))
119
+
120
+ def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None, action: MrlActorAction = MrlActorAction.DECODE) -> torch.Tensor:
121
+ if action == MrlActorAction.DECODE:
122
+ return self.decoder(x, attention_mask=attention_mask)
123
+ else:
124
+ _, ed = self.encoder(x, attention_mask=attention_mask)
125
+ return self.memory_attention(ed, attention_mask=attention_mask)
126
+
127
+ class MrlCriticModel(nn.Module):
128
+ def __init__(self, encoder: nn.Module, embed_dim: int, **kwargs):
129
+ super(MrlCriticModel, self).__init__(**kwargs)
130
+ self.encoder = encoder
131
+ self.value_head = nn.Linear(embed_dim, 1)
132
+
133
+ def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None) -> torch.Tensor:
134
+ x, _ = self.encoder(x, attention_mask=attention_mask)
135
+
136
+ if attention_mask is not None:
137
+ x = x * attention_mask.unsqueeze(-1)
138
+ x = x.sum(dim=1) / attention_mask.sum(dim=1, keepdim=True)
139
+ else:
140
+ x = x.mean(dim=1)
141
+
142
+ return self.value_head(x)