rababa 0.1.0 → 0.1.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. checksums.yaml +4 -4
  2. data/.github/workflows/python.yml +81 -0
  3. data/.github/workflows/release.yml +36 -0
  4. data/.github/workflows/ruby.yml +27 -0
  5. data/.gitignore +3 -0
  6. data/.rubocop.yml +1 -1
  7. data/CODE_OF_CONDUCT.md +13 -13
  8. data/README.adoc +80 -0
  9. data/Rakefile +1 -1
  10. data/docs/{research-arabic-diacritization-06-2021.md → research-arabic-diacritization-06-2021.adoc} +52 -37
  11. data/exe/rababa +1 -1
  12. data/lib/README.adoc +95 -0
  13. data/lib/rababa/diacritizer.rb +16 -8
  14. data/lib/rababa/encoders.rb +2 -2
  15. data/lib/rababa/harakats.rb +1 -1
  16. data/lib/rababa/reconcile.rb +1 -33
  17. data/lib/rababa/version.rb +1 -1
  18. data/models-data/README.adoc +6 -0
  19. data/python/README.adoc +211 -0
  20. data/python/config/cbhg.yml +1 -1
  21. data/python/config/test_cbhg.yml +51 -0
  22. data/python/dataset.py +23 -31
  23. data/python/diacritization_model_to_onnx.py +216 -15
  24. data/python/diacritizer.py +35 -31
  25. data/python/log_dir/CA_MSA.base.cbhg/models/README.adoc +2 -0
  26. data/python/log_dir/README.adoc +1 -0
  27. data/python/{requirement.txt → requirements.txt} +1 -1
  28. data/python/setup.py +32 -0
  29. data/python/trainer.py +10 -4
  30. data/python/util/reconcile_original_plus_diacritized.py +2 -0
  31. data/python/util/text_cleaners.py +59 -4
  32. data/rababa.gemspec +1 -1
  33. data/test-datasets/data-arabic-pointing/{Readme.md → README.adoc} +2 -1
  34. metadata +22 -18
  35. data/.github/workflows/main.yml +0 -18
  36. data/README.md +0 -73
  37. data/lib/README.md +0 -82
  38. data/models-data/README.md +0 -6
  39. data/python/README.md +0 -163
  40. data/python/log_dir/CA_MSA.base.cbhg/models/Readme.md +0 -2
  41. data/python/log_dir/README.md +0 -1
@@ -16,7 +16,7 @@ diacritics_separator: '*' # Required if the data already processed
16
16
  text_encoder: ArabicEncoderWithStartSymbol
17
17
  text_cleaner: valid_arabic_cleaners # a white list that uses only Arabic letters, punctuations, and a space
18
18
  max_len: 600 # sentences larger than this size will not be used
19
-
19
+ reconcile: true
20
20
 
21
21
  max_steps: 2_000_000
22
22
  learning_rate: 0.001
@@ -0,0 +1,51 @@
1
+ session_name: base
2
+
3
+ data_directory: "data"
4
+ data_type: "CA_MSA"
5
+ log_directory: "log_dir"
6
+ load_training_data: true
7
+ load_test_data: false
8
+ load_validation_data: true
9
+ n_training_examples: null # null load all training examples, good for fast loading
10
+ n_test_examples: null # null load all test examples
11
+ n_validation_examples: null # null load all validation examples
12
+ test_file_name: "test.csv"
13
+ is_data_preprocessed: false # The data file is organized as (original text | text | diacritics)
14
+ data_separator: '|' # Required if the data already processed
15
+ diacritics_separator: '*' # Required if the data already processed
16
+ text_encoder: ArabicEncoderWithStartSymbol
17
+ text_cleaner: valid_arabic_cleaners # a white list that uses only Arabic letters, punctuations, and a space
18
+ max_len: 600 # sentences larger than this size will not be used
19
+ reconcile: true
20
+
21
+ max_steps: 50
22
+ learning_rate: 0.001
23
+ batch_size: 32
24
+ adam_beta1: 0.9
25
+ adam_beta2: 0.999
26
+ use_decay: true
27
+ weight_decay: 0.0
28
+ embedding_dim: 256
29
+ use_prenet: false
30
+ prenet_sizes: [512, 256]
31
+ cbhg_projections: [128, 256]
32
+ cbhg_filters: 16
33
+ cbhg_gru_units: 256
34
+ post_cbhg_layers_units: [256, 256]
35
+ post_cbhg_use_batch_norm: true
36
+
37
+ use_mixed_precision: false
38
+ optimizer_type: Adam
39
+ device: cuda
40
+
41
+ # LOGGING
42
+ evaluate_frequency: 5000
43
+ evaluate_with_error_rates_frequency: 5000
44
+ n_predicted_text_tensorboard: 10 # To be written to the tensorboard
45
+ model_save_frequency: 5000
46
+ train_plotting_frequency: 50000000 # No plotting for this model
47
+ n_steps_avg_losses: [100, 500, 1_000, 5_000] # command line display of average loss values for the last n steps
48
+ error_rates_n_batches: 10000 # if calculating error rate is slow, then you can specify the number of batches to be calculated
49
+
50
+ test_model_path: null # load the last saved model
51
+ train_resume_model_path: null # load last saved model
data/python/dataset.py CHANGED
@@ -4,10 +4,13 @@ Loading the diacritization dataset
4
4
 
