torch-rb 0.5.3 → 0.6.0

Sign up to get free protection for your applications and to get access to all the features.
data/ext/torch/extconf.rb CHANGED
@@ -11,7 +11,6 @@ apple_clang = RbConfig::CONFIG["CC_VERSION_MESSAGE"] =~ /apple clang/i
11
11
 
12
12
  # check omp first
13
13
  if have_library("omp") || have_library("gomp")
14
- $CXXFLAGS += " -DAT_PARALLEL_OPENMP=1"
15
14
  $CXXFLAGS += " -Xclang" if apple_clang
16
15
  $CXXFLAGS += " -fopenmp"
17
16
  end
@@ -0,0 +1,134 @@
1
+ #include <torch/torch.h>
2
+
3
+ #include <rice/Array.hpp>
4
+ #include <rice/Constructor.hpp>
5
+ #include <rice/Hash.hpp>
6
+ #include <rice/Module.hpp>
7
+ #include <rice/String.hpp>
8
+
9
+ #include "utils.h"
10
+
11
+ void init_ivalue(Rice::Module& m) {
12
+ // https://pytorch.org/cppdocs/api/structc10_1_1_i_value.html
13
+ Rice::define_class_under<torch::IValue>(m, "IValue")
14
+ .add_handler<torch::Error>(handle_error)
15
+ .define_constructor(Rice::Constructor<torch::IValue>())
16
+ .define_method("bool?", &torch::IValue::isBool)
17
+ .define_method("bool_list?", &torch::IValue::isBoolList)
18
+ .define_method("capsule?", &torch::IValue::isCapsule)
19
+ .define_method("custom_class?", &torch::IValue::isCustomClass)
20
+ .define_method("device?", &torch::IValue::isDevice)
21
+ .define_method("double?", &torch::IValue::isDouble)
22
+ .define_method("double_list?", &torch::IValue::isDoubleList)
23
+ .define_method("future?", &torch::IValue::isFuture)
24
+ // .define_method("generator?", &torch::IValue::isGenerator)
25
+ .define_method("generic_dict?", &torch::IValue::isGenericDict)
26
+ .define_method("list?", &torch::IValue::isList)
27
+ .define_method("int?", &torch::IValue::isInt)
28
+ .define_method("int_list?", &torch::IValue::isIntList)
29
+ .define_method("module?", &torch::IValue::isModule)
30
+ .define_method("none?", &torch::IValue::isNone)
31
+ .define_method("object?", &torch::IValue::isObject)
32
+ .define_method("ptr_type?", &torch::IValue::isPtrType)
33
+ .define_method("py_object?", &torch::IValue::isPyObject)
34
+ .define_method("r_ref?", &torch::IValue::isRRef)
35
+ .define_method("scalar?", &torch::IValue::isScalar)
36
+ .define_method("string?", &torch::IValue::isString)
37
+ .define_method("tensor?", &torch::IValue::isTensor)
38
+ .define_method("tensor_list?", &torch::IValue::isTensorList)
39
+ .define_method("tuple?", &torch::IValue::isTuple)
40
+ .define_method(
41
+ "to_bool",
42
+ *[](torch::IValue& self) {
43
+ return self.toBool();
44
+ })
45
+ .define_method(
46
+ "to_double",
47
+ *[](torch::IValue& self) {
48
+ return self.toDouble();
49
+ })
50
+ .define_method(
51
+ "to_int",
52
+ *[](torch::IValue& self) {
53
+ return self.toInt();
54
+ })
55
+ .define_method(
56
+ "to_list",
57
+ *[](torch::IValue& self) {
58
+ auto list = self.toListRef();
59
+ Rice::Array obj;
60
+ for (auto& elem : list) {
61
+ obj.push(to_ruby<torch::IValue>(torch::IValue{elem}));
62
+ }
63
+ return obj;
64
+ })
65
+ .define_method(
66
+ "to_string_ref",
67
+ *[](torch::IValue& self) {
68
+ return self.toStringRef();
69
+ })
70
+ .define_method(
71
+ "to_tensor",
72
+ *[](torch::IValue& self) {
73
+ return self.toTensor();
74
+ })
75
+ .define_method(
76
+ "to_generic_dict",
77
+ *[](torch::IValue& self) {
78
+ auto dict = self.toGenericDict();
79
+ Rice::Hash obj;
80
+ for (auto& pair : dict) {
81
+ obj[to_ruby<torch::IValue>(torch::IValue{pair.key()})] = to_ruby<torch::IValue>(torch::IValue{pair.value()});
82
+ }
83
+ return obj;
84
+ })
85
+ .define_singleton_method(
86
+ "from_tensor",
87
+ *[](torch::Tensor& v) {
88
+ return torch::IValue(v);
89
+ })
90
+ // TODO create specialized list types?
91
+ .define_singleton_method(
92
+ "from_list",
93
+ *[](Rice::Array obj) {
94
+ c10::impl::GenericList list(c10::AnyType::get());
95
+ for (auto entry : obj) {
96
+ list.push_back(from_ruby<torch::IValue>(entry));
97
+ }
98
+ return torch::IValue(list);
99
+ })
100
+ .define_singleton_method(
101
+ "from_string",
102
+ *[](Rice::String v) {
103
+ return torch::IValue(v.str());
104
+ })
105
+ .define_singleton_method(
106
+ "from_int",
107
+ *[](int64_t v) {
108
+ return torch::IValue(v);
109
+ })
110
+ .define_singleton_method(
111
+ "from_double",
112
+ *[](double v) {
113
+ return torch::IValue(v);
114
+ })
115
+ .define_singleton_method(
116
+ "from_bool",
117
+ *[](bool v) {
118
+ return torch::IValue(v);
119
+ })
120
+ // see https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/python/pybind_utils.h
121
+ // createGenericDict and toIValue
122
+ .define_singleton_method(
123
+ "from_dict",
124
+ *[](Rice::Hash obj) {
125
+ auto key_type = c10::AnyType::get();
126
+ auto value_type = c10::AnyType::get();
127
+ c10::impl::GenericDict elems(key_type, value_type);
128
+ elems.reserve(obj.size());
129
+ for (auto entry : obj) {
130
+ elems.insert(from_ruby<torch::IValue>(entry.first), from_ruby<torch::IValue>((Rice::Object) entry.second));
131
+ }
132
+ return torch::IValue(std::move(elems));
133
+ });
134
+ }
data/ext/torch/nn.cpp ADDED
@@ -0,0 +1,114 @@
1
+ #include <torch/torch.h>
2
+
3
+ #include <rice/Module.hpp>
4
+
5
+ #include "nn_functions.h"
6
+ #include "templates.h"
7
+ #include "utils.h"
8
+
9
+ // need to make a distinction between parameters and tensors
10
+ class Parameter: public torch::autograd::Variable {
11
+ public:
12
+ Parameter(Tensor&& t) : torch::autograd::Variable(t) { }
13
+ };
14
+
15
+ void init_nn(Rice::Module& m) {
16
+ auto rb_mNN = Rice::define_module_under(m, "NN");
17
+ rb_mNN.add_handler<torch::Error>(handle_error);
18
+ add_nn_functions(rb_mNN);
19
+
20
+ Rice::define_module_under(rb_mNN, "Init")
21
+ .add_handler<torch::Error>(handle_error)
22
+ .define_singleton_method(
23
+ "_calculate_gain",
24
+ *[](NonlinearityType nonlinearity, double param) {
25
+ return torch::nn::init::calculate_gain(nonlinearity, param);
26
+ })
27
+ .define_singleton_method(
28
+ "_uniform!",
29
+ *[](Tensor tensor, double low, double high) {
30
+ return torch::nn::init::uniform_(tensor, low, high);
31
+ })
32
+ .define_singleton_method(
33
+ "_normal!",
34
+ *[](Tensor tensor, double mean, double std) {
35
+ return torch::nn::init::normal_(tensor, mean, std);
36
+ })
37
+ .define_singleton_method(
38
+ "_constant!",
39
+ *[](Tensor tensor, Scalar value) {
40
+ return torch::nn::init::constant_(tensor, value);
41
+ })
42
+ .define_singleton_method(
43
+ "_ones!",
44
+ *[](Tensor tensor) {
45
+ return torch::nn::init::ones_(tensor);
46
+ })
47
+ .define_singleton_method(
48
+ "_zeros!",
49
+ *[](Tensor tensor) {
50
+ return torch::nn::init::zeros_(tensor);
51
+ })
52
+ .define_singleton_method(
53
+ "_eye!",
54
+ *[](Tensor tensor) {
55
+ return torch::nn::init::eye_(tensor);
56
+ })
57
+ .define_singleton_method(
58
+ "_dirac!",
59
+ *[](Tensor tensor) {
60
+ return torch::nn::init::dirac_(tensor);
61
+ })
62
+ .define_singleton_method(
63
+ "_xavier_uniform!",
64
+ *[](Tensor tensor, double gain) {
65
+ return torch::nn::init::xavier_uniform_(tensor, gain);
66
+ })
67
+ .define_singleton_method(
68
+ "_xavier_normal!",
69
+ *[](Tensor tensor, double gain) {
70
+ return torch::nn::init::xavier_normal_(tensor, gain);
71
+ })
72
+ .define_singleton_method(
73
+ "_kaiming_uniform!",
74
+ *[](Tensor tensor, double a, FanModeType mode, NonlinearityType nonlinearity) {
75
+ return torch::nn::init::kaiming_uniform_(tensor, a, mode, nonlinearity);
76
+ })
77
+ .define_singleton_method(
78
+ "_kaiming_normal!",
79
+ *[](Tensor tensor, double a, FanModeType mode, NonlinearityType nonlinearity) {
80
+ return torch::nn::init::kaiming_normal_(tensor, a, mode, nonlinearity);
81
+ })
82
+ .define_singleton_method(
83
+ "_orthogonal!",
84
+ *[](Tensor tensor, double gain) {
85
+ return torch::nn::init::orthogonal_(tensor, gain);
86
+ })
87
+ .define_singleton_method(
88
+ "_sparse!",
89
+ *[](Tensor tensor, double sparsity, double std) {
90
+ return torch::nn::init::sparse_(tensor, sparsity, std);
91
+ });
92
+
93
+ Rice::define_class_under<Parameter, torch::Tensor>(rb_mNN, "Parameter")
94
+ .add_handler<torch::Error>(handle_error)
95
+ .define_method(
96
+ "grad",
97
+ *[](Parameter& self) {
98
+ auto grad = self.grad();
99
+ return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
100
+ })
101
+ .define_method(
102
+ "grad=",
103
+ *[](Parameter& self, torch::Tensor& grad) {
104
+ self.mutable_grad() = grad;
105
+ })
106
+ .define_singleton_method(
107
+ "_make_subclass",
108
+ *[](Tensor& rd, bool requires_grad) {
109
+ auto data = rd.detach();
110
+ data.unsafeGetTensorImpl()->set_allow_tensor_metadata_change(true);
111
+ auto var = data.set_requires_grad(requires_grad);
112
+ return Parameter(std::move(var));
113
+ });
114
+ }
@@ -3,4 +3,4 @@
3
3
 
