torchvision 0.1.1 → 0.1.2

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: c6ae301007e3310dfc378a7f76d24bd71cc048be6156b32d9e0705412817c565
4
- data.tar.gz: 423aee54316683ddf9d3893e04f8836dd16ade1408f6c046f6d8f6f337a09a46
3
+ metadata.gz: 107e429f990a063e57f6218ee6d4fbed3cff80aa6746f0794798d59e1c13b099
4
+ data.tar.gz: 556fdc4c413803d415ea5575747ec2239f18b0dff9b39a37a4fd6366adec37a6
5
5
  SHA512:
6
- metadata.gz: cf9c5a514c2ef42299161f97f077897c11b4708858018ffefcb282d6434003d49555ba54ea93ef0fb9e2f3cafdd7bd4395677e02bdfa5b1dcad4772aaeccc896
7
- data.tar.gz: 0e01311267e620be6dfbed77724baed42ad9c597a2b5452d762845597ae257932e76848e0069149598743d78dc08b0374b18897ce9a272640eed4a3cec21bbd9
6
+ metadata.gz: 21b712578516c146888be30bed64a6da6339a42974f5c88fa685278caa9231e4bc4b75e250af3ad37d5450cd205891b31b281cf23fa362456c7ae00998c2736d
7
+ data.tar.gz: ac86f13e8b5d6a400842ba37b1bf593139360a8df9a19ab4674e71a4327c821a6d5735774185dcbaed368ee735b00bf33d9faedb13d46cd168314da89d900c38
@@ -1,3 +1,8 @@
1
+ ## 0.1.2 (2020-04-29)
2
+
3
+ - Added CIFAR10, CIFAR100, FashionMNIST, and KMNIST datasets
4
+ - Added ResNet18 model
5
+
1
6
  ## 0.1.1 (2020-04-28)
2
7
 
3
8
  - Removed `mini_magick` for performance
