torchvision 0.1.0 → 0.2.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +35 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +133 -5
  5. data/lib/torchvision.rb +40 -1
  6. data/lib/torchvision/datasets/cifar10.rb +117 -0
  7. data/lib/torchvision/datasets/cifar100.rb +41 -0
  8. data/lib/torchvision/datasets/dataset_folder.rb +91 -0
  9. data/lib/torchvision/datasets/fashion_mnist.rb +30 -0
  10. data/lib/torchvision/datasets/image_folder.rb +12 -0
  11. data/lib/torchvision/datasets/kmnist.rb +30 -0
  12. data/lib/torchvision/datasets/mnist.rb +47 -76
  13. data/lib/torchvision/datasets/vision_dataset.rb +67 -0
  14. data/lib/torchvision/models/alexnet.rb +42 -0
  15. data/lib/torchvision/models/basic_block.rb +46 -0
  16. data/lib/torchvision/models/bottleneck.rb +47 -0
  17. data/lib/torchvision/models/resnet.rb +129 -0
  18. data/lib/torchvision/models/resnet101.rb +9 -0
  19. data/lib/torchvision/models/resnet152.rb +9 -0
  20. data/lib/torchvision/models/resnet18.rb +9 -0
  21. data/lib/torchvision/models/resnet34.rb +9 -0
  22. data/lib/torchvision/models/resnet50.rb +9 -0
  23. data/lib/torchvision/models/resnext101_32x8d.rb +11 -0
  24. data/lib/torchvision/models/resnext50_32x4d.rb +11 -0
  25. data/lib/torchvision/models/vgg.rb +93 -0
  26. data/lib/torchvision/models/vgg11.rb +9 -0
  27. data/lib/torchvision/models/vgg11_bn.rb +9 -0
  28. data/lib/torchvision/models/vgg13.rb +9 -0
  29. data/lib/torchvision/models/vgg13_bn.rb +9 -0
  30. data/lib/torchvision/models/vgg16.rb +9 -0
  31. data/lib/torchvision/models/vgg16_bn.rb +9 -0
  32. data/lib/torchvision/models/vgg19.rb +9 -0
  33. data/lib/torchvision/models/vgg19_bn.rb +9 -0
  34. data/lib/torchvision/models/wide_resnet101_2.rb +10 -0
  35. data/lib/torchvision/models/wide_resnet50_2.rb +10 -0
  36. data/lib/torchvision/transforms/center_crop.rb +13 -0
  37. data/lib/torchvision/transforms/compose.rb +2 -2
  38. data/lib/torchvision/transforms/functional.rb +142 -7
  39. data/lib/torchvision/transforms/normalize.rb +2 -2
  40. data/lib/torchvision/transforms/random_horizontal_flip.rb +18 -0
  41. data/lib/torchvision/transforms/random_resized_crop.rb +70 -0
  42. data/lib/torchvision/transforms/random_vertical_flip.rb +18 -0
  43. data/lib/torchvision/transforms/resize.rb +13 -0
  44. data/lib/torchvision/transforms/to_tensor.rb +2 -2
  45. data/lib/torchvision/utils.rb +120 -0
  46. data/lib/torchvision/version.rb +1 -1
  47. metadata +50 -57
