rbx-proofreader 1.0.1__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -5,13 +5,17 @@ from pathlib import Path
5
5
  # --- BASE PATHS ---
6
6
  # Resolves to the 'proofreader' root directory
7
7
  BASE_DIR = Path(__file__).resolve().parent.parent.parent
8
+ BASE_URL = "https://github.com/lucacrose/proofreader"
8
9
 
9
10
  # --- ASSETS & MODELS ---
10
11
  ASSETS_PATH = BASE_DIR / "assets"
11
12
  MODEL_PATH = ASSETS_PATH / "weights" / "yolo.pt"
12
- DB_PATH = ASSETS_PATH / "db.json"
13
- CACHE_PATH = ASSETS_PATH / "embedding_bank.pt"
13
+ DB_PATH = ASSETS_PATH / "item_database.json"
14
+ CACHE_PATH = ASSETS_PATH / "item_embeddings_bank.pt"
14
15
  THUMBNAILS_DIR = ASSETS_PATH / "thumbnails"
16
+ TRAIN_THUMBNAILS_DIR = ASSETS_PATH / "train_data"
17
+ CLASS_MAP_PATH = ASSETS_PATH / "class_mapping.json"
18
+ CLIP_BEST_PATH = ASSETS_PATH / "weights" / "clip.pt"
15
19
 
16
20
  # --- TRAINING & EMULATOR ---
17
21
  TRAIN_DIR = BASE_DIR / "proofreader" / "train"
@@ -26,11 +30,11 @@ DEFAULT_TEMPLATE = TEMPLATES_DIR / "trade_ui.html"
26
30
 
27
31
  # --- HYPERPARAMETERS (Training Settings) ---
28
32
  TRAINING_CONFIG = {
29
- "epochs": 100, # Number of times the model sees the whole dataset
33
+ "epochs": 240, # Number of times the model sees the whole dataset
30
34
  "batch_size": 16, # Number of images processed at once
31
35
  "img_size": 640, # Standard YOLO resolution
32
- "patience": 10, # Stop early if no improvement for 10 epochs
33
- "close_mosaic_epochs": 10 # Disable mosaic augmentation for the last N epochs
36
+ "patience": 20, # Stop early if no improvement for 20 epochs
37
+ "close_mosaic_epochs": 32 # Disable mosaic augmentation for the last N epochs
34
38
  }
35
39
 
36
40
  # --- AUGMENTER PROBABILITIES AND GENERATOR SETTINGS ---
