torchvision 0.2.0 → 0.2.1

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 9aa604602112403b7f6738a7bb014812deb528b9a4d480ba89cf7f0c6d01b59c
4
- data.tar.gz: 3ffa29d3ff5234040df51113d085a807d94a45c6998f8182db72bf60c25d67e7
3
+ metadata.gz: bbb87c59c0f081c0de57ccdd62e30bfc551e1cb69523e4ffd498c997e1a2d8b3
4
+ data.tar.gz: 890da113706e659d57194980c5c9262075beb8398a75da2997c0812b70abe308
5
5
  SHA512:
6
- metadata.gz: 89f61279aed314c84ae33c14efe266b849bf74a40a9f90beb650358610b479cb44f6cd247ff56aa89d0227ef6358973b5055fd1ce9bcf1e973402b2118bfe75d
7
- data.tar.gz: 4eb63bfd9a79bd3683238c186e0e70aba494028f54ce8f971a551f362eec4b4e55fe83c8e738c7669fde2c3671e3360a01b02c4d29f6c965640a6c08b3c4379d
6
+ metadata.gz: 3445b62b7824ae16205034881d37c48ac4c70d7e5677014755ae5600632f9ce45168f41b0d3e98c8104eb8337e1566db4df3e0ad5ace5e6a46a5d213d01b6c8d
7
+ data.tar.gz: 93f22c385586ff8a010880676806f6bc9ba2f614c4c14886235a300ea6e2abce0f80c260e255644e1b4d24e6ecddfd21830dde2960a92a9492239e69622d4548
data/CHANGELOG.md CHANGED
@@ -1,3 +1,9 @@
1
+ ## 0.2.1 (2021-03-14)
2
+
3
+ - Added `ImageFolder` and `DatasetFolder`
4
+ - Added `CenterCrop` and `RandomResizedCrop` transforms
5
+ - Added `crop` method
6
+
1
7
  ## 0.2.0 (2021-03-11)
2
8
 
3
9
  - Added `RandomHorizontalFlip`, `RandomVerticalFlip`, and `Resize` transforms
data/README.md CHANGED
@@ -6,7 +6,7 @@
6
6
 
7
7
  ## Installation
8
8
 
