torch-rb 0.3.4 → 0.4.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -7,13 +7,14 @@
7
7
  #include <rice/Constructor.hpp>
8
8
  #include <rice/Hash.hpp>
9
9
 
10
- #include "templates.hpp"
10
+ #include "templates.h"
11
+ #include "utils.h"
11
12
 
12
13
  // generated with:
13
14
  // rake generate:functions
14
- #include "torch_functions.hpp"
15
- #include "tensor_functions.hpp"
16
- #include "nn_functions.hpp"
15
+ #include "torch_functions.h"
16
+ #include "tensor_functions.h"
17
+ #include "nn_functions.h"
17
18
 
18
19
  using namespace Rice;
19
20
  using torch::indexing::TensorIndex;
@@ -29,11 +30,47 @@ void handle_error(torch::Error const & ex)
29
30
  throw Exception(rb_eRuntimeError, ex.what_without_backtrace());
30
31
  }
31
32
 
33
+ Class rb_cTensor;
34
+
32
35
  std::vector<TensorIndex> index_vector(Array a) {
33
- auto indices = std::vector<TensorIndex>();
36
+ Object obj;
37
+
38
+ std::vector<TensorIndex> indices;
34
39
  indices.reserve(a.size());
40
+
35
41
  for (size_t i = 0; i < a.size(); i++) {
36
- indices.push_back(from_ruby<TensorIndex>(a[i]));
42
+ obj = a[i];
43
+
44
+ if (obj.is_instance_of(rb_cInteger)) {
45
+ indices.push_back(from_ruby<int64_t>(obj));
46
+ } else if (obj.is_instance_of(rb_cRange)) {
47
+ torch::optional<int64_t> start_index = from_ruby<int64_t>(obj.call("begin"));
48
+ torch::optional<int64_t> stop_index = -1;
49
+
50
+ Object end = obj.call("end");
51
+ if (!end.is_nil()) {
52
+ stop_index = from_ruby<int64_t>(end);
53
+ }
54
+
55
+ Object exclude_end = obj.call("exclude_end?");
56
+ if (!exclude_end) {
57
+ if (stop_index.value() == -1) {
58
+ stop_index = torch::nullopt;
59
+ } else {
60
+ stop_index = stop_index.value() + 1;
61
+ }
62
+ }
63
+
64
+ indices.push_back(torch::indexing::Slice(start_index, stop_index));
65
+ } else if (obj.is_instance_of(rb_cTensor)) {
66
+ indices.push_back(from_ruby<Tensor>(obj));
67
+ } else if (obj.is_nil()) {
68
+ indices.push_back(torch::indexing::None);
69
+ } else if (obj == True || obj == False) {
70
+ indices.push_back(from_ruby<bool>(obj));
71
+ } else {
72
+ throw Exception(rb_eArgError, "Unsupported index type: %s", rb_obj_classname(obj));
73
+ }
37
74
  }
38
75
  return indices;
39
76
  }
@@ -45,9 +82,10 @@ void Init_ext()
45
82
  rb_mTorch.add_handler<torch::Error>(handle_error);
46
83
  add_torch_functions(rb_mTorch);
47
84
 
48
- Class rb_cTensor = define_class_under<torch::Tensor>(rb_mTorch, "Tensor");
85
+ rb_cTensor = define_class_under<torch::Tensor>(rb_mTorch, "Tensor");
49
86
  rb_cTensor.add_handler<torch::Error>(handle_error);
50
87
  add_tensor_functions(rb_cTensor);
88
+ THPVariableClass = rb_cTensor.value();
51
89
 
52
90
  Module rb_mNN = define_module_under(rb_mTorch, "NN");
53
91
  rb_mNN.add_handler<torch::Error>(handle_error);
@@ -68,13 +106,6 @@ void Init_ext()
68
106
  return generator.seed();
69
107
  });
70
108
 
71
- Class rb_cTensorIndex = define_class_under<TensorIndex>(rb_mTorch, "TensorIndex")
72
- .define_singleton_method("boolean", *[](bool value) { return TensorIndex(value); })
73
- .define_singleton_method("integer", *[](int64_t value) { return TensorIndex(value); })
74
- .define_singleton_method("tensor", *[](torch::Tensor& value) { return TensorIndex(value); })
75
- .define_singleton_method("slice", *[](torch::optional<int64_t> start_index, torch::optional<int64_t> stop_index) { return TensorIndex(torch::indexing::Slice(start_index, stop_index)); })
76
- .define_singleton_method("none", *[]() { return TensorIndex(torch::indexing::None); });
77
-
78
109
  // https://pytorch.org/cppdocs/api/structc10_1_1_i_value.html
79
110
  Class rb_cIValue = define_class_under<torch::IValue>(rb_mTorch, "IValue")
80
111
  .add_handler<torch::Error>(handle_error)
