torchvision 0.1.3 → 0.2.0

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 44c29605f12dddf8196432223f2137ef9b9bef490996b718f5a9bdbc13dbd33f
4
- data.tar.gz: 8790200a0ed8f7a275f99327431dc8be99a7578c473ec03411f9411cd10c6c93
3
+ metadata.gz: 9aa604602112403b7f6738a7bb014812deb528b9a4d480ba89cf7f0c6d01b59c
4
+ data.tar.gz: 3ffa29d3ff5234040df51113d085a807d94a45c6998f8182db72bf60c25d67e7
5
5
  SHA512:
6
- metadata.gz: 65816ef10f524781553327f9634bb8818a7efa6fb072e81468949d937b5430dcddc7f6c8cf3b305c977dc4b1279d23b38fb034a4c277eecc1adcc0f2b8c99e3e
7
- data.tar.gz: 01f485a78cd5a19c9a0dc987f4e46931b033ffb3a84af703b602b0e90e8c221c99b5e6318c86b18361562cc3ca582b3a0822de151a5edf63964002317a684bfa
6
+ metadata.gz: 89f61279aed314c84ae33c14efe266b849bf74a40a9f90beb650358610b479cb44f6cd247ff56aa89d0227ef6358973b5055fd1ce9bcf1e973402b2118bfe75d
7
+ data.tar.gz: 4eb63bfd9a79bd3683238c186e0e70aba494028f54ce8f971a551f362eec4b4e55fe83c8e738c7669fde2c3671e3360a01b02c4d29f6c965640a6c08b3c4379d
data/CHANGELOG.md CHANGED
@@ -1,3 +1,15 @@
1
+ ## 0.2.0 (2021-03-11)
2
+
3
+ - Added `RandomHorizontalFlip`, `RandomVerticalFlip`, and `Resize` transforms
4
+ - Added `save_image` method
5
+ - Added `data` and `targets` methods to datasets
6
+ - Removed support for Ruby < 2.6
7
+
8
+ Breaking changes
9
+
10
+ - Added dependency on libvips
11
+ - MNIST datasets return images instead of tensors
12
+
1
13
  ## 0.1.3 (2020-06-29)
2
14
 
3
15
  - Added AlexNet model
data/LICENSE.txt CHANGED
@@ -1,7 +1,7 @@
1
1
  BSD 3-Clause License
2
2
 
3
- Copyright (c) Andrew Kane 2020,
4
3
  Copyright (c) Soumith Chintala 2016,
4
+ Copyright (c) Andrew Kane 2020-2021,
5
5
  All rights reserved.
6
6
 
7
7
  Redistribution and use in source and binary forms, with or without
data/README.md CHANGED
@@ -2,12 +2,16 @@
2
2
 
3
3
  :fire: Computer vision datasets, transforms, and models for Ruby
4
4
 
