torch-rb 0.3.7 → 0.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/README.md +1 -1
- data/codegen/function.rb +134 -0
- data/codegen/generate_functions.rb +546 -0
- data/{lib/torch/native → codegen}/native_functions.yaml +0 -0
- data/ext/torch/ext.cpp +54 -75
- data/ext/torch/extconf.rb +2 -2
- data/ext/torch/nn_functions.h +6 -0
- data/ext/torch/ruby_arg_parser.cpp +593 -0
- data/ext/torch/ruby_arg_parser.h +373 -0
- data/ext/torch/{templates.hpp → templates.h} +30 -51
- data/ext/torch/tensor_functions.h +6 -0
- data/ext/torch/torch_functions.h +6 -0
- data/ext/torch/utils.h +42 -0
- data/ext/torch/{templates.cpp → wrap_outputs.h} +16 -15
- data/lib/torch.rb +0 -62
- data/lib/torch/nn/functional.rb +30 -16
- data/lib/torch/nn/init.rb +5 -19
- data/lib/torch/optim/adadelta.rb +1 -1
- data/lib/torch/optim/adam.rb +2 -2
- data/lib/torch/optim/adamax.rb +1 -1
- data/lib/torch/optim/adamw.rb +1 -1
- data/lib/torch/optim/asgd.rb +1 -1
- data/lib/torch/optim/sgd.rb +3 -3
- data/lib/torch/tensor.rb +25 -105
- data/lib/torch/version.rb +1 -1
- metadata +27 -9
- data/lib/torch/native/dispatcher.rb +0 -70
- data/lib/torch/native/function.rb +0 -200
- data/lib/torch/native/generator.rb +0 -178
- data/lib/torch/native/parser.rb +0 -117
File without changes
|
data/ext/torch/ext.cpp
CHANGED
@@ -7,13 +7,14 @@
|
|
7
7
|
#include <rice/Constructor.hpp>
|
8
8
|
#include <rice/Hash.hpp>
|
9
9
|
|
10
|
-
#include "templates.
|
10
|
+
#include "templates.h"
|
11
|
+
#include "utils.h"
|
11
12
|
|
12
13
|
// generated with:
|
13
14
|
// rake generate:functions
|
14
|
-
#include "torch_functions.
|
15
|
-
#include "tensor_functions.
|
16
|
-
#include "nn_functions.
|
15
|
+
#include "torch_functions.h"
|
16
|
+
#include "tensor_functions.h"
|
17
|
+
#include "nn_functions.h"
|
17
18
|
|
18
19
|
using namespace Rice;
|
19
20
|
using torch::indexing::TensorIndex;
|
@@ -29,11 +30,47 @@ void handle_error(torch::Error const & ex)
|
|
29
30
|
throw Exception(rb_eRuntimeError, ex.what_without_backtrace());
|
30
31
|
}
|
31
32
|
|
33
|
+
Class rb_cTensor;
|
34
|
+
|
32
35
|
std::vector<TensorIndex> index_vector(Array a) {
|
33
|
-
|
36
|
+
Object obj;
|
37
|
+
|
38
|
+
std::vector<TensorIndex> indices;
|
34
39
|
indices.reserve(a.size());
|
40
|
+
|
35
41
|
for (size_t i = 0; i < a.size(); i++) {
|
36
|
-
|
42
|
+
obj = a[i];
|
43
|
+
|
44
|
+
if (obj.is_instance_of(rb_cInteger)) {
|
45
|
+
indices.push_back(from_ruby<int64_t>(obj));
|
46
|
+
} else if (obj.is_instance_of(rb_cRange)) {
|
47
|
+
torch::optional<int64_t> start_index = from_ruby<int64_t>(obj.call("begin"));
|
48
|
+
torch::optional<int64_t> stop_index = -1;
|
49
|
+
|
50
|
+
Object end = obj.call("end");
|
51
|
+
if (!end.is_nil()) {
|
52
|
+
stop_index = from_ruby<int64_t>(end);
|
53
|
+
}
|
54
|
+
|
55
|
+
Object exclude_end = obj.call("exclude_end?");
|
56
|
+
if (!exclude_end) {
|
57
|
+
if (stop_index.value() == -1) {
|
58
|
+
stop_index = torch::nullopt;
|
59
|
+
} else {
|
60
|
+
stop_index = stop_index.value() + 1;
|
61
|
+
}
|
62
|
+
}
|
63
|
+
|
64
|
+
indices.push_back(torch::indexing::Slice(start_index, stop_index));
|
65
|
+
} else if (obj.is_instance_of(rb_cTensor)) {
|
66
|
+
indices.push_back(from_ruby<Tensor>(obj));
|
67
|
+
} else if (obj.is_nil()) {
|
68
|
+
indices.push_back(torch::indexing::None);
|
69
|
+
} else if (obj == True || obj == False) {
|
70
|
+
indices.push_back(from_ruby<bool>(obj));
|
71
|
+
} else {
|
72
|
+
throw Exception(rb_eArgError, "Unsupported index type: %s", rb_obj_classname(obj));
|
73
|
+
}
|
37
74
|
}
|
38
75
|
return indices;
|
39
76
|
}
|
@@ -45,9 +82,10 @@ void Init_ext()
|
|
45
82
|
rb_mTorch.add_handler<torch::Error>(handle_error);
|
46
83
|
add_torch_functions(rb_mTorch);
|
47
84
|
|
48
|
-
|
85
|
+
rb_cTensor = define_class_under<torch::Tensor>(rb_mTorch, "Tensor");
|
49
86
|
rb_cTensor.add_handler<torch::Error>(handle_error);
|
50
87
|
add_tensor_functions(rb_cTensor);
|
88
|
+
THPVariableClass = rb_cTensor.value();
|
51
89
|
|
52
90
|
Module rb_mNN = define_module_under(rb_mTorch, "NN");
|
53
91
|
rb_mNN.add_handler<torch::Error>(handle_error);
|
@@ -68,13 +106,6 @@ void Init_ext()
|
|
68
106
|
return generator.seed();
|
69
107
|
});
|
70
108
|
|
71
|
-
Class rb_cTensorIndex = define_class_under<TensorIndex>(rb_mTorch, "TensorIndex")
|
72
|
-
.define_singleton_method("boolean", *[](bool value) { return TensorIndex(value); })
|
73
|
-
.define_singleton_method("integer", *[](int64_t value) { return TensorIndex(value); })
|
74
|
-
.define_singleton_method("tensor", *[](torch::Tensor& value) { return TensorIndex(value); })
|
75
|
-
.define_singleton_method("slice", *[](torch::optional<int64_t> start_index, torch::optional<int64_t> stop_index) { return TensorIndex(torch::indexing::Slice(start_index, stop_index)); })
|
76
|
-
.define_singleton_method("none", *[]() { return TensorIndex(torch::indexing::None); });
|
77
|
-
|
78
109
|
// https://pytorch.org/cppdocs/api/structc10_1_1_i_value.html
|
79
110
|
Class rb_cIValue = define_class_under<torch::IValue>(rb_mTorch, "IValue")
|
80
111
|
.add_handler<torch::Error>(handle_error)
|
@@ -224,67 +255,6 @@ void Init_ext()
|
|
224
255
|
*[] {
|
225
256
|
return torch::get_parallel_info();
|
226
257
|
})
|
227
|
-
// begin tensor creation
|
228
|
-
.define_singleton_method(
|
229
|
-
"_arange",
|
230
|
-
*[](Scalar start, Scalar end, Scalar step, const torch::TensorOptions &options) {
|
231
|
-
return torch::arange(start, end, step, options);
|
232
|
-
})
|
233
|
-
.define_singleton_method(
|
234
|
-
"_empty",
|
235
|
-
*[](std::vector<int64_t> size, const torch::TensorOptions &options) {
|
236
|
-
return torch::empty(size, options);
|
237
|
-
})
|
238
|
-
.define_singleton_method(
|
239
|
-
"_eye",
|
240
|
-
*[](int64_t m, int64_t n, const torch::TensorOptions &options) {
|
241
|
-
return torch::eye(m, n, options);
|
242
|
-
})
|
243
|
-
.define_singleton_method(
|
244
|
-
"_full",
|
245
|
-
*[](std::vector<int64_t> size, Scalar fill_value, const torch::TensorOptions& options) {
|
246
|
-
return torch::full(size, fill_value, options);
|
247
|
-
})
|
248
|
-
.define_singleton_method(
|
249
|
-
"_linspace",
|
250
|
-
*[](Scalar start, Scalar end, int64_t steps, const torch::TensorOptions& options) {
|
251
|
-
return torch::linspace(start, end, steps, options);
|
252
|
-
})
|
253
|
-
.define_singleton_method(
|
254
|
-
"_logspace",
|
255
|
-
*[](Scalar start, Scalar end, int64_t steps, double base, const torch::TensorOptions& options) {
|
256
|
-
return torch::logspace(start, end, steps, base, options);
|
257
|
-
})
|
258
|
-
.define_singleton_method(
|
259
|
-
"_ones",
|
260
|
-
*[](std::vector<int64_t> size, const torch::TensorOptions &options) {
|
261
|
-
return torch::ones(size, options);
|
262
|
-
})
|
263
|
-
.define_singleton_method(
|
264
|
-
"_rand",
|
265
|
-
*[](std::vector<int64_t> size, const torch::TensorOptions &options) {
|
266
|
-
return torch::rand(size, options);
|
267
|
-
})
|
268
|
-
.define_singleton_method(
|
269
|
-
"_randint",
|
270
|
-
*[](int64_t low, int64_t high, std::vector<int64_t> size, const torch::TensorOptions &options) {
|
271
|
-
return torch::randint(low, high, size, options);
|
272
|
-
})
|
273
|
-
.define_singleton_method(
|
274
|
-
"_randn",
|
275
|
-
*[](std::vector<int64_t> size, const torch::TensorOptions &options) {
|
276
|
-
return torch::randn(size, options);
|
277
|
-
})
|
278
|
-
.define_singleton_method(
|
279
|
-
"_randperm",
|
280
|
-
*[](int64_t n, const torch::TensorOptions &options) {
|
281
|
-
return torch::randperm(n, options);
|
282
|
-
})
|
283
|
-
.define_singleton_method(
|
284
|
-
"_zeros",
|
285
|
-
*[](std::vector<int64_t> size, const torch::TensorOptions &options) {
|
286
|
-
return torch::zeros(size, options);
|
287
|
-
})
|
288
258
|
// begin operations
|
289
259
|
.define_singleton_method(
|
290
260
|
"_save",
|
@@ -352,6 +322,15 @@ void Init_ext()
|
|
352
322
|
}
|
353
323
|
return a;
|
354
324
|
})
|
325
|
+
.define_method(
|
326
|
+
"_strides",
|
327
|
+
*[](Tensor& self) {
|
328
|
+
Array a;
|
329
|
+
for (auto &stride : self.strides()) {
|
330
|
+
a.push(stride);
|
331
|
+
}
|
332
|
+
return a;
|
333
|
+
})
|
355
334
|
.define_method(
|
356
335
|
"_index",
|
357
336
|
*[](Tensor& self, Array indices) {
|
data/ext/torch/extconf.rb
CHANGED
@@ -69,8 +69,8 @@ end
|
|
69
69
|
|
70
70
|
# generate C++ functions
|
71
71
|
puts "Generating C++ functions..."
|
72
|
-
require_relative "../../
|
73
|
-
|
72
|
+
require_relative "../../codegen/generate_functions"
|
73
|
+
generate_functions
|
74
74
|
|
75
75
|
# create makefile
|
76
76
|
create_makefile("torch/ext")
|
@@ -0,0 +1,593 @@
|
|
1
|
+
// adapted from PyTorch - python_arg_parser.cpp
|
2
|
+
|
3
|
+
#include "ruby_arg_parser.h"
|
4
|
+
|
5
|
+
VALUE THPVariableClass = Qnil;
|
6
|
+
|
7
|
+
static std::unordered_map<std::string, ParameterType> type_map = {
|
8
|
+
{"Tensor", ParameterType::TENSOR},
|
9
|
+
{"Scalar", ParameterType::SCALAR},
|
10
|
+
{"int64_t", ParameterType::INT64},
|
11
|
+
{"double", ParameterType::DOUBLE},
|
12
|
+
{"complex", ParameterType::COMPLEX},
|
13
|
+
{"TensorList", ParameterType::TENSOR_LIST},
|
14
|
+
{"IntArrayRef", ParameterType::INT_LIST},
|
15
|
+
{"ArrayRef<double>", ParameterType::FLOAT_LIST},
|
16
|
+
{"Generator", ParameterType::GENERATOR},
|
17
|
+
{"bool", ParameterType::BOOL},
|
18
|
+
{"Storage", ParameterType::STORAGE},
|
19
|
+
// {"PyObject*", ParameterType::PYOBJECT},
|
20
|
+
{"ScalarType", ParameterType::SCALARTYPE},
|
21
|
+
{"Layout", ParameterType::LAYOUT},
|
22
|
+
{"MemoryFormat", ParameterType::MEMORY_FORMAT},
|
23
|
+
{"QScheme", ParameterType::QSCHEME},
|
24
|
+
{"Device", ParameterType::DEVICE},
|
25
|
+
{"std::string", ParameterType::STRING},
|
26
|
+
{"Dimname", ParameterType::DIMNAME},
|
27
|
+
{"DimnameList", ParameterType::DIMNAME_LIST},
|
28
|
+
};
|
29
|
+
|
30
|
+
static const std::unordered_map<std::string, std::vector<std::string>> numpy_compatibility_arg_names = {
|
31
|
+
{"dim", {"axis"}},
|
32
|
+
{"keepdim", {"keepdims"}},
|
33
|
+
{"input", {"x", "a", "x1"}},
|
34
|
+
{"other", {"x2"}},
|
35
|
+
};
|
36
|
+
|
37
|
+
static bool should_allow_numbers_as_tensors(const std::string& name) {
|
38
|
+
static std::unordered_set<std::string> allowed = {
|
39
|
+
"add", "add_", "add_out",
|
40
|
+
"div", "div_", "div_out",
|
41
|
+
"mul", "mul_", "mul_out",
|
42
|
+
"sub", "sub_", "sub_out",
|
43
|
+
"true_divide", "true_divide_", "true_divide_out",
|
44
|
+
"floor_divide", "floor_divide_", "floor_divide_out"
|
45
|
+
};
|
46
|
+
return allowed.find(name) != allowed.end();
|
47
|
+
}
|
48
|
+
|
49
|
+
FunctionParameter::FunctionParameter(const std::string& fmt, bool keyword_only)
|
50
|
+
: optional(false)
|
51
|
+
, allow_none(false)
|
52
|
+
, keyword_only(keyword_only)
|
53
|
+
, size(0)
|
54
|
+
, default_scalar(0)
|
55
|
+
{
|
56
|
+
auto space = fmt.find(' ');
|
57
|
+
if (space == std::string::npos) {
|
58
|
+
throw std::runtime_error("FunctionParameter(): missing type: " + fmt);
|
59
|
+
}
|
60
|
+
|
61
|
+
auto type_str = fmt.substr(0, space);
|
62
|
+
|
63
|
+
auto question = type_str.find('?');
|
64
|
+
if (question != std::string::npos) {
|
65
|
+
allow_none = true;
|
66
|
+
type_str = type_str.substr(0, question);
|
67
|
+
}
|
68
|
+
|
69
|
+
// Parse and remove brackets from type_str
|
70
|
+
auto bracket = type_str.find('[');
|
71
|
+
if (bracket != std::string::npos) {
|
72
|
+
auto size_str = type_str.substr(bracket + 1, type_str.length() - bracket - 2);
|
73
|
+
size = atoi(size_str.c_str());
|
74
|
+
type_str = type_str.substr(0, bracket);
|
75
|
+
}
|
76
|
+
|
77
|
+
auto name_str = fmt.substr(space + 1);
|
78
|
+
auto it = type_map.find(type_str);
|
79
|
+
if (it == type_map.end()) {
|
80
|
+
throw std::runtime_error("FunctionParameter(): invalid type string: " + type_str);
|
81
|
+
}
|
82
|
+
type_ = it->second;
|
83
|
+
|
84
|
+
auto eq = name_str.find('=');
|
85
|
+
if (eq != std::string::npos) {
|
86
|
+
name = name_str.substr(0, eq);
|
87
|
+
optional = true;
|
88
|
+
set_default_str(name_str.substr(eq + 1));
|
89
|
+
} else {
|
90
|
+
name = name_str;
|
91
|
+
}
|
92
|
+
ruby_name = THPUtils_internSymbol(name);
|
93
|
+
auto np_compat_it = numpy_compatibility_arg_names.find(name);
|
94
|
+
if (np_compat_it != numpy_compatibility_arg_names.end()) {
|
95
|
+
for (const auto& str: np_compat_it->second) {
|
96
|
+
numpy_python_names.push_back(THPUtils_internSymbol(str));
|
97
|
+
}
|
98
|
+
}
|
99
|
+
}
|
100
|
+
|
101
|
+
bool is_tensor_list(VALUE obj, int argnum, bool throw_error) {
|
102
|
+
if (!RB_TYPE_P(obj, T_ARRAY)) {
|
103
|
+
return false;
|
104
|
+
}
|
105
|
+
auto size = RARRAY_LEN(obj);
|
106
|
+
for (int idx = 0; idx < size; idx++) {
|
107
|
+
VALUE iobj = rb_ary_entry(obj, idx);
|
108
|
+
if (!THPVariable_Check(iobj)) {
|
109
|
+
if (throw_error) {
|
110
|
+
rb_raise(rb_eArgError, "expected Tensor as element %d in argument %d, but got %s",
|
111
|
+
static_cast<int>(idx), argnum, rb_obj_classname(obj));
|
112
|
+
}
|
113
|
+
return false;
|
114
|
+
}
|
115
|
+
}
|
116
|
+
return true;
|
117
|
+
}
|
118
|
+
|
119
|
+
// argnum is needed for raising the TypeError, it's used in the error message.
|
120
|
+
auto FunctionParameter::check(VALUE obj, int argnum) -> bool
|
121
|
+
{
|
122
|
+
switch (type_) {
|
123
|
+
case ParameterType::TENSOR: {
|
124
|
+
if (THPVariable_Check(obj)) {
|
125
|
+
return true;
|
126
|
+
}
|
127
|
+
return allow_numbers_as_tensors && THPUtils_checkScalar(obj);
|
128
|
+
}
|
129
|
+
case ParameterType::SCALAR:
|
130
|
+
case ParameterType::COMPLEX:
|
131
|
+
if (RB_TYPE_P(obj, T_COMPLEX)) {
|
132
|
+
return true;
|
133
|
+
}
|
134
|
+
// fallthrough
|
135
|
+
case ParameterType::DOUBLE: {
|
136
|
+
if (RB_FLOAT_TYPE_P(obj) || FIXNUM_P(obj)) {
|
137
|
+
return true;
|
138
|
+
}
|
139
|
+
if (THPVariable_Check(obj)) {
|
140
|
+
auto var = from_ruby<torch::Tensor>(obj);
|
141
|
+
return !var.requires_grad() && var.dim() == 0;
|
142
|
+
}
|
143
|
+
return false;
|
144
|
+
}
|
145
|
+
case ParameterType::INT64: {
|
146
|
+
if (FIXNUM_P(obj)) {
|
147
|
+
return true;
|
148
|
+
}
|
149
|
+
if (THPVariable_Check(obj)) {
|
150
|
+
auto var = from_ruby<torch::Tensor>(obj);
|
151
|
+
return at::isIntegralType(var.scalar_type(), /*includeBool=*/false) && !var.requires_grad() && var.dim() == 0;
|
152
|
+
}
|
153
|
+
return false;
|
154
|
+
}
|
155
|
+
case ParameterType::DIMNAME: return false; // return THPUtils_checkDimname(obj);
|
156
|
+
case ParameterType::DIMNAME_LIST: {
|
157
|
+
return false;
|
158
|
+
// if (THPUtils_checkDimnameList(obj)) {
|
159
|
+
// return true;
|
160
|
+
// }
|
161
|
+
// // if a size is specified (e.g. DimnameList[1]) we also allow passing a single Dimname
|
162
|
+
// return size == 1 && THPUtils_checkDimname(obj);
|
163
|
+
}
|
164
|
+
case ParameterType::TENSOR_LIST: {
|
165
|
+
return is_tensor_list(obj, argnum, true /* throw_error */);
|
166
|
+
}
|
167
|
+
case ParameterType::INT_LIST: {
|
168
|
+
if (RB_TYPE_P(obj, T_ARRAY)) {
|
169
|
+
return true;
|
170
|
+
}
|
171
|
+
// if a size is specified (e.g. IntArrayRef[2]) we also allow passing a single int
|
172
|
+
return size > 0 && FIXNUM_P(obj);
|
173
|
+
}
|
174
|
+
case ParameterType::FLOAT_LIST: return (RB_TYPE_P(obj, T_ARRAY));
|
175
|
+
case ParameterType::GENERATOR: return false; // return THPGenerator_Check(obj);
|
176
|
+
case ParameterType::BOOL: return obj == Qtrue || obj == Qfalse;
|
177
|
+
case ParameterType::STORAGE: return false; // return isStorage(obj);
|
178
|
+
// case ParameterType::PYOBJECT: return true;
|
179
|
+
case ParameterType::SCALARTYPE: return SYMBOL_P(obj);
|
180
|
+
case ParameterType::LAYOUT: return SYMBOL_P(obj);
|
181
|
+
case ParameterType::MEMORY_FORMAT: return false; // return THPMemoryFormat_Check(obj);
|
182
|
+
case ParameterType::QSCHEME: return false; // return THPQScheme_Check(obj);
|
183
|
+
case ParameterType::DEVICE: return RB_TYPE_P(obj, T_STRING); // TODO check device
|
184
|
+
case ParameterType::STRING: return RB_TYPE_P(obj, T_STRING);
|
185
|
+
default: throw std::runtime_error("unknown parameter type");
|
186
|
+
}
|
187
|
+
}
|
188
|
+
|
189
|
+
std::string FunctionParameter::type_name() const {
|
190
|
+
switch (type_) {
|
191
|
+
case ParameterType::TENSOR: return "Tensor";
|
192
|
+
case ParameterType::SCALAR: return "Number";
|
193
|
+
case ParameterType::INT64: return "int";
|
194
|
+
case ParameterType::DOUBLE: return "float";
|
195
|
+
case ParameterType::COMPLEX: return "complex";
|
196
|
+
case ParameterType::TENSOR_LIST: return "array of Tensors";
|
197
|
+
case ParameterType::INT_LIST: return "array of ints";
|
198
|
+
case ParameterType::FLOAT_LIST: return "array of floats";
|
199
|
+
case ParameterType::GENERATOR: return "torch.Generator";
|
200
|
+
case ParameterType::BOOL: return "bool";
|
201
|
+
case ParameterType::STORAGE: return "torch.Storage";
|
202
|
+
// case ParameterType::PYOBJECT: return "object";
|
203
|
+
case ParameterType::SCALARTYPE: return "torch.dtype";
|
204
|
+
case ParameterType::LAYOUT: return "torch.layout";
|
205
|
+
case ParameterType::MEMORY_FORMAT: return "torch.memory_format";
|
206
|
+
case ParameterType::QSCHEME: return "torch.qscheme";
|
207
|
+
case ParameterType::DEVICE: return "torch.device";
|
208
|
+
case ParameterType::STRING: return "str";
|
209
|
+
case ParameterType::DIMNAME: return "name";
|
210
|
+
case ParameterType::DIMNAME_LIST: return "array of names";
|
211
|
+
default: throw std::runtime_error("unknown parameter type");
|
212
|
+
}
|
213
|
+
}
|
214
|
+
|
215
|
+
static inline c10::optional<int64_t> parse_as_integer(const std::string& s) {
|
216
|
+
if (s.empty())
|
217
|
+
return c10::nullopt;
|
218
|
+
char *str_end;
|
219
|
+
long ans = strtol(s.c_str(), &str_end, 0);
|
220
|
+
// *str_end == 0 if the entire string was parsed as an integer.
|
221
|
+
return (*str_end == 0) ? c10::optional<int64_t>(ans) : c10::nullopt;
|
222
|
+
}
|
223
|
+
|
224
|
+
/*
|
225
|
+
Parse default value of IntArrayRef declared at native_functions.yaml
|
226
|
+
|
227
|
+
There are two kinds of default values:
|
228
|
+
1. IntArrayRef[2] x=1 (where size=2, value={1,1}
|
229
|
+
2. IntArrayRef x={1,2,3} (where size=3, value={1,2,3}, note that there cannot be space after comma since native_parse.py uses ', ' to split args)
|
230
|
+
*/
|
231
|
+
static inline std::vector<int64_t> parse_intlist_args(const std::string& s, int64_t size) {
|
232
|
+
size_t n = s.size();
|
233
|
+
|
234
|
+
if (s.empty()) return std::vector<int64_t>();
|
235
|
+
|
236
|
+
// case 1. s is an int (e.g., s=2)
|
237
|
+
if (s[0] != '{') {
|
238
|
+
return std::vector<int64_t>(size, std::stol(s));
|
239
|
+
}
|
240
|
+
|
241
|
+
// case 2. s is a list of dims (e.g., s={1,2})
|
242
|
+
|
243
|
+
// since already checked left brace '{' above, here only checks right brace '}'
|
244
|
+
TORCH_CHECK(s[n - 1] == '}', "Default value of IntArrayRef is missing right brace '}', found ", s[n - 1]);
|
245
|
+
|
246
|
+
auto args = std::vector<int64_t>();
|
247
|
+
std::istringstream ss(s.substr(1, s.length() - 2)); // exclude '{' and '}'
|
248
|
+
std::string tok;
|
249
|
+
|
250
|
+
while(std::getline(ss, tok, ',')) {
|
251
|
+
args.emplace_back(std::stol(tok));
|
252
|
+
}
|
253
|
+
return args;
|
254
|
+
}
|
255
|
+
|
256
|
+
void FunctionParameter::set_default_str(const std::string& str) {
|
257
|
+
if (str == "None") {
|
258
|
+
allow_none = true;
|
259
|
+
}
|
260
|
+
if (type_ == ParameterType::TENSOR) {
|
261
|
+
if (str != "None") {
|
262
|
+
throw std::runtime_error("default value for Tensor must be none, got: " + str);
|
263
|
+
}
|
264
|
+
} else if (type_ == ParameterType::INT64) {
|
265
|
+
default_int = atol(str.c_str());
|
266
|
+
} else if (type_ == ParameterType::BOOL) {
|
267
|
+
default_bool = (str == "True" || str == "true");
|
268
|
+
} else if (type_ == ParameterType::DOUBLE) {
|
269
|
+
default_double = atof(str.c_str());
|
270
|
+
} else if (type_ == ParameterType::COMPLEX) {
|
271
|
+
default_complex[0] = atof(str.c_str()); // TODO: parse "x + xj"?
|
272
|
+
default_complex[1] = 0;
|
273
|
+
} else if (type_ == ParameterType::SCALAR) {
|
274
|
+
if (str != "None") {
|
275
|
+
// we sometimes rely on integer-vs-float values, e.g. with arange.
|
276
|
+
const auto as_integer = parse_as_integer(str);
|
277
|
+
default_scalar = as_integer.has_value() ? at::Scalar(as_integer.value()) :
|
278
|
+
at::Scalar(atof(str.c_str()));
|
279
|
+
}
|
280
|
+
} else if (type_ == ParameterType::INT_LIST) {
|
281
|
+
if (str != "None") {
|
282
|
+
default_intlist = parse_intlist_args(str, size);
|
283
|
+
}
|
284
|
+
} else if (type_ == ParameterType::FLOAT_LIST) {
|
285
|
+
if (str != "None") {
|
286
|
+
throw std::runtime_error("Defaults not supported for float[]");
|
287
|
+
}
|
288
|
+
} else if (type_ == ParameterType::SCALARTYPE) {
|
289
|
+
if (str == "None") {
|
290
|
+
default_scalartype = at::ScalarType::Undefined;
|
291
|
+
} else if (str == "torch.int64") {
|
292
|
+
default_scalartype = at::ScalarType::Long;
|
293
|
+
} else {
|
294
|
+
throw std::runtime_error("invalid default value for ScalarType: " + str);
|
295
|
+
}
|
296
|
+
} else if (type_ == ParameterType::LAYOUT) {
|
297
|
+
if (str == "None") {
|
298
|
+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(allow_none);
|
299
|
+
} else if (str == "torch.strided") {
|
300
|
+
default_layout = at::Layout::Strided;
|
301
|
+
} else if (str == "torch.sparse_coo") {
|
302
|
+
default_layout = at::Layout::Sparse;
|
303
|
+
} else {
|
304
|
+
throw std::runtime_error("invalid default value for layout: " + str);
|
305
|
+
}
|
306
|
+
} else if (type_ == ParameterType::DEVICE) {
|
307
|
+
if (str != "None") {
|
308
|
+
throw std::runtime_error("invalid device: " + str);
|
309
|
+
}
|
310
|
+
} else if (type_ == ParameterType::STRING) {
|
311
|
+
if (str != "None" && str != "") {
|
312
|
+
throw std::runtime_error("invalid default string: " + str);
|
313
|
+
}
|
314
|
+
}
|
315
|
+
}
|
316
|
+
|
317
|
+
FunctionSignature::FunctionSignature(const std::string& fmt, int index)
|
318
|
+
: min_args(0)
|
319
|
+
, max_args(0)
|
320
|
+
, max_pos_args(0)
|
321
|
+
, index(index)
|
322
|
+
, hidden(false)
|
323
|
+
, deprecated(false)
|
324
|
+
{
|
325
|
+
auto open_paren = fmt.find('(');
|
326
|
+
if (open_paren == std::string::npos) {
|
327
|
+
throw std::runtime_error("missing opening parenthesis: " + fmt);
|
328
|
+
}
|
329
|
+
name = fmt.substr(0, open_paren);
|
330
|
+
|
331
|
+
bool allow_numbers_as_tensors = should_allow_numbers_as_tensors(name);
|
332
|
+
|
333
|
+
auto last_offset = open_paren + 1;
|
334
|
+
auto next_offset = last_offset;
|
335
|
+
bool keyword_only = false;
|
336
|
+
bool done = false;
|
337
|
+
while (!done) {
|
338
|
+
auto offset = fmt.find(", ", last_offset);
|
339
|
+
if (offset == std::string::npos) {
|
340
|
+
offset = fmt.find(')', last_offset);
|
341
|
+
done = true;
|
342
|
+
next_offset = offset+ 1;
|
343
|
+
// this 'if' happens for an empty parameter list, i.e. fn().
|
344
|
+
if (offset == last_offset) {
|
345
|
+
last_offset = next_offset;
|
346
|
+
break;
|
347
|
+
}
|
348
|
+
} else {
|
349
|
+
next_offset = offset + 2;
|
350
|
+
}
|
351
|
+
if (offset == std::string::npos) {
|
352
|
+
throw std::runtime_error("missing closing parenthesis: " + fmt);
|
353
|
+
}
|
354
|
+
if (offset == last_offset) {
|
355
|
+
throw std::runtime_error("malformed signature: " + fmt);
|
356
|
+
}
|
357
|
+
|
358
|
+
auto param_str = fmt.substr(last_offset, offset - last_offset);
|
359
|
+
last_offset = next_offset;
|
360
|
+
if (param_str == "*") {
|
361
|
+
keyword_only = true;
|
362
|
+
} else {
|
363
|
+
params.emplace_back(param_str, keyword_only);
|
364
|
+
params.back().allow_numbers_as_tensors = allow_numbers_as_tensors;
|
365
|
+
}
|
366
|
+
}
|
367
|
+
|
368
|
+
if (fmt.substr(last_offset) == "|deprecated") {
|
369
|
+
hidden = true;
|
370
|
+
// TODO: raise warning when parsing deprecated signatures
|
371
|
+
deprecated = true;
|
372
|
+
} else if (fmt.substr(last_offset) == "|hidden") {
|
373
|
+
hidden = true;
|
374
|
+
}
|
375
|
+
|
376
|
+
max_args = params.size();
|
377
|
+
|
378
|
+
// count the number of non-optional args
|
379
|
+
for (auto& param : params) {
|
380
|
+
if (!param.optional) {
|
381
|
+
min_args++;
|
382
|
+
}
|
383
|
+
if (!param.keyword_only) {
|
384
|
+
max_pos_args++;
|
385
|
+
}
|
386
|
+
}
|
387
|
+
}
|
388
|
+
|
389
|
+
std::string FunctionSignature::toString() const {
|
390
|
+
// TODO: consider printing more proper schema strings with defaults, optionals, etc.
|
391
|
+
std::ostringstream ss;
|
392
|
+
bool keyword_already = false;
|
393
|
+
ss << "(";
|
394
|
+
int i = 0;
|
395
|
+
for (auto& param : params) {
|
396
|
+
if (i != 0) {
|
397
|
+
ss << ", ";
|
398
|
+
}
|
399
|
+
if (param.keyword_only && !keyword_already) {
|
400
|
+
ss << "*, ";
|
401
|
+
keyword_already = true;
|
402
|
+
}
|
403
|
+
ss << param.type_name() << " " << param.name;
|
404
|
+
i++;
|
405
|
+
}
|
406
|
+
ss << ")";
|
407
|
+
return ss.str();
|
408
|
+
}
|
409
|
+
|
410
|
+
[[noreturn]]
|
411
|
+
static void extra_args(const FunctionSignature& signature, ssize_t nargs) {
|
412
|
+
const long max_pos_args = signature.max_pos_args;
|
413
|
+
const long min_args = signature.min_args;
|
414
|
+
const long nargs_ = nargs;
|
415
|
+
if (min_args != max_pos_args) {
|
416
|
+
rb_raise(rb_eArgError, "%s() takes from %ld to %ld positional arguments but %ld were given",
|
417
|
+
signature.name.c_str(), min_args, max_pos_args, nargs_);
|
418
|
+
}
|
419
|
+
rb_raise(rb_eArgError, "%s() takes %ld positional argument%s but %ld %s given",
|
420
|
+
signature.name.c_str(),
|
421
|
+
max_pos_args, max_pos_args == 1 ? "" : "s",
|
422
|
+
nargs_, nargs == 1 ? "was" : "were");
|
423
|
+
}
|
424
|
+
|
425
|
+
[[noreturn]]
|
426
|
+
static void missing_args(const FunctionSignature& signature, int idx) {
|
427
|
+
int num_missing = 0;
|
428
|
+
std::stringstream ss;
|
429
|
+
|
430
|
+
auto& params = signature.params;
|
431
|
+
for (auto it = params.begin() + idx; it != params.end(); ++it) {
|
432
|
+
if (!it->optional) {
|
433
|
+
if (num_missing > 0) {
|
434
|
+
ss << ", ";
|
435
|
+
}
|
436
|
+
ss << '"' << it->name << '"';
|
437
|
+
num_missing++;
|
438
|
+
}
|
439
|
+
}
|
440
|
+
|
441
|
+
rb_raise(rb_eArgError, "%s() missing %d required positional argument%s: %s",
|
442
|
+
signature.name.c_str(),
|
443
|
+
num_missing,
|
444
|
+
num_missing == 1 ? "s" : "",
|
445
|
+
ss.str().c_str());
|
446
|
+
}
|
447
|
+
|
448
|
+
static ssize_t find_param(FunctionSignature& signature, VALUE name) {
|
449
|
+
ssize_t i = 0;
|
450
|
+
for (auto& param : signature.params) {
|
451
|
+
bool cmp = name == param.ruby_name;
|
452
|
+
if (cmp) {
|
453
|
+
return i;
|
454
|
+
}
|
455
|
+
i++;
|
456
|
+
}
|
457
|
+
return -1;
|
458
|
+
}
|
459
|
+
|
460
|
+
[[noreturn]]
|
461
|
+
static void extra_kwargs(FunctionSignature& signature, VALUE kwargs, ssize_t num_pos_args) {
|
462
|
+
VALUE key;
|
463
|
+
|
464
|
+
VALUE keys = rb_funcall(kwargs, rb_intern("keys"), 0);
|
465
|
+
if (RARRAY_LEN(keys) > 0) {
|
466
|
+
key = rb_ary_entry(keys, 0);
|
467
|
+
|
468
|
+
if (!THPUtils_checkSymbol(key)) {
|
469
|
+
rb_raise(rb_eArgError, "keywords must be symbols, not %s", rb_obj_classname(key));
|
470
|
+
}
|
471
|
+
|
472
|
+
auto param_idx = find_param(signature, key);
|
473
|
+
if (param_idx < 0) {
|
474
|
+
rb_raise(rb_eArgError, "%s() got an unexpected keyword argument '%s'",
|
475
|
+
signature.name.c_str(), THPUtils_unpackSymbol(key).c_str());
|
476
|
+
}
|
477
|
+
|
478
|
+
if (param_idx < num_pos_args) {
|
479
|
+
rb_raise(rb_eArgError, "%s() got multiple values for argument '%s'",
|
480
|
+
signature.name.c_str(), THPUtils_unpackSymbol(key).c_str());
|
481
|
+
}
|
482
|
+
}
|
483
|
+
|
484
|
+
// this should never be hit
|
485
|
+
rb_raise(rb_eArgError, "invalid keyword arguments");
|
486
|
+
}
|
487
|
+
|
488
|
+
VALUE missing = Qundef;
|
489
|
+
|
490
|
+
bool FunctionSignature::parse(VALUE self, VALUE args, VALUE kwargs, std::vector<VALUE> &dst, // NOLINT
|
491
|
+
bool raise_exception) {
|
492
|
+
auto nargs = NIL_P(args) ? 0 : RARRAY_LEN(args);
|
493
|
+
ssize_t remaining_kwargs = NIL_P(kwargs) ? 0 : RHASH_SIZE(kwargs);
|
494
|
+
ssize_t arg_pos = 0;
|
495
|
+
bool allow_varargs_intlist = false;
|
496
|
+
|
497
|
+
// if there is a single positional IntArrayRef argument, i.e. expand(..), view(...),
|
498
|
+
// allow a var-args style IntArrayRef, so expand(5,3) behaves as expand((5,3))
|
499
|
+
if (max_pos_args == 1 && params[0].type_ == ParameterType::INT_LIST) {
|
500
|
+
allow_varargs_intlist = true;
|
501
|
+
}
|
502
|
+
|
503
|
+
if (nargs > max_pos_args && !allow_varargs_intlist) {
|
504
|
+
if (raise_exception) {
|
505
|
+
// foo() takes takes 2 positional arguments but 3 were given
|
506
|
+
extra_args(*this, nargs);
|
507
|
+
}
|
508
|
+
return false;
|
509
|
+
}
|
510
|
+
|
511
|
+
// if (!overloaded_args.empty()) {
|
512
|
+
// overloaded_args.clear();
|
513
|
+
// }
|
514
|
+
|
515
|
+
int i = 0;
|
516
|
+
// if (self != nullptr && !THPVariable_CheckExact(self) && check_has_torch_function(self)) {
|
517
|
+
// append_overloaded_arg(&this->overloaded_args, self);
|
518
|
+
// }
|
519
|
+
for (auto& param : params) {
|
520
|
+
VALUE obj = missing;
|
521
|
+
bool is_kwd = false;
|
522
|
+
if (arg_pos < nargs) {
|
523
|
+
// extra positional args given after single positional IntArrayRef arg
|
524
|
+
if (param.keyword_only) {
|
525
|
+
if (raise_exception) {
|
526
|
+
extra_args(*this, nargs);
|
527
|
+
}
|
528
|
+
return false;
|
529
|
+
}
|
530
|
+
obj = rb_ary_entry(args, arg_pos);
|
531
|
+
} else if (!NIL_P(kwargs)) {
|
532
|
+
obj = rb_hash_lookup2(kwargs, param.ruby_name, missing);
|
533
|
+
// for (VALUE numpy_name: param.numpy_python_names) {
|
534
|
+
// if (obj) {
|
535
|
+
// break;
|
536
|
+
// }
|
537
|
+
// obj = rb_hash_aref(kwargs, numpy_name);
|
538
|
+
// }
|
539
|
+
is_kwd = true;
|
540
|
+
}
|
541
|
+
|
542
|
+
if ((obj == missing && param.optional) || (NIL_P(obj) && param.allow_none)) {
|
543
|
+
dst[i++] = Qnil;
|
544
|
+
} else if (obj == missing) {
|
545
|
+
if (raise_exception) {
|
546
|
+
// foo() missing 1 required positional argument: "b"
|
547
|
+
missing_args(*this, i);
|
548
|
+
}
|
549
|
+
return false;
|
550
|
+
} else if (param.check(obj, i)) {
|
551
|
+
dst[i++] = obj;
|
552
|
+
// XXX: the Variable check is necessary because sizes become tensors when
|
553
|
+
// tracer is enabled. This behavior easily leads to ambiguities, and we
|
554
|
+
// should avoid having complex signatures that make use of it...
|
555
|
+
} else if (allow_varargs_intlist && arg_pos == 0 && !is_kwd &&
|
556
|
+
THPUtils_checkIndex(obj)) {
|
557
|
+
// take all positional arguments as this parameter
|
558
|
+
// e.g. permute(1, 2, 3) -> permute((1, 2, 3))
|
559
|
+
dst[i++] = args;
|
560
|
+
arg_pos = nargs;
|
561
|
+
continue;
|
562
|
+
} else if (raise_exception) {
|
563
|
+
if (is_kwd) {
|
564
|
+
// foo(): argument 'other' must be str, not int
|
565
|
+
rb_raise(rb_eArgError, "%s(): argument '%s' must be %s, not %s",
|
566
|
+
name.c_str(), param.name.c_str(), param.type_name().c_str(),
|
567
|
+
rb_obj_classname(obj));
|
568
|
+
} else {
|
569
|
+
// foo(): argument 'other' (position 2) must be str, not int
|
570
|
+
rb_raise(rb_eArgError, "%s(): argument '%s' (position %ld) must be %s, not %s",
|
571
|
+
name.c_str(), param.name.c_str(), static_cast<long>(arg_pos + 1),
|
572
|
+
param.type_name().c_str(), rb_obj_classname(obj));
|
573
|
+
}
|
574
|
+
} else {
|
575
|
+
return false;
|
576
|
+
}
|
577
|
+
|
578
|
+
if (!is_kwd) {
|
579
|
+
arg_pos++;
|
580
|
+
} else if (obj != missing) {
|
581
|
+
remaining_kwargs--;
|
582
|
+
}
|
583
|
+
}
|
584
|
+
|
585
|
+
if (remaining_kwargs > 0) {
|
586
|
+
if (raise_exception) {
|
587
|
+
// foo() got an unexpected keyword argument "b"
|
588
|
+
extra_kwargs(*this, kwargs, nargs);
|
589
|
+
}
|
590
|
+
return false;
|
591
|
+
}
|
592
|
+
return true;
|
593
|
+
}
|