torch-rb 0.6.0 → 0.8.2
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +21 -0
- data/README.md +23 -41
- data/codegen/function.rb +2 -0
- data/codegen/generate_functions.rb +43 -6
- data/codegen/native_functions.yaml +2007 -1327
- data/ext/torch/backends.cpp +17 -0
- data/ext/torch/cuda.cpp +5 -5
- data/ext/torch/device.cpp +13 -6
- data/ext/torch/ext.cpp +22 -5
- data/ext/torch/extconf.rb +1 -3
- data/ext/torch/fft.cpp +13 -0
- data/ext/torch/fft_functions.h +6 -0
- data/ext/torch/ivalue.cpp +31 -33
- data/ext/torch/linalg.cpp +13 -0
- data/ext/torch/linalg_functions.h +6 -0
- data/ext/torch/nn.cpp +34 -34
- data/ext/torch/random.cpp +5 -5
- data/ext/torch/ruby_arg_parser.cpp +2 -2
- data/ext/torch/ruby_arg_parser.h +23 -12
- data/ext/torch/special.cpp +13 -0
- data/ext/torch/special_functions.h +6 -0
- data/ext/torch/templates.h +111 -133
- data/ext/torch/tensor.cpp +80 -67
- data/ext/torch/torch.cpp +30 -21
- data/ext/torch/utils.h +3 -4
- data/ext/torch/wrap_outputs.h +72 -65
- data/lib/torch/inspector.rb +5 -2
- data/lib/torch/nn/convnd.rb +2 -0
- data/lib/torch/nn/functional_attention.rb +241 -0
- data/lib/torch/nn/module.rb +2 -0
- data/lib/torch/nn/module_list.rb +49 -0
- data/lib/torch/nn/multihead_attention.rb +123 -0
- data/lib/torch/nn/transformer.rb +92 -0
- data/lib/torch/nn/transformer_decoder.rb +25 -0
- data/lib/torch/nn/transformer_decoder_layer.rb +43 -0
- data/lib/torch/nn/transformer_encoder.rb +25 -0
- data/lib/torch/nn/transformer_encoder_layer.rb +36 -0
- data/lib/torch/nn/utils.rb +16 -0
- data/lib/torch/tensor.rb +2 -0
- data/lib/torch/utils/data/data_loader.rb +2 -0
- data/lib/torch/version.rb +1 -1
- data/lib/torch.rb +11 -0
- metadata +20 -5
@@ -0,0 +1,17 @@
|
|
1
|
+
#include <torch/torch.h>
|
2
|
+
|
3
|
+
#include <rice/rice.hpp>
|
4
|
+
|
5
|
+
#include "utils.h"
|
6
|
+
|
7
|
+
void init_backends(Rice::Module& m) {
|
8
|
+
auto rb_mBackends = Rice::define_module_under(m, "Backends");
|
9
|
+
|
10
|
+
Rice::define_module_under(rb_mBackends, "OpenMP")
|
11
|
+
.add_handler<torch::Error>(handle_error)
|
12
|
+
.define_singleton_function("available?", &torch::hasOpenMP);
|
13
|
+
|
14
|
+
Rice::define_module_under(rb_mBackends, "MKL")
|
15
|
+
.add_handler<torch::Error>(handle_error)
|
16
|
+
.define_singleton_function("available?", &torch::hasMKL);
|
17
|
+
}
|
data/ext/torch/cuda.cpp
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
#include <torch/torch.h>
|
2
2
|
|
3
|
-
#include <rice/
|
3
|
+
#include <rice/rice.hpp>
|
4
4
|
|
5
5
|
#include "utils.h"
|
6
6
|
|
7
7
|
void init_cuda(Rice::Module& m) {
|
8
8
|
Rice::define_module_under(m, "CUDA")
|
9
9
|
.add_handler<torch::Error>(handle_error)
|
10
|
-
.
|
11
|
-
.
|
12
|
-
.
|
13
|
-
.
|
10
|
+
.define_singleton_function("available?", &torch::cuda::is_available)
|
11
|
+
.define_singleton_function("device_count", &torch::cuda::device_count)
|
12
|
+
.define_singleton_function("manual_seed", &torch::cuda::manual_seed)
|
13
|
+
.define_singleton_function("manual_seed_all", &torch::cuda::manual_seed_all);
|
14
14
|
}
|
data/ext/torch/device.cpp
CHANGED
@@ -1,19 +1,26 @@
|
|
1
1
|
#include <torch/torch.h>
|
2
2
|
|
3
|
-
#include <rice/
|
4
|
-
#include <rice/Module.hpp>
|
3
|
+
#include <rice/rice.hpp>
|
5
4
|
|
6
5
|
#include "utils.h"
|
7
6
|
|
8
7
|
void init_device(Rice::Module& m) {
|
9
8
|
Rice::define_class_under<torch::Device>(m, "Device")
|
10
9
|
.add_handler<torch::Error>(handle_error)
|
11
|
-
.define_constructor(Rice::Constructor<torch::Device, std::string
|
12
|
-
.define_method(
|
13
|
-
|
10
|
+
.define_constructor(Rice::Constructor<torch::Device, const std::string&>())
|
11
|
+
.define_method(
|
12
|
+
"index",
|
13
|
+
[](torch::Device& self) {
|
14
|
+
return self.index();
|
15
|
+
})
|
16
|
+
.define_method(
|
17
|
+
"index?",
|
18
|
+
[](torch::Device& self) {
|
19
|
+
return self.has_index();
|
20
|
+
})
|
14
21
|
.define_method(
|
15
22
|
"type",
|
16
|
-
|
23
|
+
[](torch::Device& self) {
|
17
24
|
std::stringstream s;
|
18
25
|
s << self.type();
|
19
26
|
return s.str();
|
data/ext/torch/ext.cpp
CHANGED
@@ -1,12 +1,18 @@
|
|
1
|
-
#include <
|
1
|
+
#include <torch/torch.h>
|
2
2
|
|
3
|
+
#include <rice/rice.hpp>
|
4
|
+
|
5
|
+
void init_fft(Rice::Module& m);
|
6
|
+
void init_linalg(Rice::Module& m);
|
3
7
|
void init_nn(Rice::Module& m);
|
4
|
-
void
|
8
|
+
void init_special(Rice::Module& m);
|
9
|
+
void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions);
|
5
10
|
void init_torch(Rice::Module& m);
|
6
11
|
|
12
|
+
void init_backends(Rice::Module& m);
|
7
13
|
void init_cuda(Rice::Module& m);
|
8
14
|
void init_device(Rice::Module& m);
|
9
|
-
void init_ivalue(Rice::Module& m);
|
15
|
+
void init_ivalue(Rice::Module& m, Rice::Class& rb_cIValue);
|
10
16
|
void init_random(Rice::Module& m);
|
11
17
|
|
12
18
|
extern "C"
|
@@ -14,13 +20,24 @@ void Init_ext()
|
|
14
20
|
{
|
15
21
|
auto m = Rice::define_module("Torch");
|
16
22
|
|
23
|
+
// need to define certain classes up front to keep Rice happy
|
24
|
+
auto rb_cIValue = Rice::define_class_under<torch::IValue>(m, "IValue")
|
25
|
+
.define_constructor(Rice::Constructor<torch::IValue>());
|
26
|
+
auto rb_cTensor = Rice::define_class_under<torch::Tensor>(m, "Tensor");
|
27
|
+
auto rb_cTensorOptions = Rice::define_class_under<torch::TensorOptions>(m, "TensorOptions")
|
28
|
+
.define_constructor(Rice::Constructor<torch::TensorOptions>());
|
29
|
+
|
17
30
|
// keep this order
|
18
31
|
init_torch(m);
|
19
|
-
init_tensor(m);
|
32
|
+
init_tensor(m, rb_cTensor, rb_cTensorOptions);
|
20
33
|
init_nn(m);
|
34
|
+
init_fft(m);
|
35
|
+
init_linalg(m);
|
36
|
+
init_special(m);
|
21
37
|
|
38
|
+
init_backends(m);
|
22
39
|
init_cuda(m);
|
23
40
|
init_device(m);
|
24
|
-
init_ivalue(m);
|
41
|
+
init_ivalue(m, rb_cIValue);
|
25
42
|
init_random(m);
|
26
43
|
}
|
data/ext/torch/extconf.rb
CHANGED
data/ext/torch/fft.cpp
ADDED
@@ -0,0 +1,13 @@
|
|
1
|
+
#include <torch/torch.h>
|
2
|
+
|
3
|
+
#include <rice/rice.hpp>
|
4
|
+
|
5
|
+
#include "fft_functions.h"
|
6
|
+
#include "templates.h"
|
7
|
+
#include "utils.h"
|
8
|
+
|
9
|
+
void init_fft(Rice::Module& m) {
|
10
|
+
auto rb_mFFT = Rice::define_module_under(m, "FFT");
|
11
|
+
rb_mFFT.add_handler<torch::Error>(handle_error);
|
12
|
+
add_fft_functions(rb_mFFT);
|
13
|
+
}
|
data/ext/torch/ivalue.cpp
CHANGED
@@ -1,18 +1,13 @@
|
|
1
1
|
#include <torch/torch.h>
|
2
2
|
|
3
|
-
#include <rice/
|
4
|
-
#include <rice/Constructor.hpp>
|
5
|
-
#include <rice/Hash.hpp>
|
6
|
-
#include <rice/Module.hpp>
|
7
|
-
#include <rice/String.hpp>
|
3
|
+
#include <rice/rice.hpp>
|
8
4
|
|
9
5
|
#include "utils.h"
|
10
6
|
|
11
|
-
void init_ivalue(Rice::Module& m) {
|
7
|
+
void init_ivalue(Rice::Module& m, Rice::Class& rb_cIValue) {
|
12
8
|
// https://pytorch.org/cppdocs/api/structc10_1_1_i_value.html
|
13
|
-
|
9
|
+
rb_cIValue
|
14
10
|
.add_handler<torch::Error>(handle_error)
|
15
|
-
.define_constructor(Rice::Constructor<torch::IValue>())
|
16
11
|
.define_method("bool?", &torch::IValue::isBool)
|
17
12
|
.define_method("bool_list?", &torch::IValue::isBoolList)
|
18
13
|
.define_method("capsule?", &torch::IValue::isCapsule)
|
@@ -39,95 +34,98 @@ void init_ivalue(Rice::Module& m) {
|
|
39
34
|
.define_method("tuple?", &torch::IValue::isTuple)
|
40
35
|
.define_method(
|
41
36
|
"to_bool",
|
42
|
-
|
37
|
+
[](torch::IValue& self) {
|
43
38
|
return self.toBool();
|
44
39
|
})
|
45
40
|
.define_method(
|
46
41
|
"to_double",
|
47
|
-
|
42
|
+
[](torch::IValue& self) {
|
48
43
|
return self.toDouble();
|
49
44
|
})
|
50
45
|
.define_method(
|
51
46
|
"to_int",
|
52
|
-
|
47
|
+
[](torch::IValue& self) {
|
53
48
|
return self.toInt();
|
54
49
|
})
|
55
50
|
.define_method(
|
56
51
|
"to_list",
|
57
|
-
|
52
|
+
[](torch::IValue& self) {
|
58
53
|
auto list = self.toListRef();
|
59
54
|
Rice::Array obj;
|
60
55
|
for (auto& elem : list) {
|
61
|
-
|
56
|
+
auto v = torch::IValue{elem};
|
57
|
+
obj.push(Rice::Object(Rice::detail::To_Ruby<torch::IValue>().convert(v)));
|
62
58
|
}
|
63
59
|
return obj;
|
64
60
|
})
|
65
61
|
.define_method(
|
66
62
|
"to_string_ref",
|
67
|
-
|
63
|
+
[](torch::IValue& self) {
|
68
64
|
return self.toStringRef();
|
69
65
|
})
|
70
66
|
.define_method(
|
71
67
|
"to_tensor",
|
72
|
-
|
68
|
+
[](torch::IValue& self) {
|
73
69
|
return self.toTensor();
|
74
70
|
})
|
75
71
|
.define_method(
|
76
72
|
"to_generic_dict",
|
77
|
-
|
73
|
+
[](torch::IValue& self) {
|
78
74
|
auto dict = self.toGenericDict();
|
79
75
|
Rice::Hash obj;
|
80
76
|
for (auto& pair : dict) {
|
81
|
-
|
77
|
+
auto k = torch::IValue{pair.key()};
|
78
|
+
auto v = torch::IValue{pair.value()};
|
79
|
+
obj[Rice::Object(Rice::detail::To_Ruby<torch::IValue>().convert(k))] = Rice::Object(Rice::detail::To_Ruby<torch::IValue>().convert(v));
|
82
80
|
}
|
83
81
|
return obj;
|
84
82
|
})
|
85
|
-
.
|
83
|
+
.define_singleton_function(
|
86
84
|
"from_tensor",
|
87
|
-
|
85
|
+
[](torch::Tensor& v) {
|
88
86
|
return torch::IValue(v);
|
89
87
|
})
|
90
88
|
// TODO create specialized list types?
|
91
|
-
.
|
89
|
+
.define_singleton_function(
|
92
90
|
"from_list",
|
93
|
-
|
91
|
+
[](Rice::Array obj) {
|
94
92
|
c10::impl::GenericList list(c10::AnyType::get());
|
95
93
|
for (auto entry : obj) {
|
96
|
-
list.push_back(
|
94
|
+
list.push_back(Rice::detail::From_Ruby<torch::IValue>().convert(entry.value()));
|
97
95
|
}
|
98
96
|
return torch::IValue(list);
|
99
97
|
})
|
100
|
-
.
|
98
|
+
.define_singleton_function(
|
101
99
|
"from_string",
|
102
|
-
|
100
|
+
[](Rice::String v) {
|
103
101
|
return torch::IValue(v.str());
|
104
102
|
})
|
105
|
-
.
|
103
|
+
.define_singleton_function(
|
106
104
|
"from_int",
|
107
|
-
|
105
|
+
[](int64_t v) {
|
108
106
|
return torch::IValue(v);
|
109
107
|
})
|
110
|
-
.
|
108
|
+
.define_singleton_function(
|
111
109
|
"from_double",
|
112
|
-
|
110
|
+
[](double v) {
|
113
111
|
return torch::IValue(v);
|
114
112
|
})
|
115
|
-
.
|
113
|
+
.define_singleton_function(
|
116
114
|
"from_bool",
|
117
|
-
|
115
|
+
[](bool v) {
|
118
116
|
return torch::IValue(v);
|
119
117
|
})
|
120
118
|
// see https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/python/pybind_utils.h
|
121
119
|
// createGenericDict and toIValue
|
122
|
-
.
|
120
|
+
.define_singleton_function(
|
123
121
|
"from_dict",
|
124
|
-
|
122
|
+
[](Rice::Hash obj) {
|
125
123
|
auto key_type = c10::AnyType::get();
|
126
124
|
auto value_type = c10::AnyType::get();
|
127
125
|
c10::impl::GenericDict elems(key_type, value_type);
|
128
126
|
elems.reserve(obj.size());
|
129
127
|
for (auto entry : obj) {
|
130
|
-
elems.insert(
|
128
|
+
elems.insert(Rice::detail::From_Ruby<torch::IValue>().convert(entry.first), Rice::detail::From_Ruby<torch::IValue>().convert((Rice::Object) entry.second));
|
131
129
|
}
|
132
130
|
return torch::IValue(std::move(elems));
|
133
131
|
});
|
@@ -0,0 +1,13 @@
|
|
1
|
+
#include <torch/torch.h>
|
2
|
+
|
3
|
+
#include <rice/rice.hpp>
|
4
|
+
|
5
|
+
#include "linalg_functions.h"
|
6
|
+
#include "templates.h"
|
7
|
+
#include "utils.h"
|
8
|
+
|
9
|
+
void init_linalg(Rice::Module& m) {
|
10
|
+
auto rb_mLinalg = Rice::define_module_under(m, "Linalg");
|
11
|
+
rb_mLinalg.add_handler<torch::Error>(handle_error);
|
12
|
+
add_linalg_functions(rb_mLinalg);
|
13
|
+
}
|
data/ext/torch/nn.cpp
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
#include <torch/torch.h>
|
2
2
|
|
3
|
-
#include <rice/
|
3
|
+
#include <rice/rice.hpp>
|
4
4
|
|
5
5
|
#include "nn_functions.h"
|
6
6
|
#include "templates.h"
|
@@ -19,74 +19,74 @@ void init_nn(Rice::Module& m) {
|
|
19
19
|
|
20
20
|
Rice::define_module_under(rb_mNN, "Init")
|
21
21
|
.add_handler<torch::Error>(handle_error)
|
22
|
-
.
|
22
|
+
.define_singleton_function(
|
23
23
|
"_calculate_gain",
|
24
|
-
|
24
|
+
[](NonlinearityType nonlinearity, double param) {
|
25
25
|
return torch::nn::init::calculate_gain(nonlinearity, param);
|
26
26
|
})
|
27
|
-
.
|
27
|
+
.define_singleton_function(
|
28
28
|
"_uniform!",
|
29
|
-
|
29
|
+
[](Tensor tensor, double low, double high) {
|
30
30
|
return torch::nn::init::uniform_(tensor, low, high);
|
31
31
|
})
|
32
|
-
.
|
32
|
+
.define_singleton_function(
|
33
33
|
"_normal!",
|
34
|
-
|
34
|
+
[](Tensor tensor, double mean, double std) {
|
35
35
|
return torch::nn::init::normal_(tensor, mean, std);
|
36
36
|
})
|
37
|
-
.
|
37
|
+
.define_singleton_function(
|
38
38
|
"_constant!",
|
39
|
-
|
39
|
+
[](Tensor tensor, Scalar value) {
|
40
40
|
return torch::nn::init::constant_(tensor, value);
|
41
41
|
})
|
42
|
-
.
|
42
|
+
.define_singleton_function(
|
43
43
|
"_ones!",
|
44
|
-
|
44
|
+
[](Tensor tensor) {
|
45
45
|
return torch::nn::init::ones_(tensor);
|
46
46
|
})
|
47
|
-
.
|
47
|
+
.define_singleton_function(
|
48
48
|
"_zeros!",
|
49
|
-
|
49
|
+
[](Tensor tensor) {
|
50
50
|
return torch::nn::init::zeros_(tensor);
|
51
51
|
})
|
52
|
-
.
|
52
|
+
.define_singleton_function(
|
53
53
|
"_eye!",
|
54
|
-
|
54
|
+
[](Tensor tensor) {
|
55
55
|
return torch::nn::init::eye_(tensor);
|
56
56
|
})
|
57
|
-
.
|
57
|
+
.define_singleton_function(
|
58
58
|
"_dirac!",
|
59
|
-
|
59
|
+
[](Tensor tensor) {
|
60
60
|
return torch::nn::init::dirac_(tensor);
|
61
61
|
})
|
62
|
-
.
|
62
|
+
.define_singleton_function(
|
63
63
|
"_xavier_uniform!",
|
64
|
-
|
64
|
+
[](Tensor tensor, double gain) {
|
65
65
|
return torch::nn::init::xavier_uniform_(tensor, gain);
|
66
66
|
})
|
67
|
-
.
|
67
|
+
.define_singleton_function(
|
68
68
|
"_xavier_normal!",
|
69
|
-
|
69
|
+
[](Tensor tensor, double gain) {
|
70
70
|
return torch::nn::init::xavier_normal_(tensor, gain);
|
71
71
|
})
|
72
|
-
.
|
72
|
+
.define_singleton_function(
|
73
73
|
"_kaiming_uniform!",
|
74
|
-
|
74
|
+
[](Tensor tensor, double a, FanModeType mode, NonlinearityType nonlinearity) {
|
75
75
|
return torch::nn::init::kaiming_uniform_(tensor, a, mode, nonlinearity);
|
76
76
|
})
|
77
|
-
.
|
77
|
+
.define_singleton_function(
|
78
78
|
"_kaiming_normal!",
|
79
|
-
|
79
|
+
[](Tensor tensor, double a, FanModeType mode, NonlinearityType nonlinearity) {
|
80
80
|
return torch::nn::init::kaiming_normal_(tensor, a, mode, nonlinearity);
|
81
81
|
})
|
82
|
-
.
|
82
|
+
.define_singleton_function(
|
83
83
|
"_orthogonal!",
|
84
|
-
|
84
|
+
[](Tensor tensor, double gain) {
|
85
85
|
return torch::nn::init::orthogonal_(tensor, gain);
|
86
86
|
})
|
87
|
-
.
|
87
|
+
.define_singleton_function(
|
88
88
|
"_sparse!",
|
89
|
-
|
89
|
+
[](Tensor tensor, double sparsity, double std) {
|
90
90
|
return torch::nn::init::sparse_(tensor, sparsity, std);
|
91
91
|
});
|
92
92
|
|
@@ -94,18 +94,18 @@ void init_nn(Rice::Module& m) {
|
|
94
94
|
.add_handler<torch::Error>(handle_error)
|
95
95
|
.define_method(
|
96
96
|
"grad",
|
97
|
-
|
97
|
+
[](Parameter& self) {
|
98
98
|
auto grad = self.grad();
|
99
|
-
return grad.defined() ?
|
99
|
+
return grad.defined() ? Object(Rice::detail::To_Ruby<torch::Tensor>().convert(grad)) : Nil;
|
100
100
|
})
|
101
101
|
.define_method(
|
102
102
|
"grad=",
|
103
|
-
|
103
|
+
[](Parameter& self, torch::Tensor& grad) {
|
104
104
|
self.mutable_grad() = grad;
|
105
105
|
})
|
106
|
-
.
|
106
|
+
.define_singleton_function(
|
107
107
|
"_make_subclass",
|
108
|
-
|
108
|
+
[](Tensor& rd, bool requires_grad) {
|
109
109
|
auto data = rd.detach();
|
110
110
|
data.unsafeGetTensorImpl()->set_allow_tensor_metadata_change(true);
|
111
111
|
auto var = data.set_requires_grad(requires_grad);
|
data/ext/torch/random.cpp
CHANGED
@@ -1,20 +1,20 @@
|
|
1
1
|
#include <torch/torch.h>
|
2
2
|
|
3
|
-
#include <rice/
|
3
|
+
#include <rice/rice.hpp>
|
4
4
|
|
5
5
|
#include "utils.h"
|
6
6
|
|
7
7
|
void init_random(Rice::Module& m) {
|
8
8
|
Rice::define_module_under(m, "Random")
|
9
9
|
.add_handler<torch::Error>(handle_error)
|
10
|
-
.
|
10
|
+
.define_singleton_function(
|
11
11
|
"initial_seed",
|
12
|
-
|
12
|
+
[]() {
|
13
13
|
return at::detail::getDefaultCPUGenerator().current_seed();
|
14
14
|
})
|
15
|
-
.
|
15
|
+
.define_singleton_function(
|
16
16
|
"seed",
|
17
|
-
|
17
|
+
[]() {
|
18
18
|
// TODO set for CUDA when available
|
19
19
|
auto generator = at::detail::getDefaultCPUGenerator();
|
20
20
|
return generator.seed();
|
@@ -137,7 +137,7 @@ auto FunctionParameter::check(VALUE obj, int argnum) -> bool
|
|
137
137
|
return true;
|
138
138
|
}
|
139
139
|
if (THPVariable_Check(obj)) {
|
140
|
-
auto var =
|
140
|
+
auto var = Rice::detail::From_Ruby<torch::Tensor>().convert(obj);
|
141
141
|
return !var.requires_grad() && var.dim() == 0;
|
142
142
|
}
|
143
143
|
return false;
|
@@ -147,7 +147,7 @@ auto FunctionParameter::check(VALUE obj, int argnum) -> bool
|
|
147
147
|
return true;
|
148
148
|
}
|
149
149
|
if (THPVariable_Check(obj)) {
|
150
|
-
auto var =
|
150
|
+
auto var = Rice::detail::From_Ruby<torch::Tensor>().convert(obj);
|
151
151
|
return at::isIntegralType(var.scalar_type(), /*includeBool=*/false) && !var.requires_grad() && var.dim() == 0;
|
152
152
|
}
|
153
153
|
return false;
|
data/ext/torch/ruby_arg_parser.h
CHANGED
@@ -5,7 +5,7 @@
|
|
5
5
|
#include <sstream>
|
6
6
|
|
7
7
|
#include <torch/torch.h>
|
8
|
-
#include <rice/
|
8
|
+
#include <rice/rice.hpp>
|
9
9
|
|
10
10
|
#include "templates.h"
|
11
11
|
#include "utils.h"
|
@@ -78,6 +78,7 @@ struct RubyArgs {
|
|
78
78
|
inline OptionalTensor optionalTensor(int i);
|
79
79
|
inline at::Scalar scalar(int i);
|
80
80
|
// inline at::Scalar scalarWithDefault(int i, at::Scalar default_scalar);
|
81
|
+
inline std::vector<at::Scalar> scalarlist(int i);
|
81
82
|
inline std::vector<at::Tensor> tensorlist(int i);
|
82
83
|
template<int N>
|
83
84
|
inline std::array<at::Tensor, N> tensorlist_n(int i);
|
@@ -121,7 +122,7 @@ struct RubyArgs {
|
|
121
122
|
};
|
122
123
|
|
123
124
|
inline at::Tensor RubyArgs::tensor(int i) {
|
124
|
-
return
|
125
|
+
return Rice::detail::From_Ruby<torch::Tensor>().convert(args[i]);
|
125
126
|
}
|
126
127
|
|
127
128
|
inline OptionalTensor RubyArgs::optionalTensor(int i) {
|
@@ -131,12 +132,17 @@ inline OptionalTensor RubyArgs::optionalTensor(int i) {
|
|
131
132
|
|
132
133
|
inline at::Scalar RubyArgs::scalar(int i) {
|
133
134
|
if (NIL_P(args[i])) return signature.params[i].default_scalar;
|
134
|
-
return
|
135
|
+
return Rice::detail::From_Ruby<torch::Scalar>().convert(args[i]);
|
136
|
+
}
|
137
|
+
|
138
|
+
inline std::vector<at::Scalar> RubyArgs::scalarlist(int i) {
|
139
|
+
if (NIL_P(args[i])) return std::vector<at::Scalar>();
|
140
|
+
return Rice::detail::From_Ruby<std::vector<at::Scalar>>().convert(args[i]);
|
135
141
|
}
|
136
142
|
|
137
143
|
inline std::vector<at::Tensor> RubyArgs::tensorlist(int i) {
|
138
144
|
if (NIL_P(args[i])) return std::vector<at::Tensor>();
|
139
|
-
return
|
145
|
+
return Rice::detail::From_Ruby<std::vector<Tensor>>().convert(args[i]);
|
140
146
|
}
|
141
147
|
|
142
148
|
template<int N>
|
@@ -151,7 +157,7 @@ inline std::array<at::Tensor, N> RubyArgs::tensorlist_n(int i) {
|
|
151
157
|
}
|
152
158
|
for (int idx = 0; idx < size; idx++) {
|
153
159
|
VALUE obj = rb_ary_entry(arg, idx);
|
154
|
-
res[idx] =
|
160
|
+
res[idx] = Rice::detail::From_Ruby<Tensor>().convert(obj);
|
155
161
|
}
|
156
162
|
return res;
|
157
163
|
}
|
@@ -170,7 +176,7 @@ inline std::vector<int64_t> RubyArgs::intlist(int i) {
|
|
170
176
|
for (idx = 0; idx < size; idx++) {
|
171
177
|
VALUE obj = rb_ary_entry(arg, idx);
|
172
178
|
if (FIXNUM_P(obj)) {
|
173
|
-
res[idx] =
|
179
|
+
res[idx] = Rice::detail::From_Ruby<int64_t>().convert(obj);
|
174
180
|
} else {
|
175
181
|
rb_raise(rb_eArgError, "%s(): argument '%s' must be %s, but found element of type %s at pos %d",
|
176
182
|
signature.name.c_str(), signature.params[i].name.c_str(),
|
@@ -210,8 +216,13 @@ inline ScalarType RubyArgs::scalartype(int i) {
|
|
210
216
|
{ID2SYM(rb_intern("double")), ScalarType::Double},
|
211
217
|
{ID2SYM(rb_intern("float64")), ScalarType::Double},
|
212
218
|
{ID2SYM(rb_intern("complex_half")), ScalarType::ComplexHalf},
|
219
|
+
{ID2SYM(rb_intern("complex32")), ScalarType::ComplexHalf},
|
213
220
|
{ID2SYM(rb_intern("complex_float")), ScalarType::ComplexFloat},
|
221
|
+
{ID2SYM(rb_intern("cfloat")), ScalarType::ComplexFloat},
|
222
|
+
{ID2SYM(rb_intern("complex64")), ScalarType::ComplexFloat},
|
214
223
|
{ID2SYM(rb_intern("complex_double")), ScalarType::ComplexDouble},
|
224
|
+
{ID2SYM(rb_intern("cdouble")), ScalarType::ComplexDouble},
|
225
|
+
{ID2SYM(rb_intern("complex128")), ScalarType::ComplexDouble},
|
215
226
|
{ID2SYM(rb_intern("bool")), ScalarType::Bool},
|
216
227
|
{ID2SYM(rb_intern("qint8")), ScalarType::QInt8},
|
217
228
|
{ID2SYM(rb_intern("quint8")), ScalarType::QUInt8},
|
@@ -260,7 +271,7 @@ inline c10::OptionalArray<double> RubyArgs::doublelistOptional(int i) {
|
|
260
271
|
for (idx = 0; idx < size; idx++) {
|
261
272
|
VALUE obj = rb_ary_entry(arg, idx);
|
262
273
|
if (FIXNUM_P(obj) || RB_FLOAT_TYPE_P(obj)) {
|
263
|
-
res[idx] =
|
274
|
+
res[idx] = Rice::detail::From_Ruby<double>().convert(obj);
|
264
275
|
} else {
|
265
276
|
rb_raise(rb_eArgError, "%s(): argument '%s' must be %s, but found element of type %s at pos %d",
|
266
277
|
signature.name.c_str(), signature.params[i].name.c_str(),
|
@@ -303,22 +314,22 @@ inline c10::optional<at::MemoryFormat> RubyArgs::memoryformatOptional(int i) {
|
|
303
314
|
}
|
304
315
|
|
305
316
|
inline std::string RubyArgs::string(int i) {
|
306
|
-
return
|
317
|
+
return Rice::detail::From_Ruby<std::string>().convert(args[i]);
|
307
318
|
}
|
308
319
|
|
309
320
|
inline c10::optional<std::string> RubyArgs::stringOptional(int i) {
|
310
|
-
if (
|
311
|
-
return
|
321
|
+
if (NIL_P(args[i])) return c10::nullopt;
|
322
|
+
return Rice::detail::From_Ruby<std::string>().convert(args[i]);
|
312
323
|
}
|
313
324
|
|
314
325
|
inline int64_t RubyArgs::toInt64(int i) {
|
315
326
|
if (NIL_P(args[i])) return signature.params[i].default_int;
|
316
|
-
return
|
327
|
+
return Rice::detail::From_Ruby<int64_t>().convert(args[i]);
|
317
328
|
}
|
318
329
|
|
319
330
|
inline double RubyArgs::toDouble(int i) {
|
320
331
|
if (NIL_P(args[i])) return signature.params[i].default_double;
|
321
|
-
return
|
332
|
+
return Rice::detail::From_Ruby<double>().convert(args[i]);
|
322
333
|
}
|
323
334
|
|
324
335
|
inline bool RubyArgs::toBool(int i) {
|
@@ -0,0 +1,13 @@
|
|
1
|
+
#include <torch/torch.h>
|
2
|
+
|
3
|
+
#include <rice/rice.hpp>
|
4
|
+
|
5
|
+
#include "special_functions.h"
|
6
|
+
#include "templates.h"
|
7
|
+
#include "utils.h"
|
8
|
+
|
9
|
+
void init_special(Rice::Module& m) {
|
10
|
+
auto rb_mSpecial = Rice::define_module_under(m, "Special");
|
11
|
+
rb_mSpecial.add_handler<torch::Error>(handle_error);
|
12
|
+
add_special_functions(rb_mSpecial);
|
13
|
+
}
|