torch-rb 0.6.0 → 0.8.2

Sign up to get free protection for your applications and to get access to all the features.
Files changed (44) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +21 -0
  3. data/README.md +23 -41
  4. data/codegen/function.rb +2 -0
  5. data/codegen/generate_functions.rb +43 -6
  6. data/codegen/native_functions.yaml +2007 -1327
  7. data/ext/torch/backends.cpp +17 -0
  8. data/ext/torch/cuda.cpp +5 -5
  9. data/ext/torch/device.cpp +13 -6
  10. data/ext/torch/ext.cpp +22 -5
  11. data/ext/torch/extconf.rb +1 -3
  12. data/ext/torch/fft.cpp +13 -0
  13. data/ext/torch/fft_functions.h +6 -0
  14. data/ext/torch/ivalue.cpp +31 -33
  15. data/ext/torch/linalg.cpp +13 -0
  16. data/ext/torch/linalg_functions.h +6 -0
  17. data/ext/torch/nn.cpp +34 -34
  18. data/ext/torch/random.cpp +5 -5
  19. data/ext/torch/ruby_arg_parser.cpp +2 -2
  20. data/ext/torch/ruby_arg_parser.h +23 -12
  21. data/ext/torch/special.cpp +13 -0
  22. data/ext/torch/special_functions.h +6 -0
  23. data/ext/torch/templates.h +111 -133
  24. data/ext/torch/tensor.cpp +80 -67
  25. data/ext/torch/torch.cpp +30 -21
  26. data/ext/torch/utils.h +3 -4
  27. data/ext/torch/wrap_outputs.h +72 -65
  28. data/lib/torch/inspector.rb +5 -2
  29. data/lib/torch/nn/convnd.rb +2 -0
  30. data/lib/torch/nn/functional_attention.rb +241 -0
  31. data/lib/torch/nn/module.rb +2 -0
  32. data/lib/torch/nn/module_list.rb +49 -0
  33. data/lib/torch/nn/multihead_attention.rb +123 -0
  34. data/lib/torch/nn/transformer.rb +92 -0
  35. data/lib/torch/nn/transformer_decoder.rb +25 -0
  36. data/lib/torch/nn/transformer_decoder_layer.rb +43 -0
  37. data/lib/torch/nn/transformer_encoder.rb +25 -0
  38. data/lib/torch/nn/transformer_encoder_layer.rb +36 -0
  39. data/lib/torch/nn/utils.rb +16 -0
  40. data/lib/torch/tensor.rb +2 -0
  41. data/lib/torch/utils/data/data_loader.rb +2 -0
  42. data/lib/torch/version.rb +1 -1
  43. data/lib/torch.rb +11 -0
  44. metadata +20 -5
@@ -4,8 +4,7 @@
4
4
  #undef isfinite
5
5
  #endif
6
6
 
7
- #include <rice/Array.hpp>
8
- #include <rice/Object.hpp>
7
+ #include <rice/rice.hpp>
9
8
 
10
9
  using namespace Rice;
11
10
 
@@ -22,6 +21,10 @@ using torch::IntArrayRef;
22
21
  using torch::ArrayRef;
23
22
  using torch::TensorList;
24
23
  using torch::Storage;
24
+ using ScalarList = ArrayRef<Scalar>;
25
+
26
+ using torch::nn::init::FanModeType;
27
+ using torch::nn::init::NonlinearityType;
25
28
 
26
29
  #define HANDLE_TH_ERRORS \
