torch-rb 0.3.2 → 0.3.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +28 -0
- data/README.md +7 -2
- data/ext/torch/ext.cpp +60 -20
- data/ext/torch/extconf.rb +3 -0
- data/ext/torch/templates.cpp +36 -0
- data/ext/torch/templates.hpp +81 -87
- data/lib/torch.rb +71 -19
- data/lib/torch/native/dispatcher.rb +30 -8
- data/lib/torch/native/function.rb +93 -4
- data/lib/torch/native/generator.rb +45 -41
- data/lib/torch/native/parser.rb +57 -76
- data/lib/torch/nn/functional.rb +112 -2
- data/lib/torch/nn/leaky_relu.rb +3 -3
- data/lib/torch/nn/module.rb +9 -1
- data/lib/torch/nn/upsample.rb +31 -0
- data/lib/torch/tensor.rb +45 -51
- data/lib/torch/utils/data/data_loader.rb +2 -0
- data/lib/torch/utils/data/tensor_dataset.rb +2 -0
- data/lib/torch/version.rb +1 -1
- metadata +3 -2
data/lib/torch.rb
CHANGED
@@ -174,6 +174,9 @@ require "torch/nn/smooth_l1_loss"
|
|
174
174
|
require "torch/nn/soft_margin_loss"
|
175
175
|
require "torch/nn/triplet_margin_loss"
|
176
176
|
|
177
|
+
# nn vision
|
178
|
+
require "torch/nn/upsample"
|
179
|
+
|
177
180
|
# nn other
|
178
181
|
require "torch/nn/functional"
|
179
182
|
require "torch/nn/init"
|
@@ -196,6 +199,32 @@ module Torch
|
|
196
199
|
end
|
197
200
|
end
|
198
201
|
|
202
|
+
# legacy
|
203
|
+
# but may make it easier to port tutorials
|
204
|
+
module Autograd
|
205
|
+
class Variable
|
206
|
+
def self.new(x)
|
207
|
+
raise ArgumentError, "Variable data has to be a tensor, but got #{x.class.name}" unless x.is_a?(Tensor)
|
208
|
+
warn "[torch] The Variable API is deprecated. Use tensors with requires_grad: true instead."
|
209
|
+
x
|
210
|
+
end
|
211
|
+
end
|
212
|
+
end
|
213
|
+
|
214
|
+
# TODO move to C++
|
215
|
+
class ByteStorage
|
216
|
+
# private
|
217
|
+
attr_reader :bytes
|
218
|
+
|
219
|
+
def initialize(bytes)
|
220
|
+
@bytes = bytes
|
221
|
+
end
|
222
|
+
|
223
|
+
def self.from_buffer(bytes)
|
224
|
+
new(bytes)
|
225
|
+
end
|
226
|
+
end
|
227
|
+
|
199
228
|
# keys: https://pytorch.org/docs/stable/tensor_attributes.html#torch.torch.dtype
|
200
229
|
# values: https://github.com/pytorch/pytorch/blob/master/c10/core/ScalarType.h
|
201
230
|
DTYPE_TO_ENUM = {
|
@@ -224,40 +253,43 @@ module Torch
|
|
224
253
|
}
|
225
254
|
ENUM_TO_DTYPE = DTYPE_TO_ENUM.map(&:reverse).to_h
|
226
255
|
|
256
|
+
TENSOR_TYPE_CLASSES = []
|
257
|
+
|
227
258
|
def self._make_tensor_class(dtype, cuda = false)
|
228
259
|
cls = Class.new
|
229
260
|
device = cuda ? "cuda" : "cpu"
|
230
261
|
cls.define_singleton_method("new") do |*args|
|
231
262
|
if args.size == 1 && args.first.is_a?(Tensor)
|
232
263
|
args.first.send(dtype).to(device)
|
264
|
+
elsif args.size == 1 && args.first.is_a?(ByteStorage) && dtype == :uint8
|
265
|
+
bytes = args.first.bytes
|
266
|
+
Torch._from_blob(bytes, [bytes.bytesize], TensorOptions.new.dtype(DTYPE_TO_ENUM[dtype]))
|
233
267
|
elsif args.size == 1 && args.first.is_a?(Array)
|
234
268
|
Torch.tensor(args.first, dtype: dtype, device: device)
|
235
269
|
else
|
236
270
|
Torch.empty(*args, dtype: dtype, device: device)
|
237
271
|
end
|
238
272
|
end
|
273
|
+
TENSOR_TYPE_CLASSES << cls
|
239
274
|
cls
|
240
275
|
end
|
241
276
|
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
CUDA::IntTensor = _make_tensor_class(:int32, true)
|
259
|
-
CUDA::LongTensor = _make_tensor_class(:int64, true)
|
260
|
-
CUDA::BoolTensor = _make_tensor_class(:bool, true)
|
277
|
+
DTYPE_TO_CLASS = {
|
278
|
+
float32: "FloatTensor",
|
279
|
+
float64: "DoubleTensor",
|
280
|
+
float16: "HalfTensor",
|
281
|
+
uint8: "ByteTensor",
|
282
|
+
int8: "CharTensor",
|
283
|
+
int16: "ShortTensor",
|
284
|
+
int32: "IntTensor",
|
285
|
+
int64: "LongTensor",
|
286
|
+
bool: "BoolTensor"
|
287
|
+
}
|
288
|
+
|
289
|
+
DTYPE_TO_CLASS.each do |dtype, class_name|
|
290
|
+
const_set(class_name, _make_tensor_class(dtype))
|
291
|
+
CUDA.const_set(class_name, _make_tensor_class(dtype, true))
|
292
|
+
end
|
261
293
|
|
262
294
|
class << self
|
263
295
|
# Torch.float, Torch.long, etc
|
@@ -388,6 +420,10 @@ module Torch
|
|
388
420
|
end
|
389
421
|
|
390
422
|
def randperm(n, **options)
|
423
|
+
# dtype hack in Python
|
424
|
+
# https://github.com/pytorch/pytorch/blob/v1.6.0/tools/autograd/gen_python_functions.py#L1307-L1311
|
425
|
+
options[:dtype] ||= :int64
|
426
|
+
|
391
427
|
_randperm(n, tensor_options(**options))
|
392
428
|
end
|
393
429
|
|
@@ -460,6 +496,22 @@ module Torch
|
|
460
496
|
zeros(input.size, **like_options(input, options))
|
461
497
|
end
|
462
498
|
|
499
|
+
def stft(input, n_fft, hop_length: nil, win_length: nil, window: nil, center: true, pad_mode: "reflect", normalized: false, onesided: true)
|
500
|
+
if center
|
501
|
+
signal_dim = input.dim
|
502
|
+
extended_shape = [1] * (3 - signal_dim) + input.size
|
503
|
+
pad = n_fft.div(2).to_i
|
504
|
+
input = NN::F.pad(input.view(extended_shape), [pad, pad], mode: pad_mode)
|
505
|
+
input = input.view(input.shape[-signal_dim..-1])
|
506
|
+
end
|
507
|
+
_stft(input, n_fft, hop_length, win_length, window, normalized, onesided)
|
508
|
+
end
|
509
|
+
|
510
|
+
def clamp(tensor, min, max)
|
511
|
+
tensor = _clamp_min(tensor, min)
|
512
|
+
_clamp_max(tensor, max)
|
513
|
+
end
|
514
|
+
|
463
515
|
private
|
464
516
|
|
465
517
|
def to_ivalue(obj)
|
@@ -22,21 +22,43 @@ module Torch
|
|
22
22
|
end
|
23
23
|
|
24
24
|
def bind_functions(context, def_method, functions)
|
25
|
+
instance_method = def_method == :define_method
|
25
26
|
functions.group_by(&:ruby_name).sort_by { |g, _| g }.each do |name, funcs|
|
26
|
-
if
|
27
|
+
if instance_method
|
27
28
|
funcs.map! { |f| Function.new(f.function) }
|
28
|
-
funcs.each { |f| f.args.reject! { |a| a[:name] ==
|
29
|
+
funcs.each { |f| f.args.reject! { |a| a[:name] == :self } }
|
29
30
|
end
|
30
31
|
|
31
|
-
defined =
|
32
|
+
defined = instance_method ? context.method_defined?(name) : context.respond_to?(name)
|
32
33
|
next if defined && name != "clone"
|
33
34
|
|
34
|
-
parser
|
35
|
+
# skip parser when possible for performance
|
36
|
+
if funcs.size == 1 && funcs.first.args.size == 0
|
37
|
+
# functions with no arguments
|
38
|
+
if instance_method
|
39
|
+
context.send(:alias_method, name, funcs.first.cpp_name)
|
40
|
+
else
|
41
|
+
context.singleton_class.send(:alias_method, name, funcs.first.cpp_name)
|
42
|
+
end
|
43
|
+
elsif funcs.size == 2 && funcs.map { |f| f.arg_types.values }.sort == [["Scalar"], ["Tensor"]]
|
44
|
+
# functions that take a tensor or scalar
|
45
|
+
scalar_name, tensor_name = funcs.sort_by { |f| f.arg_types.values }.map(&:cpp_name)
|
46
|
+
context.send(def_method, name) do |other|
|
47
|
+
case other
|
48
|
+
when Tensor
|
49
|
+
send(tensor_name, other)
|
50
|
+
else
|
51
|
+
send(scalar_name, other)
|
52
|
+
end
|
53
|
+
end
|
54
|
+
else
|
55
|
+
parser = Parser.new(funcs)
|
35
56
|
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
57
|
+
context.send(def_method, name) do |*args, **options|
|
58
|
+
result = parser.parse(args, options)
|
59
|
+
raise ArgumentError, result[:error] if result[:error]
|
60
|
+
send(result[:name], *result[:args])
|
61
|
+
end
|
40
62
|
end
|
41
63
|
end
|
42
64
|
end
|
@@ -1,10 +1,15 @@
|
|
1
1
|
module Torch
|
2
2
|
module Native
|
3
3
|
class Function
|
4
|
-
attr_reader :function
|
4
|
+
attr_reader :function, :tensor_options
|
5
5
|
|
6
6
|
def initialize(function)
|
7
7
|
@function = function
|
8
|
+
|
9
|
+
# note: don't modify function in-place
|
10
|
+
@tensor_options_str = ", *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None)"
|
11
|
+
@tensor_options = @function["func"].include?(@tensor_options_str)
|
12
|
+
@out = out_size > 0 && base_name[-1] != "_"
|
8
13
|
end
|
9
14
|
|
10
15
|
def func
|
@@ -27,7 +32,7 @@ module Torch
|
|
27
32
|
@args ||= begin
|
28
33
|
args = []
|
29
34
|
pos = true
|
30
|
-
args_str = func.split("(", 2).last.split(") ->").first
|
35
|
+
args_str = func.sub(@tensor_options_str, ")").split("(", 2).last.split(") ->").first
|
31
36
|
args_str.split(", ").each do |a|
|
32
37
|
if a == "*"
|
33
38
|
pos = false
|
@@ -68,12 +73,88 @@ module Torch
|
|
68
73
|
next if t == "Generator?"
|
69
74
|
next if t == "MemoryFormat"
|
70
75
|
next if t == "MemoryFormat?"
|
71
|
-
args << {name: k, type: t, default: d, pos: pos, has_default: has_default}
|
76
|
+
args << {name: k.to_sym, type: t, default: d, pos: pos, has_default: has_default}
|
72
77
|
end
|
73
78
|
args
|
74
79
|
end
|
75
80
|
end
|
76
81
|
|
82
|
+
def arg_checkers
|
83
|
+
@arg_checkers ||= begin
|
84
|
+
checkers = {}
|
85
|
+
arg_types.each do |k, t|
|
86
|
+
checker =
|
87
|
+
case t
|
88
|
+
when "Tensor"
|
89
|
+
->(v) { v.is_a?(Tensor) }
|
90
|
+
when "Tensor?"
|
91
|
+
->(v) { v.nil? || v.is_a?(Tensor) }
|
92
|
+
when "Tensor[]", "Tensor?[]"
|
93
|
+
->(v) { v.is_a?(Array) && v.all? { |v2| v2.is_a?(Tensor) } }
|
94
|
+
when "int"
|
95
|
+
if k == :reduction
|
96
|
+
->(v) { v.is_a?(String) }
|
97
|
+
else
|
98
|
+
->(v) { v.is_a?(Integer) }
|
99
|
+
end
|
100
|
+
when "int?"
|
101
|
+
->(v) { v.is_a?(Integer) || v.nil? }
|
102
|
+
when "float?"
|
103
|
+
->(v) { v.is_a?(Numeric) || v.nil? }
|
104
|
+
when "bool?"
|
105
|
+
->(v) { v == true || v == false || v.nil? }
|
106
|
+
when "float"
|
107
|
+
->(v) { v.is_a?(Numeric) }
|
108
|
+
when /int\[.*\]/
|
109
|
+
->(v) { v.is_a?(Array) && v.all? { |v2| v2.is_a?(Integer) } }
|
110
|
+
when "Scalar"
|
111
|
+
->(v) { v.is_a?(Numeric) }
|
112
|
+
when "Scalar?"
|
113
|
+
->(v) { v.is_a?(Numeric) || v.nil? }
|
114
|
+
when "ScalarType"
|
115
|
+
->(v) { false } # not supported yet
|
116
|
+
when "ScalarType?"
|
117
|
+
->(v) { v.nil? }
|
118
|
+
when "bool"
|
119
|
+
->(v) { v == true || v == false }
|
120
|
+
when "str"
|
121
|
+
->(v) { v.is_a?(String) }
|
122
|
+
else
|
123
|
+
raise Error, "Unknown argument type: #{t}. Please report a bug with #{@name}."
|
124
|
+
end
|
125
|
+
checkers[k] = checker
|
126
|
+
end
|
127
|
+
checkers
|
128
|
+
end
|
129
|
+
end
|
130
|
+
|
131
|
+
def int_array_lengths
|
132
|
+
@int_array_lengths ||= begin
|
133
|
+
ret = {}
|
134
|
+
arg_types.each do |k, t|
|
135
|
+
if t.match?(/\Aint\[.+\]\z/)
|
136
|
+
size = t[4..-2]
|
137
|
+
raise Error, "Unknown size: #{size}. Please report a bug with #{@name}." unless size =~ /\A\d+\z/
|
138
|
+
ret[k] = size.to_i
|
139
|
+
end
|
140
|
+
end
|
141
|
+
ret
|
142
|
+
end
|
143
|
+
end
|
144
|
+
|
145
|
+
def arg_names
|
146
|
+
@arg_names ||= args.map { |a| a[:name] }
|
147
|
+
end
|
148
|
+
|
149
|
+
def arg_types
|
150
|
+
@arg_types ||= args.map { |a| [a[:name], a[:type].split("(").first] }.to_h
|
151
|
+
end
|
152
|
+
|
153
|
+
def arg_defaults
|
154
|
+
# TODO find out why can't use select here
|
155
|
+
@arg_defaults ||= args.map { |a| [a[:name], a[:default]] }.to_h
|
156
|
+
end
|
157
|
+
|
77
158
|
def out_size
|
78
159
|
@out_size ||= func.split("->").last.count("!")
|
79
160
|
end
|
@@ -82,8 +163,16 @@ module Torch
|
|
82
163
|
@ret_size ||= func.split("->").last.split(", ").size
|
83
164
|
end
|
84
165
|
|
166
|
+
def ret_array?
|
167
|
+
@ret_array ||= func.split("->").last.include?('[]')
|
168
|
+
end
|
169
|
+
|
170
|
+
def ret_void?
|
171
|
+
func.split("->").last.strip == "()"
|
172
|
+
end
|
173
|
+
|
85
174
|
def out?
|
86
|
-
|
175
|
+
@out
|
87
176
|
end
|
88
177
|
|
89
178
|
def ruby_name
|
@@ -18,12 +18,12 @@ module Torch
|
|
18
18
|
functions = functions()
|
19
19
|
|
20
20
|
# skip functions
|
21
|
-
skip_args = ["
|
21
|
+
skip_args = ["Layout", "Storage", "ConstQuantizerPtr"]
|
22
22
|
|
23
23
|
# remove functions
|
24
24
|
functions.reject! do |f|
|
25
25
|
f.ruby_name.start_with?("_") ||
|
26
|
-
f.ruby_name.
|
26
|
+
f.ruby_name.include?("_backward") ||
|
27
27
|
f.args.any? { |a| a[:type].include?("Dimname") }
|
28
28
|
end
|
29
29
|
|
@@ -31,32 +31,15 @@ module Torch
|
|
31
31
|
todo_functions, functions =
|
32
32
|
functions.partition do |f|
|
33
33
|
f.args.any? do |a|
|
34
|
-
a[:type].include?("?") && !["Tensor?", "Generator?", "int?", "ScalarType?", "Tensor?[]"].include?(a[:type]) ||
|
35
34
|
skip_args.any? { |sa| a[:type].include?(sa) } ||
|
35
|
+
# call to 'range' is ambiguous
|
36
|
+
f.cpp_name == "_range" ||
|
36
37
|
# native_functions.yaml is missing size argument for normal
|
37
38
|
# https://pytorch.org/cppdocs/api/function_namespacetorch_1a80253fe5a3ded4716ec929a348adb4b9.html
|
38
39
|
(f.base_name == "normal" && !f.out?)
|
39
40
|
end
|
40
41
|
end
|
41
42
|
|
42
|
-
# generate additional functions for optional arguments
|
43
|
-
# there may be a better way to do this
|
44
|
-
optional_functions, functions = functions.partition { |f| f.args.any? { |a| a[:type] == "int?" } }
|
45
|
-
optional_functions.each do |f|
|
46
|
-
next if f.ruby_name == "cross"
|
47
|
-
next if f.ruby_name.start_with?("avg_pool") && f.out?
|
48
|
-
|
49
|
-
opt_args = f.args.select { |a| a[:type] == "int?" }
|
50
|
-
if opt_args.size == 1
|
51
|
-
sep = f.name.include?(".") ? "_" : "."
|
52
|
-
f1 = Function.new(f.function.merge("func" => f.func.sub("(", "#{sep}#{opt_args.first[:name]}(").gsub("int?", "int")))
|
53
|
-
# TODO only remove some arguments
|
54
|
-
f2 = Function.new(f.function.merge("func" => f.func.sub(/, int\?.+\) ->/, ") ->")))
|
55
|
-
functions << f1
|
56
|
-
functions << f2
|
57
|
-
end
|
58
|
-
end
|
59
|
-
|
60
43
|
# todo_functions.each do |f|
|
61
44
|
# puts f.func
|
62
45
|
# puts
|
@@ -89,15 +72,18 @@ void add_%{type}_functions(Module m);
|
|
89
72
|
#include <rice/Module.hpp>
|
90
73
|
#include "templates.hpp"
|
91
74
|
|
75
|
+
%{functions}
|
76
|
+
|
92
77
|
void add_%{type}_functions(Module m) {
|
93
|
-
|
94
|
-
%{functions};
|
78
|
+
%{add_functions}
|
95
79
|
}
|
96
80
|
TEMPLATE
|
97
81
|
|
98
82
|
cpp_defs = []
|
83
|
+
add_defs = []
|
99
84
|
functions.sort_by(&:cpp_name).each do |func|
|
100
|
-
fargs = func.args #.select { |a| a[:type] != "Generator?" }
|
85
|
+
fargs = func.args.dup #.select { |a| a[:type] != "Generator?" }
|
86
|
+
fargs << {name: :options, type: "TensorOptions"} if func.tensor_options
|
101
87
|
|
102
88
|
cpp_args = []
|
103
89
|
fargs.each do |a|
|
@@ -109,52 +95,70 @@ void add_%{type}_functions(Module m) {
|
|
109
95
|
# TODO better signature
|
110
96
|
"OptionalTensor"
|
111
97
|
when "ScalarType?"
|
112
|
-
"
|
113
|
-
when "Tensor[]"
|
114
|
-
"TensorList"
|
115
|
-
when "Tensor?[]"
|
98
|
+
"torch::optional<ScalarType>"
|
99
|
+
when "Tensor[]", "Tensor?[]"
|
116
100
|
# TODO make optional
|
117
|
-
"
|
101
|
+
"std::vector<Tensor>"
|
118
102
|
when "int"
|
119
103
|
"int64_t"
|
104
|
+
when "int?"
|
105
|
+
"torch::optional<int64_t>"
|
106
|
+
when "float?"
|
107
|
+
"torch::optional<double>"
|
108
|
+
when "bool?"
|
109
|
+
"torch::optional<bool>"
|
110
|
+
when "Scalar?"
|
111
|
+
"torch::optional<torch::Scalar>"
|
120
112
|
when "float"
|
121
113
|
"double"
|
122
114
|
when /\Aint\[/
|
123
|
-
"
|
115
|
+
"std::vector<int64_t>"
|
124
116
|
when /Tensor\(\S!?\)/
|
125
117
|
"Tensor &"
|
126
118
|
when "str"
|
127
119
|
"std::string"
|
128
|
-
|
120
|
+
when "TensorOptions"
|
121
|
+
"const torch::TensorOptions &"
|
122
|
+
when "Layout?"
|
123
|
+
"torch::optional<Layout>"
|
124
|
+
when "Device?"
|
125
|
+
"torch::optional<Device>"
|
126
|
+
when "Scalar", "bool", "ScalarType", "Layout", "Device", "Storage"
|
129
127
|
a[:type]
|
128
|
+
else
|
129
|
+
raise "Unknown type: #{a[:type]}"
|
130
130
|
end
|
131
131
|
|
132
|
-
t = "MyReduction" if a[:name] ==
|
132
|
+
t = "MyReduction" if a[:name] == :reduction && t == "int64_t"
|
133
133
|
cpp_args << [t, a[:name]].join(" ").sub("& ", "&")
|
134
134
|
end
|
135
135
|
|
136
136
|
dispatch = func.out? ? "#{func.base_name}_out" : func.base_name
|
137
137
|
args = fargs.map { |a| a[:name] }
|
138
138
|
args.unshift(*args.pop(func.out_size)) if func.out?
|
139
|
-
args.delete(
|
139
|
+
args.delete(:self) if def_method == :define_method
|
140
140
|
|
141
141
|
prefix = def_method == :define_method ? "self." : "torch::"
|
142
142
|
|
143
143
|
body = "#{prefix}#{dispatch}(#{args.join(", ")})"
|
144
|
-
|
145
|
-
if func.
|
144
|
+
|
145
|
+
if func.cpp_name == "_fill_diagonal_"
|
146
|
+
body = "to_ruby<torch::Tensor>(#{body})"
|
147
|
+
elsif !func.ret_void?
|
146
148
|
body = "wrap(#{body})"
|
147
149
|
end
|
148
150
|
|
149
|
-
cpp_defs << "
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
151
|
+
cpp_defs << "// #{func.func}
|
152
|
+
static #{func.ret_void? ? "void" : "Object"} #{type}#{func.cpp_name}(#{cpp_args.join(", ")})
|
153
|
+
{
|
154
|
+
return #{body};
|
155
|
+
}"
|
156
|
+
|
157
|
+
add_defs << "m.#{def_method}(\"#{func.cpp_name}\", #{type}#{func.cpp_name});"
|
154
158
|
end
|
155
159
|
|
156
160
|
hpp_contents = hpp_template % {type: type}
|
157
|
-
cpp_contents = cpp_template % {type: type, functions: cpp_defs.join("\n ")}
|
161
|
+
cpp_contents = cpp_template % {type: type, functions: cpp_defs.join("\n\n"), add_functions: add_defs.join("\n ")}
|
158
162
|
|
159
163
|
path = File.expand_path("../../../ext/torch", __dir__)
|
160
164
|
File.write("#{path}/#{type}_functions.hpp", hpp_contents)
|