torchvision 0.1.1 → 0.2.2

Sign up to get free protection for your applications and to get access to all the features.
Files changed (47) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +35 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +133 -7
  5. data/lib/torchvision.rb +39 -0
  6. data/lib/torchvision/datasets/cifar10.rb +117 -0
  7. data/lib/torchvision/datasets/cifar100.rb +41 -0
  8. data/lib/torchvision/datasets/dataset_folder.rb +91 -0
  9. data/lib/torchvision/datasets/fashion_mnist.rb +30 -0
  10. data/lib/torchvision/datasets/image_folder.rb +12 -0
  11. data/lib/torchvision/datasets/kmnist.rb +30 -0
  12. data/lib/torchvision/datasets/mnist.rb +47 -75
  13. data/lib/torchvision/datasets/vision_dataset.rb +67 -0
  14. data/lib/torchvision/models/alexnet.rb +42 -0
  15. data/lib/torchvision/models/basic_block.rb +46 -0
  16. data/lib/torchvision/models/bottleneck.rb +47 -0
  17. data/lib/torchvision/models/resnet.rb +129 -0
  18. data/lib/torchvision/models/resnet101.rb +9 -0
  19. data/lib/torchvision/models/resnet152.rb +9 -0
  20. data/lib/torchvision/models/resnet18.rb +9 -0
  21. data/lib/torchvision/models/resnet34.rb +9 -0
  22. data/lib/torchvision/models/resnet50.rb +9 -0
  23. data/lib/torchvision/models/resnext101_32x8d.rb +11 -0
  24. data/lib/torchvision/models/resnext50_32x4d.rb +11 -0
  25. data/lib/torchvision/models/vgg.rb +93 -0
  26. data/lib/torchvision/models/vgg11.rb +9 -0
  27. data/lib/torchvision/models/vgg11_bn.rb +9 -0
  28. data/lib/torchvision/models/vgg13.rb +9 -0
  29. data/lib/torchvision/models/vgg13_bn.rb +9 -0
  30. data/lib/torchvision/models/vgg16.rb +9 -0
  31. data/lib/torchvision/models/vgg16_bn.rb +9 -0
  32. data/lib/torchvision/models/vgg19.rb +9 -0
  33. data/lib/torchvision/models/vgg19_bn.rb +9 -0
  34. data/lib/torchvision/models/wide_resnet101_2.rb +10 -0
  35. data/lib/torchvision/models/wide_resnet50_2.rb +10 -0
  36. data/lib/torchvision/transforms/center_crop.rb +13 -0
  37. data/lib/torchvision/transforms/compose.rb +2 -2
  38. data/lib/torchvision/transforms/functional.rb +142 -8
  39. data/lib/torchvision/transforms/normalize.rb +2 -2
  40. data/lib/torchvision/transforms/random_horizontal_flip.rb +18 -0
  41. data/lib/torchvision/transforms/random_resized_crop.rb +70 -0
  42. data/lib/torchvision/transforms/random_vertical_flip.rb +18 -0
  43. data/lib/torchvision/transforms/resize.rb +13 -0
  44. data/lib/torchvision/transforms/to_tensor.rb +2 -2
  45. data/lib/torchvision/utils.rb +118 -0
  46. data/lib/torchvision/version.rb +1 -1
  47. metadata +51 -44
