torchvision 0.1.0 → 0.2.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +35 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +133 -5
  5. data/lib/torchvision.rb +40 -1
  6. data/lib/torchvision/datasets/cifar10.rb +117 -0
  7. data/lib/torchvision/datasets/cifar100.rb +41 -0
  8. data/lib/torchvision/datasets/dataset_folder.rb +91 -0
  9. data/lib/torchvision/datasets/fashion_mnist.rb +30 -0
  10. data/lib/torchvision/datasets/image_folder.rb +12 -0
  11. data/lib/torchvision/datasets/kmnist.rb +30 -0
  12. data/lib/torchvision/datasets/mnist.rb +47 -76
  13. data/lib/torchvision/datasets/vision_dataset.rb +67 -0
  14. data/lib/torchvision/models/alexnet.rb +42 -0
  15. data/lib/torchvision/models/basic_block.rb +46 -0
  16. data/lib/torchvision/models/bottleneck.rb +47 -0
  17. data/lib/torchvision/models/resnet.rb +129 -0
  18. data/lib/torchvision/models/resnet101.rb +9 -0
  19. data/lib/torchvision/models/resnet152.rb +9 -0
  20. data/lib/torchvision/models/resnet18.rb +9 -0
  21. data/lib/torchvision/models/resnet34.rb +9 -0
  22. data/lib/torchvision/models/resnet50.rb +9 -0
  23. data/lib/torchvision/models/resnext101_32x8d.rb +11 -0
  24. data/lib/torchvision/models/resnext50_32x4d.rb +11 -0
  25. data/lib/torchvision/models/vgg.rb +93 -0
  26. data/lib/torchvision/models/vgg11.rb +9 -0
  27. data/lib/torchvision/models/vgg11_bn.rb +9 -0
  28. data/lib/torchvision/models/vgg13.rb +9 -0
  29. data/lib/torchvision/models/vgg13_bn.rb +9 -0
  30. data/lib/torchvision/models/vgg16.rb +9 -0
  31. data/lib/torchvision/models/vgg16_bn.rb +9 -0
  32. data/lib/torchvision/models/vgg19.rb +9 -0
  33. data/lib/torchvision/models/vgg19_bn.rb +9 -0
  34. data/lib/torchvision/models/wide_resnet101_2.rb +10 -0
  35. data/lib/torchvision/models/wide_resnet50_2.rb +10 -0
  36. data/lib/torchvision/transforms/center_crop.rb +13 -0
  37. data/lib/torchvision/transforms/compose.rb +2 -2
  38. data/lib/torchvision/transforms/functional.rb +142 -7
  39. data/lib/torchvision/transforms/normalize.rb +2 -2
  40. data/lib/torchvision/transforms/random_horizontal_flip.rb +18 -0
  41. data/lib/torchvision/transforms/random_resized_crop.rb +70 -0
  42. data/lib/torchvision/transforms/random_vertical_flip.rb +18 -0
  43. data/lib/torchvision/transforms/resize.rb +13 -0
  44. data/lib/torchvision/transforms/to_tensor.rb +2 -2
  45. data/lib/torchvision/utils.rb +120 -0
  46. data/lib/torchvision/version.rb +1 -1
  47. metadata +50 -57
@@ -0,0 +1,9 @@
1
+ module TorchVision
2
+ module Models
3
+ module VGG16
4
+ def self.new(**kwargs)
5
+ VGG.make_model("vgg16", "D", false, **kwargs)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,9 @@
1
+ module TorchVision
2
+ module Models
3
+ module VGG16BN
4
+ def self.new(**kwargs)
5
+ VGG.make_model("vgg16_bn", "D", true, **kwargs)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,9 @@
1
+ module TorchVision
2
+ module Models
3
+ module VGG19
4
+ def self.new(**kwargs)
5
+ VGG.make_model("vgg19", "E", false, **kwargs)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,9 @@
1
+ module TorchVision
2
+ module Models
3
+ module VGG19BN
4
+ def self.new(**kwargs)
5
+ VGG.make_model("vgg19_bn", "E", true, **kwargs)
6
+ end
7
+ end
8
+ end
9
+ end
@@ -0,0 +1,10 @@
1
+ module TorchVision
2
+ module Models
3
+ module WideResNet101_2
4
+ def self.new(**kwargs)
5
+ kwargs[:width_per_group] = 64 * 2
6
+ ResNet.make_model("wide_resnet101_2", Bottleneck, [3, 4, 23, 3], **kwargs)
7
+ end
8
+ end
9
+ end
10
+ end
@@ -0,0 +1,10 @@
1
+ module TorchVision
2
+ module Models
3
+ module WideResNet50_2
4
+ def self.new(**kwargs)
5
+ kwargs[:width_per_group] = 64 * 2
6
+ ResNet.make_model("wide_resnet50_2", Bottleneck, [3, 4, 6, 3], **kwargs)
7
+ end
8
+ end
9
+ end
10
+ end
@@ -0,0 +1,13 @@
1
+ module TorchVision
2
+ module Transforms
3
+ class CenterCrop < Torch::NN::Module
4
+ def initialize(size)
5
+ @size = size
6
+ end
7
+
8
+ def forward(img)
9
+ F.center_crop(img, @size)
10
+ end
11
+ end
12
+ end
13
+ end
@@ -1,11 +1,11 @@
1
1
  module TorchVision
