torchvision 0.2.0 → 0.2.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/README.md +10 -3
- data/lib/torchvision.rb +4 -0
- data/lib/torchvision/datasets/dataset_folder.rb +91 -0
- data/lib/torchvision/datasets/image_folder.rb +12 -0
- data/lib/torchvision/datasets/vision_dataset.rb +1 -2
- data/lib/torchvision/transforms/center_crop.rb +13 -0
- data/lib/torchvision/transforms/functional.rb +81 -13
- data/lib/torchvision/transforms/random_resized_crop.rb +70 -0
- data/lib/torchvision/version.rb +1 -1
- metadata +6 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: bbb87c59c0f081c0de57ccdd62e30bfc551e1cb69523e4ffd498c997e1a2d8b3
|
4
|
+
data.tar.gz: 890da113706e659d57194980c5c9262075beb8398a75da2997c0812b70abe308
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 3445b62b7824ae16205034881d37c48ac4c70d7e5677014755ae5600632f9ce45168f41b0d3e98c8104eb8337e1566db4df3e0ad5ace5e6a46a5d213d01b6c8d
|
7
|
+
data.tar.gz: 93f22c385586ff8a010880676806f6bc9ba2f614c4c14886235a300ea6e2abce0f80c260e255644e1b4d24e6ecddfd21830dde2960a92a9492239e69622d4548
|
data/CHANGELOG.md
CHANGED
data/README.md
CHANGED
@@ -6,7 +6,7 @@
|
|
6
6
|
|
7
7
|
## Installation
|
8
8
|
|
9
|
-
First, [install libvips](libvips-installation). For Homebrew, use:
|
9
|
+
First, [install libvips](#libvips-installation). For Homebrew, use:
|
10
10
|
|
11
11
|
```sh
|
12
12
|
brew install vips
|
@@ -25,7 +25,8 @@ This library follows the [Python API](https://pytorch.org/docs/stable/torchvisio
|
|
25
25
|
## Examples
|
26
26
|
|
27
27
|
- [MNIST](https://github.com/ankane/torch.rb/tree/master/examples/mnist)
|
28
|
-
- [
|
28
|
+
- [Transfer learning](https://github.com/ankane/torch.rb/tree/master/examples/transfer-learning)
|
29
|
+
- [Generative adversarial networks](https://github.com/ankane/torch.rb/tree/master/examples/gan)
|
29
30
|
|
30
31
|
## Datasets
|
31
32
|
|
@@ -54,9 +55,11 @@ TorchVision::Transforms::Compose.new([
|
|
54
55
|
|
55
56
|
Supported transforms are:
|
56
57
|
|
58
|
+
- CenterCrop
|
57
59
|
- Compose
|
58
60
|
- Normalize
|
59
61
|
- RandomHorizontalFlip
|
62
|
+
- RandomResizedCrop
|
60
63
|
- RandomVerticalFlip
|
61
64
|
- Resize
|
62
65
|
- ToTensor
|
@@ -130,12 +133,16 @@ net.load_state_dict(Torch.load("net.pth"))
|
|
130
133
|
|
131
134
|
## libvips Installation
|
132
135
|
|
133
|
-
###
|
136
|
+
### Linux
|
137
|
+
|
138
|
+
Check your package manager. For Ubuntu, use:
|
134
139
|
|
135
140
|
```sh
|
136
141
|
sudo apt install libvips
|
137
142
|
```
|
138
143
|
|
144
|
+
You can also [build from source](https://libvips.github.io/libvips/install.html).
|
145
|
+
|
139
146
|
### Mac
|
140
147
|
|
141
148
|
```sh
|
data/lib/torchvision.rb
CHANGED
@@ -16,6 +16,8 @@ require "torchvision/version"
|
|
16
16
|
|
17
17
|
# datasets
|
18
18
|
require "torchvision/datasets/vision_dataset"
|
19
|
+
require "torchvision/datasets/dataset_folder"
|
20
|
+
require "torchvision/datasets/image_folder"
|
19
21
|
require "torchvision/datasets/cifar10"
|
20
22
|
require "torchvision/datasets/cifar100"
|
21
23
|
require "torchvision/datasets/mnist"
|
@@ -47,10 +49,12 @@ require "torchvision/models/wide_resnet50_2"
|
|
47
49
|
require "torchvision/models/wide_resnet101_2"
|
48
50
|
|
49
51
|
# transforms
|
52
|
+
require "torchvision/transforms/center_crop"
|
50
53
|
require "torchvision/transforms/compose"
|
51
54
|
require "torchvision/transforms/functional"
|
52
55
|
require "torchvision/transforms/normalize"
|
53
56
|
require "torchvision/transforms/random_horizontal_flip"
|
57
|
+
require "torchvision/transforms/random_resized_crop"
|
54
58
|
require "torchvision/transforms/random_vertical_flip"
|
55
59
|
require "torchvision/transforms/resize"
|
56
60
|
require "torchvision/transforms/to_tensor"
|
@@ -0,0 +1,91 @@
|
|
1
|
+
module TorchVision
|
2
|
+
module Datasets
|
3
|
+
class DatasetFolder < VisionDataset
|
4
|
+
attr_reader :classes
|
5
|
+
|
6
|
+
def initialize(root, extensions: nil, transform: nil, target_transform: nil, is_valid_file: nil)
|
7
|
+
super(root, transform: transform, target_transform: target_transform)
|
8
|
+
classes, class_to_idx = find_classes(@root)
|
9
|
+
samples = make_dataset(@root, class_to_idx, extensions, is_valid_file)
|
10
|
+
if samples.empty?
|
11
|
+
msg = "Found 0 files in subfolders of: #{@root}\n"
|
12
|
+
unless extensions.nil?
|
13
|
+
msg += "Supported extensions are: #{extensions.join(",")}"
|
14
|
+
end
|
15
|
+
raise RuntimeError, msg
|
16
|
+
end
|
17
|
+
|
18
|
+
@loader = lambda do |path|
|
19
|
+
Vips::Image.new_from_file(path)
|
20
|
+
end
|
21
|
+
@extensions = extensions
|
22
|
+
|
23
|
+
@classes = classes
|
24
|
+
@class_to_idx = class_to_idx
|
25
|
+
@samples = samples
|
26
|
+
@targets = samples.map { |s| s[1] }
|
27
|
+
end
|
28
|
+
|
29
|
+
def [](index)
|
30
|
+
path, target = @samples[index]
|
31
|
+
sample = @loader.call(path)
|
32
|
+
if @transform
|
33
|
+
sample = @transform.call(sample)
|
34
|
+
end
|
35
|
+
if @target_transform
|
36
|
+
target = @target_transform.call(target)
|
37
|
+
end
|
38
|
+
|
39
|
+
[sample, target]
|
40
|
+
end
|
41
|
+
|
42
|
+
def size
|
43
|
+
@samples.size
|
44
|
+
end
|
45
|
+
|
46
|
+
private
|
47
|
+
|
48
|
+
def find_classes(dir)
|
49
|
+
classes = Dir.children(dir).select { |d| File.directory?(File.join(dir, d)) }
|
50
|
+
classes.sort!
|
51
|
+
class_to_idx = classes.map.with_index.to_h
|
52
|
+
[classes, class_to_idx]
|
53
|
+
end
|
54
|
+
|
55
|
+
def has_file_allowed_extension(filename, extensions)
|
56
|
+
filename = filename.downcase
|
57
|
+
extensions.any? { |ext| filename.end_with?(ext) }
|
58
|
+
end
|
59
|
+
|
60
|
+
def make_dataset(directory, class_to_idx, extensions, is_valid_file)
|
61
|
+
instances = []
|
62
|
+
directory = File.expand_path(directory)
|
63
|
+
both_none = extensions.nil? && is_valid_file.nil?
|
64
|
+
both_something = !extensions.nil? && !is_valid_file.nil?
|
65
|
+
if both_none || both_something
|
66
|
+
raise ArgumentError, "Both extensions and is_valid_file cannot be None or not None at the same time"
|
67
|
+
end
|
68
|
+
if !extensions.nil?
|
69
|
+
is_valid_file = lambda do |x|
|
70
|
+
has_file_allowed_extension(x, extensions)
|
71
|
+
end
|
72
|
+
end
|
73
|
+
class_to_idx.keys.sort.each do |target_class|
|
74
|
+
class_index = class_to_idx[target_class]
|
75
|
+
target_dir = File.join(directory, target_class)
|
76
|
+
if !File.directory?(target_dir)
|
77
|
+
next
|
78
|
+
end
|
79
|
+
Dir.glob("**", base: target_dir).sort.each do |fname|
|
80
|
+
path = File.join(target_dir, fname)
|
81
|
+
if is_valid_file.call(path)
|
82
|
+
item = [path, class_index]
|
83
|
+
instances << item
|
84
|
+
end
|
85
|
+
end
|
86
|
+
end
|
87
|
+
instances
|
88
|
+
end
|
89
|
+
end
|
90
|
+
end
|
91
|
+
end
|
@@ -0,0 +1,12 @@
|
|
1
|
+
module TorchVision
|
2
|
+
module Datasets
|
3
|
+
class ImageFolder < DatasetFolder
|
4
|
+
IMG_EXTENSIONS = [".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp"]
|
5
|
+
|
6
|
+
def initialize(root, transform: nil, target_transform: nil, is_valid_file: nil)
|
7
|
+
super(root, extensions: IMG_EXTENSIONS, transform: transform, target_transform: target_transform, is_valid_file: is_valid_file)
|
8
|
+
@imgs = @samples
|
9
|
+
end
|
10
|
+
end
|
11
|
+
end
|
12
|
+
end
|
@@ -1,7 +1,6 @@
|
|
1
1
|
module TorchVision
|
2
2
|
module Datasets
|
3
|
-
|
4
|
-
class VisionDataset
|
3
|
+
class VisionDataset < Torch::Utils::Data::Dataset
|
5
4
|
attr_reader :data, :targets
|
6
5
|
|
7
6
|
def initialize(root, transforms: nil, transform: nil, target_transform: nil)
|
@@ -34,21 +34,23 @@ module TorchVision
|
|
34
34
|
|
35
35
|
def resize(img, size)
|
36
36
|
raise "img should be Vips::Image. Got #{img.class.name}" unless img.is_a?(Vips::Image)
|
37
|
-
# TODO support array size
|
38
|
-
raise "Got inappropriate size arg: #{size}" unless size.is_a?(Integer)
|
39
37
|
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
38
|
+
if size.is_a?(Integer)
|
39
|
+
w, h = img.size
|
40
|
+
if (w <= h && w == size) || (h <= w && h == size)
|
41
|
+
return img
|
42
|
+
end
|
43
|
+
if w < h
|
44
|
+
ow = size
|
45
|
+
oh = (size * h / w).to_i
|
46
|
+
img.thumbnail_image(ow, height: oh)
|
47
|
+
else
|
48
|
+
oh = size
|
49
|
+
ow = (size * w / h).to_i
|
50
|
+
img.thumbnail_image(ow, height: oh)
|
51
|
+
end
|
48
52
|
else
|
49
|
-
|
50
|
-
ow = (size * w / h).to_i
|
51
|
-
img.thumbnail_image(ow, height: oh)
|
53
|
+
img.thumbnail_image(size[0], height: size[1], size: :force)
|
52
54
|
end
|
53
55
|
end
|
54
56
|
|
@@ -90,6 +92,7 @@ module TorchVision
|
|
90
92
|
|
91
93
|
def hflip(img)
|
92
94
|
if img.is_a?(Torch::Tensor)
|
95
|
+
assert_image_tensor(img)
|
93
96
|
img.flip(-1)
|
94
97
|
else
|
95
98
|
img.flip(:horizontal)
|
@@ -98,11 +101,76 @@ module TorchVision
|
|
98
101
|
|
99
102
|
def vflip(img)
|
100
103
|
if img.is_a?(Torch::Tensor)
|
104
|
+
assert_image_tensor(img)
|
101
105
|
img.flip(-2)
|
102
106
|
else
|
103
107
|
img.flip(:vertical)
|
104
108
|
end
|
105
109
|
end
|
110
|
+
|
111
|
+
def crop(img, top, left, height, width)
|
112
|
+
if img.is_a?(Torch::Tensor)
|
113
|
+
assert_image_tensor(img)
|
114
|
+
indexes = [true] * (img.dim - 2)
|
115
|
+
img[*indexes, top...(top + height), left...(left + width)]
|
116
|
+
else
|
117
|
+
img.crop(left, top, width, height)
|
118
|
+
end
|
119
|
+
end
|
120
|
+
|
121
|
+
def center_crop(img, output_size)
|
122
|
+
if output_size.is_a?(Integer)
|
123
|
+
output_size = [output_size.to_i, output_size.to_i]
|
124
|
+
elsif output_size.is_a?(Array) && output_size.length == 1
|
125
|
+
output_size = [output_size[0], output_size[0]]
|
126
|
+
end
|
127
|
+
|
128
|
+
image_width, image_height = image_size(img)
|
129
|
+
crop_height, crop_width = output_size
|
130
|
+
|
131
|
+
if crop_width > image_width || crop_height > image_height
|
132
|
+
padding_ltrb = [
|
133
|
+
crop_width > image_width ? (crop_width - image_width).div(2) : 0,
|
134
|
+
crop_height > image_height ? (crop_height - image_height).div(2) : 0,
|
135
|
+
crop_width > image_width ? (crop_width - image_width + 1).div(2) : 0,
|
136
|
+
crop_height > image_height ? (crop_height - image_height + 1).div(2) : 0
|
137
|
+
]
|
138
|
+
# TODO
|
139
|
+
img = pad(img, padding_ltrb, fill: 0)
|
140
|
+
image_width, image_height = image_size(img)
|
141
|
+
if crop_width == image_width && crop_height == image_height
|
142
|
+
return img
|
143
|
+
end
|
144
|
+
end
|
145
|
+
|
146
|
+
crop_top = ((image_height - crop_height) / 2.0).round
|
147
|
+
crop_left = ((image_width - crop_width) / 2.0).round
|
148
|
+
crop(img, crop_top, crop_left, crop_height, crop_width)
|
149
|
+
end
|
150
|
+
|
151
|
+
# TODO interpolation
|
152
|
+
def resized_crop(img, top, left, height, width, size)
|
153
|
+
img = crop(img, top, left, height, width)
|
154
|
+
img = resize(img, size) #, interpolation)
|
155
|
+
img
|
156
|
+
end
|
157
|
+
|
158
|
+
private
|
159
|
+
|
160
|
+
def image_size(img)
|
161
|
+
if img.is_a?(Torch::Tensor)
|
162
|
+
assert_image_tensor(img)
|
163
|
+
[img.shape[-1], img.shape[-2]]
|
164
|
+
else
|
165
|
+
[img.width, img.height]
|
166
|
+
end
|
167
|
+
end
|
168
|
+
|
169
|
+
def assert_image_tensor(img)
|
170
|
+
if img.ndim < 2
|
171
|
+
raise TypeError, "Tensor is not a torch image."
|
172
|
+
end
|
173
|
+
end
|
106
174
|
end
|
107
175
|
end
|
108
176
|
|
@@ -0,0 +1,70 @@
|
|
1
|
+
module TorchVision
|
2
|
+
module Transforms
|
3
|
+
class RandomResizedCrop < Torch::NN::Module
|
4
|
+
def initialize(size, scale: [0.08, 1.0], ratio: [3.0 / 4.0, 4.0 / 3.0])
|
5
|
+
super()
|
6
|
+
@size = setup_size(size, "Please provide only two dimensions (h, w) for size.")
|
7
|
+
# @interpolation = interpolation
|
8
|
+
@scale = scale
|
9
|
+
@ratio = ratio
|
10
|
+
end
|
11
|
+
|
12
|
+
def params(img, scale, ratio)
|
13
|
+
width, height = F.send(:image_size, img)
|
14
|
+
area = height * width
|
15
|
+
|
16
|
+
log_ratio = Torch.log(Torch.tensor(ratio))
|
17
|
+
10.times do
|
18
|
+
target_area = area * Torch.empty(1).uniform!(scale[0], scale[1]).item
|
19
|
+
aspect_ratio = Torch.exp(
|
20
|
+
Torch.empty(1).uniform!(log_ratio[0], log_ratio[1])
|
21
|
+
).item
|
22
|
+
|
23
|
+
w = Math.sqrt(target_area * aspect_ratio).round
|
24
|
+
h = Math.sqrt(target_area / aspect_ratio).round
|
25
|
+
|
26
|
+
if 0 < w && w <= width && 0 < h && h <= height
|
27
|
+
i = Torch.randint(0, height - h + 1, size: [1]).item
|
28
|
+
j = Torch.randint(0, width - w + 1, size: [1]).item
|
29
|
+
return i, j, h, w
|
30
|
+
end
|
31
|
+
end
|
32
|
+
|
33
|
+
# Fallback to central crop
|
34
|
+
in_ratio = width.to_f / height.to_f
|
35
|
+
if in_ratio < ratio.min
|
36
|
+
w = width
|
37
|
+
h = (w / ratio.min).round
|
38
|
+
elsif in_ratio > ratio.max
|
39
|
+
h = height
|
40
|
+
w = (h * ratio.max).round
|
41
|
+
else # whole image
|
42
|
+
w = width
|
43
|
+
h = height
|
44
|
+
end
|
45
|
+
i = (height - h).div(2)
|
46
|
+
j = (width - w).div(2)
|
47
|
+
[i, j, h, w]
|
48
|
+
end
|
49
|
+
|
50
|
+
def forward(img)
|
51
|
+
i, j, h, w = params(img, @scale, @ratio)
|
52
|
+
F.resized_crop(img, i, j, h, w, @size) #, @interpolation)
|
53
|
+
end
|
54
|
+
|
55
|
+
private
|
56
|
+
|
57
|
+
def setup_size(size, error_msg)
|
58
|
+
if size.is_a?(Integer)
|
59
|
+
return [size, size]
|
60
|
+
end
|
61
|
+
|
62
|
+
if size.length != 2
|
63
|
+
raise ArgumentError, error_msg
|
64
|
+
end
|
65
|
+
|
66
|
+
size
|
67
|
+
end
|
68
|
+
end
|
69
|
+
end
|
70
|
+
end
|
data/lib/torchvision/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torchvision
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.2.
|
4
|
+
version: 0.2.1
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2021-03-
|
11
|
+
date: 2021-03-15 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-narray
|
@@ -64,7 +64,9 @@ files:
|
|
64
64
|
- lib/torchvision.rb
|
65
65
|
- lib/torchvision/datasets/cifar10.rb
|
66
66
|
- lib/torchvision/datasets/cifar100.rb
|
67
|
+
- lib/torchvision/datasets/dataset_folder.rb
|
67
68
|
- lib/torchvision/datasets/fashion_mnist.rb
|
69
|
+
- lib/torchvision/datasets/image_folder.rb
|
68
70
|
- lib/torchvision/datasets/kmnist.rb
|
69
71
|
- lib/torchvision/datasets/mnist.rb
|
70
72
|
- lib/torchvision/datasets/vision_dataset.rb
|
@@ -90,10 +92,12 @@ files:
|
|
90
92
|
- lib/torchvision/models/vgg19_bn.rb
|
91
93
|
- lib/torchvision/models/wide_resnet101_2.rb
|
92
94
|
- lib/torchvision/models/wide_resnet50_2.rb
|
95
|
+
- lib/torchvision/transforms/center_crop.rb
|
93
96
|
- lib/torchvision/transforms/compose.rb
|
94
97
|
- lib/torchvision/transforms/functional.rb
|
95
98
|
- lib/torchvision/transforms/normalize.rb
|
96
99
|
- lib/torchvision/transforms/random_horizontal_flip.rb
|
100
|
+
- lib/torchvision/transforms/random_resized_crop.rb
|
97
101
|
- lib/torchvision/transforms/random_vertical_flip.rb
|
98
102
|
- lib/torchvision/transforms/resize.rb
|
99
103
|
- lib/torchvision/transforms/to_tensor.rb
|