torch-rb 0.4.2 → 0.6.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,307 @@
1
+ #include <torch/torch.h>
2
+
3
+ #include <rice/Constructor.hpp>
4
+ #include <rice/Module.hpp>
5
+
6
+ #include "tensor_functions.h"
7
+ #include "ruby_arg_parser.h"
8
+ #include "templates.h"
9
+ #include "utils.h"
10
+
11
+ using namespace Rice;
12
+ using torch::indexing::TensorIndex;
13
+
14
+ Class rb_cTensor;
15
+
16
+ std::vector<TensorIndex> index_vector(Array a) {
17
+ Object obj;
18
+
19
+ std::vector<TensorIndex> indices;
20
+ indices.reserve(a.size());
21
+
22
+ for (long i = 0; i < a.size(); i++) {
23
+ obj = a[i];
24
+
25
+ if (obj.is_instance_of(rb_cInteger)) {
26
+ indices.push_back(from_ruby<int64_t>(obj));
27
+ } else if (obj.is_instance_of(rb_cRange)) {
28
+ torch::optional<int64_t> start_index = torch::nullopt;
29
+ torch::optional<int64_t> stop_index = torch::nullopt;
30
+
31
+ Object begin = obj.call("begin");
32
+ if (!begin.is_nil()) {
33
+ start_index = from_ruby<int64_t>(begin);
34
+ }
35
+
36
+ Object end = obj.call("end");
37
+ if (!end.is_nil()) {
38
+ stop_index = from_ruby<int64_t>(end);
39
+ }
40
+
41
+ Object exclude_end = obj.call("exclude_end?");
42
+ if (stop_index.has_value() && !exclude_end) {
43
+ if (stop_index.value() == -1) {
44
+ stop_index = torch::nullopt;
45
+ } else {
46
+ stop_index = stop_index.value() + 1;
47
+ }
48
+ }
49
+
50
+ indices.push_back(torch::indexing::Slice(start_index, stop_index));
51
+ } else if (obj.is_instance_of(rb_cTensor)) {
52
+ indices.push_back(from_ruby<Tensor>(obj));
53
+ } else if (obj.is_nil()) {
54
+ indices.push_back(torch::indexing::None);
55
+ } else if (obj == True || obj == False) {
56
+ indices.push_back(from_ruby<bool>(obj));
57
+ } else {
58
+ throw Exception(rb_eArgError, "Unsupported index type: %s", rb_obj_classname(obj));
59
+ }
60
+ }
61
+ return indices;
62
+ }
63
+
64
+ // hack (removes inputs argument)
65
+ // https://github.com/pytorch/pytorch/commit/2e5bfa9824f549be69a28e4705a72b4cf8a4c519
66
+ // TODO add support for inputs argument
67
+ // _backward
68
+ static VALUE tensor__backward(int argc, VALUE* argv, VALUE self_)
69
+ {
70
+ HANDLE_TH_ERRORS
71
+ Tensor& self = from_ruby<Tensor&>(self_);
72
+ static RubyArgParser parser({
73
+ "_backward(Tensor? gradient=None, bool? retain_graph=None, bool create_graph=False)"
74
+ });
75
+ ParsedArgs<4> parsed_args;
76
+ auto _r = parser.parse(self_, argc, argv, parsed_args);
77
+ // _backward(Tensor self, Tensor[] inputs, Tensor? gradient=None, bool? retain_graph=None, bool create_graph=False) -> ()
78
+ auto dispatch__backward = [](const Tensor & self, TensorList inputs, const OptionalTensor & gradient, c10::optional<bool> retain_graph, bool create_graph) -> void {
79
+ // in future, release GVL
80
+ self._backward(inputs, gradient, retain_graph, create_graph);
81
+ };
82
+ dispatch__backward(self, {}, _r.optionalTensor(0), _r.toBoolOptional(1), _r.toBool(2));
83
+ RETURN_NIL
84
+ END_HANDLE_TH_ERRORS
85
+ }
86
+
87
+ void init_tensor(Rice::Module& m) {
88
+ rb_cTensor = Rice::define_class_under<torch::Tensor>(m, "Tensor");
89
+ rb_cTensor.add_handler<torch::Error>(handle_error);
90
+ add_tensor_functions(rb_cTensor);
91
+ THPVariableClass = rb_cTensor.value();
92
+
93
+ rb_define_method(rb_cTensor, "backward", (VALUE (*)(...)) tensor__backward, -1);
94
+
95
+ rb_cTensor
96
+ .define_method("cuda?", &torch::Tensor::is_cuda)
97
+ .define_method("sparse?", &torch::Tensor::is_sparse)
98
+ .define_method("quantized?", &torch::Tensor::is_quantized)
99
+ .define_method("dim", &torch::Tensor::dim)
100
+ .define_method("numel", &torch::Tensor::numel)
101
+ .define_method("element_size", &torch::Tensor::element_size)
102
+ .define_method("requires_grad", &torch::Tensor::requires_grad)
103
+ .define_method(
104
+ "_size",
105
+ *[](Tensor& self, int64_t dim) {
106
+ return self.size(dim);
107
+ })
108
+ .define_method(
109
+ "_stride",
110
+ *[](Tensor& self, int64_t dim) {
111
+ return self.stride(dim);
112
+ })
113
+ // in C++ for performance
114
+ .define_method(
115
+ "shape",
116
+ *[](Tensor& self) {
117
+ Array a;
118
+ for (auto &size : self.sizes()) {
119
+ a.push(size);
120
+ }
121
+ return a;
122
+ })
123
+ .define_method(
124
+ "_strides",
125
+ *[](Tensor& self) {
126
+ Array a;
127
+ for (auto &stride : self.strides()) {
128
+ a.push(stride);
129
+ }
130
+ return a;
131
+ })
132
+ .define_method(
133
+ "_index",
134
+ *[](Tensor& self, Array indices) {
135
+ auto vec = index_vector(indices);
136
+ return self.index(vec);
137
+ })
138
+ .define_method(
139
+ "_index_put_custom",
140
+ *[](Tensor& self, Array indices, torch::Tensor& value) {
141
+ auto vec = index_vector(indices);
142
+ return self.index_put_(vec, value);
143
+ })
144
+ .define_method(
145
+ "contiguous?",
146
+ *[](Tensor& self) {
147
+ return self.is_contiguous();
148
+ })
149
+ .define_method(
150
+ "_requires_grad!",
151
+ *[](Tensor& self, bool requires_grad) {
152
+ return self.set_requires_grad(requires_grad);
153
+ })
154
+ .define_method(
155
+ "grad",
156
+ *[](Tensor& self) {
157
+ auto grad = self.grad();
158
+ return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
159
+ })
160
+ .define_method(
161
+ "grad=",
162
+ *[](Tensor& self, torch::Tensor& grad) {
163
+ self.mutable_grad() = grad;
164
+ })
165
+ .define_method(
166
+ "_dtype",
167
+ *[](Tensor& self) {
168
+ return (int) at::typeMetaToScalarType(self.dtype());
169
+ })
170
+ .define_method(
171
+ "_type",
172
+ *[](Tensor& self, int dtype) {
173
+ return self.toType((torch::ScalarType) dtype);
174
+ })
175
+ .define_method(
176
+ "_layout",
177
+ *[](Tensor& self) {
178
+ std::stringstream s;
179
+ s << self.layout();
180
+ return s.str();
181
+ })
182
+ .define_method(
183
+ "device",
184
+ *[](Tensor& self) {
185
+ std::stringstream s;
186
+ s << self.device();
187
+ return s.str();
188
+ })
189
+ .define_method(
190
+ "_data_str",
191
+ *[](Tensor& self) {
192
+ Tensor tensor = self;
193
+
194
+ // move to CPU to get data
195
+ if (tensor.device().type() != torch::kCPU) {
196
+ torch::Device device("cpu");
197
+ tensor = tensor.to(device);
198
+ }
199
+
200
+ if (!tensor.is_contiguous()) {
201
+ tensor = tensor.contiguous();
202
+ }
203
+
204
+ auto data_ptr = (const char *) tensor.data_ptr();
205
+ return std::string(data_ptr, tensor.numel() * tensor.element_size());
206
+ })
207
+ // for TorchVision
208
+ .define_method(
209
+ "_data_ptr",
210
+ *[](Tensor& self) {
211
+ return reinterpret_cast<uintptr_t>(self.data_ptr());
212
+ })
213
+ // TODO figure out a better way to do this
214
+ .define_method(
215
+ "_flat_data",
216
+ *[](Tensor& self) {
217
+ Tensor tensor = self;
218
+
219
+ // move to CPU to get data
220
+ if (tensor.device().type() != torch::kCPU) {
221
+ torch::Device device("cpu");
222
+ tensor = tensor.to(device);
223
+ }
224
+
225
+ Array a;
226
+ auto dtype = tensor.dtype();
227
+
228
+ Tensor view = tensor.reshape({tensor.numel()});
229
+
230
+ // TODO DRY if someone knows C++
231
+ if (dtype == torch::kByte) {
232
+ for (int i = 0; i < tensor.numel(); i++) {
233
+ a.push(view[i].item().to<uint8_t>());
234
+ }
235
+ } else if (dtype == torch::kChar) {
236
+ for (int i = 0; i < tensor.numel(); i++) {
237
+ a.push(to_ruby<int>(view[i].item().to<int8_t>()));
238
+ }
239
+ } else if (dtype == torch::kShort) {
240
+ for (int i = 0; i < tensor.numel(); i++) {
241
+ a.push(view[i].item().to<int16_t>());
242
+ }
243
+ } else if (dtype == torch::kInt) {
244
+ for (int i = 0; i < tensor.numel(); i++) {
245
+ a.push(view[i].item().to<int32_t>());
246
+ }
247
+ } else if (dtype == torch::kLong) {
248
+ for (int i = 0; i < tensor.numel(); i++) {
249
+ a.push(view[i].item().to<int64_t>());
250
+ }
251
+ } else if (dtype == torch::kFloat) {
252
+ for (int i = 0; i < tensor.numel(); i++) {
253
+ a.push(view[i].item().to<float>());
254
+ }
255
+ } else if (dtype == torch::kDouble) {
256
+ for (int i = 0; i < tensor.numel(); i++) {
257
+ a.push(view[i].item().to<double>());
258
+ }
259
+ } else if (dtype == torch::kBool) {
260
+ for (int i = 0; i < tensor.numel(); i++) {
261
+ a.push(view[i].item().to<bool>() ? True : False);
262
+ }
263
+ } else {
264
+ throw std::runtime_error("Unsupported type");
265
+ }
266
+ return a;
267
+ })
268
+ .define_method(
269
+ "_to",
270
+ *[](Tensor& self, torch::Device device, int dtype, bool non_blocking, bool copy) {
271
+ return self.to(device, (torch::ScalarType) dtype, non_blocking, copy);
272
+ });
273
+
274
+ Rice::define_class_under<torch::TensorOptions>(m, "TensorOptions")
275
+ .add_handler<torch::Error>(handle_error)
276
+ .define_constructor(Rice::Constructor<torch::TensorOptions>())
277
+ .define_method(
278
+ "dtype",
279
+ *[](torch::TensorOptions& self, int dtype) {
280
+ return self.dtype((torch::ScalarType) dtype);
281
+ })
282
+ .define_method(
283
+ "layout",
284
+ *[](torch::TensorOptions& self, const std::string& layout) {
285
+ torch::Layout l;
286
+ if (layout == "strided") {
287
+ l = torch::kStrided;
288
+ } else if (layout == "sparse") {
289
+ l = torch::kSparse;
290
+ throw std::runtime_error("Sparse layout not supported yet");
291
+ } else {
292
+ throw std::runtime_error("Unsupported layout: " + layout);
293
+ }
294
+ return self.layout(l);
295
+ })
296
+ .define_method(
297
+ "device",
298
+ *[](torch::TensorOptions& self, const std::string& device) {
299
+ torch::Device d(device);
300
+ return self.device(d);
301
+ })
302
+ .define_method(
303
+ "requires_grad",
304
+ *[](torch::TensorOptions& self, bool requires_grad) {
305
+ return self.requires_grad(requires_grad);
306
+ });
307
+ }
@@ -3,4 +3,4 @@
3
3
 
