torchvision 0.1.2 → 0.1.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 107e429f990a063e57f6218ee6d4fbed3cff80aa6746f0794798d59e1c13b099
4
- data.tar.gz: 556fdc4c413803d415ea5575747ec2239f18b0dff9b39a37a4fd6366adec37a6
3
+ metadata.gz: 44c29605f12dddf8196432223f2137ef9b9bef490996b718f5a9bdbc13dbd33f
4
+ data.tar.gz: 8790200a0ed8f7a275f99327431dc8be99a7578c473ec03411f9411cd10c6c93
5
5
  SHA512:
6
- metadata.gz: 21b712578516c146888be30bed64a6da6339a42974f5c88fa685278caa9231e4bc4b75e250af3ad37d5450cd205891b31b281cf23fa362456c7ae00998c2736d
7
- data.tar.gz: ac86f13e8b5d6a400842ba37b1bf593139360a8df9a19ab4674e71a4327c821a6d5735774185dcbaed368ee735b00bf33d9faedb13d46cd168314da89d900c38
6
+ metadata.gz: 65816ef10f524781553327f9634bb8818a7efa6fb072e81468949d937b5430dcddc7f6c8cf3b305c977dc4b1279d23b38fb034a4c277eecc1adcc0f2b8c99e3e
7
+ data.tar.gz: 01f485a78cd5a19c9a0dc987f4e46931b033ffb3a84af703b602b0e90e8c221c99b5e6318c86b18361562cc3ca582b3a0822de151a5edf63964002317a684bfa
@@ -1,3 +1,11 @@
1
+ ## 0.1.3 (2020-06-29)
2
+
3
+ - Added AlexNet model
4
+ - Added ResNet34, ResNet50, ResNet101, and ResNet152 models
5
+ - Added ResNeXt model
6
+ - Added VGG11, VGG13, VGG16, and VGG19 models
7
+ - Added Wide ResNet model
8
+
1
9
  ## 0.1.2 (2020-04-29)
2
10
 
3
11
  - Added CIFAR10, CIFAR100, FashionMNIST, and KMNIST datasets
