torch-rb 0.5.2 → 0.8.1

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,13 @@
1
+ #include <torch/torch.h>
2
+
3
+ #include <rice/rice.hpp>
4
+
5
+ #include "special_functions.h"
6
+ #include "templates.h"
7
+ #include "utils.h"
8
+
9
+ void init_special(Rice::Module& m) {
10
+ auto rb_mSpecial = Rice::define_module_under(m, "Special");
11
+ rb_mSpecial.add_handler<torch::Error>(handle_error);
12
+ add_special_functions(rb_mSpecial);
13
+ }
@@ -0,0 +1,6 @@
1
+ // generated by rake generate:functions
2
+ // do not edit by hand
3
+
4
+ #pragma once
5
+
6
+ void add_special_functions(Rice::Module& m);
@@ -4,8 +4,7 @@
4
4
  #undef isfinite
5
5
  #endif
6
6
 
7
- #include <rice/Array.hpp>
8
- #include <rice/Object.hpp>
7
+ #include <rice/rice.hpp>
9
8
 
10
9
  using namespace Rice;
11
10
 
@@ -22,6 +21,10 @@ using torch::IntArrayRef;
22
21
  using torch::ArrayRef;
23
22
  using torch::TensorList;
24
23
  using torch::Storage;
24
+ using ScalarList = ArrayRef<Scalar>;
25
+
26
+ using torch::nn::init::FanModeType;
27
+ using torch::nn::init::NonlinearityType;
25
28
 
26
29
  #define HANDLE_TH_ERRORS \
