torchvision 0.1.1 → 0.1.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: c6ae301007e3310dfc378a7f76d24bd71cc048be6156b32d9e0705412817c565
4
- data.tar.gz: 423aee54316683ddf9d3893e04f8836dd16ade1408f6c046f6d8f6f337a09a46
3
+ metadata.gz: 107e429f990a063e57f6218ee6d4fbed3cff80aa6746f0794798d59e1c13b099
4
+ data.tar.gz: 556fdc4c413803d415ea5575747ec2239f18b0dff9b39a37a4fd6366adec37a6
5
5
  SHA512:
6
- metadata.gz: cf9c5a514c2ef42299161f97f077897c11b4708858018ffefcb282d6434003d49555ba54ea93ef0fb9e2f3cafdd7bd4395677e02bdfa5b1dcad4772aaeccc896
7
- data.tar.gz: 0e01311267e620be6dfbed77724baed42ad9c597a2b5452d762845597ae257932e76848e0069149598743d78dc08b0374b18897ce9a272640eed4a3cec21bbd9
6
+ metadata.gz: 21b712578516c146888be30bed64a6da6339a42974f5c88fa685278caa9231e4bc4b75e250af3ad37d5450cd205891b31b281cf23fa362456c7ae00998c2736d
7
+ data.tar.gz: ac86f13e8b5d6a400842ba37b1bf593139360a8df9a19ab4674e71a4327c821a6d5735774185dcbaed368ee735b00bf33d9faedb13d46cd168314da89d900c38
@@ -1,3 +1,8 @@
1
+ ## 0.1.2 (2020-04-29)
2
+
3
+ - Added CIFAR10, CIFAR100, FashionMNIST, and KMNIST datasets
4
+ - Added ResNet18 model
5
+
1
6
  ## 0.1.1 (2020-04-28)
2
7
 
3
8
  - Removed `mini_magick` for performance