@@ -82,7 +86,7 @@ AUGMENTER_CONFIG = {
82
86
 
83
87
  # Robustness Thresholds
84
88
  FUZZY_MATCH_CONFIDENCE_THRESHOLD = 60.0
85
- VISUAL_MATCH_THRESHOLD = 0.88
89
+ CERTAIN_VISUAL_CONF = 0.995
86
90
 
87
91
  # --- HARDWARE SETTINGS ---
88
92
  # Automatically detects if a GPU is available for faster training
@@ -1,34 +1,56 @@
1
1
  import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
2
4
  import numpy as np
5
+ import json
3
6
  import cv2
4
7
  from PIL import Image
5
- from typing import Dict, List, Any
6
- from .schema import TradeLayout
7
- from proofreader.core.config import VISUAL_MATCH_THRESHOLD
8
+ from torchvision import transforms
9
+ from transformers import CLIPVisionModelWithProjection
10
+ from typing import List
11
+ from .schema import TradeLayout, ResolvedItem
12
+
13
+ class CLIPItemEmbedder(nn.Module):
14
+ def __init__(self, num_classes, model_id="openai/clip-vit-base-patch32"):
15
+ super().__init__()
16
+ self.vision_encoder = CLIPVisionModelWithProjection.from_pretrained(model_id)
17
+ self.item_prototypes = nn.Embedding(num_classes, 512)
18
+ self.logit_scale = nn.Parameter(torch.ones([]) * 2.659)
19
+
20
+ def forward(self, pixel_values):
21
+ outputs = self.vision_encoder(pixel_values=pixel_values)
22
+ return F.normalize(outputs.image_embeds, p=2, dim=-1)
8
23
 
9
24
  class VisualMatcher:
10
- def __init__(self, embedding_bank: Dict[str, np.ndarray], item_db: List[dict], clip_processor: Any, clip_model: Any, device: str = "cuda"):
25
+ def __init__(self, weights_path: str, mapping_path: str, item_db: List[dict], device: str = "cuda"):
11
26
  self.device = device
12
- self.bank = embedding_bank
13
- self.item_db = item_db
14
- self.clip_processor = clip_processor
15
- self.clip_model = clip_model
16
27
 
17
- self.name_to_id = {str(i["name"]).lower().strip(): i["id"] for i in item_db}
28
+ with open(mapping_path, "r") as f:
29
+ self.class_to_idx = json.load(f)
30
+ self.idx_to_class = {v: k for k, v in self.class_to_idx.items()}
31
+
18
32
  self.id_to_name = {str(i["id"]): i["name"] for i in item_db}
33
+ self.name_to_id = {str(i["name"]).lower().strip(): i["id"] for i in item_db}
19
34
 
20
- self.bank_names = list(embedding_bank.keys())
21
- self.bank_tensor = torch.stack([embedding_bank[name] for name in self.bank_names]).to(self.device)
22
- self.bank_tensor = torch.nn.functional.normalize(self.bank_tensor, dim=1)
35
+ num_classes = len(self.class_to_idx)
36
+ self.model = CLIPItemEmbedder(num_classes).to(self.device)
37
+ self.model.load_state_dict(torch.load(weights_path, map_location=self.device))
38
+ self.model.eval()
23
39
 
24
- def _get_id_from_name(self, name: str) -> str:
25
- item = next((i for i in self.item_db if i["name"] == name), None)
26
- return item["id"] if item else 0
40
+ with torch.inference_mode():
41
+ self.bank_tensor = F.normalize(self.model.item_prototypes.weight, p=2, dim=-1)
42
+
43
+ self.preprocess = transforms.Compose([
44
+ transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.BICUBIC),
45
+ transforms.ToTensor(),
46
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
47
+ (0.26862954, 0.26130258, 0.27577711)),
48
+ ])
27
49
 
28
- def match_item_visuals(self, image: np.ndarray, layout: TradeLayout, similarity_threshold: float = VISUAL_MATCH_THRESHOLD):
29
- items_to_process = []
50
+ def match_item_visuals(self, image: np.ndarray, layout: TradeLayout):
51
+ items_to_process: List[ResolvedItem] = []
30
52
  crops = []
31
-
53
+
32
54
  for side in (layout.outgoing.items, layout.incoming.items):
33
55
  for item in side:
34
56
  if item.thumb_box:
@@ -36,32 +58,32 @@ class VisualMatcher:
36
58
  crop = image[y1:y2, x1:x2]
37
59
  if crop.size > 0:
38
60
  pil_img = Image.fromarray(cv2.cvtColor(crop, cv2.COLOR_BGR2RGB))
39
- crops.append(pil_img)
61
+ processed_crop = self.preprocess(pil_img)
62
+ crops.append(processed_crop)
40
63
  items_to_process.append(item)
41
64
 
42
65
  if not crops:
43
66
  return
44
67
 
45
- inputs = self.clip_processor(images=crops, return_tensors="pt", padding=True).to(self.device)
68
+ batch_tensor = torch.stack(crops).to(self.device)
46
69
 
47
- with torch.no_grad():
48
- query_features = self.clip_model.get_image_features(**inputs)
49
- query_features = torch.nn.functional.normalize(query_features, dim=1)
50
- similarities = torch.matmul(query_features, self.bank_tensor.T)
51
- best_scores, best_indices = torch.max(similarities, dim=1)
70
+ with torch.inference_mode():
71
+ query_features = self.model(batch_tensor)
72
+
73
+ logits = query_features @ self.bank_tensor.t() * self.model.logit_scale.exp()
74
+ topk_scores, topk_indices = logits.topk(k=5, dim=1)
75
+
76
+ probs = F.softmax(topk_scores.float(), dim=1)
77
+
78
+ best_idx_in_topk = probs.argmax(dim=1)
79
+ best_indices = topk_indices[torch.arange(len(topk_indices)), best_idx_in_topk]
80
+ best_probs = probs[torch.arange(len(probs)), best_idx_in_topk]
81
+
52
82
 
