torch-rb 0.3.4 → 0.4.1

Sign up to get free protection for your applications and to get access to all the features.
@@ -7,13 +7,14 @@
7
7
  #include <rice/Constructor.hpp>
8
8
  #include <rice/Hash.hpp>
9
9
 
10
- #include "templates.hpp"
10
+ #include "templates.h"
11
+ #include "utils.h"
11
12
 
12
13
  // generated with:
13
14
  // rake generate:functions
14
- #include "torch_functions.hpp"
15
- #include "tensor_functions.hpp"
16
- #include "nn_functions.hpp"
15
+ #include "torch_functions.h"
16
+ #include "tensor_functions.h"
17
+ #include "nn_functions.h"
17
18
 
18
19
  using namespace Rice;
19
20
  using torch::indexing::TensorIndex;
@@ -29,11 +30,47 @@ void handle_error(torch::Error const & ex)
29
30
  throw Exception(rb_eRuntimeError, ex.what_without_backtrace());
30
31
  }
31
32
 
33
+ Class rb_cTensor;
34
+
32
35
  std::vector<TensorIndex> index_vector(Array a) {
33
- auto indices = std::vector<TensorIndex>();
36
+ Object obj;
37
+
38
+ std::vector<TensorIndex> indices;
34
39
  indices.reserve(a.size());
40
+
35
41
  for (size_t i = 0; i < a.size(); i++) {
36
- indices.push_back(from_ruby<TensorIndex>(a[i]));
42
+ obj = a[i];
43
+
44
+ if (obj.is_instance_of(rb_cInteger)) {
45
+ indices.push_back(from_ruby<int64_t>(obj));
46
+ } else if (obj.is_instance_of(rb_cRange)) {
47
+ torch::optional<int64_t> start_index = from_ruby<int64_t>(obj.call("begin"));
48
+ torch::optional<int64_t> stop_index = -1;
49
+
50
+ Object end = obj.call("end");
51
+ if (!end.is_nil()) {
52
+ stop_index = from_ruby<int64_t>(end);
53
+ }
54
+
55
+ Object exclude_end = obj.call("exclude_end?");
56
+ if (!exclude_end) {
57
+ if (stop_index.value() == -1) {
58
+ stop_index = torch::nullopt;
59
+ } else {
60
+ stop_index = stop_index.value() + 1;
61
+ }
62
+ }
63
+
64
+ indices.push_back(torch::indexing::Slice(start_index, stop_index));
65
+ } else if (obj.is_instance_of(rb_cTensor)) {
66
+ indices.push_back(from_ruby<Tensor>(obj));
67
+ } else if (obj.is_nil()) {
68
+ indices.push_back(torch::indexing::None);
69
+ } else if (obj == True || obj == False) {
70
+ indices.push_back(from_ruby<bool>(obj));
71
+ } else {
72
+ throw Exception(rb_eArgError, "Unsupported index type: %s", rb_obj_classname(obj));
73
+ }
37
74
  }
38
75
  return indices;
39
76
  }
@@ -45,9 +82,10 @@ void Init_ext()
45
82
  rb_mTorch.add_handler<torch::Error>(handle_error);
46
83
  add_torch_functions(rb_mTorch);
47
84
 
48
- Class rb_cTensor = define_class_under<torch::Tensor>(rb_mTorch, "Tensor");
85
+ rb_cTensor = define_class_under<torch::Tensor>(rb_mTorch, "Tensor");
49
86
  rb_cTensor.add_handler<torch::Error>(handle_error);
50
87
  add_tensor_functions(rb_cTensor);
88
+ THPVariableClass = rb_cTensor.value();
51
89
 
52
90
  Module rb_mNN = define_module_under(rb_mTorch, "NN");
53
91
  rb_mNN.add_handler<torch::Error>(handle_error);
@@ -68,13 +106,6 @@ void Init_ext()
68
106
  return generator.seed();
69
107
  });
70
108
 
71
- Class rb_cTensorIndex = define_class_under<TensorIndex>(rb_mTorch, "TensorIndex")
72
- .define_singleton_method("boolean", *[](bool value) { return TensorIndex(value); })
73
- .define_singleton_method("integer", *[](int64_t value) { return TensorIndex(value); })
74
- .define_singleton_method("tensor", *[](torch::Tensor& value) { return TensorIndex(value); })
75
- .define_singleton_method("slice", *[](torch::optional<int64_t> start_index, torch::optional<int64_t> stop_index) { return TensorIndex(torch::indexing::Slice(start_index, stop_index)); })
76
- .define_singleton_method("none", *[]() { return TensorIndex(torch::indexing::None); });
77
-
78
109
  // https://pytorch.org/cppdocs/api/structc10_1_1_i_value.html
79
110
  Class rb_cIValue = define_class_under<torch::IValue>(rb_mTorch, "IValue")
