graft-pytorch 0.1.7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. graft_pytorch-0.1.7/CITATION.cff +11 -0
  2. graft_pytorch-0.1.7/LICENSE +21 -0
  3. graft_pytorch-0.1.7/MANIFEST.in +36 -0
  4. graft_pytorch-0.1.7/PKG-INFO +302 -0
  5. graft_pytorch-0.1.7/README.md +242 -0
  6. graft_pytorch-0.1.7/examples/Finetune_BERT.ipynb +1413 -0
  7. graft_pytorch-0.1.7/examples/local_dataset_utilities.py +102 -0
  8. graft_pytorch-0.1.7/graft/__init__.py +20 -0
  9. graft_pytorch-0.1.7/graft/cli.py +62 -0
  10. graft_pytorch-0.1.7/graft/config.py +36 -0
  11. graft_pytorch-0.1.7/graft/decompositions.py +54 -0
  12. graft_pytorch-0.1.7/graft/genindices.py +122 -0
  13. graft_pytorch-0.1.7/graft/grad_dist.py +20 -0
  14. graft_pytorch-0.1.7/graft/models/BERT_model.py +40 -0
  15. graft_pytorch-0.1.7/graft/models/MobilenetV2.py +111 -0
  16. graft_pytorch-0.1.7/graft/models/ResNeXt.py +154 -0
  17. graft_pytorch-0.1.7/graft/models/__init__.py +22 -0
  18. graft_pytorch-0.1.7/graft/models/efficientnet.py +197 -0
  19. graft_pytorch-0.1.7/graft/models/efficientnetb7.py +268 -0
  20. graft_pytorch-0.1.7/graft/models/fashioncnn.py +69 -0
  21. graft_pytorch-0.1.7/graft/models/mobilenet.py +83 -0
  22. graft_pytorch-0.1.7/graft/models/resnet.py +564 -0
  23. graft_pytorch-0.1.7/graft/models/resnet9.py +72 -0
  24. graft_pytorch-0.1.7/graft/scheduler.py +63 -0
  25. graft_pytorch-0.1.7/graft/trainer.py +467 -0
  26. graft_pytorch-0.1.7/graft/utils/__init__.py +5 -0
  27. graft_pytorch-0.1.7/graft/utils/extras.py +37 -0
  28. graft_pytorch-0.1.7/graft/utils/generate.py +33 -0
  29. graft_pytorch-0.1.7/graft/utils/imagenetselloader.py +54 -0
  30. graft_pytorch-0.1.7/graft/utils/loader.py +293 -0
  31. graft_pytorch-0.1.7/graft/utils/model_mapper.py +45 -0
  32. graft_pytorch-0.1.7/graft/utils/pickler.py +27 -0
  33. graft_pytorch-0.1.7/graft_pytorch.egg-info/PKG-INFO +302 -0
  34. graft_pytorch-0.1.7/graft_pytorch.egg-info/SOURCES.txt +44 -0
  35. graft_pytorch-0.1.7/graft_pytorch.egg-info/dependency_links.txt +1 -0
  36. graft_pytorch-0.1.7/graft_pytorch.egg-info/entry_points.txt +2 -0
  37. graft_pytorch-0.1.7/graft_pytorch.egg-info/not-zip-safe +1 -0
  38. graft_pytorch-0.1.7/graft_pytorch.egg-info/requires.txt +29 -0
  39. graft_pytorch-0.1.7/graft_pytorch.egg-info/top_level.txt +1 -0
  40. graft_pytorch-0.1.7/pyproject.toml +110 -0
  41. graft_pytorch-0.1.7/requirements.txt +12 -0
  42. graft_pytorch-0.1.7/setup.cfg +4 -0
  43. graft_pytorch-0.1.7/setup.py +70 -0
  44. graft_pytorch-0.1.7/tests/test_genindices.py +130 -0
  45. graft_pytorch-0.1.7/tests/test_graft.py +141 -0
  46. graft_pytorch-0.1.7/tests/test_graft_e2e.py +106 -0
