torchvision 0.2.0 → 0.2.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 9aa604602112403b7f6738a7bb014812deb528b9a4d480ba89cf7f0c6d01b59c
4
- data.tar.gz: 3ffa29d3ff5234040df51113d085a807d94a45c6998f8182db72bf60c25d67e7
3
+ metadata.gz: bbb87c59c0f081c0de57ccdd62e30bfc551e1cb69523e4ffd498c997e1a2d8b3
4
+ data.tar.gz: 890da113706e659d57194980c5c9262075beb8398a75da2997c0812b70abe308
5
5
  SHA512:
6
- metadata.gz: 89f61279aed314c84ae33c14efe266b849bf74a40a9f90beb650358610b479cb44f6cd247ff56aa89d0227ef6358973b5055fd1ce9bcf1e973402b2118bfe75d
7
- data.tar.gz: 4eb63bfd9a79bd3683238c186e0e70aba494028f54ce8f971a551f362eec4b4e55fe83c8e738c7669fde2c3671e3360a01b02c4d29f6c965640a6c08b3c4379d
6
+ metadata.gz: 3445b62b7824ae16205034881d37c48ac4c70d7e5677014755ae5600632f9ce45168f41b0d3e98c8104eb8337e1566db4df3e0ad5ace5e6a46a5d213d01b6c8d
7
+ data.tar.gz: 93f22c385586ff8a010880676806f6bc9ba2f614c4c14886235a300ea6e2abce0f80c260e255644e1b4d24e6ecddfd21830dde2960a92a9492239e69622d4548
data/CHANGELOG.md CHANGED
@@ -1,3 +1,9 @@
1
+ ## 0.2.1 (2021-03-14)
2
+
3
+ - Added `ImageFolder` and `DatasetFolder`
4
+ - Added `CenterCrop` and `RandomResizedCrop` transforms
5
+ - Added `crop` method
6
+
1
7
  ## 0.2.0 (2021-03-11)
2
8
 
3
9
  - Added `RandomHorizontalFlip`, `RandomVerticalFlip`, and `Resize` transforms
data/README.md CHANGED
@@ -6,7 +6,7 @@
6
6
 
7
7
  ## Installation
8
8
 
