torchvision 0.1.1 → 0.1.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/README.md +25 -3
- data/lib/torchvision.rb +12 -0
- data/lib/torchvision/datasets/cifar10.rb +116 -0
- data/lib/torchvision/datasets/cifar100.rb +41 -0
- data/lib/torchvision/datasets/fashion_mnist.rb +30 -0
- data/lib/torchvision/datasets/kmnist.rb +30 -0
- data/lib/torchvision/datasets/mnist.rb +47 -75
- data/lib/torchvision/datasets/vision_dataset.rb +66 -0
- data/lib/torchvision/models/basic_block.rb +46 -0
- data/lib/torchvision/models/bottleneck.rb +47 -0
- data/lib/torchvision/models/resnet.rb +107 -0
- data/lib/torchvision/models/resnet18.rb +15 -0
- data/lib/torchvision/transforms/functional.rb +23 -6
- data/lib/torchvision/version.rb +1 -1
- metadata +13 -4
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 107e429f990a063e57f6218ee6d4fbed3cff80aa6746f0794798d59e1c13b099
|
4
|
+
data.tar.gz: 556fdc4c413803d415ea5575747ec2239f18b0dff9b39a37a4fd6366adec37a6
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 21b712578516c146888be30bed64a6da6339a42974f5c88fa685278caa9231e4bc4b75e250af3ad37d5450cd205891b31b281cf23fa362456c7ae00998c2736d
|
7
|
+
data.tar.gz: ac86f13e8b5d6a400842ba37b1bf593139360a8df9a19ab4674e71a4327c821a6d5735774185dcbaed368ee735b00bf33d9faedb13d46cd168314da89d900c38
|
data/CHANGELOG.md
CHANGED
data/README.md
CHANGED
@@ -20,11 +20,33 @@ This library follows the [Python API](https://pytorch.org/docs/master/torchvisio
|
|
20
20
|
|
21
21
|
## Datasets
|
22
22
|
|
23
|
-
|
23
|
+
Load a dataset
|
24
24
|
|
25
25
|
```ruby
|
26
|
-
|
27
|
-
|
26
|
+
TorchVision::Datasets::MNIST.new("./data", train: true, download: true)
|
27
|
+
```
|
28
|
+
|
29
|
+
Supported datasets are:
|
30
|
+
|
31
|
+
- CIFAR10
|
32
|
+
- CIFAR100
|
33
|
+
- FashionMNIST
|
34
|
+
- KMNIST
|
35
|
+
- MNIST
|
36
|
+
|
37
|
+
## Transforms
|
38
|
+
|
39
|
+
```ruby
|
40
|
+
TorchVision::Transforms::Compose.new([
|
41
|
+
TorchVision::Transforms::ToTensor.new,
|
42
|
+
TorchVision::Transforms::Normalize.new([0.1307], [0.3081])
|
43
|
+
])
|
44
|
+
```
|
45
|
+
|
46
|
+
## Models
|
47
|
+
|
48
|
+
```ruby
|
49
|
+
TorchVision::Models::Resnet18.new
|
28
50
|
```
|
29
51
|
|
30
52
|
## Disclaimer
|
data/lib/torchvision.rb
CHANGED
@@ -6,13 +6,25 @@ require "torch"
|
|
6
6
|
require "digest"
|
7
7
|
require "fileutils"
|
8
8
|
require "net/http"
|
9
|
+
require "rubygems/package"
|
9
10
|
require "tmpdir"
|
10
11
|
|
11
12
|
# modules
|
12
13
|
require "torchvision/version"
|
13
14
|
|
14
15
|
# datasets
|
16
|
+
require "torchvision/datasets/vision_dataset"
|
17
|
+
require "torchvision/datasets/cifar10"
|
18
|
+
require "torchvision/datasets/cifar100"
|
15
19
|
require "torchvision/datasets/mnist"
|
20
|
+
require "torchvision/datasets/fashion_mnist"
|
21
|
+
require "torchvision/datasets/kmnist"
|
22
|
+
|
23
|
+
# models
|
24
|
+
require "torchvision/models/basic_block"
|
25
|
+
require "torchvision/models/bottleneck"
|
26
|
+
require "torchvision/models/resnet"
|
27
|
+
require "torchvision/models/resnet18"
|
16
28
|
|
17
29
|
# transforms
|
18
30
|
require "torchvision/transforms/compose"
|
@@ -0,0 +1,116 @@
|
|
1
|
+
module TorchVision
|
2
|
+
module Datasets
|
3
|
+
class CIFAR10 < VisionDataset
|
4
|
+
# https://www.cs.toronto.edu/~kriz/cifar.html
|
5
|
+
|
6
|
+
def initialize(root, train: true, download: false, transform: nil, target_transform: nil)
|
7
|
+
super(root, transform: transform, target_transform: target_transform)
|
8
|
+
@train = train
|
9
|
+
|
10
|
+
self.download if download
|
11
|
+
|
12
|
+
if !_check_integrity
|
13
|
+
raise Error, "Dataset not found or corrupted. You can use download=True to download it"
|
14
|
+
end
|
15
|
+
|
16
|
+
downloaded_list = @train ? train_list : test_list
|
17
|
+
|
18
|
+
@data = String.new
|
19
|
+
@targets = String.new
|
20
|
+
|
21
|
+
downloaded_list.each do |file|
|
22
|
+
file_path = File.join(@root, base_folder, file[:filename])
|
23
|
+
File.open(file_path, "rb") do |f|
|
24
|
+
while !f.eof?
|
25
|
+
f.read(1) if multiple_labels?
|
26
|
+
@targets << f.read(1)
|
27
|
+
@data << f.read(3072)
|
28
|
+
end
|
29
|
+
end
|
30
|
+
end
|
31
|
+
|
32
|
+
@targets = @targets.unpack("C*")
|
33
|
+
# TODO switch i to -1 when Numo supports it
|
34
|
+
@data = Numo::UInt8.from_binary(@data).reshape(@targets.size, 3, 32, 32)
|
35
|
+
@data = @data.transpose(0, 2, 3, 1)
|
36
|
+
end
|
37
|
+
|
38
|
+
def size
|
39
|
+
@data.shape[0]
|
40
|
+
end
|
41
|
+
|
42
|
+
def [](index)
|
43
|
+
# TODO remove trues when Numo supports it
|
44
|
+
img, target = @data[index, true, true, true], @targets[index]
|
45
|
+
|
46
|
+
# TODO convert to image
|
47
|
+
img = @transform.call(img) if @transform
|
48
|
+
|
49
|
+
target = @target_transform.call(target) if @target_transform
|
50
|
+
|
51
|
+
[img, target]
|
52
|
+
end
|
53
|
+
|
54
|
+
def _check_integrity
|
55
|
+
root = @root
|
56
|
+
(train_list + test_list).each do |fentry|
|
57
|
+
fpath = File.join(root, base_folder, fentry[:filename])
|
58
|
+
return false unless check_integrity(fpath, fentry[:sha256])
|
59
|
+
end
|
60
|
+
true
|
61
|
+
end
|
62
|
+
|
63
|
+
def download
|
64
|
+
if _check_integrity
|
65
|
+
puts "Files already downloaded and verified"
|
66
|
+
return
|
67
|
+
end
|
68
|
+
|
69
|
+
download_file(url, download_root: @root, filename: filename, sha256: tgz_sha256)
|
70
|
+
|
71
|
+
path = File.join(@root, filename)
|
72
|
+
File.open(path, "rb") do |io|
|
73
|
+
Gem::Package.new("").extract_tar_gz(io, @root)
|
74
|
+
end
|
75
|
+
end
|
76
|
+
|
77
|
+
private
|
78
|
+
|
79
|
+
def base_folder
|
80
|
+
"cifar-10-batches-bin"
|
81
|
+
end
|
82
|
+
|
83
|
+
def url
|
84
|
+
"https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz"
|
85
|
+
end
|
86
|
+
|
87
|
+
def filename
|
88
|
+
"cifar-10-binary.tar.gz"
|
89
|
+
end
|
90
|
+
|
91
|
+
def tgz_sha256
|
92
|
+
"c4a38c50a1bc5f3a1c5537f2155ab9d68f9f25eb1ed8d9ddda3db29a59bca1dd"
|
93
|
+
end
|
94
|
+
|
95
|
+
def train_list
|
96
|
+
[
|
97
|
+
{filename: "data_batch_1.bin", sha256: "cee916563c9f80d84e3cc88e17fdc0941787f1244f00a67874d45b261883ada5"},
|
98
|
+
{filename: "data_batch_2.bin", sha256: "a591ca11fa1708a91ee40f54b3da4784ccd871ecf2137de63f51ada8b3fa57ed"},
|
99
|
+
{filename: "data_batch_3.bin", sha256: "bbe8596564c0f86427f876058170b84dac6670ddf06d79402899d93ceea26f67"},
|
100
|
+
{filename: "data_batch_4.bin", sha256: "014e562d6e23c72197cc727519169a60359f5eccd8945ad5a09d710285ff4e48"},
|
101
|
+
{filename: "data_batch_5.bin", sha256: "755304fc0b379caeae8c14f0dac912fbc7d6cd469eb67a1029a08a39453a9add"},
|
102
|
+
]
|
103
|
+
end
|
104
|
+
|
105
|
+
def test_list
|
106
|
+
[
|
107
|
+
{filename: "test_batch.bin", sha256: "8e2eb146ae340b09e24670f29cabc6326dba54da8789dab6768acf480273f65b"}
|
108
|
+
]
|
109
|
+
end
|
110
|
+
|
111
|
+
def multiple_labels?
|
112
|
+
false
|
113
|
+
end
|
114
|
+
end
|
115
|
+
end
|
116
|
+
end
|
@@ -0,0 +1,41 @@
|
|
1
|
+
module TorchVision
|
2
|
+
module Datasets
|
3
|
+
class CIFAR100 < CIFAR10
|
4
|
+
# https://www.cs.toronto.edu/~kriz/cifar.html
|
5
|
+
|
6
|
+
private
|
7
|
+
|
8
|
+
def base_folder
|
9
|
+
"cifar-100-binary"
|
10
|
+
end
|
11
|
+
|
12
|
+
def url
|
13
|
+
"https://www.cs.toronto.edu/~kriz/cifar-100-binary.tar.gz"
|
14
|
+
end
|
15
|
+
|
16
|
+
def filename
|
17
|
+
"cifar-100-binary.tar.gz"
|
18
|
+
end
|
19
|
+
|
20
|
+
def tgz_sha256
|
21
|
+
"58a81ae192c23a4be8b1804d68e518ed807d710a4eb253b1f2a199162a40d8ec"
|
22
|
+
end
|
23
|
+
|
24
|
+
def train_list
|
25
|
+
[
|
26
|
+
{filename: "train.bin", sha256: "f31298fc616915fa142368359df1c4ca2ae984d6915ca468b998a5ec6aeebf29"}
|
27
|
+
]
|
28
|
+
end
|
29
|
+
|
30
|
+
def test_list
|
31
|
+
[
|
32
|
+
{filename: "test.bin", sha256: "d8b1e6b7b3bee4020055f0699b111f60b1af1e262aeb93a0b659061746f8224a"}
|
33
|
+
]
|
34
|
+
end
|
35
|
+
|
36
|
+
def multiple_labels?
|
37
|
+
true
|
38
|
+
end
|
39
|
+
end
|
40
|
+
end
|
41
|
+
end
|
@@ -0,0 +1,30 @@
|
|
1
|
+
module TorchVision
|
2
|
+
module Datasets
|
3
|
+
class FashionMNIST < MNIST
|
4
|
+
# https://github.com/zalandoresearch/fashion-mnist
|
5
|
+
|
6
|
+
private
|
7
|
+
|
8
|
+
def resources
|
9
|
+
[
|
10
|
+
{
|
11
|
+
url: "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz",
|
12
|
+
sha256: "3aede38d61863908ad78613f6a32ed271626dd12800ba2636569512369268a84"
|
13
|
+
},
|
14
|
+
{
|
15
|
+
url: "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz",
|
16
|
+
sha256: "a04f17134ac03560a47e3764e11b92fc97de4d1bfaf8ba1a3aa29af54cc90845"
|
17
|
+
},
|
18
|
+
{
|
19
|
+
url: "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz",
|
20
|
+
sha256: "346e55b948d973a97e58d2351dde16a484bd415d4595297633bb08f03db6a073"
|
21
|
+
},
|
22
|
+
{
|
23
|
+
url: "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz",
|
24
|
+
sha256: "67da17c76eaffca5446c3361aaab5c3cd6d1c2608764d35dfb1850b086bf8dd5"
|
25
|
+
}
|
26
|
+
]
|
27
|
+
end
|
28
|
+
end
|
29
|
+
end
|
30
|
+
end
|
@@ -0,0 +1,30 @@
|
|
1
|
+
module TorchVision
|
2
|
+
module Datasets
|
3
|
+
class KMNIST < MNIST
|
4
|
+
# https://github.com/rois-codh/kmnist
|
5
|
+
|
6
|
+
private
|
7
|
+
|
8
|
+
def resources
|
9
|
+
[
|
10
|
+
{
|
11
|
+
url: "http://codh.rois.ac.jp/kmnist/dataset/kmnist/train-images-idx3-ubyte.gz",
|
12
|
+
sha256: "51467d22d8cc72929e2a028a0428f2086b092bb31cfb79c69cc0a90ce135fde4"
|
13
|
+
},
|
14
|
+
{
|
15
|
+
url: "http://codh.rois.ac.jp/kmnist/dataset/kmnist/train-labels-idx1-ubyte.gz",
|
16
|
+
sha256: "e38f9ebcd0f3ebcdec7fc8eabdcdaef93bb0df8ea12bee65224341c8183d8e17"
|
17
|
+
},
|
18
|
+
{
|
19
|
+
url: "http://codh.rois.ac.jp/kmnist/dataset/kmnist/t10k-images-idx3-ubyte.gz",
|
20
|
+
sha256: "edd7a857845ad6bb1d0ba43fe7e794d164fe2dce499a1694695a792adfac43c5"
|
21
|
+
},
|
22
|
+
{
|
23
|
+
url: "http://codh.rois.ac.jp/kmnist/dataset/kmnist/t10k-labels-idx1-ubyte.gz",
|
24
|
+
sha256: "20bb9a0ef54c7db3efc55a92eef5582c109615df22683c380526788f98e42a1c"
|
25
|
+
}
|
26
|
+
]
|
27
|
+
end
|
28
|
+
end
|
29
|
+
end
|
30
|
+
end
|
@@ -1,31 +1,11 @@
|
|
1
1
|
module TorchVision
|
2
2
|
module Datasets
|
3
|
-
class MNIST
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
},
|
9
|
-
{
|
10
|
-
url: "http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz",
|
11
|
-
sha256: "3552534a0a558bbed6aed32b30c495cca23d567ec52cac8be1a0730e8010255c"
|
12
|
-
},
|
13
|
-
{
|
14
|
-
url: "http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz",
|
15
|
-
sha256: "8d422c7b0a1c1c79245a5bcf07fe86e33eeafee792b84584aec276f5a2dbc4e6"
|
16
|
-
},
|
17
|
-
{
|
18
|
-
url: "http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz",
|
19
|
-
sha256: "f7ae60f92e00ec6debd23a6088c31dbd2371eca3ffa0defaefb259924204aec6"
|
20
|
-
}
|
21
|
-
]
|
22
|
-
TRAINING_FILE = "training.pt"
|
23
|
-
TEST_FILE = "test.pt"
|
24
|
-
|
25
|
-
def initialize(root, train: true, download: false, transform: nil)
|
26
|
-
@root = root
|
3
|
+
class MNIST < VisionDataset
|
4
|
+
# http://yann.lecun.com/exdb/mnist/
|
5
|
+
|
6
|
+
def initialize(root, train: true, download: false, transform: nil, target_transform: nil)
|
7
|
+
super(root, transform: transform, target_transform: target_transform)
|
27
8
|
@train = train
|
28
|
-
@transform = transform
|
29
9
|
|
30
10
|
self.download if download
|
31
11
|
|
@@ -33,34 +13,36 @@ module TorchVision
|
|
33
13
|
raise Error, "Dataset not found. You can use download: true to download it"
|
34
14
|
end
|
35
15
|
|
36
|
-
data_file = @train ?
|
16
|
+
data_file = @train ? training_file : test_file
|
37
17
|
@data, @targets = Torch.load(File.join(processed_folder, data_file))
|
38
18
|
end
|
39
19
|
|
40
20
|
def size
|
41
|
-
@data.size
|
21
|
+
@data.size(0)
|
42
22
|
end
|
43
23
|
|
44
24
|
def [](index)
|
45
|
-
img = @data[index]
|
25
|
+
img, target = @data[index], @targets[index].item
|
26
|
+
|
27
|
+
# TODO convert to image
|
46
28
|
img = @transform.call(img) if @transform
|
47
29
|
|
48
|
-
target = @
|
30
|
+
target = @target_transform.call(target) if @target_transform
|
49
31
|
|
50
32
|
[img, target]
|
51
33
|
end
|
52
34
|
|
53
35
|
def raw_folder
|
54
|
-
File.join(@root, "
|
36
|
+
File.join(@root, self.class.name.split("::").last, "raw")
|
55
37
|
end
|
56
38
|
|
57
39
|
def processed_folder
|
58
|
-
File.join(@root, "
|
40
|
+
File.join(@root, self.class.name.split("::").last, "processed")
|
59
41
|
end
|
60
42
|
|
61
43
|
def check_exists
|
62
|
-
File.exist?(File.join(processed_folder,
|
63
|
-
File.exist?(File.join(processed_folder,
|
44
|
+
File.exist?(File.join(processed_folder, training_file)) &&
|
45
|
+
File.exist?(File.join(processed_folder, test_file))
|
64
46
|
end
|
65
47
|
|
66
48
|
def download
|
@@ -69,7 +51,7 @@ module TorchVision
|
|
69
51
|
FileUtils.mkdir_p(raw_folder)
|
70
52
|
FileUtils.mkdir_p(processed_folder)
|
71
53
|
|
72
|
-
|
54
|
+
resources.each do |resource|
|
73
55
|
filename = resource[:url].split("/").last
|
74
56
|
download_file(resource[:url], download_root: raw_folder, filename: filename, sha256: resource[:sha256])
|
75
57
|
end
|
@@ -85,14 +67,43 @@ module TorchVision
|
|
85
67
|
unpack_mnist("t10k-labels-idx1-ubyte", 8, [10000])
|
86
68
|
]
|
87
69
|
|
88
|
-
Torch.save(training_set, File.join(processed_folder,
|
89
|
-
Torch.save(test_set, File.join(processed_folder,
|
70
|
+
Torch.save(training_set, File.join(processed_folder, training_file))
|
71
|
+
Torch.save(test_set, File.join(processed_folder, test_file))
|
90
72
|
|
91
73
|
puts "Done!"
|
92
74
|
end
|
93
75
|
|
94
76
|
private
|
95
77
|
|
78
|
+
def resources
|
79
|
+
[
|
80
|
+
{
|
81
|
+
url: "http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz",
|
82
|
+
sha256: "440fcabf73cc546fa21475e81ea370265605f56be210a4024d2ca8f203523609"
|
83
|
+
},
|
84
|
+
{
|
85
|
+
url: "http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz",
|
86
|
+
sha256: "3552534a0a558bbed6aed32b30c495cca23d567ec52cac8be1a0730e8010255c"
|
87
|
+
},
|
88
|
+
{
|
89
|
+
url: "http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz",
|
90
|
+
sha256: "8d422c7b0a1c1c79245a5bcf07fe86e33eeafee792b84584aec276f5a2dbc4e6"
|
91
|
+
},
|
92
|
+
{
|
93
|
+
url: "http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz",
|
94
|
+
sha256: "f7ae60f92e00ec6debd23a6088c31dbd2371eca3ffa0defaefb259924204aec6"
|
95
|
+
}
|
96
|
+
]
|
97
|
+
end
|
98
|
+
|
99
|
+
def training_file
|
100
|
+
"training.pt"
|
101
|
+
end
|
102
|
+
|
103
|
+
def test_file
|
104
|
+
"test.pt"
|
105
|
+
end
|
106
|
+
|
96
107
|
def unpack_mnist(path, offset, shape)
|
97
108
|
path = File.join(raw_folder, "#{path}.gz")
|
98
109
|
File.open(path, "rb") do |f|
|
@@ -101,45 +112,6 @@ module TorchVision
|
|
101
112
|
Torch.tensor(Numo::UInt8.from_string(gz.read, shape))
|
102
113
|
end
|
103
114
|
end
|
104
|
-
|
105
|
-
def download_file(url, download_root:, filename:, sha256:)
|
106
|
-
FileUtils.mkdir_p(download_root)
|
107
|
-
|
108
|
-
dest = File.join(download_root, filename)
|
109
|
-
return dest if File.exist?(dest)
|
110
|
-
|
111
|
-
temp_path = "#{Dir.tmpdir}/#{Time.now.to_f}" # TODO better name
|
112
|
-
|
113
|
-
digest = Digest::SHA256.new
|
114
|
-
|
115
|
-
uri = URI(url)
|
116
|
-
|
117
|
-
# Net::HTTP automatically adds Accept-Encoding for compression
|
118
|
-
# of response bodies and automatically decompresses gzip
|
119
|
-
# and deflateresponses unless a Range header was sent.
|
120
|
-
# https://ruby-doc.org/stdlib-2.6.4/libdoc/net/http/rdoc/Net/HTTP.html
|
121
|
-
Net::HTTP.start(uri.host, uri.port, use_ssl: uri.scheme == "https") do |http|
|
122
|
-
request = Net::HTTP::Get.new(uri)
|
123
|
-
|
124
|
-
puts "Downloading #{url}..."
|
125
|
-
File.open(temp_path, "wb") do |f|
|
126
|
-
http.request(request) do |response|
|
127
|
-
response.read_body do |chunk|
|
128
|
-
f.write(chunk)
|
129
|
-
digest.update(chunk)
|
130
|
-
end
|
131
|
-
end
|
132
|
-
end
|
133
|
-
end
|
134
|
-
|
135
|
-
if digest.hexdigest != sha256
|
136
|
-
raise Error, "Bad hash: #{digest.hexdigest}"
|
137
|
-
end
|
138
|
-
|
139
|
-
FileUtils.mv(temp_path, dest)
|
140
|
-
|
141
|
-
dest
|
142
|
-
end
|
143
115
|
end
|
144
116
|
end
|
145
117
|
end
|
@@ -0,0 +1,66 @@
|
|
1
|
+
module TorchVision
|
2
|
+
module Datasets
|
3
|
+
# TODO inherit Torch::Utils::Data::Dataset
|
4
|
+
class VisionDataset
|
5
|
+
def initialize(root, transforms: nil, transform: nil, target_transform: nil)
|
6
|
+
@root = root
|
7
|
+
|
8
|
+
has_transforms = !transforms.nil?
|
9
|
+
has_separate_transform = !transform.nil? || !target_transform.nil?
|
10
|
+
if has_transforms && has_separate_transform
|
11
|
+
raise ArgumentError, "Only transforms or transform/target_transform can be passed as argument"
|
12
|
+
end
|
13
|
+
|
14
|
+
@transform = transform
|
15
|
+
@target_transform = target_transform
|
16
|
+
|
17
|
+
if has_separate_transform
|
18
|
+
# transforms = StandardTransform.new(transform, target_transform)
|
19
|
+
end
|
20
|
+
@transforms = transforms
|
21
|
+
end
|
22
|
+
|
23
|
+
private
|
24
|
+
|
25
|
+
def download_file(url, download_root:, filename:, sha256:)
|
26
|
+
FileUtils.mkdir_p(download_root)
|
27
|
+
|
28
|
+
dest = File.join(download_root, filename)
|
29
|
+
return dest if File.exist?(dest)
|
30
|
+
|
31
|
+
temp_path = "#{Dir.tmpdir}/#{Time.now.to_f}" # TODO better name
|
32
|
+
|
33
|
+
uri = URI(url)
|
34
|
+
|
35
|
+
# Net::HTTP automatically adds Accept-Encoding for compression
|
36
|
+
# of response bodies and automatically decompresses gzip
|
37
|
+
# and deflateresponses unless a Range header was sent.
|
38
|
+
# https://ruby-doc.org/stdlib-2.6.4/libdoc/net/http/rdoc/Net/HTTP.html
|
39
|
+
Net::HTTP.start(uri.host, uri.port, use_ssl: uri.scheme == "https") do |http|
|
40
|
+
request = Net::HTTP::Get.new(uri)
|
41
|
+
|
42
|
+
puts "Downloading #{url}..."
|
43
|
+
File.open(temp_path, "wb") do |f|
|
44
|
+
http.request(request) do |response|
|
45
|
+
response.read_body do |chunk|
|
46
|
+
f.write(chunk)
|
47
|
+
end
|
48
|
+
end
|
49
|
+
end
|
50
|
+
end
|
51
|
+
|
52
|
+
unless check_integrity(temp_path, sha256)
|
53
|
+
raise Error, "Bad hash"
|
54
|
+
end
|
55
|
+
|
56
|
+
FileUtils.mv(temp_path, dest)
|
57
|
+
|
58
|
+
dest
|
59
|
+
end
|
60
|
+
|
61
|
+
def check_integrity(path, sha256)
|
62
|
+
File.exist?(path) && Digest::SHA256.file(path).hexdigest == sha256
|
63
|
+
end
|
64
|
+
end
|
65
|
+
end
|
66
|
+
end
|
@@ -0,0 +1,46 @@
|
|
1
|
+
module TorchVision
|
2
|
+
module Models
|
3
|
+
class BasicBlock < Torch::NN::Module
|
4
|
+
def initialize(inplanes, planes, stride: 1, downsample: nil, groups: 1, base_width: 64, dilation: 1, norm_layer: nil)
|
5
|
+
super()
|
6
|
+
norm_layer ||= Torch::NN::BatchNorm2d
|
7
|
+
if groups != 1 || base_width != 64
|
8
|
+
raise ArgumentError, "BasicBlock only supports groups=1 and base_width=64"
|
9
|
+
end
|
10
|
+
if dilation > 1
|
11
|
+
raise NotImplementedError, "Dilation > 1 not supported in BasicBlock"
|
12
|
+
end
|
13
|
+
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
14
|
+
@conv1 = Torch::NN::Conv2d.new(inplanes, planes, 3, stride: stride, padding: 1, groups: 1, bias: false, dilation: 1)
|
15
|
+
@bn1 = norm_layer.new(planes)
|
16
|
+
@relu = Torch::NN::ReLU.new(inplace: true)
|
17
|
+
@conv2 = Torch::NN::Conv2d.new(planes, planes, 3, stride: 1, padding: 1, groups: 1, bias: false, dilation: 1)
|
18
|
+
@bn2 = norm_layer.new(planes)
|
19
|
+
@downsample = downsample
|
20
|
+
@stride = stride
|
21
|
+
end
|
22
|
+
|
23
|
+
def forward(x)
|
24
|
+
identity = x
|
25
|
+
|
26
|
+
out = @conv1.call(x)
|
27
|
+
out = @bn1.call(out)
|
28
|
+
out = @relu.call(out)
|
29
|
+
|
30
|
+
out = @conv2.call(out)
|
31
|
+
out = @bn2.call(out)
|
32
|
+
|
33
|
+
identity = @downsample.call(x) if @downsample
|
34
|
+
|
35
|
+
out += identity
|
36
|
+
out = @relu.call(out)
|
37
|
+
|
38
|
+
out
|
39
|
+
end
|
40
|
+
|
41
|
+
def self.expansion
|
42
|
+
1
|
43
|
+
end
|
44
|
+
end
|
45
|
+
end
|
46
|
+
end
|
@@ -0,0 +1,47 @@
|
|
1
|
+
module TorchVision
|
2
|
+
module Models
|
3
|
+
class Bottleneck < Torch::NN::Module
|
4
|
+
def initialize(inplanes, planes, stride: 1, downsample: nil, groups: 1, base_width: 64, dilation: 1, norm_layer: nil)
|
5
|
+
super()
|
6
|
+
norm_layer ||= Torch::NN::BatchNorm2d
|
7
|
+
width = (planes * (base_width / 64.0)).to_i * groups
|
8
|
+
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
9
|
+
@conv1 = Torch::NN::Conv2d.new(inplanes, width, 1, stride: 1, bias: false)
|
10
|
+
@bn1 = norm_layer.new(width)
|
11
|
+
@conv2 = Torch::NN::Conv2d.new(width, width, 3, stride: stride, padding: dilation, groups: groups, bias: false, dilation: dilation)
|
12
|
+
@bn2 = norm_layer.new(width)
|
13
|
+
@conv3 = Torch::NN::Conv2d.new(width, planes * self.class.expansion, 1, stride: 1, bias: false)
|
14
|
+
@bn3 = norm_layer.new(planes * self.class.expansion)
|
15
|
+
@relu = Torch::NN::ReLU.new(inplace: true)
|
16
|
+
@downsample = downsample
|
17
|
+
@stride = stride
|
18
|
+
end
|
19
|
+
|
20
|
+
def forward(x)
|
21
|
+
identity = x
|
22
|
+
|
23
|
+
out = @conv1.call(x)
|
24
|
+
out = @bn1.call(out)
|
25
|
+
out = @relu.call(out)
|
26
|
+
|
27
|
+
out = @conv2.call(out)
|
28
|
+
out = @bn2.call(out)
|
29
|
+
out = @relu.call(out)
|
30
|
+
|
31
|
+
out = @conv3.call(out)
|
32
|
+
out = @bn3.call(out)
|
33
|
+
|
34
|
+
identity = @downsample.call(x) if @downsample
|
35
|
+
|
36
|
+
out += identity
|
37
|
+
out = @relu.call(out)
|
38
|
+
|
39
|
+
out
|
40
|
+
end
|
41
|
+
|
42
|
+
def self.expansion
|
43
|
+
4
|
44
|
+
end
|
45
|
+
end
|
46
|
+
end
|
47
|
+
end
|
@@ -0,0 +1,107 @@
|
|
1
|
+
module TorchVision
|
2
|
+
module Models
|
3
|
+
class ResNet < Torch::NN::Module
|
4
|
+
def initialize(block, layers, num_classes=1000, zero_init_residual: false,
|
5
|
+
groups: 1, width_per_group: 64, replace_stride_with_dilation: nil, norm_layer: nil)
|
6
|
+
|
7
|
+
super()
|
8
|
+
norm_layer ||= Torch::NN::BatchNorm2d
|
9
|
+
@norm_layer = norm_layer
|
10
|
+
|
11
|
+
@inplanes = 64
|
12
|
+
@dilation = 1
|
13
|
+
if replace_stride_with_dilation.nil?
|
14
|
+
# each element in the tuple indicates if we should replace
|
15
|
+
# the 2x2 stride with a dilated convolution instead
|
16
|
+
replace_stride_with_dilation = [false, false, false]
|
17
|
+
end
|
18
|
+
if replace_stride_with_dilation.length != 3
|
19
|
+
raise ArgumentError, "replace_stride_with_dilation should be nil or a 3-element tuple, got #{replace_stride_with_dilation}"
|
20
|
+
end
|
21
|
+
@groups = groups
|
22
|
+
@base_width = width_per_group
|
23
|
+
@conv1 = Torch::NN::Conv2d.new(3, @inplanes, 7, stride: 2, padding: 3, bias: false)
|
24
|
+
@bn1 = norm_layer.new(@inplanes)
|
25
|
+
@relu = Torch::NN::ReLU.new(inplace: true)
|
26
|
+
@maxpool = Torch::NN::MaxPool2d.new(3, stride: 2, padding: 1)
|
27
|
+
@layer1 = _make_layer(block, 64, layers[0])
|
28
|
+
@layer2 = _make_layer(block, 128, layers[1], stride: 2, dilate: replace_stride_with_dilation[0])
|
29
|
+
@layer3 = _make_layer(block, 256, layers[2], stride: 2, dilate: replace_stride_with_dilation[1])
|
30
|
+
@layer4 = _make_layer(block, 512, layers[3], stride: 2, dilate: replace_stride_with_dilation[2])
|
31
|
+
@avgpool = Torch::NN::AdaptiveAvgPool2d.new([1, 1])
|
32
|
+
@fc = Torch::NN::Linear.new(512 * block.expansion, num_classes)
|
33
|
+
|
34
|
+
modules.each do |m|
|
35
|
+
case m
|
36
|
+
when Torch::NN::Conv2d
|
37
|
+
Torch::NN::Init.kaiming_normal!(m.weight, mode: "fan_out", nonlinearity: "relu")
|
38
|
+
when Torch::NN::BatchNorm2d, Torch::NN::GroupNorm
|
39
|
+
Torch::NN::Init.constant!(m.weight, 1)
|
40
|
+
Torch::NN::Init.constant!(m.bias, 0)
|
41
|
+
end
|
42
|
+
end
|
43
|
+
|
44
|
+
# Zero-initialize the last BN in each residual branch,
|
45
|
+
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
46
|
+
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
47
|
+
if zero_init_residual
|
48
|
+
modules.each do |m|
|
49
|
+
case m
|
50
|
+
when Bottleneck
|
51
|
+
Torch::NN::Init.constant!(m.bn3.weight, 0)
|
52
|
+
when BasicBlock
|
53
|
+
Torch::NN::Init.constant!(m.bn2.weight, 0)
|
54
|
+
end
|
55
|
+
end
|
56
|
+
end
|
57
|
+
end
|
58
|
+
|
59
|
+
def _make_layer(block, planes, blocks, stride: 1, dilate: false)
|
60
|
+
norm_layer = @norm_layer
|
61
|
+
downsample = nil
|
62
|
+
previous_dilation = @dilation
|
63
|
+
if dilate
|
64
|
+
@dilation *= stride
|
65
|
+
stride = 1
|
66
|
+
end
|
67
|
+
if stride != 1 || @inplanes != planes * block.expansion
|
68
|
+
downsample = Torch::NN::Sequential.new(
|
69
|
+
Torch::NN::Conv2d.new(@inplanes, planes * block.expansion, 1, stride: stride, bias: false),
|
70
|
+
norm_layer.new(planes * block.expansion)
|
71
|
+
)
|
72
|
+
end
|
73
|
+
|
74
|
+
layers = []
|
75
|
+
layers << block.new(@inplanes, planes, stride: stride, downsample: downsample, groups: @groups, base_width: @base_width, dilation: previous_dilation, norm_layer: norm_layer)
|
76
|
+
@inplanes = planes * block.expansion
|
77
|
+
(blocks - 1).times do
|
78
|
+
layers << block.new(@inplanes, planes, groups: @groups, base_width: @base_width, dilation: @dilation, norm_layer: norm_layer)
|
79
|
+
end
|
80
|
+
|
81
|
+
Torch::NN::Sequential.new(*layers)
|
82
|
+
end
|
83
|
+
|
84
|
+
def _forward_impl(x)
|
85
|
+
x = @conv1.call(x)
|
86
|
+
x = @bn1.call(x)
|
87
|
+
x = @relu.call(x)
|
88
|
+
x = @maxpool.call(x)
|
89
|
+
|
90
|
+
x = @layer1.call(x)
|
91
|
+
x = @layer2.call(x)
|
92
|
+
x = @layer3.call(x)
|
93
|
+
x = @layer4.call(x)
|
94
|
+
|
95
|
+
x = @avgpool.call(x)
|
96
|
+
x = Torch.flatten(x, 1)
|
97
|
+
x = @fc.call(x)
|
98
|
+
|
99
|
+
x
|
100
|
+
end
|
101
|
+
|
102
|
+
def forward(x)
|
103
|
+
_forward_impl(x)
|
104
|
+
end
|
105
|
+
end
|
106
|
+
end
|
107
|
+
end
|
@@ -0,0 +1,15 @@
|
|
1
|
+
module TorchVision
|
2
|
+
module Models
|
3
|
+
module ResNet18
|
4
|
+
def self.new(pretrained: false, **kwargs)
|
5
|
+
model = ResNet.new(BasicBlock, [2, 2, 2, 2], **kwargs)
|
6
|
+
if pretrained
|
7
|
+
url = "https://download.pytorch.org/models/resnet18-5c106cde.pth"
|
8
|
+
state_dict = Torch::Hub.load_state_dict_from_url(url)
|
9
|
+
model.load_state_dict(state_dict)
|
10
|
+
end
|
11
|
+
model
|
12
|
+
end
|
13
|
+
end
|
14
|
+
end
|
15
|
+
end
|
@@ -22,18 +22,35 @@ module TorchVision
|
|
22
22
|
if std.to_a.any? { |v| v == 0 }
|
23
23
|
raise ArgumentError, "std evaluated to zero after conversion to #{dtype}, leading to division by zero."
|
24
24
|
end
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
25
|
+
if mean.ndim == 1
|
26
|
+
mean = mean[0...mean.size(0), nil, nil]
|
27
|
+
end
|
28
|
+
if std.ndim == 1
|
29
|
+
std = std[0...std.size(0), nil, nil]
|
30
|
+
end
|
31
31
|
tensor.sub!(mean).div!(std)
|
32
32
|
tensor
|
33
33
|
end
|
34
34
|
|
35
35
|
# TODO improve
|
36
36
|
def to_tensor(pic)
|
37
|
+
if !pic.is_a?(Numo::NArray) && !pic.is_a?(Torch::Tensor)
|
38
|
+
raise ArgumentError, "pic should be tensor or Numo::NArray. Got #{pic.class.name}"
|
39
|
+
end
|
40
|
+
|
41
|
+
if pic.is_a?(Numo::NArray) && ![2, 3].include?(pic.ndim)
|
42
|
+
raise ArgumentError, "pic should be 2/3 dimensional. Got #{pic.dim} dimensions."
|
43
|
+
end
|
44
|
+
|
45
|
+
if pic.is_a?(Numo::NArray)
|
46
|
+
if pic.ndim == 2
|
47
|
+
raise Torch::NotImplementedYet
|
48
|
+
end
|
49
|
+
|
50
|
+
img = Torch.from_numo(pic.transpose(2, 0, 1))
|
51
|
+
return img.float.div(255)
|
52
|
+
end
|
53
|
+
|
37
54
|
pic = pic.float
|
38
55
|
pic.unsqueeze!(0).div!(255)
|
39
56
|
end
|
data/lib/torchvision/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torchvision
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.1.
|
4
|
+
version: 0.1.2
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2020-04-
|
11
|
+
date: 2020-04-29 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: numo-narray
|
@@ -30,14 +30,14 @@ dependencies:
|
|
30
30
|
requirements:
|
31
31
|
- - ">="
|
32
32
|
- !ruby/object:Gem::Version
|
33
|
-
version: 0.2.
|
33
|
+
version: 0.2.4
|
34
34
|
type: :runtime
|
35
35
|
prerelease: false
|
36
36
|
version_requirements: !ruby/object:Gem::Requirement
|
37
37
|
requirements:
|
38
38
|
- - ">="
|
39
39
|
- !ruby/object:Gem::Version
|
40
|
-
version: 0.2.
|
40
|
+
version: 0.2.4
|
41
41
|
- !ruby/object:Gem::Dependency
|
42
42
|
name: bundler
|
43
43
|
requirement: !ruby/object:Gem::Requirement
|
@@ -90,7 +90,16 @@ files:
|
|
90
90
|
- LICENSE.txt
|
91
91
|
- README.md
|
92
92
|
- lib/torchvision.rb
|
93
|
+
- lib/torchvision/datasets/cifar10.rb
|
94
|
+
- lib/torchvision/datasets/cifar100.rb
|
95
|
+
- lib/torchvision/datasets/fashion_mnist.rb
|
96
|
+
- lib/torchvision/datasets/kmnist.rb
|
93
97
|
- lib/torchvision/datasets/mnist.rb
|
98
|
+
- lib/torchvision/datasets/vision_dataset.rb
|
99
|
+
- lib/torchvision/models/basic_block.rb
|
100
|
+
- lib/torchvision/models/bottleneck.rb
|
101
|
+
- lib/torchvision/models/resnet.rb
|
102
|
+
- lib/torchvision/models/resnet18.rb
|
94
103
|
- lib/torchvision/transforms/compose.rb
|
95
104
|
- lib/torchvision/transforms/functional.rb
|
96
105
|
- lib/torchvision/transforms/normalize.rb
|