graft-pytorch 0.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- graft/__init__.py +20 -0
- graft/cli.py +62 -0
- graft/config.py +36 -0
- graft/decompositions.py +54 -0
- graft/genindices.py +122 -0
- graft/grad_dist.py +20 -0
- graft/models/BERT_model.py +40 -0
- graft/models/MobilenetV2.py +111 -0
- graft/models/ResNeXt.py +154 -0
- graft/models/__init__.py +22 -0
- graft/models/efficientnet.py +197 -0
- graft/models/efficientnetb7.py +268 -0
- graft/models/fashioncnn.py +69 -0
- graft/models/mobilenet.py +83 -0
- graft/models/resnet.py +564 -0
- graft/models/resnet9.py +72 -0
- graft/scheduler.py +63 -0
- graft/trainer.py +467 -0
- graft/utils/__init__.py +5 -0
- graft/utils/extras.py +37 -0
- graft/utils/generate.py +33 -0
- graft/utils/imagenetselloader.py +54 -0
- graft/utils/loader.py +293 -0
- graft/utils/model_mapper.py +45 -0
- graft/utils/pickler.py +27 -0
- graft_pytorch-0.1.7.dist-info/METADATA +302 -0
- graft_pytorch-0.1.7.dist-info/RECORD +31 -0
- graft_pytorch-0.1.7.dist-info/WHEEL +5 -0
- graft_pytorch-0.1.7.dist-info/entry_points.txt +2 -0
- graft_pytorch-0.1.7.dist-info/licenses/LICENSE +21 -0
- graft_pytorch-0.1.7.dist-info/top_level.txt +1 -0
graft/utils/loader.py
ADDED
@@ -0,0 +1,293 @@
|
|
1
|
+
import torchvision
|
2
|
+
import torchvision.transforms as transforms
|
3
|
+
import torch
|
4
|
+
from torch.utils.data import Dataset
|
5
|
+
import numpy as np
|
6
|
+
from sklearn.model_selection import train_test_split
|
7
|
+
import os
|
8
|
+
import PIL.Image as Image
|
9
|
+
|
10
|
+
class CustomDataset(Dataset):
|
11
|
+
def __init__(self, data, target, device=None, transform=None, isreg=False):
|
12
|
+
self.transform = transform
|
13
|
+
self.isreg = isreg
|
14
|
+
if device is not None:
|
15
|
+
# Push the entire data to given device, eg: cuda:0
|
16
|
+
self.data = data.float().to(device)
|
17
|
+
if isreg:
|
18
|
+
self.targets = target.float().to(device)
|
19
|
+
else:
|
20
|
+
self.targets = target.long().to(device)
|
21
|
+
|
22
|
+
else:
|
23
|
+
self.data = data.float()
|
24
|
+
if isreg:
|
25
|
+
self.targets = target.float()
|
26
|
+
else:
|
27
|
+
self.targets = target.long()
|
28
|
+
|
29
|
+
def __len__(self):
|
30
|
+
return len(self.targets)
|
31
|
+
|
32
|
+
def __getitem__(self, idx):
|
33
|
+
if torch.is_tensor(idx):
|
34
|
+
idx = idx.tolist()
|
35
|
+
sample_data = self.data[idx]
|
36
|
+
label = self.targets[idx]
|
37
|
+
if self.transform is not None:
|
38
|
+
sample_data = self.transform(sample_data)
|
39
|
+
return (sample_data, label) # .astype('float32')
|
40
|
+
|
41
|
+
class standard_scaling:
|
42
|
+
def __init__(self):
|
43
|
+
self.std = None
|
44
|
+
self.mean = None
|
45
|
+
|
46
|
+
def fit_transform(self, data):
|
47
|
+
self.std = np.std(data, axis=0)
|
48
|
+
self.mean = np.mean(data, axis=0)
|
49
|
+
transformed_data = np.subtract(data, self.mean)
|
50
|
+
transformed_data = np.divide(transformed_data, self.std)
|
51
|
+
return transformed_data
|
52
|
+
|
53
|
+
def transform(self, data):
|
54
|
+
transformed_data = np.subtract(data, self.mean)
|
55
|
+
transformed_data = np.divide(transformed_data, self.std)
|
56
|
+
return transformed_data
|
57
|
+
|
58
|
+
# class TinyImageNetDataset(Dataset):
|
59
|
+
# def __init__(self, root_dir, split='train', transform=None):
|
60
|
+
# """
|
61
|
+
# Args:
|
62
|
+
# root_dir (string): Directory with all the images.
|
63
|
+
# split (string): Either 'train', 'val', or 'test'.
|
64
|
+
# transform (callable, optional): Optional transform to be applied
|
65
|
+
# on a sample.
|
66
|
+
# """
|
67
|
+
# self.root_dir = root_dir
|
68
|
+
# self.split = split
|
69
|
+
# self.transform = transform
|
70
|
+
# self.classes = os.listdir(os.path.join(root_dir, split))
|
71
|
+
# self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}
|
72
|
+
# self.class_name = self._get_names()
|
73
|
+
# self.images = self._load_images()
|
74
|
+
|
75
|
+
# def _load_images(self):
|
76
|
+
# images = []
|
77
|
+
# for cls in self.classes:
|
78
|
+
# cls_dir = os.path.join(self.root_dir, self.split, cls, 'images')
|
79
|
+
# for image_file in os.listdir(cls_dir):
|
80
|
+
# image_path = os.path.join(cls_dir, image_file)
|
81
|
+
# images.append((image_path, self.class_to_idx[cls]))
|
82
|
+
# return images
|
83
|
+
|
84
|
+
# def __len__(self):
|
85
|
+
# return len(self.images)
|
86
|
+
|
87
|
+
# def __getitem__(self, idx):
|
88
|
+
# img_path, label = self.images[idx]
|
89
|
+
# image = Image.open(img_path).convert('RGB')
|
90
|
+
# if self.transform:
|
91
|
+
# image = self.transform(image)
|
92
|
+
# return image, label
|
93
|
+
|
94
|
+
# def _get_names(self):
|
95
|
+
# entity_dict = {}
|
96
|
+
# # Open the text file
|
97
|
+
# with open('tiny-imagenet-200/words.txt', 'r') as file:
|
98
|
+
# # Read each line
|
99
|
+
# for line in file:
|
100
|
+
# # Split the line into key and value using tab ('\t') as delimiter
|
101
|
+
# key, value = line.strip().split('\t')
|
102
|
+
|
103
|
+
# first = value.strip().split(',')
|
104
|
+
# # Add the key-value pair to the dictionary
|
105
|
+
# entity_dict[key] = first[0]
|
106
|
+
# # entity_dict.append(line)
|
107
|
+
# return entity_dict
|
108
|
+
|
109
|
+
class TinyImageNet(Dataset):
|
110
|
+
def __init__(self, root, split='train', transform=None):
|
111
|
+
"""
|
112
|
+
Args:
|
113
|
+
root_dir (string): Directory with all the images.
|
114
|
+
split (string): Either 'train', 'val', or 'test'.
|
115
|
+
transform (callable, optional): Optional transform to be applied
|
116
|
+
on a sample.
|
117
|
+
"""
|
118
|
+
self.root_dir = root
|
119
|
+
# self.root_dir = os.path.join(root, main_dir)
|
120
|
+
self.split = split
|
121
|
+
self.transform = transform
|
122
|
+
self.classes = []
|
123
|
+
with open(os.path.join(self.root_dir, 'wnids.txt'), 'r') as f:
|
124
|
+
self.classes = f.read().strip().split()
|
125
|
+
self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}
|
126
|
+
|
127
|
+
# self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}
|
128
|
+
self.class_name = self._get_names()
|
129
|
+
self.images = self._load_images()
|
130
|
+
# self.repl_str = str_out
|
131
|
+
|
132
|
+
def _load_images(self):
|
133
|
+
images = []
|
134
|
+
if self.split == 'train':
|
135
|
+
for cls in self.classes:
|
136
|
+
cls_dir = os.path.join(self.root_dir, self.split, cls, 'images')
|
137
|
+
for image_file in os.listdir(cls_dir):
|
138
|
+
image_path = os.path.join(cls_dir, image_file)
|
139
|
+
images.append((image_path, self.class_to_idx[cls]))
|
140
|
+
|
141
|
+
elif self.split == 'val':
|
142
|
+
val_dir = os.path.join(self.root_dir, self.split, 'images')
|
143
|
+
image_to_cls = {}
|
144
|
+
with open(os.path.join(self.root_dir, self.split, 'val_annotations.txt'), 'r') as f:
|
145
|
+
for line in f.read().strip().split('\n'):
|
146
|
+
# print(line)
|
147
|
+
image_to_cls[line.split()[0].strip()] = line.split()[1].strip()
|
148
|
+
for image_file in os.listdir(val_dir):
|
149
|
+
# print(image_file)
|
150
|
+
image_path = os.path.join(val_dir, image_file)
|
151
|
+
images.append((image_path, self.class_to_idx[image_to_cls[image_file]]))
|
152
|
+
|
153
|
+
# for cls in self.classes:
|
154
|
+
# cls_dir = os.path.join(self.root_dir, self.split, cls, 'images')
|
155
|
+
# for image_file in os.listdir(cls_dir):
|
156
|
+
# image_path = os.path.join(cls_dir, image_file)
|
157
|
+
# images.append((image_path, self.class_to_idx[cls]))
|
158
|
+
|
159
|
+
return images
|
160
|
+
|
161
|
+
def __len__(self):
|
162
|
+
return len(self.images)
|
163
|
+
|
164
|
+
def __getitem__(self, idx):
|
165
|
+
img_path, label = self.images[idx]
|
166
|
+
image = Image.open(img_path).convert('RGB')
|
167
|
+
if self.transform:
|
168
|
+
image = self.transform(image)
|
169
|
+
return image, label
|
170
|
+
|
171
|
+
def _get_names(self):
|
172
|
+
entity_dict = {}
|
173
|
+
# Open the text file
|
174
|
+
with open(os.path.join(self.root_dir, 'words.txt'), 'r') as file:
|
175
|
+
# Read each line
|
176
|
+
for line in file:
|
177
|
+
# Split the line into key and value using tab ('\t') as delimiter
|
178
|
+
key, value = line.strip().split('\t')
|
179
|
+
first = value.strip().split(',')
|
180
|
+
# Add the key-value pair to the dictionary
|
181
|
+
entity_dict[key] = first[0]
|
182
|
+
return entity_dict
|
183
|
+
|
184
|
+
|
185
|
+
def loader(dataset, dirs="./cifar10", trn_batch_size=64, val_batch_size=64, tst_batch_size=1000):
|
186
|
+
"""Load and return data loaders for the specified dataset"""
|
187
|
+
|
188
|
+
if dataset.lower() == "cifar10":
|
189
|
+
transform_train = transforms.Compose([
|
190
|
+
transforms.RandomCrop(32, padding=4),
|
191
|
+
transforms.RandomHorizontalFlip(),
|
192
|
+
transforms.ToTensor(),
|
193
|
+
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
|
194
|
+
])
|
195
|
+
|
196
|
+
transform_test = transforms.Compose([
|
197
|
+
transforms.ToTensor(),
|
198
|
+
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
|
199
|
+
])
|
200
|
+
|
201
|
+
trainset = torchvision.datasets.CIFAR10(
|
202
|
+
root='./data', train=True, download=True, transform=transform_train)
|
203
|
+
trainloader = torch.utils.data.DataLoader(
|
204
|
+
trainset, batch_size=trn_batch_size, shuffle=True, num_workers=2)
|
205
|
+
|
206
|
+
testset = torchvision.datasets.CIFAR10(
|
207
|
+
root='./data', train=False, download=True, transform=transform_test)
|
208
|
+
testloader = torch.utils.data.DataLoader(
|
209
|
+
testset, batch_size=tst_batch_size, shuffle=False, num_workers=2)
|
210
|
+
|
211
|
+
return trainloader, testloader, trainset, testset
|
212
|
+
|
213
|
+
elif dataset.lower() == "cifar100":
|
214
|
+
transform_train = transforms.Compose([
|
215
|
+
transforms.RandomCrop(32, padding=4),
|
216
|
+
transforms.RandomHorizontalFlip(),
|
217
|
+
transforms.ToTensor(),
|
218
|
+
transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
|
219
|
+
])
|
220
|
+
|
221
|
+
transform_test = transforms.Compose([
|
222
|
+
transforms.ToTensor(),
|
223
|
+
transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
|
224
|
+
])
|
225
|
+
|
226
|
+
trainset = torchvision.datasets.CIFAR100(
|
227
|
+
root='./data', train=True, download=True, transform=transform_train)
|
228
|
+
trainloader = torch.utils.data.DataLoader(
|
229
|
+
trainset, batch_size=trn_batch_size, shuffle=True, num_workers=2)
|
230
|
+
|
231
|
+
testset = torchvision.datasets.CIFAR100(
|
232
|
+
root='./data', train=False, download=True, transform=transform_test)
|
233
|
+
testloader = torch.utils.data.DataLoader(
|
234
|
+
testset, batch_size=tst_batch_size, shuffle=False, num_workers=2)
|
235
|
+
|
236
|
+
return trainloader, testloader, trainset, testset
|
237
|
+
|
238
|
+
elif dataset.lower() == "imagenet":
|
239
|
+
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
240
|
+
std=[0.229, 0.224, 0.225])
|
241
|
+
|
242
|
+
train_transform = transforms.Compose([
|
243
|
+
transforms.RandomResizedCrop(224),
|
244
|
+
transforms.RandomHorizontalFlip(),
|
245
|
+
transforms.ToTensor(),
|
246
|
+
normalize,
|
247
|
+
])
|
248
|
+
|
249
|
+
val_transform = transforms.Compose([
|
250
|
+
transforms.Resize(256),
|
251
|
+
transforms.CenterCrop(224),
|
252
|
+
transforms.ToTensor(),
|
253
|
+
normalize,
|
254
|
+
])
|
255
|
+
|
256
|
+
train_dir = os.path.join(dirs, 'train')
|
257
|
+
val_dir = os.path.join(dirs, 'val')
|
258
|
+
|
259
|
+
train_dataset = torchvision.datasets.ImageFolder(
|
260
|
+
train_dir,
|
261
|
+
train_transform
|
262
|
+
)
|
263
|
+
|
264
|
+
val_dataset = torchvision.datasets.ImageFolder(
|
265
|
+
val_dir,
|
266
|
+
val_transform
|
267
|
+
)
|
268
|
+
|
269
|
+
train_loader = torch.utils.data.DataLoader(
|
270
|
+
train_dataset, batch_size=trn_batch_size, shuffle=True,
|
271
|
+
num_workers=4, pin_memory=True
|
272
|
+
)
|
273
|
+
|
274
|
+
val_loader = torch.utils.data.DataLoader(
|
275
|
+
val_dataset, batch_size=val_batch_size, shuffle=False,
|
276
|
+
num_workers=4, pin_memory=True
|
277
|
+
)
|
278
|
+
|
279
|
+
return train_loader, val_loader, train_dataset, val_dataset
|
280
|
+
|
281
|
+
else:
|
282
|
+
raise ValueError(f"Dataset {dataset} not supported")
|
283
|
+
|
284
|
+
if __name__ == "__main__":
|
285
|
+
# Run some basic tests
|
286
|
+
try:
|
287
|
+
# Test CIFAR10
|
288
|
+
train_l, test_l, train_s, test_s = loader("cifar10")
|
289
|
+
# Test CIFAR100
|
290
|
+
train_l, test_l, train_s, test_s = loader("cifar100")
|
291
|
+
print("✓ All test cases passed successfully!")
|
292
|
+
except Exception as e:
|
293
|
+
print(f"✗ Test failed: {str(e)}")
|
@@ -0,0 +1,45 @@
|
|
1
|
+
from ..models import ResNet18,ResNet50, EfficientNetB0, ResNet9, MobileNetV2, MobileNet, ResNet101, ResNet152, ResNeXt29_32x4d, ResNext50_32x4d, ResNext101_32x8d, ResNext101_64x4d, FashionCNN, bertmodel
|
2
|
+
|
3
|
+
|
4
|
+
class ModelMapper:
|
5
|
+
def __init__(self, args):
|
6
|
+
self.args = args
|
7
|
+
self.model_mapping = {
|
8
|
+
"resnet18": ResNet18,
|
9
|
+
"resnet50": ResNet50,
|
10
|
+
"resnet9": ResNet9,
|
11
|
+
"mobilenetv2": MobileNetV2,
|
12
|
+
"mobilenet": MobileNet,
|
13
|
+
"resnet101": ResNet101,
|
14
|
+
"resnet152": ResNet152,
|
15
|
+
"efficientnetb0": EfficientNetB0,
|
16
|
+
"resnext": ResNeXt29_32x4d,
|
17
|
+
"resnext50": ResNext50_32x4d,
|
18
|
+
"resnext101_32": ResNext101_32x8d,
|
19
|
+
"resnext101_64": ResNext101_64x4d,
|
20
|
+
"fashioncnn": FashionCNN,
|
21
|
+
# "twolayernet": TwoLayerNet,
|
22
|
+
# "threelayernet": ThreeLayerNet,
|
23
|
+
"bert": bertmodel,
|
24
|
+
}
|
25
|
+
|
26
|
+
|
27
|
+
def get_model(self):
|
28
|
+
# Get the model name from arguments and convert it to lowercase
|
29
|
+
|
30
|
+
model_group1 = ["resnext", "resnet9", "fashioncnn", "twolayernet", "threelayernet"]
|
31
|
+
model_group2 = ["bert"]
|
32
|
+
|
33
|
+
model_name = self.args.model.lower()
|
34
|
+
if model_name in self.model_mapping:
|
35
|
+
# Get the corresponding model class
|
36
|
+
model_class = self.model_mapping[model_name]
|
37
|
+
if model_name.lower() in model_group1:
|
38
|
+
return model_class(self.args.in_chanls, self.args.numClasses).to(self.args.device)
|
39
|
+
elif model_name.lower() in model_group2:
|
40
|
+
return model_class(self.args.device, self.args.numClasses).to(self.args.device)
|
41
|
+
else:
|
42
|
+
return model_class(self.args.numClasses).to(self.args.device)
|
43
|
+
else:
|
44
|
+
print("model not available at this moment")
|
45
|
+
|
graft/utils/pickler.py
ADDED
@@ -0,0 +1,27 @@
|
|
1
|
+
import pickle
|
2
|
+
import os
|
3
|
+
|
4
|
+
class pickler():
|
5
|
+
|
6
|
+
def save_pickle(V, dirs, bs) -> None:
|
7
|
+
folder_name = f"{dirs}" + "_pickle"
|
8
|
+
if not os.path.exists(folder_name):
|
9
|
+
os.mkdir(folder_name)
|
10
|
+
|
11
|
+
file = os.path.join(f"{dirs}" + "_pickle", f"V_{bs}.pkl")
|
12
|
+
with open(file, 'wb') as f:
|
13
|
+
pickle.dump(V, f)
|
14
|
+
|
15
|
+
if os.path.exists(file):
|
16
|
+
print("pickle saved successfully")
|
17
|
+
else:
|
18
|
+
print("Failed to save pickle.")
|
19
|
+
|
20
|
+
|
21
|
+
def load_pickle(dirs, bs):
|
22
|
+
|
23
|
+
file = os.path.join(f"{dirs}" + "_pickle", f"V_{bs}.pkl")
|
24
|
+
|
25
|
+
with open(file, 'rb') as f:
|
26
|
+
data3 = pickle.load(f)
|
27
|
+
return data3
|
@@ -0,0 +1,302 @@
|
|
1
|
+
Metadata-Version: 2.4
|
2
|
+
Name: graft-pytorch
|
3
|
+
Version: 0.1.7
|
4
|
+
Summary: Gradient-Aware Fast MaxVol Technique for Dynamic Data Sampling
|
5
|
+
Home-page: https://github.com/ashishjv1/GRAFT
|
6
|
+
Author: Ashish Jha
|
7
|
+
Author-email: Ashish Jha <Ashish.Jha@skoltech.ru>
|
8
|
+
Maintainer-email: Ashish Jha <Ashish.Jha@skoltech.ru>
|
9
|
+
License: MIT
|
10
|
+
Project-URL: Homepage, https://github.com/ashishjv1/GRAFT
|
11
|
+
Project-URL: Repository, https://github.com/ashishjv1/GRAFT
|
12
|
+
Project-URL: Bug Reports, https://github.com/ashishjv1/GRAFT/issues
|
13
|
+
Project-URL: Documentation, https://github.com/ashishjv1/GRAFT/blob/main/README.md
|
14
|
+
Keywords: machine-learning,deep-learning,pytorch,data-sampling,gradient-based-sampling
|
15
|
+
Classifier: Development Status :: 3 - Alpha
|
16
|
+
Classifier: Intended Audience :: Developers
|
17
|
+
Classifier: Intended Audience :: Science/Research
|
18
|
+
Classifier: License :: OSI Approved :: MIT License
|
19
|
+
Classifier: Operating System :: OS Independent
|
20
|
+
Classifier: Programming Language :: Python :: 3
|
21
|
+
Classifier: Programming Language :: Python :: 3.8
|
22
|
+
Classifier: Programming Language :: Python :: 3.9
|
23
|
+
Classifier: Programming Language :: Python :: 3.10
|
24
|
+
Classifier: Programming Language :: Python :: 3.11
|
25
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
26
|
+
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
27
|
+
Requires-Python: >=3.8
|
28
|
+
Description-Content-Type: text/markdown
|
29
|
+
License-File: LICENSE
|
30
|
+
Requires-Dist: torch>=1.9.0
|
31
|
+
Requires-Dist: torchvision>=0.10.0
|
32
|
+
Requires-Dist: numpy>=1.19.2
|
33
|
+
Requires-Dist: tqdm>=4.62.3
|
34
|
+
Requires-Dist: scikit-learn>=0.24.2
|
35
|
+
Requires-Dist: pillow>=8.3.1
|
36
|
+
Requires-Dist: matplotlib>=3.4.3
|
37
|
+
Requires-Dist: transformers>=4.0.0
|
38
|
+
Requires-Dist: medmnist>=2.0.0
|
39
|
+
Provides-Extra: dev
|
40
|
+
Requires-Dist: pytest>=6.2.5; extra == "dev"
|
41
|
+
Requires-Dist: pytest-cov; extra == "dev"
|
42
|
+
Requires-Dist: black; extra == "dev"
|
43
|
+
Requires-Dist: isort; extra == "dev"
|
44
|
+
Requires-Dist: flake8; extra == "dev"
|
45
|
+
Provides-Extra: tracking
|
46
|
+
Requires-Dist: wandb>=0.12.0; extra == "tracking"
|
47
|
+
Requires-Dist: eco2ai>=1.0.0; extra == "tracking"
|
48
|
+
Provides-Extra: all
|
49
|
+
Requires-Dist: pytest>=6.2.5; extra == "all"
|
50
|
+
Requires-Dist: pytest-cov; extra == "all"
|
51
|
+
Requires-Dist: black; extra == "all"
|
52
|
+
Requires-Dist: isort; extra == "all"
|
53
|
+
Requires-Dist: flake8; extra == "all"
|
54
|
+
Requires-Dist: wandb>=0.12.0; extra == "all"
|
55
|
+
Requires-Dist: eco2ai>=1.0.0; extra == "all"
|
56
|
+
Dynamic: author
|
57
|
+
Dynamic: home-page
|
58
|
+
Dynamic: license-file
|
59
|
+
Dynamic: requires-python
|
60
|
+
|
61
|
+
# GRAFT: Gradient-Aware Fast MaxVol Technique for Dynamic Data Sampling
|
62
|
+
|
63
|
+
[](https://badge.fury.io/py/graft-pytorch)
|
64
|
+
[](https://www.python.org/downloads/)
|
65
|
+
[](https://opensource.org/licenses/MIT)
|
66
|
+
|
67
|
+
A PyTorch implementation of smart sampling for efficient deep learning training.
|
68
|
+
|
69
|
+
## Overview
|
70
|
+
GRAFT uses gradient information and feature decomposition to select the most informative samples during training, reducing computation time while maintaining model performance.
|
71
|
+
|
72
|
+
## Features
|
73
|
+
- **Smart sample selection** using gradient-based importance scoring
|
74
|
+
- **Multi-architecture support** (ResNet, ResNeXT, EfficientNet, BERT)
|
75
|
+
- **Dataset compatibility** (CIFAR10/100, TinyImageNet, Caltech256, Medical datasets)
|
76
|
+
- **Experiment tracking** with Weights & Biases integration
|
77
|
+
- **Carbon footprint tracking** with eco2AI
|
78
|
+
- **Efficient training** with reduced computational overhead
|
79
|
+
|
80
|
+
## Installation
|
81
|
+
|
82
|
+
### From PyPI (Recommended)
|
83
|
+
```bash
|
84
|
+
pip install graft-pytorch
|
85
|
+
```
|
86
|
+
|
87
|
+
### With optional dependencies
|
88
|
+
```bash
|
89
|
+
# For experiment tracking
|
90
|
+
pip install graft-pytorch[tracking]
|
91
|
+
|
92
|
+
# For development
|
93
|
+
pip install graft-pytorch[dev]
|
94
|
+
|
95
|
+
# Everything
|
96
|
+
pip install graft-pytorch[all]
|
97
|
+
```
|
98
|
+
|
99
|
+
### From Source
|
100
|
+
```bash
|
101
|
+
git clone https://github.com/ashishjv1/GRAFT.git
|
102
|
+
cd GRAFT
|
103
|
+
pip install -e .
|
104
|
+
```
|
105
|
+
|
106
|
+
## Quick Start
|
107
|
+
|
108
|
+
### Command Line Interface
|
109
|
+
```bash
|
110
|
+
# Install and train with smart sampling
|
111
|
+
pip install graft-pytorch
|
112
|
+
|
113
|
+
# Basic training with GRAFT sampling on CIFAR-10
|
114
|
+
graft-train \
|
115
|
+
--numEpochs=200 \
|
116
|
+
--batch_size=128 \
|
117
|
+
--device="cuda" \
|
118
|
+
--optimizer="sgd" \
|
119
|
+
--lr=0.1 \
|
120
|
+
--numClasses=10 \
|
121
|
+
--dataset="cifar10" \
|
122
|
+
--model="resnet18" \
|
123
|
+
--fraction=0.5 \
|
124
|
+
--select_iter=25 \
|
125
|
+
--warm_start
|
126
|
+
```
|
127
|
+
|
128
|
+
### Python API
|
129
|
+
```python
|
130
|
+
import torch
|
131
|
+
from graft import ModelTrainer, TrainingConfig
|
132
|
+
from graft.utils.loader import loader
|
133
|
+
|
134
|
+
# Load your dataset
|
135
|
+
trainloader, valloader, trainset, valset = loader(
|
136
|
+
dataset="cifar10",
|
137
|
+
trn_batch_size=128,
|
138
|
+
val_batch_size=128
|
139
|
+
)
|
140
|
+
|
141
|
+
# Configure training with GRAFT
|
142
|
+
config = TrainingConfig(
|
143
|
+
numEpochs=100,
|
144
|
+
batch_size=128,
|
145
|
+
device="cuda" if torch.cuda.is_available() else "cpu",
|
146
|
+
model_name="resnet18",
|
147
|
+
dataset_name="cifar10",
|
148
|
+
trainloader=trainloader,
|
149
|
+
valloader=valloader,
|
150
|
+
trainset=trainset,
|
151
|
+
optimizer_name="sgd",
|
152
|
+
lr=0.1,
|
153
|
+
fraction=0.5, # Use 50% of data per epoch
|
154
|
+
selection_iter=25, # Reselect samples every 25 epochs
|
155
|
+
warm_start=True # Train on full data initially
|
156
|
+
)
|
157
|
+
|
158
|
+
# Train with smart sampling
|
159
|
+
trainer = ModelTrainer(config, trainloader, valloader, trainset)
|
160
|
+
train_stats, val_stats = trainer.train()
|
161
|
+
|
162
|
+
print(f"Best validation accuracy: {val_stats['best_acc']:.2%}")
|
163
|
+
```
|
164
|
+
|
165
|
+
### Advanced Usage
|
166
|
+
```python
|
167
|
+
from graft import feature_sel, sample_selection
|
168
|
+
import torch.nn as nn
|
169
|
+
|
170
|
+
# Custom model and data selection
|
171
|
+
model = MyCustomModel()
|
172
|
+
data3 = feature_sel(dataloader, batch_size=128, device="cuda")
|
173
|
+
|
174
|
+
# Manual sample selection
|
175
|
+
selected_indices = sample_selection(
|
176
|
+
dataloader, data3, model, model.state_dict(),
|
177
|
+
batch_size=128, fraction=0.3, select_iter=10,
|
178
|
+
numEpochs=200, device="cuda", dataset="custom"
|
179
|
+
)
|
180
|
+
```
|
181
|
+
|
182
|
+
## Functionality Overview
|
183
|
+
|
184
|
+
### Core Components
|
185
|
+
|
186
|
+
#### 1. Smart Sample Selection
|
187
|
+
- **`sample_selection()`**: Selects most informative samples using gradient-based importance
|
188
|
+
- **`feature_sel()`**: Performs feature decomposition for efficient sampling
|
189
|
+
- Reduces training time by 30-50% while maintaining model performance
|
190
|
+
|
191
|
+
#### 2. Supported Models
|
192
|
+
- **Vision Models**: ResNet, ResNeXt, EfficientNet, MobileNet, FashionCNN
|
193
|
+
- **Language Models**: BERT for sequence classification
|
194
|
+
- **Custom Models**: Easy integration with any PyTorch model
|
195
|
+
|
196
|
+
#### 3. Dataset Support
|
197
|
+
- **Computer Vision**: CIFAR-10/100, TinyImageNet, Caltech256
|
198
|
+
- **Medical Imaging**: Integration with MedMNIST datasets
|
199
|
+
- **Custom Datasets**: Support for any PyTorch DataLoader
|
200
|
+
|
201
|
+
#### 4. Training Features
|
202
|
+
- **Dynamic Sampling**: Adaptive sample selection during training
|
203
|
+
- **Warm Starting**: Begin with full dataset, then switch to sampling
|
204
|
+
- **Experiment Tracking**: Built-in WandB integration
|
205
|
+
- **Carbon Tracking**: Monitor environmental impact with eco2AI
|
206
|
+
|
207
|
+
### Configuration Parameters
|
208
|
+
|
209
|
+
| Parameter | Description | Default | Options |
|
210
|
+
|-----------|-------------|---------|---------|
|
211
|
+
| `numEpochs` | Training epochs | 200 | Any integer |
|
212
|
+
| `batch_size` | Batch size | 128 | 32, 64, 128, 256+ |
|
213
|
+
| `device` | Computing device | "cuda" | "cpu", "cuda" |
|
214
|
+
| `model` | Model architecture | "resnet18" | "resnet18/50", "resnext", "efficientnet" |
|
215
|
+
| `fraction` | Data sampling ratio | 0.5 | 0.1 - 1.0 |
|
216
|
+
| `select_iter` | Reselection frequency | 25 | Any integer |
|
217
|
+
| `optimizer` | Optimization algorithm | "sgd" | "sgd", "adam" |
|
218
|
+
| `lr` | Learning rate | 0.1 | 0.001 - 0.1 |
|
219
|
+
| `warm_start` | Use full data initially | False | True/False |
|
220
|
+
| `decomp` | Decomposition backend | "numpy" | "numpy", "torch" |
|
221
|
+
|
222
|
+
### Performance Benefits
|
223
|
+
|
224
|
+
- **Speed**: 30-50% faster training time
|
225
|
+
- **Memory**: Reduced memory usage through smart sampling
|
226
|
+
- **Accuracy**: Maintains or improves model performance
|
227
|
+
- **Efficiency**: Lower carbon footprint and energy consumption
|
228
|
+
|
229
|
+
## Package Structure
|
230
|
+
```
|
231
|
+
graft-pytorch/
|
232
|
+
├── graft/
|
233
|
+
│ ├── __init__.py # Main package exports
|
234
|
+
│ ├── trainer.py # Training orchestration
|
235
|
+
│ ├── genindices.py # Sample selection algorithms
|
236
|
+
│ ├── decompositions.py # Feature decomposition
|
237
|
+
│ ├── models/ # Supported architectures
|
238
|
+
│ │ ├── resnet.py # ResNet implementations
|
239
|
+
│ │ ├── efficientnet.py # EfficientNet models
|
240
|
+
│ │ └── BERT_model.py # BERT for classification
|
241
|
+
│ └── utils/ # Utility functions
|
242
|
+
│ ├── loader.py # Dataset loaders
|
243
|
+
│ └── model_mapper.py # Model selection
|
244
|
+
├── tests/ # Comprehensive test suite
|
245
|
+
├── examples/ # Usage examples
|
246
|
+
└── OIDC_SETUP.md # Deployment configuration
|
247
|
+
```
|
248
|
+
|
249
|
+
## Contributing
|
250
|
+
|
251
|
+
We welcome contributions! Please see our [contribution guidelines](CONTRIBUTING.md) for details.
|
252
|
+
|
253
|
+
### Development Setup
|
254
|
+
```bash
|
255
|
+
# Clone the repository
|
256
|
+
git clone https://github.com/ashishjv1/GRAFT.git
|
257
|
+
cd GRAFT
|
258
|
+
|
259
|
+
# Install in development mode
|
260
|
+
pip install -e .[dev]
|
261
|
+
|
262
|
+
# Run tests
|
263
|
+
pytest tests/ -v
|
264
|
+
|
265
|
+
# Run linting
|
266
|
+
flake8 graft/ tests/
|
267
|
+
```
|
268
|
+
|
269
|
+
## License
|
270
|
+
|
271
|
+
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
272
|
+
|
273
|
+
## Citation
|
274
|
+
|
275
|
+
If you use GRAFT in your research, please cite our paper:
|
276
|
+
|
277
|
+
```bibtex
|
278
|
+
@misc{jha2025graftgradientawarefastmaxvol,
|
279
|
+
title = {GRAFT: Gradient-Aware Fast MaxVol Technique for Dynamic Data Sampling},
|
280
|
+
author = {Ashish Jha and Anh Huy Phan and Razan Dibo and Valentin Leplat},
|
281
|
+
year = {2025},
|
282
|
+
eprint = {2508.13653},
|
283
|
+
archivePrefix = {arXiv},
|
284
|
+
primaryClass = {cs.LG},
|
285
|
+
url = {https://arxiv.org/abs/2508.13653}
|
286
|
+
}
|
287
|
+
```
|
288
|
+
|
289
|
+
## Acknowledgments
|
290
|
+
|
291
|
+
- Built using PyTorch
|
292
|
+
- Inspired by MaxVol techniques for data sampling
|
293
|
+
- Special thanks to the open-source community
|
294
|
+
|
295
|
+
---
|
296
|
+
|
297
|
+
**PyPI Package**: [graft-pytorch](https://pypi.org/project/graft-pytorch/)
|
298
|
+
**Paper**: [arXiv:2508.13653](https://arxiv.org/abs/2508.13653)
|
299
|
+
**Issues**: [GitHub Issues](https://github.com/ashishjv1/GRAFT/issues)
|
300
|
+
**Contact**: [Ashish Jha](mailto:Ashish.Jha@skoltech.ru)
|
301
|
+
|
302
|
+
|