2
2
  module Transforms
3
- class Compose
3
+ class Compose < Torch::NN::Module
4
4
  def initialize(transforms)
5
5
  @transforms = transforms
6
6
  end
7
7
 
8
- def call(img)
8
+ def forward(img)
9
9
  @transforms.each do |t|
10
10
  img = t.call(img)
11
11
  end
@@ -22,19 +22,154 @@ module TorchVision
22
22
  if std.to_a.any? { |v| v == 0 }
23
23
  raise ArgumentError, "std evaluated to zero after conversion to #{dtype}, leading to division by zero."
24
24
  end
25
- # if mean.ndim == 1
26
- # raise Torch::NotImplementedYet
27
- # end
28
- # if std.ndim == 1
29
- # raise Torch::NotImplementedYet
30
- # end
25
+ if mean.ndim == 1
26
+ mean = mean[0...mean.size(0), nil, nil]
27
+ end
28
+ if std.ndim == 1
29
+ std = std[0...std.size(0), nil, nil]
30
+ end
31
31
  tensor.sub!(mean).div!(std)
32
32
  tensor
33
33
  end
34
34
 
35
+ def resize(img, size)
36
+ raise "img should be Vips::Image. Got #{img.class.name}" unless img.is_a?(Vips::Image)
37
+
38
+ if size.is_a?(Integer)
39
+ w, h = img.size
40
+ if (w <= h && w == size) || (h <= w && h == size)
41
+ return img
42
+ end
43
+ if w < h
44
+ ow = size
45
+ oh = (size * h / w).to_i
46
+ img.thumbnail_image(ow, height: oh)
47
+ else
48
+ oh = size
49
+ ow = (size * w / h).to_i
50
+ img.thumbnail_image(ow, height: oh)
51
+ end
52
+ else
53
+ img.thumbnail_image(size[0], height: size[1], size: :force)
54
+ end
55
+ end
56
+
35
57
  # TODO improve
36
58
  def to_tensor(pic)