data/README.md CHANGED
@@ -20,11 +20,33 @@ This library follows the [Python API](https://pytorch.org/docs/master/torchvisio
20
20
 
21
21
  ## Datasets
22
22
 
23
- MNIST dataset
23
+ Load a dataset
24
24
 
25
25
  ```ruby
26
- trainset = TorchVision::Datasets::MNIST.new("./data", train: true, download: true)
27
- trainset.size
26
+ TorchVision::Datasets::MNIST.new("./data", train: true, download: true)
27
+ ```
28
+
29
+ Supported datasets are:
30
+
31
+ - CIFAR10
32
+ - CIFAR100
33
+ - FashionMNIST
34
+ - KMNIST
35
+ - MNIST
36
+
37
+ ## Transforms
38
+
39
+ ```ruby
40
+ TorchVision::Transforms::Compose.new([
41
+ TorchVision::Transforms::ToTensor.new,
42
+ TorchVision::Transforms::Normalize.new([0.1307], [0.3081])
43
+ ])
44
+ ```
45
+
46
+ ## Models
47
+
48
+ ```ruby
49
+ TorchVision::Models::Resnet18.new
28
50
  ```
29
51
 
30
52
  ## Disclaimer
@@ -6,13 +6,25 @@ require "torch"
6
6
  require "digest"
7
7
  require "fileutils"
8
8
  require "net/http"
9
+ require "rubygems/package"
9
10
  require "tmpdir"
10
11
 
11
12
  # modules
12
13
  require "torchvision/version"
13
14
 
14
15
  # datasets
16
+ require "torchvision/datasets/vision_dataset"
17
+ require "torchvision/datasets/cifar10"
18
+ require "torchvision/datasets/cifar100"
15
19
  require "torchvision/datasets/mnist"
20
+ require "torchvision/datasets/fashion_mnist"
21
+ require "torchvision/datasets/kmnist"
22
+
23
+ # models
24
+ require "torchvision/models/basic_block"
25
+ require "torchvision/models/bottleneck"
26
+ require "torchvision/models/resnet"
27
+ require "torchvision/models/resnet18"
16
28
 
17
29
  # transforms
18
30
  require "torchvision/transforms/compose"
@@ -0,0 +1,116 @@
1
+ module TorchVision
2
+ module Datasets
3
+ class CIFAR10 < VisionDataset
4
+ # https://www.cs.toronto.edu/~kriz/cifar.html
5
+
6
+ def initialize(root, train: true, download: false, transform: nil, target_transform: nil)
7
+ super(root, transform: transform, target_transform: target_transform)
8
+ @train = train
9
+
10
+ self.download if download
11
+
12
+ if !_check_integrity
13
+ raise Error, "Dataset not found or corrupted. You can use download=True to download it"
14
+ end
15
+
16
+ downloaded_list = @train ? train_list : test_list
17
+
18
+ @data = String.new
19
+ @targets = String.new
20
+
21
+ downloaded_list.each do |file|
22
+ file_path = File.join(@root, base_folder, file[:filename])
23
+ File.open(file_path, "rb") do |f|
24
+ while !f.eof?
25
+ f.read(1) if multiple_labels?
26
+ @targets << f.read(1)
27
+ @data << f.read(3072)
28
+ end
29
+ end
30
+ end
31
+
32
+ @targets = @targets.unpack("C*")
33
+ # TODO switch i to -1 when Numo supports it
34
+ @data = Numo::UInt8.from_binary(@data).reshape(@targets.size, 3, 32, 32)
35
+ @data = @data.transpose(0, 2, 3, 1)
36
+ end
37
+
38
+ def size
39
+ @data.shape[0]
40
+ end
41
+
42
+ def [](index)
43
+ # TODO remove trues when Numo supports it
44
+ img, target = @data[index, true, true, true], @targets[index]
45
+
46
+ # TODO convert to image
47
+ img = @transform.call(img) if @transform
48
+
49
+ target = @target_transform.call(target) if @target_transform
50
+
51
+ [img, target]
52
+ end
53
+
54
+ def _check_integrity
55
+ root = @root
56
+ (train_list + test_list).each do |fentry|
57
+ fpath = File.join(root, base_folder, fentry[:filename])
58
+ return false unless check_integrity(fpath, fentry[:sha256])
59
+ end
60
+ true
61
+ end
62
+
63
+ def download
64
+ if _check_integrity
65
+ puts "Files already downloaded and verified"
66
+ return
67
+ end
68
+
69
+ download_file(url, download_root: @root, filename: filename, sha256: tgz_sha256)
70
+
71
+ path = File.join(@root, filename)
72
+ File.open(path, "rb") do |io|
73
+ Gem::Package.new("").extract_tar_gz(io, @root)
74
+ end
75
+ end
76
+
77
+ private
78
+
79
+ def base_folder
80
+ "cifar-10-batches-bin"
81
+ end
82
+
83
+ def url
84
+ "https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz"
85
+ end
86
+
87
+ def filename
88
+ "cifar-10-binary.tar.gz"
89
+ end
90
+
91
+ def tgz_sha256
92
+ "c4a38c50a1bc5f3a1c5537f2155ab9d68f9f25eb1ed8d9ddda3db29a59bca1dd"
93
+ end
94
+
95
+ def train_list
96
+ [
97
+ {filename: "data_batch_1.bin", sha256: "cee916563c9f80d84e3cc88e17fdc0941787f1244f00a67874d45b261883ada5"},
98
+ {filename: "data_batch_2.bin", sha256: "a591ca11fa1708a91ee40f54b3da4784ccd871ecf2137de63f51ada8b3fa57ed"},
99
+ {filename: "data_batch_3.bin", sha256: "bbe8596564c0f86427f876058170b84dac6670ddf06d79402899d93ceea26f67"},
100
+ {filename: "data_batch_4.bin", sha256: "014e562d6e23c72197cc727519169a60359f5eccd8945ad5a09d710285ff4e48"},
101
+ {filename: "data_batch_5.bin", sha256: "755304fc0b379caeae8c14f0dac912fbc7d6cd469eb67a1029a08a39453a9add"},
102
+ ]
103
+ end
104
+
105
+ def test_list
106
+ [
107
+ {filename: "test_batch.bin", sha256: "8e2eb146ae340b09e24670f29cabc6326dba54da8789dab6768acf480273f65b"}
108
+ ]
109
+ end
110
+
111
+ def multiple_labels?
112
+ false
113
+ end
114
+ end
115
+ end
116
+ end
@@ -0,0 +1,41 @@
1
+ module TorchVision
2
+ module Datasets
3
+ class CIFAR100 < CIFAR10
4
+ # https://www.cs.toronto.edu/~kriz/cifar.html
5
+
6
+ private
7
+
8
+ def base_folder
9
+ "cifar-100-binary"
10
+ end
11
+
12
+ def url
13
+ "https://www.cs.toronto.edu/~kriz/cifar-100-binary.tar.gz"
14
+ end
15
+
16
+ def filename
17
+ "cifar-100-binary.tar.gz"
18
+ end
19
+
20
+ def tgz_sha256
21
+ "58a81ae192c23a4be8b1804d68e518ed807d710a4eb253b1f2a199162a40d8ec"
22
+ end
23
+
24
+ def train_list
25
+ [
26
+ {filename: "train.bin", sha256: "f31298fc616915fa142368359df1c4ca2ae984d6915ca468b998a5ec6aeebf29"}
27
+ ]
28
+ end
29
+
30
+ def test_list
31
+ [
32
+ {filename: "test.bin", sha256: "d8b1e6b7b3bee4020055f0699b111f60b1af1e262aeb93a0b659061746f8224a"}
33
+ ]
34
+ end
35
+
36
+ def multiple_labels?
37
+ true
38
+ end
39
+ end
40
+ end
41
+ end
@@ -0,0 +1,30 @@
1
+ module TorchVision
2
+ module Datasets
3
+ class FashionMNIST < MNIST
4
+ # https://github.com/zalandoresearch/fashion-mnist
5
+
6
+ private
7
+
8
+ def resources
9
+ [
10
+ {
11
+ url: "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz",
12
+ sha256: "3aede38d61863908ad78613f6a32ed271626dd12800ba2636569512369268a84"
13
+ },
14
+ {
15
+ url: "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz",
16
+ sha256: "a04f17134ac03560a47e3764e11b92fc97de4d1bfaf8ba1a3aa29af54cc90845"
17
+ },
18
+ {
19
+ url: "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz",
20
+ sha256: "346e55b948d973a97e58d2351dde16a484bd415d4595297633bb08f03db6a073"
21
+ },
22
+ {
23
+ url: "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz",
24
+ sha256: "67da17c76eaffca5446c3361aaab5c3cd6d1c2608764d35dfb1850b086bf8dd5"
25
+ }
26
+ ]
27
+ end
28
+ end
29
+ end
30
+ end
@@ -0,0 +1,30 @@
1
+ module TorchVision
2
+ module Datasets
3
+ class KMNIST < MNIST
4
+ # https://github.com/rois-codh/kmnist
5
+
6
+ private
7
+
8
+ def resources
9
+ [
10
+ {
11
+ url: "http://codh.rois.ac.jp/kmnist/dataset/kmnist/train-images-idx3-ubyte.gz",
12
+ sha256: "51467d22d8cc72929e2a028a0428f2086b092bb31cfb79c69cc0a90ce135fde4"
13
+ },
14
+ {
15
+ url: "http://codh.rois.ac.jp/kmnist/dataset/kmnist/train-labels-idx1-ubyte.gz",
16
+ sha256: "e38f9ebcd0f3ebcdec7fc8eabdcdaef93bb0df8ea12bee65224341c8183d8e17"
17
+ },
18
+ {
19
+ url: "http://codh.rois.ac.jp/kmnist/dataset/kmnist/t10k-images-idx3-ubyte.gz",
20
+ sha256: "edd7a857845ad6bb1d0ba43fe7e794d164fe2dce499a1694695a792adfac43c5"
21
+ },
22
+ {
23
+ url: "http://codh.rois.ac.jp/kmnist/dataset/kmnist/t10k-labels-idx1-ubyte.gz",
24
+ sha256: "20bb9a0ef54c7db3efc55a92eef5582c109615df22683c380526788f98e42a1c"
25
+ }
26
+ ]
27
+ end
28
+ end
29
+ end
30
+ end
@@ -1,31 +1,11 @@
1
1
  module TorchVision
2
2
  module Datasets
3
- class MNIST
4
- RESOURCES = [
5
- {
6
- url: "http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz",
7
- sha256: "440fcabf73cc546fa21475e81ea370265605f56be210a4024d2ca8f203523609"
8
- },
9
- {
10
- url: "http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz",
11
- sha256: "3552534a0a558bbed6aed32b30c495cca23d567ec52cac8be1a0730e8010255c"
12
- },
13
- {
14
- url: "http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz",
15
- sha256: "8d422c7b0a1c1c79245a5bcf07fe86e33eeafee792b84584aec276f5a2dbc4e6"
16
- },
17
- {
18
- url: "http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz",
19
- sha256: "f7ae60f92e00ec6debd23a6088c31dbd2371eca3ffa0defaefb259924204aec6"
20
- }
21
- ]
22
- TRAINING_FILE = "training.pt"
23
- TEST_FILE = "test.pt"
24
-
25
- def initialize(root, train: true, download: false, transform: nil)
26
- @root = root
3
+ class MNIST < VisionDataset
4
+ # http://yann.lecun.com/exdb/mnist/
5
+
6
+ def initialize(root, train: true, download: false, transform: nil, target_transform: nil)
7
+ super(root, transform: transform, target_transform: target_transform)
27
8
  @train = train
28
- @transform = transform
29
9
 
30
10
  self.download if download
31
11
 
@@ -33,34 +13,36 @@ module TorchVision
33
13
  raise Error, "Dataset not found. You can use download: true to download it"
34
14
  end
35
15
 
36
- data_file = @train ? TRAINING_FILE : TEST_FILE
16
+ data_file = @train ? training_file : test_file
37
17
  @data, @targets = Torch.load(File.join(processed_folder, data_file))
38
18
  end
39
19
 
40
20
  def size
41
- @data.size[0]
21
+ @data.size(0)
42
22
  end
43
23
 
44
24
  def [](index)
45
- img = @data[index]
25
+ img, target = @data[index], @targets[index].item
26
+
27
+ # TODO convert to image
46
28
  img = @transform.call(img) if @transform
47
29
 
48
- target = @targets[index].item
30
+ target = @target_transform.call(target) if @target_transform
49
31
 
50
32
  [img, target]
51
33
  end
52
34
 
53
35
  def raw_folder
54
- File.join(@root, "MNIST", "raw")
36
+ File.join(@root, self.class.name.split("::").last, "raw")
55
37
  end
56
38
 
57
39
  def processed_folder
58
- File.join(@root, "MNIST", "processed")
40
+ File.join(@root, self.class.name.split("::").last, "processed")
59
41
  end
60
42
 
61
43
  def check_exists
62
- File.exist?(File.join(processed_folder, TRAINING_FILE)) &&
63
- File.exist?(File.join(processed_folder, TEST_FILE))
44
+ File.exist?(File.join(processed_folder, training_file)) &&
45
+ File.exist?(File.join(processed_folder, test_file))
64
46
  end
65
47
 
66
48
  def download
@@ -69,7 +51,7 @@ module TorchVision
69
51
  FileUtils.mkdir_p(raw_folder)
70
52
  FileUtils.mkdir_p(processed_folder)
71
53
 
72
- RESOURCES.each do |resource|
54
+ resources.each do |resource|
73
55
  filename = resource[:url].split("/").last
74
56
  download_file(resource[:url], download_root: raw_folder, filename: filename, sha256: resource[:sha256])
75
57
  end
@@ -85,14 +67,43 @@ module TorchVision
85
67
  unpack_mnist("t10k-labels-idx1-ubyte", 8, [10000])
86
68
  ]
