kiri-ocr 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kiri_ocr/__init__.py +14 -0
- kiri_ocr/cli.py +244 -0
- kiri_ocr/core.py +306 -0
- kiri_ocr/detector.py +374 -0
- kiri_ocr/generator.py +570 -0
- kiri_ocr/model.py +159 -0
- kiri_ocr/renderer.py +193 -0
- kiri_ocr/training.py +508 -0
- kiri_ocr-0.1.0.data/scripts/kiri-ocr +6 -0
- kiri_ocr-0.1.0.dist-info/METADATA +218 -0
- kiri_ocr-0.1.0.dist-info/RECORD +16 -0
- kiri_ocr-0.1.0.dist-info/WHEEL +5 -0
- kiri_ocr-0.1.0.dist-info/licenses/LICENSE +201 -0
- kiri_ocr-0.1.0.dist-info/top_level.txt +2 -0
- models/__init__.py +1 -0
- models/model.kiri +0 -0
kiri_ocr/generator.py
ADDED
|
@@ -0,0 +1,570 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import sys
|
|
3
|
+
import shutil
|
|
4
|
+
import random
|
|
5
|
+
from collections import Counter
|
|
6
|
+
import numpy as np
|
|
7
|
+
import cv2
|
|
8
|
+
from tqdm import tqdm
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
|
|
11
|
+
# Try to import PIL for better text rendering
|
|
12
|
+
try:
|
|
13
|
+
from PIL import Image, ImageDraw, ImageFont
|
|
14
|
+
HAS_PIL = True
|
|
15
|
+
except ImportError:
|
|
16
|
+
HAS_PIL = False
|
|
17
|
+
|
|
18
|
+
class FontManager:
|
|
19
|
+
"""Manage font loading from project fonts directory"""
|
|
20
|
+
|
|
21
|
+
def __init__(self, language='mixed', fonts_dir='fonts'):
|
|
22
|
+
self.language = language
|
|
23
|
+
self.fonts_dir = Path(fonts_dir)
|
|
24
|
+
self.khmer_fonts = []
|
|
25
|
+
self.english_fonts = []
|
|
26
|
+
self.all_fonts = []
|
|
27
|
+
self._load_fonts()
|
|
28
|
+
|
|
29
|
+
def _load_fonts(self):
|
|
30
|
+
"""Load all fonts from fonts directory"""
|
|
31
|
+
|
|
32
|
+
print(f"\nš Loading fonts from: {self.fonts_dir.absolute()}")
|
|
33
|
+
|
|
34
|
+
# Create fonts directory if it doesn't exist
|
|
35
|
+
if not self.fonts_dir.exists():
|
|
36
|
+
print(f" ā ļø Fonts directory not found: {self.fonts_dir}")
|
|
37
|
+
print(f" Creating directory...")
|
|
38
|
+
self.fonts_dir.mkdir(parents=True, exist_ok=True)
|
|
39
|
+
print(f"\n ā No fonts found!")
|
|
40
|
+
print(f"\n Please add .ttf font files to: {self.fonts_dir.absolute()}")
|
|
41
|
+
return
|
|
42
|
+
|
|
43
|
+
# Get all .ttf and .otf files
|
|
44
|
+
font_files = []
|
|
45
|
+
font_files.extend(self.fonts_dir.glob('*.ttf'))
|
|
46
|
+
font_files.extend(self.fonts_dir.glob('*.TTF'))
|
|
47
|
+
font_files.extend(self.fonts_dir.glob('*.otf'))
|
|
48
|
+
font_files.extend(self.fonts_dir.glob('*.OTF'))
|
|
49
|
+
|
|
50
|
+
if not font_files:
|
|
51
|
+
print(f"\n ā No font files found in {self.fonts_dir}")
|
|
52
|
+
print(f"\n Please add .ttf or .otf files to this directory")
|
|
53
|
+
return
|
|
54
|
+
|
|
55
|
+
print(f" Found {len(font_files)} font files")
|
|
56
|
+
|
|
57
|
+
# Categorize fonts by name
|
|
58
|
+
for font_path in font_files:
|
|
59
|
+
font_name = font_path.name.lower()
|
|
60
|
+
|
|
61
|
+
# Check if it's a Khmer font
|
|
62
|
+
is_khmer = any(keyword in font_name for keyword in [
|
|
63
|
+
'khmer', 'įįįįį»įį¶', 'battambang', 'siemreap',
|
|
64
|
+
'bokor', 'moul', 'content', 'metal', 'freehand',
|
|
65
|
+
'fasthand', 'noto', 'kh'
|
|
66
|
+
])
|
|
67
|
+
|
|
68
|
+
# Load font at different sizes
|
|
69
|
+
for size in [28, 32, 36, 40, 44, 48]:
|
|
70
|
+
try:
|
|
71
|
+
font = ImageFont.truetype(str(font_path), size)
|
|
72
|
+
self.all_fonts.append((str(font_path), size, font))
|
|
73
|
+
|
|
74
|
+
if is_khmer:
|
|
75
|
+
self.khmer_fonts.append((str(font_path), size, font))
|
|
76
|
+
else:
|
|
77
|
+
self.english_fonts.append((str(font_path), size, font))
|
|
78
|
+
|
|
79
|
+
except Exception as e:
|
|
80
|
+
print(f" ā ļø Failed to load {font_path.name} at size {size}: {e}")
|
|
81
|
+
continue
|
|
82
|
+
|
|
83
|
+
# Print summary
|
|
84
|
+
print(f"\n š Font Summary:")
|
|
85
|
+
print(f" Total fonts : {len(self.all_fonts)} (across all sizes)")
|
|
86
|
+
print(f" Khmer fonts : {len(self.khmer_fonts)}")
|
|
87
|
+
print(f" English fonts : {len(self.english_fonts)}")
|
|
88
|
+
|
|
89
|
+
def get_random_font(self, text):
|
|
90
|
+
"""Get random appropriate font for text"""
|
|
91
|
+
has_khmer = any('\u1780' <= c <= '\u17FF' for c in text)
|
|
92
|
+
|
|
93
|
+
# Choose from appropriate font pool
|
|
94
|
+
if has_khmer and self.khmer_fonts:
|
|
95
|
+
font_pool = self.khmer_fonts
|
|
96
|
+
elif not has_khmer and self.english_fonts:
|
|
97
|
+
font_pool = self.english_fonts
|
|
98
|
+
else:
|
|
99
|
+
# Fallback to all fonts
|
|
100
|
+
font_pool = self.all_fonts
|
|
101
|
+
|
|
102
|
+
if not font_pool:
|
|
103
|
+
# Ultimate fallback if no fonts loaded (should error earlier)
|
|
104
|
+
return None
|
|
105
|
+
|
|
106
|
+
# Return random font
|
|
107
|
+
font_path, size, font = random.choice(font_pool)
|
|
108
|
+
return font
|
|
109
|
+
|
|
110
|
+
class ImageRenderer:
|
|
111
|
+
"""Render text to images"""
|
|
112
|
+
|
|
113
|
+
def __init__(self, font_manager, image_height=32, image_width=512):
|
|
114
|
+
self.font_manager = font_manager
|
|
115
|
+
self.image_height = image_height
|
|
116
|
+
self.image_width = image_width
|
|
117
|
+
|
|
118
|
+
def _is_text_supported(self, font, text):
|
|
119
|
+
"""Check if font supports the characters in text (detects 'tofu' boxes)"""
|
|
120
|
+
try:
|
|
121
|
+
# Get the glyph for a strictly undefined character to use as reference
|
|
122
|
+
# Using multiple candidates to be safe
|
|
123
|
+
undefined_chars = ['\uFFFF', '\U0010FFFF', '\0']
|
|
124
|
+
ref_mask = None
|
|
125
|
+
ref_bbox = None
|
|
126
|
+
|
|
127
|
+
for uc in undefined_chars:
|
|
128
|
+
try:
|
|
129
|
+
ref_mask = font.getmask(uc)
|
|
130
|
+
ref_bbox = ref_mask.getbbox()
|
|
131
|
+
if ref_mask:
|
|
132
|
+
break
|
|
133
|
+
except Exception:
|
|
134
|
+
continue
|
|
135
|
+
|
|
136
|
+
if ref_mask is None:
|
|
137
|
+
# Can't determine reference, assume supported to avoid blocking everything
|
|
138
|
+
return True
|
|
139
|
+
|
|
140
|
+
ref_bytes = bytes(ref_mask)
|
|
141
|
+
|
|
142
|
+
for char in text:
|
|
143
|
+
# Skip spaces and control characters (they often have empty glyphs which is fine)
|
|
144
|
+
if char.isspace() or ord(char) < 32:
|
|
145
|
+
continue
|
|
146
|
+
|
|
147
|
+
try:
|
|
148
|
+
char_mask = font.getmask(char)
|
|
149
|
+
char_bbox = char_mask.getbbox()
|
|
150
|
+
|
|
151
|
+
# Compare with reference "notdef" glyph
|
|
152
|
+
if char_bbox == ref_bbox:
|
|
153
|
+
# Exact bbox match. Deep check bytes.
|
|
154
|
+
if bytes(char_mask) == ref_bytes:
|
|
155
|
+
# It's a tofu/box!
|
|
156
|
+
return False
|
|
157
|
+
except Exception:
|
|
158
|
+
# Error getting mask implies issue
|
|
159
|
+
return False
|
|
160
|
+
|
|
161
|
+
return True
|
|
162
|
+
except Exception:
|
|
163
|
+
# If check fails, be permissive
|
|
164
|
+
return True
|
|
165
|
+
|
|
166
|
+
def render(self, text, augment=True, specific_font=None, retry_limit=10):
|
|
167
|
+
"""Render text to image using PIL"""
|
|
168
|
+
if not HAS_PIL:
|
|
169
|
+
raise ImportError("Pillow library not found")
|
|
170
|
+
|
|
171
|
+
# Get font
|
|
172
|
+
if specific_font:
|
|
173
|
+
font = specific_font
|
|
174
|
+
# Check support for specific font
|
|
175
|
+
if not self._is_text_supported(font, text):
|
|
176
|
+
return None
|
|
177
|
+
else:
|
|
178
|
+
# Retry with retry_limit for random mode
|
|
179
|
+
font = None
|
|
180
|
+
for _ in range(retry_limit):
|
|
181
|
+
candidate = self.font_manager.get_random_font(text)
|
|
182
|
+
if candidate and self._is_text_supported(candidate, text):
|
|
183
|
+
font = candidate
|
|
184
|
+
break
|
|
185
|
+
|
|
186
|
+
if font is None:
|
|
187
|
+
# Just skip
|
|
188
|
+
return None
|
|
189
|
+
|
|
190
|
+
# Measure text size
|
|
191
|
+
dummy_img = Image.new('L', (1, 1))
|
|
192
|
+
draw = ImageDraw.Draw(dummy_img)
|
|
193
|
+
|
|
194
|
+
try:
|
|
195
|
+
# New PIL
|
|
196
|
+
bbox = draw.textbbox((0, 0), text, font=font)
|
|
197
|
+
text_w = bbox[2] - bbox[0]
|
|
198
|
+
text_h = bbox[3] - bbox[1]
|
|
199
|
+
offset_y = -bbox[1]
|
|
200
|
+
except AttributeError:
|
|
201
|
+
# Old PIL
|
|
202
|
+
text_w, text_h = draw.textsize(text, font=font)
|
|
203
|
+
offset_y = 0
|
|
204
|
+
|
|
205
|
+
# Padding
|
|
206
|
+
padding_x = random.randint(10, 30) if augment else 20
|
|
207
|
+
padding_y = random.randint(5, 15) if augment else 10
|
|
208
|
+
|
|
209
|
+
img_w = text_w + padding_x * 2
|
|
210
|
+
img_h = text_h + padding_y * 2
|
|
211
|
+
|
|
212
|
+
# Background color
|
|
213
|
+
bg_color = random.randint(235, 255) if augment else 255
|
|
214
|
+
text_color = random.randint(0, 30) if augment else 0
|
|
215
|
+
|
|
216
|
+
# Create image
|
|
217
|
+
img = Image.new('L', (img_w, img_h), color=bg_color)
|
|
218
|
+
draw = ImageDraw.Draw(img)
|
|
219
|
+
|
|
220
|
+
# Position
|
|
221
|
+
x = padding_x + (random.randint(-3, 3) if augment else 0)
|
|
222
|
+
y = padding_y + offset_y + (random.randint(-2, 2) if augment else 0)
|
|
223
|
+
|
|
224
|
+
# Draw text
|
|
225
|
+
draw.text((x, y), text, font=font, fill=text_color)
|
|
226
|
+
|
|
227
|
+
# Convert to numpy
|
|
228
|
+
img_array = np.array(img)
|
|
229
|
+
|
|
230
|
+
# Augmentations
|
|
231
|
+
if augment:
|
|
232
|
+
img_array = self._apply_augmentations(img_array, bg_color)
|
|
233
|
+
|
|
234
|
+
# Resize to target dimensions
|
|
235
|
+
img_array = self._resize_to_target(img_array)
|
|
236
|
+
|
|
237
|
+
return img_array
|
|
238
|
+
|
|
239
|
+
def _apply_augmentations(self, img, bg_color):
|
|
240
|
+
"""Apply random augmentations"""
|
|
241
|
+
# Gaussian noise
|
|
242
|
+
if random.random() < 0.4:
|
|
243
|
+
noise = np.random.randn(*img.shape) * random.uniform(3, 8)
|
|
244
|
+
img = np.clip(img.astype(np.float32) + noise, 0, 255).astype(np.uint8)
|
|
245
|
+
|
|
246
|
+
# Gaussian blur
|
|
247
|
+
if random.random() < 0.3:
|
|
248
|
+
kernel_size = random.choice([3, 5])
|
|
249
|
+
img = cv2.GaussianBlur(img, (kernel_size, kernel_size), 0)
|
|
250
|
+
|
|
251
|
+
# Rotation - Disabled to prevent text cutoff
|
|
252
|
+
# if random.random() < 0.3:
|
|
253
|
+
# angle = random.uniform(-3, 3)
|
|
254
|
+
# h, w = img.shape
|
|
255
|
+
# M = cv2.getRotationMatrix2D((w/2, h/2), angle, 1.0)
|
|
256
|
+
# img = cv2.warpAffine(
|
|
257
|
+
# img, M, (w, h),
|
|
258
|
+
# borderMode=cv2.BORDER_CONSTANT,
|
|
259
|
+
# borderValue=int(bg_color)
|
|
260
|
+
# )
|
|
261
|
+
|
|
262
|
+
# Morphological operations
|
|
263
|
+
if random.random() < 0.2:
|
|
264
|
+
kernel = np.ones((2, 2), np.uint8)
|
|
265
|
+
if random.random() < 0.5:
|
|
266
|
+
img = cv2.erode(img, kernel, iterations=1)
|
|
267
|
+
else:
|
|
268
|
+
img = cv2.dilate(img, kernel, iterations=1)
|
|
269
|
+
|
|
270
|
+
# Brightness/Contrast
|
|
271
|
+
if random.random() < 0.3:
|
|
272
|
+
alpha = random.uniform(0.85, 1.15) # Contrast
|
|
273
|
+
beta = random.randint(-15, 15) # Brightness
|
|
274
|
+
img = cv2.convertScaleAbs(img, alpha=alpha, beta=beta)
|
|
275
|
+
|
|
276
|
+
return img
|
|
277
|
+
|
|
278
|
+
def _resize_to_target(self, img):
|
|
279
|
+
"""Resize to target dimensions"""
|
|
280
|
+
h, w = img.shape[:2]
|
|
281
|
+
|
|
282
|
+
# Scale to match height
|
|
283
|
+
scale = self.image_height / h
|
|
284
|
+
new_w = int(w * scale)
|
|
285
|
+
|
|
286
|
+
img = cv2.resize(img, (new_w, self.image_height), interpolation=cv2.INTER_LINEAR)
|
|
287
|
+
|
|
288
|
+
# Handle width
|
|
289
|
+
if new_w < self.image_width:
|
|
290
|
+
# Pad right
|
|
291
|
+
bg_color = int(np.mean(img[:, -10:])) if new_w > 10 else 255
|
|
292
|
+
padded = np.ones((self.image_height, self.image_width), dtype=np.uint8) * bg_color
|
|
293
|
+
padded[:, :new_w] = img
|
|
294
|
+
img = padded
|
|
295
|
+
elif new_w > self.image_width:
|
|
296
|
+
# Resize to fit
|
|
297
|
+
img = cv2.resize(img, (self.image_width, self.image_height))
|
|
298
|
+
|
|
299
|
+
return img
|
|
300
|
+
|
|
301
|
+
class DatasetGenerator:
|
|
302
|
+
"""Generate training dataset"""
|
|
303
|
+
|
|
304
|
+
def __init__(self, language='mixed', image_height=32, image_width=512, fonts_dir='fonts'):
|
|
305
|
+
self.language = language
|
|
306
|
+
self.image_height = image_height
|
|
307
|
+
self.image_width = image_width
|
|
308
|
+
|
|
309
|
+
self.font_manager = FontManager(language, fonts_dir=fonts_dir)
|
|
310
|
+
self.renderer = ImageRenderer(self.font_manager, image_height, image_width)
|
|
311
|
+
|
|
312
|
+
def generate_dataset(
|
|
313
|
+
self,
|
|
314
|
+
train_file,
|
|
315
|
+
val_file=None,
|
|
316
|
+
output_dir='data',
|
|
317
|
+
train_augment=100,
|
|
318
|
+
val_augment=1,
|
|
319
|
+
font_mode='random',
|
|
320
|
+
random_augment=False,
|
|
321
|
+
retry_limit=10
|
|
322
|
+
):
|
|
323
|
+
"""Generate complete dataset"""
|
|
324
|
+
|
|
325
|
+
print("\n" + "="*70)
|
|
326
|
+
print(" šø Dataset Generation")
|
|
327
|
+
print("="*70)
|
|
328
|
+
|
|
329
|
+
# Check for existing dataset and ask user
|
|
330
|
+
append_mode = False
|
|
331
|
+
if os.path.exists(output_dir) and os.listdir(output_dir):
|
|
332
|
+
print(f"\nā ļø Dataset directory '{output_dir}' already exists.")
|
|
333
|
+
while True:
|
|
334
|
+
try:
|
|
335
|
+
choice = input(" Do you want to (c)ontinue generating or (s)tart from scratch? [c/s]: ").lower().strip()
|
|
336
|
+
if choice == 'c':
|
|
337
|
+
append_mode = True
|
|
338
|
+
print(" š Continuing generation (appending to existing)...")
|
|
339
|
+
break
|
|
340
|
+
elif choice == 's':
|
|
341
|
+
print(" šļø Cleaning up existing directory...")
|
|
342
|
+
shutil.rmtree(output_dir)
|
|
343
|
+
break
|
|
344
|
+
except KeyboardInterrupt:
|
|
345
|
+
print("\nAborted.")
|
|
346
|
+
return
|
|
347
|
+
|
|
348
|
+
# Create output dirs
|
|
349
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
350
|
+
os.makedirs(f"{output_dir}/train/images", exist_ok=True)
|
|
351
|
+
os.makedirs(f"{output_dir}/val/images", exist_ok=True)
|
|
352
|
+
|
|
353
|
+
# Generate training set
|
|
354
|
+
print("\nš Generating TRAINING set...")
|
|
355
|
+
train_data = self._generate_split(
|
|
356
|
+
train_file,
|
|
357
|
+
f"{output_dir}/train",
|
|
358
|
+
augment_factor=train_augment,
|
|
359
|
+
split_name='train',
|
|
360
|
+
font_mode=font_mode,
|
|
361
|
+
random_augment=random_augment,
|
|
362
|
+
retry_limit=retry_limit,
|
|
363
|
+
append=append_mode
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
# Generate validation set
|
|
367
|
+
if val_file and os.path.exists(val_file):
|
|
368
|
+
print("\nš Generating VALIDATION set (from separate file)...")
|
|
369
|
+
val_data = self._generate_split(
|
|
370
|
+
val_file,
|
|
371
|
+
f"{output_dir}/val",
|
|
372
|
+
augment_factor=val_augment,
|
|
373
|
+
split_name='val',
|
|
374
|
+
retry_limit=retry_limit,
|
|
375
|
+
append=append_mode
|
|
376
|
+
)
|
|
377
|
+
else:
|
|
378
|
+
# Split from training
|
|
379
|
+
print("\nš Generating VALIDATION set (split from training)...")
|
|
380
|
+
with open(train_file, 'r', encoding='utf-8') as f:
|
|
381
|
+
all_lines = [line.strip() for line in f if line.strip()]
|
|
382
|
+
|
|
383
|
+
random.shuffle(all_lines)
|
|
384
|
+
split_idx = int(len(all_lines) * 0.9)
|
|
385
|
+
val_lines = all_lines[split_idx:]
|
|
386
|
+
|
|
387
|
+
# Save temporary val file
|
|
388
|
+
temp_val = f"{output_dir}/temp_val.txt"
|
|
389
|
+
with open(temp_val, 'w', encoding='utf-8') as f:
|
|
390
|
+
f.write('\n'.join(val_lines))
|
|
391
|
+
|
|
392
|
+
val_data = self._generate_split(
|
|
393
|
+
temp_val,
|
|
394
|
+
f"{output_dir}/val",
|
|
395
|
+
augment_factor=val_augment,
|
|
396
|
+
split_name='val',
|
|
397
|
+
retry_limit=retry_limit,
|
|
398
|
+
append=append_mode
|
|
399
|
+
)
|
|
400
|
+
|
|
401
|
+
os.remove(temp_val)
|
|
402
|
+
|
|
403
|
+
print("\n" + "="*70)
|
|
404
|
+
print(" ā
Dataset Generation Complete!")
|
|
405
|
+
print("="*70)
|
|
406
|
+
print(f" Train: {len(train_data):,} samples")
|
|
407
|
+
print(f" Val: {len(val_data):,} samples")
|
|
408
|
+
print(f" Output: {output_dir}")
|
|
409
|
+
print("="*70 + "\n")
|
|
410
|
+
|
|
411
|
+
def _generate_split(self, text_file, output_dir, augment_factor, split_name, font_mode='random', random_augment=False, retry_limit=10, append=False):
|
|
412
|
+
"""Generate one split (train/val)"""
|
|
413
|
+
|
|
414
|
+
# Read text lines
|
|
415
|
+
with open(text_file, 'r', encoding='utf-8') as f:
|
|
416
|
+
lines = [line.strip() for line in f if line.strip()]
|
|
417
|
+
|
|
418
|
+
print(f" Loaded {len(lines)} lines from {text_file}")
|
|
419
|
+
|
|
420
|
+
# Load existing progress if appending
|
|
421
|
+
existing_counts = Counter()
|
|
422
|
+
start_idx = 0
|
|
423
|
+
|
|
424
|
+
if append and os.path.exists(f"{output_dir}/images"):
|
|
425
|
+
# 1. Determine start index
|
|
426
|
+
existing_files = [f for f in os.listdir(f"{output_dir}/images") if f.endswith('.png')]
|
|
427
|
+
if existing_files:
|
|
428
|
+
indices = []
|
|
429
|
+
prefix = f"{split_name}_"
|
|
430
|
+
for fname in existing_files:
|
|
431
|
+
if fname.startswith(prefix):
|
|
432
|
+
try:
|
|
433
|
+
num_part = fname[len(prefix):-4]
|
|
434
|
+
indices.append(int(num_part))
|
|
435
|
+
except ValueError:
|
|
436
|
+
continue
|
|
437
|
+
|
|
438
|
+
if indices:
|
|
439
|
+
start_idx = max(indices) + 1
|
|
440
|
+
print(f" š Appending starting from index {start_idx}")
|
|
441
|
+
|
|
442
|
+
# 2. Count existing samples per text line
|
|
443
|
+
labels_path = f"{output_dir}/labels.txt"
|
|
444
|
+
if os.path.exists(labels_path):
|
|
445
|
+
print(f" š Analyzing existing labels to resume...")
|
|
446
|
+
try:
|
|
447
|
+
with open(labels_path, 'r', encoding='utf-8') as f:
|
|
448
|
+
for line in f:
|
|
449
|
+
parts = line.strip().split('\t', 1)
|
|
450
|
+
if len(parts) == 2:
|
|
451
|
+
existing_counts[parts[1]] += 1
|
|
452
|
+
print(f" ā Found {sum(existing_counts.values())} existing samples")
|
|
453
|
+
except Exception as e:
|
|
454
|
+
print(f" ā ļø Could not read labels file: {e}")
|
|
455
|
+
|
|
456
|
+
# Create samples
|
|
457
|
+
samples = []
|
|
458
|
+
skipped_count = 0
|
|
459
|
+
|
|
460
|
+
if font_mode == 'all':
|
|
461
|
+
print(f" Using ALL fonts mode (augment factor ignored for font selection)")
|
|
462
|
+
fonts_list = self.font_manager.all_fonts
|
|
463
|
+
print(f" Iterating {len(fonts_list)} fonts per line...")
|
|
464
|
+
|
|
465
|
+
for line in lines:
|
|
466
|
+
# In 'all' mode, we skip line only if completely done to avoid missing fonts
|
|
467
|
+
# (Since we don't track which fonts were done)
|
|
468
|
+
if append and existing_counts[line] >= len(fonts_list):
|
|
469
|
+
skipped_count += len(fonts_list)
|
|
470
|
+
continue
|
|
471
|
+
|
|
472
|
+
for font_tuple in fonts_list:
|
|
473
|
+
samples.append({'text': line, 'font': font_tuple[2]})
|
|
474
|
+
else:
|
|
475
|
+
# Random mode
|
|
476
|
+
print(f" Using RANDOM fonts mode (retry limit: {retry_limit} attempts per sample)")
|
|
477
|
+
for line in lines:
|
|
478
|
+
needed = augment_factor
|
|
479
|
+
have = existing_counts[line] if append else 0
|
|
480
|
+
|
|
481
|
+
remaining = max(0, needed - have)
|
|
482
|
+
|
|
483
|
+
if remaining < needed:
|
|
484
|
+
skipped_count += (needed - remaining)
|
|
485
|
+
|
|
486
|
+
for _ in range(remaining):
|
|
487
|
+
samples.append({'text': line, 'font': None})
|
|
488
|
+
|
|
489
|
+
if skipped_count > 0:
|
|
490
|
+
print(f" āļø Skipping {skipped_count} already generated samples")
|
|
491
|
+
|
|
492
|
+
random.shuffle(samples)
|
|
493
|
+
|
|
494
|
+
# Generate images
|
|
495
|
+
file_mode = 'a' if append else 'w'
|
|
496
|
+
labels_file = open(f"{output_dir}/labels.txt", file_mode, encoding='utf-8')
|
|
497
|
+
success_count = 0
|
|
498
|
+
|
|
499
|
+
print(f" Generating {len(samples)} images...")
|
|
500
|
+
|
|
501
|
+
for idx, sample in enumerate(tqdm(samples, desc=" Generating", unit="img")):
|
|
502
|
+
text = sample['text']
|
|
503
|
+
specific_font = sample['font']
|
|
504
|
+
|
|
505
|
+
# Set retry_limit based on font_mode
|
|
506
|
+
current_retry_limit = 1 if font_mode == 'all' else retry_limit
|
|
507
|
+
|
|
508
|
+
try:
|
|
509
|
+
# Render image
|
|
510
|
+
img = self.renderer.render(
|
|
511
|
+
text,
|
|
512
|
+
augment=((augment_factor > 1) or random_augment),
|
|
513
|
+
specific_font=specific_font,
|
|
514
|
+
retry_limit=current_retry_limit
|
|
515
|
+
)
|
|
516
|
+
|
|
517
|
+
if img is None:
|
|
518
|
+
continue
|
|
519
|
+
|
|
520
|
+
# Save
|
|
521
|
+
current_idx = start_idx + idx
|
|
522
|
+
img_filename = f"{split_name}_{current_idx:06d}.png"
|
|
523
|
+
img_path = f"{output_dir}/images/{img_filename}"
|
|
524
|
+
cv2.imwrite(img_path, img)
|
|
525
|
+
|
|
526
|
+
# Write label
|
|
527
|
+
labels_file.write(f"{img_filename}\t{text}\n")
|
|
528
|
+
success_count += 1
|
|
529
|
+
|
|
530
|
+
except Exception as e:
|
|
531
|
+
print(f" ā Failed for '{text[:30]}...': {e}")
|
|
532
|
+
continue
|
|
533
|
+
|
|
534
|
+
labels_file.close()
|
|
535
|
+
print(f" ā Generated {success_count} / {len(samples)} images\n")
|
|
536
|
+
|
|
537
|
+
return samples[:success_count]
|
|
538
|
+
|
|
539
|
+
def generate_command(args):
|
|
540
|
+
# Check files exist
|
|
541
|
+
if not os.path.exists(args.train_file):
|
|
542
|
+
print(f"\nā Error: Training file not found: {args.train_file}\n")
|
|
543
|
+
return 1
|
|
544
|
+
|
|
545
|
+
if not HAS_PIL:
|
|
546
|
+
print("ā PIL/Pillow not found - install with: pip install Pillow")
|
|
547
|
+
return 1
|
|
548
|
+
|
|
549
|
+
# Get retry limit from args, default to 10
|
|
550
|
+
retry_limit = getattr(args, 'retry_limit', 10)
|
|
551
|
+
|
|
552
|
+
# Generate
|
|
553
|
+
generator = DatasetGenerator(
|
|
554
|
+
language=args.language,
|
|
555
|
+
image_height=args.height,
|
|
556
|
+
image_width=args.width,
|
|
557
|
+
fonts_dir=args.fonts_dir
|
|
558
|
+
)
|
|
559
|
+
|
|
560
|
+
generator.generate_dataset(
|
|
561
|
+
train_file=args.train_file,
|
|
562
|
+
val_file=args.val_file,
|
|
563
|
+
output_dir=args.output,
|
|
564
|
+
train_augment=args.augment,
|
|
565
|
+
val_augment=args.val_augment,
|
|
566
|
+
font_mode=args.font_mode if hasattr(args, 'font_mode') else 'random',
|
|
567
|
+
random_augment=args.random_augment if hasattr(args, 'random_augment') else False,
|
|
568
|
+
retry_limit=retry_limit
|
|
569
|
+
)
|
|
570
|
+
return 0
|