torch-rb 0.3.4 → 0.4.1

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: fa01dc7cfd168494f1066f4e692eb08dee7f419126b2d10a6eb5b2b22fe01526
4
- data.tar.gz: 9e69a7da9ecc85c51bd9cfb42b93f1b1aa6148958ed8bc16ffb0555b6d42159d
3
+ metadata.gz: a40eda909da15ec34573f3fa447519c3cb5553092d6eb35189d02c88154f1669
4
+ data.tar.gz: cb795ca4c53189534f306c874f90db531c3b14fd17a14d83a592d2e73dd255ff
5
5
  SHA512:
6
- metadata.gz: 36bc620212235e57a40d9e791b3eae85fbd73c7224a5445e2e94e49c7a3ec4cb638024193d8e47817e577a72ef96976fdf2851f3607fd3ea1712255bca0e1ec1
7
- data.tar.gz: 888971d06717b610020644ed720ae48fa645a90790d54692ac64a565434b550ce4f778688f0a22806f018656dd9d6183d6510fc458c8c5e82c0629951489279c
6
+ metadata.gz: 750e2103e8ab7b029f7f3acd439088ba41890d1aabab5761d1ea768f556ae7550d2cef91545e83eebb9b09ef9b76d9dac939f9fa80a7c1478dd7d4668a2a6c0e
7
+ data.tar.gz: d22f3569dc06cc653bfc0d2786eb952ffad858dd2b235f736ebbabbc2f26b73f4c5b92c3776f2004b8076818730de1f1012b7a9ef615a84909c806abf7e96f52
@@ -1,3 +1,31 @@
1
+ ## 0.4.1 (2020-10-12)
2
+
3
+ - Fixed installation error with Ruby < 2.7
4
+
5
+ ## 0.4.0 (2020-09-27)
6
+
7
+ - Improved performance of methods
8
+ - Improved performance of tensor indexing
9
+
10
+ ## 0.3.7 (2020-09-22)
11
+
12
+ - Improved performance
13
+ - Added `Upsample`
14
+ - Added support for passing tensor class to `type` method
15
+ - Fixed error with buffers on GPU
16
+ - Fixed error with `new_full`
17
+ - Fixed issue with `numo` method and non-contiguous tensors
18
+
19
+ ## 0.3.6 (2020-09-17)
20
+
21
+ - Added `inplace` option for leaky ReLU
22
+ - Fixed error with methods that return a tensor list (`chunk`, `split`, and `unbind`)
23
+ - Fixed error with buffers on GPU
24
+
25
+ ## 0.3.5 (2020-09-04)
26
+
27
+ - Fixed error with data loader (due to `dtype` of `randperm`)
28
+
1
29
  ## 0.3.4 (2020-08-26)
2
30
 
3
31
  - Added `Torch.clamp` method
