torch-rb 0.5.3 → 0.6.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +7 -1
- data/README.md +5 -3
- data/codegen/generate_functions.rb +7 -5
- data/codegen/native_functions.yaml +2355 -1396
- data/ext/torch/cuda.cpp +14 -0
- data/ext/torch/device.cpp +21 -0
- data/ext/torch/ext.cpp +17 -622
- data/ext/torch/extconf.rb +0 -1
- data/ext/torch/ivalue.cpp +134 -0
- data/ext/torch/nn.cpp +114 -0
- data/ext/torch/nn_functions.h +1 -1
- data/ext/torch/random.cpp +22 -0
- data/ext/torch/ruby_arg_parser.cpp +1 -1
- data/ext/torch/ruby_arg_parser.h +21 -5
- data/ext/torch/templates.h +2 -2
- data/ext/torch/tensor.cpp +307 -0
- data/ext/torch/tensor_functions.h +1 -1
- data/ext/torch/torch.cpp +86 -0
- data/ext/torch/torch_functions.h +1 -1
- data/ext/torch/utils.h +8 -1
- data/lib/torch.rb +9 -10
- data/lib/torch/nn/linear.rb +2 -0
- data/lib/torch/nn/module.rb +6 -0
- data/lib/torch/nn/parameter.rb +1 -1
- data/lib/torch/tensor.rb +4 -0
- data/lib/torch/utils/data/data_loader.rb +1 -1
- data/lib/torch/version.rb +1 -1
- metadata +11 -88
data/ext/torch/cuda.cpp
ADDED
@@ -0,0 +1,14 @@
|
|
1
|
+
#include <torch/torch.h>
|
2
|
+
|
3
|
+
#include <rice/Module.hpp>
|
4
|
+
|
5
|
+
#include "utils.h"
|
6
|
+
|
7
|
+
void init_cuda(Rice::Module& m) {
|
8
|
+
Rice::define_module_under(m, "CUDA")
|
9
|
+
.add_handler<torch::Error>(handle_error)
|
10
|
+
.define_singleton_method("available?", &torch::cuda::is_available)
|
11
|
+
.define_singleton_method("device_count", &torch::cuda::device_count)
|
12
|
+
.define_singleton_method("manual_seed", &torch::cuda::manual_seed)
|
13
|
+
.define_singleton_method("manual_seed_all", &torch::cuda::manual_seed_all);
|
14
|
+
}
|
@@ -0,0 +1,21 @@
|
|
1
|
+
#include <torch/torch.h>
|
2
|
+
|
3
|
+
#include <rice/Constructor.hpp>
|
4
|
+
#include <rice/Module.hpp>
|
5
|
+
|
6
|
+
#include "utils.h"
|
7
|
+
|
8
|
+
void init_device(Rice::Module& m) {
|
9
|
+
Rice::define_class_under<torch::Device>(m, "Device")
|
10
|
+
.add_handler<torch::Error>(handle_error)
|
11
|
+
.define_constructor(Rice::Constructor<torch::Device, std::string>())
|
12
|
+
.define_method("index", &torch::Device::index)
|
13
|
+
.define_method("index?", &torch::Device::has_index)
|
14
|
+
.define_method(
|
15
|
+
"type",
|
16
|
+
*[](torch::Device& self) {
|
17
|
+
std::stringstream s;
|
18
|
+
s << self.type();
|
19
|
+
return s.str();
|
20
|
+
});
|
21
|
+
}
|
data/ext/torch/ext.cpp
CHANGED
@@ -1,631 +1,26 @@
|
|
1
|
-
#include <
|
1
|
+
#include <rice/Module.hpp>
|
2
2
|
|
3
|
-
|
3
|
+
void init_nn(Rice::Module& m);
|
4
|
+
void init_tensor(Rice::Module& m);
|
5
|
+
void init_torch(Rice::Module& m);
|
4
6
|
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
#include "templates.h"
|
11
|
-
#include "utils.h"
|
12
|
-
|
13
|
-
// generated with:
|
14
|
-
// rake generate:functions
|
15
|
-
#include "torch_functions.h"
|
16
|
-
#include "tensor_functions.h"
|
17
|
-
#include "nn_functions.h"
|
18
|
-
|
19
|
-
using namespace Rice;
|
20
|
-
using torch::indexing::TensorIndex;
|
21
|
-
|
22
|
-
// need to make a distinction between parameters and tensors
|
23
|
-
class Parameter: public torch::autograd::Variable {
|
24
|
-
public:
|
25
|
-
Parameter(Tensor&& t) : torch::autograd::Variable(t) { }
|
26
|
-
};
|
27
|
-
|
28
|
-
void handle_error(torch::Error const & ex)
|
29
|
-
{
|
30
|
-
throw Exception(rb_eRuntimeError, ex.what_without_backtrace());
|
31
|
-
}
|
32
|
-
|
33
|
-
Class rb_cTensor;
|
34
|
-
|
35
|
-
std::vector<TensorIndex> index_vector(Array a) {
|
36
|
-
Object obj;
|
37
|
-
|
38
|
-
std::vector<TensorIndex> indices;
|
39
|
-
indices.reserve(a.size());
|
40
|
-
|
41
|
-
for (size_t i = 0; i < a.size(); i++) {
|
42
|
-
obj = a[i];
|
43
|
-
|
44
|
-
if (obj.is_instance_of(rb_cInteger)) {
|
45
|
-
indices.push_back(from_ruby<int64_t>(obj));
|
46
|
-
} else if (obj.is_instance_of(rb_cRange)) {
|
47
|
-
torch::optional<int64_t> start_index = torch::nullopt;
|
48
|
-
torch::optional<int64_t> stop_index = torch::nullopt;
|
49
|
-
|
50
|
-
Object begin = obj.call("begin");
|
51
|
-
if (!begin.is_nil()) {
|
52
|
-
start_index = from_ruby<int64_t>(begin);
|
53
|
-
}
|
54
|
-
|
55
|
-
Object end = obj.call("end");
|
56
|
-
if (!end.is_nil()) {
|
57
|
-
stop_index = from_ruby<int64_t>(end);
|
58
|
-
}
|
59
|
-
|
60
|
-
Object exclude_end = obj.call("exclude_end?");
|
61
|
-
if (stop_index.has_value() && !exclude_end) {
|
62
|
-
if (stop_index.value() == -1) {
|
63
|
-
stop_index = torch::nullopt;
|
64
|
-
} else {
|
65
|
-
stop_index = stop_index.value() + 1;
|
66
|
-
}
|
67
|
-
} else if (!stop_index.has_value() && exclude_end) {
|
68
|
-
stop_index = -1;
|
69
|
-
}
|
70
|
-
|
71
|
-
indices.push_back(torch::indexing::Slice(start_index, stop_index));
|
72
|
-
} else if (obj.is_instance_of(rb_cTensor)) {
|
73
|
-
indices.push_back(from_ruby<Tensor>(obj));
|
74
|
-
} else if (obj.is_nil()) {
|
75
|
-
indices.push_back(torch::indexing::None);
|
76
|
-
} else if (obj == True || obj == False) {
|
77
|
-
indices.push_back(from_ruby<bool>(obj));
|
78
|
-
} else {
|
79
|
-
throw Exception(rb_eArgError, "Unsupported index type: %s", rb_obj_classname(obj));
|
80
|
-
}
|
81
|
-
}
|
82
|
-
return indices;
|
83
|
-
}
|
7
|
+
void init_cuda(Rice::Module& m);
|
8
|
+
void init_device(Rice::Module& m);
|
9
|
+
void init_ivalue(Rice::Module& m);
|
10
|
+
void init_random(Rice::Module& m);
|
84
11
|
|
85
12
|
extern "C"
|
86
13
|
void Init_ext()
|
87
14
|
{
|
88
|
-
|
89
|
-
rb_mTorch.add_handler<torch::Error>(handle_error);
|
90
|
-
add_torch_functions(rb_mTorch);
|
91
|
-
|
92
|
-
rb_cTensor = define_class_under<torch::Tensor>(rb_mTorch, "Tensor");
|
93
|
-
rb_cTensor.add_handler<torch::Error>(handle_error);
|
94
|
-
add_tensor_functions(rb_cTensor);
|
95
|
-
THPVariableClass = rb_cTensor.value();
|
96
|
-
|
97
|
-
Module rb_mNN = define_module_under(rb_mTorch, "NN");
|
98
|
-
rb_mNN.add_handler<torch::Error>(handle_error);
|
99
|
-
add_nn_functions(rb_mNN);
|
100
|
-
|
101
|
-
Module rb_mRandom = define_module_under(rb_mTorch, "Random")
|
102
|
-
.add_handler<torch::Error>(handle_error)
|
103
|
-
.define_singleton_method(
|
104
|
-
"initial_seed",
|
105
|
-
*[]() {
|
106
|
-
return at::detail::getDefaultCPUGenerator().current_seed();
|
107
|
-
})
|
108
|
-
.define_singleton_method(
|
109
|
-
"seed",
|
110
|
-
*[]() {
|
111
|
-
// TODO set for CUDA when available
|
112
|
-
auto generator = at::detail::getDefaultCPUGenerator();
|
113
|
-
return generator.seed();
|
114
|
-
});
|
115
|
-
|
116
|
-
// https://pytorch.org/cppdocs/api/structc10_1_1_i_value.html
|
117
|
-
Class rb_cIValue = define_class_under<torch::IValue>(rb_mTorch, "IValue")
|
118
|
-
.add_handler<torch::Error>(handle_error)
|
119
|
-
.define_constructor(Constructor<torch::IValue>())
|
120
|
-
.define_method("bool?", &torch::IValue::isBool)
|
121
|
-
.define_method("bool_list?", &torch::IValue::isBoolList)
|
122
|
-
.define_method("capsule?", &torch::IValue::isCapsule)
|
123
|
-
.define_method("custom_class?", &torch::IValue::isCustomClass)
|
124
|
-
.define_method("device?", &torch::IValue::isDevice)
|
125
|
-
.define_method("double?", &torch::IValue::isDouble)
|
126
|
-
.define_method("double_list?", &torch::IValue::isDoubleList)
|
127
|
-
.define_method("future?", &torch::IValue::isFuture)
|
128
|
-
// .define_method("generator?", &torch::IValue::isGenerator)
|
129
|
-
.define_method("generic_dict?", &torch::IValue::isGenericDict)
|
130
|
-
.define_method("list?", &torch::IValue::isList)
|
131
|
-
.define_method("int?", &torch::IValue::isInt)
|
132
|
-
.define_method("int_list?", &torch::IValue::isIntList)
|
133
|
-
.define_method("module?", &torch::IValue::isModule)
|
134
|
-
.define_method("none?", &torch::IValue::isNone)
|
135
|
-
.define_method("object?", &torch::IValue::isObject)
|
136
|
-
.define_method("ptr_type?", &torch::IValue::isPtrType)
|
137
|
-
.define_method("py_object?", &torch::IValue::isPyObject)
|
138
|
-
.define_method("r_ref?", &torch::IValue::isRRef)
|
139
|
-
.define_method("scalar?", &torch::IValue::isScalar)
|
140
|
-
.define_method("string?", &torch::IValue::isString)
|
141
|
-
.define_method("tensor?", &torch::IValue::isTensor)
|
142
|
-
.define_method("tensor_list?", &torch::IValue::isTensorList)
|
143
|
-
.define_method("tuple?", &torch::IValue::isTuple)
|
144
|
-
.define_method(
|
145
|
-
"to_bool",
|
146
|
-
*[](torch::IValue& self) {
|
147
|
-
return self.toBool();
|
148
|
-
})
|
149
|
-
.define_method(
|
150
|
-
"to_double",
|
151
|
-
*[](torch::IValue& self) {
|
152
|
-
return self.toDouble();
|
153
|
-
})
|
154
|
-
.define_method(
|
155
|
-
"to_int",
|
156
|
-
*[](torch::IValue& self) {
|
157
|
-
return self.toInt();
|
158
|
-
})
|
159
|
-
.define_method(
|
160
|
-
"to_list",
|
161
|
-
*[](torch::IValue& self) {
|
162
|
-
auto list = self.toListRef();
|
163
|
-
Array obj;
|
164
|
-
for (auto& elem : list) {
|
165
|
-
obj.push(to_ruby<torch::IValue>(torch::IValue{elem}));
|
166
|
-
}
|
167
|
-
return obj;
|
168
|
-
})
|
169
|
-
.define_method(
|
170
|
-
"to_string_ref",
|
171
|
-
*[](torch::IValue& self) {
|
172
|
-
return self.toStringRef();
|
173
|
-
})
|
174
|
-
.define_method(
|
175
|
-
"to_tensor",
|
176
|
-
*[](torch::IValue& self) {
|
177
|
-
return self.toTensor();
|
178
|
-
})
|
179
|
-
.define_method(
|
180
|
-
"to_generic_dict",
|
181
|
-
*[](torch::IValue& self) {
|
182
|
-
auto dict = self.toGenericDict();
|
183
|
-
Hash obj;
|
184
|
-
for (auto& pair : dict) {
|
185
|
-
obj[to_ruby<torch::IValue>(torch::IValue{pair.key()})] = to_ruby<torch::IValue>(torch::IValue{pair.value()});
|
186
|
-
}
|
187
|
-
return obj;
|
188
|
-
})
|
189
|
-
.define_singleton_method(
|
190
|
-
"from_tensor",
|
191
|
-
*[](torch::Tensor& v) {
|
192
|
-
return torch::IValue(v);
|
193
|
-
})
|
194
|
-
// TODO create specialized list types?
|
195
|
-
.define_singleton_method(
|
196
|
-
"from_list",
|
197
|
-
*[](Array obj) {
|
198
|
-
c10::impl::GenericList list(c10::AnyType::get());
|
199
|
-
for (auto entry : obj) {
|
200
|
-
list.push_back(from_ruby<torch::IValue>(entry));
|
201
|
-
}
|
202
|
-
return torch::IValue(list);
|
203
|
-
})
|
204
|
-
.define_singleton_method(
|
205
|
-
"from_string",
|
206
|
-
*[](String v) {
|
207
|
-
return torch::IValue(v.str());
|
208
|
-
})
|
209
|
-
.define_singleton_method(
|
210
|
-
"from_int",
|
211
|
-
*[](int64_t v) {
|
212
|
-
return torch::IValue(v);
|
213
|
-
})
|
214
|
-
.define_singleton_method(
|
215
|
-
"from_double",
|
216
|
-
*[](double v) {
|
217
|
-
return torch::IValue(v);
|
218
|
-
})
|
219
|
-
.define_singleton_method(
|
220
|
-
"from_bool",
|
221
|
-
*[](bool v) {
|
222
|
-
return torch::IValue(v);
|
223
|
-
})
|
224
|
-
// see https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/python/pybind_utils.h
|
225
|
-
// createGenericDict and toIValue
|
226
|
-
.define_singleton_method(
|
227
|
-
"from_dict",
|
228
|
-
*[](Hash obj) {
|
229
|
-
auto key_type = c10::AnyType::get();
|
230
|
-
auto value_type = c10::AnyType::get();
|
231
|
-
c10::impl::GenericDict elems(key_type, value_type);
|
232
|
-
elems.reserve(obj.size());
|
233
|
-
for (auto entry : obj) {
|
234
|
-
elems.insert(from_ruby<torch::IValue>(entry.first), from_ruby<torch::IValue>((Object) entry.second));
|
235
|
-
}
|
236
|
-
return torch::IValue(std::move(elems));
|
237
|
-
});
|
238
|
-
|
239
|
-
rb_mTorch.define_singleton_method(
|
240
|
-
"grad_enabled?",
|
241
|
-
*[]() {
|
242
|
-
return torch::GradMode::is_enabled();
|
243
|
-
})
|
244
|
-
.define_singleton_method(
|
245
|
-
"_set_grad_enabled",
|
246
|
-
*[](bool enabled) {
|
247
|
-
torch::GradMode::set_enabled(enabled);
|
248
|
-
})
|
249
|
-
.define_singleton_method(
|
250
|
-
"manual_seed",
|
251
|
-
*[](uint64_t seed) {
|
252
|
-
return torch::manual_seed(seed);
|
253
|
-
})
|
254
|
-
// config
|
255
|
-
.define_singleton_method(
|
256
|
-
"show_config",
|
257
|
-
*[] {
|
258
|
-
return torch::show_config();
|
259
|
-
})
|
260
|
-
.define_singleton_method(
|
261
|
-
"parallel_info",
|
262
|
-
*[] {
|
263
|
-
return torch::get_parallel_info();
|
264
|
-
})
|
265
|
-
// begin operations
|
266
|
-
.define_singleton_method(
|
267
|
-
"_save",
|
268
|
-
*[](const torch::IValue &value) {
|
269
|
-
auto v = torch::pickle_save(value);
|
270
|
-
std::string str(v.begin(), v.end());
|
271
|
-
return str;
|
272
|
-
})
|
273
|
-
.define_singleton_method(
|
274
|
-
"_load",
|
275
|
-
*[](const std::string &s) {
|
276
|
-
std::vector<char> v;
|
277
|
-
std::copy(s.begin(), s.end(), std::back_inserter(v));
|
278
|
-
// https://github.com/pytorch/pytorch/issues/20356#issuecomment-567663701
|
279
|
-
return torch::pickle_load(v);
|
280
|
-
})
|
281
|
-
.define_singleton_method(
|
282
|
-
"_from_blob",
|
283
|
-
*[](String s, std::vector<int64_t> size, const torch::TensorOptions &options) {
|
284
|
-
void *data = const_cast<char *>(s.c_str());
|
285
|
-
return torch::from_blob(data, size, options);
|
286
|
-
})
|
287
|
-
.define_singleton_method(
|
288
|
-
"_tensor",
|
289
|
-
*[](Array a, std::vector<int64_t> size, const torch::TensorOptions &options) {
|
290
|
-
auto dtype = options.dtype();
|
291
|
-
torch::Tensor t;
|
292
|
-
if (dtype == torch::kBool) {
|
293
|
-
std::vector<uint8_t> vec;
|
294
|
-
for (size_t i = 0; i < a.size(); i++) {
|
295
|
-
vec.push_back(from_ruby<bool>(a[i]));
|
296
|
-
}
|
297
|
-
t = torch::tensor(vec, options);
|
298
|
-
} else {
|
299
|
-
std::vector<float> vec;
|
300
|
-
for (size_t i = 0; i < a.size(); i++) {
|
301
|
-
vec.push_back(from_ruby<float>(a[i]));
|
302
|
-
}
|
303
|
-
// hack for requires_grad error
|
304
|
-
if (options.requires_grad()) {
|
305
|
-
t = torch::tensor(vec, options.requires_grad(c10::nullopt));
|
306
|
-
t.set_requires_grad(true);
|
307
|
-
} else {
|
308
|
-
t = torch::tensor(vec, options);
|
309
|
-
}
|
310
|
-
}
|
311
|
-
return t.reshape(size);
|
312
|
-
});
|
313
|
-
|
314
|
-
rb_cTensor
|
315
|
-
.define_method("cuda?", &torch::Tensor::is_cuda)
|
316
|
-
.define_method("sparse?", &torch::Tensor::is_sparse)
|
317
|
-
.define_method("quantized?", &torch::Tensor::is_quantized)
|
318
|
-
.define_method("dim", &torch::Tensor::dim)
|
319
|
-
.define_method("numel", &torch::Tensor::numel)
|
320
|
-
.define_method("element_size", &torch::Tensor::element_size)
|
321
|
-
.define_method("requires_grad", &torch::Tensor::requires_grad)
|
322
|
-
// in C++ for performance
|
323
|
-
.define_method(
|
324
|
-
"shape",
|
325
|
-
*[](Tensor& self) {
|
326
|
-
Array a;
|
327
|
-
for (auto &size : self.sizes()) {
|
328
|
-
a.push(size);
|
329
|
-
}
|
330
|
-
return a;
|
331
|
-
})
|
332
|
-
.define_method(
|
333
|
-
"_strides",
|
334
|
-
*[](Tensor& self) {
|
335
|
-
Array a;
|
336
|
-
for (auto &stride : self.strides()) {
|
337
|
-
a.push(stride);
|
338
|
-
}
|
339
|
-
return a;
|
340
|
-
})
|
341
|
-
.define_method(
|
342
|
-
"_index",
|
343
|
-
*[](Tensor& self, Array indices) {
|
344
|
-
auto vec = index_vector(indices);
|
345
|
-
return self.index(vec);
|
346
|
-
})
|
347
|
-
.define_method(
|
348
|
-
"_index_put_custom",
|
349
|
-
*[](Tensor& self, Array indices, torch::Tensor& value) {
|
350
|
-
auto vec = index_vector(indices);
|
351
|
-
return self.index_put_(vec, value);
|
352
|
-
})
|
353
|
-
.define_method(
|
354
|
-
"contiguous?",
|
355
|
-
*[](Tensor& self) {
|
356
|
-
return self.is_contiguous();
|
357
|
-
})
|
358
|
-
.define_method(
|
359
|
-
"_requires_grad!",
|
360
|
-
*[](Tensor& self, bool requires_grad) {
|
361
|
-
return self.set_requires_grad(requires_grad);
|
362
|
-
})
|
363
|
-
.define_method(
|
364
|
-
"grad",
|
365
|
-
*[](Tensor& self) {
|
366
|
-
auto grad = self.grad();
|
367
|
-
return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
|
368
|
-
})
|
369
|
-
.define_method(
|
370
|
-
"grad=",
|
371
|
-
*[](Tensor& self, torch::Tensor& grad) {
|
372
|
-
self.mutable_grad() = grad;
|
373
|
-
})
|
374
|
-
.define_method(
|
375
|
-
"_dtype",
|
376
|
-
*[](Tensor& self) {
|
377
|
-
return (int) at::typeMetaToScalarType(self.dtype());
|
378
|
-
})
|
379
|
-
.define_method(
|
380
|
-
"_type",
|
381
|
-
*[](Tensor& self, int dtype) {
|
382
|
-
return self.toType((torch::ScalarType) dtype);
|
383
|
-
})
|
384
|
-
.define_method(
|
385
|
-
"_layout",
|
386
|
-
*[](Tensor& self) {
|
387
|
-
std::stringstream s;
|
388
|
-
s << self.layout();
|
389
|
-
return s.str();
|
390
|
-
})
|
391
|
-
.define_method(
|
392
|
-
"device",
|
393
|
-
*[](Tensor& self) {
|
394
|
-
std::stringstream s;
|
395
|
-
s << self.device();
|
396
|
-
return s.str();
|
397
|
-
})
|
398
|
-
.define_method(
|
399
|
-
"_data_str",
|
400
|
-
*[](Tensor& self) {
|
401
|
-
Tensor tensor = self;
|
402
|
-
|
403
|
-
// move to CPU to get data
|
404
|
-
if (tensor.device().type() != torch::kCPU) {
|
405
|
-
torch::Device device("cpu");
|
406
|
-
tensor = tensor.to(device);
|
407
|
-
}
|
408
|
-
|
409
|
-
if (!tensor.is_contiguous()) {
|
410
|
-
tensor = tensor.contiguous();
|
411
|
-
}
|
412
|
-
|
413
|
-
auto data_ptr = (const char *) tensor.data_ptr();
|
414
|
-
return std::string(data_ptr, tensor.numel() * tensor.element_size());
|
415
|
-
})
|
416
|
-
// for TorchVision
|
417
|
-
.define_method(
|
418
|
-
"_data_ptr",
|
419
|
-
*[](Tensor& self) {
|
420
|
-
return reinterpret_cast<uintptr_t>(self.data_ptr());
|
421
|
-
})
|
422
|
-
// TODO figure out a better way to do this
|
423
|
-
.define_method(
|
424
|
-
"_flat_data",
|
425
|
-
*[](Tensor& self) {
|
426
|
-
Tensor tensor = self;
|
427
|
-
|
428
|
-
// move to CPU to get data
|
429
|
-
if (tensor.device().type() != torch::kCPU) {
|
430
|
-
torch::Device device("cpu");
|
431
|
-
tensor = tensor.to(device);
|
432
|
-
}
|
433
|
-
|
434
|
-
Array a;
|
435
|
-
auto dtype = tensor.dtype();
|
436
|
-
|
437
|
-
Tensor view = tensor.reshape({tensor.numel()});
|
438
|
-
|
439
|
-
// TODO DRY if someone knows C++
|
440
|
-
if (dtype == torch::kByte) {
|
441
|
-
for (int i = 0; i < tensor.numel(); i++) {
|
442
|
-
a.push(view[i].item().to<uint8_t>());
|
443
|
-
}
|
444
|
-
} else if (dtype == torch::kChar) {
|
445
|
-
for (int i = 0; i < tensor.numel(); i++) {
|
446
|
-
a.push(to_ruby<int>(view[i].item().to<int8_t>()));
|
447
|
-
}
|
448
|
-
} else if (dtype == torch::kShort) {
|
449
|
-
for (int i = 0; i < tensor.numel(); i++) {
|
450
|
-
a.push(view[i].item().to<int16_t>());
|
451
|
-
}
|
452
|
-
} else if (dtype == torch::kInt) {
|
453
|
-
for (int i = 0; i < tensor.numel(); i++) {
|
454
|
-
a.push(view[i].item().to<int32_t>());
|
455
|
-
}
|
456
|
-
} else if (dtype == torch::kLong) {
|
457
|
-
for (int i = 0; i < tensor.numel(); i++) {
|
458
|
-
a.push(view[i].item().to<int64_t>());
|
459
|
-
}
|
460
|
-
} else if (dtype == torch::kFloat) {
|
461
|
-
for (int i = 0; i < tensor.numel(); i++) {
|
462
|
-
a.push(view[i].item().to<float>());
|
463
|
-
}
|
464
|
-
} else if (dtype == torch::kDouble) {
|
465
|
-
for (int i = 0; i < tensor.numel(); i++) {
|
466
|
-
a.push(view[i].item().to<double>());
|
467
|
-
}
|
468
|
-
} else if (dtype == torch::kBool) {
|
469
|
-
for (int i = 0; i < tensor.numel(); i++) {
|
470
|
-
a.push(view[i].item().to<bool>() ? True : False);
|
471
|
-
}
|
472
|
-
} else {
|
473
|
-
throw std::runtime_error("Unsupported type");
|
474
|
-
}
|
475
|
-
return a;
|
476
|
-
})
|
477
|
-
.define_method(
|
478
|
-
"_to",
|
479
|
-
*[](Tensor& self, torch::Device device, int dtype, bool non_blocking, bool copy) {
|
480
|
-
return self.to(device, (torch::ScalarType) dtype, non_blocking, copy);
|
481
|
-
})
|
482
|
-
.define_singleton_method(
|
483
|
-
"_make_subclass",
|
484
|
-
*[](Tensor& rd, bool requires_grad) {
|
485
|
-
auto data = rd.detach();
|
486
|
-
data.unsafeGetTensorImpl()->set_allow_tensor_metadata_change(true);
|
487
|
-
auto var = data.set_requires_grad(requires_grad);
|
488
|
-
return Parameter(std::move(var));
|
489
|
-
});
|
490
|
-
|
491
|
-
Class rb_cTensorOptions = define_class_under<torch::TensorOptions>(rb_mTorch, "TensorOptions")
|
492
|
-
.add_handler<torch::Error>(handle_error)
|
493
|
-
.define_constructor(Constructor<torch::TensorOptions>())
|
494
|
-
.define_method(
|
495
|
-
"dtype",
|
496
|
-
*[](torch::TensorOptions& self, int dtype) {
|
497
|
-
return self.dtype((torch::ScalarType) dtype);
|
498
|
-
})
|
499
|
-
.define_method(
|
500
|
-
"layout",
|
501
|
-
*[](torch::TensorOptions& self, std::string layout) {
|
502
|
-
torch::Layout l;
|
503
|
-
if (layout == "strided") {
|
504
|
-
l = torch::kStrided;
|
505
|
-
} else if (layout == "sparse") {
|
506
|
-
l = torch::kSparse;
|
507
|
-
throw std::runtime_error("Sparse layout not supported yet");
|
508
|
-
} else {
|
509
|
-
throw std::runtime_error("Unsupported layout: " + layout);
|
510
|
-
}
|
511
|
-
return self.layout(l);
|
512
|
-
})
|
513
|
-
.define_method(
|
514
|
-
"device",
|
515
|
-
*[](torch::TensorOptions& self, std::string device) {
|
516
|
-
torch::Device d(device);
|
517
|
-
return self.device(d);
|
518
|
-
})
|
519
|
-
.define_method(
|
520
|
-
"requires_grad",
|
521
|
-
*[](torch::TensorOptions& self, bool requires_grad) {
|
522
|
-
return self.requires_grad(requires_grad);
|
523
|
-
});
|
524
|
-
|
525
|
-
Module rb_mInit = define_module_under(rb_mNN, "Init")
|
526
|
-
.add_handler<torch::Error>(handle_error)
|
527
|
-
.define_singleton_method(
|
528
|
-
"_calculate_gain",
|
529
|
-
*[](NonlinearityType nonlinearity, double param) {
|
530
|
-
return torch::nn::init::calculate_gain(nonlinearity, param);
|
531
|
-
})
|
532
|
-
.define_singleton_method(
|
533
|
-
"_uniform!",
|
534
|
-
*[](Tensor tensor, double low, double high) {
|
535
|
-
return torch::nn::init::uniform_(tensor, low, high);
|
536
|
-
})
|
537
|
-
.define_singleton_method(
|
538
|
-
"_normal!",
|
539
|
-
*[](Tensor tensor, double mean, double std) {
|
540
|
-
return torch::nn::init::normal_(tensor, mean, std);
|
541
|
-
})
|
542
|
-
.define_singleton_method(
|
543
|
-
"_constant!",
|
544
|
-
*[](Tensor tensor, Scalar value) {
|
545
|
-
return torch::nn::init::constant_(tensor, value);
|
546
|
-
})
|
547
|
-
.define_singleton_method(
|
548
|
-
"_ones!",
|
549
|
-
*[](Tensor tensor) {
|
550
|
-
return torch::nn::init::ones_(tensor);
|
551
|
-
})
|
552
|
-
.define_singleton_method(
|
553
|
-
"_zeros!",
|
554
|
-
*[](Tensor tensor) {
|
555
|
-
return torch::nn::init::zeros_(tensor);
|
556
|
-
})
|
557
|
-
.define_singleton_method(
|
558
|
-
"_eye!",
|
559
|
-
*[](Tensor tensor) {
|
560
|
-
return torch::nn::init::eye_(tensor);
|
561
|
-
})
|
562
|
-
.define_singleton_method(
|
563
|
-
"_dirac!",
|
564
|
-
*[](Tensor tensor) {
|
565
|
-
return torch::nn::init::dirac_(tensor);
|
566
|
-
})
|
567
|
-
.define_singleton_method(
|
568
|
-
"_xavier_uniform!",
|
569
|
-
*[](Tensor tensor, double gain) {
|
570
|
-
return torch::nn::init::xavier_uniform_(tensor, gain);
|
571
|
-
})
|
572
|
-
.define_singleton_method(
|
573
|
-
"_xavier_normal!",
|
574
|
-
*[](Tensor tensor, double gain) {
|
575
|
-
return torch::nn::init::xavier_normal_(tensor, gain);
|
576
|
-
})
|
577
|
-
.define_singleton_method(
|
578
|
-
"_kaiming_uniform!",
|
579
|
-
*[](Tensor tensor, double a, FanModeType mode, NonlinearityType nonlinearity) {
|
580
|
-
return torch::nn::init::kaiming_uniform_(tensor, a, mode, nonlinearity);
|
581
|
-
})
|
582
|
-
.define_singleton_method(
|
583
|
-
"_kaiming_normal!",
|
584
|
-
*[](Tensor tensor, double a, FanModeType mode, NonlinearityType nonlinearity) {
|
585
|
-
return torch::nn::init::kaiming_normal_(tensor, a, mode, nonlinearity);
|
586
|
-
})
|
587
|
-
.define_singleton_method(
|
588
|
-
"_orthogonal!",
|
589
|
-
*[](Tensor tensor, double gain) {
|
590
|
-
return torch::nn::init::orthogonal_(tensor, gain);
|
591
|
-
})
|
592
|
-
.define_singleton_method(
|
593
|
-
"_sparse!",
|
594
|
-
*[](Tensor tensor, double sparsity, double std) {
|
595
|
-
return torch::nn::init::sparse_(tensor, sparsity, std);
|
596
|
-
});
|
597
|
-
|
598
|
-
Class rb_cParameter = define_class_under<Parameter, torch::Tensor>(rb_mNN, "Parameter")
|
599
|
-
.add_handler<torch::Error>(handle_error)
|
600
|
-
.define_method(
|
601
|
-
"grad",
|
602
|
-
*[](Parameter& self) {
|
603
|
-
auto grad = self.grad();
|
604
|
-
return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
|
605
|
-
})
|
606
|
-
.define_method(
|
607
|
-
"grad=",
|
608
|
-
*[](Parameter& self, torch::Tensor& grad) {
|
609
|
-
self.mutable_grad() = grad;
|
610
|
-
});
|
15
|
+
auto m = Rice::define_module("Torch");
|
611
16
|
|
612
|
-
|
613
|
-
|
614
|
-
|
615
|
-
|
616
|
-
.define_method("index?", &torch::Device::has_index)
|
617
|
-
.define_method(
|
618
|
-
"type",
|
619
|
-
*[](torch::Device& self) {
|
620
|
-
std::stringstream s;
|
621
|
-
s << self.type();
|
622
|
-
return s.str();
|
623
|
-
});
|
17
|
+
// keep this order
|
18
|
+
init_torch(m);
|
19
|
+
init_tensor(m);
|
20
|
+
init_nn(m);
|
624
21
|
|
625
|
-
|
626
|
-
|
627
|
-
|
628
|
-
|
629
|
-
.define_singleton_method("manual_seed", &torch::cuda::manual_seed)
|
630
|
-
.define_singleton_method("manual_seed_all", &torch::cuda::manual_seed_all);
|
22
|
+
init_cuda(m);
|
23
|
+
init_device(m);
|
24
|
+
init_ivalue(m);
|
25
|
+
init_random(m);
|
631
26
|
}
|