torch-rb 0.3.6 → 0.3.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +9 -0
- data/README.md +1 -0
- data/ext/torch/ext.cpp +29 -9
- data/ext/torch/extconf.rb +3 -0
- data/ext/torch/templates.cpp +28 -0
- data/ext/torch/templates.hpp +23 -34
- data/lib/torch.rb +35 -0
- data/lib/torch/native/dispatcher.rb +30 -8
- data/lib/torch/native/function.rb +87 -6
- data/lib/torch/native/generator.rb +28 -18
- data/lib/torch/native/parser.rb +55 -86
- data/lib/torch/nn/functional.rb +106 -0
- data/lib/torch/nn/module.rb +4 -1
- data/lib/torch/nn/upsample.rb +31 -0
- data/lib/torch/tensor.rb +13 -7
- data/lib/torch/version.rb +1 -1
- metadata +3 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 8a1852ee3d1ecc7a29c23259b8c328a95030a270b7c11f37f22049177898652e
|
4
|
+
data.tar.gz: 56823f1815d3c0c4d5d5c01ef76d781b792b3e4e7c68c0332a149b883a54c7c8
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: bed15510cfeaa555d71f1e1f46ed8944893bd349a07c4316dcd63429fe76e13facd8794399ef97fc400d05796579f2e84822b62c98c71dc996e211ad04113ae2
|
7
|
+
data.tar.gz: aa05e3645e363eda27274323cdb7fb316342074d1d5afe8f7ee6bfd9819da7883b43d084beb6b29011c631c04fdddc8e6789db41c7c84c53ba9ed152d3338b09
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,12 @@
|
|
1
|
+
## 0.3.7 (2020-09-22)
|
2
|
+
|
3
|
+
- Improved performance
|
4
|
+
- Added `Upsample`
|
5
|
+
- Added support for passing tensor class to `type` method
|
6
|
+
- Fixed error with buffers on GPU
|
7
|
+
- Fixed error with `new_full`
|
8
|
+
- Fixed issue with `numo` method and non-contiguous tensors
|
9
|
+
|
1
10
|
## 0.3.6 (2020-09-17)
|
2
11
|
|
3
12
|
- Added `inplace` option for leaky ReLU
|
data/README.md
CHANGED
@@ -402,6 +402,7 @@ Here are a few full examples:
|
|
402
402
|
- [Image classification with MNIST](examples/mnist) ([日本語版](https://qiita.com/kojix2/items/c19c36dc1bf73ea93409))
|
403
403
|
- [Collaborative filtering with MovieLens](examples/movielens)
|
404
404
|
- [Sequence models and word embeddings](examples/nlp)
|
405
|
+
- [Generative adversarial networks](examples/gan)
|
405
406
|
|
406
407
|
## LibTorch Installation
|
407
408
|
|
data/ext/torch/ext.cpp
CHANGED
@@ -232,7 +232,7 @@ void Init_ext()
|
|
232
232
|
})
|
233
233
|
.define_singleton_method(
|
234
234
|
"_empty",
|
235
|
-
*[](
|
235
|
+
*[](std::vector<int64_t> size, const torch::TensorOptions &options) {
|
236
236
|
return torch::empty(size, options);
|
237
237
|
})
|
238
238
|
.define_singleton_method(
|
@@ -242,7 +242,7 @@ void Init_ext()
|
|
242
242
|
})
|
243
243
|
.define_singleton_method(
|
244
244
|
"_full",
|
245
|
-
*[](
|
245
|
+
*[](std::vector<int64_t> size, Scalar fill_value, const torch::TensorOptions& options) {
|
246
246
|
return torch::full(size, fill_value, options);
|
247
247
|
})
|
248
248
|
.define_singleton_method(
|
@@ -257,22 +257,22 @@ void Init_ext()
|
|
257
257
|
})
|
258
258
|
.define_singleton_method(
|
259
259
|
"_ones",
|
260
|
-
*[](
|
260
|
+
*[](std::vector<int64_t> size, const torch::TensorOptions &options) {
|
261
261
|
return torch::ones(size, options);
|
262
262
|
})
|
263
263
|
.define_singleton_method(
|
264
264
|
"_rand",
|
265
|
-
*[](
|
265
|
+
*[](std::vector<int64_t> size, const torch::TensorOptions &options) {
|
266
266
|
return torch::rand(size, options);
|
267
267
|
})
|
268
268
|
.define_singleton_method(
|
269
269
|
"_randint",
|
270
|
-
*[](int64_t low, int64_t high,
|
270
|
+
*[](int64_t low, int64_t high, std::vector<int64_t> size, const torch::TensorOptions &options) {
|
271
271
|
return torch::randint(low, high, size, options);
|
272
272
|
})
|
273
273
|
.define_singleton_method(
|
274
274
|
"_randn",
|
275
|
-
*[](
|
275
|
+
*[](std::vector<int64_t> size, const torch::TensorOptions &options) {
|
276
276
|
return torch::randn(size, options);
|
277
277
|
})
|
278
278
|
.define_singleton_method(
|
@@ -282,7 +282,7 @@ void Init_ext()
|
|
282
282
|
})
|
283
283
|
.define_singleton_method(
|
284
284
|
"_zeros",
|
285
|
-
*[](
|
285
|
+
*[](std::vector<int64_t> size, const torch::TensorOptions &options) {
|
286
286
|
return torch::zeros(size, options);
|
287
287
|
})
|
288
288
|
// begin operations
|
@@ -303,13 +303,13 @@ void Init_ext()
|
|
303
303
|
})
|
304
304
|
.define_singleton_method(
|
305
305
|
"_from_blob",
|
306
|
-
*[](String s,
|
306
|
+
*[](String s, std::vector<int64_t> size, const torch::TensorOptions &options) {
|
307
307
|
void *data = const_cast<char *>(s.c_str());
|
308
308
|
return torch::from_blob(data, size, options);
|
309
309
|
})
|
310
310
|
.define_singleton_method(
|
311
311
|
"_tensor",
|
312
|
-
*[](Array a,
|
312
|
+
*[](Array a, std::vector<int64_t> size, const torch::TensorOptions &options) {
|
313
313
|
auto dtype = options.dtype();
|
314
314
|
torch::Tensor t;
|
315
315
|
if (dtype == torch::kBool) {
|
@@ -342,6 +342,16 @@ void Init_ext()
|
|
342
342
|
.define_method("numel", &torch::Tensor::numel)
|
343
343
|
.define_method("element_size", &torch::Tensor::element_size)
|
344
344
|
.define_method("requires_grad", &torch::Tensor::requires_grad)
|
345
|
+
// in C++ for performance
|
346
|
+
.define_method(
|
347
|
+
"shape",
|
348
|
+
*[](Tensor& self) {
|
349
|
+
Array a;
|
350
|
+
for (auto &size : self.sizes()) {
|
351
|
+
a.push(size);
|
352
|
+
}
|
353
|
+
return a;
|
354
|
+
})
|
345
355
|
.define_method(
|
346
356
|
"_index",
|
347
357
|
*[](Tensor& self, Array indices) {
|
@@ -420,9 +430,19 @@ void Init_ext()
|
|
420
430
|
tensor = tensor.to(device);
|
421
431
|
}
|
422
432
|
|
433
|
+
if (!tensor.is_contiguous()) {
|
434
|
+
tensor = tensor.contiguous();
|
435
|
+
}
|
436
|
+
|
423
437
|
auto data_ptr = (const char *) tensor.data_ptr();
|
424
438
|
return std::string(data_ptr, tensor.numel() * tensor.element_size());
|
425
439
|
})
|
440
|
+
// for TorchVision
|
441
|
+
.define_method(
|
442
|
+
"_data_ptr",
|
443
|
+
*[](Tensor& self) {
|
444
|
+
return reinterpret_cast<uintptr_t>(self.data_ptr());
|
445
|
+
})
|
426
446
|
// TODO figure out a better way to do this
|
427
447
|
.define_method(
|
428
448
|
"_flat_data",
|
data/ext/torch/extconf.rb
CHANGED
data/ext/torch/templates.cpp
CHANGED
@@ -2,6 +2,34 @@
|
|
2
2
|
#include <rice/Object.hpp>
|
3
3
|
#include "templates.hpp"
|
4
4
|
|
5
|
+
Object wrap(bool x) {
|
6
|
+
return to_ruby<bool>(x);
|
7
|
+
}
|
8
|
+
|
9
|
+
Object wrap(int64_t x) {
|
10
|
+
return to_ruby<int64_t>(x);
|
11
|
+
}
|
12
|
+
|
13
|
+
Object wrap(double x) {
|
14
|
+
return to_ruby<double>(x);
|
15
|
+
}
|
16
|
+
|
17
|
+
Object wrap(torch::Tensor x) {
|
18
|
+
return to_ruby<torch::Tensor>(x);
|
19
|
+
}
|
20
|
+
|
21
|
+
Object wrap(torch::Scalar x) {
|
22
|
+
return to_ruby<torch::Scalar>(x);
|
23
|
+
}
|
24
|
+
|
25
|
+
Object wrap(torch::ScalarType x) {
|
26
|
+
return to_ruby<torch::ScalarType>(x);
|
27
|
+
}
|
28
|
+
|
29
|
+
Object wrap(torch::QScheme x) {
|
30
|
+
return to_ruby<torch::QScheme>(x);
|
31
|
+
}
|
32
|
+
|
5
33
|
Object wrap(std::tuple<torch::Tensor, torch::Tensor> x) {
|
6
34
|
Array a;
|
7
35
|
a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
|
data/ext/torch/templates.hpp
CHANGED
@@ -13,49 +13,31 @@ using torch::Device;
|
|
13
13
|
using torch::Scalar;
|
14
14
|
using torch::ScalarType;
|
15
15
|
using torch::Tensor;
|
16
|
-
|
17
|
-
|
18
|
-
// it doesn't own underlying data
|
19
|
-
class IntArrayRef {
|
20
|
-
std::vector<int64_t> vec;
|
21
|
-
public:
|
22
|
-
IntArrayRef(Object o) {
|
23
|
-
Array a = Array(o);
|
24
|
-
for (size_t i = 0; i < a.size(); i++) {
|
25
|
-
vec.push_back(from_ruby<int64_t>(a[i]));
|
26
|
-
}
|
27
|
-
}
|
28
|
-
operator torch::IntArrayRef() {
|
29
|
-
return torch::IntArrayRef(vec);
|
30
|
-
}
|
31
|
-
};
|
16
|
+
using torch::IntArrayRef;
|
17
|
+
using torch::TensorList;
|
32
18
|
|
33
19
|
template<>
|
34
20
|
inline
|
35
|
-
|
21
|
+
std::vector<int64_t> from_ruby<std::vector<int64_t>>(Object x)
|
36
22
|
{
|
37
|
-
|
23
|
+
Array a = Array(x);
|
24
|
+
std::vector<int64_t> vec(a.size());
|
25
|
+
for (size_t i = 0; i < a.size(); i++) {
|
26
|
+
vec[i] = from_ruby<int64_t>(a[i]);
|
27
|
+
}
|
28
|
+
return vec;
|
38
29
|
}
|
39
30
|
|
40
|
-
class TensorList {
|
41
|
-
std::vector<torch::Tensor> vec;
|
42
|
-
public:
|
43
|
-
TensorList(Object o) {
|
44
|
-
Array a = Array(o);
|
45
|
-
for (size_t i = 0; i < a.size(); i++) {
|
46
|
-
vec.push_back(from_ruby<torch::Tensor>(a[i]));
|
47
|
-
}
|
48
|
-
}
|
49
|
-
operator torch::TensorList() {
|
50
|
-
return torch::TensorList(vec);
|
51
|
-
}
|
52
|
-
};
|
53
|
-
|
54
31
|
template<>
|
55
32
|
inline
|
56
|
-
|
33
|
+
std::vector<Tensor> from_ruby<std::vector<Tensor>>(Object x)
|
57
34
|
{
|
58
|
-
|
35
|
+
Array a = Array(x);
|
36
|
+
std::vector<Tensor> vec(a.size());
|
37
|
+
for (size_t i = 0; i < a.size(); i++) {
|
38
|
+
vec[i] = from_ruby<Tensor>(a[i]);
|
39
|
+
}
|
40
|
+
return vec;
|
59
41
|
}
|
60
42
|
|
61
43
|
class FanModeType {
|
@@ -242,6 +224,13 @@ torch::optional<Scalar> from_ruby<torch::optional<Scalar>>(Object x)
|
|
242
224
|
}
|
243
225
|
}
|
244
226
|
|
227
|
+
Object wrap(bool x);
|
228
|
+
Object wrap(int64_t x);
|
229
|
+
Object wrap(double x);
|
230
|
+
Object wrap(torch::Tensor x);
|
231
|
+
Object wrap(torch::Scalar x);
|
232
|
+
Object wrap(torch::ScalarType x);
|
233
|
+
Object wrap(torch::QScheme x);
|
245
234
|
Object wrap(std::tuple<torch::Tensor, torch::Tensor> x);
|
246
235
|
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x);
|
247
236
|
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x);
|
data/lib/torch.rb
CHANGED
@@ -174,6 +174,9 @@ require "torch/nn/smooth_l1_loss"
|
|
174
174
|
require "torch/nn/soft_margin_loss"
|
175
175
|
require "torch/nn/triplet_margin_loss"
|
176
176
|
|
177
|
+
# nn vision
|
178
|
+
require "torch/nn/upsample"
|
179
|
+
|
177
180
|
# nn other
|
178
181
|
require "torch/nn/functional"
|
179
182
|
require "torch/nn/init"
|
@@ -196,6 +199,32 @@ module Torch
|
|
196
199
|
end
|
197
200
|
end
|
198
201
|
|
202
|
+
# legacy
|
203
|
+
# but may make it easier to port tutorials
|
204
|
+
module Autograd
|
205
|
+
class Variable
|
206
|
+
def self.new(x)
|
207
|
+
raise ArgumentError, "Variable data has to be a tensor, but got #{x.class.name}" unless x.is_a?(Tensor)
|
208
|
+
warn "[torch] The Variable API is deprecated. Use tensors with requires_grad: true instead."
|
209
|
+
x
|
210
|
+
end
|
211
|
+
end
|
212
|
+
end
|
213
|
+
|
214
|
+
# TODO move to C++
|
215
|
+
class ByteStorage
|
216
|
+
# private
|
217
|
+
attr_reader :bytes
|
218
|
+
|
219
|
+
def initialize(bytes)
|
220
|
+
@bytes = bytes
|
221
|
+
end
|
222
|
+
|
223
|
+
def self.from_buffer(bytes)
|
224
|
+
new(bytes)
|
225
|
+
end
|
226
|
+
end
|
227
|
+
|
199
228
|
# keys: https://pytorch.org/docs/stable/tensor_attributes.html#torch.torch.dtype
|
200
229
|
# values: https://github.com/pytorch/pytorch/blob/master/c10/core/ScalarType.h
|
201
230
|
DTYPE_TO_ENUM = {
|
@@ -224,18 +253,24 @@ module Torch
|
|
224
253
|
}
|
225
254
|
ENUM_TO_DTYPE = DTYPE_TO_ENUM.map(&:reverse).to_h
|
226
255
|
|
256
|
+
TENSOR_TYPE_CLASSES = []
|
257
|
+
|
227
258
|
def self._make_tensor_class(dtype, cuda = false)
|
228
259
|
cls = Class.new
|
229
260
|
device = cuda ? "cuda" : "cpu"
|
230
261
|
cls.define_singleton_method("new") do |*args|
|
231
262
|
if args.size == 1 && args.first.is_a?(Tensor)
|
232
263
|
args.first.send(dtype).to(device)
|
264
|
+
elsif args.size == 1 && args.first.is_a?(ByteStorage) && dtype == :uint8
|
265
|
+
bytes = args.first.bytes
|
266
|
+
Torch._from_blob(bytes, [bytes.bytesize], TensorOptions.new.dtype(DTYPE_TO_ENUM[dtype]))
|
233
267
|
elsif args.size == 1 && args.first.is_a?(Array)
|
234
268
|
Torch.tensor(args.first, dtype: dtype, device: device)
|
235
269
|
else
|
236
270
|
Torch.empty(*args, dtype: dtype, device: device)
|
237
271
|
end
|
238
272
|
end
|
273
|
+
TENSOR_TYPE_CLASSES << cls
|
239
274
|
cls
|
240
275
|
end
|
241
276
|
|
@@ -22,21 +22,43 @@ module Torch
|
|
22
22
|
end
|
23
23
|
|
24
24
|
def bind_functions(context, def_method, functions)
|
25
|
+
instance_method = def_method == :define_method
|
25
26
|
functions.group_by(&:ruby_name).sort_by { |g, _| g }.each do |name, funcs|
|
26
|
-
if
|
27
|
+
if instance_method
|
27
28
|
funcs.map! { |f| Function.new(f.function) }
|
28
|
-
funcs.each { |f| f.args.reject! { |a| a[:name] ==
|
29
|
+
funcs.each { |f| f.args.reject! { |a| a[:name] == :self } }
|
29
30
|
end
|
30
31
|
|
31
|
-
defined =
|
32
|
+
defined = instance_method ? context.method_defined?(name) : context.respond_to?(name)
|
32
33
|
next if defined && name != "clone"
|
33
34
|
|
34
|
-
parser
|
35
|
+
# skip parser when possible for performance
|
36
|
+
if funcs.size == 1 && funcs.first.args.size == 0
|
37
|
+
# functions with no arguments
|
38
|
+
if instance_method
|
39
|
+
context.send(:alias_method, name, funcs.first.cpp_name)
|
40
|
+
else
|
41
|
+
context.singleton_class.send(:alias_method, name, funcs.first.cpp_name)
|
42
|
+
end
|
43
|
+
elsif funcs.size == 2 && funcs.map { |f| f.arg_types.values }.sort == [["Scalar"], ["Tensor"]]
|
44
|
+
# functions that take a tensor or scalar
|
45
|
+
scalar_name, tensor_name = funcs.sort_by { |f| f.arg_types.values }.map(&:cpp_name)
|
46
|
+
context.send(def_method, name) do |other|
|
47
|
+
case other
|
48
|
+
when Tensor
|
49
|
+
send(tensor_name, other)
|
50
|
+
else
|
51
|
+
send(scalar_name, other)
|
52
|
+
end
|
53
|
+
end
|
54
|
+
else
|
55
|
+
parser = Parser.new(funcs)
|
35
56
|
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
57
|
+
context.send(def_method, name) do |*args, **options|
|
58
|
+
result = parser.parse(args, options)
|
59
|
+
raise ArgumentError, result[:error] if result[:error]
|
60
|
+
send(result[:name], *result[:args])
|
61
|
+
end
|
40
62
|
end
|
41
63
|
end
|
42
64
|
end
|
@@ -6,9 +6,10 @@ module Torch
|
|
6
6
|
def initialize(function)
|
7
7
|
@function = function
|
8
8
|
|
9
|
-
|
10
|
-
@
|
11
|
-
@function["func"].
|
9
|
+
# note: don't modify function in-place
|
10
|
+
@tensor_options_str = ", *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None)"
|
11
|
+
@tensor_options = @function["func"].include?(@tensor_options_str)
|
12
|
+
@out = out_size > 0 && base_name[-1] != "_"
|
12
13
|
end
|
13
14
|
|
14
15
|
def func
|
@@ -31,7 +32,7 @@ module Torch
|
|
31
32
|
@args ||= begin
|
32
33
|
args = []
|
33
34
|
pos = true
|
34
|
-
args_str = func.split("(", 2).last.split(") ->").first
|
35
|
+
args_str = func.sub(@tensor_options_str, ")").split("(", 2).last.split(") ->").first
|
35
36
|
args_str.split(", ").each do |a|
|
36
37
|
if a == "*"
|
37
38
|
pos = false
|
@@ -72,12 +73,88 @@ module Torch
|
|
72
73
|
next if t == "Generator?"
|
73
74
|
next if t == "MemoryFormat"
|
74
75
|
next if t == "MemoryFormat?"
|
75
|
-
args << {name: k, type: t, default: d, pos: pos, has_default: has_default}
|
76
|
+
args << {name: k.to_sym, type: t, default: d, pos: pos, has_default: has_default}
|
76
77
|
end
|
77
78
|
args
|
78
79
|
end
|
79
80
|
end
|
80
81
|
|
82
|
+
def arg_checkers
|
83
|
+
@arg_checkers ||= begin
|
84
|
+
checkers = {}
|
85
|
+
arg_types.each do |k, t|
|
86
|
+
checker =
|
87
|
+
case t
|
88
|
+
when "Tensor"
|
89
|
+
->(v) { v.is_a?(Tensor) }
|
90
|
+
when "Tensor?"
|
91
|
+
->(v) { v.nil? || v.is_a?(Tensor) }
|
92
|
+
when "Tensor[]", "Tensor?[]"
|
93
|
+
->(v) { v.is_a?(Array) && v.all? { |v2| v2.is_a?(Tensor) } }
|
94
|
+
when "int"
|
95
|
+
if k == :reduction
|
96
|
+
->(v) { v.is_a?(String) }
|
97
|
+
else
|
98
|
+
->(v) { v.is_a?(Integer) }
|
99
|
+
end
|
100
|
+
when "int?"
|
101
|
+
->(v) { v.is_a?(Integer) || v.nil? }
|
102
|
+
when "float?"
|
103
|
+
->(v) { v.is_a?(Numeric) || v.nil? }
|
104
|
+
when "bool?"
|
105
|
+
->(v) { v == true || v == false || v.nil? }
|
106
|
+
when "float"
|
107
|
+
->(v) { v.is_a?(Numeric) }
|
108
|
+
when /int\[.*\]/
|
109
|
+
->(v) { v.is_a?(Array) && v.all? { |v2| v2.is_a?(Integer) } }
|
110
|
+
when "Scalar"
|
111
|
+
->(v) { v.is_a?(Numeric) }
|
112
|
+
when "Scalar?"
|
113
|
+
->(v) { v.is_a?(Numeric) || v.nil? }
|
114
|
+
when "ScalarType"
|
115
|
+
->(v) { false } # not supported yet
|
116
|
+
when "ScalarType?"
|
117
|
+
->(v) { v.nil? }
|
118
|
+
when "bool"
|
119
|
+
->(v) { v == true || v == false }
|
120
|
+
when "str"
|
121
|
+
->(v) { v.is_a?(String) }
|
122
|
+
else
|
123
|
+
raise Error, "Unknown argument type: #{t}. Please report a bug with #{@name}."
|
124
|
+
end
|
125
|
+
checkers[k] = checker
|
126
|
+
end
|
127
|
+
checkers
|
128
|
+
end
|
129
|
+
end
|
130
|
+
|
131
|
+
def int_array_lengths
|
132
|
+
@int_array_lengths ||= begin
|
133
|
+
ret = {}
|
134
|
+
arg_types.each do |k, t|
|
135
|
+
if t.match?(/\Aint\[.+\]\z/)
|
136
|
+
size = t[4..-2]
|
137
|
+
raise Error, "Unknown size: #{size}. Please report a bug with #{@name}." unless size =~ /\A\d+\z/
|
138
|
+
ret[k] = size.to_i
|
139
|
+
end
|
140
|
+
end
|
141
|
+
ret
|
142
|
+
end
|
143
|
+
end
|
144
|
+
|
145
|
+
def arg_names
|
146
|
+
@arg_names ||= args.map { |a| a[:name] }
|
147
|
+
end
|
148
|
+
|
149
|
+
def arg_types
|
150
|
+
@arg_types ||= args.map { |a| [a[:name], a[:type].split("(").first] }.to_h
|
151
|
+
end
|
152
|
+
|
153
|
+
def arg_defaults
|
154
|
+
# TODO find out why can't use select here
|
155
|
+
@arg_defaults ||= args.map { |a| [a[:name], a[:default]] }.to_h
|
156
|
+
end
|
157
|
+
|
81
158
|
def out_size
|
82
159
|
@out_size ||= func.split("->").last.count("!")
|
83
160
|
end
|
@@ -90,8 +167,12 @@ module Torch
|
|
90
167
|
@ret_array ||= func.split("->").last.include?('[]')
|
91
168
|
end
|
92
169
|
|
170
|
+
def ret_void?
|
171
|
+
func.split("->").last.strip == "()"
|
172
|
+
end
|
173
|
+
|
93
174
|
def out?
|
94
|
-
|
175
|
+
@out
|
95
176
|
end
|
96
177
|
|
97
178
|
def ruby_name
|
@@ -72,16 +72,18 @@ void add_%{type}_functions(Module m);
|
|
72
72
|
#include <rice/Module.hpp>
|
73
73
|
#include "templates.hpp"
|
74
74
|
|
75
|
+
%{functions}
|
76
|
+
|
75
77
|
void add_%{type}_functions(Module m) {
|
76
|
-
|
77
|
-
%{functions};
|
78
|
+
%{add_functions}
|
78
79
|
}
|
79
80
|
TEMPLATE
|
80
81
|
|
81
82
|
cpp_defs = []
|
83
|
+
add_defs = []
|
82
84
|
functions.sort_by(&:cpp_name).each do |func|
|
83
85
|
fargs = func.args.dup #.select { |a| a[:type] != "Generator?" }
|
84
|
-
fargs << {name:
|
86
|
+
fargs << {name: :options, type: "TensorOptions"} if func.tensor_options
|
85
87
|
|
86
88
|
cpp_args = []
|
87
89
|
fargs.each do |a|
|
@@ -94,11 +96,9 @@ void add_%{type}_functions(Module m) {
|
|
94
96
|
"OptionalTensor"
|
95
97
|
when "ScalarType?"
|
96
98
|
"torch::optional<ScalarType>"
|
97
|
-
when "Tensor[]"
|
98
|
-
"TensorList"
|
99
|
-
when "Tensor?[]"
|
99
|
+
when "Tensor[]", "Tensor?[]"
|
100
100
|
# TODO make optional
|
101
|
-
"
|
101
|
+
"std::vector<Tensor>"
|
102
102
|
when "int"
|
103
103
|
"int64_t"
|
104
104
|
when "int?"
|
@@ -112,43 +112,53 @@ void add_%{type}_functions(Module m) {
|
|
112
112
|
when "float"
|
113
113
|
"double"
|
114
114
|
when /\Aint\[/
|
115
|
-
"
|
115
|
+
"std::vector<int64_t>"
|
116
116
|
when /Tensor\(\S!?\)/
|
117
117
|
"Tensor &"
|
118
118
|
when "str"
|
119
119
|
"std::string"
|
120
120
|
when "TensorOptions"
|
121
121
|
"const torch::TensorOptions &"
|
122
|
-
|
122
|
+
when "Layout?"
|
123
|
+
"torch::optional<Layout>"
|
124
|
+
when "Device?"
|
125
|
+
"torch::optional<Device>"
|
126
|
+
when "Scalar", "bool", "ScalarType", "Layout", "Device", "Storage"
|
123
127
|
a[:type]
|
128
|
+
else
|
129
|
+
raise "Unknown type: #{a[:type]}"
|
124
130
|
end
|
125
131
|
|
126
|
-
t = "MyReduction" if a[:name] ==
|
132
|
+
t = "MyReduction" if a[:name] == :reduction && t == "int64_t"
|
127
133
|
cpp_args << [t, a[:name]].join(" ").sub("& ", "&")
|
128
134
|
end
|
129
135
|
|
130
136
|
dispatch = func.out? ? "#{func.base_name}_out" : func.base_name
|
131
137
|
args = fargs.map { |a| a[:name] }
|
132
138
|
args.unshift(*args.pop(func.out_size)) if func.out?
|
133
|
-
args.delete(
|
139
|
+
args.delete(:self) if def_method == :define_method
|
134
140
|
|
135
141
|
prefix = def_method == :define_method ? "self." : "torch::"
|
136
142
|
|
137
143
|
body = "#{prefix}#{dispatch}(#{args.join(", ")})"
|
138
144
|
|
139
|
-
if func.
|
145
|
+
if func.cpp_name == "_fill_diagonal_"
|
146
|
+
body = "to_ruby<torch::Tensor>(#{body})"
|
147
|
+
elsif !func.ret_void?
|
140
148
|
body = "wrap(#{body})"
|
141
149
|
end
|
142
150
|
|
143
|
-
cpp_defs << "
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
151
|
+
cpp_defs << "// #{func.func}
|
152
|
+
static #{func.ret_void? ? "void" : "Object"} #{type}#{func.cpp_name}(#{cpp_args.join(", ")})
|
153
|
+
{
|
154
|
+
return #{body};
|
155
|
+
}"
|
156
|
+
|
157
|
+
add_defs << "m.#{def_method}(\"#{func.cpp_name}\", #{type}#{func.cpp_name});"
|
148
158
|
end
|
149
159
|
|
150
160
|
hpp_contents = hpp_template % {type: type}
|
151
|
-
cpp_contents = cpp_template % {type: type, functions: cpp_defs.join("\n ")}
|
161
|
+
cpp_contents = cpp_template % {type: type, functions: cpp_defs.join("\n\n"), add_functions: add_defs.join("\n ")}
|
152
162
|
|
153
163
|
path = File.expand_path("../../../ext/torch", __dir__)
|
154
164
|
File.write("#{path}/#{type}_functions.hpp", hpp_contents)
|
data/lib/torch/native/parser.rb
CHANGED
@@ -6,14 +6,24 @@ module Torch
|
|
6
6
|
@name = @functions.first.ruby_name
|
7
7
|
@min_args = @functions.map { |f| f.args.count { |a| a[:pos] && !a[:has_default] } }.min
|
8
8
|
@max_args = @functions.map { |f| f.args.count { |a| a[:pos] } }.max
|
9
|
+
@int_array_first = @functions.all? { |c| c.args.first && c.args.first[:type] == "int[]" }
|
9
10
|
end
|
10
11
|
|
12
|
+
# TODO improve performance
|
13
|
+
# possibly move to C++ (see python_arg_parser.cpp)
|
11
14
|
def parse(args, options)
|
12
15
|
candidates = @functions.dup
|
13
16
|
|
14
|
-
#
|
15
|
-
|
16
|
-
|
17
|
+
# TODO check candidates individually to see if they match
|
18
|
+
if @int_array_first
|
19
|
+
int_args = []
|
20
|
+
while args.first.is_a?(Integer)
|
21
|
+
int_args << args.shift
|
22
|
+
end
|
23
|
+
if int_args.any?
|
24
|
+
raise ArgumentError, "argument '#{candidates.first.args.first[:name]}' must be array of ints, but found element of type #{args.first.class.name} at pos #{int_args.size + 1}" if args.any?
|
25
|
+
args.unshift(int_args)
|
26
|
+
end
|
17
27
|
end
|
18
28
|
|
19
29
|
# TODO account for args passed as options here
|
@@ -25,99 +35,60 @@ module Torch
|
|
25
35
|
|
26
36
|
candidates.reject! { |f| args.size > f.args.size }
|
27
37
|
|
28
|
-
# exclude functions missing required options
|
29
|
-
candidates.reject! do |func|
|
30
|
-
# TODO make more generic
|
31
|
-
func.out? && !options[:out]
|
32
|
-
end
|
33
|
-
|
34
38
|
# handle out with multiple
|
35
39
|
# there should only be one match, so safe to modify all
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
end
|
44
|
-
|
45
|
-
# exclude functions where options don't match
|
46
|
-
options.each do |k, v|
|
47
|
-
candidates.select! do |func|
|
48
|
-
func.args.any? { |a| a[:name] == k.to_s }
|
40
|
+
if options[:out]
|
41
|
+
if (out_func = candidates.find { |f| f.out? }) && out_func.out_size > 1
|
42
|
+
out_args = out_func.args.last(2).map { |a| a[:name] }
|
43
|
+
out_args.zip(options.delete(:out)).each do |k, v|
|
44
|
+
options[k] = v
|
45
|
+
end
|
46
|
+
candidates = [out_func]
|
49
47
|
end
|
50
|
-
|
51
|
-
|
48
|
+
else
|
49
|
+
# exclude functions missing required options
|
50
|
+
candidates.reject!(&:out?)
|
52
51
|
end
|
53
52
|
|
54
|
-
final_values =
|
53
|
+
final_values = nil
|
55
54
|
|
56
55
|
# check args
|
57
|
-
candidates.
|
56
|
+
while (func = candidates.shift)
|
58
57
|
good = true
|
59
58
|
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
59
|
+
# set values
|
60
|
+
# TODO use array instead of hash?
|
61
|
+
values = {}
|
62
|
+
args.each_with_index do |a, i|
|
63
|
+
values[func.arg_names[i]] = a
|
64
|
+
end
|
65
|
+
options.each do |k, v|
|
66
|
+
values[k] = v
|
67
|
+
end
|
68
|
+
func.arg_defaults.each do |k, v|
|
69
|
+
values[k] = v unless values.key?(k)
|
70
|
+
end
|
71
|
+
func.int_array_lengths.each do |k, len|
|
72
|
+
values[k] = [values[k]] * len if values[k].is_a?(Integer)
|
64
73
|
end
|
65
74
|
|
66
|
-
|
75
|
+
arg_checkers = func.arg_checkers
|
67
76
|
|
68
77
|
values.each_key do |k|
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
when "Tensor"
|
75
|
-
v.is_a?(Tensor)
|
76
|
-
when "Tensor?"
|
77
|
-
v.nil? || v.is_a?(Tensor)
|
78
|
-
when "Tensor[]", "Tensor?[]"
|
79
|
-
v.is_a?(Array) && v.all? { |v2| v2.is_a?(Tensor) }
|
80
|
-
when "int"
|
81
|
-
if k == "reduction"
|
82
|
-
v.is_a?(String)
|
83
|
-
else
|
84
|
-
v.is_a?(Integer)
|
85
|
-
end
|
86
|
-
when "int?"
|
87
|
-
v.is_a?(Integer) || v.nil?
|
88
|
-
when "float?"
|
89
|
-
v.is_a?(Numeric) || v.nil?
|
90
|
-
when "bool?"
|
91
|
-
v == true || v == false || v.nil?
|
92
|
-
when "float"
|
93
|
-
v.is_a?(Numeric)
|
94
|
-
when /int\[.*\]/
|
95
|
-
if v.is_a?(Integer)
|
96
|
-
size = t[4..-2]
|
97
|
-
raise Error, "Unknown size: #{size}. Please report a bug with #{@name}." unless size =~ /\A\d+\z/
|
98
|
-
v = [v] * size.to_i
|
99
|
-
values[k] = v
|
100
|
-
end
|
101
|
-
v.is_a?(Array) && v.all? { |v2| v2.is_a?(Integer) }
|
102
|
-
when "Scalar"
|
103
|
-
v.is_a?(Numeric)
|
104
|
-
when "Scalar?"
|
105
|
-
v.is_a?(Numeric) || v.nil?
|
106
|
-
when "ScalarType"
|
107
|
-
false # not supported yet
|
108
|
-
when "ScalarType?"
|
109
|
-
v.nil?
|
110
|
-
when "bool"
|
111
|
-
v == true || v == false
|
112
|
-
when "str"
|
113
|
-
v.is_a?(String)
|
114
|
-
else
|
115
|
-
raise Error, "Unknown argument type: #{arg_types[k]}. Please report a bug with #{@name}."
|
78
|
+
unless arg_checkers.key?(k)
|
79
|
+
good = false
|
80
|
+
if candidates.empty?
|
81
|
+
# TODO show all bad keywords at once like Ruby?
|
82
|
+
return {error: "unknown keyword: #{k}"}
|
116
83
|
end
|
84
|
+
break
|
85
|
+
end
|
117
86
|
|
118
|
-
|
119
|
-
|
120
|
-
|
87
|
+
unless arg_checkers[k].call(values[k])
|
88
|
+
good = false
|
89
|
+
if candidates.empty?
|
90
|
+
t = func.arg_types[k]
|
91
|
+
k = :input if k == :self
|
121
92
|
return {error: "#{@name}(): argument '#{k}' must be #{t}"}
|
122
93
|
end
|
123
94
|
break
|
@@ -126,17 +97,15 @@ module Torch
|
|
126
97
|
|
127
98
|
if good
|
128
99
|
final_values = values
|
100
|
+
break
|
129
101
|
end
|
130
|
-
|
131
|
-
good
|
132
102
|
end
|
133
103
|
|
134
|
-
|
104
|
+
unless final_values
|
135
105
|
raise Error, "This should never happen. Please report a bug with #{@name}."
|
136
106
|
end
|
137
107
|
|
138
|
-
|
139
|
-
args = func.args.map { |a| final_values[a[:name]] }
|
108
|
+
args = func.arg_names.map { |k| final_values[k] }
|
140
109
|
args << TensorOptions.new.dtype(6) if func.tensor_options
|
141
110
|
{
|
142
111
|
name: func.cpp_name,
|
data/lib/torch/nn/functional.rb
CHANGED
@@ -469,6 +469,77 @@ module Torch
|
|
469
469
|
Torch.triplet_margin_loss(anchor, positive, negative, margin, p, eps, swap, reduction)
|
470
470
|
end
|
471
471
|
|
472
|
+
# vision
|
473
|
+
|
474
|
+
def interpolate(input, size: nil, scale_factor: nil, mode: "nearest", align_corners: nil, recompute_scale_factor: nil)
|
475
|
+
if ["nearest", "area"].include?(mode)
|
476
|
+
unless align_corners.nil?
|
477
|
+
raise ArgumentError, "align_corners option can only be set with the interpolating modes: linear | bilinear | bicubic | trilinear"
|
478
|
+
end
|
479
|
+
else
|
480
|
+
if align_corners.nil?
|
481
|
+
align_corners = false
|
482
|
+
end
|
483
|
+
end
|
484
|
+
|
485
|
+
scale_factor_len = input.dim - 2
|
486
|
+
scale_factor_list = [nil] * scale_factor_len
|
487
|
+
# default value of recompute_scale_factor is False
|
488
|
+
if !scale_factor.nil? && (recompute_scale_factor == false || recompute_scale_factor.nil?)
|
489
|
+
if scale_factor.is_a?(Array)
|
490
|
+
_scale_factor_repeated = scale_factor
|
491
|
+
else
|
492
|
+
_scale_factor_repeated = [scale_factor] * scale_factor_len
|
493
|
+
end
|
494
|
+
scale_factor_list = _scale_factor_repeated
|
495
|
+
end
|
496
|
+
|
497
|
+
# Give this variable a short name because it has to be repeated multiple times below.
|
498
|
+
sfl = scale_factor_list
|
499
|
+
|
500
|
+
closed_over_args = [input, size, scale_factor, recompute_scale_factor]
|
501
|
+
output_size = _interp_output_size(closed_over_args)
|
502
|
+
if input.dim == 3 && mode == "nearest"
|
503
|
+
NN.upsample_nearest1d(input, output_size, sfl[0])
|
504
|
+
elsif input.dim == 4 && mode == "nearest"
|
505
|
+
NN.upsample_nearest2d(input, output_size, sfl[0], sfl[1])
|
506
|
+
elsif input.dim == 5 && mode == "nearest"
|
507
|
+
NN.upsample_nearest3d(input, output_size, sfl[0], sfl[1], sfl[2])
|
508
|
+
elsif input.dim == 3 && mode == "area"
|
509
|
+
adaptive_avg_pool1d(input, output_size)
|
510
|
+
elsif input.dim == 4 && mode == "area"
|
511
|
+
adaptive_avg_pool2d(input, output_size)
|
512
|
+
elsif input.dim == 5 && mode == "area"
|
513
|
+
adaptive_avg_pool3d(input, output_size)
|
514
|
+
elsif input.dim == 3 && mode == "linear"
|
515
|
+
# assert align_corners is not None
|
516
|
+
NN.upsample_linear1d(input, output_size, align_corners, sfl[0])
|
517
|
+
elsif input.dim == 3 && mode == "bilinear"
|
518
|
+
raise ArgumentError, "Got 3D input, but bilinear mode needs 4D input"
|
519
|
+
elsif input.dim == 3 && mode == "trilinear"
|
520
|
+
raise ArgumentError, "Got 3D input, but trilinear mode needs 5D input"
|
521
|
+
elsif input.dim == 4 && mode == "linear"
|
522
|
+
raise ArgumentError, "Got 4D input, but linear mode needs 3D input"
|
523
|
+
elsif input.dim == 4 && mode == "bilinear"
|
524
|
+
# assert align_corners is not None
|
525
|
+
NN.upsample_bilinear2d(input, output_size, align_corners, sfl[0], sfl[1])
|
526
|
+
elsif input.dim == 4 && mode == "trilinear"
|
527
|
+
raise ArgumentError, "Got 4D input, but trilinear mode needs 5D input"
|
528
|
+
elsif input.dim == 5 && mode == "linear"
|
529
|
+
raise ArgumentError, "Got 5D input, but linear mode needs 3D input"
|
530
|
+
elsif input.dim == 5 && mode == "bilinear"
|
531
|
+
raise ArgumentError, "Got 5D input, but bilinear mode needs 4D input"
|
532
|
+
elsif input.dim == 5 && mode == "trilinear"
|
533
|
+
# assert align_corners is not None
|
534
|
+
NN.upsample_trilinear3d(input, output_size, align_corners, sfl[0], sfl[1], sfl[2])
|
535
|
+
elsif input.dim == 4 && mode == "bicubic"
|
536
|
+
# assert align_corners is not None
|
537
|
+
NN.upsample_bicubic2d(input, output_size, align_corners, sfl[0], sfl[1])
|
538
|
+
else
|
539
|
+
raise ArgumentError, "Input Error: Only 3D, 4D and 5D input Tensors supported (got #{input.dim}D) for the modes: nearest | linear | bilinear | bicubic | trilinear (got #{mode})"
|
540
|
+
end
|
541
|
+
end
|
542
|
+
|
472
543
|
private
|
473
544
|
|
474
545
|
def softmax_dim(ndim)
|
@@ -484,6 +555,41 @@ module Torch
|
|
484
555
|
out_size.zip(defaults.last(out_size.length)).map { |v, d| v || d }
|
485
556
|
end
|
486
557
|
end
|
558
|
+
|
559
|
+
def _interp_output_size(closed_over_args)
|
560
|
+
input, size, scale_factor, recompute_scale_factor = closed_over_args
|
561
|
+
dim = input.dim - 2
|
562
|
+
if size.nil? && scale_factor.nil?
|
563
|
+
raise ArgumentError, "either size or scale_factor should be defined"
|
564
|
+
end
|
565
|
+
if !size.nil? && !scale_factor.nil?
|
566
|
+
raise ArgumentError, "only one of size or scale_factor should be defined"
|
567
|
+
end
|
568
|
+
if !scale_factor.nil?
|
569
|
+
if scale_factor.is_a?(Array)
|
570
|
+
if scale_factor.length != dim
|
571
|
+
raise ArgumentError, "scale_factor shape must match input shape. Input is #{dim}D, scale_factor size is #{scale_factor.length}"
|
572
|
+
end
|
573
|
+
end
|
574
|
+
end
|
575
|
+
|
576
|
+
if !size.nil?
|
577
|
+
if size.is_a?(Array)
|
578
|
+
return size
|
579
|
+
else
|
580
|
+
return [size] * dim
|
581
|
+
end
|
582
|
+
end
|
583
|
+
|
584
|
+
raise "Failed assertion" if scale_factor.nil?
|
585
|
+
if scale_factor.is_a?(Array)
|
586
|
+
scale_factors = scale_factor
|
587
|
+
else
|
588
|
+
scale_factors = [scale_factor] * dim
|
589
|
+
end
|
590
|
+
|
591
|
+
dim.times.map { |i| (input.size(i + 2) * scale_factors[i]).floor }
|
592
|
+
end
|
487
593
|
end
|
488
594
|
end
|
489
595
|
|
data/lib/torch/nn/module.rb
CHANGED
@@ -0,0 +1,31 @@
|
|
1
|
+
module Torch
|
2
|
+
module NN
|
3
|
+
class Upsample < Module
|
4
|
+
def initialize(size: nil, scale_factor: nil, mode: "nearest", align_corners: nil)
|
5
|
+
super()
|
6
|
+
@size = size
|
7
|
+
if scale_factor.is_a?(Array)
|
8
|
+
@scale_factor = scale_factor.map(&:to_f)
|
9
|
+
else
|
10
|
+
@scale_factor = scale_factor ? scale_factor.to_f : nil
|
11
|
+
end
|
12
|
+
@mode = mode
|
13
|
+
@align_corners = align_corners
|
14
|
+
end
|
15
|
+
|
16
|
+
def forward(input)
|
17
|
+
F.interpolate(input, size: @size, scale_factor: @scale_factor, mode: @mode, align_corners: @align_corners)
|
18
|
+
end
|
19
|
+
|
20
|
+
def extra_inspect
|
21
|
+
if !@scale_factor.nil?
|
22
|
+
info = "scale_factor: #{@scale_factor.inspect}"
|
23
|
+
else
|
24
|
+
info = "size: #{@size.inspect}"
|
25
|
+
end
|
26
|
+
info += ", mode: #{@mode.inspect}"
|
27
|
+
info
|
28
|
+
end
|
29
|
+
end
|
30
|
+
end
|
31
|
+
end
|
data/lib/torch/tensor.rb
CHANGED
@@ -48,6 +48,11 @@ module Torch
|
|
48
48
|
end
|
49
49
|
|
50
50
|
def to(device = nil, dtype: nil, non_blocking: false, copy: false)
|
51
|
+
if device.is_a?(Symbol) && !dtype
|
52
|
+
dtype = device
|
53
|
+
device = nil
|
54
|
+
end
|
55
|
+
|
51
56
|
device ||= self.device
|
52
57
|
device = Device.new(device) if device.is_a?(String)
|
53
58
|
|
@@ -74,10 +79,6 @@ module Torch
|
|
74
79
|
end
|
75
80
|
end
|
76
81
|
|
77
|
-
def shape
|
78
|
-
dim.times.map { |i| size(i) }
|
79
|
-
end
|
80
|
-
|
81
82
|
# mirror Python len()
|
82
83
|
def length
|
83
84
|
size(0)
|
@@ -119,9 +120,14 @@ module Torch
|
|
119
120
|
end
|
120
121
|
|
121
122
|
def type(dtype)
|
122
|
-
|
123
|
-
|
124
|
-
|
123
|
+
if dtype.is_a?(Class)
|
124
|
+
raise Error, "Invalid type: #{dtype}" unless TENSOR_TYPE_CLASSES.include?(dtype)
|
125
|
+
dtype.new(self)
|
126
|
+
else
|
127
|
+
enum = DTYPE_TO_ENUM[dtype]
|
128
|
+
raise Error, "Invalid type: #{dtype}" unless enum
|
129
|
+
_type(enum)
|
130
|
+
end
|
125
131
|
end
|
126
132
|
|
127
133
|
def reshape(*size)
|
data/lib/torch/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.3.
|
4
|
+
version: 0.3.7
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2020-09-
|
11
|
+
date: 2020-09-23 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|
@@ -238,6 +238,7 @@ files:
|
|
238
238
|
- lib/torch/nn/tanhshrink.rb
|
239
239
|
- lib/torch/nn/triplet_margin_loss.rb
|
240
240
|
- lib/torch/nn/unfold.rb
|
241
|
+
- lib/torch/nn/upsample.rb
|
241
242
|
- lib/torch/nn/utils.rb
|
242
243
|
- lib/torch/nn/weighted_loss.rb
|
243
244
|
- lib/torch/nn/zero_pad2d.rb
|