@@ -0,0 +1,91 @@
1
+ module TorchVision
2
+ module Datasets
3
+ class DatasetFolder < VisionDataset
4
+ attr_reader :classes
5
+
6
+ def initialize(root, extensions: nil, transform: nil, target_transform: nil, is_valid_file: nil)
7
+ super(root, transform: transform, target_transform: target_transform)
8
+ classes, class_to_idx = find_classes(@root)
9
+ samples = make_dataset(@root, class_to_idx, extensions, is_valid_file)
10
+ if samples.empty?
11
+ msg = "Found 0 files in subfolders of: #{@root}\n"
12
+ unless extensions.nil?
13
+ msg += "Supported extensions are: #{extensions.join(",")}"
14
+ end
15
+ raise RuntimeError, msg
16
+ end
17
+
18
+ @loader = lambda do |path|
19
+ Vips::Image.new_from_file(path)
20
+ end
21
+ @extensions = extensions
22
+
23
+ @classes = classes
24
+ @class_to_idx = class_to_idx
25
+ @samples = samples
26
+ @targets = samples.map { |s| s[1] }
27
+ end
28
+
29
+ def [](index)
30
+ path, target = @samples[index]
31
+ sample = @loader.call(path)
32
+ if @transform
33
+ sample = @transform.call(sample)
34
+ end
35
+ if @target_transform
36
+ target = @target_transform.call(target)
37
+ end
38
+
39
+ [sample, target]
40
+ end
41
+
42
+ def size
43
+ @samples.size
44
+ end
45
+
46
+ private
47
+
48
+ def find_classes(dir)
49
+ classes = Dir.children(dir).select { |d| File.directory?(File.join(dir, d)) }
50
+ classes.sort!
51
+ class_to_idx = classes.map.with_index.to_h
52
+ [classes, class_to_idx]
53
+ end
54
+
55
+ def has_file_allowed_extension(filename, extensions)
56
+ filename = filename.downcase
57
+ extensions.any? { |ext| filename.end_with?(ext) }
58
+ end
59
+
60
+ def make_dataset(directory, class_to_idx, extensions, is_valid_file)
61
+ instances = []
62
+ directory = File.expand_path(directory)
63
+ both_none = extensions.nil? && is_valid_file.nil?
64
+ both_something = !extensions.nil? && !is_valid_file.nil?
65
+ if both_none || both_something
66
+ raise ArgumentError, "Both extensions and is_valid_file cannot be None or not None at the same time"
67
+ end
68
+ if !extensions.nil?
69
+ is_valid_file = lambda do |x|
70
+ has_file_allowed_extension(x, extensions)
71
+ end
72
+ end
73
+ class_to_idx.keys.sort.each do |target_class|
74
+ class_index = class_to_idx[target_class]
75
+ target_dir = File.join(directory, target_class)
76
+ if !File.directory?(target_dir)
77
+ next
78
+ end
79
+ Dir.glob("**", base: target_dir).sort.each do |fname|
80
+ path = File.join(target_dir, fname)
81
+ if is_valid_file.call(path)
82
+ item = [path, class_index]
83
+ instances << item
84
+ end
85
+ end
86
+ end
87
+ instances
88
+ end
89
+ end
90
+ end
91
+ end
@@ -0,0 +1,30 @@
1
+ module TorchVision
2
+ module Datasets
3
+ class FashionMNIST < MNIST
4
+ # https://github.com/zalandoresearch/fashion-mnist
5
+
6
+ private
7
+
8
+ def resources
9
+ [
10
+ {
11
+ url: "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz",
12
+ sha256: "3aede38d61863908ad78613f6a32ed271626dd12800ba2636569512369268a84"
13
+ },
14
+ {
15
+ url: "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz",
16
+ sha256: "a04f17134ac03560a47e3764e11b92fc97de4d1bfaf8ba1a3aa29af54cc90845"
17
+ },
18
+ {
19
+ url: "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz",
20
+ sha256: "346e55b948d973a97e58d2351dde16a484bd415d4595297633bb08f03db6a073"
21
+ },
22
+ {
23
+ url: "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz",
24
+ sha256: "67da17c76eaffca5446c3361aaab5c3cd6d1c2608764d35dfb1850b086bf8dd5"
25
+ }
26
+ ]
27
+ end
28
+ end
29
+ end
30
+ end
@@ -0,0 +1,12 @@
1
+ module TorchVision
2
+ module Datasets
3
+ class ImageFolder < DatasetFolder
4
+ IMG_EXTENSIONS = [".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp"]
5
+
6
+ def initialize(root, transform: nil, target_transform: nil, is_valid_file: nil)
7
+ super(root, extensions: IMG_EXTENSIONS, transform: transform, target_transform: target_transform, is_valid_file: is_valid_file)
8
+ @imgs = @samples
9
+ end
10
+ end
11
+ end
12
+ end
@@ -0,0 +1,30 @@
1
+ module TorchVision
2
+ module Datasets
3
+ class KMNIST < MNIST
4
+ # https://github.com/rois-codh/kmnist
5
+
6
+ private
7
+
8
+ def resources
9
+ [
10
+ {
11
+ url: "http://codh.rois.ac.jp/kmnist/dataset/kmnist/train-images-idx3-ubyte.gz",
12
+ sha256: "51467d22d8cc72929e2a028a0428f2086b092bb31cfb79c69cc0a90ce135fde4"
13
+ },
14
+ {
15
+ url: "http://codh.rois.ac.jp/kmnist/dataset/kmnist/train-labels-idx1-ubyte.gz",
16
+ sha256: "e38f9ebcd0f3ebcdec7fc8eabdcdaef93bb0df8ea12bee65224341c8183d8e17"
17
+ },
18
+ {
19
+ url: "http://codh.rois.ac.jp/kmnist/dataset/kmnist/t10k-images-idx3-ubyte.gz",
20
+ sha256: "edd7a857845ad6bb1d0ba43fe7e794d164fe2dce499a1694695a792adfac43c5"
21
+ },
22
+ {
23
+ url: "http://codh.rois.ac.jp/kmnist/dataset/kmnist/t10k-labels-idx1-ubyte.gz",
24
+ sha256: "20bb9a0ef54c7db3efc55a92eef5582c109615df22683c380526788f98e42a1c"
25
+ }
26
+ ]
27
+ end
28
+ end
29
+ end
30
+ end
@@ -1,31 +1,10 @@
1
1
  module TorchVision
