torch-rb 0.1.5 → 0.1.6
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/README.md +1 -1
- data/ext/torch/ext.cpp +0 -170
- data/ext/torch/nn_functions.cpp +44 -24
- data/ext/torch/templates.cpp +55 -0
- data/ext/torch/templates.hpp +48 -0
- data/ext/torch/tensor_functions.cpp +76 -16
- data/ext/torch/torch_functions.cpp +165 -65
- data/lib/torch.rb +51 -42
- data/lib/torch/ext.bundle +0 -0
- data/lib/torch/native/dispatcher.rb +1 -1
- data/lib/torch/native/function.rb +36 -5
- data/lib/torch/native/generator.rb +26 -7
- data/lib/torch/native/parser.rb +51 -14
- data/lib/torch/nn/avg_pool1d.rb +18 -0
- data/lib/torch/nn/avg_pool2d.rb +7 -2
- data/lib/torch/nn/avg_pool3d.rb +19 -0
- data/lib/torch/nn/avg_poolnd.rb +1 -1
- data/lib/torch/nn/batch_norm.rb +75 -0
- data/lib/torch/nn/batch_norm1d.rb +11 -0
- data/lib/torch/nn/batch_norm2d.rb +11 -0
- data/lib/torch/nn/batch_norm3d.rb +11 -0
- data/lib/torch/nn/constant_pad1d.rb +10 -0
- data/lib/torch/nn/constant_pad2d.rb +10 -0
- data/lib/torch/nn/constant_pad3d.rb +10 -0
- data/lib/torch/nn/constant_padnd.rb +18 -0
- data/lib/torch/nn/conv1d.rb +22 -0
- data/lib/torch/nn/conv2d.rb +9 -17
- data/lib/torch/nn/conv3d.rb +22 -0
- data/lib/torch/nn/fold.rb +20 -0
- data/lib/torch/nn/functional.rb +320 -100
- data/lib/torch/nn/group_norm.rb +36 -0
- data/lib/torch/nn/gru.rb +49 -0
- data/lib/torch/nn/hardshrink.rb +18 -0
- data/lib/torch/nn/instance_norm.rb +20 -0
- data/lib/torch/nn/instance_norm1d.rb +18 -0
- data/lib/torch/nn/instance_norm2d.rb +11 -0
- data/lib/torch/nn/instance_norm3d.rb +11 -0
- data/lib/torch/nn/layer_norm.rb +35 -0
- data/lib/torch/nn/local_response_norm.rb +21 -0
- data/lib/torch/nn/log_sigmoid.rb +9 -0
- data/lib/torch/nn/lp_pool1d.rb +9 -0
- data/lib/torch/nn/lp_pool2d.rb +9 -0
- data/lib/torch/nn/lp_poolnd.rb +22 -0
- data/lib/torch/nn/lstm.rb +66 -0
- data/lib/torch/nn/max_pool1d.rb +9 -0
- data/lib/torch/nn/max_pool2d.rb +1 -1
- data/lib/torch/nn/max_pool3d.rb +9 -0
- data/lib/torch/nn/max_poolnd.rb +6 -6
- data/lib/torch/nn/max_unpool1d.rb +16 -0
- data/lib/torch/nn/max_unpool2d.rb +16 -0
- data/lib/torch/nn/max_unpool3d.rb +16 -0
- data/lib/torch/nn/max_unpoolnd.rb +9 -0
- data/lib/torch/nn/module.rb +7 -0
- data/lib/torch/nn/reflection_pad1d.rb +10 -0
- data/lib/torch/nn/reflection_pad2d.rb +10 -0
- data/lib/torch/nn/reflection_padnd.rb +13 -0
- data/lib/torch/nn/replication_pad1d.rb +10 -0
- data/lib/torch/nn/replication_pad2d.rb +10 -0
- data/lib/torch/nn/replication_pad3d.rb +10 -0
- data/lib/torch/nn/replication_padnd.rb +13 -0
- data/lib/torch/nn/rnn_base.rb +48 -4
- data/lib/torch/nn/softshrink.rb +18 -0
- data/lib/torch/nn/softsign.rb +9 -0
- data/lib/torch/nn/tanh.rb +9 -0
- data/lib/torch/nn/tanhshrink.rb +9 -0
- data/lib/torch/nn/unfold.rb +19 -0
- data/lib/torch/nn/utils.rb +25 -0
- data/lib/torch/nn/zero_pad2d.rb +9 -0
- data/lib/torch/tensor.rb +14 -25
- data/lib/torch/version.rb +1 -1
- metadata +50 -2
data/lib/torch.rb
CHANGED
@@ -29,6 +29,7 @@ require "torch/optim/lr_scheduler/step_lr"
|
|
29
29
|
|
30
30
|
# nn parameters
|
31
31
|
require "torch/nn/parameter"
|
32
|
+
require "torch/nn/utils"
|
32
33
|
|
33
34
|
# nn containers
|
34
35
|
require "torch/nn/module"
|
@@ -36,17 +37,61 @@ require "torch/nn/sequential"
|
|
36
37
|
|
37
38
|
# nn convolution layers
|
38
39
|
require "torch/nn/convnd"
|
40
|
+
require "torch/nn/conv1d"
|
39
41
|
require "torch/nn/conv2d"
|
42
|
+
require "torch/nn/conv3d"
|
43
|
+
require "torch/nn/unfold"
|
44
|
+
require "torch/nn/fold"
|
40
45
|
|
41
46
|
# nn pooling layers
|
42
47
|
require "torch/nn/max_poolnd"
|
48
|
+
require "torch/nn/max_pool1d"
|
43
49
|
require "torch/nn/max_pool2d"
|
50
|
+
require "torch/nn/max_pool3d"
|
51
|
+
require "torch/nn/max_unpoolnd"
|
52
|
+
require "torch/nn/max_unpool1d"
|
53
|
+
require "torch/nn/max_unpool2d"
|
54
|
+
require "torch/nn/max_unpool3d"
|
44
55
|
require "torch/nn/avg_poolnd"
|
56
|
+
require "torch/nn/avg_pool1d"
|
45
57
|
require "torch/nn/avg_pool2d"
|
58
|
+
require "torch/nn/avg_pool3d"
|
59
|
+
require "torch/nn/lp_poolnd"
|
60
|
+
require "torch/nn/lp_pool1d"
|
61
|
+
require "torch/nn/lp_pool2d"
|
62
|
+
|
63
|
+
# nn padding layers
|
64
|
+
require "torch/nn/reflection_padnd"
|
65
|
+
require "torch/nn/reflection_pad1d"
|
66
|
+
require "torch/nn/reflection_pad2d"
|
67
|
+
require "torch/nn/replication_padnd"
|
68
|
+
require "torch/nn/replication_pad1d"
|
69
|
+
require "torch/nn/replication_pad2d"
|
70
|
+
require "torch/nn/replication_pad3d"
|
71
|
+
require "torch/nn/constant_padnd"
|
72
|
+
require "torch/nn/constant_pad1d"
|
73
|
+
require "torch/nn/constant_pad2d"
|
74
|
+
require "torch/nn/constant_pad3d"
|
75
|
+
require "torch/nn/zero_pad2d"
|
76
|
+
|
77
|
+
# nn normalization layers
|
78
|
+
require "torch/nn/batch_norm"
|
79
|
+
require "torch/nn/batch_norm1d"
|
80
|
+
require "torch/nn/batch_norm2d"
|
81
|
+
require "torch/nn/batch_norm3d"
|
82
|
+
require "torch/nn/group_norm"
|
83
|
+
require "torch/nn/instance_norm"
|
84
|
+
require "torch/nn/instance_norm1d"
|
85
|
+
require "torch/nn/instance_norm2d"
|
86
|
+
require "torch/nn/instance_norm3d"
|
87
|
+
require "torch/nn/layer_norm"
|
88
|
+
require "torch/nn/local_response_norm"
|
46
89
|
|
47
90
|
# nn recurrent layers
|
48
91
|
require "torch/nn/rnn_base"
|
49
92
|
require "torch/nn/rnn"
|
93
|
+
require "torch/nn/lstm"
|
94
|
+
require "torch/nn/gru"
|
50
95
|
|
51
96
|
# nn linear layers
|
52
97
|
require "torch/nn/bilinear"
|
@@ -62,11 +107,17 @@ require "torch/nn/dropout3d"
|
|
62
107
|
require "torch/nn/feature_alpha_dropout"
|
63
108
|
|
64
109
|
# nn activations
|
110
|
+
require "torch/nn/hardshrink"
|
65
111
|
require "torch/nn/leaky_relu"
|
112
|
+
require "torch/nn/log_sigmoid"
|
66
113
|
require "torch/nn/prelu"
|
67
114
|
require "torch/nn/relu"
|
68
115
|
require "torch/nn/sigmoid"
|
69
116
|
require "torch/nn/softplus"
|
117
|
+
require "torch/nn/softshrink"
|
118
|
+
require "torch/nn/softsign"
|
119
|
+
require "torch/nn/tanh"
|
120
|
+
require "torch/nn/tanhshrink"
|
70
121
|
|
71
122
|
# nn activations other
|
72
123
|
require "torch/nn/log_softmax"
|
@@ -364,48 +415,6 @@ module Torch
|
|
364
415
|
zeros(input.size, like_options(input, options))
|
365
416
|
end
|
366
417
|
|
367
|
-
# --- begin operations ---
|
368
|
-
|
369
|
-
# TODO support out
|
370
|
-
def mean(input, dim = nil, keepdim: false)
|
371
|
-
if dim
|
372
|
-
_mean_dim(input, dim, keepdim)
|
373
|
-
else
|
374
|
-
_mean(input)
|
375
|
-
end
|
376
|
-
end
|
377
|
-
|
378
|
-
# TODO support dtype
|
379
|
-
def sum(input, dim = nil, keepdim: false)
|
380
|
-
if dim
|
381
|
-
_sum_dim(input, dim, keepdim)
|
382
|
-
else
|
383
|
-
_sum(input)
|
384
|
-
end
|
385
|
-
end
|
386
|
-
|
387
|
-
def topk(input, k)
|
388
|
-
_topk(input, k)
|
389
|
-
end
|
390
|
-
|
391
|
-
def max(input, dim = nil, keepdim: false, out: nil)
|
392
|
-
if dim
|
393
|
-
raise NotImplementedYet unless out
|
394
|
-
_max_out(out[0], out[1], input, dim, keepdim)
|
395
|
-
else
|
396
|
-
_max(input)
|
397
|
-
end
|
398
|
-
end
|
399
|
-
|
400
|
-
# TODO make dim keyword argument
|
401
|
-
def log_softmax(input, dim)
|
402
|
-
_log_softmax(input, dim)
|
403
|
-
end
|
404
|
-
|
405
|
-
def softmax(input, dim: nil)
|
406
|
-
_softmax(input, dim)
|
407
|
-
end
|
408
|
-
|
409
418
|
private
|
410
419
|
|
411
420
|
def tensor_size(size)
|
data/lib/torch/ext.bundle
CHANGED
Binary file
|
@@ -18,7 +18,7 @@ module Torch
|
|
18
18
|
functions = Generator.grouped_functions
|
19
19
|
bind_functions(::Torch, :define_singleton_method, functions[:torch])
|
20
20
|
bind_functions(::Torch::Tensor, :define_method, functions[:tensor])
|
21
|
-
|
21
|
+
bind_functions(::Torch::NN, :define_singleton_method, functions[:nn])
|
22
22
|
end
|
23
23
|
|
24
24
|
def bind_functions(context, def_method, functions)
|
@@ -35,11 +35,38 @@ module Torch
|
|
35
35
|
end
|
36
36
|
t, _, k = a.rpartition(" ")
|
37
37
|
k, d = k.split("=")
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
38
|
+
has_default = !d.nil?
|
39
|
+
|
40
|
+
if d
|
41
|
+
d =
|
42
|
+
case d
|
43
|
+
when "True"
|
44
|
+
true
|
45
|
+
when "False"
|
46
|
+
false
|
47
|
+
when "None"
|
48
|
+
nil
|
49
|
+
when /\A\-?\d+\z/
|
50
|
+
d.to_i
|
51
|
+
when "[]"
|
52
|
+
[]
|
53
|
+
when "[0,1]"
|
54
|
+
[0, 1]
|
55
|
+
when /\A\de\-\d+\z/, /\A\d+\.\d+\z/
|
56
|
+
d.to_f
|
57
|
+
when "Mean"
|
58
|
+
"mean"
|
59
|
+
when "contiguous_format"
|
60
|
+
d
|
61
|
+
when "long"
|
62
|
+
:long
|
63
|
+
else
|
64
|
+
raise "Unknown default: #{d}"
|
65
|
+
end
|
66
|
+
end
|
67
|
+
|
68
|
+
next if t == "Generator?"
|
69
|
+
args << {name: k, type: t, default: d, pos: pos, has_default: has_default}
|
43
70
|
end
|
44
71
|
args
|
45
72
|
end
|
@@ -49,6 +76,10 @@ module Torch
|
|
49
76
|
@out_size ||= func.split("->").last.count("!")
|
50
77
|
end
|
51
78
|
|
79
|
+
def ret_size
|
80
|
+
@ret_size ||= func.split("->").last.split(", ").size
|
81
|
+
end
|
82
|
+
|
52
83
|
def out?
|
53
84
|
out_size > 0 && base_name[-1] != "_"
|
54
85
|
end
|
@@ -17,14 +17,23 @@ module Torch
|
|
17
17
|
def grouped_functions
|
18
18
|
functions = functions()
|
19
19
|
|
20
|
-
#
|
20
|
+
# skip functions
|
21
21
|
skip_binding = ["unique_dim_consecutive", "einsum", "normal"]
|
22
|
-
skip_args = ["bool[3]", "Dimname", "
|
23
|
-
|
22
|
+
skip_args = ["bool[3]", "Dimname", "MemoryFormat", "Layout", "Storage", "ConstQuantizerPtr"]
|
23
|
+
|
24
|
+
# remove functions
|
25
|
+
functions.reject! do |f|
|
26
|
+
f.ruby_name.start_with?("_") ||
|
27
|
+
f.ruby_name.end_with?("_backward") ||
|
28
|
+
skip_binding.include?(f.ruby_name) ||
|
29
|
+
f.args.any? { |a| a[:type].include?("Dimname") }
|
30
|
+
end
|
31
|
+
|
32
|
+
# separate out into todo
|
24
33
|
todo_functions, functions =
|
25
34
|
functions.partition do |f|
|
26
35
|
f.args.any? do |a|
|
27
|
-
a[:type].include?("?") && !["Tensor?", "Generator?", "int?"].include?(a[:type]) ||
|
36
|
+
a[:type].include?("?") && !["Tensor?", "Generator?", "int?", "ScalarType?"].include?(a[:type]) ||
|
28
37
|
skip_args.any? { |sa| a[:type].include?(sa) }
|
29
38
|
end
|
30
39
|
end
|
@@ -33,7 +42,9 @@ module Torch
|
|
33
42
|
# there may be a better way to do this
|
34
43
|
optional_functions, functions = functions.partition { |f| f.args.any? { |a| a[:type] == "int?" } }
|
35
44
|
optional_functions.each do |f|
|
36
|
-
next if f.ruby_name
|
45
|
+
next if f.ruby_name == "cross"
|
46
|
+
next if f.ruby_name.start_with?("avg_pool") && f.out?
|
47
|
+
|
37
48
|
opt_args = f.args.select { |a| a[:type] == "int?" }
|
38
49
|
if opt_args.size == 1
|
39
50
|
sep = f.name.include?(".") ? "_" : "."
|
@@ -85,7 +96,7 @@ void add_%{type}_functions(Module m) {
|
|
85
96
|
|
86
97
|
cpp_defs = []
|
87
98
|
functions.sort_by(&:cpp_name).each do |func|
|
88
|
-
fargs = func.args
|
99
|
+
fargs = func.args #.select { |a| a[:type] != "Generator?" }
|
89
100
|
|
90
101
|
cpp_args = []
|
91
102
|
fargs.each do |a|
|
@@ -96,6 +107,8 @@ void add_%{type}_functions(Module m) {
|
|
96
107
|
when "Tensor?"
|
97
108
|
# TODO better signature
|
98
109
|
"OptionalTensor"
|
110
|
+
when "ScalarType?"
|
111
|
+
"OptionalScalarType"
|
99
112
|
when "Tensor[]"
|
100
113
|
"TensorList"
|
101
114
|
when "int"
|
@@ -121,10 +134,16 @@ void add_%{type}_functions(Module m) {
|
|
121
134
|
|
122
135
|
prefix = def_method == :define_method ? "self." : "torch::"
|
123
136
|
|
137
|
+
body = "#{prefix}#{dispatch}(#{args.join(", ")})"
|
138
|
+
# TODO check type as well
|
139
|
+
if func.ret_size > 1
|
140
|
+
body = "wrap(#{body})"
|
141
|
+
end
|
142
|
+
|
124
143
|
cpp_defs << ".#{def_method}(
|
125
144
|
\"#{func.cpp_name}\",
|
126
145
|
*[](#{cpp_args.join(", ")}) {
|
127
|
-
return #{
|
146
|
+
return #{body};
|
128
147
|
})"
|
129
148
|
end
|
130
149
|
|
data/lib/torch/native/parser.rb
CHANGED
@@ -4,19 +4,44 @@ module Torch
|
|
4
4
|
def initialize(functions)
|
5
5
|
@functions = functions
|
6
6
|
@name = @functions.first.ruby_name
|
7
|
-
@min_args = @functions.map { |f| f.args.count { |a| a[:pos] && a[:
|
7
|
+
@min_args = @functions.map { |f| f.args.count { |a| a[:pos] && !a[:has_default] } }.min
|
8
8
|
@max_args = @functions.map { |f| f.args.count { |a| a[:pos] } }.max
|
9
9
|
end
|
10
10
|
|
11
11
|
def parse(args, options)
|
12
12
|
candidates = @functions.dup
|
13
13
|
|
14
|
+
# remove nil
|
15
|
+
while args.any? && args.last.nil?
|
16
|
+
args.pop
|
17
|
+
end
|
18
|
+
|
19
|
+
# TODO account for args passed as options here
|
14
20
|
if args.size < @min_args || args.size > @max_args
|
15
21
|
expected = String.new(@min_args.to_s)
|
16
22
|
expected += "..#{@max_args}" if @max_args != @min_args
|
17
23
|
return {error: "wrong number of arguments (given #{args.size}, expected #{expected})"}
|
18
24
|
end
|
19
25
|
|
26
|
+
candidates.reject! { |f| args.size > f.args.size }
|
27
|
+
|
28
|
+
# exclude functions missing required options
|
29
|
+
candidates.reject! do |func|
|
30
|
+
# TODO make more generic
|
31
|
+
func.out? && !options[:out]
|
32
|
+
end
|
33
|
+
|
34
|
+
# handle out with multiple
|
35
|
+
# there should only be one match, so safe to modify all
|
36
|
+
out_func = candidates.find { |f| f.out? }
|
37
|
+
if out_func && out_func.out_size > 1 && options[:out]
|
38
|
+
out_args = out_func.args.last(2).map { |a| a[:name] }
|
39
|
+
out_args.zip(options.delete(:out)).each do |k, v|
|
40
|
+
options[k.to_sym] = v
|
41
|
+
end
|
42
|
+
candidates = [out_func]
|
43
|
+
end
|
44
|
+
|
20
45
|
# exclude functions where options don't match
|
21
46
|
options.each do |k, v|
|
22
47
|
candidates.select! do |func|
|
@@ -26,12 +51,6 @@ module Torch
|
|
26
51
|
return {error: "unknown keyword: #{k}"} if candidates.empty?
|
27
52
|
end
|
28
53
|
|
29
|
-
# exclude functions missing required options
|
30
|
-
candidates.reject! do |func|
|
31
|
-
# TODO make more generic
|
32
|
-
func.out? && !options[:out]
|
33
|
-
end
|
34
|
-
|
35
54
|
final_values = {}
|
36
55
|
|
37
56
|
# check args
|
@@ -41,29 +60,47 @@ module Torch
|
|
41
60
|
values = args.zip(func.args).map { |a, fa| [fa[:name], a] }.to_h
|
42
61
|
values.merge!(options.map { |k, v| [k.to_s, v] }.to_h)
|
43
62
|
func.args.each do |fa|
|
44
|
-
values[fa[:name]]
|
63
|
+
values[fa[:name]] = fa[:default] if values[fa[:name]].nil?
|
45
64
|
end
|
46
65
|
|
47
66
|
arg_types = func.args.map { |a| [a[:name], a[:type]] }.to_h
|
48
67
|
|
49
|
-
values.
|
68
|
+
values.each_key do |k|
|
69
|
+
v = values[k]
|
50
70
|
t = arg_types[k].split("(").first
|
71
|
+
|
51
72
|
good =
|
52
73
|
case t
|
53
74
|
when "Tensor"
|
54
75
|
v.is_a?(Tensor)
|
76
|
+
when "Tensor?"
|
77
|
+
v.nil? || v.is_a?(Tensor)
|
55
78
|
when "Tensor[]"
|
56
|
-
v.all? { |v2| v2.is_a?(Tensor) }
|
79
|
+
v.is_a?(Array) && v.all? { |v2| v2.is_a?(Tensor) }
|
57
80
|
when "int"
|
58
|
-
|
59
|
-
|
60
|
-
|
81
|
+
if k == "reduction"
|
82
|
+
v.is_a?(String)
|
83
|
+
else
|
84
|
+
v.is_a?(Integer)
|
85
|
+
end
|
86
|
+
when "float"
|
87
|
+
v.is_a?(Numeric)
|
88
|
+
when /int\[.*\]/
|
89
|
+
if v.is_a?(Integer)
|
90
|
+
size = t[4..-2]
|
91
|
+
raise Error, "Unknown size: #{size}. Please report a bug with #{@name}." unless size =~ /\A\d+\z/
|
92
|
+
v = [v] * size.to_i
|
93
|
+
values[k] = v
|
94
|
+
end
|
95
|
+
v.is_a?(Array) && v.all? { |v2| v2.is_a?(Integer) }
|
61
96
|
when "Scalar"
|
62
97
|
v.is_a?(Numeric)
|
98
|
+
when "ScalarType?"
|
99
|
+
v.nil?
|
63
100
|
when "bool"
|
64
101
|
v == true || v == false
|
65
102
|
else
|
66
|
-
raise Error, "Unknown argument type: #{arg_types[k]}. Please report a bug with #{@name}"
|
103
|
+
raise Error, "Unknown argument type: #{arg_types[k]}. Please report a bug with #{@name}."
|
67
104
|
end
|
68
105
|
|
69
106
|
if !good
|
@@ -0,0 +1,18 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class AvgPool1d < AvgPoolNd
|
4
|
+
def initialize(kernel_size, stride: nil, padding: 0, ceil_mode: false, count_include_pad: true)
|
5
|
+
super()
|
6
|
+
@kernel_size = _single(kernel_size)
|
7
|
+
@stride = _single(stride || kernel_size)
|
8
|
+
@padding = _single(padding)
|
9
|
+
@ceil_mode = ceil_mode
|
10
|
+
@count_include_pad = count_include_pad
|
11
|
+
end
|
12
|
+
|
13
|
+
def forward(input)
|
14
|
+
F.avg_pool1d(input, @kernel_size, @stride, @padding, @ceil_mode, @count_include_pad)
|
15
|
+
end
|
16
|
+
end
|
17
|
+
end
|
18
|
+
end
|
data/lib/torch/nn/avg_pool2d.rb
CHANGED
@@ -1,13 +1,18 @@
|
|
1
1
|
module Torch
|
2
2
|
module NN
|
3
3
|
class AvgPool2d < AvgPoolNd
|
4
|
-
def initialize(kernel_size)
|
4
|
+
def initialize(kernel_size, stride: nil, padding: 0, ceil_mode: false, count_include_pad: true, divisor_override: nil)
|
5
5
|
super()
|
6
6
|
@kernel_size = kernel_size
|
7
|
+
@stride = stride || kernel_size
|
8
|
+
@padding = padding
|
9
|
+
@ceil_mode = ceil_mode
|
10
|
+
@count_include_pad = count_include_pad
|
11
|
+
@divisor_override = divisor_override
|
7
12
|
end
|
8
13
|
|
9
14
|
def forward(input)
|
10
|
-
F.avg_pool2d(input, @kernel_size)
|
15
|
+
F.avg_pool2d(input, @kernel_size, @stride, @padding, @ceil_mode, @count_include_pad, @divisor_override)
|
11
16
|
end
|
12
17
|
end
|
13
18
|
end
|
@@ -0,0 +1,19 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class AvgPool3d < AvgPoolNd
|
4
|
+
def initialize(kernel_size, stride: nil, padding: 0, ceil_mode: false, count_include_pad: true, divisor_override: nil)
|
5
|
+
super()
|
6
|
+
@kernel_size = kernel_size
|
7
|
+
@stride = stride || kernel_size
|
8
|
+
@padding = padding
|
9
|
+
@ceil_mode = ceil_mode
|
10
|
+
@count_include_pad = count_include_pad
|
11
|
+
@divisor_override = divisor_override
|
12
|
+
end
|
13
|
+
|
14
|
+
def forward(input)
|
15
|
+
F.avg_pool3d(input, @kernel_size, @stride, @padding, @ceil_mode, @count_include_pad, @divisor_override)
|
16
|
+
end
|
17
|
+
end
|
18
|
+
end
|
19
|
+
end
|