xfmr-zem 0.2.4__py3-none-any.whl → 0.2.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xfmr_zem/cli.py +32 -3
- xfmr_zem/client.py +59 -8
- xfmr_zem/server.py +21 -4
- xfmr_zem/servers/data_juicer/server.py +1 -1
- xfmr_zem/servers/instruction_gen/server.py +1 -1
- xfmr_zem/servers/io/server.py +1 -1
- xfmr_zem/servers/llm/parameters.yml +10 -0
- xfmr_zem/servers/nemo_curator/server.py +1 -1
- xfmr_zem/servers/ocr/deepdoc_vietocr/__init__.py +90 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/implementations.py +1286 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/layout_recognizer.py +562 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/ocr.py +512 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/onnx/.gitattributes +35 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/onnx/README.md +5 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/onnx/ocr.res +6623 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/operators.py +725 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/phases.py +191 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/pipeline.py +561 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/postprocess.py +370 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/recognizer.py +436 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/table_structure_recognizer.py +569 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/utils/__init__.py +81 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/utils/file_utils.py +246 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/vietocr/__init__.py +0 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/vietocr/config/base.yml +58 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/vietocr/config/vgg-seq2seq.yml +38 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/vietocr/model/__init__.py +0 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/vietocr/model/backbone/cnn.py +25 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/vietocr/model/backbone/vgg.py +51 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/vietocr/model/seqmodel/seq2seq.py +175 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/vietocr/model/transformerocr.py +29 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/vietocr/model/vocab.py +36 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/vietocr/tool/config.py +37 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/vietocr/tool/translate.py +111 -0
- xfmr_zem/servers/ocr/engines.py +242 -0
- xfmr_zem/servers/ocr/install_models.py +63 -0
- xfmr_zem/servers/ocr/parameters.yml +4 -0
- xfmr_zem/servers/ocr/server.py +102 -0
- xfmr_zem/servers/profiler/parameters.yml +4 -0
- xfmr_zem/servers/sinks/parameters.yml +6 -0
- xfmr_zem/servers/unstructured/parameters.yml +6 -0
- xfmr_zem/servers/unstructured/server.py +62 -0
- xfmr_zem/zenml_wrapper.py +20 -7
- {xfmr_zem-0.2.4.dist-info → xfmr_zem-0.2.6.dist-info}/METADATA +20 -1
- xfmr_zem-0.2.6.dist-info/RECORD +58 -0
- xfmr_zem-0.2.4.dist-info/RECORD +0 -23
- /xfmr_zem/servers/data_juicer/{parameter.yaml → parameters.yml} +0 -0
- /xfmr_zem/servers/instruction_gen/{parameter.yaml → parameters.yml} +0 -0
- /xfmr_zem/servers/io/{parameter.yaml → parameters.yml} +0 -0
- /xfmr_zem/servers/nemo_curator/{parameter.yaml → parameters.yml} +0 -0
- {xfmr_zem-0.2.4.dist-info → xfmr_zem-0.2.6.dist-info}/WHEEL +0 -0
- {xfmr_zem-0.2.4.dist-info → xfmr_zem-0.2.6.dist-info}/entry_points.txt +0 -0
- {xfmr_zem-0.2.4.dist-info → xfmr_zem-0.2.6.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,512 @@
|
|
|
1
|
+
#
|
|
2
|
+
# Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
#
|
|
16
|
+
|
|
17
|
+
import logging
|
|
18
|
+
import copy
|
|
19
|
+
import time
|
|
20
|
+
import os
|
|
21
|
+
import sys
|
|
22
|
+
from pathlib import Path
|
|
23
|
+
|
|
24
|
+
# External libs
|
|
25
|
+
import numpy as np
|
|
26
|
+
import cv2
|
|
27
|
+
import onnxruntime as ort
|
|
28
|
+
import torch
|
|
29
|
+
from PIL import Image
|
|
30
|
+
|
|
31
|
+
# Internal modules
|
|
32
|
+
from . import operators
|
|
33
|
+
from .postprocess import build_post_process
|
|
34
|
+
from .vietocr.tool.config import Cfg
|
|
35
|
+
|
|
36
|
+
# Handle VietOCR import
|
|
37
|
+
try:
|
|
38
|
+
from .vietocr.tool.predictor import Predictor
|
|
39
|
+
except ImportError:
|
|
40
|
+
from .vietocr.tool.translate import Predictor
|
|
41
|
+
|
|
42
|
+
def get_project_base_directory():
|
|
43
|
+
return Path(__file__).resolve().parent
|
|
44
|
+
|
|
45
|
+
loaded_models = {}
|
|
46
|
+
|
|
47
|
+
def load_model(model_dir, nm, device_id: int | None = None):
|
|
48
|
+
model_file_path = os.path.join(model_dir, nm + ".onnx")
|
|
49
|
+
model_cached_tag = model_file_path + str(device_id) if device_id is not None else model_file_path
|
|
50
|
+
|
|
51
|
+
global loaded_models
|
|
52
|
+
loaded_model = loaded_models.get(model_cached_tag)
|
|
53
|
+
if loaded_model:
|
|
54
|
+
logging.info(f"load_model {model_file_path} reuses cached model")
|
|
55
|
+
return loaded_model
|
|
56
|
+
|
|
57
|
+
if not os.path.exists(model_file_path):
|
|
58
|
+
raise ValueError("not find model file path {}".format(
|
|
59
|
+
model_file_path))
|
|
60
|
+
|
|
61
|
+
def cuda_is_available():
|
|
62
|
+
try:
|
|
63
|
+
import torch
|
|
64
|
+
if torch.cuda.is_available() and torch.cuda.device_count() > (device_id if device_id else 0):
|
|
65
|
+
return True
|
|
66
|
+
except Exception:
|
|
67
|
+
return False
|
|
68
|
+
return False
|
|
69
|
+
|
|
70
|
+
options = ort.SessionOptions()
|
|
71
|
+
options.enable_cpu_mem_arena = False
|
|
72
|
+
options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
|
|
73
|
+
options.intra_op_num_threads = 2
|
|
74
|
+
options.inter_op_num_threads = 2
|
|
75
|
+
|
|
76
|
+
run_options = ort.RunOptions()
|
|
77
|
+
|
|
78
|
+
providers = ['CPUExecutionProvider']
|
|
79
|
+
if cuda_is_available():
|
|
80
|
+
cuda_provider_options = {
|
|
81
|
+
"device_id": device_id if device_id else 0,
|
|
82
|
+
"gpu_mem_limit": 512 * 1024 * 1024,
|
|
83
|
+
"arena_extend_strategy": "kNextPowerOfTwo",
|
|
84
|
+
}
|
|
85
|
+
providers = [('CUDAExecutionProvider', cuda_provider_options)]
|
|
86
|
+
run_options.add_run_config_entry("memory.enable_memory_arena_shrinkage", "gpu:" + str(device_id if device_id else 0))
|
|
87
|
+
logging.info(f"load_model {model_file_path} uses GPU")
|
|
88
|
+
else:
|
|
89
|
+
run_options.add_run_config_entry("memory.enable_memory_arena_shrinkage", "cpu")
|
|
90
|
+
logging.info(f"load_model {model_file_path} uses CPU")
|
|
91
|
+
|
|
92
|
+
sess = ort.InferenceSession(model_file_path, options=options, providers=providers)
|
|
93
|
+
|
|
94
|
+
loaded_model = (sess, run_options)
|
|
95
|
+
loaded_models[model_cached_tag] = loaded_model
|
|
96
|
+
return loaded_model
|
|
97
|
+
|
|
98
|
+
def download_file(url, file_path):
|
|
99
|
+
import urllib.request
|
|
100
|
+
from tqdm import tqdm
|
|
101
|
+
|
|
102
|
+
class DownloadProgressBar(tqdm):
|
|
103
|
+
def update_to(self, b=1, bsize=1, tsize=None):
|
|
104
|
+
if tsize is not None:
|
|
105
|
+
self.total = tsize
|
|
106
|
+
self.update(b * bsize - self.n)
|
|
107
|
+
|
|
108
|
+
logging.info(f"Downloading {url} to {file_path}")
|
|
109
|
+
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
|
110
|
+
|
|
111
|
+
try:
|
|
112
|
+
with DownloadProgressBar(unit='B', unit_scale=True,
|
|
113
|
+
miniters=1, desc=url.split('/')[-1]) as t:
|
|
114
|
+
urllib.request.urlretrieve(url, filename=file_path, reporthook=t.update_to)
|
|
115
|
+
logging.info("Download completed.")
|
|
116
|
+
except Exception as e:
|
|
117
|
+
logging.error(f"Failed to download model: {e}")
|
|
118
|
+
if os.path.exists(file_path):
|
|
119
|
+
os.remove(file_path)
|
|
120
|
+
raise e
|
|
121
|
+
|
|
122
|
+
def create_operators(op_param_list, global_config=None):
|
|
123
|
+
ops = []
|
|
124
|
+
for operator in op_param_list:
|
|
125
|
+
assert isinstance(operator, dict) and len(operator) == 1, "yaml format error"
|
|
126
|
+
op_name = list(operator)[0]
|
|
127
|
+
param = {} if operator[op_name] is None else operator[op_name]
|
|
128
|
+
if global_config is not None:
|
|
129
|
+
param.update(global_config)
|
|
130
|
+
op = getattr(operators, op_name)(**param)
|
|
131
|
+
ops.append(op)
|
|
132
|
+
return ops
|
|
133
|
+
|
|
134
|
+
class TextRecognizer:
|
|
135
|
+
def __init__(self, model_dir=None, device_id: int | None = None):
|
|
136
|
+
# VietOCR Configuration
|
|
137
|
+
config = Cfg.load_config_from_name('vgg-seq2seq')
|
|
138
|
+
weights_path = os.path.join(get_project_base_directory(), "vietocr", "weight", "vgg_seq2seq.pth")
|
|
139
|
+
config['weights'] = weights_path
|
|
140
|
+
config['cnn']['pretrained'] = True
|
|
141
|
+
config['device'] = 'cpu' # Optimized with Quantization in translate.py
|
|
142
|
+
|
|
143
|
+
logging.info("Initializing VietOCR (Text Recognition)...")
|
|
144
|
+
self.detector = Predictor(config)
|
|
145
|
+
|
|
146
|
+
def __call__(self, img_list):
|
|
147
|
+
results = []
|
|
148
|
+
for img in img_list:
|
|
149
|
+
# Ensure PIL Image
|
|
150
|
+
if isinstance(img, np.ndarray):
|
|
151
|
+
img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
|
|
152
|
+
|
|
153
|
+
try:
|
|
154
|
+
text = self.detector.predict(img)
|
|
155
|
+
results.append((text, 1.0))
|
|
156
|
+
except Exception as e:
|
|
157
|
+
logging.warning(f"VietOCR prediction failed: {e}")
|
|
158
|
+
results.append(("", 0.0))
|
|
159
|
+
return results, 0.0
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
class TextRecognizerPaddleOCR:
|
|
163
|
+
"""
|
|
164
|
+
PaddleOCR Recognition Model (Python API)
|
|
165
|
+
Thay thế VietOCR với model accuracy cao hơn
|
|
166
|
+
|
|
167
|
+
Uses PaddleOCR library directly instead of ONNX for simpler setup
|
|
168
|
+
"""
|
|
169
|
+
def __init__(self, model_dir=None, device_id: int | None = None):
|
|
170
|
+
from paddleocr import PaddleOCR
|
|
171
|
+
|
|
172
|
+
# Initialize PaddleOCR với English language (supports Latin alphabet + Vietnamese)
|
|
173
|
+
logging.info("Initializing PaddleOCR (English/Latin Recognition)...")
|
|
174
|
+
|
|
175
|
+
# Initialize with only lang parameter (most compatible)
|
|
176
|
+
self.ocr = PaddleOCR(lang='en')
|
|
177
|
+
|
|
178
|
+
def __call__(self, img_list):
|
|
179
|
+
"""Recognize text from image crops"""
|
|
180
|
+
results = []
|
|
181
|
+
|
|
182
|
+
for img in img_list:
|
|
183
|
+
try:
|
|
184
|
+
# Convert to numpy if PIL Image
|
|
185
|
+
if isinstance(img, Image.Image):
|
|
186
|
+
img = np.array(img)
|
|
187
|
+
|
|
188
|
+
# Convert to RGB if needed
|
|
189
|
+
if len(img.shape) == 2:
|
|
190
|
+
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
|
|
191
|
+
elif img.shape[2] == 4:
|
|
192
|
+
img = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB)
|
|
193
|
+
|
|
194
|
+
# Use PaddleOCR's text recognizer directly (skip detection)
|
|
195
|
+
# Format image for recognition (convert to correct format)
|
|
196
|
+
if img.shape[0] < 32: # Too small height
|
|
197
|
+
img = cv2.resize(img, (int(img.shape[1] * 32 / img.shape[0]), 32))
|
|
198
|
+
|
|
199
|
+
# Call recognizer directly
|
|
200
|
+
rec_result = self.ocr.rec(img)
|
|
201
|
+
|
|
202
|
+
# Extract text and confidence from recognizer output
|
|
203
|
+
if rec_result and len(rec_result) > 0:
|
|
204
|
+
# rec_result format: [(text, confidence), ...]
|
|
205
|
+
if isinstance(rec_result[0], (tuple, list)) and len(rec_result[0]) >= 2:
|
|
206
|
+
text = str(rec_result[0][0])
|
|
207
|
+
confidence = float(rec_result[0][1])
|
|
208
|
+
else:
|
|
209
|
+
text = str(rec_result[0])
|
|
210
|
+
confidence = 1.0
|
|
211
|
+
else:
|
|
212
|
+
text = ""
|
|
213
|
+
confidence = 0.0
|
|
214
|
+
|
|
215
|
+
results.append((text, float(confidence)))
|
|
216
|
+
|
|
217
|
+
except Exception as e:
|
|
218
|
+
logging.warning(f"PaddleOCR recognition failed: {e}")
|
|
219
|
+
results.append(("", 0.0))
|
|
220
|
+
|
|
221
|
+
return results, 0.0
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
class TextDetector:
|
|
225
|
+
def __init__(self, model_dir, device_id: int | None = None):
|
|
226
|
+
# ONNX Model Path
|
|
227
|
+
self.model_path = os.path.join(model_dir, "det.onnx")
|
|
228
|
+
|
|
229
|
+
# Download if not exists
|
|
230
|
+
if not os.path.exists(self.model_path):
|
|
231
|
+
logging.info("ONNX detection model not found. Downloading from monkt/paddleocr-onnx...")
|
|
232
|
+
url = "https://huggingface.co/monkt/paddleocr-onnx/resolve/main/detection/v5/det.onnx"
|
|
233
|
+
download_file(url, self.model_path)
|
|
234
|
+
|
|
235
|
+
# Initialize ONNX Session
|
|
236
|
+
sess_options = ort.SessionOptions()
|
|
237
|
+
sess_options.intra_op_num_threads = 2
|
|
238
|
+
sess_options.inter_op_num_threads = 2
|
|
239
|
+
|
|
240
|
+
providers = ['CPUExecutionProvider']
|
|
241
|
+
# Enable CUDA if requested/available
|
|
242
|
+
if device_id is not None and torch.cuda.is_available():
|
|
243
|
+
pass # Stick to CPU as requested, or add 'CUDAExecutionProvider' if needed.
|
|
244
|
+
|
|
245
|
+
logging.info(f"Loading ONNX Text Detector from {self.model_path}")
|
|
246
|
+
self.session = ort.InferenceSession(self.model_path, sess_options, providers=providers)
|
|
247
|
+
self.input_tensor_name = self.session.get_inputs()[0].name
|
|
248
|
+
|
|
249
|
+
# Preprocess Configuration (Standard DBNet / PP-OCR)
|
|
250
|
+
pre_process_list = [{"DetResizeForTest": {"limit_side_len": 960, "limit_type": "max"}},
|
|
251
|
+
{"NormalizeImage": {"std": [0.229, 0.224, 0.225], "mean": [0.485, 0.456, 0.406], "scale": "1./255.", "order": "hwc"}},
|
|
252
|
+
{"ToCHWImage": None},
|
|
253
|
+
{"KeepKeys": {"keep_keys": ["image", "shape"]}}]
|
|
254
|
+
self.preprocess_op = create_operators(pre_process_list)
|
|
255
|
+
|
|
256
|
+
# Postprocess Configuration
|
|
257
|
+
postprocess_params = {
|
|
258
|
+
"name": "DBPostProcess",
|
|
259
|
+
"thresh": 0.3,
|
|
260
|
+
"box_thresh": 0.6,
|
|
261
|
+
"max_candidates": 1000,
|
|
262
|
+
"unclip_ratio": 1.5,
|
|
263
|
+
"use_dilation": False,
|
|
264
|
+
"score_mode": "fast",
|
|
265
|
+
"box_type": "quad"
|
|
266
|
+
}
|
|
267
|
+
self.postprocess_op = build_post_process(postprocess_params)
|
|
268
|
+
|
|
269
|
+
def order_points_clockwise(self, pts):
|
|
270
|
+
rect = np.zeros((4, 2), dtype="float32")
|
|
271
|
+
s = pts.sum(axis=1)
|
|
272
|
+
rect[0] = pts[np.argmin(s)]
|
|
273
|
+
rect[2] = pts[np.argmax(s)]
|
|
274
|
+
tmp = np.delete(pts, (np.argmin(s), np.argmax(s)), axis=0)
|
|
275
|
+
diff = np.diff(np.array(tmp), axis=1)
|
|
276
|
+
rect[1] = tmp[np.argmin(diff)]
|
|
277
|
+
rect[3] = tmp[np.argmax(diff)]
|
|
278
|
+
return rect
|
|
279
|
+
|
|
280
|
+
def filter_tag_det_res(self, dt_boxes, image_shape):
|
|
281
|
+
img_height, img_width = image_shape[0:2]
|
|
282
|
+
dt_boxes_new = []
|
|
283
|
+
for box in dt_boxes:
|
|
284
|
+
if isinstance(box, list):
|
|
285
|
+
box = np.array(box)
|
|
286
|
+
box = self.order_points_clockwise(box)
|
|
287
|
+
box[:, 0] = np.clip(box[:, 0], 0, img_width - 1)
|
|
288
|
+
box[:, 1] = np.clip(box[:, 1], 0, img_height - 1)
|
|
289
|
+
|
|
290
|
+
rect_width = int(np.linalg.norm(box[0] - box[1]))
|
|
291
|
+
rect_height = int(np.linalg.norm(box[0] - box[3]))
|
|
292
|
+
if rect_width <= 3 or rect_height <= 3:
|
|
293
|
+
continue
|
|
294
|
+
dt_boxes_new.append(box)
|
|
295
|
+
return np.array(dt_boxes_new)
|
|
296
|
+
|
|
297
|
+
def transform(self, data, ops=None):
|
|
298
|
+
if ops is None:
|
|
299
|
+
ops = []
|
|
300
|
+
for op in ops:
|
|
301
|
+
data = op(data)
|
|
302
|
+
if data is None:
|
|
303
|
+
return None
|
|
304
|
+
return data
|
|
305
|
+
|
|
306
|
+
def __call__(self, img):
|
|
307
|
+
ori_im = img.copy()
|
|
308
|
+
data = {'image': img}
|
|
309
|
+
|
|
310
|
+
st = time.time()
|
|
311
|
+
|
|
312
|
+
# Preprocess
|
|
313
|
+
data = self.transform(data, self.preprocess_op)
|
|
314
|
+
img, shape_list = data
|
|
315
|
+
|
|
316
|
+
if img is None:
|
|
317
|
+
return None, 0
|
|
318
|
+
|
|
319
|
+
img = np.expand_dims(img, axis=0)
|
|
320
|
+
shape_list = np.expand_dims(shape_list, axis=0)
|
|
321
|
+
|
|
322
|
+
# Inference
|
|
323
|
+
input_dict = {self.input_tensor_name: img}
|
|
324
|
+
outputs = self.session.run(None, input_dict)
|
|
325
|
+
|
|
326
|
+
# Postprocess
|
|
327
|
+
post_result = self.postprocess_op({"maps": outputs[0]}, shape_list)
|
|
328
|
+
dt_boxes = post_result[0]['points']
|
|
329
|
+
|
|
330
|
+
if dt_boxes is None or len(dt_boxes) == 0:
|
|
331
|
+
return None, time.time() - st
|
|
332
|
+
|
|
333
|
+
dt_boxes = self.filter_tag_det_res(dt_boxes, ori_im.shape)
|
|
334
|
+
|
|
335
|
+
return dt_boxes, time.time() - st
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
class OCR:
|
|
339
|
+
def __init__(self, model_dir=None, use_paddleocr_rec=True):
|
|
340
|
+
"""
|
|
341
|
+
Initialize OCR pipeline
|
|
342
|
+
|
|
343
|
+
Args:
|
|
344
|
+
model_dir: Directory for model files
|
|
345
|
+
use_paddleocr_rec: If True, use PaddleOCR recognition (recommended)
|
|
346
|
+
If False, use VietOCR recognition (legacy)
|
|
347
|
+
"""
|
|
348
|
+
if not model_dir:
|
|
349
|
+
model_dir = os.path.join(get_project_base_directory(), "onnx")
|
|
350
|
+
|
|
351
|
+
os.makedirs(model_dir, exist_ok=True)
|
|
352
|
+
|
|
353
|
+
# Detect parallel devices (optional, mostly for GPU)
|
|
354
|
+
try:
|
|
355
|
+
import torch.cuda
|
|
356
|
+
parallel_devices = torch.cuda.device_count() if torch.cuda.is_available() else 0
|
|
357
|
+
except Exception:
|
|
358
|
+
parallel_devices = 0
|
|
359
|
+
|
|
360
|
+
# Log which recognizer is being used
|
|
361
|
+
recognizer_name = "PaddleOCR" if use_paddleocr_rec else "VietOCR"
|
|
362
|
+
logging.info(f"Initializing OCR with {recognizer_name} recognizer")
|
|
363
|
+
|
|
364
|
+
# Initialize Detector and Recognizer
|
|
365
|
+
if parallel_devices > 0:
|
|
366
|
+
self.text_detector = []
|
|
367
|
+
self.text_recognizer = []
|
|
368
|
+
for device_id in range(parallel_devices):
|
|
369
|
+
self.text_detector.append(TextDetector(model_dir, device_id))
|
|
370
|
+
|
|
371
|
+
# Choose recognizer based on flag
|
|
372
|
+
if use_paddleocr_rec:
|
|
373
|
+
self.text_recognizer.append(TextRecognizerPaddleOCR(model_dir, device_id))
|
|
374
|
+
else:
|
|
375
|
+
self.text_recognizer.append(TextRecognizer(model_dir, device_id))
|
|
376
|
+
else:
|
|
377
|
+
self.text_detector = [TextDetector(model_dir, 0)]
|
|
378
|
+
|
|
379
|
+
# Choose recognizer based on flag
|
|
380
|
+
if use_paddleocr_rec:
|
|
381
|
+
self.text_recognizer = [TextRecognizerPaddleOCR(model_dir, 0)]
|
|
382
|
+
else:
|
|
383
|
+
self.text_recognizer = [TextRecognizer(model_dir, 0)]
|
|
384
|
+
|
|
385
|
+
self.drop_score = 0.5
|
|
386
|
+
|
|
387
|
+
def get_rotate_crop_image(self, img, points):
|
|
388
|
+
assert len(points) == 4, "shape of points must be 4*2"
|
|
389
|
+
img_crop_width = int(
|
|
390
|
+
max(
|
|
391
|
+
np.linalg.norm(points[0] - points[1]),
|
|
392
|
+
np.linalg.norm(points[2] - points[3])))
|
|
393
|
+
img_crop_height = int(
|
|
394
|
+
max(
|
|
395
|
+
np.linalg.norm(points[0] - points[3]),
|
|
396
|
+
np.linalg.norm(points[1] - points[2])))
|
|
397
|
+
pts_std = np.float32([[0, 0], [img_crop_width, 0],
|
|
398
|
+
[img_crop_width, img_crop_height],
|
|
399
|
+
[0, img_crop_height]])
|
|
400
|
+
M = cv2.getPerspectiveTransform(points, pts_std)
|
|
401
|
+
dst_img = cv2.warpPerspective(
|
|
402
|
+
img,
|
|
403
|
+
M, (img_crop_width, img_crop_height),
|
|
404
|
+
borderMode=cv2.BORDER_REPLICATE,
|
|
405
|
+
flags=cv2.INTER_CUBIC)
|
|
406
|
+
dst_img_height, dst_img_width = dst_img.shape[0:2]
|
|
407
|
+
if dst_img_height * 1.0 / dst_img_width >= 1.5:
|
|
408
|
+
dst_img = np.rot90(dst_img)
|
|
409
|
+
return dst_img
|
|
410
|
+
|
|
411
|
+
def sorted_boxes(self, dt_boxes):
|
|
412
|
+
num_boxes = dt_boxes.shape[0]
|
|
413
|
+
sorted_boxes = sorted(dt_boxes, key=lambda x: (x[0][1], x[0][0]))
|
|
414
|
+
_boxes = list(sorted_boxes)
|
|
415
|
+
|
|
416
|
+
for i in range(num_boxes - 1):
|
|
417
|
+
for j in range(i, -1, -1):
|
|
418
|
+
if abs(_boxes[j + 1][0][1] - _boxes[j][0][1]) < 10 and \
|
|
419
|
+
(_boxes[j + 1][0][0] < _boxes[j][0][0]):
|
|
420
|
+
tmp = _boxes[j]
|
|
421
|
+
_boxes[j] = _boxes[j + 1]
|
|
422
|
+
_boxes[j + 1] = tmp
|
|
423
|
+
else:
|
|
424
|
+
break
|
|
425
|
+
return _boxes
|
|
426
|
+
|
|
427
|
+
def detect(self, img, device_id: int | None = None):
|
|
428
|
+
if device_id is None: device_id = 0
|
|
429
|
+
|
|
430
|
+
# Handle list if accessed by index
|
|
431
|
+
if isinstance(self.text_detector, list):
|
|
432
|
+
detector = self.text_detector[device_id if device_id < len(self.text_detector) else 0]
|
|
433
|
+
else:
|
|
434
|
+
detector = self.text_detector
|
|
435
|
+
|
|
436
|
+
dt_boxes, elapse = detector(img)
|
|
437
|
+
|
|
438
|
+
if dt_boxes is None or len(dt_boxes) == 0:
|
|
439
|
+
return []
|
|
440
|
+
|
|
441
|
+
return zip(self.sorted_boxes(dt_boxes), [("", 0) for _ in range(len(dt_boxes))])
|
|
442
|
+
|
|
443
|
+
def recognize(self, ori_im, box, device_id: int | None = None):
|
|
444
|
+
if device_id is None: device_id = 0
|
|
445
|
+
|
|
446
|
+
if isinstance(self.text_recognizer, list):
|
|
447
|
+
recognizer = self.text_recognizer[device_id if device_id < len(self.text_recognizer) else 0]
|
|
448
|
+
else:
|
|
449
|
+
recognizer = self.text_recognizer
|
|
450
|
+
|
|
451
|
+
img_crop = self.get_rotate_crop_image(ori_im, box)
|
|
452
|
+
rec_res, elapse = recognizer([img_crop])
|
|
453
|
+
text, score = rec_res[0]
|
|
454
|
+
if score < self.drop_score:
|
|
455
|
+
return ""
|
|
456
|
+
return text
|
|
457
|
+
|
|
458
|
+
def recognize_batch(self, img_list, device_id: int | None = None):
|
|
459
|
+
if device_id is None: device_id = 0
|
|
460
|
+
if isinstance(self.text_recognizer, list):
|
|
461
|
+
recognizer = self.text_recognizer[device_id if device_id < len(self.text_recognizer) else 0]
|
|
462
|
+
else:
|
|
463
|
+
recognizer = self.text_recognizer
|
|
464
|
+
|
|
465
|
+
rec_res, elapse = recognizer(img_list)
|
|
466
|
+
texts = []
|
|
467
|
+
for i in range(len(rec_res)):
|
|
468
|
+
text, score = rec_res[i]
|
|
469
|
+
if score < self.drop_score:
|
|
470
|
+
text = ""
|
|
471
|
+
texts.append(text)
|
|
472
|
+
return texts
|
|
473
|
+
|
|
474
|
+
def __call__(self, img, device_id=0, cls=True):
|
|
475
|
+
if device_id is None: device_id = 0
|
|
476
|
+
|
|
477
|
+
# Access detector/recognizer safely
|
|
478
|
+
if isinstance(self.text_detector, list):
|
|
479
|
+
detector = self.text_detector[device_id if device_id < len(self.text_detector) else 0]
|
|
480
|
+
recognizer = self.text_recognizer[device_id if device_id < len(self.text_recognizer) else 0]
|
|
481
|
+
else:
|
|
482
|
+
detector = self.text_detector
|
|
483
|
+
recognizer = self.text_recognizer
|
|
484
|
+
|
|
485
|
+
# Detection
|
|
486
|
+
start = time.time()
|
|
487
|
+
ori_im = img.copy()
|
|
488
|
+
dt_boxes, det_time = detector(img)
|
|
489
|
+
|
|
490
|
+
if dt_boxes is None or len(dt_boxes) == 0:
|
|
491
|
+
return []
|
|
492
|
+
|
|
493
|
+
# Crop
|
|
494
|
+
img_crop_list = []
|
|
495
|
+
dt_boxes = self.sorted_boxes(dt_boxes)
|
|
496
|
+
for bno in range(len(dt_boxes)):
|
|
497
|
+
tmp_box = copy.deepcopy(dt_boxes[bno])
|
|
498
|
+
img_crop = self.get_rotate_crop_image(ori_im, tmp_box)
|
|
499
|
+
img_crop_list.append(img_crop)
|
|
500
|
+
|
|
501
|
+
# Recognition
|
|
502
|
+
rec_res, rec_time = recognizer(img_crop_list)
|
|
503
|
+
|
|
504
|
+
# Filter
|
|
505
|
+
filter_boxes, filter_rec_res = [], []
|
|
506
|
+
for box, rec_result in zip(dt_boxes, rec_res):
|
|
507
|
+
text, score = rec_result
|
|
508
|
+
if score >= self.drop_score:
|
|
509
|
+
filter_boxes.append(box)
|
|
510
|
+
filter_rec_res.append(rec_result)
|
|
511
|
+
|
|
512
|
+
return list(zip([a.tolist() for a in filter_boxes], filter_rec_res))
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
|
2
|
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
|
3
|
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
|
4
|
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
|
5
|
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
|
6
|
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
|
7
|
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
|
8
|
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
|
9
|
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
|
10
|
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
|
11
|
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
|
12
|
+
*.model filter=lfs diff=lfs merge=lfs -text
|
|
13
|
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
|
14
|
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
|
15
|
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
|
16
|
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
|
17
|
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
|
18
|
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
|
19
|
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
|
20
|
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
|
21
|
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
|
22
|
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
|
23
|
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
|
24
|
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
|
25
|
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
|
26
|
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
27
|
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
|
28
|
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
|
29
|
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
|
30
|
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
|
31
|
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
|
32
|
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
|
33
|
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
|
34
|
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
|
35
|
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|