37
- Torch.tensor(pic.get_pixels, dtype: :float)
59
+ if !pic.is_a?(Numo::NArray) && !pic.is_a?(Vips::Image)
60
+ raise ArgumentError, "pic should be Vips::Image or Numo::NArray. Got #{pic.class.name}"
61
+ end
62
+
63
+ if pic.is_a?(Numo::NArray) && ![2, 3].include?(pic.ndim)
64
+ raise ArgumentError, "pic should be 2/3 dimensional. Got #{pic.dim} dimensions."
65
+ end
66
+
67
+ if pic.is_a?(Numo::NArray)
68
+ if pic.ndim == 2
69
+ pic = pic.reshape(*pic.shape, 1)
70
+ end
71
+
72
+ img = Torch.from_numo(pic.transpose(2, 0, 1))
73
+ if img.dtype == :uint8
74
+ return img.float.div(255)
75
+ else
76
+ return img
77
+ end
78
+ end
79
+
80
+ case pic.format
81
+ when :uchar
82
+ img = Torch::ByteTensor.new(Torch::ByteStorage.from_buffer(pic.write_to_memory))
83
+ else
84
+ raise Error, "Format not supported yet: #{pic.format}"
85
+ end
86
+
87
+ img = img.view(pic.height, pic.width, pic.bands)
88
+ # put it from HWC to CHW format
89
+ img = img.permute([2, 0, 1]).contiguous
90
+ img.float.div(255)
91
+ end
92
+
93
+ def hflip(img)
94
+ if img.is_a?(Torch::Tensor)
95
+ assert_image_tensor(img)
96
+ img.flip(-1)
97
+ else
98
+ img.flip(:horizontal)
99
+ end
100
+ end
101
+
102
+ def vflip(img)
103
+ if img.is_a?(Torch::Tensor)
104
+ assert_image_tensor(img)
105
+ img.flip(-2)
106
+ else
107
+ img.flip(:vertical)
108
+ end
109
+ end
110
+
111
+ def crop(img, top, left, height, width)
112
+ if img.is_a?(Torch::Tensor)
113
+ assert_image_tensor(img)
114
+ indexes = [true] * (img.dim - 2)
115
+ img[*indexes, top...(top + height), left...(left + width)]
116
+ else
117
+ img.crop(left, top, width, height)
118
+ end
119
+ end
120
+
121
+ def center_crop(img, output_size)
122
+ if output_size.is_a?(Integer)
123
+ output_size = [output_size.to_i, output_size.to_i]
124
+ elsif output_size.is_a?(Array) && output_size.length == 1
125
+ output_size = [output_size[0], output_size[0]]
126
+ end
127
+
128
+ image_width, image_height = image_size(img)
129
+ crop_height, crop_width = output_size
130
+
131
+ if crop_width > image_width || crop_height > image_height
132
+ padding_ltrb = [
133
+ crop_width > image_width ? (crop_width - image_width).div(2) : 0,
134
+ crop_height > image_height ? (crop_height - image_height).div(2) : 0,
135
+ crop_width > image_width ? (crop_width - image_width + 1).div(2) : 0,
136
+ crop_height > image_height ? (crop_height - image_height + 1).div(2) : 0
137
+ ]
138
+ # TODO
139
+ img = pad(img, padding_ltrb, fill: 0)
140
+ image_width, image_height = image_size(img)
141
+ if crop_width == image_width && crop_height == image_height
142
+ return img
143
+ end
144
+ end
145
+
146
+ crop_top = ((image_height - crop_height) / 2.0).round
147
+ crop_left = ((image_width - crop_width) / 2.0).round
148
+ crop(img, crop_top, crop_left, crop_height, crop_width)
149
+ end
150
+
151
+ # TODO interpolation
152
+ def resized_crop(img, top, left, height, width, size)
153
+ img = crop(img, top, left, height, width)
154
+ img = resize(img, size) #, interpolation)
155
+ img
156
+ end
157
+
158
+ private
159
+
160
+ def image_size(img)
161
+ if img.is_a?(Torch::Tensor)
162
+ assert_image_tensor(img)
163
+ [img.shape[-1], img.shape[-2]]
164
+ else
165
+ [img.width, img.height]
166
+ end
167
+ end
168
+
169
+ def assert_image_tensor(img)
170
+ if img.ndim < 2
171
+ raise TypeError, "Tensor is not a torch image."
172
+ end
38
173
  end
39
174
  end
40
175
  end
@@ -1,13 +1,13 @@
1
1
  module TorchVision
2
2
  module Transforms
3
- class Normalize
3
+ class Normalize < Torch::NN::Module
4
4
  def initialize(mean, std, inplace: false)
5
5
  @mean = mean
6
6
  @std = std
7
7
  @inplace = inplace
8
8
  end
9
9
 
10
- def call(tensor)
10
+ def forward(tensor)
11
11
  F.normalize(tensor, @mean, @std, inplace: @inplace)
12
12
  end
13
13
  end
