torchvision 0.1.3 → 0.2.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 44c29605f12dddf8196432223f2137ef9b9bef490996b718f5a9bdbc13dbd33f
4
- data.tar.gz: 8790200a0ed8f7a275f99327431dc8be99a7578c473ec03411f9411cd10c6c93
3
+ metadata.gz: 9aa604602112403b7f6738a7bb014812deb528b9a4d480ba89cf7f0c6d01b59c
4
+ data.tar.gz: 3ffa29d3ff5234040df51113d085a807d94a45c6998f8182db72bf60c25d67e7
5
5
  SHA512:
6
- metadata.gz: 65816ef10f524781553327f9634bb8818a7efa6fb072e81468949d937b5430dcddc7f6c8cf3b305c977dc4b1279d23b38fb034a4c277eecc1adcc0f2b8c99e3e
7
- data.tar.gz: 01f485a78cd5a19c9a0dc987f4e46931b033ffb3a84af703b602b0e90e8c221c99b5e6318c86b18361562cc3ca582b3a0822de151a5edf63964002317a684bfa
6
+ metadata.gz: 89f61279aed314c84ae33c14efe266b849bf74a40a9f90beb650358610b479cb44f6cd247ff56aa89d0227ef6358973b5055fd1ce9bcf1e973402b2118bfe75d
7
+ data.tar.gz: 4eb63bfd9a79bd3683238c186e0e70aba494028f54ce8f971a551f362eec4b4e55fe83c8e738c7669fde2c3671e3360a01b02c4d29f6c965640a6c08b3c4379d
data/CHANGELOG.md CHANGED
@@ -1,3 +1,15 @@
1
+ ## 0.2.0 (2021-03-11)
2
+
3
+ - Added `RandomHorizontalFlip`, `RandomVerticalFlip`, and `Resize` transforms
4
+ - Added `save_image` method
5
+ - Added `data` and `targets` methods to datasets
6
+ - Removed support for Ruby < 2.6
7
+
8
+ Breaking changes
9
+
10
+ - Added dependency on libvips
11
+ - MNIST datasets return images instead of tensors
12
+
1
13
  ## 0.1.3 (2020-06-29)
2
14
 
3
15
  - Added AlexNet model
data/LICENSE.txt CHANGED
@@ -1,7 +1,7 @@
1
1
  BSD 3-Clause License
2
2
 
3
- Copyright (c) Andrew Kane 2020,
4
3
  Copyright (c) Soumith Chintala 2016,
4
+ Copyright (c) Andrew Kane 2020-2021,
5
5
  All rights reserved.
6
6
 
7
7
  Redistribution and use in source and binary forms, with or without
data/README.md CHANGED
@@ -2,12 +2,16 @@
2
2
 
3
3
  :fire: Computer vision datasets, transforms, and models for Ruby
4
4
 
