torchvision 0.1.0 → 0.2.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +35 -0
- data/LICENSE.txt +1 -1
- data/README.md +133 -5
- data/lib/torchvision.rb +40 -1
- data/lib/torchvision/datasets/cifar10.rb +117 -0
- data/lib/torchvision/datasets/cifar100.rb +41 -0
- data/lib/torchvision/datasets/dataset_folder.rb +91 -0
- data/lib/torchvision/datasets/fashion_mnist.rb +30 -0
- data/lib/torchvision/datasets/image_folder.rb +12 -0
- data/lib/torchvision/datasets/kmnist.rb +30 -0
- data/lib/torchvision/datasets/mnist.rb +47 -76
- data/lib/torchvision/datasets/vision_dataset.rb +67 -0
- data/lib/torchvision/models/alexnet.rb +42 -0
- data/lib/torchvision/models/basic_block.rb +46 -0
- data/lib/torchvision/models/bottleneck.rb +47 -0
- data/lib/torchvision/models/resnet.rb +129 -0
- data/lib/torchvision/models/resnet101.rb +9 -0
- data/lib/torchvision/models/resnet152.rb +9 -0
- data/lib/torchvision/models/resnet18.rb +9 -0
- data/lib/torchvision/models/resnet34.rb +9 -0
- data/lib/torchvision/models/resnet50.rb +9 -0
- data/lib/torchvision/models/resnext101_32x8d.rb +11 -0
- data/lib/torchvision/models/resnext50_32x4d.rb +11 -0
- data/lib/torchvision/models/vgg.rb +93 -0
- data/lib/torchvision/models/vgg11.rb +9 -0
- data/lib/torchvision/models/vgg11_bn.rb +9 -0
- data/lib/torchvision/models/vgg13.rb +9 -0
- data/lib/torchvision/models/vgg13_bn.rb +9 -0
- data/lib/torchvision/models/vgg16.rb +9 -0
- data/lib/torchvision/models/vgg16_bn.rb +9 -0
- data/lib/torchvision/models/vgg19.rb +9 -0
- data/lib/torchvision/models/vgg19_bn.rb +9 -0
- data/lib/torchvision/models/wide_resnet101_2.rb +10 -0
- data/lib/torchvision/models/wide_resnet50_2.rb +10 -0
- data/lib/torchvision/transforms/center_crop.rb +13 -0
- data/lib/torchvision/transforms/compose.rb +2 -2
- data/lib/torchvision/transforms/functional.rb +142 -7
- data/lib/torchvision/transforms/normalize.rb +2 -2
- data/lib/torchvision/transforms/random_horizontal_flip.rb +18 -0
- data/lib/torchvision/transforms/random_resized_crop.rb +70 -0
- data/lib/torchvision/transforms/random_vertical_flip.rb +18 -0
- data/lib/torchvision/transforms/resize.rb +13 -0
- data/lib/torchvision/transforms/to_tensor.rb +2 -2
- data/lib/torchvision/utils.rb +120 -0
- data/lib/torchvision/version.rb +1 -1
- metadata +50 -57
@@ -0,0 +1,91 @@
|
|
1
|
+
module TorchVision
|
2
|
+
module Datasets
|
3
|
+
class DatasetFolder < VisionDataset
|
4
|
+
attr_reader :classes
|
5
|
+
|
6
|
+
def initialize(root, extensions: nil, transform: nil, target_transform: nil, is_valid_file: nil)
|
7
|
+
super(root, transform: transform, target_transform: target_transform)
|
8
|
+
classes, class_to_idx = find_classes(@root)
|
9
|
+
samples = make_dataset(@root, class_to_idx, extensions, is_valid_file)
|
10
|
+
if samples.empty?
|
11
|
+
msg = "Found 0 files in subfolders of: #{@root}\n"
|
12
|
+
unless extensions.nil?
|
13
|
+
msg += "Supported extensions are: #{extensions.join(",")}"
|
14
|
+
end
|
15
|
+
raise RuntimeError, msg
|
16
|
+
end
|
17
|
+
|
18
|
+
@loader = lambda do |path|
|
19
|
+
Vips::Image.new_from_file(path)
|
20
|
+
end
|
21
|
+
@extensions = extensions
|
22
|
+
|
23
|
+
@classes = classes
|
24
|
+
@class_to_idx = class_to_idx
|
25
|
+
@samples = samples
|
26
|
+
@targets = samples.map { |s| s[1] }
|
27
|
+
end
|
28
|
+
|
29
|
+
def [](index)
|
30
|
+
path, target = @samples[index]
|
31
|
+
sample = @loader.call(path)
|
32
|
+
if @transform
|
33
|
+
sample = @transform.call(sample)
|
34
|
+
end
|
35
|
+
if @target_transform
|
36
|
+
target = @target_transform.call(target)
|
37
|
+
end
|
38
|
+
|
39
|
+
[sample, target]
|
40
|
+
end
|
41
|
+
|
42
|
+
def size
|
43
|
+
@samples.size
|
44
|
+
end
|
45
|
+
|
46
|
+
private
|
47
|
+
|
48
|
+
def find_classes(dir)
|
49
|
+
classes = Dir.children(dir).select { |d| File.directory?(File.join(dir, d)) }
|
50
|
+
classes.sort!
|
51
|
+
class_to_idx = classes.map.with_index.to_h
|
52
|
+
[classes, class_to_idx]
|
53
|
+
end
|
54
|
+
|
55
|
+
def has_file_allowed_extension(filename, extensions)
|
56
|
+
filename = filename.downcase
|
57
|
+
extensions.any? { |ext| filename.end_with?(ext) }
|
58
|
+
end
|
59
|
+
|
60
|
+
def make_dataset(directory, class_to_idx, extensions, is_valid_file)
|
61
|
+
instances = []
|
62
|
+
directory = File.expand_path(directory)
|
63
|
+
both_none = extensions.nil? && is_valid_file.nil?
|
64
|
+
both_something = !extensions.nil? && !is_valid_file.nil?
|
65
|
+
if both_none || both_something
|
66
|
+
raise ArgumentError, "Both extensions and is_valid_file cannot be None or not None at the same time"
|
67
|
+
end
|
68
|
+
if !extensions.nil?
|
69
|
+
is_valid_file = lambda do |x|
|
70
|
+
has_file_allowed_extension(x, extensions)
|
71
|
+
end
|
72
|
+
end
|
73
|
+
class_to_idx.keys.sort.each do |target_class|
|
74
|
+
class_index = class_to_idx[target_class]
|
75
|
+
target_dir = File.join(directory, target_class)
|
76
|
+
if !File.directory?(target_dir)
|
77
|
+
next
|
78
|
+
end
|
79
|
+
Dir.glob("**", base: target_dir).sort.each do |fname|
|
80
|
+
path = File.join(target_dir, fname)
|
81
|
+
if is_valid_file.call(path)
|
82
|
+
item = [path, class_index]
|
83
|
+
instances << item
|
84
|
+
end
|
85
|
+
end
|
86
|
+
end
|
87
|
+
instances
|
88
|
+
end
|
89
|
+
end
|
90
|
+
end
|
91
|
+
end
|
@@ -0,0 +1,30 @@
|
|
1
|
+
module TorchVision
|
2
|
+
module Datasets
|
3
|
+
class FashionMNIST < MNIST
|
4
|
+
# https://github.com/zalandoresearch/fashion-mnist
|
5
|
+
|
6
|
+
private
|
7
|
+
|
8
|
+
def resources
|
9
|
+
[
|
10
|
+
{
|
11
|
+
url: "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz",
|
12
|
+
sha256: "3aede38d61863908ad78613f6a32ed271626dd12800ba2636569512369268a84"
|
13
|
+
},
|
14
|
+
{
|
15
|
+
url: "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz",
|
16
|
+
sha256: "a04f17134ac03560a47e3764e11b92fc97de4d1bfaf8ba1a3aa29af54cc90845"
|
17
|
+
},
|
18
|
+
{
|
19
|
+
url: "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz",
|
20
|
+
sha256: "346e55b948d973a97e58d2351dde16a484bd415d4595297633bb08f03db6a073"
|
21
|
+
},
|
22
|
+
{
|
23
|
+
url: "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz",
|
24
|
+
sha256: "67da17c76eaffca5446c3361aaab5c3cd6d1c2608764d35dfb1850b086bf8dd5"
|
25
|
+
}
|
26
|
+
]
|
27
|
+
end
|
28
|
+
end
|
29
|
+
end
|
30
|
+
end
|
@@ -0,0 +1,12 @@
|
|
1
|
+
module TorchVision
|
2
|
+
module Datasets
|
3
|
+
class ImageFolder < DatasetFolder
|
4
|
+
IMG_EXTENSIONS = [".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp"]
|
5
|
+
|
6
|
+
def initialize(root, transform: nil, target_transform: nil, is_valid_file: nil)
|
7
|
+
super(root, extensions: IMG_EXTENSIONS, transform: transform, target_transform: target_transform, is_valid_file: is_valid_file)
|
8
|
+
@imgs = @samples
|
9
|
+
end
|
10
|
+
end
|
11
|
+
end
|
12
|
+
end
|
@@ -0,0 +1,30 @@
|
|
1
|
+
module TorchVision
|
2
|
+
module Datasets
|
3
|
+
class KMNIST < MNIST
|
4
|
+
# https://github.com/rois-codh/kmnist
|
5
|
+
|
6
|
+
private
|
7
|
+
|
8
|
+
def resources
|
9
|
+
[
|
10
|
+
{
|
11
|
+
url: "http://codh.rois.ac.jp/kmnist/dataset/kmnist/train-images-idx3-ubyte.gz",
|
12
|
+
sha256: "51467d22d8cc72929e2a028a0428f2086b092bb31cfb79c69cc0a90ce135fde4"
|
13
|
+
},
|
14
|
+
{
|
15
|
+
url: "http://codh.rois.ac.jp/kmnist/dataset/kmnist/train-labels-idx1-ubyte.gz",
|
16
|
+
sha256: "e38f9ebcd0f3ebcdec7fc8eabdcdaef93bb0df8ea12bee65224341c8183d8e17"
|
17
|
+
},
|
18
|
+
{
|
19
|
+
url: "http://codh.rois.ac.jp/kmnist/dataset/kmnist/t10k-images-idx3-ubyte.gz",
|
20
|
+
sha256: "edd7a857845ad6bb1d0ba43fe7e794d164fe2dce499a1694695a792adfac43c5"
|
21
|
+
},
|
22
|
+
{
|
23
|
+
url: "http://codh.rois.ac.jp/kmnist/dataset/kmnist/t10k-labels-idx1-ubyte.gz",
|
24
|
+
sha256: "20bb9a0ef54c7db3efc55a92eef5582c109615df22683c380526788f98e42a1c"
|
25
|
+
}
|
26
|
+
]
|
27
|
+
end
|
28
|
+
end
|
29
|
+
end
|
30
|
+
end
|
@@ -1,31 +1,10 @@
|
|
1
1
|
module TorchVision
|
2
2
|
module Datasets
|
3
|
-
class MNIST
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
sha256: "440fcabf73cc546fa21475e81ea370265605f56be210a4024d2ca8f203523609"
|
8
|
-
},
|
9
|
-
{
|
10
|
-
url: "http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz",
|
11
|
-
sha256: "3552534a0a558bbed6aed32b30c495cca23d567ec52cac8be1a0730e8010255c"
|
12
|
-
},
|
13
|
-
{
|
14
|
-
url: "http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz",
|
15
|
-
sha256: "8d422c7b0a1c1c79245a5bcf07fe86e33eeafee792b84584aec276f5a2dbc4e6"
|
16
|
-
},
|
17
|
-
{
|
18
|
-
url: "http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz",
|
19
|
-
sha256: "f7ae60f92e00ec6debd23a6088c31dbd2371eca3ffa0defaefb259924204aec6"
|
20
|
-
}
|
21
|
-
]
|
22
|
-
TRAINING_FILE = "training.pt"
|
23
|
-
TEST_FILE = "test.pt"
|
24
|
-
|
25
|
-
def initialize(root, train: true, download: false, transform: nil)
|
26
|
-
@root = root
|
3
|
+
class MNIST < VisionDataset
|
4
|
+
# http://yann.lecun.com/exdb/mnist/
|
5
|
+
def initialize(root, train: true, download: false, transform: nil, target_transform: nil)
|
6
|
+
super(root, transform: transform, target_transform: target_transform)
|
27
7
|
@train = train
|
28
|
-
@transform = transform
|
29
8
|
|
30
9
|
self.download if download
|
31
10
|
|
@@ -33,35 +12,37 @@ module TorchVision
|
|
33
12
|
raise Error, "Dataset not found. You can use download: true to download it"
|
34
13
|
end
|
35
14
|
|
36
|
-
data_file = @train ?
|
15
|
+
data_file = @train ? training_file : test_file
|
37
16
|
@data, @targets = Torch.load(File.join(processed_folder, data_file))
|
38
17
|
end
|
39
18
|
|
40
19
|
def size
|
41
|
-
@data.size
|
20
|
+
@data.size(0)
|
42
21
|
end
|
43
22
|
|
44
23
|
def [](index)
|
45
|
-
img = @data[index]
|
46
|
-
|
24
|
+
img, target = @data[index], @targets[index].item
|
25
|
+
|
26
|
+
img = Utils.image_from_array(img)
|
27
|
+
|
47
28
|
img = @transform.call(img) if @transform
|
48
29
|
|
49
|
-
target = @
|
30
|
+
target = @target_transform.call(target) if @target_transform
|
50
31
|
|
51
32
|
[img, target]
|
52
33
|
end
|
53
34
|
|
54
35
|
def raw_folder
|
55
|
-
File.join(@root, "
|
36
|
+
File.join(@root, self.class.name.split("::").last, "raw")
|
56
37
|
end
|
57
38
|
|
58
39
|
def processed_folder
|
59
|
-
File.join(@root, "
|
40
|
+
File.join(@root, self.class.name.split("::").last, "processed")
|
60
41
|
end
|
61
42
|
|
62
43
|
def check_exists
|
63
|
-
File.exist?(File.join(processed_folder,
|
64
|
-
File.exist?(File.join(processed_folder,
|
44
|
+
File.exist?(File.join(processed_folder, training_file)) &&
|
45
|
+
File.exist?(File.join(processed_folder, test_file))
|
65
46
|
end
|
66
47
|
|
67
48
|
def download
|
@@ -70,7 +51,7 @@ module TorchVision
|
|
70
51
|
FileUtils.mkdir_p(raw_folder)
|
71
52
|
FileUtils.mkdir_p(processed_folder)
|
72
53
|
|
73
|
-
|
54
|
+
resources.each do |resource|
|
74
55
|
filename = resource[:url].split("/").last
|
75
56
|
download_file(resource[:url], download_root: raw_folder, filename: filename, sha256: resource[:sha256])
|
76
57
|
end
|
@@ -86,14 +67,43 @@ module TorchVision
|
|
86
67
|
unpack_mnist("t10k-labels-idx1-ubyte", 8, [10000])
|
87
68
|
]
|
88
69
|
|
89
|
-
Torch.save(training_set, File.join(processed_folder,
|
90
|
-
Torch.save(test_set, File.join(processed_folder,
|
70
|
+
Torch.save(training_set, File.join(processed_folder, training_file))
|
71
|
+
Torch.save(test_set, File.join(processed_folder, test_file))
|
91
72
|
|
92
73
|
puts "Done!"
|
93
74
|
end
|
94
75
|
|
95
76
|
private
|
96
77
|
|
78
|
+
def resources
|
79
|
+
[
|
80
|
+
{
|
81
|
+
url: "http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz",
|
82
|
+
sha256: "440fcabf73cc546fa21475e81ea370265605f56be210a4024d2ca8f203523609"
|
83
|
+
},
|
84
|
+
{
|
85
|
+
url: "http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz",
|
86
|
+
sha256: "3552534a0a558bbed6aed32b30c495cca23d567ec52cac8be1a0730e8010255c"
|
87
|
+
},
|
88
|
+
{
|
89
|
+
url: "http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz",
|
90
|
+
sha256: "8d422c7b0a1c1c79245a5bcf07fe86e33eeafee792b84584aec276f5a2dbc4e6"
|
91
|
+
},
|
92
|
+
{
|
93
|
+
url: "http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz",
|
94
|
+
sha256: "f7ae60f92e00ec6debd23a6088c31dbd2371eca3ffa0defaefb259924204aec6"
|
95
|
+
}
|
96
|
+
]
|
97
|
+
end
|
98
|
+
|
99
|
+
def training_file
|
100
|
+
"training.pt"
|
101
|
+
end
|
102
|
+
|
103
|
+
def test_file
|
104
|
+
"test.pt"
|
105
|
+
end
|
106
|
+
|
97
107
|
def unpack_mnist(path, offset, shape)
|
98
108
|
path = File.join(raw_folder, "#{path}.gz")
|
99
109
|
File.open(path, "rb") do |f|
|
@@ -102,45 +112,6 @@ module TorchVision
|
|
102
112
|
Torch.tensor(Numo::UInt8.from_string(gz.read, shape))
|
103
113
|
end
|
104
114
|
end
|
105
|
-
|
106
|
-
def download_file(url, download_root:, filename:, sha256:)
|
107
|
-
FileUtils.mkdir_p(download_root)
|
108
|
-
|
109
|
-
dest = File.join(download_root, filename)
|
110
|
-
return dest if File.exist?(dest)
|
111
|
-
|
112
|
-
temp_path = "#{Dir.tmpdir}/#{Time.now.to_f}" # TODO better name
|
113
|
-
|
114
|
-
digest = Digest::SHA256.new
|
115
|
-
|
116
|
-
uri = URI(url)
|
117
|
-
|
118
|
-
# Net::HTTP automatically adds Accept-Encoding for compression
|
119
|
-
# of response bodies and automatically decompresses gzip
|
120
|
-
# and deflateresponses unless a Range header was sent.
|
121
|
-
# https://ruby-doc.org/stdlib-2.6.4/libdoc/net/http/rdoc/Net/HTTP.html
|
122
|
-
Net::HTTP.start(uri.host, uri.port, use_ssl: uri.scheme == "https") do |http|
|
123
|
-
request = Net::HTTP::Get.new(uri)
|
124
|
-
|
125
|
-
puts "Downloading #{url}..."
|
126
|
-
File.open(temp_path, "wb") do |f|
|
127
|
-
http.request(request) do |response|
|
128
|
-
response.read_body do |chunk|
|
129
|
-
f.write(chunk)
|
130
|
-
digest.update(chunk)
|
131
|
-
end
|
132
|
-
end
|
133
|
-
end
|
134
|
-
end
|
135
|
-
|
136
|
-
if digest.hexdigest != sha256
|
137
|
-
raise Error, "Bad hash: #{digest.hexdigest}"
|
138
|
-
end
|
139
|
-
|
140
|
-
FileUtils.mv(temp_path, dest)
|
141
|
-
|
142
|
-
dest
|
143
|
-
end
|
144
115
|
end
|
145
116
|
end
|
146
117
|
end
|
@@ -0,0 +1,67 @@
|
|
1
|
+
module TorchVision
|
2
|
+
module Datasets
|
3
|
+
class VisionDataset < Torch::Utils::Data::Dataset
|
4
|
+
attr_reader :data, :targets
|
5
|
+
|
6
|
+
def initialize(root, transforms: nil, transform: nil, target_transform: nil)
|
7
|
+
@root = root
|
8
|
+
|
9
|
+
has_transforms = !transforms.nil?
|
10
|
+
has_separate_transform = !transform.nil? || !target_transform.nil?
|
11
|
+
if has_transforms && has_separate_transform
|
12
|
+
raise ArgumentError, "Only transforms or transform/target_transform can be passed as argument"
|
13
|
+
end
|
14
|
+
|
15
|
+
@transform = transform
|
16
|
+
@target_transform = target_transform
|
17
|
+
|
18
|
+
if has_separate_transform
|
19
|
+
# transforms = StandardTransform.new(transform, target_transform)
|
20
|
+
end
|
21
|
+
@transforms = transforms
|
22
|
+
end
|
23
|
+
|
24
|
+
private
|
25
|
+
|
26
|
+
def download_file(url, download_root:, filename:, sha256:)
|
27
|
+
FileUtils.mkdir_p(download_root)
|
28
|
+
|
29
|
+
dest = File.join(download_root, filename)
|
30
|
+
return dest if File.exist?(dest)
|
31
|
+
|
32
|
+
temp_path = "#{Dir.tmpdir}/#{Time.now.to_f}" # TODO better name
|
33
|
+
|
34
|
+
uri = URI(url)
|
35
|
+
|
36
|
+
# Net::HTTP automatically adds Accept-Encoding for compression
|
37
|
+
# of response bodies and automatically decompresses gzip
|
38
|
+
# and deflateresponses unless a Range header was sent.
|
39
|
+
# https://ruby-doc.org/stdlib-2.6.4/libdoc/net/http/rdoc/Net/HTTP.html
|
40
|
+
Net::HTTP.start(uri.host, uri.port, use_ssl: uri.scheme == "https") do |http|
|
41
|
+
request = Net::HTTP::Get.new(uri)
|
42
|
+
|
43
|
+
puts "Downloading #{url}..."
|
44
|
+
File.open(temp_path, "wb") do |f|
|
45
|
+
http.request(request) do |response|
|
46
|
+
response.read_body do |chunk|
|
47
|
+
f.write(chunk)
|
48
|
+
end
|
49
|
+
end
|
50
|
+
end
|
51
|
+
end
|
52
|
+
|
53
|
+
unless check_integrity(temp_path, sha256)
|
54
|
+
raise Error, "Bad hash"
|
55
|
+
end
|
56
|
+
|
57
|
+
FileUtils.mv(temp_path, dest)
|
58
|
+
|
59
|
+
dest
|
60
|
+
end
|
61
|
+
|
62
|
+
def check_integrity(path, sha256)
|
63
|
+
File.exist?(path) && Digest::SHA256.file(path).hexdigest == sha256
|
64
|
+
end
|
65
|
+
end
|
66
|
+
end
|
67
|
+
end
|
@@ -0,0 +1,42 @@
|
|
1
|
+
module TorchVision
|
2
|
+
module Models
|
3
|
+
class AlexNet < Torch::NN::Module
|
4
|
+
def initialize(num_classes: 1000)
|
5
|
+
super()
|
6
|
+
@features = Torch::NN::Sequential.new(
|
7
|
+
Torch::NN::Conv2d.new(3, 64, 11, stride: 4, padding: 2),
|
8
|
+
Torch::NN::ReLU.new(inplace: true),
|
9
|
+
Torch::NN::MaxPool2d.new(3, stride: 2),
|
10
|
+
Torch::NN::Conv2d.new(64, 192, 5, padding: 2),
|
11
|
+
Torch::NN::ReLU.new(inplace: true),
|
12
|
+
Torch::NN::MaxPool2d.new(3, stride: 2),
|
13
|
+
Torch::NN::Conv2d.new(192, 384, 3, padding: 1),
|
14
|
+
Torch::NN::ReLU.new(inplace: true),
|
15
|
+
Torch::NN::Conv2d.new(384, 256, 3, padding: 1),
|
16
|
+
Torch::NN::ReLU.new(inplace: true),
|
17
|
+
Torch::NN::Conv2d.new(256, 256, 3, padding: 1),
|
18
|
+
Torch::NN::ReLU.new(inplace: true),
|
19
|
+
Torch::NN::MaxPool2d.new(3, stride: 2),
|
20
|
+
)
|
21
|
+
@avgpool = Torch::NN::AdaptiveAvgPool2d.new([6, 6])
|
22
|
+
@classifier = Torch::NN::Sequential.new(
|
23
|
+
Torch::NN::Dropout.new,
|
24
|
+
Torch::NN::Linear.new(256 * 6 * 6, 4096),
|
25
|
+
Torch::NN::ReLU.new(inplace: true),
|
26
|
+
Torch::NN::Dropout.new,
|
27
|
+
Torch::NN::Linear.new(4096, 4096),
|
28
|
+
Torch::NN::ReLU.new(inplace: true),
|
29
|
+
Torch::NN::Linear.new(4096, num_classes)
|
30
|
+
)
|
31
|
+
end
|
32
|
+
|
33
|
+
def forward(x)
|
34
|
+
x = @features.call(x)
|
35
|
+
x = @avgpool.call(x)
|
36
|
+
x = Torch.flatten(x, 1)
|
37
|
+
x = @classifier.call(x)
|
38
|
+
x
|
39
|
+
end
|
40
|
+
end
|
41
|
+
end
|
42
|
+
end
|