torch-rb 0.3.3 → 0.4.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +28 -0
- data/README.md +2 -1
- data/codegen/function.rb +134 -0
- data/codegen/generate_functions.rb +546 -0
- data/{lib/torch/native → codegen}/native_functions.yaml +0 -0
- data/ext/torch/ext.cpp +76 -87
- data/ext/torch/extconf.rb +5 -2
- data/ext/torch/nn_functions.h +6 -0
- data/ext/torch/ruby_arg_parser.cpp +593 -0
- data/ext/torch/ruby_arg_parser.h +373 -0
- data/ext/torch/{templates.hpp → templates.h} +87 -97
- data/ext/torch/tensor_functions.h +6 -0
- data/ext/torch/torch_functions.h +6 -0
- data/ext/torch/utils.h +42 -0
- data/ext/torch/{templates.cpp → wrap_outputs.h} +44 -7
- data/lib/torch.rb +56 -77
- data/lib/torch/nn/functional.rb +142 -18
- data/lib/torch/nn/init.rb +5 -19
- data/lib/torch/nn/leaky_relu.rb +3 -3
- data/lib/torch/nn/module.rb +9 -1
- data/lib/torch/nn/upsample.rb +31 -0
- data/lib/torch/optim/adadelta.rb +1 -1
- data/lib/torch/optim/adam.rb +2 -2
- data/lib/torch/optim/adamax.rb +1 -1
- data/lib/torch/optim/adamw.rb +1 -1
- data/lib/torch/optim/asgd.rb +1 -1
- data/lib/torch/optim/sgd.rb +3 -3
- data/lib/torch/tensor.rb +36 -115
- data/lib/torch/utils/data/data_loader.rb +2 -0
- data/lib/torch/utils/data/tensor_dataset.rb +2 -0
- data/lib/torch/version.rb +1 -1
- metadata +28 -9
- data/lib/torch/native/dispatcher.rb +0 -48
- data/lib/torch/native/function.rb +0 -115
- data/lib/torch/native/generator.rb +0 -163
- data/lib/torch/native/parser.rb +0 -140
@@ -1,48 +0,0 @@
|
|
1
|
-
# We use a generic interface for methods (*args, **options)
|
2
|
-
# and this class to determine the C++ method to call
|
3
|
-
#
|
4
|
-
# This is needed since LibTorch uses function overloading,
|
5
|
-
# which isn't available in Ruby or Python
|
6
|
-
#
|
7
|
-
# PyTorch uses this approach, but the parser/dispatcher is written in C++
|
8
|
-
#
|
9
|
-
# We could generate Ruby methods directly, but an advantage of this approach is
|
10
|
-
# arguments and keyword arguments can be used interchangably like in Python,
|
11
|
-
# making it easier to port code
|
12
|
-
|
13
|
-
module Torch
|
14
|
-
module Native
|
15
|
-
module Dispatcher
|
16
|
-
class << self
|
17
|
-
def bind
|
18
|
-
functions = Generator.grouped_functions
|
19
|
-
bind_functions(::Torch, :define_singleton_method, functions[:torch])
|
20
|
-
bind_functions(::Torch::Tensor, :define_method, functions[:tensor])
|
21
|
-
bind_functions(::Torch::NN, :define_singleton_method, functions[:nn])
|
22
|
-
end
|
23
|
-
|
24
|
-
def bind_functions(context, def_method, functions)
|
25
|
-
functions.group_by(&:ruby_name).sort_by { |g, _| g }.each do |name, funcs|
|
26
|
-
if def_method == :define_method
|
27
|
-
funcs.map! { |f| Function.new(f.function) }
|
28
|
-
funcs.each { |f| f.args.reject! { |a| a[:name] == "self" } }
|
29
|
-
end
|
30
|
-
|
31
|
-
defined = def_method == :define_method ? context.method_defined?(name) : context.respond_to?(name)
|
32
|
-
next if defined && name != "clone"
|
33
|
-
|
34
|
-
parser = Parser.new(funcs)
|
35
|
-
|
36
|
-
context.send(def_method, name) do |*args, **options|
|
37
|
-
result = parser.parse(args, options)
|
38
|
-
raise ArgumentError, result[:error] if result[:error]
|
39
|
-
send(result[:name], *result[:args])
|
40
|
-
end
|
41
|
-
end
|
42
|
-
end
|
43
|
-
end
|
44
|
-
end
|
45
|
-
end
|
46
|
-
end
|
47
|
-
|
48
|
-
Torch::Native::Dispatcher.bind
|
@@ -1,115 +0,0 @@
|
|
1
|
-
module Torch
|
2
|
-
module Native
|
3
|
-
class Function
|
4
|
-
attr_reader :function, :tensor_options
|
5
|
-
|
6
|
-
def initialize(function)
|
7
|
-
@function = function
|
8
|
-
|
9
|
-
tensor_options_str = ", *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None)"
|
10
|
-
@tensor_options = @function["func"].include?(tensor_options_str)
|
11
|
-
@function["func"].sub!(tensor_options_str, ")")
|
12
|
-
end
|
13
|
-
|
14
|
-
def func
|
15
|
-
@func ||= @function["func"]
|
16
|
-
end
|
17
|
-
|
18
|
-
def name
|
19
|
-
@name ||= func.split("(", 2).first
|
20
|
-
end
|
21
|
-
|
22
|
-
def python_module
|
23
|
-
@python_module ||= @function["python_module"]
|
24
|
-
end
|
25
|
-
|
26
|
-
def variants
|
27
|
-
@variants ||= (@function["variants"] || "function").split(", ")
|
28
|
-
end
|
29
|
-
|
30
|
-
def args
|
31
|
-
@args ||= begin
|
32
|
-
args = []
|
33
|
-
pos = true
|
34
|
-
args_str = func.split("(", 2).last.split(") ->").first
|
35
|
-
args_str.split(", ").each do |a|
|
36
|
-
if a == "*"
|
37
|
-
pos = false
|
38
|
-
next
|
39
|
-
end
|
40
|
-
t, _, k = a.rpartition(" ")
|
41
|
-
k, d = k.split("=")
|
42
|
-
has_default = !d.nil?
|
43
|
-
|
44
|
-
if d
|
45
|
-
d =
|
46
|
-
case d
|
47
|
-
when "True"
|
48
|
-
true
|
49
|
-
when "False"
|
50
|
-
false
|
51
|
-
when "None"
|
52
|
-
nil
|
53
|
-
when /\A\-?\d+\z/
|
54
|
-
d.to_i
|
55
|
-
when "[]"
|
56
|
-
[]
|
57
|
-
when "[0,1]"
|
58
|
-
[0, 1]
|
59
|
-
when /\A\de\-\d+\z/, /\A\d+\.\d+\z/
|
60
|
-
d.to_f
|
61
|
-
when "Mean"
|
62
|
-
"mean"
|
63
|
-
when "contiguous_format"
|
64
|
-
d
|
65
|
-
when "long"
|
66
|
-
:long
|
67
|
-
else
|
68
|
-
raise "Unknown default: #{d}"
|
69
|
-
end
|
70
|
-
end
|
71
|
-
|
72
|
-
next if t == "Generator?"
|
73
|
-
next if t == "MemoryFormat"
|
74
|
-
next if t == "MemoryFormat?"
|
75
|
-
args << {name: k, type: t, default: d, pos: pos, has_default: has_default}
|
76
|
-
end
|
77
|
-
args
|
78
|
-
end
|
79
|
-
end
|
80
|
-
|
81
|
-
def out_size
|
82
|
-
@out_size ||= func.split("->").last.count("!")
|
83
|
-
end
|
84
|
-
|
85
|
-
def ret_size
|
86
|
-
@ret_size ||= func.split("->").last.split(", ").size
|
87
|
-
end
|
88
|
-
|
89
|
-
def out?
|
90
|
-
out_size > 0 && base_name[-1] != "_"
|
91
|
-
end
|
92
|
-
|
93
|
-
def ruby_name
|
94
|
-
@ruby_name ||= begin
|
95
|
-
name = base_name
|
96
|
-
if name.end_with?("_")
|
97
|
-
"#{name[0..-2]}!"
|
98
|
-
elsif name.start_with?("is_")
|
99
|
-
"#{name[3..-1]}?"
|
100
|
-
else
|
101
|
-
name
|
102
|
-
end
|
103
|
-
end
|
104
|
-
end
|
105
|
-
|
106
|
-
def cpp_name
|
107
|
-
@cpp_name ||= "_" + name.downcase.sub(".", "_")
|
108
|
-
end
|
109
|
-
|
110
|
-
def base_name
|
111
|
-
@base_name ||= name.split(".").first
|
112
|
-
end
|
113
|
-
end
|
114
|
-
end
|
115
|
-
end
|
@@ -1,163 +0,0 @@
|
|
1
|
-
require "yaml"
|
2
|
-
# use require_relative for
|
3
|
-
# rake generate:function (without bundle)
|
4
|
-
require_relative "function"
|
5
|
-
|
6
|
-
module Torch
|
7
|
-
module Native
|
8
|
-
module Generator
|
9
|
-
class << self
|
10
|
-
def generate_cpp_functions
|
11
|
-
functions = grouped_functions
|
12
|
-
generate_cpp_file("torch", :define_singleton_method, functions[:torch])
|
13
|
-
generate_cpp_file("tensor", :define_method, functions[:tensor])
|
14
|
-
generate_cpp_file("nn", :define_singleton_method, functions[:nn])
|
15
|
-
end
|
16
|
-
|
17
|
-
def grouped_functions
|
18
|
-
functions = functions()
|
19
|
-
|
20
|
-
# skip functions
|
21
|
-
skip_args = ["bool[3]", "Dimname", "Layout", "Storage", "ConstQuantizerPtr"]
|
22
|
-
|
23
|
-
# remove functions
|
24
|
-
functions.reject! do |f|
|
25
|
-
f.ruby_name.start_with?("_") ||
|
26
|
-
f.ruby_name.end_with?("_backward") ||
|
27
|
-
f.args.any? { |a| a[:type].include?("Dimname") }
|
28
|
-
end
|
29
|
-
|
30
|
-
# separate out into todo
|
31
|
-
todo_functions, functions =
|
32
|
-
functions.partition do |f|
|
33
|
-
f.args.any? do |a|
|
34
|
-
a[:type].include?("?") && !["Tensor?", "Generator?", "int?", "ScalarType?", "Tensor?[]"].include?(a[:type]) ||
|
35
|
-
skip_args.any? { |sa| a[:type].include?(sa) } ||
|
36
|
-
# call to 'range' is ambiguous
|
37
|
-
f.cpp_name == "_range" ||
|
38
|
-
# native_functions.yaml is missing size argument for normal
|
39
|
-
# https://pytorch.org/cppdocs/api/function_namespacetorch_1a80253fe5a3ded4716ec929a348adb4b9.html
|
40
|
-
(f.base_name == "normal" && !f.out?)
|
41
|
-
end
|
42
|
-
end
|
43
|
-
|
44
|
-
# todo_functions.each do |f|
|
45
|
-
# puts f.func
|
46
|
-
# puts
|
47
|
-
# end
|
48
|
-
|
49
|
-
nn_functions, other_functions = functions.partition { |f| f.python_module == "nn" }
|
50
|
-
torch_functions = other_functions.select { |f| f.variants.include?("function") }
|
51
|
-
tensor_functions = other_functions.select { |f| f.variants.include?("method") }
|
52
|
-
|
53
|
-
{torch: torch_functions, tensor: tensor_functions, nn: nn_functions}
|
54
|
-
end
|
55
|
-
|
56
|
-
private
|
57
|
-
|
58
|
-
def generate_cpp_file(type, def_method, functions)
|
59
|
-
hpp_template = <<-TEMPLATE
|
60
|
-
// generated by rake generate:functions
|
61
|
-
// do not edit by hand
|
62
|
-
|
63
|
-
#pragma once
|
64
|
-
|
65
|
-
void add_%{type}_functions(Module m);
|
66
|
-
TEMPLATE
|
67
|
-
|
68
|
-
cpp_template = <<-TEMPLATE
|
69
|
-
// generated by rake generate:functions
|
70
|
-
// do not edit by hand
|
71
|
-
|
72
|
-
#include <torch/torch.h>
|
73
|
-
#include <rice/Module.hpp>
|
74
|
-
#include "templates.hpp"
|
75
|
-
|
76
|
-
void add_%{type}_functions(Module m) {
|
77
|
-
m
|
78
|
-
%{functions};
|
79
|
-
}
|
80
|
-
TEMPLATE
|
81
|
-
|
82
|
-
cpp_defs = []
|
83
|
-
functions.sort_by(&:cpp_name).each do |func|
|
84
|
-
fargs = func.args.dup #.select { |a| a[:type] != "Generator?" }
|
85
|
-
fargs << {name: "options", type: "TensorOptions"} if func.tensor_options
|
86
|
-
|
87
|
-
cpp_args = []
|
88
|
-
fargs.each do |a|
|
89
|
-
t =
|
90
|
-
case a[:type]
|
91
|
-
when "Tensor"
|
92
|
-
"const Tensor &"
|
93
|
-
when "Tensor?"
|
94
|
-
# TODO better signature
|
95
|
-
"OptionalTensor"
|
96
|
-
when "ScalarType?"
|
97
|
-
"torch::optional<ScalarType>"
|
98
|
-
when "Tensor[]"
|
99
|
-
"TensorList"
|
100
|
-
when "Tensor?[]"
|
101
|
-
# TODO make optional
|
102
|
-
"TensorList"
|
103
|
-
when "int"
|
104
|
-
"int64_t"
|
105
|
-
when "int?"
|
106
|
-
"torch::optional<int64_t>"
|
107
|
-
when "float"
|
108
|
-
"double"
|
109
|
-
when /\Aint\[/
|
110
|
-
"IntArrayRef"
|
111
|
-
when /Tensor\(\S!?\)/
|
112
|
-
"Tensor &"
|
113
|
-
when "str"
|
114
|
-
"std::string"
|
115
|
-
when "TensorOptions"
|
116
|
-
"const torch::TensorOptions &"
|
117
|
-
else
|
118
|
-
a[:type]
|
119
|
-
end
|
120
|
-
|
121
|
-
t = "MyReduction" if a[:name] == "reduction" && t == "int64_t"
|
122
|
-
cpp_args << [t, a[:name]].join(" ").sub("& ", "&")
|
123
|
-
end
|
124
|
-
|
125
|
-
dispatch = func.out? ? "#{func.base_name}_out" : func.base_name
|
126
|
-
args = fargs.map { |a| a[:name] }
|
127
|
-
args.unshift(*args.pop(func.out_size)) if func.out?
|
128
|
-
args.delete("self") if def_method == :define_method
|
129
|
-
|
130
|
-
prefix = def_method == :define_method ? "self." : "torch::"
|
131
|
-
|
132
|
-
body = "#{prefix}#{dispatch}(#{args.join(", ")})"
|
133
|
-
# TODO check type as well
|
134
|
-
if func.ret_size > 1
|
135
|
-
body = "wrap(#{body})"
|
136
|
-
end
|
137
|
-
|
138
|
-
cpp_defs << ".#{def_method}(
|
139
|
-
\"#{func.cpp_name}\",
|
140
|
-
*[](#{cpp_args.join(", ")}) {
|
141
|
-
return #{body};
|
142
|
-
})"
|
143
|
-
end
|
144
|
-
|
145
|
-
hpp_contents = hpp_template % {type: type}
|
146
|
-
cpp_contents = cpp_template % {type: type, functions: cpp_defs.join("\n ")}
|
147
|
-
|
148
|
-
path = File.expand_path("../../../ext/torch", __dir__)
|
149
|
-
File.write("#{path}/#{type}_functions.hpp", hpp_contents)
|
150
|
-
File.write("#{path}/#{type}_functions.cpp", cpp_contents)
|
151
|
-
end
|
152
|
-
|
153
|
-
def functions
|
154
|
-
@native_functions ||= YAML.load_file(path).map { |f| Function.new(f) }
|
155
|
-
end
|
156
|
-
|
157
|
-
def path
|
158
|
-
File.expand_path("native_functions.yaml", __dir__)
|
159
|
-
end
|
160
|
-
end
|
161
|
-
end
|
162
|
-
end
|
163
|
-
end
|
data/lib/torch/native/parser.rb
DELETED
@@ -1,140 +0,0 @@
|
|
1
|
-
module Torch
|
2
|
-
module Native
|
3
|
-
class Parser
|
4
|
-
def initialize(functions)
|
5
|
-
@functions = functions
|
6
|
-
@name = @functions.first.ruby_name
|
7
|
-
@min_args = @functions.map { |f| f.args.count { |a| a[:pos] && !a[:has_default] } }.min
|
8
|
-
@max_args = @functions.map { |f| f.args.count { |a| a[:pos] } }.max
|
9
|
-
end
|
10
|
-
|
11
|
-
def parse(args, options)
|
12
|
-
candidates = @functions.dup
|
13
|
-
|
14
|
-
# remove nil
|
15
|
-
while args.any? && args.last.nil?
|
16
|
-
args.pop
|
17
|
-
end
|
18
|
-
|
19
|
-
# TODO account for args passed as options here
|
20
|
-
if args.size < @min_args || args.size > @max_args
|
21
|
-
expected = String.new(@min_args.to_s)
|
22
|
-
expected += "..#{@max_args}" if @max_args != @min_args
|
23
|
-
return {error: "wrong number of arguments (given #{args.size}, expected #{expected})"}
|
24
|
-
end
|
25
|
-
|
26
|
-
candidates.reject! { |f| args.size > f.args.size }
|
27
|
-
|
28
|
-
# exclude functions missing required options
|
29
|
-
candidates.reject! do |func|
|
30
|
-
# TODO make more generic
|
31
|
-
func.out? && !options[:out]
|
32
|
-
end
|
33
|
-
|
34
|
-
# handle out with multiple
|
35
|
-
# there should only be one match, so safe to modify all
|
36
|
-
out_func = candidates.find { |f| f.out? }
|
37
|
-
if out_func && out_func.out_size > 1 && options[:out]
|
38
|
-
out_args = out_func.args.last(2).map { |a| a[:name] }
|
39
|
-
out_args.zip(options.delete(:out)).each do |k, v|
|
40
|
-
options[k.to_sym] = v
|
41
|
-
end
|
42
|
-
candidates = [out_func]
|
43
|
-
end
|
44
|
-
|
45
|
-
# exclude functions where options don't match
|
46
|
-
options.each do |k, v|
|
47
|
-
candidates.select! do |func|
|
48
|
-
func.args.any? { |a| a[:name] == k.to_s }
|
49
|
-
end
|
50
|
-
# TODO show all bad keywords at once like Ruby?
|
51
|
-
return {error: "unknown keyword: #{k}"} if candidates.empty?
|
52
|
-
end
|
53
|
-
|
54
|
-
final_values = {}
|
55
|
-
|
56
|
-
# check args
|
57
|
-
candidates.select! do |func|
|
58
|
-
good = true
|
59
|
-
|
60
|
-
values = args.zip(func.args).map { |a, fa| [fa[:name], a] }.to_h
|
61
|
-
values.merge!(options.map { |k, v| [k.to_s, v] }.to_h)
|
62
|
-
func.args.each do |fa|
|
63
|
-
values[fa[:name]] = fa[:default] if values[fa[:name]].nil?
|
64
|
-
end
|
65
|
-
|
66
|
-
arg_types = func.args.map { |a| [a[:name], a[:type]] }.to_h
|
67
|
-
|
68
|
-
values.each_key do |k|
|
69
|
-
v = values[k]
|
70
|
-
t = arg_types[k].split("(").first
|
71
|
-
|
72
|
-
good =
|
73
|
-
case t
|
74
|
-
when "Tensor"
|
75
|
-
v.is_a?(Tensor)
|
76
|
-
when "Tensor?"
|
77
|
-
v.nil? || v.is_a?(Tensor)
|
78
|
-
when "Tensor[]", "Tensor?[]"
|
79
|
-
v.is_a?(Array) && v.all? { |v2| v2.is_a?(Tensor) }
|
80
|
-
when "int"
|
81
|
-
if k == "reduction"
|
82
|
-
v.is_a?(String)
|
83
|
-
else
|
84
|
-
v.is_a?(Integer)
|
85
|
-
end
|
86
|
-
when "int?"
|
87
|
-
v.is_a?(Integer) || v.nil?
|
88
|
-
when "float"
|
89
|
-
v.is_a?(Numeric)
|
90
|
-
when /int\[.*\]/
|
91
|
-
if v.is_a?(Integer)
|
92
|
-
size = t[4..-2]
|
93
|
-
raise Error, "Unknown size: #{size}. Please report a bug with #{@name}." unless size =~ /\A\d+\z/
|
94
|
-
v = [v] * size.to_i
|
95
|
-
values[k] = v
|
96
|
-
end
|
97
|
-
v.is_a?(Array) && v.all? { |v2| v2.is_a?(Integer) }
|
98
|
-
when "Scalar"
|
99
|
-
v.is_a?(Numeric)
|
100
|
-
when "ScalarType?"
|
101
|
-
v.nil?
|
102
|
-
when "bool"
|
103
|
-
v == true || v == false
|
104
|
-
when "str"
|
105
|
-
v.is_a?(String)
|
106
|
-
else
|
107
|
-
raise Error, "Unknown argument type: #{arg_types[k]}. Please report a bug with #{@name}."
|
108
|
-
end
|
109
|
-
|
110
|
-
if !good
|
111
|
-
if candidates.size == 1
|
112
|
-
k = "input" if k == "self"
|
113
|
-
return {error: "#{@name}(): argument '#{k}' must be #{t}"}
|
114
|
-
end
|
115
|
-
break
|
116
|
-
end
|
117
|
-
end
|
118
|
-
|
119
|
-
if good
|
120
|
-
final_values = values
|
121
|
-
end
|
122
|
-
|
123
|
-
good
|
124
|
-
end
|
125
|
-
|
126
|
-
if candidates.size != 1
|
127
|
-
raise Error, "This should never happen. Please report a bug with #{@name}."
|
128
|
-
end
|
129
|
-
|
130
|
-
func = candidates.first
|
131
|
-
args = func.args.map { |a| final_values[a[:name]] }
|
132
|
-
args << TensorOptions.new.dtype(6) if func.tensor_options
|
133
|
-
{
|
134
|
-
name: func.cpp_name,
|
135
|
-
args: args
|
136
|
-
}
|
137
|
-
end
|
138
|
-
end
|
139
|
-
end
|
140
|
-
end
|