data/README.md CHANGED
@@ -20,11 +20,33 @@ This library follows the [Python API](https://pytorch.org/docs/master/torchvisio
20
20
 
21
21
  ## Datasets
22
22
 
23
- MNIST dataset
23
+ Load a dataset
24
24
 
25
25
  ```ruby
26
- trainset = TorchVision::Datasets::MNIST.new("./data", train: true, download: true)
27
- trainset.size
26
+ TorchVision::Datasets::MNIST.new("./data", train: true, download: true)
27
+ ```
28
+
29
+ Supported datasets are:
30
+
31
+ - CIFAR10
32
+ - CIFAR100
33
+ - FashionMNIST
34
+ - KMNIST
35
+ - MNIST
36
+
37
+ ## Transforms
38
+
39
+ ```ruby
40
+ TorchVision::Transforms::Compose.new([
41
+ TorchVision::Transforms::ToTensor.new,
42
+ TorchVision::Transforms::Normalize.new([0.1307], [0.3081])
43
+ ])
44
+ ```
45
+
46
+ ## Models
47
+
48
+ ```ruby
49
+ TorchVision::Models::Resnet18.new
28
50
  ```
29
51
 
30
52
  ## Disclaimer
@@ -6,13 +6,25 @@ require "torch"
6
6
  require "digest"
7
7
  require "fileutils"
8
8
  require "net/http"
9
+ require "rubygems/package"
9
10
  require "tmpdir"
10
11
 
11
12
  # modules
12
13
  require "torchvision/version"
13
14
 
14
15
  # datasets
16
+ require "torchvision/datasets/vision_dataset"
17
+ require "torchvision/datasets/cifar10"
18
+ require "torchvision/datasets/cifar100"
15
19
  require "torchvision/datasets/mnist"
20
+ require "torchvision/datasets/fashion_mnist"
21
+ require "torchvision/datasets/kmnist"
22
+
23
+ # models
24
+ require "torchvision/models/basic_block"
25
+ require "torchvision/models/bottleneck"
26
+ require "torchvision/models/resnet"
27
+ require "torchvision/models/resnet18"
16
28
 
17
29
  # transforms
18
30
  require "torchvision/transforms/compose"
@@ -0,0 +1,116 @@
1
+ module TorchVision
2
+ module Datasets
3
+ class CIFAR10 < VisionDataset
4
+ # https://www.cs.toronto.edu/~kriz/cifar.html
5
+
6
+ def initialize(root, train: true, download: false, transform: nil, target_transform: nil)
7
+ super(root, transform: transform, target_transform: target_transform)
8
+ @train = train
9
+
10
+ self.download if download
11
+
12
+ if !_check_integrity
13
+ raise Error, "Dataset not found or corrupted. You can use download=True to download it"
14
+ end
15
+
16
+ downloaded_list = @train ? train_list : test_list
17
+
18
+ @data = String.new
19
+ @targets = String.new
20
+
21
+ downloaded_list.each do |file|
22
+ file_path = File.join(@root, base_folder, file[:filename])
23
+ File.open(file_path, "rb") do |f|
24
+ while !f.eof?
25
+ f.read(1) if multiple_labels?
26
+ @targets << f.read(1)
27
+ @data << f.read(3072)
28
+ end
29
+ end
30
+ end
31
+
32
+ @targets = @targets.unpack("C*")
33
+ # TODO switch i to -1 when Numo supports it
34
+ @data = Numo::UInt8.from_binary(@data).reshape(@targets.size, 3, 32, 32)
35
+ @data = @data.transpose(0, 2, 3, 1)
36
+ end
37
+
38
+ def size
39
+ @data.shape[0]
40
+ end
41
+
42
+ def [](index)
43
+ # TODO remove trues when Numo supports it
44
+ img, target = @data[index, true, true, true], @targets[index]
45
+
46
+ # TODO convert to image
47
+ img = @transform.call(img) if @transform
48
+
49
+ target = @target_transform.call(target) if @target_transform
50
+
51
+ [img, target]
52
+ end
53
+
54
+ def _check_integrity
55
+ root = @root
56
+ (train_list + test_list).each do |fentry|
57
+ fpath = File.join(root, base_folder, fentry[:filename])
58
+ return false unless check_integrity(fpath, fentry[:sha256])
59
+ end
60
+ true
61
+ end
62
+
63
+ def download
64
+ if _check_integrity
65
+ puts "Files already downloaded and verified"
66
+ return
67
+ end
68
+
69
+ download_file(url, download_root: @root, filename: filename, sha256: tgz_sha256)
70
+
71
+ path = File.join(@root, filename)
72
+ File.open(path, "rb") do |io|
73
+ Gem::Package.new("").extract_tar_gz(io, @root)
74
+ end
75
+ end
76
+
77
+ private
78
+
79
+ def base_folder
80
+ "cifar-10-batches-bin"
81
+ end
82
+
83
+ def url
84
+ "https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz"
85
+ end
86
+
87
+ def filename
88
+ "cifar-10-binary.tar.gz"
89
+ end
90
+
91
+ def tgz_sha256
92
+ "c4a38c50a1bc5f3a1c5537f2155ab9d68f9f25eb1ed8d9ddda3db29a59bca1dd"
93
+ end
94
+
95
+ def train_list
96
+ [
97
+ {filename: "data_batch_1.bin", sha256: "cee916563c9f80d84e3cc88e17fdc0941787f1244f00a67874d45b261883ada5"},
98
+ {filename: "data_batch_2.bin", sha256: "a591ca11fa1708a91ee40f54b3da4784ccd871ecf2137de63f51ada8b3fa57ed"},
99
+ {filename: "data_batch_3.bin", sha256: "bbe8596564c0f86427f876058170b84dac6670ddf06d79402899d93ceea26f67"},
100
+ {filename: "data_batch_4.bin", sha256: "014e562d6e23c72197cc727519169a60359f5eccd8945ad5a09d710285ff4e48"},
101
+ {filename: "data_batch_5.bin", sha256: "755304fc0b379caeae8c14f0dac912fbc7d6cd469eb67a1029a08a39453a9add"},
102
+ ]
103
+ end
104
+
105
+ def test_list
106
+ [
107
+ {filename: "test_batch.bin", sha256: "8e2eb146ae340b09e24670f29cabc6326dba54da8789dab6768acf480273f65b"}
108
+ ]
109
+ end
110
+
111
+ def multiple_labels?
112
+ false
113
+ end
114
+ end
115
+ end
116
+ end
@@ -0,0 +1,41 @@
1
+ module TorchVision
2
+ module Datasets
3
+ class CIFAR100 < CIFAR10
4
+ # https://www.cs.toronto.edu/~kriz/cifar.html
5
+
6
+ private
7
+
8
+ def base_folder
9
+ "cifar-100-binary"
10
+ end
11
+
12
+ def url
13
+ "https://www.cs.toronto.edu/~kriz/cifar-100-binary.tar.gz"
14
+ end
15
+
16
+ def filename
17
+ "cifar-100-binary.tar.gz"
18
+ end
19
+
20
+ def tgz_sha256
21
+ "58a81ae192c23a4be8b1804d68e518ed807d710a4eb253b1f2a199162a40d8ec"
22
+ end
23
+
24
+ def train_list
25
+ [
26
+ {filename: "train.bin", sha256: "f31298fc616915fa142368359df1c4ca2ae984d6915ca468b998a5ec6aeebf29"}
27
+ ]
28
+ end
29
+
30
+ def test_list
31
+ [
32
+ {filename: "test.bin", sha256: "d8b1e6b7b3bee4020055f0699b111f60b1af1e262aeb93a0b659061746f8224a"}
33
+ ]
34
+ end
35
+
36
+ def multiple_labels?
37
+ true
38
+ end
39
+ end
40
+ end
41
+ end
@@ -0,0 +1,30 @@
1
+ module TorchVision
2
+ module Datasets
3
+ class FashionMNIST < MNIST
4
+ # https://github.com/zalandoresearch/fashion-mnist
5
+
6
+ private
7
+
8
+ def resources
9
+ [
10
+ {
11
+ url: "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz",
12
+ sha256: "3aede38d61863908ad78613f6a32ed271626dd12800ba2636569512369268a84"
13
+ },
14
+ {
15
+ url: "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz",
16
+ sha256: "a04f17134ac03560a47e3764e11b92fc97de4d1bfaf8ba1a3aa29af54cc90845"
17
+ },
18
+ {
19
+ url: "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz",
20
+ sha256: "346e55b948d973a97e58d2351dde16a484bd415d4595297633bb08f03db6a073"
21
+ },
22
+ {
23
+ url: "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz",
24
+ sha256: "67da17c76eaffca5446c3361aaab5c3cd6d1c2608764d35dfb1850b086bf8dd5"
25
+ }
26
+ ]
27
+ end
28
+ end
29
+ end
30
+ end
@@ -0,0 +1,30 @@
1
+ module TorchVision
2
+ module Datasets
3
+ class KMNIST < MNIST
4
+ # https://github.com/rois-codh/kmnist
5
+
6
+ private
7
+
8
+ def resources
9
+ [
10
+ {
11
+ url: "http://codh.rois.ac.jp/kmnist/dataset/kmnist/train-images-idx3-ubyte.gz",
12
+ sha256: "51467d22d8cc72929e2a028a0428f2086b092bb31cfb79c69cc0a90ce135fde4"
13
+ },
14
+ {
15
+ url: "http://codh.rois.ac.jp/kmnist/dataset/kmnist/train-labels-idx1-ubyte.gz",
16
+ sha256: "e38f9ebcd0f3ebcdec7fc8eabdcdaef93bb0df8ea12bee65224341c8183d8e17"
17
+ },
18
+ {
19
+ url: "http://codh.rois.ac.jp/kmnist/dataset/kmnist/t10k-images-idx3-ubyte.gz",
20
+ sha256: "edd7a857845ad6bb1d0ba43fe7e794d164fe2dce499a1694695a792adfac43c5"
21
+ },
22
+ {
23
+ url: "http://codh.rois.ac.jp/kmnist/dataset/kmnist/t10k-labels-idx1-ubyte.gz",
24
+ sha256: "20bb9a0ef54c7db3efc55a92eef5582c109615df22683c380526788f98e42a1c"
25
+ }
26
+ ]
27
+ end
28
+ end
29
+ end
30
+ end
@@ -1,31 +1,11 @@
1
1
  module TorchVision
2
2
  module Datasets
3
- class MNIST
4
- RESOURCES = [
5
- {
6
- url: "http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz",
7
- sha256: "440fcabf73cc546fa21475e81ea370265605f56be210a4024d2ca8f203523609"
8
- },
9
- {
10
- url: "http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz",
11
- sha256: "3552534a0a558bbed6aed32b30c495cca23d567ec52cac8be1a0730e8010255c"
12
- },
13
- {
14
- url: "http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz",
15
- sha256: "8d422c7b0a1c1c79245a5bcf07fe86e33eeafee792b84584aec276f5a2dbc4e6"
16
- },
17
- {
18
- url: "http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz",
19
- sha256: "f7ae60f92e00ec6debd23a6088c31dbd2371eca3ffa0defaefb259924204aec6"
20
- }
21
- ]
22
- TRAINING_FILE = "training.pt"
23
- TEST_FILE = "test.pt"
24
-
25
- def initialize(root, train: true, download: false, transform: nil)
26
- @root = root
3
+ class MNIST < VisionDataset
4
+ # http://yann.lecun.com/exdb/mnist/
5
+
6
+ def initialize(root, train: true, download: false, transform: nil, target_transform: nil)
7
+ super(root, transform: transform, target_transform: target_transform)
27
8
  @train = train
28
- @transform = transform
29
9
 
30
10
  self.download if download
31
11
 
@@ -33,34 +13,36 @@ module TorchVision
33
13
  raise Error, "Dataset not found. You can use download: true to download it"
34
14
  end
35
15
 
36
- data_file = @train ? TRAINING_FILE : TEST_FILE
16
+ data_file = @train ? training_file : test_file
37
17
  @data, @targets = Torch.load(File.join(processed_folder, data_file))
38
18
  end
39
19
 
40
20
  def size
41
- @data.size[0]
21
+ @data.size(0)
42
22
  end
43
23
 
44
24
  def [](index)
45
- img = @data[index]
25
+ img, target = @data[index], @targets[index].item
26
+
27
+ # TODO convert to image
46
28
  img = @transform.call(img) if @transform
47
29
 
48
- target = @targets[index].item
30
+ target = @target_transform.call(target) if @target_transform
49
31
 
50
32
  [img, target]
51
33
  end
52
34
 
53
35
  def raw_folder
54
- File.join(@root, "MNIST", "raw")
36
+ File.join(@root, self.class.name.split("::").last, "raw")
55
37
  end
56
38
 
57
39
  def processed_folder
58
- File.join(@root, "MNIST", "processed")
40
+ File.join(@root, self.class.name.split("::").last, "processed")
59
41
  end
60
42
 
61
43
  def check_exists
62
- File.exist?(File.join(processed_folder, TRAINING_FILE)) &&
63
- File.exist?(File.join(processed_folder, TEST_FILE))
44
+ File.exist?(File.join(processed_folder, training_file)) &&
45
+ File.exist?(File.join(processed_folder, test_file))
64
46
  end
65
47
 
66
48
  def download
@@ -69,7 +51,7 @@ module TorchVision
69
51
  FileUtils.mkdir_p(raw_folder)
70
52
  FileUtils.mkdir_p(processed_folder)
71
53
 
72
- RESOURCES.each do |resource|
54
+ resources.each do |resource|
73
55
  filename = resource[:url].split("/").last
74
56
  download_file(resource[:url], download_root: raw_folder, filename: filename, sha256: resource[:sha256])
75
57
  end
@@ -85,14 +67,43 @@ module TorchVision
85
67
  unpack_mnist("t10k-labels-idx1-ubyte", 8, [10000])
86
68
  ]
