torchvision 0.1.0 → 0.2.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +35 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +133 -5
  5. data/lib/torchvision.rb +40 -1
  6. data/lib/torchvision/datasets/cifar10.rb +117 -0
  7. data/lib/torchvision/datasets/cifar100.rb +41 -0
  8. data/lib/torchvision/datasets/dataset_folder.rb +91 -0
  9. data/lib/torchvision/datasets/fashion_mnist.rb +30 -0
  10. data/lib/torchvision/datasets/image_folder.rb +12 -0
  11. data/lib/torchvision/datasets/kmnist.rb +30 -0
  12. data/lib/torchvision/datasets/mnist.rb +47 -76
  13. data/lib/torchvision/datasets/vision_dataset.rb +67 -0
  14. data/lib/torchvision/models/alexnet.rb +42 -0
  15. data/lib/torchvision/models/basic_block.rb +46 -0
  16. data/lib/torchvision/models/bottleneck.rb +47 -0
  17. data/lib/torchvision/models/resnet.rb +129 -0
  18. data/lib/torchvision/models/resnet101.rb +9 -0
  19. data/lib/torchvision/models/resnet152.rb +9 -0
  20. data/lib/torchvision/models/resnet18.rb +9 -0
  21. data/lib/torchvision/models/resnet34.rb +9 -0
  22. data/lib/torchvision/models/resnet50.rb +9 -0
  23. data/lib/torchvision/models/resnext101_32x8d.rb +11 -0
  24. data/lib/torchvision/models/resnext50_32x4d.rb +11 -0
  25. data/lib/torchvision/models/vgg.rb +93 -0
  26. data/lib/torchvision/models/vgg11.rb +9 -0
  27. data/lib/torchvision/models/vgg11_bn.rb +9 -0
  28. data/lib/torchvision/models/vgg13.rb +9 -0
  29. data/lib/torchvision/models/vgg13_bn.rb +9 -0
  30. data/lib/torchvision/models/vgg16.rb +9 -0
  31. data/lib/torchvision/models/vgg16_bn.rb +9 -0
  32. data/lib/torchvision/models/vgg19.rb +9 -0
  33. data/lib/torchvision/models/vgg19_bn.rb +9 -0
  34. data/lib/torchvision/models/wide_resnet101_2.rb +10 -0
  35. data/lib/torchvision/models/wide_resnet50_2.rb +10 -0
  36. data/lib/torchvision/transforms/center_crop.rb +13 -0
  37. data/lib/torchvision/transforms/compose.rb +2 -2
  38. data/lib/torchvision/transforms/functional.rb +142 -7
  39. data/lib/torchvision/transforms/normalize.rb +2 -2
  40. data/lib/torchvision/transforms/random_horizontal_flip.rb +18 -0
  41. data/lib/torchvision/transforms/random_resized_crop.rb +70 -0
  42. data/lib/torchvision/transforms/random_vertical_flip.rb +18 -0
  43. data/lib/torchvision/transforms/resize.rb +13 -0
  44. data/lib/torchvision/transforms/to_tensor.rb +2 -2
  45. data/lib/torchvision/utils.rb +120 -0
  46. data/lib/torchvision/version.rb +1 -1
  47. metadata +50 -57
