torch-rb 0.3.7 → 0.5.1

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,200 +0,0 @@
1
- module Torch
2
- module Native
3
- class Function
4
- attr_reader :function, :tensor_options
5
-
6
- def initialize(function)
7
- @function = function
8
-
9
- # note: don't modify function in-place
10
- @tensor_options_str = ", *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None)"
11
- @tensor_options = @function["func"].include?(@tensor_options_str)
12
- @out = out_size > 0 && base_name[-1] != "_"
13
- end
14
-
15
- def func
16
- @func ||= @function["func"]
17
- end
18
-
19
- def name
20
- @name ||= func.split("(", 2).first
21
- end
22
-
23
- def python_module
24
- @python_module ||= @function["python_module"]
25
- end
26
-
27
- def variants
28
- @variants ||= (@function["variants"] || "function").split(", ")
29
- end
30
-
31
- def args
32
- @args ||= begin
33
- args = []
34
- pos = true
35
- args_str = func.sub(@tensor_options_str, ")").split("(", 2).last.split(") ->").first
36
- args_str.split(", ").each do |a|
37
- if a == "*"
38
- pos = false
39
- next
40
- end
41
- t, _, k = a.rpartition(" ")
42
- k, d = k.split("=")
43
- has_default = !d.nil?
44
-
45
- if d
46
- d =
47
- case d
48
- when "True"
49
- true
50
- when "False"
51
- false
52
- when "None"
53
- nil
54
- when /\A\-?\d+\z/
55
- d.to_i
56
- when "[]"
57
- []
58
- when "[0,1]"
59
- [0, 1]
60
- when /\A\de\-\d+\z/, /\A\d+\.\d+\z/
61
- d.to_f
62
- when "Mean"
63
- "mean"
64
- when "contiguous_format"
65
- d
66
- when "long"
67
- :long
68
- else
69
- raise "Unknown default: #{d}"
70
- end
71
- end
72
-
73
- next if t == "Generator?"
74
- next if t == "MemoryFormat"
75
- next if t == "MemoryFormat?"
76
- args << {name: k.to_sym, type: t, default: d, pos: pos, has_default: has_default}
77
- end
78
- args
79
- end
80
- end
81
-
82
- def arg_checkers
83
- @arg_checkers ||= begin
84
- checkers = {}
85
- arg_types.each do |k, t|
86
- checker =
87
- case t
88
- when "Tensor"
89
- ->(v) { v.is_a?(Tensor) }
90
- when "Tensor?"
91
- ->(v) { v.nil? || v.is_a?(Tensor) }
92
- when "Tensor[]", "Tensor?[]"
93
- ->(v) { v.is_a?(Array) && v.all? { |v2| v2.is_a?(Tensor) } }
94
- when "int"
95
- if k == :reduction
96
- ->(v) { v.is_a?(String) }
97
- else
98
- ->(v) { v.is_a?(Integer) }
99
- end
100
- when "int?"
101
- ->(v) { v.is_a?(Integer) || v.nil? }
102
- when "float?"
103
- ->(v) { v.is_a?(Numeric) || v.nil? }
104
- when "bool?"
105
- ->(v) { v == true || v == false || v.nil? }
106
- when "float"
107
- ->(v) { v.is_a?(Numeric) }
108
- when /int\[.*\]/
109
- ->(v) { v.is_a?(Array) && v.all? { |v2| v2.is_a?(Integer) } }
110
- when "Scalar"
111
- ->(v) { v.is_a?(Numeric) }
112
- when "Scalar?"
113
- ->(v) { v.is_a?(Numeric) || v.nil? }
114
- when "ScalarType"
115
- ->(v) { false } # not supported yet
116
- when "ScalarType?"
117
- ->(v) { v.nil? }
118
- when "bool"
119
- ->(v) { v == true || v == false }
120
- when "str"
121
- ->(v) { v.is_a?(String) }
122
- else
123
- raise Error, "Unknown argument type: #{t}. Please report a bug with #{@name}."
124
- end
125
- checkers[k] = checker
126
- end
127
- checkers
128
- end
129
- end
130
-
131
- def int_array_lengths
132
- @int_array_lengths ||= begin
133
- ret = {}
134
- arg_types.each do |k, t|
135
- if t.match?(/\Aint\[.+\]\z/)
136
- size = t[4..-2]
137
- raise Error, "Unknown size: #{size}. Please report a bug with #{@name}." unless size =~ /\A\d+\z/
138
- ret[k] = size.to_i
139
- end
140
- end
141
- ret
142
- end
143
- end
144
-
145
- def arg_names
146
- @arg_names ||= args.map { |a| a[:name] }
147
- end
148
-
149
- def arg_types
150
- @arg_types ||= args.map { |a| [a[:name], a[:type].split("(").first] }.to_h
151
- end
152
-
153
- def arg_defaults
154
- # TODO find out why can't use select here
155
- @arg_defaults ||= args.map { |a| [a[:name], a[:default]] }.to_h
156
- end
157
-
158
- def out_size
159
- @out_size ||= func.split("->").last.count("!")
160
- end
161
-
162
- def ret_size
163
- @ret_size ||= func.split("->").last.split(", ").size
164
- end
165
-
166
- def ret_array?
167
- @ret_array ||= func.split("->").last.include?('[]')
168
- end
169
-
170
- def ret_void?
171
- func.split("->").last.strip == "()"
172
- end
173
-
174
- def out?
175
- @out
176
- end
177
-
178
- def ruby_name
179
- @ruby_name ||= begin
180
- name = base_name
181
- if name.end_with?("_")
182
- "#{name[0..-2]}!"
183
- elsif name.start_with?("is_")
184
- "#{name[3..-1]}?"
185
- else
186
- name
187
- end
188
- end
189
- end
190
-
191
- def cpp_name
192
- @cpp_name ||= "_" + name.downcase.sub(".", "_")
193
- end
194
-
195
- def base_name
196
- @base_name ||= name.split(".").first
197
- end
198
- end
199
- end
200
- end
@@ -1,178 +0,0 @@
1
- require "yaml"
2
- # use require_relative for
3
- # rake generate:function (without bundle)
4
- require_relative "function"
5
-
6
- module Torch
7
- module Native
8
- module Generator
9
- class << self
10
- def generate_cpp_functions
11
- functions = grouped_functions
12
- generate_cpp_file("torch", :define_singleton_method, functions[:torch])
13
- generate_cpp_file("tensor", :define_method, functions[:tensor])
14
- generate_cpp_file("nn", :define_singleton_method, functions[:nn])
15
- end
16
-
17
- def grouped_functions
18
- functions = functions()
19
-
20
- # skip functions
21
- skip_args = ["Layout", "Storage", "ConstQuantizerPtr"]
22
-
23
- # remove functions
24
- functions.reject! do |f|
25
- f.ruby_name.start_with?("_") ||
26
- f.ruby_name.include?("_backward") ||
27
- f.args.any? { |a| a[:type].include?("Dimname") }
28
- end
29
-
30
- # separate out into todo
31
- todo_functions, functions =
32
- functions.partition do |f|
33
- f.args.any? do |a|
34
- skip_args.any? { |sa| a[:type].include?(sa) } ||
35
- # call to 'range' is ambiguous
36
- f.cpp_name == "_range" ||
37
- # native_functions.yaml is missing size argument for normal
38
- # https://pytorch.org/cppdocs/api/function_namespacetorch_1a80253fe5a3ded4716ec929a348adb4b9.html
39
- (f.base_name == "normal" && !f.out?)
40
- end
41
- end
42
-
43
- # todo_functions.each do |f|
44
- # puts f.func
45
- # puts
46
- # end
47
-
48
- nn_functions, other_functions = functions.partition { |f| f.python_module == "nn" }
49
- torch_functions = other_functions.select { |f| f.variants.include?("function") }
50
- tensor_functions = other_functions.select { |f| f.variants.include?("method") }
51
-
52
- {torch: torch_functions, tensor: tensor_functions, nn: nn_functions}
53
- end
54
-
55
- private
56
-
57
- def generate_cpp_file(type, def_method, functions)
58
- hpp_template = <<-TEMPLATE
59
- // generated by rake generate:functions
60
- // do not edit by hand
61
-
62
- #pragma once
63
-
64
- void add_%{type}_functions(Module m);
65
- TEMPLATE
66
-
67
- cpp_template = <<-TEMPLATE
68
- // generated by rake generate:functions
69
- // do not edit by hand
70
-
71
- #include <torch/torch.h>
72
- #include <rice/Module.hpp>
73
- #include "templates.hpp"
74
-
75
- %{functions}
76
-
77
- void add_%{type}_functions(Module m) {
78
- %{add_functions}
79
- }
80
- TEMPLATE
81
-
82
- cpp_defs = []
83
- add_defs = []
84
- functions.sort_by(&:cpp_name).each do |func|
85
- fargs = func.args.dup #.select { |a| a[:type] != "Generator?" }
86
- fargs << {name: :options, type: "TensorOptions"} if func.tensor_options
87
-
88
- cpp_args = []
89
- fargs.each do |a|
90
- t =
91
- case a[:type]
92
- when "Tensor"
93
- "const Tensor &"
94
- when "Tensor?"
95
- # TODO better signature
96
- "OptionalTensor"
97
- when "ScalarType?"
98
- "torch::optional<ScalarType>"
99
- when "Tensor[]", "Tensor?[]"
100
- # TODO make optional
101
- "std::vector<Tensor>"
102
- when "int"
103
- "int64_t"
104
- when "int?"
105
- "torch::optional<int64_t>"
106
- when "float?"
107
- "torch::optional<double>"
108
- when "bool?"
109
- "torch::optional<bool>"
110
- when "Scalar?"
111
- "torch::optional<torch::Scalar>"
112
- when "float"
113
- "double"
114
- when /\Aint\[/
115
- "std::vector<int64_t>"
116
- when /Tensor\(\S!?\)/
117
- "Tensor &"
118
- when "str"
119
- "std::string"
120
- when "TensorOptions"
121
- "const torch::TensorOptions &"
122
- when "Layout?"
123
- "torch::optional<Layout>"
124
- when "Device?"
125
- "torch::optional<Device>"
126
- when "Scalar", "bool", "ScalarType", "Layout", "Device", "Storage"
127
- a[:type]
128
- else
129
- raise "Unknown type: #{a[:type]}"
130
- end
131
-
132
- t = "MyReduction" if a[:name] == :reduction && t == "int64_t"
133
- cpp_args << [t, a[:name]].join(" ").sub("& ", "&")
134
- end
135
-
136
- dispatch = func.out? ? "#{func.base_name}_out" : func.base_name
137
- args = fargs.map { |a| a[:name] }
138
- args.unshift(*args.pop(func.out_size)) if func.out?
139
- args.delete(:self) if def_method == :define_method
140
-
141
- prefix = def_method == :define_method ? "self." : "torch::"
142
-
143
- body = "#{prefix}#{dispatch}(#{args.join(", ")})"
144
-
145
- if func.cpp_name == "_fill_diagonal_"
146
- body = "to_ruby<torch::Tensor>(#{body})"
147
- elsif !func.ret_void?
148
- body = "wrap(#{body})"
149
- end
150
-
151
- cpp_defs << "// #{func.func}
152
- static #{func.ret_void? ? "void" : "Object"} #{type}#{func.cpp_name}(#{cpp_args.join(", ")})
153
- {
154
- return #{body};
155
- }"
156
-
157
- add_defs << "m.#{def_method}(\"#{func.cpp_name}\", #{type}#{func.cpp_name});"
158
- end
159
-
160
- hpp_contents = hpp_template % {type: type}
161
- cpp_contents = cpp_template % {type: type, functions: cpp_defs.join("\n\n"), add_functions: add_defs.join("\n ")}
162
-
163
- path = File.expand_path("../../../ext/torch", __dir__)
164
- File.write("#{path}/#{type}_functions.hpp", hpp_contents)
165
- File.write("#{path}/#{type}_functions.cpp", cpp_contents)
166
- end
167
-
168
- def functions
169
- @native_functions ||= YAML.load_file(path).map { |f| Function.new(f) }
170
- end
171
-
172
- def path
173
- File.expand_path("native_functions.yaml", __dir__)
174
- end
175
- end
176
- end
177
- end
178
- end
@@ -1,117 +0,0 @@
1
- module Torch
2
- module Native
3
- class Parser
4
- def initialize(functions)
5
- @functions = functions
6
- @name = @functions.first.ruby_name
7
- @min_args = @functions.map { |f| f.args.count { |a| a[:pos] && !a[:has_default] } }.min
8
- @max_args = @functions.map { |f| f.args.count { |a| a[:pos] } }.max
9
- @int_array_first = @functions.all? { |c| c.args.first && c.args.first[:type] == "int[]" }
10
- end
11
-
12
- # TODO improve performance
13
- # possibly move to C++ (see python_arg_parser.cpp)
14
- def parse(args, options)
15
- candidates = @functions.dup
16
-
17
- # TODO check candidates individually to see if they match
18
- if @int_array_first
19
- int_args = []
20
- while args.first.is_a?(Integer)
21
- int_args << args.shift
22
- end
23
- if int_args.any?
24
- raise ArgumentError, "argument '#{candidates.first.args.first[:name]}' must be array of ints, but found element of type #{args.first.class.name} at pos #{int_args.size + 1}" if args.any?
25
- args.unshift(int_args)
26
- end
27
- end
28
-
29
- # TODO account for args passed as options here
30
- if args.size < @min_args || args.size > @max_args
31
- expected = String.new(@min_args.to_s)
32
- expected += "..#{@max_args}" if @max_args != @min_args
33
- return {error: "wrong number of arguments (given #{args.size}, expected #{expected})"}
34
- end
35
-
36
- candidates.reject! { |f| args.size > f.args.size }
37
-
38
- # handle out with multiple
39
- # there should only be one match, so safe to modify all
40
- if options[:out]
41
- if (out_func = candidates.find { |f| f.out? }) && out_func.out_size > 1
42
- out_args = out_func.args.last(2).map { |a| a[:name] }
43
- out_args.zip(options.delete(:out)).each do |k, v|
44
- options[k] = v
45
- end
46
- candidates = [out_func]
47
- end
48
- else
49
- # exclude functions missing required options
50
- candidates.reject!(&:out?)
51
- end
52
-
53
- final_values = nil
54
-
55
- # check args
56
- while (func = candidates.shift)
57
- good = true
58
-
59
- # set values
60
- # TODO use array instead of hash?
61
- values = {}
62
- args.each_with_index do |a, i|
63
- values[func.arg_names[i]] = a
64
- end
65
- options.each do |k, v|
66
- values[k] = v
67
- end
68
- func.arg_defaults.each do |k, v|
69
- values[k] = v unless values.key?(k)
70
- end
71
- func.int_array_lengths.each do |k, len|
72
- values[k] = [values[k]] * len if values[k].is_a?(Integer)
73
- end
74
-
75
- arg_checkers = func.arg_checkers
76
-
77
- values.each_key do |k|
78
- unless arg_checkers.key?(k)
79
- good = false
80
- if candidates.empty?
81
- # TODO show all bad keywords at once like Ruby?
82
- return {error: "unknown keyword: #{k}"}
83
- end
84
- break
85
- end
86
-
87
- unless arg_checkers[k].call(values[k])
88
- good = false
89
- if candidates.empty?
90
- t = func.arg_types[k]
91
- k = :input if k == :self
92
- return {error: "#{@name}(): argument '#{k}' must be #{t}"}
93
- end
94
- break
95
- end
96
- end
97
-
98
- if good
99
- final_values = values
100
- break
101
- end
102
- end
103
-
104
- unless final_values
105
- raise Error, "This should never happen. Please report a bug with #{@name}."
106
- end
107
-
108
- args = func.arg_names.map { |k| final_values[k] }
109
- args << TensorOptions.new.dtype(6) if func.tensor_options
110
- {
111
- name: func.cpp_name,
112
- args: args
113
- }
114
- end
115
- end
116
- end
117
- end