graft-pytorch 0.1.7__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- graft_pytorch-0.1.7/CITATION.cff +11 -0
- graft_pytorch-0.1.7/LICENSE +21 -0
- graft_pytorch-0.1.7/MANIFEST.in +36 -0
- graft_pytorch-0.1.7/PKG-INFO +302 -0
- graft_pytorch-0.1.7/README.md +242 -0
- graft_pytorch-0.1.7/examples/Finetune_BERT.ipynb +1413 -0
- graft_pytorch-0.1.7/examples/local_dataset_utilities.py +102 -0
- graft_pytorch-0.1.7/graft/__init__.py +20 -0
- graft_pytorch-0.1.7/graft/cli.py +62 -0
- graft_pytorch-0.1.7/graft/config.py +36 -0
- graft_pytorch-0.1.7/graft/decompositions.py +54 -0
- graft_pytorch-0.1.7/graft/genindices.py +122 -0
- graft_pytorch-0.1.7/graft/grad_dist.py +20 -0
- graft_pytorch-0.1.7/graft/models/BERT_model.py +40 -0
- graft_pytorch-0.1.7/graft/models/MobilenetV2.py +111 -0
- graft_pytorch-0.1.7/graft/models/ResNeXt.py +154 -0
- graft_pytorch-0.1.7/graft/models/__init__.py +22 -0
- graft_pytorch-0.1.7/graft/models/efficientnet.py +197 -0
- graft_pytorch-0.1.7/graft/models/efficientnetb7.py +268 -0
- graft_pytorch-0.1.7/graft/models/fashioncnn.py +69 -0
- graft_pytorch-0.1.7/graft/models/mobilenet.py +83 -0
- graft_pytorch-0.1.7/graft/models/resnet.py +564 -0
- graft_pytorch-0.1.7/graft/models/resnet9.py +72 -0
- graft_pytorch-0.1.7/graft/scheduler.py +63 -0
- graft_pytorch-0.1.7/graft/trainer.py +467 -0
- graft_pytorch-0.1.7/graft/utils/__init__.py +5 -0
- graft_pytorch-0.1.7/graft/utils/extras.py +37 -0
- graft_pytorch-0.1.7/graft/utils/generate.py +33 -0
- graft_pytorch-0.1.7/graft/utils/imagenetselloader.py +54 -0
- graft_pytorch-0.1.7/graft/utils/loader.py +293 -0
- graft_pytorch-0.1.7/graft/utils/model_mapper.py +45 -0
- graft_pytorch-0.1.7/graft/utils/pickler.py +27 -0
- graft_pytorch-0.1.7/graft_pytorch.egg-info/PKG-INFO +302 -0
- graft_pytorch-0.1.7/graft_pytorch.egg-info/SOURCES.txt +44 -0
- graft_pytorch-0.1.7/graft_pytorch.egg-info/dependency_links.txt +1 -0
- graft_pytorch-0.1.7/graft_pytorch.egg-info/entry_points.txt +2 -0
- graft_pytorch-0.1.7/graft_pytorch.egg-info/not-zip-safe +1 -0
- graft_pytorch-0.1.7/graft_pytorch.egg-info/requires.txt +29 -0
- graft_pytorch-0.1.7/graft_pytorch.egg-info/top_level.txt +1 -0
- graft_pytorch-0.1.7/pyproject.toml +110 -0
- graft_pytorch-0.1.7/requirements.txt +12 -0
- graft_pytorch-0.1.7/setup.cfg +4 -0
- graft_pytorch-0.1.7/setup.py +70 -0
- graft_pytorch-0.1.7/tests/test_genindices.py +130 -0
- graft_pytorch-0.1.7/tests/test_graft.py +141 -0
- graft_pytorch-0.1.7/tests/test_graft_e2e.py +106 -0
@@ -0,0 +1,11 @@
|
|
1
|
+
cff-version: 1.2.0
|
2
|
+
message: "If you use this software, please cite it as below."
|
3
|
+
authors:
|
4
|
+
- family-names: "Jha"
|
5
|
+
given-names: "A"
|
6
|
+
affiliation: "Skolkovo Institute of Science & Technology"
|
7
|
+
title: "GRAFT: Gradient-Aware Fast MaxVol Technique for Dynamic Data Sampling"
|
8
|
+
version: 1.0.0
|
9
|
+
doi: 10.5281/zenodo.16947530
|
10
|
+
date-released: 2025-01-01
|
11
|
+
url: "https://github.com/ashishjv1/GRAFT"
|
@@ -0,0 +1,21 @@
|
|
1
|
+
MIT License
|
2
|
+
|
3
|
+
Copyright (c) 2024 Ashish Jha
|
4
|
+
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
7
|
+
in the Software without restriction, including without limitation the rights
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
10
|
+
furnished to do so, subject to the following conditions:
|
11
|
+
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
13
|
+
copies or substantial portions of the Software.
|
14
|
+
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21
|
+
SOFTWARE.
|
@@ -0,0 +1,36 @@
|
|
1
|
+
# Include the README and other documentation files
|
2
|
+
include README.md
|
3
|
+
include LICENSE
|
4
|
+
include CITATION.cff
|
5
|
+
include requirements.txt
|
6
|
+
include pyproject.toml
|
7
|
+
|
8
|
+
# Include supplementary materials
|
9
|
+
recursive-include supplementary *
|
10
|
+
|
11
|
+
# Include example notebooks and scripts
|
12
|
+
recursive-include examples *.py *.ipynb
|
13
|
+
|
14
|
+
# Include test files
|
15
|
+
recursive-include tests *.py
|
16
|
+
|
17
|
+
# Exclude unnecessary files
|
18
|
+
global-exclude __pycache__
|
19
|
+
global-exclude *.py[co]
|
20
|
+
global-exclude *.so
|
21
|
+
global-exclude .DS_Store
|
22
|
+
global-exclude *.egg-info
|
23
|
+
|
24
|
+
# Exclude legacy files
|
25
|
+
exclude trainer.py
|
26
|
+
exclude GRAFT.py
|
27
|
+
exclude GRAFT_swinft.py
|
28
|
+
|
29
|
+
# Exclude build directories
|
30
|
+
prune build
|
31
|
+
prune dist
|
32
|
+
prune .git
|
33
|
+
prune .pytest_cache
|
34
|
+
|
35
|
+
# Include package data
|
36
|
+
recursive-include graft *.py
|
@@ -0,0 +1,302 @@
|
|
1
|
+
Metadata-Version: 2.4
|
2
|
+
Name: graft-pytorch
|
3
|
+
Version: 0.1.7
|
4
|
+
Summary: Gradient-Aware Fast MaxVol Technique for Dynamic Data Sampling
|
5
|
+
Home-page: https://github.com/ashishjv1/GRAFT
|
6
|
+
Author: Ashish Jha
|
7
|
+
Author-email: Ashish Jha <Ashish.Jha@skoltech.ru>
|
8
|
+
Maintainer-email: Ashish Jha <Ashish.Jha@skoltech.ru>
|
9
|
+
License: MIT
|
10
|
+
Project-URL: Homepage, https://github.com/ashishjv1/GRAFT
|
11
|
+
Project-URL: Repository, https://github.com/ashishjv1/GRAFT
|
12
|
+
Project-URL: Bug Reports, https://github.com/ashishjv1/GRAFT/issues
|
13
|
+
Project-URL: Documentation, https://github.com/ashishjv1/GRAFT/blob/main/README.md
|
14
|
+
Keywords: machine-learning,deep-learning,pytorch,data-sampling,gradient-based-sampling
|
15
|
+
Classifier: Development Status :: 3 - Alpha
|
16
|
+
Classifier: Intended Audience :: Developers
|
17
|
+
Classifier: Intended Audience :: Science/Research
|
18
|
+
Classifier: License :: OSI Approved :: MIT License
|
19
|
+
Classifier: Operating System :: OS Independent
|
20
|
+
Classifier: Programming Language :: Python :: 3
|
21
|
+
Classifier: Programming Language :: Python :: 3.8
|
22
|
+
Classifier: Programming Language :: Python :: 3.9
|
23
|
+
Classifier: Programming Language :: Python :: 3.10
|
24
|
+
Classifier: Programming Language :: Python :: 3.11
|
25
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
26
|
+
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
27
|
+
Requires-Python: >=3.8
|
28
|
+
Description-Content-Type: text/markdown
|
29
|
+
License-File: LICENSE
|
30
|
+
Requires-Dist: torch>=1.9.0
|
31
|
+
Requires-Dist: torchvision>=0.10.0
|
32
|
+
Requires-Dist: numpy>=1.19.2
|
33
|
+
Requires-Dist: tqdm>=4.62.3
|
34
|
+
Requires-Dist: scikit-learn>=0.24.2
|
35
|
+
Requires-Dist: pillow>=8.3.1
|
36
|
+
Requires-Dist: matplotlib>=3.4.3
|
37
|
+
Requires-Dist: transformers>=4.0.0
|
38
|
+
Requires-Dist: medmnist>=2.0.0
|
39
|
+
Provides-Extra: dev
|
40
|
+
Requires-Dist: pytest>=6.2.5; extra == "dev"
|
41
|
+
Requires-Dist: pytest-cov; extra == "dev"
|
42
|
+
Requires-Dist: black; extra == "dev"
|
43
|
+
Requires-Dist: isort; extra == "dev"
|
44
|
+
Requires-Dist: flake8; extra == "dev"
|
45
|
+
Provides-Extra: tracking
|
46
|
+
Requires-Dist: wandb>=0.12.0; extra == "tracking"
|
47
|
+
Requires-Dist: eco2ai>=1.0.0; extra == "tracking"
|
48
|
+
Provides-Extra: all
|
49
|
+
Requires-Dist: pytest>=6.2.5; extra == "all"
|
50
|
+
Requires-Dist: pytest-cov; extra == "all"
|
51
|
+
Requires-Dist: black; extra == "all"
|
52
|
+
Requires-Dist: isort; extra == "all"
|
53
|
+
Requires-Dist: flake8; extra == "all"
|
54
|
+
Requires-Dist: wandb>=0.12.0; extra == "all"
|
55
|
+
Requires-Dist: eco2ai>=1.0.0; extra == "all"
|
56
|
+
Dynamic: author
|
57
|
+
Dynamic: home-page
|
58
|
+
Dynamic: license-file
|
59
|
+
Dynamic: requires-python
|
60
|
+
|
61
|
+
# GRAFT: Gradient-Aware Fast MaxVol Technique for Dynamic Data Sampling
|
62
|
+
|
63
|
+
[](https://badge.fury.io/py/graft-pytorch)
|
64
|
+
[](https://www.python.org/downloads/)
|
65
|
+
[](https://opensource.org/licenses/MIT)
|
66
|
+
|
67
|
+
A PyTorch implementation of smart sampling for efficient deep learning training.
|
68
|
+
|
69
|
+
## Overview
|
70
|
+
GRAFT uses gradient information and feature decomposition to select the most informative samples during training, reducing computation time while maintaining model performance.
|
71
|
+
|
72
|
+
## Features
|
73
|
+
- **Smart sample selection** using gradient-based importance scoring
|
74
|
+
- **Multi-architecture support** (ResNet, ResNeXT, EfficientNet, BERT)
|
75
|
+
- **Dataset compatibility** (CIFAR10/100, TinyImageNet, Caltech256, Medical datasets)
|
76
|
+
- **Experiment tracking** with Weights & Biases integration
|
77
|
+
- **Carbon footprint tracking** with eco2AI
|
78
|
+
- **Efficient training** with reduced computational overhead
|
79
|
+
|
80
|
+
## Installation
|
81
|
+
|
82
|
+
### From PyPI (Recommended)
|
83
|
+
```bash
|
84
|
+
pip install graft-pytorch
|
85
|
+
```
|
86
|
+
|
87
|
+
### With optional dependencies
|
88
|
+
```bash
|
89
|
+
# For experiment tracking
|
90
|
+
pip install graft-pytorch[tracking]
|
91
|
+
|
92
|
+
# For development
|
93
|
+
pip install graft-pytorch[dev]
|
94
|
+
|
95
|
+
# Everything
|
96
|
+
pip install graft-pytorch[all]
|
97
|
+
```
|
98
|
+
|
99
|
+
### From Source
|
100
|
+
```bash
|
101
|
+
git clone https://github.com/ashishjv1/GRAFT.git
|
102
|
+
cd GRAFT
|
103
|
+
pip install -e .
|
104
|
+
```
|
105
|
+
|
106
|
+
## Quick Start
|
107
|
+
|
108
|
+
### Command Line Interface
|
109
|
+
```bash
|
110
|
+
# Install and train with smart sampling
|
111
|
+
pip install graft-pytorch
|
112
|
+
|
113
|
+
# Basic training with GRAFT sampling on CIFAR-10
|
114
|
+
graft-train \
|
115
|
+
--numEpochs=200 \
|
116
|
+
--batch_size=128 \
|
117
|
+
--device="cuda" \
|
118
|
+
--optimizer="sgd" \
|
119
|
+
--lr=0.1 \
|
120
|
+
--numClasses=10 \
|
121
|
+
--dataset="cifar10" \
|
122
|
+
--model="resnet18" \
|
123
|
+
--fraction=0.5 \
|
124
|
+
--select_iter=25 \
|
125
|
+
--warm_start
|
126
|
+
```
|
127
|
+
|
128
|
+
### Python API
|
129
|
+
```python
|
130
|
+
import torch
|
131
|
+
from graft import ModelTrainer, TrainingConfig
|
132
|
+
from graft.utils.loader import loader
|
133
|
+
|
134
|
+
# Load your dataset
|
135
|
+
trainloader, valloader, trainset, valset = loader(
|
136
|
+
dataset="cifar10",
|
137
|
+
trn_batch_size=128,
|
138
|
+
val_batch_size=128
|
139
|
+
)
|
140
|
+
|
141
|
+
# Configure training with GRAFT
|
142
|
+
config = TrainingConfig(
|
143
|
+
numEpochs=100,
|
144
|
+
batch_size=128,
|
145
|
+
device="cuda" if torch.cuda.is_available() else "cpu",
|
146
|
+
model_name="resnet18",
|
147
|
+
dataset_name="cifar10",
|
148
|
+
trainloader=trainloader,
|
149
|
+
valloader=valloader,
|
150
|
+
trainset=trainset,
|
151
|
+
optimizer_name="sgd",
|
152
|
+
lr=0.1,
|
153
|
+
fraction=0.5, # Use 50% of data per epoch
|
154
|
+
selection_iter=25, # Reselect samples every 25 epochs
|
155
|
+
warm_start=True # Train on full data initially
|
156
|
+
)
|
157
|
+
|
158
|
+
# Train with smart sampling
|
159
|
+
trainer = ModelTrainer(config, trainloader, valloader, trainset)
|
160
|
+
train_stats, val_stats = trainer.train()
|
161
|
+
|
162
|
+
print(f"Best validation accuracy: {val_stats['best_acc']:.2%}")
|
163
|
+
```
|
164
|
+
|
165
|
+
### Advanced Usage
|
166
|
+
```python
|
167
|
+
from graft import feature_sel, sample_selection
|
168
|
+
import torch.nn as nn
|
169
|
+
|
170
|
+
# Custom model and data selection
|
171
|
+
model = MyCustomModel()
|
172
|
+
data3 = feature_sel(dataloader, batch_size=128, device="cuda")
|
173
|
+
|
174
|
+
# Manual sample selection
|
175
|
+
selected_indices = sample_selection(
|
176
|
+
dataloader, data3, model, model.state_dict(),
|
177
|
+
batch_size=128, fraction=0.3, select_iter=10,
|
178
|
+
numEpochs=200, device="cuda", dataset="custom"
|
179
|
+
)
|
180
|
+
```
|
181
|
+
|
182
|
+
## Functionality Overview
|
183
|
+
|
184
|
+
### Core Components
|
185
|
+
|
186
|
+
#### 1. Smart Sample Selection
|
187
|
+
- **`sample_selection()`**: Selects most informative samples using gradient-based importance
|
188
|
+
- **`feature_sel()`**: Performs feature decomposition for efficient sampling
|
189
|
+
- Reduces training time by 30-50% while maintaining model performance
|
190
|
+
|
191
|
+
#### 2. Supported Models
|
192
|
+
- **Vision Models**: ResNet, ResNeXt, EfficientNet, MobileNet, FashionCNN
|
193
|
+
- **Language Models**: BERT for sequence classification
|
194
|
+
- **Custom Models**: Easy integration with any PyTorch model
|
195
|
+
|
196
|
+
#### 3. Dataset Support
|
197
|
+
- **Computer Vision**: CIFAR-10/100, TinyImageNet, Caltech256
|
198
|
+
- **Medical Imaging**: Integration with MedMNIST datasets
|
199
|
+
- **Custom Datasets**: Support for any PyTorch DataLoader
|
200
|
+
|
201
|
+
#### 4. Training Features
|
202
|
+
- **Dynamic Sampling**: Adaptive sample selection during training
|
203
|
+
- **Warm Starting**: Begin with full dataset, then switch to sampling
|
204
|
+
- **Experiment Tracking**: Built-in WandB integration
|
205
|
+
- **Carbon Tracking**: Monitor environmental impact with eco2AI
|
206
|
+
|
207
|
+
### Configuration Parameters
|
208
|
+
|
209
|
+
| Parameter | Description | Default | Options |
|
210
|
+
|-----------|-------------|---------|---------|
|
211
|
+
| `numEpochs` | Training epochs | 200 | Any integer |
|
212
|
+
| `batch_size` | Batch size | 128 | 32, 64, 128, 256+ |
|
213
|
+
| `device` | Computing device | "cuda" | "cpu", "cuda" |
|
214
|
+
| `model` | Model architecture | "resnet18" | "resnet18/50", "resnext", "efficientnet" |
|
215
|
+
| `fraction` | Data sampling ratio | 0.5 | 0.1 - 1.0 |
|
216
|
+
| `select_iter` | Reselection frequency | 25 | Any integer |
|
217
|
+
| `optimizer` | Optimization algorithm | "sgd" | "sgd", "adam" |
|
218
|
+
| `lr` | Learning rate | 0.1 | 0.001 - 0.1 |
|
219
|
+
| `warm_start` | Use full data initially | False | True/False |
|
220
|
+
| `decomp` | Decomposition backend | "numpy" | "numpy", "torch" |
|
221
|
+
|
222
|
+
### Performance Benefits
|
223
|
+
|
224
|
+
- **Speed**: 30-50% faster training time
|
225
|
+
- **Memory**: Reduced memory usage through smart sampling
|
226
|
+
- **Accuracy**: Maintains or improves model performance
|
227
|
+
- **Efficiency**: Lower carbon footprint and energy consumption
|
228
|
+
|
229
|
+
## Package Structure
|
230
|
+
```
|
231
|
+
graft-pytorch/
|
232
|
+
├── graft/
|
233
|
+
│ ├── __init__.py # Main package exports
|
234
|
+
│ ├── trainer.py # Training orchestration
|
235
|
+
│ ├── genindices.py # Sample selection algorithms
|
236
|
+
│ ├── decompositions.py # Feature decomposition
|
237
|
+
│ ├── models/ # Supported architectures
|
238
|
+
│ │ ├── resnet.py # ResNet implementations
|
239
|
+
│ │ ├── efficientnet.py # EfficientNet models
|
240
|
+
│ │ └── BERT_model.py # BERT for classification
|
241
|
+
│ └── utils/ # Utility functions
|
242
|
+
│ ├── loader.py # Dataset loaders
|
243
|
+
│ └── model_mapper.py # Model selection
|
244
|
+
├── tests/ # Comprehensive test suite
|
245
|
+
├── examples/ # Usage examples
|
246
|
+
└── OIDC_SETUP.md # Deployment configuration
|
247
|
+
```
|
248
|
+
|
249
|
+
## Contributing
|
250
|
+
|
251
|
+
We welcome contributions! Please see our [contribution guidelines](CONTRIBUTING.md) for details.
|
252
|
+
|
253
|
+
### Development Setup
|
254
|
+
```bash
|
255
|
+
# Clone the repository
|
256
|
+
git clone https://github.com/ashishjv1/GRAFT.git
|
257
|
+
cd GRAFT
|
258
|
+
|
259
|
+
# Install in development mode
|
260
|
+
pip install -e .[dev]
|
261
|
+
|
262
|
+
# Run tests
|
263
|
+
pytest tests/ -v
|
264
|
+
|
265
|
+
# Run linting
|
266
|
+
flake8 graft/ tests/
|
267
|
+
```
|
268
|
+
|
269
|
+
## License
|
270
|
+
|
271
|
+
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
272
|
+
|
273
|
+
## Citation
|
274
|
+
|
275
|
+
If you use GRAFT in your research, please cite our paper:
|
276
|
+
|
277
|
+
```bibtex
|
278
|
+
@misc{jha2025graftgradientawarefastmaxvol,
|
279
|
+
title = {GRAFT: Gradient-Aware Fast MaxVol Technique for Dynamic Data Sampling},
|
280
|
+
author = {Ashish Jha and Anh Huy Phan and Razan Dibo and Valentin Leplat},
|
281
|
+
year = {2025},
|
282
|
+
eprint = {2508.13653},
|
283
|
+
archivePrefix = {arXiv},
|
284
|
+
primaryClass = {cs.LG},
|
285
|
+
url = {https://arxiv.org/abs/2508.13653}
|
286
|
+
}
|
287
|
+
```
|
288
|
+
|
289
|
+
## Acknowledgments
|
290
|
+
|
291
|
+
- Built using PyTorch
|
292
|
+
- Inspired by MaxVol techniques for data sampling
|
293
|
+
- Special thanks to the open-source community
|
294
|
+
|
295
|
+
---
|
296
|
+
|
297
|
+
**PyPI Package**: [graft-pytorch](https://pypi.org/project/graft-pytorch/)
|
298
|
+
**Paper**: [arXiv:2508.13653](https://arxiv.org/abs/2508.13653)
|
299
|
+
**Issues**: [GitHub Issues](https://github.com/ashishjv1/GRAFT/issues)
|
300
|
+
**Contact**: [Ashish Jha](mailto:Ashish.Jha@skoltech.ru)
|
301
|
+
|
302
|
+
|
@@ -0,0 +1,242 @@
|
|
1
|
+
# GRAFT: Gradient-Aware Fast MaxVol Technique for Dynamic Data Sampling
|
2
|
+
|
3
|
+
[](https://badge.fury.io/py/graft-pytorch)
|
4
|
+
[](https://www.python.org/downloads/)
|
5
|
+
[](https://opensource.org/licenses/MIT)
|
6
|
+
|
7
|
+
A PyTorch implementation of smart sampling for efficient deep learning training.
|
8
|
+
|
9
|
+
## Overview
|
10
|
+
GRAFT uses gradient information and feature decomposition to select the most informative samples during training, reducing computation time while maintaining model performance.
|
11
|
+
|
12
|
+
## Features
|
13
|
+
- **Smart sample selection** using gradient-based importance scoring
|
14
|
+
- **Multi-architecture support** (ResNet, ResNeXT, EfficientNet, BERT)
|
15
|
+
- **Dataset compatibility** (CIFAR10/100, TinyImageNet, Caltech256, Medical datasets)
|
16
|
+
- **Experiment tracking** with Weights & Biases integration
|
17
|
+
- **Carbon footprint tracking** with eco2AI
|
18
|
+
- **Efficient training** with reduced computational overhead
|
19
|
+
|
20
|
+
## Installation
|
21
|
+
|
22
|
+
### From PyPI (Recommended)
|
23
|
+
```bash
|
24
|
+
pip install graft-pytorch
|
25
|
+
```
|
26
|
+
|
27
|
+
### With optional dependencies
|
28
|
+
```bash
|
29
|
+
# For experiment tracking
|
30
|
+
pip install graft-pytorch[tracking]
|
31
|
+
|
32
|
+
# For development
|
33
|
+
pip install graft-pytorch[dev]
|
34
|
+
|
35
|
+
# Everything
|
36
|
+
pip install graft-pytorch[all]
|
37
|
+
```
|
38
|
+
|
39
|
+
### From Source
|
40
|
+
```bash
|
41
|
+
git clone https://github.com/ashishjv1/GRAFT.git
|
42
|
+
cd GRAFT
|
43
|
+
pip install -e .
|
44
|
+
```
|
45
|
+
|
46
|
+
## Quick Start
|
47
|
+
|
48
|
+
### Command Line Interface
|
49
|
+
```bash
|
50
|
+
# Install and train with smart sampling
|
51
|
+
pip install graft-pytorch
|
52
|
+
|
53
|
+
# Basic training with GRAFT sampling on CIFAR-10
|
54
|
+
graft-train \
|
55
|
+
--numEpochs=200 \
|
56
|
+
--batch_size=128 \
|
57
|
+
--device="cuda" \
|
58
|
+
--optimizer="sgd" \
|
59
|
+
--lr=0.1 \
|
60
|
+
--numClasses=10 \
|
61
|
+
--dataset="cifar10" \
|
62
|
+
--model="resnet18" \
|
63
|
+
--fraction=0.5 \
|
64
|
+
--select_iter=25 \
|
65
|
+
--warm_start
|
66
|
+
```
|
67
|
+
|
68
|
+
### Python API
|
69
|
+
```python
|
70
|
+
import torch
|
71
|
+
from graft import ModelTrainer, TrainingConfig
|
72
|
+
from graft.utils.loader import loader
|
73
|
+
|
74
|
+
# Load your dataset
|
75
|
+
trainloader, valloader, trainset, valset = loader(
|
76
|
+
dataset="cifar10",
|
77
|
+
trn_batch_size=128,
|
78
|
+
val_batch_size=128
|
79
|
+
)
|
80
|
+
|
81
|
+
# Configure training with GRAFT
|
82
|
+
config = TrainingConfig(
|
83
|
+
numEpochs=100,
|
84
|
+
batch_size=128,
|
85
|
+
device="cuda" if torch.cuda.is_available() else "cpu",
|
86
|
+
model_name="resnet18",
|
87
|
+
dataset_name="cifar10",
|
88
|
+
trainloader=trainloader,
|
89
|
+
valloader=valloader,
|
90
|
+
trainset=trainset,
|
91
|
+
optimizer_name="sgd",
|
92
|
+
lr=0.1,
|
93
|
+
fraction=0.5, # Use 50% of data per epoch
|
94
|
+
selection_iter=25, # Reselect samples every 25 epochs
|
95
|
+
warm_start=True # Train on full data initially
|
96
|
+
)
|
97
|
+
|
98
|
+
# Train with smart sampling
|
99
|
+
trainer = ModelTrainer(config, trainloader, valloader, trainset)
|
100
|
+
train_stats, val_stats = trainer.train()
|
101
|
+
|
102
|
+
print(f"Best validation accuracy: {val_stats['best_acc']:.2%}")
|
103
|
+
```
|
104
|
+
|
105
|
+
### Advanced Usage
|
106
|
+
```python
|
107
|
+
from graft import feature_sel, sample_selection
|
108
|
+
import torch.nn as nn
|
109
|
+
|
110
|
+
# Custom model and data selection
|
111
|
+
model = MyCustomModel()
|
112
|
+
data3 = feature_sel(dataloader, batch_size=128, device="cuda")
|
113
|
+
|
114
|
+
# Manual sample selection
|
115
|
+
selected_indices = sample_selection(
|
116
|
+
dataloader, data3, model, model.state_dict(),
|
117
|
+
batch_size=128, fraction=0.3, select_iter=10,
|
118
|
+
numEpochs=200, device="cuda", dataset="custom"
|
119
|
+
)
|
120
|
+
```
|
121
|
+
|
122
|
+
## Functionality Overview
|
123
|
+
|
124
|
+
### Core Components
|
125
|
+
|
126
|
+
#### 1. Smart Sample Selection
|
127
|
+
- **`sample_selection()`**: Selects most informative samples using gradient-based importance
|
128
|
+
- **`feature_sel()`**: Performs feature decomposition for efficient sampling
|
129
|
+
- Reduces training time by 30-50% while maintaining model performance
|
130
|
+
|
131
|
+
#### 2. Supported Models
|
132
|
+
- **Vision Models**: ResNet, ResNeXt, EfficientNet, MobileNet, FashionCNN
|
133
|
+
- **Language Models**: BERT for sequence classification
|
134
|
+
- **Custom Models**: Easy integration with any PyTorch model
|
135
|
+
|
136
|
+
#### 3. Dataset Support
|
137
|
+
- **Computer Vision**: CIFAR-10/100, TinyImageNet, Caltech256
|
138
|
+
- **Medical Imaging**: Integration with MedMNIST datasets
|
139
|
+
- **Custom Datasets**: Support for any PyTorch DataLoader
|
140
|
+
|
141
|
+
#### 4. Training Features
|
142
|
+
- **Dynamic Sampling**: Adaptive sample selection during training
|
143
|
+
- **Warm Starting**: Begin with full dataset, then switch to sampling
|
144
|
+
- **Experiment Tracking**: Built-in WandB integration
|
145
|
+
- **Carbon Tracking**: Monitor environmental impact with eco2AI
|
146
|
+
|
147
|
+
### Configuration Parameters
|
148
|
+
|
149
|
+
| Parameter | Description | Default | Options |
|
150
|
+
|-----------|-------------|---------|---------|
|
151
|
+
| `numEpochs` | Training epochs | 200 | Any integer |
|
152
|
+
| `batch_size` | Batch size | 128 | 32, 64, 128, 256+ |
|
153
|
+
| `device` | Computing device | "cuda" | "cpu", "cuda" |
|
154
|
+
| `model` | Model architecture | "resnet18" | "resnet18/50", "resnext", "efficientnet" |
|
155
|
+
| `fraction` | Data sampling ratio | 0.5 | 0.1 - 1.0 |
|
156
|
+
| `select_iter` | Reselection frequency | 25 | Any integer |
|
157
|
+
| `optimizer` | Optimization algorithm | "sgd" | "sgd", "adam" |
|
158
|
+
| `lr` | Learning rate | 0.1 | 0.001 - 0.1 |
|
159
|
+
| `warm_start` | Use full data initially | False | True/False |
|
160
|
+
| `decomp` | Decomposition backend | "numpy" | "numpy", "torch" |
|
161
|
+
|
162
|
+
### Performance Benefits
|
163
|
+
|
164
|
+
- **Speed**: 30-50% faster training time
|
165
|
+
- **Memory**: Reduced memory usage through smart sampling
|
166
|
+
- **Accuracy**: Maintains or improves model performance
|
167
|
+
- **Efficiency**: Lower carbon footprint and energy consumption
|
168
|
+
|
169
|
+
## Package Structure
|
170
|
+
```
|
171
|
+
graft-pytorch/
|
172
|
+
├── graft/
|
173
|
+
│ ├── __init__.py # Main package exports
|
174
|
+
│ ├── trainer.py # Training orchestration
|
175
|
+
│ ├── genindices.py # Sample selection algorithms
|
176
|
+
│ ├── decompositions.py # Feature decomposition
|
177
|
+
│ ├── models/ # Supported architectures
|
178
|
+
│ │ ├── resnet.py # ResNet implementations
|
179
|
+
│ │ ├── efficientnet.py # EfficientNet models
|
180
|
+
│ │ └── BERT_model.py # BERT for classification
|
181
|
+
│ └── utils/ # Utility functions
|
182
|
+
│ ├── loader.py # Dataset loaders
|
183
|
+
│ └── model_mapper.py # Model selection
|
184
|
+
├── tests/ # Comprehensive test suite
|
185
|
+
├── examples/ # Usage examples
|
186
|
+
└── OIDC_SETUP.md # Deployment configuration
|
187
|
+
```
|
188
|
+
|
189
|
+
## Contributing
|
190
|
+
|
191
|
+
We welcome contributions! Please see our [contribution guidelines](CONTRIBUTING.md) for details.
|
192
|
+
|
193
|
+
### Development Setup
|
194
|
+
```bash
|
195
|
+
# Clone the repository
|
196
|
+
git clone https://github.com/ashishjv1/GRAFT.git
|
197
|
+
cd GRAFT
|
198
|
+
|
199
|
+
# Install in development mode
|
200
|
+
pip install -e .[dev]
|
201
|
+
|
202
|
+
# Run tests
|
203
|
+
pytest tests/ -v
|
204
|
+
|
205
|
+
# Run linting
|
206
|
+
flake8 graft/ tests/
|
207
|
+
```
|
208
|
+
|
209
|
+
## License
|
210
|
+
|
211
|
+
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
|
212
|
+
|
213
|
+
## Citation
|
214
|
+
|
215
|
+
If you use GRAFT in your research, please cite our paper:
|
216
|
+
|
217
|
+
```bibtex
|
218
|
+
@misc{jha2025graftgradientawarefastmaxvol,
|
219
|
+
title = {GRAFT: Gradient-Aware Fast MaxVol Technique for Dynamic Data Sampling},
|
220
|
+
author = {Ashish Jha and Anh Huy Phan and Razan Dibo and Valentin Leplat},
|
221
|
+
year = {2025},
|
222
|
+
eprint = {2508.13653},
|
223
|
+
archivePrefix = {arXiv},
|
224
|
+
primaryClass = {cs.LG},
|
225
|
+
url = {https://arxiv.org/abs/2508.13653}
|
226
|
+
}
|
227
|
+
```
|
228
|
+
|
229
|
+
## Acknowledgments
|
230
|
+
|
231
|
+
- Built using PyTorch
|
232
|
+
- Inspired by MaxVol techniques for data sampling
|
233
|
+
- Special thanks to the open-source community
|
234
|
+
|
235
|
+
---
|
236
|
+
|
237
|
+
**PyPI Package**: [graft-pytorch](https://pypi.org/project/graft-pytorch/)
|
238
|
+
**Paper**: [arXiv:2508.13653](https://arxiv.org/abs/2508.13653)
|
239
|
+
**Issues**: [GitHub Issues](https://github.com/ashishjv1/GRAFT/issues)
|
240
|
+
**Contact**: [Ashish Jha](mailto:Ashish.Jha@skoltech.ru)
|
241
|
+
|
242
|
+
|