5
- This gem is currently experimental. There may be breaking changes between each release. Please report any issues you experience.
6
-
7
- [![Build Status](https://travis-ci.org/ankane/torchvision.svg?branch=master)](https://travis-ci.org/ankane/torchvision)
5
+ [![Build Status](https://github.com/ankane/torchvision/workflows/build/badge.svg?branch=master)](https://github.com/ankane/torchvision/actions)
8
6
 
9
7
  ## Installation
10
8
 
9
+ First, [install libvips](libvips-installation). For Homebrew, use:
10
+
11
+ ```sh
12
+ brew install vips
13
+ ```
14
+
11
15
  Add this line to your application’s Gemfile:
12
16
 
13
17
  ```ruby
@@ -16,7 +20,12 @@ gem 'torchvision'
16
20
 
17
21
  ## Getting Started
18
22
 
19
- This library follows the [Python API](https://pytorch.org/docs/master/torchvision/). Many methods and options are missing at the moment. PRs welcome!
23
+ This library follows the [Python API](https://pytorch.org/docs/stable/torchvision/index.html). Many methods and options are missing at the moment. PRs welcome!
24
+
25
+ ## Examples
26
+
27
+ - [MNIST](https://github.com/ankane/torch.rb/tree/master/examples/mnist)
28
+ - [Generative Adversarial Networks](https://github.com/ankane/torch.rb/tree/master/examples/gan)
20
29
 
21
30
  ## Datasets
22
31
 
@@ -43,6 +52,15 @@ TorchVision::Transforms::Compose.new([
43
52
  ])
44
53
  ```
45
54
 
55
+ Supported transforms are:
56
+
57
+ - Compose
58
+ - Normalize
59
+ - RandomHorizontalFlip
60
+ - RandomVerticalFlip
61
+ - Resize
62
+ - ToTensor
63
+
46
64
  ## Models
47
65
 
48
66
  - [AlexNet](#alexnet)
@@ -94,6 +112,40 @@ TorchVision::Models::WideResNet52_2.new
94
112
  TorchVision::Models::WideResNet101_2.new
95
113
  ```
96
114
 
115
+ ## Pretrained Models
116
+
117
+ You can download pretrained models with [this script](pretrained.py)
118
+
119
+ ```sh
120
+ pip install torchvision
121
+ python pretrained.py
122
+ ```
123
+
124
+ And load them
125
+
126
+ ```ruby
127
+ net = TorchVision::Models::ResNet18.new
128
+ net.load_state_dict(Torch.load("net.pth"))
129
+ ```
130
+
131
+ ## libvips Installation
132
+
133
+ ### Ubuntu
134
+
135
+ ```sh
136
+ sudo apt install libvips
137
+ ```
138
+
139
+ ### Mac
140
+
141
+ ```sh
142
+ brew install vips
143
+ ```
144
+
145
+ ### Windows
146
+
147
+ Check out [the options](https://libvips.github.io/libvips/install.html).
148
+
97
149
  ## Disclaimer
98
150
 
99
151
  This library downloads and prepares public datasets. We don’t host any datasets. Be sure to adhere to the license for each dataset.
data/lib/torchvision.rb CHANGED
@@ -1,5 +1,6 @@
1
1
  # dependencies
2
2
  require "numo/narray"
3
+ require "vips"
3
4
  require "torch"
4
5
 
5
6
  # stdlib
@@ -10,6 +11,7 @@ require "rubygems/package"
10
11
  require "tmpdir"
11
12
 
12
13
  # modules
14
+ require "torchvision/utils"
13
15
  require "torchvision/version"
14
16
 
15
17
  # datasets
@@ -48,6 +50,9 @@ require "torchvision/models/wide_resnet101_2"
48
50
  require "torchvision/transforms/compose"
49
51
  require "torchvision/transforms/functional"
50
52
  require "torchvision/transforms/normalize"
53
+ require "torchvision/transforms/random_horizontal_flip"
54
+ require "torchvision/transforms/random_vertical_flip"
55
+ require "torchvision/transforms/resize"
51
56
  require "torchvision/transforms/to_tensor"
52
57
 
53
58
  module TorchVision
@@ -43,7 +43,8 @@ module TorchVision
43
43
  # TODO remove trues when Numo supports it
44
44
  img, target = @data[index, true, true, true], @targets[index]
45
45
 
46
- # TODO convert to image
46
+ img = Utils.image_from_array(img)
47
+
47
48
  img = @transform.call(img) if @transform
48
49
 
49
50
  target = @target_transform.call(target) if @target_transform
@@ -2,7 +2,6 @@ module TorchVision
2
2
  module Datasets
3
3
  class MNIST < VisionDataset
4
4
  # http://yann.lecun.com/exdb/mnist/
5
-
6
5
  def initialize(root, train: true, download: false, transform: nil, target_transform: nil)
7
6
  super(root, transform: transform, target_transform: target_transform)
8
7
  @train = train
@@ -24,7 +23,8 @@ module TorchVision
24
23
  def [](index)
25
24
  img, target = @data[index], @targets[index].item
26
25
 
27
- # TODO convert to image
26
+ img = Utils.image_from_array(img)
27
+
28
28
  img = @transform.call(img) if @transform
29
29
 
30
30
  target = @target_transform.call(target) if @target_transform
@@ -2,6 +2,8 @@ module TorchVision
2
2
  module Datasets
3
3
  # TODO inherit Torch::Utils::Data::Dataset
4
4
  class VisionDataset
5
+ attr_reader :data, :targets
6
+
5
7
  def initialize(root, transforms: nil, transform: nil, target_transform: nil)
6
8
  @root = root
7
9
 
@@ -1,11 +1,11 @@
1
1
  module TorchVision
2
2
  module Transforms
3
- class Compose
3
+ class Compose < Torch::NN::Module
4
4
  def initialize(transforms)
5
5
  @transforms = transforms
6
6
  end
7
7
 
8
- def call(img)
8
+ def forward(img)
9
9
  @transforms.each do |t|
10
10
  img = t.call(img)
11
11
  end
@@ -32,10 +32,30 @@ module TorchVision
32
32
  tensor
33
33
  end
34
34
 
35
+ def resize(img, size)
36
+ raise "img should be Vips::Image. Got #{img.class.name}" unless img.is_a?(Vips::Image)
37
+ # TODO support array size
38
+ raise "Got inappropriate size arg: #{size}" unless size.is_a?(Integer)
39
+
40
+ w, h = img.size
41
+ if (w <= h && w == size) || (h <= w && h == size)
42
+ return img
43
+ end
44
+ if w < h
45
+ ow = size
46
+ oh = (size * h / w).to_i
47
+ img.thumbnail_image(ow, height: oh)
48
+ else
49
+ oh = size
50
+ ow = (size * w / h).to_i
51
+ img.thumbnail_image(ow, height: oh)
52
+ end
53
+ end
54
+
35
55
  # TODO improve
36
56
  def to_tensor(pic)
37
- if !pic.is_a?(Numo::NArray) && !pic.is_a?(Torch::Tensor)
38
- raise ArgumentError, "pic should be tensor or Numo::NArray. Got #{pic.class.name}"
57
+ if !pic.is_a?(Numo::NArray) && !pic.is_a?(Vips::Image)
58
+ raise ArgumentError, "pic should be Vips::Image or Numo::NArray. Got #{pic.class.name}"
39
59
  end
40
60
 
41
61
  if pic.is_a?(Numo::NArray) && ![2, 3].include?(pic.ndim)
@@ -44,15 +64,44 @@ module TorchVision
44
64
 
45
65
  if pic.is_a?(Numo::NArray)
46
66
  if pic.ndim == 2
47
- raise Torch::NotImplementedYet
67
+ pic = pic.reshape(*pic.shape, 1)
48
68
  end
49
69
 
50
70
  img = Torch.from_numo(pic.transpose(2, 0, 1))
51
- return img.float.div(255)
71
+ if img.dtype == :uint8
72
+ return img.float.div(255)
73
+ else
74
+ return img
75
+ end
76
+ end
77
+
78
+ case pic.format
79
+ when :uchar
80
+ img = Torch::ByteTensor.new(Torch::ByteStorage.from_buffer(pic.write_to_memory))
81
+ else
82
+ raise Error, "Format not supported yet: #{pic.format}"
52
83
  end
53
84
 
54
- pic = pic.float
55
- pic.unsqueeze!(0).div!(255)
85
+ img = img.view(pic.height, pic.width, pic.bands)
86
+ # put it from HWC to CHW format
87
+ img = img.permute([2, 0, 1]).contiguous
88
+ img.float.div(255)
89
+ end
90
+
91
+ def hflip(img)
92
+ if img.is_a?(Torch::Tensor)
93
+ img.flip(-1)
94
+ else
95
+ img.flip(:horizontal)
96
+ end
97
+ end
98
+
99
+ def vflip(img)
100
+ if img.is_a?(Torch::Tensor)
101
+ img.flip(-2)
102
+ else
103
+ img.flip(:vertical)
104
+ end
56
105
  end
57
106
  end
58
107
  end
@@ -1,13 +1,13 @@
1
1
  module TorchVision
2
2
  module Transforms
3
- class Normalize
3
+ class Normalize < Torch::NN::Module
4
4
  def initialize(mean, std, inplace: false)
5
5
  @mean = mean
6
6
  @std = std
7
7
  @inplace = inplace
8
8
  end
9
9
 
10
- def call(tensor)
10
+ def forward(tensor)
11
11
  F.normalize(tensor, @mean, @std, inplace: @inplace)
12
12
  end
13
13
  end
@@ -0,0 +1,18 @@
1
+ module TorchVision
2
+ module Transforms
3
+ class RandomHorizontalFlip < Torch::NN::Module
4
+ def initialize(p: 0.5)
5
+ super()
6
+ @p = p
7
+ end
8
+
9
+ def forward(img)
10
+ if Torch.rand(1).item < @p
11
+ F.hflip(img)
12
+ else
13
+ img
14
+ end
15
+ end
16
+ end
17
+ end
18
+ end
@@ -0,0 +1,18 @@
1
+ module TorchVision
2
+ module Transforms
3
+ class RandomVerticalFlip < Torch::NN::Module
4
+ def initialize(p: 0.5)
5
+ super()
6
+ @p = p
7
+ end
8
+
9
+ def forward(img)
10
+ if Torch.rand(1).item < @p
11
+ F.vflip(img)
12
+ else
13
+ img
14
+ end
15
+ end
16
+ end
17
+ end
18
+ end
@@ -0,0 +1,13 @@
1
+ module TorchVision
2
+ module Transforms
3
+ class Resize < Torch::NN::Module
4
+ def initialize(size)
5
+ @size = size
6
+ end
7
+
8
+ def forward(img)
9
+ F.resize(img, @size)
10
+ end
11
+ end
12
+ end
13
+ end
@@ -1,7 +1,7 @@
1
1
  module TorchVision
2
2
  module Transforms
3
- class ToTensor
4
- def call(pic)
3
+ class ToTensor < Torch::NN::Module
4
+ def forward(pic)
5
5
  F.to_tensor(pic)
6
6
  end
7
7
  end
@@ -0,0 +1,120 @@
1
+ module TorchVision
2
+ module Utils
3
+ class << self
4
+ def make_grid(tensor, nrow: 8, padding: 2, normalize: false, range: nil, scale_each: false, pad_value: 0)
5
+ unless Torch.tensor?(tensor) || (tensor.is_a?(Array) && tensor.all? { |t| Torch.tensor?(t) })
6
+ raise ArgumentError, "tensor or list of tensors expected, got #{tensor.class.name}"
7
+ end
8
+
9
+ # if list of tensors, convert to a 4D mini-batch Tensor
10
+ if tensor.is_a?(Array)
11
+ tensor = Torch.stack(tensor, dim: 0)
12
+ end
13
+
14
+ if tensor.dim == 2 # single image H x W
15
+ tensor = tensor.unsqueeze(0)
16
+ end
17
+ if tensor.dim == 3 # single image
18
+ if tensor.size(0) == 1 # if single-channel, convert to 3-channel
19
+ tensor = Torch.cat([tensor, tensor, tensor], 0)
20
+ end
21
+ tensor = tensor.unsqueeze(0)
22
+ end
23
+
24
+ if tensor.dim == 4 && tensor.size(1) == 1 # single-channel images
25
+ tensor = Torch.cat([tensor, tensor, tensor], 1)
26
+ end
27
+
28
+ if normalize
29
+ tensor = tensor.clone # avoid modifying tensor in-place
30
+ if !range.nil? && !range.is_a?(Array)
31
+ raise "range has to be an array (min, max) if specified. min and max are numbers"
32
+ end
33
+
34
+ norm_ip = lambda do |img, min, max|
35
+ img.clamp!(min, max)
36
+ img.add!(-min).div!(max - min + 1e-5)
37
+ end
38
+
39
+ norm_range = lambda do |t, range|
40
+ if !range.nil?
41
+ norm_ip.call(t, range[0], range[1])
42
+ else
43
+ norm_ip.call(t, t.min.to_f, t.max.to_f)
44
+ end
45
+ end
46
+
47
+ if scale_each
48
+ tensor.each do |t| # loop over mini-batch dimension
49
+ norm_range.call(t, range)
50
+ end
51
+ else
52
+ norm_range.call(tensor, range)
53
+ end
54
+ end
55
+
56
+ if tensor.size(0) == 1
57
+ return tensor.squeeze(0)
58
+ end
59
+
60
+ # make the mini-batch of images into a grid
61
+ nmaps = tensor.size(0)
62
+ xmaps = [nrow, nmaps].min
63
+ ymaps = (nmaps.to_f / xmaps).ceil
64
+ height, width = (tensor.size(2) + padding), (tensor.size(3) + padding)
65
+ num_channels = tensor.size(1)
66
+ grid = tensor.new_full([num_channels, height * ymaps + padding, width * xmaps + padding], pad_value)
67
+ k = 0
68
+ ymaps.times do |y|
69
+ xmaps.times do |x|
70
+ break if k >= nmaps
71
+ grid.narrow(1, y * height + padding, height - padding).narrow(2, x * width + padding, width - padding).copy!(tensor[k])
72
+ k += 1
73
+ end
74
+ end
75
+ grid
76
+ end
77
+
78
+ def save_image(tensor, fp, nrow: 8, padding: 2, normalize: false, range: nil, scale_each: false, pad_value: 0)
79
+ grid = make_grid(tensor, nrow: nrow, padding: padding, pad_value: pad_value, normalize: normalize, range: range, scale_each: scale_each)
80
+ # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
81
+ ndarr = grid.mul(255).add!(0.5).clamp!(0, 255).permute(1, 2, 0).to("cpu", dtype: :uint8)
82
+ im = image_from_array(ndarr)
83
+ im.write_to_file(fp)
84
+ end
85
+
86
+ # private
87
+ # Ruby-specific method
88
+ # TODO use Numo when bridge available
89
+ def image_from_array(array)
90
+ case array
91
+ when Torch::Tensor
92
+ # TODO support more dtypes
93
+ raise "Type not supported yet: #{array.dtype}" unless array.dtype == :uint8
94
+
95
+ array = array.contiguous unless array.contiguous?
96
+
97
+ width, height = array.shape
98
+ bands = array.shape[2] || 1
99
+ data = FFI::Pointer.new(:uint8, array._data_ptr)
100
+ data.define_singleton_method(:bytesize) do
101
+ array.numel * array.element_size
102
+ end
103
+
104
+ Vips::Image.new_from_memory(data, width, height, bands, :uchar)
105
+ when Numo::NArray
106
+ # TODO support more types
107
+ raise "Type not supported yet: #{array.class.name}" unless array.is_a?(Numo::UInt8)
108
+
109
+ width, height = array.shape
110
+ bands = array.shape[2] || 1
111
+ data = array.to_binary
112
+
113
+ Vips::Image.new_from_memory(data, width, height, bands, :uchar)
114
+ else
115
+ raise "Expected Torch::Tensor or Numo::NArray, not #{array.class.name}"
116
+ end
117
+ end
118
+ end
119
+ end
120
+ end
@@ -1,3 +1,3 @@
1
1
  module TorchVision
2
- VERSION = "0.1.3"
2
+ VERSION = "0.2.0"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torchvision
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.1.3
4
+ version: 0.2.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
- autorequire:
8
+ autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2020-06-30 00:00:00.000000000 Z
11
+ date: 2021-03-11 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: numo-narray
@@ -25,63 +25,35 @@ dependencies:
25
25
  - !ruby/object:Gem::Version
26
26
  version: '0'
27
27
  - !ruby/object:Gem::Dependency
28
- name: torch-rb
28
+ name: ruby-vips
29
29
  requirement: !ruby/object:Gem::Requirement
30
30
  requirements:
31
31
  - - ">="
32
32
  - !ruby/object:Gem::Version
33
- version: 0.2.7
33
+ version: '2.1'
34
34
  type: :runtime
35
35
  prerelease: false
36
36
  version_requirements: !ruby/object:Gem::Requirement
37
37
  requirements:
38
38
  - - ">="
39
39
  - !ruby/object:Gem::Version
40
- version: 0.2.7
41
- - !ruby/object:Gem::Dependency
42
- name: bundler
43
- requirement: !ruby/object:Gem::Requirement
44
- requirements:
45
- - - ">="
46
- - !ruby/object:Gem::Version
47
- version: '0'
48
- type: :development
49
- prerelease: false
50
- version_requirements: !ruby/object:Gem::Requirement
51
- requirements:
52
- - - ">="
53
- - !ruby/object:Gem::Version
54
- version: '0'
40
+ version: '2.1'
55
41
  - !ruby/object:Gem::Dependency
56
- name: rake
57
- requirement: !ruby/object:Gem::Requirement
58
- requirements:
59
- - - ">="
60
- - !ruby/object:Gem::Version
61
- version: '0'
62
- type: :development
63
- prerelease: false
64
- version_requirements: !ruby/object:Gem::Requirement
65
- requirements:
66
- - - ">="
67
- - !ruby/object:Gem::Version
68
- version: '0'
69
- - !ruby/object:Gem::Dependency
70
- name: minitest
42
+ name: torch-rb
71
43
  requirement: !ruby/object:Gem::Requirement
72
44
  requirements:
73
45
  - - ">="
74
46
  - !ruby/object:Gem::Version
75
- version: '5'
76
- type: :development
47
+ version: 0.3.7
48
+ type: :runtime
77
49
  prerelease: false
78
50
  version_requirements: !ruby/object:Gem::Requirement
79
51
  requirements:
80
52
  - - ">="
81
53
  - !ruby/object:Gem::Version
82
- version: '5'
83
- description:
84
- email: andrew@chartkick.com
54
+ version: 0.3.7
55
+ description:
56
+ email: andrew@ankane.org
85
57
  executables: []
86
58
  extensions: []
87
59
  extra_rdoc_files: []
@@ -121,13 +93,17 @@ files:
121
93
  - lib/torchvision/transforms/compose.rb
122
94
  - lib/torchvision/transforms/functional.rb
123
95
  - lib/torchvision/transforms/normalize.rb
96
+ - lib/torchvision/transforms/random_horizontal_flip.rb
97
+ - lib/torchvision/transforms/random_vertical_flip.rb
98
+ - lib/torchvision/transforms/resize.rb
124
99
  - lib/torchvision/transforms/to_tensor.rb
100
+ - lib/torchvision/utils.rb
125
101
  - lib/torchvision/version.rb
126
102
  homepage: https://github.com/ankane/torchvision
127
103
  licenses:
128
104
  - BSD-3-Clause
129
105
  metadata: {}
130
- post_install_message:
106
+ post_install_message:
131
107
  rdoc_options: []
132
108
  require_paths:
133
109
  - lib
@@ -135,15 +111,15 @@ required_ruby_version: !ruby/object:Gem::Requirement
135
111
  requirements:
136
112
  - - ">="
137
113
  - !ruby/object:Gem::Version
138
- version: '2.4'
114
+ version: '2.6'
139
115
  required_rubygems_version: !ruby/object:Gem::Requirement
140
116
  requirements:
141
117
  - - ">="
142
118
  - !ruby/object:Gem::Version
143
119
  version: '0'
144
120
  requirements: []
145
- rubygems_version: 3.1.2
146
- signing_key:
121
+ rubygems_version: 3.2.3
122
+ signing_key:
147
123
  specification_version: 4
148
124
  summary: Computer vision datasets, transforms, and models for Ruby
149
125
  test_files: []