kiri-ocr 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kiri_ocr/__init__.py +14 -0
- kiri_ocr/cli.py +244 -0
- kiri_ocr/core.py +306 -0
- kiri_ocr/detector.py +374 -0
- kiri_ocr/generator.py +570 -0
- kiri_ocr/model.py +159 -0
- kiri_ocr/renderer.py +193 -0
- kiri_ocr/training.py +508 -0
- kiri_ocr-0.1.0.data/scripts/kiri-ocr +6 -0
- kiri_ocr-0.1.0.dist-info/METADATA +218 -0
- kiri_ocr-0.1.0.dist-info/RECORD +16 -0
- kiri_ocr-0.1.0.dist-info/WHEEL +5 -0
- kiri_ocr-0.1.0.dist-info/licenses/LICENSE +201 -0
- kiri_ocr-0.1.0.dist-info/top_level.txt +2 -0
- models/__init__.py +1 -0
- models/model.kiri +0 -0
kiri_ocr/model.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
# ========== CHARACTER SET ==========
|
|
6
|
+
class CharacterSet:
|
|
7
|
+
"""Manages character to index mapping"""
|
|
8
|
+
def __init__(self):
|
|
9
|
+
self.chars = ['BLANK', 'PAD', 'SOS', ' '] # Special tokens + Space
|
|
10
|
+
self.char2idx = {'BLANK': 0, 'PAD': 1, 'SOS': 2, ' ': 3}
|
|
11
|
+
self.idx2char = {0: 'BLANK', 1: 'PAD', 2: 'SOS', 3: ' '}
|
|
12
|
+
|
|
13
|
+
def add_chars(self, text):
|
|
14
|
+
"""Add new characters from text"""
|
|
15
|
+
for char in text:
|
|
16
|
+
if char not in self.char2idx:
|
|
17
|
+
idx = len(self.chars)
|
|
18
|
+
self.chars.append(char)
|
|
19
|
+
self.char2idx[char] = idx
|
|
20
|
+
self.idx2char[idx] = char
|
|
21
|
+
|
|
22
|
+
def encode(self, text):
|
|
23
|
+
"""Convert text to indices"""
|
|
24
|
+
return [self.char2idx.get(char, 0) for char in text]
|
|
25
|
+
|
|
26
|
+
def decode(self, indices):
|
|
27
|
+
"""Convert indices to text (CTC decode)"""
|
|
28
|
+
chars = []
|
|
29
|
+
prev_idx = None
|
|
30
|
+
for idx in indices:
|
|
31
|
+
if idx > 2 and idx != prev_idx: # Skip blank, pad, sos
|
|
32
|
+
chars.append(self.idx2char.get(idx, ''))
|
|
33
|
+
prev_idx = idx
|
|
34
|
+
return ''.join(chars)
|
|
35
|
+
|
|
36
|
+
def __len__(self):
|
|
37
|
+
return len(self.chars)
|
|
38
|
+
|
|
39
|
+
def save(self, path):
|
|
40
|
+
with open(path, 'w', encoding='utf-8') as f:
|
|
41
|
+
for char in self.chars:
|
|
42
|
+
f.write(char + '\n')
|
|
43
|
+
|
|
44
|
+
@classmethod
|
|
45
|
+
def load(cls, path):
|
|
46
|
+
charset = cls()
|
|
47
|
+
with open(path, 'r', encoding='utf-8') as f:
|
|
48
|
+
chars = [line.rstrip('\n') for line in f]
|
|
49
|
+
charset.chars = chars
|
|
50
|
+
charset.char2idx = {char: idx for idx, char in enumerate(chars)}
|
|
51
|
+
charset.idx2char = {idx: char for idx, char in enumerate(chars)}
|
|
52
|
+
return charset
|
|
53
|
+
|
|
54
|
+
@classmethod
|
|
55
|
+
def from_checkpoint(cls, checkpoint):
|
|
56
|
+
"""Load charset from model checkpoint dictionary"""
|
|
57
|
+
charset = cls()
|
|
58
|
+
if 'charset' in checkpoint:
|
|
59
|
+
charset.chars = checkpoint['charset']
|
|
60
|
+
charset.char2idx = {char: idx for idx, char in enumerate(charset.chars)}
|
|
61
|
+
charset.idx2char = {idx: char for idx, char in enumerate(charset.chars)}
|
|
62
|
+
return charset
|
|
63
|
+
|
|
64
|
+
def save_checkpoint(model, charset, optimizer, epoch, val_loss, accuracy, path):
|
|
65
|
+
"""Save model checkpoint with charset included"""
|
|
66
|
+
torch.save({
|
|
67
|
+
'epoch': epoch,
|
|
68
|
+
'model_state_dict': model.state_dict(),
|
|
69
|
+
'optimizer_state_dict': optimizer.state_dict(),
|
|
70
|
+
'val_loss': val_loss,
|
|
71
|
+
'accuracy': accuracy,
|
|
72
|
+
'charset': charset.chars
|
|
73
|
+
}, path)
|
|
74
|
+
|
|
75
|
+
# ========== LIGHTWEIGHT OCR MODEL ==========
|
|
76
|
+
class LightweightOCR(nn.Module):
|
|
77
|
+
"""
|
|
78
|
+
Lightweight OCR model (~13MB)
|
|
79
|
+
Similar to the reference model you showed
|
|
80
|
+
"""
|
|
81
|
+
def __init__(self, num_chars, hidden_size=256):
|
|
82
|
+
super(LightweightOCR, self).__init__()
|
|
83
|
+
|
|
84
|
+
# CNN backbone (lighter than before)
|
|
85
|
+
self.cnn = nn.Sequential(
|
|
86
|
+
# Block 1: [B, 1, 32, W] -> [B, 32, 32, W]
|
|
87
|
+
self._conv_block(1, 32, kernel_size=3, stride=1, padding=1),
|
|
88
|
+
nn.MaxPool2d(kernel_size=2, stride=2), # -> [B, 32, 16, W/2]
|
|
89
|
+
|
|
90
|
+
# Block 2: [B, 32, 16, W/2] -> [B, 64, 16, W/2]
|
|
91
|
+
self._conv_block(32, 64, kernel_size=3, stride=1, padding=1),
|
|
92
|
+
nn.MaxPool2d(kernel_size=2, stride=2), # -> [B, 64, 8, W/4]
|
|
93
|
+
|
|
94
|
+
# Block 3: [B, 64, 8, W/4] -> [B, 128, 8, W/4]
|
|
95
|
+
self._conv_block(64, 128, kernel_size=3, stride=1, padding=1),
|
|
96
|
+
self._conv_block(128, 128, kernel_size=3, stride=1, padding=1),
|
|
97
|
+
nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1)), # -> [B, 128, 4, W/4]
|
|
98
|
+
|
|
99
|
+
# Block 4: [B, 128, 4, W/4] -> [B, 256, 4, W/4]
|
|
100
|
+
self._conv_block(128, 256, kernel_size=3, stride=1, padding=1),
|
|
101
|
+
self._conv_block(256, 256, kernel_size=3, stride=1, padding=1),
|
|
102
|
+
nn.MaxPool2d(kernel_size=(4, 1), stride=(4, 1)), # -> [B, 256, 1, W/4]
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
# Two-layer LSTM with intermediate projection
|
|
106
|
+
self.lstm1 = nn.LSTM(
|
|
107
|
+
256, hidden_size,
|
|
108
|
+
bidirectional=True,
|
|
109
|
+
batch_first=True
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
self.intermediate_linear = nn.Linear(hidden_size * 2, hidden_size)
|
|
113
|
+
|
|
114
|
+
self.lstm2 = nn.LSTM(
|
|
115
|
+
hidden_size, hidden_size,
|
|
116
|
+
bidirectional=True,
|
|
117
|
+
batch_first=True
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
# Output layer
|
|
121
|
+
self.fc = nn.Linear(hidden_size * 2, num_chars)
|
|
122
|
+
|
|
123
|
+
def _conv_block(self, in_ch, out_ch, **kwargs):
|
|
124
|
+
"""Convolutional block with BatchNorm and ReLU"""
|
|
125
|
+
return nn.Sequential(
|
|
126
|
+
nn.Conv2d(in_ch, out_ch, bias=False, **kwargs),
|
|
127
|
+
nn.BatchNorm2d(out_ch),
|
|
128
|
+
nn.ReLU(inplace=True),
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
def forward(self, x):
|
|
132
|
+
"""
|
|
133
|
+
Args:
|
|
134
|
+
x: [B, 1, H, W]
|
|
135
|
+
Returns:
|
|
136
|
+
output: [T, B, num_chars]
|
|
137
|
+
"""
|
|
138
|
+
# CNN features
|
|
139
|
+
features = self.cnn(x) # [B, 256, 1, W']
|
|
140
|
+
|
|
141
|
+
# Remove height dimension
|
|
142
|
+
features = features.squeeze(2) # [B, 256, W']
|
|
143
|
+
features = features.permute(0, 2, 1) # [B, W', 256]
|
|
144
|
+
|
|
145
|
+
# First LSTM layer
|
|
146
|
+
rnn_out, _ = self.lstm1(features) # [B, W', hidden*2]
|
|
147
|
+
|
|
148
|
+
# Intermediate projection
|
|
149
|
+
rnn_out = self.intermediate_linear(rnn_out) # [B, W', hidden]
|
|
150
|
+
rnn_out = torch.relu(rnn_out)
|
|
151
|
+
|
|
152
|
+
# Second LSTM layer
|
|
153
|
+
rnn_out, _ = self.lstm2(rnn_out) # [B, W', hidden*2]
|
|
154
|
+
|
|
155
|
+
# Fully connected
|
|
156
|
+
logits = self.fc(rnn_out) # [B, W', num_chars]
|
|
157
|
+
|
|
158
|
+
# Permute for CTC
|
|
159
|
+
return logits.permute(1, 0, 2) # [W', B, num_chars]
|
kiri_ocr/renderer.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
1
|
+
import cv2
|
|
2
|
+
import numpy as np
|
|
3
|
+
from PIL import Image, ImageDraw, ImageFont
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
class DocumentRenderer:
|
|
7
|
+
"""Render OCR results on image"""
|
|
8
|
+
|
|
9
|
+
def __init__(self, font_path=None, font_size=12):
|
|
10
|
+
self.font_size = font_size
|
|
11
|
+
|
|
12
|
+
# Try to load font for text rendering
|
|
13
|
+
self.font = None
|
|
14
|
+
if font_path and Path(font_path).exists():
|
|
15
|
+
try:
|
|
16
|
+
self.font = ImageFont.truetype(font_path, font_size)
|
|
17
|
+
except:
|
|
18
|
+
pass
|
|
19
|
+
|
|
20
|
+
# Fallback fonts
|
|
21
|
+
if self.font is None:
|
|
22
|
+
# Priority list including Khmer fonts
|
|
23
|
+
candidate_fonts = [
|
|
24
|
+
'fonts/KhmerOSbattambang.ttf',
|
|
25
|
+
'fonts/Battambang-Regular.ttf',
|
|
26
|
+
'fonts/NotoSansKhmer-Regular.ttf',
|
|
27
|
+
'Arial.ttf',
|
|
28
|
+
'DejaVuSans.ttf',
|
|
29
|
+
'NotoSans-Regular.ttf'
|
|
30
|
+
]
|
|
31
|
+
|
|
32
|
+
# Check for any TTF in fonts/ directory
|
|
33
|
+
if Path('fonts').exists():
|
|
34
|
+
candidate_fonts = [str(f) for f in Path('fonts').glob('*.ttf')] + candidate_fonts
|
|
35
|
+
|
|
36
|
+
for font_name in candidate_fonts:
|
|
37
|
+
try:
|
|
38
|
+
self.font = ImageFont.truetype(font_name, font_size)
|
|
39
|
+
break
|
|
40
|
+
except:
|
|
41
|
+
continue
|
|
42
|
+
|
|
43
|
+
def draw_boxes(self, image_path, results, output_path='output_boxes.png'):
|
|
44
|
+
"""Draw bounding boxes only"""
|
|
45
|
+
img = cv2.imread(str(image_path))
|
|
46
|
+
|
|
47
|
+
for result in results:
|
|
48
|
+
x, y, w, h = result['box']
|
|
49
|
+
conf = result['confidence']
|
|
50
|
+
|
|
51
|
+
# Color based on confidence
|
|
52
|
+
if conf > 0.9:
|
|
53
|
+
color = (0, 255, 0) # Green
|
|
54
|
+
elif conf > 0.7:
|
|
55
|
+
color = (0, 165, 255) # Orange
|
|
56
|
+
else:
|
|
57
|
+
color = (0, 0, 255) # Red
|
|
58
|
+
|
|
59
|
+
# Draw box
|
|
60
|
+
cv2.rectangle(img, (x, y), (x+w, y+h), color, 2)
|
|
61
|
+
|
|
62
|
+
# Draw line number
|
|
63
|
+
if 'line_number' in result:
|
|
64
|
+
cv2.putText(img, str(result['line_number']), (x-5, y-5),
|
|
65
|
+
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
|
|
66
|
+
|
|
67
|
+
cv2.imwrite(output_path, img)
|
|
68
|
+
print(f"\nā Boxes saved to {output_path}")
|
|
69
|
+
|
|
70
|
+
return img
|
|
71
|
+
|
|
72
|
+
def draw_results(self, image_path, results, output_path='output_ocr.png',
|
|
73
|
+
show_text=True, show_confidence=True):
|
|
74
|
+
"""Draw boxes with recognized text"""
|
|
75
|
+
img = cv2.imread(str(image_path))
|
|
76
|
+
|
|
77
|
+
# Convert to PIL for better text rendering
|
|
78
|
+
img_pil = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
|
|
79
|
+
draw = ImageDraw.Draw(img_pil)
|
|
80
|
+
|
|
81
|
+
for result in results:
|
|
82
|
+
x, y, w, h = result['box']
|
|
83
|
+
text = result['text']
|
|
84
|
+
conf = result['confidence']
|
|
85
|
+
|
|
86
|
+
# Color based on confidence
|
|
87
|
+
if conf > 0.9:
|
|
88
|
+
color = (0, 255, 0)
|
|
89
|
+
elif conf > 0.7:
|
|
90
|
+
color = (255, 165, 0)
|
|
91
|
+
else:
|
|
92
|
+
color = (255, 0, 0)
|
|
93
|
+
|
|
94
|
+
# Draw box
|
|
95
|
+
draw.rectangle([x, y, x+w, y+h], outline=color, width=2)
|
|
96
|
+
|
|
97
|
+
if show_text:
|
|
98
|
+
# Prepare label
|
|
99
|
+
label = text[:50] # Limit length
|
|
100
|
+
if show_confidence:
|
|
101
|
+
label += f" ({conf*100:.0f}%)"
|
|
102
|
+
|
|
103
|
+
# Background for text
|
|
104
|
+
if self.font:
|
|
105
|
+
try:
|
|
106
|
+
# Calculate text position
|
|
107
|
+
left, top, right, bottom = draw.textbbox((0, 0), label, font=self.font)
|
|
108
|
+
text_height = bottom - top
|
|
109
|
+
text_y = y - text_height - 5
|
|
110
|
+
|
|
111
|
+
bbox = draw.textbbox((x, text_y), label, font=self.font)
|
|
112
|
+
# Add padding
|
|
113
|
+
bbox = (bbox[0]-2, bbox[1]-2, bbox[2]+2, bbox[3]+2)
|
|
114
|
+
|
|
115
|
+
draw.rectangle(bbox, fill=color)
|
|
116
|
+
draw.text((x, text_y), label, fill=(255, 255, 255), font=self.font)
|
|
117
|
+
except:
|
|
118
|
+
# Fallback
|
|
119
|
+
img_cv = cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR)
|
|
120
|
+
cv2.putText(img_cv, label, (x, y-10),
|
|
121
|
+
cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 1)
|
|
122
|
+
img_pil = Image.fromarray(cv2.cvtColor(img_cv, cv2.COLOR_BGR2RGB))
|
|
123
|
+
draw = ImageDraw.Draw(img_pil)
|
|
124
|
+
|
|
125
|
+
# Save
|
|
126
|
+
img_pil.save(output_path)
|
|
127
|
+
print(f"ā Results saved to {output_path}")
|
|
128
|
+
|
|
129
|
+
return img_pil
|
|
130
|
+
|
|
131
|
+
def create_report(self, image_path, results, output_path='ocr_report.html'):
|
|
132
|
+
"""Create HTML report"""
|
|
133
|
+
html = f"""
|
|
134
|
+
<!DOCTYPE html>
|
|
135
|
+
<html>
|
|
136
|
+
<head>
|
|
137
|
+
<meta charset="UTF-8">
|
|
138
|
+
<title>OCR Report</title>
|
|
139
|
+
<style>
|
|
140
|
+
body {{ font-family: Arial, sans-serif; margin: 20px; }}
|
|
141
|
+
.header {{ background: #4CAF50; color: white; padding: 20px; border-radius: 5px; }}
|
|
142
|
+
.result {{ padding: 10px; margin: 10px 0; border-left: 4px solid #4CAF50; }}
|
|
143
|
+
.conf-high {{ border-color: #4CAF50; background: #f1f8f4; }}
|
|
144
|
+
.conf-medium {{ border-color: #FF9800; background: #fff8f1; }}
|
|
145
|
+
.conf-low {{ border-color: #F44336; background: #fef1f1; }}
|
|
146
|
+
.text {{ font-size: 18px; font-weight: bold; margin-bottom: 5px; }}
|
|
147
|
+
.confidence {{ color: #666; font-size: 14px; }}
|
|
148
|
+
.stats {{ background: #f5f5f5; padding: 15px; border-radius: 5px; margin: 20px 0; }}
|
|
149
|
+
.full-text {{ background: #f5f5f5; padding: 20px; white-space: pre-wrap;
|
|
150
|
+
font-family: monospace; border-radius: 5px; }}
|
|
151
|
+
</style>
|
|
152
|
+
</head>
|
|
153
|
+
<body>
|
|
154
|
+
<div class="header">
|
|
155
|
+
<h1>š OCR Report</h1>
|
|
156
|
+
<p>Document: {Path(image_path).name}</p>
|
|
157
|
+
</div>
|
|
158
|
+
|
|
159
|
+
<div class="stats">
|
|
160
|
+
<strong>Statistics:</strong><br>
|
|
161
|
+
Total Regions: {len(results)}<br>
|
|
162
|
+
Average Confidence: {np.mean([r['confidence'] for r in results])*100:.2f}%<br>
|
|
163
|
+
High Confidence (>90%): {sum(1 for r in results if r['confidence'] > 0.9)}<br>
|
|
164
|
+
Medium Confidence (70-90%): {sum(1 for r in results if 0.7 < r['confidence'] <= 0.9)}<br>
|
|
165
|
+
Low Confidence (<70%): {sum(1 for r in results if r['confidence'] <= 0.7)}
|
|
166
|
+
</div>
|
|
167
|
+
|
|
168
|
+
<h2>š Full Text</h2>
|
|
169
|
+
<div class="full-text">{"<br>".join([r['text'] for r in results])}</div>
|
|
170
|
+
|
|
171
|
+
<h2>š Detailed Results</h2>
|
|
172
|
+
"""
|
|
173
|
+
|
|
174
|
+
for i, result in enumerate(results, 1):
|
|
175
|
+
conf = result['confidence']
|
|
176
|
+
conf_class = 'conf-high' if conf > 0.9 else ('conf-medium' if conf > 0.7 else 'conf-low')
|
|
177
|
+
|
|
178
|
+
html += f"""
|
|
179
|
+
<div class="result {conf_class}">
|
|
180
|
+
<div class="text">{i}. {result['text']}</div>
|
|
181
|
+
<div class="confidence">Confidence: {conf*100:.2f}% | Box: {result['box']}</div>
|
|
182
|
+
</div>
|
|
183
|
+
"""
|
|
184
|
+
|
|
185
|
+
html += """
|
|
186
|
+
</body>
|
|
187
|
+
</html>
|
|
188
|
+
"""
|
|
189
|
+
|
|
190
|
+
with open(output_path, 'w', encoding='utf-8') as f:
|
|
191
|
+
f.write(html)
|
|
192
|
+
|
|
193
|
+
print(f"ā Report saved to {output_path}")
|