torch-rb 0.3.4 → 0.4.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: fa01dc7cfd168494f1066f4e692eb08dee7f419126b2d10a6eb5b2b22fe01526
4
- data.tar.gz: 9e69a7da9ecc85c51bd9cfb42b93f1b1aa6148958ed8bc16ffb0555b6d42159d
3
+ metadata.gz: a40eda909da15ec34573f3fa447519c3cb5553092d6eb35189d02c88154f1669
4
+ data.tar.gz: cb795ca4c53189534f306c874f90db531c3b14fd17a14d83a592d2e73dd255ff
5
5
  SHA512:
6
- metadata.gz: 36bc620212235e57a40d9e791b3eae85fbd73c7224a5445e2e94e49c7a3ec4cb638024193d8e47817e577a72ef96976fdf2851f3607fd3ea1712255bca0e1ec1
7
- data.tar.gz: 888971d06717b610020644ed720ae48fa645a90790d54692ac64a565434b550ce4f778688f0a22806f018656dd9d6183d6510fc458c8c5e82c0629951489279c
6
+ metadata.gz: 750e2103e8ab7b029f7f3acd439088ba41890d1aabab5761d1ea768f556ae7550d2cef91545e83eebb9b09ef9b76d9dac939f9fa80a7c1478dd7d4668a2a6c0e
7
+ data.tar.gz: d22f3569dc06cc653bfc0d2786eb952ffad858dd2b235f736ebbabbc2f26b73f4c5b92c3776f2004b8076818730de1f1012b7a9ef615a84909c806abf7e96f52
@@ -1,3 +1,31 @@
1
+ ## 0.4.1 (2020-10-12)
2
+
3
+ - Fixed installation error with Ruby < 2.7
4
+
5
+ ## 0.4.0 (2020-09-27)
6
+
7
+ - Improved performance of methods
8
+ - Improved performance of tensor indexing
9
+
10
+ ## 0.3.7 (2020-09-22)
11
+
12
+ - Improved performance
13
+ - Added `Upsample`
14
+ - Added support for passing tensor class to `type` method
15
+ - Fixed error with buffers on GPU
16
+ - Fixed error with `new_full`
17
+ - Fixed issue with `numo` method and non-contiguous tensors
18
+
19
+ ## 0.3.6 (2020-09-17)
20
+
21
+ - Added `inplace` option for leaky ReLU
22
+ - Fixed error with methods that return a tensor list (`chunk`, `split`, and `unbind`)
23
+ - Fixed error with buffers on GPU
24
+
25
+ ## 0.3.5 (2020-09-04)
26
+
27
+ - Fixed error with data loader (due to `dtype` of `randperm`)
28
+
1
29
  ## 0.3.4 (2020-08-26)
2
30
 
3
31
  - Added `Torch.clamp` method
