graphvision-ai 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- graphvision_ai-0.1.0/PKG-INFO +12 -0
- graphvision_ai-0.1.0/README.md +0 -0
- graphvision_ai-0.1.0/graphvision/__init__.py +0 -0
- graphvision_ai-0.1.0/graphvision/analyzer.py +212 -0
- graphvision_ai-0.1.0/graphvision/model.py +90 -0
- graphvision_ai-0.1.0/graphvision_ai.egg-info/PKG-INFO +12 -0
- graphvision_ai-0.1.0/graphvision_ai.egg-info/SOURCES.txt +10 -0
- graphvision_ai-0.1.0/graphvision_ai.egg-info/dependency_links.txt +1 -0
- graphvision_ai-0.1.0/graphvision_ai.egg-info/requires.txt +6 -0
- graphvision_ai-0.1.0/graphvision_ai.egg-info/top_level.txt +1 -0
- graphvision_ai-0.1.0/pyproject.toml +21 -0
- graphvision_ai-0.1.0/setup.cfg +4 -0
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: graphvision-ai
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Automatic Graph Classification and Data Extraction
|
|
5
|
+
Author: Aryan Gahlot
|
|
6
|
+
Description-Content-Type: text/markdown
|
|
7
|
+
Requires-Dist: torch
|
|
8
|
+
Requires-Dist: torchvision
|
|
9
|
+
Requires-Dist: opencv-python
|
|
10
|
+
Requires-Dist: easyocr
|
|
11
|
+
Requires-Dist: Pillow
|
|
12
|
+
Requires-Dist: numpy
|
|
File without changes
|
|
File without changes
|
|
@@ -0,0 +1,212 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import cv2
|
|
3
|
+
import easyocr
|
|
4
|
+
import numpy as np
|
|
5
|
+
from PIL import Image
|
|
6
|
+
from torchvision import transforms
|
|
7
|
+
|
|
8
|
+
from .model import (
|
|
9
|
+
load_classifier,
|
|
10
|
+
load_hbar_model,
|
|
11
|
+
load_vbar_model,
|
|
12
|
+
load_pie_model
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
class GraphAnalyzer:
|
|
16
|
+
def __init__(self):
|
|
17
|
+
# Automatically use Mac MPS, Nvidia CUDA, or CPU
|
|
18
|
+
self.device = torch.device("mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu"))
|
|
19
|
+
|
|
20
|
+
# 1. Load Models
|
|
21
|
+
self.classifier = load_classifier(self.device)
|
|
22
|
+
self.hbar_model = load_hbar_model(self.device)
|
|
23
|
+
self.vbar_model = load_vbar_model(self.device)
|
|
24
|
+
self.pie_model = load_pie_model(self.device)
|
|
25
|
+
|
|
26
|
+
# 2. Setup OCR Engine
|
|
27
|
+
self.reader = easyocr.Reader(['en'])
|
|
28
|
+
|
|
29
|
+
# 3. CRITICAL: Setup strict ResNet Normalization
|
|
30
|
+
self.data_transforms = transforms.Compose([
|
|
31
|
+
transforms.Resize((224, 224)),
|
|
32
|
+
transforms.ToTensor(),
|
|
33
|
+
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
|
34
|
+
])
|
|
35
|
+
|
|
36
|
+
# Adjust these indices to match exactly how the notebook classifier sorted the folders
|
|
37
|
+
self.class_map = {
|
|
38
|
+
0: "dot_line",
|
|
39
|
+
1: "hbar_categorical",
|
|
40
|
+
2: "line",
|
|
41
|
+
3: "pie",
|
|
42
|
+
4: "vbar_categorical"
|
|
43
|
+
}
|
|
44
|
+
|
|
45
|
+
def analyze(self, image):
|
|
46
|
+
"""Main routing function. Classifies the image and extracts data."""
|
|
47
|
+
|
|
48
|
+
# --- ROBUST IMAGE LOADING ---
|
|
49
|
+
if isinstance(image, str):
|
|
50
|
+
img_pil = Image.open(image).convert('RGB')
|
|
51
|
+
cv_img = cv2.imread(image)
|
|
52
|
+
elif isinstance(image, np.ndarray):
|
|
53
|
+
cv_img = image
|
|
54
|
+
img_pil = Image.fromarray(cv2.cvtColor(cv_img, cv2.COLOR_BGR2RGB))
|
|
55
|
+
elif isinstance(image, Image.Image):
|
|
56
|
+
img_pil = image.convert('RGB')
|
|
57
|
+
cv_img = cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR)
|
|
58
|
+
else:
|
|
59
|
+
raise TypeError("Input must be a file path (str), NumPy array (cv2), or PIL Image.")
|
|
60
|
+
|
|
61
|
+
# 1. Classify the graph type
|
|
62
|
+
image_tensor = self.data_transforms(img_pil).unsqueeze(0).to(self.device)
|
|
63
|
+
|
|
64
|
+
with torch.no_grad():
|
|
65
|
+
outputs = self.classifier(image_tensor)
|
|
66
|
+
predicted_class = torch.argmax(outputs, dim=1).item()
|
|
67
|
+
|
|
68
|
+
graph_type = self.class_map.get(predicted_class, "unknown")
|
|
69
|
+
|
|
70
|
+
# 2. Route to the correct OCR + Prediction logic
|
|
71
|
+
if graph_type == "pie":
|
|
72
|
+
return self._extract_pie_data(cv_img, image_tensor)
|
|
73
|
+
elif graph_type == "vbar_categorical":
|
|
74
|
+
return self._extract_vbar_data(cv_img, image_tensor)
|
|
75
|
+
elif graph_type == "hbar_categorical":
|
|
76
|
+
return self._extract_hbar_data(cv_img, image_tensor)
|
|
77
|
+
else:
|
|
78
|
+
return {"type": graph_type, "status": f"Extraction logic for {graph_type} is not completely implemented in this snippet"}
|
|
79
|
+
|
|
80
|
+
def _extract_pie_data(self, cv_img, image_tensor):
|
|
81
|
+
"""The dedicated Pie Chart extraction logic"""
|
|
82
|
+
with torch.no_grad():
|
|
83
|
+
preds = self.pie_model(image_tensor).squeeze().cpu().numpy() * 100.0
|
|
84
|
+
|
|
85
|
+
h, w, _ = cv_img.shape
|
|
86
|
+
gray = cv2.cvtColor(cv_img, cv2.COLOR_BGR2GRAY)
|
|
87
|
+
all_text_results = self.reader.readtext(gray, mag_ratio=2.5)
|
|
88
|
+
|
|
89
|
+
title = "Untitled"
|
|
90
|
+
raw_legend_names = []
|
|
91
|
+
|
|
92
|
+
for bbox, text, conf in all_text_results:
|
|
93
|
+
clean_text = text.strip()
|
|
94
|
+
if not clean_text: continue
|
|
95
|
+
|
|
96
|
+
x_center = (bbox[0][0] + bbox[2][0]) / 2
|
|
97
|
+
y_center = (bbox[0][1] + bbox[2][1]) / 2
|
|
98
|
+
x_pct = x_center / w
|
|
99
|
+
y_pct = y_center / h
|
|
100
|
+
|
|
101
|
+
if y_pct < 0.15:
|
|
102
|
+
title = clean_text
|
|
103
|
+
elif y_pct > 0.15:
|
|
104
|
+
if len(clean_text) > 2 and not clean_text.replace('.', '', 1).isdigit():
|
|
105
|
+
if clean_text.lower() == "grav": clean_text = "Gray"
|
|
106
|
+
raw_legend_names.append((y_pct, clean_text))
|
|
107
|
+
|
|
108
|
+
legend_names = [item[1] for item in sorted(raw_legend_names, key=lambda i: i[0])]
|
|
109
|
+
|
|
110
|
+
num_slices = len(legend_names)
|
|
111
|
+
if num_slices == 0:
|
|
112
|
+
valid_preds = [v for v in preds if v > 1.5]
|
|
113
|
+
num_slices = len(valid_preds)
|
|
114
|
+
|
|
115
|
+
num_slices = min(num_slices, 10)
|
|
116
|
+
slice_preds = preds[:num_slices]
|
|
117
|
+
total_pred = sum(slice_preds)
|
|
118
|
+
|
|
119
|
+
if total_pred > 0:
|
|
120
|
+
normalized_preds = [(v / total_pred) * 100.0 for v in slice_preds]
|
|
121
|
+
else:
|
|
122
|
+
normalized_preds = slice_preds
|
|
123
|
+
|
|
124
|
+
final_slices = {}
|
|
125
|
+
for i in range(num_slices):
|
|
126
|
+
slice_name = legend_names[i] if i < len(legend_names) else f"Unknown_Slice_{i+1}"
|
|
127
|
+
slice_value = round(normalized_preds[i], 2) if i < len(normalized_preds) else 0.0
|
|
128
|
+
final_slices[slice_name] = slice_value
|
|
129
|
+
|
|
130
|
+
return {
|
|
131
|
+
"type": "pie",
|
|
132
|
+
"title": title,
|
|
133
|
+
"data": final_slices
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
def _extract_vbar_data(self, cv_img, image_tensor):
|
|
137
|
+
"""The dedicated Vertical Bar Chart extraction logic"""
|
|
138
|
+
with torch.no_grad():
|
|
139
|
+
preds = self.vbar_model(image_tensor).squeeze().cpu().numpy() * 100.0
|
|
140
|
+
|
|
141
|
+
h, w, _ = cv_img.shape
|
|
142
|
+
|
|
143
|
+
# 1. Title
|
|
144
|
+
title_res = self.reader.readtext(cv_img[0:int(h*0.15), :])
|
|
145
|
+
title = title_res[0][1] if title_res else "Untitled"
|
|
146
|
+
|
|
147
|
+
# 2. X-Axis Labels (Rotated crop to catch vertical text)
|
|
148
|
+
label_area = cv_img[int(h*0.70):h, :]
|
|
149
|
+
label_rot = cv2.rotate(label_area, cv2.ROTATE_90_COUNTERCLOCKWISE)
|
|
150
|
+
label_res = self.reader.readtext(label_rot)
|
|
151
|
+
|
|
152
|
+
x_labels = [res[1] for res in label_res if len(res[1]) > 1]
|
|
153
|
+
x_labels = x_labels[::-1]
|
|
154
|
+
|
|
155
|
+
# 3. Intelligent Matching (The dynamic fix)
|
|
156
|
+
max_val = np.max(preds)
|
|
157
|
+
noise_threshold = max_val * 0.05
|
|
158
|
+
|
|
159
|
+
valid_preds = [v for v in preds if v > noise_threshold]
|
|
160
|
+
|
|
161
|
+
count = max(len(x_labels), len(valid_preds))
|
|
162
|
+
count = min(count, 10)
|
|
163
|
+
|
|
164
|
+
final_values = [round(float(v), 2) for v in preds[:count]]
|
|
165
|
+
|
|
166
|
+
while len(x_labels) < len(final_values):
|
|
167
|
+
x_labels.append(f"Unknown_Label_{len(x_labels)+1}")
|
|
168
|
+
|
|
169
|
+
return {
|
|
170
|
+
"type": "vbar_categorical",
|
|
171
|
+
"title": title,
|
|
172
|
+
"x_axis_labels": x_labels,
|
|
173
|
+
"values": final_values
|
|
174
|
+
}
|
|
175
|
+
|
|
176
|
+
def _extract_hbar_data(self, cv_img, image_tensor):
|
|
177
|
+
"""The dedicated Horizontal Bar Chart extraction logic"""
|
|
178
|
+
with torch.no_grad():
|
|
179
|
+
preds = self.hbar_model(image_tensor).squeeze().cpu().numpy() * 100.0
|
|
180
|
+
|
|
181
|
+
h, w, _ = cv_img.shape
|
|
182
|
+
|
|
183
|
+
# 1. Title
|
|
184
|
+
title_res = self.reader.readtext(cv_img[0:int(h*0.15), :])
|
|
185
|
+
title = title_res[0][1] if title_res else "Untitled"
|
|
186
|
+
|
|
187
|
+
# 2. Labels (Smooth Contrast applied)
|
|
188
|
+
label_area = cv_img[int(h*0.10):int(h*0.95), 0:int(w*0.45)]
|
|
189
|
+
gray = cv2.cvtColor(label_area, cv2.COLOR_BGR2GRAY)
|
|
190
|
+
contrast_img = cv2.convertScaleAbs(gray, alpha=1.5, beta=-50)
|
|
191
|
+
|
|
192
|
+
label_res = self.reader.readtext(contrast_img, mag_ratio=2.0)
|
|
193
|
+
y_labels = [res[1] for res in label_res if len(res[1]) > 1]
|
|
194
|
+
y_labels = y_labels[::-1]
|
|
195
|
+
|
|
196
|
+
# 3. Intelligent Matching
|
|
197
|
+
valid_preds = [v for v in preds if v > 2.0]
|
|
198
|
+
|
|
199
|
+
count = len(y_labels) if len(y_labels) > 0 else len(valid_preds)
|
|
200
|
+
count = min(count, 10)
|
|
201
|
+
|
|
202
|
+
final_values = [round(float(v), 2) for v in preds[:count]]
|
|
203
|
+
|
|
204
|
+
while len(y_labels) < len(final_values):
|
|
205
|
+
y_labels.append(f"Unknown_Label_{len(y_labels)+1}")
|
|
206
|
+
|
|
207
|
+
return {
|
|
208
|
+
"type": "hbar_categorical",
|
|
209
|
+
"title": title,
|
|
210
|
+
"y_axis_labels": y_labels,
|
|
211
|
+
"values": final_values
|
|
212
|
+
}
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn as nn
|
|
3
|
+
from torchvision import models
|
|
4
|
+
import os
|
|
5
|
+
import urllib.request
|
|
6
|
+
import shutil
|
|
7
|
+
|
|
8
|
+
# --- AUTO-DOWNLOAD WEIGHTS SETUP ---
|
|
9
|
+
HOME_DIR = os.path.expanduser("~")
|
|
10
|
+
# Create a hidden folder in the user's home directory to store weights securely
|
|
11
|
+
WEIGHTS_DIR = os.path.join(HOME_DIR, ".graphvision_weights")
|
|
12
|
+
os.makedirs(WEIGHTS_DIR, exist_ok=True)
|
|
13
|
+
|
|
14
|
+
# Your Hugging Face Repository
|
|
15
|
+
HF_BASE_URL = "https://huggingface.co/ShadowGard3n/graphvision/resolve/main/"
|
|
16
|
+
|
|
17
|
+
WEIGHT_URLS = {
|
|
18
|
+
"graph_classifier.pth": f"{HF_BASE_URL}graph_classifier.pth",
|
|
19
|
+
"pie_regressor_stable.pth": f"{HF_BASE_URL}pie_regressor_stable.pth",
|
|
20
|
+
"hbar_regressor_stable.pth": f"{HF_BASE_URL}hbar_regressor_stable.pth",
|
|
21
|
+
"vbar_regressor.pth": f"{HF_BASE_URL}vbar_regressor.pth"
|
|
22
|
+
}
|
|
23
|
+
|
|
24
|
+
def download_weight(filename):
|
|
25
|
+
"""Downloads the weight file from Hugging Face if it doesn't already exist."""
|
|
26
|
+
filepath = os.path.join(WEIGHTS_DIR, filename)
|
|
27
|
+
if not os.path.exists(filepath):
|
|
28
|
+
print(f"Downloading {filename} from Hugging Face (This only happens once)...")
|
|
29
|
+
# Add a custom User-Agent header so Hugging Face doesn't block the request
|
|
30
|
+
req = urllib.request.Request(WEIGHT_URLS[filename], headers={'User-Agent': 'Mozilla/5.0'})
|
|
31
|
+
with urllib.request.urlopen(req) as response, open(filepath, 'wb') as out_file:
|
|
32
|
+
shutil.copyfileobj(response, out_file)
|
|
33
|
+
return filepath
|
|
34
|
+
|
|
35
|
+
# --- 1. MODEL ARCHITECTURES ---
|
|
36
|
+
class PieRegressor(nn.Module):
|
|
37
|
+
def __init__(self):
|
|
38
|
+
super(PieRegressor, self).__init__()
|
|
39
|
+
self.backbone = models.resnet18(weights=None)
|
|
40
|
+
num_ftrs = self.backbone.fc.in_features
|
|
41
|
+
self.backbone.fc = nn.Linear(num_ftrs, 10)
|
|
42
|
+
self.sigmoid = nn.Sigmoid()
|
|
43
|
+
|
|
44
|
+
def forward(self, x):
|
|
45
|
+
return self.sigmoid(self.backbone(x))
|
|
46
|
+
|
|
47
|
+
class BarRegressor(nn.Module):
|
|
48
|
+
def __init__(self):
|
|
49
|
+
super(BarRegressor, self).__init__()
|
|
50
|
+
self.backbone = models.resnet18(weights=None)
|
|
51
|
+
num_ftrs = self.backbone.fc.in_features
|
|
52
|
+
self.backbone.fc = nn.Linear(num_ftrs, 10)
|
|
53
|
+
self.sigmoid = nn.Sigmoid()
|
|
54
|
+
|
|
55
|
+
def forward(self, x):
|
|
56
|
+
return self.sigmoid(self.backbone(x))
|
|
57
|
+
|
|
58
|
+
# --- 2. SECURE LOADING FUNCTIONS ---
|
|
59
|
+
def load_classifier(device):
|
|
60
|
+
# FIX: Use the download_weight function instead of os.path.join
|
|
61
|
+
path = download_weight("graph_classifier.pth")
|
|
62
|
+
|
|
63
|
+
model = models.resnet18(weights=None)
|
|
64
|
+
num_ftrs = model.fc.in_features
|
|
65
|
+
# Set to 5 classes to perfectly match the saved checkpoint
|
|
66
|
+
model.fc = nn.Linear(num_ftrs, 5)
|
|
67
|
+
|
|
68
|
+
model.load_state_dict(torch.load(path, map_location=device))
|
|
69
|
+
return model.to(device).eval()
|
|
70
|
+
|
|
71
|
+
def load_pie_model(device):
|
|
72
|
+
# FIX: Use the download_weight function
|
|
73
|
+
path = download_weight("pie_regressor_stable.pth")
|
|
74
|
+
model = PieRegressor()
|
|
75
|
+
model.load_state_dict(torch.load(path, map_location=device))
|
|
76
|
+
return model.to(device).eval()
|
|
77
|
+
|
|
78
|
+
def load_hbar_model(device):
|
|
79
|
+
# FIX: Use the download_weight function
|
|
80
|
+
path = download_weight("hbar_regressor_stable.pth")
|
|
81
|
+
model = BarRegressor()
|
|
82
|
+
model.load_state_dict(torch.load(path, map_location=device))
|
|
83
|
+
return model.to(device).eval()
|
|
84
|
+
|
|
85
|
+
def load_vbar_model(device):
|
|
86
|
+
# FIX: Use the download_weight function
|
|
87
|
+
path = download_weight("vbar_regressor.pth")
|
|
88
|
+
model = BarRegressor()
|
|
89
|
+
model.load_state_dict(torch.load(path, map_location=device))
|
|
90
|
+
return model.to(device).eval()
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: graphvision-ai
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Automatic Graph Classification and Data Extraction
|
|
5
|
+
Author: Aryan Gahlot
|
|
6
|
+
Description-Content-Type: text/markdown
|
|
7
|
+
Requires-Dist: torch
|
|
8
|
+
Requires-Dist: torchvision
|
|
9
|
+
Requires-Dist: opencv-python
|
|
10
|
+
Requires-Dist: easyocr
|
|
11
|
+
Requires-Dist: Pillow
|
|
12
|
+
Requires-Dist: numpy
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
README.md
|
|
2
|
+
pyproject.toml
|
|
3
|
+
graphvision/__init__.py
|
|
4
|
+
graphvision/analyzer.py
|
|
5
|
+
graphvision/model.py
|
|
6
|
+
graphvision_ai.egg-info/PKG-INFO
|
|
7
|
+
graphvision_ai.egg-info/SOURCES.txt
|
|
8
|
+
graphvision_ai.egg-info/dependency_links.txt
|
|
9
|
+
graphvision_ai.egg-info/requires.txt
|
|
10
|
+
graphvision_ai.egg-info/top_level.txt
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
graphvision
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools>=61.0"]
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "graphvision-ai"
|
|
7
|
+
version = "0.1.0"
|
|
8
|
+
description = "Automatic Graph Classification and Data Extraction"
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
authors = [{name="Aryan Gahlot"}]
|
|
11
|
+
dependencies = [
|
|
12
|
+
"torch",
|
|
13
|
+
"torchvision",
|
|
14
|
+
"opencv-python",
|
|
15
|
+
"easyocr",
|
|
16
|
+
"Pillow",
|
|
17
|
+
"numpy"
|
|
18
|
+
]
|
|
19
|
+
|
|
20
|
+
[tool.setuptools.packages.find]
|
|
21
|
+
include = ["graphvision"]
|