torch-rb 0.5.2 → 0.8.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +28 -0
- data/README.md +14 -4
- data/codegen/function.rb +2 -0
- data/codegen/generate_functions.rb +48 -10
- data/codegen/native_functions.yaml +3318 -1679
- data/ext/torch/backends.cpp +17 -0
- data/ext/torch/cuda.cpp +14 -0
- data/ext/torch/device.cpp +28 -0
- data/ext/torch/ext.cpp +34 -613
- data/ext/torch/extconf.rb +1 -4
- data/ext/torch/fft.cpp +13 -0
- data/ext/torch/fft_functions.h +6 -0
- data/ext/torch/ivalue.cpp +132 -0
- data/ext/torch/linalg.cpp +13 -0
- data/ext/torch/linalg_functions.h +6 -0
- data/ext/torch/nn.cpp +114 -0
- data/ext/torch/nn_functions.h +1 -1
- data/ext/torch/random.cpp +22 -0
- data/ext/torch/ruby_arg_parser.cpp +3 -3
- data/ext/torch/ruby_arg_parser.h +44 -17
- data/ext/torch/special.cpp +13 -0
- data/ext/torch/special_functions.h +6 -0
- data/ext/torch/templates.h +111 -133
- data/ext/torch/tensor.cpp +320 -0
- data/ext/torch/tensor_functions.h +1 -1
- data/ext/torch/torch.cpp +95 -0
- data/ext/torch/torch_functions.h +1 -1
- data/ext/torch/utils.h +8 -2
- data/ext/torch/wrap_outputs.h +72 -65
- data/lib/torch.rb +14 -10
- data/lib/torch/inspector.rb +5 -2
- data/lib/torch/nn/linear.rb +2 -0
- data/lib/torch/nn/module.rb +107 -21
- data/lib/torch/nn/parameter.rb +1 -1
- data/lib/torch/tensor.rb +4 -0
- data/lib/torch/utils/data/data_loader.rb +1 -1
- data/lib/torch/version.rb +1 -1
- metadata +21 -91
data/ext/torch/extconf.rb
CHANGED
@@ -1,8 +1,6 @@
|
|
1
1
|
require "mkmf-rice"
|
2
2
|
|
3
|
-
|
4
|
-
|
5
|
-
$CXXFLAGS += " -std=c++14"
|
3
|
+
$CXXFLAGS += " -std=c++17 $(optflags)"
|
6
4
|
|
7
5
|
# change to 0 for Linux pre-cxx11 ABI version
|
8
6
|
$CXXFLAGS += " -D_GLIBCXX_USE_CXX11_ABI=1"
|
@@ -11,7 +9,6 @@ apple_clang = RbConfig::CONFIG["CC_VERSION_MESSAGE"] =~ /apple clang/i
|
|
11
9
|
|
12
10
|
# check omp first
|
13
11
|
if have_library("omp") || have_library("gomp")
|
14
|
-
$CXXFLAGS += " -DAT_PARALLEL_OPENMP=1"
|
15
12
|
$CXXFLAGS += " -Xclang" if apple_clang
|
16
13
|
$CXXFLAGS += " -fopenmp"
|
17
14
|
end
|
data/ext/torch/fft.cpp
ADDED
@@ -0,0 +1,13 @@
|
|
1
|
+
#include <torch/torch.h>
|
2
|
+
|
3
|
+
#include <rice/rice.hpp>
|
4
|
+
|
5
|
+
#include "fft_functions.h"
|
6
|
+
#include "templates.h"
|
7
|
+
#include "utils.h"
|
8
|
+
|
9
|
+
void init_fft(Rice::Module& m) {
|
10
|
+
auto rb_mFFT = Rice::define_module_under(m, "FFT");
|
11
|
+
rb_mFFT.add_handler<torch::Error>(handle_error);
|
12
|
+
add_fft_functions(rb_mFFT);
|
13
|
+
}
|
@@ -0,0 +1,132 @@
|
|
1
|
+
#include <torch/torch.h>
|
2
|
+
|
3
|
+
#include <rice/rice.hpp>
|
4
|
+
|
5
|
+
#include "utils.h"
|
6
|
+
|
7
|
+
void init_ivalue(Rice::Module& m, Rice::Class& rb_cIValue) {
|
8
|
+
// https://pytorch.org/cppdocs/api/structc10_1_1_i_value.html
|
9
|
+
rb_cIValue
|
10
|
+
.add_handler<torch::Error>(handle_error)
|
11
|
+
.define_method("bool?", &torch::IValue::isBool)
|
12
|
+
.define_method("bool_list?", &torch::IValue::isBoolList)
|
13
|
+
.define_method("capsule?", &torch::IValue::isCapsule)
|
14
|
+
.define_method("custom_class?", &torch::IValue::isCustomClass)
|
15
|
+
.define_method("device?", &torch::IValue::isDevice)
|
16
|
+
.define_method("double?", &torch::IValue::isDouble)
|
17
|
+
.define_method("double_list?", &torch::IValue::isDoubleList)
|
18
|
+
.define_method("future?", &torch::IValue::isFuture)
|
19
|
+
// .define_method("generator?", &torch::IValue::isGenerator)
|
20
|
+
.define_method("generic_dict?", &torch::IValue::isGenericDict)
|
21
|
+
.define_method("list?", &torch::IValue::isList)
|
22
|
+
.define_method("int?", &torch::IValue::isInt)
|
23
|
+
.define_method("int_list?", &torch::IValue::isIntList)
|
24
|
+
.define_method("module?", &torch::IValue::isModule)
|
25
|
+
.define_method("none?", &torch::IValue::isNone)
|
26
|
+
.define_method("object?", &torch::IValue::isObject)
|
27
|
+
.define_method("ptr_type?", &torch::IValue::isPtrType)
|
28
|
+
.define_method("py_object?", &torch::IValue::isPyObject)
|
29
|
+
.define_method("r_ref?", &torch::IValue::isRRef)
|
30
|
+
.define_method("scalar?", &torch::IValue::isScalar)
|
31
|
+
.define_method("string?", &torch::IValue::isString)
|
32
|
+
.define_method("tensor?", &torch::IValue::isTensor)
|
33
|
+
.define_method("tensor_list?", &torch::IValue::isTensorList)
|
34
|
+
.define_method("tuple?", &torch::IValue::isTuple)
|
35
|
+
.define_method(
|
36
|
+
"to_bool",
|
37
|
+
[](torch::IValue& self) {
|
38
|
+
return self.toBool();
|
39
|
+
})
|
40
|
+
.define_method(
|
41
|
+
"to_double",
|
42
|
+
[](torch::IValue& self) {
|
43
|
+
return self.toDouble();
|
44
|
+
})
|
45
|
+
.define_method(
|
46
|
+
"to_int",
|
47
|
+
[](torch::IValue& self) {
|
48
|
+
return self.toInt();
|
49
|
+
})
|
50
|
+
.define_method(
|
51
|
+
"to_list",
|
52
|
+
[](torch::IValue& self) {
|
53
|
+
auto list = self.toListRef();
|
54
|
+
Rice::Array obj;
|
55
|
+
for (auto& elem : list) {
|
56
|
+
auto v = torch::IValue{elem};
|
57
|
+
obj.push(Rice::Object(Rice::detail::To_Ruby<torch::IValue>().convert(v)));
|
58
|
+
}
|
59
|
+
return obj;
|
60
|
+
})
|
61
|
+
.define_method(
|
62
|
+
"to_string_ref",
|
63
|
+
[](torch::IValue& self) {
|
64
|
+
return self.toStringRef();
|
65
|
+
})
|
66
|
+
.define_method(
|
67
|
+
"to_tensor",
|
68
|
+
[](torch::IValue& self) {
|
69
|
+
return self.toTensor();
|
70
|
+
})
|
71
|
+
.define_method(
|
72
|
+
"to_generic_dict",
|
73
|
+
[](torch::IValue& self) {
|
74
|
+
auto dict = self.toGenericDict();
|
75
|
+
Rice::Hash obj;
|
76
|
+
for (auto& pair : dict) {
|
77
|
+
auto k = torch::IValue{pair.key()};
|
78
|
+
auto v = torch::IValue{pair.value()};
|
79
|
+
obj[Rice::Object(Rice::detail::To_Ruby<torch::IValue>().convert(k))] = Rice::Object(Rice::detail::To_Ruby<torch::IValue>().convert(v));
|
80
|
+
}
|
81
|
+
return obj;
|
82
|
+
})
|
83
|
+
.define_singleton_function(
|
84
|
+
"from_tensor",
|
85
|
+
[](torch::Tensor& v) {
|
86
|
+
return torch::IValue(v);
|
87
|
+
})
|
88
|
+
// TODO create specialized list types?
|
89
|
+
.define_singleton_function(
|
90
|
+
"from_list",
|
91
|
+
[](Rice::Array obj) {
|
92
|
+
c10::impl::GenericList list(c10::AnyType::get());
|
93
|
+
for (auto entry : obj) {
|
94
|
+
list.push_back(Rice::detail::From_Ruby<torch::IValue>().convert(entry.value()));
|
95
|
+
}
|
96
|
+
return torch::IValue(list);
|
97
|
+
})
|
98
|
+
.define_singleton_function(
|
99
|
+
"from_string",
|
100
|
+
[](Rice::String v) {
|
101
|
+
return torch::IValue(v.str());
|
102
|
+
})
|
103
|
+
.define_singleton_function(
|
104
|
+
"from_int",
|
105
|
+
[](int64_t v) {
|
106
|
+
return torch::IValue(v);
|
107
|
+
})
|
108
|
+
.define_singleton_function(
|
109
|
+
"from_double",
|
110
|
+
[](double v) {
|
111
|
+
return torch::IValue(v);
|
112
|
+
})
|
113
|
+
.define_singleton_function(
|
114
|
+
"from_bool",
|
115
|
+
[](bool v) {
|
116
|
+
return torch::IValue(v);
|
117
|
+
})
|
118
|
+
// see https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/python/pybind_utils.h
|
119
|
+
// createGenericDict and toIValue
|
120
|
+
.define_singleton_function(
|
121
|
+
"from_dict",
|
122
|
+
[](Rice::Hash obj) {
|
123
|
+
auto key_type = c10::AnyType::get();
|
124
|
+
auto value_type = c10::AnyType::get();
|
125
|
+
c10::impl::GenericDict elems(key_type, value_type);
|
126
|
+
elems.reserve(obj.size());
|
127
|
+
for (auto entry : obj) {
|
128
|
+
elems.insert(Rice::detail::From_Ruby<torch::IValue>().convert(entry.first), Rice::detail::From_Ruby<torch::IValue>().convert((Rice::Object) entry.second));
|
129
|
+
}
|
130
|
+
return torch::IValue(std::move(elems));
|
131
|
+
});
|
132
|
+
}
|
@@ -0,0 +1,13 @@
|
|
1
|
+
#include <torch/torch.h>
|
2
|
+
|
3
|
+
#include <rice/rice.hpp>
|
4
|
+
|
5
|
+
#include "linalg_functions.h"
|
6
|
+
#include "templates.h"
|
7
|
+
#include "utils.h"
|
8
|
+
|
9
|
+
void init_linalg(Rice::Module& m) {
|
10
|
+
auto rb_mLinalg = Rice::define_module_under(m, "Linalg");
|
11
|
+
rb_mLinalg.add_handler<torch::Error>(handle_error);
|
12
|
+
add_linalg_functions(rb_mLinalg);
|
13
|
+
}
|
data/ext/torch/nn.cpp
ADDED
@@ -0,0 +1,114 @@
|
|
1
|
+
#include <torch/torch.h>
|
2
|
+
|
3
|
+
#include <rice/rice.hpp>
|
4
|
+
|
5
|
+
#include "nn_functions.h"
|
6
|
+
#include "templates.h"
|
7
|
+
#include "utils.h"
|
8
|
+
|
9
|
+
// need to make a distinction between parameters and tensors
|
10
|
+
class Parameter: public torch::autograd::Variable {
|
11
|
+
public:
|
12
|
+
Parameter(Tensor&& t) : torch::autograd::Variable(t) { }
|
13
|
+
};
|
14
|
+
|
15
|
+
void init_nn(Rice::Module& m) {
|
16
|
+
auto rb_mNN = Rice::define_module_under(m, "NN");
|
17
|
+
rb_mNN.add_handler<torch::Error>(handle_error);
|
18
|
+
add_nn_functions(rb_mNN);
|
19
|
+
|
20
|
+
Rice::define_module_under(rb_mNN, "Init")
|
21
|
+
.add_handler<torch::Error>(handle_error)
|
22
|
+
.define_singleton_function(
|
23
|
+
"_calculate_gain",
|
24
|
+
[](NonlinearityType nonlinearity, double param) {
|
25
|
+
return torch::nn::init::calculate_gain(nonlinearity, param);
|
26
|
+
})
|
27
|
+
.define_singleton_function(
|
28
|
+
"_uniform!",
|
29
|
+
[](Tensor tensor, double low, double high) {
|
30
|
+
return torch::nn::init::uniform_(tensor, low, high);
|
31
|
+
})
|
32
|
+
.define_singleton_function(
|
33
|
+
"_normal!",
|
34
|
+
[](Tensor tensor, double mean, double std) {
|
35
|
+
return torch::nn::init::normal_(tensor, mean, std);
|
36
|
+
})
|
37
|
+
.define_singleton_function(
|
38
|
+
"_constant!",
|
39
|
+
[](Tensor tensor, Scalar value) {
|
40
|
+
return torch::nn::init::constant_(tensor, value);
|
41
|
+
})
|
42
|
+
.define_singleton_function(
|
43
|
+
"_ones!",
|
44
|
+
[](Tensor tensor) {
|
45
|
+
return torch::nn::init::ones_(tensor);
|
46
|
+
})
|
47
|
+
.define_singleton_function(
|
48
|
+
"_zeros!",
|
49
|
+
[](Tensor tensor) {
|
50
|
+
return torch::nn::init::zeros_(tensor);
|
51
|
+
})
|
52
|
+
.define_singleton_function(
|
53
|
+
"_eye!",
|
54
|
+
[](Tensor tensor) {
|
55
|
+
return torch::nn::init::eye_(tensor);
|
56
|
+
})
|
57
|
+
.define_singleton_function(
|
58
|
+
"_dirac!",
|
59
|
+
[](Tensor tensor) {
|
60
|
+
return torch::nn::init::dirac_(tensor);
|
61
|
+
})
|
62
|
+
.define_singleton_function(
|
63
|
+
"_xavier_uniform!",
|
64
|
+
[](Tensor tensor, double gain) {
|
65
|
+
return torch::nn::init::xavier_uniform_(tensor, gain);
|
66
|
+
})
|
67
|
+
.define_singleton_function(
|
68
|
+
"_xavier_normal!",
|
69
|
+
[](Tensor tensor, double gain) {
|
70
|
+
return torch::nn::init::xavier_normal_(tensor, gain);
|
71
|
+
})
|
72
|
+
.define_singleton_function(
|
73
|
+
"_kaiming_uniform!",
|
74
|
+
[](Tensor tensor, double a, FanModeType mode, NonlinearityType nonlinearity) {
|
75
|
+
return torch::nn::init::kaiming_uniform_(tensor, a, mode, nonlinearity);
|
76
|
+
})
|
77
|
+
.define_singleton_function(
|
78
|
+
"_kaiming_normal!",
|
79
|
+
[](Tensor tensor, double a, FanModeType mode, NonlinearityType nonlinearity) {
|
80
|
+
return torch::nn::init::kaiming_normal_(tensor, a, mode, nonlinearity);
|
81
|
+
})
|
82
|
+
.define_singleton_function(
|
83
|
+
"_orthogonal!",
|
84
|
+
[](Tensor tensor, double gain) {
|
85
|
+
return torch::nn::init::orthogonal_(tensor, gain);
|
86
|
+
})
|
87
|
+
.define_singleton_function(
|
88
|
+
"_sparse!",
|
89
|
+
[](Tensor tensor, double sparsity, double std) {
|
90
|
+
return torch::nn::init::sparse_(tensor, sparsity, std);
|
91
|
+
});
|
92
|
+
|
93
|
+
Rice::define_class_under<Parameter, torch::Tensor>(rb_mNN, "Parameter")
|
94
|
+
.add_handler<torch::Error>(handle_error)
|
95
|
+
.define_method(
|
96
|
+
"grad",
|
97
|
+
[](Parameter& self) {
|
98
|
+
auto grad = self.grad();
|
99
|
+
return grad.defined() ? Object(Rice::detail::To_Ruby<torch::Tensor>().convert(grad)) : Nil;
|
100
|
+
})
|
101
|
+
.define_method(
|
102
|
+
"grad=",
|
103
|
+
[](Parameter& self, torch::Tensor& grad) {
|
104
|
+
self.mutable_grad() = grad;
|
105
|
+
})
|
106
|
+
.define_singleton_function(
|
107
|
+
"_make_subclass",
|
108
|
+
[](Tensor& rd, bool requires_grad) {
|
109
|
+
auto data = rd.detach();
|
110
|
+
data.unsafeGetTensorImpl()->set_allow_tensor_metadata_change(true);
|
111
|
+
auto var = data.set_requires_grad(requires_grad);
|
112
|
+
return Parameter(std::move(var));
|
113
|
+
});
|
114
|
+
}
|
data/ext/torch/nn_functions.h
CHANGED
@@ -0,0 +1,22 @@
|
|
1
|
+
#include <torch/torch.h>
|
2
|
+
|
3
|
+
#include <rice/rice.hpp>
|
4
|
+
|
5
|
+
#include "utils.h"
|
6
|
+
|
7
|
+
void init_random(Rice::Module& m) {
|
8
|
+
Rice::define_module_under(m, "Random")
|
9
|
+
.add_handler<torch::Error>(handle_error)
|
10
|
+
.define_singleton_function(
|
11
|
+
"initial_seed",
|
12
|
+
[]() {
|
13
|
+
return at::detail::getDefaultCPUGenerator().current_seed();
|
14
|
+
})
|
15
|
+
.define_singleton_function(
|
16
|
+
"seed",
|
17
|
+
[]() {
|
18
|
+
// TODO set for CUDA when available
|
19
|
+
auto generator = at::detail::getDefaultCPUGenerator();
|
20
|
+
return generator.seed();
|
21
|
+
});
|
22
|
+
}
|
@@ -137,7 +137,7 @@ auto FunctionParameter::check(VALUE obj, int argnum) -> bool
|
|
137
137
|
return true;
|
138
138
|
}
|
139
139
|
if (THPVariable_Check(obj)) {
|
140
|
-
auto var =
|
140
|
+
auto var = Rice::detail::From_Ruby<torch::Tensor>().convert(obj);
|
141
141
|
return !var.requires_grad() && var.dim() == 0;
|
142
142
|
}
|
143
143
|
return false;
|
@@ -147,7 +147,7 @@ auto FunctionParameter::check(VALUE obj, int argnum) -> bool
|
|
147
147
|
return true;
|
148
148
|
}
|
149
149
|
if (THPVariable_Check(obj)) {
|
150
|
-
auto var =
|
150
|
+
auto var = Rice::detail::From_Ruby<torch::Tensor>().convert(obj);
|
151
151
|
return at::isIntegralType(var.scalar_type(), /*includeBool=*/false) && !var.requires_grad() && var.dim() == 0;
|
152
152
|
}
|
153
153
|
return false;
|
@@ -487,7 +487,7 @@ static void extra_kwargs(FunctionSignature& signature, VALUE kwargs, ssize_t num
|
|
487
487
|
|
488
488
|
VALUE missing = Qundef;
|
489
489
|
|
490
|
-
bool FunctionSignature::parse(VALUE self, VALUE args, VALUE kwargs,
|
490
|
+
bool FunctionSignature::parse(VALUE self, VALUE args, VALUE kwargs, VALUE dst[], // NOLINT
|
491
491
|
bool raise_exception) {
|
492
492
|
auto nargs = NIL_P(args) ? 0 : RARRAY_LEN(args);
|
493
493
|
ssize_t remaining_kwargs = NIL_P(kwargs) ? 0 : RHASH_SIZE(kwargs);
|
data/ext/torch/ruby_arg_parser.h
CHANGED
@@ -2,8 +2,10 @@
|
|
2
2
|
|
3
3
|
#pragma once
|
4
4
|
|
5
|
+
#include <sstream>
|
6
|
+
|
5
7
|
#include <torch/torch.h>
|
6
|
-
#include <rice/
|
8
|
+
#include <rice/rice.hpp>
|
7
9
|
|
8
10
|
#include "templates.h"
|
9
11
|
#include "utils.h"
|
@@ -46,7 +48,7 @@ struct FunctionParameter {
|
|
46
48
|
struct FunctionSignature {
|
47
49
|
explicit FunctionSignature(const std::string& fmt, int index);
|
48
50
|
|
49
|
-
bool parse(VALUE self, VALUE args, VALUE kwargs,
|
51
|
+
bool parse(VALUE self, VALUE args, VALUE kwargs, VALUE dst[], bool raise_exception);
|
50
52
|
|
51
53
|
std::string toString() const;
|
52
54
|
|
@@ -63,19 +65,20 @@ struct FunctionSignature {
|
|
63
65
|
};
|
64
66
|
|
65
67
|
struct RubyArgs {
|
66
|
-
RubyArgs(const FunctionSignature& signature,
|
68
|
+
RubyArgs(const FunctionSignature& signature, VALUE* args)
|
67
69
|
: signature(signature)
|
68
70
|
, args(args)
|
69
71
|
, idx(signature.index) {}
|
70
72
|
|
71
73
|
const FunctionSignature& signature;
|
72
|
-
|
74
|
+
VALUE* args;
|
73
75
|
int idx;
|
74
76
|
|
75
77
|
inline at::Tensor tensor(int i);
|
76
78
|
inline OptionalTensor optionalTensor(int i);
|
77
79
|
inline at::Scalar scalar(int i);
|
78
80
|
// inline at::Scalar scalarWithDefault(int i, at::Scalar default_scalar);
|
81
|
+
inline std::vector<at::Scalar> scalarlist(int i);
|
79
82
|
inline std::vector<at::Tensor> tensorlist(int i);
|
80
83
|
template<int N>
|
81
84
|
inline std::array<at::Tensor, N> tensorlist_n(int i);
|
@@ -119,7 +122,7 @@ struct RubyArgs {
|
|
119
122
|
};
|
120
123
|
|
121
124
|
inline at::Tensor RubyArgs::tensor(int i) {
|
122
|
-
return
|
125
|
+
return Rice::detail::From_Ruby<torch::Tensor>().convert(args[i]);
|
123
126
|
}
|
124
127
|
|
125
128
|
inline OptionalTensor RubyArgs::optionalTensor(int i) {
|
@@ -129,12 +132,17 @@ inline OptionalTensor RubyArgs::optionalTensor(int i) {
|
|
129
132
|
|
130
133
|
inline at::Scalar RubyArgs::scalar(int i) {
|
131
134
|
if (NIL_P(args[i])) return signature.params[i].default_scalar;
|
132
|
-
return
|
135
|
+
return Rice::detail::From_Ruby<torch::Scalar>().convert(args[i]);
|
136
|
+
}
|
137
|
+
|
138
|
+
inline std::vector<at::Scalar> RubyArgs::scalarlist(int i) {
|
139
|
+
if (NIL_P(args[i])) return std::vector<at::Scalar>();
|
140
|
+
return Rice::detail::From_Ruby<std::vector<at::Scalar>>().convert(args[i]);
|
133
141
|
}
|
134
142
|
|
135
143
|
inline std::vector<at::Tensor> RubyArgs::tensorlist(int i) {
|
136
144
|
if (NIL_P(args[i])) return std::vector<at::Tensor>();
|
137
|
-
return
|
145
|
+
return Rice::detail::From_Ruby<std::vector<Tensor>>().convert(args[i]);
|
138
146
|
}
|
139
147
|
|
140
148
|
template<int N>
|
@@ -149,7 +157,7 @@ inline std::array<at::Tensor, N> RubyArgs::tensorlist_n(int i) {
|
|
149
157
|
}
|
150
158
|
for (int idx = 0; idx < size; idx++) {
|
151
159
|
VALUE obj = rb_ary_entry(arg, idx);
|
152
|
-
res[idx] =
|
160
|
+
res[idx] = Rice::detail::From_Ruby<Tensor>().convert(obj);
|
153
161
|
}
|
154
162
|
return res;
|
155
163
|
}
|
@@ -168,7 +176,7 @@ inline std::vector<int64_t> RubyArgs::intlist(int i) {
|
|
168
176
|
for (idx = 0; idx < size; idx++) {
|
169
177
|
VALUE obj = rb_ary_entry(arg, idx);
|
170
178
|
if (FIXNUM_P(obj)) {
|
171
|
-
res[idx] =
|
179
|
+
res[idx] = Rice::detail::From_Ruby<int64_t>().convert(obj);
|
172
180
|
} else {
|
173
181
|
rb_raise(rb_eArgError, "%s(): argument '%s' must be %s, but found element of type %s at pos %d",
|
174
182
|
signature.name.c_str(), signature.params[i].name.c_str(),
|
@@ -208,8 +216,13 @@ inline ScalarType RubyArgs::scalartype(int i) {
|
|
208
216
|
{ID2SYM(rb_intern("double")), ScalarType::Double},
|
209
217
|
{ID2SYM(rb_intern("float64")), ScalarType::Double},
|
210
218
|
{ID2SYM(rb_intern("complex_half")), ScalarType::ComplexHalf},
|
219
|
+
{ID2SYM(rb_intern("complex32")), ScalarType::ComplexHalf},
|
211
220
|
{ID2SYM(rb_intern("complex_float")), ScalarType::ComplexFloat},
|
221
|
+
{ID2SYM(rb_intern("cfloat")), ScalarType::ComplexFloat},
|
222
|
+
{ID2SYM(rb_intern("complex64")), ScalarType::ComplexFloat},
|
212
223
|
{ID2SYM(rb_intern("complex_double")), ScalarType::ComplexDouble},
|
224
|
+
{ID2SYM(rb_intern("cdouble")), ScalarType::ComplexDouble},
|
225
|
+
{ID2SYM(rb_intern("complex128")), ScalarType::ComplexDouble},
|
213
226
|
{ID2SYM(rb_intern("bool")), ScalarType::Bool},
|
214
227
|
{ID2SYM(rb_intern("qint8")), ScalarType::QInt8},
|
215
228
|
{ID2SYM(rb_intern("quint8")), ScalarType::QUInt8},
|
@@ -258,7 +271,7 @@ inline c10::OptionalArray<double> RubyArgs::doublelistOptional(int i) {
|
|
258
271
|
for (idx = 0; idx < size; idx++) {
|
259
272
|
VALUE obj = rb_ary_entry(arg, idx);
|
260
273
|
if (FIXNUM_P(obj) || RB_FLOAT_TYPE_P(obj)) {
|
261
|
-
res[idx] =
|
274
|
+
res[idx] = Rice::detail::From_Ruby<double>().convert(obj);
|
262
275
|
} else {
|
263
276
|
rb_raise(rb_eArgError, "%s(): argument '%s' must be %s, but found element of type %s at pos %d",
|
264
277
|
signature.name.c_str(), signature.params[i].name.c_str(),
|
@@ -301,22 +314,22 @@ inline c10::optional<at::MemoryFormat> RubyArgs::memoryformatOptional(int i) {
|
|
301
314
|
}
|
302
315
|
|
303
316
|
inline std::string RubyArgs::string(int i) {
|
304
|
-
return
|
317
|
+
return Rice::detail::From_Ruby<std::string>().convert(args[i]);
|
305
318
|
}
|
306
319
|
|
307
320
|
inline c10::optional<std::string> RubyArgs::stringOptional(int i) {
|
308
|
-
if (
|
309
|
-
return
|
321
|
+
if (NIL_P(args[i])) return c10::nullopt;
|
322
|
+
return Rice::detail::From_Ruby<std::string>().convert(args[i]);
|
310
323
|
}
|
311
324
|
|
312
325
|
inline int64_t RubyArgs::toInt64(int i) {
|
313
326
|
if (NIL_P(args[i])) return signature.params[i].default_int;
|
314
|
-
return
|
327
|
+
return Rice::detail::From_Ruby<int64_t>().convert(args[i]);
|
315
328
|
}
|
316
329
|
|
317
330
|
inline double RubyArgs::toDouble(int i) {
|
318
331
|
if (NIL_P(args[i])) return signature.params[i].default_double;
|
319
|
-
return
|
332
|
+
return Rice::detail::From_Ruby<double>().convert(args[i]);
|
320
333
|
}
|
321
334
|
|
322
335
|
inline bool RubyArgs::toBool(int i) {
|
@@ -328,6 +341,12 @@ inline bool RubyArgs::isNone(int i) {
|
|
328
341
|
return NIL_P(args[i]);
|
329
342
|
}
|
330
343
|
|
344
|
+
template<int N>
|
345
|
+
struct ParsedArgs {
|
346
|
+
ParsedArgs() : args() { }
|
347
|
+
VALUE args[N];
|
348
|
+
};
|
349
|
+
|
331
350
|
struct RubyArgParser {
|
332
351
|
std::vector<FunctionSignature> signatures_;
|
333
352
|
std::string function_name;
|
@@ -356,7 +375,15 @@ struct RubyArgParser {
|
|
356
375
|
});
|
357
376
|
}
|
358
377
|
|
359
|
-
|
378
|
+
template<int N>
|
379
|
+
inline RubyArgs parse(VALUE self, int argc, VALUE* argv, ParsedArgs<N> &dst) {
|
380
|
+
if (N < max_args) {
|
381
|
+
rb_raise(rb_eArgError, "RubyArgParser: dst ParsedArgs buffer does not have enough capacity, expected %d (got %d)", (int)max_args, N);
|
382
|
+
}
|
383
|
+
return raw_parse(self, argc, argv, dst.args);
|
384
|
+
}
|
385
|
+
|
386
|
+
inline RubyArgs raw_parse(VALUE self, int argc, VALUE* argv, VALUE parsed_args[]) {
|
360
387
|
VALUE args, kwargs;
|
361
388
|
rb_scan_args(argc, argv, "*:", &args, &kwargs);
|
362
389
|
|
@@ -378,7 +405,7 @@ struct RubyArgParser {
|
|
378
405
|
rb_raise(rb_eArgError, "No matching signatures");
|
379
406
|
}
|
380
407
|
|
381
|
-
void print_error(VALUE self, VALUE args, VALUE kwargs,
|
408
|
+
void print_error(VALUE self, VALUE args, VALUE kwargs, VALUE parsed_args[]) {
|
382
409
|
ssize_t num_args = (NIL_P(args) ? 0 : RARRAY_LEN(args)) + (NIL_P(kwargs) ? 0 : RHASH_SIZE(kwargs));
|
383
410
|
std::vector<int> plausible_idxs;
|
384
411
|
ssize_t i = 0;
|