torch-rb 0.3.4 → 0.4.1

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,373 @@
1
+ // adapted from PyTorch - python_arg_parser.h
2
+
3
+ #pragma once
4
+
5
+ #include <torch/torch.h>
6
+ #include <rice/Exception.hpp>
7
+
8
+ #include "templates.h"
9
+ #include "utils.h"
10
+
11
+ enum class ParameterType {
12
+ TENSOR, SCALAR, INT64, DOUBLE, COMPLEX, TENSOR_LIST, INT_LIST, GENERATOR,
13
+ BOOL, STORAGE, PYOBJECT, SCALARTYPE, LAYOUT, MEMORY_FORMAT, DEVICE, STRING,
14
+ DIMNAME, DIMNAME_LIST, QSCHEME, FLOAT_LIST
15
+ };
16
+
17
+ struct FunctionParameter {
18
+ FunctionParameter(const std::string& fmt, bool keyword_only);
19
+
20
+ bool check(VALUE obj, int argnum);
21
+
22
+ void set_default_str(const std::string& str);
23
+ std::string type_name() const;
24
+
25
+ ParameterType type_;
26
+ bool optional;
27
+ bool allow_none;
28
+ bool keyword_only;
29
+ bool allow_numbers_as_tensors = false;
30
+ int size;
31
+ std::string name;
32
+ VALUE ruby_name;
33
+ at::SmallVector<VALUE, 5> numpy_python_names;
34
+ at::Scalar default_scalar;
35
+ std::vector<int64_t> default_intlist;
36
+ union {
37
+ bool default_bool;
38
+ int64_t default_int;
39
+ double default_double;
40
+ double default_complex[2]; // see Scalar
41
+ at::ScalarType default_scalartype;
42
+ at::Layout default_layout;
43
+ };
44
+ };
45
+
46
+ struct FunctionSignature {
47
+ explicit FunctionSignature(const std::string& fmt, int index);
48
+
49
+ bool parse(VALUE self, VALUE args, VALUE kwargs, std::vector<VALUE>& dst, bool raise_exception);
50
+
51
+ std::string toString() const;
52
+
53
+ std::string name;
54
+ std::vector<FunctionParameter> params;
55
+ // std::vector<py::handle> overloaded_args;
56
+ ssize_t min_args;
57
+ ssize_t max_args;
58
+ ssize_t max_pos_args;
59
+ int index;
60
+ bool hidden;
61
+ bool deprecated;
62
+ bool disable_torch_function;
63
+ };
64
+
65
+ struct RubyArgs {
66
+ RubyArgs(const FunctionSignature& signature, std::vector<VALUE> &args)
67
+ : signature(signature)
68
+ , args(args)
69
+ , idx(signature.index) {}
70
+
71
+ const FunctionSignature& signature;
72
+ std::vector<VALUE> args;
73
+ int idx;
74
+
75
+ inline at::Tensor tensor(int i);
76
+ inline OptionalTensor optionalTensor(int i);
77
+ inline at::Scalar scalar(int i);
78
+ // inline at::Scalar scalarWithDefault(int i, at::Scalar default_scalar);
79
+ inline std::vector<at::Tensor> tensorlist(int i);
80
+ template<int N>
81
+ inline std::array<at::Tensor, N> tensorlist_n(int i);
82
+ inline std::vector<int64_t> intlist(int i);
83
+ // inline c10::OptionalArray<int64_t> intlistOptional(int i);
84
+ // inline std::vector<int64_t> intlistWithDefault(int i, std::vector<int64_t> default_intlist);
85
+ inline c10::optional<at::Generator> generator(int i);
86
+ inline at::Storage storage(int i);
87
+ inline at::ScalarType scalartype(int i);
88
+ // inline at::ScalarType scalartypeWithDefault(int i, at::ScalarType default_scalartype);
89
+ inline c10::optional<at::ScalarType> scalartypeOptional(int i);
90
+ inline c10::optional<at::Scalar> scalarOptional(int i);
91
+ inline c10::optional<int64_t> toInt64Optional(int i);
92
+ inline c10::optional<bool> toBoolOptional(int i);
93
+ inline c10::optional<double> toDoubleOptional(int i);
94
+ // inline c10::OptionalArray<double> doublelistOptional(int i);
95
+ // inline at::Layout layout(int i);
96
+ // inline at::Layout layoutWithDefault(int i, at::Layout default_layout);
97
+ inline c10::optional<at::Layout> layoutOptional(int i);
98
+ inline at::Device device(int i);
99
+ // inline at::Device deviceWithDefault(int i, const at::Device& default_device);
100
+ // inline c10::optional<at::Device> deviceOptional(int i);
101
+ // inline at::Dimname dimname(int i);
102
+ // inline std::vector<at::Dimname> dimnamelist(int i);
103
+ // inline c10::optional<std::vector<at::Dimname>> toDimnameListOptional(int i);
104
+ inline at::MemoryFormat memoryformat(int i);
105
+ inline c10::optional<at::MemoryFormat> memoryformatOptional(int i);
106
+ // inline at::QScheme toQScheme(int i);
107
+ inline std::string string(int i);
108
+ // inline c10::optional<std::string> stringOptional(int i);
109
+ // inline PyObject* pyobject(int i);
110
+ inline int64_t toInt64(int i);
111
+ // inline int64_t toInt64WithDefault(int i, int64_t default_int);
112
+ inline double toDouble(int i);
113
+ // inline double toDoubleWithDefault(int i, double default_double);
114
+ // inline c10::complex<double> toComplex(int i);
115
+ // inline c10::complex<double> toComplexWithDefault(int i, c10::complex<double> default_complex);
116
+ inline bool toBool(int i);
117
+ // inline bool toBoolWithDefault(int i, bool default_bool);
118
+ inline bool isNone(int i);
119
+ };
120
+
121
+ inline at::Tensor RubyArgs::tensor(int i) {
122
+ return from_ruby<torch::Tensor>(args[i]);
123
+ }
124
+
125
+ inline OptionalTensor RubyArgs::optionalTensor(int i) {
126
+ if (NIL_P(args[i])) return OptionalTensor(Nil);
127
+ return tensor(i);
128
+ }
129
+
130
+ inline at::Scalar RubyArgs::scalar(int i) {
131
+ if (NIL_P(args[i])) return signature.params[i].default_scalar;
132
+ return from_ruby<torch::Scalar>(args[i]);
133
+ }
134
+
135
+ inline std::vector<at::Tensor> RubyArgs::tensorlist(int i) {
136
+ if (NIL_P(args[i])) return std::vector<at::Tensor>();
137
+ return from_ruby<std::vector<Tensor>>(args[i]);
138
+ }
139
+
140
+ template<int N>
141
+ inline std::array<at::Tensor, N> RubyArgs::tensorlist_n(int i) {
142
+ auto res = std::array<at::Tensor, N>();
143
+ if (NIL_P(args[i])) return res;
144
+ VALUE arg = args[i];
145
+ Check_Type(arg, T_ARRAY);
146
+ auto size = RARRAY_LEN(arg);
147
+ if (size != N) {
148
+ rb_raise(rb_eArgError, "expected array of %d elements but got %d", N, (int)size);
149
+ }
150
+ for (int idx = 0; idx < size; idx++) {
151
+ VALUE obj = rb_ary_entry(arg, idx);
152
+ res[idx] = from_ruby<Tensor>(obj);
153
+ }
154
+ return res;
155
+ }
156
+
157
+ inline std::vector<int64_t> RubyArgs::intlist(int i) {
158
+ if (NIL_P(args[i])) return signature.params[i].default_intlist;
159
+
160
+ VALUE arg = args[i];
161
+ auto size = signature.params[i].size;
162
+ if (size > 0 && FIXNUM_P(arg)) {
163
+ return std::vector<int64_t>(size, FIX2INT(arg));
164
+ }
165
+
166
+ size = RARRAY_LEN(arg);
167
+ std::vector<int64_t> res(size);
168
+ for (idx = 0; idx < size; idx++) {
169
+ VALUE obj = rb_ary_entry(arg, idx);
170
+ if (FIXNUM_P(obj)) {
171
+ res[idx] = from_ruby<int64_t>(obj);
172
+ } else {
173
+ rb_raise(rb_eArgError, "%s(): argument '%s' must be %s, but found element of type %s at pos %d",
174
+ signature.name.c_str(), signature.params[i].name.c_str(),
175
+ signature.params[i].type_name().c_str(), rb_obj_classname(obj), idx + 1);
176
+ }
177
+ }
178
+ return res;
179
+ }
180
+
181
+ inline c10::optional<at::Generator> RubyArgs::generator(int i) {
182
+ if (NIL_P(args[i])) return c10::nullopt;
183
+ throw std::runtime_error("generator not supported yet");
184
+ }
185
+
186
+ inline at::Storage RubyArgs::storage(int i) {
187
+ if (NIL_P(args[i])) return at::Storage();
188
+ throw std::runtime_error("storage not supported yet");
189
+ }
190
+
191
+ inline ScalarType RubyArgs::scalartype(int i) {
192
+ if (NIL_P(args[i])) {
193
+ auto scalartype = signature.params[i].default_scalartype;
194
+ return (scalartype == at::ScalarType::Undefined) ? at::typeMetaToScalarType(at::get_default_dtype()) : scalartype;
195
+ }
196
+
197
+ static std::unordered_map<VALUE, ScalarType> dtype_map = {
198
+ {ID2SYM(rb_intern("uint8")), ScalarType::Byte},
199
+ {ID2SYM(rb_intern("int8")), ScalarType::Char},
200
+ {ID2SYM(rb_intern("short")), ScalarType::Short},
201
+ {ID2SYM(rb_intern("int16")), ScalarType::Short},
202
+ {ID2SYM(rb_intern("int")), ScalarType::Int},
203
+ {ID2SYM(rb_intern("int32")), ScalarType::Int},
204
+ {ID2SYM(rb_intern("long")), ScalarType::Long},
205
+ {ID2SYM(rb_intern("int64")), ScalarType::Long},
206
+ {ID2SYM(rb_intern("float")), ScalarType::Float},
207
+ {ID2SYM(rb_intern("float32")), ScalarType::Float},
208
+ {ID2SYM(rb_intern("double")), ScalarType::Double},
209
+ {ID2SYM(rb_intern("float64")), ScalarType::Double},
210
+ {ID2SYM(rb_intern("complex_half")), ScalarType::ComplexHalf},
211
+ {ID2SYM(rb_intern("complex_float")), ScalarType::ComplexFloat},
212
+ {ID2SYM(rb_intern("complex_double")), ScalarType::ComplexDouble},
213
+ {ID2SYM(rb_intern("bool")), ScalarType::Bool},
214
+ {ID2SYM(rb_intern("qint8")), ScalarType::QInt8},
215
+ {ID2SYM(rb_intern("quint8")), ScalarType::QUInt8},
216
+ {ID2SYM(rb_intern("qint32")), ScalarType::QInt32},
217
+ {ID2SYM(rb_intern("bfloat16")), ScalarType::BFloat16},
218
+ };
219
+
220
+ auto it = dtype_map.find(args[i]);
221
+ if (it == dtype_map.end()) {
222
+ rb_raise(rb_eArgError, "invalid dtype: %s", THPUtils_unpackSymbol(args[i]).c_str());
223
+ }
224
+ return it->second;
225
+ }
226
+
227
+ inline c10::optional<ScalarType> RubyArgs::scalartypeOptional(int i) {
228
+ if (NIL_P(args[i])) return c10::nullopt;
229
+ return scalartype(i);
230
+ }
231
+
232
+ inline c10::optional<Scalar> RubyArgs::scalarOptional(int i) {
233
+ if (NIL_P(args[i])) return c10::nullopt;
234
+ return scalar(i);
235
+ }
236
+
237
+ inline c10::optional<int64_t> RubyArgs::toInt64Optional(int i) {
238
+ if (NIL_P(args[i])) return c10::nullopt;
239
+ return toInt64(i);
240
+ }
241
+
242
+ inline c10::optional<bool> RubyArgs::toBoolOptional(int i) {
243
+ if (NIL_P(args[i])) return c10::nullopt;
244
+ return toBool(i);
245
+ }
246
+
247
+ inline c10::optional<double> RubyArgs::toDoubleOptional(int i) {
248
+ if (NIL_P(args[i])) return c10::nullopt;
249
+ return toDouble(i);
250
+ }
251
+
252
+ inline c10::optional<at::Layout> RubyArgs::layoutOptional(int i) {
253
+ if (NIL_P(args[i])) return c10::nullopt;
254
+
255
+ static std::unordered_map<VALUE, Layout> layout_map = {
256
+ {ID2SYM(rb_intern("strided")), Layout::Strided},
257
+ };
258
+
259
+ auto it = layout_map.find(args[i]);
260
+ if (it == layout_map.end()) {
261
+ rb_raise(rb_eArgError, "invalid layout: %s", THPUtils_unpackSymbol(args[i]).c_str());
262
+ }
263
+ return it->second;
264
+ }
265
+
266
+ inline at::Device RubyArgs::device(int i) {
267
+ if (NIL_P(args[i])) {
268
+ return at::Device("cpu");
269
+ }
270
+ const std::string &device_str = THPUtils_unpackString(args[i]);
271
+ return at::Device(device_str);
272
+ }
273
+
274
+ inline at::MemoryFormat RubyArgs::memoryformat(int i) {
275
+ if (NIL_P(args[i])) return at::MemoryFormat::Contiguous;
276
+ throw std::runtime_error("memoryformat not supported yet");
277
+ }
278
+
279
+ inline c10::optional<at::MemoryFormat> RubyArgs::memoryformatOptional(int i) {
280
+ if (NIL_P(args[i])) return c10::nullopt;
281
+ return memoryformat(i);
282
+ }
283
+
284
+ inline std::string RubyArgs::string(int i) {
285
+ return from_ruby<std::string>(args[i]);
286
+ }
287
+
288
+ inline int64_t RubyArgs::toInt64(int i) {
289
+ if (NIL_P(args[i])) return signature.params[i].default_int;
290
+ return from_ruby<int64_t>(args[i]);
291
+ }
292
+
293
+ inline double RubyArgs::toDouble(int i) {
294
+ if (NIL_P(args[i])) return signature.params[i].default_double;
295
+ return from_ruby<double>(args[i]);
296
+ }
297
+
298
+ inline bool RubyArgs::toBool(int i) {
299
+ if (NIL_P(args[i])) return signature.params[i].default_bool;
300
+ return RTEST(args[i]);
301
+ }
302
+
303
+ inline bool RubyArgs::isNone(int i) {
304
+ return NIL_P(args[i]);
305
+ }
306
+
307
+ struct RubyArgParser {
308
+ std::vector<FunctionSignature> signatures_;
309
+ std::string function_name;
310
+ ssize_t max_args;
311
+
312
+ public:
313
+ RubyArgParser(std::vector<std::string> fmts) : max_args(0) {
314
+ int index = 0;
315
+ for (auto& fmt : fmts) {
316
+ signatures_.emplace_back(fmt, index);
317
+ ++index;
318
+ }
319
+ for (auto& signature : signatures_) {
320
+ if (signature.max_args > max_args) {
321
+ max_args = signature.max_args;
322
+ }
323
+ }
324
+ if (signatures_.size() > 0) {
325
+ function_name = signatures_[0].name;
326
+ }
327
+
328
+ // Check deprecated signatures last
329
+ std::stable_partition(signatures_.begin(), signatures_.end(),
330
+ [](const FunctionSignature & sig) {
331
+ return !sig.deprecated;
332
+ });
333
+ }
334
+
335
+ RubyArgs parse(VALUE self, int argc, VALUE* argv, std::vector<VALUE> &parsed_args) {
336
+ VALUE args, kwargs;
337
+ rb_scan_args(argc, argv, "*:", &args, &kwargs);
338
+
339
+ if (signatures_.size() == 1) {
340
+ auto& signature = signatures_[0];
341
+ signature.parse(self, args, kwargs, parsed_args, true);
342
+ return RubyArgs(signature, parsed_args);
343
+ }
344
+
345
+ for (auto& signature : signatures_) {
346
+ if (signature.parse(self, args, kwargs, parsed_args, false)) {
347
+ return RubyArgs(signature, parsed_args);
348
+ }
349
+ }
350
+
351
+ print_error(self, args, kwargs, parsed_args);
352
+
353
+ // TODO better message
354
+ rb_raise(rb_eArgError, "No matching signatures");
355
+ }
356
+
357
+ void print_error(VALUE self, VALUE args, VALUE kwargs, std::vector<VALUE>& parsed_args) {
358
+ ssize_t num_args = (NIL_P(args) ? 0 : RARRAY_LEN(args)) + (NIL_P(kwargs) ? 0 : RHASH_SIZE(kwargs));
359
+ std::vector<int> plausible_idxs;
360
+ ssize_t i = 0;
361
+ for (auto& signature : signatures_) {
362
+ if (num_args >= signature.min_args && num_args <= signature.max_args && !signature.hidden) {
363
+ plausible_idxs.push_back(i);
364
+ }
365
+ i++;
366
+ }
367
+
368
+ if (plausible_idxs.size() == 1) {
369
+ auto& signature = signatures_[plausible_idxs[0]];
370
+ signature.parse(self, args, kwargs, parsed_args, true);
371
+ }
372
+ }
373
+ };
@@ -10,75 +10,55 @@
10
10
  using namespace Rice;
