torch-rb 0.5.3 → 0.6.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,14 @@
1
+ #include <torch/torch.h>
2
+
3
+ #include <rice/Module.hpp>
4
+
5
+ #include "utils.h"
6
+
7
+ void init_cuda(Rice::Module& m) {
8
+ Rice::define_module_under(m, "CUDA")
9
+ .add_handler<torch::Error>(handle_error)
10
+ .define_singleton_method("available?", &torch::cuda::is_available)
11
+ .define_singleton_method("device_count", &torch::cuda::device_count)
12
+ .define_singleton_method("manual_seed", &torch::cuda::manual_seed)
13
+ .define_singleton_method("manual_seed_all", &torch::cuda::manual_seed_all);
14
+ }
@@ -0,0 +1,21 @@
1
+ #include <torch/torch.h>
2
+
3
+ #include <rice/Constructor.hpp>
4
+ #include <rice/Module.hpp>
5
+
6
+ #include "utils.h"
7
+
8
+ void init_device(Rice::Module& m) {
9
+ Rice::define_class_under<torch::Device>(m, "Device")
10
+ .add_handler<torch::Error>(handle_error)
11
+ .define_constructor(Rice::Constructor<torch::Device, std::string>())
12
+ .define_method("index", &torch::Device::index)
13
+ .define_method("index?", &torch::Device::has_index)
14
+ .define_method(
15
+ "type",
16
+ *[](torch::Device& self) {
17
+ std::stringstream s;
18
+ s << self.type();
19
+ return s.str();
20
+ });
21
+ }
data/ext/torch/ext.cpp CHANGED
@@ -1,631 +1,26 @@
1
- #include <sstream>
1
+ #include <rice/Module.hpp>
2
2
 
3
- #include <torch/torch.h>
3
+ void init_nn(Rice::Module& m);
4
+ void init_tensor(Rice::Module& m);
5
+ void init_torch(Rice::Module& m);
4
6
 
5
- #include <rice/Array.hpp>
6
- #include <rice/Class.hpp>
7
- #include <rice/Constructor.hpp>
8
- #include <rice/Hash.hpp>
9
-
10
- #include "templates.h"
11
- #include "utils.h"
12
-
13
- // generated with:
14
- // rake generate:functions
15
- #include "torch_functions.h"
16
- #include "tensor_functions.h"
17
- #include "nn_functions.h"
18
-
19
- using namespace Rice;
20
- using torch::indexing::TensorIndex;
21
-
22
- // need to make a distinction between parameters and tensors
23
- class Parameter: public torch::autograd::Variable {
24
- public:
25
- Parameter(Tensor&& t) : torch::autograd::Variable(t) { }
26
- };
27
-
28
- void handle_error(torch::Error const & ex)
29
- {
30
- throw Exception(rb_eRuntimeError, ex.what_without_backtrace());
31
- }
32
-
33
- Class rb_cTensor;
34
-
35
- std::vector<TensorIndex> index_vector(Array a) {
36
- Object obj;
37
-
38
- std::vector<TensorIndex> indices;
39
- indices.reserve(a.size());
40
-
41
- for (size_t i = 0; i < a.size(); i++) {
42
- obj = a[i];
43
-
44
- if (obj.is_instance_of(rb_cInteger)) {
45
- indices.push_back(from_ruby<int64_t>(obj));
46
- } else if (obj.is_instance_of(rb_cRange)) {
47
- torch::optional<int64_t> start_index = torch::nullopt;
48
- torch::optional<int64_t> stop_index = torch::nullopt;
49
-
50
- Object begin = obj.call("begin");
51
- if (!begin.is_nil()) {
52
- start_index = from_ruby<int64_t>(begin);
53
- }
54
-
55
- Object end = obj.call("end");
56
- if (!end.is_nil()) {
57
- stop_index = from_ruby<int64_t>(end);
58
- }
59
-
60
- Object exclude_end = obj.call("exclude_end?");
61
- if (stop_index.has_value() && !exclude_end) {
62
- if (stop_index.value() == -1) {
63
- stop_index = torch::nullopt;
64
- } else {
65
- stop_index = stop_index.value() + 1;
66
- }
67
- } else if (!stop_index.has_value() && exclude_end) {
68
- stop_index = -1;
69
- }
70
-
71
- indices.push_back(torch::indexing::Slice(start_index, stop_index));
72
- } else if (obj.is_instance_of(rb_cTensor)) {
73
- indices.push_back(from_ruby<Tensor>(obj));
74
- } else if (obj.is_nil()) {
75
- indices.push_back(torch::indexing::None);
76
- } else if (obj == True || obj == False) {
77
- indices.push_back(from_ruby<bool>(obj));
78
- } else {
79
- throw Exception(rb_eArgError, "Unsupported index type: %s", rb_obj_classname(obj));
80
- }
81
- }
82
- return indices;
83
- }
7
+ void init_cuda(Rice::Module& m);
8
+ void init_device(Rice::Module& m);
9
+ void init_ivalue(Rice::Module& m);
10
+ void init_random(Rice::Module& m);
84
11
 
