torch-rb 0.5.0 → 0.7.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +26 -0
- data/README.md +13 -4
- data/codegen/generate_functions.rb +13 -14
- data/codegen/native_functions.yaml +2355 -1396
- data/ext/torch/cuda.cpp +14 -0
- data/ext/torch/device.cpp +28 -0
- data/ext/torch/ext.cpp +26 -613
- data/ext/torch/extconf.rb +1 -4
- data/ext/torch/ivalue.cpp +132 -0
- data/ext/torch/nn.cpp +114 -0
- data/ext/torch/nn_functions.h +1 -1
- data/ext/torch/random.cpp +22 -0
- data/ext/torch/ruby_arg_parser.cpp +3 -3
- data/ext/torch/ruby_arg_parser.h +37 -16
- data/ext/torch/templates.h +110 -133
- data/ext/torch/tensor.cpp +320 -0
- data/ext/torch/tensor_functions.h +1 -1
- data/ext/torch/torch.cpp +95 -0
- data/ext/torch/torch_functions.h +1 -1
- data/ext/torch/utils.h +8 -2
- data/ext/torch/wrap_outputs.h +72 -65
- data/lib/torch.rb +19 -17
- data/lib/torch/inspector.rb +5 -2
- data/lib/torch/nn/linear.rb +2 -0
- data/lib/torch/nn/module.rb +107 -21
- data/lib/torch/nn/parameter.rb +1 -1
- data/lib/torch/tensor.rb +9 -0
- data/lib/torch/utils/data/data_loader.rb +1 -1
- data/lib/torch/version.rb +1 -1
- metadata +14 -91
data/ext/torch/extconf.rb
CHANGED
@@ -1,8 +1,6 @@
|
|
1
1
|
require "mkmf-rice"
|
2
2
|
|
3
|
-
|
4
|
-
|
5
|
-
$CXXFLAGS += " -std=c++14"
|
3
|
+
$CXXFLAGS += " -std=c++17 $(optflags)"
|
6
4
|
|
7
5
|
# change to 0 for Linux pre-cxx11 ABI version
|
8
6
|
$CXXFLAGS += " -D_GLIBCXX_USE_CXX11_ABI=1"
|
@@ -11,7 +9,6 @@ apple_clang = RbConfig::CONFIG["CC_VERSION_MESSAGE"] =~ /apple clang/i
|
|
11
9
|
|
12
10
|
# check omp first
|
13
11
|
if have_library("omp") || have_library("gomp")
|
14
|
-
$CXXFLAGS += " -DAT_PARALLEL_OPENMP=1"
|
15
12
|
$CXXFLAGS += " -Xclang" if apple_clang
|
16
13
|
$CXXFLAGS += " -fopenmp"
|
17
14
|
end
|
@@ -0,0 +1,132 @@
|
|
1
|
+
#include <torch/torch.h>
|
2
|
+
|
3
|
+
#include <rice/rice.hpp>
|
4
|
+
|
5
|
+
#include "utils.h"
|
6
|
+
|
7
|
+
void init_ivalue(Rice::Module& m, Rice::Class& rb_cIValue) {
|
8
|
+
// https://pytorch.org/cppdocs/api/structc10_1_1_i_value.html
|
9
|
+
rb_cIValue
|
10
|
+
.add_handler<torch::Error>(handle_error)
|
11
|
+
.define_method("bool?", &torch::IValue::isBool)
|
12
|
+
.define_method("bool_list?", &torch::IValue::isBoolList)
|
13
|
+
.define_method("capsule?", &torch::IValue::isCapsule)
|
14
|
+
.define_method("custom_class?", &torch::IValue::isCustomClass)
|
15
|
+
.define_method("device?", &torch::IValue::isDevice)
|
16
|
+
.define_method("double?", &torch::IValue::isDouble)
|
17
|
+
.define_method("double_list?", &torch::IValue::isDoubleList)
|
18
|
+
.define_method("future?", &torch::IValue::isFuture)
|
19
|
+
// .define_method("generator?", &torch::IValue::isGenerator)
|
20
|
+
.define_method("generic_dict?", &torch::IValue::isGenericDict)
|
21
|
+
.define_method("list?", &torch::IValue::isList)
|
22
|
+
.define_method("int?", &torch::IValue::isInt)
|
23
|
+
.define_method("int_list?", &torch::IValue::isIntList)
|
24
|
+
.define_method("module?", &torch::IValue::isModule)
|
25
|
+
.define_method("none?", &torch::IValue::isNone)
|
26
|
+
.define_method("object?", &torch::IValue::isObject)
|
27
|
+
.define_method("ptr_type?", &torch::IValue::isPtrType)
|
28
|
+
.define_method("py_object?", &torch::IValue::isPyObject)
|
29
|
+
.define_method("r_ref?", &torch::IValue::isRRef)
|
30
|
+
.define_method("scalar?", &torch::IValue::isScalar)
|
31
|
+
.define_method("string?", &torch::IValue::isString)
|
32
|
+
.define_method("tensor?", &torch::IValue::isTensor)
|
33
|
+
.define_method("tensor_list?", &torch::IValue::isTensorList)
|
34
|
+
.define_method("tuple?", &torch::IValue::isTuple)
|
35
|
+
.define_method(
|
36
|
+
"to_bool",
|
37
|
+
[](torch::IValue& self) {
|
38
|
+
return self.toBool();
|
39
|
+
})
|
40
|
+
.define_method(
|
41
|
+
"to_double",
|
42
|
+
[](torch::IValue& self) {
|
43
|
+
return self.toDouble();
|
44
|
+
})
|
45
|
+
.define_method(
|
46
|
+
"to_int",
|
47
|
+
[](torch::IValue& self) {
|
48
|
+
return self.toInt();
|
49
|
+
})
|
50
|
+
.define_method(
|
51
|
+
"to_list",
|
52
|
+
[](torch::IValue& self) {
|
53
|
+
auto list = self.toListRef();
|
54
|
+
Rice::Array obj;
|
55
|
+
for (auto& elem : list) {
|
56
|
+
auto v = torch::IValue{elem};
|
57
|
+
obj.push(Rice::Object(Rice::detail::To_Ruby<torch::IValue>().convert(v)));
|
58
|
+
}
|
59
|
+
return obj;
|
60
|
+
})
|
61
|
+
.define_method(
|
62
|
+
"to_string_ref",
|
63
|
+
[](torch::IValue& self) {
|
64
|
+
return self.toStringRef();
|
65
|
+
})
|
66
|
+
.define_method(
|
67
|
+
"to_tensor",
|
68
|
+
[](torch::IValue& self) {
|
69
|
+
return self.toTensor();
|
70
|
+
})
|
71
|
+
.define_method(
|
72
|
+
"to_generic_dict",
|
73
|
+
[](torch::IValue& self) {
|
74
|
+
auto dict = self.toGenericDict();
|
75
|
+
Rice::Hash obj;
|
76
|
+
for (auto& pair : dict) {
|
77
|
+
auto k = torch::IValue{pair.key()};
|
78
|
+
auto v = torch::IValue{pair.value()};
|
79
|
+
obj[Rice::Object(Rice::detail::To_Ruby<torch::IValue>().convert(k))] = Rice::Object(Rice::detail::To_Ruby<torch::IValue>().convert(v));
|
80
|
+
}
|
81
|
+
return obj;
|
82
|
+
})
|
83
|
+
.define_singleton_function(
|
84
|
+
"from_tensor",
|
85
|
+
[](torch::Tensor& v) {
|
86
|
+
return torch::IValue(v);
|
87
|
+
})
|
88
|
+
// TODO create specialized list types?
|
89
|
+
.define_singleton_function(
|
90
|
+
"from_list",
|
91
|
+
[](Rice::Array obj) {
|
92
|
+
c10::impl::GenericList list(c10::AnyType::get());
|
93
|
+
for (auto entry : obj) {
|
94
|
+
list.push_back(Rice::detail::From_Ruby<torch::IValue>().convert(entry.value()));
|
95
|
+
}
|
96
|
+
return torch::IValue(list);
|
97
|
+
})
|
98
|
+
.define_singleton_function(
|
99
|
+
"from_string",
|
100
|
+
[](Rice::String v) {
|
101
|
+
return torch::IValue(v.str());
|
102
|
+
})
|
103
|
+
.define_singleton_function(
|
104
|
+
"from_int",
|
105
|
+
[](int64_t v) {
|
106
|
+
return torch::IValue(v);
|
107
|
+
})
|
108
|
+
.define_singleton_function(
|
109
|
+
"from_double",
|
110
|
+
[](double v) {
|
111
|
+
return torch::IValue(v);
|
112
|
+
})
|
113
|
+
.define_singleton_function(
|
114
|
+
"from_bool",
|
115
|
+
[](bool v) {
|
116
|
+
return torch::IValue(v);
|
117
|
+
})
|
118
|
+
// see https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/python/pybind_utils.h
|
119
|
+
// createGenericDict and toIValue
|
120
|
+
.define_singleton_function(
|
121
|
+
"from_dict",
|
122
|
+
[](Rice::Hash obj) {
|
123
|
+
auto key_type = c10::AnyType::get();
|
124
|
+
auto value_type = c10::AnyType::get();
|
125
|
+
c10::impl::GenericDict elems(key_type, value_type);
|
126
|
+
elems.reserve(obj.size());
|
127
|
+
for (auto entry : obj) {
|
128
|
+
elems.insert(Rice::detail::From_Ruby<torch::IValue>().convert(entry.first), Rice::detail::From_Ruby<torch::IValue>().convert((Rice::Object) entry.second));
|
129
|
+
}
|
130
|
+
return torch::IValue(std::move(elems));
|
131
|
+
});
|
132
|
+
}
|
data/ext/torch/nn.cpp
ADDED
@@ -0,0 +1,114 @@
|
|
1
|
+
#include <torch/torch.h>
|
2
|
+
|
3
|
+
#include <rice/rice.hpp>
|
4
|
+
|
5
|
+
#include "nn_functions.h"
|
6
|
+
#include "templates.h"
|
7
|
+
#include "utils.h"
|
8
|
+
|
9
|
+
// need to make a distinction between parameters and tensors
|
10
|
+
class Parameter: public torch::autograd::Variable {
|
11
|
+
public:
|
12
|
+
Parameter(Tensor&& t) : torch::autograd::Variable(t) { }
|
13
|
+
};
|
14
|
+
|
15
|
+
void init_nn(Rice::Module& m) {
|
16
|
+
auto rb_mNN = Rice::define_module_under(m, "NN");
|
17
|
+
rb_mNN.add_handler<torch::Error>(handle_error);
|
18
|
+
add_nn_functions(rb_mNN);
|
19
|
+
|
20
|
+
Rice::define_module_under(rb_mNN, "Init")
|
21
|
+
.add_handler<torch::Error>(handle_error)
|
22
|
+
.define_singleton_function(
|
23
|
+
"_calculate_gain",
|
24
|
+
[](NonlinearityType nonlinearity, double param) {
|
25
|
+
return torch::nn::init::calculate_gain(nonlinearity, param);
|
26
|
+
})
|
27
|
+
.define_singleton_function(
|
28
|
+
"_uniform!",
|
29
|
+
[](Tensor tensor, double low, double high) {
|
30
|
+
return torch::nn::init::uniform_(tensor, low, high);
|
31
|
+
})
|
32
|
+
.define_singleton_function(
|
33
|
+
"_normal!",
|
34
|
+
[](Tensor tensor, double mean, double std) {
|
35
|
+
return torch::nn::init::normal_(tensor, mean, std);
|
36
|
+
})
|
37
|
+
.define_singleton_function(
|
38
|
+
"_constant!",
|
39
|
+
[](Tensor tensor, Scalar value) {
|
40
|
+
return torch::nn::init::constant_(tensor, value);
|
41
|
+
})
|
42
|
+
.define_singleton_function(
|
43
|
+
"_ones!",
|
44
|
+
[](Tensor tensor) {
|
45
|
+
return torch::nn::init::ones_(tensor);
|
46
|
+
})
|
47
|
+
.define_singleton_function(
|
48
|
+
"_zeros!",
|
49
|
+
[](Tensor tensor) {
|
50
|
+
return torch::nn::init::zeros_(tensor);
|
51
|
+
})
|
52
|
+
.define_singleton_function(
|
53
|
+
"_eye!",
|
54
|
+
[](Tensor tensor) {
|
55
|
+
return torch::nn::init::eye_(tensor);
|
56
|
+
})
|
57
|
+
.define_singleton_function(
|
58
|
+
"_dirac!",
|
59
|
+
[](Tensor tensor) {
|
60
|
+
return torch::nn::init::dirac_(tensor);
|
61
|
+
})
|
62
|
+
.define_singleton_function(
|
63
|
+
"_xavier_uniform!",
|
64
|
+
[](Tensor tensor, double gain) {
|
65
|
+
return torch::nn::init::xavier_uniform_(tensor, gain);
|
66
|
+
})
|
67
|
+
.define_singleton_function(
|
68
|
+
"_xavier_normal!",
|
69
|
+
[](Tensor tensor, double gain) {
|
70
|
+
return torch::nn::init::xavier_normal_(tensor, gain);
|
71
|
+
})
|
72
|
+
.define_singleton_function(
|
73
|
+
"_kaiming_uniform!",
|
74
|
+
[](Tensor tensor, double a, FanModeType mode, NonlinearityType nonlinearity) {
|
75
|
+
return torch::nn::init::kaiming_uniform_(tensor, a, mode, nonlinearity);
|
76
|
+
})
|
77
|
+
.define_singleton_function(
|
78
|
+
"_kaiming_normal!",
|
79
|
+
[](Tensor tensor, double a, FanModeType mode, NonlinearityType nonlinearity) {
|
80
|
+
return torch::nn::init::kaiming_normal_(tensor, a, mode, nonlinearity);
|
81
|
+
})
|
82
|
+
.define_singleton_function(
|
83
|
+
"_orthogonal!",
|
84
|
+
[](Tensor tensor, double gain) {
|
85
|
+
return torch::nn::init::orthogonal_(tensor, gain);
|
86
|
+
})
|
87
|
+
.define_singleton_function(
|
88
|
+
"_sparse!",
|
89
|
+
[](Tensor tensor, double sparsity, double std) {
|
90
|
+
return torch::nn::init::sparse_(tensor, sparsity, std);
|
91
|
+
});
|
92
|
+
|
93
|
+
Rice::define_class_under<Parameter, torch::Tensor>(rb_mNN, "Parameter")
|
94
|
+
.add_handler<torch::Error>(handle_error)
|
95
|
+
.define_method(
|
96
|
+
"grad",
|
97
|
+
[](Parameter& self) {
|
98
|
+
auto grad = self.grad();
|
99
|
+
return grad.defined() ? Object(Rice::detail::To_Ruby<torch::Tensor>().convert(grad)) : Nil;
|
100
|
+
})
|
101
|
+
.define_method(
|
102
|
+
"grad=",
|
103
|
+
[](Parameter& self, torch::Tensor& grad) {
|
104
|
+
self.mutable_grad() = grad;
|
105
|
+
})
|
106
|
+
.define_singleton_function(
|
107
|
+
"_make_subclass",
|
108
|
+
[](Tensor& rd, bool requires_grad) {
|
109
|
+
auto data = rd.detach();
|
110
|
+
data.unsafeGetTensorImpl()->set_allow_tensor_metadata_change(true);
|
111
|
+
auto var = data.set_requires_grad(requires_grad);
|
112
|
+
return Parameter(std::move(var));
|
113
|
+
});
|
114
|
+
}
|
data/ext/torch/nn_functions.h
CHANGED
@@ -0,0 +1,22 @@
|
|
1
|
+
#include <torch/torch.h>
|
2
|
+
|
3
|
+
#include <rice/rice.hpp>
|
4
|
+
|
5
|
+
#include "utils.h"
|
6
|
+
|
7
|
+
void init_random(Rice::Module& m) {
|
8
|
+
Rice::define_module_under(m, "Random")
|
9
|
+
.add_handler<torch::Error>(handle_error)
|
10
|
+
.define_singleton_function(
|
11
|
+
"initial_seed",
|
12
|
+
[]() {
|
13
|
+
return at::detail::getDefaultCPUGenerator().current_seed();
|
14
|
+
})
|
15
|
+
.define_singleton_function(
|
16
|
+
"seed",
|
17
|
+
[]() {
|
18
|
+
// TODO set for CUDA when available
|
19
|
+
auto generator = at::detail::getDefaultCPUGenerator();
|
20
|
+
return generator.seed();
|
21
|
+
});
|
22
|
+
}
|
@@ -137,7 +137,7 @@ auto FunctionParameter::check(VALUE obj, int argnum) -> bool
|
|
137
137
|
return true;
|
138
138
|
}
|
139
139
|
if (THPVariable_Check(obj)) {
|
140
|
-
auto var =
|
140
|
+
auto var = Rice::detail::From_Ruby<torch::Tensor>().convert(obj);
|
141
141
|
return !var.requires_grad() && var.dim() == 0;
|
142
142
|
}
|
143
143
|
return false;
|
@@ -147,7 +147,7 @@ auto FunctionParameter::check(VALUE obj, int argnum) -> bool
|
|
147
147
|
return true;
|
148
148
|
}
|
149
149
|
if (THPVariable_Check(obj)) {
|
150
|
-
auto var =
|
150
|
+
auto var = Rice::detail::From_Ruby<torch::Tensor>().convert(obj);
|
151
151
|
return at::isIntegralType(var.scalar_type(), /*includeBool=*/false) && !var.requires_grad() && var.dim() == 0;
|
152
152
|
}
|
153
153
|
return false;
|
@@ -487,7 +487,7 @@ static void extra_kwargs(FunctionSignature& signature, VALUE kwargs, ssize_t num
|
|
487
487
|
|
488
488
|
VALUE missing = Qundef;
|
489
489
|
|
490
|
-
bool FunctionSignature::parse(VALUE self, VALUE args, VALUE kwargs,
|
490
|
+
bool FunctionSignature::parse(VALUE self, VALUE args, VALUE kwargs, VALUE dst[], // NOLINT
|
491
491
|
bool raise_exception) {
|
492
492
|
auto nargs = NIL_P(args) ? 0 : RARRAY_LEN(args);
|
493
493
|
ssize_t remaining_kwargs = NIL_P(kwargs) ? 0 : RHASH_SIZE(kwargs);
|
data/ext/torch/ruby_arg_parser.h
CHANGED
@@ -2,8 +2,10 @@
|
|
2
2
|
|
3
3
|
#pragma once
|
4
4
|
|
5
|
+
#include <sstream>
|
6
|
+
|
5
7
|
#include <torch/torch.h>
|
6
|
-
#include <rice/
|
8
|
+
#include <rice/rice.hpp>
|
7
9
|
|
8
10
|
#include "templates.h"
|
9
11
|
#include "utils.h"
|
@@ -46,7 +48,7 @@ struct FunctionParameter {
|
|
46
48
|
struct FunctionSignature {
|
47
49
|
explicit FunctionSignature(const std::string& fmt, int index);
|
48
50
|
|
49
|
-
bool parse(VALUE self, VALUE args, VALUE kwargs,
|
51
|
+
bool parse(VALUE self, VALUE args, VALUE kwargs, VALUE dst[], bool raise_exception);
|
50
52
|
|
51
53
|
std::string toString() const;
|
52
54
|
|
@@ -63,13 +65,13 @@ struct FunctionSignature {
|
|
63
65
|
};
|
64
66
|
|
65
67
|
struct RubyArgs {
|
66
|
-
RubyArgs(const FunctionSignature& signature,
|
68
|
+
RubyArgs(const FunctionSignature& signature, VALUE* args)
|
67
69
|
: signature(signature)
|
68
70
|
, args(args)
|
69
71
|
, idx(signature.index) {}
|
70
72
|
|
71
73
|
const FunctionSignature& signature;
|
72
|
-
|
74
|
+
VALUE* args;
|
73
75
|
int idx;
|
74
76
|
|
75
77
|
inline at::Tensor tensor(int i);
|
@@ -119,7 +121,7 @@ struct RubyArgs {
|
|
119
121
|
};
|
120
122
|
|
121
123
|
inline at::Tensor RubyArgs::tensor(int i) {
|
122
|
-
return
|
124
|
+
return Rice::detail::From_Ruby<torch::Tensor>().convert(args[i]);
|
123
125
|
}
|
124
126
|
|
125
127
|
inline OptionalTensor RubyArgs::optionalTensor(int i) {
|
@@ -129,12 +131,12 @@ inline OptionalTensor RubyArgs::optionalTensor(int i) {
|
|
129
131
|
|
130
132
|
inline at::Scalar RubyArgs::scalar(int i) {
|
131
133
|
if (NIL_P(args[i])) return signature.params[i].default_scalar;
|
132
|
-
return
|
134
|
+
return Rice::detail::From_Ruby<torch::Scalar>().convert(args[i]);
|
133
135
|
}
|
134
136
|
|
135
137
|
inline std::vector<at::Tensor> RubyArgs::tensorlist(int i) {
|
136
138
|
if (NIL_P(args[i])) return std::vector<at::Tensor>();
|
137
|
-
return
|
139
|
+
return Rice::detail::From_Ruby<std::vector<Tensor>>().convert(args[i]);
|
138
140
|
}
|
139
141
|
|
140
142
|
template<int N>
|
@@ -149,7 +151,7 @@ inline std::array<at::Tensor, N> RubyArgs::tensorlist_n(int i) {
|
|
149
151
|
}
|
150
152
|
for (int idx = 0; idx < size; idx++) {
|
151
153
|
VALUE obj = rb_ary_entry(arg, idx);
|
152
|
-
res[idx] =
|
154
|
+
res[idx] = Rice::detail::From_Ruby<Tensor>().convert(obj);
|
153
155
|
}
|
154
156
|
return res;
|
155
157
|
}
|
@@ -168,7 +170,7 @@ inline std::vector<int64_t> RubyArgs::intlist(int i) {
|
|
168
170
|
for (idx = 0; idx < size; idx++) {
|
169
171
|
VALUE obj = rb_ary_entry(arg, idx);
|
170
172
|
if (FIXNUM_P(obj)) {
|
171
|
-
res[idx] =
|
173
|
+
res[idx] = Rice::detail::From_Ruby<int64_t>().convert(obj);
|
172
174
|
} else {
|
173
175
|
rb_raise(rb_eArgError, "%s(): argument '%s' must be %s, but found element of type %s at pos %d",
|
174
176
|
signature.name.c_str(), signature.params[i].name.c_str(),
|
@@ -208,8 +210,13 @@ inline ScalarType RubyArgs::scalartype(int i) {
|
|
208
210
|
{ID2SYM(rb_intern("double")), ScalarType::Double},
|
209
211
|
{ID2SYM(rb_intern("float64")), ScalarType::Double},
|
210
212
|
{ID2SYM(rb_intern("complex_half")), ScalarType::ComplexHalf},
|
213
|
+
{ID2SYM(rb_intern("complex32")), ScalarType::ComplexHalf},
|
211
214
|
{ID2SYM(rb_intern("complex_float")), ScalarType::ComplexFloat},
|
215
|
+
{ID2SYM(rb_intern("cfloat")), ScalarType::ComplexFloat},
|
216
|
+
{ID2SYM(rb_intern("complex64")), ScalarType::ComplexFloat},
|
212
217
|
{ID2SYM(rb_intern("complex_double")), ScalarType::ComplexDouble},
|
218
|
+
{ID2SYM(rb_intern("cdouble")), ScalarType::ComplexDouble},
|
219
|
+
{ID2SYM(rb_intern("complex128")), ScalarType::ComplexDouble},
|
213
220
|
{ID2SYM(rb_intern("bool")), ScalarType::Bool},
|
214
221
|
{ID2SYM(rb_intern("qint8")), ScalarType::QInt8},
|
215
222
|
{ID2SYM(rb_intern("quint8")), ScalarType::QUInt8},
|
@@ -258,7 +265,7 @@ inline c10::OptionalArray<double> RubyArgs::doublelistOptional(int i) {
|
|
258
265
|
for (idx = 0; idx < size; idx++) {
|
259
266
|
VALUE obj = rb_ary_entry(arg, idx);
|
260
267
|
if (FIXNUM_P(obj) || RB_FLOAT_TYPE_P(obj)) {
|
261
|
-
res[idx] =
|
268
|
+
res[idx] = Rice::detail::From_Ruby<double>().convert(obj);
|
262
269
|
} else {
|
263
270
|
rb_raise(rb_eArgError, "%s(): argument '%s' must be %s, but found element of type %s at pos %d",
|
264
271
|
signature.name.c_str(), signature.params[i].name.c_str(),
|
@@ -301,22 +308,22 @@ inline c10::optional<at::MemoryFormat> RubyArgs::memoryformatOptional(int i) {
|
|
301
308
|
}
|
302
309
|
|
303
310
|
inline std::string RubyArgs::string(int i) {
|
304
|
-
return
|
311
|
+
return Rice::detail::From_Ruby<std::string>().convert(args[i]);
|
305
312
|
}
|
306
313
|
|
307
314
|
inline c10::optional<std::string> RubyArgs::stringOptional(int i) {
|
308
315
|
if (!args[i]) return c10::nullopt;
|
309
|
-
return
|
316
|
+
return Rice::detail::From_Ruby<std::string>().convert(args[i]);
|
310
317
|
}
|
311
318
|
|
312
319
|
inline int64_t RubyArgs::toInt64(int i) {
|
313
320
|
if (NIL_P(args[i])) return signature.params[i].default_int;
|
314
|
-
return
|
321
|
+
return Rice::detail::From_Ruby<int64_t>().convert(args[i]);
|
315
322
|
}
|
316
323
|
|
317
324
|
inline double RubyArgs::toDouble(int i) {
|
318
325
|
if (NIL_P(args[i])) return signature.params[i].default_double;
|
319
|
-
return
|
326
|
+
return Rice::detail::From_Ruby<double>().convert(args[i]);
|
320
327
|
}
|
321
328
|
|
322
329
|
inline bool RubyArgs::toBool(int i) {
|
@@ -328,6 +335,12 @@ inline bool RubyArgs::isNone(int i) {
|
|
328
335
|
return NIL_P(args[i]);
|
329
336
|
}
|
330
337
|
|
338
|
+
template<int N>
|
339
|
+
struct ParsedArgs {
|
340
|
+
ParsedArgs() : args() { }
|
341
|
+
VALUE args[N];
|
342
|
+
};
|
343
|
+
|
331
344
|
struct RubyArgParser {
|
332
345
|
std::vector<FunctionSignature> signatures_;
|
333
346
|
std::string function_name;
|
@@ -356,7 +369,15 @@ struct RubyArgParser {
|
|
356
369
|
});
|
357
370
|
}
|
358
371
|
|
359
|
-
|
372
|
+
template<int N>
|
373
|
+
inline RubyArgs parse(VALUE self, int argc, VALUE* argv, ParsedArgs<N> &dst) {
|
374
|
+
if (N < max_args) {
|
375
|
+
rb_raise(rb_eArgError, "RubyArgParser: dst ParsedArgs buffer does not have enough capacity, expected %d (got %d)", (int)max_args, N);
|
376
|
+
}
|
377
|
+
return raw_parse(self, argc, argv, dst.args);
|
378
|
+
}
|
379
|
+
|
380
|
+
inline RubyArgs raw_parse(VALUE self, int argc, VALUE* argv, VALUE parsed_args[]) {
|
360
381
|
VALUE args, kwargs;
|
361
382
|
rb_scan_args(argc, argv, "*:", &args, &kwargs);
|
362
383
|
|
@@ -378,7 +399,7 @@ struct RubyArgParser {
|
|
378
399
|
rb_raise(rb_eArgError, "No matching signatures");
|
379
400
|
}
|
380
401
|
|
381
|
-
void print_error(VALUE self, VALUE args, VALUE kwargs,
|
402
|
+
void print_error(VALUE self, VALUE args, VALUE kwargs, VALUE parsed_args[]) {
|
382
403
|
ssize_t num_args = (NIL_P(args) ? 0 : RARRAY_LEN(args)) + (NIL_P(kwargs) ? 0 : RHASH_SIZE(kwargs));
|
383
404
|
std::vector<int> plausible_idxs;
|
384
405
|
ssize_t i = 0;
|