4
4
  #pragma once
5
5
 
6
- void add_tensor_functions(Module m);
6
+ void add_tensor_functions(Rice::Module& m);
@@ -0,0 +1,86 @@
1
+ #include <torch/torch.h>
2
+
3
+ #include <rice/Module.hpp>
4
+
5
+ #include "torch_functions.h"
6
+ #include "templates.h"
7
+ #include "utils.h"
8
+
9
+ void init_torch(Rice::Module& m) {
10
+ m.add_handler<torch::Error>(handle_error);
11
+ add_torch_functions(m);
12
+ m.define_singleton_method(
13
+ "grad_enabled?",
14
+ *[]() {
15
+ return torch::GradMode::is_enabled();
16
+ })
17
+ .define_singleton_method(
18
+ "_set_grad_enabled",
19
+ *[](bool enabled) {
20
+ torch::GradMode::set_enabled(enabled);
21
+ })
22
+ .define_singleton_method(
23
+ "manual_seed",
24
+ *[](uint64_t seed) {
25
+ return torch::manual_seed(seed);
26
+ })
27
+ // config
28
+ .define_singleton_method(
29
+ "show_config",
30
+ *[] {
31
+ return torch::show_config();
32
+ })
33
+ .define_singleton_method(
34
+ "parallel_info",
35
+ *[] {
36
+ return torch::get_parallel_info();
37
+ })
38
+ // begin operations
39
+ .define_singleton_method(
40
+ "_save",
41
+ *[](const torch::IValue &value) {
42
+ auto v = torch::pickle_save(value);
43
+ std::string str(v.begin(), v.end());
44
+ return str;
45
+ })
46
+ .define_singleton_method(
47
+ "_load",
48
+ *[](const std::string &s) {
49
+ std::vector<char> v;
50
+ std::copy(s.begin(), s.end(), std::back_inserter(v));
51
+ // https://github.com/pytorch/pytorch/issues/20356#issuecomment-567663701
52
+ return torch::pickle_load(v);
53
+ })
54
+ .define_singleton_method(
55
+ "_from_blob",
56
+ *[](Rice::String s, std::vector<int64_t> size, const torch::TensorOptions &options) {
57
+ void *data = const_cast<char *>(s.c_str());
58
+ return torch::from_blob(data, size, options);
59
+ })
60
+ .define_singleton_method(
61
+ "_tensor",
62
+ *[](Rice::Array a, std::vector<int64_t> size, const torch::TensorOptions &options) {
63
+ auto dtype = options.dtype();
64
+ torch::Tensor t;
65
+ if (dtype == torch::kBool) {
66
+ std::vector<uint8_t> vec;
67
+ for (long i = 0; i < a.size(); i++) {
68
+ vec.push_back(from_ruby<bool>(a[i]));
69
+ }
70
+ t = torch::tensor(vec, options);
71
+ } else {
72
+ std::vector<float> vec;
73
+ for (long i = 0; i < a.size(); i++) {
74
+ vec.push_back(from_ruby<float>(a[i]));
75
+ }
76
+ // hack for requires_grad error
77
+ if (options.requires_grad()) {
78
+ t = torch::tensor(vec, options.requires_grad(c10::nullopt));
79
+ t.set_requires_grad(true);
80
+ } else {
81
+ t = torch::tensor(vec, options);
82
+ }
83
+ }
84
+ return t.reshape(size);
85
+ });
86
+ }
@@ -3,4 +3,4 @@
3
3
 
