graft-pytorch 0.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- graft/__init__.py +20 -0
- graft/cli.py +62 -0
- graft/config.py +36 -0
- graft/decompositions.py +54 -0
- graft/genindices.py +122 -0
- graft/grad_dist.py +20 -0
- graft/models/BERT_model.py +40 -0
- graft/models/MobilenetV2.py +111 -0
- graft/models/ResNeXt.py +154 -0
- graft/models/__init__.py +22 -0
- graft/models/efficientnet.py +197 -0
- graft/models/efficientnetb7.py +268 -0
- graft/models/fashioncnn.py +69 -0
- graft/models/mobilenet.py +83 -0
- graft/models/resnet.py +564 -0
- graft/models/resnet9.py +72 -0
- graft/scheduler.py +63 -0
- graft/trainer.py +467 -0
- graft/utils/__init__.py +5 -0
- graft/utils/extras.py +37 -0
- graft/utils/generate.py +33 -0
- graft/utils/imagenetselloader.py +54 -0
- graft/utils/loader.py +293 -0
- graft/utils/model_mapper.py +45 -0
- graft/utils/pickler.py +27 -0
- graft_pytorch-0.1.7.dist-info/METADATA +302 -0
- graft_pytorch-0.1.7.dist-info/RECORD +31 -0
- graft_pytorch-0.1.7.dist-info/WHEEL +5 -0
- graft_pytorch-0.1.7.dist-info/entry_points.txt +2 -0
- graft_pytorch-0.1.7.dist-info/licenses/LICENSE +21 -0
- graft_pytorch-0.1.7.dist-info/top_level.txt +1 -0
graft/scheduler.py
ADDED
@@ -0,0 +1,63 @@
|
|
1
|
+
import logging
|
2
|
+
import math
|
3
|
+
|
4
|
+
from torch.optim.lr_scheduler import LambdaLR
|
5
|
+
|
6
|
+
logger = logging.getLogger(__name__)
|
7
|
+
|
8
|
+
class ConstantLRSchedule(LambdaLR):
|
9
|
+
""" Constant learning rate schedule.
|
10
|
+
"""
|
11
|
+
def __init__(self, optimizer, last_epoch=-1):
|
12
|
+
super(ConstantLRSchedule, self).__init__(optimizer, lambda _: 1.0, last_epoch=last_epoch)
|
13
|
+
|
14
|
+
|
15
|
+
class WarmupConstantSchedule(LambdaLR):
|
16
|
+
""" Linear warmup and then constant.
|
17
|
+
Linearly increases learning rate schedule from 0 to 1 over `warmup_steps` training steps.
|
18
|
+
Keeps learning rate schedule equal to 1. after warmup_steps.
|
19
|
+
"""
|
20
|
+
def __init__(self, optimizer, warmup_steps, last_epoch=-1):
|
21
|
+
self.warmup_steps = warmup_steps
|
22
|
+
super(WarmupConstantSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)
|
23
|
+
|
24
|
+
def lr_lambda(self, step):
|
25
|
+
if step < self.warmup_steps:
|
26
|
+
return float(step) / float(max(1.0, self.warmup_steps))
|
27
|
+
return 1.
|
28
|
+
|
29
|
+
|
30
|
+
class WarmupLinearSchedule(LambdaLR):
|
31
|
+
""" Linear warmup and then linear decay.
|
32
|
+
Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps.
|
33
|
+
Linearly decreases learning rate from 1. to 0. over remaining `t_total - warmup_steps` steps.
|
34
|
+
"""
|
35
|
+
def __init__(self, optimizer, warmup_steps, t_total, last_epoch=-1):
|
36
|
+
self.warmup_steps = warmup_steps
|
37
|
+
self.t_total = t_total
|
38
|
+
super(WarmupLinearSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)
|
39
|
+
|
40
|
+
def lr_lambda(self, step):
|
41
|
+
if step < self.warmup_steps:
|
42
|
+
return float(step) / float(max(1, self.warmup_steps))
|
43
|
+
return max(0.0, float(self.t_total - step) / float(max(1.0, self.t_total - self.warmup_steps)))
|
44
|
+
|
45
|
+
|
46
|
+
class WarmupCosineSchedule(LambdaLR):
|
47
|
+
""" Linear warmup and then cosine decay.
|
48
|
+
Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps.
|
49
|
+
Decreases learning rate from 1. to 0. over remaining `t_total - warmup_steps` steps following a cosine curve.
|
50
|
+
If `cycles` (default=0.5) is different from default, learning rate follows cosine function after warmup.
|
51
|
+
"""
|
52
|
+
def __init__(self, optimizer, warmup_steps, t_total, cycles=.5, last_epoch=-1):
|
53
|
+
self.warmup_steps = warmup_steps
|
54
|
+
self.t_total = t_total
|
55
|
+
self.cycles = cycles
|
56
|
+
super(WarmupCosineSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)
|
57
|
+
|
58
|
+
def lr_lambda(self, step):
|
59
|
+
if step < self.warmup_steps:
|
60
|
+
return float(step) / float(max(1.0, self.warmup_steps))
|
61
|
+
# progress after warmup
|
62
|
+
progress = float(step - self.warmup_steps) / float(max(1, self.t_total - self.warmup_steps))
|
63
|
+
return max(0.0, 0.5 * (1. + math.cos(math.pi * float(self.cycles) * 2.0 * progress)))
|
graft/trainer.py
ADDED
@@ -0,0 +1,467 @@
|
|
1
|
+
# Standard library imports
|
2
|
+
import os
|
3
|
+
import pickle
|
4
|
+
import copy
|
5
|
+
import gc
|
6
|
+
import argparse
|
7
|
+
|
8
|
+
# Third-party imports
|
9
|
+
import torch
|
10
|
+
import torch.nn as nn
|
11
|
+
import torch.optim as optim
|
12
|
+
import torch.optim.lr_scheduler as lr_scheduler
|
13
|
+
from torch.utils.data import DataLoader, Subset
|
14
|
+
from tqdm import tqdm
|
15
|
+
|
16
|
+
# Local imports
|
17
|
+
from .utils.loader import loader
|
18
|
+
from .utils.model_mapper import ModelMapper
|
19
|
+
from .utils.imagenetselloader import imagenet_selloader
|
20
|
+
from .utils import pickler
|
21
|
+
from .decompositions import feature_sel
|
22
|
+
from .genindices import sample_selection
|
23
|
+
|
24
|
+
# Optional dependencies
|
25
|
+
try:
|
26
|
+
import wandb
|
27
|
+
WANDB_AVAILABLE = True
|
28
|
+
except ImportError:
|
29
|
+
WANDB_AVAILABLE = False
|
30
|
+
|
31
|
+
try:
|
32
|
+
import eco2ai
|
33
|
+
ECO2AI_AVAILABLE = True
|
34
|
+
except ImportError:
|
35
|
+
ECO2AI_AVAILABLE = False
|
36
|
+
|
37
|
+
|
38
|
+
class TrainingConfig:
|
39
|
+
def __init__(self, numEpochs, batch_size, device, net,
|
40
|
+
model_name, dataset_name, trainloader, valloader,
|
41
|
+
trainset, data3, optimizer_name, lr, weight_decay,
|
42
|
+
grad_clip, fraction, selection_iter, warm_start,
|
43
|
+
imgntselloader, sched="cosine", multi_checkpoint=False,
|
44
|
+
use_wandb=True): # Add use_wandb parameter
|
45
|
+
|
46
|
+
self.numEpochs = numEpochs
|
47
|
+
self.batch_size = batch_size
|
48
|
+
self.device = device
|
49
|
+
self.net = net
|
50
|
+
self.model_name = model_name
|
51
|
+
self.dataset_name = dataset_name
|
52
|
+
self.trainloader = trainloader
|
53
|
+
self.valloader = valloader
|
54
|
+
self.trainset = trainset
|
55
|
+
self.data3 = data3
|
56
|
+
self.optimizer_name = optimizer_name
|
57
|
+
self.lr = lr
|
58
|
+
self.weight_decay = weight_decay
|
59
|
+
self.grad_clip = grad_clip
|
60
|
+
self.fraction = fraction
|
61
|
+
self.selection_iter = selection_iter
|
62
|
+
self.warm_start = warm_start
|
63
|
+
self.imgntselloader = imgntselloader
|
64
|
+
self.sched = sched
|
65
|
+
self.multi_checkpoint = multi_checkpoint
|
66
|
+
self.use_wandb = use_wandb and WANDB_AVAILABLE
|
67
|
+
|
68
|
+
@classmethod
|
69
|
+
def from_args(cls, args):
|
70
|
+
return cls(
|
71
|
+
numEpochs=args.numEpochs,
|
72
|
+
batch_size=args.batch_size,
|
73
|
+
device=args.device,
|
74
|
+
net=None, # Placeholder, will be set in the trainer
|
75
|
+
model_name=args.model,
|
76
|
+
dataset_name=args.dataset,
|
77
|
+
trainloader=None, # Placeholder, will be set in the trainer
|
78
|
+
valloader=None, # Placeholder, will be set in the trainer
|
79
|
+
trainset=None, # Placeholder, will be set in the trainer
|
80
|
+
data3=None, # Placeholder, will be set in the trainer
|
81
|
+
optimizer_name=args.optimizer,
|
82
|
+
lr=args.lr,
|
83
|
+
weight_decay=args.weight_decay,
|
84
|
+
grad_clip=args.grad_clip,
|
85
|
+
fraction=args.fraction,
|
86
|
+
selection_iter=args.select_iter,
|
87
|
+
warm_start=args.warm_start,
|
88
|
+
imgntselloader=None, # Placeholder, will be set in the trainer
|
89
|
+
sched="cosine",
|
90
|
+
multi_checkpoint=False,
|
91
|
+
use_wandb=getattr(args, 'use_wandb', True)
|
92
|
+
)
|
93
|
+
|
94
|
+
|
95
|
+
class ModelTrainer:
|
96
|
+
def __init__(self, config, model, trainloader, valloader, trainset, data3):
|
97
|
+
self.config = config
|
98
|
+
self.model = model
|
99
|
+
self.trainloader = trainloader
|
100
|
+
self.valloader = valloader
|
101
|
+
self.trainset = trainset
|
102
|
+
self.data3 = data3
|
103
|
+
self.optimizer = None
|
104
|
+
self.scheduler = None
|
105
|
+
self.loss_fn = None
|
106
|
+
self.curr_high = 0
|
107
|
+
self.total = 0
|
108
|
+
self.correct = 0
|
109
|
+
self.trn_losses = list()
|
110
|
+
self.val_losses = list()
|
111
|
+
self.trn_acc = list()
|
112
|
+
self.val_acc = list()
|
113
|
+
self.selection = 0
|
114
|
+
self.weight_decay = 1e-4
|
115
|
+
|
116
|
+
self.dir_save = f"saved_models/{config.model_name}"
|
117
|
+
self.save_dir = f"{self.dir_save}/multi_checkpoint"
|
118
|
+
|
119
|
+
self._setup()
|
120
|
+
|
121
|
+
def _setup(self):
|
122
|
+
# Default to cross entropy loss unless specifically handling regression
|
123
|
+
self.loss_fn = torch.nn.functional.cross_entropy
|
124
|
+
|
125
|
+
# Create save directories
|
126
|
+
if not os.path.exists(self.dir_save):
|
127
|
+
os.makedirs(self.dir_save)
|
128
|
+
if not os.path.exists(self.save_dir):
|
129
|
+
os.makedirs(self.save_dir)
|
130
|
+
|
131
|
+
if self.config.optimizer_name.lower() == "adam":
|
132
|
+
self.optimizer = optim.Adam(self.model.parameters(), lr=self.config.lr)
|
133
|
+
elif self.config.optimizer_name.lower() == "sgd":
|
134
|
+
self.optimizer = optim.SGD(self.model.parameters(), lr=self.config.lr, momentum=0.9, weight_decay = self.config.weight_decay)
|
135
|
+
else:
|
136
|
+
raise ValueError(f"Unsupported optimizer: {self.config.optimizer_name}")
|
137
|
+
|
138
|
+
if self.config.sched.lower() == "onecycle":
|
139
|
+
self.scheduler = lr_scheduler.OneCycleLR(self.optimizer, self.config.lr, epochs=self.config.numEpochs,
|
140
|
+
steps_per_epoch=len(self.trainloader))
|
141
|
+
else:
|
142
|
+
self.scheduler = lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=200)
|
143
|
+
|
144
|
+
if self.config.use_wandb:
|
145
|
+
if WANDB_AVAILABLE:
|
146
|
+
wandb.login()
|
147
|
+
if self.config.model_name.lower() == "efficientnet-b0":
|
148
|
+
model_name = "efficientnetb0"
|
149
|
+
elif self.config.model_name.lower() == "efficientnet-b5":
|
150
|
+
model_name = "efficientnetb5"
|
151
|
+
config = {"lr": self.config.lr, "batch_size": self.config.batch_size}
|
152
|
+
config.update({"architecture": f'{self.model}'})
|
153
|
+
wandb.init(project=f"Smart_Sampling_{self.config.model_name}_{self.config.dataset_name}",
|
154
|
+
config=config)
|
155
|
+
|
156
|
+
self.main_trainloader = self.trainloader
|
157
|
+
|
158
|
+
def train(self):
|
159
|
+
train_stats = {
|
160
|
+
'losses': self.trn_losses,
|
161
|
+
'accuracies': self.trn_acc,
|
162
|
+
'best_acc': max(self.trn_acc) if self.trn_acc else 0
|
163
|
+
}
|
164
|
+
|
165
|
+
val_stats = {
|
166
|
+
'losses': self.val_losses,
|
167
|
+
'accuracies': self.val_acc,
|
168
|
+
'best_acc': self.curr_high
|
169
|
+
}
|
170
|
+
|
171
|
+
for epoch in range(self.config.numEpochs):
|
172
|
+
self.model.train()
|
173
|
+
tot_train_loss = 0
|
174
|
+
before_lr = self.optimizer.param_groups[0]["lr"]
|
175
|
+
pruned_samples = 0
|
176
|
+
total_samples = 0
|
177
|
+
|
178
|
+
|
179
|
+
if (epoch) % self.config.selection_iter == 0:
|
180
|
+
if self.config.warm_start and self.selection == 0:
|
181
|
+
trainloader = self.trainloader
|
182
|
+
self.selection += 1
|
183
|
+
else:
|
184
|
+
train_model = self.model
|
185
|
+
cached_state_dict = copy.deepcopy(train_model.state_dict())
|
186
|
+
clone_dict = copy.deepcopy(train_model.state_dict())
|
187
|
+
|
188
|
+
# Skip selection if no data3 available (for tests)
|
189
|
+
if self.data3 is None:
|
190
|
+
continue
|
191
|
+
|
192
|
+
if not self.config.imgntselloader:
|
193
|
+
indices = sample_selection(self.main_trainloader, self.data3, self.model,
|
194
|
+
clone_dict, self.config.batch_size, self.config.fraction,
|
195
|
+
self.config.selection_iter, self.config.numEpochs,
|
196
|
+
self.config.device, self.config.dataset_name)
|
197
|
+
else:
|
198
|
+
|
199
|
+
indices = sample_selection(self.config.imgntselloader, self.data3, self.model,
|
200
|
+
clone_dict, self.config.batch_size, self.config.fraction,
|
201
|
+
self.config.selection_iter, self.config.numEpochs, self.config.device, self.config.dataset_name)
|
202
|
+
|
203
|
+
self.model.load_state_dict(cached_state_dict)
|
204
|
+
|
205
|
+
self.selection += 1
|
206
|
+
|
207
|
+
datasubset = Subset(self.trainset, indices)
|
208
|
+
new_trainloader = DataLoader(datasubset, batch_size=self.config.batch_size,
|
209
|
+
shuffle=True, pin_memory=False, num_workers=1)
|
210
|
+
|
211
|
+
self.trainloader = new_trainloader
|
212
|
+
|
213
|
+
del cached_state_dict
|
214
|
+
del clone_dict
|
215
|
+
del train_model
|
216
|
+
torch.cuda.empty_cache()
|
217
|
+
gc.collect()
|
218
|
+
|
219
|
+
for _, (trainsamples, labels) in enumerate(tqdm(self.trainloader)):
|
220
|
+
|
221
|
+
trainsamples = trainsamples.to(self.config.device)
|
222
|
+
labels = labels.to(self.config.device)
|
223
|
+
|
224
|
+
X = trainsamples
|
225
|
+
Y = labels
|
226
|
+
pred = self.model(X)
|
227
|
+
|
228
|
+
|
229
|
+
# loss = torch.nn.functional.cross_entropy(pred, Y.to(device))
|
230
|
+
loss = self.loss_fn(pred, Y.to(self.config.device))
|
231
|
+
|
232
|
+
|
233
|
+
tot_train_loss += loss.item()
|
234
|
+
|
235
|
+
self.optimizer.zero_grad()
|
236
|
+
|
237
|
+
loss.backward()
|
238
|
+
|
239
|
+
if self.config.grad_clip:
|
240
|
+
nn.utils.clip_grad_value_(self.model.parameters(), self.config.grad_clip)
|
241
|
+
|
242
|
+
self.optimizer.step()
|
243
|
+
|
244
|
+
# calculate accuracy
|
245
|
+
_, predicted = torch.max(pred.cpu().data, 1)
|
246
|
+
self.total += Y.size(0)
|
247
|
+
|
248
|
+
self.correct += (predicted == Y.cpu()).sum().item()
|
249
|
+
# accuracy = 100 * correct / total
|
250
|
+
pruned_samples += len(trainsamples) - len(X)
|
251
|
+
total_samples += len(trainsamples)
|
252
|
+
|
253
|
+
if self.config.sched.lower() == "onecycle":
|
254
|
+
self.scheduler.step()
|
255
|
+
|
256
|
+
if self.config.sched.lower() == "cosine":
|
257
|
+
self.scheduler.step()
|
258
|
+
|
259
|
+
after_lr = self.optimizer.param_groups[0]["lr"]
|
260
|
+
|
261
|
+
print("Last Epoch [%d] -> Current Epoch [%d]: lr %.4f -> %.4f optimizer %s" % (epoch, epoch+1, before_lr, after_lr, self.config.optimizer_name))
|
262
|
+
|
263
|
+
|
264
|
+
if epoch % 20 == 0:
|
265
|
+
dir_parts = self.dir_save.split('/')
|
266
|
+
current_dir = ''
|
267
|
+
|
268
|
+
for part in dir_parts:
|
269
|
+
current_dir = os.path.join(current_dir, part)
|
270
|
+
if not os.path.exists(current_dir):
|
271
|
+
os.makedirs(current_dir)
|
272
|
+
|
273
|
+
if not os.path.exists(self.save_dir):
|
274
|
+
os.makedirs(self.save_dir)
|
275
|
+
|
276
|
+
if not os.path.exists(self.dir_save):
|
277
|
+
os.makedirs(self.dir_save)
|
278
|
+
|
279
|
+
if self.config.selection_iter > self.config.numEpochs:
|
280
|
+
file_prefix = "Full"
|
281
|
+
else:
|
282
|
+
file_prefix = "Sampled"
|
283
|
+
|
284
|
+
if self.config.multi_checkpoint:
|
285
|
+
file_prefix += "_multi"
|
286
|
+
|
287
|
+
filename = f"{file_prefix}_{self.config.dataset_name}_sch{self.config.sched}_si{self.config.selection_iter}_f{self.config.fraction}"
|
288
|
+
if self.config.multi_checkpoint:
|
289
|
+
filename += f"_ep{epoch}"
|
290
|
+
torch.save(self.model.state_dict(), f"{self.save_dir}/{filename}.pth")
|
291
|
+
else:
|
292
|
+
torch.save(self.model.state_dict(), f"{self.dir_save}/{filename}.pth")
|
293
|
+
|
294
|
+
|
295
|
+
|
296
|
+
if (epoch+1) % 1 == 0:
|
297
|
+
trn_loss = 0
|
298
|
+
trn_correct = 0
|
299
|
+
trn_total = 0
|
300
|
+
val_loss = 0
|
301
|
+
val_correct = 0
|
302
|
+
val_total = 0
|
303
|
+
self.model.eval()
|
304
|
+
with torch.no_grad():
|
305
|
+
for _, (inputs, targets) in enumerate(self.trainloader):
|
306
|
+
inputs, targets = inputs.to(self.config.device), \
|
307
|
+
targets.to(self.config.device, non_blocking=True)
|
308
|
+
outputs = self.model(inputs)
|
309
|
+
loss = self.loss_fn(outputs, targets)
|
310
|
+
trn_loss += loss.item()
|
311
|
+
_, predicted = outputs.max(1)
|
312
|
+
trn_total += targets.size(0)
|
313
|
+
trn_correct += predicted.eq(targets).sum().item()
|
314
|
+
self.trn_losses.append(trn_loss)
|
315
|
+
self.trn_acc.append(trn_correct / trn_total)
|
316
|
+
with torch.no_grad():
|
317
|
+
for _, (inputs, targets) in enumerate(self.valloader):
|
318
|
+
inputs, targets = inputs.to(self.config.device), \
|
319
|
+
targets.to(self.config.device, non_blocking=True)
|
320
|
+
outputs = self.model(inputs)
|
321
|
+
loss = self.loss_fn(outputs, targets)
|
322
|
+
val_loss += loss.item()
|
323
|
+
_, predicted = outputs.max(1)
|
324
|
+
val_total += targets.size(0)
|
325
|
+
val_correct += predicted.eq(targets).sum().item()
|
326
|
+
self.val_losses.append(val_loss)
|
327
|
+
self.val_acc.append(val_correct / val_total)
|
328
|
+
|
329
|
+
if self.val_acc[-1] > self.curr_high:
|
330
|
+
self.curr_high = self.val_acc[-1]
|
331
|
+
|
332
|
+
|
333
|
+
if self.config.use_wandb and WANDB_AVAILABLE:
|
334
|
+
wandb.log({
|
335
|
+
"Validation accuracy": self.curr_high,
|
336
|
+
"Val Loss": self.val_losses[-1]/100,
|
337
|
+
"loss": self.trn_losses[-1]/100,
|
338
|
+
"Train Accuracy": self.trn_acc[-1]*100,
|
339
|
+
"Epoch": epoch
|
340
|
+
})
|
341
|
+
|
342
|
+
print("Epoch [{}/{}], Loss: {:.4f}, Train Accuracy: {:.2f}%".format(
|
343
|
+
epoch+1,
|
344
|
+
self.config.numEpochs,
|
345
|
+
self.trn_losses[-1],
|
346
|
+
self.trn_acc[-1]*100
|
347
|
+
))
|
348
|
+
|
349
|
+
print("Highest Accuracy:", self.curr_high)
|
350
|
+
print("Validation Accuracy:", self.val_acc[-1])
|
351
|
+
print("Validation Loss", self.val_losses[-1])
|
352
|
+
|
353
|
+
return train_stats, val_stats
|
354
|
+
|
355
|
+
|
356
|
+
def get_model(args):
|
357
|
+
arguments = type('', (), {'model': args.model.lower(), 'numClasses': args.numClasses,
|
358
|
+
'device': args.device, 'in_chanls':args.inp_channels})()
|
359
|
+
model_mapper = ModelMapper(arguments)
|
360
|
+
return model_mapper.get_model()
|
361
|
+
|
362
|
+
|
363
|
+
def prepare_data(args, trainloader):
|
364
|
+
if args.select_iter < args.numEpochs:
|
365
|
+
imgntselloader = None
|
366
|
+
pickle_dir = f"{args.dataset}_pickle"
|
367
|
+
file = os.path.join(pickle_dir, f"V_{args.batch_size}.pkl")
|
368
|
+
|
369
|
+
# Create pickle directory if it doesn't exist
|
370
|
+
if not os.path.exists(pickle_dir):
|
371
|
+
os.makedirs(pickle_dir)
|
372
|
+
|
373
|
+
if os.path.exists(file):
|
374
|
+
print("Loading existing pickle file")
|
375
|
+
with open(file, 'rb') as f:
|
376
|
+
data3 = pickle.load(f)
|
377
|
+
else:
|
378
|
+
print("Generating new pickle file")
|
379
|
+
if args.dataset.lower() != "imagenet":
|
380
|
+
V = feature_sel(trainloader, args.batch_size, device=args.device, decomp_type=args.decomp)
|
381
|
+
data3 = V
|
382
|
+
# Save pickle
|
383
|
+
with open(file, 'wb') as f:
|
384
|
+
pickle.dump(V, f)
|
385
|
+
else:
|
386
|
+
imgntselloader = imagenet_selloader(args.dataset, dirs=args.dataset_dir,
|
387
|
+
trn_batch_size=args.batch_size,
|
388
|
+
val_batch_size=args.batch_size,
|
389
|
+
tst_batch_size=1000, resize=32)
|
390
|
+
|
391
|
+
V = feature_sel(imgntselloader, args.batch_size, device=args.device, decomp_type=args.decomp)
|
392
|
+
data3 = V
|
393
|
+
|
394
|
+
with open(file, 'wb') as f:
|
395
|
+
pickle.dump(V, f)
|
396
|
+
else:
|
397
|
+
data3 = None
|
398
|
+
|
399
|
+
if args.dataset.lower() == "imagenet" and not imgntselloader:
|
400
|
+
imgntselloader = imagenet_selloader(args.dataset, dirs=args.dataset_dir,
|
401
|
+
trn_batch_size=args.batch_size,
|
402
|
+
val_batch_size=args.batch_size,
|
403
|
+
tst_batch_size=1000, resize=32)
|
404
|
+
|
405
|
+
return data3
|
406
|
+
|
407
|
+
|
408
|
+
def setup_tracker(args):
|
409
|
+
if not ECO2AI_AVAILABLE:
|
410
|
+
print("Warning: eco2ai not available, skipping emissions tracking")
|
411
|
+
return None
|
412
|
+
|
413
|
+
if args.warm_start:
|
414
|
+
ttype = "warm"
|
415
|
+
else:
|
416
|
+
ttype = "nowarm"
|
417
|
+
|
418
|
+
tracker = eco2ai.Tracker(
|
419
|
+
project_name=f"{args.model}_dset-{args.dataset}_bs-{args.batch_size}",
|
420
|
+
experiment_description="training DEIM_IS model",
|
421
|
+
file_name=f"emission_-{args.model}_dset-{args.dataset}_bs-{args.batch_size}_epochs-{args.numEpochs}_fraction-{args.fraction}_{args.optimizer}_{ttype}.csv"
|
422
|
+
)
|
423
|
+
return tracker
|
424
|
+
|
425
|
+
|
426
|
+
if __name__ == '__main__':
|
427
|
+
parser = argparse.ArgumentParser(description="Model Training with smart Sampling")
|
428
|
+
parser.add_argument('--batch_size', default='128', type=int, required=True, help='(default=%(default)s)')
|
429
|
+
parser.add_argument('--numEpochs', default='5', type=int, required=True, help='(default=%(default)s)')
|
430
|
+
parser.add_argument('--numClasses', default='10', type=int, required=True, help='(default=%(default)s)')
|
431
|
+
parser.add_argument('--lr', default='0.001', type=float, required=False, help='learning rate')
|
432
|
+
parser.add_argument('--device', default='cuda', type=str, required=False, help='device to use for decompositions')
|
433
|
+
parser.add_argument('--model', default='resnet50', type=str, required=False, help='model to train')
|
434
|
+
parser.add_argument('--dataset', default="cifar10", type=str, required=False, help='Indicate the dataset')
|
435
|
+
parser.add_argument('--dataset_dir', default="./cifar10", type=str, required=False, help='Imagenet folder')
|
436
|
+
parser.add_argument('--pretrained', default=False, action='store_true', help='use pretrained or not')
|
437
|
+
parser.add_argument('--weight_decay', default=0.0001, type=float, required=False, help='Weight Decay to be used')
|
438
|
+
parser.add_argument('--inp_channels', default="3", type=int, required=False, help='Number of input channels')
|
439
|
+
parser.add_argument('--save_pickle', default=False, action='store_true', help='to save or not to save U, S, V components')
|
440
|
+
parser.add_argument('--decomp', default="numpy", type=str, required=False, help='To perform SVD using torch or numpy')
|
441
|
+
parser.add_argument('--optimizer', default="sgd", type=str, required=True, help='Choice for optimizer')
|
442
|
+
parser.add_argument('--select_iter', default="50", type=int, required=True, help='Data Selection Iteration')
|
443
|
+
parser.add_argument('--fraction', default="0.50", type=float, required=True, help='fraction of data')
|
444
|
+
parser.add_argument('--grad_clip', default=0.00, type=float, required=False, help='Gradient Clipping Value')
|
445
|
+
parser.add_argument('--warm_start', default=False, action='store_true', help='Train with a warm-start')
|
446
|
+
|
447
|
+
args = parser.parse_args()
|
448
|
+
|
449
|
+
trainloader, valloader, trainset, valset = loader(dataset=args.dataset, dirs=args.dataset_dir, trn_batch_size=args.batch_size, val_batch_size=args.batch_size, tst_batch_size=1000)
|
450
|
+
|
451
|
+
config = TrainingConfig.from_args(args)
|
452
|
+
|
453
|
+
model = get_model(args)
|
454
|
+
data3 = prepare_data(args, trainloader)
|
455
|
+
|
456
|
+
trainer = ModelTrainer(config, model, trainloader, valloader, trainset, data3)
|
457
|
+
|
458
|
+
tracker = setup_tracker(args)
|
459
|
+
if tracker:
|
460
|
+
tracker.start()
|
461
|
+
|
462
|
+
train_stats, val_stats = trainer.train()
|
463
|
+
|
464
|
+
if tracker:
|
465
|
+
tracker.stop()
|
466
|
+
|
467
|
+
|
graft/utils/__init__.py
ADDED
graft/utils/extras.py
ADDED
@@ -0,0 +1,37 @@
|
|
1
|
+
|
2
|
+
import torch
|
3
|
+
|
4
|
+
def cal_val(val_loader, model, device):
|
5
|
+
val_acc = []
|
6
|
+
val_losses = []
|
7
|
+
val_loss = 0
|
8
|
+
val_total = 0
|
9
|
+
val_correct = 0
|
10
|
+
for _, (inputs, targets) in enumerate(val_loader):
|
11
|
+
inputs, targets = inputs.to(device), targets.to(device, non_blocking=True)
|
12
|
+
outputs = model(inputs)
|
13
|
+
loss = torch.nn.functional.cross_entropy(outputs, targets)
|
14
|
+
val_loss += loss.item()
|
15
|
+
_, predicted = outputs.max(1)
|
16
|
+
val_total += targets.size(0)
|
17
|
+
val_correct += predicted.eq(targets).sum().item()
|
18
|
+
# val_losses.append(val_loss)
|
19
|
+
val_acc.append(val_correct / val_total)
|
20
|
+
|
21
|
+
return val_acc[-1], val_loss / len(val_loader)
|
22
|
+
|
23
|
+
|
24
|
+
def elements_provider(l):
|
25
|
+
my_iterator = iter(l)
|
26
|
+
|
27
|
+
def getter():
|
28
|
+
nonlocal my_iterator
|
29
|
+
while True:
|
30
|
+
|
31
|
+
try:
|
32
|
+
return next(my_iterator)
|
33
|
+
except StopIteration:
|
34
|
+
pass
|
35
|
+
my_iterator = iter(l)
|
36
|
+
|
37
|
+
return getter
|
graft/utils/generate.py
ADDED
@@ -0,0 +1,33 @@
|
|
1
|
+
import os
|
2
|
+
from medmnist import DermaMNIST
|
3
|
+
from PIL import Image
|
4
|
+
|
5
|
+
# Initialize the dataset
|
6
|
+
train_dataset = DermaMNIST(split='train', download=True, size=224)
|
7
|
+
valid_dataset = DermaMNIST(split='val', download=True, size=224)
|
8
|
+
test_dataset = DermaMNIST(split='test', download=True, size=224)
|
9
|
+
|
10
|
+
# Define the root directory for the reorganized dataset
|
11
|
+
root_dir = 'DermaMNIST'
|
12
|
+
|
13
|
+
# Define the subdirectories
|
14
|
+
subdirs = ['train', 'valid', 'test']
|
15
|
+
classes = [str(i) for i in range(7)] # Assuming class labels are 0 through 6
|
16
|
+
|
17
|
+
# Create directories
|
18
|
+
for subdir in subdirs:
|
19
|
+
for cls in classes:
|
20
|
+
os.makedirs(os.path.join(root_dir, subdir, cls), exist_ok=True)
|
21
|
+
|
22
|
+
def save_images(dataset, subdir):
|
23
|
+
for idx, (img, label) in enumerate(dataset):
|
24
|
+
label = str(int(label[0])) # Convert label to int and then to string
|
25
|
+
img_path = os.path.join(root_dir, subdir, label, f"{subdir}_{idx}.png")
|
26
|
+
img.save(img_path)
|
27
|
+
|
28
|
+
# Save images to corresponding directories
|
29
|
+
save_images(train_dataset, 'train')
|
30
|
+
save_images(valid_dataset, 'valid')
|
31
|
+
save_images(test_dataset, 'test')
|
32
|
+
|
33
|
+
print("Dataset reorganized successfully.")
|
@@ -0,0 +1,54 @@
|
|
1
|
+
# from libauc.datasets import CheXpert
|
2
|
+
import torchvision.datasets as datasets
|
3
|
+
import torchvision.transforms as transforms
|
4
|
+
import torch
|
5
|
+
import os
|
6
|
+
|
7
|
+
|
8
|
+
|
9
|
+
def imagenet_selloader(dataset, dirs="./imagenet", trn_batch_size=64, val_batch_size=64, tst_batch_size=1000, resize=32):
|
10
|
+
|
11
|
+
|
12
|
+
if dataset.lower() == "imagenet":
|
13
|
+
# Define the data transforms
|
14
|
+
|
15
|
+
traindir = os.path.join(dirs, 'train')
|
16
|
+
valdir = os.path.join(dirs, 'val')
|
17
|
+
|
18
|
+
|
19
|
+
fullset = datasets.ImageFolder(
|
20
|
+
traindir,
|
21
|
+
transforms.Compose([
|
22
|
+
transforms.Resize(resize),
|
23
|
+
transforms.CenterCrop(resize),
|
24
|
+
transforms.ToTensor(),
|
25
|
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
26
|
+
])
|
27
|
+
|
28
|
+
)
|
29
|
+
|
30
|
+
# testset = datasets.ImageFolder(
|
31
|
+
# valdir,
|
32
|
+
# transforms.Compose([
|
33
|
+
# transforms.Resize(resize),
|
34
|
+
# transforms.CenterCrop(resize),
|
35
|
+
# transforms.ToTensor(),
|
36
|
+
# transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
37
|
+
# ])
|
38
|
+
# )
|
39
|
+
|
40
|
+
|
41
|
+
|
42
|
+
|
43
|
+
|
44
|
+
# Creating the Data Loaders
|
45
|
+
trainloader = torch.utils.data.DataLoader(fullset, batch_size=trn_batch_size,
|
46
|
+
shuffle=False, pin_memory=True, num_workers=2)
|
47
|
+
|
48
|
+
# valloader = torch.utils.data.DataLoader(testset, batch_size=val_batch_size,
|
49
|
+
# shuffle=False, pin_memory=True, num_workers=2)
|
50
|
+
|
51
|
+
# testloader = torch.utils.data.DataLoader(testset, batch_size=tst_batch_size,
|
52
|
+
# shuffle=False, pin_memory=True, num_workers=1)
|
53
|
+
|
54
|
+
return trainloader
|