kiri-ocr 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kiri_ocr/__init__.py +14 -0
- kiri_ocr/cli.py +244 -0
- kiri_ocr/core.py +306 -0
- kiri_ocr/detector.py +374 -0
- kiri_ocr/generator.py +570 -0
- kiri_ocr/model.py +159 -0
- kiri_ocr/renderer.py +193 -0
- kiri_ocr/training.py +508 -0
- kiri_ocr-0.1.0.data/scripts/kiri-ocr +6 -0
- kiri_ocr-0.1.0.dist-info/METADATA +218 -0
- kiri_ocr-0.1.0.dist-info/RECORD +16 -0
- kiri_ocr-0.1.0.dist-info/WHEEL +5 -0
- kiri_ocr-0.1.0.dist-info/licenses/LICENSE +201 -0
- kiri_ocr-0.1.0.dist-info/top_level.txt +2 -0
- models/__init__.py +1 -0
- models/model.kiri +0 -0
kiri_ocr/training.py
ADDED
|
@@ -0,0 +1,508 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import torch.optim as optim
|
|
4
|
+
from torch.utils.data import Dataset, DataLoader
|
|
5
|
+
from PIL import Image
|
|
6
|
+
import numpy as np
|
|
7
|
+
import os
|
|
8
|
+
import matplotlib.pyplot as plt
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from tqdm import tqdm
|
|
11
|
+
|
|
12
|
+
try:
|
|
13
|
+
from datasets import load_dataset
|
|
14
|
+
except ImportError:
|
|
15
|
+
load_dataset = None
|
|
16
|
+
|
|
17
|
+
from .model import LightweightOCR, CharacterSet, save_checkpoint
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
# ========== DATASET ==========
|
|
21
|
+
class OCRDataset(Dataset):
|
|
22
|
+
def __init__(self, labels_file, charset, img_height=32):
|
|
23
|
+
self.img_height = img_height
|
|
24
|
+
self.charset = charset
|
|
25
|
+
self.samples = []
|
|
26
|
+
|
|
27
|
+
labels_path = Path(labels_file)
|
|
28
|
+
self.img_dir = labels_path.parent / "images"
|
|
29
|
+
|
|
30
|
+
with open(labels_file, "r", encoding="utf-8") as f:
|
|
31
|
+
for line in f:
|
|
32
|
+
parts = line.strip().split("\t")
|
|
33
|
+
if len(parts) == 2:
|
|
34
|
+
img_name, text = parts
|
|
35
|
+
charset.add_chars(text)
|
|
36
|
+
self.samples.append((img_name, text))
|
|
37
|
+
|
|
38
|
+
print(f" Loaded {len(self.samples)} samples")
|
|
39
|
+
|
|
40
|
+
def __len__(self):
|
|
41
|
+
return len(self.samples)
|
|
42
|
+
|
|
43
|
+
def __getitem__(self, idx):
|
|
44
|
+
img_name, text = self.samples[idx]
|
|
45
|
+
img_path = self.img_dir / img_name
|
|
46
|
+
|
|
47
|
+
img = Image.open(img_path).convert("L")
|
|
48
|
+
|
|
49
|
+
# Resize maintaining aspect ratio
|
|
50
|
+
w, h = img.size
|
|
51
|
+
new_h = self.img_height
|
|
52
|
+
new_w = int(w * new_h / h)
|
|
53
|
+
img = img.resize((new_w, new_h), Image.LANCZOS)
|
|
54
|
+
|
|
55
|
+
# Normalize
|
|
56
|
+
img = np.array(img).astype(np.float32) / 255.0
|
|
57
|
+
img = torch.FloatTensor(img).unsqueeze(0)
|
|
58
|
+
|
|
59
|
+
target = self.charset.encode(text)
|
|
60
|
+
target = torch.LongTensor(target)
|
|
61
|
+
|
|
62
|
+
return img, target, text
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class HFOCRDataset(Dataset):
|
|
66
|
+
def __init__(
|
|
67
|
+
self, dataset, charset, img_height=32, image_col="image", text_col="text"
|
|
68
|
+
):
|
|
69
|
+
if load_dataset is None:
|
|
70
|
+
raise ImportError("Please install 'datasets' library: pip install datasets")
|
|
71
|
+
|
|
72
|
+
self.img_height = img_height
|
|
73
|
+
self.charset = charset
|
|
74
|
+
self.image_col = image_col
|
|
75
|
+
self.text_col = text_col
|
|
76
|
+
self.dataset = dataset
|
|
77
|
+
|
|
78
|
+
# Iterate to build charset
|
|
79
|
+
print(f" š Scanning charset from '{text_col}' column...")
|
|
80
|
+
if text_col not in self.dataset.column_names:
|
|
81
|
+
raise ValueError(
|
|
82
|
+
f"Column '{text_col}' not found in dataset. Available: {self.dataset.column_names}"
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
# Optimize: select only text column to avoid decoding images during scan
|
|
86
|
+
try:
|
|
87
|
+
iter_ds = self.dataset.select_columns([text_col])
|
|
88
|
+
except:
|
|
89
|
+
iter_ds = self.dataset
|
|
90
|
+
|
|
91
|
+
for item in tqdm(iter_ds, desc="Scanning charset"):
|
|
92
|
+
text = item.get(text_col, "")
|
|
93
|
+
if text:
|
|
94
|
+
charset.add_chars(text)
|
|
95
|
+
|
|
96
|
+
# Clear memory
|
|
97
|
+
del iter_ds
|
|
98
|
+
import gc
|
|
99
|
+
gc.collect()
|
|
100
|
+
|
|
101
|
+
def __len__(self):
|
|
102
|
+
return len(self.dataset)
|
|
103
|
+
|
|
104
|
+
def __getitem__(self, idx):
|
|
105
|
+
item = self.dataset[idx]
|
|
106
|
+
img = item[self.image_col]
|
|
107
|
+
text = item[self.text_col]
|
|
108
|
+
|
|
109
|
+
if img.mode != "L":
|
|
110
|
+
img = img.convert("L")
|
|
111
|
+
|
|
112
|
+
# Resize maintaining aspect ratio
|
|
113
|
+
w, h = img.size
|
|
114
|
+
new_h = self.img_height
|
|
115
|
+
new_w = int(w * new_h / h)
|
|
116
|
+
img = img.resize((new_w, new_h), Image.LANCZOS)
|
|
117
|
+
|
|
118
|
+
# Normalize
|
|
119
|
+
img = np.array(img).astype(np.float32) / 255.0
|
|
120
|
+
img = torch.FloatTensor(img).unsqueeze(0)
|
|
121
|
+
|
|
122
|
+
target = self.charset.encode(text)
|
|
123
|
+
target = torch.LongTensor(target)
|
|
124
|
+
|
|
125
|
+
return img, target, text
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def collate_fn(batch):
|
|
129
|
+
images, targets, texts = zip(*batch)
|
|
130
|
+
|
|
131
|
+
max_width = max([img.size(2) for img in images])
|
|
132
|
+
batch_size = len(images)
|
|
133
|
+
height = images[0].size(1)
|
|
134
|
+
|
|
135
|
+
padded_images = torch.zeros(batch_size, 1, height, max_width)
|
|
136
|
+
|
|
137
|
+
for i, img in enumerate(images):
|
|
138
|
+
w = img.size(2)
|
|
139
|
+
padded_images[i, :, :, :w] = img
|
|
140
|
+
|
|
141
|
+
target_lengths = torch.LongTensor([len(t) for t in targets])
|
|
142
|
+
targets_concat = torch.cat(targets)
|
|
143
|
+
|
|
144
|
+
return padded_images, targets_concat, target_lengths, texts
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def train_model(
|
|
148
|
+
model,
|
|
149
|
+
train_loader,
|
|
150
|
+
val_loader,
|
|
151
|
+
charset,
|
|
152
|
+
num_epochs=200,
|
|
153
|
+
device="cuda",
|
|
154
|
+
save_dir="models",
|
|
155
|
+
lr=0.001,
|
|
156
|
+
weight_decay=0.0001,
|
|
157
|
+
):
|
|
158
|
+
os.makedirs(save_dir, exist_ok=True)
|
|
159
|
+
|
|
160
|
+
model = model.to(device)
|
|
161
|
+
criterion = nn.CTCLoss(blank=0, zero_infinity=True)
|
|
162
|
+
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
|
|
163
|
+
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
|
|
164
|
+
optimizer, mode="min", factor=0.5, patience=10
|
|
165
|
+
)
|
|
166
|
+
|
|
167
|
+
best_loss = float("inf")
|
|
168
|
+
best_acc = 0
|
|
169
|
+
|
|
170
|
+
print("\n" + "=" * 70)
|
|
171
|
+
print(" š Training Lightweight OCR")
|
|
172
|
+
print("=" * 70)
|
|
173
|
+
|
|
174
|
+
history = {"train_loss": [], "val_loss": [], "acc": []}
|
|
175
|
+
|
|
176
|
+
for epoch in range(num_epochs):
|
|
177
|
+
# Training
|
|
178
|
+
model.train()
|
|
179
|
+
train_loss = 0
|
|
180
|
+
|
|
181
|
+
# Progress bar for training
|
|
182
|
+
pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", unit="batch")
|
|
183
|
+
|
|
184
|
+
for batch_idx, (images, targets, target_lengths, texts) in enumerate(pbar):
|
|
185
|
+
images = images.to(device)
|
|
186
|
+
targets = targets.to(device)
|
|
187
|
+
target_lengths = target_lengths.to(device)
|
|
188
|
+
|
|
189
|
+
optimizer.zero_grad()
|
|
190
|
+
outputs = model(images)
|
|
191
|
+
|
|
192
|
+
input_lengths = torch.full(
|
|
193
|
+
size=(outputs.size(1),),
|
|
194
|
+
fill_value=outputs.size(0),
|
|
195
|
+
dtype=torch.long,
|
|
196
|
+
device=device,
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
log_probs = nn.functional.log_softmax(outputs, dim=2)
|
|
200
|
+
loss = criterion(log_probs, targets, input_lengths, target_lengths)
|
|
201
|
+
|
|
202
|
+
loss.backward()
|
|
203
|
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
|
|
204
|
+
optimizer.step()
|
|
205
|
+
|
|
206
|
+
train_loss += loss.item()
|
|
207
|
+
|
|
208
|
+
# Update progress bar
|
|
209
|
+
pbar.set_postfix({"loss": loss.item()})
|
|
210
|
+
|
|
211
|
+
train_loss /= len(train_loader)
|
|
212
|
+
|
|
213
|
+
# Validation
|
|
214
|
+
model.eval()
|
|
215
|
+
val_loss = 0
|
|
216
|
+
correct = 0
|
|
217
|
+
total = 0
|
|
218
|
+
|
|
219
|
+
with torch.no_grad():
|
|
220
|
+
for images, targets, target_lengths, texts in val_loader:
|
|
221
|
+
images = images.to(device)
|
|
222
|
+
targets = targets.to(device)
|
|
223
|
+
target_lengths = target_lengths.to(device)
|
|
224
|
+
|
|
225
|
+
outputs = model(images)
|
|
226
|
+
|
|
227
|
+
input_lengths = torch.full(
|
|
228
|
+
size=(outputs.size(1),),
|
|
229
|
+
fill_value=outputs.size(0),
|
|
230
|
+
dtype=torch.long,
|
|
231
|
+
device=device,
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
log_probs = nn.functional.log_softmax(outputs, dim=2)
|
|
235
|
+
loss = criterion(log_probs, targets, input_lengths, target_lengths)
|
|
236
|
+
val_loss += loss.item()
|
|
237
|
+
|
|
238
|
+
# Accuracy
|
|
239
|
+
_, preds = outputs.max(2)
|
|
240
|
+
preds = preds.transpose(0, 1)
|
|
241
|
+
|
|
242
|
+
for i, text in enumerate(texts):
|
|
243
|
+
pred_indices = preds[i].cpu().numpy()
|
|
244
|
+
pred_text = charset.decode(pred_indices)
|
|
245
|
+
if pred_text == text:
|
|
246
|
+
correct += 1
|
|
247
|
+
total += 1
|
|
248
|
+
|
|
249
|
+
val_loss /= len(val_loader)
|
|
250
|
+
accuracy = correct / total * 100
|
|
251
|
+
|
|
252
|
+
scheduler.step(val_loss)
|
|
253
|
+
|
|
254
|
+
print(
|
|
255
|
+
f"Epoch [{epoch+1:3d}/{num_epochs}] "
|
|
256
|
+
f"Train: {train_loss:.4f} | "
|
|
257
|
+
f"Val: {val_loss:.4f} | "
|
|
258
|
+
f"Acc: {accuracy:5.2f}%"
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
# Store history
|
|
262
|
+
history["train_loss"].append(train_loss)
|
|
263
|
+
history["val_loss"].append(val_loss)
|
|
264
|
+
history["acc"].append(accuracy)
|
|
265
|
+
|
|
266
|
+
# Save checkpoint every epoch
|
|
267
|
+
save_checkpoint(
|
|
268
|
+
model,
|
|
269
|
+
charset,
|
|
270
|
+
optimizer,
|
|
271
|
+
epoch + 1,
|
|
272
|
+
val_loss,
|
|
273
|
+
accuracy,
|
|
274
|
+
f"{save_dir}/checkpoint_epoch_{epoch+1}.kiri",
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
# Save best model
|
|
278
|
+
if accuracy > best_acc or (accuracy == best_acc and val_loss < best_loss):
|
|
279
|
+
best_loss = val_loss
|
|
280
|
+
best_acc = accuracy
|
|
281
|
+
|
|
282
|
+
save_checkpoint(
|
|
283
|
+
model,
|
|
284
|
+
charset,
|
|
285
|
+
optimizer,
|
|
286
|
+
epoch + 1,
|
|
287
|
+
val_loss,
|
|
288
|
+
accuracy,
|
|
289
|
+
f"{save_dir}/model.kiri",
|
|
290
|
+
)
|
|
291
|
+
print(f" ā Saved Best! Acc: {accuracy:.2f}%")
|
|
292
|
+
|
|
293
|
+
if torch.cuda.is_available():
|
|
294
|
+
torch.cuda.empty_cache()
|
|
295
|
+
|
|
296
|
+
# Plot history
|
|
297
|
+
try:
|
|
298
|
+
plt.figure(figsize=(12, 5))
|
|
299
|
+
|
|
300
|
+
plt.subplot(1, 2, 1)
|
|
301
|
+
plt.plot(history["train_loss"], label="Train Loss")
|
|
302
|
+
plt.plot(history["val_loss"], label="Val Loss")
|
|
303
|
+
plt.xlabel("Epoch")
|
|
304
|
+
plt.ylabel("Loss")
|
|
305
|
+
plt.title("Training Loss")
|
|
306
|
+
plt.legend()
|
|
307
|
+
plt.grid(True)
|
|
308
|
+
|
|
309
|
+
plt.subplot(1, 2, 2)
|
|
310
|
+
plt.plot(history["acc"], label="Validation Accuracy", color="green")
|
|
311
|
+
plt.xlabel("Epoch")
|
|
312
|
+
plt.ylabel("Accuracy (%)")
|
|
313
|
+
plt.title("Training Accuracy")
|
|
314
|
+
plt.legend()
|
|
315
|
+
plt.grid(True)
|
|
316
|
+
|
|
317
|
+
plt.tight_layout()
|
|
318
|
+
plt.savefig(f"{save_dir}/training_history.png")
|
|
319
|
+
print(f"\nš Training plot saved to {save_dir}/training_history.png")
|
|
320
|
+
except Exception as e:
|
|
321
|
+
print(f"\nā ļø Failed to save plot: {e}")
|
|
322
|
+
|
|
323
|
+
print(f"\nš Training complete! Best accuracy: {best_acc:.2f}%\n")
|
|
324
|
+
|
|
325
|
+
|
|
326
|
+
def train_command(args):
|
|
327
|
+
IMAGE_HEIGHT = args.height
|
|
328
|
+
BATCH_SIZE = args.batch_size
|
|
329
|
+
NUM_EPOCHS = args.epochs
|
|
330
|
+
HIDDEN_SIZE = args.hidden_size
|
|
331
|
+
|
|
332
|
+
device = args.device
|
|
333
|
+
if device == "cuda" and not torch.cuda.is_available():
|
|
334
|
+
print("ā ļø CUDA not available, using CPU")
|
|
335
|
+
device = "cpu"
|
|
336
|
+
|
|
337
|
+
print(f"\nš„ļø Device: {device}\n")
|
|
338
|
+
|
|
339
|
+
# Load charset from pretrained model if available
|
|
340
|
+
charset = CharacterSet()
|
|
341
|
+
if (
|
|
342
|
+
hasattr(args, "from_model")
|
|
343
|
+
and args.from_model
|
|
344
|
+
and os.path.exists(args.from_model)
|
|
345
|
+
):
|
|
346
|
+
try:
|
|
347
|
+
checkpoint = torch.load(args.from_model, map_location="cpu")
|
|
348
|
+
if "charset" in checkpoint:
|
|
349
|
+
print(f"š Loading charset from {args.from_model}")
|
|
350
|
+
charset = CharacterSet.from_checkpoint(checkpoint)
|
|
351
|
+
except Exception:
|
|
352
|
+
pass # Will be handled later or just start fresh
|
|
353
|
+
|
|
354
|
+
print("š Loading datasets...")
|
|
355
|
+
|
|
356
|
+
train_dataset = None
|
|
357
|
+
val_dataset = None
|
|
358
|
+
|
|
359
|
+
if hasattr(args, "hf_dataset") and args.hf_dataset:
|
|
360
|
+
print(f" ā¬ļø Loading HF dataset: {args.hf_dataset}")
|
|
361
|
+
subset = getattr(args, "hf_subset", None)
|
|
362
|
+
|
|
363
|
+
# Load train split
|
|
364
|
+
try:
|
|
365
|
+
ds = load_dataset(args.hf_dataset, subset, split=args.hf_train_split)
|
|
366
|
+
except Exception as e:
|
|
367
|
+
print(f"ā Error loading dataset: {e}")
|
|
368
|
+
return
|
|
369
|
+
|
|
370
|
+
val_ds = None
|
|
371
|
+
# Try finding val split
|
|
372
|
+
val_splits = (
|
|
373
|
+
[args.hf_val_split] if args.hf_val_split else ["val", "test", "validation"]
|
|
374
|
+
)
|
|
375
|
+
|
|
376
|
+
for split in val_splits:
|
|
377
|
+
if not split:
|
|
378
|
+
continue
|
|
379
|
+
try:
|
|
380
|
+
# We check if we can load it
|
|
381
|
+
candidate = load_dataset(args.hf_dataset, subset, split=split)
|
|
382
|
+
val_ds = candidate
|
|
383
|
+
print(f" ā Found validation split: '{split}'")
|
|
384
|
+
break
|
|
385
|
+
except:
|
|
386
|
+
pass
|
|
387
|
+
|
|
388
|
+
if val_ds is None:
|
|
389
|
+
print(
|
|
390
|
+
f" ā ļø No validation split found. Splitting {args.hf_val_percent*100}% from train..."
|
|
391
|
+
)
|
|
392
|
+
try:
|
|
393
|
+
split_ds = ds.train_test_split(test_size=args.hf_val_percent)
|
|
394
|
+
ds = split_ds["train"]
|
|
395
|
+
val_ds = split_ds["test"]
|
|
396
|
+
except Exception as e:
|
|
397
|
+
print(f" ā ļø Could not split dataset: {e}. Using train for validation.")
|
|
398
|
+
val_ds = ds
|
|
399
|
+
|
|
400
|
+
train_dataset = HFOCRDataset(
|
|
401
|
+
ds,
|
|
402
|
+
charset,
|
|
403
|
+
img_height=IMAGE_HEIGHT,
|
|
404
|
+
image_col=args.hf_image_col,
|
|
405
|
+
text_col=args.hf_text_col,
|
|
406
|
+
)
|
|
407
|
+
val_dataset = HFOCRDataset(
|
|
408
|
+
val_ds,
|
|
409
|
+
charset,
|
|
410
|
+
img_height=IMAGE_HEIGHT,
|
|
411
|
+
image_col=args.hf_image_col,
|
|
412
|
+
text_col=args.hf_text_col,
|
|
413
|
+
)
|
|
414
|
+
else:
|
|
415
|
+
if not os.path.exists(args.train_labels):
|
|
416
|
+
print(f"ā Training labels not found: {args.train_labels}")
|
|
417
|
+
return
|
|
418
|
+
|
|
419
|
+
train_dataset = OCRDataset(args.train_labels, charset, img_height=IMAGE_HEIGHT)
|
|
420
|
+
|
|
421
|
+
if args.val_labels and os.path.exists(args.val_labels):
|
|
422
|
+
val_dataset = OCRDataset(args.val_labels, charset, img_height=IMAGE_HEIGHT)
|
|
423
|
+
else:
|
|
424
|
+
print(
|
|
425
|
+
f"ā ļø Validation labels not found. Using training set for validation (not recommended)."
|
|
426
|
+
)
|
|
427
|
+
val_dataset = train_dataset
|
|
428
|
+
|
|
429
|
+
print(f"\nš Dataset: {len(train_dataset)} train, {len(val_dataset)} val")
|
|
430
|
+
print(f"š Characters: {len(charset)}\n")
|
|
431
|
+
|
|
432
|
+
# Clear memory before training
|
|
433
|
+
import gc
|
|
434
|
+
gc.collect()
|
|
435
|
+
if torch.cuda.is_available():
|
|
436
|
+
torch.cuda.empty_cache()
|
|
437
|
+
|
|
438
|
+
os.makedirs(args.output_dir, exist_ok=True)
|
|
439
|
+
# charset.save(f'{args.output_dir}/charset_lite.txt') # Not needed if using .kiri
|
|
440
|
+
|
|
441
|
+
train_loader = DataLoader(
|
|
442
|
+
train_dataset,
|
|
443
|
+
batch_size=BATCH_SIZE,
|
|
444
|
+
shuffle=True,
|
|
445
|
+
num_workers=4 if device == "cuda" else 0,
|
|
446
|
+
collate_fn=collate_fn,
|
|
447
|
+
)
|
|
448
|
+
val_loader = DataLoader(
|
|
449
|
+
val_dataset,
|
|
450
|
+
batch_size=BATCH_SIZE,
|
|
451
|
+
shuffle=False,
|
|
452
|
+
num_workers=4 if device == "cuda" else 0,
|
|
453
|
+
collate_fn=collate_fn,
|
|
454
|
+
)
|
|
455
|
+
|
|
456
|
+
model = LightweightOCR(num_chars=len(charset), hidden_size=HIDDEN_SIZE)
|
|
457
|
+
|
|
458
|
+
# Load pretrained model if specified
|
|
459
|
+
if hasattr(args, "from_model") and args.from_model:
|
|
460
|
+
print(f"š Loading pretrained model from {args.from_model}")
|
|
461
|
+
if os.path.exists(args.from_model):
|
|
462
|
+
try:
|
|
463
|
+
checkpoint = torch.load(args.from_model, map_location=device)
|
|
464
|
+
if "model_state_dict" in checkpoint:
|
|
465
|
+
state_dict = checkpoint["model_state_dict"]
|
|
466
|
+
model_dict = model.state_dict()
|
|
467
|
+
|
|
468
|
+
# Filter out unnecessary keys or shape mismatches
|
|
469
|
+
pretrained_dict = {
|
|
470
|
+
k: v
|
|
471
|
+
for k, v in state_dict.items()
|
|
472
|
+
if k in model_dict and v.shape == model_dict[k].shape
|
|
473
|
+
}
|
|
474
|
+
|
|
475
|
+
# Log what was skipped
|
|
476
|
+
skipped_keys = [k for k in state_dict if k not in pretrained_dict]
|
|
477
|
+
if skipped_keys:
|
|
478
|
+
print(
|
|
479
|
+
f" ā ļø Skipped mismatched layers (fine-tuning): {skipped_keys[:5]} ..."
|
|
480
|
+
)
|
|
481
|
+
|
|
482
|
+
# Overwrite entries in the existing state dict
|
|
483
|
+
model_dict.update(pretrained_dict)
|
|
484
|
+
model.load_state_dict(model_dict)
|
|
485
|
+
print(" ā Weights loaded successfully")
|
|
486
|
+
else:
|
|
487
|
+
print(" ā ļø Invalid checkpoint format (missing model_state_dict)")
|
|
488
|
+
except Exception as e:
|
|
489
|
+
print(f" ā Error loading model: {e}")
|
|
490
|
+
else:
|
|
491
|
+
print(f" ā Pretrained model not found: {args.from_model}")
|
|
492
|
+
|
|
493
|
+
total_params = sum(p.numel() for p in model.parameters())
|
|
494
|
+
model_size_mb = total_params * 4 / 1024 / 1024
|
|
495
|
+
|
|
496
|
+
print(f"šļø Model: {total_params:,} params ({model_size_mb:.2f} MB)\n")
|
|
497
|
+
|
|
498
|
+
train_model(
|
|
499
|
+
model,
|
|
500
|
+
train_loader,
|
|
501
|
+
val_loader,
|
|
502
|
+
charset,
|
|
503
|
+
num_epochs=NUM_EPOCHS,
|
|
504
|
+
device=device,
|
|
505
|
+
save_dir=args.output_dir,
|
|
506
|
+
lr=args.lr,
|
|
507
|
+
weight_decay=args.weight_decay,
|
|
508
|
+
)
|