torch-rb 0.3.7 → 0.5.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 8a1852ee3d1ecc7a29c23259b8c328a95030a270b7c11f37f22049177898652e
4
- data.tar.gz: 56823f1815d3c0c4d5d5c01ef76d781b792b3e4e7c68c0332a149b883a54c7c8
3
+ metadata.gz: 68b4f1b0599e97803bd0e07efdd7fff96e0e04369005dbd72db190df3cf4e1b4
4
+ data.tar.gz: 302f75c8b43b25ac06e49c7d8047c755a58d38d86e67c6b8df4f00a8579a2676
5
5
  SHA512:
6
- metadata.gz: bed15510cfeaa555d71f1e1f46ed8944893bd349a07c4316dcd63429fe76e13facd8794399ef97fc400d05796579f2e84822b62c98c71dc996e211ad04113ae2
7
- data.tar.gz: aa05e3645e363eda27274323cdb7fb316342074d1d5afe8f7ee6bfd9819da7883b43d084beb6b29011c631c04fdddc8e6789db41c7c84c53ba9ed152d3338b09
6
+ metadata.gz: d18a01b8d18f659fb2da9e4cf6e700de260c04f7400ce62e5850d79f03f5e58e5b700730c7bdddab0148e60be080b5b85bb48e53527653372527461da46863d7
7
+ data.tar.gz: 8aaa9e2ee75e2a1a64ae506a59eca35e150bf48568dc4c905fb30df6a61d56ab745383f31525fd3c9e6a259cf33adf154ba8d70476b75fd124b035dbf9b48f4f
@@ -1,3 +1,26 @@
1
+ ## 0.5.1 (2020-10-28)
2
+
3
+ - Fixed error with tensor classes and no arguments
4
+ - Fixed error with `stft` and `clamp` methods
5
+
6
+ ## 0.5.0 (2020-10-28)
7
+
8
+ - Updated LibTorch to 1.7.0
9
+ - Removed deprecated overload for `addcmul!` and `addcdiv!`
10
+
11
+ ## 0.4.2 (2020-10-27)
12
+
13
+ - Fixed errors with optimizer options
14
+
15
+ ## 0.4.1 (2020-10-12)
16
+
17
+ - Fixed installation error with Ruby < 2.7
18
+
19
+ ## 0.4.0 (2020-09-27)
20
+
21
+ - Improved performance of methods
22
+ - Improved performance of tensor indexing
23
+
1
24
  ## 0.3.7 (2020-09-22)
2
25
 
3
26
  - Improved performance
data/README.md CHANGED
@@ -416,7 +416,8 @@ Here’s the list of compatible versions.
416
416
 
417
417
  Torch.rb | LibTorch
418
418
  --- | ---
419
- 0.3.0-0.3.4 | 1.6.0
419
+ 0.5.0+ | 1.7.0
420
+ 0.3.0+ | 1.6.0
420
421
  0.2.0-0.2.7 | 1.5.0-1.5.1
421
422
  0.1.8 | 1.4.0
422
423
  0.1.0-0.1.7 | 1.3.1
