torchvision 0.1.2 → 0.1.3

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 107e429f990a063e57f6218ee6d4fbed3cff80aa6746f0794798d59e1c13b099
4
- data.tar.gz: 556fdc4c413803d415ea5575747ec2239f18b0dff9b39a37a4fd6366adec37a6
3
+ metadata.gz: 44c29605f12dddf8196432223f2137ef9b9bef490996b718f5a9bdbc13dbd33f
4
+ data.tar.gz: 8790200a0ed8f7a275f99327431dc8be99a7578c473ec03411f9411cd10c6c93
5
5
  SHA512:
6
- metadata.gz: 21b712578516c146888be30bed64a6da6339a42974f5c88fa685278caa9231e4bc4b75e250af3ad37d5450cd205891b31b281cf23fa362456c7ae00998c2736d
7
- data.tar.gz: ac86f13e8b5d6a400842ba37b1bf593139360a8df9a19ab4674e71a4327c821a6d5735774185dcbaed368ee735b00bf33d9faedb13d46cd168314da89d900c38
6
+ metadata.gz: 65816ef10f524781553327f9634bb8818a7efa6fb072e81468949d937b5430dcddc7f6c8cf3b305c977dc4b1279d23b38fb034a4c277eecc1adcc0f2b8c99e3e
7
+ data.tar.gz: 01f485a78cd5a19c9a0dc987f4e46931b033ffb3a84af703b602b0e90e8c221c99b5e6318c86b18361562cc3ca582b3a0822de151a5edf63964002317a684bfa
@@ -1,3 +1,11 @@
1
+ ## 0.1.3 (2020-06-29)
2
+
3
+ - Added AlexNet model
4
+ - Added ResNet34, ResNet50, ResNet101, and ResNet152 models
5
+ - Added ResNeXt model
6
+ - Added VGG11, VGG13, VGG16, and VGG19 models
7
+ - Added Wide ResNet model
8
+
1
9
  ## 0.1.2 (2020-04-29)
2
10
 
3
11
  - Added CIFAR10, CIFAR100, FashionMNIST, and KMNIST datasets