87
69
 
88
- Torch.save(training_set, File.join(processed_folder, TRAINING_FILE))
89
- Torch.save(test_set, File.join(processed_folder, TEST_FILE))
70
+ Torch.save(training_set, File.join(processed_folder, training_file))
71
+ Torch.save(test_set, File.join(processed_folder, test_file))
90
72
 
91
73
  puts "Done!"
92
74
  end
93
75
 
94
76
  private
95
77
 
78
+ def resources
79
+ [
80
+ {
81
+ url: "http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz",
82
+ sha256: "440fcabf73cc546fa21475e81ea370265605f56be210a4024d2ca8f203523609"
83
+ },
84
+ {
85
+ url: "http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz",
86
+ sha256: "3552534a0a558bbed6aed32b30c495cca23d567ec52cac8be1a0730e8010255c"
87
+ },
88
+ {
89
+ url: "http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz",
90
+ sha256: "8d422c7b0a1c1c79245a5bcf07fe86e33eeafee792b84584aec276f5a2dbc4e6"
91
+ },
92
+ {
93
+ url: "http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz",
94
+ sha256: "f7ae60f92e00ec6debd23a6088c31dbd2371eca3ffa0defaefb259924204aec6"
95
+ }
96
+ ]
97
+ end
98
+
99
+ def training_file
100
+ "training.pt"
101
+ end
102
+
103
+ def test_file
104
+ "test.pt"
105
+ end
106
+
96
107
  def unpack_mnist(path, offset, shape)