@@ -224,67 +255,6 @@ void Init_ext()
224
255
  *[] {
225
256
  return torch::get_parallel_info();
226
257
  })
227
- // begin tensor creation
228
- .define_singleton_method(
229
- "_arange",
230
- *[](Scalar start, Scalar end, Scalar step, const torch::TensorOptions &options) {
231
- return torch::arange(start, end, step, options);
232
- })
233
- .define_singleton_method(
234
- "_empty",
235
- *[](IntArrayRef size, const torch::TensorOptions &options) {
236
- return torch::empty(size, options);
237
- })
238
- .define_singleton_method(
239
- "_eye",
240
- *[](int64_t m, int64_t n, const torch::TensorOptions &options) {
241
- return torch::eye(m, n, options);
242
- })
243
- .define_singleton_method(
244
- "_full",
245
- *[](IntArrayRef size, Scalar fill_value, const torch::TensorOptions& options) {
246
- return torch::full(size, fill_value, options);
247
- })
248
- .define_singleton_method(
249
- "_linspace",
250
- *[](Scalar start, Scalar end, int64_t steps, const torch::TensorOptions& options) {
251
- return torch::linspace(start, end, steps, options);
252
- })
253
- .define_singleton_method(
254
- "_logspace",
255
- *[](Scalar start, Scalar end, int64_t steps, double base, const torch::TensorOptions& options) {
256
- return torch::logspace(start, end, steps, base, options);
257
- })
258
- .define_singleton_method(
259
- "_ones",
260
- *[](IntArrayRef size, const torch::TensorOptions &options) {
261
- return torch::ones(size, options);
262
- })
263
- .define_singleton_method(
264
- "_rand",
265
- *[](IntArrayRef size, const torch::TensorOptions &options) {
266
- return torch::rand(size, options);
267
- })
268
- .define_singleton_method(
269
- "_randint",
270
- *[](int64_t low, int64_t high, IntArrayRef size, const torch::TensorOptions &options) {
271
- return torch::randint(low, high, size, options);
272
- })
273
- .define_singleton_method(
274
- "_randn",
275
- *[](IntArrayRef size, const torch::TensorOptions &options) {
276
- return torch::randn(size, options);
277
- })
278
- .define_singleton_method(
279
- "_randperm",
280
- *[](int64_t n, const torch::TensorOptions &options) {
281
- return torch::randperm(n, options);
282
- })
283
- .define_singleton_method(
284
- "_zeros",
285
- *[](IntArrayRef size, const torch::TensorOptions &options) {
286
- return torch::zeros(size, options);
287
- })
288
258
  // begin operations
289
259
  .define_singleton_method(
290
260
  "_save",
@@ -301,20 +271,15 @@ void Init_ext()
301
271
  // https://github.com/pytorch/pytorch/issues/20356#issuecomment-567663701
302
272
  return torch::pickle_load(v);
303
273
  })