87
69
 
88
- Torch.save(training_set, File.join(processed_folder, TRAINING_FILE))
89
- Torch.save(test_set, File.join(processed_folder, TEST_FILE))
70
+ Torch.save(training_set, File.join(processed_folder, training_file))
71
+ Torch.save(test_set, File.join(processed_folder, test_file))
90
72
 
91
73
  puts "Done!"
92
74
  end
93
75
 
94
76
  private
95
77
 
78
+ def resources
79
+ [
80
+ {
81
+ url: "http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz",
82
+ sha256: "440fcabf73cc546fa21475e81ea370265605f56be210a4024d2ca8f203523609"
83
+ },
84
+ {
85
+ url: "http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz",
86
+ sha256: "3552534a0a558bbed6aed32b30c495cca23d567ec52cac8be1a0730e8010255c"
87
+ },
88
+ {
89
+ url: "http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz",
90
+ sha256: "8d422c7b0a1c1c79245a5bcf07fe86e33eeafee792b84584aec276f5a2dbc4e6"
91
+ },
92
+ {
93
+ url: "http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz",
94
+ sha256: "f7ae60f92e00ec6debd23a6088c31dbd2371eca3ffa0defaefb259924204aec6"
95
+ }
96
+ ]
97
+ end
98
+
99
+ def training_file
100
+ "training.pt"
101
+ end
102
+
103
+ def test_file
104
+ "test.pt"
105
+ end
106
+
96
107
  def unpack_mnist(path, offset, shape)