5
5
  import os
6
6
 
7
- from diacritization_evaluation import util
7
+ import util.text_cleaners as cleaners
8
8
  import pandas as pd
9
9
  import torch
10
10
  import random
11
+ import warnings
12
+ from diacritization_evaluation import util
13
+
11
14
  from torch.utils.data import DataLoader, Dataset
12
15
 
13
16
  from config_manager import ConfigManager
@@ -15,7 +18,7 @@ from config_manager import ConfigManager
15
18
 
16
19
  class DiacritizationDataset(Dataset):
17
20
  """
18
- The diacritization dataset
21
+ The datasets for preprocessing for diacritization
19
22
  """
20
23
 
21
24
  def __init__(self, config_manager: ConfigManager, list_ids, data):
@@ -24,6 +27,7 @@ class DiacritizationDataset(Dataset):
24
27
  self.data = data
25
28
  self.text_encoder = config_manager.text_encoder
26
29
  self.config = config_manager.config
30
+ # print('config:: ', self.config)
27
31
 
28
32
  def __len__(self):
29
33
  "Denotes the total number of samples"
@@ -33,35 +37,22 @@ class DiacritizationDataset(Dataset):
33
37
  "Generates one sample of data"
34
38
  # Select sample
35
39
  id = self.list_ids[index]
36
- if self.config["is_data_preprocessed"]:
37
- data = self.data.iloc[id]
38
- inputs = torch.Tensor(self.text_encoder.input_to_sequence(data[1]))
39
- targets = torch.Tensor(
40
- self.text_encoder.target_to_sequence(
41
- data[2].split(self.config["diacritics_separator"])
42
- )
43
- )
44
- return inputs, targets, data[0]
45
-
46
- encoding_failed = True
47
- while encoding_failed:
48
- try:
49
- data = self.data[id]
50
- data = self.text_encoder.clean(data)
51
- text, inputs, diacritics = util.extract_haraqat(data)
52
- encoding_failed = False
53
- except:
54
- print('dataset.py :: error with that data')
55
- print('id: ', id)
56
- print('data: ', data)
57
- # text, inputs, diacritics = util.extract_haraqat(data[0])
58
- id = random.randint(0, len(data))
59
-
60
- inputs = torch.Tensor(self.text_encoder.input_to_sequence("".join(inputs)))
61
- diacritics = torch.Tensor(self.text_encoder.target_to_sequence(diacritics))
62
-
63
- return inputs, diacritics, text
40
+ data_orig = self.data[id].strip()
41
+ text, inputs, diacritics = cleaners.extract_haraqat(
42
+ self.text_encoder.clean(data_orig))
43
+
44
+ inputs = torch.Tensor(
45
+ self.text_encoder.input_to_sequence("".join(inputs)))
46
+ diacritics = torch.Tensor(
47
+ self.text_encoder.target_to_sequence(diacritics))
48
+
49
+ return inputs, diacritics, data_orig
64
50
 
51
+ #data = self.data[id]
52
+ #data = self.text_encoder.clean(data)
53
+ #text, inputs, diacritics = util.extract_haraqat(data)
54
+ #inputs = torch.Tensor(self.text_encoder.input_to_sequence("".join(inputs)))
55
+ #diacritics = torch.Tensor(self.text_encoder.target_to_sequence(diacritics))
65
56
 
66
57
  def collate_fn(data):