53
83
  for i, item in enumerate(items_to_process):
54
- visual_match_val = self.bank_names[best_indices[i]]
55
- visual_conf = best_scores[i].item()
84
+ visual_idx = best_indices[i].item()
56
85
 
57
- is_ocr_valid = item.name.lower().strip() in self.name_to_id if item.name else False
86
+ visual_match_id_str = self.idx_to_class[visual_idx]
58
87
 
59
- if (not is_ocr_valid or visual_conf > 0.95) and visual_conf >= similarity_threshold:
60
- if str(visual_match_val).isdigit():
61
- item.id = int(visual_match_val)
62
- item.name = self.id_to_name.get(str(visual_match_val), "Unknown Item")
63
- else:
64
- item.name = visual_match_val
65
- item.id = self._get_id_from_name(visual_match_val)
66
- else:
67
- item.id = self._get_id_from_name(item.name)
88
+ item.visual_id = int(visual_match_id_str)
89
+ item.visual_conf = float(best_probs[i].item())
proofreader/core/ocr.py CHANGED
@@ -3,17 +3,13 @@ import easyocr
3
3
  import numpy as np
4
4
  import re
5
5
  from rapidfuzz import process, utils
6
- from .schema import Box, TradeLayout, TradeSide
6
+ from .schema import TradeLayout
7
7
  from proofreader.core.config import FUZZY_MATCH_CONFIDENCE_THRESHOLD, OCR_LANGUAGES, OCR_USE_GPU
8
8
 
9
9
  class OCRReader:
10
10
  def __init__(self, item_list, languages=OCR_LANGUAGES, gpu=OCR_USE_GPU):
11
11
  self.reader = easyocr.Reader(languages, gpu=gpu)
12
-
13
- self.item_names = []
14
-
15
- for item in item_list:
16
- self.item_names.append(item["name"])
12
+ self.item_names = [item["name"] for item in item_list]
17
13
 
18
14
  def _fuzzy_match_name(self, raw_text: str, threshold: float = FUZZY_MATCH_CONFIDENCE_THRESHOLD) -> str:
19
15
  if not raw_text or len(raw_text) < 2:
@@ -32,48 +28,65 @@ class OCRReader:
32
28
 
33
29
  def _clean_robux_text(self, raw_text: str) -> int:
34
30
  cleaned = raw_text.upper().strip()
35
-
36
31
  substitutions = {
37
32
  ',': '', '.': '', ' ': '',
38
33
  'S': '5', 'O': '0', 'I': '1',
39
34
  'L': '1', 'B': '8', 'G': '6'
40
35
  }
41
-
42
36
  for char, sub in substitutions.items():
43
37
  cleaned = cleaned.replace(char, sub)
44
38
 
45
39
  digits = re.findall(r'\d+', cleaned)
46
-
47
40
  return int("".join(digits)) if digits else 0
48
41
 
49
- def _get_text_from_box(self, image: np.ndarray, box: Box, is_robux: bool = False) -> str:
50
- x1, y1, x2, y2 = box.coords
42
+ def process_layout(self, image: np.ndarray, layout: TradeLayout, skip_if=None):
43
+ all_items = layout.outgoing.items + layout.incoming.items
44
+ crops = []
45
+ target_refs = []
46
+ STD_H = 64
51
47
 
52
- crop = image[max(0, y1-2):y2+2, max(0, x1-2):x2+2]
53
-
54
- if crop.size == 0:
55
- return ""
56
-
57
- gray = cv2.cvtColor(crop, cv2.COLOR_BGR2GRAY)
58
-
59
- if is_robux:
60
- gray = cv2.resize(gray, None, fx=3, fy=3, interpolation=cv2.INTER_CUBIC)
61
- results = self.reader.readtext(gray, allowlist="0123456789,S ")
62
- else:
63
- results = self.reader.readtext(gray)
64
-
65
- return " ".join([res[1] for res in results]).strip()
48
+ for item in all_items:
49
+ if skip_if and skip_if(item):
50
+ continue
66
51
 
67
- def process_side(self, image: np.ndarray, side: TradeSide):
68
- for item in side.items:
69
52
  if item.name_box:
70
- raw_name = self._get_text_from_box(image, item.name_box)
71
- item.name = self._fuzzy_match_name(raw_name)
53
+ x1, y1, x2, y2 = item.name_box.coords
54
+ crop = image[max(0, y1-2):y2+2, max(0, x1-2):x2+2]
55
+ if crop.size > 0:
56
+ gray = cv2.cvtColor(crop, cv2.COLOR_BGR2GRAY)
57
+ h, w = gray.shape
58
+ new_w = int(w * (STD_H / h))
59
+ resized = cv2.resize(gray, (new_w, STD_H), interpolation=cv2.INTER_LINEAR)
60
+ crops.append(resized)
61
+ target_refs.append({'type': 'item', 'obj': item})
62
+
63
+ for side in [layout.outgoing, layout.incoming]:
64
+ if side.robux and side.robux.value_box:
65
+ x1, y1, x2, y2 = side.robux.value_box.coords
66
+ crop = image[max(0, y1-2):y2+2, max(0, x1-2):x2+2]
67
+ if crop.size > 0:
68
+ gray = cv2.cvtColor(crop, cv2.COLOR_BGR2GRAY)
69
+ h, w = gray.shape
70
+ new_w = int(w * (STD_H / h))
71
+ resized = cv2.resize(gray, (new_w, STD_H), interpolation=cv2.INTER_LINEAR)
72
+ crops.append(resized)
73
+ target_refs.append({'type': 'robux', 'obj': side.robux})
72
74
 
73
- if side.robux and side.robux.value_box:
74
- raw_val = self._get_text_from_box(image, side.robux.value_box, is_robux=True)
75
- side.robux.value = self._clean_robux_text(raw_val)
75
+ if not crops:
76
+ return
77
+
78
+ max_w = max(c.shape[1] for c in crops)
79
+ padded_crops = [cv2.copyMakeBorder(c, 0, 0, 0, max_w - c.shape[1], cv2.BORDER_CONSTANT, value=0) for c in crops]
80
+
81
+ batch_results = self.reader.readtext_batched(padded_crops, batch_size=len(padded_crops))
76
82
 
77
- def process_layout(self, image: str, layout: TradeLayout):
78
- self.process_side(image, layout.outgoing)
79
- self.process_side(image, layout.incoming)
83
+ for i, res in enumerate(batch_results):
84
+ raw_text = " ".join([text_info[1] for text_info in res]).strip()
85
+ conf = np.mean([text_info[2] for text_info in res]) if res else 0.0
86
+
87
+ target = target_refs[i]
88
+ if target['type'] == 'item':
89
+ target['obj'].text_name = raw_text
90
+ target['obj'].text_conf = float(conf)
91
+ else:
92
+ target['obj'].value = self._clean_robux_text(raw_text)
@@ -15,6 +15,14 @@ class ResolvedItem:
15
15
  thumb_box: Optional[Box] = None
16
16
  name_box: Optional[Box] = None
17
17
 
18
+ visual_id: int = -1
19
+ visual_conf: float = 0
20
+
21
+ text_name: str = ""
22
+ text_conf: float = 0
23
+
24
+ _finalized: bool = False
25
+
18
26
  @dataclass
19
27
  class ResolvedRobux:
20
28
  value: int = 0
proofreader/main.py CHANGED
@@ -9,7 +9,8 @@ from .core.detector import TradeDetector
9
9
  from .core.resolver import SpatialResolver
10
10
  from .core.ocr import OCRReader
11
11
  from .core.matcher import VisualMatcher
12
- from .core.config import DB_PATH, CACHE_PATH, MODEL_PATH, DEVICE
12
+ from .core.config import DB_PATH, MODEL_PATH, DEVICE, CLASS_MAP_PATH, CLIP_BEST_PATH, BASE_URL, CERTAIN_VISUAL_CONF
13
+ from .core.schema import ResolvedItem
13
14
 
14
15
  class TradeEngine:
15
16
  def __init__(self):
@@ -28,38 +29,32 @@ class TradeEngine:
28
29
 
29
30
  self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(self.device)
30
31
  self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", use_fast=True)
31
-
32
+
32
33
  with open(DB_PATH, "r") as f:
33
34
  item_db = json.load(f)
34
35
 
