torch-rb 0.5.2 → 0.8.1

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,17 @@
1
+ #include <torch/torch.h>
2
+
3
+ #include <rice/rice.hpp>
4
+
5
+ #include "utils.h"
6
+
7
+ void init_backends(Rice::Module& m) {
8
+ auto rb_mBackends = Rice::define_module_under(m, "Backends");
9
+
10
+ Rice::define_module_under(rb_mBackends, "OpenMP")
11
+ .add_handler<torch::Error>(handle_error)
12
+ .define_singleton_function("available?", &torch::hasOpenMP);
13
+
14
+ Rice::define_module_under(rb_mBackends, "MKL")
15
+ .add_handler<torch::Error>(handle_error)
16
+ .define_singleton_function("available?", &torch::hasMKL);
17
+ }
@@ -0,0 +1,14 @@
1
+ #include <torch/torch.h>
2
+
3
+ #include <rice/rice.hpp>
4
+
5
+ #include "utils.h"
6
+
7
+ void init_cuda(Rice::Module& m) {
8
+ Rice::define_module_under(m, "CUDA")
9
+ .add_handler<torch::Error>(handle_error)
10
+ .define_singleton_function("available?", &torch::cuda::is_available)
11
+ .define_singleton_function("device_count", &torch::cuda::device_count)
12
+ .define_singleton_function("manual_seed", &torch::cuda::manual_seed)
13
+ .define_singleton_function("manual_seed_all", &torch::cuda::manual_seed_all);
14
+ }
@@ -0,0 +1,28 @@
1
+ #include <torch/torch.h>
2
+
3
+ #include <rice/rice.hpp>
4
+
5
+ #include "utils.h"
6
+
7
+ void init_device(Rice::Module& m) {
8
+ Rice::define_class_under<torch::Device>(m, "Device")
9
+ .add_handler<torch::Error>(handle_error)
10
+ .define_constructor(Rice::Constructor<torch::Device, const std::string&>())
11
+ .define_method(
12
+ "index",
13
+ [](torch::Device& self) {
14
+ return self.index();
15
+ })
16
+ .define_method(
17
+ "index?",
18
+ [](torch::Device& self) {
19
+ return self.has_index();
20
+ })
21
+ .define_method(
22
+ "type",
23
+ [](torch::Device& self) {
24
+ std::stringstream s;
25
+ s << self.type();
26
+ return s.str();
27
+ });
28
+ }
data/ext/torch/ext.cpp CHANGED
@@ -1,622 +1,43 @@
1
- #include <sstream>
2
-
3
1
  #include <torch/torch.h>
4
2
 
5
- #include <rice/Array.hpp>
6
- #include <rice/Class.hpp>
7
- #include <rice/Constructor.hpp>
8
- #include <rice/Hash.hpp>
9
-
10
- #include "templates.h"
11
- #include "utils.h"
12
-
13
- // generated with:
14
- // rake generate:functions
15
- #include "torch_functions.h"
16
- #include "tensor_functions.h"
17
- #include "nn_functions.h"
18
-
19
- using namespace Rice;
20
- using torch::indexing::TensorIndex;
21
-
22
- // need to make a distinction between parameters and tensors
23
- class Parameter: public torch::autograd::Variable {
24
- public:
25
- Parameter(Tensor&& t) : torch::autograd::Variable(t) { }
26
- };
27
-
28
- void handle_error(torch::Error const & ex)
29
- {
30
- throw Exception(rb_eRuntimeError, ex.what_without_backtrace());
31
- }
32
-
33
- Class rb_cTensor;
3
+ #include <rice/rice.hpp>
34
4
 
