torch-rb 0.3.4 → 0.4.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +28 -0
- data/README.md +2 -1
- data/codegen/function.rb +134 -0
- data/codegen/generate_functions.rb +549 -0
- data/{lib/torch/native → codegen}/native_functions.yaml +0 -0
- data/ext/torch/ext.cpp +76 -87
- data/ext/torch/extconf.rb +5 -2
- data/ext/torch/nn_functions.h +6 -0
- data/ext/torch/ruby_arg_parser.cpp +593 -0
- data/ext/torch/ruby_arg_parser.h +373 -0
- data/ext/torch/{templates.hpp → templates.h} +87 -97
- data/ext/torch/tensor_functions.h +6 -0
- data/ext/torch/torch_functions.h +6 -0
- data/ext/torch/utils.h +42 -0
- data/ext/torch/{templates.cpp → wrap_outputs.h} +44 -7
- data/lib/torch.rb +51 -77
- data/lib/torch/nn/functional.rb +142 -18
- data/lib/torch/nn/init.rb +5 -19
- data/lib/torch/nn/leaky_relu.rb +3 -3
- data/lib/torch/nn/module.rb +9 -1
- data/lib/torch/nn/upsample.rb +31 -0
- data/lib/torch/optim/adadelta.rb +1 -1
- data/lib/torch/optim/adam.rb +2 -2
- data/lib/torch/optim/adamax.rb +1 -1
- data/lib/torch/optim/adamw.rb +1 -1
- data/lib/torch/optim/asgd.rb +1 -1
- data/lib/torch/optim/sgd.rb +3 -3
- data/lib/torch/tensor.rb +36 -115
- data/lib/torch/utils/data/data_loader.rb +2 -0
- data/lib/torch/utils/data/tensor_dataset.rb +2 -0
- data/lib/torch/version.rb +1 -1
- metadata +19 -14
- data/lib/torch/native/dispatcher.rb +0 -48
- data/lib/torch/native/function.rb +0 -115
- data/lib/torch/native/generator.rb +0 -163
- data/lib/torch/native/parser.rb +0 -140
@@ -1,48 +0,0 @@
|
|
1
|
-
# We use a generic interface for methods (*args, **options)
|
2
|
-
# and this class to determine the C++ method to call
|
3
|
-
#
|
4
|
-
# This is needed since LibTorch uses function overloading,
|
5
|
-
# which isn't available in Ruby or Python
|
6
|
-
#
|
7
|
-
# PyTorch uses this approach, but the parser/dispatcher is written in C++
|
8
|
-
#
|
9
|
-
# We could generate Ruby methods directly, but an advantage of this approach is
|
10
|
-
# arguments and keyword arguments can be used interchangably like in Python,
|
11
|
-
# making it easier to port code
|
12
|
-
|
13
|
-
module Torch
|
14
|
-
module Native
|
15
|
-
module Dispatcher
|
16
|
-
class << self
|
17
|
-
def bind
|
18
|
-
functions = Generator.grouped_functions
|
19
|
-
bind_functions(::Torch, :define_singleton_method, functions[:torch])
|
20
|
-
bind_functions(::Torch::Tensor, :define_method, functions[:tensor])
|
21
|
-
bind_functions(::Torch::NN, :define_singleton_method, functions[:nn])
|
22
|
-
end
|
23
|
-
|
24
|
-
def bind_functions(context, def_method, functions)
|
25
|
-
functions.group_by(&:ruby_name).sort_by { |g, _| g }.each do |name, funcs|
|
26
|
-
if def_method == :define_method
|
27
|
-
funcs.map! { |f| Function.new(f.function) }
|
28
|
-
funcs.each { |f| f.args.reject! { |a| a[:name] == "self" } }
|
29
|
-
end
|
30
|
-
|
31
|
-
defined = def_method == :define_method ? context.method_defined?(name) : context.respond_to?(name)
|
32
|
-
next if defined && name != "clone"
|
33
|
-
|
34
|
-
parser = Parser.new(funcs)
|
35
|
-
|
36
|
-
context.send(def_method, name) do |*args, **options|
|
37
|
-
result = parser.parse(args, options)
|
38
|
-
raise ArgumentError, result[:error] if result[:error]
|
39
|
-
send(result[:name], *result[:args])
|
40
|
-
end
|
41
|
-
end
|
42
|
-
end
|
43
|
-
end
|
44
|
-
end
|
45
|
-
end
|
46
|
-
end
|
47
|
-
|
48
|
-
Torch::Native::Dispatcher.bind
|
@@ -1,115 +0,0 @@
|
|
1
|
-
module Torch
|
2
|
-
module Native
|
3
|
-
class Function
|
4
|
-
attr_reader :function, :tensor_options
|
5
|
-
|
6
|
-
def initialize(function)
|
7
|
-
@function = function
|
8
|
-
|
9
|
-
tensor_options_str = ", *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None)"
|
10
|
-
@tensor_options = @function["func"].include?(tensor_options_str)
|
11
|
-
@function["func"].sub!(tensor_options_str, ")")
|
12
|
-
end
|
13
|
-
|
14
|
-
def func
|
15
|
-
@func ||= @function["func"]
|
16
|
-
end
|
17
|
-
|
18
|
-
def name
|
19
|
-
@name ||= func.split("(", 2).first
|
20
|
-
end
|
21
|
-
|
22
|
-
def python_module
|
23
|
-
@python_module ||= @function["python_module"]
|
24
|
-
end
|
25
|
-
|
26
|
-
def variants
|
27
|
-
@variants ||= (@function["variants"] || "function").split(", ")
|
28
|
-
end
|
29
|
-
|
30
|
-
def args
|
31
|
-
@args ||= begin
|
32
|
-
args = []
|
33
|
-
pos = true
|
34
|
-
args_str = func.split("(", 2).last.split(") ->").first
|
35
|
-
args_str.split(", ").each do |a|
|
36
|
-
if a == "*"
|
37
|
-
pos = false
|
38
|
-
next
|
39
|
-
end
|
40
|
-
t, _, k = a.rpartition(" ")
|
41
|
-
k, d = k.split("=")
|
42
|
-
has_default = !d.nil?
|
43
|
-
|
44
|
-
if d
|
45
|
-
d =
|
46
|
-
case d
|
47
|
-
when "True"
|
48
|
-
true
|
49
|
-
when "False"
|
50
|
-
false
|
51
|
-
when "None"
|
52
|
-
nil
|
53
|
-
when /\A\-?\d+\z/
|
54
|
-
d.to_i
|
55
|
-
when "[]"
|
56
|
-
[]
|
57
|
-
when "[0,1]"
|
58
|
-
[0, 1]
|
59
|
-
when /\A\de\-\d+\z/, /\A\d+\.\d+\z/
|
60
|
-
d.to_f
|
61
|
-
when "Mean"
|
62
|
-
"mean"
|
63
|
-
when "contiguous_format"
|
64
|
-
d
|
65
|
-
when "long"
|
66
|
-
:long
|
67
|
-
else
|
68
|
-
raise "Unknown default: #{d}"
|
69
|
-
end
|
70
|
-
end
|
71
|
-
|
72
|
-
next if t == "Generator?"
|
73
|
-
next if t == "MemoryFormat"
|
74
|
-
next if t == "MemoryFormat?"
|
75
|
-
args << {name: k, type: t, default: d, pos: pos, has_default: has_default}
|
76
|
-
end
|
77
|
-
args
|
78
|
-
end
|
79
|
-
end
|
80
|
-
|
81
|
-
def out_size
|
82
|
-
@out_size ||= func.split("->").last.count("!")
|
83
|
-
end
|
84
|
-
|
85
|
-
def ret_size
|
86
|
-
@ret_size ||= func.split("->").last.split(", ").size
|
87
|
-
end
|
88
|
-
|
89
|
-
def out?
|
90
|
-
out_size > 0 && base_name[-1] != "_"
|
91
|
-
end
|
92
|
-
|
93
|
-
def ruby_name
|
94
|
-
@ruby_name ||= begin
|
95
|
-
name = base_name
|
96
|
-
if name.end_with?("_")
|
97
|
-
"#{name[0..-2]}!"
|
98
|
-
elsif name.start_with?("is_")
|
99
|
-
"#{name[3..-1]}?"
|
100
|
-
else
|
101
|
-
name
|
102
|
-
end
|
103
|
-
end
|
104
|
-
end
|
105
|
-
|
106
|
-
def cpp_name
|
107
|
-
@cpp_name ||= "_" + name.downcase.sub(".", "_")
|
108
|
-
end
|
109
|
-
|
110
|
-
def base_name
|
111
|
-
@base_name ||= name.split(".").first
|
112
|
-
end
|
113
|
-
end
|
114
|
-
end
|
115
|
-
end
|
@@ -1,163 +0,0 @@
|
|
1
|
-
require "yaml"
|
2
|
-
# use require_relative for
|
3
|
-
# rake generate:function (without bundle)
|
4
|
-
require_relative "function"
|
5
|
-
|
6
|
-
module Torch
|
7
|
-
module Native
|
8
|
-
module Generator
|
9
|
-
class << self
|
10
|
-
def generate_cpp_functions
|
11
|
-
functions = grouped_functions
|
12
|
-
generate_cpp_file("torch", :define_singleton_method, functions[:torch])
|
13
|
-
generate_cpp_file("tensor", :define_method, functions[:tensor])
|
14
|
-
generate_cpp_file("nn", :define_singleton_method, functions[:nn])
|
15
|
-
end
|
16
|
-
|
17
|
-
def grouped_functions
|
18
|
-
functions = functions()
|
19
|
-
|
20
|
-
# skip functions
|
21
|
-
skip_args = ["bool[3]", "Dimname", "Layout", "Storage", "ConstQuantizerPtr"]
|
22
|
-
|
23
|
-
# remove functions
|
24
|
-
functions.reject! do |f|
|
25
|
-
f.ruby_name.start_with?("_") ||
|
26
|
-
f.ruby_name.end_with?("_backward") ||
|
27
|
-
f.args.any? { |a| a[:type].include?("Dimname") }
|
28
|
-
end
|
29
|
-
|
30
|
-
# separate out into todo
|
31
|
-
todo_functions, functions =
|
32
|
-
functions.partition do |f|
|
33
|
-
f.args.any? do |a|
|
34
|
-
a[:type].include?("?") && !["Tensor?", "Generator?", "int?", "ScalarType?", "Tensor?[]"].include?(a[:type]) ||
|
35
|
-
skip_args.any? { |sa| a[:type].include?(sa) } ||
|
36
|
-
# call to 'range' is ambiguous
|
37
|
-
f.cpp_name == "_range" ||
|
38
|
-
# native_functions.yaml is missing size argument for normal
|
39
|
-
# https://pytorch.org/cppdocs/api/function_namespacetorch_1a80253fe5a3ded4716ec929a348adb4b9.html
|
40
|
-
(f.base_name == "normal" && !f.out?)
|
41
|
-
end
|
42
|
-
end
|
43
|
-
|
44
|
-
# todo_functions.each do |f|
|
45
|
-
# puts f.func
|
46
|
-
# puts
|
47
|
-
# end
|
48
|
-
|
49
|
-
nn_functions, other_functions = functions.partition { |f| f.python_module == "nn" }
|
50
|
-
torch_functions = other_functions.select { |f| f.variants.include?("function") }
|
51
|
-
tensor_functions = other_functions.select { |f| f.variants.include?("method") }
|
52
|
-
|
53
|
-
{torch: torch_functions, tensor: tensor_functions, nn: nn_functions}
|
54
|
-
end
|
55
|
-
|
56
|
-
private
|
57
|
-
|
58
|
-
def generate_cpp_file(type, def_method, functions)
|
59
|
-
hpp_template = <<-TEMPLATE
|
60
|
-
// generated by rake generate:functions
|
61
|
-
// do not edit by hand
|
62
|
-
|
63
|
-
#pragma once
|
64
|
-
|
65
|
-
void add_%{type}_functions(Module m);
|
66
|
-
TEMPLATE
|
67
|
-
|
68
|
-
cpp_template = <<-TEMPLATE
|
69
|
-
// generated by rake generate:functions
|
70
|
-
// do not edit by hand
|
71
|
-
|
72
|
-
#include <torch/torch.h>
|
73
|
-
#include <rice/Module.hpp>
|
74
|
-
#include "templates.hpp"
|
75
|
-
|
76
|
-
void add_%{type}_functions(Module m) {
|
77
|
-
m
|
78
|
-
%{functions};
|
79
|
-
}
|
80
|
-
TEMPLATE
|
81
|
-
|
82
|
-
cpp_defs = []
|
83
|
-
functions.sort_by(&:cpp_name).each do |func|
|
84
|
-
fargs = func.args.dup #.select { |a| a[:type] != "Generator?" }
|
85
|
-
fargs << {name: "options", type: "TensorOptions"} if func.tensor_options
|
86
|
-
|
87
|
-
cpp_args = []
|
88
|
-
fargs.each do |a|
|
89
|
-
t =
|
90
|
-
case a[:type]
|
91
|
-
when "Tensor"
|
92
|
-
"const Tensor &"
|
93
|
-
when "Tensor?"
|
94
|
-
# TODO better signature
|
95
|
-
"OptionalTensor"
|
96
|
-
when "ScalarType?"
|
97
|
-
"torch::optional<ScalarType>"
|
98
|
-
when "Tensor[]"
|
99
|
-
"TensorList"
|
100
|
-
when "Tensor?[]"
|
101
|
-
# TODO make optional
|
102
|
-
"TensorList"
|
103
|
-
when "int"
|
104
|
-
"int64_t"
|
105
|
-
when "int?"
|
106
|
-
"torch::optional<int64_t>"
|
107
|
-
when "float"
|
108
|
-
"double"
|
109
|
-
when /\Aint\[/
|
110
|
-
"IntArrayRef"
|
111
|
-
when /Tensor\(\S!?\)/
|
112
|
-
"Tensor &"
|
113
|
-
when "str"
|
114
|
-
"std::string"
|
115
|
-
when "TensorOptions"
|
116
|
-
"const torch::TensorOptions &"
|
117
|
-
else
|
118
|
-
a[:type]
|
119
|
-
end
|
120
|
-
|
121
|
-
t = "MyReduction" if a[:name] == "reduction" && t == "int64_t"
|
122
|
-
cpp_args << [t, a[:name]].join(" ").sub("& ", "&")
|
123
|
-
end
|
124
|
-
|
125
|
-
dispatch = func.out? ? "#{func.base_name}_out" : func.base_name
|
126
|
-
args = fargs.map { |a| a[:name] }
|
127
|
-
args.unshift(*args.pop(func.out_size)) if func.out?
|
128
|
-
args.delete("self") if def_method == :define_method
|
129
|
-
|
130
|
-
prefix = def_method == :define_method ? "self." : "torch::"
|
131
|
-
|
132
|
-
body = "#{prefix}#{dispatch}(#{args.join(", ")})"
|
133
|
-
# TODO check type as well
|
134
|
-
if func.ret_size > 1
|
135
|
-
body = "wrap(#{body})"
|
136
|
-
end
|
137
|
-
|
138
|
-
cpp_defs << ".#{def_method}(
|
139
|
-
\"#{func.cpp_name}\",
|
140
|
-
*[](#{cpp_args.join(", ")}) {
|
141
|
-
return #{body};
|
142
|
-
})"
|
143
|
-
end
|
144
|
-
|
145
|
-
hpp_contents = hpp_template % {type: type}
|
146
|
-
cpp_contents = cpp_template % {type: type, functions: cpp_defs.join("\n ")}
|
147
|
-
|
148
|
-
path = File.expand_path("../../../ext/torch", __dir__)
|
149
|
-
File.write("#{path}/#{type}_functions.hpp", hpp_contents)
|
150
|
-
File.write("#{path}/#{type}_functions.cpp", cpp_contents)
|
151
|
-
end
|
152
|
-
|
153
|
-
def functions
|
154
|
-
@native_functions ||= YAML.load_file(path).map { |f| Function.new(f) }
|
155
|
-
end
|
156
|
-
|
157
|
-
def path
|
158
|
-
File.expand_path("native_functions.yaml", __dir__)
|
159
|
-
end
|
160
|
-
end
|
161
|
-
end
|
162
|
-
end
|
163
|
-
end
|
data/lib/torch/native/parser.rb
DELETED
@@ -1,140 +0,0 @@
|
|
1
|
-
module Torch
|
2
|
-
module Native
|
3
|
-
class Parser
|
4
|
-
def initialize(functions)
|
5
|
-
@functions = functions
|
6
|
-
@name = @functions.first.ruby_name
|
7
|
-
@min_args = @functions.map { |f| f.args.count { |a| a[:pos] && !a[:has_default] } }.min
|
8
|
-
@max_args = @functions.map { |f| f.args.count { |a| a[:pos] } }.max
|
9
|
-
end
|
10
|
-
|
11
|
-
def parse(args, options)
|
12
|
-
candidates = @functions.dup
|
13
|
-
|
14
|
-
# remove nil
|
15
|
-
while args.any? && args.last.nil?
|
16
|
-
args.pop
|
17
|
-
end
|
18
|
-
|
19
|
-
# TODO account for args passed as options here
|
20
|
-
if args.size < @min_args || args.size > @max_args
|
21
|
-
expected = String.new(@min_args.to_s)
|
22
|
-
expected += "..#{@max_args}" if @max_args != @min_args
|
23
|
-
return {error: "wrong number of arguments (given #{args.size}, expected #{expected})"}
|
24
|
-
end
|
25
|
-
|
26
|
-
candidates.reject! { |f| args.size > f.args.size }
|
27
|
-
|
28
|
-
# exclude functions missing required options
|
29
|
-
candidates.reject! do |func|
|
30
|
-
# TODO make more generic
|
31
|
-
func.out? && !options[:out]
|
32
|
-
end
|
33
|
-
|
34
|
-
# handle out with multiple
|
35
|
-
# there should only be one match, so safe to modify all
|
36
|
-
out_func = candidates.find { |f| f.out? }
|
37
|
-
if out_func && out_func.out_size > 1 && options[:out]
|
38
|
-
out_args = out_func.args.last(2).map { |a| a[:name] }
|
39
|
-
out_args.zip(options.delete(:out)).each do |k, v|
|
40
|
-
options[k.to_sym] = v
|
41
|
-
end
|
42
|
-
candidates = [out_func]
|
43
|
-
end
|
44
|
-
|
45
|
-
# exclude functions where options don't match
|
46
|
-
options.each do |k, v|
|
47
|
-
candidates.select! do |func|
|
48
|
-
func.args.any? { |a| a[:name] == k.to_s }
|
49
|
-
end
|
50
|
-
# TODO show all bad keywords at once like Ruby?
|
51
|
-
return {error: "unknown keyword: #{k}"} if candidates.empty?
|
52
|
-
end
|
53
|
-
|
54
|
-
final_values = {}
|
55
|
-
|
56
|
-
# check args
|
57
|
-
candidates.select! do |func|
|
58
|
-
good = true
|
59
|
-
|
60
|
-
values = args.zip(func.args).map { |a, fa| [fa[:name], a] }.to_h
|
61
|
-
values.merge!(options.map { |k, v| [k.to_s, v] }.to_h)
|
62
|
-
func.args.each do |fa|
|
63
|
-
values[fa[:name]] = fa[:default] if values[fa[:name]].nil?
|
64
|
-
end
|
65
|
-
|
66
|
-
arg_types = func.args.map { |a| [a[:name], a[:type]] }.to_h
|
67
|
-
|
68
|
-
values.each_key do |k|
|
69
|
-
v = values[k]
|
70
|
-
t = arg_types[k].split("(").first
|
71
|
-
|
72
|
-
good =
|
73
|
-
case t
|
74
|
-
when "Tensor"
|
75
|
-
v.is_a?(Tensor)
|
76
|
-
when "Tensor?"
|
77
|
-
v.nil? || v.is_a?(Tensor)
|
78
|
-
when "Tensor[]", "Tensor?[]"
|
79
|
-
v.is_a?(Array) && v.all? { |v2| v2.is_a?(Tensor) }
|
80
|
-
when "int"
|
81
|
-
if k == "reduction"
|
82
|
-
v.is_a?(String)
|
83
|
-
else
|
84
|
-
v.is_a?(Integer)
|
85
|
-
end
|
86
|
-
when "int?"
|
87
|
-
v.is_a?(Integer) || v.nil?
|
88
|
-
when "float"
|
89
|
-
v.is_a?(Numeric)
|
90
|
-
when /int\[.*\]/
|
91
|
-
if v.is_a?(Integer)
|
92
|
-
size = t[4..-2]
|
93
|
-
raise Error, "Unknown size: #{size}. Please report a bug with #{@name}." unless size =~ /\A\d+\z/
|
94
|
-
v = [v] * size.to_i
|
95
|
-
values[k] = v
|
96
|
-
end
|
97
|
-
v.is_a?(Array) && v.all? { |v2| v2.is_a?(Integer) }
|
98
|
-
when "Scalar"
|
99
|
-
v.is_a?(Numeric)
|
100
|
-
when "ScalarType?"
|
101
|
-
v.nil?
|
102
|
-
when "bool"
|
103
|
-
v == true || v == false
|
104
|
-
when "str"
|
105
|
-
v.is_a?(String)
|
106
|
-
else
|
107
|
-
raise Error, "Unknown argument type: #{arg_types[k]}. Please report a bug with #{@name}."
|
108
|
-
end
|
109
|
-
|
110
|
-
if !good
|
111
|
-
if candidates.size == 1
|
112
|
-
k = "input" if k == "self"
|
113
|
-
return {error: "#{@name}(): argument '#{k}' must be #{t}"}
|
114
|
-
end
|
115
|
-
break
|
116
|
-
end
|
117
|
-
end
|
118
|
-
|
119
|
-
if good
|
120
|
-
final_values = values
|
121
|
-
end
|
122
|
-
|
123
|
-
good
|
124
|
-
end
|
125
|
-
|
126
|
-
if candidates.size != 1
|
127
|
-
raise Error, "This should never happen. Please report a bug with #{@name}."
|
128
|
-
end
|
129
|
-
|
130
|
-
func = candidates.first
|
131
|
-
args = func.args.map { |a| final_values[a[:name]] }
|
132
|
-
args << TensorOptions.new.dtype(6) if func.tensor_options
|
133
|
-
{
|
134
|
-
name: func.cpp_name,
|
135
|
-
args: args
|
136
|
-
}
|
137
|
-
end
|
138
|
-
end
|
139
|
-
end
|
140
|
-
end
|