graft-pytorch 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
graft/utils/loader.py ADDED
@@ -0,0 +1,293 @@
1
+ import torchvision
2
+ import torchvision.transforms as transforms
3
+ import torch
4
+ from torch.utils.data import Dataset
5
+ import numpy as np
6
+ from sklearn.model_selection import train_test_split
7
+ import os
8
+ import PIL.Image as Image
9
+
10
+ class CustomDataset(Dataset):
11
+ def __init__(self, data, target, device=None, transform=None, isreg=False):
12
+ self.transform = transform
13
+ self.isreg = isreg
14
+ if device is not None:
15
+ # Push the entire data to given device, eg: cuda:0
16
+ self.data = data.float().to(device)
17
+ if isreg:
18
+ self.targets = target.float().to(device)
19
+ else:
20
+ self.targets = target.long().to(device)
21
+
22
+ else:
23
+ self.data = data.float()
24
+ if isreg:
25
+ self.targets = target.float()
26
+ else:
27
+ self.targets = target.long()
28
+
29
+ def __len__(self):
30
+ return len(self.targets)
31
+
32
+ def __getitem__(self, idx):
33
+ if torch.is_tensor(idx):
34
+ idx = idx.tolist()
35
+ sample_data = self.data[idx]
36
+ label = self.targets[idx]
37
+ if self.transform is not None:
38
+ sample_data = self.transform(sample_data)
39
+ return (sample_data, label) # .astype('float32')
40
+
41
+ class standard_scaling:
42
+ def __init__(self):
43
+ self.std = None
44
+ self.mean = None
45
+
46
+ def fit_transform(self, data):
47
+ self.std = np.std(data, axis=0)
48
+ self.mean = np.mean(data, axis=0)
49
+ transformed_data = np.subtract(data, self.mean)
50
+ transformed_data = np.divide(transformed_data, self.std)
51
+ return transformed_data
52
+
53
+ def transform(self, data):
54
+ transformed_data = np.subtract(data, self.mean)
55
+ transformed_data = np.divide(transformed_data, self.std)
56
+ return transformed_data
57
+
58
+ # class TinyImageNetDataset(Dataset):
59
+ # def __init__(self, root_dir, split='train', transform=None):
60
+ # """
61
+ # Args:
62
+ # root_dir (string): Directory with all the images.
63
+ # split (string): Either 'train', 'val', or 'test'.
64
+ # transform (callable, optional): Optional transform to be applied
65
+ # on a sample.
66
+ # """
67
+ # self.root_dir = root_dir
68
+ # self.split = split
69
+ # self.transform = transform
70
+ # self.classes = os.listdir(os.path.join(root_dir, split))
71
+ # self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}
72
+ # self.class_name = self._get_names()
73
+ # self.images = self._load_images()
74
+
75
+ # def _load_images(self):
76
+ # images = []
77
+ # for cls in self.classes:
78
+ # cls_dir = os.path.join(self.root_dir, self.split, cls, 'images')
79
+ # for image_file in os.listdir(cls_dir):
80
+ # image_path = os.path.join(cls_dir, image_file)
81
+ # images.append((image_path, self.class_to_idx[cls]))
82
+ # return images
83
+
84
+ # def __len__(self):
85
+ # return len(self.images)
86
+
87
+ # def __getitem__(self, idx):
88
+ # img_path, label = self.images[idx]
89
+ # image = Image.open(img_path).convert('RGB')
90
+ # if self.transform:
91
+ # image = self.transform(image)
92
+ # return image, label
93
+
94
+ # def _get_names(self):
95
+ # entity_dict = {}
96
+ # # Open the text file
97
+ # with open('tiny-imagenet-200/words.txt', 'r') as file:
98
+ # # Read each line
99
+ # for line in file:
100
+ # # Split the line into key and value using tab ('\t') as delimiter
101
+ # key, value = line.strip().split('\t')
102
+
103
+ # first = value.strip().split(',')
104
+ # # Add the key-value pair to the dictionary
105
+ # entity_dict[key] = first[0]
106
+ # # entity_dict.append(line)
107
+ # return entity_dict
108
+
109
+ class TinyImageNet(Dataset):
110
+ def __init__(self, root, split='train', transform=None):
111
+ """
112
+ Args:
113
+ root_dir (string): Directory with all the images.
114
+ split (string): Either 'train', 'val', or 'test'.
115
+ transform (callable, optional): Optional transform to be applied
116
+ on a sample.
117
+ """
118
+ self.root_dir = root
119
+ # self.root_dir = os.path.join(root, main_dir)
120
+ self.split = split
121
+ self.transform = transform
122
+ self.classes = []
123
+ with open(os.path.join(self.root_dir, 'wnids.txt'), 'r') as f:
124
+ self.classes = f.read().strip().split()
125
+ self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}
126
+
127
+ # self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}
128
+ self.class_name = self._get_names()
129
+ self.images = self._load_images()
130
+ # self.repl_str = str_out
131
+
132
+ def _load_images(self):
133
+ images = []
134
+ if self.split == 'train':
135
+ for cls in self.classes:
136
+ cls_dir = os.path.join(self.root_dir, self.split, cls, 'images')
137
+ for image_file in os.listdir(cls_dir):
138
+ image_path = os.path.join(cls_dir, image_file)
139
+ images.append((image_path, self.class_to_idx[cls]))
140
+
141
+ elif self.split == 'val':
142
+ val_dir = os.path.join(self.root_dir, self.split, 'images')
143
+ image_to_cls = {}
144
+ with open(os.path.join(self.root_dir, self.split, 'val_annotations.txt'), 'r') as f:
145
+ for line in f.read().strip().split('\n'):
146
+ # print(line)
147
+ image_to_cls[line.split()[0].strip()] = line.split()[1].strip()
148
+ for image_file in os.listdir(val_dir):
149
+ # print(image_file)
150
+ image_path = os.path.join(val_dir, image_file)
151
+ images.append((image_path, self.class_to_idx[image_to_cls[image_file]]))
152
+
153
+ # for cls in self.classes:
154
+ # cls_dir = os.path.join(self.root_dir, self.split, cls, 'images')
155
+ # for image_file in os.listdir(cls_dir):
156
+ # image_path = os.path.join(cls_dir, image_file)
157
+ # images.append((image_path, self.class_to_idx[cls]))
158
+
159
+ return images
160
+
161
+ def __len__(self):
162
+ return len(self.images)
163
+
164
+ def __getitem__(self, idx):
165
+ img_path, label = self.images[idx]
166
+ image = Image.open(img_path).convert('RGB')
167
+ if self.transform:
168
+ image = self.transform(image)
169
+ return image, label
170
+
171
+ def _get_names(self):
172
+ entity_dict = {}
173
+ # Open the text file
174
+ with open(os.path.join(self.root_dir, 'words.txt'), 'r') as file:
175
+ # Read each line
176
+ for line in file:
177
+ # Split the line into key and value using tab ('\t') as delimiter
178
+ key, value = line.strip().split('\t')
179
+ first = value.strip().split(',')
180
+ # Add the key-value pair to the dictionary
181
+ entity_dict[key] = first[0]
182
+ return entity_dict
183
+
184
+
185
+ def loader(dataset, dirs="./cifar10", trn_batch_size=64, val_batch_size=64, tst_batch_size=1000):
186
+ """Load and return data loaders for the specified dataset"""
187
+
188
+ if dataset.lower() == "cifar10":
189
+ transform_train = transforms.Compose([
190
+ transforms.RandomCrop(32, padding=4),
191
+ transforms.RandomHorizontalFlip(),
192
+ transforms.ToTensor(),
193
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
194
+ ])
195
+
196
+ transform_test = transforms.Compose([
197
+ transforms.ToTensor(),
198
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
199
+ ])
200
+
201
+ trainset = torchvision.datasets.CIFAR10(
202
+ root='./data', train=True, download=True, transform=transform_train)
203
+ trainloader = torch.utils.data.DataLoader(
204
+ trainset, batch_size=trn_batch_size, shuffle=True, num_workers=2)
205
+
206
+ testset = torchvision.datasets.CIFAR10(
207
+ root='./data', train=False, download=True, transform=transform_test)
208
+ testloader = torch.utils.data.DataLoader(
209
+ testset, batch_size=tst_batch_size, shuffle=False, num_workers=2)
210
+
211
+ return trainloader, testloader, trainset, testset
212
+
213
+ elif dataset.lower() == "cifar100":
214
+ transform_train = transforms.Compose([
215
+ transforms.RandomCrop(32, padding=4),
216
+ transforms.RandomHorizontalFlip(),
217
+ transforms.ToTensor(),
218
+ transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
219
+ ])
220
+
221
+ transform_test = transforms.Compose([
222
+ transforms.ToTensor(),
223
+ transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
224
+ ])
225
+
226
+ trainset = torchvision.datasets.CIFAR100(
227
+ root='./data', train=True, download=True, transform=transform_train)
228
+ trainloader = torch.utils.data.DataLoader(
229
+ trainset, batch_size=trn_batch_size, shuffle=True, num_workers=2)
230
+
231
+ testset = torchvision.datasets.CIFAR100(
232
+ root='./data', train=False, download=True, transform=transform_test)
233
+ testloader = torch.utils.data.DataLoader(
234
+ testset, batch_size=tst_batch_size, shuffle=False, num_workers=2)
235
+
236
+ return trainloader, testloader, trainset, testset
237
+
238
+ elif dataset.lower() == "imagenet":
239
+ normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
240
+ std=[0.229, 0.224, 0.225])
241
+
242
+ train_transform = transforms.Compose([
243
+ transforms.RandomResizedCrop(224),
244
+ transforms.RandomHorizontalFlip(),
245
+ transforms.ToTensor(),
246
+ normalize,
247
+ ])
248
+
249
+ val_transform = transforms.Compose([
250
+ transforms.Resize(256),
251
+ transforms.CenterCrop(224),
252
+ transforms.ToTensor(),
253
+ normalize,
254
+ ])
255
+
256
+ train_dir = os.path.join(dirs, 'train')
257
+ val_dir = os.path.join(dirs, 'val')
258
+
259
+ train_dataset = torchvision.datasets.ImageFolder(
260
+ train_dir,
261
+ train_transform
262
+ )
263
+
264
+ val_dataset = torchvision.datasets.ImageFolder(
265
+ val_dir,
266
+ val_transform
267
+ )
268
+
269
+ train_loader = torch.utils.data.DataLoader(
270
+ train_dataset, batch_size=trn_batch_size, shuffle=True,
271
+ num_workers=4, pin_memory=True
272
+ )
273
+
274
+ val_loader = torch.utils.data.DataLoader(
275
+ val_dataset, batch_size=val_batch_size, shuffle=False,
276
+ num_workers=4, pin_memory=True
277
+ )
278
+
279
+ return train_loader, val_loader, train_dataset, val_dataset
280
+
281
+ else:
282
+ raise ValueError(f"Dataset {dataset} not supported")
283
+
284
+ if __name__ == "__main__":
285
+ # Run some basic tests
286
+ try:
287
+ # Test CIFAR10
288
+ train_l, test_l, train_s, test_s = loader("cifar10")
289
+ # Test CIFAR100
290
+ train_l, test_l, train_s, test_s = loader("cifar100")
291
+ print("✓ All test cases passed successfully!")
292
+ except Exception as e:
293
+ print(f"✗ Test failed: {str(e)}")
@@ -0,0 +1,45 @@
1
+ from ..models import ResNet18,ResNet50, EfficientNetB0, ResNet9, MobileNetV2, MobileNet, ResNet101, ResNet152, ResNeXt29_32x4d, ResNext50_32x4d, ResNext101_32x8d, ResNext101_64x4d, FashionCNN, bertmodel
2
+
3
+
4
+ class ModelMapper:
5
+ def __init__(self, args):
6
+ self.args = args
7
+ self.model_mapping = {
8
+ "resnet18": ResNet18,
9
+ "resnet50": ResNet50,
10
+ "resnet9": ResNet9,
11
+ "mobilenetv2": MobileNetV2,
12
+ "mobilenet": MobileNet,
13
+ "resnet101": ResNet101,
14
+ "resnet152": ResNet152,
15
+ "efficientnetb0": EfficientNetB0,
16
+ "resnext": ResNeXt29_32x4d,
17
+ "resnext50": ResNext50_32x4d,
18
+ "resnext101_32": ResNext101_32x8d,
19
+ "resnext101_64": ResNext101_64x4d,
20
+ "fashioncnn": FashionCNN,
21
+ # "twolayernet": TwoLayerNet,
22
+ # "threelayernet": ThreeLayerNet,
23
+ "bert": bertmodel,
24
+ }
25
+
26
+
27
+ def get_model(self):
28
+ # Get the model name from arguments and convert it to lowercase
29
+
30
+ model_group1 = ["resnext", "resnet9", "fashioncnn", "twolayernet", "threelayernet"]
31
+ model_group2 = ["bert"]
32
+
33
+ model_name = self.args.model.lower()
34
+ if model_name in self.model_mapping:
35
+ # Get the corresponding model class
36
+ model_class = self.model_mapping[model_name]
37
+ if model_name.lower() in model_group1:
38
+ return model_class(self.args.in_chanls, self.args.numClasses).to(self.args.device)
39
+ elif model_name.lower() in model_group2:
40
+ return model_class(self.args.device, self.args.numClasses).to(self.args.device)
41
+ else:
42
+ return model_class(self.args.numClasses).to(self.args.device)
43
+ else:
44
+ print("model not available at this moment")
45
+
graft/utils/pickler.py ADDED
@@ -0,0 +1,27 @@
1
+ import pickle
2
+ import os
3
+
4
+ class pickler():
5
+
6
+ def save_pickle(V, dirs, bs) -> None:
7
+ folder_name = f"{dirs}" + "_pickle"
8
+ if not os.path.exists(folder_name):
9
+ os.mkdir(folder_name)
10
+
11
+ file = os.path.join(f"{dirs}" + "_pickle", f"V_{bs}.pkl")
12
+ with open(file, 'wb') as f:
13
+ pickle.dump(V, f)
14
+
15
+ if os.path.exists(file):
16
+ print("pickle saved successfully")
17
+ else:
18
+ print("Failed to save pickle.")
19
+
20
+
21
+ def load_pickle(dirs, bs):
22
+
23
+ file = os.path.join(f"{dirs}" + "_pickle", f"V_{bs}.pkl")
24
+
25
+ with open(file, 'rb') as f:
26
+ data3 = pickle.load(f)
27
+ return data3
@@ -0,0 +1,302 @@
1
+ Metadata-Version: 2.4
2
+ Name: graft-pytorch
3
+ Version: 0.1.7
4
+ Summary: Gradient-Aware Fast MaxVol Technique for Dynamic Data Sampling
5
+ Home-page: https://github.com/ashishjv1/GRAFT
6
+ Author: Ashish Jha
7
+ Author-email: Ashish Jha <Ashish.Jha@skoltech.ru>
8
+ Maintainer-email: Ashish Jha <Ashish.Jha@skoltech.ru>
9
+ License: MIT
10
+ Project-URL: Homepage, https://github.com/ashishjv1/GRAFT
11
+ Project-URL: Repository, https://github.com/ashishjv1/GRAFT
12
+ Project-URL: Bug Reports, https://github.com/ashishjv1/GRAFT/issues
13
+ Project-URL: Documentation, https://github.com/ashishjv1/GRAFT/blob/main/README.md
14
+ Keywords: machine-learning,deep-learning,pytorch,data-sampling,gradient-based-sampling
15
+ Classifier: Development Status :: 3 - Alpha
16
+ Classifier: Intended Audience :: Developers
17
+ Classifier: Intended Audience :: Science/Research
18
+ Classifier: License :: OSI Approved :: MIT License
19
+ Classifier: Operating System :: OS Independent
20
+ Classifier: Programming Language :: Python :: 3
21
+ Classifier: Programming Language :: Python :: 3.8
22
+ Classifier: Programming Language :: Python :: 3.9
23
+ Classifier: Programming Language :: Python :: 3.10
24
+ Classifier: Programming Language :: Python :: 3.11
25
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
26
+ Classifier: Topic :: Software Development :: Libraries :: Python Modules
27
+ Requires-Python: >=3.8
28
+ Description-Content-Type: text/markdown
29
+ License-File: LICENSE
30
+ Requires-Dist: torch>=1.9.0
31
+ Requires-Dist: torchvision>=0.10.0
32
+ Requires-Dist: numpy>=1.19.2
33
+ Requires-Dist: tqdm>=4.62.3
34
+ Requires-Dist: scikit-learn>=0.24.2
35
+ Requires-Dist: pillow>=8.3.1
36
+ Requires-Dist: matplotlib>=3.4.3
37
+ Requires-Dist: transformers>=4.0.0
38
+ Requires-Dist: medmnist>=2.0.0
39
+ Provides-Extra: dev
40
+ Requires-Dist: pytest>=6.2.5; extra == "dev"
41
+ Requires-Dist: pytest-cov; extra == "dev"
42
+ Requires-Dist: black; extra == "dev"
43
+ Requires-Dist: isort; extra == "dev"
44
+ Requires-Dist: flake8; extra == "dev"
45
+ Provides-Extra: tracking
46
+ Requires-Dist: wandb>=0.12.0; extra == "tracking"
47
+ Requires-Dist: eco2ai>=1.0.0; extra == "tracking"
48
+ Provides-Extra: all
49
+ Requires-Dist: pytest>=6.2.5; extra == "all"
50
+ Requires-Dist: pytest-cov; extra == "all"
51
+ Requires-Dist: black; extra == "all"
52
+ Requires-Dist: isort; extra == "all"
53
+ Requires-Dist: flake8; extra == "all"
54
+ Requires-Dist: wandb>=0.12.0; extra == "all"
55
+ Requires-Dist: eco2ai>=1.0.0; extra == "all"
56
+ Dynamic: author
57
+ Dynamic: home-page
58
+ Dynamic: license-file
59
+ Dynamic: requires-python
60
+
61
+ # GRAFT: Gradient-Aware Fast MaxVol Technique for Dynamic Data Sampling
62
+
63
+ [![PyPI version](https://badge.fury.io/py/graft-pytorch.svg)](https://badge.fury.io/py/graft-pytorch)
64
+ [![Python 3.8+](https://img.shields.io/badge/python-3.8+-blue.svg)](https://www.python.org/downloads/)
65
+ [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
66
+
67
+ A PyTorch implementation of smart sampling for efficient deep learning training.
68
+
69
+ ## Overview
70
+ GRAFT uses gradient information and feature decomposition to select the most informative samples during training, reducing computation time while maintaining model performance.
71
+
72
+ ## Features
73
+ - **Smart sample selection** using gradient-based importance scoring
74
+ - **Multi-architecture support** (ResNet, ResNeXT, EfficientNet, BERT)
75
+ - **Dataset compatibility** (CIFAR10/100, TinyImageNet, Caltech256, Medical datasets)
76
+ - **Experiment tracking** with Weights & Biases integration
77
+ - **Carbon footprint tracking** with eco2AI
78
+ - **Efficient training** with reduced computational overhead
79
+
80
+ ## Installation
81
+
82
+ ### From PyPI (Recommended)
83
+ ```bash
84
+ pip install graft-pytorch
85
+ ```
86
+
87
+ ### With optional dependencies
88
+ ```bash
89
+ # For experiment tracking
90
+ pip install graft-pytorch[tracking]
91
+
92
+ # For development
93
+ pip install graft-pytorch[dev]
94
+
95
+ # Everything
96
+ pip install graft-pytorch[all]
97
+ ```
98
+
99
+ ### From Source
100
+ ```bash
101
+ git clone https://github.com/ashishjv1/GRAFT.git
102
+ cd GRAFT
103
+ pip install -e .
104
+ ```
105
+
106
+ ## Quick Start
107
+
108
+ ### Command Line Interface
109
+ ```bash
110
+ # Install and train with smart sampling
111
+ pip install graft-pytorch
112
+
113
+ # Basic training with GRAFT sampling on CIFAR-10
114
+ graft-train \
115
+ --numEpochs=200 \
116
+ --batch_size=128 \
117
+ --device="cuda" \
118
+ --optimizer="sgd" \
119
+ --lr=0.1 \
120
+ --numClasses=10 \
121
+ --dataset="cifar10" \
122
+ --model="resnet18" \
123
+ --fraction=0.5 \
124
+ --select_iter=25 \
125
+ --warm_start
126
+ ```
127
+
128
+ ### Python API
129
+ ```python
130
+ import torch
131
+ from graft import ModelTrainer, TrainingConfig
132
+ from graft.utils.loader import loader
133
+
134
+ # Load your dataset
135
+ trainloader, valloader, trainset, valset = loader(
136
+ dataset="cifar10",
137
+ trn_batch_size=128,
138
+ val_batch_size=128
139
+ )
140
+
141
+ # Configure training with GRAFT
142
+ config = TrainingConfig(
143
+ numEpochs=100,
144
+ batch_size=128,
145
+ device="cuda" if torch.cuda.is_available() else "cpu",
146
+ model_name="resnet18",
147
+ dataset_name="cifar10",
148
+ trainloader=trainloader,
149
+ valloader=valloader,
150
+ trainset=trainset,
151
+ optimizer_name="sgd",
152
+ lr=0.1,
153
+ fraction=0.5, # Use 50% of data per epoch
154
+ selection_iter=25, # Reselect samples every 25 epochs
155
+ warm_start=True # Train on full data initially
156
+ )
157
+
158
+ # Train with smart sampling
159
+ trainer = ModelTrainer(config, trainloader, valloader, trainset)
160
+ train_stats, val_stats = trainer.train()
161
+
162
+ print(f"Best validation accuracy: {val_stats['best_acc']:.2%}")
163
+ ```
164
+
165
+ ### Advanced Usage
166
+ ```python
167
+ from graft import feature_sel, sample_selection
168
+ import torch.nn as nn
169
+
170
+ # Custom model and data selection
171
+ model = MyCustomModel()
172
+ data3 = feature_sel(dataloader, batch_size=128, device="cuda")
173
+
174
+ # Manual sample selection
175
+ selected_indices = sample_selection(
176
+ dataloader, data3, model, model.state_dict(),
177
+ batch_size=128, fraction=0.3, select_iter=10,
178
+ numEpochs=200, device="cuda", dataset="custom"
179
+ )
180
+ ```
181
+
182
+ ## Functionality Overview
183
+
184
+ ### Core Components
185
+
186
+ #### 1. Smart Sample Selection
187
+ - **`sample_selection()`**: Selects most informative samples using gradient-based importance
188
+ - **`feature_sel()`**: Performs feature decomposition for efficient sampling
189
+ - Reduces training time by 30-50% while maintaining model performance
190
+
191
+ #### 2. Supported Models
192
+ - **Vision Models**: ResNet, ResNeXt, EfficientNet, MobileNet, FashionCNN
193
+ - **Language Models**: BERT for sequence classification
194
+ - **Custom Models**: Easy integration with any PyTorch model
195
+
196
+ #### 3. Dataset Support
197
+ - **Computer Vision**: CIFAR-10/100, TinyImageNet, Caltech256
198
+ - **Medical Imaging**: Integration with MedMNIST datasets
199
+ - **Custom Datasets**: Support for any PyTorch DataLoader
200
+
201
+ #### 4. Training Features
202
+ - **Dynamic Sampling**: Adaptive sample selection during training
203
+ - **Warm Starting**: Begin with full dataset, then switch to sampling
204
+ - **Experiment Tracking**: Built-in WandB integration
205
+ - **Carbon Tracking**: Monitor environmental impact with eco2AI
206
+
207
+ ### Configuration Parameters
208
+
209
+ | Parameter | Description | Default | Options |
210
+ |-----------|-------------|---------|---------|
211
+ | `numEpochs` | Training epochs | 200 | Any integer |
212
+ | `batch_size` | Batch size | 128 | 32, 64, 128, 256+ |
213
+ | `device` | Computing device | "cuda" | "cpu", "cuda" |
214
+ | `model` | Model architecture | "resnet18" | "resnet18/50", "resnext", "efficientnet" |
215
+ | `fraction` | Data sampling ratio | 0.5 | 0.1 - 1.0 |
216
+ | `select_iter` | Reselection frequency | 25 | Any integer |
217
+ | `optimizer` | Optimization algorithm | "sgd" | "sgd", "adam" |
218
+ | `lr` | Learning rate | 0.1 | 0.001 - 0.1 |
219
+ | `warm_start` | Use full data initially | False | True/False |
220
+ | `decomp` | Decomposition backend | "numpy" | "numpy", "torch" |
221
+
222
+ ### Performance Benefits
223
+
224
+ - **Speed**: 30-50% faster training time
225
+ - **Memory**: Reduced memory usage through smart sampling
226
+ - **Accuracy**: Maintains or improves model performance
227
+ - **Efficiency**: Lower carbon footprint and energy consumption
228
+
229
+ ## Package Structure
230
+ ```
231
+ graft-pytorch/
232
+ ├── graft/
233
+ │ ├── __init__.py # Main package exports
234
+ │ ├── trainer.py # Training orchestration
235
+ │ ├── genindices.py # Sample selection algorithms
236
+ │ ├── decompositions.py # Feature decomposition
237
+ │ ├── models/ # Supported architectures
238
+ │ │ ├── resnet.py # ResNet implementations
239
+ │ │ ├── efficientnet.py # EfficientNet models
240
+ │ │ └── BERT_model.py # BERT for classification
241
+ │ └── utils/ # Utility functions
242
+ │ ├── loader.py # Dataset loaders
243
+ │ └── model_mapper.py # Model selection
244
+ ├── tests/ # Comprehensive test suite
245
+ ├── examples/ # Usage examples
246
+ └── OIDC_SETUP.md # Deployment configuration
247
+ ```
248
+
249
+ ## Contributing
250
+
251
+ We welcome contributions! Please see our [contribution guidelines](CONTRIBUTING.md) for details.
252
+
253
+ ### Development Setup
254
+ ```bash
255
+ # Clone the repository
256
+ git clone https://github.com/ashishjv1/GRAFT.git
257
+ cd GRAFT
258
+
259
+ # Install in development mode
260
+ pip install -e .[dev]
261
+
262
+ # Run tests
263
+ pytest tests/ -v
264
+
265
+ # Run linting
266
+ flake8 graft/ tests/
267
+ ```
268
+
269
+ ## License
270
+
271
+ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
272
+
273
+ ## Citation
274
+
275
+ If you use GRAFT in your research, please cite our paper:
276
+
277
+ ```bibtex
278
+ @misc{jha2025graftgradientawarefastmaxvol,
279
+ title = {GRAFT: Gradient-Aware Fast MaxVol Technique for Dynamic Data Sampling},
280
+ author = {Ashish Jha and Anh Huy Phan and Razan Dibo and Valentin Leplat},
281
+ year = {2025},
282
+ eprint = {2508.13653},
283
+ archivePrefix = {arXiv},
284
+ primaryClass = {cs.LG},
285
+ url = {https://arxiv.org/abs/2508.13653}
286
+ }
287
+ ```
288
+
289
+ ## Acknowledgments
290
+
291
+ - Built using PyTorch
292
+ - Inspired by MaxVol techniques for data sampling
293
+ - Special thanks to the open-source community
294
+
295
+ ---
296
+
297
+ **PyPI Package**: [graft-pytorch](https://pypi.org/project/graft-pytorch/)
298
+ **Paper**: [arXiv:2508.13653](https://arxiv.org/abs/2508.13653)
299
+ **Issues**: [GitHub Issues](https://github.com/ashishjv1/GRAFT/issues)
300
+ **Contact**: [Ashish Jha](mailto:Ashish.Jha@skoltech.ru)
301
+
302
+