35
- std::vector<TensorIndex> index_vector(Array a) {
36
- Object obj;
5
+ void init_fft(Rice::Module& m);
6
+ void init_linalg(Rice::Module& m);
7
+ void init_nn(Rice::Module& m);
8
+ void init_special(Rice::Module& m);
9
+ void init_tensor(Rice::Module& m, Rice::Class& c, Rice::Class& rb_cTensorOptions);
10
+ void init_torch(Rice::Module& m);
37
11
 
38
- std::vector<TensorIndex> indices;
39
- indices.reserve(a.size());
40
-
41
- for (size_t i = 0; i < a.size(); i++) {
42
- obj = a[i];
43
-
44
- if (obj.is_instance_of(rb_cInteger)) {
45
- indices.push_back(from_ruby<int64_t>(obj));
46
- } else if (obj.is_instance_of(rb_cRange)) {
47
- torch::optional<int64_t> start_index = from_ruby<int64_t>(obj.call("begin"));
48
- torch::optional<int64_t> stop_index = -1;
49
-
50
- Object end = obj.call("end");
51
- if (!end.is_nil()) {
52
- stop_index = from_ruby<int64_t>(end);
53
- }
54
-
55
- Object exclude_end = obj.call("exclude_end?");
56
- if (!exclude_end) {
57
- if (stop_index.value() == -1) {
58
- stop_index = torch::nullopt;
59
- } else {
60
- stop_index = stop_index.value() + 1;
61
- }
62
- }
63
-
64
- indices.push_back(torch::indexing::Slice(start_index, stop_index));
65
- } else if (obj.is_instance_of(rb_cTensor)) {
66
- indices.push_back(from_ruby<Tensor>(obj));
67
- } else if (obj.is_nil()) {
68
- indices.push_back(torch::indexing::None);
69
- } else if (obj == True || obj == False) {
70
- indices.push_back(from_ruby<bool>(obj));
71
- } else {
72
- throw Exception(rb_eArgError, "Unsupported index type: %s", rb_obj_classname(obj));
73
- }
74
- }
75
- return indices;
76
- }
12
+ void init_backends(Rice::Module& m);
13
+ void init_cuda(Rice::Module& m);
14
+ void init_device(Rice::Module& m);
15
+ void init_ivalue(Rice::Module& m, Rice::Class& rb_cIValue);
16
+ void init_random(Rice::Module& m);
77
17
 
78
18
  extern "C"
79
19
  void Init_ext()
