torchvision 0.1.0 → 0.2.1

Sign up to get free protection for your applications and to get access to all the features.
Files changed (47) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +35 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +133 -5
  5. data/lib/torchvision.rb +40 -1
  6. data/lib/torchvision/datasets/cifar10.rb +117 -0
  7. data/lib/torchvision/datasets/cifar100.rb +41 -0
  8. data/lib/torchvision/datasets/dataset_folder.rb +91 -0
  9. data/lib/torchvision/datasets/fashion_mnist.rb +30 -0
  10. data/lib/torchvision/datasets/image_folder.rb +12 -0
  11. data/lib/torchvision/datasets/kmnist.rb +30 -0
  12. data/lib/torchvision/datasets/mnist.rb +47 -76
  13. data/lib/torchvision/datasets/vision_dataset.rb +67 -0
  14. data/lib/torchvision/models/alexnet.rb +42 -0
  15. data/lib/torchvision/models/basic_block.rb +46 -0
  16. data/lib/torchvision/models/bottleneck.rb +47 -0
  17. data/lib/torchvision/models/resnet.rb +129 -0
  18. data/lib/torchvision/models/resnet101.rb +9 -0
  19. data/lib/torchvision/models/resnet152.rb +9 -0
  20. data/lib/torchvision/models/resnet18.rb +9 -0
  21. data/lib/torchvision/models/resnet34.rb +9 -0
  22. data/lib/torchvision/models/resnet50.rb +9 -0
  23. data/lib/torchvision/models/resnext101_32x8d.rb +11 -0
  24. data/lib/torchvision/models/resnext50_32x4d.rb +11 -0
  25. data/lib/torchvision/models/vgg.rb +93 -0
  26. data/lib/torchvision/models/vgg11.rb +9 -0
  27. data/lib/torchvision/models/vgg11_bn.rb +9 -0
  28. data/lib/torchvision/models/vgg13.rb +9 -0
  29. data/lib/torchvision/models/vgg13_bn.rb +9 -0
  30. data/lib/torchvision/models/vgg16.rb +9 -0
  31. data/lib/torchvision/models/vgg16_bn.rb +9 -0
  32. data/lib/torchvision/models/vgg19.rb +9 -0
  33. data/lib/torchvision/models/vgg19_bn.rb +9 -0
  34. data/lib/torchvision/models/wide_resnet101_2.rb +10 -0
  35. data/lib/torchvision/models/wide_resnet50_2.rb +10 -0
  36. data/lib/torchvision/transforms/center_crop.rb +13 -0
  37. data/lib/torchvision/transforms/compose.rb +2 -2
  38. data/lib/torchvision/transforms/functional.rb +142 -7
  39. data/lib/torchvision/transforms/normalize.rb +2 -2
  40. data/lib/torchvision/transforms/random_horizontal_flip.rb +18 -0
  41. data/lib/torchvision/transforms/random_resized_crop.rb +70 -0
  42. data/lib/torchvision/transforms/random_vertical_flip.rb +18 -0
  43. data/lib/torchvision/transforms/resize.rb +13 -0
  44. data/lib/torchvision/transforms/to_tensor.rb +2 -2
  45. data/lib/torchvision/utils.rb +120 -0
  46. data/lib/torchvision/version.rb +1 -1
  47. metadata +50 -57
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 19fff61bb461e5fbf850702bf485a442b2013807ac28549bc427e5f3d8c7472b
4
- data.tar.gz: bc8004d2ca26e9022f2fa2b1663277bcedf701729f59aae248058a26e605a5ad
3
+ metadata.gz: bbb87c59c0f081c0de57ccdd62e30bfc551e1cb69523e4ffd498c997e1a2d8b3
4
+ data.tar.gz: 890da113706e659d57194980c5c9262075beb8398a75da2997c0812b70abe308
5
5
  SHA512:
6
- metadata.gz: fbd3d7292efa6ee2fd2c0ff8cb85659d37a19761b4d93a9a4923a9990d7400c849738913db12720e7d232d6fdb180c16f06da0ecc3601a922bf0036beb0b44bd
7
- data.tar.gz: '09b86d6b01f25d43ac65d9c3d0509b3488ed2108d57090553ae46526f841551cbdd16fefce1cb15ea7f326d946e3abb26d777c3c45e23538f8fb7753fdb6fec9'
6
+ metadata.gz: 3445b62b7824ae16205034881d37c48ac4c70d7e5677014755ae5600632f9ce45168f41b0d3e98c8104eb8337e1566db4df3e0ad5ace5e6a46a5d213d01b6c8d
7
+ data.tar.gz: 93f22c385586ff8a010880676806f6bc9ba2f614c4c14886235a300ea6e2abce0f80c260e255644e1b4d24e6ecddfd21830dde2960a92a9492239e69622d4548
data/CHANGELOG.md CHANGED
@@ -1,3 +1,38 @@
1
+ ## 0.2.1 (2021-03-14)
2
+
3
+ - Added `ImageFolder` and `DatasetFolder`
4
+ - Added `CenterCrop` and `RandomResizedCrop` transforms
5
+ - Added `crop` method
6
+
7
+ ## 0.2.0 (2021-03-11)
8
+
9
+ - Added `RandomHorizontalFlip`, `RandomVerticalFlip`, and `Resize` transforms
10
+ - Added `save_image` method
11
+ - Added `data` and `targets` methods to datasets
12
+ - Removed support for Ruby < 2.6
13
+
14
+ Breaking changes
15
+
16
+ - Added dependency on libvips
17
+ - MNIST datasets return images instead of tensors
18
+
19
+ ## 0.1.3 (2020-06-29)
20
+
21
+ - Added AlexNet model
22
+ - Added ResNet34, ResNet50, ResNet101, and ResNet152 models
23
+ - Added ResNeXt model
24
+ - Added VGG11, VGG13, VGG16, and VGG19 models
25
+ - Added Wide ResNet model
26
+
27
+ ## 0.1.2 (2020-04-29)
28
+
29
+ - Added CIFAR10, CIFAR100, FashionMNIST, and KMNIST datasets
30
+ - Added ResNet18 model
31
+
32
+ ## 0.1.1 (2020-04-28)
33
+
34
+ - Removed `mini_magick` for performance
35
+
1
36
  ## 0.1.0 (2020-04-27)
2
37
 
3
38
  - First release
data/LICENSE.txt CHANGED
@@ -1,7 +1,7 @@
1
1
  BSD 3-Clause License
2
2
 
3
- Copyright (c) Andrew Kane 2020,
4
3
  Copyright (c) Soumith Chintala 2016,
4
+ Copyright (c) Andrew Kane 2020-2021,
5
5
  All rights reserved.
6
6
 
7
7
  Redistribution and use in source and binary forms, with or without
data/README.md CHANGED
@@ -2,10 +2,16 @@
2
2
 
3
3
  :fire: Computer vision datasets, transforms, and models for Ruby
4
4
 
