torchvision 0.1.1 → 0.2.2
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +35 -0
- data/LICENSE.txt +1 -1
- data/README.md +133 -7
- data/lib/torchvision.rb +39 -0
- data/lib/torchvision/datasets/cifar10.rb +117 -0
- data/lib/torchvision/datasets/cifar100.rb +41 -0
- data/lib/torchvision/datasets/dataset_folder.rb +91 -0
- data/lib/torchvision/datasets/fashion_mnist.rb +30 -0
- data/lib/torchvision/datasets/image_folder.rb +12 -0
- data/lib/torchvision/datasets/kmnist.rb +30 -0
- data/lib/torchvision/datasets/mnist.rb +47 -75
- data/lib/torchvision/datasets/vision_dataset.rb +67 -0
- data/lib/torchvision/models/alexnet.rb +42 -0
- data/lib/torchvision/models/basic_block.rb +46 -0
- data/lib/torchvision/models/bottleneck.rb +47 -0
- data/lib/torchvision/models/resnet.rb +129 -0
- data/lib/torchvision/models/resnet101.rb +9 -0
- data/lib/torchvision/models/resnet152.rb +9 -0
- data/lib/torchvision/models/resnet18.rb +9 -0
- data/lib/torchvision/models/resnet34.rb +9 -0
- data/lib/torchvision/models/resnet50.rb +9 -0
- data/lib/torchvision/models/resnext101_32x8d.rb +11 -0
- data/lib/torchvision/models/resnext50_32x4d.rb +11 -0
- data/lib/torchvision/models/vgg.rb +93 -0
- data/lib/torchvision/models/vgg11.rb +9 -0
- data/lib/torchvision/models/vgg11_bn.rb +9 -0
- data/lib/torchvision/models/vgg13.rb +9 -0
- data/lib/torchvision/models/vgg13_bn.rb +9 -0
- data/lib/torchvision/models/vgg16.rb +9 -0
- data/lib/torchvision/models/vgg16_bn.rb +9 -0
- data/lib/torchvision/models/vgg19.rb +9 -0
- data/lib/torchvision/models/vgg19_bn.rb +9 -0
- data/lib/torchvision/models/wide_resnet101_2.rb +10 -0
- data/lib/torchvision/models/wide_resnet50_2.rb +10 -0
- data/lib/torchvision/transforms/center_crop.rb +13 -0
- data/lib/torchvision/transforms/compose.rb +2 -2
- data/lib/torchvision/transforms/functional.rb +142 -8
- data/lib/torchvision/transforms/normalize.rb +2 -2
- data/lib/torchvision/transforms/random_horizontal_flip.rb +18 -0
- data/lib/torchvision/transforms/random_resized_crop.rb +70 -0
- data/lib/torchvision/transforms/random_vertical_flip.rb +18 -0
- data/lib/torchvision/transforms/resize.rb +13 -0
- data/lib/torchvision/transforms/to_tensor.rb +2 -2
- data/lib/torchvision/utils.rb +118 -0
- data/lib/torchvision/version.rb +1 -1
- metadata +51 -44
@@ -0,0 +1,46 @@
|
|
1
|
+
module TorchVision
|
2
|
+
module Models
|
3
|
+
class BasicBlock < Torch::NN::Module
|
4
|
+
def initialize(inplanes, planes, stride: 1, downsample: nil, groups: 1, base_width: 64, dilation: 1, norm_layer: nil)
|
5
|
+
super()
|
6
|
+
norm_layer ||= Torch::NN::BatchNorm2d
|
7
|
+
if groups != 1 || base_width != 64
|
8
|
+
raise ArgumentError, "BasicBlock only supports groups=1 and base_width=64"
|
9
|
+
end
|
10
|
+
if dilation > 1
|
11
|
+
raise NotImplementedError, "Dilation > 1 not supported in BasicBlock"
|
12
|
+
end
|
13
|
+
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
14
|
+
@conv1 = Torch::NN::Conv2d.new(inplanes, planes, 3, stride: stride, padding: 1, groups: 1, bias: false, dilation: 1)
|
15
|
+
@bn1 = norm_layer.new(planes)
|
16
|
+
@relu = Torch::NN::ReLU.new(inplace: true)
|
17
|
+
@conv2 = Torch::NN::Conv2d.new(planes, planes, 3, stride: 1, padding: 1, groups: 1, bias: false, dilation: 1)
|
18
|
+
@bn2 = norm_layer.new(planes)
|
19
|
+
@downsample = downsample
|
20
|
+
@stride = stride
|
21
|
+
end
|
22
|
+
|
23
|
+
def forward(x)
|
24
|
+
identity = x
|
25
|
+
|
26
|
+
out = @conv1.call(x)
|
27
|
+
out = @bn1.call(out)
|
28
|
+
out = @relu.call(out)
|
29
|
+
|
30
|
+
out = @conv2.call(out)
|
31
|
+
out = @bn2.call(out)
|
32
|
+
|
33
|
+
identity = @downsample.call(x) if @downsample
|
34
|
+
|
35
|
+
out += identity
|
36
|
+
out = @relu.call(out)
|
37
|
+
|
38
|
+
out
|
39
|
+
end
|
40
|
+
|
41
|
+
def self.expansion
|
42
|
+
1
|
43
|
+
end
|
44
|
+
end
|
45
|
+
end
|
46
|
+
end
|
@@ -0,0 +1,47 @@
|
|
1
|
+
module TorchVision
|
2
|
+
module Models
|
3
|
+
class Bottleneck < Torch::NN::Module
|
4
|
+
def initialize(inplanes, planes, stride: 1, downsample: nil, groups: 1, base_width: 64, dilation: 1, norm_layer: nil)
|
5
|
+
super()
|
6
|
+
norm_layer ||= Torch::NN::BatchNorm2d
|
7
|
+
width = (planes * (base_width / 64.0)).to_i * groups
|
8
|
+
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
9
|
+
@conv1 = Torch::NN::Conv2d.new(inplanes, width, 1, stride: 1, bias: false)
|
10
|
+
@bn1 = norm_layer.new(width)
|
11
|
+
@conv2 = Torch::NN::Conv2d.new(width, width, 3, stride: stride, padding: dilation, groups: groups, bias: false, dilation: dilation)
|
12
|
+
@bn2 = norm_layer.new(width)
|
13
|
+
@conv3 = Torch::NN::Conv2d.new(width, planes * self.class.expansion, 1, stride: 1, bias: false)
|
14
|
+
@bn3 = norm_layer.new(planes * self.class.expansion)
|
15
|
+
@relu = Torch::NN::ReLU.new(inplace: true)
|
16
|
+
@downsample = downsample
|
17
|
+
@stride = stride
|
18
|
+
end
|
19
|
+
|
20
|
+
def forward(x)
|
21
|
+
identity = x
|
22
|
+
|
23
|
+
out = @conv1.call(x)
|
24
|
+
out = @bn1.call(out)
|
25
|
+
out = @relu.call(out)
|
26
|
+
|
27
|
+
out = @conv2.call(out)
|
28
|
+
out = @bn2.call(out)
|
29
|
+
out = @relu.call(out)
|
30
|
+
|
31
|
+
out = @conv3.call(out)
|
32
|
+
out = @bn3.call(out)
|
33
|
+
|
34
|
+
identity = @downsample.call(x) if @downsample
|
35
|
+
|
36
|
+
out += identity
|
37
|
+
out = @relu.call(out)
|
38
|
+
|
39
|
+
out
|
40
|
+
end
|
41
|
+
|
42
|
+
def self.expansion
|
43
|
+
4
|
44
|
+
end
|
45
|
+
end
|
46
|
+
end
|
47
|
+
end
|
@@ -0,0 +1,129 @@
|
|
1
|
+
module TorchVision
|
2
|
+
module Models
|
3
|
+
class ResNet < Torch::NN::Module
|
4
|
+
MODEL_URLS = {
|
5
|
+
"resnet18" => "https://download.pytorch.org/models/resnet18-5c106cde.pth",
|
6
|
+
"resnet34" => "https://download.pytorch.org/models/resnet34-333f7ec4.pth",
|
7
|
+
"resnet50" => "https://download.pytorch.org/models/resnet50-19c8e357.pth",
|
8
|
+
"resnet101" => "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth",
|
9
|
+
"resnet152" => "https://download.pytorch.org/models/resnet152-b121ed2d.pth",
|
10
|
+
"resnext50_32x4d" => "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth",
|
11
|
+
"resnext101_32x8d" => "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth",
|
12
|
+
"wide_resnet50_2" => "https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth",
|
13
|
+
"wide_resnet101_2" => "https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth"
|
14
|
+
}
|
15
|
+
|
16
|
+
def initialize(block, layers, num_classes=1000, zero_init_residual: false,
|
17
|
+
groups: 1, width_per_group: 64, replace_stride_with_dilation: nil, norm_layer: nil)
|
18
|
+
|
19
|
+
super()
|
20
|
+
norm_layer ||= Torch::NN::BatchNorm2d
|
21
|
+
@norm_layer = norm_layer
|
22
|
+
|
23
|
+
@inplanes = 64
|
24
|
+
@dilation = 1
|
25
|
+
if replace_stride_with_dilation.nil?
|
26
|
+
# each element in the tuple indicates if we should replace
|
27
|
+
# the 2x2 stride with a dilated convolution instead
|
28
|
+
replace_stride_with_dilation = [false, false, false]
|
29
|
+
end
|
30
|
+
if replace_stride_with_dilation.length != 3
|
31
|
+
raise ArgumentError, "replace_stride_with_dilation should be nil or a 3-element tuple, got #{replace_stride_with_dilation}"
|
32
|
+
end
|
33
|
+
@groups = groups
|
34
|
+
@base_width = width_per_group
|
35
|
+
@conv1 = Torch::NN::Conv2d.new(3, @inplanes, 7, stride: 2, padding: 3, bias: false)
|
36
|
+
@bn1 = norm_layer.new(@inplanes)
|
37
|
+
@relu = Torch::NN::ReLU.new(inplace: true)
|
38
|
+
@maxpool = Torch::NN::MaxPool2d.new(3, stride: 2, padding: 1)
|
39
|
+
@layer1 = _make_layer(block, 64, layers[0])
|
40
|
+
@layer2 = _make_layer(block, 128, layers[1], stride: 2, dilate: replace_stride_with_dilation[0])
|
41
|
+
@layer3 = _make_layer(block, 256, layers[2], stride: 2, dilate: replace_stride_with_dilation[1])
|
42
|
+
@layer4 = _make_layer(block, 512, layers[3], stride: 2, dilate: replace_stride_with_dilation[2])
|
43
|
+
@avgpool = Torch::NN::AdaptiveAvgPool2d.new([1, 1])
|
44
|
+
@fc = Torch::NN::Linear.new(512 * block.expansion, num_classes)
|
45
|
+
|
46
|
+
modules.each do |m|
|
47
|
+
case m
|
48
|
+
when Torch::NN::Conv2d
|
49
|
+
Torch::NN::Init.kaiming_normal!(m.weight, mode: "fan_out", nonlinearity: "relu")
|
50
|
+
when Torch::NN::BatchNorm2d, Torch::NN::GroupNorm
|
51
|
+
Torch::NN::Init.constant!(m.weight, 1)
|
52
|
+
Torch::NN::Init.constant!(m.bias, 0)
|
53
|
+
end
|
54
|
+
end
|
55
|
+
|
56
|
+
# Zero-initialize the last BN in each residual branch,
|
57
|
+
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
58
|
+
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
59
|
+
if zero_init_residual
|
60
|
+
modules.each do |m|
|
61
|
+
case m
|
62
|
+
when Bottleneck
|
63
|
+
Torch::NN::Init.constant!(m.bn3.weight, 0)
|
64
|
+
when BasicBlock
|
65
|
+
Torch::NN::Init.constant!(m.bn2.weight, 0)
|
66
|
+
end
|
67
|
+
end
|
68
|
+
end
|
69
|
+
end
|
70
|
+
|
71
|
+
def _make_layer(block, planes, blocks, stride: 1, dilate: false)
|
72
|
+
norm_layer = @norm_layer
|
73
|
+
downsample = nil
|
74
|
+
previous_dilation = @dilation
|
75
|
+
if dilate
|
76
|
+
@dilation *= stride
|
77
|
+
stride = 1
|
78
|
+
end
|
79
|
+
if stride != 1 || @inplanes != planes * block.expansion
|
80
|
+
downsample = Torch::NN::Sequential.new(
|
81
|
+
Torch::NN::Conv2d.new(@inplanes, planes * block.expansion, 1, stride: stride, bias: false),
|
82
|
+
norm_layer.new(planes * block.expansion)
|
83
|
+
)
|
84
|
+
end
|
85
|
+
|
86
|
+
layers = []
|
87
|
+
layers << block.new(@inplanes, planes, stride: stride, downsample: downsample, groups: @groups, base_width: @base_width, dilation: previous_dilation, norm_layer: norm_layer)
|
88
|
+
@inplanes = planes * block.expansion
|
89
|
+
(blocks - 1).times do
|
90
|
+
layers << block.new(@inplanes, planes, groups: @groups, base_width: @base_width, dilation: @dilation, norm_layer: norm_layer)
|
91
|
+
end
|
92
|
+
|
93
|
+
Torch::NN::Sequential.new(*layers)
|
94
|
+
end
|
95
|
+
|
96
|
+
def _forward_impl(x)
|
97
|
+
x = @conv1.call(x)
|
98
|
+
x = @bn1.call(x)
|
99
|
+
x = @relu.call(x)
|
100
|
+
x = @maxpool.call(x)
|
101
|
+
|
102
|
+
x = @layer1.call(x)
|
103
|
+
x = @layer2.call(x)
|
104
|
+
x = @layer3.call(x)
|
105
|
+
x = @layer4.call(x)
|
106
|
+
|
107
|
+
x = @avgpool.call(x)
|
108
|
+
x = Torch.flatten(x, 1)
|
109
|
+
x = @fc.call(x)
|
110
|
+
|
111
|
+
x
|
112
|
+
end
|
113
|
+
|
114
|
+
def forward(x)
|
115
|
+
_forward_impl(x)
|
116
|
+
end
|
117
|
+
|
118
|
+
def self.make_model(arch, block, layers, pretrained: false, **kwargs)
|
119
|
+
model = ResNet.new(block, layers, **kwargs)
|
120
|
+
if pretrained
|
121
|
+
url = MODEL_URLS[arch]
|
122
|
+
state_dict = Torch::Hub.load_state_dict_from_url(url)
|
123
|
+
model.load_state_dict(state_dict)
|
124
|
+
end
|
125
|
+
model
|
126
|
+
end
|
127
|
+
end
|
128
|
+
end
|
129
|
+
end
|
@@ -0,0 +1,93 @@
|
|
1
|
+
module TorchVision
|
2
|
+
module Models
|
3
|
+
class VGG < Torch::NN::Module
|
4
|
+
MODEL_URLS = {
|
5
|
+
"vgg11" => "https://download.pytorch.org/models/vgg11-bbd30ac9.pth",
|
6
|
+
"vgg13" => "https://download.pytorch.org/models/vgg13-c768596a.pth",
|
7
|
+
"vgg16" => "https://download.pytorch.org/models/vgg16-397923af.pth",
|
8
|
+
"vgg19" => "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth",
|
9
|
+
"vgg11_bn" => "https://download.pytorch.org/models/vgg11_bn-6002323d.pth",
|
10
|
+
"vgg13_bn" => "https://download.pytorch.org/models/vgg13_bn-abd245e5.pth",
|
11
|
+
"vgg16_bn" => "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth",
|
12
|
+
"vgg19_bn" => "https://download.pytorch.org/models/vgg19_bn-c79401a0.pth"
|
13
|
+
}
|
14
|
+
|
15
|
+
def initialize(features, num_classes: 1000, init_weights: true)
|
16
|
+
super()
|
17
|
+
@features = features
|
18
|
+
@avgpool = Torch::NN::AdaptiveAvgPool2d.new([7, 7])
|
19
|
+
@classifier = Torch::NN::Sequential.new(
|
20
|
+
Torch::NN::Linear.new(512 * 7 * 7, 4096),
|
21
|
+
Torch::NN::ReLU.new(inplace: true),
|
22
|
+
Torch::NN::Dropout.new,
|
23
|
+
Torch::NN::Linear.new(4096, 4096),
|
24
|
+
Torch::NN::ReLU.new(inplace: true),
|
25
|
+
Torch::NN::Dropout.new,
|
26
|
+
Torch::NN::Linear.new(4096, num_classes)
|
27
|
+
)
|
28
|
+
_initialize_weights if init_weights
|
29
|
+
end
|
30
|
+
|
31
|
+
def forward(x)
|
32
|
+
x = @features.call(x)
|
33
|
+
x = @avgpool.call(x)
|
34
|
+
x = Torch.flatten(x, 1)
|
35
|
+
x = @classifier.call(x)
|
36
|
+
x
|
37
|
+
end
|
38
|
+
|
39
|
+
def _initialize_weights
|
40
|
+
modules.each do |m|
|
41
|
+
case m
|
42
|
+
when Torch::NN::Conv2d
|
43
|
+
Torch::NN::Init.kaiming_normal!(m.weight, mode: "fan_out", nonlinearity: "relu")
|
44
|
+
Torch::NN::Init.constant!(m.bias, 0) if m.bias
|
45
|
+
when Torch::NN::BatchNorm2d
|
46
|
+
Torch::NN::Init.constant!(m.weight, 1)
|
47
|
+
Torch::NN::Init.constant!(m.bias, 0)
|
48
|
+
when Torch::NN::Linear
|
49
|
+
Torch::NN::Init.normal!(m.weight, mean: 0, std: 0.01)
|
50
|
+
Torch::NN::Init.constant!(m.bias, 0)
|
51
|
+
end
|
52
|
+
end
|
53
|
+
end
|
54
|
+
|
55
|
+
CFGS = {
|
56
|
+
"A" => [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
|
57
|
+
"B" => [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
|
58
|
+
"D" => [64, 64, "M", 128, 128, "M", 256, 256, 256, "M", 512, 512, 512, "M", 512, 512, 512, "M"],
|
59
|
+
"E" => [64, 64, "M", 128, 128, "M", 256, 256, 256, 256, "M", 512, 512, 512, 512, "M", 512, 512, 512, 512, "M"],
|
60
|
+
}
|
61
|
+
|
62
|
+
def self.make_model(arch, cfg, batch_norm, pretrained: false, **kwargs)
|
63
|
+
kwargs[:init_weights] = false if pretrained
|
64
|
+
model = VGG.new(make_layers(CFGS[cfg], batch_norm), **kwargs)
|
65
|
+
if pretrained
|
66
|
+
url = MODEL_URLS[arch]
|
67
|
+
state_dict = Torch::Hub.load_state_dict_from_url(url)
|
68
|
+
model.load_state_dict(state_dict)
|
69
|
+
end
|
70
|
+
model
|
71
|
+
end
|
72
|
+
|
73
|
+
def self.make_layers(cfg, batch_norm)
|
74
|
+
layers = []
|
75
|
+
in_channels = 3
|
76
|
+
cfg.each do |v|
|
77
|
+
if v == "M"
|
78
|
+
layers += [Torch::NN::MaxPool2d.new(2, stride: 2)]
|
79
|
+
else
|
80
|
+
conv2d = Torch::NN::Conv2d.new(in_channels, v, 3, padding: 1)
|
81
|
+
if batch_norm
|
82
|
+
layers += [conv2d, Torch::NN::BatchNorm2d.new(v), Torch::NN::ReLU.new(inplace: true)]
|
83
|
+
else
|
84
|
+
layers += [conv2d, Torch::NN::ReLU.new(inplace: true)]
|
85
|
+
end
|
86
|
+
in_channels = v
|
87
|
+
end
|
88
|
+
end
|
89
|
+
Torch::NN::Sequential.new(*layers)
|
90
|
+
end
|
91
|
+
end
|
92
|
+
end
|
93
|
+
end
|