torch-rb 0.3.4 → 0.4.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,373 @@
1
+ // adapted from PyTorch - python_arg_parser.h
2
+
3
+ #pragma once
4
+
5
+ #include <torch/torch.h>
6
+ #include <rice/Exception.hpp>
7
+
8
+ #include "templates.h"
9
+ #include "utils.h"
10
+
11
+ enum class ParameterType {
12
+ TENSOR, SCALAR, INT64, DOUBLE, COMPLEX, TENSOR_LIST, INT_LIST, GENERATOR,
13
+ BOOL, STORAGE, PYOBJECT, SCALARTYPE, LAYOUT, MEMORY_FORMAT, DEVICE, STRING,
14
+ DIMNAME, DIMNAME_LIST, QSCHEME, FLOAT_LIST
15
+ };
16
+
17
+ struct FunctionParameter {
18
+ FunctionParameter(const std::string& fmt, bool keyword_only);
19
+
20
+ bool check(VALUE obj, int argnum);
21
+
22
+ void set_default_str(const std::string& str);
23
+ std::string type_name() const;
24
+
25
+ ParameterType type_;
26
+ bool optional;
27
+ bool allow_none;
28
+ bool keyword_only;
29
+ bool allow_numbers_as_tensors = false;
30
+ int size;
31
+ std::string name;
32
+ VALUE ruby_name;
33
+ at::SmallVector<VALUE, 5> numpy_python_names;
34
+ at::Scalar default_scalar;
35
+ std::vector<int64_t> default_intlist;
36
+ union {
37
+ bool default_bool;
38
+ int64_t default_int;
39
+ double default_double;
40
+ double default_complex[2]; // see Scalar
41
+ at::ScalarType default_scalartype;
42
+ at::Layout default_layout;
43
+ };
44
+ };
45
+
46
+ struct FunctionSignature {
47
+ explicit FunctionSignature(const std::string& fmt, int index);
48
+
49
+ bool parse(VALUE self, VALUE args, VALUE kwargs, std::vector<VALUE>& dst, bool raise_exception);
50
+
51
+ std::string toString() const;
52
+
53
+ std::string name;
54
+ std::vector<FunctionParameter> params;
55
+ // std::vector<py::handle> overloaded_args;
56
+ ssize_t min_args;
57
+ ssize_t max_args;
58
+ ssize_t max_pos_args;
59
+ int index;
60
+ bool hidden;
61
+ bool deprecated;
62
+ bool disable_torch_function;
63
+ };
64
+
65
+ struct RubyArgs {
66
+ RubyArgs(const FunctionSignature& signature, std::vector<VALUE> &args)
67
+ : signature(signature)
68
+ , args(args)
69
+ , idx(signature.index) {}
70
+
71
+ const FunctionSignature& signature;
72
+ std::vector<VALUE> args;
73
+ int idx;
74
+
75
+ inline at::Tensor tensor(int i);
76
+ inline OptionalTensor optionalTensor(int i);
77
+ inline at::Scalar scalar(int i);
78
+ // inline at::Scalar scalarWithDefault(int i, at::Scalar default_scalar);
79
+ inline std::vector<at::Tensor> tensorlist(int i);
80
+ template<int N>
81
+ inline std::array<at::Tensor, N> tensorlist_n(int i);
82
+ inline std::vector<int64_t> intlist(int i);
83
+ // inline c10::OptionalArray<int64_t> intlistOptional(int i);
84
+ // inline std::vector<int64_t> intlistWithDefault(int i, std::vector<int64_t> default_intlist);
85
+ inline c10::optional<at::Generator> generator(int i);
86
+ inline at::Storage storage(int i);
87
+ inline at::ScalarType scalartype(int i);
88
+ // inline at::ScalarType scalartypeWithDefault(int i, at::ScalarType default_scalartype);
89
+ inline c10::optional<at::ScalarType> scalartypeOptional(int i);
90
+ inline c10::optional<at::Scalar> scalarOptional(int i);
91
+ inline c10::optional<int64_t> toInt64Optional(int i);
92
+ inline c10::optional<bool> toBoolOptional(int i);
93
+ inline c10::optional<double> toDoubleOptional(int i);
94
+ // inline c10::OptionalArray<double> doublelistOptional(int i);
95
+ // inline at::Layout layout(int i);
96
+ // inline at::Layout layoutWithDefault(int i, at::Layout default_layout);
97
+ inline c10::optional<at::Layout> layoutOptional(int i);
98
+ inline at::Device device(int i);
99
+ // inline at::Device deviceWithDefault(int i, const at::Device& default_device);
100
+ // inline c10::optional<at::Device> deviceOptional(int i);
101
+ // inline at::Dimname dimname(int i);
102
+ // inline std::vector<at::Dimname> dimnamelist(int i);
103
+ // inline c10::optional<std::vector<at::Dimname>> toDimnameListOptional(int i);
104
+ inline at::MemoryFormat memoryformat(int i);
105
+ inline c10::optional<at::MemoryFormat> memoryformatOptional(int i);
106
+ // inline at::QScheme toQScheme(int i);
107
+ inline std::string string(int i);
108
+ // inline c10::optional<std::string> stringOptional(int i);
109
+ // inline PyObject* pyobject(int i);
110
+ inline int64_t toInt64(int i);
111
+ // inline int64_t toInt64WithDefault(int i, int64_t default_int);
112
+ inline double toDouble(int i);
113
+ // inline double toDoubleWithDefault(int i, double default_double);
114
+ // inline c10::complex<double> toComplex(int i);
115
+ // inline c10::complex<double> toComplexWithDefault(int i, c10::complex<double> default_complex);
116
+ inline bool toBool(int i);
117
+ // inline bool toBoolWithDefault(int i, bool default_bool);
118
+ inline bool isNone(int i);
119
+ };
120
+
121
+ inline at::Tensor RubyArgs::tensor(int i) {
122
+ return from_ruby<torch::Tensor>(args[i]);
123
+ }
124
+
125
+ inline OptionalTensor RubyArgs::optionalTensor(int i) {
126
+ if (NIL_P(args[i])) return OptionalTensor(Nil);
127
+ return tensor(i);
128
+ }
129
+
130
+ inline at::Scalar RubyArgs::scalar(int i) {
131
+ if (NIL_P(args[i])) return signature.params[i].default_scalar;
132
+ return from_ruby<torch::Scalar>(args[i]);
133
+ }
134
+
135
+ inline std::vector<at::Tensor> RubyArgs::tensorlist(int i) {
136
+ if (NIL_P(args[i])) return std::vector<at::Tensor>();
137
+ return from_ruby<std::vector<Tensor>>(args[i]);
138
+ }
139
+
140
+ template<int N>
141
+ inline std::array<at::Tensor, N> RubyArgs::tensorlist_n(int i) {
142
+ auto res = std::array<at::Tensor, N>();
143
+ if (NIL_P(args[i])) return res;
144
+ VALUE arg = args[i];
145
+ Check_Type(arg, T_ARRAY);
146
+ auto size = RARRAY_LEN(arg);
147
+ if (size != N) {
148
+ rb_raise(rb_eArgError, "expected array of %d elements but got %d", N, (int)size);
149
+ }
150
+ for (int idx = 0; idx < size; idx++) {
151
+ VALUE obj = rb_ary_entry(arg, idx);
152
+ res[idx] = from_ruby<Tensor>(obj);
153
+ }
154
+ return res;
155
+ }
156
+
157
+ inline std::vector<int64_t> RubyArgs::intlist(int i) {
158
+ if (NIL_P(args[i])) return signature.params[i].default_intlist;
159
+
160
+ VALUE arg = args[i];
161
+ auto size = signature.params[i].size;
162
+ if (size > 0 && FIXNUM_P(arg)) {
163
+ return std::vector<int64_t>(size, FIX2INT(arg));
164
+ }
165
+
166
+ size = RARRAY_LEN(arg);
167
+ std::vector<int64_t> res(size);
168
+ for (idx = 0; idx < size; idx++) {
169
+ VALUE obj = rb_ary_entry(arg, idx);
170
+ if (FIXNUM_P(obj)) {
171
+ res[idx] = from_ruby<int64_t>(obj);
172
+ } else {
173
+ rb_raise(rb_eArgError, "%s(): argument '%s' must be %s, but found element of type %s at pos %d",
174
+ signature.name.c_str(), signature.params[i].name.c_str(),
175
+ signature.params[i].type_name().c_str(), rb_obj_classname(obj), idx + 1);
176
+ }
177
+ }
178
+ return res;
179
+ }
180
+
181
+ inline c10::optional<at::Generator> RubyArgs::generator(int i) {
182
+ if (NIL_P(args[i])) return c10::nullopt;
183
+ throw std::runtime_error("generator not supported yet");
184
+ }
185
+
186
+ inline at::Storage RubyArgs::storage(int i) {
187
+ if (NIL_P(args[i])) return at::Storage();
188
+ throw std::runtime_error("storage not supported yet");
189
+ }
190
+
191
+ inline ScalarType RubyArgs::scalartype(int i) {
192
+ if (NIL_P(args[i])) {
193
+ auto scalartype = signature.params[i].default_scalartype;
194
+ return (scalartype == at::ScalarType::Undefined) ? at::typeMetaToScalarType(at::get_default_dtype()) : scalartype;
195
+ }
196
+
197
+ static std::unordered_map<VALUE, ScalarType> dtype_map = {
198
+ {ID2SYM(rb_intern("uint8")), ScalarType::Byte},
199
+ {ID2SYM(rb_intern("int8")), ScalarType::Char},
200
+ {ID2SYM(rb_intern("short")), ScalarType::Short},
201
+ {ID2SYM(rb_intern("int16")), ScalarType::Short},
202
+ {ID2SYM(rb_intern("int")), ScalarType::Int},
203
+ {ID2SYM(rb_intern("int32")), ScalarType::Int},
204
+ {ID2SYM(rb_intern("long")), ScalarType::Long},
205
+ {ID2SYM(rb_intern("int64")), ScalarType::Long},
206
+ {ID2SYM(rb_intern("float")), ScalarType::Float},
207
+ {ID2SYM(rb_intern("float32")), ScalarType::Float},
208
+ {ID2SYM(rb_intern("double")), ScalarType::Double},
209
+ {ID2SYM(rb_intern("float64")), ScalarType::Double},
210
+ {ID2SYM(rb_intern("complex_half")), ScalarType::ComplexHalf},
211
+ {ID2SYM(rb_intern("complex_float")), ScalarType::ComplexFloat},
212
+ {ID2SYM(rb_intern("complex_double")), ScalarType::ComplexDouble},
213
+ {ID2SYM(rb_intern("bool")), ScalarType::Bool},
214
+ {ID2SYM(rb_intern("qint8")), ScalarType::QInt8},
215
+ {ID2SYM(rb_intern("quint8")), ScalarType::QUInt8},
216
+ {ID2SYM(rb_intern("qint32")), ScalarType::QInt32},
217
+ {ID2SYM(rb_intern("bfloat16")), ScalarType::BFloat16},
218
+ };
219
+
220
+ auto it = dtype_map.find(args[i]);
221
+ if (it == dtype_map.end()) {
222
+ rb_raise(rb_eArgError, "invalid dtype: %s", THPUtils_unpackSymbol(args[i]).c_str());
223
+ }
224
+ return it->second;
225
+ }
226
+
227
+ inline c10::optional<ScalarType> RubyArgs::scalartypeOptional(int i) {
228
+ if (NIL_P(args[i])) return c10::nullopt;
229
+ return scalartype(i);
230
+ }
231
+
232
+ inline c10::optional<Scalar> RubyArgs::scalarOptional(int i) {
233
+ if (NIL_P(args[i])) return c10::nullopt;
234
+ return scalar(i);
235
+ }
236
+
237
+ inline c10::optional<int64_t> RubyArgs::toInt64Optional(int i) {
238
+ if (NIL_P(args[i])) return c10::nullopt;
239
+ return toInt64(i);
240
+ }
241
+
242
+ inline c10::optional<bool> RubyArgs::toBoolOptional(int i) {
243
+ if (NIL_P(args[i])) return c10::nullopt;
244
+ return toBool(i);
245
+ }
246
+
247
+ inline c10::optional<double> RubyArgs::toDoubleOptional(int i) {
248
+ if (NIL_P(args[i])) return c10::nullopt;
249
+ return toDouble(i);
250
+ }
251
+
252
+ inline c10::optional<at::Layout> RubyArgs::layoutOptional(int i) {
253
+ if (NIL_P(args[i])) return c10::nullopt;
254
+
255
+ static std::unordered_map<VALUE, Layout> layout_map = {
256
+ {ID2SYM(rb_intern("strided")), Layout::Strided},
257
+ };
258
+
259
+ auto it = layout_map.find(args[i]);
260
+ if (it == layout_map.end()) {
261
+ rb_raise(rb_eArgError, "invalid layout: %s", THPUtils_unpackSymbol(args[i]).c_str());
262
+ }
263
+ return it->second;
264
+ }
265
+
266
+ inline at::Device RubyArgs::device(int i) {
267
+ if (NIL_P(args[i])) {
268
+ return at::Device("cpu");
269
+ }
270
+ const std::string &device_str = THPUtils_unpackString(args[i]);
271
+ return at::Device(device_str);
272
+ }
273
+
274
+ inline at::MemoryFormat RubyArgs::memoryformat(int i) {
275
+ if (NIL_P(args[i])) return at::MemoryFormat::Contiguous;
276
+ throw std::runtime_error("memoryformat not supported yet");
277
+ }
278
+
279
+ inline c10::optional<at::MemoryFormat> RubyArgs::memoryformatOptional(int i) {
280
+ if (NIL_P(args[i])) return c10::nullopt;
281
+ return memoryformat(i);
282
+ }
283
+
284
+ inline std::string RubyArgs::string(int i) {
285
+ return from_ruby<std::string>(args[i]);
286
+ }
287
+
288
+ inline int64_t RubyArgs::toInt64(int i) {
289
+ if (NIL_P(args[i])) return signature.params[i].default_int;
290
+ return from_ruby<int64_t>(args[i]);
291
+ }
292
+
293
+ inline double RubyArgs::toDouble(int i) {
294
+ if (NIL_P(args[i])) return signature.params[i].default_double;
295
+ return from_ruby<double>(args[i]);
296
+ }
297
+
298
+ inline bool RubyArgs::toBool(int i) {
299
+ if (NIL_P(args[i])) return signature.params[i].default_bool;
300
+ return RTEST(args[i]);
301
+ }
302
+
303
+ inline bool RubyArgs::isNone(int i) {
304
+ return NIL_P(args[i]);
305
+ }
306
+
307
+ struct RubyArgParser {
308
+ std::vector<FunctionSignature> signatures_;
309
+ std::string function_name;
310
+ ssize_t max_args;
311
+
312
+ public:
313
+ RubyArgParser(std::vector<std::string> fmts) : max_args(0) {
314
+ int index = 0;
315
+ for (auto& fmt : fmts) {
316
+ signatures_.emplace_back(fmt, index);
317
+ ++index;
318
+ }
319
+ for (auto& signature : signatures_) {
320
+ if (signature.max_args > max_args) {
321
+ max_args = signature.max_args;
322
+ }
323
+ }
324
+ if (signatures_.size() > 0) {
325
+ function_name = signatures_[0].name;
326
+ }
327
+
328
+ // Check deprecated signatures last
329
+ std::stable_partition(signatures_.begin(), signatures_.end(),
330
+ [](const FunctionSignature & sig) {
331
+ return !sig.deprecated;
332
+ });
333
+ }
334
+
335
+ RubyArgs parse(VALUE self, int argc, VALUE* argv, std::vector<VALUE> &parsed_args) {
336
+ VALUE args, kwargs;
337
+ rb_scan_args(argc, argv, "*:", &args, &kwargs);
338
+
339
+ if (signatures_.size() == 1) {
340
+ auto& signature = signatures_[0];
341
+ signature.parse(self, args, kwargs, parsed_args, true);
342
+ return RubyArgs(signature, parsed_args);
343
+ }
344
+
345
+ for (auto& signature : signatures_) {
346
+ if (signature.parse(self, args, kwargs, parsed_args, false)) {
347
+ return RubyArgs(signature, parsed_args);
348
+ }
349
+ }
350
+
351
+ print_error(self, args, kwargs, parsed_args);
352
+
353
+ // TODO better message
354
+ rb_raise(rb_eArgError, "No matching signatures");
355
+ }
356
+
357
+ void print_error(VALUE self, VALUE args, VALUE kwargs, std::vector<VALUE>& parsed_args) {
358
+ ssize_t num_args = (NIL_P(args) ? 0 : RARRAY_LEN(args)) + (NIL_P(kwargs) ? 0 : RHASH_SIZE(kwargs));
359
+ std::vector<int> plausible_idxs;
360
+ ssize_t i = 0;
361
+ for (auto& signature : signatures_) {
362
+ if (num_args >= signature.min_args && num_args <= signature.max_args && !signature.hidden) {
363
+ plausible_idxs.push_back(i);
364
+ }
365
+ i++;
366
+ }
367
+
368
+ if (plausible_idxs.size() == 1) {
369
+ auto& signature = signatures_[plausible_idxs[0]];
370
+ signature.parse(self, args, kwargs, parsed_args, true);
371
+ }
372
+ }
373
+ };
@@ -10,75 +10,55 @@
10
10
  using namespace Rice;