@@ -0,0 +1,91 @@
1
+ module TorchVision
2
+ module Datasets
3
+ class DatasetFolder < VisionDataset
4
+ attr_reader :classes
5
+
6
+ def initialize(root, extensions: nil, transform: nil, target_transform: nil, is_valid_file: nil)
7
+ super(root, transform: transform, target_transform: target_transform)
8
+ classes, class_to_idx = find_classes(@root)
9
+ samples = make_dataset(@root, class_to_idx, extensions, is_valid_file)
10
+ if samples.empty?
11
+ msg = "Found 0 files in subfolders of: #{@root}\n"
12
+ unless extensions.nil?
13
+ msg += "Supported extensions are: #{extensions.join(",")}"
14
+ end
15
+ raise RuntimeError, msg
16
+ end
17
+
18
+ @loader = lambda do |path|
19
+ Vips::Image.new_from_file(path)
20
+ end
21
+ @extensions = extensions
22
+
23
+ @classes = classes
24
+ @class_to_idx = class_to_idx
25
+ @samples = samples
26
+ @targets = samples.map { |s| s[1] }
27
+ end
28
+
29
+ def [](index)
30
+ path, target = @samples[index]
31
+ sample = @loader.call(path)
32
+ if @transform
33
+ sample = @transform.call(sample)
34
+ end
35
+ if @target_transform
36
+ target = @target_transform.call(target)
37
+ end
38
+
39
+ [sample, target]
40
+ end
41
+
42
+ def size
43
+ @samples.size
44
+ end
45
+
46
+ private
47
+
48
+ def find_classes(dir)
49
+ classes = Dir.children(dir).select { |d| File.directory?(File.join(dir, d)) }
50
+ classes.sort!
51
+ class_to_idx = classes.map.with_index.to_h
52
+ [classes, class_to_idx]
53
+ end
54
+
55
+ def has_file_allowed_extension(filename, extensions)
56
+ filename = filename.downcase
57
+ extensions.any? { |ext| filename.end_with?(ext) }
58
+ end
59
+
60
+ def make_dataset(directory, class_to_idx, extensions, is_valid_file)
61
+ instances = []
62
+ directory = File.expand_path(directory)
63
+ both_none = extensions.nil? && is_valid_file.nil?
64
+ both_something = !extensions.nil? && !is_valid_file.nil?
65
+ if both_none || both_something
66
+ raise ArgumentError, "Both extensions and is_valid_file cannot be None or not None at the same time"
67
+ end
68
+ if !extensions.nil?
69
+ is_valid_file = lambda do |x|
70
+ has_file_allowed_extension(x, extensions)
71
+ end
72
+ end
73
+ class_to_idx.keys.sort.each do |target_class|
74
+ class_index = class_to_idx[target_class]
75
+ target_dir = File.join(directory, target_class)
76
+ if !File.directory?(target_dir)
77
+ next
78
+ end
79
+ Dir.glob("**", base: target_dir).sort.each do |fname|
80
+ path = File.join(target_dir, fname)
81
+ if is_valid_file.call(path)
82
+ item = [path, class_index]
83
+ instances << item
84
+ end
85
+ end
86
+ end
87
+ instances
88
+ end
89
+ end
90
+ end
91
+ end
@@ -0,0 +1,30 @@
1
+ module TorchVision
2
+ module Datasets
3
+ class FashionMNIST < MNIST
4
+ # https://github.com/zalandoresearch/fashion-mnist
5
+
6
+ private
7
+
8
+ def resources
9
+ [
10
+ {
11
+ url: "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz",
12
+ sha256: "3aede38d61863908ad78613f6a32ed271626dd12800ba2636569512369268a84"
13
+ },
14
+ {
15
+ url: "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz",
16
+ sha256: "a04f17134ac03560a47e3764e11b92fc97de4d1bfaf8ba1a3aa29af54cc90845"
17
+ },
18
+ {
19
+ url: "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz",
20
+ sha256: "346e55b948d973a97e58d2351dde16a484bd415d4595297633bb08f03db6a073"
21
+ },
22
+ {
23
+ url: "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz",
24
+ sha256: "67da17c76eaffca5446c3361aaab5c3cd6d1c2608764d35dfb1850b086bf8dd5"
25
+ }
26
+ ]
27
+ end
28
+ end
29
+ end
30
+ end
@@ -0,0 +1,12 @@
1
+ module TorchVision
2
+ module Datasets
3
+ class ImageFolder < DatasetFolder
4
+ IMG_EXTENSIONS = [".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp"]
5
+
6
+ def initialize(root, transform: nil, target_transform: nil, is_valid_file: nil)
7
+ super(root, extensions: IMG_EXTENSIONS, transform: transform, target_transform: target_transform, is_valid_file: is_valid_file)
8
+ @imgs = @samples
9
+ end
10
+ end
11
+ end
12
+ end
@@ -0,0 +1,30 @@
1
+ module TorchVision
2
+ module Datasets
3
+ class KMNIST < MNIST
4
+ # https://github.com/rois-codh/kmnist
5
+
6
+ private
7
+
8
+ def resources
9
+ [
10
+ {
11
+ url: "http://codh.rois.ac.jp/kmnist/dataset/kmnist/train-images-idx3-ubyte.gz",
12
+ sha256: "51467d22d8cc72929e2a028a0428f2086b092bb31cfb79c69cc0a90ce135fde4"
13
+ },
14
+ {
15
+ url: "http://codh.rois.ac.jp/kmnist/dataset/kmnist/train-labels-idx1-ubyte.gz",
16
+ sha256: "e38f9ebcd0f3ebcdec7fc8eabdcdaef93bb0df8ea12bee65224341c8183d8e17"
17
+ },
18
+ {
19
+ url: "http://codh.rois.ac.jp/kmnist/dataset/kmnist/t10k-images-idx3-ubyte.gz",
20
+ sha256: "edd7a857845ad6bb1d0ba43fe7e794d164fe2dce499a1694695a792adfac43c5"
21
+ },
22
+ {
23
+ url: "http://codh.rois.ac.jp/kmnist/dataset/kmnist/t10k-labels-idx1-ubyte.gz",
24
+ sha256: "20bb9a0ef54c7db3efc55a92eef5582c109615df22683c380526788f98e42a1c"
25
+ }
26
+ ]
27
+ end
28
+ end
29
+ end
30
+ end
@@ -1,31 +1,10 @@
1
1
  module TorchVision
