torchvision 0.1.0 → 0.2.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (47) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +35 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +133 -5
  5. data/lib/torchvision.rb +40 -1
  6. data/lib/torchvision/datasets/cifar10.rb +117 -0
  7. data/lib/torchvision/datasets/cifar100.rb +41 -0
  8. data/lib/torchvision/datasets/dataset_folder.rb +91 -0
  9. data/lib/torchvision/datasets/fashion_mnist.rb +30 -0
  10. data/lib/torchvision/datasets/image_folder.rb +12 -0
  11. data/lib/torchvision/datasets/kmnist.rb +30 -0
  12. data/lib/torchvision/datasets/mnist.rb +47 -76
  13. data/lib/torchvision/datasets/vision_dataset.rb +67 -0
  14. data/lib/torchvision/models/alexnet.rb +42 -0
  15. data/lib/torchvision/models/basic_block.rb +46 -0
  16. data/lib/torchvision/models/bottleneck.rb +47 -0
  17. data/lib/torchvision/models/resnet.rb +129 -0
  18. data/lib/torchvision/models/resnet101.rb +9 -0
  19. data/lib/torchvision/models/resnet152.rb +9 -0
  20. data/lib/torchvision/models/resnet18.rb +9 -0
  21. data/lib/torchvision/models/resnet34.rb +9 -0
  22. data/lib/torchvision/models/resnet50.rb +9 -0
  23. data/lib/torchvision/models/resnext101_32x8d.rb +11 -0
  24. data/lib/torchvision/models/resnext50_32x4d.rb +11 -0
  25. data/lib/torchvision/models/vgg.rb +93 -0
  26. data/lib/torchvision/models/vgg11.rb +9 -0
  27. data/lib/torchvision/models/vgg11_bn.rb +9 -0
  28. data/lib/torchvision/models/vgg13.rb +9 -0
  29. data/lib/torchvision/models/vgg13_bn.rb +9 -0
  30. data/lib/torchvision/models/vgg16.rb +9 -0
  31. data/lib/torchvision/models/vgg16_bn.rb +9 -0
  32. data/lib/torchvision/models/vgg19.rb +9 -0
  33. data/lib/torchvision/models/vgg19_bn.rb +9 -0
  34. data/lib/torchvision/models/wide_resnet101_2.rb +10 -0
  35. data/lib/torchvision/models/wide_resnet50_2.rb +10 -0
  36. data/lib/torchvision/transforms/center_crop.rb +13 -0
  37. data/lib/torchvision/transforms/compose.rb +2 -2
  38. data/lib/torchvision/transforms/functional.rb +142 -7
  39. data/lib/torchvision/transforms/normalize.rb +2 -2
  40. data/lib/torchvision/transforms/random_horizontal_flip.rb +18 -0
  41. data/lib/torchvision/transforms/random_resized_crop.rb +70 -0
  42. data/lib/torchvision/transforms/random_vertical_flip.rb +18 -0
  43. data/lib/torchvision/transforms/resize.rb +13 -0
  44. data/lib/torchvision/transforms/to_tensor.rb +2 -2
  45. data/lib/torchvision/utils.rb +120 -0
  46. data/lib/torchvision/version.rb +1 -1
  47. metadata +50 -57