@@ -0,0 +1,46 @@
1
+ module TorchVision
2
+ module Models
3
+ class BasicBlock < Torch::NN::Module
4
+ def initialize(inplanes, planes, stride: 1, downsample: nil, groups: 1, base_width: 64, dilation: 1, norm_layer: nil)
5
+ super()
6
+ norm_layer ||= Torch::NN::BatchNorm2d
7
+ if groups != 1 || base_width != 64
8
+ raise ArgumentError, "BasicBlock only supports groups=1 and base_width=64"
9
+ end
10
+ if dilation > 1
11
+ raise NotImplementedError, "Dilation > 1 not supported in BasicBlock"
12
+ end
13
+ # Both self.conv1 and self.downsample layers downsample the input when stride != 1
14
+ @conv1 = Torch::NN::Conv2d.new(inplanes, planes, 3, stride: stride, padding: 1, groups: 1, bias: false, dilation: 1)
15
+ @bn1 = norm_layer.new(planes)
16
+ @relu = Torch::NN::ReLU.new(inplace: true)
17
+ @conv2 = Torch::NN::Conv2d.new(planes, planes, 3, stride: 1, padding: 1, groups: 1, bias: false, dilation: 1)
18
+ @bn2 = norm_layer.new(planes)
19
+ @downsample = downsample
20
+ @stride = stride
21
+ end
22
+
23
+ def forward(x)
24
+ identity = x
25
+
26
+ out = @conv1.call(x)
27
+ out = @bn1.call(out)
28
+ out = @relu.call(out)
29
+
30
+ out = @conv2.call(out)
31
+ out = @bn2.call(out)
32
+
33
+ identity = @downsample.call(x) if @downsample
34
+
35
+ out += identity
36
+ out = @relu.call(out)
37
+
38
+ out
39
+ end
40
+
41
+ def self.expansion
42
+ 1
43
+ end
44
+ end
45
+ end
46
+ end
@@ -0,0 +1,47 @@
1
+ module TorchVision
2
+ module Models
3
+ class Bottleneck < Torch::NN::Module
4
+ def initialize(inplanes, planes, stride: 1, downsample: nil, groups: 1, base_width: 64, dilation: 1, norm_layer: nil)
5
+ super()
6
+ norm_layer ||= Torch::NN::BatchNorm2d
7
+ width = (planes * (base_width / 64.0)).to_i * groups
8
+ # Both self.conv2 and self.downsample layers downsample the input when stride != 1
9
+ @conv1 = Torch::NN::Conv2d.new(inplanes, width, 1, stride: 1, bias: false)
10
+ @bn1 = norm_layer.new(width)
11
+ @conv2 = Torch::NN::Conv2d.new(width, width, 3, stride: stride, padding: dilation, groups: groups, bias: false, dilation: dilation)
12
+ @bn2 = norm_layer.new(width)
13
+ @conv3 = Torch::NN::Conv2d.new(width, planes * self.class.expansion, 1, stride: 1, bias: false)
14
+ @bn3 = norm_layer.new(planes * self.class.expansion)
15
+ @relu = Torch::NN::ReLU.new(inplace: true)
16
+ @downsample = downsample
17
+ @stride = stride
18
+ end
19
+
20
+ def forward(x)
21
+ identity = x
22
+
23
+ out = @conv1.call(x)
24
+ out = @bn1.call(out)
25
+ out = @relu.call(out)
26
+
27
+ out = @conv2.call(out)
28
+ out = @bn2.call(out)
29
+ out = @relu.call(out)
30
+
31
+ out = @conv3.call(out)
32
+ out = @bn3.call(out)
33
+
34
+ identity = @downsample.call(x) if @downsample
35
+
36
+ out += identity
37
+ out = @relu.call(out)
38
+
39
+ out
40
+ end
41
+
42
+ def self.expansion
43
+ 4
44
+ end
45
+ end
46
+ end
47
+ end
@@ -0,0 +1,129 @@
1
+ module TorchVision
2
+ module Models
3
+ class ResNet < Torch::NN::Module
4
+ MODEL_URLS = {
5
+ "resnet18" => "https://download.pytorch.org/models/resnet18-5c106cde.pth",
6
+ "resnet34" => "https://download.pytorch.org/models/resnet34-333f7ec4.pth",
7
+ "resnet50" => "https://download.pytorch.org/models/resnet50-19c8e357.pth",
8
+ "resnet101" => "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth",
9
+ "resnet152" => "https://download.pytorch.org/models/resnet152-b121ed2d.pth",
10
+ "resnext50_32x4d" => "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth",
11
+ "resnext101_32x8d" => "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth",
12
+ "wide_resnet50_2" => "https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth",
13
+ "wide_resnet101_2" => "https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth"
14
+ }
15
+
16
+ def initialize(block, layers, num_classes=1000, zero_init_residual: false,
17
+ groups: 1, width_per_group: 64, replace_stride_with_dilation: nil, norm_layer: nil)
18
+
19
+ super()
20
+ norm_layer ||= Torch::NN::BatchNorm2d
21
+ @norm_layer = norm_layer
22
+
23
+ @inplanes = 64
24
+ @dilation = 1
25
+ if replace_stride_with_dilation.nil?
26
+ # each element in the tuple indicates if we should replace
27
+ # the 2x2 stride with a dilated convolution instead
28
+ replace_stride_with_dilation = [false, false, false]
29
+ end
30
+ if replace_stride_with_dilation.length != 3
31
+ raise ArgumentError, "replace_stride_with_dilation should be nil or a 3-element tuple, got #{replace_stride_with_dilation}"
32
+ end
33
+ @groups = groups
34
+ @base_width = width_per_group
35
+ @conv1 = Torch::NN::Conv2d.new(3, @inplanes, 7, stride: 2, padding: 3, bias: false)
36
+ @bn1 = norm_layer.new(@inplanes)
37
+ @relu = Torch::NN::ReLU.new(inplace: true)
38
+ @maxpool = Torch::NN::MaxPool2d.new(3, stride: 2, padding: 1)
39
+ @layer1 = _make_layer(block, 64, layers[0])
40
+ @layer2 = _make_layer(block, 128, layers[1], stride: 2, dilate: replace_stride_with_dilation[0])
41
+ @layer3 = _make_layer(block, 256, layers[2], stride: 2, dilate: replace_stride_with_dilation[1])
42
+ @layer4 = _make_layer(block, 512, layers[3], stride: 2, dilate: replace_stride_with_dilation[2])
43
+ @avgpool = Torch::NN::AdaptiveAvgPool2d.new([1, 1])
44
+ @fc = Torch::NN::Linear.new(512 * block.expansion, num_classes)
45
+
46
+ modules.each do |m|
47
+ case m
48
+ when Torch::NN::Conv2d
49
+ Torch::NN::Init.kaiming_normal!(m.weight, mode: "fan_out", nonlinearity: "relu")
50
+ when Torch::NN::BatchNorm2d, Torch::NN::GroupNorm
51
+ Torch::NN::Init.constant!(m.weight, 1)
52
+ Torch::NN::Init.constant!(m.bias, 0)
53
+ end
54
+ end
55
+
56
+ # Zero-initialize the last BN in each residual branch,
57
+ # so that the residual branch starts with zeros, and each residual block behaves like an identity.
58
+ # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
59
+ if zero_init_residual
60
+ modules.each do |m|
61
+ case m
62
+ when Bottleneck
63
+ Torch::NN::Init.constant!(m.bn3.weight, 0)
64
+ when BasicBlock
65
+ Torch::NN::Init.constant!(m.bn2.weight, 0)
66
+ end
67
+ end
68
+ end
69
+ end
70
+
71
+ def _make_layer(block, planes, blocks, stride: 1, dilate: false)
72
+ norm_layer = @norm_layer
73
+ downsample = nil
74
+ previous_dilation = @dilation
75
+ if dilate
76
+ @dilation *= stride
77
+ stride = 1
78
+ end
79
+ if stride != 1 || @inplanes != planes * block.expansion
80
+ downsample = Torch::NN::Sequential.new(
81
+ Torch::NN::Conv2d.new(@inplanes, planes * block.expansion, 1, stride: stride, bias: false),
82
+ norm_layer.new(planes * block.expansion)
83
+ )
84
+ end
85
+
86
+ layers = []
87
+ layers << block.new(@inplanes, planes, stride: stride, downsample: downsample, groups: @groups, base_width: @base_width, dilation: previous_dilation, norm_layer: norm_layer)
88
+ @inplanes = planes * block.expansion
89
+ (blocks - 1).times do
90
+ layers << block.new(@inplanes, planes, groups: @groups, base_width: @base_width, dilation: @dilation, norm_layer: norm_layer)
91
+ end
92
+
93
+ Torch::NN::Sequential.new(*layers)
94
+ end
95
+
96
+ def _forward_impl(x)
97
+ x = @conv1.call(x)
98
+ x = @bn1.call(x)
99
+ x = @relu.call(x)
100
+ x = @maxpool.call(x)
101
+
102
+ x = @layer1.call(x)
103
+ x = @layer2.call(x)
104
+ x = @layer3.call(x)
105
+ x = @layer4.call(x)
106
+
107
+ x = @avgpool.call(x)
108
+ x = Torch.flatten(x, 1)
109
+ x = @fc.call(x)
110
+
111
+ x
112
+ end
113
+
114
+ def forward(x)
115
+ _forward_impl(x)
116
+ end
117
+
118
+ def self.make_model(arch, block, layers, pretrained: false, **kwargs)
119
+ model = ResNet.new(block, layers, **kwargs)
120
+ if pretrained
121
+ url = MODEL_URLS[arch]
122
+ state_dict = Torch::Hub.load_state_dict_from_url(url)
123
+ model.load_state_dict(state_dict)
124
+ end
125
+ model
126
+ end
127
+ end
128
+ end
129
+ end
@@ -0,0 +1,9 @@
1
+ module TorchVision
2
+ module Models
3
+ module ResNet101
4
+ def self.new(**kwargs)
5
+ ResNet.make_model("resnet101", Bottleneck, [3, 4, 23, 3], **kwargs)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,9 @@
1
+ module TorchVision
2
+ module Models
3
+ module ResNet152
4
+ def self.new(**kwargs)
5
+ ResNet.make_model("resnet152", Bottleneck, [3, 8, 36, 3], **kwargs)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,9 @@
1
+ module TorchVision
2
+ module Models
3
+ module ResNet18
4
+ def self.new(**kwargs)
5
+ ResNet.make_model("resnet18", BasicBlock, [2, 2, 2, 2], **kwargs)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,9 @@
1
+ module TorchVision
2
+ module Models
3
+ module ResNet34
4
+ def self.new(**kwargs)
5
+ ResNet.make_model("resnet34", BasicBlock, [3, 4, 6, 3], **kwargs)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,9 @@
1
+ module TorchVision
2
+ module Models
3
+ module ResNet50
4
+ def self.new(**kwargs)
5
+ ResNet.make_model("resnet50", Bottleneck, [3, 4, 6, 3], **kwargs)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,11 @@
1
+ module TorchVision
2
+ module Models
3
+ module ResNext101_32x8d
4
+ def self.new(**kwargs)
5
+ kwargs[:groups] = 32
6
+ kwargs[:width_per_group] = 8
7
+ ResNet.make_model("resnext101_32x8d", Bottleneck, [3, 4, 23, 3], **kwargs)
8
+ end
9
+ end
10
+ end
11
+ end
@@ -0,0 +1,11 @@
1
+ module TorchVision
2
+ module Models
3
+ module ResNext50_32x4d
4
+ def self.new(**kwargs)
5
+ kwargs[:groups] = 32
6
+ kwargs[:width_per_group] = 4
7
+ ResNet.make_model("resnext50_32x4d", Bottleneck, [3, 4, 6, 3], **kwargs)
8
+ end
9
+ end
10
+ end
11
+ end
@@ -0,0 +1,93 @@
1
+ module TorchVision
2
+ module Models
3
+ class VGG < Torch::NN::Module
4
+ MODEL_URLS = {
5
+ "vgg11" => "https://download.pytorch.org/models/vgg11-bbd30ac9.pth",
6
+ "vgg13" => "https://download.pytorch.org/models/vgg13-c768596a.pth",
7
+ "vgg16" => "https://download.pytorch.org/models/vgg16-397923af.pth",
8
+ "vgg19" => "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth",
9
+ "vgg11_bn" => "https://download.pytorch.org/models/vgg11_bn-6002323d.pth",
10
+ "vgg13_bn" => "https://download.pytorch.org/models/vgg13_bn-abd245e5.pth",
11
+ "vgg16_bn" => "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth",
12
+ "vgg19_bn" => "https://download.pytorch.org/models/vgg19_bn-c79401a0.pth"
13
+ }
14
+
15
+ def initialize(features, num_classes: 1000, init_weights: true)
16
+ super()
17
+ @features = features
18
+ @avgpool = Torch::NN::AdaptiveAvgPool2d.new([7, 7])
19
+ @classifier = Torch::NN::Sequential.new(
20
+ Torch::NN::Linear.new(512 * 7 * 7, 4096),
21
+ Torch::NN::ReLU.new(inplace: true),
22
+ Torch::NN::Dropout.new,
23
+ Torch::NN::Linear.new(4096, 4096),
24
+ Torch::NN::ReLU.new(inplace: true),
25
+ Torch::NN::Dropout.new,
26
+ Torch::NN::Linear.new(4096, num_classes)
27
+ )
28
+ _initialize_weights if init_weights
29
+ end
30
+
31
+ def forward(x)
32
+ x = @features.call(x)
33
+ x = @avgpool.call(x)
34
+ x = Torch.flatten(x, 1)
35
+ x = @classifier.call(x)
36
+ x
37
+ end
38
+
39
+ def _initialize_weights
40
+ modules.each do |m|
41
+ case m
42
+ when Torch::NN::Conv2d
43
+ Torch::NN::Init.kaiming_normal!(m.weight, mode: "fan_out", nonlinearity: "relu")
44
+ Torch::NN::Init.constant!(m.bias, 0) if m.bias
45
+ when Torch::NN::BatchNorm2d
46
+ Torch::NN::Init.constant!(m.weight, 1)
47
+ Torch::NN::Init.constant!(m.bias, 0)
48
+ when Torch::NN::Linear
49
+ Torch::NN::Init.normal!(m.weight, mean: 0, std: 0.01)
50
+ Torch::NN::Init.constant!(m.bias, 0)
51
+ end
52
+ end
53
+ end
54
+
55
+ CFGS = {
56
+ "A" => [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
57
+ "B" => [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
58
+ "D" => [64, 64, "M", 128, 128, "M", 256, 256, 256, "M", 512, 512, 512, "M", 512, 512, 512, "M"],
59
+ "E" => [64, 64, "M", 128, 128, "M", 256, 256, 256, 256, "M", 512, 512, 512, 512, "M", 512, 512, 512, 512, "M"],
60
+ }
61
+
62
+ def self.make_model(arch, cfg, batch_norm, pretrained: false, **kwargs)
63
+ kwargs[:init_weights] = false if pretrained
64
+ model = VGG.new(make_layers(CFGS[cfg], batch_norm), **kwargs)
65
+ if pretrained
66
+ url = MODEL_URLS[arch]
67
+ state_dict = Torch::Hub.load_state_dict_from_url(url)
68
+ model.load_state_dict(state_dict)
69
+ end
70
+ model
71
+ end
72
+
73
+ def self.make_layers(cfg, batch_norm)
74
+ layers = []
75
+ in_channels = 3
76
+ cfg.each do |v|
77
+ if v == "M"
78
+ layers += [Torch::NN::MaxPool2d.new(2, stride: 2)]
79
+ else
80
+ conv2d = Torch::NN::Conv2d.new(in_channels, v, 3, padding: 1)
81
+ if batch_norm
82
+ layers += [conv2d, Torch::NN::BatchNorm2d.new(v), Torch::NN::ReLU.new(inplace: true)]
83
+ else
84
+ layers += [conv2d, Torch::NN::ReLU.new(inplace: true)]
85
+ end
86
+ in_channels = v
87
+ end
88
+ end
89
+ Torch::NN::Sequential.new(*layers)
90
+ end
91
+ end
92
+ end
93
+ end
@@ -0,0 +1,9 @@
1
+ module TorchVision
2
+ module Models
3
+ module VGG11
4
+ def self.new(**kwargs)
5
+ VGG.make_model("vgg11", "A", false, **kwargs)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,9 @@
1
+ module TorchVision
2
+ module Models
3
+ module VGG11BN
4
+ def self.new(**kwargs)
5
+ VGG.make_model("vgg11_bn", "A", true, **kwargs)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,9 @@
1
+ module TorchVision
2
+ module Models
3
+ module VGG13
4
+ def self.new(**kwargs)
5
+ VGG.make_model("vgg13", "B", false, **kwargs)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,9 @@
1
+ module TorchVision
2
+ module Models
3
+ module VGG13BN
4
+ def self.new(**kwargs)
5
+ VGG.make_model("vgg13_bn", "B", true, **kwargs)
6
+ end
7
+ end
8
+ end
9
+ end