@@ -0,0 +1,18 @@
1
+ module TorchVision
2
+ module Transforms
3
+ class RandomHorizontalFlip < Torch::NN::Module
4
+ def initialize(p: 0.5)
5
+ super()
6
+ @p = p
7
+ end
8
+
9
+ def forward(img)
10
+ if Torch.rand(1).item < @p
11
+ F.hflip(img)
12
+ else
13
+ img
14
+ end
15
+ end
16
+ end
17
+ end
18
+ end
@@ -0,0 +1,70 @@
1
+ module TorchVision
2
+ module Transforms
3
+ class RandomResizedCrop < Torch::NN::Module
4
+ def initialize(size, scale: [0.08, 1.0], ratio: [3.0 / 4.0, 4.0 / 3.0])
5
+ super()
6
+ @size = setup_size(size, "Please provide only two dimensions (h, w) for size.")
7
+ # @interpolation = interpolation
8
+ @scale = scale
9
+ @ratio = ratio
10
+ end
11
+
12
+ def params(img, scale, ratio)
13
+ width, height = F.send(:image_size, img)
14
+ area = height * width
15
+
16
+ log_ratio = Torch.log(Torch.tensor(ratio))
17
+ 10.times do
18
+ target_area = area * Torch.empty(1).uniform!(scale[0], scale[1]).item
19
+ aspect_ratio = Torch.exp(
20
+ Torch.empty(1).uniform!(log_ratio[0], log_ratio[1])
21
+ ).item
22
+
23
+ w = Math.sqrt(target_area * aspect_ratio).round
24
+ h = Math.sqrt(target_area / aspect_ratio).round
25
+
26
+ if 0 < w && w <= width && 0 < h && h <= height
27
+ i = Torch.randint(0, height - h + 1, size: [1]).item
28
+ j = Torch.randint(0, width - w + 1, size: [1]).item
29
+ return i, j, h, w
30
+ end
31
+ end
32
+
33
+ # Fallback to central crop
34
+ in_ratio = width.to_f / height.to_f
35
+ if in_ratio < ratio.min
36
+ w = width
37
+ h = (w / ratio.min).round
38
+ elsif in_ratio > ratio.max
39
+ h = height
40
+ w = (h * ratio.max).round
41
+ else # whole image
42
+ w = width
43
+ h = height
44
+ end
45
+ i = (height - h).div(2)
46
+ j = (width - w).div(2)
47
+ [i, j, h, w]
48
+ end
49
+
50
+ def forward(img)
51
+ i, j, h, w = params(img, @scale, @ratio)
52
+ F.resized_crop(img, i, j, h, w, @size) #, @interpolation)
53
+ end
54
+
55
+ private
56
+
57
+ def setup_size(size, error_msg)
58
+ if size.is_a?(Integer)
59
+ return [size, size]
60
+ end
61
+
62
+ if size.length != 2
63
+ raise ArgumentError, error_msg
64
+ end
65
+
66
+ size
67
+ end
68
+ end
69
+ end
70
+ end
@@ -0,0 +1,18 @@
1
+ module TorchVision
2
+ module Transforms
3
+ class RandomVerticalFlip < Torch::NN::Module
4
+ def initialize(p: 0.5)
5
+ super()
6
+ @p = p
7
+ end
8
+
9
+ def forward(img)
10
+ if Torch.rand(1).item < @p
11
+ F.vflip(img)
12
+ else
13
+ img
14
+ end
15
+ end
16
+ end
17
+ end
18
+ end
@@ -0,0 +1,13 @@
1
+ module TorchVision
2
+ module Transforms
3
+ class Resize < Torch::NN::Module
4
+ def initialize(size)
5
+ @size = size
6
+ end
7
+
8
+ def forward(img)
9
+ F.resize(img, @size)
10
+ end
11
+ end
12
+ end
13
+ end
@@ -1,7 +1,7 @@
1
1
  module TorchVision
2
2
  module Transforms
3
- class ToTensor
4
- def call(pic)
3
+ class ToTensor < Torch::NN::Module
4
+ def forward(pic)
5
5
  F.to_tensor(pic)
6
6
  end
7
7
  end