data/README.md CHANGED
@@ -402,6 +402,7 @@ Here are a few full examples:
402
402
  - [Image classification with MNIST](examples/mnist) ([日本語版](https://qiita.com/kojix2/items/c19c36dc1bf73ea93409))
403
403
  - [Collaborative filtering with MovieLens](examples/movielens)
404
404
  - [Sequence models and word embeddings](examples/nlp)
405
+ - [Generative adversarial networks](examples/gan)
405
406
 
406
407
  ## LibTorch Installation
407
408
 
@@ -415,7 +416,7 @@ Here’s the list of compatible versions.
415
416
 
416
417
  Torch.rb | LibTorch
417
418
  --- | ---
418
- 0.3.0-0.3.4 | 1.6.0
419
+ 0.3.0+ | 1.6.0
419
420
  0.2.0-0.2.7 | 1.5.0-1.5.1
420
421
  0.1.8 | 1.4.0
421
422
  0.1.0-0.1.7 | 1.3.1
@@ -0,0 +1,134 @@
1
+ class Function
2
+ attr_reader :definition, :params, :retvals
3
+
4
+ def initialize(definition)
5
+ @definition = definition
6
+ @params, @retvals = parse_func
7
+ end
8
+
9
+ def name
10
+ func.split("(", 2).first
11
+ end
12
+
13
+ def base_name
14
+ name.split(".").first
15
+ end
16
+
17
+ def func
18
+ definition["func"]
19
+ end
20
+
21
+ def python_module
22
+ definition["python_module"]
23
+ end
24
+
25
+ def variants
26
+ (definition["variants"] || "function").split(", ")
27
+ end
28
+
29
+ def out_index
30
+ params.index { |v| v[:modifier].to_s.include?("!") } if base_name[-1] != "_" && retvals.any?
31
+ end
32
+
33
+ def out?
34
+ !out_index.nil?
35
+ end
36
+
37
+ private
38
+
39
+ def parse_func
40
+ input, output = func.split(/\s*->\s*/)
41
+ [generate_params(input), generate_retvals(output)]
42
+ end
43
+
44
+ def generate_params(input)
45
+ input = input.split("(", 2).last.chomp(")").split(/\s*,\s+/)
46
+
47
+ keyword_only = false
48
+ params = []
49
+ input.each do |i|
50
+ if i == "*"
51
+ keyword_only = true
52
+ next
53
+ end
54
+
55
+ type, name = i.split(/\s+/)
56
+
57
+ if name.include?("=")
58
+ name, default = name.split("=", 2)
59
+ end
60
+
61
+ optional = false
62
+ if type.include?("?")
63
+ optional = true unless ["dtype", "device", "layout", "pin_memory"].include?(name)
64
+ type = type.delete("?")
65
+ end
66
+
67
+ type, modifier = extract_modifier(type)
68
+
69
+ if type.include?("[")
70
+ list_size = /\[(.*)\]/.match(type)[1]
71
+ list_size = nil if list_size.empty?
72
+ end
73
+
74
+ if name == "dtype" && (base_name.start_with?("randperm") || base_name == "tril_indices" || base_name == "triu_indices")
75
+ # dtype hack
76
+ # https://github.com/pytorch/pytorch/blob/v1.6.0/tools/autograd/gen_python_functions.py#L1307-L1311
77
+ default = "torch.int64"
78
+ end
79
+
80
+ params << {
81
+ name: name,
82
+ type: type,
83
+ default: default,
84
+ keyword_only: keyword_only,
85
+ optional: optional,
86
+ modifier: modifier,
87
+ list_size: list_size
88
+ }
89
+ end
90
+
91
+ if (params.map { |v| v[:name] } & ["dtype", "device", "layout", "pin_memory"]).size == 4
92
+ params << {
93
+ name: "requires_grad",
94
+ type: "bool",
95
+ default: "False",
96
+ keyword_only: true,
97
+ optional: false,
98
+ modifier: nil,
99
+ list_size: nil
100
+ }
101
+ end
102
+
103
+ params
104
+ end
105
+
106
+ def generate_retvals(output)
107
+ output =
108
+ if output == "()"
109
+ []
110
+ elsif output[0] == "("
111
+ output[1..-2].split(/\s*,\s*/)
112
+ else
113
+ [output]
114
+ end
115
+
116
+ retvals = []
117
+ output.each do |o|
118
+ type, name = o.split(/\s+/)
119
+ type, modifier = extract_modifier(type)
120
+ retvals << {name: name, type: type, modifier: modifier}
121
+ end
122
+ retvals
123
+ end
124
+
125
+ # Tensor(a), Tensor(a!), Tensor(a)[]
126
+ def extract_modifier(type)
127
+ if type.include?("(")
128
+ parts = type.split(/[\(\)]/, 3)
129
+ modifier = parts.delete_at(1)
130
+ type = parts.join("")
131
+ end
132
+ [type, modifier]
133
+ end
134
+ end
@@ -0,0 +1,549 @@
1
+ require "yaml"
2
+ # use require_relative for
3
+ # rake generate:function (without bundle)
4
+ require_relative "function"
5
+
6
+ def generate_functions
7
+ functions = load_functions
8
+ functions = skip_functions(functions)
9
+ functions = group_functions(functions)
10
+
11
+ generate_files("torch", :define_singleton_method, functions[:torch])
12
+ generate_files("tensor", :define_method, functions[:tensor])
13
+ generate_files("nn", :define_singleton_method, functions[:nn])
14
+ end
15
+
16
+ def load_functions
17
+ path = File.expand_path("native_functions.yaml", __dir__)
18
+ YAML.load_file(path).map { |f| Function.new(f) }.sort_by(&:name)
19
+ end
20
+
21
+ def skip_functions(functions)
22
+ functions.reject do |f|
23
+ f.base_name.start_with?("_") ||
24
+ f.base_name.include?("_backward") ||
25
+ f.base_name.include?("_forward") ||
26
+ f.base_name == "to" ||
27
+ # in ext.cpp
28
+ f.base_name == "index" ||
29
+ f.base_name == "index_put_" ||
30
+ # need to add to ext.cpp
31
+ f.base_name == "index_put" ||
32
+ # not supported yet
33
+ f.func.include?("Dimname") ||
34
+ f.func.include?("ConstQuantizerPtr")
35
+ end
36
+ end
37
+
38
+ def group_functions(functions)
39
+ nn_functions, other_functions = functions.partition { |f| f.python_module == "nn" }
40
+ torch_functions = other_functions.select { |f| f.variants.include?("function") }
41
+ tensor_functions = other_functions.select { |f| f.variants.include?("method") }
42
+
43
+ {torch: torch_functions, tensor: tensor_functions, nn: nn_functions}
44
+ end
45
+
46
+ def generate_files(type, def_method, functions)
47
+ method_defs = []
48
+ attach_defs = []
49
+ functions.group_by(&:base_name).each do |name, grouped_functions|
50
+ method_defs << generate_method_def(name, grouped_functions, type, def_method)
51
+ attach_defs << generate_attach_def(name, type, def_method)
52
+ end
53
+ write_header(type)
54
+ write_body(type, method_defs, attach_defs)
55
+ end
56
+
57
+ def write_header(type)
58
+ template = <<~EOS
59
+ // generated by rake generate:functions
60
+ // do not edit by hand
61
+
62
+ #pragma once
63
+
64
+ void add_%{type}_functions(Module m);
65
+ EOS
66
+
67
+ contents = template % {type: type}
68
+ write_file("#{type}_functions.h", contents)
69
+ end
70
+
71
+ def write_body(type, method_defs, attach_defs)
72
+ cuda_lazy_init = %{\n#include "torch/csrc/utils/cuda_lazy_init.h"\n} unless type == "nn"
73
+
74
+ template = <<~EOS
75
+ // generated by rake generate:functions
76
+ // do not edit by hand
77
+
78
+ #include <torch/torch.h>
79
+ #include <rice/Module.hpp>
80
+
81
+ #include "ruby_arg_parser.h"
82
+ #include "templates.h"
83
+ #include "wrap_outputs.h"
84
+ %{cuda_lazy_init}
85
+ %{method_defs}
86
+ void add_%{type}_functions(Module m) {
87
+ %{attach_defs}
88
+ }
89
+ EOS
90
+
91
+ contents = template % {
92
+ type: type,
93
+ method_defs: method_defs.join("\n"),
94
+ attach_defs: attach_defs.join("\n "),
95
+ cuda_lazy_init: cuda_lazy_init
96
+ }
97
+ write_file("#{type}_functions.cpp", contents)
98
+ end
99
+
100
+ def write_file(name, contents)
101
+ path = File.expand_path("../ext/torch", __dir__)
102
+ File.write(File.join(path, name), contents)
103
+ end
104
+
105
+ def generate_attach_def(name, type, def_method)
106
+ ruby_name =
107
+ if name.end_with?("_")
108
+ "#{name[0..-2]}!"
109
+ elsif name.start_with?("is_")
110
+ "#{name[3..-1]}?"
111
+ else
112
+ name
113
+ end
114
+
115
+ ruby_name = "_#{ruby_name}" if ["size", "stride", "random!"].include?(ruby_name)
116
+
117
+ # cast for Ruby < 2.7 https://github.com/thisMagpie/fftw/issues/22#issuecomment-49508900
118
+ cast = RUBY_VERSION.to_f > 2.7 ? "" : "(VALUE (*)(...)) "
119
+
120
+ "rb_#{def_method}(m, \"#{ruby_name}\", #{cast}#{type}_#{name}, -1);"
121
+ end
122
+
123
+ def generate_method_def(name, functions, type, def_method)
124
+ assign_self = type == "tensor" ? "\n Tensor& self = from_ruby<Tensor&>(self_);" : ""
125
+
126
+ functions = group_overloads(functions, type)
127
+ signatures = functions.map { |f| f["signature"] }
128
+ max_args = signatures.map { |s| s.count(",") - s.count("*") }.max + 1
129
+
130
+ template = <<~EOS
131
+ // #{name}
132
+ static VALUE #{type}_#{name}(int argc, VALUE* argv, VALUE self_)
133
+ {
134
+ HANDLE_TH_ERRORS#{assign_self}
135
+ static RubyArgParser parser({
136
+ #{signatures.map(&:inspect).join(",\n ")}
137
+ });
138
+ std::vector<VALUE> parsed_args(#{max_args});
139
+ auto _r = parser.parse(self_, argc, argv, parsed_args);
140
+ #{add_dispatches(functions, def_method)}
141
+ END_HANDLE_TH_ERRORS
142
+ }
143
+ EOS
144
+ end
145
+
146
+ def indent(code)
147
+ code.split("\n").join("\n ")
148
+ end
149
+
150
+ def add_dispatches(functions, def_method)
151
+ if functions.size == 1
152
+ add_dispatch(functions.first, def_method)
153
+ else
154
+ body = []
155
+ functions.each_with_index do |f, i|
156
+ body << "case #{i}: {
157
+ #{add_dispatch(f, def_method).split("\n").join("\n ")}
158
+ }"
159
+ end
160
+
161
+ "switch (_r.idx) {
162
+ #{body.join("\n ")}
163
+ }
164
+ RETURN_NIL"
165
+ end
166
+ end
167
+
168
+ def add_dispatch(function, def_method)
169
+ if function["out"] && function["out"] != function["base"]
170
+ base_code = generate_dispatch(function["base"], def_method)
171
+ out_code = generate_dispatch(function["out"], def_method)
172
+ out_index = function["out"].out_index
173
+
174
+ return "if (_r.isNone(#{out_index})) {
175
+ #{indent(base_code)}
176
+ } else {
177
+ #{indent(out_code)}
178
+ }"
179
+ else
180
+ generate_dispatch(function["base"], def_method)
181
+ end
182
+ end
183
+
184
+ def group_overloads(functions, type)
185
+ grouped = Hash.new { |hash, key| hash[key] = {} }
186
+
187
+ functions.each do |function|
188
+ signature = generate_signature(function, type, skip_out: true)
189
+ v = grouped[signature]
190
+ if function.out?
191
+ v["out"] = function
192
+ v["signature"] = generate_signature(function, type)
193
+
194
+ # for now
195
+ v["base"] ||= function
196
+ else
197
+ v["base"] = function
198
+ v["signature"] ||= signature
199
+ end
200
+ end
201
+
202
+ puts "Missing base: #{functions.first.name}" if grouped.any? { |_, v| !v["base"] }
203
+ sort_functions(grouped.values)
204
+ end
205
+
206
+ def sort_functions(functions)
207
+ # TODO
208
+ functions.sort_by { |f| f["out"] ? 1 : 0 }
209
+ end
210
+
211
+ def generate_dispatch(function, def_method)
212
+ cpp_name = function.base_name
213
+ cpp_name += "_out" if function.out?
214
+
215
+ remove_self = def_method == :define_method
216
+
217
+ params = function.params.map(&:dup)
218
+ set_param_position(params, remove_self)
219
+ params, opt_params = split_opt_params(params)
220
+ opt_index = opt_params.map { |v| v[:position] }.min if opt_params.any?
221
+
222
+ cpp_params = generate_dispatch_params(function, params)
223
+ if opt_index
224
+ cpp_params.insert(remove_self ? opt_index + 1 : opt_index, "const TensorOptions & options")
225
+ end
226
+
227
+ retval = generate_dispatch_retval(function)
228
+ dispatch_code = generate_dispatch_code(function, def_method, params, opt_index, remove_self)
229
+ function_code = generate_function_code(function, cpp_name, params, opt_index, remove_self)
230
+
231
+ out_var = generate_out_var(function.out_index, function.retvals.size) if function.out? && function.retvals.size > 1 && function.retvals.all? { |v| v[:type] == "Tensor" }
232
+ tensor_options = generate_tensor_options(function, opt_params) if opt_params.any?
233
+
234
+ "// #{function.func}#{tensor_options}#{out_var}
235
+ auto dispatch_#{cpp_name} = [](#{cpp_params.join(", ")}) -> #{retval} {
236
+ // in future, release GVL
237
+ #{dispatch_code}
238
+ };
239
+ #{function_code}"
240
+ end
241
+
242
+ def generate_out_var(out_index, size)
243
+ "\n auto out = _r.tensorlist_n<#{size}>(#{out_index});"
244
+ end
245
+
246
+ def set_param_position(params, remove_self)
247
+ i = 0
248
+ params.each do |v|
249
+ next if remove_self && v[:name] == "self"
250
+ v[:position] = i
251
+ i += 1
252
+ end
253
+ end
254
+
255
+ def split_opt_params(params)
256
+ option_names = ["dtype", "device", "layout", "requires_grad", "pin_memory"]
257
+
258
+ opt_params, other_params = params.partition { |v, i| option_names.include?(v[:name]) }
259
+ if opt_params.size >= 4
260
+ [other_params, opt_params]
261
+ else
262
+ [params, []]
263
+ end
264
+ end
265
+
266
+ def generate_tensor_options(function, opt_params)
267
+ code = "\n const auto options = TensorOptions()"
268
+ order = ["dtype", "device", "layout", "requires_grad", "pin_memory"]
269
+ opt_params.sort_by { |v| order.index(v[:name]) }.each do |opt|
270
+ i = opt[:position]
271
+
272
+ c =
273
+ case opt[:name]
274
+ when "dtype"
275
+ if function.base_name == "arange"
276
+ "dtype(_r.scalartypeOptional(#{i}))"
277
+ else
278
+ "dtype(_r.scalartype(#{i}))"
279
+ end
280
+ when "device"
281
+ "device(_r.device(#{i}))"
282
+ when "layout"
283
+ "layout(_r.layoutOptional(#{i}))"
284
+ when "requires_grad"
285
+ "requires_grad(_r.toBool(#{i}))"
286
+ when "pin_memory"
287
+ "pinned_memory(_r.toBool(#{i}))"
288
+ end
289
+
290
+ code += "\n .#{c}"
291
+ end
292
+
293
+ "#{code};\n torch::utils::maybe_initialize_cuda(options);"
294
+ end
295
+
296
+ def generate_function_code(function, cpp_name, params, opt_index, remove_self)
297
+ params = generate_function_params(function, params, remove_self)
298
+ if opt_index
299
+ opt_index += 1 if remove_self
300
+ params.insert(opt_index, "options")
301
+ end
302
+
303
+ code = "dispatch_#{cpp_name}(#{params.join(", ")})"
304
+ if function.retvals.empty?
305
+ "#{code};\nRETURN_NIL"
306
+ else
307
+ "return wrap(#{code});"
308
+ end
309
+ end
310
+
311
+ def generate_function_params(function, params, remove_self)
312
+ out_var = function.out? && function.retvals.size > 1 && function.retvals.all? { |v| v[:type] == "Tensor" }
313
+
314
+ i = 0
315
+ params.map do |param|
316
+ i += 1
317
+
318
+ next "self" if remove_self && param[:name] == "self"
319
+ if out_var && i > function.out_index
320
+ next "out[#{i - function.out_index - 1}]"
321
+ end
322
+
323
+ func =
324
+ case param[:type]
325
+ when "Tensor"
326
+ "tensor"
327
+ when "Tensor[]"
328
+ "tensorlist"
329
+ when /\Aint\[/
330
+ "intlist"
331
+ when "Scalar"
332
+ "scalar"
333
+ when "bool"
334
+ "toBool"
335
+ when "int"
336
+ "toInt64"
337
+ when "float"
338
+ "toDouble"
339
+ when "ScalarType"
340
+ "scalartype"
341
+ when "str"
342
+ "string"
343
+ when "Generator"
344
+ "generator"
345
+ when "MemoryFormat"
346
+ "memoryformat"
347
+ when "Storage"
348
+ "storage"
349
+ else
350
+ raise "Unknown type: #{param[:type]} (#{function.name})"
351
+ end
352
+
353
+ if param[:optional]
354
+ func =
355
+ case func
356
+ when "tensor"
357
+ if function.out?
358
+ "tensor"
359
+ else
360
+ "optionalTensor"
361
+ end
362
+ when "generator", "tensorlist", "intlist"
363
+ func
364
+ else
365
+ "#{func}Optional"
366
+ end
367
+ end
368
+
369
+ "_r.#{func}(#{param[:position]})"
370
+ end
371
+ end
372
+
373
+ def generate_dispatch_code(function, def_method, params, opt_index, remove_self)
374
+ # torch::empty sets requires_grad by at::empty doesn't
375
+ # https://github.com/pytorch/pytorch/issues/36455
376
+ prefix = remove_self ? "self." : (opt_index ? "torch::" : "at::")
377
+ dispatch = function.out? ? "#{function.base_name}_out" : function.base_name
378
+
379
+ params = params.map { |v| v[:name] }
380
+ params.reject! { |v| v == "self" } if remove_self
381
+ params.insert(opt_index, "options") if opt_index
382
+
383
+ if function.out_index
384
+ params.unshift(params.slice!(function.out_index, function.retvals.size))
385
+ end
386
+
387
+ code = "#{prefix}#{dispatch}(#{params.join(", ")});"
388
+ code = "return #{code}" unless function.retvals.empty?
389
+ code
390
+ end
391
+
392
+ def generate_dispatch_params(function, params)
393
+ params.map do |param|
394
+ type =
395
+ case param[:type]
396
+ when "Tensor"
397
+ if param[:optional]
398
+ if function.out?
399
+ "const Tensor &"
400
+ else
401
+ # TODO
402
+ # "const c10::optional<at::Tensor> &"
403
+ "const OptionalTensor &"
404
+ end
405
+ elsif param[:modifier]
406
+ if param[:modifier].include?("!") && function.retvals.size > 1
407
+ "Tensor &"
408
+ else
409
+ "Tensor"
410
+ end
411
+ else
412
+ "const Tensor &"
413
+ end
414
+ when "Tensor[]"
415
+ "TensorList"
416
+ when "int"
417
+ "int64_t"
418
+ when "float"
419
+ "double"
420
+ when /\Aint\[/
421
+ "IntArrayRef"
422
+ when "str"
423
+ "std::string"
424
+ when "Scalar", "bool", "ScalarType", "Layout", "Device", "Storage", "Generator", "MemoryFormat", "Storage"
425
+ param[:type]
426
+ else
427
+ raise "Unknown type: #{param[:type]} (#{function.name})"
428
+ end
429
+
430
+ if param[:optional] && param[:type] != "Tensor"
431
+ type = "c10::optional<#{type}>"
432
+ end
433
+
434
+ "#{type} #{param[:name]}"
435
+ end
436
+ end
437
+
438
+ def generate_dispatch_retval(function)
439
+ types = function.retvals.map { |r| r[:type] }
440
+
441
+ case types
442
+ when []
443
+ "void"
444
+ when ["bool"]
445
+ "bool"
446
+ when ["int"]
447
+ "int64_t"
448
+ when ["float"]
449
+ "double"
450
+ when ["Scalar"]
451
+ "Scalar"
452
+ when ["ScalarType"]
453
+ "ScalarType"
454
+ when ["QScheme"]
455
+ "QScheme"
456
+ when ["Tensor"]
457
+ "Tensor"
458
+ when ["Tensor[]"]
459
+ "std::vector<Tensor>"
460
+ when ["Tensor", "Tensor"]
461
+ "std::tuple<Tensor,Tensor>"
462
+ when ["Tensor", "Tensor", "Tensor"]
463
+ "std::tuple<Tensor,Tensor,Tensor>"
464
+ when ["Tensor", "Tensor", "Tensor", "Tensor"]
465
+ "std::tuple<Tensor,Tensor,Tensor,Tensor>"
466
+ when ["Tensor", "Tensor", "Tensor", "Tensor", "Tensor"]
467
+ "std::tuple<Tensor,Tensor,Tensor,Tensor,Tensor>"
468
+ when ["Tensor", "Tensor", "float", "int"]
469
+ "std::tuple<Tensor,Tensor,float,int>"
470
+ else
471
+ raise "Unknown retvals: #{types}"
472
+ end
473
+ end
474
+
475
+ def generate_signature(function, type, skip_out: false)
476
+ params = function.params.dup
477
+ if function.out?
478
+ if skip_out
479
+ # remove out
480
+ params.slice!(function.out_index, function.retvals.size)
481
+ elsif function.retvals.size > 1 && params[function.out_index, function.retvals.size].all? { |r| r[:type] == "Tensor" }
482
+ # combine tensor into tensorlist
483
+ list_size = function.retvals.size
484
+ params.slice!(function.out_index, list_size)
485
+ params.insert(function.out_index, {name: "out", type: "Tensor[#{list_size}]", list_size: list_size, keyword_only: true})
486
+ end
487
+ end
488
+
489
+ parts = params.select { |v| !v[:keyword_only] && !(type == "tensor" && v[:name] == "self") }
490
+ keyword_only_parts = params.select { |v| v[:keyword_only] }
491
+ if keyword_only_parts.any?
492
+ parts << "*"
493
+ parts.concat(keyword_only_parts)
494
+ end
495
+
496
+ "#{function.base_name}(#{parts.map { |v| signature_param(v) }.join(", ")})"
497
+ end
498
+
499
+ def signature_param(param)
500
+ return "*" if param == "*"
501
+
502
+ name = param[:name]
503
+ name = "input" if name == "self"
504
+
505
+ sig = "#{signature_type(param)} #{name}"
506
+ case param[:default]
507
+ when nil
508
+ # do nothing
509
+ when "[]"
510
+ sig += "=None"
511
+ when "Mean"
512
+ sig += "=at::Reduction::Mean"
513
+ else
514
+ sig += "=#{param[:default]}"
515
+ end
516
+
517
+ # hack
518
+ sig += "=None" if param[:name] == "out"
519
+
520
+ sig
521
+ end
522
+
523
+ def signature_type(param)
524
+ type =
525
+ case param[:type]
526
+ when "Tensor", /\ATensor\([a-z]!?\)\z/
527
+ "Tensor"
528
+ when /\Tensor\[\d*\]\z/
529
+ "TensorList"
530
+ when /\ADimname\[\d*\]\z/
531
+ "DirnameList"
532
+ when /\Aint\[\d*\]\z/
533
+ "IntArrayRef"
534
+ when "int"
535
+ "int64_t"
536
+ when "float"
537
+ "double"
538
+ when "str"
539
+ "std::string"
540
+ when "Scalar", "Dimname", "bool", "ScalarType", "Layout", "Device", "Generator", "MemoryFormat", "Storage"
541
+ param[:type]
542
+ else
543
+ raise "Unknown type: #{param[:type]}"
544
+ end
545
+
546
+ type += "[#{param[:list_size]}]" if param[:list_size]
547
+ type += "?" if param[:optional]
548
+ type
549
+ end