torch-rb 0.3.4 → 0.4.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +28 -0
- data/README.md +2 -1
- data/codegen/function.rb +134 -0
- data/codegen/generate_functions.rb +549 -0
- data/{lib/torch/native → codegen}/native_functions.yaml +0 -0
- data/ext/torch/ext.cpp +76 -87
- data/ext/torch/extconf.rb +5 -2
- data/ext/torch/nn_functions.h +6 -0
- data/ext/torch/ruby_arg_parser.cpp +593 -0
- data/ext/torch/ruby_arg_parser.h +373 -0
- data/ext/torch/{templates.hpp → templates.h} +87 -97
- data/ext/torch/tensor_functions.h +6 -0
- data/ext/torch/torch_functions.h +6 -0
- data/ext/torch/utils.h +42 -0
- data/ext/torch/{templates.cpp → wrap_outputs.h} +44 -7
- data/lib/torch.rb +51 -77
- data/lib/torch/nn/functional.rb +142 -18
- data/lib/torch/nn/init.rb +5 -19
- data/lib/torch/nn/leaky_relu.rb +3 -3
- data/lib/torch/nn/module.rb +9 -1
- data/lib/torch/nn/upsample.rb +31 -0
- data/lib/torch/optim/adadelta.rb +1 -1
- data/lib/torch/optim/adam.rb +2 -2
- data/lib/torch/optim/adamax.rb +1 -1
- data/lib/torch/optim/adamw.rb +1 -1
- data/lib/torch/optim/asgd.rb +1 -1
- data/lib/torch/optim/sgd.rb +3 -3
- data/lib/torch/tensor.rb +36 -115
- data/lib/torch/utils/data/data_loader.rb +2 -0
- data/lib/torch/utils/data/tensor_dataset.rb +2 -0
- data/lib/torch/version.rb +1 -1
- metadata +19 -14
- data/lib/torch/native/dispatcher.rb +0 -48
- data/lib/torch/native/function.rb +0 -115
- data/lib/torch/native/generator.rb +0 -163
- data/lib/torch/native/parser.rb +0 -140
@@ -0,0 +1,373 @@
|
|
1
|
+
// adapted from PyTorch - python_arg_parser.h
|
2
|
+
|
3
|
+
#pragma once
|
4
|
+
|
5
|
+
#include <torch/torch.h>
|
6
|
+
#include <rice/Exception.hpp>
|
7
|
+
|
8
|
+
#include "templates.h"
|
9
|
+
#include "utils.h"
|
10
|
+
|
11
|
+
enum class ParameterType {
|
12
|
+
TENSOR, SCALAR, INT64, DOUBLE, COMPLEX, TENSOR_LIST, INT_LIST, GENERATOR,
|
13
|
+
BOOL, STORAGE, PYOBJECT, SCALARTYPE, LAYOUT, MEMORY_FORMAT, DEVICE, STRING,
|
14
|
+
DIMNAME, DIMNAME_LIST, QSCHEME, FLOAT_LIST
|
15
|
+
};
|
16
|
+
|
17
|
+
struct FunctionParameter {
|
18
|
+
FunctionParameter(const std::string& fmt, bool keyword_only);
|
19
|
+
|
20
|
+
bool check(VALUE obj, int argnum);
|
21
|
+
|
22
|
+
void set_default_str(const std::string& str);
|
23
|
+
std::string type_name() const;
|
24
|
+
|
25
|
+
ParameterType type_;
|
26
|
+
bool optional;
|
27
|
+
bool allow_none;
|
28
|
+
bool keyword_only;
|
29
|
+
bool allow_numbers_as_tensors = false;
|
30
|
+
int size;
|
31
|
+
std::string name;
|
32
|
+
VALUE ruby_name;
|
33
|
+
at::SmallVector<VALUE, 5> numpy_python_names;
|
34
|
+
at::Scalar default_scalar;
|
35
|
+
std::vector<int64_t> default_intlist;
|
36
|
+
union {
|
37
|
+
bool default_bool;
|
38
|
+
int64_t default_int;
|
39
|
+
double default_double;
|
40
|
+
double default_complex[2]; // see Scalar
|
41
|
+
at::ScalarType default_scalartype;
|
42
|
+
at::Layout default_layout;
|
43
|
+
};
|
44
|
+
};
|
45
|
+
|
46
|
+
struct FunctionSignature {
|
47
|
+
explicit FunctionSignature(const std::string& fmt, int index);
|
48
|
+
|
49
|
+
bool parse(VALUE self, VALUE args, VALUE kwargs, std::vector<VALUE>& dst, bool raise_exception);
|
50
|
+
|
51
|
+
std::string toString() const;
|
52
|
+
|
53
|
+
std::string name;
|
54
|
+
std::vector<FunctionParameter> params;
|
55
|
+
// std::vector<py::handle> overloaded_args;
|
56
|
+
ssize_t min_args;
|
57
|
+
ssize_t max_args;
|
58
|
+
ssize_t max_pos_args;
|
59
|
+
int index;
|
60
|
+
bool hidden;
|
61
|
+
bool deprecated;
|
62
|
+
bool disable_torch_function;
|
63
|
+
};
|
64
|
+
|
65
|
+
struct RubyArgs {
|
66
|
+
RubyArgs(const FunctionSignature& signature, std::vector<VALUE> &args)
|
67
|
+
: signature(signature)
|
68
|
+
, args(args)
|
69
|
+
, idx(signature.index) {}
|
70
|
+
|
71
|
+
const FunctionSignature& signature;
|
72
|
+
std::vector<VALUE> args;
|
73
|
+
int idx;
|
74
|
+
|
75
|
+
inline at::Tensor tensor(int i);
|
76
|
+
inline OptionalTensor optionalTensor(int i);
|
77
|
+
inline at::Scalar scalar(int i);
|
78
|
+
// inline at::Scalar scalarWithDefault(int i, at::Scalar default_scalar);
|
79
|
+
inline std::vector<at::Tensor> tensorlist(int i);
|
80
|
+
template<int N>
|
81
|
+
inline std::array<at::Tensor, N> tensorlist_n(int i);
|
82
|
+
inline std::vector<int64_t> intlist(int i);
|
83
|
+
// inline c10::OptionalArray<int64_t> intlistOptional(int i);
|
84
|
+
// inline std::vector<int64_t> intlistWithDefault(int i, std::vector<int64_t> default_intlist);
|
85
|
+
inline c10::optional<at::Generator> generator(int i);
|
86
|
+
inline at::Storage storage(int i);
|
87
|
+
inline at::ScalarType scalartype(int i);
|
88
|
+
// inline at::ScalarType scalartypeWithDefault(int i, at::ScalarType default_scalartype);
|
89
|
+
inline c10::optional<at::ScalarType> scalartypeOptional(int i);
|
90
|
+
inline c10::optional<at::Scalar> scalarOptional(int i);
|
91
|
+
inline c10::optional<int64_t> toInt64Optional(int i);
|
92
|
+
inline c10::optional<bool> toBoolOptional(int i);
|
93
|
+
inline c10::optional<double> toDoubleOptional(int i);
|
94
|
+
// inline c10::OptionalArray<double> doublelistOptional(int i);
|
95
|
+
// inline at::Layout layout(int i);
|
96
|
+
// inline at::Layout layoutWithDefault(int i, at::Layout default_layout);
|
97
|
+
inline c10::optional<at::Layout> layoutOptional(int i);
|
98
|
+
inline at::Device device(int i);
|
99
|
+
// inline at::Device deviceWithDefault(int i, const at::Device& default_device);
|
100
|
+
// inline c10::optional<at::Device> deviceOptional(int i);
|
101
|
+
// inline at::Dimname dimname(int i);
|
102
|
+
// inline std::vector<at::Dimname> dimnamelist(int i);
|
103
|
+
// inline c10::optional<std::vector<at::Dimname>> toDimnameListOptional(int i);
|
104
|
+
inline at::MemoryFormat memoryformat(int i);
|
105
|
+
inline c10::optional<at::MemoryFormat> memoryformatOptional(int i);
|
106
|
+
// inline at::QScheme toQScheme(int i);
|
107
|
+
inline std::string string(int i);
|
108
|
+
// inline c10::optional<std::string> stringOptional(int i);
|
109
|
+
// inline PyObject* pyobject(int i);
|
110
|
+
inline int64_t toInt64(int i);
|
111
|
+
// inline int64_t toInt64WithDefault(int i, int64_t default_int);
|
112
|
+
inline double toDouble(int i);
|
113
|
+
// inline double toDoubleWithDefault(int i, double default_double);
|
114
|
+
// inline c10::complex<double> toComplex(int i);
|
115
|
+
// inline c10::complex<double> toComplexWithDefault(int i, c10::complex<double> default_complex);
|
116
|
+
inline bool toBool(int i);
|
117
|
+
// inline bool toBoolWithDefault(int i, bool default_bool);
|
118
|
+
inline bool isNone(int i);
|
119
|
+
};
|
120
|
+
|
121
|
+
inline at::Tensor RubyArgs::tensor(int i) {
|
122
|
+
return from_ruby<torch::Tensor>(args[i]);
|
123
|
+
}
|
124
|
+
|
125
|
+
inline OptionalTensor RubyArgs::optionalTensor(int i) {
|
126
|
+
if (NIL_P(args[i])) return OptionalTensor(Nil);
|
127
|
+
return tensor(i);
|
128
|
+
}
|
129
|
+
|
130
|
+
inline at::Scalar RubyArgs::scalar(int i) {
|
131
|
+
if (NIL_P(args[i])) return signature.params[i].default_scalar;
|
132
|
+
return from_ruby<torch::Scalar>(args[i]);
|
133
|
+
}
|
134
|
+
|
135
|
+
inline std::vector<at::Tensor> RubyArgs::tensorlist(int i) {
|
136
|
+
if (NIL_P(args[i])) return std::vector<at::Tensor>();
|
137
|
+
return from_ruby<std::vector<Tensor>>(args[i]);
|
138
|
+
}
|
139
|
+
|
140
|
+
template<int N>
|
141
|
+
inline std::array<at::Tensor, N> RubyArgs::tensorlist_n(int i) {
|
142
|
+
auto res = std::array<at::Tensor, N>();
|
143
|
+
if (NIL_P(args[i])) return res;
|
144
|
+
VALUE arg = args[i];
|
145
|
+
Check_Type(arg, T_ARRAY);
|
146
|
+
auto size = RARRAY_LEN(arg);
|
147
|
+
if (size != N) {
|
148
|
+
rb_raise(rb_eArgError, "expected array of %d elements but got %d", N, (int)size);
|
149
|
+
}
|
150
|
+
for (int idx = 0; idx < size; idx++) {
|
151
|
+
VALUE obj = rb_ary_entry(arg, idx);
|
152
|
+
res[idx] = from_ruby<Tensor>(obj);
|
153
|
+
}
|
154
|
+
return res;
|
155
|
+
}
|
156
|
+
|
157
|
+
inline std::vector<int64_t> RubyArgs::intlist(int i) {
|
158
|
+
if (NIL_P(args[i])) return signature.params[i].default_intlist;
|
159
|
+
|
160
|
+
VALUE arg = args[i];
|
161
|
+
auto size = signature.params[i].size;
|
162
|
+
if (size > 0 && FIXNUM_P(arg)) {
|
163
|
+
return std::vector<int64_t>(size, FIX2INT(arg));
|
164
|
+
}
|
165
|
+
|
166
|
+
size = RARRAY_LEN(arg);
|
167
|
+
std::vector<int64_t> res(size);
|
168
|
+
for (idx = 0; idx < size; idx++) {
|
169
|
+
VALUE obj = rb_ary_entry(arg, idx);
|
170
|
+
if (FIXNUM_P(obj)) {
|
171
|
+
res[idx] = from_ruby<int64_t>(obj);
|
172
|
+
} else {
|
173
|
+
rb_raise(rb_eArgError, "%s(): argument '%s' must be %s, but found element of type %s at pos %d",
|
174
|
+
signature.name.c_str(), signature.params[i].name.c_str(),
|
175
|
+
signature.params[i].type_name().c_str(), rb_obj_classname(obj), idx + 1);
|
176
|
+
}
|
177
|
+
}
|
178
|
+
return res;
|
179
|
+
}
|
180
|
+
|
181
|
+
inline c10::optional<at::Generator> RubyArgs::generator(int i) {
|
182
|
+
if (NIL_P(args[i])) return c10::nullopt;
|
183
|
+
throw std::runtime_error("generator not supported yet");
|
184
|
+
}
|
185
|
+
|
186
|
+
inline at::Storage RubyArgs::storage(int i) {
|
187
|
+
if (NIL_P(args[i])) return at::Storage();
|
188
|
+
throw std::runtime_error("storage not supported yet");
|
189
|
+
}
|
190
|
+
|
191
|
+
inline ScalarType RubyArgs::scalartype(int i) {
|
192
|
+
if (NIL_P(args[i])) {
|
193
|
+
auto scalartype = signature.params[i].default_scalartype;
|
194
|
+
return (scalartype == at::ScalarType::Undefined) ? at::typeMetaToScalarType(at::get_default_dtype()) : scalartype;
|
195
|
+
}
|
196
|
+
|
197
|
+
static std::unordered_map<VALUE, ScalarType> dtype_map = {
|
198
|
+
{ID2SYM(rb_intern("uint8")), ScalarType::Byte},
|
199
|
+
{ID2SYM(rb_intern("int8")), ScalarType::Char},
|
200
|
+
{ID2SYM(rb_intern("short")), ScalarType::Short},
|
201
|
+
{ID2SYM(rb_intern("int16")), ScalarType::Short},
|
202
|
+
{ID2SYM(rb_intern("int")), ScalarType::Int},
|
203
|
+
{ID2SYM(rb_intern("int32")), ScalarType::Int},
|
204
|
+
{ID2SYM(rb_intern("long")), ScalarType::Long},
|
205
|
+
{ID2SYM(rb_intern("int64")), ScalarType::Long},
|
206
|
+
{ID2SYM(rb_intern("float")), ScalarType::Float},
|
207
|
+
{ID2SYM(rb_intern("float32")), ScalarType::Float},
|
208
|
+
{ID2SYM(rb_intern("double")), ScalarType::Double},
|
209
|
+
{ID2SYM(rb_intern("float64")), ScalarType::Double},
|
210
|
+
{ID2SYM(rb_intern("complex_half")), ScalarType::ComplexHalf},
|
211
|
+
{ID2SYM(rb_intern("complex_float")), ScalarType::ComplexFloat},
|
212
|
+
{ID2SYM(rb_intern("complex_double")), ScalarType::ComplexDouble},
|
213
|
+
{ID2SYM(rb_intern("bool")), ScalarType::Bool},
|
214
|
+
{ID2SYM(rb_intern("qint8")), ScalarType::QInt8},
|
215
|
+
{ID2SYM(rb_intern("quint8")), ScalarType::QUInt8},
|
216
|
+
{ID2SYM(rb_intern("qint32")), ScalarType::QInt32},
|
217
|
+
{ID2SYM(rb_intern("bfloat16")), ScalarType::BFloat16},
|
218
|
+
};
|
219
|
+
|
220
|
+
auto it = dtype_map.find(args[i]);
|
221
|
+
if (it == dtype_map.end()) {
|
222
|
+
rb_raise(rb_eArgError, "invalid dtype: %s", THPUtils_unpackSymbol(args[i]).c_str());
|
223
|
+
}
|
224
|
+
return it->second;
|
225
|
+
}
|
226
|
+
|
227
|
+
inline c10::optional<ScalarType> RubyArgs::scalartypeOptional(int i) {
|
228
|
+
if (NIL_P(args[i])) return c10::nullopt;
|
229
|
+
return scalartype(i);
|
230
|
+
}
|
231
|
+
|
232
|
+
inline c10::optional<Scalar> RubyArgs::scalarOptional(int i) {
|
233
|
+
if (NIL_P(args[i])) return c10::nullopt;
|
234
|
+
return scalar(i);
|
235
|
+
}
|
236
|
+
|
237
|
+
inline c10::optional<int64_t> RubyArgs::toInt64Optional(int i) {
|
238
|
+
if (NIL_P(args[i])) return c10::nullopt;
|
239
|
+
return toInt64(i);
|
240
|
+
}
|
241
|
+
|
242
|
+
inline c10::optional<bool> RubyArgs::toBoolOptional(int i) {
|
243
|
+
if (NIL_P(args[i])) return c10::nullopt;
|
244
|
+
return toBool(i);
|
245
|
+
}
|
246
|
+
|
247
|
+
inline c10::optional<double> RubyArgs::toDoubleOptional(int i) {
|
248
|
+
if (NIL_P(args[i])) return c10::nullopt;
|
249
|
+
return toDouble(i);
|
250
|
+
}
|
251
|
+
|
252
|
+
inline c10::optional<at::Layout> RubyArgs::layoutOptional(int i) {
|
253
|
+
if (NIL_P(args[i])) return c10::nullopt;
|
254
|
+
|
255
|
+
static std::unordered_map<VALUE, Layout> layout_map = {
|
256
|
+
{ID2SYM(rb_intern("strided")), Layout::Strided},
|
257
|
+
};
|
258
|
+
|
259
|
+
auto it = layout_map.find(args[i]);
|
260
|
+
if (it == layout_map.end()) {
|
261
|
+
rb_raise(rb_eArgError, "invalid layout: %s", THPUtils_unpackSymbol(args[i]).c_str());
|
262
|
+
}
|
263
|
+
return it->second;
|
264
|
+
}
|
265
|
+
|
266
|
+
inline at::Device RubyArgs::device(int i) {
|
267
|
+
if (NIL_P(args[i])) {
|
268
|
+
return at::Device("cpu");
|
269
|
+
}
|
270
|
+
const std::string &device_str = THPUtils_unpackString(args[i]);
|
271
|
+
return at::Device(device_str);
|
272
|
+
}
|
273
|
+
|
274
|
+
inline at::MemoryFormat RubyArgs::memoryformat(int i) {
|
275
|
+
if (NIL_P(args[i])) return at::MemoryFormat::Contiguous;
|
276
|
+
throw std::runtime_error("memoryformat not supported yet");
|
277
|
+
}
|
278
|
+
|
279
|
+
inline c10::optional<at::MemoryFormat> RubyArgs::memoryformatOptional(int i) {
|
280
|
+
if (NIL_P(args[i])) return c10::nullopt;
|
281
|
+
return memoryformat(i);
|
282
|
+
}
|
283
|
+
|
284
|
+
inline std::string RubyArgs::string(int i) {
|
285
|
+
return from_ruby<std::string>(args[i]);
|
286
|
+
}
|
287
|
+
|
288
|
+
inline int64_t RubyArgs::toInt64(int i) {
|
289
|
+
if (NIL_P(args[i])) return signature.params[i].default_int;
|
290
|
+
return from_ruby<int64_t>(args[i]);
|
291
|
+
}
|
292
|
+
|
293
|
+
inline double RubyArgs::toDouble(int i) {
|
294
|
+
if (NIL_P(args[i])) return signature.params[i].default_double;
|
295
|
+
return from_ruby<double>(args[i]);
|
296
|
+
}
|
297
|
+
|
298
|
+
inline bool RubyArgs::toBool(int i) {
|
299
|
+
if (NIL_P(args[i])) return signature.params[i].default_bool;
|
300
|
+
return RTEST(args[i]);
|
301
|
+
}
|
302
|
+
|
303
|
+
inline bool RubyArgs::isNone(int i) {
|
304
|
+
return NIL_P(args[i]);
|
305
|
+
}
|
306
|
+
|
307
|
+
struct RubyArgParser {
|
308
|
+
std::vector<FunctionSignature> signatures_;
|
309
|
+
std::string function_name;
|
310
|
+
ssize_t max_args;
|
311
|
+
|
312
|
+
public:
|
313
|
+
RubyArgParser(std::vector<std::string> fmts) : max_args(0) {
|
314
|
+
int index = 0;
|
315
|
+
for (auto& fmt : fmts) {
|
316
|
+
signatures_.emplace_back(fmt, index);
|
317
|
+
++index;
|
318
|
+
}
|
319
|
+
for (auto& signature : signatures_) {
|
320
|
+
if (signature.max_args > max_args) {
|
321
|
+
max_args = signature.max_args;
|
322
|
+
}
|
323
|
+
}
|
324
|
+
if (signatures_.size() > 0) {
|
325
|
+
function_name = signatures_[0].name;
|
326
|
+
}
|
327
|
+
|
328
|
+
// Check deprecated signatures last
|
329
|
+
std::stable_partition(signatures_.begin(), signatures_.end(),
|
330
|
+
[](const FunctionSignature & sig) {
|
331
|
+
return !sig.deprecated;
|
332
|
+
});
|
333
|
+
}
|
334
|
+
|
335
|
+
RubyArgs parse(VALUE self, int argc, VALUE* argv, std::vector<VALUE> &parsed_args) {
|
336
|
+
VALUE args, kwargs;
|
337
|
+
rb_scan_args(argc, argv, "*:", &args, &kwargs);
|
338
|
+
|
339
|
+
if (signatures_.size() == 1) {
|
340
|
+
auto& signature = signatures_[0];
|
341
|
+
signature.parse(self, args, kwargs, parsed_args, true);
|
342
|
+
return RubyArgs(signature, parsed_args);
|
343
|
+
}
|
344
|
+
|
345
|
+
for (auto& signature : signatures_) {
|
346
|
+
if (signature.parse(self, args, kwargs, parsed_args, false)) {
|
347
|
+
return RubyArgs(signature, parsed_args);
|
348
|
+
}
|
349
|
+
}
|
350
|
+
|
351
|
+
print_error(self, args, kwargs, parsed_args);
|
352
|
+
|
353
|
+
// TODO better message
|
354
|
+
rb_raise(rb_eArgError, "No matching signatures");
|
355
|
+
}
|
356
|
+
|
357
|
+
void print_error(VALUE self, VALUE args, VALUE kwargs, std::vector<VALUE>& parsed_args) {
|
358
|
+
ssize_t num_args = (NIL_P(args) ? 0 : RARRAY_LEN(args)) + (NIL_P(kwargs) ? 0 : RHASH_SIZE(kwargs));
|
359
|
+
std::vector<int> plausible_idxs;
|
360
|
+
ssize_t i = 0;
|
361
|
+
for (auto& signature : signatures_) {
|
362
|
+
if (num_args >= signature.min_args && num_args <= signature.max_args && !signature.hidden) {
|
363
|
+
plausible_idxs.push_back(i);
|
364
|
+
}
|
365
|
+
i++;
|
366
|
+
}
|
367
|
+
|
368
|
+
if (plausible_idxs.size() == 1) {
|
369
|
+
auto& signature = signatures_[plausible_idxs[0]];
|
370
|
+
signature.parse(self, args, kwargs, parsed_args, true);
|
371
|
+
}
|
372
|
+
}
|
373
|
+
};
|
@@ -10,75 +10,55 @@
|
|
10
10
|
using namespace Rice;
|
11
11
|
|
12
12
|
using torch::Device;
|
13
|
+
using torch::Scalar;
|
13
14
|
using torch::ScalarType;
|
14
15
|
using torch::Tensor;
|
16
|
+
using torch::QScheme;
|
17
|
+
using torch::Generator;
|
18
|
+
using torch::TensorOptions;
|
19
|
+
using torch::Layout;
|
20
|
+
using torch::MemoryFormat;
|
21
|
+
using torch::IntArrayRef;
|
22
|
+
using torch::TensorList;
|
23
|
+
using torch::Storage;
|
24
|
+
|
25
|
+
#define HANDLE_TH_ERRORS \
|
26
|
+
try {
|
27
|
+
|
28
|
+
#define END_HANDLE_TH_ERRORS \
|
29
|
+
} catch (const torch::Error& ex) { \
|
30
|
+
rb_raise(rb_eRuntimeError, "%s", ex.what_without_backtrace()); \
|
31
|
+
} catch (const Rice::Exception& ex) { \
|
32
|
+
rb_raise(ex.class_of(), "%s", ex.what()); \
|
33
|
+
} catch (const std::exception& ex) { \
|
34
|
+
rb_raise(rb_eRuntimeError, "%s", ex.what()); \
|
35
|
+
}
|
15
36
|
|
16
|
-
|
17
|
-
|
18
|
-
class IntArrayRef {
|
19
|
-
std::vector<int64_t> vec;
|
20
|
-
public:
|
21
|
-
IntArrayRef(Object o) {
|
22
|
-
Array a = Array(o);
|
23
|
-
for (size_t i = 0; i < a.size(); i++) {
|
24
|
-
vec.push_back(from_ruby<int64_t>(a[i]));
|
25
|
-
}
|
26
|
-
}
|
27
|
-
operator torch::IntArrayRef() {
|
28
|
-
return torch::IntArrayRef(vec);
|
29
|
-
}
|
30
|
-
};
|
31
|
-
|
32
|
-
template<>
|
33
|
-
inline
|
34
|
-
IntArrayRef from_ruby<IntArrayRef>(Object x)
|
35
|
-
{
|
36
|
-
return IntArrayRef(x);
|
37
|
-
}
|
38
|
-
|
39
|
-
// for now
|
40
|
-
class Scalar {
|
41
|
-
torch::Scalar value;
|
42
|
-
public:
|
43
|
-
Scalar(Object o) {
|
44
|
-
// TODO cast based on Ruby type
|
45
|
-
if (o.rb_type() == T_FIXNUM) {
|
46
|
-
value = torch::Scalar(from_ruby<int64_t>(o));
|
47
|
-
} else {
|
48
|
-
value = torch::Scalar(from_ruby<float>(o));
|
49
|
-
}
|
50
|
-
}
|
51
|
-
operator torch::Scalar() {
|
52
|
-
return value;
|
53
|
-
}
|
54
|
-
};
|
37
|
+
#define RETURN_NIL \
|
38
|
+
return Qnil;
|
55
39
|
|
56
40
|
template<>
|
57
41
|
inline
|
58
|
-
|
42
|
+
std::vector<int64_t> from_ruby<std::vector<int64_t>>(Object x)
|
59
43
|
{
|
60
|
-
|
44
|
+
Array a = Array(x);
|
45
|
+
std::vector<int64_t> vec(a.size());
|
46
|
+
for (size_t i = 0; i < a.size(); i++) {
|
47
|
+
vec[i] = from_ruby<int64_t>(a[i]);
|
48
|
+
}
|
49
|
+
return vec;
|
61
50
|
}
|
62
51
|
|
63
|
-
class TensorList {
|
64
|
-
std::vector<torch::Tensor> vec;
|
65
|
-
public:
|
66
|
-
TensorList(Object o) {
|
67
|
-
Array a = Array(o);
|
68
|
-
for (size_t i = 0; i < a.size(); i++) {
|
69
|
-
vec.push_back(from_ruby<torch::Tensor>(a[i]));
|
70
|
-
}
|
71
|
-
}
|
72
|
-
operator torch::TensorList() {
|
73
|
-
return torch::TensorList(vec);
|
74
|
-
}
|
75
|
-
};
|
76
|
-
|
77
52
|
template<>
|
78
53
|
inline
|
79
|
-
|
54
|
+
std::vector<Tensor> from_ruby<std::vector<Tensor>>(Object x)
|
80
55
|
{
|
81
|
-
|
56
|
+
Array a = Array(x);
|
57
|
+
std::vector<Tensor> vec(a.size());
|
58
|
+
for (size_t i = 0; i < a.size(); i++) {
|
59
|
+
vec[i] = from_ruby<Tensor>(a[i]);
|
60
|
+
}
|
61
|
+
return vec;
|
82
62
|
}
|
83
63
|
|
84
64
|
class FanModeType {
|
@@ -147,51 +127,35 @@ NonlinearityType from_ruby<NonlinearityType>(Object x)
|
|
147
127
|
return NonlinearityType(x);
|
148
128
|
}
|
149
129
|
|
150
|
-
class
|
151
|
-
|
130
|
+
class OptionalTensor {
|
131
|
+
torch::Tensor value;
|
152
132
|
public:
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
operator int64_t() {
|
157
|
-
if (value.is_nil()) {
|
158
|
-
return torch::Reduction::None;
|
159
|
-
}
|
160
|
-
|
161
|
-
std::string s = String(value).str();
|
162
|
-
if (s == "mean") {
|
163
|
-
return torch::Reduction::Mean;
|
164
|
-
} else if (s == "sum") {
|
165
|
-
return torch::Reduction::Sum;
|
166
|
-
} else if (s == "none") {
|
167
|
-
return torch::Reduction::None;
|
133
|
+
OptionalTensor(Object o) {
|
134
|
+
if (o.is_nil()) {
|
135
|
+
value = {};
|
168
136
|
} else {
|
169
|
-
|
137
|
+
value = from_ruby<torch::Tensor>(o);
|
170
138
|
}
|
171
139
|
}
|
140
|
+
OptionalTensor(torch::Tensor o) {
|
141
|
+
value = o;
|
142
|
+
}
|
143
|
+
operator torch::Tensor() const {
|
144
|
+
return value;
|
145
|
+
}
|
172
146
|
};
|
173
147
|
|
174
148
|
template<>
|
175
149
|
inline
|
176
|
-
|
150
|
+
Scalar from_ruby<Scalar>(Object x)
|
177
151
|
{
|
178
|
-
|
152
|
+
if (x.rb_type() == T_FIXNUM) {
|
153
|
+
return torch::Scalar(from_ruby<int64_t>(x));
|
154
|
+
} else {
|
155
|
+
return torch::Scalar(from_ruby<double>(x));
|
156
|
+
}
|
179
157
|
}
|
180
158
|
|
181
|
-
class OptionalTensor {
|
182
|
-
Object value;
|
183
|
-
public:
|
184
|
-
OptionalTensor(Object o) {
|
185
|
-
value = o;
|
186
|
-
}
|
187
|
-
operator torch::Tensor() {
|
188
|
-
if (value.is_nil()) {
|
189
|
-
return {};
|
190
|
-
}
|
191
|
-
return from_ruby<torch::Tensor>(value);
|
192
|
-
}
|
193
|
-
};
|
194
|
-
|
195
159
|
template<>
|
196
160
|
inline
|
197
161
|
OptionalTensor from_ruby<OptionalTensor>(Object x)
|
@@ -221,9 +185,35 @@ torch::optional<int64_t> from_ruby<torch::optional<int64_t>>(Object x)
|
|
221
185
|
}
|
222
186
|
}
|
223
187
|
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
188
|
+
template<>
|
189
|
+
inline
|
190
|
+
torch::optional<double> from_ruby<torch::optional<double>>(Object x)
|
191
|
+
{
|
192
|
+
if (x.is_nil()) {
|
193
|
+
return torch::nullopt;
|
194
|
+
} else {
|
195
|
+
return torch::optional<double>{from_ruby<double>(x)};
|
196
|
+
}
|
197
|
+
}
|
198
|
+
|
199
|
+
template<>
|
200
|
+
inline
|
201
|
+
torch::optional<bool> from_ruby<torch::optional<bool>>(Object x)
|
202
|
+
{
|
203
|
+
if (x.is_nil()) {
|
204
|
+
return torch::nullopt;
|
205
|
+
} else {
|
206
|
+
return torch::optional<bool>{from_ruby<bool>(x)};
|
207
|
+
}
|
208
|
+
}
|
209
|
+
|
210
|
+
template<>
|
211
|
+
inline
|
212
|
+
torch::optional<Scalar> from_ruby<torch::optional<Scalar>>(Object x)
|
213
|
+
{
|
214
|
+
if (x.is_nil()) {
|
215
|
+
return torch::nullopt;
|
216
|
+
} else {
|
217
|
+
return torch::optional<Scalar>{from_ruby<Scalar>(x)};
|
218
|
+
}
|
219
|
+
}
|