4
4
  #pragma once
5
5
 
6
- void add_nn_functions(Module m);
6
+ void add_nn_functions(Rice::Module& m);
@@ -0,0 +1,22 @@
1
+ #include <torch/torch.h>
2
+
3
+ #include <rice/Module.hpp>
4
+
5
+ #include "utils.h"
6
+
7
+ void init_random(Rice::Module& m) {
8
+ Rice::define_module_under(m, "Random")
9
+ .add_handler<torch::Error>(handle_error)
10
+ .define_singleton_method(
11
+ "initial_seed",
12
+ *[]() {
13
+ return at::detail::getDefaultCPUGenerator().current_seed();
14
+ })
15
+ .define_singleton_method(
16
+ "seed",
17
+ *[]() {
18
+ // TODO set for CUDA when available
19
+ auto generator = at::detail::getDefaultCPUGenerator();
20
+ return generator.seed();
21
+ });
22
+ }
@@ -487,7 +487,7 @@ static void extra_kwargs(FunctionSignature& signature, VALUE kwargs, ssize_t num
487
487
 
488
488
  VALUE missing = Qundef;
489
489
 
490
- bool FunctionSignature::parse(VALUE self, VALUE args, VALUE kwargs, std::vector<VALUE> &dst, // NOLINT
490
+ bool FunctionSignature::parse(VALUE self, VALUE args, VALUE kwargs, VALUE dst[], // NOLINT
491
491
  bool raise_exception) {
492
492
  auto nargs = NIL_P(args) ? 0 : RARRAY_LEN(args);
493
493
  ssize_t remaining_kwargs = NIL_P(kwargs) ? 0 : RHASH_SIZE(kwargs);
@@ -2,6 +2,8 @@
2
2
 
3
3
  #pragma once
4
4
 
5
+ #include <sstream>
6
+
5
7
  #include <torch/torch.h>
6
8
  #include <rice/Exception.hpp>
7
9
 
@@ -46,7 +48,7 @@ struct FunctionParameter {
46
48
  struct FunctionSignature {
47
49
  explicit FunctionSignature(const std::string& fmt, int index);
48
50
 
49
- bool parse(VALUE self, VALUE args, VALUE kwargs, std::vector<VALUE>& dst, bool raise_exception);
51
+ bool parse(VALUE self, VALUE args, VALUE kwargs, VALUE dst[], bool raise_exception);
50
52
 
51
53
  std::string toString() const;
52
54
 
@@ -63,13 +65,13 @@ struct FunctionSignature {
63
65
  };
64
66
 
65
67
  struct RubyArgs {
66
- RubyArgs(const FunctionSignature& signature, std::vector<VALUE> &args)
68
+ RubyArgs(const FunctionSignature& signature, VALUE* args)
67
69
  : signature(signature)
68
70
  , args(args)
69
71
  , idx(signature.index) {}
70
72
 
71
73
  const FunctionSignature& signature;
72
- std::vector<VALUE> args;
74
+ VALUE* args;
73
75
  int idx;
74
76
 
75
77
  inline at::Tensor tensor(int i);
@@ -328,6 +330,12 @@ inline bool RubyArgs::isNone(int i) {
328
330
  return NIL_P(args[i]);
329
331
  }
330
332
 
333
+ template<int N>
334
+ struct ParsedArgs {
335
+ ParsedArgs() : args() { }
336
+ VALUE args[N];
337
+ };
338
+
331
339
  struct RubyArgParser {
332
340
  std::vector<FunctionSignature> signatures_;
333
341
  std::string function_name;
@@ -356,7 +364,15 @@ struct RubyArgParser {
356
364
  });
357
365
  }
358
366
 
359
- RubyArgs parse(VALUE self, int argc, VALUE* argv, std::vector<VALUE> &parsed_args) {
367
+ template<int N>
368
+ inline RubyArgs parse(VALUE self, int argc, VALUE* argv, ParsedArgs<N> &dst) {
369
+ if (N < max_args) {
370
+ rb_raise(rb_eArgError, "RubyArgParser: dst ParsedArgs buffer does not have enough capacity, expected %d (got %d)", (int)max_args, N);
371
+ }
372
+ return raw_parse(self, argc, argv, dst.args);
373
+ }
374
+
375
+ inline RubyArgs raw_parse(VALUE self, int argc, VALUE* argv, VALUE parsed_args[]) {
360
376
  VALUE args, kwargs;
361
377
  rb_scan_args(argc, argv, "*:", &args, &kwargs);
362
378
 
@@ -378,7 +394,7 @@ struct RubyArgParser {
378
394
  rb_raise(rb_eArgError, "No matching signatures");
379
395
  }
380
396
 
381
- void print_error(VALUE self, VALUE args, VALUE kwargs, std::vector<VALUE>& parsed_args) {
397
+ void print_error(VALUE self, VALUE args, VALUE kwargs, VALUE parsed_args[]) {
382
398
  ssize_t num_args = (NIL_P(args) ? 0 : RARRAY_LEN(args)) + (NIL_P(kwargs) ? 0 : RHASH_SIZE(kwargs));
383
399
  std::vector<int> plausible_idxs;
384
400
  ssize_t i = 0;
@@ -44,7 +44,7 @@ std::vector<int64_t> from_ruby<std::vector<int64_t>>(Object x)
44
44
  {
45
45
  Array a = Array(x);
46
46
  std::vector<int64_t> vec(a.size());
47
- for (size_t i = 0; i < a.size(); i++) {
47
+ for (long i = 0; i < a.size(); i++) {
48
48
  vec[i] = from_ruby<int64_t>(a[i]);
49
49
  }
50
50
  return vec;
@@ -56,7 +56,7 @@ std::vector<Tensor> from_ruby<std::vector<Tensor>>(Object x)
56
56
  {
57
57
  Array a = Array(x);
58
58
  std::vector<Tensor> vec(a.size());
59
- for (size_t i = 0; i < a.size(); i++) {
59
+ for (long i = 0; i < a.size(); i++) {
60
60
  vec[i] = from_ruby<Tensor>(a[i]);
61
61
  }
62
62
  return vec;
@@ -0,0 +1,307 @@
1
+ #include <torch/torch.h>
2
+
3
+ #include <rice/Constructor.hpp>
4
+ #include <rice/Module.hpp>
5
+
6
+ #include "tensor_functions.h"
7
+ #include "ruby_arg_parser.h"
8
+ #include "templates.h"
9
+ #include "utils.h"
10
+
11
+ using namespace Rice;
12
+ using torch::indexing::TensorIndex;
13
+
14
+ Class rb_cTensor;
15
+
16
+ std::vector<TensorIndex> index_vector(Array a) {
17
+ Object obj;
18
+
19
+ std::vector<TensorIndex> indices;
20
+ indices.reserve(a.size());
21
+
22
+ for (long i = 0; i < a.size(); i++) {
23
+ obj = a[i];
24
+
25
+ if (obj.is_instance_of(rb_cInteger)) {
26
+ indices.push_back(from_ruby<int64_t>(obj));
27
+ } else if (obj.is_instance_of(rb_cRange)) {
28
+ torch::optional<int64_t> start_index = torch::nullopt;
29
+ torch::optional<int64_t> stop_index = torch::nullopt;
30
+
31
+ Object begin = obj.call("begin");
32
+ if (!begin.is_nil()) {
33
+ start_index = from_ruby<int64_t>(begin);
34
+ }
35
+
36
+ Object end = obj.call("end");
37
+ if (!end.is_nil()) {
38
+ stop_index = from_ruby<int64_t>(end);
39
+ }
40
+
41
+ Object exclude_end = obj.call("exclude_end?");
42
+ if (stop_index.has_value() && !exclude_end) {
43
+ if (stop_index.value() == -1) {
44
+ stop_index = torch::nullopt;
45
+ } else {
46
+ stop_index = stop_index.value() + 1;
47
+ }
48
+ }
49
+
50
+ indices.push_back(torch::indexing::Slice(start_index, stop_index));
51
+ } else if (obj.is_instance_of(rb_cTensor)) {
52
+ indices.push_back(from_ruby<Tensor>(obj));
53
+ } else if (obj.is_nil()) {
54
+ indices.push_back(torch::indexing::None);
55
+ } else if (obj == True || obj == False) {
56
+ indices.push_back(from_ruby<bool>(obj));
57
+ } else {
58
+ throw Exception(rb_eArgError, "Unsupported index type: %s", rb_obj_classname(obj));
59
+ }
60
+ }
61
+ return indices;
62
+ }
63
+
64
+ // hack (removes inputs argument)
65
+ // https://github.com/pytorch/pytorch/commit/2e5bfa9824f549be69a28e4705a72b4cf8a4c519
66
+ // TODO add support for inputs argument
67
+ // _backward
68
+ static VALUE tensor__backward(int argc, VALUE* argv, VALUE self_)
69
+ {
70
+ HANDLE_TH_ERRORS
71
+ Tensor& self = from_ruby<Tensor&>(self_);
72
+ static RubyArgParser parser({
73
+ "_backward(Tensor? gradient=None, bool? retain_graph=None, bool create_graph=False)"
74
+ });
75
+ ParsedArgs<4> parsed_args;
76
+ auto _r = parser.parse(self_, argc, argv, parsed_args);
77
+ // _backward(Tensor self, Tensor[] inputs, Tensor? gradient=None, bool? retain_graph=None, bool create_graph=False) -> ()
78
+ auto dispatch__backward = [](const Tensor & self, TensorList inputs, const OptionalTensor & gradient, c10::optional<bool> retain_graph, bool create_graph) -> void {
79
+ // in future, release GVL
80
+ self._backward(inputs, gradient, retain_graph, create_graph);
81
+ };
82
+ dispatch__backward(self, {}, _r.optionalTensor(0), _r.toBoolOptional(1), _r.toBool(2));
83
+ RETURN_NIL
84
+ END_HANDLE_TH_ERRORS
85
+ }
86
+
87
+ void init_tensor(Rice::Module& m) {
88
+ rb_cTensor = Rice::define_class_under<torch::Tensor>(m, "Tensor");
89
+ rb_cTensor.add_handler<torch::Error>(handle_error);
90
+ add_tensor_functions(rb_cTensor);
91
+ THPVariableClass = rb_cTensor.value();
92
+
93
+ rb_define_method(rb_cTensor, "backward", (VALUE (*)(...)) tensor__backward, -1);
94
+
95
+ rb_cTensor
96
+ .define_method("cuda?", &torch::Tensor::is_cuda)
97
+ .define_method("sparse?", &torch::Tensor::is_sparse)
98
+ .define_method("quantized?", &torch::Tensor::is_quantized)
99
+ .define_method("dim", &torch::Tensor::dim)
100
+ .define_method("numel", &torch::Tensor::numel)
101
+ .define_method("element_size", &torch::Tensor::element_size)
102
+ .define_method("requires_grad", &torch::Tensor::requires_grad)
103
+ .define_method(
104
+ "_size",
105
+ *[](Tensor& self, int64_t dim) {
106
+ return self.size(dim);
107
+ })
108
+ .define_method(
109
+ "_stride",
110
+ *[](Tensor& self, int64_t dim) {
111
+ return self.stride(dim);
112
+ })
113
+ // in C++ for performance
114
+ .define_method(
115
+ "shape",
116
+ *[](Tensor& self) {
117
+ Array a;
118
+ for (auto &size : self.sizes()) {
119
+ a.push(size);
120
+ }
121
+ return a;
122
+ })
123
+ .define_method(
124
+ "_strides",
125
+ *[](Tensor& self) {
126
+ Array a;
127
+ for (auto &stride : self.strides()) {
128
+ a.push(stride);
129
+ }
130
+ return a;
131
+ })
132
+ .define_method(
133
+ "_index",
134
+ *[](Tensor& self, Array indices) {
135
+ auto vec = index_vector(indices);
136
+ return self.index(vec);
137
+ })
138
+ .define_method(
139
+ "_index_put_custom",
140
+ *[](Tensor& self, Array indices, torch::Tensor& value) {
141
+ auto vec = index_vector(indices);
142
+ return self.index_put_(vec, value);
143
+ })
144
+ .define_method(
145
+ "contiguous?",
146
+ *[](Tensor& self) {
147
+ return self.is_contiguous();
148
+ })
149
+ .define_method(
150
+ "_requires_grad!",
151
+ *[](Tensor& self, bool requires_grad) {
152
+ return self.set_requires_grad(requires_grad);
153
+ })
154
+ .define_method(
155
+ "grad",
156
+ *[](Tensor& self) {
157
+ auto grad = self.grad();
158
+ return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
159
+ })
160
+ .define_method(
161
+ "grad=",
162
+ *[](Tensor& self, torch::Tensor& grad) {
163
+ self.mutable_grad() = grad;
164
+ })
165
+ .define_method(
166
+ "_dtype",
167
+ *[](Tensor& self) {
168
+ return (int) at::typeMetaToScalarType(self.dtype());
169
+ })
170
+ .define_method(
171
+ "_type",
172
+ *[](Tensor& self, int dtype) {
173
+ return self.toType((torch::ScalarType) dtype);
174
+ })
175
+ .define_method(
176
+ "_layout",
177
+ *[](Tensor& self) {
178
+ std::stringstream s;
179
+ s << self.layout();
180
+ return s.str();
181
+ })
182
+ .define_method(
183
+ "device",
184
+ *[](Tensor& self) {
185
+ std::stringstream s;
186
+ s << self.device();
187
+ return s.str();
188
+ })
189
+ .define_method(
190
+ "_data_str",
191
+ *[](Tensor& self) {
192
+ Tensor tensor = self;
193
+
194
+ // move to CPU to get data
195
+ if (tensor.device().type() != torch::kCPU) {
196
+ torch::Device device("cpu");
197
+ tensor = tensor.to(device);
198
+ }
199
+
200
+ if (!tensor.is_contiguous()) {
201
+ tensor = tensor.contiguous();
202
+ }
203
+
204
+ auto data_ptr = (const char *) tensor.data_ptr();
205
+ return std::string(data_ptr, tensor.numel() * tensor.element_size());
206
+ })
207
+ // for TorchVision
208
+ .define_method(
209
+ "_data_ptr",
210
+ *[](Tensor& self) {
211
+ return reinterpret_cast<uintptr_t>(self.data_ptr());
212
+ })
213
+ // TODO figure out a better way to do this
214
+ .define_method(
215
+ "_flat_data",
216
+ *[](Tensor& self) {
217
+ Tensor tensor = self;
218
+
219
+ // move to CPU to get data
220
+ if (tensor.device().type() != torch::kCPU) {
221
+ torch::Device device("cpu");
222
+ tensor = tensor.to(device);
223
+ }
224
+
225
+ Array a;
226
+ auto dtype = tensor.dtype();
227
+
228
+ Tensor view = tensor.reshape({tensor.numel()});
229
+
230
+ // TODO DRY if someone knows C++
231
+ if (dtype == torch::kByte) {
232
+ for (int i = 0; i < tensor.numel(); i++) {
233
+ a.push(view[i].item().to<uint8_t>());
234
+ }
235
+ } else if (dtype == torch::kChar) {
236
+ for (int i = 0; i < tensor.numel(); i++) {
237
+ a.push(to_ruby<int>(view[i].item().to<int8_t>()));
238
+ }
239
+ } else if (dtype == torch::kShort) {
240
+ for (int i = 0; i < tensor.numel(); i++) {
241
+ a.push(view[i].item().to<int16_t>());
242
+ }
243
+ } else if (dtype == torch::kInt) {
244
+ for (int i = 0; i < tensor.numel(); i++) {
245
+ a.push(view[i].item().to<int32_t>());
246
+ }
247
+ } else if (dtype == torch::kLong) {
248
+ for (int i = 0; i < tensor.numel(); i++) {
249
+ a.push(view[i].item().to<int64_t>());
250
+ }
251
+ } else if (dtype == torch::kFloat) {
252
+ for (int i = 0; i < tensor.numel(); i++) {
253
+ a.push(view[i].item().to<float>());
254
+ }
255
+ } else if (dtype == torch::kDouble) {
256
+ for (int i = 0; i < tensor.numel(); i++) {
257
+ a.push(view[i].item().to<double>());
258
+ }
259
+ } else if (dtype == torch::kBool) {
260
+ for (int i = 0; i < tensor.numel(); i++) {
261
+ a.push(view[i].item().to<bool>() ? True : False);
262
+ }
263
+ } else {
264
+ throw std::runtime_error("Unsupported type");
265
+ }
266
+ return a;
267
+ })
268
+ .define_method(
269
+ "_to",
270
+ *[](Tensor& self, torch::Device device, int dtype, bool non_blocking, bool copy) {
271
+ return self.to(device, (torch::ScalarType) dtype, non_blocking, copy);
272
+ });
273
+
274
+ Rice::define_class_under<torch::TensorOptions>(m, "TensorOptions")
275
+ .add_handler<torch::Error>(handle_error)
276
+ .define_constructor(Rice::Constructor<torch::TensorOptions>())
277
+ .define_method(
278
+ "dtype",
279
+ *[](torch::TensorOptions& self, int dtype) {
280
+ return self.dtype((torch::ScalarType) dtype);
281
+ })
282
+ .define_method(
283
+ "layout",
284
+ *[](torch::TensorOptions& self, const std::string& layout) {
285
+ torch::Layout l;
286
+ if (layout == "strided") {
287
+ l = torch::kStrided;
288
+ } else if (layout == "sparse") {
289
+ l = torch::kSparse;
290
+ throw std::runtime_error("Sparse layout not supported yet");
291
+ } else {
292
+ throw std::runtime_error("Unsupported layout: " + layout);
293
+ }
294
+ return self.layout(l);
295
+ })
296
+ .define_method(
297
+ "device",
298
+ *[](torch::TensorOptions& self, const std::string& device) {
299
+ torch::Device d(device);
300
+ return self.device(d);
301
+ })
302
+ .define_method(
303
+ "requires_grad",
304
+ *[](torch::TensorOptions& self, bool requires_grad) {
305
+ return self.requires_grad(requires_grad);
306
+ });
307
+ }