torchvision 0.1.1 → 0.2.2
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +35 -0
- data/LICENSE.txt +1 -1
- data/README.md +133 -7
- data/lib/torchvision.rb +39 -0
- data/lib/torchvision/datasets/cifar10.rb +117 -0
- data/lib/torchvision/datasets/cifar100.rb +41 -0
- data/lib/torchvision/datasets/dataset_folder.rb +91 -0
- data/lib/torchvision/datasets/fashion_mnist.rb +30 -0
- data/lib/torchvision/datasets/image_folder.rb +12 -0
- data/lib/torchvision/datasets/kmnist.rb +30 -0
- data/lib/torchvision/datasets/mnist.rb +47 -75
- data/lib/torchvision/datasets/vision_dataset.rb +67 -0
- data/lib/torchvision/models/alexnet.rb +42 -0
- data/lib/torchvision/models/basic_block.rb +46 -0
- data/lib/torchvision/models/bottleneck.rb +47 -0
- data/lib/torchvision/models/resnet.rb +129 -0
- data/lib/torchvision/models/resnet101.rb +9 -0
- data/lib/torchvision/models/resnet152.rb +9 -0
- data/lib/torchvision/models/resnet18.rb +9 -0
- data/lib/torchvision/models/resnet34.rb +9 -0
- data/lib/torchvision/models/resnet50.rb +9 -0
- data/lib/torchvision/models/resnext101_32x8d.rb +11 -0
- data/lib/torchvision/models/resnext50_32x4d.rb +11 -0
- data/lib/torchvision/models/vgg.rb +93 -0
- data/lib/torchvision/models/vgg11.rb +9 -0
- data/lib/torchvision/models/vgg11_bn.rb +9 -0
- data/lib/torchvision/models/vgg13.rb +9 -0
- data/lib/torchvision/models/vgg13_bn.rb +9 -0
- data/lib/torchvision/models/vgg16.rb +9 -0
- data/lib/torchvision/models/vgg16_bn.rb +9 -0
- data/lib/torchvision/models/vgg19.rb +9 -0
- data/lib/torchvision/models/vgg19_bn.rb +9 -0
- data/lib/torchvision/models/wide_resnet101_2.rb +10 -0
- data/lib/torchvision/models/wide_resnet50_2.rb +10 -0
- data/lib/torchvision/transforms/center_crop.rb +13 -0
- data/lib/torchvision/transforms/compose.rb +2 -2
- data/lib/torchvision/transforms/functional.rb +142 -8
- data/lib/torchvision/transforms/normalize.rb +2 -2
- data/lib/torchvision/transforms/random_horizontal_flip.rb +18 -0
- data/lib/torchvision/transforms/random_resized_crop.rb +70 -0
- data/lib/torchvision/transforms/random_vertical_flip.rb +18 -0
- data/lib/torchvision/transforms/resize.rb +13 -0
- data/lib/torchvision/transforms/to_tensor.rb +2 -2
- data/lib/torchvision/utils.rb +118 -0
- data/lib/torchvision/version.rb +1 -1
- metadata +51 -44
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: d218068a8502ca9aa41ec3240043c55122c22e10944fb5d0fb402e95ccb48e99
|
4
|
+
data.tar.gz: 3d9288e590d4d9a570c7f6c501ffbd30ef968c58790e2beec42721672b34075f
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: c841175a05fdace23e273413c863dcb2135fcab66414a6600864e16e12375c65d37a62fa2f03a433852f946324f40192ca5656d1f3e87b72ec0207e1f92e585f
|
7
|
+
data.tar.gz: a9e98d5929bed556b59ef29ad690668dac7bfe7b5f46f87e92d9e8dab352793ee659dcd8894343264c60b75bfd0cc07177db00bb0a5828410d879ebeb22bf860
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,38 @@
|
|
1
|
+
## 0.2.2 (2021-05-23)
|
2
|
+
|
3
|
+
- Fixed error with ruby-vips 2.1.2
|
4
|
+
|
5
|
+
## 0.2.1 (2021-03-14)
|
6
|
+
|
7
|
+
- Added `ImageFolder` and `DatasetFolder`
|
8
|
+
- Added `CenterCrop` and `RandomResizedCrop` transforms
|
9
|
+
- Added `crop` method
|
10
|
+
|
11
|
+
## 0.2.0 (2021-03-11)
|
12
|
+
|
13
|
+
- Added `RandomHorizontalFlip`, `RandomVerticalFlip`, and `Resize` transforms
|
14
|
+
- Added `save_image` method
|
15
|
+
- Added `data` and `targets` methods to datasets
|
16
|
+
- Removed support for Ruby < 2.6
|
17
|
+
|
18
|
+
Breaking changes
|
19
|
+
|
20
|
+
- Added dependency on libvips
|
21
|
+
- MNIST datasets return images instead of tensors
|
22
|
+
|
23
|
+
## 0.1.3 (2020-06-29)
|
24
|
+
|
25
|
+
- Added AlexNet model
|
26
|
+
- Added ResNet34, ResNet50, ResNet101, and ResNet152 models
|
27
|
+
- Added ResNeXt model
|
28
|
+
- Added VGG11, VGG13, VGG16, and VGG19 models
|
29
|
+
- Added Wide ResNet model
|
30
|
+
|
31
|
+
## 0.1.2 (2020-04-29)
|
32
|
+
|
33
|
+
- Added CIFAR10, CIFAR100, FashionMNIST, and KMNIST datasets
|
34
|
+
- Added ResNet18 model
|
35
|
+
|
1
36
|
## 0.1.1 (2020-04-28)
|
2
37
|
|
3
38
|
- Removed `mini_magick` for performance
|
data/LICENSE.txt
CHANGED
data/README.md
CHANGED
@@ -2,12 +2,16 @@
|
|
2
2
|
|
3
3
|
:fire: Computer vision datasets, transforms, and models for Ruby
|
4
4
|
|
5
|
-
|
6
|
-
|
7
|
-
[![Build Status](https://travis-ci.org/ankane/torchvision.svg?branch=master)](https://travis-ci.org/ankane/torchvision)
|
5
|
+
[![Build Status](https://github.com/ankane/torchvision/workflows/build/badge.svg?branch=master)](https://github.com/ankane/torchvision/actions)
|
8
6
|
|
9
7
|
## Installation
|
10
8
|
|
9
|
+
First, [install libvips](#libvips-installation). For Homebrew, use:
|
10
|
+
|
11
|
+
```sh
|
12
|
+
brew install vips
|
13
|
+
```
|
14
|
+
|
11
15
|
Add this line to your application’s Gemfile:
|
12
16
|
|
13
17
|
```ruby
|
@@ -16,17 +20,139 @@ gem 'torchvision'
|
|
16
20
|
|
17
21
|
## Getting Started
|
18
22
|
|
19
|
-
This library follows the [Python API](https://pytorch.org/docs/
|
23
|
+
This library follows the [Python API](https://pytorch.org/docs/stable/torchvision/index.html). Many methods and options are missing at the moment. PRs welcome!
|
24
|
+
|
25
|
+
## Examples
|
26
|
+
|
27
|
+
- [MNIST](https://github.com/ankane/torch.rb/tree/master/examples/mnist)
|
28
|
+
- [Transfer learning](https://github.com/ankane/torch.rb/tree/master/examples/transfer-learning)
|
29
|
+
- [Generative adversarial networks](https://github.com/ankane/torch.rb/tree/master/examples/gan)
|
20
30
|
|
21
31
|
## Datasets
|
22
32
|
|
23
|
-
|
33
|
+
Load a dataset
|
34
|
+
|
35
|
+
```ruby
|
36
|
+
TorchVision::Datasets::MNIST.new("./data", train: true, download: true)
|
37
|
+
```
|
38
|
+
|
39
|
+
Supported datasets are:
|
40
|
+
|
41
|
+
- CIFAR10
|
42
|
+
- CIFAR100
|
43
|
+
- FashionMNIST
|
44
|
+
- KMNIST
|
45
|
+
- MNIST
|
46
|
+
|
47
|
+
## Transforms
|
48
|
+
|
49
|
+
```ruby
|
50
|
+
TorchVision::Transforms::Compose.new([
|
51
|
+
TorchVision::Transforms::ToTensor.new,
|
52
|
+
TorchVision::Transforms::Normalize.new([0.1307], [0.3081])
|
53
|
+
])
|
54
|
+
```
|
55
|
+
|
56
|
+
Supported transforms are:
|
57
|
+
|
58
|
+
- CenterCrop
|
59
|
+
- Compose
|
60
|
+
- Normalize
|
61
|
+
- RandomHorizontalFlip
|
62
|
+
- RandomResizedCrop
|
63
|
+
- RandomVerticalFlip
|
64
|
+
- Resize
|
65
|
+
- ToTensor
|
66
|
+
|
67
|
+
## Models
|
68
|
+
|
69
|
+
- [AlexNet](#alexnet)
|
70
|
+
- [ResNet](#resnet)
|
71
|
+
- [ResNeXt](#resnext)
|
72
|
+
- [VGG](#vgg)
|
73
|
+
- [Wide ResNet](#wide-resnet)
|
74
|
+
|
75
|
+
### AlexNet
|
24
76
|
|
25
77
|
```ruby
|
26
|
-
|
27
|
-
trainset.size
|
78
|
+
TorchVision::Models::AlexNet.new
|
28
79
|
```
|
29
80
|
|
81
|
+
### ResNet
|
82
|
+
|
83
|
+
```ruby
|
84
|
+
TorchVision::Models::ResNet18.new
|
85
|
+
TorchVision::Models::ResNet34.new
|
86
|
+
TorchVision::Models::ResNet50.new
|
87
|
+
TorchVision::Models::ResNet101.new
|
88
|
+
TorchVision::Models::ResNet152.new
|
89
|
+
```
|
90
|
+
|
91
|
+
### ResNeXt
|
92
|
+
|
93
|
+
```ruby
|
94
|
+
TorchVision::Models::ResNext52_32x4d.new
|
95
|
+
TorchVision::Models::ResNext101_32x8d.new
|
96
|
+
```
|
97
|
+
|
98
|
+
### VGG
|
99
|
+
|
100
|
+
```ruby
|
101
|
+
TorchVision::Models::VGG11.new
|
102
|
+
TorchVision::Models::VGG11BN.new
|
103
|
+
TorchVision::Models::VGG13.new
|
104
|
+
TorchVision::Models::VGG13BN.new
|
105
|
+
TorchVision::Models::VGG16.new
|
106
|
+
TorchVision::Models::VGG16BN.new
|
107
|
+
TorchVision::Models::VGG19.new
|
108
|
+
TorchVision::Models::VGG19BN.new
|
109
|
+
```
|
110
|
+
|
111
|
+
### Wide ResNet
|
112
|
+
|
113
|
+
```ruby
|
114
|
+
TorchVision::Models::WideResNet52_2.new
|
115
|
+
TorchVision::Models::WideResNet101_2.new
|
116
|
+
```
|
117
|
+
|
118
|
+
## Pretrained Models
|
119
|
+
|
120
|
+
You can download pretrained models with [this script](pretrained.py)
|
121
|
+
|
122
|
+
```sh
|
123
|
+
pip install torchvision
|
124
|
+
python pretrained.py
|
125
|
+
```
|
126
|
+
|
127
|
+
And load them
|
128
|
+
|
129
|
+
```ruby
|
130
|
+
net = TorchVision::Models::ResNet18.new
|
131
|
+
net.load_state_dict(Torch.load("net.pth"))
|
132
|
+
```
|
133
|
+
|
134
|
+
## libvips Installation
|
135
|
+
|
136
|
+
### Linux
|
137
|
+
|
138
|
+
Check your package manager. For Ubuntu, use:
|
139
|
+
|
140
|
+
```sh
|
141
|
+
sudo apt install libvips
|
142
|
+
```
|
143
|
+
|
144
|
+
You can also [build from source](https://libvips.github.io/libvips/install.html).
|
145
|
+
|
146
|
+
### Mac
|
147
|
+
|
148
|
+
```sh
|
149
|
+
brew install vips
|
150
|
+
```
|
151
|
+
|
152
|
+
### Windows
|
153
|
+
|
154
|
+
Check out [the options](https://libvips.github.io/libvips/install.html).
|
155
|
+
|
30
156
|
## Disclaimer
|
31
157
|
|
32
158
|
This library downloads and prepares public datasets. We don’t host any datasets. Be sure to adhere to the license for each dataset.
|
data/lib/torchvision.rb
CHANGED
@@ -1,23 +1,62 @@
|
|
1
1
|
# dependencies
|
2
2
|
require "numo/narray"
|
3
|
+
require "vips"
|
3
4
|
require "torch"
|
4
5
|
|
5
6
|
# stdlib
|
6
7
|
require "digest"
|
7
8
|
require "fileutils"
|
8
9
|
require "net/http"
|
10
|
+
require "rubygems/package"
|
9
11
|
require "tmpdir"
|
10
12
|
|
11
13
|
# modules
|
14
|
+
require "torchvision/utils"
|
12
15
|
require "torchvision/version"
|
13
16
|
|
14
17
|
# datasets
|
18
|
+
require "torchvision/datasets/vision_dataset"
|
19
|
+
require "torchvision/datasets/dataset_folder"
|
20
|
+
require "torchvision/datasets/image_folder"
|
21
|
+
require "torchvision/datasets/cifar10"
|
22
|
+
require "torchvision/datasets/cifar100"
|
15
23
|
require "torchvision/datasets/mnist"
|
24
|
+
require "torchvision/datasets/fashion_mnist"
|
25
|
+
require "torchvision/datasets/kmnist"
|
26
|
+
|
27
|
+
# models
|
28
|
+
require "torchvision/models/alexnet"
|
29
|
+
require "torchvision/models/basic_block"
|
30
|
+
require "torchvision/models/bottleneck"
|
31
|
+
require "torchvision/models/resnet"
|
32
|
+
require "torchvision/models/resnet18"
|
33
|
+
require "torchvision/models/resnet34"
|
34
|
+
require "torchvision/models/resnet50"
|
35
|
+
require "torchvision/models/resnet101"
|
36
|
+
require "torchvision/models/resnet152"
|
37
|
+
require "torchvision/models/resnext50_32x4d"
|
38
|
+
require "torchvision/models/resnext101_32x8d"
|
39
|
+
require "torchvision/models/vgg"
|
40
|
+
require "torchvision/models/vgg11"
|
41
|
+
require "torchvision/models/vgg11_bn"
|
42
|
+
require "torchvision/models/vgg13"
|
43
|
+
require "torchvision/models/vgg13_bn"
|
44
|
+
require "torchvision/models/vgg16"
|
45
|
+
require "torchvision/models/vgg16_bn"
|
46
|
+
require "torchvision/models/vgg19"
|
47
|
+
require "torchvision/models/vgg19_bn"
|
48
|
+
require "torchvision/models/wide_resnet50_2"
|
49
|
+
require "torchvision/models/wide_resnet101_2"
|
16
50
|
|
17
51
|
# transforms
|
52
|
+
require "torchvision/transforms/center_crop"
|
18
53
|
require "torchvision/transforms/compose"
|
19
54
|
require "torchvision/transforms/functional"
|
20
55
|
require "torchvision/transforms/normalize"
|
56
|
+
require "torchvision/transforms/random_horizontal_flip"
|
57
|
+
require "torchvision/transforms/random_resized_crop"
|
58
|
+
require "torchvision/transforms/random_vertical_flip"
|
59
|
+
require "torchvision/transforms/resize"
|
21
60
|
require "torchvision/transforms/to_tensor"
|
22
61
|
|
23
62
|
module TorchVision
|
@@ -0,0 +1,117 @@
|
|
1
|
+
module TorchVision
|
2
|
+
module Datasets
|
3
|
+
class CIFAR10 < VisionDataset
|
4
|
+
# https://www.cs.toronto.edu/~kriz/cifar.html
|
5
|
+
|
6
|
+
def initialize(root, train: true, download: false, transform: nil, target_transform: nil)
|
7
|
+
super(root, transform: transform, target_transform: target_transform)
|
8
|
+
@train = train
|
9
|
+
|
10
|
+
self.download if download
|
11
|
+
|
12
|
+
if !_check_integrity
|
13
|
+
raise Error, "Dataset not found or corrupted. You can use download=True to download it"
|
14
|
+
end
|
15
|
+
|
16
|
+
downloaded_list = @train ? train_list : test_list
|
17
|
+
|
18
|
+
@data = String.new
|
19
|
+
@targets = String.new
|
20
|
+
|
21
|
+
downloaded_list.each do |file|
|
22
|
+
file_path = File.join(@root, base_folder, file[:filename])
|
23
|
+
File.open(file_path, "rb") do |f|
|
24
|
+
while !f.eof?
|
25
|
+
f.read(1) if multiple_labels?
|
26
|
+
@targets << f.read(1)
|
27
|
+
@data << f.read(3072)
|
28
|
+
end
|
29
|
+
end
|
30
|
+
end
|
31
|
+
|
32
|
+
@targets = @targets.unpack("C*")
|
33
|
+
# TODO switch i to -1 when Numo supports it
|
34
|
+
@data = Numo::UInt8.from_binary(@data).reshape(@targets.size, 3, 32, 32)
|
35
|
+
@data = @data.transpose(0, 2, 3, 1)
|
36
|
+
end
|
37
|
+
|
38
|
+
def size
|
39
|
+
@data.shape[0]
|
40
|
+
end
|
41
|
+
|
42
|
+
def [](index)
|
43
|
+
# TODO remove trues when Numo supports it
|
44
|
+
img, target = @data[index, true, true, true], @targets[index]
|
45
|
+
|
46
|
+
img = Utils.image_from_array(img)
|
47
|
+
|
48
|
+
img = @transform.call(img) if @transform
|
49
|
+
|
50
|
+
target = @target_transform.call(target) if @target_transform
|
51
|
+
|
52
|
+
[img, target]
|
53
|
+
end
|
54
|
+
|
55
|
+
def _check_integrity
|
56
|
+
root = @root
|
57
|
+
(train_list + test_list).each do |fentry|
|
58
|
+
fpath = File.join(root, base_folder, fentry[:filename])
|
59
|
+
return false unless check_integrity(fpath, fentry[:sha256])
|
60
|
+
end
|
61
|
+
true
|
62
|
+
end
|
63
|
+
|
64
|
+
def download
|
65
|
+
if _check_integrity
|
66
|
+
puts "Files already downloaded and verified"
|
67
|
+
return
|
68
|
+
end
|
69
|
+
|
70
|
+
download_file(url, download_root: @root, filename: filename, sha256: tgz_sha256)
|
71
|
+
|
72
|
+
path = File.join(@root, filename)
|
73
|
+
File.open(path, "rb") do |io|
|
74
|
+
Gem::Package.new("").extract_tar_gz(io, @root)
|
75
|
+
end
|
76
|
+
end
|
77
|
+
|
78
|
+
private
|
79
|
+
|
80
|
+
def base_folder
|
81
|
+
"cifar-10-batches-bin"
|
82
|
+
end
|
83
|
+
|
84
|
+
def url
|
85
|
+
"https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz"
|
86
|
+
end
|
87
|
+
|
88
|
+
def filename
|
89
|
+
"cifar-10-binary.tar.gz"
|
90
|
+
end
|
91
|
+
|
92
|
+
def tgz_sha256
|
93
|
+
"c4a38c50a1bc5f3a1c5537f2155ab9d68f9f25eb1ed8d9ddda3db29a59bca1dd"
|
94
|
+
end
|
95
|
+
|
96
|
+
def train_list
|
97
|
+
[
|
98
|
+
{filename: "data_batch_1.bin", sha256: "cee916563c9f80d84e3cc88e17fdc0941787f1244f00a67874d45b261883ada5"},
|
99
|
+
{filename: "data_batch_2.bin", sha256: "a591ca11fa1708a91ee40f54b3da4784ccd871ecf2137de63f51ada8b3fa57ed"},
|
100
|
+
{filename: "data_batch_3.bin", sha256: "bbe8596564c0f86427f876058170b84dac6670ddf06d79402899d93ceea26f67"},
|
101
|
+
{filename: "data_batch_4.bin", sha256: "014e562d6e23c72197cc727519169a60359f5eccd8945ad5a09d710285ff4e48"},
|
102
|
+
{filename: "data_batch_5.bin", sha256: "755304fc0b379caeae8c14f0dac912fbc7d6cd469eb67a1029a08a39453a9add"},
|
103
|
+
]
|
104
|
+
end
|
105
|
+
|
106
|
+
def test_list
|
107
|
+
[
|
108
|
+
{filename: "test_batch.bin", sha256: "8e2eb146ae340b09e24670f29cabc6326dba54da8789dab6768acf480273f65b"}
|
109
|
+
]
|
110
|
+
end
|
111
|
+
|
112
|
+
def multiple_labels?
|
113
|
+
false
|
114
|
+
end
|
115
|
+
end
|
116
|
+
end
|
117
|
+
end
|
@@ -0,0 +1,41 @@
|
|
1
|
+
module TorchVision
|
2
|
+
module Datasets
|
3
|
+
class CIFAR100 < CIFAR10
|
4
|
+
# https://www.cs.toronto.edu/~kriz/cifar.html
|
5
|
+
|
6
|
+
private
|
7
|
+
|
8
|
+
def base_folder
|
9
|
+
"cifar-100-binary"
|
10
|
+
end
|
11
|
+
|
12
|
+
def url
|
13
|
+
"https://www.cs.toronto.edu/~kriz/cifar-100-binary.tar.gz"
|
14
|
+
end
|
15
|
+
|
16
|
+
def filename
|
17
|
+
"cifar-100-binary.tar.gz"
|
18
|
+
end
|
19
|
+
|
20
|
+
def tgz_sha256
|
21
|
+
"58a81ae192c23a4be8b1804d68e518ed807d710a4eb253b1f2a199162a40d8ec"
|
22
|
+
end
|
23
|
+
|
24
|
+
def train_list
|
25
|
+
[
|
26
|
+
{filename: "train.bin", sha256: "f31298fc616915fa142368359df1c4ca2ae984d6915ca468b998a5ec6aeebf29"}
|
27
|
+
]
|
28
|
+
end
|
29
|
+
|
30
|
+
def test_list
|
31
|
+
[
|
32
|
+
{filename: "test.bin", sha256: "d8b1e6b7b3bee4020055f0699b111f60b1af1e262aeb93a0b659061746f8224a"}
|
33
|
+
]
|
34
|
+
end
|
35
|
+
|
36
|
+
def multiple_labels?
|
37
|
+
true
|
38
|
+
end
|
39
|
+
end
|
40
|
+
end
|
41
|
+
end
|