torch-rb 0.6.0 → 0.7.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -137,7 +137,7 @@ auto FunctionParameter::check(VALUE obj, int argnum) -> bool
137
137
  return true;
138
138
  }
139
139
  if (THPVariable_Check(obj)) {
140
- auto var = from_ruby<torch::Tensor>(obj);
140
+ auto var = Rice::detail::From_Ruby<torch::Tensor>().convert(obj);
141
141
  return !var.requires_grad() && var.dim() == 0;
142
142
  }
143
143
  return false;
@@ -147,7 +147,7 @@ auto FunctionParameter::check(VALUE obj, int argnum) -> bool
147
147
  return true;
148
148
  }
149
149
  if (THPVariable_Check(obj)) {
150
- auto var = from_ruby<torch::Tensor>(obj);
150
+ auto var = Rice::detail::From_Ruby<torch::Tensor>().convert(obj);
151
151
  return at::isIntegralType(var.scalar_type(), /*includeBool=*/false) && !var.requires_grad() && var.dim() == 0;
152
152
  }
153
153
  return false;
@@ -5,7 +5,7 @@
5
5
  #include <sstream>
6
6
 
7
7
  #include <torch/torch.h>
8
- #include <rice/Exception.hpp>
8
+ #include <rice/rice.hpp>
9
9
 
10
10
  #include "templates.h"
11
11
  #include "utils.h"
@@ -121,7 +121,7 @@ struct RubyArgs {
121
121
  };
122
122
 
123
123
  inline at::Tensor RubyArgs::tensor(int i) {
124
- return from_ruby<torch::Tensor>(args[i]);
124
+ return Rice::detail::From_Ruby<torch::Tensor>().convert(args[i]);
125
125
  }
126
126
 