97
108
  path = File.join(raw_folder, "#{path}.gz")
98
109
  File.open(path, "rb") do |f|
@@ -101,45 +112,6 @@ module TorchVision
101
112
  Torch.tensor(Numo::UInt8.from_string(gz.read, shape))
102
113
  end
103
114
  end
104
-
105
- def download_file(url, download_root:, filename:, sha256:)
106
- FileUtils.mkdir_p(download_root)
107
-
108
- dest = File.join(download_root, filename)
109
- return dest if File.exist?(dest)
110
-
111
- temp_path = "#{Dir.tmpdir}/#{Time.now.to_f}" # TODO better name
112
-
113
- digest = Digest::SHA256.new
114
-
115
- uri = URI(url)
116
-
117
- # Net::HTTP automatically adds Accept-Encoding for compression
118
- # of response bodies and automatically decompresses gzip
119
- # and deflateresponses unless a Range header was sent.
120
- # https://ruby-doc.org/stdlib-2.6.4/libdoc/net/http/rdoc/Net/HTTP.html
121
- Net::HTTP.start(uri.host, uri.port, use_ssl: uri.scheme == "https") do |http|
122
- request = Net::HTTP::Get.new(uri)
123
-
124
- puts "Downloading #{url}..."
125
- File.open(temp_path, "wb") do |f|
126
- http.request(request) do |response|
127
- response.read_body do |chunk|
128
- f.write(chunk)
129
- digest.update(chunk)
130
- end
131
- end
132
- end
133
- end
134
-
135
- if digest.hexdigest != sha256
136
- raise Error, "Bad hash: #{digest.hexdigest}"
137
- end
138
-
139
- FileUtils.mv(temp_path, dest)
140
-
141
- dest
142
- end
143
115
  end