2
2
  module Datasets
3
- class MNIST
4
- RESOURCES = [
5
- {
6
- url: "http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz",
7
- sha256: "440fcabf73cc546fa21475e81ea370265605f56be210a4024d2ca8f203523609"
8
- },
9
- {
10
- url: "http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz",
11
- sha256: "3552534a0a558bbed6aed32b30c495cca23d567ec52cac8be1a0730e8010255c"
12
- },
13
- {
14
- url: "http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz",
15
- sha256: "8d422c7b0a1c1c79245a5bcf07fe86e33eeafee792b84584aec276f5a2dbc4e6"
16
- },
17
- {
18
- url: "http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz",
19
- sha256: "f7ae60f92e00ec6debd23a6088c31dbd2371eca3ffa0defaefb259924204aec6"
20
- }
21
- ]
22
- TRAINING_FILE = "training.pt"
23
- TEST_FILE = "test.pt"
24
-
25
- def initialize(root, train: true, download: false, transform: nil)
26
- @root = root
3
+ class MNIST < VisionDataset
4
+ # http://yann.lecun.com/exdb/mnist/
5
+ def initialize(root, train: true, download: false, transform: nil, target_transform: nil)
6
+ super(root, transform: transform, target_transform: target_transform)
27
7
  @train = train
28
- @transform = transform
29
8
 
30
9
  self.download if download
31
10
 
@@ -33,35 +12,37 @@ module TorchVision
33
12
  raise Error, "Dataset not found. You can use download: true to download it"
34
13
  end
35
14
 
36
- data_file = @train ? TRAINING_FILE : TEST_FILE
15
+ data_file = @train ? training_file : test_file
37
16
  @data, @targets = Torch.load(File.join(processed_folder, data_file))
38
17
  end
39
18
 
40
19
  def size
41
- @data.size[0]
20
+ @data.size(0)
42
21
  end
43
22
 
44
23
  def [](index)
45
- img = @data[index]
46
- img = MiniMagick::Image.import_pixels(img.numo.to_binary, img.size(0), img.size(1), 8, "gray")
24
+ img, target = @data[index], @targets[index].item
25
+
26
+ img = Utils.image_from_array(img)
27
+
47
28
  img = @transform.call(img) if @transform
48
29
 
49
- target = @targets[index].item
30
+ target = @target_transform.call(target) if @target_transform
50
31
 
51
32
  [img, target]
52
33
  end
53
34
 
54
35
  def raw_folder
55
- File.join(@root, "MNIST", "raw")
36
+ File.join(@root, self.class.name.split("::").last, "raw")
56
37
  end
57
38
 
58
39
  def processed_folder
59
- File.join(@root, "MNIST", "processed")
40
+ File.join(@root, self.class.name.split("::").last, "processed")
60
41
  end
61
42
 
62
43
  def check_exists
63
- File.exist?(File.join(processed_folder, TRAINING_FILE)) &&
64
- File.exist?(File.join(processed_folder, TEST_FILE))
44
+ File.exist?(File.join(processed_folder, training_file)) &&
45
+ File.exist?(File.join(processed_folder, test_file))
65
46
  end
66
47
 
67
48
  def download
