torch-rb 0.3.7 → 0.4.0

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 8a1852ee3d1ecc7a29c23259b8c328a95030a270b7c11f37f22049177898652e
4
- data.tar.gz: 56823f1815d3c0c4d5d5c01ef76d781b792b3e4e7c68c0332a149b883a54c7c8
3
+ metadata.gz: 6294a5809ed3ea9b3d3f9954b351b8c8e5d81c37e9318ae5044d243320166b20
4
+ data.tar.gz: 57d8036a0449497cbf040c8650c4dcdbe3e0f09191951d77e6d907c6b3c1013d
5
5
  SHA512:
6
- metadata.gz: bed15510cfeaa555d71f1e1f46ed8944893bd349a07c4316dcd63429fe76e13facd8794399ef97fc400d05796579f2e84822b62c98c71dc996e211ad04113ae2
7
- data.tar.gz: aa05e3645e363eda27274323cdb7fb316342074d1d5afe8f7ee6bfd9819da7883b43d084beb6b29011c631c04fdddc8e6789db41c7c84c53ba9ed152d3338b09
6
+ metadata.gz: c7aaf862b43f72f5ca8e7f6f6aca079195c11420db081f9c4f18f161610598d975fb81a036689dcd3158d200c33e39e67841bded2e728f2a2fb1b5b8d9e9dfad
7
+ data.tar.gz: afaf4c344951629248f79bb03246f6879555a58438d07107dea008aa2433d08115d56b1f9464cf20065c276417d3d90a59aecbfefa304208ee81ae313867bad6
@@ -1,3 +1,8 @@
1
+ ## 0.4.0 (2020-09-27)
2
+
3
+ - Improved performance of methods
4
+ - Improved performance of tensor indexing
5
+
1
6
  ## 0.3.7 (2020-09-22)
2
7
 
3
8
  - Improved performance
data/README.md CHANGED
@@ -416,7 +416,7 @@ Here’s the list of compatible versions.
416
416
 
417
417
  Torch.rb | LibTorch
418
418
  --- | ---
419
- 0.3.0-0.3.4 | 1.6.0
419
+ 0.3.0+ | 1.6.0
420
420
  0.2.0-0.2.7 | 1.5.0-1.5.1
421
421
  0.1.8 | 1.4.0
422
422
  0.1.0-0.1.7 | 1.3.1
