openocr-python 0.0.9__py3-none-any.whl → 0.1.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openocr/__init__.py +35 -1
- openocr/configs/dataset/rec/evaluation.yaml +41 -0
- openocr/configs/dataset/rec/ltb.yaml +9 -0
- openocr/configs/dataset/rec/mjsynth.yaml +11 -0
- openocr/configs/dataset/rec/openvino.yaml +25 -0
- openocr/configs/dataset/rec/ost.yaml +17 -0
- openocr/configs/dataset/rec/synthtext.yaml +7 -0
- openocr/configs/dataset/rec/test.yaml +77 -0
- openocr/configs/dataset/rec/textocr.yaml +13 -0
- openocr/configs/dataset/rec/textocr_horizontal.yaml +13 -0
- openocr/configs/dataset/rec/union14m_b.yaml +47 -0
- openocr/configs/dataset/rec/union14m_l_filtered.yaml +35 -0
- openocr/configs/rec/cmer/cmer.yml +127 -0
- openocr/configs/rec/mdiff4str/svtrv2_mdiffdecoder_base.yml +152 -0
- openocr/configs/rec/mdiff4str/svtrv2_mdiffdecoder_small.yml +152 -0
- openocr/configs/rec/unirec/focalsvtr_ardecoder_unirec.yml +114 -0
- openocr/configs/rec/unirec/opendoc_pipeline.yml +105 -0
- openocr/demo_gradio.py +28 -8
- openocr/demo_opendoc.py +572 -0
- openocr/demo_unirec.py +392 -0
- openocr/opendet/losses/__init__.py +5 -7
- openocr/opendet/preprocess/crop_resize.py +2 -1
- openocr/openocr.py +685 -0
- openocr/openrec/losses/__init__.py +8 -3
- openocr/openrec/losses/cmer_loss.py +12 -0
- openocr/openrec/losses/mdiff_loss.py +11 -0
- openocr/openrec/losses/unirec_loss.py +12 -0
- openocr/openrec/metrics/__init__.py +4 -1
- openocr/openrec/metrics/rec_metric_cmer.py +328 -0
- openocr/openrec/modeling/cmer_modeling/modeling_cmer.py +643 -0
- openocr/openrec/modeling/decoders/__init__.py +1 -0
- openocr/openrec/modeling/decoders/ctc_decoder.py +1 -1
- openocr/openrec/modeling/decoders/dan_decoder.py +4 -4
- openocr/openrec/modeling/decoders/dptr_parseq_clip_b_decoder.py +1563 -1398
- openocr/openrec/modeling/decoders/mdiff_decoder.py +587 -0
- openocr/openrec/modeling/decoders/smtr_decoder.py +99 -48
- openocr/openrec/modeling/unirec_modeling/configuration_unirec.py +166 -0
- openocr/openrec/modeling/unirec_modeling/modeling_unirec.py +433 -0
- openocr/openrec/optimizer/__init__.py +4 -3
- openocr/openrec/optimizer/lr.py +49 -0
- openocr/openrec/postprocess/__init__.py +2 -0
- openocr/openrec/postprocess/abinet_postprocess.py +1 -1
- openocr/openrec/postprocess/ar_postprocess.py +1 -1
- openocr/openrec/postprocess/cmer_postprocess.py +86 -0
- openocr/openrec/postprocess/cppd_postprocess.py +1 -1
- openocr/openrec/postprocess/igtr_postprocess.py +1 -1
- openocr/openrec/postprocess/lister_postprocess.py +1 -1
- openocr/openrec/postprocess/mgp_postprocess.py +1 -1
- openocr/openrec/postprocess/nrtr_postprocess.py +2 -2
- openocr/openrec/postprocess/smtr_postprocess.py +1 -1
- openocr/openrec/postprocess/srn_postprocess.py +1 -1
- openocr/openrec/postprocess/unirec_postprocess.py +58 -0
- openocr/openrec/postprocess/visionlan_postprocess.py +1 -1
- openocr/openrec/preprocess/__init__.py +5 -0
- openocr/openrec/preprocess/ce_label_encode.py +1 -1
- openocr/openrec/preprocess/cmer_label_encode.py +1025 -0
- openocr/openrec/preprocess/ctc_label_encode.py +1 -1
- openocr/openrec/preprocess/dptr_label_encode.py +177 -157
- openocr/openrec/preprocess/igtr_label_encode.py +4 -2
- openocr/openrec/preprocess/mdiff_label_encode.py +312 -0
- openocr/openrec/preprocess/rec_aug.py +128 -2
- openocr/openrec/preprocess/resize.py +57 -0
- openocr/openrec/preprocess/unirec_label_encode.py +62 -0
- openocr/tools/data/__init__.py +78 -55
- openocr/tools/data/cmer_web_dataset.py +310 -0
- openocr/tools/data/native_size_dataset.py +753 -0
- openocr/tools/data/native_size_sampler.py +158 -0
- openocr/tools/data/ratio_dataset_tvresize.py +2 -0
- openocr/tools/data/ratio_sampler.py +2 -1
- openocr/tools/download/download_dataset.py +38 -0
- openocr/tools/download/utils.py +28 -0
- openocr/tools/download_example_images.py +236 -0
- openocr/tools/engine/trainer.py +155 -39
- openocr/tools/eval_rec_all_ch.py +2 -2
- openocr/tools/infer_det.py +20 -2
- openocr/tools/infer_doc.py +898 -0
- openocr/tools/infer_doc_onnx.py +1172 -0
- openocr/tools/infer_e2e.py +27 -10
- openocr/tools/infer_rec.py +64 -15
- openocr/tools/infer_unirec_onnx.py +730 -0
- openocr/tools/to_markdown.py +468 -0
- openocr/tools/utils/ckpt.py +17 -5
- openocr/tools/utils/opendoc_onnx_utils/utils.py +1052 -0
- openocr_python-0.1.0.dev0.dist-info/METADATA +324 -0
- {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/RECORD +89 -45
- {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/WHEEL +1 -1
- openocr_python-0.1.0.dev0.dist-info/entry_points.txt +2 -0
- openocr_python-0.0.9.dist-info/METADATA +0 -149
- /openocr_python-0.0.9.dist-info/LICENCE → /openocr_python-0.1.0.dev0.dist-info/licenses/LICENSE +0 -0
- {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/top_level.txt +0 -0
openocr/tools/engine/trainer.py
CHANGED
|
@@ -4,6 +4,7 @@ import random
|
|
|
4
4
|
import time
|
|
5
5
|
|
|
6
6
|
import numpy as np
|
|
7
|
+
import torch.amp
|
|
7
8
|
from tqdm import tqdm
|
|
8
9
|
|
|
9
10
|
import torch
|
|
@@ -16,6 +17,14 @@ from tools.utils.utility import AverageMeter
|
|
|
16
17
|
|
|
17
18
|
__all__ = ['Trainer']
|
|
18
19
|
|
|
20
|
+
import torch.distributed as dist
|
|
21
|
+
|
|
22
|
+
rank = int(os.environ.get('RANK', 0)) # torchrun 会提供 RANK
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def is_main_process():
|
|
26
|
+
return (not dist.is_available() or not dist.is_initialized() or rank == 0)
|
|
27
|
+
|
|
19
28
|
|
|
20
29
|
def get_parameter_number(model):
|
|
21
30
|
total_num = sum(p.numel() for p in model.parameters())
|
|
@@ -52,9 +61,14 @@ class Trainer(object):
|
|
|
52
61
|
os.makedirs(self.cfg['Global']['output_dir'], exist_ok=True)
|
|
53
62
|
|
|
54
63
|
self.writer = None
|
|
55
|
-
if
|
|
56
|
-
|
|
64
|
+
if is_main_process(
|
|
65
|
+
) and self.cfg['Global']['use_tensorboard'] and 'train' in mode:
|
|
66
|
+
import wandb
|
|
57
67
|
from torch.utils.tensorboard import SummaryWriter
|
|
68
|
+
wandb.init(project='demo-sync-tb',
|
|
69
|
+
name=self.cfg['Global'].get('run_name',
|
|
70
|
+
'log_wandb_openocr'),
|
|
71
|
+
sync_tensorboard=True)
|
|
58
72
|
|
|
59
73
|
self.writer = SummaryWriter(self.cfg['Global']['output_dir'])
|
|
60
74
|
|
|
@@ -74,9 +88,10 @@ class Trainer(object):
|
|
|
74
88
|
# build data loader
|
|
75
89
|
self.train_dataloader = None
|
|
76
90
|
if 'train' in mode:
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
91
|
+
if is_main_process():
|
|
92
|
+
cfg.save(
|
|
93
|
+
os.path.join(self.cfg['Global']['output_dir'],
|
|
94
|
+
'config.yml'), self.cfg)
|
|
80
95
|
self.train_dataloader = build_dataloader(self.cfg,
|
|
81
96
|
'Train',
|
|
82
97
|
self.logger,
|
|
@@ -107,16 +122,23 @@ class Trainer(object):
|
|
|
107
122
|
self.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
|
|
108
123
|
self.model)
|
|
109
124
|
self.logger.info('convert_sync_batchnorm')
|
|
110
|
-
|
|
125
|
+
self.accumulation_steps = self.cfg['Global'].get(
|
|
126
|
+
'accumulation_steps', 1)
|
|
111
127
|
from openrec.optimizer import build_optimizer
|
|
112
128
|
self.optimizer, self.lr_scheduler = None, None
|
|
129
|
+
epochs = self.cfg['Global']['epoch_num']
|
|
130
|
+
try:
|
|
131
|
+
step_each_epoch = len(self.train_dataloader)
|
|
132
|
+
except TypeError:
|
|
133
|
+
# 针对 IterableDataset 的处理
|
|
134
|
+
step_each_epoch = self.cfg['Global'].get('total_iter_steps', 100000)
|
|
113
135
|
if self.train_dataloader is not None:
|
|
114
136
|
# build optim
|
|
115
137
|
self.optimizer, self.lr_scheduler = build_optimizer(
|
|
116
138
|
self.cfg['Optimizer'],
|
|
117
139
|
self.cfg['LRScheduler'],
|
|
118
|
-
epochs=
|
|
119
|
-
step_each_epoch=
|
|
140
|
+
epochs=epochs,
|
|
141
|
+
step_each_epoch=step_each_epoch,
|
|
120
142
|
model=self.model,
|
|
121
143
|
)
|
|
122
144
|
self.grad_clip_val = self.cfg['Global'].get('grad_clip_val', 0)
|
|
@@ -129,7 +151,7 @@ class Trainer(object):
|
|
|
129
151
|
self.model, [self.local_rank], find_unused_parameters=False)
|
|
130
152
|
|
|
131
153
|
# amp
|
|
132
|
-
self.scaler = (torch.
|
|
154
|
+
self.scaler = (torch.amp.GradScaler() if self.cfg['Global'].get(
|
|
133
155
|
'use_amp', False) else None)
|
|
134
156
|
|
|
135
157
|
self.logger.info(
|
|
@@ -146,9 +168,28 @@ class Trainer(object):
|
|
|
146
168
|
self.cfg['PostProcess'], self.cfg['Global'])
|
|
147
169
|
# build model
|
|
148
170
|
# for rec algorithm
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
171
|
+
self.use_transformers = self.cfg['Global'].get('use_transformers',
|
|
172
|
+
False)
|
|
173
|
+
if self.use_transformers:
|
|
174
|
+
if self.cfg['Architecture']['algorithm'] == 'UniRec':
|
|
175
|
+
from openrec.modeling.unirec_modeling.modeling_unirec import UniRecForConditionalGenerationNew
|
|
176
|
+
from openrec.modeling.unirec_modeling.configuration_unirec import UniRecConfig
|
|
177
|
+
cfg_vlm = UniRecConfig.from_pretrained(
|
|
178
|
+
self.cfg['Global']['vlm_ocr_config'])
|
|
179
|
+
cfg_vlm._attn_implementation = 'flash_attention_2'
|
|
180
|
+
# cfg_vlm._attn_implementation = "eager"
|
|
181
|
+
# cfg_vlm._attn_implementation = "sdpa"
|
|
182
|
+
self.model = UniRecForConditionalGenerationNew(config=cfg_vlm)
|
|
183
|
+
elif self.cfg['Architecture']['algorithm'] == 'CMER':
|
|
184
|
+
from openrec.modeling.cmer_modeling.modeling_cmer import CMER, CMERConfig
|
|
185
|
+
cfg_model = CMERConfig(
|
|
186
|
+
self.cfg['Architecture']['vision_config'],
|
|
187
|
+
self.cfg['Architecture']['decoder_config'])
|
|
188
|
+
self.model = CMER(config=cfg_model)
|
|
189
|
+
else:
|
|
190
|
+
char_num = self.post_process_class.get_character_num()
|
|
191
|
+
self.cfg['Architecture']['Decoder']['out_channels'] = char_num
|
|
192
|
+
self.model = build_rec_model(self.cfg['Architecture'])
|
|
152
193
|
# build loss
|
|
153
194
|
self.loss_class = build_rec_loss(self.cfg['Loss'])
|
|
154
195
|
# build metric
|
|
@@ -247,34 +288,88 @@ class Trainer(object):
|
|
|
247
288
|
train_batch_cost = 0.0
|
|
248
289
|
reader_start = time.time()
|
|
249
290
|
eta_meter = AverageMeter()
|
|
250
|
-
|
|
291
|
+
save_iter_step = self.cfg['Global'].get('save_iter_step',
|
|
292
|
+
[10e10, 2000])
|
|
293
|
+
start_save_iter = save_iter_step[0]
|
|
294
|
+
save_iter_step = save_iter_step[1]
|
|
295
|
+
|
|
296
|
+
if self.cfg['Global'].get('resume_from_iter',
|
|
297
|
+
False): # for unirec resume training
|
|
298
|
+
if self.cfg['Global']['checkpoints'] is None:
|
|
299
|
+
raise ValueError(
|
|
300
|
+
'resume_from_iter is True, but checkpoints is None')
|
|
301
|
+
start_epoch = start_epoch - 1
|
|
302
|
+
self.resume_iter = global_step
|
|
303
|
+
iter_model_file_name = os.path.basename(
|
|
304
|
+
self.cfg['Global']['checkpoints'])
|
|
305
|
+
last_whole_epoch_global_step = iter_model_file_name.split('_')[1]
|
|
306
|
+
self.cfg['Train']['sampler'][
|
|
307
|
+
'resume_iter'] = self.resume_iter - last_whole_epoch_global_step
|
|
308
|
+
|
|
309
|
+
last_whole_epoch_global_step = 0
|
|
251
310
|
for epoch in range(start_epoch, epoch_num + 1):
|
|
252
|
-
if self.
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
311
|
+
if not self.cfg['Global'].get('resume_from_iter',
|
|
312
|
+
False): # for unirec resume training
|
|
313
|
+
if 'sampler' in self.cfg['Train']:
|
|
314
|
+
self.cfg['Train']['sampler']['resume_iter'] = 0
|
|
315
|
+
if hasattr(self.train_dataloader, "dataset") and self.train_dataloader.dataset is not None:
|
|
316
|
+
if self.train_dataloader.dataset.need_reset and epoch > 1:
|
|
317
|
+
self.train_dataloader = build_dataloader(self.cfg,
|
|
318
|
+
'Train',
|
|
319
|
+
self.logger,
|
|
320
|
+
epoch=epoch,
|
|
321
|
+
task=self.task)
|
|
258
322
|
|
|
259
323
|
for idx, batch in enumerate(self.train_dataloader):
|
|
324
|
+
if self.cfg['Global'].get('resume_from_iter',
|
|
325
|
+
False): # for unirec resume training
|
|
326
|
+
if global_step != self.resume_iter:
|
|
327
|
+
global_step += 1
|
|
328
|
+
if is_main_process(
|
|
329
|
+
) and global_step % print_batch_step == 0:
|
|
330
|
+
self.logger.info(
|
|
331
|
+
f'skip iter {global_step}, resume from iter {self.resume_iter}'
|
|
332
|
+
)
|
|
333
|
+
continue
|
|
334
|
+
else:
|
|
335
|
+
global_step += 1
|
|
336
|
+
self.cfg['Global']['resume_from_iter'] = False
|
|
337
|
+
self.logger.info(
|
|
338
|
+
f'resume from iter {self.resume_iter}, start training from iter {global_step}'
|
|
339
|
+
)
|
|
340
|
+
continue
|
|
341
|
+
|
|
260
342
|
batch_tensor = [t.to(self.device) for t in batch]
|
|
261
343
|
batch_numpy = [t.numpy() for t in batch]
|
|
262
|
-
self.optimizer.zero_grad()
|
|
263
344
|
train_reader_cost += time.time() - reader_start
|
|
264
345
|
# use amp
|
|
265
346
|
if self.scaler:
|
|
266
|
-
with torch.
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
347
|
+
with torch.amp.autocast(device_type=self.device.type,
|
|
348
|
+
dtype=torch.bfloat16):
|
|
349
|
+
if self.use_transformers:
|
|
350
|
+
inputs = {
|
|
351
|
+
'pixel_values': batch_tensor[0],
|
|
352
|
+
'input_ids': None,
|
|
353
|
+
'attention_mask': None,
|
|
354
|
+
'labels': batch_tensor[1],
|
|
355
|
+
'length': batch_tensor[2]
|
|
356
|
+
}
|
|
357
|
+
preds = self.model(**inputs)
|
|
358
|
+
else:
|
|
359
|
+
preds = self.model(batch_tensor[0],
|
|
360
|
+
data=batch_tensor[1:])
|
|
270
361
|
loss = self.loss_class(preds, batch_tensor)
|
|
362
|
+
loss['loss'] = loss['loss'] / self.accumulation_steps
|
|
271
363
|
self.scaler.scale(loss['loss']).backward()
|
|
272
|
-
if self.
|
|
273
|
-
|
|
274
|
-
self.
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
364
|
+
if (global_step + 1) % self.accumulation_steps == 0:
|
|
365
|
+
if self.grad_clip_val > 0:
|
|
366
|
+
self.scaler.unscale_(self.optimizer)
|
|
367
|
+
torch.nn.utils.clip_grad_norm_(
|
|
368
|
+
self.model.parameters(),
|
|
369
|
+
max_norm=self.grad_clip_val)
|
|
370
|
+
self.scaler.step(self.optimizer)
|
|
371
|
+
self.scaler.update()
|
|
372
|
+
self.optimizer.zero_grad(set_to_none=True)
|
|
278
373
|
else:
|
|
279
374
|
preds = self.model(batch_tensor[0], data=batch_tensor[1:])
|
|
280
375
|
loss = self.loss_class(preds, batch_tensor)
|
|
@@ -300,8 +395,14 @@ class Trainer(object):
|
|
|
300
395
|
global_step += 1
|
|
301
396
|
total_samples += len(batch[0])
|
|
302
397
|
|
|
303
|
-
|
|
398
|
+
try:
|
|
399
|
+
self.lr_scheduler.step()
|
|
400
|
+
except Exception as e:
|
|
401
|
+
self.logger.info(
|
|
402
|
+
f'lr_scheduler step error, {e}, please check your config'
|
|
403
|
+
)
|
|
304
404
|
|
|
405
|
+
loss['loss'] = loss['loss'] * self.accumulation_steps
|
|
305
406
|
# logger
|
|
306
407
|
stats = {
|
|
307
408
|
k: float(v)
|
|
@@ -315,8 +416,9 @@ class Trainer(object):
|
|
|
315
416
|
for k, v in train_stats.get().items():
|
|
316
417
|
self.writer.add_scalar(f'TRAIN/{k}', v, global_step)
|
|
317
418
|
|
|
318
|
-
if
|
|
319
|
-
(global_step > 0 and global_step % print_batch_step == 0)
|
|
419
|
+
if is_main_process() and (
|
|
420
|
+
(global_step > 0 and global_step % print_batch_step == 0)
|
|
421
|
+
or (idx >= len(self.train_dataloader) - 1)):
|
|
320
422
|
logs = train_stats.log()
|
|
321
423
|
|
|
322
424
|
eta_sec = (
|
|
@@ -337,16 +439,31 @@ class Trainer(object):
|
|
|
337
439
|
train_batch_cost = 0.0
|
|
338
440
|
reader_start = time.time()
|
|
339
441
|
# eval iter step
|
|
340
|
-
if (global_step > start_eval_step and
|
|
341
|
-
|
|
442
|
+
if is_main_process() and (global_step > start_eval_step and
|
|
443
|
+
(global_step - start_eval_step) %
|
|
444
|
+
eval_batch_step == 0):
|
|
342
445
|
self.eval_step(global_step, epoch)
|
|
446
|
+
# save iter step
|
|
447
|
+
if is_main_process(
|
|
448
|
+
) and global_step > start_save_iter and global_step % save_iter_step == 0:
|
|
449
|
+
save_ckpt(
|
|
450
|
+
self.model,
|
|
451
|
+
self.cfg,
|
|
452
|
+
self.optimizer,
|
|
453
|
+
self.lr_scheduler,
|
|
454
|
+
epoch,
|
|
455
|
+
global_step,
|
|
456
|
+
self.best_metric,
|
|
457
|
+
is_best=False,
|
|
458
|
+
prefix=
|
|
459
|
+
f'iter_{last_whole_epoch_global_step}_{global_step}')
|
|
343
460
|
|
|
344
461
|
# eval epoch step
|
|
345
|
-
if
|
|
462
|
+
if is_main_process() and epoch > start_eval_epoch and (
|
|
346
463
|
epoch - start_eval_epoch) % eval_epoch_step == 0:
|
|
347
464
|
self.eval_step(global_step, epoch)
|
|
348
465
|
|
|
349
|
-
if
|
|
466
|
+
if is_main_process():
|
|
350
467
|
save_ckpt(self.model,
|
|
351
468
|
self.cfg,
|
|
352
469
|
self.optimizer,
|
|
@@ -367,14 +484,13 @@ class Trainer(object):
|
|
|
367
484
|
self.best_metric,
|
|
368
485
|
is_best=False,
|
|
369
486
|
prefix='epoch_' + str(epoch))
|
|
370
|
-
|
|
487
|
+
last_whole_epoch_global_step = global_step
|
|
371
488
|
best_str = f"best metric, {', '.join(['{}: {}'.format(k, v) for k, v in self.best_metric.items()])}"
|
|
372
489
|
self.logger.info(best_str)
|
|
373
490
|
if self.writer is not None:
|
|
374
491
|
self.writer.close()
|
|
375
492
|
if torch.cuda.device_count() > 1:
|
|
376
493
|
torch.distributed.barrier()
|
|
377
|
-
torch.distributed.destroy_process_group()
|
|
378
494
|
|
|
379
495
|
def eval_step(self, global_step, epoch):
|
|
380
496
|
cur_metric = self.eval()
|
openocr/tools/eval_rec_all_ch.py
CHANGED
|
@@ -95,7 +95,7 @@ def main():
|
|
|
95
95
|
acc_each_ignore_space_symbol.append(
|
|
96
96
|
metric['acc_ignore_space_symbol'] * 100)
|
|
97
97
|
acc_each_lower_ignore_space_symbol.append(
|
|
98
|
-
metric['
|
|
98
|
+
metric['acc_ignore_space_lower_symbol'] * 100)
|
|
99
99
|
acc_each_dis.append(metric['norm_edit_dis'])
|
|
100
100
|
acc_each_num.append(metric['num_samples'])
|
|
101
101
|
|
|
@@ -148,7 +148,7 @@ def main():
|
|
|
148
148
|
] + [avg1.sum().tolist()])
|
|
149
149
|
avg1 = np.array(acc_each_lower_ignore_space_symbol) * np.array(
|
|
150
150
|
acc_each_num) / sum(acc_each_num)
|
|
151
|
-
csv_w.writerow(['
|
|
151
|
+
csv_w.writerow(['acc_ignore_space_lower_symbol'] +
|
|
152
152
|
acc_each_lower_ignore_space_symbol + [
|
|
153
153
|
sum(acc_each_lower_ignore_space_symbol) /
|
|
154
154
|
len(acc_each_lower_ignore_space_symbol)
|
openocr/tools/infer_det.py
CHANGED
|
@@ -123,18 +123,36 @@ class OpenDetector(object):
|
|
|
123
123
|
config=None,
|
|
124
124
|
backend='torch',
|
|
125
125
|
onnx_model_path=None,
|
|
126
|
+
use_gpu='auto',
|
|
126
127
|
numId=0):
|
|
127
128
|
"""
|
|
128
129
|
Args:
|
|
129
130
|
config (dict, optional): 配置信息。默认为None。
|
|
130
131
|
backend (str): 'torch' 或 'onnx'
|
|
131
132
|
onnx_model_path (str): ONNX模型路径(仅当backend='onnx'时需要)
|
|
133
|
+
use_gpu (str, optional): GPU使用策略,可选值为'auto'/'true'/'false'。默认为'auto'。
|
|
132
134
|
numId (int, optional): 设备编号。默认为0。
|
|
133
135
|
"""
|
|
134
136
|
|
|
135
137
|
if config is None:
|
|
136
138
|
config = Config(DEFAULT_CFG_PATH_DET).cfg
|
|
137
139
|
|
|
140
|
+
# Parse use_gpu parameter
|
|
141
|
+
if use_gpu == 'auto':
|
|
142
|
+
try:
|
|
143
|
+
import torch
|
|
144
|
+
device = 'gpu' if torch.cuda.is_available() else 'cpu'
|
|
145
|
+
except:
|
|
146
|
+
device = 'cpu'
|
|
147
|
+
elif use_gpu == 'true':
|
|
148
|
+
device = 'gpu'
|
|
149
|
+
elif use_gpu == 'false':
|
|
150
|
+
device = 'cpu'
|
|
151
|
+
else:
|
|
152
|
+
raise ValueError(f"use_gpu must be 'auto', 'true', or 'false', got '{use_gpu}'")
|
|
153
|
+
|
|
154
|
+
config['Global']['device'] = device
|
|
155
|
+
|
|
138
156
|
self._init_common(config)
|
|
139
157
|
backend = backend if config['Global'].get(
|
|
140
158
|
'backend', None) is None else config['Global']['backend']
|
|
@@ -160,7 +178,7 @@ class OpenDetector(object):
|
|
|
160
178
|
else:
|
|
161
179
|
raise ValueError('ONNX模式需要指定onnx_model_path参数')
|
|
162
180
|
self.onnx_det_engine = ONNXEngine(
|
|
163
|
-
onnx_model_path, use_gpu=
|
|
181
|
+
onnx_model_path, use_gpu=(device == 'gpu'))
|
|
164
182
|
else:
|
|
165
183
|
raise ValueError("backend参数必须是'torch'或'onnx'")
|
|
166
184
|
|
|
@@ -269,7 +287,7 @@ class OpenDetector(object):
|
|
|
269
287
|
|
|
270
288
|
info = {'boxes': post_result[0]['points'], 'elapse': t_cost}
|
|
271
289
|
if return_mask:
|
|
272
|
-
if isinstance(preds['maps'], self.torch.Tensor):
|
|
290
|
+
if self.backend == 'torch' and isinstance(preds['maps'], self.torch.Tensor):
|
|
273
291
|
mask = preds['maps'].detach().cpu().numpy()
|
|
274
292
|
else:
|
|
275
293
|
mask = preds['maps']
|