torchvision 0.1.0 → 0.2.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +35 -0
- data/LICENSE.txt +1 -1
- data/README.md +133 -5
- data/lib/torchvision.rb +40 -1
- data/lib/torchvision/datasets/cifar10.rb +117 -0
- data/lib/torchvision/datasets/cifar100.rb +41 -0
- data/lib/torchvision/datasets/dataset_folder.rb +91 -0
- data/lib/torchvision/datasets/fashion_mnist.rb +30 -0
- data/lib/torchvision/datasets/image_folder.rb +12 -0
- data/lib/torchvision/datasets/kmnist.rb +30 -0
- data/lib/torchvision/datasets/mnist.rb +47 -76
- data/lib/torchvision/datasets/vision_dataset.rb +67 -0
- data/lib/torchvision/models/alexnet.rb +42 -0
- data/lib/torchvision/models/basic_block.rb +46 -0
- data/lib/torchvision/models/bottleneck.rb +47 -0
- data/lib/torchvision/models/resnet.rb +129 -0
- data/lib/torchvision/models/resnet101.rb +9 -0
- data/lib/torchvision/models/resnet152.rb +9 -0
- data/lib/torchvision/models/resnet18.rb +9 -0
- data/lib/torchvision/models/resnet34.rb +9 -0
- data/lib/torchvision/models/resnet50.rb +9 -0
- data/lib/torchvision/models/resnext101_32x8d.rb +11 -0
- data/lib/torchvision/models/resnext50_32x4d.rb +11 -0
- data/lib/torchvision/models/vgg.rb +93 -0
- data/lib/torchvision/models/vgg11.rb +9 -0
- data/lib/torchvision/models/vgg11_bn.rb +9 -0
- data/lib/torchvision/models/vgg13.rb +9 -0
- data/lib/torchvision/models/vgg13_bn.rb +9 -0
- data/lib/torchvision/models/vgg16.rb +9 -0
- data/lib/torchvision/models/vgg16_bn.rb +9 -0
- data/lib/torchvision/models/vgg19.rb +9 -0
- data/lib/torchvision/models/vgg19_bn.rb +9 -0
- data/lib/torchvision/models/wide_resnet101_2.rb +10 -0
- data/lib/torchvision/models/wide_resnet50_2.rb +10 -0
- data/lib/torchvision/transforms/center_crop.rb +13 -0
- data/lib/torchvision/transforms/compose.rb +2 -2
- data/lib/torchvision/transforms/functional.rb +142 -7
- data/lib/torchvision/transforms/normalize.rb +2 -2
- data/lib/torchvision/transforms/random_horizontal_flip.rb +18 -0
- data/lib/torchvision/transforms/random_resized_crop.rb +70 -0
- data/lib/torchvision/transforms/random_vertical_flip.rb +18 -0
- data/lib/torchvision/transforms/resize.rb +13 -0
- data/lib/torchvision/transforms/to_tensor.rb +2 -2
- data/lib/torchvision/utils.rb +120 -0
- data/lib/torchvision/version.rb +1 -1
- metadata +50 -57
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: bbb87c59c0f081c0de57ccdd62e30bfc551e1cb69523e4ffd498c997e1a2d8b3
|
4
|
+
data.tar.gz: 890da113706e659d57194980c5c9262075beb8398a75da2997c0812b70abe308
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 3445b62b7824ae16205034881d37c48ac4c70d7e5677014755ae5600632f9ce45168f41b0d3e98c8104eb8337e1566db4df3e0ad5ace5e6a46a5d213d01b6c8d
|
7
|
+
data.tar.gz: 93f22c385586ff8a010880676806f6bc9ba2f614c4c14886235a300ea6e2abce0f80c260e255644e1b4d24e6ecddfd21830dde2960a92a9492239e69622d4548
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,38 @@
|
|
1
|
+
## 0.2.1 (2021-03-14)
|
2
|
+
|
3
|
+
- Added `ImageFolder` and `DatasetFolder`
|
4
|
+
- Added `CenterCrop` and `RandomResizedCrop` transforms
|
5
|
+
- Added `crop` method
|
6
|
+
|
7
|
+
## 0.2.0 (2021-03-11)
|
8
|
+
|
9
|
+
- Added `RandomHorizontalFlip`, `RandomVerticalFlip`, and `Resize` transforms
|
10
|
+
- Added `save_image` method
|
11
|
+
- Added `data` and `targets` methods to datasets
|
12
|
+
- Removed support for Ruby < 2.6
|
13
|
+
|
14
|
+
Breaking changes
|
15
|
+
|
16
|
+
- Added dependency on libvips
|
17
|
+
- MNIST datasets return images instead of tensors
|
18
|
+
|
19
|
+
## 0.1.3 (2020-06-29)
|
20
|
+
|
21
|
+
- Added AlexNet model
|
22
|
+
- Added ResNet34, ResNet50, ResNet101, and ResNet152 models
|
23
|
+
- Added ResNeXt model
|
24
|
+
- Added VGG11, VGG13, VGG16, and VGG19 models
|
25
|
+
- Added Wide ResNet model
|
26
|
+
|
27
|
+
## 0.1.2 (2020-04-29)
|
28
|
+
|
29
|
+
- Added CIFAR10, CIFAR100, FashionMNIST, and KMNIST datasets
|
30
|
+
- Added ResNet18 model
|
31
|
+
|
32
|
+
## 0.1.1 (2020-04-28)
|
33
|
+
|
34
|
+
- Removed `mini_magick` for performance
|
35
|
+
|
1
36
|
## 0.1.0 (2020-04-27)
|
2
37
|
|
3
38
|
- First release
|
data/LICENSE.txt
CHANGED
data/README.md
CHANGED
@@ -2,10 +2,16 @@
|
|
2
2
|
|
3
3
|
:fire: Computer vision datasets, transforms, and models for Ruby
|
4
4
|
|
5
|
-
|
5
|
+
[](https://github.com/ankane/torchvision/actions)
|
6
6
|
|
7
7
|
## Installation
|
8
8
|
|
9
|
+
First, [install libvips](#libvips-installation). For Homebrew, use:
|
10
|
+
|
11
|
+
```sh
|
12
|
+
brew install vips
|
13
|
+
```
|
14
|
+
|
9
15
|
Add this line to your application’s Gemfile:
|
10
16
|
|
11
17
|
```ruby
|
@@ -14,17 +20,139 @@ gem 'torchvision'
|
|
14
20
|
|
15
21
|
## Getting Started
|
16
22
|
|
17
|
-
This library follows the [Python API](https://pytorch.org/docs/
|
23
|
+
This library follows the [Python API](https://pytorch.org/docs/stable/torchvision/index.html). Many methods and options are missing at the moment. PRs welcome!
|
24
|
+
|
25
|
+
## Examples
|
26
|
+
|
27
|
+
- [MNIST](https://github.com/ankane/torch.rb/tree/master/examples/mnist)
|
28
|
+
- [Transfer learning](https://github.com/ankane/torch.rb/tree/master/examples/transfer-learning)
|
29
|
+
- [Generative adversarial networks](https://github.com/ankane/torch.rb/tree/master/examples/gan)
|
18
30
|
|
19
31
|
## Datasets
|
20
32
|
|
21
|
-
|
33
|
+
Load a dataset
|
34
|
+
|
35
|
+
```ruby
|
36
|
+
TorchVision::Datasets::MNIST.new("./data", train: true, download: true)
|
37
|
+
```
|
38
|
+
|
39
|
+
Supported datasets are:
|
40
|
+
|
41
|
+
- CIFAR10
|
42
|
+
- CIFAR100
|
43
|
+
- FashionMNIST
|
44
|
+
- KMNIST
|
45
|
+
- MNIST
|
46
|
+
|
47
|
+
## Transforms
|
48
|
+
|
49
|
+
```ruby
|
50
|
+
TorchVision::Transforms::Compose.new([
|
51
|
+
TorchVision::Transforms::ToTensor.new,
|
52
|
+
TorchVision::Transforms::Normalize.new([0.1307], [0.3081])
|
53
|
+
])
|
54
|
+
```
|
55
|
+
|
56
|
+
Supported transforms are:
|
57
|
+
|
58
|
+
- CenterCrop
|
59
|
+
- Compose
|
60
|
+
- Normalize
|
61
|
+
- RandomHorizontalFlip
|
62
|
+
- RandomResizedCrop
|
63
|
+
- RandomVerticalFlip
|
64
|
+
- Resize
|
65
|
+
- ToTensor
|
66
|
+
|
67
|
+
## Models
|
68
|
+
|
69
|
+
- [AlexNet](#alexnet)
|
70
|
+
- [ResNet](#resnet)
|
71
|
+
- [ResNeXt](#resnext)
|
72
|
+
- [VGG](#vgg)
|
73
|
+
- [Wide ResNet](#wide-resnet)
|
74
|
+
|
75
|
+
### AlexNet
|
76
|
+
|
77
|
+
```ruby
|
78
|
+
TorchVision::Models::AlexNet.new
|
79
|
+
```
|
80
|
+
|
81
|
+
### ResNet
|
82
|
+
|
83
|
+
```ruby
|
84
|
+
TorchVision::Models::ResNet18.new
|
85
|
+
TorchVision::Models::ResNet34.new
|
86
|
+
TorchVision::Models::ResNet50.new
|
87
|
+
TorchVision::Models::ResNet101.new
|
88
|
+
TorchVision::Models::ResNet152.new
|
89
|
+
```
|
90
|
+
|
91
|
+
### ResNeXt
|
92
|
+
|
93
|
+
```ruby
|
94
|
+
TorchVision::Models::ResNext52_32x4d.new
|
95
|
+
TorchVision::Models::ResNext101_32x8d.new
|
96
|
+
```
|
97
|
+
|
98
|
+
### VGG
|
99
|
+
|
100
|
+
```ruby
|
101
|
+
TorchVision::Models::VGG11.new
|
102
|
+
TorchVision::Models::VGG11BN.new
|
103
|
+
TorchVision::Models::VGG13.new
|
104
|
+
TorchVision::Models::VGG13BN.new
|
105
|
+
TorchVision::Models::VGG16.new
|
106
|
+
TorchVision::Models::VGG16BN.new
|
107
|
+
TorchVision::Models::VGG19.new
|
108
|
+
TorchVision::Models::VGG19BN.new
|
109
|
+
```
|
110
|
+
|
111
|
+
### Wide ResNet
|
22
112
|
|
23
113
|
```ruby
|
24
|
-
|
25
|
-
|
114
|
+
TorchVision::Models::WideResNet52_2.new
|
115
|
+
TorchVision::Models::WideResNet101_2.new
|
26
116
|
```
|
27
117
|
|
118
|
+
## Pretrained Models
|
119
|
+
|
120
|
+
You can download pretrained models with [this script](pretrained.py)
|
121
|
+
|
122
|
+
```sh
|
123
|
+
pip install torchvision
|
124
|
+
python pretrained.py
|
125
|
+
```
|
126
|
+
|
127
|
+
And load them
|
128
|
+
|
129
|
+
```ruby
|
130
|
+
net = TorchVision::Models::ResNet18.new
|
131
|
+
net.load_state_dict(Torch.load("net.pth"))
|
132
|
+
```
|
133
|
+
|
134
|
+
## libvips Installation
|
135
|
+
|
136
|
+
### Linux
|
137
|
+
|
138
|
+
Check your package manager. For Ubuntu, use:
|
139
|
+
|
140
|
+
```sh
|
141
|
+
sudo apt install libvips
|
142
|
+
```
|
143
|
+
|
144
|
+
You can also [build from source](https://libvips.github.io/libvips/install.html).
|
145
|
+
|
146
|
+
### Mac
|
147
|
+
|
148
|
+
```sh
|
149
|
+
brew install vips
|
150
|
+
```
|
151
|
+
|
152
|
+
### Windows
|
153
|
+
|
154
|
+
Check out [the options](https://libvips.github.io/libvips/install.html).
|
155
|
+
|
28
156
|
## Disclaimer
|
29
157
|
|
30
158
|
This library downloads and prepares public datasets. We don’t host any datasets. Be sure to adhere to the license for each dataset.
|
data/lib/torchvision.rb
CHANGED
@@ -1,23 +1,62 @@
|
|
1
1
|
# dependencies
|
2
|
-
require "mini_magick"
|
3
2
|
require "numo/narray"
|
3
|
+
require "vips"
|
4
4
|
require "torch"
|
5
5
|
|
6
6
|
# stdlib
|
7
7
|
require "digest"
|
8
8
|
require "fileutils"
|
9
9
|
require "net/http"
|
10
|
+
require "rubygems/package"
|
11
|
+
require "tmpdir"
|
10
12
|
|
11
13
|
# modules
|
14
|
+
require "torchvision/utils"
|
12
15
|
require "torchvision/version"
|
13
16
|
|
14
17
|
# datasets
|
18
|
+
require "torchvision/datasets/vision_dataset"
|
19
|
+
require "torchvision/datasets/dataset_folder"
|
20
|
+
require "torchvision/datasets/image_folder"
|
21
|
+
require "torchvision/datasets/cifar10"
|
22
|
+
require "torchvision/datasets/cifar100"
|
15
23
|
require "torchvision/datasets/mnist"
|
24
|
+
require "torchvision/datasets/fashion_mnist"
|
25
|
+
require "torchvision/datasets/kmnist"
|
26
|
+
|
27
|
+
# models
|
28
|
+
require "torchvision/models/alexnet"
|
29
|
+
require "torchvision/models/basic_block"
|
30
|
+
require "torchvision/models/bottleneck"
|
31
|
+
require "torchvision/models/resnet"
|
32
|
+
require "torchvision/models/resnet18"
|
33
|
+
require "torchvision/models/resnet34"
|
34
|
+
require "torchvision/models/resnet50"
|
35
|
+
require "torchvision/models/resnet101"
|
36
|
+
require "torchvision/models/resnet152"
|
37
|
+
require "torchvision/models/resnext50_32x4d"
|
38
|
+
require "torchvision/models/resnext101_32x8d"
|
39
|
+
require "torchvision/models/vgg"
|
40
|
+
require "torchvision/models/vgg11"
|
41
|
+
require "torchvision/models/vgg11_bn"
|
42
|
+
require "torchvision/models/vgg13"
|
43
|
+
require "torchvision/models/vgg13_bn"
|
44
|
+
require "torchvision/models/vgg16"
|
45
|
+
require "torchvision/models/vgg16_bn"
|
46
|
+
require "torchvision/models/vgg19"
|
47
|
+
require "torchvision/models/vgg19_bn"
|
48
|
+
require "torchvision/models/wide_resnet50_2"
|
49
|
+
require "torchvision/models/wide_resnet101_2"
|
16
50
|
|
17
51
|
# transforms
|
52
|
+
require "torchvision/transforms/center_crop"
|
18
53
|
require "torchvision/transforms/compose"
|
19
54
|
require "torchvision/transforms/functional"
|
20
55
|
require "torchvision/transforms/normalize"
|
56
|
+
require "torchvision/transforms/random_horizontal_flip"
|
57
|
+
require "torchvision/transforms/random_resized_crop"
|
58
|
+
require "torchvision/transforms/random_vertical_flip"
|
59
|
+
require "torchvision/transforms/resize"
|
21
60
|
require "torchvision/transforms/to_tensor"
|
22
61
|
|
23
62
|
module TorchVision
|
@@ -0,0 +1,117 @@
|
|
1
|
+
module TorchVision
|
2
|
+
module Datasets
|
3
|
+
class CIFAR10 < VisionDataset
|
4
|
+
# https://www.cs.toronto.edu/~kriz/cifar.html
|
5
|
+
|
6
|
+
def initialize(root, train: true, download: false, transform: nil, target_transform: nil)
|
7
|
+
super(root, transform: transform, target_transform: target_transform)
|
8
|
+
@train = train
|
9
|
+
|
10
|
+
self.download if download
|
11
|
+
|
12
|
+
if !_check_integrity
|
13
|
+
raise Error, "Dataset not found or corrupted. You can use download=True to download it"
|
14
|
+
end
|
15
|
+
|
16
|
+
downloaded_list = @train ? train_list : test_list
|
17
|
+
|
18
|
+
@data = String.new
|
19
|
+
@targets = String.new
|
20
|
+
|
21
|
+
downloaded_list.each do |file|
|
22
|
+
file_path = File.join(@root, base_folder, file[:filename])
|
23
|
+
File.open(file_path, "rb") do |f|
|
24
|
+
while !f.eof?
|
25
|
+
f.read(1) if multiple_labels?
|
26
|
+
@targets << f.read(1)
|
27
|
+
@data << f.read(3072)
|
28
|
+
end
|
29
|
+
end
|
30
|
+
end
|
31
|
+
|
32
|
+
@targets = @targets.unpack("C*")
|
33
|
+
# TODO switch i to -1 when Numo supports it
|
34
|
+
@data = Numo::UInt8.from_binary(@data).reshape(@targets.size, 3, 32, 32)
|
35
|
+
@data = @data.transpose(0, 2, 3, 1)
|
36
|
+
end
|
37
|
+
|
38
|
+
def size
|
39
|
+
@data.shape[0]
|
40
|
+
end
|
41
|
+
|
42
|
+
def [](index)
|
43
|
+
# TODO remove trues when Numo supports it
|
44
|
+
img, target = @data[index, true, true, true], @targets[index]
|
45
|
+
|
46
|
+
img = Utils.image_from_array(img)
|
47
|
+
|
48
|
+
img = @transform.call(img) if @transform
|
49
|
+
|
50
|
+
target = @target_transform.call(target) if @target_transform
|
51
|
+
|
52
|
+
[img, target]
|
53
|
+
end
|
54
|
+
|
55
|
+
def _check_integrity
|
56
|
+
root = @root
|
57
|
+
(train_list + test_list).each do |fentry|
|
58
|
+
fpath = File.join(root, base_folder, fentry[:filename])
|
59
|
+
return false unless check_integrity(fpath, fentry[:sha256])
|
60
|
+
end
|
61
|
+
true
|
62
|
+
end
|
63
|
+
|
64
|
+
def download
|
65
|
+
if _check_integrity
|
66
|
+
puts "Files already downloaded and verified"
|
67
|
+
return
|
68
|
+
end
|
69
|
+
|
70
|
+
download_file(url, download_root: @root, filename: filename, sha256: tgz_sha256)
|
71
|
+
|
72
|
+
path = File.join(@root, filename)
|
73
|
+
File.open(path, "rb") do |io|
|
74
|
+
Gem::Package.new("").extract_tar_gz(io, @root)
|
75
|
+
end
|
76
|
+
end
|
77
|
+
|
78
|
+
private
|
79
|
+
|
80
|
+
def base_folder
|
81
|
+
"cifar-10-batches-bin"
|
82
|
+
end
|
83
|
+
|
84
|
+
def url
|
85
|
+
"https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz"
|
86
|
+
end
|
87
|
+
|
88
|
+
def filename
|
89
|
+
"cifar-10-binary.tar.gz"
|
90
|
+
end
|
91
|
+
|
92
|
+
def tgz_sha256
|
93
|
+
"c4a38c50a1bc5f3a1c5537f2155ab9d68f9f25eb1ed8d9ddda3db29a59bca1dd"
|
94
|
+
end
|
95
|
+
|
96
|
+
def train_list
|
97
|
+
[
|
98
|
+
{filename: "data_batch_1.bin", sha256: "cee916563c9f80d84e3cc88e17fdc0941787f1244f00a67874d45b261883ada5"},
|
99
|
+
{filename: "data_batch_2.bin", sha256: "a591ca11fa1708a91ee40f54b3da4784ccd871ecf2137de63f51ada8b3fa57ed"},
|
100
|
+
{filename: "data_batch_3.bin", sha256: "bbe8596564c0f86427f876058170b84dac6670ddf06d79402899d93ceea26f67"},
|
101
|
+
{filename: "data_batch_4.bin", sha256: "014e562d6e23c72197cc727519169a60359f5eccd8945ad5a09d710285ff4e48"},
|
102
|
+
{filename: "data_batch_5.bin", sha256: "755304fc0b379caeae8c14f0dac912fbc7d6cd469eb67a1029a08a39453a9add"},
|
103
|
+
]
|
104
|
+
end
|
105
|
+
|
106
|
+
def test_list
|
107
|
+
[
|
108
|
+
{filename: "test_batch.bin", sha256: "8e2eb146ae340b09e24670f29cabc6326dba54da8789dab6768acf480273f65b"}
|
109
|
+
]
|
110
|
+
end
|
111
|
+
|
112
|
+
def multiple_labels?
|
113
|
+
false
|
114
|
+
end
|
115
|
+
end
|
116
|
+
end
|
117
|
+
end
|
@@ -0,0 +1,41 @@
|
|
1
|
+
module TorchVision
|
2
|
+
module Datasets
|
3
|
+
class CIFAR100 < CIFAR10
|
4
|
+
# https://www.cs.toronto.edu/~kriz/cifar.html
|
5
|
+
|
6
|
+
private
|
7
|
+
|
8
|
+
def base_folder
|
9
|
+
"cifar-100-binary"
|
10
|
+
end
|
11
|
+
|
12
|
+
def url
|
13
|
+
"https://www.cs.toronto.edu/~kriz/cifar-100-binary.tar.gz"
|
14
|
+
end
|
15
|
+
|
16
|
+
def filename
|
17
|
+
"cifar-100-binary.tar.gz"
|
18
|
+
end
|
19
|
+
|
20
|
+
def tgz_sha256
|
21
|
+
"58a81ae192c23a4be8b1804d68e518ed807d710a4eb253b1f2a199162a40d8ec"
|
22
|
+
end
|
23
|
+
|
24
|
+
def train_list
|
25
|
+
[
|
26
|
+
{filename: "train.bin", sha256: "f31298fc616915fa142368359df1c4ca2ae984d6915ca468b998a5ec6aeebf29"}
|
27
|
+
]
|
28
|
+
end
|
29
|
+
|
30
|
+
def test_list
|
31
|
+
[
|
32
|
+
{filename: "test.bin", sha256: "d8b1e6b7b3bee4020055f0699b111f60b1af1e262aeb93a0b659061746f8224a"}
|
33
|
+
]
|
34
|
+
end
|
35
|
+
|
36
|
+
def multiple_labels?
|
37
|
+
true
|
38
|
+
end
|
39
|
+
end
|
40
|
+
end
|
41
|
+
end
|