torch-rb 0.3.4 → 0.4.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +28 -0
- data/README.md +2 -1
- data/codegen/function.rb +134 -0
- data/codegen/generate_functions.rb +549 -0
- data/{lib/torch/native → codegen}/native_functions.yaml +0 -0
- data/ext/torch/ext.cpp +76 -87
- data/ext/torch/extconf.rb +5 -2
- data/ext/torch/nn_functions.h +6 -0
- data/ext/torch/ruby_arg_parser.cpp +593 -0
- data/ext/torch/ruby_arg_parser.h +373 -0
- data/ext/torch/{templates.hpp → templates.h} +87 -97
- data/ext/torch/tensor_functions.h +6 -0
- data/ext/torch/torch_functions.h +6 -0
- data/ext/torch/utils.h +42 -0
- data/ext/torch/{templates.cpp → wrap_outputs.h} +44 -7
- data/lib/torch.rb +51 -77
- data/lib/torch/nn/functional.rb +142 -18
- data/lib/torch/nn/init.rb +5 -19
- data/lib/torch/nn/leaky_relu.rb +3 -3
- data/lib/torch/nn/module.rb +9 -1
- data/lib/torch/nn/upsample.rb +31 -0
- data/lib/torch/optim/adadelta.rb +1 -1
- data/lib/torch/optim/adam.rb +2 -2
- data/lib/torch/optim/adamax.rb +1 -1
- data/lib/torch/optim/adamw.rb +1 -1
- data/lib/torch/optim/asgd.rb +1 -1
- data/lib/torch/optim/sgd.rb +3 -3
- data/lib/torch/tensor.rb +36 -115
- data/lib/torch/utils/data/data_loader.rb +2 -0
- data/lib/torch/utils/data/tensor_dataset.rb +2 -0
- data/lib/torch/version.rb +1 -1
- metadata +19 -14
- data/lib/torch/native/dispatcher.rb +0 -48
- data/lib/torch/native/function.rb +0 -115
- data/lib/torch/native/generator.rb +0 -163
- data/lib/torch/native/parser.rb +0 -140
File without changes
|
data/ext/torch/ext.cpp
CHANGED
@@ -7,13 +7,14 @@
|
|
7
7
|
#include <rice/Constructor.hpp>
|
8
8
|
#include <rice/Hash.hpp>
|
9
9
|
|
10
|
-
#include "templates.
|
10
|
+
#include "templates.h"
|
11
|
+
#include "utils.h"
|
11
12
|
|
12
13
|
// generated with:
|
13
14
|
// rake generate:functions
|
14
|
-
#include "torch_functions.
|
15
|
-
#include "tensor_functions.
|
16
|
-
#include "nn_functions.
|
15
|
+
#include "torch_functions.h"
|
16
|
+
#include "tensor_functions.h"
|
17
|
+
#include "nn_functions.h"
|
17
18
|
|
18
19
|
using namespace Rice;
|
19
20
|
using torch::indexing::TensorIndex;
|
@@ -29,11 +30,47 @@ void handle_error(torch::Error const & ex)
|
|
29
30
|
throw Exception(rb_eRuntimeError, ex.what_without_backtrace());
|
30
31
|
}
|
31
32
|
|
33
|
+
Class rb_cTensor;
|
34
|
+
|
32
35
|
std::vector<TensorIndex> index_vector(Array a) {
|
33
|
-
|
36
|
+
Object obj;
|
37
|
+
|
38
|
+
std::vector<TensorIndex> indices;
|
34
39
|
indices.reserve(a.size());
|
40
|
+
|
35
41
|
for (size_t i = 0; i < a.size(); i++) {
|
36
|
-
|
42
|
+
obj = a[i];
|
43
|
+
|
44
|
+
if (obj.is_instance_of(rb_cInteger)) {
|
45
|
+
indices.push_back(from_ruby<int64_t>(obj));
|
46
|
+
} else if (obj.is_instance_of(rb_cRange)) {
|
47
|
+
torch::optional<int64_t> start_index = from_ruby<int64_t>(obj.call("begin"));
|
48
|
+
torch::optional<int64_t> stop_index = -1;
|
49
|
+
|
50
|
+
Object end = obj.call("end");
|
51
|
+
if (!end.is_nil()) {
|
52
|
+
stop_index = from_ruby<int64_t>(end);
|
53
|
+
}
|
54
|
+
|
55
|
+
Object exclude_end = obj.call("exclude_end?");
|
56
|
+
if (!exclude_end) {
|
57
|
+
if (stop_index.value() == -1) {
|
58
|
+
stop_index = torch::nullopt;
|
59
|
+
} else {
|
60
|
+
stop_index = stop_index.value() + 1;
|
61
|
+
}
|
62
|
+
}
|
63
|
+
|
64
|
+
indices.push_back(torch::indexing::Slice(start_index, stop_index));
|
65
|
+
} else if (obj.is_instance_of(rb_cTensor)) {
|
66
|
+
indices.push_back(from_ruby<Tensor>(obj));
|
67
|
+
} else if (obj.is_nil()) {
|
68
|
+
indices.push_back(torch::indexing::None);
|
69
|
+
} else if (obj == True || obj == False) {
|
70
|
+
indices.push_back(from_ruby<bool>(obj));
|
71
|
+
} else {
|
72
|
+
throw Exception(rb_eArgError, "Unsupported index type: %s", rb_obj_classname(obj));
|
73
|
+
}
|
37
74
|
}
|
38
75
|
return indices;
|
39
76
|
}
|
@@ -45,9 +82,10 @@ void Init_ext()
|
|
45
82
|
rb_mTorch.add_handler<torch::Error>(handle_error);
|
46
83
|
add_torch_functions(rb_mTorch);
|
47
84
|
|
48
|
-
|
85
|
+
rb_cTensor = define_class_under<torch::Tensor>(rb_mTorch, "Tensor");
|
49
86
|
rb_cTensor.add_handler<torch::Error>(handle_error);
|
50
87
|
add_tensor_functions(rb_cTensor);
|
88
|
+
THPVariableClass = rb_cTensor.value();
|
51
89
|
|
52
90
|
Module rb_mNN = define_module_under(rb_mTorch, "NN");
|
53
91
|
rb_mNN.add_handler<torch::Error>(handle_error);
|
@@ -68,13 +106,6 @@ void Init_ext()
|
|
68
106
|
return generator.seed();
|
69
107
|
});
|
70
108
|
|
71
|
-
Class rb_cTensorIndex = define_class_under<TensorIndex>(rb_mTorch, "TensorIndex")
|
72
|
-
.define_singleton_method("boolean", *[](bool value) { return TensorIndex(value); })
|
73
|
-
.define_singleton_method("integer", *[](int64_t value) { return TensorIndex(value); })
|
74
|
-
.define_singleton_method("tensor", *[](torch::Tensor& value) { return TensorIndex(value); })
|
75
|
-
.define_singleton_method("slice", *[](torch::optional<int64_t> start_index, torch::optional<int64_t> stop_index) { return TensorIndex(torch::indexing::Slice(start_index, stop_index)); })
|
76
|
-
.define_singleton_method("none", *[]() { return TensorIndex(torch::indexing::None); });
|
77
|
-
|
78
109
|
// https://pytorch.org/cppdocs/api/structc10_1_1_i_value.html
|
79
110
|
Class rb_cIValue = define_class_under<torch::IValue>(rb_mTorch, "IValue")
|
80
111
|
.add_handler<torch::Error>(handle_error)
|
@@ -224,67 +255,6 @@ void Init_ext()
|
|
224
255
|
*[] {
|
225
256
|
return torch::get_parallel_info();
|
226
257
|
})
|
227
|
-
// begin tensor creation
|
228
|
-
.define_singleton_method(
|
229
|
-
"_arange",
|
230
|
-
*[](Scalar start, Scalar end, Scalar step, const torch::TensorOptions &options) {
|
231
|
-
return torch::arange(start, end, step, options);
|
232
|
-
})
|
233
|
-
.define_singleton_method(
|
234
|
-
"_empty",
|
235
|
-
*[](IntArrayRef size, const torch::TensorOptions &options) {
|
236
|
-
return torch::empty(size, options);
|
237
|
-
})
|
238
|
-
.define_singleton_method(
|
239
|
-
"_eye",
|
240
|
-
*[](int64_t m, int64_t n, const torch::TensorOptions &options) {
|
241
|
-
return torch::eye(m, n, options);
|
242
|
-
})
|
243
|
-
.define_singleton_method(
|
244
|
-
"_full",
|
245
|
-
*[](IntArrayRef size, Scalar fill_value, const torch::TensorOptions& options) {
|
246
|
-
return torch::full(size, fill_value, options);
|
247
|
-
})
|
248
|
-
.define_singleton_method(
|
249
|
-
"_linspace",
|
250
|
-
*[](Scalar start, Scalar end, int64_t steps, const torch::TensorOptions& options) {
|
251
|
-
return torch::linspace(start, end, steps, options);
|
252
|
-
})
|
253
|
-
.define_singleton_method(
|
254
|
-
"_logspace",
|
255
|
-
*[](Scalar start, Scalar end, int64_t steps, double base, const torch::TensorOptions& options) {
|
256
|
-
return torch::logspace(start, end, steps, base, options);
|
257
|
-
})
|
258
|
-
.define_singleton_method(
|
259
|
-
"_ones",
|
260
|
-
*[](IntArrayRef size, const torch::TensorOptions &options) {
|
261
|
-
return torch::ones(size, options);
|
262
|
-
})
|
263
|
-
.define_singleton_method(
|
264
|
-
"_rand",
|
265
|
-
*[](IntArrayRef size, const torch::TensorOptions &options) {
|
266
|
-
return torch::rand(size, options);
|
267
|
-
})
|
268
|
-
.define_singleton_method(
|
269
|
-
"_randint",
|
270
|
-
*[](int64_t low, int64_t high, IntArrayRef size, const torch::TensorOptions &options) {
|
271
|
-
return torch::randint(low, high, size, options);
|
272
|
-
})
|
273
|
-
.define_singleton_method(
|
274
|
-
"_randn",
|
275
|
-
*[](IntArrayRef size, const torch::TensorOptions &options) {
|
276
|
-
return torch::randn(size, options);
|
277
|
-
})
|
278
|
-
.define_singleton_method(
|
279
|
-
"_randperm",
|
280
|
-
*[](int64_t n, const torch::TensorOptions &options) {
|
281
|
-
return torch::randperm(n, options);
|
282
|
-
})
|
283
|
-
.define_singleton_method(
|
284
|
-
"_zeros",
|
285
|
-
*[](IntArrayRef size, const torch::TensorOptions &options) {
|
286
|
-
return torch::zeros(size, options);
|
287
|
-
})
|
288
258
|
// begin operations
|
289
259
|
.define_singleton_method(
|
290
260
|
"_save",
|
@@ -301,20 +271,15 @@ void Init_ext()
|
|
301
271
|
// https://github.com/pytorch/pytorch/issues/20356#issuecomment-567663701
|
302
272
|
return torch::pickle_load(v);
|
303
273
|
})
|
304
|
-
.define_singleton_method(
|
305
|
-
"_binary_cross_entropy_with_logits",
|
306
|
-
*[](const Tensor &input, const Tensor &target, OptionalTensor weight, OptionalTensor pos_weight, MyReduction reduction) {
|
307
|
-
return torch::binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction);
|
308
|
-
})
|
309
274
|
.define_singleton_method(
|
310
275
|
"_from_blob",
|
311
|
-
*[](String s,
|
276
|
+
*[](String s, std::vector<int64_t> size, const torch::TensorOptions &options) {
|
312
277
|
void *data = const_cast<char *>(s.c_str());
|
313
278
|
return torch::from_blob(data, size, options);
|
314
279
|
})
|
315
280
|
.define_singleton_method(
|
316
281
|
"_tensor",
|
317
|
-
*[](Array a,
|
282
|
+
*[](Array a, std::vector<int64_t> size, const torch::TensorOptions &options) {
|
318
283
|
auto dtype = options.dtype();
|
319
284
|
torch::Tensor t;
|
320
285
|
if (dtype == torch::kBool) {
|
@@ -347,6 +312,25 @@ void Init_ext()
|
|
347
312
|
.define_method("numel", &torch::Tensor::numel)
|
348
313
|
.define_method("element_size", &torch::Tensor::element_size)
|
349
314
|
.define_method("requires_grad", &torch::Tensor::requires_grad)
|
315
|
+
// in C++ for performance
|
316
|
+
.define_method(
|
317
|
+
"shape",
|
318
|
+
*[](Tensor& self) {
|
319
|
+
Array a;
|
320
|
+
for (auto &size : self.sizes()) {
|
321
|
+
a.push(size);
|
322
|
+
}
|
323
|
+
return a;
|
324
|
+
})
|
325
|
+
.define_method(
|
326
|
+
"_strides",
|
327
|
+
*[](Tensor& self) {
|
328
|
+
Array a;
|
329
|
+
for (auto &stride : self.strides()) {
|
330
|
+
a.push(stride);
|
331
|
+
}
|
332
|
+
return a;
|
333
|
+
})
|
350
334
|
.define_method(
|
351
335
|
"_index",
|
352
336
|
*[](Tensor& self, Array indices) {
|
@@ -379,11 +363,6 @@ void Init_ext()
|
|
379
363
|
*[](Tensor& self, bool requires_grad) {
|
380
364
|
return self.set_requires_grad(requires_grad);
|
381
365
|
})
|
382
|
-
.define_method(
|
383
|
-
"_backward",
|
384
|
-
*[](Tensor& self, OptionalTensor gradient, bool create_graph, bool retain_graph) {
|
385
|
-
return self.backward(gradient, create_graph, retain_graph);
|
386
|
-
})
|
387
366
|
.define_method(
|
388
367
|
"grad",
|
389
368
|
*[](Tensor& self) {
|
@@ -430,9 +409,19 @@ void Init_ext()
|
|
430
409
|
tensor = tensor.to(device);
|
431
410
|
}
|
432
411
|
|
412
|
+
if (!tensor.is_contiguous()) {
|
413
|
+
tensor = tensor.contiguous();
|
414
|
+
}
|
415
|
+
|
433
416
|
auto data_ptr = (const char *) tensor.data_ptr();
|
434
417
|
return std::string(data_ptr, tensor.numel() * tensor.element_size());
|
435
418
|
})
|
419
|
+
// for TorchVision
|
420
|
+
.define_method(
|
421
|
+
"_data_ptr",
|
422
|
+
*[](Tensor& self) {
|
423
|
+
return reinterpret_cast<uintptr_t>(self.data_ptr());
|
424
|
+
})
|
436
425
|
// TODO figure out a better way to do this
|
437
426
|
.define_method(
|
438
427
|
"_flat_data",
|
data/ext/torch/extconf.rb
CHANGED
@@ -17,6 +17,9 @@ if have_library("omp") || have_library("gomp")
|
|
17
17
|
end
|
18
18
|
|
19
19
|
if apple_clang
|
20
|
+
# silence rice warnings
|
21
|
+
$CXXFLAGS += " -Wno-deprecated-declarations"
|
22
|
+
|
20
23
|
# silence ruby/intern.h warning
|
21
24
|
$CXXFLAGS += " -Wno-deprecated-register"
|
22
25
|
|
@@ -66,8 +69,8 @@ end
|
|
66
69
|
|
67
70
|
# generate C++ functions
|
68
71
|
puts "Generating C++ functions..."
|
69
|
-
require_relative "../../
|
70
|
-
|
72
|
+
require_relative "../../codegen/generate_functions"
|
73
|
+
generate_functions
|
71
74
|
|
72
75
|
# create makefile
|
73
76
|
create_makefile("torch/ext")
|
@@ -0,0 +1,593 @@
|
|
1
|
+
// adapted from PyTorch - python_arg_parser.cpp
|
2
|
+
|
3
|
+
#include "ruby_arg_parser.h"
|
4
|
+
|
5
|
+
VALUE THPVariableClass = Qnil;
|
6
|
+
|
7
|
+
static std::unordered_map<std::string, ParameterType> type_map = {
|
8
|
+
{"Tensor", ParameterType::TENSOR},
|
9
|
+
{"Scalar", ParameterType::SCALAR},
|
10
|
+
{"int64_t", ParameterType::INT64},
|
11
|
+
{"double", ParameterType::DOUBLE},
|
12
|
+
{"complex", ParameterType::COMPLEX},
|
13
|
+
{"TensorList", ParameterType::TENSOR_LIST},
|
14
|
+
{"IntArrayRef", ParameterType::INT_LIST},
|
15
|
+
{"ArrayRef<double>", ParameterType::FLOAT_LIST},
|
16
|
+
{"Generator", ParameterType::GENERATOR},
|
17
|
+
{"bool", ParameterType::BOOL},
|
18
|
+
{"Storage", ParameterType::STORAGE},
|
19
|
+
// {"PyObject*", ParameterType::PYOBJECT},
|
20
|
+
{"ScalarType", ParameterType::SCALARTYPE},
|
21
|
+
{"Layout", ParameterType::LAYOUT},
|
22
|
+
{"MemoryFormat", ParameterType::MEMORY_FORMAT},
|
23
|
+
{"QScheme", ParameterType::QSCHEME},
|
24
|
+
{"Device", ParameterType::DEVICE},
|
25
|
+
{"std::string", ParameterType::STRING},
|
26
|
+
{"Dimname", ParameterType::DIMNAME},
|
27
|
+
{"DimnameList", ParameterType::DIMNAME_LIST},
|
28
|
+
};
|
29
|
+
|
30
|
+
static const std::unordered_map<std::string, std::vector<std::string>> numpy_compatibility_arg_names = {
|
31
|
+
{"dim", {"axis"}},
|
32
|
+
{"keepdim", {"keepdims"}},
|
33
|
+
{"input", {"x", "a", "x1"}},
|
34
|
+
{"other", {"x2"}},
|
35
|
+
};
|
36
|
+
|
37
|
+
static bool should_allow_numbers_as_tensors(const std::string& name) {
|
38
|
+
static std::unordered_set<std::string> allowed = {
|
39
|
+
"add", "add_", "add_out",
|
40
|
+
"div", "div_", "div_out",
|
41
|
+
"mul", "mul_", "mul_out",
|
42
|
+
"sub", "sub_", "sub_out",
|
43
|
+
"true_divide", "true_divide_", "true_divide_out",
|
44
|
+
"floor_divide", "floor_divide_", "floor_divide_out"
|
45
|
+
};
|
46
|
+
return allowed.find(name) != allowed.end();
|
47
|
+
}
|
48
|
+
|
49
|
+
FunctionParameter::FunctionParameter(const std::string& fmt, bool keyword_only)
|
50
|
+
: optional(false)
|
51
|
+
, allow_none(false)
|
52
|
+
, keyword_only(keyword_only)
|
53
|
+
, size(0)
|
54
|
+
, default_scalar(0)
|
55
|
+
{
|
56
|
+
auto space = fmt.find(' ');
|
57
|
+
if (space == std::string::npos) {
|
58
|
+
throw std::runtime_error("FunctionParameter(): missing type: " + fmt);
|
59
|
+
}
|
60
|
+
|
61
|
+
auto type_str = fmt.substr(0, space);
|
62
|
+
|
63
|
+
auto question = type_str.find('?');
|
64
|
+
if (question != std::string::npos) {
|
65
|
+
allow_none = true;
|
66
|
+
type_str = type_str.substr(0, question);
|
67
|
+
}
|
68
|
+
|
69
|
+
// Parse and remove brackets from type_str
|
70
|
+
auto bracket = type_str.find('[');
|
71
|
+
if (bracket != std::string::npos) {
|
72
|
+
auto size_str = type_str.substr(bracket + 1, type_str.length() - bracket - 2);
|
73
|
+
size = atoi(size_str.c_str());
|
74
|
+
type_str = type_str.substr(0, bracket);
|
75
|
+
}
|
76
|
+
|
77
|
+
auto name_str = fmt.substr(space + 1);
|
78
|
+
auto it = type_map.find(type_str);
|
79
|
+
if (it == type_map.end()) {
|
80
|
+
throw std::runtime_error("FunctionParameter(): invalid type string: " + type_str);
|
81
|
+
}
|
82
|
+
type_ = it->second;
|
83
|
+
|
84
|
+
auto eq = name_str.find('=');
|
85
|
+
if (eq != std::string::npos) {
|
86
|
+
name = name_str.substr(0, eq);
|
87
|
+
optional = true;
|
88
|
+
set_default_str(name_str.substr(eq + 1));
|
89
|
+
} else {
|
90
|
+
name = name_str;
|
91
|
+
}
|
92
|
+
ruby_name = THPUtils_internSymbol(name);
|
93
|
+
auto np_compat_it = numpy_compatibility_arg_names.find(name);
|
94
|
+
if (np_compat_it != numpy_compatibility_arg_names.end()) {
|
95
|
+
for (const auto& str: np_compat_it->second) {
|
96
|
+
numpy_python_names.push_back(THPUtils_internSymbol(str));
|
97
|
+
}
|
98
|
+
}
|
99
|
+
}
|
100
|
+
|
101
|
+
bool is_tensor_list(VALUE obj, int argnum, bool throw_error) {
|
102
|
+
if (!RB_TYPE_P(obj, T_ARRAY)) {
|
103
|
+
return false;
|
104
|
+
}
|
105
|
+
auto size = RARRAY_LEN(obj);
|
106
|
+
for (int idx = 0; idx < size; idx++) {
|
107
|
+
VALUE iobj = rb_ary_entry(obj, idx);
|
108
|
+
if (!THPVariable_Check(iobj)) {
|
109
|
+
if (throw_error) {
|
110
|
+
rb_raise(rb_eArgError, "expected Tensor as element %d in argument %d, but got %s",
|
111
|
+
static_cast<int>(idx), argnum, rb_obj_classname(obj));
|
112
|
+
}
|
113
|
+
return false;
|
114
|
+
}
|
115
|
+
}
|
116
|
+
return true;
|
117
|
+
}
|
118
|
+
|
119
|
+
// argnum is needed for raising the TypeError, it's used in the error message.
|
120
|
+
auto FunctionParameter::check(VALUE obj, int argnum) -> bool
|
121
|
+
{
|
122
|
+
switch (type_) {
|
123
|
+
case ParameterType::TENSOR: {
|
124
|
+
if (THPVariable_Check(obj)) {
|
125
|
+
return true;
|
126
|
+
}
|
127
|
+
return allow_numbers_as_tensors && THPUtils_checkScalar(obj);
|
128
|
+
}
|
129
|
+
case ParameterType::SCALAR:
|
130
|
+
case ParameterType::COMPLEX:
|
131
|
+
if (RB_TYPE_P(obj, T_COMPLEX)) {
|
132
|
+
return true;
|
133
|
+
}
|
134
|
+
// fallthrough
|
135
|
+
case ParameterType::DOUBLE: {
|
136
|
+
if (RB_FLOAT_TYPE_P(obj) || FIXNUM_P(obj)) {
|
137
|
+
return true;
|
138
|
+
}
|
139
|
+
if (THPVariable_Check(obj)) {
|
140
|
+
auto var = from_ruby<torch::Tensor>(obj);
|
141
|
+
return !var.requires_grad() && var.dim() == 0;
|
142
|
+
}
|
143
|
+
return false;
|
144
|
+
}
|
145
|
+
case ParameterType::INT64: {
|
146
|
+
if (FIXNUM_P(obj)) {
|
147
|
+
return true;
|
148
|
+
}
|
149
|
+
if (THPVariable_Check(obj)) {
|
150
|
+
auto var = from_ruby<torch::Tensor>(obj);
|
151
|
+
return at::isIntegralType(var.scalar_type(), /*includeBool=*/false) && !var.requires_grad() && var.dim() == 0;
|
152
|
+
}
|
153
|
+
return false;
|
154
|
+
}
|
155
|
+
case ParameterType::DIMNAME: return false; // return THPUtils_checkDimname(obj);
|
156
|
+
case ParameterType::DIMNAME_LIST: {
|
157
|
+
return false;
|
158
|
+
// if (THPUtils_checkDimnameList(obj)) {
|
159
|
+
// return true;
|
160
|
+
// }
|
161
|
+
// // if a size is specified (e.g. DimnameList[1]) we also allow passing a single Dimname
|
162
|
+
// return size == 1 && THPUtils_checkDimname(obj);
|
163
|
+
}
|
164
|
+
case ParameterType::TENSOR_LIST: {
|
165
|
+
return is_tensor_list(obj, argnum, true /* throw_error */);
|
166
|
+
}
|
167
|
+
case ParameterType::INT_LIST: {
|
168
|
+
if (RB_TYPE_P(obj, T_ARRAY)) {
|
169
|
+
return true;
|
170
|
+
}
|
171
|
+
// if a size is specified (e.g. IntArrayRef[2]) we also allow passing a single int
|
172
|
+
return size > 0 && FIXNUM_P(obj);
|
173
|
+
}
|
174
|
+
case ParameterType::FLOAT_LIST: return (RB_TYPE_P(obj, T_ARRAY));
|
175
|
+
case ParameterType::GENERATOR: return false; // return THPGenerator_Check(obj);
|
176
|
+
case ParameterType::BOOL: return obj == Qtrue || obj == Qfalse;
|
177
|
+
case ParameterType::STORAGE: return false; // return isStorage(obj);
|
178
|
+
// case ParameterType::PYOBJECT: return true;
|
179
|
+
case ParameterType::SCALARTYPE: return SYMBOL_P(obj);
|
180
|
+
case ParameterType::LAYOUT: return SYMBOL_P(obj);
|
181
|
+
case ParameterType::MEMORY_FORMAT: return false; // return THPMemoryFormat_Check(obj);
|
182
|
+
case ParameterType::QSCHEME: return false; // return THPQScheme_Check(obj);
|
183
|
+
case ParameterType::DEVICE: return RB_TYPE_P(obj, T_STRING); // TODO check device
|
184
|
+
case ParameterType::STRING: return RB_TYPE_P(obj, T_STRING);
|
185
|
+
default: throw std::runtime_error("unknown parameter type");
|
186
|
+
}
|
187
|
+
}
|
188
|
+
|
189
|
+
std::string FunctionParameter::type_name() const {
|
190
|
+
switch (type_) {
|
191
|
+
case ParameterType::TENSOR: return "Tensor";
|
192
|
+
case ParameterType::SCALAR: return "Number";
|
193
|
+
case ParameterType::INT64: return "int";
|
194
|
+
case ParameterType::DOUBLE: return "float";
|
195
|
+
case ParameterType::COMPLEX: return "complex";
|
196
|
+
case ParameterType::TENSOR_LIST: return "array of Tensors";
|
197
|
+
case ParameterType::INT_LIST: return "array of ints";
|
198
|
+
case ParameterType::FLOAT_LIST: return "array of floats";
|
199
|
+
case ParameterType::GENERATOR: return "torch.Generator";
|
200
|
+
case ParameterType::BOOL: return "bool";
|
201
|
+
case ParameterType::STORAGE: return "torch.Storage";
|
202
|
+
// case ParameterType::PYOBJECT: return "object";
|
203
|
+
case ParameterType::SCALARTYPE: return "torch.dtype";
|
204
|
+
case ParameterType::LAYOUT: return "torch.layout";
|
205
|
+
case ParameterType::MEMORY_FORMAT: return "torch.memory_format";
|
206
|
+
case ParameterType::QSCHEME: return "torch.qscheme";
|
207
|
+
case ParameterType::DEVICE: return "torch.device";
|
208
|
+
case ParameterType::STRING: return "str";
|
209
|
+
case ParameterType::DIMNAME: return "name";
|
210
|
+
case ParameterType::DIMNAME_LIST: return "array of names";
|
211
|
+
default: throw std::runtime_error("unknown parameter type");
|
212
|
+
}
|
213
|
+
}
|
214
|
+
|
215
|
+
static inline c10::optional<int64_t> parse_as_integer(const std::string& s) {
|
216
|
+
if (s.empty())
|
217
|
+
return c10::nullopt;
|
218
|
+
char *str_end;
|
219
|
+
long ans = strtol(s.c_str(), &str_end, 0);
|
220
|
+
// *str_end == 0 if the entire string was parsed as an integer.
|
221
|
+
return (*str_end == 0) ? c10::optional<int64_t>(ans) : c10::nullopt;
|
222
|
+
}
|
223
|
+
|
224
|
+
/*
|
225
|
+
Parse default value of IntArrayRef declared at native_functions.yaml
|
226
|
+
|
227
|
+
There are two kinds of default values:
|
228
|
+
1. IntArrayRef[2] x=1 (where size=2, value={1,1}
|
229
|
+
2. IntArrayRef x={1,2,3} (where size=3, value={1,2,3}, note that there cannot be space after comma since native_parse.py uses ', ' to split args)
|
230
|
+
*/
|
231
|
+
static inline std::vector<int64_t> parse_intlist_args(const std::string& s, int64_t size) {
|
232
|
+
size_t n = s.size();
|
233
|
+
|
234
|
+
if (s.empty()) return std::vector<int64_t>();
|
235
|
+
|
236
|
+
// case 1. s is an int (e.g., s=2)
|
237
|
+
if (s[0] != '{') {
|
238
|
+
return std::vector<int64_t>(size, std::stol(s));
|
239
|
+
}
|
240
|
+
|
241
|
+
// case 2. s is a list of dims (e.g., s={1,2})
|
242
|
+
|
243
|
+
// since already checked left brace '{' above, here only checks right brace '}'
|
244
|
+
TORCH_CHECK(s[n - 1] == '}', "Default value of IntArrayRef is missing right brace '}', found ", s[n - 1]);
|
245
|
+
|
246
|
+
auto args = std::vector<int64_t>();
|
247
|
+
std::istringstream ss(s.substr(1, s.length() - 2)); // exclude '{' and '}'
|
248
|
+
std::string tok;
|
249
|
+
|
250
|
+
while(std::getline(ss, tok, ',')) {
|
251
|
+
args.emplace_back(std::stol(tok));
|
252
|
+
}
|
253
|
+
return args;
|
254
|
+
}
|
255
|
+
|
256
|
+
void FunctionParameter::set_default_str(const std::string& str) {
|
257
|
+
if (str == "None") {
|
258
|
+
allow_none = true;
|
259
|
+
}
|
260
|
+
if (type_ == ParameterType::TENSOR) {
|
261
|
+
if (str != "None") {
|
262
|
+
throw std::runtime_error("default value for Tensor must be none, got: " + str);
|
263
|
+
}
|
264
|
+
} else if (type_ == ParameterType::INT64) {
|
265
|
+
default_int = atol(str.c_str());
|
266
|
+
} else if (type_ == ParameterType::BOOL) {
|
267
|
+
default_bool = (str == "True" || str == "true");
|
268
|
+
} else if (type_ == ParameterType::DOUBLE) {
|
269
|
+
default_double = atof(str.c_str());
|
270
|
+
} else if (type_ == ParameterType::COMPLEX) {
|
271
|
+
default_complex[0] = atof(str.c_str()); // TODO: parse "x + xj"?
|
272
|
+
default_complex[1] = 0;
|
273
|
+
} else if (type_ == ParameterType::SCALAR) {
|
274
|
+
if (str != "None") {
|
275
|
+
// we sometimes rely on integer-vs-float values, e.g. with arange.
|
276
|
+
const auto as_integer = parse_as_integer(str);
|
277
|
+
default_scalar = as_integer.has_value() ? at::Scalar(as_integer.value()) :
|
278
|
+
at::Scalar(atof(str.c_str()));
|
279
|
+
}
|
280
|
+
} else if (type_ == ParameterType::INT_LIST) {
|
281
|
+
if (str != "None") {
|
282
|
+
default_intlist = parse_intlist_args(str, size);
|
283
|
+
}
|
284
|
+
} else if (type_ == ParameterType::FLOAT_LIST) {
|
285
|
+
if (str != "None") {
|
286
|
+
throw std::runtime_error("Defaults not supported for float[]");
|
287
|
+
}
|
288
|
+
} else if (type_ == ParameterType::SCALARTYPE) {
|
289
|
+
if (str == "None") {
|
290
|
+
default_scalartype = at::ScalarType::Undefined;
|
291
|
+
} else if (str == "torch.int64") {
|
292
|
+
default_scalartype = at::ScalarType::Long;
|
293
|
+
} else {
|
294
|
+
throw std::runtime_error("invalid default value for ScalarType: " + str);
|
295
|
+
}
|
296
|
+
} else if (type_ == ParameterType::LAYOUT) {
|
297
|
+
if (str == "None") {
|
298
|
+
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(allow_none);
|
299
|
+
} else if (str == "torch.strided") {
|
300
|
+
default_layout = at::Layout::Strided;
|
301
|
+
} else if (str == "torch.sparse_coo") {
|
302
|
+
default_layout = at::Layout::Sparse;
|
303
|
+
} else {
|
304
|
+
throw std::runtime_error("invalid default value for layout: " + str);
|
305
|
+
}
|
306
|
+
} else if (type_ == ParameterType::DEVICE) {
|
307
|
+
if (str != "None") {
|
308
|
+
throw std::runtime_error("invalid device: " + str);
|
309
|
+
}
|
310
|
+
} else if (type_ == ParameterType::STRING) {
|
311
|
+
if (str != "None" && str != "") {
|
312
|
+
throw std::runtime_error("invalid default string: " + str);
|
313
|
+
}
|
314
|
+
}
|
315
|
+
}
|
316
|
+
|
317
|
+
FunctionSignature::FunctionSignature(const std::string& fmt, int index)
|
318
|
+
: min_args(0)
|
319
|
+
, max_args(0)
|
320
|
+
, max_pos_args(0)
|
321
|
+
, index(index)
|
322
|
+
, hidden(false)
|
323
|
+
, deprecated(false)
|
324
|
+
{
|
325
|
+
auto open_paren = fmt.find('(');
|
326
|
+
if (open_paren == std::string::npos) {
|
327
|
+
throw std::runtime_error("missing opening parenthesis: " + fmt);
|
328
|
+
}
|
329
|
+
name = fmt.substr(0, open_paren);
|
330
|
+
|
331
|
+
bool allow_numbers_as_tensors = should_allow_numbers_as_tensors(name);
|
332
|
+
|
333
|
+
auto last_offset = open_paren + 1;
|
334
|
+
auto next_offset = last_offset;
|
335
|
+
bool keyword_only = false;
|
336
|
+
bool done = false;
|
337
|
+
while (!done) {
|
338
|
+
auto offset = fmt.find(", ", last_offset);
|
339
|
+
if (offset == std::string::npos) {
|
340
|
+
offset = fmt.find(')', last_offset);
|
341
|
+
done = true;
|
342
|
+
next_offset = offset+ 1;
|
343
|
+
// this 'if' happens for an empty parameter list, i.e. fn().
|
344
|
+
if (offset == last_offset) {
|
345
|
+
last_offset = next_offset;
|
346
|
+
break;
|
347
|
+
}
|
348
|
+
} else {
|
349
|
+
next_offset = offset + 2;
|
350
|
+
}
|
351
|
+
if (offset == std::string::npos) {
|
352
|
+
throw std::runtime_error("missing closing parenthesis: " + fmt);
|
353
|
+
}
|
354
|
+
if (offset == last_offset) {
|
355
|
+
throw std::runtime_error("malformed signature: " + fmt);
|
356
|
+
}
|
357
|
+
|
358
|
+
auto param_str = fmt.substr(last_offset, offset - last_offset);
|
359
|
+
last_offset = next_offset;
|
360
|
+
if (param_str == "*") {
|
361
|
+
keyword_only = true;
|
362
|
+
} else {
|
363
|
+
params.emplace_back(param_str, keyword_only);
|
364
|
+
params.back().allow_numbers_as_tensors = allow_numbers_as_tensors;
|
365
|
+
}
|
366
|
+
}
|
367
|
+
|
368
|
+
if (fmt.substr(last_offset) == "|deprecated") {
|
369
|
+
hidden = true;
|
370
|
+
// TODO: raise warning when parsing deprecated signatures
|
371
|
+
deprecated = true;
|
372
|
+
} else if (fmt.substr(last_offset) == "|hidden") {
|
373
|
+
hidden = true;
|
374
|
+
}
|
375
|
+
|
376
|
+
max_args = params.size();
|
377
|
+
|
378
|
+
// count the number of non-optional args
|
379
|
+
for (auto& param : params) {
|
380
|
+
if (!param.optional) {
|
381
|
+
min_args++;
|
382
|
+
}
|
383
|
+
if (!param.keyword_only) {
|
384
|
+
max_pos_args++;
|
385
|
+
}
|
386
|
+
}
|
387
|
+
}
|
388
|
+
|
389
|
+
std::string FunctionSignature::toString() const {
|
390
|
+
// TODO: consider printing more proper schema strings with defaults, optionals, etc.
|
391
|
+
std::ostringstream ss;
|
392
|
+
bool keyword_already = false;
|
393
|
+
ss << "(";
|
394
|
+
int i = 0;
|
395
|
+
for (auto& param : params) {
|
396
|
+
if (i != 0) {
|
397
|
+
ss << ", ";
|
398
|
+
}
|
399
|
+
if (param.keyword_only && !keyword_already) {
|
400
|
+
ss << "*, ";
|
401
|
+
keyword_already = true;
|
402
|
+
}
|
403
|
+
ss << param.type_name() << " " << param.name;
|
404
|
+
i++;
|
405
|
+
}
|
406
|
+
ss << ")";
|
407
|
+
return ss.str();
|
408
|
+
}
|
409
|
+
|
410
|
+
[[noreturn]]
|
411
|
+
static void extra_args(const FunctionSignature& signature, ssize_t nargs) {
|
412
|
+
const long max_pos_args = signature.max_pos_args;
|
413
|
+
const long min_args = signature.min_args;
|
414
|
+
const long nargs_ = nargs;
|
415
|
+
if (min_args != max_pos_args) {
|
416
|
+
rb_raise(rb_eArgError, "%s() takes from %ld to %ld positional arguments but %ld were given",
|
417
|
+
signature.name.c_str(), min_args, max_pos_args, nargs_);
|
418
|
+
}
|
419
|
+
rb_raise(rb_eArgError, "%s() takes %ld positional argument%s but %ld %s given",
|
420
|
+
signature.name.c_str(),
|
421
|
+
max_pos_args, max_pos_args == 1 ? "" : "s",
|
422
|
+
nargs_, nargs == 1 ? "was" : "were");
|
423
|
+
}
|
424
|
+
|
425
|
+
[[noreturn]]
|
426
|
+
static void missing_args(const FunctionSignature& signature, int idx) {
|
427
|
+
int num_missing = 0;
|
428
|
+
std::stringstream ss;
|
429
|
+
|
430
|
+
auto& params = signature.params;
|
431
|
+
for (auto it = params.begin() + idx; it != params.end(); ++it) {
|
432
|
+
if (!it->optional) {
|
433
|
+
if (num_missing > 0) {
|
434
|
+
ss << ", ";
|
435
|
+
}
|
436
|
+
ss << '"' << it->name << '"';
|
437
|
+
num_missing++;
|
438
|
+
}
|
439
|
+
}
|
440
|
+
|
441
|
+
rb_raise(rb_eArgError, "%s() missing %d required positional argument%s: %s",
|
442
|
+
signature.name.c_str(),
|
443
|
+
num_missing,
|
444
|
+
num_missing == 1 ? "s" : "",
|
445
|
+
ss.str().c_str());
|
446
|
+
}
|
447
|
+
|
448
|
+
static ssize_t find_param(FunctionSignature& signature, VALUE name) {
|
449
|
+
ssize_t i = 0;
|
450
|
+
for (auto& param : signature.params) {
|
451
|
+
bool cmp = name == param.ruby_name;
|
452
|
+
if (cmp) {
|
453
|
+
return i;
|
454
|
+
}
|
455
|
+
i++;
|
456
|
+
}
|
457
|
+
return -1;
|
458
|
+
}
|
459
|
+
|
460
|
+
[[noreturn]]
|
461
|
+
static void extra_kwargs(FunctionSignature& signature, VALUE kwargs, ssize_t num_pos_args) {
|
462
|
+
VALUE key;
|
463
|
+
|
464
|
+
VALUE keys = rb_funcall(kwargs, rb_intern("keys"), 0);
|
465
|
+
if (RARRAY_LEN(keys) > 0) {
|
466
|
+
key = rb_ary_entry(keys, 0);
|
467
|
+
|
468
|
+
if (!THPUtils_checkSymbol(key)) {
|
469
|
+
rb_raise(rb_eArgError, "keywords must be symbols, not %s", rb_obj_classname(key));
|
470
|
+
}
|
471
|
+
|
472
|
+
auto param_idx = find_param(signature, key);
|
473
|
+
if (param_idx < 0) {
|
474
|
+
rb_raise(rb_eArgError, "%s() got an unexpected keyword argument '%s'",
|
475
|
+
signature.name.c_str(), THPUtils_unpackSymbol(key).c_str());
|
476
|
+
}
|
477
|
+
|
478
|
+
if (param_idx < num_pos_args) {
|
479
|
+
rb_raise(rb_eArgError, "%s() got multiple values for argument '%s'",
|
480
|
+
signature.name.c_str(), THPUtils_unpackSymbol(key).c_str());
|
481
|
+
}
|
482
|
+
}
|
483
|
+
|
484
|
+
// this should never be hit
|
485
|
+
rb_raise(rb_eArgError, "invalid keyword arguments");
|
486
|
+
}
|
487
|
+
|
488
|
+
VALUE missing = Qundef;
|
489
|
+
|
490
|
+
bool FunctionSignature::parse(VALUE self, VALUE args, VALUE kwargs, std::vector<VALUE> &dst, // NOLINT
|
491
|
+
bool raise_exception) {
|
492
|
+
auto nargs = NIL_P(args) ? 0 : RARRAY_LEN(args);
|
493
|
+
ssize_t remaining_kwargs = NIL_P(kwargs) ? 0 : RHASH_SIZE(kwargs);
|
494
|
+
ssize_t arg_pos = 0;
|
495
|
+
bool allow_varargs_intlist = false;
|
496
|
+
|
497
|
+
// if there is a single positional IntArrayRef argument, i.e. expand(..), view(...),
|
498
|
+
// allow a var-args style IntArrayRef, so expand(5,3) behaves as expand((5,3))
|
499
|
+
if (max_pos_args == 1 && params[0].type_ == ParameterType::INT_LIST) {
|
500
|
+
allow_varargs_intlist = true;
|
501
|
+
}
|
502
|
+
|
503
|
+
if (nargs > max_pos_args && !allow_varargs_intlist) {
|
504
|
+
if (raise_exception) {
|
505
|
+
// foo() takes takes 2 positional arguments but 3 were given
|
506
|
+
extra_args(*this, nargs);
|
507
|
+
}
|
508
|
+
return false;
|
509
|
+
}
|
510
|
+
|
511
|
+
// if (!overloaded_args.empty()) {
|
512
|
+
// overloaded_args.clear();
|
513
|
+
// }
|
514
|
+
|
515
|
+
int i = 0;
|
516
|
+
// if (self != nullptr && !THPVariable_CheckExact(self) && check_has_torch_function(self)) {
|
517
|
+
// append_overloaded_arg(&this->overloaded_args, self);
|
518
|
+
// }
|
519
|
+
for (auto& param : params) {
|
520
|
+
VALUE obj = missing;
|
521
|
+
bool is_kwd = false;
|
522
|
+
if (arg_pos < nargs) {
|
523
|
+
// extra positional args given after single positional IntArrayRef arg
|
524
|
+
if (param.keyword_only) {
|
525
|
+
if (raise_exception) {
|
526
|
+
extra_args(*this, nargs);
|
527
|
+
}
|
528
|
+
return false;
|
529
|
+
}
|
530
|
+
obj = rb_ary_entry(args, arg_pos);
|
531
|
+
} else if (!NIL_P(kwargs)) {
|
532
|
+
obj = rb_hash_lookup2(kwargs, param.ruby_name, missing);
|
533
|
+
// for (VALUE numpy_name: param.numpy_python_names) {
|
534
|
+
// if (obj) {
|
535
|
+
// break;
|
536
|
+
// }
|
537
|
+
// obj = rb_hash_aref(kwargs, numpy_name);
|
538
|
+
// }
|
539
|
+
is_kwd = true;
|
540
|
+
}
|
541
|
+
|
542
|
+
if ((obj == missing && param.optional) || (NIL_P(obj) && param.allow_none)) {
|
543
|
+
dst[i++] = Qnil;
|
544
|
+
} else if (obj == missing) {
|
545
|
+
if (raise_exception) {
|
546
|
+
// foo() missing 1 required positional argument: "b"
|
547
|
+
missing_args(*this, i);
|
548
|
+
}
|
549
|
+
return false;
|
550
|
+
} else if (param.check(obj, i)) {
|
551
|
+
dst[i++] = obj;
|
552
|
+
// XXX: the Variable check is necessary because sizes become tensors when
|
553
|
+
// tracer is enabled. This behavior easily leads to ambiguities, and we
|
554
|
+
// should avoid having complex signatures that make use of it...
|
555
|
+
} else if (allow_varargs_intlist && arg_pos == 0 && !is_kwd &&
|
556
|
+
THPUtils_checkIndex(obj)) {
|
557
|
+
// take all positional arguments as this parameter
|
558
|
+
// e.g. permute(1, 2, 3) -> permute((1, 2, 3))
|
559
|
+
dst[i++] = args;
|
560
|
+
arg_pos = nargs;
|
561
|
+
continue;
|
562
|
+
} else if (raise_exception) {
|
563
|
+
if (is_kwd) {
|
564
|
+
// foo(): argument 'other' must be str, not int
|
565
|
+
rb_raise(rb_eArgError, "%s(): argument '%s' must be %s, not %s",
|
566
|
+
name.c_str(), param.name.c_str(), param.type_name().c_str(),
|
567
|
+
rb_obj_classname(obj));
|
568
|
+
} else {
|
569
|
+
// foo(): argument 'other' (position 2) must be str, not int
|
570
|
+
rb_raise(rb_eArgError, "%s(): argument '%s' (position %ld) must be %s, not %s",
|
571
|
+
name.c_str(), param.name.c_str(), static_cast<long>(arg_pos + 1),
|
572
|
+
param.type_name().c_str(), rb_obj_classname(obj));
|
573
|
+
}
|
574
|
+
} else {
|
575
|
+
return false;
|
576
|
+
}
|
577
|
+
|
578
|
+
if (!is_kwd) {
|
579
|
+
arg_pos++;
|
580
|
+
} else if (obj != missing) {
|
581
|
+
remaining_kwargs--;
|
582
|
+
}
|
583
|
+
}
|
584
|
+
|
585
|
+
if (remaining_kwargs > 0) {
|
586
|
+
if (raise_exception) {
|
587
|
+
// foo() got an unexpected keyword argument "b"
|
588
|
+
extra_kwargs(*this, kwargs, nargs);
|
589
|
+
}
|
590
|
+
return false;
|
591
|
+
}
|
592
|
+
return true;
|
593
|
+
}
|