torchvision 0.1.1 → 0.2.2

Sign up to get free protection for your applications and to get access to all the features.
Files changed (47) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +35 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +133 -7
  5. data/lib/torchvision.rb +39 -0
  6. data/lib/torchvision/datasets/cifar10.rb +117 -0
  7. data/lib/torchvision/datasets/cifar100.rb +41 -0
  8. data/lib/torchvision/datasets/dataset_folder.rb +91 -0
  9. data/lib/torchvision/datasets/fashion_mnist.rb +30 -0
  10. data/lib/torchvision/datasets/image_folder.rb +12 -0
  11. data/lib/torchvision/datasets/kmnist.rb +30 -0
  12. data/lib/torchvision/datasets/mnist.rb +47 -75
  13. data/lib/torchvision/datasets/vision_dataset.rb +67 -0
  14. data/lib/torchvision/models/alexnet.rb +42 -0
  15. data/lib/torchvision/models/basic_block.rb +46 -0
  16. data/lib/torchvision/models/bottleneck.rb +47 -0
  17. data/lib/torchvision/models/resnet.rb +129 -0
  18. data/lib/torchvision/models/resnet101.rb +9 -0
  19. data/lib/torchvision/models/resnet152.rb +9 -0
  20. data/lib/torchvision/models/resnet18.rb +9 -0
  21. data/lib/torchvision/models/resnet34.rb +9 -0
  22. data/lib/torchvision/models/resnet50.rb +9 -0
  23. data/lib/torchvision/models/resnext101_32x8d.rb +11 -0
  24. data/lib/torchvision/models/resnext50_32x4d.rb +11 -0
  25. data/lib/torchvision/models/vgg.rb +93 -0
  26. data/lib/torchvision/models/vgg11.rb +9 -0
  27. data/lib/torchvision/models/vgg11_bn.rb +9 -0
  28. data/lib/torchvision/models/vgg13.rb +9 -0
  29. data/lib/torchvision/models/vgg13_bn.rb +9 -0
  30. data/lib/torchvision/models/vgg16.rb +9 -0
  31. data/lib/torchvision/models/vgg16_bn.rb +9 -0
  32. data/lib/torchvision/models/vgg19.rb +9 -0
  33. data/lib/torchvision/models/vgg19_bn.rb +9 -0
  34. data/lib/torchvision/models/wide_resnet101_2.rb +10 -0
  35. data/lib/torchvision/models/wide_resnet50_2.rb +10 -0
  36. data/lib/torchvision/transforms/center_crop.rb +13 -0
  37. data/lib/torchvision/transforms/compose.rb +2 -2
  38. data/lib/torchvision/transforms/functional.rb +142 -8
  39. data/lib/torchvision/transforms/normalize.rb +2 -2
  40. data/lib/torchvision/transforms/random_horizontal_flip.rb +18 -0
  41. data/lib/torchvision/transforms/random_resized_crop.rb +70 -0
  42. data/lib/torchvision/transforms/random_vertical_flip.rb +18 -0
  43. data/lib/torchvision/transforms/resize.rb +13 -0
  44. data/lib/torchvision/transforms/to_tensor.rb +2 -2
  45. data/lib/torchvision/utils.rb +118 -0
  46. data/lib/torchvision/version.rb +1 -1
  47. metadata +51 -44
