rababa 0.1.0 → 0.1.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/.github/workflows/python.yml +81 -0
- data/.github/workflows/release.yml +36 -0
- data/.github/workflows/ruby.yml +27 -0
- data/.gitignore +3 -0
- data/.rubocop.yml +1 -1
- data/CODE_OF_CONDUCT.md +13 -13
- data/README.adoc +80 -0
- data/Rakefile +1 -1
- data/docs/{research-arabic-diacritization-06-2021.md → research-arabic-diacritization-06-2021.adoc} +52 -37
- data/exe/rababa +1 -1
- data/lib/README.adoc +95 -0
- data/lib/rababa/diacritizer.rb +16 -8
- data/lib/rababa/encoders.rb +2 -2
- data/lib/rababa/harakats.rb +1 -1
- data/lib/rababa/reconcile.rb +1 -33
- data/lib/rababa/version.rb +1 -1
- data/models-data/README.adoc +6 -0
- data/python/README.adoc +211 -0
- data/python/config/cbhg.yml +1 -1
- data/python/config/test_cbhg.yml +51 -0
- data/python/dataset.py +23 -31
- data/python/diacritization_model_to_onnx.py +216 -15
- data/python/diacritizer.py +35 -31
- data/python/log_dir/CA_MSA.base.cbhg/models/README.adoc +2 -0
- data/python/log_dir/README.adoc +1 -0
- data/python/{requirement.txt → requirements.txt} +1 -1
- data/python/setup.py +32 -0
- data/python/trainer.py +10 -4
- data/python/util/reconcile_original_plus_diacritized.py +2 -0
- data/python/util/text_cleaners.py +59 -4
- data/rababa.gemspec +1 -1
- data/test-datasets/data-arabic-pointing/{Readme.md → README.adoc} +2 -1
- metadata +22 -18
- data/.github/workflows/main.yml +0 -18
- data/README.md +0 -73
- data/lib/README.md +0 -82
- data/models-data/README.md +0 -6
- data/python/README.md +0 -163
- data/python/log_dir/CA_MSA.base.cbhg/models/Readme.md +0 -2
- data/python/log_dir/README.md +0 -1
data/python/config/cbhg.yml
CHANGED
@@ -16,7 +16,7 @@ diacritics_separator: '*' # Required if the data already processed
|
|
16
16
|
text_encoder: ArabicEncoderWithStartSymbol
|
17
17
|
text_cleaner: valid_arabic_cleaners # a white list that uses only Arabic letters, punctuations, and a space
|
18
18
|
max_len: 600 # sentences larger than this size will not be used
|
19
|
-
|
19
|
+
reconcile: true
|
20
20
|
|
21
21
|
max_steps: 2_000_000
|
22
22
|
learning_rate: 0.001
|
@@ -0,0 +1,51 @@
|
|
1
|
+
session_name: base
|
2
|
+
|
3
|
+
data_directory: "data"
|
4
|
+
data_type: "CA_MSA"
|
5
|
+
log_directory: "log_dir"
|
6
|
+
load_training_data: true
|
7
|
+
load_test_data: false
|
8
|
+
load_validation_data: true
|
9
|
+
n_training_examples: null # null load all training examples, good for fast loading
|
10
|
+
n_test_examples: null # null load all test examples
|
11
|
+
n_validation_examples: null # null load all validation examples
|
12
|
+
test_file_name: "test.csv"
|
13
|
+
is_data_preprocessed: false # The data file is organized as (original text | text | diacritics)
|
14
|
+
data_separator: '|' # Required if the data already processed
|
15
|
+
diacritics_separator: '*' # Required if the data already processed
|
16
|
+
text_encoder: ArabicEncoderWithStartSymbol
|
17
|
+
text_cleaner: valid_arabic_cleaners # a white list that uses only Arabic letters, punctuations, and a space
|
18
|
+
max_len: 600 # sentences larger than this size will not be used
|
19
|
+
reconcile: true
|
20
|
+
|
21
|
+
max_steps: 50
|
22
|
+
learning_rate: 0.001
|
23
|
+
batch_size: 32
|
24
|
+
adam_beta1: 0.9
|
25
|
+
adam_beta2: 0.999
|
26
|
+
use_decay: true
|
27
|
+
weight_decay: 0.0
|
28
|
+
embedding_dim: 256
|
29
|
+
use_prenet: false
|
30
|
+
prenet_sizes: [512, 256]
|
31
|
+
cbhg_projections: [128, 256]
|
32
|
+
cbhg_filters: 16
|
33
|
+
cbhg_gru_units: 256
|
34
|
+
post_cbhg_layers_units: [256, 256]
|
35
|
+
post_cbhg_use_batch_norm: true
|
36
|
+
|
37
|
+
use_mixed_precision: false
|
38
|
+
optimizer_type: Adam
|
39
|
+
device: cuda
|
40
|
+
|
41
|
+
# LOGGING
|
42
|
+
evaluate_frequency: 5000
|
43
|
+
evaluate_with_error_rates_frequency: 5000
|
44
|
+
n_predicted_text_tensorboard: 10 # To be written to the tensorboard
|
45
|
+
model_save_frequency: 5000
|
46
|
+
train_plotting_frequency: 50000000 # No plotting for this model
|
47
|
+
n_steps_avg_losses: [100, 500, 1_000, 5_000] # command line display of average loss values for the last n steps
|
48
|
+
error_rates_n_batches: 10000 # if calculating error rate is slow, then you can specify the number of batches to be calculated
|
49
|
+
|
50
|
+
test_model_path: null # load the last saved model
|
51
|
+
train_resume_model_path: null # load last saved model
|
data/python/dataset.py
CHANGED
@@ -4,10 +4,13 @@ Loading the diacritization dataset
|
|
4
4
|
|
5
5
|
import os
|
6
6
|
|
7
|
-
|
7
|
+
import util.text_cleaners as cleaners
|
8
8
|
import pandas as pd
|
9
9
|
import torch
|
10
10
|
import random
|
11
|
+
import warnings
|
12
|
+
from diacritization_evaluation import util
|
13
|
+
|
11
14
|
from torch.utils.data import DataLoader, Dataset
|
12
15
|
|
13
16
|
from config_manager import ConfigManager
|
@@ -15,7 +18,7 @@ from config_manager import ConfigManager
|
|
15
18
|
|
16
19
|
class DiacritizationDataset(Dataset):
|
17
20
|
"""
|
18
|
-
The diacritization
|
21
|
+
The datasets for preprocessing for diacritization
|
19
22
|
"""
|
20
23
|
|
21
24
|
def __init__(self, config_manager: ConfigManager, list_ids, data):
|
@@ -24,6 +27,7 @@ class DiacritizationDataset(Dataset):
|
|
24
27
|
self.data = data
|
25
28
|
self.text_encoder = config_manager.text_encoder
|
26
29
|
self.config = config_manager.config
|
30
|
+
# print('config:: ', self.config)
|
27
31
|
|
28
32
|
def __len__(self):
|
29
33
|
"Denotes the total number of samples"
|
@@ -33,35 +37,22 @@ class DiacritizationDataset(Dataset):
|
|
33
37
|
"Generates one sample of data"
|
34
38
|
# Select sample
|
35
39
|
id = self.list_ids[index]
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
encoding_failed = True
|
47
|
-
while encoding_failed:
|
48
|
-
try:
|
49
|
-
data = self.data[id]
|
50
|
-
data = self.text_encoder.clean(data)
|
51
|
-
text, inputs, diacritics = util.extract_haraqat(data)
|
52
|
-
encoding_failed = False
|
53
|
-
except:
|
54
|
-
print('dataset.py :: error with that data')
|
55
|
-
print('id: ', id)
|
56
|
-
print('data: ', data)
|
57
|
-
# text, inputs, diacritics = util.extract_haraqat(data[0])
|
58
|
-
id = random.randint(0, len(data))
|
59
|
-
|
60
|
-
inputs = torch.Tensor(self.text_encoder.input_to_sequence("".join(inputs)))
|
61
|
-
diacritics = torch.Tensor(self.text_encoder.target_to_sequence(diacritics))
|
62
|
-
|
63
|
-
return inputs, diacritics, text
|
40
|
+
data_orig = self.data[id].strip()
|
41
|
+
text, inputs, diacritics = cleaners.extract_haraqat(
|
42
|
+
self.text_encoder.clean(data_orig))
|
43
|
+
|
44
|
+
inputs = torch.Tensor(
|
45
|
+
self.text_encoder.input_to_sequence("".join(inputs)))
|
46
|
+
diacritics = torch.Tensor(
|
47
|
+
self.text_encoder.target_to_sequence(diacritics))
|
48
|
+
|
49
|
+
return inputs, diacritics, data_orig
|
64
50
|
|
51
|
+
#data = self.data[id]
|
52
|
+
#data = self.text_encoder.clean(data)
|
53
|
+
#text, inputs, diacritics = util.extract_haraqat(data)
|
54
|
+
#inputs = torch.Tensor(self.text_encoder.input_to_sequence("".join(inputs)))
|
55
|
+
#diacritics = torch.Tensor(self.text_encoder.target_to_sequence(diacritics))
|
65
56
|
|
66
57
|
def collate_fn(data):
|
67
58
|
"""
|
@@ -164,7 +155,8 @@ def load_test_data(config_manager: ConfigManager, loader_parameters):
|
|
164
155
|
config_manager, [idx for idx in range(len(test_data))], test_data
|
165
156
|
)
|
166
157
|
|
167
|
-
test_iterator = DataLoader(test_dataset, collate_fn=collate_fn,
|
158
|
+
test_iterator = DataLoader(test_dataset, collate_fn=collate_fn,
|
159
|
+
**loader_parameters)
|
168
160
|
|
169
161
|
print(f"Length of test iterator = {len(test_iterator)}")
|
170
162
|
return test_iterator
|
@@ -3,21 +3,21 @@ import pickle
|
|
3
3
|
|
4
4
|
import numpy as np
|
5
5
|
|
6
|
-
from diacritizer import
|
6
|
+
from diacritizer import Diacritizer
|
7
7
|
|
8
8
|
|
9
9
|
"""
|
10
10
|
Key Params:
|
11
|
-
max_len:
|
11
|
+
max_len:
|
12
12
|
is the max length for the arabic strings to be diacritized
|
13
|
-
batch size:
|
13
|
+
batch size:
|
14
14
|
has to do with the model training and usage
|
15
15
|
"""
|
16
|
-
max_len =
|
16
|
+
max_len = 200 # 600 for the original length
|
17
17
|
batch_size = 32
|
18
18
|
|
19
19
|
|
20
|
-
"""
|
20
|
+
"""
|
21
21
|
example and mock data:
|
22
22
|
we found that populating all the data, removing the zeros gives better results.
|
23
23
|
"""
|
@@ -25,7 +25,7 @@ src = torch.Tensor([[1 for i in range(max_len)]
|
|
25
25
|
for i in range(batch_size)]).long()
|
26
26
|
lengths = torch.Tensor([max_len for i in range(batch_size)]).long()
|
27
27
|
# example data
|
28
|
-
batch_data = pickle.load( open('../models-data/
|
28
|
+
batch_data = pickle.load( open('../models-data/batch_example_data.pkl', 'rb') )
|
29
29
|
|
30
30
|
#target = batch_data['target']
|
31
31
|
|
@@ -37,7 +37,7 @@ model_kind_str = 'cbhg'
|
|
37
37
|
config_str = 'config/cbhg.yml'
|
38
38
|
load_model = True
|
39
39
|
|
40
|
-
dia =
|
40
|
+
dia = Diacritizer(config_str, model_kind_str, load_model)
|
41
41
|
|
42
42
|
# set model to inference mode
|
43
43
|
dia.model.to(dia.device)
|
@@ -58,13 +58,22 @@ import onnxruntime
|
|
58
58
|
onnx_model_filename = '../models-data/diacritization_model.onnx'
|
59
59
|
|
60
60
|
|
61
|
+
print(src.shape)
|
62
|
+
|
63
|
+
#exit()
|
61
64
|
# export model
|
62
|
-
torch.onnx.export(dia.model,
|
63
|
-
(src, lengths),
|
64
|
-
onnx_model_filename,
|
65
|
-
verbose=False,
|
66
|
-
opset_version=11,
|
67
|
-
input_names=['src', 'lengths']
|
65
|
+
torch.onnx.export(dia.model,
|
66
|
+
(src, lengths),
|
67
|
+
onnx_model_filename,
|
68
|
+
verbose=False,
|
69
|
+
opset_version=11,
|
70
|
+
input_names=['src', 'lengths'],
|
71
|
+
output_names=['output'],
|
72
|
+
dynamic_axes = {'src': [1], #[0,1,2], #[0,1,2],
|
73
|
+
#'input_2':{0:'batch'},
|
74
|
+
'output': [1]
|
75
|
+
})
|
76
|
+
|
68
77
|
print('Model printed in rel. path:', onnx_model_filename)
|
69
78
|
|
70
79
|
|
@@ -94,10 +103,202 @@ ort_inputs = {ort_session.get_inputs()[0].name: src.detach().numpy().astype(np.i
|
|
94
103
|
# run onnx model
|
95
104
|
ort_outs = ort_session.run(None, ort_inputs)
|
96
105
|
|
106
|
+
print('outs:: ', ort_outs)
|
97
107
|
|
98
|
-
for i in range(batch_size):
|
99
|
-
np.testing.assert_allclose(torch_out['diacritics'][i].detach().numpy(), ort_outs[0][i], rtol=1e-03, atol=1e-03)
|
100
108
|
|
109
|
+
print('src:: ', src.detach().numpy().astype(np.int64))
|
110
|
+
print('lengths: ',lengths.detach().numpy().astype(np.int64))
|
111
|
+
|
112
|
+
#exit()
|
113
|
+
|
114
|
+
for i in range(batch_size):
|
115
|
+
np.testing.assert_allclose(torch_out['diacritics'][i].detach().numpy(),
|
116
|
+
ort_outs[0][i], rtol=1e-03, atol=1e-03)
|
101
117
|
|
102
118
|
print("\n!!!Exported model has been tested with ONNXRuntime, result looks good within given tolerance!!!")
|
103
119
|
|
120
|
+
|
121
|
+
|
122
|
+
vec = [[41, 12, 40] for i in range(batch_size)]
|
123
|
+
src = torch.Tensor(vec).long()
|
124
|
+
|
125
|
+
lengths = torch.Tensor([3 for i in range(batch_size)]).long()
|
126
|
+
|
127
|
+
ort_inputs = {ort_session.get_inputs()[0].name: src.detach().numpy().astype(np.int64),
|
128
|
+
ort_session.get_inputs()[1].name: lengths.detach().numpy().astype(np.int64)}
|
129
|
+
|
130
|
+
|
131
|
+
#print('12345678910')
|
132
|
+
#print(ort_session.get_inputs()[0].name)
|
133
|
+
#print(ort_session.get_inputs()[1].name)
|
134
|
+
|
135
|
+
print('run 3')
|
136
|
+
ort_outs = ort_session.run(None, ort_inputs)
|
137
|
+
print('outs:: ', ort_outs[0].shape)
|
138
|
+
print('outs:: ', ort_outs[0][0][0])
|
139
|
+
print('outs:: ', ort_outs[0][0][1])
|
140
|
+
print('outs:: ', ort_outs[0][0][2])
|
141
|
+
|
142
|
+
torch_out = dia.model(src, lengths)
|
143
|
+
|
144
|
+
#print(torch_out['diacritics'][0])
|
145
|
+
for i in range(batch_size):
|
146
|
+
np.testing.assert_allclose(torch_out['diacritics'][i].detach().numpy(), \
|
147
|
+
ort_outs[0][i], rtol=1e-03, atol=1e-03)
|
148
|
+
|
149
|
+
print('12345678910')
|
150
|
+
print(ort_session.get_inputs()[0].name)
|
151
|
+
print(ort_session.get_inputs()[1].name)
|
152
|
+
|
153
|
+
#exit()
|
154
|
+
|
155
|
+
|
156
|
+
|
157
|
+
"""
|
158
|
+
Test ONNX model on randomized data
|
159
|
+
"""
|
160
|
+
|
161
|
+
import random
|
162
|
+
test_id = 0
|
163
|
+
|
164
|
+
print('***** Test MAX size :: Random Boolean vectors: *****')
|
165
|
+
print(max_len)
|
166
|
+
|
167
|
+
for test_run in range(3):
|
168
|
+
|
169
|
+
vec = [[random.randint(0,1) for i in range(max_len)]
|
170
|
+
for i in range(batch_size)]
|
171
|
+
src = torch.Tensor(vec).long()
|
172
|
+
lengths = torch.Tensor([max_len for i in range(batch_size)]).long()
|
173
|
+
|
174
|
+
"""
|
175
|
+
with open('test_data/test'+str(test_id)+'.txt', 'w') as f:
|
176
|
+
for ll in src.detach().tolist():
|
177
|
+
for item in ll:
|
178
|
+
f.write("%s " % item)
|
179
|
+
f.write("\n")
|
180
|
+
f.close()
|
181
|
+
"""
|
182
|
+
torch_out = dia.model(src, lengths)
|
183
|
+
"""
|
184
|
+
my_list = torch_out['diacritics'].detach().numpy().tolist()
|
185
|
+
with open('test_data/test'+str(test_id)+'_torch.txt', 'w') as f:
|
186
|
+
for ll in my_list:
|
187
|
+
for item in ll:
|
188
|
+
for l in item:
|
189
|
+
f.write("%s " % l)
|
190
|
+
f.write("\n")
|
191
|
+
f.close()
|
192
|
+
test_id+=1
|
193
|
+
"""
|
194
|
+
# prepare onnx input
|
195
|
+
ort_inputs = {ort_session.get_inputs()[0].name: src.detach().numpy().astype(np.int64),
|
196
|
+
ort_session.get_inputs()[1].name: lengths.detach().numpy().astype(np.int64)}
|
197
|
+
|
198
|
+
# run onnx model
|
199
|
+
ort_outs = ort_session.run(None, ort_inputs)
|
200
|
+
|
201
|
+
for i in range(batch_size):
|
202
|
+
np.testing.assert_allclose(torch_out['diacritics'][i].detach().numpy(), \
|
203
|
+
ort_outs[0][i], rtol=1e-03, atol=1e-03)
|
204
|
+
|
205
|
+
print('test :: ', test_run)
|
206
|
+
print("Result looks good within given tolerance!!!")
|
207
|
+
|
208
|
+
|
209
|
+
print('***** Test MAX size :: Random float, vectors within 0:16 *****')
|
210
|
+
print(max_len)
|
211
|
+
|
212
|
+
for test_run in range(3):
|
213
|
+
|
214
|
+
vec = [[random.randint(0, 17) for i in range(max_len)]
|
215
|
+
for i in range(batch_size)]
|
216
|
+
src = torch.Tensor(vec).long()
|
217
|
+
"""
|
218
|
+
with open('test_data/test'+str(test_id)+'.txt', 'w') as f:
|
219
|
+
for ll in src.detach().tolist():
|
220
|
+
for item in ll:
|
221
|
+
f.write("%s " % item)
|
222
|
+
f.write("\n")
|
223
|
+
f.close()
|
224
|
+
"""
|
225
|
+
torch_out = dia.model(src, lengths)
|
226
|
+
|
227
|
+
#my_list = torch_out['diacritics'].detach().numpy().tolist()
|
228
|
+
"""
|
229
|
+
with open('test_data/test'+str(test_id)+'_torch.txt', 'w') as f:
|
230
|
+
for ll in my_list:
|
231
|
+
for item in ll:
|
232
|
+
for l in item:
|
233
|
+
f.write("%s " % l)
|
234
|
+
f.write("\n")
|
235
|
+
f.close()
|
236
|
+
test_id+=1
|
237
|
+
"""
|
238
|
+
# prepare onnx input
|
239
|
+
ort_inputs = {ort_session.get_inputs()[0].name: src.detach().numpy().astype(np.int64),
|
240
|
+
ort_session.get_inputs()[1].name: lengths.detach().numpy().astype(np.int64)}
|
241
|
+
|
242
|
+
# run onnx model
|
243
|
+
ort_outs = ort_session.run(None, ort_inputs)
|
244
|
+
|
245
|
+
for i in range(batch_size):
|
246
|
+
np.testing.assert_allclose(torch_out['diacritics'][i].detach().numpy(), \
|
247
|
+
ort_outs[0][i], rtol=1, atol=1)
|
248
|
+
|
249
|
+
print('test :: ', test_run)
|
250
|
+
print("Result looks good within given tolerance!!!")
|
251
|
+
|
252
|
+
|
253
|
+
print('***** Test Dynamical sizes :: Random Boolean vectors: *****')
|
254
|
+
|
255
|
+
for l in [2, 10, 40, 100, 150]:
|
256
|
+
|
257
|
+
print('length:: ', l)
|
258
|
+
|
259
|
+
vec = [[1 for i in range(l)] # random.randint(0,1)
|
260
|
+
for i in range(batch_size)]
|
261
|
+
src = torch.Tensor(vec).long()
|
262
|
+
lengths = torch.Tensor([l for i in range(batch_size)]).long()
|
263
|
+
|
264
|
+
torch_out = dia.model(src, lengths)
|
265
|
+
|
266
|
+
# prepare onnx input
|
267
|
+
ort_inputs = {ort_session.get_inputs()[0].name: src.detach().numpy().astype(np.int64),
|
268
|
+
ort_session.get_inputs()[1].name: lengths.detach().numpy().astype(np.int64)}
|
269
|
+
|
270
|
+
# run onnx model
|
271
|
+
ort_outs = ort_session.run(None, ort_inputs)
|
272
|
+
|
273
|
+
for i in range(batch_size):
|
274
|
+
np.testing.assert_allclose(torch_out['diacritics'][i].detach().numpy(), \
|
275
|
+
ort_outs[0][i], rtol=1e-03, atol=1e-03)
|
276
|
+
|
277
|
+
print('test :: ', l)
|
278
|
+
print("Result looks good within given tolerance!!!")
|
279
|
+
|
280
|
+
|
281
|
+
print('***** Test Dynamical sizes :: Random float, vectors within 0:16 *****')
|
282
|
+
|
283
|
+
for l in [2, 10, 40, 100, 150]:
|
284
|
+
|
285
|
+
vec = [[random.randint(0, 17) for i in range(l)]
|
286
|
+
for i in range(batch_size)]
|
287
|
+
src = torch.Tensor(vec).long()
|
288
|
+
lengths = torch.Tensor([l for i in range(batch_size)]).long()
|
289
|
+
|
290
|
+
torch_out = dia.model(src, lengths)
|
291
|
+
|
292
|
+
# prepare onnx input
|
293
|
+
ort_inputs = {ort_session.get_inputs()[0].name: src.detach().numpy().astype(np.int64),
|
294
|
+
ort_session.get_inputs()[1].name: lengths.detach().numpy().astype(np.int64)}
|
295
|
+
|
296
|
+
# run onnx model
|
297
|
+
ort_outs = ort_session.run(None, ort_inputs)
|
298
|
+
|
299
|
+
for i in range(batch_size):
|
300
|
+
np.testing.assert_allclose(torch_out['diacritics'][i].detach().numpy(), \
|
301
|
+
ort_outs[0][i], rtol=1, atol=1)
|
302
|
+
|
303
|
+
print('test :: ', l)
|
304
|
+
print("Result looks good within given tolerance!!!")
|
data/python/diacritizer.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1
1
|
from typing import Dict
|
2
2
|
import torch
|
3
3
|
import tqdm
|
4
|
+
import pandas as pd
|
5
|
+
import numpy as np
|
4
6
|
from config_manager import ConfigManager
|
5
7
|
from dataset import (DiacritizationDataset,
|
6
8
|
collate_fn)
|
@@ -9,7 +11,6 @@ from torch.utils.data import (DataLoader,
|
|
9
11
|
import util.reconcile_original_plus_diacritized as reconcile
|
10
12
|
|
11
13
|
|
12
|
-
|
13
14
|
class Diacritizer:
|
14
15
|
def __init__(
|
15
16
|
self, config_path: str, model_kind: str, load_model: bool = False
|
@@ -35,6 +36,7 @@ class Diacritizer:
|
|
35
36
|
|
36
37
|
def diacritize_text(self, text: str):
|
37
38
|
# convert string into indices
|
39
|
+
text = text.strip()
|
38
40
|
seq = self.text_encoder.input_to_sequence(text)
|
39
41
|
# transform indices into "batch data"
|
40
42
|
batch_data = {'original': [text],
|
@@ -48,25 +50,31 @@ class Diacritizer:
|
|
48
50
|
loader_params = {"batch_size": self.config_manager.config["batch_size"],
|
49
51
|
"shuffle": False,
|
50
52
|
"num_workers": 2}
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
53
|
+
|
54
|
+
data_tmp = pd.read_csv(path,
|
55
|
+
encoding="utf-8",
|
56
|
+
sep=self.config_manager.config["data_separator"],
|
57
|
+
header=None)
|
58
|
+
|
59
|
+
data = []
|
60
|
+
max_len = self.config_manager.config["max_len"]
|
61
|
+
for txt in [d[0] for d in data_tmp.values.tolist()]:
|
62
|
+
if len(txt) > max_len:
|
63
|
+
txt = txt[:max_len]
|
64
|
+
warnings.warn('Warning: text length cut for sentence: \n'+text)
|
65
|
+
data.append(txt)
|
66
|
+
|
67
|
+
list_ids = [idx for idx in range(len(data))]
|
68
|
+
dataset = DiacritizationDataset(self.config_manager,
|
69
|
+
list_ids,
|
70
|
+
data)
|
71
|
+
|
72
|
+
data_iterator = DataLoader(dataset,
|
73
|
+
collate_fn=collate_fn,
|
74
|
+
# **loader_params,
|
75
|
+
shuffle=False)
|
76
|
+
|
77
|
+
# print(f"Length of data iterator = {len(data_iterator)}")
|
70
78
|
return data_iterator
|
71
79
|
|
72
80
|
def diacritize_file(self, path: str):
|
@@ -75,6 +83,7 @@ class Diacritizer:
|
|
75
83
|
diacritized_data = []
|
76
84
|
for batch_inputs in tqdm.tqdm(data_iterator):
|
77
85
|
|
86
|
+
#batch_inputs["original"] = batch_inputs["original"].to(self.device)
|
78
87
|
batch_inputs["src"] = batch_inputs["src"].to(self.device)
|
79
88
|
batch_inputs["lengths"] = batch_inputs["lengths"].to('cpu')
|
80
89
|
batch_inputs["target"] = batch_inputs["target"].to(self.device)
|
@@ -85,7 +94,7 @@ class Diacritizer:
|
|
85
94
|
return diacritized_data
|
86
95
|
|
87
96
|
def diacritize_batch(self, batch):
|
88
|
-
#print('batch: ',batch)
|
97
|
+
# print('batch: ',batch)
|
89
98
|
self.model.eval()
|
90
99
|
originals = batch['original']
|
91
100
|
inputs = batch["src"]
|
@@ -93,25 +102,20 @@ class Diacritizer:
|
|
93
102
|
outputs = self.model(inputs.to(self.device), lengths.to("cpu"))
|
94
103
|
diacritics = outputs["diacritics"]
|
95
104
|
predictions = torch.max(diacritics, 2).indices
|
96
|
-
sentences = []
|
97
105
|
|
106
|
+
sentences = []
|
98
107
|
for src, prediction, original in zip(inputs, predictions, originals):
|
99
108
|
sentence = self.text_encoder.combine_text_and_haraqat(
|
100
|
-
|
101
|
-
|
102
|
-
)
|
109
|
+
list(src.detach().cpu().numpy()),
|
110
|
+
list(prediction.detach().cpu().numpy()))
|
103
111
|
# Diacritized strings, sentence have to be "reconciled"
|
104
112
|
# with original strings, because the non arabic strings are removed
|
105
113
|
# before being processed in nnet
|
106
|
-
|
114
|
+
if self.config['reconcile']:
|
115
|
+
sentence = reconcile.reconcile_strings(original, sentence)
|
107
116
|
sentences.append(sentence)
|
108
117
|
|
109
118
|
return sentences
|
110
119
|
|
111
120
|
def diacritize_iterators(self, iterator):
|
112
121
|
pass
|
113
|
-
|
114
|
-
""" not needed
|
115
|
-
class CBHGDiacritizer(Diacritizer):
|
116
|
-
class Seq2SeqDiacritizer(Diacritizer):
|
117
|
-
"""
|