torch-rb 0.3.6 → 0.5.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +27 -0
- data/README.md +3 -1
- data/codegen/function.rb +134 -0
- data/codegen/generate_functions.rb +557 -0
- data/{lib/torch/native → codegen}/native_functions.yaml +2363 -714
- data/ext/torch/ext.cpp +78 -89
- data/ext/torch/extconf.rb +5 -2
- data/ext/torch/nn_functions.h +6 -0
- data/ext/torch/ruby_arg_parser.cpp +593 -0
- data/ext/torch/ruby_arg_parser.h +397 -0
- data/ext/torch/{templates.hpp → templates.h} +46 -77
- data/ext/torch/tensor_functions.h +6 -0
- data/ext/torch/torch_functions.h +6 -0
- data/ext/torch/utils.h +42 -0
- data/ext/torch/{templates.cpp → wrap_outputs.h} +44 -8
- data/lib/torch.rb +35 -62
- data/lib/torch/nn/functional.rb +136 -16
- data/lib/torch/nn/init.rb +5 -19
- data/lib/torch/nn/module.rb +4 -1
- data/lib/torch/nn/upsample.rb +31 -0
- data/lib/torch/optim/adadelta.rb +4 -4
- data/lib/torch/optim/adagrad.rb +3 -3
- data/lib/torch/optim/adam.rb +4 -4
- data/lib/torch/optim/adamax.rb +3 -3
- data/lib/torch/optim/adamw.rb +3 -3
- data/lib/torch/optim/asgd.rb +2 -2
- data/lib/torch/optim/rmsprop.rb +7 -7
- data/lib/torch/optim/rprop.rb +1 -1
- data/lib/torch/optim/sgd.rb +5 -5
- data/lib/torch/tensor.rb +36 -110
- data/lib/torch/version.rb +1 -1
- metadata +19 -14
- data/lib/torch/native/dispatcher.rb +0 -48
- data/lib/torch/native/function.rb +0 -119
- data/lib/torch/native/generator.rb +0 -168
- data/lib/torch/native/parser.rb +0 -148
data/lib/torch/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.5.0
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
|
-
autorequire:
|
8
|
+
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2020-
|
11
|
+
date: 2020-10-28 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|
@@ -108,7 +108,7 @@ dependencies:
|
|
108
108
|
- - ">="
|
109
109
|
- !ruby/object:Gem::Version
|
110
110
|
version: 0.1.1
|
111
|
-
description:
|
111
|
+
description:
|
112
112
|
email: andrew@chartkick.com
|
113
113
|
executables: []
|
114
114
|
extensions:
|
@@ -118,19 +118,23 @@ files:
|
|
118
118
|
- CHANGELOG.md
|
119
119
|
- LICENSE.txt
|
120
120
|
- README.md
|
121
|
+
- codegen/function.rb
|
122
|
+
- codegen/generate_functions.rb
|
123
|
+
- codegen/native_functions.yaml
|
121
124
|
- ext/torch/ext.cpp
|
122
125
|
- ext/torch/extconf.rb
|
123
|
-
- ext/torch/
|
124
|
-
- ext/torch/
|
126
|
+
- ext/torch/nn_functions.h
|
127
|
+
- ext/torch/ruby_arg_parser.cpp
|
128
|
+
- ext/torch/ruby_arg_parser.h
|
129
|
+
- ext/torch/templates.h
|
130
|
+
- ext/torch/tensor_functions.h
|
131
|
+
- ext/torch/torch_functions.h
|
132
|
+
- ext/torch/utils.h
|
133
|
+
- ext/torch/wrap_outputs.h
|
125
134
|
- lib/torch-rb.rb
|
126
135
|
- lib/torch.rb
|
127
136
|
- lib/torch/hub.rb
|
128
137
|
- lib/torch/inspector.rb
|
129
|
-
- lib/torch/native/dispatcher.rb
|
130
|
-
- lib/torch/native/function.rb
|
131
|
-
- lib/torch/native/generator.rb
|
132
|
-
- lib/torch/native/native_functions.yaml
|
133
|
-
- lib/torch/native/parser.rb
|
134
138
|
- lib/torch/nn/adaptive_avg_pool1d.rb
|
135
139
|
- lib/torch/nn/adaptive_avg_pool2d.rb
|
136
140
|
- lib/torch/nn/adaptive_avg_pool3d.rb
|
@@ -238,6 +242,7 @@ files:
|
|
238
242
|
- lib/torch/nn/tanhshrink.rb
|
239
243
|
- lib/torch/nn/triplet_margin_loss.rb
|
240
244
|
- lib/torch/nn/unfold.rb
|
245
|
+
- lib/torch/nn/upsample.rb
|
241
246
|
- lib/torch/nn/utils.rb
|
242
247
|
- lib/torch/nn/weighted_loss.rb
|
243
248
|
- lib/torch/nn/zero_pad2d.rb
|
@@ -269,7 +274,7 @@ homepage: https://github.com/ankane/torch.rb
|
|
269
274
|
licenses:
|
270
275
|
- BSD-3-Clause
|
271
276
|
metadata: {}
|
272
|
-
post_install_message:
|
277
|
+
post_install_message:
|
273
278
|
rdoc_options: []
|
274
279
|
require_paths:
|
275
280
|
- lib
|
@@ -284,8 +289,8 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
284
289
|
- !ruby/object:Gem::Version
|
285
290
|
version: '0'
|
286
291
|
requirements: []
|
287
|
-
rubygems_version: 3.1.
|
288
|
-
signing_key:
|
292
|
+
rubygems_version: 3.1.4
|
293
|
+
signing_key:
|
289
294
|
specification_version: 4
|
290
295
|
summary: Deep learning for Ruby, powered by LibTorch
|
291
296
|
test_files: []
|
@@ -1,48 +0,0 @@
|
|
1
|
-
# We use a generic interface for methods (*args, **options)
|
2
|
-
# and this class to determine the C++ method to call
|
3
|
-
#
|
4
|
-
# This is needed since LibTorch uses function overloading,
|
5
|
-
# which isn't available in Ruby or Python
|
6
|
-
#
|
7
|
-
# PyTorch uses this approach, but the parser/dispatcher is written in C++
|
8
|
-
#
|
9
|
-
# We could generate Ruby methods directly, but an advantage of this approach is
|
10
|
-
# arguments and keyword arguments can be used interchangably like in Python,
|
11
|
-
# making it easier to port code
|
12
|
-
|
13
|
-
module Torch
|
14
|
-
module Native
|
15
|
-
module Dispatcher
|
16
|
-
class << self
|
17
|
-
def bind
|
18
|
-
functions = Generator.grouped_functions
|
19
|
-
bind_functions(::Torch, :define_singleton_method, functions[:torch])
|
20
|
-
bind_functions(::Torch::Tensor, :define_method, functions[:tensor])
|
21
|
-
bind_functions(::Torch::NN, :define_singleton_method, functions[:nn])
|
22
|
-
end
|
23
|
-
|
24
|
-
def bind_functions(context, def_method, functions)
|
25
|
-
functions.group_by(&:ruby_name).sort_by { |g, _| g }.each do |name, funcs|
|
26
|
-
if def_method == :define_method
|
27
|
-
funcs.map! { |f| Function.new(f.function) }
|
28
|
-
funcs.each { |f| f.args.reject! { |a| a[:name] == "self" } }
|
29
|
-
end
|
30
|
-
|
31
|
-
defined = def_method == :define_method ? context.method_defined?(name) : context.respond_to?(name)
|
32
|
-
next if defined && name != "clone"
|
33
|
-
|
34
|
-
parser = Parser.new(funcs)
|
35
|
-
|
36
|
-
context.send(def_method, name) do |*args, **options|
|
37
|
-
result = parser.parse(args, options)
|
38
|
-
raise ArgumentError, result[:error] if result[:error]
|
39
|
-
send(result[:name], *result[:args])
|
40
|
-
end
|
41
|
-
end
|
42
|
-
end
|
43
|
-
end
|
44
|
-
end
|
45
|
-
end
|
46
|
-
end
|
47
|
-
|
48
|
-
Torch::Native::Dispatcher.bind
|
@@ -1,119 +0,0 @@
|
|
1
|
-
module Torch
|
2
|
-
module Native
|
3
|
-
class Function
|
4
|
-
attr_reader :function, :tensor_options
|
5
|
-
|
6
|
-
def initialize(function)
|
7
|
-
@function = function
|
8
|
-
|
9
|
-
tensor_options_str = ", *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None)"
|
10
|
-
@tensor_options = @function["func"].include?(tensor_options_str)
|
11
|
-
@function["func"].sub!(tensor_options_str, ")")
|
12
|
-
end
|
13
|
-
|
14
|
-
def func
|
15
|
-
@func ||= @function["func"]
|
16
|
-
end
|
17
|
-
|
18
|
-
def name
|
19
|
-
@name ||= func.split("(", 2).first
|
20
|
-
end
|
21
|
-
|
22
|
-
def python_module
|
23
|
-
@python_module ||= @function["python_module"]
|
24
|
-
end
|
25
|
-
|
26
|
-
def variants
|
27
|
-
@variants ||= (@function["variants"] || "function").split(", ")
|
28
|
-
end
|
29
|
-
|
30
|
-
def args
|
31
|
-
@args ||= begin
|
32
|
-
args = []
|
33
|
-
pos = true
|
34
|
-
args_str = func.split("(", 2).last.split(") ->").first
|
35
|
-
args_str.split(", ").each do |a|
|
36
|
-
if a == "*"
|
37
|
-
pos = false
|
38
|
-
next
|
39
|
-
end
|
40
|
-
t, _, k = a.rpartition(" ")
|
41
|
-
k, d = k.split("=")
|
42
|
-
has_default = !d.nil?
|
43
|
-
|
44
|
-
if d
|
45
|
-
d =
|
46
|
-
case d
|
47
|
-
when "True"
|
48
|
-
true
|
49
|
-
when "False"
|
50
|
-
false
|
51
|
-
when "None"
|
52
|
-
nil
|
53
|
-
when /\A\-?\d+\z/
|
54
|
-
d.to_i
|
55
|
-
when "[]"
|
56
|
-
[]
|
57
|
-
when "[0,1]"
|
58
|
-
[0, 1]
|
59
|
-
when /\A\de\-\d+\z/, /\A\d+\.\d+\z/
|
60
|
-
d.to_f
|
61
|
-
when "Mean"
|
62
|
-
"mean"
|
63
|
-
when "contiguous_format"
|
64
|
-
d
|
65
|
-
when "long"
|
66
|
-
:long
|
67
|
-
else
|
68
|
-
raise "Unknown default: #{d}"
|
69
|
-
end
|
70
|
-
end
|
71
|
-
|
72
|
-
next if t == "Generator?"
|
73
|
-
next if t == "MemoryFormat"
|
74
|
-
next if t == "MemoryFormat?"
|
75
|
-
args << {name: k, type: t, default: d, pos: pos, has_default: has_default}
|
76
|
-
end
|
77
|
-
args
|
78
|
-
end
|
79
|
-
end
|
80
|
-
|
81
|
-
def out_size
|
82
|
-
@out_size ||= func.split("->").last.count("!")
|
83
|
-
end
|
84
|
-
|
85
|
-
def ret_size
|
86
|
-
@ret_size ||= func.split("->").last.split(", ").size
|
87
|
-
end
|
88
|
-
|
89
|
-
def ret_array?
|
90
|
-
@ret_array ||= func.split("->").last.include?('[]')
|
91
|
-
end
|
92
|
-
|
93
|
-
def out?
|
94
|
-
out_size > 0 && base_name[-1] != "_"
|
95
|
-
end
|
96
|
-
|
97
|
-
def ruby_name
|
98
|
-
@ruby_name ||= begin
|
99
|
-
name = base_name
|
100
|
-
if name.end_with?("_")
|
101
|
-
"#{name[0..-2]}!"
|
102
|
-
elsif name.start_with?("is_")
|
103
|
-
"#{name[3..-1]}?"
|
104
|
-
else
|
105
|
-
name
|
106
|
-
end
|
107
|
-
end
|
108
|
-
end
|
109
|
-
|
110
|
-
def cpp_name
|
111
|
-
@cpp_name ||= "_" + name.downcase.sub(".", "_")
|
112
|
-
end
|
113
|
-
|
114
|
-
def base_name
|
115
|
-
@base_name ||= name.split(".").first
|
116
|
-
end
|
117
|
-
end
|
118
|
-
end
|
119
|
-
end
|
@@ -1,168 +0,0 @@
|
|
1
|
-
require "yaml"
|
2
|
-
# use require_relative for
|
3
|
-
# rake generate:function (without bundle)
|
4
|
-
require_relative "function"
|
5
|
-
|
6
|
-
module Torch
|
7
|
-
module Native
|
8
|
-
module Generator
|
9
|
-
class << self
|
10
|
-
def generate_cpp_functions
|
11
|
-
functions = grouped_functions
|
12
|
-
generate_cpp_file("torch", :define_singleton_method, functions[:torch])
|
13
|
-
generate_cpp_file("tensor", :define_method, functions[:tensor])
|
14
|
-
generate_cpp_file("nn", :define_singleton_method, functions[:nn])
|
15
|
-
end
|
16
|
-
|
17
|
-
def grouped_functions
|
18
|
-
functions = functions()
|
19
|
-
|
20
|
-
# skip functions
|
21
|
-
skip_args = ["Layout", "Storage", "ConstQuantizerPtr"]
|
22
|
-
|
23
|
-
# remove functions
|
24
|
-
functions.reject! do |f|
|
25
|
-
f.ruby_name.start_with?("_") ||
|
26
|
-
f.ruby_name.include?("_backward") ||
|
27
|
-
f.args.any? { |a| a[:type].include?("Dimname") }
|
28
|
-
end
|
29
|
-
|
30
|
-
# separate out into todo
|
31
|
-
todo_functions, functions =
|
32
|
-
functions.partition do |f|
|
33
|
-
f.args.any? do |a|
|
34
|
-
skip_args.any? { |sa| a[:type].include?(sa) } ||
|
35
|
-
# call to 'range' is ambiguous
|
36
|
-
f.cpp_name == "_range" ||
|
37
|
-
# native_functions.yaml is missing size argument for normal
|
38
|
-
# https://pytorch.org/cppdocs/api/function_namespacetorch_1a80253fe5a3ded4716ec929a348adb4b9.html
|
39
|
-
(f.base_name == "normal" && !f.out?)
|
40
|
-
end
|
41
|
-
end
|
42
|
-
|
43
|
-
# todo_functions.each do |f|
|
44
|
-
# puts f.func
|
45
|
-
# puts
|
46
|
-
# end
|
47
|
-
|
48
|
-
nn_functions, other_functions = functions.partition { |f| f.python_module == "nn" }
|
49
|
-
torch_functions = other_functions.select { |f| f.variants.include?("function") }
|
50
|
-
tensor_functions = other_functions.select { |f| f.variants.include?("method") }
|
51
|
-
|
52
|
-
{torch: torch_functions, tensor: tensor_functions, nn: nn_functions}
|
53
|
-
end
|
54
|
-
|
55
|
-
private
|
56
|
-
|
57
|
-
def generate_cpp_file(type, def_method, functions)
|
58
|
-
hpp_template = <<-TEMPLATE
|
59
|
-
// generated by rake generate:functions
|
60
|
-
// do not edit by hand
|
61
|
-
|
62
|
-
#pragma once
|
63
|
-
|
64
|
-
void add_%{type}_functions(Module m);
|
65
|
-
TEMPLATE
|
66
|
-
|
67
|
-
cpp_template = <<-TEMPLATE
|
68
|
-
// generated by rake generate:functions
|
69
|
-
// do not edit by hand
|
70
|
-
|
71
|
-
#include <torch/torch.h>
|
72
|
-
#include <rice/Module.hpp>
|
73
|
-
#include "templates.hpp"
|
74
|
-
|
75
|
-
void add_%{type}_functions(Module m) {
|
76
|
-
m
|
77
|
-
%{functions};
|
78
|
-
}
|
79
|
-
TEMPLATE
|
80
|
-
|
81
|
-
cpp_defs = []
|
82
|
-
functions.sort_by(&:cpp_name).each do |func|
|
83
|
-
fargs = func.args.dup #.select { |a| a[:type] != "Generator?" }
|
84
|
-
fargs << {name: "options", type: "TensorOptions"} if func.tensor_options
|
85
|
-
|
86
|
-
cpp_args = []
|
87
|
-
fargs.each do |a|
|
88
|
-
t =
|
89
|
-
case a[:type]
|
90
|
-
when "Tensor"
|
91
|
-
"const Tensor &"
|
92
|
-
when "Tensor?"
|
93
|
-
# TODO better signature
|
94
|
-
"OptionalTensor"
|
95
|
-
when "ScalarType?"
|
96
|
-
"torch::optional<ScalarType>"
|
97
|
-
when "Tensor[]"
|
98
|
-
"TensorList"
|
99
|
-
when "Tensor?[]"
|
100
|
-
# TODO make optional
|
101
|
-
"TensorList"
|
102
|
-
when "int"
|
103
|
-
"int64_t"
|
104
|
-
when "int?"
|
105
|
-
"torch::optional<int64_t>"
|
106
|
-
when "float?"
|
107
|
-
"torch::optional<double>"
|
108
|
-
when "bool?"
|
109
|
-
"torch::optional<bool>"
|
110
|
-
when "Scalar?"
|
111
|
-
"torch::optional<torch::Scalar>"
|
112
|
-
when "float"
|
113
|
-
"double"
|
114
|
-
when /\Aint\[/
|
115
|
-
"IntArrayRef"
|
116
|
-
when /Tensor\(\S!?\)/
|
117
|
-
"Tensor &"
|
118
|
-
when "str"
|
119
|
-
"std::string"
|
120
|
-
when "TensorOptions"
|
121
|
-
"const torch::TensorOptions &"
|
122
|
-
else
|
123
|
-
a[:type]
|
124
|
-
end
|
125
|
-
|
126
|
-
t = "MyReduction" if a[:name] == "reduction" && t == "int64_t"
|
127
|
-
cpp_args << [t, a[:name]].join(" ").sub("& ", "&")
|
128
|
-
end
|
129
|
-
|
130
|
-
dispatch = func.out? ? "#{func.base_name}_out" : func.base_name
|
131
|
-
args = fargs.map { |a| a[:name] }
|
132
|
-
args.unshift(*args.pop(func.out_size)) if func.out?
|
133
|
-
args.delete("self") if def_method == :define_method
|
134
|
-
|
135
|
-
prefix = def_method == :define_method ? "self." : "torch::"
|
136
|
-
|
137
|
-
body = "#{prefix}#{dispatch}(#{args.join(", ")})"
|
138
|
-
|
139
|
-
if func.ret_size > 1 || func.ret_array?
|
140
|
-
body = "wrap(#{body})"
|
141
|
-
end
|
142
|
-
|
143
|
-
cpp_defs << ".#{def_method}(
|
144
|
-
\"#{func.cpp_name}\",
|
145
|
-
*[](#{cpp_args.join(", ")}) {
|
146
|
-
return #{body};
|
147
|
-
})"
|
148
|
-
end
|
149
|
-
|
150
|
-
hpp_contents = hpp_template % {type: type}
|
151
|
-
cpp_contents = cpp_template % {type: type, functions: cpp_defs.join("\n ")}
|
152
|
-
|
153
|
-
path = File.expand_path("../../../ext/torch", __dir__)
|
154
|
-
File.write("#{path}/#{type}_functions.hpp", hpp_contents)
|
155
|
-
File.write("#{path}/#{type}_functions.cpp", cpp_contents)
|
156
|
-
end
|
157
|
-
|
158
|
-
def functions
|
159
|
-
@native_functions ||= YAML.load_file(path).map { |f| Function.new(f) }
|
160
|
-
end
|
161
|
-
|
162
|
-
def path
|
163
|
-
File.expand_path("native_functions.yaml", __dir__)
|
164
|
-
end
|
165
|
-
end
|
166
|
-
end
|
167
|
-
end
|
168
|
-
end
|
data/lib/torch/native/parser.rb
DELETED
@@ -1,148 +0,0 @@
|
|
1
|
-
module Torch
|
2
|
-
module Native
|
3
|
-
class Parser
|
4
|
-
def initialize(functions)
|
5
|
-
@functions = functions
|
6
|
-
@name = @functions.first.ruby_name
|
7
|
-
@min_args = @functions.map { |f| f.args.count { |a| a[:pos] && !a[:has_default] } }.min
|
8
|
-
@max_args = @functions.map { |f| f.args.count { |a| a[:pos] } }.max
|
9
|
-
end
|
10
|
-
|
11
|
-
def parse(args, options)
|
12
|
-
candidates = @functions.dup
|
13
|
-
|
14
|
-
# remove nil
|
15
|
-
while args.any? && args.last.nil?
|
16
|
-
args.pop
|
17
|
-
end
|
18
|
-
|
19
|
-
# TODO account for args passed as options here
|
20
|
-
if args.size < @min_args || args.size > @max_args
|
21
|
-
expected = String.new(@min_args.to_s)
|
22
|
-
expected += "..#{@max_args}" if @max_args != @min_args
|
23
|
-
return {error: "wrong number of arguments (given #{args.size}, expected #{expected})"}
|
24
|
-
end
|
25
|
-
|
26
|
-
candidates.reject! { |f| args.size > f.args.size }
|
27
|
-
|
28
|
-
# exclude functions missing required options
|
29
|
-
candidates.reject! do |func|
|
30
|
-
# TODO make more generic
|
31
|
-
func.out? && !options[:out]
|
32
|
-
end
|
33
|
-
|
34
|
-
# handle out with multiple
|
35
|
-
# there should only be one match, so safe to modify all
|
36
|
-
out_func = candidates.find { |f| f.out? }
|
37
|
-
if out_func && out_func.out_size > 1 && options[:out]
|
38
|
-
out_args = out_func.args.last(2).map { |a| a[:name] }
|
39
|
-
out_args.zip(options.delete(:out)).each do |k, v|
|
40
|
-
options[k.to_sym] = v
|
41
|
-
end
|
42
|
-
candidates = [out_func]
|
43
|
-
end
|
44
|
-
|
45
|
-
# exclude functions where options don't match
|
46
|
-
options.each do |k, v|
|
47
|
-
candidates.select! do |func|
|
48
|
-
func.args.any? { |a| a[:name] == k.to_s }
|
49
|
-
end
|
50
|
-
# TODO show all bad keywords at once like Ruby?
|
51
|
-
return {error: "unknown keyword: #{k}"} if candidates.empty?
|
52
|
-
end
|
53
|
-
|
54
|
-
final_values = {}
|
55
|
-
|
56
|
-
# check args
|
57
|
-
candidates.select! do |func|
|
58
|
-
good = true
|
59
|
-
|
60
|
-
values = args.zip(func.args).map { |a, fa| [fa[:name], a] }.to_h
|
61
|
-
values.merge!(options.map { |k, v| [k.to_s, v] }.to_h)
|
62
|
-
func.args.each do |fa|
|
63
|
-
values[fa[:name]] = fa[:default] if values[fa[:name]].nil?
|
64
|
-
end
|
65
|
-
|
66
|
-
arg_types = func.args.map { |a| [a[:name], a[:type]] }.to_h
|
67
|
-
|
68
|
-
values.each_key do |k|
|
69
|
-
v = values[k]
|
70
|
-
t = arg_types[k].split("(").first
|
71
|
-
|
72
|
-
good =
|
73
|
-
case t
|
74
|
-
when "Tensor"
|
75
|
-
v.is_a?(Tensor)
|
76
|
-
when "Tensor?"
|
77
|
-
v.nil? || v.is_a?(Tensor)
|
78
|
-
when "Tensor[]", "Tensor?[]"
|
79
|
-
v.is_a?(Array) && v.all? { |v2| v2.is_a?(Tensor) }
|
80
|
-
when "int"
|
81
|
-
if k == "reduction"
|
82
|
-
v.is_a?(String)
|
83
|
-
else
|
84
|
-
v.is_a?(Integer)
|
85
|
-
end
|
86
|
-
when "int?"
|
87
|
-
v.is_a?(Integer) || v.nil?
|
88
|
-
when "float?"
|
89
|
-
v.is_a?(Numeric) || v.nil?
|
90
|
-
when "bool?"
|
91
|
-
v == true || v == false || v.nil?
|
92
|
-
when "float"
|
93
|
-
v.is_a?(Numeric)
|
94
|
-
when /int\[.*\]/
|
95
|
-
if v.is_a?(Integer)
|
96
|
-
size = t[4..-2]
|
97
|
-
raise Error, "Unknown size: #{size}. Please report a bug with #{@name}." unless size =~ /\A\d+\z/
|
98
|
-
v = [v] * size.to_i
|
99
|
-
values[k] = v
|
100
|
-
end
|
101
|
-
v.is_a?(Array) && v.all? { |v2| v2.is_a?(Integer) }
|
102
|
-
when "Scalar"
|
103
|
-
v.is_a?(Numeric)
|
104
|
-
when "Scalar?"
|
105
|
-
v.is_a?(Numeric) || v.nil?
|
106
|
-
when "ScalarType"
|
107
|
-
false # not supported yet
|
108
|
-
when "ScalarType?"
|
109
|
-
v.nil?
|
110
|
-
when "bool"
|
111
|
-
v == true || v == false
|
112
|
-
when "str"
|
113
|
-
v.is_a?(String)
|
114
|
-
else
|
115
|
-
raise Error, "Unknown argument type: #{arg_types[k]}. Please report a bug with #{@name}."
|
116
|
-
end
|
117
|
-
|
118
|
-
if !good
|
119
|
-
if candidates.size == 1
|
120
|
-
k = "input" if k == "self"
|
121
|
-
return {error: "#{@name}(): argument '#{k}' must be #{t}"}
|
122
|
-
end
|
123
|
-
break
|
124
|
-
end
|
125
|
-
end
|
126
|
-
|
127
|
-
if good
|
128
|
-
final_values = values
|
129
|
-
end
|
130
|
-
|
131
|
-
good
|
132
|
-
end
|
133
|
-
|
134
|
-
if candidates.size != 1
|
135
|
-
raise Error, "This should never happen. Please report a bug with #{@name}."
|
136
|
-
end
|
137
|
-
|
138
|
-
func = candidates.first
|
139
|
-
args = func.args.map { |a| final_values[a[:name]] }
|
140
|
-
args << TensorOptions.new.dtype(6) if func.tensor_options
|
141
|
-
{
|
142
|
-
name: func.cpp_name,
|
143
|
-
args: args
|
144
|
-
}
|
145
|
-
end
|
146
|
-
end
|
147
|
-
end
|
148
|
-
end
|