@@ -0,0 +1,91 @@
1
+ module TorchVision
2
+ module Datasets
3
+ class DatasetFolder < VisionDataset
4
+ attr_reader :classes
5
+
6
+ def initialize(root, extensions: nil, transform: nil, target_transform: nil, is_valid_file: nil)
7
+ super(root, transform: transform, target_transform: target_transform)
8
+ classes, class_to_idx = find_classes(@root)
9
+ samples = make_dataset(@root, class_to_idx, extensions, is_valid_file)
10
+ if samples.empty?
11
+ msg = "Found 0 files in subfolders of: #{@root}\n"
12
+ unless extensions.nil?
13
+ msg += "Supported extensions are: #{extensions.join(",")}"
14
+ end
15
+ raise RuntimeError, msg
16
+ end
17
+
18
+ @loader = lambda do |path|
19
+ Vips::Image.new_from_file(path)
20
+ end
21
+ @extensions = extensions
22
+
23
+ @classes = classes
24
+ @class_to_idx = class_to_idx
25
+ @samples = samples
26
+ @targets = samples.map { |s| s[1] }
27
+ end
28
+
29
+ def [](index)
30
+ path, target = @samples[index]
31
+ sample = @loader.call(path)
32
+ if @transform
33
+ sample = @transform.call(sample)
34
+ end
35
+ if @target_transform
36
+ target = @target_transform.call(target)
37
+ end
38
+
39
+ [sample, target]
40
+ end
41
+
42
+ def size
43
+ @samples.size
44
+ end
45
+
46
+ private
47
+
48
+ def find_classes(dir)
49
+ classes = Dir.children(dir).select { |d| File.directory?(File.join(dir, d)) }
50
+ classes.sort!
51
+ class_to_idx = classes.map.with_index.to_h
52
+ [classes, class_to_idx]
53
+ end
54
+
55
+ def has_file_allowed_extension(filename, extensions)
56
+ filename = filename.downcase
57
+ extensions.any? { |ext| filename.end_with?(ext) }
58
+ end
59
+
60
+ def make_dataset(directory, class_to_idx, extensions, is_valid_file)
61
+ instances = []
62
+ directory = File.expand_path(directory)
63
+ both_none = extensions.nil? && is_valid_file.nil?
64
+ both_something = !extensions.nil? && !is_valid_file.nil?
65
+ if both_none || both_something
66
+ raise ArgumentError, "Both extensions and is_valid_file cannot be None or not None at the same time"
67
+ end
68
+ if !extensions.nil?
69
+ is_valid_file = lambda do |x|
70
+ has_file_allowed_extension(x, extensions)
71
+ end
72
+ end
73
+ class_to_idx.keys.sort.each do |target_class|
74
+ class_index = class_to_idx[target_class]
75
+ target_dir = File.join(directory, target_class)
76
+ if !File.directory?(target_dir)
77
+ next
78
+ end
79
+ Dir.glob("**", base: target_dir).sort.each do |fname|
80
+ path = File.join(target_dir, fname)
81
+ if is_valid_file.call(path)
82
+ item = [path, class_index]
83
+ instances << item
84
+ end
85
+ end
86
+ end
87
+ instances
88
+ end
89
+ end
90
+ end
91
+ end
@@ -0,0 +1,30 @@
1
+ module TorchVision
2
+ module Datasets
3
+ class FashionMNIST < MNIST
4
+ # https://github.com/zalandoresearch/fashion-mnist
5
+
6
+ private
7
+
8
+ def resources
9
+ [
10
+ {
11
+ url: "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz",
12
+ sha256: "3aede38d61863908ad78613f6a32ed271626dd12800ba2636569512369268a84"
13
+ },
14
+ {
15
+ url: "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz",
16
+ sha256: "a04f17134ac03560a47e3764e11b92fc97de4d1bfaf8ba1a3aa29af54cc90845"
17
+ },
18
+ {
19
+ url: "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz",
20
+ sha256: "346e55b948d973a97e58d2351dde16a484bd415d4595297633bb08f03db6a073"
21
+ },
22
+ {
23
+ url: "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz",
24
+ sha256: "67da17c76eaffca5446c3361aaab5c3cd6d1c2608764d35dfb1850b086bf8dd5"
25
+ }
26
+ ]
27
+ end
28
+ end
29
+ end
30
+ end
@@ -0,0 +1,12 @@
1
+ module TorchVision
2
+ module Datasets
3
+ class ImageFolder < DatasetFolder
4
+ IMG_EXTENSIONS = [".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp"]
5
+
6
+ def initialize(root, transform: nil, target_transform: nil, is_valid_file: nil)
7
+ super(root, extensions: IMG_EXTENSIONS, transform: transform, target_transform: target_transform, is_valid_file: is_valid_file)
8
+ @imgs = @samples
9
+ end
10
+ end
11
+ end
12
+ end
@@ -0,0 +1,30 @@
1
+ module TorchVision
2
+ module Datasets
3
+ class KMNIST < MNIST
4
+ # https://github.com/rois-codh/kmnist
5
+
6
+ private
7
+
8
+ def resources
9
+ [
10
+ {
11
+ url: "http://codh.rois.ac.jp/kmnist/dataset/kmnist/train-images-idx3-ubyte.gz",
12
+ sha256: "51467d22d8cc72929e2a028a0428f2086b092bb31cfb79c69cc0a90ce135fde4"
13
+ },
14
+ {
15
+ url: "http://codh.rois.ac.jp/kmnist/dataset/kmnist/train-labels-idx1-ubyte.gz",
16
+ sha256: "e38f9ebcd0f3ebcdec7fc8eabdcdaef93bb0df8ea12bee65224341c8183d8e17"
17
+ },
18
+ {
19
+ url: "http://codh.rois.ac.jp/kmnist/dataset/kmnist/t10k-images-idx3-ubyte.gz",
20
+ sha256: "edd7a857845ad6bb1d0ba43fe7e794d164fe2dce499a1694695a792adfac43c5"
21
+ },
22
+ {
23
+ url: "http://codh.rois.ac.jp/kmnist/dataset/kmnist/t10k-labels-idx1-ubyte.gz",
24
+ sha256: "20bb9a0ef54c7db3efc55a92eef5582c109615df22683c380526788f98e42a1c"
25
+ }
26
+ ]
27
+ end
28
+ end
29
+ end
30
+ end
@@ -1,31 +1,10 @@
1
1
  module TorchVision