144
116
  end
145
117
  end
@@ -0,0 +1,66 @@
1
+ module TorchVision
2
+ module Datasets
3
+ # TODO inherit Torch::Utils::Data::Dataset
4
+ class VisionDataset
5
+ def initialize(root, transforms: nil, transform: nil, target_transform: nil)
6
+ @root = root
7
+
8
+ has_transforms = !transforms.nil?
9
+ has_separate_transform = !transform.nil? || !target_transform.nil?
10
+ if has_transforms && has_separate_transform
11
+ raise ArgumentError, "Only transforms or transform/target_transform can be passed as argument"
12
+ end
13
+
14
+ @transform = transform
15
+ @target_transform = target_transform
16
+
17
+ if has_separate_transform
18
+ # transforms = StandardTransform.new(transform, target_transform)
19
+ end
20
+ @transforms = transforms
21
+ end
22
+
23
+ private
24
+
25
+ def download_file(url, download_root:, filename:, sha256:)
26
+ FileUtils.mkdir_p(download_root)
27
+
28
+ dest = File.join(download_root, filename)
29
+ return dest if File.exist?(dest)
30
+
31
+ temp_path = "#{Dir.tmpdir}/#{Time.now.to_f}" # TODO better name
32
+
33
+ uri = URI(url)
34
+
35
+ # Net::HTTP automatically adds Accept-Encoding for compression
36
+ # of response bodies and automatically decompresses gzip
37
+ # and deflateresponses unless a Range header was sent.
38
+ # https://ruby-doc.org/stdlib-2.6.4/libdoc/net/http/rdoc/Net/HTTP.html
39
+ Net::HTTP.start(uri.host, uri.port, use_ssl: uri.scheme == "https") do |http|
40
+ request = Net::HTTP::Get.new(uri)
41
+
42
+ puts "Downloading #{url}..."
43
+ File.open(temp_path, "wb") do |f|
44
+ http.request(request) do |response|
45
+ response.read_body do |chunk|
46
+ f.write(chunk)
47
+ end
48
+ end
49
+ end
50
+ end
51
+
52
+ unless check_integrity(temp_path, sha256)
53
+ raise Error, "Bad hash"
54
+ end
55
+
56
+ FileUtils.mv(temp_path, dest)
57
+
58
+ dest
59
+ end
60
+
61
+ def check_integrity(path, sha256)
62
+ File.exist?(path) && Digest::SHA256.file(path).hexdigest == sha256
63
+ end
64
+ end
65
+ end
66
+ end
@@ -0,0 +1,46 @@
1
+ module TorchVision
2
+ module Models
3
+ class BasicBlock < Torch::NN::Module
4
+ def initialize(inplanes, planes, stride: 1, downsample: nil, groups: 1, base_width: 64, dilation: 1, norm_layer: nil)
5
+ super()
6
+ norm_layer ||= Torch::NN::BatchNorm2d
7
+ if groups != 1 || base_width != 64
8
+ raise ArgumentError, "BasicBlock only supports groups=1 and base_width=64"
9
+ end
10
+ if dilation > 1
11
+ raise NotImplementedError, "Dilation > 1 not supported in BasicBlock"
12
+ end
13
+ # Both self.conv1 and self.downsample layers downsample the input when stride != 1
14
+ @conv1 = Torch::NN::Conv2d.new(inplanes, planes, 3, stride: stride, padding: 1, groups: 1, bias: false, dilation: 1)
15
+ @bn1 = norm_layer.new(planes)
16
+ @relu = Torch::NN::ReLU.new(inplace: true)
17
+ @conv2 = Torch::NN::Conv2d.new(planes, planes, 3, stride: 1, padding: 1, groups: 1, bias: false, dilation: 1)
18
+ @bn2 = norm_layer.new(planes)
19
+ @downsample = downsample
20
+ @stride = stride
21
+ end
22
+
23
+ def forward(x)
24
+ identity = x
25
+
26
+ out = @conv1.call(x)
27
+ out = @bn1.call(out)
28
+ out = @relu.call(out)
29
+
30
+ out = @conv2.call(out)
31
+ out = @bn2.call(out)
32
+
33
+ identity = @downsample.call(x) if @downsample
34
+
35
+ out += identity
36
+ out = @relu.call(out)
37
+
38
+ out
39
+ end
40
+
41
+ def self.expansion
42
+ 1
43
+ end
44
+ end
45
+ end
46
+ end
@@ -0,0 +1,47 @@
1
+ module TorchVision
2
+ module Models
3
+ class Bottleneck < Torch::NN::Module
4
+ def initialize(inplanes, planes, stride: 1, downsample: nil, groups: 1, base_width: 64, dilation: 1, norm_layer: nil)
5
+ super()
6
+ norm_layer ||= Torch::NN::BatchNorm2d
7
+ width = (planes * (base_width / 64.0)).to_i * groups
8
+ # Both self.conv2 and self.downsample layers downsample the input when stride != 1
9
+ @conv1 = Torch::NN::Conv2d.new(inplanes, width, 1, stride: 1, bias: false)
10
+ @bn1 = norm_layer.new(width)
11
+ @conv2 = Torch::NN::Conv2d.new(width, width, 3, stride: stride, padding: dilation, groups: groups, bias: false, dilation: dilation)
12
+ @bn2 = norm_layer.new(width)
13
+ @conv3 = Torch::NN::Conv2d.new(width, planes * self.class.expansion, 1, stride: 1, bias: false)
14
+ @bn3 = norm_layer.new(planes * self.class.expansion)
15
+ @relu = Torch::NN::ReLU.new(inplace: true)
16
+ @downsample = downsample
17
+ @stride = stride
18
+ end
19
+
20
+ def forward(x)
21
+ identity = x
22
+
23
+ out = @conv1.call(x)
24
+ out = @bn1.call(out)
25
+ out = @relu.call(out)
26
+
27
+ out = @conv2.call(out)
28
+ out = @bn2.call(out)
29
+ out = @relu.call(out)
30
+
31
+ out = @conv3.call(out)
32
+ out = @bn3.call(out)
33
+
34
+ identity = @downsample.call(x) if @downsample
35
+
36
+ out += identity
37
+ out = @relu.call(out)
38
+
39
+ out
40
+ end
41
+
42
+ def self.expansion
43
+ 4
44
+ end
45
+ end
46
+ end
47
+ end
@@ -0,0 +1,107 @@
1
+ module TorchVision
2
+ module Models
3
+ class ResNet < Torch::NN::Module
4
+ def initialize(block, layers, num_classes=1000, zero_init_residual: false,
5
+ groups: 1, width_per_group: 64, replace_stride_with_dilation: nil, norm_layer: nil)
6
+
7
+ super()
8
+ norm_layer ||= Torch::NN::BatchNorm2d
9
+ @norm_layer = norm_layer
10
+
11
+ @inplanes = 64
12
+ @dilation = 1
13
+ if replace_stride_with_dilation.nil?
14
+ # each element in the tuple indicates if we should replace
15
+ # the 2x2 stride with a dilated convolution instead
16
+ replace_stride_with_dilation = [false, false, false]
17
+ end
18
+ if replace_stride_with_dilation.length != 3
19
+ raise ArgumentError, "replace_stride_with_dilation should be nil or a 3-element tuple, got #{replace_stride_with_dilation}"
20
+ end
21
+ @groups = groups
22
+ @base_width = width_per_group
23
+ @conv1 = Torch::NN::Conv2d.new(3, @inplanes, 7, stride: 2, padding: 3, bias: false)
24
+ @bn1 = norm_layer.new(@inplanes)
25
+ @relu = Torch::NN::ReLU.new(inplace: true)
26
+ @maxpool = Torch::NN::MaxPool2d.new(3, stride: 2, padding: 1)
27
+ @layer1 = _make_layer(block, 64, layers[0])
28
+ @layer2 = _make_layer(block, 128, layers[1], stride: 2, dilate: replace_stride_with_dilation[0])
29
+ @layer3 = _make_layer(block, 256, layers[2], stride: 2, dilate: replace_stride_with_dilation[1])
30
+ @layer4 = _make_layer(block, 512, layers[3], stride: 2, dilate: replace_stride_with_dilation[2])
31
+ @avgpool = Torch::NN::AdaptiveAvgPool2d.new([1, 1])
32
+ @fc = Torch::NN::Linear.new(512 * block.expansion, num_classes)
33
+
34
+ modules.each do |m|
35
+ case m
36
+ when Torch::NN::Conv2d
37
+ Torch::NN::Init.kaiming_normal!(m.weight, mode: "fan_out", nonlinearity: "relu")
38
+ when Torch::NN::BatchNorm2d, Torch::NN::GroupNorm
39
+ Torch::NN::Init.constant!(m.weight, 1)
40
+ Torch::NN::Init.constant!(m.bias, 0)
41
+ end
42
+ end
43
+
44
+ # Zero-initialize the last BN in each residual branch,
45
+ # so that the residual branch starts with zeros, and each residual block behaves like an identity.
46
+ # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
47
+ if zero_init_residual
48
+ modules.each do |m|
49
+ case m
50
+ when Bottleneck
51
+ Torch::NN::Init.constant!(m.bn3.weight, 0)
52
+ when BasicBlock
53
+ Torch::NN::Init.constant!(m.bn2.weight, 0)
54
+ end
55
+ end
56
+ end
57
+ end
58
+
59
+ def _make_layer(block, planes, blocks, stride: 1, dilate: false)
60
+ norm_layer = @norm_layer
61
+ downsample = nil
62
+ previous_dilation = @dilation
63
+ if dilate
64
+ @dilation *= stride
65
+ stride = 1
66
+ end
67
+ if stride != 1 || @inplanes != planes * block.expansion
68
+ downsample = Torch::NN::Sequential.new(
69
+ Torch::NN::Conv2d.new(@inplanes, planes * block.expansion, 1, stride: stride, bias: false),
70
+ norm_layer.new(planes * block.expansion)
71
+ )
72
+ end
73
+
74
+ layers = []
75
+ layers << block.new(@inplanes, planes, stride: stride, downsample: downsample, groups: @groups, base_width: @base_width, dilation: previous_dilation, norm_layer: norm_layer)
76
+ @inplanes = planes * block.expansion
77
+ (blocks - 1).times do
78
+ layers << block.new(@inplanes, planes, groups: @groups, base_width: @base_width, dilation: @dilation, norm_layer: norm_layer)
79
+ end
80
+
81
+ Torch::NN::Sequential.new(*layers)
82
+ end
83
+
84
+ def _forward_impl(x)
85
+ x = @conv1.call(x)
86
+ x = @bn1.call(x)
87
+ x = @relu.call(x)
88
+ x = @maxpool.call(x)
89
+
90
+ x = @layer1.call(x)
91
+ x = @layer2.call(x)
92
+ x = @layer3.call(x)
93
+ x = @layer4.call(x)
94
+
95
+ x = @avgpool.call(x)
96
+ x = Torch.flatten(x, 1)
97
+ x = @fc.call(x)
98
+
99
+ x
100
+ end
101
+
102
+ def forward(x)
103
+ _forward_impl(x)
104
+ end
105
+ end
106
+ end
107
+ end
@@ -0,0 +1,15 @@
1
+ module TorchVision
2
+ module Models
3
+ module ResNet18
4
+ def self.new(pretrained: false, **kwargs)
5
+ model = ResNet.new(BasicBlock, [2, 2, 2, 2], **kwargs)
6
+ if pretrained
7
+ url = "https://download.pytorch.org/models/resnet18-5c106cde.pth"
8
+ state_dict = Torch::Hub.load_state_dict_from_url(url)
9
+ model.load_state_dict(state_dict)
10
+ end
11
+ model
12
+ end
13
+ end
14
+ end
15
+ end
@@ -22,18 +22,35 @@ module TorchVision
22
22
  if std.to_a.any? { |v| v == 0 }