@@ -0,0 +1,134 @@
1
+ class Function
2
+ attr_reader :definition, :params, :retvals
3
+
4
+ def initialize(definition)
5
+ @definition = definition
6
+ @params, @retvals = parse_func
7
+ end
8
+
9
+ def name
10
+ func.split("(", 2).first
11
+ end
12
+
13
+ def base_name
14
+ name.split(".").first
15
+ end
16
+
17
+ def func
18
+ definition["func"]
19
+ end
20
+
21
+ def python_module
22
+ definition["python_module"]
23
+ end
24
+
25
+ def variants
26
+ (definition["variants"] || "function").split(", ")
27
+ end
28
+
29
+ def out_index
30
+ params.index { |v| v[:modifier].to_s.include?("!") } if base_name[-1] != "_" && retvals.any?
31
+ end
32
+
33
+ def out?
34
+ !out_index.nil?
35
+ end
36
+
37
+ private
38
+
39
+ def parse_func
40
+ input, output = func.split(/\s*->\s*/)
41
+ [generate_params(input), generate_retvals(output)]
42
+ end
43
+
44
+ def generate_params(input)
45
+ input = input.split("(", 2).last.chomp(")").split(/\s*,\s+/)
46
+
47
+ keyword_only = false
48
+ params = []
49
+ input.each do |i|
50
+ if i == "*"
51
+ keyword_only = true
52
+ next
53
+ end
54
+
55
+ type, name = i.split(/\s+/)
56
+
57
+ if name.include?("=")
58
+ name, default = name.split("=", 2)
59
+ end
60
+
61
+ optional = false
62
+ if type.include?("?")
63
+ optional = true unless ["dtype", "device", "layout", "pin_memory"].include?(name)
64
+ type = type.delete("?")
65
+ end
66
+
67
+ type, modifier = extract_modifier(type)
68
+
69
+ if type.include?("[")
70
+ list_size = /\[(.*)\]/.match(type)[1]
71
+ list_size = nil if list_size.empty?
72
+ end
73
+
74
+ if name == "dtype" && (base_name.start_with?("randperm") || base_name == "tril_indices" || base_name == "triu_indices")
75
+ # dtype hack
76
+ # https://github.com/pytorch/pytorch/blob/v1.6.0/tools/autograd/gen_python_functions.py#L1307-L1311
77
+ default = "torch.int64"
78
+ end
79
+
80
+ params << {
81
+ name: name,
82
+ type: type,
83
+ default: default,
84
+ keyword_only: keyword_only,
85
+ optional: optional,
86
+ modifier: modifier,
87
+ list_size: list_size
88
+ }
89
+ end
90
+
91
+ if (params.map { |v| v[:name] } & ["dtype", "device", "layout", "pin_memory"]).size == 4
92
+ params << {
93
+ name: "requires_grad",
94
+ type: "bool",
95
+ default: "False",
96
+ keyword_only: true,
97
+ optional: false,
98
+ modifier: nil,
99
+ list_size: nil
100
+ }
101
+ end
102
+
103
+ params
104
+ end
105
+
106
+ def generate_retvals(output)
107
+ output =
108
+ if output == "()"
109
+ []
110
+ elsif output[0] == "("
111
+ output[1..-2].split(/\s*,\s*/)
112
+ else
113
+ [output]
114
+ end
115
+
116
+ retvals = []
117
+ output.each do |o|
118
+ type, name = o.split(/\s+/)
119
+ type, modifier = extract_modifier(type)
120
+ retvals << {name: name, type: type, modifier: modifier}
121
+ end
122
+ retvals
123
+ end
124
+
125
+ # Tensor(a), Tensor(a!), Tensor(a)[]
126
+ def extract_modifier(type)
127
+ if type.include?("(")
128
+ parts = type.split(/[\(\)]/, 3)
129
+ modifier = parts.delete_at(1)
130
+ type = parts.join("")
131
+ end
132
+ [type, modifier]
133
+ end
134
+ end
@@ -0,0 +1,546 @@
1
+ require "yaml"
2
+ # use require_relative for
3
+ # rake generate:function (without bundle)
4
+ require_relative "function"
5
+
6
+ def generate_functions
7
+ functions = load_functions
8
+ functions = skip_functions(functions)
9
+ functions = group_functions(functions)
10
+
11
+ generate_files("torch", :define_singleton_method, functions[:torch])
12
+ generate_files("tensor", :define_method, functions[:tensor])
13
+ generate_files("nn", :define_singleton_method, functions[:nn])
14
+ end
15
+
16
+ def load_functions
17
+ path = File.expand_path("native_functions.yaml", __dir__)
18
+ YAML.load_file(path).map { |f| Function.new(f) }.sort_by(&:name)
19
+ end
20
+
21
+ def skip_functions(functions)
22
+ functions.reject do |f|
23
+ f.base_name.start_with?("_") ||
24
+ f.base_name.include?("_backward") ||
25
+ f.base_name.include?("_forward") ||
26
+ f.base_name == "to" ||
27
+ # in ext.cpp
28
+ f.base_name == "index" ||
29
+ f.base_name == "index_put_" ||
30
+ # need to add to ext.cpp
31
+ f.base_name == "index_put" ||
32
+ # not supported yet
33
+ f.func.include?("Dimname") ||
34
+ f.func.include?("ConstQuantizerPtr")
35
+ end
36
+ end
37
+
38
+ def group_functions(functions)
39
+ nn_functions, other_functions = functions.partition { |f| f.python_module == "nn" }
40
+ torch_functions = other_functions.select { |f| f.variants.include?("function") }
41
+ tensor_functions = other_functions.select { |f| f.variants.include?("method") }
42
+
43
+ {torch: torch_functions, tensor: tensor_functions, nn: nn_functions}
44
+ end
45
+
46
+ def generate_files(type, def_method, functions)
47
+ method_defs = []
48
+ attach_defs = []
49
+ functions.group_by(&:base_name).each do |name, grouped_functions|
50
+ method_defs << generate_method_def(name, grouped_functions, type, def_method)
51
+ attach_defs << generate_attach_def(name, type, def_method)
52
+ end
53
+ write_header(type)
54
+ write_body(type, method_defs, attach_defs)
55
+ end
56
+
57
+ def write_header(type)
58
+ template = <<~EOS
59
+ // generated by rake generate:functions
60
+ // do not edit by hand
61
+
62
+ #pragma once
63
+
64
+ void add_%{type}_functions(Module m);
65
+ EOS
66
+
67
+ contents = template % {type: type}
68
+ write_file("#{type}_functions.h", contents)
69
+ end
70
+
71
+ def write_body(type, method_defs, attach_defs)
72
+ cuda_lazy_init = %{\n#include "torch/csrc/utils/cuda_lazy_init.h"\n} unless type == "nn"
73
+
74
+ template = <<~EOS
75
+ // generated by rake generate:functions
76
+ // do not edit by hand
77
+
78
+ #include <torch/torch.h>
79
+ #include <rice/Module.hpp>
80
+
81
+ #include "ruby_arg_parser.h"
82
+ #include "templates.h"
83
+ #include "wrap_outputs.h"
84
+ %{cuda_lazy_init}
85
+ %{method_defs}
86
+ void add_%{type}_functions(Module m) {
87
+ %{attach_defs}
88
+ }
89
+ EOS
90
+
91
+ contents = template % {
92
+ type: type,
93
+ method_defs: method_defs.join("\n"),
94
+ attach_defs: attach_defs.join("\n "),
95
+ cuda_lazy_init: cuda_lazy_init
96
+ }
97
+ write_file("#{type}_functions.cpp", contents)
98
+ end
99
+
100
+ def write_file(name, contents)
101
+ path = File.expand_path("../ext/torch", __dir__)
102
+ File.write(File.join(path, name), contents)
103
+ end
104
+
105
+ def generate_attach_def(name, type, def_method)
106
+ ruby_name =
107
+ if name.end_with?("_")
108
+ "#{name[0..-2]}!"
109
+ elsif name.start_with?("is_")
110
+ "#{name[3..-1]}?"
111
+ else
112
+ name
113
+ end
114
+
115
+ ruby_name = "_#{ruby_name}" if ["size", "stride", "random!"].include?(ruby_name)
116
+
117
+ "rb_#{def_method}(m, \"#{ruby_name}\", #{type}_#{name}, -1);"
118
+ end
119
+
120
+ def generate_method_def(name, functions, type, def_method)
121
+ assign_self = type == "tensor" ? "\n Tensor& self = from_ruby<Tensor&>(self_);" : ""
122
+
123
+ functions = group_overloads(functions, type)
124
+ signatures = functions.map { |f| f["signature"] }
125
+ max_args = signatures.map { |s| s.count(",") - s.count("*") }.max + 1
126
+
127
+ template = <<~EOS
128
+ // #{name}
129
+ static VALUE #{type}_#{name}(int argc, VALUE* argv, VALUE self_)
130
+ {
131
+ HANDLE_TH_ERRORS#{assign_self}
132
+ static RubyArgParser parser({
133
+ #{signatures.map(&:inspect).join(",\n ")}
134
+ });
135
+ std::vector<VALUE> parsed_args(#{max_args});
136
+ auto _r = parser.parse(self_, argc, argv, parsed_args);
137
+ #{add_dispatches(functions, def_method)}
138
+ END_HANDLE_TH_ERRORS
139
+ }
140
+ EOS
141
+ end
142
+
143
+ def indent(code)
144
+ code.split("\n").join("\n ")
145
+ end
146
+
147
+ def add_dispatches(functions, def_method)
148
+ if functions.size == 1
149
+ add_dispatch(functions.first, def_method)
150
+ else
151
+ body = []
152
+ functions.each_with_index do |f, i|
153
+ body << "case #{i}: {
154
+ #{add_dispatch(f, def_method).split("\n").join("\n ")}
155
+ }"
156
+ end
157
+
158
+ "switch (_r.idx) {
159
+ #{body.join("\n ")}
160
+ }
161
+ RETURN_NIL"
162
+ end
163
+ end
164
+
165
+ def add_dispatch(function, def_method)
166
+ if function["out"] && function["out"] != function["base"]
167
+ base_code = generate_dispatch(function["base"], def_method)
168
+ out_code = generate_dispatch(function["out"], def_method)
169
+ out_index = function["out"].out_index
170
+
171
+ return "if (_r.isNone(#{out_index})) {
172
+ #{indent(base_code)}
173
+ } else {
174
+ #{indent(out_code)}
175
+ }"
176
+ else
177
+ generate_dispatch(function["base"], def_method)
178
+ end
179
+ end
180
+
181
+ def group_overloads(functions, type)
182
+ grouped = Hash.new { |hash, key| hash[key] = {} }
183
+
184
+ functions.each do |function|
185
+ signature = generate_signature(function, type, skip_out: true)
186
+ v = grouped[signature]
187
+ if function.out?
188
+ v["out"] = function
189
+ v["signature"] = generate_signature(function, type)
190
+
191
+ # for now
192
+ v["base"] ||= function
193
+ else
194
+ v["base"] = function
195
+ v["signature"] ||= signature
196
+ end
197
+ end
198
+
199
+ puts "Missing base: #{functions.first.name}" if grouped.any? { |_, v| !v["base"] }
200
+ sort_functions(grouped.values)
201
+ end
202
+
203
+ def sort_functions(functions)
204
+ # TODO
205
+ functions.sort_by { |f| f["out"] ? 1 : 0 }
206
+ end
207
+
208
+ def generate_dispatch(function, def_method)
209
+ cpp_name = function.base_name
210
+ cpp_name += "_out" if function.out?
211
+
212
+ remove_self = def_method == :define_method
213
+
214
+ params = function.params.map(&:dup)
215
+ set_param_position(params, remove_self)
216
+ params, opt_params = split_opt_params(params)
217
+ opt_index = opt_params.map { |v| v[:position] }.min if opt_params.any?
218
+
219
+ cpp_params = generate_dispatch_params(function, params)
220
+ if opt_index
221
+ cpp_params.insert(remove_self ? opt_index + 1 : opt_index, "const TensorOptions & options")
222
+ end
223
+
224
+ retval = generate_dispatch_retval(function)
225
+ dispatch_code = generate_dispatch_code(function, def_method, params, opt_index, remove_self)
226
+ function_code = generate_function_code(function, cpp_name, params, opt_index, remove_self)
227
+
228
+ out_var = generate_out_var(function.out_index, function.retvals.size) if function.out? && function.retvals.size > 1 && function.retvals.all? { |v| v[:type] == "Tensor" }
229
+ tensor_options = generate_tensor_options(function, opt_params) if opt_params.any?
230
+
231
+ "// #{function.func}#{tensor_options}#{out_var}
232
+ auto dispatch_#{cpp_name} = [](#{cpp_params.join(", ")}) -> #{retval} {
233
+ // in future, release GVL
234
+ #{dispatch_code}
235
+ };
236
+ #{function_code}"
237
+ end
238
+
239
+ def generate_out_var(out_index, size)
240
+ "\n auto out = _r.tensorlist_n<#{size}>(#{out_index});"
241
+ end
242
+
243
+ def set_param_position(params, remove_self)
244
+ i = 0
245
+ params.each do |v|
246
+ next if remove_self && v[:name] == "self"
247
+ v[:position] = i
248
+ i += 1
249
+ end
250
+ end
251
+
252
+ def split_opt_params(params)
253
+ option_names = ["dtype", "device", "layout", "requires_grad", "pin_memory"]
254
+
255
+ opt_params, other_params = params.partition { |v, i| option_names.include?(v[:name]) }
256
+ if opt_params.size >= 4
257
+ [other_params, opt_params]
258
+ else
259
+ [params, []]
260
+ end
261
+ end
262
+
263
+ def generate_tensor_options(function, opt_params)
264
+ code = "\n const auto options = TensorOptions()"
265
+ order = ["dtype", "device", "layout", "requires_grad", "pin_memory"]
266
+ opt_params.sort_by { |v| order.index(v[:name]) }.each do |opt|
267
+ i = opt[:position]
268
+
269
+ c =
270
+ case opt[:name]
271
+ when "dtype"
272
+ if function.base_name == "arange"
273
+ "dtype(_r.scalartypeOptional(#{i}))"
274
+ else
275
+ "dtype(_r.scalartype(#{i}))"
276
+ end
277
+ when "device"
278
+ "device(_r.device(#{i}))"
279
+ when "layout"
280
+ "layout(_r.layoutOptional(#{i}))"
281
+ when "requires_grad"
282
+ "requires_grad(_r.toBool(#{i}))"
283
+ when "pin_memory"
284
+ "pinned_memory(_r.toBool(#{i}))"
285
+ end
286
+
287
+ code += "\n .#{c}"
288
+ end
289
+
290
+ "#{code};\n torch::utils::maybe_initialize_cuda(options);"
291
+ end
292
+
293
+ def generate_function_code(function, cpp_name, params, opt_index, remove_self)
294
+ params = generate_function_params(function, params, remove_self)
295
+ if opt_index
296
+ opt_index += 1 if remove_self
297
+ params.insert(opt_index, "options")
298
+ end
299
+
300
+ code = "dispatch_#{cpp_name}(#{params.join(", ")})"
301
+ if function.retvals.empty?
302
+ "#{code};\nRETURN_NIL"
303
+ else
304
+ "return wrap(#{code});"
305
+ end
306
+ end
307
+
308
+ def generate_function_params(function, params, remove_self)
309
+ out_var = function.out? && function.retvals.size > 1 && function.retvals.all? { |v| v[:type] == "Tensor" }
310
+
311
+ i = 0
312
+ params.map do |param|
313
+ i += 1
314
+
315
+ next "self" if remove_self && param[:name] == "self"
316
+ if out_var && i > function.out_index
317
+ next "out[#{i - function.out_index - 1}]"
318
+ end
319
+
320
+ func =
321
+ case param[:type]
322
+ when "Tensor"
323
+ "tensor"
324
+ when "Tensor[]"
325
+ "tensorlist"
326
+ when /\Aint\[/
327
+ "intlist"
328
+ when "Scalar"
329
+ "scalar"
330
+ when "bool"
331
+ "toBool"
332
+ when "int"
333
+ "toInt64"
334
+ when "float"
335
+ "toDouble"
336
+ when "ScalarType"
337
+ "scalartype"
338
+ when "str"
339
+ "string"
340
+ when "Generator"
341
+ "generator"
342
+ when "MemoryFormat"
343
+ "memoryformat"
344
+ when "Storage"
345
+ "storage"
346
+ else
347
+ raise "Unknown type: #{param[:type]} (#{function.name})"
348
+ end
349
+
350
+ if param[:optional]
351
+ func =
352
+ case func
353
+ when "tensor"
354
+ if function.out?
355
+ "tensor"
356
+ else
357
+ "optionalTensor"
358
+ end
359
+ when "generator", "tensorlist", "intlist"
360
+ func
361
+ else
362
+ "#{func}Optional"
363
+ end
364
+ end
365
+
366
+ "_r.#{func}(#{param[:position]})"
367
+ end
368
+ end
369
+
370
+ def generate_dispatch_code(function, def_method, params, opt_index, remove_self)
371
+ # torch::empty sets requires_grad by at::empty doesn't
372
+ # https://github.com/pytorch/pytorch/issues/36455
373
+ prefix = remove_self ? "self." : (opt_index ? "torch::" : "at::")
374
+ dispatch = function.out? ? "#{function.base_name}_out" : function.base_name
375
+
376
+ params = params.map { |v| v[:name] }
377
+ params.reject! { |v| v == "self" } if remove_self
378
+ params.insert(opt_index, "options") if opt_index
379
+
380
+ if function.out_index
381
+ params.unshift(params.slice!(function.out_index, function.retvals.size))
382
+ end
383
+
384
+ code = "#{prefix}#{dispatch}(#{params.join(", ")});"
385
+ code = "return #{code}" unless function.retvals.empty?
386
+ code
387
+ end
388
+
389
+ def generate_dispatch_params(function, params)
390
+ params.map do |param|
391
+ type =
392
+ case param[:type]
393
+ when "Tensor"
394
+ if param[:optional]
395
+ if function.out?
396
+ "const Tensor &"
397
+ else
398
+ # TODO
399
+ # "const c10::optional<at::Tensor> &"
400
+ "const OptionalTensor &"
401
+ end
402
+ elsif param[:modifier]
403
+ if param[:modifier].include?("!") && function.retvals.size > 1
404
+ "Tensor &"
405
+ else
406
+ "Tensor"
407
+ end
408
+ else
409
+ "const Tensor &"
410
+ end
411
+ when "Tensor[]"
412
+ "TensorList"
413
+ when "int"
414
+ "int64_t"
415
+ when "float"
416
+ "double"
417
+ when /\Aint\[/
418
+ "IntArrayRef"
419
+ when "str"
420
+ "std::string"
421
+ when "Scalar", "bool", "ScalarType", "Layout", "Device", "Storage", "Generator", "MemoryFormat", "Storage"
422
+ param[:type]
423
+ else
424
+ raise "Unknown type: #{param[:type]} (#{function.name})"
425
+ end
426
+
427
+ if param[:optional] && param[:type] != "Tensor"
428
+ type = "c10::optional<#{type}>"
429
+ end
430
+
431
+ "#{type} #{param[:name]}"
432
+ end
433
+ end
434
+
435
+ def generate_dispatch_retval(function)
436
+ types = function.retvals.map { |r| r[:type] }
437
+
438
+ case types
439
+ when []
440
+ "void"
441
+ when ["bool"]
442
+ "bool"
443
+ when ["int"]
444
+ "int64_t"
445
+ when ["float"]
446
+ "double"
447
+ when ["Scalar"]
448
+ "Scalar"
449
+ when ["ScalarType"]
450
+ "ScalarType"
451
+ when ["QScheme"]
452
+ "QScheme"
453
+ when ["Tensor"]
454
+ "Tensor"
455
+ when ["Tensor[]"]
456
+ "std::vector<Tensor>"
457
+ when ["Tensor", "Tensor"]
458
+ "std::tuple<Tensor,Tensor>"
459
+ when ["Tensor", "Tensor", "Tensor"]
460
+ "std::tuple<Tensor,Tensor,Tensor>"
461
+ when ["Tensor", "Tensor", "Tensor", "Tensor"]
462
+ "std::tuple<Tensor,Tensor,Tensor,Tensor>"
463
+ when ["Tensor", "Tensor", "Tensor", "Tensor", "Tensor"]
464
+ "std::tuple<Tensor,Tensor,Tensor,Tensor,Tensor>"
465
+ when ["Tensor", "Tensor", "float", "int"]
466
+ "std::tuple<Tensor,Tensor,float,int>"
467
+ else
468
+ raise "Unknown retvals: #{types}"
469
+ end
470
+ end
471
+
472
+ def generate_signature(function, type, skip_out: false)
473
+ params = function.params.dup
474
+ if function.out?
475
+ if skip_out
476
+ # remove out
477
+ params.slice!(function.out_index, function.retvals.size)
478
+ elsif function.retvals.size > 1 && params[function.out_index, function.retvals.size].all? { |r| r[:type] == "Tensor" }
479
+ # combine tensor into tensorlist
480
+ list_size = function.retvals.size
481
+ params.slice!(function.out_index, list_size)
482
+ params.insert(function.out_index, {name: "out", type: "Tensor[#{list_size}]", list_size: list_size, keyword_only: true})
483
+ end
484
+ end
485
+
486
+ parts = params.select { |v| !v[:keyword_only] && !(type == "tensor" && v[:name] == "self") }
487
+ keyword_only_parts = params.select { |v| v[:keyword_only] }
488
+ if keyword_only_parts.any?
489
+ parts << "*"
490
+ parts.concat(keyword_only_parts)
491
+ end
492
+
493
+ "#{function.base_name}(#{parts.map { |v| signature_param(v) }.join(", ")})"
494
+ end
495
+
496
+ def signature_param(param)
497
+ return "*" if param == "*"
498
+
499
+ name = param[:name]
500
+ name = "input" if name == "self"
501
+
502
+ sig = "#{signature_type(param)} #{name}"
503
+ case param[:default]
504
+ when nil
505
+ # do nothing
506
+ when "[]"
507
+ sig += "=None"
508
+ when "Mean"
509
+ sig += "=at::Reduction::Mean"
510
+ else
511
+ sig += "=#{param[:default]}"
512
+ end
513
+
514
+ # hack
515
+ sig += "=None" if param[:name] == "out"
516
+
517
+ sig
518
+ end
519
+
520
+ def signature_type(param)
521
+ type =
522
+ case param[:type]
523
+ when "Tensor", /\ATensor\([a-z]!?\)\z/
524
+ "Tensor"
525
+ when /\Tensor\[\d*\]\z/
526
+ "TensorList"
527
+ when /\ADimname\[\d*\]\z/
528
+ "DirnameList"
529
+ when /\Aint\[\d*\]\z/
530
+ "IntArrayRef"
531
+ when "int"
532
+ "int64_t"
533
+ when "float"
534
+ "double"
535
+ when "str"
536
+ "std::string"
537
+ when "Scalar", "Dimname", "bool", "ScalarType", "Layout", "Device", "Generator", "MemoryFormat", "Storage"
538
+ param[:type]
539
+ else
540
+ raise "Unknown type: #{param[:type]}"
541
+ end
542
+
543
+ type += "[#{param[:list_size]}]" if param[:list_size]
544
+ type += "?" if param[:optional]
545
+ type
546
+ end