torch-rb 0.3.7 → 0.4.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/README.md +1 -1
- data/codegen/function.rb +134 -0
- data/codegen/generate_functions.rb +546 -0
- data/{lib/torch/native → codegen}/native_functions.yaml +0 -0
- data/ext/torch/ext.cpp +54 -75
- data/ext/torch/extconf.rb +2 -2
- data/ext/torch/nn_functions.h +6 -0
- data/ext/torch/ruby_arg_parser.cpp +593 -0
- data/ext/torch/ruby_arg_parser.h +373 -0
- data/ext/torch/{templates.hpp → templates.h} +30 -51
- data/ext/torch/tensor_functions.h +6 -0
- data/ext/torch/torch_functions.h +6 -0
- data/ext/torch/utils.h +42 -0
- data/ext/torch/{templates.cpp → wrap_outputs.h} +16 -15
- data/lib/torch.rb +0 -62
- data/lib/torch/nn/functional.rb +30 -16
- data/lib/torch/nn/init.rb +5 -19
- data/lib/torch/optim/adadelta.rb +1 -1
- data/lib/torch/optim/adam.rb +2 -2
- data/lib/torch/optim/adamax.rb +1 -1
- data/lib/torch/optim/adamw.rb +1 -1
- data/lib/torch/optim/asgd.rb +1 -1
- data/lib/torch/optim/sgd.rb +3 -3
- data/lib/torch/tensor.rb +25 -105
- data/lib/torch/version.rb +1 -1
- metadata +27 -9
- data/lib/torch/native/dispatcher.rb +0 -70
- data/lib/torch/native/function.rb +0 -200
- data/lib/torch/native/generator.rb +0 -178
- data/lib/torch/native/parser.rb +0 -117
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 6294a5809ed3ea9b3d3f9954b351b8c8e5d81c37e9318ae5044d243320166b20
|
4
|
+
data.tar.gz: 57d8036a0449497cbf040c8650c4dcdbe3e0f09191951d77e6d907c6b3c1013d
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: c7aaf862b43f72f5ca8e7f6f6aca079195c11420db081f9c4f18f161610598d975fb81a036689dcd3158d200c33e39e67841bded2e728f2a2fb1b5b8d9e9dfad
|
7
|
+
data.tar.gz: afaf4c344951629248f79bb03246f6879555a58438d07107dea008aa2433d08115d56b1f9464cf20065c276417d3d90a59aecbfefa304208ee81ae313867bad6
|
data/CHANGELOG.md
CHANGED
data/README.md
CHANGED
data/codegen/function.rb
ADDED
@@ -0,0 +1,134 @@
|
|
1
|
+
class Function
|
2
|
+
attr_reader :definition, :params, :retvals
|
3
|
+
|
4
|
+
def initialize(definition)
|
5
|
+
@definition = definition
|
6
|
+
@params, @retvals = parse_func
|
7
|
+
end
|
8
|
+
|
9
|
+
def name
|
10
|
+
func.split("(", 2).first
|
11
|
+
end
|
12
|
+
|
13
|
+
def base_name
|
14
|
+
name.split(".").first
|
15
|
+
end
|
16
|
+
|
17
|
+
def func
|
18
|
+
definition["func"]
|
19
|
+
end
|
20
|
+
|
21
|
+
def python_module
|
22
|
+
definition["python_module"]
|
23
|
+
end
|
24
|
+
|
25
|
+
def variants
|
26
|
+
(definition["variants"] || "function").split(", ")
|
27
|
+
end
|
28
|
+
|
29
|
+
def out_index
|
30
|
+
params.index { |v| v[:modifier].to_s.include?("!") } if base_name[-1] != "_" && retvals.any?
|
31
|
+
end
|
32
|
+
|
33
|
+
def out?
|
34
|
+
!out_index.nil?
|
35
|
+
end
|
36
|
+
|
37
|
+
private
|
38
|
+
|
39
|
+
def parse_func
|
40
|
+
input, output = func.split(/\s*->\s*/)
|
41
|
+
[generate_params(input), generate_retvals(output)]
|
42
|
+
end
|
43
|
+
|
44
|
+
def generate_params(input)
|
45
|
+
input = input.split("(", 2).last.chomp(")").split(/\s*,\s+/)
|
46
|
+
|
47
|
+
keyword_only = false
|
48
|
+
params = []
|
49
|
+
input.each do |i|
|
50
|
+
if i == "*"
|
51
|
+
keyword_only = true
|
52
|
+
next
|
53
|
+
end
|
54
|
+
|
55
|
+
type, name = i.split(/\s+/)
|
56
|
+
|
57
|
+
if name.include?("=")
|
58
|
+
name, default = name.split("=", 2)
|
59
|
+
end
|
60
|
+
|
61
|
+
optional = false
|
62
|
+
if type.include?("?")
|
63
|
+
optional = true unless ["dtype", "device", "layout", "pin_memory"].include?(name)
|
64
|
+
type = type.delete("?")
|
65
|
+
end
|
66
|
+
|
67
|
+
type, modifier = extract_modifier(type)
|
68
|
+
|
69
|
+
if type.include?("[")
|
70
|
+
list_size = /\[(.*)\]/.match(type)[1]
|
71
|
+
list_size = nil if list_size.empty?
|
72
|
+
end
|
73
|
+
|
74
|
+
if name == "dtype" && (base_name.start_with?("randperm") || base_name == "tril_indices" || base_name == "triu_indices")
|
75
|
+
# dtype hack
|
76
|
+
# https://github.com/pytorch/pytorch/blob/v1.6.0/tools/autograd/gen_python_functions.py#L1307-L1311
|
77
|
+
default = "torch.int64"
|
78
|
+
end
|
79
|
+
|
80
|
+
params << {
|
81
|
+
name: name,
|
82
|
+
type: type,
|
83
|
+
default: default,
|
84
|
+
keyword_only: keyword_only,
|
85
|
+
optional: optional,
|
86
|
+
modifier: modifier,
|
87
|
+
list_size: list_size
|
88
|
+
}
|
89
|
+
end
|
90
|
+
|
91
|
+
if (params.map { |v| v[:name] } & ["dtype", "device", "layout", "pin_memory"]).size == 4
|
92
|
+
params << {
|
93
|
+
name: "requires_grad",
|
94
|
+
type: "bool",
|
95
|
+
default: "False",
|
96
|
+
keyword_only: true,
|
97
|
+
optional: false,
|
98
|
+
modifier: nil,
|
99
|
+
list_size: nil
|
100
|
+
}
|
101
|
+
end
|
102
|
+
|
103
|
+
params
|
104
|
+
end
|
105
|
+
|
106
|
+
def generate_retvals(output)
|
107
|
+
output =
|
108
|
+
if output == "()"
|
109
|
+
[]
|
110
|
+
elsif output[0] == "("
|
111
|
+
output[1..-2].split(/\s*,\s*/)
|
112
|
+
else
|
113
|
+
[output]
|
114
|
+
end
|
115
|
+
|
116
|
+
retvals = []
|
117
|
+
output.each do |o|
|
118
|
+
type, name = o.split(/\s+/)
|
119
|
+
type, modifier = extract_modifier(type)
|
120
|
+
retvals << {name: name, type: type, modifier: modifier}
|
121
|
+
end
|
122
|
+
retvals
|
123
|
+
end
|
124
|
+
|
125
|
+
# Tensor(a), Tensor(a!), Tensor(a)[]
|
126
|
+
def extract_modifier(type)
|
127
|
+
if type.include?("(")
|
128
|
+
parts = type.split(/[\(\)]/, 3)
|
129
|
+
modifier = parts.delete_at(1)
|
130
|
+
type = parts.join("")
|
131
|
+
end
|
132
|
+
[type, modifier]
|
133
|
+
end
|
134
|
+
end
|
@@ -0,0 +1,546 @@
|
|
1
|
+
require "yaml"
|
2
|
+
# use require_relative for
|
3
|
+
# rake generate:function (without bundle)
|
4
|
+
require_relative "function"
|
5
|
+
|
6
|
+
def generate_functions
|
7
|
+
functions = load_functions
|
8
|
+
functions = skip_functions(functions)
|
9
|
+
functions = group_functions(functions)
|
10
|
+
|
11
|
+
generate_files("torch", :define_singleton_method, functions[:torch])
|
12
|
+
generate_files("tensor", :define_method, functions[:tensor])
|
13
|
+
generate_files("nn", :define_singleton_method, functions[:nn])
|
14
|
+
end
|
15
|
+
|
16
|
+
def load_functions
|
17
|
+
path = File.expand_path("native_functions.yaml", __dir__)
|
18
|
+
YAML.load_file(path).map { |f| Function.new(f) }.sort_by(&:name)
|
19
|
+
end
|
20
|
+
|
21
|
+
def skip_functions(functions)
|
22
|
+
functions.reject do |f|
|
23
|
+
f.base_name.start_with?("_") ||
|
24
|
+
f.base_name.include?("_backward") ||
|
25
|
+
f.base_name.include?("_forward") ||
|
26
|
+
f.base_name == "to" ||
|
27
|
+
# in ext.cpp
|
28
|
+
f.base_name == "index" ||
|
29
|
+
f.base_name == "index_put_" ||
|
30
|
+
# need to add to ext.cpp
|
31
|
+
f.base_name == "index_put" ||
|
32
|
+
# not supported yet
|
33
|
+
f.func.include?("Dimname") ||
|
34
|
+
f.func.include?("ConstQuantizerPtr")
|
35
|
+
end
|
36
|
+
end
|
37
|
+
|
38
|
+
def group_functions(functions)
|
39
|
+
nn_functions, other_functions = functions.partition { |f| f.python_module == "nn" }
|
40
|
+
torch_functions = other_functions.select { |f| f.variants.include?("function") }
|
41
|
+
tensor_functions = other_functions.select { |f| f.variants.include?("method") }
|
42
|
+
|
43
|
+
{torch: torch_functions, tensor: tensor_functions, nn: nn_functions}
|
44
|
+
end
|
45
|
+
|
46
|
+
def generate_files(type, def_method, functions)
|
47
|
+
method_defs = []
|
48
|
+
attach_defs = []
|
49
|
+
functions.group_by(&:base_name).each do |name, grouped_functions|
|
50
|
+
method_defs << generate_method_def(name, grouped_functions, type, def_method)
|
51
|
+
attach_defs << generate_attach_def(name, type, def_method)
|
52
|
+
end
|
53
|
+
write_header(type)
|
54
|
+
write_body(type, method_defs, attach_defs)
|
55
|
+
end
|
56
|
+
|
57
|
+
def write_header(type)
|
58
|
+
template = <<~EOS
|
59
|
+
// generated by rake generate:functions
|
60
|
+
// do not edit by hand
|
61
|
+
|
62
|
+
#pragma once
|
63
|
+
|
64
|
+
void add_%{type}_functions(Module m);
|
65
|
+
EOS
|
66
|
+
|
67
|
+
contents = template % {type: type}
|
68
|
+
write_file("#{type}_functions.h", contents)
|
69
|
+
end
|
70
|
+
|
71
|
+
def write_body(type, method_defs, attach_defs)
|
72
|
+
cuda_lazy_init = %{\n#include "torch/csrc/utils/cuda_lazy_init.h"\n} unless type == "nn"
|
73
|
+
|
74
|
+
template = <<~EOS
|
75
|
+
// generated by rake generate:functions
|
76
|
+
// do not edit by hand
|
77
|
+
|
78
|
+
#include <torch/torch.h>
|
79
|
+
#include <rice/Module.hpp>
|
80
|
+
|
81
|
+
#include "ruby_arg_parser.h"
|
82
|
+
#include "templates.h"
|
83
|
+
#include "wrap_outputs.h"
|
84
|
+
%{cuda_lazy_init}
|
85
|
+
%{method_defs}
|
86
|
+
void add_%{type}_functions(Module m) {
|
87
|
+
%{attach_defs}
|
88
|
+
}
|
89
|
+
EOS
|
90
|
+
|
91
|
+
contents = template % {
|
92
|
+
type: type,
|
93
|
+
method_defs: method_defs.join("\n"),
|
94
|
+
attach_defs: attach_defs.join("\n "),
|
95
|
+
cuda_lazy_init: cuda_lazy_init
|
96
|
+
}
|
97
|
+
write_file("#{type}_functions.cpp", contents)
|
98
|
+
end
|
99
|
+
|
100
|
+
def write_file(name, contents)
|
101
|
+
path = File.expand_path("../ext/torch", __dir__)
|
102
|
+
File.write(File.join(path, name), contents)
|
103
|
+
end
|
104
|
+
|
105
|
+
def generate_attach_def(name, type, def_method)
|
106
|
+
ruby_name =
|
107
|
+
if name.end_with?("_")
|
108
|
+
"#{name[0..-2]}!"
|
109
|
+
elsif name.start_with?("is_")
|
110
|
+
"#{name[3..-1]}?"
|
111
|
+
else
|
112
|
+
name
|
113
|
+
end
|
114
|
+
|
115
|
+
ruby_name = "_#{ruby_name}" if ["size", "stride", "random!"].include?(ruby_name)
|
116
|
+
|
117
|
+
"rb_#{def_method}(m, \"#{ruby_name}\", #{type}_#{name}, -1);"
|
118
|
+
end
|
119
|
+
|
120
|
+
def generate_method_def(name, functions, type, def_method)
|
121
|
+
assign_self = type == "tensor" ? "\n Tensor& self = from_ruby<Tensor&>(self_);" : ""
|
122
|
+
|
123
|
+
functions = group_overloads(functions, type)
|
124
|
+
signatures = functions.map { |f| f["signature"] }
|
125
|
+
max_args = signatures.map { |s| s.count(",") - s.count("*") }.max + 1
|
126
|
+
|
127
|
+
template = <<~EOS
|
128
|
+
// #{name}
|
129
|
+
static VALUE #{type}_#{name}(int argc, VALUE* argv, VALUE self_)
|
130
|
+
{
|
131
|
+
HANDLE_TH_ERRORS#{assign_self}
|
132
|
+
static RubyArgParser parser({
|
133
|
+
#{signatures.map(&:inspect).join(",\n ")}
|
134
|
+
});
|
135
|
+
std::vector<VALUE> parsed_args(#{max_args});
|
136
|
+
auto _r = parser.parse(self_, argc, argv, parsed_args);
|
137
|
+
#{add_dispatches(functions, def_method)}
|
138
|
+
END_HANDLE_TH_ERRORS
|
139
|
+
}
|
140
|
+
EOS
|
141
|
+
end
|
142
|
+
|
143
|
+
def indent(code)
|
144
|
+
code.split("\n").join("\n ")
|
145
|
+
end
|
146
|
+
|
147
|
+
def add_dispatches(functions, def_method)
|
148
|
+
if functions.size == 1
|
149
|
+
add_dispatch(functions.first, def_method)
|
150
|
+
else
|
151
|
+
body = []
|
152
|
+
functions.each_with_index do |f, i|
|
153
|
+
body << "case #{i}: {
|
154
|
+
#{add_dispatch(f, def_method).split("\n").join("\n ")}
|
155
|
+
}"
|
156
|
+
end
|
157
|
+
|
158
|
+
"switch (_r.idx) {
|
159
|
+
#{body.join("\n ")}
|
160
|
+
}
|
161
|
+
RETURN_NIL"
|
162
|
+
end
|
163
|
+
end
|
164
|
+
|
165
|
+
def add_dispatch(function, def_method)
|
166
|
+
if function["out"] && function["out"] != function["base"]
|
167
|
+
base_code = generate_dispatch(function["base"], def_method)
|
168
|
+
out_code = generate_dispatch(function["out"], def_method)
|
169
|
+
out_index = function["out"].out_index
|
170
|
+
|
171
|
+
return "if (_r.isNone(#{out_index})) {
|
172
|
+
#{indent(base_code)}
|
173
|
+
} else {
|
174
|
+
#{indent(out_code)}
|
175
|
+
}"
|
176
|
+
else
|
177
|
+
generate_dispatch(function["base"], def_method)
|
178
|
+
end
|
179
|
+
end
|
180
|
+
|
181
|
+
def group_overloads(functions, type)
|
182
|
+
grouped = Hash.new { |hash, key| hash[key] = {} }
|
183
|
+
|
184
|
+
functions.each do |function|
|
185
|
+
signature = generate_signature(function, type, skip_out: true)
|
186
|
+
v = grouped[signature]
|
187
|
+
if function.out?
|
188
|
+
v["out"] = function
|
189
|
+
v["signature"] = generate_signature(function, type)
|
190
|
+
|
191
|
+
# for now
|
192
|
+
v["base"] ||= function
|
193
|
+
else
|
194
|
+
v["base"] = function
|
195
|
+
v["signature"] ||= signature
|
196
|
+
end
|
197
|
+
end
|
198
|
+
|
199
|
+
puts "Missing base: #{functions.first.name}" if grouped.any? { |_, v| !v["base"] }
|
200
|
+
sort_functions(grouped.values)
|
201
|
+
end
|
202
|
+
|
203
|
+
def sort_functions(functions)
|
204
|
+
# TODO
|
205
|
+
functions.sort_by { |f| f["out"] ? 1 : 0 }
|
206
|
+
end
|
207
|
+
|
208
|
+
def generate_dispatch(function, def_method)
|
209
|
+
cpp_name = function.base_name
|
210
|
+
cpp_name += "_out" if function.out?
|
211
|
+
|
212
|
+
remove_self = def_method == :define_method
|
213
|
+
|
214
|
+
params = function.params.map(&:dup)
|
215
|
+
set_param_position(params, remove_self)
|
216
|
+
params, opt_params = split_opt_params(params)
|
217
|
+
opt_index = opt_params.map { |v| v[:position] }.min if opt_params.any?
|
218
|
+
|
219
|
+
cpp_params = generate_dispatch_params(function, params)
|
220
|
+
if opt_index
|
221
|
+
cpp_params.insert(remove_self ? opt_index + 1 : opt_index, "const TensorOptions & options")
|
222
|
+
end
|
223
|
+
|
224
|
+
retval = generate_dispatch_retval(function)
|
225
|
+
dispatch_code = generate_dispatch_code(function, def_method, params, opt_index, remove_self)
|
226
|
+
function_code = generate_function_code(function, cpp_name, params, opt_index, remove_self)
|
227
|
+
|
228
|
+
out_var = generate_out_var(function.out_index, function.retvals.size) if function.out? && function.retvals.size > 1 && function.retvals.all? { |v| v[:type] == "Tensor" }
|
229
|
+
tensor_options = generate_tensor_options(function, opt_params) if opt_params.any?
|
230
|
+
|
231
|
+
"// #{function.func}#{tensor_options}#{out_var}
|
232
|
+
auto dispatch_#{cpp_name} = [](#{cpp_params.join(", ")}) -> #{retval} {
|
233
|
+
// in future, release GVL
|
234
|
+
#{dispatch_code}
|
235
|
+
};
|
236
|
+
#{function_code}"
|
237
|
+
end
|
238
|
+
|
239
|
+
def generate_out_var(out_index, size)
|
240
|
+
"\n auto out = _r.tensorlist_n<#{size}>(#{out_index});"
|
241
|
+
end
|
242
|
+
|
243
|
+
def set_param_position(params, remove_self)
|
244
|
+
i = 0
|
245
|
+
params.each do |v|
|
246
|
+
next if remove_self && v[:name] == "self"
|
247
|
+
v[:position] = i
|
248
|
+
i += 1
|
249
|
+
end
|
250
|
+
end
|
251
|
+
|
252
|
+
def split_opt_params(params)
|
253
|
+
option_names = ["dtype", "device", "layout", "requires_grad", "pin_memory"]
|
254
|
+
|
255
|
+
opt_params, other_params = params.partition { |v, i| option_names.include?(v[:name]) }
|
256
|
+
if opt_params.size >= 4
|
257
|
+
[other_params, opt_params]
|
258
|
+
else
|
259
|
+
[params, []]
|
260
|
+
end
|
261
|
+
end
|
262
|
+
|
263
|
+
def generate_tensor_options(function, opt_params)
|
264
|
+
code = "\n const auto options = TensorOptions()"
|
265
|
+
order = ["dtype", "device", "layout", "requires_grad", "pin_memory"]
|
266
|
+
opt_params.sort_by { |v| order.index(v[:name]) }.each do |opt|
|
267
|
+
i = opt[:position]
|
268
|
+
|
269
|
+
c =
|
270
|
+
case opt[:name]
|
271
|
+
when "dtype"
|
272
|
+
if function.base_name == "arange"
|
273
|
+
"dtype(_r.scalartypeOptional(#{i}))"
|
274
|
+
else
|
275
|
+
"dtype(_r.scalartype(#{i}))"
|
276
|
+
end
|
277
|
+
when "device"
|
278
|
+
"device(_r.device(#{i}))"
|
279
|
+
when "layout"
|
280
|
+
"layout(_r.layoutOptional(#{i}))"
|
281
|
+
when "requires_grad"
|
282
|
+
"requires_grad(_r.toBool(#{i}))"
|
283
|
+
when "pin_memory"
|
284
|
+
"pinned_memory(_r.toBool(#{i}))"
|
285
|
+
end
|
286
|
+
|
287
|
+
code += "\n .#{c}"
|
288
|
+
end
|
289
|
+
|
290
|
+
"#{code};\n torch::utils::maybe_initialize_cuda(options);"
|
291
|
+
end
|
292
|
+
|
293
|
+
def generate_function_code(function, cpp_name, params, opt_index, remove_self)
|
294
|
+
params = generate_function_params(function, params, remove_self)
|
295
|
+
if opt_index
|
296
|
+
opt_index += 1 if remove_self
|
297
|
+
params.insert(opt_index, "options")
|
298
|
+
end
|
299
|
+
|
300
|
+
code = "dispatch_#{cpp_name}(#{params.join(", ")})"
|
301
|
+
if function.retvals.empty?
|
302
|
+
"#{code};\nRETURN_NIL"
|
303
|
+
else
|
304
|
+
"return wrap(#{code});"
|
305
|
+
end
|
306
|
+
end
|
307
|
+
|
308
|
+
def generate_function_params(function, params, remove_self)
|
309
|
+
out_var = function.out? && function.retvals.size > 1 && function.retvals.all? { |v| v[:type] == "Tensor" }
|
310
|
+
|
311
|
+
i = 0
|
312
|
+
params.map do |param|
|
313
|
+
i += 1
|
314
|
+
|
315
|
+
next "self" if remove_self && param[:name] == "self"
|
316
|
+
if out_var && i > function.out_index
|
317
|
+
next "out[#{i - function.out_index - 1}]"
|
318
|
+
end
|
319
|
+
|
320
|
+
func =
|
321
|
+
case param[:type]
|
322
|
+
when "Tensor"
|
323
|
+
"tensor"
|
324
|
+
when "Tensor[]"
|
325
|
+
"tensorlist"
|
326
|
+
when /\Aint\[/
|
327
|
+
"intlist"
|
328
|
+
when "Scalar"
|
329
|
+
"scalar"
|
330
|
+
when "bool"
|
331
|
+
"toBool"
|
332
|
+
when "int"
|
333
|
+
"toInt64"
|
334
|
+
when "float"
|
335
|
+
"toDouble"
|
336
|
+
when "ScalarType"
|
337
|
+
"scalartype"
|
338
|
+
when "str"
|
339
|
+
"string"
|
340
|
+
when "Generator"
|
341
|
+
"generator"
|
342
|
+
when "MemoryFormat"
|
343
|
+
"memoryformat"
|
344
|
+
when "Storage"
|
345
|
+
"storage"
|
346
|
+
else
|
347
|
+
raise "Unknown type: #{param[:type]} (#{function.name})"
|
348
|
+
end
|
349
|
+
|
350
|
+
if param[:optional]
|
351
|
+
func =
|
352
|
+
case func
|
353
|
+
when "tensor"
|
354
|
+
if function.out?
|
355
|
+
"tensor"
|
356
|
+
else
|
357
|
+
"optionalTensor"
|
358
|
+
end
|
359
|
+
when "generator", "tensorlist", "intlist"
|
360
|
+
func
|
361
|
+
else
|
362
|
+
"#{func}Optional"
|
363
|
+
end
|
364
|
+
end
|
365
|
+
|
366
|
+
"_r.#{func}(#{param[:position]})"
|
367
|
+
end
|
368
|
+
end
|
369
|
+
|
370
|
+
def generate_dispatch_code(function, def_method, params, opt_index, remove_self)
|
371
|
+
# torch::empty sets requires_grad by at::empty doesn't
|
372
|
+
# https://github.com/pytorch/pytorch/issues/36455
|
373
|
+
prefix = remove_self ? "self." : (opt_index ? "torch::" : "at::")
|
374
|
+
dispatch = function.out? ? "#{function.base_name}_out" : function.base_name
|
375
|
+
|
376
|
+
params = params.map { |v| v[:name] }
|
377
|
+
params.reject! { |v| v == "self" } if remove_self
|
378
|
+
params.insert(opt_index, "options") if opt_index
|
379
|
+
|
380
|
+
if function.out_index
|
381
|
+
params.unshift(params.slice!(function.out_index, function.retvals.size))
|
382
|
+
end
|
383
|
+
|
384
|
+
code = "#{prefix}#{dispatch}(#{params.join(", ")});"
|
385
|
+
code = "return #{code}" unless function.retvals.empty?
|
386
|
+
code
|
387
|
+
end
|
388
|
+
|
389
|
+
def generate_dispatch_params(function, params)
|
390
|
+
params.map do |param|
|
391
|
+
type =
|
392
|
+
case param[:type]
|
393
|
+
when "Tensor"
|
394
|
+
if param[:optional]
|
395
|
+
if function.out?
|
396
|
+
"const Tensor &"
|
397
|
+
else
|
398
|
+
# TODO
|
399
|
+
# "const c10::optional<at::Tensor> &"
|
400
|
+
"const OptionalTensor &"
|
401
|
+
end
|
402
|
+
elsif param[:modifier]
|
403
|
+
if param[:modifier].include?("!") && function.retvals.size > 1
|
404
|
+
"Tensor &"
|
405
|
+
else
|
406
|
+
"Tensor"
|
407
|
+
end
|
408
|
+
else
|
409
|
+
"const Tensor &"
|
410
|
+
end
|
411
|
+
when "Tensor[]"
|
412
|
+
"TensorList"
|
413
|
+
when "int"
|
414
|
+
"int64_t"
|
415
|
+
when "float"
|
416
|
+
"double"
|
417
|
+
when /\Aint\[/
|
418
|
+
"IntArrayRef"
|
419
|
+
when "str"
|
420
|
+
"std::string"
|
421
|
+
when "Scalar", "bool", "ScalarType", "Layout", "Device", "Storage", "Generator", "MemoryFormat", "Storage"
|
422
|
+
param[:type]
|
423
|
+
else
|
424
|
+
raise "Unknown type: #{param[:type]} (#{function.name})"
|
425
|
+
end
|
426
|
+
|
427
|
+
if param[:optional] && param[:type] != "Tensor"
|
428
|
+
type = "c10::optional<#{type}>"
|
429
|
+
end
|
430
|
+
|
431
|
+
"#{type} #{param[:name]}"
|
432
|
+
end
|
433
|
+
end
|
434
|
+
|
435
|
+
def generate_dispatch_retval(function)
|
436
|
+
types = function.retvals.map { |r| r[:type] }
|
437
|
+
|
438
|
+
case types
|
439
|
+
when []
|
440
|
+
"void"
|
441
|
+
when ["bool"]
|
442
|
+
"bool"
|
443
|
+
when ["int"]
|
444
|
+
"int64_t"
|
445
|
+
when ["float"]
|
446
|
+
"double"
|
447
|
+
when ["Scalar"]
|
448
|
+
"Scalar"
|
449
|
+
when ["ScalarType"]
|
450
|
+
"ScalarType"
|
451
|
+
when ["QScheme"]
|
452
|
+
"QScheme"
|
453
|
+
when ["Tensor"]
|
454
|
+
"Tensor"
|
455
|
+
when ["Tensor[]"]
|
456
|
+
"std::vector<Tensor>"
|
457
|
+
when ["Tensor", "Tensor"]
|
458
|
+
"std::tuple<Tensor,Tensor>"
|
459
|
+
when ["Tensor", "Tensor", "Tensor"]
|
460
|
+
"std::tuple<Tensor,Tensor,Tensor>"
|
461
|
+
when ["Tensor", "Tensor", "Tensor", "Tensor"]
|
462
|
+
"std::tuple<Tensor,Tensor,Tensor,Tensor>"
|
463
|
+
when ["Tensor", "Tensor", "Tensor", "Tensor", "Tensor"]
|
464
|
+
"std::tuple<Tensor,Tensor,Tensor,Tensor,Tensor>"
|
465
|
+
when ["Tensor", "Tensor", "float", "int"]
|
466
|
+
"std::tuple<Tensor,Tensor,float,int>"
|
467
|
+
else
|
468
|
+
raise "Unknown retvals: #{types}"
|
469
|
+
end
|
470
|
+
end
|
471
|
+
|
472
|
+
def generate_signature(function, type, skip_out: false)
|
473
|
+
params = function.params.dup
|
474
|
+
if function.out?
|
475
|
+
if skip_out
|
476
|
+
# remove out
|
477
|
+
params.slice!(function.out_index, function.retvals.size)
|
478
|
+
elsif function.retvals.size > 1 && params[function.out_index, function.retvals.size].all? { |r| r[:type] == "Tensor" }
|
479
|
+
# combine tensor into tensorlist
|
480
|
+
list_size = function.retvals.size
|
481
|
+
params.slice!(function.out_index, list_size)
|
482
|
+
params.insert(function.out_index, {name: "out", type: "Tensor[#{list_size}]", list_size: list_size, keyword_only: true})
|
483
|
+
end
|
484
|
+
end
|
485
|
+
|
486
|
+
parts = params.select { |v| !v[:keyword_only] && !(type == "tensor" && v[:name] == "self") }
|
487
|
+
keyword_only_parts = params.select { |v| v[:keyword_only] }
|
488
|
+
if keyword_only_parts.any?
|
489
|
+
parts << "*"
|
490
|
+
parts.concat(keyword_only_parts)
|
491
|
+
end
|
492
|
+
|
493
|
+
"#{function.base_name}(#{parts.map { |v| signature_param(v) }.join(", ")})"
|
494
|
+
end
|
495
|
+
|
496
|
+
def signature_param(param)
|
497
|
+
return "*" if param == "*"
|
498
|
+
|
499
|
+
name = param[:name]
|
500
|
+
name = "input" if name == "self"
|
501
|
+
|
502
|
+
sig = "#{signature_type(param)} #{name}"
|
503
|
+
case param[:default]
|
504
|
+
when nil
|
505
|
+
# do nothing
|
506
|
+
when "[]"
|
507
|
+
sig += "=None"
|
508
|
+
when "Mean"
|
509
|
+
sig += "=at::Reduction::Mean"
|
510
|
+
else
|
511
|
+
sig += "=#{param[:default]}"
|
512
|
+
end
|
513
|
+
|
514
|
+
# hack
|
515
|
+
sig += "=None" if param[:name] == "out"
|
516
|
+
|
517
|
+
sig
|
518
|
+
end
|
519
|
+
|
520
|
+
def signature_type(param)
|
521
|
+
type =
|
522
|
+
case param[:type]
|
523
|
+
when "Tensor", /\ATensor\([a-z]!?\)\z/
|
524
|
+
"Tensor"
|
525
|
+
when /\Tensor\[\d*\]\z/
|
526
|
+
"TensorList"
|
527
|
+
when /\ADimname\[\d*\]\z/
|
528
|
+
"DirnameList"
|
529
|
+
when /\Aint\[\d*\]\z/
|
530
|
+
"IntArrayRef"
|
531
|
+
when "int"
|
532
|
+
"int64_t"
|
533
|
+
when "float"
|
534
|
+
"double"
|
535
|
+
when "str"
|
536
|
+
"std::string"
|
537
|
+
when "Scalar", "Dimname", "bool", "ScalarType", "Layout", "Device", "Generator", "MemoryFormat", "Storage"
|
538
|
+
param[:type]
|
539
|
+
else
|
540
|
+
raise "Unknown type: #{param[:type]}"
|
541
|
+
end
|
542
|
+
|
543
|
+
type += "[#{param[:list_size]}]" if param[:list_size]
|
544
|
+
type += "?" if param[:optional]
|
545
|
+
type
|
546
|
+
end
|