2
2
  module Datasets
3
- class MNIST
4
- RESOURCES = [
5
- {
6
- url: "http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz",
7
- sha256: "440fcabf73cc546fa21475e81ea370265605f56be210a4024d2ca8f203523609"
8
- },
9
- {
10
- url: "http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz",
11
- sha256: "3552534a0a558bbed6aed32b30c495cca23d567ec52cac8be1a0730e8010255c"
12
- },
13
- {
14
- url: "http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz",
15
- sha256: "8d422c7b0a1c1c79245a5bcf07fe86e33eeafee792b84584aec276f5a2dbc4e6"
16
- },
17
- {
18
- url: "http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz",
19
- sha256: "f7ae60f92e00ec6debd23a6088c31dbd2371eca3ffa0defaefb259924204aec6"
20
- }
21
- ]
22
- TRAINING_FILE = "training.pt"
23
- TEST_FILE = "test.pt"
24
-
25
- def initialize(root, train: true, download: false, transform: nil)
26
- @root = root
3
+ class MNIST < VisionDataset
4
+ # http://yann.lecun.com/exdb/mnist/
5
+ def initialize(root, train: true, download: false, transform: nil, target_transform: nil)
6
+ super(root, transform: transform, target_transform: target_transform)
27
7
  @train = train
28
- @transform = transform
29
8
 
30
9
  self.download if download
31
10
 
@@ -33,35 +12,37 @@ module TorchVision
33
12
  raise Error, "Dataset not found. You can use download: true to download it"
34
13
  end
35
14
 
36
- data_file = @train ? TRAINING_FILE : TEST_FILE
15
+ data_file = @train ? training_file : test_file
37
16
  @data, @targets = Torch.load(File.join(processed_folder, data_file))
38
17
  end
39
18
 
40
19
  def size
41
- @data.size[0]
20
+ @data.size(0)
42
21
  end
43
22
 
44
23
  def [](index)
45
- img = @data[index]
46
- img = MiniMagick::Image.import_pixels(img.numo.to_binary, img.size(0), img.size(1), 8, "gray")
24
+ img, target = @data[index], @targets[index].item
25
+
26
+ img = Utils.image_from_array(img)
27
+
47
28
  img = @transform.call(img) if @transform
48
29
 
49
- target = @targets[index].item
30
+ target = @target_transform.call(target) if @target_transform
50
31
 
51
32
  [img, target]
52
33
  end
53
34
 
54
35
  def raw_folder
55
- File.join(@root, "MNIST", "raw")
36
+ File.join(@root, self.class.name.split("::").last, "raw")
56
37
  end
57
38
 
58
39
  def processed_folder
59
- File.join(@root, "MNIST", "processed")
40
+ File.join(@root, self.class.name.split("::").last, "processed")
60
41
  end
61
42
 
62
43
  def check_exists
63
- File.exist?(File.join(processed_folder, TRAINING_FILE)) &&
64
- File.exist?(File.join(processed_folder, TEST_FILE))
44
+ File.exist?(File.join(processed_folder, training_file)) &&
45
+ File.exist?(File.join(processed_folder, test_file))
65
46
  end
66
47
 
67
48
  def download
