torch-rb 0.3.4 → 0.4.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +28 -0
- data/README.md +2 -1
- data/codegen/function.rb +134 -0
- data/codegen/generate_functions.rb +549 -0
- data/{lib/torch/native → codegen}/native_functions.yaml +0 -0
- data/ext/torch/ext.cpp +76 -87
- data/ext/torch/extconf.rb +5 -2
- data/ext/torch/nn_functions.h +6 -0
- data/ext/torch/ruby_arg_parser.cpp +593 -0
- data/ext/torch/ruby_arg_parser.h +373 -0
- data/ext/torch/{templates.hpp → templates.h} +87 -97
- data/ext/torch/tensor_functions.h +6 -0
- data/ext/torch/torch_functions.h +6 -0
- data/ext/torch/utils.h +42 -0
- data/ext/torch/{templates.cpp → wrap_outputs.h} +44 -7
- data/lib/torch.rb +51 -77
- data/lib/torch/nn/functional.rb +142 -18
- data/lib/torch/nn/init.rb +5 -19
- data/lib/torch/nn/leaky_relu.rb +3 -3
- data/lib/torch/nn/module.rb +9 -1
- data/lib/torch/nn/upsample.rb +31 -0
- data/lib/torch/optim/adadelta.rb +1 -1
- data/lib/torch/optim/adam.rb +2 -2
- data/lib/torch/optim/adamax.rb +1 -1
- data/lib/torch/optim/adamw.rb +1 -1
- data/lib/torch/optim/asgd.rb +1 -1
- data/lib/torch/optim/sgd.rb +3 -3
- data/lib/torch/tensor.rb +36 -115
- data/lib/torch/utils/data/data_loader.rb +2 -0
- data/lib/torch/utils/data/tensor_dataset.rb +2 -0
- data/lib/torch/version.rb +1 -1
- metadata +19 -14
- data/lib/torch/native/dispatcher.rb +0 -48
- data/lib/torch/native/function.rb +0 -115
- data/lib/torch/native/generator.rb +0 -163
- data/lib/torch/native/parser.rb +0 -140
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: a40eda909da15ec34573f3fa447519c3cb5553092d6eb35189d02c88154f1669
|
4
|
+
data.tar.gz: cb795ca4c53189534f306c874f90db531c3b14fd17a14d83a592d2e73dd255ff
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 750e2103e8ab7b029f7f3acd439088ba41890d1aabab5761d1ea768f556ae7550d2cef91545e83eebb9b09ef9b76d9dac939f9fa80a7c1478dd7d4668a2a6c0e
|
7
|
+
data.tar.gz: d22f3569dc06cc653bfc0d2786eb952ffad858dd2b235f736ebbabbc2f26b73f4c5b92c3776f2004b8076818730de1f1012b7a9ef615a84909c806abf7e96f52
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,31 @@
|
|
1
|
+
## 0.4.1 (2020-10-12)
|
2
|
+
|
3
|
+
- Fixed installation error with Ruby < 2.7
|
4
|
+
|
5
|
+
## 0.4.0 (2020-09-27)
|
6
|
+
|
7
|
+
- Improved performance of methods
|
8
|
+
- Improved performance of tensor indexing
|
9
|
+
|
10
|
+
## 0.3.7 (2020-09-22)
|
11
|
+
|
12
|
+
- Improved performance
|
13
|
+
- Added `Upsample`
|
14
|
+
- Added support for passing tensor class to `type` method
|
15
|
+
- Fixed error with buffers on GPU
|
16
|
+
- Fixed error with `new_full`
|
17
|
+
- Fixed issue with `numo` method and non-contiguous tensors
|
18
|
+
|
19
|
+
## 0.3.6 (2020-09-17)
|
20
|
+
|
21
|
+
- Added `inplace` option for leaky ReLU
|
22
|
+
- Fixed error with methods that return a tensor list (`chunk`, `split`, and `unbind`)
|
23
|
+
- Fixed error with buffers on GPU
|
24
|
+
|
25
|
+
## 0.3.5 (2020-09-04)
|
26
|
+
|
27
|
+
- Fixed error with data loader (due to `dtype` of `randperm`)
|
28
|
+
|
1
29
|
## 0.3.4 (2020-08-26)
|
2
30
|
|
3
31
|
- Added `Torch.clamp` method
|
data/README.md
CHANGED
@@ -402,6 +402,7 @@ Here are a few full examples:
|
|
402
402
|
- [Image classification with MNIST](examples/mnist) ([日本語版](https://qiita.com/kojix2/items/c19c36dc1bf73ea93409))
|
403
403
|
- [Collaborative filtering with MovieLens](examples/movielens)
|
404
404
|
- [Sequence models and word embeddings](examples/nlp)
|
405
|
+
- [Generative adversarial networks](examples/gan)
|
405
406
|
|
406
407
|
## LibTorch Installation
|
407
408
|
|
@@ -415,7 +416,7 @@ Here’s the list of compatible versions.
|
|
415
416
|
|
416
417
|
Torch.rb | LibTorch
|
417
418
|
--- | ---
|
418
|
-
0.3.0
|
419
|
+
0.3.0+ | 1.6.0
|
419
420
|
0.2.0-0.2.7 | 1.5.0-1.5.1
|
420
421
|
0.1.8 | 1.4.0
|
421
422
|
0.1.0-0.1.7 | 1.3.1
|
data/codegen/function.rb
ADDED
@@ -0,0 +1,134 @@
|
|
1
|
+
class Function
|
2
|
+
attr_reader :definition, :params, :retvals
|
3
|
+
|
4
|
+
def initialize(definition)
|
5
|
+
@definition = definition
|
6
|
+
@params, @retvals = parse_func
|
7
|
+
end
|
8
|
+
|
9
|
+
def name
|
10
|
+
func.split("(", 2).first
|
11
|
+
end
|
12
|
+
|
13
|
+
def base_name
|
14
|
+
name.split(".").first
|
15
|
+
end
|
16
|
+
|
17
|
+
def func
|
18
|
+
definition["func"]
|
19
|
+
end
|
20
|
+
|
21
|
+
def python_module
|
22
|
+
definition["python_module"]
|
23
|
+
end
|
24
|
+
|
25
|
+
def variants
|
26
|
+
(definition["variants"] || "function").split(", ")
|
27
|
+
end
|
28
|
+
|
29
|
+
def out_index
|
30
|
+
params.index { |v| v[:modifier].to_s.include?("!") } if base_name[-1] != "_" && retvals.any?
|
31
|
+
end
|
32
|
+
|
33
|
+
def out?
|
34
|
+
!out_index.nil?
|
35
|
+
end
|
36
|
+
|
37
|
+
private
|
38
|
+
|
39
|
+
def parse_func
|
40
|
+
input, output = func.split(/\s*->\s*/)
|
41
|
+
[generate_params(input), generate_retvals(output)]
|
42
|
+
end
|
43
|
+
|
44
|
+
def generate_params(input)
|
45
|
+
input = input.split("(", 2).last.chomp(")").split(/\s*,\s+/)
|
46
|
+
|
47
|
+
keyword_only = false
|
48
|
+
params = []
|
49
|
+
input.each do |i|
|
50
|
+
if i == "*"
|
51
|
+
keyword_only = true
|
52
|
+
next
|
53
|
+
end
|
54
|
+
|
55
|
+
type, name = i.split(/\s+/)
|
56
|
+
|
57
|
+
if name.include?("=")
|
58
|
+
name, default = name.split("=", 2)
|
59
|
+
end
|
60
|
+
|
61
|
+
optional = false
|
62
|
+
if type.include?("?")
|
63
|
+
optional = true unless ["dtype", "device", "layout", "pin_memory"].include?(name)
|
64
|
+
type = type.delete("?")
|
65
|
+
end
|
66
|
+
|
67
|
+
type, modifier = extract_modifier(type)
|
68
|
+
|
69
|
+
if type.include?("[")
|
70
|
+
list_size = /\[(.*)\]/.match(type)[1]
|
71
|
+
list_size = nil if list_size.empty?
|
72
|
+
end
|
73
|
+
|
74
|
+
if name == "dtype" && (base_name.start_with?("randperm") || base_name == "tril_indices" || base_name == "triu_indices")
|
75
|
+
# dtype hack
|
76
|
+
# https://github.com/pytorch/pytorch/blob/v1.6.0/tools/autograd/gen_python_functions.py#L1307-L1311
|
77
|
+
default = "torch.int64"
|
78
|
+
end
|
79
|
+
|
80
|
+
params << {
|
81
|
+
name: name,
|
82
|
+
type: type,
|
83
|
+
default: default,
|
84
|
+
keyword_only: keyword_only,
|
85
|
+
optional: optional,
|
86
|
+
modifier: modifier,
|
87
|
+
list_size: list_size
|
88
|
+
}
|
89
|
+
end
|
90
|
+
|
91
|
+
if (params.map { |v| v[:name] } & ["dtype", "device", "layout", "pin_memory"]).size == 4
|
92
|
+
params << {
|
93
|
+
name: "requires_grad",
|
94
|
+
type: "bool",
|
95
|
+
default: "False",
|
96
|
+
keyword_only: true,
|
97
|
+
optional: false,
|
98
|
+
modifier: nil,
|
99
|
+
list_size: nil
|
100
|
+
}
|
101
|
+
end
|
102
|
+
|
103
|
+
params
|
104
|
+
end
|
105
|
+
|
106
|
+
def generate_retvals(output)
|
107
|
+
output =
|
108
|
+
if output == "()"
|
109
|
+
[]
|
110
|
+
elsif output[0] == "("
|
111
|
+
output[1..-2].split(/\s*,\s*/)
|
112
|
+
else
|
113
|
+
[output]
|
114
|
+
end
|
115
|
+
|
116
|
+
retvals = []
|
117
|
+
output.each do |o|
|
118
|
+
type, name = o.split(/\s+/)
|
119
|
+
type, modifier = extract_modifier(type)
|
120
|
+
retvals << {name: name, type: type, modifier: modifier}
|
121
|
+
end
|
122
|
+
retvals
|
123
|
+
end
|
124
|
+
|
125
|
+
# Tensor(a), Tensor(a!), Tensor(a)[]
|
126
|
+
def extract_modifier(type)
|
127
|
+
if type.include?("(")
|
128
|
+
parts = type.split(/[\(\)]/, 3)
|
129
|
+
modifier = parts.delete_at(1)
|
130
|
+
type = parts.join("")
|
131
|
+
end
|
132
|
+
[type, modifier]
|
133
|
+
end
|
134
|
+
end
|
@@ -0,0 +1,549 @@
|
|
1
|
+
require "yaml"
|
2
|
+
# use require_relative for
|
3
|
+
# rake generate:function (without bundle)
|
4
|
+
require_relative "function"
|
5
|
+
|
6
|
+
def generate_functions
|
7
|
+
functions = load_functions
|
8
|
+
functions = skip_functions(functions)
|
9
|
+
functions = group_functions(functions)
|
10
|
+
|
11
|
+
generate_files("torch", :define_singleton_method, functions[:torch])
|
12
|
+
generate_files("tensor", :define_method, functions[:tensor])
|
13
|
+
generate_files("nn", :define_singleton_method, functions[:nn])
|
14
|
+
end
|
15
|
+
|
16
|
+
def load_functions
|
17
|
+
path = File.expand_path("native_functions.yaml", __dir__)
|
18
|
+
YAML.load_file(path).map { |f| Function.new(f) }.sort_by(&:name)
|
19
|
+
end
|
20
|
+
|
21
|
+
def skip_functions(functions)
|
22
|
+
functions.reject do |f|
|
23
|
+
f.base_name.start_with?("_") ||
|
24
|
+
f.base_name.include?("_backward") ||
|
25
|
+
f.base_name.include?("_forward") ||
|
26
|
+
f.base_name == "to" ||
|
27
|
+
# in ext.cpp
|
28
|
+
f.base_name == "index" ||
|
29
|
+
f.base_name == "index_put_" ||
|
30
|
+
# need to add to ext.cpp
|
31
|
+
f.base_name == "index_put" ||
|
32
|
+
# not supported yet
|
33
|
+
f.func.include?("Dimname") ||
|
34
|
+
f.func.include?("ConstQuantizerPtr")
|
35
|
+
end
|
36
|
+
end
|
37
|
+
|
38
|
+
def group_functions(functions)
|
39
|
+
nn_functions, other_functions = functions.partition { |f| f.python_module == "nn" }
|
40
|
+
torch_functions = other_functions.select { |f| f.variants.include?("function") }
|
41
|
+
tensor_functions = other_functions.select { |f| f.variants.include?("method") }
|
42
|
+
|
43
|
+
{torch: torch_functions, tensor: tensor_functions, nn: nn_functions}
|
44
|
+
end
|
45
|
+
|
46
|
+
def generate_files(type, def_method, functions)
|
47
|
+
method_defs = []
|
48
|
+
attach_defs = []
|
49
|
+
functions.group_by(&:base_name).each do |name, grouped_functions|
|
50
|
+
method_defs << generate_method_def(name, grouped_functions, type, def_method)
|
51
|
+
attach_defs << generate_attach_def(name, type, def_method)
|
52
|
+
end
|
53
|
+
write_header(type)
|
54
|
+
write_body(type, method_defs, attach_defs)
|
55
|
+
end
|
56
|
+
|
57
|
+
def write_header(type)
|
58
|
+
template = <<~EOS
|
59
|
+
// generated by rake generate:functions
|
60
|
+
// do not edit by hand
|
61
|
+
|
62
|
+
#pragma once
|
63
|
+
|
64
|
+
void add_%{type}_functions(Module m);
|
65
|
+
EOS
|
66
|
+
|
67
|
+
contents = template % {type: type}
|
68
|
+
write_file("#{type}_functions.h", contents)
|
69
|
+
end
|
70
|
+
|
71
|
+
def write_body(type, method_defs, attach_defs)
|
72
|
+
cuda_lazy_init = %{\n#include "torch/csrc/utils/cuda_lazy_init.h"\n} unless type == "nn"
|
73
|
+
|
74
|
+
template = <<~EOS
|
75
|
+
// generated by rake generate:functions
|
76
|
+
// do not edit by hand
|
77
|
+
|
78
|
+
#include <torch/torch.h>
|
79
|
+
#include <rice/Module.hpp>
|
80
|
+
|
81
|
+
#include "ruby_arg_parser.h"
|
82
|
+
#include "templates.h"
|
83
|
+
#include "wrap_outputs.h"
|
84
|
+
%{cuda_lazy_init}
|
85
|
+
%{method_defs}
|
86
|
+
void add_%{type}_functions(Module m) {
|
87
|
+
%{attach_defs}
|
88
|
+
}
|
89
|
+
EOS
|
90
|
+
|
91
|
+
contents = template % {
|
92
|
+
type: type,
|
93
|
+
method_defs: method_defs.join("\n"),
|
94
|
+
attach_defs: attach_defs.join("\n "),
|
95
|
+
cuda_lazy_init: cuda_lazy_init
|
96
|
+
}
|
97
|
+
write_file("#{type}_functions.cpp", contents)
|
98
|
+
end
|
99
|
+
|
100
|
+
def write_file(name, contents)
|
101
|
+
path = File.expand_path("../ext/torch", __dir__)
|
102
|
+
File.write(File.join(path, name), contents)
|
103
|
+
end
|
104
|
+
|
105
|
+
def generate_attach_def(name, type, def_method)
|
106
|
+
ruby_name =
|
107
|
+
if name.end_with?("_")
|
108
|
+
"#{name[0..-2]}!"
|
109
|
+
elsif name.start_with?("is_")
|
110
|
+
"#{name[3..-1]}?"
|
111
|
+
else
|
112
|
+
name
|
113
|
+
end
|
114
|
+
|
115
|
+
ruby_name = "_#{ruby_name}" if ["size", "stride", "random!"].include?(ruby_name)
|
116
|
+
|
117
|
+
# cast for Ruby < 2.7 https://github.com/thisMagpie/fftw/issues/22#issuecomment-49508900
|
118
|
+
cast = RUBY_VERSION.to_f > 2.7 ? "" : "(VALUE (*)(...)) "
|
119
|
+
|
120
|
+
"rb_#{def_method}(m, \"#{ruby_name}\", #{cast}#{type}_#{name}, -1);"
|
121
|
+
end
|
122
|
+
|
123
|
+
def generate_method_def(name, functions, type, def_method)
|
124
|
+
assign_self = type == "tensor" ? "\n Tensor& self = from_ruby<Tensor&>(self_);" : ""
|
125
|
+
|
126
|
+
functions = group_overloads(functions, type)
|
127
|
+
signatures = functions.map { |f| f["signature"] }
|
128
|
+
max_args = signatures.map { |s| s.count(",") - s.count("*") }.max + 1
|
129
|
+
|
130
|
+
template = <<~EOS
|
131
|
+
// #{name}
|
132
|
+
static VALUE #{type}_#{name}(int argc, VALUE* argv, VALUE self_)
|
133
|
+
{
|
134
|
+
HANDLE_TH_ERRORS#{assign_self}
|
135
|
+
static RubyArgParser parser({
|
136
|
+
#{signatures.map(&:inspect).join(",\n ")}
|
137
|
+
});
|
138
|
+
std::vector<VALUE> parsed_args(#{max_args});
|
139
|
+
auto _r = parser.parse(self_, argc, argv, parsed_args);
|
140
|
+
#{add_dispatches(functions, def_method)}
|
141
|
+
END_HANDLE_TH_ERRORS
|
142
|
+
}
|
143
|
+
EOS
|
144
|
+
end
|
145
|
+
|
146
|
+
def indent(code)
|
147
|
+
code.split("\n").join("\n ")
|
148
|
+
end
|
149
|
+
|
150
|
+
def add_dispatches(functions, def_method)
|
151
|
+
if functions.size == 1
|
152
|
+
add_dispatch(functions.first, def_method)
|
153
|
+
else
|
154
|
+
body = []
|
155
|
+
functions.each_with_index do |f, i|
|
156
|
+
body << "case #{i}: {
|
157
|
+
#{add_dispatch(f, def_method).split("\n").join("\n ")}
|
158
|
+
}"
|
159
|
+
end
|
160
|
+
|
161
|
+
"switch (_r.idx) {
|
162
|
+
#{body.join("\n ")}
|
163
|
+
}
|
164
|
+
RETURN_NIL"
|
165
|
+
end
|
166
|
+
end
|
167
|
+
|
168
|
+
def add_dispatch(function, def_method)
|
169
|
+
if function["out"] && function["out"] != function["base"]
|
170
|
+
base_code = generate_dispatch(function["base"], def_method)
|
171
|
+
out_code = generate_dispatch(function["out"], def_method)
|
172
|
+
out_index = function["out"].out_index
|
173
|
+
|
174
|
+
return "if (_r.isNone(#{out_index})) {
|
175
|
+
#{indent(base_code)}
|
176
|
+
} else {
|
177
|
+
#{indent(out_code)}
|
178
|
+
}"
|
179
|
+
else
|
180
|
+
generate_dispatch(function["base"], def_method)
|
181
|
+
end
|
182
|
+
end
|
183
|
+
|
184
|
+
def group_overloads(functions, type)
|
185
|
+
grouped = Hash.new { |hash, key| hash[key] = {} }
|
186
|
+
|
187
|
+
functions.each do |function|
|
188
|
+
signature = generate_signature(function, type, skip_out: true)
|
189
|
+
v = grouped[signature]
|
190
|
+
if function.out?
|
191
|
+
v["out"] = function
|
192
|
+
v["signature"] = generate_signature(function, type)
|
193
|
+
|
194
|
+
# for now
|
195
|
+
v["base"] ||= function
|
196
|
+
else
|
197
|
+
v["base"] = function
|
198
|
+
v["signature"] ||= signature
|
199
|
+
end
|
200
|
+
end
|
201
|
+
|
202
|
+
puts "Missing base: #{functions.first.name}" if grouped.any? { |_, v| !v["base"] }
|
203
|
+
sort_functions(grouped.values)
|
204
|
+
end
|
205
|
+
|
206
|
+
def sort_functions(functions)
|
207
|
+
# TODO
|
208
|
+
functions.sort_by { |f| f["out"] ? 1 : 0 }
|
209
|
+
end
|
210
|
+
|
211
|
+
def generate_dispatch(function, def_method)
|
212
|
+
cpp_name = function.base_name
|
213
|
+
cpp_name += "_out" if function.out?
|
214
|
+
|
215
|
+
remove_self = def_method == :define_method
|
216
|
+
|
217
|
+
params = function.params.map(&:dup)
|
218
|
+
set_param_position(params, remove_self)
|
219
|
+
params, opt_params = split_opt_params(params)
|
220
|
+
opt_index = opt_params.map { |v| v[:position] }.min if opt_params.any?
|
221
|
+
|
222
|
+
cpp_params = generate_dispatch_params(function, params)
|
223
|
+
if opt_index
|
224
|
+
cpp_params.insert(remove_self ? opt_index + 1 : opt_index, "const TensorOptions & options")
|
225
|
+
end
|
226
|
+
|
227
|
+
retval = generate_dispatch_retval(function)
|
228
|
+
dispatch_code = generate_dispatch_code(function, def_method, params, opt_index, remove_self)
|
229
|
+
function_code = generate_function_code(function, cpp_name, params, opt_index, remove_self)
|
230
|
+
|
231
|
+
out_var = generate_out_var(function.out_index, function.retvals.size) if function.out? && function.retvals.size > 1 && function.retvals.all? { |v| v[:type] == "Tensor" }
|
232
|
+
tensor_options = generate_tensor_options(function, opt_params) if opt_params.any?
|
233
|
+
|
234
|
+
"// #{function.func}#{tensor_options}#{out_var}
|
235
|
+
auto dispatch_#{cpp_name} = [](#{cpp_params.join(", ")}) -> #{retval} {
|
236
|
+
// in future, release GVL
|
237
|
+
#{dispatch_code}
|
238
|
+
};
|
239
|
+
#{function_code}"
|
240
|
+
end
|
241
|
+
|
242
|
+
def generate_out_var(out_index, size)
|
243
|
+
"\n auto out = _r.tensorlist_n<#{size}>(#{out_index});"
|
244
|
+
end
|
245
|
+
|
246
|
+
def set_param_position(params, remove_self)
|
247
|
+
i = 0
|
248
|
+
params.each do |v|
|
249
|
+
next if remove_self && v[:name] == "self"
|
250
|
+
v[:position] = i
|
251
|
+
i += 1
|
252
|
+
end
|
253
|
+
end
|
254
|
+
|
255
|
+
def split_opt_params(params)
|
256
|
+
option_names = ["dtype", "device", "layout", "requires_grad", "pin_memory"]
|
257
|
+
|
258
|
+
opt_params, other_params = params.partition { |v, i| option_names.include?(v[:name]) }
|
259
|
+
if opt_params.size >= 4
|
260
|
+
[other_params, opt_params]
|
261
|
+
else
|
262
|
+
[params, []]
|
263
|
+
end
|
264
|
+
end
|
265
|
+
|
266
|
+
def generate_tensor_options(function, opt_params)
|
267
|
+
code = "\n const auto options = TensorOptions()"
|
268
|
+
order = ["dtype", "device", "layout", "requires_grad", "pin_memory"]
|
269
|
+
opt_params.sort_by { |v| order.index(v[:name]) }.each do |opt|
|
270
|
+
i = opt[:position]
|
271
|
+
|
272
|
+
c =
|
273
|
+
case opt[:name]
|
274
|
+
when "dtype"
|
275
|
+
if function.base_name == "arange"
|
276
|
+
"dtype(_r.scalartypeOptional(#{i}))"
|
277
|
+
else
|
278
|
+
"dtype(_r.scalartype(#{i}))"
|
279
|
+
end
|
280
|
+
when "device"
|
281
|
+
"device(_r.device(#{i}))"
|
282
|
+
when "layout"
|
283
|
+
"layout(_r.layoutOptional(#{i}))"
|
284
|
+
when "requires_grad"
|
285
|
+
"requires_grad(_r.toBool(#{i}))"
|
286
|
+
when "pin_memory"
|
287
|
+
"pinned_memory(_r.toBool(#{i}))"
|
288
|
+
end
|
289
|
+
|
290
|
+
code += "\n .#{c}"
|
291
|
+
end
|
292
|
+
|
293
|
+
"#{code};\n torch::utils::maybe_initialize_cuda(options);"
|
294
|
+
end
|
295
|
+
|
296
|
+
def generate_function_code(function, cpp_name, params, opt_index, remove_self)
|
297
|
+
params = generate_function_params(function, params, remove_self)
|
298
|
+
if opt_index
|
299
|
+
opt_index += 1 if remove_self
|
300
|
+
params.insert(opt_index, "options")
|
301
|
+
end
|
302
|
+
|
303
|
+
code = "dispatch_#{cpp_name}(#{params.join(", ")})"
|
304
|
+
if function.retvals.empty?
|
305
|
+
"#{code};\nRETURN_NIL"
|
306
|
+
else
|
307
|
+
"return wrap(#{code});"
|
308
|
+
end
|
309
|
+
end
|
310
|
+
|
311
|
+
def generate_function_params(function, params, remove_self)
|
312
|
+
out_var = function.out? && function.retvals.size > 1 && function.retvals.all? { |v| v[:type] == "Tensor" }
|
313
|
+
|
314
|
+
i = 0
|
315
|
+
params.map do |param|
|
316
|
+
i += 1
|
317
|
+
|
318
|
+
next "self" if remove_self && param[:name] == "self"
|
319
|
+
if out_var && i > function.out_index
|
320
|
+
next "out[#{i - function.out_index - 1}]"
|
321
|
+
end
|
322
|
+
|
323
|
+
func =
|
324
|
+
case param[:type]
|
325
|
+
when "Tensor"
|
326
|
+
"tensor"
|
327
|
+
when "Tensor[]"
|
328
|
+
"tensorlist"
|
329
|
+
when /\Aint\[/
|
330
|
+
"intlist"
|
331
|
+
when "Scalar"
|
332
|
+
"scalar"
|
333
|
+
when "bool"
|
334
|
+
"toBool"
|
335
|
+
when "int"
|
336
|
+
"toInt64"
|
337
|
+
when "float"
|
338
|
+
"toDouble"
|
339
|
+
when "ScalarType"
|
340
|
+
"scalartype"
|
341
|
+
when "str"
|
342
|
+
"string"
|
343
|
+
when "Generator"
|
344
|
+
"generator"
|
345
|
+
when "MemoryFormat"
|
346
|
+
"memoryformat"
|
347
|
+
when "Storage"
|
348
|
+
"storage"
|
349
|
+
else
|
350
|
+
raise "Unknown type: #{param[:type]} (#{function.name})"
|
351
|
+
end
|
352
|
+
|
353
|
+
if param[:optional]
|
354
|
+
func =
|
355
|
+
case func
|
356
|
+
when "tensor"
|
357
|
+
if function.out?
|
358
|
+
"tensor"
|
359
|
+
else
|
360
|
+
"optionalTensor"
|
361
|
+
end
|
362
|
+
when "generator", "tensorlist", "intlist"
|
363
|
+
func
|
364
|
+
else
|
365
|
+
"#{func}Optional"
|
366
|
+
end
|
367
|
+
end
|
368
|
+
|
369
|
+
"_r.#{func}(#{param[:position]})"
|
370
|
+
end
|
371
|
+
end
|
372
|
+
|
373
|
+
def generate_dispatch_code(function, def_method, params, opt_index, remove_self)
|
374
|
+
# torch::empty sets requires_grad by at::empty doesn't
|
375
|
+
# https://github.com/pytorch/pytorch/issues/36455
|
376
|
+
prefix = remove_self ? "self." : (opt_index ? "torch::" : "at::")
|
377
|
+
dispatch = function.out? ? "#{function.base_name}_out" : function.base_name
|
378
|
+
|
379
|
+
params = params.map { |v| v[:name] }
|
380
|
+
params.reject! { |v| v == "self" } if remove_self
|
381
|
+
params.insert(opt_index, "options") if opt_index
|
382
|
+
|
383
|
+
if function.out_index
|
384
|
+
params.unshift(params.slice!(function.out_index, function.retvals.size))
|
385
|
+
end
|
386
|
+
|
387
|
+
code = "#{prefix}#{dispatch}(#{params.join(", ")});"
|
388
|
+
code = "return #{code}" unless function.retvals.empty?
|
389
|
+
code
|
390
|
+
end
|
391
|
+
|
392
|
+
def generate_dispatch_params(function, params)
|
393
|
+
params.map do |param|
|
394
|
+
type =
|
395
|
+
case param[:type]
|
396
|
+
when "Tensor"
|
397
|
+
if param[:optional]
|
398
|
+
if function.out?
|
399
|
+
"const Tensor &"
|
400
|
+
else
|
401
|
+
# TODO
|
402
|
+
# "const c10::optional<at::Tensor> &"
|
403
|
+
"const OptionalTensor &"
|
404
|
+
end
|
405
|
+
elsif param[:modifier]
|
406
|
+
if param[:modifier].include?("!") && function.retvals.size > 1
|
407
|
+
"Tensor &"
|
408
|
+
else
|
409
|
+
"Tensor"
|
410
|
+
end
|
411
|
+
else
|
412
|
+
"const Tensor &"
|
413
|
+
end
|
414
|
+
when "Tensor[]"
|
415
|
+
"TensorList"
|
416
|
+
when "int"
|
417
|
+
"int64_t"
|
418
|
+
when "float"
|
419
|
+
"double"
|
420
|
+
when /\Aint\[/
|
421
|
+
"IntArrayRef"
|
422
|
+
when "str"
|
423
|
+
"std::string"
|
424
|
+
when "Scalar", "bool", "ScalarType", "Layout", "Device", "Storage", "Generator", "MemoryFormat", "Storage"
|
425
|
+
param[:type]
|
426
|
+
else
|
427
|
+
raise "Unknown type: #{param[:type]} (#{function.name})"
|
428
|
+
end
|
429
|
+
|
430
|
+
if param[:optional] && param[:type] != "Tensor"
|
431
|
+
type = "c10::optional<#{type}>"
|
432
|
+
end
|
433
|
+
|
434
|
+
"#{type} #{param[:name]}"
|
435
|
+
end
|
436
|
+
end
|
437
|
+
|
438
|
+
def generate_dispatch_retval(function)
|
439
|
+
types = function.retvals.map { |r| r[:type] }
|
440
|
+
|
441
|
+
case types
|
442
|
+
when []
|
443
|
+
"void"
|
444
|
+
when ["bool"]
|
445
|
+
"bool"
|
446
|
+
when ["int"]
|
447
|
+
"int64_t"
|
448
|
+
when ["float"]
|
449
|
+
"double"
|
450
|
+
when ["Scalar"]
|
451
|
+
"Scalar"
|
452
|
+
when ["ScalarType"]
|
453
|
+
"ScalarType"
|
454
|
+
when ["QScheme"]
|
455
|
+
"QScheme"
|
456
|
+
when ["Tensor"]
|
457
|
+
"Tensor"
|
458
|
+
when ["Tensor[]"]
|
459
|
+
"std::vector<Tensor>"
|
460
|
+
when ["Tensor", "Tensor"]
|
461
|
+
"std::tuple<Tensor,Tensor>"
|
462
|
+
when ["Tensor", "Tensor", "Tensor"]
|
463
|
+
"std::tuple<Tensor,Tensor,Tensor>"
|
464
|
+
when ["Tensor", "Tensor", "Tensor", "Tensor"]
|
465
|
+
"std::tuple<Tensor,Tensor,Tensor,Tensor>"
|
466
|
+
when ["Tensor", "Tensor", "Tensor", "Tensor", "Tensor"]
|
467
|
+
"std::tuple<Tensor,Tensor,Tensor,Tensor,Tensor>"
|
468
|
+
when ["Tensor", "Tensor", "float", "int"]
|
469
|
+
"std::tuple<Tensor,Tensor,float,int>"
|
470
|
+
else
|
471
|
+
raise "Unknown retvals: #{types}"
|
472
|
+
end
|
473
|
+
end
|
474
|
+
|
475
|
+
def generate_signature(function, type, skip_out: false)
|
476
|
+
params = function.params.dup
|
477
|
+
if function.out?
|
478
|
+
if skip_out
|
479
|
+
# remove out
|
480
|
+
params.slice!(function.out_index, function.retvals.size)
|
481
|
+
elsif function.retvals.size > 1 && params[function.out_index, function.retvals.size].all? { |r| r[:type] == "Tensor" }
|
482
|
+
# combine tensor into tensorlist
|
483
|
+
list_size = function.retvals.size
|
484
|
+
params.slice!(function.out_index, list_size)
|
485
|
+
params.insert(function.out_index, {name: "out", type: "Tensor[#{list_size}]", list_size: list_size, keyword_only: true})
|
486
|
+
end
|
487
|
+
end
|
488
|
+
|
489
|
+
parts = params.select { |v| !v[:keyword_only] && !(type == "tensor" && v[:name] == "self") }
|
490
|
+
keyword_only_parts = params.select { |v| v[:keyword_only] }
|
491
|
+
if keyword_only_parts.any?
|
492
|
+
parts << "*"
|
493
|
+
parts.concat(keyword_only_parts)
|
494
|
+
end
|
495
|
+
|
496
|
+
"#{function.base_name}(#{parts.map { |v| signature_param(v) }.join(", ")})"
|
497
|
+
end
|
498
|
+
|
499
|
+
def signature_param(param)
|
500
|
+
return "*" if param == "*"
|
501
|
+
|
502
|
+
name = param[:name]
|
503
|
+
name = "input" if name == "self"
|
504
|
+
|
505
|
+
sig = "#{signature_type(param)} #{name}"
|
506
|
+
case param[:default]
|
507
|
+
when nil
|
508
|
+
# do nothing
|
509
|
+
when "[]"
|
510
|
+
sig += "=None"
|
511
|
+
when "Mean"
|
512
|
+
sig += "=at::Reduction::Mean"
|
513
|
+
else
|
514
|
+
sig += "=#{param[:default]}"
|
515
|
+
end
|
516
|
+
|
517
|
+
# hack
|
518
|
+
sig += "=None" if param[:name] == "out"
|
519
|
+
|
520
|
+
sig
|
521
|
+
end
|
522
|
+
|
523
|
+
def signature_type(param)
|
524
|
+
type =
|
525
|
+
case param[:type]
|
526
|
+
when "Tensor", /\ATensor\([a-z]!?\)\z/
|
527
|
+
"Tensor"
|
528
|
+
when /\Tensor\[\d*\]\z/
|
529
|
+
"TensorList"
|
530
|
+
when /\ADimname\[\d*\]\z/
|
531
|
+
"DirnameList"
|
532
|
+
when /\Aint\[\d*\]\z/
|
533
|
+
"IntArrayRef"
|
534
|
+
when "int"
|
535
|
+
"int64_t"
|
536
|
+
when "float"
|
537
|
+
"double"
|
538
|
+
when "str"
|
539
|
+
"std::string"
|
540
|
+
when "Scalar", "Dimname", "bool", "ScalarType", "Layout", "Device", "Generator", "MemoryFormat", "Storage"
|
541
|
+
param[:type]
|
542
|
+
else
|
543
|
+
raise "Unknown type: #{param[:type]}"
|
544
|
+
end
|
545
|
+
|
546
|
+
type += "[#{param[:list_size]}]" if param[:list_size]
|
547
|
+
type += "?" if param[:optional]
|
548
|
+
type
|
549
|
+
end
|