graft-pytorch 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
graft/__init__.py ADDED
@@ -0,0 +1,20 @@
1
+ """
2
+ GRAFT: Gradient-Aware Fast MaxVol Technique for Dynamic Data Sampling
3
+
4
+ A PyTorch implementation of smart sampling for efficient deep learning training.
5
+ """
6
+
7
+ __version__ = "0.1.7"
8
+ __author__ = "Ashish Jha"
9
+ __email__ = "ashish.jha@skoltech.ru"
10
+
11
+ from .trainer import ModelTrainer, TrainingConfig
12
+ from .decompositions import feature_sel
13
+ from .genindices import sample_selection
14
+
15
+ __all__ = [
16
+ "ModelTrainer",
17
+ "TrainingConfig",
18
+ "feature_sel",
19
+ "sample_selection",
20
+ ]
graft/cli.py ADDED
@@ -0,0 +1,62 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ Command-line interface for GRAFT training.
4
+ """
5
+
6
+ import sys
7
+ from .trainer import *
8
+
9
+ def main():
10
+ """Main entry point for the CLI."""
11
+ import argparse
12
+ from .trainer import TrainingConfig, get_model, prepare_data, ModelTrainer, setup_tracker
13
+ from .utils.loader import loader
14
+
15
+ # Create argument parser (moved from trainer.py)
16
+ parser = argparse.ArgumentParser(description="Model Training with smart Sampling")
17
+ parser.add_argument('--batch_size', default='128', type=int, required=True, help='(default=%(default)s)')
18
+ parser.add_argument('--numEpochs', default='5', type=int, required=True, help='(default=%(default)s)')
19
+ parser.add_argument('--numClasses', default='10', type=int, required=True, help='(default=%(default)s)')
20
+ parser.add_argument('--lr', default='0.001', type=float, required=False, help='learning rate')
21
+ parser.add_argument('--device', default='cuda', type=str, required=False, help='device to use for decompositions')
22
+ parser.add_argument('--model', default='resnet50', type=str, required=False, help='model to train')
23
+ parser.add_argument('--dataset', default="cifar10", type=str, required=False, help='Indicate the dataset')
24
+ parser.add_argument('--dataset_dir', default="./cifar10", type=str, required=False, help='Imagenet folder')
25
+ parser.add_argument('--pretrained', default=False, action='store_true', help='use pretrained or not')
26
+ parser.add_argument('--weight_decay', default=0.0001, type=float, required=False, help='Weight Decay to be used')
27
+ parser.add_argument('--inp_channels', default="3", type=int, required=False, help='Number of input channels')
28
+ parser.add_argument('--save_pickle', default=False, action='store_true', help='to save or not to save U, S, V components')
29
+ parser.add_argument('--decomp', default="numpy", type=str, required=False, help='To perform SVD using torch or numpy')
30
+ parser.add_argument('--optimizer', default="sgd", type=str, required=True, help='Choice for optimizer')
31
+ parser.add_argument('--select_iter', default="50", type=int, required=True, help='Data Selection Iteration')
32
+ parser.add_argument('--fraction', default="0.50", type=float, required=True, help='fraction of data')
33
+ parser.add_argument('--grad_clip', default=0.00, type=float, required=False, help='Gradient Clipping Value')
34
+ parser.add_argument('--warm_start', default=False, action='store_true', help='Train with a warm-start')
35
+
36
+ args = parser.parse_args()
37
+
38
+ trainloader, valloader, trainset, valset = loader(
39
+ dataset=args.dataset,
40
+ dirs=args.dataset_dir,
41
+ trn_batch_size=args.batch_size,
42
+ val_batch_size=args.batch_size,
43
+ tst_batch_size=1000
44
+ )
45
+
46
+ config = TrainingConfig.from_args(args)
47
+ model = get_model(args)
48
+ data3 = prepare_data(args, trainloader)
49
+
50
+ trainer = ModelTrainer(config, model, trainloader, valloader, trainset, data3)
51
+
52
+ tracker = setup_tracker(args)
53
+ if tracker:
54
+ tracker.start()
55
+
56
+ train_stats, val_stats = trainer.train()
57
+
58
+ if tracker:
59
+ tracker.stop()
60
+
61
+ if __name__ == '__main__':
62
+ main()
graft/config.py ADDED
@@ -0,0 +1,36 @@
1
+ from dataclasses import dataclass
2
+
3
+ @dataclass
4
+ class TrainingConfig:
5
+ model_name: str
6
+ dataset_name: str
7
+ num_epochs: int
8
+ batch_size: int
9
+ lr: float
10
+ device: str
11
+ optimizer: str
12
+ weight_decay: float
13
+ grad_clip: float
14
+ fraction: float
15
+ selection_iter: int
16
+ warm_start: bool
17
+ sched: str = "cosine"
18
+ num_workers: int = 4
19
+
20
+ @classmethod
21
+ def from_args(cls, args):
22
+ return cls(
23
+ model_name=args.model,
24
+ dataset_name=args.dataset,
25
+ num_epochs=args.numEpochs,
26
+ batch_size=args.batch_size,
27
+ lr=args.lr,
28
+ device=args.device,
29
+ optimizer=args.optimizer,
30
+ weight_decay=args.weight_decay,
31
+ grad_clip=args.grad_clip,
32
+ fraction=args.fraction,
33
+ selection_iter=args.select_iter,
34
+ warm_start=args.warm_start,
35
+ sched=getattr(args, 'sched', 'cosine')
36
+ )
@@ -0,0 +1,54 @@
1
+ import numpy as np
2
+ from tqdm import tqdm
3
+ import torch
4
+
5
+ def index_sel(vh, r):
6
+ '''
7
+ Performs a Maximum Vol Index selection on batch-wise Vt elements from index_decomposition
8
+ A - is the input matrix (dataset)
9
+ r - is the desired rank
10
+ device - cpu or gpu
11
+ '''
12
+ # Move input tensor to GPU
13
+ device = "cuda" if torch.cuda.is_available() and isinstance(vh, torch.Tensor) else "cpu"
14
+
15
+ if isinstance(vh, torch.Tensor):
16
+ V = torch.transpose(vh, 0, 1).to(device)
17
+ V = V[:, :r]
18
+ icol = []
19
+ for i in range(0, r):
20
+ col_i = torch.where(torch.abs(V) == (torch.max(torch.abs(V[:, i]))))[0]
21
+ V[:, i+1:] = V[:, i+1:] - (V[:, 0:i+1] @ (torch.pinverse(V[col_i, 0:i+1]) @ V[col_i, i+1:]))
22
+ icol.append(col_i.cpu().numpy())
23
+ else: # Assume numpy array
24
+ V = np.transpose(vh)
25
+ V = V[:, :r]
26
+ icol = []
27
+ for i in range(0, r):
28
+ col_i = np.where(np.abs(V) == (np.max(np.abs(V[:, i]))))[0]
29
+ V[:, i+1:] = V[:, i+1:] - (V[:, 0:i+1] @ (np.linalg.pinv(V[col_i, 0:i+1]) @ V[col_i, i+1:]))
30
+ icol.append(col_i)
31
+ return icol
32
+
33
+
34
+
35
+ def feature_sel(trainloader, batch_size, device, decomp_type="numpy"):
36
+ '''
37
+ Performs a SVD Decomposition of each batch using either torch or numpy.
38
+ trainloader - the input trainloader on which training will be performed
39
+ batch_size - batch_size used for training(Note: training batch size and decomposition batch_size should be same)
40
+ device - "cuda" can only be used with torch (if the device is "cuda" and decomp_type is numpy "cpu" will be used by default.)
41
+ decomp_type - perform decomposition using "torch" or "numpy")
42
+ '''
43
+ V_list = []
44
+ for _, (trainsamples, _) in enumerate(tqdm(trainloader)):
45
+
46
+ if decomp_type == "torch":
47
+ _, _, Vt = torch.linalg.svd(torch.reshape(trainsamples.to(device),(-1, trainsamples.shape[0])),full_matrices=False)
48
+ else:
49
+ _, _, Vt = np.linalg.svd(np.reshape(trainsamples.cpu().numpy(),(-1,trainsamples.shape[0])),full_matrices=False)
50
+
51
+ V_list.append(Vt)
52
+
53
+ return V_list
54
+
graft/genindices.py ADDED
@@ -0,0 +1,122 @@
1
+ import torch
2
+ import itertools
3
+ from .decompositions import index_sel
4
+ from tqdm import tqdm
5
+ from .grad_dist import calnorm
6
+ import numpy as np
7
+ import math
8
+ import gc
9
+
10
+
11
+ def process_indices(indices):
12
+ '''
13
+ Processes indices to generate a list of cumulative indices
14
+ '''
15
+
16
+ l2 = indices[0]
17
+ for i in range(len(indices) - 1):
18
+ l2 = l2 + list(np.array(l2[-1]) + np.array(indices[i + 1]))
19
+
20
+ return l2
21
+
22
+
23
+ def sample_selection(trainloader, data3, net, clone_dict, batch_size, fraction, sel_iter, numEpochs, device, dataset_name):
24
+ # Note: Seeds should be set by the caller for reproducibility
25
+
26
+ if dataset_name.lower() == 'boston':
27
+ loss_fn = torch.nn.MSELoss(reduction='mean')
28
+ else:
29
+ loss_fn = torch.nn.functional.cross_entropy
30
+ assert numEpochs > sel_iter, "Number of Epochs must be greater than sel_iter"
31
+ indices = []
32
+ l2 = []
33
+ len_ranks = batch_size * fraction
34
+ min_range = int(len_ranks - (len_ranks * fraction))
35
+ max_range = int(len_ranks + (len_ranks * fraction))
36
+
37
+ if max_range - min_range < 1:
38
+ ranks = np.arange((1, max_range),1, dtype=int)
39
+ num_selections = int(numEpochs / sel_iter)
40
+ candidates = ranks
41
+ else:
42
+ ranks = np.arange(min_range, max_range, 1, dtype=int)
43
+ num_selections = int(numEpochs / sel_iter)
44
+ candidates = math.ceil(len(ranks) / num_selections)
45
+
46
+ candidate_ranks = list(np.random.choice(list(ranks), size=candidates, replace=False))
47
+ if len(candidate_ranks) > 3:
48
+ candidate_ranks = list(np.random.choice(list(candidate_ranks), size=3, replace=False))
49
+ print("current selected rank candidates:", candidate_ranks)
50
+
51
+
52
+ # Add success status tracking
53
+ total_samples = len(trainloader.dataset)
54
+ selected_count = 0
55
+ success_rate = 0.0
56
+
57
+ for _, ((trainsamples, labels), V) in enumerate(tqdm(zip(trainloader, data3), desc="Sample Selection")):
58
+
59
+ net.load_state_dict(clone_dict)
60
+ trainsamples = trainsamples.to(device)
61
+ labels = labels.to(device)
62
+
63
+
64
+ A = np.reshape(trainsamples.detach().cpu().numpy(),(-1,trainsamples.shape[0]))
65
+ out, _ = net(trainsamples, last=True, freeze=True)
66
+
67
+
68
+ loss = loss_fn(out, labels).sum()
69
+ l0_grad = torch.autograd.grad(loss, out)[0]
70
+ distance_dict = {}
71
+ for ranks in candidate_ranks:
72
+ net.load_state_dict(clone_dict)
73
+ idx2 = index_sel(V, min(ranks, A.shape[1]))
74
+ idx2 = list(set((itertools.chain(*idx2))))
75
+ if dataset_name == "boston":
76
+ out_idx, _ = net(trainsamples[idx2,:], last=True, freeze=True)
77
+ else:
78
+ out_idx, _ = net(trainsamples[idx2,:,:,:], last=True, freeze=True)
79
+ loss_idx = loss_fn(out_idx, labels[idx2]).sum()
80
+ l0_idx_grad = torch.autograd.grad(loss_idx, out_idx)[0]
81
+ distance = calnorm(l0_idx_grad, l0_grad)
82
+ distance_dict[tuple(idx2)] = distance
83
+
84
+ indices.append(list(min(distance_dict, key=distance_dict.get)))
85
+ selected_count += len(idx2)
86
+ success_rate = (selected_count / total_samples) * 100
87
+
88
+ print(f"Sample Selection Complete - Selected {selected_count}/{total_samples} samples ({success_rate:.2f}%)")
89
+
90
+ del clone_dict
91
+ del net
92
+ torch.cuda.empty_cache()
93
+ gc.collect()
94
+
95
+ # Process collected indices
96
+ batch_indices = []
97
+ total_indices = []
98
+
99
+ for batch_idx in indices:
100
+ if isinstance(batch_idx, list):
101
+ batch_indices.extend(batch_idx)
102
+ else:
103
+ batch_indices.append(batch_idx)
104
+
105
+ # Convert to occurrence count/scores
106
+ unique_indices = np.unique(batch_indices)
107
+ scores = np.zeros(len(trainloader.dataset))
108
+ for idx in batch_indices:
109
+ scores[idx] += 1
110
+
111
+ # Select top fraction based on scores
112
+ num_to_select = int(len(trainloader.dataset) * fraction)
113
+ selected_indices = np.argsort(scores)[::-1][:num_to_select]
114
+
115
+ # Ensure we have exactly the right number of unique indices
116
+ selected_indices = np.unique(selected_indices)[:num_to_select]
117
+
118
+ final_success_rate = (len(selected_indices) / len(trainloader.dataset)) * 100
119
+ print(f"Final Selection - Kept {len(selected_indices)}/{len(trainloader.dataset)} samples ({final_success_rate:.2f}%)")
120
+
121
+ return selected_indices
122
+
graft/grad_dist.py ADDED
@@ -0,0 +1,20 @@
1
+
2
+ import torch
3
+ import numpy as np
4
+ import warnings
5
+
6
+ # Suppress specific NumPy 2.0 warnings for backward compatibility
7
+ warnings.filterwarnings("ignore", message="__array_wrap__ must accept context and return_scalar arguments")
8
+
9
+ def calnorm(idxgrads, fgrads):
10
+ # Keep everything in PyTorch to avoid NumPy 2.0 warnings
11
+ ss_grad = torch.transpose(idxgrads.clone().detach().cpu(), 0, 1)
12
+ b_ = fgrads.sum(dim=0).detach().cpu()
13
+
14
+ # Use PyTorch's pinverse instead of NumPy
15
+ pinverse = torch.pinverse(ss_grad.float())
16
+ x = torch.matmul(pinverse, b_.float())
17
+
18
+ # Calculate residual norm
19
+ norm_residual = torch.norm(torch.matmul(ss_grad.float(), x) - b_.float())
20
+ return norm_residual
@@ -0,0 +1,40 @@
1
+ from transformers import AutoModelForSequenceClassification
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+
7
+ class bertmodel(nn.Module):
8
+ def __init__(self, device, numlabels=2):
9
+ super(bertmodel, self).__init__()
10
+ self.model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=numlabels)
11
+ self.device = device
12
+
13
+ def forward(self, x, indices, last=False, freeze=False):
14
+ if freeze and last and not indices:
15
+ for param in self.model.parameters():
16
+ param.requires_grad = False
17
+ for param in self.model.classifier.parameters():
18
+ param.requires_grad = True
19
+
20
+ output = self.model(x["input_ids"].to(self.device),
21
+ attention_mask=x["attention_mask"].to(self.device), labels=x["label"].to(self.device))
22
+ elif freeze and last and indices:
23
+ for param in self.model.parameters():
24
+ param.requires_grad = False
25
+ for param in self.model.classifier.parameters():
26
+ param.requires_grad = True
27
+
28
+ output = self.model(x["input_ids"][indices].to(self.device),
29
+ attention_mask=x["attention_mask"][indices].to(self.device), labels=x["label"][indices].to(self.device))
30
+
31
+ else:
32
+ # for param in self.model.parameters():
33
+ # param.requires_grad = True
34
+ output = self.model(x["input_ids"].to(self.device),
35
+ attention_mask=x["attention_mask"].to(self.device), labels=x["label"].to(self.device))
36
+
37
+
38
+ return output
39
+
40
+
@@ -0,0 +1,111 @@
1
+ '''MobileNetV2 in PyTorch.
2
+
3
+ Reference
4
+ Inverted Residuals and Linear Bottlenecks: Mobile Networks for Classification, Detection and Segmentation
5
+ https://arxiv.org/abs/1801.04381
6
+ '''
7
+
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+
14
+ class Block(nn.Module):
15
+ '''expand + depthwise + pointwise'''
16
+
17
+ def __init__(self, in_planes, out_planes, expansion, stride):
18
+ super(Block, self).__init__()
19
+ self.stride = stride
20
+
21
+ planes = expansion * in_planes
22
+ self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, stride=1, padding=0, bias=False)
23
+ self.bn1 = nn.BatchNorm2d(planes)
24
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, groups=planes, bias=False)
25
+ self.bn2 = nn.BatchNorm2d(planes)
26
+ self.conv3 = nn.Conv2d(planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False)
27
+ self.bn3 = nn.BatchNorm2d(out_planes)
28
+
29
+ self.shortcut = nn.Sequential()
30
+ if stride == 1 and in_planes != out_planes:
31
+ self.shortcut = nn.Sequential(
32
+ nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False),
33
+ nn.BatchNorm2d(out_planes),
34
+ )
35
+
36
+
37
+ def forward(self, x):
38
+ out = F.relu(self.bn1(self.conv1(x)))
39
+ out = F.relu(self.bn2(self.conv2(out)))
40
+ out = self.bn3(self.conv3(out))
41
+ out = out + self.shortcut(x) if self.stride==1 else out
42
+ return out
43
+
44
+
45
+ class MobileNetV2(nn.Module):
46
+ # (expansion, out_planes, num_blocks, stride)
47
+ cfg = [(1, 16, 1, 1),
48
+ (6, 24, 2, 1), # NOTE: change stride 2 -> 1 for CIFAR10
49
+ (6, 32, 3, 2),
50
+ (6, 64, 4, 2),
51
+ (6, 96, 3, 1),
52
+ (6, 160, 3, 2),
53
+ (6, 320, 1, 1)]
54
+
55
+
56
+ def __init__(self, num_classes=10):
57
+ super(MobileNetV2, self).__init__()
58
+ self.embDim = 1280
59
+
60
+ # NOTE: change conv1 stride 2 -> 1 for CIFAR10
61
+ self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1, bias=False)
62
+ self.bn1 = nn.BatchNorm2d(32)
63
+ self.layers = self._make_layers(in_planes=32)
64
+ self.conv2 = nn.Conv2d(320, 1280, kernel_size=1, stride=1, padding=0, bias=False)
65
+ self.bn2 = nn.BatchNorm2d(1280)
66
+ self.linear = nn.Linear(1280, num_classes)
67
+
68
+
69
+ def _make_layers(self, in_planes):
70
+ layers = []
71
+ for expansion, out_planes, num_blocks, stride in self.cfg:
72
+ strides = [stride] + [1]*(num_blocks-1)
73
+ for stride in strides:
74
+ layers.append(Block(in_planes, out_planes, expansion, stride))
75
+ in_planes = out_planes
76
+ return nn.Sequential(*layers)
77
+
78
+
79
+ def forward(self, x, last=False, freeze=False):
80
+ if freeze:
81
+ with torch.no_grad():
82
+ out = F.relu(self.bn1(self.conv1(x)))
83
+ out = self.layers(out)
84
+ out = F.relu(self.bn2(self.conv2(out)))
85
+ # NOTE: change pooling kernel_size 7 -> 4 for CIFAR10
86
+ out = F.avg_pool2d(out, 4)
87
+ e = out.view(out.size(0), -1)
88
+ else:
89
+ out = F.relu(self.bn1(self.conv1(x)))
90
+ out = self.layers(out)
91
+ out = F.relu(self.bn2(self.conv2(out)))
92
+ # NOTE: change pooling kernel_size 7 -> 4 for CIFAR10
93
+ out = F.avg_pool2d(out, 4)
94
+ e = out.view(out.size(0), -1)
95
+ out = self.linear(e)
96
+ if last:
97
+ return out, e
98
+ else:
99
+ return out
100
+
101
+ def get_embedding_dim(self):
102
+ return self.embDim
103
+
104
+
105
+ def test():
106
+ net = MobileNetV2()
107
+ x = torch.randn(2, 3, 32, 32)
108
+ y = net(x)
109
+ print(y.size())
110
+
111
+ # test()
@@ -0,0 +1,154 @@
1
+ """
2
+ ResNeXt is a simple, highly modularized network architecture for image classification. The
3
+ network is constructed by repeating a building block that aggregates a set of transformations
4
+ with the same topology. The simple design results in a homogeneous, multi-branch architecture
5
+ that has only a few hyper-parameters to set. This strategy exposes a new dimension, which is
6
+ referred as “cardinality” (the size of the set of transformations), as an essential factor in
7
+ addition to the dimensions of depth and width.
8
+
9
+ We can think of cardinality as the set of separate conv block representing same complexity as
10
+ when those blocks are combined together to make a single block.
11
+
12
+ Blog: https://towardsdatascience.com/review-resnext-1st-runner-up-of-ilsvrc-2016-image-classification-15d7f17b42ac
13
+
14
+ #### Citation ####
15
+
16
+ PyTorch Code: https://github.com/prlz77/ResNeXt.pytorch
17
+
18
+ @article{Xie2016,
19
+ title={Aggregated Residual Transformations for Deep Neural Networks},
20
+ author={Saining Xie and Ross Girshick and Piotr Dollár and Zhuowen Tu and Kaiming He},
21
+ journal={arXiv preprint arXiv:1611.05431},
22
+ year={2016}
23
+ }
24
+
25
+ """
26
+
27
+ import torch.nn as nn
28
+ import torch.nn.functional as F
29
+ import torch
30
+
31
+
32
+ class Block(nn.Module):
33
+ expansion = 2
34
+ """
35
+ '''Grouped convolution block.'''
36
+ groups: Integer
37
+ It controls the connections between inputs and outputs. in_channels and out_channels must
38
+ both be divisible by groups. For example,
39
+
40
+ At groups=1, all inputs are convolved to all outputs.
41
+ At groups=2, the operation becomes equivalent to having two conv layers side by side, each
42
+ seeing half the input channels and producing half the output channels, and both subsequently
43
+ concatenated.
44
+ At groups= in_channels, each input channel is convolved with its own set of filters
45
+ (of size (out_channels/in_channels))
46
+
47
+ group parameter in Conv2d, splits the output_channel by cardinality.
48
+
49
+ """
50
+
51
+ def __init__(self, in_planes, cardinality=32, bottleneck_width=4, stride=1):
52
+ super(Block, self).__init__()
53
+ group_width = cardinality * bottleneck_width
54
+ self.conv1 = nn.Conv2d(in_planes, group_width, kernel_size=1, bias=False)
55
+ self.bn1 = nn.BatchNorm2d(group_width)
56
+ """group=cardinality, it divides the out_channel by 32(cardinality) i.e. thus, divides channel 128 into 4"""
57
+ self.conv2 = nn.Conv2d(group_width, group_width, kernel_size=3, stride=stride, padding=1, groups=cardinality, bias=False)
58
+ self.bn2 = nn.BatchNorm2d(group_width)
59
+ self.conv3 = nn.Conv2d(group_width, self.expansion*group_width, kernel_size=1, bias=False)
60
+ self.bn3 = nn.BatchNorm2d(self.expansion*group_width)
61
+
62
+ self.shortcut = nn.Sequential()
63
+ if stride != 1 or in_planes != self.expansion*group_width:
64
+ self.shortcut = nn.Sequential(
65
+ nn.Conv2d(in_planes, self.expansion*group_width, kernel_size=1, stride=stride, bias=False),
66
+ nn.BatchNorm2d(self.expansion*group_width)
67
+ )
68
+
69
+ def forward(self, x):
70
+ out = F.relu(self.bn1(self.conv1(x)))
71
+ out = F.relu(self.bn2(self.conv2(out)))
72
+ out = self.bn3(self.conv3(out))
73
+ out += self.shortcut(x)
74
+ out = F.relu(out)
75
+ return out
76
+
77
+
78
+ class ResNeXt(nn.Module):
79
+ def __init__(self, input_channel, num_blocks, cardinality, bottleneck_width, n_classes=10):
80
+ super(ResNeXt, self).__init__()
81
+ self.cardinality = cardinality
82
+ self.bottleneck_width = bottleneck_width
83
+ self.in_planes = 64
84
+
85
+ self.conv1 = nn.Conv2d(input_channel, 64, kernel_size=1, bias=False)
86
+ self.bn1 = nn.BatchNorm2d(64)
87
+ self.layer1 = self._make_layer(num_blocks[0], 1)
88
+ self.layer2 = self._make_layer(num_blocks[1], 2)
89
+ self.layer3 = self._make_layer(num_blocks[2], 2)
90
+ # self.layer4 = self._make_layer(num_blocks[3], 2)
91
+ self.pool = nn.AdaptiveAvgPool2d(1)
92
+ self.linear = nn.Linear(cardinality*bottleneck_width*8, n_classes)
93
+
94
+ def _make_layer(self, num_blocks, stride):
95
+ strides = [stride] + [1]*(num_blocks-1)
96
+ layers = []
97
+ for stride in strides:
98
+ layers.append(Block(self.in_planes, self.cardinality, self.bottleneck_width, stride))
99
+ self.in_planes = Block.expansion * self.cardinality * self.bottleneck_width
100
+ # Increase bottleneck_width by 2 after each stage.
101
+ self.bottleneck_width *= 2
102
+ return nn.Sequential(*layers)
103
+
104
+ # def forward(self, x):
105
+ # out = F.relu(self.bn1(self.conv1(x)))
106
+ # out = self.layer1(out)
107
+ # out = self.layer2(out)
108
+ # out = self.layer3(out)
109
+ # # out = self.layer4(out)
110
+ # out = self.pool(out)
111
+ # out = out.view(out.size(0), -1)
112
+ # out = self.linear(out)
113
+ # return out
114
+
115
+ def forward(self, x, last=False, freeze=False):
116
+ if freeze:
117
+ with torch.no_grad():
118
+ out = F.relu(self.bn1(self.conv1(x)))
119
+ out = self.layer1(out)
120
+ out = self.layer2(out)
121
+ out = self.layer3(out)
122
+
123
+ # out = F.relu(self.bn2(self.conv2(out)))
124
+ # NOTE: change pooling kernel_size 7 -> 4 for CIFAR10
125
+ out = self.pool(out)
126
+ e = out.view(out.size(0), -1)
127
+ else:
128
+ out = F.relu(self.bn1(self.conv1(x)))
129
+ out = self.layer1(out)
130
+ out = self.layer2(out)
131
+ out = self.layer3(out)
132
+ # out = F.relu(self.bn2(self.conv2(out)))
133
+ # NOTE: change pooling kernel_size 7 -> 4 for CIFAR10
134
+ out = self.pool(out)
135
+ e = out.view(out.size(0), -1)
136
+ out = self.linear(e)
137
+ if last:
138
+ return out, e
139
+ else:
140
+ return out
141
+
142
+
143
+
144
+ def ResNeXt29_2x64d(input_channel, n_classes):
145
+ return ResNeXt(input_channel=input_channel, num_blocks=[3,3,3], cardinality=2, bottleneck_width=64, n_classes=n_classes)
146
+
147
+ def ResNeXt29_4x64d(input_channel, n_classes):
148
+ return ResNeXt(input_channel, num_blocks=[3,3,3], cardinality=4, bottleneck_width=64, n_classes=n_classes)
149
+
150
+ def ResNeXt29_8x64d(input_channel, n_classes):
151
+ return ResNeXt(input_channel, num_blocks=[3,3,3], cardinality=8, bottleneck_width=64, n_classes=n_classes)
152
+
153
+ def ResNeXt29_32x4d(input_channel, n_classes):
154
+ return ResNeXt(input_channel, num_blocks=[3,3,3], cardinality=32, bottleneck_width=4, n_classes=n_classes)
@@ -0,0 +1,22 @@
1
+ from .efficientnet import EfficientNetB0
2
+ from .efficientnetb7 import EfficientNet
3
+ from .ResNeXt import ResNeXt29_2x64d
4
+ from .ResNeXt import ResNeXt29_4x64d
5
+ from .ResNeXt import ResNeXt29_8x64d
6
+ from .ResNeXt import ResNeXt29_32x4d
7
+ from .resnet import ResNext50_32x4d
8
+ from .resnet import ResNext101_32x8d
9
+ from .resnet import ResNext101_64x4d
10
+ from .mobilenet import MobileNet
11
+ from .MobilenetV2 import MobileNetV2
12
+ from .resnet import ResNet18
13
+ from .resnet import ResNet34
14
+ from .resnet import ResNet50
15
+ from .resnet import ResNet101
16
+ from .resnet import ResNet152
17
+ from .resnet9 import ResNet9
18
+ from .fashioncnn import FashionCNN
19
+ # from .regression import TwoLayerNet
20
+ # from .regression import ThreeLayerNet
21
+ from .BERT_model import bertmodel
22
+ # from .regression import DualNet