@@ -0,0 +1,46 @@
1
+ module TorchVision
2
+ module Models
3
+ class BasicBlock < Torch::NN::Module
4
+ def initialize(inplanes, planes, stride: 1, downsample: nil, groups: 1, base_width: 64, dilation: 1, norm_layer: nil)
5
+ super()
6
+ norm_layer ||= Torch::NN::BatchNorm2d
7
+ if groups != 1 || base_width != 64
8
+ raise ArgumentError, "BasicBlock only supports groups=1 and base_width=64"
9
+ end
10
+ if dilation > 1
11
+ raise NotImplementedError, "Dilation > 1 not supported in BasicBlock"
12
+ end
13
+ # Both self.conv1 and self.downsample layers downsample the input when stride != 1
14
+ @conv1 = Torch::NN::Conv2d.new(inplanes, planes, 3, stride: stride, padding: 1, groups: 1, bias: false, dilation: 1)
15
+ @bn1 = norm_layer.new(planes)
16
+ @relu = Torch::NN::ReLU.new(inplace: true)
17
+ @conv2 = Torch::NN::Conv2d.new(planes, planes, 3, stride: 1, padding: 1, groups: 1, bias: false, dilation: 1)
18
+ @bn2 = norm_layer.new(planes)
19
+ @downsample = downsample
20
+ @stride = stride
21
+ end
22
+
23
+ def forward(x)
24
+ identity = x
25
+
26
+ out = @conv1.call(x)
27
+ out = @bn1.call(out)
28
+ out = @relu.call(out)
29
+
30
+ out = @conv2.call(out)
31
+ out = @bn2.call(out)
32
+
33
+ identity = @downsample.call(x) if @downsample
34
+
35
+ out += identity
36
+ out = @relu.call(out)
37
+
38
+ out
39
+ end
40
+
41
+ def self.expansion
42
+ 1
43
+ end
44
+ end
45
+ end
46
+ end
@@ -0,0 +1,47 @@
1
+ module TorchVision
2
+ module Models
3
+ class Bottleneck < Torch::NN::Module
4
+ def initialize(inplanes, planes, stride: 1, downsample: nil, groups: 1, base_width: 64, dilation: 1, norm_layer: nil)
5
+ super()
6
+ norm_layer ||= Torch::NN::BatchNorm2d
7
+ width = (planes * (base_width / 64.0)).to_i * groups
8
+ # Both self.conv2 and self.downsample layers downsample the input when stride != 1
9
+ @conv1 = Torch::NN::Conv2d.new(inplanes, width, 1, stride: 1, bias: false)
10
+ @bn1 = norm_layer.new(width)
11
+ @conv2 = Torch::NN::Conv2d.new(width, width, 3, stride: stride, padding: dilation, groups: groups, bias: false, dilation: dilation)
12
+ @bn2 = norm_layer.new(width)
13
+ @conv3 = Torch::NN::Conv2d.new(width, planes * self.class.expansion, 1, stride: 1, bias: false)
14
+ @bn3 = norm_layer.new(planes * self.class.expansion)
15
+ @relu = Torch::NN::ReLU.new(inplace: true)
16
+ @downsample = downsample
17
+ @stride = stride
18
+ end
19
+
20
+ def forward(x)
21
+ identity = x
22
+
23
+ out = @conv1.call(x)
24
+ out = @bn1.call(out)
25
+ out = @relu.call(out)
26
+
27
+ out = @conv2.call(out)
28
+ out = @bn2.call(out)
29
+ out = @relu.call(out)
30
+
31
+ out = @conv3.call(out)
32
+ out = @bn3.call(out)
33
+
34
+ identity = @downsample.call(x) if @downsample
35
+
36
+ out += identity
37
+ out = @relu.call(out)
38
+
39
+ out
40
+ end
41
+
42
+ def self.expansion
43
+ 4
44
+ end
45
+ end
46
+ end
47
+ end
@@ -0,0 +1,129 @@
1
+ module TorchVision
2
+ module Models
3
+ class ResNet < Torch::NN::Module
4
+ MODEL_URLS = {
5
+ "resnet18" => "https://download.pytorch.org/models/resnet18-5c106cde.pth",
6
+ "resnet34" => "https://download.pytorch.org/models/resnet34-333f7ec4.pth",
7
+ "resnet50" => "https://download.pytorch.org/models/resnet50-19c8e357.pth",
8
+ "resnet101" => "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth",
9
+ "resnet152" => "https://download.pytorch.org/models/resnet152-b121ed2d.pth",
10
+ "resnext50_32x4d" => "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth",
11
+ "resnext101_32x8d" => "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth",
12
+ "wide_resnet50_2" => "https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth",
13
+ "wide_resnet101_2" => "https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth"
14
+ }
15
+
16
+ def initialize(block, layers, num_classes=1000, zero_init_residual: false,
17
+ groups: 1, width_per_group: 64, replace_stride_with_dilation: nil, norm_layer: nil)
18
+
19
+ super()
20
+ norm_layer ||= Torch::NN::BatchNorm2d
21
+ @norm_layer = norm_layer
22
+
23
+ @inplanes = 64
24
+ @dilation = 1
25
+ if replace_stride_with_dilation.nil?
26
+ # each element in the tuple indicates if we should replace
27
+ # the 2x2 stride with a dilated convolution instead
28
+ replace_stride_with_dilation = [false, false, false]
29
+ end
30
+ if replace_stride_with_dilation.length != 3
31
+ raise ArgumentError, "replace_stride_with_dilation should be nil or a 3-element tuple, got #{replace_stride_with_dilation}"
32
+ end
33
+ @groups = groups
34
+ @base_width = width_per_group
35
+ @conv1 = Torch::NN::Conv2d.new(3, @inplanes, 7, stride: 2, padding: 3, bias: false)
36
+ @bn1 = norm_layer.new(@inplanes)
37
+ @relu = Torch::NN::ReLU.new(inplace: true)
38
+ @maxpool = Torch::NN::MaxPool2d.new(3, stride: 2, padding: 1)
39
+ @layer1 = _make_layer(block, 64, layers[0])
40
+ @layer2 = _make_layer(block, 128, layers[1], stride: 2, dilate: replace_stride_with_dilation[0])
41
+ @layer3 = _make_layer(block, 256, layers[2], stride: 2, dilate: replace_stride_with_dilation[1])
42
+ @layer4 = _make_layer(block, 512, layers[3], stride: 2, dilate: replace_stride_with_dilation[2])
43
+ @avgpool = Torch::NN::AdaptiveAvgPool2d.new([1, 1])
44
+ @fc = Torch::NN::Linear.new(512 * block.expansion, num_classes)
45
+
46
+ modules.each do |m|
47
+ case m
48
+ when Torch::NN::Conv2d
49
+ Torch::NN::Init.kaiming_normal!(m.weight, mode: "fan_out", nonlinearity: "relu")
50
+ when Torch::NN::BatchNorm2d, Torch::NN::GroupNorm
51
+ Torch::NN::Init.constant!(m.weight, 1)
52
+ Torch::NN::Init.constant!(m.bias, 0)
53
+ end
54
+ end
55
+
56
+ # Zero-initialize the last BN in each residual branch,
57
+ # so that the residual branch starts with zeros, and each residual block behaves like an identity.
58
+ # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
59
+ if zero_init_residual
60
+ modules.each do |m|
61
+ case m
62
+ when Bottleneck
63
+ Torch::NN::Init.constant!(m.bn3.weight, 0)
64
+ when BasicBlock
65
+ Torch::NN::Init.constant!(m.bn2.weight, 0)
66
+ end
67
+ end
68
+ end
69
+ end
70
+
71
+ def _make_layer(block, planes, blocks, stride: 1, dilate: false)
72
+ norm_layer = @norm_layer
73
+ downsample = nil
74
+ previous_dilation = @dilation
75
+ if dilate
76
+ @dilation *= stride
77
+ stride = 1
78
+ end
79
+ if stride != 1 || @inplanes != planes * block.expansion
80
+ downsample = Torch::NN::Sequential.new(
81
+ Torch::NN::Conv2d.new(@inplanes, planes * block.expansion, 1, stride: stride, bias: false),
82
+ norm_layer.new(planes * block.expansion)
83
+ )
84
+ end
85
+
86
+ layers = []
87
+ layers << block.new(@inplanes, planes, stride: stride, downsample: downsample, groups: @groups, base_width: @base_width, dilation: previous_dilation, norm_layer: norm_layer)
88
+ @inplanes = planes * block.expansion
89
+ (blocks - 1).times do
90
+ layers << block.new(@inplanes, planes, groups: @groups, base_width: @base_width, dilation: @dilation, norm_layer: norm_layer)
91
+ end
92
+
93
+ Torch::NN::Sequential.new(*layers)
94
+ end
95
+
96
+ def _forward_impl(x)
97
+ x = @conv1.call(x)
98
+ x = @bn1.call(x)
99
+ x = @relu.call(x)
100
+ x = @maxpool.call(x)
101
+
102
+ x = @layer1.call(x)
103
+ x = @layer2.call(x)
104
+ x = @layer3.call(x)
105
+ x = @layer4.call(x)
106
+
107
+ x = @avgpool.call(x)
108
+ x = Torch.flatten(x, 1)
109
+ x = @fc.call(x)
110
+
111
+ x
112
+ end
113
+
114
+ def forward(x)
115
+ _forward_impl(x)
116
+ end
117
+
118
+ def self.make_model(arch, block, layers, pretrained: false, **kwargs)
119
+ model = ResNet.new(block, layers, **kwargs)
120
+ if pretrained
121
+ url = MODEL_URLS[arch]
122
+ state_dict = Torch::Hub.load_state_dict_from_url(url)
123
+ model.load_state_dict(state_dict)
124
+ end
125
+ model
126
+ end
127
+ end
128
+ end
129
+ end
@@ -0,0 +1,9 @@
1
+ module TorchVision
2
+ module Models
3
+ module ResNet101
4
+ def self.new(**kwargs)
5
+ ResNet.make_model("resnet101", Bottleneck, [3, 4, 23, 3], **kwargs)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,9 @@
1
+ module TorchVision
2
+ module Models
3
+ module ResNet152
4
+ def self.new(**kwargs)
5
+ ResNet.make_model("resnet152", Bottleneck, [3, 8, 36, 3], **kwargs)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,9 @@
1
+ module TorchVision
2
+ module Models
3
+ module ResNet18
4
+ def self.new(**kwargs)
5
+ ResNet.make_model("resnet18", BasicBlock, [2, 2, 2, 2], **kwargs)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,9 @@
1
+ module TorchVision
2
+ module Models
3
+ module ResNet34
4
+ def self.new(**kwargs)
5
+ ResNet.make_model("resnet34", BasicBlock, [3, 4, 6, 3], **kwargs)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,9 @@
1
+ module TorchVision
2
+ module Models
3
+ module ResNet50
4
+ def self.new(**kwargs)
5
+ ResNet.make_model("resnet50", Bottleneck, [3, 4, 6, 3], **kwargs)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,11 @@
1
+ module TorchVision
2
+ module Models
3
+ module ResNext101_32x8d
4
+ def self.new(**kwargs)
5
+ kwargs[:groups] = 32
6
+ kwargs[:width_per_group] = 8
7
+ ResNet.make_model("resnext101_32x8d", Bottleneck, [3, 4, 23, 3], **kwargs)
8
+ end
9
+ end
10
+ end
11
+ end
@@ -0,0 +1,11 @@
1
+ module TorchVision
2
+ module Models
3
+ module ResNext50_32x4d
4
+ def self.new(**kwargs)
5
+ kwargs[:groups] = 32
6
+ kwargs[:width_per_group] = 4
7
+ ResNet.make_model("resnext50_32x4d", Bottleneck, [3, 4, 6, 3], **kwargs)
8
+ end
9
+ end
10
+ end
11
+ end
@@ -0,0 +1,93 @@
1
+ module TorchVision
2
+ module Models
3
+ class VGG < Torch::NN::Module
4
+ MODEL_URLS = {
5
+ "vgg11" => "https://download.pytorch.org/models/vgg11-bbd30ac9.pth",
6
+ "vgg13" => "https://download.pytorch.org/models/vgg13-c768596a.pth",
7
+ "vgg16" => "https://download.pytorch.org/models/vgg16-397923af.pth",
8
+ "vgg19" => "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth",
9
+ "vgg11_bn" => "https://download.pytorch.org/models/vgg11_bn-6002323d.pth",
10
+ "vgg13_bn" => "https://download.pytorch.org/models/vgg13_bn-abd245e5.pth",
11
+ "vgg16_bn" => "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth",
12
+ "vgg19_bn" => "https://download.pytorch.org/models/vgg19_bn-c79401a0.pth"
13
+ }
14
+
15
+ def initialize(features, num_classes: 1000, init_weights: true)
16
+ super()
17
+ @features = features
18
+ @avgpool = Torch::NN::AdaptiveAvgPool2d.new([7, 7])
19
+ @classifier = Torch::NN::Sequential.new(
20
+ Torch::NN::Linear.new(512 * 7 * 7, 4096),
21
+ Torch::NN::ReLU.new(inplace: true),
22
+ Torch::NN::Dropout.new,
23
+ Torch::NN::Linear.new(4096, 4096),
24
+ Torch::NN::ReLU.new(inplace: true),
25
+ Torch::NN::Dropout.new,
26
+ Torch::NN::Linear.new(4096, num_classes)
27
+ )
28
+ _initialize_weights if init_weights
29
+ end
30
+
31
+ def forward(x)
32
+ x = @features.call(x)
33
+ x = @avgpool.call(x)
34
+ x = Torch.flatten(x, 1)
35
+ x = @classifier.call(x)
36
+ x
37
+ end
38
+
39
+ def _initialize_weights
40
+ modules.each do |m|
41
+ case m
42
+ when Torch::NN::Conv2d
43
+ Torch::NN::Init.kaiming_normal!(m.weight, mode: "fan_out", nonlinearity: "relu")
44
+ Torch::NN::Init.constant!(m.bias, 0) if m.bias
45
+ when Torch::NN::BatchNorm2d
46
+ Torch::NN::Init.constant!(m.weight, 1)
47
+ Torch::NN::Init.constant!(m.bias, 0)
48
+ when Torch::NN::Linear
49
+ Torch::NN::Init.normal!(m.weight, mean: 0, std: 0.01)
50
+ Torch::NN::Init.constant!(m.bias, 0)
51
+ end
52
+ end
53
+ end
54
+
55
+ CFGS = {
56
+ "A" => [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
57
+ "B" => [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
58
+ "D" => [64, 64, "M", 128, 128, "M", 256, 256, 256, "M", 512, 512, 512, "M", 512, 512, 512, "M"],
59
+ "E" => [64, 64, "M", 128, 128, "M", 256, 256, 256, 256, "M", 512, 512, 512, 512, "M", 512, 512, 512, 512, "M"],
60
+ }
61
+
62
+ def self.make_model(arch, cfg, batch_norm, pretrained: false, **kwargs)
63
+ kwargs[:init_weights] = false if pretrained
64
+ model = VGG.new(make_layers(CFGS[cfg], batch_norm), **kwargs)
65
+ if pretrained
66
+ url = MODEL_URLS[arch]
67
+ state_dict = Torch::Hub.load_state_dict_from_url(url)
68
+ model.load_state_dict(state_dict)
69
+ end
70
+ model
71
+ end
72
+
73
+ def self.make_layers(cfg, batch_norm)
74
+ layers = []
75
+ in_channels = 3
76
+ cfg.each do |v|
77
+ if v == "M"
78
+ layers += [Torch::NN::MaxPool2d.new(2, stride: 2)]
79
+ else
80
+ conv2d = Torch::NN::Conv2d.new(in_channels, v, 3, padding: 1)
81
+ if batch_norm
82
+ layers += [conv2d, Torch::NN::BatchNorm2d.new(v), Torch::NN::ReLU.new(inplace: true)]
83
+ else
84
+ layers += [conv2d, Torch::NN::ReLU.new(inplace: true)]
85
+ end
86
+ in_channels = v
87
+ end
88
+ end
89
+ Torch::NN::Sequential.new(*layers)
90
+ end
91
+ end
92
+ end
93
+ end
@@ -0,0 +1,9 @@
1
+ module TorchVision
2
+ module Models
3
+ module VGG11
4
+ def self.new(**kwargs)
5
+ VGG.make_model("vgg11", "A", false, **kwargs)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,9 @@
1
+ module TorchVision
2
+ module Models
3
+ module VGG11BN
4
+ def self.new(**kwargs)
5
+ VGG.make_model("vgg11_bn", "A", true, **kwargs)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,9 @@
1
+ module TorchVision
2
+ module Models
3
+ module VGG13
4
+ def self.new(**kwargs)
5
+ VGG.make_model("vgg13", "B", false, **kwargs)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,9 @@
1
+ module TorchVision
2
+ module Models
3
+ module VGG13BN
4
+ def self.new(**kwargs)
5
+ VGG.make_model("vgg13_bn", "B", true, **kwargs)
6
+ end
7
+ end
8
+ end
9
+ end