67
58
  """
@@ -164,7 +155,8 @@ def load_test_data(config_manager: ConfigManager, loader_parameters):
164
155
  config_manager, [idx for idx in range(len(test_data))], test_data
165
156
  )
166
157
 
167
- test_iterator = DataLoader(test_dataset, collate_fn=collate_fn, **loader_parameters)
158
+ test_iterator = DataLoader(test_dataset, collate_fn=collate_fn,
159
+ **loader_parameters)
168
160
 
169
161
  print(f"Length of test iterator = {len(test_iterator)}")
170
162
  return test_iterator
@@ -3,21 +3,21 @@ import pickle
3
3
 
4
4
  import numpy as np
5
5
 
6
- from diacritizer import CBHGDiacritizer
6
+ from diacritizer import Diacritizer
7
7
 
8
8
 
9
9
  """
10
10
  Key Params:
11
- max_len:
11
+ max_len:
12
12
  is the max length for the arabic strings to be diacritized
13
- batch size:
13
+ batch size:
14
14
  has to do with the model training and usage
15
15
  """
16
- max_len = 300 # 600 for the original length
16
+ max_len = 200 # 600 for the original length
17
17
  batch_size = 32
18
18
 
19
19
 
20
- """
20
+ """
21
21
  example and mock data:
22
22
  we found that populating all the data, removing the zeros gives better results.
23
23
  """
@@ -25,7 +25,7 @@ src = torch.Tensor([[1 for i in range(max_len)]
25
25
  for i in range(batch_size)]).long()
26
26
  lengths = torch.Tensor([max_len for i in range(batch_size)]).long()
27
27
  # example data
28
- batch_data = pickle.load( open('../models-data/batch_data.pkl', 'rb') )
28
+ batch_data = pickle.load( open('../models-data/batch_example_data.pkl', 'rb') )
29
29
 
30
30
  #target = batch_data['target']
31
31
 
@@ -37,7 +37,7 @@ model_kind_str = 'cbhg'
37
37
  config_str = 'config/cbhg.yml'
38
38
  load_model = True
39
39
 
40
- dia = CBHGDiacritizer(config_str, model_kind_str, load_model)
40
+ dia = Diacritizer(config_str, model_kind_str, load_model)
41
41
 
42
42
  # set model to inference mode
43
43
  dia.model.to(dia.device)
@@ -58,13 +58,22 @@ import onnxruntime
58
58
  onnx_model_filename = '../models-data/diacritization_model.onnx'
59
59
 
60
60
 
61
+ print(src.shape)
62
+
63
+ #exit()
61
64
  # export model
62
- torch.onnx.export(dia.model,
63
- (src, lengths),
64
- onnx_model_filename,
65
- verbose=False,
66
- opset_version=11,
67
- input_names=['src', 'lengths'])
65
+ torch.onnx.export(dia.model,
66
+ (src, lengths),
67
+ onnx_model_filename,
68
+ verbose=False,
69
+ opset_version=11,
70
+ input_names=['src', 'lengths'],
71
+ output_names=['output'],
72
+ dynamic_axes = {'src': [1], #[0,1,2], #[0,1,2],
73
+ #'input_2':{0:'batch'},
74
+ 'output': [1]
75
+ })
76
+
68
77
  print('Model printed in rel. path:', onnx_model_filename)
69
78
 
70
79
 