35
- cache_data = torch.load(CACHE_PATH, weights_only=False)['embeddings']
36
- self.embeddings = {k: torch.tensor(v).to(self.device) for k, v in cache_data.items()}
37
-
38
36
  self.detector = TradeDetector(MODEL_PATH)
39
37
  self.resolver = SpatialResolver()
40
-
41
38
  self.reader = OCRReader(item_db)
42
-
39
+
43
40
  self.matcher = VisualMatcher(
44
- embedding_bank=self.embeddings,
45
41
  item_db=item_db,
46
- clip_processor=self.clip_processor,
47
- clip_model=self.clip_model,
42
+ weights_path=CLIP_BEST_PATH,
43
+ mapping_path=CLASS_MAP_PATH,
48
44
  device=self.device
49
45
  )
50
46
 
51
47
  def _ensure_assets(self):
52
- BASE_URL = "https://github.com/lucacrose/proofreader"
53
-
54
48
  assets = {
55
- DB_PATH: f"{BASE_URL}/releases/download/v1.0.0/db.json",
56
- CACHE_PATH: f"{BASE_URL}/releases/download/v1.0.0/embedding_bank.pt",
57
- MODEL_PATH: f"{BASE_URL}/releases/download/v1.0.0/yolo.pt"
49
+ DB_PATH: f"{BASE_URL}/releases/download/v1.1.0/item_database.json",
50
+ MODEL_PATH: f"{BASE_URL}/releases/download/v1.1.0/yolo.pt",
51
+ CLIP_BEST_PATH: f"{BASE_URL}/releases/download/v1.1.0/clip.pt",
52
+ CLASS_MAP_PATH: f"{BASE_URL}/releases/download/v1.1.0/class_mapping.json"
58
53
  }
59
54
 
60
55
  for path, url in assets.items():
61
56
  if not path.exists():
62
- print(f"📦 {path.name} missing. Downloading from published release...")
57
+ print(f"📦 {path.name} missing. Downloading from latest release...")
63
58
  self._download_file(url, path)
64
59
 
65
60
  def _download_file(self, url, dest_path):
@@ -73,17 +68,73 @@ class TradeEngine:
73
68
  f.write(chunk)
74
69
  pbar.update(len(chunk))
75
70
 
71
+ def _final_judge(self, item: ResolvedItem):
72
+ if getattr(item, "_finalized", False):
73
+ return
74
+
75
+ v_id = item.visual_id
76
+ v_conf = item.visual_conf
77
+
78
+ ocr_name_raw = item.text_name.lower().strip()
79
+ ocr_id_direct = self.matcher.name_to_id.get(ocr_name_raw)
80
+ ocr_conf = item.text_conf / 100.0 if item.text_conf > 1 else item.text_conf
81
+
82
+ if v_id != -1 and v_id == ocr_id_direct:
83
+ item.id = v_id
84
+ item.name = self.matcher.id_to_name.get(str(v_id))
85
+ return
86
+
87
+ if v_conf > 0.85:
88
+ item.id = v_id
89
+ item.name = self.matcher.id_to_name.get(str(v_id))
90
+ return
91
+
92
+ if ocr_conf > 0.85 and ocr_id_direct:
93
+ item.id = ocr_id_direct
94
+ item.name = self.matcher.id_to_name.get(str(ocr_id_direct))
95
+ return
96
+
97
+ if len(ocr_name_raw) > 2:
98
+ fuzzy_name = self.reader._fuzzy_match_name(ocr_name_raw)
99
+ fuzzy_id = self.matcher.name_to_id.get(fuzzy_name.lower())
100
+
101
+ if fuzzy_id:
102
+ item.id = int(fuzzy_id)
103
+ item.name = fuzzy_name
104
+ return
105
+
106
+ if v_conf >= ocr_conf and v_id != -1:
107
+ item.id = v_id
108
+ item.name = self.matcher.id_to_name.get(str(v_id))
109
+ elif ocr_id_direct:
110
+ item.id = ocr_id_direct
111
+ item.name = self.matcher.id_to_name.get(str(ocr_id_direct))
112
+
76
113
  def process_image(self, image_path: str, conf_threshold: float) -> dict:
77
114
  if not os.path.exists(image_path):
78
115
  raise FileNotFoundError(f"Image not found: {image_path}")
79
116
 
80
117
  boxes = self.detector.detect(image_path, conf_threshold)
