torch-rb 0.3.7 → 0.4.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,178 +0,0 @@
1
- require "yaml"
2
- # use require_relative for
3
- # rake generate:function (without bundle)
4
- require_relative "function"
5
-
6
- module Torch
7
- module Native
8
- module Generator
9
- class << self
10
- def generate_cpp_functions
11
- functions = grouped_functions
12
- generate_cpp_file("torch", :define_singleton_method, functions[:torch])
13
- generate_cpp_file("tensor", :define_method, functions[:tensor])
14
- generate_cpp_file("nn", :define_singleton_method, functions[:nn])
15
- end
16
-
17
- def grouped_functions
18
- functions = functions()
19
-
20
- # skip functions
21
- skip_args = ["Layout", "Storage", "ConstQuantizerPtr"]
22
-
23
- # remove functions
24
- functions.reject! do |f|
25
- f.ruby_name.start_with?("_") ||
26
- f.ruby_name.include?("_backward") ||
27
- f.args.any? { |a| a[:type].include?("Dimname") }
28
- end
29
-
30
- # separate out into todo
31
- todo_functions, functions =
32
- functions.partition do |f|
33
- f.args.any? do |a|
34
- skip_args.any? { |sa| a[:type].include?(sa) } ||
35
- # call to 'range' is ambiguous
36
- f.cpp_name == "_range" ||
37
- # native_functions.yaml is missing size argument for normal
38
- # https://pytorch.org/cppdocs/api/function_namespacetorch_1a80253fe5a3ded4716ec929a348adb4b9.html
39
- (f.base_name == "normal" && !f.out?)
40
- end
41
- end
42
-
43
- # todo_functions.each do |f|
44
- # puts f.func
45
- # puts
46
- # end
47
-
48
- nn_functions, other_functions = functions.partition { |f| f.python_module == "nn" }
49
- torch_functions = other_functions.select { |f| f.variants.include?("function") }
50
- tensor_functions = other_functions.select { |f| f.variants.include?("method") }
51
-
52
- {torch: torch_functions, tensor: tensor_functions, nn: nn_functions}
53
- end
54
-
55
- private
56
-
57
- def generate_cpp_file(type, def_method, functions)
58
- hpp_template = <<-TEMPLATE
59
- // generated by rake generate:functions
60
- // do not edit by hand
61
-
62
- #pragma once
63
-
64
- void add_%{type}_functions(Module m);
65
- TEMPLATE
66
-
67
- cpp_template = <<-TEMPLATE
68
- // generated by rake generate:functions
69
- // do not edit by hand
70
-
71
- #include <torch/torch.h>
72
- #include <rice/Module.hpp>
73
- #include "templates.hpp"
74
-
75
- %{functions}
76
-
77
- void add_%{type}_functions(Module m) {
78
- %{add_functions}
79
- }
80
- TEMPLATE
81
-
82
- cpp_defs = []
83
- add_defs = []
84
- functions.sort_by(&:cpp_name).each do |func|
85
- fargs = func.args.dup #.select { |a| a[:type] != "Generator?" }
86
- fargs << {name: :options, type: "TensorOptions"} if func.tensor_options
87
-
88
- cpp_args = []
89
- fargs.each do |a|
90
- t =
91
- case a[:type]
92
- when "Tensor"
93
- "const Tensor &"
94
- when "Tensor?"
95
- # TODO better signature
96
- "OptionalTensor"
97
- when "ScalarType?"
98
- "torch::optional<ScalarType>"
99
- when "Tensor[]", "Tensor?[]"
100
- # TODO make optional
101
- "std::vector<Tensor>"
102
- when "int"
103
- "int64_t"
104
- when "int?"
105
- "torch::optional<int64_t>"
106
- when "float?"
107
- "torch::optional<double>"
108
- when "bool?"
109
- "torch::optional<bool>"
110
- when "Scalar?"
111
- "torch::optional<torch::Scalar>"
112
- when "float"
113
- "double"
114
- when /\Aint\[/
115
- "std::vector<int64_t>"
116
- when /Tensor\(\S!?\)/
117
- "Tensor &"
118
- when "str"
119
- "std::string"
120
- when "TensorOptions"
121
- "const torch::TensorOptions &"
122
- when "Layout?"
123
- "torch::optional<Layout>"
124
- when "Device?"
125
- "torch::optional<Device>"
126
- when "Scalar", "bool", "ScalarType", "Layout", "Device", "Storage"
127
- a[:type]
128
- else
129
- raise "Unknown type: #{a[:type]}"
130
- end
131
-
132
- t = "MyReduction" if a[:name] == :reduction && t == "int64_t"
133
- cpp_args << [t, a[:name]].join(" ").sub("& ", "&")
134
- end
135
-
136
- dispatch = func.out? ? "#{func.base_name}_out" : func.base_name
137
- args = fargs.map { |a| a[:name] }
138
- args.unshift(*args.pop(func.out_size)) if func.out?
139
- args.delete(:self) if def_method == :define_method
140
-
141
- prefix = def_method == :define_method ? "self." : "torch::"
142
-
143
- body = "#{prefix}#{dispatch}(#{args.join(", ")})"
144
-
145
- if func.cpp_name == "_fill_diagonal_"
146
- body = "to_ruby<torch::Tensor>(#{body})"
147
- elsif !func.ret_void?
148
- body = "wrap(#{body})"
149
- end
150
-
151
- cpp_defs << "// #{func.func}
152
- static #{func.ret_void? ? "void" : "Object"} #{type}#{func.cpp_name}(#{cpp_args.join(", ")})
153
- {
154
- return #{body};
155
- }"
156
-
157
- add_defs << "m.#{def_method}(\"#{func.cpp_name}\", #{type}#{func.cpp_name});"
158
- end
159
-
160
- hpp_contents = hpp_template % {type: type}
161
- cpp_contents = cpp_template % {type: type, functions: cpp_defs.join("\n\n"), add_functions: add_defs.join("\n ")}
162
-
163
- path = File.expand_path("../../../ext/torch", __dir__)
164
- File.write("#{path}/#{type}_functions.hpp", hpp_contents)
165
- File.write("#{path}/#{type}_functions.cpp", cpp_contents)
166
- end
167
-
168
- def functions
169
- @native_functions ||= YAML.load_file(path).map { |f| Function.new(f) }
170
- end
171
-
172
- def path
173
- File.expand_path("native_functions.yaml", __dir__)
174
- end
175
- end
176
- end
177
- end
178
- end
@@ -1,117 +0,0 @@
1
- module Torch
2
- module Native
3
- class Parser
4
- def initialize(functions)
5
- @functions = functions
6
- @name = @functions.first.ruby_name
7
- @min_args = @functions.map { |f| f.args.count { |a| a[:pos] && !a[:has_default] } }.min
8
- @max_args = @functions.map { |f| f.args.count { |a| a[:pos] } }.max
9
- @int_array_first = @functions.all? { |c| c.args.first && c.args.first[:type] == "int[]" }
10
- end
11
-
12
- # TODO improve performance
13
- # possibly move to C++ (see python_arg_parser.cpp)
14
- def parse(args, options)
15
- candidates = @functions.dup
16
-
17
- # TODO check candidates individually to see if they match
18
- if @int_array_first
19
- int_args = []
20
- while args.first.is_a?(Integer)
21
- int_args << args.shift
22
- end
23
- if int_args.any?
24
- raise ArgumentError, "argument '#{candidates.first.args.first[:name]}' must be array of ints, but found element of type #{args.first.class.name} at pos #{int_args.size + 1}" if args.any?
25
- args.unshift(int_args)
26
- end
27
- end
28
-
29
- # TODO account for args passed as options here
30
- if args.size < @min_args || args.size > @max_args
31
- expected = String.new(@min_args.to_s)
32
- expected += "..#{@max_args}" if @max_args != @min_args
33
- return {error: "wrong number of arguments (given #{args.size}, expected #{expected})"}
34
- end
35
-
36
- candidates.reject! { |f| args.size > f.args.size }
37
-
38
- # handle out with multiple
39
- # there should only be one match, so safe to modify all
40
- if options[:out]
41
- if (out_func = candidates.find { |f| f.out? }) && out_func.out_size > 1
42
- out_args = out_func.args.last(2).map { |a| a[:name] }
43
- out_args.zip(options.delete(:out)).each do |k, v|
44
- options[k] = v
45
- end
46
- candidates = [out_func]
47
- end
48
- else
49
- # exclude functions missing required options
50
- candidates.reject!(&:out?)
51
- end
52
-
53
- final_values = nil
54
-
55
- # check args
56
- while (func = candidates.shift)
57
- good = true
58
-
59
- # set values
60
- # TODO use array instead of hash?
61
- values = {}
62
- args.each_with_index do |a, i|
63
- values[func.arg_names[i]] = a
64
- end
65
- options.each do |k, v|
66
- values[k] = v
67
- end
68
- func.arg_defaults.each do |k, v|
69
- values[k] = v unless values.key?(k)
70
- end
71
- func.int_array_lengths.each do |k, len|
72
- values[k] = [values[k]] * len if values[k].is_a?(Integer)
73
- end
74
-
75
- arg_checkers = func.arg_checkers
76
-
77
- values.each_key do |k|
78
- unless arg_checkers.key?(k)
79
- good = false
80
- if candidates.empty?
81
- # TODO show all bad keywords at once like Ruby?
82
- return {error: "unknown keyword: #{k}"}
83
- end
84
- break
85
- end
86
-
87
- unless arg_checkers[k].call(values[k])
88
- good = false
89
- if candidates.empty?
90
- t = func.arg_types[k]
91
- k = :input if k == :self
92
- return {error: "#{@name}(): argument '#{k}' must be #{t}"}
93
- end
94
- break
95
- end
96
- end
97
-
98
- if good
99
- final_values = values
100
- break
101
- end
102
- end
103
-
104
- unless final_values
105
- raise Error, "This should never happen. Please report a bug with #{@name}."
106
- end
107
-
108
- args = func.arg_names.map { |k| final_values[k] }
109
- args << TensorOptions.new.dtype(6) if func.tensor_options
110
- {
111
- name: func.cpp_name,
112
- args: args
113
- }
114
- end
115
- end
116
- end
117
- end