11
11
 
12
12
  using torch::Device;
13
+ using torch::Scalar;
13
14
  using torch::ScalarType;
14
15
  using torch::Tensor;
16
+ using torch::QScheme;
17
+ using torch::Generator;
18
+ using torch::TensorOptions;
19
+ using torch::Layout;
20
+ using torch::MemoryFormat;
21
+ using torch::IntArrayRef;
22
+ using torch::TensorList;
23
+ using torch::Storage;
24
+
25
+ #define HANDLE_TH_ERRORS \
26
+ try {
27
+
28
+ #define END_HANDLE_TH_ERRORS \
29
+ } catch (const torch::Error& ex) { \
30
+ rb_raise(rb_eRuntimeError, "%s", ex.what_without_backtrace()); \
31
+ } catch (const Rice::Exception& ex) { \
32
+ rb_raise(ex.class_of(), "%s", ex.what()); \
33
+ } catch (const std::exception& ex) { \
34
+ rb_raise(rb_eRuntimeError, "%s", ex.what()); \
35
+ }
15
36
 
16
- // need to wrap torch::IntArrayRef() since
17
- // it doesn't own underlying data
18
- class IntArrayRef {
19
- std::vector<int64_t> vec;
20
- public:
21
- IntArrayRef(Object o) {
22
- Array a = Array(o);
23
- for (size_t i = 0; i < a.size(); i++) {
24
- vec.push_back(from_ruby<int64_t>(a[i]));
25
- }
26
- }
27
- operator torch::IntArrayRef() {
28
- return torch::IntArrayRef(vec);
29
- }
30
- };
31
-
32
- template<>
33
- inline
34
- IntArrayRef from_ruby<IntArrayRef>(Object x)
35
- {
36
- return IntArrayRef(x);
37
- }
38
-
39
- // for now
40
- class Scalar {
41
- torch::Scalar value;
42
- public:
43
- Scalar(Object o) {
44
- // TODO cast based on Ruby type
45
- if (o.rb_type() == T_FIXNUM) {
46
- value = torch::Scalar(from_ruby<int64_t>(o));
47
- } else {
48
- value = torch::Scalar(from_ruby<float>(o));
49
- }
50
- }
51
- operator torch::Scalar() {
52
- return value;
53
- }
54
- };
37
+ #define RETURN_NIL \
38
+ return Qnil;
55
39
 
