wolof-translate 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. wolof_translate/__init__.py +73 -0
  2. wolof_translate/data/__init__.py +0 -0
  3. wolof_translate/data/dataset_v1.py +151 -0
  4. wolof_translate/data/dataset_v2.py +187 -0
  5. wolof_translate/data/dataset_v3.py +187 -0
  6. wolof_translate/data/dataset_v3_2.py +187 -0
  7. wolof_translate/data/dataset_v4.py +202 -0
  8. wolof_translate/data/dataset_v5.py +65 -0
  9. wolof_translate/models/__init__.py +0 -0
  10. wolof_translate/models/transformers/__init__.py +0 -0
  11. wolof_translate/models/transformers/main.py +865 -0
  12. wolof_translate/models/transformers/main_2.py +362 -0
  13. wolof_translate/models/transformers/optimization.py +41 -0
  14. wolof_translate/models/transformers/position.py +46 -0
  15. wolof_translate/models/transformers/size.py +44 -0
  16. wolof_translate/pipe/__init__.py +1 -0
  17. wolof_translate/pipe/nlp_pipeline.py +512 -0
  18. wolof_translate/tokenizers/__init__.py +0 -0
  19. wolof_translate/trainers/__init__.py +0 -0
  20. wolof_translate/trainers/transformer_trainer.py +760 -0
  21. wolof_translate/trainers/transformer_trainer_custom.py +882 -0
  22. wolof_translate/trainers/transformer_trainer_ml.py +925 -0
  23. wolof_translate/trainers/transformer_trainer_ml_.py +1042 -0
  24. wolof_translate/utils/__init__.py +1 -0
  25. wolof_translate/utils/bucket_iterator.py +143 -0
  26. wolof_translate/utils/database_manager.py +116 -0
  27. wolof_translate/utils/display_predictions.py +162 -0
  28. wolof_translate/utils/download_model.py +40 -0
  29. wolof_translate/utils/evaluate_custom.py +147 -0
  30. wolof_translate/utils/evaluation.py +74 -0
  31. wolof_translate/utils/extract_new_sentences.py +810 -0
  32. wolof_translate/utils/extract_poems.py +60 -0
  33. wolof_translate/utils/extract_sentences.py +562 -0
  34. wolof_translate/utils/improvements/__init__.py +0 -0
  35. wolof_translate/utils/improvements/end_marks.py +45 -0
  36. wolof_translate/utils/recuperate_datasets.py +94 -0
  37. wolof_translate/utils/recuperate_datasets_trunc.py +85 -0
  38. wolof_translate/utils/send_model.py +26 -0
  39. wolof_translate/utils/sent_corrections.py +169 -0
  40. wolof_translate/utils/sent_transformers.py +27 -0
  41. wolof_translate/utils/sent_unification.py +97 -0
  42. wolof_translate/utils/split_with_valid.py +72 -0
  43. wolof_translate/utils/tokenize_text.py +46 -0
  44. wolof_translate/utils/training.py +213 -0
  45. wolof_translate/utils/trunc_hg_training.py +196 -0
  46. wolof_translate-0.0.1.dist-info/METADATA +31 -0
  47. wolof_translate-0.0.1.dist-info/RECORD +49 -0
  48. wolof_translate-0.0.1.dist-info/WHEEL +5 -0
  49. wolof_translate-0.0.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1 @@
