torch-rb 0.3.7 → 0.4.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/README.md +1 -1
- data/codegen/function.rb +134 -0
- data/codegen/generate_functions.rb +546 -0
- data/{lib/torch/native → codegen}/native_functions.yaml +0 -0
- data/ext/torch/ext.cpp +54 -75
- data/ext/torch/extconf.rb +2 -2
- data/ext/torch/nn_functions.h +6 -0
- data/ext/torch/ruby_arg_parser.cpp +593 -0
- data/ext/torch/ruby_arg_parser.h +373 -0
- data/ext/torch/{templates.hpp → templates.h} +30 -51
- data/ext/torch/tensor_functions.h +6 -0
- data/ext/torch/torch_functions.h +6 -0
- data/ext/torch/utils.h +42 -0
- data/ext/torch/{templates.cpp → wrap_outputs.h} +16 -15
- data/lib/torch.rb +0 -62
- data/lib/torch/nn/functional.rb +30 -16
- data/lib/torch/nn/init.rb +5 -19
- data/lib/torch/optim/adadelta.rb +1 -1
- data/lib/torch/optim/adam.rb +2 -2
- data/lib/torch/optim/adamax.rb +1 -1
- data/lib/torch/optim/adamw.rb +1 -1
- data/lib/torch/optim/asgd.rb +1 -1
- data/lib/torch/optim/sgd.rb +3 -3
- data/lib/torch/tensor.rb +25 -105
- data/lib/torch/version.rb +1 -1
- metadata +27 -9
- data/lib/torch/native/dispatcher.rb +0 -70
- data/lib/torch/native/function.rb +0 -200
- data/lib/torch/native/generator.rb +0 -178
- data/lib/torch/native/parser.rb +0 -117
@@ -0,0 +1,373 @@
|
|
1
|
+
// adapted from PyTorch - python_arg_parser.h
|
2
|
+
|
3
|
+
#pragma once
|
4
|
+
|
5
|
+
#include <torch/torch.h>
|
6
|
+
#include <rice/Exception.hpp>
|
7
|
+
|
8
|
+
#include "templates.h"
|
9
|
+
#include "utils.h"
|
10
|
+
|
11
|
+
enum class ParameterType {
|
12
|
+
TENSOR, SCALAR, INT64, DOUBLE, COMPLEX, TENSOR_LIST, INT_LIST, GENERATOR,
|
13
|
+
BOOL, STORAGE, PYOBJECT, SCALARTYPE, LAYOUT, MEMORY_FORMAT, DEVICE, STRING,
|
14
|
+
DIMNAME, DIMNAME_LIST, QSCHEME, FLOAT_LIST
|
15
|
+
};
|
16
|
+
|
17
|
+
struct FunctionParameter {
|
18
|
+
FunctionParameter(const std::string& fmt, bool keyword_only);
|
19
|
+
|
20
|
+
bool check(VALUE obj, int argnum);
|
21
|
+
|
22
|
+
void set_default_str(const std::string& str);
|
23
|
+
std::string type_name() const;
|
24
|
+
|
25
|
+
ParameterType type_;
|
26
|
+
bool optional;
|
27
|
+
bool allow_none;
|
28
|
+
bool keyword_only;
|
29
|
+
bool allow_numbers_as_tensors = false;
|
30
|
+
int size;
|
31
|
+
std::string name;
|
32
|
+
VALUE ruby_name;
|
33
|
+
at::SmallVector<VALUE, 5> numpy_python_names;
|
34
|
+
at::Scalar default_scalar;
|
35
|
+
std::vector<int64_t> default_intlist;
|
36
|
+
union {
|
37
|
+
bool default_bool;
|
38
|
+
int64_t default_int;
|
39
|
+
double default_double;
|
40
|
+
double default_complex[2]; // see Scalar
|
41
|
+
at::ScalarType default_scalartype;
|
42
|
+
at::Layout default_layout;
|
43
|
+
};
|
44
|
+
};
|
45
|
+
|
46
|
+
struct FunctionSignature {
|
47
|
+
explicit FunctionSignature(const std::string& fmt, int index);
|
48
|
+
|
49
|
+
bool parse(VALUE self, VALUE args, VALUE kwargs, std::vector<VALUE>& dst, bool raise_exception);
|
50
|
+
|
51
|
+
std::string toString() const;
|
52
|
+
|
53
|
+
std::string name;
|
54
|
+
std::vector<FunctionParameter> params;
|
55
|
+
// std::vector<py::handle> overloaded_args;
|
56
|
+
ssize_t min_args;
|
57
|
+
ssize_t max_args;
|
58
|
+
ssize_t max_pos_args;
|
59
|
+
int index;
|
60
|
+
bool hidden;
|
61
|
+
bool deprecated;
|
62
|
+
bool disable_torch_function;
|
63
|
+
};
|
64
|
+
|
65
|
+
struct RubyArgs {
|
66
|
+
RubyArgs(const FunctionSignature& signature, std::vector<VALUE> &args)
|
67
|
+
: signature(signature)
|
68
|
+
, args(args)
|
69
|
+
, idx(signature.index) {}
|
70
|
+
|
71
|
+
const FunctionSignature& signature;
|
72
|
+
std::vector<VALUE> args;
|
73
|
+
int idx;
|
74
|
+
|
75
|
+
inline at::Tensor tensor(int i);
|
76
|
+
inline OptionalTensor optionalTensor(int i);
|
77
|
+
inline at::Scalar scalar(int i);
|
78
|
+
// inline at::Scalar scalarWithDefault(int i, at::Scalar default_scalar);
|
79
|
+
inline std::vector<at::Tensor> tensorlist(int i);
|
80
|
+
template<int N>
|
81
|
+
inline std::array<at::Tensor, N> tensorlist_n(int i);
|
82
|
+
inline std::vector<int64_t> intlist(int i);
|
83
|
+
// inline c10::OptionalArray<int64_t> intlistOptional(int i);
|
84
|
+
// inline std::vector<int64_t> intlistWithDefault(int i, std::vector<int64_t> default_intlist);
|
85
|
+
inline c10::optional<at::Generator> generator(int i);
|
86
|
+
inline at::Storage storage(int i);
|
87
|
+
inline at::ScalarType scalartype(int i);
|
88
|
+
// inline at::ScalarType scalartypeWithDefault(int i, at::ScalarType default_scalartype);
|
89
|
+
inline c10::optional<at::ScalarType> scalartypeOptional(int i);
|
90
|
+
inline c10::optional<at::Scalar> scalarOptional(int i);
|
91
|
+
inline c10::optional<int64_t> toInt64Optional(int i);
|
92
|
+
inline c10::optional<bool> toBoolOptional(int i);
|
93
|
+
inline c10::optional<double> toDoubleOptional(int i);
|
94
|
+
// inline c10::OptionalArray<double> doublelistOptional(int i);
|
95
|
+
// inline at::Layout layout(int i);
|
96
|
+
// inline at::Layout layoutWithDefault(int i, at::Layout default_layout);
|
97
|
+
inline c10::optional<at::Layout> layoutOptional(int i);
|
98
|
+
inline at::Device device(int i);
|
99
|
+
// inline at::Device deviceWithDefault(int i, const at::Device& default_device);
|
100
|
+
// inline c10::optional<at::Device> deviceOptional(int i);
|
101
|
+
// inline at::Dimname dimname(int i);
|
102
|
+
// inline std::vector<at::Dimname> dimnamelist(int i);
|
103
|
+
// inline c10::optional<std::vector<at::Dimname>> toDimnameListOptional(int i);
|
104
|
+
inline at::MemoryFormat memoryformat(int i);
|
105
|
+
inline c10::optional<at::MemoryFormat> memoryformatOptional(int i);
|
106
|
+
// inline at::QScheme toQScheme(int i);
|
107
|
+
inline std::string string(int i);
|
108
|
+
// inline c10::optional<std::string> stringOptional(int i);
|
109
|
+
// inline PyObject* pyobject(int i);
|
110
|
+
inline int64_t toInt64(int i);
|
111
|
+
// inline int64_t toInt64WithDefault(int i, int64_t default_int);
|
112
|
+
inline double toDouble(int i);
|
113
|
+
// inline double toDoubleWithDefault(int i, double default_double);
|
114
|
+
// inline c10::complex<double> toComplex(int i);
|
115
|
+
// inline c10::complex<double> toComplexWithDefault(int i, c10::complex<double> default_complex);
|
116
|
+
inline bool toBool(int i);
|
117
|
+
// inline bool toBoolWithDefault(int i, bool default_bool);
|
118
|
+
inline bool isNone(int i);
|
119
|
+
};
|
120
|
+
|
121
|
+
inline at::Tensor RubyArgs::tensor(int i) {
|
122
|
+
return from_ruby<torch::Tensor>(args[i]);
|
123
|
+
}
|
124
|
+
|
125
|
+
inline OptionalTensor RubyArgs::optionalTensor(int i) {
|
126
|
+
if (NIL_P(args[i])) return OptionalTensor(Nil);
|
127
|
+
return tensor(i);
|
128
|
+
}
|
129
|
+
|
130
|
+
inline at::Scalar RubyArgs::scalar(int i) {
|
131
|
+
if (NIL_P(args[i])) return signature.params[i].default_scalar;
|
132
|
+
return from_ruby<torch::Scalar>(args[i]);
|
133
|
+
}
|
134
|
+
|
135
|
+
inline std::vector<at::Tensor> RubyArgs::tensorlist(int i) {
|
136
|
+
if (NIL_P(args[i])) return std::vector<at::Tensor>();
|
137
|
+
return from_ruby<std::vector<Tensor>>(args[i]);
|
138
|
+
}
|
139
|
+
|
140
|
+
template<int N>
|
141
|
+
inline std::array<at::Tensor, N> RubyArgs::tensorlist_n(int i) {
|
142
|
+
auto res = std::array<at::Tensor, N>();
|
143
|
+
if (NIL_P(args[i])) return res;
|
144
|
+
VALUE arg = args[i];
|
145
|
+
Check_Type(arg, T_ARRAY);
|
146
|
+
auto size = RARRAY_LEN(arg);
|
147
|
+
if (size != N) {
|
148
|
+
rb_raise(rb_eArgError, "expected array of %d elements but got %d", N, (int)size);
|
149
|
+
}
|
150
|
+
for (int idx = 0; idx < size; idx++) {
|
151
|
+
VALUE obj = rb_ary_entry(arg, idx);
|
152
|
+
res[idx] = from_ruby<Tensor>(obj);
|
153
|
+
}
|
154
|
+
return res;
|
155
|
+
}
|
156
|
+
|
157
|
+
inline std::vector<int64_t> RubyArgs::intlist(int i) {
|
158
|
+
if (NIL_P(args[i])) return signature.params[i].default_intlist;
|
159
|
+
|
160
|
+
VALUE arg = args[i];
|
161
|
+
auto size = signature.params[i].size;
|
162
|
+
if (size > 0 && FIXNUM_P(arg)) {
|
163
|
+
return std::vector<int64_t>(size, FIX2INT(arg));
|
164
|
+
}
|
165
|
+
|
166
|
+
size = RARRAY_LEN(arg);
|
167
|
+
std::vector<int64_t> res(size);
|
168
|
+
for (idx = 0; idx < size; idx++) {
|
169
|
+
VALUE obj = rb_ary_entry(arg, idx);
|
170
|
+
if (FIXNUM_P(obj)) {
|
171
|
+
res[idx] = from_ruby<int64_t>(obj);
|
172
|
+
} else {
|
173
|
+
rb_raise(rb_eArgError, "%s(): argument '%s' must be %s, but found element of type %s at pos %d",
|
174
|
+
signature.name.c_str(), signature.params[i].name.c_str(),
|
175
|
+
signature.params[i].type_name().c_str(), rb_obj_classname(obj), idx + 1);
|
176
|
+
}
|
177
|
+
}
|
178
|
+
return res;
|
179
|
+
}
|
180
|
+
|
181
|
+
inline c10::optional<at::Generator> RubyArgs::generator(int i) {
|
182
|
+
if (NIL_P(args[i])) return c10::nullopt;
|
183
|
+
throw std::runtime_error("generator not supported yet");
|
184
|
+
}
|
185
|
+
|
186
|
+
inline at::Storage RubyArgs::storage(int i) {
|
187
|
+
if (NIL_P(args[i])) return at::Storage();
|
188
|
+
throw std::runtime_error("storage not supported yet");
|
189
|
+
}
|
190
|
+
|
191
|
+
inline ScalarType RubyArgs::scalartype(int i) {
|
192
|
+
if (NIL_P(args[i])) {
|
193
|
+
auto scalartype = signature.params[i].default_scalartype;
|
194
|
+
return (scalartype == at::ScalarType::Undefined) ? at::typeMetaToScalarType(at::get_default_dtype()) : scalartype;
|
195
|
+
}
|
196
|
+
|
197
|
+
static std::unordered_map<VALUE, ScalarType> dtype_map = {
|
198
|
+
{ID2SYM(rb_intern("uint8")), ScalarType::Byte},
|
199
|
+
{ID2SYM(rb_intern("int8")), ScalarType::Char},
|
200
|
+
{ID2SYM(rb_intern("short")), ScalarType::Short},
|
201
|
+
{ID2SYM(rb_intern("int16")), ScalarType::Short},
|
202
|
+
{ID2SYM(rb_intern("int")), ScalarType::Int},
|
203
|
+
{ID2SYM(rb_intern("int32")), ScalarType::Int},
|
204
|
+
{ID2SYM(rb_intern("long")), ScalarType::Long},
|
205
|
+
{ID2SYM(rb_intern("int64")), ScalarType::Long},
|
206
|
+
{ID2SYM(rb_intern("float")), ScalarType::Float},
|
207
|
+
{ID2SYM(rb_intern("float32")), ScalarType::Float},
|
208
|
+
{ID2SYM(rb_intern("double")), ScalarType::Double},
|
209
|
+
{ID2SYM(rb_intern("float64")), ScalarType::Double},
|
210
|
+
{ID2SYM(rb_intern("complex_half")), ScalarType::ComplexHalf},
|
211
|
+
{ID2SYM(rb_intern("complex_float")), ScalarType::ComplexFloat},
|
212
|
+
{ID2SYM(rb_intern("complex_double")), ScalarType::ComplexDouble},
|
213
|
+
{ID2SYM(rb_intern("bool")), ScalarType::Bool},
|
214
|
+
{ID2SYM(rb_intern("qint8")), ScalarType::QInt8},
|
215
|
+
{ID2SYM(rb_intern("quint8")), ScalarType::QUInt8},
|
216
|
+
{ID2SYM(rb_intern("qint32")), ScalarType::QInt32},
|
217
|
+
{ID2SYM(rb_intern("bfloat16")), ScalarType::BFloat16},
|
218
|
+
};
|
219
|
+
|
220
|
+
auto it = dtype_map.find(args[i]);
|
221
|
+
if (it == dtype_map.end()) {
|
222
|
+
rb_raise(rb_eArgError, "invalid dtype: %s", THPUtils_unpackSymbol(args[i]).c_str());
|
223
|
+
}
|
224
|
+
return it->second;
|
225
|
+
}
|
226
|
+
|
227
|
+
inline c10::optional<ScalarType> RubyArgs::scalartypeOptional(int i) {
|
228
|
+
if (NIL_P(args[i])) return c10::nullopt;
|
229
|
+
return scalartype(i);
|
230
|
+
}
|
231
|
+
|
232
|
+
inline c10::optional<Scalar> RubyArgs::scalarOptional(int i) {
|
233
|
+
if (NIL_P(args[i])) return c10::nullopt;
|
234
|
+
return scalar(i);
|
235
|
+
}
|
236
|
+
|
237
|
+
inline c10::optional<int64_t> RubyArgs::toInt64Optional(int i) {
|
238
|
+
if (NIL_P(args[i])) return c10::nullopt;
|
239
|
+
return toInt64(i);
|
240
|
+
}
|
241
|
+
|
242
|
+
inline c10::optional<bool> RubyArgs::toBoolOptional(int i) {
|
243
|
+
if (NIL_P(args[i])) return c10::nullopt;
|
244
|
+
return toBool(i);
|
245
|
+
}
|
246
|
+
|
247
|
+
inline c10::optional<double> RubyArgs::toDoubleOptional(int i) {
|
248
|
+
if (NIL_P(args[i])) return c10::nullopt;
|
249
|
+
return toDouble(i);
|
250
|
+
}
|
251
|
+
|
252
|
+
inline c10::optional<at::Layout> RubyArgs::layoutOptional(int i) {
|
253
|
+
if (NIL_P(args[i])) return c10::nullopt;
|
254
|
+
|
255
|
+
static std::unordered_map<VALUE, Layout> layout_map = {
|
256
|
+
{ID2SYM(rb_intern("strided")), Layout::Strided},
|
257
|
+
};
|
258
|
+
|
259
|
+
auto it = layout_map.find(args[i]);
|
260
|
+
if (it == layout_map.end()) {
|
261
|
+
rb_raise(rb_eArgError, "invalid layout: %s", THPUtils_unpackSymbol(args[i]).c_str());
|
262
|
+
}
|
263
|
+
return it->second;
|
264
|
+
}
|
265
|
+
|
266
|
+
inline at::Device RubyArgs::device(int i) {
|
267
|
+
if (NIL_P(args[i])) {
|
268
|
+
return at::Device("cpu");
|
269
|
+
}
|
270
|
+
const std::string &device_str = THPUtils_unpackString(args[i]);
|
271
|
+
return at::Device(device_str);
|
272
|
+
}
|
273
|
+
|
274
|
+
inline at::MemoryFormat RubyArgs::memoryformat(int i) {
|
275
|
+
if (NIL_P(args[i])) return at::MemoryFormat::Contiguous;
|
276
|
+
throw std::runtime_error("memoryformat not supported yet");
|
277
|
+
}
|
278
|
+
|
279
|
+
inline c10::optional<at::MemoryFormat> RubyArgs::memoryformatOptional(int i) {
|
280
|
+
if (NIL_P(args[i])) return c10::nullopt;
|
281
|
+
return memoryformat(i);
|
282
|
+
}
|
283
|
+
|
284
|
+
inline std::string RubyArgs::string(int i) {
|
285
|
+
return from_ruby<std::string>(args[i]);
|
286
|
+
}
|
287
|
+
|
288
|
+
inline int64_t RubyArgs::toInt64(int i) {
|
289
|
+
if (NIL_P(args[i])) return signature.params[i].default_int;
|
290
|
+
return from_ruby<int64_t>(args[i]);
|
291
|
+
}
|
292
|
+
|
293
|
+
inline double RubyArgs::toDouble(int i) {
|
294
|
+
if (NIL_P(args[i])) return signature.params[i].default_double;
|
295
|
+
return from_ruby<double>(args[i]);
|
296
|
+
}
|
297
|
+
|
298
|
+
inline bool RubyArgs::toBool(int i) {
|
299
|
+
if (NIL_P(args[i])) return signature.params[i].default_bool;
|
300
|
+
return RTEST(args[i]);
|
301
|
+
}
|
302
|
+
|
303
|
+
inline bool RubyArgs::isNone(int i) {
|
304
|
+
return NIL_P(args[i]);
|
305
|
+
}
|
306
|
+
|
307
|
+
struct RubyArgParser {
|
308
|
+
std::vector<FunctionSignature> signatures_;
|
309
|
+
std::string function_name;
|
310
|
+
ssize_t max_args;
|
311
|
+
|
312
|
+
public:
|
313
|
+
RubyArgParser(std::vector<std::string> fmts) : max_args(0) {
|
314
|
+
int index = 0;
|
315
|
+
for (auto& fmt : fmts) {
|
316
|
+
signatures_.emplace_back(fmt, index);
|
317
|
+
++index;
|
318
|
+
}
|
319
|
+
for (auto& signature : signatures_) {
|
320
|
+
if (signature.max_args > max_args) {
|
321
|
+
max_args = signature.max_args;
|
322
|
+
}
|
323
|
+
}
|
324
|
+
if (signatures_.size() > 0) {
|
325
|
+
function_name = signatures_[0].name;
|
326
|
+
}
|
327
|
+
|
328
|
+
// Check deprecated signatures last
|
329
|
+
std::stable_partition(signatures_.begin(), signatures_.end(),
|
330
|
+
[](const FunctionSignature & sig) {
|
331
|
+
return !sig.deprecated;
|
332
|
+
});
|
333
|
+
}
|
334
|
+
|
335
|
+
RubyArgs parse(VALUE self, int argc, VALUE* argv, std::vector<VALUE> &parsed_args) {
|
336
|
+
VALUE args, kwargs;
|
337
|
+
rb_scan_args(argc, argv, "*:", &args, &kwargs);
|
338
|
+
|
339
|
+
if (signatures_.size() == 1) {
|
340
|
+
auto& signature = signatures_[0];
|
341
|
+
signature.parse(self, args, kwargs, parsed_args, true);
|
342
|
+
return RubyArgs(signature, parsed_args);
|
343
|
+
}
|
344
|
+
|
345
|
+
for (auto& signature : signatures_) {
|
346
|
+
if (signature.parse(self, args, kwargs, parsed_args, false)) {
|
347
|
+
return RubyArgs(signature, parsed_args);
|
348
|
+
}
|
349
|
+
}
|
350
|
+
|
351
|
+
print_error(self, args, kwargs, parsed_args);
|
352
|
+
|
353
|
+
// TODO better message
|
354
|
+
rb_raise(rb_eArgError, "No matching signatures");
|
355
|
+
}
|
356
|
+
|
357
|
+
void print_error(VALUE self, VALUE args, VALUE kwargs, std::vector<VALUE>& parsed_args) {
|
358
|
+
ssize_t num_args = (NIL_P(args) ? 0 : RARRAY_LEN(args)) + (NIL_P(kwargs) ? 0 : RHASH_SIZE(kwargs));
|
359
|
+
std::vector<int> plausible_idxs;
|
360
|
+
ssize_t i = 0;
|
361
|
+
for (auto& signature : signatures_) {
|
362
|
+
if (num_args >= signature.min_args && num_args <= signature.max_args && !signature.hidden) {
|
363
|
+
plausible_idxs.push_back(i);
|
364
|
+
}
|
365
|
+
i++;
|
366
|
+
}
|
367
|
+
|
368
|
+
if (plausible_idxs.size() == 1) {
|
369
|
+
auto& signature = signatures_[plausible_idxs[0]];
|
370
|
+
signature.parse(self, args, kwargs, parsed_args, true);
|
371
|
+
}
|
372
|
+
}
|
373
|
+
};
|
@@ -13,8 +13,29 @@ using torch::Device;
|
|
13
13
|
using torch::Scalar;
|
14
14
|
using torch::ScalarType;
|
15
15
|
using torch::Tensor;
|
16
|
+
using torch::QScheme;
|
17
|
+
using torch::Generator;
|
18
|
+
using torch::TensorOptions;
|
19
|
+
using torch::Layout;
|
20
|
+
using torch::MemoryFormat;
|
16
21
|
using torch::IntArrayRef;
|
17
22
|
using torch::TensorList;
|
23
|
+
using torch::Storage;
|
24
|
+
|
25
|
+
#define HANDLE_TH_ERRORS \
|
26
|
+
try {
|
27
|
+
|
28
|
+
#define END_HANDLE_TH_ERRORS \
|
29
|
+
} catch (const torch::Error& ex) { \
|
30
|
+
rb_raise(rb_eRuntimeError, "%s", ex.what_without_backtrace()); \
|
31
|
+
} catch (const Rice::Exception& ex) { \
|
32
|
+
rb_raise(ex.class_of(), "%s", ex.what()); \
|
33
|
+
} catch (const std::exception& ex) { \
|
34
|
+
rb_raise(rb_eRuntimeError, "%s", ex.what()); \
|
35
|
+
}
|
36
|
+
|
37
|
+
#define RETURN_NIL \
|
38
|
+
return Qnil;
|
18
39
|
|
19
40
|
template<>
|
20
41
|
inline
|
@@ -106,48 +127,21 @@ NonlinearityType from_ruby<NonlinearityType>(Object x)
|
|
106
127
|
return NonlinearityType(x);
|
107
128
|
}
|
108
129
|
|
109
|
-
class
|
110
|
-
|
130
|
+
class OptionalTensor {
|
131
|
+
torch::Tensor value;
|
111
132
|
public:
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
operator int64_t() {
|
116
|
-
if (value.is_nil()) {
|
117
|
-
return torch::Reduction::None;
|
118
|
-
}
|
119
|
-
|
120
|
-
std::string s = String(value).str();
|
121
|
-
if (s == "mean") {
|
122
|
-
return torch::Reduction::Mean;
|
123
|
-
} else if (s == "sum") {
|
124
|
-
return torch::Reduction::Sum;
|
125
|
-
} else if (s == "none") {
|
126
|
-
return torch::Reduction::None;
|
133
|
+
OptionalTensor(Object o) {
|
134
|
+
if (o.is_nil()) {
|
135
|
+
value = {};
|
127
136
|
} else {
|
128
|
-
|
137
|
+
value = from_ruby<torch::Tensor>(o);
|
129
138
|
}
|
130
139
|
}
|
131
|
-
|
132
|
-
|
133
|
-
template<>
|
134
|
-
inline
|
135
|
-
MyReduction from_ruby<MyReduction>(Object x)
|
136
|
-
{
|
137
|
-
return MyReduction(x);
|
138
|
-
}
|
139
|
-
|
140
|
-
class OptionalTensor {
|
141
|
-
Object value;
|
142
|
-
public:
|
143
|
-
OptionalTensor(Object o) {
|
140
|
+
OptionalTensor(torch::Tensor o) {
|
144
141
|
value = o;
|
145
142
|
}
|
146
|
-
operator torch::Tensor() {
|
147
|
-
|
148
|
-
return {};
|
149
|
-
}
|
150
|
-
return from_ruby<torch::Tensor>(value);
|
143
|
+
operator torch::Tensor() const {
|
144
|
+
return value;
|
151
145
|
}
|
152
146
|
};
|
153
147
|
|
@@ -223,18 +217,3 @@ torch::optional<Scalar> from_ruby<torch::optional<Scalar>>(Object x)
|
|
223
217
|
return torch::optional<Scalar>{from_ruby<Scalar>(x)};
|
224
218
|
}
|
225
219
|
}
|
226
|
-
|
227
|
-
Object wrap(bool x);
|
228
|
-
Object wrap(int64_t x);
|
229
|
-
Object wrap(double x);
|
230
|
-
Object wrap(torch::Tensor x);
|
231
|
-
Object wrap(torch::Scalar x);
|
232
|
-
Object wrap(torch::ScalarType x);
|
233
|
-
Object wrap(torch::QScheme x);
|
234
|
-
Object wrap(std::tuple<torch::Tensor, torch::Tensor> x);
|
235
|
-
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x);
|
236
|
-
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x);
|
237
|
-
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x);
|
238
|
-
Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x);
|
239
|
-
Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x);
|
240
|
-
Object wrap(std::vector<torch::Tensor> x);
|