56
40
  template<>
57
41
  inline
58
- Scalar from_ruby<Scalar>(Object x)
42
+ std::vector<int64_t> from_ruby<std::vector<int64_t>>(Object x)
59
43
  {
60
- return Scalar(x);
44
+ Array a = Array(x);
45
+ std::vector<int64_t> vec(a.size());
46
+ for (size_t i = 0; i < a.size(); i++) {
47
+ vec[i] = from_ruby<int64_t>(a[i]);
48
+ }
49
+ return vec;
61
50
  }
62
51
 
63
- class TensorList {
64
- std::vector<torch::Tensor> vec;
65
- public:
66
- TensorList(Object o) {
67
- Array a = Array(o);
68
- for (size_t i = 0; i < a.size(); i++) {
69
- vec.push_back(from_ruby<torch::Tensor>(a[i]));
70
- }
71
- }
72
- operator torch::TensorList() {
73
- return torch::TensorList(vec);
74
- }
75
- };
76
-
77
52
  template<>
78
53
  inline
79
- TensorList from_ruby<TensorList>(Object x)
54
+ std::vector<Tensor> from_ruby<std::vector<Tensor>>(Object x)
80
55
  {
81
- return TensorList(x);
56
+ Array a = Array(x);
57
+ std::vector<Tensor> vec(a.size());
58
+ for (size_t i = 0; i < a.size(); i++) {
59
+ vec[i] = from_ruby<Tensor>(a[i]);
60
+ }
61
+ return vec;
82
62
  }