11
11
 
12
12
  using torch::Device;
13
+ using torch::Scalar;
13
14
  using torch::ScalarType;
14
15
  using torch::Tensor;
16
+ using torch::QScheme;
17
+ using torch::Generator;
18
+ using torch::TensorOptions;
19
+ using torch::Layout;
20
+ using torch::MemoryFormat;
21
+ using torch::IntArrayRef;
22
+ using torch::TensorList;
23
+ using torch::Storage;
24
+
25
+ #define HANDLE_TH_ERRORS \
26
+ try {
27
+
28
+ #define END_HANDLE_TH_ERRORS \
29
+ } catch (const torch::Error& ex) { \
30
+ rb_raise(rb_eRuntimeError, "%s", ex.what_without_backtrace()); \
31
+ } catch (const Rice::Exception& ex) { \
32
+ rb_raise(ex.class_of(), "%s", ex.what()); \
33
+ } catch (const std::exception& ex) { \
34
+ rb_raise(rb_eRuntimeError, "%s", ex.what()); \
35
+ }
15
36
 
16
- // need to wrap torch::IntArrayRef() since
17
- // it doesn't own underlying data
18
- class IntArrayRef {
19
- std::vector<int64_t> vec;
20
- public:
21
- IntArrayRef(Object o) {
22
- Array a = Array(o);
23
- for (size_t i = 0; i < a.size(); i++) {
24
- vec.push_back(from_ruby<int64_t>(a[i]));
25
- }
26
- }
27
- operator torch::IntArrayRef() {
28
- return torch::IntArrayRef(vec);
29
- }
30
- };
31
-
32
- template<>
33
- inline
34
- IntArrayRef from_ruby<IntArrayRef>(Object x)
35
- {
36
- return IntArrayRef(x);
37
- }
38
-
39
- // for now
40
- class Scalar {
41
- torch::Scalar value;
42
- public:
43
- Scalar(Object o) {
44
- // TODO cast based on Ruby type
45
- if (o.rb_type() == T_FIXNUM) {
46
- value = torch::Scalar(from_ruby<int64_t>(o));
47
- } else {
48
- value = torch::Scalar(from_ruby<float>(o));
49
- }
50
- }
51
- operator torch::Scalar() {
52
- return value;
53
- }
54
- };
37
+ #define RETURN_NIL \
38
+ return Qnil;
55
39
 
