openocr-python 0.0.9__py3-none-any.whl → 0.1.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openocr/__init__.py +35 -1
- openocr/configs/dataset/rec/evaluation.yaml +41 -0
- openocr/configs/dataset/rec/ltb.yaml +9 -0
- openocr/configs/dataset/rec/mjsynth.yaml +11 -0
- openocr/configs/dataset/rec/openvino.yaml +25 -0
- openocr/configs/dataset/rec/ost.yaml +17 -0
- openocr/configs/dataset/rec/synthtext.yaml +7 -0
- openocr/configs/dataset/rec/test.yaml +77 -0
- openocr/configs/dataset/rec/textocr.yaml +13 -0
- openocr/configs/dataset/rec/textocr_horizontal.yaml +13 -0
- openocr/configs/dataset/rec/union14m_b.yaml +47 -0
- openocr/configs/dataset/rec/union14m_l_filtered.yaml +35 -0
- openocr/configs/rec/cmer/cmer.yml +127 -0
- openocr/configs/rec/mdiff4str/svtrv2_mdiffdecoder_base.yml +152 -0
- openocr/configs/rec/mdiff4str/svtrv2_mdiffdecoder_small.yml +152 -0
- openocr/configs/rec/unirec/focalsvtr_ardecoder_unirec.yml +114 -0
- openocr/configs/rec/unirec/opendoc_pipeline.yml +105 -0
- openocr/demo_gradio.py +28 -8
- openocr/demo_opendoc.py +572 -0
- openocr/demo_unirec.py +392 -0
- openocr/opendet/losses/__init__.py +5 -7
- openocr/opendet/preprocess/crop_resize.py +2 -1
- openocr/openocr.py +685 -0
- openocr/openrec/losses/__init__.py +8 -3
- openocr/openrec/losses/cmer_loss.py +12 -0
- openocr/openrec/losses/mdiff_loss.py +11 -0
- openocr/openrec/losses/unirec_loss.py +12 -0
- openocr/openrec/metrics/__init__.py +4 -1
- openocr/openrec/metrics/rec_metric_cmer.py +328 -0
- openocr/openrec/modeling/cmer_modeling/modeling_cmer.py +643 -0
- openocr/openrec/modeling/decoders/__init__.py +1 -0
- openocr/openrec/modeling/decoders/ctc_decoder.py +1 -1
- openocr/openrec/modeling/decoders/dan_decoder.py +4 -4
- openocr/openrec/modeling/decoders/dptr_parseq_clip_b_decoder.py +1563 -1398
- openocr/openrec/modeling/decoders/mdiff_decoder.py +587 -0
- openocr/openrec/modeling/decoders/smtr_decoder.py +99 -48
- openocr/openrec/modeling/unirec_modeling/configuration_unirec.py +166 -0
- openocr/openrec/modeling/unirec_modeling/modeling_unirec.py +433 -0
- openocr/openrec/optimizer/__init__.py +4 -3
- openocr/openrec/optimizer/lr.py +49 -0
- openocr/openrec/postprocess/__init__.py +2 -0
- openocr/openrec/postprocess/abinet_postprocess.py +1 -1
- openocr/openrec/postprocess/ar_postprocess.py +1 -1
- openocr/openrec/postprocess/cmer_postprocess.py +86 -0
- openocr/openrec/postprocess/cppd_postprocess.py +1 -1
- openocr/openrec/postprocess/igtr_postprocess.py +1 -1
- openocr/openrec/postprocess/lister_postprocess.py +1 -1
- openocr/openrec/postprocess/mgp_postprocess.py +1 -1
- openocr/openrec/postprocess/nrtr_postprocess.py +2 -2
- openocr/openrec/postprocess/smtr_postprocess.py +1 -1
- openocr/openrec/postprocess/srn_postprocess.py +1 -1
- openocr/openrec/postprocess/unirec_postprocess.py +58 -0
- openocr/openrec/postprocess/visionlan_postprocess.py +1 -1
- openocr/openrec/preprocess/__init__.py +5 -0
- openocr/openrec/preprocess/ce_label_encode.py +1 -1
- openocr/openrec/preprocess/cmer_label_encode.py +1025 -0
- openocr/openrec/preprocess/ctc_label_encode.py +1 -1
- openocr/openrec/preprocess/dptr_label_encode.py +177 -157
- openocr/openrec/preprocess/igtr_label_encode.py +4 -2
- openocr/openrec/preprocess/mdiff_label_encode.py +312 -0
- openocr/openrec/preprocess/rec_aug.py +128 -2
- openocr/openrec/preprocess/resize.py +57 -0
- openocr/openrec/preprocess/unirec_label_encode.py +62 -0
- openocr/tools/data/__init__.py +78 -55
- openocr/tools/data/cmer_web_dataset.py +310 -0
- openocr/tools/data/native_size_dataset.py +753 -0
- openocr/tools/data/native_size_sampler.py +158 -0
- openocr/tools/data/ratio_dataset_tvresize.py +2 -0
- openocr/tools/data/ratio_sampler.py +2 -1
- openocr/tools/download/download_dataset.py +38 -0
- openocr/tools/download/utils.py +28 -0
- openocr/tools/download_example_images.py +236 -0
- openocr/tools/engine/trainer.py +155 -39
- openocr/tools/eval_rec_all_ch.py +2 -2
- openocr/tools/infer_det.py +20 -2
- openocr/tools/infer_doc.py +898 -0
- openocr/tools/infer_doc_onnx.py +1172 -0
- openocr/tools/infer_e2e.py +27 -10
- openocr/tools/infer_rec.py +64 -15
- openocr/tools/infer_unirec_onnx.py +730 -0
- openocr/tools/to_markdown.py +468 -0
- openocr/tools/utils/ckpt.py +17 -5
- openocr/tools/utils/opendoc_onnx_utils/utils.py +1052 -0
- openocr_python-0.1.0.dev0.dist-info/METADATA +324 -0
- {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/RECORD +89 -45
- {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/WHEEL +1 -1
- openocr_python-0.1.0.dev0.dist-info/entry_points.txt +2 -0
- openocr_python-0.0.9.dist-info/METADATA +0 -149
- /openocr_python-0.0.9.dist-info/LICENCE → /openocr_python-0.1.0.dev0.dist-info/licenses/LICENSE +0 -0
- {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,158 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import torch
|
|
3
|
+
from torch.utils.data import Sampler
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def resize_image(original_width, original_height, max_width, max_height):
|
|
7
|
+
# 计算宽高比
|
|
8
|
+
aspect_ratio = original_width / original_height
|
|
9
|
+
|
|
10
|
+
# 计算新的宽度和高度
|
|
11
|
+
if original_width > max_width or original_height > max_height:
|
|
12
|
+
if (max_width / max_height) >= aspect_ratio:
|
|
13
|
+
# 按高度限制比例
|
|
14
|
+
new_height = max_height
|
|
15
|
+
new_width = int(new_height * aspect_ratio)
|
|
16
|
+
else:
|
|
17
|
+
# 按宽度限制比例
|
|
18
|
+
new_width = max_width
|
|
19
|
+
new_height = int(new_width / aspect_ratio)
|
|
20
|
+
else:
|
|
21
|
+
# 如果图片已经小于或等于最大尺寸,则无需调整
|
|
22
|
+
new_width, new_height = original_width, original_height
|
|
23
|
+
return new_width, new_height
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class NaSizeSampler(Sampler):
|
|
27
|
+
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
data_source,
|
|
31
|
+
max_side=[64 * 15, 64 * 22], # w,h
|
|
32
|
+
min_bs=1,
|
|
33
|
+
max_bs=1024,
|
|
34
|
+
resume_iter=0,
|
|
35
|
+
scale_ratio=2,
|
|
36
|
+
seed=None):
|
|
37
|
+
"""
|
|
38
|
+
multi scale samper
|
|
39
|
+
Args:
|
|
40
|
+
data_source(dataset)
|
|
41
|
+
scales(list): several scales for image resolution
|
|
42
|
+
first_bs(int): batch size for the first scale in scales
|
|
43
|
+
divided_factor(list[w, h]): ImageNet models down-sample images by a factor, ensure that width and height dimensions are multiples are multiple of devided_factor.
|
|
44
|
+
is_training(boolean): mode
|
|
45
|
+
"""
|
|
46
|
+
self.data_source = data_source
|
|
47
|
+
|
|
48
|
+
self.seed = data_source.seed
|
|
49
|
+
|
|
50
|
+
self.img_label_pair_list = data_source.img_label_pair_list
|
|
51
|
+
self.shuffle = data_source.do_shuffle
|
|
52
|
+
self.is_training = data_source.mode == 'train'
|
|
53
|
+
|
|
54
|
+
max_side = data_source.max_side
|
|
55
|
+
batch_list = []
|
|
56
|
+
sorted_keys = sorted(
|
|
57
|
+
self.img_label_pair_list.keys(),
|
|
58
|
+
key=lambda k: int(k.split('_')[0]) * int(k.split('_')[1]))
|
|
59
|
+
for key in sorted_keys:
|
|
60
|
+
w_r, h_r = key.split('_')
|
|
61
|
+
w_r = int(w_r)
|
|
62
|
+
h_r = int(h_r)
|
|
63
|
+
|
|
64
|
+
current_bs = int(((max_side[0] * max_side[1]) // (w_r * h_r)) *
|
|
65
|
+
min_bs * scale_ratio)
|
|
66
|
+
current_bs = min(current_bs, max_bs,
|
|
67
|
+
len(self.img_label_pair_list[key]))
|
|
68
|
+
bacth_num = len(self.img_label_pair_list[key]) // current_bs
|
|
69
|
+
current_img_indices_all = np.arange(len(
|
|
70
|
+
self.img_label_pair_list[key]),
|
|
71
|
+
dtype=np.int64)
|
|
72
|
+
|
|
73
|
+
drop = len(self.img_label_pair_list[key]) - current_bs * bacth_num
|
|
74
|
+
if self.is_training and drop > 0:
|
|
75
|
+
drop_full_num = current_bs - drop
|
|
76
|
+
drop_full = np.random.choice(current_img_indices_all,
|
|
77
|
+
drop_full_num,
|
|
78
|
+
replace=True)
|
|
79
|
+
current_img_indices = np.append(current_img_indices_all,
|
|
80
|
+
drop_full)
|
|
81
|
+
else:
|
|
82
|
+
current_img_indices = current_img_indices_all[:bacth_num *
|
|
83
|
+
current_bs]
|
|
84
|
+
current_batch_list = current_img_indices.reshape(-1, current_bs, 1)
|
|
85
|
+
w_r_batch = np.full_like(current_batch_list, w_r)
|
|
86
|
+
h_r_batch = np.full_like(current_batch_list, h_r)
|
|
87
|
+
random_zoom_time = np.random.randint(
|
|
88
|
+
-5, 50, [current_batch_list.shape[0], 1, 1])
|
|
89
|
+
random_zoom_time = np.tile(random_zoom_time,
|
|
90
|
+
(1, current_batch_list.shape[1], 1))
|
|
91
|
+
current_batch_list = np.concatenate(
|
|
92
|
+
[current_batch_list, w_r_batch, h_r_batch, random_zoom_time],
|
|
93
|
+
axis=-1)
|
|
94
|
+
batch_list.extend(current_batch_list.tolist())
|
|
95
|
+
|
|
96
|
+
if not self.is_training and drop > 0:
|
|
97
|
+
current_img_indices = current_img_indices_all[bacth_num *
|
|
98
|
+
current_bs:]
|
|
99
|
+
current_batch_list = current_img_indices.reshape(-1, drop, 1)
|
|
100
|
+
w_r_batch = np.full_like(current_batch_list, w_r)
|
|
101
|
+
h_r_batch = np.full_like(current_batch_list, h_r)
|
|
102
|
+
random_zoom_time = np.random.randint(
|
|
103
|
+
-5, 50, [current_batch_list.shape[0], 1, 1])
|
|
104
|
+
random_zoom_time = np.tile(random_zoom_time,
|
|
105
|
+
(1, current_batch_list.shape[1], 1))
|
|
106
|
+
current_batch_list = np.concatenate([
|
|
107
|
+
current_batch_list, w_r_batch, h_r_batch, random_zoom_time
|
|
108
|
+
],
|
|
109
|
+
axis=-1)
|
|
110
|
+
batch_list.extend(current_batch_list.tolist())
|
|
111
|
+
|
|
112
|
+
self.fix_cobatch = 4
|
|
113
|
+
self.batch_list = batch_list # [[[img_id, w_r, h_r, zoom_time], ...], ...]
|
|
114
|
+
self.length = len(self.batch_list)
|
|
115
|
+
self.batchs_id_sort = [i for i in range(self.length)]
|
|
116
|
+
self.batchs_in_one_epoch_id = self.batchs_id_sort.copy()
|
|
117
|
+
self.is_shuffled = False
|
|
118
|
+
self.resume_iter = resume_iter
|
|
119
|
+
if self.shuffle or self.is_training:
|
|
120
|
+
g = torch.Generator()
|
|
121
|
+
g.manual_seed(self.seed) # 让所有进程的种子相同
|
|
122
|
+
random_indices = torch.randperm(len(self.batchs_in_one_epoch_id),
|
|
123
|
+
generator=g).tolist()
|
|
124
|
+
self.batchs_in_one_epoch_id = [
|
|
125
|
+
self.batchs_in_one_epoch_id[i] for i in random_indices
|
|
126
|
+
]
|
|
127
|
+
if self.resume_iter > 0:
|
|
128
|
+
# resume iter
|
|
129
|
+
for iter_ in range(len(self.batch_list)):
|
|
130
|
+
if iter_ <= self.resume_iter:
|
|
131
|
+
batch_list_current = self.batch_list[
|
|
132
|
+
self.batchs_in_one_epoch_id[iter_]]
|
|
133
|
+
batch_list_current_resume = []
|
|
134
|
+
for batch in batch_list_current:
|
|
135
|
+
batch.append(1)
|
|
136
|
+
batch_list_current_resume.append(batch)
|
|
137
|
+
self.batch_list[self.batchs_in_one_epoch_id[
|
|
138
|
+
iter_]] = batch_list_current_resume
|
|
139
|
+
else:
|
|
140
|
+
batch_list_current = self.batch_list[
|
|
141
|
+
self.batchs_in_one_epoch_id[iter_]]
|
|
142
|
+
batch_list_current_resume = []
|
|
143
|
+
for batch in batch_list_current:
|
|
144
|
+
batch.append(0)
|
|
145
|
+
batch_list_current_resume.append(batch)
|
|
146
|
+
self.batch_list[self.batchs_in_one_epoch_id[
|
|
147
|
+
iter_]] = batch_list_current_resume
|
|
148
|
+
self.resume_iter = 0
|
|
149
|
+
|
|
150
|
+
def __iter__(self):
|
|
151
|
+
for batch_tuple_id in self.batchs_in_one_epoch_id:
|
|
152
|
+
yield self.batch_list[batch_tuple_id]
|
|
153
|
+
|
|
154
|
+
def set_epoch(self, epoch: int):
|
|
155
|
+
self.epoch = epoch
|
|
156
|
+
|
|
157
|
+
def __len__(self):
|
|
158
|
+
return self.length
|
|
@@ -167,6 +167,8 @@ class RatioDataSetTVResize(Dataset):
|
|
|
167
167
|
valid_ratio = min(1.0, float(resized_w / imgW))
|
|
168
168
|
data['image'] = img
|
|
169
169
|
data['valid_ratio'] = valid_ratio
|
|
170
|
+
r = float(w) / float(h)
|
|
171
|
+
data['real_ratio'] = max(1, round(r))
|
|
170
172
|
return data
|
|
171
173
|
|
|
172
174
|
def get_lmdb_sample_info(self, txn, index):
|
|
@@ -56,7 +56,8 @@ class RatioSampler(Sampler):
|
|
|
56
56
|
self.base_im_w = base_im_w
|
|
57
57
|
|
|
58
58
|
# Get the GPU and node related information
|
|
59
|
-
num_replicas = torch.cuda.device_count() if torch.cuda.is_available(
|
|
59
|
+
num_replicas = torch.cuda.device_count() if torch.cuda.is_available(
|
|
60
|
+
) else 1
|
|
60
61
|
# rank = dist.get_rank()
|
|
61
62
|
rank = (int(os.environ['LOCAL_RANK'])
|
|
62
63
|
if 'LOCAL_RANK' in os.environ else 0)
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import sys
|
|
3
|
+
|
|
4
|
+
__dir__ = os.path.dirname(os.path.abspath(__file__))
|
|
5
|
+
|
|
6
|
+
sys.path.append(__dir__)
|
|
7
|
+
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
|
|
8
|
+
sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..', '..')))
|
|
9
|
+
|
|
10
|
+
from engine import Config
|
|
11
|
+
from utility import ArgsParser
|
|
12
|
+
import download.utils
|
|
13
|
+
from torchvision.datasets.utils import extract_archive
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def main(cfg):
|
|
17
|
+
urls, filename_paths, check_validity = download.utils.get_dataset_info(cfg)
|
|
18
|
+
for url, filename_path in zip(urls, filename_paths):
|
|
19
|
+
print(f'Downloading {filename_path} from {url} . . .')
|
|
20
|
+
download.utils.urlretrieve(url=url,
|
|
21
|
+
filename=filename_path,
|
|
22
|
+
check_validity=check_validity)
|
|
23
|
+
if not filename_path.endswith('.mdb'):
|
|
24
|
+
extract_archive(from_path=filename_path,
|
|
25
|
+
to_path=cfg['root'],
|
|
26
|
+
remove_finished=True)
|
|
27
|
+
|
|
28
|
+
print('Downloads finished!')
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
if __name__ == '__main__':
|
|
32
|
+
FLAGS = ArgsParser().parse_args()
|
|
33
|
+
cfg = Config(FLAGS.config)
|
|
34
|
+
FLAGS = vars(FLAGS)
|
|
35
|
+
opt = FLAGS.pop('opt')
|
|
36
|
+
cfg.merge_dict(FLAGS)
|
|
37
|
+
cfg.merge_dict(opt)
|
|
38
|
+
main(cfg.cfg)
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
import urllib
|
|
2
|
+
import ssl
|
|
3
|
+
from tqdm import tqdm
|
|
4
|
+
import os
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def get_dataset_info(cfg):
|
|
8
|
+
download_urls, filenames, check_validity = cfg['download_links'], cfg[
|
|
9
|
+
'filenames'], cfg['check_validity']
|
|
10
|
+
return download_urls, filenames, check_validity
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
# Modified from torchvision as some datasets cant pass the certificate validity check:
|
|
14
|
+
# https://github.com/pytorch/vision/blob/868a3b42f4bffe29e4414ad7e4c7d9d0b4690ecb/torchvision/datasets/utils.py#L27C1-L32C40
|
|
15
|
+
def urlretrieve(url, filename, chunk_size=1024 * 32, check_validity=True):
|
|
16
|
+
os.makedirs(os.path.dirname(filename), exist_ok=True)
|
|
17
|
+
ctx = ssl.create_default_context()
|
|
18
|
+
if not check_validity:
|
|
19
|
+
ctx.check_hostname = False
|
|
20
|
+
ctx.verify_mode = ssl.CERT_NONE
|
|
21
|
+
request = urllib.request.Request(url)
|
|
22
|
+
with urllib.request.urlopen(request, context=ctx) as response:
|
|
23
|
+
with open(filename, 'wb') as fh, tqdm(total=response.length,
|
|
24
|
+
unit='B',
|
|
25
|
+
unit_scale=True) as pbar:
|
|
26
|
+
while chunk := response.read(chunk_size):
|
|
27
|
+
fh.write(chunk)
|
|
28
|
+
pbar.update(len(chunk))
|
|
@@ -0,0 +1,236 @@
|
|
|
1
|
+
"""Download example images from ModelScope dataset for demo purposes."""
|
|
2
|
+
import os
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
import shutil
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def download_example_images():
|
|
8
|
+
"""Download example images from ModelScope dataset.
|
|
9
|
+
|
|
10
|
+
Returns:
|
|
11
|
+
Dict with paths to example image directories: {'ocr': path, 'doc': path, 'unirec': path}
|
|
12
|
+
"""
|
|
13
|
+
# Will use dataset cache path folders directly
|
|
14
|
+
subdirs = {}
|
|
15
|
+
|
|
16
|
+
print(f'📥 Downloading example images...')
|
|
17
|
+
|
|
18
|
+
download_success = False
|
|
19
|
+
|
|
20
|
+
try:
|
|
21
|
+
# Try ModelScope first (default)
|
|
22
|
+
print('🌐 Trying ModelScope (China mirror) first...')
|
|
23
|
+
try:
|
|
24
|
+
# Download files directly from ModelScope dataset repository
|
|
25
|
+
dataset_id = 'topdktu/openocr_test_images'
|
|
26
|
+
|
|
27
|
+
# Try to get file list and download
|
|
28
|
+
try:
|
|
29
|
+
# This is a simplified approach - download via git clone or snapshot
|
|
30
|
+
from modelscope.hub.snapshot_download import snapshot_download
|
|
31
|
+
|
|
32
|
+
cache_path = snapshot_download(
|
|
33
|
+
repo_id=dataset_id,
|
|
34
|
+
repo_type='dataset',
|
|
35
|
+
cache_dir=str(Path.home() / '.cache' / 'openocr')
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
print(f'✅ Dataset downloaded from ModelScope to {cache_path}')
|
|
39
|
+
|
|
40
|
+
# Use dataset cache path folders directly
|
|
41
|
+
cache_dir = Path(cache_path)
|
|
42
|
+
subdirs = {
|
|
43
|
+
'ocr': cache_dir / 'ocr',
|
|
44
|
+
'doc': cache_dir / 'doc',
|
|
45
|
+
'unirec': cache_dir / 'unirec'
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
# Verify folders exist and have images
|
|
49
|
+
all_folders_valid = True
|
|
50
|
+
for folder_name, folder_path in subdirs.items():
|
|
51
|
+
if folder_path.exists():
|
|
52
|
+
img_count = len([f for f in folder_path.glob('*') if f.is_file() and f.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp', '.gif']])
|
|
53
|
+
if img_count > 0:
|
|
54
|
+
print(f' ✓ Found {folder_name} folder with {img_count} images')
|
|
55
|
+
else:
|
|
56
|
+
print(f' ⚠️ {folder_name} folder exists but has no images')
|
|
57
|
+
all_folders_valid = False
|
|
58
|
+
else:
|
|
59
|
+
print(f' ⚠️ {folder_name} folder not found')
|
|
60
|
+
all_folders_valid = False
|
|
61
|
+
|
|
62
|
+
if all_folders_valid:
|
|
63
|
+
download_success = True
|
|
64
|
+
else:
|
|
65
|
+
print('⚠️ ModelScope download incomplete, trying HuggingFace...')
|
|
66
|
+
subdirs = {}
|
|
67
|
+
|
|
68
|
+
except Exception as e:
|
|
69
|
+
print(f'⚠️ ModelScope snapshot download failed: {e}')
|
|
70
|
+
print(' Trying HuggingFace...')
|
|
71
|
+
|
|
72
|
+
except ImportError:
|
|
73
|
+
print('⚠️ modelscope not installed. Install with: pip install modelscope')
|
|
74
|
+
print(' Trying HuggingFace...')
|
|
75
|
+
except Exception as e:
|
|
76
|
+
print(f'⚠️ ModelScope download failed: {e}')
|
|
77
|
+
print(' Trying HuggingFace...')
|
|
78
|
+
|
|
79
|
+
if not download_success:
|
|
80
|
+
# Try HuggingFace
|
|
81
|
+
print('🌐 Using HuggingFace...')
|
|
82
|
+
try:
|
|
83
|
+
from huggingface_hub import snapshot_download
|
|
84
|
+
|
|
85
|
+
# Download entire dataset
|
|
86
|
+
dataset_path = snapshot_download(
|
|
87
|
+
repo_id='topdu/openocr_test_images',
|
|
88
|
+
repo_type='dataset',
|
|
89
|
+
cache_dir=str(Path.home() / '.cache' / 'openocr')
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
print(f'✅ Dataset downloaded to {dataset_path}')
|
|
93
|
+
|
|
94
|
+
# Use dataset cache path folders directly
|
|
95
|
+
cache_dir = Path(dataset_path)
|
|
96
|
+
subdirs = {
|
|
97
|
+
'ocr': cache_dir / 'ocr',
|
|
98
|
+
'doc': cache_dir / 'doc',
|
|
99
|
+
'unirec': cache_dir / 'unirec'
|
|
100
|
+
}
|
|
101
|
+
|
|
102
|
+
# Verify folders exist and have images
|
|
103
|
+
all_folders_valid = True
|
|
104
|
+
for folder_name, folder_path in subdirs.items():
|
|
105
|
+
if folder_path.exists():
|
|
106
|
+
img_count = len([f for f in folder_path.glob('*') if f.is_file() and f.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp', '.gif']])
|
|
107
|
+
if img_count > 0:
|
|
108
|
+
print(f' ✓ Found {folder_name} folder with {img_count} images')
|
|
109
|
+
else:
|
|
110
|
+
print(f' ⚠️ {folder_name} folder exists but has no images')
|
|
111
|
+
all_folders_valid = False
|
|
112
|
+
else:
|
|
113
|
+
print(f' ⚠️ {folder_name} folder not found')
|
|
114
|
+
all_folders_valid = False
|
|
115
|
+
|
|
116
|
+
if all_folders_valid:
|
|
117
|
+
download_success = True
|
|
118
|
+
|
|
119
|
+
except ImportError:
|
|
120
|
+
print('⚠️ huggingface_hub not installed. Install with: pip install huggingface_hub')
|
|
121
|
+
except Exception as e:
|
|
122
|
+
print(f'⚠️ HuggingFace download failed: {e}')
|
|
123
|
+
|
|
124
|
+
# Try GitHub releases as fallback for OCR examples only
|
|
125
|
+
if not download_success:
|
|
126
|
+
print('🌐 Trying GitHub releases as fallback for OCR examples...')
|
|
127
|
+
try:
|
|
128
|
+
import urllib.request
|
|
129
|
+
import tarfile
|
|
130
|
+
import tempfile
|
|
131
|
+
|
|
132
|
+
ocr_url = 'https://github.com/Topdu/OpenOCR/releases/download/develop0.0.1/OCR_e2e_img.tar'
|
|
133
|
+
|
|
134
|
+
# Use temp directory for download
|
|
135
|
+
with tempfile.TemporaryDirectory() as temp_dir:
|
|
136
|
+
temp_path = Path(temp_dir)
|
|
137
|
+
tar_path = temp_path / 'OCR_e2e_img.tar'
|
|
138
|
+
|
|
139
|
+
print(f' Downloading from {ocr_url}...')
|
|
140
|
+
urllib.request.urlretrieve(ocr_url, str(tar_path))
|
|
141
|
+
|
|
142
|
+
print(f' Extracting...')
|
|
143
|
+
with tarfile.open(str(tar_path), 'r') as tar:
|
|
144
|
+
tar.extractall(path=str(temp_path))
|
|
145
|
+
|
|
146
|
+
# Move to cache directory
|
|
147
|
+
cache_base = Path.home() / '.cache' / 'openocr' / 'openocr_examples'
|
|
148
|
+
cache_base.mkdir(parents=True, exist_ok=True)
|
|
149
|
+
|
|
150
|
+
# Copy extracted files to cache
|
|
151
|
+
ocr_source = temp_path / 'OCR_e2e_img'
|
|
152
|
+
ocr_target = cache_base / 'ocr'
|
|
153
|
+
if ocr_source.exists():
|
|
154
|
+
if ocr_target.exists():
|
|
155
|
+
shutil.rmtree(str(ocr_target))
|
|
156
|
+
shutil.copytree(str(ocr_source), str(ocr_target))
|
|
157
|
+
|
|
158
|
+
# Set subdirs for GitHub download
|
|
159
|
+
subdirs = {
|
|
160
|
+
'ocr': ocr_target,
|
|
161
|
+
'doc': cache_base / 'doc',
|
|
162
|
+
'unirec': cache_base / 'unirec'
|
|
163
|
+
}
|
|
164
|
+
|
|
165
|
+
# Create empty directories for doc and unirec if they don't exist
|
|
166
|
+
for key in ['doc', 'unirec']:
|
|
167
|
+
subdirs[key].mkdir(parents=True, exist_ok=True)
|
|
168
|
+
|
|
169
|
+
print(f' ✓ OCR example images downloaded from GitHub to cache')
|
|
170
|
+
download_success = True
|
|
171
|
+
|
|
172
|
+
except Exception as e:
|
|
173
|
+
print(f'⚠️ GitHub download failed: {e}')
|
|
174
|
+
|
|
175
|
+
if download_success:
|
|
176
|
+
print(f'✅ Example images ready!')
|
|
177
|
+
else:
|
|
178
|
+
print('⚠️ Could not download example images automatically.')
|
|
179
|
+
|
|
180
|
+
except Exception as e:
|
|
181
|
+
print(f'❌ Download failed: {e}')
|
|
182
|
+
|
|
183
|
+
finally:
|
|
184
|
+
# Verify directories
|
|
185
|
+
if subdirs:
|
|
186
|
+
print('\n📝 Example image directories:')
|
|
187
|
+
for name, subdir in subdirs.items():
|
|
188
|
+
if subdir.exists():
|
|
189
|
+
if not any(subdir.iterdir()):
|
|
190
|
+
print(f' ⚠️ {name}: No images found in {subdir}')
|
|
191
|
+
print(f' You can manually add example images to this directory.')
|
|
192
|
+
else:
|
|
193
|
+
img_count = len(list(subdir.glob('*.[jp][pn]g')) + list(subdir.glob('*.jpeg')) + list(subdir.glob('*.bmp')))
|
|
194
|
+
print(f' ✓ {name}: {img_count} images found in {subdir}')
|
|
195
|
+
else:
|
|
196
|
+
print(f' ⚠️ {name}: Directory not found at {subdir}')
|
|
197
|
+
else:
|
|
198
|
+
print('\n⚠️ No example image directories available')
|
|
199
|
+
|
|
200
|
+
return {k: str(v) for k, v in subdirs.items()}
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
def get_example_images_path(demo_type='ocr'):
|
|
204
|
+
"""Get the path to example images for a specific demo type.
|
|
205
|
+
|
|
206
|
+
Args:
|
|
207
|
+
demo_type: Type of demo ('ocr', 'doc', or 'unirec')
|
|
208
|
+
|
|
209
|
+
Returns:
|
|
210
|
+
Path to example images directory
|
|
211
|
+
"""
|
|
212
|
+
# Download and get paths from cache
|
|
213
|
+
print(f'Getting example images for {demo_type}...')
|
|
214
|
+
paths = download_example_images()
|
|
215
|
+
|
|
216
|
+
# Return the path for the requested demo type
|
|
217
|
+
if demo_type in paths:
|
|
218
|
+
return paths[demo_type]
|
|
219
|
+
else:
|
|
220
|
+
print(f'⚠️ Unknown demo type: {demo_type}')
|
|
221
|
+
return paths.get('ocr', '')
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
if __name__ == '__main__':
|
|
225
|
+
# Test download
|
|
226
|
+
import argparse
|
|
227
|
+
|
|
228
|
+
parser = argparse.ArgumentParser(description='Download example images for OpenOCR demos')
|
|
229
|
+
|
|
230
|
+
args = parser.parse_args()
|
|
231
|
+
|
|
232
|
+
paths = download_example_images()
|
|
233
|
+
|
|
234
|
+
print('\n📁 Example image directories:')
|
|
235
|
+
for demo_type, path in paths.items():
|
|
236
|
+
print(f' {demo_type}: {path}')
|