torch-rb 0.3.7 → 0.4.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,178 +0,0 @@
1
- require "yaml"
2
- # use require_relative for
3
- # rake generate:function (without bundle)
4
- require_relative "function"
5
-
6
- module Torch
7
- module Native
8
- module Generator
9
- class << self
10
- def generate_cpp_functions
11
- functions = grouped_functions
12
- generate_cpp_file("torch", :define_singleton_method, functions[:torch])
13
- generate_cpp_file("tensor", :define_method, functions[:tensor])
14
- generate_cpp_file("nn", :define_singleton_method, functions[:nn])
15
- end
16
-
17
- def grouped_functions
18
- functions = functions()
19
-
20
- # skip functions
21
- skip_args = ["Layout", "Storage", "ConstQuantizerPtr"]
22
-
23
- # remove functions
24
- functions.reject! do |f|
25
- f.ruby_name.start_with?("_") ||
26
- f.ruby_name.include?("_backward") ||
27
- f.args.any? { |a| a[:type].include?("Dimname") }
28
- end
29
-
30
- # separate out into todo
31
- todo_functions, functions =
32
- functions.partition do |f|
33
- f.args.any? do |a|
34
- skip_args.any? { |sa| a[:type].include?(sa) } ||
35
- # call to 'range' is ambiguous
36
- f.cpp_name == "_range" ||
37
- # native_functions.yaml is missing size argument for normal
38
- # https://pytorch.org/cppdocs/api/function_namespacetorch_1a80253fe5a3ded4716ec929a348adb4b9.html
39
- (f.base_name == "normal" && !f.out?)
40
- end
41
- end
42
-
43
- # todo_functions.each do |f|
44
- # puts f.func
45
- # puts
46
- # end
47
-
48
- nn_functions, other_functions = functions.partition { |f| f.python_module == "nn" }
49
- torch_functions = other_functions.select { |f| f.variants.include?("function") }
50
- tensor_functions = other_functions.select { |f| f.variants.include?("method") }
51
-
52
- {torch: torch_functions, tensor: tensor_functions, nn: nn_functions}
53
- end
54
-
55
- private
56
-
57
- def generate_cpp_file(type, def_method, functions)
58
- hpp_template = <<-TEMPLATE
59
- // generated by rake generate:functions
60
- // do not edit by hand
61
-
62
- #pragma once
63
-
64
- void add_%{type}_functions(Module m);
65
- TEMPLATE
66
-
67
- cpp_template = <<-TEMPLATE
68
- // generated by rake generate:functions
69
- // do not edit by hand
70
-
71
- #include <torch/torch.h>
72
- #include <rice/Module.hpp>
73
- #include "templates.hpp"
74
-
75
- %{functions}
76
-
77
- void add_%{type}_functions(Module m) {
78
- %{add_functions}
79
- }
80
- TEMPLATE
81
-
82
- cpp_defs = []
83
- add_defs = []
84
- functions.sort_by(&:cpp_name).each do |func|
85
- fargs = func.args.dup #.select { |a| a[:type] != "Generator?" }
86
- fargs << {name: :options, type: "TensorOptions"} if func.tensor_options
87
-
88
- cpp_args = []
89
- fargs.each do |a|
90
- t =
91
- case a[:type]
92
- when "Tensor"
93
- "const Tensor &"
94
- when "Tensor?"
95
- # TODO better signature
96
- "OptionalTensor"
97
- when "ScalarType?"
98
- "torch::optional<ScalarType>"
99
- when "Tensor[]", "Tensor?[]"
100
- # TODO make optional
101
- "std::vector<Tensor>"
102
- when "int"
103
- "int64_t"
104
- when "int?"
105
- "torch::optional<int64_t>"
106
- when "float?"
107
- "torch::optional<double>"
108
- when "bool?"
109
- "torch::optional<bool>"
110
- when "Scalar?"
111
- "torch::optional<torch::Scalar>"
112
- when "float"
113
- "double"
114
- when /\Aint\[/
115
- "std::vector<int64_t>"
116
- when /Tensor\(\S!?\)/
117
- "Tensor &"
118
- when "str"
119
- "std::string"
120
- when "TensorOptions"
121
- "const torch::TensorOptions &"
122
- when "Layout?"
123
- "torch::optional<Layout>"
124
- when "Device?"
125
- "torch::optional<Device>"
126
- when "Scalar", "bool", "ScalarType", "Layout", "Device", "Storage"
127
- a[:type]
128
- else
129
- raise "Unknown type: #{a[:type]}"
130
- end
131
-
132
- t = "MyReduction" if a[:name] == :reduction && t == "int64_t"
133
- cpp_args << [t, a[:name]].join(" ").sub("& ", "&")
134
- end
135
-
136
- dispatch = func.out? ? "#{func.base_name}_out" : func.base_name
137
- args = fargs.map { |a| a[:name] }
138
- args.unshift(*args.pop(func.out_size)) if func.out?
139
- args.delete(:self) if def_method == :define_method
140
-
141
- prefix = def_method == :define_method ? "self." : "torch::"
142
-
143
- body = "#{prefix}#{dispatch}(#{args.join(", ")})"
144
-
145
- if func.cpp_name == "_fill_diagonal_"
146
- body = "to_ruby<torch::Tensor>(#{body})"
147
- elsif !func.ret_void?
148
- body = "wrap(#{body})"
149
- end
150
-
151
- cpp_defs << "// #{func.func}
152
- static #{func.ret_void? ? "void" : "Object"} #{type}#{func.cpp_name}(#{cpp_args.join(", ")})
153
- {
154
- return #{body};
155
- }"
156
-
157
- add_defs << "m.#{def_method}(\"#{func.cpp_name}\", #{type}#{func.cpp_name});"
158
- end
159
-
160
- hpp_contents = hpp_template % {type: type}
161
- cpp_contents = cpp_template % {type: type, functions: cpp_defs.join("\n\n"), add_functions: add_defs.join("\n ")}
162
-
163
- path = File.expand_path("../../../ext/torch", __dir__)
164
- File.write("#{path}/#{type}_functions.hpp", hpp_contents)
165
- File.write("#{path}/#{type}_functions.cpp", cpp_contents)
166
- end
167
-
168
- def functions
169
- @native_functions ||= YAML.load_file(path).map { |f| Function.new(f) }
170
- end
171
-
172
- def path
173
- File.expand_path("native_functions.yaml", __dir__)
174
- end
175
- end
176
- end
177
- end
178
- end
@@ -1,117 +0,0 @@
1
- module Torch
2
- module Native
3
- class Parser
4
- def initialize(functions)
5
- @functions = functions
6
- @name = @functions.first.ruby_name
7
- @min_args = @functions.map { |f| f.args.count { |a| a[:pos] && !a[:has_default] } }.min
8
- @max_args = @functions.map { |f| f.args.count { |a| a[:pos] } }.max
9
- @int_array_first = @functions.all? { |c| c.args.first && c.args.first[:type] == "int[]" }
10
- end
11
-
12
- # TODO improve performance
13
- # possibly move to C++ (see python_arg_parser.cpp)
14
- def parse(args, options)
15
- candidates = @functions.dup
16
-
17
- # TODO check candidates individually to see if they match
18
- if @int_array_first
19
- int_args = []
20
- while args.first.is_a?(Integer)
21
- int_args << args.shift
22
- end
23
- if int_args.any?
24
- raise ArgumentError, "argument '#{candidates.first.args.first[:name]}' must be array of ints, but found element of type #{args.first.class.name} at pos #{int_args.size + 1}" if args.any?
25
- args.unshift(int_args)
26
- end
27
- end
28
-
29
- # TODO account for args passed as options here
30
- if args.size < @min_args || args.size > @max_args
31
- expected = String.new(@min_args.to_s)
32
- expected += "..#{@max_args}" if @max_args != @min_args
33
- return {error: "wrong number of arguments (given #{args.size}, expected #{expected})"}
34
- end
35
-
36
- candidates.reject! { |f| args.size > f.args.size }
37
-
38
- # handle out with multiple
39
- # there should only be one match, so safe to modify all
40
- if options[:out]
41
- if (out_func = candidates.find { |f| f.out? }) && out_func.out_size > 1
42
- out_args = out_func.args.last(2).map { |a| a[:name] }
43
- out_args.zip(options.delete(:out)).each do |k, v|
44
- options[k] = v
45
- end
46
- candidates = [out_func]
47
- end
48
- else
49
- # exclude functions missing required options
50
- candidates.reject!(&:out?)
51
- end
52
-
53
- final_values = nil
54
-
55
- # check args
56
- while (func = candidates.shift)
57
- good = true
58
-
59
- # set values
60
- # TODO use array instead of hash?
61
- values = {}
62
- args.each_with_index do |a, i|
63
- values[func.arg_names[i]] = a
64
- end
65
- options.each do |k, v|
66
- values[k] = v
67
- end
68
- func.arg_defaults.each do |k, v|
69
- values[k] = v unless values.key?(k)
70
- end
71
- func.int_array_lengths.each do |k, len|
72
- values[k] = [values[k]] * len if values[k].is_a?(Integer)
73
- end
74
-
75
- arg_checkers = func.arg_checkers
76
-
77
- values.each_key do |k|
78
- unless arg_checkers.key?(k)
79
- good = false
80
- if candidates.empty?
81
- # TODO show all bad keywords at once like Ruby?
82
- return {error: "unknown keyword: #{k}"}
83
- end
84
- break
85
- end
86
-
87
- unless arg_checkers[k].call(values[k])
88
- good = false
89
- if candidates.empty?
90
- t = func.arg_types[k]
91
- k = :input if k == :self
92
- return {error: "#{@name}(): argument '#{k}' must be #{t}"}
93
- end
94
- break
95
- end
96
- end
97
-
98
- if good
99
- final_values = values
100
- break
101
- end
102
- end
103
-
104
- unless final_values
105
- raise Error, "This should never happen. Please report a bug with #{@name}."
106
- end
107
-
108
- args = func.arg_names.map { |k| final_values[k] }
109
- args << TensorOptions.new.dtype(6) if func.tensor_options
110
- {
111
- name: func.cpp_name,
112
- args: args
113
- }
114
- end
115
- end
116
- end
117
- end