torchvision 0.2.0 → 0.2.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/README.md +10 -3
- data/lib/torchvision.rb +4 -0
- data/lib/torchvision/datasets/dataset_folder.rb +91 -0
- data/lib/torchvision/datasets/image_folder.rb +12 -0
- data/lib/torchvision/datasets/vision_dataset.rb +1 -2
- data/lib/torchvision/transforms/center_crop.rb +13 -0
- data/lib/torchvision/transforms/functional.rb +81 -13
- data/lib/torchvision/transforms/random_resized_crop.rb +70 -0
- data/lib/torchvision/version.rb +1 -1
- metadata +6 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: bbb87c59c0f081c0de57ccdd62e30bfc551e1cb69523e4ffd498c997e1a2d8b3
|
4
|
+
data.tar.gz: 890da113706e659d57194980c5c9262075beb8398a75da2997c0812b70abe308
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 3445b62b7824ae16205034881d37c48ac4c70d7e5677014755ae5600632f9ce45168f41b0d3e98c8104eb8337e1566db4df3e0ad5ace5e6a46a5d213d01b6c8d
|
7
|
+
data.tar.gz: 93f22c385586ff8a010880676806f6bc9ba2f614c4c14886235a300ea6e2abce0f80c260e255644e1b4d24e6ecddfd21830dde2960a92a9492239e69622d4548
|
data/CHANGELOG.md
CHANGED
data/README.md
CHANGED
@@ -6,7 +6,7 @@
|
|
6
6
|
|
7
7
|
## Installation
|
8
8
|
|
9
|
-
First, [install libvips](libvips-installation). For Homebrew, use:
|
9
|
+
First, [install libvips](#libvips-installation). For Homebrew, use:
|
10
10
|
|
11
11
|
```sh
|
12
12
|
brew install vips
|
@@ -25,7 +25,8 @@ This library follows the [Python API](https://pytorch.org/docs/stable/torchvisio
|
|
25
25
|
## Examples
|
26
26
|
|
27
27
|
- [MNIST](https://github.com/ankane/torch.rb/tree/master/examples/mnist)
|
28
|
-
- [
|
28
|
+
- [Transfer learning](https://github.com/ankane/torch.rb/tree/master/examples/transfer-learning)
|
29
|
+
- [Generative adversarial networks](https://github.com/ankane/torch.rb/tree/master/examples/gan)
|
29
30
|
|
30
31
|
## Datasets
|
31
32
|
|
@@ -54,9 +55,11 @@ TorchVision::Transforms::Compose.new([
|
|
54
55
|
|
55
56
|
Supported transforms are:
|
56
57
|
|
58
|
+
- CenterCrop
|
57
59
|
- Compose
|
58
60
|
- Normalize
|
59
61
|
- RandomHorizontalFlip
|
62
|
+
- RandomResizedCrop
|
60
63
|
- RandomVerticalFlip
|
61
64
|
- Resize
|
62
65
|
- ToTensor
|
@@ -130,12 +133,16 @@ net.load_state_dict(Torch.load("net.pth"))
|
|
130
133
|
|
131
134
|
## libvips Installation
|
132
135
|
|
133
|
-
###
|
136
|
+
### Linux
|
137
|
+
|
138
|
+
Check your package manager. For Ubuntu, use:
|
134
139
|
|
135
140
|
```sh
|
136
141
|
sudo apt install libvips
|
137
142
|
```
|
138
143
|
|
144
|
+
You can also [build from source](https://libvips.github.io/libvips/install.html).
|
145
|
+
|
139
146
|
### Mac
|
140
147
|
|
141
148
|
```sh
|
data/lib/torchvision.rb
CHANGED
@@ -16,6 +16,8 @@ require "torchvision/version"
|
|
16
16
|
|
17
17
|
# datasets
|
18
18
|
require "torchvision/datasets/vision_dataset"
|
19
|
+
require "torchvision/datasets/dataset_folder"
|
20
|
+
require "torchvision/datasets/image_folder"
|
19
21
|
require "torchvision/datasets/cifar10"
|
20
22
|
require "torchvision/datasets/cifar100"
|
21
23
|
require "torchvision/datasets/mnist"
|
@@ -47,10 +49,12 @@ require "torchvision/models/wide_resnet50_2"
|
|
47
49
|
require "torchvision/models/wide_resnet101_2"
|
48
50
|
|
49
51
|
# transforms
|
52
|
+
require "torchvision/transforms/center_crop"
|
50
53
|
require "torchvision/transforms/compose"
|
51
54
|
require "torchvision/transforms/functional"
|
52
55
|
require "torchvision/transforms/normalize"
|
53
56
|
require "torchvision/transforms/random_horizontal_flip"
|
57
|
+
require "torchvision/transforms/random_resized_crop"
|
54
58
|
require "torchvision/transforms/random_vertical_flip"
|
55
59
|
require "torchvision/transforms/resize"
|
56
60
|
require "torchvision/transforms/to_tensor"
|
@@ -0,0 +1,91 @@
|
|
1
|
+
module TorchVision
|
2
|
+
module Datasets
|
3
|
+
class DatasetFolder < VisionDataset
|
4
|
+
attr_reader :classes
|
5
|
+
|
6
|
+
def initialize(root, extensions: nil, transform: nil, target_transform: nil, is_valid_file: nil)
|
7
|
+
super(root, transform: transform, target_transform: target_transform)
|
8
|
+
classes, class_to_idx = find_classes(@root)
|
9
|
+
samples = make_dataset(@root, class_to_idx, extensions, is_valid_file)
|
10
|
+
if samples.empty?
|
11
|
+
msg = "Found 0 files in subfolders of: #{@root}\n"
|
12
|
+
unless extensions.nil?
|
13
|
+
msg += "Supported extensions are: #{extensions.join(",")}"
|
14
|
+
end
|
15
|
+
raise RuntimeError, msg
|
16
|
+
end
|
17
|
+
|
18
|
+
@loader = lambda do |path|
|
19
|
+
Vips::Image.new_from_file(path)
|
20
|
+
end
|
21
|
+
@extensions = extensions
|
22
|
+
|
23
|
+
@classes = classes
|
24
|
+
@class_to_idx = class_to_idx
|
25
|
+
@samples = samples
|
26
|
+
@targets = samples.map { |s| s[1] }
|
27
|
+
end
|
28
|
+
|
29
|
+
def [](index)
|
30
|
+
path, target = @samples[index]
|
31
|
+
sample = @loader.call(path)
|
32
|
+
if @transform
|
33
|
+
sample = @transform.call(sample)
|
34
|
+
end
|
35
|
+
if @target_transform
|
36
|
+
target = @target_transform.call(target)
|
37
|
+
end
|
38
|
+
|
39
|
+
[sample, target]
|
40
|
+
end
|
41
|
+
|
42
|
+
def size
|
43
|
+
@samples.size
|
44
|
+
end
|
45
|
+
|
46
|
+
private
|
47
|
+
|
48
|
+
def find_classes(dir)
|
49
|
+
classes = Dir.children(dir).select { |d| File.directory?(File.join(dir, d)) }
|
50
|
+
classes.sort!
|
51
|
+
class_to_idx = classes.map.with_index.to_h
|
52
|
+
[classes, class_to_idx]
|
53
|
+
end
|
54
|
+
|
55
|
+
def has_file_allowed_extension(filename, extensions)
|
56
|
+
filename = filename.downcase
|
57
|
+
extensions.any? { |ext| filename.end_with?(ext) }
|
58
|
+
end
|
59
|
+
|
60
|
+
def make_dataset(directory, class_to_idx, extensions, is_valid_file)
|
61
|
+
instances = []
|
62
|
+
directory = File.expand_path(directory)
|
63
|
+
both_none = extensions.nil? && is_valid_file.nil?
|
64
|
+
both_something = !extensions.nil? && !is_valid_file.nil?
|
65
|
+
if both_none || both_something
|
66
|
+
raise ArgumentError, "Both extensions and is_valid_file cannot be None or not None at the same time"
|
67
|
+
end
|
68
|
+
if !extensions.nil?
|
69
|
+
is_valid_file = lambda do |x|
|
70
|
+
has_file_allowed_extension(x, extensions)
|
71
|
+
end
|
72
|
+
end
|
73
|
+
class_to_idx.keys.sort.each do |target_class|
|
74
|
+
class_index = class_to_idx[target_class]
|
75
|
+
target_dir = File.join(directory, target_class)
|
76
|
+
if !File.directory?(target_dir)
|
77
|
+
next
|
78
|
+
end
|
79
|
+
Dir.glob("**", base: target_dir).sort.each do |fname|
|
80
|
+
path = File.join(target_dir, fname)
|
81
|
+
if is_valid_file.call(path)
|
82
|
+
item = [path, class_index]
|
83
|
+
instances << item
|
84
|
+
end
|
85
|
+
end
|
86
|
+
end
|
87
|
+
instances
|
88
|
+
end
|
89
|
+
end
|
90
|
+
end
|
91
|
+
end
|
@@ -0,0 +1,12 @@
|
|
1
|
+
module TorchVision
|
2
|
+
module Datasets
|
3
|
+
class ImageFolder < DatasetFolder
|
4
|
+
IMG_EXTENSIONS = [".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp"]
|
5
|
+
|
6
|
+
def initialize(root, transform: nil, target_transform: nil, is_valid_file: nil)
|
7
|
+
super(root, extensions: IMG_EXTENSIONS, transform: transform, target_transform: target_transform, is_valid_file: is_valid_file)
|
8
|
+
@imgs = @samples
|
9
|
+
end
|
10
|
+
end
|
11
|
+
end
|
12
|
+
end
|
@@ -1,7 +1,6 @@
|
|
1
1
|
module TorchVision
|
2
2
|
module Datasets
|
3
|
-
|
4
|
-
class VisionDataset
|
3
|
+
class VisionDataset < Torch::Utils::Data::Dataset
|
5
4
|
attr_reader :data, :targets
|
6
5
|
|
7
6
|
def initialize(root, transforms: nil, transform: nil, target_transform: nil)
|
@@ -34,21 +34,23 @@ module TorchVision
|
|
34
34
|
|
35
35
|
def resize(img, size)
|
36
36
|
raise "img should be Vips::Image. Got #{img.class.name}" unless img.is_a?(Vips::Image)
|
37
|
-
# TODO support array size
|
38
|
-
raise "Got inappropriate size arg: #{size}" unless size.is_a?(Integer)
|
39
37
|
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
38
|
+
if size.is_a?(Integer)
|
39
|
+
w, h = img.size
|
40
|
+
if (w <= h && w == size) || (h <= w && h == size)
|
41
|
+
return img
|
42
|
+
end
|
43
|
+
if w < h
|
44
|
+
ow = size
|
45
|
+
oh = (size * h / w).to_i
|
46
|
+
img.thumbnail_image(ow, height: oh)
|
47
|
+
else
|
48
|
+
oh = size
|
49
|
+
ow = (size * w / h).to_i
|
50
|
+
img.thumbnail_image(ow, height: oh)
|
51
|
+
end
|
48
52
|
else
|
49
|
-
|
50
|
-
ow = (size * w / h).to_i
|
51
|
-
img.thumbnail_image(ow, height: oh)
|
53
|
+
img.thumbnail_image(size[0], height: size[1], size: :force)
|
52
54
|
end
|
53
55
|
end
|
54
56
|
|
@@ -90,6 +92,7 @@ module TorchVision
|
|
90
92
|
|
91
93
|
def hflip(img)
|
92
94
|
if img.is_a?(Torch::Tensor)
|
95
|
+
assert_image_tensor(img)
|
93
96
|
img.flip(-1)
|
94
97
|
else
|
95
98
|
img.flip(:horizontal)
|
@@ -98,11 +101,76 @@ module TorchVision
|
|
98
101
|
|
99
102
|
def vflip(img)
|
100
103
|
if img.is_a?(Torch::Tensor)
|
104
|
+
assert_image_tensor(img)
|
101
105
|
img.flip(-2)
|
102
106
|
else
|
103
107
|
img.flip(:vertical)
|
104
108
|
end
|
105
109
|
end
|
110
|
+
|
111
|
+
def crop(img, top, left, height, width)
|
112
|
+
if img.is_a?(Torch::Tensor)
|
113
|
+
assert_image_tensor(img)
|
114
|
+
indexes = [true] * (img.dim - 2)
|
115
|
+
img[*indexes, top...(top + height), left...(left + width)]
|
116
|
+
else
|
117
|
+
img.crop(left, top, width, height)
|
118
|
+
end
|
119
|
+
end
|
120
|
+
|
121
|
+
def center_crop(img, output_size)
|
122
|
+
if output_size.is_a?(Integer)
|
123
|
+
output_size = [output_size.to_i, output_size.to_i]
|
124
|
+
elsif output_size.is_a?(Array) && output_size.length == 1
|
125
|
+
output_size = [output_size[0], output_size[0]]
|
126
|
+
end
|
127
|
+
|
128
|
+
image_width, image_height = image_size(img)
|
129
|
+
crop_height, crop_width = output_size
|
130
|
+
|
131
|
+
if crop_width > image_width || crop_height > image_height
|
132
|
+
padding_ltrb = [
|
133
|
+
crop_width > image_width ? (crop_width - image_width).div(2) : 0,
|
134
|
+
crop_height > image_height ? (crop_height - image_height).div(2) : 0,
|
135
|
+
crop_width > image_width ? (crop_width - image_width + 1).div(2) : 0,
|
136
|
+
crop_height > image_height ? (crop_height - image_height + 1).div(2) : 0
|
137
|
+
]
|
138
|
+
# TODO
|
139
|
+
img = pad(img, padding_ltrb, fill: 0)
|
140
|
+
image_width, image_height = image_size(img)
|
141
|
+
if crop_width == image_width && crop_height == image_height
|
142
|
+
return img
|
143
|
+
end
|
144
|
+
end
|
145
|
+
|
146
|
+
crop_top = ((image_height - crop_height) / 2.0).round
|
147
|
+
crop_left = ((image_width - crop_width) / 2.0).round
|
148
|
+
crop(img, crop_top, crop_left, crop_height, crop_width)
|
149
|
+
end
|
150
|
+
|
151
|
+
# TODO interpolation
|
152
|
+
def resized_crop(img, top, left, height, width, size)
|
153
|
+
img = crop(img, top, left, height, width)
|
154
|
+
img = resize(img, size) #, interpolation)
|
155
|
+
img
|
156
|
+
end
|
157
|
+
|
158
|
+
private
|
159
|
+
|
160
|
+
def image_size(img)
|
161
|
+
if img.is_a?(Torch::Tensor)
|
162
|
+
assert_image_tensor(img)
|
163
|
+
[img.shape[-1], img.shape[-2]]
|
164
|
+
else
|
165
|
+
[img.width, img.height]
|
166
|
+
end
|
167
|
+
end
|
168
|
+
|
169
|
+
def assert_image_tensor(img)
|
170
|
+
if img.ndim < 2
|
171
|
+
raise TypeError, "Tensor is not a torch image."
|
172
|
+
end
|
173
|
+
end
|
106
174
|
end
|
107
175
|
end
|
108
176
|
|
@@ -0,0 +1,70 @@
|
|
1
|
+
module TorchVision
|
2
|
+
module Transforms
|
3
|
+
class RandomResizedCrop < Torch::NN::Module
|
4
|
+
def initialize(size, scale: [0.08, 1.0], ratio: [3.0 / 4.0, 4.0 / 3.0])
|
5
|
+
super()
|
6
|
+
@size = setup_size(size, "Please provide only two dimensions (h, w) for size.")
|
7
|
+
# @interpolation = interpolation
|
8
|
+
@scale = scale
|
9
|
+
@ratio = ratio
|
10
|
+
end
|
11
|
+
|
12
|
+
def params(img, scale, ratio)
|
13
|
+
width, height = F.send(:image_size, img)
|
14
|
+
area = height * width
|
15
|
+
|
16
|
+
log_ratio = Torch.log(Torch.tensor(ratio))
|
17
|
+
10.times do
|
18
|
+
target_area = area * Torch.empty(1).uniform!(scale[0], scale[1]).item
|
19
|
+
aspect_ratio = Torch.exp(
|
20
|
+
Torch.empty(1).uniform!(log_ratio[0], log_ratio[1])
|
21
|
+
).item
|
22
|
+
|
23
|
+
w = Math.sqrt(target_area * aspect_ratio).round
|
24
|
+
h = Math.sqrt(target_area / aspect_ratio).round
|
25
|
+
|
26
|
+
if 0 < w && w <= width && 0 < h && h <= height
|
27
|
+
i = Torch.randint(0, height - h + 1, size: [1]).item
|
28
|
+
j = Torch.randint(0, width - w + 1, size: [1]).item
|
29
|
+
return i, j, h, w
|
30
|
+
end
|
31
|
+
end
|
32
|
+
|
33
|
+
# Fallback to central crop
|
34
|
+
in_ratio = width.to_f / height.to_f
|
35
|
+
if in_ratio < ratio.min
|
36
|
+
w = width
|
37
|
+
h = (w / ratio.min).round
|
38
|
+
elsif in_ratio > ratio.max
|
39
|
+
h = height
|
40
|
+
w = (h * ratio.max).round
|
41
|
+
else # whole image
|
42
|
+
w = width
|
43
|
+
h = height
|
44
|
+
end
|
45
|
+
i = (height - h).div(2)
|
46
|
+
j = (width - w).div(2)
|
47
|
+
[i, j, h, w]
|
48
|
+
end
|
49
|
+
|
50
|
+
def forward(img)
|
51
|
+
i, j, h, w = params(img, @scale, @ratio)
|
52
|
+
F.resized_crop(img, i, j, h, w, @size) #, @interpolation)
|
53
|
+
end
|
54
|
+
|
55
|
+
private
|
56
|
+
|
57
|
+
def setup_size(size, error_msg)
|
58
|
+
if size.is_a?(Integer)
|
59
|
+
return [size, size]
|
60
|
+
end
|
61
|
+
|
62
|
+
if size.length != 2
|
63
|
+
raise ArgumentError, error_msg
|
64
|
+
end
|
65
|
+
|
66
|
+
size
|
67
|
+
end
|
68
|
+
end
|
69
|
+
end
|
70
|
+
end
|
data/lib/torchvision/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torchvision
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.2.
|
4
|
+
version: 0.2.1
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2021-03-
|
11
|
+
date: 2021-03-15 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-narray
|
@@ -64,7 +64,9 @@ files:
|
|
64
64
|
- lib/torchvision.rb
|
65
65
|
- lib/torchvision/datasets/cifar10.rb
|
66
66
|
- lib/torchvision/datasets/cifar100.rb
|
67
|
+
- lib/torchvision/datasets/dataset_folder.rb
|
67
68
|
- lib/torchvision/datasets/fashion_mnist.rb
|
69
|
+
- lib/torchvision/datasets/image_folder.rb
|
68
70
|
- lib/torchvision/datasets/kmnist.rb
|
69
71
|
- lib/torchvision/datasets/mnist.rb
|
70
72
|
- lib/torchvision/datasets/vision_dataset.rb
|
@@ -90,10 +92,12 @@ files:
|
|
90
92
|
- lib/torchvision/models/vgg19_bn.rb
|
91
93
|
- lib/torchvision/models/wide_resnet101_2.rb
|
92
94
|
- lib/torchvision/models/wide_resnet50_2.rb
|
95
|
+
- lib/torchvision/transforms/center_crop.rb
|
93
96
|
- lib/torchvision/transforms/compose.rb
|
94
97
|
- lib/torchvision/transforms/functional.rb
|
95
98
|
- lib/torchvision/transforms/normalize.rb
|
96
99
|
- lib/torchvision/transforms/random_horizontal_flip.rb
|
100
|
+
- lib/torchvision/transforms/random_resized_crop.rb
|
97
101
|
- lib/torchvision/transforms/random_vertical_flip.rb
|
98
102
|
- lib/torchvision/transforms/resize.rb
|
99
103
|
- lib/torchvision/transforms/to_tensor.rb
|