graphvision-ai 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,12 @@
1
+ Metadata-Version: 2.4
2
+ Name: graphvision-ai
3
+ Version: 0.1.0
4
+ Summary: Automatic Graph Classification and Data Extraction
5
+ Author: Aryan Gahlot
6
+ Description-Content-Type: text/markdown
7
+ Requires-Dist: torch
8
+ Requires-Dist: torchvision
9
+ Requires-Dist: opencv-python
10
+ Requires-Dist: easyocr
11
+ Requires-Dist: Pillow
12
+ Requires-Dist: numpy
File without changes
File without changes
@@ -0,0 +1,212 @@
1
+ import torch
2
+ import cv2
3
+ import easyocr
4
+ import numpy as np
5
+ from PIL import Image
6
+ from torchvision import transforms
7
+
8
+ from .model import (
9
+ load_classifier,
10
+ load_hbar_model,
11
+ load_vbar_model,
12
+ load_pie_model
13
+ )
14
+
15
+ class GraphAnalyzer:
16
+ def __init__(self):
17
+ # Automatically use Mac MPS, Nvidia CUDA, or CPU
18
+ self.device = torch.device("mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu"))
19
+
20
+ # 1. Load Models
21
+ self.classifier = load_classifier(self.device)
22
+ self.hbar_model = load_hbar_model(self.device)
23
+ self.vbar_model = load_vbar_model(self.device)
24
+ self.pie_model = load_pie_model(self.device)
25
+
26
+ # 2. Setup OCR Engine
27
+ self.reader = easyocr.Reader(['en'])
28
+
29
+ # 3. CRITICAL: Setup strict ResNet Normalization
30
+ self.data_transforms = transforms.Compose([
31
+ transforms.Resize((224, 224)),
32
+ transforms.ToTensor(),
33
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
34
+ ])
35
+
36
+ # Adjust these indices to match exactly how the notebook classifier sorted the folders
37
+ self.class_map = {
38
+ 0: "dot_line",
39
+ 1: "hbar_categorical",
40
+ 2: "line",
41
+ 3: "pie",
42
+ 4: "vbar_categorical"
43
+ }
44
+
45
+ def analyze(self, image):
46
+ """Main routing function. Classifies the image and extracts data."""
47
+
48
+ # --- ROBUST IMAGE LOADING ---
49
+ if isinstance(image, str):
50
+ img_pil = Image.open(image).convert('RGB')
51
+ cv_img = cv2.imread(image)
52
+ elif isinstance(image, np.ndarray):
53
+ cv_img = image
54
+ img_pil = Image.fromarray(cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB))
55
+ elif isinstance(image, Image.Image):
56
+ img_pil = image.convert('RGB')
57
+ cv_img = cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR)
58
+ else:
59
+ raise TypeError("Input must be a file path (str), NumPy array (cv2), or PIL Image.")
60
+
61
+ # 1. Classify the graph type
62
+ image_tensor = self.data_transforms(img_pil).unsqueeze(0).to(self.device)
63
+
64
+ with torch.no_grad():
65
+ outputs = self.classifier(image_tensor)
66
+ predicted_class = torch.argmax(outputs, dim=1).item()
67
+
68
+ graph_type = self.class_map.get(predicted_class, "unknown")
69
+
70
+ # 2. Route to the correct OCR + Prediction logic
71
+ if graph_type == "pie":
72
+ return self._extract_pie_data(cv_img, image_tensor)
73
+ elif graph_type == "vbar_categorical":
74
+ return self._extract_vbar_data(cv_img, image_tensor)
75
+ elif graph_type == "hbar_categorical":
76
+ return self._extract_hbar_data(cv_img, image_tensor)
77
+ else:
78
+ return {"type": graph_type, "status": f"Extraction logic for {graph_type} is not completely implemented in this snippet"}
79
+
80
+ def _extract_pie_data(self, cv_img, image_tensor):
81
+ """The dedicated Pie Chart extraction logic"""
82
+ with torch.no_grad():
83
+ preds = self.pie_model(image_tensor).squeeze().cpu().numpy() * 100.0
84
+
85
+ h, w, _ = cv_img.shape
86
+ gray = cv2.cvtColor(cv_img, cv2.COLOR_BGR2GRAY)
87
+ all_text_results = self.reader.readtext(gray, mag_ratio=2.5)
88
+
89
+ title = "Untitled"
90
+ raw_legend_names = []
91
+
92
+ for bbox, text, conf in all_text_results:
93
+ clean_text = text.strip()
94
+ if not clean_text: continue
95
+
96
+ x_center = (bbox[0][0] + bbox[2][0]) / 2
97
+ y_center = (bbox[0][1] + bbox[2][1]) / 2
98
+ x_pct = x_center / w
99
+ y_pct = y_center / h
100
+
101
+ if y_pct < 0.15:
102
+ title = clean_text
103
+ elif y_pct > 0.15:
104
+ if len(clean_text) > 2 and not clean_text.replace('.', '', 1).isdigit():
105
+ if clean_text.lower() == "grav": clean_text = "Gray"
106
+ raw_legend_names.append((y_pct, clean_text))
107
+
108
+ legend_names = [item[1] for item in sorted(raw_legend_names, key=lambda i: i[0])]
109
+
110
+ num_slices = len(legend_names)
111
+ if num_slices == 0:
112
+ valid_preds = [v for v in preds if v > 1.5]
113
+ num_slices = len(valid_preds)
114
+
115
+ num_slices = min(num_slices, 10)
116
+ slice_preds = preds[:num_slices]
117
+ total_pred = sum(slice_preds)
118
+
119
+ if total_pred > 0:
120
+ normalized_preds = [(v / total_pred) * 100.0 for v in slice_preds]
121
+ else:
122
+ normalized_preds = slice_preds
123
+
124
+ final_slices = {}
125
+ for i in range(num_slices):
126
+ slice_name = legend_names[i] if i < len(legend_names) else f"Unknown_Slice_{i+1}"
127
+ slice_value = round(normalized_preds[i], 2) if i < len(normalized_preds) else 0.0
128
+ final_slices[slice_name] = slice_value
129
+
130
+ return {
131
+ "type": "pie",
132
+ "title": title,
133
+ "data": final_slices
134
+ }
135
+
136
+ def _extract_vbar_data(self, cv_img, image_tensor):
137
+ """The dedicated Vertical Bar Chart extraction logic"""
138
+ with torch.no_grad():
139
+ preds = self.vbar_model(image_tensor).squeeze().cpu().numpy() * 100.0
140
+
141
+ h, w, _ = cv_img.shape
142
+
143
+ # 1. Title
144
+ title_res = self.reader.readtext(cv_img[0:int(h*0.15), :])
145
+ title = title_res[0][1] if title_res else "Untitled"
146
+
147
+ # 2. X-Axis Labels (Rotated crop to catch vertical text)
148
+ label_area = cv_img[int(h*0.70):h, :]
149
+ label_rot = cv2.rotate(label_area, cv2.ROTATE_90_COUNTERCLOCKWISE)
150
+ label_res = self.reader.readtext(label_rot)
151
+
152
+ x_labels = [res[1] for res in label_res if len(res[1]) > 1]
153
+ x_labels = x_labels[::-1]
154
+
155
+ # 3. Intelligent Matching (The dynamic fix)
156
+ max_val = np.max(preds)
157
+ noise_threshold = max_val * 0.05
158
+
159
+ valid_preds = [v for v in preds if v > noise_threshold]
160
+
161
+ count = max(len(x_labels), len(valid_preds))
162
+ count = min(count, 10)
163
+
164
+ final_values = [round(float(v), 2) for v in preds[:count]]
165
+
166
+ while len(x_labels) < len(final_values):
167
+ x_labels.append(f"Unknown_Label_{len(x_labels)+1}")
168
+
169
+ return {
170
+ "type": "vbar_categorical",
171
+ "title": title,
172
+ "x_axis_labels": x_labels,
173
+ "values": final_values
174
+ }
175
+
176
+ def _extract_hbar_data(self, cv_img, image_tensor):
177
+ """The dedicated Horizontal Bar Chart extraction logic"""
178
+ with torch.no_grad():
179
+ preds = self.hbar_model(image_tensor).squeeze().cpu().numpy() * 100.0
180
+
181
+ h, w, _ = cv_img.shape
182
+
183
+ # 1. Title
184
+ title_res = self.reader.readtext(cv_img[0:int(h*0.15), :])
185
+ title = title_res[0][1] if title_res else "Untitled"
186
+
187
+ # 2. Labels (Smooth Contrast applied)
188
+ label_area = cv_img[int(h*0.10):int(h*0.95), 0:int(w*0.45)]
189
+ gray = cv2.cvtColor(label_area, cv2.COLOR_BGR2GRAY)
190
+ contrast_img = cv2.convertScaleAbs(gray, alpha=1.5, beta=-50)
191
+
192
+ label_res = self.reader.readtext(contrast_img, mag_ratio=2.0)
193
+ y_labels = [res[1] for res in label_res if len(res[1]) > 1]
194
+ y_labels = y_labels[::-1]
195
+
196
+ # 3. Intelligent Matching
197
+ valid_preds = [v for v in preds if v > 2.0]
198
+
199
+ count = len(y_labels) if len(y_labels) > 0 else len(valid_preds)
200
+ count = min(count, 10)
201
+
202
+ final_values = [round(float(v), 2) for v in preds[:count]]
203
+
204
+ while len(y_labels) < len(final_values):
205
+ y_labels.append(f"Unknown_Label_{len(y_labels)+1}")
206
+
207
+ return {
208
+ "type": "hbar_categorical",
209
+ "title": title,
210
+ "y_axis_labels": y_labels,
211
+ "values": final_values
212
+ }
@@ -0,0 +1,90 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision import models
4
+ import os
5
+ import urllib.request
6
+ import shutil
7
+
8
+ # --- AUTO-DOWNLOAD WEIGHTS SETUP ---
9
+ HOME_DIR = os.path.expanduser("~")
10
+ # Create a hidden folder in the user's home directory to store weights securely
11
+ WEIGHTS_DIR = os.path.join(HOME_DIR, ".graphvision_weights")
12
+ os.makedirs(WEIGHTS_DIR, exist_ok=True)
13
+
14
+ # Your Hugging Face Repository
15
+ HF_BASE_URL = "https://huggingface.co/ShadowGard3n/graphvision/resolve/main/"
16
+
17
+ WEIGHT_URLS = {
18
+ "graph_classifier.pth": f"{HF_BASE_URL}graph_classifier.pth",
19
+ "pie_regressor_stable.pth": f"{HF_BASE_URL}pie_regressor_stable.pth",
20
+ "hbar_regressor_stable.pth": f"{HF_BASE_URL}hbar_regressor_stable.pth",
21
+ "vbar_regressor.pth": f"{HF_BASE_URL}vbar_regressor.pth"
22
+ }
23
+
24
+ def download_weight(filename):
25
+ """Downloads the weight file from Hugging Face if it doesn't already exist."""
26
+ filepath = os.path.join(WEIGHTS_DIR, filename)
27
+ if not os.path.exists(filepath):
28
+ print(f"Downloading {filename} from Hugging Face (This only happens once)...")
29
+ # Add a custom User-Agent header so Hugging Face doesn't block the request
30
+ req = urllib.request.Request(WEIGHT_URLS[filename], headers={'User-Agent': 'Mozilla/5.0'})
31
+ with urllib.request.urlopen(req) as response, open(filepath, 'wb') as out_file:
32
+ shutil.copyfileobj(response, out_file)
33
+ return filepath
34
+
35
+ # --- 1. MODEL ARCHITECTURES ---
36
+ class PieRegressor(nn.Module):
37
+ def __init__(self):
38
+ super(PieRegressor, self).__init__()
39
+ self.backbone = models.resnet18(weights=None)
40
+ num_ftrs = self.backbone.fc.in_features
41
+ self.backbone.fc = nn.Linear(num_ftrs, 10)
42
+ self.sigmoid = nn.Sigmoid()
43
+
44
+ def forward(self, x):
45
+ return self.sigmoid(self.backbone(x))
46
+
47
+ class BarRegressor(nn.Module):
48
+ def __init__(self):
49
+ super(BarRegressor, self).__init__()
50
+ self.backbone = models.resnet18(weights=None)
51
+ num_ftrs = self.backbone.fc.in_features
52
+ self.backbone.fc = nn.Linear(num_ftrs, 10)
53
+ self.sigmoid = nn.Sigmoid()
54
+
55
+ def forward(self, x):
56
+ return self.sigmoid(self.backbone(x))
57
+
58
+ # --- 2. SECURE LOADING FUNCTIONS ---
59
+ def load_classifier(device):
60
+ # FIX: Use the download_weight function instead of os.path.join
61
+ path = download_weight("graph_classifier.pth")
62
+
63
+ model = models.resnet18(weights=None)
64
+ num_ftrs = model.fc.in_features
65
+ # Set to 5 classes to perfectly match the saved checkpoint
66
+ model.fc = nn.Linear(num_ftrs, 5)
67
+
68
+ model.load_state_dict(torch.load(path, map_location=device))
69
+ return model.to(device).eval()
70
+
71
+ def load_pie_model(device):
72
+ # FIX: Use the download_weight function
73
+ path = download_weight("pie_regressor_stable.pth")
74
+ model = PieRegressor()
75
+ model.load_state_dict(torch.load(path, map_location=device))
76
+ return model.to(device).eval()
77
+
78
+ def load_hbar_model(device):
79
+ # FIX: Use the download_weight function
80
+ path = download_weight("hbar_regressor_stable.pth")
81
+ model = BarRegressor()
82
+ model.load_state_dict(torch.load(path, map_location=device))
83
+ return model.to(device).eval()
84
+
85
+ def load_vbar_model(device):
86
+ # FIX: Use the download_weight function
87
+ path = download_weight("vbar_regressor.pth")
88
+ model = BarRegressor()
89
+ model.load_state_dict(torch.load(path, map_location=device))
90
+ return model.to(device).eval()
@@ -0,0 +1,12 @@
1
+ Metadata-Version: 2.4
2
+ Name: graphvision-ai
3
+ Version: 0.1.0
4
+ Summary: Automatic Graph Classification and Data Extraction
5
+ Author: Aryan Gahlot
6
+ Description-Content-Type: text/markdown
7
+ Requires-Dist: torch
8
+ Requires-Dist: torchvision
9
+ Requires-Dist: opencv-python
10
+ Requires-Dist: easyocr
11
+ Requires-Dist: Pillow
12
+ Requires-Dist: numpy
@@ -0,0 +1,10 @@
1
+ README.md
2
+ pyproject.toml
3
+ graphvision/__init__.py
4
+ graphvision/analyzer.py
5
+ graphvision/model.py
6
+ graphvision_ai.egg-info/PKG-INFO
7
+ graphvision_ai.egg-info/SOURCES.txt
8
+ graphvision_ai.egg-info/dependency_links.txt
9
+ graphvision_ai.egg-info/requires.txt
10
+ graphvision_ai.egg-info/top_level.txt
@@ -0,0 +1,6 @@
1
+ torch
2
+ torchvision
3
+ opencv-python
4
+ easyocr
5
+ Pillow
6
+ numpy
@@ -0,0 +1 @@
1
+ graphvision
@@ -0,0 +1,21 @@
1
+ [build-system]
2
+ requires = ["setuptools>=61.0"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "graphvision-ai"
7
+ version = "0.1.0"
8
+ description = "Automatic Graph Classification and Data Extraction"
9
+ readme = "README.md"
10
+ authors = [{name="Aryan Gahlot"}]
11
+ dependencies = [
12
+ "torch",
13
+ "torchvision",
14
+ "opencv-python",
15
+ "easyocr",
16
+ "Pillow",
17
+ "numpy"
18
+ ]
19
+
20
+ [tool.setuptools.packages.find]
21
+ include = ["graphvision"]
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+