rbx-proofreader 1.0.1__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- proofreader/core/config.py +10 -6
- proofreader/core/matcher.py +59 -37
- proofreader/core/ocr.py +48 -35
- proofreader/core/schema.py +8 -0
- proofreader/main.py +70 -19
- proofreader/train/clip_trainer.py +173 -0
- proofreader/train/emulator/generator.py +185 -137
- proofreader/train/{train.py → yolo_trainer.py} +5 -8
- rbx_proofreader-1.1.0.dist-info/METADATA +160 -0
- rbx_proofreader-1.1.0.dist-info/RECORD +17 -0
- {rbx_proofreader-1.0.1.dist-info → rbx_proofreader-1.1.0.dist-info}/WHEEL +1 -1
- proofreader/train/builder.py +0 -94
- rbx_proofreader-1.0.1.dist-info/METADATA +0 -128
- rbx_proofreader-1.0.1.dist-info/RECORD +0 -17
- {rbx_proofreader-1.0.1.dist-info → rbx_proofreader-1.1.0.dist-info}/licenses/LICENSE +0 -0
- {rbx_proofreader-1.0.1.dist-info → rbx_proofreader-1.1.0.dist-info}/top_level.txt +0 -0
proofreader/core/config.py
CHANGED
|
@@ -5,13 +5,17 @@ from pathlib import Path
|
|
|
5
5
|
# --- BASE PATHS ---
|
|
6
6
|
# Resolves to the 'proofreader' root directory
|
|
7
7
|
BASE_DIR = Path(__file__).resolve().parent.parent.parent
|
|
8
|
+
BASE_URL = "https://github.com/lucacrose/proofreader/releases/latest/download"
|
|
8
9
|
|
|
9
10
|
# --- ASSETS & MODELS ---
|
|
10
11
|
ASSETS_PATH = BASE_DIR / "assets"
|
|
11
12
|
MODEL_PATH = ASSETS_PATH / "weights" / "yolo.pt"
|
|
12
|
-
DB_PATH = ASSETS_PATH / "
|
|
13
|
-
CACHE_PATH = ASSETS_PATH / "
|
|
13
|
+
DB_PATH = ASSETS_PATH / "item_database.json"
|
|
14
|
+
CACHE_PATH = ASSETS_PATH / "item_embeddings_bank.pt"
|
|
14
15
|
THUMBNAILS_DIR = ASSETS_PATH / "thumbnails"
|
|
16
|
+
TRAIN_THUMBNAILS_DIR = ASSETS_PATH / "train_data"
|
|
17
|
+
CLASS_MAP_PATH = ASSETS_PATH / "class_mapping.json"
|
|
18
|
+
CLIP_BEST_PATH = ASSETS_PATH / "weights" / "clip.pt"
|
|
15
19
|
|
|
16
20
|
# --- TRAINING & EMULATOR ---
|
|
17
21
|
TRAIN_DIR = BASE_DIR / "proofreader" / "train"
|
|
@@ -26,11 +30,11 @@ DEFAULT_TEMPLATE = TEMPLATES_DIR / "trade_ui.html"
|
|
|
26
30
|
|
|
27
31
|
# --- HYPERPARAMETERS (Training Settings) ---
|
|
28
32
|
TRAINING_CONFIG = {
|
|
29
|
-
"epochs":
|
|
33
|
+
"epochs": 240, # Number of times the model sees the whole dataset
|
|
30
34
|
"batch_size": 16, # Number of images processed at once
|
|
31
35
|
"img_size": 640, # Standard YOLO resolution
|
|
32
|
-
"patience":
|
|
33
|
-
"close_mosaic_epochs":
|
|
36
|
+
"patience": 20, # Stop early if no improvement for 20 epochs
|
|
37
|
+
"close_mosaic_epochs": 32 # Disable mosaic augmentation for the last N epochs
|
|
34
38
|
}
|
|
35
39
|
|
|
36
40
|
# --- AUGMENTER PROBABILITIES AND GENERATOR SETTINGS ---
|
|
@@ -82,7 +86,7 @@ AUGMENTER_CONFIG = {
|
|
|
82
86
|
|
|
83
87
|
# Robustness Thresholds
|
|
84
88
|
FUZZY_MATCH_CONFIDENCE_THRESHOLD = 60.0
|
|
85
|
-
|
|
89
|
+
CERTAIN_VISUAL_CONF = 0.995
|
|
86
90
|
|
|
87
91
|
# --- HARDWARE SETTINGS ---
|
|
88
92
|
# Automatically detects if a GPU is available for faster training
|
proofreader/core/matcher.py
CHANGED
|
@@ -1,34 +1,56 @@
|
|
|
1
1
|
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import torch.nn.functional as F
|
|
2
4
|
import numpy as np
|
|
5
|
+
import json
|
|
3
6
|
import cv2
|
|
4
7
|
from PIL import Image
|
|
5
|
-
from
|
|
6
|
-
from
|
|
7
|
-
from
|
|
8
|
+
from torchvision import transforms
|
|
9
|
+
from transformers import CLIPVisionModelWithProjection
|
|
10
|
+
from typing import List
|
|
11
|
+
from .schema import TradeLayout, ResolvedItem
|
|
12
|
+
|
|
13
|
+
class CLIPItemEmbedder(nn.Module):
|
|
14
|
+
def __init__(self, num_classes, model_id="openai/clip-vit-base-patch32"):
|
|
15
|
+
super().__init__()
|
|
16
|
+
self.vision_encoder = CLIPVisionModelWithProjection.from_pretrained(model_id)
|
|
17
|
+
self.item_prototypes = nn.Embedding(num_classes, 512)
|
|
18
|
+
self.logit_scale = nn.Parameter(torch.ones([]) * 2.659)
|
|
19
|
+
|
|
20
|
+
def forward(self, pixel_values):
|
|
21
|
+
outputs = self.vision_encoder(pixel_values=pixel_values)
|
|
22
|
+
return F.normalize(outputs.image_embeds, p=2, dim=-1)
|
|
8
23
|
|
|
9
24
|
class VisualMatcher:
|
|
10
|
-
def __init__(self,
|
|
25
|
+
def __init__(self, weights_path: str, mapping_path: str, item_db: List[dict], device: str = "cuda"):
|
|
11
26
|
self.device = device
|
|
12
|
-
self.bank = embedding_bank
|
|
13
|
-
self.item_db = item_db
|
|
14
|
-
self.clip_processor = clip_processor
|
|
15
|
-
self.clip_model = clip_model
|
|
16
27
|
|
|
17
|
-
|
|
28
|
+
with open(mapping_path, "r") as f:
|
|
29
|
+
self.class_to_idx = json.load(f)
|
|
30
|
+
self.idx_to_class = {v: k for k, v in self.class_to_idx.items()}
|
|
31
|
+
|
|
18
32
|
self.id_to_name = {str(i["id"]): i["name"] for i in item_db}
|
|
33
|
+
self.name_to_id = {str(i["name"]).lower().strip(): i["id"] for i in item_db}
|
|
19
34
|
|
|
20
|
-
|
|
21
|
-
self.
|
|
22
|
-
self.
|
|
35
|
+
num_classes = len(self.class_to_idx)
|
|
36
|
+
self.model = CLIPItemEmbedder(num_classes).to(self.device)
|
|
37
|
+
self.model.load_state_dict(torch.load(weights_path, map_location=self.device))
|
|
38
|
+
self.model.eval()
|
|
23
39
|
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
40
|
+
with torch.inference_mode():
|
|
41
|
+
self.bank_tensor = F.normalize(self.model.item_prototypes.weight, p=2, dim=-1)
|
|
42
|
+
|
|
43
|
+
self.preprocess = transforms.Compose([
|
|
44
|
+
transforms.Resize((224, 224), interpolation=transforms.InterpolationMode.BICUBIC),
|
|
45
|
+
transforms.ToTensor(),
|
|
46
|
+
transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
|
|
47
|
+
(0.26862954, 0.26130258, 0.27577711)),
|
|
48
|
+
])
|
|
27
49
|
|
|
28
|
-
def match_item_visuals(self, image: np.ndarray, layout: TradeLayout
|
|
29
|
-
items_to_process = []
|
|
50
|
+
def match_item_visuals(self, image: np.ndarray, layout: TradeLayout):
|
|
51
|
+
items_to_process: List[ResolvedItem] = []
|
|
30
52
|
crops = []
|
|
31
|
-
|
|
53
|
+
|
|
32
54
|
for side in (layout.outgoing.items, layout.incoming.items):
|
|
33
55
|
for item in side:
|
|
34
56
|
if item.thumb_box:
|
|
@@ -36,32 +58,32 @@ class VisualMatcher:
|
|
|
36
58
|
crop = image[y1:y2, x1:x2]
|
|
37
59
|
if crop.size > 0:
|
|
38
60
|
pil_img = Image.fromarray(cv2.cvtColor(crop, cv2.COLOR_BGR2RGB))
|
|
39
|
-
|
|
61
|
+
processed_crop = self.preprocess(pil_img)
|
|
62
|
+
crops.append(processed_crop)
|
|
40
63
|
items_to_process.append(item)
|
|
41
64
|
|
|
42
65
|
if not crops:
|
|
43
66
|
return
|
|
44
67
|
|
|
45
|
-
|
|
68
|
+
batch_tensor = torch.stack(crops).to(self.device)
|
|
46
69
|
|
|
47
|
-
with torch.
|
|
48
|
-
query_features = self.
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
70
|
+
with torch.inference_mode():
|
|
71
|
+
query_features = self.model(batch_tensor)
|
|
72
|
+
|
|
73
|
+
logits = query_features @ self.bank_tensor.t() * self.model.logit_scale.exp()
|
|
74
|
+
topk_scores, topk_indices = logits.topk(k=5, dim=1)
|
|
75
|
+
|
|
76
|
+
probs = F.softmax(topk_scores.float(), dim=1)
|
|
77
|
+
|
|
78
|
+
best_idx_in_topk = probs.argmax(dim=1)
|
|
79
|
+
best_indices = topk_indices[torch.arange(len(topk_indices)), best_idx_in_topk]
|
|
80
|
+
best_probs = probs[torch.arange(len(probs)), best_idx_in_topk]
|
|
81
|
+
|
|
52
82
|
|
|
53
83
|
for i, item in enumerate(items_to_process):
|
|
54
|
-
|
|
55
|
-
visual_conf = best_scores[i].item()
|
|
84
|
+
visual_idx = best_indices[i].item()
|
|
56
85
|
|
|
57
|
-
|
|
86
|
+
visual_match_id_str = self.idx_to_class[visual_idx]
|
|
58
87
|
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
item.id = int(visual_match_val)
|
|
62
|
-
item.name = self.id_to_name.get(str(visual_match_val), "Unknown Item")
|
|
63
|
-
else:
|
|
64
|
-
item.name = visual_match_val
|
|
65
|
-
item.id = self._get_id_from_name(visual_match_val)
|
|
66
|
-
else:
|
|
67
|
-
item.id = self._get_id_from_name(item.name)
|
|
88
|
+
item.visual_id = int(visual_match_id_str)
|
|
89
|
+
item.visual_conf = float(best_probs[i].item())
|
proofreader/core/ocr.py
CHANGED
|
@@ -3,17 +3,13 @@ import easyocr
|
|
|
3
3
|
import numpy as np
|
|
4
4
|
import re
|
|
5
5
|
from rapidfuzz import process, utils
|
|
6
|
-
from .schema import
|
|
6
|
+
from .schema import TradeLayout
|
|
7
7
|
from proofreader.core.config import FUZZY_MATCH_CONFIDENCE_THRESHOLD, OCR_LANGUAGES, OCR_USE_GPU
|
|
8
8
|
|
|
9
9
|
class OCRReader:
|
|
10
10
|
def __init__(self, item_list, languages=OCR_LANGUAGES, gpu=OCR_USE_GPU):
|
|
11
11
|
self.reader = easyocr.Reader(languages, gpu=gpu)
|
|
12
|
-
|
|
13
|
-
self.item_names = []
|
|
14
|
-
|
|
15
|
-
for item in item_list:
|
|
16
|
-
self.item_names.append(item["name"])
|
|
12
|
+
self.item_names = [item["name"] for item in item_list]
|
|
17
13
|
|
|
18
14
|
def _fuzzy_match_name(self, raw_text: str, threshold: float = FUZZY_MATCH_CONFIDENCE_THRESHOLD) -> str:
|
|
19
15
|
if not raw_text or len(raw_text) < 2:
|
|
@@ -32,48 +28,65 @@ class OCRReader:
|
|
|
32
28
|
|
|
33
29
|
def _clean_robux_text(self, raw_text: str) -> int:
|
|
34
30
|
cleaned = raw_text.upper().strip()
|
|
35
|
-
|
|
36
31
|
substitutions = {
|
|
37
32
|
',': '', '.': '', ' ': '',
|
|
38
33
|
'S': '5', 'O': '0', 'I': '1',
|
|
39
34
|
'L': '1', 'B': '8', 'G': '6'
|
|
40
35
|
}
|
|
41
|
-
|
|
42
36
|
for char, sub in substitutions.items():
|
|
43
37
|
cleaned = cleaned.replace(char, sub)
|
|
44
38
|
|
|
45
39
|
digits = re.findall(r'\d+', cleaned)
|
|
46
|
-
|
|
47
40
|
return int("".join(digits)) if digits else 0
|
|
48
41
|
|
|
49
|
-
def
|
|
50
|
-
|
|
42
|
+
def process_layout(self, image: np.ndarray, layout: TradeLayout, skip_if=None):
|
|
43
|
+
all_items = layout.outgoing.items + layout.incoming.items
|
|
44
|
+
crops = []
|
|
45
|
+
target_refs = []
|
|
46
|
+
STD_H = 64
|
|
51
47
|
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
return ""
|
|
56
|
-
|
|
57
|
-
gray = cv2.cvtColor(crop, cv2.COLOR_BGR2GRAY)
|
|
58
|
-
|
|
59
|
-
if is_robux:
|
|
60
|
-
gray = cv2.resize(gray, None, fx=3, fy=3, interpolation=cv2.INTER_CUBIC)
|
|
61
|
-
results = self.reader.readtext(gray, allowlist="0123456789,S ")
|
|
62
|
-
else:
|
|
63
|
-
results = self.reader.readtext(gray)
|
|
64
|
-
|
|
65
|
-
return " ".join([res[1] for res in results]).strip()
|
|
48
|
+
for item in all_items:
|
|
49
|
+
if skip_if and skip_if(item):
|
|
50
|
+
continue
|
|
66
51
|
|
|
67
|
-
def process_side(self, image: np.ndarray, side: TradeSide):
|
|
68
|
-
for item in side.items:
|
|
69
52
|
if item.name_box:
|
|
70
|
-
|
|
71
|
-
|
|
53
|
+
x1, y1, x2, y2 = item.name_box.coords
|
|
54
|
+
crop = image[max(0, y1-2):y2+2, max(0, x1-2):x2+2]
|
|
55
|
+
if crop.size > 0:
|
|
56
|
+
gray = cv2.cvtColor(crop, cv2.COLOR_BGR2GRAY)
|
|
57
|
+
h, w = gray.shape
|
|
58
|
+
new_w = int(w * (STD_H / h))
|
|
59
|
+
resized = cv2.resize(gray, (new_w, STD_H), interpolation=cv2.INTER_LINEAR)
|
|
60
|
+
crops.append(resized)
|
|
61
|
+
target_refs.append({'type': 'item', 'obj': item})
|
|
62
|
+
|
|
63
|
+
for side in [layout.outgoing, layout.incoming]:
|
|
64
|
+
if side.robux and side.robux.value_box:
|
|
65
|
+
x1, y1, x2, y2 = side.robux.value_box.coords
|
|
66
|
+
crop = image[max(0, y1-2):y2+2, max(0, x1-2):x2+2]
|
|
67
|
+
if crop.size > 0:
|
|
68
|
+
gray = cv2.cvtColor(crop, cv2.COLOR_BGR2GRAY)
|
|
69
|
+
h, w = gray.shape
|
|
70
|
+
new_w = int(w * (STD_H / h))
|
|
71
|
+
resized = cv2.resize(gray, (new_w, STD_H), interpolation=cv2.INTER_LINEAR)
|
|
72
|
+
crops.append(resized)
|
|
73
|
+
target_refs.append({'type': 'robux', 'obj': side.robux})
|
|
72
74
|
|
|
73
|
-
if
|
|
74
|
-
|
|
75
|
-
|
|
75
|
+
if not crops:
|
|
76
|
+
return
|
|
77
|
+
|
|
78
|
+
max_w = max(c.shape[1] for c in crops)
|
|
79
|
+
padded_crops = [cv2.copyMakeBorder(c, 0, 0, 0, max_w - c.shape[1], cv2.BORDER_CONSTANT, value=0) for c in crops]
|
|
80
|
+
|
|
81
|
+
batch_results = self.reader.readtext_batched(padded_crops, batch_size=len(padded_crops))
|
|
76
82
|
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
83
|
+
for i, res in enumerate(batch_results):
|
|
84
|
+
raw_text = " ".join([text_info[1] for text_info in res]).strip()
|
|
85
|
+
conf = np.mean([text_info[2] for text_info in res]) if res else 0.0
|
|
86
|
+
|
|
87
|
+
target = target_refs[i]
|
|
88
|
+
if target['type'] == 'item':
|
|
89
|
+
target['obj'].text_name = raw_text
|
|
90
|
+
target['obj'].text_conf = float(conf)
|
|
91
|
+
else:
|
|
92
|
+
target['obj'].value = self._clean_robux_text(raw_text)
|
proofreader/core/schema.py
CHANGED
|
@@ -15,6 +15,14 @@ class ResolvedItem:
|
|
|
15
15
|
thumb_box: Optional[Box] = None
|
|
16
16
|
name_box: Optional[Box] = None
|
|
17
17
|
|
|
18
|
+
visual_id: int = -1
|
|
19
|
+
visual_conf: float = 0
|
|
20
|
+
|
|
21
|
+
text_name: str = ""
|
|
22
|
+
text_conf: float = 0
|
|
23
|
+
|
|
24
|
+
_finalized: bool = False
|
|
25
|
+
|
|
18
26
|
@dataclass
|
|
19
27
|
class ResolvedRobux:
|
|
20
28
|
value: int = 0
|
proofreader/main.py
CHANGED
|
@@ -9,7 +9,8 @@ from .core.detector import TradeDetector
|
|
|
9
9
|
from .core.resolver import SpatialResolver
|
|
10
10
|
from .core.ocr import OCRReader
|
|
11
11
|
from .core.matcher import VisualMatcher
|
|
12
|
-
from .core.config import DB_PATH,
|
|
12
|
+
from .core.config import DB_PATH, MODEL_PATH, DEVICE, CLASS_MAP_PATH, CLIP_BEST_PATH, BASE_URL, CERTAIN_VISUAL_CONF
|
|
13
|
+
from .core.schema import ResolvedItem
|
|
13
14
|
|
|
14
15
|
class TradeEngine:
|
|
15
16
|
def __init__(self):
|
|
@@ -28,38 +29,32 @@ class TradeEngine:
|
|
|
28
29
|
|
|
29
30
|
self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(self.device)
|
|
30
31
|
self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", use_fast=True)
|
|
31
|
-
|
|
32
|
+
|
|
32
33
|
with open(DB_PATH, "r") as f:
|
|
33
34
|
item_db = json.load(f)
|
|
34
35
|
|
|
35
|
-
cache_data = torch.load(CACHE_PATH, weights_only=False)['embeddings']
|
|
36
|
-
self.embeddings = {k: torch.tensor(v).to(self.device) for k, v in cache_data.items()}
|
|
37
|
-
|
|
38
36
|
self.detector = TradeDetector(MODEL_PATH)
|
|
39
37
|
self.resolver = SpatialResolver()
|
|
40
|
-
|
|
41
38
|
self.reader = OCRReader(item_db)
|
|
42
|
-
|
|
39
|
+
|
|
43
40
|
self.matcher = VisualMatcher(
|
|
44
|
-
embedding_bank=self.embeddings,
|
|
45
41
|
item_db=item_db,
|
|
46
|
-
|
|
47
|
-
|
|
42
|
+
weights_path=CLIP_BEST_PATH,
|
|
43
|
+
mapping_path=CLASS_MAP_PATH,
|
|
48
44
|
device=self.device
|
|
49
45
|
)
|
|
50
46
|
|
|
51
47
|
def _ensure_assets(self):
|
|
52
|
-
BASE_URL = "https://github.com/lucacrose/proofreader"
|
|
53
|
-
|
|
54
48
|
assets = {
|
|
55
|
-
DB_PATH: f"{BASE_URL}/
|
|
56
|
-
|
|
57
|
-
|
|
49
|
+
DB_PATH: f"{BASE_URL}/item_database.json",
|
|
50
|
+
MODEL_PATH: f"{BASE_URL}/yolo.pt",
|
|
51
|
+
CLIP_BEST_PATH: f"{BASE_URL}/clip.pt",
|
|
52
|
+
CLASS_MAP_PATH: f"{BASE_URL}/class_mapping.json"
|
|
58
53
|
}
|
|
59
54
|
|
|
60
55
|
for path, url in assets.items():
|
|
61
56
|
if not path.exists():
|
|
62
|
-
print(f"📦 {path.name} missing. Downloading from
|
|
57
|
+
print(f"📦 {path.name} missing. Downloading from latest release...")
|
|
63
58
|
self._download_file(url, path)
|
|
64
59
|
|
|
65
60
|
def _download_file(self, url, dest_path):
|
|
@@ -73,17 +68,73 @@ class TradeEngine:
|
|
|
73
68
|
f.write(chunk)
|
|
74
69
|
pbar.update(len(chunk))
|
|
75
70
|
|
|
71
|
+
def _final_judge(self, item: ResolvedItem):
|
|
72
|
+
if getattr(item, "_finalized", False):
|
|
73
|
+
return
|
|
74
|
+
|
|
75
|
+
v_id = item.visual_id
|
|
76
|
+
v_conf = item.visual_conf
|
|
77
|
+
|
|
78
|
+
ocr_name_raw = item.text_name.lower().strip()
|
|
79
|
+
ocr_id_direct = self.matcher.name_to_id.get(ocr_name_raw)
|
|
80
|
+
ocr_conf = item.text_conf / 100.0 if item.text_conf > 1 else item.text_conf
|
|
81
|
+
|
|
82
|
+
if v_id != -1 and v_id == ocr_id_direct:
|
|
83
|
+
item.id = v_id
|
|
84
|
+
item.name = self.matcher.id_to_name.get(str(v_id))
|
|
85
|
+
return
|
|
86
|
+
|
|
87
|
+
if v_conf > 0.85:
|
|
88
|
+
item.id = v_id
|
|
89
|
+
item.name = self.matcher.id_to_name.get(str(v_id))
|
|
90
|
+
return
|
|
91
|
+
|
|
92
|
+
if ocr_conf > 0.85 and ocr_id_direct:
|
|
93
|
+
item.id = ocr_id_direct
|
|
94
|
+
item.name = self.matcher.id_to_name.get(str(ocr_id_direct))
|
|
95
|
+
return
|
|
96
|
+
|
|
97
|
+
if len(ocr_name_raw) > 2:
|
|
98
|
+
fuzzy_name = self.reader._fuzzy_match_name(ocr_name_raw)
|
|
99
|
+
fuzzy_id = self.matcher.name_to_id.get(fuzzy_name.lower())
|
|
100
|
+
|
|
101
|
+
if fuzzy_id:
|
|
102
|
+
item.id = int(fuzzy_id)
|
|
103
|
+
item.name = fuzzy_name
|
|
104
|
+
return
|
|
105
|
+
|
|
106
|
+
if v_conf >= ocr_conf and v_id != -1:
|
|
107
|
+
item.id = v_id
|
|
108
|
+
item.name = self.matcher.id_to_name.get(str(v_id))
|
|
109
|
+
elif ocr_id_direct:
|
|
110
|
+
item.id = ocr_id_direct
|
|
111
|
+
item.name = self.matcher.id_to_name.get(str(ocr_id_direct))
|
|
112
|
+
|
|
76
113
|
def process_image(self, image_path: str, conf_threshold: float) -> dict:
|
|
77
114
|
if not os.path.exists(image_path):
|
|
78
115
|
raise FileNotFoundError(f"Image not found: {image_path}")
|
|
79
116
|
|
|
80
117
|
boxes = self.detector.detect(image_path, conf_threshold)
|
|
81
118
|
layout = self.resolver.resolve(boxes)
|
|
82
|
-
|
|
83
119
|
image = cv2.imread(image_path)
|
|
84
120
|
|
|
85
|
-
self.reader.process_layout(image, layout)
|
|
86
|
-
|
|
87
121
|
self.matcher.match_item_visuals(image, layout)
|
|
88
122
|
|
|
123
|
+
for side in [layout.outgoing, layout.incoming]:
|
|
124
|
+
for item in side.items:
|
|
125
|
+
if item.visual_id != -1 and item.visual_conf >= CERTAIN_VISUAL_CONF:
|
|
126
|
+
item.id = item.visual_id
|
|
127
|
+
item.name = self.matcher.id_to_name.get(str(item.visual_id), "Unknown")
|
|
128
|
+
item._finalized = True
|
|
129
|
+
|
|
130
|
+
self.reader.process_layout(
|
|
131
|
+
image,
|
|
132
|
+
layout,
|
|
133
|
+
skip_if=lambda item: getattr(item, "_finalized", False)
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
for side in [layout.outgoing, layout.incoming]:
|
|
137
|
+
for item in side.items:
|
|
138
|
+
self._final_judge(item)
|
|
139
|
+
|
|
89
140
|
return layout.to_dict()
|
|
@@ -0,0 +1,173 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
import torch.nn.functional as F
|
|
4
|
+
from torch.utils.data import DataLoader
|
|
5
|
+
from torchvision import datasets, transforms
|
|
6
|
+
from transformers import CLIPVisionModelWithProjection
|
|
7
|
+
from tqdm import tqdm
|
|
8
|
+
from torch.amp import GradScaler, autocast
|
|
9
|
+
from proofreader.core.config import CLASS_MAP_PATH, CLIP_BEST_PATH, DATASET_ROOT
|
|
10
|
+
import os
|
|
11
|
+
import json
|
|
12
|
+
import numpy as np
|
|
13
|
+
import random
|
|
14
|
+
|
|
15
|
+
MODEL_ID = "openai/clip-vit-base-patch32"
|
|
16
|
+
EPOCHS = 10
|
|
17
|
+
BATCH_SIZE = 48
|
|
18
|
+
LEARNING_RATE = 1e-5
|
|
19
|
+
EMBEDDING_DIM = 512
|
|
20
|
+
WEIGHT_DECAY = 0.1
|
|
21
|
+
PATIENCE = 3 # Stop if no improvement for 3 epochs
|
|
22
|
+
MIN_DELTA = 0.1 # Minimum % improvement to be considered "better"
|
|
23
|
+
|
|
24
|
+
def set_seed(seed: int = 42):
|
|
25
|
+
random.seed(seed)
|
|
26
|
+
np.random.seed(seed)
|
|
27
|
+
os.environ["PYTHONHASHSEED"] = str(seed)
|
|
28
|
+
|
|
29
|
+
torch.manual_seed(seed)
|
|
30
|
+
torch.cuda.manual_seed(seed)
|
|
31
|
+
torch.cuda.manual_seed_all(seed)
|
|
32
|
+
|
|
33
|
+
torch.backends.cudnn.deterministic = True
|
|
34
|
+
torch.backends.cudnn.benchmark = False
|
|
35
|
+
|
|
36
|
+
class CLIPItemEmbedder(nn.Module):
|
|
37
|
+
def __init__(self, num_classes):
|
|
38
|
+
super().__init__()
|
|
39
|
+
self.vision_encoder = CLIPVisionModelWithProjection.from_pretrained(MODEL_ID)
|
|
40
|
+
self.item_prototypes = nn.Embedding(num_classes, EMBEDDING_DIM)
|
|
41
|
+
self.logit_scale = nn.Parameter(torch.ones([]) * 2.659)
|
|
42
|
+
|
|
43
|
+
def forward(self, pixel_values, item_ids):
|
|
44
|
+
outputs = self.vision_encoder(pixel_values=pixel_values)
|
|
45
|
+
image_embeds = outputs.image_embeds
|
|
46
|
+
|
|
47
|
+
label_embeds = self.item_prototypes(item_ids)
|
|
48
|
+
label_embeds = F.normalize(label_embeds, p=2, dim=-1)
|
|
49
|
+
|
|
50
|
+
return image_embeds, label_embeds, self.logit_scale.exp()
|
|
51
|
+
|
|
52
|
+
class EarlyStopper:
|
|
53
|
+
def __init__(self, patience=3, min_delta=0.05):
|
|
54
|
+
self.patience = patience
|
|
55
|
+
self.min_delta = min_delta
|
|
56
|
+
self.counter = 0
|
|
57
|
+
self.best_accuracy = 0
|
|
58
|
+
self.best_state = None
|
|
59
|
+
|
|
60
|
+
def check(self, current_accuracy, model):
|
|
61
|
+
if current_accuracy > (self.best_accuracy + self.min_delta):
|
|
62
|
+
self.best_accuracy = current_accuracy
|
|
63
|
+
self.best_state = getattr(model, "_orig_mod", model).state_dict()
|
|
64
|
+
self.counter = 0
|
|
65
|
+
return False, True
|
|
66
|
+
else:
|
|
67
|
+
self.counter += 1
|
|
68
|
+
return (self.counter >= self.patience), False
|
|
69
|
+
|
|
70
|
+
def get_transforms():
|
|
71
|
+
return transforms.Compose([
|
|
72
|
+
transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC),
|
|
73
|
+
# Resolution Crush
|
|
74
|
+
transforms.RandomApply([
|
|
75
|
+
transforms.RandomChoice([transforms.Resize(128), transforms.Resize(64)]),
|
|
76
|
+
transforms.Resize(224),
|
|
77
|
+
], p=0.3),
|
|
78
|
+
# Gaussian Blur
|
|
79
|
+
transforms.RandomApply([
|
|
80
|
+
transforms.GaussianBlur(kernel_size=(3, 5), sigma=(0.1, 2.0))
|
|
81
|
+
], p=0.2),
|
|
82
|
+
transforms.CenterCrop(224),
|
|
83
|
+
transforms.RandomHorizontalFlip(),
|
|
84
|
+
transforms.ColorJitter(0.3, 0.3, 0.3),
|
|
85
|
+
transforms.ToTensor(),
|
|
86
|
+
transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
|
|
87
|
+
(0.26862954, 0.26130258, 0.27577711)),
|
|
88
|
+
])
|
|
89
|
+
|
|
90
|
+
def train_clip():
|
|
91
|
+
set_seed(1)
|
|
92
|
+
|
|
93
|
+
torch.backends.cudnn.benchmark = True
|
|
94
|
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
95
|
+
|
|
96
|
+
dataset_path = f"{DATASET_ROOT}/classification"
|
|
97
|
+
full_dataset = datasets.ImageFolder(root=dataset_path, transform=get_transforms())
|
|
98
|
+
num_classes = len(full_dataset.classes)
|
|
99
|
+
|
|
100
|
+
with open(CLASS_MAP_PATH, "w") as f:
|
|
101
|
+
json.dump(full_dataset.class_to_idx, f, separators=(",", ":"))
|
|
102
|
+
|
|
103
|
+
train_size = int(0.95 * len(full_dataset))
|
|
104
|
+
train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_size, len(full_dataset)-train_size])
|
|
105
|
+
|
|
106
|
+
train_loader = DataLoader(
|
|
107
|
+
train_dataset, batch_size=BATCH_SIZE, shuffle=True,
|
|
108
|
+
num_workers=os.cpu_count(), pin_memory=True, prefetch_factor=2, persistent_workers=True
|
|
109
|
+
)
|
|
110
|
+
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, pin_memory=True)
|
|
111
|
+
|
|
112
|
+
model = CLIPItemEmbedder(num_classes).to(device)
|
|
113
|
+
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
|
|
114
|
+
scaler = GradScaler('cuda')
|
|
115
|
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
|
|
116
|
+
stopper = EarlyStopper(patience=PATIENCE, min_delta=MIN_DELTA)
|
|
117
|
+
|
|
118
|
+
print(f"Starting training for {num_classes} classes...")
|
|
119
|
+
|
|
120
|
+
for epoch in range(EPOCHS):
|
|
121
|
+
model.train()
|
|
122
|
+
loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}")
|
|
123
|
+
|
|
124
|
+
for images, labels in loop:
|
|
125
|
+
images, labels = images.to(device, non_blocking=True), labels.to(device, non_blocking=True)
|
|
126
|
+
optimizer.zero_grad(set_to_none=True)
|
|
127
|
+
|
|
128
|
+
with autocast('cuda'):
|
|
129
|
+
img_emb, _, scale = model(images, labels)
|
|
130
|
+
img_emb = F.normalize(img_emb, p=2, dim=-1)
|
|
131
|
+
|
|
132
|
+
all_ids = torch.arange(num_classes, device=device)
|
|
133
|
+
prototypes = F.normalize(model.item_prototypes(all_ids), p=2, dim=-1)
|
|
134
|
+
|
|
135
|
+
logits = scale * img_emb @ prototypes.t()
|
|
136
|
+
loss = F.cross_entropy(logits, labels)
|
|
137
|
+
|
|
138
|
+
scaler.scale(loss).backward()
|
|
139
|
+
scaler.step(optimizer)
|
|
140
|
+
scaler.update()
|
|
141
|
+
loop.set_postfix(loss=f"{loss.item():.4f}")
|
|
142
|
+
|
|
143
|
+
scheduler.step()
|
|
144
|
+
|
|
145
|
+
model.eval()
|
|
146
|
+
correct, total = 0, 0
|
|
147
|
+
with torch.no_grad(), autocast('cuda'):
|
|
148
|
+
all_ids = torch.arange(num_classes).to(device)
|
|
149
|
+
prototypes = F.normalize(model.item_prototypes(all_ids), p=2, dim=-1)
|
|
150
|
+
|
|
151
|
+
for images, labels in val_loader:
|
|
152
|
+
images, labels = images.to(device), labels.to(device)
|
|
153
|
+
img_emb, _, _ = model(images, labels)
|
|
154
|
+
preds = (img_emb @ prototypes.t()).argmax(dim=-1)
|
|
155
|
+
correct += (preds == labels).sum().item()
|
|
156
|
+
total += labels.size(0)
|
|
157
|
+
|
|
158
|
+
val_acc = 100 * correct / total
|
|
159
|
+
print(f"Validation Accuracy: {val_acc:.2f}%")
|
|
160
|
+
|
|
161
|
+
stop_now, is_best = stopper.check(val_acc, model)
|
|
162
|
+
if is_best:
|
|
163
|
+
torch.save(stopper.best_state, CLIP_BEST_PATH)
|
|
164
|
+
print("Successfully saved new best model weights.")
|
|
165
|
+
|
|
166
|
+
if stop_now:
|
|
167
|
+
print(f"Stopping early. Best Accuracy was {stopper.best_accuracy:.2f}%")
|
|
168
|
+
break
|
|
169
|
+
|
|
170
|
+
print("Training finished.")
|
|
171
|
+
|
|
172
|
+
if __name__ == "__main__":
|
|
173
|
+
train_clip()
|