kiri-ocr 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
kiri_ocr/model.py ADDED
@@ -0,0 +1,159 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from pathlib import Path
4
+
5
+ # ========== CHARACTER SET ==========
6
+ class CharacterSet:
7
+ """Manages character to index mapping"""
8
+ def __init__(self):
9
+ self.chars = ['BLANK', 'PAD', 'SOS', ' '] # Special tokens + Space
10
+ self.char2idx = {'BLANK': 0, 'PAD': 1, 'SOS': 2, ' ': 3}
11
+ self.idx2char = {0: 'BLANK', 1: 'PAD', 2: 'SOS', 3: ' '}
12
+
13
+ def add_chars(self, text):
14
+ """Add new characters from text"""
15
+ for char in text:
16
+ if char not in self.char2idx:
17
+ idx = len(self.chars)
18
+ self.chars.append(char)
19
+ self.char2idx[char] = idx
20
+ self.idx2char[idx] = char
21
+
22
+ def encode(self, text):
23
+ """Convert text to indices"""
24
+ return [self.char2idx.get(char, 0) for char in text]
25
+
26
+ def decode(self, indices):
27
+ """Convert indices to text (CTC decode)"""
28
+ chars = []
29
+ prev_idx = None
30
+ for idx in indices:
31
+ if idx > 2 and idx != prev_idx: # Skip blank, pad, sos
32
+ chars.append(self.idx2char.get(idx, ''))
33
+ prev_idx = idx
34
+ return ''.join(chars)
35
+
36
+ def __len__(self):
37
+ return len(self.chars)
38
+
39
+ def save(self, path):
40
+ with open(path, 'w', encoding='utf-8') as f:
41
+ for char in self.chars:
42
+ f.write(char + '\n')
43
+
44
+ @classmethod
45
+ def load(cls, path):
46
+ charset = cls()
47
+ with open(path, 'r', encoding='utf-8') as f:
48
+ chars = [line.rstrip('\n') for line in f]
49
+ charset.chars = chars
50
+ charset.char2idx = {char: idx for idx, char in enumerate(chars)}
51
+ charset.idx2char = {idx: char for idx, char in enumerate(chars)}
52
+ return charset
53
+
54
+ @classmethod
55
+ def from_checkpoint(cls, checkpoint):
56
+ """Load charset from model checkpoint dictionary"""
57
+ charset = cls()
58
+ if 'charset' in checkpoint:
59
+ charset.chars = checkpoint['charset']
60
+ charset.char2idx = {char: idx for idx, char in enumerate(charset.chars)}
61
+ charset.idx2char = {idx: char for idx, char in enumerate(charset.chars)}
62
+ return charset
63
+
64
+ def save_checkpoint(model, charset, optimizer, epoch, val_loss, accuracy, path):
65
+ """Save model checkpoint with charset included"""
66
+ torch.save({
67
+ 'epoch': epoch,
68
+ 'model_state_dict': model.state_dict(),
69
+ 'optimizer_state_dict': optimizer.state_dict(),
70
+ 'val_loss': val_loss,
71
+ 'accuracy': accuracy,
72
+ 'charset': charset.chars
73
+ }, path)
74
+
75
+ # ========== LIGHTWEIGHT OCR MODEL ==========
76
+ class LightweightOCR(nn.Module):
77
+ """
78
+ Lightweight OCR model (~13MB)
79
+ Similar to the reference model you showed
80
+ """
81
+ def __init__(self, num_chars, hidden_size=256):
82
+ super(LightweightOCR, self).__init__()
83
+
84
+ # CNN backbone (lighter than before)
85
+ self.cnn = nn.Sequential(
86
+ # Block 1: [B, 1, 32, W] -> [B, 32, 32, W]
87
+ self._conv_block(1, 32, kernel_size=3, stride=1, padding=1),
88
+ nn.MaxPool2d(kernel_size=2, stride=2), # -> [B, 32, 16, W/2]
89
+
90
+ # Block 2: [B, 32, 16, W/2] -> [B, 64, 16, W/2]
91
+ self._conv_block(32, 64, kernel_size=3, stride=1, padding=1),
92
+ nn.MaxPool2d(kernel_size=2, stride=2), # -> [B, 64, 8, W/4]
93
+
94
+ # Block 3: [B, 64, 8, W/4] -> [B, 128, 8, W/4]
95
+ self._conv_block(64, 128, kernel_size=3, stride=1, padding=1),
96
+ self._conv_block(128, 128, kernel_size=3, stride=1, padding=1),
97
+ nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1)), # -> [B, 128, 4, W/4]
98
+
99
+ # Block 4: [B, 128, 4, W/4] -> [B, 256, 4, W/4]
100
+ self._conv_block(128, 256, kernel_size=3, stride=1, padding=1),
101
+ self._conv_block(256, 256, kernel_size=3, stride=1, padding=1),
102
+ nn.MaxPool2d(kernel_size=(4, 1), stride=(4, 1)), # -> [B, 256, 1, W/4]
103
+ )
104
+
105
+ # Two-layer LSTM with intermediate projection
106
+ self.lstm1 = nn.LSTM(
107
+ 256, hidden_size,
108
+ bidirectional=True,
109
+ batch_first=True
110
+ )
111
+
112
+ self.intermediate_linear = nn.Linear(hidden_size * 2, hidden_size)
113
+
114
+ self.lstm2 = nn.LSTM(
115
+ hidden_size, hidden_size,
116
+ bidirectional=True,
117
+ batch_first=True
118
+ )
119
+
120
+ # Output layer
121
+ self.fc = nn.Linear(hidden_size * 2, num_chars)
122
+
123
+ def _conv_block(self, in_ch, out_ch, **kwargs):
124
+ """Convolutional block with BatchNorm and ReLU"""
125
+ return nn.Sequential(
126
+ nn.Conv2d(in_ch, out_ch, bias=False, **kwargs),
127
+ nn.BatchNorm2d(out_ch),
128
+ nn.ReLU(inplace=True),
129
+ )
130
+
131
+ def forward(self, x):
132
+ """
133
+ Args:
134
+ x: [B, 1, H, W]
135
+ Returns:
136
+ output: [T, B, num_chars]
137
+ """
138
+ # CNN features
139
+ features = self.cnn(x) # [B, 256, 1, W']
140
+
141
+ # Remove height dimension
142
+ features = features.squeeze(2) # [B, 256, W']
143
+ features = features.permute(0, 2, 1) # [B, W', 256]
144
+
145
+ # First LSTM layer
146
+ rnn_out, _ = self.lstm1(features) # [B, W', hidden*2]
147
+
148
+ # Intermediate projection
149
+ rnn_out = self.intermediate_linear(rnn_out) # [B, W', hidden]
150
+ rnn_out = torch.relu(rnn_out)
151
+
152
+ # Second LSTM layer
153
+ rnn_out, _ = self.lstm2(rnn_out) # [B, W', hidden*2]
154
+
155
+ # Fully connected
156
+ logits = self.fc(rnn_out) # [B, W', num_chars]
157
+
158
+ # Permute for CTC
159
+ return logits.permute(1, 0, 2) # [W', B, num_chars]
kiri_ocr/renderer.py ADDED
@@ -0,0 +1,193 @@
1
+ import cv2
2
+ import numpy as np
3
+ from PIL import Image, ImageDraw, ImageFont
4
+ from pathlib import Path
5
+
6
+ class DocumentRenderer:
7
+ """Render OCR results on image"""
8
+
9
+ def __init__(self, font_path=None, font_size=12):
10
+ self.font_size = font_size
11
+
12
+ # Try to load font for text rendering
13
+ self.font = None
14
+ if font_path and Path(font_path).exists():
15
+ try:
16
+ self.font = ImageFont.truetype(font_path, font_size)
17
+ except:
18
+ pass
19
+
20
+ # Fallback fonts
21
+ if self.font is None:
22
+ # Priority list including Khmer fonts
23
+ candidate_fonts = [
24
+ 'fonts/KhmerOSbattambang.ttf',
25
+ 'fonts/Battambang-Regular.ttf',
26
+ 'fonts/NotoSansKhmer-Regular.ttf',
27
+ 'Arial.ttf',
28
+ 'DejaVuSans.ttf',
29
+ 'NotoSans-Regular.ttf'
30
+ ]
31
+
32
+ # Check for any TTF in fonts/ directory
33
+ if Path('fonts').exists():
34
+ candidate_fonts = [str(f) for f in Path('fonts').glob('*.ttf')] + candidate_fonts
35
+
36
+ for font_name in candidate_fonts:
37
+ try:
38
+ self.font = ImageFont.truetype(font_name, font_size)
39
+ break
40
+ except:
41
+ continue
42
+
43
+ def draw_boxes(self, image_path, results, output_path='output_boxes.png'):
44
+ """Draw bounding boxes only"""
45
+ img = cv2.imread(str(image_path))
46
+
47
+ for result in results:
48
+ x, y, w, h = result['box']
49
+ conf = result['confidence']
50
+
51
+ # Color based on confidence
52
+ if conf > 0.9:
53
+ color = (0, 255, 0) # Green
54
+ elif conf > 0.7:
55
+ color = (0, 165, 255) # Orange
56
+ else:
57
+ color = (0, 0, 255) # Red
58
+
59
+ # Draw box
60
+ cv2.rectangle(img, (x, y), (x+w, y+h), color, 2)
61
+
62
+ # Draw line number
63
+ if 'line_number' in result:
64
+ cv2.putText(img, str(result['line_number']), (x-5, y-5),
65
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
66
+
67
+ cv2.imwrite(output_path, img)
68
+ print(f"\nāœ“ Boxes saved to {output_path}")
69
+
70
+ return img
71
+
72
+ def draw_results(self, image_path, results, output_path='output_ocr.png',
73
+ show_text=True, show_confidence=True):
74
+ """Draw boxes with recognized text"""
75
+ img = cv2.imread(str(image_path))
76
+
77
+ # Convert to PIL for better text rendering
78
+ img_pil = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
79
+ draw = ImageDraw.Draw(img_pil)
80
+
81
+ for result in results:
82
+ x, y, w, h = result['box']
83
+ text = result['text']
84
+ conf = result['confidence']
85
+
86
+ # Color based on confidence
87
+ if conf > 0.9:
88
+ color = (0, 255, 0)
89
+ elif conf > 0.7:
90
+ color = (255, 165, 0)
91
+ else:
92
+ color = (255, 0, 0)
93
+
94
+ # Draw box
95
+ draw.rectangle([x, y, x+w, y+h], outline=color, width=2)
96
+
97
+ if show_text:
98
+ # Prepare label
99
+ label = text[:50] # Limit length
100
+ if show_confidence:
101
+ label += f" ({conf*100:.0f}%)"
102
+
103
+ # Background for text
104
+ if self.font:
105
+ try:
106
+ # Calculate text position
107
+ left, top, right, bottom = draw.textbbox((0, 0), label, font=self.font)
108
+ text_height = bottom - top
109
+ text_y = y - text_height - 5
110
+
111
+ bbox = draw.textbbox((x, text_y), label, font=self.font)
112
+ # Add padding
113
+ bbox = (bbox[0]-2, bbox[1]-2, bbox[2]+2, bbox[3]+2)
114
+
115
+ draw.rectangle(bbox, fill=color)
116
+ draw.text((x, text_y), label, fill=(255, 255, 255), font=self.font)
117
+ except:
118
+ # Fallback
119
+ img_cv = cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR)
120
+ cv2.putText(img_cv, label, (x, y-10),
121
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
122
+ img_pil = Image.fromarray(cv2.cvtColor(img_cv, cv2.COLOR_BGR2RGB))
123
+ draw = ImageDraw.Draw(img_pil)
124
+
125
+ # Save
126
+ img_pil.save(output_path)
127
+ print(f"āœ“ Results saved to {output_path}")
128
+
129
+ return img_pil
130
+
131
+ def create_report(self, image_path, results, output_path='ocr_report.html'):
132
+ """Create HTML report"""
133
+ html = f"""
134
+ <!DOCTYPE html>
135
+ <html>
136
+ <head>
137
+ <meta charset="UTF-8">
138
+ <title>OCR Report</title>
139
+ <style>
140
+ body {{ font-family: Arial, sans-serif; margin: 20px; }}
141
+ .header {{ background: #4CAF50; color: white; padding: 20px; border-radius: 5px; }}
142
+ .result {{ padding: 10px; margin: 10px 0; border-left: 4px solid #4CAF50; }}
143
+ .conf-high {{ border-color: #4CAF50; background: #f1f8f4; }}
144
+ .conf-medium {{ border-color: #FF9800; background: #fff8f1; }}
145
+ .conf-low {{ border-color: #F44336; background: #fef1f1; }}
146
+ .text {{ font-size: 18px; font-weight: bold; margin-bottom: 5px; }}
147
+ .confidence {{ color: #666; font-size: 14px; }}
148
+ .stats {{ background: #f5f5f5; padding: 15px; border-radius: 5px; margin: 20px 0; }}
149
+ .full-text {{ background: #f5f5f5; padding: 20px; white-space: pre-wrap;
150
+ font-family: monospace; border-radius: 5px; }}
151
+ </style>
152
+ </head>
153
+ <body>
154
+ <div class="header">
155
+ <h1>šŸ“„ OCR Report</h1>
156
+ <p>Document: {Path(image_path).name}</p>
157
+ </div>
158
+
159
+ <div class="stats">
160
+ <strong>Statistics:</strong><br>
161
+ Total Regions: {len(results)}<br>
162
+ Average Confidence: {np.mean([r['confidence'] for r in results])*100:.2f}%<br>
163
+ High Confidence (>90%): {sum(1 for r in results if r['confidence'] > 0.9)}<br>
164
+ Medium Confidence (70-90%): {sum(1 for r in results if 0.7 < r['confidence'] <= 0.9)}<br>
165
+ Low Confidence (<70%): {sum(1 for r in results if r['confidence'] <= 0.7)}
166
+ </div>
167
+
168
+ <h2>šŸ“ Full Text</h2>
169
+ <div class="full-text">{"<br>".join([r['text'] for r in results])}</div>
170
+
171
+ <h2>šŸ“‹ Detailed Results</h2>
172
+ """
173
+
174
+ for i, result in enumerate(results, 1):
175
+ conf = result['confidence']
176
+ conf_class = 'conf-high' if conf > 0.9 else ('conf-medium' if conf > 0.7 else 'conf-low')
177
+
178
+ html += f"""
179
+ <div class="result {conf_class}">
180
+ <div class="text">{i}. {result['text']}</div>
181
+ <div class="confidence">Confidence: {conf*100:.2f}% | Box: {result['box']}</div>
182
+ </div>
183
+ """
184
+
185
+ html += """
186
+ </body>
187
+ </html>
188
+ """
189
+
190
+ with open(output_path, 'w', encoding='utf-8') as f:
191
+ f.write(html)
192
+
193
+ print(f"āœ“ Report saved to {output_path}")