torchvision 0.1.1 → 0.2.2

Sign up to get free protection for your applications and to get access to all the features.
Files changed (47) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +35 -0
  3. data/LICENSE.txt +1 -1
  4. data/README.md +133 -7
  5. data/lib/torchvision.rb +39 -0
  6. data/lib/torchvision/datasets/cifar10.rb +117 -0
  7. data/lib/torchvision/datasets/cifar100.rb +41 -0
  8. data/lib/torchvision/datasets/dataset_folder.rb +91 -0
  9. data/lib/torchvision/datasets/fashion_mnist.rb +30 -0
  10. data/lib/torchvision/datasets/image_folder.rb +12 -0
  11. data/lib/torchvision/datasets/kmnist.rb +30 -0
  12. data/lib/torchvision/datasets/mnist.rb +47 -75
  13. data/lib/torchvision/datasets/vision_dataset.rb +67 -0
  14. data/lib/torchvision/models/alexnet.rb +42 -0
  15. data/lib/torchvision/models/basic_block.rb +46 -0
  16. data/lib/torchvision/models/bottleneck.rb +47 -0
  17. data/lib/torchvision/models/resnet.rb +129 -0
  18. data/lib/torchvision/models/resnet101.rb +9 -0
  19. data/lib/torchvision/models/resnet152.rb +9 -0
  20. data/lib/torchvision/models/resnet18.rb +9 -0
  21. data/lib/torchvision/models/resnet34.rb +9 -0
  22. data/lib/torchvision/models/resnet50.rb +9 -0
  23. data/lib/torchvision/models/resnext101_32x8d.rb +11 -0
  24. data/lib/torchvision/models/resnext50_32x4d.rb +11 -0
  25. data/lib/torchvision/models/vgg.rb +93 -0
  26. data/lib/torchvision/models/vgg11.rb +9 -0
  27. data/lib/torchvision/models/vgg11_bn.rb +9 -0
  28. data/lib/torchvision/models/vgg13.rb +9 -0
  29. data/lib/torchvision/models/vgg13_bn.rb +9 -0
  30. data/lib/torchvision/models/vgg16.rb +9 -0
  31. data/lib/torchvision/models/vgg16_bn.rb +9 -0
  32. data/lib/torchvision/models/vgg19.rb +9 -0
  33. data/lib/torchvision/models/vgg19_bn.rb +9 -0
  34. data/lib/torchvision/models/wide_resnet101_2.rb +10 -0
  35. data/lib/torchvision/models/wide_resnet50_2.rb +10 -0
  36. data/lib/torchvision/transforms/center_crop.rb +13 -0
  37. data/lib/torchvision/transforms/compose.rb +2 -2
  38. data/lib/torchvision/transforms/functional.rb +142 -8
  39. data/lib/torchvision/transforms/normalize.rb +2 -2
  40. data/lib/torchvision/transforms/random_horizontal_flip.rb +18 -0
  41. data/lib/torchvision/transforms/random_resized_crop.rb +70 -0
  42. data/lib/torchvision/transforms/random_vertical_flip.rb +18 -0
  43. data/lib/torchvision/transforms/resize.rb +13 -0
  44. data/lib/torchvision/transforms/to_tensor.rb +2 -2
  45. data/lib/torchvision/utils.rb +118 -0
  46. data/lib/torchvision/version.rb +1 -1
  47. metadata +51 -44
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: c6ae301007e3310dfc378a7f76d24bd71cc048be6156b32d9e0705412817c565
4
- data.tar.gz: 423aee54316683ddf9d3893e04f8836dd16ade1408f6c046f6d8f6f337a09a46
3
+ metadata.gz: d218068a8502ca9aa41ec3240043c55122c22e10944fb5d0fb402e95ccb48e99
4
+ data.tar.gz: 3d9288e590d4d9a570c7f6c501ffbd30ef968c58790e2beec42721672b34075f
5
5
  SHA512:
6
- metadata.gz: cf9c5a514c2ef42299161f97f077897c11b4708858018ffefcb282d6434003d49555ba54ea93ef0fb9e2f3cafdd7bd4395677e02bdfa5b1dcad4772aaeccc896
7
- data.tar.gz: 0e01311267e620be6dfbed77724baed42ad9c597a2b5452d762845597ae257932e76848e0069149598743d78dc08b0374b18897ce9a272640eed4a3cec21bbd9
6
+ metadata.gz: c841175a05fdace23e273413c863dcb2135fcab66414a6600864e16e12375c65d37a62fa2f03a433852f946324f40192ca5656d1f3e87b72ec0207e1f92e585f
7
+ data.tar.gz: a9e98d5929bed556b59ef29ad690668dac7bfe7b5f46f87e92d9e8dab352793ee659dcd8894343264c60b75bfd0cc07177db00bb0a5828410d879ebeb22bf860
data/CHANGELOG.md CHANGED
@@ -1,3 +1,38 @@
1
+ ## 0.2.2 (2021-05-23)
2
+
3
+ - Fixed error with ruby-vips 2.1.2
4
+
5
+ ## 0.2.1 (2021-03-14)
6
+
7
+ - Added `ImageFolder` and `DatasetFolder`
8
+ - Added `CenterCrop` and `RandomResizedCrop` transforms
9
+ - Added `crop` method
10
+
11
+ ## 0.2.0 (2021-03-11)
12
+
13
+ - Added `RandomHorizontalFlip`, `RandomVerticalFlip`, and `Resize` transforms
14
+ - Added `save_image` method
15
+ - Added `data` and `targets` methods to datasets
16
+ - Removed support for Ruby < 2.6
17
+
18
+ Breaking changes
19
+
20
+ - Added dependency on libvips
21
+ - MNIST datasets return images instead of tensors
22
+
23
+ ## 0.1.3 (2020-06-29)
24
+
25
+ - Added AlexNet model
26
+ - Added ResNet34, ResNet50, ResNet101, and ResNet152 models
27
+ - Added ResNeXt model
28
+ - Added VGG11, VGG13, VGG16, and VGG19 models
29
+ - Added Wide ResNet model
30
+
31
+ ## 0.1.2 (2020-04-29)
32
+
33
+ - Added CIFAR10, CIFAR100, FashionMNIST, and KMNIST datasets
34
+ - Added ResNet18 model
35
+
1
36
  ## 0.1.1 (2020-04-28)
2
37
 
3
38
  - Removed `mini_magick` for performance
data/LICENSE.txt CHANGED
@@ -1,7 +1,7 @@
1
1
  BSD 3-Clause License
2
2
 
3
- Copyright (c) Andrew Kane 2020,
4
3
  Copyright (c) Soumith Chintala 2016,
4
+ Copyright (c) Andrew Kane 2020-2021,
5
5
  All rights reserved.
6
6
 
7
7
  Redistribution and use in source and binary forms, with or without
data/README.md CHANGED
@@ -2,12 +2,16 @@
2
2
 
3
3
  :fire: Computer vision datasets, transforms, and models for Ruby
4
4
 