@@ -70,7 +51,7 @@ module TorchVision
70
51
  FileUtils.mkdir_p(raw_folder)
71
52
  FileUtils.mkdir_p(processed_folder)
72
53
 
73
- RESOURCES.each do |resource|
54
+ resources.each do |resource|
74
55
  filename = resource[:url].split("/").last
75
56
  download_file(resource[:url], download_root: raw_folder, filename: filename, sha256: resource[:sha256])
76
57
  end
@@ -86,14 +67,43 @@ module TorchVision
86
67
  unpack_mnist("t10k-labels-idx1-ubyte", 8, [10000])
87
68
  ]
88
69
 
89
- Torch.save(training_set, File.join(processed_folder, TRAINING_FILE))
90
- Torch.save(test_set, File.join(processed_folder, TEST_FILE))
70
+ Torch.save(training_set, File.join(processed_folder, training_file))
71
+ Torch.save(test_set, File.join(processed_folder, test_file))
91
72
 
92
73
  puts "Done!"
93
74
  end
94
75
 
95
76
  private
96
77
 
78
+ def resources
79
+ [
80
+ {
81
+ url: "http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz",
82
+ sha256: "440fcabf73cc546fa21475e81ea370265605f56be210a4024d2ca8f203523609"
83
+ },
84
+ {
85
+ url: "http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz",
86
+ sha256: "3552534a0a558bbed6aed32b30c495cca23d567ec52cac8be1a0730e8010255c"
87
+ },
88
+ {
89
+ url: "http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz",
90
+ sha256: "8d422c7b0a1c1c79245a5bcf07fe86e33eeafee792b84584aec276f5a2dbc4e6"
91
+ },
92
+ {
93
+ url: "http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz",
94
+ sha256: "f7ae60f92e00ec6debd23a6088c31dbd2371eca3ffa0defaefb259924204aec6"
95
+ }
96
+ ]
97
+ end
98
+
99
+ def training_file
100
+ "training.pt"
101
+ end
102
+
103
+ def test_file
104
+ "test.pt"
105
+ end
106
+
97
107
  def unpack_mnist(path, offset, shape)
98
108
  path = File.join(raw_folder, "#{path}.gz")
99
109
  File.open(path, "rb") do |f|
@@ -102,45 +112,6 @@ module TorchVision
102
112
  Torch.tensor(Numo::UInt8.from_string(gz.read, shape))
103
113
  end
104
114
  end
105
-
106
- def download_file(url, download_root:, filename:, sha256:)
107
- FileUtils.mkdir_p(download_root)
108
-
109
- dest = File.join(download_root, filename)
110
- return dest if File.exist?(dest)
111
-
112
- temp_path = "#{Dir.tmpdir}/#{Time.now.to_f}" # TODO better name
113
-
114
- digest = Digest::SHA256.new
115
-
116
- uri = URI(url)
117
-
118
- # Net::HTTP automatically adds Accept-Encoding for compression
119
- # of response bodies and automatically decompresses gzip
120
- # and deflateresponses unless a Range header was sent.
121
- # https://ruby-doc.org/stdlib-2.6.4/libdoc/net/http/rdoc/Net/HTTP.html
122
- Net::HTTP.start(uri.host, uri.port, use_ssl: uri.scheme == "https") do |http|
123
- request = Net::HTTP::Get.new(uri)
124
-
125
- puts "Downloading #{url}..."
126
- File.open(temp_path, "wb") do |f|
127
- http.request(request) do |response|
128
- response.read_body do |chunk|
129
- f.write(chunk)
130
- digest.update(chunk)
131
- end
132
- end
133
- end
134
- end
135
-
136
- if digest.hexdigest != sha256
137
- raise Error, "Bad hash: #{digest.hexdigest}"
138
- end
139
-
140
- FileUtils.mv(temp_path, dest)
141
-
142
- dest
143
- end
144
115
  end
145
116
  end
146
117
  end