@@ -0,0 +1,120 @@
1
+ module TorchVision
2
+ module Utils
3
+ class << self
4
+ def make_grid(tensor, nrow: 8, padding: 2, normalize: false, range: nil, scale_each: false, pad_value: 0)
5
+ unless Torch.tensor?(tensor) || (tensor.is_a?(Array) && tensor.all? { |t| Torch.tensor?(t) })
6
+ raise ArgumentError, "tensor or list of tensors expected, got #{tensor.class.name}"
7
+ end
8
+
9
+ # if list of tensors, convert to a 4D mini-batch Tensor
10
+ if tensor.is_a?(Array)
11
+ tensor = Torch.stack(tensor, dim: 0)
12
+ end
13
+
14
+ if tensor.dim == 2 # single image H x W
15
+ tensor = tensor.unsqueeze(0)
16
+ end
17
+ if tensor.dim == 3 # single image
18
+ if tensor.size(0) == 1 # if single-channel, convert to 3-channel
19
+ tensor = Torch.cat([tensor, tensor, tensor], 0)
20
+ end
21
+ tensor = tensor.unsqueeze(0)
22
+ end
23
+
24
+ if tensor.dim == 4 && tensor.size(1) == 1 # single-channel images
25
+ tensor = Torch.cat([tensor, tensor, tensor], 1)
26
+ end
27
+
28
+ if normalize
29
+ tensor = tensor.clone # avoid modifying tensor in-place
30
+ if !range.nil? && !range.is_a?(Array)
31
+ raise "range has to be an array (min, max) if specified. min and max are numbers"
32
+ end
33
+
34
+ norm_ip = lambda do |img, min, max|
35
+ img.clamp!(min, max)
36
+ img.add!(-min).div!(max - min + 1e-5)
37
+ end
38
+
39
+ norm_range = lambda do |t, range|
40
+ if !range.nil?
41
+ norm_ip.call(t, range[0], range[1])
42
+ else
43
+ norm_ip.call(t, t.min.to_f, t.max.to_f)
44
+ end
45
+ end
46
+
47
+ if scale_each
48
+ tensor.each do |t| # loop over mini-batch dimension
49
+ norm_range.call(t, range)
50
+ end
51
+ else
52
+ norm_range.call(tensor, range)
53
+ end
54
+ end
55
+
56
+ if tensor.size(0) == 1
57
+ return tensor.squeeze(0)
58
+ end
59
+
60
+ # make the mini-batch of images into a grid
61
+ nmaps = tensor.size(0)
62
+ xmaps = [nrow, nmaps].min
63
+ ymaps = (nmaps.to_f / xmaps).ceil
64
+ height, width = (tensor.size(2) + padding), (tensor.size(3) + padding)
65
+ num_channels = tensor.size(1)
66
+ grid = tensor.new_full([num_channels, height * ymaps + padding, width * xmaps + padding], pad_value)
67
+ k = 0
68
+ ymaps.times do |y|
69
+ xmaps.times do |x|
70
+ break if k >= nmaps
71
+ grid.narrow(1, y * height + padding, height - padding).narrow(2, x * width + padding, width - padding).copy!(tensor[k])
72
+ k += 1
73
+ end
74
+ end
75
+ grid
76
+ end
77
+
78
+ def save_image(tensor, fp, nrow: 8, padding: 2, normalize: false, range: nil, scale_each: false, pad_value: 0)
79
+ grid = make_grid(tensor, nrow: nrow, padding: padding, pad_value: pad_value, normalize: normalize, range: range, scale_each: scale_each)
80
+ # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
81
+ ndarr = grid.mul(255).add!(0.5).clamp!(0, 255).permute(1, 2, 0).to("cpu", dtype: :uint8)
82
+ im = image_from_array(ndarr)
83
+ im.write_to_file(fp)
84
+ end
85
+
86
+ # private
87
+ # Ruby-specific method
88
+ # TODO use Numo when bridge available
89
+ def image_from_array(array)
90
+ case array
91
+ when Torch::Tensor
92
+ # TODO support more dtypes
93
+ raise "Type not supported yet: #{array.dtype}" unless array.dtype == :uint8
94
+
95
+ array = array.contiguous unless array.contiguous?
96
+
97
+ width, height = array.shape
98
+ bands = array.shape[2] || 1
99
+ data = FFI::Pointer.new(:uint8, array._data_ptr)
100
+ data.define_singleton_method(:bytesize) do
101
+ array.numel * array.element_size
102
+ end
103
+
104
+ Vips::Image.new_from_memory(data, width, height, bands, :uchar)
105
+ when Numo::NArray
106
+ # TODO support more types
107
+ raise "Type not supported yet: #{array.class.name}" unless array.is_a?(Numo::UInt8)
108
+
109
+ width, height = array.shape
110
+ bands = array.shape[2] || 1
111
+ data = array.to_binary
112
+
113
+ Vips::Image.new_from_memory(data, width, height, bands, :uchar)
114
+ else
115
+ raise "Expected Torch::Tensor or Numo::NArray, not #{array.class.name}"
116
+ end
117
+ end
118
+ end
119
+ end
120
+ end