2
2
  module Datasets
3
- class MNIST
4
- RESOURCES = [
5
- {
6
- url: "http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz",
7
- sha256: "440fcabf73cc546fa21475e81ea370265605f56be210a4024d2ca8f203523609"
8
- },
9
- {
10
- url: "http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz",
11
- sha256: "3552534a0a558bbed6aed32b30c495cca23d567ec52cac8be1a0730e8010255c"
12
- },
13
- {
14
- url: "http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz",
15
- sha256: "8d422c7b0a1c1c79245a5bcf07fe86e33eeafee792b84584aec276f5a2dbc4e6"
16
- },
17
- {
18
- url: "http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz",
19
- sha256: "f7ae60f92e00ec6debd23a6088c31dbd2371eca3ffa0defaefb259924204aec6"
20
- }
21
- ]
22
- TRAINING_FILE = "training.pt"
23
- TEST_FILE = "test.pt"
24
-
25
- def initialize(root, train: true, download: false, transform: nil)
26
- @root = root
3
+ class MNIST < VisionDataset
4
+ # http://yann.lecun.com/exdb/mnist/
5
+ def initialize(root, train: true, download: false, transform: nil, target_transform: nil)
6
+ super(root, transform: transform, target_transform: target_transform)
27
7
  @train = train
28
- @transform = transform
29
8
 
30
9
  self.download if download
31
10
 
@@ -33,34 +12,37 @@ module TorchVision
33
12
  raise Error, "Dataset not found. You can use download: true to download it"
34
13
  end
35
14
 
36
- data_file = @train ? TRAINING_FILE : TEST_FILE
15
+ data_file = @train ? training_file : test_file
37
16
  @data, @targets = Torch.load(File.join(processed_folder, data_file))
38
17
  end
39
18
 
40
19
  def size
41
- @data.size[0]
20
+ @data.size(0)
42
21
  end
43
22
 
44
23
  def [](index)
45
- img = @data[index]
24
+ img, target = @data[index], @targets[index].item
25
+
26
+ img = Utils.image_from_array(img)
27
+
46
28
  img = @transform.call(img) if @transform
47
29
 
48
- target = @targets[index].item
30
+ target = @target_transform.call(target) if @target_transform
49
31
 
50
32
  [img, target]
51
33
  end
52
34
 
53
35
  def raw_folder
54
- File.join(@root, "MNIST", "raw")
36
+ File.join(@root, self.class.name.split("::").last, "raw")
55
37
  end
56
38
 
57
39
  def processed_folder
58
- File.join(@root, "MNIST", "processed")
40
+ File.join(@root, self.class.name.split("::").last, "processed")
59
41
  end
60
42
 
61
43
  def check_exists
62
- File.exist?(File.join(processed_folder, TRAINING_FILE)) &&
63
- File.exist?(File.join(processed_folder, TEST_FILE))
44
+ File.exist?(File.join(processed_folder, training_file)) &&
45
+ File.exist?(File.join(processed_folder, test_file))
64
46
  end
65
47
 
66
48
  def download