23
23
  raise ArgumentError, "std evaluated to zero after conversion to #{dtype}, leading to division by zero."
24
24
  end
25
- # if mean.ndim == 1
26
- # raise Torch::NotImplementedYet
27
- # end
28
- # if std.ndim == 1
29
- # raise Torch::NotImplementedYet
30
- # end
25
+ if mean.ndim == 1
26
+ mean = mean[0...mean.size(0), nil, nil]
27
+ end
28
+ if std.ndim == 1
29
+ std = std[0...std.size(0), nil, nil]
30
+ end
31
31
  tensor.sub!(mean).div!(std)
32
32
  tensor
33
33
  end
34
34
 
35
35
  # TODO improve
36
36
  def to_tensor(pic)
37
+ if !pic.is_a?(Numo::NArray) && !pic.is_a?(Torch::Tensor)
38
+ raise ArgumentError, "pic should be tensor or Numo::NArray. Got #{pic.class.name}"
39
+ end
40
+
41
+ if pic.is_a?(Numo::NArray) && ![2, 3].include?(pic.ndim)
42
+ raise ArgumentError, "pic should be 2/3 dimensional. Got #{pic.dim} dimensions."
43
+ end
44
+
45
+ if pic.is_a?(Numo::NArray)
46
+ if pic.ndim == 2
47
+ raise Torch::NotImplementedYet
48
+ end
49
+
50
+ img = Torch.from_numo(pic.transpose(2, 0, 1))
51
+ return img.float.div(255)
52
+ end
53
+
37
54
  pic = pic.float