56
40
  template<>
57
41
  inline
58
- Scalar from_ruby<Scalar>(Object x)
42
+ std::vector<int64_t> from_ruby<std::vector<int64_t>>(Object x)
59
43
  {
60
- return Scalar(x);
44
+ Array a = Array(x);
45
+ std::vector<int64_t> vec(a.size());
46
+ for (size_t i = 0; i < a.size(); i++) {
47
+ vec[i] = from_ruby<int64_t>(a[i]);
48
+ }
49
+ return vec;
61
50
  }
62
51
 
63
- class TensorList {
64
- std::vector<torch::Tensor> vec;
65
- public:
66
- TensorList(Object o) {
67
- Array a = Array(o);
68
- for (size_t i = 0; i < a.size(); i++) {
69
- vec.push_back(from_ruby<torch::Tensor>(a[i]));
70
- }
71
- }
72
- operator torch::TensorList() {
73
- return torch::TensorList(vec);
74
- }
75
- };
76
-
77
52
  template<>
78
53
  inline
79
- TensorList from_ruby<TensorList>(Object x)
54
+ std::vector<Tensor> from_ruby<std::vector<Tensor>>(Object x)
80
55
  {
81
- return TensorList(x);
56
+ Array a = Array(x);
57
+ std::vector<Tensor> vec(a.size());
58
+ for (size_t i = 0; i < a.size(); i++) {
59
+ vec[i] = from_ruby<Tensor>(a[i]);
60
+ }
61
+ return vec;
82
62
  }
