kiri-ocr 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kiri_ocr/__init__.py +14 -0
- kiri_ocr/cli.py +244 -0
- kiri_ocr/core.py +306 -0
- kiri_ocr/detector.py +374 -0
- kiri_ocr/generator.py +570 -0
- kiri_ocr/model.py +159 -0
- kiri_ocr/renderer.py +193 -0
- kiri_ocr/training.py +508 -0
- kiri_ocr-0.1.0.data/scripts/kiri-ocr +6 -0
- kiri_ocr-0.1.0.dist-info/METADATA +218 -0
- kiri_ocr-0.1.0.dist-info/RECORD +16 -0
- kiri_ocr-0.1.0.dist-info/WHEEL +5 -0
- kiri_ocr-0.1.0.dist-info/licenses/LICENSE +201 -0
- kiri_ocr-0.1.0.dist-info/top_level.txt +2 -0
- models/__init__.py +1 -0
- models/model.kiri +0 -0
kiri_ocr/__init__.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
from .core import OCR
|
|
2
|
+
from .renderer import DocumentRenderer
|
|
3
|
+
from .model import LightweightOCR, CharacterSet
|
|
4
|
+
from .detector import TextDetector
|
|
5
|
+
|
|
6
|
+
__version__ = '0.1.0'
|
|
7
|
+
|
|
8
|
+
__all__ = [
|
|
9
|
+
'OCR',
|
|
10
|
+
'DocumentRenderer',
|
|
11
|
+
'LightweightOCR',
|
|
12
|
+
'CharacterSet',
|
|
13
|
+
'TextDetector',
|
|
14
|
+
]
|
kiri_ocr/cli.py
ADDED
|
@@ -0,0 +1,244 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import json
|
|
3
|
+
import argparse
|
|
4
|
+
import sys
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
import yaml
|
|
7
|
+
|
|
8
|
+
from .core import OCR
|
|
9
|
+
from .renderer import DocumentRenderer
|
|
10
|
+
from .training import train_command
|
|
11
|
+
from .generator import generate_command
|
|
12
|
+
|
|
13
|
+
DEFAULT_TRAIN_CONFIG = {
|
|
14
|
+
"height": 32,
|
|
15
|
+
"batch_size": 32,
|
|
16
|
+
"epochs": 2,
|
|
17
|
+
"hidden_size": 256,
|
|
18
|
+
"device": "cuda",
|
|
19
|
+
"output_dir": "models",
|
|
20
|
+
"train_labels": "data/train/labels.txt",
|
|
21
|
+
"val_labels": "data/val/labels.txt",
|
|
22
|
+
"lr": 0.001,
|
|
23
|
+
"weight_decay": 0.0001
|
|
24
|
+
}
|
|
25
|
+
|
|
26
|
+
def init_config(args):
|
|
27
|
+
path = args.output
|
|
28
|
+
# Ensure .yaml extension if default or user didn't specify
|
|
29
|
+
if path == 'config.yaml' and not path.endswith('.yaml') and not path.endswith('.yml'):
|
|
30
|
+
pass # user provided custom name
|
|
31
|
+
|
|
32
|
+
with open(path, 'w') as f:
|
|
33
|
+
yaml.dump(DEFAULT_TRAIN_CONFIG, f, default_flow_style=False)
|
|
34
|
+
print(f"ā Created default config at {path}")
|
|
35
|
+
|
|
36
|
+
def run_inference(args):
|
|
37
|
+
# Create output directory
|
|
38
|
+
output_dir = Path(args.output)
|
|
39
|
+
output_dir.mkdir(exist_ok=True)
|
|
40
|
+
|
|
41
|
+
print("\n" + "="*70)
|
|
42
|
+
print(" š Kiri OCR System")
|
|
43
|
+
print("="*70)
|
|
44
|
+
|
|
45
|
+
try:
|
|
46
|
+
# Initialize OCR
|
|
47
|
+
ocr = OCR(
|
|
48
|
+
model_path=args.model,
|
|
49
|
+
charset_path=args.charset,
|
|
50
|
+
language=args.language,
|
|
51
|
+
padding=args.padding,
|
|
52
|
+
device=args.device,
|
|
53
|
+
verbose=True
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
# Process document & Extract text
|
|
57
|
+
full_text, results = ocr.extract_text(args.image, mode=args.mode, verbose=True)
|
|
58
|
+
|
|
59
|
+
# Save text
|
|
60
|
+
text_output = output_dir / 'extracted_text.txt'
|
|
61
|
+
with open(text_output, 'w', encoding='utf-8') as f:
|
|
62
|
+
f.write(full_text)
|
|
63
|
+
print(f"\nā Text saved to {text_output}")
|
|
64
|
+
|
|
65
|
+
# Save JSON
|
|
66
|
+
json_output = output_dir / 'ocr_results.json'
|
|
67
|
+
with open(json_output, 'w', encoding='utf-8') as f:
|
|
68
|
+
json.dump(results, f, indent=2, ensure_ascii=False)
|
|
69
|
+
print(f"ā JSON saved to {json_output}")
|
|
70
|
+
|
|
71
|
+
# Render results
|
|
72
|
+
if not args.no_render:
|
|
73
|
+
renderer = DocumentRenderer()
|
|
74
|
+
|
|
75
|
+
# Boxes only
|
|
76
|
+
renderer.draw_boxes(
|
|
77
|
+
args.image,
|
|
78
|
+
results,
|
|
79
|
+
output_path=str(output_dir / 'boxes.png')
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
# Boxes with text
|
|
83
|
+
renderer.draw_results(
|
|
84
|
+
args.image,
|
|
85
|
+
results,
|
|
86
|
+
output_path=str(output_dir / 'ocr_result.png')
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
# HTML report
|
|
90
|
+
renderer.create_report(
|
|
91
|
+
args.image,
|
|
92
|
+
results,
|
|
93
|
+
output_path=str(output_dir / 'report.html')
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
print("\n" + "="*70)
|
|
97
|
+
print(" ā
Processing Complete!")
|
|
98
|
+
print("="*70)
|
|
99
|
+
print(f" Regions detected: {len(results)}")
|
|
100
|
+
if results:
|
|
101
|
+
print(f" Average confidence: {np.mean([r['confidence'] for r in results])*100:.2f}%")
|
|
102
|
+
print(f" Output directory: {output_dir}")
|
|
103
|
+
print("="*70 + "\n")
|
|
104
|
+
|
|
105
|
+
except Exception as e:
|
|
106
|
+
print(f"\nā Error: {e}")
|
|
107
|
+
# Suggest model path if it seems to be missing
|
|
108
|
+
if "No such file" in str(e) and "model" in str(e):
|
|
109
|
+
print("\nTip: Make sure you have trained a model or specified the correct path with --model")
|
|
110
|
+
print(" Run: kiri-ocr train ... to train a model first.")
|
|
111
|
+
|
|
112
|
+
def merge_config(args, defaults):
|
|
113
|
+
"""Merge defaults < config file < args"""
|
|
114
|
+
# Start with defaults
|
|
115
|
+
config = defaults.copy()
|
|
116
|
+
|
|
117
|
+
# Update with config file if provided
|
|
118
|
+
if hasattr(args, 'config') and args.config:
|
|
119
|
+
try:
|
|
120
|
+
with open(args.config, 'r') as f:
|
|
121
|
+
if args.config.endswith('.json'):
|
|
122
|
+
file_config = json.load(f)
|
|
123
|
+
else:
|
|
124
|
+
file_config = yaml.safe_load(f)
|
|
125
|
+
|
|
126
|
+
if file_config:
|
|
127
|
+
config.update(file_config)
|
|
128
|
+
print(f"Loaded config from {args.config}")
|
|
129
|
+
except Exception as e:
|
|
130
|
+
print(f"Error loading config file: {e}")
|
|
131
|
+
sys.exit(1)
|
|
132
|
+
|
|
133
|
+
# Update with explicit args (non-None)
|
|
134
|
+
for key, value in vars(args).items():
|
|
135
|
+
if value is not None and key in config:
|
|
136
|
+
config[key] = value
|
|
137
|
+
|
|
138
|
+
# Update args object
|
|
139
|
+
for key, value in config.items():
|
|
140
|
+
setattr(args, key, value)
|
|
141
|
+
|
|
142
|
+
return args
|
|
143
|
+
|
|
144
|
+
def main():
|
|
145
|
+
parser = argparse.ArgumentParser(description='Kiri OCR - Command Line Tool')
|
|
146
|
+
subparsers = parser.add_subparsers(dest='command', help='Command to run')
|
|
147
|
+
|
|
148
|
+
# === PREDICT / RUN ===
|
|
149
|
+
predict_parser = subparsers.add_parser('predict', help='Run OCR on an image')
|
|
150
|
+
predict_parser.add_argument('image', help='Path to document image')
|
|
151
|
+
predict_parser.add_argument('--mode', choices=['lines', 'words'], default='lines',
|
|
152
|
+
help='Detection mode (default: lines)')
|
|
153
|
+
predict_parser.add_argument('--model', default='models/model.kiri',
|
|
154
|
+
help='Path to model file')
|
|
155
|
+
predict_parser.add_argument('--charset', default='models/charset_lite.txt',
|
|
156
|
+
help='Path to charset file')
|
|
157
|
+
predict_parser.add_argument('--language', choices=['english', 'khmer', 'mixed'],
|
|
158
|
+
default='mixed', help='Language mode')
|
|
159
|
+
predict_parser.add_argument('--padding', type=int, default=10,
|
|
160
|
+
help='Padding around detected boxes in pixels (default: 10)')
|
|
161
|
+
predict_parser.add_argument('--output', '-o', default='output',
|
|
162
|
+
help='Output directory')
|
|
163
|
+
predict_parser.add_argument('--no-render', action='store_true',
|
|
164
|
+
help='Skip rendering (text only)')
|
|
165
|
+
predict_parser.add_argument('--device', choices=['cpu', 'cuda'], default='cpu',
|
|
166
|
+
help='Device to use')
|
|
167
|
+
|
|
168
|
+
# === TRAIN ===
|
|
169
|
+
train_parser = subparsers.add_parser('train', help='Train the OCR model')
|
|
170
|
+
train_parser.add_argument('--config', help='Path to config file (YAML/JSON)')
|
|
171
|
+
train_parser.add_argument('--train-labels', help='Path to training labels file')
|
|
172
|
+
train_parser.add_argument('--val-labels', help='Path to validation labels file')
|
|
173
|
+
train_parser.add_argument('--hf-dataset', help='HuggingFace dataset ID (e.g. mrrtmob/km_en_image_line)')
|
|
174
|
+
train_parser.add_argument('--hf-subset', help='HuggingFace dataset subset/config name')
|
|
175
|
+
train_parser.add_argument('--hf-train-split', default='train', help='Train split name (default: train)')
|
|
176
|
+
train_parser.add_argument('--hf-val-split', help='Validation split name')
|
|
177
|
+
train_parser.add_argument('--hf-image-col', default='image', help='Image column name')
|
|
178
|
+
train_parser.add_argument('--hf-text-col', default='text', help='Text column name')
|
|
179
|
+
train_parser.add_argument('--hf-val-percent', type=float, default=0.1, help='Val % if no split found (default: 0.1)')
|
|
180
|
+
train_parser.add_argument('--output-dir', help='Directory to save model')
|
|
181
|
+
train_parser.add_argument('--epochs', type=int, help='Number of epochs')
|
|
182
|
+
train_parser.add_argument('--batch-size', type=int, help='Batch size')
|
|
183
|
+
train_parser.add_argument('--height', type=int, help='Image height')
|
|
184
|
+
train_parser.add_argument('--hidden-size', type=int, help='Hidden size')
|
|
185
|
+
train_parser.add_argument('--device', help='Device (cuda/cpu)')
|
|
186
|
+
train_parser.add_argument('--from-model', help='Path to pretrained model for fine-tuning')
|
|
187
|
+
train_parser.add_argument('--lr', type=float, help='Learning rate')
|
|
188
|
+
train_parser.add_argument('--weight-decay', type=float, help='Weight decay')
|
|
189
|
+
|
|
190
|
+
# === GENERATE ===
|
|
191
|
+
gen_parser = subparsers.add_parser('generate', help='Generate synthetic data')
|
|
192
|
+
gen_parser.add_argument('--train-file', '-t', required=True,
|
|
193
|
+
help='Training text file (one line per sample)')
|
|
194
|
+
gen_parser.add_argument('--val-file', '-v', default=None,
|
|
195
|
+
help='Validation text file (optional)')
|
|
196
|
+
gen_parser.add_argument('--output', '-o', default='data',
|
|
197
|
+
help='Output directory')
|
|
198
|
+
gen_parser.add_argument('--language', '-l', choices=['english', 'khmer', 'mixed'],
|
|
199
|
+
default='mixed', help='Language mode')
|
|
200
|
+
gen_parser.add_argument('--augment', '-a', type=int, default=1,
|
|
201
|
+
help='Augmentation factor for training')
|
|
202
|
+
gen_parser.add_argument('--val-augment', type=int, default=1,
|
|
203
|
+
help='Augmentation factor for validation')
|
|
204
|
+
gen_parser.add_argument('--height', type=int, default=32,
|
|
205
|
+
help='Image height')
|
|
206
|
+
gen_parser.add_argument('--width', type=int, default=512,
|
|
207
|
+
help='Image width')
|
|
208
|
+
gen_parser.add_argument('--fonts-dir', default='fonts',
|
|
209
|
+
help='Directory containing .ttf font files')
|
|
210
|
+
gen_parser.add_argument('--font-mode', choices=['random', 'all'], default='random',
|
|
211
|
+
help='Font selection mode: random (default) or all (iterate all fonts per line)')
|
|
212
|
+
gen_parser.add_argument('--random-augment', action='store_true',
|
|
213
|
+
help='Apply random augmentations (noise, rotation) even if augmentation factor is 1')
|
|
214
|
+
|
|
215
|
+
# === INIT CONFIG ===
|
|
216
|
+
init_parser = subparsers.add_parser('init-config', help='Create default config file')
|
|
217
|
+
init_parser.add_argument('--output', '-o', default='config.yaml', help='Output file')
|
|
218
|
+
|
|
219
|
+
# Backward compatibility logic
|
|
220
|
+
if len(sys.argv) > 1 and sys.argv[1] not in ['predict', 'train', 'generate', 'init-config', '-h', '--help']:
|
|
221
|
+
sys.argv.insert(1, 'predict')
|
|
222
|
+
|
|
223
|
+
args = parser.parse_args()
|
|
224
|
+
|
|
225
|
+
if args.command == 'predict':
|
|
226
|
+
run_inference(args)
|
|
227
|
+
elif args.command == 'train':
|
|
228
|
+
# Merge config for training
|
|
229
|
+
args = merge_config(args, DEFAULT_TRAIN_CONFIG)
|
|
230
|
+
# Check required fields
|
|
231
|
+
if not args.train_labels and not args.hf_dataset:
|
|
232
|
+
print("ā Error: --train-labels or --hf-dataset is required")
|
|
233
|
+
sys.exit(1)
|
|
234
|
+
|
|
235
|
+
train_command(args)
|
|
236
|
+
elif args.command == 'generate':
|
|
237
|
+
generate_command(args)
|
|
238
|
+
elif args.command == 'init-config':
|
|
239
|
+
init_config(args)
|
|
240
|
+
else:
|
|
241
|
+
parser.print_help()
|
|
242
|
+
|
|
243
|
+
if __name__ == '__main__':
|
|
244
|
+
main()
|
kiri_ocr/core.py
ADDED
|
@@ -0,0 +1,306 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import cv2
|
|
3
|
+
import numpy as np
|
|
4
|
+
import sys
|
|
5
|
+
from PIL import Image
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
import json
|
|
8
|
+
|
|
9
|
+
from .model import LightweightOCR, CharacterSet
|
|
10
|
+
from .detector import TextDetector
|
|
11
|
+
|
|
12
|
+
class OCR:
|
|
13
|
+
"""Complete Document OCR System with Padding"""
|
|
14
|
+
|
|
15
|
+
def __init__(self, model_path='models/model.kiri',
|
|
16
|
+
charset_path='models/charset_lite.txt',
|
|
17
|
+
language='mixed',
|
|
18
|
+
padding=10,
|
|
19
|
+
device='cpu',
|
|
20
|
+
verbose=False):
|
|
21
|
+
"""
|
|
22
|
+
Args:
|
|
23
|
+
model_path: Path to trained model (.kiri or .pth)
|
|
24
|
+
charset_path: Path to character set (used if model doesn't contain charset)
|
|
25
|
+
language: 'english', 'khmer', or 'mixed'
|
|
26
|
+
padding: Pixels to pad around detected boxes (default: 10)
|
|
27
|
+
device: 'cpu' or 'cuda'
|
|
28
|
+
verbose: Whether to print loading/processing info
|
|
29
|
+
"""
|
|
30
|
+
self.device = device
|
|
31
|
+
self.verbose = verbose
|
|
32
|
+
self.language = language
|
|
33
|
+
self.padding = padding
|
|
34
|
+
|
|
35
|
+
# Resolve model path
|
|
36
|
+
if not Path(model_path).exists():
|
|
37
|
+
# Try looking in package directory
|
|
38
|
+
pkg_dir = Path(__file__).parent
|
|
39
|
+
if (pkg_dir / model_path).exists():
|
|
40
|
+
model_path = str(pkg_dir / model_path)
|
|
41
|
+
# Try looking in sibling 'models' package (if installed via setup.py with models package)
|
|
42
|
+
elif (pkg_dir.parent / 'models' / Path(model_path).name).exists():
|
|
43
|
+
model_path = str(pkg_dir.parent / 'models' / Path(model_path).name)
|
|
44
|
+
# Fallback to .pth if .kiri not found
|
|
45
|
+
elif model_path.endswith('.kiri') and (Path(model_path).parent / 'model.pth').exists():
|
|
46
|
+
model_path = str(Path(model_path).parent / 'model.pth')
|
|
47
|
+
|
|
48
|
+
if self.verbose:
|
|
49
|
+
print(f"š¦ Loading OCR model from {model_path}...")
|
|
50
|
+
|
|
51
|
+
try:
|
|
52
|
+
checkpoint = torch.load(model_path, map_location=device, weights_only=False)
|
|
53
|
+
|
|
54
|
+
# Load charset
|
|
55
|
+
if 'charset' in checkpoint:
|
|
56
|
+
if self.verbose:
|
|
57
|
+
print(f" ā Found embedded charset ({len(checkpoint['charset'])} chars)")
|
|
58
|
+
self.charset = CharacterSet.from_checkpoint(checkpoint)
|
|
59
|
+
else:
|
|
60
|
+
# Fallback to charset file
|
|
61
|
+
if not Path(charset_path).exists():
|
|
62
|
+
# Try looking in package directory if not found locally
|
|
63
|
+
pkg_dir = Path(__file__).parent
|
|
64
|
+
if (pkg_dir / charset_path).exists():
|
|
65
|
+
charset_path = str(pkg_dir / charset_path)
|
|
66
|
+
|
|
67
|
+
if self.verbose:
|
|
68
|
+
print(f" ā¹ļø Loading charset from file: {charset_path}")
|
|
69
|
+
self.charset = CharacterSet.load(charset_path)
|
|
70
|
+
|
|
71
|
+
# Initialize model
|
|
72
|
+
self.model = LightweightOCR(num_chars=len(self.charset)).eval()
|
|
73
|
+
|
|
74
|
+
# Load weights
|
|
75
|
+
if 'model_state_dict' in checkpoint:
|
|
76
|
+
self.model.load_state_dict(checkpoint['model_state_dict'])
|
|
77
|
+
else:
|
|
78
|
+
# Assume checkpoint IS state_dict (legacy/raw save)
|
|
79
|
+
self.model.load_state_dict(checkpoint)
|
|
80
|
+
|
|
81
|
+
except RuntimeError as e:
|
|
82
|
+
if "size mismatch" in str(e):
|
|
83
|
+
print(f"\nā Error loading model: {e}")
|
|
84
|
+
print("\nā ļø CRITICAL: The model weights do not match the character set.")
|
|
85
|
+
print(f" - Charset size: {len(self.charset)}")
|
|
86
|
+
print(f" - Model file: {model_path}")
|
|
87
|
+
sys.exit(1)
|
|
88
|
+
else:
|
|
89
|
+
raise e
|
|
90
|
+
except Exception as e:
|
|
91
|
+
print(f"ā Error loading model: {e}")
|
|
92
|
+
raise e
|
|
93
|
+
|
|
94
|
+
self.model = self.model.to(device)
|
|
95
|
+
|
|
96
|
+
if self.verbose:
|
|
97
|
+
print(f"ā Model loaded ({len(self.charset)} characters)")
|
|
98
|
+
print(f"ā Box padding: {padding}px")
|
|
99
|
+
|
|
100
|
+
# Detector
|
|
101
|
+
self.detector = TextDetector()
|
|
102
|
+
|
|
103
|
+
def _preprocess_region(self, img, box, extra_padding=5):
|
|
104
|
+
"""
|
|
105
|
+
Crop and preprocess a region with extra padding
|
|
106
|
+
|
|
107
|
+
Args:
|
|
108
|
+
img: Source image (numpy array)
|
|
109
|
+
box: Bounding box (x, y, w, h) - already has padding from detector
|
|
110
|
+
extra_padding: Additional padding when cropping (default: 5px)
|
|
111
|
+
"""
|
|
112
|
+
img_h, img_w = img.shape[:2]
|
|
113
|
+
x, y, w, h = box
|
|
114
|
+
|
|
115
|
+
# Add extra padding (with boundary checks)
|
|
116
|
+
x_extra = max(0, x - extra_padding)
|
|
117
|
+
y_extra = max(0, y - extra_padding)
|
|
118
|
+
w_extra = min(img_w - x_extra, w + 2 * extra_padding)
|
|
119
|
+
h_extra = min(img_h - y_extra, h + 2 * extra_padding)
|
|
120
|
+
|
|
121
|
+
# Crop
|
|
122
|
+
roi = img[y_extra:y_extra+h_extra, x_extra:x_extra+w_extra]
|
|
123
|
+
|
|
124
|
+
if roi.size == 0:
|
|
125
|
+
return None
|
|
126
|
+
|
|
127
|
+
# Invert if dark background (Model expects Light BG / Dark Text)
|
|
128
|
+
if np.mean(roi) < 127:
|
|
129
|
+
roi = 255 - roi
|
|
130
|
+
|
|
131
|
+
# Convert to PIL
|
|
132
|
+
roi_pil = Image.fromarray(roi).convert('L')
|
|
133
|
+
|
|
134
|
+
# Resize maintaining aspect ratio
|
|
135
|
+
orig_w, orig_h = roi_pil.size
|
|
136
|
+
new_h = 32
|
|
137
|
+
new_w = int((orig_w / orig_h) * new_h)
|
|
138
|
+
|
|
139
|
+
# Ensure minimum width
|
|
140
|
+
if new_w < 32:
|
|
141
|
+
new_w = 32
|
|
142
|
+
|
|
143
|
+
roi_pil = roi_pil.resize((new_w, new_h), Image.LANCZOS)
|
|
144
|
+
|
|
145
|
+
# Normalize
|
|
146
|
+
roi_array = np.array(roi_pil) / 255.0
|
|
147
|
+
roi_tensor = torch.tensor(roi_array, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
|
|
148
|
+
|
|
149
|
+
return roi_tensor
|
|
150
|
+
|
|
151
|
+
def recognize_single_line_image(self, image_path):
|
|
152
|
+
"""Recognize text from a single line image without detection"""
|
|
153
|
+
img = cv2.imread(str(image_path))
|
|
154
|
+
if img is None:
|
|
155
|
+
raise ValueError(f"Could not load image: {image_path}")
|
|
156
|
+
|
|
157
|
+
# Preprocess
|
|
158
|
+
# We need to resize height to 32 and keep aspect ratio
|
|
159
|
+
if len(img.shape) == 3:
|
|
160
|
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
|
161
|
+
|
|
162
|
+
# Invert if dark background (Model expects Light BG / Dark Text)
|
|
163
|
+
if np.mean(img) < 127:
|
|
164
|
+
img = 255 - img
|
|
165
|
+
|
|
166
|
+
img_pil = Image.fromarray(img)
|
|
167
|
+
w, h = img_pil.size
|
|
168
|
+
new_h = 32
|
|
169
|
+
new_w = int((w / h) * new_h)
|
|
170
|
+
if new_w < 32: new_w = 32
|
|
171
|
+
|
|
172
|
+
img_pil = img_pil.resize((new_w, new_h), Image.LANCZOS)
|
|
173
|
+
img_array = np.array(img_pil) / 255.0
|
|
174
|
+
img_tensor = torch.tensor(img_array, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
|
|
175
|
+
|
|
176
|
+
text, confidence = self.recognize_region(img_tensor)
|
|
177
|
+
return text, confidence
|
|
178
|
+
|
|
179
|
+
def recognize_region(self, image_tensor):
|
|
180
|
+
"""Recognize text in a single region"""
|
|
181
|
+
image_tensor = image_tensor.to(self.device)
|
|
182
|
+
|
|
183
|
+
with torch.no_grad():
|
|
184
|
+
logits = self.model(image_tensor)
|
|
185
|
+
|
|
186
|
+
# CTC decoding (LightweightOCR)
|
|
187
|
+
probs = torch.softmax(logits, dim=-1)
|
|
188
|
+
preds = torch.argmax(probs, dim=-1)
|
|
189
|
+
preds = preds.squeeze().tolist()
|
|
190
|
+
|
|
191
|
+
# Handle single timestep
|
|
192
|
+
if not isinstance(preds, list):
|
|
193
|
+
preds = [preds]
|
|
194
|
+
|
|
195
|
+
# CTC decode
|
|
196
|
+
char_confidences = []
|
|
197
|
+
decoded_indices = []
|
|
198
|
+
previous_idx = -1
|
|
199
|
+
|
|
200
|
+
for i, idx in enumerate(preds):
|
|
201
|
+
if idx != previous_idx:
|
|
202
|
+
if idx > 2: # Skip BLANK, PAD, SOS
|
|
203
|
+
decoded_indices.append(idx)
|
|
204
|
+
char_confidences.append(probs[i, 0, idx].item())
|
|
205
|
+
previous_idx = idx
|
|
206
|
+
|
|
207
|
+
confidence = np.mean(char_confidences) if char_confidences else 0.0
|
|
208
|
+
text = self.charset.decode(decoded_indices)
|
|
209
|
+
|
|
210
|
+
return text, confidence
|
|
211
|
+
|
|
212
|
+
def process_document(self, image_path, mode='lines', verbose=False):
|
|
213
|
+
"""
|
|
214
|
+
Process entire document
|
|
215
|
+
|
|
216
|
+
Args:
|
|
217
|
+
image_path: Path to document image
|
|
218
|
+
mode: 'lines' or 'words' for detection granularity
|
|
219
|
+
verbose: Whether to print progress
|
|
220
|
+
|
|
221
|
+
Returns:
|
|
222
|
+
List of dicts with 'box', 'text', 'confidence'
|
|
223
|
+
"""
|
|
224
|
+
if verbose:
|
|
225
|
+
print(f"\nš Processing document: {image_path}")
|
|
226
|
+
print(f"š² Box padding: {self.padding}px")
|
|
227
|
+
|
|
228
|
+
# Detect text regions
|
|
229
|
+
if mode == 'lines':
|
|
230
|
+
boxes = self.detector.detect_lines(image_path)
|
|
231
|
+
else:
|
|
232
|
+
boxes = self.detector.detect_words(image_path)
|
|
233
|
+
|
|
234
|
+
if verbose:
|
|
235
|
+
print(f"š Detected {len(boxes)} regions")
|
|
236
|
+
|
|
237
|
+
# Load image
|
|
238
|
+
img = cv2.imread(str(image_path))
|
|
239
|
+
|
|
240
|
+
# Recognize each region
|
|
241
|
+
results = []
|
|
242
|
+
for i, box in enumerate(boxes, 1):
|
|
243
|
+
try:
|
|
244
|
+
# Preprocess with extra padding
|
|
245
|
+
region_tensor = self._preprocess_region(img, box, extra_padding=5)
|
|
246
|
+
if region_tensor is None:
|
|
247
|
+
continue
|
|
248
|
+
|
|
249
|
+
# Recognize
|
|
250
|
+
text, confidence = self.recognize_region(region_tensor)
|
|
251
|
+
|
|
252
|
+
# Convert types for JSON serialization
|
|
253
|
+
safe_box = [int(v) for v in box]
|
|
254
|
+
safe_confidence = float(confidence)
|
|
255
|
+
|
|
256
|
+
results.append({
|
|
257
|
+
'box': safe_box,
|
|
258
|
+
'text': text,
|
|
259
|
+
'confidence': safe_confidence,
|
|
260
|
+
'line_number': i
|
|
261
|
+
})
|
|
262
|
+
|
|
263
|
+
if verbose:
|
|
264
|
+
print(f" {i:2d}. {text:50s} ({confidence*100:.1f}%)")
|
|
265
|
+
|
|
266
|
+
except Exception as e:
|
|
267
|
+
if verbose:
|
|
268
|
+
print(f" {i:2d}. [Error: {e}]")
|
|
269
|
+
continue
|
|
270
|
+
|
|
271
|
+
return results
|
|
272
|
+
|
|
273
|
+
def extract_text(self, image_path, mode='lines', verbose=False):
|
|
274
|
+
"""Extract all text from document as string"""
|
|
275
|
+
results = self.process_document(image_path, mode, verbose=verbose)
|
|
276
|
+
|
|
277
|
+
if not results:
|
|
278
|
+
return "", results
|
|
279
|
+
|
|
280
|
+
# Reconstruct text layout
|
|
281
|
+
# Sort by Y then X
|
|
282
|
+
results.sort(key=lambda r: (r['box'][1], r['box'][0]))
|
|
283
|
+
|
|
284
|
+
full_text = ""
|
|
285
|
+
for i, res in enumerate(results):
|
|
286
|
+
text = res['text']
|
|
287
|
+
box = res['box']
|
|
288
|
+
_, y, _, h = box
|
|
289
|
+
|
|
290
|
+
if i > 0:
|
|
291
|
+
prev_box = results[i-1]['box']
|
|
292
|
+
prev_y, prev_h = prev_box[1], prev_box[3]
|
|
293
|
+
|
|
294
|
+
# Check if on the same line (vertical center is close)
|
|
295
|
+
center_y = y + h/2
|
|
296
|
+
prev_center_y = prev_y + prev_h/2
|
|
297
|
+
|
|
298
|
+
# If vertical centers are close (within half height), assume same line
|
|
299
|
+
if abs(center_y - prev_center_y) < max(h, prev_h) / 2:
|
|
300
|
+
full_text += " " + text
|
|
301
|
+
else:
|
|
302
|
+
full_text += "\n" + text
|
|
303
|
+
else:
|
|
304
|
+
full_text += text
|
|
305
|
+
|
|
306
|
+
return full_text, results
|