1
+ # from wolof_translate.utils.tokenize_text import tokenization
@@ -0,0 +1,143 @@
1
+
2
+ import torch
3
+ import numpy as np
4
+ from typing import *
5
+ from torch.utils.data import Sampler
6
+ from torch.nn.utils.rnn import pad_sequence
7
+ from math import ceil
8
+
9
+ class SequenceLengthBatchSampler(Sampler):
10
+ def __init__(self, dataset, boundaries, batch_sizes, input_key = None, label_key = None, drop_unique = True):
11
+ self.dataset = dataset
12
+ self.boundaries = boundaries
13
+ self.batch_sizes = batch_sizes
14
+ self.data_info = {}
15
+ self.drop_unique = drop_unique
16
+
17
+ # Initialize dictionary with indices and element lengths
18
+ for i in range(len(dataset)):
19
+ data = dataset[i]
20
+ length = (
21
+ max(len(data[0]), len(data[2]))
22
+ if (input_key is None and label_key is None)
23
+ else max(len(data[input_key]), len(data[label_key]))
24
+ )
25
+ self.data_info[i] = {"index": i, "length": length}
26
+
27
+ self.calculate_length()
28
+
29
+ def calculate_length(self):
30
+ self.batches = []
31
+
32
+ # Sort indices based on element length
33
+ sorted_indices = sorted(self.data_info.keys(), key=lambda i: self.data_info[i]["length"])
34
+
35
+ # Group indices into batches of sequences with the same length
36
+ for boundary in self.boundaries:
37
+ batch = [i for i in sorted_indices if self.data_info[i]["length"] <= boundary] # Filter indices based on length boundary
38
+ self.batches.append(batch)
39
+ sorted_indices = [i for i in sorted_indices if i not in batch] # Remove processed indices
40
+
41
+ # Add remaining indices to the last batch
42
+ self.batches.append(sorted_indices)
43
+
44
+ # Calculate the total length of the data loader
45
+ self.length = sum(ceil(len(batch) / batch_size) for batch, batch_size in zip(self.batches, self.batch_sizes) if len(batch) % batch_size != 1 or not self.drop_unique)
46
+
47
+ def __iter__(self):
48
+ # indices = list(self.data_info.keys()) # Get indices from the data_info dictionary
49
+ # np.random.shuffle(indices) # Shuffle the indices
50
+
51
+ # Yield batches with the corresponding batch sizes
52
+ for batch_indices, batch_size in zip(self.batches, self.batch_sizes):
53
+ num_batches = len(batch_indices) // batch_size
54
+
55
+ for i in range(num_batches):
56
+ # Recuperate the current bucket
57
+ current_bucket = batch_indices[i * batch_size: (i + 1) * batch_size]
58
+
59
+ # Shuffle the current bucket
60
+ np.random.shuffle(current_bucket)
61
+
62
+ # Yield the current bucket
63
+ yield [self.data_info[i]["index"] for i in current_bucket]
64
+
65
+ remaining_indices = len(batch_indices) % batch_size
66
+
67
+ if remaining_indices > 0 and remaining_indices != 1 or not self.drop_unique:
68
+
69
+ # Recuperate the current bucket
70
+ current_bucket = batch_indices[-remaining_indices:]
71
+
72
+ # Shuffle the current bucket
73
+ np.random.shuffle(current_bucket)
74
+
75
+ # Yield the current bucket
76
+ yield [self.data_info[i]["index"] for i in batch_indices[-remaining_indices:]]
77
+
78
+ def __len__(self):
79
+ return self.length
80
+
81
+
82
+ class BucketSampler(Sampler):
83
+ def __init__(self, dataset, batch_size, sort_key=lambda x, index_1, index_2: max(len(x[index_1]), len(x[index_2])), input_key: Union[str, int] = 0, label_key: Union[str, int] = 1):
84
+ self.dataset = dataset
85
+ self.batch_size = batch_size
86
+ self.sort_key = sort_key
87
+ self.index_1 = input_key
88
+ self.index_2 = label_key
89
+ indices = np.argsort([self.sort_key(self.dataset[i], self.index_1, self.index_2) for i in range(len(self.dataset))])
90
+ self.batches = [indices[i:i + self.batch_size] for i in range(0, len(indices), self.batch_size)]
91
+
92
+ def __iter__(self):
93
+ if self.batch_size > 1:
94
+ np.random.shuffle(self.batches)
95
+ for batch in self.batches:
96
+ yield batch.tolist()
97
+
98
+ def __len__(self):
99
+ return ceil(len(self.dataset) / self.batch_size)
100
+
101
+
102
+ def collate_fn(batch):
103
+ from torch.nn.utils.rnn import pad_sequence
104
+ # Separate the input sequences, target sequences, and attention masks
105
+ input_seqs, input_masks, target_seqs, target_masks = zip(*batch)
106
+
107
+ # Pad the input sequences to have the same length
108
+ padded_input_seqs = pad_sequence(input_seqs, batch_first=True)
109
+
110
+ # Pad the target sequences to have the same length
111
+ padded_target_seqs = pad_sequence(target_seqs, batch_first=True)
112
+
113
+ # Pad the input masks to have the same length
114
+ padded_input_masks = pad_sequence(input_masks, batch_first=True)
115
+
116
+ # Pad the labels masks to have the same length
117
+ padded_target_masks = pad_sequence(target_masks, batch_first=True)
118
+
119
+ return padded_input_seqs, padded_input_masks, padded_target_seqs, padded_target_masks
120
+
121
+ def collate_fn_trunc(batch, max_len, eos_token_id, pad_token_id):
122
+ from torch.nn.utils.rnn import pad_sequence
123
+ # Separate the input sequences, target sequences, and attention masks
124
+ input_seqs, input_masks, target_seqs, target_masks = zip(*batch)
125
+
126
+ # Pad the input sequences to have the same length
127
+ padded_input_seqs = pad_sequence(input_seqs, batch_first=True)[:,:max_len]
128
+
129
+ # Pad the target sequences to have the same length
130
+ padded_target_seqs = pad_sequence(target_seqs, batch_first=True)[:,:max_len]
131
+
132
+ # add eos_token id if pad token id is not visible
133
+ padded_input_seqs[:, -1:][(padded_input_seqs[:, -1:] != eos_token_id) & (padded_input_seqs[:, -1:] != pad_token_id)] = eos_token_id
134
+
135
+ padded_target_seqs[:, -1:][(padded_target_seqs[:, -1:] != eos_token_id) & (padded_target_seqs[:, -1:] != pad_token_id)] = eos_token_id
136
+
137
+ # Pad the input masks to have the same length
138
+ padded_input_masks = pad_sequence(input_masks, batch_first=True)[:,:max_len]
139
+
140
+ # Pad the labels masks to have the same length
141
+ padded_target_masks = pad_sequence(target_masks, batch_first=True)[:,:max_len]
142
+
143
+ return padded_input_seqs, padded_input_masks, padded_target_seqs, padded_target_masks
@@ -0,0 +1,116 @@
1
+ from pymongo.mongo_client import MongoClient
2
+ from pymongo.server_api import ServerApi
3
+ import pandas as pd
4
+
5
+
6
+ class TranslationMongoDBManager:
7
+ def __init__(self, uri: str, database: str):
8
+
9
+ # recuperate the client
10
+ self.client = MongoClient(uri)
11
+
12
+ # recuperate the database
13
+ self.db = self.client.get_database(database)
14
+
15
+ def insert_documents(self, documents: list, collection: str = "sentences"):
16
+
17
+ # insert documents inside a collection
18
+ results = self.db[collection].insert_many(documents)
19
+
20
+ return results
21
+
22
+ def insert_document(self, document: dict, collection: str = "sentences"):
23
+
24
+ assert not "_id" in document
25
+
26
+ # get the id of the last sentence (recuperate the max id and add 1 to it)
27
+ max_id = self.get_max_id(collection)
28
+
29
+ # add the new sentences
30
+ document["_id"] = max_id + 1
31
+
32
+ results = self.db[collection].insert_one(document)
33
+
34
+ return results
35
+
36
+ def update_document(
37
+ self,
38
+ id: int,
39
+ document: dict,
40
+ collection: str = "sentences",
41
+ update_collection: str = "updated",
42
+ ):
43
+
44
+ # recuperate the document to update
45
+ upd_sent = self.db[collection].find_one({"_id": {"$eq": id}})
46
+
47
+ # delete the document
48
+ self.db[collection].update_one(
49
+ {"_id": {"$eq": upd_sent["_id"]}}, {"$set": document}
50
+ )
51
+
52
+ # add the sentences to the deleted sentences
53
+ upd_sent["_id"] = len(list(self.db[update_collection].find()))
54
+
55
+ results = self.db[update_collection].insert_one(upd_sent)
56
+
57
+ return results
58
+
59
+ def delete_document(
60
+ self, id: int, collection: str = "sentences", del_collection: str = "deleted"
61
+ ):
62
+
63
+ # recuperate the document to delete
64
+ del_sent = self.db[collection].find_one({"_id": {"$eq": id}})
65
+
66
+ # delete the sentence
67
+ self.db[collection].delete_one({"_id": {"$eq": del_sent["_id"]}})
68
+
69
+ # add the sentences to the deleted sentences
70
+ del_sent["_id"] = len(list(self.db[del_collection].find()))
71
+
72
+ results = self.db[del_collection].insert_one(del_sent)
73
+
74
+ return results
75
+
76
+ def get_max_id(self, collection: str = "sentences"):
77
+
78
+ # recuperate the maximum id
79
+ id = list(self.db[collection].find().sort("_id", -1).limit(1))[0]["_id"]
80
+
81
+ return id
82
+
83
+ def save_data_frames(
84
+ self,
85
+ sentences_path: str,
86
+ deleted_path: str,
87
+ collection: str = "sentences",
88
+ del_collection: str = "deleted",
89
+ ):
90
+
91
+ # recuperate the new corpora
92
+ new_corpora = pd.DataFrame(list(self.db[collection].find()))
93
+
94
+ # recuperate the deleted sentences as a Data Frame
95
+ deleted_df = pd.DataFrame(list(self.db[del_collection].find()))
96
+
97
+ # save the data frames as csv files
98
+ new_corpora.set_index("_id", inplace=True)
99
+
100
+ deleted_df.set_index("_id", inplace=True)
101
+
102
+ new_corpora.to_csv(sentences_path, index=False)
103
+
104
+ deleted_df.to_csv(deleted_path, index=False)
105
+
106
+ def load_data_frames(
107
+ self, collection: str = "sentences", del_collection: str = "deleted"
108
+ ):
109
+
110
+ # recuperate the new corpora
111
+ new_corpora = pd.DataFrame(list(self.db[collection].find()))
112
+
113
+ # recuperate the deleted sentences as a Data Frame
114
+ deleted_df = pd.DataFrame(list(self.db[del_collection].find()))
115
+
116
+ return new_corpora, deleted_df
@@ -0,0 +1,162 @@
1
+ import plotly.graph_objects as go
2
+ from tabulate import tabulate
3
+ import plotly.io as pio
4
+ import pandas as pd
5
+ import textwrap
6
+
7
+
8
+ def display_samples(
9
+ data_frame: pd.DataFrame,
10
+ n_samples: int = 40,
11
+ seed: int = 0,
12
+ header_color: str = "paleturquoise",
13
+ cells_color: str = "lavender",
14
+ width: int = 600,
15
+ height: int = 1000,
16
+ save_sample: bool = True,
17
+ table_caption: str = "",
18
+ label: str = "",
19
+ filename: str = "samples.csv",
20
+ lang: str = "eng",
21
+ show: bool = True,
22
+ ):
23
+ """Display a random sample of the data frame.
24
+
25
+ Args:
26
+ data_frame (pd.DataFrame): The data frame to display.
27
+ n_samples (int, optional): The number of samples. Defaults to 40.
28
+ seed (int, optional): The generator' seed. Defaults to 0.
29
+ header_color (str, optional): The header color. Defaults to 'paleturquoise'.
30
+ cells_color (str, optional): The cells' color. Defaults to 'lavender'.
31
+ width (int): The width of the figure. Defaults to 600.
32
+ height (int): The height of the figure. Defaults to 300.
33
+ lang (str): The language: 'fr' for french or 'eng' for english. Defaults to 'eng'.
34
+
35
+ Returns:
36
+ : The figure.
37
+ """
38
+
39
+ # get the samples from the data frame
40
+ # samples = data_frame.sample(n_samples, random_state = seed).tail(13)
41
+ samples = data_frame.sample(n_samples, random_state=seed)
42
+
43
+ if lang == "fr":
44
+
45
+ samples.columns = ["Phrases Originales", "Target Sentences", "Prédictions"]
46
+
47
+ elif lang == "eng":
48
+
49
+ samples.columns = ["Source Sentences", "Target Sentences", "Predictions"]
50
+
51
+ # trace the figure
52
+ fig = go.Figure(
53
+ data=go.Table(
54
+ header=dict(
55
+ values=list(samples.columns),
56
+ fill_color=header_color,
57
+ align="center",
58
+ font=dict(size=14, color="black"), # Header font style
59
+ height=40,
60
+ ),
61
+ cells=dict(
62
+ values=[samples[col] for col in samples.columns],
63
+ fill_color=cells_color,
64
+ line=dict(color="rgb(204, 204, 204)", width=1),
65
+ font=dict(size=12, color="black"),
66
+ height=30,
67
+ align="left",
68
+ ),
69
+ columnwidth=[400, 400, 400],
70
+ )
71
+ )
72
+
73
+ # Customize the table layout
74
+ fig.update_layout(
75
+ width=width, # Set the overall table width
76
+ height=height, # Set the overall table height
77
+ margin=dict(l=0, r=0, t=0, b=0), # Remove margin
78
+ paper_bgcolor="rgba(0,0,0,0)", # Transparent background
79
+ plot_bgcolor="rgba(0,0,0,0)", # Transparent plot area background
80
+ )
81
+
82
+ # display the figure
83
+ if show:
84
+ fig.show()
85
+
86
+ # save the latex script to create the table in latex
87
+ if save_sample:
88
+ samples.to_csv(f"{filename}_{lang}.csv", index=False, encoding="utf-16")
89
+
90
+ return fig
91
+
92
+
93
+ def save_go_figure_as_image(
94
+ fig, path: str, scale: int = 3, width: int = 600, height: int = 1000
95
+ ):
96
+
97
+ # save the figure as a image
98
+ pio.write_image(fig, path, format="png", scale=scale, width=width, height=height)
99
+
100
+
101
+ def escape_latex(text):
102
+ """
103
+ Escape special characters in text for LaTeX.
104
+ """
105
+ special_chars = ["_", "&", "%", "$", "#", "{", "}", "~", "^"]
106
+ for char in special_chars:
107
+ text = text.replace(char, "\\" + char)
108
+ return text
109
+
110
+
111
+ def wrap_text(text, max_width):
112
+ """
113
+ Wrap long text to fit into the table cells.
114
+ """
115
+ return "\n".join(textwrap.wrap(text, width=max_width))
116
+
117
+
118
+ def save_latex_table(
119
+ data_frame,
120
+ table_caption="",
121
+ label="",
122
+ filename="table.tex",
123
+ max_cell_width: int = 100,
124
+ wrap_long_text: bool = True,
125
+ ):
126
+ """
127
+ Convert a pandas DataFrame to a LaTeX table and save it to a file.
128
+
129
+ Parameters:
130
+ data_frame (pandas.DataFrame): The DataFrame to convert to LaTeX table.
131
+ table_caption (str): Optional caption for the table.
132
+ label (str): Optional label for referencing the table in the document.
133
+ filename (str): The name of the file to save the LaTeX code.
134
+
135
+ Returns:
136
+ None
137
+ """
138
+ # Convert the DataFrame to a LaTeX tabular representation
139
+ latex_table = data_frame.to_latex(
140
+ index=False, escape=False, column_format="p{.425\linewidth}p{.425\linewidth}"
141
+ )
142
+
143
+ # Modify the LaTeX tabular representation to include the caption and label, and add necessary formatting
144
+ latex_table = (
145
+ "\\begin{table}\n"
146
+ " \\centering\n"
147
+ " \\small\n"
148
+ f" \\caption{{{table_caption}}}\n\n"
149
+ " \\begin{tabular}{*{3}{p{.425\\linewidth}}}\n"
150
+ " \\toprule\n"
151
+ " Emociones primarias & Derivación de las emociones primarias \\\\\n"
152
+ " \\midrule\n"
153
+ f"{latex_table}"
154
+ " \\bottomrule\n"
155
+ " \\end{tabular}\n"
156
+ f" \\label{{{label}}}\n"
157
+ "\\end{table}"
158
+ )
159
+
160
+ with open(filename, "w", encoding="utf-8") as f:
161
+
162
+ f.write(latex_table)
@@ -0,0 +1,40 @@
1
+ import shutil
2
+ import wandb
3
+ import glob
4
+ import os
5
+
6
+
7
+ def transfer_model(artifact_dir: str, model_name: str):
8
+ """Transfer a download artifact into another directory
9
+
10
+ Args:
11
+ artifact_dir (str): The directory of the artifact
12
+ model_name (str): The name of the model
13
+ """
14
+ # transfer the model inside the artifact to data/checkpoints/name_of_model
15
+ os.makedirs(model_name, exist_ok=True)
16
+ for file in glob.glob(f"{artifact_dir}/*"):
17
+ shutil.copy(file, model_name)
18
+
19
+ # delete the artifact
20
+ shutil.rmtree(artifact_dir)
21
+
22
+
23
+ def download_artifact(artifact_name: str, model_name: str, type_: str = "dataset"):
24
+ """This function download an artifact from weights and bias and store it into a directory
25
+
26
+ Args:
27
+ artifact_name (str): name of the artifact
28
+ model_name (str): name of the model
29
+ type (str): type of the artifact. Default to 'directory'.
30
+ """
31
+ # download wandb model
32
+ run = wandb.init()
33
+ artifact = run.use_artifact(artifact_name, type=type_)
34
+ artifact_dir = artifact.download()
35
+
36
+ # transfer the artifact into another directory
37
+ transfer_model(artifact_dir, model_name)
38
+
39
+ # finish wandb
40
+ wandb.finish()
@@ -0,0 +1,147 @@
1
+ from tokenizers import Tokenizer
2
+ from typing import *
3
+ import numpy as np
4
+ import evaluate
5
+
6
+
7
+ class TranslationEvaluation:
8
+ def __init__(
9
+ self,
10
+ tokenizer: Tokenizer,
11
+ decoder: Union[Callable, None] = None,
12
+ next_gen: bool = False,
13
+ ):
14
+
15
+ self.tokenizer = tokenizer
16
+
17
+ self.decoder = decoder
18
+
19
+ self.bleu = evaluate.load("sacrebleu")
20
+
21
+ self.rouge = evaluate.load("rouge")
22
+
23
+ self.accuracy = evaluate.load("accuracy")
24
+
25
+ self.next_gen = next_gen
26
+
27
+ def postprocess_text(self, preds, labels):
28
+
29
+ for i in range(len(labels)):
30
+
31
+ pred = preds[i].strip()
32
+
33
+ label = labels[i].strip()
34
+
35
+ if self.next_gen:
36
+
37
+ new_pred = ""
38
+
39
+ new_label = ""
40
+
41
+ for j in range(len(labels[i])):
42
+
43
+ if pred[:j] != labels[i][:j]:
44
+
45
+ new_pred = pred[j:]
46
+
47
+ new_label = label[j:]
48
+
49
+ break
50
+
51
+ preds[i] = new_pred
52
+
53
+ labels[i] = new_label
54
+
55
+ else:
56
+
57
+ preds[i] = pred
58
+
59
+ labels[i] = [label]
60
+
61
+ return preds, labels
62
+
63
+ def postprocess_codes(self, preds: np.ndarray, labels: np.ndarray):
64
+
65
+ label_weights = (labels != 0).astype(float).tolist()
66
+
67
+ preds = preds.tolist()
68
+
69
+ labels = labels.tolist()
70
+
71
+ return preds, labels, label_weights
72
+
73
+ def compute_metrics(
74
+ self, eval_preds, rouge: bool = True, bleu: bool = True, accuracy: bool = False
75
+ ):
76
+
77
+ preds, labels = eval_preds
78
+
79
+ if isinstance(preds, tuple):
80
+
81
+ preds = preds[0]
82
+
83
+ decoded_preds = (
84
+ self.tokenizer.batch_decode(preds, skip_special_tokens=True)
85
+ if not self.decoder
86
+ else self.decoder(preds)
87
+ )
88
+
89
+ decoded_labels = (
90
+ self.tokenizer.batch_decode(labels, skip_special_tokens=True)
91
+ if not self.decoder
92
+ else self.decoder(labels)
93
+ )
94
+
95
+ result = {}
96
+
97
+ if accuracy:
98
+
99
+ pred_codes, label_codes, sample_weight = self.postprocess_codes(
100
+ preds, labels
101
+ )
102
+
103
+ accuracy_result = np.mean(
104
+ [
105
+ self.accuracy.compute(
106
+ predictions=pred_codes[i],
107
+ references=label_codes[i],
108
+ sample_weight=sample_weight[i],
109
+ )["accuracy"]
110
+ for i in range(len(pred_codes))
111
+ ]
112
+ )
113
+
114
+ result["accuracy"] = accuracy_result
115
+
116
+ if bleu or rouge:
117
+
118
+ decoded_preds, decoded_labels = self.postprocess_text(
119
+ decoded_preds, decoded_labels
120
+ )
121
+
122
+ if bleu:
123
+
124
+ bleu_result = self.bleu.compute(
125
+ predictions=decoded_preds, references=decoded_labels
126
+ )
127
+
128
+ result["bleu"] = bleu_result["score"]
129
+
130
+ if rouge:
131
+
132
+ rouge_result = self.rouge.compute(
133
+ predictions=decoded_preds, references=decoded_labels
134
+ )
135
+
136
+ result.update(rouge_result)
137
+
138
+ prediction_lens = [
139
+ np.count_nonzero(np.array(pred) != self.tokenizer.pad_token_id)
140
+ for pred in preds
141
+ ]
142
+
143
+ result["gen_len"] = np.mean(prediction_lens)
144
+
145
+ result = {k: round(v, 4) for k, v in result.items()}
146
+
147
+ return result
@@ -0,0 +1,74 @@
1
+ from tokenizers import Tokenizer
2
+ from typing import *
3
+ import numpy as np
4
+ import evaluate
5
+
6
+
7
+ class TranslationEvaluation:
8
+ def __init__(
9
+ self,
10
+ tokenizer: Tokenizer,
11
+ decoder: Union[Callable, None] = None,
12
+ metric=evaluate.load("sacrebleu"),
13
+ ):
14
+
15
+ self.tokenizer = tokenizer
16
+
17
+ self.decoder = decoder
18
+
19
+ self.metric = metric
20
+
21
+ def postprocess_text(self, preds, labels):
22
+
23
+ preds = [pred.strip() for pred in preds]
24
+
25
+ for label in labels:
26
+
27
+ print(label)
28
+
29
+ labels = [[label.strip()] for label in labels]
30
+
31
+ return preds, labels
32
+
33
+ def compute_metrics(self, eval_preds):
34
+
35
+ preds, labels = eval_preds
36
+
37
+ if isinstance(preds, tuple):
38
+
39
+ preds = preds[0]
40
+
41
+ decoded_preds = (
42
+ self.tokenizer.batch_decode(preds, skip_special_tokens=True)
43
+ if not self.decoder
44
+ else self.decoder(preds)
45
+ )
46
+
47
+ labels = np.where(labels != -100, labels, self.tokenizer.pad_token_id)
48
+
49
+ decoded_labels = (
50
+ self.tokenizer.batch_decode(labels, skip_special_tokens=True)
51
+ if not self.decoder
52
+ else self.decoder(labels)
53
+ )
54
+
55
+ decoded_preds, decoded_labels = self.postprocess_text(
56
+ decoded_preds, decoded_labels
57
+ )
58
+
59
+ result = self.metric.compute(
60
+ predictions=decoded_preds, references=decoded_labels
61
+ )
62
+
63
+ result = {"bleu": result["score"]}
64
+
65
+ prediction_lens = [
66
+ np.count_nonzero(np.array(pred) != self.tokenizer.pad_token_id)
67
+ for pred in preds
68
+ ]
69
+
70
+ result["gen_len"] = np.mean(prediction_lens)
71
+
72
+ result = {k: round(v, 4) for k, v in result.items()}
73
+
74
+ return result