@@ -69,7 +51,7 @@ module TorchVision
69
51
  FileUtils.mkdir_p(raw_folder)
70
52
  FileUtils.mkdir_p(processed_folder)
71
53
 
72
- RESOURCES.each do |resource|
54
+ resources.each do |resource|
73
55
  filename = resource[:url].split("/").last
74
56
  download_file(resource[:url], download_root: raw_folder, filename: filename, sha256: resource[:sha256])
75
57
  end
@@ -85,14 +67,43 @@ module TorchVision
85
67
  unpack_mnist("t10k-labels-idx1-ubyte", 8, [10000])
86
68
  ]
87
69
 
88
- Torch.save(training_set, File.join(processed_folder, TRAINING_FILE))
89
- Torch.save(test_set, File.join(processed_folder, TEST_FILE))
70
+ Torch.save(training_set, File.join(processed_folder, training_file))
71
+ Torch.save(test_set, File.join(processed_folder, test_file))
90
72
 
91
73
  puts "Done!"
92
74
  end
93
75
 
94
76
  private
95
77
 
78
+ def resources
79
+ [
80
+ {
81
+ url: "http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz",
82
+ sha256: "440fcabf73cc546fa21475e81ea370265605f56be210a4024d2ca8f203523609"
83
+ },
84
+ {
85
+ url: "http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz",
86
+ sha256: "3552534a0a558bbed6aed32b30c495cca23d567ec52cac8be1a0730e8010255c"
87
+ },
88
+ {
89
+ url: "http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz",
90
+ sha256: "8d422c7b0a1c1c79245a5bcf07fe86e33eeafee792b84584aec276f5a2dbc4e6"
91
+ },
92
+ {
93
+ url: "http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz",
94
+ sha256: "f7ae60f92e00ec6debd23a6088c31dbd2371eca3ffa0defaefb259924204aec6"
95
+ }
96
+ ]
97
+ end
98
+
99
+ def training_file
100
+ "training.pt"
101
+ end
102
+
103
+ def test_file
104
+ "test.pt"
105
+ end
106
+
96
107
  def unpack_mnist(path, offset, shape)
97
108
  path = File.join(raw_folder, "#{path}.gz")
98
109
  File.open(path, "rb") do |f|
@@ -101,45 +112,6 @@ module TorchVision
101
112
  Torch.tensor(Numo::UInt8.from_string(gz.read, shape))
102
113
  end
103
114
  end
104
-
105
- def download_file(url, download_root:, filename:, sha256:)
106
- FileUtils.mkdir_p(download_root)
107
-
108
- dest = File.join(download_root, filename)
109
- return dest if File.exist?(dest)
110
-
111
- temp_path = "#{Dir.tmpdir}/#{Time.now.to_f}" # TODO better name
112
-
113
- digest = Digest::SHA256.new
114
-
115
- uri = URI(url)
116
-
117
- # Net::HTTP automatically adds Accept-Encoding for compression
118
- # of response bodies and automatically decompresses gzip
119
- # and deflateresponses unless a Range header was sent.
120
- # https://ruby-doc.org/stdlib-2.6.4/libdoc/net/http/rdoc/Net/HTTP.html
121
- Net::HTTP.start(uri.host, uri.port, use_ssl: uri.scheme == "https") do |http|
122
- request = Net::HTTP::Get.new(uri)
123
-
124
- puts "Downloading #{url}..."
125
- File.open(temp_path, "wb") do |f|
126
- http.request(request) do |response|
127
- response.read_body do |chunk|
128
- f.write(chunk)
129
- digest.update(chunk)
130
- end
131
- end
132
- end
133
- end
134
-
135
- if digest.hexdigest != sha256
136
- raise Error, "Bad hash: #{digest.hexdigest}"
137
- end
138
-
139
- FileUtils.mv(temp_path, dest)
140
-
141
- dest
142
- end
143
115
  end
144
116
  end
145
117
  end
