xfmr-zem 0.2.4__py3-none-any.whl → 0.2.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xfmr_zem/cli.py +32 -3
- xfmr_zem/client.py +59 -8
- xfmr_zem/server.py +21 -4
- xfmr_zem/servers/data_juicer/server.py +1 -1
- xfmr_zem/servers/instruction_gen/server.py +1 -1
- xfmr_zem/servers/io/server.py +1 -1
- xfmr_zem/servers/llm/parameters.yml +10 -0
- xfmr_zem/servers/nemo_curator/server.py +1 -1
- xfmr_zem/servers/ocr/deepdoc_vietocr/__init__.py +90 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/implementations.py +1286 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/layout_recognizer.py +562 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/ocr.py +512 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/onnx/.gitattributes +35 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/onnx/README.md +5 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/onnx/ocr.res +6623 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/operators.py +725 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/phases.py +191 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/pipeline.py +561 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/postprocess.py +370 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/recognizer.py +436 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/table_structure_recognizer.py +569 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/utils/__init__.py +81 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/utils/file_utils.py +246 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/vietocr/__init__.py +0 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/vietocr/config/base.yml +58 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/vietocr/config/vgg-seq2seq.yml +38 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/vietocr/model/__init__.py +0 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/vietocr/model/backbone/cnn.py +25 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/vietocr/model/backbone/vgg.py +51 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/vietocr/model/seqmodel/seq2seq.py +175 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/vietocr/model/transformerocr.py +29 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/vietocr/model/vocab.py +36 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/vietocr/tool/config.py +37 -0
- xfmr_zem/servers/ocr/deepdoc_vietocr/vietocr/tool/translate.py +111 -0
- xfmr_zem/servers/ocr/engines.py +242 -0
- xfmr_zem/servers/ocr/install_models.py +63 -0
- xfmr_zem/servers/ocr/parameters.yml +4 -0
- xfmr_zem/servers/ocr/server.py +102 -0
- xfmr_zem/servers/profiler/parameters.yml +4 -0
- xfmr_zem/servers/sinks/parameters.yml +6 -0
- xfmr_zem/servers/unstructured/parameters.yml +6 -0
- xfmr_zem/servers/unstructured/server.py +62 -0
- xfmr_zem/zenml_wrapper.py +20 -7
- {xfmr_zem-0.2.4.dist-info → xfmr_zem-0.2.6.dist-info}/METADATA +20 -1
- xfmr_zem-0.2.6.dist-info/RECORD +58 -0
- xfmr_zem-0.2.4.dist-info/RECORD +0 -23
- /xfmr_zem/servers/data_juicer/{parameter.yaml → parameters.yml} +0 -0
- /xfmr_zem/servers/instruction_gen/{parameter.yaml → parameters.yml} +0 -0
- /xfmr_zem/servers/io/{parameter.yaml → parameters.yml} +0 -0
- /xfmr_zem/servers/nemo_curator/{parameter.yaml → parameters.yml} +0 -0
- {xfmr_zem-0.2.4.dist-info → xfmr_zem-0.2.6.dist-info}/WHEEL +0 -0
- {xfmr_zem-0.2.4.dist-info → xfmr_zem-0.2.6.dist-info}/entry_points.txt +0 -0
- {xfmr_zem-0.2.4.dist-info → xfmr_zem-0.2.6.dist-info}/licenses/LICENSE +0 -0
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
import yaml
|
|
2
|
+
|
|
3
|
+
def load_config(config_file):
|
|
4
|
+
with open(config_file, encoding='utf-8') as f:
|
|
5
|
+
config = yaml.safe_load(f)
|
|
6
|
+
|
|
7
|
+
return config
|
|
8
|
+
|
|
9
|
+
class Cfg(dict):
|
|
10
|
+
def __init__(self, config_dict):
|
|
11
|
+
super(Cfg, self).__init__(**config_dict)
|
|
12
|
+
self.__dict__ = self
|
|
13
|
+
|
|
14
|
+
@staticmethod
|
|
15
|
+
def load_config_from_file(fname, base_file=None):
|
|
16
|
+
from pathlib import Path
|
|
17
|
+
if base_file is None:
|
|
18
|
+
base_file = Path(__file__).resolve().parent.parent / 'config' / 'base.yml'
|
|
19
|
+
|
|
20
|
+
base_config = load_config(base_file)
|
|
21
|
+
|
|
22
|
+
with open(fname, encoding='utf-8') as f:
|
|
23
|
+
config = yaml.safe_load(f)
|
|
24
|
+
base_config.update(config)
|
|
25
|
+
|
|
26
|
+
return Cfg(base_config)
|
|
27
|
+
|
|
28
|
+
@staticmethod
|
|
29
|
+
def load_config_from_name(name):
|
|
30
|
+
from pathlib import Path
|
|
31
|
+
config_dir = Path(__file__).resolve().parent.parent / 'config'
|
|
32
|
+
return Cfg.load_config_from_file(config_dir / f'{name}.yml')
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def save(self, fname):
|
|
36
|
+
with open(fname, 'w') as outfile:
|
|
37
|
+
yaml.dump(dict(self), outfile, default_flow_style=False, allow_unicode=True)
|
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import torch
|
|
3
|
+
import numpy as np
|
|
4
|
+
import cv2
|
|
5
|
+
from ..model.vocab import Vocab
|
|
6
|
+
from ..model.transformerocr import VietOCR
|
|
7
|
+
import math
|
|
8
|
+
from PIL import Image
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def translate(img, model, max_seq_length=128, sos_token=1, eos_token=2):
|
|
12
|
+
"""data: BxCxHxW"""
|
|
13
|
+
model.eval()
|
|
14
|
+
device = img.device
|
|
15
|
+
|
|
16
|
+
with torch.no_grad():
|
|
17
|
+
src = model.cnn(img)
|
|
18
|
+
memory = model.transformer.forward_encoder(src)
|
|
19
|
+
|
|
20
|
+
translated_sentence = [[sos_token] * len(img)]
|
|
21
|
+
max_length = 0
|
|
22
|
+
|
|
23
|
+
while max_length <= max_seq_length and not all(np.any(np.asarray(translated_sentence).T == eos_token, axis=1)):
|
|
24
|
+
tgt_inp = torch.LongTensor(translated_sentence).to(device)
|
|
25
|
+
output, memory = model.transformer.forward_decoder(tgt_inp, memory)
|
|
26
|
+
output = output.to('cpu')
|
|
27
|
+
|
|
28
|
+
values, indices = torch.topk(output, 1)
|
|
29
|
+
indices = indices[:, -1, 0]
|
|
30
|
+
indices = indices.tolist()
|
|
31
|
+
|
|
32
|
+
translated_sentence.append(indices)
|
|
33
|
+
max_length += 1
|
|
34
|
+
|
|
35
|
+
del output
|
|
36
|
+
|
|
37
|
+
translated_sentence = np.asarray(translated_sentence).T
|
|
38
|
+
|
|
39
|
+
return translated_sentence
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def build_model(config):
|
|
43
|
+
vocab = Vocab(config['vocab'])
|
|
44
|
+
device = config['device']
|
|
45
|
+
|
|
46
|
+
model = VietOCR(len(vocab),
|
|
47
|
+
config['backbone'],
|
|
48
|
+
config['cnn'],
|
|
49
|
+
config['transformer'],
|
|
50
|
+
config['seq_modeling'])
|
|
51
|
+
|
|
52
|
+
model = model.to(device)
|
|
53
|
+
|
|
54
|
+
if 'weights' in config and config['weights']:
|
|
55
|
+
weights = config['weights']
|
|
56
|
+
if weights.startswith('http'):
|
|
57
|
+
# Logic for downloading could go here, but we assume local path for now
|
|
58
|
+
pass
|
|
59
|
+
|
|
60
|
+
if os.path.exists(weights):
|
|
61
|
+
model.load_state_dict(torch.load(weights, map_location=device))
|
|
62
|
+
model.eval()
|
|
63
|
+
else:
|
|
64
|
+
import logging
|
|
65
|
+
logging.warning(f"Weight file not found: {weights}")
|
|
66
|
+
|
|
67
|
+
return model, vocab
|
|
68
|
+
|
|
69
|
+
def resize(w, h, expected_height, image_min_width, image_max_width):
|
|
70
|
+
new_w = int(expected_height * float(w) / float(h))
|
|
71
|
+
round_to = 10
|
|
72
|
+
new_w = math.ceil(new_w/round_to)*round_to
|
|
73
|
+
new_w = max(new_w, image_min_width)
|
|
74
|
+
new_w = min(new_w, image_max_width)
|
|
75
|
+
|
|
76
|
+
return new_w, expected_height
|
|
77
|
+
|
|
78
|
+
def process_image(image, image_height, image_min_width, image_max_width):
|
|
79
|
+
img = image.convert('RGB')
|
|
80
|
+
|
|
81
|
+
w, h = img.size
|
|
82
|
+
new_w, image_height = resize(w, h, image_height, image_min_width, image_max_width)
|
|
83
|
+
|
|
84
|
+
img = img.resize((new_w, image_height), Image.LANCZOS)
|
|
85
|
+
|
|
86
|
+
img = np.asarray(img).transpose(2,0, 1)
|
|
87
|
+
img = img/255
|
|
88
|
+
return img
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def process_input(image, image_height, image_min_width, image_max_width):
|
|
92
|
+
img = process_image(image, image_height, image_min_width, image_max_width)
|
|
93
|
+
img = img[np.newaxis, ...]
|
|
94
|
+
img = torch.FloatTensor(img)
|
|
95
|
+
return img
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class Predictor:
|
|
99
|
+
def __init__(self, config):
|
|
100
|
+
self.model, self.vocab = build_model(config)
|
|
101
|
+
self.config = config
|
|
102
|
+
|
|
103
|
+
def predict(self, img):
|
|
104
|
+
img_input = process_input(img, self.config['dataset']['image_height'],
|
|
105
|
+
self.config['dataset']['image_min_width'],
|
|
106
|
+
self.config['dataset']['image_max_width'])
|
|
107
|
+
img_input = img_input.to(self.config['device'])
|
|
108
|
+
s = translate(img_input, self.model)
|
|
109
|
+
s = s[0].tolist()
|
|
110
|
+
s = self.vocab.decode(s)
|
|
111
|
+
return s
|
|
@@ -0,0 +1,242 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import abc
|
|
3
|
+
from typing import Dict, Any, List
|
|
4
|
+
from PIL import Image
|
|
5
|
+
from loguru import logger
|
|
6
|
+
|
|
7
|
+
class OCREngineBase(abc.ABC):
|
|
8
|
+
"""
|
|
9
|
+
Abstract Base Class for OCR Engines (Dependency Inversion & Open/Closed).
|
|
10
|
+
"""
|
|
11
|
+
@abc.abstractmethod
|
|
12
|
+
def process(self, image_path: str) -> Dict[str, Any]:
|
|
13
|
+
"""Process an image and return extracted text and metadata."""
|
|
14
|
+
pass
|
|
15
|
+
|
|
16
|
+
class TesseractEngine(OCREngineBase):
|
|
17
|
+
"""
|
|
18
|
+
Lightweight OCR using Tesseract (Fast & Simple).
|
|
19
|
+
"""
|
|
20
|
+
def __init__(self):
|
|
21
|
+
logger.debug("TesseractEngine: Initializing...")
|
|
22
|
+
try:
|
|
23
|
+
import pytesseract
|
|
24
|
+
import shutil
|
|
25
|
+
|
|
26
|
+
logger.debug("TesseractEngine: Checking for tesseract binary...")
|
|
27
|
+
# Check if tesseract binary exists
|
|
28
|
+
if not shutil.which("tesseract"):
|
|
29
|
+
raise RuntimeError(
|
|
30
|
+
"Tesseract binary not found. To use the 'tesseract' engine, "
|
|
31
|
+
"please install it using: sudo apt install tesseract-ocr"
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
self.pytesseract = pytesseract
|
|
35
|
+
logger.debug("TesseractEngine: Initialization complete")
|
|
36
|
+
except ImportError:
|
|
37
|
+
logger.error("pytesseract not installed. Please install with 'pip install pytesseract'")
|
|
38
|
+
raise
|
|
39
|
+
|
|
40
|
+
def process(self, image_path: str) -> Dict[str, Any]:
|
|
41
|
+
logger.info(f"Using Tesseract to process: {image_path}")
|
|
42
|
+
image = Image.open(image_path)
|
|
43
|
+
text = self.pytesseract.image_to_string(image)
|
|
44
|
+
return {
|
|
45
|
+
"text": text,
|
|
46
|
+
"engine": "tesseract",
|
|
47
|
+
"metadata": {"format": image.format, "size": image.size}
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
class PaddleEngine(OCREngineBase):
|
|
51
|
+
"""
|
|
52
|
+
Medium-weight OCR using PaddleOCR (High accuracy for multi-language).
|
|
53
|
+
"""
|
|
54
|
+
def __init__(self):
|
|
55
|
+
logger.debug("PaddleEngine: Initializing...")
|
|
56
|
+
try:
|
|
57
|
+
logger.debug("PaddleEngine: Importing PaddleOCR...")
|
|
58
|
+
from paddleocr import PaddleOCR
|
|
59
|
+
logger.debug("PaddleEngine: Creating PaddleOCR instance (use_angle_cls=True, lang='en')...")
|
|
60
|
+
self.ocr = PaddleOCR(use_angle_cls=True, lang='en') # Default to English
|
|
61
|
+
logger.debug("PaddleEngine: Initialization complete")
|
|
62
|
+
except ImportError:
|
|
63
|
+
logger.error("paddleocr not installed. Please install with 'pip install paddleocr paddlepaddle'")
|
|
64
|
+
raise
|
|
65
|
+
|
|
66
|
+
def process(self, image_path: str) -> Dict[str, Any]:
|
|
67
|
+
logger.info(f"Using PaddleOCR to process: {image_path}")
|
|
68
|
+
result = self.ocr.ocr(image_path, cls=True)
|
|
69
|
+
|
|
70
|
+
full_text = []
|
|
71
|
+
scores = []
|
|
72
|
+
for line in result:
|
|
73
|
+
if line:
|
|
74
|
+
for res in line:
|
|
75
|
+
full_text.append(res[1][0])
|
|
76
|
+
scores.append(float(res[1][1]))
|
|
77
|
+
|
|
78
|
+
return {
|
|
79
|
+
"text": "\n".join(full_text),
|
|
80
|
+
"engine": "paddleocr",
|
|
81
|
+
"metadata": {"avg_confidence": sum(scores)/len(scores) if scores else 0}
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
class HuggingFaceVLEngine(OCREngineBase):
|
|
85
|
+
"""
|
|
86
|
+
Advanced OCR using Hugging Face Vision Language Models (e.g. Qwen2-VL, Molmo).
|
|
87
|
+
"""
|
|
88
|
+
def __init__(self, model_id: str = "Qwen/Qwen2-VL-2B-Instruct"):
|
|
89
|
+
self.model_id = model_id or "Qwen/Qwen2-VL-2B-Instruct"
|
|
90
|
+
self.model = None
|
|
91
|
+
self.processor = None
|
|
92
|
+
|
|
93
|
+
def _lazy_load(self):
|
|
94
|
+
if self.model is None:
|
|
95
|
+
try:
|
|
96
|
+
logger.debug(f"HuggingFaceVLEngine: Starting lazy load for model: {self.model_id}")
|
|
97
|
+
import torch
|
|
98
|
+
logger.debug(f"HuggingFaceVLEngine: PyTorch version={torch.__version__}, CUDA available={torch.cuda.is_available()}")
|
|
99
|
+
from transformers import AutoModelForVision2Seq, AutoProcessor
|
|
100
|
+
|
|
101
|
+
logger.info(f"Loading Hugging Face VL model: {self.model_id} (this may take a while)...")
|
|
102
|
+
logger.debug(f"HuggingFaceVLEngine: Loading processor from {self.model_id}...")
|
|
103
|
+
self.processor = AutoProcessor.from_pretrained(self.model_id)
|
|
104
|
+
logger.debug(f"HuggingFaceVLEngine: Processor loaded successfully")
|
|
105
|
+
|
|
106
|
+
# Use GPU if available
|
|
107
|
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
108
|
+
|
|
109
|
+
# Default to float32
|
|
110
|
+
dtype = torch.float32
|
|
111
|
+
if device == "cuda":
|
|
112
|
+
# Check compute capability
|
|
113
|
+
cc_major = torch.cuda.get_device_properties(0).major
|
|
114
|
+
# Pascal (6.1) has poor FP16 performance, so use FP32.
|
|
115
|
+
# Volta (7.0) and newer have good FP16 performance.
|
|
116
|
+
if cc_major >= 7:
|
|
117
|
+
dtype = torch.float16
|
|
118
|
+
|
|
119
|
+
logger.debug(f"HuggingFaceVLEngine: Loading model with dtype={dtype}, device='{device}'...")
|
|
120
|
+
# Using AutoModelForVision2Seq for generality
|
|
121
|
+
self.model = AutoModelForVision2Seq.from_pretrained(
|
|
122
|
+
self.model_id,
|
|
123
|
+
torch_dtype=dtype,
|
|
124
|
+
device_map=None,
|
|
125
|
+
trust_remote_code=True
|
|
126
|
+
).to(device)
|
|
127
|
+
self._device = device
|
|
128
|
+
logger.debug(f"HuggingFaceVLEngine: Model loaded successfully on {device}")
|
|
129
|
+
except ImportError:
|
|
130
|
+
logger.error("transformers/torch not installed. Required for HuggingFace-VL.")
|
|
131
|
+
raise
|
|
132
|
+
except Exception as e:
|
|
133
|
+
logger.error(f"Error loading model {self.model_id}: {e}")
|
|
134
|
+
raise
|
|
135
|
+
|
|
136
|
+
def process(self, image_path: str) -> Dict[str, Any]:
|
|
137
|
+
self._lazy_load()
|
|
138
|
+
logger.info(f"Using {self.model_id} via HuggingFaceVLEngine to process: {image_path}")
|
|
139
|
+
|
|
140
|
+
image = Image.open(image_path).convert("RGB")
|
|
141
|
+
|
|
142
|
+
# Use proper chat template format for Qwen2-VL
|
|
143
|
+
messages = [
|
|
144
|
+
{
|
|
145
|
+
"role": "user",
|
|
146
|
+
"content": [
|
|
147
|
+
{"type": "image", "image": image},
|
|
148
|
+
{"type": "text", "text": "Extract all text from this image exactly as it appears."}
|
|
149
|
+
]
|
|
150
|
+
}
|
|
151
|
+
]
|
|
152
|
+
|
|
153
|
+
# Apply chat template for proper formatting
|
|
154
|
+
text_input = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
|
155
|
+
inputs = self.processor(text=[text_input], images=[image], return_tensors="pt", padding=True).to(self._device)
|
|
156
|
+
|
|
157
|
+
generated_ids = self.model.generate(**inputs, max_new_tokens=512)
|
|
158
|
+
# Only decode new tokens
|
|
159
|
+
generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
|
|
160
|
+
text = self.processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True)[0]
|
|
161
|
+
|
|
162
|
+
return {
|
|
163
|
+
"text": text,
|
|
164
|
+
"engine": "huggingface",
|
|
165
|
+
"metadata": {"model_id": self.model_id}
|
|
166
|
+
}
|
|
167
|
+
|
|
168
|
+
class VietOCREngine(OCREngineBase):
|
|
169
|
+
"""
|
|
170
|
+
Specialized Vietnamese OCR using built-in Deep-ocr DocumentPipeline (Layout + OCR + MD).
|
|
171
|
+
"""
|
|
172
|
+
def __init__(self):
|
|
173
|
+
logger.debug("VietOCREngine: Initializing...")
|
|
174
|
+
try:
|
|
175
|
+
logger.debug("VietOCREngine: Importing DocumentPipeline and components...")
|
|
176
|
+
from xfmr_zem.servers.ocr.deepdoc_vietocr.pipeline import DocumentPipeline
|
|
177
|
+
from xfmr_zem.servers.ocr.deepdoc_vietocr.implementations import (
|
|
178
|
+
PaddleStructureV3Analyzer,
|
|
179
|
+
PaddleOCRTextDetector,
|
|
180
|
+
VietOCRRecognizer,
|
|
181
|
+
VietnameseTextPostProcessor,
|
|
182
|
+
SmartMarkdownReconstruction
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
logger.info("Initializing Internal Deep-ocr DocumentPipeline for Vietnamese...")
|
|
186
|
+
logger.debug("VietOCREngine: Creating PaddleStructureV3Analyzer...")
|
|
187
|
+
layout_analyzer = PaddleStructureV3Analyzer()
|
|
188
|
+
logger.debug("VietOCREngine: Creating PaddleOCRTextDetector...")
|
|
189
|
+
text_detector = PaddleOCRTextDetector()
|
|
190
|
+
logger.debug("VietOCREngine: Creating VietOCRRecognizer...")
|
|
191
|
+
text_recognizer = VietOCRRecognizer()
|
|
192
|
+
logger.debug("VietOCREngine: Creating VietnameseTextPostProcessor...")
|
|
193
|
+
post_processor = VietnameseTextPostProcessor()
|
|
194
|
+
logger.debug("VietOCREngine: Creating SmartMarkdownReconstruction...")
|
|
195
|
+
reconstructor = SmartMarkdownReconstruction()
|
|
196
|
+
|
|
197
|
+
logger.debug("VietOCREngine: Assembling DocumentPipeline...")
|
|
198
|
+
self.pipeline = DocumentPipeline(
|
|
199
|
+
layout_analyzer=layout_analyzer,
|
|
200
|
+
text_detector=text_detector,
|
|
201
|
+
text_recognizer=text_recognizer,
|
|
202
|
+
post_processor=post_processor,
|
|
203
|
+
reconstructor=reconstructor
|
|
204
|
+
)
|
|
205
|
+
logger.debug("VietOCREngine: Initialization complete")
|
|
206
|
+
except Exception as e:
|
|
207
|
+
logger.error(f"Error loading internal Deep-ocr components: {e}")
|
|
208
|
+
import traceback
|
|
209
|
+
logger.error(traceback.format_exc())
|
|
210
|
+
raise
|
|
211
|
+
|
|
212
|
+
def process(self, image_path: str) -> Dict[str, Any]:
|
|
213
|
+
logger.info(f"Using Internal Deep-ocr (DocumentPipeline) to process: {image_path}")
|
|
214
|
+
from PIL import Image
|
|
215
|
+
|
|
216
|
+
img = Image.open(image_path)
|
|
217
|
+
|
|
218
|
+
# document.process returns reconstructed markdown text
|
|
219
|
+
markdown_text = self.pipeline.process(img)
|
|
220
|
+
|
|
221
|
+
return {
|
|
222
|
+
"text": markdown_text,
|
|
223
|
+
"engine": "deepdoc_vietocr",
|
|
224
|
+
"metadata": {"format": "markdown"}
|
|
225
|
+
}
|
|
226
|
+
|
|
227
|
+
class OCREngineFactory:
|
|
228
|
+
"""
|
|
229
|
+
Factory to create OCR engines (Switching strategy).
|
|
230
|
+
"""
|
|
231
|
+
@staticmethod
|
|
232
|
+
def get_engine(engine_type: str, **kwargs) -> OCREngineBase:
|
|
233
|
+
if engine_type == "tesseract":
|
|
234
|
+
return TesseractEngine()
|
|
235
|
+
elif engine_type == "paddle":
|
|
236
|
+
return PaddleEngine()
|
|
237
|
+
elif engine_type == "huggingface" or engine_type == "qwen":
|
|
238
|
+
return HuggingFaceVLEngine(model_id=kwargs.get("model_id"))
|
|
239
|
+
elif engine_type == "viet":
|
|
240
|
+
return VietOCREngine()
|
|
241
|
+
else:
|
|
242
|
+
raise ValueError(f"Unknown engine type: {engine_type}")
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import urllib.request
|
|
3
|
+
from tqdm import tqdm
|
|
4
|
+
import logging
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
|
|
8
|
+
|
|
9
|
+
def download_file(url, file_path):
|
|
10
|
+
class DownloadProgressBar(tqdm):
|
|
11
|
+
def update_to(self, b=1, bsize=1, tsize=None):
|
|
12
|
+
if tsize is not None:
|
|
13
|
+
self.total = tsize
|
|
14
|
+
self.update(b * bsize - self.n)
|
|
15
|
+
|
|
16
|
+
if os.path.exists(file_path):
|
|
17
|
+
logging.info(f"File already exists: {file_path}")
|
|
18
|
+
return
|
|
19
|
+
|
|
20
|
+
logging.info(f"Downloading {url} to {file_path}")
|
|
21
|
+
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
|
22
|
+
|
|
23
|
+
try:
|
|
24
|
+
with DownloadProgressBar(unit='B', unit_scale=True,
|
|
25
|
+
miniters=1, desc=url.split('/')[-1]) as t:
|
|
26
|
+
urllib.request.urlretrieve(url, filename=file_path, reporthook=t.update_to)
|
|
27
|
+
logging.info("Download completed.")
|
|
28
|
+
except Exception as e:
|
|
29
|
+
logging.error(f"Failed to download: {e}")
|
|
30
|
+
if os.path.exists(file_path):
|
|
31
|
+
os.remove(file_path)
|
|
32
|
+
raise e
|
|
33
|
+
|
|
34
|
+
def main():
|
|
35
|
+
base_dir = Path(__file__).resolve().parent / "deepdoc_vietocr"
|
|
36
|
+
|
|
37
|
+
models = [
|
|
38
|
+
# Detection
|
|
39
|
+
("https://huggingface.co/monkt/paddleocr-onnx/resolve/main/detection/v5/det.onnx",
|
|
40
|
+
base_dir / "onnx" / "det.onnx"),
|
|
41
|
+
|
|
42
|
+
# Layout Analysis
|
|
43
|
+
("https://huggingface.co/monkt/paddleocr-onnx/resolve/main/layout/v1/layout.onnx",
|
|
44
|
+
base_dir / "onnx" / "layout.onnx"),
|
|
45
|
+
|
|
46
|
+
# Table Structure
|
|
47
|
+
("https://huggingface.co/monkt/paddleocr-onnx/resolve/main/tsr/v1/tsr.onnx",
|
|
48
|
+
base_dir / "onnx" / "tsr.onnx"),
|
|
49
|
+
|
|
50
|
+
# VietOCR Recognition
|
|
51
|
+
("https://github.com/p_nhm/vietocr-weights/raw/main/vgg_seq2seq.pth",
|
|
52
|
+
base_dir / "vietocr" / "weight" / "vgg_seq2seq.pth")
|
|
53
|
+
]
|
|
54
|
+
|
|
55
|
+
logging.info("Starting OCR model installation...")
|
|
56
|
+
for url, path in models:
|
|
57
|
+
try:
|
|
58
|
+
download_file(url, str(path))
|
|
59
|
+
except Exception as e:
|
|
60
|
+
logging.error(f"Skipping {path} due to error: {e}")
|
|
61
|
+
|
|
62
|
+
if __name__ == "__main__":
|
|
63
|
+
main()
|
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import pandas as pd
|
|
3
|
+
from xfmr_zem.server import ZemServer
|
|
4
|
+
from xfmr_zem.servers.ocr.engines import OCREngineFactory
|
|
5
|
+
from loguru import logger
|
|
6
|
+
from PIL import Image
|
|
7
|
+
import io
|
|
8
|
+
|
|
9
|
+
# Initialize ZemServer for OCR
|
|
10
|
+
mcp = ZemServer("ocr")
|
|
11
|
+
|
|
12
|
+
def extract_pdf_pages(file_path: str, engine: str, ocr_engine, model_id: str = None):
|
|
13
|
+
"""Helper to process PDF pages with optional OCR for scanned content."""
|
|
14
|
+
import fitz # PyMuPDF
|
|
15
|
+
|
|
16
|
+
results = []
|
|
17
|
+
doc = fitz.open(file_path)
|
|
18
|
+
|
|
19
|
+
for page_num in range(len(doc)):
|
|
20
|
+
page = doc[page_num]
|
|
21
|
+
text = page.get_text().strip()
|
|
22
|
+
|
|
23
|
+
# Determine if we need to OCR (Strategy: text is too short or empty)
|
|
24
|
+
is_scanned = len(text) < 50
|
|
25
|
+
|
|
26
|
+
if is_scanned:
|
|
27
|
+
logger.info(f"Page {page_num + 1} appears scanned. Running OCR with {engine}...")
|
|
28
|
+
# Render page to image for OCR
|
|
29
|
+
pix = page.get_pixmap(matrix=fitz.Matrix(2, 2)) # 2x zoom for better OCR
|
|
30
|
+
img_data = pix.tobytes("png")
|
|
31
|
+
img = Image.open(io.BytesIO(img_data))
|
|
32
|
+
|
|
33
|
+
# Temporary save for engine compatibility (engines expect path)
|
|
34
|
+
temp_path = f"/tmp/ocr_page_{page_num}.png"
|
|
35
|
+
img.save(temp_path)
|
|
36
|
+
|
|
37
|
+
try:
|
|
38
|
+
ocr_result = ocr_engine.process(temp_path)
|
|
39
|
+
final_text = ocr_result["text"]
|
|
40
|
+
source = f"{engine}_ocr"
|
|
41
|
+
finally:
|
|
42
|
+
if os.path.exists(temp_path):
|
|
43
|
+
os.remove(temp_path)
|
|
44
|
+
else:
|
|
45
|
+
final_text = text
|
|
46
|
+
source = "digital_pdf"
|
|
47
|
+
|
|
48
|
+
results.append({
|
|
49
|
+
"text": final_text,
|
|
50
|
+
"page": page_num + 1,
|
|
51
|
+
"engine": source,
|
|
52
|
+
"metadata": {"file": file_path, "is_scanned": is_scanned}
|
|
53
|
+
})
|
|
54
|
+
|
|
55
|
+
doc.close()
|
|
56
|
+
return results
|
|
57
|
+
|
|
58
|
+
@mcp.tool()
|
|
59
|
+
async def extract_text(file_path: str, engine: str = "tesseract", model_id: str = None) -> pd.DataFrame:
|
|
60
|
+
"""
|
|
61
|
+
Extracts text from an image or PDF using the specified OCR engine.
|
|
62
|
+
For PDFs, it will automatically handle scanned pages using the OCR engine.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
file_path: Path to the image or PDF file.
|
|
66
|
+
engine: The OCR engine to use ("tesseract", "paddle", "huggingface", "viet"). Defaults to "tesseract".
|
|
67
|
+
model_id: Optional model ID for the 'huggingface' engine.
|
|
68
|
+
"""
|
|
69
|
+
logger.info(f"OCR Extraction: {file_path} using {engine}")
|
|
70
|
+
|
|
71
|
+
if not os.path.exists(file_path):
|
|
72
|
+
raise FileNotFoundError(f"File not found: {file_path}")
|
|
73
|
+
|
|
74
|
+
try:
|
|
75
|
+
# Get engine from factory
|
|
76
|
+
ocr_engine = OCREngineFactory.get_engine(engine, model_id=model_id)
|
|
77
|
+
|
|
78
|
+
# Handle PDF vs Image
|
|
79
|
+
if file_path.lower().endswith(".pdf"):
|
|
80
|
+
logger.info(f"Processing PDF file: {file_path}")
|
|
81
|
+
data = extract_pdf_pages(file_path, engine, ocr_engine, model_id)
|
|
82
|
+
df = pd.DataFrame(data)
|
|
83
|
+
else:
|
|
84
|
+
# Process image
|
|
85
|
+
result = ocr_engine.process(file_path)
|
|
86
|
+
df = pd.DataFrame([{
|
|
87
|
+
"text": result["text"],
|
|
88
|
+
"engine": result["engine"],
|
|
89
|
+
"metadata": result["metadata"]
|
|
90
|
+
}])
|
|
91
|
+
|
|
92
|
+
logger.info(f"Successfully extracted text from {file_path}")
|
|
93
|
+
return df.to_dict(orient="records")
|
|
94
|
+
|
|
95
|
+
except Exception as e:
|
|
96
|
+
logger.error(f"OCR Error with {engine}: {e}")
|
|
97
|
+
import traceback
|
|
98
|
+
logger.error(traceback.format_exc())
|
|
99
|
+
raise RuntimeError(f"OCR failed: {str(e)}")
|
|
100
|
+
|
|
101
|
+
if __name__ == "__main__":
|
|
102
|
+
mcp.run()
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import pandas as pd
|
|
3
|
+
from xfmr_zem.server import ZemServer
|
|
4
|
+
from unstructured.partition.auto import partition
|
|
5
|
+
from loguru import logger
|
|
6
|
+
|
|
7
|
+
# Initialize ZemServer for Unstructured
|
|
8
|
+
mcp = ZemServer("unstructured")
|
|
9
|
+
|
|
10
|
+
@mcp.tool()
|
|
11
|
+
async def parse_document(file_path: str, strategy: str = "fast") -> pd.DataFrame:
|
|
12
|
+
"""
|
|
13
|
+
Parses a document (PDF, DOCX, HTML, etc.) and returns all text segments as a DataFrame.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
file_path: Path to the document file.
|
|
17
|
+
strategy: Partitioning strategy ("fast", "hi_res", "ocr_only"). Defaults to "fast".
|
|
18
|
+
"""
|
|
19
|
+
logger.info(f"Parsing document: {file_path} with strategy: {strategy}")
|
|
20
|
+
|
|
21
|
+
if not os.path.exists(file_path):
|
|
22
|
+
raise FileNotFoundError(f"File not found: {file_path}")
|
|
23
|
+
|
|
24
|
+
# Use unstructured to partition the file
|
|
25
|
+
elements = partition(filename=file_path, strategy=strategy)
|
|
26
|
+
|
|
27
|
+
# Convert elements to a list of dicts
|
|
28
|
+
data = []
|
|
29
|
+
for el in elements:
|
|
30
|
+
data.append({
|
|
31
|
+
"text": str(el),
|
|
32
|
+
"type": el.category,
|
|
33
|
+
"element_id": el.id,
|
|
34
|
+
"metadata": el.metadata.to_dict() if hasattr(el, "metadata") else {}
|
|
35
|
+
})
|
|
36
|
+
|
|
37
|
+
df = pd.DataFrame(data)
|
|
38
|
+
logger.info(f"Extracted {len(df)} elements from {file_path}")
|
|
39
|
+
return df
|
|
40
|
+
|
|
41
|
+
@mcp.tool()
|
|
42
|
+
async def extract_tables(file_path: str) -> pd.DataFrame:
|
|
43
|
+
"""
|
|
44
|
+
Specifically extracts tables from a document and returns them.
|
|
45
|
+
Note: Requires 'hi_res' strategy internally.
|
|
46
|
+
"""
|
|
47
|
+
logger.info(f"Extracting tables from: {file_path}")
|
|
48
|
+
|
|
49
|
+
# Partition with hi_res to get table structure
|
|
50
|
+
elements = partition(filename=file_path, strategy="hi_res")
|
|
51
|
+
|
|
52
|
+
# Filter for Table elements
|
|
53
|
+
tables = [str(el) for el in elements if el.category == "Table"]
|
|
54
|
+
|
|
55
|
+
if not tables:
|
|
56
|
+
logger.warning(f"No tables found in {file_path}")
|
|
57
|
+
return pd.DataFrame(columns=["table_content"])
|
|
58
|
+
|
|
59
|
+
return pd.DataFrame({"table_content": tables})
|
|
60
|
+
|
|
61
|
+
if __name__ == "__main__":
|
|
62
|
+
mcp.run()
|