@@ -0,0 +1,11 @@
1
+ cff-version: 1.2.0
2
+ message: "If you use this software, please cite it as below."
3
+ authors:
4
+ - family-names: "Jha"
5
+ given-names: "A"
6
+ affiliation: "Skolkovo Institute of Science & Technology"
7
+ title: "GRAFT: Gradient-Aware Fast MaxVol Technique for Dynamic Data Sampling"
8
+ version: 1.0.0
9
+ doi: 10.5281/zenodo.16947530
10
+ date-released: 2025-01-01
11
+ url: "https://github.com/ashishjv1/GRAFT"
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Ashish Jha
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,36 @@
1
+ # Include the README and other documentation files
2
+ include README.md
3
+ include LICENSE
4
+ include CITATION.cff
5
+ include requirements.txt
6
+ include pyproject.toml
7
+
8
+ # Include supplementary materials
9
+ recursive-include supplementary *
10
+
11
+ # Include example notebooks and scripts
12
+ recursive-include examples *.py *.ipynb
13
+
14
+ # Include test files
15
+ recursive-include tests *.py
16
+
17
+ # Exclude unnecessary files
18
+ global-exclude __pycache__
19
+ global-exclude *.py[co]
20
+ global-exclude *.so
21
+ global-exclude .DS_Store
22
+ global-exclude *.egg-info
23
+
24
+ # Exclude legacy files
25
+ exclude trainer.py
26
+ exclude GRAFT.py
27
+ exclude GRAFT_swinft.py
28
+
29
+ # Exclude build directories
30
+ prune build
31
+ prune dist
32
+ prune .git
33
+ prune .pytest_cache
34
+
35
+ # Include package data
36
+ recursive-include graft *.py
@@ -0,0 +1,302 @@
1
+ Metadata-Version: 2.4
2
+ Name: graft-pytorch
3
+ Version: 0.1.7
4
+ Summary: Gradient-Aware Fast MaxVol Technique for Dynamic Data Sampling
5
+ Home-page: https://github.com/ashishjv1/GRAFT
6
+ Author: Ashish Jha
7
+ Author-email: Ashish Jha <Ashish.Jha@skoltech.ru>
8
+ Maintainer-email: Ashish Jha <Ashish.Jha@skoltech.ru>
9
+ License: MIT
10
+ Project-URL: Homepage, https://github.com/ashishjv1/GRAFT
11
+ Project-URL: Repository, https://github.com/ashishjv1/GRAFT
12
+ Project-URL: Bug Reports, https://github.com/ashishjv1/GRAFT/issues
13
+ Project-URL: Documentation, https://github.com/ashishjv1/GRAFT/blob/main/README.md
14
+ Keywords: machine-learning,deep-learning,pytorch,data-sampling,gradient-based-sampling
15
+ Classifier: Development Status :: 3 - Alpha
16
+ Classifier: Intended Audience :: Developers
17
+ Classifier: Intended Audience :: Science/Research
18
+ Classifier: License :: OSI Approved :: MIT License
19
+ Classifier: Operating System :: OS Independent
20
+ Classifier: Programming Language :: Python :: 3
21
+ Classifier: Programming Language :: Python :: 3.8
22
+ Classifier: Programming Language :: Python :: 3.9
23
+ Classifier: Programming Language :: Python :: 3.10
24
+ Classifier: Programming Language :: Python :: 3.11
25
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
26
+ Classifier: Topic :: Software Development :: Libraries :: Python Modules
27
+ Requires-Python: >=3.8
28
+ Description-Content-Type: text/markdown
29
+ License-File: LICENSE
30
+ Requires-Dist: torch>=1.9.0
31
+ Requires-Dist: torchvision>=0.10.0
32
+ Requires-Dist: numpy>=1.19.2
33
+ Requires-Dist: tqdm>=4.62.3
34
+ Requires-Dist: scikit-learn>=0.24.2
35
+ Requires-Dist: pillow>=8.3.1
36
+ Requires-Dist: matplotlib>=3.4.3
37
+ Requires-Dist: transformers>=4.0.0
38
+ Requires-Dist: medmnist>=2.0.0
39
+ Provides-Extra: dev
40
+ Requires-Dist: pytest>=6.2.5; extra == "dev"
41
+ Requires-Dist: pytest-cov; extra == "dev"
42
+ Requires-Dist: black; extra == "dev"
43
+ Requires-Dist: isort; extra == "dev"
44
+ Requires-Dist: flake8; extra == "dev"
45
+ Provides-Extra: tracking
46
+ Requires-Dist: wandb>=0.12.0; extra == "tracking"
47
+ Requires-Dist: eco2ai>=1.0.0; extra == "tracking"
48
+ Provides-Extra: all
49
+ Requires-Dist: pytest>=6.2.5; extra == "all"
50
+ Requires-Dist: pytest-cov; extra == "all"
51
+ Requires-Dist: black; extra == "all"
52
+ Requires-Dist: isort; extra == "all"
53
+ Requires-Dist: flake8; extra == "all"
54
+ Requires-Dist: wandb>=0.12.0; extra == "all"
55
+ Requires-Dist: eco2ai>=1.0.0; extra == "all"
56
+ Dynamic: author
57
+ Dynamic: home-page
58
+ Dynamic: license-file
59
+ Dynamic: requires-python
60
+
61
+ # GRAFT: Gradient-Aware Fast MaxVol Technique for Dynamic Data Sampling
62
+
63
+ [![PyPI version](https://badge.fury.io/py/graft-pytorch.svg)](https://badge.fury.io/py/graft-pytorch)
64
+ [![Python 3.8+](https://img.shields.io/badge/python-3.8+-blue.svg)](https://www.python.org/downloads/)
65
+ [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
66
+
67
+ A PyTorch implementation of smart sampling for efficient deep learning training.
68
+
69
+ ## Overview
70
+ GRAFT uses gradient information and feature decomposition to select the most informative samples during training, reducing computation time while maintaining model performance.
71
+
72
+ ## Features
73
+ - **Smart sample selection** using gradient-based importance scoring
74
+ - **Multi-architecture support** (ResNet, ResNeXT, EfficientNet, BERT)
75
+ - **Dataset compatibility** (CIFAR10/100, TinyImageNet, Caltech256, Medical datasets)
76
+ - **Experiment tracking** with Weights & Biases integration
77
+ - **Carbon footprint tracking** with eco2AI
78
+ - **Efficient training** with reduced computational overhead
79
+
80
+ ## Installation
81
+
82
+ ### From PyPI (Recommended)
83
+ ```bash
84
+ pip install graft-pytorch
85
+ ```
86
+
87
+ ### With optional dependencies
88
+ ```bash
89
+ # For experiment tracking
90
+ pip install graft-pytorch[tracking]
91
+
92
+ # For development
93
+ pip install graft-pytorch[dev]
94
+
95
+ # Everything
96
+ pip install graft-pytorch[all]
97
+ ```
98
+
99
+ ### From Source
100
+ ```bash
101
+ git clone https://github.com/ashishjv1/GRAFT.git
102
+ cd GRAFT
103
+ pip install -e .
104
+ ```
105
+
106
+ ## Quick Start
107
+
108
+ ### Command Line Interface
109
+ ```bash
110
+ # Install and train with smart sampling
111
+ pip install graft-pytorch
112
+
113
+ # Basic training with GRAFT sampling on CIFAR-10
114
+ graft-train \
115
+ --numEpochs=200 \
116
+ --batch_size=128 \
117
+ --device="cuda" \
118
+ --optimizer="sgd" \
119
+ --lr=0.1 \
120
+ --numClasses=10 \
121
+ --dataset="cifar10" \
122
+ --model="resnet18" \
123
+ --fraction=0.5 \
124
+ --select_iter=25 \
125
+ --warm_start
126
+ ```
127
+
128
+ ### Python API
129
+ ```python
130
+ import torch
131
+ from graft import ModelTrainer, TrainingConfig
132
+ from graft.utils.loader import loader
133
+
134
+ # Load your dataset
135
+ trainloader, valloader, trainset, valset = loader(
136
+ dataset="cifar10",
137
+ trn_batch_size=128,
138
+ val_batch_size=128
139
+ )
140
+
141
+ # Configure training with GRAFT
142
+ config = TrainingConfig(
143
+ numEpochs=100,
144
+ batch_size=128,
145
+ device="cuda" if torch.cuda.is_available() else "cpu",
146
+ model_name="resnet18",
147
+ dataset_name="cifar10",
148
+ trainloader=trainloader,
149
+ valloader=valloader,
150
+ trainset=trainset,
151
+ optimizer_name="sgd",
152
+ lr=0.1,
153
+ fraction=0.5, # Use 50% of data per epoch
154
+ selection_iter=25, # Reselect samples every 25 epochs
155
+ warm_start=True # Train on full data initially
156
+ )
157
+
158
+ # Train with smart sampling
159
+ trainer = ModelTrainer(config, trainloader, valloader, trainset)
160
+ train_stats, val_stats = trainer.train()
161
+
162
+ print(f"Best validation accuracy: {val_stats['best_acc']:.2%}")
163
+ ```
164
+
165
+ ### Advanced Usage
166
+ ```python
167
+ from graft import feature_sel, sample_selection
168
+ import torch.nn as nn
169
+
170
+ # Custom model and data selection
171
+ model = MyCustomModel()
172
+ data3 = feature_sel(dataloader, batch_size=128, device="cuda")
173
+
174
+ # Manual sample selection
175
+ selected_indices = sample_selection(
176
+ dataloader, data3, model, model.state_dict(),
177
+ batch_size=128, fraction=0.3, select_iter=10,
178
+ numEpochs=200, device="cuda", dataset="custom"
179
+ )
180
+ ```
181
+
182
+ ## Functionality Overview
183
+
184
+ ### Core Components
185
+
186
+ #### 1. Smart Sample Selection
187
+ - **`sample_selection()`**: Selects most informative samples using gradient-based importance
188
+ - **`feature_sel()`**: Performs feature decomposition for efficient sampling
189
+ - Reduces training time by 30-50% while maintaining model performance
190
+
191
+ #### 2. Supported Models
192
+ - **Vision Models**: ResNet, ResNeXt, EfficientNet, MobileNet, FashionCNN
193
+ - **Language Models**: BERT for sequence classification
194
+ - **Custom Models**: Easy integration with any PyTorch model
195
+
196
+ #### 3. Dataset Support
197
+ - **Computer Vision**: CIFAR-10/100, TinyImageNet, Caltech256
198
+ - **Medical Imaging**: Integration with MedMNIST datasets
199
+ - **Custom Datasets**: Support for any PyTorch DataLoader
200
+
201
+ #### 4. Training Features
202
+ - **Dynamic Sampling**: Adaptive sample selection during training
203
+ - **Warm Starting**: Begin with full dataset, then switch to sampling
204
+ - **Experiment Tracking**: Built-in WandB integration
205
+ - **Carbon Tracking**: Monitor environmental impact with eco2AI
206
+
207
+ ### Configuration Parameters
208
+
209
+ | Parameter | Description | Default | Options |
210
+ |-----------|-------------|---------|---------|
211
+ | `numEpochs` | Training epochs | 200 | Any integer |
212
+ | `batch_size` | Batch size | 128 | 32, 64, 128, 256+ |
213
+ | `device` | Computing device | "cuda" | "cpu", "cuda" |
214
+ | `model` | Model architecture | "resnet18" | "resnet18/50", "resnext", "efficientnet" |
215
+ | `fraction` | Data sampling ratio | 0.5 | 0.1 - 1.0 |
216
+ | `select_iter` | Reselection frequency | 25 | Any integer |
217
+ | `optimizer` | Optimization algorithm | "sgd" | "sgd", "adam" |
218
+ | `lr` | Learning rate | 0.1 | 0.001 - 0.1 |
219
+ | `warm_start` | Use full data initially | False | True/False |
220
+ | `decomp` | Decomposition backend | "numpy" | "numpy", "torch" |
221
+
222
+ ### Performance Benefits
223
+
224
+ - **Speed**: 30-50% faster training time
225
+ - **Memory**: Reduced memory usage through smart sampling
226
+ - **Accuracy**: Maintains or improves model performance
227
+ - **Efficiency**: Lower carbon footprint and energy consumption
228
+
229
+ ## Package Structure
230
+ ```
231
+ graft-pytorch/
232
+ ├── graft/
233
+ │ ├── __init__.py # Main package exports
234
+ │ ├── trainer.py # Training orchestration
235
+ │ ├── genindices.py # Sample selection algorithms
236
+ │ ├── decompositions.py # Feature decomposition
237
+ │ ├── models/ # Supported architectures
238
+ │ │ ├── resnet.py # ResNet implementations
239
+ │ │ ├── efficientnet.py # EfficientNet models
240
+ │ │ └── BERT_model.py # BERT for classification
241
+ │ └── utils/ # Utility functions
242
+ │ ├── loader.py # Dataset loaders
243
+ │ └── model_mapper.py # Model selection
244
+ ├── tests/ # Comprehensive test suite
245
+ ├── examples/ # Usage examples
246
+ └── OIDC_SETUP.md # Deployment configuration
247
+ ```
248
+
249
+ ## Contributing
250
+
251
+ We welcome contributions! Please see our [contribution guidelines](CONTRIBUTING.md) for details.
252
+
253
+ ### Development Setup
254
+ ```bash
255
+ # Clone the repository
256
+ git clone https://github.com/ashishjv1/GRAFT.git
257
+ cd GRAFT
258
+
259
+ # Install in development mode
260
+ pip install -e .[dev]
261
+
262
+ # Run tests
263
+ pytest tests/ -v
264
+
265
+ # Run linting
266
+ flake8 graft/ tests/
267
+ ```
268
+
269
+ ## License
270
+
271
+ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
272
+
273
+ ## Citation
274
+
275
+ If you use GRAFT in your research, please cite our paper:
276
+
277
+ ```bibtex
278
+ @misc{jha2025graftgradientawarefastmaxvol,
279
+ title = {GRAFT: Gradient-Aware Fast MaxVol Technique for Dynamic Data Sampling},
280
+ author = {Ashish Jha and Anh Huy Phan and Razan Dibo and Valentin Leplat},
281
+ year = {2025},
282
+ eprint = {2508.13653},
283
+ archivePrefix = {arXiv},
284
+ primaryClass = {cs.LG},
285
+ url = {https://arxiv.org/abs/2508.13653}
286
+ }
287
+ ```
288
+
289
+ ## Acknowledgments
290
+
291
+ - Built using PyTorch
292
+ - Inspired by MaxVol techniques for data sampling
293
+ - Special thanks to the open-source community
294
+
295
+ ---
296
+
297
+ **PyPI Package**: [graft-pytorch](https://pypi.org/project/graft-pytorch/)
298
+ **Paper**: [arXiv:2508.13653](https://arxiv.org/abs/2508.13653)
299
+ **Issues**: [GitHub Issues](https://github.com/ashishjv1/GRAFT/issues)
300
+ **Contact**: [Ashish Jha](mailto:Ashish.Jha@skoltech.ru)
301
+
302
+
@@ -0,0 +1,242 @@
1
+ # GRAFT: Gradient-Aware Fast MaxVol Technique for Dynamic Data Sampling
2
+
3
+ [![PyPI version](https://badge.fury.io/py/graft-pytorch.svg)](https://badge.fury.io/py/graft-pytorch)
4
+ [![Python 3.8+](https://img.shields.io/badge/python-3.8+-blue.svg)](https://www.python.org/downloads/)
5
+ [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
6
+
7
+ A PyTorch implementation of smart sampling for efficient deep learning training.
8
+
9
+ ## Overview
10
+ GRAFT uses gradient information and feature decomposition to select the most informative samples during training, reducing computation time while maintaining model performance.
11
+
12
+ ## Features
13
+ - **Smart sample selection** using gradient-based importance scoring
14
+ - **Multi-architecture support** (ResNet, ResNeXT, EfficientNet, BERT)
15
+ - **Dataset compatibility** (CIFAR10/100, TinyImageNet, Caltech256, Medical datasets)
16
+ - **Experiment tracking** with Weights & Biases integration
17
+ - **Carbon footprint tracking** with eco2AI
18
+ - **Efficient training** with reduced computational overhead
19
+
20
+ ## Installation
21
+
22
+ ### From PyPI (Recommended)
23
+ ```bash
24
+ pip install graft-pytorch
25
+ ```
26
+
27
+ ### With optional dependencies
28
+ ```bash
29
+ # For experiment tracking
30
+ pip install graft-pytorch[tracking]
31
+
32
+ # For development
33
+ pip install graft-pytorch[dev]
34
+
35
+ # Everything
36
+ pip install graft-pytorch[all]
37
+ ```
38
+
39
+ ### From Source
40
+ ```bash
41
+ git clone https://github.com/ashishjv1/GRAFT.git
42
+ cd GRAFT
43
+ pip install -e .
44
+ ```
45
+
46
+ ## Quick Start
47
+
48
+ ### Command Line Interface
49
+ ```bash
50
+ # Install and train with smart sampling
51
+ pip install graft-pytorch
52
+
53
+ # Basic training with GRAFT sampling on CIFAR-10
54
+ graft-train \
55
+ --numEpochs=200 \
56
+ --batch_size=128 \
57
+ --device="cuda" \
58
+ --optimizer="sgd" \
59
+ --lr=0.1 \
60
+ --numClasses=10 \
61
+ --dataset="cifar10" \
62
+ --model="resnet18" \
63
+ --fraction=0.5 \
64
+ --select_iter=25 \
65
+ --warm_start
66
+ ```
67
+
68
+ ### Python API
69
+ ```python
70
+ import torch
71
+ from graft import ModelTrainer, TrainingConfig
72
+ from graft.utils.loader import loader
73
+
74
+ # Load your dataset
75
+ trainloader, valloader, trainset, valset = loader(
76
+ dataset="cifar10",
77
+ trn_batch_size=128,
78
+ val_batch_size=128
79
+ )
80
+
81
+ # Configure training with GRAFT
82
+ config = TrainingConfig(
83
+ numEpochs=100,
84
+ batch_size=128,
85
+ device="cuda" if torch.cuda.is_available() else "cpu",
86
+ model_name="resnet18",
87
+ dataset_name="cifar10",
88
+ trainloader=trainloader,
89
+ valloader=valloader,
90
+ trainset=trainset,
91
+ optimizer_name="sgd",
92
+ lr=0.1,
93
+ fraction=0.5, # Use 50% of data per epoch
94
+ selection_iter=25, # Reselect samples every 25 epochs
95
+ warm_start=True # Train on full data initially
96
+ )
97
+
98
+ # Train with smart sampling
99
+ trainer = ModelTrainer(config, trainloader, valloader, trainset)
100
+ train_stats, val_stats = trainer.train()
101
+
102
+ print(f"Best validation accuracy: {val_stats['best_acc']:.2%}")
103
+ ```
104
+
105
+ ### Advanced Usage
106
+ ```python
107
+ from graft import feature_sel, sample_selection
108
+ import torch.nn as nn
109
+
110
+ # Custom model and data selection
111
+ model = MyCustomModel()
112
+ data3 = feature_sel(dataloader, batch_size=128, device="cuda")
113
+
114
+ # Manual sample selection
115
+ selected_indices = sample_selection(
116
+ dataloader, data3, model, model.state_dict(),
117
+ batch_size=128, fraction=0.3, select_iter=10,
118
+ numEpochs=200, device="cuda", dataset="custom"
119
+ )
120
+ ```
121
+
122
+ ## Functionality Overview
123
+
124
+ ### Core Components
125
+
126
+ #### 1. Smart Sample Selection
127
+ - **`sample_selection()`**: Selects most informative samples using gradient-based importance
128
+ - **`feature_sel()`**: Performs feature decomposition for efficient sampling
129
+ - Reduces training time by 30-50% while maintaining model performance
130
+
131
+ #### 2. Supported Models
132
+ - **Vision Models**: ResNet, ResNeXt, EfficientNet, MobileNet, FashionCNN
133
+ - **Language Models**: BERT for sequence classification
134
+ - **Custom Models**: Easy integration with any PyTorch model
135
+
136
+ #### 3. Dataset Support
137
+ - **Computer Vision**: CIFAR-10/100, TinyImageNet, Caltech256
138
+ - **Medical Imaging**: Integration with MedMNIST datasets
139
+ - **Custom Datasets**: Support for any PyTorch DataLoader
140
+
141
+ #### 4. Training Features
142
+ - **Dynamic Sampling**: Adaptive sample selection during training
143
+ - **Warm Starting**: Begin with full dataset, then switch to sampling
144
+ - **Experiment Tracking**: Built-in WandB integration
145
+ - **Carbon Tracking**: Monitor environmental impact with eco2AI
146
+
147
+ ### Configuration Parameters
148
+
149
+ | Parameter | Description | Default | Options |
150
+ |-----------|-------------|---------|---------|
151
+ | `numEpochs` | Training epochs | 200 | Any integer |
152
+ | `batch_size` | Batch size | 128 | 32, 64, 128, 256+ |
153
+ | `device` | Computing device | "cuda" | "cpu", "cuda" |
154
+ | `model` | Model architecture | "resnet18" | "resnet18/50", "resnext", "efficientnet" |
155
+ | `fraction` | Data sampling ratio | 0.5 | 0.1 - 1.0 |
156
+ | `select_iter` | Reselection frequency | 25 | Any integer |
157
+ | `optimizer` | Optimization algorithm | "sgd" | "sgd", "adam" |
158
+ | `lr` | Learning rate | 0.1 | 0.001 - 0.1 |
159
+ | `warm_start` | Use full data initially | False | True/False |
160
+ | `decomp` | Decomposition backend | "numpy" | "numpy", "torch" |
161
+
162
+ ### Performance Benefits
163
+
164
+ - **Speed**: 30-50% faster training time
165
+ - **Memory**: Reduced memory usage through smart sampling
166
+ - **Accuracy**: Maintains or improves model performance
167
+ - **Efficiency**: Lower carbon footprint and energy consumption
168
+
169
+ ## Package Structure
170
+ ```
171
+ graft-pytorch/
172
+ ├── graft/
173
+ │ ├── __init__.py # Main package exports
174
+ │ ├── trainer.py # Training orchestration
175
+ │ ├── genindices.py # Sample selection algorithms
176
+ │ ├── decompositions.py # Feature decomposition
177
+ │ ├── models/ # Supported architectures
178
+ │ │ ├── resnet.py # ResNet implementations
179
+ │ │ ├── efficientnet.py # EfficientNet models
180
+ │ │ └── BERT_model.py # BERT for classification
181
+ │ └── utils/ # Utility functions
182
+ │ ├── loader.py # Dataset loaders
183
+ │ └── model_mapper.py # Model selection
184
+ ├── tests/ # Comprehensive test suite
185
+ ├── examples/ # Usage examples
186
+ └── OIDC_SETUP.md # Deployment configuration
187
+ ```
188
+
189
+ ## Contributing
190
+
191
+ We welcome contributions! Please see our [contribution guidelines](CONTRIBUTING.md) for details.
192
+
193
+ ### Development Setup
194
+ ```bash
195
+ # Clone the repository
196
+ git clone https://github.com/ashishjv1/GRAFT.git
197
+ cd GRAFT
198
+
199
+ # Install in development mode
200
+ pip install -e .[dev]
201
+
202
+ # Run tests
203
+ pytest tests/ -v
204
+
205
+ # Run linting
206
+ flake8 graft/ tests/
207
+ ```
208
+
209
+ ## License
210
+
211
+ This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
212
+
213
+ ## Citation
214
+
215
+ If you use GRAFT in your research, please cite our paper:
216
+
217
+ ```bibtex
218
+ @misc{jha2025graftgradientawarefastmaxvol,
219
+ title = {GRAFT: Gradient-Aware Fast MaxVol Technique for Dynamic Data Sampling},
220
+ author = {Ashish Jha and Anh Huy Phan and Razan Dibo and Valentin Leplat},
221
+ year = {2025},
222
+ eprint = {2508.13653},
223
+ archivePrefix = {arXiv},
224
+ primaryClass = {cs.LG},
225
+ url = {https://arxiv.org/abs/2508.13653}
226
+ }
227
+ ```
228
+
229
+ ## Acknowledgments
230
+
231
+ - Built using PyTorch
232
+ - Inspired by MaxVol techniques for data sampling
233
+ - Special thanks to the open-source community
234
+
235
+ ---
236
+
237
+ **PyPI Package**: [graft-pytorch](https://pypi.org/project/graft-pytorch/)
238
+ **Paper**: [arXiv:2508.13653](https://arxiv.org/abs/2508.13653)
239
+ **Issues**: [GitHub Issues](https://github.com/ashishjv1/GRAFT/issues)
240
+ **Contact**: [Ashish Jha](mailto:Ashish.Jha@skoltech.ru)
241
+
242
+