5
- This gem is currently experimental. There may be breaking changes between each release. Please report any issues you experience.
6
-
7
- [![Build Status](https://travis-ci.org/ankane/torchvision.svg?branch=master)](https://travis-ci.org/ankane/torchvision)
5
+ [![Build Status](https://github.com/ankane/torchvision/workflows/build/badge.svg?branch=master)](https://github.com/ankane/torchvision/actions)
8
6
 
9
7
  ## Installation
10
8
 
9
+ First, [install libvips](libvips-installation). For Homebrew, use:
10
+
11
+ ```sh
12
+ brew install vips
13
+ ```
14
+
11
15
  Add this line to your application’s Gemfile:
12
16
 
13
17
  ```ruby
@@ -16,7 +20,12 @@ gem 'torchvision'
16
20
 
17
21
  ## Getting Started
18
22
 
19
- This library follows the [Python API](https://pytorch.org/docs/master/torchvision/). Many methods and options are missing at the moment. PRs welcome!
23
+ This library follows the [Python API](https://pytorch.org/docs/stable/torchvision/index.html). Many methods and options are missing at the moment. PRs welcome!
24
+
25
+ ## Examples
26
+
27
+ - [MNIST](https://github.com/ankane/torch.rb/tree/master/examples/mnist)
28
+ - [Generative Adversarial Networks](https://github.com/ankane/torch.rb/tree/master/examples/gan)
20
29
 
21
30
  ## Datasets
22
31
 
@@ -43,6 +52,15 @@ TorchVision::Transforms::Compose.new([
43
52
  ])
44
53
  ```
45
54
 
55
+ Supported transforms are:
56
+
57
+ - Compose
58
+ - Normalize
59
+ - RandomHorizontalFlip
60
+ - RandomVerticalFlip
61
+ - Resize
62
+ - ToTensor
63
+
46
64
  ## Models
47
65
 
48
66
  - [AlexNet](#alexnet)
@@ -94,6 +112,40 @@ TorchVision::Models::WideResNet52_2.new
94
112
  TorchVision::Models::WideResNet101_2.new
95
113
  ```
96
114
 
115
+ ## Pretrained Models
116
+
117
+ You can download pretrained models with [this script](pretrained.py)
118
+
119
+ ```sh
120
+ pip install torchvision
121
+ python pretrained.py
122
+ ```
123
+
124
+ And load them
125
+
126
+ ```ruby
127
+ net = TorchVision::Models::ResNet18.new
128
+ net.load_state_dict(Torch.load("net.pth"))
129
+ ```
130
+
131
+ ## libvips Installation
132
+
133
+ ### Ubuntu
134
+
135
+ ```sh
136
+ sudo apt install libvips
137
+ ```
138
+
139
+ ### Mac
140
+
141
+ ```sh
142
+ brew install vips
143
+ ```
144
+
145
+ ### Windows
146
+
147
+ Check out [the options](https://libvips.github.io/libvips/install.html).
148
+
97
149
  ## Disclaimer
98
150
 
99
151
  This library downloads and prepares public datasets. We don’t host any datasets. Be sure to adhere to the license for each dataset.
data/lib/torchvision.rb CHANGED
@@ -1,5 +1,6 @@
1
1
  # dependencies
2
2
  require "numo/narray"
3
+ require "vips"
3
4
  require "torch"
4
5
 
5
6
  # stdlib
@@ -10,6 +11,7 @@ require "rubygems/package"
10
11
  require "tmpdir"
11
12
 
12
13
  # modules
14
+ require "torchvision/utils"
13
15
  require "torchvision/version"
14
16
 
15
17
  # datasets
@@ -48,6 +50,9 @@ require "torchvision/models/wide_resnet101_2"
48
50
  require "torchvision/transforms/compose"
49
51
  require "torchvision/transforms/functional"
50
52
  require "torchvision/transforms/normalize"
53
+ require "torchvision/transforms/random_horizontal_flip"
54
+ require "torchvision/transforms/random_vertical_flip"
55
+ require "torchvision/transforms/resize"
51
56
  require "torchvision/transforms/to_tensor"
52
57
 
53
58
  module TorchVision
@@ -43,7 +43,8 @@ module TorchVision
43
43
  # TODO remove trues when Numo supports it
44
44
  img, target = @data[index, true, true, true], @targets[index]
45
45
 
46
- # TODO convert to image
46
+ img = Utils.image_from_array(img)
47
+
47
48
  img = @transform.call(img) if @transform
48
49
 
49
50
  target = @target_transform.call(target) if @target_transform
@@ -2,7 +2,6 @@ module TorchVision
2
2
  module Datasets
3
3
  class MNIST < VisionDataset
4
4
  # http://yann.lecun.com/exdb/mnist/
5
-
6
5
  def initialize(root, train: true, download: false, transform: nil, target_transform: nil)
7
6
  super(root, transform: transform, target_transform: target_transform)
8
7
  @train = train
@@ -24,7 +23,8 @@ module TorchVision
24
23
  def [](index)
25
24
  img, target = @data[index], @targets[index].item
26
25
 
27
- # TODO convert to image
26
+ img = Utils.image_from_array(img)
27
+
28
28
  img = @transform.call(img) if @transform
29
29
 
30
30
  target = @target_transform.call(target) if @target_transform
@@ -2,6 +2,8 @@ module TorchVision
2
2
  module Datasets
3
3
  # TODO inherit Torch::Utils::Data::Dataset
4
4
  class VisionDataset
5
+ attr_reader :data, :targets
6
+
5
7
  def initialize(root, transforms: nil, transform: nil, target_transform: nil)
6
8
  @root = root
7
9
 
@@ -1,11 +1,11 @@
1
1
  module TorchVision
2
2
  module Transforms
3
- class Compose
3
+ class Compose < Torch::NN::Module
4
4
  def initialize(transforms)
5
5
  @transforms = transforms
6
6
  end
7
7
 
8
- def call(img)
8
+ def forward(img)
9
9
  @transforms.each do |t|
10
10
  img = t.call(img)
11
11
  end
@@ -32,10 +32,30 @@ module TorchVision
32
32
  tensor
33
33
  end
34
34
 
35
+ def resize(img, size)
36
+ raise "img should be Vips::Image. Got #{img.class.name}" unless img.is_a?(Vips::Image)
37
+ # TODO support array size
38
+ raise "Got inappropriate size arg: #{size}" unless size.is_a?(Integer)
39
+
40
+ w, h = img.size
41
+ if (w <= h && w == size) || (h <= w && h == size)
42
+ return img
43
+ end
44
+ if w < h
45
+ ow = size
46
+ oh = (size * h / w).to_i
47
+ img.thumbnail_image(ow, height: oh)
48
+ else
49
+ oh = size
50
+ ow = (size * w / h).to_i
51
+ img.thumbnail_image(ow, height: oh)
52
+ end
53
+ end
54
+
35
55
  # TODO improve
36
56
  def to_tensor(pic)
37
- if !pic.is_a?(Numo::NArray) && !pic.is_a?(Torch::Tensor)
38
- raise ArgumentError, "pic should be tensor or Numo::NArray. Got #{pic.class.name}"
57
+ if !pic.is_a?(Numo::NArray) && !pic.is_a?(Vips::Image)
58
+ raise ArgumentError, "pic should be Vips::Image or Numo::NArray. Got #{pic.class.name}"
39
59
  end
40
60
 
41
61
  if pic.is_a?(Numo::NArray) && ![2, 3].include?(pic.ndim)
@@ -44,15 +64,44 @@ module TorchVision
44
64
 
45
65
  if pic.is_a?(Numo::NArray)
46
66
  if pic.ndim == 2
47
- raise Torch::NotImplementedYet
67
+ pic = pic.reshape(*pic.shape, 1)
48
68
  end
49
69
 
50
70
  img = Torch.from_numo(pic.transpose(2, 0, 1))
51
- return img.float.div(255)
71
+ if img.dtype == :uint8
72
+ return img.float.div(255)
73
+ else
74
+ return img
75
+ end
76
+ end
77
+
78
+ case pic.format
79
+ when :uchar
80
+ img = Torch::ByteTensor.new(Torch::ByteStorage.from_buffer(pic.write_to_memory))
81
+ else
82
+ raise Error, "Format not supported yet: #{pic.format}"
52
83
  end
53
84
 
54
- pic = pic.float
55
- pic.unsqueeze!(0).div!(255)
85
+ img = img.view(pic.height, pic.width, pic.bands)
86
+ # put it from HWC to CHW format
87
+ img = img.permute([2, 0, 1]).contiguous
88
+ img.float.div(255)
89
+ end
90
+
91
+ def hflip(img)
92
+ if img.is_a?(Torch::Tensor)
93
+ img.flip(-1)
94
+ else
95
+ img.flip(:horizontal)
96
+ end
97
+ end
98
+
99
+ def vflip(img)
100
+ if img.is_a?(Torch::Tensor)
101
+ img.flip(-2)
102
+ else
103
+ img.flip(:vertical)
104
+ end
56
105
  end
57
106
  end
58
107
  end
@@ -1,13 +1,13 @@
1
1
  module TorchVision
2
2
  module Transforms
3
- class Normalize
3
+ class Normalize < Torch::NN::Module
4
4
  def initialize(mean, std, inplace: false)
5
5
  @mean = mean
6
6
  @std = std
7
7
  @inplace = inplace
8
8
  end
9
9
 
10
- def call(tensor)
10
+ def forward(tensor)
11
11
  F.normalize(tensor, @mean, @std, inplace: @inplace)
12
12
  end
13
13
  end
@@ -0,0 +1,18 @@
1
+ module TorchVision
2
+ module Transforms
3
+ class RandomHorizontalFlip < Torch::NN::Module
4
+ def initialize(p: 0.5)
5
+ super()
6
+ @p = p
7
+ end
8
+
9
+ def forward(img)
10
+ if Torch.rand(1).item < @p
11
+ F.hflip(img)
12
+ else
13
+ img
14
+ end
15
+ end
16
+ end
17
+ end
18
+ end
@@ -0,0 +1,18 @@
1
+ module TorchVision
2
+ module Transforms
3
+ class RandomVerticalFlip < Torch::NN::Module
4
+ def initialize(p: 0.5)
5
+ super()
6
+ @p = p
7
+ end
8
+
9
+ def forward(img)
10
+ if Torch.rand(1).item < @p
11
+ F.vflip(img)
12
+ else
13
+ img
14
+ end
15
+ end
16
+ end
17
+ end
18
+ end
@@ -0,0 +1,13 @@
1
+ module TorchVision
2
+ module Transforms
3
+ class Resize < Torch::NN::Module
4
+ def initialize(size)
5
+ @size = size
6
+ end
7
+
8
+ def forward(img)
9
+ F.resize(img, @size)
10
+ end
11
+ end
12
+ end
13
+ end
@@ -1,7 +1,7 @@
1
1
  module TorchVision
2
2
  module Transforms
3
- class ToTensor
4
- def call(pic)
3
+ class ToTensor < Torch::NN::Module
4
+ def forward(pic)
5
5
  F.to_tensor(pic)
6
6
  end
7
7
  end
@@ -0,0 +1,120 @@
1
+ module TorchVision
2
+ module Utils
3
+ class << self
4
+ def make_grid(tensor, nrow: 8, padding: 2, normalize: false, range: nil, scale_each: false, pad_value: 0)
5
+ unless Torch.tensor?(tensor) || (tensor.is_a?(Array) && tensor.all? { |t| Torch.tensor?(t) })
6
+ raise ArgumentError, "tensor or list of tensors expected, got #{tensor.class.name}"
7
+ end
8
+
9
+ # if list of tensors, convert to a 4D mini-batch Tensor
10
+ if tensor.is_a?(Array)
11
+ tensor = Torch.stack(tensor, dim: 0)
12
+ end
13
+
14
+ if tensor.dim == 2 # single image H x W
15
+ tensor = tensor.unsqueeze(0)
16
+ end
17
+ if tensor.dim == 3 # single image
18
+ if tensor.size(0) == 1 # if single-channel, convert to 3-channel
19
+ tensor = Torch.cat([tensor, tensor, tensor], 0)
20
+ end
21
+ tensor = tensor.unsqueeze(0)
22
+ end
23
+
24
+ if tensor.dim == 4 && tensor.size(1) == 1 # single-channel images
25
+ tensor = Torch.cat([tensor, tensor, tensor], 1)
26
+ end
27
+
28
+ if normalize
29
+ tensor = tensor.clone # avoid modifying tensor in-place
30
+ if !range.nil? && !range.is_a?(Array)
31
+ raise "range has to be an array (min, max) if specified. min and max are numbers"
32
+ end
33
+
34
+ norm_ip = lambda do |img, min, max|
35
+ img.clamp!(min, max)
36
+ img.add!(-min).div!(max - min + 1e-5)
37
+ end
38
+
39
+ norm_range = lambda do |t, range|
40
+ if !range.nil?
41
+ norm_ip.call(t, range[0], range[1])
42
+ else
43
+ norm_ip.call(t, t.min.to_f, t.max.to_f)
44
+ end
45
+ end
46
+
47
+ if scale_each
48
+ tensor.each do |t| # loop over mini-batch dimension
49
+ norm_range.call(t, range)
50
+ end
51
+ else
52
+ norm_range.call(tensor, range)
53
+ end
54
+ end
55
+
56
+ if tensor.size(0) == 1
57
+ return tensor.squeeze(0)
58
+ end
59
+
60
+ # make the mini-batch of images into a grid
61
+ nmaps = tensor.size(0)
62
+ xmaps = [nrow, nmaps].min
63
+ ymaps = (nmaps.to_f / xmaps).ceil
64
+ height, width = (tensor.size(2) + padding), (tensor.size(3) + padding)
65
+ num_channels = tensor.size(1)
66
+ grid = tensor.new_full([num_channels, height * ymaps + padding, width * xmaps + padding], pad_value)
67
+ k = 0
68
+ ymaps.times do |y|
69
+ xmaps.times do |x|
70
+ break if k >= nmaps
71
+ grid.narrow(1, y * height + padding, height - padding).narrow(2, x * width + padding, width - padding).copy!(tensor[k])
72
+ k += 1
73
+ end
74
+ end
75
+ grid
76
+ end
77
+
78
+ def save_image(tensor, fp, nrow: 8, padding: 2, normalize: false, range: nil, scale_each: false, pad_value: 0)
79
+ grid = make_grid(tensor, nrow: nrow, padding: padding, pad_value: pad_value, normalize: normalize, range: range, scale_each: scale_each)
80
+ # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
81
+ ndarr = grid.mul(255).add!(0.5).clamp!(0, 255).permute(1, 2, 0).to("cpu", dtype: :uint8)
82
+ im = image_from_array(ndarr)
83
+ im.write_to_file(fp)
84
+ end
85
+
86
+ # private
87
+ # Ruby-specific method
88
+ # TODO use Numo when bridge available
89
+ def image_from_array(array)
90
+ case array
91
+ when Torch::Tensor
92
+ # TODO support more dtypes
93
+ raise "Type not supported yet: #{array.dtype}" unless array.dtype == :uint8
94
+
95
+ array = array.contiguous unless array.contiguous?
96
+
97
+ width, height = array.shape
98
+ bands = array.shape[2] || 1
99
+ data = FFI::Pointer.new(:uint8, array._data_ptr)
100
+ data.define_singleton_method(:bytesize) do
101
+ array.numel * array.element_size
102
+ end
103
+
104
+ Vips::Image.new_from_memory(data, width, height, bands, :uchar)
105
+ when Numo::NArray
106
+ # TODO support more types
107
+ raise "Type not supported yet: #{array.class.name}" unless array.is_a?(Numo::UInt8)
108
+
109
+ width, height = array.shape
110
+ bands = array.shape[2] || 1
111
+ data = array.to_binary
112
+
113
+ Vips::Image.new_from_memory(data, width, height, bands, :uchar)
114
+ else
115
+ raise "Expected Torch::Tensor or Numo::NArray, not #{array.class.name}"
116
+ end
117
+ end
118
+ end
119
+ end
120
+ end
@@ -1,3 +1,3 @@
1
1
  module TorchVision
2
- VERSION = "0.1.3"
2
+ VERSION = "0.2.0"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torchvision
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.1.3
4
+ version: 0.2.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
- autorequire:
8
+ autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2020-06-30 00:00:00.000000000 Z
11
+ date: 2021-03-11 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray
@@ -25,63 +25,35 @@ dependencies:
25
25
  - !ruby/object:Gem::Version
26
26
  version: '0'
27
27
  - !ruby/object:Gem::Dependency
28
- name: torch-rb
28
+ name: ruby-vips
29
29
  requirement: !ruby/object:Gem::Requirement
30
30
  requirements:
31
31
  - - ">="
32
32
  - !ruby/object:Gem::Version
33
- version: 0.2.7
33
+ version: '2.1'
34
34
  type: :runtime
35
35
  prerelease: false
36
36
  version_requirements: !ruby/object:Gem::Requirement
37
37
  requirements:
38
38
  - - ">="
39
39
  - !ruby/object:Gem::Version
40
- version: 0.2.7
41
- - !ruby/object:Gem::Dependency
42
- name: bundler
43
- requirement: !ruby/object:Gem::Requirement
44
- requirements:
45
- - - ">="
46
- - !ruby/object:Gem::Version
47
- version: '0'
48
- type: :development
49
- prerelease: false
50
- version_requirements: !ruby/object:Gem::Requirement
51
- requirements:
52
- - - ">="
53
- - !ruby/object:Gem::Version
54
- version: '0'
40
+ version: '2.1'
55
41
  - !ruby/object:Gem::Dependency
56
- name: rake
57
- requirement: !ruby/object:Gem::Requirement
58
- requirements:
59
- - - ">="
60
- - !ruby/object:Gem::Version
61
- version: '0'
62
- type: :development
63
- prerelease: false
64
- version_requirements: !ruby/object:Gem::Requirement
65
- requirements:
66
- - - ">="
67
- - !ruby/object:Gem::Version
68
- version: '0'
69
- - !ruby/object:Gem::Dependency
70
- name: minitest
42
+ name: torch-rb
71
43
  requirement: !ruby/object:Gem::Requirement
72
44
  requirements:
73
45
  - - ">="
74
46
  - !ruby/object:Gem::Version
75
- version: '5'
76
- type: :development
47
+ version: 0.3.7
48
+ type: :runtime
77
49
  prerelease: false
78
50
  version_requirements: !ruby/object:Gem::Requirement
79
51
  requirements:
80
52
  - - ">="
81
53
  - !ruby/object:Gem::Version
82
- version: '5'
83
- description:
84
- email: andrew@chartkick.com
54
+ version: 0.3.7
55
+ description:
56
+ email: andrew@ankane.org
85
57
  executables: []
86
58
  extensions: []
87
59
  extra_rdoc_files: []
@@ -121,13 +93,17 @@ files:
121
93
  - lib/torchvision/transforms/compose.rb
122
94
  - lib/torchvision/transforms/functional.rb
123
95
  - lib/torchvision/transforms/normalize.rb
96
+ - lib/torchvision/transforms/random_horizontal_flip.rb
97
+ - lib/torchvision/transforms/random_vertical_flip.rb
98
+ - lib/torchvision/transforms/resize.rb
124
99
  - lib/torchvision/transforms/to_tensor.rb
100
+ - lib/torchvision/utils.rb
125
101
  - lib/torchvision/version.rb
126
102
  homepage: https://github.com/ankane/torchvision
127
103
  licenses:
128
104
  - BSD-3-Clause
129
105
  metadata: {}
130
- post_install_message:
106
+ post_install_message:
131
107
  rdoc_options: []
132
108
  require_paths:
133
109
  - lib
@@ -135,15 +111,15 @@ required_ruby_version: !ruby/object:Gem::Requirement
135
111
  requirements:
136
112
  - - ">="
137
113
  - !ruby/object:Gem::Version
138
- version: '2.4'
114
+ version: '2.6'
139
115
  required_rubygems_version: !ruby/object:Gem::Requirement
140
116
  requirements:
141
117
  - - ">="
142
118
  - !ruby/object:Gem::Version
143
119
  version: '0'
144
120
  requirements: []
145
- rubygems_version: 3.1.2
146
- signing_key:
121
+ rubygems_version: 3.2.3
122
+ signing_key:
147
123
  specification_version: 4
148
124
  summary: Computer vision datasets, transforms, and models for Ruby
149
125
  test_files: []