83
63
 
84
64
  class FanModeType {
@@ -147,51 +127,35 @@ NonlinearityType from_ruby<NonlinearityType>(Object x)
147
127
  return NonlinearityType(x);
148
128
  }
149
129
 
150
- class MyReduction {
151
- Object value;
130
+ class OptionalTensor {
131
+ torch::Tensor value;
152
132
  public:
153
- MyReduction(Object o) {
154
- value = o;
155
- }
156
- operator int64_t() {
157
- if (value.is_nil()) {
158
- return torch::Reduction::None;
159
- }
160
-
161
- std::string s = String(value).str();
162
- if (s == "mean") {
163
- return torch::Reduction::Mean;
164
- } else if (s == "sum") {
165
- return torch::Reduction::Sum;
166
- } else if (s == "none") {
167
- return torch::Reduction::None;
133
+ OptionalTensor(Object o) {
134
+ if (o.is_nil()) {
135
+ value = {};
168
136
  } else {
169
- throw std::runtime_error("Unsupported reduction: " + s);
137
+ value = from_ruby<torch::Tensor>(o);
170
138
  }
171
139
  }
140
+ OptionalTensor(torch::Tensor o) {
141
+ value = o;
142
+ }
143
+ operator torch::Tensor() const {
144
+ return value;
145
+ }
172
146
  };
173
147
 
174
148
  template<>
175
149
  inline
176
- MyReduction from_ruby<MyReduction>(Object x)
150
+ Scalar from_ruby<Scalar>(Object x)
177
151
  {
178
- return MyReduction(x);
152
+ if (x.rb_type() == T_FIXNUM) {
153
+ return torch::Scalar(from_ruby<int64_t>(x));
154
+ } else {
155
+ return torch::Scalar(from_ruby<double>(x));
156
+ }
179
157
  }
180
158
 
181
- class OptionalTensor {
182
- Object value;
183
- public:
184
- OptionalTensor(Object o) {
185
- value = o;
186
- }
187
- operator torch::Tensor() {
188
- if (value.is_nil()) {
189
- return {};
190
- }
191
- return from_ruby<torch::Tensor>(value);
192
- }
193
- };
194
-
195
159
  template<>
196
160
  inline
197
161
  OptionalTensor from_ruby<OptionalTensor>(Object x)
@@ -221,9 +185,35 @@ torch::optional<int64_t> from_ruby<torch::optional<int64_t>>(Object x)
221
185
  }
222
186
  }
223
187
 
224
- Object wrap(std::tuple<torch::Tensor, torch::Tensor> x);
225
- Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor> x);
226
- Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x);
227
- Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor> x);
228
- Object wrap(std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, int64_t> x);
229
- Object wrap(std::tuple<torch::Tensor, torch::Tensor, double, int64_t> x);
188
+ template<>
189
+ inline
190
+ torch::optional<double> from_ruby<torch::optional<double>>(Object x)
191
+ {
192
+ if (x.is_nil()) {
193
+ return torch::nullopt;
194
+ } else {
195
+ return torch::optional<double>{from_ruby<double>(x)};
196
+ }
197
+ }
198
+
199
+ template<>
200
+ inline
201
+ torch::optional<bool> from_ruby<torch::optional<bool>>(Object x)
202
+ {
203
+ if (x.is_nil()) {
204
+ return torch::nullopt;
205
+ } else {
206
+ return torch::optional<bool>{from_ruby<bool>(x)};
207
+ }
208
+ }
209
+
210
+ template<>
211
+ inline
212
+ torch::optional<Scalar> from_ruby<torch::optional<Scalar>>(Object x)
213
+ {
214
+ if (x.is_nil()) {
215
+ return torch::nullopt;
216
+ } else {
217
+ return torch::optional<Scalar>{from_ruby<Scalar>(x)};
218
+ }
219
+ }