torch-rb 0.3.2 → 0.3.7
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +28 -0
- data/README.md +7 -2
- data/ext/torch/ext.cpp +60 -20
- data/ext/torch/extconf.rb +3 -0
- data/ext/torch/templates.cpp +36 -0
- data/ext/torch/templates.hpp +81 -87
- data/lib/torch.rb +71 -19
- data/lib/torch/native/dispatcher.rb +30 -8
- data/lib/torch/native/function.rb +93 -4
- data/lib/torch/native/generator.rb +45 -41
- data/lib/torch/native/parser.rb +57 -76
- data/lib/torch/nn/functional.rb +112 -2
- data/lib/torch/nn/leaky_relu.rb +3 -3
- data/lib/torch/nn/module.rb +9 -1
- data/lib/torch/nn/upsample.rb +31 -0
- data/lib/torch/tensor.rb +45 -51
- data/lib/torch/utils/data/data_loader.rb +2 -0
- data/lib/torch/utils/data/tensor_dataset.rb +2 -0
- data/lib/torch/version.rb +1 -1
- metadata +3 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 8a1852ee3d1ecc7a29c23259b8c328a95030a270b7c11f37f22049177898652e
|
4
|
+
data.tar.gz: 56823f1815d3c0c4d5d5c01ef76d781b792b3e4e7c68c0332a149b883a54c7c8
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: bed15510cfeaa555d71f1e1f46ed8944893bd349a07c4316dcd63429fe76e13facd8794399ef97fc400d05796579f2e84822b62c98c71dc996e211ad04113ae2
|
7
|
+
data.tar.gz: aa05e3645e363eda27274323cdb7fb316342074d1d5afe8f7ee6bfd9819da7883b43d084beb6b29011c631c04fdddc8e6789db41c7c84c53ba9ed152d3338b09
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,31 @@
|
|
1
|
+
## 0.3.7 (2020-09-22)
|
2
|
+
|
3
|
+
- Improved performance
|
4
|
+
- Added `Upsample`
|
5
|
+
- Added support for passing tensor class to `type` method
|
6
|
+
- Fixed error with buffers on GPU
|
7
|
+
- Fixed error with `new_full`
|
8
|
+
- Fixed issue with `numo` method and non-contiguous tensors
|
9
|
+
|
10
|
+
## 0.3.6 (2020-09-17)
|
11
|
+
|
12
|
+
- Added `inplace` option for leaky ReLU
|
13
|
+
- Fixed error with methods that return a tensor list (`chunk`, `split`, and `unbind`)
|
14
|
+
- Fixed error with buffers on GPU
|
15
|
+
|
16
|
+
## 0.3.5 (2020-09-04)
|
17
|
+
|
18
|
+
- Fixed error with data loader (due to `dtype` of `randperm`)
|
19
|
+
|
20
|
+
## 0.3.4 (2020-08-26)
|
21
|
+
|
22
|
+
- Added `Torch.clamp` method
|
23
|
+
|
24
|
+
## 0.3.3 (2020-08-25)
|
25
|
+
|
26
|
+
- Added spectral ops
|
27
|
+
- Fixed tensor indexing
|
28
|
+
|
1
29
|
## 0.3.2 (2020-08-24)
|
2
30
|
|
3
31
|
- Added `enable_grad` method
|
data/README.md
CHANGED
@@ -2,7 +2,11 @@
|
|
2
2
|
|
3
3
|
:fire: Deep learning for Ruby, powered by [LibTorch](https://pytorch.org)
|
4
4
|
|
5
|
-
|
5
|
+
Check out:
|
6
|
+
|
7
|
+
- [TorchVision](https://github.com/ankane/torchvision) for computer vision tasks
|
8
|
+
- [TorchText](https://github.com/ankane/torchtext) for text and NLP tasks
|
9
|
+
- [TorchAudio](https://github.com/ankane/torchaudio) for audio tasks
|
6
10
|
|
7
11
|
[](https://travis-ci.org/ankane/torch.rb)
|
8
12
|
|
@@ -398,6 +402,7 @@ Here are a few full examples:
|
|
398
402
|
- [Image classification with MNIST](examples/mnist) ([日本語版](https://qiita.com/kojix2/items/c19c36dc1bf73ea93409))
|
399
403
|
- [Collaborative filtering with MovieLens](examples/movielens)
|
400
404
|
- [Sequence models and word embeddings](examples/nlp)
|
405
|
+
- [Generative adversarial networks](examples/gan)
|
401
406
|
|
402
407
|
## LibTorch Installation
|
403
408
|
|
@@ -411,7 +416,7 @@ Here’s the list of compatible versions.
|
|
411
416
|
|
412
417
|
Torch.rb | LibTorch
|
413
418
|
--- | ---
|
414
|
-
0.3.0-0.3.
|
419
|
+
0.3.0-0.3.4 | 1.6.0
|
415
420
|
0.2.0-0.2.7 | 1.5.0-1.5.1
|
416
421
|
0.1.8 | 1.4.0
|
417
422
|
0.1.0-0.1.7 | 1.3.1
|
data/ext/torch/ext.cpp
CHANGED
@@ -16,6 +16,7 @@
|
|
16
16
|
#include "nn_functions.hpp"
|
17
17
|
|
18
18
|
using namespace Rice;
|
19
|
+
using torch::indexing::TensorIndex;
|
19
20
|
|
20
21
|
// need to make a distinction between parameters and tensors
|
21
22
|
class Parameter: public torch::autograd::Variable {
|
@@ -28,6 +29,15 @@ void handle_error(torch::Error const & ex)
|
|
28
29
|
throw Exception(rb_eRuntimeError, ex.what_without_backtrace());
|
29
30
|
}
|
30
31
|
|
32
|
+
std::vector<TensorIndex> index_vector(Array a) {
|
33
|
+
auto indices = std::vector<TensorIndex>();
|
34
|
+
indices.reserve(a.size());
|
35
|
+
for (size_t i = 0; i < a.size(); i++) {
|
36
|
+
indices.push_back(from_ruby<TensorIndex>(a[i]));
|
37
|
+
}
|
38
|
+
return indices;
|
39
|
+
}
|
40
|
+
|
31
41
|
extern "C"
|
32
42
|
void Init_ext()
|
33
43
|
{
|
@@ -58,6 +68,13 @@ void Init_ext()
|
|
58
68
|
return generator.seed();
|
59
69
|
});
|
60
70
|
|
71
|
+
Class rb_cTensorIndex = define_class_under<TensorIndex>(rb_mTorch, "TensorIndex")
|
72
|
+
.define_singleton_method("boolean", *[](bool value) { return TensorIndex(value); })
|
73
|
+
.define_singleton_method("integer", *[](int64_t value) { return TensorIndex(value); })
|
74
|
+
.define_singleton_method("tensor", *[](torch::Tensor& value) { return TensorIndex(value); })
|
75
|
+
.define_singleton_method("slice", *[](torch::optional<int64_t> start_index, torch::optional<int64_t> stop_index) { return TensorIndex(torch::indexing::Slice(start_index, stop_index)); })
|
76
|
+
.define_singleton_method("none", *[]() { return TensorIndex(torch::indexing::None); });
|
77
|
+
|
61
78
|
// https://pytorch.org/cppdocs/api/structc10_1_1_i_value.html
|
62
79
|
Class rb_cIValue = define_class_under<torch::IValue>(rb_mTorch, "IValue")
|
63
80
|
.add_handler<torch::Error>(handle_error)
|
@@ -215,7 +232,7 @@ void Init_ext()
|
|
215
232
|
})
|
216
233
|
.define_singleton_method(
|
217
234
|
"_empty",
|
218
|
-
*[](
|
235
|
+
*[](std::vector<int64_t> size, const torch::TensorOptions &options) {
|
219
236
|
return torch::empty(size, options);
|
220
237
|
})
|
221
238
|
.define_singleton_method(
|
@@ -225,7 +242,7 @@ void Init_ext()
|
|
225
242
|
})
|
226
243
|
.define_singleton_method(
|
227
244
|
"_full",
|
228
|
-
*[](
|
245
|
+
*[](std::vector<int64_t> size, Scalar fill_value, const torch::TensorOptions& options) {
|
229
246
|
return torch::full(size, fill_value, options);
|
230
247
|
})
|
231
248
|
.define_singleton_method(
|
@@ -240,22 +257,22 @@ void Init_ext()
|
|
240
257
|
})
|
241
258
|
.define_singleton_method(
|
242
259
|
"_ones",
|
243
|
-
*[](
|
260
|
+
*[](std::vector<int64_t> size, const torch::TensorOptions &options) {
|
244
261
|
return torch::ones(size, options);
|
245
262
|
})
|
246
263
|
.define_singleton_method(
|
247
264
|
"_rand",
|
248
|
-
*[](
|
265
|
+
*[](std::vector<int64_t> size, const torch::TensorOptions &options) {
|
249
266
|
return torch::rand(size, options);
|
250
267
|
})
|
251
268
|
.define_singleton_method(
|
252
269
|
"_randint",
|
253
|
-
*[](int64_t low, int64_t high,
|
270
|
+
*[](int64_t low, int64_t high, std::vector<int64_t> size, const torch::TensorOptions &options) {
|
254
271
|
return torch::randint(low, high, size, options);
|
255
272
|
})
|
256
273
|
.define_singleton_method(
|
257
274
|
"_randn",
|
258
|
-
*[](
|
275
|
+
*[](std::vector<int64_t> size, const torch::TensorOptions &options) {
|
259
276
|
return torch::randn(size, options);
|
260
277
|
})
|
261
278
|
.define_singleton_method(
|
@@ -265,7 +282,7 @@ void Init_ext()
|
|
265
282
|
})
|
266
283
|
.define_singleton_method(
|
267
284
|
"_zeros",
|
268
|
-
*[](
|
285
|
+
*[](std::vector<int64_t> size, const torch::TensorOptions &options) {
|
269
286
|
return torch::zeros(size, options);
|
270
287
|
})
|
271
288
|
// begin operations
|
@@ -284,20 +301,15 @@ void Init_ext()
|
|
284
301
|
// https://github.com/pytorch/pytorch/issues/20356#issuecomment-567663701
|
285
302
|
return torch::pickle_load(v);
|
286
303
|
})
|
287
|
-
.define_singleton_method(
|
288
|
-
"_binary_cross_entropy_with_logits",
|
289
|
-
*[](const Tensor &input, const Tensor &target, OptionalTensor weight, OptionalTensor pos_weight, MyReduction reduction) {
|
290
|
-
return torch::binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction);
|
291
|
-
})
|
292
304
|
.define_singleton_method(
|
293
305
|
"_from_blob",
|
294
|
-
*[](String s,
|
306
|
+
*[](String s, std::vector<int64_t> size, const torch::TensorOptions &options) {
|
295
307
|
void *data = const_cast<char *>(s.c_str());
|
296
308
|
return torch::from_blob(data, size, options);
|
297
309
|
})
|
298
310
|
.define_singleton_method(
|
299
311
|
"_tensor",
|
300
|
-
*[](Array a,
|
312
|
+
*[](Array a, std::vector<int64_t> size, const torch::TensorOptions &options) {
|
301
313
|
auto dtype = options.dtype();
|
302
314
|
torch::Tensor t;
|
303
315
|
if (dtype == torch::kBool) {
|
@@ -330,6 +342,28 @@ void Init_ext()
|
|
330
342
|
.define_method("numel", &torch::Tensor::numel)
|
331
343
|
.define_method("element_size", &torch::Tensor::element_size)
|
332
344
|
.define_method("requires_grad", &torch::Tensor::requires_grad)
|
345
|
+
// in C++ for performance
|
346
|
+
.define_method(
|
347
|
+
"shape",
|
348
|
+
*[](Tensor& self) {
|
349
|
+
Array a;
|
350
|
+
for (auto &size : self.sizes()) {
|
351
|
+
a.push(size);
|
352
|
+
}
|
353
|
+
return a;
|
354
|
+
})
|
355
|
+
.define_method(
|
356
|
+
"_index",
|
357
|
+
*[](Tensor& self, Array indices) {
|
358
|
+
auto vec = index_vector(indices);
|
359
|
+
return self.index(vec);
|
360
|
+
})
|
361
|
+
.define_method(
|
362
|
+
"_index_put_custom",
|
363
|
+
*[](Tensor& self, Array indices, torch::Tensor& value) {
|
364
|
+
auto vec = index_vector(indices);
|
365
|
+
return self.index_put_(vec, value);
|
366
|
+
})
|
333
367
|
.define_method(
|
334
368
|
"contiguous?",
|
335
369
|
*[](Tensor& self) {
|
@@ -350,11 +384,6 @@ void Init_ext()
|
|
350
384
|
*[](Tensor& self, bool requires_grad) {
|
351
385
|
return self.set_requires_grad(requires_grad);
|
352
386
|
})
|
353
|
-
.define_method(
|
354
|
-
"_backward",
|
355
|
-
*[](Tensor& self, OptionalTensor gradient, bool create_graph, bool retain_graph) {
|
356
|
-
return self.backward(gradient, create_graph, retain_graph);
|
357
|
-
})
|
358
387
|
.define_method(
|
359
388
|
"grad",
|
360
389
|
*[](Tensor& self) {
|
@@ -401,9 +430,19 @@ void Init_ext()
|
|
401
430
|
tensor = tensor.to(device);
|
402
431
|
}
|
403
432
|
|
433
|
+
if (!tensor.is_contiguous()) {
|
434
|
+
tensor = tensor.contiguous();
|
435
|
+
}
|
436
|
+
|
404
437
|
auto data_ptr = (const char *) tensor.data_ptr();
|
405
438
|
return std::string(data_ptr, tensor.numel() * tensor.element_size());
|
406
439
|
})
|
440
|
+
// for TorchVision
|
441
|
+
.define_method(
|
442
|
+
"_data_ptr",
|
443
|
+
*[](Tensor& self) {
|
444
|
+
return reinterpret_cast<uintptr_t>(self.data_ptr());
|
445
|
+
})
|
407
446
|
// TODO figure out a better way to do this
|
408
447
|
.define_method(
|
409
448
|
"_flat_data",
|
@@ -508,6 +547,7 @@ void Init_ext()
|
|
508
547
|
});
|
509
548
|
|
510
549
|
Module rb_mInit = define_module_under(rb_mNN, "Init")
|
550
|
+
.add_handler<torch::Error>(handle_error)
|
511
551
|
.define_singleton_method(
|
512
552
|
"_calculate_gain",
|
513
553
|
*[](NonlinearityType nonlinearity, double param) {
|
@@ -594,8 +634,8 @@ void Init_ext()
|
|
594
634
|
});
|
595
635
|
|
596
636
|
Class rb_cDevice = define_class_under<torch::Device>(rb_mTorch, "Device")
|
597
|
-
.define_constructor(Constructor<torch::Device, std::string>())
|
598
637
|
.add_handler<torch::Error>(handle_error)
|
638
|
+
.define_constructor(Constructor<torch::Device, std::string>())
|
599
639
|
.define_method("index", &torch::Device::index)
|
600
640
|
.define_method("index?", &torch::Device::has_index)
|
601
641
|
.define_method(
|
data/ext/torch/extconf.rb
CHANGED
data/ext/torch/templates.cpp
CHANGED
@@ -2,6 +2,34 @@
|
|
2
2
|
#include <rice/Object.hpp>
|
3
3
|
#include "templates.hpp"
|
4
4
|
|
5
|
+
Object wrap(bool x) {
|
6
|
+
return to_ruby<bool>(x);
|
7
|
+
}
|
8
|
+
|
9
|
+
Object wrap(int64_t x) {
|
10
|
+
return to_ruby<int64_t>(x);
|
11
|
+
}
|
12
|
+
|
13
|
+
Object wrap(double x) {
|
14
|
+
return to_ruby<double>(x);
|
15
|
+
}
|
16
|
+
|
17
|
+
Object wrap(torch::Tensor x) {
|
18
|
+
return to_ruby<torch::Tensor>(x);
|
19
|
+
}
|
20
|
+
|
21
|
+
Object wrap(torch::Scalar x) {
|
22
|
+
return to_ruby<torch::Scalar>(x);
|
23
|
+
}
|
24
|
+
|
25
|
+
Object wrap(torch::ScalarType x) {
|
26
|
+
return to_ruby<torch::ScalarType>(x);
|
27
|
+
}
|
28
|
+
|
29
|
+
Object wrap(torch::QScheme x) {
|
30
|
+
return to_ruby<torch::QScheme>(x);
|
31
|
+
}
|
32
|
+
|
5
33
|
Object wrap(std::tuple<torch::Tensor, torch::Tensor> x) {
|
6
34
|
Array a;
|
7
35
|
a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
|
@@ -53,3 +81,11 @@ Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) {
|
|
53
81
|
a.push(to_ruby<int64_t>(std::get<3>(x)));
|
54
82
|
return Object(a);
|
55
83
|
}
|
84
|
+
|
85
|
+
Object wrap(std::vector<torch::Tensor> x) {
|
86
|
+
Array a;
|
87
|
+
for (auto& t : x) {
|
88
|
+
a.push(to_ruby<torch::Tensor>(t));
|
89
|
+
}
|
90
|
+
return Object(a);
|
91
|
+
}
|
data/ext/torch/templates.hpp
CHANGED
@@ -9,72 +9,35 @@
|
|
9
9
|
|
10
10
|
using namespace Rice;
|
11
11
|
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
Array a = Array(o);
|
19
|
-
for (size_t i = 0; i < a.size(); i++) {
|
20
|
-
vec.push_back(from_ruby<int64_t>(a[i]));
|
21
|
-
}
|
22
|
-
}
|
23
|
-
operator torch::IntArrayRef() {
|
24
|
-
return torch::IntArrayRef(vec);
|
25
|
-
}
|
26
|
-
};
|
12
|
+
using torch::Device;
|
13
|
+
using torch::Scalar;
|
14
|
+
using torch::ScalarType;
|
15
|
+
using torch::Tensor;
|
16
|
+
using torch::IntArrayRef;
|
17
|
+
using torch::TensorList;
|
27
18
|
|
28
19
|
template<>
|
29
20
|
inline
|
30
|
-
|
21
|
+
std::vector<int64_t> from_ruby<std::vector<int64_t>>(Object x)
|
31
22
|
{
|
32
|
-
|
23
|
+
Array a = Array(x);
|
24
|
+
std::vector<int64_t> vec(a.size());
|
25
|
+
for (size_t i = 0; i < a.size(); i++) {
|
26
|
+
vec[i] = from_ruby<int64_t>(a[i]);
|
27
|
+
}
|
28
|
+
return vec;
|
33
29
|
}
|
34
30
|
|
35
|
-
// for now
|
36
|
-
class Scalar {
|
37
|
-
torch::Scalar value;
|
38
|
-
public:
|
39
|
-
Scalar(Object o) {
|
40
|
-
// TODO cast based on Ruby type
|
41
|
-
if (o.rb_type() == T_FIXNUM) {
|
42
|
-
value = torch::Scalar(from_ruby<int64_t>(o));
|
43
|
-
} else {
|
44
|
-
value = torch::Scalar(from_ruby<float>(o));
|
45
|
-
}
|
46
|
-
}
|
47
|
-
operator torch::Scalar() {
|
48
|
-
return value;
|
49
|
-
}
|
50
|
-
};
|
51
|
-
|
52
31
|
template<>
|
53
32
|
inline
|
54
|
-
|
33
|
+
std::vector<Tensor> from_ruby<std::vector<Tensor>>(Object x)
|
55
34
|
{
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
TensorList(Object o) {
|
63
|
-
Array a = Array(o);
|
64
|
-
for (size_t i = 0; i < a.size(); i++) {
|
65
|
-
vec.push_back(from_ruby<torch::Tensor>(a[i]));
|
66
|
-
}
|
67
|
-
}
|
68
|
-
operator torch::TensorList() {
|
69
|
-
return torch::TensorList(vec);
|
70
|
-
}
|
71
|
-
};
|
72
|
-
|
73
|
-
template<>
|
74
|
-
inline
|
75
|
-
TensorList from_ruby<TensorList>(Object x)
|
76
|
-
{
|
77
|
-
return TensorList(x);
|
35
|
+
Array a = Array(x);
|
36
|
+
std::vector<Tensor> vec(a.size());
|
37
|
+
for (size_t i = 0; i < a.size(); i++) {
|
38
|
+
vec[i] = from_ruby<Tensor>(a[i]);
|
39
|
+
}
|
40
|
+
return vec;
|
78
41
|
}
|
79
42
|
|
80
43
|
class FanModeType {
|
@@ -174,8 +137,6 @@ MyReduction from_ruby<MyReduction>(Object x)
|
|
174
137
|
return MyReduction(x);
|
175
138
|
}
|
176
139
|
|
177
|
-
typedef torch::Tensor Tensor;
|
178
|
-
|
179
140
|
class OptionalTensor {
|
180
141
|
Object value;
|
181
142
|
public:
|
@@ -190,6 +151,17 @@ class OptionalTensor {
|
|
190
151
|
}
|
191
152
|
};
|
192
153
|
|
154
|
+
template<>
|
155
|
+
inline
|
156
|
+
Scalar from_ruby<Scalar>(Object x)
|
157
|
+
{
|
158
|
+
if (x.rb_type() == T_FIXNUM) {
|
159
|
+
return torch::Scalar(from_ruby<int64_t>(x));
|
160
|
+
} else {
|
161
|
+
return torch::Scalar(from_ruby<double>(x));
|
162
|
+
}
|
163
|
+
}
|
164
|
+
|
193
165
|
template<>
|
194
166
|
inline
|
195
167
|
OptionalTensor from_ruby<OptionalTensor>(Object x)
|
@@ -197,50 +169,72 @@ OptionalTensor from_ruby<OptionalTensor>(Object x)
|
|
197
169
|
return OptionalTensor(x);
|
198
170
|
}
|
199
171
|
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
}
|
172
|
+
template<>
|
173
|
+
inline
|
174
|
+
torch::optional<torch::ScalarType> from_ruby<torch::optional<torch::ScalarType>>(Object x)
|
175
|
+
{
|
176
|
+
if (x.is_nil()) {
|
177
|
+
return torch::nullopt;
|
178
|
+
} else {
|
179
|
+
return torch::optional<torch::ScalarType>{from_ruby<torch::ScalarType>(x)};
|
180
|
+
}
|
181
|
+
}
|
210
182
|
|
211
183
|
template<>
|
212
184
|
inline
|
213
|
-
|
185
|
+
torch::optional<int64_t> from_ruby<torch::optional<int64_t>>(Object x)
|
214
186
|
{
|
215
|
-
|
187
|
+
if (x.is_nil()) {
|
188
|
+
return torch::nullopt;
|
189
|
+
} else {
|
190
|
+
return torch::optional<int64_t>{from_ruby<int64_t>(x)};
|
191
|
+
}
|
216
192
|
}
|
217
193
|
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
return ScalarType(value);
|
229
|
-
}
|
230
|
-
};
|
194
|
+
template<>
|
195
|
+
inline
|
196
|
+
torch::optional<double> from_ruby<torch::optional<double>>(Object x)
|
197
|
+
{
|
198
|
+
if (x.is_nil()) {
|
199
|
+
return torch::nullopt;
|
200
|
+
} else {
|
201
|
+
return torch::optional<double>{from_ruby<double>(x)};
|
202
|
+
}
|
203
|
+
}
|
231
204
|
|
232
205
|
template<>
|
233
206
|
inline
|
234
|
-
|
207
|
+
torch::optional<bool> from_ruby<torch::optional<bool>>(Object x)
|
235
208
|
{
|
236
|
-
|
209
|
+
if (x.is_nil()) {
|
210
|
+
return torch::nullopt;
|
211
|
+
} else {
|
212
|
+
return torch::optional<bool>{from_ruby<bool>(x)};
|
213
|
+
}
|
237
214
|
}
|
238
215
|
|
239
|
-
|
216
|
+
template<>
|
217
|
+
inline
|
218
|
+
torch::optional<Scalar> from_ruby<torch::optional<Scalar>>(Object x)
|
219
|
+
{
|
220
|
+
if (x.is_nil()) {
|
221
|
+
return torch::nullopt;
|
222
|
+
} else {
|
223
|
+
return torch::optional<Scalar>{from_ruby<Scalar>(x)};
|
224
|
+
}
|
225
|
+
}
|
240
226
|
|
227
|
+
Object wrap(bool x);
|
228
|
+
Object wrap(int64_t x);
|
229
|
+
Object wrap(double x);
|
230
|
+
Object wrap(torch::Tensor x);
|
231
|
+
Object wrap(torch::Scalar x);
|
232
|
+
Object wrap(torch::ScalarType x);
|
233
|
+
Object wrap(torch::QScheme x);
|
241
234
|
Object wrap(std::tuple<torch::Tensor, torch::Tensor> x);
|
242
235
|
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x);
|
243
236
|
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x);
|
244
237
|
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x);
|
245
238
|
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x);
|
246
239
|
Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x);
|
240
|
+
Object wrap(std::vector<torch::Tensor> x);
|