torch-rb 0.6.0 → 0.8.2

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +21 -0
  3. data/README.md +23 -41
  4. data/codegen/function.rb +2 -0
  5. data/codegen/generate_functions.rb +43 -6
  6. data/codegen/native_functions.yaml +2007 -1327
  7. data/ext/torch/backends.cpp +17 -0
  8. data/ext/torch/cuda.cpp +5 -5
  9. data/ext/torch/device.cpp +13 -6
  10. data/ext/torch/ext.cpp +22 -5
  11. data/ext/torch/extconf.rb +1 -3
  12. data/ext/torch/fft.cpp +13 -0
  13. data/ext/torch/fft_functions.h +6 -0
  14. data/ext/torch/ivalue.cpp +31 -33
  15. data/ext/torch/linalg.cpp +13 -0
  16. data/ext/torch/linalg_functions.h +6 -0
  17. data/ext/torch/nn.cpp +34 -34
  18. data/ext/torch/random.cpp +5 -5
  19. data/ext/torch/ruby_arg_parser.cpp +2 -2
  20. data/ext/torch/ruby_arg_parser.h +23 -12
  21. data/ext/torch/special.cpp +13 -0
  22. data/ext/torch/special_functions.h +6 -0
  23. data/ext/torch/templates.h +111 -133
  24. data/ext/torch/tensor.cpp +80 -67
  25. data/ext/torch/torch.cpp +30 -21
  26. data/ext/torch/utils.h +3 -4
  27. data/ext/torch/wrap_outputs.h +72 -65
  28. data/lib/torch/inspector.rb +5 -2
  29. data/lib/torch/nn/convnd.rb +2 -0
  30. data/lib/torch/nn/functional_attention.rb +241 -0
  31. data/lib/torch/nn/module.rb +2 -0
  32. data/lib/torch/nn/module_list.rb +49 -0
  33. data/lib/torch/nn/multihead_attention.rb +123 -0
  34. data/lib/torch/nn/transformer.rb +92 -0
  35. data/lib/torch/nn/transformer_decoder.rb +25 -0
  36. data/lib/torch/nn/transformer_decoder_layer.rb +43 -0
  37. data/lib/torch/nn/transformer_encoder.rb +25 -0
  38. data/lib/torch/nn/transformer_encoder_layer.rb +36 -0
  39. data/lib/torch/nn/utils.rb +16 -0
  40. data/lib/torch/tensor.rb +2 -0
  41. data/lib/torch/utils/data/data_loader.rb +2 -0
  42. data/lib/torch/version.rb +1 -1
  43. data/lib/torch.rb +11 -0
  44. metadata +20 -5
@@ -4,8 +4,7 @@
4
4
  #undef isfinite
5
5
  #endif
6
6
 
7
- #include <rice/Array.hpp>
8
- #include <rice/Object.hpp>
7
+ #include <rice/rice.hpp>
9
8
 
10
9
  using namespace Rice;
11
10
 
@@ -22,6 +21,10 @@ using torch::IntArrayRef;
22
21
  using torch::ArrayRef;
23
22
  using torch::TensorList;
24
23
  using torch::Storage;
24
+ using ScalarList = ArrayRef<Scalar>;
25
+
26
+ using torch::nn::init::FanModeType;
27
+ using torch::nn::init::NonlinearityType;
25
28
 
26
29
  #define HANDLE_TH_ERRORS \