4
4
  #pragma once
5
5
 
6
- void add_torch_functions(Module m);
6
+ void add_torch_functions(Rice::Module& m);
data/ext/torch/utils.h CHANGED
@@ -1,13 +1,20 @@
1
1
  #pragma once
2
2
 
3
+ #include <rice/Exception.hpp>
3
4
  #include <rice/Symbol.hpp>
4
5
 
6
+ // TODO find better place
7
+ inline void handle_error(torch::Error const & ex)
8
+ {
9
+ throw Rice::Exception(rb_eRuntimeError, ex.what_without_backtrace());
10
+ }
11
+
5
12
  // keep THP prefix for now to make it easier to compare code
6
13
 
7
14
  extern VALUE THPVariableClass;
8
15
 
9
16
  inline VALUE THPUtils_internSymbol(const std::string& str) {
10
- return Symbol(str);
17
+ return Rice::Symbol(str);
11
18
  }
12
19
 
13
20
  inline std::string THPUtils_unpackSymbol(VALUE obj) {
@@ -90,3 +90,10 @@ inline Object wrap(torch::TensorList x) {
90
90
  }
91
91
  return Object(a);
92
92
  }
93
+
94
+ inline Object wrap(std::tuple<double, double> x) {
95
+ Array a;
96
+ a.push(to_ruby<double>(std::get<0>(x)));
97
+ a.push(to_ruby<double>(std::get<1>(x)));
98
+ return Object(a);
99
+ }
data/lib/torch.rb CHANGED
@@ -261,6 +261,8 @@ module Torch
261
261
  Torch._from_blob(bytes, [bytes.bytesize], TensorOptions.new.dtype(DTYPE_TO_ENUM[dtype]))