@@ -70,7 +51,7 @@ module TorchVision
70
51
  FileUtils.mkdir_p(raw_folder)
71
52
  FileUtils.mkdir_p(processed_folder)
72
53
 
73
- RESOURCES.each do |resource|
54
+ resources.each do |resource|
74
55
  filename = resource[:url].split("/").last
75
56
  download_file(resource[:url], download_root: raw_folder, filename: filename, sha256: resource[:sha256])
76
57
  end
@@ -86,14 +67,43 @@ module TorchVision
86
67
  unpack_mnist("t10k-labels-idx1-ubyte", 8, [10000])
87
68
  ]
88
69
 
89
- Torch.save(training_set, File.join(processed_folder, TRAINING_FILE))
90
- Torch.save(test_set, File.join(processed_folder, TEST_FILE))
70
+ Torch.save(training_set, File.join(processed_folder, training_file))
71
+ Torch.save(test_set, File.join(processed_folder, test_file))
91
72
 
92
73
  puts "Done!"
93
74
  end
94
75
 
95
76
  private
96
77
 
78
+ def resources
79
+ [
80
+ {
81
+ url: "http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz",
82
+ sha256: "440fcabf73cc546fa21475e81ea370265605f56be210a4024d2ca8f203523609"
83
+ },
84
+ {
85
+ url: "http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz",
86
+ sha256: "3552534a0a558bbed6aed32b30c495cca23d567ec52cac8be1a0730e8010255c"
87
+ },
88
+ {
89
+ url: "http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz",
90
+ sha256: "8d422c7b0a1c1c79245a5bcf07fe86e33eeafee792b84584aec276f5a2dbc4e6"
91
+ },
92
+ {
93
+ url: "http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz",
94
+ sha256: "f7ae60f92e00ec6debd23a6088c31dbd2371eca3ffa0defaefb259924204aec6"
95
+ }
96
+ ]
97
+ end
98
+
99
+ def training_file
100
+ "training.pt"
101
+ end
102
+
103
+ def test_file
104
+ "test.pt"
105
+ end
106
+
97
107
  def unpack_mnist(path, offset, shape)
98
108
  path = File.join(raw_folder, "#{path}.gz")
99
109
  File.open(path, "rb") do |f|
@@ -102,45 +112,6 @@ module TorchVision
102
112
  Torch.tensor(Numo::UInt8.from_string(gz.read, shape))
103
113
  end
104
114
  end
105
-
106
- def download_file(url, download_root:, filename:, sha256:)
107
- FileUtils.mkdir_p(download_root)
108
-
109
- dest = File.join(download_root, filename)
110
- return dest if File.exist?(dest)
111
-
112
- temp_path = "#{Dir.tmpdir}/#{Time.now.to_f}" # TODO better name
113
-
114
- digest = Digest::SHA256.new
115
-
116
- uri = URI(url)
117
-
118
- # Net::HTTP automatically adds Accept-Encoding for compression
119
- # of response bodies and automatically decompresses gzip
120
- # and deflateresponses unless a Range header was sent.
121
- # https://ruby-doc.org/stdlib-2.6.4/libdoc/net/http/rdoc/Net/HTTP.html
122
- Net::HTTP.start(uri.host, uri.port, use_ssl: uri.scheme == "https") do |http|
123
- request = Net::HTTP::Get.new(uri)
124
-
125
- puts "Downloading #{url}..."
126
- File.open(temp_path, "wb") do |f|
127
- http.request(request) do |response|
128
- response.read_body do |chunk|
129
- f.write(chunk)
130
- digest.update(chunk)
131
- end
132
- end
133
- end
134
- end
135
-
136
- if digest.hexdigest != sha256
137
- raise Error, "Bad hash: #{digest.hexdigest}"
138
- end
139
-
140
- FileUtils.mv(temp_path, dest)
141
-
142
- dest
143
- end
144
115
  end
145
116
  end
146
117
  end
