graft-pytorch 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
graft/scheduler.py ADDED
@@ -0,0 +1,63 @@
1
+ import logging
2
+ import math
3
+
4
+ from torch.optim.lr_scheduler import LambdaLR
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+ class ConstantLRSchedule(LambdaLR):
9
+ """ Constant learning rate schedule.
10
+ """
11
+ def __init__(self, optimizer, last_epoch=-1):
12
+ super(ConstantLRSchedule, self).__init__(optimizer, lambda _: 1.0, last_epoch=last_epoch)
13
+
14
+
15
+ class WarmupConstantSchedule(LambdaLR):
16
+ """ Linear warmup and then constant.
17
+ Linearly increases learning rate schedule from 0 to 1 over `warmup_steps` training steps.
18
+ Keeps learning rate schedule equal to 1. after warmup_steps.
19
+ """
20
+ def __init__(self, optimizer, warmup_steps, last_epoch=-1):
21
+ self.warmup_steps = warmup_steps
22
+ super(WarmupConstantSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)
23
+
24
+ def lr_lambda(self, step):
25
+ if step < self.warmup_steps:
26
+ return float(step) / float(max(1.0, self.warmup_steps))
27
+ return 1.
28
+
29
+
30
+ class WarmupLinearSchedule(LambdaLR):
31
+ """ Linear warmup and then linear decay.
32
+ Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps.
33
+ Linearly decreases learning rate from 1. to 0. over remaining `t_total - warmup_steps` steps.
34
+ """
35
+ def __init__(self, optimizer, warmup_steps, t_total, last_epoch=-1):
36
+ self.warmup_steps = warmup_steps
37
+ self.t_total = t_total
38
+ super(WarmupLinearSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)
39
+
40
+ def lr_lambda(self, step):
41
+ if step < self.warmup_steps:
42
+ return float(step) / float(max(1, self.warmup_steps))
43
+ return max(0.0, float(self.t_total - step) / float(max(1.0, self.t_total - self.warmup_steps)))
44
+
45
+
46
+ class WarmupCosineSchedule(LambdaLR):
47
+ """ Linear warmup and then cosine decay.
48
+ Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps.
49
+ Decreases learning rate from 1. to 0. over remaining `t_total - warmup_steps` steps following a cosine curve.
50
+ If `cycles` (default=0.5) is different from default, learning rate follows cosine function after warmup.
51
+ """
52
+ def __init__(self, optimizer, warmup_steps, t_total, cycles=.5, last_epoch=-1):
53
+ self.warmup_steps = warmup_steps
54
+ self.t_total = t_total
55
+ self.cycles = cycles
56
+ super(WarmupCosineSchedule, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)
57
+
58
+ def lr_lambda(self, step):
59
+ if step < self.warmup_steps:
60
+ return float(step) / float(max(1.0, self.warmup_steps))
61
+ # progress after warmup
62
+ progress = float(step - self.warmup_steps) / float(max(1, self.t_total - self.warmup_steps))
63
+ return max(0.0, 0.5 * (1. + math.cos(math.pi * float(self.cycles) * 2.0 * progress)))
graft/trainer.py ADDED
@@ -0,0 +1,467 @@
1
+ # Standard library imports
2
+ import os
3
+ import pickle
4
+ import copy
5
+ import gc
6
+ import argparse
7
+
8
+ # Third-party imports
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.optim as optim
12
+ import torch.optim.lr_scheduler as lr_scheduler
13
+ from torch.utils.data import DataLoader, Subset
14
+ from tqdm import tqdm
15
+
16
+ # Local imports
17
+ from .utils.loader import loader
18
+ from .utils.model_mapper import ModelMapper
19
+ from .utils.imagenetselloader import imagenet_selloader
20
+ from .utils import pickler
21
+ from .decompositions import feature_sel
22
+ from .genindices import sample_selection
23
+
24
+ # Optional dependencies
25
+ try:
26
+ import wandb
27
+ WANDB_AVAILABLE = True
28
+ except ImportError:
29
+ WANDB_AVAILABLE = False
30
+
31
+ try:
32
+ import eco2ai
33
+ ECO2AI_AVAILABLE = True
34
+ except ImportError:
35
+ ECO2AI_AVAILABLE = False
36
+
37
+
38
+ class TrainingConfig:
39
+ def __init__(self, numEpochs, batch_size, device, net,
40
+ model_name, dataset_name, trainloader, valloader,
41
+ trainset, data3, optimizer_name, lr, weight_decay,
42
+ grad_clip, fraction, selection_iter, warm_start,
43
+ imgntselloader, sched="cosine", multi_checkpoint=False,
44
+ use_wandb=True): # Add use_wandb parameter
45
+
46
+ self.numEpochs = numEpochs
47
+ self.batch_size = batch_size
48
+ self.device = device
49
+ self.net = net
50
+ self.model_name = model_name
51
+ self.dataset_name = dataset_name
52
+ self.trainloader = trainloader
53
+ self.valloader = valloader
54
+ self.trainset = trainset
55
+ self.data3 = data3
56
+ self.optimizer_name = optimizer_name
57
+ self.lr = lr
58
+ self.weight_decay = weight_decay
59
+ self.grad_clip = grad_clip
60
+ self.fraction = fraction
61
+ self.selection_iter = selection_iter
62
+ self.warm_start = warm_start
63
+ self.imgntselloader = imgntselloader
64
+ self.sched = sched
65
+ self.multi_checkpoint = multi_checkpoint
66
+ self.use_wandb = use_wandb and WANDB_AVAILABLE
67
+
68
+ @classmethod
69
+ def from_args(cls, args):
70
+ return cls(
71
+ numEpochs=args.numEpochs,
72
+ batch_size=args.batch_size,
73
+ device=args.device,
74
+ net=None, # Placeholder, will be set in the trainer
75
+ model_name=args.model,
76
+ dataset_name=args.dataset,
77
+ trainloader=None, # Placeholder, will be set in the trainer
78
+ valloader=None, # Placeholder, will be set in the trainer
79
+ trainset=None, # Placeholder, will be set in the trainer
80
+ data3=None, # Placeholder, will be set in the trainer
81
+ optimizer_name=args.optimizer,
82
+ lr=args.lr,
83
+ weight_decay=args.weight_decay,
84
+ grad_clip=args.grad_clip,
85
+ fraction=args.fraction,
86
+ selection_iter=args.select_iter,
87
+ warm_start=args.warm_start,
88
+ imgntselloader=None, # Placeholder, will be set in the trainer
89
+ sched="cosine",
90
+ multi_checkpoint=False,
91
+ use_wandb=getattr(args, 'use_wandb', True)
92
+ )
93
+
94
+
95
+ class ModelTrainer:
96
+ def __init__(self, config, model, trainloader, valloader, trainset, data3):
97
+ self.config = config
98
+ self.model = model
99
+ self.trainloader = trainloader
100
+ self.valloader = valloader
101
+ self.trainset = trainset
102
+ self.data3 = data3
103
+ self.optimizer = None
104
+ self.scheduler = None
105
+ self.loss_fn = None
106
+ self.curr_high = 0
107
+ self.total = 0
108
+ self.correct = 0
109
+ self.trn_losses = list()
110
+ self.val_losses = list()
111
+ self.trn_acc = list()
112
+ self.val_acc = list()
113
+ self.selection = 0
114
+ self.weight_decay = 1e-4
115
+
116
+ self.dir_save = f"saved_models/{config.model_name}"
117
+ self.save_dir = f"{self.dir_save}/multi_checkpoint"
118
+
119
+ self._setup()
120
+
121
+ def _setup(self):
122
+ # Default to cross entropy loss unless specifically handling regression
123
+ self.loss_fn = torch.nn.functional.cross_entropy
124
+
125
+ # Create save directories
126
+ if not os.path.exists(self.dir_save):
127
+ os.makedirs(self.dir_save)
128
+ if not os.path.exists(self.save_dir):
129
+ os.makedirs(self.save_dir)
130
+
131
+ if self.config.optimizer_name.lower() == "adam":
132
+ self.optimizer = optim.Adam(self.model.parameters(), lr=self.config.lr)
133
+ elif self.config.optimizer_name.lower() == "sgd":
134
+ self.optimizer = optim.SGD(self.model.parameters(), lr=self.config.lr, momentum=0.9, weight_decay = self.config.weight_decay)
135
+ else:
136
+ raise ValueError(f"Unsupported optimizer: {self.config.optimizer_name}")
137
+
138
+ if self.config.sched.lower() == "onecycle":
139
+ self.scheduler = lr_scheduler.OneCycleLR(self.optimizer, self.config.lr, epochs=self.config.numEpochs,
140
+ steps_per_epoch=len(self.trainloader))
141
+ else:
142
+ self.scheduler = lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=200)
143
+
144
+ if self.config.use_wandb:
145
+ if WANDB_AVAILABLE:
146
+ wandb.login()
147
+ if self.config.model_name.lower() == "efficientnet-b0":
148
+ model_name = "efficientnetb0"
149
+ elif self.config.model_name.lower() == "efficientnet-b5":
150
+ model_name = "efficientnetb5"
151
+ config = {"lr": self.config.lr, "batch_size": self.config.batch_size}
152
+ config.update({"architecture": f'{self.model}'})
153
+ wandb.init(project=f"Smart_Sampling_{self.config.model_name}_{self.config.dataset_name}",
154
+ config=config)
155
+
156
+ self.main_trainloader = self.trainloader
157
+
158
+ def train(self):
159
+ train_stats = {
160
+ 'losses': self.trn_losses,
161
+ 'accuracies': self.trn_acc,
162
+ 'best_acc': max(self.trn_acc) if self.trn_acc else 0
163
+ }
164
+
165
+ val_stats = {
166
+ 'losses': self.val_losses,
167
+ 'accuracies': self.val_acc,
168
+ 'best_acc': self.curr_high
169
+ }
170
+
171
+ for epoch in range(self.config.numEpochs):
172
+ self.model.train()
173
+ tot_train_loss = 0
174
+ before_lr = self.optimizer.param_groups[0]["lr"]
175
+ pruned_samples = 0
176
+ total_samples = 0
177
+
178
+
179
+ if (epoch) % self.config.selection_iter == 0:
180
+ if self.config.warm_start and self.selection == 0:
181
+ trainloader = self.trainloader
182
+ self.selection += 1
183
+ else:
184
+ train_model = self.model
185
+ cached_state_dict = copy.deepcopy(train_model.state_dict())
186
+ clone_dict = copy.deepcopy(train_model.state_dict())
187
+
188
+ # Skip selection if no data3 available (for tests)
189
+ if self.data3 is None:
190
+ continue
191
+
192
+ if not self.config.imgntselloader:
193
+ indices = sample_selection(self.main_trainloader, self.data3, self.model,
194
+ clone_dict, self.config.batch_size, self.config.fraction,
195
+ self.config.selection_iter, self.config.numEpochs,
196
+ self.config.device, self.config.dataset_name)
197
+ else:
198
+
199
+ indices = sample_selection(self.config.imgntselloader, self.data3, self.model,
200
+ clone_dict, self.config.batch_size, self.config.fraction,
201
+ self.config.selection_iter, self.config.numEpochs, self.config.device, self.config.dataset_name)
202
+
203
+ self.model.load_state_dict(cached_state_dict)
204
+
205
+ self.selection += 1
206
+
207
+ datasubset = Subset(self.trainset, indices)
208
+ new_trainloader = DataLoader(datasubset, batch_size=self.config.batch_size,
209
+ shuffle=True, pin_memory=False, num_workers=1)
210
+
211
+ self.trainloader = new_trainloader
212
+
213
+ del cached_state_dict
214
+ del clone_dict
215
+ del train_model
216
+ torch.cuda.empty_cache()
217
+ gc.collect()
218
+
219
+ for _, (trainsamples, labels) in enumerate(tqdm(self.trainloader)):
220
+
221
+ trainsamples = trainsamples.to(self.config.device)
222
+ labels = labels.to(self.config.device)
223
+
224
+ X = trainsamples
225
+ Y = labels
226
+ pred = self.model(X)
227
+
228
+
229
+ # loss = torch.nn.functional.cross_entropy(pred, Y.to(device))
230
+ loss = self.loss_fn(pred, Y.to(self.config.device))
231
+
232
+
233
+ tot_train_loss += loss.item()
234
+
235
+ self.optimizer.zero_grad()
236
+
237
+ loss.backward()
238
+
239
+ if self.config.grad_clip:
240
+ nn.utils.clip_grad_value_(self.model.parameters(), self.config.grad_clip)
241
+
242
+ self.optimizer.step()
243
+
244
+ # calculate accuracy
245
+ _, predicted = torch.max(pred.cpu().data, 1)
246
+ self.total += Y.size(0)
247
+
248
+ self.correct += (predicted == Y.cpu()).sum().item()
249
+ # accuracy = 100 * correct / total
250
+ pruned_samples += len(trainsamples) - len(X)
251
+ total_samples += len(trainsamples)
252
+
253
+ if self.config.sched.lower() == "onecycle":
254
+ self.scheduler.step()
255
+
256
+ if self.config.sched.lower() == "cosine":
257
+ self.scheduler.step()
258
+
259
+ after_lr = self.optimizer.param_groups[0]["lr"]
260
+
261
+ print("Last Epoch [%d] -> Current Epoch [%d]: lr %.4f -> %.4f optimizer %s" % (epoch, epoch+1, before_lr, after_lr, self.config.optimizer_name))
262
+
263
+
264
+ if epoch % 20 == 0:
265
+ dir_parts = self.dir_save.split('/')
266
+ current_dir = ''
267
+
268
+ for part in dir_parts:
269
+ current_dir = os.path.join(current_dir, part)
270
+ if not os.path.exists(current_dir):
271
+ os.makedirs(current_dir)
272
+
273
+ if not os.path.exists(self.save_dir):
274
+ os.makedirs(self.save_dir)
275
+
276
+ if not os.path.exists(self.dir_save):
277
+ os.makedirs(self.dir_save)
278
+
279
+ if self.config.selection_iter > self.config.numEpochs:
280
+ file_prefix = "Full"
281
+ else:
282
+ file_prefix = "Sampled"
283
+
284
+ if self.config.multi_checkpoint:
285
+ file_prefix += "_multi"
286
+
287
+ filename = f"{file_prefix}_{self.config.dataset_name}_sch{self.config.sched}_si{self.config.selection_iter}_f{self.config.fraction}"
288
+ if self.config.multi_checkpoint:
289
+ filename += f"_ep{epoch}"
290
+ torch.save(self.model.state_dict(), f"{self.save_dir}/{filename}.pth")
291
+ else:
292
+ torch.save(self.model.state_dict(), f"{self.dir_save}/{filename}.pth")
293
+
294
+
295
+
296
+ if (epoch+1) % 1 == 0:
297
+ trn_loss = 0
298
+ trn_correct = 0
299
+ trn_total = 0
300
+ val_loss = 0
301
+ val_correct = 0
302
+ val_total = 0
303
+ self.model.eval()
304
+ with torch.no_grad():
305
+ for _, (inputs, targets) in enumerate(self.trainloader):
306
+ inputs, targets = inputs.to(self.config.device), \
307
+ targets.to(self.config.device, non_blocking=True)
308
+ outputs = self.model(inputs)
309
+ loss = self.loss_fn(outputs, targets)
310
+ trn_loss += loss.item()
311
+ _, predicted = outputs.max(1)
312
+ trn_total += targets.size(0)
313
+ trn_correct += predicted.eq(targets).sum().item()
314
+ self.trn_losses.append(trn_loss)
315
+ self.trn_acc.append(trn_correct / trn_total)
316
+ with torch.no_grad():
317
+ for _, (inputs, targets) in enumerate(self.valloader):
318
+ inputs, targets = inputs.to(self.config.device), \
319
+ targets.to(self.config.device, non_blocking=True)
320
+ outputs = self.model(inputs)
321
+ loss = self.loss_fn(outputs, targets)
322
+ val_loss += loss.item()
323
+ _, predicted = outputs.max(1)
324
+ val_total += targets.size(0)
325
+ val_correct += predicted.eq(targets).sum().item()
326
+ self.val_losses.append(val_loss)
327
+ self.val_acc.append(val_correct / val_total)
328
+
329
+ if self.val_acc[-1] > self.curr_high:
330
+ self.curr_high = self.val_acc[-1]
331
+
332
+
333
+ if self.config.use_wandb and WANDB_AVAILABLE:
334
+ wandb.log({
335
+ "Validation accuracy": self.curr_high,
336
+ "Val Loss": self.val_losses[-1]/100,
337
+ "loss": self.trn_losses[-1]/100,
338
+ "Train Accuracy": self.trn_acc[-1]*100,
339
+ "Epoch": epoch
340
+ })
341
+
342
+ print("Epoch [{}/{}], Loss: {:.4f}, Train Accuracy: {:.2f}%".format(
343
+ epoch+1,
344
+ self.config.numEpochs,
345
+ self.trn_losses[-1],
346
+ self.trn_acc[-1]*100
347
+ ))
348
+
349
+ print("Highest Accuracy:", self.curr_high)
350
+ print("Validation Accuracy:", self.val_acc[-1])
351
+ print("Validation Loss", self.val_losses[-1])
352
+
353
+ return train_stats, val_stats
354
+
355
+
356
+ def get_model(args):
357
+ arguments = type('', (), {'model': args.model.lower(), 'numClasses': args.numClasses,
358
+ 'device': args.device, 'in_chanls':args.inp_channels})()
359
+ model_mapper = ModelMapper(arguments)
360
+ return model_mapper.get_model()
361
+
362
+
363
+ def prepare_data(args, trainloader):
364
+ if args.select_iter < args.numEpochs:
365
+ imgntselloader = None
366
+ pickle_dir = f"{args.dataset}_pickle"
367
+ file = os.path.join(pickle_dir, f"V_{args.batch_size}.pkl")
368
+
369
+ # Create pickle directory if it doesn't exist
370
+ if not os.path.exists(pickle_dir):
371
+ os.makedirs(pickle_dir)
372
+
373
+ if os.path.exists(file):
374
+ print("Loading existing pickle file")
375
+ with open(file, 'rb') as f:
376
+ data3 = pickle.load(f)
377
+ else:
378
+ print("Generating new pickle file")
379
+ if args.dataset.lower() != "imagenet":
380
+ V = feature_sel(trainloader, args.batch_size, device=args.device, decomp_type=args.decomp)
381
+ data3 = V
382
+ # Save pickle
383
+ with open(file, 'wb') as f:
384
+ pickle.dump(V, f)
385
+ else:
386
+ imgntselloader = imagenet_selloader(args.dataset, dirs=args.dataset_dir,
387
+ trn_batch_size=args.batch_size,
388
+ val_batch_size=args.batch_size,
389
+ tst_batch_size=1000, resize=32)
390
+
391
+ V = feature_sel(imgntselloader, args.batch_size, device=args.device, decomp_type=args.decomp)
392
+ data3 = V
393
+
394
+ with open(file, 'wb') as f:
395
+ pickle.dump(V, f)
396
+ else:
397
+ data3 = None
398
+
399
+ if args.dataset.lower() == "imagenet" and not imgntselloader:
400
+ imgntselloader = imagenet_selloader(args.dataset, dirs=args.dataset_dir,
401
+ trn_batch_size=args.batch_size,
402
+ val_batch_size=args.batch_size,
403
+ tst_batch_size=1000, resize=32)
404
+
405
+ return data3
406
+
407
+
408
+ def setup_tracker(args):
409
+ if not ECO2AI_AVAILABLE:
410
+ print("Warning: eco2ai not available, skipping emissions tracking")
411
+ return None
412
+
413
+ if args.warm_start:
414
+ ttype = "warm"
415
+ else:
416
+ ttype = "nowarm"
417
+
418
+ tracker = eco2ai.Tracker(
419
+ project_name=f"{args.model}_dset-{args.dataset}_bs-{args.batch_size}",
420
+ experiment_description="training DEIM_IS model",
421
+ file_name=f"emission_-{args.model}_dset-{args.dataset}_bs-{args.batch_size}_epochs-{args.numEpochs}_fraction-{args.fraction}_{args.optimizer}_{ttype}.csv"
422
+ )
423
+ return tracker
424
+
425
+
426
+ if __name__ == '__main__':
427
+ parser = argparse.ArgumentParser(description="Model Training with smart Sampling")
428
+ parser.add_argument('--batch_size', default='128', type=int, required=True, help='(default=%(default)s)')
429
+ parser.add_argument('--numEpochs', default='5', type=int, required=True, help='(default=%(default)s)')
430
+ parser.add_argument('--numClasses', default='10', type=int, required=True, help='(default=%(default)s)')
431
+ parser.add_argument('--lr', default='0.001', type=float, required=False, help='learning rate')
432
+ parser.add_argument('--device', default='cuda', type=str, required=False, help='device to use for decompositions')
433
+ parser.add_argument('--model', default='resnet50', type=str, required=False, help='model to train')
434
+ parser.add_argument('--dataset', default="cifar10", type=str, required=False, help='Indicate the dataset')
435
+ parser.add_argument('--dataset_dir', default="./cifar10", type=str, required=False, help='Imagenet folder')
436
+ parser.add_argument('--pretrained', default=False, action='store_true', help='use pretrained or not')
437
+ parser.add_argument('--weight_decay', default=0.0001, type=float, required=False, help='Weight Decay to be used')
438
+ parser.add_argument('--inp_channels', default="3", type=int, required=False, help='Number of input channels')
439
+ parser.add_argument('--save_pickle', default=False, action='store_true', help='to save or not to save U, S, V components')
440
+ parser.add_argument('--decomp', default="numpy", type=str, required=False, help='To perform SVD using torch or numpy')
441
+ parser.add_argument('--optimizer', default="sgd", type=str, required=True, help='Choice for optimizer')
442
+ parser.add_argument('--select_iter', default="50", type=int, required=True, help='Data Selection Iteration')
443
+ parser.add_argument('--fraction', default="0.50", type=float, required=True, help='fraction of data')
444
+ parser.add_argument('--grad_clip', default=0.00, type=float, required=False, help='Gradient Clipping Value')
445
+ parser.add_argument('--warm_start', default=False, action='store_true', help='Train with a warm-start')
446
+
447
+ args = parser.parse_args()
448
+
449
+ trainloader, valloader, trainset, valset = loader(dataset=args.dataset, dirs=args.dataset_dir, trn_batch_size=args.batch_size, val_batch_size=args.batch_size, tst_batch_size=1000)
450
+
451
+ config = TrainingConfig.from_args(args)
452
+
453
+ model = get_model(args)
454
+ data3 = prepare_data(args, trainloader)
455
+
456
+ trainer = ModelTrainer(config, model, trainloader, valloader, trainset, data3)
457
+
458
+ tracker = setup_tracker(args)
459
+ if tracker:
460
+ tracker.start()
461
+
462
+ train_stats, val_stats = trainer.train()
463
+
464
+ if tracker:
465
+ tracker.stop()
466
+
467
+
@@ -0,0 +1,5 @@
1
+ # from .loader import loader
2
+ from .model_mapper import ModelMapper
3
+ from .pickler import pickler
4
+ from .extras import cal_val, elements_provider
5
+
graft/utils/extras.py ADDED
@@ -0,0 +1,37 @@
1
+
2
+ import torch
3
+
4
+ def cal_val(val_loader, model, device):
5
+ val_acc = []
6
+ val_losses = []
7
+ val_loss = 0
8
+ val_total = 0
9
+ val_correct = 0
10
+ for _, (inputs, targets) in enumerate(val_loader):
11
+ inputs, targets = inputs.to(device), targets.to(device, non_blocking=True)
12
+ outputs = model(inputs)
13
+ loss = torch.nn.functional.cross_entropy(outputs, targets)
14
+ val_loss += loss.item()
15
+ _, predicted = outputs.max(1)
16
+ val_total += targets.size(0)
17
+ val_correct += predicted.eq(targets).sum().item()
18
+ # val_losses.append(val_loss)
19
+ val_acc.append(val_correct / val_total)
20
+
21
+ return val_acc[-1], val_loss / len(val_loader)
22
+
23
+
24
+ def elements_provider(l):
25
+ my_iterator = iter(l)
26
+
27
+ def getter():
28
+ nonlocal my_iterator
29
+ while True:
30
+
31
+ try:
32
+ return next(my_iterator)
33
+ except StopIteration:
34
+ pass
35
+ my_iterator = iter(l)
36
+
37
+ return getter
@@ -0,0 +1,33 @@
1
+ import os
2
+ from medmnist import DermaMNIST
3
+ from PIL import Image
4
+
5
+ # Initialize the dataset
6
+ train_dataset = DermaMNIST(split='train', download=True, size=224)
7
+ valid_dataset = DermaMNIST(split='val', download=True, size=224)
8
+ test_dataset = DermaMNIST(split='test', download=True, size=224)
9
+
10
+ # Define the root directory for the reorganized dataset
11
+ root_dir = 'DermaMNIST'
12
+
13
+ # Define the subdirectories
14
+ subdirs = ['train', 'valid', 'test']
15
+ classes = [str(i) for i in range(7)] # Assuming class labels are 0 through 6
16
+
17
+ # Create directories
18
+ for subdir in subdirs:
19
+ for cls in classes:
20
+ os.makedirs(os.path.join(root_dir, subdir, cls), exist_ok=True)
21
+
22
+ def save_images(dataset, subdir):
23
+ for idx, (img, label) in enumerate(dataset):
24
+ label = str(int(label[0])) # Convert label to int and then to string
25
+ img_path = os.path.join(root_dir, subdir, label, f"{subdir}_{idx}.png")
26
+ img.save(img_path)
27
+
28
+ # Save images to corresponding directories
29
+ save_images(train_dataset, 'train')
30
+ save_images(valid_dataset, 'valid')
31
+ save_images(test_dataset, 'test')
32
+
33
+ print("Dataset reorganized successfully.")
@@ -0,0 +1,54 @@
1
+ # from libauc.datasets import CheXpert
2
+ import torchvision.datasets as datasets
3
+ import torchvision.transforms as transforms
4
+ import torch
5
+ import os
6
+
7
+
8
+
9
+ def imagenet_selloader(dataset, dirs="./imagenet", trn_batch_size=64, val_batch_size=64, tst_batch_size=1000, resize=32):
10
+
11
+
12
+ if dataset.lower() == "imagenet":
13
+ # Define the data transforms
14
+
15
+ traindir = os.path.join(dirs, 'train')
16
+ valdir = os.path.join(dirs, 'val')
17
+
18
+
19
+ fullset = datasets.ImageFolder(
20
+ traindir,
21
+ transforms.Compose([
22
+ transforms.Resize(resize),
23
+ transforms.CenterCrop(resize),
24
+ transforms.ToTensor(),
25
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
26
+ ])
27
+
28
+ )
29
+
30
+ # testset = datasets.ImageFolder(
31
+ # valdir,
32
+ # transforms.Compose([
33
+ # transforms.Resize(resize),
34
+ # transforms.CenterCrop(resize),
35
+ # transforms.ToTensor(),
36
+ # transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
37
+ # ])
38
+ # )
39
+
40
+
41
+
42
+
43
+
44
+ # Creating the Data Loaders
45
+ trainloader = torch.utils.data.DataLoader(fullset, batch_size=trn_batch_size,
46
+ shuffle=False, pin_memory=True, num_workers=2)
47
+
48
+ # valloader = torch.utils.data.DataLoader(testset, batch_size=val_batch_size,
49
+ # shuffle=False, pin_memory=True, num_workers=2)
50
+
51
+ # testloader = torch.utils.data.DataLoader(testset, batch_size=tst_batch_size,
52
+ # shuffle=False, pin_memory=True, num_workers=1)
53
+
54
+ return trainloader