torchvision 0.1.0 → 0.2.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +35 -0
- data/LICENSE.txt +1 -1
- data/README.md +133 -5
- data/lib/torchvision.rb +40 -1
- data/lib/torchvision/datasets/cifar10.rb +117 -0
- data/lib/torchvision/datasets/cifar100.rb +41 -0
- data/lib/torchvision/datasets/dataset_folder.rb +91 -0
- data/lib/torchvision/datasets/fashion_mnist.rb +30 -0
- data/lib/torchvision/datasets/image_folder.rb +12 -0
- data/lib/torchvision/datasets/kmnist.rb +30 -0
- data/lib/torchvision/datasets/mnist.rb +47 -76
- data/lib/torchvision/datasets/vision_dataset.rb +67 -0
- data/lib/torchvision/models/alexnet.rb +42 -0
- data/lib/torchvision/models/basic_block.rb +46 -0
- data/lib/torchvision/models/bottleneck.rb +47 -0
- data/lib/torchvision/models/resnet.rb +129 -0
- data/lib/torchvision/models/resnet101.rb +9 -0
- data/lib/torchvision/models/resnet152.rb +9 -0
- data/lib/torchvision/models/resnet18.rb +9 -0
- data/lib/torchvision/models/resnet34.rb +9 -0
- data/lib/torchvision/models/resnet50.rb +9 -0
- data/lib/torchvision/models/resnext101_32x8d.rb +11 -0
- data/lib/torchvision/models/resnext50_32x4d.rb +11 -0
- data/lib/torchvision/models/vgg.rb +93 -0
- data/lib/torchvision/models/vgg11.rb +9 -0
- data/lib/torchvision/models/vgg11_bn.rb +9 -0
- data/lib/torchvision/models/vgg13.rb +9 -0
- data/lib/torchvision/models/vgg13_bn.rb +9 -0
- data/lib/torchvision/models/vgg16.rb +9 -0
- data/lib/torchvision/models/vgg16_bn.rb +9 -0
- data/lib/torchvision/models/vgg19.rb +9 -0
- data/lib/torchvision/models/vgg19_bn.rb +9 -0
- data/lib/torchvision/models/wide_resnet101_2.rb +10 -0
- data/lib/torchvision/models/wide_resnet50_2.rb +10 -0
- data/lib/torchvision/transforms/center_crop.rb +13 -0
- data/lib/torchvision/transforms/compose.rb +2 -2
- data/lib/torchvision/transforms/functional.rb +142 -7
- data/lib/torchvision/transforms/normalize.rb +2 -2
- data/lib/torchvision/transforms/random_horizontal_flip.rb +18 -0
- data/lib/torchvision/transforms/random_resized_crop.rb +70 -0
- data/lib/torchvision/transforms/random_vertical_flip.rb +18 -0
- data/lib/torchvision/transforms/resize.rb +13 -0
- data/lib/torchvision/transforms/to_tensor.rb +2 -2
- data/lib/torchvision/utils.rb +120 -0
- data/lib/torchvision/version.rb +1 -1
- metadata +50 -57
@@ -22,19 +22,154 @@ module TorchVision
|
|
22
22
|
if std.to_a.any? { |v| v == 0 }
|
23
23
|
raise ArgumentError, "std evaluated to zero after conversion to #{dtype}, leading to division by zero."
|
24
24
|
end
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
25
|
+
if mean.ndim == 1
|
26
|
+
mean = mean[0...mean.size(0), nil, nil]
|
27
|
+
end
|
28
|
+
if std.ndim == 1
|
29
|
+
std = std[0...std.size(0), nil, nil]
|
30
|
+
end
|
31
31
|
tensor.sub!(mean).div!(std)
|
32
32
|
tensor
|
33
33
|
end
|
34
34
|
|
35
|
+
def resize(img, size)
|
36
|
+
raise "img should be Vips::Image. Got #{img.class.name}" unless img.is_a?(Vips::Image)
|
37
|
+
|
38
|
+
if size.is_a?(Integer)
|
39
|
+
w, h = img.size
|
40
|
+
if (w <= h && w == size) || (h <= w && h == size)
|
41
|
+
return img
|
42
|
+
end
|
43
|
+
if w < h
|
44
|
+
ow = size
|
45
|
+
oh = (size * h / w).to_i
|
46
|
+
img.thumbnail_image(ow, height: oh)
|
47
|
+
else
|
48
|
+
oh = size
|
49
|
+
ow = (size * w / h).to_i
|
50
|
+
img.thumbnail_image(ow, height: oh)
|
51
|
+
end
|
52
|
+
else
|
53
|
+
img.thumbnail_image(size[0], height: size[1], size: :force)
|
54
|
+
end
|
55
|
+
end
|
56
|
+
|
35
57
|
# TODO improve
|
36
58
|
def to_tensor(pic)
|
37
|
-
|
59
|
+
if !pic.is_a?(Numo::NArray) && !pic.is_a?(Vips::Image)
|
60
|
+
raise ArgumentError, "pic should be Vips::Image or Numo::NArray. Got #{pic.class.name}"
|
61
|
+
end
|
62
|
+
|
63
|
+
if pic.is_a?(Numo::NArray) && ![2, 3].include?(pic.ndim)
|
64
|
+
raise ArgumentError, "pic should be 2/3 dimensional. Got #{pic.dim} dimensions."
|
65
|
+
end
|
66
|
+
|
67
|
+
if pic.is_a?(Numo::NArray)
|
68
|
+
if pic.ndim == 2
|
69
|
+
pic = pic.reshape(*pic.shape, 1)
|
70
|
+
end
|
71
|
+
|
72
|
+
img = Torch.from_numo(pic.transpose(2, 0, 1))
|
73
|
+
if img.dtype == :uint8
|
74
|
+
return img.float.div(255)
|
75
|
+
else
|
76
|
+
return img
|
77
|
+
end
|
78
|
+
end
|
79
|
+
|
80
|
+
case pic.format
|
81
|
+
when :uchar
|
82
|
+
img = Torch::ByteTensor.new(Torch::ByteStorage.from_buffer(pic.write_to_memory))
|
83
|
+
else
|
84
|
+
raise Error, "Format not supported yet: #{pic.format}"
|
85
|
+
end
|
86
|
+
|
87
|
+
img = img.view(pic.height, pic.width, pic.bands)
|
88
|
+
# put it from HWC to CHW format
|
89
|
+
img = img.permute([2, 0, 1]).contiguous
|
90
|
+
img.float.div(255)
|
91
|
+
end
|
92
|
+
|
93
|
+
def hflip(img)
|
94
|
+
if img.is_a?(Torch::Tensor)
|
95
|
+
assert_image_tensor(img)
|
96
|
+
img.flip(-1)
|
97
|
+
else
|
98
|
+
img.flip(:horizontal)
|
99
|
+
end
|
100
|
+
end
|
101
|
+
|
102
|
+
def vflip(img)
|
103
|
+
if img.is_a?(Torch::Tensor)
|
104
|
+
assert_image_tensor(img)
|
105
|
+
img.flip(-2)
|
106
|
+
else
|
107
|
+
img.flip(:vertical)
|
108
|
+
end
|
109
|
+
end
|
110
|
+
|
111
|
+
def crop(img, top, left, height, width)
|
112
|
+
if img.is_a?(Torch::Tensor)
|
113
|
+
assert_image_tensor(img)
|
114
|
+
indexes = [true] * (img.dim - 2)
|
115
|
+
img[*indexes, top...(top + height), left...(left + width)]
|
116
|
+
else
|
117
|
+
img.crop(left, top, width, height)
|
118
|
+
end
|
119
|
+
end
|
120
|
+
|
121
|
+
def center_crop(img, output_size)
|
122
|
+
if output_size.is_a?(Integer)
|
123
|
+
output_size = [output_size.to_i, output_size.to_i]
|
124
|
+
elsif output_size.is_a?(Array) && output_size.length == 1
|
125
|
+
output_size = [output_size[0], output_size[0]]
|
126
|
+
end
|
127
|
+
|
128
|
+
image_width, image_height = image_size(img)
|
129
|
+
crop_height, crop_width = output_size
|
130
|
+
|
131
|
+
if crop_width > image_width || crop_height > image_height
|
132
|
+
padding_ltrb = [
|
133
|
+
crop_width > image_width ? (crop_width - image_width).div(2) : 0,
|
134
|
+
crop_height > image_height ? (crop_height - image_height).div(2) : 0,
|
135
|
+
crop_width > image_width ? (crop_width - image_width + 1).div(2) : 0,
|
136
|
+
crop_height > image_height ? (crop_height - image_height + 1).div(2) : 0
|
137
|
+
]
|
138
|
+
# TODO
|
139
|
+
img = pad(img, padding_ltrb, fill: 0)
|
140
|
+
image_width, image_height = image_size(img)
|
141
|
+
if crop_width == image_width && crop_height == image_height
|
142
|
+
return img
|
143
|
+
end
|
144
|
+
end
|
145
|
+
|
146
|
+
crop_top = ((image_height - crop_height) / 2.0).round
|
147
|
+
crop_left = ((image_width - crop_width) / 2.0).round
|
148
|
+
crop(img, crop_top, crop_left, crop_height, crop_width)
|
149
|
+
end
|
150
|
+
|
151
|
+
# TODO interpolation
|
152
|
+
def resized_crop(img, top, left, height, width, size)
|
153
|
+
img = crop(img, top, left, height, width)
|
154
|
+
img = resize(img, size) #, interpolation)
|
155
|
+
img
|
156
|
+
end
|
157
|
+
|
158
|
+
private
|
159
|
+
|
160
|
+
def image_size(img)
|
161
|
+
if img.is_a?(Torch::Tensor)
|
162
|
+
assert_image_tensor(img)
|
163
|
+
[img.shape[-1], img.shape[-2]]
|
164
|
+
else
|
165
|
+
[img.width, img.height]
|
166
|
+
end
|
167
|
+
end
|
168
|
+
|
169
|
+
def assert_image_tensor(img)
|
170
|
+
if img.ndim < 2
|
171
|
+
raise TypeError, "Tensor is not a torch image."
|
172
|
+
end
|
38
173
|
end
|
39
174
|
end
|
40
175
|
end
|
@@ -1,13 +1,13 @@
|
|
1
1
|
module TorchVision
|
2
2
|
module Transforms
|
3
|
-
class Normalize
|
3
|
+
class Normalize < Torch::NN::Module
|
4
4
|
def initialize(mean, std, inplace: false)
|
5
5
|
@mean = mean
|
6
6
|
@std = std
|
7
7
|
@inplace = inplace
|
8
8
|
end
|
9
9
|
|
10
|
-
def
|
10
|
+
def forward(tensor)
|
11
11
|
F.normalize(tensor, @mean, @std, inplace: @inplace)
|
12
12
|
end
|
13
13
|
end
|
@@ -0,0 +1,18 @@
|
|
1
|
+
module TorchVision
|
2
|
+
module Transforms
|
3
|
+
class RandomHorizontalFlip < Torch::NN::Module
|
4
|
+
def initialize(p: 0.5)
|
5
|
+
super()
|
6
|
+
@p = p
|
7
|
+
end
|
8
|
+
|
9
|
+
def forward(img)
|
10
|
+
if Torch.rand(1).item < @p
|
11
|
+
F.hflip(img)
|
12
|
+
else
|
13
|
+
img
|
14
|
+
end
|
15
|
+
end
|
16
|
+
end
|
17
|
+
end
|
18
|
+
end
|
@@ -0,0 +1,70 @@
|
|
1
|
+
module TorchVision
|
2
|
+
module Transforms
|
3
|
+
class RandomResizedCrop < Torch::NN::Module
|
4
|
+
def initialize(size, scale: [0.08, 1.0], ratio: [3.0 / 4.0, 4.0 / 3.0])
|
5
|
+
super()
|
6
|
+
@size = setup_size(size, "Please provide only two dimensions (h, w) for size.")
|
7
|
+
# @interpolation = interpolation
|
8
|
+
@scale = scale
|
9
|
+
@ratio = ratio
|
10
|
+
end
|
11
|
+
|
12
|
+
def params(img, scale, ratio)
|
13
|
+
width, height = F.send(:image_size, img)
|
14
|
+
area = height * width
|
15
|
+
|
16
|
+
log_ratio = Torch.log(Torch.tensor(ratio))
|
17
|
+
10.times do
|
18
|
+
target_area = area * Torch.empty(1).uniform!(scale[0], scale[1]).item
|
19
|
+
aspect_ratio = Torch.exp(
|
20
|
+
Torch.empty(1).uniform!(log_ratio[0], log_ratio[1])
|
21
|
+
).item
|
22
|
+
|
23
|
+
w = Math.sqrt(target_area * aspect_ratio).round
|
24
|
+
h = Math.sqrt(target_area / aspect_ratio).round
|
25
|
+
|
26
|
+
if 0 < w && w <= width && 0 < h && h <= height
|
27
|
+
i = Torch.randint(0, height - h + 1, size: [1]).item
|
28
|
+
j = Torch.randint(0, width - w + 1, size: [1]).item
|
29
|
+
return i, j, h, w
|
30
|
+
end
|
31
|
+
end
|
32
|
+
|
33
|
+
# Fallback to central crop
|
34
|
+
in_ratio = width.to_f / height.to_f
|
35
|
+
if in_ratio < ratio.min
|
36
|
+
w = width
|
37
|
+
h = (w / ratio.min).round
|
38
|
+
elsif in_ratio > ratio.max
|
39
|
+
h = height
|
40
|
+
w = (h * ratio.max).round
|
41
|
+
else # whole image
|
42
|
+
w = width
|
43
|
+
h = height
|
44
|
+
end
|
45
|
+
i = (height - h).div(2)
|
46
|
+
j = (width - w).div(2)
|
47
|
+
[i, j, h, w]
|
48
|
+
end
|
49
|
+
|
50
|
+
def forward(img)
|
51
|
+
i, j, h, w = params(img, @scale, @ratio)
|
52
|
+
F.resized_crop(img, i, j, h, w, @size) #, @interpolation)
|
53
|
+
end
|
54
|
+
|
55
|
+
private
|
56
|
+
|
57
|
+
def setup_size(size, error_msg)
|
58
|
+
if size.is_a?(Integer)
|
59
|
+
return [size, size]
|
60
|
+
end
|
61
|
+
|
62
|
+
if size.length != 2
|
63
|
+
raise ArgumentError, error_msg
|
64
|
+
end
|
65
|
+
|
66
|
+
size
|
67
|
+
end
|
68
|
+
end
|
69
|
+
end
|
70
|
+
end
|
@@ -0,0 +1,18 @@
|
|
1
|
+
module TorchVision
|
2
|
+
module Transforms
|
3
|
+
class RandomVerticalFlip < Torch::NN::Module
|
4
|
+
def initialize(p: 0.5)
|
5
|
+
super()
|
6
|
+
@p = p
|
7
|
+
end
|
8
|
+
|
9
|
+
def forward(img)
|
10
|
+
if Torch.rand(1).item < @p
|
11
|
+
F.vflip(img)
|
12
|
+
else
|
13
|
+
img
|
14
|
+
end
|
15
|
+
end
|
16
|
+
end
|
17
|
+
end
|
18
|
+
end
|
@@ -0,0 +1,120 @@
|
|
1
|
+
module TorchVision
|
2
|
+
module Utils
|
3
|
+
class << self
|
4
|
+
def make_grid(tensor, nrow: 8, padding: 2, normalize: false, range: nil, scale_each: false, pad_value: 0)
|
5
|
+
unless Torch.tensor?(tensor) || (tensor.is_a?(Array) && tensor.all? { |t| Torch.tensor?(t) })
|
6
|
+
raise ArgumentError, "tensor or list of tensors expected, got #{tensor.class.name}"
|
7
|
+
end
|
8
|
+
|
9
|
+
# if list of tensors, convert to a 4D mini-batch Tensor
|
10
|
+
if tensor.is_a?(Array)
|
11
|
+
tensor = Torch.stack(tensor, dim: 0)
|
12
|
+
end
|
13
|
+
|
14
|
+
if tensor.dim == 2 # single image H x W
|
15
|
+
tensor = tensor.unsqueeze(0)
|
16
|
+
end
|
17
|
+
if tensor.dim == 3 # single image
|
18
|
+
if tensor.size(0) == 1 # if single-channel, convert to 3-channel
|
19
|
+
tensor = Torch.cat([tensor, tensor, tensor], 0)
|
20
|
+
end
|
21
|
+
tensor = tensor.unsqueeze(0)
|
22
|
+
end
|
23
|
+
|
24
|
+
if tensor.dim == 4 && tensor.size(1) == 1 # single-channel images
|
25
|
+
tensor = Torch.cat([tensor, tensor, tensor], 1)
|
26
|
+
end
|
27
|
+
|
28
|
+
if normalize
|
29
|
+
tensor = tensor.clone # avoid modifying tensor in-place
|
30
|
+
if !range.nil? && !range.is_a?(Array)
|
31
|
+
raise "range has to be an array (min, max) if specified. min and max are numbers"
|
32
|
+
end
|
33
|
+
|
34
|
+
norm_ip = lambda do |img, min, max|
|
35
|
+
img.clamp!(min, max)
|
36
|
+
img.add!(-min).div!(max - min + 1e-5)
|
37
|
+
end
|
38
|
+
|
39
|
+
norm_range = lambda do |t, range|
|
40
|
+
if !range.nil?
|
41
|
+
norm_ip.call(t, range[0], range[1])
|
42
|
+
else
|
43
|
+
norm_ip.call(t, t.min.to_f, t.max.to_f)
|
44
|
+
end
|
45
|
+
end
|
46
|
+
|
47
|
+
if scale_each
|
48
|
+
tensor.each do |t| # loop over mini-batch dimension
|
49
|
+
norm_range.call(t, range)
|
50
|
+
end
|
51
|
+
else
|
52
|
+
norm_range.call(tensor, range)
|
53
|
+
end
|
54
|
+
end
|
55
|
+
|
56
|
+
if tensor.size(0) == 1
|
57
|
+
return tensor.squeeze(0)
|
58
|
+
end
|
59
|
+
|
60
|
+
# make the mini-batch of images into a grid
|
61
|
+
nmaps = tensor.size(0)
|
62
|
+
xmaps = [nrow, nmaps].min
|
63
|
+
ymaps = (nmaps.to_f / xmaps).ceil
|
64
|
+
height, width = (tensor.size(2) + padding), (tensor.size(3) + padding)
|
65
|
+
num_channels = tensor.size(1)
|
66
|
+
grid = tensor.new_full([num_channels, height * ymaps + padding, width * xmaps + padding], pad_value)
|
67
|
+
k = 0
|
68
|
+
ymaps.times do |y|
|
69
|
+
xmaps.times do |x|
|
70
|
+
break if k >= nmaps
|
71
|
+
grid.narrow(1, y * height + padding, height - padding).narrow(2, x * width + padding, width - padding).copy!(tensor[k])
|
72
|
+
k += 1
|
73
|
+
end
|
74
|
+
end
|
75
|
+
grid
|
76
|
+
end
|
77
|
+
|
78
|
+
def save_image(tensor, fp, nrow: 8, padding: 2, normalize: false, range: nil, scale_each: false, pad_value: 0)
|
79
|
+
grid = make_grid(tensor, nrow: nrow, padding: padding, pad_value: pad_value, normalize: normalize, range: range, scale_each: scale_each)
|
80
|
+
# Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
|
81
|
+
ndarr = grid.mul(255).add!(0.5).clamp!(0, 255).permute(1, 2, 0).to("cpu", dtype: :uint8)
|
82
|
+
im = image_from_array(ndarr)
|
83
|
+
im.write_to_file(fp)
|
84
|
+
end
|
85
|
+
|
86
|
+
# private
|
87
|
+
# Ruby-specific method
|
88
|
+
# TODO use Numo when bridge available
|
89
|
+
def image_from_array(array)
|
90
|
+
case array
|
91
|
+
when Torch::Tensor
|
92
|
+
# TODO support more dtypes
|
93
|
+
raise "Type not supported yet: #{array.dtype}" unless array.dtype == :uint8
|
94
|
+
|
95
|
+
array = array.contiguous unless array.contiguous?
|
96
|
+
|
97
|
+
width, height = array.shape
|
98
|
+
bands = array.shape[2] || 1
|
99
|
+
data = FFI::Pointer.new(:uint8, array._data_ptr)
|
100
|
+
data.define_singleton_method(:bytesize) do
|
101
|
+
array.numel * array.element_size
|
102
|
+
end
|
103
|
+
|
104
|
+
Vips::Image.new_from_memory(data, width, height, bands, :uchar)
|
105
|
+
when Numo::NArray
|
106
|
+
# TODO support more types
|
107
|
+
raise "Type not supported yet: #{array.class.name}" unless array.is_a?(Numo::UInt8)
|
108
|
+
|
109
|
+
width, height = array.shape
|
110
|
+
bands = array.shape[2] || 1
|
111
|
+
data = array.to_binary
|
112
|
+
|
113
|
+
Vips::Image.new_from_memory(data, width, height, bands, :uchar)
|
114
|
+
else
|
115
|
+
raise "Expected Torch::Tensor or Numo::NArray, not #{array.class.name}"
|
116
|
+
end
|
117
|
+
end
|
118
|
+
end
|
119
|
+
end
|
120
|
+
end
|