torch-rb 0.3.6 → 0.5.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,397 @@
1
+ // adapted from PyTorch - python_arg_parser.h
2
+
3
+ #pragma once
4
+
5
+ #include <torch/torch.h>
6
+ #include <rice/Exception.hpp>
7
+
8
+ #include "templates.h"
9
+ #include "utils.h"
10
+
11
+ enum class ParameterType {
12
+ TENSOR, SCALAR, INT64, DOUBLE, COMPLEX, TENSOR_LIST, INT_LIST, GENERATOR,
13
+ BOOL, STORAGE, PYOBJECT, SCALARTYPE, LAYOUT, MEMORY_FORMAT, DEVICE, STRING,
14
+ DIMNAME, DIMNAME_LIST, QSCHEME, FLOAT_LIST
15
+ };
16
+
17
+ struct FunctionParameter {
18
+ FunctionParameter(const std::string& fmt, bool keyword_only);
19
+
20
+ bool check(VALUE obj, int argnum);
21
+
22
+ void set_default_str(const std::string& str);
23
+ std::string type_name() const;
24
+
25
+ ParameterType type_;
26
+ bool optional;
27
+ bool allow_none;
28
+ bool keyword_only;
29
+ bool allow_numbers_as_tensors = false;
30
+ int size;
31
+ std::string name;
32
+ VALUE ruby_name;
33
+ at::SmallVector<VALUE, 5> numpy_python_names;
34
+ at::Scalar default_scalar;
35
+ std::vector<int64_t> default_intlist;
36
+ union {
37
+ bool default_bool;
38
+ int64_t default_int;
39
+ double default_double;
40
+ double default_complex[2]; // see Scalar
41
+ at::ScalarType default_scalartype;
42
+ at::Layout default_layout;
43
+ };
44
+ };
45
+
46
+ struct FunctionSignature {
47
+ explicit FunctionSignature(const std::string& fmt, int index);
48
+
49
+ bool parse(VALUE self, VALUE args, VALUE kwargs, std::vector<VALUE>& dst, bool raise_exception);
50
+
51
+ std::string toString() const;
52
+
53
+ std::string name;
54
+ std::vector<FunctionParameter> params;
55
+ // std::vector<py::handle> overloaded_args;
56
+ ssize_t min_args;
57
+ ssize_t max_args;
58
+ ssize_t max_pos_args;
59
+ int index;
60
+ bool hidden;
61
+ bool deprecated;
62
+ bool disable_torch_function;
63
+ };
64
+
65
+ struct RubyArgs {
66
+ RubyArgs(const FunctionSignature& signature, std::vector<VALUE> &args)
67
+ : signature(signature)
68
+ , args(args)
69
+ , idx(signature.index) {}
70
+
71
+ const FunctionSignature& signature;
72
+ std::vector<VALUE> args;
73
+ int idx;
74
+
75
+ inline at::Tensor tensor(int i);
76
+ inline OptionalTensor optionalTensor(int i);
77
+ inline at::Scalar scalar(int i);
78
+ // inline at::Scalar scalarWithDefault(int i, at::Scalar default_scalar);
79
+ inline std::vector<at::Tensor> tensorlist(int i);
80
+ template<int N>
81
+ inline std::array<at::Tensor, N> tensorlist_n(int i);
82
+ inline std::vector<int64_t> intlist(int i);
83
+ // inline c10::OptionalArray<int64_t> intlistOptional(int i);
84
+ // inline std::vector<int64_t> intlistWithDefault(int i, std::vector<int64_t> default_intlist);
85
+ inline c10::optional<at::Generator> generator(int i);
86
+ inline at::Storage storage(int i);
87
+ inline at::ScalarType scalartype(int i);
88
+ // inline at::ScalarType scalartypeWithDefault(int i, at::ScalarType default_scalartype);
89
+ inline c10::optional<at::ScalarType> scalartypeOptional(int i);
90
+ inline c10::optional<at::Scalar> scalarOptional(int i);
91
+ inline c10::optional<int64_t> toInt64Optional(int i);
92
+ inline c10::optional<bool> toBoolOptional(int i);
93
+ inline c10::optional<double> toDoubleOptional(int i);
94
+ inline c10::OptionalArray<double> doublelistOptional(int i);
95
+ // inline at::Layout layout(int i);
96
+ // inline at::Layout layoutWithDefault(int i, at::Layout default_layout);
97
+ inline c10::optional<at::Layout> layoutOptional(int i);
98
+ inline at::Device device(int i);
99
+ // inline at::Device deviceWithDefault(int i, const at::Device& default_device);
100
+ // inline c10::optional<at::Device> deviceOptional(int i);
101
+ // inline at::Dimname dimname(int i);
102
+ // inline std::vector<at::Dimname> dimnamelist(int i);
103
+ // inline c10::optional<std::vector<at::Dimname>> toDimnameListOptional(int i);
104
+ inline at::MemoryFormat memoryformat(int i);
105
+ inline c10::optional<at::MemoryFormat> memoryformatOptional(int i);
106
+ // inline at::QScheme toQScheme(int i);
107
+ inline std::string string(int i);
108
+ inline c10::optional<std::string> stringOptional(int i);
109
+ // inline PyObject* pyobject(int i);
110
+ inline int64_t toInt64(int i);
111
+ // inline int64_t toInt64WithDefault(int i, int64_t default_int);
112
+ inline double toDouble(int i);
113
+ // inline double toDoubleWithDefault(int i, double default_double);
114
+ // inline c10::complex<double> toComplex(int i);
115
+ // inline c10::complex<double> toComplexWithDefault(int i, c10::complex<double> default_complex);
116
+ inline bool toBool(int i);
117
+ // inline bool toBoolWithDefault(int i, bool default_bool);
118
+ inline bool isNone(int i);
119
+ };
120
+
121
+ inline at::Tensor RubyArgs::tensor(int i) {
122
+ return from_ruby<torch::Tensor>(args[i]);
123
+ }
124
+
125
+ inline OptionalTensor RubyArgs::optionalTensor(int i) {
126
+ if (NIL_P(args[i])) return OptionalTensor(Nil);
127
+ return tensor(i);
128
+ }
129
+
130
+ inline at::Scalar RubyArgs::scalar(int i) {
131
+ if (NIL_P(args[i])) return signature.params[i].default_scalar;
132
+ return from_ruby<torch::Scalar>(args[i]);
133
+ }
134
+
135
+ inline std::vector<at::Tensor> RubyArgs::tensorlist(int i) {
136
+ if (NIL_P(args[i])) return std::vector<at::Tensor>();
137
+ return from_ruby<std::vector<Tensor>>(args[i]);
138
+ }
139
+
140
+ template<int N>
141
+ inline std::array<at::Tensor, N> RubyArgs::tensorlist_n(int i) {
142
+ auto res = std::array<at::Tensor, N>();
143
+ if (NIL_P(args[i])) return res;
144
+ VALUE arg = args[i];
145
+ Check_Type(arg, T_ARRAY);
146
+ auto size = RARRAY_LEN(arg);
147
+ if (size != N) {
148
+ rb_raise(rb_eArgError, "expected array of %d elements but got %d", N, (int)size);
149
+ }
150
+ for (int idx = 0; idx < size; idx++) {
151
+ VALUE obj = rb_ary_entry(arg, idx);
152
+ res[idx] = from_ruby<Tensor>(obj);
153
+ }
154
+ return res;
155
+ }
156
+
157
+ inline std::vector<int64_t> RubyArgs::intlist(int i) {
158
+ if (NIL_P(args[i])) return signature.params[i].default_intlist;
159
+
160
+ VALUE arg = args[i];
161
+ auto size = signature.params[i].size;
162
+ if (size > 0 && FIXNUM_P(arg)) {
163
+ return std::vector<int64_t>(size, FIX2INT(arg));
164
+ }
165
+
166
+ size = RARRAY_LEN(arg);
167
+ std::vector<int64_t> res(size);
168
+ for (idx = 0; idx < size; idx++) {
169
+ VALUE obj = rb_ary_entry(arg, idx);
170
+ if (FIXNUM_P(obj)) {
171
+ res[idx] = from_ruby<int64_t>(obj);
172
+ } else {
173
+ rb_raise(rb_eArgError, "%s(): argument '%s' must be %s, but found element of type %s at pos %d",
174
+ signature.name.c_str(), signature.params[i].name.c_str(),
175
+ signature.params[i].type_name().c_str(), rb_obj_classname(obj), idx + 1);
176
+ }
177
+ }
178
+ return res;
179
+ }
180
+
181
+ inline c10::optional<at::Generator> RubyArgs::generator(int i) {
182
+ if (NIL_P(args[i])) return c10::nullopt;
183
+ throw std::runtime_error("generator not supported yet");
184
+ }
185
+
186
+ inline at::Storage RubyArgs::storage(int i) {
187
+ if (NIL_P(args[i])) return at::Storage();
188
+ throw std::runtime_error("storage not supported yet");
189
+ }
190
+
191
+ inline ScalarType RubyArgs::scalartype(int i) {
192
+ if (NIL_P(args[i])) {
193
+ auto scalartype = signature.params[i].default_scalartype;
194
+ return (scalartype == at::ScalarType::Undefined) ? at::typeMetaToScalarType(at::get_default_dtype()) : scalartype;
195
+ }
196
+
197
+ static std::unordered_map<VALUE, ScalarType> dtype_map = {
198
+ {ID2SYM(rb_intern("uint8")), ScalarType::Byte},
199
+ {ID2SYM(rb_intern("int8")), ScalarType::Char},
200
+ {ID2SYM(rb_intern("short")), ScalarType::Short},
201
+ {ID2SYM(rb_intern("int16")), ScalarType::Short},
202
+ {ID2SYM(rb_intern("int")), ScalarType::Int},
203
+ {ID2SYM(rb_intern("int32")), ScalarType::Int},
204
+ {ID2SYM(rb_intern("long")), ScalarType::Long},
205
+ {ID2SYM(rb_intern("int64")), ScalarType::Long},
206
+ {ID2SYM(rb_intern("float")), ScalarType::Float},
207
+ {ID2SYM(rb_intern("float32")), ScalarType::Float},
208
+ {ID2SYM(rb_intern("double")), ScalarType::Double},
209
+ {ID2SYM(rb_intern("float64")), ScalarType::Double},
210
+ {ID2SYM(rb_intern("complex_half")), ScalarType::ComplexHalf},
211
+ {ID2SYM(rb_intern("complex_float")), ScalarType::ComplexFloat},
212
+ {ID2SYM(rb_intern("complex_double")), ScalarType::ComplexDouble},
213
+ {ID2SYM(rb_intern("bool")), ScalarType::Bool},
214
+ {ID2SYM(rb_intern("qint8")), ScalarType::QInt8},
215
+ {ID2SYM(rb_intern("quint8")), ScalarType::QUInt8},
216
+ {ID2SYM(rb_intern("qint32")), ScalarType::QInt32},
217
+ {ID2SYM(rb_intern("bfloat16")), ScalarType::BFloat16},
218
+ };
219
+
220
+ auto it = dtype_map.find(args[i]);
221
+ if (it == dtype_map.end()) {
222
+ rb_raise(rb_eArgError, "invalid dtype: %s", THPUtils_unpackSymbol(args[i]).c_str());
223
+ }
224
+ return it->second;
225
+ }
226
+
227
+ inline c10::optional<ScalarType> RubyArgs::scalartypeOptional(int i) {
228
+ if (NIL_P(args[i])) return c10::nullopt;
229
+ return scalartype(i);
230
+ }
231
+
232
+ inline c10::optional<Scalar> RubyArgs::scalarOptional(int i) {
233
+ if (NIL_P(args[i])) return c10::nullopt;
234
+ return scalar(i);
235
+ }
236
+
237
+ inline c10::optional<int64_t> RubyArgs::toInt64Optional(int i) {
238
+ if (NIL_P(args[i])) return c10::nullopt;
239
+ return toInt64(i);
240
+ }
241
+
242
+ inline c10::optional<bool> RubyArgs::toBoolOptional(int i) {
243
+ if (NIL_P(args[i])) return c10::nullopt;
244
+ return toBool(i);
245
+ }
246
+
247
+ inline c10::optional<double> RubyArgs::toDoubleOptional(int i) {
248
+ if (NIL_P(args[i])) return c10::nullopt;
249
+ return toDouble(i);
250
+ }
251
+
252
+ inline c10::OptionalArray<double> RubyArgs::doublelistOptional(int i) {
253
+ if (NIL_P(args[i])) return {};
254
+
255
+ VALUE arg = args[i];
256
+ auto size = RARRAY_LEN(arg);
257
+ std::vector<double> res(size);
258
+ for (idx = 0; idx < size; idx++) {
259
+ VALUE obj = rb_ary_entry(arg, idx);
260
+ if (FIXNUM_P(obj) || RB_FLOAT_TYPE_P(obj)) {
261
+ res[idx] = from_ruby<double>(obj);
262
+ } else {
263
+ rb_raise(rb_eArgError, "%s(): argument '%s' must be %s, but found element of type %s at pos %d",
264
+ signature.name.c_str(), signature.params[i].name.c_str(),
265
+ signature.params[i].type_name().c_str(), rb_obj_classname(obj), idx + 1);
266
+ }
267
+ }
268
+ return res;
269
+ }
270
+
271
+ inline c10::optional<at::Layout> RubyArgs::layoutOptional(int i) {
272
+ if (NIL_P(args[i])) return c10::nullopt;
273
+
274
+ static std::unordered_map<VALUE, Layout> layout_map = {
275
+ {ID2SYM(rb_intern("strided")), Layout::Strided},
276
+ };
277
+
278
+ auto it = layout_map.find(args[i]);
279
+ if (it == layout_map.end()) {
280
+ rb_raise(rb_eArgError, "invalid layout: %s", THPUtils_unpackSymbol(args[i]).c_str());
281
+ }
282
+ return it->second;
283
+ }
284
+
285
+ inline at::Device RubyArgs::device(int i) {
286
+ if (NIL_P(args[i])) {
287
+ return at::Device("cpu");
288
+ }
289
+ const std::string &device_str = THPUtils_unpackString(args[i]);
290
+ return at::Device(device_str);
291
+ }
292
+
293
+ inline at::MemoryFormat RubyArgs::memoryformat(int i) {
294
+ if (NIL_P(args[i])) return at::MemoryFormat::Contiguous;
295
+ throw std::runtime_error("memoryformat not supported yet");
296
+ }
297
+
298
+ inline c10::optional<at::MemoryFormat> RubyArgs::memoryformatOptional(int i) {
299
+ if (NIL_P(args[i])) return c10::nullopt;
300
+ return memoryformat(i);
301
+ }
302
+
303
+ inline std::string RubyArgs::string(int i) {
304
+ return from_ruby<std::string>(args[i]);
305
+ }
306
+
307
+ inline c10::optional<std::string> RubyArgs::stringOptional(int i) {
308
+ if (!args[i]) return c10::nullopt;
309
+ return from_ruby<std::string>(args[i]);
310
+ }
311
+
312
+ inline int64_t RubyArgs::toInt64(int i) {
313
+ if (NIL_P(args[i])) return signature.params[i].default_int;
314
+ return from_ruby<int64_t>(args[i]);
315
+ }
316
+
317
+ inline double RubyArgs::toDouble(int i) {
318
+ if (NIL_P(args[i])) return signature.params[i].default_double;
319
+ return from_ruby<double>(args[i]);
320
+ }
321
+
322
+ inline bool RubyArgs::toBool(int i) {
323
+ if (NIL_P(args[i])) return signature.params[i].default_bool;
324
+ return RTEST(args[i]);
325
+ }
326
+
327
+ inline bool RubyArgs::isNone(int i) {
328
+ return NIL_P(args[i]);
329
+ }
330
+
331
+ struct RubyArgParser {
332
+ std::vector<FunctionSignature> signatures_;
333
+ std::string function_name;
334
+ ssize_t max_args;
335
+
336
+ public:
337
+ RubyArgParser(std::vector<std::string> fmts) : max_args(0) {
338
+ int index = 0;
339
+ for (auto& fmt : fmts) {
340
+ signatures_.emplace_back(fmt, index);
341
+ ++index;
342
+ }
343
+ for (auto& signature : signatures_) {
344
+ if (signature.max_args > max_args) {
345
+ max_args = signature.max_args;
346
+ }
347
+ }
348
+ if (signatures_.size() > 0) {
349
+ function_name = signatures_[0].name;
350
+ }
351
+
352
+ // Check deprecated signatures last
353
+ std::stable_partition(signatures_.begin(), signatures_.end(),
354
+ [](const FunctionSignature & sig) {
355
+ return !sig.deprecated;
356
+ });
357
+ }
358
+
359
+ RubyArgs parse(VALUE self, int argc, VALUE* argv, std::vector<VALUE> &parsed_args) {
360
+ VALUE args, kwargs;
361
+ rb_scan_args(argc, argv, "*:", &args, &kwargs);
362
+
363
+ if (signatures_.size() == 1) {
364
+ auto& signature = signatures_[0];
365
+ signature.parse(self, args, kwargs, parsed_args, true);
366
+ return RubyArgs(signature, parsed_args);
367
+ }
368
+
369
+ for (auto& signature : signatures_) {
370
+ if (signature.parse(self, args, kwargs, parsed_args, false)) {
371
+ return RubyArgs(signature, parsed_args);
372
+ }
373
+ }
374
+
375
+ print_error(self, args, kwargs, parsed_args);
376
+
377
+ // TODO better message
378
+ rb_raise(rb_eArgError, "No matching signatures");
379
+ }
380
+
381
+ void print_error(VALUE self, VALUE args, VALUE kwargs, std::vector<VALUE>& parsed_args) {
382
+ ssize_t num_args = (NIL_P(args) ? 0 : RARRAY_LEN(args)) + (NIL_P(kwargs) ? 0 : RHASH_SIZE(kwargs));
383
+ std::vector<int> plausible_idxs;
384
+ ssize_t i = 0;
385
+ for (auto& signature : signatures_) {
386
+ if (num_args >= signature.min_args && num_args <= signature.max_args && !signature.hidden) {
387
+ plausible_idxs.push_back(i);
388
+ }
389
+ i++;
390
+ }
391
+
392
+ if (plausible_idxs.size() == 1) {
393
+ auto& signature = signatures_[plausible_idxs[0]];
394
+ signature.parse(self, args, kwargs, parsed_args, true);
395
+ }
396
+ }
397
+ };
@@ -13,49 +13,53 @@ using torch::Device;
13
13
  using torch::Scalar;