97
108
  path = File.join(raw_folder, "#{path}.gz")
98
109
  File.open(path, "rb") do |f|
@@ -101,45 +112,6 @@ module TorchVision
101
112
  Torch.tensor(Numo::UInt8.from_string(gz.read, shape))
102
113
  end
103
114
  end
104
-
105
- def download_file(url, download_root:, filename:, sha256:)
106
- FileUtils.mkdir_p(download_root)
107
-
108
- dest = File.join(download_root, filename)
109
- return dest if File.exist?(dest)
110
-
111
- temp_path = "#{Dir.tmpdir}/#{Time.now.to_f}" # TODO better name
112
-
113
- digest = Digest::SHA256.new
114
-
115
- uri = URI(url)
116
-
117
- # Net::HTTP automatically adds Accept-Encoding for compression
118
- # of response bodies and automatically decompresses gzip
119
- # and deflateresponses unless a Range header was sent.
120
- # https://ruby-doc.org/stdlib-2.6.4/libdoc/net/http/rdoc/Net/HTTP.html
121
- Net::HTTP.start(uri.host, uri.port, use_ssl: uri.scheme == "https") do |http|
122
- request = Net::HTTP::Get.new(uri)
123
-
124
- puts "Downloading #{url}..."
125
- File.open(temp_path, "wb") do |f|
126
- http.request(request) do |response|
127
- response.read_body do |chunk|
128
- f.write(chunk)
129
- digest.update(chunk)
130
- end
131
- end
132
- end
133
- end
134
-
135
- if digest.hexdigest != sha256
136
- raise Error, "Bad hash: #{digest.hexdigest}"
137
- end
138
-
139
- FileUtils.mv(temp_path, dest)
140
-
141
- dest
142
- end
143
115
  end
