torchvision 0.1.0 → 0.2.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +35 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +133 -5
  5. data/lib/torchvision.rb +40 -1
  6. data/lib/torchvision/datasets/cifar10.rb +117 -0
  7. data/lib/torchvision/datasets/cifar100.rb +41 -0
  8. data/lib/torchvision/datasets/dataset_folder.rb +91 -0
  9. data/lib/torchvision/datasets/fashion_mnist.rb +30 -0
  10. data/lib/torchvision/datasets/image_folder.rb +12 -0
  11. data/lib/torchvision/datasets/kmnist.rb +30 -0
  12. data/lib/torchvision/datasets/mnist.rb +47 -76
  13. data/lib/torchvision/datasets/vision_dataset.rb +67 -0
  14. data/lib/torchvision/models/alexnet.rb +42 -0
  15. data/lib/torchvision/models/basic_block.rb +46 -0
  16. data/lib/torchvision/models/bottleneck.rb +47 -0
  17. data/lib/torchvision/models/resnet.rb +129 -0
  18. data/lib/torchvision/models/resnet101.rb +9 -0
  19. data/lib/torchvision/models/resnet152.rb +9 -0
  20. data/lib/torchvision/models/resnet18.rb +9 -0
  21. data/lib/torchvision/models/resnet34.rb +9 -0
  22. data/lib/torchvision/models/resnet50.rb +9 -0
  23. data/lib/torchvision/models/resnext101_32x8d.rb +11 -0
  24. data/lib/torchvision/models/resnext50_32x4d.rb +11 -0
  25. data/lib/torchvision/models/vgg.rb +93 -0
  26. data/lib/torchvision/models/vgg11.rb +9 -0
  27. data/lib/torchvision/models/vgg11_bn.rb +9 -0
  28. data/lib/torchvision/models/vgg13.rb +9 -0
  29. data/lib/torchvision/models/vgg13_bn.rb +9 -0
  30. data/lib/torchvision/models/vgg16.rb +9 -0
  31. data/lib/torchvision/models/vgg16_bn.rb +9 -0
  32. data/lib/torchvision/models/vgg19.rb +9 -0
  33. data/lib/torchvision/models/vgg19_bn.rb +9 -0
  34. data/lib/torchvision/models/wide_resnet101_2.rb +10 -0
  35. data/lib/torchvision/models/wide_resnet50_2.rb +10 -0
  36. data/lib/torchvision/transforms/center_crop.rb +13 -0
  37. data/lib/torchvision/transforms/compose.rb +2 -2
  38. data/lib/torchvision/transforms/functional.rb +142 -7
  39. data/lib/torchvision/transforms/normalize.rb +2 -2
  40. data/lib/torchvision/transforms/random_horizontal_flip.rb +18 -0
  41. data/lib/torchvision/transforms/random_resized_crop.rb +70 -0
  42. data/lib/torchvision/transforms/random_vertical_flip.rb +18 -0
  43. data/lib/torchvision/transforms/resize.rb +13 -0
  44. data/lib/torchvision/transforms/to_tensor.rb +2 -2
  45. data/lib/torchvision/utils.rb +120 -0
  46. data/lib/torchvision/version.rb +1 -1
  47. metadata +50 -57
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 19fff61bb461e5fbf850702bf485a442b2013807ac28549bc427e5f3d8c7472b
4
- data.tar.gz: bc8004d2ca26e9022f2fa2b1663277bcedf701729f59aae248058a26e605a5ad
3
+ metadata.gz: bbb87c59c0f081c0de57ccdd62e30bfc551e1cb69523e4ffd498c997e1a2d8b3
4
+ data.tar.gz: 890da113706e659d57194980c5c9262075beb8398a75da2997c0812b70abe308
5
5
  SHA512:
6
- metadata.gz: fbd3d7292efa6ee2fd2c0ff8cb85659d37a19761b4d93a9a4923a9990d7400c849738913db12720e7d232d6fdb180c16f06da0ecc3601a922bf0036beb0b44bd
7
- data.tar.gz: '09b86d6b01f25d43ac65d9c3d0509b3488ed2108d57090553ae46526f841551cbdd16fefce1cb15ea7f326d946e3abb26d777c3c45e23538f8fb7753fdb6fec9'
6
+ metadata.gz: 3445b62b7824ae16205034881d37c48ac4c70d7e5677014755ae5600632f9ce45168f41b0d3e98c8104eb8337e1566db4df3e0ad5ace5e6a46a5d213d01b6c8d
7
+ data.tar.gz: 93f22c385586ff8a010880676806f6bc9ba2f614c4c14886235a300ea6e2abce0f80c260e255644e1b4d24e6ecddfd21830dde2960a92a9492239e69622d4548
data/CHANGELOG.md CHANGED
@@ -1,3 +1,38 @@
1
+ ## 0.2.1 (2021-03-14)
2
+
3
+ - Added `ImageFolder` and `DatasetFolder`
4
+ - Added `CenterCrop` and `RandomResizedCrop` transforms
5
+ - Added `crop` method
6
+
7
+ ## 0.2.0 (2021-03-11)
8
+
9
+ - Added `RandomHorizontalFlip`, `RandomVerticalFlip`, and `Resize` transforms
10
+ - Added `save_image` method
11
+ - Added `data` and `targets` methods to datasets
12
+ - Removed support for Ruby < 2.6
13
+
14
+ Breaking changes
15
+
16
+ - Added dependency on libvips
17
+ - MNIST datasets return images instead of tensors
18
+
19
+ ## 0.1.3 (2020-06-29)
20
+
21
+ - Added AlexNet model
22
+ - Added ResNet34, ResNet50, ResNet101, and ResNet152 models
23
+ - Added ResNeXt model
24
+ - Added VGG11, VGG13, VGG16, and VGG19 models
25
+ - Added Wide ResNet model
26
+
27
+ ## 0.1.2 (2020-04-29)
28
+
29
+ - Added CIFAR10, CIFAR100, FashionMNIST, and KMNIST datasets
30
+ - Added ResNet18 model
31
+
32
+ ## 0.1.1 (2020-04-28)
33
+
34
+ - Removed `mini_magick` for performance
35
+
1
36
  ## 0.1.0 (2020-04-27)
2
37
 
3
38
  - First release
data/LICENSE.txt CHANGED
@@ -1,7 +1,7 @@
1
1
  BSD 3-Clause License
2
2
 
3
- Copyright (c) Andrew Kane 2020,
4
3
  Copyright (c) Soumith Chintala 2016,
4
+ Copyright (c) Andrew Kane 2020-2021,
5
5
  All rights reserved.
6
6
 
7
7
  Redistribution and use in source and binary forms, with or without
data/README.md CHANGED
@@ -2,10 +2,16 @@
2
2
 
3
3
  :fire: Computer vision datasets, transforms, and models for Ruby
4
4
 