83
63
 
84
64
  class FanModeType {
@@ -147,51 +127,35 @@ NonlinearityType from_ruby<NonlinearityType>(Object x)
147
127
  return NonlinearityType(x);
148
128
  }
149
129
 
150
- class MyReduction {
151
- Object value;
130
+ class OptionalTensor {
131
+ torch::Tensor value;
152
132
  public:
153
- MyReduction(Object o) {
154
- value = o;
155
- }
156
- operator int64_t() {
157
- if (value.is_nil()) {
158
- return torch::Reduction::None;
159
- }
160
-
161
- std::string s = String(value).str();
162
- if (s == "mean") {
163
- return torch::Reduction::Mean;
164
- } else if (s == "sum") {
165
- return torch::Reduction::Sum;
166
- } else if (s == "none") {
167
- return torch::Reduction::None;
133
+ OptionalTensor(Object o) {
134
+ if (o.is_nil()) {
135
+ value = {};
168
136
  } else {
169
- throw std::runtime_error("Unsupported reduction: " + s);
137
+ value = from_ruby<torch::Tensor>(o);
170
138
  }
171
139
  }
140
+ OptionalTensor(torch::Tensor o) {
141
+ value = o;
142
+ }
143
+ operator torch::Tensor() const {
144
+ return value;
145
+ }
172
146
  };
173
147
 
174
148
  template<>
175
149
  inline
176
- MyReduction from_ruby<MyReduction>(Object x)
150
+ Scalar from_ruby<Scalar>(Object x)
177
151
  {
178
- return MyReduction(x);
152
+ if (x.rb_type() == T_FIXNUM) {
153
+ return torch::Scalar(from_ruby<int64_t>(x));
154
+ } else {
155
+ return torch::Scalar(from_ruby<double>(x));
156
+ }
179
157
  }
180
158
 
181
- class OptionalTensor {
182
- Object value;
183
- public:
184
- OptionalTensor(Object o) {
185
- value = o;
186
- }
187
- operator torch::Tensor() {
188
- if (value.is_nil()) {
189
- return {};
190
- }
191
- return from_ruby<torch::Tensor>(value);
192
- }
193
- };
194
-
195
159
  template<>
196
160
  inline
197
161
  OptionalTensor from_ruby<OptionalTensor>(Object x)
@@ -221,9 +185,35 @@ torch::optional<int64_t> from_ruby<torch::optional<int64_t>>(Object x)
221
185
  }
222
186
  }
223
187
 
224
- Object wrap(std::tuple<torch::Tensor, torch::Tensor> x);
225
- Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x);
226
- Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x);
227
- Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x);
228
- Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x);
229
- Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x);
188
+ template<>
189
+ inline
190
+ torch::optional<double> from_ruby<torch::optional<double>>(Object x)
191
+ {
192
+ if (x.is_nil()) {
193
+ return torch::nullopt;
194
+ } else {
195
+ return torch::optional<double>{from_ruby<double>(x)};
196
+ }
197
+ }
198
+
199
+ template<>
200
+ inline
201
+ torch::optional<bool> from_ruby<torch::optional<bool>>(Object x)
202
+ {
203
+ if (x.is_nil()) {
204
+ return torch::nullopt;
205
+ } else {
206
+ return torch::optional<bool>{from_ruby<bool>(x)};
207
+ }
208
+ }
209
+
210
+ template<>
211
+ inline
212
+ torch::optional<Scalar> from_ruby<torch::optional<Scalar>>(Object x)
213
+ {
214
+ if (x.is_nil()) {
215
+ return torch::nullopt;
216
+ } else {
217
+ return torch::optional<Scalar>{from_ruby<Scalar>(x)};
218
+ }
219
+ }