9
- First, [install libvips](libvips-installation). For Homebrew, use:
9
+ First, [install libvips](#libvips-installation). For Homebrew, use:
10
10
 
11
11
  ```sh
12
12
  brew install vips
@@ -25,7 +25,8 @@ This library follows the [Python API](https://pytorch.org/docs/stable/torchvisio
25
25
  ## Examples
26
26
 
27
27
  - [MNIST](https://github.com/ankane/torch.rb/tree/master/examples/mnist)
28
- - [Generative Adversarial Networks](https://github.com/ankane/torch.rb/tree/master/examples/gan)
28
+ - [Transfer learning](https://github.com/ankane/torch.rb/tree/master/examples/transfer-learning)
29
+ - [Generative adversarial networks](https://github.com/ankane/torch.rb/tree/master/examples/gan)
29
30
 
30
31
  ## Datasets
31
32
 
@@ -54,9 +55,11 @@ TorchVision::Transforms::Compose.new([
54
55
 
55
56
  Supported transforms are:
56
57
 
58
+ - CenterCrop
57
59
  - Compose
58
60
  - Normalize
59
61
  - RandomHorizontalFlip
62
+ - RandomResizedCrop
60
63
  - RandomVerticalFlip
61
64
  - Resize
62
65
  - ToTensor
@@ -130,12 +133,16 @@ net.load_state_dict(Torch.load("net.pth"))
130
133
 
131
134
  ## libvips Installation
132
135
 
133
- ### Ubuntu
136
+ ### Linux
137
+
138
+ Check your package manager. For Ubuntu, use:
134
139
 
135
140
  ```sh
136
141
  sudo apt install libvips
137
142
  ```
138
143
 
144
+ You can also [build from source](https://libvips.github.io/libvips/install.html).
145
+
139
146
  ### Mac
140
147
 
141
148
  ```sh
data/lib/torchvision.rb CHANGED
@@ -16,6 +16,8 @@ require "torchvision/version"
16
16
 
17
17
  # datasets
18
18
  require "torchvision/datasets/vision_dataset"
19
+ require "torchvision/datasets/dataset_folder"
20
+ require "torchvision/datasets/image_folder"
19
21
  require "torchvision/datasets/cifar10"
20
22
  require "torchvision/datasets/cifar100"
21
23
  require "torchvision/datasets/mnist"
@@ -47,10 +49,12 @@ require "torchvision/models/wide_resnet50_2"
47
49
  require "torchvision/models/wide_resnet101_2"
48
50
 
49
51
  # transforms
52
+ require "torchvision/transforms/center_crop"
50
53
  require "torchvision/transforms/compose"
51
54
  require "torchvision/transforms/functional"
52
55
  require "torchvision/transforms/normalize"
53
56
  require "torchvision/transforms/random_horizontal_flip"
57
+ require "torchvision/transforms/random_resized_crop"
54
58
  require "torchvision/transforms/random_vertical_flip"
55
59
  require "torchvision/transforms/resize"
56
60
  require "torchvision/transforms/to_tensor"
@@ -0,0 +1,91 @@
1
+ module TorchVision
2
+ module Datasets
3
+ class DatasetFolder < VisionDataset
4
+ attr_reader :classes
5
+
6
+ def initialize(root, extensions: nil, transform: nil, target_transform: nil, is_valid_file: nil)
7
+ super(root, transform: transform, target_transform: target_transform)
8
+ classes, class_to_idx = find_classes(@root)
9
+ samples = make_dataset(@root, class_to_idx, extensions, is_valid_file)
10
+ if samples.empty?
11
+ msg = "Found 0 files in subfolders of: #{@root}\n"
12
+ unless extensions.nil?
13
+ msg += "Supported extensions are: #{extensions.join(",")}"
14
+ end
15
+ raise RuntimeError, msg
16
+ end
17
+
18
+ @loader = lambda do |path|
19
+ Vips::Image.new_from_file(path)
20
+ end
21
+ @extensions = extensions
22
+
23
+ @classes = classes
24
+ @class_to_idx = class_to_idx
25
+ @samples = samples
26
+ @targets = samples.map { |s| s[1] }
27
+ end
28
+
29
+ def [](index)
30
+ path, target = @samples[index]
31
+ sample = @loader.call(path)
32
+ if @transform
33
+ sample = @transform.call(sample)
34
+ end
35
+ if @target_transform
36
+ target = @target_transform.call(target)
37
+ end
38
+
39
+ [sample, target]
40
+ end
41
+
42
+ def size
43
+ @samples.size
44
+ end
45
+
46
+ private
47
+
48
+ def find_classes(dir)
49
+ classes = Dir.children(dir).select { |d| File.directory?(File.join(dir, d)) }
50
+ classes.sort!
51
+ class_to_idx = classes.map.with_index.to_h
52
+ [classes, class_to_idx]
53
+ end
54
+
55
+ def has_file_allowed_extension(filename, extensions)
56
+ filename = filename.downcase
57
+ extensions.any? { |ext| filename.end_with?(ext) }
58
+ end
59
+
60
+ def make_dataset(directory, class_to_idx, extensions, is_valid_file)
61
+ instances = []
62
+ directory = File.expand_path(directory)
63
+ both_none = extensions.nil? && is_valid_file.nil?
64
+ both_something = !extensions.nil? && !is_valid_file.nil?
65
+ if both_none || both_something
66
+ raise ArgumentError, "Both extensions and is_valid_file cannot be None or not None at the same time"
67
+ end
68
+ if !extensions.nil?
69
+ is_valid_file = lambda do |x|
70
+ has_file_allowed_extension(x, extensions)
71
+ end
72
+ end
73
+ class_to_idx.keys.sort.each do |target_class|
74
+ class_index = class_to_idx[target_class]
75
+ target_dir = File.join(directory, target_class)
76
+ if !File.directory?(target_dir)
77
+ next
78
+ end
79
+ Dir.glob("**", base: target_dir).sort.each do |fname|
80
+ path = File.join(target_dir, fname)
81
+ if is_valid_file.call(path)
82
+ item = [path, class_index]
83
+ instances << item
84
+ end
85
+ end
86
+ end
87
+ instances
88
+ end
89
+ end
90
+ end
91
+ end
@@ -0,0 +1,12 @@
1
+ module TorchVision
2
+ module Datasets
3
+ class ImageFolder < DatasetFolder
4
+ IMG_EXTENSIONS = [".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp"]
5
+
6
+ def initialize(root, transform: nil, target_transform: nil, is_valid_file: nil)
7
+ super(root, extensions: IMG_EXTENSIONS, transform: transform, target_transform: target_transform, is_valid_file: is_valid_file)
8
+ @imgs = @samples
9
+ end
10
+ end
11
+ end
12
+ end
@@ -1,7 +1,6 @@
1
1
  module TorchVision
2
2
  module Datasets
3
- # TODO inherit Torch::Utils::Data::Dataset
4
- class VisionDataset
3
+ class VisionDataset < Torch::Utils::Data::Dataset
5
4
  attr_reader :data, :targets
6
5
 
7
6
  def initialize(root, transforms: nil, transform: nil, target_transform: nil)
@@ -0,0 +1,13 @@
1
+ module TorchVision
2
+ module Transforms
3
+ class CenterCrop < Torch::NN::Module
4
+ def initialize(size)
5
+ @size = size
6
+ end
7
+
8
+ def forward(img)
9
+ F.center_crop(img, @size)
10
+ end
11
+ end
12
+ end
13
+ end
@@ -34,21 +34,23 @@ module TorchVision
34
34
 
35
35
  def resize(img, size)
36
36
  raise "img should be Vips::Image. Got #{img.class.name}" unless img.is_a?(Vips::Image)
37
- # TODO support array size
38
- raise "Got inappropriate size arg: #{size}" unless size.is_a?(Integer)
39
37
 
40
- w, h = img.size
41
- if (w <= h && w == size) || (h <= w && h == size)
42
- return img
43
- end
44
- if w < h
45
- ow = size
46
- oh = (size * h / w).to_i
47
- img.thumbnail_image(ow, height: oh)
38
+ if size.is_a?(Integer)
39
+ w, h = img.size
40
+ if (w <= h && w == size) || (h <= w && h == size)
41
+ return img
42
+ end
43
+ if w < h
44
+ ow = size
45
+ oh = (size * h / w).to_i
46
+ img.thumbnail_image(ow, height: oh)
47
+ else
48
+ oh = size
49
+ ow = (size * w / h).to_i
50
+ img.thumbnail_image(ow, height: oh)
51
+ end
48
52
  else
49
- oh = size
50
- ow = (size * w / h).to_i
51
- img.thumbnail_image(ow, height: oh)
53
+ img.thumbnail_image(size[0], height: size[1], size: :force)
52
54
  end
53
55
  end
54
56
 
@@ -90,6 +92,7 @@ module TorchVision
90
92
 
91
93
  def hflip(img)
92
94
  if img.is_a?(Torch::Tensor)
95
+ assert_image_tensor(img)
93
96
  img.flip(-1)
94
97
  else
95
98
  img.flip(:horizontal)
@@ -98,11 +101,76 @@ module TorchVision
98
101
 
99
102
  def vflip(img)
100
103
  if img.is_a?(Torch::Tensor)
104
+ assert_image_tensor(img)
101
105
  img.flip(-2)
102
106
  else
103
107
  img.flip(:vertical)
104
108
  end
105
109
  end
110
+
111
+ def crop(img, top, left, height, width)
112
+ if img.is_a?(Torch::Tensor)
113
+ assert_image_tensor(img)
114
+ indexes = [true] * (img.dim - 2)
115
+ img[*indexes, top...(top + height), left...(left + width)]
116
+ else
117
+ img.crop(left, top, width, height)
118
+ end
119
+ end
120
+
121
+ def center_crop(img, output_size)
122
+ if output_size.is_a?(Integer)
123
+ output_size = [output_size.to_i, output_size.to_i]
124
+ elsif output_size.is_a?(Array) && output_size.length == 1
125
+ output_size = [output_size[0], output_size[0]]
126
+ end
127
+
128
+ image_width, image_height = image_size(img)
129
+ crop_height, crop_width = output_size
130
+
131
+ if crop_width > image_width || crop_height > image_height
132
+ padding_ltrb = [
133
+ crop_width > image_width ? (crop_width - image_width).div(2) : 0,
134
+ crop_height > image_height ? (crop_height - image_height).div(2) : 0,
135
+ crop_width > image_width ? (crop_width - image_width + 1).div(2) : 0,
136
+ crop_height > image_height ? (crop_height - image_height + 1).div(2) : 0
137
+ ]
138
+ # TODO
139
+ img = pad(img, padding_ltrb, fill: 0)
140
+ image_width, image_height = image_size(img)
141
+ if crop_width == image_width && crop_height == image_height
142
+ return img
143
+ end
144
+ end
145
+
146
+ crop_top = ((image_height - crop_height) / 2.0).round
147
+ crop_left = ((image_width - crop_width) / 2.0).round
148
+ crop(img, crop_top, crop_left, crop_height, crop_width)
149
+ end
150
+
151
+ # TODO interpolation
152
+ def resized_crop(img, top, left, height, width, size)
153
+ img = crop(img, top, left, height, width)
154
+ img = resize(img, size) #, interpolation)
155
+ img
156
+ end
157
+
158
+ private
159
+
160
+ def image_size(img)
161
+ if img.is_a?(Torch::Tensor)
162
+ assert_image_tensor(img)
163
+ [img.shape[-1], img.shape[-2]]
164
+ else
165
+ [img.width, img.height]
166
+ end
167
+ end
168
+
169
+ def assert_image_tensor(img)
170
+ if img.ndim < 2
171
+ raise TypeError, "Tensor is not a torch image."
172
+ end
173
+ end
106
174
  end
107
175
  end
108
176
 
@@ -0,0 +1,70 @@
1
+ module TorchVision
2
+ module Transforms
3
+ class RandomResizedCrop < Torch::NN::Module
4
+ def initialize(size, scale: [0.08, 1.0], ratio: [3.0 / 4.0, 4.0 / 3.0])
5
+ super()
6
+ @size = setup_size(size, "Please provide only two dimensions (h, w) for size.")
7
+ # @interpolation = interpolation
8
+ @scale = scale
9
+ @ratio = ratio
10
+ end
11
+
12
+ def params(img, scale, ratio)
13
+ width, height = F.send(:image_size, img)
14
+ area = height * width
15
+
16
+ log_ratio = Torch.log(Torch.tensor(ratio))
17
+ 10.times do
18
+ target_area = area * Torch.empty(1).uniform!(scale[0], scale[1]).item
19
+ aspect_ratio = Torch.exp(
20
+ Torch.empty(1).uniform!(log_ratio[0], log_ratio[1])
21
+ ).item
22
+
23
+ w = Math.sqrt(target_area * aspect_ratio).round
24
+ h = Math.sqrt(target_area / aspect_ratio).round
25
+
26
+ if 0 < w && w <= width && 0 < h && h <= height
27
+ i = Torch.randint(0, height - h + 1, size: [1]).item
28
+ j = Torch.randint(0, width - w + 1, size: [1]).item
29
+ return i, j, h, w
30
+ end
31
+ end
32
+
33
+ # Fallback to central crop
34
+ in_ratio = width.to_f / height.to_f
35
+ if in_ratio < ratio.min
36
+ w = width
37
+ h = (w / ratio.min).round
38
+ elsif in_ratio > ratio.max
39
+ h = height
40
+ w = (h * ratio.max).round
41
+ else # whole image
42
+ w = width
43
+ h = height
44
+ end
45
+ i = (height - h).div(2)
46
+ j = (width - w).div(2)
47
+ [i, j, h, w]
48
+ end
49
+
50
+ def forward(img)
51
+ i, j, h, w = params(img, @scale, @ratio)
52
+ F.resized_crop(img, i, j, h, w, @size) #, @interpolation)
53
+ end
54
+
55
+ private
56
+
57
+ def setup_size(size, error_msg)
58
+ if size.is_a?(Integer)
59
+ return [size, size]
60
+ end
61
+
62
+ if size.length != 2
63
+ raise ArgumentError, error_msg
64
+ end
65
+
66
+ size
67
+ end
68
+ end
69
+ end
70
+ end
@@ -1,3 +1,3 @@
1
1
  module TorchVision
2
- VERSION = "0.2.0"
2
+ VERSION = "0.2.1"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torchvision
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.2.0
4
+ version: 0.2.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2021-03-11 00:00:00.000000000 Z
11
+ date: 2021-03-15 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray
@@ -64,7 +64,9 @@ files:
64
64
  - lib/torchvision.rb
65
65
  - lib/torchvision/datasets/cifar10.rb
66
66
  - lib/torchvision/datasets/cifar100.rb
67
+ - lib/torchvision/datasets/dataset_folder.rb
67
68
  - lib/torchvision/datasets/fashion_mnist.rb
69
+ - lib/torchvision/datasets/image_folder.rb
68
70
  - lib/torchvision/datasets/kmnist.rb
69
71
  - lib/torchvision/datasets/mnist.rb
70
72
  - lib/torchvision/datasets/vision_dataset.rb
@@ -90,10 +92,12 @@ files:
90
92
  - lib/torchvision/models/vgg19_bn.rb
91
93
  - lib/torchvision/models/wide_resnet101_2.rb
92
94
  - lib/torchvision/models/wide_resnet50_2.rb
95
+ - lib/torchvision/transforms/center_crop.rb
93
96
  - lib/torchvision/transforms/compose.rb
94
97
  - lib/torchvision/transforms/functional.rb
95
98
  - lib/torchvision/transforms/normalize.rb
96
99
  - lib/torchvision/transforms/random_horizontal_flip.rb
100
+ - lib/torchvision/transforms/random_resized_crop.rb
97
101
  - lib/torchvision/transforms/random_vertical_flip.rb
98
102
  - lib/torchvision/transforms/resize.rb
99
103
  - lib/torchvision/transforms/to_tensor.rb