38
55
  pic.unsqueeze!(0).div!(255)
39
56
  end
@@ -1,3 +1,3 @@
1
1
  module TorchVision
2
- VERSION = "0.1.1"
2
+ VERSION = "0.1.2"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torchvision
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.1.1
4
+ version: 0.1.2
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2020-04-28 00:00:00.000000000 Z
11
+ date: 2020-04-29 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray
@@ -30,14 +30,14 @@ dependencies:
30
30
  requirements:
31
31
  - - ">="
32
32
  - !ruby/object:Gem::Version
33
- version: 0.2.3
33
+ version: 0.2.4
34
34
  type: :runtime
35
35
  prerelease: false
36
36
  version_requirements: !ruby/object:Gem::Requirement
37
37
  requirements:
38
38
  - - ">="
39
39
  - !ruby/object:Gem::Version
40
- version: 0.2.3
40
+ version: 0.2.4
41
41
  - !ruby/object:Gem::Dependency
42
42
  name: bundler
43
43
  requirement: !ruby/object:Gem::Requirement
@@ -90,7 +90,16 @@ files:
90
90
  - LICENSE.txt
91
91
  - README.md
92
92
  - lib/torchvision.rb
93
+ - lib/torchvision/datasets/cifar10.rb
94
+ - lib/torchvision/datasets/cifar100.rb
95
+ - lib/torchvision/datasets/fashion_mnist.rb
96
+ - lib/torchvision/datasets/kmnist.rb
93
97
  - lib/torchvision/datasets/mnist.rb
98
+ - lib/torchvision/datasets/vision_dataset.rb
99
+ - lib/torchvision/models/basic_block.rb
100
+ - lib/torchvision/models/bottleneck.rb
101
+ - lib/torchvision/models/resnet.rb
102
+ - lib/torchvision/models/resnet18.rb
94
103
  - lib/torchvision/transforms/compose.rb
95
104
  - lib/torchvision/transforms/functional.rb
96
105
  - lib/torchvision/transforms/normalize.rb