torch-rb 0.3.7 → 0.4.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/README.md +1 -1
- data/codegen/function.rb +134 -0
- data/codegen/generate_functions.rb +546 -0
- data/{lib/torch/native → codegen}/native_functions.yaml +0 -0
- data/ext/torch/ext.cpp +54 -75
- data/ext/torch/extconf.rb +2 -2
- data/ext/torch/nn_functions.h +6 -0
- data/ext/torch/ruby_arg_parser.cpp +593 -0
- data/ext/torch/ruby_arg_parser.h +373 -0
- data/ext/torch/{templates.hpp → templates.h} +30 -51
- data/ext/torch/tensor_functions.h +6 -0
- data/ext/torch/torch_functions.h +6 -0
- data/ext/torch/utils.h +42 -0
- data/ext/torch/{templates.cpp → wrap_outputs.h} +16 -15
- data/lib/torch.rb +0 -62
- data/lib/torch/nn/functional.rb +30 -16
- data/lib/torch/nn/init.rb +5 -19
- data/lib/torch/optim/adadelta.rb +1 -1
- data/lib/torch/optim/adam.rb +2 -2
- data/lib/torch/optim/adamax.rb +1 -1
- data/lib/torch/optim/adamw.rb +1 -1
- data/lib/torch/optim/asgd.rb +1 -1
- data/lib/torch/optim/sgd.rb +3 -3
- data/lib/torch/tensor.rb +25 -105
- data/lib/torch/version.rb +1 -1
- metadata +27 -9
- data/lib/torch/native/dispatcher.rb +0 -70
- data/lib/torch/native/function.rb +0 -200
- data/lib/torch/native/generator.rb +0 -178
- data/lib/torch/native/parser.rb +0 -117
@@ -1,178 +0,0 @@
|
|
1
|
-
require "yaml"
|
2
|
-
# use require_relative for
|
3
|
-
# rake generate:function (without bundle)
|
4
|
-
require_relative "function"
|
5
|
-
|
6
|
-
module Torch
|
7
|
-
module Native
|
8
|
-
module Generator
|
9
|
-
class << self
|
10
|
-
def generate_cpp_functions
|
11
|
-
functions = grouped_functions
|
12
|
-
generate_cpp_file("torch", :define_singleton_method, functions[:torch])
|
13
|
-
generate_cpp_file("tensor", :define_method, functions[:tensor])
|
14
|
-
generate_cpp_file("nn", :define_singleton_method, functions[:nn])
|
15
|
-
end
|
16
|
-
|
17
|
-
def grouped_functions
|
18
|
-
functions = functions()
|
19
|
-
|
20
|
-
# skip functions
|
21
|
-
skip_args = ["Layout", "Storage", "ConstQuantizerPtr"]
|
22
|
-
|
23
|
-
# remove functions
|
24
|
-
functions.reject! do |f|
|
25
|
-
f.ruby_name.start_with?("_") ||
|
26
|
-
f.ruby_name.include?("_backward") ||
|
27
|
-
f.args.any? { |a| a[:type].include?("Dimname") }
|
28
|
-
end
|
29
|
-
|
30
|
-
# separate out into todo
|
31
|
-
todo_functions, functions =
|
32
|
-
functions.partition do |f|
|
33
|
-
f.args.any? do |a|
|
34
|
-
skip_args.any? { |sa| a[:type].include?(sa) } ||
|
35
|
-
# call to 'range' is ambiguous
|
36
|
-
f.cpp_name == "_range" ||
|
37
|
-
# native_functions.yaml is missing size argument for normal
|
38
|
-
# https://pytorch.org/cppdocs/api/function_namespacetorch_1a80253fe5a3ded4716ec929a348adb4b9.html
|
39
|
-
(f.base_name == "normal" && !f.out?)
|
40
|
-
end
|
41
|
-
end
|
42
|
-
|
43
|
-
# todo_functions.each do |f|
|
44
|
-
# puts f.func
|
45
|
-
# puts
|
46
|
-
# end
|
47
|
-
|
48
|
-
nn_functions, other_functions = functions.partition { |f| f.python_module == "nn" }
|
49
|
-
torch_functions = other_functions.select { |f| f.variants.include?("function") }
|
50
|
-
tensor_functions = other_functions.select { |f| f.variants.include?("method") }
|
51
|
-
|
52
|
-
{torch: torch_functions, tensor: tensor_functions, nn: nn_functions}
|
53
|
-
end
|
54
|
-
|
55
|
-
private
|
56
|
-
|
57
|
-
def generate_cpp_file(type, def_method, functions)
|
58
|
-
hpp_template = <<-TEMPLATE
|
59
|
-
// generated by rake generate:functions
|
60
|
-
// do not edit by hand
|
61
|
-
|
62
|
-
#pragma once
|
63
|
-
|
64
|
-
void add_%{type}_functions(Module m);
|
65
|
-
TEMPLATE
|
66
|
-
|
67
|
-
cpp_template = <<-TEMPLATE
|
68
|
-
// generated by rake generate:functions
|
69
|
-
// do not edit by hand
|
70
|
-
|
71
|
-
#include <torch/torch.h>
|
72
|
-
#include <rice/Module.hpp>
|
73
|
-
#include "templates.hpp"
|
74
|
-
|
75
|
-
%{functions}
|
76
|
-
|
77
|
-
void add_%{type}_functions(Module m) {
|
78
|
-
%{add_functions}
|
79
|
-
}
|
80
|
-
TEMPLATE
|
81
|
-
|
82
|
-
cpp_defs = []
|
83
|
-
add_defs = []
|
84
|
-
functions.sort_by(&:cpp_name).each do |func|
|
85
|
-
fargs = func.args.dup #.select { |a| a[:type] != "Generator?" }
|
86
|
-
fargs << {name: :options, type: "TensorOptions"} if func.tensor_options
|
87
|
-
|
88
|
-
cpp_args = []
|
89
|
-
fargs.each do |a|
|
90
|
-
t =
|
91
|
-
case a[:type]
|
92
|
-
when "Tensor"
|
93
|
-
"const Tensor &"
|
94
|
-
when "Tensor?"
|
95
|
-
# TODO better signature
|
96
|
-
"OptionalTensor"
|
97
|
-
when "ScalarType?"
|
98
|
-
"torch::optional<ScalarType>"
|
99
|
-
when "Tensor[]", "Tensor?[]"
|
100
|
-
# TODO make optional
|
101
|
-
"std::vector<Tensor>"
|
102
|
-
when "int"
|
103
|
-
"int64_t"
|
104
|
-
when "int?"
|
105
|
-
"torch::optional<int64_t>"
|
106
|
-
when "float?"
|
107
|
-
"torch::optional<double>"
|
108
|
-
when "bool?"
|
109
|
-
"torch::optional<bool>"
|
110
|
-
when "Scalar?"
|
111
|
-
"torch::optional<torch::Scalar>"
|
112
|
-
when "float"
|
113
|
-
"double"
|
114
|
-
when /\Aint\[/
|
115
|
-
"std::vector<int64_t>"
|
116
|
-
when /Tensor\(\S!?\)/
|
117
|
-
"Tensor &"
|
118
|
-
when "str"
|
119
|
-
"std::string"
|
120
|
-
when "TensorOptions"
|
121
|
-
"const torch::TensorOptions &"
|
122
|
-
when "Layout?"
|
123
|
-
"torch::optional<Layout>"
|
124
|
-
when "Device?"
|
125
|
-
"torch::optional<Device>"
|
126
|
-
when "Scalar", "bool", "ScalarType", "Layout", "Device", "Storage"
|
127
|
-
a[:type]
|
128
|
-
else
|
129
|
-
raise "Unknown type: #{a[:type]}"
|
130
|
-
end
|
131
|
-
|
132
|
-
t = "MyReduction" if a[:name] == :reduction && t == "int64_t"
|
133
|
-
cpp_args << [t, a[:name]].join(" ").sub("& ", "&")
|
134
|
-
end
|
135
|
-
|
136
|
-
dispatch = func.out? ? "#{func.base_name}_out" : func.base_name
|
137
|
-
args = fargs.map { |a| a[:name] }
|
138
|
-
args.unshift(*args.pop(func.out_size)) if func.out?
|
139
|
-
args.delete(:self) if def_method == :define_method
|
140
|
-
|
141
|
-
prefix = def_method == :define_method ? "self." : "torch::"
|
142
|
-
|
143
|
-
body = "#{prefix}#{dispatch}(#{args.join(", ")})"
|
144
|
-
|
145
|
-
if func.cpp_name == "_fill_diagonal_"
|
146
|
-
body = "to_ruby<torch::Tensor>(#{body})"
|
147
|
-
elsif !func.ret_void?
|
148
|
-
body = "wrap(#{body})"
|
149
|
-
end
|
150
|
-
|
151
|
-
cpp_defs << "// #{func.func}
|
152
|
-
static #{func.ret_void? ? "void" : "Object"} #{type}#{func.cpp_name}(#{cpp_args.join(", ")})
|
153
|
-
{
|
154
|
-
return #{body};
|
155
|
-
}"
|
156
|
-
|
157
|
-
add_defs << "m.#{def_method}(\"#{func.cpp_name}\", #{type}#{func.cpp_name});"
|
158
|
-
end
|
159
|
-
|
160
|
-
hpp_contents = hpp_template % {type: type}
|
161
|
-
cpp_contents = cpp_template % {type: type, functions: cpp_defs.join("\n\n"), add_functions: add_defs.join("\n ")}
|
162
|
-
|
163
|
-
path = File.expand_path("../../../ext/torch", __dir__)
|
164
|
-
File.write("#{path}/#{type}_functions.hpp", hpp_contents)
|
165
|
-
File.write("#{path}/#{type}_functions.cpp", cpp_contents)
|
166
|
-
end
|
167
|
-
|
168
|
-
def functions
|
169
|
-
@native_functions ||= YAML.load_file(path).map { |f| Function.new(f) }
|
170
|
-
end
|
171
|
-
|
172
|
-
def path
|
173
|
-
File.expand_path("native_functions.yaml", __dir__)
|
174
|
-
end
|
175
|
-
end
|
176
|
-
end
|
177
|
-
end
|
178
|
-
end
|
data/lib/torch/native/parser.rb
DELETED
@@ -1,117 +0,0 @@
|
|
1
|
-
module Torch
|
2
|
-
module Native
|
3
|
-
class Parser
|
4
|
-
def initialize(functions)
|
5
|
-
@functions = functions
|
6
|
-
@name = @functions.first.ruby_name
|
7
|
-
@min_args = @functions.map { |f| f.args.count { |a| a[:pos] && !a[:has_default] } }.min
|
8
|
-
@max_args = @functions.map { |f| f.args.count { |a| a[:pos] } }.max
|
9
|
-
@int_array_first = @functions.all? { |c| c.args.first && c.args.first[:type] == "int[]" }
|
10
|
-
end
|
11
|
-
|
12
|
-
# TODO improve performance
|
13
|
-
# possibly move to C++ (see python_arg_parser.cpp)
|
14
|
-
def parse(args, options)
|
15
|
-
candidates = @functions.dup
|
16
|
-
|
17
|
-
# TODO check candidates individually to see if they match
|
18
|
-
if @int_array_first
|
19
|
-
int_args = []
|
20
|
-
while args.first.is_a?(Integer)
|
21
|
-
int_args << args.shift
|
22
|
-
end
|
23
|
-
if int_args.any?
|
24
|
-
raise ArgumentError, "argument '#{candidates.first.args.first[:name]}' must be array of ints, but found element of type #{args.first.class.name} at pos #{int_args.size + 1}" if args.any?
|
25
|
-
args.unshift(int_args)
|
26
|
-
end
|
27
|
-
end
|
28
|
-
|
29
|
-
# TODO account for args passed as options here
|
30
|
-
if args.size < @min_args || args.size > @max_args
|
31
|
-
expected = String.new(@min_args.to_s)
|
32
|
-
expected += "..#{@max_args}" if @max_args != @min_args
|
33
|
-
return {error: "wrong number of arguments (given #{args.size}, expected #{expected})"}
|
34
|
-
end
|
35
|
-
|
36
|
-
candidates.reject! { |f| args.size > f.args.size }
|
37
|
-
|
38
|
-
# handle out with multiple
|
39
|
-
# there should only be one match, so safe to modify all
|
40
|
-
if options[:out]
|
41
|
-
if (out_func = candidates.find { |f| f.out? }) && out_func.out_size > 1
|
42
|
-
out_args = out_func.args.last(2).map { |a| a[:name] }
|
43
|
-
out_args.zip(options.delete(:out)).each do |k, v|
|
44
|
-
options[k] = v
|
45
|
-
end
|
46
|
-
candidates = [out_func]
|
47
|
-
end
|
48
|
-
else
|
49
|
-
# exclude functions missing required options
|
50
|
-
candidates.reject!(&:out?)
|
51
|
-
end
|
52
|
-
|
53
|
-
final_values = nil
|
54
|
-
|
55
|
-
# check args
|
56
|
-
while (func = candidates.shift)
|
57
|
-
good = true
|
58
|
-
|
59
|
-
# set values
|
60
|
-
# TODO use array instead of hash?
|
61
|
-
values = {}
|
62
|
-
args.each_with_index do |a, i|
|
63
|
-
values[func.arg_names[i]] = a
|
64
|
-
end
|
65
|
-
options.each do |k, v|
|
66
|
-
values[k] = v
|
67
|
-
end
|
68
|
-
func.arg_defaults.each do |k, v|
|
69
|
-
values[k] = v unless values.key?(k)
|
70
|
-
end
|
71
|
-
func.int_array_lengths.each do |k, len|
|
72
|
-
values[k] = [values[k]] * len if values[k].is_a?(Integer)
|
73
|
-
end
|
74
|
-
|
75
|
-
arg_checkers = func.arg_checkers
|
76
|
-
|
77
|
-
values.each_key do |k|
|
78
|
-
unless arg_checkers.key?(k)
|
79
|
-
good = false
|
80
|
-
if candidates.empty?
|
81
|
-
# TODO show all bad keywords at once like Ruby?
|
82
|
-
return {error: "unknown keyword: #{k}"}
|
83
|
-
end
|
84
|
-
break
|
85
|
-
end
|
86
|
-
|
87
|
-
unless arg_checkers[k].call(values[k])
|
88
|
-
good = false
|
89
|
-
if candidates.empty?
|
90
|
-
t = func.arg_types[k]
|
91
|
-
k = :input if k == :self
|
92
|
-
return {error: "#{@name}(): argument '#{k}' must be #{t}"}
|
93
|
-
end
|
94
|
-
break
|
95
|
-
end
|
96
|
-
end
|
97
|
-
|
98
|
-
if good
|
99
|
-
final_values = values
|
100
|
-
break
|
101
|
-
end
|
102
|
-
end
|
103
|
-
|
104
|
-
unless final_values
|
105
|
-
raise Error, "This should never happen. Please report a bug with #{@name}."
|
106
|
-
end
|
107
|
-
|
108
|
-
args = func.arg_names.map { |k| final_values[k] }
|
109
|
-
args << TensorOptions.new.dtype(6) if func.tensor_options
|
110
|
-
{
|
111
|
-
name: func.cpp_name,
|
112
|
-
args: args
|
113
|
-
}
|
114
|
-
end
|
115
|
-
end
|
116
|
-
end
|
117
|
-
end
|