torchvision 0.1.0

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml ADDED
@@ -0,0 +1,7 @@
1
+ ---
2
+ SHA256:
3
+ metadata.gz: 19fff61bb461e5fbf850702bf485a442b2013807ac28549bc427e5f3d8c7472b
4
+ data.tar.gz: bc8004d2ca26e9022f2fa2b1663277bcedf701729f59aae248058a26e605a5ad
5
+ SHA512:
6
+ metadata.gz: fbd3d7292efa6ee2fd2c0ff8cb85659d37a19761b4d93a9a4923a9990d7400c849738913db12720e7d232d6fdb180c16f06da0ecc3601a922bf0036beb0b44bd
7
+ data.tar.gz: '09b86d6b01f25d43ac65d9c3d0509b3488ed2108d57090553ae46526f841551cbdd16fefce1cb15ea7f326d946e3abb26d777c3c45e23538f8fb7753fdb6fec9'
data/CHANGELOG.md ADDED
@@ -0,0 +1,3 @@
1
+ ## 0.1.0 (2020-04-27)
2
+
3
+ - First release
data/LICENSE.txt ADDED
@@ -0,0 +1,30 @@
1
+ BSD 3-Clause License
2
+
3
+ Copyright (c) Andrew Kane 2020,
4
+ Copyright (c) Soumith Chintala 2016,
5
+ All rights reserved.
6
+
7
+ Redistribution and use in source and binary forms, with or without
8
+ modification, are permitted provided that the following conditions are met:
9
+
10
+ * Redistributions of source code must retain the above copyright notice, this
11
+ list of conditions and the following disclaimer.
12
+
13
+ * Redistributions in binary form must reproduce the above copyright notice,
14
+ this list of conditions and the following disclaimer in the documentation
15
+ and/or other materials provided with the distribution.
16
+
17
+ * Neither the name of the copyright holder nor the names of its
18
+ contributors may be used to endorse or promote products derived from
19
+ this software without specific prior written permission.
20
+
21
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
22
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
23
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
25
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
26
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
27
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
28
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
29
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
30
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
data/README.md ADDED
@@ -0,0 +1,54 @@
1
+ # TorchVision
2
+
3
+ :fire: Computer vision datasets, transforms, and models for Ruby
4
+
5
+ This gem is currently experimental. There may be breaking changes between each release. Please report any issues you experience.
6
+
7
+ ## Installation
8
+
9
+ Add this line to your application’s Gemfile:
10
+
11
+ ```ruby
12
+ gem 'torchvision'
13
+ ```
14
+
15
+ ## Getting Started
16
+
17
+ This library follows the [Python API](https://pytorch.org/docs/master/torchvision/). Many methods and options are missing at the moment. PRs welcome!
18
+
19
+ ## Datasets
20
+
21
+ MNIST dataset
22
+
23
+ ```ruby
24
+ trainset = TorchVision::Datasets::MNIST.new("./data", train: true, download: true)
25
+ trainset.size
26
+ ```
27
+
28
+ ## Disclaimer
29
+
30
+ This library downloads and prepares public datasets. We don’t host any datasets. Be sure to adhere to the license for each dataset.
31
+
32
+ If you’re a dataset owner and wish to update any details or remove it from this project, let us know.
33
+
34
+ ## History
35
+
36
+ View the [changelog](https://github.com/ankane/torchvision/blob/master/CHANGELOG.md)
37
+
38
+ ## Contributing
39
+
40
+ Everyone is encouraged to help improve this project. Here are a few ways you can help:
41
+
42
+ - [Report bugs](https://github.com/ankane/torchvision/issues)
43
+ - Fix bugs and [submit pull requests](https://github.com/ankane/torchvision/pulls)
44
+ - Write, clarify, or fix documentation
45
+ - Suggest or add new features
46
+
47
+ To get started with development:
48
+
49
+ ```sh
50
+ git clone https://github.com/ankane/torchvision.git
51
+ cd torchvision
52
+ bundle install
53
+ bundle exec rake test
54
+ ```
@@ -0,0 +1,25 @@
1
+ # dependencies
2
+ require "mini_magick"
3
+ require "numo/narray"
4
+ require "torch"
5
+
6
+ # stdlib
7
+ require "digest"
8
+ require "fileutils"
9
+ require "net/http"
10
+
11
+ # modules
12
+ require "torchvision/version"
13
+
14
+ # datasets
15
+ require "torchvision/datasets/mnist"
16
+
17
+ # transforms
18
+ require "torchvision/transforms/compose"
19
+ require "torchvision/transforms/functional"
20
+ require "torchvision/transforms/normalize"
21
+ require "torchvision/transforms/to_tensor"
22
+
23
+ module TorchVision
24
+ class Error < StandardError; end
25
+ end
@@ -0,0 +1,146 @@
1
+ module TorchVision
2
+ module Datasets
3
+ class MNIST
4
+ RESOURCES = [
5
+ {
6
+ url: "http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz",
7
+ sha256: "440fcabf73cc546fa21475e81ea370265605f56be210a4024d2ca8f203523609"
8
+ },
9
+ {
10
+ url: "http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz",
11
+ sha256: "3552534a0a558bbed6aed32b30c495cca23d567ec52cac8be1a0730e8010255c"
12
+ },
13
+ {
14
+ url: "http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz",
15
+ sha256: "8d422c7b0a1c1c79245a5bcf07fe86e33eeafee792b84584aec276f5a2dbc4e6"
16
+ },
17
+ {
18
+ url: "http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz",
19
+ sha256: "f7ae60f92e00ec6debd23a6088c31dbd2371eca3ffa0defaefb259924204aec6"
20
+ }
21
+ ]
22
+ TRAINING_FILE = "training.pt"
23
+ TEST_FILE = "test.pt"
24
+
25
+ def initialize(root, train: true, download: false, transform: nil)
26
+ @root = root
27
+ @train = train
28
+ @transform = transform
29
+
30
+ self.download if download
31
+
32
+ if !check_exists
33
+ raise Error, "Dataset not found. You can use download: true to download it"
34
+ end
35
+
36
+ data_file = @train ? TRAINING_FILE : TEST_FILE
37
+ @data, @targets = Torch.load(File.join(processed_folder, data_file))
38
+ end
39
+
40
+ def size
41
+ @data.size[0]
42
+ end
43
+
44
+ def [](index)
45
+ img = @data[index]
46
+ img = MiniMagick::Image.import_pixels(img.numo.to_binary, img.size(0), img.size(1), 8, "gray")
47
+ img = @transform.call(img) if @transform
48
+
49
+ target = @targets[index].item
50
+
51
+ [img, target]
52
+ end
53
+
54
+ def raw_folder
55
+ File.join(@root, "MNIST", "raw")
56
+ end
57
+
58
+ def processed_folder
59
+ File.join(@root, "MNIST", "processed")
60
+ end
61
+
62
+ def check_exists
63
+ File.exist?(File.join(processed_folder, TRAINING_FILE)) &&
64
+ File.exist?(File.join(processed_folder, TEST_FILE))
65
+ end
66
+
67
+ def download
68
+ return if check_exists
69
+
70
+ FileUtils.mkdir_p(raw_folder)
71
+ FileUtils.mkdir_p(processed_folder)
72
+
73
+ RESOURCES.each do |resource|
74
+ filename = resource[:url].split("/").last
75
+ download_file(resource[:url], download_root: raw_folder, filename: filename, sha256: resource[:sha256])
76
+ end
77
+
78
+ puts "Processing..."
79
+
80
+ training_set = [
81
+ unpack_mnist("train-images-idx3-ubyte", 16, [60000, 28, 28]),
82
+ unpack_mnist("train-labels-idx1-ubyte", 8, [60000])
83
+ ]
84
+ test_set = [
85
+ unpack_mnist("t10k-images-idx3-ubyte", 16, [10000, 28, 28]),
86
+ unpack_mnist("t10k-labels-idx1-ubyte", 8, [10000])
87
+ ]
88
+
89
+ Torch.save(training_set, File.join(processed_folder, TRAINING_FILE))
90
+ Torch.save(test_set, File.join(processed_folder, TEST_FILE))
91
+
92
+ puts "Done!"
93
+ end
94
+
95
+ private
96
+
97
+ def unpack_mnist(path, offset, shape)
98
+ path = File.join(raw_folder, "#{path}.gz")
99
+ File.open(path, "rb") do |f|
100
+ gz = Zlib::GzipReader.new(f)
101
+ gz.read(offset)
102
+ Torch.tensor(Numo::UInt8.from_string(gz.read, shape))
103
+ end
104
+ end
105
+
106
+ def download_file(url, download_root:, filename:, sha256:)
107
+ FileUtils.mkdir_p(download_root)
108
+
109
+ dest = File.join(download_root, filename)
110
+ return dest if File.exist?(dest)
111
+
112
+ temp_path = "#{Dir.tmpdir}/#{Time.now.to_f}" # TODO better name
113
+
114
+ digest = Digest::SHA256.new
115
+
116
+ uri = URI(url)
117
+
118
+ # Net::HTTP automatically adds Accept-Encoding for compression
119
+ # of response bodies and automatically decompresses gzip
120
+ # and deflateresponses unless a Range header was sent.
121
+ # https://ruby-doc.org/stdlib-2.6.4/libdoc/net/http/rdoc/Net/HTTP.html
122
+ Net::HTTP.start(uri.host, uri.port, use_ssl: uri.scheme == "https") do |http|
123
+ request = Net::HTTP::Get.new(uri)
124
+
125
+ puts "Downloading #{url}..."
126
+ File.open(temp_path, "wb") do |f|
127
+ http.request(request) do |response|
128
+ response.read_body do |chunk|
129
+ f.write(chunk)
130
+ digest.update(chunk)
131
+ end
132
+ end
133
+ end
134
+ end
135
+
136
+ if digest.hexdigest != sha256
137
+ raise Error, "Bad hash: #{digest.hexdigest}"
138
+ end
139
+
140
+ FileUtils.mv(temp_path, dest)
141
+
142
+ dest
143
+ end
144
+ end
145
+ end
146
+ end
@@ -0,0 +1,16 @@
1
+ module TorchVision
2
+ module Transforms
3
+ class Compose
4
+ def initialize(transforms)
5
+ @transforms = transforms
6
+ end
7
+
8
+ def call(img)
9
+ @transforms.each do |t|
10
+ img = t.call(img)
11
+ end
12
+ img
13
+ end
14
+ end
15
+ end
16
+ end
@@ -0,0 +1,45 @@
1
+ module TorchVision
2
+ module Transforms
3
+ class Functional
4
+ class << self
5
+ def normalize(tensor, mean, std, inplace: false)
6
+ unless Torch.tensor?(tensor)
7
+ raise ArgumentError, "tensor should be a torch tensor. Got #{tensor.class.name}"
8
+ end
9
+
10
+ if tensor.ndimension != 3
11
+ raise ArgumentError, "Expected tensor to be a tensor image of size (C, H, W). Got tensor.size() = #{tensor.size}"
12
+ end
13
+
14
+ tensor = tensor.clone unless inplace
15
+
16
+ dtype = tensor.dtype
17
+ # TODO Torch.as_tensor
18
+ mean = Torch.tensor(mean, dtype: dtype, device: tensor.device)
19
+ std = Torch.tensor(std, dtype: dtype, device: tensor.device)
20
+
21
+ # TODO
22
+ if std.to_a.any? { |v| v == 0 }
23
+ raise ArgumentError, "std evaluated to zero after conversion to #{dtype}, leading to division by zero."
24
+ end
25
+ # if mean.ndim == 1
26
+ # raise Torch::NotImplementedYet
27
+ # end
28
+ # if std.ndim == 1
29
+ # raise Torch::NotImplementedYet
30
+ # end
31
+ tensor.sub!(mean).div!(std)
32
+ tensor
33
+ end
34
+
35
+ # TODO improve
36
+ def to_tensor(pic)
37
+ Torch.tensor(pic.get_pixels, dtype: :float)
38
+ end
39
+ end
40
+ end
41
+
42
+ # shortcut
43
+ F = Functional
44
+ end
45
+ end
@@ -0,0 +1,15 @@
1
+ module TorchVision
2
+ module Transforms
3
+ class Normalize
4
+ def initialize(mean, std, inplace: false)
5
+ @mean = mean
6
+ @std = std
7
+ @inplace = inplace
8
+ end
9
+
10
+ def call(tensor)
11
+ F.normalize(tensor, @mean, @std, inplace: @inplace)
12
+ end
13
+ end
14
+ end
15
+ end
@@ -0,0 +1,9 @@
1
+ module TorchVision
2
+ module Transforms
3
+ class ToTensor
4
+ def call(pic)
5
+ F.to_tensor(pic)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,3 @@
1
+ module TorchVision
2
+ VERSION = "0.1.0"
3
+ end
metadata ADDED
@@ -0,0 +1,136 @@
1
+ --- !ruby/object:Gem::Specification
2
+ name: torchvision
3
+ version: !ruby/object:Gem::Version
4
+ version: 0.1.0
5
+ platform: ruby
6
+ authors:
7
+ - Andrew Kane
8
+ autorequire:
9
+ bindir: bin
10
+ cert_chain: []
11
+ date: 2020-04-27 00:00:00.000000000 Z
12
+ dependencies:
13
+ - !ruby/object:Gem::Dependency
14
+ name: mini_magick
15
+ requirement: !ruby/object:Gem::Requirement
16
+ requirements:
17
+ - - ">="
18
+ - !ruby/object:Gem::Version
19
+ version: '0'
20
+ type: :runtime
21
+ prerelease: false
22
+ version_requirements: !ruby/object:Gem::Requirement
23
+ requirements:
24
+ - - ">="
25
+ - !ruby/object:Gem::Version
26
+ version: '0'
27
+ - !ruby/object:Gem::Dependency
28
+ name: numo-narray
29
+ requirement: !ruby/object:Gem::Requirement
30
+ requirements:
31
+ - - ">="
32
+ - !ruby/object:Gem::Version
33
+ version: '0'
34
+ type: :runtime
35
+ prerelease: false
36
+ version_requirements: !ruby/object:Gem::Requirement
37
+ requirements:
38
+ - - ">="
39
+ - !ruby/object:Gem::Version
40
+ version: '0'
41
+ - !ruby/object:Gem::Dependency
42
+ name: torch-rb
43
+ requirement: !ruby/object:Gem::Requirement
44
+ requirements:
45
+ - - ">="
46
+ - !ruby/object:Gem::Version
47
+ version: 0.2.2
48
+ type: :runtime
49
+ prerelease: false
50
+ version_requirements: !ruby/object:Gem::Requirement
51
+ requirements:
52
+ - - ">="
53
+ - !ruby/object:Gem::Version
54
+ version: 0.2.2
55
+ - !ruby/object:Gem::Dependency
56
+ name: bundler
57
+ requirement: !ruby/object:Gem::Requirement
58
+ requirements:
59
+ - - ">="
60
+ - !ruby/object:Gem::Version
61
+ version: '0'
62
+ type: :development
63
+ prerelease: false
64
+ version_requirements: !ruby/object:Gem::Requirement
65
+ requirements:
66
+ - - ">="
67
+ - !ruby/object:Gem::Version
68
+ version: '0'
69
+ - !ruby/object:Gem::Dependency
70
+ name: rake
71
+ requirement: !ruby/object:Gem::Requirement
72
+ requirements:
73
+ - - ">="
74
+ - !ruby/object:Gem::Version
75
+ version: '0'
76
+ type: :development
77
+ prerelease: false
78
+ version_requirements: !ruby/object:Gem::Requirement
79
+ requirements:
80
+ - - ">="
81
+ - !ruby/object:Gem::Version
82
+ version: '0'
83
+ - !ruby/object:Gem::Dependency
84
+ name: minitest
85
+ requirement: !ruby/object:Gem::Requirement
86
+ requirements:
87
+ - - ">="
88
+ - !ruby/object:Gem::Version
89
+ version: '5'
90
+ type: :development
91
+ prerelease: false
92
+ version_requirements: !ruby/object:Gem::Requirement
93
+ requirements:
94
+ - - ">="
95
+ - !ruby/object:Gem::Version
96
+ version: '5'
97
+ description:
98
+ email: andrew@chartkick.com
99
+ executables: []
100
+ extensions: []
101
+ extra_rdoc_files: []
102
+ files:
103
+ - CHANGELOG.md
104
+ - LICENSE.txt
105
+ - README.md
106
+ - lib/torchvision.rb
107
+ - lib/torchvision/datasets/mnist.rb
108
+ - lib/torchvision/transforms/compose.rb
109
+ - lib/torchvision/transforms/functional.rb
110
+ - lib/torchvision/transforms/normalize.rb
111
+ - lib/torchvision/transforms/to_tensor.rb
112
+ - lib/torchvision/version.rb
113
+ homepage: https://github.com/ankane/torchvision
114
+ licenses:
115
+ - BSD-3-Clause
116
+ metadata: {}
117
+ post_install_message:
118
+ rdoc_options: []
119
+ require_paths:
120
+ - lib
121
+ required_ruby_version: !ruby/object:Gem::Requirement
122
+ requirements:
123
+ - - ">="
124
+ - !ruby/object:Gem::Version
125
+ version: '2.4'
126
+ required_rubygems_version: !ruby/object:Gem::Requirement
127
+ requirements:
128
+ - - ">="
129
+ - !ruby/object:Gem::Version
130
+ version: '0'
131
+ requirements: []
132
+ rubygems_version: 3.1.2
133
+ signing_key:
134
+ specification_version: 4
135
+ summary: Computer vision datasets, transforms, and models for Ruby
136
+ test_files: []