kiri-ocr 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
kiri_ocr/training.py ADDED
@@ -0,0 +1,508 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ from torch.utils.data import Dataset, DataLoader
5
+ from PIL import Image
6
+ import numpy as np
7
+ import os
8
+ import matplotlib.pyplot as plt
9
+ from pathlib import Path
10
+ from tqdm import tqdm
11
+
12
+ try:
13
+ from datasets import load_dataset
14
+ except ImportError:
15
+ load_dataset = None
16
+
17
+ from .model import LightweightOCR, CharacterSet, save_checkpoint
18
+
19
+
20
+ # ========== DATASET ==========
21
+ class OCRDataset(Dataset):
22
+ def __init__(self, labels_file, charset, img_height=32):
23
+ self.img_height = img_height
24
+ self.charset = charset
25
+ self.samples = []
26
+
27
+ labels_path = Path(labels_file)
28
+ self.img_dir = labels_path.parent / "images"
29
+
30
+ with open(labels_file, "r", encoding="utf-8") as f:
31
+ for line in f:
32
+ parts = line.strip().split("\t")
33
+ if len(parts) == 2:
34
+ img_name, text = parts
35
+ charset.add_chars(text)
36
+ self.samples.append((img_name, text))
37
+
38
+ print(f" Loaded {len(self.samples)} samples")
39
+
40
+ def __len__(self):
41
+ return len(self.samples)
42
+
43
+ def __getitem__(self, idx):
44
+ img_name, text = self.samples[idx]
45
+ img_path = self.img_dir / img_name
46
+
47
+ img = Image.open(img_path).convert("L")
48
+
49
+ # Resize maintaining aspect ratio
50
+ w, h = img.size
51
+ new_h = self.img_height
52
+ new_w = int(w * new_h / h)
53
+ img = img.resize((new_w, new_h), Image.LANCZOS)
54
+
55
+ # Normalize
56
+ img = np.array(img).astype(np.float32) / 255.0
57
+ img = torch.FloatTensor(img).unsqueeze(0)
58
+
59
+ target = self.charset.encode(text)
60
+ target = torch.LongTensor(target)
61
+
62
+ return img, target, text
63
+
64
+
65
+ class HFOCRDataset(Dataset):
66
+ def __init__(
67
+ self, dataset, charset, img_height=32, image_col="image", text_col="text"
68
+ ):
69
+ if load_dataset is None:
70
+ raise ImportError("Please install 'datasets' library: pip install datasets")
71
+
72
+ self.img_height = img_height
73
+ self.charset = charset
74
+ self.image_col = image_col
75
+ self.text_col = text_col
76
+ self.dataset = dataset
77
+
78
+ # Iterate to build charset
79
+ print(f" šŸ”„ Scanning charset from '{text_col}' column...")
80
+ if text_col not in self.dataset.column_names:
81
+ raise ValueError(
82
+ f"Column '{text_col}' not found in dataset. Available: {self.dataset.column_names}"
83
+ )
84
+
85
+ # Optimize: select only text column to avoid decoding images during scan
86
+ try:
87
+ iter_ds = self.dataset.select_columns([text_col])
88
+ except:
89
+ iter_ds = self.dataset
90
+
91
+ for item in tqdm(iter_ds, desc="Scanning charset"):
92
+ text = item.get(text_col, "")
93
+ if text:
94
+ charset.add_chars(text)
95
+
96
+ # Clear memory
97
+ del iter_ds
98
+ import gc
99
+ gc.collect()
100
+
101
+ def __len__(self):
102
+ return len(self.dataset)
103
+
104
+ def __getitem__(self, idx):
105
+ item = self.dataset[idx]
106
+ img = item[self.image_col]
107
+ text = item[self.text_col]
108
+
109
+ if img.mode != "L":
110
+ img = img.convert("L")
111
+
112
+ # Resize maintaining aspect ratio
113
+ w, h = img.size
114
+ new_h = self.img_height
115
+ new_w = int(w * new_h / h)
116
+ img = img.resize((new_w, new_h), Image.LANCZOS)
117
+
118
+ # Normalize
119
+ img = np.array(img).astype(np.float32) / 255.0
120
+ img = torch.FloatTensor(img).unsqueeze(0)
121
+
122
+ target = self.charset.encode(text)
123
+ target = torch.LongTensor(target)
124
+
125
+ return img, target, text
126
+
127
+
128
+ def collate_fn(batch):
129
+ images, targets, texts = zip(*batch)
130
+
131
+ max_width = max([img.size(2) for img in images])
132
+ batch_size = len(images)
133
+ height = images[0].size(1)
134
+
135
+ padded_images = torch.zeros(batch_size, 1, height, max_width)
136
+
137
+ for i, img in enumerate(images):
138
+ w = img.size(2)
139
+ padded_images[i, :, :, :w] = img
140
+
141
+ target_lengths = torch.LongTensor([len(t) for t in targets])
142
+ targets_concat = torch.cat(targets)
143
+
144
+ return padded_images, targets_concat, target_lengths, texts
145
+
146
+
147
+ def train_model(
148
+ model,
149
+ train_loader,
150
+ val_loader,
151
+ charset,
152
+ num_epochs=200,
153
+ device="cuda",
154
+ save_dir="models",
155
+ lr=0.001,
156
+ weight_decay=0.0001,
157
+ ):
158
+ os.makedirs(save_dir, exist_ok=True)
159
+
160
+ model = model.to(device)
161
+ criterion = nn.CTCLoss(blank=0, zero_infinity=True)
162
+ optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
163
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(
164
+ optimizer, mode="min", factor=0.5, patience=10
165
+ )
166
+
167
+ best_loss = float("inf")
168
+ best_acc = 0
169
+
170
+ print("\n" + "=" * 70)
171
+ print(" šŸš€ Training Lightweight OCR")
172
+ print("=" * 70)
173
+
174
+ history = {"train_loss": [], "val_loss": [], "acc": []}
175
+
176
+ for epoch in range(num_epochs):
177
+ # Training
178
+ model.train()
179
+ train_loss = 0
180
+
181
+ # Progress bar for training
182
+ pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch")
183
+
184
+ for batch_idx, (images, targets, target_lengths, texts) in enumerate(pbar):
185
+ images = images.to(device)
186
+ targets = targets.to(device)
187
+ target_lengths = target_lengths.to(device)
188
+
189
+ optimizer.zero_grad()
190
+ outputs = model(images)
191
+
192
+ input_lengths = torch.full(
193
+ size=(outputs.size(1),),
194
+ fill_value=outputs.size(0),
195
+ dtype=torch.long,
196
+ device=device,
197
+ )
198
+
199
+ log_probs = nn.functional.log_softmax(outputs, dim=2)
200
+ loss = criterion(log_probs, targets, input_lengths, target_lengths)
201
+
202
+ loss.backward()
203
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
204
+ optimizer.step()
205
+
206
+ train_loss += loss.item()
207
+
208
+ # Update progress bar
209
+ pbar.set_postfix({"loss": loss.item()})
210
+
211
+ train_loss /= len(train_loader)
212
+
213
+ # Validation
214
+ model.eval()
215
+ val_loss = 0
216
+ correct = 0
217
+ total = 0
218
+
219
+ with torch.no_grad():
220
+ for images, targets, target_lengths, texts in val_loader:
221
+ images = images.to(device)
222
+ targets = targets.to(device)
223
+ target_lengths = target_lengths.to(device)
224
+
225
+ outputs = model(images)
226
+
227
+ input_lengths = torch.full(
228
+ size=(outputs.size(1),),
229
+ fill_value=outputs.size(0),
230
+ dtype=torch.long,
231
+ device=device,
232
+ )
233
+
234
+ log_probs = nn.functional.log_softmax(outputs, dim=2)
235
+ loss = criterion(log_probs, targets, input_lengths, target_lengths)
236
+ val_loss += loss.item()
237
+
238
+ # Accuracy
239
+ _, preds = outputs.max(2)
240
+ preds = preds.transpose(0, 1)
241
+
242
+ for i, text in enumerate(texts):
243
+ pred_indices = preds[i].cpu().numpy()
244
+ pred_text = charset.decode(pred_indices)
245
+ if pred_text == text:
246
+ correct += 1
247
+ total += 1
248
+
249
+ val_loss /= len(val_loader)
250
+ accuracy = correct / total * 100
251
+
252
+ scheduler.step(val_loss)
253
+
254
+ print(
255
+ f"Epoch [{epoch+1:3d}/{num_epochs}] "
256
+ f"Train: {train_loss:.4f} | "
257
+ f"Val: {val_loss:.4f} | "
258
+ f"Acc: {accuracy:5.2f}%"
259
+ )
260
+
261
+ # Store history
262
+ history["train_loss"].append(train_loss)
263
+ history["val_loss"].append(val_loss)
264
+ history["acc"].append(accuracy)
265
+
266
+ # Save checkpoint every epoch
267
+ save_checkpoint(
268
+ model,
269
+ charset,
270
+ optimizer,
271
+ epoch + 1,
272
+ val_loss,
273
+ accuracy,
274
+ f"{save_dir}/checkpoint_epoch_{epoch+1}.kiri",
275
+ )
276
+
277
+ # Save best model
278
+ if accuracy > best_acc or (accuracy == best_acc and val_loss < best_loss):
279
+ best_loss = val_loss
280
+ best_acc = accuracy
281
+
282
+ save_checkpoint(
283
+ model,
284
+ charset,
285
+ optimizer,
286
+ epoch + 1,
287
+ val_loss,
288
+ accuracy,
289
+ f"{save_dir}/model.kiri",
290
+ )
291
+ print(f" āœ“ Saved Best! Acc: {accuracy:.2f}%")
292
+
293
+ if torch.cuda.is_available():
294
+ torch.cuda.empty_cache()
295
+
296
+ # Plot history
297
+ try:
298
+ plt.figure(figsize=(12, 5))
299
+
300
+ plt.subplot(1, 2, 1)
301
+ plt.plot(history["train_loss"], label="Train Loss")
302
+ plt.plot(history["val_loss"], label="Val Loss")
303
+ plt.xlabel("Epoch")
304
+ plt.ylabel("Loss")
305
+ plt.title("Training Loss")
306
+ plt.legend()
307
+ plt.grid(True)
308
+
309
+ plt.subplot(1, 2, 2)
310
+ plt.plot(history["acc"], label="Validation Accuracy", color="green")
311
+ plt.xlabel("Epoch")
312
+ plt.ylabel("Accuracy (%)")
313
+ plt.title("Training Accuracy")
314
+ plt.legend()
315
+ plt.grid(True)
316
+
317
+ plt.tight_layout()
318
+ plt.savefig(f"{save_dir}/training_history.png")
319
+ print(f"\nšŸ“Š Training plot saved to {save_dir}/training_history.png")
320
+ except Exception as e:
321
+ print(f"\nāš ļø Failed to save plot: {e}")
322
+
323
+ print(f"\nšŸŽ‰ Training complete! Best accuracy: {best_acc:.2f}%\n")
324
+
325
+
326
+ def train_command(args):
327
+ IMAGE_HEIGHT = args.height
328
+ BATCH_SIZE = args.batch_size
329
+ NUM_EPOCHS = args.epochs
330
+ HIDDEN_SIZE = args.hidden_size
331
+
332
+ device = args.device
333
+ if device == "cuda" and not torch.cuda.is_available():
334
+ print("āš ļø CUDA not available, using CPU")
335
+ device = "cpu"
336
+
337
+ print(f"\nšŸ–„ļø Device: {device}\n")
338
+
339
+ # Load charset from pretrained model if available
340
+ charset = CharacterSet()
341
+ if (
342
+ hasattr(args, "from_model")
343
+ and args.from_model
344
+ and os.path.exists(args.from_model)
345
+ ):
346
+ try:
347
+ checkpoint = torch.load(args.from_model, map_location="cpu")
348
+ if "charset" in checkpoint:
349
+ print(f"šŸ”„ Loading charset from {args.from_model}")
350
+ charset = CharacterSet.from_checkpoint(checkpoint)
351
+ except Exception:
352
+ pass # Will be handled later or just start fresh
353
+
354
+ print("šŸ“‚ Loading datasets...")
355
+
356
+ train_dataset = None
357
+ val_dataset = None
358
+
359
+ if hasattr(args, "hf_dataset") and args.hf_dataset:
360
+ print(f" ā¬‡ļø Loading HF dataset: {args.hf_dataset}")
361
+ subset = getattr(args, "hf_subset", None)
362
+
363
+ # Load train split
364
+ try:
365
+ ds = load_dataset(args.hf_dataset, subset, split=args.hf_train_split)
366
+ except Exception as e:
367
+ print(f"āŒ Error loading dataset: {e}")
368
+ return
369
+
370
+ val_ds = None
371
+ # Try finding val split
372
+ val_splits = (
373
+ [args.hf_val_split] if args.hf_val_split else ["val", "test", "validation"]
374
+ )
375
+
376
+ for split in val_splits:
377
+ if not split:
378
+ continue
379
+ try:
380
+ # We check if we can load it
381
+ candidate = load_dataset(args.hf_dataset, subset, split=split)
382
+ val_ds = candidate
383
+ print(f" āœ“ Found validation split: '{split}'")
384
+ break
385
+ except:
386
+ pass
387
+
388
+ if val_ds is None:
389
+ print(
390
+ f" āš ļø No validation split found. Splitting {args.hf_val_percent*100}% from train..."
391
+ )
392
+ try:
393
+ split_ds = ds.train_test_split(test_size=args.hf_val_percent)
394
+ ds = split_ds["train"]
395
+ val_ds = split_ds["test"]
396
+ except Exception as e:
397
+ print(f" āš ļø Could not split dataset: {e}. Using train for validation.")
398
+ val_ds = ds
399
+
400
+ train_dataset = HFOCRDataset(
401
+ ds,
402
+ charset,
403
+ img_height=IMAGE_HEIGHT,
404
+ image_col=args.hf_image_col,
405
+ text_col=args.hf_text_col,
406
+ )
407
+ val_dataset = HFOCRDataset(
408
+ val_ds,
409
+ charset,
410
+ img_height=IMAGE_HEIGHT,
411
+ image_col=args.hf_image_col,
412
+ text_col=args.hf_text_col,
413
+ )
414
+ else:
415
+ if not os.path.exists(args.train_labels):
416
+ print(f"āŒ Training labels not found: {args.train_labels}")
417
+ return
418
+
419
+ train_dataset = OCRDataset(args.train_labels, charset, img_height=IMAGE_HEIGHT)
420
+
421
+ if args.val_labels and os.path.exists(args.val_labels):
422
+ val_dataset = OCRDataset(args.val_labels, charset, img_height=IMAGE_HEIGHT)
423
+ else:
424
+ print(
425
+ f"āš ļø Validation labels not found. Using training set for validation (not recommended)."
426
+ )
427
+ val_dataset = train_dataset
428
+
429
+ print(f"\nšŸ“Š Dataset: {len(train_dataset)} train, {len(val_dataset)} val")
430
+ print(f"šŸ“ Characters: {len(charset)}\n")
431
+
432
+ # Clear memory before training
433
+ import gc
434
+ gc.collect()
435
+ if torch.cuda.is_available():
436
+ torch.cuda.empty_cache()
437
+
438
+ os.makedirs(args.output_dir, exist_ok=True)
439
+ # charset.save(f'{args.output_dir}/charset_lite.txt') # Not needed if using .kiri
440
+
441
+ train_loader = DataLoader(
442
+ train_dataset,
443
+ batch_size=BATCH_SIZE,
444
+ shuffle=True,
445
+ num_workers=4 if device == "cuda" else 0,
446
+ collate_fn=collate_fn,
447
+ )
448
+ val_loader = DataLoader(
449
+ val_dataset,
450
+ batch_size=BATCH_SIZE,
451
+ shuffle=False,
452
+ num_workers=4 if device == "cuda" else 0,
453
+ collate_fn=collate_fn,
454
+ )
455
+
456
+ model = LightweightOCR(num_chars=len(charset), hidden_size=HIDDEN_SIZE)
457
+
458
+ # Load pretrained model if specified
459
+ if hasattr(args, "from_model") and args.from_model:
460
+ print(f"šŸ”„ Loading pretrained model from {args.from_model}")
461
+ if os.path.exists(args.from_model):
462
+ try:
463
+ checkpoint = torch.load(args.from_model, map_location=device)
464
+ if "model_state_dict" in checkpoint:
465
+ state_dict = checkpoint["model_state_dict"]
466
+ model_dict = model.state_dict()
467
+
468
+ # Filter out unnecessary keys or shape mismatches
469
+ pretrained_dict = {
470
+ k: v
471
+ for k, v in state_dict.items()
472
+ if k in model_dict and v.shape == model_dict[k].shape
473
+ }
474
+
475
+ # Log what was skipped
476
+ skipped_keys = [k for k in state_dict if k not in pretrained_dict]
477
+ if skipped_keys:
478
+ print(
479
+ f" āš ļø Skipped mismatched layers (fine-tuning): {skipped_keys[:5]} ..."
480
+ )
481
+
482
+ # Overwrite entries in the existing state dict
483
+ model_dict.update(pretrained_dict)
484
+ model.load_state_dict(model_dict)
485
+ print(" āœ“ Weights loaded successfully")
486
+ else:
487
+ print(" āš ļø Invalid checkpoint format (missing model_state_dict)")
488
+ except Exception as e:
489
+ print(f" āŒ Error loading model: {e}")
490
+ else:
491
+ print(f" āŒ Pretrained model not found: {args.from_model}")
492
+
493
+ total_params = sum(p.numel() for p in model.parameters())
494
+ model_size_mb = total_params * 4 / 1024 / 1024
495
+
496
+ print(f"šŸ—ļø Model: {total_params:,} params ({model_size_mb:.2f} MB)\n")
497
+
498
+ train_model(
499
+ model,
500
+ train_loader,
501
+ val_loader,
502
+ charset,
503
+ num_epochs=NUM_EPOCHS,
504
+ device=device,
505
+ save_dir=args.output_dir,
506
+ lr=args.lr,
507
+ weight_decay=args.weight_decay,
508
+ )
@@ -0,0 +1,6 @@
1
+ #!python
2
+ import sys
3
+ from kiri_ocr.cli import main
4
+
5
+ if __name__ == '__main__':
6
+ sys.exit(main())