304
- .define_singleton_method(
305
- "_binary_cross_entropy_with_logits",
306
- *[](const Tensor &input, const Tensor &target, OptionalTensor weight, OptionalTensor pos_weight, MyReduction reduction) {
307
- return torch::binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction);
308
- })
309
274
  .define_singleton_method(
310
275
  "_from_blob",
311
- *[](String s, IntArrayRef size, const torch::TensorOptions &options) {
276
+ *[](String s, std::vector<int64_t> size, const torch::TensorOptions &options) {
312
277
  void *data = const_cast<char *>(s.c_str());
313
278
  return torch::from_blob(data, size, options);
314
279
  })
315
280
  .define_singleton_method(
316
281
  "_tensor",
317
- *[](Array a, IntArrayRef size, const torch::TensorOptions &options) {
282
+ *[](Array a, std::vector<int64_t> size, const torch::TensorOptions &options) {
318
283
  auto dtype = options.dtype();
319
284
  torch::Tensor t;
320
285
  if (dtype == torch::kBool) {
@@ -347,6 +312,25 @@ void Init_ext()
347
312
  .define_method("numel", &torch::Tensor::numel)
348
313
  .define_method("element_size", &torch::Tensor::element_size)
349
314
  .define_method("requires_grad", &torch::Tensor::requires_grad)
315
+ // in C++ for performance
316
+ .define_method(
317
+ "shape",
318
+ *[](Tensor& self) {
319
+ Array a;
320
+ for (auto &size : self.sizes()) {
321
+ a.push(size);
322
+ }
323
+ return a;
324
+ })
325
+ .define_method(
326
+ "_strides",
327
+ *[](Tensor& self) {
328
+ Array a;
329
+ for (auto &stride : self.strides()) {
330
+ a.push(stride);
331
+ }
332
+ return a;
333
+ })
350
334
  .define_method(
351
335
  "_index",
352
336
  *[](Tensor& self, Array indices) {
@@ -379,11 +363,6 @@ void Init_ext()
379
363
  *[](Tensor& self, bool requires_grad) {
380
364
  return self.set_requires_grad(requires_grad);
381
365
  })
382
- .define_method(
383
- "_backward",
384
- *[](Tensor& self, OptionalTensor gradient, bool create_graph, bool retain_graph) {
385
- return self.backward(gradient, create_graph, retain_graph);
386
- })
387
366
  .define_method(
388
367
  "grad",
389
368
  *[](Tensor& self) {
@@ -430,9 +409,19 @@ void Init_ext()
430
409
  tensor = tensor.to(device);
431
410
  }
432
411
 
412
+ if (!tensor.is_contiguous()) {
413
+ tensor = tensor.contiguous();
414
+ }
415
+
433
416
  auto data_ptr = (const char *) tensor.data_ptr();
434
417
  return std::string(data_ptr, tensor.numel() * tensor.element_size());
435
418
  })
419
+ // for TorchVision
420
+ .define_method(
421
+ "_data_ptr",
422
+ *[](Tensor& self) {
423
+ return reinterpret_cast<uintptr_t>(self.data_ptr());
424
+ })
436
425
  // TODO figure out a better way to do this
437
426
  .define_method(
438
427
  "_flat_data",
@@ -17,6 +17,9 @@ if have_library("omp") || have_library("gomp")
17
17
  end
18
18
 
19
19
  if apple_clang
20
+ # silence rice warnings
21
+ $CXXFLAGS += " -Wno-deprecated-declarations"
22
+
20
23
  # silence ruby/intern.h warning
21
24
  $CXXFLAGS += " -Wno-deprecated-register"
22
25
 
@@ -66,8 +69,8 @@ end
66
69
 
67
70
  # generate C++ functions
68
71
  puts "Generating C++ functions..."
69
- require_relative "../../lib/torch/native/generator"
70
- Torch::Native::Generator.generate_cpp_functions
72
+ require_relative "../../codegen/generate_functions"
73
+ generate_functions
71
74
 
72
75
  # create makefile
73
76
  create_makefile("torch/ext")
@@ -0,0 +1,6 @@
1
+ // generated by rake generate:functions
2
+ // do not edit by hand
3
+
4
+ #pragma once
5
+
6
+ void add_nn_functions(Module m);
@@ -0,0 +1,593 @@
1
+ // adapted from PyTorch - python_arg_parser.cpp
2
+
3
+ #include "ruby_arg_parser.h"
4
+
5
+ VALUE THPVariableClass = Qnil;
6
+
7
+ static std::unordered_map<std::string, ParameterType> type_map = {
8
+ {"Tensor", ParameterType::TENSOR},
9
+ {"Scalar", ParameterType::SCALAR},
10
+ {"int64_t", ParameterType::INT64},
11
+ {"double", ParameterType::DOUBLE},
12
+ {"complex", ParameterType::COMPLEX},
13
+ {"TensorList", ParameterType::TENSOR_LIST},
14
+ {"IntArrayRef", ParameterType::INT_LIST},
15
+ {"ArrayRef<double>", ParameterType::FLOAT_LIST},
16
+ {"Generator", ParameterType::GENERATOR},
17
+ {"bool", ParameterType::BOOL},
18
+ {"Storage", ParameterType::STORAGE},
19
+ // {"PyObject*", ParameterType::PYOBJECT},
20
+ {"ScalarType", ParameterType::SCALARTYPE},
21
+ {"Layout", ParameterType::LAYOUT},
22
+ {"MemoryFormat", ParameterType::MEMORY_FORMAT},
23
+ {"QScheme", ParameterType::QSCHEME},
24
+ {"Device", ParameterType::DEVICE},
25
+ {"std::string", ParameterType::STRING},
26
+ {"Dimname", ParameterType::DIMNAME},
27
+ {"DimnameList", ParameterType::DIMNAME_LIST},
28
+ };
29
+
30
+ static const std::unordered_map<std::string, std::vector<std::string>> numpy_compatibility_arg_names = {
31
+ {"dim", {"axis"}},
32
+ {"keepdim", {"keepdims"}},
33
+ {"input", {"x", "a", "x1"}},
34
+ {"other", {"x2"}},
35
+ };
36
+
37
+ static bool should_allow_numbers_as_tensors(const std::string& name) {
38
+ static std::unordered_set<std::string> allowed = {
39
+ "add", "add_", "add_out",
40
+ "div", "div_", "div_out",
41
+ "mul", "mul_", "mul_out",
42
+ "sub", "sub_", "sub_out",
43
+ "true_divide", "true_divide_", "true_divide_out",
44
+ "floor_divide", "floor_divide_", "floor_divide_out"
45
+ };
46
+ return allowed.find(name) != allowed.end();
47
+ }
48
+
49
+ FunctionParameter::FunctionParameter(const std::string& fmt, bool keyword_only)
50
+ : optional(false)
51
+ , allow_none(false)
52
+ , keyword_only(keyword_only)
53
+ , size(0)
54
+ , default_scalar(0)
55
+ {
56
+ auto space = fmt.find(' ');
57
+ if (space == std::string::npos) {
58
+ throw std::runtime_error("FunctionParameter(): missing type: " + fmt);
59
+ }
60
+
61
+ auto type_str = fmt.substr(0, space);
62
+
63
+ auto question = type_str.find('?');
64
+ if (question != std::string::npos) {
65
+ allow_none = true;
66
+ type_str = type_str.substr(0, question);
67
+ }
68
+
69
+ // Parse and remove brackets from type_str
70
+ auto bracket = type_str.find('[');
71
+ if (bracket != std::string::npos) {
72
+ auto size_str = type_str.substr(bracket + 1, type_str.length() - bracket - 2);
73
+ size = atoi(size_str.c_str());
74
+ type_str = type_str.substr(0, bracket);
75
+ }
76
+
77
+ auto name_str = fmt.substr(space + 1);
78
+ auto it = type_map.find(type_str);
79
+ if (it == type_map.end()) {
80
+ throw std::runtime_error("FunctionParameter(): invalid type string: " + type_str);
81
+ }
82
+ type_ = it->second;
83
+
84
+ auto eq = name_str.find('=');
85
+ if (eq != std::string::npos) {
86
+ name = name_str.substr(0, eq);
87
+ optional = true;
88
+ set_default_str(name_str.substr(eq + 1));
89
+ } else {
90
+ name = name_str;
91
+ }
92
+ ruby_name = THPUtils_internSymbol(name);
93
+ auto np_compat_it = numpy_compatibility_arg_names.find(name);
94
+ if (np_compat_it != numpy_compatibility_arg_names.end()) {
95
+ for (const auto& str: np_compat_it->second) {
96
+ numpy_python_names.push_back(THPUtils_internSymbol(str));
97
+ }
98
+ }
99
+ }
100
+
101
+ bool is_tensor_list(VALUE obj, int argnum, bool throw_error) {
102
+ if (!RB_TYPE_P(obj, T_ARRAY)) {
103
+ return false;
104
+ }
105
+ auto size = RARRAY_LEN(obj);
106
+ for (int idx = 0; idx < size; idx++) {
107
+ VALUE iobj = rb_ary_entry(obj, idx);
108
+ if (!THPVariable_Check(iobj)) {
109
+ if (throw_error) {
110
+ rb_raise(rb_eArgError, "expected Tensor as element %d in argument %d, but got %s",
111
+ static_cast<int>(idx), argnum, rb_obj_classname(obj));
112
+ }
113
+ return false;
114
+ }
115
+ }
116
+ return true;
117
+ }
118
+
119
+ // argnum is needed for raising the TypeError, it's used in the error message.
120
+ auto FunctionParameter::check(VALUE obj, int argnum) -> bool
121
+ {
122
+ switch (type_) {
123
+ case ParameterType::TENSOR: {
124
+ if (THPVariable_Check(obj)) {
125
+ return true;
126
+ }
127
+ return allow_numbers_as_tensors && THPUtils_checkScalar(obj);
128
+ }
129
+ case ParameterType::SCALAR:
130
+ case ParameterType::COMPLEX:
131
+ if (RB_TYPE_P(obj, T_COMPLEX)) {
132
+ return true;
133
+ }
134
+ // fallthrough
135
+ case ParameterType::DOUBLE: {
136
+ if (RB_FLOAT_TYPE_P(obj) || FIXNUM_P(obj)) {
137
+ return true;
138
+ }
139
+ if (THPVariable_Check(obj)) {
140
+ auto var = from_ruby<torch::Tensor>(obj);
141
+ return !var.requires_grad() && var.dim() == 0;
142
+ }
143
+ return false;
144
+ }
145
+ case ParameterType::INT64: {
146
+ if (FIXNUM_P(obj)) {
147
+ return true;
148
+ }
149
+ if (THPVariable_Check(obj)) {
150
+ auto var = from_ruby<torch::Tensor>(obj);
151
+ return at::isIntegralType(var.scalar_type(), /*includeBool=*/false) && !var.requires_grad() && var.dim() == 0;
152
+ }
153
+ return false;
154
+ }
155
+ case ParameterType::DIMNAME: return false; // return THPUtils_checkDimname(obj);
156
+ case ParameterType::DIMNAME_LIST: {
157
+ return false;
158
+ // if (THPUtils_checkDimnameList(obj)) {
159
+ // return true;
160
+ // }
161
+ // // if a size is specified (e.g. DimnameList[1]) we also allow passing a single Dimname
162
+ // return size == 1 && THPUtils_checkDimname(obj);
163
+ }
164
+ case ParameterType::TENSOR_LIST: {
165
+ return is_tensor_list(obj, argnum, true /* throw_error */);
166
+ }
167
+ case ParameterType::INT_LIST: {
168
+ if (RB_TYPE_P(obj, T_ARRAY)) {
169
+ return true;
170
+ }
171
+ // if a size is specified (e.g. IntArrayRef[2]) we also allow passing a single int
172
+ return size > 0 && FIXNUM_P(obj);
173
+ }
174
+ case ParameterType::FLOAT_LIST: return (RB_TYPE_P(obj, T_ARRAY));
175
+ case ParameterType::GENERATOR: return false; // return THPGenerator_Check(obj);
176
+ case ParameterType::BOOL: return obj == Qtrue || obj == Qfalse;
177
+ case ParameterType::STORAGE: return false; // return isStorage(obj);
178
+ // case ParameterType::PYOBJECT: return true;
179
+ case ParameterType::SCALARTYPE: return SYMBOL_P(obj);
180
+ case ParameterType::LAYOUT: return SYMBOL_P(obj);
181
+ case ParameterType::MEMORY_FORMAT: return false; // return THPMemoryFormat_Check(obj);
182
+ case ParameterType::QSCHEME: return false; // return THPQScheme_Check(obj);
183
+ case ParameterType::DEVICE: return RB_TYPE_P(obj, T_STRING); // TODO check device
184
+ case ParameterType::STRING: return RB_TYPE_P(obj, T_STRING);
185
+ default: throw std::runtime_error("unknown parameter type");
186
+ }
187
+ }
188
+
189
+ std::string FunctionParameter::type_name() const {
190
+ switch (type_) {
191
+ case ParameterType::TENSOR: return "Tensor";
192
+ case ParameterType::SCALAR: return "Number";
193
+ case ParameterType::INT64: return "int";
194
+ case ParameterType::DOUBLE: return "float";
195
+ case ParameterType::COMPLEX: return "complex";
196
+ case ParameterType::TENSOR_LIST: return "array of Tensors";
197
+ case ParameterType::INT_LIST: return "array of ints";
198
+ case ParameterType::FLOAT_LIST: return "array of floats";
199
+ case ParameterType::GENERATOR: return "torch.Generator";
200
+ case ParameterType::BOOL: return "bool";
201
+ case ParameterType::STORAGE: return "torch.Storage";
202
+ // case ParameterType::PYOBJECT: return "object";
203
+ case ParameterType::SCALARTYPE: return "torch.dtype";
204
+ case ParameterType::LAYOUT: return "torch.layout";
205
+ case ParameterType::MEMORY_FORMAT: return "torch.memory_format";
206
+ case ParameterType::QSCHEME: return "torch.qscheme";
207
+ case ParameterType::DEVICE: return "torch.device";
208
+ case ParameterType::STRING: return "str";
209
+ case ParameterType::DIMNAME: return "name";
210
+ case ParameterType::DIMNAME_LIST: return "array of names";
211
+ default: throw std::runtime_error("unknown parameter type");
212
+ }
213
+ }
214
+
215
+ static inline c10::optional<int64_t> parse_as_integer(const std::string& s) {
216
+ if (s.empty())
217
+ return c10::nullopt;
218
+ char *str_end;
219
+ long ans = strtol(s.c_str(), &str_end, 0);
220
+ // *str_end == 0 if the entire string was parsed as an integer.
221
+ return (*str_end == 0) ? c10::optional<int64_t>(ans) : c10::nullopt;
222
+ }
223
+
224
+ /*
225
+ Parse default value of IntArrayRef declared at native_functions.yaml
226
+
227
+ There are two kinds of default values:
228
+ 1. IntArrayRef[2] x=1 (where size=2, value={1,1}
229
+ 2. IntArrayRef x={1,2,3} (where size=3, value={1,2,3}, note that there cannot be space after comma since native_parse.py uses ', ' to split args)
230
+ */
231
+ static inline std::vector<int64_t> parse_intlist_args(const std::string& s, int64_t size) {
232
+ size_t n = s.size();
233
+
234
+ if (s.empty()) return std::vector<int64_t>();
235
+
236
+ // case 1. s is an int (e.g., s=2)
237
+ if (s[0] != '{') {
238
+ return std::vector<int64_t>(size, std::stol(s));
239
+ }
240
+
241
+ // case 2. s is a list of dims (e.g., s={1,2})
242
+
243
+ // since already checked left brace '{' above, here only checks right brace '}'
244
+ TORCH_CHECK(s[n - 1] == '}', "Default value of IntArrayRef is missing right brace '}', found ", s[n - 1]);
245
+
246
+ auto args = std::vector<int64_t>();
247
+ std::istringstream ss(s.substr(1, s.length() - 2)); // exclude '{' and '}'
248
+ std::string tok;
249
+
250
+ while(std::getline(ss, tok, ',')) {
251
+ args.emplace_back(std::stol(tok));
252
+ }
253
+ return args;
254
+ }
255
+
256
+ void FunctionParameter::set_default_str(const std::string& str) {
257
+ if (str == "None") {
258
+ allow_none = true;
259
+ }
260
+ if (type_ == ParameterType::TENSOR) {
261
+ if (str != "None") {
262
+ throw std::runtime_error("default value for Tensor must be none, got: " + str);
263
+ }
264
+ } else if (type_ == ParameterType::INT64) {
265
+ default_int = atol(str.c_str());
266
+ } else if (type_ == ParameterType::BOOL) {
267
+ default_bool = (str == "True" || str == "true");
268
+ } else if (type_ == ParameterType::DOUBLE) {
269
+ default_double = atof(str.c_str());
270
+ } else if (type_ == ParameterType::COMPLEX) {
271
+ default_complex[0] = atof(str.c_str()); // TODO: parse "x + xj"?
272
+ default_complex[1] = 0;
273
+ } else if (type_ == ParameterType::SCALAR) {
274
+ if (str != "None") {
275
+ // we sometimes rely on integer-vs-float values, e.g. with arange.
276
+ const auto as_integer = parse_as_integer(str);
277
+ default_scalar = as_integer.has_value() ? at::Scalar(as_integer.value()) :
278
+ at::Scalar(atof(str.c_str()));
279
+ }
280
+ } else if (type_ == ParameterType::INT_LIST) {
281
+ if (str != "None") {
282
+ default_intlist = parse_intlist_args(str, size);
283
+ }
284
+ } else if (type_ == ParameterType::FLOAT_LIST) {
285
+ if (str != "None") {
286
+ throw std::runtime_error("Defaults not supported for float[]");
287
+ }
288
+ } else if (type_ == ParameterType::SCALARTYPE) {
289
+ if (str == "None") {
290
+ default_scalartype = at::ScalarType::Undefined;
291
+ } else if (str == "torch.int64") {
292
+ default_scalartype = at::ScalarType::Long;
293
+ } else {
294
+ throw std::runtime_error("invalid default value for ScalarType: " + str);
295
+ }
296
+ } else if (type_ == ParameterType::LAYOUT) {
297
+ if (str == "None") {
298
+ TORCH_INTERNAL_ASSERT_DEBUG_ONLY(allow_none);
299
+ } else if (str == "torch.strided") {
300
+ default_layout = at::Layout::Strided;
301
+ } else if (str == "torch.sparse_coo") {
302
+ default_layout = at::Layout::Sparse;
303
+ } else {
304
+ throw std::runtime_error("invalid default value for layout: " + str);
305
+ }
306
+ } else if (type_ == ParameterType::DEVICE) {
307
+ if (str != "None") {
308
+ throw std::runtime_error("invalid device: " + str);
309
+ }
310
+ } else if (type_ == ParameterType::STRING) {
311
+ if (str != "None" && str != "") {
312
+ throw std::runtime_error("invalid default string: " + str);
313
+ }
314
+ }
315
+ }
316
+
317
+ FunctionSignature::FunctionSignature(const std::string& fmt, int index)
318
+ : min_args(0)
319
+ , max_args(0)
320
+ , max_pos_args(0)
321
+ , index(index)
322
+ , hidden(false)
323
+ , deprecated(false)
324
+ {
325
+ auto open_paren = fmt.find('(');
326
+ if (open_paren == std::string::npos) {
327
+ throw std::runtime_error("missing opening parenthesis: " + fmt);
328
+ }
329
+ name = fmt.substr(0, open_paren);
330
+
331
+ bool allow_numbers_as_tensors = should_allow_numbers_as_tensors(name);
332
+
333
+ auto last_offset = open_paren + 1;
334
+ auto next_offset = last_offset;
335
+ bool keyword_only = false;
336
+ bool done = false;
337
+ while (!done) {
338
+ auto offset = fmt.find(", ", last_offset);
339
+ if (offset == std::string::npos) {
340
+ offset = fmt.find(')', last_offset);
341
+ done = true;
342
+ next_offset = offset+ 1;
343
+ // this 'if' happens for an empty parameter list, i.e. fn().
344
+ if (offset == last_offset) {
345
+ last_offset = next_offset;
346
+ break;
347
+ }
348
+ } else {
349
+ next_offset = offset + 2;
350
+ }
351
+ if (offset == std::string::npos) {
352
+ throw std::runtime_error("missing closing parenthesis: " + fmt);
353
+ }
354
+ if (offset == last_offset) {
355
+ throw std::runtime_error("malformed signature: " + fmt);
356
+ }
357
+
358
+ auto param_str = fmt.substr(last_offset, offset - last_offset);
359
+ last_offset = next_offset;
360
+ if (param_str == "*") {
361
+ keyword_only = true;
362
+ } else {
363
+ params.emplace_back(param_str, keyword_only);
364
+ params.back().allow_numbers_as_tensors = allow_numbers_as_tensors;
365
+ }
366
+ }
367
+
368
+ if (fmt.substr(last_offset) == "|deprecated") {
369
+ hidden = true;
370
+ // TODO: raise warning when parsing deprecated signatures
371
+ deprecated = true;
372
+ } else if (fmt.substr(last_offset) == "|hidden") {
373
+ hidden = true;
374
+ }
375
+
376
+ max_args = params.size();
377
+
378
+ // count the number of non-optional args
379
+ for (auto& param : params) {
380
+ if (!param.optional) {
381
+ min_args++;
382
+ }
383
+ if (!param.keyword_only) {
384
+ max_pos_args++;
385
+ }
386
+ }
387
+ }
388
+
389
+ std::string FunctionSignature::toString() const {
390
+ // TODO: consider printing more proper schema strings with defaults, optionals, etc.
391
+ std::ostringstream ss;
392
+ bool keyword_already = false;
393
+ ss << "(";
394
+ int i = 0;
395
+ for (auto& param : params) {
396
+ if (i != 0) {
397
+ ss << ", ";
398
+ }
399
+ if (param.keyword_only && !keyword_already) {
400
+ ss << "*, ";
401
+ keyword_already = true;
402
+ }
403
+ ss << param.type_name() << " " << param.name;
404
+ i++;
405
+ }
406
+ ss << ")";
407
+ return ss.str();
408
+ }
409
+
410
+ [[noreturn]]
411
+ static void extra_args(const FunctionSignature& signature, ssize_t nargs) {
412
+ const long max_pos_args = signature.max_pos_args;
413
+ const long min_args = signature.min_args;
414
+ const long nargs_ = nargs;
415
+ if (min_args != max_pos_args) {
416
+ rb_raise(rb_eArgError, "%s() takes from %ld to %ld positional arguments but %ld were given",
417
+ signature.name.c_str(), min_args, max_pos_args, nargs_);
418
+ }
419
+ rb_raise(rb_eArgError, "%s() takes %ld positional argument%s but %ld %s given",
420
+ signature.name.c_str(),
421
+ max_pos_args, max_pos_args == 1 ? "" : "s",
422
+ nargs_, nargs == 1 ? "was" : "were");
423
+ }
424
+
425
+ [[noreturn]]
426
+ static void missing_args(const FunctionSignature& signature, int idx) {
427
+ int num_missing = 0;
428
+ std::stringstream ss;
429
+
430
+ auto& params = signature.params;
431
+ for (auto it = params.begin() + idx; it != params.end(); ++it) {
432
+ if (!it->optional) {
433
+ if (num_missing > 0) {
434
+ ss << ", ";
435
+ }
436
+ ss << '"' << it->name << '"';
437
+ num_missing++;
438
+ }
439
+ }
440
+
441
+ rb_raise(rb_eArgError, "%s() missing %d required positional argument%s: %s",
442
+ signature.name.c_str(),
443
+ num_missing,
444
+ num_missing == 1 ? "s" : "",
445
+ ss.str().c_str());
446
+ }
447
+
448
+ static ssize_t find_param(FunctionSignature& signature, VALUE name) {
449
+ ssize_t i = 0;
450
+ for (auto& param : signature.params) {
451
+ bool cmp = name == param.ruby_name;
452
+ if (cmp) {
453
+ return i;
454
+ }
455
+ i++;
456
+ }
457
+ return -1;
458
+ }
459
+
460
+ [[noreturn]]
461
+ static void extra_kwargs(FunctionSignature& signature, VALUE kwargs, ssize_t num_pos_args) {
462
+ VALUE key;
463
+
464
+ VALUE keys = rb_funcall(kwargs, rb_intern("keys"), 0);
465
+ if (RARRAY_LEN(keys) > 0) {
466
+ key = rb_ary_entry(keys, 0);
467
+
468
+ if (!THPUtils_checkSymbol(key)) {
469
+ rb_raise(rb_eArgError, "keywords must be symbols, not %s", rb_obj_classname(key));
470
+ }
471
+
472
+ auto param_idx = find_param(signature, key);
473
+ if (param_idx < 0) {
474
+ rb_raise(rb_eArgError, "%s() got an unexpected keyword argument '%s'",
475
+ signature.name.c_str(), THPUtils_unpackSymbol(key).c_str());
476
+ }
477
+
478
+ if (param_idx < num_pos_args) {
479
+ rb_raise(rb_eArgError, "%s() got multiple values for argument '%s'",
480
+ signature.name.c_str(), THPUtils_unpackSymbol(key).c_str());
481
+ }
482
+ }
483
+
484
+ // this should never be hit
485
+ rb_raise(rb_eArgError, "invalid keyword arguments");
486
+ }
487
+
488
+ VALUE missing = Qundef;
489
+
490
+ bool FunctionSignature::parse(VALUE self, VALUE args, VALUE kwargs, std::vector<VALUE> &dst, // NOLINT
491
+ bool raise_exception) {
492
+ auto nargs = NIL_P(args) ? 0 : RARRAY_LEN(args);
493
+ ssize_t remaining_kwargs = NIL_P(kwargs) ? 0 : RHASH_SIZE(kwargs);
494
+ ssize_t arg_pos = 0;
495
+ bool allow_varargs_intlist = false;
496
+
497
+ // if there is a single positional IntArrayRef argument, i.e. expand(..), view(...),
498
+ // allow a var-args style IntArrayRef, so expand(5,3) behaves as expand((5,3))
499
+ if (max_pos_args == 1 && params[0].type_ == ParameterType::INT_LIST) {
500
+ allow_varargs_intlist = true;
501
+ }
502
+
503
+ if (nargs > max_pos_args && !allow_varargs_intlist) {
504
+ if (raise_exception) {
505
+ // foo() takes takes 2 positional arguments but 3 were given
506
+ extra_args(*this, nargs);
507
+ }
508
+ return false;
509
+ }
510
+
511
+ // if (!overloaded_args.empty()) {
512
+ // overloaded_args.clear();
513
+ // }
514
+
515
+ int i = 0;
516
+ // if (self != nullptr && !THPVariable_CheckExact(self) && check_has_torch_function(self)) {
517
+ // append_overloaded_arg(&this->overloaded_args, self);
518
+ // }
519
+ for (auto& param : params) {
520
+ VALUE obj = missing;
521
+ bool is_kwd = false;
522
+ if (arg_pos < nargs) {
523
+ // extra positional args given after single positional IntArrayRef arg
524
+ if (param.keyword_only) {
525
+ if (raise_exception) {
526
+ extra_args(*this, nargs);
527
+ }
528
+ return false;
529
+ }
530
+ obj = rb_ary_entry(args, arg_pos);
531
+ } else if (!NIL_P(kwargs)) {
532
+ obj = rb_hash_lookup2(kwargs, param.ruby_name, missing);
533
+ // for (VALUE numpy_name: param.numpy_python_names) {
534
+ // if (obj) {
535
+ // break;
536
+ // }
537
+ // obj = rb_hash_aref(kwargs, numpy_name);
538
+ // }
539
+ is_kwd = true;
540
+ }
541
+
542
+ if ((obj == missing && param.optional) || (NIL_P(obj) && param.allow_none)) {
543
+ dst[i++] = Qnil;
544
+ } else if (obj == missing) {
545
+ if (raise_exception) {
546
+ // foo() missing 1 required positional argument: "b"
547
+ missing_args(*this, i);
548
+ }
549
+ return false;
550
+ } else if (param.check(obj, i)) {
551
+ dst[i++] = obj;
552
+ // XXX: the Variable check is necessary because sizes become tensors when
553
+ // tracer is enabled. This behavior easily leads to ambiguities, and we
554
+ // should avoid having complex signatures that make use of it...
555
+ } else if (allow_varargs_intlist && arg_pos == 0 && !is_kwd &&
556
+ THPUtils_checkIndex(obj)) {
557
+ // take all positional arguments as this parameter
558
+ // e.g. permute(1, 2, 3) -> permute((1, 2, 3))
559
+ dst[i++] = args;
560
+ arg_pos = nargs;
561
+ continue;
562
+ } else if (raise_exception) {
563
+ if (is_kwd) {
564
+ // foo(): argument 'other' must be str, not int
565
+ rb_raise(rb_eArgError, "%s(): argument '%s' must be %s, not %s",
566
+ name.c_str(), param.name.c_str(), param.type_name().c_str(),
567
+ rb_obj_classname(obj));
568
+ } else {
569
+ // foo(): argument 'other' (position 2) must be str, not int
570
+ rb_raise(rb_eArgError, "%s(): argument '%s' (position %ld) must be %s, not %s",
571
+ name.c_str(), param.name.c_str(), static_cast<long>(arg_pos + 1),
572
+ param.type_name().c_str(), rb_obj_classname(obj));
573
+ }
574
+ } else {
575
+ return false;
576
+ }
577
+
578
+ if (!is_kwd) {
579
+ arg_pos++;
580
+ } else if (obj != missing) {
581
+ remaining_kwargs--;
582
+ }
583
+ }
584
+
585
+ if (remaining_kwargs > 0) {
586
+ if (raise_exception) {
587
+ // foo() got an unexpected keyword argument "b"
588
+ extra_kwargs(*this, kwargs, nargs);
589
+ }
590
+ return false;
591
+ }
592
+ return true;
593
+ }