torch-rb 0.3.4 → 0.4.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,48 +0,0 @@
1
- # We use a generic interface for methods (*args, **options)
2
- # and this class to determine the C++ method to call
3
- #
4
- # This is needed since LibTorch uses function overloading,
5
- # which isn't available in Ruby or Python
6
- #
7
- # PyTorch uses this approach, but the parser/dispatcher is written in C++
8
- #
9
- # We could generate Ruby methods directly, but an advantage of this approach is
10
- # arguments and keyword arguments can be used interchangably like in Python,
11
- # making it easier to port code
12
-
13
- module Torch
14
- module Native
15
- module Dispatcher
16
- class << self
17
- def bind
18
- functions = Generator.grouped_functions
19
- bind_functions(::Torch, :define_singleton_method, functions[:torch])
20
- bind_functions(::Torch::Tensor, :define_method, functions[:tensor])
21
- bind_functions(::Torch::NN, :define_singleton_method, functions[:nn])
22
- end
23
-
24
- def bind_functions(context, def_method, functions)
25
- functions.group_by(&:ruby_name).sort_by { |g, _| g }.each do |name, funcs|
26
- if def_method == :define_method
27
- funcs.map! { |f| Function.new(f.function) }
28
- funcs.each { |f| f.args.reject! { |a| a[:name] == "self" } }
29
- end
30
-
31
- defined = def_method == :define_method ? context.method_defined?(name) : context.respond_to?(name)
32
- next if defined && name != "clone"
33
-
34
- parser = Parser.new(funcs)
35
-
36
- context.send(def_method, name) do |*args, **options|
37
- result = parser.parse(args, options)
38
- raise ArgumentError, result[:error] if result[:error]
39
- send(result[:name], *result[:args])
40
- end
41
- end
42
- end
43
- end
44
- end
45
- end
46
- end
47
-
48
- Torch::Native::Dispatcher.bind
@@ -1,115 +0,0 @@
1
- module Torch
2
- module Native
3
- class Function
4
- attr_reader :function, :tensor_options
5
-
6
- def initialize(function)
7
- @function = function
8
-
9
- tensor_options_str = ", *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None)"
10
- @tensor_options = @function["func"].include?(tensor_options_str)
11
- @function["func"].sub!(tensor_options_str, ")")
12
- end
13
-
14
- def func
15
- @func ||= @function["func"]
16
- end
17
-
18
- def name
19
- @name ||= func.split("(", 2).first
20
- end
21
-
22
- def python_module
23
- @python_module ||= @function["python_module"]
24
- end
25
-
26
- def variants
27
- @variants ||= (@function["variants"] || "function").split(", ")
28
- end
29
-
30
- def args
31
- @args ||= begin
32
- args = []
33
- pos = true
34
- args_str = func.split("(", 2).last.split(") ->").first
35
- args_str.split(", ").each do |a|
36
- if a == "*"
37
- pos = false
38
- next
39
- end
40
- t, _, k = a.rpartition(" ")
41
- k, d = k.split("=")
42
- has_default = !d.nil?
43
-
44
- if d
45
- d =
46
- case d
47
- when "True"
48
- true
49
- when "False"
50
- false
51
- when "None"
52
- nil
53
- when /\A\-?\d+\z/
54
- d.to_i
55
- when "[]"
56
- []
57
- when "[0,1]"
58
- [0, 1]
59
- when /\A\de\-\d+\z/, /\A\d+\.\d+\z/
60
- d.to_f
61
- when "Mean"
62
- "mean"
63
- when "contiguous_format"
64
- d
65
- when "long"
66
- :long
67
- else
68
- raise "Unknown default: #{d}"
69
- end
70
- end
71
-
72
- next if t == "Generator?"
73
- next if t == "MemoryFormat"
74
- next if t == "MemoryFormat?"
75
- args << {name: k, type: t, default: d, pos: pos, has_default: has_default}
76
- end
77
- args
78
- end
79
- end
80
-
81
- def out_size
82
- @out_size ||= func.split("->").last.count("!")
83
- end
84
-
85
- def ret_size
86
- @ret_size ||= func.split("->").last.split(", ").size
87
- end
88
-
89
- def out?
90
- out_size > 0 && base_name[-1] != "_"
91
- end
92
-
93
- def ruby_name
94
- @ruby_name ||= begin
95
- name = base_name
96
- if name.end_with?("_")
97
- "#{name[0..-2]}!"
98
- elsif name.start_with?("is_")
99
- "#{name[3..-1]}?"
100
- else
101
- name
102
- end
103
- end
104
- end
105
-
106
- def cpp_name
107
- @cpp_name ||= "_" + name.downcase.sub(".", "_")
108
- end
109
-
110
- def base_name
111
- @base_name ||= name.split(".").first
112
- end
113
- end
114
- end
115
- end
@@ -1,163 +0,0 @@
1
- require "yaml"
2
- # use require_relative for
3
- # rake generate:function (without bundle)
4
- require_relative "function"
5
-
6
- module Torch
7
- module Native
8
- module Generator
9
- class << self
10
- def generate_cpp_functions
11
- functions = grouped_functions
12
- generate_cpp_file("torch", :define_singleton_method, functions[:torch])
13
- generate_cpp_file("tensor", :define_method, functions[:tensor])
14
- generate_cpp_file("nn", :define_singleton_method, functions[:nn])
15
- end
16
-
17
- def grouped_functions
18
- functions = functions()
19
-
20
- # skip functions
21
- skip_args = ["bool[3]", "Dimname", "Layout", "Storage", "ConstQuantizerPtr"]
22
-
23
- # remove functions
24
- functions.reject! do |f|
25
- f.ruby_name.start_with?("_") ||
26
- f.ruby_name.end_with?("_backward") ||
27
- f.args.any? { |a| a[:type].include?("Dimname") }
28
- end
29
-
30
- # separate out into todo
31
- todo_functions, functions =
32
- functions.partition do |f|
33
- f.args.any? do |a|
34
- a[:type].include?("?") && !["Tensor?", "Generator?", "int?", "ScalarType?", "Tensor?[]"].include?(a[:type]) ||
35
- skip_args.any? { |sa| a[:type].include?(sa) } ||
36
- # call to 'range' is ambiguous
37
- f.cpp_name == "_range" ||
38
- # native_functions.yaml is missing size argument for normal
39
- # https://pytorch.org/cppdocs/api/function_namespacetorch_1a80253fe5a3ded4716ec929a348adb4b9.html
40
- (f.base_name == "normal" && !f.out?)
41
- end
42
- end
43
-
44
- # todo_functions.each do |f|
45
- # puts f.func
46
- # puts
47
- # end
48
-
49
- nn_functions, other_functions = functions.partition { |f| f.python_module == "nn" }
50
- torch_functions = other_functions.select { |f| f.variants.include?("function") }
51
- tensor_functions = other_functions.select { |f| f.variants.include?("method") }
52
-
53
- {torch: torch_functions, tensor: tensor_functions, nn: nn_functions}
54
- end
55
-
56
- private
57
-
58
- def generate_cpp_file(type, def_method, functions)
59
- hpp_template = <<-TEMPLATE
60
- // generated by rake generate:functions
61
- // do not edit by hand
62
-
63
- #pragma once
64
-
65
- void add_%{type}_functions(Module m);
66
- TEMPLATE
67
-
68
- cpp_template = <<-TEMPLATE
69
- // generated by rake generate:functions
70
- // do not edit by hand
71
-
72
- #include <torch/torch.h>
73
- #include <rice/Module.hpp>
74
- #include "templates.hpp"
75
-
76
- void add_%{type}_functions(Module m) {
77
- m
78
- %{functions};
79
- }
80
- TEMPLATE
81
-
82
- cpp_defs = []
83
- functions.sort_by(&:cpp_name).each do |func|
84
- fargs = func.args.dup #.select { |a| a[:type] != "Generator?" }
85
- fargs << {name: "options", type: "TensorOptions"} if func.tensor_options
86
-
87
- cpp_args = []
88
- fargs.each do |a|
89
- t =
90
- case a[:type]
91
- when "Tensor"
92
- "const Tensor &"
93
- when "Tensor?"
94
- # TODO better signature
95
- "OptionalTensor"
96
- when "ScalarType?"
97
- "torch::optional<ScalarType>"
98
- when "Tensor[]"
99
- "TensorList"
100
- when "Tensor?[]"
101
- # TODO make optional
102
- "TensorList"
103
- when "int"
104
- "int64_t"
105
- when "int?"
106
- "torch::optional<int64_t>"
107
- when "float"
108
- "double"
109
- when /\Aint\[/
110
- "IntArrayRef"
111
- when /Tensor\(\S!?\)/
112
- "Tensor &"
113
- when "str"
114
- "std::string"
115
- when "TensorOptions"
116
- "const torch::TensorOptions &"
117
- else
118
- a[:type]
119
- end
120
-
121
- t = "MyReduction" if a[:name] == "reduction" && t == "int64_t"
122
- cpp_args << [t, a[:name]].join(" ").sub("& ", "&")
123
- end
124
-
125
- dispatch = func.out? ? "#{func.base_name}_out" : func.base_name
126
- args = fargs.map { |a| a[:name] }
127
- args.unshift(*args.pop(func.out_size)) if func.out?
128
- args.delete("self") if def_method == :define_method
129
-
130
- prefix = def_method == :define_method ? "self." : "torch::"
131
-
132
- body = "#{prefix}#{dispatch}(#{args.join(", ")})"
133
- # TODO check type as well
134
- if func.ret_size > 1
135
- body = "wrap(#{body})"
136
- end
137
-
138
- cpp_defs << ".#{def_method}(
139
- \"#{func.cpp_name}\",
140
- *[](#{cpp_args.join(", ")}) {
141
- return #{body};
142
- })"
143
- end
144
-
145
- hpp_contents = hpp_template % {type: type}
146
- cpp_contents = cpp_template % {type: type, functions: cpp_defs.join("\n ")}
147
-
148
- path = File.expand_path("../../../ext/torch", __dir__)
149
- File.write("#{path}/#{type}_functions.hpp", hpp_contents)
150
- File.write("#{path}/#{type}_functions.cpp", cpp_contents)
151
- end
152
-
153
- def functions
154
- @native_functions ||= YAML.load_file(path).map { |f| Function.new(f) }
155
- end
156
-
157
- def path
158
- File.expand_path("native_functions.yaml", __dir__)
159
- end
160
- end
161
- end
162
- end
163
- end
@@ -1,140 +0,0 @@
1
- module Torch
2
- module Native
3
- class Parser
4
- def initialize(functions)
5
- @functions = functions
6
- @name = @functions.first.ruby_name
7
- @min_args = @functions.map { |f| f.args.count { |a| a[:pos] && !a[:has_default] } }.min
8
- @max_args = @functions.map { |f| f.args.count { |a| a[:pos] } }.max
9
- end
10
-
11
- def parse(args, options)
12
- candidates = @functions.dup
13
-
14
- # remove nil
15
- while args.any? && args.last.nil?
16
- args.pop
17
- end
18
-
19
- # TODO account for args passed as options here
20
- if args.size < @min_args || args.size > @max_args
21
- expected = String.new(@min_args.to_s)
22
- expected += "..#{@max_args}" if @max_args != @min_args
23
- return {error: "wrong number of arguments (given #{args.size}, expected #{expected})"}
24
- end
25
-
26
- candidates.reject! { |f| args.size > f.args.size }
27
-
28
- # exclude functions missing required options
29
- candidates.reject! do |func|
30
- # TODO make more generic
31
- func.out? && !options[:out]
32
- end
33
-
34
- # handle out with multiple
35
- # there should only be one match, so safe to modify all
36
- out_func = candidates.find { |f| f.out? }
37
- if out_func && out_func.out_size > 1 && options[:out]
38
- out_args = out_func.args.last(2).map { |a| a[:name] }
39
- out_args.zip(options.delete(:out)).each do |k, v|
40
- options[k.to_sym] = v
41
- end
42
- candidates = [out_func]
43
- end
44
-
45
- # exclude functions where options don't match
46
- options.each do |k, v|
47
- candidates.select! do |func|
48
- func.args.any? { |a| a[:name] == k.to_s }
49
- end
50
- # TODO show all bad keywords at once like Ruby?
51
- return {error: "unknown keyword: #{k}"} if candidates.empty?
52
- end
53
-
54
- final_values = {}
55
-
56
- # check args
57
- candidates.select! do |func|
58
- good = true
59
-
60
- values = args.zip(func.args).map { |a, fa| [fa[:name], a] }.to_h
61
- values.merge!(options.map { |k, v| [k.to_s, v] }.to_h)
62
- func.args.each do |fa|
63
- values[fa[:name]] = fa[:default] if values[fa[:name]].nil?
64
- end
65
-
66
- arg_types = func.args.map { |a| [a[:name], a[:type]] }.to_h
67
-
68
- values.each_key do |k|
69
- v = values[k]
70
- t = arg_types[k].split("(").first
71
-
72
- good =
73
- case t
74
- when "Tensor"
75
- v.is_a?(Tensor)
76
- when "Tensor?"
77
- v.nil? || v.is_a?(Tensor)
78
- when "Tensor[]", "Tensor?[]"
79
- v.is_a?(Array) && v.all? { |v2| v2.is_a?(Tensor) }
80
- when "int"
81
- if k == "reduction"
82
- v.is_a?(String)
83
- else
84
- v.is_a?(Integer)
85
- end
86
- when "int?"
87
- v.is_a?(Integer) || v.nil?
88
- when "float"
89
- v.is_a?(Numeric)
90
- when /int\[.*\]/
91
- if v.is_a?(Integer)
92
- size = t[4..-2]
93
- raise Error, "Unknown size: #{size}. Please report a bug with #{@name}." unless size =~ /\A\d+\z/
94
- v = [v] * size.to_i
95
- values[k] = v
96
- end
97
- v.is_a?(Array) && v.all? { |v2| v2.is_a?(Integer) }
98
- when "Scalar"
99
- v.is_a?(Numeric)
100
- when "ScalarType?"
101
- v.nil?
102
- when "bool"
103
- v == true || v == false
104
- when "str"
105
- v.is_a?(String)
106
- else
107
- raise Error, "Unknown argument type: #{arg_types[k]}. Please report a bug with #{@name}."
108
- end
109
-
110
- if !good
111
- if candidates.size == 1
112
- k = "input" if k == "self"
113
- return {error: "#{@name}(): argument '#{k}' must be #{t}"}
114
- end
115
- break
116
- end
117
- end
118
-
119
- if good
120
- final_values = values
121
- end
122
-
123
- good
124
- end
125
-
126
- if candidates.size != 1
127
- raise Error, "This should never happen. Please report a bug with #{@name}."
128
- end
129
-
130
- func = candidates.first
131
- args = func.args.map { |a| final_values[a[:name]] }
132
- args << TensorOptions.new.dtype(6) if func.tensor_options
133
- {
134
- name: func.cpp_name,
135
- args: args
136
- }
137
- end
138
- end
139
- end
140
- end