81
118
  layout = self.resolver.resolve(boxes)
82
-
83
119
  image = cv2.imread(image_path)
84
120
 
85
- self.reader.process_layout(image, layout)
86
-
87
121
  self.matcher.match_item_visuals(image, layout)
88
122
 
123
+ for side in [layout.outgoing, layout.incoming]:
124
+ for item in side.items:
125
+ if item.visual_id != -1 and item.visual_conf >= CERTAIN_VISUAL_CONF:
126
+ item.id = item.visual_id
127
+ item.name = self.matcher.id_to_name.get(str(item.visual_id), "Unknown")
128
+ item._finalized = True
129
+
130
+ self.reader.process_layout(
131
+ image,
132
+ layout,
133
+ skip_if=lambda item: getattr(item, "_finalized", False)
134
+ )
135
+
136
+ for side in [layout.outgoing, layout.incoming]:
137
+ for item in side.items:
138
+ self._final_judge(item)
139
+
89
140
  return layout.to_dict()
@@ -0,0 +1,173 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.utils.data import DataLoader
5
+ from torchvision import datasets, transforms
6
+ from transformers import CLIPVisionModelWithProjection
7
+ from tqdm import tqdm
8
+ from torch.amp import GradScaler, autocast
9
+ from proofreader.core.config import CLASS_MAP_PATH, CLIP_BEST_PATH, DATASET_ROOT
10
+ import os
11
+ import json
12
+ import numpy as np
13
+ import random
14
+
15
+ MODEL_ID = "openai/clip-vit-base-patch32"
16
+ EPOCHS = 10
17
+ BATCH_SIZE = 48
18
+ LEARNING_RATE = 1e-5
19
+ EMBEDDING_DIM = 512
20
+ WEIGHT_DECAY = 0.1
21
+ PATIENCE = 3 # Stop if no improvement for 3 epochs
22
+ MIN_DELTA = 0.1 # Minimum % improvement to be considered "better"
23
+
24
+ def set_seed(seed: int = 42):
25
+ random.seed(seed)
26
+ np.random.seed(seed)
27
+ os.environ["PYTHONHASHSEED"] = str(seed)
28
+
29
+ torch.manual_seed(seed)
30
+ torch.cuda.manual_seed(seed)
31
+ torch.cuda.manual_seed_all(seed)
32
+
33
+ torch.backends.cudnn.deterministic = True
34
+ torch.backends.cudnn.benchmark = False
35
+
36
+ class CLIPItemEmbedder(nn.Module):
37
+ def __init__(self, num_classes):
38
+ super().__init__()
39
+ self.vision_encoder = CLIPVisionModelWithProjection.from_pretrained(MODEL_ID)
40
+ self.item_prototypes = nn.Embedding(num_classes, EMBEDDING_DIM)
41
+ self.logit_scale = nn.Parameter(torch.ones([]) * 2.659)
42
+
43
+ def forward(self, pixel_values, item_ids):
44
+ outputs = self.vision_encoder(pixel_values=pixel_values)
45
+ image_embeds = outputs.image_embeds
46
+
47
+ label_embeds = self.item_prototypes(item_ids)
48
+ label_embeds = F.normalize(label_embeds, p=2, dim=-1)
49
+
50
+ return image_embeds, label_embeds, self.logit_scale.exp()
51
+
52
+ class EarlyStopper:
53
+ def __init__(self, patience=3, min_delta=0.05):
54
+ self.patience = patience
55
+ self.min_delta = min_delta
56
+ self.counter = 0
57
+ self.best_accuracy = 0
58
+ self.best_state = None
59
+
60
+ def check(self, current_accuracy, model):
61
+ if current_accuracy > (self.best_accuracy + self.min_delta):
62
+ self.best_accuracy = current_accuracy
63
+ self.best_state = getattr(model, "_orig_mod", model).state_dict()
64
+ self.counter = 0
65
+ return False, True
66
+ else:
67
+ self.counter += 1
68
+ return (self.counter >= self.patience), False
69
+
70
+ def get_transforms():
71
+ return transforms.Compose([
72
+ transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC),
73
+ # Resolution Crush
74
+ transforms.RandomApply([
75
+ transforms.RandomChoice([transforms.Resize(128), transforms.Resize(64)]),
76
+ transforms.Resize(224),
77
+ ], p=0.3),
78
+ # Gaussian Blur
79
+ transforms.RandomApply([
80
+ transforms.GaussianBlur(kernel_size=(3, 5), sigma=(0.1, 2.0))
81
+ ], p=0.2),
82
+ transforms.CenterCrop(224),
83
+ transforms.RandomHorizontalFlip(),
84
+ transforms.ColorJitter(0.3, 0.3, 0.3),
85
+ transforms.ToTensor(),
86
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
87
+ (0.26862954, 0.26130258, 0.27577711)),
88
+ ])
89
+
90
+ def train_clip():
91
+ set_seed(1)
92
+
93
+ torch.backends.cudnn.benchmark = True
94
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
95
+
96
+ dataset_path = f"{DATASET_ROOT}/classification"
97
+ full_dataset = datasets.ImageFolder(root=dataset_path, transform=get_transforms())
98
+ num_classes = len(full_dataset.classes)
99
+
100
+ with open(CLASS_MAP_PATH, "w") as f:
101
+ json.dump(full_dataset.class_to_idx, f, separators=(",", ":"))
102
+
103
+ train_size = int(0.95 * len(full_dataset))
104
+ train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_size, len(full_dataset)-train_size])
105
+
106
+ train_loader = DataLoader(
107
+ train_dataset, batch_size=BATCH_SIZE, shuffle=True,
108
+ num_workers=os.cpu_count(), pin_memory=True, prefetch_factor=2, persistent_workers=True
109
+ )
110
+ val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, pin_memory=True)
111
+
112
+ model = CLIPItemEmbedder(num_classes).to(device)
113
+ optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
114
+ scaler = GradScaler('cuda')
115
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
116
+ stopper = EarlyStopper(patience=PATIENCE, min_delta=MIN_DELTA)
117
+
118
+ print(f"Starting training for {num_classes} classes...")
119
+
120
+ for epoch in range(EPOCHS):
121
+ model.train()
122
+ loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}")
123
+
124
+ for images, labels in loop:
125
+ images, labels = images.to(device, non_blocking=True), labels.to(device, non_blocking=True)
126
+ optimizer.zero_grad(set_to_none=True)
127
+
128
+ with autocast('cuda'):
129
+ img_emb, _, scale = model(images, labels)
130
+ img_emb = F.normalize(img_emb, p=2, dim=-1)
131
+
132
+ all_ids = torch.arange(num_classes, device=device)
133
+ prototypes = F.normalize(model.item_prototypes(all_ids), p=2, dim=-1)
134
+
135
+ logits = scale * img_emb @ prototypes.t()
136
+ loss = F.cross_entropy(logits, labels)
137
+
138
+ scaler.scale(loss).backward()
139
+ scaler.step(optimizer)
140
+ scaler.update()
141
+ loop.set_postfix(loss=f"{loss.item():.4f}")
142
+
143
+ scheduler.step()
144
+
145
+ model.eval()
146
+ correct, total = 0, 0
147
+ with torch.no_grad(), autocast('cuda'):
148
+ all_ids = torch.arange(num_classes).to(device)
149
+ prototypes = F.normalize(model.item_prototypes(all_ids), p=2, dim=-1)
150
+
151
+ for images, labels in val_loader:
152
+ images, labels = images.to(device), labels.to(device)
153
+ img_emb, _, _ = model(images, labels)
154
+ preds = (img_emb @ prototypes.t()).argmax(dim=-1)
155
+ correct += (preds == labels).sum().item()
156
+ total += labels.size(0)
157
+
158
+ val_acc = 100 * correct / total
159
+ print(f"Validation Accuracy: {val_acc:.2f}%")
160
+
161
+ stop_now, is_best = stopper.check(val_acc, model)
162
+ if is_best:
163
+ torch.save(stopper.best_state, CLIP_BEST_PATH)
164
+ print("Successfully saved new best model weights.")
165
+
166
+ if stop_now:
167
+ print(f"Stopping early. Best Accuracy was {stopper.best_accuracy:.2f}%")
168
+ break
169
+
170
+ print("Training finished.")
171
+
172
+ if __name__ == "__main__":
173
+ train_clip()