kiri-ocr 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
kiri_ocr/generator.py ADDED
@@ -0,0 +1,570 @@
1
+ import os
2
+ import sys
3
+ import shutil
4
+ import random
5
+ from collections import Counter
6
+ import numpy as np
7
+ import cv2
8
+ from tqdm import tqdm
9
+ from pathlib import Path
10
+
11
+ # Try to import PIL for better text rendering
12
+ try:
13
+ from PIL import Image, ImageDraw, ImageFont
14
+ HAS_PIL = True
15
+ except ImportError:
16
+ HAS_PIL = False
17
+
18
+ class FontManager:
19
+ """Manage font loading from project fonts directory"""
20
+
21
+ def __init__(self, language='mixed', fonts_dir='fonts'):
22
+ self.language = language
23
+ self.fonts_dir = Path(fonts_dir)
24
+ self.khmer_fonts = []
25
+ self.english_fonts = []
26
+ self.all_fonts = []
27
+ self._load_fonts()
28
+
29
+ def _load_fonts(self):
30
+ """Load all fonts from fonts directory"""
31
+
32
+ print(f"\nšŸ” Loading fonts from: {self.fonts_dir.absolute()}")
33
+
34
+ # Create fonts directory if it doesn't exist
35
+ if not self.fonts_dir.exists():
36
+ print(f" āš ļø Fonts directory not found: {self.fonts_dir}")
37
+ print(f" Creating directory...")
38
+ self.fonts_dir.mkdir(parents=True, exist_ok=True)
39
+ print(f"\n āŒ No fonts found!")
40
+ print(f"\n Please add .ttf font files to: {self.fonts_dir.absolute()}")
41
+ return
42
+
43
+ # Get all .ttf and .otf files
44
+ font_files = []
45
+ font_files.extend(self.fonts_dir.glob('*.ttf'))
46
+ font_files.extend(self.fonts_dir.glob('*.TTF'))
47
+ font_files.extend(self.fonts_dir.glob('*.otf'))
48
+ font_files.extend(self.fonts_dir.glob('*.OTF'))
49
+
50
+ if not font_files:
51
+ print(f"\n āŒ No font files found in {self.fonts_dir}")
52
+ print(f"\n Please add .ttf or .otf files to this directory")
53
+ return
54
+
55
+ print(f" Found {len(font_files)} font files")
56
+
57
+ # Categorize fonts by name
58
+ for font_path in font_files:
59
+ font_name = font_path.name.lower()
60
+
61
+ # Check if it's a Khmer font
62
+ is_khmer = any(keyword in font_name for keyword in [
63
+ 'khmer', 'įž€įž˜įŸ’įž–įž»įž‡įž¶', 'battambang', 'siemreap',
64
+ 'bokor', 'moul', 'content', 'metal', 'freehand',
65
+ 'fasthand', 'noto', 'kh'
66
+ ])
67
+
68
+ # Load font at different sizes
69
+ for size in [28, 32, 36, 40, 44, 48]:
70
+ try:
71
+ font = ImageFont.truetype(str(font_path), size)
72
+ self.all_fonts.append((str(font_path), size, font))
73
+
74
+ if is_khmer:
75
+ self.khmer_fonts.append((str(font_path), size, font))
76
+ else:
77
+ self.english_fonts.append((str(font_path), size, font))
78
+
79
+ except Exception as e:
80
+ print(f" āš ļø Failed to load {font_path.name} at size {size}: {e}")
81
+ continue
82
+
83
+ # Print summary
84
+ print(f"\n šŸ“Š Font Summary:")
85
+ print(f" Total fonts : {len(self.all_fonts)} (across all sizes)")
86
+ print(f" Khmer fonts : {len(self.khmer_fonts)}")
87
+ print(f" English fonts : {len(self.english_fonts)}")
88
+
89
+ def get_random_font(self, text):
90
+ """Get random appropriate font for text"""
91
+ has_khmer = any('\u1780' <= c <= '\u17FF' for c in text)
92
+
93
+ # Choose from appropriate font pool
94
+ if has_khmer and self.khmer_fonts:
95
+ font_pool = self.khmer_fonts
96
+ elif not has_khmer and self.english_fonts:
97
+ font_pool = self.english_fonts
98
+ else:
99
+ # Fallback to all fonts
100
+ font_pool = self.all_fonts
101
+
102
+ if not font_pool:
103
+ # Ultimate fallback if no fonts loaded (should error earlier)
104
+ return None
105
+
106
+ # Return random font
107
+ font_path, size, font = random.choice(font_pool)
108
+ return font
109
+
110
+ class ImageRenderer:
111
+ """Render text to images"""
112
+
113
+ def __init__(self, font_manager, image_height=32, image_width=512):
114
+ self.font_manager = font_manager
115
+ self.image_height = image_height
116
+ self.image_width = image_width
117
+
118
+ def _is_text_supported(self, font, text):
119
+ """Check if font supports the characters in text (detects 'tofu' boxes)"""
120
+ try:
121
+ # Get the glyph for a strictly undefined character to use as reference
122
+ # Using multiple candidates to be safe
123
+ undefined_chars = ['\uFFFF', '\U0010FFFF', '\0']
124
+ ref_mask = None
125
+ ref_bbox = None
126
+
127
+ for uc in undefined_chars:
128
+ try:
129
+ ref_mask = font.getmask(uc)
130
+ ref_bbox = ref_mask.getbbox()
131
+ if ref_mask:
132
+ break
133
+ except Exception:
134
+ continue
135
+
136
+ if ref_mask is None:
137
+ # Can't determine reference, assume supported to avoid blocking everything
138
+ return True
139
+
140
+ ref_bytes = bytes(ref_mask)
141
+
142
+ for char in text:
143
+ # Skip spaces and control characters (they often have empty glyphs which is fine)
144
+ if char.isspace() or ord(char) < 32:
145
+ continue
146
+
147
+ try:
148
+ char_mask = font.getmask(char)
149
+ char_bbox = char_mask.getbbox()
150
+
151
+ # Compare with reference "notdef" glyph
152
+ if char_bbox == ref_bbox:
153
+ # Exact bbox match. Deep check bytes.
154
+ if bytes(char_mask) == ref_bytes:
155
+ # It's a tofu/box!
156
+ return False
157
+ except Exception:
158
+ # Error getting mask implies issue
159
+ return False
160
+
161
+ return True
162
+ except Exception:
163
+ # If check fails, be permissive
164
+ return True
165
+
166
+ def render(self, text, augment=True, specific_font=None, retry_limit=10):
167
+ """Render text to image using PIL"""
168
+ if not HAS_PIL:
169
+ raise ImportError("Pillow library not found")
170
+
171
+ # Get font
172
+ if specific_font:
173
+ font = specific_font
174
+ # Check support for specific font
175
+ if not self._is_text_supported(font, text):
176
+ return None
177
+ else:
178
+ # Retry with retry_limit for random mode
179
+ font = None
180
+ for _ in range(retry_limit):
181
+ candidate = self.font_manager.get_random_font(text)
182
+ if candidate and self._is_text_supported(candidate, text):
183
+ font = candidate
184
+ break
185
+
186
+ if font is None:
187
+ # Just skip
188
+ return None
189
+
190
+ # Measure text size
191
+ dummy_img = Image.new('L', (1, 1))
192
+ draw = ImageDraw.Draw(dummy_img)
193
+
194
+ try:
195
+ # New PIL
196
+ bbox = draw.textbbox((0, 0), text, font=font)
197
+ text_w = bbox[2] - bbox[0]
198
+ text_h = bbox[3] - bbox[1]
199
+ offset_y = -bbox[1]
200
+ except AttributeError:
201
+ # Old PIL
202
+ text_w, text_h = draw.textsize(text, font=font)
203
+ offset_y = 0
204
+
205
+ # Padding
206
+ padding_x = random.randint(10, 30) if augment else 20
207
+ padding_y = random.randint(5, 15) if augment else 10
208
+
209
+ img_w = text_w + padding_x * 2
210
+ img_h = text_h + padding_y * 2
211
+
212
+ # Background color
213
+ bg_color = random.randint(235, 255) if augment else 255
214
+ text_color = random.randint(0, 30) if augment else 0
215
+
216
+ # Create image
217
+ img = Image.new('L', (img_w, img_h), color=bg_color)
218
+ draw = ImageDraw.Draw(img)
219
+
220
+ # Position
221
+ x = padding_x + (random.randint(-3, 3) if augment else 0)
222
+ y = padding_y + offset_y + (random.randint(-2, 2) if augment else 0)
223
+
224
+ # Draw text
225
+ draw.text((x, y), text, font=font, fill=text_color)
226
+
227
+ # Convert to numpy
228
+ img_array = np.array(img)
229
+
230
+ # Augmentations
231
+ if augment:
232
+ img_array = self._apply_augmentations(img_array, bg_color)
233
+
234
+ # Resize to target dimensions
235
+ img_array = self._resize_to_target(img_array)
236
+
237
+ return img_array
238
+
239
+ def _apply_augmentations(self, img, bg_color):
240
+ """Apply random augmentations"""
241
+ # Gaussian noise
242
+ if random.random() < 0.4:
243
+ noise = np.random.randn(*img.shape) * random.uniform(3, 8)
244
+ img = np.clip(img.astype(np.float32) + noise, 0, 255).astype(np.uint8)
245
+
246
+ # Gaussian blur
247
+ if random.random() < 0.3:
248
+ kernel_size = random.choice([3, 5])
249
+ img = cv2.GaussianBlur(img, (kernel_size, kernel_size), 0)
250
+
251
+ # Rotation - Disabled to prevent text cutoff
252
+ # if random.random() < 0.3:
253
+ # angle = random.uniform(-3, 3)
254
+ # h, w = img.shape
255
+ # M = cv2.getRotationMatrix2D((w/2, h/2), angle, 1.0)
256
+ # img = cv2.warpAffine(
257
+ # img, M, (w, h),
258
+ # borderMode=cv2.BORDER_CONSTANT,
259
+ # borderValue=int(bg_color)
260
+ # )
261
+
262
+ # Morphological operations
263
+ if random.random() < 0.2:
264
+ kernel = np.ones((2, 2), np.uint8)
265
+ if random.random() < 0.5:
266
+ img = cv2.erode(img, kernel, iterations=1)
267
+ else:
268
+ img = cv2.dilate(img, kernel, iterations=1)
269
+
270
+ # Brightness/Contrast
271
+ if random.random() < 0.3:
272
+ alpha = random.uniform(0.85, 1.15) # Contrast
273
+ beta = random.randint(-15, 15) # Brightness
274
+ img = cv2.convertScaleAbs(img, alpha=alpha, beta=beta)
275
+
276
+ return img
277
+
278
+ def _resize_to_target(self, img):
279
+ """Resize to target dimensions"""
280
+ h, w = img.shape[:2]
281
+
282
+ # Scale to match height
283
+ scale = self.image_height / h
284
+ new_w = int(w * scale)
285
+
286
+ img = cv2.resize(img, (new_w, self.image_height), interpolation=cv2.INTER_LINEAR)
287
+
288
+ # Handle width
289
+ if new_w < self.image_width:
290
+ # Pad right
291
+ bg_color = int(np.mean(img[:, -10:])) if new_w > 10 else 255
292
+ padded = np.ones((self.image_height, self.image_width), dtype=np.uint8) * bg_color
293
+ padded[:, :new_w] = img
294
+ img = padded
295
+ elif new_w > self.image_width:
296
+ # Resize to fit
297
+ img = cv2.resize(img, (self.image_width, self.image_height))
298
+
299
+ return img
300
+
301
+ class DatasetGenerator:
302
+ """Generate training dataset"""
303
+
304
+ def __init__(self, language='mixed', image_height=32, image_width=512, fonts_dir='fonts'):
305
+ self.language = language
306
+ self.image_height = image_height
307
+ self.image_width = image_width
308
+
309
+ self.font_manager = FontManager(language, fonts_dir=fonts_dir)
310
+ self.renderer = ImageRenderer(self.font_manager, image_height, image_width)
311
+
312
+ def generate_dataset(
313
+ self,
314
+ train_file,
315
+ val_file=None,
316
+ output_dir='data',
317
+ train_augment=100,
318
+ val_augment=1,
319
+ font_mode='random',
320
+ random_augment=False,
321
+ retry_limit=10
322
+ ):
323
+ """Generate complete dataset"""
324
+
325
+ print("\n" + "="*70)
326
+ print(" šŸ“ø Dataset Generation")
327
+ print("="*70)
328
+
329
+ # Check for existing dataset and ask user
330
+ append_mode = False
331
+ if os.path.exists(output_dir) and os.listdir(output_dir):
332
+ print(f"\nāš ļø Dataset directory '{output_dir}' already exists.")
333
+ while True:
334
+ try:
335
+ choice = input(" Do you want to (c)ontinue generating or (s)tart from scratch? [c/s]: ").lower().strip()
336
+ if choice == 'c':
337
+ append_mode = True
338
+ print(" šŸ”„ Continuing generation (appending to existing)...")
339
+ break
340
+ elif choice == 's':
341
+ print(" šŸ—‘ļø Cleaning up existing directory...")
342
+ shutil.rmtree(output_dir)
343
+ break
344
+ except KeyboardInterrupt:
345
+ print("\nAborted.")
346
+ return
347
+
348
+ # Create output dirs
349
+ os.makedirs(output_dir, exist_ok=True)
350
+ os.makedirs(f"{output_dir}/train/images", exist_ok=True)
351
+ os.makedirs(f"{output_dir}/val/images", exist_ok=True)
352
+
353
+ # Generate training set
354
+ print("\nšŸ“š Generating TRAINING set...")
355
+ train_data = self._generate_split(
356
+ train_file,
357
+ f"{output_dir}/train",
358
+ augment_factor=train_augment,
359
+ split_name='train',
360
+ font_mode=font_mode,
361
+ random_augment=random_augment,
362
+ retry_limit=retry_limit,
363
+ append=append_mode
364
+ )
365
+
366
+ # Generate validation set
367
+ if val_file and os.path.exists(val_file):
368
+ print("\nšŸ“š Generating VALIDATION set (from separate file)...")
369
+ val_data = self._generate_split(
370
+ val_file,
371
+ f"{output_dir}/val",
372
+ augment_factor=val_augment,
373
+ split_name='val',
374
+ retry_limit=retry_limit,
375
+ append=append_mode
376
+ )
377
+ else:
378
+ # Split from training
379
+ print("\nšŸ“š Generating VALIDATION set (split from training)...")
380
+ with open(train_file, 'r', encoding='utf-8') as f:
381
+ all_lines = [line.strip() for line in f if line.strip()]
382
+
383
+ random.shuffle(all_lines)
384
+ split_idx = int(len(all_lines) * 0.9)
385
+ val_lines = all_lines[split_idx:]
386
+
387
+ # Save temporary val file
388
+ temp_val = f"{output_dir}/temp_val.txt"
389
+ with open(temp_val, 'w', encoding='utf-8') as f:
390
+ f.write('\n'.join(val_lines))
391
+
392
+ val_data = self._generate_split(
393
+ temp_val,
394
+ f"{output_dir}/val",
395
+ augment_factor=val_augment,
396
+ split_name='val',
397
+ retry_limit=retry_limit,
398
+ append=append_mode
399
+ )
400
+
401
+ os.remove(temp_val)
402
+
403
+ print("\n" + "="*70)
404
+ print(" āœ… Dataset Generation Complete!")
405
+ print("="*70)
406
+ print(f" Train: {len(train_data):,} samples")
407
+ print(f" Val: {len(val_data):,} samples")
408
+ print(f" Output: {output_dir}")
409
+ print("="*70 + "\n")
410
+
411
+ def _generate_split(self, text_file, output_dir, augment_factor, split_name, font_mode='random', random_augment=False, retry_limit=10, append=False):
412
+ """Generate one split (train/val)"""
413
+
414
+ # Read text lines
415
+ with open(text_file, 'r', encoding='utf-8') as f:
416
+ lines = [line.strip() for line in f if line.strip()]
417
+
418
+ print(f" Loaded {len(lines)} lines from {text_file}")
419
+
420
+ # Load existing progress if appending
421
+ existing_counts = Counter()
422
+ start_idx = 0
423
+
424
+ if append and os.path.exists(f"{output_dir}/images"):
425
+ # 1. Determine start index
426
+ existing_files = [f for f in os.listdir(f"{output_dir}/images") if f.endswith('.png')]
427
+ if existing_files:
428
+ indices = []
429
+ prefix = f"{split_name}_"
430
+ for fname in existing_files:
431
+ if fname.startswith(prefix):
432
+ try:
433
+ num_part = fname[len(prefix):-4]
434
+ indices.append(int(num_part))
435
+ except ValueError:
436
+ continue
437
+
438
+ if indices:
439
+ start_idx = max(indices) + 1
440
+ print(f" šŸ‘‰ Appending starting from index {start_idx}")
441
+
442
+ # 2. Count existing samples per text line
443
+ labels_path = f"{output_dir}/labels.txt"
444
+ if os.path.exists(labels_path):
445
+ print(f" šŸ“– Analyzing existing labels to resume...")
446
+ try:
447
+ with open(labels_path, 'r', encoding='utf-8') as f:
448
+ for line in f:
449
+ parts = line.strip().split('\t', 1)
450
+ if len(parts) == 2:
451
+ existing_counts[parts[1]] += 1
452
+ print(f" āœ“ Found {sum(existing_counts.values())} existing samples")
453
+ except Exception as e:
454
+ print(f" āš ļø Could not read labels file: {e}")
455
+
456
+ # Create samples
457
+ samples = []
458
+ skipped_count = 0
459
+
460
+ if font_mode == 'all':
461
+ print(f" Using ALL fonts mode (augment factor ignored for font selection)")
462
+ fonts_list = self.font_manager.all_fonts
463
+ print(f" Iterating {len(fonts_list)} fonts per line...")
464
+
465
+ for line in lines:
466
+ # In 'all' mode, we skip line only if completely done to avoid missing fonts
467
+ # (Since we don't track which fonts were done)
468
+ if append and existing_counts[line] >= len(fonts_list):
469
+ skipped_count += len(fonts_list)
470
+ continue
471
+
472
+ for font_tuple in fonts_list:
473
+ samples.append({'text': line, 'font': font_tuple[2]})
474
+ else:
475
+ # Random mode
476
+ print(f" Using RANDOM fonts mode (retry limit: {retry_limit} attempts per sample)")
477
+ for line in lines:
478
+ needed = augment_factor
479
+ have = existing_counts[line] if append else 0
480
+
481
+ remaining = max(0, needed - have)
482
+
483
+ if remaining < needed:
484
+ skipped_count += (needed - remaining)
485
+
486
+ for _ in range(remaining):
487
+ samples.append({'text': line, 'font': None})
488
+
489
+ if skipped_count > 0:
490
+ print(f" ā­ļø Skipping {skipped_count} already generated samples")
491
+
492
+ random.shuffle(samples)
493
+
494
+ # Generate images
495
+ file_mode = 'a' if append else 'w'
496
+ labels_file = open(f"{output_dir}/labels.txt", file_mode, encoding='utf-8')
497
+ success_count = 0
498
+
499
+ print(f" Generating {len(samples)} images...")
500
+
501
+ for idx, sample in enumerate(tqdm(samples, desc=" Generating", unit="img")):
502
+ text = sample['text']
503
+ specific_font = sample['font']
504
+
505
+ # Set retry_limit based on font_mode
506
+ current_retry_limit = 1 if font_mode == 'all' else retry_limit
507
+
508
+ try:
509
+ # Render image
510
+ img = self.renderer.render(
511
+ text,
512
+ augment=((augment_factor > 1) or random_augment),
513
+ specific_font=specific_font,
514
+ retry_limit=current_retry_limit
515
+ )
516
+
517
+ if img is None:
518
+ continue
519
+
520
+ # Save
521
+ current_idx = start_idx + idx
522
+ img_filename = f"{split_name}_{current_idx:06d}.png"
523
+ img_path = f"{output_dir}/images/{img_filename}"
524
+ cv2.imwrite(img_path, img)
525
+
526
+ # Write label
527
+ labels_file.write(f"{img_filename}\t{text}\n")
528
+ success_count += 1
529
+
530
+ except Exception as e:
531
+ print(f" ⚠ Failed for '{text[:30]}...': {e}")
532
+ continue
533
+
534
+ labels_file.close()
535
+ print(f" āœ“ Generated {success_count} / {len(samples)} images\n")
536
+
537
+ return samples[:success_count]
538
+
539
+ def generate_command(args):
540
+ # Check files exist
541
+ if not os.path.exists(args.train_file):
542
+ print(f"\nāŒ Error: Training file not found: {args.train_file}\n")
543
+ return 1
544
+
545
+ if not HAS_PIL:
546
+ print("āŒ PIL/Pillow not found - install with: pip install Pillow")
547
+ return 1
548
+
549
+ # Get retry limit from args, default to 10
550
+ retry_limit = getattr(args, 'retry_limit', 10)
551
+
552
+ # Generate
553
+ generator = DatasetGenerator(
554
+ language=args.language,
555
+ image_height=args.height,
556
+ image_width=args.width,
557
+ fonts_dir=args.fonts_dir
558
+ )
559
+
560
+ generator.generate_dataset(
561
+ train_file=args.train_file,
562
+ val_file=args.val_file,
563
+ output_dir=args.output,
564
+ train_augment=args.augment,
565
+ val_augment=args.val_augment,
566
+ font_mode=args.font_mode if hasattr(args, 'font_mode') else 'random',
567
+ random_augment=args.random_augment if hasattr(args, 'random_augment') else False,
568
+ retry_limit=retry_limit
569
+ )
570
+ return 0