5
- This gem is currently experimental. There may be breaking changes between each release. Please report any issues you experience.
5
+ [![Build Status](https://github.com/ankane/torchvision/workflows/build/badge.svg?branch=master)](https://github.com/ankane/torchvision/actions)
6
6
 
7
7
  ## Installation
8
8
 
9
+ First, [install libvips](#libvips-installation). For Homebrew, use:
10
+
11
+ ```sh
12
+ brew install vips
13
+ ```
14
+
9
15
  Add this line to your application’s Gemfile:
10
16
 
11
17
  ```ruby
@@ -14,17 +20,139 @@ gem 'torchvision'
14
20
 
15
21
  ## Getting Started
16
22
 
17
- This library follows the [Python API](https://pytorch.org/docs/master/torchvision/). Many methods and options are missing at the moment. PRs welcome!
23
+ This library follows the [Python API](https://pytorch.org/docs/stable/torchvision/index.html). Many methods and options are missing at the moment. PRs welcome!
24
+
25
+ ## Examples
26
+
27
+ - [MNIST](https://github.com/ankane/torch.rb/tree/master/examples/mnist)
28
+ - [Transfer learning](https://github.com/ankane/torch.rb/tree/master/examples/transfer-learning)
29
+ - [Generative adversarial networks](https://github.com/ankane/torch.rb/tree/master/examples/gan)
18
30
 
19
31
  ## Datasets
20
32
 
21
- MNIST dataset
33
+ Load a dataset
34
+
35
+ ```ruby
36
+ TorchVision::Datasets::MNIST.new("./data", train: true, download: true)
37
+ ```
38
+
39
+ Supported datasets are:
40
+
41
+ - CIFAR10
42
+ - CIFAR100
43
+ - FashionMNIST
44
+ - KMNIST
45
+ - MNIST
46
+
47
+ ## Transforms
48
+
49
+ ```ruby
50
+ TorchVision::Transforms::Compose.new([
51
+ TorchVision::Transforms::ToTensor.new,
52
+ TorchVision::Transforms::Normalize.new([0.1307], [0.3081])
53
+ ])
54
+ ```
55
+
56
+ Supported transforms are:
57
+
58
+ - CenterCrop
59
+ - Compose
60
+ - Normalize
61
+ - RandomHorizontalFlip
62
+ - RandomResizedCrop
63
+ - RandomVerticalFlip
64
+ - Resize
65
+ - ToTensor
66
+
67
+ ## Models
68
+
69
+ - [AlexNet](#alexnet)
70
+ - [ResNet](#resnet)
71
+ - [ResNeXt](#resnext)
72
+ - [VGG](#vgg)
73
+ - [Wide ResNet](#wide-resnet)
74
+
75
+ ### AlexNet
76
+
77
+ ```ruby
78
+ TorchVision::Models::AlexNet.new
79
+ ```
80
+
81
+ ### ResNet
82
+
83
+ ```ruby
84
+ TorchVision::Models::ResNet18.new
85
+ TorchVision::Models::ResNet34.new
86
+ TorchVision::Models::ResNet50.new
87
+ TorchVision::Models::ResNet101.new
88
+ TorchVision::Models::ResNet152.new
89
+ ```
90
+
91
+ ### ResNeXt
92
+
93
+ ```ruby
94
+ TorchVision::Models::ResNext52_32x4d.new
95
+ TorchVision::Models::ResNext101_32x8d.new
96
+ ```
97
+
98
+ ### VGG
99
+
100
+ ```ruby
101
+ TorchVision::Models::VGG11.new
102
+ TorchVision::Models::VGG11BN.new
103
+ TorchVision::Models::VGG13.new
104
+ TorchVision::Models::VGG13BN.new
105
+ TorchVision::Models::VGG16.new
106
+ TorchVision::Models::VGG16BN.new
107
+ TorchVision::Models::VGG19.new
108
+ TorchVision::Models::VGG19BN.new
109
+ ```
110
+
111
+ ### Wide ResNet
22
112
 
23
113
  ```ruby
24
- trainset = TorchVision::Datasets::MNIST.new("./data", train: true, download: true)
25
- trainset.size
114
+ TorchVision::Models::WideResNet52_2.new
115
+ TorchVision::Models::WideResNet101_2.new
26
116
  ```
27
117
 
118
+ ## Pretrained Models
119
+
120
+ You can download pretrained models with [this script](pretrained.py)
121
+
122
+ ```sh
123
+ pip install torchvision
124
+ python pretrained.py
125
+ ```
126
+
127
+ And load them
128
+
129
+ ```ruby
130
+ net = TorchVision::Models::ResNet18.new
131
+ net.load_state_dict(Torch.load("net.pth"))
132
+ ```
133
+
134
+ ## libvips Installation
135
+
136
+ ### Linux
137
+
138
+ Check your package manager. For Ubuntu, use:
139
+
140
+ ```sh
141
+ sudo apt install libvips
142
+ ```
143
+
144
+ You can also [build from source](https://libvips.github.io/libvips/install.html).
145
+
146
+ ### Mac
147
+
148
+ ```sh
149
+ brew install vips
150
+ ```
151
+
152
+ ### Windows
153
+
154
+ Check out [the options](https://libvips.github.io/libvips/install.html).
155
+
28
156
  ## Disclaimer
29
157
 
30
158
  This library downloads and prepares public datasets. We don’t host any datasets. Be sure to adhere to the license for each dataset.
data/lib/torchvision.rb CHANGED
@@ -1,23 +1,62 @@
1
1
  # dependencies
2
- require "mini_magick"
3
2
  require "numo/narray"
3
+ require "vips"
4
4
  require "torch"
5
5
 
6
6
  # stdlib
7
7
  require "digest"
8
8
  require "fileutils"
9
9
  require "net/http"
10
+ require "rubygems/package"
11
+ require "tmpdir"
10
12
 
11
13
  # modules
14
+ require "torchvision/utils"
12
15
  require "torchvision/version"
13
16
 
14
17
  # datasets
18
+ require "torchvision/datasets/vision_dataset"
19
+ require "torchvision/datasets/dataset_folder"
20
+ require "torchvision/datasets/image_folder"
21
+ require "torchvision/datasets/cifar10"
22
+ require "torchvision/datasets/cifar100"
15
23
  require "torchvision/datasets/mnist"
24
+ require "torchvision/datasets/fashion_mnist"
25
+ require "torchvision/datasets/kmnist"
26
+
27
+ # models
28
+ require "torchvision/models/alexnet"
29
+ require "torchvision/models/basic_block"
30
+ require "torchvision/models/bottleneck"
31
+ require "torchvision/models/resnet"
32
+ require "torchvision/models/resnet18"
33
+ require "torchvision/models/resnet34"
34
+ require "torchvision/models/resnet50"
35
+ require "torchvision/models/resnet101"
36
+ require "torchvision/models/resnet152"
37
+ require "torchvision/models/resnext50_32x4d"
38
+ require "torchvision/models/resnext101_32x8d"
39
+ require "torchvision/models/vgg"
40
+ require "torchvision/models/vgg11"
41
+ require "torchvision/models/vgg11_bn"
42
+ require "torchvision/models/vgg13"
43
+ require "torchvision/models/vgg13_bn"
44
+ require "torchvision/models/vgg16"
45
+ require "torchvision/models/vgg16_bn"
46
+ require "torchvision/models/vgg19"
47
+ require "torchvision/models/vgg19_bn"
48
+ require "torchvision/models/wide_resnet50_2"
49
+ require "torchvision/models/wide_resnet101_2"
16
50
 
17
51
  # transforms
52
+ require "torchvision/transforms/center_crop"
18
53
  require "torchvision/transforms/compose"
19
54
  require "torchvision/transforms/functional"
20
55
  require "torchvision/transforms/normalize"
56
+ require "torchvision/transforms/random_horizontal_flip"
57
+ require "torchvision/transforms/random_resized_crop"
58
+ require "torchvision/transforms/random_vertical_flip"
59
+ require "torchvision/transforms/resize"
21
60
  require "torchvision/transforms/to_tensor"
22
61
 
23
62
  module TorchVision
@@ -0,0 +1,117 @@
1
+ module TorchVision
2
+ module Datasets
3
+ class CIFAR10 < VisionDataset
4
+ # https://www.cs.toronto.edu/~kriz/cifar.html
5
+
6
+ def initialize(root, train: true, download: false, transform: nil, target_transform: nil)
7
+ super(root, transform: transform, target_transform: target_transform)
8
+ @train = train
9
+
10
+ self.download if download
11
+
12
+ if !_check_integrity
13
+ raise Error, "Dataset not found or corrupted. You can use download=True to download it"
14
+ end
15
+
16
+ downloaded_list = @train ? train_list : test_list
17
+
18
+ @data = String.new
19
+ @targets = String.new
20
+
21
+ downloaded_list.each do |file|
22
+ file_path = File.join(@root, base_folder, file[:filename])
23
+ File.open(file_path, "rb") do |f|
24
+ while !f.eof?
25
+ f.read(1) if multiple_labels?
26
+ @targets << f.read(1)
27
+ @data << f.read(3072)
28
+ end
29
+ end
30
+ end
31
+
32
+ @targets = @targets.unpack("C*")
33
+ # TODO switch i to -1 when Numo supports it
34
+ @data = Numo::UInt8.from_binary(@data).reshape(@targets.size, 3, 32, 32)
35
+ @data = @data.transpose(0, 2, 3, 1)
36
+ end
37
+
38
+ def size
39
+ @data.shape[0]
40
+ end
41
+
42
+ def [](index)
43
+ # TODO remove trues when Numo supports it
44
+ img, target = @data[index, true, true, true], @targets[index]
45
+
46
+ img = Utils.image_from_array(img)
47
+
48
+ img = @transform.call(img) if @transform
49
+
50
+ target = @target_transform.call(target) if @target_transform
51
+
52
+ [img, target]
53
+ end
54
+
55
+ def _check_integrity
56
+ root = @root
57
+ (train_list + test_list).each do |fentry|
58
+ fpath = File.join(root, base_folder, fentry[:filename])
59
+ return false unless check_integrity(fpath, fentry[:sha256])
60
+ end
61
+ true
62
+ end
63
+
64
+ def download
65
+ if _check_integrity
66
+ puts "Files already downloaded and verified"
67
+ return
68
+ end
69
+
70
+ download_file(url, download_root: @root, filename: filename, sha256: tgz_sha256)
71
+
72
+ path = File.join(@root, filename)
73
+ File.open(path, "rb") do |io|
74
+ Gem::Package.new("").extract_tar_gz(io, @root)
75
+ end
76
+ end
77
+
78
+ private
79
+
80
+ def base_folder
81
+ "cifar-10-batches-bin"
82
+ end
83
+
84
+ def url
85
+ "https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz"
86
+ end
87
+
88
+ def filename
89
+ "cifar-10-binary.tar.gz"
90
+ end
91
+
92
+ def tgz_sha256
93
+ "c4a38c50a1bc5f3a1c5537f2155ab9d68f9f25eb1ed8d9ddda3db29a59bca1dd"
94
+ end
95
+
96
+ def train_list
97
+ [
98
+ {filename: "data_batch_1.bin", sha256: "cee916563c9f80d84e3cc88e17fdc0941787f1244f00a67874d45b261883ada5"},
99
+ {filename: "data_batch_2.bin", sha256: "a591ca11fa1708a91ee40f54b3da4784ccd871ecf2137de63f51ada8b3fa57ed"},
100
+ {filename: "data_batch_3.bin", sha256: "bbe8596564c0f86427f876058170b84dac6670ddf06d79402899d93ceea26f67"},
101
+ {filename: "data_batch_4.bin", sha256: "014e562d6e23c72197cc727519169a60359f5eccd8945ad5a09d710285ff4e48"},
102
+ {filename: "data_batch_5.bin", sha256: "755304fc0b379caeae8c14f0dac912fbc7d6cd469eb67a1029a08a39453a9add"},
103
+ ]
104
+ end
105
+
106
+ def test_list
107
+ [
108
+ {filename: "test_batch.bin", sha256: "8e2eb146ae340b09e24670f29cabc6326dba54da8789dab6768acf480273f65b"}
109
+ ]
110
+ end
111
+
112
+ def multiple_labels?
113
+ false
114
+ end
115
+ end
116
+ end
117
+ end
@@ -0,0 +1,41 @@
1
+ module TorchVision
2
+ module Datasets
3
+ class CIFAR100 < CIFAR10
4
+ # https://www.cs.toronto.edu/~kriz/cifar.html
5
+
6
+ private
7
+
8
+ def base_folder
9
+ "cifar-100-binary"
10
+ end
11
+
12
+ def url
13
+ "https://www.cs.toronto.edu/~kriz/cifar-100-binary.tar.gz"
14
+ end
15
+
16
+ def filename
17
+ "cifar-100-binary.tar.gz"
18
+ end
19
+
20
+ def tgz_sha256
21
+ "58a81ae192c23a4be8b1804d68e518ed807d710a4eb253b1f2a199162a40d8ec"
22
+ end
23
+
24
+ def train_list
25
+ [
26
+ {filename: "train.bin", sha256: "f31298fc616915fa142368359df1c4ca2ae984d6915ca468b998a5ec6aeebf29"}
27
+ ]
28
+ end
29
+
30
+ def test_list
31
+ [
32
+ {filename: "test.bin", sha256: "d8b1e6b7b3bee4020055f0699b111f60b1af1e262aeb93a0b659061746f8224a"}
33
+ ]
34
+ end
35
+
36
+ def multiple_labels?
37
+ true
38
+ end
39
+ end
40
+ end
41
+ end