data/README.md CHANGED
@@ -45,8 +45,53 @@ TorchVision::Transforms::Compose.new([
45
45
 
46
46
  ## Models
47
47
 
48
+ - [AlexNet](#alexnet)
49
+ - [ResNet](#resnet)
50
+ - [ResNeXt](#resnext)
51
+ - [VGG](#vgg)
52
+ - [Wide ResNet](#wide-resnet)
53
+
54
+ ### AlexNet
55
+
56
+ ```ruby
57
+ TorchVision::Models::AlexNet.new
58
+ ```
59
+
60
+ ### ResNet
61
+
62
+ ```ruby
63
+ TorchVision::Models::ResNet18.new
64
+ TorchVision::Models::ResNet34.new
65
+ TorchVision::Models::ResNet50.new
66
+ TorchVision::Models::ResNet101.new
67
+ TorchVision::Models::ResNet152.new
68
+ ```
69
+
70
+ ### ResNeXt
71
+
72
+ ```ruby
73
+ TorchVision::Models::ResNext52_32x4d.new
74
+ TorchVision::Models::ResNext101_32x8d.new
75
+ ```
76
+
77
+ ### VGG
78
+
79
+ ```ruby
80
+ TorchVision::Models::VGG11.new
81
+ TorchVision::Models::VGG11BN.new
82
+ TorchVision::Models::VGG13.new
83
+ TorchVision::Models::VGG13BN.new
84
+ TorchVision::Models::VGG16.new
85
+ TorchVision::Models::VGG16BN.new
86
+ TorchVision::Models::VGG19.new
87
+ TorchVision::Models::VGG19BN.new
88
+ ```
89
+
90
+ ### Wide ResNet
91
+
48
92
  ```ruby
49
- TorchVision::Models::Resnet18.new
93
+ TorchVision::Models::WideResNet52_2.new
94
+ TorchVision::Models::WideResNet101_2.new
50
95
  ```
51
96
 
52
97
  ## Disclaimer
@@ -21,10 +21,28 @@ require "torchvision/datasets/fashion_mnist"
21
21
  require "torchvision/datasets/kmnist"
22
22
 
23
23
  # models
24
+ require "torchvision/models/alexnet"
24
25
  require "torchvision/models/basic_block"
25
26
  require "torchvision/models/bottleneck"
26
27
  require "torchvision/models/resnet"
27
28
  require "torchvision/models/resnet18"
29
+ require "torchvision/models/resnet34"
30
+ require "torchvision/models/resnet50"
31
+ require "torchvision/models/resnet101"
32
+ require "torchvision/models/resnet152"
33
+ require "torchvision/models/resnext50_32x4d"
34
+ require "torchvision/models/resnext101_32x8d"
35
+ require "torchvision/models/vgg"
36
+ require "torchvision/models/vgg11"
37
+ require "torchvision/models/vgg11_bn"
38
+ require "torchvision/models/vgg13"
39
+ require "torchvision/models/vgg13_bn"
40
+ require "torchvision/models/vgg16"
41
+ require "torchvision/models/vgg16_bn"
42
+ require "torchvision/models/vgg19"
43
+ require "torchvision/models/vgg19_bn"
44
+ require "torchvision/models/wide_resnet50_2"
45
+ require "torchvision/models/wide_resnet101_2"
28
46
 
29
47
  # transforms
30
48
  require "torchvision/transforms/compose"
@@ -0,0 +1,42 @@
1
+ module TorchVision
2
+ module Models
3
+ class AlexNet < Torch::NN::Module
4
+ def initialize(num_classes: 1000)
5
+ super()
6
+ @features = Torch::NN::Sequential.new(
7
+ Torch::NN::Conv2d.new(3, 64, 11, stride: 4, padding: 2),
8
+ Torch::NN::ReLU.new(inplace: true),
9
+ Torch::NN::MaxPool2d.new(3, stride: 2),
10
+ Torch::NN::Conv2d.new(64, 192, 5, padding: 2),
11
+ Torch::NN::ReLU.new(inplace: true),
12
+ Torch::NN::MaxPool2d.new(3, stride: 2),
13
+ Torch::NN::Conv2d.new(192, 384, 3, padding: 1),
14
+ Torch::NN::ReLU.new(inplace: true),
15
+ Torch::NN::Conv2d.new(384, 256, 3, padding: 1),
16
+ Torch::NN::ReLU.new(inplace: true),
17
+ Torch::NN::Conv2d.new(256, 256, 3, padding: 1),
18
+ Torch::NN::ReLU.new(inplace: true),
19
+ Torch::NN::MaxPool2d.new(3, stride: 2),
20
+ )
21
+ @avgpool = Torch::NN::AdaptiveAvgPool2d.new([6, 6])
22
+ @classifier = Torch::NN::Sequential.new(
23
+ Torch::NN::Dropout.new,
24
+ Torch::NN::Linear.new(256 * 6 * 6, 4096),
25
+ Torch::NN::ReLU.new(inplace: true),
26
+ Torch::NN::Dropout.new,
27
+ Torch::NN::Linear.new(4096, 4096),
28
+ Torch::NN::ReLU.new(inplace: true),
29
+ Torch::NN::Linear.new(4096, num_classes)
30
+ )
31
+ end
32
+
33
+ def forward(x)
34
+ x = @features.call(x)
35
+ x = @avgpool.call(x)
36
+ x = Torch.flatten(x, 1)
37
+ x = @classifier.call(x)
38
+ x
39
+ end
40
+ end
41
+ end
42
+ end
@@ -1,6 +1,18 @@
1
1
  module TorchVision
2
2
  module Models
3
3
  class ResNet < Torch::NN::Module
4
+ MODEL_URLS = {
5
+ "resnet18" => "https://download.pytorch.org/models/resnet18-5c106cde.pth",
6
+ "resnet34" => "https://download.pytorch.org/models/resnet34-333f7ec4.pth",
7
+ "resnet50" => "https://download.pytorch.org/models/resnet50-19c8e357.pth",
8
+ "resnet101" => "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth",
9
+ "resnet152" => "https://download.pytorch.org/models/resnet152-b121ed2d.pth",
10
+ "resnext50_32x4d" => "https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth",
11
+ "resnext101_32x8d" => "https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth",
12
+ "wide_resnet50_2" => "https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth",
13
+ "wide_resnet101_2" => "https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth"
14
+ }
15
+
4
16
  def initialize(block, layers, num_classes=1000, zero_init_residual: false,
5
17
  groups: 1, width_per_group: 64, replace_stride_with_dilation: nil, norm_layer: nil)
6
18
 
@@ -102,6 +114,16 @@ module TorchVision
102
114
  def forward(x)
103
115
  _forward_impl(x)
104
116
  end
117
+
118
+ def self.make_model(arch, block, layers, pretrained: false, **kwargs)
119
+ model = ResNet.new(block, layers, **kwargs)
120
+ if pretrained
121
+ url = MODEL_URLS[arch]
122
+ state_dict = Torch::Hub.load_state_dict_from_url(url)
123
+ model.load_state_dict(state_dict)
124
+ end
125
+ model
126
+ end
105
127
  end
106
128
  end
107
129
  end
@@ -0,0 +1,9 @@
1
+ module TorchVision
2
+ module Models
3
+ module ResNet101
4
+ def self.new(**kwargs)
5
+ ResNet.make_model("resnet101", Bottleneck, [3, 4, 23, 3], **kwargs)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,9 @@
1
+ module TorchVision
2
+ module Models
3
+ module ResNet152
4
+ def self.new(**kwargs)
5
+ ResNet.make_model("resnet152", Bottleneck, [3, 8, 36, 3], **kwargs)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -1,14 +1,8 @@
1
1
  module TorchVision
2
2
  module Models
3
3
  module ResNet18
4
- def self.new(pretrained: false, **kwargs)
5
- model = ResNet.new(BasicBlock, [2, 2, 2, 2], **kwargs)
6
- if pretrained
7
- url = "https://download.pytorch.org/models/resnet18-5c106cde.pth"
8
- state_dict = Torch::Hub.load_state_dict_from_url(url)
9
- model.load_state_dict(state_dict)
10
- end
11
- model
4
+ def self.new(**kwargs)
5
+ ResNet.make_model("resnet18", BasicBlock, [2, 2, 2, 2], **kwargs)
12
6
  end
13
7
  end
14
8
  end
@@ -0,0 +1,9 @@
1
+ module TorchVision
2
+ module Models
3
+ module ResNet34
4
+ def self.new(**kwargs)
5
+ ResNet.make_model("resnet34", BasicBlock, [3, 4, 6, 3], **kwargs)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,9 @@
1
+ module TorchVision
2
+ module Models
3
+ module ResNet50
4
+ def self.new(**kwargs)
5
+ ResNet.make_model("resnet50", Bottleneck, [3, 4, 6, 3], **kwargs)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,11 @@
1
+ module TorchVision
2
+ module Models
3
+ module ResNext101_32x8d
4
+ def self.new(**kwargs)
5
+ kwargs[:groups] = 32
6
+ kwargs[:width_per_group] = 8
7
+ ResNet.make_model("resnext101_32x8d", Bottleneck, [3, 4, 23, 3], **kwargs)
8
+ end
9
+ end
10
+ end
11
+ end
@@ -0,0 +1,11 @@
1
+ module TorchVision
2
+ module Models
3
+ module ResNext50_32x4d
4
+ def self.new(**kwargs)
5
+ kwargs[:groups] = 32
6
+ kwargs[:width_per_group] = 4
7
+ ResNet.make_model("resnext50_32x4d", Bottleneck, [3, 4, 6, 3], **kwargs)
8
+ end
9
+ end
10
+ end
11
+ end
@@ -0,0 +1,93 @@
1
+ module TorchVision
2
+ module Models
3
+ class VGG < Torch::NN::Module
4
+ MODEL_URLS = {
5
+ "vgg11" => "https://download.pytorch.org/models/vgg11-bbd30ac9.pth",
6
+ "vgg13" => "https://download.pytorch.org/models/vgg13-c768596a.pth",
7
+ "vgg16" => "https://download.pytorch.org/models/vgg16-397923af.pth",
8
+ "vgg19" => "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth",
9
+ "vgg11_bn" => "https://download.pytorch.org/models/vgg11_bn-6002323d.pth",
10
+ "vgg13_bn" => "https://download.pytorch.org/models/vgg13_bn-abd245e5.pth",
11
+ "vgg16_bn" => "https://download.pytorch.org/models/vgg16_bn-6c64b313.pth",
12
+ "vgg19_bn" => "https://download.pytorch.org/models/vgg19_bn-c79401a0.pth"
13
+ }
14
+
15
+ def initialize(features, num_classes: 1000, init_weights: true)
16
+ super()
17
+ @features = features
18
+ @avgpool = Torch::NN::AdaptiveAvgPool2d.new([7, 7])
19
+ @classifier = Torch::NN::Sequential.new(
20
+ Torch::NN::Linear.new(512 * 7 * 7, 4096),
21
+ Torch::NN::ReLU.new(inplace: true),
22
+ Torch::NN::Dropout.new,
23
+ Torch::NN::Linear.new(4096, 4096),
24
+ Torch::NN::ReLU.new(inplace: true),
25
+ Torch::NN::Dropout.new,
26
+ Torch::NN::Linear.new(4096, num_classes)
27
+ )
28
+ _initialize_weights if init_weights
29
+ end
30
+
31
+ def forward(x)
32
+ x = @features.call(x)
33
+ x = @avgpool.call(x)
34
+ x = Torch.flatten(x, 1)
35
+ x = @classifier.call(x)
36
+ x
37
+ end
38
+
39
+ def _initialize_weights
40
+ modules.each do |m|
41
+ case m
42
+ when Torch::NN::Conv2d
43
+ Torch::NN::Init.kaiming_normal!(m.weight, mode: "fan_out", nonlinearity: "relu")
44
+ Torch::NN::Init.constant!(m.bias, 0) if m.bias
45
+ when Torch::NN::BatchNorm2d
46
+ Torch::NN::Init.constant!(m.weight, 1)
47
+ Torch::NN::Init.constant!(m.bias, 0)
48
+ when Torch::NN::Linear
49
+ Torch::NN::Init.normal!(m.weight, mean: 0, std: 0.01)
50
+ Torch::NN::Init.constant!(m.bias, 0)
51
+ end
52
+ end
53
+ end
54
+
55
+ CFGS = {
56
+ "A" => [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
57
+ "B" => [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"],
58
+ "D" => [64, 64, "M", 128, 128, "M", 256, 256, 256, "M", 512, 512, 512, "M", 512, 512, 512, "M"],
59
+ "E" => [64, 64, "M", 128, 128, "M", 256, 256, 256, 256, "M", 512, 512, 512, 512, "M", 512, 512, 512, 512, "M"],
60
+ }
61
+
62
+ def self.make_model(arch, cfg, batch_norm, pretrained: false, **kwargs)
63
+ kwargs[:init_weights] = false if pretrained
64
+ model = VGG.new(make_layers(CFGS[cfg], batch_norm), **kwargs)
65
+ if pretrained
66
+ url = MODEL_URLS[arch]
67
+ state_dict = Torch::Hub.load_state_dict_from_url(url)
68
+ model.load_state_dict(state_dict)
69
+ end
70
+ model
71
+ end
72
+
73
+ def self.make_layers(cfg, batch_norm)
74
+ layers = []
75
+ in_channels = 3
76
+ cfg.each do |v|
77
+ if v == "M"
78
+ layers += [Torch::NN::MaxPool2d.new(2, stride: 2)]
79
+ else
80
+ conv2d = Torch::NN::Conv2d.new(in_channels, v, 3, padding: 1)
81
+ if batch_norm
82
+ layers += [conv2d, Torch::NN::BatchNorm2d.new(v), Torch::NN::ReLU.new(inplace: true)]
83
+ else
84
+ layers += [conv2d, Torch::NN::ReLU.new(inplace: true)]
85
+ end
86
+ in_channels = v
87
+ end
88
+ end
89
+ Torch::NN::Sequential.new(*layers)
90
+ end
91
+ end
92
+ end
93
+ end
@@ -0,0 +1,9 @@
1
+ module TorchVision
2
+ module Models
3
+ module VGG11
4
+ def self.new(**kwargs)
5
+ VGG.make_model("vgg11", "A", false, **kwargs)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,9 @@
1
+ module TorchVision
2
+ module Models
3
+ module VGG11BN
4
+ def self.new(**kwargs)
5
+ VGG.make_model("vgg11_bn", "A", true, **kwargs)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,9 @@
1
+ module TorchVision
2
+ module Models
3
+ module VGG13
4
+ def self.new(**kwargs)
5
+ VGG.make_model("vgg13", "B", false, **kwargs)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,9 @@
1
+ module TorchVision
2
+ module Models
3
+ module VGG13BN
4
+ def self.new(**kwargs)
5
+ VGG.make_model("vgg13_bn", "B", true, **kwargs)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,9 @@
1
+ module TorchVision
2
+ module Models
3
+ module VGG16
4
+ def self.new(**kwargs)
5
+ VGG.make_model("vgg16", "D", false, **kwargs)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,9 @@
1
+ module TorchVision
2
+ module Models
3
+ module VGG16BN
4
+ def self.new(**kwargs)
5
+ VGG.make_model("vgg16_bn", "D", true, **kwargs)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,9 @@
1
+ module TorchVision
2
+ module Models
3
+ module VGG19
4
+ def self.new(**kwargs)
5
+ VGG.make_model("vgg19", "E", false, **kwargs)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,9 @@
1
+ module TorchVision
2
+ module Models
3
+ module VGG19BN
4
+ def self.new(**kwargs)
5
+ VGG.make_model("vgg19_bn", "E", true, **kwargs)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,10 @@
1
+ module TorchVision
2
+ module Models
3
+ module WideResNet101_2
4
+ def self.new(**kwargs)
5
+ kwargs[:width_per_group] = 64 * 2
6
+ ResNet.make_model("wide_resnet101_2", Bottleneck, [3, 4, 23, 3], **kwargs)
7
+ end
8
+ end
9
+ end
10
+ end
@@ -0,0 +1,10 @@
1
+ module TorchVision
2
+ module Models
3
+ module WideResNet50_2
4
+ def self.new(**kwargs)
5
+ kwargs[:width_per_group] = 64 * 2
6
+ ResNet.make_model("wide_resnet50_2", Bottleneck, [3, 4, 6, 3], **kwargs)
7
+ end
8
+ end
9
+ end
10
+ end
@@ -1,3 +1,3 @@
1
1
  module TorchVision
2
- VERSION = "0.1.2"
2
+ VERSION = "0.1.3"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torchvision
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.1.2
4
+ version: 0.1.3
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2020-04-29 00:00:00.000000000 Z
11
+ date: 2020-06-30 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray
@@ -30,14 +30,14 @@ dependencies:
30
30
  requirements:
31
31
  - - ">="
32
32
  - !ruby/object:Gem::Version
33
- version: 0.2.4
33
+ version: 0.2.7
34
34
  type: :runtime
35
35
  prerelease: false
36
36
  version_requirements: !ruby/object:Gem::Requirement
37
37
  requirements:
38
38
  - - ">="
39
39
  - !ruby/object:Gem::Version
40
- version: 0.2.4
40
+ version: 0.2.7
41
41
  - !ruby/object:Gem::Dependency
42
42
  name: bundler
43
43
  requirement: !ruby/object:Gem::Requirement
@@ -96,10 +96,28 @@ files:
96
96
  - lib/torchvision/datasets/kmnist.rb
97
97
  - lib/torchvision/datasets/mnist.rb
98
98
  - lib/torchvision/datasets/vision_dataset.rb
99
+ - lib/torchvision/models/alexnet.rb
99
100
  - lib/torchvision/models/basic_block.rb
100
101
  - lib/torchvision/models/bottleneck.rb
101
102
  - lib/torchvision/models/resnet.rb
103
+ - lib/torchvision/models/resnet101.rb
104
+ - lib/torchvision/models/resnet152.rb
102
105
  - lib/torchvision/models/resnet18.rb
106
+ - lib/torchvision/models/resnet34.rb
107
+ - lib/torchvision/models/resnet50.rb
108
+ - lib/torchvision/models/resnext101_32x8d.rb
109
+ - lib/torchvision/models/resnext50_32x4d.rb
110
+ - lib/torchvision/models/vgg.rb
111
+ - lib/torchvision/models/vgg11.rb
112
+ - lib/torchvision/models/vgg11_bn.rb
113
+ - lib/torchvision/models/vgg13.rb
114
+ - lib/torchvision/models/vgg13_bn.rb
115
+ - lib/torchvision/models/vgg16.rb
116
+ - lib/torchvision/models/vgg16_bn.rb
117
+ - lib/torchvision/models/vgg19.rb
118
+ - lib/torchvision/models/vgg19_bn.rb
119
+ - lib/torchvision/models/wide_resnet101_2.rb
120
+ - lib/torchvision/models/wide_resnet50_2.rb
103
121
  - lib/torchvision/transforms/compose.rb
104
122
  - lib/torchvision/transforms/functional.rb
105
123
  - lib/torchvision/transforms/normalize.rb