torch-rb 0.3.7 → 0.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/README.md +1 -1
- data/codegen/function.rb +134 -0
- data/codegen/generate_functions.rb +546 -0
- data/{lib/torch/native → codegen}/native_functions.yaml +0 -0
- data/ext/torch/ext.cpp +54 -75
- data/ext/torch/extconf.rb +2 -2
- data/ext/torch/nn_functions.h +6 -0
- data/ext/torch/ruby_arg_parser.cpp +593 -0
- data/ext/torch/ruby_arg_parser.h +373 -0
- data/ext/torch/{templates.hpp → templates.h} +30 -51
- data/ext/torch/tensor_functions.h +6 -0
- data/ext/torch/torch_functions.h +6 -0
- data/ext/torch/utils.h +42 -0
- data/ext/torch/{templates.cpp → wrap_outputs.h} +16 -15
- data/lib/torch.rb +0 -62
- data/lib/torch/nn/functional.rb +30 -16
- data/lib/torch/nn/init.rb +5 -19
- data/lib/torch/optim/adadelta.rb +1 -1
- data/lib/torch/optim/adam.rb +2 -2
- data/lib/torch/optim/adamax.rb +1 -1
- data/lib/torch/optim/adamw.rb +1 -1
- data/lib/torch/optim/asgd.rb +1 -1
- data/lib/torch/optim/sgd.rb +3 -3
- data/lib/torch/tensor.rb +25 -105
- data/lib/torch/version.rb +1 -1
- metadata +27 -9
- data/lib/torch/native/dispatcher.rb +0 -70
- data/lib/torch/native/function.rb +0 -200
- data/lib/torch/native/generator.rb +0 -178
- data/lib/torch/native/parser.rb +0 -117
@@ -1,178 +0,0 @@
|
|
1
|
-
require "yaml"
|
2
|
-
# use require_relative for
|
3
|
-
# rake generate:function (without bundle)
|
4
|
-
require_relative "function"
|
5
|
-
|
6
|
-
module Torch
|
7
|
-
module Native
|
8
|
-
module Generator
|
9
|
-
class << self
|
10
|
-
def generate_cpp_functions
|
11
|
-
functions = grouped_functions
|
12
|
-
generate_cpp_file("torch", :define_singleton_method, functions[:torch])
|
13
|
-
generate_cpp_file("tensor", :define_method, functions[:tensor])
|
14
|
-
generate_cpp_file("nn", :define_singleton_method, functions[:nn])
|
15
|
-
end
|
16
|
-
|
17
|
-
def grouped_functions
|
18
|
-
functions = functions()
|
19
|
-
|
20
|
-
# skip functions
|
21
|
-
skip_args = ["Layout", "Storage", "ConstQuantizerPtr"]
|
22
|
-
|
23
|
-
# remove functions
|
24
|
-
functions.reject! do |f|
|
25
|
-
f.ruby_name.start_with?("_") ||
|
26
|
-
f.ruby_name.include?("_backward") ||
|
27
|
-
f.args.any? { |a| a[:type].include?("Dimname") }
|
28
|
-
end
|
29
|
-
|
30
|
-
# separate out into todo
|
31
|
-
todo_functions, functions =
|
32
|
-
functions.partition do |f|
|
33
|
-
f.args.any? do |a|
|
34
|
-
skip_args.any? { |sa| a[:type].include?(sa) } ||
|
35
|
-
# call to 'range' is ambiguous
|
36
|
-
f.cpp_name == "_range" ||
|
37
|
-
# native_functions.yaml is missing size argument for normal
|
38
|
-
# https://pytorch.org/cppdocs/api/function_namespacetorch_1a80253fe5a3ded4716ec929a348adb4b9.html
|
39
|
-
(f.base_name == "normal" && !f.out?)
|
40
|
-
end
|
41
|
-
end
|
42
|
-
|
43
|
-
# todo_functions.each do |f|
|
44
|
-
# puts f.func
|
45
|
-
# puts
|
46
|
-
# end
|
47
|
-
|
48
|
-
nn_functions, other_functions = functions.partition { |f| f.python_module == "nn" }
|
49
|
-
torch_functions = other_functions.select { |f| f.variants.include?("function") }
|
50
|
-
tensor_functions = other_functions.select { |f| f.variants.include?("method") }
|
51
|
-
|
52
|
-
{torch: torch_functions, tensor: tensor_functions, nn: nn_functions}
|
53
|
-
end
|
54
|
-
|
55
|
-
private
|
56
|
-
|
57
|
-
def generate_cpp_file(type, def_method, functions)
|
58
|
-
hpp_template = <<-TEMPLATE
|
59
|
-
// generated by rake generate:functions
|
60
|
-
// do not edit by hand
|
61
|
-
|
62
|
-
#pragma once
|
63
|
-
|
64
|
-
void add_%{type}_functions(Module m);
|
65
|
-
TEMPLATE
|
66
|
-
|
67
|
-
cpp_template = <<-TEMPLATE
|
68
|
-
// generated by rake generate:functions
|
69
|
-
// do not edit by hand
|
70
|
-
|
71
|
-
#include <torch/torch.h>
|
72
|
-
#include <rice/Module.hpp>
|
73
|
-
#include "templates.hpp"
|
74
|
-
|
75
|
-
%{functions}
|
76
|
-
|
77
|
-
void add_%{type}_functions(Module m) {
|
78
|
-
%{add_functions}
|
79
|
-
}
|
80
|
-
TEMPLATE
|
81
|
-
|
82
|
-
cpp_defs = []
|
83
|
-
add_defs = []
|
84
|
-
functions.sort_by(&:cpp_name).each do |func|
|
85
|
-
fargs = func.args.dup #.select { |a| a[:type] != "Generator?" }
|
86
|
-
fargs << {name: :options, type: "TensorOptions"} if func.tensor_options
|
87
|
-
|
88
|
-
cpp_args = []
|
89
|
-
fargs.each do |a|
|
90
|
-
t =
|
91
|
-
case a[:type]
|
92
|
-
when "Tensor"
|
93
|
-
"const Tensor &"
|
94
|
-
when "Tensor?"
|
95
|
-
# TODO better signature
|
96
|
-
"OptionalTensor"
|
97
|
-
when "ScalarType?"
|
98
|
-
"torch::optional<ScalarType>"
|
99
|
-
when "Tensor[]", "Tensor?[]"
|
100
|
-
# TODO make optional
|
101
|
-
"std::vector<Tensor>"
|
102
|
-
when "int"
|
103
|
-
"int64_t"
|
104
|
-
when "int?"
|
105
|
-
"torch::optional<int64_t>"
|
106
|
-
when "float?"
|
107
|
-
"torch::optional<double>"
|
108
|
-
when "bool?"
|
109
|
-
"torch::optional<bool>"
|
110
|
-
when "Scalar?"
|
111
|
-
"torch::optional<torch::Scalar>"
|
112
|
-
when "float"
|
113
|
-
"double"
|
114
|
-
when /\Aint\[/
|
115
|
-
"std::vector<int64_t>"
|
116
|
-
when /Tensor\(\S!?\)/
|
117
|
-
"Tensor &"
|
118
|
-
when "str"
|
119
|
-
"std::string"
|
120
|
-
when "TensorOptions"
|
121
|
-
"const torch::TensorOptions &"
|
122
|
-
when "Layout?"
|
123
|
-
"torch::optional<Layout>"
|
124
|
-
when "Device?"
|
125
|
-
"torch::optional<Device>"
|
126
|
-
when "Scalar", "bool", "ScalarType", "Layout", "Device", "Storage"
|
127
|
-
a[:type]
|
128
|
-
else
|
129
|
-
raise "Unknown type: #{a[:type]}"
|
130
|
-
end
|
131
|
-
|
132
|
-
t = "MyReduction" if a[:name] == :reduction && t == "int64_t"
|
133
|
-
cpp_args << [t, a[:name]].join(" ").sub("& ", "&")
|
134
|
-
end
|
135
|
-
|
136
|
-
dispatch = func.out? ? "#{func.base_name}_out" : func.base_name
|
137
|
-
args = fargs.map { |a| a[:name] }
|
138
|
-
args.unshift(*args.pop(func.out_size)) if func.out?
|
139
|
-
args.delete(:self) if def_method == :define_method
|
140
|
-
|
141
|
-
prefix = def_method == :define_method ? "self." : "torch::"
|
142
|
-
|
143
|
-
body = "#{prefix}#{dispatch}(#{args.join(", ")})"
|
144
|
-
|
145
|
-
if func.cpp_name == "_fill_diagonal_"
|
146
|
-
body = "to_ruby<torch::Tensor>(#{body})"
|
147
|
-
elsif !func.ret_void?
|
148
|
-
body = "wrap(#{body})"
|
149
|
-
end
|
150
|
-
|
151
|
-
cpp_defs << "// #{func.func}
|
152
|
-
static #{func.ret_void? ? "void" : "Object"} #{type}#{func.cpp_name}(#{cpp_args.join(", ")})
|
153
|
-
{
|
154
|
-
return #{body};
|
155
|
-
}"
|
156
|
-
|
157
|
-
add_defs << "m.#{def_method}(\"#{func.cpp_name}\", #{type}#{func.cpp_name});"
|
158
|
-
end
|
159
|
-
|
160
|
-
hpp_contents = hpp_template % {type: type}
|
161
|
-
cpp_contents = cpp_template % {type: type, functions: cpp_defs.join("\n\n"), add_functions: add_defs.join("\n ")}
|
162
|
-
|
163
|
-
path = File.expand_path("../../../ext/torch", __dir__)
|
164
|
-
File.write("#{path}/#{type}_functions.hpp", hpp_contents)
|
165
|
-
File.write("#{path}/#{type}_functions.cpp", cpp_contents)
|
166
|
-
end
|
167
|
-
|
168
|
-
def functions
|
169
|
-
@native_functions ||= YAML.load_file(path).map { |f| Function.new(f) }
|
170
|
-
end
|
171
|
-
|
172
|
-
def path
|
173
|
-
File.expand_path("native_functions.yaml", __dir__)
|
174
|
-
end
|
175
|
-
end
|
176
|
-
end
|
177
|
-
end
|
178
|
-
end
|
data/lib/torch/native/parser.rb
DELETED
@@ -1,117 +0,0 @@
|
|
1
|
-
module Torch
|
2
|
-
module Native
|
3
|
-
class Parser
|
4
|
-
def initialize(functions)
|
5
|
-
@functions = functions
|
6
|
-
@name = @functions.first.ruby_name
|
7
|
-
@min_args = @functions.map { |f| f.args.count { |a| a[:pos] && !a[:has_default] } }.min
|
8
|
-
@max_args = @functions.map { |f| f.args.count { |a| a[:pos] } }.max
|
9
|
-
@int_array_first = @functions.all? { |c| c.args.first && c.args.first[:type] == "int[]" }
|
10
|
-
end
|
11
|
-
|
12
|
-
# TODO improve performance
|
13
|
-
# possibly move to C++ (see python_arg_parser.cpp)
|
14
|
-
def parse(args, options)
|
15
|
-
candidates = @functions.dup
|
16
|
-
|
17
|
-
# TODO check candidates individually to see if they match
|
18
|
-
if @int_array_first
|
19
|
-
int_args = []
|
20
|
-
while args.first.is_a?(Integer)
|
21
|
-
int_args << args.shift
|
22
|
-
end
|
23
|
-
if int_args.any?
|
24
|
-
raise ArgumentError, "argument '#{candidates.first.args.first[:name]}' must be array of ints, but found element of type #{args.first.class.name} at pos #{int_args.size + 1}" if args.any?
|
25
|
-
args.unshift(int_args)
|
26
|
-
end
|
27
|
-
end
|
28
|
-
|
29
|
-
# TODO account for args passed as options here
|
30
|
-
if args.size < @min_args || args.size > @max_args
|
31
|
-
expected = String.new(@min_args.to_s)
|
32
|
-
expected += "..#{@max_args}" if @max_args != @min_args
|
33
|
-
return {error: "wrong number of arguments (given #{args.size}, expected #{expected})"}
|
34
|
-
end
|
35
|
-
|
36
|
-
candidates.reject! { |f| args.size > f.args.size }
|
37
|
-
|
38
|
-
# handle out with multiple
|
39
|
-
# there should only be one match, so safe to modify all
|
40
|
-
if options[:out]
|
41
|
-
if (out_func = candidates.find { |f| f.out? }) && out_func.out_size > 1
|
42
|
-
out_args = out_func.args.last(2).map { |a| a[:name] }
|
43
|
-
out_args.zip(options.delete(:out)).each do |k, v|
|
44
|
-
options[k] = v
|
45
|
-
end
|
46
|
-
candidates = [out_func]
|
47
|
-
end
|
48
|
-
else
|
49
|
-
# exclude functions missing required options
|
50
|
-
candidates.reject!(&:out?)
|
51
|
-
end
|
52
|
-
|
53
|
-
final_values = nil
|
54
|
-
|
55
|
-
# check args
|
56
|
-
while (func = candidates.shift)
|
57
|
-
good = true
|
58
|
-
|
59
|
-
# set values
|
60
|
-
# TODO use array instead of hash?
|
61
|
-
values = {}
|
62
|
-
args.each_with_index do |a, i|
|
63
|
-
values[func.arg_names[i]] = a
|
64
|
-
end
|
65
|
-
options.each do |k, v|
|
66
|
-
values[k] = v
|
67
|
-
end
|
68
|
-
func.arg_defaults.each do |k, v|
|
69
|
-
values[k] = v unless values.key?(k)
|
70
|
-
end
|
71
|
-
func.int_array_lengths.each do |k, len|
|
72
|
-
values[k] = [values[k]] * len if values[k].is_a?(Integer)
|
73
|
-
end
|
74
|
-
|
75
|
-
arg_checkers = func.arg_checkers
|
76
|
-
|
77
|
-
values.each_key do |k|
|
78
|
-
unless arg_checkers.key?(k)
|
79
|
-
good = false
|
80
|
-
if candidates.empty?
|
81
|
-
# TODO show all bad keywords at once like Ruby?
|
82
|
-
return {error: "unknown keyword: #{k}"}
|
83
|
-
end
|
84
|
-
break
|
85
|
-
end
|
86
|
-
|
87
|
-
unless arg_checkers[k].call(values[k])
|
88
|
-
good = false
|
89
|
-
if candidates.empty?
|
90
|
-
t = func.arg_types[k]
|
91
|
-
k = :input if k == :self
|
92
|
-
return {error: "#{@name}(): argument '#{k}' must be #{t}"}
|
93
|
-
end
|
94
|
-
break
|
95
|
-
end
|
96
|
-
end
|
97
|
-
|
98
|
-
if good
|
99
|
-
final_values = values
|
100
|
-
break
|
101
|
-
end
|
102
|
-
end
|
103
|
-
|
104
|
-
unless final_values
|
105
|
-
raise Error, "This should never happen. Please report a bug with #{@name}."
|
106
|
-
end
|
107
|
-
|
108
|
-
args = func.arg_names.map { |k| final_values[k] }
|
109
|
-
args << TensorOptions.new.dtype(6) if func.tensor_options
|
110
|
-
{
|
111
|
-
name: func.cpp_name,
|
112
|
-
args: args
|
113
|
-
}
|
114
|
-
end
|
115
|
-
end
|
116
|
-
end
|
117
|
-
end
|