torch-rb 0.3.7 → 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 8a1852ee3d1ecc7a29c23259b8c328a95030a270b7c11f37f22049177898652e
4
- data.tar.gz: 56823f1815d3c0c4d5d5c01ef76d781b792b3e4e7c68c0332a149b883a54c7c8
3
+ metadata.gz: 6294a5809ed3ea9b3d3f9954b351b8c8e5d81c37e9318ae5044d243320166b20
4
+ data.tar.gz: 57d8036a0449497cbf040c8650c4dcdbe3e0f09191951d77e6d907c6b3c1013d
5
5
  SHA512:
6
- metadata.gz: bed15510cfeaa555d71f1e1f46ed8944893bd349a07c4316dcd63429fe76e13facd8794399ef97fc400d05796579f2e84822b62c98c71dc996e211ad04113ae2
7
- data.tar.gz: aa05e3645e363eda27274323cdb7fb316342074d1d5afe8f7ee6bfd9819da7883b43d084beb6b29011c631c04fdddc8e6789db41c7c84c53ba9ed152d3338b09
6
+ metadata.gz: c7aaf862b43f72f5ca8e7f6f6aca079195c11420db081f9c4f18f161610598d975fb81a036689dcd3158d200c33e39e67841bded2e728f2a2fb1b5b8d9e9dfad
7
+ data.tar.gz: afaf4c344951629248f79bb03246f6879555a58438d07107dea008aa2433d08115d56b1f9464cf20065c276417d3d90a59aecbfefa304208ee81ae313867bad6
@@ -1,3 +1,8 @@
1
+ ## 0.4.0 (2020-09-27)
2
+
3
+ - Improved performance of methods
4
+ - Improved performance of tensor indexing
5
+
1
6
  ## 0.3.7 (2020-09-22)
2
7
 
3
8
  - Improved performance
data/README.md CHANGED
@@ -416,7 +416,7 @@ Here’s the list of compatible versions.
416
416
 
417
417
  Torch.rb | LibTorch
418
418
  --- | ---
419
- 0.3.0-0.3.4 | 1.6.0
419
+ 0.3.0+ | 1.6.0
420
420
  0.2.0-0.2.7 | 1.5.0-1.5.1
421
421
  0.1.8 | 1.4.0
422
422
  0.1.0-0.1.7 | 1.3.1