@@ -0,0 +1,67 @@
1
+ module TorchVision
2
+ module Datasets
3
+ class VisionDataset < Torch::Utils::Data::Dataset
4
+ attr_reader :data, :targets
5
+
6
+ def initialize(root, transforms: nil, transform: nil, target_transform: nil)
7
+ @root = root
8
+
9
+ has_transforms = !transforms.nil?
10
+ has_separate_transform = !transform.nil? || !target_transform.nil?
11
+ if has_transforms && has_separate_transform
12
+ raise ArgumentError, "Only transforms or transform/target_transform can be passed as argument"
13
+ end
14
+
15
+ @transform = transform
16
+ @target_transform = target_transform
17
+
18
+ if has_separate_transform
19
+ # transforms = StandardTransform.new(transform, target_transform)
20
+ end
21
+ @transforms = transforms
22
+ end
23
+
24
+ private
25
+
26
+ def download_file(url, download_root:, filename:, sha256:)
27
+ FileUtils.mkdir_p(download_root)
28
+
29
+ dest = File.join(download_root, filename)
30
+ return dest if File.exist?(dest)
31
+
32
+ temp_path = "#{Dir.tmpdir}/#{Time.now.to_f}" # TODO better name
33
+
34
+ uri = URI(url)
35
+
36
+ # Net::HTTP automatically adds Accept-Encoding for compression
37
+ # of response bodies and automatically decompresses gzip
38
+ # and deflateresponses unless a Range header was sent.
39
+ # https://ruby-doc.org/stdlib-2.6.4/libdoc/net/http/rdoc/Net/HTTP.html
40
+ Net::HTTP.start(uri.host, uri.port, use_ssl: uri.scheme == "https") do |http|
41
+ request = Net::HTTP::Get.new(uri)
42
+
43
+ puts "Downloading #{url}..."
44
+ File.open(temp_path, "wb") do |f|
45
+ http.request(request) do |response|
46
+ response.read_body do |chunk|
47
+ f.write(chunk)
48
+ end
49
+ end
50
+ end
51
+ end
52
+
53
+ unless check_integrity(temp_path, sha256)
54
+ raise Error, "Bad hash"
55
+ end
56
+
57
+ FileUtils.mv(temp_path, dest)
58
+
59
+ dest
60
+ end
61
+
62
+ def check_integrity(path, sha256)
63
+ File.exist?(path) && Digest::SHA256.file(path).hexdigest == sha256
64
+ end
65
+ end
66
+ end
67
+ end
@@ -0,0 +1,42 @@
1
+ module TorchVision
2
+ module Models
3
+ class AlexNet < Torch::NN::Module
4
+ def initialize(num_classes: 1000)
5
+ super()
6
+ @features = Torch::NN::Sequential.new(
7
+ Torch::NN::Conv2d.new(3, 64, 11, stride: 4, padding: 2),
8
+ Torch::NN::ReLU.new(inplace: true),
9
+ Torch::NN::MaxPool2d.new(3, stride: 2),
10
+ Torch::NN::Conv2d.new(64, 192, 5, padding: 2),
11
+ Torch::NN::ReLU.new(inplace: true),
12
+ Torch::NN::MaxPool2d.new(3, stride: 2),
13
+ Torch::NN::Conv2d.new(192, 384, 3, padding: 1),
14
+ Torch::NN::ReLU.new(inplace: true),
15
+ Torch::NN::Conv2d.new(384, 256, 3, padding: 1),
16
+ Torch::NN::ReLU.new(inplace: true),
17
+ Torch::NN::Conv2d.new(256, 256, 3, padding: 1),
18
+ Torch::NN::ReLU.new(inplace: true),
19
+ Torch::NN::MaxPool2d.new(3, stride: 2),
20
+ )
21
+ @avgpool = Torch::NN::AdaptiveAvgPool2d.new([6, 6])
22
+ @classifier = Torch::NN::Sequential.new(
23
+ Torch::NN::Dropout.new,
24
+ Torch::NN::Linear.new(256 * 6 * 6, 4096),
25
+ Torch::NN::ReLU.new(inplace: true),
26
+ Torch::NN::Dropout.new,
27
+ Torch::NN::Linear.new(4096, 4096),
28
+ Torch::NN::ReLU.new(inplace: true),
29
+ Torch::NN::Linear.new(4096, num_classes)
30
+ )
31
+ end
32
+
33
+ def forward(x)
34
+ x = @features.call(x)
35
+ x = @avgpool.call(x)
36
+ x = Torch.flatten(x, 1)
37
+ x = @classifier.call(x)
38
+ x
39
+ end
40
+ end
41
+ end
42
+ end