torch-rb 0.3.6 → 0.3.7
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +9 -0
- data/README.md +1 -0
- data/ext/torch/ext.cpp +29 -9
- data/ext/torch/extconf.rb +3 -0
- data/ext/torch/templates.cpp +28 -0
- data/ext/torch/templates.hpp +23 -34
- data/lib/torch.rb +35 -0
- data/lib/torch/native/dispatcher.rb +30 -8
- data/lib/torch/native/function.rb +87 -6
- data/lib/torch/native/generator.rb +28 -18
- data/lib/torch/native/parser.rb +55 -86
- data/lib/torch/nn/functional.rb +106 -0
- data/lib/torch/nn/module.rb +4 -1
- data/lib/torch/nn/upsample.rb +31 -0
- data/lib/torch/tensor.rb +13 -7
- data/lib/torch/version.rb +1 -1
- metadata +3 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 8a1852ee3d1ecc7a29c23259b8c328a95030a270b7c11f37f22049177898652e
|
4
|
+
data.tar.gz: 56823f1815d3c0c4d5d5c01ef76d781b792b3e4e7c68c0332a149b883a54c7c8
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: bed15510cfeaa555d71f1e1f46ed8944893bd349a07c4316dcd63429fe76e13facd8794399ef97fc400d05796579f2e84822b62c98c71dc996e211ad04113ae2
|
7
|
+
data.tar.gz: aa05e3645e363eda27274323cdb7fb316342074d1d5afe8f7ee6bfd9819da7883b43d084beb6b29011c631c04fdddc8e6789db41c7c84c53ba9ed152d3338b09
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,12 @@
|
|
1
|
+
## 0.3.7 (2020-09-22)
|
2
|
+
|
3
|
+
- Improved performance
|
4
|
+
- Added `Upsample`
|
5
|
+
- Added support for passing tensor class to `type` method
|
6
|
+
- Fixed error with buffers on GPU
|
7
|
+
- Fixed error with `new_full`
|
8
|
+
- Fixed issue with `numo` method and non-contiguous tensors
|
9
|
+
|
1
10
|
## 0.3.6 (2020-09-17)
|
2
11
|
|
3
12
|
- Added `inplace` option for leaky ReLU
|
data/README.md
CHANGED
@@ -402,6 +402,7 @@ Here are a few full examples:
|
|
402
402
|
- [Image classification with MNIST](examples/mnist) ([日本語版](https://qiita.com/kojix2/items/c19c36dc1bf73ea93409))
|
403
403
|
- [Collaborative filtering with MovieLens](examples/movielens)
|
404
404
|
- [Sequence models and word embeddings](examples/nlp)
|
405
|
+
- [Generative adversarial networks](examples/gan)
|
405
406
|
|
406
407
|
## LibTorch Installation
|
407
408
|
|
data/ext/torch/ext.cpp
CHANGED
@@ -232,7 +232,7 @@ void Init_ext()
|
|
232
232
|
})
|
233
233
|
.define_singleton_method(
|
234
234
|
"_empty",
|
235
|
-
*[](
|
235
|
+
*[](std::vector<int64_t> size, const torch::TensorOptions &options) {
|
236
236
|
return torch::empty(size, options);
|
237
237
|
})
|
238
238
|
.define_singleton_method(
|
@@ -242,7 +242,7 @@ void Init_ext()
|
|
242
242
|
})
|
243
243
|
.define_singleton_method(
|
244
244
|
"_full",
|
245
|
-
*[](
|
245
|
+
*[](std::vector<int64_t> size, Scalar fill_value, const torch::TensorOptions& options) {
|
246
246
|
return torch::full(size, fill_value, options);
|
247
247
|
})
|
248
248
|
.define_singleton_method(
|
@@ -257,22 +257,22 @@ void Init_ext()
|
|
257
257
|
})
|
258
258
|
.define_singleton_method(
|
259
259
|
"_ones",
|
260
|
-
*[](
|
260
|
+
*[](std::vector<int64_t> size, const torch::TensorOptions &options) {
|
261
261
|
return torch::ones(size, options);
|
262
262
|
})
|
263
263
|
.define_singleton_method(
|
264
264
|
"_rand",
|
265
|
-
*[](
|
265
|
+
*[](std::vector<int64_t> size, const torch::TensorOptions &options) {
|
266
266
|
return torch::rand(size, options);
|
267
267
|
})
|
268
268
|
.define_singleton_method(
|
269
269
|
"_randint",
|
270
|
-
*[](int64_t low, int64_t high,
|
270
|
+
*[](int64_t low, int64_t high, std::vector<int64_t> size, const torch::TensorOptions &options) {
|
271
271
|
return torch::randint(low, high, size, options);
|
272
272
|
})
|
273
273
|
.define_singleton_method(
|
274
274
|
"_randn",
|
275
|
-
*[](
|
275
|
+
*[](std::vector<int64_t> size, const torch::TensorOptions &options) {
|
276
276
|
return torch::randn(size, options);
|
277
277
|
})
|
278
278
|
.define_singleton_method(
|
@@ -282,7 +282,7 @@ void Init_ext()
|
|
282
282
|
})
|
283
283
|
.define_singleton_method(
|
284
284
|
"_zeros",
|
285
|
-
*[](
|
285
|
+
*[](std::vector<int64_t> size, const torch::TensorOptions &options) {
|
286
286
|
return torch::zeros(size, options);
|
287
287
|
})
|
288
288
|
// begin operations
|
@@ -303,13 +303,13 @@ void Init_ext()
|
|
303
303
|
})
|
304
304
|
.define_singleton_method(
|
305
305
|
"_from_blob",
|
306
|
-
*[](String s,
|
306
|
+
*[](String s, std::vector<int64_t> size, const torch::TensorOptions &options) {
|
307
307
|
void *data = const_cast<char *>(s.c_str());
|
308
308
|
return torch::from_blob(data, size, options);
|
309
309
|
})
|
310
310
|
.define_singleton_method(
|
311
311
|
"_tensor",
|
312
|
-
*[](Array a,
|
312
|
+
*[](Array a, std::vector<int64_t> size, const torch::TensorOptions &options) {
|
313
313
|
auto dtype = options.dtype();
|
314
314
|
torch::Tensor t;
|
315
315
|
if (dtype == torch::kBool) {
|
@@ -342,6 +342,16 @@ void Init_ext()
|
|
342
342
|
.define_method("numel", &torch::Tensor::numel)
|
343
343
|
.define_method("element_size", &torch::Tensor::element_size)
|
344
344
|
.define_method("requires_grad", &torch::Tensor::requires_grad)
|
345
|
+
// in C++ for performance
|
346
|
+
.define_method(
|
347
|
+
"shape",
|
348
|
+
*[](Tensor& self) {
|
349
|
+
Array a;
|
350
|
+
for (auto &size : self.sizes()) {
|
351
|
+
a.push(size);
|
352
|
+
}
|
353
|
+
return a;
|
354
|
+
})
|
345
355
|
.define_method(
|
346
356
|
"_index",
|
347
357
|
*[](Tensor& self, Array indices) {
|
@@ -420,9 +430,19 @@ void Init_ext()
|
|
420
430
|
tensor = tensor.to(device);
|
421
431
|
}
|
422
432
|
|
433
|
+
if (!tensor.is_contiguous()) {
|
434
|
+
tensor = tensor.contiguous();
|
435
|
+
}
|
436
|
+
|
423
437
|
auto data_ptr = (const char *) tensor.data_ptr();
|
424
438
|
return std::string(data_ptr, tensor.numel() * tensor.element_size());
|
425
439
|
})
|
440
|
+
// for TorchVision
|
441
|
+
.define_method(
|
442
|
+
"_data_ptr",
|
443
|
+
*[](Tensor& self) {
|
444
|
+
return reinterpret_cast<uintptr_t>(self.data_ptr());
|
445
|
+
})
|
426
446
|
// TODO figure out a better way to do this
|
427
447
|
.define_method(
|
428
448
|
"_flat_data",
|
data/ext/torch/extconf.rb
CHANGED
data/ext/torch/templates.cpp
CHANGED
@@ -2,6 +2,34 @@
|
|
2
2
|
#include <rice/Object.hpp>
|
3
3
|
#include "templates.hpp"
|
4
4
|
|
5
|
+
Object wrap(bool x) {
|
6
|
+
return to_ruby<bool>(x);
|
7
|
+
}
|
8
|
+
|
9
|
+
Object wrap(int64_t x) {
|
10
|
+
return to_ruby<int64_t>(x);
|
11
|
+
}
|
12
|
+
|
13
|
+
Object wrap(double x) {
|
14
|
+
return to_ruby<double>(x);
|
15
|
+
}
|
16
|
+
|
17
|
+
Object wrap(torch::Tensor x) {
|
18
|
+
return to_ruby<torch::Tensor>(x);
|
19
|
+
}
|
20
|
+
|
21
|
+
Object wrap(torch::Scalar x) {
|
22
|
+
return to_ruby<torch::Scalar>(x);
|
23
|
+
}
|
24
|
+
|
25
|
+
Object wrap(torch::ScalarType x) {
|
26
|
+
return to_ruby<torch::ScalarType>(x);
|
27
|
+
}
|
28
|
+
|
29
|
+
Object wrap(torch::QScheme x) {
|
30
|
+
return to_ruby<torch::QScheme>(x);
|
31
|
+
}
|
32
|
+
|
5
33
|
Object wrap(std::tuple<torch::Tensor, torch::Tensor> x) {
|
6
34
|
Array a;
|
7
35
|
a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
|
data/ext/torch/templates.hpp
CHANGED
@@ -13,49 +13,31 @@ using torch::Device;
|
|
13
13
|
using torch::Scalar;
|
14
14
|
using torch::ScalarType;
|
15
15
|
using torch::Tensor;
|
16
|
-
|
17
|
-
|
18
|
-
// it doesn't own underlying data
|
19
|
-
class IntArrayRef {
|
20
|
-
std::vector<int64_t> vec;
|
21
|
-
public:
|
22
|
-
IntArrayRef(Object o) {
|
23
|
-
Array a = Array(o);
|
24
|
-
for (size_t i = 0; i < a.size(); i++) {
|
25
|
-
vec.push_back(from_ruby<int64_t>(a[i]));
|
26
|
-
}
|
27
|
-
}
|
28
|
-
operator torch::IntArrayRef() {
|
29
|
-
return torch::IntArrayRef(vec);
|
30
|
-
}
|
31
|
-
};
|
16
|
+
using torch::IntArrayRef;
|
17
|
+
using torch::TensorList;
|
32
18
|
|
33
19
|
template<>
|
34
20
|
inline
|
35
|
-
|
21
|
+
std::vector<int64_t> from_ruby<std::vector<int64_t>>(Object x)
|
36
22
|
{
|
37
|
-
|
23
|
+
Array a = Array(x);
|
24
|
+
std::vector<int64_t> vec(a.size());
|
25
|
+
for (size_t i = 0; i < a.size(); i++) {
|
26
|
+
vec[i] = from_ruby<int64_t>(a[i]);
|
27
|
+
}
|
28
|
+
return vec;
|
38
29
|
}
|
39
30
|
|
40
|
-
class TensorList {
|
41
|
-
std::vector<torch::Tensor> vec;
|
42
|
-
public:
|
43
|
-
TensorList(Object o) {
|
44
|
-
Array a = Array(o);
|
45
|
-
for (size_t i = 0; i < a.size(); i++) {
|
46
|
-
vec.push_back(from_ruby<torch::Tensor>(a[i]));
|
47
|
-
}
|
48
|
-
}
|
49
|
-
operator torch::TensorList() {
|
50
|
-
return torch::TensorList(vec);
|
51
|
-
}
|
52
|
-
};
|
53
|
-
|
54
31
|
template<>
|
55
32
|
inline
|
56
|
-
|
33
|
+
std::vector<Tensor> from_ruby<std::vector<Tensor>>(Object x)
|
57
34
|
{
|
58
|
-
|
35
|
+
Array a = Array(x);
|
36
|
+
std::vector<Tensor> vec(a.size());
|
37
|
+
for (size_t i = 0; i < a.size(); i++) {
|
38
|
+
vec[i] = from_ruby<Tensor>(a[i]);
|
39
|
+
}
|
40
|
+
return vec;
|
59
41
|
}
|
60
42
|
|
61
43
|
class FanModeType {
|
@@ -242,6 +224,13 @@ torch::optional<Scalar> from_ruby<torch::optional<Scalar>>(Object x)
|
|
242
224
|
}
|
243
225
|
}
|
244
226
|
|
227
|
+
Object wrap(bool x);
|
228
|
+
Object wrap(int64_t x);
|
229
|
+
Object wrap(double x);
|
230
|
+
Object wrap(torch::Tensor x);
|
231
|
+
Object wrap(torch::Scalar x);
|
232
|
+
Object wrap(torch::ScalarType x);
|
233
|
+
Object wrap(torch::QScheme x);
|
245
234
|
Object wrap(std::tuple<torch::Tensor, torch::Tensor> x);
|
246
235
|
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x);
|
247
236
|
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x);
|
data/lib/torch.rb
CHANGED
@@ -174,6 +174,9 @@ require "torch/nn/smooth_l1_loss"
|
|
174
174
|
require "torch/nn/soft_margin_loss"
|
175
175
|
require "torch/nn/triplet_margin_loss"
|
176
176
|
|
177
|
+
# nn vision
|
178
|
+
require "torch/nn/upsample"
|
179
|
+
|
177
180
|
# nn other
|
178
181
|
require "torch/nn/functional"
|
179
182
|
require "torch/nn/init"
|
@@ -196,6 +199,32 @@ module Torch
|
|
196
199
|
end
|
197
200
|
end
|
198
201
|
|
202
|
+
# legacy
|
203
|
+
# but may make it easier to port tutorials
|
204
|
+
module Autograd
|
205
|
+
class Variable
|
206
|
+
def self.new(x)
|
207
|
+
raise ArgumentError, "Variable data has to be a tensor, but got #{x.class.name}" unless x.is_a?(Tensor)
|
208
|
+
warn "[torch] The Variable API is deprecated. Use tensors with requires_grad: true instead."
|
209
|
+
x
|
210
|
+
end
|
211
|
+
end
|
212
|
+
end
|
213
|
+
|
214
|
+
# TODO move to C++
|
215
|
+
class ByteStorage
|
216
|
+
# private
|
217
|
+
attr_reader :bytes
|
218
|
+
|
219
|
+
def initialize(bytes)
|
220
|
+
@bytes = bytes
|
221
|
+
end
|
222
|
+
|
223
|
+
def self.from_buffer(bytes)
|
224
|
+
new(bytes)
|
225
|
+
end
|
226
|
+
end
|
227
|
+
|
199
228
|
# keys: https://pytorch.org/docs/stable/tensor_attributes.html#torch.torch.dtype
|
200
229
|
# values: https://github.com/pytorch/pytorch/blob/master/c10/core/ScalarType.h
|
201
230
|
DTYPE_TO_ENUM = {
|
@@ -224,18 +253,24 @@ module Torch
|
|
224
253
|
}
|
225
254
|
ENUM_TO_DTYPE = DTYPE_TO_ENUM.map(&:reverse).to_h
|
226
255
|
|
256
|
+
TENSOR_TYPE_CLASSES = []
|
257
|
+
|
227
258
|
def self._make_tensor_class(dtype, cuda = false)
|
228
259
|
cls = Class.new
|
229
260
|
device = cuda ? "cuda" : "cpu"
|
230
261
|
cls.define_singleton_method("new") do |*args|
|
231
262
|
if args.size == 1 && args.first.is_a?(Tensor)
|
232
263
|
args.first.send(dtype).to(device)
|
264
|
+
elsif args.size == 1 && args.first.is_a?(ByteStorage) && dtype == :uint8
|
265
|
+
bytes = args.first.bytes
|
266
|
+
Torch._from_blob(bytes, [bytes.bytesize], TensorOptions.new.dtype(DTYPE_TO_ENUM[dtype]))
|
233
267
|
elsif args.size == 1 && args.first.is_a?(Array)
|
234
268
|
Torch.tensor(args.first, dtype: dtype, device: device)
|
235
269
|
else
|
236
270
|
Torch.empty(*args, dtype: dtype, device: device)
|
237
271
|
end
|
238
272
|
end
|
273
|
+
TENSOR_TYPE_CLASSES << cls
|
239
274
|
cls
|
240
275
|
end
|
241
276
|
|
@@ -22,21 +22,43 @@ module Torch
|
|
22
22
|
end
|
23
23
|
|
24
24
|
def bind_functions(context, def_method, functions)
|
25
|
+
instance_method = def_method == :define_method
|
25
26
|
functions.group_by(&:ruby_name).sort_by { |g, _| g }.each do |name, funcs|
|
26
|
-
if
|
27
|
+
if instance_method
|
27
28
|
funcs.map! { |f| Function.new(f.function) }
|
28
|
-
funcs.each { |f| f.args.reject! { |a| a[:name] ==
|
29
|
+
funcs.each { |f| f.args.reject! { |a| a[:name] == :self } }
|
29
30
|
end
|
30
31
|
|
31
|
-
defined =
|
32
|
+
defined = instance_method ? context.method_defined?(name) : context.respond_to?(name)
|
32
33
|
next if defined && name != "clone"
|
33
34
|
|
34
|
-
parser
|
35
|
+
# skip parser when possible for performance
|
36
|
+
if funcs.size == 1 && funcs.first.args.size == 0
|
37
|
+
# functions with no arguments
|
38
|
+
if instance_method
|
39
|
+
context.send(:alias_method, name, funcs.first.cpp_name)
|
40
|
+
else
|
41
|
+
context.singleton_class.send(:alias_method, name, funcs.first.cpp_name)
|
42
|
+
end
|
43
|
+
elsif funcs.size == 2 && funcs.map { |f| f.arg_types.values }.sort == [["Scalar"], ["Tensor"]]
|
44
|
+
# functions that take a tensor or scalar
|
45
|
+
scalar_name, tensor_name = funcs.sort_by { |f| f.arg_types.values }.map(&:cpp_name)
|
46
|
+
context.send(def_method, name) do |other|
|
47
|
+
case other
|
48
|
+
when Tensor
|
49
|
+
send(tensor_name, other)
|
50
|
+
else
|
51
|
+
send(scalar_name, other)
|
52
|
+
end
|
53
|
+
end
|
54
|
+
else
|
55
|
+
parser = Parser.new(funcs)
|
35
56
|
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
57
|
+
context.send(def_method, name) do |*args, **options|
|
58
|
+
result = parser.parse(args, options)
|
59
|
+
raise ArgumentError, result[:error] if result[:error]
|
60
|
+
send(result[:name], *result[:args])
|
61
|
+
end
|
40
62
|
end
|
41
63
|
end
|
42
64
|
end
|
@@ -6,9 +6,10 @@ module Torch
|
|
6
6
|
def initialize(function)
|
7
7
|
@function = function
|
8
8
|
|
9
|
-
|
10
|
-
@
|
11
|
-
@function["func"].
|
9
|
+
# note: don't modify function in-place
|
10
|
+
@tensor_options_str = ", *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None)"
|
11
|
+
@tensor_options = @function["func"].include?(@tensor_options_str)
|
12
|
+
@out = out_size > 0 && base_name[-1] != "_"
|
12
13
|
end
|
13
14
|
|
14
15
|
def func
|
@@ -31,7 +32,7 @@ module Torch
|
|
31
32
|
@args ||= begin
|
32
33
|
args = []
|
33
34
|
pos = true
|
34
|
-
args_str = func.split("(", 2).last.split(") ->").first
|
35
|
+
args_str = func.sub(@tensor_options_str, ")").split("(", 2).last.split(") ->").first
|
35
36
|
args_str.split(", ").each do |a|
|
36
37
|
if a == "*"
|
37
38
|
pos = false
|
@@ -72,12 +73,88 @@ module Torch
|
|
72
73
|
next if t == "Generator?"
|
73
74
|
next if t == "MemoryFormat"
|
74
75
|
next if t == "MemoryFormat?"
|
75
|
-
args << {name: k, type: t, default: d, pos: pos, has_default: has_default}
|
76
|
+
args << {name: k.to_sym, type: t, default: d, pos: pos, has_default: has_default}
|
76
77
|
end
|
77
78
|
args
|
78
79
|
end
|
79
80
|
end
|
80
81
|
|
82
|
+
def arg_checkers
|
83
|
+
@arg_checkers ||= begin
|
84
|
+
checkers = {}
|
85
|
+
arg_types.each do |k, t|
|
86
|
+
checker =
|
87
|
+
case t
|
88
|
+
when "Tensor"
|
89
|
+
->(v) { v.is_a?(Tensor) }
|
90
|
+
when "Tensor?"
|
91
|
+
->(v) { v.nil? || v.is_a?(Tensor) }
|
92
|
+
when "Tensor[]", "Tensor?[]"
|
93
|
+
->(v) { v.is_a?(Array) && v.all? { |v2| v2.is_a?(Tensor) } }
|
94
|
+
when "int"
|
95
|
+
if k == :reduction
|
96
|
+
->(v) { v.is_a?(String) }
|
97
|
+
else
|
98
|
+
->(v) { v.is_a?(Integer) }
|
99
|
+
end
|
100
|
+
when "int?"
|
101
|
+
->(v) { v.is_a?(Integer) || v.nil? }
|
102
|
+
when "float?"
|
103
|
+
->(v) { v.is_a?(Numeric) || v.nil? }
|
104
|
+
when "bool?"
|
105
|
+
->(v) { v == true || v == false || v.nil? }
|
106
|
+
when "float"
|
107
|
+
->(v) { v.is_a?(Numeric) }
|
108
|
+
when /int\[.*\]/
|
109
|
+
->(v) { v.is_a?(Array) && v.all? { |v2| v2.is_a?(Integer) } }
|
110
|
+
when "Scalar"
|
111
|
+
->(v) { v.is_a?(Numeric) }
|
112
|
+
when "Scalar?"
|
113
|
+
->(v) { v.is_a?(Numeric) || v.nil? }
|
114
|
+
when "ScalarType"
|
115
|
+
->(v) { false } # not supported yet
|
116
|
+
when "ScalarType?"
|
117
|
+
->(v) { v.nil? }
|
118
|
+
when "bool"
|
119
|
+
->(v) { v == true || v == false }
|
120
|
+
when "str"
|
121
|
+
->(v) { v.is_a?(String) }
|
122
|
+
else
|
123
|
+
raise Error, "Unknown argument type: #{t}. Please report a bug with #{@name}."
|
124
|
+
end
|
125
|
+
checkers[k] = checker
|
126
|
+
end
|
127
|
+
checkers
|
128
|
+
end
|
129
|
+
end
|
130
|
+
|
131
|
+
def int_array_lengths
|
132
|
+
@int_array_lengths ||= begin
|
133
|
+
ret = {}
|
134
|
+
arg_types.each do |k, t|
|
135
|
+
if t.match?(/\Aint\[.+\]\z/)
|
136
|
+
size = t[4..-2]
|
137
|
+
raise Error, "Unknown size: #{size}. Please report a bug with #{@name}." unless size =~ /\A\d+\z/
|
138
|
+
ret[k] = size.to_i
|
139
|
+
end
|
140
|
+
end
|
141
|
+
ret
|
142
|
+
end
|
143
|
+
end
|
144
|
+
|
145
|
+
def arg_names
|
146
|
+
@arg_names ||= args.map { |a| a[:name] }
|
147
|
+
end
|
148
|
+
|
149
|
+
def arg_types
|
150
|
+
@arg_types ||= args.map { |a| [a[:name], a[:type].split("(").first] }.to_h
|
151
|
+
end
|
152
|
+
|
153
|
+
def arg_defaults
|
154
|
+
# TODO find out why can't use select here
|
155
|
+
@arg_defaults ||= args.map { |a| [a[:name], a[:default]] }.to_h
|
156
|
+
end
|
157
|
+
|
81
158
|
def out_size
|
82
159
|
@out_size ||= func.split("->").last.count("!")
|
83
160
|
end
|
@@ -90,8 +167,12 @@ module Torch
|
|
90
167
|
@ret_array ||= func.split("->").last.include?('[]')
|
91
168
|
end
|
92
169
|
|
170
|
+
def ret_void?
|
171
|
+
func.split("->").last.strip == "()"
|
172
|
+
end
|
173
|
+
|
93
174
|
def out?
|
94
|
-
|
175
|
+
@out
|
95
176
|
end
|
96
177
|
|
97
178
|
def ruby_name
|
@@ -72,16 +72,18 @@ void add_%{type}_functions(Module m);
|
|
72
72
|
#include <rice/Module.hpp>
|
73
73
|
#include "templates.hpp"
|
74
74
|
|
75
|
+
%{functions}
|
76
|
+
|
75
77
|
void add_%{type}_functions(Module m) {
|
76
|
-
|
77
|
-
%{functions};
|
78
|
+
%{add_functions}
|
78
79
|
}
|
79
80
|
TEMPLATE
|
80
81
|
|
81
82
|
cpp_defs = []
|
83
|
+
add_defs = []
|
82
84
|
functions.sort_by(&:cpp_name).each do |func|
|
83
85
|
fargs = func.args.dup #.select { |a| a[:type] != "Generator?" }
|
84
|
-
fargs << {name:
|
86
|
+
fargs << {name: :options, type: "TensorOptions"} if func.tensor_options
|
85
87
|
|
86
88
|
cpp_args = []
|
87
89
|
fargs.each do |a|
|
@@ -94,11 +96,9 @@ void add_%{type}_functions(Module m) {
|
|
94
96
|
"OptionalTensor"
|
95
97
|
when "ScalarType?"
|
96
98
|
"torch::optional<ScalarType>"
|
97
|
-
when "Tensor[]"
|
98
|
-
"TensorList"
|
99
|
-
when "Tensor?[]"
|
99
|
+
when "Tensor[]", "Tensor?[]"
|
100
100
|
# TODO make optional
|
101
|
-
"
|
101
|
+
"std::vector<Tensor>"
|
102
102
|
when "int"
|
103
103
|
"int64_t"
|
104
104
|
when "int?"
|
@@ -112,43 +112,53 @@ void add_%{type}_functions(Module m) {
|
|
112
112
|
when "float"
|
113
113
|
"double"
|
114
114
|
when /\Aint\[/
|
115
|
-
"
|
115
|
+
"std::vector<int64_t>"
|
116
116
|
when /Tensor\(\S!?\)/
|
117
117
|
"Tensor &"
|
118
118
|
when "str"
|
119
119
|
"std::string"
|
120
120
|
when "TensorOptions"
|
121
121
|
"const torch::TensorOptions &"
|
122
|
-
|
122
|
+
when "Layout?"
|
123
|
+
"torch::optional<Layout>"
|
124
|
+
when "Device?"
|
125
|
+
"torch::optional<Device>"
|
126
|
+
when "Scalar", "bool", "ScalarType", "Layout", "Device", "Storage"
|
123
127
|
a[:type]
|
128
|
+
else
|
129
|
+
raise "Unknown type: #{a[:type]}"
|
124
130
|
end
|
125
131
|
|
126
|
-
t = "MyReduction" if a[:name] ==
|
132
|
+
t = "MyReduction" if a[:name] == :reduction && t == "int64_t"
|
127
133
|
cpp_args << [t, a[:name]].join(" ").sub("& ", "&")
|
128
134
|
end
|
129
135
|
|
130
136
|
dispatch = func.out? ? "#{func.base_name}_out" : func.base_name
|
131
137
|
args = fargs.map { |a| a[:name] }
|
132
138
|
args.unshift(*args.pop(func.out_size)) if func.out?
|
133
|
-
args.delete(
|
139
|
+
args.delete(:self) if def_method == :define_method
|
134
140
|
|
135
141
|
prefix = def_method == :define_method ? "self." : "torch::"
|
136
142
|
|
137
143
|
body = "#{prefix}#{dispatch}(#{args.join(", ")})"
|
138
144
|
|
139
|
-
if func.
|
145
|
+
if func.cpp_name == "_fill_diagonal_"
|
146
|
+
body = "to_ruby<torch::Tensor>(#{body})"
|
147
|
+
elsif !func.ret_void?
|
140
148
|
body = "wrap(#{body})"
|
141
149
|
end
|
142
150
|
|
143
|
-
cpp_defs << "
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
151
|
+
cpp_defs << "// #{func.func}
|
152
|
+
static #{func.ret_void? ? "void" : "Object"} #{type}#{func.cpp_name}(#{cpp_args.join(", ")})
|
153
|
+
{
|
154
|
+
return #{body};
|
155
|
+
}"
|
156
|
+
|
157
|
+
add_defs << "m.#{def_method}(\"#{func.cpp_name}\", #{type}#{func.cpp_name});"
|
148
158
|
end
|
149
159
|
|
150
160
|
hpp_contents = hpp_template % {type: type}
|
151
|
-
cpp_contents = cpp_template % {type: type, functions: cpp_defs.join("\n ")}
|
161
|
+
cpp_contents = cpp_template % {type: type, functions: cpp_defs.join("\n\n"), add_functions: add_defs.join("\n ")}
|
152
162
|
|
153
163
|
path = File.expand_path("../../../ext/torch", __dir__)
|
154
164
|
File.write("#{path}/#{type}_functions.hpp", hpp_contents)
|
data/lib/torch/native/parser.rb
CHANGED
@@ -6,14 +6,24 @@ module Torch
|
|
6
6
|
@name = @functions.first.ruby_name
|
7
7
|
@min_args = @functions.map { |f| f.args.count { |a| a[:pos] && !a[:has_default] } }.min
|
8
8
|
@max_args = @functions.map { |f| f.args.count { |a| a[:pos] } }.max
|
9
|
+
@int_array_first = @functions.all? { |c| c.args.first && c.args.first[:type] == "int[]" }
|
9
10
|
end
|
10
11
|
|
12
|
+
# TODO improve performance
|
13
|
+
# possibly move to C++ (see python_arg_parser.cpp)
|
11
14
|
def parse(args, options)
|
12
15
|
candidates = @functions.dup
|
13
16
|
|
14
|
-
#
|
15
|
-
|
16
|
-
|
17
|
+
# TODO check candidates individually to see if they match
|
18
|
+
if @int_array_first
|
19
|
+
int_args = []
|
20
|
+
while args.first.is_a?(Integer)
|
21
|
+
int_args << args.shift
|
22
|
+
end
|
23
|
+
if int_args.any?
|
24
|
+
raise ArgumentError, "argument '#{candidates.first.args.first[:name]}' must be array of ints, but found element of type #{args.first.class.name} at pos #{int_args.size + 1}" if args.any?
|
25
|
+
args.unshift(int_args)
|
26
|
+
end
|
17
27
|
end
|
18
28
|
|
19
29
|
# TODO account for args passed as options here
|
@@ -25,99 +35,60 @@ module Torch
|
|
25
35
|
|
26
36
|
candidates.reject! { |f| args.size > f.args.size }
|
27
37
|
|
28
|
-
# exclude functions missing required options
|
29
|
-
candidates.reject! do |func|
|
30
|
-
# TODO make more generic
|
31
|
-
func.out? && !options[:out]
|
32
|
-
end
|
33
|
-
|
34
38
|
# handle out with multiple
|
35
39
|
# there should only be one match, so safe to modify all
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
end
|
44
|
-
|
45
|
-
# exclude functions where options don't match
|
46
|
-
options.each do |k, v|
|
47
|
-
candidates.select! do |func|
|
48
|
-
func.args.any? { |a| a[:name] == k.to_s }
|
40
|
+
if options[:out]
|
41
|
+
if (out_func = candidates.find { |f| f.out? }) && out_func.out_size > 1
|
42
|
+
out_args = out_func.args.last(2).map { |a| a[:name] }
|
43
|
+
out_args.zip(options.delete(:out)).each do |k, v|
|
44
|
+
options[k] = v
|
45
|
+
end
|
46
|
+
candidates = [out_func]
|
49
47
|
end
|
50
|
-
|
51
|
-
|
48
|
+
else
|
49
|
+
# exclude functions missing required options
|
50
|
+
candidates.reject!(&:out?)
|
52
51
|
end
|
53
52
|
|
54
|
-
final_values =
|
53
|
+
final_values = nil
|
55
54
|
|
56
55
|
# check args
|
57
|
-
candidates.
|
56
|
+
while (func = candidates.shift)
|
58
57
|
good = true
|
59
58
|
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
59
|
+
# set values
|
60
|
+
# TODO use array instead of hash?
|
61
|
+
values = {}
|
62
|
+
args.each_with_index do |a, i|
|
63
|
+
values[func.arg_names[i]] = a
|
64
|
+
end
|
65
|
+
options.each do |k, v|
|
66
|
+
values[k] = v
|
67
|
+
end
|
68
|
+
func.arg_defaults.each do |k, v|
|
69
|
+
values[k] = v unless values.key?(k)
|
70
|
+
end
|
71
|
+
func.int_array_lengths.each do |k, len|
|
72
|
+
values[k] = [values[k]] * len if values[k].is_a?(Integer)
|
64
73
|
end
|
65
74
|
|
66
|
-
|
75
|
+
arg_checkers = func.arg_checkers
|
67
76
|
|
68
77
|
values.each_key do |k|
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
when "Tensor"
|
75
|
-
v.is_a?(Tensor)
|
76
|
-
when "Tensor?"
|
77
|
-
v.nil? || v.is_a?(Tensor)
|
78
|
-
when "Tensor[]", "Tensor?[]"
|
79
|
-
v.is_a?(Array) && v.all? { |v2| v2.is_a?(Tensor) }
|
80
|
-
when "int"
|
81
|
-
if k == "reduction"
|
82
|
-
v.is_a?(String)
|
83
|
-
else
|
84
|
-
v.is_a?(Integer)
|
85
|
-
end
|
86
|
-
when "int?"
|
87
|
-
v.is_a?(Integer) || v.nil?
|
88
|
-
when "float?"
|
89
|
-
v.is_a?(Numeric) || v.nil?
|
90
|
-
when "bool?"
|
91
|
-
v == true || v == false || v.nil?
|
92
|
-
when "float"
|
93
|
-
v.is_a?(Numeric)
|
94
|
-
when /int\[.*\]/
|
95
|
-
if v.is_a?(Integer)
|
96
|
-
size = t[4..-2]
|
97
|
-
raise Error, "Unknown size: #{size}. Please report a bug with #{@name}." unless size =~ /\A\d+\z/
|
98
|
-
v = [v] * size.to_i
|
99
|
-
values[k] = v
|
100
|
-
end
|
101
|
-
v.is_a?(Array) && v.all? { |v2| v2.is_a?(Integer) }
|
102
|
-
when "Scalar"
|
103
|
-
v.is_a?(Numeric)
|
104
|
-
when "Scalar?"
|
105
|
-
v.is_a?(Numeric) || v.nil?
|
106
|
-
when "ScalarType"
|
107
|
-
false # not supported yet
|
108
|
-
when "ScalarType?"
|
109
|
-
v.nil?
|
110
|
-
when "bool"
|
111
|
-
v == true || v == false
|
112
|
-
when "str"
|
113
|
-
v.is_a?(String)
|
114
|
-
else
|
115
|
-
raise Error, "Unknown argument type: #{arg_types[k]}. Please report a bug with #{@name}."
|
78
|
+
unless arg_checkers.key?(k)
|
79
|
+
good = false
|
80
|
+
if candidates.empty?
|
81
|
+
# TODO show all bad keywords at once like Ruby?
|
82
|
+
return {error: "unknown keyword: #{k}"}
|
116
83
|
end
|
84
|
+
break
|
85
|
+
end
|
117
86
|
|
118
|
-
|
119
|
-
|
120
|
-
|
87
|
+
unless arg_checkers[k].call(values[k])
|
88
|
+
good = false
|
89
|
+
if candidates.empty?
|
90
|
+
t = func.arg_types[k]
|
91
|
+
k = :input if k == :self
|
121
92
|
return {error: "#{@name}(): argument '#{k}' must be #{t}"}
|
122
93
|
end
|
123
94
|
break
|
@@ -126,17 +97,15 @@ module Torch
|
|
126
97
|
|
127
98
|
if good
|
128
99
|
final_values = values
|
100
|
+
break
|
129
101
|
end
|
130
|
-
|
131
|
-
good
|
132
102
|
end
|
133
103
|
|
134
|
-
|
104
|
+
unless final_values
|
135
105
|
raise Error, "This should never happen. Please report a bug with #{@name}."
|
136
106
|
end
|
137
107
|
|
138
|
-
|
139
|
-
args = func.args.map { |a| final_values[a[:name]] }
|
108
|
+
args = func.arg_names.map { |k| final_values[k] }
|
140
109
|
args << TensorOptions.new.dtype(6) if func.tensor_options
|
141
110
|
{
|
142
111
|
name: func.cpp_name,
|
data/lib/torch/nn/functional.rb
CHANGED
@@ -469,6 +469,77 @@ module Torch
|
|
469
469
|
Torch.triplet_margin_loss(anchor, positive, negative, margin, p, eps, swap, reduction)
|
470
470
|
end
|
471
471
|
|
472
|
+
# vision
|
473
|
+
|
474
|
+
def interpolate(input, size: nil, scale_factor: nil, mode: "nearest", align_corners: nil, recompute_scale_factor: nil)
|
475
|
+
if ["nearest", "area"].include?(mode)
|
476
|
+
unless align_corners.nil?
|
477
|
+
raise ArgumentError, "align_corners option can only be set with the interpolating modes: linear | bilinear | bicubic | trilinear"
|
478
|
+
end
|
479
|
+
else
|
480
|
+
if align_corners.nil?
|
481
|
+
align_corners = false
|
482
|
+
end
|
483
|
+
end
|
484
|
+
|
485
|
+
scale_factor_len = input.dim - 2
|
486
|
+
scale_factor_list = [nil] * scale_factor_len
|
487
|
+
# default value of recompute_scale_factor is False
|
488
|
+
if !scale_factor.nil? && (recompute_scale_factor == false || recompute_scale_factor.nil?)
|
489
|
+
if scale_factor.is_a?(Array)
|
490
|
+
_scale_factor_repeated = scale_factor
|
491
|
+
else
|
492
|
+
_scale_factor_repeated = [scale_factor] * scale_factor_len
|
493
|
+
end
|
494
|
+
scale_factor_list = _scale_factor_repeated
|
495
|
+
end
|
496
|
+
|
497
|
+
# Give this variable a short name because it has to be repeated multiple times below.
|
498
|
+
sfl = scale_factor_list
|
499
|
+
|
500
|
+
closed_over_args = [input, size, scale_factor, recompute_scale_factor]
|
501
|
+
output_size = _interp_output_size(closed_over_args)
|
502
|
+
if input.dim == 3 && mode == "nearest"
|
503
|
+
NN.upsample_nearest1d(input, output_size, sfl[0])
|
504
|
+
elsif input.dim == 4 && mode == "nearest"
|
505
|
+
NN.upsample_nearest2d(input, output_size, sfl[0], sfl[1])
|
506
|
+
elsif input.dim == 5 && mode == "nearest"
|
507
|
+
NN.upsample_nearest3d(input, output_size, sfl[0], sfl[1], sfl[2])
|
508
|
+
elsif input.dim == 3 && mode == "area"
|
509
|
+
adaptive_avg_pool1d(input, output_size)
|
510
|
+
elsif input.dim == 4 && mode == "area"
|
511
|
+
adaptive_avg_pool2d(input, output_size)
|
512
|
+
elsif input.dim == 5 && mode == "area"
|
513
|
+
adaptive_avg_pool3d(input, output_size)
|
514
|
+
elsif input.dim == 3 && mode == "linear"
|
515
|
+
# assert align_corners is not None
|
516
|
+
NN.upsample_linear1d(input, output_size, align_corners, sfl[0])
|
517
|
+
elsif input.dim == 3 && mode == "bilinear"
|
518
|
+
raise ArgumentError, "Got 3D input, but bilinear mode needs 4D input"
|
519
|
+
elsif input.dim == 3 && mode == "trilinear"
|
520
|
+
raise ArgumentError, "Got 3D input, but trilinear mode needs 5D input"
|
521
|
+
elsif input.dim == 4 && mode == "linear"
|
522
|
+
raise ArgumentError, "Got 4D input, but linear mode needs 3D input"
|
523
|
+
elsif input.dim == 4 && mode == "bilinear"
|
524
|
+
# assert align_corners is not None
|
525
|
+
NN.upsample_bilinear2d(input, output_size, align_corners, sfl[0], sfl[1])
|
526
|
+
elsif input.dim == 4 && mode == "trilinear"
|
527
|
+
raise ArgumentError, "Got 4D input, but trilinear mode needs 5D input"
|
528
|
+
elsif input.dim == 5 && mode == "linear"
|
529
|
+
raise ArgumentError, "Got 5D input, but linear mode needs 3D input"
|
530
|
+
elsif input.dim == 5 && mode == "bilinear"
|
531
|
+
raise ArgumentError, "Got 5D input, but bilinear mode needs 4D input"
|
532
|
+
elsif input.dim == 5 && mode == "trilinear"
|
533
|
+
# assert align_corners is not None
|
534
|
+
NN.upsample_trilinear3d(input, output_size, align_corners, sfl[0], sfl[1], sfl[2])
|
535
|
+
elsif input.dim == 4 && mode == "bicubic"
|
536
|
+
# assert align_corners is not None
|
537
|
+
NN.upsample_bicubic2d(input, output_size, align_corners, sfl[0], sfl[1])
|
538
|
+
else
|
539
|
+
raise ArgumentError, "Input Error: Only 3D, 4D and 5D input Tensors supported (got #{input.dim}D) for the modes: nearest | linear | bilinear | bicubic | trilinear (got #{mode})"
|
540
|
+
end
|
541
|
+
end
|
542
|
+
|
472
543
|
private
|
473
544
|
|
474
545
|
def softmax_dim(ndim)
|
@@ -484,6 +555,41 @@ module Torch
|
|
484
555
|
out_size.zip(defaults.last(out_size.length)).map { |v, d| v || d }
|
485
556
|
end
|
486
557
|
end
|
558
|
+
|
559
|
+
def _interp_output_size(closed_over_args)
|
560
|
+
input, size, scale_factor, recompute_scale_factor = closed_over_args
|
561
|
+
dim = input.dim - 2
|
562
|
+
if size.nil? && scale_factor.nil?
|
563
|
+
raise ArgumentError, "either size or scale_factor should be defined"
|
564
|
+
end
|
565
|
+
if !size.nil? && !scale_factor.nil?
|
566
|
+
raise ArgumentError, "only one of size or scale_factor should be defined"
|
567
|
+
end
|
568
|
+
if !scale_factor.nil?
|
569
|
+
if scale_factor.is_a?(Array)
|
570
|
+
if scale_factor.length != dim
|
571
|
+
raise ArgumentError, "scale_factor shape must match input shape. Input is #{dim}D, scale_factor size is #{scale_factor.length}"
|
572
|
+
end
|
573
|
+
end
|
574
|
+
end
|
575
|
+
|
576
|
+
if !size.nil?
|
577
|
+
if size.is_a?(Array)
|
578
|
+
return size
|
579
|
+
else
|
580
|
+
return [size] * dim
|
581
|
+
end
|
582
|
+
end
|
583
|
+
|
584
|
+
raise "Failed assertion" if scale_factor.nil?
|
585
|
+
if scale_factor.is_a?(Array)
|
586
|
+
scale_factors = scale_factor
|
587
|
+
else
|
588
|
+
scale_factors = [scale_factor] * dim
|
589
|
+
end
|
590
|
+
|
591
|
+
dim.times.map { |i| (input.size(i + 2) * scale_factors[i]).floor }
|
592
|
+
end
|
487
593
|
end
|
488
594
|
end
|
489
595
|
|
data/lib/torch/nn/module.rb
CHANGED
@@ -0,0 +1,31 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class Upsample < Module
|
4
|
+
def initialize(size: nil, scale_factor: nil, mode: "nearest", align_corners: nil)
|
5
|
+
super()
|
6
|
+
@size = size
|
7
|
+
if scale_factor.is_a?(Array)
|
8
|
+
@scale_factor = scale_factor.map(&:to_f)
|
9
|
+
else
|
10
|
+
@scale_factor = scale_factor ? scale_factor.to_f : nil
|
11
|
+
end
|
12
|
+
@mode = mode
|
13
|
+
@align_corners = align_corners
|
14
|
+
end
|
15
|
+
|
16
|
+
def forward(input)
|
17
|
+
F.interpolate(input, size: @size, scale_factor: @scale_factor, mode: @mode, align_corners: @align_corners)
|
18
|
+
end
|
19
|
+
|
20
|
+
def extra_inspect
|
21
|
+
if !@scale_factor.nil?
|
22
|
+
info = "scale_factor: #{@scale_factor.inspect}"
|
23
|
+
else
|
24
|
+
info = "size: #{@size.inspect}"
|
25
|
+
end
|
26
|
+
info += ", mode: #{@mode.inspect}"
|
27
|
+
info
|
28
|
+
end
|
29
|
+
end
|
30
|
+
end
|
31
|
+
end
|
data/lib/torch/tensor.rb
CHANGED
@@ -48,6 +48,11 @@ module Torch
|
|
48
48
|
end
|
49
49
|
|
50
50
|
def to(device = nil, dtype: nil, non_blocking: false, copy: false)
|
51
|
+
if device.is_a?(Symbol) && !dtype
|
52
|
+
dtype = device
|
53
|
+
device = nil
|
54
|
+
end
|
55
|
+
|
51
56
|
device ||= self.device
|
52
57
|
device = Device.new(device) if device.is_a?(String)
|
53
58
|
|
@@ -74,10 +79,6 @@ module Torch
|
|
74
79
|
end
|
75
80
|
end
|
76
81
|
|
77
|
-
def shape
|
78
|
-
dim.times.map { |i| size(i) }
|
79
|
-
end
|
80
|
-
|
81
82
|
# mirror Python len()
|
82
83
|
def length
|
83
84
|
size(0)
|
@@ -119,9 +120,14 @@ module Torch
|
|
119
120
|
end
|
120
121
|
|
121
122
|
def type(dtype)
|
122
|
-
|
123
|
-
|
124
|
-
|
123
|
+
if dtype.is_a?(Class)
|
124
|
+
raise Error, "Invalid type: #{dtype}" unless TENSOR_TYPE_CLASSES.include?(dtype)
|
125
|
+
dtype.new(self)
|
126
|
+
else
|
127
|
+
enum = DTYPE_TO_ENUM[dtype]
|
128
|
+
raise Error, "Invalid type: #{dtype}" unless enum
|
129
|
+
_type(enum)
|
130
|
+
end
|
125
131
|
end
|
126
132
|
|
127
133
|
def reshape(*size)
|
data/lib/torch/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.3.
|
4
|
+
version: 0.3.7
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2020-09-
|
11
|
+
date: 2020-09-23 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|
@@ -238,6 +238,7 @@ files:
|
|
238
238
|
- lib/torch/nn/tanhshrink.rb
|
239
239
|
- lib/torch/nn/triplet_margin_loss.rb
|
240
240
|
- lib/torch/nn/unfold.rb
|
241
|
+
- lib/torch/nn/upsample.rb
|
241
242
|
- lib/torch/nn/utils.rb
|
242
243
|
- lib/torch/nn/weighted_loss.rb
|
243
244
|
- lib/torch/nn/zero_pad2d.rb
|