27
30
  try {
@@ -38,37 +41,42 @@ using torch::Storage;
38
41
  #define RETURN_NIL \
39
42
  return Qnil;
40
43
 
41
- template<>
42
- inline
43
- std::vector<int64_t> from_ruby<std::vector<int64_t>>(Object x)
44
- {
45
- Array a = Array(x);
46
- std::vector<int64_t> vec(a.size());
47
- for (long i = 0; i < a.size(); i++) {
48
- vec[i] = from_ruby<int64_t>(a[i]);
49
- }
50
- return vec;
51
- }
44
+ class OptionalTensor {
45
+ torch::Tensor value;
46
+ public:
47
+ OptionalTensor(Object o) {
48
+ if (o.is_nil()) {
49
+ value = {};
50
+ } else {
51
+ value = Rice::detail::From_Ruby<torch::Tensor>().convert(o.value());
52
+ }
53
+ }
54
+ OptionalTensor(torch::Tensor o) {
55
+ value = o;
56
+ }
57
+ operator torch::Tensor() const {
58
+ return value;
59
+ }
60
+ };
52
61
 
53
- template<>
54
- inline
55
- std::vector<Tensor> from_ruby<std::vector<Tensor>>(Object x)
62
+ namespace Rice::detail
56
63
  {
57
- Array a = Array(x);
58
- std::vector<Tensor> vec(a.size());
59
- for (long i = 0; i < a.size(); i++) {
60
- vec[i] = from_ruby<Tensor>(a[i]);
61
- }
62
- return vec;
63
- }
64
+ template<>
65
+ struct Type<FanModeType>
66
+ {
67
+ static bool verify()
68
+ {
69
+ return true;
70
+ }
71
+ };
64
72
 
65
- class FanModeType {
66
- std::string s;
73
+ template<>
74
+ class From_Ruby<FanModeType>
75
+ {
67
76
  public:
68
- FanModeType(Object o) {
69
- s = String(o).str();
70
- }
71
- operator torch::nn::init::FanModeType() {
77
+ FanModeType convert(VALUE x)
78
+ {
79
+ auto s = String(x).str();
72
80
  if (s == "fan_in") {
73
81
  return torch::kFanIn;
74
82
  } else if (s == "fan_out") {
@@ -77,22 +85,24 @@ class FanModeType {
77
85
  throw std::runtime_error("Unsupported nonlinearity type: " + s);
78
86
  }
79
87
  }
80
- };
81
-
82
- template<>
83
- inline
84
- FanModeType from_ruby<FanModeType>(Object x)
85
- {
86
- return FanModeType(x);
87
- }
88
+ };
89
+
90
+ template<>
91
+ struct Type<NonlinearityType>
92
+ {
93
+ static bool verify()
94
+ {
95
+ return true;
96
+ }
97
+ };
88
98
 
89
- class NonlinearityType {
90
- std::string s;
99
+ template<>
100
+ class From_Ruby<NonlinearityType>
101
+ {
91
102
  public:
92
- NonlinearityType(Object o) {
93
- s = String(o).str();
94
- }
95
- operator torch::nn::init::NonlinearityType() {
103
+ NonlinearityType convert(VALUE x)
104
+ {
105
+ auto s = String(x).str();
96
106
  if (s == "linear") {
97
107
  return torch::kLinear;
98
108
  } else if (s == "conv1d") {
@@ -119,102 +129,70 @@ class NonlinearityType {
119
129
  throw std::runtime_error("Unsupported nonlinearity type: " + s);
120
130
  }
121
131
  }
122
- };
132
+ };
133
+
134
+ template<>
135
+ struct Type<OptionalTensor>
136
+ {
137
+ static bool verify()
138
+ {
139
+ return true;
140
+ }
141
+ };
123
142
 
124
- template<>
125
- inline
126
- NonlinearityType from_ruby<NonlinearityType>(Object x)
127
- {
128
- return NonlinearityType(x);
129
- }
143
+ template<>
144
+ class From_Ruby<OptionalTensor>
145
+ {
146
+ public:
147
+ OptionalTensor convert(VALUE x)
148
+ {
149
+ return OptionalTensor(x);
150
+ }
151
+ };
152
+
153
+ template<>
154
+ struct Type<Scalar>
155
+ {
156
+ static bool verify()
157
+ {
158
+ return true;
159
+ }
160
+ };
130
161
 
131
- class OptionalTensor {
132
- torch::Tensor value;
162
+ template<>
163
+ class From_Ruby<Scalar>
164
+ {
133
165
  public:
134
- OptionalTensor(Object o) {
135
- if (o.is_nil()) {
136
- value = {};
166
+ Scalar convert(VALUE x)
167
+ {
168
+ if (FIXNUM_P(x)) {
169
+ return torch::Scalar(From_Ruby<int64_t>().convert(x));
137
170
  } else {
138
- value = from_ruby<torch::Tensor>(o);
171
+ return torch::Scalar(From_Ruby<double>().convert(x));
139
172
  }
140
173
  }
141
- OptionalTensor(torch::Tensor o) {
142
- value = o;
143
- }
144
- operator torch::Tensor() const {
145
- return value;
174
+ };
175
+
176
+ template<typename T>
177
+ struct Type<torch::optional<T>>
178
+ {
179
+ static bool verify()
180
+ {
181
+ return true;
146
182
  }
147
- };
148
-
149
- template<>
150
- inline
151
- Scalar from_ruby<Scalar>(Object x)
152
- {
153
- if (x.rb_type() == T_FIXNUM) {
154
- return torch::Scalar(from_ruby<int64_t>(x));
155
- } else {
156
- return torch::Scalar(from_ruby<double>(x));
157
- }
158
- }
159
-
160
- template<>
161
- inline
162
- OptionalTensor from_ruby<OptionalTensor>(Object x)
163
- {
164
- return OptionalTensor(x);
165
- }
166
-
167
- template<>
168
- inline
169
- torch::optional<torch::ScalarType> from_ruby<torch::optional<torch::ScalarType>>(Object x)
170
- {
171
- if (x.is_nil()) {
172
- return torch::nullopt;
173
- } else {
174
- return torch::optional<torch::ScalarType>{from_ruby<torch::ScalarType>(x)};
175
- }
176
- }
177
-
178
- template<>
179
- inline
180
- torch::optional<int64_t> from_ruby<torch::optional<int64_t>>(Object x)
181
- {
182
- if (x.is_nil()) {
183
- return torch::nullopt;
184
- } else {
185
- return torch::optional<int64_t>{from_ruby<int64_t>(x)};
186
- }
187
- }
183
+ };
188
184
 
189
- template<>
190
- inline
191
- torch::optional<double> from_ruby<torch::optional<double>>(Object x)
192
- {
193
- if (x.is_nil()) {
194
- return torch::nullopt;
195
- } else {
196
- return torch::optional<double>{from_ruby<double>(x)};
197
- }
198
- }
199
-
200
- template<>
201
- inline
202
- torch::optional<bool> from_ruby<torch::optional<bool>>(Object x)
203
- {
204
- if (x.is_nil()) {
205
- return torch::nullopt;
206
- } else {
207
- return torch::optional<bool>{from_ruby<bool>(x)};
208
- }
209
- }
210
-
211
- template<>
212
- inline
213
- torch::optional<Scalar> from_ruby<torch::optional<Scalar>>(Object x)
214
- {
215
- if (x.is_nil()) {
216
- return torch::nullopt;
217
- } else {
218
- return torch::optional<Scalar>{from_ruby<Scalar>(x)};
219
- }
185
+ template<typename T>
186
+ class From_Ruby<torch::optional<T>>
187
+ {
188
+ public:
189
+ torch::optional<T> convert(VALUE x)
190
+ {
191
+ if (NIL_P(x)) {
192
+ return torch::nullopt;
193
+ } else {
194
+ return torch::optional<T>{From_Ruby<T>().convert(x)};
195
+ }
196
+ }
197
+ };
220
198
  }
data/ext/torch/tensor.cpp CHANGED
@@ -1,7 +1,6 @@
1
1
  #include <torch/torch.h>
2
2
 
3
- #include <rice/Constructor.hpp>
4
- #include <rice/Module.hpp>
3
+ #include <rice/rice.hpp>
5
4
 
6
5
  #include "tensor_functions.h"
7
6
  #include "ruby_arg_parser.h"
@@ -11,6 +10,39 @@
11
10
  using namespace Rice;
12
11
  using torch::indexing::TensorIndex;
13
12
 
13
+ namespace Rice::detail
14
+ {
15
+ template<typename T>
16
+ struct Type<c10::complex<T>>
17
+ {
18
+ static bool verify()
19
+ {
20
+ return true;
21
+ }
22
+ };
23
+
24
+ template<typename T>
25
+ class To_Ruby<c10::complex<T>>
26
+ {
27
+ public:
28
+ VALUE convert(c10::complex<T> const& x)
29
+ {
30
+ return rb_dbl_complex_new(x.real(), x.imag());
31
+ }
32
+ };
33
+ }
34
+
35
+ template<typename T>
36
+ Array flat_data(Tensor& tensor) {
37
+ Tensor view = tensor.reshape({tensor.numel()});
38
+
39
+ Array a;
40
+ for (int i = 0; i < tensor.numel(); i++) {
41
+ a.push(view[i].item().to<T>());
42
+ }
43
+ return a;
44
+ }
45
+
14
46
  Class rb_cTensor;
15
47
 
16
48
  std::vector<TensorIndex> index_vector(Array a) {
@@ -23,19 +55,19 @@ std::vector<TensorIndex> index_vector(Array a) {
23
55
  obj = a[i];
24
56
 
25
57
  if (obj.is_instance_of(rb_cInteger)) {
26
- indices.push_back(from_ruby<int64_t>(obj));
58
+ indices.push_back(Rice::detail::From_Ruby<int64_t>().convert(obj.value()));
27
59
  } else if (obj.is_instance_of(rb_cRange)) {
28
60
  torch::optional<int64_t> start_index = torch::nullopt;
29
61
  torch::optional<int64_t> stop_index = torch::nullopt;
30
62
 
31
63
  Object begin = obj.call("begin");
32
64
  if (!begin.is_nil()) {
33
- start_index = from_ruby<int64_t>(begin);
65
+ start_index = Rice::detail::From_Ruby<int64_t>().convert(begin.value());
34
66
  }
35
67
 
36
68
  Object end = obj.call("end");
37
69
  if (!end.is_nil()) {
38
- stop_index = from_ruby<int64_t>(end);
70
+ stop_index = Rice::detail::From_Ruby<int64_t>().convert(end.value());
39
71
  }
40
72
 
41
73
  Object exclude_end = obj.call("exclude_end?");
@@ -49,11 +81,11 @@ std::vector<TensorIndex> index_vector(Array a) {
49
81
 
50
82
  indices.push_back(torch::indexing::Slice(start_index, stop_index));
51
83
  } else if (obj.is_instance_of(rb_cTensor)) {
52
- indices.push_back(from_ruby<Tensor>(obj));
84
+ indices.push_back(Rice::detail::From_Ruby<Tensor>().convert(obj.value()));
53
85
  } else if (obj.is_nil()) {
54
86
  indices.push_back(torch::indexing::None);
55
87
  } else if (obj == True || obj == False) {
56
- indices.push_back(from_ruby<bool>(obj));
88
+ indices.push_back(Rice::detail::From_Ruby<bool>().convert(obj.value()));
57
89
  } else {
58
90
  throw Exception(rb_eArgError, "Unsupported index type: %s", rb_obj_classname(obj));
59
91
  }
@@ -68,7 +100,7 @@ std::vector<TensorIndex> index_vector(Array a) {
68
100
  static VALUE tensor__backward(int argc, VALUE* argv, VALUE self_)
69
101
  {
70
102
  HANDLE_TH_ERRORS
71
- Tensor& self = from_ruby<Tensor&>(self_);
103
+ Tensor& self = Rice::detail::From_Ruby<Tensor&>().convert(self_);
72
104
  static RubyArgParser parser({
73
105
  "_backward(Tensor? gradient=None, bool? retain_graph=None, bool create_graph=False)"
74
106
  });
@@ -84,8 +116,8 @@ static VALUE tensor__backward(int argc, VALUE* argv, VALUE self_)
84
116
  END_HANDLE_TH_ERRORS
85
117
  }
86
118
 
87
- void init_tensor(Rice::Module& m) {
88
- rb_cTensor = Rice::define_class_under<torch::Tensor>(m, "Tensor");
119
+ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions) {
120
+ rb_cTensor = c;
89
121
  rb_cTensor.add_handler<torch::Error>(handle_error);
90
122
  add_tensor_functions(rb_cTensor);
91
123
  THPVariableClass = rb_cTensor.value();
@@ -102,18 +134,18 @@ void init_tensor(Rice::Module& m) {
102
134
  .define_method("requires_grad", &torch::Tensor::requires_grad)
103
135
  .define_method(
104
136
  "_size",
105
- *[](Tensor& self, int64_t dim) {
137
+ [](Tensor& self, int64_t dim) {
106
138
  return self.size(dim);
107
139
  })
108
140
  .define_method(
109
141
  "_stride",
110
- *[](Tensor& self, int64_t dim) {
142
+ [](Tensor& self, int64_t dim) {
111
143
  return self.stride(dim);
112
144
  })
113
145
  // in C++ for performance
114
146
  .define_method(
115
147
  "shape",
116
- *[](Tensor& self) {
148
+ [](Tensor& self) {
117
149
  Array a;
118
150
  for (auto &size : self.sizes()) {
119
151
  a.push(size);
@@ -122,7 +154,7 @@ void init_tensor(Rice::Module& m) {
122
154
  })
123
155
  .define_method(
124
156
  "_strides",
125
- *[](Tensor& self) {
157
+ [](Tensor& self) {
126
158
  Array a;
127
159
  for (auto &stride : self.strides()) {
128
160
  a.push(stride);
@@ -131,65 +163,65 @@ void init_tensor(Rice::Module& m) {
131
163
  })
132
164
  .define_method(
133
165
  "_index",
134
- *[](Tensor& self, Array indices) {
166
+ [](Tensor& self, Array indices) {
135
167
  auto vec = index_vector(indices);
136
168
  return self.index(vec);
137
169
  })
138
170
  .define_method(
139
171
  "_index_put_custom",
140
- *[](Tensor& self, Array indices, torch::Tensor& value) {
172
+ [](Tensor& self, Array indices, torch::Tensor& value) {
141
173
  auto vec = index_vector(indices);
142
174
  return self.index_put_(vec, value);
143
175
  })
144
176
  .define_method(
145
177
  "contiguous?",
146
- *[](Tensor& self) {
178
+ [](Tensor& self) {
147
179
  return self.is_contiguous();
148
180
  })
149
181
  .define_method(
150
182
  "_requires_grad!",
151
- *[](Tensor& self, bool requires_grad) {
183
+ [](Tensor& self, bool requires_grad) {
152
184
  return self.set_requires_grad(requires_grad);
153
185
  })
154
186
  .define_method(
155
187
  "grad",
156
- *[](Tensor& self) {
188
+ [](Tensor& self) {
157
189
  auto grad = self.grad();
158
- return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
190
+ return grad.defined() ? Object(Rice::detail::To_Ruby<torch::Tensor>().convert(grad)) : Nil;
159
191
  })
160
192
  .define_method(
161
193
  "grad=",
162
- *[](Tensor& self, torch::Tensor& grad) {
194
+ [](Tensor& self, torch::Tensor& grad) {
163
195
  self.mutable_grad() = grad;
164
196
  })
165
197
  .define_method(
166
198
  "_dtype",
167
- *[](Tensor& self) {
199
+ [](Tensor& self) {
168
200
  return (int) at::typeMetaToScalarType(self.dtype());
169
201
  })
170
202
  .define_method(
171
203
  "_type",
172
- *[](Tensor& self, int dtype) {
204
+ [](Tensor& self, int dtype) {
173
205
  return self.toType((torch::ScalarType) dtype);
174
206
  })
175
207
  .define_method(
176
208
  "_layout",
177
- *[](Tensor& self) {
209
+ [](Tensor& self) {
178
210
  std::stringstream s;
179
211
  s << self.layout();
180
212
  return s.str();
181
213
  })
182
214
  .define_method(
183
215
  "device",
184
- *[](Tensor& self) {
216
+ [](Tensor& self) {
185
217
  std::stringstream s;
186
218
  s << self.device();
187
219
  return s.str();
188
220
  })
189
221
  .define_method(
190
222
  "_data_str",
191
- *[](Tensor& self) {
192
- Tensor tensor = self;
223
+ [](Tensor& self) {
224
+ auto tensor = self;
193
225
 
194
226
  // move to CPU to get data
195
227
  if (tensor.device().type() != torch::kCPU) {
@@ -207,14 +239,14 @@ void init_tensor(Rice::Module& m) {
207
239
  // for TorchVision
208
240
  .define_method(
209
241
  "_data_ptr",
210
- *[](Tensor& self) {
242
+ [](Tensor& self) {
211
243
  return reinterpret_cast<uintptr_t>(self.data_ptr());
212
244
  })
213
245
  // TODO figure out a better way to do this
214
246
  .define_method(
215
247
  "_flat_data",
216
- *[](Tensor& self) {
217
- Tensor tensor = self;
248
+ [](Tensor& self) {
249
+ auto tensor = self;
218
250
 
219
251
  // move to CPU to get data
220
252
  if (tensor.device().type() != torch::kCPU) {
@@ -222,66 +254,47 @@ void init_tensor(Rice::Module& m) {
222
254
  tensor = tensor.to(device);
223
255
  }
224
256
 
225
- Array a;
226
257
  auto dtype = tensor.dtype();
227
-
228
- Tensor view = tensor.reshape({tensor.numel()});
229
-
230
- // TODO DRY if someone knows C++
231
258
  if (dtype == torch::kByte) {
232
- for (int i = 0; i < tensor.numel(); i++) {
233
- a.push(view[i].item().to<uint8_t>());
234
- }
259
+ return flat_data<uint8_t>(tensor);
235
260
  } else if (dtype == torch::kChar) {
236
- for (int i = 0; i < tensor.numel(); i++) {
237
- a.push(to_ruby<int>(view[i].item().to<int8_t>()));
238
- }
261
+ return flat_data<int8_t>(tensor);
239
262
  } else if (dtype == torch::kShort) {
240
- for (int i = 0; i < tensor.numel(); i++) {
241
- a.push(view[i].item().to<int16_t>());
242
- }
263
+ return flat_data<int16_t>(tensor);
243
264
  } else if (dtype == torch::kInt) {
244
- for (int i = 0; i < tensor.numel(); i++) {
245
- a.push(view[i].item().to<int32_t>());
246
- }
265
+ return flat_data<int32_t>(tensor);
247
266
  } else if (dtype == torch::kLong) {
248
- for (int i = 0; i < tensor.numel(); i++) {
249
- a.push(view[i].item().to<int64_t>());
250
- }
267
+ return flat_data<int64_t>(tensor);
251
268
  } else if (dtype == torch::kFloat) {
252
- for (int i = 0; i < tensor.numel(); i++) {
253
- a.push(view[i].item().to<float>());
254
- }
269
+ return flat_data<float>(tensor);
255
270
  } else if (dtype == torch::kDouble) {
256
- for (int i = 0; i < tensor.numel(); i++) {
257
- a.push(view[i].item().to<double>());
258
- }
271
+ return flat_data<double>(tensor);
259
272
  } else if (dtype == torch::kBool) {
260
- for (int i = 0; i < tensor.numel(); i++) {
261
- a.push(view[i].item().to<bool>() ? True : False);
262
- }
273
+ return flat_data<bool>(tensor);
274
+ } else if (dtype == torch::kComplexFloat) {
275
+ return flat_data<c10::complex<float>>(tensor);
276
+ } else if (dtype == torch::kComplexDouble) {
277
+ return flat_data<c10::complex<double>>(tensor);
263
278
  } else {
264
279
  throw std::runtime_error("Unsupported type");
265
280
  }
266
- return a;
267
281
  })
268
282
  .define_method(
269
283
  "_to",
270
- *[](Tensor& self, torch::Device device, int dtype, bool non_blocking, bool copy) {
284
+ [](Tensor& self, torch::Device device, int dtype, bool non_blocking, bool copy) {
271
285
  return self.to(device, (torch::ScalarType) dtype, non_blocking, copy);
272
286
  });
273
287
 
274
- Rice::define_class_under<torch::TensorOptions>(m, "TensorOptions")
288
+ rb_cTensorOptions
275
289
  .add_handler<torch::Error>(handle_error)
276
- .define_constructor(Rice::Constructor<torch::TensorOptions>())
277
290
  .define_method(
278
291
  "dtype",
279
- *[](torch::TensorOptions& self, int dtype) {
292
+ [](torch::TensorOptions& self, int dtype) {
280
293
  return self.dtype((torch::ScalarType) dtype);
281
294
  })
282
295
  .define_method(
283
296
  "layout",
284
- *[](torch::TensorOptions& self, const std::string& layout) {
297
+ [](torch::TensorOptions& self, const std::string& layout) {
285
298
  torch::Layout l;
286
299
  if (layout == "strided") {
287
300
  l = torch::kStrided;
@@ -295,13 +308,13 @@ void init_tensor(Rice::Module& m) {
295
308
  })
296
309
  .define_method(
297
310
  "device",
298
- *[](torch::TensorOptions& self, const std::string& device) {
311
+ [](torch::TensorOptions& self, const std::string& device) {
299
312
  torch::Device d(device);
300
313
  return self.device(d);
301
314
  })
302
315
  .define_method(
303
316
  "requires_grad",
304
- *[](torch::TensorOptions& self, bool requires_grad) {
317
+ [](torch::TensorOptions& self, bool requires_grad) {
305
318
  return self.requires_grad(requires_grad);
306
319
  });
307
320
  }