rababa 0.1.0 → 0.1.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/.github/workflows/python.yml +81 -0
- data/.github/workflows/release.yml +36 -0
- data/.github/workflows/ruby.yml +27 -0
- data/.gitignore +3 -0
- data/.rubocop.yml +1 -1
- data/CODE_OF_CONDUCT.md +13 -13
- data/README.adoc +80 -0
- data/Rakefile +1 -1
- data/docs/{research-arabic-diacritization-06-2021.md → research-arabic-diacritization-06-2021.adoc} +52 -37
- data/exe/rababa +1 -1
- data/lib/README.adoc +95 -0
- data/lib/rababa/diacritizer.rb +16 -8
- data/lib/rababa/encoders.rb +2 -2
- data/lib/rababa/harakats.rb +1 -1
- data/lib/rababa/reconcile.rb +1 -33
- data/lib/rababa/version.rb +1 -1
- data/models-data/README.adoc +6 -0
- data/python/README.adoc +211 -0
- data/python/config/cbhg.yml +1 -1
- data/python/config/test_cbhg.yml +51 -0
- data/python/dataset.py +23 -31
- data/python/diacritization_model_to_onnx.py +216 -15
- data/python/diacritizer.py +35 -31
- data/python/log_dir/CA_MSA.base.cbhg/models/README.adoc +2 -0
- data/python/log_dir/README.adoc +1 -0
- data/python/{requirement.txt → requirements.txt} +1 -1
- data/python/setup.py +32 -0
- data/python/trainer.py +10 -4
- data/python/util/reconcile_original_plus_diacritized.py +2 -0
- data/python/util/text_cleaners.py +59 -4
- data/rababa.gemspec +1 -1
- data/test-datasets/data-arabic-pointing/{Readme.md → README.adoc} +2 -1
- metadata +22 -18
- data/.github/workflows/main.yml +0 -18
- data/README.md +0 -73
- data/lib/README.md +0 -82
- data/models-data/README.md +0 -6
- data/python/README.md +0 -163
- data/python/log_dir/CA_MSA.base.cbhg/models/Readme.md +0 -2
- data/python/log_dir/README.md +0 -1
data/python/config/cbhg.yml
CHANGED
@@ -16,7 +16,7 @@ diacritics_separator: '*' # Required if the data already processed
|
|
16
16
|
text_encoder: ArabicEncoderWithStartSymbol
|
17
17
|
text_cleaner: valid_arabic_cleaners # a white list that uses only Arabic letters, punctuations, and a space
|
18
18
|
max_len: 600 # sentences larger than this size will not be used
|
19
|
-
|
19
|
+
reconcile: true
|
20
20
|
|
21
21
|
max_steps: 2_000_000
|
22
22
|
learning_rate: 0.001
|
@@ -0,0 +1,51 @@
|
|
1
|
+
session_name: base
|
2
|
+
|
3
|
+
data_directory: "data"
|
4
|
+
data_type: "CA_MSA"
|
5
|
+
log_directory: "log_dir"
|
6
|
+
load_training_data: true
|
7
|
+
load_test_data: false
|
8
|
+
load_validation_data: true
|
9
|
+
n_training_examples: null # null load all training examples, good for fast loading
|
10
|
+
n_test_examples: null # null load all test examples
|
11
|
+
n_validation_examples: null # null load all validation examples
|
12
|
+
test_file_name: "test.csv"
|
13
|
+
is_data_preprocessed: false # The data file is organized as (original text | text | diacritics)
|
14
|
+
data_separator: '|' # Required if the data already processed
|
15
|
+
diacritics_separator: '*' # Required if the data already processed
|
16
|
+
text_encoder: ArabicEncoderWithStartSymbol
|
17
|
+
text_cleaner: valid_arabic_cleaners # a white list that uses only Arabic letters, punctuations, and a space
|
18
|
+
max_len: 600 # sentences larger than this size will not be used
|
19
|
+
reconcile: true
|
20
|
+
|
21
|
+
max_steps: 50
|
22
|
+
learning_rate: 0.001
|
23
|
+
batch_size: 32
|
24
|
+
adam_beta1: 0.9
|
25
|
+
adam_beta2: 0.999
|
26
|
+
use_decay: true
|
27
|
+
weight_decay: 0.0
|
28
|
+
embedding_dim: 256
|
29
|
+
use_prenet: false
|
30
|
+
prenet_sizes: [512, 256]
|
31
|
+
cbhg_projections: [128, 256]
|
32
|
+
cbhg_filters: 16
|
33
|
+
cbhg_gru_units: 256
|
34
|
+
post_cbhg_layers_units: [256, 256]
|
35
|
+
post_cbhg_use_batch_norm: true
|
36
|
+
|
37
|
+
use_mixed_precision: false
|
38
|
+
optimizer_type: Adam
|
39
|
+
device: cuda
|
40
|
+
|
41
|
+
# LOGGING
|
42
|
+
evaluate_frequency: 5000
|
43
|
+
evaluate_with_error_rates_frequency: 5000
|
44
|
+
n_predicted_text_tensorboard: 10 # To be written to the tensorboard
|
45
|
+
model_save_frequency: 5000
|
46
|
+
train_plotting_frequency: 50000000 # No plotting for this model
|
47
|
+
n_steps_avg_losses: [100, 500, 1_000, 5_000] # command line display of average loss values for the last n steps
|
48
|
+
error_rates_n_batches: 10000 # if calculating error rate is slow, then you can specify the number of batches to be calculated
|
49
|
+
|
50
|
+
test_model_path: null # load the last saved model
|
51
|
+
train_resume_model_path: null # load last saved model
|
data/python/dataset.py
CHANGED
@@ -4,10 +4,13 @@ Loading the diacritization dataset
|
|
4
4
|
|
5
5
|
import os
|
6
6
|
|
7
|
-
|
7
|
+
import util.text_cleaners as cleaners
|
8
8
|
import pandas as pd
|
9
9
|
import torch
|
10
10
|
import random
|
11
|
+
import warnings
|
12
|
+
from diacritization_evaluation import util
|
13
|
+
|
11
14
|
from torch.utils.data import DataLoader, Dataset
|
12
15
|
|
13
16
|
from config_manager import ConfigManager
|
@@ -15,7 +18,7 @@ from config_manager import ConfigManager
|
|
15
18
|
|
16
19
|
class DiacritizationDataset(Dataset):
|
17
20
|
"""
|
18
|
-
The diacritization
|
21
|
+
The datasets for preprocessing for diacritization
|
19
22
|
"""
|
20
23
|
|
21
24
|
def __init__(self, config_manager: ConfigManager, list_ids, data):
|
@@ -24,6 +27,7 @@ class DiacritizationDataset(Dataset):
|
|
24
27
|
self.data = data
|
25
28
|
self.text_encoder = config_manager.text_encoder
|
26
29
|
self.config = config_manager.config
|
30
|
+
# print('config:: ', self.config)
|
27
31
|
|
28
32
|
def __len__(self):
|
29
33
|
"Denotes the total number of samples"
|
@@ -33,35 +37,22 @@ class DiacritizationDataset(Dataset):
|
|
33
37
|
"Generates one sample of data"
|
34
38
|
# Select sample
|
35
39
|
id = self.list_ids[index]
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
encoding_failed = True
|
47
|
-
while encoding_failed:
|
48
|
-
try:
|
49
|
-
data = self.data[id]
|
50
|
-
data = self.text_encoder.clean(data)
|
51
|
-
text, inputs, diacritics = util.extract_haraqat(data)
|
52
|
-
encoding_failed = False
|
53
|
-
except:
|
54
|
-
print('dataset.py :: error with that data')
|
55
|
-
print('id: ', id)
|
56
|
-
print('data: ', data)
|
57
|
-
# text, inputs, diacritics = util.extract_haraqat(data[0])
|
58
|
-
id = random.randint(0, len(data))
|
59
|
-
|
60
|
-
inputs = torch.Tensor(self.text_encoder.input_to_sequence("".join(inputs)))
|
61
|
-
diacritics = torch.Tensor(self.text_encoder.target_to_sequence(diacritics))
|
62
|
-
|
63
|
-
return inputs, diacritics, text
|
40
|
+
data_orig = self.data[id].strip()
|
41
|
+
text, inputs, diacritics = cleaners.extract_haraqat(
|
42
|
+
self.text_encoder.clean(data_orig))
|
43
|
+
|
44
|
+
inputs = torch.Tensor(
|
45
|
+
self.text_encoder.input_to_sequence("".join(inputs)))
|
46
|
+
diacritics = torch.Tensor(
|
47
|
+
self.text_encoder.target_to_sequence(diacritics))
|
48
|
+
|
49
|
+
return inputs, diacritics, data_orig
|
64
50
|
|
51
|
+
#data = self.data[id]
|
52
|
+
#data = self.text_encoder.clean(data)
|
53
|
+
#text, inputs, diacritics = util.extract_haraqat(data)
|
54
|
+
#inputs = torch.Tensor(self.text_encoder.input_to_sequence("".join(inputs)))
|
55
|
+
#diacritics = torch.Tensor(self.text_encoder.target_to_sequence(diacritics))
|
65
56
|
|
66
57
|
def collate_fn(data):
|
67
58
|
"""
|
@@ -164,7 +155,8 @@ def load_test_data(config_manager: ConfigManager, loader_parameters):
|
|
164
155
|
config_manager, [idx for idx in range(len(test_data))], test_data
|
165
156
|
)
|
166
157
|
|
167
|
-
test_iterator = DataLoader(test_dataset, collate_fn=collate_fn,
|
158
|
+
test_iterator = DataLoader(test_dataset, collate_fn=collate_fn,
|
159
|
+
**loader_parameters)
|
168
160
|
|
169
161
|
print(f"Length of test iterator = {len(test_iterator)}")
|
170
162
|
return test_iterator
|
@@ -3,21 +3,21 @@ import pickle
|
|
3
3
|
|
4
4
|
import numpy as np
|
5
5
|
|
6
|
-
from diacritizer import
|
6
|
+
from diacritizer import Diacritizer
|
7
7
|
|
8
8
|
|
9
9
|
"""
|
10
10
|
Key Params:
|
11
|
-
max_len:
|
11
|
+
max_len:
|
12
12
|
is the max length for the arabic strings to be diacritized
|
13
|
-
batch size:
|
13
|
+
batch size:
|
14
14
|
has to do with the model training and usage
|
15
15
|
"""
|
16
|
-
max_len =
|
16
|
+
max_len = 200 # 600 for the original length
|
17
17
|
batch_size = 32
|
18
18
|
|
19
19
|
|
20
|
-
"""
|
20
|
+
"""
|
21
21
|
example and mock data:
|
22
22
|
we found that populating all the data, removing the zeros gives better results.
|
23
23
|
"""
|
@@ -25,7 +25,7 @@ src = torch.Tensor([[1 for i in range(max_len)]
|
|
25
25
|
for i in range(batch_size)]).long()
|
26
26
|
lengths = torch.Tensor([max_len for i in range(batch_size)]).long()
|
27
27
|
# example data
|
28
|
-
batch_data = pickle.load( open('../models-data/
|
28
|
+
batch_data = pickle.load( open('../models-data/batch_example_data.pkl', 'rb') )
|
29
29
|
|
30
30
|
#target = batch_data['target']
|
31
31
|
|
@@ -37,7 +37,7 @@ model_kind_str = 'cbhg'
|
|
37
37
|
config_str = 'config/cbhg.yml'
|
38
38
|
load_model = True
|
39
39
|
|
40
|
-
dia =
|
40
|
+
dia = Diacritizer(config_str, model_kind_str, load_model)
|
41
41
|
|
42
42
|
# set model to inference mode
|
43
43
|
dia.model.to(dia.device)
|
@@ -58,13 +58,22 @@ import onnxruntime
|
|
58
58
|
onnx_model_filename = '../models-data/diacritization_model.onnx'
|
59
59
|
|
60
60
|
|
61
|
+
print(src.shape)
|
62
|
+
|
63
|
+
#exit()
|
61
64
|
# export model
|
62
|
-
torch.onnx.export(dia.model,
|
63
|
-
(src, lengths),
|
64
|
-
onnx_model_filename,
|
65
|
-
verbose=False,
|
66
|
-
opset_version=11,
|
67
|
-
input_names=['src', 'lengths']
|
65
|
+
torch.onnx.export(dia.model,
|
66
|
+
(src, lengths),
|
67
|
+
onnx_model_filename,
|
68
|
+
verbose=False,
|
69
|
+
opset_version=11,
|
70
|
+
input_names=['src', 'lengths'],
|
71
|
+
output_names=['output'],
|
72
|
+
dynamic_axes = {'src': [1], #[0,1,2], #[0,1,2],
|
73
|
+
#'input_2':{0:'batch'},
|
74
|
+
'output': [1]
|
75
|
+
})
|
76
|
+
|
68
77
|
print('Model printed in rel. path:', onnx_model_filename)
|
69
78
|
|
70
79
|
|
@@ -94,10 +103,202 @@ ort_inputs = {ort_session.get_inputs()[0].name: src.detach().numpy().astype(np.i
|
|
94
103
|
# run onnx model
|
95
104
|
ort_outs = ort_session.run(None, ort_inputs)
|
96
105
|
|
106
|
+
print('outs:: ', ort_outs)
|
97
107
|
|
98
|
-
for i in range(batch_size):
|
99
|
-
np.testing.assert_allclose(torch_out['diacritics'][i].detach().numpy(), ort_outs[0][i], rtol=1e-03, atol=1e-03)
|
100
108
|
|
109
|
+
print('src:: ', src.detach().numpy().astype(np.int64))
|
110
|
+
print('lengths: ',lengths.detach().numpy().astype(np.int64))
|
111
|
+
|
112
|
+
#exit()
|
113
|
+
|
114
|
+
for i in range(batch_size):
|
115
|
+
np.testing.assert_allclose(torch_out['diacritics'][i].detach().numpy(),
|
116
|
+
ort_outs[0][i], rtol=1e-03, atol=1e-03)
|
101
117
|
|
102
118
|
print("\n!!!Exported model has been tested with ONNXRuntime, result looks good within given tolerance!!!")
|
103
119
|
|
120
|
+
|
121
|
+
|
122
|
+
vec = [[41, 12, 40] for i in range(batch_size)]
|
123
|
+
src = torch.Tensor(vec).long()
|
124
|
+
|
125
|
+
lengths = torch.Tensor([3 for i in range(batch_size)]).long()
|
126
|
+
|
127
|
+
ort_inputs = {ort_session.get_inputs()[0].name: src.detach().numpy().astype(np.int64),
|
128
|
+
ort_session.get_inputs()[1].name: lengths.detach().numpy().astype(np.int64)}
|
129
|
+
|
130
|
+
|
131
|
+
#print('12345678910')
|
132
|
+
#print(ort_session.get_inputs()[0].name)
|
133
|
+
#print(ort_session.get_inputs()[1].name)
|
134
|
+
|
135
|
+
print('run 3')
|
136
|
+
ort_outs = ort_session.run(None, ort_inputs)
|
137
|
+
print('outs:: ', ort_outs[0].shape)
|
138
|
+
print('outs:: ', ort_outs[0][0][0])
|
139
|
+
print('outs:: ', ort_outs[0][0][1])
|
140
|
+
print('outs:: ', ort_outs[0][0][2])
|
141
|
+
|
142
|
+
torch_out = dia.model(src, lengths)
|
143
|
+
|
144
|
+
#print(torch_out['diacritics'][0])
|
145
|
+
for i in range(batch_size):
|
146
|
+
np.testing.assert_allclose(torch_out['diacritics'][i].detach().numpy(), \
|
147
|
+
ort_outs[0][i], rtol=1e-03, atol=1e-03)
|
148
|
+
|
149
|
+
print('12345678910')
|
150
|
+
print(ort_session.get_inputs()[0].name)
|
151
|
+
print(ort_session.get_inputs()[1].name)
|
152
|
+
|
153
|
+
#exit()
|
154
|
+
|
155
|
+
|
156
|
+
|
157
|
+
"""
|
158
|
+
Test ONNX model on randomized data
|
159
|
+
"""
|
160
|
+
|
161
|
+
import random
|
162
|
+
test_id = 0
|
163
|
+
|
164
|
+
print('***** Test MAX size :: Random Boolean vectors: *****')
|
165
|
+
print(max_len)
|
166
|
+
|
167
|
+
for test_run in range(3):
|
168
|
+
|
169
|
+
vec = [[random.randint(0,1) for i in range(max_len)]
|
170
|
+
for i in range(batch_size)]
|
171
|
+
src = torch.Tensor(vec).long()
|
172
|
+
lengths = torch.Tensor([max_len for i in range(batch_size)]).long()
|
173
|
+
|
174
|
+
"""
|
175
|
+
with open('test_data/test'+str(test_id)+'.txt', 'w') as f:
|
176
|
+
for ll in src.detach().tolist():
|
177
|
+
for item in ll:
|
178
|
+
f.write("%s " % item)
|
179
|
+
f.write("\n")
|
180
|
+
f.close()
|
181
|
+
"""
|
182
|
+
torch_out = dia.model(src, lengths)
|
183
|
+
"""
|
184
|
+
my_list = torch_out['diacritics'].detach().numpy().tolist()
|
185
|
+
with open('test_data/test'+str(test_id)+'_torch.txt', 'w') as f:
|
186
|
+
for ll in my_list:
|
187
|
+
for item in ll:
|
188
|
+
for l in item:
|
189
|
+
f.write("%s " % l)
|
190
|
+
f.write("\n")
|
191
|
+
f.close()
|
192
|
+
test_id+=1
|
193
|
+
"""
|
194
|
+
# prepare onnx input
|
195
|
+
ort_inputs = {ort_session.get_inputs()[0].name: src.detach().numpy().astype(np.int64),
|
196
|
+
ort_session.get_inputs()[1].name: lengths.detach().numpy().astype(np.int64)}
|
197
|
+
|
198
|
+
# run onnx model
|
199
|
+
ort_outs = ort_session.run(None, ort_inputs)
|
200
|
+
|
201
|
+
for i in range(batch_size):
|
202
|
+
np.testing.assert_allclose(torch_out['diacritics'][i].detach().numpy(), \
|
203
|
+
ort_outs[0][i], rtol=1e-03, atol=1e-03)
|
204
|
+
|
205
|
+
print('test :: ', test_run)
|
206
|
+
print("Result looks good within given tolerance!!!")
|
207
|
+
|
208
|
+
|
209
|
+
print('***** Test MAX size :: Random float, vectors within 0:16 *****')
|
210
|
+
print(max_len)
|
211
|
+
|
212
|
+
for test_run in range(3):
|
213
|
+
|
214
|
+
vec = [[random.randint(0, 17) for i in range(max_len)]
|
215
|
+
for i in range(batch_size)]
|
216
|
+
src = torch.Tensor(vec).long()
|
217
|
+
"""
|
218
|
+
with open('test_data/test'+str(test_id)+'.txt', 'w') as f:
|
219
|
+
for ll in src.detach().tolist():
|
220
|
+
for item in ll:
|
221
|
+
f.write("%s " % item)
|
222
|
+
f.write("\n")
|
223
|
+
f.close()
|
224
|
+
"""
|
225
|
+
torch_out = dia.model(src, lengths)
|
226
|
+
|
227
|
+
#my_list = torch_out['diacritics'].detach().numpy().tolist()
|
228
|
+
"""
|
229
|
+
with open('test_data/test'+str(test_id)+'_torch.txt', 'w') as f:
|
230
|
+
for ll in my_list:
|
231
|
+
for item in ll:
|
232
|
+
for l in item:
|
233
|
+
f.write("%s " % l)
|
234
|
+
f.write("\n")
|
235
|
+
f.close()
|
236
|
+
test_id+=1
|
237
|
+
"""
|
238
|
+
# prepare onnx input
|
239
|
+
ort_inputs = {ort_session.get_inputs()[0].name: src.detach().numpy().astype(np.int64),
|
240
|
+
ort_session.get_inputs()[1].name: lengths.detach().numpy().astype(np.int64)}
|
241
|
+
|
242
|
+
# run onnx model
|
243
|
+
ort_outs = ort_session.run(None, ort_inputs)
|
244
|
+
|
245
|
+
for i in range(batch_size):
|
246
|
+
np.testing.assert_allclose(torch_out['diacritics'][i].detach().numpy(), \
|
247
|
+
ort_outs[0][i], rtol=1, atol=1)
|
248
|
+
|
249
|
+
print('test :: ', test_run)
|
250
|
+
print("Result looks good within given tolerance!!!")
|
251
|
+
|
252
|
+
|
253
|
+
print('***** Test Dynamical sizes :: Random Boolean vectors: *****')
|
254
|
+
|
255
|
+
for l in [2, 10, 40, 100, 150]:
|
256
|
+
|
257
|
+
print('length:: ', l)
|
258
|
+
|
259
|
+
vec = [[1 for i in range(l)] # random.randint(0,1)
|
260
|
+
for i in range(batch_size)]
|
261
|
+
src = torch.Tensor(vec).long()
|
262
|
+
lengths = torch.Tensor([l for i in range(batch_size)]).long()
|
263
|
+
|
264
|
+
torch_out = dia.model(src, lengths)
|
265
|
+
|
266
|
+
# prepare onnx input
|
267
|
+
ort_inputs = {ort_session.get_inputs()[0].name: src.detach().numpy().astype(np.int64),
|
268
|
+
ort_session.get_inputs()[1].name: lengths.detach().numpy().astype(np.int64)}
|
269
|
+
|
270
|
+
# run onnx model
|
271
|
+
ort_outs = ort_session.run(None, ort_inputs)
|
272
|
+
|
273
|
+
for i in range(batch_size):
|
274
|
+
np.testing.assert_allclose(torch_out['diacritics'][i].detach().numpy(), \
|
275
|
+
ort_outs[0][i], rtol=1e-03, atol=1e-03)
|
276
|
+
|
277
|
+
print('test :: ', l)
|
278
|
+
print("Result looks good within given tolerance!!!")
|
279
|
+
|
280
|
+
|
281
|
+
print('***** Test Dynamical sizes :: Random float, vectors within 0:16 *****')
|
282
|
+
|
283
|
+
for l in [2, 10, 40, 100, 150]:
|
284
|
+
|
285
|
+
vec = [[random.randint(0, 17) for i in range(l)]
|
286
|
+
for i in range(batch_size)]
|
287
|
+
src = torch.Tensor(vec).long()
|
288
|
+
lengths = torch.Tensor([l for i in range(batch_size)]).long()
|
289
|
+
|
290
|
+
torch_out = dia.model(src, lengths)
|
291
|
+
|
292
|
+
# prepare onnx input
|
293
|
+
ort_inputs = {ort_session.get_inputs()[0].name: src.detach().numpy().astype(np.int64),
|
294
|
+
ort_session.get_inputs()[1].name: lengths.detach().numpy().astype(np.int64)}
|
295
|
+
|
296
|
+
# run onnx model
|
297
|
+
ort_outs = ort_session.run(None, ort_inputs)
|
298
|
+
|
299
|
+
for i in range(batch_size):
|
300
|
+
np.testing.assert_allclose(torch_out['diacritics'][i].detach().numpy(), \
|
301
|
+
ort_outs[0][i], rtol=1, atol=1)
|
302
|
+
|
303
|
+
print('test :: ', l)
|
304
|
+
print("Result looks good within given tolerance!!!")
|
data/python/diacritizer.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1
1
|
from typing import Dict
|
2
2
|
import torch
|
3
3
|
import tqdm
|
4
|
+
import pandas as pd
|
5
|
+
import numpy as np
|
4
6
|
from config_manager import ConfigManager
|
5
7
|
from dataset import (DiacritizationDataset,
|
6
8
|
collate_fn)
|
@@ -9,7 +11,6 @@ from torch.utils.data import (DataLoader,
|
|
9
11
|
import util.reconcile_original_plus_diacritized as reconcile
|
10
12
|
|
11
13
|
|
12
|
-
|
13
14
|
class Diacritizer:
|
14
15
|
def __init__(
|
15
16
|
self, config_path: str, model_kind: str, load_model: bool = False
|
@@ -35,6 +36,7 @@ class Diacritizer:
|
|
35
36
|
|
36
37
|
def diacritize_text(self, text: str):
|
37
38
|
# convert string into indices
|
39
|
+
text = text.strip()
|
38
40
|
seq = self.text_encoder.input_to_sequence(text)
|
39
41
|
# transform indices into "batch data"
|
40
42
|
batch_data = {'original': [text],
|
@@ -48,25 +50,31 @@ class Diacritizer:
|
|
48
50
|
loader_params = {"batch_size": self.config_manager.config["batch_size"],
|
49
51
|
"shuffle": False,
|
50
52
|
"num_workers": 2}
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
53
|
+
|
54
|
+
data_tmp = pd.read_csv(path,
|
55
|
+
encoding="utf-8",
|
56
|
+
sep=self.config_manager.config["data_separator"],
|
57
|
+
header=None)
|
58
|
+
|
59
|
+
data = []
|
60
|
+
max_len = self.config_manager.config["max_len"]
|
61
|
+
for txt in [d[0] for d in data_tmp.values.tolist()]:
|
62
|
+
if len(txt) > max_len:
|
63
|
+
txt = txt[:max_len]
|
64
|
+
warnings.warn('Warning: text length cut for sentence: \n'+text)
|
65
|
+
data.append(txt)
|
66
|
+
|
67
|
+
list_ids = [idx for idx in range(len(data))]
|
68
|
+
dataset = DiacritizationDataset(self.config_manager,
|
69
|
+
list_ids,
|
70
|
+
data)
|
71
|
+
|
72
|
+
data_iterator = DataLoader(dataset,
|
73
|
+
collate_fn=collate_fn,
|
74
|
+
# **loader_params,
|
75
|
+
shuffle=False)
|
76
|
+
|
77
|
+
# print(f"Length of data iterator = {len(data_iterator)}")
|
70
78
|
return data_iterator
|
71
79
|
|
72
80
|
def diacritize_file(self, path: str):
|
@@ -75,6 +83,7 @@ class Diacritizer:
|
|
75
83
|
diacritized_data = []
|
76
84
|
for batch_inputs in tqdm.tqdm(data_iterator):
|
77
85
|
|
86
|
+
#batch_inputs["original"] = batch_inputs["original"].to(self.device)
|
78
87
|
batch_inputs["src"] = batch_inputs["src"].to(self.device)
|
79
88
|
batch_inputs["lengths"] = batch_inputs["lengths"].to('cpu')
|
80
89
|
batch_inputs["target"] = batch_inputs["target"].to(self.device)
|
@@ -85,7 +94,7 @@ class Diacritizer:
|
|
85
94
|
return diacritized_data
|
86
95
|
|
87
96
|
def diacritize_batch(self, batch):
|
88
|
-
#print('batch: ',batch)
|
97
|
+
# print('batch: ',batch)
|
89
98
|
self.model.eval()
|
90
99
|
originals = batch['original']
|
91
100
|
inputs = batch["src"]
|
@@ -93,25 +102,20 @@ class Diacritizer:
|
|
93
102
|
outputs = self.model(inputs.to(self.device), lengths.to("cpu"))
|
94
103
|
diacritics = outputs["diacritics"]
|
95
104
|
predictions = torch.max(diacritics, 2).indices
|
96
|
-
sentences = []
|
97
105
|
|
106
|
+
sentences = []
|
98
107
|
for src, prediction, original in zip(inputs, predictions, originals):
|
99
108
|
sentence = self.text_encoder.combine_text_and_haraqat(
|
100
|
-
|
101
|
-
|
102
|
-
)
|
109
|
+
list(src.detach().cpu().numpy()),
|
110
|
+
list(prediction.detach().cpu().numpy()))
|
103
111
|
# Diacritized strings, sentence have to be "reconciled"
|
104
112
|
# with original strings, because the non arabic strings are removed
|
105
113
|
# before being processed in nnet
|
106
|
-
|
114
|
+
if self.config['reconcile']:
|
115
|
+
sentence = reconcile.reconcile_strings(original, sentence)
|
107
116
|
sentences.append(sentence)
|
108
117
|
|
109
118
|
return sentences
|
110
119
|
|
111
120
|
def diacritize_iterators(self, iterator):
|
112
121
|
pass
|
113
|
-
|
114
|
-
""" not needed
|
115
|
-
class CBHGDiacritizer(Diacritizer):
|
116
|
-
class Seq2SeqDiacritizer(Diacritizer):
|
117
|
-
"""
|