5
- This gem is currently experimental. There may be breaking changes between each release. Please report any issues you experience.
6
-
7
- [![Build Status](https://travis-ci.org/ankane/torchvision.svg?branch=master)](https://travis-ci.org/ankane/torchvision)
5
+ [![Build Status](https://github.com/ankane/torchvision/workflows/build/badge.svg?branch=master)](https://github.com/ankane/torchvision/actions)
8
6
 
9
7
  ## Installation
10
8
 
9
+ First, [install libvips](#libvips-installation). For Homebrew, use:
10
+
11
+ ```sh
12
+ brew install vips
13
+ ```
14
+
11
15
  Add this line to your application’s Gemfile:
12
16
 
13
17
  ```ruby
@@ -16,17 +20,139 @@ gem 'torchvision'
16
20
 
17
21
  ## Getting Started
18
22
 
19
- This library follows the [Python API](https://pytorch.org/docs/master/torchvision/). Many methods and options are missing at the moment. PRs welcome!
23
+ This library follows the [Python API](https://pytorch.org/docs/stable/torchvision/index.html). Many methods and options are missing at the moment. PRs welcome!
24
+
25
+ ## Examples
26
+
27
+ - [MNIST](https://github.com/ankane/torch.rb/tree/master/examples/mnist)
28
+ - [Transfer learning](https://github.com/ankane/torch.rb/tree/master/examples/transfer-learning)
29
+ - [Generative adversarial networks](https://github.com/ankane/torch.rb/tree/master/examples/gan)
20
30
 
21
31
  ## Datasets
22
32
 
23
- MNIST dataset
33
+ Load a dataset
34
+
35
+ ```ruby
36
+ TorchVision::Datasets::MNIST.new("./data", train: true, download: true)
37
+ ```
38
+
39
+ Supported datasets are:
40
+
41
+ - CIFAR10
42
+ - CIFAR100
43
+ - FashionMNIST
44
+ - KMNIST
45
+ - MNIST
46
+
47
+ ## Transforms
48
+
49
+ ```ruby
50
+ TorchVision::Transforms::Compose.new([
51
+ TorchVision::Transforms::ToTensor.new,
52
+ TorchVision::Transforms::Normalize.new([0.1307], [0.3081])
53
+ ])
54
+ ```
55
+
56
+ Supported transforms are:
57
+
58
+ - CenterCrop
59
+ - Compose
60
+ - Normalize
61
+ - RandomHorizontalFlip
62
+ - RandomResizedCrop
63
+ - RandomVerticalFlip
64
+ - Resize
65
+ - ToTensor
66
+
67
+ ## Models
68
+
69
+ - [AlexNet](#alexnet)
70
+ - [ResNet](#resnet)
71
+ - [ResNeXt](#resnext)
72
+ - [VGG](#vgg)
73
+ - [Wide ResNet](#wide-resnet)
74
+
75
+ ### AlexNet
24
76
 
25
77
  ```ruby
26
- trainset = TorchVision::Datasets::MNIST.new("./data", train: true, download: true)
27
- trainset.size
78
+ TorchVision::Models::AlexNet.new
28
79
  ```
29
80
 
81
+ ### ResNet
82
+
83
+ ```ruby
84
+ TorchVision::Models::ResNet18.new
85
+ TorchVision::Models::ResNet34.new
86
+ TorchVision::Models::ResNet50.new
87
+ TorchVision::Models::ResNet101.new
88
+ TorchVision::Models::ResNet152.new
89
+ ```
90
+
91
+ ### ResNeXt
92
+
93
+ ```ruby
94
+ TorchVision::Models::ResNext52_32x4d.new
95
+ TorchVision::Models::ResNext101_32x8d.new
96
+ ```
97
+
98
+ ### VGG
99
+
100
+ ```ruby
101
+ TorchVision::Models::VGG11.new
102
+ TorchVision::Models::VGG11BN.new
103
+ TorchVision::Models::VGG13.new
104
+ TorchVision::Models::VGG13BN.new
105
+ TorchVision::Models::VGG16.new
106
+ TorchVision::Models::VGG16BN.new
107
+ TorchVision::Models::VGG19.new
108
+ TorchVision::Models::VGG19BN.new
109
+ ```
110
+
111
+ ### Wide ResNet
112
+
113
+ ```ruby
114
+ TorchVision::Models::WideResNet52_2.new
115
+ TorchVision::Models::WideResNet101_2.new
116
+ ```
117
+
118
+ ## Pretrained Models
119
+
120
+ You can download pretrained models with [this script](pretrained.py)
121
+
122
+ ```sh
123
+ pip install torchvision
124
+ python pretrained.py
125
+ ```
126
+
127
+ And load them
128
+
129
+ ```ruby
130
+ net = TorchVision::Models::ResNet18.new
131
+ net.load_state_dict(Torch.load("net.pth"))
132
+ ```
133
+
134
+ ## libvips Installation
135
+
136
+ ### Linux
137
+
138
+ Check your package manager. For Ubuntu, use:
139
+
140
+ ```sh
141
+ sudo apt install libvips
142
+ ```
143
+
144
+ You can also [build from source](https://libvips.github.io/libvips/install.html).
145
+
146
+ ### Mac
147
+
148
+ ```sh
149
+ brew install vips
150
+ ```
151
+
152
+ ### Windows
153
+
154
+ Check out [the options](https://libvips.github.io/libvips/install.html).
155
+
30
156
  ## Disclaimer
31
157
 
32
158
  This library downloads and prepares public datasets. We don’t host any datasets. Be sure to adhere to the license for each dataset.
data/lib/torchvision.rb CHANGED
@@ -1,23 +1,62 @@
1
1
  # dependencies
2
2
  require "numo/narray"
3
+ require "vips"
3
4
  require "torch"
4
5
 
5
6
  # stdlib
6
7
  require "digest"
7
8
  require "fileutils"
8
9
  require "net/http"
10
+ require "rubygems/package"
9
11
  require "tmpdir"
10
12
 
11
13
  # modules
14
+ require "torchvision/utils"
12
15
  require "torchvision/version"
13
16
 
14
17
  # datasets
18
+ require "torchvision/datasets/vision_dataset"
19
+ require "torchvision/datasets/dataset_folder"
20
+ require "torchvision/datasets/image_folder"
21
+ require "torchvision/datasets/cifar10"
22
+ require "torchvision/datasets/cifar100"
15
23
  require "torchvision/datasets/mnist"
24
+ require "torchvision/datasets/fashion_mnist"
25
+ require "torchvision/datasets/kmnist"
26
+
27
+ # models
28
+ require "torchvision/models/alexnet"
29
+ require "torchvision/models/basic_block"
30
+ require "torchvision/models/bottleneck"
31
+ require "torchvision/models/resnet"
32
+ require "torchvision/models/resnet18"
33
+ require "torchvision/models/resnet34"
34
+ require "torchvision/models/resnet50"
35
+ require "torchvision/models/resnet101"
36
+ require "torchvision/models/resnet152"
37
+ require "torchvision/models/resnext50_32x4d"
38
+ require "torchvision/models/resnext101_32x8d"
39
+ require "torchvision/models/vgg"
40
+ require "torchvision/models/vgg11"
41
+ require "torchvision/models/vgg11_bn"
42
+ require "torchvision/models/vgg13"
43
+ require "torchvision/models/vgg13_bn"
44
+ require "torchvision/models/vgg16"
45
+ require "torchvision/models/vgg16_bn"
46
+ require "torchvision/models/vgg19"
47
+ require "torchvision/models/vgg19_bn"
48
+ require "torchvision/models/wide_resnet50_2"
49
+ require "torchvision/models/wide_resnet101_2"
16
50
 
17
51
  # transforms
52
+ require "torchvision/transforms/center_crop"
18
53
  require "torchvision/transforms/compose"
19
54
  require "torchvision/transforms/functional"
20
55
  require "torchvision/transforms/normalize"
56
+ require "torchvision/transforms/random_horizontal_flip"
57
+ require "torchvision/transforms/random_resized_crop"
58
+ require "torchvision/transforms/random_vertical_flip"
59
+ require "torchvision/transforms/resize"
21
60
  require "torchvision/transforms/to_tensor"
22
61
 
23
62
  module TorchVision
@@ -0,0 +1,117 @@
1
+ module TorchVision
2
+ module Datasets
3
+ class CIFAR10 < VisionDataset
4
+ # https://www.cs.toronto.edu/~kriz/cifar.html
5
+
6
+ def initialize(root, train: true, download: false, transform: nil, target_transform: nil)
7
+ super(root, transform: transform, target_transform: target_transform)
8
+ @train = train
9
+
10
+ self.download if download
11
+
12
+ if !_check_integrity
13
+ raise Error, "Dataset not found or corrupted. You can use download=True to download it"
14
+ end
15
+
16
+ downloaded_list = @train ? train_list : test_list
17
+
18
+ @data = String.new
19
+ @targets = String.new
20
+
21
+ downloaded_list.each do |file|
22
+ file_path = File.join(@root, base_folder, file[:filename])
23
+ File.open(file_path, "rb") do |f|
24
+ while !f.eof?
25
+ f.read(1) if multiple_labels?
26
+ @targets << f.read(1)
27
+ @data << f.read(3072)
28
+ end
29
+ end
30
+ end
31
+
32
+ @targets = @targets.unpack("C*")
33
+ # TODO switch i to -1 when Numo supports it
34
+ @data = Numo::UInt8.from_binary(@data).reshape(@targets.size, 3, 32, 32)
35
+ @data = @data.transpose(0, 2, 3, 1)
36
+ end
37
+
38
+ def size
39
+ @data.shape[0]
40
+ end
41
+
42
+ def [](index)
43
+ # TODO remove trues when Numo supports it
44
+ img, target = @data[index, true, true, true], @targets[index]
45
+
46
+ img = Utils.image_from_array(img)
47
+
48
+ img = @transform.call(img) if @transform
49
+
50
+ target = @target_transform.call(target) if @target_transform
51
+
52
+ [img, target]
53
+ end
54
+
55
+ def _check_integrity
56
+ root = @root
57
+ (train_list + test_list).each do |fentry|
58
+ fpath = File.join(root, base_folder, fentry[:filename])
59
+ return false unless check_integrity(fpath, fentry[:sha256])
60
+ end
61
+ true
62
+ end
63
+
64
+ def download
65
+ if _check_integrity
66
+ puts "Files already downloaded and verified"
67
+ return
68
+ end
69
+
70
+ download_file(url, download_root: @root, filename: filename, sha256: tgz_sha256)
71
+
72
+ path = File.join(@root, filename)
73
+ File.open(path, "rb") do |io|
74
+ Gem::Package.new("").extract_tar_gz(io, @root)
75
+ end
76
+ end
77
+
78
+ private
79
+
80
+ def base_folder
81
+ "cifar-10-batches-bin"
82
+ end
83
+
84
+ def url
85
+ "https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz"
86
+ end
87
+
88
+ def filename
89
+ "cifar-10-binary.tar.gz"
90
+ end
91
+
92
+ def tgz_sha256
93
+ "c4a38c50a1bc5f3a1c5537f2155ab9d68f9f25eb1ed8d9ddda3db29a59bca1dd"
94
+ end
95
+
96
+ def train_list
97
+ [
98
+ {filename: "data_batch_1.bin", sha256: "cee916563c9f80d84e3cc88e17fdc0941787f1244f00a67874d45b261883ada5"},
99
+ {filename: "data_batch_2.bin", sha256: "a591ca11fa1708a91ee40f54b3da4784ccd871ecf2137de63f51ada8b3fa57ed"},
100
+ {filename: "data_batch_3.bin", sha256: "bbe8596564c0f86427f876058170b84dac6670ddf06d79402899d93ceea26f67"},
101
+ {filename: "data_batch_4.bin", sha256: "014e562d6e23c72197cc727519169a60359f5eccd8945ad5a09d710285ff4e48"},
102
+ {filename: "data_batch_5.bin", sha256: "755304fc0b379caeae8c14f0dac912fbc7d6cd469eb67a1029a08a39453a9add"},
103
+ ]
104
+ end
105
+
106
+ def test_list
107
+ [
108
+ {filename: "test_batch.bin", sha256: "8e2eb146ae340b09e24670f29cabc6326dba54da8789dab6768acf480273f65b"}
109
+ ]
110
+ end
111
+
112
+ def multiple_labels?
113
+ false
114
+ end
115
+ end
116
+ end
117
+ end
@@ -0,0 +1,41 @@
1
+ module TorchVision
2
+ module Datasets
3
+ class CIFAR100 < CIFAR10
4
+ # https://www.cs.toronto.edu/~kriz/cifar.html
5
+
6
+ private
7
+
8
+ def base_folder
9
+ "cifar-100-binary"
10
+ end
11
+
12
+ def url
13
+ "https://www.cs.toronto.edu/~kriz/cifar-100-binary.tar.gz"
14
+ end
15
+
16
+ def filename
17
+ "cifar-100-binary.tar.gz"
18
+ end
19
+
20
+ def tgz_sha256
21
+ "58a81ae192c23a4be8b1804d68e518ed807d710a4eb253b1f2a199162a40d8ec"
22
+ end
23
+
24
+ def train_list
25
+ [
26
+ {filename: "train.bin", sha256: "f31298fc616915fa142368359df1c4ca2ae984d6915ca468b998a5ec6aeebf29"}
27
+ ]
28
+ end
29
+
30
+ def test_list
31
+ [
32
+ {filename: "test.bin", sha256: "d8b1e6b7b3bee4020055f0699b111f60b1af1e262aeb93a0b659061746f8224a"}
33
+ ]
34
+ end
35
+
36
+ def multiple_labels?
37
+ true
38
+ end
39
+ end
40
+ end
41
+ end