262
262
  elsif args.size == 1 && args.first.is_a?(Array)
263
263
  Torch.tensor(args.first, dtype: dtype, device: device)
264
+ elsif args.size == 0
265
+ Torch.empty(0, dtype: dtype, device: device)
264
266
  else
265
267
  Torch.empty(*args, dtype: dtype, device: device)
266
268
  end
@@ -335,25 +337,24 @@ module Torch
335
337
  }
336
338
  end
337
339
 
338
- def no_grad
339
- previous_value = grad_enabled?
340
- begin
341
- _set_grad_enabled(false)
342
- yield
343
- ensure
344
- _set_grad_enabled(previous_value)
345
- end
340
+ def no_grad(&block)
341
+ grad_enabled(false, &block)
342
+ end
343
+
344
+ def enable_grad(&block)
345
+ grad_enabled(true, &block)
346
346
  end
347
347
 
348
- def enable_grad
348
+ def grad_enabled(value)
349
349
  previous_value = grad_enabled?
350
350
  begin
351
- _set_grad_enabled(true)
351
+ _set_grad_enabled(value)
352
352
  yield
353
353
  ensure
354
354
  _set_grad_enabled(previous_value)
355
355
  end
356
356
  end
357
+ alias_method :set_grad_enabled, :grad_enabled
357
358
 