85
12
  extern "C"
86
13
  void Init_ext()
87
14
  {
88
- Module rb_mTorch = define_module("Torch");
89
- rb_mTorch.add_handler<torch::Error>(handle_error);
90
- add_torch_functions(rb_mTorch);
91
-
92
- rb_cTensor = define_class_under<torch::Tensor>(rb_mTorch, "Tensor");
93
- rb_cTensor.add_handler<torch::Error>(handle_error);
94
- add_tensor_functions(rb_cTensor);
95
- THPVariableClass = rb_cTensor.value();
96
-
97
- Module rb_mNN = define_module_under(rb_mTorch, "NN");
98
- rb_mNN.add_handler<torch::Error>(handle_error);
99
- add_nn_functions(rb_mNN);
100
-
101
- Module rb_mRandom = define_module_under(rb_mTorch, "Random")
102
- .add_handler<torch::Error>(handle_error)
103
- .define_singleton_method(
104
- "initial_seed",
105
- *[]() {
106
- return at::detail::getDefaultCPUGenerator().current_seed();
107
- })
108
- .define_singleton_method(
109
- "seed",
110
- *[]() {
111
- // TODO set for CUDA when available
112
- auto generator = at::detail::getDefaultCPUGenerator();
113
- return generator.seed();
114
- });
115
-
116
- // https://pytorch.org/cppdocs/api/structc10_1_1_i_value.html
117
- Class rb_cIValue = define_class_under<torch::IValue>(rb_mTorch, "IValue")
118
- .add_handler<torch::Error>(handle_error)
119
- .define_constructor(Constructor<torch::IValue>())
120
- .define_method("bool?", &torch::IValue::isBool)
121
- .define_method("bool_list?", &torch::IValue::isBoolList)
122
- .define_method("capsule?", &torch::IValue::isCapsule)
123
- .define_method("custom_class?", &torch::IValue::isCustomClass)
124
- .define_method("device?", &torch::IValue::isDevice)
125
- .define_method("double?", &torch::IValue::isDouble)
126
- .define_method("double_list?", &torch::IValue::isDoubleList)
127
- .define_method("future?", &torch::IValue::isFuture)
128
- // .define_method("generator?", &torch::IValue::isGenerator)
129
- .define_method("generic_dict?", &torch::IValue::isGenericDict)
130
- .define_method("list?", &torch::IValue::isList)
131
- .define_method("int?", &torch::IValue::isInt)
132
- .define_method("int_list?", &torch::IValue::isIntList)
133
- .define_method("module?", &torch::IValue::isModule)
134
- .define_method("none?", &torch::IValue::isNone)
135
- .define_method("object?", &torch::IValue::isObject)
136
- .define_method("ptr_type?", &torch::IValue::isPtrType)
137
- .define_method("py_object?", &torch::IValue::isPyObject)
138
- .define_method("r_ref?", &torch::IValue::isRRef)
139
- .define_method("scalar?", &torch::IValue::isScalar)
140
- .define_method("string?", &torch::IValue::isString)
141
- .define_method("tensor?", &torch::IValue::isTensor)
142
- .define_method("tensor_list?", &torch::IValue::isTensorList)
143
- .define_method("tuple?", &torch::IValue::isTuple)
144
- .define_method(
145
- "to_bool",
146
- *[](torch::IValue& self) {
147
- return self.toBool();
148
- })
149
- .define_method(
150
- "to_double",
151
- *[](torch::IValue& self) {
152
- return self.toDouble();
153
- })
154
- .define_method(
155
- "to_int",
156
- *[](torch::IValue& self) {
157
- return self.toInt();
158
- })
159
- .define_method(
160
- "to_list",
161
- *[](torch::IValue& self) {
162
- auto list = self.toListRef();
163
- Array obj;
164
- for (auto& elem : list) {
165
- obj.push(to_ruby<torch::IValue>(torch::IValue{elem}));
166
- }
167
- return obj;
168
- })
169
- .define_method(
170
- "to_string_ref",
171
- *[](torch::IValue& self) {
172
- return self.toStringRef();
173
- })
174
- .define_method(
175
- "to_tensor",
176
- *[](torch::IValue& self) {
177
- return self.toTensor();
178
- })
179
- .define_method(
180
- "to_generic_dict",
181
- *[](torch::IValue& self) {
182
- auto dict = self.toGenericDict();
183
- Hash obj;
184
- for (auto& pair : dict) {
185
- obj[to_ruby<torch::IValue>(torch::IValue{pair.key()})] = to_ruby<torch::IValue>(torch::IValue{pair.value()});
186
- }
187
- return obj;
188
- })
189
- .define_singleton_method(
190
- "from_tensor",
191
- *[](torch::Tensor& v) {
192
- return torch::IValue(v);
193
- })
194
- // TODO create specialized list types?
195
- .define_singleton_method(
196
- "from_list",
197
- *[](Array obj) {
198
- c10::impl::GenericList list(c10::AnyType::get());
199
- for (auto entry : obj) {
200
- list.push_back(from_ruby<torch::IValue>(entry));
201
- }
202
- return torch::IValue(list);
203
- })
204
- .define_singleton_method(
205
- "from_string",
206
- *[](String v) {
207
- return torch::IValue(v.str());
208
- })
209
- .define_singleton_method(
210
- "from_int",
211
- *[](int64_t v) {
212
- return torch::IValue(v);
213
- })
214
- .define_singleton_method(
215
- "from_double",
216
- *[](double v) {
217
- return torch::IValue(v);
218
- })
219
- .define_singleton_method(
220
- "from_bool",
221
- *[](bool v) {
222
- return torch::IValue(v);
223
- })
224
- // see https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/python/pybind_utils.h
225
- // createGenericDict and toIValue
226
- .define_singleton_method(
227
- "from_dict",
228
- *[](Hash obj) {
229
- auto key_type = c10::AnyType::get();
230
- auto value_type = c10::AnyType::get();
231
- c10::impl::GenericDict elems(key_type, value_type);
232
- elems.reserve(obj.size());
233
- for (auto entry : obj) {
234
- elems.insert(from_ruby<torch::IValue>(entry.first), from_ruby<torch::IValue>((Object) entry.second));
235
- }
236
- return torch::IValue(std::move(elems));
237
- });
238
-
239
- rb_mTorch.define_singleton_method(
240
- "grad_enabled?",
241
- *[]() {
242
- return torch::GradMode::is_enabled();
243
- })
244
- .define_singleton_method(
245
- "_set_grad_enabled",
246
- *[](bool enabled) {
247
- torch::GradMode::set_enabled(enabled);
248
- })
249
- .define_singleton_method(
250
- "manual_seed",
251
- *[](uint64_t seed) {
252
- return torch::manual_seed(seed);
253
- })
254
- // config
255
- .define_singleton_method(
256
- "show_config",
257
- *[] {
258
- return torch::show_config();
259
- })
260
- .define_singleton_method(
261
- "parallel_info",
262
- *[] {
263
- return torch::get_parallel_info();
264
- })
265
- // begin operations
266
- .define_singleton_method(
267
- "_save",
268
- *[](const torch::IValue &value) {
269
- auto v = torch::pickle_save(value);
270
- std::string str(v.begin(), v.end());
271
- return str;
272
- })
273
- .define_singleton_method(
274
- "_load",
275
- *[](const std::string &s) {
276
- std::vector<char> v;
277
- std::copy(s.begin(), s.end(), std::back_inserter(v));
278
- // https://github.com/pytorch/pytorch/issues/20356#issuecomment-567663701
279
- return torch::pickle_load(v);
280
- })
281
- .define_singleton_method(
282
- "_from_blob",
283
- *[](String s, std::vector<int64_t> size, const torch::TensorOptions &options) {
284
- void *data = const_cast<char *>(s.c_str());
285
- return torch::from_blob(data, size, options);
286
- })
287
- .define_singleton_method(
288
- "_tensor",
289
- *[](Array a, std::vector<int64_t> size, const torch::TensorOptions &options) {
290
- auto dtype = options.dtype();
291
- torch::Tensor t;
292
- if (dtype == torch::kBool) {
293
- std::vector<uint8_t> vec;
294
- for (size_t i = 0; i < a.size(); i++) {
295
- vec.push_back(from_ruby<bool>(a[i]));
296
- }
297
- t = torch::tensor(vec, options);
298
- } else {
299
- std::vector<float> vec;
300
- for (size_t i = 0; i < a.size(); i++) {
301
- vec.push_back(from_ruby<float>(a[i]));
302
- }
303
- // hack for requires_grad error
304
- if (options.requires_grad()) {
305
- t = torch::tensor(vec, options.requires_grad(c10::nullopt));
306
- t.set_requires_grad(true);
307
- } else {
308
- t = torch::tensor(vec, options);
309
- }
310
- }
311
- return t.reshape(size);
312
- });
313
-
314
- rb_cTensor
315
- .define_method("cuda?", &torch::Tensor::is_cuda)
316
- .define_method("sparse?", &torch::Tensor::is_sparse)
317
- .define_method("quantized?", &torch::Tensor::is_quantized)
318
- .define_method("dim", &torch::Tensor::dim)
319
- .define_method("numel", &torch::Tensor::numel)
320
- .define_method("element_size", &torch::Tensor::element_size)
321
- .define_method("requires_grad", &torch::Tensor::requires_grad)
322
- // in C++ for performance
323
- .define_method(
324
- "shape",
325
- *[](Tensor& self) {
326
- Array a;
327
- for (auto &size : self.sizes()) {
328
- a.push(size);
329
- }
330
- return a;
331
- })
332
- .define_method(
333
- "_strides",
334
- *[](Tensor& self) {
335
- Array a;
336
- for (auto &stride : self.strides()) {
337
- a.push(stride);
338
- }
339
- return a;
340
- })
341
- .define_method(
342
- "_index",
343
- *[](Tensor& self, Array indices) {
344
- auto vec = index_vector(indices);
345
- return self.index(vec);
346
- })
347
- .define_method(
348
- "_index_put_custom",
349
- *[](Tensor& self, Array indices, torch::Tensor& value) {
350
- auto vec = index_vector(indices);
351
- return self.index_put_(vec, value);
352
- })
353
- .define_method(
354
- "contiguous?",
355
- *[](Tensor& self) {
356
- return self.is_contiguous();
357
- })
358
- .define_method(
359
- "_requires_grad!",
360
- *[](Tensor& self, bool requires_grad) {
361
- return self.set_requires_grad(requires_grad);
362
- })
363
- .define_method(
364
- "grad",
365
- *[](Tensor& self) {
366
- auto grad = self.grad();
367
- return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
368
- })
369
- .define_method(
370
- "grad=",
371
- *[](Tensor& self, torch::Tensor& grad) {
372
- self.mutable_grad() = grad;
373
- })
374
- .define_method(
375
- "_dtype",
376
- *[](Tensor& self) {
377
- return (int) at::typeMetaToScalarType(self.dtype());
378
- })
379
- .define_method(
380
- "_type",
381
- *[](Tensor& self, int dtype) {
382
- return self.toType((torch::ScalarType) dtype);
383
- })
384
- .define_method(
385
- "_layout",
386
- *[](Tensor& self) {
387
- std::stringstream s;
388
- s << self.layout();
389
- return s.str();
390
- })
391
- .define_method(
392
- "device",
393
- *[](Tensor& self) {
394
- std::stringstream s;
395
- s << self.device();
396
- return s.str();
397
- })
398
- .define_method(
399
- "_data_str",
400
- *[](Tensor& self) {
401
- Tensor tensor = self;
402
-
403
- // move to CPU to get data
404
- if (tensor.device().type() != torch::kCPU) {
405
- torch::Device device("cpu");
406
- tensor = tensor.to(device);
407
- }
408
-
409
- if (!tensor.is_contiguous()) {
410
- tensor = tensor.contiguous();
411
- }
412
-
413
- auto data_ptr = (const char *) tensor.data_ptr();
414
- return std::string(data_ptr, tensor.numel() * tensor.element_size());
415
- })
416
- // for TorchVision
417
- .define_method(
418
- "_data_ptr",
419
- *[](Tensor& self) {
420
- return reinterpret_cast<uintptr_t>(self.data_ptr());
421
- })
422
- // TODO figure out a better way to do this
423
- .define_method(
424
- "_flat_data",
425
- *[](Tensor& self) {
426
- Tensor tensor = self;
427
-
428
- // move to CPU to get data
429
- if (tensor.device().type() != torch::kCPU) {
430
- torch::Device device("cpu");
431
- tensor = tensor.to(device);
432
- }
433
-
434
- Array a;
435
- auto dtype = tensor.dtype();
436
-
437
- Tensor view = tensor.reshape({tensor.numel()});
438
-
439
- // TODO DRY if someone knows C++
440
- if (dtype == torch::kByte) {
441
- for (int i = 0; i < tensor.numel(); i++) {
442
- a.push(view[i].item().to<uint8_t>());
443
- }
444
- } else if (dtype == torch::kChar) {
445
- for (int i = 0; i < tensor.numel(); i++) {
446
- a.push(to_ruby<int>(view[i].item().to<int8_t>()));
447
- }
448
- } else if (dtype == torch::kShort) {
449
- for (int i = 0; i < tensor.numel(); i++) {
450
- a.push(view[i].item().to<int16_t>());
451
- }
452
- } else if (dtype == torch::kInt) {
453
- for (int i = 0; i < tensor.numel(); i++) {
454
- a.push(view[i].item().to<int32_t>());
455
- }
456
- } else if (dtype == torch::kLong) {
457
- for (int i = 0; i < tensor.numel(); i++) {
458
- a.push(view[i].item().to<int64_t>());
459
- }
460
- } else if (dtype == torch::kFloat) {
461
- for (int i = 0; i < tensor.numel(); i++) {
462
- a.push(view[i].item().to<float>());
463
- }
464
- } else if (dtype == torch::kDouble) {
465
- for (int i = 0; i < tensor.numel(); i++) {
466
- a.push(view[i].item().to<double>());
467
- }
468
- } else if (dtype == torch::kBool) {
469
- for (int i = 0; i < tensor.numel(); i++) {
470
- a.push(view[i].item().to<bool>() ? True : False);
471
- }
472
- } else {
473
- throw std::runtime_error("Unsupported type");
474
- }
475
- return a;
476
- })
477
- .define_method(
478
- "_to",
479
- *[](Tensor& self, torch::Device device, int dtype, bool non_blocking, bool copy) {
480
- return self.to(device, (torch::ScalarType) dtype, non_blocking, copy);
481
- })
482
- .define_singleton_method(
483
- "_make_subclass",
484
- *[](Tensor& rd, bool requires_grad) {
485
- auto data = rd.detach();
486
- data.unsafeGetTensorImpl()->set_allow_tensor_metadata_change(true);
487
- auto var = data.set_requires_grad(requires_grad);
488
- return Parameter(std::move(var));
489
- });
490
-
491
- Class rb_cTensorOptions = define_class_under<torch::TensorOptions>(rb_mTorch, "TensorOptions")
492
- .add_handler<torch::Error>(handle_error)
493
- .define_constructor(Constructor<torch::TensorOptions>())
494
- .define_method(
495
- "dtype",
496
- *[](torch::TensorOptions& self, int dtype) {
497
- return self.dtype((torch::ScalarType) dtype);
498
- })
499
- .define_method(
500
- "layout",
501
- *[](torch::TensorOptions& self, std::string layout) {
502
- torch::Layout l;
503
- if (layout == "strided") {
504
- l = torch::kStrided;
505
- } else if (layout == "sparse") {
506
- l = torch::kSparse;
507
- throw std::runtime_error("Sparse layout not supported yet");
508
- } else {
509
- throw std::runtime_error("Unsupported layout: " + layout);
510
- }
511
- return self.layout(l);
512
- })
513
- .define_method(
514
- "device",
515
- *[](torch::TensorOptions& self, std::string device) {
516
- torch::Device d(device);
517
- return self.device(d);
518
- })
519
- .define_method(
520
- "requires_grad",
521
- *[](torch::TensorOptions& self, bool requires_grad) {
522
- return self.requires_grad(requires_grad);
523
- });
524
-
525
- Module rb_mInit = define_module_under(rb_mNN, "Init")
526
- .add_handler<torch::Error>(handle_error)
527
- .define_singleton_method(
528
- "_calculate_gain",
529
- *[](NonlinearityType nonlinearity, double param) {
530
- return torch::nn::init::calculate_gain(nonlinearity, param);
531
- })
532
- .define_singleton_method(
533
- "_uniform!",
534
- *[](Tensor tensor, double low, double high) {
535
- return torch::nn::init::uniform_(tensor, low, high);
536
- })
537
- .define_singleton_method(
538
- "_normal!",
539
- *[](Tensor tensor, double mean, double std) {
540
- return torch::nn::init::normal_(tensor, mean, std);
541
- })
542
- .define_singleton_method(
543
- "_constant!",
544
- *[](Tensor tensor, Scalar value) {
545
- return torch::nn::init::constant_(tensor, value);
546
- })
547
- .define_singleton_method(
548
- "_ones!",
549
- *[](Tensor tensor) {
550
- return torch::nn::init::ones_(tensor);
551
- })
552
- .define_singleton_method(
553
- "_zeros!",
554
- *[](Tensor tensor) {
555
- return torch::nn::init::zeros_(tensor);
556
- })
557
- .define_singleton_method(
558
- "_eye!",
559
- *[](Tensor tensor) {
560
- return torch::nn::init::eye_(tensor);
561
- })
562
- .define_singleton_method(
563
- "_dirac!",
564
- *[](Tensor tensor) {
565
- return torch::nn::init::dirac_(tensor);
566
- })
567
- .define_singleton_method(
568
- "_xavier_uniform!",
569
- *[](Tensor tensor, double gain) {
570
- return torch::nn::init::xavier_uniform_(tensor, gain);
571
- })
572
- .define_singleton_method(
573
- "_xavier_normal!",
574
- *[](Tensor tensor, double gain) {
575
- return torch::nn::init::xavier_normal_(tensor, gain);
576
- })
577
- .define_singleton_method(
578
- "_kaiming_uniform!",
579
- *[](Tensor tensor, double a, FanModeType mode, NonlinearityType nonlinearity) {
580
- return torch::nn::init::kaiming_uniform_(tensor, a, mode, nonlinearity);
581
- })
582
- .define_singleton_method(
583
- "_kaiming_normal!",
584
- *[](Tensor tensor, double a, FanModeType mode, NonlinearityType nonlinearity) {
585
- return torch::nn::init::kaiming_normal_(tensor, a, mode, nonlinearity);
586
- })
587
- .define_singleton_method(
588
- "_orthogonal!",
589
- *[](Tensor tensor, double gain) {
590
- return torch::nn::init::orthogonal_(tensor, gain);
591
- })
592
- .define_singleton_method(
593
- "_sparse!",
594
- *[](Tensor tensor, double sparsity, double std) {
595
- return torch::nn::init::sparse_(tensor, sparsity, std);
596
- });
597
-
598
- Class rb_cParameter = define_class_under<Parameter, torch::Tensor>(rb_mNN, "Parameter")
599
- .add_handler<torch::Error>(handle_error)
600
- .define_method(
601
- "grad",
602
- *[](Parameter& self) {
603
- auto grad = self.grad();
604
- return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
605
- })
606
- .define_method(
607
- "grad=",
608
- *[](Parameter& self, torch::Tensor& grad) {
609
- self.mutable_grad() = grad;
610
- });
15
+ auto m = Rice::define_module("Torch");
611
16
 
612
- Class rb_cDevice = define_class_under<torch::Device>(rb_mTorch, "Device")
613
- .add_handler<torch::Error>(handle_error)
614
- .define_constructor(Constructor<torch::Device, std::string>())
615
- .define_method("index", &torch::Device::index)
616
- .define_method("index?", &torch::Device::has_index)
617
- .define_method(
618
- "type",
619
- *[](torch::Device& self) {
620
- std::stringstream s;
621
- s << self.type();
622
- return s.str();
623
- });
17
+ // keep this order
18
+ init_torch(m);
19
+ init_tensor(m);
20
+ init_nn(m);
624
21
 
625
- Module rb_mCUDA = define_module_under(rb_mTorch, "CUDA")
626
- .add_handler<torch::Error>(handle_error)
627
- .define_singleton_method("available?", &torch::cuda::is_available)
628
- .define_singleton_method("device_count", &torch::cuda::device_count)
629
- .define_singleton_method("manual_seed", &torch::cuda::manual_seed)
630
- .define_singleton_method("manual_seed_all", &torch::cuda::manual_seed_all);
22
+ init_cuda(m);
23
+ init_device(m);
24
+ init_ivalue(m);
25
+ init_random(m);
631
26
  }