torch-rb 0.4.2 → 0.6.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +26 -0
- data/README.md +13 -3
- data/codegen/generate_functions.rb +20 -13
- data/codegen/native_functions.yaml +4129 -1521
- data/ext/torch/cuda.cpp +14 -0
- data/ext/torch/device.cpp +21 -0
- data/ext/torch/ext.cpp +17 -623
- data/ext/torch/extconf.rb +0 -1
- data/ext/torch/ivalue.cpp +134 -0
- data/ext/torch/nn.cpp +114 -0
- data/ext/torch/nn_functions.h +1 -1
- data/ext/torch/random.cpp +22 -0
- data/ext/torch/ruby_arg_parser.cpp +1 -1
- data/ext/torch/ruby_arg_parser.h +47 -7
- data/ext/torch/templates.h +3 -2
- data/ext/torch/tensor.cpp +307 -0
- data/ext/torch/tensor_functions.h +1 -1
- data/ext/torch/torch.cpp +86 -0
- data/ext/torch/torch_functions.h +1 -1
- data/ext/torch/utils.h +8 -1
- data/ext/torch/wrap_outputs.h +7 -0
- data/lib/torch.rb +14 -17
- data/lib/torch/nn/linear.rb +2 -0
- data/lib/torch/nn/module.rb +107 -21
- data/lib/torch/nn/parameter.rb +1 -1
- data/lib/torch/optim/adadelta.rb +2 -2
- data/lib/torch/optim/adagrad.rb +2 -2
- data/lib/torch/optim/adam.rb +2 -2
- data/lib/torch/optim/adamax.rb +1 -1
- data/lib/torch/optim/adamw.rb +2 -2
- data/lib/torch/optim/rmsprop.rb +3 -3
- data/lib/torch/optim/rprop.rb +1 -1
- data/lib/torch/tensor.rb +9 -0
- data/lib/torch/utils/data/data_loader.rb +1 -1
- data/lib/torch/version.rb +1 -1
- metadata +12 -89
data/ext/torch/cuda.cpp
ADDED
@@ -0,0 +1,14 @@
|
|
1
|
+
#include <torch/torch.h>
|
2
|
+
|
3
|
+
#include <rice/Module.hpp>
|
4
|
+
|
5
|
+
#include "utils.h"
|
6
|
+
|
7
|
+
void init_cuda(Rice::Module& m) {
|
8
|
+
Rice::define_module_under(m, "CUDA")
|
9
|
+
.add_handler<torch::Error>(handle_error)
|
10
|
+
.define_singleton_method("available?", &torch::cuda::is_available)
|
11
|
+
.define_singleton_method("device_count", &torch::cuda::device_count)
|
12
|
+
.define_singleton_method("manual_seed", &torch::cuda::manual_seed)
|
13
|
+
.define_singleton_method("manual_seed_all", &torch::cuda::manual_seed_all);
|
14
|
+
}
|
@@ -0,0 +1,21 @@
|
|
1
|
+
#include <torch/torch.h>
|
2
|
+
|
3
|
+
#include <rice/Constructor.hpp>
|
4
|
+
#include <rice/Module.hpp>
|
5
|
+
|
6
|
+
#include "utils.h"
|
7
|
+
|
8
|
+
void init_device(Rice::Module& m) {
|
9
|
+
Rice::define_class_under<torch::Device>(m, "Device")
|
10
|
+
.add_handler<torch::Error>(handle_error)
|
11
|
+
.define_constructor(Rice::Constructor<torch::Device, std::string>())
|
12
|
+
.define_method("index", &torch::Device::index)
|
13
|
+
.define_method("index?", &torch::Device::has_index)
|
14
|
+
.define_method(
|
15
|
+
"type",
|
16
|
+
*[](torch::Device& self) {
|
17
|
+
std::stringstream s;
|
18
|
+
s << self.type();
|
19
|
+
return s.str();
|
20
|
+
});
|
21
|
+
}
|
data/ext/torch/ext.cpp
CHANGED
@@ -1,632 +1,26 @@
|
|
1
|
-
#include <
|
1
|
+
#include <rice/Module.hpp>
|
2
2
|
|
3
|
-
|
3
|
+
void init_nn(Rice::Module& m);
|
4
|
+
void init_tensor(Rice::Module& m);
|
5
|
+
void init_torch(Rice::Module& m);
|
4
6
|
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
#include "templates.h"
|
11
|
-
#include "utils.h"
|
12
|
-
|
13
|
-
// generated with:
|
14
|
-
// rake generate:functions
|
15
|
-
#include "torch_functions.h"
|
16
|
-
#include "tensor_functions.h"
|
17
|
-
#include "nn_functions.h"
|
18
|
-
|
19
|
-
using namespace Rice;
|
20
|
-
using torch::indexing::TensorIndex;
|
21
|
-
|
22
|
-
// need to make a distinction between parameters and tensors
|
23
|
-
class Parameter: public torch::autograd::Variable {
|
24
|
-
public:
|
25
|
-
Parameter(Tensor&& t) : torch::autograd::Variable(t) { }
|
26
|
-
};
|
27
|
-
|
28
|
-
void handle_error(torch::Error const & ex)
|
29
|
-
{
|
30
|
-
throw Exception(rb_eRuntimeError, ex.what_without_backtrace());
|
31
|
-
}
|
32
|
-
|
33
|
-
Class rb_cTensor;
|
34
|
-
|
35
|
-
std::vector<TensorIndex> index_vector(Array a) {
|
36
|
-
Object obj;
|
37
|
-
|
38
|
-
std::vector<TensorIndex> indices;
|
39
|
-
indices.reserve(a.size());
|
40
|
-
|
41
|
-
for (size_t i = 0; i < a.size(); i++) {
|
42
|
-
obj = a[i];
|
43
|
-
|
44
|
-
if (obj.is_instance_of(rb_cInteger)) {
|
45
|
-
indices.push_back(from_ruby<int64_t>(obj));
|
46
|
-
} else if (obj.is_instance_of(rb_cRange)) {
|
47
|
-
torch::optional<int64_t> start_index = from_ruby<int64_t>(obj.call("begin"));
|
48
|
-
torch::optional<int64_t> stop_index = -1;
|
49
|
-
|
50
|
-
Object end = obj.call("end");
|
51
|
-
if (!end.is_nil()) {
|
52
|
-
stop_index = from_ruby<int64_t>(end);
|
53
|
-
}
|
54
|
-
|
55
|
-
Object exclude_end = obj.call("exclude_end?");
|
56
|
-
if (!exclude_end) {
|
57
|
-
if (stop_index.value() == -1) {
|
58
|
-
stop_index = torch::nullopt;
|
59
|
-
} else {
|
60
|
-
stop_index = stop_index.value() + 1;
|
61
|
-
}
|
62
|
-
}
|
63
|
-
|
64
|
-
indices.push_back(torch::indexing::Slice(start_index, stop_index));
|
65
|
-
} else if (obj.is_instance_of(rb_cTensor)) {
|
66
|
-
indices.push_back(from_ruby<Tensor>(obj));
|
67
|
-
} else if (obj.is_nil()) {
|
68
|
-
indices.push_back(torch::indexing::None);
|
69
|
-
} else if (obj == True || obj == False) {
|
70
|
-
indices.push_back(from_ruby<bool>(obj));
|
71
|
-
} else {
|
72
|
-
throw Exception(rb_eArgError, "Unsupported index type: %s", rb_obj_classname(obj));
|
73
|
-
}
|
74
|
-
}
|
75
|
-
return indices;
|
76
|
-
}
|
7
|
+
void init_cuda(Rice::Module& m);
|
8
|
+
void init_device(Rice::Module& m);
|
9
|
+
void init_ivalue(Rice::Module& m);
|
10
|
+
void init_random(Rice::Module& m);
|
77
11
|
|
78
12
|
extern "C"
|
79
13
|
void Init_ext()
|
80
14
|
{
|
81
|
-
|
82
|
-
rb_mTorch.add_handler<torch::Error>(handle_error);
|
83
|
-
add_torch_functions(rb_mTorch);
|
84
|
-
|
85
|
-
rb_cTensor = define_class_under<torch::Tensor>(rb_mTorch, "Tensor");
|
86
|
-
rb_cTensor.add_handler<torch::Error>(handle_error);
|
87
|
-
add_tensor_functions(rb_cTensor);
|
88
|
-
THPVariableClass = rb_cTensor.value();
|
89
|
-
|
90
|
-
Module rb_mNN = define_module_under(rb_mTorch, "NN");
|
91
|
-
rb_mNN.add_handler<torch::Error>(handle_error);
|
92
|
-
add_nn_functions(rb_mNN);
|
93
|
-
|
94
|
-
Module rb_mRandom = define_module_under(rb_mTorch, "Random")
|
95
|
-
.add_handler<torch::Error>(handle_error)
|
96
|
-
.define_singleton_method(
|
97
|
-
"initial_seed",
|
98
|
-
*[]() {
|
99
|
-
return at::detail::getDefaultCPUGenerator().current_seed();
|
100
|
-
})
|
101
|
-
.define_singleton_method(
|
102
|
-
"seed",
|
103
|
-
*[]() {
|
104
|
-
// TODO set for CUDA when available
|
105
|
-
auto generator = at::detail::getDefaultCPUGenerator();
|
106
|
-
return generator.seed();
|
107
|
-
});
|
108
|
-
|
109
|
-
// https://pytorch.org/cppdocs/api/structc10_1_1_i_value.html
|
110
|
-
Class rb_cIValue = define_class_under<torch::IValue>(rb_mTorch, "IValue")
|
111
|
-
.add_handler<torch::Error>(handle_error)
|
112
|
-
.define_constructor(Constructor<torch::IValue>())
|
113
|
-
.define_method("bool?", &torch::IValue::isBool)
|
114
|
-
.define_method("bool_list?", &torch::IValue::isBoolList)
|
115
|
-
.define_method("capsule?", &torch::IValue::isCapsule)
|
116
|
-
.define_method("custom_class?", &torch::IValue::isCustomClass)
|
117
|
-
.define_method("device?", &torch::IValue::isDevice)
|
118
|
-
.define_method("double?", &torch::IValue::isDouble)
|
119
|
-
.define_method("double_list?", &torch::IValue::isDoubleList)
|
120
|
-
.define_method("future?", &torch::IValue::isFuture)
|
121
|
-
// .define_method("generator?", &torch::IValue::isGenerator)
|
122
|
-
.define_method("generic_dict?", &torch::IValue::isGenericDict)
|
123
|
-
.define_method("list?", &torch::IValue::isList)
|
124
|
-
.define_method("int?", &torch::IValue::isInt)
|
125
|
-
.define_method("int_list?", &torch::IValue::isIntList)
|
126
|
-
.define_method("module?", &torch::IValue::isModule)
|
127
|
-
.define_method("none?", &torch::IValue::isNone)
|
128
|
-
.define_method("object?", &torch::IValue::isObject)
|
129
|
-
.define_method("ptr_type?", &torch::IValue::isPtrType)
|
130
|
-
.define_method("py_object?", &torch::IValue::isPyObject)
|
131
|
-
.define_method("r_ref?", &torch::IValue::isRRef)
|
132
|
-
.define_method("scalar?", &torch::IValue::isScalar)
|
133
|
-
.define_method("string?", &torch::IValue::isString)
|
134
|
-
.define_method("tensor?", &torch::IValue::isTensor)
|
135
|
-
.define_method("tensor_list?", &torch::IValue::isTensorList)
|
136
|
-
.define_method("tuple?", &torch::IValue::isTuple)
|
137
|
-
.define_method(
|
138
|
-
"to_bool",
|
139
|
-
*[](torch::IValue& self) {
|
140
|
-
return self.toBool();
|
141
|
-
})
|
142
|
-
.define_method(
|
143
|
-
"to_double",
|
144
|
-
*[](torch::IValue& self) {
|
145
|
-
return self.toDouble();
|
146
|
-
})
|
147
|
-
.define_method(
|
148
|
-
"to_int",
|
149
|
-
*[](torch::IValue& self) {
|
150
|
-
return self.toInt();
|
151
|
-
})
|
152
|
-
.define_method(
|
153
|
-
"to_list",
|
154
|
-
*[](torch::IValue& self) {
|
155
|
-
auto list = self.toListRef();
|
156
|
-
Array obj;
|
157
|
-
for (auto& elem : list) {
|
158
|
-
obj.push(to_ruby<torch::IValue>(torch::IValue{elem}));
|
159
|
-
}
|
160
|
-
return obj;
|
161
|
-
})
|
162
|
-
.define_method(
|
163
|
-
"to_string_ref",
|
164
|
-
*[](torch::IValue& self) {
|
165
|
-
return self.toStringRef();
|
166
|
-
})
|
167
|
-
.define_method(
|
168
|
-
"to_tensor",
|
169
|
-
*[](torch::IValue& self) {
|
170
|
-
return self.toTensor();
|
171
|
-
})
|
172
|
-
.define_method(
|
173
|
-
"to_generic_dict",
|
174
|
-
*[](torch::IValue& self) {
|
175
|
-
auto dict = self.toGenericDict();
|
176
|
-
Hash obj;
|
177
|
-
for (auto& pair : dict) {
|
178
|
-
obj[to_ruby<torch::IValue>(torch::IValue{pair.key()})] = to_ruby<torch::IValue>(torch::IValue{pair.value()});
|
179
|
-
}
|
180
|
-
return obj;
|
181
|
-
})
|
182
|
-
.define_singleton_method(
|
183
|
-
"from_tensor",
|
184
|
-
*[](torch::Tensor& v) {
|
185
|
-
return torch::IValue(v);
|
186
|
-
})
|
187
|
-
// TODO create specialized list types?
|
188
|
-
.define_singleton_method(
|
189
|
-
"from_list",
|
190
|
-
*[](Array obj) {
|
191
|
-
c10::impl::GenericList list(c10::AnyType::get());
|
192
|
-
for (auto entry : obj) {
|
193
|
-
list.push_back(from_ruby<torch::IValue>(entry));
|
194
|
-
}
|
195
|
-
return torch::IValue(list);
|
196
|
-
})
|
197
|
-
.define_singleton_method(
|
198
|
-
"from_string",
|
199
|
-
*[](String v) {
|
200
|
-
return torch::IValue(v.str());
|
201
|
-
})
|
202
|
-
.define_singleton_method(
|
203
|
-
"from_int",
|
204
|
-
*[](int64_t v) {
|
205
|
-
return torch::IValue(v);
|
206
|
-
})
|
207
|
-
.define_singleton_method(
|
208
|
-
"from_double",
|
209
|
-
*[](double v) {
|
210
|
-
return torch::IValue(v);
|
211
|
-
})
|
212
|
-
.define_singleton_method(
|
213
|
-
"from_bool",
|
214
|
-
*[](bool v) {
|
215
|
-
return torch::IValue(v);
|
216
|
-
})
|
217
|
-
// see https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/python/pybind_utils.h
|
218
|
-
// createGenericDict and toIValue
|
219
|
-
.define_singleton_method(
|
220
|
-
"from_dict",
|
221
|
-
*[](Hash obj) {
|
222
|
-
auto key_type = c10::AnyType::get();
|
223
|
-
auto value_type = c10::AnyType::get();
|
224
|
-
c10::impl::GenericDict elems(key_type, value_type);
|
225
|
-
elems.reserve(obj.size());
|
226
|
-
for (auto entry : obj) {
|
227
|
-
elems.insert(from_ruby<torch::IValue>(entry.first), from_ruby<torch::IValue>((Object) entry.second));
|
228
|
-
}
|
229
|
-
return torch::IValue(std::move(elems));
|
230
|
-
});
|
231
|
-
|
232
|
-
rb_mTorch.define_singleton_method(
|
233
|
-
"grad_enabled?",
|
234
|
-
*[]() {
|
235
|
-
return torch::GradMode::is_enabled();
|
236
|
-
})
|
237
|
-
.define_singleton_method(
|
238
|
-
"_set_grad_enabled",
|
239
|
-
*[](bool enabled) {
|
240
|
-
torch::GradMode::set_enabled(enabled);
|
241
|
-
})
|
242
|
-
.define_singleton_method(
|
243
|
-
"manual_seed",
|
244
|
-
*[](uint64_t seed) {
|
245
|
-
return torch::manual_seed(seed);
|
246
|
-
})
|
247
|
-
// config
|
248
|
-
.define_singleton_method(
|
249
|
-
"show_config",
|
250
|
-
*[] {
|
251
|
-
return torch::show_config();
|
252
|
-
})
|
253
|
-
.define_singleton_method(
|
254
|
-
"parallel_info",
|
255
|
-
*[] {
|
256
|
-
return torch::get_parallel_info();
|
257
|
-
})
|
258
|
-
// begin operations
|
259
|
-
.define_singleton_method(
|
260
|
-
"_save",
|
261
|
-
*[](const torch::IValue &value) {
|
262
|
-
auto v = torch::pickle_save(value);
|
263
|
-
std::string str(v.begin(), v.end());
|
264
|
-
return str;
|
265
|
-
})
|
266
|
-
.define_singleton_method(
|
267
|
-
"_load",
|
268
|
-
*[](const std::string &s) {
|
269
|
-
std::vector<char> v;
|
270
|
-
std::copy(s.begin(), s.end(), std::back_inserter(v));
|
271
|
-
// https://github.com/pytorch/pytorch/issues/20356#issuecomment-567663701
|
272
|
-
return torch::pickle_load(v);
|
273
|
-
})
|
274
|
-
.define_singleton_method(
|
275
|
-
"_from_blob",
|
276
|
-
*[](String s, std::vector<int64_t> size, const torch::TensorOptions &options) {
|
277
|
-
void *data = const_cast<char *>(s.c_str());
|
278
|
-
return torch::from_blob(data, size, options);
|
279
|
-
})
|
280
|
-
.define_singleton_method(
|
281
|
-
"_tensor",
|
282
|
-
*[](Array a, std::vector<int64_t> size, const torch::TensorOptions &options) {
|
283
|
-
auto dtype = options.dtype();
|
284
|
-
torch::Tensor t;
|
285
|
-
if (dtype == torch::kBool) {
|
286
|
-
std::vector<uint8_t> vec;
|
287
|
-
for (size_t i = 0; i < a.size(); i++) {
|
288
|
-
vec.push_back(from_ruby<bool>(a[i]));
|
289
|
-
}
|
290
|
-
t = torch::tensor(vec, options);
|
291
|
-
} else {
|
292
|
-
std::vector<float> vec;
|
293
|
-
for (size_t i = 0; i < a.size(); i++) {
|
294
|
-
vec.push_back(from_ruby<float>(a[i]));
|
295
|
-
}
|
296
|
-
// hack for requires_grad error
|
297
|
-
if (options.requires_grad()) {
|
298
|
-
t = torch::tensor(vec, options.requires_grad(c10::nullopt));
|
299
|
-
t.set_requires_grad(true);
|
300
|
-
} else {
|
301
|
-
t = torch::tensor(vec, options);
|
302
|
-
}
|
303
|
-
}
|
304
|
-
return t.reshape(size);
|
305
|
-
});
|
306
|
-
|
307
|
-
rb_cTensor
|
308
|
-
.define_method("cuda?", &torch::Tensor::is_cuda)
|
309
|
-
.define_method("sparse?", &torch::Tensor::is_sparse)
|
310
|
-
.define_method("quantized?", &torch::Tensor::is_quantized)
|
311
|
-
.define_method("dim", &torch::Tensor::dim)
|
312
|
-
.define_method("numel", &torch::Tensor::numel)
|
313
|
-
.define_method("element_size", &torch::Tensor::element_size)
|
314
|
-
.define_method("requires_grad", &torch::Tensor::requires_grad)
|
315
|
-
// in C++ for performance
|
316
|
-
.define_method(
|
317
|
-
"shape",
|
318
|
-
*[](Tensor& self) {
|
319
|
-
Array a;
|
320
|
-
for (auto &size : self.sizes()) {
|
321
|
-
a.push(size);
|
322
|
-
}
|
323
|
-
return a;
|
324
|
-
})
|
325
|
-
.define_method(
|
326
|
-
"_strides",
|
327
|
-
*[](Tensor& self) {
|
328
|
-
Array a;
|
329
|
-
for (auto &stride : self.strides()) {
|
330
|
-
a.push(stride);
|
331
|
-
}
|
332
|
-
return a;
|
333
|
-
})
|
334
|
-
.define_method(
|
335
|
-
"_index",
|
336
|
-
*[](Tensor& self, Array indices) {
|
337
|
-
auto vec = index_vector(indices);
|
338
|
-
return self.index(vec);
|
339
|
-
})
|
340
|
-
.define_method(
|
341
|
-
"_index_put_custom",
|
342
|
-
*[](Tensor& self, Array indices, torch::Tensor& value) {
|
343
|
-
auto vec = index_vector(indices);
|
344
|
-
return self.index_put_(vec, value);
|
345
|
-
})
|
346
|
-
.define_method(
|
347
|
-
"contiguous?",
|
348
|
-
*[](Tensor& self) {
|
349
|
-
return self.is_contiguous();
|
350
|
-
})
|
351
|
-
.define_method(
|
352
|
-
"addcmul!",
|
353
|
-
*[](Tensor& self, Scalar value, const Tensor & tensor1, const Tensor & tensor2) {
|
354
|
-
return self.addcmul_(tensor1, tensor2, value);
|
355
|
-
})
|
356
|
-
.define_method(
|
357
|
-
"addcdiv!",
|
358
|
-
*[](Tensor& self, Scalar value, const Tensor & tensor1, const Tensor & tensor2) {
|
359
|
-
return self.addcdiv_(tensor1, tensor2, value);
|
360
|
-
})
|
361
|
-
.define_method(
|
362
|
-
"_requires_grad!",
|
363
|
-
*[](Tensor& self, bool requires_grad) {
|
364
|
-
return self.set_requires_grad(requires_grad);
|
365
|
-
})
|
366
|
-
.define_method(
|
367
|
-
"grad",
|
368
|
-
*[](Tensor& self) {
|
369
|
-
auto grad = self.grad();
|
370
|
-
return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
|
371
|
-
})
|
372
|
-
.define_method(
|
373
|
-
"grad=",
|
374
|
-
*[](Tensor& self, torch::Tensor& grad) {
|
375
|
-
self.grad() = grad;
|
376
|
-
})
|
377
|
-
.define_method(
|
378
|
-
"_dtype",
|
379
|
-
*[](Tensor& self) {
|
380
|
-
return (int) at::typeMetaToScalarType(self.dtype());
|
381
|
-
})
|
382
|
-
.define_method(
|
383
|
-
"_type",
|
384
|
-
*[](Tensor& self, int dtype) {
|
385
|
-
return self.toType((torch::ScalarType) dtype);
|
386
|
-
})
|
387
|
-
.define_method(
|
388
|
-
"_layout",
|
389
|
-
*[](Tensor& self) {
|
390
|
-
std::stringstream s;
|
391
|
-
s << self.layout();
|
392
|
-
return s.str();
|
393
|
-
})
|
394
|
-
.define_method(
|
395
|
-
"device",
|
396
|
-
*[](Tensor& self) {
|
397
|
-
std::stringstream s;
|
398
|
-
s << self.device();
|
399
|
-
return s.str();
|
400
|
-
})
|
401
|
-
.define_method(
|
402
|
-
"_data_str",
|
403
|
-
*[](Tensor& self) {
|
404
|
-
Tensor tensor = self;
|
405
|
-
|
406
|
-
// move to CPU to get data
|
407
|
-
if (tensor.device().type() != torch::kCPU) {
|
408
|
-
torch::Device device("cpu");
|
409
|
-
tensor = tensor.to(device);
|
410
|
-
}
|
411
|
-
|
412
|
-
if (!tensor.is_contiguous()) {
|
413
|
-
tensor = tensor.contiguous();
|
414
|
-
}
|
415
|
-
|
416
|
-
auto data_ptr = (const char *) tensor.data_ptr();
|
417
|
-
return std::string(data_ptr, tensor.numel() * tensor.element_size());
|
418
|
-
})
|
419
|
-
// for TorchVision
|
420
|
-
.define_method(
|
421
|
-
"_data_ptr",
|
422
|
-
*[](Tensor& self) {
|
423
|
-
return reinterpret_cast<uintptr_t>(self.data_ptr());
|
424
|
-
})
|
425
|
-
// TODO figure out a better way to do this
|
426
|
-
.define_method(
|
427
|
-
"_flat_data",
|
428
|
-
*[](Tensor& self) {
|
429
|
-
Tensor tensor = self;
|
430
|
-
|
431
|
-
// move to CPU to get data
|
432
|
-
if (tensor.device().type() != torch::kCPU) {
|
433
|
-
torch::Device device("cpu");
|
434
|
-
tensor = tensor.to(device);
|
435
|
-
}
|
436
|
-
|
437
|
-
Array a;
|
438
|
-
auto dtype = tensor.dtype();
|
439
|
-
|
440
|
-
Tensor view = tensor.reshape({tensor.numel()});
|
441
|
-
|
442
|
-
// TODO DRY if someone knows C++
|
443
|
-
if (dtype == torch::kByte) {
|
444
|
-
for (int i = 0; i < tensor.numel(); i++) {
|
445
|
-
a.push(view[i].item().to<uint8_t>());
|
446
|
-
}
|
447
|
-
} else if (dtype == torch::kChar) {
|
448
|
-
for (int i = 0; i < tensor.numel(); i++) {
|
449
|
-
a.push(to_ruby<int>(view[i].item().to<int8_t>()));
|
450
|
-
}
|
451
|
-
} else if (dtype == torch::kShort) {
|
452
|
-
for (int i = 0; i < tensor.numel(); i++) {
|
453
|
-
a.push(view[i].item().to<int16_t>());
|
454
|
-
}
|
455
|
-
} else if (dtype == torch::kInt) {
|
456
|
-
for (int i = 0; i < tensor.numel(); i++) {
|
457
|
-
a.push(view[i].item().to<int32_t>());
|
458
|
-
}
|
459
|
-
} else if (dtype == torch::kLong) {
|
460
|
-
for (int i = 0; i < tensor.numel(); i++) {
|
461
|
-
a.push(view[i].item().to<int64_t>());
|
462
|
-
}
|
463
|
-
} else if (dtype == torch::kFloat) {
|
464
|
-
for (int i = 0; i < tensor.numel(); i++) {
|
465
|
-
a.push(view[i].item().to<float>());
|
466
|
-
}
|
467
|
-
} else if (dtype == torch::kDouble) {
|
468
|
-
for (int i = 0; i < tensor.numel(); i++) {
|
469
|
-
a.push(view[i].item().to<double>());
|
470
|
-
}
|
471
|
-
} else if (dtype == torch::kBool) {
|
472
|
-
for (int i = 0; i < tensor.numel(); i++) {
|
473
|
-
a.push(view[i].item().to<bool>() ? True : False);
|
474
|
-
}
|
475
|
-
} else {
|
476
|
-
throw std::runtime_error("Unsupported type");
|
477
|
-
}
|
478
|
-
return a;
|
479
|
-
})
|
480
|
-
.define_method(
|
481
|
-
"_to",
|
482
|
-
*[](Tensor& self, torch::Device device, int dtype, bool non_blocking, bool copy) {
|
483
|
-
return self.to(device, (torch::ScalarType) dtype, non_blocking, copy);
|
484
|
-
})
|
485
|
-
.define_singleton_method(
|
486
|
-
"_make_subclass",
|
487
|
-
*[](Tensor& rd, bool requires_grad) {
|
488
|
-
auto data = rd.detach();
|
489
|
-
data.unsafeGetTensorImpl()->set_allow_tensor_metadata_change(true);
|
490
|
-
auto var = data.set_requires_grad(requires_grad);
|
491
|
-
return Parameter(std::move(var));
|
492
|
-
});
|
493
|
-
|
494
|
-
Class rb_cTensorOptions = define_class_under<torch::TensorOptions>(rb_mTorch, "TensorOptions")
|
495
|
-
.add_handler<torch::Error>(handle_error)
|
496
|
-
.define_constructor(Constructor<torch::TensorOptions>())
|
497
|
-
.define_method(
|
498
|
-
"dtype",
|
499
|
-
*[](torch::TensorOptions& self, int dtype) {
|
500
|
-
return self.dtype((torch::ScalarType) dtype);
|
501
|
-
})
|
502
|
-
.define_method(
|
503
|
-
"layout",
|
504
|
-
*[](torch::TensorOptions& self, std::string layout) {
|
505
|
-
torch::Layout l;
|
506
|
-
if (layout == "strided") {
|
507
|
-
l = torch::kStrided;
|
508
|
-
} else if (layout == "sparse") {
|
509
|
-
l = torch::kSparse;
|
510
|
-
throw std::runtime_error("Sparse layout not supported yet");
|
511
|
-
} else {
|
512
|
-
throw std::runtime_error("Unsupported layout: " + layout);
|
513
|
-
}
|
514
|
-
return self.layout(l);
|
515
|
-
})
|
516
|
-
.define_method(
|
517
|
-
"device",
|
518
|
-
*[](torch::TensorOptions& self, std::string device) {
|
519
|
-
torch::Device d(device);
|
520
|
-
return self.device(d);
|
521
|
-
})
|
522
|
-
.define_method(
|
523
|
-
"requires_grad",
|
524
|
-
*[](torch::TensorOptions& self, bool requires_grad) {
|
525
|
-
return self.requires_grad(requires_grad);
|
526
|
-
});
|
527
|
-
|
528
|
-
Module rb_mInit = define_module_under(rb_mNN, "Init")
|
529
|
-
.add_handler<torch::Error>(handle_error)
|
530
|
-
.define_singleton_method(
|
531
|
-
"_calculate_gain",
|
532
|
-
*[](NonlinearityType nonlinearity, double param) {
|
533
|
-
return torch::nn::init::calculate_gain(nonlinearity, param);
|
534
|
-
})
|
535
|
-
.define_singleton_method(
|
536
|
-
"_uniform!",
|
537
|
-
*[](Tensor tensor, double low, double high) {
|
538
|
-
return torch::nn::init::uniform_(tensor, low, high);
|
539
|
-
})
|
540
|
-
.define_singleton_method(
|
541
|
-
"_normal!",
|
542
|
-
*[](Tensor tensor, double mean, double std) {
|
543
|
-
return torch::nn::init::normal_(tensor, mean, std);
|
544
|
-
})
|
545
|
-
.define_singleton_method(
|
546
|
-
"_constant!",
|
547
|
-
*[](Tensor tensor, Scalar value) {
|
548
|
-
return torch::nn::init::constant_(tensor, value);
|
549
|
-
})
|
550
|
-
.define_singleton_method(
|
551
|
-
"_ones!",
|
552
|
-
*[](Tensor tensor) {
|
553
|
-
return torch::nn::init::ones_(tensor);
|
554
|
-
})
|
555
|
-
.define_singleton_method(
|
556
|
-
"_zeros!",
|
557
|
-
*[](Tensor tensor) {
|
558
|
-
return torch::nn::init::zeros_(tensor);
|
559
|
-
})
|
560
|
-
.define_singleton_method(
|
561
|
-
"_eye!",
|
562
|
-
*[](Tensor tensor) {
|
563
|
-
return torch::nn::init::eye_(tensor);
|
564
|
-
})
|
565
|
-
.define_singleton_method(
|
566
|
-
"_dirac!",
|
567
|
-
*[](Tensor tensor) {
|
568
|
-
return torch::nn::init::dirac_(tensor);
|
569
|
-
})
|
570
|
-
.define_singleton_method(
|
571
|
-
"_xavier_uniform!",
|
572
|
-
*[](Tensor tensor, double gain) {
|
573
|
-
return torch::nn::init::xavier_uniform_(tensor, gain);
|
574
|
-
})
|
575
|
-
.define_singleton_method(
|
576
|
-
"_xavier_normal!",
|
577
|
-
*[](Tensor tensor, double gain) {
|
578
|
-
return torch::nn::init::xavier_normal_(tensor, gain);
|
579
|
-
})
|
580
|
-
.define_singleton_method(
|
581
|
-
"_kaiming_uniform!",
|
582
|
-
*[](Tensor tensor, double a, FanModeType mode, NonlinearityType nonlinearity) {
|
583
|
-
return torch::nn::init::kaiming_uniform_(tensor, a, mode, nonlinearity);
|
584
|
-
})
|
585
|
-
.define_singleton_method(
|
586
|
-
"_kaiming_normal!",
|
587
|
-
*[](Tensor tensor, double a, FanModeType mode, NonlinearityType nonlinearity) {
|
588
|
-
return torch::nn::init::kaiming_normal_(tensor, a, mode, nonlinearity);
|
589
|
-
})
|
590
|
-
.define_singleton_method(
|
591
|
-
"_orthogonal!",
|
592
|
-
*[](Tensor tensor, double gain) {
|
593
|
-
return torch::nn::init::orthogonal_(tensor, gain);
|
594
|
-
})
|
595
|
-
.define_singleton_method(
|
596
|
-
"_sparse!",
|
597
|
-
*[](Tensor tensor, double sparsity, double std) {
|
598
|
-
return torch::nn::init::sparse_(tensor, sparsity, std);
|
599
|
-
});
|
600
|
-
|
601
|
-
Class rb_cParameter = define_class_under<Parameter, torch::Tensor>(rb_mNN, "Parameter")
|
602
|
-
.add_handler<torch::Error>(handle_error)
|
603
|
-
.define_method(
|
604
|
-
"grad",
|
605
|
-
*[](Parameter& self) {
|
606
|
-
auto grad = self.grad();
|
607
|
-
return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
|
608
|
-
})
|
609
|
-
.define_method(
|
610
|
-
"grad=",
|
611
|
-
*[](Parameter& self, torch::Tensor& grad) {
|
612
|
-
self.grad() = grad;
|
613
|
-
});
|
15
|
+
auto m = Rice::define_module("Torch");
|
614
16
|
|
615
|
-
|
616
|
-
|
617
|
-
|
618
|
-
|
619
|
-
.define_method("index?", &torch::Device::has_index)
|
620
|
-
.define_method(
|
621
|
-
"type",
|
622
|
-
*[](torch::Device& self) {
|
623
|
-
std::stringstream s;
|
624
|
-
s << self.type();
|
625
|
-
return s.str();
|
626
|
-
});
|
17
|
+
// keep this order
|
18
|
+
init_torch(m);
|
19
|
+
init_tensor(m);
|
20
|
+
init_nn(m);
|
627
21
|
|
628
|
-
|
629
|
-
|
630
|
-
|
631
|
-
|
22
|
+
init_cuda(m);
|
23
|
+
init_device(m);
|
24
|
+
init_ivalue(m);
|
25
|
+
init_random(m);
|
632
26
|
}
|