@@ -0,0 +1,67 @@
1
+ module TorchVision
2
+ module Datasets
3
+ class VisionDataset < Torch::Utils::Data::Dataset
4
+ attr_reader :data, :targets
5
+
6
+ def initialize(root, transforms: nil, transform: nil, target_transform: nil)
7
+ @root = root
8
+
9
+ has_transforms = !transforms.nil?
10
+ has_separate_transform = !transform.nil? || !target_transform.nil?
11
+ if has_transforms && has_separate_transform
12
+ raise ArgumentError, "Only transforms or transform/target_transform can be passed as argument"
13
+ end
14
+
15
+ @transform = transform
16
+ @target_transform = target_transform
17
+
18
+ if has_separate_transform
19
+ # transforms = StandardTransform.new(transform, target_transform)
20
+ end
21
+ @transforms = transforms
22
+ end
23
+
24
+ private
25
+
26
+ def download_file(url, download_root:, filename:, sha256:)
27
+ FileUtils.mkdir_p(download_root)
28
+
29
+ dest = File.join(download_root, filename)
30
+ return dest if File.exist?(dest)
31
+
32
+ temp_path = "#{Dir.tmpdir}/#{Time.now.to_f}" # TODO better name
33
+
34
+ uri = URI(url)
35
+
36
+ # Net::HTTP automatically adds Accept-Encoding for compression
37
+ # of response bodies and automatically decompresses gzip
38
+ # and deflateresponses unless a Range header was sent.
39
+ # https://ruby-doc.org/stdlib-2.6.4/libdoc/net/http/rdoc/Net/HTTP.html
40
+ Net::HTTP.start(uri.host, uri.port, use_ssl: uri.scheme == "https") do |http|
41
+ request = Net::HTTP::Get.new(uri)
42
+
43
+ puts "Downloading #{url}..."
44
+ File.open(temp_path, "wb") do |f|
45
+ http.request(request) do |response|
46
+ response.read_body do |chunk|
47
+ f.write(chunk)
48
+ end
49
+ end
50
+ end
51
+ end
52
+
53
+ unless check_integrity(temp_path, sha256)
54
+ raise Error, "Bad hash"
55
+ end
56
+
57
+ FileUtils.mv(temp_path, dest)
58
+
59
+ dest
60
+ end
61
+
62
+ def check_integrity(path, sha256)
63
+ File.exist?(path) && Digest::SHA256.file(path).hexdigest == sha256
64
+ end
65
+ end
66
+ end
67
+ end
@@ -0,0 +1,42 @@
1
+ module TorchVision
2
+ module Models
3
+ class AlexNet < Torch::NN::Module
4
+ def initialize(num_classes: 1000)
5
+ super()
6
+ @features = Torch::NN::Sequential.new(
7
+ Torch::NN::Conv2d.new(3, 64, 11, stride: 4, padding: 2),
8
+ Torch::NN::ReLU.new(inplace: true),
9
+ Torch::NN::MaxPool2d.new(3, stride: 2),
10
+ Torch::NN::Conv2d.new(64, 192, 5, padding: 2),
11
+ Torch::NN::ReLU.new(inplace: true),
12
+ Torch::NN::MaxPool2d.new(3, stride: 2),
13
+ Torch::NN::Conv2d.new(192, 384, 3, padding: 1),
14
+ Torch::NN::ReLU.new(inplace: true),
15
+ Torch::NN::Conv2d.new(384, 256, 3, padding: 1),
16
+ Torch::NN::ReLU.new(inplace: true),
17
+ Torch::NN::Conv2d.new(256, 256, 3, padding: 1),
18
+ Torch::NN::ReLU.new(inplace: true),
19
+ Torch::NN::MaxPool2d.new(3, stride: 2),
20
+ )
21
+ @avgpool = Torch::NN::AdaptiveAvgPool2d.new([6, 6])
22
+ @classifier = Torch::NN::Sequential.new(
23
+ Torch::NN::Dropout.new,
24
+ Torch::NN::Linear.new(256 * 6 * 6, 4096),
25
+ Torch::NN::ReLU.new(inplace: true),
26
+ Torch::NN::Dropout.new,
27
+ Torch::NN::Linear.new(4096, 4096),
28
+ Torch::NN::ReLU.new(inplace: true),
29
+ Torch::NN::Linear.new(4096, num_classes)
30
+ )
31
+ end
32
+
33
+ def forward(x)
34
+ x = @features.call(x)
35
+ x = @avgpool.call(x)
36
+ x = Torch.flatten(x, 1)
37
+ x = @classifier.call(x)
38
+ x
39
+ end
40
+ end
41
+ end
42
+ end