torchvision 0.1.0 → 0.2.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +35 -0
- data/LICENSE.txt +1 -1
- data/README.md +133 -5
- data/lib/torchvision.rb +40 -1
- data/lib/torchvision/datasets/cifar10.rb +117 -0
- data/lib/torchvision/datasets/cifar100.rb +41 -0
- data/lib/torchvision/datasets/dataset_folder.rb +91 -0
- data/lib/torchvision/datasets/fashion_mnist.rb +30 -0
- data/lib/torchvision/datasets/image_folder.rb +12 -0
- data/lib/torchvision/datasets/kmnist.rb +30 -0
- data/lib/torchvision/datasets/mnist.rb +47 -76
- data/lib/torchvision/datasets/vision_dataset.rb +67 -0
- data/lib/torchvision/models/alexnet.rb +42 -0
- data/lib/torchvision/models/basic_block.rb +46 -0
- data/lib/torchvision/models/bottleneck.rb +47 -0
- data/lib/torchvision/models/resnet.rb +129 -0
- data/lib/torchvision/models/resnet101.rb +9 -0
- data/lib/torchvision/models/resnet152.rb +9 -0
- data/lib/torchvision/models/resnet18.rb +9 -0
- data/lib/torchvision/models/resnet34.rb +9 -0
- data/lib/torchvision/models/resnet50.rb +9 -0
- data/lib/torchvision/models/resnext101_32x8d.rb +11 -0
- data/lib/torchvision/models/resnext50_32x4d.rb +11 -0
- data/lib/torchvision/models/vgg.rb +93 -0
- data/lib/torchvision/models/vgg11.rb +9 -0
- data/lib/torchvision/models/vgg11_bn.rb +9 -0
- data/lib/torchvision/models/vgg13.rb +9 -0
- data/lib/torchvision/models/vgg13_bn.rb +9 -0
- data/lib/torchvision/models/vgg16.rb +9 -0
- data/lib/torchvision/models/vgg16_bn.rb +9 -0
- data/lib/torchvision/models/vgg19.rb +9 -0
- data/lib/torchvision/models/vgg19_bn.rb +9 -0
- data/lib/torchvision/models/wide_resnet101_2.rb +10 -0
- data/lib/torchvision/models/wide_resnet50_2.rb +10 -0
- data/lib/torchvision/transforms/center_crop.rb +13 -0
- data/lib/torchvision/transforms/compose.rb +2 -2
- data/lib/torchvision/transforms/functional.rb +142 -7
- data/lib/torchvision/transforms/normalize.rb +2 -2
- data/lib/torchvision/transforms/random_horizontal_flip.rb +18 -0
- data/lib/torchvision/transforms/random_resized_crop.rb +70 -0
- data/lib/torchvision/transforms/random_vertical_flip.rb +18 -0
- data/lib/torchvision/transforms/resize.rb +13 -0
- data/lib/torchvision/transforms/to_tensor.rb +2 -2
- data/lib/torchvision/utils.rb +120 -0
- data/lib/torchvision/version.rb +1 -1
- metadata +50 -57
@@ -0,0 +1,91 @@
|
|
1
|
+
module TorchVision
|
2
|
+
module Datasets
|
3
|
+
class DatasetFolder < VisionDataset
|
4
|
+
attr_reader :classes
|
5
|
+
|
6
|
+
def initialize(root, extensions: nil, transform: nil, target_transform: nil, is_valid_file: nil)
|
7
|
+
super(root, transform: transform, target_transform: target_transform)
|
8
|
+
classes, class_to_idx = find_classes(@root)
|
9
|
+
samples = make_dataset(@root, class_to_idx, extensions, is_valid_file)
|
10
|
+
if samples.empty?
|
11
|
+
msg = "Found 0 files in subfolders of: #{@root}\n"
|
12
|
+
unless extensions.nil?
|
13
|
+
msg += "Supported extensions are: #{extensions.join(",")}"
|
14
|
+
end
|
15
|
+
raise RuntimeError, msg
|
16
|
+
end
|
17
|
+
|
18
|
+
@loader = lambda do |path|
|
19
|
+
Vips::Image.new_from_file(path)
|
20
|
+
end
|
21
|
+
@extensions = extensions
|
22
|
+
|
23
|
+
@classes = classes
|
24
|
+
@class_to_idx = class_to_idx
|
25
|
+
@samples = samples
|
26
|
+
@targets = samples.map { |s| s[1] }
|
27
|
+
end
|
28
|
+
|
29
|
+
def [](index)
|
30
|
+
path, target = @samples[index]
|
31
|
+
sample = @loader.call(path)
|
32
|
+
if @transform
|
33
|
+
sample = @transform.call(sample)
|
34
|
+
end
|
35
|
+
if @target_transform
|
36
|
+
target = @target_transform.call(target)
|
37
|
+
end
|
38
|
+
|
39
|
+
[sample, target]
|
40
|
+
end
|
41
|
+
|
42
|
+
def size
|
43
|
+
@samples.size
|
44
|
+
end
|
45
|
+
|
46
|
+
private
|
47
|
+
|
48
|
+
def find_classes(dir)
|
49
|
+
classes = Dir.children(dir).select { |d| File.directory?(File.join(dir, d)) }
|
50
|
+
classes.sort!
|
51
|
+
class_to_idx = classes.map.with_index.to_h
|
52
|
+
[classes, class_to_idx]
|
53
|
+
end
|
54
|
+
|
55
|
+
def has_file_allowed_extension(filename, extensions)
|
56
|
+
filename = filename.downcase
|
57
|
+
extensions.any? { |ext| filename.end_with?(ext) }
|
58
|
+
end
|
59
|
+
|
60
|
+
def make_dataset(directory, class_to_idx, extensions, is_valid_file)
|
61
|
+
instances = []
|
62
|
+
directory = File.expand_path(directory)
|
63
|
+
both_none = extensions.nil? && is_valid_file.nil?
|
64
|
+
both_something = !extensions.nil? && !is_valid_file.nil?
|
65
|
+
if both_none || both_something
|
66
|
+
raise ArgumentError, "Both extensions and is_valid_file cannot be None or not None at the same time"
|
67
|
+
end
|
68
|
+
if !extensions.nil?
|
69
|
+
is_valid_file = lambda do |x|
|
70
|
+
has_file_allowed_extension(x, extensions)
|
71
|
+
end
|
72
|
+
end
|
73
|
+
class_to_idx.keys.sort.each do |target_class|
|
74
|
+
class_index = class_to_idx[target_class]
|
75
|
+
target_dir = File.join(directory, target_class)
|
76
|
+
if !File.directory?(target_dir)
|
77
|
+
next
|
78
|
+
end
|
79
|
+
Dir.glob("**", base: target_dir).sort.each do |fname|
|
80
|
+
path = File.join(target_dir, fname)
|
81
|
+
if is_valid_file.call(path)
|
82
|
+
item = [path, class_index]
|
83
|
+
instances << item
|
84
|
+
end
|
85
|
+
end
|
86
|
+
end
|
87
|
+
instances
|
88
|
+
end
|
89
|
+
end
|
90
|
+
end
|
91
|
+
end
|
@@ -0,0 +1,30 @@
|
|
1
|
+
module TorchVision
|
2
|
+
module Datasets
|
3
|
+
class FashionMNIST < MNIST
|
4
|
+
# https://github.com/zalandoresearch/fashion-mnist
|
5
|
+
|
6
|
+
private
|
7
|
+
|
8
|
+
def resources
|
9
|
+
[
|
10
|
+
{
|
11
|
+
url: "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz",
|
12
|
+
sha256: "3aede38d61863908ad78613f6a32ed271626dd12800ba2636569512369268a84"
|
13
|
+
},
|
14
|
+
{
|
15
|
+
url: "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz",
|
16
|
+
sha256: "a04f17134ac03560a47e3764e11b92fc97de4d1bfaf8ba1a3aa29af54cc90845"
|
17
|
+
},
|
18
|
+
{
|
19
|
+
url: "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz",
|
20
|
+
sha256: "346e55b948d973a97e58d2351dde16a484bd415d4595297633bb08f03db6a073"
|
21
|
+
},
|
22
|
+
{
|
23
|
+
url: "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz",
|
24
|
+
sha256: "67da17c76eaffca5446c3361aaab5c3cd6d1c2608764d35dfb1850b086bf8dd5"
|
25
|
+
}
|
26
|
+
]
|
27
|
+
end
|
28
|
+
end
|
29
|
+
end
|
30
|
+
end
|
@@ -0,0 +1,12 @@
|
|
1
|
+
module TorchVision
|
2
|
+
module Datasets
|
3
|
+
class ImageFolder < DatasetFolder
|
4
|
+
IMG_EXTENSIONS = [".jpg", ".jpeg", ".png", ".ppm", ".bmp", ".pgm", ".tif", ".tiff", ".webp"]
|
5
|
+
|
6
|
+
def initialize(root, transform: nil, target_transform: nil, is_valid_file: nil)
|
7
|
+
super(root, extensions: IMG_EXTENSIONS, transform: transform, target_transform: target_transform, is_valid_file: is_valid_file)
|
8
|
+
@imgs = @samples
|
9
|
+
end
|
10
|
+
end
|
11
|
+
end
|
12
|
+
end
|
@@ -0,0 +1,30 @@
|
|
1
|
+
module TorchVision
|
2
|
+
module Datasets
|
3
|
+
class KMNIST < MNIST
|
4
|
+
# https://github.com/rois-codh/kmnist
|
5
|
+
|
6
|
+
private
|
7
|
+
|
8
|
+
def resources
|
9
|
+
[
|
10
|
+
{
|
11
|
+
url: "http://codh.rois.ac.jp/kmnist/dataset/kmnist/train-images-idx3-ubyte.gz",
|
12
|
+
sha256: "51467d22d8cc72929e2a028a0428f2086b092bb31cfb79c69cc0a90ce135fde4"
|
13
|
+
},
|
14
|
+
{
|
15
|
+
url: "http://codh.rois.ac.jp/kmnist/dataset/kmnist/train-labels-idx1-ubyte.gz",
|
16
|
+
sha256: "e38f9ebcd0f3ebcdec7fc8eabdcdaef93bb0df8ea12bee65224341c8183d8e17"
|
17
|
+
},
|
18
|
+
{
|
19
|
+
url: "http://codh.rois.ac.jp/kmnist/dataset/kmnist/t10k-images-idx3-ubyte.gz",
|
20
|
+
sha256: "edd7a857845ad6bb1d0ba43fe7e794d164fe2dce499a1694695a792adfac43c5"
|
21
|
+
},
|
22
|
+
{
|
23
|
+
url: "http://codh.rois.ac.jp/kmnist/dataset/kmnist/t10k-labels-idx1-ubyte.gz",
|
24
|
+
sha256: "20bb9a0ef54c7db3efc55a92eef5582c109615df22683c380526788f98e42a1c"
|
25
|
+
}
|
26
|
+
]
|
27
|
+
end
|
28
|
+
end
|
29
|
+
end
|
30
|
+
end
|
@@ -1,31 +1,10 @@
|
|
1
1
|
module TorchVision
|
2
2
|
module Datasets
|
3
|
-
class MNIST
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
sha256: "440fcabf73cc546fa21475e81ea370265605f56be210a4024d2ca8f203523609"
|
8
|
-
},
|
9
|
-
{
|
10
|
-
url: "http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz",
|
11
|
-
sha256: "3552534a0a558bbed6aed32b30c495cca23d567ec52cac8be1a0730e8010255c"
|
12
|
-
},
|
13
|
-
{
|
14
|
-
url: "http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz",
|
15
|
-
sha256: "8d422c7b0a1c1c79245a5bcf07fe86e33eeafee792b84584aec276f5a2dbc4e6"
|
16
|
-
},
|
17
|
-
{
|
18
|
-
url: "http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz",
|
19
|
-
sha256: "f7ae60f92e00ec6debd23a6088c31dbd2371eca3ffa0defaefb259924204aec6"
|
20
|
-
}
|
21
|
-
]
|
22
|
-
TRAINING_FILE = "training.pt"
|
23
|
-
TEST_FILE = "test.pt"
|
24
|
-
|
25
|
-
def initialize(root, train: true, download: false, transform: nil)
|
26
|
-
@root = root
|
3
|
+
class MNIST < VisionDataset
|
4
|
+
# http://yann.lecun.com/exdb/mnist/
|
5
|
+
def initialize(root, train: true, download: false, transform: nil, target_transform: nil)
|
6
|
+
super(root, transform: transform, target_transform: target_transform)
|
27
7
|
@train = train
|
28
|
-
@transform = transform
|
29
8
|
|
30
9
|
self.download if download
|
31
10
|
|
@@ -33,35 +12,37 @@ module TorchVision
|
|
33
12
|
raise Error, "Dataset not found. You can use download: true to download it"
|
34
13
|
end
|
35
14
|
|
36
|
-
data_file = @train ?
|
15
|
+
data_file = @train ? training_file : test_file
|
37
16
|
@data, @targets = Torch.load(File.join(processed_folder, data_file))
|
38
17
|
end
|
39
18
|
|
40
19
|
def size
|
41
|
-
@data.size
|
20
|
+
@data.size(0)
|
42
21
|
end
|
43
22
|
|
44
23
|
def [](index)
|
45
|
-
img = @data[index]
|
46
|
-
|
24
|
+
img, target = @data[index], @targets[index].item
|
25
|
+
|
26
|
+
img = Utils.image_from_array(img)
|
27
|
+
|
47
28
|
img = @transform.call(img) if @transform
|
48
29
|
|
49
|
-
target = @
|
30
|
+
target = @target_transform.call(target) if @target_transform
|
50
31
|
|
51
32
|
[img, target]
|
52
33
|
end
|
53
34
|
|
54
35
|
def raw_folder
|
55
|
-
File.join(@root, "
|
36
|
+
File.join(@root, self.class.name.split("::").last, "raw")
|
56
37
|
end
|
57
38
|
|
58
39
|
def processed_folder
|
59
|
-
File.join(@root, "
|
40
|
+
File.join(@root, self.class.name.split("::").last, "processed")
|
60
41
|
end
|
61
42
|
|
62
43
|
def check_exists
|
63
|
-
File.exist?(File.join(processed_folder,
|
64
|
-
File.exist?(File.join(processed_folder,
|
44
|
+
File.exist?(File.join(processed_folder, training_file)) &&
|
45
|
+
File.exist?(File.join(processed_folder, test_file))
|
65
46
|
end
|
66
47
|
|
67
48
|
def download
|
@@ -70,7 +51,7 @@ module TorchVision
|
|
70
51
|
FileUtils.mkdir_p(raw_folder)
|
71
52
|
FileUtils.mkdir_p(processed_folder)
|
72
53
|
|
73
|
-
|
54
|
+
resources.each do |resource|
|
74
55
|
filename = resource[:url].split("/").last
|
75
56
|
download_file(resource[:url], download_root: raw_folder, filename: filename, sha256: resource[:sha256])
|
76
57
|
end
|
@@ -86,14 +67,43 @@ module TorchVision
|
|
86
67
|
unpack_mnist("t10k-labels-idx1-ubyte", 8, [10000])
|
87
68
|
]
|
88
69
|
|
89
|
-
Torch.save(training_set, File.join(processed_folder,
|
90
|
-
Torch.save(test_set, File.join(processed_folder,
|
70
|
+
Torch.save(training_set, File.join(processed_folder, training_file))
|
71
|
+
Torch.save(test_set, File.join(processed_folder, test_file))
|
91
72
|
|
92
73
|
puts "Done!"
|
93
74
|
end
|
94
75
|
|
95
76
|
private
|
96
77
|
|
78
|
+
def resources
|
79
|
+
[
|
80
|
+
{
|
81
|
+
url: "http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz",
|
82
|
+
sha256: "440fcabf73cc546fa21475e81ea370265605f56be210a4024d2ca8f203523609"
|
83
|
+
},
|
84
|
+
{
|
85
|
+
url: "http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz",
|
86
|
+
sha256: "3552534a0a558bbed6aed32b30c495cca23d567ec52cac8be1a0730e8010255c"
|
87
|
+
},
|
88
|
+
{
|
89
|
+
url: "http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz",
|
90
|
+
sha256: "8d422c7b0a1c1c79245a5bcf07fe86e33eeafee792b84584aec276f5a2dbc4e6"
|
91
|
+
},
|
92
|
+
{
|
93
|
+
url: "http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz",
|
94
|
+
sha256: "f7ae60f92e00ec6debd23a6088c31dbd2371eca3ffa0defaefb259924204aec6"
|
95
|
+
}
|
96
|
+
]
|
97
|
+
end
|
98
|
+
|
99
|
+
def training_file
|
100
|
+
"training.pt"
|
101
|
+
end
|
102
|
+
|
103
|
+
def test_file
|
104
|
+
"test.pt"
|
105
|
+
end
|
106
|
+
|
97
107
|
def unpack_mnist(path, offset, shape)
|
98
108
|
path = File.join(raw_folder, "#{path}.gz")
|
99
109
|
File.open(path, "rb") do |f|
|
@@ -102,45 +112,6 @@ module TorchVision
|
|
102
112
|
Torch.tensor(Numo::UInt8.from_string(gz.read, shape))
|
103
113
|
end
|
104
114
|
end
|
105
|
-
|
106
|
-
def download_file(url, download_root:, filename:, sha256:)
|
107
|
-
FileUtils.mkdir_p(download_root)
|
108
|
-
|
109
|
-
dest = File.join(download_root, filename)
|
110
|
-
return dest if File.exist?(dest)
|
111
|
-
|
112
|
-
temp_path = "#{Dir.tmpdir}/#{Time.now.to_f}" # TODO better name
|
113
|
-
|
114
|
-
digest = Digest::SHA256.new
|
115
|
-
|
116
|
-
uri = URI(url)
|
117
|
-
|
118
|
-
# Net::HTTP automatically adds Accept-Encoding for compression
|
119
|
-
# of response bodies and automatically decompresses gzip
|
120
|
-
# and deflateresponses unless a Range header was sent.
|
121
|
-
# https://ruby-doc.org/stdlib-2.6.4/libdoc/net/http/rdoc/Net/HTTP.html
|
122
|
-
Net::HTTP.start(uri.host, uri.port, use_ssl: uri.scheme == "https") do |http|
|
123
|
-
request = Net::HTTP::Get.new(uri)
|
124
|
-
|
125
|
-
puts "Downloading #{url}..."
|
126
|
-
File.open(temp_path, "wb") do |f|
|
127
|
-
http.request(request) do |response|
|
128
|
-
response.read_body do |chunk|
|
129
|
-
f.write(chunk)
|
130
|
-
digest.update(chunk)
|
131
|
-
end
|
132
|
-
end
|
133
|
-
end
|
134
|
-
end
|
135
|
-
|
136
|
-
if digest.hexdigest != sha256
|
137
|
-
raise Error, "Bad hash: #{digest.hexdigest}"
|
138
|
-
end
|
139
|
-
|
140
|
-
FileUtils.mv(temp_path, dest)
|
141
|
-
|
142
|
-
dest
|
143
|
-
end
|
144
115
|
end
|
145
116
|
end
|
146
117
|
end
|
@@ -0,0 +1,67 @@
|
|
1
|
+
module TorchVision
|
2
|
+
module Datasets
|
3
|
+
class VisionDataset < Torch::Utils::Data::Dataset
|
4
|
+
attr_reader :data, :targets
|
5
|
+
|
6
|
+
def initialize(root, transforms: nil, transform: nil, target_transform: nil)
|
7
|
+
@root = root
|
8
|
+
|
9
|
+
has_transforms = !transforms.nil?
|
10
|
+
has_separate_transform = !transform.nil? || !target_transform.nil?
|
11
|
+
if has_transforms && has_separate_transform
|
12
|
+
raise ArgumentError, "Only transforms or transform/target_transform can be passed as argument"
|
13
|
+
end
|
14
|
+
|
15
|
+
@transform = transform
|
16
|
+
@target_transform = target_transform
|
17
|
+
|
18
|
+
if has_separate_transform
|
19
|
+
# transforms = StandardTransform.new(transform, target_transform)
|
20
|
+
end
|
21
|
+
@transforms = transforms
|
22
|
+
end
|
23
|
+
|
24
|
+
private
|
25
|
+
|
26
|
+
def download_file(url, download_root:, filename:, sha256:)
|
27
|
+
FileUtils.mkdir_p(download_root)
|
28
|
+
|
29
|
+
dest = File.join(download_root, filename)
|
30
|
+
return dest if File.exist?(dest)
|
31
|
+
|
32
|
+
temp_path = "#{Dir.tmpdir}/#{Time.now.to_f}" # TODO better name
|
33
|
+
|
34
|
+
uri = URI(url)
|
35
|
+
|
36
|
+
# Net::HTTP automatically adds Accept-Encoding for compression
|
37
|
+
# of response bodies and automatically decompresses gzip
|
38
|
+
# and deflateresponses unless a Range header was sent.
|
39
|
+
# https://ruby-doc.org/stdlib-2.6.4/libdoc/net/http/rdoc/Net/HTTP.html
|
40
|
+
Net::HTTP.start(uri.host, uri.port, use_ssl: uri.scheme == "https") do |http|
|
41
|
+
request = Net::HTTP::Get.new(uri)
|
42
|
+
|
43
|
+
puts "Downloading #{url}..."
|
44
|
+
File.open(temp_path, "wb") do |f|
|
45
|
+
http.request(request) do |response|
|
46
|
+
response.read_body do |chunk|
|
47
|
+
f.write(chunk)
|
48
|
+
end
|
49
|
+
end
|
50
|
+
end
|
51
|
+
end
|
52
|
+
|
53
|
+
unless check_integrity(temp_path, sha256)
|
54
|
+
raise Error, "Bad hash"
|
55
|
+
end
|
56
|
+
|
57
|
+
FileUtils.mv(temp_path, dest)
|
58
|
+
|
59
|
+
dest
|
60
|
+
end
|
61
|
+
|
62
|
+
def check_integrity(path, sha256)
|
63
|
+
File.exist?(path) && Digest::SHA256.file(path).hexdigest == sha256
|
64
|
+
end
|
65
|
+
end
|
66
|
+
end
|
67
|
+
end
|
@@ -0,0 +1,42 @@
|
|
1
|
+
module TorchVision
|
2
|
+
module Models
|
3
|
+
class AlexNet < Torch::NN::Module
|
4
|
+
def initialize(num_classes: 1000)
|
5
|
+
super()
|
6
|
+
@features = Torch::NN::Sequential.new(
|
7
|
+
Torch::NN::Conv2d.new(3, 64, 11, stride: 4, padding: 2),
|
8
|
+
Torch::NN::ReLU.new(inplace: true),
|
9
|
+
Torch::NN::MaxPool2d.new(3, stride: 2),
|
10
|
+
Torch::NN::Conv2d.new(64, 192, 5, padding: 2),
|
11
|
+
Torch::NN::ReLU.new(inplace: true),
|
12
|
+
Torch::NN::MaxPool2d.new(3, stride: 2),
|
13
|
+
Torch::NN::Conv2d.new(192, 384, 3, padding: 1),
|
14
|
+
Torch::NN::ReLU.new(inplace: true),
|
15
|
+
Torch::NN::Conv2d.new(384, 256, 3, padding: 1),
|
16
|
+
Torch::NN::ReLU.new(inplace: true),
|
17
|
+
Torch::NN::Conv2d.new(256, 256, 3, padding: 1),
|
18
|
+
Torch::NN::ReLU.new(inplace: true),
|
19
|
+
Torch::NN::MaxPool2d.new(3, stride: 2),
|
20
|
+
)
|
21
|
+
@avgpool = Torch::NN::AdaptiveAvgPool2d.new([6, 6])
|
22
|
+
@classifier = Torch::NN::Sequential.new(
|
23
|
+
Torch::NN::Dropout.new,
|
24
|
+
Torch::NN::Linear.new(256 * 6 * 6, 4096),
|
25
|
+
Torch::NN::ReLU.new(inplace: true),
|
26
|
+
Torch::NN::Dropout.new,
|
27
|
+
Torch::NN::Linear.new(4096, 4096),
|
28
|
+
Torch::NN::ReLU.new(inplace: true),
|
29
|
+
Torch::NN::Linear.new(4096, num_classes)
|
30
|
+
)
|
31
|
+
end
|
32
|
+
|
33
|
+
def forward(x)
|
34
|
+
x = @features.call(x)
|
35
|
+
x = @avgpool.call(x)
|
36
|
+
x = Torch.flatten(x, 1)
|
37
|
+
x = @classifier.call(x)
|
38
|
+
x
|
39
|
+
end
|
40
|
+
end
|
41
|
+
end
|
42
|
+
end
|