27
30
  try {
@@ -38,37 +41,42 @@ using torch::Storage;
38
41
  #define RETURN_NIL \
39
42
  return Qnil;
40
43
 
41
- template<>
42
- inline
43
- std::vector<int64_t> from_ruby<std::vector<int64_t>>(Object x)
44
- {
45
- Array a = Array(x);
46
- std::vector<int64_t> vec(a.size());
47
- for (long i = 0; i < a.size(); i++) {
48
- vec[i] = from_ruby<int64_t>(a[i]);
49
- }
50
- return vec;
51
- }
44
+ class OptionalTensor {
45
+ torch::Tensor value;
46
+ public:
47
+ OptionalTensor(Object o) {
48
+ if (o.is_nil()) {
49
+ value = {};
50
+ } else {
51
+ value = Rice::detail::From_Ruby<torch::Tensor>().convert(o.value());
52
+ }
53
+ }
54
+ OptionalTensor(torch::Tensor o) {
55
+ value = o;
56
+ }
57
+ operator torch::Tensor() const {
58
+ return value;
59
+ }
60
+ };
52
61
 
53
- template<>
54
- inline
55
- std::vector<Tensor> from_ruby<std::vector<Tensor>>(Object x)
62
+ namespace Rice::detail
56
63
  {
57
- Array a = Array(x);
58
- std::vector<Tensor> vec(a.size());
59
- for (long i = 0; i < a.size(); i++) {
60
- vec[i] = from_ruby<Tensor>(a[i]);
61
- }
62
- return vec;
63
- }
64
+ template<>
65
+ struct Type<FanModeType>
66
+ {
67
+ static bool verify()
68
+ {
69
+ return true;
70
+ }
71
+ };
64
72
 
65
- class FanModeType {
66
- std::string s;
73
+ template<>
74
+ class From_Ruby<FanModeType>
75
+ {
67
76
  public:
68
- FanModeType(Object o) {
69
- s = String(o).str();
70
- }
71
- operator torch::nn::init::FanModeType() {
77
+ FanModeType convert(VALUE x)
78
+ {
79
+ auto s = String(x).str();
72
80
  if (s == "fan_in") {
73
81
  return torch::kFanIn;
74
82
  } else if (s == "fan_out") {
@@ -77,22 +85,24 @@ class FanModeType {
77
85
  throw std::runtime_error("Unsupported nonlinearity type: " + s);
78
86
  }
79
87
  }
80
- };
81
-
82
- template<>
83
- inline
84
- FanModeType from_ruby<FanModeType>(Object x)
85
- {
86
- return FanModeType(x);
87
- }
88
+ };
89
+
90
+ template<>
91
+ struct Type<NonlinearityType>
92
+ {
93
+ static bool verify()
94
+ {
95
+ return true;
96
+ }
97
+ };
88
98
 
89
- class NonlinearityType {
90
- std::string s;
99
+ template<>
100
+ class From_Ruby<NonlinearityType>
101
+ {
91
102
  public:
92
- NonlinearityType(Object o) {
93
- s = String(o).str();
94
- }
95
- operator torch::nn::init::NonlinearityType() {
103
+ NonlinearityType convert(VALUE x)
104
+ {
105
+ auto s = String(x).str();
96
106
  if (s == "linear") {
97
107
  return torch::kLinear;
98
108
  } else if (s == "conv1d") {
@@ -119,102 +129,70 @@ class NonlinearityType {
119
129
  throw std::runtime_error("Unsupported nonlinearity type: " + s);
120
130
  }
121
131
  }
122
- };
132
+ };
133
+
134
+ template<>
135
+ struct Type<OptionalTensor>
136
+ {
137
+ static bool verify()
138
+ {
139
+ return true;
140
+ }
141
+ };
123
142
 
124
- template<>
125
- inline
126
- NonlinearityType from_ruby<NonlinearityType>(Object x)
127
- {
128
- return NonlinearityType(x);
129
- }
143
+ template<>
144
+ class From_Ruby<OptionalTensor>
145
+ {
146
+ public:
147
+ OptionalTensor convert(VALUE x)
148
+ {
149
+ return OptionalTensor(x);
150
+ }
151
+ };
152
+
153
+ template<>
154
+ struct Type<Scalar>
155
+ {
156
+ static bool verify()
157
+ {
158
+ return true;
159
+ }
160
+ };
130
161
 
131
- class OptionalTensor {
132
- torch::Tensor value;
162
+ template<>
163
+ class From_Ruby<Scalar>
164
+ {
133
165
  public:
134
- OptionalTensor(Object o) {
135
- if (o.is_nil()) {
136
- value = {};
166
+ Scalar convert(VALUE x)
167
+ {
168
+ if (FIXNUM_P(x)) {
169
+ return torch::Scalar(From_Ruby<int64_t>().convert(x));
137
170
  } else {
138
- value = from_ruby<torch::Tensor>(o);
171
+ return torch::Scalar(From_Ruby<double>().convert(x));
139
172
  }
140
173
  }
141
- OptionalTensor(torch::Tensor o) {
142
- value = o;
143
- }
144
- operator torch::Tensor() const {
145
- return value;
174
+ };
175
+
176
+ template<typename T>
177
+ struct Type<torch::optional<T>>
178
+ {
179
+ static bool verify()
180
+ {
181
+ return true;
146
182
  }
147
- };
148
-
149
- template<>
150
- inline
151
- Scalar from_ruby<Scalar>(Object x)
152
- {
153
- if (x.rb_type() == T_FIXNUM) {
154
- return torch::Scalar(from_ruby<int64_t>(x));
155
- } else {
156
- return torch::Scalar(from_ruby<double>(x));
157
- }
158
- }
159
-
160
- template<>
161
- inline
162
- OptionalTensor from_ruby<OptionalTensor>(Object x)
163
- {
164
- return OptionalTensor(x);
165
- }
166
-
167
- template<>
168
- inline
169
- torch::optional<torch::ScalarType> from_ruby<torch::optional<torch::ScalarType>>(Object x)
170
- {
171
- if (x.is_nil()) {
172
- return torch::nullopt;
173
- } else {
174
- return torch::optional<torch::ScalarType>{from_ruby<torch::ScalarType>(x)};
175
- }
176
- }
177
-
178
- template<>
179
- inline
180
- torch::optional<int64_t> from_ruby<torch::optional<int64_t>>(Object x)
181
- {
182
- if (x.is_nil()) {
183
- return torch::nullopt;
184
- } else {
185
- return torch::optional<int64_t>{from_ruby<int64_t>(x)};
186
- }
187
- }
183
+ };
188
184
 
189
- template<>
190
- inline
191
- torch::optional<double> from_ruby<torch::optional<double>>(Object x)
192
- {
193
- if (x.is_nil()) {
194
- return torch::nullopt;
195
- } else {
196
- return torch::optional<double>{from_ruby<double>(x)};
197
- }
198
- }
199
-
200
- template<>
201
- inline
202
- torch::optional<bool> from_ruby<torch::optional<bool>>(Object x)
203
- {
204
- if (x.is_nil()) {
205
- return torch::nullopt;
206
- } else {
207
- return torch::optional<bool>{from_ruby<bool>(x)};
208
- }
209
- }
210
-
211
- template<>
212
- inline
213
- torch::optional<Scalar> from_ruby<torch::optional<Scalar>>(Object x)
214
- {
215
- if (x.is_nil()) {
216
- return torch::nullopt;
217
- } else {
218
- return torch::optional<Scalar>{from_ruby<Scalar>(x)};
219
- }
185
+ template<typename T>
186
+ class From_Ruby<torch::optional<T>>
187
+ {
188
+ public:
189
+ torch::optional<T> convert(VALUE x)
190
+ {
191
+ if (NIL_P(x)) {
192
+ return torch::nullopt;
193
+ } else {
194
+ return torch::optional<T>{From_Ruby<T>().convert(x)};
195
+ }
196
+ }
197
+ };
220
198
  }
data/ext/torch/tensor.cpp CHANGED
@@ -1,7 +1,6 @@
1
1
  #include <torch/torch.h>
2
2
 
3
- #include <rice/Constructor.hpp>
4
- #include <rice/Module.hpp>
3
+ #include <rice/rice.hpp>
5
4
 
6
5
  #include "tensor_functions.h"
7
6
  #include "ruby_arg_parser.h"
@@ -11,6 +10,39 @@
11
10
  using namespace Rice;
12
11
  using torch::indexing::TensorIndex;
13
12
 
13
+ namespace Rice::detail
14
+ {
15
+ template<typename T>
16
+ struct Type<c10::complex<T>>
17
+ {
18
+ static bool verify()
19
+ {
20
+ return true;
21
+ }
22
+ };
23
+
24
+ template<typename T>
25
+ class To_Ruby<c10::complex<T>>
26
+ {
27
+ public:
28
+ VALUE convert(c10::complex<T> const& x)
29
+ {
30
+ return rb_dbl_complex_new(x.real(), x.imag());
31
+ }
32
+ };
33
+ }
34
+
35
+ template<typename T>
36
+ Array flat_data(Tensor& tensor) {
37
+ Tensor view = tensor.reshape({tensor.numel()});
38
+
39
+ Array a;
40
+ for (int i = 0; i < tensor.numel(); i++) {
41
+ a.push(view[i].item().to<T>());
42
+ }
43
+ return a;
44
+ }
45
+
14
46
  Class rb_cTensor;
15
47
 
16
48
  std::vector<TensorIndex> index_vector(Array a) {
@@ -23,19 +55,19 @@ std::vector<TensorIndex> index_vector(Array a) {
23
55
  obj = a[i];
24
56
 
25
57
  if (obj.is_instance_of(rb_cInteger)) {
26
- indices.push_back(from_ruby<int64_t>(obj));
58
+ indices.push_back(Rice::detail::From_Ruby<int64_t>().convert(obj.value()));
27
59
  } else if (obj.is_instance_of(rb_cRange)) {
28
60
  torch::optional<int64_t> start_index = torch::nullopt;
29
61
  torch::optional<int64_t> stop_index = torch::nullopt;
30
62
 
31
63
  Object begin = obj.call("begin");
32
64
  if (!begin.is_nil()) {
33
- start_index = from_ruby<int64_t>(begin);
65
+ start_index = Rice::detail::From_Ruby<int64_t>().convert(begin.value());
34
66
  }
35
67
 
36
68
  Object end = obj.call("end");
37
69
  if (!end.is_nil()) {
38
- stop_index = from_ruby<int64_t>(end);
70
+ stop_index = Rice::detail::From_Ruby<int64_t>().convert(end.value());
39
71
  }
40
72
 
41
73
  Object exclude_end = obj.call("exclude_end?");
@@ -49,11 +81,11 @@ std::vector<TensorIndex> index_vector(Array a) {
49
81
 
50
82
  indices.push_back(torch::indexing::Slice(start_index, stop_index));
51
83
  } else if (obj.is_instance_of(rb_cTensor)) {
52
- indices.push_back(from_ruby<Tensor>(obj));
84
+ indices.push_back(Rice::detail::From_Ruby<Tensor>().convert(obj.value()));
53
85
  } else if (obj.is_nil()) {
54
86
  indices.push_back(torch::indexing::None);
55
87
  } else if (obj == True || obj == False) {
56
- indices.push_back(from_ruby<bool>(obj));
88
+ indices.push_back(Rice::detail::From_Ruby<bool>().convert(obj.value()));
57
89
  } else {
58
90
  throw Exception(rb_eArgError, "Unsupported index type: %s", rb_obj_classname(obj));
59
91
  }
@@ -68,7 +100,7 @@ std::vector<TensorIndex> index_vector(Array a) {
68
100
  static VALUE tensor__backward(int argc, VALUE* argv, VALUE self_)
69
101
  {
70
102
  HANDLE_TH_ERRORS
71
- Tensor& self = from_ruby<Tensor&>(self_);
103
+ Tensor& self = Rice::detail::From_Ruby<Tensor&>().convert(self_);
72
104
  static RubyArgParser parser({
73
105
  "_backward(Tensor? gradient=None, bool? retain_graph=None, bool create_graph=False)"
74
106
  });
@@ -84,8 +116,8 @@ static VALUE tensor__backward(int argc, VALUE* argv, VALUE self_)
84
116
  END_HANDLE_TH_ERRORS
85
117
  }
86
118
 
87
- void init_tensor(Rice::Module& m) {
88
- rb_cTensor = Rice::define_class_under<torch::Tensor>(m, "Tensor");
119
+ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions) {
120
+ rb_cTensor = c;
89
121
  rb_cTensor.add_handler<torch::Error>(handle_error);
90
122
  add_tensor_functions(rb_cTensor);
91
123
  THPVariableClass = rb_cTensor.value();
@@ -102,18 +134,18 @@ void init_tensor(Rice::Module& m) {
102
134
  .define_method("requires_grad", &torch::Tensor::requires_grad)
103
135
  .define_method(
104
136
  "_size",
105
- *[](Tensor& self, int64_t dim) {
137
+ [](Tensor& self, int64_t dim) {
106
138
  return self.size(dim);
107
139
  })
108
140
  .define_method(
109
141
  "_stride",
110
- *[](Tensor& self, int64_t dim) {
142
+ [](Tensor& self, int64_t dim) {
111
143
  return self.stride(dim);
112
144
  })
113
145
  // in C++ for performance
114
146
  .define_method(
115
147
  "shape",
116
- *[](Tensor& self) {
148
+ [](Tensor& self) {
117
149
  Array a;
118
150
  for (auto &size : self.sizes()) {
119
151
  a.push(size);
@@ -122,7 +154,7 @@ void init_tensor(Rice::Module& m) {
122
154
  })
123
155
  .define_method(
124
156
  "_strides",
125
- *[](Tensor& self) {
157
+ [](Tensor& self) {
126
158
  Array a;
127
159
  for (auto &stride : self.strides()) {
128
160
  a.push(stride);
@@ -131,65 +163,65 @@ void init_tensor(Rice::Module& m) {
131
163
  })
132
164
  .define_method(
133
165
  "_index",
134
- *[](Tensor& self, Array indices) {
166
+ [](Tensor& self, Array indices) {
135
167
  auto vec = index_vector(indices);
136
168
  return self.index(vec);
137
169
  })
138
170
  .define_method(
139
171
  "_index_put_custom",
140
- *[](Tensor& self, Array indices, torch::Tensor& value) {
172
+ [](Tensor& self, Array indices, torch::Tensor& value) {
141
173
  auto vec = index_vector(indices);
142
174
  return self.index_put_(vec, value);
143
175
  })
144
176
  .define_method(
145
177
  "contiguous?",
146
- *[](Tensor& self) {
178
+ [](Tensor& self) {
147
179
  return self.is_contiguous();
148
180
  })
149
181
  .define_method(
150
182
  "_requires_grad!",
151
- *[](Tensor& self, bool requires_grad) {
183
+ [](Tensor& self, bool requires_grad) {
152
184
  return self.set_requires_grad(requires_grad);
153
185
  })
154
186
  .define_method(
155
187
  "grad",
156
- *[](Tensor& self) {
188
+ [](Tensor& self) {
157
189
  auto grad = self.grad();
158
- return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
190
+ return grad.defined() ? Object(Rice::detail::To_Ruby<torch::Tensor>().convert(grad)) : Nil;
159
191
  })
160
192
  .define_method(
161
193
  "grad=",
162
- *[](Tensor& self, torch::Tensor& grad) {
194
+ [](Tensor& self, torch::Tensor& grad) {
163
195
  self.mutable_grad() = grad;
164
196
  })
165
197
  .define_method(
166
198
  "_dtype",
167
- *[](Tensor& self) {
199
+ [](Tensor& self) {
168
200
  return (int) at::typeMetaToScalarType(self.dtype());
169
201
  })
170
202
  .define_method(
171
203
  "_type",
172
- *[](Tensor& self, int dtype) {
204
+ [](Tensor& self, int dtype) {
173
205
  return self.toType((torch::ScalarType) dtype);
174
206
  })
175
207
  .define_method(
176
208
  "_layout",
177
- *[](Tensor& self) {
209
+ [](Tensor& self) {
178
210
  std::stringstream s;
179
211
  s << self.layout();
180
212
  return s.str();
181
213
  })
182
214
  .define_method(
183
215
  "device",
184
- *[](Tensor& self) {
216
+ [](Tensor& self) {
185
217
  std::stringstream s;
186
218
  s << self.device();
187
219
  return s.str();
188
220
  })
189
221
  .define_method(
190
222
  "_data_str",
191
- *[](Tensor& self) {
192
- Tensor tensor = self;
223
+ [](Tensor& self) {
224
+ auto tensor = self;
193
225
 
194
226
  // move to CPU to get data
195
227
  if (tensor.device().type() != torch::kCPU) {
@@ -207,14 +239,14 @@ void init_tensor(Rice::Module& m) {
207
239
  // for TorchVision
208
240
  .define_method(
209
241
  "_data_ptr",
210
- *[](Tensor& self) {
242
+ [](Tensor& self) {
211
243
  return reinterpret_cast<uintptr_t>(self.data_ptr());
212
244
  })
213
245
  // TODO figure out a better way to do this
214
246
  .define_method(
215
247
  "_flat_data",
216
- *[](Tensor& self) {
217
- Tensor tensor = self;
248
+ [](Tensor& self) {
249
+ auto tensor = self;
218
250
 
219
251
  // move to CPU to get data
220
252
  if (tensor.device().type() != torch::kCPU) {
@@ -222,66 +254,47 @@ void init_tensor(Rice::Module& m) {
222
254
  tensor = tensor.to(device);
223
255
  }
224
256
 
225
- Array a;
226
257
  auto dtype = tensor.dtype();
227
-
228
- Tensor view = tensor.reshape({tensor.numel()});
229
-
230
- // TODO DRY if someone knows C++
231
258
  if (dtype == torch::kByte) {
232
- for (int i = 0; i < tensor.numel(); i++) {
233
- a.push(view[i].item().to<uint8_t>());
234
- }
259
+ return flat_data<uint8_t>(tensor);
235
260
  } else if (dtype == torch::kChar) {
236
- for (int i = 0; i < tensor.numel(); i++) {
237
- a.push(to_ruby<int>(view[i].item().to<int8_t>()));
238
- }
261
+ return flat_data<int8_t>(tensor);
239
262
  } else if (dtype == torch::kShort) {
240
- for (int i = 0; i < tensor.numel(); i++) {
241
- a.push(view[i].item().to<int16_t>());
242
- }
263
+ return flat_data<int16_t>(tensor);
243
264
  } else if (dtype == torch::kInt) {
244
- for (int i = 0; i < tensor.numel(); i++) {
245
- a.push(view[i].item().to<int32_t>());
246
- }
265
+ return flat_data<int32_t>(tensor);
247
266
  } else if (dtype == torch::kLong) {
248
- for (int i = 0; i < tensor.numel(); i++) {
249
- a.push(view[i].item().to<int64_t>());
250
- }
267
+ return flat_data<int64_t>(tensor);
251
268
  } else if (dtype == torch::kFloat) {
252
- for (int i = 0; i < tensor.numel(); i++) {
253
- a.push(view[i].item().to<float>());
254
- }
269
+ return flat_data<float>(tensor);
255
270
  } else if (dtype == torch::kDouble) {
256
- for (int i = 0; i < tensor.numel(); i++) {
257
- a.push(view[i].item().to<double>());
258
- }
271
+ return flat_data<double>(tensor);
259
272
  } else if (dtype == torch::kBool) {
260
- for (int i = 0; i < tensor.numel(); i++) {
261
- a.push(view[i].item().to<bool>() ? True : False);
262
- }
273
+ return flat_data<bool>(tensor);
274
+ } else if (dtype == torch::kComplexFloat) {
275
+ return flat_data<c10::complex<float>>(tensor);
276
+ } else if (dtype == torch::kComplexDouble) {
277
+ return flat_data<c10::complex<double>>(tensor);
263
278
  } else {
264
279
  throw std::runtime_error("Unsupported type");
265
280
  }
266
- return a;
267
281
  })
268
282
  .define_method(
269
283
  "_to",
270
- *[](Tensor& self, torch::Device device, int dtype, bool non_blocking, bool copy) {
284
+ [](Tensor& self, torch::Device device, int dtype, bool non_blocking, bool copy) {
271
285
  return self.to(device, (torch::ScalarType) dtype, non_blocking, copy);
272
286
  });
273
287
 
274
- Rice::define_class_under<torch::TensorOptions>(m, "TensorOptions")
288
+ rb_cTensorOptions
275
289
  .add_handler<torch::Error>(handle_error)
276
- .define_constructor(Rice::Constructor<torch::TensorOptions>())
277
290
  .define_method(
278
291
  "dtype",
279
- *[](torch::TensorOptions& self, int dtype) {
292
+ [](torch::TensorOptions& self, int dtype) {
280
293
  return self.dtype((torch::ScalarType) dtype);
281
294
  })
282
295
  .define_method(
283
296
  "layout",
284
- *[](torch::TensorOptions& self, const std::string& layout) {
297
+ [](torch::TensorOptions& self, const std::string& layout) {
285
298
  torch::Layout l;
286
299
  if (layout == "strided") {
287
300
  l = torch::kStrided;
@@ -295,13 +308,13 @@ void init_tensor(Rice::Module& m) {
295
308
  })
296
309
  .define_method(
297
310
  "device",
298
- *[](torch::TensorOptions& self, const std::string& device) {
311
+ [](torch::TensorOptions& self, const std::string& device) {
299
312
  torch::Device d(device);
300
313
  return self.device(d);
301
314
  })
302
315
  .define_method(
303
316
  "requires_grad",
304
- *[](torch::TensorOptions& self, bool requires_grad) {
317
+ [](torch::TensorOptions& self, bool requires_grad) {
305
318
  return self.requires_grad(requires_grad);
306
319
  });
307
320
  }