wolof-translate 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wolof_translate/__init__.py +73 -0
- wolof_translate/data/__init__.py +0 -0
- wolof_translate/data/dataset_v1.py +151 -0
- wolof_translate/data/dataset_v2.py +187 -0
- wolof_translate/data/dataset_v3.py +187 -0
- wolof_translate/data/dataset_v3_2.py +187 -0
- wolof_translate/data/dataset_v4.py +202 -0
- wolof_translate/data/dataset_v5.py +65 -0
- wolof_translate/models/__init__.py +0 -0
- wolof_translate/models/transformers/__init__.py +0 -0
- wolof_translate/models/transformers/main.py +865 -0
- wolof_translate/models/transformers/main_2.py +362 -0
- wolof_translate/models/transformers/optimization.py +41 -0
- wolof_translate/models/transformers/position.py +46 -0
- wolof_translate/models/transformers/size.py +44 -0
- wolof_translate/pipe/__init__.py +1 -0
- wolof_translate/pipe/nlp_pipeline.py +512 -0
- wolof_translate/tokenizers/__init__.py +0 -0
- wolof_translate/trainers/__init__.py +0 -0
- wolof_translate/trainers/transformer_trainer.py +760 -0
- wolof_translate/trainers/transformer_trainer_custom.py +882 -0
- wolof_translate/trainers/transformer_trainer_ml.py +925 -0
- wolof_translate/trainers/transformer_trainer_ml_.py +1042 -0
- wolof_translate/utils/__init__.py +1 -0
- wolof_translate/utils/bucket_iterator.py +143 -0
- wolof_translate/utils/database_manager.py +116 -0
- wolof_translate/utils/display_predictions.py +162 -0
- wolof_translate/utils/download_model.py +40 -0
- wolof_translate/utils/evaluate_custom.py +147 -0
- wolof_translate/utils/evaluation.py +74 -0
- wolof_translate/utils/extract_new_sentences.py +810 -0
- wolof_translate/utils/extract_poems.py +60 -0
- wolof_translate/utils/extract_sentences.py +562 -0
- wolof_translate/utils/improvements/__init__.py +0 -0
- wolof_translate/utils/improvements/end_marks.py +45 -0
- wolof_translate/utils/recuperate_datasets.py +94 -0
- wolof_translate/utils/recuperate_datasets_trunc.py +85 -0
- wolof_translate/utils/send_model.py +26 -0
- wolof_translate/utils/sent_corrections.py +169 -0
- wolof_translate/utils/sent_transformers.py +27 -0
- wolof_translate/utils/sent_unification.py +97 -0
- wolof_translate/utils/split_with_valid.py +72 -0
- wolof_translate/utils/tokenize_text.py +46 -0
- wolof_translate/utils/training.py +213 -0
- wolof_translate/utils/trunc_hg_training.py +196 -0
- wolof_translate-0.0.1.dist-info/METADATA +31 -0
- wolof_translate-0.0.1.dist-info/RECORD +49 -0
- wolof_translate-0.0.1.dist-info/WHEEL +5 -0
- wolof_translate-0.0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,187 @@
|
|
|
1
|
+
from wolof_translate.utils.sent_transformers import TransformerSequences
|
|
2
|
+
from transformers import PreTrainedTokenizerFast
|
|
3
|
+
from torch.utils.data import Dataset
|
|
4
|
+
from typing import *
|
|
5
|
+
import pandas as pd
|
|
6
|
+
import torch
|
|
7
|
+
import re
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class T5SentenceDataset(Dataset):
|
|
11
|
+
def __init__(
|
|
12
|
+
self,
|
|
13
|
+
data_path: str,
|
|
14
|
+
tokenizer: PreTrainedTokenizerFast,
|
|
15
|
+
corpus_1: str = "french",
|
|
16
|
+
corpus_2: str = "wolof",
|
|
17
|
+
max_len: int = 38,
|
|
18
|
+
truncation: bool = False,
|
|
19
|
+
file_sep: str = ",",
|
|
20
|
+
cp1_transformer: Union[TransformerSequences, None] = None,
|
|
21
|
+
cp2_transformer: Union[TransformerSequences, None] = None,
|
|
22
|
+
**kwargs
|
|
23
|
+
):
|
|
24
|
+
|
|
25
|
+
# let us recuperate the data frame
|
|
26
|
+
self.__sentences = pd.read_csv(data_path, sep=file_sep, **kwargs)
|
|
27
|
+
|
|
28
|
+
# let us recuperate the tokenizer
|
|
29
|
+
self.tokenizer = tokenizer
|
|
30
|
+
|
|
31
|
+
# recuperate the first corpus' sentences
|
|
32
|
+
self.sentences_1 = self.__sentences[corpus_1].to_list()
|
|
33
|
+
|
|
34
|
+
# recuperate the second corpus' sentences
|
|
35
|
+
self.sentences_2 = self.__sentences[corpus_2].to_list()
|
|
36
|
+
|
|
37
|
+
# recuperate the length
|
|
38
|
+
self.length = len(self.sentences_1)
|
|
39
|
+
|
|
40
|
+
# let us recuperate the max len
|
|
41
|
+
self.max_len = max_len + max_len // 5
|
|
42
|
+
|
|
43
|
+
# let us recuperate the truncation argument
|
|
44
|
+
self.truncation = truncation
|
|
45
|
+
|
|
46
|
+
# let us initialize the transformer
|
|
47
|
+
self.cp1_transformer = cp1_transformer
|
|
48
|
+
|
|
49
|
+
self.cp2_transformer = cp2_transformer
|
|
50
|
+
|
|
51
|
+
def __getitem__(self, index):
|
|
52
|
+
"""Recuperate ids and attention masks of sentences at index
|
|
53
|
+
|
|
54
|
+
Args:
|
|
55
|
+
index (int): The index of the sentences to recuperate
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
tuple: The `sentence to translate' ids`, `the attention mask of the sentence to translate`
|
|
59
|
+
`the labels' ids`
|
|
60
|
+
"""
|
|
61
|
+
sentence_1 = self.sentences_1[index]
|
|
62
|
+
|
|
63
|
+
sentence_2 = self.sentences_2[index]
|
|
64
|
+
|
|
65
|
+
# apply transformers if necessary
|
|
66
|
+
if not self.cp1_transformer is None:
|
|
67
|
+
|
|
68
|
+
sentence_1 = self.cp1_transformer(sentence_1)[0]
|
|
69
|
+
|
|
70
|
+
if not self.cp2_transformer is None:
|
|
71
|
+
|
|
72
|
+
sentence_2 = self.cp2_transformer(sentence_2)[0]
|
|
73
|
+
|
|
74
|
+
sentence_1 = sentence_1 + self.tokenizer.eos_token
|
|
75
|
+
|
|
76
|
+
sentence_2 = sentence_2 + self.tokenizer.eos_token
|
|
77
|
+
|
|
78
|
+
# let us encode the sentences (we provide the second sentence as labels to the tokenizer)
|
|
79
|
+
data = self.tokenizer(
|
|
80
|
+
sentence_1,
|
|
81
|
+
truncation=self.truncation,
|
|
82
|
+
max_length=self.max_len,
|
|
83
|
+
padding="max_length",
|
|
84
|
+
return_tensors="pt",
|
|
85
|
+
text_target=sentence_2,
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
return (
|
|
89
|
+
data.input_ids.squeeze(0),
|
|
90
|
+
data.attention_mask.squeeze(0),
|
|
91
|
+
data.labels.squeeze(0),
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
def __len__(self):
|
|
95
|
+
|
|
96
|
+
return self.length
|
|
97
|
+
|
|
98
|
+
def decode(self, labels: torch.Tensor):
|
|
99
|
+
|
|
100
|
+
if labels.ndim < 2:
|
|
101
|
+
|
|
102
|
+
labels = labels.unsqueeze(0)
|
|
103
|
+
|
|
104
|
+
sentences = self.tokenizer.batch_decode(labels, skip_special_tokens=True)
|
|
105
|
+
|
|
106
|
+
return sentences
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class SentenceDataset(T5SentenceDataset):
|
|
110
|
+
def __init__(
|
|
111
|
+
self,
|
|
112
|
+
data_path: str,
|
|
113
|
+
tokenizer: PreTrainedTokenizerFast,
|
|
114
|
+
corpus_1: str = "french",
|
|
115
|
+
corpus_2: str = "wolof",
|
|
116
|
+
max_len: int = 38,
|
|
117
|
+
truncation: bool = False,
|
|
118
|
+
file_sep: str = ",",
|
|
119
|
+
cp1_transformer: Union[TransformerSequences, None] = None,
|
|
120
|
+
cp2_transformer: Union[TransformerSequences, None] = None,
|
|
121
|
+
**kwargs
|
|
122
|
+
):
|
|
123
|
+
|
|
124
|
+
super().__init__(
|
|
125
|
+
data_path,
|
|
126
|
+
tokenizer,
|
|
127
|
+
corpus_1,
|
|
128
|
+
corpus_2,
|
|
129
|
+
max_len,
|
|
130
|
+
truncation,
|
|
131
|
+
file_sep,
|
|
132
|
+
cp1_transformer,
|
|
133
|
+
cp2_transformer,
|
|
134
|
+
**kwargs
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
def __getitem__(self, index):
|
|
138
|
+
"""Recuperate ids and attention masks of sentences at index
|
|
139
|
+
|
|
140
|
+
Args:
|
|
141
|
+
index (int): The index of the sentences to recuperate
|
|
142
|
+
|
|
143
|
+
Returns:
|
|
144
|
+
tuple: The `sentence to translate' ids`, `the attention mask of the sentence to translate`
|
|
145
|
+
`the labels' ids`
|
|
146
|
+
"""
|
|
147
|
+
sentence_1 = self.sentences_1[index]
|
|
148
|
+
|
|
149
|
+
sentence_2 = self.sentences_2[index]
|
|
150
|
+
|
|
151
|
+
# apply transformers if necessary
|
|
152
|
+
if not self.cp1_transformer is None:
|
|
153
|
+
|
|
154
|
+
sentence_1 = self.cp1_transformer(sentence_1)[0]
|
|
155
|
+
|
|
156
|
+
if not self.cp2_transformer is None:
|
|
157
|
+
|
|
158
|
+
sentence_2 = self.cp2_transformer(sentence_2)[0]
|
|
159
|
+
|
|
160
|
+
sentence_1 = sentence_1 + self.tokenizer.eos_token
|
|
161
|
+
|
|
162
|
+
sentence_2 = sentence_2 + self.tokenizer.eos_token
|
|
163
|
+
|
|
164
|
+
# let us encode the sentences (we provide the second sentence as labels to the tokenizer)
|
|
165
|
+
data = self.tokenizer(
|
|
166
|
+
sentence_1,
|
|
167
|
+
truncation=self.truncation,
|
|
168
|
+
max_length=self.max_len,
|
|
169
|
+
padding="max_length",
|
|
170
|
+
return_tensors="pt",
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
# let us encode the sentences (we provide the second sentence as labels to the tokenizer)
|
|
174
|
+
labels = self.tokenizer(
|
|
175
|
+
sentence_2,
|
|
176
|
+
truncation=self.truncation,
|
|
177
|
+
max_length=self.max_len,
|
|
178
|
+
padding="max_length",
|
|
179
|
+
return_tensors="pt",
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
return (
|
|
183
|
+
data.input_ids.squeeze(0),
|
|
184
|
+
data.attention_mask.squeeze(0),
|
|
185
|
+
labels.input_ids.squeeze(0),
|
|
186
|
+
labels.attention_mask.squeeze(0),
|
|
187
|
+
)
|
|
@@ -0,0 +1,202 @@
|
|
|
1
|
+
from wolof_translate.utils.sent_transformers import TransformerSequences
|
|
2
|
+
from transformers import PreTrainedTokenizerFast
|
|
3
|
+
from torch.utils.data import Dataset
|
|
4
|
+
from typing import *
|
|
5
|
+
import pandas as pd
|
|
6
|
+
import torch
|
|
7
|
+
import re
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class T5SentenceDataset(Dataset):
|
|
11
|
+
def __init__(
|
|
12
|
+
self,
|
|
13
|
+
data_path: str,
|
|
14
|
+
tokenizer: PreTrainedTokenizerFast,
|
|
15
|
+
corpus_1: str = "french",
|
|
16
|
+
corpus_2: str = "wolof",
|
|
17
|
+
max_len: Union[int, None] = None,
|
|
18
|
+
truncation: bool = False,
|
|
19
|
+
file_sep: str = ",",
|
|
20
|
+
cp1_transformer: Union[TransformerSequences, None] = None,
|
|
21
|
+
cp2_transformer: Union[TransformerSequences, None] = None,
|
|
22
|
+
add_bos_token: bool = False,
|
|
23
|
+
**kwargs
|
|
24
|
+
):
|
|
25
|
+
|
|
26
|
+
# let us recuperate the data frame
|
|
27
|
+
self.__sentences = pd.read_csv(data_path, sep=file_sep, **kwargs)
|
|
28
|
+
|
|
29
|
+
# let us recuperate the tokenizer
|
|
30
|
+
self.tokenizer = tokenizer
|
|
31
|
+
|
|
32
|
+
# recuperate the first corpus' sentences
|
|
33
|
+
self.sentences_1 = self.__sentences[corpus_1].to_list()
|
|
34
|
+
|
|
35
|
+
# recuperate the second corpus' sentences
|
|
36
|
+
self.sentences_2 = self.__sentences[corpus_2].to_list()
|
|
37
|
+
|
|
38
|
+
# recuperate the length
|
|
39
|
+
self.length = len(self.sentences_1)
|
|
40
|
+
|
|
41
|
+
# let us recuperate the max len
|
|
42
|
+
self.max_len = max_len + max_len // 5 if not max_len is None else None
|
|
43
|
+
|
|
44
|
+
# let us recuperate the truncation argument
|
|
45
|
+
self.truncation = truncation
|
|
46
|
+
|
|
47
|
+
# let us initialize the transformer
|
|
48
|
+
self.cp1_transformer = cp1_transformer
|
|
49
|
+
|
|
50
|
+
self.cp2_transformer = cp2_transformer
|
|
51
|
+
|
|
52
|
+
# see if we add a beginning of the sentence
|
|
53
|
+
self.add_bos = add_bos_token
|
|
54
|
+
|
|
55
|
+
# let us recuperate the special tokens
|
|
56
|
+
self.special_tokens = tokenizer.convert_ids_to_tokens(tokenizer.all_special_ids)
|
|
57
|
+
|
|
58
|
+
def __getitem__(self, index):
|
|
59
|
+
"""Recuperate ids and attention masks of sentences at index
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
index (int): The index of the sentences to recuperate
|
|
63
|
+
|
|
64
|
+
Returns:
|
|
65
|
+
tuple: The `sentence to translate' ids`, `the attention mask of the sentence to translate`
|
|
66
|
+
`the labels' ids`
|
|
67
|
+
"""
|
|
68
|
+
sentence_1 = self.sentences_1[index]
|
|
69
|
+
|
|
70
|
+
sentence_2 = self.sentences_2[index]
|
|
71
|
+
|
|
72
|
+
# apply transformers if necessary
|
|
73
|
+
if not self.cp1_transformer is None:
|
|
74
|
+
|
|
75
|
+
sentence_1 = self.cp1_transformer(sentence_1)[0]
|
|
76
|
+
|
|
77
|
+
if not self.cp2_transformer is None:
|
|
78
|
+
|
|
79
|
+
sentence_2 = self.cp2_transformer(sentence_2)[0]
|
|
80
|
+
|
|
81
|
+
sentence_1 = sentence_1 + self.tokenizer.eos_token
|
|
82
|
+
|
|
83
|
+
sentence_2 = sentence_2 + self.tokenizer.eos_token
|
|
84
|
+
|
|
85
|
+
# let us encode the sentences (we provide the second sentence as labels to the tokenizer)
|
|
86
|
+
data = self.tokenizer(
|
|
87
|
+
sentence_1,
|
|
88
|
+
truncation=self.truncation,
|
|
89
|
+
max_length=self.max_len,
|
|
90
|
+
padding="max_length",
|
|
91
|
+
return_tensors="pt",
|
|
92
|
+
text_target=sentence_2,
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
return (
|
|
96
|
+
data.input_ids.squeeze(0),
|
|
97
|
+
data.attention_mask.squeeze(0),
|
|
98
|
+
data.labels.squeeze(0),
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
def __len__(self):
|
|
102
|
+
|
|
103
|
+
return self.length
|
|
104
|
+
|
|
105
|
+
def decode(self, labels: torch.Tensor):
|
|
106
|
+
|
|
107
|
+
if labels.ndim < 2:
|
|
108
|
+
|
|
109
|
+
labels = labels.unsqueeze(0)
|
|
110
|
+
|
|
111
|
+
sentences = self.tokenizer.batch_decode(labels, skip_special_tokens=True)
|
|
112
|
+
|
|
113
|
+
return [
|
|
114
|
+
re.sub("|".join(self.special_tokens), "", sentence)
|
|
115
|
+
for sentence in sentences
|
|
116
|
+
]
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
class SentenceDataset(T5SentenceDataset):
|
|
120
|
+
def __init__(
|
|
121
|
+
self,
|
|
122
|
+
data_path: str,
|
|
123
|
+
tokenizer: PreTrainedTokenizerFast,
|
|
124
|
+
corpus_1: str = "french",
|
|
125
|
+
corpus_2: str = "wolof",
|
|
126
|
+
max_len: Union[int, None] = None,
|
|
127
|
+
truncation: bool = False,
|
|
128
|
+
file_sep: str = ",",
|
|
129
|
+
cp1_transformer: Union[TransformerSequences, None] = None,
|
|
130
|
+
cp2_transformer: Union[TransformerSequences, None] = None,
|
|
131
|
+
add_bos_token: bool = False,
|
|
132
|
+
**kwargs
|
|
133
|
+
):
|
|
134
|
+
|
|
135
|
+
super().__init__(
|
|
136
|
+
data_path,
|
|
137
|
+
tokenizer,
|
|
138
|
+
corpus_1,
|
|
139
|
+
corpus_2,
|
|
140
|
+
max_len,
|
|
141
|
+
truncation,
|
|
142
|
+
file_sep,
|
|
143
|
+
cp1_transformer,
|
|
144
|
+
cp2_transformer,
|
|
145
|
+
add_bos_token,
|
|
146
|
+
**kwargs
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
def __getitem__(self, index):
|
|
150
|
+
"""Recuperate ids and attention masks of sentences at index
|
|
151
|
+
|
|
152
|
+
Args:
|
|
153
|
+
index (int): The index of the sentences to recuperate
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
tuple: The `sentence to translate' ids`, `the attention mask of the sentence to translate`
|
|
157
|
+
`the labels' ids`
|
|
158
|
+
"""
|
|
159
|
+
sentence_1 = self.sentences_1[index]
|
|
160
|
+
|
|
161
|
+
sentence_2 = self.sentences_2[index]
|
|
162
|
+
|
|
163
|
+
# apply transformers if necessary
|
|
164
|
+
if not self.cp1_transformer is None:
|
|
165
|
+
|
|
166
|
+
sentence_1 = self.cp1_transformer(sentence_1)[0]
|
|
167
|
+
|
|
168
|
+
if not self.cp2_transformer is None:
|
|
169
|
+
|
|
170
|
+
sentence_2 = self.cp2_transformer(sentence_2)[0]
|
|
171
|
+
|
|
172
|
+
# initialize the bos token
|
|
173
|
+
bos_token = "" if not self.add_bos else self.tokenizer.bos_token
|
|
174
|
+
|
|
175
|
+
sentence_1 = sentence_1
|
|
176
|
+
|
|
177
|
+
sentence_2 = sentence_2
|
|
178
|
+
|
|
179
|
+
# let us encode the sentences (we provide the second sentence as labels to the tokenizer)
|
|
180
|
+
data = self.tokenizer(
|
|
181
|
+
sentence_1,
|
|
182
|
+
truncation=self.truncation,
|
|
183
|
+
max_length=self.max_len,
|
|
184
|
+
padding="max_length" if not self.max_len is None else False,
|
|
185
|
+
return_tensors="pt",
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
# let us encode the sentences (we provide the second sentence as labels to the tokenizer)
|
|
189
|
+
labels = self.tokenizer(
|
|
190
|
+
sentence_2,
|
|
191
|
+
truncation=self.truncation,
|
|
192
|
+
max_length=self.max_len,
|
|
193
|
+
padding="max_length" if not self.max_len is None else False,
|
|
194
|
+
return_tensors="pt",
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
return (
|
|
198
|
+
data.input_ids.squeeze(0),
|
|
199
|
+
data.attention_mask.squeeze(0),
|
|
200
|
+
labels.input_ids.squeeze(0),
|
|
201
|
+
labels.attention_mask.squeeze(0),
|
|
202
|
+
)
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
from wolof_translate.utils.sent_transformers import TransformerSequences
|
|
2
|
+
from wolof_translate.data.dataset_v4 import T5SentenceDataset
|
|
3
|
+
from transformers import PreTrainedTokenizerFast
|
|
4
|
+
from torch.utils.data import Dataset
|
|
5
|
+
from typing import *
|
|
6
|
+
import pandas as pd
|
|
7
|
+
import torch
|
|
8
|
+
import re
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class SentenceDataset(T5SentenceDataset):
|
|
12
|
+
def __init__(
|
|
13
|
+
self,
|
|
14
|
+
data_path: str,
|
|
15
|
+
tokenizer: PreTrainedTokenizerFast,
|
|
16
|
+
corpus_1: str = "french",
|
|
17
|
+
corpus_2: str = "wolof",
|
|
18
|
+
file_sep: str = ",",
|
|
19
|
+
cp1_transformer: Union[TransformerSequences, None] = None,
|
|
20
|
+
cp2_transformer: Union[TransformerSequences, None] = None,
|
|
21
|
+
**kwargs
|
|
22
|
+
):
|
|
23
|
+
|
|
24
|
+
super().__init__(
|
|
25
|
+
data_path,
|
|
26
|
+
tokenizer,
|
|
27
|
+
corpus_1,
|
|
28
|
+
corpus_2,
|
|
29
|
+
0,
|
|
30
|
+
False,
|
|
31
|
+
file_sep,
|
|
32
|
+
cp1_transformer,
|
|
33
|
+
cp2_transformer**kwargs,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
def __getitem__(self, index):
|
|
37
|
+
"""Recuperate ids and attention masks of sentences at index
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
index (int): The index of the sentences to recuperate
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
tuple: The `sentence to translate' ids`, `the attention mask of the sentence to translate`
|
|
44
|
+
`the labels' ids`
|
|
45
|
+
"""
|
|
46
|
+
sentence_1 = self.sentences_1[index]
|
|
47
|
+
|
|
48
|
+
sentence_2 = self.sentences_2[index]
|
|
49
|
+
|
|
50
|
+
# apply transformers if necessary
|
|
51
|
+
if not self.cp1_transformer is None:
|
|
52
|
+
|
|
53
|
+
sentence_1 = self.cp1_transformer(sentence_1)[0]
|
|
54
|
+
|
|
55
|
+
if not self.cp2_transformer is None:
|
|
56
|
+
|
|
57
|
+
sentence_2 = self.cp2_transformer(sentence_2)[0]
|
|
58
|
+
|
|
59
|
+
# let us encode the sentences (we provide the second sentence as labels to the tokenizer)
|
|
60
|
+
data = self.tokenizer(sentence_1)
|
|
61
|
+
|
|
62
|
+
# let us encode the sentences (we provide the second sentence as labels to the tokenizer)
|
|
63
|
+
labels = self.tokenizer(sentence_2)
|
|
64
|
+
|
|
65
|
+
return (data.input_ids.squeeze(0), labels.input_ids.squeeze(0))
|
|
File without changes
|
|
File without changes
|