wolof-translate 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wolof_translate/__init__.py +73 -0
- wolof_translate/data/__init__.py +0 -0
- wolof_translate/data/dataset_v1.py +151 -0
- wolof_translate/data/dataset_v2.py +187 -0
- wolof_translate/data/dataset_v3.py +187 -0
- wolof_translate/data/dataset_v3_2.py +187 -0
- wolof_translate/data/dataset_v4.py +202 -0
- wolof_translate/data/dataset_v5.py +65 -0
- wolof_translate/models/__init__.py +0 -0
- wolof_translate/models/transformers/__init__.py +0 -0
- wolof_translate/models/transformers/main.py +865 -0
- wolof_translate/models/transformers/main_2.py +362 -0
- wolof_translate/models/transformers/optimization.py +41 -0
- wolof_translate/models/transformers/position.py +46 -0
- wolof_translate/models/transformers/size.py +44 -0
- wolof_translate/pipe/__init__.py +1 -0
- wolof_translate/pipe/nlp_pipeline.py +512 -0
- wolof_translate/tokenizers/__init__.py +0 -0
- wolof_translate/trainers/__init__.py +0 -0
- wolof_translate/trainers/transformer_trainer.py +760 -0
- wolof_translate/trainers/transformer_trainer_custom.py +882 -0
- wolof_translate/trainers/transformer_trainer_ml.py +925 -0
- wolof_translate/trainers/transformer_trainer_ml_.py +1042 -0
- wolof_translate/utils/__init__.py +1 -0
- wolof_translate/utils/bucket_iterator.py +143 -0
- wolof_translate/utils/database_manager.py +116 -0
- wolof_translate/utils/display_predictions.py +162 -0
- wolof_translate/utils/download_model.py +40 -0
- wolof_translate/utils/evaluate_custom.py +147 -0
- wolof_translate/utils/evaluation.py +74 -0
- wolof_translate/utils/extract_new_sentences.py +810 -0
- wolof_translate/utils/extract_poems.py +60 -0
- wolof_translate/utils/extract_sentences.py +562 -0
- wolof_translate/utils/improvements/__init__.py +0 -0
- wolof_translate/utils/improvements/end_marks.py +45 -0
- wolof_translate/utils/recuperate_datasets.py +94 -0
- wolof_translate/utils/recuperate_datasets_trunc.py +85 -0
- wolof_translate/utils/send_model.py +26 -0
- wolof_translate/utils/sent_corrections.py +169 -0
- wolof_translate/utils/sent_transformers.py +27 -0
- wolof_translate/utils/sent_unification.py +97 -0
- wolof_translate/utils/split_with_valid.py +72 -0
- wolof_translate/utils/tokenize_text.py +46 -0
- wolof_translate/utils/training.py +213 -0
- wolof_translate/utils/trunc_hg_training.py +196 -0
- wolof_translate-0.0.1.dist-info/METADATA +31 -0
- wolof_translate-0.0.1.dist-info/RECORD +49 -0
- wolof_translate-0.0.1.dist-info/WHEEL +5 -0
- wolof_translate-0.0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1 @@
|
|
|
1
|
+
# from wolof_translate.utils.tokenize_text import tokenization
|
|
@@ -0,0 +1,143 @@
|
|
|
1
|
+
|
|
2
|
+
import torch
|
|
3
|
+
import numpy as np
|
|
4
|
+
from typing import *
|
|
5
|
+
from torch.utils.data import Sampler
|
|
6
|
+
from torch.nn.utils.rnn import pad_sequence
|
|
7
|
+
from math import ceil
|
|
8
|
+
|
|
9
|
+
class SequenceLengthBatchSampler(Sampler):
|
|
10
|
+
def __init__(self, dataset, boundaries, batch_sizes, input_key = None, label_key = None, drop_unique = True):
|
|
11
|
+
self.dataset = dataset
|
|
12
|
+
self.boundaries = boundaries
|
|
13
|
+
self.batch_sizes = batch_sizes
|
|
14
|
+
self.data_info = {}
|
|
15
|
+
self.drop_unique = drop_unique
|
|
16
|
+
|
|
17
|
+
# Initialize dictionary with indices and element lengths
|
|
18
|
+
for i in range(len(dataset)):
|
|
19
|
+
data = dataset[i]
|
|
20
|
+
length = (
|
|
21
|
+
max(len(data[0]), len(data[2]))
|
|
22
|
+
if (input_key is None and label_key is None)
|
|
23
|
+
else max(len(data[input_key]), len(data[label_key]))
|
|
24
|
+
)
|
|
25
|
+
self.data_info[i] = {"index": i, "length": length}
|
|
26
|
+
|
|
27
|
+
self.calculate_length()
|
|
28
|
+
|
|
29
|
+
def calculate_length(self):
|
|
30
|
+
self.batches = []
|
|
31
|
+
|
|
32
|
+
# Sort indices based on element length
|
|
33
|
+
sorted_indices = sorted(self.data_info.keys(), key=lambda i: self.data_info[i]["length"])
|
|
34
|
+
|
|
35
|
+
# Group indices into batches of sequences with the same length
|
|
36
|
+
for boundary in self.boundaries:
|
|
37
|
+
batch = [i for i in sorted_indices if self.data_info[i]["length"] <= boundary] # Filter indices based on length boundary
|
|
38
|
+
self.batches.append(batch)
|
|
39
|
+
sorted_indices = [i for i in sorted_indices if i not in batch] # Remove processed indices
|
|
40
|
+
|
|
41
|
+
# Add remaining indices to the last batch
|
|
42
|
+
self.batches.append(sorted_indices)
|
|
43
|
+
|
|
44
|
+
# Calculate the total length of the data loader
|
|
45
|
+
self.length = sum(ceil(len(batch) / batch_size) for batch, batch_size in zip(self.batches, self.batch_sizes) if len(batch) % batch_size != 1 or not self.drop_unique)
|
|
46
|
+
|
|
47
|
+
def __iter__(self):
|
|
48
|
+
# indices = list(self.data_info.keys()) # Get indices from the data_info dictionary
|
|
49
|
+
# np.random.shuffle(indices) # Shuffle the indices
|
|
50
|
+
|
|
51
|
+
# Yield batches with the corresponding batch sizes
|
|
52
|
+
for batch_indices, batch_size in zip(self.batches, self.batch_sizes):
|
|
53
|
+
num_batches = len(batch_indices) // batch_size
|
|
54
|
+
|
|
55
|
+
for i in range(num_batches):
|
|
56
|
+
# Recuperate the current bucket
|
|
57
|
+
current_bucket = batch_indices[i * batch_size: (i + 1) * batch_size]
|
|
58
|
+
|
|
59
|
+
# Shuffle the current bucket
|
|
60
|
+
np.random.shuffle(current_bucket)
|
|
61
|
+
|
|
62
|
+
# Yield the current bucket
|
|
63
|
+
yield [self.data_info[i]["index"] for i in current_bucket]
|
|
64
|
+
|
|
65
|
+
remaining_indices = len(batch_indices) % batch_size
|
|
66
|
+
|
|
67
|
+
if remaining_indices > 0 and remaining_indices != 1 or not self.drop_unique:
|
|
68
|
+
|
|
69
|
+
# Recuperate the current bucket
|
|
70
|
+
current_bucket = batch_indices[-remaining_indices:]
|
|
71
|
+
|
|
72
|
+
# Shuffle the current bucket
|
|
73
|
+
np.random.shuffle(current_bucket)
|
|
74
|
+
|
|
75
|
+
# Yield the current bucket
|
|
76
|
+
yield [self.data_info[i]["index"] for i in batch_indices[-remaining_indices:]]
|
|
77
|
+
|
|
78
|
+
def __len__(self):
|
|
79
|
+
return self.length
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class BucketSampler(Sampler):
|
|
83
|
+
def __init__(self, dataset, batch_size, sort_key=lambda x, index_1, index_2: max(len(x[index_1]), len(x[index_2])), input_key: Union[str, int] = 0, label_key: Union[str, int] = 1):
|
|
84
|
+
self.dataset = dataset
|
|
85
|
+
self.batch_size = batch_size
|
|
86
|
+
self.sort_key = sort_key
|
|
87
|
+
self.index_1 = input_key
|
|
88
|
+
self.index_2 = label_key
|
|
89
|
+
indices = np.argsort([self.sort_key(self.dataset[i], self.index_1, self.index_2) for i in range(len(self.dataset))])
|
|
90
|
+
self.batches = [indices[i:i + self.batch_size] for i in range(0, len(indices), self.batch_size)]
|
|
91
|
+
|
|
92
|
+
def __iter__(self):
|
|
93
|
+
if self.batch_size > 1:
|
|
94
|
+
np.random.shuffle(self.batches)
|
|
95
|
+
for batch in self.batches:
|
|
96
|
+
yield batch.tolist()
|
|
97
|
+
|
|
98
|
+
def __len__(self):
|
|
99
|
+
return ceil(len(self.dataset) / self.batch_size)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def collate_fn(batch):
|
|
103
|
+
from torch.nn.utils.rnn import pad_sequence
|
|
104
|
+
# Separate the input sequences, target sequences, and attention masks
|
|
105
|
+
input_seqs, input_masks, target_seqs, target_masks = zip(*batch)
|
|
106
|
+
|
|
107
|
+
# Pad the input sequences to have the same length
|
|
108
|
+
padded_input_seqs = pad_sequence(input_seqs, batch_first=True)
|
|
109
|
+
|
|
110
|
+
# Pad the target sequences to have the same length
|
|
111
|
+
padded_target_seqs = pad_sequence(target_seqs, batch_first=True)
|
|
112
|
+
|
|
113
|
+
# Pad the input masks to have the same length
|
|
114
|
+
padded_input_masks = pad_sequence(input_masks, batch_first=True)
|
|
115
|
+
|
|
116
|
+
# Pad the labels masks to have the same length
|
|
117
|
+
padded_target_masks = pad_sequence(target_masks, batch_first=True)
|
|
118
|
+
|
|
119
|
+
return padded_input_seqs, padded_input_masks, padded_target_seqs, padded_target_masks
|
|
120
|
+
|
|
121
|
+
def collate_fn_trunc(batch, max_len, eos_token_id, pad_token_id):
|
|
122
|
+
from torch.nn.utils.rnn import pad_sequence
|
|
123
|
+
# Separate the input sequences, target sequences, and attention masks
|
|
124
|
+
input_seqs, input_masks, target_seqs, target_masks = zip(*batch)
|
|
125
|
+
|
|
126
|
+
# Pad the input sequences to have the same length
|
|
127
|
+
padded_input_seqs = pad_sequence(input_seqs, batch_first=True)[:,:max_len]
|
|
128
|
+
|
|
129
|
+
# Pad the target sequences to have the same length
|
|
130
|
+
padded_target_seqs = pad_sequence(target_seqs, batch_first=True)[:,:max_len]
|
|
131
|
+
|
|
132
|
+
# add eos_token id if pad token id is not visible
|
|
133
|
+
padded_input_seqs[:, -1:][(padded_input_seqs[:, -1:] != eos_token_id) & (padded_input_seqs[:, -1:] != pad_token_id)] = eos_token_id
|
|
134
|
+
|
|
135
|
+
padded_target_seqs[:, -1:][(padded_target_seqs[:, -1:] != eos_token_id) & (padded_target_seqs[:, -1:] != pad_token_id)] = eos_token_id
|
|
136
|
+
|
|
137
|
+
# Pad the input masks to have the same length
|
|
138
|
+
padded_input_masks = pad_sequence(input_masks, batch_first=True)[:,:max_len]
|
|
139
|
+
|
|
140
|
+
# Pad the labels masks to have the same length
|
|
141
|
+
padded_target_masks = pad_sequence(target_masks, batch_first=True)[:,:max_len]
|
|
142
|
+
|
|
143
|
+
return padded_input_seqs, padded_input_masks, padded_target_seqs, padded_target_masks
|
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
from pymongo.mongo_client import MongoClient
|
|
2
|
+
from pymongo.server_api import ServerApi
|
|
3
|
+
import pandas as pd
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class TranslationMongoDBManager:
|
|
7
|
+
def __init__(self, uri: str, database: str):
|
|
8
|
+
|
|
9
|
+
# recuperate the client
|
|
10
|
+
self.client = MongoClient(uri)
|
|
11
|
+
|
|
12
|
+
# recuperate the database
|
|
13
|
+
self.db = self.client.get_database(database)
|
|
14
|
+
|
|
15
|
+
def insert_documents(self, documents: list, collection: str = "sentences"):
|
|
16
|
+
|
|
17
|
+
# insert documents inside a collection
|
|
18
|
+
results = self.db[collection].insert_many(documents)
|
|
19
|
+
|
|
20
|
+
return results
|
|
21
|
+
|
|
22
|
+
def insert_document(self, document: dict, collection: str = "sentences"):
|
|
23
|
+
|
|
24
|
+
assert not "_id" in document
|
|
25
|
+
|
|
26
|
+
# get the id of the last sentence (recuperate the max id and add 1 to it)
|
|
27
|
+
max_id = self.get_max_id(collection)
|
|
28
|
+
|
|
29
|
+
# add the new sentences
|
|
30
|
+
document["_id"] = max_id + 1
|
|
31
|
+
|
|
32
|
+
results = self.db[collection].insert_one(document)
|
|
33
|
+
|
|
34
|
+
return results
|
|
35
|
+
|
|
36
|
+
def update_document(
|
|
37
|
+
self,
|
|
38
|
+
id: int,
|
|
39
|
+
document: dict,
|
|
40
|
+
collection: str = "sentences",
|
|
41
|
+
update_collection: str = "updated",
|
|
42
|
+
):
|
|
43
|
+
|
|
44
|
+
# recuperate the document to update
|
|
45
|
+
upd_sent = self.db[collection].find_one({"_id": {"$eq": id}})
|
|
46
|
+
|
|
47
|
+
# delete the document
|
|
48
|
+
self.db[collection].update_one(
|
|
49
|
+
{"_id": {"$eq": upd_sent["_id"]}}, {"$set": document}
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
# add the sentences to the deleted sentences
|
|
53
|
+
upd_sent["_id"] = len(list(self.db[update_collection].find()))
|
|
54
|
+
|
|
55
|
+
results = self.db[update_collection].insert_one(upd_sent)
|
|
56
|
+
|
|
57
|
+
return results
|
|
58
|
+
|
|
59
|
+
def delete_document(
|
|
60
|
+
self, id: int, collection: str = "sentences", del_collection: str = "deleted"
|
|
61
|
+
):
|
|
62
|
+
|
|
63
|
+
# recuperate the document to delete
|
|
64
|
+
del_sent = self.db[collection].find_one({"_id": {"$eq": id}})
|
|
65
|
+
|
|
66
|
+
# delete the sentence
|
|
67
|
+
self.db[collection].delete_one({"_id": {"$eq": del_sent["_id"]}})
|
|
68
|
+
|
|
69
|
+
# add the sentences to the deleted sentences
|
|
70
|
+
del_sent["_id"] = len(list(self.db[del_collection].find()))
|
|
71
|
+
|
|
72
|
+
results = self.db[del_collection].insert_one(del_sent)
|
|
73
|
+
|
|
74
|
+
return results
|
|
75
|
+
|
|
76
|
+
def get_max_id(self, collection: str = "sentences"):
|
|
77
|
+
|
|
78
|
+
# recuperate the maximum id
|
|
79
|
+
id = list(self.db[collection].find().sort("_id", -1).limit(1))[0]["_id"]
|
|
80
|
+
|
|
81
|
+
return id
|
|
82
|
+
|
|
83
|
+
def save_data_frames(
|
|
84
|
+
self,
|
|
85
|
+
sentences_path: str,
|
|
86
|
+
deleted_path: str,
|
|
87
|
+
collection: str = "sentences",
|
|
88
|
+
del_collection: str = "deleted",
|
|
89
|
+
):
|
|
90
|
+
|
|
91
|
+
# recuperate the new corpora
|
|
92
|
+
new_corpora = pd.DataFrame(list(self.db[collection].find()))
|
|
93
|
+
|
|
94
|
+
# recuperate the deleted sentences as a Data Frame
|
|
95
|
+
deleted_df = pd.DataFrame(list(self.db[del_collection].find()))
|
|
96
|
+
|
|
97
|
+
# save the data frames as csv files
|
|
98
|
+
new_corpora.set_index("_id", inplace=True)
|
|
99
|
+
|
|
100
|
+
deleted_df.set_index("_id", inplace=True)
|
|
101
|
+
|
|
102
|
+
new_corpora.to_csv(sentences_path, index=False)
|
|
103
|
+
|
|
104
|
+
deleted_df.to_csv(deleted_path, index=False)
|
|
105
|
+
|
|
106
|
+
def load_data_frames(
|
|
107
|
+
self, collection: str = "sentences", del_collection: str = "deleted"
|
|
108
|
+
):
|
|
109
|
+
|
|
110
|
+
# recuperate the new corpora
|
|
111
|
+
new_corpora = pd.DataFrame(list(self.db[collection].find()))
|
|
112
|
+
|
|
113
|
+
# recuperate the deleted sentences as a Data Frame
|
|
114
|
+
deleted_df = pd.DataFrame(list(self.db[del_collection].find()))
|
|
115
|
+
|
|
116
|
+
return new_corpora, deleted_df
|
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
import plotly.graph_objects as go
|
|
2
|
+
from tabulate import tabulate
|
|
3
|
+
import plotly.io as pio
|
|
4
|
+
import pandas as pd
|
|
5
|
+
import textwrap
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def display_samples(
|
|
9
|
+
data_frame: pd.DataFrame,
|
|
10
|
+
n_samples: int = 40,
|
|
11
|
+
seed: int = 0,
|
|
12
|
+
header_color: str = "paleturquoise",
|
|
13
|
+
cells_color: str = "lavender",
|
|
14
|
+
width: int = 600,
|
|
15
|
+
height: int = 1000,
|
|
16
|
+
save_sample: bool = True,
|
|
17
|
+
table_caption: str = "",
|
|
18
|
+
label: str = "",
|
|
19
|
+
filename: str = "samples.csv",
|
|
20
|
+
lang: str = "eng",
|
|
21
|
+
show: bool = True,
|
|
22
|
+
):
|
|
23
|
+
"""Display a random sample of the data frame.
|
|
24
|
+
|
|
25
|
+
Args:
|
|
26
|
+
data_frame (pd.DataFrame): The data frame to display.
|
|
27
|
+
n_samples (int, optional): The number of samples. Defaults to 40.
|
|
28
|
+
seed (int, optional): The generator' seed. Defaults to 0.
|
|
29
|
+
header_color (str, optional): The header color. Defaults to 'paleturquoise'.
|
|
30
|
+
cells_color (str, optional): The cells' color. Defaults to 'lavender'.
|
|
31
|
+
width (int): The width of the figure. Defaults to 600.
|
|
32
|
+
height (int): The height of the figure. Defaults to 300.
|
|
33
|
+
lang (str): The language: 'fr' for french or 'eng' for english. Defaults to 'eng'.
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
: The figure.
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
# get the samples from the data frame
|
|
40
|
+
# samples = data_frame.sample(n_samples, random_state = seed).tail(13)
|
|
41
|
+
samples = data_frame.sample(n_samples, random_state=seed)
|
|
42
|
+
|
|
43
|
+
if lang == "fr":
|
|
44
|
+
|
|
45
|
+
samples.columns = ["Phrases Originales", "Target Sentences", "Prédictions"]
|
|
46
|
+
|
|
47
|
+
elif lang == "eng":
|
|
48
|
+
|
|
49
|
+
samples.columns = ["Source Sentences", "Target Sentences", "Predictions"]
|
|
50
|
+
|
|
51
|
+
# trace the figure
|
|
52
|
+
fig = go.Figure(
|
|
53
|
+
data=go.Table(
|
|
54
|
+
header=dict(
|
|
55
|
+
values=list(samples.columns),
|
|
56
|
+
fill_color=header_color,
|
|
57
|
+
align="center",
|
|
58
|
+
font=dict(size=14, color="black"), # Header font style
|
|
59
|
+
height=40,
|
|
60
|
+
),
|
|
61
|
+
cells=dict(
|
|
62
|
+
values=[samples[col] for col in samples.columns],
|
|
63
|
+
fill_color=cells_color,
|
|
64
|
+
line=dict(color="rgb(204, 204, 204)", width=1),
|
|
65
|
+
font=dict(size=12, color="black"),
|
|
66
|
+
height=30,
|
|
67
|
+
align="left",
|
|
68
|
+
),
|
|
69
|
+
columnwidth=[400, 400, 400],
|
|
70
|
+
)
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
# Customize the table layout
|
|
74
|
+
fig.update_layout(
|
|
75
|
+
width=width, # Set the overall table width
|
|
76
|
+
height=height, # Set the overall table height
|
|
77
|
+
margin=dict(l=0, r=0, t=0, b=0), # Remove margin
|
|
78
|
+
paper_bgcolor="rgba(0,0,0,0)", # Transparent background
|
|
79
|
+
plot_bgcolor="rgba(0,0,0,0)", # Transparent plot area background
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
# display the figure
|
|
83
|
+
if show:
|
|
84
|
+
fig.show()
|
|
85
|
+
|
|
86
|
+
# save the latex script to create the table in latex
|
|
87
|
+
if save_sample:
|
|
88
|
+
samples.to_csv(f"{filename}_{lang}.csv", index=False, encoding="utf-16")
|
|
89
|
+
|
|
90
|
+
return fig
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def save_go_figure_as_image(
|
|
94
|
+
fig, path: str, scale: int = 3, width: int = 600, height: int = 1000
|
|
95
|
+
):
|
|
96
|
+
|
|
97
|
+
# save the figure as a image
|
|
98
|
+
pio.write_image(fig, path, format="png", scale=scale, width=width, height=height)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def escape_latex(text):
|
|
102
|
+
"""
|
|
103
|
+
Escape special characters in text for LaTeX.
|
|
104
|
+
"""
|
|
105
|
+
special_chars = ["_", "&", "%", "$", "#", "{", "}", "~", "^"]
|
|
106
|
+
for char in special_chars:
|
|
107
|
+
text = text.replace(char, "\\" + char)
|
|
108
|
+
return text
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def wrap_text(text, max_width):
|
|
112
|
+
"""
|
|
113
|
+
Wrap long text to fit into the table cells.
|
|
114
|
+
"""
|
|
115
|
+
return "\n".join(textwrap.wrap(text, width=max_width))
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def save_latex_table(
|
|
119
|
+
data_frame,
|
|
120
|
+
table_caption="",
|
|
121
|
+
label="",
|
|
122
|
+
filename="table.tex",
|
|
123
|
+
max_cell_width: int = 100,
|
|
124
|
+
wrap_long_text: bool = True,
|
|
125
|
+
):
|
|
126
|
+
"""
|
|
127
|
+
Convert a pandas DataFrame to a LaTeX table and save it to a file.
|
|
128
|
+
|
|
129
|
+
Parameters:
|
|
130
|
+
data_frame (pandas.DataFrame): The DataFrame to convert to LaTeX table.
|
|
131
|
+
table_caption (str): Optional caption for the table.
|
|
132
|
+
label (str): Optional label for referencing the table in the document.
|
|
133
|
+
filename (str): The name of the file to save the LaTeX code.
|
|
134
|
+
|
|
135
|
+
Returns:
|
|
136
|
+
None
|
|
137
|
+
"""
|
|
138
|
+
# Convert the DataFrame to a LaTeX tabular representation
|
|
139
|
+
latex_table = data_frame.to_latex(
|
|
140
|
+
index=False, escape=False, column_format="p{.425\linewidth}p{.425\linewidth}"
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
# Modify the LaTeX tabular representation to include the caption and label, and add necessary formatting
|
|
144
|
+
latex_table = (
|
|
145
|
+
"\\begin{table}\n"
|
|
146
|
+
" \\centering\n"
|
|
147
|
+
" \\small\n"
|
|
148
|
+
f" \\caption{{{table_caption}}}\n\n"
|
|
149
|
+
" \\begin{tabular}{*{3}{p{.425\\linewidth}}}\n"
|
|
150
|
+
" \\toprule\n"
|
|
151
|
+
" Emociones primarias & Derivación de las emociones primarias \\\\\n"
|
|
152
|
+
" \\midrule\n"
|
|
153
|
+
f"{latex_table}"
|
|
154
|
+
" \\bottomrule\n"
|
|
155
|
+
" \\end{tabular}\n"
|
|
156
|
+
f" \\label{{{label}}}\n"
|
|
157
|
+
"\\end{table}"
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
with open(filename, "w", encoding="utf-8") as f:
|
|
161
|
+
|
|
162
|
+
f.write(latex_table)
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
import shutil
|
|
2
|
+
import wandb
|
|
3
|
+
import glob
|
|
4
|
+
import os
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def transfer_model(artifact_dir: str, model_name: str):
|
|
8
|
+
"""Transfer a download artifact into another directory
|
|
9
|
+
|
|
10
|
+
Args:
|
|
11
|
+
artifact_dir (str): The directory of the artifact
|
|
12
|
+
model_name (str): The name of the model
|
|
13
|
+
"""
|
|
14
|
+
# transfer the model inside the artifact to data/checkpoints/name_of_model
|
|
15
|
+
os.makedirs(model_name, exist_ok=True)
|
|
16
|
+
for file in glob.glob(f"{artifact_dir}/*"):
|
|
17
|
+
shutil.copy(file, model_name)
|
|
18
|
+
|
|
19
|
+
# delete the artifact
|
|
20
|
+
shutil.rmtree(artifact_dir)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def download_artifact(artifact_name: str, model_name: str, type_: str = "dataset"):
|
|
24
|
+
"""This function download an artifact from weights and bias and store it into a directory
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
artifact_name (str): name of the artifact
|
|
28
|
+
model_name (str): name of the model
|
|
29
|
+
type (str): type of the artifact. Default to 'directory'.
|
|
30
|
+
"""
|
|
31
|
+
# download wandb model
|
|
32
|
+
run = wandb.init()
|
|
33
|
+
artifact = run.use_artifact(artifact_name, type=type_)
|
|
34
|
+
artifact_dir = artifact.download()
|
|
35
|
+
|
|
36
|
+
# transfer the artifact into another directory
|
|
37
|
+
transfer_model(artifact_dir, model_name)
|
|
38
|
+
|
|
39
|
+
# finish wandb
|
|
40
|
+
wandb.finish()
|
|
@@ -0,0 +1,147 @@
|
|
|
1
|
+
from tokenizers import Tokenizer
|
|
2
|
+
from typing import *
|
|
3
|
+
import numpy as np
|
|
4
|
+
import evaluate
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class TranslationEvaluation:
|
|
8
|
+
def __init__(
|
|
9
|
+
self,
|
|
10
|
+
tokenizer: Tokenizer,
|
|
11
|
+
decoder: Union[Callable, None] = None,
|
|
12
|
+
next_gen: bool = False,
|
|
13
|
+
):
|
|
14
|
+
|
|
15
|
+
self.tokenizer = tokenizer
|
|
16
|
+
|
|
17
|
+
self.decoder = decoder
|
|
18
|
+
|
|
19
|
+
self.bleu = evaluate.load("sacrebleu")
|
|
20
|
+
|
|
21
|
+
self.rouge = evaluate.load("rouge")
|
|
22
|
+
|
|
23
|
+
self.accuracy = evaluate.load("accuracy")
|
|
24
|
+
|
|
25
|
+
self.next_gen = next_gen
|
|
26
|
+
|
|
27
|
+
def postprocess_text(self, preds, labels):
|
|
28
|
+
|
|
29
|
+
for i in range(len(labels)):
|
|
30
|
+
|
|
31
|
+
pred = preds[i].strip()
|
|
32
|
+
|
|
33
|
+
label = labels[i].strip()
|
|
34
|
+
|
|
35
|
+
if self.next_gen:
|
|
36
|
+
|
|
37
|
+
new_pred = ""
|
|
38
|
+
|
|
39
|
+
new_label = ""
|
|
40
|
+
|
|
41
|
+
for j in range(len(labels[i])):
|
|
42
|
+
|
|
43
|
+
if pred[:j] != labels[i][:j]:
|
|
44
|
+
|
|
45
|
+
new_pred = pred[j:]
|
|
46
|
+
|
|
47
|
+
new_label = label[j:]
|
|
48
|
+
|
|
49
|
+
break
|
|
50
|
+
|
|
51
|
+
preds[i] = new_pred
|
|
52
|
+
|
|
53
|
+
labels[i] = new_label
|
|
54
|
+
|
|
55
|
+
else:
|
|
56
|
+
|
|
57
|
+
preds[i] = pred
|
|
58
|
+
|
|
59
|
+
labels[i] = [label]
|
|
60
|
+
|
|
61
|
+
return preds, labels
|
|
62
|
+
|
|
63
|
+
def postprocess_codes(self, preds: np.ndarray, labels: np.ndarray):
|
|
64
|
+
|
|
65
|
+
label_weights = (labels != 0).astype(float).tolist()
|
|
66
|
+
|
|
67
|
+
preds = preds.tolist()
|
|
68
|
+
|
|
69
|
+
labels = labels.tolist()
|
|
70
|
+
|
|
71
|
+
return preds, labels, label_weights
|
|
72
|
+
|
|
73
|
+
def compute_metrics(
|
|
74
|
+
self, eval_preds, rouge: bool = True, bleu: bool = True, accuracy: bool = False
|
|
75
|
+
):
|
|
76
|
+
|
|
77
|
+
preds, labels = eval_preds
|
|
78
|
+
|
|
79
|
+
if isinstance(preds, tuple):
|
|
80
|
+
|
|
81
|
+
preds = preds[0]
|
|
82
|
+
|
|
83
|
+
decoded_preds = (
|
|
84
|
+
self.tokenizer.batch_decode(preds, skip_special_tokens=True)
|
|
85
|
+
if not self.decoder
|
|
86
|
+
else self.decoder(preds)
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
decoded_labels = (
|
|
90
|
+
self.tokenizer.batch_decode(labels, skip_special_tokens=True)
|
|
91
|
+
if not self.decoder
|
|
92
|
+
else self.decoder(labels)
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
result = {}
|
|
96
|
+
|
|
97
|
+
if accuracy:
|
|
98
|
+
|
|
99
|
+
pred_codes, label_codes, sample_weight = self.postprocess_codes(
|
|
100
|
+
preds, labels
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
accuracy_result = np.mean(
|
|
104
|
+
[
|
|
105
|
+
self.accuracy.compute(
|
|
106
|
+
predictions=pred_codes[i],
|
|
107
|
+
references=label_codes[i],
|
|
108
|
+
sample_weight=sample_weight[i],
|
|
109
|
+
)["accuracy"]
|
|
110
|
+
for i in range(len(pred_codes))
|
|
111
|
+
]
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
result["accuracy"] = accuracy_result
|
|
115
|
+
|
|
116
|
+
if bleu or rouge:
|
|
117
|
+
|
|
118
|
+
decoded_preds, decoded_labels = self.postprocess_text(
|
|
119
|
+
decoded_preds, decoded_labels
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
if bleu:
|
|
123
|
+
|
|
124
|
+
bleu_result = self.bleu.compute(
|
|
125
|
+
predictions=decoded_preds, references=decoded_labels
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
result["bleu"] = bleu_result["score"]
|
|
129
|
+
|
|
130
|
+
if rouge:
|
|
131
|
+
|
|
132
|
+
rouge_result = self.rouge.compute(
|
|
133
|
+
predictions=decoded_preds, references=decoded_labels
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
result.update(rouge_result)
|
|
137
|
+
|
|
138
|
+
prediction_lens = [
|
|
139
|
+
np.count_nonzero(np.array(pred) != self.tokenizer.pad_token_id)
|
|
140
|
+
for pred in preds
|
|
141
|
+
]
|
|
142
|
+
|
|
143
|
+
result["gen_len"] = np.mean(prediction_lens)
|
|
144
|
+
|
|
145
|
+
result = {k: round(v, 4) for k, v in result.items()}
|
|
146
|
+
|
|
147
|
+
return result
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
from tokenizers import Tokenizer
|
|
2
|
+
from typing import *
|
|
3
|
+
import numpy as np
|
|
4
|
+
import evaluate
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class TranslationEvaluation:
|
|
8
|
+
def __init__(
|
|
9
|
+
self,
|
|
10
|
+
tokenizer: Tokenizer,
|
|
11
|
+
decoder: Union[Callable, None] = None,
|
|
12
|
+
metric=evaluate.load("sacrebleu"),
|
|
13
|
+
):
|
|
14
|
+
|
|
15
|
+
self.tokenizer = tokenizer
|
|
16
|
+
|
|
17
|
+
self.decoder = decoder
|
|
18
|
+
|
|
19
|
+
self.metric = metric
|
|
20
|
+
|
|
21
|
+
def postprocess_text(self, preds, labels):
|
|
22
|
+
|
|
23
|
+
preds = [pred.strip() for pred in preds]
|
|
24
|
+
|
|
25
|
+
for label in labels:
|
|
26
|
+
|
|
27
|
+
print(label)
|
|
28
|
+
|
|
29
|
+
labels = [[label.strip()] for label in labels]
|
|
30
|
+
|
|
31
|
+
return preds, labels
|
|
32
|
+
|
|
33
|
+
def compute_metrics(self, eval_preds):
|
|
34
|
+
|
|
35
|
+
preds, labels = eval_preds
|
|
36
|
+
|
|
37
|
+
if isinstance(preds, tuple):
|
|
38
|
+
|
|
39
|
+
preds = preds[0]
|
|
40
|
+
|
|
41
|
+
decoded_preds = (
|
|
42
|
+
self.tokenizer.batch_decode(preds, skip_special_tokens=True)
|
|
43
|
+
if not self.decoder
|
|
44
|
+
else self.decoder(preds)
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
labels = np.where(labels != -100, labels, self.tokenizer.pad_token_id)
|
|
48
|
+
|
|
49
|
+
decoded_labels = (
|
|
50
|
+
self.tokenizer.batch_decode(labels, skip_special_tokens=True)
|
|
51
|
+
if not self.decoder
|
|
52
|
+
else self.decoder(labels)
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
decoded_preds, decoded_labels = self.postprocess_text(
|
|
56
|
+
decoded_preds, decoded_labels
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
result = self.metric.compute(
|
|
60
|
+
predictions=decoded_preds, references=decoded_labels
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
result = {"bleu": result["score"]}
|
|
64
|
+
|
|
65
|
+
prediction_lens = [
|
|
66
|
+
np.count_nonzero(np.array(pred) != self.tokenizer.pad_token_id)
|
|
67
|
+
for pred in preds
|
|
68
|
+
]
|
|
69
|
+
|
|
70
|
+
result["gen_len"] = np.mean(prediction_lens)
|
|
71
|
+
|
|
72
|
+
result = {k: round(v, 4) for k, v in result.items()}
|
|
73
|
+
|
|
74
|
+
return result
|