80
111
  .add_handler<torch::Error>(handle_error)
@@ -224,67 +255,6 @@ void Init_ext()
224
255
  *[] {
225
256
  return torch::get_parallel_info();
226
257
  })
227
- // begin tensor creation
228
- .define_singleton_method(
229
- "_arange",
230
- *[](Scalar start, Scalar end, Scalar step, const torch::TensorOptions &options) {
231
- return torch::arange(start, end, step, options);
232
- })
233
- .define_singleton_method(
234
- "_empty",
235
- *[](IntArrayRef size, const torch::TensorOptions &options) {
236
- return torch::empty(size, options);
237
- })
238
- .define_singleton_method(
239
- "_eye",
240
- *[](int64_t m, int64_t n, const torch::TensorOptions &options) {
241
- return torch::eye(m, n, options);
242
- })
243
- .define_singleton_method(
244
- "_full",
245
- *[](IntArrayRef size, Scalar fill_value, const torch::TensorOptions& options) {
246
- return torch::full(size, fill_value, options);
247
- })
248
- .define_singleton_method(
249
- "_linspace",
250
- *[](Scalar start, Scalar end, int64_t steps, const torch::TensorOptions& options) {
251
- return torch::linspace(start, end, steps, options);
252
- })
253
- .define_singleton_method(
254
- "_logspace",
255
- *[](Scalar start, Scalar end, int64_t steps, double base, const torch::TensorOptions& options) {
256
- return torch::logspace(start, end, steps, base, options);
257
- })
258
- .define_singleton_method(
259
- "_ones",
260
- *[](IntArrayRef size, const torch::TensorOptions &options) {
261
- return torch::ones(size, options);
262
- })
263
- .define_singleton_method(
264
- "_rand",
265
- *[](IntArrayRef size, const torch::TensorOptions &options) {
266
- return torch::rand(size, options);
267
- })
268
- .define_singleton_method(
269
- "_randint",
270
- *[](int64_t low, int64_t high, IntArrayRef size, const torch::TensorOptions &options) {
271
- return torch::randint(low, high, size, options);
272
- })
273
- .define_singleton_method(
274
- "_randn",
275
- *[](IntArrayRef size, const torch::TensorOptions &options) {
276
- return torch::randn(size, options);
277
- })
278
- .define_singleton_method(
279
- "_randperm",
280
- *[](int64_t n, const torch::TensorOptions &options) {
281
- return torch::randperm(n, options);
282
- })
283
- .define_singleton_method(
284
- "_zeros",
285
- *[](IntArrayRef size, const torch::TensorOptions &options) {
286
- return torch::zeros(size, options);
287
- })
288
258
  // begin operations
289
259
  .define_singleton_method(
290
260
  "_save",
@@ -301,20 +271,15 @@ void Init_ext()
301
271
  // https://github.com/pytorch/pytorch/issues/20356#issuecomment-567663701
302
272
  return torch::pickle_load(v);
303
273
  })