5
- This gem is currently experimental. There may be breaking changes between each release. Please report any issues you experience.
5
+ [![Build Status](https://github.com/ankane/torchvision/workflows/build/badge.svg?branch=master)](https://github.com/ankane/torchvision/actions)
6
6
 
7
7
  ## Installation
8
8
 
9
+ First, [install libvips](#libvips-installation). For Homebrew, use:
10
+
11
+ ```sh
12
+ brew install vips
13
+ ```
14
+
9
15
  Add this line to your application’s Gemfile:
10
16
 
11
17
  ```ruby
@@ -14,17 +20,139 @@ gem 'torchvision'
14
20
 
15
21
  ## Getting Started
16
22
 
17
- This library follows the [Python API](https://pytorch.org/docs/master/torchvision/). Many methods and options are missing at the moment. PRs welcome!
23
+ This library follows the [Python API](https://pytorch.org/docs/stable/torchvision/index.html). Many methods and options are missing at the moment. PRs welcome!
24
+
25
+ ## Examples
26
+
27
+ - [MNIST](https://github.com/ankane/torch.rb/tree/master/examples/mnist)
28
+ - [Transfer learning](https://github.com/ankane/torch.rb/tree/master/examples/transfer-learning)
29
+ - [Generative adversarial networks](https://github.com/ankane/torch.rb/tree/master/examples/gan)
18
30
 
19
31
  ## Datasets
20
32
 
21
- MNIST dataset
33
+ Load a dataset
34
+
35
+ ```ruby
36
+ TorchVision::Datasets::MNIST.new("./data", train: true, download: true)
37
+ ```
38
+
39
+ Supported datasets are:
40
+
41
+ - CIFAR10
42
+ - CIFAR100
43
+ - FashionMNIST
44
+ - KMNIST
45
+ - MNIST
46
+
47
+ ## Transforms
48
+
49
+ ```ruby
50
+ TorchVision::Transforms::Compose.new([
51
+ TorchVision::Transforms::ToTensor.new,
52
+ TorchVision::Transforms::Normalize.new([0.1307], [0.3081])
53
+ ])
54
+ ```
55
+
56
+ Supported transforms are:
57
+
58
+ - CenterCrop
59
+ - Compose
60
+ - Normalize
61
+ - RandomHorizontalFlip
62
+ - RandomResizedCrop
63
+ - RandomVerticalFlip
64
+ - Resize
65
+ - ToTensor
66
+
67
+ ## Models
68
+
69
+ - [AlexNet](#alexnet)
70
+ - [ResNet](#resnet)
71
+ - [ResNeXt](#resnext)
72
+ - [VGG](#vgg)
73
+ - [Wide ResNet](#wide-resnet)
74
+
75
+ ### AlexNet
76
+
77
+ ```ruby
78
+ TorchVision::Models::AlexNet.new
79
+ ```
80
+
81
+ ### ResNet
82
+
83
+ ```ruby
84
+ TorchVision::Models::ResNet18.new
85
+ TorchVision::Models::ResNet34.new
86
+ TorchVision::Models::ResNet50.new
87
+ TorchVision::Models::ResNet101.new
88
+ TorchVision::Models::ResNet152.new
89
+ ```
90
+
91
+ ### ResNeXt
92
+
93
+ ```ruby
94
+ TorchVision::Models::ResNext52_32x4d.new
95
+ TorchVision::Models::ResNext101_32x8d.new
96
+ ```
97
+
98
+ ### VGG
99
+
100
+ ```ruby
101
+ TorchVision::Models::VGG11.new
102
+ TorchVision::Models::VGG11BN.new
103
+ TorchVision::Models::VGG13.new
104
+ TorchVision::Models::VGG13BN.new
105
+ TorchVision::Models::VGG16.new
106
+ TorchVision::Models::VGG16BN.new
107
+ TorchVision::Models::VGG19.new
108
+ TorchVision::Models::VGG19BN.new
109
+ ```
110
+
111
+ ### Wide ResNet
22
112
 
23
113
  ```ruby
24
- trainset = TorchVision::Datasets::MNIST.new("./data", train: true, download: true)
25
- trainset.size
114
+ TorchVision::Models::WideResNet52_2.new
115
+ TorchVision::Models::WideResNet101_2.new
26
116
  ```
27
117
 
118
+ ## Pretrained Models
119
+
120
+ You can download pretrained models with [this script](pretrained.py)
121
+
122
+ ```sh
123
+ pip install torchvision
124
+ python pretrained.py
125
+ ```
126
+
127
+ And load them
128
+
129
+ ```ruby
130
+ net = TorchVision::Models::ResNet18.new
131
+ net.load_state_dict(Torch.load("net.pth"))
132
+ ```
133
+
134
+ ## libvips Installation
135
+
136
+ ### Linux
137
+
138
+ Check your package manager. For Ubuntu, use:
139
+
140
+ ```sh
141
+ sudo apt install libvips
142
+ ```
143
+
144
+ You can also [build from source](https://libvips.github.io/libvips/install.html).
145
+
146
+ ### Mac
147
+
148
+ ```sh
149
+ brew install vips
150
+ ```
151
+
152
+ ### Windows
153
+
154
+ Check out [the options](https://libvips.github.io/libvips/install.html).
155
+
28
156
  ## Disclaimer
29
157
 
30
158
  This library downloads and prepares public datasets. We don’t host any datasets. Be sure to adhere to the license for each dataset.
data/lib/torchvision.rb CHANGED
@@ -1,23 +1,62 @@
1
1
  # dependencies
2
- require "mini_magick"
3
2
  require "numo/narray"
3
+ require "vips"
4
4
  require "torch"
5
5
 
6
6
  # stdlib
7
7
  require "digest"
8
8
  require "fileutils"
9
9
  require "net/http"
10
+ require "rubygems/package"
11
+ require "tmpdir"
10
12
 
11
13
  # modules
14
+ require "torchvision/utils"
12
15
  require "torchvision/version"
13
16
 
14
17
  # datasets
18
+ require "torchvision/datasets/vision_dataset"
19
+ require "torchvision/datasets/dataset_folder"
20
+ require "torchvision/datasets/image_folder"
21
+ require "torchvision/datasets/cifar10"
22
+ require "torchvision/datasets/cifar100"
15
23
  require "torchvision/datasets/mnist"
24
+ require "torchvision/datasets/fashion_mnist"
25
+ require "torchvision/datasets/kmnist"
26
+
27
+ # models
28
+ require "torchvision/models/alexnet"
29
+ require "torchvision/models/basic_block"
30
+ require "torchvision/models/bottleneck"
31
+ require "torchvision/models/resnet"
32
+ require "torchvision/models/resnet18"
33
+ require "torchvision/models/resnet34"
34
+ require "torchvision/models/resnet50"
35
+ require "torchvision/models/resnet101"
36
+ require "torchvision/models/resnet152"
37
+ require "torchvision/models/resnext50_32x4d"
38
+ require "torchvision/models/resnext101_32x8d"
39
+ require "torchvision/models/vgg"
40
+ require "torchvision/models/vgg11"
41
+ require "torchvision/models/vgg11_bn"
42
+ require "torchvision/models/vgg13"
43
+ require "torchvision/models/vgg13_bn"
44
+ require "torchvision/models/vgg16"
45
+ require "torchvision/models/vgg16_bn"
46
+ require "torchvision/models/vgg19"
47
+ require "torchvision/models/vgg19_bn"
48
+ require "torchvision/models/wide_resnet50_2"
49
+ require "torchvision/models/wide_resnet101_2"
16
50
 
17
51
  # transforms
52
+ require "torchvision/transforms/center_crop"
18
53
  require "torchvision/transforms/compose"
19
54
  require "torchvision/transforms/functional"
20
55
  require "torchvision/transforms/normalize"
56
+ require "torchvision/transforms/random_horizontal_flip"
57
+ require "torchvision/transforms/random_resized_crop"
58
+ require "torchvision/transforms/random_vertical_flip"
59
+ require "torchvision/transforms/resize"
21
60
  require "torchvision/transforms/to_tensor"
22
61
 
23
62
  module TorchVision
@@ -0,0 +1,117 @@
1
+ module TorchVision
2
+ module Datasets
3
+ class CIFAR10 < VisionDataset
4
+ # https://www.cs.toronto.edu/~kriz/cifar.html
5
+
6
+ def initialize(root, train: true, download: false, transform: nil, target_transform: nil)
7
+ super(root, transform: transform, target_transform: target_transform)
8
+ @train = train
9
+
10
+ self.download if download
11
+
12
+ if !_check_integrity
13
+ raise Error, "Dataset not found or corrupted. You can use download=True to download it"
14
+ end
15
+
16
+ downloaded_list = @train ? train_list : test_list
17
+
18
+ @data = String.new
19
+ @targets = String.new
20
+
21
+ downloaded_list.each do |file|
22
+ file_path = File.join(@root, base_folder, file[:filename])
23
+ File.open(file_path, "rb") do |f|
24
+ while !f.eof?
25
+ f.read(1) if multiple_labels?
26
+ @targets << f.read(1)
27
+ @data << f.read(3072)
28
+ end
29
+ end
30
+ end
31
+
32
+ @targets = @targets.unpack("C*")
33
+ # TODO switch i to -1 when Numo supports it
34
+ @data = Numo::UInt8.from_binary(@data).reshape(@targets.size, 3, 32, 32)
35
+ @data = @data.transpose(0, 2, 3, 1)
36
+ end
37
+
38
+ def size
39
+ @data.shape[0]
40
+ end
41
+
42
+ def [](index)
43
+ # TODO remove trues when Numo supports it
44
+ img, target = @data[index, true, true, true], @targets[index]
45
+
46
+ img = Utils.image_from_array(img)
47
+
48
+ img = @transform.call(img) if @transform
49
+
50
+ target = @target_transform.call(target) if @target_transform
51
+
52
+ [img, target]
53
+ end
54
+
55
+ def _check_integrity
56
+ root = @root
57
+ (train_list + test_list).each do |fentry|
58
+ fpath = File.join(root, base_folder, fentry[:filename])
59
+ return false unless check_integrity(fpath, fentry[:sha256])
60
+ end
61
+ true
62
+ end
63
+
64
+ def download
65
+ if _check_integrity
66
+ puts "Files already downloaded and verified"
67
+ return
68
+ end
69
+
70
+ download_file(url, download_root: @root, filename: filename, sha256: tgz_sha256)
71
+
72
+ path = File.join(@root, filename)
73
+ File.open(path, "rb") do |io|
74
+ Gem::Package.new("").extract_tar_gz(io, @root)
75
+ end
76
+ end
77
+
78
+ private
79
+
80
+ def base_folder
81
+ "cifar-10-batches-bin"
82
+ end
83
+
84
+ def url
85
+ "https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz"
86
+ end
87
+
88
+ def filename
89
+ "cifar-10-binary.tar.gz"
90
+ end
91
+
92
+ def tgz_sha256
93
+ "c4a38c50a1bc5f3a1c5537f2155ab9d68f9f25eb1ed8d9ddda3db29a59bca1dd"
94
+ end
95
+
96
+ def train_list
97
+ [
98
+ {filename: "data_batch_1.bin", sha256: "cee916563c9f80d84e3cc88e17fdc0941787f1244f00a67874d45b261883ada5"},
99
+ {filename: "data_batch_2.bin", sha256: "a591ca11fa1708a91ee40f54b3da4784ccd871ecf2137de63f51ada8b3fa57ed"},
100
+ {filename: "data_batch_3.bin", sha256: "bbe8596564c0f86427f876058170b84dac6670ddf06d79402899d93ceea26f67"},
101
+ {filename: "data_batch_4.bin", sha256: "014e562d6e23c72197cc727519169a60359f5eccd8945ad5a09d710285ff4e48"},
102
+ {filename: "data_batch_5.bin", sha256: "755304fc0b379caeae8c14f0dac912fbc7d6cd469eb67a1029a08a39453a9add"},
103
+ ]
104
+ end
105
+
106
+ def test_list
107
+ [
108
+ {filename: "test_batch.bin", sha256: "8e2eb146ae340b09e24670f29cabc6326dba54da8789dab6768acf480273f65b"}
109
+ ]
110
+ end
111
+
112
+ def multiple_labels?
113
+ false
114
+ end
115
+ end
116
+ end
117
+ end
@@ -0,0 +1,41 @@
1
+ module TorchVision
2
+ module Datasets
3
+ class CIFAR100 < CIFAR10
4
+ # https://www.cs.toronto.edu/~kriz/cifar.html
5
+
6
+ private
7
+
8
+ def base_folder
9
+ "cifar-100-binary"
10
+ end
11
+
12
+ def url
13
+ "https://www.cs.toronto.edu/~kriz/cifar-100-binary.tar.gz"
14
+ end
15
+
16
+ def filename
17
+ "cifar-100-binary.tar.gz"
18
+ end
19
+
20
+ def tgz_sha256
21
+ "58a81ae192c23a4be8b1804d68e518ed807d710a4eb253b1f2a199162a40d8ec"
22
+ end
23
+
24
+ def train_list
25
+ [
26
+ {filename: "train.bin", sha256: "f31298fc616915fa142368359df1c4ca2ae984d6915ca468b998a5ec6aeebf29"}
27
+ ]
28
+ end
29
+
30
+ def test_list
31
+ [
32
+ {filename: "test.bin", sha256: "d8b1e6b7b3bee4020055f0699b111f60b1af1e262aeb93a0b659061746f8224a"}
33
+ ]
34
+ end
35
+
36
+ def multiple_labels?
37
+ true
38
+ end
39
+ end
40
+ end
41
+ end