144
116
  end
145
117
  end
@@ -0,0 +1,66 @@
1
+ module TorchVision
2
+ module Datasets
3
+ # TODO inherit Torch::Utils::Data::Dataset
4
+ class VisionDataset
5
+ def initialize(root, transforms: nil, transform: nil, target_transform: nil)
6
+ @root = root
7
+
8
+ has_transforms = !transforms.nil?
9
+ has_separate_transform = !transform.nil? || !target_transform.nil?
10
+ if has_transforms && has_separate_transform
11
+ raise ArgumentError, "Only transforms or transform/target_transform can be passed as argument"
12
+ end
13
+
14
+ @transform = transform
15
+ @target_transform = target_transform
16
+
17
+ if has_separate_transform
18
+ # transforms = StandardTransform.new(transform, target_transform)
19
+ end
20
+ @transforms = transforms
21
+ end
22
+
23
+ private
24
+
25
+ def download_file(url, download_root:, filename:, sha256:)
26
+ FileUtils.mkdir_p(download_root)
27
+
28
+ dest = File.join(download_root, filename)
29
+ return dest if File.exist?(dest)
30
+
31
+ temp_path = "#{Dir.tmpdir}/#{Time.now.to_f}" # TODO better name
32
+
33
+ uri = URI(url)
34
+
35
+ # Net::HTTP automatically adds Accept-Encoding for compression
36
+ # of response bodies and automatically decompresses gzip
37
+ # and deflateresponses unless a Range header was sent.
38
+ # https://ruby-doc.org/stdlib-2.6.4/libdoc/net/http/rdoc/Net/HTTP.html
39
+ Net::HTTP.start(uri.host, uri.port, use_ssl: uri.scheme == "https") do |http|
40
+ request = Net::HTTP::Get.new(uri)
41
+
42
+ puts "Downloading #{url}..."
43
+ File.open(temp_path, "wb") do |f|
44
+ http.request(request) do |response|
45
+ response.read_body do |chunk|
46
+ f.write(chunk)
47
+ end
48
+ end
49
+ end
50
+ end
51
+
52
+ unless check_integrity(temp_path, sha256)
53
+ raise Error, "Bad hash"
54
+ end
55
+
56
+ FileUtils.mv(temp_path, dest)
57
+
58
+ dest
59
+ end
60
+
61
+ def check_integrity(path, sha256)
62
+ File.exist?(path) && Digest::SHA256.file(path).hexdigest == sha256
63
+ end
64
+ end
65
+ end
66
+ end
@@ -0,0 +1,46 @@
1
+ module TorchVision
2
+ module Models
3
+ class BasicBlock < Torch::NN::Module
4
+ def initialize(inplanes, planes, stride: 1, downsample: nil, groups: 1, base_width: 64, dilation: 1, norm_layer: nil)
5
+ super()
6
+ norm_layer ||= Torch::NN::BatchNorm2d
7
+ if groups != 1 || base_width != 64
8
+ raise ArgumentError, "BasicBlock only supports groups=1 and base_width=64"
9
+ end
10
+ if dilation > 1
11
+ raise NotImplementedError, "Dilation > 1 not supported in BasicBlock"
12
+ end
13
+ # Both self.conv1 and self.downsample layers downsample the input when stride != 1
14
+ @conv1 = Torch::NN::Conv2d.new(inplanes, planes, 3, stride: stride, padding: 1, groups: 1, bias: false, dilation: 1)
15
+ @bn1 = norm_layer.new(planes)
16
+ @relu = Torch::NN::ReLU.new(inplace: true)
17
+ @conv2 = Torch::NN::Conv2d.new(planes, planes, 3, stride: 1, padding: 1, groups: 1, bias: false, dilation: 1)
18
+ @bn2 = norm_layer.new(planes)
19
+ @downsample = downsample
20
+ @stride = stride
21
+ end
22
+
23
+ def forward(x)
24
+ identity = x
25
+
26
+ out = @conv1.call(x)
27
+ out = @bn1.call(out)
28
+ out = @relu.call(out)
29
+
30
+ out = @conv2.call(out)
31
+ out = @bn2.call(out)
32
+
33
+ identity = @downsample.call(x) if @downsample
34
+
35
+ out += identity
36
+ out = @relu.call(out)
37
+
38
+ out
39
+ end
40
+
41
+ def self.expansion
42
+ 1
43
+ end
44
+ end
45
+ end
46
+ end
@@ -0,0 +1,47 @@
1
+ module TorchVision
2
+ module Models
3
+ class Bottleneck < Torch::NN::Module
4
+ def initialize(inplanes, planes, stride: 1, downsample: nil, groups: 1, base_width: 64, dilation: 1, norm_layer: nil)
5
+ super()
6
+ norm_layer ||= Torch::NN::BatchNorm2d
7
+ width = (planes * (base_width / 64.0)).to_i * groups
8
+ # Both self.conv2 and self.downsample layers downsample the input when stride != 1
9
+ @conv1 = Torch::NN::Conv2d.new(inplanes, width, 1, stride: 1, bias: false)
10
+ @bn1 = norm_layer.new(width)
11
+ @conv2 = Torch::NN::Conv2d.new(width, width, 3, stride: stride, padding: dilation, groups: groups, bias: false, dilation: dilation)
12
+ @bn2 = norm_layer.new(width)
13
+ @conv3 = Torch::NN::Conv2d.new(width, planes * self.class.expansion, 1, stride: 1, bias: false)
14
+ @bn3 = norm_layer.new(planes * self.class.expansion)
15
+ @relu = Torch::NN::ReLU.new(inplace: true)
16
+ @downsample = downsample
17
+ @stride = stride
18
+ end
19
+
20
+ def forward(x)
21
+ identity = x
22
+
23
+ out = @conv1.call(x)
24
+ out = @bn1.call(out)
25
+ out = @relu.call(out)
26
+
27
+ out = @conv2.call(out)
28
+ out = @bn2.call(out)
29
+ out = @relu.call(out)
30
+
31
+ out = @conv3.call(out)
32
+ out = @bn3.call(out)
33
+
34
+ identity = @downsample.call(x) if @downsample
35
+
36
+ out += identity
37
+ out = @relu.call(out)
38
+
39
+ out
40
+ end
41
+
42
+ def self.expansion
43
+ 4
44
+ end
45
+ end
46
+ end
47
+ end
@@ -0,0 +1,107 @@
1
+ module TorchVision
2
+ module Models
3
+ class ResNet < Torch::NN::Module
4
+ def initialize(block, layers, num_classes=1000, zero_init_residual: false,
5
+ groups: 1, width_per_group: 64, replace_stride_with_dilation: nil, norm_layer: nil)
6
+
7
+ super()
8
+ norm_layer ||= Torch::NN::BatchNorm2d
9
+ @norm_layer = norm_layer
10
+
11
+ @inplanes = 64
12
+ @dilation = 1
13
+ if replace_stride_with_dilation.nil?
14
+ # each element in the tuple indicates if we should replace
15
+ # the 2x2 stride with a dilated convolution instead
16
+ replace_stride_with_dilation = [false, false, false]
17
+ end
18
+ if replace_stride_with_dilation.length != 3
19
+ raise ArgumentError, "replace_stride_with_dilation should be nil or a 3-element tuple, got #{replace_stride_with_dilation}"
20
+ end
21
+ @groups = groups
22
+ @base_width = width_per_group
23
+ @conv1 = Torch::NN::Conv2d.new(3, @inplanes, 7, stride: 2, padding: 3, bias: false)
24
+ @bn1 = norm_layer.new(@inplanes)
25
+ @relu = Torch::NN::ReLU.new(inplace: true)
26
+ @maxpool = Torch::NN::MaxPool2d.new(3, stride: 2, padding: 1)
27
+ @layer1 = _make_layer(block, 64, layers[0])
28
+ @layer2 = _make_layer(block, 128, layers[1], stride: 2, dilate: replace_stride_with_dilation[0])
29
+ @layer3 = _make_layer(block, 256, layers[2], stride: 2, dilate: replace_stride_with_dilation[1])
30
+ @layer4 = _make_layer(block, 512, layers[3], stride: 2, dilate: replace_stride_with_dilation[2])
31
+ @avgpool = Torch::NN::AdaptiveAvgPool2d.new([1, 1])
32
+ @fc = Torch::NN::Linear.new(512 * block.expansion, num_classes)
33
+
34
+ modules.each do |m|
35
+ case m
36
+ when Torch::NN::Conv2d
37
+ Torch::NN::Init.kaiming_normal!(m.weight, mode: "fan_out", nonlinearity: "relu")
38
+ when Torch::NN::BatchNorm2d, Torch::NN::GroupNorm
39
+ Torch::NN::Init.constant!(m.weight, 1)
40
+ Torch::NN::Init.constant!(m.bias, 0)
41
+ end
42
+ end
43
+
44
+ # Zero-initialize the last BN in each residual branch,
45
+ # so that the residual branch starts with zeros, and each residual block behaves like an identity.
46
+ # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
47
+ if zero_init_residual
48
+ modules.each do |m|
49
+ case m
50
+ when Bottleneck
51
+ Torch::NN::Init.constant!(m.bn3.weight, 0)
52
+ when BasicBlock
53
+ Torch::NN::Init.constant!(m.bn2.weight, 0)
54
+ end
55
+ end
56
+ end
57
+ end
58
+
59
+ def _make_layer(block, planes, blocks, stride: 1, dilate: false)
60
+ norm_layer = @norm_layer
61
+ downsample = nil
62
+ previous_dilation = @dilation
63
+ if dilate
64
+ @dilation *= stride
65
+ stride = 1
66
+ end
67
+ if stride != 1 || @inplanes != planes * block.expansion
68
+ downsample = Torch::NN::Sequential.new(
69
+ Torch::NN::Conv2d.new(@inplanes, planes * block.expansion, 1, stride: stride, bias: false),
70
+ norm_layer.new(planes * block.expansion)
71
+ )
72
+ end
73
+
74
+ layers = []
75
+ layers << block.new(@inplanes, planes, stride: stride, downsample: downsample, groups: @groups, base_width: @base_width, dilation: previous_dilation, norm_layer: norm_layer)
76
+ @inplanes = planes * block.expansion
77
+ (blocks - 1).times do
78
+ layers << block.new(@inplanes, planes, groups: @groups, base_width: @base_width, dilation: @dilation, norm_layer: norm_layer)
79
+ end
80
+
81
+ Torch::NN::Sequential.new(*layers)
82
+ end
83
+
84
+ def _forward_impl(x)
85
+ x = @conv1.call(x)
86
+ x = @bn1.call(x)
87
+ x = @relu.call(x)
88
+ x = @maxpool.call(x)
89
+
90
+ x = @layer1.call(x)
91
+ x = @layer2.call(x)
92
+ x = @layer3.call(x)
93
+ x = @layer4.call(x)
94
+
95
+ x = @avgpool.call(x)
96
+ x = Torch.flatten(x, 1)
97
+ x = @fc.call(x)
98
+
99
+ x
100
+ end
101
+
102
+ def forward(x)
103
+ _forward_impl(x)
104
+ end
105
+ end
106
+ end
107
+ end
@@ -0,0 +1,15 @@
1
+ module TorchVision
2
+ module Models
3
+ module ResNet18
4
+ def self.new(pretrained: false, **kwargs)
5
+ model = ResNet.new(BasicBlock, [2, 2, 2, 2], **kwargs)
6
+ if pretrained
7
+ url = "https://download.pytorch.org/models/resnet18-5c106cde.pth"
8
+ state_dict = Torch::Hub.load_state_dict_from_url(url)
9
+ model.load_state_dict(state_dict)
10
+ end
11
+ model
12
+ end
13
+ end
14
+ end
15
+ end
@@ -22,18 +22,35 @@ module TorchVision
22
22
  if std.to_a.any? { |v| v == 0 }