@@ -0,0 +1,67 @@
1
+ module TorchVision
2
+ module Datasets
3
+ class VisionDataset < Torch::Utils::Data::Dataset
4
+ attr_reader :data, :targets
5
+
6
+ def initialize(root, transforms: nil, transform: nil, target_transform: nil)
7
+ @root = root
8
+
9
+ has_transforms = !transforms.nil?
10
+ has_separate_transform = !transform.nil? || !target_transform.nil?
11
+ if has_transforms && has_separate_transform
12
+ raise ArgumentError, "Only transforms or transform/target_transform can be passed as argument"
13
+ end
14
+
15
+ @transform = transform
16
+ @target_transform = target_transform
17
+
18
+ if has_separate_transform
19
+ # transforms = StandardTransform.new(transform, target_transform)
20
+ end
21
+ @transforms = transforms
22
+ end
23
+
24
+ private
25
+
26
+ def download_file(url, download_root:, filename:, sha256:)
27
+ FileUtils.mkdir_p(download_root)
28
+
29
+ dest = File.join(download_root, filename)
30
+ return dest if File.exist?(dest)
31
+
32
+ temp_path = "#{Dir.tmpdir}/#{Time.now.to_f}" # TODO better name
33
+
34
+ uri = URI(url)
35
+
36
+ # Net::HTTP automatically adds Accept-Encoding for compression
37
+ # of response bodies and automatically decompresses gzip
38
+ # and deflateresponses unless a Range header was sent.
39
+ # https://ruby-doc.org/stdlib-2.6.4/libdoc/net/http/rdoc/Net/HTTP.html
40
+ Net::HTTP.start(uri.host, uri.port, use_ssl: uri.scheme == "https") do |http|
41
+ request = Net::HTTP::Get.new(uri)
42
+
43
+ puts "Downloading #{url}..."
44
+ File.open(temp_path, "wb") do |f|
45
+ http.request(request) do |response|
46
+ response.read_body do |chunk|
47
+ f.write(chunk)
48
+ end
49
+ end
50
+ end
51
+ end
52
+
53
+ unless check_integrity(temp_path, sha256)
54
+ raise Error, "Bad hash"
55
+ end
56
+
57
+ FileUtils.mv(temp_path, dest)
58
+
59
+ dest
60
+ end
61
+
62
+ def check_integrity(path, sha256)
63
+ File.exist?(path) && Digest::SHA256.file(path).hexdigest == sha256
64
+ end
65
+ end
66
+ end
67
+ end
@@ -0,0 +1,42 @@
1
+ module TorchVision
2
+ module Models
3
+ class AlexNet < Torch::NN::Module
4
+ def initialize(num_classes: 1000)
5
+ super()
6
+ @features = Torch::NN::Sequential.new(
7
+ Torch::NN::Conv2d.new(3, 64, 11, stride: 4, padding: 2),
8
+ Torch::NN::ReLU.new(inplace: true),
9
+ Torch::NN::MaxPool2d.new(3, stride: 2),
10
+ Torch::NN::Conv2d.new(64, 192, 5, padding: 2),
11
+ Torch::NN::ReLU.new(inplace: true),
12
+ Torch::NN::MaxPool2d.new(3, stride: 2),
13
+ Torch::NN::Conv2d.new(192, 384, 3, padding: 1),
14
+ Torch::NN::ReLU.new(inplace: true),
15
+ Torch::NN::Conv2d.new(384, 256, 3, padding: 1),
16
+ Torch::NN::ReLU.new(inplace: true),
17
+ Torch::NN::Conv2d.new(256, 256, 3, padding: 1),
18
+ Torch::NN::ReLU.new(inplace: true),
19
+ Torch::NN::MaxPool2d.new(3, stride: 2),
20
+ )
21
+ @avgpool = Torch::NN::AdaptiveAvgPool2d.new([6, 6])
22
+ @classifier = Torch::NN::Sequential.new(
23
+ Torch::NN::Dropout.new,
24
+ Torch::NN::Linear.new(256 * 6 * 6, 4096),
25
+ Torch::NN::ReLU.new(inplace: true),
26
+ Torch::NN::Dropout.new,
27
+ Torch::NN::Linear.new(4096, 4096),
28
+ Torch::NN::ReLU.new(inplace: true),
29
+ Torch::NN::Linear.new(4096, num_classes)
30
+ )
31
+ end
32
+
33
+ def forward(x)
34
+ x = @features.call(x)
35
+ x = @avgpool.call(x)
36
+ x = Torch.flatten(x, 1)
37
+ x = @classifier.call(x)
38
+ x
39
+ end
40
+ end
41
+ end
42
+ end