torchvision 0.1.1 → 0.1.2
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/README.md +25 -3
- data/lib/torchvision.rb +12 -0
- data/lib/torchvision/datasets/cifar10.rb +116 -0
- data/lib/torchvision/datasets/cifar100.rb +41 -0
- data/lib/torchvision/datasets/fashion_mnist.rb +30 -0
- data/lib/torchvision/datasets/kmnist.rb +30 -0
- data/lib/torchvision/datasets/mnist.rb +47 -75
- data/lib/torchvision/datasets/vision_dataset.rb +66 -0
- data/lib/torchvision/models/basic_block.rb +46 -0
- data/lib/torchvision/models/bottleneck.rb +47 -0
- data/lib/torchvision/models/resnet.rb +107 -0
- data/lib/torchvision/models/resnet18.rb +15 -0
- data/lib/torchvision/transforms/functional.rb +23 -6
- data/lib/torchvision/version.rb +1 -1
- metadata +13 -4
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 107e429f990a063e57f6218ee6d4fbed3cff80aa6746f0794798d59e1c13b099
|
4
|
+
data.tar.gz: 556fdc4c413803d415ea5575747ec2239f18b0dff9b39a37a4fd6366adec37a6
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 21b712578516c146888be30bed64a6da6339a42974f5c88fa685278caa9231e4bc4b75e250af3ad37d5450cd205891b31b281cf23fa362456c7ae00998c2736d
|
7
|
+
data.tar.gz: ac86f13e8b5d6a400842ba37b1bf593139360a8df9a19ab4674e71a4327c821a6d5735774185dcbaed368ee735b00bf33d9faedb13d46cd168314da89d900c38
|
data/CHANGELOG.md
CHANGED
data/README.md
CHANGED
@@ -20,11 +20,33 @@ This library follows the [Python API](https://pytorch.org/docs/master/torchvisio
|
|
20
20
|
|
21
21
|
## Datasets
|
22
22
|
|
23
|
-
|
23
|
+
Load a dataset
|
24
24
|
|
25
25
|
```ruby
|
26
|
-
|
27
|
-
|
26
|
+
TorchVision::Datasets::MNIST.new("./data", train: true, download: true)
|
27
|
+
```
|
28
|
+
|
29
|
+
Supported datasets are:
|
30
|
+
|
31
|
+
- CIFAR10
|
32
|
+
- CIFAR100
|
33
|
+
- FashionMNIST
|
34
|
+
- KMNIST
|
35
|
+
- MNIST
|
36
|
+
|
37
|
+
## Transforms
|
38
|
+
|
39
|
+
```ruby
|
40
|
+
TorchVision::Transforms::Compose.new([
|
41
|
+
TorchVision::Transforms::ToTensor.new,
|
42
|
+
TorchVision::Transforms::Normalize.new([0.1307], [0.3081])
|
43
|
+
])
|
44
|
+
```
|
45
|
+
|
46
|
+
## Models
|
47
|
+
|
48
|
+
```ruby
|
49
|
+
TorchVision::Models::Resnet18.new
|
28
50
|
```
|
29
51
|
|
30
52
|
## Disclaimer
|
data/lib/torchvision.rb
CHANGED
@@ -6,13 +6,25 @@ require "torch"
|
|
6
6
|
require "digest"
|
7
7
|
require "fileutils"
|
8
8
|
require "net/http"
|
9
|
+
require "rubygems/package"
|
9
10
|
require "tmpdir"
|
10
11
|
|
11
12
|
# modules
|
12
13
|
require "torchvision/version"
|
13
14
|
|
14
15
|
# datasets
|
16
|
+
require "torchvision/datasets/vision_dataset"
|
17
|
+
require "torchvision/datasets/cifar10"
|
18
|
+
require "torchvision/datasets/cifar100"
|
15
19
|
require "torchvision/datasets/mnist"
|
20
|
+
require "torchvision/datasets/fashion_mnist"
|
21
|
+
require "torchvision/datasets/kmnist"
|
22
|
+
|
23
|
+
# models
|
24
|
+
require "torchvision/models/basic_block"
|
25
|
+
require "torchvision/models/bottleneck"
|
26
|
+
require "torchvision/models/resnet"
|
27
|
+
require "torchvision/models/resnet18"
|
16
28
|
|
17
29
|
# transforms
|
18
30
|
require "torchvision/transforms/compose"
|
@@ -0,0 +1,116 @@
|
|
1
|
+
module TorchVision
|
2
|
+
module Datasets
|
3
|
+
class CIFAR10 < VisionDataset
|
4
|
+
# https://www.cs.toronto.edu/~kriz/cifar.html
|
5
|
+
|
6
|
+
def initialize(root, train: true, download: false, transform: nil, target_transform: nil)
|
7
|
+
super(root, transform: transform, target_transform: target_transform)
|
8
|
+
@train = train
|
9
|
+
|
10
|
+
self.download if download
|
11
|
+
|
12
|
+
if !_check_integrity
|
13
|
+
raise Error, "Dataset not found or corrupted. You can use download=True to download it"
|
14
|
+
end
|
15
|
+
|
16
|
+
downloaded_list = @train ? train_list : test_list
|
17
|
+
|
18
|
+
@data = String.new
|
19
|
+
@targets = String.new
|
20
|
+
|
21
|
+
downloaded_list.each do |file|
|
22
|
+
file_path = File.join(@root, base_folder, file[:filename])
|
23
|
+
File.open(file_path, "rb") do |f|
|
24
|
+
while !f.eof?
|
25
|
+
f.read(1) if multiple_labels?
|
26
|
+
@targets << f.read(1)
|
27
|
+
@data << f.read(3072)
|
28
|
+
end
|
29
|
+
end
|
30
|
+
end
|
31
|
+
|
32
|
+
@targets = @targets.unpack("C*")
|
33
|
+
# TODO switch i to -1 when Numo supports it
|
34
|
+
@data = Numo::UInt8.from_binary(@data).reshape(@targets.size, 3, 32, 32)
|
35
|
+
@data = @data.transpose(0, 2, 3, 1)
|
36
|
+
end
|
37
|
+
|
38
|
+
def size
|
39
|
+
@data.shape[0]
|
40
|
+
end
|
41
|
+
|
42
|
+
def [](index)
|
43
|
+
# TODO remove trues when Numo supports it
|
44
|
+
img, target = @data[index, true, true, true], @targets[index]
|
45
|
+
|
46
|
+
# TODO convert to image
|
47
|
+
img = @transform.call(img) if @transform
|
48
|
+
|
49
|
+
target = @target_transform.call(target) if @target_transform
|
50
|
+
|
51
|
+
[img, target]
|
52
|
+
end
|
53
|
+
|
54
|
+
def _check_integrity
|
55
|
+
root = @root
|
56
|
+
(train_list + test_list).each do |fentry|
|
57
|
+
fpath = File.join(root, base_folder, fentry[:filename])
|
58
|
+
return false unless check_integrity(fpath, fentry[:sha256])
|
59
|
+
end
|
60
|
+
true
|
61
|
+
end
|
62
|
+
|
63
|
+
def download
|
64
|
+
if _check_integrity
|
65
|
+
puts "Files already downloaded and verified"
|
66
|
+
return
|
67
|
+
end
|
68
|
+
|
69
|
+
download_file(url, download_root: @root, filename: filename, sha256: tgz_sha256)
|
70
|
+
|
71
|
+
path = File.join(@root, filename)
|
72
|
+
File.open(path, "rb") do |io|
|
73
|
+
Gem::Package.new("").extract_tar_gz(io, @root)
|
74
|
+
end
|
75
|
+
end
|
76
|
+
|
77
|
+
private
|
78
|
+
|
79
|
+
def base_folder
|
80
|
+
"cifar-10-batches-bin"
|
81
|
+
end
|
82
|
+
|
83
|
+
def url
|
84
|
+
"https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz"
|
85
|
+
end
|
86
|
+
|
87
|
+
def filename
|
88
|
+
"cifar-10-binary.tar.gz"
|
89
|
+
end
|
90
|
+
|
91
|
+
def tgz_sha256
|
92
|
+
"c4a38c50a1bc5f3a1c5537f2155ab9d68f9f25eb1ed8d9ddda3db29a59bca1dd"
|
93
|
+
end
|
94
|
+
|
95
|
+
def train_list
|
96
|
+
[
|
97
|
+
{filename: "data_batch_1.bin", sha256: "cee916563c9f80d84e3cc88e17fdc0941787f1244f00a67874d45b261883ada5"},
|
98
|
+
{filename: "data_batch_2.bin", sha256: "a591ca11fa1708a91ee40f54b3da4784ccd871ecf2137de63f51ada8b3fa57ed"},
|
99
|
+
{filename: "data_batch_3.bin", sha256: "bbe8596564c0f86427f876058170b84dac6670ddf06d79402899d93ceea26f67"},
|
100
|
+
{filename: "data_batch_4.bin", sha256: "014e562d6e23c72197cc727519169a60359f5eccd8945ad5a09d710285ff4e48"},
|
101
|
+
{filename: "data_batch_5.bin", sha256: "755304fc0b379caeae8c14f0dac912fbc7d6cd469eb67a1029a08a39453a9add"},
|
102
|
+
]
|
103
|
+
end
|
104
|
+
|
105
|
+
def test_list
|
106
|
+
[
|
107
|
+
{filename: "test_batch.bin", sha256: "8e2eb146ae340b09e24670f29cabc6326dba54da8789dab6768acf480273f65b"}
|
108
|
+
]
|
109
|
+
end
|
110
|
+
|
111
|
+
def multiple_labels?
|
112
|
+
false
|
113
|
+
end
|
114
|
+
end
|
115
|
+
end
|
116
|
+
end
|
@@ -0,0 +1,41 @@
|
|
1
|
+
module TorchVision
|
2
|
+
module Datasets
|
3
|
+
class CIFAR100 < CIFAR10
|
4
|
+
# https://www.cs.toronto.edu/~kriz/cifar.html
|
5
|
+
|
6
|
+
private
|
7
|
+
|
8
|
+
def base_folder
|
9
|
+
"cifar-100-binary"
|
10
|
+
end
|
11
|
+
|
12
|
+
def url
|
13
|
+
"https://www.cs.toronto.edu/~kriz/cifar-100-binary.tar.gz"
|
14
|
+
end
|
15
|
+
|
16
|
+
def filename
|
17
|
+
"cifar-100-binary.tar.gz"
|
18
|
+
end
|
19
|
+
|
20
|
+
def tgz_sha256
|
21
|
+
"58a81ae192c23a4be8b1804d68e518ed807d710a4eb253b1f2a199162a40d8ec"
|
22
|
+
end
|
23
|
+
|
24
|
+
def train_list
|
25
|
+
[
|
26
|
+
{filename: "train.bin", sha256: "f31298fc616915fa142368359df1c4ca2ae984d6915ca468b998a5ec6aeebf29"}
|
27
|
+
]
|
28
|
+
end
|
29
|
+
|
30
|
+
def test_list
|
31
|
+
[
|
32
|
+
{filename: "test.bin", sha256: "d8b1e6b7b3bee4020055f0699b111f60b1af1e262aeb93a0b659061746f8224a"}
|
33
|
+
]
|
34
|
+
end
|
35
|
+
|
36
|
+
def multiple_labels?
|
37
|
+
true
|
38
|
+
end
|
39
|
+
end
|
40
|
+
end
|
41
|
+
end
|
@@ -0,0 +1,30 @@
|
|
1
|
+
module TorchVision
|
2
|
+
module Datasets
|
3
|
+
class FashionMNIST < MNIST
|
4
|
+
# https://github.com/zalandoresearch/fashion-mnist
|
5
|
+
|
6
|
+
private
|
7
|
+
|
8
|
+
def resources
|
9
|
+
[
|
10
|
+
{
|
11
|
+
url: "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz",
|
12
|
+
sha256: "3aede38d61863908ad78613f6a32ed271626dd12800ba2636569512369268a84"
|
13
|
+
},
|
14
|
+
{
|
15
|
+
url: "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz",
|
16
|
+
sha256: "a04f17134ac03560a47e3764e11b92fc97de4d1bfaf8ba1a3aa29af54cc90845"
|
17
|
+
},
|
18
|
+
{
|
19
|
+
url: "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz",
|
20
|
+
sha256: "346e55b948d973a97e58d2351dde16a484bd415d4595297633bb08f03db6a073"
|
21
|
+
},
|
22
|
+
{
|
23
|
+
url: "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz",
|
24
|
+
sha256: "67da17c76eaffca5446c3361aaab5c3cd6d1c2608764d35dfb1850b086bf8dd5"
|
25
|
+
}
|
26
|
+
]
|
27
|
+
end
|
28
|
+
end
|
29
|
+
end
|
30
|
+
end
|
@@ -0,0 +1,30 @@
|
|
1
|
+
module TorchVision
|
2
|
+
module Datasets
|
3
|
+
class KMNIST < MNIST
|
4
|
+
# https://github.com/rois-codh/kmnist
|
5
|
+
|
6
|
+
private
|
7
|
+
|
8
|
+
def resources
|
9
|
+
[
|
10
|
+
{
|
11
|
+
url: "http://codh.rois.ac.jp/kmnist/dataset/kmnist/train-images-idx3-ubyte.gz",
|
12
|
+
sha256: "51467d22d8cc72929e2a028a0428f2086b092bb31cfb79c69cc0a90ce135fde4"
|
13
|
+
},
|
14
|
+
{
|
15
|
+
url: "http://codh.rois.ac.jp/kmnist/dataset/kmnist/train-labels-idx1-ubyte.gz",
|
16
|
+
sha256: "e38f9ebcd0f3ebcdec7fc8eabdcdaef93bb0df8ea12bee65224341c8183d8e17"
|
17
|
+
},
|
18
|
+
{
|
19
|
+
url: "http://codh.rois.ac.jp/kmnist/dataset/kmnist/t10k-images-idx3-ubyte.gz",
|
20
|
+
sha256: "edd7a857845ad6bb1d0ba43fe7e794d164fe2dce499a1694695a792adfac43c5"
|
21
|
+
},
|
22
|
+
{
|
23
|
+
url: "http://codh.rois.ac.jp/kmnist/dataset/kmnist/t10k-labels-idx1-ubyte.gz",
|
24
|
+
sha256: "20bb9a0ef54c7db3efc55a92eef5582c109615df22683c380526788f98e42a1c"
|
25
|
+
}
|
26
|
+
]
|
27
|
+
end
|
28
|
+
end
|
29
|
+
end
|
30
|
+
end
|
@@ -1,31 +1,11 @@
|
|
1
1
|
module TorchVision
|
2
2
|
module Datasets
|
3
|
-
class MNIST
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
},
|
9
|
-
{
|
10
|
-
url: "http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz",
|
11
|
-
sha256: "3552534a0a558bbed6aed32b30c495cca23d567ec52cac8be1a0730e8010255c"
|
12
|
-
},
|
13
|
-
{
|
14
|
-
url: "http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz",
|
15
|
-
sha256: "8d422c7b0a1c1c79245a5bcf07fe86e33eeafee792b84584aec276f5a2dbc4e6"
|
16
|
-
},
|
17
|
-
{
|
18
|
-
url: "http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz",
|
19
|
-
sha256: "f7ae60f92e00ec6debd23a6088c31dbd2371eca3ffa0defaefb259924204aec6"
|
20
|
-
}
|
21
|
-
]
|
22
|
-
TRAINING_FILE = "training.pt"
|
23
|
-
TEST_FILE = "test.pt"
|
24
|
-
|
25
|
-
def initialize(root, train: true, download: false, transform: nil)
|
26
|
-
@root = root
|
3
|
+
class MNIST < VisionDataset
|
4
|
+
# http://yann.lecun.com/exdb/mnist/
|
5
|
+
|
6
|
+
def initialize(root, train: true, download: false, transform: nil, target_transform: nil)
|
7
|
+
super(root, transform: transform, target_transform: target_transform)
|
27
8
|
@train = train
|
28
|
-
@transform = transform
|
29
9
|
|
30
10
|
self.download if download
|
31
11
|
|
@@ -33,34 +13,36 @@ module TorchVision
|
|
33
13
|
raise Error, "Dataset not found. You can use download: true to download it"
|
34
14
|
end
|
35
15
|
|
36
|
-
data_file = @train ?
|
16
|
+
data_file = @train ? training_file : test_file
|
37
17
|
@data, @targets = Torch.load(File.join(processed_folder, data_file))
|
38
18
|
end
|
39
19
|
|
40
20
|
def size
|
41
|
-
@data.size
|
21
|
+
@data.size(0)
|
42
22
|
end
|
43
23
|
|
44
24
|
def [](index)
|
45
|
-
img = @data[index]
|
25
|
+
img, target = @data[index], @targets[index].item
|
26
|
+
|
27
|
+
# TODO convert to image
|
46
28
|
img = @transform.call(img) if @transform
|
47
29
|
|
48
|
-
target = @
|
30
|
+
target = @target_transform.call(target) if @target_transform
|
49
31
|
|
50
32
|
[img, target]
|
51
33
|
end
|
52
34
|
|
53
35
|
def raw_folder
|
54
|
-
File.join(@root, "
|
36
|
+
File.join(@root, self.class.name.split("::").last, "raw")
|
55
37
|
end
|
56
38
|
|
57
39
|
def processed_folder
|
58
|
-
File.join(@root, "
|
40
|
+
File.join(@root, self.class.name.split("::").last, "processed")
|
59
41
|
end
|
60
42
|
|
61
43
|
def check_exists
|
62
|
-
File.exist?(File.join(processed_folder,
|
63
|
-
File.exist?(File.join(processed_folder,
|
44
|
+
File.exist?(File.join(processed_folder, training_file)) &&
|
45
|
+
File.exist?(File.join(processed_folder, test_file))
|
64
46
|
end
|
65
47
|
|
66
48
|
def download
|
@@ -69,7 +51,7 @@ module TorchVision
|
|
69
51
|
FileUtils.mkdir_p(raw_folder)
|
70
52
|
FileUtils.mkdir_p(processed_folder)
|
71
53
|
|
72
|
-
|
54
|
+
resources.each do |resource|
|
73
55
|
filename = resource[:url].split("/").last
|
74
56
|
download_file(resource[:url], download_root: raw_folder, filename: filename, sha256: resource[:sha256])
|
75
57
|
end
|
@@ -85,14 +67,43 @@ module TorchVision
|
|
85
67
|
unpack_mnist("t10k-labels-idx1-ubyte", 8, [10000])
|
86
68
|
]
|
87
69
|
|
88
|
-
Torch.save(training_set, File.join(processed_folder,
|
89
|
-
Torch.save(test_set, File.join(processed_folder,
|
70
|
+
Torch.save(training_set, File.join(processed_folder, training_file))
|
71
|
+
Torch.save(test_set, File.join(processed_folder, test_file))
|
90
72
|
|
91
73
|
puts "Done!"
|
92
74
|
end
|
93
75
|
|
94
76
|
private
|
95
77
|
|
78
|
+
def resources
|
79
|
+
[
|
80
|
+
{
|
81
|
+
url: "http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz",
|
82
|
+
sha256: "440fcabf73cc546fa21475e81ea370265605f56be210a4024d2ca8f203523609"
|
83
|
+
},
|
84
|
+
{
|
85
|
+
url: "http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz",
|
86
|
+
sha256: "3552534a0a558bbed6aed32b30c495cca23d567ec52cac8be1a0730e8010255c"
|
87
|
+
},
|
88
|
+
{
|
89
|
+
url: "http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz",
|
90
|
+
sha256: "8d422c7b0a1c1c79245a5bcf07fe86e33eeafee792b84584aec276f5a2dbc4e6"
|
91
|
+
},
|
92
|
+
{
|
93
|
+
url: "http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz",
|
94
|
+
sha256: "f7ae60f92e00ec6debd23a6088c31dbd2371eca3ffa0defaefb259924204aec6"
|
95
|
+
}
|
96
|
+
]
|
97
|
+
end
|
98
|
+
|
99
|
+
def training_file
|
100
|
+
"training.pt"
|
101
|
+
end
|
102
|
+
|
103
|
+
def test_file
|
104
|
+
"test.pt"
|
105
|
+
end
|
106
|
+
|
96
107
|
def unpack_mnist(path, offset, shape)
|
97
108
|
path = File.join(raw_folder, "#{path}.gz")
|
98
109
|
File.open(path, "rb") do |f|
|
@@ -101,45 +112,6 @@ module TorchVision
|
|
101
112
|
Torch.tensor(Numo::UInt8.from_string(gz.read, shape))
|
102
113
|
end
|
103
114
|
end
|
104
|
-
|
105
|
-
def download_file(url, download_root:, filename:, sha256:)
|
106
|
-
FileUtils.mkdir_p(download_root)
|
107
|
-
|
108
|
-
dest = File.join(download_root, filename)
|
109
|
-
return dest if File.exist?(dest)
|
110
|
-
|
111
|
-
temp_path = "#{Dir.tmpdir}/#{Time.now.to_f}" # TODO better name
|
112
|
-
|
113
|
-
digest = Digest::SHA256.new
|
114
|
-
|
115
|
-
uri = URI(url)
|
116
|
-
|
117
|
-
# Net::HTTP automatically adds Accept-Encoding for compression
|
118
|
-
# of response bodies and automatically decompresses gzip
|
119
|
-
# and deflateresponses unless a Range header was sent.
|
120
|
-
# https://ruby-doc.org/stdlib-2.6.4/libdoc/net/http/rdoc/Net/HTTP.html
|
121
|
-
Net::HTTP.start(uri.host, uri.port, use_ssl: uri.scheme == "https") do |http|
|
122
|
-
request = Net::HTTP::Get.new(uri)
|
123
|
-
|
124
|
-
puts "Downloading #{url}..."
|
125
|
-
File.open(temp_path, "wb") do |f|
|
126
|
-
http.request(request) do |response|
|
127
|
-
response.read_body do |chunk|
|
128
|
-
f.write(chunk)
|
129
|
-
digest.update(chunk)
|
130
|
-
end
|
131
|
-
end
|
132
|
-
end
|
133
|
-
end
|
134
|
-
|
135
|
-
if digest.hexdigest != sha256
|
136
|
-
raise Error, "Bad hash: #{digest.hexdigest}"
|
137
|
-
end
|
138
|
-
|
139
|
-
FileUtils.mv(temp_path, dest)
|
140
|
-
|
141
|
-
dest
|
142
|
-
end
|
143
115
|
end
|
144
116
|
end
|
145
117
|
end
|
@@ -0,0 +1,66 @@
|
|
1
|
+
module TorchVision
|
2
|
+
module Datasets
|
3
|
+
# TODO inherit Torch::Utils::Data::Dataset
|
4
|
+
class VisionDataset
|
5
|
+
def initialize(root, transforms: nil, transform: nil, target_transform: nil)
|
6
|
+
@root = root
|
7
|
+
|
8
|
+
has_transforms = !transforms.nil?
|
9
|
+
has_separate_transform = !transform.nil? || !target_transform.nil?
|
10
|
+
if has_transforms && has_separate_transform
|
11
|
+
raise ArgumentError, "Only transforms or transform/target_transform can be passed as argument"
|
12
|
+
end
|
13
|
+
|
14
|
+
@transform = transform
|
15
|
+
@target_transform = target_transform
|
16
|
+
|
17
|
+
if has_separate_transform
|
18
|
+
# transforms = StandardTransform.new(transform, target_transform)
|
19
|
+
end
|
20
|
+
@transforms = transforms
|
21
|
+
end
|
22
|
+
|
23
|
+
private
|
24
|
+
|
25
|
+
def download_file(url, download_root:, filename:, sha256:)
|
26
|
+
FileUtils.mkdir_p(download_root)
|
27
|
+
|
28
|
+
dest = File.join(download_root, filename)
|
29
|
+
return dest if File.exist?(dest)
|
30
|
+
|
31
|
+
temp_path = "#{Dir.tmpdir}/#{Time.now.to_f}" # TODO better name
|
32
|
+
|
33
|
+
uri = URI(url)
|
34
|
+
|
35
|
+
# Net::HTTP automatically adds Accept-Encoding for compression
|
36
|
+
# of response bodies and automatically decompresses gzip
|
37
|
+
# and deflateresponses unless a Range header was sent.
|
38
|
+
# https://ruby-doc.org/stdlib-2.6.4/libdoc/net/http/rdoc/Net/HTTP.html
|
39
|
+
Net::HTTP.start(uri.host, uri.port, use_ssl: uri.scheme == "https") do |http|
|
40
|
+
request = Net::HTTP::Get.new(uri)
|
41
|
+
|
42
|
+
puts "Downloading #{url}..."
|
43
|
+
File.open(temp_path, "wb") do |f|
|
44
|
+
http.request(request) do |response|
|
45
|
+
response.read_body do |chunk|
|
46
|
+
f.write(chunk)
|
47
|
+
end
|
48
|
+
end
|
49
|
+
end
|
50
|
+
end
|
51
|
+
|
52
|
+
unless check_integrity(temp_path, sha256)
|
53
|
+
raise Error, "Bad hash"
|
54
|
+
end
|
55
|
+
|
56
|
+
FileUtils.mv(temp_path, dest)
|
57
|
+
|
58
|
+
dest
|
59
|
+
end
|
60
|
+
|
61
|
+
def check_integrity(path, sha256)
|
62
|
+
File.exist?(path) && Digest::SHA256.file(path).hexdigest == sha256
|
63
|
+
end
|
64
|
+
end
|
65
|
+
end
|
66
|
+
end
|
@@ -0,0 +1,46 @@
|
|
1
|
+
module TorchVision
|
2
|
+
module Models
|
3
|
+
class BasicBlock < Torch::NN::Module
|
4
|
+
def initialize(inplanes, planes, stride: 1, downsample: nil, groups: 1, base_width: 64, dilation: 1, norm_layer: nil)
|
5
|
+
super()
|
6
|
+
norm_layer ||= Torch::NN::BatchNorm2d
|
7
|
+
if groups != 1 || base_width != 64
|
8
|
+
raise ArgumentError, "BasicBlock only supports groups=1 and base_width=64"
|
9
|
+
end
|
10
|
+
if dilation > 1
|
11
|
+
raise NotImplementedError, "Dilation > 1 not supported in BasicBlock"
|
12
|
+
end
|
13
|
+
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
14
|
+
@conv1 = Torch::NN::Conv2d.new(inplanes, planes, 3, stride: stride, padding: 1, groups: 1, bias: false, dilation: 1)
|
15
|
+
@bn1 = norm_layer.new(planes)
|
16
|
+
@relu = Torch::NN::ReLU.new(inplace: true)
|
17
|
+
@conv2 = Torch::NN::Conv2d.new(planes, planes, 3, stride: 1, padding: 1, groups: 1, bias: false, dilation: 1)
|
18
|
+
@bn2 = norm_layer.new(planes)
|
19
|
+
@downsample = downsample
|
20
|
+
@stride = stride
|
21
|
+
end
|
22
|
+
|
23
|
+
def forward(x)
|
24
|
+
identity = x
|
25
|
+
|
26
|
+
out = @conv1.call(x)
|
27
|
+
out = @bn1.call(out)
|
28
|
+
out = @relu.call(out)
|
29
|
+
|
30
|
+
out = @conv2.call(out)
|
31
|
+
out = @bn2.call(out)
|
32
|
+
|
33
|
+
identity = @downsample.call(x) if @downsample
|
34
|
+
|
35
|
+
out += identity
|
36
|
+
out = @relu.call(out)
|
37
|
+
|
38
|
+
out
|
39
|
+
end
|
40
|
+
|
41
|
+
def self.expansion
|
42
|
+
1
|
43
|
+
end
|
44
|
+
end
|
45
|
+
end
|
46
|
+
end
|
@@ -0,0 +1,47 @@
|
|
1
|
+
module TorchVision
|
2
|
+
module Models
|
3
|
+
class Bottleneck < Torch::NN::Module
|
4
|
+
def initialize(inplanes, planes, stride: 1, downsample: nil, groups: 1, base_width: 64, dilation: 1, norm_layer: nil)
|
5
|
+
super()
|
6
|
+
norm_layer ||= Torch::NN::BatchNorm2d
|
7
|
+
width = (planes * (base_width / 64.0)).to_i * groups
|
8
|
+
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
9
|
+
@conv1 = Torch::NN::Conv2d.new(inplanes, width, 1, stride: 1, bias: false)
|
10
|
+
@bn1 = norm_layer.new(width)
|
11
|
+
@conv2 = Torch::NN::Conv2d.new(width, width, 3, stride: stride, padding: dilation, groups: groups, bias: false, dilation: dilation)
|
12
|
+
@bn2 = norm_layer.new(width)
|
13
|
+
@conv3 = Torch::NN::Conv2d.new(width, planes * self.class.expansion, 1, stride: 1, bias: false)
|
14
|
+
@bn3 = norm_layer.new(planes * self.class.expansion)
|
15
|
+
@relu = Torch::NN::ReLU.new(inplace: true)
|
16
|
+
@downsample = downsample
|
17
|
+
@stride = stride
|
18
|
+
end
|
19
|
+
|
20
|
+
def forward(x)
|
21
|
+
identity = x
|
22
|
+
|
23
|
+
out = @conv1.call(x)
|
24
|
+
out = @bn1.call(out)
|
25
|
+
out = @relu.call(out)
|
26
|
+
|
27
|
+
out = @conv2.call(out)
|
28
|
+
out = @bn2.call(out)
|
29
|
+
out = @relu.call(out)
|
30
|
+
|
31
|
+
out = @conv3.call(out)
|
32
|
+
out = @bn3.call(out)
|
33
|
+
|
34
|
+
identity = @downsample.call(x) if @downsample
|
35
|
+
|
36
|
+
out += identity
|
37
|
+
out = @relu.call(out)
|
38
|
+
|
39
|
+
out
|
40
|
+
end
|
41
|
+
|
42
|
+
def self.expansion
|
43
|
+
4
|
44
|
+
end
|
45
|
+
end
|
46
|
+
end
|
47
|
+
end
|
@@ -0,0 +1,107 @@
|
|
1
|
+
module TorchVision
|
2
|
+
module Models
|
3
|
+
class ResNet < Torch::NN::Module
|
4
|
+
def initialize(block, layers, num_classes=1000, zero_init_residual: false,
|
5
|
+
groups: 1, width_per_group: 64, replace_stride_with_dilation: nil, norm_layer: nil)
|
6
|
+
|
7
|
+
super()
|
8
|
+
norm_layer ||= Torch::NN::BatchNorm2d
|
9
|
+
@norm_layer = norm_layer
|
10
|
+
|
11
|
+
@inplanes = 64
|
12
|
+
@dilation = 1
|
13
|
+
if replace_stride_with_dilation.nil?
|
14
|
+
# each element in the tuple indicates if we should replace
|
15
|
+
# the 2x2 stride with a dilated convolution instead
|
16
|
+
replace_stride_with_dilation = [false, false, false]
|
17
|
+
end
|
18
|
+
if replace_stride_with_dilation.length != 3
|
19
|
+
raise ArgumentError, "replace_stride_with_dilation should be nil or a 3-element tuple, got #{replace_stride_with_dilation}"
|
20
|
+
end
|
21
|
+
@groups = groups
|
22
|
+
@base_width = width_per_group
|
23
|
+
@conv1 = Torch::NN::Conv2d.new(3, @inplanes, 7, stride: 2, padding: 3, bias: false)
|
24
|
+
@bn1 = norm_layer.new(@inplanes)
|
25
|
+
@relu = Torch::NN::ReLU.new(inplace: true)
|
26
|
+
@maxpool = Torch::NN::MaxPool2d.new(3, stride: 2, padding: 1)
|
27
|
+
@layer1 = _make_layer(block, 64, layers[0])
|
28
|
+
@layer2 = _make_layer(block, 128, layers[1], stride: 2, dilate: replace_stride_with_dilation[0])
|
29
|
+
@layer3 = _make_layer(block, 256, layers[2], stride: 2, dilate: replace_stride_with_dilation[1])
|
30
|
+
@layer4 = _make_layer(block, 512, layers[3], stride: 2, dilate: replace_stride_with_dilation[2])
|
31
|
+
@avgpool = Torch::NN::AdaptiveAvgPool2d.new([1, 1])
|
32
|
+
@fc = Torch::NN::Linear.new(512 * block.expansion, num_classes)
|
33
|
+
|
34
|
+
modules.each do |m|
|
35
|
+
case m
|
36
|
+
when Torch::NN::Conv2d
|
37
|
+
Torch::NN::Init.kaiming_normal!(m.weight, mode: "fan_out", nonlinearity: "relu")
|
38
|
+
when Torch::NN::BatchNorm2d, Torch::NN::GroupNorm
|
39
|
+
Torch::NN::Init.constant!(m.weight, 1)
|
40
|
+
Torch::NN::Init.constant!(m.bias, 0)
|
41
|
+
end
|
42
|
+
end
|
43
|
+
|
44
|
+
# Zero-initialize the last BN in each residual branch,
|
45
|
+
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
46
|
+
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
47
|
+
if zero_init_residual
|
48
|
+
modules.each do |m|
|
49
|
+
case m
|
50
|
+
when Bottleneck
|
51
|
+
Torch::NN::Init.constant!(m.bn3.weight, 0)
|
52
|
+
when BasicBlock
|
53
|
+
Torch::NN::Init.constant!(m.bn2.weight, 0)
|
54
|
+
end
|
55
|
+
end
|
56
|
+
end
|
57
|
+
end
|
58
|
+
|
59
|
+
def _make_layer(block, planes, blocks, stride: 1, dilate: false)
|
60
|
+
norm_layer = @norm_layer
|
61
|
+
downsample = nil
|
62
|
+
previous_dilation = @dilation
|
63
|
+
if dilate
|
64
|
+
@dilation *= stride
|
65
|
+
stride = 1
|
66
|
+
end
|
67
|
+
if stride != 1 || @inplanes != planes * block.expansion
|
68
|
+
downsample = Torch::NN::Sequential.new(
|
69
|
+
Torch::NN::Conv2d.new(@inplanes, planes * block.expansion, 1, stride: stride, bias: false),
|
70
|
+
norm_layer.new(planes * block.expansion)
|
71
|
+
)
|
72
|
+
end
|
73
|
+
|
74
|
+
layers = []
|
75
|
+
layers << block.new(@inplanes, planes, stride: stride, downsample: downsample, groups: @groups, base_width: @base_width, dilation: previous_dilation, norm_layer: norm_layer)
|
76
|
+
@inplanes = planes * block.expansion
|
77
|
+
(blocks - 1).times do
|
78
|
+
layers << block.new(@inplanes, planes, groups: @groups, base_width: @base_width, dilation: @dilation, norm_layer: norm_layer)
|
79
|
+
end
|
80
|
+
|
81
|
+
Torch::NN::Sequential.new(*layers)
|
82
|
+
end
|
83
|
+
|
84
|
+
def _forward_impl(x)
|
85
|
+
x = @conv1.call(x)
|
86
|
+
x = @bn1.call(x)
|
87
|
+
x = @relu.call(x)
|
88
|
+
x = @maxpool.call(x)
|
89
|
+
|
90
|
+
x = @layer1.call(x)
|
91
|
+
x = @layer2.call(x)
|
92
|
+
x = @layer3.call(x)
|
93
|
+
x = @layer4.call(x)
|
94
|
+
|
95
|
+
x = @avgpool.call(x)
|
96
|
+
x = Torch.flatten(x, 1)
|
97
|
+
x = @fc.call(x)
|
98
|
+
|
99
|
+
x
|
100
|
+
end
|
101
|
+
|
102
|
+
def forward(x)
|
103
|
+
_forward_impl(x)
|
104
|
+
end
|
105
|
+
end
|
106
|
+
end
|
107
|
+
end
|
@@ -0,0 +1,15 @@
|
|
1
|
+
module TorchVision
|
2
|
+
module Models
|
3
|
+
module ResNet18
|
4
|
+
def self.new(pretrained: false, **kwargs)
|
5
|
+
model = ResNet.new(BasicBlock, [2, 2, 2, 2], **kwargs)
|
6
|
+
if pretrained
|
7
|
+
url = "https://download.pytorch.org/models/resnet18-5c106cde.pth"
|
8
|
+
state_dict = Torch::Hub.load_state_dict_from_url(url)
|
9
|
+
model.load_state_dict(state_dict)
|
10
|
+
end
|
11
|
+
model
|
12
|
+
end
|
13
|
+
end
|
14
|
+
end
|
15
|
+
end
|
@@ -22,18 +22,35 @@ module TorchVision
|
|
22
22
|
if std.to_a.any? { |v| v == 0 }
|
23
23
|
raise ArgumentError, "std evaluated to zero after conversion to #{dtype}, leading to division by zero."
|
24
24
|
end
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
25
|
+
if mean.ndim == 1
|
26
|
+
mean = mean[0...mean.size(0), nil, nil]
|
27
|
+
end
|
28
|
+
if std.ndim == 1
|
29
|
+
std = std[0...std.size(0), nil, nil]
|
30
|
+
end
|
31
31
|
tensor.sub!(mean).div!(std)
|
32
32
|
tensor
|
33
33
|
end
|
34
34
|
|
35
35
|
# TODO improve
|
36
36
|
def to_tensor(pic)
|
37
|
+
if !pic.is_a?(Numo::NArray) && !pic.is_a?(Torch::Tensor)
|
38
|
+
raise ArgumentError, "pic should be tensor or Numo::NArray. Got #{pic.class.name}"
|
39
|
+
end
|
40
|
+
|
41
|
+
if pic.is_a?(Numo::NArray) && ![2, 3].include?(pic.ndim)
|
42
|
+
raise ArgumentError, "pic should be 2/3 dimensional. Got #{pic.dim} dimensions."
|
43
|
+
end
|
44
|
+
|
45
|
+
if pic.is_a?(Numo::NArray)
|
46
|
+
if pic.ndim == 2
|
47
|
+
raise Torch::NotImplementedYet
|
48
|
+
end
|
49
|
+
|
50
|
+
img = Torch.from_numo(pic.transpose(2, 0, 1))
|
51
|
+
return img.float.div(255)
|
52
|
+
end
|
53
|
+
|
37
54
|
pic = pic.float
|
38
55
|
pic.unsqueeze!(0).div!(255)
|
39
56
|
end
|
data/lib/torchvision/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torchvision
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.1.
|
4
|
+
version: 0.1.2
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2020-04-
|
11
|
+
date: 2020-04-29 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-narray
|
@@ -30,14 +30,14 @@ dependencies:
|
|
30
30
|
requirements:
|
31
31
|
- - ">="
|
32
32
|
- !ruby/object:Gem::Version
|
33
|
-
version: 0.2.
|
33
|
+
version: 0.2.4
|
34
34
|
type: :runtime
|
35
35
|
prerelease: false
|
36
36
|
version_requirements: !ruby/object:Gem::Requirement
|
37
37
|
requirements:
|
38
38
|
- - ">="
|
39
39
|
- !ruby/object:Gem::Version
|
40
|
-
version: 0.2.
|
40
|
+
version: 0.2.4
|
41
41
|
- !ruby/object:Gem::Dependency
|
42
42
|
name: bundler
|
43
43
|
requirement: !ruby/object:Gem::Requirement
|
@@ -90,7 +90,16 @@ files:
|
|
90
90
|
- LICENSE.txt
|
91
91
|
- README.md
|
92
92
|
- lib/torchvision.rb
|
93
|
+
- lib/torchvision/datasets/cifar10.rb
|
94
|
+
- lib/torchvision/datasets/cifar100.rb
|
95
|
+
- lib/torchvision/datasets/fashion_mnist.rb
|
96
|
+
- lib/torchvision/datasets/kmnist.rb
|
93
97
|
- lib/torchvision/datasets/mnist.rb
|
98
|
+
- lib/torchvision/datasets/vision_dataset.rb
|
99
|
+
- lib/torchvision/models/basic_block.rb
|
100
|
+
- lib/torchvision/models/bottleneck.rb
|
101
|
+
- lib/torchvision/models/resnet.rb
|
102
|
+
- lib/torchvision/models/resnet18.rb
|
94
103
|
- lib/torchvision/transforms/compose.rb
|
95
104
|
- lib/torchvision/transforms/functional.rb
|
96
105
|
- lib/torchvision/transforms/normalize.rb
|