openocr-python 0.0.9__py3-none-any.whl → 0.1.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. openocr/__init__.py +35 -1
  2. openocr/configs/dataset/rec/evaluation.yaml +41 -0
  3. openocr/configs/dataset/rec/ltb.yaml +9 -0
  4. openocr/configs/dataset/rec/mjsynth.yaml +11 -0
  5. openocr/configs/dataset/rec/openvino.yaml +25 -0
  6. openocr/configs/dataset/rec/ost.yaml +17 -0
  7. openocr/configs/dataset/rec/synthtext.yaml +7 -0
  8. openocr/configs/dataset/rec/test.yaml +77 -0
  9. openocr/configs/dataset/rec/textocr.yaml +13 -0
  10. openocr/configs/dataset/rec/textocr_horizontal.yaml +13 -0
  11. openocr/configs/dataset/rec/union14m_b.yaml +47 -0
  12. openocr/configs/dataset/rec/union14m_l_filtered.yaml +35 -0
  13. openocr/configs/rec/cmer/cmer.yml +127 -0
  14. openocr/configs/rec/mdiff4str/svtrv2_mdiffdecoder_base.yml +152 -0
  15. openocr/configs/rec/mdiff4str/svtrv2_mdiffdecoder_small.yml +152 -0
  16. openocr/configs/rec/unirec/focalsvtr_ardecoder_unirec.yml +114 -0
  17. openocr/configs/rec/unirec/opendoc_pipeline.yml +105 -0
  18. openocr/demo_gradio.py +28 -8
  19. openocr/demo_opendoc.py +572 -0
  20. openocr/demo_unirec.py +392 -0
  21. openocr/opendet/losses/__init__.py +5 -7
  22. openocr/opendet/preprocess/crop_resize.py +2 -1
  23. openocr/openocr.py +685 -0
  24. openocr/openrec/losses/__init__.py +8 -3
  25. openocr/openrec/losses/cmer_loss.py +12 -0
  26. openocr/openrec/losses/mdiff_loss.py +11 -0
  27. openocr/openrec/losses/unirec_loss.py +12 -0
  28. openocr/openrec/metrics/__init__.py +4 -1
  29. openocr/openrec/metrics/rec_metric_cmer.py +328 -0
  30. openocr/openrec/modeling/cmer_modeling/modeling_cmer.py +643 -0
  31. openocr/openrec/modeling/decoders/__init__.py +1 -0
  32. openocr/openrec/modeling/decoders/ctc_decoder.py +1 -1
  33. openocr/openrec/modeling/decoders/dan_decoder.py +4 -4
  34. openocr/openrec/modeling/decoders/dptr_parseq_clip_b_decoder.py +1563 -1398
  35. openocr/openrec/modeling/decoders/mdiff_decoder.py +587 -0
  36. openocr/openrec/modeling/decoders/smtr_decoder.py +99 -48
  37. openocr/openrec/modeling/unirec_modeling/configuration_unirec.py +166 -0
  38. openocr/openrec/modeling/unirec_modeling/modeling_unirec.py +433 -0
  39. openocr/openrec/optimizer/__init__.py +4 -3
  40. openocr/openrec/optimizer/lr.py +49 -0
  41. openocr/openrec/postprocess/__init__.py +2 -0
  42. openocr/openrec/postprocess/abinet_postprocess.py +1 -1
  43. openocr/openrec/postprocess/ar_postprocess.py +1 -1
  44. openocr/openrec/postprocess/cmer_postprocess.py +86 -0
  45. openocr/openrec/postprocess/cppd_postprocess.py +1 -1
  46. openocr/openrec/postprocess/igtr_postprocess.py +1 -1
  47. openocr/openrec/postprocess/lister_postprocess.py +1 -1
  48. openocr/openrec/postprocess/mgp_postprocess.py +1 -1
  49. openocr/openrec/postprocess/nrtr_postprocess.py +2 -2
  50. openocr/openrec/postprocess/smtr_postprocess.py +1 -1
  51. openocr/openrec/postprocess/srn_postprocess.py +1 -1
  52. openocr/openrec/postprocess/unirec_postprocess.py +58 -0
  53. openocr/openrec/postprocess/visionlan_postprocess.py +1 -1
  54. openocr/openrec/preprocess/__init__.py +5 -0
  55. openocr/openrec/preprocess/ce_label_encode.py +1 -1
  56. openocr/openrec/preprocess/cmer_label_encode.py +1025 -0
  57. openocr/openrec/preprocess/ctc_label_encode.py +1 -1
  58. openocr/openrec/preprocess/dptr_label_encode.py +177 -157
  59. openocr/openrec/preprocess/igtr_label_encode.py +4 -2
  60. openocr/openrec/preprocess/mdiff_label_encode.py +312 -0
  61. openocr/openrec/preprocess/rec_aug.py +128 -2
  62. openocr/openrec/preprocess/resize.py +57 -0
  63. openocr/openrec/preprocess/unirec_label_encode.py +62 -0
  64. openocr/tools/data/__init__.py +78 -55
  65. openocr/tools/data/cmer_web_dataset.py +310 -0
  66. openocr/tools/data/native_size_dataset.py +753 -0
  67. openocr/tools/data/native_size_sampler.py +158 -0
  68. openocr/tools/data/ratio_dataset_tvresize.py +2 -0
  69. openocr/tools/data/ratio_sampler.py +2 -1
  70. openocr/tools/download/download_dataset.py +38 -0
  71. openocr/tools/download/utils.py +28 -0
  72. openocr/tools/download_example_images.py +236 -0
  73. openocr/tools/engine/trainer.py +155 -39
  74. openocr/tools/eval_rec_all_ch.py +2 -2
  75. openocr/tools/infer_det.py +20 -2
  76. openocr/tools/infer_doc.py +898 -0
  77. openocr/tools/infer_doc_onnx.py +1172 -0
  78. openocr/tools/infer_e2e.py +27 -10
  79. openocr/tools/infer_rec.py +64 -15
  80. openocr/tools/infer_unirec_onnx.py +730 -0
  81. openocr/tools/to_markdown.py +468 -0
  82. openocr/tools/utils/ckpt.py +17 -5
  83. openocr/tools/utils/opendoc_onnx_utils/utils.py +1052 -0
  84. openocr_python-0.1.0.dev0.dist-info/METADATA +324 -0
  85. {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/RECORD +89 -45
  86. {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/WHEEL +1 -1
  87. openocr_python-0.1.0.dev0.dist-info/entry_points.txt +2 -0
  88. openocr_python-0.0.9.dist-info/METADATA +0 -149
  89. /openocr_python-0.0.9.dist-info/LICENCE → /openocr_python-0.1.0.dev0.dist-info/licenses/LICENSE +0 -0
  90. {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,158 @@
1
+ import numpy as np
2
+ import torch
3
+ from torch.utils.data import Sampler
4
+
5
+
6
+ def resize_image(original_width, original_height, max_width, max_height):
7
+ # 计算宽高比
8
+ aspect_ratio = original_width / original_height
9
+
10
+ # 计算新的宽度和高度
11
+ if original_width > max_width or original_height > max_height:
12
+ if (max_width / max_height) >= aspect_ratio:
13
+ # 按高度限制比例
14
+ new_height = max_height
15
+ new_width = int(new_height * aspect_ratio)
16
+ else:
17
+ # 按宽度限制比例
18
+ new_width = max_width
19
+ new_height = int(new_width / aspect_ratio)
20
+ else:
21
+ # 如果图片已经小于或等于最大尺寸,则无需调整
22
+ new_width, new_height = original_width, original_height
23
+ return new_width, new_height
24
+
25
+
26
+ class NaSizeSampler(Sampler):
27
+
28
+ def __init__(
29
+ self,
30
+ data_source,
31
+ max_side=[64 * 15, 64 * 22], # w,h
32
+ min_bs=1,
33
+ max_bs=1024,
34
+ resume_iter=0,
35
+ scale_ratio=2,
36
+ seed=None):
37
+ """
38
+ multi scale samper
39
+ Args:
40
+ data_source(dataset)
41
+ scales(list): several scales for image resolution
42
+ first_bs(int): batch size for the first scale in scales
43
+ divided_factor(list[w, h]): ImageNet models down-sample images by a factor, ensure that width and height dimensions are multiples are multiple of devided_factor.
44
+ is_training(boolean): mode
45
+ """
46
+ self.data_source = data_source
47
+
48
+ self.seed = data_source.seed
49
+
50
+ self.img_label_pair_list = data_source.img_label_pair_list
51
+ self.shuffle = data_source.do_shuffle
52
+ self.is_training = data_source.mode == 'train'
53
+
54
+ max_side = data_source.max_side
55
+ batch_list = []
56
+ sorted_keys = sorted(
57
+ self.img_label_pair_list.keys(),
58
+ key=lambda k: int(k.split('_')[0]) * int(k.split('_')[1]))
59
+ for key in sorted_keys:
60
+ w_r, h_r = key.split('_')
61
+ w_r = int(w_r)
62
+ h_r = int(h_r)
63
+
64
+ current_bs = int(((max_side[0] * max_side[1]) // (w_r * h_r)) *
65
+ min_bs * scale_ratio)
66
+ current_bs = min(current_bs, max_bs,
67
+ len(self.img_label_pair_list[key]))
68
+ bacth_num = len(self.img_label_pair_list[key]) // current_bs
69
+ current_img_indices_all = np.arange(len(
70
+ self.img_label_pair_list[key]),
71
+ dtype=np.int64)
72
+
73
+ drop = len(self.img_label_pair_list[key]) - current_bs * bacth_num
74
+ if self.is_training and drop > 0:
75
+ drop_full_num = current_bs - drop
76
+ drop_full = np.random.choice(current_img_indices_all,
77
+ drop_full_num,
78
+ replace=True)
79
+ current_img_indices = np.append(current_img_indices_all,
80
+ drop_full)
81
+ else:
82
+ current_img_indices = current_img_indices_all[:bacth_num *
83
+ current_bs]
84
+ current_batch_list = current_img_indices.reshape(-1, current_bs, 1)
85
+ w_r_batch = np.full_like(current_batch_list, w_r)
86
+ h_r_batch = np.full_like(current_batch_list, h_r)
87
+ random_zoom_time = np.random.randint(
88
+ -5, 50, [current_batch_list.shape[0], 1, 1])
89
+ random_zoom_time = np.tile(random_zoom_time,
90
+ (1, current_batch_list.shape[1], 1))
91
+ current_batch_list = np.concatenate(
92
+ [current_batch_list, w_r_batch, h_r_batch, random_zoom_time],
93
+ axis=-1)
94
+ batch_list.extend(current_batch_list.tolist())
95
+
96
+ if not self.is_training and drop > 0:
97
+ current_img_indices = current_img_indices_all[bacth_num *
98
+ current_bs:]
99
+ current_batch_list = current_img_indices.reshape(-1, drop, 1)
100
+ w_r_batch = np.full_like(current_batch_list, w_r)
101
+ h_r_batch = np.full_like(current_batch_list, h_r)
102
+ random_zoom_time = np.random.randint(
103
+ -5, 50, [current_batch_list.shape[0], 1, 1])
104
+ random_zoom_time = np.tile(random_zoom_time,
105
+ (1, current_batch_list.shape[1], 1))
106
+ current_batch_list = np.concatenate([
107
+ current_batch_list, w_r_batch, h_r_batch, random_zoom_time
108
+ ],
109
+ axis=-1)
110
+ batch_list.extend(current_batch_list.tolist())
111
+
112
+ self.fix_cobatch = 4
113
+ self.batch_list = batch_list # [[[img_id, w_r, h_r, zoom_time], ...], ...]
114
+ self.length = len(self.batch_list)
115
+ self.batchs_id_sort = [i for i in range(self.length)]
116
+ self.batchs_in_one_epoch_id = self.batchs_id_sort.copy()
117
+ self.is_shuffled = False
118
+ self.resume_iter = resume_iter
119
+ if self.shuffle or self.is_training:
120
+ g = torch.Generator()
121
+ g.manual_seed(self.seed) # 让所有进程的种子相同
122
+ random_indices = torch.randperm(len(self.batchs_in_one_epoch_id),
123
+ generator=g).tolist()
124
+ self.batchs_in_one_epoch_id = [
125
+ self.batchs_in_one_epoch_id[i] for i in random_indices
126
+ ]
127
+ if self.resume_iter > 0:
128
+ # resume iter
129
+ for iter_ in range(len(self.batch_list)):
130
+ if iter_ <= self.resume_iter:
131
+ batch_list_current = self.batch_list[
132
+ self.batchs_in_one_epoch_id[iter_]]
133
+ batch_list_current_resume = []
134
+ for batch in batch_list_current:
135
+ batch.append(1)
136
+ batch_list_current_resume.append(batch)
137
+ self.batch_list[self.batchs_in_one_epoch_id[
138
+ iter_]] = batch_list_current_resume
139
+ else:
140
+ batch_list_current = self.batch_list[
141
+ self.batchs_in_one_epoch_id[iter_]]
142
+ batch_list_current_resume = []
143
+ for batch in batch_list_current:
144
+ batch.append(0)
145
+ batch_list_current_resume.append(batch)
146
+ self.batch_list[self.batchs_in_one_epoch_id[
147
+ iter_]] = batch_list_current_resume
148
+ self.resume_iter = 0
149
+
150
+ def __iter__(self):
151
+ for batch_tuple_id in self.batchs_in_one_epoch_id:
152
+ yield self.batch_list[batch_tuple_id]
153
+
154
+ def set_epoch(self, epoch: int):
155
+ self.epoch = epoch
156
+
157
+ def __len__(self):
158
+ return self.length
@@ -167,6 +167,8 @@ class RatioDataSetTVResize(Dataset):
167
167
  valid_ratio = min(1.0, float(resized_w / imgW))
168
168
  data['image'] = img
169
169
  data['valid_ratio'] = valid_ratio
170
+ r = float(w) / float(h)
171
+ data['real_ratio'] = max(1, round(r))
170
172
  return data
171
173
 
172
174
  def get_lmdb_sample_info(self, txn, index):
@@ -56,7 +56,8 @@ class RatioSampler(Sampler):
56
56
  self.base_im_w = base_im_w
57
57
 
58
58
  # Get the GPU and node related information
59
- num_replicas = torch.cuda.device_count() if torch.cuda.is_available() else 1
59
+ num_replicas = torch.cuda.device_count() if torch.cuda.is_available(
60
+ ) else 1
60
61
  # rank = dist.get_rank()
61
62
  rank = (int(os.environ['LOCAL_RANK'])
62
63
  if 'LOCAL_RANK' in os.environ else 0)
@@ -0,0 +1,38 @@
1
+ import os
2
+ import sys
3
+
4
+ __dir__ = os.path.dirname(os.path.abspath(__file__))
5
+
6
+ sys.path.append(__dir__)
7
+ sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
8
+ sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..', '..')))
9
+
10
+ from engine import Config
11
+ from utility import ArgsParser
12
+ import download.utils
13
+ from torchvision.datasets.utils import extract_archive
14
+
15
+
16
+ def main(cfg):
17
+ urls, filename_paths, check_validity = download.utils.get_dataset_info(cfg)
18
+ for url, filename_path in zip(urls, filename_paths):
19
+ print(f'Downloading {filename_path} from {url} . . .')
20
+ download.utils.urlretrieve(url=url,
21
+ filename=filename_path,
22
+ check_validity=check_validity)
23
+ if not filename_path.endswith('.mdb'):
24
+ extract_archive(from_path=filename_path,
25
+ to_path=cfg['root'],
26
+ remove_finished=True)
27
+
28
+ print('Downloads finished!')
29
+
30
+
31
+ if __name__ == '__main__':
32
+ FLAGS = ArgsParser().parse_args()
33
+ cfg = Config(FLAGS.config)
34
+ FLAGS = vars(FLAGS)
35
+ opt = FLAGS.pop('opt')
36
+ cfg.merge_dict(FLAGS)
37
+ cfg.merge_dict(opt)
38
+ main(cfg.cfg)
@@ -0,0 +1,28 @@
1
+ import urllib
2
+ import ssl
3
+ from tqdm import tqdm
4
+ import os
5
+
6
+
7
+ def get_dataset_info(cfg):
8
+ download_urls, filenames, check_validity = cfg['download_links'], cfg[
9
+ 'filenames'], cfg['check_validity']
10
+ return download_urls, filenames, check_validity
11
+
12
+
13
+ # Modified from torchvision as some datasets cant pass the certificate validity check:
14
+ # https://github.com/pytorch/vision/blob/868a3b42f4bffe29e4414ad7e4c7d9d0b4690ecb/torchvision/datasets/utils.py#L27C1-L32C40
15
+ def urlretrieve(url, filename, chunk_size=1024 * 32, check_validity=True):
16
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
17
+ ctx = ssl.create_default_context()
18
+ if not check_validity:
19
+ ctx.check_hostname = False
20
+ ctx.verify_mode = ssl.CERT_NONE
21
+ request = urllib.request.Request(url)
22
+ with urllib.request.urlopen(request, context=ctx) as response:
23
+ with open(filename, 'wb') as fh, tqdm(total=response.length,
24
+ unit='B',
25
+ unit_scale=True) as pbar:
26
+ while chunk := response.read(chunk_size):
27
+ fh.write(chunk)
28
+ pbar.update(len(chunk))
@@ -0,0 +1,236 @@
1
+ """Download example images from ModelScope dataset for demo purposes."""
2
+ import os
3
+ from pathlib import Path
4
+ import shutil
5
+
6
+
7
+ def download_example_images():
8
+ """Download example images from ModelScope dataset.
9
+
10
+ Returns:
11
+ Dict with paths to example image directories: {'ocr': path, 'doc': path, 'unirec': path}
12
+ """
13
+ # Will use dataset cache path folders directly
14
+ subdirs = {}
15
+
16
+ print(f'📥 Downloading example images...')
17
+
18
+ download_success = False
19
+
20
+ try:
21
+ # Try ModelScope first (default)
22
+ print('🌐 Trying ModelScope (China mirror) first...')
23
+ try:
24
+ # Download files directly from ModelScope dataset repository
25
+ dataset_id = 'topdktu/openocr_test_images'
26
+
27
+ # Try to get file list and download
28
+ try:
29
+ # This is a simplified approach - download via git clone or snapshot
30
+ from modelscope.hub.snapshot_download import snapshot_download
31
+
32
+ cache_path = snapshot_download(
33
+ repo_id=dataset_id,
34
+ repo_type='dataset',
35
+ cache_dir=str(Path.home() / '.cache' / 'openocr')
36
+ )
37
+
38
+ print(f'✅ Dataset downloaded from ModelScope to {cache_path}')
39
+
40
+ # Use dataset cache path folders directly
41
+ cache_dir = Path(cache_path)
42
+ subdirs = {
43
+ 'ocr': cache_dir / 'ocr',
44
+ 'doc': cache_dir / 'doc',
45
+ 'unirec': cache_dir / 'unirec'
46
+ }
47
+
48
+ # Verify folders exist and have images
49
+ all_folders_valid = True
50
+ for folder_name, folder_path in subdirs.items():
51
+ if folder_path.exists():
52
+ img_count = len([f for f in folder_path.glob('*') if f.is_file() and f.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp', '.gif']])
53
+ if img_count > 0:
54
+ print(f' ✓ Found {folder_name} folder with {img_count} images')
55
+ else:
56
+ print(f' ⚠️ {folder_name} folder exists but has no images')
57
+ all_folders_valid = False
58
+ else:
59
+ print(f' ⚠️ {folder_name} folder not found')
60
+ all_folders_valid = False
61
+
62
+ if all_folders_valid:
63
+ download_success = True
64
+ else:
65
+ print('⚠️ ModelScope download incomplete, trying HuggingFace...')
66
+ subdirs = {}
67
+
68
+ except Exception as e:
69
+ print(f'⚠️ ModelScope snapshot download failed: {e}')
70
+ print(' Trying HuggingFace...')
71
+
72
+ except ImportError:
73
+ print('⚠️ modelscope not installed. Install with: pip install modelscope')
74
+ print(' Trying HuggingFace...')
75
+ except Exception as e:
76
+ print(f'⚠️ ModelScope download failed: {e}')
77
+ print(' Trying HuggingFace...')
78
+
79
+ if not download_success:
80
+ # Try HuggingFace
81
+ print('🌐 Using HuggingFace...')
82
+ try:
83
+ from huggingface_hub import snapshot_download
84
+
85
+ # Download entire dataset
86
+ dataset_path = snapshot_download(
87
+ repo_id='topdu/openocr_test_images',
88
+ repo_type='dataset',
89
+ cache_dir=str(Path.home() / '.cache' / 'openocr')
90
+ )
91
+
92
+ print(f'✅ Dataset downloaded to {dataset_path}')
93
+
94
+ # Use dataset cache path folders directly
95
+ cache_dir = Path(dataset_path)
96
+ subdirs = {
97
+ 'ocr': cache_dir / 'ocr',
98
+ 'doc': cache_dir / 'doc',
99
+ 'unirec': cache_dir / 'unirec'
100
+ }
101
+
102
+ # Verify folders exist and have images
103
+ all_folders_valid = True
104
+ for folder_name, folder_path in subdirs.items():
105
+ if folder_path.exists():
106
+ img_count = len([f for f in folder_path.glob('*') if f.is_file() and f.suffix.lower() in ['.jpg', '.jpeg', '.png', '.bmp', '.gif']])
107
+ if img_count > 0:
108
+ print(f' ✓ Found {folder_name} folder with {img_count} images')
109
+ else:
110
+ print(f' ⚠️ {folder_name} folder exists but has no images')
111
+ all_folders_valid = False
112
+ else:
113
+ print(f' ⚠️ {folder_name} folder not found')
114
+ all_folders_valid = False
115
+
116
+ if all_folders_valid:
117
+ download_success = True
118
+
119
+ except ImportError:
120
+ print('⚠️ huggingface_hub not installed. Install with: pip install huggingface_hub')
121
+ except Exception as e:
122
+ print(f'⚠️ HuggingFace download failed: {e}')
123
+
124
+ # Try GitHub releases as fallback for OCR examples only
125
+ if not download_success:
126
+ print('🌐 Trying GitHub releases as fallback for OCR examples...')
127
+ try:
128
+ import urllib.request
129
+ import tarfile
130
+ import tempfile
131
+
132
+ ocr_url = 'https://github.com/Topdu/OpenOCR/releases/download/develop0.0.1/OCR_e2e_img.tar'
133
+
134
+ # Use temp directory for download
135
+ with tempfile.TemporaryDirectory() as temp_dir:
136
+ temp_path = Path(temp_dir)
137
+ tar_path = temp_path / 'OCR_e2e_img.tar'
138
+
139
+ print(f' Downloading from {ocr_url}...')
140
+ urllib.request.urlretrieve(ocr_url, str(tar_path))
141
+
142
+ print(f' Extracting...')
143
+ with tarfile.open(str(tar_path), 'r') as tar:
144
+ tar.extractall(path=str(temp_path))
145
+
146
+ # Move to cache directory
147
+ cache_base = Path.home() / '.cache' / 'openocr' / 'openocr_examples'
148
+ cache_base.mkdir(parents=True, exist_ok=True)
149
+
150
+ # Copy extracted files to cache
151
+ ocr_source = temp_path / 'OCR_e2e_img'
152
+ ocr_target = cache_base / 'ocr'
153
+ if ocr_source.exists():
154
+ if ocr_target.exists():
155
+ shutil.rmtree(str(ocr_target))
156
+ shutil.copytree(str(ocr_source), str(ocr_target))
157
+
158
+ # Set subdirs for GitHub download
159
+ subdirs = {
160
+ 'ocr': ocr_target,
161
+ 'doc': cache_base / 'doc',
162
+ 'unirec': cache_base / 'unirec'
163
+ }
164
+
165
+ # Create empty directories for doc and unirec if they don't exist
166
+ for key in ['doc', 'unirec']:
167
+ subdirs[key].mkdir(parents=True, exist_ok=True)
168
+
169
+ print(f' ✓ OCR example images downloaded from GitHub to cache')
170
+ download_success = True
171
+
172
+ except Exception as e:
173
+ print(f'⚠️ GitHub download failed: {e}')
174
+
175
+ if download_success:
176
+ print(f'✅ Example images ready!')
177
+ else:
178
+ print('⚠️ Could not download example images automatically.')
179
+
180
+ except Exception as e:
181
+ print(f'❌ Download failed: {e}')
182
+
183
+ finally:
184
+ # Verify directories
185
+ if subdirs:
186
+ print('\n📝 Example image directories:')
187
+ for name, subdir in subdirs.items():
188
+ if subdir.exists():
189
+ if not any(subdir.iterdir()):
190
+ print(f' ⚠️ {name}: No images found in {subdir}')
191
+ print(f' You can manually add example images to this directory.')
192
+ else:
193
+ img_count = len(list(subdir.glob('*.[jp][pn]g')) + list(subdir.glob('*.jpeg')) + list(subdir.glob('*.bmp')))
194
+ print(f' ✓ {name}: {img_count} images found in {subdir}')
195
+ else:
196
+ print(f' ⚠️ {name}: Directory not found at {subdir}')
197
+ else:
198
+ print('\n⚠️ No example image directories available')
199
+
200
+ return {k: str(v) for k, v in subdirs.items()}
201
+
202
+
203
+ def get_example_images_path(demo_type='ocr'):
204
+ """Get the path to example images for a specific demo type.
205
+
206
+ Args:
207
+ demo_type: Type of demo ('ocr', 'doc', or 'unirec')
208
+
209
+ Returns:
210
+ Path to example images directory
211
+ """
212
+ # Download and get paths from cache
213
+ print(f'Getting example images for {demo_type}...')
214
+ paths = download_example_images()
215
+
216
+ # Return the path for the requested demo type
217
+ if demo_type in paths:
218
+ return paths[demo_type]
219
+ else:
220
+ print(f'⚠️ Unknown demo type: {demo_type}')
221
+ return paths.get('ocr', '')
222
+
223
+
224
+ if __name__ == '__main__':
225
+ # Test download
226
+ import argparse
227
+
228
+ parser = argparse.ArgumentParser(description='Download example images for OpenOCR demos')
229
+
230
+ args = parser.parse_args()
231
+
232
+ paths = download_example_images()
233
+
234
+ print('\n📁 Example image directories:')
235
+ for demo_type, path in paths.items():
236
+ print(f' {demo_type}: {path}')