torch-rb 0.3.2 → 0.3.7
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +28 -0
- data/README.md +7 -2
- data/ext/torch/ext.cpp +60 -20
- data/ext/torch/extconf.rb +3 -0
- data/ext/torch/templates.cpp +36 -0
- data/ext/torch/templates.hpp +81 -87
- data/lib/torch.rb +71 -19
- data/lib/torch/native/dispatcher.rb +30 -8
- data/lib/torch/native/function.rb +93 -4
- data/lib/torch/native/generator.rb +45 -41
- data/lib/torch/native/parser.rb +57 -76
- data/lib/torch/nn/functional.rb +112 -2
- data/lib/torch/nn/leaky_relu.rb +3 -3
- data/lib/torch/nn/module.rb +9 -1
- data/lib/torch/nn/upsample.rb +31 -0
- data/lib/torch/tensor.rb +45 -51
- data/lib/torch/utils/data/data_loader.rb +2 -0
- data/lib/torch/utils/data/tensor_dataset.rb +2 -0
- data/lib/torch/version.rb +1 -1
- metadata +3 -2
data/lib/torch.rb
CHANGED
@@ -174,6 +174,9 @@ require "torch/nn/smooth_l1_loss"
|
|
174
174
|
require "torch/nn/soft_margin_loss"
|
175
175
|
require "torch/nn/triplet_margin_loss"
|
176
176
|
|
177
|
+
# nn vision
|
178
|
+
require "torch/nn/upsample"
|
179
|
+
|
177
180
|
# nn other
|
178
181
|
require "torch/nn/functional"
|
179
182
|
require "torch/nn/init"
|
@@ -196,6 +199,32 @@ module Torch
|
|
196
199
|
end
|
197
200
|
end
|
198
201
|
|
202
|
+
# legacy
|
203
|
+
# but may make it easier to port tutorials
|
204
|
+
module Autograd
|
205
|
+
class Variable
|
206
|
+
def self.new(x)
|
207
|
+
raise ArgumentError, "Variable data has to be a tensor, but got #{x.class.name}" unless x.is_a?(Tensor)
|
208
|
+
warn "[torch] The Variable API is deprecated. Use tensors with requires_grad: true instead."
|
209
|
+
x
|
210
|
+
end
|
211
|
+
end
|
212
|
+
end
|
213
|
+
|
214
|
+
# TODO move to C++
|
215
|
+
class ByteStorage
|
216
|
+
# private
|
217
|
+
attr_reader :bytes
|
218
|
+
|
219
|
+
def initialize(bytes)
|
220
|
+
@bytes = bytes
|
221
|
+
end
|
222
|
+
|
223
|
+
def self.from_buffer(bytes)
|
224
|
+
new(bytes)
|
225
|
+
end
|
226
|
+
end
|
227
|
+
|
199
228
|
# keys: https://pytorch.org/docs/stable/tensor_attributes.html#torch.torch.dtype
|
200
229
|
# values: https://github.com/pytorch/pytorch/blob/master/c10/core/ScalarType.h
|
201
230
|
DTYPE_TO_ENUM = {
|
@@ -224,40 +253,43 @@ module Torch
|
|
224
253
|
}
|
225
254
|
ENUM_TO_DTYPE = DTYPE_TO_ENUM.map(&:reverse).to_h
|
226
255
|
|
256
|
+
TENSOR_TYPE_CLASSES = []
|
257
|
+
|
227
258
|
def self._make_tensor_class(dtype, cuda = false)
|
228
259
|
cls = Class.new
|
229
260
|
device = cuda ? "cuda" : "cpu"
|
230
261
|
cls.define_singleton_method("new") do |*args|
|
231
262
|
if args.size == 1 && args.first.is_a?(Tensor)
|
232
263
|
args.first.send(dtype).to(device)
|
264
|
+
elsif args.size == 1 && args.first.is_a?(ByteStorage) && dtype == :uint8
|
265
|
+
bytes = args.first.bytes
|
266
|
+
Torch._from_blob(bytes, [bytes.bytesize], TensorOptions.new.dtype(DTYPE_TO_ENUM[dtype]))
|
233
267
|
elsif args.size == 1 && args.first.is_a?(Array)
|
234
268
|
Torch.tensor(args.first, dtype: dtype, device: device)
|
235
269
|
else
|
236
270
|
Torch.empty(*args, dtype: dtype, device: device)
|
237
271
|
end
|
238
272
|
end
|
273
|
+
TENSOR_TYPE_CLASSES << cls
|
239
274
|
cls
|
240
275
|
end
|
241
276
|
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
CUDA::IntTensor = _make_tensor_class(:int32, true)
|
259
|
-
CUDA::LongTensor = _make_tensor_class(:int64, true)
|
260
|
-
CUDA::BoolTensor = _make_tensor_class(:bool, true)
|
277
|
+
DTYPE_TO_CLASS = {
|
278
|
+
float32: "FloatTensor",
|
279
|
+
float64: "DoubleTensor",
|
280
|
+
float16: "HalfTensor",
|
281
|
+
uint8: "ByteTensor",
|
282
|
+
int8: "CharTensor",
|
283
|
+
int16: "ShortTensor",
|
284
|
+
int32: "IntTensor",
|
285
|
+
int64: "LongTensor",
|
286
|
+
bool: "BoolTensor"
|
287
|
+
}
|
288
|
+
|
289
|
+
DTYPE_TO_CLASS.each do |dtype, class_name|
|
290
|
+
const_set(class_name, _make_tensor_class(dtype))
|
291
|
+
CUDA.const_set(class_name, _make_tensor_class(dtype, true))
|
292
|
+
end
|
261
293
|
|
262
294
|
class << self
|
263
295
|
# Torch.float, Torch.long, etc
|
@@ -388,6 +420,10 @@ module Torch
|
|
388
420
|
end
|
389
421
|
|
390
422
|
def randperm(n, **options)
|
423
|
+
# dtype hack in Python
|
424
|
+
# https://github.com/pytorch/pytorch/blob/v1.6.0/tools/autograd/gen_python_functions.py#L1307-L1311
|
425
|
+
options[:dtype] ||= :int64
|
426
|
+
|
391
427
|
_randperm(n, tensor_options(**options))
|
392
428
|
end
|
393
429
|
|
@@ -460,6 +496,22 @@ module Torch
|
|
460
496
|
zeros(input.size, **like_options(input, options))
|
461
497
|
end
|
462
498
|
|
499
|
+
def stft(input, n_fft, hop_length: nil, win_length: nil, window: nil, center: true, pad_mode: "reflect", normalized: false, onesided: true)
|
500
|
+
if center
|
501
|
+
signal_dim = input.dim
|
502
|
+
extended_shape = [1] * (3 - signal_dim) + input.size
|
503
|
+
pad = n_fft.div(2).to_i
|
504
|
+
input = NN::F.pad(input.view(extended_shape), [pad, pad], mode: pad_mode)
|
505
|
+
input = input.view(input.shape[-signal_dim..-1])
|
506
|
+
end
|
507
|
+
_stft(input, n_fft, hop_length, win_length, window, normalized, onesided)
|
508
|
+
end
|
509
|
+
|
510
|
+
def clamp(tensor, min, max)
|
511
|
+
tensor = _clamp_min(tensor, min)
|
512
|
+
_clamp_max(tensor, max)
|
513
|
+
end
|
514
|
+
|
463
515
|
private
|
464
516
|
|
465
517
|
def to_ivalue(obj)
|
@@ -22,21 +22,43 @@ module Torch
|
|
22
22
|
end
|
23
23
|
|
24
24
|
def bind_functions(context, def_method, functions)
|
25
|
+
instance_method = def_method == :define_method
|
25
26
|
functions.group_by(&:ruby_name).sort_by { |g, _| g }.each do |name, funcs|
|
26
|
-
if
|
27
|
+
if instance_method
|
27
28
|
funcs.map! { |f| Function.new(f.function) }
|
28
|
-
funcs.each { |f| f.args.reject! { |a| a[:name] ==
|
29
|
+
funcs.each { |f| f.args.reject! { |a| a[:name] == :self } }
|
29
30
|
end
|
30
31
|
|
31
|
-
defined =
|
32
|
+
defined = instance_method ? context.method_defined?(name) : context.respond_to?(name)
|
32
33
|
next if defined && name != "clone"
|
33
34
|
|
34
|
-
parser
|
35
|
+
# skip parser when possible for performance
|
36
|
+
if funcs.size == 1 && funcs.first.args.size == 0
|
37
|
+
# functions with no arguments
|
38
|
+
if instance_method
|
39
|
+
context.send(:alias_method, name, funcs.first.cpp_name)
|
40
|
+
else
|
41
|
+
context.singleton_class.send(:alias_method, name, funcs.first.cpp_name)
|
42
|
+
end
|
43
|
+
elsif funcs.size == 2 && funcs.map { |f| f.arg_types.values }.sort == [["Scalar"], ["Tensor"]]
|
44
|
+
# functions that take a tensor or scalar
|
45
|
+
scalar_name, tensor_name = funcs.sort_by { |f| f.arg_types.values }.map(&:cpp_name)
|
46
|
+
context.send(def_method, name) do |other|
|
47
|
+
case other
|
48
|
+
when Tensor
|
49
|
+
send(tensor_name, other)
|
50
|
+
else
|
51
|
+
send(scalar_name, other)
|
52
|
+
end
|
53
|
+
end
|
54
|
+
else
|
55
|
+
parser = Parser.new(funcs)
|
35
56
|
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
57
|
+
context.send(def_method, name) do |*args, **options|
|
58
|
+
result = parser.parse(args, options)
|
59
|
+
raise ArgumentError, result[:error] if result[:error]
|
60
|
+
send(result[:name], *result[:args])
|
61
|
+
end
|
40
62
|
end
|
41
63
|
end
|
42
64
|
end
|
@@ -1,10 +1,15 @@
|
|
1
1
|
module Torch
|
2
2
|
module Native
|
3
3
|
class Function
|
4
|
-
attr_reader :function
|
4
|
+
attr_reader :function, :tensor_options
|
5
5
|
|
6
6
|
def initialize(function)
|
7
7
|
@function = function
|
8
|
+
|
9
|
+
# note: don't modify function in-place
|
10
|
+
@tensor_options_str = ", *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None)"
|
11
|
+
@tensor_options = @function["func"].include?(@tensor_options_str)
|
12
|
+
@out = out_size > 0 && base_name[-1] != "_"
|
8
13
|
end
|
9
14
|
|
10
15
|
def func
|
@@ -27,7 +32,7 @@ module Torch
|
|
27
32
|
@args ||= begin
|
28
33
|
args = []
|
29
34
|
pos = true
|
30
|
-
args_str = func.split("(", 2).last.split(") ->").first
|
35
|
+
args_str = func.sub(@tensor_options_str, ")").split("(", 2).last.split(") ->").first
|
31
36
|
args_str.split(", ").each do |a|
|
32
37
|
if a == "*"
|
33
38
|
pos = false
|
@@ -68,12 +73,88 @@ module Torch
|
|
68
73
|
next if t == "Generator?"
|
69
74
|
next if t == "MemoryFormat"
|
70
75
|
next if t == "MemoryFormat?"
|
71
|
-
args << {name: k, type: t, default: d, pos: pos, has_default: has_default}
|
76
|
+
args << {name: k.to_sym, type: t, default: d, pos: pos, has_default: has_default}
|
72
77
|
end
|
73
78
|
args
|
74
79
|
end
|
75
80
|
end
|
76
81
|
|
82
|
+
def arg_checkers
|
83
|
+
@arg_checkers ||= begin
|
84
|
+
checkers = {}
|
85
|
+
arg_types.each do |k, t|
|
86
|
+
checker =
|
87
|
+
case t
|
88
|
+
when "Tensor"
|
89
|
+
->(v) { v.is_a?(Tensor) }
|
90
|
+
when "Tensor?"
|
91
|
+
->(v) { v.nil? || v.is_a?(Tensor) }
|
92
|
+
when "Tensor[]", "Tensor?[]"
|
93
|
+
->(v) { v.is_a?(Array) && v.all? { |v2| v2.is_a?(Tensor) } }
|
94
|
+
when "int"
|
95
|
+
if k == :reduction
|
96
|
+
->(v) { v.is_a?(String) }
|
97
|
+
else
|
98
|
+
->(v) { v.is_a?(Integer) }
|
99
|
+
end
|
100
|
+
when "int?"
|
101
|
+
->(v) { v.is_a?(Integer) || v.nil? }
|
102
|
+
when "float?"
|
103
|
+
->(v) { v.is_a?(Numeric) || v.nil? }
|
104
|
+
when "bool?"
|
105
|
+
->(v) { v == true || v == false || v.nil? }
|
106
|
+
when "float"
|
107
|
+
->(v) { v.is_a?(Numeric) }
|
108
|
+
when /int\[.*\]/
|
109
|
+
->(v) { v.is_a?(Array) && v.all? { |v2| v2.is_a?(Integer) } }
|
110
|
+
when "Scalar"
|
111
|
+
->(v) { v.is_a?(Numeric) }
|
112
|
+
when "Scalar?"
|
113
|
+
->(v) { v.is_a?(Numeric) || v.nil? }
|
114
|
+
when "ScalarType"
|
115
|
+
->(v) { false } # not supported yet
|
116
|
+
when "ScalarType?"
|
117
|
+
->(v) { v.nil? }
|
118
|
+
when "bool"
|
119
|
+
->(v) { v == true || v == false }
|
120
|
+
when "str"
|
121
|
+
->(v) { v.is_a?(String) }
|
122
|
+
else
|
123
|
+
raise Error, "Unknown argument type: #{t}. Please report a bug with #{@name}."
|
124
|
+
end
|
125
|
+
checkers[k] = checker
|
126
|
+
end
|
127
|
+
checkers
|
128
|
+
end
|
129
|
+
end
|
130
|
+
|
131
|
+
def int_array_lengths
|
132
|
+
@int_array_lengths ||= begin
|
133
|
+
ret = {}
|
134
|
+
arg_types.each do |k, t|
|
135
|
+
if t.match?(/\Aint\[.+\]\z/)
|
136
|
+
size = t[4..-2]
|
137
|
+
raise Error, "Unknown size: #{size}. Please report a bug with #{@name}." unless size =~ /\A\d+\z/
|
138
|
+
ret[k] = size.to_i
|
139
|
+
end
|
140
|
+
end
|
141
|
+
ret
|
142
|
+
end
|
143
|
+
end
|
144
|
+
|
145
|
+
def arg_names
|
146
|
+
@arg_names ||= args.map { |a| a[:name] }
|
147
|
+
end
|
148
|
+
|
149
|
+
def arg_types
|
150
|
+
@arg_types ||= args.map { |a| [a[:name], a[:type].split("(").first] }.to_h
|
151
|
+
end
|
152
|
+
|
153
|
+
def arg_defaults
|
154
|
+
# TODO find out why can't use select here
|
155
|
+
@arg_defaults ||= args.map { |a| [a[:name], a[:default]] }.to_h
|
156
|
+
end
|
157
|
+
|
77
158
|
def out_size
|
78
159
|
@out_size ||= func.split("->").last.count("!")
|
79
160
|
end
|
@@ -82,8 +163,16 @@ module Torch
|
|
82
163
|
@ret_size ||= func.split("->").last.split(", ").size
|
83
164
|
end
|
84
165
|
|
166
|
+
def ret_array?
|
167
|
+
@ret_array ||= func.split("->").last.include?('[]')
|
168
|
+
end
|
169
|
+
|
170
|
+
def ret_void?
|
171
|
+
func.split("->").last.strip == "()"
|
172
|
+
end
|
173
|
+
|
85
174
|
def out?
|
86
|
-
|
175
|
+
@out
|
87
176
|
end
|
88
177
|
|
89
178
|
def ruby_name
|
@@ -18,12 +18,12 @@ module Torch
|
|
18
18
|
functions = functions()
|
19
19
|
|
20
20
|
# skip functions
|
21
|
-
skip_args = ["
|
21
|
+
skip_args = ["Layout", "Storage", "ConstQuantizerPtr"]
|
22
22
|
|
23
23
|
# remove functions
|
24
24
|
functions.reject! do |f|
|
25
25
|
f.ruby_name.start_with?("_") ||
|
26
|
-
f.ruby_name.
|
26
|
+
f.ruby_name.include?("_backward") ||
|
27
27
|
f.args.any? { |a| a[:type].include?("Dimname") }
|
28
28
|
end
|
29
29
|
|
@@ -31,32 +31,15 @@ module Torch
|
|
31
31
|
todo_functions, functions =
|
32
32
|
functions.partition do |f|
|
33
33
|
f.args.any? do |a|
|
34
|
-
a[:type].include?("?") && !["Tensor?", "Generator?", "int?", "ScalarType?", "Tensor?[]"].include?(a[:type]) ||
|
35
34
|
skip_args.any? { |sa| a[:type].include?(sa) } ||
|
35
|
+
# call to 'range' is ambiguous
|
36
|
+
f.cpp_name == "_range" ||
|
36
37
|
# native_functions.yaml is missing size argument for normal
|
37
38
|
# https://pytorch.org/cppdocs/api/function_namespacetorch_1a80253fe5a3ded4716ec929a348adb4b9.html
|
38
39
|
(f.base_name == "normal" && !f.out?)
|
39
40
|
end
|
40
41
|
end
|
41
42
|
|
42
|
-
# generate additional functions for optional arguments
|
43
|
-
# there may be a better way to do this
|
44
|
-
optional_functions, functions = functions.partition { |f| f.args.any? { |a| a[:type] == "int?" } }
|
45
|
-
optional_functions.each do |f|
|
46
|
-
next if f.ruby_name == "cross"
|
47
|
-
next if f.ruby_name.start_with?("avg_pool") && f.out?
|
48
|
-
|
49
|
-
opt_args = f.args.select { |a| a[:type] == "int?" }
|
50
|
-
if opt_args.size == 1
|
51
|
-
sep = f.name.include?(".") ? "_" : "."
|
52
|
-
f1 = Function.new(f.function.merge("func" => f.func.sub("(", "#{sep}#{opt_args.first[:name]}(").gsub("int?", "int")))
|
53
|
-
# TODO only remove some arguments
|
54
|
-
f2 = Function.new(f.function.merge("func" => f.func.sub(/, int\?.+\) ->/, ") ->")))
|
55
|
-
functions << f1
|
56
|
-
functions << f2
|
57
|
-
end
|
58
|
-
end
|
59
|
-
|
60
43
|
# todo_functions.each do |f|
|
61
44
|
# puts f.func
|
62
45
|
# puts
|
@@ -89,15 +72,18 @@ void add_%{type}_functions(Module m);
|
|
89
72
|
#include <rice/Module.hpp>
|
90
73
|
#include "templates.hpp"
|
91
74
|
|
75
|
+
%{functions}
|
76
|
+
|
92
77
|
void add_%{type}_functions(Module m) {
|
93
|
-
|
94
|
-
%{functions};
|
78
|
+
%{add_functions}
|
95
79
|
}
|
96
80
|
TEMPLATE
|
97
81
|
|
98
82
|
cpp_defs = []
|
83
|
+
add_defs = []
|
99
84
|
functions.sort_by(&:cpp_name).each do |func|
|
100
|
-
fargs = func.args #.select { |a| a[:type] != "Generator?" }
|
85
|
+
fargs = func.args.dup #.select { |a| a[:type] != "Generator?" }
|
86
|
+
fargs << {name: :options, type: "TensorOptions"} if func.tensor_options
|
101
87
|
|
102
88
|
cpp_args = []
|
103
89
|
fargs.each do |a|
|
@@ -109,52 +95,70 @@ void add_%{type}_functions(Module m) {
|
|
109
95
|
# TODO better signature
|
110
96
|
"OptionalTensor"
|
111
97
|
when "ScalarType?"
|
112
|
-
"
|
113
|
-
when "Tensor[]"
|
114
|
-
"TensorList"
|
115
|
-
when "Tensor?[]"
|
98
|
+
"torch::optional<ScalarType>"
|
99
|
+
when "Tensor[]", "Tensor?[]"
|
116
100
|
# TODO make optional
|
117
|
-
"
|
101
|
+
"std::vector<Tensor>"
|
118
102
|
when "int"
|
119
103
|
"int64_t"
|
104
|
+
when "int?"
|
105
|
+
"torch::optional<int64_t>"
|
106
|
+
when "float?"
|
107
|
+
"torch::optional<double>"
|
108
|
+
when "bool?"
|
109
|
+
"torch::optional<bool>"
|
110
|
+
when "Scalar?"
|
111
|
+
"torch::optional<torch::Scalar>"
|
120
112
|
when "float"
|
121
113
|
"double"
|
122
114
|
when /\Aint\[/
|
123
|
-
"
|
115
|
+
"std::vector<int64_t>"
|
124
116
|
when /Tensor\(\S!?\)/
|
125
117
|
"Tensor &"
|
126
118
|
when "str"
|
127
119
|
"std::string"
|
128
|
-
|
120
|
+
when "TensorOptions"
|
121
|
+
"const torch::TensorOptions &"
|
122
|
+
when "Layout?"
|
123
|
+
"torch::optional<Layout>"
|
124
|
+
when "Device?"
|
125
|
+
"torch::optional<Device>"
|
126
|
+
when "Scalar", "bool", "ScalarType", "Layout", "Device", "Storage"
|
129
127
|
a[:type]
|
128
|
+
else
|
129
|
+
raise "Unknown type: #{a[:type]}"
|
130
130
|
end
|
131
131
|
|
132
|
-
t = "MyReduction" if a[:name] ==
|
132
|
+
t = "MyReduction" if a[:name] == :reduction && t == "int64_t"
|
133
133
|
cpp_args << [t, a[:name]].join(" ").sub("& ", "&")
|
134
134
|
end
|
135
135
|
|
136
136
|
dispatch = func.out? ? "#{func.base_name}_out" : func.base_name
|
137
137
|
args = fargs.map { |a| a[:name] }
|
138
138
|
args.unshift(*args.pop(func.out_size)) if func.out?
|
139
|
-
args.delete(
|
139
|
+
args.delete(:self) if def_method == :define_method
|
140
140
|
|
141
141
|
prefix = def_method == :define_method ? "self." : "torch::"
|
142
142
|
|
143
143
|
body = "#{prefix}#{dispatch}(#{args.join(", ")})"
|
144
|
-
|
145
|
-
if func.
|
144
|
+
|
145
|
+
if func.cpp_name == "_fill_diagonal_"
|
146
|
+
body = "to_ruby<torch::Tensor>(#{body})"
|
147
|
+
elsif !func.ret_void?
|
146
148
|
body = "wrap(#{body})"
|
147
149
|
end
|
148
150
|
|
149
|
-
cpp_defs << "
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
151
|
+
cpp_defs << "// #{func.func}
|
152
|
+
static #{func.ret_void? ? "void" : "Object"} #{type}#{func.cpp_name}(#{cpp_args.join(", ")})
|
153
|
+
{
|
154
|
+
return #{body};
|
155
|
+
}"
|
156
|
+
|
157
|
+
add_defs << "m.#{def_method}(\"#{func.cpp_name}\", #{type}#{func.cpp_name});"
|
154
158
|
end
|
155
159
|
|
156
160
|
hpp_contents = hpp_template % {type: type}
|
157
|
-
cpp_contents = cpp_template % {type: type, functions: cpp_defs.join("\n ")}
|
161
|
+
cpp_contents = cpp_template % {type: type, functions: cpp_defs.join("\n\n"), add_functions: add_defs.join("\n ")}
|
158
162
|
|
159
163
|
path = File.expand_path("../../../ext/torch", __dir__)
|
160
164
|
File.write("#{path}/#{type}_functions.hpp", hpp_contents)
|