@@ -0,0 +1,134 @@
1
+ class Function
2
+ attr_reader :definition, :params, :retvals
3
+
4
+ def initialize(definition)
5
+ @definition = definition
6
+ @params, @retvals = parse_func
7
+ end
8
+
9
+ def name
10
+ func.split("(", 2).first
11
+ end
12
+
13
+ def base_name
14
+ name.split(".").first
15
+ end
16
+
17
+ def func
18
+ definition["func"]
19
+ end
20
+
21
+ def python_module
22
+ definition["python_module"]
23
+ end
24
+
25
+ def variants
26
+ (definition["variants"] || "function").split(", ")
27
+ end
28
+
29
+ def out_index
30
+ params.index { |v| v[:modifier].to_s.include?("!") } if base_name[-1] != "_" && retvals.any?
31
+ end
32
+
33
+ def out?
34
+ !out_index.nil?
35
+ end
36
+
37
+ private
38
+
39
+ def parse_func
40
+ input, output = func.split(/\s*->\s*/)
41
+ [generate_params(input), generate_retvals(output)]
42
+ end
43
+
44
+ def generate_params(input)
45
+ input = input.split("(", 2).last.chomp(")").split(/\s*,\s+/)
46
+
47
+ keyword_only = false
48
+ params = []
49
+ input.each do |i|
50
+ if i == "*"
51
+ keyword_only = true
52
+ next
53
+ end
54
+
55
+ type, name = i.split(/\s+/)
56
+
57
+ if name.include?("=")
58
+ name, default = name.split("=", 2)
59
+ end
60
+
61
+ optional = false
62
+ if type.include?("?")
63
+ optional = true unless ["dtype", "device", "layout", "pin_memory"].include?(name)
64
+ type = type.delete("?")
65
+ end
66
+
67
+ type, modifier = extract_modifier(type)
68
+
69
+ if type.include?("[")
70
+ list_size = /\[(.*)\]/.match(type)[1]
71
+ list_size = nil if list_size.empty?
72
+ end
73
+
74
+ if name == "dtype" && (base_name.start_with?("randperm") || base_name == "tril_indices" || base_name == "triu_indices")
75
+ # dtype hack
76
+ # https://github.com/pytorch/pytorch/blob/v1.6.0/tools/autograd/gen_python_functions.py#L1307-L1311
77
+ default = "torch.int64"
78
+ end
79
+
80
+ params << {
81
+ name: name,
82
+ type: type,
83
+ default: default,
84
+ keyword_only: keyword_only,
85
+ optional: optional,
86
+ modifier: modifier,
87
+ list_size: list_size
88
+ }
89
+ end
90
+
91
+ if (params.map { |v| v[:name] } & ["dtype", "device", "layout", "pin_memory"]).size == 4
92
+ params << {
93
+ name: "requires_grad",
94
+ type: "bool",
95
+ default: "False",
96
+ keyword_only: true,
97
+ optional: false,
98
+ modifier: nil,
99
+ list_size: nil
100
+ }
101
+ end
102
+
103
+ params
104
+ end
105
+
106
+ def generate_retvals(output)
107
+ output =
108
+ if output == "()"
109
+ []
110
+ elsif output[0] == "("
111
+ output[1..-2].split(/\s*,\s*/)
112
+ else
113
+ [output]
114
+ end
115
+
116
+ retvals = []
117
+ output.each do |o|
118
+ type, name = o.split(/\s+/)
119
+ type, modifier = extract_modifier(type)
120
+ retvals << {name: name, type: type, modifier: modifier}
121
+ end
122
+ retvals
123
+ end
124
+
125
+ # Tensor(a), Tensor(a!), Tensor(a)[]
126
+ def extract_modifier(type)
127
+ if type.include?("(")
128
+ parts = type.split(/[\(\)]/, 3)
129
+ modifier = parts.delete_at(1)
130
+ type = parts.join("")
131
+ end
132
+ [type, modifier]
133
+ end
134
+ end
@@ -0,0 +1,546 @@
1
+ require "yaml"
2
+ # use require_relative for
3
+ # rake generate:function (without bundle)
4
+ require_relative "function"
5
+
6
+ def generate_functions
7
+ functions = load_functions
8
+ functions = skip_functions(functions)
9
+ functions = group_functions(functions)
10
+
11
+ generate_files("torch", :define_singleton_method, functions[:torch])
12
+ generate_files("tensor", :define_method, functions[:tensor])
13
+ generate_files("nn", :define_singleton_method, functions[:nn])
14
+ end
15
+
16
+ def load_functions
17
+ path = File.expand_path("native_functions.yaml", __dir__)
18
+ YAML.load_file(path).map { |f| Function.new(f) }.sort_by(&:name)
19
+ end
20
+
21
+ def skip_functions(functions)
22
+ functions.reject do |f|
23
+ f.base_name.start_with?("_") ||
24
+ f.base_name.include?("_backward") ||
25
+ f.base_name.include?("_forward") ||
26
+ f.base_name == "to" ||
27
+ # in ext.cpp
28
+ f.base_name == "index" ||
29
+ f.base_name == "index_put_" ||
30
+ # need to add to ext.cpp
31
+ f.base_name == "index_put" ||
32
+ # not supported yet
33
+ f.func.include?("Dimname") ||
34
+ f.func.include?("ConstQuantizerPtr")
35
+ end
36
+ end
37
+
38
+ def group_functions(functions)
39
+ nn_functions, other_functions = functions.partition { |f| f.python_module == "nn" }
40
+ torch_functions = other_functions.select { |f| f.variants.include?("function") }
41
+ tensor_functions = other_functions.select { |f| f.variants.include?("method") }
42
+
43
+ {torch: torch_functions, tensor: tensor_functions, nn: nn_functions}
44
+ end
45
+
46
+ def generate_files(type, def_method, functions)
47
+ method_defs = []
48
+ attach_defs = []
49
+ functions.group_by(&:base_name).each do |name, grouped_functions|
50
+ method_defs << generate_method_def(name, grouped_functions, type, def_method)
51
+ attach_defs << generate_attach_def(name, type, def_method)
52
+ end
53
+ write_header(type)
54
+ write_body(type, method_defs, attach_defs)
55
+ end
56
+
57
+ def write_header(type)
58
+ template = <<~EOS
59
+ // generated by rake generate:functions
60
+ // do not edit by hand
61
+
62
+ #pragma once
63
+
64
+ void add_%{type}_functions(Module m);
65
+ EOS
66
+
67
+ contents = template % {type: type}
68
+ write_file("#{type}_functions.h", contents)
69
+ end
70
+
71
+ def write_body(type, method_defs, attach_defs)
72
+ cuda_lazy_init = %{\n#include "torch/csrc/utils/cuda_lazy_init.h"\n} unless type == "nn"
73
+
74
+ template = <<~EOS
75
+ // generated by rake generate:functions
76
+ // do not edit by hand
77
+
78
+ #include <torch/torch.h>
79
+ #include <rice/Module.hpp>
80
+
81
+ #include "ruby_arg_parser.h"
82
+ #include "templates.h"
83
+ #include "wrap_outputs.h"
84
+ %{cuda_lazy_init}
85
+ %{method_defs}
86
+ void add_%{type}_functions(Module m) {
87
+ %{attach_defs}
88
+ }
89
+ EOS
90
+
91
+ contents = template % {
92
+ type: type,
93
+ method_defs: method_defs.join("\n"),
94
+ attach_defs: attach_defs.join("\n "),
95
+ cuda_lazy_init: cuda_lazy_init
96
+ }
97
+ write_file("#{type}_functions.cpp", contents)
98
+ end
99
+
100
+ def write_file(name, contents)
101
+ path = File.expand_path("../ext/torch", __dir__)
102
+ File.write(File.join(path, name), contents)
103
+ end
104
+
105
+ def generate_attach_def(name, type, def_method)
106
+ ruby_name =
107
+ if name.end_with?("_")
108
+ "#{name[0..-2]}!"
109
+ elsif name.start_with?("is_")
110
+ "#{name[3..-1]}?"
111
+ else
112
+ name
113
+ end
114
+
115
+ ruby_name = "_#{ruby_name}" if ["size", "stride", "random!"].include?(ruby_name)
116
+
117
+ "rb_#{def_method}(m, \"#{ruby_name}\", #{type}_#{name}, -1);"
118
+ end
119
+
120
+ def generate_method_def(name, functions, type, def_method)
121
+ assign_self = type == "tensor" ? "\n Tensor& self = from_ruby<Tensor&>(self_);" : ""
122
+
123
+ functions = group_overloads(functions, type)
124
+ signatures = functions.map { |f| f["signature"] }
125
+ max_args = signatures.map { |s| s.count(",") - s.count("*") }.max + 1
126
+
127
+ template = <<~EOS
128
+ // #{name}
129
+ static VALUE #{type}_#{name}(int argc, VALUE* argv, VALUE self_)
130
+ {
131
+ HANDLE_TH_ERRORS#{assign_self}
132
+ static RubyArgParser parser({
133
+ #{signatures.map(&:inspect).join(",\n ")}
134
+ });
135
+ std::vector<VALUE> parsed_args(#{max_args});
136
+ auto _r = parser.parse(self_, argc, argv, parsed_args);
137
+ #{add_dispatches(functions, def_method)}
138
+ END_HANDLE_TH_ERRORS
139
+ }
140
+ EOS
141
+ end
142
+
143
+ def indent(code)
144
+ code.split("\n").join("\n ")
145
+ end
146
+
147
+ def add_dispatches(functions, def_method)
148
+ if functions.size == 1
149
+ add_dispatch(functions.first, def_method)
150
+ else
151
+ body = []
152
+ functions.each_with_index do |f, i|
153
+ body << "case #{i}: {
154
+ #{add_dispatch(f, def_method).split("\n").join("\n ")}
155
+ }"
156
+ end
157
+
158
+ "switch (_r.idx) {
159
+ #{body.join("\n ")}
160
+ }
161
+ RETURN_NIL"
162
+ end
163
+ end
164
+
165
+ def add_dispatch(function, def_method)
166
+ if function["out"] && function["out"] != function["base"]
167
+ base_code = generate_dispatch(function["base"], def_method)
168
+ out_code = generate_dispatch(function["out"], def_method)
169
+ out_index = function["out"].out_index
170
+
171
+ return "if (_r.isNone(#{out_index})) {
172
+ #{indent(base_code)}
173
+ } else {
174
+ #{indent(out_code)}
175
+ }"
176
+ else
177
+ generate_dispatch(function["base"], def_method)
178
+ end
179
+ end
180
+
181
+ def group_overloads(functions, type)
182
+ grouped = Hash.new { |hash, key| hash[key] = {} }
183
+
184
+ functions.each do |function|
185
+ signature = generate_signature(function, type, skip_out: true)
186
+ v = grouped[signature]
187
+ if function.out?
188
+ v["out"] = function
189
+ v["signature"] = generate_signature(function, type)
190
+
191
+ # for now
192
+ v["base"] ||= function
193
+ else
194
+ v["base"] = function
195
+ v["signature"] ||= signature
196
+ end
197
+ end
198
+
199
+ puts "Missing base: #{functions.first.name}" if grouped.any? { |_, v| !v["base"] }
200
+ sort_functions(grouped.values)
201
+ end
202
+
203
+ def sort_functions(functions)
204
+ # TODO
205
+ functions.sort_by { |f| f["out"] ? 1 : 0 }
206
+ end
207
+
208
+ def generate_dispatch(function, def_method)
209
+ cpp_name = function.base_name
210
+ cpp_name += "_out" if function.out?
211
+
212
+ remove_self = def_method == :define_method
213
+
214
+ params = function.params.map(&:dup)
215
+ set_param_position(params, remove_self)
216
+ params, opt_params = split_opt_params(params)
217
+ opt_index = opt_params.map { |v| v[:position] }.min if opt_params.any?
218
+
219
+ cpp_params = generate_dispatch_params(function, params)
220
+ if opt_index
221
+ cpp_params.insert(remove_self ? opt_index + 1 : opt_index, "const TensorOptions & options")
222
+ end
223
+
224
+ retval = generate_dispatch_retval(function)
225
+ dispatch_code = generate_dispatch_code(function, def_method, params, opt_index, remove_self)
226
+ function_code = generate_function_code(function, cpp_name, params, opt_index, remove_self)
227
+
228
+ out_var = generate_out_var(function.out_index, function.retvals.size) if function.out? && function.retvals.size > 1 && function.retvals.all? { |v| v[:type] == "Tensor" }
229
+ tensor_options = generate_tensor_options(function, opt_params) if opt_params.any?
230
+
231
+ "// #{function.func}#{tensor_options}#{out_var}
232
+ auto dispatch_#{cpp_name} = [](#{cpp_params.join(", ")}) -> #{retval} {
233
+ // in future, release GVL
234
+ #{dispatch_code}
235
+ };
236
+ #{function_code}"
237
+ end
238
+
239
+ def generate_out_var(out_index, size)
240
+ "\n auto out = _r.tensorlist_n<#{size}>(#{out_index});"
241
+ end
242
+
243
+ def set_param_position(params, remove_self)
244
+ i = 0
245
+ params.each do |v|
246
+ next if remove_self && v[:name] == "self"
247
+ v[:position] = i
248
+ i += 1
249
+ end
250
+ end
251
+
252
+ def split_opt_params(params)
253
+ option_names = ["dtype", "device", "layout", "requires_grad", "pin_memory"]
254
+
255
+ opt_params, other_params = params.partition { |v, i| option_names.include?(v[:name]) }
256
+ if opt_params.size >= 4
257
+ [other_params, opt_params]
258
+ else
259
+ [params, []]
260
+ end
261
+ end
262
+
263
+ def generate_tensor_options(function, opt_params)
264
+ code = "\n const auto options = TensorOptions()"
265
+ order = ["dtype", "device", "layout", "requires_grad", "pin_memory"]
266
+ opt_params.sort_by { |v| order.index(v[:name]) }.each do |opt|
267
+ i = opt[:position]
268
+
269
+ c =
270
+ case opt[:name]
271
+ when "dtype"
272
+ if function.base_name == "arange"
273
+ "dtype(_r.scalartypeOptional(#{i}))"
274
+ else
275
+ "dtype(_r.scalartype(#{i}))"
276
+ end
277
+ when "device"
278
+ "device(_r.device(#{i}))"
279
+ when "layout"
280
+ "layout(_r.layoutOptional(#{i}))"
281
+ when "requires_grad"
282
+ "requires_grad(_r.toBool(#{i}))"
283
+ when "pin_memory"
284
+ "pinned_memory(_r.toBool(#{i}))"
285
+ end
286
+
287
+ code += "\n .#{c}"
288
+ end
289
+
290
+ "#{code};\n torch::utils::maybe_initialize_cuda(options);"
291
+ end
292
+
293
+ def generate_function_code(function, cpp_name, params, opt_index, remove_self)
294
+ params = generate_function_params(function, params, remove_self)
295
+ if opt_index
296
+ opt_index += 1 if remove_self
297
+ params.insert(opt_index, "options")
298
+ end
299
+
300
+ code = "dispatch_#{cpp_name}(#{params.join(", ")})"
301
+ if function.retvals.empty?
302
+ "#{code};\nRETURN_NIL"
303
+ else
304
+ "return wrap(#{code});"
305
+ end
306
+ end
307
+
308
+ def generate_function_params(function, params, remove_self)
309
+ out_var = function.out? && function.retvals.size > 1 && function.retvals.all? { |v| v[:type] == "Tensor" }
310
+
311
+ i = 0
312
+ params.map do |param|
313
+ i += 1
314
+
315
+ next "self" if remove_self && param[:name] == "self"
316
+ if out_var && i > function.out_index
317
+ next "out[#{i - function.out_index - 1}]"
318
+ end
319
+
320
+ func =
321
+ case param[:type]
322
+ when "Tensor"
323
+ "tensor"
324
+ when "Tensor[]"
325
+ "tensorlist"
326
+ when /\Aint\[/
327
+ "intlist"
328
+ when "Scalar"
329
+ "scalar"
330
+ when "bool"
331
+ "toBool"
332
+ when "int"
333
+ "toInt64"
334
+ when "float"
335
+ "toDouble"
336
+ when "ScalarType"
337
+ "scalartype"
338
+ when "str"
339
+ "string"
340
+ when "Generator"
341
+ "generator"
342
+ when "MemoryFormat"
343
+ "memoryformat"
344
+ when "Storage"
345
+ "storage"
346
+ else
347
+ raise "Unknown type: #{param[:type]} (#{function.name})"
348
+ end
349
+
350
+ if param[:optional]
351
+ func =
352
+ case func
353
+ when "tensor"
354
+ if function.out?
355
+ "tensor"
356
+ else
357
+ "optionalTensor"
358
+ end
359
+ when "generator", "tensorlist", "intlist"
360
+ func
361
+ else
362
+ "#{func}Optional"
363
+ end
364
+ end
365
+
366
+ "_r.#{func}(#{param[:position]})"
367
+ end
368
+ end
369
+
370
+ def generate_dispatch_code(function, def_method, params, opt_index, remove_self)
371
+ # torch::empty sets requires_grad by at::empty doesn't
372
+ # https://github.com/pytorch/pytorch/issues/36455
373
+ prefix = remove_self ? "self." : (opt_index ? "torch::" : "at::")
374
+ dispatch = function.out? ? "#{function.base_name}_out" : function.base_name
375
+
376
+ params = params.map { |v| v[:name] }
377
+ params.reject! { |v| v == "self" } if remove_self
378
+ params.insert(opt_index, "options") if opt_index
379
+
380
+ if function.out_index
381
+ params.unshift(params.slice!(function.out_index, function.retvals.size))
382
+ end
383
+
384
+ code = "#{prefix}#{dispatch}(#{params.join(", ")});"
385
+ code = "return #{code}" unless function.retvals.empty?
386
+ code
387
+ end
388
+
389
+ def generate_dispatch_params(function, params)
390
+ params.map do |param|
391
+ type =
392
+ case param[:type]
393
+ when "Tensor"
394
+ if param[:optional]
395
+ if function.out?
396
+ "const Tensor &"
397
+ else
398
+ # TODO
399
+ # "const c10::optional<at::Tensor> &"
400
+ "const OptionalTensor &"
401
+ end
402
+ elsif param[:modifier]
403
+ if param[:modifier].include?("!") && function.retvals.size > 1
404
+ "Tensor &"
405
+ else
406
+ "Tensor"
407
+ end
408
+ else
409
+ "const Tensor &"
410
+ end
411
+ when "Tensor[]"
412
+ "TensorList"
413
+ when "int"
414
+ "int64_t"
415
+ when "float"
416
+ "double"
417
+ when /\Aint\[/
418
+ "IntArrayRef"
419
+ when "str"
420
+ "std::string"
421
+ when "Scalar", "bool", "ScalarType", "Layout", "Device", "Storage", "Generator", "MemoryFormat", "Storage"
422
+ param[:type]
423
+ else
424
+ raise "Unknown type: #{param[:type]} (#{function.name})"
425
+ end
426
+
427
+ if param[:optional] && param[:type] != "Tensor"
428
+ type = "c10::optional<#{type}>"
429
+ end
430
+
431
+ "#{type} #{param[:name]}"
432
+ end
433
+ end
434
+
435
+ def generate_dispatch_retval(function)
436
+ types = function.retvals.map { |r| r[:type] }
437
+
438
+ case types
439
+ when []
440
+ "void"
441
+ when ["bool"]
442
+ "bool"
443
+ when ["int"]
444
+ "int64_t"
445
+ when ["float"]
446
+ "double"
447
+ when ["Scalar"]
448
+ "Scalar"
449
+ when ["ScalarType"]
450
+ "ScalarType"
451
+ when ["QScheme"]
452
+ "QScheme"
453
+ when ["Tensor"]
454
+ "Tensor"
455
+ when ["Tensor[]"]
456
+ "std::vector<Tensor>"
457
+ when ["Tensor", "Tensor"]
458
+ "std::tuple<Tensor,Tensor>"
459
+ when ["Tensor", "Tensor", "Tensor"]
460
+ "std::tuple<Tensor,Tensor,Tensor>"
461
+ when ["Tensor", "Tensor", "Tensor", "Tensor"]
462
+ "std::tuple<Tensor,Tensor,Tensor,Tensor>"
463
+ when ["Tensor", "Tensor", "Tensor", "Tensor", "Tensor"]
464
+ "std::tuple<Tensor,Tensor,Tensor,Tensor,Tensor>"
465
+ when ["Tensor", "Tensor", "float", "int"]
466
+ "std::tuple<Tensor,Tensor,float,int>"
467
+ else
468
+ raise "Unknown retvals: #{types}"
469
+ end
470
+ end
471
+
472
+ def generate_signature(function, type, skip_out: false)
473
+ params = function.params.dup
474
+ if function.out?
475
+ if skip_out
476
+ # remove out
477
+ params.slice!(function.out_index, function.retvals.size)
478
+ elsif function.retvals.size > 1 && params[function.out_index, function.retvals.size].all? { |r| r[:type] == "Tensor" }
479
+ # combine tensor into tensorlist
480
+ list_size = function.retvals.size
481
+ params.slice!(function.out_index, list_size)
482
+ params.insert(function.out_index, {name: "out", type: "Tensor[#{list_size}]", list_size: list_size, keyword_only: true})
483
+ end
484
+ end
485
+
486
+ parts = params.select { |v| !v[:keyword_only] && !(type == "tensor" && v[:name] == "self") }
487
+ keyword_only_parts = params.select { |v| v[:keyword_only] }
488
+ if keyword_only_parts.any?
489
+ parts << "*"
490
+ parts.concat(keyword_only_parts)
491
+ end
492
+
493
+ "#{function.base_name}(#{parts.map { |v| signature_param(v) }.join(", ")})"
494
+ end
495
+
496
+ def signature_param(param)
497
+ return "*" if param == "*"
498
+
499
+ name = param[:name]
500
+ name = "input" if name == "self"
501
+
502
+ sig = "#{signature_type(param)} #{name}"
503
+ case param[:default]
504
+ when nil
505
+ # do nothing
506
+ when "[]"
507
+ sig += "=None"
508
+ when "Mean"
509
+ sig += "=at::Reduction::Mean"
510
+ else
511
+ sig += "=#{param[:default]}"
512
+ end
513
+
514
+ # hack
515
+ sig += "=None" if param[:name] == "out"
516
+
517
+ sig
518
+ end
519
+
520
+ def signature_type(param)
521
+ type =
522
+ case param[:type]
523
+ when "Tensor", /\ATensor\([a-z]!?\)\z/
524
+ "Tensor"
525
+ when /\Tensor\[\d*\]\z/
526
+ "TensorList"
527
+ when /\ADimname\[\d*\]\z/
528
+ "DirnameList"
529
+ when /\Aint\[\d*\]\z/
530
+ "IntArrayRef"
531
+ when "int"
532
+ "int64_t"
533
+ when "float"
534
+ "double"
535
+ when "str"
536
+ "std::string"
537
+ when "Scalar", "Dimname", "bool", "ScalarType", "Layout", "Device", "Generator", "MemoryFormat", "Storage"
538
+ param[:type]
539
+ else
540
+ raise "Unknown type: #{param[:type]}"
541
+ end
542
+
543
+ type += "[#{param[:list_size]}]" if param[:list_size]
544
+ type += "?" if param[:optional]
545
+ type
546
+ end