torch-rb 0.1.5 → 0.1.6
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/README.md +1 -1
- data/ext/torch/ext.cpp +0 -170
- data/ext/torch/nn_functions.cpp +44 -24
- data/ext/torch/templates.cpp +55 -0
- data/ext/torch/templates.hpp +48 -0
- data/ext/torch/tensor_functions.cpp +76 -16
- data/ext/torch/torch_functions.cpp +165 -65
- data/lib/torch.rb +51 -42
- data/lib/torch/ext.bundle +0 -0
- data/lib/torch/native/dispatcher.rb +1 -1
- data/lib/torch/native/function.rb +36 -5
- data/lib/torch/native/generator.rb +26 -7
- data/lib/torch/native/parser.rb +51 -14
- data/lib/torch/nn/avg_pool1d.rb +18 -0
- data/lib/torch/nn/avg_pool2d.rb +7 -2
- data/lib/torch/nn/avg_pool3d.rb +19 -0
- data/lib/torch/nn/avg_poolnd.rb +1 -1
- data/lib/torch/nn/batch_norm.rb +75 -0
- data/lib/torch/nn/batch_norm1d.rb +11 -0
- data/lib/torch/nn/batch_norm2d.rb +11 -0
- data/lib/torch/nn/batch_norm3d.rb +11 -0
- data/lib/torch/nn/constant_pad1d.rb +10 -0
- data/lib/torch/nn/constant_pad2d.rb +10 -0
- data/lib/torch/nn/constant_pad3d.rb +10 -0
- data/lib/torch/nn/constant_padnd.rb +18 -0
- data/lib/torch/nn/conv1d.rb +22 -0
- data/lib/torch/nn/conv2d.rb +9 -17
- data/lib/torch/nn/conv3d.rb +22 -0
- data/lib/torch/nn/fold.rb +20 -0
- data/lib/torch/nn/functional.rb +320 -100
- data/lib/torch/nn/group_norm.rb +36 -0
- data/lib/torch/nn/gru.rb +49 -0
- data/lib/torch/nn/hardshrink.rb +18 -0
- data/lib/torch/nn/instance_norm.rb +20 -0
- data/lib/torch/nn/instance_norm1d.rb +18 -0
- data/lib/torch/nn/instance_norm2d.rb +11 -0
- data/lib/torch/nn/instance_norm3d.rb +11 -0
- data/lib/torch/nn/layer_norm.rb +35 -0
- data/lib/torch/nn/local_response_norm.rb +21 -0
- data/lib/torch/nn/log_sigmoid.rb +9 -0
- data/lib/torch/nn/lp_pool1d.rb +9 -0
- data/lib/torch/nn/lp_pool2d.rb +9 -0
- data/lib/torch/nn/lp_poolnd.rb +22 -0
- data/lib/torch/nn/lstm.rb +66 -0
- data/lib/torch/nn/max_pool1d.rb +9 -0
- data/lib/torch/nn/max_pool2d.rb +1 -1
- data/lib/torch/nn/max_pool3d.rb +9 -0
- data/lib/torch/nn/max_poolnd.rb +6 -6
- data/lib/torch/nn/max_unpool1d.rb +16 -0
- data/lib/torch/nn/max_unpool2d.rb +16 -0
- data/lib/torch/nn/max_unpool3d.rb +16 -0
- data/lib/torch/nn/max_unpoolnd.rb +9 -0
- data/lib/torch/nn/module.rb +7 -0
- data/lib/torch/nn/reflection_pad1d.rb +10 -0
- data/lib/torch/nn/reflection_pad2d.rb +10 -0
- data/lib/torch/nn/reflection_padnd.rb +13 -0
- data/lib/torch/nn/replication_pad1d.rb +10 -0
- data/lib/torch/nn/replication_pad2d.rb +10 -0
- data/lib/torch/nn/replication_pad3d.rb +10 -0
- data/lib/torch/nn/replication_padnd.rb +13 -0
- data/lib/torch/nn/rnn_base.rb +48 -4
- data/lib/torch/nn/softshrink.rb +18 -0
- data/lib/torch/nn/softsign.rb +9 -0
- data/lib/torch/nn/tanh.rb +9 -0
- data/lib/torch/nn/tanhshrink.rb +9 -0
- data/lib/torch/nn/unfold.rb +19 -0
- data/lib/torch/nn/utils.rb +25 -0
- data/lib/torch/nn/zero_pad2d.rb +9 -0
- data/lib/torch/tensor.rb +14 -25
- data/lib/torch/version.rb +1 -1
- metadata +50 -2
data/lib/torch.rb
CHANGED
@@ -29,6 +29,7 @@ require "torch/optim/lr_scheduler/step_lr"
|
|
29
29
|
|
30
30
|
# nn parameters
|
31
31
|
require "torch/nn/parameter"
|
32
|
+
require "torch/nn/utils"
|
32
33
|
|
33
34
|
# nn containers
|
34
35
|
require "torch/nn/module"
|
@@ -36,17 +37,61 @@ require "torch/nn/sequential"
|
|
36
37
|
|
37
38
|
# nn convolution layers
|
38
39
|
require "torch/nn/convnd"
|
40
|
+
require "torch/nn/conv1d"
|
39
41
|
require "torch/nn/conv2d"
|
42
|
+
require "torch/nn/conv3d"
|
43
|
+
require "torch/nn/unfold"
|
44
|
+
require "torch/nn/fold"
|
40
45
|
|
41
46
|
# nn pooling layers
|
42
47
|
require "torch/nn/max_poolnd"
|
48
|
+
require "torch/nn/max_pool1d"
|
43
49
|
require "torch/nn/max_pool2d"
|
50
|
+
require "torch/nn/max_pool3d"
|
51
|
+
require "torch/nn/max_unpoolnd"
|
52
|
+
require "torch/nn/max_unpool1d"
|
53
|
+
require "torch/nn/max_unpool2d"
|
54
|
+
require "torch/nn/max_unpool3d"
|
44
55
|
require "torch/nn/avg_poolnd"
|
56
|
+
require "torch/nn/avg_pool1d"
|
45
57
|
require "torch/nn/avg_pool2d"
|
58
|
+
require "torch/nn/avg_pool3d"
|
59
|
+
require "torch/nn/lp_poolnd"
|
60
|
+
require "torch/nn/lp_pool1d"
|
61
|
+
require "torch/nn/lp_pool2d"
|
62
|
+
|
63
|
+
# nn padding layers
|
64
|
+
require "torch/nn/reflection_padnd"
|
65
|
+
require "torch/nn/reflection_pad1d"
|
66
|
+
require "torch/nn/reflection_pad2d"
|
67
|
+
require "torch/nn/replication_padnd"
|
68
|
+
require "torch/nn/replication_pad1d"
|
69
|
+
require "torch/nn/replication_pad2d"
|
70
|
+
require "torch/nn/replication_pad3d"
|
71
|
+
require "torch/nn/constant_padnd"
|
72
|
+
require "torch/nn/constant_pad1d"
|
73
|
+
require "torch/nn/constant_pad2d"
|
74
|
+
require "torch/nn/constant_pad3d"
|
75
|
+
require "torch/nn/zero_pad2d"
|
76
|
+
|
77
|
+
# nn normalization layers
|
78
|
+
require "torch/nn/batch_norm"
|
79
|
+
require "torch/nn/batch_norm1d"
|
80
|
+
require "torch/nn/batch_norm2d"
|
81
|
+
require "torch/nn/batch_norm3d"
|
82
|
+
require "torch/nn/group_norm"
|
83
|
+
require "torch/nn/instance_norm"
|
84
|
+
require "torch/nn/instance_norm1d"
|
85
|
+
require "torch/nn/instance_norm2d"
|
86
|
+
require "torch/nn/instance_norm3d"
|
87
|
+
require "torch/nn/layer_norm"
|
88
|
+
require "torch/nn/local_response_norm"
|
46
89
|
|
47
90
|
# nn recurrent layers
|
48
91
|
require "torch/nn/rnn_base"
|
49
92
|
require "torch/nn/rnn"
|
93
|
+
require "torch/nn/lstm"
|
94
|
+
require "torch/nn/gru"
|
50
95
|
|
51
96
|
# nn linear layers
|
52
97
|
require "torch/nn/bilinear"
|
@@ -62,11 +107,17 @@ require "torch/nn/dropout3d"
|
|
62
107
|
require "torch/nn/feature_alpha_dropout"
|
63
108
|
|
64
109
|
# nn activations
|
110
|
+
require "torch/nn/hardshrink"
|
65
111
|
require "torch/nn/leaky_relu"
|
112
|
+
require "torch/nn/log_sigmoid"
|
66
113
|
require "torch/nn/prelu"
|
67
114
|
require "torch/nn/relu"
|
68
115
|
require "torch/nn/sigmoid"
|
69
116
|
require "torch/nn/softplus"
|
117
|
+
require "torch/nn/softshrink"
|
118
|
+
require "torch/nn/softsign"
|
119
|
+
require "torch/nn/tanh"
|
120
|
+
require "torch/nn/tanhshrink"
|
70
121
|
|
71
122
|
# nn activations other
|
72
123
|
require "torch/nn/log_softmax"
|
@@ -364,48 +415,6 @@ module Torch
|
|
364
415
|
zeros(input.size, like_options(input, options))
|
365
416
|
end
|
366
417
|
|
367
|
-
# --- begin operations ---
|
368
|
-
|
369
|
-
# TODO support out
|
370
|
-
def mean(input, dim = nil, keepdim: false)
|
371
|
-
if dim
|
372
|
-
_mean_dim(input, dim, keepdim)
|
373
|
-
else
|
374
|
-
_mean(input)
|
375
|
-
end
|
376
|
-
end
|
377
|
-
|
378
|
-
# TODO support dtype
|
379
|
-
def sum(input, dim = nil, keepdim: false)
|
380
|
-
if dim
|
381
|
-
_sum_dim(input, dim, keepdim)
|
382
|
-
else
|
383
|
-
_sum(input)
|
384
|
-
end
|
385
|
-
end
|
386
|
-
|
387
|
-
def topk(input, k)
|
388
|
-
_topk(input, k)
|
389
|
-
end
|
390
|
-
|
391
|
-
def max(input, dim = nil, keepdim: false, out: nil)
|
392
|
-
if dim
|
393
|
-
raise NotImplementedYet unless out
|
394
|
-
_max_out(out[0], out[1], input, dim, keepdim)
|
395
|
-
else
|
396
|
-
_max(input)
|
397
|
-
end
|
398
|
-
end
|
399
|
-
|
400
|
-
# TODO make dim keyword argument
|
401
|
-
def log_softmax(input, dim)
|
402
|
-
_log_softmax(input, dim)
|
403
|
-
end
|
404
|
-
|
405
|
-
def softmax(input, dim: nil)
|
406
|
-
_softmax(input, dim)
|
407
|
-
end
|
408
|
-
|
409
418
|
private
|
410
419
|
|
411
420
|
def tensor_size(size)
|
data/lib/torch/ext.bundle
CHANGED
Binary file
|
@@ -18,7 +18,7 @@ module Torch
|
|
18
18
|
functions = Generator.grouped_functions
|
19
19
|
bind_functions(::Torch, :define_singleton_method, functions[:torch])
|
20
20
|
bind_functions(::Torch::Tensor, :define_method, functions[:tensor])
|
21
|
-
|
21
|
+
bind_functions(::Torch::NN, :define_singleton_method, functions[:nn])
|
22
22
|
end
|
23
23
|
|
24
24
|
def bind_functions(context, def_method, functions)
|
@@ -35,11 +35,38 @@ module Torch
|
|
35
35
|
end
|
36
36
|
t, _, k = a.rpartition(" ")
|
37
37
|
k, d = k.split("=")
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
38
|
+
has_default = !d.nil?
|
39
|
+
|
40
|
+
if d
|
41
|
+
d =
|
42
|
+
case d
|
43
|
+
when "True"
|
44
|
+
true
|
45
|
+
when "False"
|
46
|
+
false
|
47
|
+
when "None"
|
48
|
+
nil
|
49
|
+
when /\A\-?\d+\z/
|
50
|
+
d.to_i
|
51
|
+
when "[]"
|
52
|
+
[]
|
53
|
+
when "[0,1]"
|
54
|
+
[0, 1]
|
55
|
+
when /\A\de\-\d+\z/, /\A\d+\.\d+\z/
|
56
|
+
d.to_f
|
57
|
+
when "Mean"
|
58
|
+
"mean"
|
59
|
+
when "contiguous_format"
|
60
|
+
d
|
61
|
+
when "long"
|
62
|
+
:long
|
63
|
+
else
|
64
|
+
raise "Unknown default: #{d}"
|
65
|
+
end
|
66
|
+
end
|
67
|
+
|
68
|
+
next if t == "Generator?"
|
69
|
+
args << {name: k, type: t, default: d, pos: pos, has_default: has_default}
|
43
70
|
end
|
44
71
|
args
|
45
72
|
end
|
@@ -49,6 +76,10 @@ module Torch
|
|
49
76
|
@out_size ||= func.split("->").last.count("!")
|
50
77
|
end
|
51
78
|
|
79
|
+
def ret_size
|
80
|
+
@ret_size ||= func.split("->").last.split(", ").size
|
81
|
+
end
|
82
|
+
|
52
83
|
def out?
|
53
84
|
out_size > 0 && base_name[-1] != "_"
|
54
85
|
end
|
@@ -17,14 +17,23 @@ module Torch
|
|
17
17
|
def grouped_functions
|
18
18
|
functions = functions()
|
19
19
|
|
20
|
-
#
|
20
|
+
# skip functions
|
21
21
|
skip_binding = ["unique_dim_consecutive", "einsum", "normal"]
|
22
|
-
skip_args = ["bool[3]", "Dimname", "
|
23
|
-
|
22
|
+
skip_args = ["bool[3]", "Dimname", "MemoryFormat", "Layout", "Storage", "ConstQuantizerPtr"]
|
23
|
+
|
24
|
+
# remove functions
|
25
|
+
functions.reject! do |f|
|
26
|
+
f.ruby_name.start_with?("_") ||
|
27
|
+
f.ruby_name.end_with?("_backward") ||
|
28
|
+
skip_binding.include?(f.ruby_name) ||
|
29
|
+
f.args.any? { |a| a[:type].include?("Dimname") }
|
30
|
+
end
|
31
|
+
|
32
|
+
# separate out into todo
|
24
33
|
todo_functions, functions =
|
25
34
|
functions.partition do |f|
|
26
35
|
f.args.any? do |a|
|
27
|
-
a[:type].include?("?") && !["Tensor?", "Generator?", "int?"].include?(a[:type]) ||
|
36
|
+
a[:type].include?("?") && !["Tensor?", "Generator?", "int?", "ScalarType?"].include?(a[:type]) ||
|
28
37
|
skip_args.any? { |sa| a[:type].include?(sa) }
|
29
38
|
end
|
30
39
|
end
|
@@ -33,7 +42,9 @@ module Torch
|
|
33
42
|
# there may be a better way to do this
|
34
43
|
optional_functions, functions = functions.partition { |f| f.args.any? { |a| a[:type] == "int?" } }
|
35
44
|
optional_functions.each do |f|
|
36
|
-
next if f.ruby_name
|
45
|
+
next if f.ruby_name == "cross"
|
46
|
+
next if f.ruby_name.start_with?("avg_pool") && f.out?
|
47
|
+
|
37
48
|
opt_args = f.args.select { |a| a[:type] == "int?" }
|
38
49
|
if opt_args.size == 1
|
39
50
|
sep = f.name.include?(".") ? "_" : "."
|
@@ -85,7 +96,7 @@ void add_%{type}_functions(Module m) {
|
|
85
96
|
|
86
97
|
cpp_defs = []
|
87
98
|
functions.sort_by(&:cpp_name).each do |func|
|
88
|
-
fargs = func.args
|
99
|
+
fargs = func.args #.select { |a| a[:type] != "Generator?" }
|
89
100
|
|
90
101
|
cpp_args = []
|
91
102
|
fargs.each do |a|
|
@@ -96,6 +107,8 @@ void add_%{type}_functions(Module m) {
|
|
96
107
|
when "Tensor?"
|
97
108
|
# TODO better signature
|
98
109
|
"OptionalTensor"
|
110
|
+
when "ScalarType?"
|
111
|
+
"OptionalScalarType"
|
99
112
|
when "Tensor[]"
|
100
113
|
"TensorList"
|
101
114
|
when "int"
|
@@ -121,10 +134,16 @@ void add_%{type}_functions(Module m) {
|
|
121
134
|
|
122
135
|
prefix = def_method == :define_method ? "self." : "torch::"
|
123
136
|
|
137
|
+
body = "#{prefix}#{dispatch}(#{args.join(", ")})"
|
138
|
+
# TODO check type as well
|
139
|
+
if func.ret_size > 1
|
140
|
+
body = "wrap(#{body})"
|
141
|
+
end
|
142
|
+
|
124
143
|
cpp_defs << ".#{def_method}(
|
125
144
|
\"#{func.cpp_name}\",
|
126
145
|
*[](#{cpp_args.join(", ")}) {
|
127
|
-
return #{
|
146
|
+
return #{body};
|
128
147
|
})"
|
129
148
|
end
|
130
149
|
|
data/lib/torch/native/parser.rb
CHANGED
@@ -4,19 +4,44 @@ module Torch
|
|
4
4
|
def initialize(functions)
|
5
5
|
@functions = functions
|
6
6
|
@name = @functions.first.ruby_name
|
7
|
-
@min_args = @functions.map { |f| f.args.count { |a| a[:pos] && a[:
|
7
|
+
@min_args = @functions.map { |f| f.args.count { |a| a[:pos] && !a[:has_default] } }.min
|
8
8
|
@max_args = @functions.map { |f| f.args.count { |a| a[:pos] } }.max
|
9
9
|
end
|
10
10
|
|
11
11
|
def parse(args, options)
|
12
12
|
candidates = @functions.dup
|
13
13
|
|
14
|
+
# remove nil
|
15
|
+
while args.any? && args.last.nil?
|
16
|
+
args.pop
|
17
|
+
end
|
18
|
+
|
19
|
+
# TODO account for args passed as options here
|
14
20
|
if args.size < @min_args || args.size > @max_args
|
15
21
|
expected = String.new(@min_args.to_s)
|
16
22
|
expected += "..#{@max_args}" if @max_args != @min_args
|
17
23
|
return {error: "wrong number of arguments (given #{args.size}, expected #{expected})"}
|
18
24
|
end
|
19
25
|
|
26
|
+
candidates.reject! { |f| args.size > f.args.size }
|
27
|
+
|
28
|
+
# exclude functions missing required options
|
29
|
+
candidates.reject! do |func|
|
30
|
+
# TODO make more generic
|
31
|
+
func.out? && !options[:out]
|
32
|
+
end
|
33
|
+
|
34
|
+
# handle out with multiple
|
35
|
+
# there should only be one match, so safe to modify all
|
36
|
+
out_func = candidates.find { |f| f.out? }
|
37
|
+
if out_func && out_func.out_size > 1 && options[:out]
|
38
|
+
out_args = out_func.args.last(2).map { |a| a[:name] }
|
39
|
+
out_args.zip(options.delete(:out)).each do |k, v|
|
40
|
+
options[k.to_sym] = v
|
41
|
+
end
|
42
|
+
candidates = [out_func]
|
43
|
+
end
|
44
|
+
|
20
45
|
# exclude functions where options don't match
|
21
46
|
options.each do |k, v|
|
22
47
|
candidates.select! do |func|
|
@@ -26,12 +51,6 @@ module Torch
|
|
26
51
|
return {error: "unknown keyword: #{k}"} if candidates.empty?
|
27
52
|
end
|
28
53
|
|
29
|
-
# exclude functions missing required options
|
30
|
-
candidates.reject! do |func|
|
31
|
-
# TODO make more generic
|
32
|
-
func.out? && !options[:out]
|
33
|
-
end
|
34
|
-
|
35
54
|
final_values = {}
|
36
55
|
|
37
56
|
# check args
|
@@ -41,29 +60,47 @@ module Torch
|
|
41
60
|
values = args.zip(func.args).map { |a, fa| [fa[:name], a] }.to_h
|
42
61
|
values.merge!(options.map { |k, v| [k.to_s, v] }.to_h)
|
43
62
|
func.args.each do |fa|
|
44
|
-
values[fa[:name]]
|
63
|
+
values[fa[:name]] = fa[:default] if values[fa[:name]].nil?
|
45
64
|
end
|
46
65
|
|
47
66
|
arg_types = func.args.map { |a| [a[:name], a[:type]] }.to_h
|
48
67
|
|
49
|
-
values.
|
68
|
+
values.each_key do |k|
|
69
|
+
v = values[k]
|
50
70
|
t = arg_types[k].split("(").first
|
71
|
+
|
51
72
|
good =
|
52
73
|
case t
|
53
74
|
when "Tensor"
|
54
75
|
v.is_a?(Tensor)
|
76
|
+
when "Tensor?"
|
77
|
+
v.nil? || v.is_a?(Tensor)
|
55
78
|
when "Tensor[]"
|
56
|
-
v.all? { |v2| v2.is_a?(Tensor) }
|
79
|
+
v.is_a?(Array) && v.all? { |v2| v2.is_a?(Tensor) }
|
57
80
|
when "int"
|
58
|
-
|
59
|
-
|
60
|
-
|
81
|
+
if k == "reduction"
|
82
|
+
v.is_a?(String)
|
83
|
+
else
|
84
|
+
v.is_a?(Integer)
|
85
|
+
end
|
86
|
+
when "float"
|
87
|
+
v.is_a?(Numeric)
|
88
|
+
when /int\[.*\]/
|
89
|
+
if v.is_a?(Integer)
|
90
|
+
size = t[4..-2]
|
91
|
+
raise Error, "Unknown size: #{size}. Please report a bug with #{@name}." unless size =~ /\A\d+\z/
|
92
|
+
v = [v] * size.to_i
|
93
|
+
values[k] = v
|
94
|
+
end
|
95
|
+
v.is_a?(Array) && v.all? { |v2| v2.is_a?(Integer) }
|
61
96
|
when "Scalar"
|
62
97
|
v.is_a?(Numeric)
|
98
|
+
when "ScalarType?"
|
99
|
+
v.nil?
|
63
100
|
when "bool"
|
64
101
|
v == true || v == false
|
65
102
|
else
|
66
|
-
raise Error, "Unknown argument type: #{arg_types[k]}. Please report a bug with #{@name}"
|
103
|
+
raise Error, "Unknown argument type: #{arg_types[k]}. Please report a bug with #{@name}."
|
67
104
|
end
|
68
105
|
|
69
106
|
if !good
|
@@ -0,0 +1,18 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class AvgPool1d < AvgPoolNd
|
4
|
+
def initialize(kernel_size, stride: nil, padding: 0, ceil_mode: false, count_include_pad: true)
|
5
|
+
super()
|
6
|
+
@kernel_size = _single(kernel_size)
|
7
|
+
@stride = _single(stride || kernel_size)
|
8
|
+
@padding = _single(padding)
|
9
|
+
@ceil_mode = ceil_mode
|
10
|
+
@count_include_pad = count_include_pad
|
11
|
+
end
|
12
|
+
|
13
|
+
def forward(input)
|
14
|
+
F.avg_pool1d(input, @kernel_size, @stride, @padding, @ceil_mode, @count_include_pad)
|
15
|
+
end
|
16
|
+
end
|
17
|
+
end
|
18
|
+
end
|
data/lib/torch/nn/avg_pool2d.rb
CHANGED
@@ -1,13 +1,18 @@
|
|
1
1
|
module Torch
|
2
2
|
module NN
|
3
3
|
class AvgPool2d < AvgPoolNd
|
4
|
-
def initialize(kernel_size)
|
4
|
+
def initialize(kernel_size, stride: nil, padding: 0, ceil_mode: false, count_include_pad: true, divisor_override: nil)
|
5
5
|
super()
|
6
6
|
@kernel_size = kernel_size
|
7
|
+
@stride = stride || kernel_size
|
8
|
+
@padding = padding
|
9
|
+
@ceil_mode = ceil_mode
|
10
|
+
@count_include_pad = count_include_pad
|
11
|
+
@divisor_override = divisor_override
|
7
12
|
end
|
8
13
|
|
9
14
|
def forward(input)
|
10
|
-
F.avg_pool2d(input, @kernel_size)
|
15
|
+
F.avg_pool2d(input, @kernel_size, @stride, @padding, @ceil_mode, @count_include_pad, @divisor_override)
|
11
16
|
end
|
12
17
|
end
|
13
18
|
end
|
@@ -0,0 +1,19 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class AvgPool3d < AvgPoolNd
|
4
|
+
def initialize(kernel_size, stride: nil, padding: 0, ceil_mode: false, count_include_pad: true, divisor_override: nil)
|
5
|
+
super()
|
6
|
+
@kernel_size = kernel_size
|
7
|
+
@stride = stride || kernel_size
|
8
|
+
@padding = padding
|
9
|
+
@ceil_mode = ceil_mode
|
10
|
+
@count_include_pad = count_include_pad
|
11
|
+
@divisor_override = divisor_override
|
12
|
+
end
|
13
|
+
|
14
|
+
def forward(input)
|
15
|
+
F.avg_pool3d(input, @kernel_size, @stride, @padding, @ceil_mode, @count_include_pad, @divisor_override)
|
16
|
+
end
|
17
|
+
end
|
18
|
+
end
|
19
|
+
end
|