torch-rb 0.6.0 → 0.8.2
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +21 -0
- data/README.md +23 -41
- data/codegen/function.rb +2 -0
- data/codegen/generate_functions.rb +43 -6
- data/codegen/native_functions.yaml +2007 -1327
- data/ext/torch/backends.cpp +17 -0
- data/ext/torch/cuda.cpp +5 -5
- data/ext/torch/device.cpp +13 -6
- data/ext/torch/ext.cpp +22 -5
- data/ext/torch/extconf.rb +1 -3
- data/ext/torch/fft.cpp +13 -0
- data/ext/torch/fft_functions.h +6 -0
- data/ext/torch/ivalue.cpp +31 -33
- data/ext/torch/linalg.cpp +13 -0
- data/ext/torch/linalg_functions.h +6 -0
- data/ext/torch/nn.cpp +34 -34
- data/ext/torch/random.cpp +5 -5
- data/ext/torch/ruby_arg_parser.cpp +2 -2
- data/ext/torch/ruby_arg_parser.h +23 -12
- data/ext/torch/special.cpp +13 -0
- data/ext/torch/special_functions.h +6 -0
- data/ext/torch/templates.h +111 -133
- data/ext/torch/tensor.cpp +80 -67
- data/ext/torch/torch.cpp +30 -21
- data/ext/torch/utils.h +3 -4
- data/ext/torch/wrap_outputs.h +72 -65
- data/lib/torch/inspector.rb +5 -2
- data/lib/torch/nn/convnd.rb +2 -0
- data/lib/torch/nn/functional_attention.rb +241 -0
- data/lib/torch/nn/module.rb +2 -0
- data/lib/torch/nn/module_list.rb +49 -0
- data/lib/torch/nn/multihead_attention.rb +123 -0
- data/lib/torch/nn/transformer.rb +92 -0
- data/lib/torch/nn/transformer_decoder.rb +25 -0
- data/lib/torch/nn/transformer_decoder_layer.rb +43 -0
- data/lib/torch/nn/transformer_encoder.rb +25 -0
- data/lib/torch/nn/transformer_encoder_layer.rb +36 -0
- data/lib/torch/nn/utils.rb +16 -0
- data/lib/torch/tensor.rb +2 -0
- data/lib/torch/utils/data/data_loader.rb +2 -0
- data/lib/torch/version.rb +1 -1
- data/lib/torch.rb +11 -0
- metadata +20 -5
@@ -0,0 +1,17 @@
|
|
1
|
+
#include <torch/torch.h>
|
2
|
+
|
3
|
+
#include <rice/rice.hpp>
|
4
|
+
|
5
|
+
#include "utils.h"
|
6
|
+
|
7
|
+
void init_backends(Rice::Module& m) {
|
8
|
+
auto rb_mBackends = Rice::define_module_under(m, "Backends");
|
9
|
+
|
10
|
+
Rice::define_module_under(rb_mBackends, "OpenMP")
|
11
|
+
.add_handler<torch::Error>(handle_error)
|
12
|
+
.define_singleton_function("available?", &torch::hasOpenMP);
|
13
|
+
|
14
|
+
Rice::define_module_under(rb_mBackends, "MKL")
|
15
|
+
.add_handler<torch::Error>(handle_error)
|
16
|
+
.define_singleton_function("available?", &torch::hasMKL);
|
17
|
+
}
|
data/ext/torch/cuda.cpp
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
#include <torch/torch.h>
|
2
2
|
|
3
|
-
#include <rice/
|
3
|
+
#include <rice/rice.hpp>
|
4
4
|
|
5
5
|
#include "utils.h"
|
6
6
|
|
7
7
|
void init_cuda(Rice::Module& m) {
|
8
8
|
Rice::define_module_under(m, "CUDA")
|
9
9
|
.add_handler<torch::Error>(handle_error)
|
10
|
-
.
|
11
|
-
.
|
12
|
-
.
|
13
|
-
.
|
10
|
+
.define_singleton_function("available?", &torch::cuda::is_available)
|
11
|
+
.define_singleton_function("device_count", &torch::cuda::device_count)
|
12
|
+
.define_singleton_function("manual_seed", &torch::cuda::manual_seed)
|
13
|
+
.define_singleton_function("manual_seed_all", &torch::cuda::manual_seed_all);
|
14
14
|
}
|
data/ext/torch/device.cpp
CHANGED
@@ -1,19 +1,26 @@
|
|
1
1
|
#include <torch/torch.h>
|
2
2
|
|
3
|
-
#include <rice/
|
4
|
-
#include <rice/Module.hpp>
|
3
|
+
#include <rice/rice.hpp>
|
5
4
|
|
6
5
|
#include "utils.h"
|
7
6
|
|
8
7
|
void init_device(Rice::Module& m) {
|
9
8
|
Rice::define_class_under<torch::Device>(m, "Device")
|
10
9
|
.add_handler<torch::Error>(handle_error)
|
11
|
-
.define_constructor(Rice::Constructor<torch::Device, std::string
|
12
|
-
.define_method(
|
13
|
-
|
10
|
+
.define_constructor(Rice::Constructor<torch::Device, const std::string&>())
|
11
|
+
.define_method(
|
12
|
+
"index",
|
13
|
+
[](torch::Device& self) {
|
14
|
+
return self.index();
|
15
|
+
})
|
16
|
+
.define_method(
|
17
|
+
"index?",
|
18
|
+
[](torch::Device& self) {
|
19
|
+
return self.has_index();
|
20
|
+
})
|
14
21
|
.define_method(
|
15
22
|
"type",
|
16
|
-
|
23
|
+
[](torch::Device& self) {
|
17
24
|
std::stringstream s;
|
18
25
|
s << self.type();
|
19
26
|
return s.str();
|
data/ext/torch/ext.cpp
CHANGED
@@ -1,12 +1,18 @@
|
|
1
|
-
#include <
|
1
|
+
#include <torch/torch.h>
|
2
2
|
|
3
|
+
#include <rice/rice.hpp>
|
4
|
+
|
5
|
+
void init_fft(Rice::Module& m);
|
6
|
+
void init_linalg(Rice::Module& m);
|
3
7
|
void init_nn(Rice::Module& m);
|
4
|
-
void
|
8
|
+
void init_special(Rice::Module& m);
|
9
|
+
void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions);
|
5
10
|
void init_torch(Rice::Module& m);
|
6
11
|
|
12
|
+
void init_backends(Rice::Module& m);
|
7
13
|
void init_cuda(Rice::Module& m);
|
8
14
|
void init_device(Rice::Module& m);
|
9
|
-
void init_ivalue(Rice::Module& m);
|
15
|
+
void init_ivalue(Rice::Module& m, Rice::Class& rb_cIValue);
|
10
16
|
void init_random(Rice::Module& m);
|
11
17
|
|
12
18
|
extern "C"
|
@@ -14,13 +20,24 @@ void Init_ext()
|
|
14
20
|
{
|
15
21
|
auto m = Rice::define_module("Torch");
|
16
22
|
|
23
|
+
// need to define certain classes up front to keep Rice happy
|
24
|
+
auto rb_cIValue = Rice::define_class_under<torch::IValue>(m, "IValue")
|
25
|
+
.define_constructor(Rice::Constructor<torch::IValue>());
|
26
|
+
auto rb_cTensor = Rice::define_class_under<torch::Tensor>(m, "Tensor");
|
27
|
+
auto rb_cTensorOptions = Rice::define_class_under<torch::TensorOptions>(m, "TensorOptions")
|
28
|
+
.define_constructor(Rice::Constructor<torch::TensorOptions>());
|
29
|
+
|
17
30
|
// keep this order
|
18
31
|
init_torch(m);
|
19
|
-
init_tensor(m);
|
32
|
+
init_tensor(m, rb_cTensor, rb_cTensorOptions);
|
20
33
|
init_nn(m);
|
34
|
+
init_fft(m);
|
35
|
+
init_linalg(m);
|
36
|
+
init_special(m);
|
21
37
|
|
38
|
+
init_backends(m);
|
22
39
|
init_cuda(m);
|
23
40
|
init_device(m);
|
24
|
-
init_ivalue(m);
|
41
|
+
init_ivalue(m, rb_cIValue);
|
25
42
|
init_random(m);
|
26
43
|
}
|
data/ext/torch/extconf.rb
CHANGED
data/ext/torch/fft.cpp
ADDED
@@ -0,0 +1,13 @@
|
|
1
|
+
#include <torch/torch.h>
|
2
|
+
|
3
|
+
#include <rice/rice.hpp>
|
4
|
+
|
5
|
+
#include "fft_functions.h"
|
6
|
+
#include "templates.h"
|
7
|
+
#include "utils.h"
|
8
|
+
|
9
|
+
void init_fft(Rice::Module& m) {
|
10
|
+
auto rb_mFFT = Rice::define_module_under(m, "FFT");
|
11
|
+
rb_mFFT.add_handler<torch::Error>(handle_error);
|
12
|
+
add_fft_functions(rb_mFFT);
|
13
|
+
}
|
data/ext/torch/ivalue.cpp
CHANGED
@@ -1,18 +1,13 @@
|
|
1
1
|
#include <torch/torch.h>
|
2
2
|
|
3
|
-
#include <rice/
|
4
|
-
#include <rice/Constructor.hpp>
|
5
|
-
#include <rice/Hash.hpp>
|
6
|
-
#include <rice/Module.hpp>
|
7
|
-
#include <rice/String.hpp>
|
3
|
+
#include <rice/rice.hpp>
|
8
4
|
|
9
5
|
#include "utils.h"
|
10
6
|
|
11
|
-
void init_ivalue(Rice::Module& m) {
|
7
|
+
void init_ivalue(Rice::Module& m, Rice::Class& rb_cIValue) {
|
12
8
|
// https://pytorch.org/cppdocs/api/structc10_1_1_i_value.html
|
13
|
-
|
9
|
+
rb_cIValue
|
14
10
|
.add_handler<torch::Error>(handle_error)
|
15
|
-
.define_constructor(Rice::Constructor<torch::IValue>())
|
16
11
|
.define_method("bool?", &torch::IValue::isBool)
|
17
12
|
.define_method("bool_list?", &torch::IValue::isBoolList)
|
18
13
|
.define_method("capsule?", &torch::IValue::isCapsule)
|
@@ -39,95 +34,98 @@ void init_ivalue(Rice::Module& m) {
|
|
39
34
|
.define_method("tuple?", &torch::IValue::isTuple)
|
40
35
|
.define_method(
|
41
36
|
"to_bool",
|
42
|
-
|
37
|
+
[](torch::IValue& self) {
|
43
38
|
return self.toBool();
|
44
39
|
})
|
45
40
|
.define_method(
|
46
41
|
"to_double",
|
47
|
-
|
42
|
+
[](torch::IValue& self) {
|
48
43
|
return self.toDouble();
|
49
44
|
})
|
50
45
|
.define_method(
|
51
46
|
"to_int",
|
52
|
-
|
47
|
+
[](torch::IValue& self) {
|
53
48
|
return self.toInt();
|
54
49
|
})
|
55
50
|
.define_method(
|
56
51
|
"to_list",
|
57
|
-
|
52
|
+
[](torch::IValue& self) {
|
58
53
|
auto list = self.toListRef();
|
59
54
|
Rice::Array obj;
|
60
55
|
for (auto& elem : list) {
|
61
|
-
|
56
|
+
auto v = torch::IValue{elem};
|
57
|
+
obj.push(Rice::Object(Rice::detail::To_Ruby<torch::IValue>().convert(v)));
|
62
58
|
}
|
63
59
|
return obj;
|
64
60
|
})
|
65
61
|
.define_method(
|
66
62
|
"to_string_ref",
|
67
|
-
|
63
|
+
[](torch::IValue& self) {
|
68
64
|
return self.toStringRef();
|
69
65
|
})
|
70
66
|
.define_method(
|
71
67
|
"to_tensor",
|
72
|
-
|
68
|
+
[](torch::IValue& self) {
|
73
69
|
return self.toTensor();
|
74
70
|
})
|
75
71
|
.define_method(
|
76
72
|
"to_generic_dict",
|
77
|
-
|
73
|
+
[](torch::IValue& self) {
|
78
74
|
auto dict = self.toGenericDict();
|
79
75
|
Rice::Hash obj;
|
80
76
|
for (auto& pair : dict) {
|
81
|
-
|
77
|
+
auto k = torch::IValue{pair.key()};
|
78
|
+
auto v = torch::IValue{pair.value()};
|
79
|
+
obj[Rice::Object(Rice::detail::To_Ruby<torch::IValue>().convert(k))] = Rice::Object(Rice::detail::To_Ruby<torch::IValue>().convert(v));
|
82
80
|
}
|
83
81
|
return obj;
|
84
82
|
})
|
85
|
-
.
|
83
|
+
.define_singleton_function(
|
86
84
|
"from_tensor",
|
87
|
-
|
85
|
+
[](torch::Tensor& v) {
|
88
86
|
return torch::IValue(v);
|
89
87
|
})
|
90
88
|
// TODO create specialized list types?
|
91
|
-
.
|
89
|
+
.define_singleton_function(
|
92
90
|
"from_list",
|
93
|
-
|
91
|
+
[](Rice::Array obj) {
|
94
92
|
c10::impl::GenericList list(c10::AnyType::get());
|
95
93
|
for (auto entry : obj) {
|
96
|
-
list.push_back(
|
94
|
+
list.push_back(Rice::detail::From_Ruby<torch::IValue>().convert(entry.value()));
|
97
95
|
}
|
98
96
|
return torch::IValue(list);
|
99
97
|
})
|
100
|
-
.
|
98
|
+
.define_singleton_function(
|
101
99
|
"from_string",
|
102
|
-
|
100
|
+
[](Rice::String v) {
|
103
101
|
return torch::IValue(v.str());
|
104
102
|
})
|
105
|
-
.
|
103
|
+
.define_singleton_function(
|
106
104
|
"from_int",
|
107
|
-
|
105
|
+
[](int64_t v) {
|
108
106
|
return torch::IValue(v);
|
109
107
|
})
|
110
|
-
.
|
108
|
+
.define_singleton_function(
|
111
109
|
"from_double",
|
112
|
-
|
110
|
+
[](double v) {
|
113
111
|
return torch::IValue(v);
|
114
112
|
})
|
115
|
-
.
|
113
|
+
.define_singleton_function(
|
116
114
|
"from_bool",
|
117
|
-
|
115
|
+
[](bool v) {
|
118
116
|
return torch::IValue(v);
|
119
117
|
})
|
120
118
|
// see https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/python/pybind_utils.h
|
121
119
|
// createGenericDict and toIValue
|
122
|
-
.
|
120
|
+
.define_singleton_function(
|
123
121
|
"from_dict",
|
124
|
-
|
122
|
+
[](Rice::Hash obj) {
|
125
123
|
auto key_type = c10::AnyType::get();
|
126
124
|
auto value_type = c10::AnyType::get();
|
127
125
|
c10::impl::GenericDict elems(key_type, value_type);
|
128
126
|
elems.reserve(obj.size());
|
129
127
|
for (auto entry : obj) {
|
130
|
-
elems.insert(
|
128
|
+
elems.insert(Rice::detail::From_Ruby<torch::IValue>().convert(entry.first), Rice::detail::From_Ruby<torch::IValue>().convert((Rice::Object) entry.second));
|
131
129
|
}
|
132
130
|
return torch::IValue(std::move(elems));
|
133
131
|
});
|
@@ -0,0 +1,13 @@
|
|
1
|
+
#include <torch/torch.h>
|
2
|
+
|
3
|
+
#include <rice/rice.hpp>
|
4
|
+
|
5
|
+
#include "linalg_functions.h"
|
6
|
+
#include "templates.h"
|
7
|
+
#include "utils.h"
|
8
|
+
|
9
|
+
void init_linalg(Rice::Module& m) {
|
10
|
+
auto rb_mLinalg = Rice::define_module_under(m, "Linalg");
|
11
|
+
rb_mLinalg.add_handler<torch::Error>(handle_error);
|
12
|
+
add_linalg_functions(rb_mLinalg);
|
13
|
+
}
|
data/ext/torch/nn.cpp
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
#include <torch/torch.h>
|
2
2
|
|
3
|
-
#include <rice/
|
3
|
+
#include <rice/rice.hpp>
|
4
4
|
|
5
5
|
#include "nn_functions.h"
|
6
6
|
#include "templates.h"
|
@@ -19,74 +19,74 @@ void init_nn(Rice::Module& m) {
|
|
19
19
|
|
20
20
|
Rice::define_module_under(rb_mNN, "Init")
|
21
21
|
.add_handler<torch::Error>(handle_error)
|
22
|
-
.
|
22
|
+
.define_singleton_function(
|
23
23
|
"_calculate_gain",
|
24
|
-
|
24
|
+
[](NonlinearityType nonlinearity, double param) {
|
25
25
|
return torch::nn::init::calculate_gain(nonlinearity, param);
|
26
26
|
})
|
27
|
-
.
|
27
|
+
.define_singleton_function(
|
28
28
|
"_uniform!",
|
29
|
-
|
29
|
+
[](Tensor tensor, double low, double high) {
|
30
30
|
return torch::nn::init::uniform_(tensor, low, high);
|
31
31
|
})
|
32
|
-
.
|
32
|
+
.define_singleton_function(
|
33
33
|
"_normal!",
|
34
|
-
|
34
|
+
[](Tensor tensor, double mean, double std) {
|
35
35
|
return torch::nn::init::normal_(tensor, mean, std);
|
36
36
|
})
|
37
|
-
.
|
37
|
+
.define_singleton_function(
|
38
38
|
"_constant!",
|
39
|
-
|
39
|
+
[](Tensor tensor, Scalar value) {
|
40
40
|
return torch::nn::init::constant_(tensor, value);
|
41
41
|
})
|
42
|
-
.
|
42
|
+
.define_singleton_function(
|
43
43
|
"_ones!",
|
44
|
-
|
44
|
+
[](Tensor tensor) {
|
45
45
|
return torch::nn::init::ones_(tensor);
|
46
46
|
})
|
47
|
-
.
|
47
|
+
.define_singleton_function(
|
48
48
|
"_zeros!",
|
49
|
-
|
49
|
+
[](Tensor tensor) {
|
50
50
|
return torch::nn::init::zeros_(tensor);
|
51
51
|
})
|
52
|
-
.
|
52
|
+
.define_singleton_function(
|
53
53
|
"_eye!",
|
54
|
-
|
54
|
+
[](Tensor tensor) {
|
55
55
|
return torch::nn::init::eye_(tensor);
|
56
56
|
})
|
57
|
-
.
|
57
|
+
.define_singleton_function(
|
58
58
|
"_dirac!",
|
59
|
-
|
59
|
+
[](Tensor tensor) {
|
60
60
|
return torch::nn::init::dirac_(tensor);
|
61
61
|
})
|
62
|
-
.
|
62
|
+
.define_singleton_function(
|
63
63
|
"_xavier_uniform!",
|
64
|
-
|
64
|
+
[](Tensor tensor, double gain) {
|
65
65
|
return torch::nn::init::xavier_uniform_(tensor, gain);
|
66
66
|
})
|
67
|
-
.
|
67
|
+
.define_singleton_function(
|
68
68
|
"_xavier_normal!",
|
69
|
-
|
69
|
+
[](Tensor tensor, double gain) {
|
70
70
|
return torch::nn::init::xavier_normal_(tensor, gain);
|
71
71
|
})
|
72
|
-
.
|
72
|
+
.define_singleton_function(
|
73
73
|
"_kaiming_uniform!",
|
74
|
-
|
74
|
+
[](Tensor tensor, double a, FanModeType mode, NonlinearityType nonlinearity) {
|
75
75
|
return torch::nn::init::kaiming_uniform_(tensor, a, mode, nonlinearity);
|
76
76
|
})
|
77
|
-
.
|
77
|
+
.define_singleton_function(
|
78
78
|
"_kaiming_normal!",
|
79
|
-
|
79
|
+
[](Tensor tensor, double a, FanModeType mode, NonlinearityType nonlinearity) {
|
80
80
|
return torch::nn::init::kaiming_normal_(tensor, a, mode, nonlinearity);
|
81
81
|
})
|
82
|
-
.
|
82
|
+
.define_singleton_function(
|
83
83
|
"_orthogonal!",
|
84
|
-
|
84
|
+
[](Tensor tensor, double gain) {
|
85
85
|
return torch::nn::init::orthogonal_(tensor, gain);
|
86
86
|
})
|
87
|
-
.
|
87
|
+
.define_singleton_function(
|
88
88
|
"_sparse!",
|
89
|
-
|
89
|
+
[](Tensor tensor, double sparsity, double std) {
|
90
90
|
return torch::nn::init::sparse_(tensor, sparsity, std);
|
91
91
|
});
|
92
92
|
|
@@ -94,18 +94,18 @@ void init_nn(Rice::Module& m) {
|
|
94
94
|
.add_handler<torch::Error>(handle_error)
|
95
95
|
.define_method(
|
96
96
|
"grad",
|
97
|
-
|
97
|
+
[](Parameter& self) {
|
98
98
|
auto grad = self.grad();
|
99
|
-
return grad.defined() ?
|
99
|
+
return grad.defined() ? Object(Rice::detail::To_Ruby<torch::Tensor>().convert(grad)) : Nil;
|
100
100
|
})
|
101
101
|
.define_method(
|
102
102
|
"grad=",
|
103
|
-
|
103
|
+
[](Parameter& self, torch::Tensor& grad) {
|
104
104
|
self.mutable_grad() = grad;
|
105
105
|
})
|
106
|
-
.
|
106
|
+
.define_singleton_function(
|
107
107
|
"_make_subclass",
|
108
|
-
|
108
|
+
[](Tensor& rd, bool requires_grad) {
|
109
109
|
auto data = rd.detach();
|
110
110
|
data.unsafeGetTensorImpl()->set_allow_tensor_metadata_change(true);
|
111
111
|
auto var = data.set_requires_grad(requires_grad);
|
data/ext/torch/random.cpp
CHANGED
@@ -1,20 +1,20 @@
|
|
1
1
|
#include <torch/torch.h>
|
2
2
|
|
3
|
-
#include <rice/
|
3
|
+
#include <rice/rice.hpp>
|
4
4
|
|
5
5
|
#include "utils.h"
|
6
6
|
|
7
7
|
void init_random(Rice::Module& m) {
|
8
8
|
Rice::define_module_under(m, "Random")
|
9
9
|
.add_handler<torch::Error>(handle_error)
|
10
|
-
.
|
10
|
+
.define_singleton_function(
|
11
11
|
"initial_seed",
|
12
|
-
|
12
|
+
[]() {
|
13
13
|
return at::detail::getDefaultCPUGenerator().current_seed();
|
14
14
|
})
|
15
|
-
.
|
15
|
+
.define_singleton_function(
|
16
16
|
"seed",
|
17
|
-
|
17
|
+
[]() {
|
18
18
|
// TODO set for CUDA when available
|
19
19
|
auto generator = at::detail::getDefaultCPUGenerator();
|
20
20
|
return generator.seed();
|
@@ -137,7 +137,7 @@ auto FunctionParameter::check(VALUE obj, int argnum) -> bool
|
|
137
137
|
return true;
|
138
138
|
}
|
139
139
|
if (THPVariable_Check(obj)) {
|
140
|
-
auto var =
|
140
|
+
auto var = Rice::detail::From_Ruby<torch::Tensor>().convert(obj);
|
141
141
|
return !var.requires_grad() && var.dim() == 0;
|
142
142
|
}
|
143
143
|
return false;
|
@@ -147,7 +147,7 @@ auto FunctionParameter::check(VALUE obj, int argnum) -> bool
|
|
147
147
|
return true;
|
148
148
|
}
|
149
149
|
if (THPVariable_Check(obj)) {
|
150
|
-
auto var =
|
150
|
+
auto var = Rice::detail::From_Ruby<torch::Tensor>().convert(obj);
|
151
151
|
return at::isIntegralType(var.scalar_type(), /*includeBool=*/false) && !var.requires_grad() && var.dim() == 0;
|
152
152
|
}
|
153
153
|
return false;
|
data/ext/torch/ruby_arg_parser.h
CHANGED
@@ -5,7 +5,7 @@
|
|
5
5
|
#include <sstream>
|
6
6
|
|
7
7
|
#include <torch/torch.h>
|
8
|
-
#include <rice/
|
8
|
+
#include <rice/rice.hpp>
|
9
9
|
|
10
10
|
#include "templates.h"
|
11
11
|
#include "utils.h"
|
@@ -78,6 +78,7 @@ struct RubyArgs {
|
|
78
78
|
inline OptionalTensor optionalTensor(int i);
|
79
79
|
inline at::Scalar scalar(int i);
|
80
80
|
// inline at::Scalar scalarWithDefault(int i, at::Scalar default_scalar);
|
81
|
+
inline std::vector<at::Scalar> scalarlist(int i);
|
81
82
|
inline std::vector<at::Tensor> tensorlist(int i);
|
82
83
|
template<int N>
|
83
84
|
inline std::array<at::Tensor, N> tensorlist_n(int i);
|
@@ -121,7 +122,7 @@ struct RubyArgs {
|
|
121
122
|
};
|
122
123
|
|
123
124
|
inline at::Tensor RubyArgs::tensor(int i) {
|
124
|
-
return
|
125
|
+
return Rice::detail::From_Ruby<torch::Tensor>().convert(args[i]);
|
125
126
|
}
|
126
127
|
|
127
128
|
inline OptionalTensor RubyArgs::optionalTensor(int i) {
|
@@ -131,12 +132,17 @@ inline OptionalTensor RubyArgs::optionalTensor(int i) {
|
|
131
132
|
|
132
133
|
inline at::Scalar RubyArgs::scalar(int i) {
|
133
134
|
if (NIL_P(args[i])) return signature.params[i].default_scalar;
|
134
|
-
return
|
135
|
+
return Rice::detail::From_Ruby<torch::Scalar>().convert(args[i]);
|
136
|
+
}
|
137
|
+
|
138
|
+
inline std::vector<at::Scalar> RubyArgs::scalarlist(int i) {
|
139
|
+
if (NIL_P(args[i])) return std::vector<at::Scalar>();
|
140
|
+
return Rice::detail::From_Ruby<std::vector<at::Scalar>>().convert(args[i]);
|
135
141
|
}
|
136
142
|
|
137
143
|
inline std::vector<at::Tensor> RubyArgs::tensorlist(int i) {
|
138
144
|
if (NIL_P(args[i])) return std::vector<at::Tensor>();
|
139
|
-
return
|
145
|
+
return Rice::detail::From_Ruby<std::vector<Tensor>>().convert(args[i]);
|
140
146
|
}
|
141
147
|
|
142
148
|
template<int N>
|
@@ -151,7 +157,7 @@ inline std::array<at::Tensor, N> RubyArgs::tensorlist_n(int i) {
|
|
151
157
|
}
|
152
158
|
for (int idx = 0; idx < size; idx++) {
|
153
159
|
VALUE obj = rb_ary_entry(arg, idx);
|
154
|
-
res[idx] =
|
160
|
+
res[idx] = Rice::detail::From_Ruby<Tensor>().convert(obj);
|
155
161
|
}
|
156
162
|
return res;
|
157
163
|
}
|
@@ -170,7 +176,7 @@ inline std::vector<int64_t> RubyArgs::intlist(int i) {
|
|
170
176
|
for (idx = 0; idx < size; idx++) {
|
171
177
|
VALUE obj = rb_ary_entry(arg, idx);
|
172
178
|
if (FIXNUM_P(obj)) {
|
173
|
-
res[idx] =
|
179
|
+
res[idx] = Rice::detail::From_Ruby<int64_t>().convert(obj);
|
174
180
|
} else {
|
175
181
|
rb_raise(rb_eArgError, "%s(): argument '%s' must be %s, but found element of type %s at pos %d",
|
176
182
|
signature.name.c_str(), signature.params[i].name.c_str(),
|
@@ -210,8 +216,13 @@ inline ScalarType RubyArgs::scalartype(int i) {
|
|
210
216
|
{ID2SYM(rb_intern("double")), ScalarType::Double},
|
211
217
|
{ID2SYM(rb_intern("float64")), ScalarType::Double},
|
212
218
|
{ID2SYM(rb_intern("complex_half")), ScalarType::ComplexHalf},
|
219
|
+
{ID2SYM(rb_intern("complex32")), ScalarType::ComplexHalf},
|
213
220
|
{ID2SYM(rb_intern("complex_float")), ScalarType::ComplexFloat},
|
221
|
+
{ID2SYM(rb_intern("cfloat")), ScalarType::ComplexFloat},
|
222
|
+
{ID2SYM(rb_intern("complex64")), ScalarType::ComplexFloat},
|
214
223
|
{ID2SYM(rb_intern("complex_double")), ScalarType::ComplexDouble},
|
224
|
+
{ID2SYM(rb_intern("cdouble")), ScalarType::ComplexDouble},
|
225
|
+
{ID2SYM(rb_intern("complex128")), ScalarType::ComplexDouble},
|
215
226
|
{ID2SYM(rb_intern("bool")), ScalarType::Bool},
|
216
227
|
{ID2SYM(rb_intern("qint8")), ScalarType::QInt8},
|
217
228
|
{ID2SYM(rb_intern("quint8")), ScalarType::QUInt8},
|
@@ -260,7 +271,7 @@ inline c10::OptionalArray<double> RubyArgs::doublelistOptional(int i) {
|
|
260
271
|
for (idx = 0; idx < size; idx++) {
|
261
272
|
VALUE obj = rb_ary_entry(arg, idx);
|
262
273
|
if (FIXNUM_P(obj) || RB_FLOAT_TYPE_P(obj)) {
|
263
|
-
res[idx] =
|
274
|
+
res[idx] = Rice::detail::From_Ruby<double>().convert(obj);
|
264
275
|
} else {
|
265
276
|
rb_raise(rb_eArgError, "%s(): argument '%s' must be %s, but found element of type %s at pos %d",
|
266
277
|
signature.name.c_str(), signature.params[i].name.c_str(),
|
@@ -303,22 +314,22 @@ inline c10::optional<at::MemoryFormat> RubyArgs::memoryformatOptional(int i) {
|
|
303
314
|
}
|
304
315
|
|
305
316
|
inline std::string RubyArgs::string(int i) {
|
306
|
-
return
|
317
|
+
return Rice::detail::From_Ruby<std::string>().convert(args[i]);
|
307
318
|
}
|
308
319
|
|
309
320
|
inline c10::optional<std::string> RubyArgs::stringOptional(int i) {
|
310
|
-
if (
|
311
|
-
return
|
321
|
+
if (NIL_P(args[i])) return c10::nullopt;
|
322
|
+
return Rice::detail::From_Ruby<std::string>().convert(args[i]);
|
312
323
|
}
|
313
324
|
|
314
325
|
inline int64_t RubyArgs::toInt64(int i) {
|
315
326
|
if (NIL_P(args[i])) return signature.params[i].default_int;
|
316
|
-
return
|
327
|
+
return Rice::detail::From_Ruby<int64_t>().convert(args[i]);
|
317
328
|
}
|
318
329
|
|
319
330
|
inline double RubyArgs::toDouble(int i) {
|
320
331
|
if (NIL_P(args[i])) return signature.params[i].default_double;
|
321
|
-
return
|
332
|
+
return Rice::detail::From_Ruby<double>().convert(args[i]);
|
322
333
|
}
|
323
334
|
|
324
335
|
inline bool RubyArgs::toBool(int i) {
|
@@ -0,0 +1,13 @@
|
|
1
|
+
#include <torch/torch.h>
|
2
|
+
|
3
|
+
#include <rice/rice.hpp>
|
4
|
+
|
5
|
+
#include "special_functions.h"
|
6
|
+
#include "templates.h"
|
7
|
+
#include "utils.h"
|
8
|
+
|
9
|
+
void init_special(Rice::Module& m) {
|
10
|
+
auto rb_mSpecial = Rice::define_module_under(m, "Special");
|
11
|
+
rb_mSpecial.add_handler<torch::Error>(handle_error);
|
12
|
+
add_special_functions(rb_mSpecial);
|
13
|
+
}
|