9
- First, [install libvips](libvips-installation). For Homebrew, use:
9
+ First, [install libvips](#libvips-installation). For Homebrew, use:
10
10
 
11
11
  ```sh
12
12
  brew install vips
@@ -25,7 +25,8 @@ This library follows the [Python API](https://pytorch.org/docs/stable/torchvisio
25
25
  ## Examples
26
26
 
27
27
  - [MNIST](https://github.com/ankane/torch.rb/tree/master/examples/mnist)
28
- - [Generative Adversarial Networks](https://github.com/ankane/torch.rb/tree/master/examples/gan)
28
+ - [Transfer learning](https://github.com/ankane/torch.rb/tree/master/examples/transfer-learning)
29
+ - [Generative adversarial networks](https://github.com/ankane/torch.rb/tree/master/examples/gan)
29
30
 
30
31
  ## Datasets
31
32
 
@@ -54,9 +55,11 @@ TorchVision::Transforms::Compose.new([
54
55
 
55
56
  Supported transforms are:
56
57
 
58
+ - CenterCrop
57
59
  - Compose
58
60
  - Normalize
59
61
  - RandomHorizontalFlip
62
+ - RandomResizedCrop
60
63
  - RandomVerticalFlip
61
64
  - Resize
62
65
  - ToTensor
@@ -130,12 +133,16 @@ net.load_state_dict(Torch.load("net.pth"))
130
133
 
131
134
  ## libvips Installation
132
135
 
133
- ### Ubuntu
136
+ ### Linux
137
+
138
+ Check your package manager. For Ubuntu, use:
134
139
 
135
140
  ```sh
136
141
  sudo apt install libvips
137
142
  ```
138
143
 
144
+ You can also [build from source](https://libvips.github.io/libvips/install.html).
145
+
139
146
  ### Mac
140
147
 
141
148
  ```sh
data/lib/torchvision.rb CHANGED
@@ -16,6 +16,8 @@ require "torchvision/version"
16
16
 
17
17
  # datasets
18
18
  require "torchvision/datasets/vision_dataset"
19
+ require "torchvision/datasets/dataset_folder"
20
+ require "torchvision/datasets/image_folder"
19
21
  require "torchvision/datasets/cifar10"
20
22
  require "torchvision/datasets/cifar100"
21
23
  require "torchvision/datasets/mnist"
@@ -47,10 +49,12 @@ require "torchvision/models/wide_resnet50_2"
47
49
  require "torchvision/models/wide_resnet101_2"
48
50
 
49
51
  # transforms
52
+ require "torchvision/transforms/center_crop"
50
53
  require "torchvision/transforms/compose"
51
54
  require "torchvision/transforms/functional"
52
55
  require "torchvision/transforms/normalize"
53
56
  require "torchvision/transforms/random_horizontal_flip"
57
+ require "torchvision/transforms/random_resized_crop"
54
58
  require "torchvision/transforms/random_vertical_flip"
55
59
  require "torchvision/transforms/resize"
56
60
  require "torchvision/transforms/to_tensor"
@@ -0,0 +1,91 @@
1
+ module TorchVision
2
+ module Datasets
3
+ class DatasetFolder < VisionDataset
4
+ attr_reader :classes
5
+
6
+ def initialize(root, extensions: nil, transform: nil, target_transform: nil, is_valid_file: nil)
7
+ super(root, transform: transform, target_transform: target_transform)
8
+ classes, class_to_idx = find_classes(@root)
9
+ samples = make_dataset(@root, class_to_idx, extensions, is_valid_file)
10
+ if samples.empty?
11
+ msg = "Found 0 files in subfolders of: #{@root}\n"
12
+ unless extensions.nil?
13
+ msg += "Supported extensions are: #{extensions.join(",")}"
14
+ end
15
+ raise RuntimeError, msg
16
+ end
17
+
18
+ @loader = lambda do |path|
19
+ Vips::Image.new_from_file(path)
20
+ end
21
+ @extensions = extensions
22
+
23
+ @classes = classes
24
+ @class_to_idx = class_to_idx
25
+ @samples = samples
26
+ @targets = samples.map { |s| s[1] }
27
+ end
28
+
29
+ def [](index)
30
+ path, target = @samples[index]
31
+ sample = @loader.call(path)
32
+ if @transform
33
+ sample = @transform.call(sample)
34
+ end
35
+ if @target_transform
36
+ target = @target_transform.call(target)
37
+ end
38
+
39
+ [sample, target]
40
+ end
41
+
42
+ def size
43
+ @samples.size
44
+ end
45
+
46
+ private
47
+
48
+ def find_classes(dir)
49
+ classes = Dir.children(dir).select { |d| File.directory?(File.join(dir, d)) }
50
+ classes.sort!
51
+ class_to_idx = classes.map.with_index.to_h
52
+ [classes, class_to_idx]
53
+ end
54
+
55
+ def has_file_allowed_extension(filename, extensions)
56
+ filename = filename.downcase
57
+ extensions.any? { |ext| filename.end_with?(ext) }
58
+ end
59
+
60
+ def make_dataset(directory, class_to_idx, extensions, is_valid_file)
61
+ instances = []
62
+ directory = File.expand_path(directory)
63
+ both_none = extensions.nil? && is_valid_file.nil?
64
+ both_something = !extensions.nil? && !is_valid_file.nil?
65
+ if both_none || both_something
66
+ raise ArgumentError, "Both extensions and is_valid_file cannot be None or not None at the same time"
67
+ end
68
+ if !extensions.nil?
69
+ is_valid_file = lambda do |x|
70
+ has_file_allowed_extension(x, extensions)
71
+ end
72
+ end
73
+ class_to_idx.keys.sort.each do |target_class|
74
+ class_index = class_to_idx[target_class]
75
+ target_dir = File.join(directory, target_class)
76
+ if !File.directory?(target_dir)
77
+ next
78
+ end
79
+ Dir.glob("**", base: target_dir).sort.each do |fname|
80
+ path = File.join(target_dir, fname)
81
+ if is_valid_file.call(path)
82
+ item = [path, class_index]
83
+ instances << item
84
+ end
85
+ end
86
+ end
87
+ instances
88
+ end
89
+ end
90
+ end
91
+ end
@@ -0,0 +1,12 @@
1
+ module TorchVision
2
+ module Datasets
3
+ class ImageFolder < DatasetFolder
4
+ IMG_EXTENSIONS = [".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp"]
5
+
6
+ def initialize(root, transform: nil, target_transform: nil, is_valid_file: nil)
7
+ super(root, extensions: IMG_EXTENSIONS, transform: transform, target_transform: target_transform, is_valid_file: is_valid_file)
8
+ @imgs = @samples
9
+ end
10
+ end
11
+ end
12
+ end
@@ -1,7 +1,6 @@
1
1
  module TorchVision
2
2
  module Datasets
3
- # TODO inherit Torch::Utils::Data::Dataset
4
- class VisionDataset
3
+ class VisionDataset < Torch::Utils::Data::Dataset
5
4
  attr_reader :data, :targets
6
5
 
7
6
  def initialize(root, transforms: nil, transform: nil, target_transform: nil)
@@ -0,0 +1,13 @@
1
+ module TorchVision
2
+ module Transforms
3
+ class CenterCrop < Torch::NN::Module
4
+ def initialize(size)
5
+ @size = size
6
+ end
7
+
8
+ def forward(img)
9
+ F.center_crop(img, @size)
10
+ end
11
+ end
12
+ end
13
+ end
@@ -34,21 +34,23 @@ module TorchVision
34
34
 
35
35
  def resize(img, size)
36
36
  raise "img should be Vips::Image. Got #{img.class.name}" unless img.is_a?(Vips::Image)
37
- # TODO support array size
38
- raise "Got inappropriate size arg: #{size}" unless size.is_a?(Integer)
39
37
 
40
- w, h = img.size
41
- if (w <= h && w == size) || (h <= w && h == size)
42
- return img
43
- end
44
- if w < h
45
- ow = size
46
- oh = (size * h / w).to_i
47
- img.thumbnail_image(ow, height: oh)
38
+ if size.is_a?(Integer)
39
+ w, h = img.size
40
+ if (w <= h && w == size) || (h <= w && h == size)
41
+ return img
42
+ end
43
+ if w < h
44
+ ow = size
45
+ oh = (size * h / w).to_i
46
+ img.thumbnail_image(ow, height: oh)
47
+ else
48
+ oh = size
49
+ ow = (size * w / h).to_i
50
+ img.thumbnail_image(ow, height: oh)
51
+ end
48
52
  else
49
- oh = size
50
- ow = (size * w / h).to_i
51
- img.thumbnail_image(ow, height: oh)
53
+ img.thumbnail_image(size[0], height: size[1], size: :force)
52
54
  end
53
55
  end
54
56
 
@@ -90,6 +92,7 @@ module TorchVision
90
92
 
91
93
  def hflip(img)
92
94
  if img.is_a?(Torch::Tensor)
95
+ assert_image_tensor(img)
93
96
  img.flip(-1)
94
97
  else
95
98
  img.flip(:horizontal)
@@ -98,11 +101,76 @@ module TorchVision
98
101
 
99
102
  def vflip(img)
100
103
  if img.is_a?(Torch::Tensor)
104
+ assert_image_tensor(img)
101
105
  img.flip(-2)
102
106
  else
103
107
  img.flip(:vertical)
104
108
  end
105
109
  end
110
+
111
+ def crop(img, top, left, height, width)
112
+ if img.is_a?(Torch::Tensor)
113
+ assert_image_tensor(img)
114
+ indexes = [true] * (img.dim - 2)
115
+ img[*indexes, top...(top + height), left...(left + width)]
116
+ else
117
+ img.crop(left, top, width, height)
118
+ end
119
+ end
120
+
121
+ def center_crop(img, output_size)
122
+ if output_size.is_a?(Integer)
123
+ output_size = [output_size.to_i, output_size.to_i]
124
+ elsif output_size.is_a?(Array) && output_size.length == 1
125
+ output_size = [output_size[0], output_size[0]]
126
+ end
127
+
128
+ image_width, image_height = image_size(img)
129
+ crop_height, crop_width = output_size
130
+
131
+ if crop_width > image_width || crop_height > image_height
132
+ padding_ltrb = [
133
+ crop_width > image_width ? (crop_width - image_width).div(2) : 0,
134
+ crop_height > image_height ? (crop_height - image_height).div(2) : 0,
135
+ crop_width > image_width ? (crop_width - image_width + 1).div(2) : 0,
136
+ crop_height > image_height ? (crop_height - image_height + 1).div(2) : 0
137
+ ]
138
+ # TODO
139
+ img = pad(img, padding_ltrb, fill: 0)
140
+ image_width, image_height = image_size(img)
141
+ if crop_width == image_width && crop_height == image_height
142
+ return img
143
+ end
144
+ end
145
+
146
+ crop_top = ((image_height - crop_height) / 2.0).round
147
+ crop_left = ((image_width - crop_width) / 2.0).round
148
+ crop(img, crop_top, crop_left, crop_height, crop_width)
149
+ end
150
+
151
+ # TODO interpolation
152
+ def resized_crop(img, top, left, height, width, size)
153
+ img = crop(img, top, left, height, width)
154
+ img = resize(img, size) #, interpolation)
155
+ img
156
+ end
157
+
158
+ private
159
+
160
+ def image_size(img)
161
+ if img.is_a?(Torch::Tensor)
162
+ assert_image_tensor(img)
163
+ [img.shape[-1], img.shape[-2]]
164
+ else
165
+ [img.width, img.height]
166
+ end
167
+ end
168
+
169
+ def assert_image_tensor(img)
170
+ if img.ndim < 2
171
+ raise TypeError, "Tensor is not a torch image."
172
+ end
173
+ end
106
174
  end
107
175
  end
108
176
 
@@ -0,0 +1,70 @@
1
+ module TorchVision
2
+ module Transforms
3
+ class RandomResizedCrop < Torch::NN::Module
4
+ def initialize(size, scale: [0.08, 1.0], ratio: [3.0 / 4.0, 4.0 / 3.0])
5
+ super()
6
+ @size = setup_size(size, "Please provide only two dimensions (h, w) for size.")
7
+ # @interpolation = interpolation
8
+ @scale = scale
9
+ @ratio = ratio
10
+ end
11
+
12
+ def params(img, scale, ratio)
13
+ width, height = F.send(:image_size, img)
14
+ area = height * width
15
+
16
+ log_ratio = Torch.log(Torch.tensor(ratio))
17
+ 10.times do
18
+ target_area = area * Torch.empty(1).uniform!(scale[0], scale[1]).item
19
+ aspect_ratio = Torch.exp(
20
+ Torch.empty(1).uniform!(log_ratio[0], log_ratio[1])
21
+ ).item
22
+
23
+ w = Math.sqrt(target_area * aspect_ratio).round
24
+ h = Math.sqrt(target_area / aspect_ratio).round
25
+
26
+ if 0 < w && w <= width && 0 < h && h <= height
27
+ i = Torch.randint(0, height - h + 1, size: [1]).item
28
+ j = Torch.randint(0, width - w + 1, size: [1]).item
29
+ return i, j, h, w
30
+ end
31
+ end
32
+
33
+ # Fallback to central crop
34
+ in_ratio = width.to_f / height.to_f
35
+ if in_ratio < ratio.min
36
+ w = width
37
+ h = (w / ratio.min).round
38
+ elsif in_ratio > ratio.max
39
+ h = height
40
+ w = (h * ratio.max).round
41
+ else # whole image
42
+ w = width
43
+ h = height
44
+ end
45
+ i = (height - h).div(2)
46
+ j = (width - w).div(2)
47
+ [i, j, h, w]
48
+ end
49
+
50
+ def forward(img)
51
+ i, j, h, w = params(img, @scale, @ratio)
52
+ F.resized_crop(img, i, j, h, w, @size) #, @interpolation)
53
+ end
54
+
55
+ private
56
+
57
+ def setup_size(size, error_msg)
58
+ if size.is_a?(Integer)
59
+ return [size, size]
60
+ end
61
+
62
+ if size.length != 2
63
+ raise ArgumentError, error_msg
64
+ end
65
+
66
+ size
67
+ end
68
+ end
69
+ end
70
+ end
@@ -1,3 +1,3 @@
1
1
  module TorchVision
2
- VERSION = "0.2.0"
2
+ VERSION = "0.2.1"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torchvision
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.2.0
4
+ version: 0.2.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2021-03-11 00:00:00.000000000 Z
11
+ date: 2021-03-15 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray
@@ -64,7 +64,9 @@ files:
64
64
  - lib/torchvision.rb
65
65
  - lib/torchvision/datasets/cifar10.rb
66
66
  - lib/torchvision/datasets/cifar100.rb
67
+ - lib/torchvision/datasets/dataset_folder.rb
67
68
  - lib/torchvision/datasets/fashion_mnist.rb
69
+ - lib/torchvision/datasets/image_folder.rb
68
70
  - lib/torchvision/datasets/kmnist.rb
69
71
  - lib/torchvision/datasets/mnist.rb
70
72
  - lib/torchvision/datasets/vision_dataset.rb
@@ -90,10 +92,12 @@ files:
90
92
  - lib/torchvision/models/vgg19_bn.rb
91
93
  - lib/torchvision/models/wide_resnet101_2.rb
92
94
  - lib/torchvision/models/wide_resnet50_2.rb
95
+ - lib/torchvision/transforms/center_crop.rb
93
96
  - lib/torchvision/transforms/compose.rb
94
97
  - lib/torchvision/transforms/functional.rb
95
98
  - lib/torchvision/transforms/normalize.rb
96
99
  - lib/torchvision/transforms/random_horizontal_flip.rb
100
+ - lib/torchvision/transforms/random_resized_crop.rb
97
101
  - lib/torchvision/transforms/random_vertical_flip.rb
98
102
  - lib/torchvision/transforms/resize.rb
99
103
  - lib/torchvision/transforms/to_tensor.rb