27
30
  try {
@@ -38,37 +41,42 @@ using torch::Storage;
38
41
  #define RETURN_NIL \
39
42
  return Qnil;
40
43
 
41
- template<>
42
- inline
43
- std::vector<int64_t> from_ruby<std::vector<int64_t>>(Object x)
44
- {
45
- Array a = Array(x);
46
- std::vector<int64_t> vec(a.size());
47
- for (size_t i = 0; i < a.size(); i++) {
48
- vec[i] = from_ruby<int64_t>(a[i]);
49
- }
50
- return vec;
51
- }
44
+ class OptionalTensor {
45
+ torch::Tensor value;
46
+ public:
47
+ OptionalTensor(Object o) {
48
+ if (o.is_nil()) {
49
+ value = {};
50
+ } else {
51
+ value = Rice::detail::From_Ruby<torch::Tensor>().convert(o.value());
52
+ }
53
+ }
54
+ OptionalTensor(torch::Tensor o) {
55
+ value = o;
56
+ }
57
+ operator torch::Tensor() const {
58
+ return value;
59
+ }
60
+ };
52
61
 
53
- template<>
54
- inline
55
- std::vector<Tensor> from_ruby<std::vector<Tensor>>(Object x)
62
+ namespace Rice::detail
56
63
  {
57
- Array a = Array(x);
58
- std::vector<Tensor> vec(a.size());
59
- for (size_t i = 0; i < a.size(); i++) {
60
- vec[i] = from_ruby<Tensor>(a[i]);
61
- }
62
- return vec;
63
- }
64
+ template<>
65
+ struct Type<FanModeType>
66
+ {
67
+ static bool verify()
68
+ {
69
+ return true;
70
+ }
71
+ };
64
72
 
65
- class FanModeType {
66
- std::string s;
73
+ template<>
74
+ class From_Ruby<FanModeType>
75
+ {
67
76
  public:
68
- FanModeType(Object o) {
69
- s = String(o).str();
70
- }
71
- operator torch::nn::init::FanModeType() {
77
+ FanModeType convert(VALUE x)
78
+ {
79
+ auto s = String(x).str();
72
80
  if (s == "fan_in") {
73
81
  return torch::kFanIn;
74
82
  } else if (s == "fan_out") {
@@ -77,22 +85,24 @@ class FanModeType {
77
85
  throw std::runtime_error("Unsupported nonlinearity type: " + s);
78
86
  }
79
87
  }
80
- };
81
-
82
- template<>
83
- inline
84
- FanModeType from_ruby<FanModeType>(Object x)
85
- {
86
- return FanModeType(x);
87
- }
88
+ };
89
+
90
+ template<>
91
+ struct Type<NonlinearityType>
92
+ {
93
+ static bool verify()
94
+ {
95
+ return true;
96
+ }
97
+ };
88
98
 
89
- class NonlinearityType {
90
- std::string s;
99
+ template<>
100
+ class From_Ruby<NonlinearityType>
101
+ {
91
102
  public:
92
- NonlinearityType(Object o) {
93
- s = String(o).str();
94
- }
95
- operator torch::nn::init::NonlinearityType() {
103
+ NonlinearityType convert(VALUE x)
104
+ {
105
+ auto s = String(x).str();
96
106
  if (s == "linear") {
97
107
  return torch::kLinear;
98
108
  } else if (s == "conv1d") {
@@ -119,102 +129,70 @@ class NonlinearityType {
119
129
  throw std::runtime_error("Unsupported nonlinearity type: " + s);
120
130
  }
121
131
  }
122
- };
132
+ };
133
+
134
+ template<>
135
+ struct Type<OptionalTensor>
136
+ {
137
+ static bool verify()
138
+ {
139
+ return true;
140
+ }
141
+ };
123
142
 
124
- template<>
125
- inline
126
- NonlinearityType from_ruby<NonlinearityType>(Object x)
127
- {
128
- return NonlinearityType(x);
129
- }
143
+ template<>
144
+ class From_Ruby<OptionalTensor>
145
+ {
146
+ public:
147
+ OptionalTensor convert(VALUE x)
148
+ {
149
+ return OptionalTensor(x);
150
+ }
151
+ };
152
+
153
+ template<>
154
+ struct Type<Scalar>
155
+ {
156
+ static bool verify()
157
+ {
158
+ return true;
159
+ }
160
+ };
130
161
 
131
- class OptionalTensor {
132
- torch::Tensor value;
162
+ template<>
163
+ class From_Ruby<Scalar>
164
+ {
133
165
  public:
134
- OptionalTensor(Object o) {
135
- if (o.is_nil()) {
136
- value = {};
166
+ Scalar convert(VALUE x)
167
+ {
168
+ if (FIXNUM_P(x)) {
169
+ return torch::Scalar(From_Ruby<int64_t>().convert(x));
137
170
  } else {
138
- value = from_ruby<torch::Tensor>(o);
171
+ return torch::Scalar(From_Ruby<double>().convert(x));
139
172
  }
140
173
  }
141
- OptionalTensor(torch::Tensor o) {
142
- value = o;
143
- }
144
- operator torch::Tensor() const {
145
- return value;
174
+ };
175
+
176
+ template<typename T>
177
+ struct Type<torch::optional<T>>
178
+ {
179
+ static bool verify()
180
+ {
181
+ return true;
146
182
  }
147
- };
148
-
149
- template<>
150
- inline
151
- Scalar from_ruby<Scalar>(Object x)
152
- {
153
- if (x.rb_type() == T_FIXNUM) {
154
- return torch::Scalar(from_ruby<int64_t>(x));
155
- } else {
156
- return torch::Scalar(from_ruby<double>(x));
157
- }
158
- }
159
-
160
- template<>
161
- inline
162
- OptionalTensor from_ruby<OptionalTensor>(Object x)
163
- {
164
- return OptionalTensor(x);
165
- }
166
-
167
- template<>
168
- inline
169
- torch::optional<torch::ScalarType> from_ruby<torch::optional<torch::ScalarType>>(Object x)
170
- {
171
- if (x.is_nil()) {
172
- return torch::nullopt;
173
- } else {
174
- return torch::optional<torch::ScalarType>{from_ruby<torch::ScalarType>(x)};
175
- }
176
- }
177
-
178
- template<>
179
- inline
180
- torch::optional<int64_t> from_ruby<torch::optional<int64_t>>(Object x)
181
- {
182
- if (x.is_nil()) {
183
- return torch::nullopt;
184
- } else {
185
- return torch::optional<int64_t>{from_ruby<int64_t>(x)};
186
- }
187
- }
183
+ };
188
184
 
189
- template<>
190
- inline
191
- torch::optional<double> from_ruby<torch::optional<double>>(Object x)
192
- {
193
- if (x.is_nil()) {
194
- return torch::nullopt;
195
- } else {
196
- return torch::optional<double>{from_ruby<double>(x)};
197
- }
198
- }
199
-
200
- template<>
201
- inline
202
- torch::optional<bool> from_ruby<torch::optional<bool>>(Object x)
203
- {
204
- if (x.is_nil()) {
205
- return torch::nullopt;
206
- } else {
207
- return torch::optional<bool>{from_ruby<bool>(x)};
208
- }
209
- }
210
-
211
- template<>
212
- inline
213
- torch::optional<Scalar> from_ruby<torch::optional<Scalar>>(Object x)
214
- {
215
- if (x.is_nil()) {
216
- return torch::nullopt;
217
- } else {
218
- return torch::optional<Scalar>{from_ruby<Scalar>(x)};
219
- }
185
+ template<typename T>
186
+ class From_Ruby<torch::optional<T>>
187
+ {
188
+ public:
189
+ torch::optional<T> convert(VALUE x)
190
+ {
191
+ if (NIL_P(x)) {
192
+ return torch::nullopt;
193
+ } else {
194
+ return torch::optional<T>{From_Ruby<T>().convert(x)};
195
+ }
196
+ }
197
+ };
220
198
  }
@@ -0,0 +1,320 @@
1
+ #include <torch/torch.h>
2
+
3
+ #include <rice/rice.hpp>
4
+
5
+ #include "tensor_functions.h"
6
+ #include "ruby_arg_parser.h"
7
+ #include "templates.h"
8
+ #include "utils.h"
9
+
10
+ using namespace Rice;
11
+ using torch::indexing::TensorIndex;
12
+
13
+ namespace Rice::detail
14
+ {
15
+ template<typename T>
16
+ struct Type<c10::complex<T>>
17
+ {
18
+ static bool verify()
19
+ {
20
+ return true;
21
+ }
22
+ };
23
+
24
+ template<typename T>
25
+ class To_Ruby<c10::complex<T>>
26
+ {
27
+ public:
28
+ VALUE convert(c10::complex<T> const& x)
29
+ {
30
+ return rb_dbl_complex_new(x.real(), x.imag());
31
+ }
32
+ };
33
+ }
34
+
35
+ template<typename T>
36
+ Array flat_data(Tensor& tensor) {
37
+ Tensor view = tensor.reshape({tensor.numel()});
38
+
39
+ Array a;
40
+ for (int i = 0; i < tensor.numel(); i++) {
41
+ a.push(view[i].item().to<T>());
42
+ }
43
+ return a;
44
+ }
45
+
46
+ Class rb_cTensor;
47
+
48
+ std::vector<TensorIndex> index_vector(Array a) {
49
+ Object obj;
50
+
51
+ std::vector<TensorIndex> indices;
52
+ indices.reserve(a.size());
53
+
54
+ for (long i = 0; i < a.size(); i++) {
55
+ obj = a[i];
56
+
57
+ if (obj.is_instance_of(rb_cInteger)) {
58
+ indices.push_back(Rice::detail::From_Ruby<int64_t>().convert(obj.value()));
59
+ } else if (obj.is_instance_of(rb_cRange)) {
60
+ torch::optional<int64_t> start_index = torch::nullopt;
61
+ torch::optional<int64_t> stop_index = torch::nullopt;
62
+
63
+ Object begin = obj.call("begin");
64
+ if (!begin.is_nil()) {
65
+ start_index = Rice::detail::From_Ruby<int64_t>().convert(begin.value());
66
+ }
67
+
68
+ Object end = obj.call("end");
69
+ if (!end.is_nil()) {
70
+ stop_index = Rice::detail::From_Ruby<int64_t>().convert(end.value());
71
+ }
72
+
73
+ Object exclude_end = obj.call("exclude_end?");
74
+ if (stop_index.has_value() && !exclude_end) {
75
+ if (stop_index.value() == -1) {
76
+ stop_index = torch::nullopt;
77
+ } else {
78
+ stop_index = stop_index.value() + 1;
79
+ }
80
+ }
81
+
82
+ indices.push_back(torch::indexing::Slice(start_index, stop_index));
83
+ } else if (obj.is_instance_of(rb_cTensor)) {
84
+ indices.push_back(Rice::detail::From_Ruby<Tensor>().convert(obj.value()));
85
+ } else if (obj.is_nil()) {
86
+ indices.push_back(torch::indexing::None);
87
+ } else if (obj == True || obj == False) {
88
+ indices.push_back(Rice::detail::From_Ruby<bool>().convert(obj.value()));
89
+ } else {
90
+ throw Exception(rb_eArgError, "Unsupported index type: %s", rb_obj_classname(obj));
91
+ }
92
+ }
93
+ return indices;
94
+ }
95
+
96
+ // hack (removes inputs argument)
97
+ // https://github.com/pytorch/pytorch/commit/2e5bfa9824f549be69a28e4705a72b4cf8a4c519
98
+ // TODO add support for inputs argument
99
+ // _backward
100
+ static VALUE tensor__backward(int argc, VALUE* argv, VALUE self_)
101
+ {
102
+ HANDLE_TH_ERRORS
103
+ Tensor& self = Rice::detail::From_Ruby<Tensor&>().convert(self_);
104
+ static RubyArgParser parser({
105
+ "_backward(Tensor? gradient=None, bool? retain_graph=None, bool create_graph=False)"
106
+ });
107
+ ParsedArgs<4> parsed_args;
108
+ auto _r = parser.parse(self_, argc, argv, parsed_args);
109
+ // _backward(Tensor self, Tensor[] inputs, Tensor? gradient=None, bool? retain_graph=None, bool create_graph=False) -> ()
110
+ auto dispatch__backward = [](const Tensor & self, TensorList inputs, const OptionalTensor & gradient, c10::optional<bool> retain_graph, bool create_graph) -> void {
111
+ // in future, release GVL
112
+ self._backward(inputs, gradient, retain_graph, create_graph);
113
+ };
114
+ dispatch__backward(self, {}, _r.optionalTensor(0), _r.toBoolOptional(1), _r.toBool(2));
115
+ RETURN_NIL
116
+ END_HANDLE_TH_ERRORS
117
+ }
118
+
119
+ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions) {
120
+ rb_cTensor = c;
121
+ rb_cTensor.add_handler<torch::Error>(handle_error);
122
+ add_tensor_functions(rb_cTensor);
123
+ THPVariableClass = rb_cTensor.value();
124
+
125
+ rb_define_method(rb_cTensor, "backward", (VALUE (*)(...)) tensor__backward, -1);
126
+
127
+ rb_cTensor
128
+ .define_method("cuda?", &torch::Tensor::is_cuda)
129
+ .define_method("sparse?", &torch::Tensor::is_sparse)
130
+ .define_method("quantized?", &torch::Tensor::is_quantized)
131
+ .define_method("dim", &torch::Tensor::dim)
132
+ .define_method("numel", &torch::Tensor::numel)
133
+ .define_method("element_size", &torch::Tensor::element_size)
134
+ .define_method("requires_grad", &torch::Tensor::requires_grad)
135
+ .define_method(
136
+ "_size",
137
+ [](Tensor& self, int64_t dim) {
138
+ return self.size(dim);
139
+ })
140
+ .define_method(
141
+ "_stride",
142
+ [](Tensor& self, int64_t dim) {
143
+ return self.stride(dim);
144
+ })
145
+ // in C++ for performance
146
+ .define_method(
147
+ "shape",
148
+ [](Tensor& self) {
149
+ Array a;
150
+ for (auto &size : self.sizes()) {
151
+ a.push(size);
152
+ }
153
+ return a;
154
+ })
155
+ .define_method(
156
+ "_strides",
157
+ [](Tensor& self) {
158
+ Array a;
159
+ for (auto &stride : self.strides()) {
160
+ a.push(stride);
161
+ }
162
+ return a;
163
+ })
164
+ .define_method(
165
+ "_index",
166
+ [](Tensor& self, Array indices) {
167
+ auto vec = index_vector(indices);
168
+ return self.index(vec);
169
+ })
170
+ .define_method(
171
+ "_index_put_custom",
172
+ [](Tensor& self, Array indices, torch::Tensor& value) {
173
+ auto vec = index_vector(indices);
174
+ return self.index_put_(vec, value);
175
+ })
176
+ .define_method(
177
+ "contiguous?",
178
+ [](Tensor& self) {
179
+ return self.is_contiguous();
180
+ })
181
+ .define_method(
182
+ "_requires_grad!",
183
+ [](Tensor& self, bool requires_grad) {
184
+ return self.set_requires_grad(requires_grad);
185
+ })
186
+ .define_method(
187
+ "grad",
188
+ [](Tensor& self) {
189
+ auto grad = self.grad();
190
+ return grad.defined() ? Object(Rice::detail::To_Ruby<torch::Tensor>().convert(grad)) : Nil;
191
+ })
192
+ .define_method(
193
+ "grad=",
194
+ [](Tensor& self, torch::Tensor& grad) {
195
+ self.mutable_grad() = grad;
196
+ })
197
+ .define_method(
198
+ "_dtype",
199
+ [](Tensor& self) {
200
+ return (int) at::typeMetaToScalarType(self.dtype());
201
+ })
202
+ .define_method(
203
+ "_type",
204
+ [](Tensor& self, int dtype) {
205
+ return self.toType((torch::ScalarType) dtype);
206
+ })
207
+ .define_method(
208
+ "_layout",
209
+ [](Tensor& self) {
210
+ std::stringstream s;
211
+ s << self.layout();
212
+ return s.str();
213
+ })
214
+ .define_method(
215
+ "device",
216
+ [](Tensor& self) {
217
+ std::stringstream s;
218
+ s << self.device();
219
+ return s.str();
220
+ })
221
+ .define_method(
222
+ "_data_str",
223
+ [](Tensor& self) {
224
+ auto tensor = self;
225
+
226
+ // move to CPU to get data
227
+ if (tensor.device().type() != torch::kCPU) {
228
+ torch::Device device("cpu");
229
+ tensor = tensor.to(device);
230
+ }
231
+
232
+ if (!tensor.is_contiguous()) {
233
+ tensor = tensor.contiguous();
234
+ }
235
+
236
+ auto data_ptr = (const char *) tensor.data_ptr();
237
+ return std::string(data_ptr, tensor.numel() * tensor.element_size());
238
+ })
239
+ // for TorchVision
240
+ .define_method(
241
+ "_data_ptr",
242
+ [](Tensor& self) {
243
+ return reinterpret_cast<uintptr_t>(self.data_ptr());
244
+ })
245
+ // TODO figure out a better way to do this
246
+ .define_method(
247
+ "_flat_data",
248
+ [](Tensor& self) {
249
+ auto tensor = self;
250
+
251
+ // move to CPU to get data
252
+ if (tensor.device().type() != torch::kCPU) {
253
+ torch::Device device("cpu");
254
+ tensor = tensor.to(device);
255
+ }
256
+
257
+ auto dtype = tensor.dtype();
258
+ if (dtype == torch::kByte) {
259
+ return flat_data<uint8_t>(tensor);
260
+ } else if (dtype == torch::kChar) {
261
+ return flat_data<int8_t>(tensor);
262
+ } else if (dtype == torch::kShort) {
263
+ return flat_data<int16_t>(tensor);
264
+ } else if (dtype == torch::kInt) {
265
+ return flat_data<int32_t>(tensor);
266
+ } else if (dtype == torch::kLong) {
267
+ return flat_data<int64_t>(tensor);
268
+ } else if (dtype == torch::kFloat) {
269
+ return flat_data<float>(tensor);
270
+ } else if (dtype == torch::kDouble) {
271
+ return flat_data<double>(tensor);
272
+ } else if (dtype == torch::kBool) {
273
+ return flat_data<bool>(tensor);
274
+ } else if (dtype == torch::kComplexFloat) {
275
+ return flat_data<c10::complex<float>>(tensor);
276
+ } else if (dtype == torch::kComplexDouble) {
277
+ return flat_data<c10::complex<double>>(tensor);
278
+ } else {
279
+ throw std::runtime_error("Unsupported type");
280
+ }
281
+ })
282
+ .define_method(
283
+ "_to",
284
+ [](Tensor& self, torch::Device device, int dtype, bool non_blocking, bool copy) {
285
+ return self.to(device, (torch::ScalarType) dtype, non_blocking, copy);
286
+ });
287
+
288
+ rb_cTensorOptions
289
+ .add_handler<torch::Error>(handle_error)
290
+ .define_method(
291
+ "dtype",
292
+ [](torch::TensorOptions& self, int dtype) {
293
+ return self.dtype((torch::ScalarType) dtype);
294
+ })
295
+ .define_method(
296
+ "layout",
297
+ [](torch::TensorOptions& self, const std::string& layout) {
298
+ torch::Layout l;
299
+ if (layout == "strided") {
300
+ l = torch::kStrided;
301
+ } else if (layout == "sparse") {
302
+ l = torch::kSparse;
303
+ throw std::runtime_error("Sparse layout not supported yet");
304
+ } else {
305
+ throw std::runtime_error("Unsupported layout: " + layout);
306
+ }
307
+ return self.layout(l);
308
+ })
309
+ .define_method(
310
+ "device",
311
+ [](torch::TensorOptions& self, const std::string& device) {
312
+ torch::Device d(device);
313
+ return self.device(d);
314
+ })
315
+ .define_method(
316
+ "requires_grad",
317
+ [](torch::TensorOptions& self, bool requires_grad) {
318
+ return self.requires_grad(requires_grad);
319
+ });
320
+ }