torchvision 0.1.1 → 0.2.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +35 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +133 -7
  5. data/lib/torchvision.rb +39 -0
  6. data/lib/torchvision/datasets/cifar10.rb +117 -0
  7. data/lib/torchvision/datasets/cifar100.rb +41 -0
  8. data/lib/torchvision/datasets/dataset_folder.rb +91 -0
  9. data/lib/torchvision/datasets/fashion_mnist.rb +30 -0
  10. data/lib/torchvision/datasets/image_folder.rb +12 -0
  11. data/lib/torchvision/datasets/kmnist.rb +30 -0
  12. data/lib/torchvision/datasets/mnist.rb +47 -75
  13. data/lib/torchvision/datasets/vision_dataset.rb +67 -0
  14. data/lib/torchvision/models/alexnet.rb +42 -0
  15. data/lib/torchvision/models/basic_block.rb +46 -0
  16. data/lib/torchvision/models/bottleneck.rb +47 -0
  17. data/lib/torchvision/models/resnet.rb +129 -0
  18. data/lib/torchvision/models/resnet101.rb +9 -0
  19. data/lib/torchvision/models/resnet152.rb +9 -0
  20. data/lib/torchvision/models/resnet18.rb +9 -0
  21. data/lib/torchvision/models/resnet34.rb +9 -0
  22. data/lib/torchvision/models/resnet50.rb +9 -0
  23. data/lib/torchvision/models/resnext101_32x8d.rb +11 -0
  24. data/lib/torchvision/models/resnext50_32x4d.rb +11 -0
  25. data/lib/torchvision/models/vgg.rb +93 -0
  26. data/lib/torchvision/models/vgg11.rb +9 -0
  27. data/lib/torchvision/models/vgg11_bn.rb +9 -0
  28. data/lib/torchvision/models/vgg13.rb +9 -0
  29. data/lib/torchvision/models/vgg13_bn.rb +9 -0
  30. data/lib/torchvision/models/vgg16.rb +9 -0
  31. data/lib/torchvision/models/vgg16_bn.rb +9 -0
  32. data/lib/torchvision/models/vgg19.rb +9 -0
  33. data/lib/torchvision/models/vgg19_bn.rb +9 -0
  34. data/lib/torchvision/models/wide_resnet101_2.rb +10 -0
  35. data/lib/torchvision/models/wide_resnet50_2.rb +10 -0
  36. data/lib/torchvision/transforms/center_crop.rb +13 -0
  37. data/lib/torchvision/transforms/compose.rb +2 -2
  38. data/lib/torchvision/transforms/functional.rb +142 -8
  39. data/lib/torchvision/transforms/normalize.rb +2 -2
  40. data/lib/torchvision/transforms/random_horizontal_flip.rb +18 -0
  41. data/lib/torchvision/transforms/random_resized_crop.rb +70 -0
  42. data/lib/torchvision/transforms/random_vertical_flip.rb +18 -0
  43. data/lib/torchvision/transforms/resize.rb +13 -0
  44. data/lib/torchvision/transforms/to_tensor.rb +2 -2
  45. data/lib/torchvision/utils.rb +118 -0
  46. data/lib/torchvision/version.rb +1 -1
  47. metadata +51 -44
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: c6ae301007e3310dfc378a7f76d24bd71cc048be6156b32d9e0705412817c565
4
- data.tar.gz: 423aee54316683ddf9d3893e04f8836dd16ade1408f6c046f6d8f6f337a09a46
3
+ metadata.gz: d218068a8502ca9aa41ec3240043c55122c22e10944fb5d0fb402e95ccb48e99
4
+ data.tar.gz: 3d9288e590d4d9a570c7f6c501ffbd30ef968c58790e2beec42721672b34075f
5
5
  SHA512:
6
- metadata.gz: cf9c5a514c2ef42299161f97f077897c11b4708858018ffefcb282d6434003d49555ba54ea93ef0fb9e2f3cafdd7bd4395677e02bdfa5b1dcad4772aaeccc896
7
- data.tar.gz: 0e01311267e620be6dfbed77724baed42ad9c597a2b5452d762845597ae257932e76848e0069149598743d78dc08b0374b18897ce9a272640eed4a3cec21bbd9
6
+ metadata.gz: c841175a05fdace23e273413c863dcb2135fcab66414a6600864e16e12375c65d37a62fa2f03a433852f946324f40192ca5656d1f3e87b72ec0207e1f92e585f
7
+ data.tar.gz: a9e98d5929bed556b59ef29ad690668dac7bfe7b5f46f87e92d9e8dab352793ee659dcd8894343264c60b75bfd0cc07177db00bb0a5828410d879ebeb22bf860
data/CHANGELOG.md CHANGED
@@ -1,3 +1,38 @@
1
+ ## 0.2.2 (2021-05-23)
2
+
3
+ - Fixed error with ruby-vips 2.1.2
4
+
5
+ ## 0.2.1 (2021-03-14)
6
+
7
+ - Added `ImageFolder` and `DatasetFolder`
8
+ - Added `CenterCrop` and `RandomResizedCrop` transforms
9
+ - Added `crop` method
10
+
11
+ ## 0.2.0 (2021-03-11)
12
+
13
+ - Added `RandomHorizontalFlip`, `RandomVerticalFlip`, and `Resize` transforms
14
+ - Added `save_image` method
15
+ - Added `data` and `targets` methods to datasets
16
+ - Removed support for Ruby < 2.6
17
+
18
+ Breaking changes
19
+
20
+ - Added dependency on libvips
21
+ - MNIST datasets return images instead of tensors
22
+
23
+ ## 0.1.3 (2020-06-29)
24
+
25
+ - Added AlexNet model
26
+ - Added ResNet34, ResNet50, ResNet101, and ResNet152 models
27
+ - Added ResNeXt model
28
+ - Added VGG11, VGG13, VGG16, and VGG19 models
29
+ - Added Wide ResNet model
30
+
31
+ ## 0.1.2 (2020-04-29)
32
+
33
+ - Added CIFAR10, CIFAR100, FashionMNIST, and KMNIST datasets
34
+ - Added ResNet18 model
35
+
1
36
  ## 0.1.1 (2020-04-28)
2
37
 
3
38
  - Removed `mini_magick` for performance
data/LICENSE.txt CHANGED
@@ -1,7 +1,7 @@
1
1
  BSD 3-Clause License
2
2
 
3
- Copyright (c) Andrew Kane 2020,
4
3
  Copyright (c) Soumith Chintala 2016,
4
+ Copyright (c) Andrew Kane 2020-2021,
5
5
  All rights reserved.
6
6
 
7
7
  Redistribution and use in source and binary forms, with or without
data/README.md CHANGED
@@ -2,12 +2,16 @@
2
2
 
3
3
  :fire: Computer vision datasets, transforms, and models for Ruby
4
4
 
