torchvision 0.1.0 → 0.2.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +35 -0
- data/LICENSE.txt +1 -1
- data/README.md +133 -5
- data/lib/torchvision.rb +40 -1
- data/lib/torchvision/datasets/cifar10.rb +117 -0
- data/lib/torchvision/datasets/cifar100.rb +41 -0
- data/lib/torchvision/datasets/dataset_folder.rb +91 -0
- data/lib/torchvision/datasets/fashion_mnist.rb +30 -0
- data/lib/torchvision/datasets/image_folder.rb +12 -0
- data/lib/torchvision/datasets/kmnist.rb +30 -0
- data/lib/torchvision/datasets/mnist.rb +47 -76
- data/lib/torchvision/datasets/vision_dataset.rb +67 -0
- data/lib/torchvision/models/alexnet.rb +42 -0
- data/lib/torchvision/models/basic_block.rb +46 -0
- data/lib/torchvision/models/bottleneck.rb +47 -0
- data/lib/torchvision/models/resnet.rb +129 -0
- data/lib/torchvision/models/resnet101.rb +9 -0
- data/lib/torchvision/models/resnet152.rb +9 -0
- data/lib/torchvision/models/resnet18.rb +9 -0
- data/lib/torchvision/models/resnet34.rb +9 -0
- data/lib/torchvision/models/resnet50.rb +9 -0
- data/lib/torchvision/models/resnext101_32x8d.rb +11 -0
- data/lib/torchvision/models/resnext50_32x4d.rb +11 -0
- data/lib/torchvision/models/vgg.rb +93 -0
- data/lib/torchvision/models/vgg11.rb +9 -0
- data/lib/torchvision/models/vgg11_bn.rb +9 -0
- data/lib/torchvision/models/vgg13.rb +9 -0
- data/lib/torchvision/models/vgg13_bn.rb +9 -0
- data/lib/torchvision/models/vgg16.rb +9 -0
- data/lib/torchvision/models/vgg16_bn.rb +9 -0
- data/lib/torchvision/models/vgg19.rb +9 -0
- data/lib/torchvision/models/vgg19_bn.rb +9 -0
- data/lib/torchvision/models/wide_resnet101_2.rb +10 -0
- data/lib/torchvision/models/wide_resnet50_2.rb +10 -0
- data/lib/torchvision/transforms/center_crop.rb +13 -0
- data/lib/torchvision/transforms/compose.rb +2 -2
- data/lib/torchvision/transforms/functional.rb +142 -7
- data/lib/torchvision/transforms/normalize.rb +2 -2
- data/lib/torchvision/transforms/random_horizontal_flip.rb +18 -0
- data/lib/torchvision/transforms/random_resized_crop.rb +70 -0
- data/lib/torchvision/transforms/random_vertical_flip.rb +18 -0
- data/lib/torchvision/transforms/resize.rb +13 -0
- data/lib/torchvision/transforms/to_tensor.rb +2 -2
- data/lib/torchvision/utils.rb +120 -0
- data/lib/torchvision/version.rb +1 -1
- metadata +50 -57
@@ -0,0 +1,46 @@
|
|
1
|
+
module TorchVision
|
2
|
+
module Models
|
3
|
+
class BasicBlock < Torch::NN::Module
|
4
|
+
def initialize(inplanes, planes, stride: 1, downsample: nil, groups: 1, base_width: 64, dilation: 1, norm_layer: nil)
|
5
|
+
super()
|
6
|
+
norm_layer ||= Torch::NN::BatchNorm2d
|
7
|
+
if groups != 1 || base_width != 64
|
8
|
+
raise ArgumentError, "BasicBlock only supports groups=1 and base_width=64"
|
9
|
+
end
|
10
|
+
if dilation > 1
|
11
|
+
raise NotImplementedError, "Dilation > 1 not supported in BasicBlock"
|
12
|
+
end
|
13
|
+
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
14
|
+
@conv1 = Torch::NN::Conv2d.new(inplanes, planes, 3, stride: stride, padding: 1, groups: 1, bias: false, dilation: 1)
|
15
|
+
@bn1 = norm_layer.new(planes)
|
16
|
+
@relu = Torch::NN::ReLU.new(inplace: true)
|
17
|
+
@conv2 = Torch::NN::Conv2d.new(planes, planes, 3, stride: 1, padding: 1, groups: 1, bias: false, dilation: 1)
|
18
|
+
@bn2 = norm_layer.new(planes)
|
19
|
+
@downsample = downsample
|
20
|
+
@stride = stride
|
21
|
+
end
|
22
|
+
|
23
|
+
def forward(x)
|
24
|
+
identity = x
|
25
|
+
|
26
|
+
out = @conv1.call(x)
|
27
|
+
out = @bn1.call(out)
|
28
|
+
out = @relu.call(out)
|
29
|
+
|
30
|
+
out = @conv2.call(out)
|
31
|
+
out = @bn2.call(out)
|
32
|
+
|
33
|
+
identity = @downsample.call(x) if @downsample
|
34
|
+
|
35
|
+
out += identity
|
36
|
+
out = @relu.call(out)
|
37
|
+
|
38
|
+
out
|
39
|
+
end
|
40
|
+
|
41
|
+
def self.expansion
|
42
|
+
1
|
43
|
+
end
|
44
|
+
end
|
45
|
+
end
|
46
|
+
end
|
@@ -0,0 +1,47 @@
|
|
1
|
+
module TorchVision
|
2
|
+
module Models
|
3
|
+
class Bottleneck < Torch::NN::Module
|
4
|
+
def initialize(inplanes, planes, stride: 1, downsample: nil, groups: 1, base_width: 64, dilation: 1, norm_layer: nil)
|
5
|
+
super()
|
6
|
+
norm_layer ||= Torch::NN::BatchNorm2d
|
7
|
+
width = (planes * (base_width / 64.0)).to_i * groups
|
8
|
+
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
9
|
+
@conv1 = Torch::NN::Conv2d.new(inplanes, width, 1, stride: 1, bias: false)
|
10
|
+
@bn1 = norm_layer.new(width)
|
11
|
+
@conv2 = Torch::NN::Conv2d.new(width, width, 3, stride: stride, padding: dilation, groups: groups, bias: false, dilation: dilation)
|
12
|
+
@bn2 = norm_layer.new(width)
|
13
|
+
@conv3 = Torch::NN::Conv2d.new(width, planes * self.class.expansion, 1, stride: 1, bias: false)
|
14
|
+
@bn3 = norm_layer.new(planes * self.class.expansion)
|
15
|
+
@relu = Torch::NN::ReLU.new(inplace: true)
|
16
|
+
@downsample = downsample
|
17
|
+
@stride = stride
|
18
|
+
end
|
19
|
+
|
20
|
+
def forward(x)
|
21
|
+
identity = x
|
22
|
+
|
23
|
+
out = @conv1.call(x)
|
24
|
+
out = @bn1.call(out)
|
25
|
+
out = @relu.call(out)
|
26
|
+
|
27
|
+
out = @conv2.call(out)
|
28
|
+
out = @bn2.call(out)
|
29
|
+
out = @relu.call(out)
|
30
|
+
|
31
|
+
out = @conv3.call(out)
|
32
|
+
out = @bn3.call(out)
|
33
|
+
|
34
|
+
identity = @downsample.call(x) if @downsample
|
35
|
+
|
36
|
+
out += identity
|
37
|
+
out = @relu.call(out)
|
38
|
+
|
39
|
+
out
|
40
|
+
end
|
41
|
+
|
42
|
+
def self.expansion
|
43
|
+
4
|
44
|
+
end
|
45
|
+
end
|
46
|
+
end
|
47
|
+
end
|
@@ -0,0 +1,129 @@
|
|
1
|
+
module TorchVision
|
2
|
+
module Models
|
3
|
+
class ResNet < Torch::NN::Module
|
4
|
+
MODEL_URLS = {
|
5
|
+
"resnet18" => "https://download.pytorch.org/models/resnet18-5c106cde.pth",
|
6
|
+
"resnet34" => "https://download.pytorch.org/models/resnet34-333f7ec4.pth",
|
7
|
+
"resnet50" => "https://download.pytorch.org/models/resnet50-19c8e357.pth",
|
8
|
+
"resnet101" => "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth",
|
9
|
+
"resnet152" => "https://download.pytorch.org/models/resnet152-b121ed2d.pth",
|
10
|
+
"resnext50_32x4d" => "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth",
|
11
|
+
"resnext101_32x8d" => "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth",
|
12
|
+
"wide_resnet50_2" => "https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth",
|
13
|
+
"wide_resnet101_2" => "https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth"
|
14
|
+
}
|
15
|
+
|
16
|
+
def initialize(block, layers, num_classes=1000, zero_init_residual: false,
|
17
|
+
groups: 1, width_per_group: 64, replace_stride_with_dilation: nil, norm_layer: nil)
|
18
|
+
|
19
|
+
super()
|
20
|
+
norm_layer ||= Torch::NN::BatchNorm2d
|
21
|
+
@norm_layer = norm_layer
|
22
|
+
|
23
|
+
@inplanes = 64
|
24
|
+
@dilation = 1
|
25
|
+
if replace_stride_with_dilation.nil?
|
26
|
+
# each element in the tuple indicates if we should replace
|
27
|
+
# the 2x2 stride with a dilated convolution instead
|
28
|
+
replace_stride_with_dilation = [false, false, false]
|
29
|
+
end
|
30
|
+
if replace_stride_with_dilation.length != 3
|
31
|
+
raise ArgumentError, "replace_stride_with_dilation should be nil or a 3-element tuple, got #{replace_stride_with_dilation}"
|
32
|
+
end
|
33
|
+
@groups = groups
|
34
|
+
@base_width = width_per_group
|
35
|
+
@conv1 = Torch::NN::Conv2d.new(3, @inplanes, 7, stride: 2, padding: 3, bias: false)
|
36
|
+
@bn1 = norm_layer.new(@inplanes)
|
37
|
+
@relu = Torch::NN::ReLU.new(inplace: true)
|
38
|
+
@maxpool = Torch::NN::MaxPool2d.new(3, stride: 2, padding: 1)
|
39
|
+
@layer1 = _make_layer(block, 64, layers[0])
|
40
|
+
@layer2 = _make_layer(block, 128, layers[1], stride: 2, dilate: replace_stride_with_dilation[0])
|
41
|
+
@layer3 = _make_layer(block, 256, layers[2], stride: 2, dilate: replace_stride_with_dilation[1])
|
42
|
+
@layer4 = _make_layer(block, 512, layers[3], stride: 2, dilate: replace_stride_with_dilation[2])
|
43
|
+
@avgpool = Torch::NN::AdaptiveAvgPool2d.new([1, 1])
|
44
|
+
@fc = Torch::NN::Linear.new(512 * block.expansion, num_classes)
|
45
|
+
|
46
|
+
modules.each do |m|
|
47
|
+
case m
|
48
|
+
when Torch::NN::Conv2d
|
49
|
+
Torch::NN::Init.kaiming_normal!(m.weight, mode: "fan_out", nonlinearity: "relu")
|
50
|
+
when Torch::NN::BatchNorm2d, Torch::NN::GroupNorm
|
51
|
+
Torch::NN::Init.constant!(m.weight, 1)
|
52
|
+
Torch::NN::Init.constant!(m.bias, 0)
|
53
|
+
end
|
54
|
+
end
|
55
|
+
|
56
|
+
# Zero-initialize the last BN in each residual branch,
|
57
|
+
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
58
|
+
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
59
|
+
if zero_init_residual
|
60
|
+
modules.each do |m|
|
61
|
+
case m
|
62
|
+
when Bottleneck
|
63
|
+
Torch::NN::Init.constant!(m.bn3.weight, 0)
|
64
|
+
when BasicBlock
|
65
|
+
Torch::NN::Init.constant!(m.bn2.weight, 0)
|
66
|
+
end
|
67
|
+
end
|
68
|
+
end
|
69
|
+
end
|
70
|
+
|
71
|
+
def _make_layer(block, planes, blocks, stride: 1, dilate: false)
|
72
|
+
norm_layer = @norm_layer
|
73
|
+
downsample = nil
|
74
|
+
previous_dilation = @dilation
|
75
|
+
if dilate
|
76
|
+
@dilation *= stride
|
77
|
+
stride = 1
|
78
|
+
end
|
79
|
+
if stride != 1 || @inplanes != planes * block.expansion
|
80
|
+
downsample = Torch::NN::Sequential.new(
|
81
|
+
Torch::NN::Conv2d.new(@inplanes, planes * block.expansion, 1, stride: stride, bias: false),
|
82
|
+
norm_layer.new(planes * block.expansion)
|
83
|
+
)
|
84
|
+
end
|
85
|
+
|
86
|
+
layers = []
|
87
|
+
layers << block.new(@inplanes, planes, stride: stride, downsample: downsample, groups: @groups, base_width: @base_width, dilation: previous_dilation, norm_layer: norm_layer)
|
88
|
+
@inplanes = planes * block.expansion
|
89
|
+
(blocks - 1).times do
|
90
|
+
layers << block.new(@inplanes, planes, groups: @groups, base_width: @base_width, dilation: @dilation, norm_layer: norm_layer)
|
91
|
+
end
|
92
|
+
|
93
|
+
Torch::NN::Sequential.new(*layers)
|
94
|
+
end
|
95
|
+
|
96
|
+
def _forward_impl(x)
|
97
|
+
x = @conv1.call(x)
|
98
|
+
x = @bn1.call(x)
|
99
|
+
x = @relu.call(x)
|
100
|
+
x = @maxpool.call(x)
|
101
|
+
|
102
|
+
x = @layer1.call(x)
|
103
|
+
x = @layer2.call(x)
|
104
|
+
x = @layer3.call(x)
|
105
|
+
x = @layer4.call(x)
|
106
|
+
|
107
|
+
x = @avgpool.call(x)
|
108
|
+
x = Torch.flatten(x, 1)
|
109
|
+
x = @fc.call(x)
|
110
|
+
|
111
|
+
x
|
112
|
+
end
|
113
|
+
|
114
|
+
def forward(x)
|
115
|
+
_forward_impl(x)
|
116
|
+
end
|
117
|
+
|
118
|
+
def self.make_model(arch, block, layers, pretrained: false, **kwargs)
|
119
|
+
model = ResNet.new(block, layers, **kwargs)
|
120
|
+
if pretrained
|
121
|
+
url = MODEL_URLS[arch]
|
122
|
+
state_dict = Torch::Hub.load_state_dict_from_url(url)
|
123
|
+
model.load_state_dict(state_dict)
|
124
|
+
end
|
125
|
+
model
|
126
|
+
end
|
127
|
+
end
|
128
|
+
end
|
129
|
+
end
|
@@ -0,0 +1,93 @@
|
|
1
|
+
module TorchVision
|
2
|
+
module Models
|
3
|
+
class VGG < Torch::NN::Module
|
4
|
+
MODEL_URLS = {
|
5
|
+
"vgg11" => "https://download.pytorch.org/models/vgg11-bbd30ac9.pth",
|
6
|
+
"vgg13" => "https://download.pytorch.org/models/vgg13-c768596a.pth",
|
7
|
+
"vgg16" => "https://download.pytorch.org/models/vgg16-397923af.pth",
|
8
|
+
"vgg19" => "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth",
|
9
|
+
"vgg11_bn" => "https://download.pytorch.org/models/vgg11_bn-6002323d.pth",
|
10
|
+
"vgg13_bn" => "https://download.pytorch.org/models/vgg13_bn-abd245e5.pth",
|
11
|
+
"vgg16_bn" => "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth",
|
12
|
+
"vgg19_bn" => "https://download.pytorch.org/models/vgg19_bn-c79401a0.pth"
|
13
|
+
}
|
14
|
+
|
15
|
+
def initialize(features, num_classes: 1000, init_weights: true)
|
16
|
+
super()
|
17
|
+
@features = features
|
18
|
+
@avgpool = Torch::NN::AdaptiveAvgPool2d.new([7, 7])
|
19
|
+
@classifier = Torch::NN::Sequential.new(
|
20
|
+
Torch::NN::Linear.new(512 * 7 * 7, 4096),
|
21
|
+
Torch::NN::ReLU.new(inplace: true),
|
22
|
+
Torch::NN::Dropout.new,
|
23
|
+
Torch::NN::Linear.new(4096, 4096),
|
24
|
+
Torch::NN::ReLU.new(inplace: true),
|
25
|
+
Torch::NN::Dropout.new,
|
26
|
+
Torch::NN::Linear.new(4096, num_classes)
|
27
|
+
)
|
28
|
+
_initialize_weights if init_weights
|
29
|
+
end
|
30
|
+
|
31
|
+
def forward(x)
|
32
|
+
x = @features.call(x)
|
33
|
+
x = @avgpool.call(x)
|
34
|
+
x = Torch.flatten(x, 1)
|
35
|
+
x = @classifier.call(x)
|
36
|
+
x
|
37
|
+
end
|
38
|
+
|
39
|
+
def _initialize_weights
|
40
|
+
modules.each do |m|
|
41
|
+
case m
|
42
|
+
when Torch::NN::Conv2d
|
43
|
+
Torch::NN::Init.kaiming_normal!(m.weight, mode: "fan_out", nonlinearity: "relu")
|
44
|
+
Torch::NN::Init.constant!(m.bias, 0) if m.bias
|
45
|
+
when Torch::NN::BatchNorm2d
|
46
|
+
Torch::NN::Init.constant!(m.weight, 1)
|
47
|
+
Torch::NN::Init.constant!(m.bias, 0)
|
48
|
+
when Torch::NN::Linear
|
49
|
+
Torch::NN::Init.normal!(m.weight, mean: 0, std: 0.01)
|
50
|
+
Torch::NN::Init.constant!(m.bias, 0)
|
51
|
+
end
|
52
|
+
end
|
53
|
+
end
|
54
|
+
|
55
|
+
CFGS = {
|
56
|
+
"A" => [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
|
57
|
+
"B" => [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
|
58
|
+
"D" => [64, 64, "M", 128, 128, "M", 256, 256, 256, "M", 512, 512, 512, "M", 512, 512, 512, "M"],
|
59
|
+
"E" => [64, 64, "M", 128, 128, "M", 256, 256, 256, 256, "M", 512, 512, 512, 512, "M", 512, 512, 512, 512, "M"],
|
60
|
+
}
|
61
|
+
|
62
|
+
def self.make_model(arch, cfg, batch_norm, pretrained: false, **kwargs)
|
63
|
+
kwargs[:init_weights] = false if pretrained
|
64
|
+
model = VGG.new(make_layers(CFGS[cfg], batch_norm), **kwargs)
|
65
|
+
if pretrained
|
66
|
+
url = MODEL_URLS[arch]
|
67
|
+
state_dict = Torch::Hub.load_state_dict_from_url(url)
|
68
|
+
model.load_state_dict(state_dict)
|
69
|
+
end
|
70
|
+
model
|
71
|
+
end
|
72
|
+
|
73
|
+
def self.make_layers(cfg, batch_norm)
|
74
|
+
layers = []
|
75
|
+
in_channels = 3
|
76
|
+
cfg.each do |v|
|
77
|
+
if v == "M"
|
78
|
+
layers += [Torch::NN::MaxPool2d.new(2, stride: 2)]
|
79
|
+
else
|
80
|
+
conv2d = Torch::NN::Conv2d.new(in_channels, v, 3, padding: 1)
|
81
|
+
if batch_norm
|
82
|
+
layers += [conv2d, Torch::NN::BatchNorm2d.new(v), Torch::NN::ReLU.new(inplace: true)]
|
83
|
+
else
|
84
|
+
layers += [conv2d, Torch::NN::ReLU.new(inplace: true)]
|
85
|
+
end
|
86
|
+
in_channels = v
|
87
|
+
end
|
88
|
+
end
|
89
|
+
Torch::NN::Sequential.new(*layers)
|
90
|
+
end
|
91
|
+
end
|
92
|
+
end
|
93
|
+
end
|