torch-rb 0.5.0 → 0.7.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,8 +4,7 @@
4
4
  #undef isfinite
5
5
  #endif
6
6
 
7
- #include <rice/Array.hpp>
8
- #include <rice/Object.hpp>
7
+ #include <rice/rice.hpp>
9
8
 
10
9
  using namespace Rice;
11
10
 
@@ -23,6 +22,9 @@ using torch::ArrayRef;
23
22
  using torch::TensorList;
24
23
  using torch::Storage;
25
24
 
25
+ using torch::nn::init::FanModeType;
26
+ using torch::nn::init::NonlinearityType;
27
+
26
28
  #define HANDLE_TH_ERRORS \
27
29
  try {
28
30
 
@@ -38,37 +40,42 @@ using torch::Storage;
38
40
  #define RETURN_NIL \
39
41
  return Qnil;
40
42
 
41
- template<>
42
- inline
43
- std::vector<int64_t> from_ruby<std::vector<int64_t>>(Object x)
44
- {
45
- Array a = Array(x);
46
- std::vector<int64_t> vec(a.size());
47
- for (size_t i = 0; i < a.size(); i++) {
48
- vec[i] = from_ruby<int64_t>(a[i]);
49
- }
50
- return vec;
51
- }
43
+ class OptionalTensor {
44
+ torch::Tensor value;
45
+ public:
46
+ OptionalTensor(Object o) {
47
+ if (o.is_nil()) {
48
+ value = {};
49
+ } else {
50
+ value = Rice::detail::From_Ruby<torch::Tensor>().convert(o.value());
51
+ }
52
+ }
53
+ OptionalTensor(torch::Tensor o) {
54
+ value = o;
55
+ }
56
+ operator torch::Tensor() const {
57
+ return value;
58
+ }
59
+ };
52
60
 
53
- template<>
54
- inline
55
- std::vector<Tensor> from_ruby<std::vector<Tensor>>(Object x)
61
+ namespace Rice::detail
56
62
  {
57
- Array a = Array(x);
58
- std::vector<Tensor> vec(a.size());
59
- for (size_t i = 0; i < a.size(); i++) {
60
- vec[i] = from_ruby<Tensor>(a[i]);
61
- }
62
- return vec;
63
- }
63
+ template<>
64
+ struct Type<FanModeType>
65
+ {
66
+ static bool verify()
67
+ {
68
+ return true;
69
+ }
70
+ };
64
71
 
65
- class FanModeType {
66
- std::string s;
72
+ template<>
73
+ class From_Ruby<FanModeType>
74
+ {
67
75
  public:
68
- FanModeType(Object o) {
69
- s = String(o).str();
70
- }
71
- operator torch::nn::init::FanModeType() {
76
+ FanModeType convert(VALUE x)
77
+ {
78
+ auto s = String(x).str();
72
79
  if (s == "fan_in") {
73
80
  return torch::kFanIn;
74
81
  } else if (s == "fan_out") {
@@ -77,22 +84,24 @@ class FanModeType {
77
84
  throw std::runtime_error("Unsupported nonlinearity type: " + s);
78
85
  }
79
86
  }
80
- };
81
-
82
- template<>
83
- inline
84
- FanModeType from_ruby<FanModeType>(Object x)
85
- {
86
- return FanModeType(x);
87
- }
87
+ };
88
+
89
+ template<>
90
+ struct Type<NonlinearityType>
91
+ {
92
+ static bool verify()
93
+ {
94
+ return true;
95
+ }
96
+ };
88
97
 
89
- class NonlinearityType {
90
- std::string s;
98
+ template<>
99
+ class From_Ruby<NonlinearityType>
100
+ {
91
101
  public:
92
- NonlinearityType(Object o) {
93
- s = String(o).str();
94
- }
95
- operator torch::nn::init::NonlinearityType() {
102
+ NonlinearityType convert(VALUE x)
103
+ {
104
+ auto s = String(x).str();
96
105
  if (s == "linear") {
97
106
  return torch::kLinear;
98
107
  } else if (s == "conv1d") {
@@ -119,102 +128,70 @@ class NonlinearityType {
119
128
  throw std::runtime_error("Unsupported nonlinearity type: " + s);
120
129
  }
121
130
  }
122
- };
131
+ };
132
+
133
+ template<>
134
+ struct Type<OptionalTensor>
135
+ {
136
+ static bool verify()
137
+ {
138
+ return true;
139
+ }
140
+ };
123
141
 
124
- template<>
125
- inline
126
- NonlinearityType from_ruby<NonlinearityType>(Object x)
127
- {
128
- return NonlinearityType(x);
129
- }
142
+ template<>
143
+ class From_Ruby<OptionalTensor>
144
+ {
145
+ public:
146
+ OptionalTensor convert(VALUE x)
147
+ {
148
+ return OptionalTensor(x);
149
+ }
150
+ };
151
+
152
+ template<>
153
+ struct Type<Scalar>
154
+ {
155
+ static bool verify()
156
+ {
157
+ return true;
158
+ }
159
+ };
130
160
 
131
- class OptionalTensor {
132
- torch::Tensor value;
161
+ template<>
162
+ class From_Ruby<Scalar>
163
+ {
133
164
  public:
134
- OptionalTensor(Object o) {
135
- if (o.is_nil()) {
136
- value = {};
165
+ Scalar convert(VALUE x)
166
+ {
167
+ if (FIXNUM_P(x)) {
168
+ return torch::Scalar(From_Ruby<int64_t>().convert(x));
137
169
  } else {
138
- value = from_ruby<torch::Tensor>(o);
170
+ return torch::Scalar(From_Ruby<double>().convert(x));
139
171
  }
140
172
  }
141
- OptionalTensor(torch::Tensor o) {
142
- value = o;
143
- }
144
- operator torch::Tensor() const {
145
- return value;
173
+ };
174
+
175
+ template<typename T>
176
+ struct Type<torch::optional<T>>
177
+ {
178
+ static bool verify()
179
+ {
180
+ return true;
146
181
  }
147
- };
148
-
149
- template<>
150
- inline
151
- Scalar from_ruby<Scalar>(Object x)
152
- {
153
- if (x.rb_type() == T_FIXNUM) {
154
- return torch::Scalar(from_ruby<int64_t>(x));
155
- } else {
156
- return torch::Scalar(from_ruby<double>(x));
157
- }
158
- }
159
-
160
- template<>
161
- inline
162
- OptionalTensor from_ruby<OptionalTensor>(Object x)
163
- {
164
- return OptionalTensor(x);
165
- }
166
-
167
- template<>
168
- inline
169
- torch::optional<torch::ScalarType> from_ruby<torch::optional<torch::ScalarType>>(Object x)
170
- {
171
- if (x.is_nil()) {
172
- return torch::nullopt;
173
- } else {
174
- return torch::optional<torch::ScalarType>{from_ruby<torch::ScalarType>(x)};
175
- }
176
- }
177
-
178
- template<>
179
- inline
180
- torch::optional<int64_t> from_ruby<torch::optional<int64_t>>(Object x)
181
- {
182
- if (x.is_nil()) {
183
- return torch::nullopt;
184
- } else {
185
- return torch::optional<int64_t>{from_ruby<int64_t>(x)};
186
- }
187
- }
182
+ };
188
183
 
189
- template<>
190
- inline
191
- torch::optional<double> from_ruby<torch::optional<double>>(Object x)
192
- {
193
- if (x.is_nil()) {
194
- return torch::nullopt;
195
- } else {
196
- return torch::optional<double>{from_ruby<double>(x)};
197
- }
198
- }
199
-
200
- template<>
201
- inline
202
- torch::optional<bool> from_ruby<torch::optional<bool>>(Object x)
203
- {
204
- if (x.is_nil()) {
205
- return torch::nullopt;
206
- } else {
207
- return torch::optional<bool>{from_ruby<bool>(x)};
208
- }
209
- }
210
-
211
- template<>
212
- inline
213
- torch::optional<Scalar> from_ruby<torch::optional<Scalar>>(Object x)
214
- {
215
- if (x.is_nil()) {
216
- return torch::nullopt;
217
- } else {
218
- return torch::optional<Scalar>{from_ruby<Scalar>(x)};
219
- }
184
+ template<typename T>
185
+ class From_Ruby<torch::optional<T>>
186
+ {
187
+ public:
188
+ torch::optional<T> convert(VALUE x)
189
+ {
190
+ if (NIL_P(x)) {
191
+ return torch::nullopt;
192
+ } else {
193
+ return torch::optional<T>{From_Ruby<T>().convert(x)};
194
+ }
195
+ }
196
+ };
220
197
  }
@@ -0,0 +1,320 @@
1
+ #include <torch/torch.h>
2
+
3
+ #include <rice/rice.hpp>
4
+
5
+ #include "tensor_functions.h"
6
+ #include "ruby_arg_parser.h"
7
+ #include "templates.h"
8
+ #include "utils.h"
9
+
10
+ using namespace Rice;
11
+ using torch::indexing::TensorIndex;
12
+
13
+ namespace Rice::detail
14
+ {
15
+ template<typename T>
16
+ struct Type<c10::complex<T>>
17
+ {
18
+ static bool verify()
19
+ {
20
+ return true;
21
+ }
22
+ };
23
+
24
+ template<typename T>
25
+ class To_Ruby<c10::complex<T>>
26
+ {
27
+ public:
28
+ VALUE convert(c10::complex<T> const& x)
29
+ {
30
+ return rb_dbl_complex_new(x.real(), x.imag());
31
+ }
32
+ };
33
+ }
34
+
35
+ template<typename T>
36
+ Array flat_data(Tensor& tensor) {
37
+ Tensor view = tensor.reshape({tensor.numel()});
38
+
39
+ Array a;
40
+ for (int i = 0; i < tensor.numel(); i++) {
41
+ a.push(view[i].item().to<T>());
42
+ }
43
+ return a;
44
+ }
45
+
46
+ Class rb_cTensor;
47
+
48
+ std::vector<TensorIndex> index_vector(Array a) {
49
+ Object obj;
50
+
51
+ std::vector<TensorIndex> indices;
52
+ indices.reserve(a.size());
53
+
54
+ for (long i = 0; i < a.size(); i++) {
55
+ obj = a[i];
56
+
57
+ if (obj.is_instance_of(rb_cInteger)) {
58
+ indices.push_back(Rice::detail::From_Ruby<int64_t>().convert(obj.value()));
59
+ } else if (obj.is_instance_of(rb_cRange)) {
60
+ torch::optional<int64_t> start_index = torch::nullopt;
61
+ torch::optional<int64_t> stop_index = torch::nullopt;
62
+
63
+ Object begin = obj.call("begin");
64
+ if (!begin.is_nil()) {
65
+ start_index = Rice::detail::From_Ruby<int64_t>().convert(begin.value());
66
+ }
67
+
68
+ Object end = obj.call("end");
69
+ if (!end.is_nil()) {
70
+ stop_index = Rice::detail::From_Ruby<int64_t>().convert(end.value());
71
+ }
72
+
73
+ Object exclude_end = obj.call("exclude_end?");
74
+ if (stop_index.has_value() && !exclude_end) {
75
+ if (stop_index.value() == -1) {
76
+ stop_index = torch::nullopt;
77
+ } else {
78
+ stop_index = stop_index.value() + 1;
79
+ }
80
+ }
81
+
82
+ indices.push_back(torch::indexing::Slice(start_index, stop_index));
83
+ } else if (obj.is_instance_of(rb_cTensor)) {
84
+ indices.push_back(Rice::detail::From_Ruby<Tensor>().convert(obj.value()));
85
+ } else if (obj.is_nil()) {
86
+ indices.push_back(torch::indexing::None);
87
+ } else if (obj == True || obj == False) {
88
+ indices.push_back(Rice::detail::From_Ruby<bool>().convert(obj.value()));
89
+ } else {
90
+ throw Exception(rb_eArgError, "Unsupported index type: %s", rb_obj_classname(obj));
91
+ }
92
+ }
93
+ return indices;
94
+ }
95
+
96
+ // hack (removes inputs argument)
97
+ // https://github.com/pytorch/pytorch/commit/2e5bfa9824f549be69a28e4705a72b4cf8a4c519
98
+ // TODO add support for inputs argument
99
+ // _backward
100
+ static VALUE tensor__backward(int argc, VALUE* argv, VALUE self_)
101
+ {
102
+ HANDLE_TH_ERRORS
103
+ Tensor& self = Rice::detail::From_Ruby<Tensor&>().convert(self_);
104
+ static RubyArgParser parser({
105
+ "_backward(Tensor? gradient=None, bool? retain_graph=None, bool create_graph=False)"
106
+ });
107
+ ParsedArgs<4> parsed_args;
108
+ auto _r = parser.parse(self_, argc, argv, parsed_args);
109
+ // _backward(Tensor self, Tensor[] inputs, Tensor? gradient=None, bool? retain_graph=None, bool create_graph=False) -> ()
110
+ auto dispatch__backward = [](const Tensor & self, TensorList inputs, const OptionalTensor & gradient, c10::optional<bool> retain_graph, bool create_graph) -> void {
111
+ // in future, release GVL
112
+ self._backward(inputs, gradient, retain_graph, create_graph);
113
+ };
114
+ dispatch__backward(self, {}, _r.optionalTensor(0), _r.toBoolOptional(1), _r.toBool(2));
115
+ RETURN_NIL
116
+ END_HANDLE_TH_ERRORS
117
+ }
118
+
119
+ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions) {
120
+ rb_cTensor = c;
121
+ rb_cTensor.add_handler<torch::Error>(handle_error);
122
+ add_tensor_functions(rb_cTensor);
123
+ THPVariableClass = rb_cTensor.value();
124
+
125
+ rb_define_method(rb_cTensor, "backward", (VALUE (*)(...)) tensor__backward, -1);
126
+
127
+ rb_cTensor
128
+ .define_method("cuda?", &torch::Tensor::is_cuda)
129
+ .define_method("sparse?", &torch::Tensor::is_sparse)
130
+ .define_method("quantized?", &torch::Tensor::is_quantized)
131
+ .define_method("dim", &torch::Tensor::dim)
132
+ .define_method("numel", &torch::Tensor::numel)
133
+ .define_method("element_size", &torch::Tensor::element_size)
134
+ .define_method("requires_grad", &torch::Tensor::requires_grad)
135
+ .define_method(
136
+ "_size",
137
+ [](Tensor& self, int64_t dim) {
138
+ return self.size(dim);
139
+ })
140
+ .define_method(
141
+ "_stride",
142
+ [](Tensor& self, int64_t dim) {
143
+ return self.stride(dim);
144
+ })
145
+ // in C++ for performance
146
+ .define_method(
147
+ "shape",
148
+ [](Tensor& self) {
149
+ Array a;
150
+ for (auto &size : self.sizes()) {
151
+ a.push(size);
152
+ }
153
+ return a;
154
+ })
155
+ .define_method(
156
+ "_strides",
157
+ [](Tensor& self) {
158
+ Array a;
159
+ for (auto &stride : self.strides()) {
160
+ a.push(stride);
161
+ }
162
+ return a;
163
+ })
164
+ .define_method(
165
+ "_index",
166
+ [](Tensor& self, Array indices) {
167
+ auto vec = index_vector(indices);
168
+ return self.index(vec);
169
+ })
170
+ .define_method(
171
+ "_index_put_custom",
172
+ [](Tensor& self, Array indices, torch::Tensor& value) {
173
+ auto vec = index_vector(indices);
174
+ return self.index_put_(vec, value);
175
+ })
176
+ .define_method(
177
+ "contiguous?",
178
+ [](Tensor& self) {
179
+ return self.is_contiguous();
180
+ })
181
+ .define_method(
182
+ "_requires_grad!",
183
+ [](Tensor& self, bool requires_grad) {
184
+ return self.set_requires_grad(requires_grad);
185
+ })
186
+ .define_method(
187
+ "grad",
188
+ [](Tensor& self) {
189
+ auto grad = self.grad();
190
+ return grad.defined() ? Object(Rice::detail::To_Ruby<torch::Tensor>().convert(grad)) : Nil;
191
+ })
192
+ .define_method(
193
+ "grad=",
194
+ [](Tensor& self, torch::Tensor& grad) {
195
+ self.mutable_grad() = grad;
196
+ })
197
+ .define_method(
198
+ "_dtype",
199
+ [](Tensor& self) {
200
+ return (int) at::typeMetaToScalarType(self.dtype());
201
+ })
202
+ .define_method(
203
+ "_type",
204
+ [](Tensor& self, int dtype) {
205
+ return self.toType((torch::ScalarType) dtype);
206
+ })
207
+ .define_method(
208
+ "_layout",
209
+ [](Tensor& self) {
210
+ std::stringstream s;
211
+ s << self.layout();
212
+ return s.str();
213
+ })
214
+ .define_method(
215
+ "device",
216
+ [](Tensor& self) {
217
+ std::stringstream s;
218
+ s << self.device();
219
+ return s.str();
220
+ })
221
+ .define_method(
222
+ "_data_str",
223
+ [](Tensor& self) {
224
+ auto tensor = self;
225
+
226
+ // move to CPU to get data
227
+ if (tensor.device().type() != torch::kCPU) {
228
+ torch::Device device("cpu");
229
+ tensor = tensor.to(device);
230
+ }
231
+
232
+ if (!tensor.is_contiguous()) {
233
+ tensor = tensor.contiguous();
234
+ }
235
+
236
+ auto data_ptr = (const char *) tensor.data_ptr();
237
+ return std::string(data_ptr, tensor.numel() * tensor.element_size());
238
+ })
239
+ // for TorchVision
240
+ .define_method(
241
+ "_data_ptr",
242
+ [](Tensor& self) {
243
+ return reinterpret_cast<uintptr_t>(self.data_ptr());
244
+ })
245
+ // TODO figure out a better way to do this
246
+ .define_method(
247
+ "_flat_data",
248
+ [](Tensor& self) {
249
+ auto tensor = self;
250
+
251
+ // move to CPU to get data
252
+ if (tensor.device().type() != torch::kCPU) {
253
+ torch::Device device("cpu");
254
+ tensor = tensor.to(device);
255
+ }
256
+
257
+ auto dtype = tensor.dtype();
258
+ if (dtype == torch::kByte) {
259
+ return flat_data<uint8_t>(tensor);
260
+ } else if (dtype == torch::kChar) {
261
+ return flat_data<int8_t>(tensor);
262
+ } else if (dtype == torch::kShort) {
263
+ return flat_data<int16_t>(tensor);
264
+ } else if (dtype == torch::kInt) {
265
+ return flat_data<int32_t>(tensor);
266
+ } else if (dtype == torch::kLong) {
267
+ return flat_data<int64_t>(tensor);
268
+ } else if (dtype == torch::kFloat) {
269
+ return flat_data<float>(tensor);
270
+ } else if (dtype == torch::kDouble) {
271
+ return flat_data<double>(tensor);
272
+ } else if (dtype == torch::kBool) {
273
+ return flat_data<bool>(tensor);
274
+ } else if (dtype == torch::kComplexFloat) {
275
+ return flat_data<c10::complex<float>>(tensor);
276
+ } else if (dtype == torch::kComplexDouble) {
277
+ return flat_data<c10::complex<double>>(tensor);
278
+ } else {
279
+ throw std::runtime_error("Unsupported type");
280
+ }
281
+ })
282
+ .define_method(
283
+ "_to",
284
+ [](Tensor& self, torch::Device device, int dtype, bool non_blocking, bool copy) {
285
+ return self.to(device, (torch::ScalarType) dtype, non_blocking, copy);
286
+ });
287
+
288
+ rb_cTensorOptions
289
+ .add_handler<torch::Error>(handle_error)
290
+ .define_method(
291
+ "dtype",
292
+ [](torch::TensorOptions& self, int dtype) {
293
+ return self.dtype((torch::ScalarType) dtype);
294
+ })
295
+ .define_method(
296
+ "layout",
297
+ [](torch::TensorOptions& self, const std::string& layout) {
298
+ torch::Layout l;
299
+ if (layout == "strided") {
300
+ l = torch::kStrided;
301
+ } else if (layout == "sparse") {
302
+ l = torch::kSparse;
303
+ throw std::runtime_error("Sparse layout not supported yet");
304
+ } else {
305
+ throw std::runtime_error("Unsupported layout: " + layout);
306
+ }
307
+ return self.layout(l);
308
+ })
309
+ .define_method(
310
+ "device",
311
+ [](torch::TensorOptions& self, const std::string& device) {
312
+ torch::Device d(device);
313
+ return self.device(d);
314
+ })
315
+ .define_method(
316
+ "requires_grad",
317
+ [](torch::TensorOptions& self, bool requires_grad) {
318
+ return self.requires_grad(requires_grad);
319
+ });
320
+ }