torchvision 0.1.2 → 0.1.3
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +8 -0
- data/README.md +46 -1
- data/lib/torchvision.rb +18 -0
- data/lib/torchvision/models/alexnet.rb +42 -0
- data/lib/torchvision/models/resnet.rb +22 -0
- data/lib/torchvision/models/resnet101.rb +9 -0
- data/lib/torchvision/models/resnet152.rb +9 -0
- data/lib/torchvision/models/resnet18.rb +2 -8
- data/lib/torchvision/models/resnet34.rb +9 -0
- data/lib/torchvision/models/resnet50.rb +9 -0
- data/lib/torchvision/models/resnext101_32x8d.rb +11 -0
- data/lib/torchvision/models/resnext50_32x4d.rb +11 -0
- data/lib/torchvision/models/vgg.rb +93 -0
- data/lib/torchvision/models/vgg11.rb +9 -0
- data/lib/torchvision/models/vgg11_bn.rb +9 -0
- data/lib/torchvision/models/vgg13.rb +9 -0
- data/lib/torchvision/models/vgg13_bn.rb +9 -0
- data/lib/torchvision/models/vgg16.rb +9 -0
- data/lib/torchvision/models/vgg16_bn.rb +9 -0
- data/lib/torchvision/models/vgg19.rb +9 -0
- data/lib/torchvision/models/vgg19_bn.rb +9 -0
- data/lib/torchvision/models/wide_resnet101_2.rb +10 -0
- data/lib/torchvision/models/wide_resnet50_2.rb +10 -0
- data/lib/torchvision/version.rb +1 -1
- metadata +22 -4
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 44c29605f12dddf8196432223f2137ef9b9bef490996b718f5a9bdbc13dbd33f
|
4
|
+
data.tar.gz: 8790200a0ed8f7a275f99327431dc8be99a7578c473ec03411f9411cd10c6c93
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 65816ef10f524781553327f9634bb8818a7efa6fb072e81468949d937b5430dcddc7f6c8cf3b305c977dc4b1279d23b38fb034a4c277eecc1adcc0f2b8c99e3e
|
7
|
+
data.tar.gz: 01f485a78cd5a19c9a0dc987f4e46931b033ffb3a84af703b602b0e90e8c221c99b5e6318c86b18361562cc3ca582b3a0822de151a5edf63964002317a684bfa
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,11 @@
|
|
1
|
+
## 0.1.3 (2020-06-29)
|
2
|
+
|
3
|
+
- Added AlexNet model
|
4
|
+
- Added ResNet34, ResNet50, ResNet101, and ResNet152 models
|
5
|
+
- Added ResNeXt model
|
6
|
+
- Added VGG11, VGG13, VGG16, and VGG19 models
|
7
|
+
- Added Wide ResNet model
|
8
|
+
|
1
9
|
## 0.1.2 (2020-04-29)
|
2
10
|
|
3
11
|
- Added CIFAR10, CIFAR100, FashionMNIST, and KMNIST datasets
|
data/README.md
CHANGED
@@ -45,8 +45,53 @@ TorchVision::Transforms::Compose.new([
|
|
45
45
|
|
46
46
|
## Models
|
47
47
|
|
48
|
+
- [AlexNet](#alexnet)
|
49
|
+
- [ResNet](#resnet)
|
50
|
+
- [ResNeXt](#resnext)
|
51
|
+
- [VGG](#vgg)
|
52
|
+
- [Wide ResNet](#wide-resnet)
|
53
|
+
|
54
|
+
### AlexNet
|
55
|
+
|
56
|
+
```ruby
|
57
|
+
TorchVision::Models::AlexNet.new
|
58
|
+
```
|
59
|
+
|
60
|
+
### ResNet
|
61
|
+
|
62
|
+
```ruby
|
63
|
+
TorchVision::Models::ResNet18.new
|
64
|
+
TorchVision::Models::ResNet34.new
|
65
|
+
TorchVision::Models::ResNet50.new
|
66
|
+
TorchVision::Models::ResNet101.new
|
67
|
+
TorchVision::Models::ResNet152.new
|
68
|
+
```
|
69
|
+
|
70
|
+
### ResNeXt
|
71
|
+
|
72
|
+
```ruby
|
73
|
+
TorchVision::Models::ResNext52_32x4d.new
|
74
|
+
TorchVision::Models::ResNext101_32x8d.new
|
75
|
+
```
|
76
|
+
|
77
|
+
### VGG
|
78
|
+
|
79
|
+
```ruby
|
80
|
+
TorchVision::Models::VGG11.new
|
81
|
+
TorchVision::Models::VGG11BN.new
|
82
|
+
TorchVision::Models::VGG13.new
|
83
|
+
TorchVision::Models::VGG13BN.new
|
84
|
+
TorchVision::Models::VGG16.new
|
85
|
+
TorchVision::Models::VGG16BN.new
|
86
|
+
TorchVision::Models::VGG19.new
|
87
|
+
TorchVision::Models::VGG19BN.new
|
88
|
+
```
|
89
|
+
|
90
|
+
### Wide ResNet
|
91
|
+
|
48
92
|
```ruby
|
49
|
-
TorchVision::Models::
|
93
|
+
TorchVision::Models::WideResNet52_2.new
|
94
|
+
TorchVision::Models::WideResNet101_2.new
|
50
95
|
```
|
51
96
|
|
52
97
|
## Disclaimer
|
data/lib/torchvision.rb
CHANGED
@@ -21,10 +21,28 @@ require "torchvision/datasets/fashion_mnist"
|
|
21
21
|
require "torchvision/datasets/kmnist"
|
22
22
|
|
23
23
|
# models
|
24
|
+
require "torchvision/models/alexnet"
|
24
25
|
require "torchvision/models/basic_block"
|
25
26
|
require "torchvision/models/bottleneck"
|
26
27
|
require "torchvision/models/resnet"
|
27
28
|
require "torchvision/models/resnet18"
|
29
|
+
require "torchvision/models/resnet34"
|
30
|
+
require "torchvision/models/resnet50"
|
31
|
+
require "torchvision/models/resnet101"
|
32
|
+
require "torchvision/models/resnet152"
|
33
|
+
require "torchvision/models/resnext50_32x4d"
|
34
|
+
require "torchvision/models/resnext101_32x8d"
|
35
|
+
require "torchvision/models/vgg"
|
36
|
+
require "torchvision/models/vgg11"
|
37
|
+
require "torchvision/models/vgg11_bn"
|
38
|
+
require "torchvision/models/vgg13"
|
39
|
+
require "torchvision/models/vgg13_bn"
|
40
|
+
require "torchvision/models/vgg16"
|
41
|
+
require "torchvision/models/vgg16_bn"
|
42
|
+
require "torchvision/models/vgg19"
|
43
|
+
require "torchvision/models/vgg19_bn"
|
44
|
+
require "torchvision/models/wide_resnet50_2"
|
45
|
+
require "torchvision/models/wide_resnet101_2"
|
28
46
|
|
29
47
|
# transforms
|
30
48
|
require "torchvision/transforms/compose"
|
@@ -0,0 +1,42 @@
|
|
1
|
+
module TorchVision
|
2
|
+
module Models
|
3
|
+
class AlexNet < Torch::NN::Module
|
4
|
+
def initialize(num_classes: 1000)
|
5
|
+
super()
|
6
|
+
@features = Torch::NN::Sequential.new(
|
7
|
+
Torch::NN::Conv2d.new(3, 64, 11, stride: 4, padding: 2),
|
8
|
+
Torch::NN::ReLU.new(inplace: true),
|
9
|
+
Torch::NN::MaxPool2d.new(3, stride: 2),
|
10
|
+
Torch::NN::Conv2d.new(64, 192, 5, padding: 2),
|
11
|
+
Torch::NN::ReLU.new(inplace: true),
|
12
|
+
Torch::NN::MaxPool2d.new(3, stride: 2),
|
13
|
+
Torch::NN::Conv2d.new(192, 384, 3, padding: 1),
|
14
|
+
Torch::NN::ReLU.new(inplace: true),
|
15
|
+
Torch::NN::Conv2d.new(384, 256, 3, padding: 1),
|
16
|
+
Torch::NN::ReLU.new(inplace: true),
|
17
|
+
Torch::NN::Conv2d.new(256, 256, 3, padding: 1),
|
18
|
+
Torch::NN::ReLU.new(inplace: true),
|
19
|
+
Torch::NN::MaxPool2d.new(3, stride: 2),
|
20
|
+
)
|
21
|
+
@avgpool = Torch::NN::AdaptiveAvgPool2d.new([6, 6])
|
22
|
+
@classifier = Torch::NN::Sequential.new(
|
23
|
+
Torch::NN::Dropout.new,
|
24
|
+
Torch::NN::Linear.new(256 * 6 * 6, 4096),
|
25
|
+
Torch::NN::ReLU.new(inplace: true),
|
26
|
+
Torch::NN::Dropout.new,
|
27
|
+
Torch::NN::Linear.new(4096, 4096),
|
28
|
+
Torch::NN::ReLU.new(inplace: true),
|
29
|
+
Torch::NN::Linear.new(4096, num_classes)
|
30
|
+
)
|
31
|
+
end
|
32
|
+
|
33
|
+
def forward(x)
|
34
|
+
x = @features.call(x)
|
35
|
+
x = @avgpool.call(x)
|
36
|
+
x = Torch.flatten(x, 1)
|
37
|
+
x = @classifier.call(x)
|
38
|
+
x
|
39
|
+
end
|
40
|
+
end
|
41
|
+
end
|
42
|
+
end
|
@@ -1,6 +1,18 @@
|
|
1
1
|
module TorchVision
|
2
2
|
module Models
|
3
3
|
class ResNet < Torch::NN::Module
|
4
|
+
MODEL_URLS = {
|
5
|
+
"resnet18" => "https://download.pytorch.org/models/resnet18-5c106cde.pth",
|
6
|
+
"resnet34" => "https://download.pytorch.org/models/resnet34-333f7ec4.pth",
|
7
|
+
"resnet50" => "https://download.pytorch.org/models/resnet50-19c8e357.pth",
|
8
|
+
"resnet101" => "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth",
|
9
|
+
"resnet152" => "https://download.pytorch.org/models/resnet152-b121ed2d.pth",
|
10
|
+
"resnext50_32x4d" => "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth",
|
11
|
+
"resnext101_32x8d" => "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth",
|
12
|
+
"wide_resnet50_2" => "https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth",
|
13
|
+
"wide_resnet101_2" => "https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth"
|
14
|
+
}
|
15
|
+
|
4
16
|
def initialize(block, layers, num_classes=1000, zero_init_residual: false,
|
5
17
|
groups: 1, width_per_group: 64, replace_stride_with_dilation: nil, norm_layer: nil)
|
6
18
|
|
@@ -102,6 +114,16 @@ module TorchVision
|
|
102
114
|
def forward(x)
|
103
115
|
_forward_impl(x)
|
104
116
|
end
|
117
|
+
|
118
|
+
def self.make_model(arch, block, layers, pretrained: false, **kwargs)
|
119
|
+
model = ResNet.new(block, layers, **kwargs)
|
120
|
+
if pretrained
|
121
|
+
url = MODEL_URLS[arch]
|
122
|
+
state_dict = Torch::Hub.load_state_dict_from_url(url)
|
123
|
+
model.load_state_dict(state_dict)
|
124
|
+
end
|
125
|
+
model
|
126
|
+
end
|
105
127
|
end
|
106
128
|
end
|
107
129
|
end
|
@@ -1,14 +1,8 @@
|
|
1
1
|
module TorchVision
|
2
2
|
module Models
|
3
3
|
module ResNet18
|
4
|
-
def self.new(
|
5
|
-
|
6
|
-
if pretrained
|
7
|
-
url = "https://download.pytorch.org/models/resnet18-5c106cde.pth"
|
8
|
-
state_dict = Torch::Hub.load_state_dict_from_url(url)
|
9
|
-
model.load_state_dict(state_dict)
|
10
|
-
end
|
11
|
-
model
|
4
|
+
def self.new(**kwargs)
|
5
|
+
ResNet.make_model("resnet18", BasicBlock, [2, 2, 2, 2], **kwargs)
|
12
6
|
end
|
13
7
|
end
|
14
8
|
end
|
@@ -0,0 +1,93 @@
|
|
1
|
+
module TorchVision
|
2
|
+
module Models
|
3
|
+
class VGG < Torch::NN::Module
|
4
|
+
MODEL_URLS = {
|
5
|
+
"vgg11" => "https://download.pytorch.org/models/vgg11-bbd30ac9.pth",
|
6
|
+
"vgg13" => "https://download.pytorch.org/models/vgg13-c768596a.pth",
|
7
|
+
"vgg16" => "https://download.pytorch.org/models/vgg16-397923af.pth",
|
8
|
+
"vgg19" => "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth",
|
9
|
+
"vgg11_bn" => "https://download.pytorch.org/models/vgg11_bn-6002323d.pth",
|
10
|
+
"vgg13_bn" => "https://download.pytorch.org/models/vgg13_bn-abd245e5.pth",
|
11
|
+
"vgg16_bn" => "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth",
|
12
|
+
"vgg19_bn" => "https://download.pytorch.org/models/vgg19_bn-c79401a0.pth"
|
13
|
+
}
|
14
|
+
|
15
|
+
def initialize(features, num_classes: 1000, init_weights: true)
|
16
|
+
super()
|
17
|
+
@features = features
|
18
|
+
@avgpool = Torch::NN::AdaptiveAvgPool2d.new([7, 7])
|
19
|
+
@classifier = Torch::NN::Sequential.new(
|
20
|
+
Torch::NN::Linear.new(512 * 7 * 7, 4096),
|
21
|
+
Torch::NN::ReLU.new(inplace: true),
|
22
|
+
Torch::NN::Dropout.new,
|
23
|
+
Torch::NN::Linear.new(4096, 4096),
|
24
|
+
Torch::NN::ReLU.new(inplace: true),
|
25
|
+
Torch::NN::Dropout.new,
|
26
|
+
Torch::NN::Linear.new(4096, num_classes)
|
27
|
+
)
|
28
|
+
_initialize_weights if init_weights
|
29
|
+
end
|
30
|
+
|
31
|
+
def forward(x)
|
32
|
+
x = @features.call(x)
|
33
|
+
x = @avgpool.call(x)
|
34
|
+
x = Torch.flatten(x, 1)
|
35
|
+
x = @classifier.call(x)
|
36
|
+
x
|
37
|
+
end
|
38
|
+
|
39
|
+
def _initialize_weights
|
40
|
+
modules.each do |m|
|
41
|
+
case m
|
42
|
+
when Torch::NN::Conv2d
|
43
|
+
Torch::NN::Init.kaiming_normal!(m.weight, mode: "fan_out", nonlinearity: "relu")
|
44
|
+
Torch::NN::Init.constant!(m.bias, 0) if m.bias
|
45
|
+
when Torch::NN::BatchNorm2d
|
46
|
+
Torch::NN::Init.constant!(m.weight, 1)
|
47
|
+
Torch::NN::Init.constant!(m.bias, 0)
|
48
|
+
when Torch::NN::Linear
|
49
|
+
Torch::NN::Init.normal!(m.weight, mean: 0, std: 0.01)
|
50
|
+
Torch::NN::Init.constant!(m.bias, 0)
|
51
|
+
end
|
52
|
+
end
|
53
|
+
end
|
54
|
+
|
55
|
+
CFGS = {
|
56
|
+
"A" => [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
|
57
|
+
"B" => [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
|
58
|
+
"D" => [64, 64, "M", 128, 128, "M", 256, 256, 256, "M", 512, 512, 512, "M", 512, 512, 512, "M"],
|
59
|
+
"E" => [64, 64, "M", 128, 128, "M", 256, 256, 256, 256, "M", 512, 512, 512, 512, "M", 512, 512, 512, 512, "M"],
|
60
|
+
}
|
61
|
+
|
62
|
+
def self.make_model(arch, cfg, batch_norm, pretrained: false, **kwargs)
|
63
|
+
kwargs[:init_weights] = false if pretrained
|
64
|
+
model = VGG.new(make_layers(CFGS[cfg], batch_norm), **kwargs)
|
65
|
+
if pretrained
|
66
|
+
url = MODEL_URLS[arch]
|
67
|
+
state_dict = Torch::Hub.load_state_dict_from_url(url)
|
68
|
+
model.load_state_dict(state_dict)
|
69
|
+
end
|
70
|
+
model
|
71
|
+
end
|
72
|
+
|
73
|
+
def self.make_layers(cfg, batch_norm)
|
74
|
+
layers = []
|
75
|
+
in_channels = 3
|
76
|
+
cfg.each do |v|
|
77
|
+
if v == "M"
|
78
|
+
layers += [Torch::NN::MaxPool2d.new(2, stride: 2)]
|
79
|
+
else
|
80
|
+
conv2d = Torch::NN::Conv2d.new(in_channels, v, 3, padding: 1)
|
81
|
+
if batch_norm
|
82
|
+
layers += [conv2d, Torch::NN::BatchNorm2d.new(v), Torch::NN::ReLU.new(inplace: true)]
|
83
|
+
else
|
84
|
+
layers += [conv2d, Torch::NN::ReLU.new(inplace: true)]
|
85
|
+
end
|
86
|
+
in_channels = v
|
87
|
+
end
|
88
|
+
end
|
89
|
+
Torch::NN::Sequential.new(*layers)
|
90
|
+
end
|
91
|
+
end
|
92
|
+
end
|
93
|
+
end
|
data/lib/torchvision/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torchvision
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.1.
|
4
|
+
version: 0.1.3
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2020-
|
11
|
+
date: 2020-06-30 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-narray
|
@@ -30,14 +30,14 @@ dependencies:
|
|
30
30
|
requirements:
|
31
31
|
- - ">="
|
32
32
|
- !ruby/object:Gem::Version
|
33
|
-
version: 0.2.
|
33
|
+
version: 0.2.7
|
34
34
|
type: :runtime
|
35
35
|
prerelease: false
|
36
36
|
version_requirements: !ruby/object:Gem::Requirement
|
37
37
|
requirements:
|
38
38
|
- - ">="
|
39
39
|
- !ruby/object:Gem::Version
|
40
|
-
version: 0.2.
|
40
|
+
version: 0.2.7
|
41
41
|
- !ruby/object:Gem::Dependency
|
42
42
|
name: bundler
|
43
43
|
requirement: !ruby/object:Gem::Requirement
|
@@ -96,10 +96,28 @@ files:
|
|
96
96
|
- lib/torchvision/datasets/kmnist.rb
|
97
97
|
- lib/torchvision/datasets/mnist.rb
|
98
98
|
- lib/torchvision/datasets/vision_dataset.rb
|
99
|
+
- lib/torchvision/models/alexnet.rb
|
99
100
|
- lib/torchvision/models/basic_block.rb
|
100
101
|
- lib/torchvision/models/bottleneck.rb
|
101
102
|
- lib/torchvision/models/resnet.rb
|
103
|
+
- lib/torchvision/models/resnet101.rb
|
104
|
+
- lib/torchvision/models/resnet152.rb
|
102
105
|
- lib/torchvision/models/resnet18.rb
|
106
|
+
- lib/torchvision/models/resnet34.rb
|
107
|
+
- lib/torchvision/models/resnet50.rb
|
108
|
+
- lib/torchvision/models/resnext101_32x8d.rb
|
109
|
+
- lib/torchvision/models/resnext50_32x4d.rb
|
110
|
+
- lib/torchvision/models/vgg.rb
|
111
|
+
- lib/torchvision/models/vgg11.rb
|
112
|
+
- lib/torchvision/models/vgg11_bn.rb
|
113
|
+
- lib/torchvision/models/vgg13.rb
|
114
|
+
- lib/torchvision/models/vgg13_bn.rb
|
115
|
+
- lib/torchvision/models/vgg16.rb
|
116
|
+
- lib/torchvision/models/vgg16_bn.rb
|
117
|
+
- lib/torchvision/models/vgg19.rb
|
118
|
+
- lib/torchvision/models/vgg19_bn.rb
|
119
|
+
- lib/torchvision/models/wide_resnet101_2.rb
|
120
|
+
- lib/torchvision/models/wide_resnet50_2.rb
|
103
121
|
- lib/torchvision/transforms/compose.rb
|
104
122
|
- lib/torchvision/transforms/functional.rb
|
105
123
|
- lib/torchvision/transforms/normalize.rb
|