data/README.md CHANGED
@@ -402,6 +402,7 @@ Here are a few full examples:
402
402
  - [Image classification with MNIST](examples/mnist) ([日本語版](https://qiita.com/kojix2/items/c19c36dc1bf73ea93409))
403
403
  - [Collaborative filtering with MovieLens](examples/movielens)
404
404
  - [Sequence models and word embeddings](examples/nlp)
405
+ - [Generative adversarial networks](examples/gan)
405
406
 
406
407
  ## LibTorch Installation
407
408
 
@@ -415,7 +416,7 @@ Here’s the list of compatible versions.
415
416
 
416
417
  Torch.rb | LibTorch
417
418
  --- | ---
418
- 0.3.0-0.3.4 | 1.6.0
419
+ 0.3.0+ | 1.6.0
419
420
  0.2.0-0.2.7 | 1.5.0-1.5.1
420
421
  0.1.8 | 1.4.0
421
422
  0.1.0-0.1.7 | 1.3.1
@@ -0,0 +1,134 @@
1
+ class Function
2
+ attr_reader :definition, :params, :retvals
3
+
4
+ def initialize(definition)
5
+ @definition = definition
6
+ @params, @retvals = parse_func
7
+ end
8
+
9
+ def name
10
+ func.split("(", 2).first
11
+ end
12
+
13
+ def base_name
14
+ name.split(".").first
15
+ end
16
+
17
+ def func
18
+ definition["func"]
19
+ end
20
+
21
+ def python_module
22
+ definition["python_module"]
23
+ end
24
+
25
+ def variants
26
+ (definition["variants"] || "function").split(", ")
27
+ end
28
+
29
+ def out_index
30
+ params.index { |v| v[:modifier].to_s.include?("!") } if base_name[-1] != "_" && retvals.any?
31
+ end
32
+
33
+ def out?
34
+ !out_index.nil?
35
+ end
36
+
37
+ private
38
+
39
+ def parse_func
40
+ input, output = func.split(/\s*->\s*/)
41
+ [generate_params(input), generate_retvals(output)]
42
+ end
43
+
44
+ def generate_params(input)
45
+ input = input.split("(", 2).last.chomp(")").split(/\s*,\s+/)
46
+
47
+ keyword_only = false
48
+ params = []
49
+ input.each do |i|
50
+ if i == "*"
51
+ keyword_only = true
52
+ next
53
+ end
54
+
55
+ type, name = i.split(/\s+/)
56
+
57
+ if name.include?("=")
58
+ name, default = name.split("=", 2)
59
+ end
60
+
61
+ optional = false
62
+ if type.include?("?")
63
+ optional = true unless ["dtype", "device", "layout", "pin_memory"].include?(name)
64
+ type = type.delete("?")
65
+ end
66
+
67
+ type, modifier = extract_modifier(type)
68
+
69
+ if type.include?("[")
70
+ list_size = /\[(.*)\]/.match(type)[1]
71
+ list_size = nil if list_size.empty?
72
+ end
73
+
74
+ if name == "dtype" && (base_name.start_with?("randperm") || base_name == "tril_indices" || base_name == "triu_indices")
75
+ # dtype hack
76
+ # https://github.com/pytorch/pytorch/blob/v1.6.0/tools/autograd/gen_python_functions.py#L1307-L1311
77
+ default = "torch.int64"
78
+ end
79
+
80
+ params << {
81
+ name: name,
82
+ type: type,
83
+ default: default,
84
+ keyword_only: keyword_only,
85
+ optional: optional,
86
+ modifier: modifier,
87
+ list_size: list_size
88
+ }
89
+ end
90
+
91
+ if (params.map { |v| v[:name] } & ["dtype", "device", "layout", "pin_memory"]).size == 4
92
+ params << {
93
+ name: "requires_grad",
94
+ type: "bool",
95
+ default: "False",
96
+ keyword_only: true,
97
+ optional: false,
98
+ modifier: nil,
99
+ list_size: nil
100
+ }
101
+ end
102
+
103
+ params
104
+ end
105
+
106
+ def generate_retvals(output)
107
+ output =
108
+ if output == "()"
109
+ []
110
+ elsif output[0] == "("
111
+ output[1..-2].split(/\s*,\s*/)
112
+ else
113
+ [output]
114
+ end
115
+
116
+ retvals = []
117
+ output.each do |o|
118
+ type, name = o.split(/\s+/)
119
+ type, modifier = extract_modifier(type)
120
+ retvals << {name: name, type: type, modifier: modifier}
121
+ end
122
+ retvals
123
+ end
124
+
125
+ # Tensor(a), Tensor(a!), Tensor(a)[]
126
+ def extract_modifier(type)
127
+ if type.include?("(")
128
+ parts = type.split(/[\(\)]/, 3)
129
+ modifier = parts.delete_at(1)
130
+ type = parts.join("")
131
+ end
132
+ [type, modifier]
133
+ end
134
+ end
@@ -0,0 +1,549 @@
1
+ require "yaml"
2
+ # use require_relative for
3
+ # rake generate:function (without bundle)
4
+ require_relative "function"
5
+
6
+ def generate_functions
7
+ functions = load_functions
8
+ functions = skip_functions(functions)
9
+ functions = group_functions(functions)
10
+
11
+ generate_files("torch", :define_singleton_method, functions[:torch])
12
+ generate_files("tensor", :define_method, functions[:tensor])
13
+ generate_files("nn", :define_singleton_method, functions[:nn])
14
+ end
15
+
16
+ def load_functions
17
+ path = File.expand_path("native_functions.yaml", __dir__)
18
+ YAML.load_file(path).map { |f| Function.new(f) }.sort_by(&:name)
19
+ end
20
+
21
+ def skip_functions(functions)
22
+ functions.reject do |f|
23
+ f.base_name.start_with?("_") ||
24
+ f.base_name.include?("_backward") ||
25
+ f.base_name.include?("_forward") ||
26
+ f.base_name == "to" ||
27
+ # in ext.cpp
28
+ f.base_name == "index" ||
29
+ f.base_name == "index_put_" ||
30
+ # need to add to ext.cpp
31
+ f.base_name == "index_put" ||
32
+ # not supported yet
33
+ f.func.include?("Dimname") ||
34
+ f.func.include?("ConstQuantizerPtr")
35
+ end
36
+ end
37
+
38
+ def group_functions(functions)
39
+ nn_functions, other_functions = functions.partition { |f| f.python_module == "nn" }
40
+ torch_functions = other_functions.select { |f| f.variants.include?("function") }
41
+ tensor_functions = other_functions.select { |f| f.variants.include?("method") }
42
+
43
+ {torch: torch_functions, tensor: tensor_functions, nn: nn_functions}
44
+ end
45
+
46
+ def generate_files(type, def_method, functions)
47
+ method_defs = []
48
+ attach_defs = []
49
+ functions.group_by(&:base_name).each do |name, grouped_functions|
50
+ method_defs << generate_method_def(name, grouped_functions, type, def_method)
51
+ attach_defs << generate_attach_def(name, type, def_method)
52
+ end
53
+ write_header(type)
54
+ write_body(type, method_defs, attach_defs)
55
+ end
56
+
57
+ def write_header(type)
58
+ template = <<~EOS
59
+ // generated by rake generate:functions
60
+ // do not edit by hand
61
+
62
+ #pragma once
63
+
64
+ void add_%{type}_functions(Module m);
65
+ EOS
66
+
67
+ contents = template % {type: type}
68
+ write_file("#{type}_functions.h", contents)
69
+ end
70
+
71
+ def write_body(type, method_defs, attach_defs)
72
+ cuda_lazy_init = %{\n#include "torch/csrc/utils/cuda_lazy_init.h"\n} unless type == "nn"
73
+
74
+ template = <<~EOS
75
+ // generated by rake generate:functions
76
+ // do not edit by hand
77
+
78
+ #include <torch/torch.h>
79
+ #include <rice/Module.hpp>
80
+
81
+ #include "ruby_arg_parser.h"
82
+ #include "templates.h"
83
+ #include "wrap_outputs.h"
84
+ %{cuda_lazy_init}
85
+ %{method_defs}
86
+ void add_%{type}_functions(Module m) {
87
+ %{attach_defs}
88
+ }
89
+ EOS
90
+
91
+ contents = template % {
92
+ type: type,
93
+ method_defs: method_defs.join("\n"),
94
+ attach_defs: attach_defs.join("\n "),
95
+ cuda_lazy_init: cuda_lazy_init
96
+ }
97
+ write_file("#{type}_functions.cpp", contents)
98
+ end
99
+
100
+ def write_file(name, contents)
101
+ path = File.expand_path("../ext/torch", __dir__)
102
+ File.write(File.join(path, name), contents)
103
+ end
104
+
105
+ def generate_attach_def(name, type, def_method)
106
+ ruby_name =
107
+ if name.end_with?("_")
108
+ "#{name[0..-2]}!"
109
+ elsif name.start_with?("is_")
110
+ "#{name[3..-1]}?"
111
+ else
112
+ name
113
+ end
114
+
115
+ ruby_name = "_#{ruby_name}" if ["size", "stride", "random!"].include?(ruby_name)
116
+
117
+ # cast for Ruby < 2.7 https://github.com/thisMagpie/fftw/issues/22#issuecomment-49508900
118
+ cast = RUBY_VERSION.to_f > 2.7 ? "" : "(VALUE (*)(...)) "
119
+
120
+ "rb_#{def_method}(m, \"#{ruby_name}\", #{cast}#{type}_#{name}, -1);"
121
+ end
122
+
123
+ def generate_method_def(name, functions, type, def_method)
124
+ assign_self = type == "tensor" ? "\n Tensor& self = from_ruby<Tensor&>(self_);" : ""
125
+
126
+ functions = group_overloads(functions, type)
127
+ signatures = functions.map { |f| f["signature"] }
128
+ max_args = signatures.map { |s| s.count(",") - s.count("*") }.max + 1
129
+
130
+ template = <<~EOS
131
+ // #{name}
132
+ static VALUE #{type}_#{name}(int argc, VALUE* argv, VALUE self_)
133
+ {
134
+ HANDLE_TH_ERRORS#{assign_self}
135
+ static RubyArgParser parser({
136
+ #{signatures.map(&:inspect).join(",\n ")}
137
+ });
138
+ std::vector<VALUE> parsed_args(#{max_args});
139
+ auto _r = parser.parse(self_, argc, argv, parsed_args);
140
+ #{add_dispatches(functions, def_method)}
141
+ END_HANDLE_TH_ERRORS
142
+ }
143
+ EOS
144
+ end
145
+
146
+ def indent(code)
147
+ code.split("\n").join("\n ")
148
+ end
149
+
150
+ def add_dispatches(functions, def_method)
151
+ if functions.size == 1
152
+ add_dispatch(functions.first, def_method)
153
+ else
154
+ body = []
155
+ functions.each_with_index do |f, i|
156
+ body << "case #{i}: {
157
+ #{add_dispatch(f, def_method).split("\n").join("\n ")}
158
+ }"
159
+ end
160
+
161
+ "switch (_r.idx) {
162
+ #{body.join("\n ")}
163
+ }
164
+ RETURN_NIL"
165
+ end
166
+ end
167
+
168
+ def add_dispatch(function, def_method)
169
+ if function["out"] && function["out"] != function["base"]
170
+ base_code = generate_dispatch(function["base"], def_method)
171
+ out_code = generate_dispatch(function["out"], def_method)
172
+ out_index = function["out"].out_index
173
+
174
+ return "if (_r.isNone(#{out_index})) {
175
+ #{indent(base_code)}
176
+ } else {
177
+ #{indent(out_code)}
178
+ }"
179
+ else
180
+ generate_dispatch(function["base"], def_method)
181
+ end
182
+ end
183
+
184
+ def group_overloads(functions, type)
185
+ grouped = Hash.new { |hash, key| hash[key] = {} }
186
+
187
+ functions.each do |function|
188
+ signature = generate_signature(function, type, skip_out: true)
189
+ v = grouped[signature]
190
+ if function.out?
191
+ v["out"] = function
192
+ v["signature"] = generate_signature(function, type)
193
+
194
+ # for now
195
+ v["base"] ||= function
196
+ else
197
+ v["base"] = function
198
+ v["signature"] ||= signature
199
+ end
200
+ end
201
+
202
+ puts "Missing base: #{functions.first.name}" if grouped.any? { |_, v| !v["base"] }
203
+ sort_functions(grouped.values)
204
+ end
205
+
206
+ def sort_functions(functions)
207
+ # TODO
208
+ functions.sort_by { |f| f["out"] ? 1 : 0 }
209
+ end
210
+
211
+ def generate_dispatch(function, def_method)
212
+ cpp_name = function.base_name
213
+ cpp_name += "_out" if function.out?
214
+
215
+ remove_self = def_method == :define_method
216
+
217
+ params = function.params.map(&:dup)
218
+ set_param_position(params, remove_self)
219
+ params, opt_params = split_opt_params(params)
220
+ opt_index = opt_params.map { |v| v[:position] }.min if opt_params.any?
221
+
222
+ cpp_params = generate_dispatch_params(function, params)
223
+ if opt_index
224
+ cpp_params.insert(remove_self ? opt_index + 1 : opt_index, "const TensorOptions & options")
225
+ end
226
+
227
+ retval = generate_dispatch_retval(function)
228
+ dispatch_code = generate_dispatch_code(function, def_method, params, opt_index, remove_self)
229
+ function_code = generate_function_code(function, cpp_name, params, opt_index, remove_self)
230
+
231
+ out_var = generate_out_var(function.out_index, function.retvals.size) if function.out? && function.retvals.size > 1 && function.retvals.all? { |v| v[:type] == "Tensor" }
232
+ tensor_options = generate_tensor_options(function, opt_params) if opt_params.any?
233
+
234
+ "// #{function.func}#{tensor_options}#{out_var}
235
+ auto dispatch_#{cpp_name} = [](#{cpp_params.join(", ")}) -> #{retval} {
236
+ // in future, release GVL
237
+ #{dispatch_code}
238
+ };
239
+ #{function_code}"
240
+ end
241
+
242
+ def generate_out_var(out_index, size)
243
+ "\n auto out = _r.tensorlist_n<#{size}>(#{out_index});"
244
+ end
245
+
246
+ def set_param_position(params, remove_self)
247
+ i = 0
248
+ params.each do |v|
249
+ next if remove_self && v[:name] == "self"
250
+ v[:position] = i
251
+ i += 1
252
+ end
253
+ end
254
+
255
+ def split_opt_params(params)
256
+ option_names = ["dtype", "device", "layout", "requires_grad", "pin_memory"]
257
+
258
+ opt_params, other_params = params.partition { |v, i| option_names.include?(v[:name]) }
259
+ if opt_params.size >= 4
260
+ [other_params, opt_params]
261
+ else
262
+ [params, []]
263
+ end
264
+ end
265
+
266
+ def generate_tensor_options(function, opt_params)
267
+ code = "\n const auto options = TensorOptions()"
268
+ order = ["dtype", "device", "layout", "requires_grad", "pin_memory"]
269
+ opt_params.sort_by { |v| order.index(v[:name]) }.each do |opt|
270
+ i = opt[:position]
271
+
272
+ c =
273
+ case opt[:name]
274
+ when "dtype"
275
+ if function.base_name == "arange"
276
+ "dtype(_r.scalartypeOptional(#{i}))"
277
+ else
278
+ "dtype(_r.scalartype(#{i}))"
279
+ end
280
+ when "device"
281
+ "device(_r.device(#{i}))"
282
+ when "layout"
283
+ "layout(_r.layoutOptional(#{i}))"
284
+ when "requires_grad"
285
+ "requires_grad(_r.toBool(#{i}))"
286
+ when "pin_memory"
287
+ "pinned_memory(_r.toBool(#{i}))"
288
+ end
289
+
290
+ code += "\n .#{c}"
291
+ end
292
+
293
+ "#{code};\n torch::utils::maybe_initialize_cuda(options);"
294
+ end
295
+
296
+ def generate_function_code(function, cpp_name, params, opt_index, remove_self)
297
+ params = generate_function_params(function, params, remove_self)
298
+ if opt_index
299
+ opt_index += 1 if remove_self
300
+ params.insert(opt_index, "options")
301
+ end
302
+
303
+ code = "dispatch_#{cpp_name}(#{params.join(", ")})"
304
+ if function.retvals.empty?
305
+ "#{code};\nRETURN_NIL"
306
+ else
307
+ "return wrap(#{code});"
308
+ end
309
+ end
310
+
311
+ def generate_function_params(function, params, remove_self)
312
+ out_var = function.out? && function.retvals.size > 1 && function.retvals.all? { |v| v[:type] == "Tensor" }
313
+
314
+ i = 0
315
+ params.map do |param|
316
+ i += 1
317
+
318
+ next "self" if remove_self && param[:name] == "self"
319
+ if out_var && i > function.out_index
320
+ next "out[#{i - function.out_index - 1}]"
321
+ end
322
+
323
+ func =
324
+ case param[:type]
325
+ when "Tensor"
326
+ "tensor"
327
+ when "Tensor[]"
328
+ "tensorlist"
329
+ when /\Aint\[/
330
+ "intlist"
331
+ when "Scalar"
332
+ "scalar"
333
+ when "bool"
334
+ "toBool"
335
+ when "int"
336
+ "toInt64"
337
+ when "float"
338
+ "toDouble"
339
+ when "ScalarType"
340
+ "scalartype"
341
+ when "str"
342
+ "string"
343
+ when "Generator"
344
+ "generator"
345
+ when "MemoryFormat"
346
+ "memoryformat"
347
+ when "Storage"
348
+ "storage"
349
+ else
350
+ raise "Unknown type: #{param[:type]} (#{function.name})"
351
+ end
352
+
353
+ if param[:optional]
354
+ func =
355
+ case func
356
+ when "tensor"
357
+ if function.out?
358
+ "tensor"
359
+ else
360
+ "optionalTensor"
361
+ end
362
+ when "generator", "tensorlist", "intlist"
363
+ func
364
+ else
365
+ "#{func}Optional"
366
+ end
367
+ end
368
+
369
+ "_r.#{func}(#{param[:position]})"
370
+ end
371
+ end
372
+
373
+ def generate_dispatch_code(function, def_method, params, opt_index, remove_self)
374
+ # torch::empty sets requires_grad by at::empty doesn't
375
+ # https://github.com/pytorch/pytorch/issues/36455
376
+ prefix = remove_self ? "self." : (opt_index ? "torch::" : "at::")
377
+ dispatch = function.out? ? "#{function.base_name}_out" : function.base_name
378
+
379
+ params = params.map { |v| v[:name] }
380
+ params.reject! { |v| v == "self" } if remove_self
381
+ params.insert(opt_index, "options") if opt_index
382
+
383
+ if function.out_index
384
+ params.unshift(params.slice!(function.out_index, function.retvals.size))
385
+ end
386
+
387
+ code = "#{prefix}#{dispatch}(#{params.join(", ")});"
388
+ code = "return #{code}" unless function.retvals.empty?
389
+ code
390
+ end
391
+
392
+ def generate_dispatch_params(function, params)
393
+ params.map do |param|
394
+ type =
395
+ case param[:type]
396
+ when "Tensor"
397
+ if param[:optional]
398
+ if function.out?
399
+ "const Tensor &"
400
+ else
401
+ # TODO
402
+ # "const c10::optional<at::Tensor> &"
403
+ "const OptionalTensor &"
404
+ end
405
+ elsif param[:modifier]
406
+ if param[:modifier].include?("!") && function.retvals.size > 1
407
+ "Tensor &"
408
+ else
409
+ "Tensor"
410
+ end
411
+ else
412
+ "const Tensor &"
413
+ end
414
+ when "Tensor[]"
415
+ "TensorList"
416
+ when "int"
417
+ "int64_t"
418
+ when "float"
419
+ "double"
420
+ when /\Aint\[/
421
+ "IntArrayRef"
422
+ when "str"
423
+ "std::string"
424
+ when "Scalar", "bool", "ScalarType", "Layout", "Device", "Storage", "Generator", "MemoryFormat", "Storage"
425
+ param[:type]
426
+ else
427
+ raise "Unknown type: #{param[:type]} (#{function.name})"
428
+ end
429
+
430
+ if param[:optional] && param[:type] != "Tensor"
431
+ type = "c10::optional<#{type}>"
432
+ end
433
+
434
+ "#{type} #{param[:name]}"
435
+ end
436
+ end
437
+
438
+ def generate_dispatch_retval(function)
439
+ types = function.retvals.map { |r| r[:type] }
440
+
441
+ case types
442
+ when []
443
+ "void"
444
+ when ["bool"]
445
+ "bool"
446
+ when ["int"]
447
+ "int64_t"
448
+ when ["float"]
449
+ "double"
450
+ when ["Scalar"]
451
+ "Scalar"
452
+ when ["ScalarType"]
453
+ "ScalarType"
454
+ when ["QScheme"]
455
+ "QScheme"
456
+ when ["Tensor"]
457
+ "Tensor"
458
+ when ["Tensor[]"]
459
+ "std::vector<Tensor>"
460
+ when ["Tensor", "Tensor"]
461
+ "std::tuple<Tensor,Tensor>"
462
+ when ["Tensor", "Tensor", "Tensor"]
463
+ "std::tuple<Tensor,Tensor,Tensor>"
464
+ when ["Tensor", "Tensor", "Tensor", "Tensor"]
465
+ "std::tuple<Tensor,Tensor,Tensor,Tensor>"
466
+ when ["Tensor", "Tensor", "Tensor", "Tensor", "Tensor"]
467
+ "std::tuple<Tensor,Tensor,Tensor,Tensor,Tensor>"
468
+ when ["Tensor", "Tensor", "float", "int"]
469
+ "std::tuple<Tensor,Tensor,float,int>"
470
+ else
471
+ raise "Unknown retvals: #{types}"
472
+ end
473
+ end
474
+
475
+ def generate_signature(function, type, skip_out: false)
476
+ params = function.params.dup
477
+ if function.out?
478
+ if skip_out
479
+ # remove out
480
+ params.slice!(function.out_index, function.retvals.size)
481
+ elsif function.retvals.size > 1 && params[function.out_index, function.retvals.size].all? { |r| r[:type] == "Tensor" }
482
+ # combine tensor into tensorlist
483
+ list_size = function.retvals.size
484
+ params.slice!(function.out_index, list_size)
485
+ params.insert(function.out_index, {name: "out", type: "Tensor[#{list_size}]", list_size: list_size, keyword_only: true})
486
+ end
487
+ end
488
+
489
+ parts = params.select { |v| !v[:keyword_only] && !(type == "tensor" && v[:name] == "self") }
490
+ keyword_only_parts = params.select { |v| v[:keyword_only] }
491
+ if keyword_only_parts.any?
492
+ parts << "*"
493
+ parts.concat(keyword_only_parts)
494
+ end
495
+
496
+ "#{function.base_name}(#{parts.map { |v| signature_param(v) }.join(", ")})"
497
+ end
498
+
499
+ def signature_param(param)
500
+ return "*" if param == "*"
501
+
502
+ name = param[:name]
503
+ name = "input" if name == "self"
504
+
505
+ sig = "#{signature_type(param)} #{name}"
506
+ case param[:default]
507
+ when nil
508
+ # do nothing
509
+ when "[]"
510
+ sig += "=None"
511
+ when "Mean"
512
+ sig += "=at::Reduction::Mean"
513
+ else
514
+ sig += "=#{param[:default]}"
515
+ end
516
+
517
+ # hack
518
+ sig += "=None" if param[:name] == "out"
519
+
520
+ sig
521
+ end
522
+
523
+ def signature_type(param)
524
+ type =
525
+ case param[:type]
526
+ when "Tensor", /\ATensor\([a-z]!?\)\z/
527
+ "Tensor"
528
+ when /\Tensor\[\d*\]\z/
529
+ "TensorList"
530
+ when /\ADimname\[\d*\]\z/
531
+ "DirnameList"
532
+ when /\Aint\[\d*\]\z/
533
+ "IntArrayRef"
534
+ when "int"
535
+ "int64_t"
536
+ when "float"
537
+ "double"
538
+ when "str"
539
+ "std::string"
540
+ when "Scalar", "Dimname", "bool", "ScalarType", "Layout", "Device", "Generator", "MemoryFormat", "Storage"
541
+ param[:type]
542
+ else
543
+ raise "Unknown type: #{param[:type]}"
544
+ end
545
+
546
+ type += "[#{param[:list_size]}]" if param[:list_size]
547
+ type += "?" if param[:optional]
548
+ type
549
+ end