data/README.md CHANGED
@@ -45,8 +45,53 @@ TorchVision::Transforms::Compose.new([
45
45
 
46
46
  ## Models
47
47
 
48
+ - [AlexNet](#alexnet)
49
+ - [ResNet](#resnet)
50
+ - [ResNeXt](#resnext)
51
+ - [VGG](#vgg)
52
+ - [Wide ResNet](#wide-resnet)
53
+
54
+ ### AlexNet
55
+
56
+ ```ruby
57
+ TorchVision::Models::AlexNet.new
58
+ ```
59
+
60
+ ### ResNet
61
+
62
+ ```ruby
63
+ TorchVision::Models::ResNet18.new
64
+ TorchVision::Models::ResNet34.new
65
+ TorchVision::Models::ResNet50.new
66
+ TorchVision::Models::ResNet101.new
67
+ TorchVision::Models::ResNet152.new
68
+ ```
69
+
70
+ ### ResNeXt
71
+
72
+ ```ruby
73
+ TorchVision::Models::ResNext52_32x4d.new
74
+ TorchVision::Models::ResNext101_32x8d.new
75
+ ```
76
+
77
+ ### VGG
78
+
79
+ ```ruby
80
+ TorchVision::Models::VGG11.new
81
+ TorchVision::Models::VGG11BN.new
82
+ TorchVision::Models::VGG13.new
83
+ TorchVision::Models::VGG13BN.new
84
+ TorchVision::Models::VGG16.new
85
+ TorchVision::Models::VGG16BN.new
86
+ TorchVision::Models::VGG19.new
87
+ TorchVision::Models::VGG19BN.new
88
+ ```
89
+
90
+ ### Wide ResNet
91
+
48
92
  ```ruby
49
- TorchVision::Models::Resnet18.new
93
+ TorchVision::Models::WideResNet52_2.new
94
+ TorchVision::Models::WideResNet101_2.new
50
95
  ```
51
96
 
52
97
  ## Disclaimer
@@ -21,10 +21,28 @@ require "torchvision/datasets/fashion_mnist"
21
21
  require "torchvision/datasets/kmnist"
22
22
 
23
23
  # models
24
+ require "torchvision/models/alexnet"
24
25
  require "torchvision/models/basic_block"
25
26
  require "torchvision/models/bottleneck"
26
27
  require "torchvision/models/resnet"
27
28
  require "torchvision/models/resnet18"
29
+ require "torchvision/models/resnet34"
30
+ require "torchvision/models/resnet50"
31
+ require "torchvision/models/resnet101"
32
+ require "torchvision/models/resnet152"
33
+ require "torchvision/models/resnext50_32x4d"
34
+ require "torchvision/models/resnext101_32x8d"
35
+ require "torchvision/models/vgg"
36
+ require "torchvision/models/vgg11"
37
+ require "torchvision/models/vgg11_bn"
38
+ require "torchvision/models/vgg13"
39
+ require "torchvision/models/vgg13_bn"
40
+ require "torchvision/models/vgg16"
41
+ require "torchvision/models/vgg16_bn"
42
+ require "torchvision/models/vgg19"
43
+ require "torchvision/models/vgg19_bn"
44
+ require "torchvision/models/wide_resnet50_2"
45
+ require "torchvision/models/wide_resnet101_2"
28
46
 
29
47
  # transforms
30
48
  require "torchvision/transforms/compose"
@@ -0,0 +1,42 @@
1
+ module TorchVision
2
+ module Models
3
+ class AlexNet < Torch::NN::Module
4
+ def initialize(num_classes: 1000)
5
+ super()
6
+ @features = Torch::NN::Sequential.new(
7
+ Torch::NN::Conv2d.new(3, 64, 11, stride: 4, padding: 2),
8
+ Torch::NN::ReLU.new(inplace: true),
9
+ Torch::NN::MaxPool2d.new(3, stride: 2),
10
+ Torch::NN::Conv2d.new(64, 192, 5, padding: 2),
11
+ Torch::NN::ReLU.new(inplace: true),
12
+ Torch::NN::MaxPool2d.new(3, stride: 2),
13
+ Torch::NN::Conv2d.new(192, 384, 3, padding: 1),
14
+ Torch::NN::ReLU.new(inplace: true),
15
+ Torch::NN::Conv2d.new(384, 256, 3, padding: 1),
16
+ Torch::NN::ReLU.new(inplace: true),
17
+ Torch::NN::Conv2d.new(256, 256, 3, padding: 1),
18
+ Torch::NN::ReLU.new(inplace: true),
19
+ Torch::NN::MaxPool2d.new(3, stride: 2),
20
+ )
21
+ @avgpool = Torch::NN::AdaptiveAvgPool2d.new([6, 6])
22
+ @classifier = Torch::NN::Sequential.new(
23
+ Torch::NN::Dropout.new,
24
+ Torch::NN::Linear.new(256 * 6 * 6, 4096),
25
+ Torch::NN::ReLU.new(inplace: true),
26
+ Torch::NN::Dropout.new,
27
+ Torch::NN::Linear.new(4096, 4096),
28
+ Torch::NN::ReLU.new(inplace: true),
29
+ Torch::NN::Linear.new(4096, num_classes)
30
+ )
31
+ end
32
+
33
+ def forward(x)
34
+ x = @features.call(x)
35
+ x = @avgpool.call(x)
36
+ x = Torch.flatten(x, 1)
37
+ x = @classifier.call(x)
38
+ x
39
+ end
40
+ end
41
+ end
42
+ end
@@ -1,6 +1,18 @@
1
1
  module TorchVision
2
2
  module Models
3
3
  class ResNet < Torch::NN::Module
4
+ MODEL_URLS = {
5
+ "resnet18" => "https://download.pytorch.org/models/resnet18-5c106cde.pth",
6
+ "resnet34" => "https://download.pytorch.org/models/resnet34-333f7ec4.pth",
7
+ "resnet50" => "https://download.pytorch.org/models/resnet50-19c8e357.pth",
8
+ "resnet101" => "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth",
9
+ "resnet152" => "https://download.pytorch.org/models/resnet152-b121ed2d.pth",
10
+ "resnext50_32x4d" => "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth",
11
+ "resnext101_32x8d" => "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth",
12
+ "wide_resnet50_2" => "https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth",
13
+ "wide_resnet101_2" => "https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth"
14
+ }
15
+
4
16
  def initialize(block, layers, num_classes=1000, zero_init_residual: false,
5
17
  groups: 1, width_per_group: 64, replace_stride_with_dilation: nil, norm_layer: nil)
6
18
 
@@ -102,6 +114,16 @@ module TorchVision
102
114
  def forward(x)
103
115
  _forward_impl(x)
104
116
  end
117
+
118
+ def self.make_model(arch, block, layers, pretrained: false, **kwargs)
119
+ model = ResNet.new(block, layers, **kwargs)
120
+ if pretrained
121
+ url = MODEL_URLS[arch]
122
+ state_dict = Torch::Hub.load_state_dict_from_url(url)
123
+ model.load_state_dict(state_dict)
124
+ end
125
+ model
126
+ end
105
127
  end
106
128
  end
107
129
  end
@@ -0,0 +1,9 @@
1
+ module TorchVision
2
+ module Models
3
+ module ResNet101
4
+ def self.new(**kwargs)
5
+ ResNet.make_model("resnet101", Bottleneck, [3, 4, 23, 3], **kwargs)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,9 @@
1
+ module TorchVision
2
+ module Models
3
+ module ResNet152
4
+ def self.new(**kwargs)
5
+ ResNet.make_model("resnet152", Bottleneck, [3, 8, 36, 3], **kwargs)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -1,14 +1,8 @@
1
1
  module TorchVision
2
2
  module Models
3
3
  module ResNet18
4
- def self.new(pretrained: false, **kwargs)
5
- model = ResNet.new(BasicBlock, [2, 2, 2, 2], **kwargs)
6
- if pretrained
7
- url = "https://download.pytorch.org/models/resnet18-5c106cde.pth"
8
- state_dict = Torch::Hub.load_state_dict_from_url(url)
9
- model.load_state_dict(state_dict)
10
- end
11
- model
4
+ def self.new(**kwargs)
5
+ ResNet.make_model("resnet18", BasicBlock, [2, 2, 2, 2], **kwargs)
12
6
  end
13
7
  end
14
8
  end
@@ -0,0 +1,9 @@
1
+ module TorchVision
2
+ module Models
3
+ module ResNet34
4
+ def self.new(**kwargs)
5
+ ResNet.make_model("resnet34", BasicBlock, [3, 4, 6, 3], **kwargs)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,9 @@
1
+ module TorchVision
2
+ module Models
3
+ module ResNet50
4
+ def self.new(**kwargs)
5
+ ResNet.make_model("resnet50", Bottleneck, [3, 4, 6, 3], **kwargs)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,11 @@
1
+ module TorchVision
2
+ module Models
3
+ module ResNext101_32x8d
4
+ def self.new(**kwargs)
5
+ kwargs[:groups] = 32
6
+ kwargs[:width_per_group] = 8
7
+ ResNet.make_model("resnext101_32x8d", Bottleneck, [3, 4, 23, 3], **kwargs)
8
+ end
9
+ end
10
+ end
11
+ end
@@ -0,0 +1,11 @@
1
+ module TorchVision
2
+ module Models
3
+ module ResNext50_32x4d
4
+ def self.new(**kwargs)
5
+ kwargs[:groups] = 32
6
+ kwargs[:width_per_group] = 4
7
+ ResNet.make_model("resnext50_32x4d", Bottleneck, [3, 4, 6, 3], **kwargs)
8
+ end
9
+ end
10
+ end
11
+ end
@@ -0,0 +1,93 @@
1
+ module TorchVision
2
+ module Models
3
+ class VGG < Torch::NN::Module
4
+ MODEL_URLS = {
5
+ "vgg11" => "https://download.pytorch.org/models/vgg11-bbd30ac9.pth",
6
+ "vgg13" => "https://download.pytorch.org/models/vgg13-c768596a.pth",
7
+ "vgg16" => "https://download.pytorch.org/models/vgg16-397923af.pth",
8
+ "vgg19" => "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth",
9
+ "vgg11_bn" => "https://download.pytorch.org/models/vgg11_bn-6002323d.pth",
10
+ "vgg13_bn" => "https://download.pytorch.org/models/vgg13_bn-abd245e5.pth",
11
+ "vgg16_bn" => "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth",
12
+ "vgg19_bn" => "https://download.pytorch.org/models/vgg19_bn-c79401a0.pth"
13
+ }
14
+
15
+ def initialize(features, num_classes: 1000, init_weights: true)
16
+ super()
17
+ @features = features
18
+ @avgpool = Torch::NN::AdaptiveAvgPool2d.new([7, 7])
19
+ @classifier = Torch::NN::Sequential.new(
20
+ Torch::NN::Linear.new(512 * 7 * 7, 4096),
21
+ Torch::NN::ReLU.new(inplace: true),
22
+ Torch::NN::Dropout.new,
23
+ Torch::NN::Linear.new(4096, 4096),
24
+ Torch::NN::ReLU.new(inplace: true),
25
+ Torch::NN::Dropout.new,
26
+ Torch::NN::Linear.new(4096, num_classes)
27
+ )
28
+ _initialize_weights if init_weights
29
+ end
30
+
31
+ def forward(x)
32
+ x = @features.call(x)
33
+ x = @avgpool.call(x)
34
+ x = Torch.flatten(x, 1)
35
+ x = @classifier.call(x)
36
+ x
37
+ end
38
+
39
+ def _initialize_weights
40
+ modules.each do |m|
41
+ case m
42
+ when Torch::NN::Conv2d
43
+ Torch::NN::Init.kaiming_normal!(m.weight, mode: "fan_out", nonlinearity: "relu")
44
+ Torch::NN::Init.constant!(m.bias, 0) if m.bias
45
+ when Torch::NN::BatchNorm2d
46
+ Torch::NN::Init.constant!(m.weight, 1)
47
+ Torch::NN::Init.constant!(m.bias, 0)
48
+ when Torch::NN::Linear
49
+ Torch::NN::Init.normal!(m.weight, mean: 0, std: 0.01)
50
+ Torch::NN::Init.constant!(m.bias, 0)
51
+ end
52
+ end
53
+ end
54
+
55
+ CFGS = {
56
+ "A" => [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
57
+ "B" => [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
58
+ "D" => [64, 64, "M", 128, 128, "M", 256, 256, 256, "M", 512, 512, 512, "M", 512, 512, 512, "M"],
59
+ "E" => [64, 64, "M", 128, 128, "M", 256, 256, 256, 256, "M", 512, 512, 512, 512, "M", 512, 512, 512, 512, "M"],
60
+ }
61
+
62
+ def self.make_model(arch, cfg, batch_norm, pretrained: false, **kwargs)
63
+ kwargs[:init_weights] = false if pretrained
64
+ model = VGG.new(make_layers(CFGS[cfg], batch_norm), **kwargs)
65
+ if pretrained
66
+ url = MODEL_URLS[arch]
67
+ state_dict = Torch::Hub.load_state_dict_from_url(url)
68
+ model.load_state_dict(state_dict)
69
+ end
70
+ model
71
+ end
72
+
73
+ def self.make_layers(cfg, batch_norm)
74
+ layers = []
75
+ in_channels = 3
76
+ cfg.each do |v|
77
+ if v == "M"
78
+ layers += [Torch::NN::MaxPool2d.new(2, stride: 2)]
79
+ else
80
+ conv2d = Torch::NN::Conv2d.new(in_channels, v, 3, padding: 1)
81
+ if batch_norm
82
+ layers += [conv2d, Torch::NN::BatchNorm2d.new(v), Torch::NN::ReLU.new(inplace: true)]
83
+ else
84
+ layers += [conv2d, Torch::NN::ReLU.new(inplace: true)]
85
+ end
86
+ in_channels = v
87
+ end
88
+ end
89
+ Torch::NN::Sequential.new(*layers)
90
+ end
91
+ end
92
+ end
93
+ end
@@ -0,0 +1,9 @@
1
+ module TorchVision
2
+ module Models
3
+ module VGG11
4
+ def self.new(**kwargs)
5
+ VGG.make_model("vgg11", "A", false, **kwargs)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,9 @@
1
+ module TorchVision
2
+ module Models
3
+ module VGG11BN
4
+ def self.new(**kwargs)
5
+ VGG.make_model("vgg11_bn", "A", true, **kwargs)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,9 @@
1
+ module TorchVision
2
+ module Models
3
+ module VGG13
4
+ def self.new(**kwargs)
5
+ VGG.make_model("vgg13", "B", false, **kwargs)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,9 @@
1
+ module TorchVision
2
+ module Models
3
+ module VGG13BN
4
+ def self.new(**kwargs)
5
+ VGG.make_model("vgg13_bn", "B", true, **kwargs)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,9 @@
1
+ module TorchVision
2
+ module Models
3
+ module VGG16
4
+ def self.new(**kwargs)
5
+ VGG.make_model("vgg16", "D", false, **kwargs)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,9 @@
1
+ module TorchVision
2
+ module Models
3
+ module VGG16BN
4
+ def self.new(**kwargs)
5
+ VGG.make_model("vgg16_bn", "D", true, **kwargs)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,9 @@
1
+ module TorchVision
2
+ module Models
3
+ module VGG19
4
+ def self.new(**kwargs)
5
+ VGG.make_model("vgg19", "E", false, **kwargs)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,9 @@
1
+ module TorchVision
2
+ module Models
3
+ module VGG19BN
4
+ def self.new(**kwargs)
5
+ VGG.make_model("vgg19_bn", "E", true, **kwargs)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,10 @@
1
+ module TorchVision
2
+ module Models
3
+ module WideResNet101_2
4
+ def self.new(**kwargs)
5
+ kwargs[:width_per_group] = 64 * 2
6
+ ResNet.make_model("wide_resnet101_2", Bottleneck, [3, 4, 23, 3], **kwargs)
7
+ end
8
+ end
9
+ end
10
+ end
@@ -0,0 +1,10 @@
1
+ module TorchVision
2
+ module Models
3
+ module WideResNet50_2
4
+ def self.new(**kwargs)
5
+ kwargs[:width_per_group] = 64 * 2
6
+ ResNet.make_model("wide_resnet50_2", Bottleneck, [3, 4, 6, 3], **kwargs)
7
+ end
8
+ end
9
+ end
10
+ end
@@ -1,3 +1,3 @@
1
1
  module TorchVision
2
- VERSION = "0.1.2"
2
+ VERSION = "0.1.3"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torchvision
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.1.2
4
+ version: 0.1.3
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2020-04-29 00:00:00.000000000 Z
11
+ date: 2020-06-30 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray
@@ -30,14 +30,14 @@ dependencies:
30
30
  requirements:
31
31
  - - ">="
32
32
  - !ruby/object:Gem::Version
33
- version: 0.2.4
33
+ version: 0.2.7
34
34
  type: :runtime
35
35
  prerelease: false
36
36
  version_requirements: !ruby/object:Gem::Requirement
37
37
  requirements:
38
38
  - - ">="
39
39
  - !ruby/object:Gem::Version
40
- version: 0.2.4
40
+ version: 0.2.7
41
41
  - !ruby/object:Gem::Dependency
42
42
  name: bundler
43
43
  requirement: !ruby/object:Gem::Requirement
@@ -96,10 +96,28 @@ files:
96
96
  - lib/torchvision/datasets/kmnist.rb
97
97
  - lib/torchvision/datasets/mnist.rb
98
98
  - lib/torchvision/datasets/vision_dataset.rb
99
+ - lib/torchvision/models/alexnet.rb
99
100
  - lib/torchvision/models/basic_block.rb
100
101
  - lib/torchvision/models/bottleneck.rb
101
102
  - lib/torchvision/models/resnet.rb
103
+ - lib/torchvision/models/resnet101.rb
104
+ - lib/torchvision/models/resnet152.rb
102
105
  - lib/torchvision/models/resnet18.rb
106
+ - lib/torchvision/models/resnet34.rb
107
+ - lib/torchvision/models/resnet50.rb
108
+ - lib/torchvision/models/resnext101_32x8d.rb
109
+ - lib/torchvision/models/resnext50_32x4d.rb
110
+ - lib/torchvision/models/vgg.rb
111
+ - lib/torchvision/models/vgg11.rb
112
+ - lib/torchvision/models/vgg11_bn.rb
113
+ - lib/torchvision/models/vgg13.rb
114
+ - lib/torchvision/models/vgg13_bn.rb
115
+ - lib/torchvision/models/vgg16.rb
116
+ - lib/torchvision/models/vgg16_bn.rb
117
+ - lib/torchvision/models/vgg19.rb
118
+ - lib/torchvision/models/vgg19_bn.rb
119
+ - lib/torchvision/models/wide_resnet101_2.rb
120
+ - lib/torchvision/models/wide_resnet50_2.rb
103
121
  - lib/torchvision/transforms/compose.rb
104
122
  - lib/torchvision/transforms/functional.rb
105
123
  - lib/torchvision/transforms/normalize.rb