14
14
  using torch::ScalarType;
15
15
  using torch::Tensor;
16
+ using torch::QScheme;
17
+ using torch::Generator;
18
+ using torch::TensorOptions;
19
+ using torch::Layout;
20
+ using torch::MemoryFormat;
21
+ using torch::IntArrayRef;
22
+ using torch::ArrayRef;
23
+ using torch::TensorList;
24
+ using torch::Storage;
25
+
26
+ #define HANDLE_TH_ERRORS \
27
+ try {
28
+
29
+ #define END_HANDLE_TH_ERRORS \
30
+ } catch (const torch::Error& ex) { \
31
+ rb_raise(rb_eRuntimeError, "%s", ex.what_without_backtrace()); \
32
+ } catch (const Rice::Exception& ex) { \
33
+ rb_raise(ex.class_of(), "%s", ex.what()); \
34
+ } catch (const std::exception& ex) { \
35
+ rb_raise(rb_eRuntimeError, "%s", ex.what()); \
36
+ }
16
37
 
17
- // need to wrap torch::IntArrayRef() since
18
- // it doesn't own underlying data
19
- class IntArrayRef {
20
- std::vector<int64_t> vec;
21
- public:
22
- IntArrayRef(Object o) {
23
- Array a = Array(o);
24
- for (size_t i = 0; i < a.size(); i++) {
25
- vec.push_back(from_ruby<int64_t>(a[i]));
26
- }
27
- }
28
- operator torch::IntArrayRef() {
29
- return torch::IntArrayRef(vec);
30
- }
31
- };
38
+ #define RETURN_NIL \
39
+ return Qnil;
32
40
 