358
359
  def device(str)
359
360
  Device.new(str)
@@ -434,7 +435,8 @@ module Torch
434
435
  zeros(input.size, **like_options(input, options))
435
436
  end
436
437
 
437
- def stft(input, n_fft, hop_length: nil, win_length: nil, window: nil, center: true, pad_mode: "reflect", normalized: false, onesided: true)
438
+ # center option
439
+ def stft(input, n_fft, hop_length: nil, win_length: nil, window: nil, center: true, pad_mode: "reflect", normalized: false, onesided: true, return_complex: nil)
438
440
  if center
439
441
  signal_dim = input.dim
440
442
  extended_shape = [1] * (3 - signal_dim) + input.size
@@ -442,12 +444,7 @@ module Torch
442
444
  input = NN::F.pad(input.view(extended_shape), [pad, pad], mode: pad_mode)
443
445
  input = input.view(input.shape[-signal_dim..-1])
444
446
  end
445
- _stft(input, n_fft, hop_length, win_length, window, normalized, onesided)
446
- end
447
-
448
- def clamp(tensor, min, max)
449
- tensor = _clamp_min(tensor, min)
450
- _clamp_max(tensor, max)
447
+ _stft(input, n_fft, hop_length, win_length, window, normalized, onesided, return_complex)
451
448
  end
452
449
 
453
450
  private