torch-rb 0.3.7 → 0.5.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,200 +0,0 @@
1
- module Torch
2
- module Native
3
- class Function
4
- attr_reader :function, :tensor_options
5
-
6
- def initialize(function)
7
- @function = function
8
-
9
- # note: don't modify function in-place
10
- @tensor_options_str = ", *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None)"
11
- @tensor_options = @function["func"].include?(@tensor_options_str)
12
- @out = out_size > 0 && base_name[-1] != "_"
13
- end
14
-
15
- def func
16
- @func ||= @function["func"]
17
- end
18
-
19
- def name
20
- @name ||= func.split("(", 2).first
21
- end
22
-
23
- def python_module
24
- @python_module ||= @function["python_module"]
25
- end
26
-
27
- def variants
28
- @variants ||= (@function["variants"] || "function").split(", ")
29
- end
30
-
31
- def args
32
- @args ||= begin
33
- args = []
34
- pos = true
35
- args_str = func.sub(@tensor_options_str, ")").split("(", 2).last.split(") ->").first
36
- args_str.split(", ").each do |a|
37
- if a == "*"
38
- pos = false
39
- next
40
- end
41
- t, _, k = a.rpartition(" ")
42
- k, d = k.split("=")
43
- has_default = !d.nil?
44
-
45
- if d
46
- d =
47
- case d
48
- when "True"
49
- true
50
- when "False"
51
- false
52
- when "None"
53
- nil
54
- when /\A\-?\d+\z/
55
- d.to_i
56
- when "[]"
57
- []
58
- when "[0,1]"
59
- [0, 1]
60
- when /\A\de\-\d+\z/, /\A\d+\.\d+\z/
61
- d.to_f
62
- when "Mean"
63
- "mean"
64
- when "contiguous_format"
65
- d
66
- when "long"
67
- :long
68
- else
69
- raise "Unknown default: #{d}"
70
- end
71
- end
72
-
73
- next if t == "Generator?"
74
- next if t == "MemoryFormat"
75
- next if t == "MemoryFormat?"
76
- args << {name: k.to_sym, type: t, default: d, pos: pos, has_default: has_default}
77
- end
78
- args
79
- end
80
- end
81
-
82
- def arg_checkers
83
- @arg_checkers ||= begin
84
- checkers = {}
85
- arg_types.each do |k, t|
86
- checker =
87
- case t
88
- when "Tensor"
89
- ->(v) { v.is_a?(Tensor) }
90
- when "Tensor?"
91
- ->(v) { v.nil? || v.is_a?(Tensor) }
92
- when "Tensor[]", "Tensor?[]"
93
- ->(v) { v.is_a?(Array) && v.all? { |v2| v2.is_a?(Tensor) } }
94
- when "int"
95
- if k == :reduction
96
- ->(v) { v.is_a?(String) }
97
- else
98
- ->(v) { v.is_a?(Integer) }
99
- end
100
- when "int?"
101
- ->(v) { v.is_a?(Integer) || v.nil? }
102
- when "float?"
103
- ->(v) { v.is_a?(Numeric) || v.nil? }
104
- when "bool?"
105
- ->(v) { v == true || v == false || v.nil? }
106
- when "float"
107
- ->(v) { v.is_a?(Numeric) }
108
- when /int\[.*\]/
109
- ->(v) { v.is_a?(Array) && v.all? { |v2| v2.is_a?(Integer) } }
110
- when "Scalar"
111
- ->(v) { v.is_a?(Numeric) }
112
- when "Scalar?"
113
- ->(v) { v.is_a?(Numeric) || v.nil? }
114
- when "ScalarType"
115
- ->(v) { false } # not supported yet
116
- when "ScalarType?"
117
- ->(v) { v.nil? }
118
- when "bool"
119
- ->(v) { v == true || v == false }
120
- when "str"
121
- ->(v) { v.is_a?(String) }
122
- else
123
- raise Error, "Unknown argument type: #{t}. Please report a bug with #{@name}."
124
- end
125
- checkers[k] = checker
126
- end
127
- checkers
128
- end
129
- end
130
-
131
- def int_array_lengths
132
- @int_array_lengths ||= begin
133
- ret = {}
134
- arg_types.each do |k, t|
135
- if t.match?(/\Aint\[.+\]\z/)
136
- size = t[4..-2]
137
- raise Error, "Unknown size: #{size}. Please report a bug with #{@name}." unless size =~ /\A\d+\z/
138
- ret[k] = size.to_i
139
- end
140
- end
141
- ret
142
- end
143
- end
144
-
145
- def arg_names
146
- @arg_names ||= args.map { |a| a[:name] }
147
- end
148
-
149
- def arg_types
150
- @arg_types ||= args.map { |a| [a[:name], a[:type].split("(").first] }.to_h
151
- end
152
-
153
- def arg_defaults
154
- # TODO find out why can't use select here
155
- @arg_defaults ||= args.map { |a| [a[:name], a[:default]] }.to_h
156
- end
157
-
158
- def out_size
159
- @out_size ||= func.split("->").last.count("!")
160
- end
161
-
162
- def ret_size
163
- @ret_size ||= func.split("->").last.split(", ").size
164
- end
165
-
166
- def ret_array?
167
- @ret_array ||= func.split("->").last.include?('[]')
168
- end
169
-
170
- def ret_void?
171
- func.split("->").last.strip == "()"
172
- end
173
-
174
- def out?
175
- @out
176
- end
177
-
178
- def ruby_name
179
- @ruby_name ||= begin
180
- name = base_name
181
- if name.end_with?("_")
182
- "#{name[0..-2]}!"
183
- elsif name.start_with?("is_")
184
- "#{name[3..-1]}?"
185
- else
186
- name
187
- end
188
- end
189
- end
190
-
191
- def cpp_name
192
- @cpp_name ||= "_" + name.downcase.sub(".", "_")
193
- end
194
-
195
- def base_name
196
- @base_name ||= name.split(".").first
197
- end
198
- end
199
- end
200
- end
@@ -1,178 +0,0 @@
1
- require "yaml"
2
- # use require_relative for
3
- # rake generate:function (without bundle)
4
- require_relative "function"
5
-
6
- module Torch
7
- module Native
8
- module Generator
9
- class << self
10
- def generate_cpp_functions
11
- functions = grouped_functions
12
- generate_cpp_file("torch", :define_singleton_method, functions[:torch])
13
- generate_cpp_file("tensor", :define_method, functions[:tensor])
14
- generate_cpp_file("nn", :define_singleton_method, functions[:nn])
15
- end
16
-
17
- def grouped_functions
18
- functions = functions()
19
-
20
- # skip functions
21
- skip_args = ["Layout", "Storage", "ConstQuantizerPtr"]
22
-
23
- # remove functions
24
- functions.reject! do |f|
25
- f.ruby_name.start_with?("_") ||
26
- f.ruby_name.include?("_backward") ||
27
- f.args.any? { |a| a[:type].include?("Dimname") }
28
- end
29
-
30
- # separate out into todo
31
- todo_functions, functions =
32
- functions.partition do |f|
33
- f.args.any? do |a|
34
- skip_args.any? { |sa| a[:type].include?(sa) } ||
35
- # call to 'range' is ambiguous
36
- f.cpp_name == "_range" ||
37
- # native_functions.yaml is missing size argument for normal
38
- # https://pytorch.org/cppdocs/api/function_namespacetorch_1a80253fe5a3ded4716ec929a348adb4b9.html
39
- (f.base_name == "normal" && !f.out?)
40
- end
41
- end
42
-
43
- # todo_functions.each do |f|
44
- # puts f.func
45
- # puts
46
- # end
47
-
48
- nn_functions, other_functions = functions.partition { |f| f.python_module == "nn" }
49
- torch_functions = other_functions.select { |f| f.variants.include?("function") }
50
- tensor_functions = other_functions.select { |f| f.variants.include?("method") }
51
-
52
- {torch: torch_functions, tensor: tensor_functions, nn: nn_functions}
53
- end
54
-
55
- private
56
-
57
- def generate_cpp_file(type, def_method, functions)
58
- hpp_template = <<-TEMPLATE
59
- // generated by rake generate:functions
60
- // do not edit by hand
61
-
62
- #pragma once
63
-
64
- void add_%{type}_functions(Module m);
65
- TEMPLATE
66
-
67
- cpp_template = <<-TEMPLATE
68
- // generated by rake generate:functions
69
- // do not edit by hand
70
-
71
- #include <torch/torch.h>
72
- #include <rice/Module.hpp>
73
- #include "templates.hpp"
74
-
75
- %{functions}
76
-
77
- void add_%{type}_functions(Module m) {
78
- %{add_functions}
79
- }
80
- TEMPLATE
81
-
82
- cpp_defs = []
83
- add_defs = []
84
- functions.sort_by(&:cpp_name).each do |func|
85
- fargs = func.args.dup #.select { |a| a[:type] != "Generator?" }
86
- fargs << {name: :options, type: "TensorOptions"} if func.tensor_options
87
-
88
- cpp_args = []
89
- fargs.each do |a|
90
- t =
91
- case a[:type]
92
- when "Tensor"
93
- "const Tensor &"
94
- when "Tensor?"
95
- # TODO better signature
96
- "OptionalTensor"
97
- when "ScalarType?"
98
- "torch::optional<ScalarType>"
99
- when "Tensor[]", "Tensor?[]"
100
- # TODO make optional
101
- "std::vector<Tensor>"
102
- when "int"
103
- "int64_t"
104
- when "int?"
105
- "torch::optional<int64_t>"
106
- when "float?"
107
- "torch::optional<double>"
108
- when "bool?"
109
- "torch::optional<bool>"
110
- when "Scalar?"
111
- "torch::optional<torch::Scalar>"
112
- when "float"
113
- "double"
114
- when /\Aint\[/
115
- "std::vector<int64_t>"
116
- when /Tensor\(\S!?\)/
117
- "Tensor &"
118
- when "str"
119
- "std::string"
120
- when "TensorOptions"
121
- "const torch::TensorOptions &"
122
- when "Layout?"
123
- "torch::optional<Layout>"
124
- when "Device?"
125
- "torch::optional<Device>"
126
- when "Scalar", "bool", "ScalarType", "Layout", "Device", "Storage"
127
- a[:type]
128
- else
129
- raise "Unknown type: #{a[:type]}"
130
- end
131
-
132
- t = "MyReduction" if a[:name] == :reduction && t == "int64_t"
133
- cpp_args << [t, a[:name]].join(" ").sub("& ", "&")
134
- end
135
-
136
- dispatch = func.out? ? "#{func.base_name}_out" : func.base_name
137
- args = fargs.map { |a| a[:name] }
138
- args.unshift(*args.pop(func.out_size)) if func.out?
139
- args.delete(:self) if def_method == :define_method
140
-
141
- prefix = def_method == :define_method ? "self." : "torch::"
142
-
143
- body = "#{prefix}#{dispatch}(#{args.join(", ")})"
144
-
145
- if func.cpp_name == "_fill_diagonal_"
146
- body = "to_ruby<torch::Tensor>(#{body})"
147
- elsif !func.ret_void?
148
- body = "wrap(#{body})"
149
- end
150
-
151
- cpp_defs << "// #{func.func}
152
- static #{func.ret_void? ? "void" : "Object"} #{type}#{func.cpp_name}(#{cpp_args.join(", ")})
153
- {
154
- return #{body};
155
- }"
156
-
157
- add_defs << "m.#{def_method}(\"#{func.cpp_name}\", #{type}#{func.cpp_name});"
158
- end
159
-
160
- hpp_contents = hpp_template % {type: type}
161
- cpp_contents = cpp_template % {type: type, functions: cpp_defs.join("\n\n"), add_functions: add_defs.join("\n ")}
162
-
163
- path = File.expand_path("../../../ext/torch", __dir__)
164
- File.write("#{path}/#{type}_functions.hpp", hpp_contents)
165
- File.write("#{path}/#{type}_functions.cpp", cpp_contents)
166
- end
167
-
168
- def functions
169
- @native_functions ||= YAML.load_file(path).map { |f| Function.new(f) }
170
- end
171
-
172
- def path
173
- File.expand_path("native_functions.yaml", __dir__)
174
- end
175
- end
176
- end
177
- end
178
- end
@@ -1,117 +0,0 @@
1
- module Torch
2
- module Native
3
- class Parser
4
- def initialize(functions)
5
- @functions = functions
6
- @name = @functions.first.ruby_name
7
- @min_args = @functions.map { |f| f.args.count { |a| a[:pos] && !a[:has_default] } }.min
8
- @max_args = @functions.map { |f| f.args.count { |a| a[:pos] } }.max
9
- @int_array_first = @functions.all? { |c| c.args.first && c.args.first[:type] == "int[]" }
10
- end
11
-
12
- # TODO improve performance
13
- # possibly move to C++ (see python_arg_parser.cpp)
14
- def parse(args, options)
15
- candidates = @functions.dup
16
-
17
- # TODO check candidates individually to see if they match
18
- if @int_array_first
19
- int_args = []
20
- while args.first.is_a?(Integer)
21
- int_args << args.shift
22
- end
23
- if int_args.any?
24
- raise ArgumentError, "argument '#{candidates.first.args.first[:name]}' must be array of ints, but found element of type #{args.first.class.name} at pos #{int_args.size + 1}" if args.any?
25
- args.unshift(int_args)
26
- end
27
- end
28
-
29
- # TODO account for args passed as options here
30
- if args.size < @min_args || args.size > @max_args
31
- expected = String.new(@min_args.to_s)
32
- expected += "..#{@max_args}" if @max_args != @min_args
33
- return {error: "wrong number of arguments (given #{args.size}, expected #{expected})"}
34
- end
35
-
36
- candidates.reject! { |f| args.size > f.args.size }
37
-
38
- # handle out with multiple
39
- # there should only be one match, so safe to modify all
40
- if options[:out]
41
- if (out_func = candidates.find { |f| f.out? }) && out_func.out_size > 1
42
- out_args = out_func.args.last(2).map { |a| a[:name] }
43
- out_args.zip(options.delete(:out)).each do |k, v|
44
- options[k] = v
45
- end
46
- candidates = [out_func]
47
- end
48
- else
49
- # exclude functions missing required options
50
- candidates.reject!(&:out?)
51
- end
52
-
53
- final_values = nil
54
-
55
- # check args
56
- while (func = candidates.shift)
57
- good = true
58
-
59
- # set values
60
- # TODO use array instead of hash?
61
- values = {}
62
- args.each_with_index do |a, i|
63
- values[func.arg_names[i]] = a
64
- end
65
- options.each do |k, v|
66
- values[k] = v
67
- end
68
- func.arg_defaults.each do |k, v|
69
- values[k] = v unless values.key?(k)
70
- end
71
- func.int_array_lengths.each do |k, len|
72
- values[k] = [values[k]] * len if values[k].is_a?(Integer)
73
- end
74
-
75
- arg_checkers = func.arg_checkers
76
-
77
- values.each_key do |k|
78
- unless arg_checkers.key?(k)
79
- good = false
80
- if candidates.empty?
81
- # TODO show all bad keywords at once like Ruby?
82
- return {error: "unknown keyword: #{k}"}
83
- end
84
- break
85
- end
86
-
87
- unless arg_checkers[k].call(values[k])
88
- good = false
89
- if candidates.empty?
90
- t = func.arg_types[k]
91
- k = :input if k == :self
92
- return {error: "#{@name}(): argument '#{k}' must be #{t}"}
93
- end
94
- break
95
- end
96
- end
97
-
98
- if good
99
- final_values = values
100
- break
101
- end
102
- end
103
-
104
- unless final_values
105
- raise Error, "This should never happen. Please report a bug with #{@name}."
106
- end
107
-
108
- args = func.arg_names.map { |k| final_values[k] }
109
- args << TensorOptions.new.dtype(6) if func.tensor_options
110
- {
111
- name: func.cpp_name,
112
- args: args
113
- }
114
- end
115
- end
116
- end
117
- end