33
41
  template<>
34
42
  inline
35
- IntArrayRef from_ruby<IntArrayRef>(Object x)
43
+ std::vector<int64_t> from_ruby<std::vector<int64_t>>(Object x)
36
44
  {
37
- return IntArrayRef(x);
45
+ Array a = Array(x);
46
+ std::vector<int64_t> vec(a.size());
47
+ for (size_t i = 0; i < a.size(); i++) {
48
+ vec[i] = from_ruby<int64_t>(a[i]);
49
+ }
50
+ return vec;
38
51
  }
39
52
 
40
- class TensorList {
41
- std::vector<torch::Tensor> vec;
42
- public:
43
- TensorList(Object o) {
44
- Array a = Array(o);
45
- for (size_t i = 0; i < a.size(); i++) {
46
- vec.push_back(from_ruby<torch::Tensor>(a[i]));
47
- }
48
- }
49
- operator torch::TensorList() {
50
- return torch::TensorList(vec);
51
- }
52
- };
53
-
54
53
  template<>
55
54
  inline
56
- TensorList from_ruby<TensorList>(Object x)
55
+ std::vector<Tensor> from_ruby<std::vector<Tensor>>(Object x)
57
56
  {
58
- return TensorList(x);
57
+ Array a = Array(x);
58
+ std::vector<Tensor> vec(a.size());
59
+ for (size_t i = 0; i < a.size(); i++) {
60
+ vec[i] = from_ruby<Tensor>(a[i]);
61
+ }
62
+ return vec;
59
63
  }
