torch-rb 0.3.4 → 0.4.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +28 -0
- data/README.md +2 -1
- data/codegen/function.rb +134 -0
- data/codegen/generate_functions.rb +549 -0
- data/{lib/torch/native → codegen}/native_functions.yaml +0 -0
- data/ext/torch/ext.cpp +76 -87
- data/ext/torch/extconf.rb +5 -2
- data/ext/torch/nn_functions.h +6 -0
- data/ext/torch/ruby_arg_parser.cpp +593 -0
- data/ext/torch/ruby_arg_parser.h +373 -0
- data/ext/torch/{templates.hpp → templates.h} +87 -97
- data/ext/torch/tensor_functions.h +6 -0
- data/ext/torch/torch_functions.h +6 -0
- data/ext/torch/utils.h +42 -0
- data/ext/torch/{templates.cpp → wrap_outputs.h} +44 -7
- data/lib/torch.rb +51 -77
- data/lib/torch/nn/functional.rb +142 -18
- data/lib/torch/nn/init.rb +5 -19
- data/lib/torch/nn/leaky_relu.rb +3 -3
- data/lib/torch/nn/module.rb +9 -1
- data/lib/torch/nn/upsample.rb +31 -0
- data/lib/torch/optim/adadelta.rb +1 -1
- data/lib/torch/optim/adam.rb +2 -2
- data/lib/torch/optim/adamax.rb +1 -1
- data/lib/torch/optim/adamw.rb +1 -1
- data/lib/torch/optim/asgd.rb +1 -1
- data/lib/torch/optim/sgd.rb +3 -3
- data/lib/torch/tensor.rb +36 -115
- data/lib/torch/utils/data/data_loader.rb +2 -0
- data/lib/torch/utils/data/tensor_dataset.rb +2 -0
- data/lib/torch/version.rb +1 -1
- metadata +19 -14
- data/lib/torch/native/dispatcher.rb +0 -48
- data/lib/torch/native/function.rb +0 -115
- data/lib/torch/native/generator.rb +0 -163
- data/lib/torch/native/parser.rb +0 -140
@@ -0,0 +1,373 @@
|
|
1
|
+
// adapted from PyTorch - python_arg_parser.h
|
2
|
+
|
3
|
+
#pragma once
|
4
|
+
|
5
|
+
#include <torch/torch.h>
|
6
|
+
#include <rice/Exception.hpp>
|
7
|
+
|
8
|
+
#include "templates.h"
|
9
|
+
#include "utils.h"
|
10
|
+
|
11
|
+
enum class ParameterType {
|
12
|
+
TENSOR, SCALAR, INT64, DOUBLE, COMPLEX, TENSOR_LIST, INT_LIST, GENERATOR,
|
13
|
+
BOOL, STORAGE, PYOBJECT, SCALARTYPE, LAYOUT, MEMORY_FORMAT, DEVICE, STRING,
|
14
|
+
DIMNAME, DIMNAME_LIST, QSCHEME, FLOAT_LIST
|
15
|
+
};
|
16
|
+
|
17
|
+
struct FunctionParameter {
|
18
|
+
FunctionParameter(const std::string& fmt, bool keyword_only);
|
19
|
+
|
20
|
+
bool check(VALUE obj, int argnum);
|
21
|
+
|
22
|
+
void set_default_str(const std::string& str);
|
23
|
+
std::string type_name() const;
|
24
|
+
|
25
|
+
ParameterType type_;
|
26
|
+
bool optional;
|
27
|
+
bool allow_none;
|
28
|
+
bool keyword_only;
|
29
|
+
bool allow_numbers_as_tensors = false;
|
30
|
+
int size;
|
31
|
+
std::string name;
|
32
|
+
VALUE ruby_name;
|
33
|
+
at::SmallVector<VALUE, 5> numpy_python_names;
|
34
|
+
at::Scalar default_scalar;
|
35
|
+
std::vector<int64_t> default_intlist;
|
36
|
+
union {
|
37
|
+
bool default_bool;
|
38
|
+
int64_t default_int;
|
39
|
+
double default_double;
|
40
|
+
double default_complex[2]; // see Scalar
|
41
|
+
at::ScalarType default_scalartype;
|
42
|
+
at::Layout default_layout;
|
43
|
+
};
|
44
|
+
};
|
45
|
+
|
46
|
+
struct FunctionSignature {
|
47
|
+
explicit FunctionSignature(const std::string& fmt, int index);
|
48
|
+
|
49
|
+
bool parse(VALUE self, VALUE args, VALUE kwargs, std::vector<VALUE>& dst, bool raise_exception);
|
50
|
+
|
51
|
+
std::string toString() const;
|
52
|
+
|
53
|
+
std::string name;
|
54
|
+
std::vector<FunctionParameter> params;
|
55
|
+
// std::vector<py::handle> overloaded_args;
|
56
|
+
ssize_t min_args;
|
57
|
+
ssize_t max_args;
|
58
|
+
ssize_t max_pos_args;
|
59
|
+
int index;
|
60
|
+
bool hidden;
|
61
|
+
bool deprecated;
|
62
|
+
bool disable_torch_function;
|
63
|
+
};
|
64
|
+
|
65
|
+
struct RubyArgs {
|
66
|
+
RubyArgs(const FunctionSignature& signature, std::vector<VALUE> &args)
|
67
|
+
: signature(signature)
|
68
|
+
, args(args)
|
69
|
+
, idx(signature.index) {}
|
70
|
+
|
71
|
+
const FunctionSignature& signature;
|
72
|
+
std::vector<VALUE> args;
|
73
|
+
int idx;
|
74
|
+
|
75
|
+
inline at::Tensor tensor(int i);
|
76
|
+
inline OptionalTensor optionalTensor(int i);
|
77
|
+
inline at::Scalar scalar(int i);
|
78
|
+
// inline at::Scalar scalarWithDefault(int i, at::Scalar default_scalar);
|
79
|
+
inline std::vector<at::Tensor> tensorlist(int i);
|
80
|
+
template<int N>
|
81
|
+
inline std::array<at::Tensor, N> tensorlist_n(int i);
|
82
|
+
inline std::vector<int64_t> intlist(int i);
|
83
|
+
// inline c10::OptionalArray<int64_t> intlistOptional(int i);
|
84
|
+
// inline std::vector<int64_t> intlistWithDefault(int i, std::vector<int64_t> default_intlist);
|
85
|
+
inline c10::optional<at::Generator> generator(int i);
|
86
|
+
inline at::Storage storage(int i);
|
87
|
+
inline at::ScalarType scalartype(int i);
|
88
|
+
// inline at::ScalarType scalartypeWithDefault(int i, at::ScalarType default_scalartype);
|
89
|
+
inline c10::optional<at::ScalarType> scalartypeOptional(int i);
|
90
|
+
inline c10::optional<at::Scalar> scalarOptional(int i);
|
91
|
+
inline c10::optional<int64_t> toInt64Optional(int i);
|
92
|
+
inline c10::optional<bool> toBoolOptional(int i);
|
93
|
+
inline c10::optional<double> toDoubleOptional(int i);
|
94
|
+
// inline c10::OptionalArray<double> doublelistOptional(int i);
|
95
|
+
// inline at::Layout layout(int i);
|
96
|
+
// inline at::Layout layoutWithDefault(int i, at::Layout default_layout);
|
97
|
+
inline c10::optional<at::Layout> layoutOptional(int i);
|
98
|
+
inline at::Device device(int i);
|
99
|
+
// inline at::Device deviceWithDefault(int i, const at::Device& default_device);
|
100
|
+
// inline c10::optional<at::Device> deviceOptional(int i);
|
101
|
+
// inline at::Dimname dimname(int i);
|
102
|
+
// inline std::vector<at::Dimname> dimnamelist(int i);
|
103
|
+
// inline c10::optional<std::vector<at::Dimname>> toDimnameListOptional(int i);
|
104
|
+
inline at::MemoryFormat memoryformat(int i);
|
105
|
+
inline c10::optional<at::MemoryFormat> memoryformatOptional(int i);
|
106
|
+
// inline at::QScheme toQScheme(int i);
|
107
|
+
inline std::string string(int i);
|
108
|
+
// inline c10::optional<std::string> stringOptional(int i);
|
109
|
+
// inline PyObject* pyobject(int i);
|
110
|
+
inline int64_t toInt64(int i);
|
111
|
+
// inline int64_t toInt64WithDefault(int i, int64_t default_int);
|
112
|
+
inline double toDouble(int i);
|
113
|
+
// inline double toDoubleWithDefault(int i, double default_double);
|
114
|
+
// inline c10::complex<double> toComplex(int i);
|
115
|
+
// inline c10::complex<double> toComplexWithDefault(int i, c10::complex<double> default_complex);
|
116
|
+
inline bool toBool(int i);
|
117
|
+
// inline bool toBoolWithDefault(int i, bool default_bool);
|
118
|
+
inline bool isNone(int i);
|
119
|
+
};
|
120
|
+
|
121
|
+
inline at::Tensor RubyArgs::tensor(int i) {
|
122
|
+
return from_ruby<torch::Tensor>(args[i]);
|
123
|
+
}
|
124
|
+
|
125
|
+
inline OptionalTensor RubyArgs::optionalTensor(int i) {
|
126
|
+
if (NIL_P(args[i])) return OptionalTensor(Nil);
|
127
|
+
return tensor(i);
|
128
|
+
}
|
129
|
+
|
130
|
+
inline at::Scalar RubyArgs::scalar(int i) {
|
131
|
+
if (NIL_P(args[i])) return signature.params[i].default_scalar;
|
132
|
+
return from_ruby<torch::Scalar>(args[i]);
|
133
|
+
}
|
134
|
+
|
135
|
+
inline std::vector<at::Tensor> RubyArgs::tensorlist(int i) {
|
136
|
+
if (NIL_P(args[i])) return std::vector<at::Tensor>();
|
137
|
+
return from_ruby<std::vector<Tensor>>(args[i]);
|
138
|
+
}
|
139
|
+
|
140
|
+
template<int N>
|
141
|
+
inline std::array<at::Tensor, N> RubyArgs::tensorlist_n(int i) {
|
142
|
+
auto res = std::array<at::Tensor, N>();
|
143
|
+
if (NIL_P(args[i])) return res;
|
144
|
+
VALUE arg = args[i];
|
145
|
+
Check_Type(arg, T_ARRAY);
|
146
|
+
auto size = RARRAY_LEN(arg);
|
147
|
+
if (size != N) {
|
148
|
+
rb_raise(rb_eArgError, "expected array of %d elements but got %d", N, (int)size);
|
149
|
+
}
|
150
|
+
for (int idx = 0; idx < size; idx++) {
|
151
|
+
VALUE obj = rb_ary_entry(arg, idx);
|
152
|
+
res[idx] = from_ruby<Tensor>(obj);
|
153
|
+
}
|
154
|
+
return res;
|
155
|
+
}
|
156
|
+
|
157
|
+
inline std::vector<int64_t> RubyArgs::intlist(int i) {
|
158
|
+
if (NIL_P(args[i])) return signature.params[i].default_intlist;
|
159
|
+
|
160
|
+
VALUE arg = args[i];
|
161
|
+
auto size = signature.params[i].size;
|
162
|
+
if (size > 0 && FIXNUM_P(arg)) {
|
163
|
+
return std::vector<int64_t>(size, FIX2INT(arg));
|
164
|
+
}
|
165
|
+
|
166
|
+
size = RARRAY_LEN(arg);
|
167
|
+
std::vector<int64_t> res(size);
|
168
|
+
for (idx = 0; idx < size; idx++) {
|
169
|
+
VALUE obj = rb_ary_entry(arg, idx);
|
170
|
+
if (FIXNUM_P(obj)) {
|
171
|
+
res[idx] = from_ruby<int64_t>(obj);
|
172
|
+
} else {
|
173
|
+
rb_raise(rb_eArgError, "%s(): argument '%s' must be %s, but found element of type %s at pos %d",
|
174
|
+
signature.name.c_str(), signature.params[i].name.c_str(),
|
175
|
+
signature.params[i].type_name().c_str(), rb_obj_classname(obj), idx + 1);
|
176
|
+
}
|
177
|
+
}
|
178
|
+
return res;
|
179
|
+
}
|
180
|
+
|
181
|
+
inline c10::optional<at::Generator> RubyArgs::generator(int i) {
|
182
|
+
if (NIL_P(args[i])) return c10::nullopt;
|
183
|
+
throw std::runtime_error("generator not supported yet");
|
184
|
+
}
|
185
|
+
|
186
|
+
inline at::Storage RubyArgs::storage(int i) {
|
187
|
+
if (NIL_P(args[i])) return at::Storage();
|
188
|
+
throw std::runtime_error("storage not supported yet");
|
189
|
+
}
|
190
|
+
|
191
|
+
inline ScalarType RubyArgs::scalartype(int i) {
|
192
|
+
if (NIL_P(args[i])) {
|
193
|
+
auto scalartype = signature.params[i].default_scalartype;
|
194
|
+
return (scalartype == at::ScalarType::Undefined) ? at::typeMetaToScalarType(at::get_default_dtype()) : scalartype;
|
195
|
+
}
|
196
|
+
|
197
|
+
static std::unordered_map<VALUE, ScalarType> dtype_map = {
|
198
|
+
{ID2SYM(rb_intern("uint8")), ScalarType::Byte},
|
199
|
+
{ID2SYM(rb_intern("int8")), ScalarType::Char},
|
200
|
+
{ID2SYM(rb_intern("short")), ScalarType::Short},
|
201
|
+
{ID2SYM(rb_intern("int16")), ScalarType::Short},
|
202
|
+
{ID2SYM(rb_intern("int")), ScalarType::Int},
|
203
|
+
{ID2SYM(rb_intern("int32")), ScalarType::Int},
|
204
|
+
{ID2SYM(rb_intern("long")), ScalarType::Long},
|
205
|
+
{ID2SYM(rb_intern("int64")), ScalarType::Long},
|
206
|
+
{ID2SYM(rb_intern("float")), ScalarType::Float},
|
207
|
+
{ID2SYM(rb_intern("float32")), ScalarType::Float},
|
208
|
+
{ID2SYM(rb_intern("double")), ScalarType::Double},
|
209
|
+
{ID2SYM(rb_intern("float64")), ScalarType::Double},
|
210
|
+
{ID2SYM(rb_intern("complex_half")), ScalarType::ComplexHalf},
|
211
|
+
{ID2SYM(rb_intern("complex_float")), ScalarType::ComplexFloat},
|
212
|
+
{ID2SYM(rb_intern("complex_double")), ScalarType::ComplexDouble},
|
213
|
+
{ID2SYM(rb_intern("bool")), ScalarType::Bool},
|
214
|
+
{ID2SYM(rb_intern("qint8")), ScalarType::QInt8},
|
215
|
+
{ID2SYM(rb_intern("quint8")), ScalarType::QUInt8},
|
216
|
+
{ID2SYM(rb_intern("qint32")), ScalarType::QInt32},
|
217
|
+
{ID2SYM(rb_intern("bfloat16")), ScalarType::BFloat16},
|
218
|
+
};
|
219
|
+
|
220
|
+
auto it = dtype_map.find(args[i]);
|
221
|
+
if (it == dtype_map.end()) {
|
222
|
+
rb_raise(rb_eArgError, "invalid dtype: %s", THPUtils_unpackSymbol(args[i]).c_str());
|
223
|
+
}
|
224
|
+
return it->second;
|
225
|
+
}
|
226
|
+
|
227
|
+
inline c10::optional<ScalarType> RubyArgs::scalartypeOptional(int i) {
|
228
|
+
if (NIL_P(args[i])) return c10::nullopt;
|
229
|
+
return scalartype(i);
|
230
|
+
}
|
231
|
+
|
232
|
+
inline c10::optional<Scalar> RubyArgs::scalarOptional(int i) {
|
233
|
+
if (NIL_P(args[i])) return c10::nullopt;
|
234
|
+
return scalar(i);
|
235
|
+
}
|
236
|
+
|
237
|
+
inline c10::optional<int64_t> RubyArgs::toInt64Optional(int i) {
|
238
|
+
if (NIL_P(args[i])) return c10::nullopt;
|
239
|
+
return toInt64(i);
|
240
|
+
}
|
241
|
+
|
242
|
+
inline c10::optional<bool> RubyArgs::toBoolOptional(int i) {
|
243
|
+
if (NIL_P(args[i])) return c10::nullopt;
|
244
|
+
return toBool(i);
|
245
|
+
}
|
246
|
+
|
247
|
+
inline c10::optional<double> RubyArgs::toDoubleOptional(int i) {
|
248
|
+
if (NIL_P(args[i])) return c10::nullopt;
|
249
|
+
return toDouble(i);
|
250
|
+
}
|
251
|
+
|
252
|
+
inline c10::optional<at::Layout> RubyArgs::layoutOptional(int i) {
|
253
|
+
if (NIL_P(args[i])) return c10::nullopt;
|
254
|
+
|
255
|
+
static std::unordered_map<VALUE, Layout> layout_map = {
|
256
|
+
{ID2SYM(rb_intern("strided")), Layout::Strided},
|
257
|
+
};
|
258
|
+
|
259
|
+
auto it = layout_map.find(args[i]);
|
260
|
+
if (it == layout_map.end()) {
|
261
|
+
rb_raise(rb_eArgError, "invalid layout: %s", THPUtils_unpackSymbol(args[i]).c_str());
|
262
|
+
}
|
263
|
+
return it->second;
|
264
|
+
}
|
265
|
+
|
266
|
+
inline at::Device RubyArgs::device(int i) {
|
267
|
+
if (NIL_P(args[i])) {
|
268
|
+
return at::Device("cpu");
|
269
|
+
}
|
270
|
+
const std::string &device_str = THPUtils_unpackString(args[i]);
|
271
|
+
return at::Device(device_str);
|
272
|
+
}
|
273
|
+
|
274
|
+
inline at::MemoryFormat RubyArgs::memoryformat(int i) {
|
275
|
+
if (NIL_P(args[i])) return at::MemoryFormat::Contiguous;
|
276
|
+
throw std::runtime_error("memoryformat not supported yet");
|
277
|
+
}
|
278
|
+
|
279
|
+
inline c10::optional<at::MemoryFormat> RubyArgs::memoryformatOptional(int i) {
|
280
|
+
if (NIL_P(args[i])) return c10::nullopt;
|
281
|
+
return memoryformat(i);
|
282
|
+
}
|
283
|
+
|
284
|
+
inline std::string RubyArgs::string(int i) {
|
285
|
+
return from_ruby<std::string>(args[i]);
|
286
|
+
}
|
287
|
+
|
288
|
+
inline int64_t RubyArgs::toInt64(int i) {
|
289
|
+
if (NIL_P(args[i])) return signature.params[i].default_int;
|
290
|
+
return from_ruby<int64_t>(args[i]);
|
291
|
+
}
|
292
|
+
|
293
|
+
inline double RubyArgs::toDouble(int i) {
|
294
|
+
if (NIL_P(args[i])) return signature.params[i].default_double;
|
295
|
+
return from_ruby<double>(args[i]);
|
296
|
+
}
|
297
|
+
|
298
|
+
inline bool RubyArgs::toBool(int i) {
|
299
|
+
if (NIL_P(args[i])) return signature.params[i].default_bool;
|
300
|
+
return RTEST(args[i]);
|
301
|
+
}
|
302
|
+
|
303
|
+
inline bool RubyArgs::isNone(int i) {
|
304
|
+
return NIL_P(args[i]);
|
305
|
+
}
|
306
|
+
|
307
|
+
struct RubyArgParser {
|
308
|
+
std::vector<FunctionSignature> signatures_;
|
309
|
+
std::string function_name;
|
310
|
+
ssize_t max_args;
|
311
|
+
|
312
|
+
public:
|
313
|
+
RubyArgParser(std::vector<std::string> fmts) : max_args(0) {
|
314
|
+
int index = 0;
|
315
|
+
for (auto& fmt : fmts) {
|
316
|
+
signatures_.emplace_back(fmt, index);
|
317
|
+
++index;
|
318
|
+
}
|
319
|
+
for (auto& signature : signatures_) {
|
320
|
+
if (signature.max_args > max_args) {
|
321
|
+
max_args = signature.max_args;
|
322
|
+
}
|
323
|
+
}
|
324
|
+
if (signatures_.size() > 0) {
|
325
|
+
function_name = signatures_[0].name;
|
326
|
+
}
|
327
|
+
|
328
|
+
// Check deprecated signatures last
|
329
|
+
std::stable_partition(signatures_.begin(), signatures_.end(),
|
330
|
+
[](const FunctionSignature & sig) {
|
331
|
+
return !sig.deprecated;
|
332
|
+
});
|
333
|
+
}
|
334
|
+
|
335
|
+
RubyArgs parse(VALUE self, int argc, VALUE* argv, std::vector<VALUE> &parsed_args) {
|
336
|
+
VALUE args, kwargs;
|
337
|
+
rb_scan_args(argc, argv, "*:", &args, &kwargs);
|
338
|
+
|
339
|
+
if (signatures_.size() == 1) {
|
340
|
+
auto& signature = signatures_[0];
|
341
|
+
signature.parse(self, args, kwargs, parsed_args, true);
|
342
|
+
return RubyArgs(signature, parsed_args);
|
343
|
+
}
|
344
|
+
|
345
|
+
for (auto& signature : signatures_) {
|
346
|
+
if (signature.parse(self, args, kwargs, parsed_args, false)) {
|
347
|
+
return RubyArgs(signature, parsed_args);
|
348
|
+
}
|
349
|
+
}
|
350
|
+
|
351
|
+
print_error(self, args, kwargs, parsed_args);
|
352
|
+
|
353
|
+
// TODO better message
|
354
|
+
rb_raise(rb_eArgError, "No matching signatures");
|
355
|
+
}
|
356
|
+
|
357
|
+
void print_error(VALUE self, VALUE args, VALUE kwargs, std::vector<VALUE>& parsed_args) {
|
358
|
+
ssize_t num_args = (NIL_P(args) ? 0 : RARRAY_LEN(args)) + (NIL_P(kwargs) ? 0 : RHASH_SIZE(kwargs));
|
359
|
+
std::vector<int> plausible_idxs;
|
360
|
+
ssize_t i = 0;
|
361
|
+
for (auto& signature : signatures_) {
|
362
|
+
if (num_args >= signature.min_args && num_args <= signature.max_args && !signature.hidden) {
|
363
|
+
plausible_idxs.push_back(i);
|
364
|
+
}
|
365
|
+
i++;
|
366
|
+
}
|
367
|
+
|
368
|
+
if (plausible_idxs.size() == 1) {
|
369
|
+
auto& signature = signatures_[plausible_idxs[0]];
|
370
|
+
signature.parse(self, args, kwargs, parsed_args, true);
|
371
|
+
}
|
372
|
+
}
|
373
|
+
};
|
@@ -10,75 +10,55 @@
|
|
10
10
|
using namespace Rice;
|
11
11
|
|
12
12
|
using torch::Device;
|
13
|
+
using torch::Scalar;
|
13
14
|
using torch::ScalarType;
|
14
15
|
using torch::Tensor;
|
16
|
+
using torch::QScheme;
|
17
|
+
using torch::Generator;
|
18
|
+
using torch::TensorOptions;
|
19
|
+
using torch::Layout;
|
20
|
+
using torch::MemoryFormat;
|
21
|
+
using torch::IntArrayRef;
|
22
|
+
using torch::TensorList;
|
23
|
+
using torch::Storage;
|
24
|
+
|
25
|
+
#define HANDLE_TH_ERRORS \
|
26
|
+
try {
|
27
|
+
|
28
|
+
#define END_HANDLE_TH_ERRORS \
|
29
|
+
} catch (const torch::Error& ex) { \
|
30
|
+
rb_raise(rb_eRuntimeError, "%s", ex.what_without_backtrace()); \
|
31
|
+
} catch (const Rice::Exception& ex) { \
|
32
|
+
rb_raise(ex.class_of(), "%s", ex.what()); \
|
33
|
+
} catch (const std::exception& ex) { \
|
34
|
+
rb_raise(rb_eRuntimeError, "%s", ex.what()); \
|
35
|
+
}
|
15
36
|
|
16
|
-
|
17
|
-
|
18
|
-
class IntArrayRef {
|
19
|
-
std::vector<int64_t> vec;
|
20
|
-
public:
|
21
|
-
IntArrayRef(Object o) {
|
22
|
-
Array a = Array(o);
|
23
|
-
for (size_t i = 0; i < a.size(); i++) {
|
24
|
-
vec.push_back(from_ruby<int64_t>(a[i]));
|
25
|
-
}
|
26
|
-
}
|
27
|
-
operator torch::IntArrayRef() {
|
28
|
-
return torch::IntArrayRef(vec);
|
29
|
-
}
|
30
|
-
};
|
31
|
-
|
32
|
-
template<>
|
33
|
-
inline
|
34
|
-
IntArrayRef from_ruby<IntArrayRef>(Object x)
|
35
|
-
{
|
36
|
-
return IntArrayRef(x);
|
37
|
-
}
|
38
|
-
|
39
|
-
// for now
|
40
|
-
class Scalar {
|
41
|
-
torch::Scalar value;
|
42
|
-
public:
|
43
|
-
Scalar(Object o) {
|
44
|
-
// TODO cast based on Ruby type
|
45
|
-
if (o.rb_type() == T_FIXNUM) {
|
46
|
-
value = torch::Scalar(from_ruby<int64_t>(o));
|
47
|
-
} else {
|
48
|
-
value = torch::Scalar(from_ruby<float>(o));
|
49
|
-
}
|
50
|
-
}
|
51
|
-
operator torch::Scalar() {
|
52
|
-
return value;
|
53
|
-
}
|
54
|
-
};
|
37
|
+
#define RETURN_NIL \
|
38
|
+
return Qnil;
|
55
39
|
|
56
40
|
template<>
|
57
41
|
inline
|
58
|
-
|
42
|
+
std::vector<int64_t> from_ruby<std::vector<int64_t>>(Object x)
|
59
43
|
{
|
60
|
-
|
44
|
+
Array a = Array(x);
|
45
|
+
std::vector<int64_t> vec(a.size());
|
46
|
+
for (size_t i = 0; i < a.size(); i++) {
|
47
|
+
vec[i] = from_ruby<int64_t>(a[i]);
|
48
|
+
}
|
49
|
+
return vec;
|
61
50
|
}
|
62
51
|
|
63
|
-
class TensorList {
|
64
|
-
std::vector<torch::Tensor> vec;
|
65
|
-
public:
|
66
|
-
TensorList(Object o) {
|
67
|
-
Array a = Array(o);
|
68
|
-
for (size_t i = 0; i < a.size(); i++) {
|
69
|
-
vec.push_back(from_ruby<torch::Tensor>(a[i]));
|
70
|
-
}
|
71
|
-
}
|
72
|
-
operator torch::TensorList() {
|
73
|
-
return torch::TensorList(vec);
|
74
|
-
}
|
75
|
-
};
|
76
|
-
|
77
52
|
template<>
|
78
53
|
inline
|
79
|
-
|
54
|
+
std::vector<Tensor> from_ruby<std::vector<Tensor>>(Object x)
|
80
55
|
{
|
81
|
-
|
56
|
+
Array a = Array(x);
|
57
|
+
std::vector<Tensor> vec(a.size());
|
58
|
+
for (size_t i = 0; i < a.size(); i++) {
|
59
|
+
vec[i] = from_ruby<Tensor>(a[i]);
|
60
|
+
}
|
61
|
+
return vec;
|
82
62
|
}
|
83
63
|
|
84
64
|
class FanModeType {
|
@@ -147,51 +127,35 @@ NonlinearityType from_ruby<NonlinearityType>(Object x)
|
|
147
127
|
return NonlinearityType(x);
|
148
128
|
}
|
149
129
|
|
150
|
-
class
|
151
|
-
|
130
|
+
class OptionalTensor {
|
131
|
+
torch::Tensor value;
|
152
132
|
public:
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
operator int64_t() {
|
157
|
-
if (value.is_nil()) {
|
158
|
-
return torch::Reduction::None;
|
159
|
-
}
|
160
|
-
|
161
|
-
std::string s = String(value).str();
|
162
|
-
if (s == "mean") {
|
163
|
-
return torch::Reduction::Mean;
|
164
|
-
} else if (s == "sum") {
|
165
|
-
return torch::Reduction::Sum;
|
166
|
-
} else if (s == "none") {
|
167
|
-
return torch::Reduction::None;
|
133
|
+
OptionalTensor(Object o) {
|
134
|
+
if (o.is_nil()) {
|
135
|
+
value = {};
|
168
136
|
} else {
|
169
|
-
|
137
|
+
value = from_ruby<torch::Tensor>(o);
|
170
138
|
}
|
171
139
|
}
|
140
|
+
OptionalTensor(torch::Tensor o) {
|
141
|
+
value = o;
|
142
|
+
}
|
143
|
+
operator torch::Tensor() const {
|
144
|
+
return value;
|
145
|
+
}
|
172
146
|
};
|
173
147
|
|
174
148
|
template<>
|
175
149
|
inline
|
176
|
-
|
150
|
+
Scalar from_ruby<Scalar>(Object x)
|
177
151
|
{
|
178
|
-
|
152
|
+
if (x.rb_type() == T_FIXNUM) {
|
153
|
+
return torch::Scalar(from_ruby<int64_t>(x));
|
154
|
+
} else {
|
155
|
+
return torch::Scalar(from_ruby<double>(x));
|
156
|
+
}
|
179
157
|
}
|
180
158
|
|
181
|
-
class OptionalTensor {
|
182
|
-
Object value;
|
183
|
-
public:
|
184
|
-
OptionalTensor(Object o) {
|
185
|
-
value = o;
|
186
|
-
}
|
187
|
-
operator torch::Tensor() {
|
188
|
-
if (value.is_nil()) {
|
189
|
-
return {};
|
190
|
-
}
|
191
|
-
return from_ruby<torch::Tensor>(value);
|
192
|
-
}
|
193
|
-
};
|
194
|
-
|
195
159
|
template<>
|
196
160
|
inline
|
197
161
|
OptionalTensor from_ruby<OptionalTensor>(Object x)
|
@@ -221,9 +185,35 @@ torch::optional<int64_t> from_ruby<torch::optional<int64_t>>(Object x)
|
|
221
185
|
}
|
222
186
|
}
|
223
187
|
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
188
|
+
template<>
|
189
|
+
inline
|
190
|
+
torch::optional<double> from_ruby<torch::optional<double>>(Object x)
|
191
|
+
{
|
192
|
+
if (x.is_nil()) {
|
193
|
+
return torch::nullopt;
|
194
|
+
} else {
|
195
|
+
return torch::optional<double>{from_ruby<double>(x)};
|
196
|
+
}
|
197
|
+
}
|
198
|
+
|
199
|
+
template<>
|
200
|
+
inline
|
201
|
+
torch::optional<bool> from_ruby<torch::optional<bool>>(Object x)
|
202
|
+
{
|
203
|
+
if (x.is_nil()) {
|
204
|
+
return torch::nullopt;
|
205
|
+
} else {
|
206
|
+
return torch::optional<bool>{from_ruby<bool>(x)};
|
207
|
+
}
|
208
|
+
}
|
209
|
+
|
210
|
+
template<>
|
211
|
+
inline
|
212
|
+
torch::optional<Scalar> from_ruby<torch::optional<Scalar>>(Object x)
|
213
|
+
{
|
214
|
+
if (x.is_nil()) {
|
215
|
+
return torch::nullopt;
|
216
|
+
} else {
|
217
|
+
return torch::optional<Scalar>{from_ruby<Scalar>(x)};
|
218
|
+
}
|
219
|
+
}
|