23
23
  raise ArgumentError, "std evaluated to zero after conversion to #{dtype}, leading to division by zero."
24
24
  end
25
- # if mean.ndim == 1
26
- # raise Torch::NotImplementedYet
27
- # end
28
- # if std.ndim == 1
29
- # raise Torch::NotImplementedYet
30
- # end
25
+ if mean.ndim == 1
26
+ mean = mean[0...mean.size(0), nil, nil]
27
+ end
28
+ if std.ndim == 1
29
+ std = std[0...std.size(0), nil, nil]
30
+ end
31
31
  tensor.sub!(mean).div!(std)
32
32
  tensor
33
33
  end
34
34
 
35
35
  # TODO improve
36
36
  def to_tensor(pic)
37
+ if !pic.is_a?(Numo::NArray) && !pic.is_a?(Torch::Tensor)
38
+ raise ArgumentError, "pic should be tensor or Numo::NArray. Got #{pic.class.name}"
39
+ end
40
+
41
+ if pic.is_a?(Numo::NArray) && ![2, 3].include?(pic.ndim)
42
+ raise ArgumentError, "pic should be 2/3 dimensional. Got #{pic.dim} dimensions."
43
+ end
44
+
45
+ if pic.is_a?(Numo::NArray)
46
+ if pic.ndim == 2
47
+ raise Torch::NotImplementedYet
48
+ end
49
+
50
+ img = Torch.from_numo(pic.transpose(2, 0, 1))
51
+ return img.float.div(255)
52
+ end
53
+
37
54
  pic = pic.float