127
127
  inline OptionalTensor RubyArgs::optionalTensor(int i) {
@@ -131,12 +131,12 @@ inline OptionalTensor RubyArgs::optionalTensor(int i) {
131
131
 
132
132
  inline at::Scalar RubyArgs::scalar(int i) {
133
133
  if (NIL_P(args[i])) return signature.params[i].default_scalar;
134
- return from_ruby<torch::Scalar>(args[i]);
134
+ return Rice::detail::From_Ruby<torch::Scalar>().convert(args[i]);
135
135
  }
136
136
 
137
137
  inline std::vector<at::Tensor> RubyArgs::tensorlist(int i) {
138
138
  if (NIL_P(args[i])) return std::vector<at::Tensor>();
139
- return from_ruby<std::vector<Tensor>>(args[i]);
139
+ return Rice::detail::From_Ruby<std::vector<Tensor>>().convert(args[i]);
140
140
  }
141
141
 
142
142
  template<int N>
@@ -151,7 +151,7 @@ inline std::array<at::Tensor, N> RubyArgs::tensorlist_n(int i) {
151
151
  }
152
152
  for (int idx = 0; idx < size; idx++) {
153
153
  VALUE obj = rb_ary_entry(arg, idx);
154
- res[idx] = from_ruby<Tensor>(obj);
154
+ res[idx] = Rice::detail::From_Ruby<Tensor>().convert(obj);
155
155
  }
156
156
  return res;
157
157
  }
@@ -170,7 +170,7 @@ inline std::vector<int64_t> RubyArgs::intlist(int i) {
170
170
  for (idx = 0; idx < size; idx++) {
171
171
  VALUE obj = rb_ary_entry(arg, idx);
172
172
  if (FIXNUM_P(obj)) {
173
- res[idx] = from_ruby<int64_t>(obj);
173
+ res[idx] = Rice::detail::From_Ruby<int64_t>().convert(obj);
174
174
  } else {
175
175
  rb_raise(rb_eArgError, "%s(): argument '%s' must be %s, but found element of type %s at pos %d",
176
176
  signature.name.c_str(), signature.params[i].name.c_str(),
@@ -210,8 +210,13 @@ inline ScalarType RubyArgs::scalartype(int i) {
210
210
  {ID2SYM(rb_intern("double")), ScalarType::Double},
211
211
  {ID2SYM(rb_intern("float64")), ScalarType::Double},
212
212
  {ID2SYM(rb_intern("complex_half")), ScalarType::ComplexHalf},
213
+ {ID2SYM(rb_intern("complex32")), ScalarType::ComplexHalf},
213
214
  {ID2SYM(rb_intern("complex_float")), ScalarType::ComplexFloat},
215
+ {ID2SYM(rb_intern("cfloat")), ScalarType::ComplexFloat},
216
+ {ID2SYM(rb_intern("complex64")), ScalarType::ComplexFloat},
214
217
  {ID2SYM(rb_intern("complex_double")), ScalarType::ComplexDouble},
218
+ {ID2SYM(rb_intern("cdouble")), ScalarType::ComplexDouble},
219
+ {ID2SYM(rb_intern("complex128")), ScalarType::ComplexDouble},
215
220
  {ID2SYM(rb_intern("bool")), ScalarType::Bool},
216
221
  {ID2SYM(rb_intern("qint8")), ScalarType::QInt8},
217
222
  {ID2SYM(rb_intern("quint8")), ScalarType::QUInt8},
@@ -260,7 +265,7 @@ inline c10::OptionalArray<double> RubyArgs::doublelistOptional(int i) {
260
265
  for (idx = 0; idx < size; idx++) {
261
266
  VALUE obj = rb_ary_entry(arg, idx);
262
267
  if (FIXNUM_P(obj) || RB_FLOAT_TYPE_P(obj)) {
263
- res[idx] = from_ruby<double>(obj);
268
+ res[idx] = Rice::detail::From_Ruby<double>().convert(obj);
264
269
  } else {
265
270
  rb_raise(rb_eArgError, "%s(): argument '%s' must be %s, but found element of type %s at pos %d",
266
271
  signature.name.c_str(), signature.params[i].name.c_str(),
@@ -303,22 +308,22 @@ inline c10::optional<at::MemoryFormat> RubyArgs::memoryformatOptional(int i) {
303
308
  }
304
309
 
305
310
  inline std::string RubyArgs::string(int i) {
306
- return from_ruby<std::string>(args[i]);
311
+ return Rice::detail::From_Ruby<std::string>().convert(args[i]);
307
312
  }
308
313
 
309
314
  inline c10::optional<std::string> RubyArgs::stringOptional(int i) {
310
315
  if (!args[i]) return c10::nullopt;
311
- return from_ruby<std::string>(args[i]);
316
+ return Rice::detail::From_Ruby<std::string>().convert(args[i]);
312
317
  }
313
318
 
314
319
  inline int64_t RubyArgs::toInt64(int i) {
315
320
  if (NIL_P(args[i])) return signature.params[i].default_int;
316
- return from_ruby<int64_t>(args[i]);
321
+ return Rice::detail::From_Ruby<int64_t>().convert(args[i]);
317
322
  }
318
323
 
319
324
  inline double RubyArgs::toDouble(int i) {
320
325
  if (NIL_P(args[i])) return signature.params[i].default_double;
321
- return from_ruby<double>(args[i]);
326
+ return Rice::detail::From_Ruby<double>().convert(args[i]);
322
327
  }
323
328
 
324
329
  inline bool RubyArgs::toBool(int i) {
@@ -4,8 +4,7 @@
4
4
  #undef isfinite
5
5
  #endif
6
6
 
7
- #include <rice/Array.hpp>
8
- #include <rice/Object.hpp>
7
+ #include <rice/rice.hpp>
9
8
 
10
9
  using namespace Rice;
11
10
 
@@ -23,6 +22,9 @@ using torch::ArrayRef;
23
22
  using torch::TensorList;
24
23
  using torch::Storage;
25
24
 
25
+ using torch::nn::init::FanModeType;
26
+ using torch::nn::init::NonlinearityType;
27
+
26
28
  #define HANDLE_TH_ERRORS \
27
29
  try {
28
30
 
@@ -38,37 +40,42 @@ using torch::Storage;
38
40
  #define RETURN_NIL \
39
41
  return Qnil;
40
42
 
41
- template<>
42
- inline
43
- std::vector<int64_t> from_ruby<std::vector<int64_t>>(Object x)
44
- {
45
- Array a = Array(x);
46
- std::vector<int64_t> vec(a.size());
47
- for (long i = 0; i < a.size(); i++) {
48
- vec[i] = from_ruby<int64_t>(a[i]);
49
- }
50
- return vec;
51
- }
43
+ class OptionalTensor {
44
+ torch::Tensor value;
45
+ public:
46
+ OptionalTensor(Object o) {
47
+ if (o.is_nil()) {
48
+ value = {};
49
+ } else {
50
+ value = Rice::detail::From_Ruby<torch::Tensor>().convert(o.value());
51
+ }
52
+ }
53
+ OptionalTensor(torch::Tensor o) {
54
+ value = o;
55
+ }
56
+ operator torch::Tensor() const {
57
+ return value;
58
+ }
59
+ };
52
60
 
53
- template<>
54
- inline
55
- std::vector<Tensor> from_ruby<std::vector<Tensor>>(Object x)
61
+ namespace Rice::detail
56
62
  {
57
- Array a = Array(x);
58
- std::vector<Tensor> vec(a.size());
59
- for (long i = 0; i < a.size(); i++) {
60
- vec[i] = from_ruby<Tensor>(a[i]);
61
- }
62
- return vec;
63
- }
63
+ template<>
64
+ struct Type<FanModeType>
65
+ {
66
+ static bool verify()
67
+ {
68
+ return true;
69
+ }
70
+ };
64
71
 
65
- class FanModeType {
66
- std::string s;
72
+ template<>
73
+ class From_Ruby<FanModeType>
74
+ {
67
75
  public:
68
- FanModeType(Object o) {
69
- s = String(o).str();
70
- }
71
- operator torch::nn::init::FanModeType() {
76
+ FanModeType convert(VALUE x)
77
+ {
78
+ auto s = String(x).str();
72
79
  if (s == "fan_in") {
73
80
  return torch::kFanIn;
74
81
  } else if (s == "fan_out") {
@@ -77,22 +84,24 @@ class FanModeType {
77
84
  throw std::runtime_error("Unsupported nonlinearity type: " + s);
78
85
  }
79
86
  }
80
- };
81
-
82
- template<>
83
- inline
84
- FanModeType from_ruby<FanModeType>(Object x)
85
- {
86
- return FanModeType(x);
87
- }
87
+ };
88
+
89
+ template<>
90
+ struct Type<NonlinearityType>
91
+ {
92
+ static bool verify()
93
+ {
94
+ return true;
95
+ }
96
+ };
88
97
 
89
- class NonlinearityType {
90
- std::string s;
98
+ template<>
99
+ class From_Ruby<NonlinearityType>
100
+ {
91
101
  public:
92
- NonlinearityType(Object o) {
93
- s = String(o).str();
94
- }
95
- operator torch::nn::init::NonlinearityType() {
102
+ NonlinearityType convert(VALUE x)
103
+ {
104
+ auto s = String(x).str();
96
105
  if (s == "linear") {
97
106
  return torch::kLinear;
98
107
  } else if (s == "conv1d") {
@@ -119,102 +128,70 @@ class NonlinearityType {
119
128
  throw std::runtime_error("Unsupported nonlinearity type: " + s);
120
129
  }
121
130
  }
122
- };
131
+ };
132
+
133
+ template<>
134
+ struct Type<OptionalTensor>
135
+ {
136
+ static bool verify()
137
+ {
138
+ return true;
139
+ }
140
+ };
123
141
 
124
- template<>
125
- inline
126
- NonlinearityType from_ruby<NonlinearityType>(Object x)
127
- {
128
- return NonlinearityType(x);
129
- }
142
+ template<>
143
+ class From_Ruby<OptionalTensor>
144
+ {
145
+ public:
146
+ OptionalTensor convert(VALUE x)
147
+ {
148
+ return OptionalTensor(x);
149
+ }
150
+ };
151
+
152
+ template<>
153
+ struct Type<Scalar>
154
+ {
155
+ static bool verify()
156
+ {
157
+ return true;
158
+ }
159
+ };
130
160
 
131
- class OptionalTensor {
132
- torch::Tensor value;
161
+ template<>
162
+ class From_Ruby<Scalar>
163
+ {
133
164
  public:
134
- OptionalTensor(Object o) {
135
- if (o.is_nil()) {
136
- value = {};
165
+ Scalar convert(VALUE x)
166
+ {
167
+ if (FIXNUM_P(x)) {
168
+ return torch::Scalar(From_Ruby<int64_t>().convert(x));
137
169
  } else {
138
- value = from_ruby<torch::Tensor>(o);
170
+ return torch::Scalar(From_Ruby<double>().convert(x));
139
171
  }
140
172
  }
141
- OptionalTensor(torch::Tensor o) {
142
- value = o;
143
- }
144
- operator torch::Tensor() const {
145
- return value;
173
+ };
174
+
175
+ template<typename T>
176
+ struct Type<torch::optional<T>>
177
+ {
178
+ static bool verify()
179
+ {
180
+ return true;
146
181
  }
147
- };
148
-
149
- template<>
150
- inline
151
- Scalar from_ruby<Scalar>(Object x)
152
- {
153
- if (x.rb_type() == T_FIXNUM) {
154
- return torch::Scalar(from_ruby<int64_t>(x));
155
- } else {
156
- return torch::Scalar(from_ruby<double>(x));
157
- }
158
- }
159
-
160
- template<>
161
- inline
162
- OptionalTensor from_ruby<OptionalTensor>(Object x)
163
- {
164
- return OptionalTensor(x);
165
- }
166
-
167
- template<>
168
- inline
169
- torch::optional<torch::ScalarType> from_ruby<torch::optional<torch::ScalarType>>(Object x)
170
- {
171
- if (x.is_nil()) {
172
- return torch::nullopt;
173
- } else {
174
- return torch::optional<torch::ScalarType>{from_ruby<torch::ScalarType>(x)};
175
- }
176
- }
177
-
178
- template<>
179
- inline
180
- torch::optional<int64_t> from_ruby<torch::optional<int64_t>>(Object x)
181
- {
182
- if (x.is_nil()) {
183
- return torch::nullopt;
184
- } else {
185
- return torch::optional<int64_t>{from_ruby<int64_t>(x)};
186
- }
187
- }
182
+ };
188
183
 
189
- template<>
190
- inline
191
- torch::optional<double> from_ruby<torch::optional<double>>(Object x)
192
- {
193
- if (x.is_nil()) {
194
- return torch::nullopt;
195
- } else {
196
- return torch::optional<double>{from_ruby<double>(x)};
197
- }
198
- }
199
-
200
- template<>
201
- inline
202
- torch::optional<bool> from_ruby<torch::optional<bool>>(Object x)
203
- {
204
- if (x.is_nil()) {
205
- return torch::nullopt;
206
- } else {
207
- return torch::optional<bool>{from_ruby<bool>(x)};
208
- }
209
- }
210
-
211
- template<>
212
- inline
213
- torch::optional<Scalar> from_ruby<torch::optional<Scalar>>(Object x)
214
- {
215
- if (x.is_nil()) {
216
- return torch::nullopt;
217
- } else {
218
- return torch::optional<Scalar>{from_ruby<Scalar>(x)};
219
- }
184
+ template<typename T>
185
+ class From_Ruby<torch::optional<T>>
186
+ {
187
+ public:
188
+ torch::optional<T> convert(VALUE x)
189
+ {
190
+ if (NIL_P(x)) {
191
+ return torch::nullopt;
192
+ } else {
193
+ return torch::optional<T>{From_Ruby<T>().convert(x)};
194
+ }
195
+ }
196
+ };
220
197
  }
data/ext/torch/tensor.cpp CHANGED
@@ -1,7 +1,6 @@
1
1
  #include <torch/torch.h>
2
2
 
3
- #include <rice/Constructor.hpp>
4
- #include <rice/Module.hpp>
3
+ #include <rice/rice.hpp>
5
4
 
6
5
  #include "tensor_functions.h"
7
6
  #include "ruby_arg_parser.h"
@@ -11,6 +10,39 @@
11
10
  using namespace Rice;
12
11
  using torch::indexing::TensorIndex;
13
12
 
13
+ namespace Rice::detail
14
+ {
15
+ template<typename T>
16
+ struct Type<c10::complex<T>>
17
+ {
18
+ static bool verify()
19
+ {
20
+ return true;
21
+ }
22
+ };
23
+
24
+ template<typename T>
25
+ class To_Ruby<c10::complex<T>>
26
+ {
27
+ public:
28
+ VALUE convert(c10::complex<T> const& x)
29
+ {
30
+ return rb_dbl_complex_new(x.real(), x.imag());
31
+ }
32
+ };
33
+ }
34
+
35
+ template<typename T>
36
+ Array flat_data(Tensor& tensor) {
37
+ Tensor view = tensor.reshape({tensor.numel()});
38
+
39
+ Array a;
40
+ for (int i = 0; i < tensor.numel(); i++) {
41
+ a.push(view[i].item().to<T>());
42
+ }
43
+ return a;
44
+ }
45
+
14
46
  Class rb_cTensor;
15
47
 
16
48
  std::vector<TensorIndex> index_vector(Array a) {
@@ -23,19 +55,19 @@ std::vector<TensorIndex> index_vector(Array a) {
23
55
  obj = a[i];
24
56
 
25
57
  if (obj.is_instance_of(rb_cInteger)) {
26
- indices.push_back(from_ruby<int64_t>(obj));
58
+ indices.push_back(Rice::detail::From_Ruby<int64_t>().convert(obj.value()));
27
59
  } else if (obj.is_instance_of(rb_cRange)) {
28
60
  torch::optional<int64_t> start_index = torch::nullopt;
29
61
  torch::optional<int64_t> stop_index = torch::nullopt;
30
62
 
31
63
  Object begin = obj.call("begin");
32
64
  if (!begin.is_nil()) {
33
- start_index = from_ruby<int64_t>(begin);
65
+ start_index = Rice::detail::From_Ruby<int64_t>().convert(begin.value());
34
66
  }
35
67
 
36
68
  Object end = obj.call("end");
37
69
  if (!end.is_nil()) {
38
- stop_index = from_ruby<int64_t>(end);
70
+ stop_index = Rice::detail::From_Ruby<int64_t>().convert(end.value());
39
71
  }
40
72
 
41
73
  Object exclude_end = obj.call("exclude_end?");
@@ -49,11 +81,11 @@ std::vector<TensorIndex> index_vector(Array a) {
49
81
 
50
82
  indices.push_back(torch::indexing::Slice(start_index, stop_index));
51
83
  } else if (obj.is_instance_of(rb_cTensor)) {
52
- indices.push_back(from_ruby<Tensor>(obj));
84
+ indices.push_back(Rice::detail::From_Ruby<Tensor>().convert(obj.value()));
53
85
  } else if (obj.is_nil()) {
54
86
  indices.push_back(torch::indexing::None);
55
87
  } else if (obj == True || obj == False) {
56
- indices.push_back(from_ruby<bool>(obj));
88
+ indices.push_back(Rice::detail::From_Ruby<bool>().convert(obj.value()));
57
89
  } else {
58
90
  throw Exception(rb_eArgError, "Unsupported index type: %s", rb_obj_classname(obj));
59
91
  }
@@ -68,7 +100,7 @@ std::vector<TensorIndex> index_vector(Array a) {
68
100
  static VALUE tensor__backward(int argc, VALUE* argv, VALUE self_)
69
101
  {
70
102
  HANDLE_TH_ERRORS
71
- Tensor& self = from_ruby<Tensor&>(self_);
103
+ Tensor& self = Rice::detail::From_Ruby<Tensor&>().convert(self_);
72
104
  static RubyArgParser parser({
73
105
  "_backward(Tensor? gradient=None, bool? retain_graph=None, bool create_graph=False)"
74
106
  });
@@ -84,8 +116,8 @@ static VALUE tensor__backward(int argc, VALUE* argv, VALUE self_)
84
116
  END_HANDLE_TH_ERRORS
85
117
  }
86
118
 
87
- void init_tensor(Rice::Module& m) {
88
- rb_cTensor = Rice::define_class_under<torch::Tensor>(m, "Tensor");
119
+ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions) {
120
+ rb_cTensor = c;
89
121
  rb_cTensor.add_handler<torch::Error>(handle_error);
90
122
  add_tensor_functions(rb_cTensor);
91
123
  THPVariableClass = rb_cTensor.value();
@@ -102,18 +134,18 @@ void init_tensor(Rice::Module& m) {
102
134
  .define_method("requires_grad", &torch::Tensor::requires_grad)
103
135
  .define_method(
104
136
  "_size",
105
- *[](Tensor& self, int64_t dim) {
137
+ [](Tensor& self, int64_t dim) {
106
138
  return self.size(dim);
107
139
  })
108
140
  .define_method(
109
141
  "_stride",
110
- *[](Tensor& self, int64_t dim) {
142
+ [](Tensor& self, int64_t dim) {
111
143
  return self.stride(dim);
112
144
  })
113
145
  // in C++ for performance
114
146
  .define_method(
115
147
  "shape",
116
- *[](Tensor& self) {
148
+ [](Tensor& self) {
117
149
  Array a;
118
150
  for (auto &size : self.sizes()) {
119
151
  a.push(size);
@@ -122,7 +154,7 @@ void init_tensor(Rice::Module& m) {
122
154
  })
123
155
  .define_method(
124
156
  "_strides",
125
- *[](Tensor& self) {
157
+ [](Tensor& self) {
126
158
  Array a;
127
159
  for (auto &stride : self.strides()) {
128
160
  a.push(stride);
@@ -131,65 +163,65 @@ void init_tensor(Rice::Module& m) {
131
163
  })
132
164
  .define_method(
133
165
  "_index",
134
- *[](Tensor& self, Array indices) {
166
+ [](Tensor& self, Array indices) {
135
167
  auto vec = index_vector(indices);
136
168
  return self.index(vec);
137
169
  })
138
170
  .define_method(
139
171
  "_index_put_custom",
140
- *[](Tensor& self, Array indices, torch::Tensor& value) {
172
+ [](Tensor& self, Array indices, torch::Tensor& value) {
141
173
  auto vec = index_vector(indices);
142
174
  return self.index_put_(vec, value);
143
175
  })
144
176
  .define_method(
145
177
  "contiguous?",
146
- *[](Tensor& self) {
178
+ [](Tensor& self) {
147
179
  return self.is_contiguous();
148
180
  })
149
181
  .define_method(
150
182
  "_requires_grad!",
151
- *[](Tensor& self, bool requires_grad) {
183
+ [](Tensor& self, bool requires_grad) {
152
184
  return self.set_requires_grad(requires_grad);
153
185
  })
154
186
  .define_method(
155
187
  "grad",
156
- *[](Tensor& self) {
188
+ [](Tensor& self) {
157
189
  auto grad = self.grad();
158
- return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
190
+ return grad.defined() ? Object(Rice::detail::To_Ruby<torch::Tensor>().convert(grad)) : Nil;
159
191
  })
160
192
  .define_method(
161
193
  "grad=",
162
- *[](Tensor& self, torch::Tensor& grad) {
194
+ [](Tensor& self, torch::Tensor& grad) {
163
195
  self.mutable_grad() = grad;
164
196
  })
165
197
  .define_method(
166
198
  "_dtype",
167
- *[](Tensor& self) {
199
+ [](Tensor& self) {
168
200
  return (int) at::typeMetaToScalarType(self.dtype());
169
201
  })
170
202
  .define_method(
171
203
  "_type",
172
- *[](Tensor& self, int dtype) {
204
+ [](Tensor& self, int dtype) {
173
205
  return self.toType((torch::ScalarType) dtype);
174
206
  })
175
207
  .define_method(
176
208
  "_layout",
177
- *[](Tensor& self) {
209
+ [](Tensor& self) {
178
210
  std::stringstream s;
179
211
  s << self.layout();
180
212
  return s.str();
181
213
  })
182
214
  .define_method(
183
215
  "device",
184
- *[](Tensor& self) {
216
+ [](Tensor& self) {
185
217
  std::stringstream s;
186
218
  s << self.device();
187
219
  return s.str();
188
220
  })
189
221
  .define_method(
190
222
  "_data_str",
191
- *[](Tensor& self) {
192
- Tensor tensor = self;
223
+ [](Tensor& self) {
224
+ auto tensor = self;
193
225
 
194
226
  // move to CPU to get data
195
227
  if (tensor.device().type() != torch::kCPU) {
@@ -207,14 +239,14 @@ void init_tensor(Rice::Module& m) {
207
239
  // for TorchVision
208
240
  .define_method(
209
241
  "_data_ptr",
210
- *[](Tensor& self) {
242
+ [](Tensor& self) {
211
243
  return reinterpret_cast<uintptr_t>(self.data_ptr());
212
244
  })
213
245
  // TODO figure out a better way to do this
214
246
  .define_method(
215
247
  "_flat_data",
216
- *[](Tensor& self) {
217
- Tensor tensor = self;
248
+ [](Tensor& self) {
249
+ auto tensor = self;
218
250
 
219
251
  // move to CPU to get data
220
252
  if (tensor.device().type() != torch::kCPU) {
@@ -222,66 +254,47 @@ void init_tensor(Rice::Module& m) {
222
254
  tensor = tensor.to(device);
223
255
  }
224
256
 
225
- Array a;
226
257
  auto dtype = tensor.dtype();
227
-
228
- Tensor view = tensor.reshape({tensor.numel()});
229
-
230
- // TODO DRY if someone knows C++
231
258
  if (dtype == torch::kByte) {
232
- for (int i = 0; i < tensor.numel(); i++) {
233
- a.push(view[i].item().to<uint8_t>());
234
- }
259
+ return flat_data<uint8_t>(tensor);
235
260
  } else if (dtype == torch::kChar) {
236
- for (int i = 0; i < tensor.numel(); i++) {
237
- a.push(to_ruby<int>(view[i].item().to<int8_t>()));
238
- }
261
+ return flat_data<int8_t>(tensor);
239
262
  } else if (dtype == torch::kShort) {
240
- for (int i = 0; i < tensor.numel(); i++) {
241
- a.push(view[i].item().to<int16_t>());
242
- }
263
+ return flat_data<int16_t>(tensor);
243
264
  } else if (dtype == torch::kInt) {
244
- for (int i = 0; i < tensor.numel(); i++) {
245
- a.push(view[i].item().to<int32_t>());
246
- }
265
+ return flat_data<int32_t>(tensor);
247
266
  } else if (dtype == torch::kLong) {
248
- for (int i = 0; i < tensor.numel(); i++) {
249
- a.push(view[i].item().to<int64_t>());
250
- }
267
+ return flat_data<int64_t>(tensor);
251
268
  } else if (dtype == torch::kFloat) {
252
- for (int i = 0; i < tensor.numel(); i++) {
253
- a.push(view[i].item().to<float>());
254
- }
269
+ return flat_data<float>(tensor);
255
270
  } else if (dtype == torch::kDouble) {
256
- for (int i = 0; i < tensor.numel(); i++) {
257
- a.push(view[i].item().to<double>());
258
- }
271
+ return flat_data<double>(tensor);
259
272
  } else if (dtype == torch::kBool) {
260
- for (int i = 0; i < tensor.numel(); i++) {
261
- a.push(view[i].item().to<bool>() ? True : False);
262
- }
273
+ return flat_data<bool>(tensor);
274
+ } else if (dtype == torch::kComplexFloat) {
275
+ return flat_data<c10::complex<float>>(tensor);
276
+ } else if (dtype == torch::kComplexDouble) {
277
+ return flat_data<c10::complex<double>>(tensor);
263
278
  } else {
264
279
  throw std::runtime_error("Unsupported type");
265
280
  }
266
- return a;
267
281
  })
268
282
  .define_method(
269
283
  "_to",
270
- *[](Tensor& self, torch::Device device, int dtype, bool non_blocking, bool copy) {
284
+ [](Tensor& self, torch::Device device, int dtype, bool non_blocking, bool copy) {
271
285
  return self.to(device, (torch::ScalarType) dtype, non_blocking, copy);
272
286
  });
273
287
 
274
- Rice::define_class_under<torch::TensorOptions>(m, "TensorOptions")
288
+ rb_cTensorOptions
275
289
  .add_handler<torch::Error>(handle_error)
276
- .define_constructor(Rice::Constructor<torch::TensorOptions>())
277
290
  .define_method(
278
291
  "dtype",
279
- *[](torch::TensorOptions& self, int dtype) {
292
+ [](torch::TensorOptions& self, int dtype) {
280
293
  return self.dtype((torch::ScalarType) dtype);
281
294
  })
282
295
  .define_method(
283
296
  "layout",
284
- *[](torch::TensorOptions& self, const std::string& layout) {
297
+ [](torch::TensorOptions& self, const std::string& layout) {
285
298
  torch::Layout l;
286
299
  if (layout == "strided") {
287
300
  l = torch::kStrided;
@@ -295,13 +308,13 @@ void init_tensor(Rice::Module& m) {
295
308
  })
296
309
  .define_method(
297
310
  "device",
298
- *[](torch::TensorOptions& self, const std::string& device) {
311
+ [](torch::TensorOptions& self, const std::string& device) {
299
312
  torch::Device d(device);
300
313
  return self.device(d);
301
314
  })
302
315
  .define_method(
303
316
  "requires_grad",
304
- *[](torch::TensorOptions& self, bool requires_grad) {
317
+ [](torch::TensorOptions& self, bool requires_grad) {
305
318
  return self.requires_grad(requires_grad);
306
319
  });
307
320
  }