torch-rb 0.3.3 → 0.4.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,48 +0,0 @@
1
- # We use a generic interface for methods (*args, **options)
2
- # and this class to determine the C++ method to call
3
- #
4
- # This is needed since LibTorch uses function overloading,
5
- # which isn't available in Ruby or Python
6
- #
7
- # PyTorch uses this approach, but the parser/dispatcher is written in C++
8
- #
9
- # We could generate Ruby methods directly, but an advantage of this approach is
10
- # arguments and keyword arguments can be used interchangably like in Python,
11
- # making it easier to port code
12
-
13
- module Torch
14
- module Native
15
- module Dispatcher
16
- class << self
17
- def bind
18
- functions = Generator.grouped_functions
19
- bind_functions(::Torch, :define_singleton_method, functions[:torch])
20
- bind_functions(::Torch::Tensor, :define_method, functions[:tensor])
21
- bind_functions(::Torch::NN, :define_singleton_method, functions[:nn])
22
- end
23
-
24
- def bind_functions(context, def_method, functions)
25
- functions.group_by(&:ruby_name).sort_by { |g, _| g }.each do |name, funcs|
26
- if def_method == :define_method
27
- funcs.map! { |f| Function.new(f.function) }
28
- funcs.each { |f| f.args.reject! { |a| a[:name] == "self" } }
29
- end
30
-
31
- defined = def_method == :define_method ? context.method_defined?(name) : context.respond_to?(name)
32
- next if defined && name != "clone"
33
-
34
- parser = Parser.new(funcs)
35
-
36
- context.send(def_method, name) do |*args, **options|
37
- result = parser.parse(args, options)
38
- raise ArgumentError, result[:error] if result[:error]
39
- send(result[:name], *result[:args])
40
- end
41
- end
42
- end
43
- end
44
- end
45
- end
46
- end
47
-
48
- Torch::Native::Dispatcher.bind
@@ -1,115 +0,0 @@
1
- module Torch
2
- module Native
3
- class Function
4
- attr_reader :function, :tensor_options
5
-
6
- def initialize(function)
7
- @function = function
8
-
9
- tensor_options_str = ", *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None)"
10
- @tensor_options = @function["func"].include?(tensor_options_str)
11
- @function["func"].sub!(tensor_options_str, ")")
12
- end
13
-
14
- def func
15
- @func ||= @function["func"]
16
- end
17
-
18
- def name
19
- @name ||= func.split("(", 2).first
20
- end
21
-
22
- def python_module
23
- @python_module ||= @function["python_module"]
24
- end
25
-
26
- def variants
27
- @variants ||= (@function["variants"] || "function").split(", ")
28
- end
29
-
30
- def args
31
- @args ||= begin
32
- args = []
33
- pos = true
34
- args_str = func.split("(", 2).last.split(") ->").first
35
- args_str.split(", ").each do |a|
36
- if a == "*"
37
- pos = false
38
- next
39
- end
40
- t, _, k = a.rpartition(" ")
41
- k, d = k.split("=")
42
- has_default = !d.nil?
43
-
44
- if d
45
- d =
46
- case d
47
- when "True"
48
- true
49
- when "False"
50
- false
51
- when "None"
52
- nil
53
- when /\A\-?\d+\z/
54
- d.to_i
55
- when "[]"
56
- []
57
- when "[0,1]"
58
- [0, 1]
59
- when /\A\de\-\d+\z/, /\A\d+\.\d+\z/
60
- d.to_f
61
- when "Mean"
62
- "mean"
63
- when "contiguous_format"
64
- d
65
- when "long"
66
- :long
67
- else
68
- raise "Unknown default: #{d}"
69
- end
70
- end
71
-
72
- next if t == "Generator?"
73
- next if t == "MemoryFormat"
74
- next if t == "MemoryFormat?"
75
- args << {name: k, type: t, default: d, pos: pos, has_default: has_default}
76
- end
77
- args
78
- end
79
- end
80
-
81
- def out_size
82
- @out_size ||= func.split("->").last.count("!")
83
- end
84
-
85
- def ret_size
86
- @ret_size ||= func.split("->").last.split(", ").size
87
- end
88
-
89
- def out?
90
- out_size > 0 && base_name[-1] != "_"
91
- end
92
-
93
- def ruby_name
94
- @ruby_name ||= begin
95
- name = base_name
96
- if name.end_with?("_")
97
- "#{name[0..-2]}!"
98
- elsif name.start_with?("is_")
99
- "#{name[3..-1]}?"
100
- else
101
- name
102
- end
103
- end
104
- end
105
-
106
- def cpp_name
107
- @cpp_name ||= "_" + name.downcase.sub(".", "_")
108
- end
109
-
110
- def base_name
111
- @base_name ||= name.split(".").first
112
- end
113
- end
114
- end
115
- end
@@ -1,163 +0,0 @@
1
- require "yaml"
2
- # use require_relative for
3
- # rake generate:function (without bundle)
4
- require_relative "function"
5
-
6
- module Torch
7
- module Native
8
- module Generator
9
- class << self
10
- def generate_cpp_functions
11
- functions = grouped_functions
12
- generate_cpp_file("torch", :define_singleton_method, functions[:torch])
13
- generate_cpp_file("tensor", :define_method, functions[:tensor])
14
- generate_cpp_file("nn", :define_singleton_method, functions[:nn])
15
- end
16
-
17
- def grouped_functions
18
- functions = functions()
19
-
20
- # skip functions
21
- skip_args = ["bool[3]", "Dimname", "Layout", "Storage", "ConstQuantizerPtr"]
22
-
23
- # remove functions
24
- functions.reject! do |f|
25
- f.ruby_name.start_with?("_") ||
26
- f.ruby_name.end_with?("_backward") ||
27
- f.args.any? { |a| a[:type].include?("Dimname") }
28
- end
29
-
30
- # separate out into todo
31
- todo_functions, functions =
32
- functions.partition do |f|
33
- f.args.any? do |a|
34
- a[:type].include?("?") && !["Tensor?", "Generator?", "int?", "ScalarType?", "Tensor?[]"].include?(a[:type]) ||
35
- skip_args.any? { |sa| a[:type].include?(sa) } ||
36
- # call to 'range' is ambiguous
37
- f.cpp_name == "_range" ||
38
- # native_functions.yaml is missing size argument for normal
39
- # https://pytorch.org/cppdocs/api/function_namespacetorch_1a80253fe5a3ded4716ec929a348adb4b9.html
40
- (f.base_name == "normal" && !f.out?)
41
- end
42
- end
43
-
44
- # todo_functions.each do |f|
45
- # puts f.func
46
- # puts
47
- # end
48
-
49
- nn_functions, other_functions = functions.partition { |f| f.python_module == "nn" }
50
- torch_functions = other_functions.select { |f| f.variants.include?("function") }
51
- tensor_functions = other_functions.select { |f| f.variants.include?("method") }
52
-
53
- {torch: torch_functions, tensor: tensor_functions, nn: nn_functions}
54
- end
55
-
56
- private
57
-
58
- def generate_cpp_file(type, def_method, functions)
59
- hpp_template = <<-TEMPLATE
60
- // generated by rake generate:functions
61
- // do not edit by hand
62
-
63
- #pragma once
64
-
65
- void add_%{type}_functions(Module m);
66
- TEMPLATE
67
-
68
- cpp_template = <<-TEMPLATE
69
- // generated by rake generate:functions
70
- // do not edit by hand
71
-
72
- #include <torch/torch.h>
73
- #include <rice/Module.hpp>
74
- #include "templates.hpp"
75
-
76
- void add_%{type}_functions(Module m) {
77
- m
78
- %{functions};
79
- }
80
- TEMPLATE
81
-
82
- cpp_defs = []
83
- functions.sort_by(&:cpp_name).each do |func|
84
- fargs = func.args.dup #.select { |a| a[:type] != "Generator?" }
85
- fargs << {name: "options", type: "TensorOptions"} if func.tensor_options
86
-
87
- cpp_args = []
88
- fargs.each do |a|
89
- t =
90
- case a[:type]
91
- when "Tensor"
92
- "const Tensor &"
93
- when "Tensor?"
94
- # TODO better signature
95
- "OptionalTensor"
96
- when "ScalarType?"
97
- "torch::optional<ScalarType>"
98
- when "Tensor[]"
99
- "TensorList"
100
- when "Tensor?[]"
101
- # TODO make optional
102
- "TensorList"
103
- when "int"
104
- "int64_t"
105
- when "int?"
106
- "torch::optional<int64_t>"
107
- when "float"
108
- "double"
109
- when /\Aint\[/
110
- "IntArrayRef"
111
- when /Tensor\(\S!?\)/
112
- "Tensor &"
113
- when "str"
114
- "std::string"
115
- when "TensorOptions"
116
- "const torch::TensorOptions &"
117
- else
118
- a[:type]
119
- end
120
-
121
- t = "MyReduction" if a[:name] == "reduction" && t == "int64_t"
122
- cpp_args << [t, a[:name]].join(" ").sub("& ", "&")
123
- end
124
-
125
- dispatch = func.out? ? "#{func.base_name}_out" : func.base_name
126
- args = fargs.map { |a| a[:name] }
127
- args.unshift(*args.pop(func.out_size)) if func.out?
128
- args.delete("self") if def_method == :define_method
129
-
130
- prefix = def_method == :define_method ? "self." : "torch::"
131
-
132
- body = "#{prefix}#{dispatch}(#{args.join(", ")})"
133
- # TODO check type as well
134
- if func.ret_size > 1
135
- body = "wrap(#{body})"
136
- end
137
-
138
- cpp_defs << ".#{def_method}(
139
- \"#{func.cpp_name}\",
140
- *[](#{cpp_args.join(", ")}) {
141
- return #{body};
142
- })"
143
- end
144
-
145
- hpp_contents = hpp_template % {type: type}
146
- cpp_contents = cpp_template % {type: type, functions: cpp_defs.join("\n ")}
147
-
148
- path = File.expand_path("../../../ext/torch", __dir__)
149
- File.write("#{path}/#{type}_functions.hpp", hpp_contents)
150
- File.write("#{path}/#{type}_functions.cpp", cpp_contents)
151
- end
152
-
153
- def functions
154
- @native_functions ||= YAML.load_file(path).map { |f| Function.new(f) }
155
- end
156
-
157
- def path
158
- File.expand_path("native_functions.yaml", __dir__)
159
- end
160
- end
161
- end
162
- end
163
- end
@@ -1,140 +0,0 @@
1
- module Torch
2
- module Native
3
- class Parser
4
- def initialize(functions)
5
- @functions = functions
6
- @name = @functions.first.ruby_name
7
- @min_args = @functions.map { |f| f.args.count { |a| a[:pos] && !a[:has_default] } }.min
8
- @max_args = @functions.map { |f| f.args.count { |a| a[:pos] } }.max
9
- end
10
-
11
- def parse(args, options)
12
- candidates = @functions.dup
13
-
14
- # remove nil
15
- while args.any? && args.last.nil?
16
- args.pop
17
- end
18
-
19
- # TODO account for args passed as options here
20
- if args.size < @min_args || args.size > @max_args
21
- expected = String.new(@min_args.to_s)
22
- expected += "..#{@max_args}" if @max_args != @min_args
23
- return {error: "wrong number of arguments (given #{args.size}, expected #{expected})"}
24
- end
25
-
26
- candidates.reject! { |f| args.size > f.args.size }
27
-
28
- # exclude functions missing required options
29
- candidates.reject! do |func|
30
- # TODO make more generic
31
- func.out? && !options[:out]
32
- end
33
-
34
- # handle out with multiple
35
- # there should only be one match, so safe to modify all
36
- out_func = candidates.find { |f| f.out? }
37
- if out_func && out_func.out_size > 1 && options[:out]
38
- out_args = out_func.args.last(2).map { |a| a[:name] }
39
- out_args.zip(options.delete(:out)).each do |k, v|
40
- options[k.to_sym] = v
41
- end
42
- candidates = [out_func]
43
- end
44
-
45
- # exclude functions where options don't match
46
- options.each do |k, v|
47
- candidates.select! do |func|
48
- func.args.any? { |a| a[:name] == k.to_s }
49
- end
50
- # TODO show all bad keywords at once like Ruby?
51
- return {error: "unknown keyword: #{k}"} if candidates.empty?
52
- end
53
-
54
- final_values = {}
55
-
56
- # check args
57
- candidates.select! do |func|
58
- good = true
59
-
60
- values = args.zip(func.args).map { |a, fa| [fa[:name], a] }.to_h
61
- values.merge!(options.map { |k, v| [k.to_s, v] }.to_h)
62
- func.args.each do |fa|
63
- values[fa[:name]] = fa[:default] if values[fa[:name]].nil?
64
- end
65
-
66
- arg_types = func.args.map { |a| [a[:name], a[:type]] }.to_h
67
-
68
- values.each_key do |k|
69
- v = values[k]
70
- t = arg_types[k].split("(").first
71
-
72
- good =
73
- case t
74
- when "Tensor"
75
- v.is_a?(Tensor)
76
- when "Tensor?"
77
- v.nil? || v.is_a?(Tensor)
78
- when "Tensor[]", "Tensor?[]"
79
- v.is_a?(Array) && v.all? { |v2| v2.is_a?(Tensor) }
80
- when "int"
81
- if k == "reduction"
82
- v.is_a?(String)
83
- else
84
- v.is_a?(Integer)
85
- end
86
- when "int?"
87
- v.is_a?(Integer) || v.nil?
88
- when "float"
89
- v.is_a?(Numeric)
90
- when /int\[.*\]/
91
- if v.is_a?(Integer)
92
- size = t[4..-2]
93
- raise Error, "Unknown size: #{size}. Please report a bug with #{@name}." unless size =~ /\A\d+\z/
94
- v = [v] * size.to_i
95
- values[k] = v
96
- end
97
- v.is_a?(Array) && v.all? { |v2| v2.is_a?(Integer) }
98
- when "Scalar"
99
- v.is_a?(Numeric)
100
- when "ScalarType?"
101
- v.nil?
102
- when "bool"
103
- v == true || v == false
104
- when "str"
105
- v.is_a?(String)
106
- else
107
- raise Error, "Unknown argument type: #{arg_types[k]}. Please report a bug with #{@name}."
108
- end
109
-
110
- if !good
111
- if candidates.size == 1
112
- k = "input" if k == "self"
113
- return {error: "#{@name}(): argument '#{k}' must be #{t}"}
114
- end
115
- break
116
- end
117
- end
118
-
119
- if good
120
- final_values = values
121
- end
122
-
123
- good
124
- end
125
-
126
- if candidates.size != 1
127
- raise Error, "This should never happen. Please report a bug with #{@name}."
128
- end
129
-
130
- func = candidates.first
131
- args = func.args.map { |a| final_values[a[:name]] }
132
- args << TensorOptions.new.dtype(6) if func.tensor_options
133
- {
134
- name: func.cpp_name,
135
- args: args
136
- }
137
- end
138
- end
139
- end
140
- end