torch-rb 0.5.2 → 0.8.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,13 @@
1
+ #include <torch/torch.h>
2
+
3
+ #include <rice/rice.hpp>
4
+
5
+ #include "special_functions.h"
6
+ #include "templates.h"
7
+ #include "utils.h"
8
+
9
+ void init_special(Rice::Module& m) {
10
+ auto rb_mSpecial = Rice::define_module_under(m, "Special");
11
+ rb_mSpecial.add_handler<torch::Error>(handle_error);
12
+ add_special_functions(rb_mSpecial);
13
+ }
@@ -0,0 +1,6 @@
1
+ // generated by rake generate:functions
2
+ // do not edit by hand
3
+
4
+ #pragma once
5
+
6
+ void add_special_functions(Rice::Module& m);
@@ -4,8 +4,7 @@
4
4
  #undef isfinite
5
5
  #endif
6
6
 
7
- #include <rice/Array.hpp>
8
- #include <rice/Object.hpp>
7
+ #include <rice/rice.hpp>
9
8
 
10
9
  using namespace Rice;
11
10
 
@@ -22,6 +21,10 @@ using torch::IntArrayRef;
22
21
  using torch::ArrayRef;
23
22
  using torch::TensorList;
24
23
  using torch::Storage;
24
+ using ScalarList = ArrayRef<Scalar>;
25
+
26
+ using torch::nn::init::FanModeType;
27
+ using torch::nn::init::NonlinearityType;
25
28
 
26
29
  #define HANDLE_TH_ERRORS \