@@ -94,10 +103,202 @@ ort_inputs = {ort_session.get_inputs()[0].name: src.detach().numpy().astype(np.i
94
103
  # run onnx model
95
104
  ort_outs = ort_session.run(None, ort_inputs)
96
105
 
106
+ print('outs:: ', ort_outs)
97
107
 
98
- for i in range(batch_size):
99
- np.testing.assert_allclose(torch_out['diacritics'][i].detach().numpy(), ort_outs[0][i], rtol=1e-03, atol=1e-03)
100
108
 
109
+ print('src:: ', src.detach().numpy().astype(np.int64))
110
+ print('lengths: ',lengths.detach().numpy().astype(np.int64))
111
+
112
+ #exit()
113
+
114
+ for i in range(batch_size):
115
+ np.testing.assert_allclose(torch_out['diacritics'][i].detach().numpy(),
116
+ ort_outs[0][i], rtol=1e-03, atol=1e-03)
101
117
 
102
118
  print("\n!!!Exported model has been tested with ONNXRuntime, result looks good within given tolerance!!!")
103
119
 
120
+
121
+
122
+ vec = [[41, 12, 40] for i in range(batch_size)]
123
+ src = torch.Tensor(vec).long()
124
+
125
+ lengths = torch.Tensor([3 for i in range(batch_size)]).long()
126
+
127
+ ort_inputs = {ort_session.get_inputs()[0].name: src.detach().numpy().astype(np.int64),
128
+ ort_session.get_inputs()[1].name: lengths.detach().numpy().astype(np.int64)}
129
+
130
+
131
+ #print('12345678910')
132
+ #print(ort_session.get_inputs()[0].name)
133
+ #print(ort_session.get_inputs()[1].name)
134
+
135
+ print('run 3')
136
+ ort_outs = ort_session.run(None, ort_inputs)
137
+ print('outs:: ', ort_outs[0].shape)
138
+ print('outs:: ', ort_outs[0][0][0])
139
+ print('outs:: ', ort_outs[0][0][1])
140
+ print('outs:: ', ort_outs[0][0][2])
141
+
142
+ torch_out = dia.model(src, lengths)
143
+
144
+ #print(torch_out['diacritics'][0])
145
+ for i in range(batch_size):
146
+ np.testing.assert_allclose(torch_out['diacritics'][i].detach().numpy(), \
147
+ ort_outs[0][i], rtol=1e-03, atol=1e-03)
148
+
149
+ print('12345678910')
150
+ print(ort_session.get_inputs()[0].name)
151
+ print(ort_session.get_inputs()[1].name)
152
+
153
+ #exit()
154
+
155
+
156
+
157
+ """
158
+ Test ONNX model on randomized data
159
+ """
160
+
161
+ import random
162
+ test_id = 0
163
+
164
+ print('***** Test MAX size :: Random Boolean vectors: *****')
165
+ print(max_len)
166
+
167
+ for test_run in range(3):
168
+
169
+ vec = [[random.randint(0,1) for i in range(max_len)]
170
+ for i in range(batch_size)]
171
+ src = torch.Tensor(vec).long()
172
+ lengths = torch.Tensor([max_len for i in range(batch_size)]).long()
173
+
174
+ """
175
+ with open('test_data/test'+str(test_id)+'.txt', 'w') as f:
176
+ for ll in src.detach().tolist():
177
+ for item in ll:
178
+ f.write("%s " % item)
179
+ f.write("\n")
180
+ f.close()
181
+ """
182
+ torch_out = dia.model(src, lengths)
183
+ """
184
+ my_list = torch_out['diacritics'].detach().numpy().tolist()
185
+ with open('test_data/test'+str(test_id)+'_torch.txt', 'w') as f:
186
+ for ll in my_list:
187
+ for item in ll:
188
+ for l in item:
189
+ f.write("%s " % l)
190
+ f.write("\n")
191
+ f.close()
192
+ test_id+=1
193
+ """
194
+ # prepare onnx input
195
+ ort_inputs = {ort_session.get_inputs()[0].name: src.detach().numpy().astype(np.int64),
196
+ ort_session.get_inputs()[1].name: lengths.detach().numpy().astype(np.int64)}
197
+
198
+ # run onnx model
199
+ ort_outs = ort_session.run(None, ort_inputs)
200
+
201
+ for i in range(batch_size):
202
+ np.testing.assert_allclose(torch_out['diacritics'][i].detach().numpy(), \
203
+ ort_outs[0][i], rtol=1e-03, atol=1e-03)
204
+
205
+ print('test :: ', test_run)
206
+ print("Result looks good within given tolerance!!!")
207
+
208
+
209
+ print('***** Test MAX size :: Random float, vectors within 0:16 *****')
210
+ print(max_len)
211
+
212
+ for test_run in range(3):
213
+
214
+ vec = [[random.randint(0, 17) for i in range(max_len)]
215
+ for i in range(batch_size)]
216
+ src = torch.Tensor(vec).long()
217
+ """
218
+ with open('test_data/test'+str(test_id)+'.txt', 'w') as f:
219
+ for ll in src.detach().tolist():
220
+ for item in ll:
221
+ f.write("%s " % item)
222
+ f.write("\n")
223
+ f.close()
224
+ """
225
+ torch_out = dia.model(src, lengths)
226
+
227
+ #my_list = torch_out['diacritics'].detach().numpy().tolist()
228
+ """
229
+ with open('test_data/test'+str(test_id)+'_torch.txt', 'w') as f:
230
+ for ll in my_list:
231
+ for item in ll:
232
+ for l in item:
233
+ f.write("%s " % l)
234
+ f.write("\n")
235
+ f.close()
236
+ test_id+=1
237
+ """
238
+ # prepare onnx input
239
+ ort_inputs = {ort_session.get_inputs()[0].name: src.detach().numpy().astype(np.int64),
240
+ ort_session.get_inputs()[1].name: lengths.detach().numpy().astype(np.int64)}
241
+
242
+ # run onnx model
243
+ ort_outs = ort_session.run(None, ort_inputs)
244
+
245
+ for i in range(batch_size):
246
+ np.testing.assert_allclose(torch_out['diacritics'][i].detach().numpy(), \
247
+ ort_outs[0][i], rtol=1, atol=1)
248
+
249
+ print('test :: ', test_run)
250
+ print("Result looks good within given tolerance!!!")
251
+
252
+
253
+ print('***** Test Dynamical sizes :: Random Boolean vectors: *****')
254
+
255
+ for l in [2, 10, 40, 100, 150]:
256
+
257
+ print('length:: ', l)
258
+
259
+ vec = [[1 for i in range(l)] # random.randint(0,1)
260
+ for i in range(batch_size)]
261
+ src = torch.Tensor(vec).long()
262
+ lengths = torch.Tensor([l for i in range(batch_size)]).long()
263
+
264
+ torch_out = dia.model(src, lengths)
265
+
266
+ # prepare onnx input
267
+ ort_inputs = {ort_session.get_inputs()[0].name: src.detach().numpy().astype(np.int64),
268
+ ort_session.get_inputs()[1].name: lengths.detach().numpy().astype(np.int64)}
269
+
270
+ # run onnx model
271
+ ort_outs = ort_session.run(None, ort_inputs)
272
+
273
+ for i in range(batch_size):
274
+ np.testing.assert_allclose(torch_out['diacritics'][i].detach().numpy(), \
275
+ ort_outs[0][i], rtol=1e-03, atol=1e-03)
276
+
277
+ print('test :: ', l)
278
+ print("Result looks good within given tolerance!!!")
279
+
280
+
281
+ print('***** Test Dynamical sizes :: Random float, vectors within 0:16 *****')
282
+
283
+ for l in [2, 10, 40, 100, 150]:
284
+
285
+ vec = [[random.randint(0, 17) for i in range(l)]
286
+ for i in range(batch_size)]
287
+ src = torch.Tensor(vec).long()
288
+ lengths = torch.Tensor([l for i in range(batch_size)]).long()
289
+
290
+ torch_out = dia.model(src, lengths)
291
+
292
+ # prepare onnx input
293
+ ort_inputs = {ort_session.get_inputs()[0].name: src.detach().numpy().astype(np.int64),
294
+ ort_session.get_inputs()[1].name: lengths.detach().numpy().astype(np.int64)}
295
+
296
+ # run onnx model
297
+ ort_outs = ort_session.run(None, ort_inputs)
298
+
299
+ for i in range(batch_size):
300
+ np.testing.assert_allclose(torch_out['diacritics'][i].detach().numpy(), \
301
+ ort_outs[0][i], rtol=1, atol=1)
302
+
303
+ print('test :: ', l)
304
+ print("Result looks good within given tolerance!!!")
@@ -1,6 +1,8 @@
1
1
  from typing import Dict
2
2
  import torch
3
3
  import tqdm
4
+ import pandas as pd
5
+ import numpy as np
4
6
  from config_manager import ConfigManager
5
7
  from dataset import (DiacritizationDataset,
6
8
  collate_fn)
@@ -9,7 +11,6 @@ from torch.utils.data import (DataLoader,
9
11
  import util.reconcile_original_plus_diacritized as reconcile
10
12
 
11
13
 
12
-
13
14
  class Diacritizer:
14
15
  def __init__(
15
16
  self, config_path: str, model_kind: str, load_model: bool = False
@@ -35,6 +36,7 @@ class Diacritizer:
35
36
 
36
37
  def diacritize_text(self, text: str):
37
38
  # convert string into indices
39
+ text = text.strip()
38
40
  seq = self.text_encoder.input_to_sequence(text)
39
41
  # transform indices into "batch data"
40
42
  batch_data = {'original': [text],
@@ -48,25 +50,31 @@ class Diacritizer:
48
50
  loader_params = {"batch_size": self.config_manager.config["batch_size"],
49
51
  "shuffle": False,
50
52
  "num_workers": 2}
51
- # data processed or not, specs in config file
52
- if self.config_manager.config["is_data_preprocessed"]:
53
- data = pd.read_csv(path,
54
- encoding="utf-8",
55
- sep=self.config_manager.config["data_separator"],
56
- nrows=self.config_manager.config["n_validation_examples"],
57
- header=None)
58
-
59
- # data = data[data[0] <= config_manager.config["max_len"]]
60
- dataset = DiacritizationDataset(self.config_manager, data.index, data)
61
- else:
62
- with open(path, encoding="utf8") as file:
63
- data = file.readlines()
64
- data = [text for text in data if len(text) <= self.config_manager.config["max_len"]]
65
- dataset = DiacritizationDataset(
66
- self.config_manager, [idx for idx in range(len(data))], data)
67
-
68
- data_iterator = DataLoader(dataset, collate_fn=collate_fn, **loader_params)
69
- # print(f"Length of data iterator = {len(valid_iterator)}")
53
+
54
+ data_tmp = pd.read_csv(path,
55
+ encoding="utf-8",
56
+ sep=self.config_manager.config["data_separator"],
57
+ header=None)
58
+
59
+ data = []
60
+ max_len = self.config_manager.config["max_len"]
61
+ for txt in [d[0] for d in data_tmp.values.tolist()]:
62
+ if len(txt) > max_len:
63
+ txt = txt[:max_len]
64
+ warnings.warn('Warning: text length cut for sentence: \n'+text)
65
+ data.append(txt)
66
+
67
+ list_ids = [idx for idx in range(len(data))]
68
+ dataset = DiacritizationDataset(self.config_manager,
69
+ list_ids,
70
+ data)
71
+
72
+ data_iterator = DataLoader(dataset,
73
+ collate_fn=collate_fn,
74
+ # **loader_params,
75
+ shuffle=False)
76
+
77
+ # print(f"Length of data iterator = {len(data_iterator)}")
70
78
  return data_iterator
71
79
 
72
80
  def diacritize_file(self, path: str):
@@ -75,6 +83,7 @@ class Diacritizer:
75
83
  diacritized_data = []
76
84
  for batch_inputs in tqdm.tqdm(data_iterator):
77
85
 
86
+ #batch_inputs["original"] = batch_inputs["original"].to(self.device)
78
87
  batch_inputs["src"] = batch_inputs["src"].to(self.device)
79
88
  batch_inputs["lengths"] = batch_inputs["lengths"].to('cpu')
80
89
  batch_inputs["target"] = batch_inputs["target"].to(self.device)
@@ -85,7 +94,7 @@ class Diacritizer:
85
94
  return diacritized_data
86
95
 
87
96
  def diacritize_batch(self, batch):
88
- #print('batch: ',batch)
97
+ # print('batch: ',batch)
89
98
  self.model.eval()
90
99
  originals = batch['original']
91
100
  inputs = batch["src"]
@@ -93,25 +102,20 @@ class Diacritizer:
93
102
  outputs = self.model(inputs.to(self.device), lengths.to("cpu"))
94
103
  diacritics = outputs["diacritics"]
95
104
  predictions = torch.max(diacritics, 2).indices
96
- sentences = []
97
105
 
106
+ sentences = []
98
107
  for src, prediction, original in zip(inputs, predictions, originals):
99
108
  sentence = self.text_encoder.combine_text_and_haraqat(
100
- list(src.detach().cpu().numpy()),
101
- list(prediction.detach().cpu().numpy()),
102
- )
109
+ list(src.detach().cpu().numpy()),
110
+ list(prediction.detach().cpu().numpy()))
103
111
  # Diacritized strings, sentence have to be "reconciled"
104
112
  # with original strings, because the non arabic strings are removed
105
113
  # before being processed in nnet
106
- sentence = reconcile.reconcile_strings(original, sentence)
114
+ if self.config['reconcile']:
115
+ sentence = reconcile.reconcile_strings(original, sentence)
107
116
  sentences.append(sentence)
108
117
 
109
118
  return sentences
110
119
 
111
120
  def diacritize_iterators(self, iterator):
112
121
  pass
113
-
114
- """ not needed
115
- class CBHGDiacritizer(Diacritizer):
116
- class Seq2SeqDiacritizer(Diacritizer):
117
- """