@@ -0,0 +1,134 @@
1
+ class Function
2
+ attr_reader :definition, :params, :retvals
3
+
4
+ def initialize(definition)
5
+ @definition = definition
6
+ @params, @retvals = parse_func
7
+ end
8
+
9
+ def name
10
+ func.split("(", 2).first
11
+ end
12
+
13
+ def base_name
14
+ name.split(".").first
15
+ end
16
+
17
+ def func
18
+ definition["func"]
19
+ end
20
+
21
+ def python_module
22
+ definition["python_module"]
23
+ end
24
+
25
+ def variants
26
+ (definition["variants"] || "function").split(", ")
27
+ end
28
+
29
+ def out_index
30
+ params.index { |v| v[:modifier].to_s.include?("!") } if base_name[-1] != "_" && retvals.any?
31
+ end
32
+
33
+ def out?
34
+ !out_index.nil?
35
+ end
36
+
37
+ private
38
+
39
+ def parse_func
40
+ input, output = func.split(/\s*->\s*/)
41
+ [generate_params(input), generate_retvals(output)]
42
+ end
43
+
44
+ def generate_params(input)
45
+ input = input.split("(", 2).last.chomp(")").split(/\s*,\s+/)
46
+
47
+ keyword_only = false
48
+ params = []
49
+ input.each do |i|
50
+ if i == "*"
51
+ keyword_only = true
52
+ next
53
+ end
54
+
55
+ type, name = i.split(/\s+/)
56
+
57
+ if name.include?("=")
58
+ name, default = name.split("=", 2)
59
+ end
60
+
61
+ optional = false
62
+ if type.include?("?")
63
+ optional = true unless ["dtype", "device", "layout", "pin_memory"].include?(name)
64
+ type = type.delete("?")
65
+ end
66
+
67
+ type, modifier = extract_modifier(type)
68
+
69
+ if type.include?("[")
70
+ list_size = /\[(.*)\]/.match(type)[1]
71
+ list_size = nil if list_size.empty?
72
+ end
73
+
74
+ if name == "dtype" && (base_name.start_with?("randperm") || base_name == "tril_indices" || base_name == "triu_indices")
75
+ # dtype hack
76
+ # https://github.com/pytorch/pytorch/blob/v1.6.0/tools/autograd/gen_python_functions.py#L1307-L1311
77
+ default = "torch.int64"
78
+ end
79
+
80
+ params << {
81
+ name: name,
82
+ type: type,
83
+ default: default,
84
+ keyword_only: keyword_only,
85
+ optional: optional,
86
+ modifier: modifier,
87
+ list_size: list_size
88
+ }
89
+ end
90
+
91
+ if (params.map { |v| v[:name] } & ["dtype", "device", "layout", "pin_memory"]).size == 4
92
+ params << {
93
+ name: "requires_grad",
94
+ type: "bool",
95
+ default: "False",
96
+ keyword_only: true,
97
+ optional: false,
98
+ modifier: nil,
99
+ list_size: nil
100
+ }
101
+ end
102
+
103
+ params
104
+ end
105
+
106
+ def generate_retvals(output)
107
+ output =
108
+ if output == "()"
109
+ []
110
+ elsif output[0] == "("
111
+ output[1..-2].split(/\s*,\s*/)
112
+ else
113
+ [output]
114
+ end
115
+
116
+ retvals = []
117
+ output.each do |o|
118
+ type, name = o.split(/\s+/)
119
+ type, modifier = extract_modifier(type)
120
+ retvals << {name: name, type: type, modifier: modifier}
121
+ end
122
+ retvals
123
+ end
124
+
125
+ # Tensor(a), Tensor(a!), Tensor(a)[]
126
+ def extract_modifier(type)
127
+ if type.include?("(")
128
+ parts = type.split(/[\(\)]/, 3)
129
+ modifier = parts.delete_at(1)
130
+ type = parts.join("")
131
+ end
132
+ [type, modifier]
133
+ end
134
+ end
@@ -0,0 +1,557 @@
1
+ require "yaml"
2
+ # use require_relative for
3
+ # rake generate:function (without bundle)
4
+ require_relative "function"
5
+
6
+ def generate_functions
7
+ functions = load_functions
8
+ functions = skip_functions(functions)
9
+ functions = group_functions(functions)
10
+
11
+ generate_files("torch", :define_singleton_method, functions[:torch])
12
+ generate_files("tensor", :define_method, functions[:tensor])
13
+ generate_files("nn", :define_singleton_method, functions[:nn])
14
+ end
15
+
16
+ def load_functions
17
+ path = File.expand_path("native_functions.yaml", __dir__)
18
+ YAML.load_file(path).map { |f| Function.new(f) }.sort_by(&:name)
19
+ end
20
+
21
+ def skip_functions(functions)
22
+ functions.reject do |f|
23
+ f.base_name.start_with?("_") ||
24
+ f.base_name.include?("_backward") ||
25
+ f.base_name.include?("_forward") ||
26
+ f.base_name == "to" ||
27
+ # in ext.cpp
28
+ f.base_name == "index" ||
29
+ f.base_name == "index_put_" ||
30
+ # need to add to ext.cpp
31
+ f.base_name == "index_put" ||
32
+ # not supported yet
33
+ f.func.include?("Dimname") ||
34
+ f.func.include?("ConstQuantizerPtr")
35
+ end
36
+ end
37
+
38
+ def group_functions(functions)
39
+ nn_functions, other_functions = functions.partition { |f| f.python_module == "nn" }
40
+ torch_functions = other_functions.select { |f| f.variants.include?("function") }
41
+ tensor_functions = other_functions.select { |f| f.variants.include?("method") }
42
+
43
+ {torch: torch_functions, tensor: tensor_functions, nn: nn_functions}
44
+ end
45
+
46
+ def generate_files(type, def_method, functions)
47
+ method_defs = []
48
+ attach_defs = []
49
+ functions.group_by(&:base_name).each do |name, grouped_functions|
50
+ method_defs << generate_method_def(name, grouped_functions, type, def_method)
51
+ attach_defs << generate_attach_def(name, type, def_method)
52
+ end
53
+ write_header(type)
54
+ write_body(type, method_defs, attach_defs)
55
+ end
56
+
57
+ def write_header(type)
58
+ template = <<~EOS
59
+ // generated by rake generate:functions
60
+ // do not edit by hand
61
+
62
+ #pragma once
63
+
64
+ void add_%{type}_functions(Module m);
65
+ EOS
66
+
67
+ contents = template % {type: type}
68
+ write_file("#{type}_functions.h", contents)
69
+ end
70
+
71
+ def write_body(type, method_defs, attach_defs)
72
+ cuda_lazy_init = %{\n#include "torch/csrc/utils/cuda_lazy_init.h"\n} unless type == "nn"
73
+
74
+ template = <<~EOS
75
+ // generated by rake generate:functions
76
+ // do not edit by hand
77
+
78
+ #include <torch/torch.h>
79
+ #include <rice/Module.hpp>
80
+
81
+ #include "ruby_arg_parser.h"
82
+ #include "templates.h"
83
+ #include "wrap_outputs.h"
84
+ %{cuda_lazy_init}
85
+ %{method_defs}
86
+ void add_%{type}_functions(Module m) {
87
+ %{attach_defs}
88
+ }
89
+ EOS
90
+
91
+ contents = template % {
92
+ type: type,
93
+ method_defs: method_defs.join("\n"),
94
+ attach_defs: attach_defs.join("\n "),
95
+ cuda_lazy_init: cuda_lazy_init
96
+ }
97
+ write_file("#{type}_functions.cpp", contents)
98
+ end
99
+
100
+ def write_file(name, contents)
101
+ path = File.expand_path("../ext/torch", __dir__)
102
+ File.write(File.join(path, name), contents)
103
+ end
104
+
105
+ def generate_attach_def(name, type, def_method)
106
+ ruby_name =
107
+ if name.end_with?("_")
108
+ "#{name[0..-2]}!"
109
+ elsif name.start_with?("is_")
110
+ "#{name[3..-1]}?"
111
+ else
112
+ name
113
+ end
114
+
115
+ ruby_name = "_#{ruby_name}" if ["size", "stride", "random!", "stft"].include?(ruby_name)
116
+
117
+ # cast for Ruby < 2.7 https://github.com/thisMagpie/fftw/issues/22#issuecomment-49508900
118
+ cast = RUBY_VERSION.to_f > 2.7 ? "" : "(VALUE (*)(...)) "
119
+
120
+ "rb_#{def_method}(m, \"#{ruby_name}\", #{cast}#{type}_#{name}, -1);"
121
+ end
122
+
123
+ def generate_method_def(name, functions, type, def_method)
124
+ assign_self = type == "tensor" ? "\n Tensor& self = from_ruby<Tensor&>(self_);" : ""
125
+
126
+ functions = group_overloads(functions, type)
127
+ signatures = functions.map { |f| f["signature"] }
128
+ max_args = signatures.map { |s| s.count(",") - s.count("*") }.max + 1
129
+
130
+ template = <<~EOS
131
+ // #{name}
132
+ static VALUE #{type}_#{name}(int argc, VALUE* argv, VALUE self_)
133
+ {
134
+ HANDLE_TH_ERRORS#{assign_self}
135
+ static RubyArgParser parser({
136
+ #{signatures.map(&:inspect).join(",\n ")}
137
+ });
138
+ std::vector<VALUE> parsed_args(#{max_args});
139
+ auto _r = parser.parse(self_, argc, argv, parsed_args);
140
+ #{add_dispatches(functions, def_method)}
141
+ END_HANDLE_TH_ERRORS
142
+ }
143
+ EOS
144
+ end
145
+
146
+ def indent(code)
147
+ code.split("\n").join("\n ")
148
+ end
149
+
150
+ def add_dispatches(functions, def_method)
151
+ if functions.size == 1
152
+ add_dispatch(functions.first, def_method)
153
+ else
154
+ body = []
155
+ functions.each_with_index do |f, i|
156
+ body << "case #{i}: {
157
+ #{add_dispatch(f, def_method).split("\n").join("\n ")}
158
+ }"
159
+ end
160
+
161
+ "switch (_r.idx) {
162
+ #{body.join("\n ")}
163
+ }
164
+ RETURN_NIL"
165
+ end
166
+ end
167
+
168
+ def add_dispatch(function, def_method)
169
+ if function["out"] && function["out"] != function["base"]
170
+ base_code = generate_dispatch(function["base"], def_method)
171
+ out_code = generate_dispatch(function["out"], def_method)
172
+ out_index = function["out"].out_index
173
+
174
+ return "if (_r.isNone(#{out_index})) {
175
+ #{indent(base_code)}
176
+ } else {
177
+ #{indent(out_code)}
178
+ }"
179
+ else
180
+ generate_dispatch(function["base"], def_method)
181
+ end
182
+ end
183
+
184
+ def group_overloads(functions, type)
185
+ grouped = Hash.new { |hash, key| hash[key] = {} }
186
+
187
+ functions.each do |function|
188
+ signature = generate_signature(function, type, skip_out: true)
189
+ v = grouped[signature]
190
+ if function.out?
191
+ v["out"] = function
192
+ v["signature"] = generate_signature(function, type)
193
+
194
+ # for now
195
+ v["base"] ||= function
196
+ else
197
+ v["base"] = function
198
+ v["signature"] ||= signature
199
+ end
200
+ end
201
+
202
+ puts "Missing base: #{functions.first.name}" if grouped.any? { |_, v| !v["base"] }
203
+ sort_functions(grouped.values)
204
+ end
205
+
206
+ def sort_functions(functions)
207
+ # TODO
208
+ functions.sort_by { |f| f["out"] ? 1 : 0 }
209
+ end
210
+
211
+ def generate_dispatch(function, def_method)
212
+ cpp_name = function.base_name
213
+ cpp_name += "_out" if function.out?
214
+
215
+ remove_self = def_method == :define_method
216
+
217
+ params = function.params.map(&:dup)
218
+ set_param_position(params, remove_self)
219
+ params, opt_params = split_opt_params(params)
220
+ opt_index = opt_params.map { |v| v[:position] }.min if opt_params.any?
221
+
222
+ cpp_params = generate_dispatch_params(function, params)
223
+ if opt_index
224
+ cpp_params.insert(remove_self ? opt_index + 1 : opt_index, "const TensorOptions & options")
225
+ end
226
+
227
+ retval = generate_dispatch_retval(function)
228
+ dispatch_code = generate_dispatch_code(function, def_method, params, opt_index, remove_self)
229
+ function_code = generate_function_code(function, cpp_name, params, opt_index, remove_self)
230
+
231
+ out_var = generate_out_var(function.out_index, function.retvals.size) if function.out? && function.retvals.size > 1 && function.retvals.all? { |v| v[:type] == "Tensor" }
232
+ tensor_options = generate_tensor_options(function, opt_params) if opt_params.any?
233
+
234
+ "// #{function.func}#{tensor_options}#{out_var}
235
+ auto dispatch_#{cpp_name} = [](#{cpp_params.join(", ")}) -> #{retval} {
236
+ // in future, release GVL
237
+ #{dispatch_code}
238
+ };
239
+ #{function_code}"
240
+ end
241
+
242
+ def generate_out_var(out_index, size)
243
+ "\n auto out = _r.tensorlist_n<#{size}>(#{out_index});"
244
+ end
245
+
246
+ def set_param_position(params, remove_self)
247
+ i = 0
248
+ params.each do |v|
249
+ next if remove_self && v[:name] == "self"
250
+ v[:position] = i
251
+ i += 1
252
+ end
253
+ end
254
+
255
+ def split_opt_params(params)
256
+ option_names = ["dtype", "device", "layout", "requires_grad", "pin_memory"]
257
+
258
+ opt_params, other_params = params.partition { |v, i| option_names.include?(v[:name]) }
259
+ if opt_params.size >= 4
260
+ [other_params, opt_params]
261
+ else
262
+ [params, []]
263
+ end
264
+ end
265
+
266
+ def generate_tensor_options(function, opt_params)
267
+ code = "\n const auto options = TensorOptions()"
268
+ order = ["dtype", "device", "layout", "requires_grad", "pin_memory"]
269
+ opt_params.sort_by { |v| order.index(v[:name]) }.each do |opt|
270
+ i = opt[:position]
271
+
272
+ c =
273
+ case opt[:name]
274
+ when "dtype"
275
+ if function.base_name == "arange"
276
+ "dtype(_r.scalartypeOptional(#{i}))"
277
+ else
278
+ "dtype(_r.scalartype(#{i}))"
279
+ end
280
+ when "device"
281
+ "device(_r.device(#{i}))"
282
+ when "layout"
283
+ "layout(_r.layoutOptional(#{i}))"
284
+ when "requires_grad"
285
+ "requires_grad(_r.toBool(#{i}))"
286
+ when "pin_memory"
287
+ "pinned_memory(_r.toBool(#{i}))"
288
+ end
289
+
290
+ code += "\n .#{c}"
291
+ end
292
+
293
+ "#{code};\n torch::utils::maybe_initialize_cuda(options);"
294
+ end
295
+
296
+ def generate_function_code(function, cpp_name, params, opt_index, remove_self)
297
+ params = generate_function_params(function, params, remove_self)
298
+ if opt_index
299
+ opt_index += 1 if remove_self
300
+ params.insert(opt_index, "options")
301
+ end
302
+
303
+ code = "dispatch_#{cpp_name}(#{params.join(", ")})"
304
+ if function.retvals.empty?
305
+ "#{code};\nRETURN_NIL"
306
+ else
307
+ "return wrap(#{code});"
308
+ end
309
+ end
310
+
311
+ def generate_function_params(function, params, remove_self)
312
+ out_var = function.out? && function.retvals.size > 1 && function.retvals.all? { |v| v[:type] == "Tensor" }
313
+
314
+ i = 0
315
+ params.map do |param|
316
+ i += 1
317
+
318
+ next "self" if remove_self && param[:name] == "self"
319
+ if out_var && i > function.out_index
320
+ next "out[#{i - function.out_index - 1}]"
321
+ end
322
+
323
+ func =
324
+ case param[:type]
325
+ when "Tensor"
326
+ "tensor"
327
+ when "Tensor[]"
328
+ "tensorlist"
329
+ when /\Aint\[/
330
+ "intlist"
331
+ when "float[]"
332
+ "doublelist"
333
+ when "Scalar"
334
+ "scalar"
335
+ when "bool"
336
+ "toBool"
337
+ when "int"
338
+ "toInt64"
339
+ when "float"
340
+ "toDouble"
341
+ when "ScalarType"
342
+ "scalartype"
343
+ when "str"
344
+ "string"
345
+ when "Generator"
346
+ "generator"
347
+ when "MemoryFormat"
348
+ "memoryformat"
349
+ when "Storage"
350
+ "storage"
351
+ else
352
+ raise "Unknown type: #{param[:type]} (#{function.name})"
353
+ end
354
+
355
+ if param[:optional]
356
+ func =
357
+ case func
358
+ when "tensor"
359
+ if function.out?
360
+ "tensor"
361
+ else
362
+ "optionalTensor"
363
+ end
364
+ when "generator", "tensorlist", "intlist"
365
+ func
366
+ else
367
+ "#{func}Optional"
368
+ end
369
+ end
370
+
371
+ "_r.#{func}(#{param[:position]})"
372
+ end
373
+ end
374
+
375
+ def generate_dispatch_code(function, def_method, params, opt_index, remove_self)
376
+ # torch::empty sets requires_grad by at::empty doesn't
377
+ # https://github.com/pytorch/pytorch/issues/36455
378
+ prefix = remove_self ? "self." : (opt_index ? "torch::" : "at::")
379
+ dispatch = function.out? ? "#{function.base_name}_out" : function.base_name
380
+
381
+ params = params.map { |v| v[:name] }
382
+ params.reject! { |v| v == "self" } if remove_self
383
+ params.insert(opt_index, "options") if opt_index
384
+
385
+ if function.out_index
386
+ params.unshift(params.slice!(function.out_index, function.retvals.size))
387
+ end
388
+
389
+ code = "#{prefix}#{dispatch}(#{params.join(", ")});"
390
+ code = "return #{code}" unless function.retvals.empty?
391
+ code
392
+ end
393
+
394
+ def generate_dispatch_params(function, params)
395
+ params.map do |param|
396
+ type =
397
+ case param[:type]
398
+ when "Tensor"
399
+ if param[:optional]
400
+ if function.out?
401
+ "const Tensor &"
402
+ else
403
+ # TODO
404
+ # "const c10::optional<at::Tensor> &"
405
+ "const OptionalTensor &"
406
+ end
407
+ elsif param[:modifier]
408
+ if param[:modifier].include?("!") && function.retvals.size > 1
409
+ "Tensor &"
410
+ else
411
+ "Tensor"
412
+ end
413
+ else
414
+ "const Tensor &"
415
+ end
416
+ when "Tensor[]"
417
+ "TensorList"
418
+ when "int"
419
+ "int64_t"
420
+ when "float"
421
+ "double"
422
+ when /\Aint\[/
423
+ "IntArrayRef"
424
+ when "float[]"
425
+ "ArrayRef<double>"
426
+ when "str"
427
+ "std::string"
428
+ when "Scalar", "bool", "ScalarType", "Layout", "Device", "Storage", "Generator", "MemoryFormat", "Storage"
429
+ param[:type]
430
+ else
431
+ raise "Unknown type: #{param[:type]} (#{function.name})"
432
+ end
433
+
434
+ if param[:optional] && param[:type] != "Tensor"
435
+ type = "c10::optional<#{type}>"
436
+ end
437
+
438
+ "#{type} #{param[:name]}"
439
+ end
440
+ end
441
+
442
+ def generate_dispatch_retval(function)
443
+ types = function.retvals.map { |r| r[:type] }
444
+
445
+ case types
446
+ when []
447
+ "void"
448
+ when ["bool"]
449
+ "bool"
450
+ when ["int"]
451
+ "int64_t"
452
+ when ["float"]
453
+ "double"
454
+ when ["Scalar"]
455
+ "Scalar"
456
+ when ["ScalarType"]
457
+ "ScalarType"
458
+ when ["QScheme"]
459
+ "QScheme"
460
+ when ["Tensor"]
461
+ "Tensor"
462
+ when ["Tensor[]"]
463
+ "std::vector<Tensor>"
464
+ when ["Tensor", "Tensor"]
465
+ "std::tuple<Tensor,Tensor>"
466
+ when ["Tensor", "Tensor", "Tensor"]
467
+ "std::tuple<Tensor,Tensor,Tensor>"
468
+ when ["Tensor", "Tensor", "Tensor", "Tensor"]
469
+ "std::tuple<Tensor,Tensor,Tensor,Tensor>"
470
+ when ["Tensor", "Tensor", "Tensor", "Tensor", "Tensor"]
471
+ "std::tuple<Tensor,Tensor,Tensor,Tensor,Tensor>"
472
+ when ["Tensor", "Tensor", "float", "int"]
473
+ "std::tuple<Tensor,Tensor,double,int>"
474
+ when ["float", "float"]
475
+ "std::tuple<double,double>"
476
+ else
477
+ raise "Unknown retvals: #{types}"
478
+ end
479
+ end
480
+
481
+ def generate_signature(function, type, skip_out: false)
482
+ params = function.params.dup
483
+ if function.out?
484
+ if skip_out
485
+ # remove out
486
+ params.slice!(function.out_index, function.retvals.size)
487
+ elsif function.retvals.size > 1 && params[function.out_index, function.retvals.size].all? { |r| r[:type] == "Tensor" }
488
+ # combine tensor into tensorlist
489
+ list_size = function.retvals.size
490
+ params.slice!(function.out_index, list_size)
491
+ params.insert(function.out_index, {name: "out", type: "Tensor[#{list_size}]", list_size: list_size, keyword_only: true})
492
+ end
493
+ end
494
+
495
+ parts = params.select { |v| !v[:keyword_only] && !(type == "tensor" && v[:name] == "self") }
496
+ keyword_only_parts = params.select { |v| v[:keyword_only] }
497
+ if keyword_only_parts.any?
498
+ parts << "*"
499
+ parts.concat(keyword_only_parts)
500
+ end
501
+
502
+ "#{function.base_name}(#{parts.map { |v| signature_param(v) }.join(", ")})"
503
+ end
504
+
505
+ def signature_param(param)
506
+ return "*" if param == "*"
507
+
508
+ name = param[:name]
509
+ name = "input" if name == "self"
510
+
511
+ sig = "#{signature_type(param)} #{name}"
512
+ case param[:default]
513
+ when nil
514
+ # do nothing
515
+ when "[]"
516
+ sig += "=None"
517
+ when "Mean"
518
+ sig += "=at::Reduction::Mean"
519
+ else
520
+ sig += "=#{param[:default]}"
521
+ end
522
+
523
+ # hack
524
+ sig += "=None" if param[:name] == "out"
525
+
526
+ sig
527
+ end
528
+
529
+ def signature_type(param)
530
+ type =
531
+ case param[:type]
532
+ when "Tensor", /\ATensor\([a-z]!?\)\z/
533
+ "Tensor"
534
+ when /\Tensor\[\d*\]\z/
535
+ "TensorList"
536
+ when /\ADimname\[\d*\]\z/
537
+ "DirnameList"
538
+ when /\Aint\[\d*\]\z/
539
+ "IntArrayRef"
540
+ when "int"
541
+ "int64_t"
542
+ when "float"
543
+ "double"
544
+ when "str"
545
+ "std::string"
546
+ when "Scalar", "Dimname", "bool", "ScalarType", "Layout", "Device", "Generator", "MemoryFormat", "Storage"
547
+ param[:type]
548
+ when "float[]"
549
+ "ArrayRef<double>"
550
+ else
551
+ raise "Unknown type: #{param[:type]}"
552
+ end
553
+
554
+ type += "[#{param[:list_size]}]" if param[:list_size]
555
+ type += "?" if param[:optional]
556
+ type
557
+ end