80
20
  {
81
- Module rb_mTorch = define_module("Torch");
82
- rb_mTorch.add_handler<torch::Error>(handle_error);
83
- add_torch_functions(rb_mTorch);
84
-
85
- rb_cTensor = define_class_under<torch::Tensor>(rb_mTorch, "Tensor");
86
- rb_cTensor.add_handler<torch::Error>(handle_error);
87
- add_tensor_functions(rb_cTensor);
88
- THPVariableClass = rb_cTensor.value();
89
-
90
- Module rb_mNN = define_module_under(rb_mTorch, "NN");
91
- rb_mNN.add_handler<torch::Error>(handle_error);
92
- add_nn_functions(rb_mNN);
93
-
94
- Module rb_mRandom = define_module_under(rb_mTorch, "Random")
95
- .add_handler<torch::Error>(handle_error)
96
- .define_singleton_method(
97
- "initial_seed",
98
- *[]() {
99
- return at::detail::getDefaultCPUGenerator().current_seed();
100
- })
101
- .define_singleton_method(
102
- "seed",
103
- *[]() {
104
- // TODO set for CUDA when available
105
- auto generator = at::detail::getDefaultCPUGenerator();
106
- return generator.seed();
107
- });
108
-
109
- // https://pytorch.org/cppdocs/api/structc10_1_1_i_value.html
110
- Class rb_cIValue = define_class_under<torch::IValue>(rb_mTorch, "IValue")
111
- .add_handler<torch::Error>(handle_error)
112
- .define_constructor(Constructor<torch::IValue>())
113
- .define_method("bool?", &torch::IValue::isBool)
114
- .define_method("bool_list?", &torch::IValue::isBoolList)
115
- .define_method("capsule?", &torch::IValue::isCapsule)
116
- .define_method("custom_class?", &torch::IValue::isCustomClass)
117
- .define_method("device?", &torch::IValue::isDevice)
118
- .define_method("double?", &torch::IValue::isDouble)
119
- .define_method("double_list?", &torch::IValue::isDoubleList)
120
- .define_method("future?", &torch::IValue::isFuture)
121
- // .define_method("generator?", &torch::IValue::isGenerator)
122
- .define_method("generic_dict?", &torch::IValue::isGenericDict)
123
- .define_method("list?", &torch::IValue::isList)
124
- .define_method("int?", &torch::IValue::isInt)
125
- .define_method("int_list?", &torch::IValue::isIntList)
126
- .define_method("module?", &torch::IValue::isModule)
127
- .define_method("none?", &torch::IValue::isNone)
128
- .define_method("object?", &torch::IValue::isObject)
129
- .define_method("ptr_type?", &torch::IValue::isPtrType)
130
- .define_method("py_object?", &torch::IValue::isPyObject)
131
- .define_method("r_ref?", &torch::IValue::isRRef)
132
- .define_method("scalar?", &torch::IValue::isScalar)
133
- .define_method("string?", &torch::IValue::isString)
134
- .define_method("tensor?", &torch::IValue::isTensor)
135
- .define_method("tensor_list?", &torch::IValue::isTensorList)
136
- .define_method("tuple?", &torch::IValue::isTuple)
137
- .define_method(
138
- "to_bool",
139
- *[](torch::IValue& self) {
140
- return self.toBool();
141
- })
142
- .define_method(
143
- "to_double",
144
- *[](torch::IValue& self) {
145
- return self.toDouble();
146
- })
147
- .define_method(
148
- "to_int",
149
- *[](torch::IValue& self) {
150
- return self.toInt();
151
- })
152
- .define_method(
153
- "to_list",
154
- *[](torch::IValue& self) {
155
- auto list = self.toListRef();
156
- Array obj;
157
- for (auto& elem : list) {
158
- obj.push(to_ruby<torch::IValue>(torch::IValue{elem}));
159
- }
160
- return obj;
161
- })
162
- .define_method(
163
- "to_string_ref",
164
- *[](torch::IValue& self) {
165
- return self.toStringRef();
166
- })
167
- .define_method(
168
- "to_tensor",
169
- *[](torch::IValue& self) {
170
- return self.toTensor();
171
- })
172
- .define_method(
173
- "to_generic_dict",
174
- *[](torch::IValue& self) {
175
- auto dict = self.toGenericDict();
176
- Hash obj;
177
- for (auto& pair : dict) {
178
- obj[to_ruby<torch::IValue>(torch::IValue{pair.key()})] = to_ruby<torch::IValue>(torch::IValue{pair.value()});
179
- }
180
- return obj;
181
- })
182
- .define_singleton_method(
183
- "from_tensor",
184
- *[](torch::Tensor& v) {
185
- return torch::IValue(v);
186
- })
187
- // TODO create specialized list types?
188
- .define_singleton_method(
189
- "from_list",
190
- *[](Array obj) {
191
- c10::impl::GenericList list(c10::AnyType::get());
192
- for (auto entry : obj) {
193
- list.push_back(from_ruby<torch::IValue>(entry));
194
- }
195
- return torch::IValue(list);
196
- })
197
- .define_singleton_method(
198
- "from_string",
199
- *[](String v) {
200
- return torch::IValue(v.str());
201
- })
202
- .define_singleton_method(
203
- "from_int",
204
- *[](int64_t v) {
205
- return torch::IValue(v);
206
- })
207
- .define_singleton_method(
208
- "from_double",
209
- *[](double v) {
210
- return torch::IValue(v);
211
- })
212
- .define_singleton_method(
213
- "from_bool",
214
- *[](bool v) {
215
- return torch::IValue(v);
216
- })
217
- // see https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/python/pybind_utils.h
218
- // createGenericDict and toIValue
219
- .define_singleton_method(
220
- "from_dict",
221
- *[](Hash obj) {
222
- auto key_type = c10::AnyType::get();
223
- auto value_type = c10::AnyType::get();
224
- c10::impl::GenericDict elems(key_type, value_type);
225
- elems.reserve(obj.size());
226
- for (auto entry : obj) {
227
- elems.insert(from_ruby<torch::IValue>(entry.first), from_ruby<torch::IValue>((Object) entry.second));
228
- }
229
- return torch::IValue(std::move(elems));
230
- });
231
-
232
- rb_mTorch.define_singleton_method(
233
- "grad_enabled?",
234
- *[]() {
235
- return torch::GradMode::is_enabled();
236
- })
237
- .define_singleton_method(
238
- "_set_grad_enabled",
239
- *[](bool enabled) {
240
- torch::GradMode::set_enabled(enabled);
241
- })
242
- .define_singleton_method(
243
- "manual_seed",
244
- *[](uint64_t seed) {
245
- return torch::manual_seed(seed);
246
- })
247
- // config
248
- .define_singleton_method(
249
- "show_config",
250
- *[] {
251
- return torch::show_config();
252
- })
253
- .define_singleton_method(
254
- "parallel_info",
255
- *[] {
256
- return torch::get_parallel_info();
257
- })
258
- // begin operations
259
- .define_singleton_method(
260
- "_save",
261
- *[](const torch::IValue &value) {
262
- auto v = torch::pickle_save(value);
263
- std::string str(v.begin(), v.end());
264
- return str;
265
- })
266
- .define_singleton_method(
267
- "_load",
268
- *[](const std::string &s) {
269
- std::vector<char> v;
270
- std::copy(s.begin(), s.end(), std::back_inserter(v));
271
- // https://github.com/pytorch/pytorch/issues/20356#issuecomment-567663701
272
- return torch::pickle_load(v);
273
- })
274
- .define_singleton_method(
275
- "_from_blob",
276
- *[](String s, std::vector<int64_t> size, const torch::TensorOptions &options) {
277
- void *data = const_cast<char *>(s.c_str());
278
- return torch::from_blob(data, size, options);
279
- })
280
- .define_singleton_method(
281
- "_tensor",
282
- *[](Array a, std::vector<int64_t> size, const torch::TensorOptions &options) {
283
- auto dtype = options.dtype();
284
- torch::Tensor t;
285
- if (dtype == torch::kBool) {
286
- std::vector<uint8_t> vec;
287
- for (size_t i = 0; i < a.size(); i++) {
288
- vec.push_back(from_ruby<bool>(a[i]));
289
- }
290
- t = torch::tensor(vec, options);
291
- } else {
292
- std::vector<float> vec;
293
- for (size_t i = 0; i < a.size(); i++) {
294
- vec.push_back(from_ruby<float>(a[i]));
295
- }
296
- // hack for requires_grad error
297
- if (options.requires_grad()) {
298
- t = torch::tensor(vec, options.requires_grad(c10::nullopt));
299
- t.set_requires_grad(true);
300
- } else {
301
- t = torch::tensor(vec, options);
302
- }
303
- }
304
- return t.reshape(size);
305
- });
306
-
307
- rb_cTensor
308
- .define_method("cuda?", &torch::Tensor::is_cuda)
309
- .define_method("sparse?", &torch::Tensor::is_sparse)
310
- .define_method("quantized?", &torch::Tensor::is_quantized)
311
- .define_method("dim", &torch::Tensor::dim)
312
- .define_method("numel", &torch::Tensor::numel)
313
- .define_method("element_size", &torch::Tensor::element_size)
314
- .define_method("requires_grad", &torch::Tensor::requires_grad)
315
- // in C++ for performance
316
- .define_method(
317
- "shape",
318
- *[](Tensor& self) {
319
- Array a;
320
- for (auto &size : self.sizes()) {
321
- a.push(size);
322
- }
323
- return a;
324
- })
325
- .define_method(
326
- "_strides",
327
- *[](Tensor& self) {
328
- Array a;
329
- for (auto &stride : self.strides()) {
330
- a.push(stride);
331
- }
332
- return a;
333
- })
334
- .define_method(
335
- "_index",
336
- *[](Tensor& self, Array indices) {
337
- auto vec = index_vector(indices);
338
- return self.index(vec);
339
- })
340
- .define_method(
341
- "_index_put_custom",
342
- *[](Tensor& self, Array indices, torch::Tensor& value) {
343
- auto vec = index_vector(indices);
344
- return self.index_put_(vec, value);
345
- })
346
- .define_method(
347
- "contiguous?",
348
- *[](Tensor& self) {
349
- return self.is_contiguous();
350
- })
351
- .define_method(
352
- "_requires_grad!",
353
- *[](Tensor& self, bool requires_grad) {
354
- return self.set_requires_grad(requires_grad);
355
- })
356
- .define_method(
357
- "grad",
358
- *[](Tensor& self) {
359
- auto grad = self.grad();
360
- return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
361
- })
362
- .define_method(
363
- "grad=",
364
- *[](Tensor& self, torch::Tensor& grad) {
365
- self.mutable_grad() = grad;
366
- })
367
- .define_method(
368
- "_dtype",
369
- *[](Tensor& self) {
370
- return (int) at::typeMetaToScalarType(self.dtype());
371
- })
372
- .define_method(
373
- "_type",
374
- *[](Tensor& self, int dtype) {
375
- return self.toType((torch::ScalarType) dtype);
376
- })
377
- .define_method(
378
- "_layout",
379
- *[](Tensor& self) {
380
- std::stringstream s;
381
- s << self.layout();
382
- return s.str();
383
- })
384
- .define_method(
385
- "device",
386
- *[](Tensor& self) {
387
- std::stringstream s;
388
- s << self.device();
389
- return s.str();
390
- })
391
- .define_method(
392
- "_data_str",
393
- *[](Tensor& self) {
394
- Tensor tensor = self;
395
-
396
- // move to CPU to get data
397
- if (tensor.device().type() != torch::kCPU) {
398
- torch::Device device("cpu");
399
- tensor = tensor.to(device);
400
- }
401
-
402
- if (!tensor.is_contiguous()) {
403
- tensor = tensor.contiguous();
404
- }
405
-
406
- auto data_ptr = (const char *) tensor.data_ptr();
407
- return std::string(data_ptr, tensor.numel() * tensor.element_size());
408
- })
409
- // for TorchVision
410
- .define_method(
411
- "_data_ptr",
412
- *[](Tensor& self) {
413
- return reinterpret_cast<uintptr_t>(self.data_ptr());
414
- })
415
- // TODO figure out a better way to do this
416
- .define_method(
417
- "_flat_data",
418
- *[](Tensor& self) {
419
- Tensor tensor = self;
420
-
421
- // move to CPU to get data
422
- if (tensor.device().type() != torch::kCPU) {
423
- torch::Device device("cpu");
424
- tensor = tensor.to(device);
425
- }
426
-
427
- Array a;
428
- auto dtype = tensor.dtype();
429
-
430
- Tensor view = tensor.reshape({tensor.numel()});
431
-
432
- // TODO DRY if someone knows C++
433
- if (dtype == torch::kByte) {
434
- for (int i = 0; i < tensor.numel(); i++) {
435
- a.push(view[i].item().to<uint8_t>());
436
- }
437
- } else if (dtype == torch::kChar) {
438
- for (int i = 0; i < tensor.numel(); i++) {
439
- a.push(to_ruby<int>(view[i].item().to<int8_t>()));
440
- }
441
- } else if (dtype == torch::kShort) {
442
- for (int i = 0; i < tensor.numel(); i++) {
443
- a.push(view[i].item().to<int16_t>());
444
- }
445
- } else if (dtype == torch::kInt) {
446
- for (int i = 0; i < tensor.numel(); i++) {
447
- a.push(view[i].item().to<int32_t>());
448
- }
449
- } else if (dtype == torch::kLong) {
450
- for (int i = 0; i < tensor.numel(); i++) {
451
- a.push(view[i].item().to<int64_t>());
452
- }
453
- } else if (dtype == torch::kFloat) {
454
- for (int i = 0; i < tensor.numel(); i++) {
455
- a.push(view[i].item().to<float>());
456
- }
457
- } else if (dtype == torch::kDouble) {
458
- for (int i = 0; i < tensor.numel(); i++) {
459
- a.push(view[i].item().to<double>());
460
- }
461
- } else if (dtype == torch::kBool) {
462
- for (int i = 0; i < tensor.numel(); i++) {
463
- a.push(view[i].item().to<bool>() ? True : False);
464
- }
465
- } else {
466
- throw std::runtime_error("Unsupported type");
467
- }
468
- return a;
469
- })
470
- .define_method(
471
- "_to",
472
- *[](Tensor& self, torch::Device device, int dtype, bool non_blocking, bool copy) {
473
- return self.to(device, (torch::ScalarType) dtype, non_blocking, copy);
474
- })
475
- .define_singleton_method(
476
- "_make_subclass",
477
- *[](Tensor& rd, bool requires_grad) {
478
- auto data = rd.detach();
479
- data.unsafeGetTensorImpl()->set_allow_tensor_metadata_change(true);
480
- auto var = data.set_requires_grad(requires_grad);
481
- return Parameter(std::move(var));
482
- });
483
-
484
- Class rb_cTensorOptions = define_class_under<torch::TensorOptions>(rb_mTorch, "TensorOptions")
485
- .add_handler<torch::Error>(handle_error)
486
- .define_constructor(Constructor<torch::TensorOptions>())
487
- .define_method(
488
- "dtype",
489
- *[](torch::TensorOptions& self, int dtype) {
490
- return self.dtype((torch::ScalarType) dtype);
491
- })
492
- .define_method(
493
- "layout",
494
- *[](torch::TensorOptions& self, std::string layout) {
495
- torch::Layout l;
496
- if (layout == "strided") {
497
- l = torch::kStrided;
498
- } else if (layout == "sparse") {
499
- l = torch::kSparse;
500
- throw std::runtime_error("Sparse layout not supported yet");
501
- } else {
502
- throw std::runtime_error("Unsupported layout: " + layout);
503
- }
504
- return self.layout(l);
505
- })
506
- .define_method(
507
- "device",
508
- *[](torch::TensorOptions& self, std::string device) {
509
- torch::Device d(device);
510
- return self.device(d);
511
- })
512
- .define_method(
513
- "requires_grad",
514
- *[](torch::TensorOptions& self, bool requires_grad) {
515
- return self.requires_grad(requires_grad);
516
- });
517
-
518
- Module rb_mInit = define_module_under(rb_mNN, "Init")
519
- .add_handler<torch::Error>(handle_error)
520
- .define_singleton_method(
521
- "_calculate_gain",
522
- *[](NonlinearityType nonlinearity, double param) {
523
- return torch::nn::init::calculate_gain(nonlinearity, param);
524
- })
525
- .define_singleton_method(
526
- "_uniform!",
527
- *[](Tensor tensor, double low, double high) {
528
- return torch::nn::init::uniform_(tensor, low, high);
529
- })
530
- .define_singleton_method(
531
- "_normal!",
532
- *[](Tensor tensor, double mean, double std) {
533
- return torch::nn::init::normal_(tensor, mean, std);
534
- })
535
- .define_singleton_method(
536
- "_constant!",
537
- *[](Tensor tensor, Scalar value) {
538
- return torch::nn::init::constant_(tensor, value);
539
- })
540
- .define_singleton_method(
541
- "_ones!",
542
- *[](Tensor tensor) {
543
- return torch::nn::init::ones_(tensor);
544
- })
545
- .define_singleton_method(
546
- "_zeros!",
547
- *[](Tensor tensor) {
548
- return torch::nn::init::zeros_(tensor);
549
- })
550
- .define_singleton_method(
551
- "_eye!",
552
- *[](Tensor tensor) {
553
- return torch::nn::init::eye_(tensor);
554
- })
555
- .define_singleton_method(
556
- "_dirac!",
557
- *[](Tensor tensor) {
558
- return torch::nn::init::dirac_(tensor);
559
- })
560
- .define_singleton_method(
561
- "_xavier_uniform!",
562
- *[](Tensor tensor, double gain) {
563
- return torch::nn::init::xavier_uniform_(tensor, gain);
564
- })
565
- .define_singleton_method(
566
- "_xavier_normal!",
567
- *[](Tensor tensor, double gain) {
568
- return torch::nn::init::xavier_normal_(tensor, gain);
569
- })
570
- .define_singleton_method(
571
- "_kaiming_uniform!",
572
- *[](Tensor tensor, double a, FanModeType mode, NonlinearityType nonlinearity) {
573
- return torch::nn::init::kaiming_uniform_(tensor, a, mode, nonlinearity);
574
- })
575
- .define_singleton_method(
576
- "_kaiming_normal!",
577
- *[](Tensor tensor, double a, FanModeType mode, NonlinearityType nonlinearity) {
578
- return torch::nn::init::kaiming_normal_(tensor, a, mode, nonlinearity);
579
- })
580
- .define_singleton_method(
581
- "_orthogonal!",
582
- *[](Tensor tensor, double gain) {
583
- return torch::nn::init::orthogonal_(tensor, gain);
584
- })
585
- .define_singleton_method(
586
- "_sparse!",
587
- *[](Tensor tensor, double sparsity, double std) {
588
- return torch::nn::init::sparse_(tensor, sparsity, std);
589
- });
590
-
591
- Class rb_cParameter = define_class_under<Parameter, torch::Tensor>(rb_mNN, "Parameter")
592
- .add_handler<torch::Error>(handle_error)
593
- .define_method(
594
- "grad",
595
- *[](Parameter& self) {
596
- auto grad = self.grad();
597
- return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
598
- })
599
- .define_method(
600
- "grad=",
601
- *[](Parameter& self, torch::Tensor& grad) {
602
- self.mutable_grad() = grad;
603
- });
604
-
605
- Class rb_cDevice = define_class_under<torch::Device>(rb_mTorch, "Device")
606
- .add_handler<torch::Error>(handle_error)
607
- .define_constructor(Constructor<torch::Device, std::string>())
608
- .define_method("index", &torch::Device::index)
609
- .define_method("index?", &torch::Device::has_index)
610
- .define_method(
611
- "type",
612
- *[](torch::Device& self) {
613
- std::stringstream s;
614
- s << self.type();
615
- return s.str();
616
- });
617
-
618
- Module rb_mCUDA = define_module_under(rb_mTorch, "CUDA")
619
- .add_handler<torch::Error>(handle_error)
620
- .define_singleton_method("available?", &torch::cuda::is_available)
621
- .define_singleton_method("device_count", &torch::cuda::device_count);
21
+ auto m = Rice::define_module("Torch");
22
+
23
+ // need to define certain classes up front to keep Rice happy
24
+ auto rb_cIValue = Rice::define_class_under<torch::IValue>(m, "IValue")
25
+ .define_constructor(Rice::Constructor<torch::IValue>());
26
+ auto rb_cTensor = Rice::define_class_under<torch::Tensor>(m, "Tensor");
27
+ auto rb_cTensorOptions = Rice::define_class_under<torch::TensorOptions>(m, "TensorOptions")
28
+ .define_constructor(Rice::Constructor<torch::TensorOptions>());
29
+
30
+ // keep this order
31
+ init_torch(m);
32
+ init_tensor(m, rb_cTensor, rb_cTensorOptions);
33
+ init_nn(m);
34
+ init_fft(m);
35
+ init_linalg(m);
36
+ init_special(m);
37
+
38
+ init_backends(m);
39
+ init_cuda(m);
40
+ init_device(m);
41
+ init_ivalue(m, rb_cIValue);
42
+ init_random(m);
622
43
  }