torch-rb 0.3.2 → 0.3.7
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +28 -0
- data/README.md +7 -2
- data/ext/torch/ext.cpp +60 -20
- data/ext/torch/extconf.rb +3 -0
- data/ext/torch/templates.cpp +36 -0
- data/ext/torch/templates.hpp +81 -87
- data/lib/torch.rb +71 -19
- data/lib/torch/native/dispatcher.rb +30 -8
- data/lib/torch/native/function.rb +93 -4
- data/lib/torch/native/generator.rb +45 -41
- data/lib/torch/native/parser.rb +57 -76
- data/lib/torch/nn/functional.rb +112 -2
- data/lib/torch/nn/leaky_relu.rb +3 -3
- data/lib/torch/nn/module.rb +9 -1
- data/lib/torch/nn/upsample.rb +31 -0
- data/lib/torch/tensor.rb +45 -51
- data/lib/torch/utils/data/data_loader.rb +2 -0
- data/lib/torch/utils/data/tensor_dataset.rb +2 -0
- data/lib/torch/version.rb +1 -1
- metadata +3 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 8a1852ee3d1ecc7a29c23259b8c328a95030a270b7c11f37f22049177898652e
|
4
|
+
data.tar.gz: 56823f1815d3c0c4d5d5c01ef76d781b792b3e4e7c68c0332a149b883a54c7c8
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: bed15510cfeaa555d71f1e1f46ed8944893bd349a07c4316dcd63429fe76e13facd8794399ef97fc400d05796579f2e84822b62c98c71dc996e211ad04113ae2
|
7
|
+
data.tar.gz: aa05e3645e363eda27274323cdb7fb316342074d1d5afe8f7ee6bfd9819da7883b43d084beb6b29011c631c04fdddc8e6789db41c7c84c53ba9ed152d3338b09
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,31 @@
|
|
1
|
+
## 0.3.7 (2020-09-22)
|
2
|
+
|
3
|
+
- Improved performance
|
4
|
+
- Added `Upsample`
|
5
|
+
- Added support for passing tensor class to `type` method
|
6
|
+
- Fixed error with buffers on GPU
|
7
|
+
- Fixed error with `new_full`
|
8
|
+
- Fixed issue with `numo` method and non-contiguous tensors
|
9
|
+
|
10
|
+
## 0.3.6 (2020-09-17)
|
11
|
+
|
12
|
+
- Added `inplace` option for leaky ReLU
|
13
|
+
- Fixed error with methods that return a tensor list (`chunk`, `split`, and `unbind`)
|
14
|
+
- Fixed error with buffers on GPU
|
15
|
+
|
16
|
+
## 0.3.5 (2020-09-04)
|
17
|
+
|
18
|
+
- Fixed error with data loader (due to `dtype` of `randperm`)
|
19
|
+
|
20
|
+
## 0.3.4 (2020-08-26)
|
21
|
+
|
22
|
+
- Added `Torch.clamp` method
|
23
|
+
|
24
|
+
## 0.3.3 (2020-08-25)
|
25
|
+
|
26
|
+
- Added spectral ops
|
27
|
+
- Fixed tensor indexing
|
28
|
+
|
1
29
|
## 0.3.2 (2020-08-24)
|
2
30
|
|
3
31
|
- Added `enable_grad` method
|
data/README.md
CHANGED
@@ -2,7 +2,11 @@
|
|
2
2
|
|
3
3
|
:fire: Deep learning for Ruby, powered by [LibTorch](https://pytorch.org)
|
4
4
|
|
5
|
-
|
5
|
+
Check out:
|
6
|
+
|
7
|
+
- [TorchVision](https://github.com/ankane/torchvision) for computer vision tasks
|
8
|
+
- [TorchText](https://github.com/ankane/torchtext) for text and NLP tasks
|
9
|
+
- [TorchAudio](https://github.com/ankane/torchaudio) for audio tasks
|
6
10
|
|
7
11
|
[![Build Status](https://travis-ci.org/ankane/torch.rb.svg?branch=master)](https://travis-ci.org/ankane/torch.rb)
|
8
12
|
|
@@ -398,6 +402,7 @@ Here are a few full examples:
|
|
398
402
|
- [Image classification with MNIST](examples/mnist) ([日本語版](https://qiita.com/kojix2/items/c19c36dc1bf73ea93409))
|
399
403
|
- [Collaborative filtering with MovieLens](examples/movielens)
|
400
404
|
- [Sequence models and word embeddings](examples/nlp)
|
405
|
+
- [Generative adversarial networks](examples/gan)
|
401
406
|
|
402
407
|
## LibTorch Installation
|
403
408
|
|
@@ -411,7 +416,7 @@ Here’s the list of compatible versions.
|
|
411
416
|
|
412
417
|
Torch.rb | LibTorch
|
413
418
|
--- | ---
|
414
|
-
0.3.0-0.3.
|
419
|
+
0.3.0-0.3.4 | 1.6.0
|
415
420
|
0.2.0-0.2.7 | 1.5.0-1.5.1
|
416
421
|
0.1.8 | 1.4.0
|
417
422
|
0.1.0-0.1.7 | 1.3.1
|
data/ext/torch/ext.cpp
CHANGED
@@ -16,6 +16,7 @@
|
|
16
16
|
#include "nn_functions.hpp"
|
17
17
|
|
18
18
|
using namespace Rice;
|
19
|
+
using torch::indexing::TensorIndex;
|
19
20
|
|
20
21
|
// need to make a distinction between parameters and tensors
|
21
22
|
class Parameter: public torch::autograd::Variable {
|
@@ -28,6 +29,15 @@ void handle_error(torch::Error const & ex)
|
|
28
29
|
throw Exception(rb_eRuntimeError, ex.what_without_backtrace());
|
29
30
|
}
|
30
31
|
|
32
|
+
std::vector<TensorIndex> index_vector(Array a) {
|
33
|
+
auto indices = std::vector<TensorIndex>();
|
34
|
+
indices.reserve(a.size());
|
35
|
+
for (size_t i = 0; i < a.size(); i++) {
|
36
|
+
indices.push_back(from_ruby<TensorIndex>(a[i]));
|
37
|
+
}
|
38
|
+
return indices;
|
39
|
+
}
|
40
|
+
|
31
41
|
extern "C"
|
32
42
|
void Init_ext()
|
33
43
|
{
|
@@ -58,6 +68,13 @@ void Init_ext()
|
|
58
68
|
return generator.seed();
|
59
69
|
});
|
60
70
|
|
71
|
+
Class rb_cTensorIndex = define_class_under<TensorIndex>(rb_mTorch, "TensorIndex")
|
72
|
+
.define_singleton_method("boolean", *[](bool value) { return TensorIndex(value); })
|
73
|
+
.define_singleton_method("integer", *[](int64_t value) { return TensorIndex(value); })
|
74
|
+
.define_singleton_method("tensor", *[](torch::Tensor& value) { return TensorIndex(value); })
|
75
|
+
.define_singleton_method("slice", *[](torch::optional<int64_t> start_index, torch::optional<int64_t> stop_index) { return TensorIndex(torch::indexing::Slice(start_index, stop_index)); })
|
76
|
+
.define_singleton_method("none", *[]() { return TensorIndex(torch::indexing::None); });
|
77
|
+
|
61
78
|
// https://pytorch.org/cppdocs/api/structc10_1_1_i_value.html
|
62
79
|
Class rb_cIValue = define_class_under<torch::IValue>(rb_mTorch, "IValue")
|
63
80
|
.add_handler<torch::Error>(handle_error)
|
@@ -215,7 +232,7 @@ void Init_ext()
|
|
215
232
|
})
|
216
233
|
.define_singleton_method(
|
217
234
|
"_empty",
|
218
|
-
*[](
|
235
|
+
*[](std::vector<int64_t> size, const torch::TensorOptions &options) {
|
219
236
|
return torch::empty(size, options);
|
220
237
|
})
|
221
238
|
.define_singleton_method(
|
@@ -225,7 +242,7 @@ void Init_ext()
|
|
225
242
|
})
|
226
243
|
.define_singleton_method(
|
227
244
|
"_full",
|
228
|
-
*[](
|
245
|
+
*[](std::vector<int64_t> size, Scalar fill_value, const torch::TensorOptions& options) {
|
229
246
|
return torch::full(size, fill_value, options);
|
230
247
|
})
|
231
248
|
.define_singleton_method(
|
@@ -240,22 +257,22 @@ void Init_ext()
|
|
240
257
|
})
|
241
258
|
.define_singleton_method(
|
242
259
|
"_ones",
|
243
|
-
*[](
|
260
|
+
*[](std::vector<int64_t> size, const torch::TensorOptions &options) {
|
244
261
|
return torch::ones(size, options);
|
245
262
|
})
|
246
263
|
.define_singleton_method(
|
247
264
|
"_rand",
|
248
|
-
*[](
|
265
|
+
*[](std::vector<int64_t> size, const torch::TensorOptions &options) {
|
249
266
|
return torch::rand(size, options);
|
250
267
|
})
|
251
268
|
.define_singleton_method(
|
252
269
|
"_randint",
|
253
|
-
*[](int64_t low, int64_t high,
|
270
|
+
*[](int64_t low, int64_t high, std::vector<int64_t> size, const torch::TensorOptions &options) {
|
254
271
|
return torch::randint(low, high, size, options);
|
255
272
|
})
|
256
273
|
.define_singleton_method(
|
257
274
|
"_randn",
|
258
|
-
*[](
|
275
|
+
*[](std::vector<int64_t> size, const torch::TensorOptions &options) {
|
259
276
|
return torch::randn(size, options);
|
260
277
|
})
|
261
278
|
.define_singleton_method(
|
@@ -265,7 +282,7 @@ void Init_ext()
|
|
265
282
|
})
|
266
283
|
.define_singleton_method(
|
267
284
|
"_zeros",
|
268
|
-
*[](
|
285
|
+
*[](std::vector<int64_t> size, const torch::TensorOptions &options) {
|
269
286
|
return torch::zeros(size, options);
|
270
287
|
})
|
271
288
|
// begin operations
|
@@ -284,20 +301,15 @@ void Init_ext()
|
|
284
301
|
// https://github.com/pytorch/pytorch/issues/20356#issuecomment-567663701
|
285
302
|
return torch::pickle_load(v);
|
286
303
|
})
|
287
|
-
.define_singleton_method(
|
288
|
-
"_binary_cross_entropy_with_logits",
|
289
|
-
*[](const Tensor &input, const Tensor &target, OptionalTensor weight, OptionalTensor pos_weight, MyReduction reduction) {
|
290
|
-
return torch::binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction);
|
291
|
-
})
|
292
304
|
.define_singleton_method(
|
293
305
|
"_from_blob",
|
294
|
-
*[](String s,
|
306
|
+
*[](String s, std::vector<int64_t> size, const torch::TensorOptions &options) {
|
295
307
|
void *data = const_cast<char *>(s.c_str());
|
296
308
|
return torch::from_blob(data, size, options);
|
297
309
|
})
|
298
310
|
.define_singleton_method(
|
299
311
|
"_tensor",
|
300
|
-
*[](Array a,
|
312
|
+
*[](Array a, std::vector<int64_t> size, const torch::TensorOptions &options) {
|
301
313
|
auto dtype = options.dtype();
|
302
314
|
torch::Tensor t;
|
303
315
|
if (dtype == torch::kBool) {
|
@@ -330,6 +342,28 @@ void Init_ext()
|
|
330
342
|
.define_method("numel", &torch::Tensor::numel)
|
331
343
|
.define_method("element_size", &torch::Tensor::element_size)
|
332
344
|
.define_method("requires_grad", &torch::Tensor::requires_grad)
|
345
|
+
// in C++ for performance
|
346
|
+
.define_method(
|
347
|
+
"shape",
|
348
|
+
*[](Tensor& self) {
|
349
|
+
Array a;
|
350
|
+
for (auto &size : self.sizes()) {
|
351
|
+
a.push(size);
|
352
|
+
}
|
353
|
+
return a;
|
354
|
+
})
|
355
|
+
.define_method(
|
356
|
+
"_index",
|
357
|
+
*[](Tensor& self, Array indices) {
|
358
|
+
auto vec = index_vector(indices);
|
359
|
+
return self.index(vec);
|
360
|
+
})
|
361
|
+
.define_method(
|
362
|
+
"_index_put_custom",
|
363
|
+
*[](Tensor& self, Array indices, torch::Tensor& value) {
|
364
|
+
auto vec = index_vector(indices);
|
365
|
+
return self.index_put_(vec, value);
|
366
|
+
})
|
333
367
|
.define_method(
|
334
368
|
"contiguous?",
|
335
369
|
*[](Tensor& self) {
|
@@ -350,11 +384,6 @@ void Init_ext()
|
|
350
384
|
*[](Tensor& self, bool requires_grad) {
|
351
385
|
return self.set_requires_grad(requires_grad);
|
352
386
|
})
|
353
|
-
.define_method(
|
354
|
-
"_backward",
|
355
|
-
*[](Tensor& self, OptionalTensor gradient, bool create_graph, bool retain_graph) {
|
356
|
-
return self.backward(gradient, create_graph, retain_graph);
|
357
|
-
})
|
358
387
|
.define_method(
|
359
388
|
"grad",
|
360
389
|
*[](Tensor& self) {
|
@@ -401,9 +430,19 @@ void Init_ext()
|
|
401
430
|
tensor = tensor.to(device);
|
402
431
|
}
|
403
432
|
|
433
|
+
if (!tensor.is_contiguous()) {
|
434
|
+
tensor = tensor.contiguous();
|
435
|
+
}
|
436
|
+
|
404
437
|
auto data_ptr = (const char *) tensor.data_ptr();
|
405
438
|
return std::string(data_ptr, tensor.numel() * tensor.element_size());
|
406
439
|
})
|
440
|
+
// for TorchVision
|
441
|
+
.define_method(
|
442
|
+
"_data_ptr",
|
443
|
+
*[](Tensor& self) {
|
444
|
+
return reinterpret_cast<uintptr_t>(self.data_ptr());
|
445
|
+
})
|
407
446
|
// TODO figure out a better way to do this
|
408
447
|
.define_method(
|
409
448
|
"_flat_data",
|
@@ -508,6 +547,7 @@ void Init_ext()
|
|
508
547
|
});
|
509
548
|
|
510
549
|
Module rb_mInit = define_module_under(rb_mNN, "Init")
|
550
|
+
.add_handler<torch::Error>(handle_error)
|
511
551
|
.define_singleton_method(
|
512
552
|
"_calculate_gain",
|
513
553
|
*[](NonlinearityType nonlinearity, double param) {
|
@@ -594,8 +634,8 @@ void Init_ext()
|
|
594
634
|
});
|
595
635
|
|
596
636
|
Class rb_cDevice = define_class_under<torch::Device>(rb_mTorch, "Device")
|
597
|
-
.define_constructor(Constructor<torch::Device, std::string>())
|
598
637
|
.add_handler<torch::Error>(handle_error)
|
638
|
+
.define_constructor(Constructor<torch::Device, std::string>())
|
599
639
|
.define_method("index", &torch::Device::index)
|
600
640
|
.define_method("index?", &torch::Device::has_index)
|
601
641
|
.define_method(
|
data/ext/torch/extconf.rb
CHANGED
data/ext/torch/templates.cpp
CHANGED
@@ -2,6 +2,34 @@
|
|
2
2
|
#include <rice/Object.hpp>
|
3
3
|
#include "templates.hpp"
|
4
4
|
|
5
|
+
Object wrap(bool x) {
|
6
|
+
return to_ruby<bool>(x);
|
7
|
+
}
|
8
|
+
|
9
|
+
Object wrap(int64_t x) {
|
10
|
+
return to_ruby<int64_t>(x);
|
11
|
+
}
|
12
|
+
|
13
|
+
Object wrap(double x) {
|
14
|
+
return to_ruby<double>(x);
|
15
|
+
}
|
16
|
+
|
17
|
+
Object wrap(torch::Tensor x) {
|
18
|
+
return to_ruby<torch::Tensor>(x);
|
19
|
+
}
|
20
|
+
|
21
|
+
Object wrap(torch::Scalar x) {
|
22
|
+
return to_ruby<torch::Scalar>(x);
|
23
|
+
}
|
24
|
+
|
25
|
+
Object wrap(torch::ScalarType x) {
|
26
|
+
return to_ruby<torch::ScalarType>(x);
|
27
|
+
}
|
28
|
+
|
29
|
+
Object wrap(torch::QScheme x) {
|
30
|
+
return to_ruby<torch::QScheme>(x);
|
31
|
+
}
|
32
|
+
|
5
33
|
Object wrap(std::tuple<torch::Tensor, torch::Tensor> x) {
|
6
34
|
Array a;
|
7
35
|
a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
|
@@ -53,3 +81,11 @@ Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x) {
|
|
53
81
|
a.push(to_ruby<int64_t>(std::get<3>(x)));
|
54
82
|
return Object(a);
|
55
83
|
}
|
84
|
+
|
85
|
+
Object wrap(std::vector<torch::Tensor> x) {
|
86
|
+
Array a;
|
87
|
+
for (auto& t : x) {
|
88
|
+
a.push(to_ruby<torch::Tensor>(t));
|
89
|
+
}
|
90
|
+
return Object(a);
|
91
|
+
}
|
data/ext/torch/templates.hpp
CHANGED
@@ -9,72 +9,35 @@
|
|
9
9
|
|
10
10
|
using namespace Rice;
|
11
11
|
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
Array a = Array(o);
|
19
|
-
for (size_t i = 0; i < a.size(); i++) {
|
20
|
-
vec.push_back(from_ruby<int64_t>(a[i]));
|
21
|
-
}
|
22
|
-
}
|
23
|
-
operator torch::IntArrayRef() {
|
24
|
-
return torch::IntArrayRef(vec);
|
25
|
-
}
|
26
|
-
};
|
12
|
+
using torch::Device;
|
13
|
+
using torch::Scalar;
|
14
|
+
using torch::ScalarType;
|
15
|
+
using torch::Tensor;
|
16
|
+
using torch::IntArrayRef;
|
17
|
+
using torch::TensorList;
|
27
18
|
|
28
19
|
template<>
|
29
20
|
inline
|
30
|
-
|
21
|
+
std::vector<int64_t> from_ruby<std::vector<int64_t>>(Object x)
|
31
22
|
{
|
32
|
-
|
23
|
+
Array a = Array(x);
|
24
|
+
std::vector<int64_t> vec(a.size());
|
25
|
+
for (size_t i = 0; i < a.size(); i++) {
|
26
|
+
vec[i] = from_ruby<int64_t>(a[i]);
|
27
|
+
}
|
28
|
+
return vec;
|
33
29
|
}
|
34
30
|
|
35
|
-
// for now
|
36
|
-
class Scalar {
|
37
|
-
torch::Scalar value;
|
38
|
-
public:
|
39
|
-
Scalar(Object o) {
|
40
|
-
// TODO cast based on Ruby type
|
41
|
-
if (o.rb_type() == T_FIXNUM) {
|
42
|
-
value = torch::Scalar(from_ruby<int64_t>(o));
|
43
|
-
} else {
|
44
|
-
value = torch::Scalar(from_ruby<float>(o));
|
45
|
-
}
|
46
|
-
}
|
47
|
-
operator torch::Scalar() {
|
48
|
-
return value;
|
49
|
-
}
|
50
|
-
};
|
51
|
-
|
52
31
|
template<>
|
53
32
|
inline
|
54
|
-
|
33
|
+
std::vector<Tensor> from_ruby<std::vector<Tensor>>(Object x)
|
55
34
|
{
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
TensorList(Object o) {
|
63
|
-
Array a = Array(o);
|
64
|
-
for (size_t i = 0; i < a.size(); i++) {
|
65
|
-
vec.push_back(from_ruby<torch::Tensor>(a[i]));
|
66
|
-
}
|
67
|
-
}
|
68
|
-
operator torch::TensorList() {
|
69
|
-
return torch::TensorList(vec);
|
70
|
-
}
|
71
|
-
};
|
72
|
-
|
73
|
-
template<>
|
74
|
-
inline
|
75
|
-
TensorList from_ruby<TensorList>(Object x)
|
76
|
-
{
|
77
|
-
return TensorList(x);
|
35
|
+
Array a = Array(x);
|
36
|
+
std::vector<Tensor> vec(a.size());
|
37
|
+
for (size_t i = 0; i < a.size(); i++) {
|
38
|
+
vec[i] = from_ruby<Tensor>(a[i]);
|
39
|
+
}
|
40
|
+
return vec;
|
78
41
|
}
|
79
42
|
|
80
43
|
class FanModeType {
|
@@ -174,8 +137,6 @@ MyReduction from_ruby<MyReduction>(Object x)
|
|
174
137
|
return MyReduction(x);
|
175
138
|
}
|
176
139
|
|
177
|
-
typedef torch::Tensor Tensor;
|
178
|
-
|
179
140
|
class OptionalTensor {
|
180
141
|
Object value;
|
181
142
|
public:
|
@@ -190,6 +151,17 @@ class OptionalTensor {
|
|
190
151
|
}
|
191
152
|
};
|
192
153
|
|
154
|
+
template<>
|
155
|
+
inline
|
156
|
+
Scalar from_ruby<Scalar>(Object x)
|
157
|
+
{
|
158
|
+
if (x.rb_type() == T_FIXNUM) {
|
159
|
+
return torch::Scalar(from_ruby<int64_t>(x));
|
160
|
+
} else {
|
161
|
+
return torch::Scalar(from_ruby<double>(x));
|
162
|
+
}
|
163
|
+
}
|
164
|
+
|
193
165
|
template<>
|
194
166
|
inline
|
195
167
|
OptionalTensor from_ruby<OptionalTensor>(Object x)
|
@@ -197,50 +169,72 @@ OptionalTensor from_ruby<OptionalTensor>(Object x)
|
|
197
169
|
return OptionalTensor(x);
|
198
170
|
}
|
199
171
|
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
}
|
172
|
+
template<>
|
173
|
+
inline
|
174
|
+
torch::optional<torch::ScalarType> from_ruby<torch::optional<torch::ScalarType>>(Object x)
|
175
|
+
{
|
176
|
+
if (x.is_nil()) {
|
177
|
+
return torch::nullopt;
|
178
|
+
} else {
|
179
|
+
return torch::optional<torch::ScalarType>{from_ruby<torch::ScalarType>(x)};
|
180
|
+
}
|
181
|
+
}
|
210
182
|
|
211
183
|
template<>
|
212
184
|
inline
|
213
|
-
|
185
|
+
torch::optional<int64_t> from_ruby<torch::optional<int64_t>>(Object x)
|
214
186
|
{
|
215
|
-
|
187
|
+
if (x.is_nil()) {
|
188
|
+
return torch::nullopt;
|
189
|
+
} else {
|
190
|
+
return torch::optional<int64_t>{from_ruby<int64_t>(x)};
|
191
|
+
}
|
216
192
|
}
|
217
193
|
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
return ScalarType(value);
|
229
|
-
}
|
230
|
-
};
|
194
|
+
template<>
|
195
|
+
inline
|
196
|
+
torch::optional<double> from_ruby<torch::optional<double>>(Object x)
|
197
|
+
{
|
198
|
+
if (x.is_nil()) {
|
199
|
+
return torch::nullopt;
|
200
|
+
} else {
|
201
|
+
return torch::optional<double>{from_ruby<double>(x)};
|
202
|
+
}
|
203
|
+
}
|
231
204
|
|
232
205
|
template<>
|
233
206
|
inline
|
234
|
-
|
207
|
+
torch::optional<bool> from_ruby<torch::optional<bool>>(Object x)
|
235
208
|
{
|
236
|
-
|
209
|
+
if (x.is_nil()) {
|
210
|
+
return torch::nullopt;
|
211
|
+
} else {
|
212
|
+
return torch::optional<bool>{from_ruby<bool>(x)};
|
213
|
+
}
|
237
214
|
}
|
238
215
|
|
239
|
-
|
216
|
+
template<>
|
217
|
+
inline
|
218
|
+
torch::optional<Scalar> from_ruby<torch::optional<Scalar>>(Object x)
|
219
|
+
{
|
220
|
+
if (x.is_nil()) {
|
221
|
+
return torch::nullopt;
|
222
|
+
} else {
|
223
|
+
return torch::optional<Scalar>{from_ruby<Scalar>(x)};
|
224
|
+
}
|
225
|
+
}
|
240
226
|
|
227
|
+
Object wrap(bool x);
|
228
|
+
Object wrap(int64_t x);
|
229
|
+
Object wrap(double x);
|
230
|
+
Object wrap(torch::Tensor x);
|
231
|
+
Object wrap(torch::Scalar x);
|
232
|
+
Object wrap(torch::ScalarType x);
|
233
|
+
Object wrap(torch::QScheme x);
|
241
234
|
Object wrap(std::tuple<torch::Tensor, torch::Tensor> x);
|
242
235
|
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x);
|
243
236
|
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x);
|
244
237
|
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x);
|
245
238
|
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x);
|
246
239
|
Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x);
|
240
|
+
Object wrap(std::vector<torch::Tensor> x);
|