5
- This gem is currently experimental. There may be breaking changes between each release. Please report any issues you experience.
6
-
7
- [![Build Status](https://travis-ci.org/ankane/torchvision.svg?branch=master)](https://travis-ci.org/ankane/torchvision)
5
+ [![Build Status](https://github.com/ankane/torchvision/workflows/build/badge.svg?branch=master)](https://github.com/ankane/torchvision/actions)
8
6
 
9
7
  ## Installation
10
8
 
9
+ First, [install libvips](#libvips-installation). For Homebrew, use:
10
+
11
+ ```sh
12
+ brew install vips
13
+ ```
14
+
11
15
  Add this line to your application’s Gemfile:
12
16
 
13
17
  ```ruby
@@ -16,17 +20,139 @@ gem 'torchvision'
16
20
 
17
21
  ## Getting Started
18
22
 
19
- This library follows the [Python API](https://pytorch.org/docs/master/torchvision/). Many methods and options are missing at the moment. PRs welcome!
23
+ This library follows the [Python API](https://pytorch.org/docs/stable/torchvision/index.html). Many methods and options are missing at the moment. PRs welcome!
24
+
25
+ ## Examples
26
+
27
+ - [MNIST](https://github.com/ankane/torch.rb/tree/master/examples/mnist)
28
+ - [Transfer learning](https://github.com/ankane/torch.rb/tree/master/examples/transfer-learning)
29
+ - [Generative adversarial networks](https://github.com/ankane/torch.rb/tree/master/examples/gan)
20
30
 
21
31
  ## Datasets
22
32
 
23
- MNIST dataset
33
+ Load a dataset
34
+
35
+ ```ruby
36
+ TorchVision::Datasets::MNIST.new("./data", train: true, download: true)
37
+ ```
38
+
39
+ Supported datasets are:
40
+
41
+ - CIFAR10
42
+ - CIFAR100
43
+ - FashionMNIST
44
+ - KMNIST
45
+ - MNIST
46
+
47
+ ## Transforms
48
+
49
+ ```ruby
50
+ TorchVision::Transforms::Compose.new([
51
+ TorchVision::Transforms::ToTensor.new,
52
+ TorchVision::Transforms::Normalize.new([0.1307], [0.3081])
53
+ ])
54
+ ```
55
+
56
+ Supported transforms are:
57
+
58
+ - CenterCrop
59
+ - Compose
60
+ - Normalize
61
+ - RandomHorizontalFlip
62
+ - RandomResizedCrop
63
+ - RandomVerticalFlip
64
+ - Resize
65
+ - ToTensor
66
+
67
+ ## Models
68
+
69
+ - [AlexNet](#alexnet)
70
+ - [ResNet](#resnet)
71
+ - [ResNeXt](#resnext)
72
+ - [VGG](#vgg)
73
+ - [Wide ResNet](#wide-resnet)
74
+
75
+ ### AlexNet
24
76
 
25
77
  ```ruby
26
- trainset = TorchVision::Datasets::MNIST.new("./data", train: true, download: true)
27
- trainset.size
78
+ TorchVision::Models::AlexNet.new
28
79
  ```
29
80
 
81
+ ### ResNet
82
+
83
+ ```ruby
84
+ TorchVision::Models::ResNet18.new
85
+ TorchVision::Models::ResNet34.new
86
+ TorchVision::Models::ResNet50.new
87
+ TorchVision::Models::ResNet101.new
88
+ TorchVision::Models::ResNet152.new
89
+ ```
90
+
91
+ ### ResNeXt
92
+
93
+ ```ruby
94
+ TorchVision::Models::ResNext52_32x4d.new
95
+ TorchVision::Models::ResNext101_32x8d.new
96
+ ```
97
+
98
+ ### VGG
99
+
100
+ ```ruby
101
+ TorchVision::Models::VGG11.new
102
+ TorchVision::Models::VGG11BN.new
103
+ TorchVision::Models::VGG13.new
104
+ TorchVision::Models::VGG13BN.new
105
+ TorchVision::Models::VGG16.new
106
+ TorchVision::Models::VGG16BN.new
107
+ TorchVision::Models::VGG19.new
108
+ TorchVision::Models::VGG19BN.new
109
+ ```
110
+
111
+ ### Wide ResNet
112
+
113
+ ```ruby
114
+ TorchVision::Models::WideResNet52_2.new
115
+ TorchVision::Models::WideResNet101_2.new
116
+ ```
117
+
118
+ ## Pretrained Models
119
+
120
+ You can download pretrained models with [this script](pretrained.py)
121
+
122
+ ```sh
123
+ pip install torchvision
124
+ python pretrained.py
125
+ ```
126
+
127
+ And load them
128
+
129
+ ```ruby
130
+ net = TorchVision::Models::ResNet18.new
131
+ net.load_state_dict(Torch.load("net.pth"))
132
+ ```
133
+
134
+ ## libvips Installation
135
+
136
+ ### Linux
137
+
138
+ Check your package manager. For Ubuntu, use:
139
+
140
+ ```sh
141
+ sudo apt install libvips
142
+ ```
143
+
144
+ You can also [build from source](https://libvips.github.io/libvips/install.html).
145
+
146
+ ### Mac
147
+
148
+ ```sh
149
+ brew install vips
150
+ ```
151
+
152
+ ### Windows
153
+
154
+ Check out [the options](https://libvips.github.io/libvips/install.html).
155
+
30
156
  ## Disclaimer
31
157
 
32
158
  This library downloads and prepares public datasets. We don’t host any datasets. Be sure to adhere to the license for each dataset.
data/lib/torchvision.rb CHANGED
@@ -1,23 +1,62 @@
1
1
  # dependencies
2
2
  require "numo/narray"
3
+ require "vips"
3
4
  require "torch"
4
5
 
5
6
  # stdlib
6
7
  require "digest"
7
8
  require "fileutils"
8
9
  require "net/http"
10
+ require "rubygems/package"
9
11
  require "tmpdir"
10
12
 
11
13
  # modules
14
+ require "torchvision/utils"
12
15
  require "torchvision/version"
13
16
 
14
17
  # datasets
18
+ require "torchvision/datasets/vision_dataset"
19
+ require "torchvision/datasets/dataset_folder"
20
+ require "torchvision/datasets/image_folder"
21
+ require "torchvision/datasets/cifar10"
22
+ require "torchvision/datasets/cifar100"
15
23
  require "torchvision/datasets/mnist"
24
+ require "torchvision/datasets/fashion_mnist"
25
+ require "torchvision/datasets/kmnist"
26
+
27
+ # models
28
+ require "torchvision/models/alexnet"
29
+ require "torchvision/models/basic_block"
30
+ require "torchvision/models/bottleneck"
31
+ require "torchvision/models/resnet"
32
+ require "torchvision/models/resnet18"
33
+ require "torchvision/models/resnet34"
34
+ require "torchvision/models/resnet50"
35
+ require "torchvision/models/resnet101"
36
+ require "torchvision/models/resnet152"
37
+ require "torchvision/models/resnext50_32x4d"
38
+ require "torchvision/models/resnext101_32x8d"
39
+ require "torchvision/models/vgg"
40
+ require "torchvision/models/vgg11"
41
+ require "torchvision/models/vgg11_bn"
42
+ require "torchvision/models/vgg13"
43
+ require "torchvision/models/vgg13_bn"
44
+ require "torchvision/models/vgg16"
45
+ require "torchvision/models/vgg16_bn"
46
+ require "torchvision/models/vgg19"
47
+ require "torchvision/models/vgg19_bn"
48
+ require "torchvision/models/wide_resnet50_2"
49
+ require "torchvision/models/wide_resnet101_2"
16
50
 
17
51
  # transforms
52
+ require "torchvision/transforms/center_crop"
18
53
  require "torchvision/transforms/compose"
19
54
  require "torchvision/transforms/functional"
20
55
  require "torchvision/transforms/normalize"
56
+ require "torchvision/transforms/random_horizontal_flip"
57
+ require "torchvision/transforms/random_resized_crop"
58
+ require "torchvision/transforms/random_vertical_flip"
59
+ require "torchvision/transforms/resize"
21
60
  require "torchvision/transforms/to_tensor"
22
61
 
23
62
  module TorchVision
@@ -0,0 +1,117 @@
1
+ module TorchVision
2
+ module Datasets
3
+ class CIFAR10 < VisionDataset
4
+ # https://www.cs.toronto.edu/~kriz/cifar.html
5
+
6
+ def initialize(root, train: true, download: false, transform: nil, target_transform: nil)
7
+ super(root, transform: transform, target_transform: target_transform)
8
+ @train = train
9
+
10
+ self.download if download
11
+
12
+ if !_check_integrity
13
+ raise Error, "Dataset not found or corrupted. You can use download=True to download it"
14
+ end
15
+
16
+ downloaded_list = @train ? train_list : test_list
17
+
18
+ @data = String.new
19
+ @targets = String.new
20
+
21
+ downloaded_list.each do |file|
22
+ file_path = File.join(@root, base_folder, file[:filename])
23
+ File.open(file_path, "rb") do |f|
24
+ while !f.eof?
25
+ f.read(1) if multiple_labels?
26
+ @targets << f.read(1)
27
+ @data << f.read(3072)
28
+ end
29
+ end
30
+ end
31
+
32
+ @targets = @targets.unpack("C*")
33
+ # TODO switch i to -1 when Numo supports it
34
+ @data = Numo::UInt8.from_binary(@data).reshape(@targets.size, 3, 32, 32)
35
+ @data = @data.transpose(0, 2, 3, 1)
36
+ end
37
+
38
+ def size
39
+ @data.shape[0]
40
+ end
41
+
42
+ def [](index)
43
+ # TODO remove trues when Numo supports it
44
+ img, target = @data[index, true, true, true], @targets[index]
45
+
46
+ img = Utils.image_from_array(img)
47
+
48
+ img = @transform.call(img) if @transform
49
+
50
+ target = @target_transform.call(target) if @target_transform
51
+
52
+ [img, target]
53
+ end
54
+
55
+ def _check_integrity
56
+ root = @root
57
+ (train_list + test_list).each do |fentry|
58
+ fpath = File.join(root, base_folder, fentry[:filename])
59
+ return false unless check_integrity(fpath, fentry[:sha256])
60
+ end
61
+ true
62
+ end
63
+
64
+ def download
65
+ if _check_integrity
66
+ puts "Files already downloaded and verified"
67
+ return
68
+ end
69
+
70
+ download_file(url, download_root: @root, filename: filename, sha256: tgz_sha256)
71
+
72
+ path = File.join(@root, filename)
73
+ File.open(path, "rb") do |io|
74
+ Gem::Package.new("").extract_tar_gz(io, @root)
75
+ end
76
+ end
77
+
78
+ private
79
+
80
+ def base_folder
81
+ "cifar-10-batches-bin"
82
+ end
83
+
84
+ def url
85
+ "https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz"
86
+ end
87
+
88
+ def filename
89
+ "cifar-10-binary.tar.gz"
90
+ end
91
+
92
+ def tgz_sha256
93
+ "c4a38c50a1bc5f3a1c5537f2155ab9d68f9f25eb1ed8d9ddda3db29a59bca1dd"
94
+ end
95
+
96
+ def train_list
97
+ [
98
+ {filename: "data_batch_1.bin", sha256: "cee916563c9f80d84e3cc88e17fdc0941787f1244f00a67874d45b261883ada5"},
99
+ {filename: "data_batch_2.bin", sha256: "a591ca11fa1708a91ee40f54b3da4784ccd871ecf2137de63f51ada8b3fa57ed"},
100
+ {filename: "data_batch_3.bin", sha256: "bbe8596564c0f86427f876058170b84dac6670ddf06d79402899d93ceea26f67"},
101
+ {filename: "data_batch_4.bin", sha256: "014e562d6e23c72197cc727519169a60359f5eccd8945ad5a09d710285ff4e48"},
102
+ {filename: "data_batch_5.bin", sha256: "755304fc0b379caeae8c14f0dac912fbc7d6cd469eb67a1029a08a39453a9add"},
103
+ ]
104
+ end
105
+
106
+ def test_list
107
+ [
108
+ {filename: "test_batch.bin", sha256: "8e2eb146ae340b09e24670f29cabc6326dba54da8789dab6768acf480273f65b"}
109
+ ]
110
+ end
111
+
112
+ def multiple_labels?
113
+ false
114
+ end
115
+ end
116
+ end
117
+ end
@@ -0,0 +1,41 @@
1
+ module TorchVision
2
+ module Datasets
3
+ class CIFAR100 < CIFAR10
4
+ # https://www.cs.toronto.edu/~kriz/cifar.html
5
+
6
+ private
7
+
8
+ def base_folder
9
+ "cifar-100-binary"
10
+ end
11
+
12
+ def url
13
+ "https://www.cs.toronto.edu/~kriz/cifar-100-binary.tar.gz"
14
+ end
15
+
16
+ def filename
17
+ "cifar-100-binary.tar.gz"
18
+ end
19
+
20
+ def tgz_sha256
21
+ "58a81ae192c23a4be8b1804d68e518ed807d710a4eb253b1f2a199162a40d8ec"
22
+ end
23
+
24
+ def train_list
25
+ [
26
+ {filename: "train.bin", sha256: "f31298fc616915fa142368359df1c4ca2ae984d6915ca468b998a5ec6aeebf29"}
27
+ ]
28
+ end
29
+
30
+ def test_list
31
+ [
32
+ {filename: "test.bin", sha256: "d8b1e6b7b3bee4020055f0699b111f60b1af1e262aeb93a0b659061746f8224a"}
33
+ ]
34
+ end
35
+
36
+ def multiple_labels?
37
+ true
38
+ end
39
+ end
40
+ end
41
+ end