60
64
 
61
65
  class FanModeType {
@@ -124,48 +128,21 @@ NonlinearityType from_ruby<NonlinearityType>(Object x)
124
128
  return NonlinearityType(x);
125
129
  }
126
130
 
127
- class MyReduction {
128
- Object value;
131
+ class OptionalTensor {
132
+ torch::Tensor value;
129
133
  public:
130
- MyReduction(Object o) {
131
- value = o;
132
- }
133
- operator int64_t() {
134
- if (value.is_nil()) {
135
- return torch::Reduction::None;
136
- }
137
-
138
- std::string s = String(value).str();
139
- if (s == "mean") {
140
- return torch::Reduction::Mean;
141
- } else if (s == "sum") {
142
- return torch::Reduction::Sum;
143
- } else if (s == "none") {
144
- return torch::Reduction::None;
134
+ OptionalTensor(Object o) {
135
+ if (o.is_nil()) {
136
+ value = {};
145
137
  } else {
146
- throw std::runtime_error("Unsupported reduction: " + s);
138
+ value = from_ruby<torch::Tensor>(o);
147
139
  }
148
140
  }
149
- };
150
-
151
- template<>
152
- inline
153
- MyReduction from_ruby<MyReduction>(Object x)
154
- {
155
- return MyReduction(x);
156
- }
157
-
158
- class OptionalTensor {
159
- Object value;
160
- public:
161
- OptionalTensor(Object o) {
141
+ OptionalTensor(torch::Tensor o) {
162
142
  value = o;
163
143
  }
164
- operator torch::Tensor() {
165
- if (value.is_nil()) {
166
- return {};
167
- }
168
- return from_ruby<torch::Tensor>(value);
144
+ operator torch::Tensor() const {
145
+ return value;
169
146
  }
170
147
  };
171
148
 
@@ -241,11 +218,3 @@ torch::optional<Scalar> from_ruby<torch::optional<Scalar>>(Object x)
241
218
  return torch::optional<Scalar>{from_ruby<Scalar>(x)};
242
219
  }
243
220
  }
244
-
245
- Object wrap(std::tuple<torch::Tensor, torch::Tensor> x);
246
- Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x);
247
- Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x);
248
- Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x);
249
- Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x);
250
- Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x);
251
- Object wrap(std::vector<torch::Tensor> x);