kiri-ocr 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
kiri_ocr/__init__.py ADDED
@@ -0,0 +1,14 @@
1
+ from .core import OCR
2
+ from .renderer import DocumentRenderer
3
+ from .model import LightweightOCR, CharacterSet
4
+ from .detector import TextDetector
5
+
6
+ __version__ = '0.1.0'
7
+
8
+ __all__ = [
9
+ 'OCR',
10
+ 'DocumentRenderer',
11
+ 'LightweightOCR',
12
+ 'CharacterSet',
13
+ 'TextDetector',
14
+ ]
kiri_ocr/cli.py ADDED
@@ -0,0 +1,244 @@
1
+ import numpy as np
2
+ import json
3
+ import argparse
4
+ import sys
5
+ from pathlib import Path
6
+ import yaml
7
+
8
+ from .core import OCR
9
+ from .renderer import DocumentRenderer
10
+ from .training import train_command
11
+ from .generator import generate_command
12
+
13
+ DEFAULT_TRAIN_CONFIG = {
14
+ "height": 32,
15
+ "batch_size": 32,
16
+ "epochs": 2,
17
+ "hidden_size": 256,
18
+ "device": "cuda",
19
+ "output_dir": "models",
20
+ "train_labels": "data/train/labels.txt",
21
+ "val_labels": "data/val/labels.txt",
22
+ "lr": 0.001,
23
+ "weight_decay": 0.0001
24
+ }
25
+
26
+ def init_config(args):
27
+ path = args.output
28
+ # Ensure .yaml extension if default or user didn't specify
29
+ if path == 'config.yaml' and not path.endswith('.yaml') and not path.endswith('.yml'):
30
+ pass # user provided custom name
31
+
32
+ with open(path, 'w') as f:
33
+ yaml.dump(DEFAULT_TRAIN_CONFIG, f, default_flow_style=False)
34
+ print(f"āœ“ Created default config at {path}")
35
+
36
+ def run_inference(args):
37
+ # Create output directory
38
+ output_dir = Path(args.output)
39
+ output_dir.mkdir(exist_ok=True)
40
+
41
+ print("\n" + "="*70)
42
+ print(" šŸ“„ Kiri OCR System")
43
+ print("="*70)
44
+
45
+ try:
46
+ # Initialize OCR
47
+ ocr = OCR(
48
+ model_path=args.model,
49
+ charset_path=args.charset,
50
+ language=args.language,
51
+ padding=args.padding,
52
+ device=args.device,
53
+ verbose=True
54
+ )
55
+
56
+ # Process document & Extract text
57
+ full_text, results = ocr.extract_text(args.image, mode=args.mode, verbose=True)
58
+
59
+ # Save text
60
+ text_output = output_dir / 'extracted_text.txt'
61
+ with open(text_output, 'w', encoding='utf-8') as f:
62
+ f.write(full_text)
63
+ print(f"\nāœ“ Text saved to {text_output}")
64
+
65
+ # Save JSON
66
+ json_output = output_dir / 'ocr_results.json'
67
+ with open(json_output, 'w', encoding='utf-8') as f:
68
+ json.dump(results, f, indent=2, ensure_ascii=False)
69
+ print(f"āœ“ JSON saved to {json_output}")
70
+
71
+ # Render results
72
+ if not args.no_render:
73
+ renderer = DocumentRenderer()
74
+
75
+ # Boxes only
76
+ renderer.draw_boxes(
77
+ args.image,
78
+ results,
79
+ output_path=str(output_dir / 'boxes.png')
80
+ )
81
+
82
+ # Boxes with text
83
+ renderer.draw_results(
84
+ args.image,
85
+ results,
86
+ output_path=str(output_dir / 'ocr_result.png')
87
+ )
88
+
89
+ # HTML report
90
+ renderer.create_report(
91
+ args.image,
92
+ results,
93
+ output_path=str(output_dir / 'report.html')
94
+ )
95
+
96
+ print("\n" + "="*70)
97
+ print(" āœ… Processing Complete!")
98
+ print("="*70)
99
+ print(f" Regions detected: {len(results)}")
100
+ if results:
101
+ print(f" Average confidence: {np.mean([r['confidence'] for r in results])*100:.2f}%")
102
+ print(f" Output directory: {output_dir}")
103
+ print("="*70 + "\n")
104
+
105
+ except Exception as e:
106
+ print(f"\nāŒ Error: {e}")
107
+ # Suggest model path if it seems to be missing
108
+ if "No such file" in str(e) and "model" in str(e):
109
+ print("\nTip: Make sure you have trained a model or specified the correct path with --model")
110
+ print(" Run: kiri-ocr train ... to train a model first.")
111
+
112
+ def merge_config(args, defaults):
113
+ """Merge defaults < config file < args"""
114
+ # Start with defaults
115
+ config = defaults.copy()
116
+
117
+ # Update with config file if provided
118
+ if hasattr(args, 'config') and args.config:
119
+ try:
120
+ with open(args.config, 'r') as f:
121
+ if args.config.endswith('.json'):
122
+ file_config = json.load(f)
123
+ else:
124
+ file_config = yaml.safe_load(f)
125
+
126
+ if file_config:
127
+ config.update(file_config)
128
+ print(f"Loaded config from {args.config}")
129
+ except Exception as e:
130
+ print(f"Error loading config file: {e}")
131
+ sys.exit(1)
132
+
133
+ # Update with explicit args (non-None)
134
+ for key, value in vars(args).items():
135
+ if value is not None and key in config:
136
+ config[key] = value
137
+
138
+ # Update args object
139
+ for key, value in config.items():
140
+ setattr(args, key, value)
141
+
142
+ return args
143
+
144
+ def main():
145
+ parser = argparse.ArgumentParser(description='Kiri OCR - Command Line Tool')
146
+ subparsers = parser.add_subparsers(dest='command', help='Command to run')
147
+
148
+ # === PREDICT / RUN ===
149
+ predict_parser = subparsers.add_parser('predict', help='Run OCR on an image')
150
+ predict_parser.add_argument('image', help='Path to document image')
151
+ predict_parser.add_argument('--mode', choices=['lines', 'words'], default='lines',
152
+ help='Detection mode (default: lines)')
153
+ predict_parser.add_argument('--model', default='models/model.kiri',
154
+ help='Path to model file')
155
+ predict_parser.add_argument('--charset', default='models/charset_lite.txt',
156
+ help='Path to charset file')
157
+ predict_parser.add_argument('--language', choices=['english', 'khmer', 'mixed'],
158
+ default='mixed', help='Language mode')
159
+ predict_parser.add_argument('--padding', type=int, default=10,
160
+ help='Padding around detected boxes in pixels (default: 10)')
161
+ predict_parser.add_argument('--output', '-o', default='output',
162
+ help='Output directory')
163
+ predict_parser.add_argument('--no-render', action='store_true',
164
+ help='Skip rendering (text only)')
165
+ predict_parser.add_argument('--device', choices=['cpu', 'cuda'], default='cpu',
166
+ help='Device to use')
167
+
168
+ # === TRAIN ===
169
+ train_parser = subparsers.add_parser('train', help='Train the OCR model')
170
+ train_parser.add_argument('--config', help='Path to config file (YAML/JSON)')
171
+ train_parser.add_argument('--train-labels', help='Path to training labels file')
172
+ train_parser.add_argument('--val-labels', help='Path to validation labels file')
173
+ train_parser.add_argument('--hf-dataset', help='HuggingFace dataset ID (e.g. mrrtmob/km_en_image_line)')
174
+ train_parser.add_argument('--hf-subset', help='HuggingFace dataset subset/config name')
175
+ train_parser.add_argument('--hf-train-split', default='train', help='Train split name (default: train)')
176
+ train_parser.add_argument('--hf-val-split', help='Validation split name')
177
+ train_parser.add_argument('--hf-image-col', default='image', help='Image column name')
178
+ train_parser.add_argument('--hf-text-col', default='text', help='Text column name')
179
+ train_parser.add_argument('--hf-val-percent', type=float, default=0.1, help='Val % if no split found (default: 0.1)')
180
+ train_parser.add_argument('--output-dir', help='Directory to save model')
181
+ train_parser.add_argument('--epochs', type=int, help='Number of epochs')
182
+ train_parser.add_argument('--batch-size', type=int, help='Batch size')
183
+ train_parser.add_argument('--height', type=int, help='Image height')
184
+ train_parser.add_argument('--hidden-size', type=int, help='Hidden size')
185
+ train_parser.add_argument('--device', help='Device (cuda/cpu)')
186
+ train_parser.add_argument('--from-model', help='Path to pretrained model for fine-tuning')
187
+ train_parser.add_argument('--lr', type=float, help='Learning rate')
188
+ train_parser.add_argument('--weight-decay', type=float, help='Weight decay')
189
+
190
+ # === GENERATE ===
191
+ gen_parser = subparsers.add_parser('generate', help='Generate synthetic data')
192
+ gen_parser.add_argument('--train-file', '-t', required=True,
193
+ help='Training text file (one line per sample)')
194
+ gen_parser.add_argument('--val-file', '-v', default=None,
195
+ help='Validation text file (optional)')
196
+ gen_parser.add_argument('--output', '-o', default='data',
197
+ help='Output directory')
198
+ gen_parser.add_argument('--language', '-l', choices=['english', 'khmer', 'mixed'],
199
+ default='mixed', help='Language mode')
200
+ gen_parser.add_argument('--augment', '-a', type=int, default=1,
201
+ help='Augmentation factor for training')
202
+ gen_parser.add_argument('--val-augment', type=int, default=1,
203
+ help='Augmentation factor for validation')
204
+ gen_parser.add_argument('--height', type=int, default=32,
205
+ help='Image height')
206
+ gen_parser.add_argument('--width', type=int, default=512,
207
+ help='Image width')
208
+ gen_parser.add_argument('--fonts-dir', default='fonts',
209
+ help='Directory containing .ttf font files')
210
+ gen_parser.add_argument('--font-mode', choices=['random', 'all'], default='random',
211
+ help='Font selection mode: random (default) or all (iterate all fonts per line)')
212
+ gen_parser.add_argument('--random-augment', action='store_true',
213
+ help='Apply random augmentations (noise, rotation) even if augmentation factor is 1')
214
+
215
+ # === INIT CONFIG ===
216
+ init_parser = subparsers.add_parser('init-config', help='Create default config file')
217
+ init_parser.add_argument('--output', '-o', default='config.yaml', help='Output file')
218
+
219
+ # Backward compatibility logic
220
+ if len(sys.argv) > 1 and sys.argv[1] not in ['predict', 'train', 'generate', 'init-config', '-h', '--help']:
221
+ sys.argv.insert(1, 'predict')
222
+
223
+ args = parser.parse_args()
224
+
225
+ if args.command == 'predict':
226
+ run_inference(args)
227
+ elif args.command == 'train':
228
+ # Merge config for training
229
+ args = merge_config(args, DEFAULT_TRAIN_CONFIG)
230
+ # Check required fields
231
+ if not args.train_labels and not args.hf_dataset:
232
+ print("āŒ Error: --train-labels or --hf-dataset is required")
233
+ sys.exit(1)
234
+
235
+ train_command(args)
236
+ elif args.command == 'generate':
237
+ generate_command(args)
238
+ elif args.command == 'init-config':
239
+ init_config(args)
240
+ else:
241
+ parser.print_help()
242
+
243
+ if __name__ == '__main__':
244
+ main()
kiri_ocr/core.py ADDED
@@ -0,0 +1,306 @@
1
+ import torch
2
+ import cv2
3
+ import numpy as np
4
+ import sys
5
+ from PIL import Image
6
+ from pathlib import Path
7
+ import json
8
+
9
+ from .model import LightweightOCR, CharacterSet
10
+ from .detector import TextDetector
11
+
12
+ class OCR:
13
+ """Complete Document OCR System with Padding"""
14
+
15
+ def __init__(self, model_path='models/model.kiri',
16
+ charset_path='models/charset_lite.txt',
17
+ language='mixed',
18
+ padding=10,
19
+ device='cpu',
20
+ verbose=False):
21
+ """
22
+ Args:
23
+ model_path: Path to trained model (.kiri or .pth)
24
+ charset_path: Path to character set (used if model doesn't contain charset)
25
+ language: 'english', 'khmer', or 'mixed'
26
+ padding: Pixels to pad around detected boxes (default: 10)
27
+ device: 'cpu' or 'cuda'
28
+ verbose: Whether to print loading/processing info
29
+ """
30
+ self.device = device
31
+ self.verbose = verbose
32
+ self.language = language
33
+ self.padding = padding
34
+
35
+ # Resolve model path
36
+ if not Path(model_path).exists():
37
+ # Try looking in package directory
38
+ pkg_dir = Path(__file__).parent
39
+ if (pkg_dir / model_path).exists():
40
+ model_path = str(pkg_dir / model_path)
41
+ # Try looking in sibling 'models' package (if installed via setup.py with models package)
42
+ elif (pkg_dir.parent / 'models' / Path(model_path).name).exists():
43
+ model_path = str(pkg_dir.parent / 'models' / Path(model_path).name)
44
+ # Fallback to .pth if .kiri not found
45
+ elif model_path.endswith('.kiri') and (Path(model_path).parent / 'model.pth').exists():
46
+ model_path = str(Path(model_path).parent / 'model.pth')
47
+
48
+ if self.verbose:
49
+ print(f"šŸ“¦ Loading OCR model from {model_path}...")
50
+
51
+ try:
52
+ checkpoint = torch.load(model_path, map_location=device, weights_only=False)
53
+
54
+ # Load charset
55
+ if 'charset' in checkpoint:
56
+ if self.verbose:
57
+ print(f" āœ“ Found embedded charset ({len(checkpoint['charset'])} chars)")
58
+ self.charset = CharacterSet.from_checkpoint(checkpoint)
59
+ else:
60
+ # Fallback to charset file
61
+ if not Path(charset_path).exists():
62
+ # Try looking in package directory if not found locally
63
+ pkg_dir = Path(__file__).parent
64
+ if (pkg_dir / charset_path).exists():
65
+ charset_path = str(pkg_dir / charset_path)
66
+
67
+ if self.verbose:
68
+ print(f" ā„¹ļø Loading charset from file: {charset_path}")
69
+ self.charset = CharacterSet.load(charset_path)
70
+
71
+ # Initialize model
72
+ self.model = LightweightOCR(num_chars=len(self.charset)).eval()
73
+
74
+ # Load weights
75
+ if 'model_state_dict' in checkpoint:
76
+ self.model.load_state_dict(checkpoint['model_state_dict'])
77
+ else:
78
+ # Assume checkpoint IS state_dict (legacy/raw save)
79
+ self.model.load_state_dict(checkpoint)
80
+
81
+ except RuntimeError as e:
82
+ if "size mismatch" in str(e):
83
+ print(f"\nāŒ Error loading model: {e}")
84
+ print("\nāš ļø CRITICAL: The model weights do not match the character set.")
85
+ print(f" - Charset size: {len(self.charset)}")
86
+ print(f" - Model file: {model_path}")
87
+ sys.exit(1)
88
+ else:
89
+ raise e
90
+ except Exception as e:
91
+ print(f"āŒ Error loading model: {e}")
92
+ raise e
93
+
94
+ self.model = self.model.to(device)
95
+
96
+ if self.verbose:
97
+ print(f"āœ“ Model loaded ({len(self.charset)} characters)")
98
+ print(f"āœ“ Box padding: {padding}px")
99
+
100
+ # Detector
101
+ self.detector = TextDetector()
102
+
103
+ def _preprocess_region(self, img, box, extra_padding=5):
104
+ """
105
+ Crop and preprocess a region with extra padding
106
+
107
+ Args:
108
+ img: Source image (numpy array)
109
+ box: Bounding box (x, y, w, h) - already has padding from detector
110
+ extra_padding: Additional padding when cropping (default: 5px)
111
+ """
112
+ img_h, img_w = img.shape[:2]
113
+ x, y, w, h = box
114
+
115
+ # Add extra padding (with boundary checks)
116
+ x_extra = max(0, x - extra_padding)
117
+ y_extra = max(0, y - extra_padding)
118
+ w_extra = min(img_w - x_extra, w + 2 * extra_padding)
119
+ h_extra = min(img_h - y_extra, h + 2 * extra_padding)
120
+
121
+ # Crop
122
+ roi = img[y_extra:y_extra+h_extra, x_extra:x_extra+w_extra]
123
+
124
+ if roi.size == 0:
125
+ return None
126
+
127
+ # Invert if dark background (Model expects Light BG / Dark Text)
128
+ if np.mean(roi) < 127:
129
+ roi = 255 - roi
130
+
131
+ # Convert to PIL
132
+ roi_pil = Image.fromarray(roi).convert('L')
133
+
134
+ # Resize maintaining aspect ratio
135
+ orig_w, orig_h = roi_pil.size
136
+ new_h = 32
137
+ new_w = int((orig_w / orig_h) * new_h)
138
+
139
+ # Ensure minimum width
140
+ if new_w < 32:
141
+ new_w = 32
142
+
143
+ roi_pil = roi_pil.resize((new_w, new_h), Image.LANCZOS)
144
+
145
+ # Normalize
146
+ roi_array = np.array(roi_pil) / 255.0
147
+ roi_tensor = torch.tensor(roi_array, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
148
+
149
+ return roi_tensor
150
+
151
+ def recognize_single_line_image(self, image_path):
152
+ """Recognize text from a single line image without detection"""
153
+ img = cv2.imread(str(image_path))
154
+ if img is None:
155
+ raise ValueError(f"Could not load image: {image_path}")
156
+
157
+ # Preprocess
158
+ # We need to resize height to 32 and keep aspect ratio
159
+ if len(img.shape) == 3:
160
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
161
+
162
+ # Invert if dark background (Model expects Light BG / Dark Text)
163
+ if np.mean(img) < 127:
164
+ img = 255 - img
165
+
166
+ img_pil = Image.fromarray(img)
167
+ w, h = img_pil.size
168
+ new_h = 32
169
+ new_w = int((w / h) * new_h)
170
+ if new_w < 32: new_w = 32
171
+
172
+ img_pil = img_pil.resize((new_w, new_h), Image.LANCZOS)
173
+ img_array = np.array(img_pil) / 255.0
174
+ img_tensor = torch.tensor(img_array, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
175
+
176
+ text, confidence = self.recognize_region(img_tensor)
177
+ return text, confidence
178
+
179
+ def recognize_region(self, image_tensor):
180
+ """Recognize text in a single region"""
181
+ image_tensor = image_tensor.to(self.device)
182
+
183
+ with torch.no_grad():
184
+ logits = self.model(image_tensor)
185
+
186
+ # CTC decoding (LightweightOCR)
187
+ probs = torch.softmax(logits, dim=-1)
188
+ preds = torch.argmax(probs, dim=-1)
189
+ preds = preds.squeeze().tolist()
190
+
191
+ # Handle single timestep
192
+ if not isinstance(preds, list):
193
+ preds = [preds]
194
+
195
+ # CTC decode
196
+ char_confidences = []
197
+ decoded_indices = []
198
+ previous_idx = -1
199
+
200
+ for i, idx in enumerate(preds):
201
+ if idx != previous_idx:
202
+ if idx > 2: # Skip BLANK, PAD, SOS
203
+ decoded_indices.append(idx)
204
+ char_confidences.append(probs[i, 0, idx].item())
205
+ previous_idx = idx
206
+
207
+ confidence = np.mean(char_confidences) if char_confidences else 0.0
208
+ text = self.charset.decode(decoded_indices)
209
+
210
+ return text, confidence
211
+
212
+ def process_document(self, image_path, mode='lines', verbose=False):
213
+ """
214
+ Process entire document
215
+
216
+ Args:
217
+ image_path: Path to document image
218
+ mode: 'lines' or 'words' for detection granularity
219
+ verbose: Whether to print progress
220
+
221
+ Returns:
222
+ List of dicts with 'box', 'text', 'confidence'
223
+ """
224
+ if verbose:
225
+ print(f"\nšŸ“„ Processing document: {image_path}")
226
+ print(f"šŸ”² Box padding: {self.padding}px")
227
+
228
+ # Detect text regions
229
+ if mode == 'lines':
230
+ boxes = self.detector.detect_lines(image_path)
231
+ else:
232
+ boxes = self.detector.detect_words(image_path)
233
+
234
+ if verbose:
235
+ print(f"šŸ” Detected {len(boxes)} regions")
236
+
237
+ # Load image
238
+ img = cv2.imread(str(image_path))
239
+
240
+ # Recognize each region
241
+ results = []
242
+ for i, box in enumerate(boxes, 1):
243
+ try:
244
+ # Preprocess with extra padding
245
+ region_tensor = self._preprocess_region(img, box, extra_padding=5)
246
+ if region_tensor is None:
247
+ continue
248
+
249
+ # Recognize
250
+ text, confidence = self.recognize_region(region_tensor)
251
+
252
+ # Convert types for JSON serialization
253
+ safe_box = [int(v) for v in box]
254
+ safe_confidence = float(confidence)
255
+
256
+ results.append({
257
+ 'box': safe_box,
258
+ 'text': text,
259
+ 'confidence': safe_confidence,
260
+ 'line_number': i
261
+ })
262
+
263
+ if verbose:
264
+ print(f" {i:2d}. {text:50s} ({confidence*100:.1f}%)")
265
+
266
+ except Exception as e:
267
+ if verbose:
268
+ print(f" {i:2d}. [Error: {e}]")
269
+ continue
270
+
271
+ return results
272
+
273
+ def extract_text(self, image_path, mode='lines', verbose=False):
274
+ """Extract all text from document as string"""
275
+ results = self.process_document(image_path, mode, verbose=verbose)
276
+
277
+ if not results:
278
+ return "", results
279
+
280
+ # Reconstruct text layout
281
+ # Sort by Y then X
282
+ results.sort(key=lambda r: (r['box'][1], r['box'][0]))
283
+
284
+ full_text = ""
285
+ for i, res in enumerate(results):
286
+ text = res['text']
287
+ box = res['box']
288
+ _, y, _, h = box
289
+
290
+ if i > 0:
291
+ prev_box = results[i-1]['box']
292
+ prev_y, prev_h = prev_box[1], prev_box[3]
293
+
294
+ # Check if on the same line (vertical center is close)
295
+ center_y = y + h/2
296
+ prev_center_y = prev_y + prev_h/2
297
+
298
+ # If vertical centers are close (within half height), assume same line
299
+ if abs(center_y - prev_center_y) < max(h, prev_h) / 2:
300
+ full_text += " " + text
301
+ else:
302
+ full_text += "\n" + text
303
+ else:
304
+ full_text += text
305
+
306
+ return full_text, results