304
- .define_singleton_method(
305
- "_binary_cross_entropy_with_logits",
306
- *[](const Tensor &input, const Tensor &target, OptionalTensor weight, OptionalTensor pos_weight, MyReduction reduction) {
307
- return torch::binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction);
308
- })
309
274
  .define_singleton_method(
310
275
  "_from_blob",
311
- *[](String s, IntArrayRef size, const torch::TensorOptions &options) {
276
+ *[](String s, std::vector<int64_t> size, const torch::TensorOptions &options) {
312
277
  void *data = const_cast<char *>(s.c_str());
313
278
  return torch::from_blob(data, size, options);
314
279
  })
315
280
  .define_singleton_method(
316
281
  "_tensor",
317
- *[](Array a, IntArrayRef size, const torch::TensorOptions &options) {
282
+ *[](Array a, std::vector<int64_t> size, const torch::TensorOptions &options) {
318
283
  auto dtype = options.dtype();
319
284
  torch::Tensor t;
320
285
  if (dtype == torch::kBool) {
@@ -347,6 +312,25 @@ void Init_ext()
347
312
  .define_method("numel", &torch::Tensor::numel)
348
313
  .define_method("element_size", &torch::Tensor::element_size)
349
314
  .define_method("requires_grad", &torch::Tensor::requires_grad)
315
+ // in C++ for performance
316
+ .define_method(
317
+ "shape",
318
+ *[](Tensor& self) {
319
+ Array a;
320
+ for (auto &size : self.sizes()) {
321
+ a.push(size);
322
+ }
323
+ return a;
324
+ })
325
+ .define_method(
326
+ "_strides",
327
+ *[](Tensor& self) {
328
+ Array a;
329
+ for (auto &stride : self.strides()) {
330
+ a.push(stride);
331
+ }
332
+ return a;
333
+ })
350
334
  .define_method(
351
335
  "_index",
352
336
  *[](Tensor& self, Array indices) {
@@ -379,11 +363,6 @@ void Init_ext()
379
363
  *[](Tensor& self, bool requires_grad) {
380
364
  return self.set_requires_grad(requires_grad);
381
365
  })
382
- .define_method(
383
- "_backward",
384
- *[](Tensor& self, OptionalTensor gradient, bool create_graph, bool retain_graph) {
385
- return self.backward(gradient, create_graph, retain_graph);
386
- })
387
366
  .define_method(
388
367
  "grad",
389
368
  *[](Tensor& self) {
@@ -430,9 +409,19 @@ void Init_ext()
430
409
  tensor = tensor.to(device);
431
410
  }
432
411
 
412
+ if (!tensor.is_contiguous()) {
413
+ tensor = tensor.contiguous();
414
+ }
415
+
433
416
  auto data_ptr = (const char *) tensor.data_ptr();
434
417
  return std::string(data_ptr, tensor.numel() * tensor.element_size());
435
418
  })
419
+ // for TorchVision
420
+ .define_method(
421
+ "_data_ptr",
422
+ *[](Tensor& self) {
423
+ return reinterpret_cast<uintptr_t>(self.data_ptr());
424
+ })
436
425
  // TODO figure out a better way to do this
437
426
  .define_method(
438
427
  "_flat_data",
@@ -17,6 +17,9 @@ if have_library("omp") || have_library("gomp")
17
17
  end
18
18
 
19
19
  if apple_clang
20
+ # silence rice warnings
21
+ $CXXFLAGS += " -Wno-deprecated-declarations"
22
+
20
23
  # silence ruby/intern.h warning
21
24
  $CXXFLAGS += " -Wno-deprecated-register"
22
25
 
@@ -66,8 +69,8 @@ end
66
69
 
67
70
  # generate C++ functions
68
71
  puts "Generating C++ functions..."
69
- require_relative "../../lib/torch/native/generator"
70
- Torch::Native::Generator.generate_cpp_functions
72
+ require_relative "../../codegen/generate_functions"
73
+ generate_functions
71
74
 
72
75
  # create makefile
73
76
  create_makefile("torch/ext")
@@ -0,0 +1,6 @@
1
+ // generated by rake generate:functions
2
+ // do not edit by hand
3
+
4
+ #pragma once
5
+
6
+ void add_nn_functions(Module m);
@@ -0,0 +1,593 @@
1
+ // adapted from PyTorch - python_arg_parser.cpp
2
+
3
+ #include "ruby_arg_parser.h"
4
+
5
+ VALUE THPVariableClass = Qnil;
6
+
7
+ static std::unordered_map<std::string, ParameterType> type_map = {
8
+ {"Tensor", ParameterType::TENSOR},
9
+ {"Scalar", ParameterType::SCALAR},
10
+ {"int64_t", ParameterType::INT64},
11
+ {"double", ParameterType::DOUBLE},
12
+ {"complex", ParameterType::COMPLEX},
13
+ {"TensorList", ParameterType::TENSOR_LIST},
14
+ {"IntArrayRef", ParameterType::INT_LIST},
15
+ {"ArrayRef<double>", ParameterType::FLOAT_LIST},
16
+ {"Generator", ParameterType::GENERATOR},
17
+ {"bool", ParameterType::BOOL},
18
+ {"Storage", ParameterType::STORAGE},
19
+ // {"PyObject*", ParameterType::PYOBJECT},
20
+ {"ScalarType", ParameterType::SCALARTYPE},
21
+ {"Layout", ParameterType::LAYOUT},
22
+ {"MemoryFormat", ParameterType::MEMORY_FORMAT},
23
+ {"QScheme", ParameterType::QSCHEME},
24
+ {"Device", ParameterType::DEVICE},
25
+ {"std::string", ParameterType::STRING},
26
+ {"Dimname", ParameterType::DIMNAME},
27
+ {"DimnameList", ParameterType::DIMNAME_LIST},
28
+ };
29
+
30
+ static const std::unordered_map<std::string, std::vector<std::string>> numpy_compatibility_arg_names = {
31
+ {"dim", {"axis"}},
32
+ {"keepdim", {"keepdims"}},
33
+ {"input", {"x", "a", "x1"}},
34
+ {"other", {"x2"}},
35
+ };
36
+
37
+ static bool should_allow_numbers_as_tensors(const std::string& name) {
38
+ static std::unordered_set<std::string> allowed = {
39
+ "add", "add_", "add_out",
40
+ "div", "div_", "div_out",
41
+ "mul", "mul_", "mul_out",
42
+ "sub", "sub_", "sub_out",
43
+ "true_divide", "true_divide_", "true_divide_out",
44
+ "floor_divide", "floor_divide_", "floor_divide_out"
45
+ };
46
+ return allowed.find(name) != allowed.end();
47
+ }
48
+
49
+ FunctionParameter::FunctionParameter(const std::string& fmt, bool keyword_only)
50
+ : optional(false)
51
+ , allow_none(false)
52
+ , keyword_only(keyword_only)
53
+ , size(0)
54
+ , default_scalar(0)
55
+ {
56
+ auto space = fmt.find(' ');
57
+ if (space == std::string::npos) {
58
+ throw std::runtime_error("FunctionParameter(): missing type: " + fmt);
59
+ }
60
+
61
+ auto type_str = fmt.substr(0, space);
62
+
63
+ auto question = type_str.find('?');
64
+ if (question != std::string::npos) {
65
+ allow_none = true;
66
+ type_str = type_str.substr(0, question);
67
+ }
68
+
69
+ // Parse and remove brackets from type_str
70
+ auto bracket = type_str.find('[');
71
+ if (bracket != std::string::npos) {
72
+ auto size_str = type_str.substr(bracket + 1, type_str.length() - bracket - 2);
73
+ size = atoi(size_str.c_str());
74
+ type_str = type_str.substr(0, bracket);
75
+ }
76
+
77
+ auto name_str = fmt.substr(space + 1);
78
+ auto it = type_map.find(type_str);
79
+ if (it == type_map.end()) {
80
+ throw std::runtime_error("FunctionParameter(): invalid type string: " + type_str);
81
+ }
82
+ type_ = it->second;
83
+
84
+ auto eq = name_str.find('=');
85
+ if (eq != std::string::npos) {
86
+ name = name_str.substr(0, eq);
87
+ optional = true;
88
+ set_default_str(name_str.substr(eq + 1));
89
+ } else {
90
+ name = name_str;
91
+ }
92
+ ruby_name = THPUtils_internSymbol(name);
93
+ auto np_compat_it = numpy_compatibility_arg_names.find(name);
94
+ if (np_compat_it != numpy_compatibility_arg_names.end()) {
95
+ for (const auto& str: np_compat_it->second) {
96
+ numpy_python_names.push_back(THPUtils_internSymbol(str));
97
+ }
98
+ }
99
+ }
100
+
101
+ bool is_tensor_list(VALUE obj, int argnum, bool throw_error) {
102
+ if (!RB_TYPE_P(obj, T_ARRAY)) {
103
+ return false;
104
+ }
105
+ auto size = RARRAY_LEN(obj);
106
+ for (int idx = 0; idx < size; idx++) {
107
+ VALUE iobj = rb_ary_entry(obj, idx);
108
+ if (!THPVariable_Check(iobj)) {
109
+ if (throw_error) {
110
+ rb_raise(rb_eArgError, "expected Tensor as element %d in argument %d, but got %s",
111
+ static_cast<int>(idx), argnum, rb_obj_classname(obj));
112
+ }
113
+ return false;
114
+ }
115
+ }
116
+ return true;
117
+ }
118
+
119
+ // argnum is needed for raising the TypeError, it's used in the error message.
120
+ auto FunctionParameter::check(VALUE obj, int argnum) -> bool
121
+ {
122
+ switch (type_) {
123
+ case ParameterType::TENSOR: {
124
+ if (THPVariable_Check(obj)) {
125
+ return true;
126
+ }
127
+ return allow_numbers_as_tensors && THPUtils_checkScalar(obj);
128
+ }
129
+ case ParameterType::SCALAR:
130
+ case ParameterType::COMPLEX:
131
+ if (RB_TYPE_P(obj, T_COMPLEX)) {
132
+ return true;
133
+ }
134
+ // fallthrough
135
+ case ParameterType::DOUBLE: {
136
+ if (RB_FLOAT_TYPE_P(obj) || FIXNUM_P(obj)) {
137
+ return true;
138
+ }
139
+ if (THPVariable_Check(obj)) {
140
+ auto var = from_ruby<torch::Tensor>(obj);
141
+ return !var.requires_grad() && var.dim() == 0;
142
+ }
143
+ return false;
144
+ }
145
+ case ParameterType::INT64: {
146
+ if (FIXNUM_P(obj)) {
147
+ return true;
148
+ }
149
+ if (THPVariable_Check(obj)) {
150
+ auto var = from_ruby<torch::Tensor>(obj);
151
+ return at::isIntegralType(var.scalar_type(), /*includeBool=*/false) && !var.requires_grad() && var.dim() == 0;
152
+ }
153
+ return false;
154
+ }
155
+ case ParameterType::DIMNAME: return false; // return THPUtils_checkDimname(obj);
156
+ case ParameterType::DIMNAME_LIST: {
157
+ return false;
158
+ // if (THPUtils_checkDimnameList(obj)) {
159
+ // return true;
160
+ // }
161
+ // // if a size is specified (e.g. DimnameList[1]) we also allow passing a single Dimname
162
+ // return size == 1 && THPUtils_checkDimname(obj);
163
+ }
164
+ case ParameterType::TENSOR_LIST: {
165
+ return is_tensor_list(obj, argnum, true /* throw_error */);
166
+ }
167
+ case ParameterType::INT_LIST: {
168
+ if (RB_TYPE_P(obj, T_ARRAY)) {
169
+ return true;
170
+ }
171
+ // if a size is specified (e.g. IntArrayRef[2]) we also allow passing a single int
172
+ return size > 0 && FIXNUM_P(obj);
173
+ }
174
+ case ParameterType::FLOAT_LIST: return (RB_TYPE_P(obj, T_ARRAY));
175
+ case ParameterType::GENERATOR: return false; // return THPGenerator_Check(obj);
176
+ case ParameterType::BOOL: return obj == Qtrue || obj == Qfalse;
177
+ case ParameterType::STORAGE: return false; // return isStorage(obj);
178
+ // case ParameterType::PYOBJECT: return true;
179
+ case ParameterType::SCALARTYPE: return SYMBOL_P(obj);
180
+ case ParameterType::LAYOUT: return SYMBOL_P(obj);
181
+ case ParameterType::MEMORY_FORMAT: return false; // return THPMemoryFormat_Check(obj);
182
+ case ParameterType::QSCHEME: return false; // return THPQScheme_Check(obj);
183
+ case ParameterType::DEVICE: return RB_TYPE_P(obj, T_STRING); // TODO check device
184
+ case ParameterType::STRING: return RB_TYPE_P(obj, T_STRING);
185
+ default: throw std::runtime_error("unknown parameter type");
186
+ }
187
+ }
188
+
189
+ std::string FunctionParameter::type_name() const {
190
+ switch (type_) {
191
+ case ParameterType::TENSOR: return "Tensor";
192
+ case ParameterType::SCALAR: return "Number";
193
+ case ParameterType::INT64: return "int";
194
+ case ParameterType::DOUBLE: return "float";
195
+ case ParameterType::COMPLEX: return "complex";
196
+ case ParameterType::TENSOR_LIST: return "array of Tensors";
197
+ case ParameterType::INT_LIST: return "array of ints";
198
+ case ParameterType::FLOAT_LIST: return "array of floats";
199
+ case ParameterType::GENERATOR: return "torch.Generator";
200
+ case ParameterType::BOOL: return "bool";
201
+ case ParameterType::STORAGE: return "torch.Storage";
202
+ // case ParameterType::PYOBJECT: return "object";
203
+ case ParameterType::SCALARTYPE: return "torch.dtype";
204
+ case ParameterType::LAYOUT: return "torch.layout";
205
+ case ParameterType::MEMORY_FORMAT: return "torch.memory_format";
206
+ case ParameterType::QSCHEME: return "torch.qscheme";
207
+ case ParameterType::DEVICE: return "torch.device";
208
+ case ParameterType::STRING: return "str";
209
+ case ParameterType::DIMNAME: return "name";
210
+ case ParameterType::DIMNAME_LIST: return "array of names";
211
+ default: throw std::runtime_error("unknown parameter type");
212
+ }
213
+ }
214
+
215
+ static inline c10::optional<int64_t> parse_as_integer(const std::string& s) {
216
+ if (s.empty())
217
+ return c10::nullopt;
218
+ char *str_end;
219
+ long ans = strtol(s.c_str(), &str_end, 0);
220
+ // *str_end == 0 if the entire string was parsed as an integer.
221
+ return (*str_end == 0) ? c10::optional<int64_t>(ans) : c10::nullopt;
222
+ }
223
+
224
+ /*
225
+ Parse default value of IntArrayRef declared at native_functions.yaml
226
+
227
+ There are two kinds of default values:
228
+ 1. IntArrayRef[2] x=1 (where size=2, value={1,1}
229
+ 2. IntArrayRef x={1,2,3} (where size=3, value={1,2,3}, note that there cannot be space after comma since native_parse.py uses ', ' to split args)
230
+ */
231
+ static inline std::vector<int64_t> parse_intlist_args(const std::string& s, int64_t size) {
232
+ size_t n = s.size();
233
+
234
+ if (s.empty()) return std::vector<int64_t>();
235
+
236
+ // case 1. s is an int (e.g., s=2)
237
+ if (s[0] != '{') {
238
+ return std::vector<int64_t>(size, std::stol(s));
239
+ }
240
+
241
+ // case 2. s is a list of dims (e.g., s={1,2})
242
+
243
+ // since already checked left brace '{' above, here only checks right brace '}'
244
+ TORCH_CHECK(s[n - 1] == '}', "Default value of IntArrayRef is missing right brace '}', found ", s[n - 1]);
245
+
246
+ auto args = std::vector<int64_t>();
247
+ std::istringstream ss(s.substr(1, s.length() - 2)); // exclude '{' and '}'
248
+ std::string tok;
249
+
250
+ while(std::getline(ss, tok, ',')) {
251
+ args.emplace_back(std::stol(tok));
252
+ }
253
+ return args;
254
+ }
255
+
256
+ void FunctionParameter::set_default_str(const std::string& str) {
257
+ if (str == "None") {
258
+ allow_none = true;
259
+ }
260
+ if (type_ == ParameterType::TENSOR) {
261
+ if (str != "None") {
262
+ throw std::runtime_error("default value for Tensor must be none, got: " + str);
263
+ }
264
+ } else if (type_ == ParameterType::INT64) {
265
+ default_int = atol(str.c_str());
266
+ } else if (type_ == ParameterType::BOOL) {
267
+ default_bool = (str == "True" || str == "true");
268
+ } else if (type_ == ParameterType::DOUBLE) {
269
+ default_double = atof(str.c_str());
270
+ } else if (type_ == ParameterType::COMPLEX) {
271
+ default_complex[0] = atof(str.c_str()); // TODO: parse "x + xj"?
272
+ default_complex[1] = 0;
273
+ } else if (type_ == ParameterType::SCALAR) {
274
+ if (str != "None") {
275
+ // we sometimes rely on integer-vs-float values, e.g. with arange.
276
+ const auto as_integer = parse_as_integer(str);
277
+ default_scalar = as_integer.has_value() ? at::Scalar(as_integer.value()) :
278
+ at::Scalar(atof(str.c_str()));
279
+ }
280
+ } else if (type_ == ParameterType::INT_LIST) {
281
+ if (str != "None") {
282
+ default_intlist = parse_intlist_args(str, size);
283
+ }
284
+ } else if (type_ == ParameterType::FLOAT_LIST) {
285
+ if (str != "None") {
286
+ throw std::runtime_error("Defaults not supported for float[]");
287
+ }
288
+ } else if (type_ == ParameterType::SCALARTYPE) {
289
+ if (str == "None") {
290
+ default_scalartype = at::ScalarType::Undefined;
291
+ } else if (str == "torch.int64") {
292
+ default_scalartype = at::ScalarType::Long;
293
+ } else {
294
+ throw std::runtime_error("invalid default value for ScalarType: " + str);
295
+ }
296
+ } else if (type_ == ParameterType::LAYOUT) {
297
+ if (str == "None") {
298
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(allow_none);
299
+ } else if (str == "torch.strided") {
300
+ default_layout = at::Layout::Strided;
301
+ } else if (str == "torch.sparse_coo") {
302
+ default_layout = at::Layout::Sparse;
303
+ } else {
304
+ throw std::runtime_error("invalid default value for layout: " + str);
305
+ }
306
+ } else if (type_ == ParameterType::DEVICE) {
307
+ if (str != "None") {
308
+ throw std::runtime_error("invalid device: " + str);
309
+ }
310
+ } else if (type_ == ParameterType::STRING) {
311
+ if (str != "None" && str != "") {
312
+ throw std::runtime_error("invalid default string: " + str);
313
+ }
314
+ }
315
+ }
316
+
317
+ FunctionSignature::FunctionSignature(const std::string& fmt, int index)
318
+ : min_args(0)
319
+ , max_args(0)
320
+ , max_pos_args(0)
321
+ , index(index)
322
+ , hidden(false)
323
+ , deprecated(false)
324
+ {
325
+ auto open_paren = fmt.find('(');
326
+ if (open_paren == std::string::npos) {
327
+ throw std::runtime_error("missing opening parenthesis: " + fmt);
328
+ }
329
+ name = fmt.substr(0, open_paren);
330
+
331
+ bool allow_numbers_as_tensors = should_allow_numbers_as_tensors(name);
332
+
333
+ auto last_offset = open_paren + 1;
334
+ auto next_offset = last_offset;
335
+ bool keyword_only = false;
336
+ bool done = false;
337
+ while (!done) {
338
+ auto offset = fmt.find(", ", last_offset);
339
+ if (offset == std::string::npos) {
340
+ offset = fmt.find(')', last_offset);
341
+ done = true;
342
+ next_offset = offset+ 1;
343
+ // this 'if' happens for an empty parameter list, i.e. fn().
344
+ if (offset == last_offset) {
345
+ last_offset = next_offset;
346
+ break;
347
+ }
348
+ } else {
349
+ next_offset = offset + 2;
350
+ }
351
+ if (offset == std::string::npos) {
352
+ throw std::runtime_error("missing closing parenthesis: " + fmt);
353
+ }
354
+ if (offset == last_offset) {
355
+ throw std::runtime_error("malformed signature: " + fmt);
356
+ }
357
+
358
+ auto param_str = fmt.substr(last_offset, offset - last_offset);
359
+ last_offset = next_offset;
360
+ if (param_str == "*") {
361
+ keyword_only = true;
362
+ } else {
363
+ params.emplace_back(param_str, keyword_only);
364
+ params.back().allow_numbers_as_tensors = allow_numbers_as_tensors;
365
+ }
366
+ }
367
+
368
+ if (fmt.substr(last_offset) == "|deprecated") {
369
+ hidden = true;
370
+ // TODO: raise warning when parsing deprecated signatures
371
+ deprecated = true;
372
+ } else if (fmt.substr(last_offset) == "|hidden") {
373
+ hidden = true;
374
+ }
375
+
376
+ max_args = params.size();
377
+
378
+ // count the number of non-optional args
379
+ for (auto& param : params) {
380
+ if (!param.optional) {
381
+ min_args++;
382
+ }
383
+ if (!param.keyword_only) {
384
+ max_pos_args++;
385
+ }
386
+ }
387
+ }
388
+
389
+ std::string FunctionSignature::toString() const {
390
+ // TODO: consider printing more proper schema strings with defaults, optionals, etc.
391
+ std::ostringstream ss;
392
+ bool keyword_already = false;
393
+ ss << "(";
394
+ int i = 0;
395
+ for (auto& param : params) {
396
+ if (i != 0) {
397
+ ss << ", ";
398
+ }
399
+ if (param.keyword_only && !keyword_already) {
400
+ ss << "*, ";
401
+ keyword_already = true;
402
+ }
403
+ ss << param.type_name() << " " << param.name;
404
+ i++;
405
+ }
406
+ ss << ")";
407
+ return ss.str();
408
+ }
409
+
410
+ [[noreturn]]
411
+ static void extra_args(const FunctionSignature& signature, ssize_t nargs) {
412
+ const long max_pos_args = signature.max_pos_args;
413
+ const long min_args = signature.min_args;
414
+ const long nargs_ = nargs;
415
+ if (min_args != max_pos_args) {
416
+ rb_raise(rb_eArgError, "%s() takes from %ld to %ld positional arguments but %ld were given",
417
+ signature.name.c_str(), min_args, max_pos_args, nargs_);
418
+ }
419
+ rb_raise(rb_eArgError, "%s() takes %ld positional argument%s but %ld %s given",
420
+ signature.name.c_str(),
421
+ max_pos_args, max_pos_args == 1 ? "" : "s",
422
+ nargs_, nargs == 1 ? "was" : "were");
423
+ }
424
+
425
+ [[noreturn]]
426
+ static void missing_args(const FunctionSignature& signature, int idx) {
427
+ int num_missing = 0;
428
+ std::stringstream ss;
429
+
430
+ auto& params = signature.params;
431
+ for (auto it = params.begin() + idx; it != params.end(); ++it) {
432
+ if (!it->optional) {
433
+ if (num_missing > 0) {
434
+ ss << ", ";
435
+ }
436
+ ss << '"' << it->name << '"';
437
+ num_missing++;
438
+ }
439
+ }
440
+
441
+ rb_raise(rb_eArgError, "%s() missing %d required positional argument%s: %s",
442
+ signature.name.c_str(),
443
+ num_missing,
444
+ num_missing == 1 ? "s" : "",
445
+ ss.str().c_str());
446
+ }
447
+
448
+ static ssize_t find_param(FunctionSignature& signature, VALUE name) {
449
+ ssize_t i = 0;
450
+ for (auto& param : signature.params) {
451
+ bool cmp = name == param.ruby_name;
452
+ if (cmp) {
453
+ return i;
454
+ }
455
+ i++;
456
+ }
457
+ return -1;
458
+ }
459
+
460
+ [[noreturn]]
461
+ static void extra_kwargs(FunctionSignature& signature, VALUE kwargs, ssize_t num_pos_args) {
462
+ VALUE key;
463
+
464
+ VALUE keys = rb_funcall(kwargs, rb_intern("keys"), 0);
465
+ if (RARRAY_LEN(keys) > 0) {
466
+ key = rb_ary_entry(keys, 0);
467
+
468
+ if (!THPUtils_checkSymbol(key)) {
469
+ rb_raise(rb_eArgError, "keywords must be symbols, not %s", rb_obj_classname(key));
470
+ }
471
+
472
+ auto param_idx = find_param(signature, key);
473
+ if (param_idx < 0) {
474
+ rb_raise(rb_eArgError, "%s() got an unexpected keyword argument '%s'",
475
+ signature.name.c_str(), THPUtils_unpackSymbol(key).c_str());
476
+ }
477
+
478
+ if (param_idx < num_pos_args) {
479
+ rb_raise(rb_eArgError, "%s() got multiple values for argument '%s'",
480
+ signature.name.c_str(), THPUtils_unpackSymbol(key).c_str());
481
+ }
482
+ }
483
+
484
+ // this should never be hit
485
+ rb_raise(rb_eArgError, "invalid keyword arguments");
486
+ }
487
+
488
+ VALUE missing = Qundef;
489
+
490
+ bool FunctionSignature::parse(VALUE self, VALUE args, VALUE kwargs, std::vector<VALUE> &dst, // NOLINT
491
+ bool raise_exception) {
492
+ auto nargs = NIL_P(args) ? 0 : RARRAY_LEN(args);
493
+ ssize_t remaining_kwargs = NIL_P(kwargs) ? 0 : RHASH_SIZE(kwargs);
494
+ ssize_t arg_pos = 0;
495
+ bool allow_varargs_intlist = false;
496
+
497
+ // if there is a single positional IntArrayRef argument, i.e. expand(..), view(...),
498
+ // allow a var-args style IntArrayRef, so expand(5,3) behaves as expand((5,3))
499
+ if (max_pos_args == 1 && params[0].type_ == ParameterType::INT_LIST) {
500
+ allow_varargs_intlist = true;
501
+ }
502
+
503
+ if (nargs > max_pos_args && !allow_varargs_intlist) {
504
+ if (raise_exception) {
505
+ // foo() takes takes 2 positional arguments but 3 were given
506
+ extra_args(*this, nargs);
507
+ }
508
+ return false;
509
+ }
510
+
511
+ // if (!overloaded_args.empty()) {
512
+ // overloaded_args.clear();
513
+ // }
514
+
515
+ int i = 0;
516
+ // if (self != nullptr && !THPVariable_CheckExact(self) && check_has_torch_function(self)) {
517
+ // append_overloaded_arg(&this->overloaded_args, self);
518
+ // }
519
+ for (auto& param : params) {
520
+ VALUE obj = missing;
521
+ bool is_kwd = false;
522
+ if (arg_pos < nargs) {
523
+ // extra positional args given after single positional IntArrayRef arg
524
+ if (param.keyword_only) {
525
+ if (raise_exception) {
526
+ extra_args(*this, nargs);
527
+ }
528
+ return false;
529
+ }
530
+ obj = rb_ary_entry(args, arg_pos);
531
+ } else if (!NIL_P(kwargs)) {
532
+ obj = rb_hash_lookup2(kwargs, param.ruby_name, missing);
533
+ // for (VALUE numpy_name: param.numpy_python_names) {
534
+ // if (obj) {
535
+ // break;
536
+ // }
537
+ // obj = rb_hash_aref(kwargs, numpy_name);
538
+ // }
539
+ is_kwd = true;
540
+ }
541
+
542
+ if ((obj == missing && param.optional) || (NIL_P(obj) && param.allow_none)) {
543
+ dst[i++] = Qnil;
544
+ } else if (obj == missing) {
545
+ if (raise_exception) {
546
+ // foo() missing 1 required positional argument: "b"
547
+ missing_args(*this, i);
548
+ }
549
+ return false;
550
+ } else if (param.check(obj, i)) {
551
+ dst[i++] = obj;
552
+ // XXX: the Variable check is necessary because sizes become tensors when
553
+ // tracer is enabled. This behavior easily leads to ambiguities, and we
554
+ // should avoid having complex signatures that make use of it...
555
+ } else if (allow_varargs_intlist && arg_pos == 0 && !is_kwd &&
556
+ THPUtils_checkIndex(obj)) {
557
+ // take all positional arguments as this parameter
558
+ // e.g. permute(1, 2, 3) -> permute((1, 2, 3))
559
+ dst[i++] = args;
560
+ arg_pos = nargs;
561
+ continue;
562
+ } else if (raise_exception) {
563
+ if (is_kwd) {
564
+ // foo(): argument 'other' must be str, not int
565
+ rb_raise(rb_eArgError, "%s(): argument '%s' must be %s, not %s",
566
+ name.c_str(), param.name.c_str(), param.type_name().c_str(),
567
+ rb_obj_classname(obj));
568
+ } else {
569
+ // foo(): argument 'other' (position 2) must be str, not int
570
+ rb_raise(rb_eArgError, "%s(): argument '%s' (position %ld) must be %s, not %s",
571
+ name.c_str(), param.name.c_str(), static_cast<long>(arg_pos + 1),
572
+ param.type_name().c_str(), rb_obj_classname(obj));
573
+ }
574
+ } else {
575
+ return false;
576
+ }
577
+
578
+ if (!is_kwd) {
579
+ arg_pos++;
580
+ } else if (obj != missing) {
581
+ remaining_kwargs--;
582
+ }
583
+ }
584
+
585
+ if (remaining_kwargs > 0) {
586
+ if (raise_exception) {
587
+ // foo() got an unexpected keyword argument "b"
588
+ extra_kwargs(*this, kwargs, nargs);
589
+ }
590
+ return false;
591
+ }
592
+ return true;
593
+ }