38
55
  pic.unsqueeze!(0).div!(255)
39
56
  end
@@ -1,3 +1,3 @@
1
1
  module TorchVision
2
- VERSION = "0.1.1"
2
+ VERSION = "0.1.2"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torchvision
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.1.1
4
+ version: 0.1.2
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2020-04-28 00:00:00.000000000 Z
11
+ date: 2020-04-29 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray
@@ -30,14 +30,14 @@ dependencies:
30
30
  requirements:
31
31
  - - ">="
32
32
  - !ruby/object:Gem::Version
33
- version: 0.2.3
33
+ version: 0.2.4
34
34
  type: :runtime
35
35
  prerelease: false
36
36
  version_requirements: !ruby/object:Gem::Requirement
37
37
  requirements:
38
38
  - - ">="
39
39
  - !ruby/object:Gem::Version
40
- version: 0.2.3
40
+ version: 0.2.4
41
41
  - !ruby/object:Gem::Dependency
42
42
  name: bundler
43
43
  requirement: !ruby/object:Gem::Requirement
@@ -90,7 +90,16 @@ files:
90
90
  - LICENSE.txt
91
91
  - README.md
92
92
  - lib/torchvision.rb
93
+ - lib/torchvision/datasets/cifar10.rb
94
+ - lib/torchvision/datasets/cifar100.rb
95
+ - lib/torchvision/datasets/fashion_mnist.rb
96
+ - lib/torchvision/datasets/kmnist.rb
93
97
  - lib/torchvision/datasets/mnist.rb
98
+ - lib/torchvision/datasets/vision_dataset.rb
99
+ - lib/torchvision/models/basic_block.rb
100
+ - lib/torchvision/models/bottleneck.rb
101
+ - lib/torchvision/models/resnet.rb
102
+ - lib/torchvision/models/resnet18.rb
94
103
  - lib/torchvision/transforms/compose.rb
95
104
  - lib/torchvision/transforms/functional.rb
96
105
  - lib/torchvision/transforms/normalize.rb