27
30
  try {
@@ -38,37 +41,42 @@ using torch::Storage;
38
41
  #define RETURN_NIL \
39
42
  return Qnil;
40
43
 
41
- template<>
42
- inline
43
- std::vector<int64_t> from_ruby<std::vector<int64_t>>(Object x)
44
- {
45
- Array a = Array(x);
46
- std::vector<int64_t> vec(a.size());
47
- for (size_t i = 0; i < a.size(); i++) {
48
- vec[i] = from_ruby<int64_t>(a[i]);
49
- }
50
- return vec;
51
- }
44
+ class OptionalTensor {
45
+ torch::Tensor value;
46
+ public:
47
+ OptionalTensor(Object o) {
48
+ if (o.is_nil()) {
49
+ value = {};
50
+ } else {
51
+ value = Rice::detail::From_Ruby<torch::Tensor>().convert(o.value());
52
+ }
53
+ }
54
+ OptionalTensor(torch::Tensor o) {
55
+ value = o;
56
+ }
57
+ operator torch::Tensor() const {
58
+ return value;
59
+ }
60
+ };
52
61
 
53
- template<>
54
- inline
55
- std::vector<Tensor> from_ruby<std::vector<Tensor>>(Object x)
62
+ namespace Rice::detail
56
63
  {
57
- Array a = Array(x);
58
- std::vector<Tensor> vec(a.size());
59
- for (size_t i = 0; i < a.size(); i++) {
60
- vec[i] = from_ruby<Tensor>(a[i]);
61
- }
62
- return vec;
63
- }
64
+ template<>
65
+ struct Type<FanModeType>
66
+ {
67
+ static bool verify()
68
+ {
69
+ return true;
70
+ }
71
+ };
64
72
 
65
- class FanModeType {
66
- std::string s;
73
+ template<>
74
+ class From_Ruby<FanModeType>
75
+ {
67
76
  public:
68
- FanModeType(Object o) {
69
- s = String(o).str();
70
- }
71
- operator torch::nn::init::FanModeType() {
77
+ FanModeType convert(VALUE x)
78
+ {
79
+ auto s = String(x).str();
72
80
  if (s == "fan_in") {
73
81
  return torch::kFanIn;
74
82
  } else if (s == "fan_out") {
@@ -77,22 +85,24 @@ class FanModeType {
77
85
  throw std::runtime_error("Unsupported nonlinearity type: " + s);
78
86
  }
79
87
  }
80
- };
81
-
82
- template<>
83
- inline
84
- FanModeType from_ruby<FanModeType>(Object x)
85
- {
86
- return FanModeType(x);
87
- }
88
+ };
89
+
90
+ template<>
91
+ struct Type<NonlinearityType>
92
+ {
93
+ static bool verify()
94
+ {
95
+ return true;
96
+ }
97
+ };
88
98
 
89
- class NonlinearityType {
90
- std::string s;
99
+ template<>
100
+ class From_Ruby<NonlinearityType>
101
+ {
91
102
  public:
92
- NonlinearityType(Object o) {
93
- s = String(o).str();
94
- }
95
- operator torch::nn::init::NonlinearityType() {
103
+ NonlinearityType convert(VALUE x)
104
+ {
105
+ auto s = String(x).str();
96
106
  if (s == "linear") {
97
107
  return torch::kLinear;
98
108
  } else if (s == "conv1d") {
@@ -119,102 +129,70 @@ class NonlinearityType {
119
129
  throw std::runtime_error("Unsupported nonlinearity type: " + s);
120
130
  }
121
131
  }
122
- };
132
+ };
133
+
134
+ template<>
135
+ struct Type<OptionalTensor>
136
+ {
137
+ static bool verify()
138
+ {
139
+ return true;
140
+ }
141
+ };
123
142
 
124
- template<>
125
- inline
126
- NonlinearityType from_ruby<NonlinearityType>(Object x)
127
- {
128
- return NonlinearityType(x);
129
- }
143
+ template<>
144
+ class From_Ruby<OptionalTensor>
145
+ {
146
+ public:
147
+ OptionalTensor convert(VALUE x)
148
+ {
149
+ return OptionalTensor(x);
150
+ }
151
+ };
152
+
153
+ template<>
154
+ struct Type<Scalar>
155
+ {
156
+ static bool verify()
157
+ {
158
+ return true;
159
+ }
160
+ };
130
161
 
131
- class OptionalTensor {
132
- torch::Tensor value;
162
+ template<>
163
+ class From_Ruby<Scalar>
164
+ {
133
165
  public:
134
- OptionalTensor(Object o) {
135
- if (o.is_nil()) {
136
- value = {};
166
+ Scalar convert(VALUE x)
167
+ {
168
+ if (FIXNUM_P(x)) {
169
+ return torch::Scalar(From_Ruby<int64_t>().convert(x));
137
170
  } else {
138
- value = from_ruby<torch::Tensor>(o);
171
+ return torch::Scalar(From_Ruby<double>().convert(x));
139
172
  }
140
173
  }
141
- OptionalTensor(torch::Tensor o) {
142
- value = o;
143
- }
144
- operator torch::Tensor() const {
145
- return value;
174
+ };
175
+
176
+ template<typename T>
177
+ struct Type<torch::optional<T>>
178
+ {
179
+ static bool verify()
180
+ {
181
+ return true;
146
182
  }
147
- };
148
-
149
- template<>
150
- inline
151
- Scalar from_ruby<Scalar>(Object x)
152
- {
153
- if (x.rb_type() == T_FIXNUM) {
154
- return torch::Scalar(from_ruby<int64_t>(x));
155
- } else {
156
- return torch::Scalar(from_ruby<double>(x));
157
- }
158
- }
159
-
160
- template<>
161
- inline
162
- OptionalTensor from_ruby<OptionalTensor>(Object x)
163
- {
164
- return OptionalTensor(x);
165
- }
166
-
167
- template<>
168
- inline
169
- torch::optional<torch::ScalarType> from_ruby<torch::optional<torch::ScalarType>>(Object x)
170
- {
171
- if (x.is_nil()) {
172
- return torch::nullopt;
173
- } else {
174
- return torch::optional<torch::ScalarType>{from_ruby<torch::ScalarType>(x)};
175
- }
176
- }
177
-
178
- template<>
179
- inline
180
- torch::optional<int64_t> from_ruby<torch::optional<int64_t>>(Object x)
181
- {
182
- if (x.is_nil()) {
183
- return torch::nullopt;
184
- } else {
185
- return torch::optional<int64_t>{from_ruby<int64_t>(x)};
186
- }
187
- }
183
+ };
188
184
 
189
- template<>
190
- inline
191
- torch::optional<double> from_ruby<torch::optional<double>>(Object x)
192
- {
193
- if (x.is_nil()) {
194
- return torch::nullopt;
195
- } else {
196
- return torch::optional<double>{from_ruby<double>(x)};
197
- }
198
- }
199
-
200
- template<>
201
- inline
202
- torch::optional<bool> from_ruby<torch::optional<bool>>(Object x)
203
- {
204
- if (x.is_nil()) {
205
- return torch::nullopt;
206
- } else {
207
- return torch::optional<bool>{from_ruby<bool>(x)};
208
- }
209
- }
210
-
211
- template<>
212
- inline
213
- torch::optional<Scalar> from_ruby<torch::optional<Scalar>>(Object x)
214
- {
215
- if (x.is_nil()) {
216
- return torch::nullopt;
217
- } else {
218
- return torch::optional<Scalar>{from_ruby<Scalar>(x)};
219
- }
185
+ template<typename T>
186
+ class From_Ruby<torch::optional<T>>
187
+ {
188
+ public:
189
+ torch::optional<T> convert(VALUE x)
190
+ {
191
+ if (NIL_P(x)) {
192
+ return torch::nullopt;
193
+ } else {
194
+ return torch::optional<T>{From_Ruby<T>().convert(x)};
195
+ }
196
+ }
197
+ };
220
198
  }
@@ -0,0 +1,320 @@
1
+ #include <torch/torch.h>
2
+
3
+ #include <rice/rice.hpp>
4
+
5
+ #include "tensor_functions.h"
6
+ #include "ruby_arg_parser.h"
7
+ #include "templates.h"
8
+ #include "utils.h"
9
+
10
+ using namespace Rice;
11
+ using torch::indexing::TensorIndex;
12
+
13
+ namespace Rice::detail
14
+ {
15
+ template<typename T>
16
+ struct Type<c10::complex<T>>
17
+ {
18
+ static bool verify()
19
+ {
20
+ return true;
21
+ }
22
+ };
23
+
24
+ template<typename T>
25
+ class To_Ruby<c10::complex<T>>
26
+ {
27
+ public:
28
+ VALUE convert(c10::complex<T> const& x)
29
+ {
30
+ return rb_dbl_complex_new(x.real(), x.imag());
31
+ }
32
+ };
33
+ }
34
+
35
+ template<typename T>
36
+ Array flat_data(Tensor& tensor) {
37
+ Tensor view = tensor.reshape({tensor.numel()});
38
+
39
+ Array a;
40
+ for (int i = 0; i < tensor.numel(); i++) {
41
+ a.push(view[i].item().to<T>());
42
+ }
43
+ return a;
44
+ }
45
+
46
+ Class rb_cTensor;
47
+
48
+ std::vector<TensorIndex> index_vector(Array a) {
49
+ Object obj;
50
+
51
+ std::vector<TensorIndex> indices;
52
+ indices.reserve(a.size());
53
+
54
+ for (long i = 0; i < a.size(); i++) {
55
+ obj = a[i];
56
+
57
+ if (obj.is_instance_of(rb_cInteger)) {
58
+ indices.push_back(Rice::detail::From_Ruby<int64_t>().convert(obj.value()));
59
+ } else if (obj.is_instance_of(rb_cRange)) {
60
+ torch::optional<int64_t> start_index = torch::nullopt;
61
+ torch::optional<int64_t> stop_index = torch::nullopt;
62
+
63
+ Object begin = obj.call("begin");
64
+ if (!begin.is_nil()) {
65
+ start_index = Rice::detail::From_Ruby<int64_t>().convert(begin.value());
66
+ }
67
+
68
+ Object end = obj.call("end");
69
+ if (!end.is_nil()) {
70
+ stop_index = Rice::detail::From_Ruby<int64_t>().convert(end.value());
71
+ }
72
+
73
+ Object exclude_end = obj.call("exclude_end?");
74
+ if (stop_index.has_value() && !exclude_end) {
75
+ if (stop_index.value() == -1) {
76
+ stop_index = torch::nullopt;
77
+ } else {
78
+ stop_index = stop_index.value() + 1;
79
+ }
80
+ }
81
+
82
+ indices.push_back(torch::indexing::Slice(start_index, stop_index));
83
+ } else if (obj.is_instance_of(rb_cTensor)) {
84
+ indices.push_back(Rice::detail::From_Ruby<Tensor>().convert(obj.value()));
85
+ } else if (obj.is_nil()) {
86
+ indices.push_back(torch::indexing::None);
87
+ } else if (obj == True || obj == False) {
88
+ indices.push_back(Rice::detail::From_Ruby<bool>().convert(obj.value()));
89
+ } else {
90
+ throw Exception(rb_eArgError, "Unsupported index type: %s", rb_obj_classname(obj));
91
+ }
92
+ }
93
+ return indices;
94
+ }
95
+
96
+ // hack (removes inputs argument)
97
+ // https://github.com/pytorch/pytorch/commit/2e5bfa9824f549be69a28e4705a72b4cf8a4c519
98
+ // TODO add support for inputs argument
99
+ // _backward
100
+ static VALUE tensor__backward(int argc, VALUE* argv, VALUE self_)
101
+ {
102
+ HANDLE_TH_ERRORS
103
+ Tensor& self = Rice::detail::From_Ruby<Tensor&>().convert(self_);
104
+ static RubyArgParser parser({
105
+ "_backward(Tensor? gradient=None, bool? retain_graph=None, bool create_graph=False)"
106
+ });
107
+ ParsedArgs<4> parsed_args;
108
+ auto _r = parser.parse(self_, argc, argv, parsed_args);
109
+ // _backward(Tensor self, Tensor[] inputs, Tensor? gradient=None, bool? retain_graph=None, bool create_graph=False) -> ()
110
+ auto dispatch__backward = [](const Tensor & self, TensorList inputs, const OptionalTensor & gradient, c10::optional<bool> retain_graph, bool create_graph) -> void {
111
+ // in future, release GVL
112
+ self._backward(inputs, gradient, retain_graph, create_graph);
113
+ };
114
+ dispatch__backward(self, {}, _r.optionalTensor(0), _r.toBoolOptional(1), _r.toBool(2));
115
+ RETURN_NIL
116
+ END_HANDLE_TH_ERRORS
117
+ }
118
+
119
+ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions) {
120
+ rb_cTensor = c;
121
+ rb_cTensor.add_handler<torch::Error>(handle_error);
122
+ add_tensor_functions(rb_cTensor);
123
+ THPVariableClass = rb_cTensor.value();
124
+
125
+ rb_define_method(rb_cTensor, "backward", (VALUE (*)(...)) tensor__backward, -1);
126
+
127
+ rb_cTensor
128
+ .define_method("cuda?", &torch::Tensor::is_cuda)
129
+ .define_method("sparse?", &torch::Tensor::is_sparse)
130
+ .define_method("quantized?", &torch::Tensor::is_quantized)
131
+ .define_method("dim", &torch::Tensor::dim)
132
+ .define_method("numel", &torch::Tensor::numel)
133
+ .define_method("element_size", &torch::Tensor::element_size)
134
+ .define_method("requires_grad", &torch::Tensor::requires_grad)
135
+ .define_method(
136
+ "_size",
137
+ [](Tensor& self, int64_t dim) {
138
+ return self.size(dim);
139
+ })
140
+ .define_method(
141
+ "_stride",
142
+ [](Tensor& self, int64_t dim) {
143
+ return self.stride(dim);
144
+ })
145
+ // in C++ for performance
146
+ .define_method(
147
+ "shape",
148
+ [](Tensor& self) {
149
+ Array a;
150
+ for (auto &size : self.sizes()) {
151
+ a.push(size);
152
+ }
153
+ return a;
154
+ })
155
+ .define_method(
156
+ "_strides",
157
+ [](Tensor& self) {
158
+ Array a;
159
+ for (auto &stride : self.strides()) {
160
+ a.push(stride);
161
+ }
162
+ return a;
163
+ })
164
+ .define_method(
165
+ "_index",
166
+ [](Tensor& self, Array indices) {
167
+ auto vec = index_vector(indices);
168
+ return self.index(vec);
169
+ })
170
+ .define_method(
171
+ "_index_put_custom",
172
+ [](Tensor& self, Array indices, torch::Tensor& value) {
173
+ auto vec = index_vector(indices);
174
+ return self.index_put_(vec, value);
175
+ })
176
+ .define_method(
177
+ "contiguous?",
178
+ [](Tensor& self) {
179
+ return self.is_contiguous();
180
+ })
181
+ .define_method(
182
+ "_requires_grad!",
183
+ [](Tensor& self, bool requires_grad) {
184
+ return self.set_requires_grad(requires_grad);
185
+ })
186
+ .define_method(
187
+ "grad",
188
+ [](Tensor& self) {
189
+ auto grad = self.grad();
190
+ return grad.defined() ? Object(Rice::detail::To_Ruby<torch::Tensor>().convert(grad)) : Nil;
191
+ })
192
+ .define_method(
193
+ "grad=",
194
+ [](Tensor& self, torch::Tensor& grad) {
195
+ self.mutable_grad() = grad;
196
+ })
197
+ .define_method(
198
+ "_dtype",
199
+ [](Tensor& self) {
200
+ return (int) at::typeMetaToScalarType(self.dtype());
201
+ })
202
+ .define_method(
203
+ "_type",
204
+ [](Tensor& self, int dtype) {
205
+ return self.toType((torch::ScalarType) dtype);
206
+ })
207
+ .define_method(
208
+ "_layout",
209
+ [](Tensor& self) {
210
+ std::stringstream s;
211
+ s << self.layout();
212
+ return s.str();
213
+ })
214
+ .define_method(
215
+ "device",
216
+ [](Tensor& self) {
217
+ std::stringstream s;
218
+ s << self.device();
219
+ return s.str();
220
+ })
221
+ .define_method(
222
+ "_data_str",
223
+ [](Tensor& self) {
224
+ auto tensor = self;
225
+
226
+ // move to CPU to get data
227
+ if (tensor.device().type() != torch::kCPU) {
228
+ torch::Device device("cpu");
229
+ tensor = tensor.to(device);
230
+ }
231
+
232
+ if (!tensor.is_contiguous()) {
233
+ tensor = tensor.contiguous();
234
+ }
235
+
236
+ auto data_ptr = (const char *) tensor.data_ptr();
237
+ return std::string(data_ptr, tensor.numel() * tensor.element_size());
238
+ })
239
+ // for TorchVision
240
+ .define_method(
241
+ "_data_ptr",
242
+ [](Tensor& self) {
243
+ return reinterpret_cast<uintptr_t>(self.data_ptr());
244
+ })
245
+ // TODO figure out a better way to do this
246
+ .define_method(
247
+ "_flat_data",
248
+ [](Tensor& self) {
249
+ auto tensor = self;
250
+
251
+ // move to CPU to get data
252
+ if (tensor.device().type() != torch::kCPU) {
253
+ torch::Device device("cpu");
254
+ tensor = tensor.to(device);
255
+ }
256
+
257
+ auto dtype = tensor.dtype();
258
+ if (dtype == torch::kByte) {
259
+ return flat_data<uint8_t>(tensor);
260
+ } else if (dtype == torch::kChar) {
261
+ return flat_data<int8_t>(tensor);
262
+ } else if (dtype == torch::kShort) {
263
+ return flat_data<int16_t>(tensor);
264
+ } else if (dtype == torch::kInt) {
265
+ return flat_data<int32_t>(tensor);
266
+ } else if (dtype == torch::kLong) {
267
+ return flat_data<int64_t>(tensor);
268
+ } else if (dtype == torch::kFloat) {
269
+ return flat_data<float>(tensor);
270
+ } else if (dtype == torch::kDouble) {
271
+ return flat_data<double>(tensor);
272
+ } else if (dtype == torch::kBool) {
273
+ return flat_data<bool>(tensor);
274
+ } else if (dtype == torch::kComplexFloat) {
275
+ return flat_data<c10::complex<float>>(tensor);
276
+ } else if (dtype == torch::kComplexDouble) {
277
+ return flat_data<c10::complex<double>>(tensor);
278
+ } else {
279
+ throw std::runtime_error("Unsupported type");
280
+ }
281
+ })
282
+ .define_method(
283
+ "_to",
284
+ [](Tensor& self, torch::Device device, int dtype, bool non_blocking, bool copy) {
285
+ return self.to(device, (torch::ScalarType) dtype, non_blocking, copy);
286
+ });
287
+
288
+ rb_cTensorOptions
289
+ .add_handler<torch::Error>(handle_error)
290
+ .define_method(
291
+ "dtype",
292
+ [](torch::TensorOptions& self, int dtype) {
293
+ return self.dtype((torch::ScalarType) dtype);
294
+ })
295
+ .define_method(
296
+ "layout",
297
+ [](torch::TensorOptions& self, const std::string& layout) {
298
+ torch::Layout l;
299
+ if (layout == "strided") {
300
+ l = torch::kStrided;
301
+ } else if (layout == "sparse") {
302
+ l = torch::kSparse;
303
+ throw std::runtime_error("Sparse layout not supported yet");
304
+ } else {
305
+ throw std::runtime_error("Unsupported layout: " + layout);
306
+ }
307
+ return self.layout(l);
308
+ })
309
+ .define_method(
310
+ "device",
311
+ [](torch::TensorOptions& self, const std::string& device) {
312
+ torch::Device d(device);
313
+ return self.device(d);
314
+ })
315
+ .define_method(
316
+ "requires_grad",
317
+ [](torch::TensorOptions& self, bool requires_grad) {
318
+ return self.requires_grad(requires_grad);
319
+ });
320
+ }