torch-rb 0.1.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. checksums.yaml +7 -0
  2. data/CHANGELOG.md +28 -0
  3. data/LICENSE.txt +46 -0
  4. data/README.md +426 -0
  5. data/ext/torch/ext.cpp +839 -0
  6. data/ext/torch/extconf.rb +25 -0
  7. data/lib/torch-rb.rb +1 -0
  8. data/lib/torch.rb +422 -0
  9. data/lib/torch/ext.bundle +0 -0
  10. data/lib/torch/inspector.rb +85 -0
  11. data/lib/torch/nn/alpha_dropout.rb +9 -0
  12. data/lib/torch/nn/conv2d.rb +37 -0
  13. data/lib/torch/nn/convnd.rb +41 -0
  14. data/lib/torch/nn/dropout.rb +9 -0
  15. data/lib/torch/nn/dropout2d.rb +9 -0
  16. data/lib/torch/nn/dropout3d.rb +9 -0
  17. data/lib/torch/nn/dropoutnd.rb +15 -0
  18. data/lib/torch/nn/embedding.rb +52 -0
  19. data/lib/torch/nn/feature_alpha_dropout.rb +9 -0
  20. data/lib/torch/nn/functional.rb +100 -0
  21. data/lib/torch/nn/init.rb +30 -0
  22. data/lib/torch/nn/linear.rb +36 -0
  23. data/lib/torch/nn/module.rb +85 -0
  24. data/lib/torch/nn/mse_loss.rb +13 -0
  25. data/lib/torch/nn/parameter.rb +14 -0
  26. data/lib/torch/nn/relu.rb +13 -0
  27. data/lib/torch/nn/sequential.rb +29 -0
  28. data/lib/torch/optim/adadelta.rb +57 -0
  29. data/lib/torch/optim/adagrad.rb +71 -0
  30. data/lib/torch/optim/adam.rb +81 -0
  31. data/lib/torch/optim/adamax.rb +68 -0
  32. data/lib/torch/optim/adamw.rb +82 -0
  33. data/lib/torch/optim/asgd.rb +65 -0
  34. data/lib/torch/optim/lr_scheduler/lr_scheduler.rb +33 -0
  35. data/lib/torch/optim/lr_scheduler/step_lr.rb +17 -0
  36. data/lib/torch/optim/optimizer.rb +62 -0
  37. data/lib/torch/optim/rmsprop.rb +76 -0
  38. data/lib/torch/optim/rprop.rb +68 -0
  39. data/lib/torch/optim/sgd.rb +60 -0
  40. data/lib/torch/tensor.rb +196 -0
  41. data/lib/torch/utils/data/data_loader.rb +27 -0
  42. data/lib/torch/utils/data/tensor_dataset.rb +22 -0
  43. data/lib/torch/version.rb +3 -0
  44. metadata +169 -0
@@ -0,0 +1,839 @@
1
+ #include <sstream>
2
+
3
+ #include <torch/torch.h>
4
+
5
+ #include <rice/Array.hpp>
6
+ #include <rice/Class.hpp>
7
+ #include <rice/Constructor.hpp>
8
+
9
+ using namespace Rice;
10
+
11
+ template<>
12
+ inline
13
+ long long from_ruby<long long>(Object x)
14
+ {
15
+ return NUM2LL(x);
16
+ }
17
+
18
+ template<>
19
+ inline
20
+ Object to_ruby<long long>(long long const & x)
21
+ {
22
+ return LL2NUM(x);
23
+ }
24
+
25
+ template<>
26
+ inline
27
+ unsigned long long from_ruby<unsigned long long>(Object x)
28
+ {
29
+ return NUM2ULL(x);
30
+ }
31
+
32
+ template<>
33
+ inline
34
+ Object to_ruby<unsigned long long>(unsigned long long const & x)
35
+ {
36
+ return ULL2NUM(x);
37
+ }
38
+
39
+ template<>
40
+ inline
41
+ short from_ruby<short>(Object x)
42
+ {
43
+ return NUM2SHORT(x);
44
+ }
45
+
46
+ template<>
47
+ inline
48
+ Object to_ruby<short>(short const & x)
49
+ {
50
+ return INT2NUM(x);
51
+ }
52
+
53
+ template<>
54
+ inline
55
+ unsigned short from_ruby<unsigned short>(Object x)
56
+ {
57
+ return NUM2USHORT(x);
58
+ }
59
+
60
+ template<>
61
+ inline
62
+ Object to_ruby<unsigned short>(unsigned short const & x)
63
+ {
64
+ return UINT2NUM(x);
65
+ }
66
+
67
+ // need to wrap torch::IntArrayRef() since
68
+ // it doesn't own underlying data
69
+ class IntArrayRef {
70
+ std::vector<int64_t> vec;
71
+ public:
72
+ IntArrayRef(Object o) {
73
+ Array a = Array(o);
74
+ for (size_t i = 0; i < a.size(); i++) {
75
+ vec.push_back(from_ruby<int64_t>(a[i]));
76
+ }
77
+ }
78
+ operator torch::IntArrayRef() {
79
+ return torch::IntArrayRef(vec);
80
+ }
81
+ };
82
+
83
+ template<>
84
+ inline
85
+ IntArrayRef from_ruby<IntArrayRef>(Object x)
86
+ {
87
+ return IntArrayRef(x);
88
+ }
89
+
90
+ // for now
91
+ class Scalar {
92
+ torch::Scalar value;
93
+ public:
94
+ Scalar(Object o) {
95
+ // TODO cast based on Ruby type
96
+ if (o.rb_type() == T_FIXNUM) {
97
+ value = torch::Scalar(from_ruby<int64_t>(o));
98
+ } else {
99
+ value = torch::Scalar(from_ruby<float>(o));
100
+ }
101
+ }
102
+ operator torch::Scalar() {
103
+ return value;
104
+ }
105
+ };
106
+
107
+ template<>
108
+ inline
109
+ Scalar from_ruby<Scalar>(Object x)
110
+ {
111
+ return Scalar(x);
112
+ }
113
+
114
+ class TensorList {
115
+ std::vector<torch::Tensor> vec;
116
+ public:
117
+ TensorList(Object o) {
118
+ Array a = Array(o);
119
+ for (size_t i = 0; i < a.size(); i++) {
120
+ vec.push_back(from_ruby<torch::Tensor>(a[i]));
121
+ }
122
+ }
123
+ operator torch::TensorList() {
124
+ return torch::TensorList(vec);
125
+ }
126
+ };
127
+
128
+ template<>
129
+ inline
130
+ TensorList from_ruby<TensorList>(Object x)
131
+ {
132
+ return TensorList(x);
133
+ }
134
+
135
+ extern "C"
136
+ void Init_ext()
137
+ {
138
+ Module rb_mTorch = define_module("Torch")
139
+ .define_singleton_method(
140
+ "grad_enabled?",
141
+ *[]() {
142
+ return torch::GradMode::is_enabled();
143
+ })
144
+ .define_singleton_method(
145
+ "_set_grad_enabled",
146
+ *[](bool enabled) {
147
+ torch::GradMode::set_enabled(enabled);
148
+ })
149
+ .define_singleton_method(
150
+ "floating_point?",
151
+ *[](torch::Tensor& input) {
152
+ return torch::is_floating_point(input);
153
+ })
154
+ .define_singleton_method(
155
+ "manual_seed",
156
+ *[](uint64_t seed) {
157
+ return torch::manual_seed(seed);
158
+ })
159
+ // begin tensor creation
160
+ .define_singleton_method(
161
+ "_arange",
162
+ *[](Scalar start, Scalar end, Scalar step, const torch::TensorOptions &options) {
163
+ return torch::arange(start, end, step, options);
164
+ })
165
+ .define_singleton_method(
166
+ "_empty",
167
+ *[](IntArrayRef size, const torch::TensorOptions &options) {
168
+ return torch::empty(size, options);
169
+ })
170
+ .define_singleton_method(
171
+ "_eye",
172
+ *[](int64_t m, int64_t n, const torch::TensorOptions &options) {
173
+ return torch::eye(m, n, options);
174
+ })
175
+ .define_singleton_method(
176
+ "_full",
177
+ *[](IntArrayRef size, Scalar fill_value, const torch::TensorOptions& options) {
178
+ return torch::full(size, fill_value, options);
179
+ })
180
+ .define_singleton_method(
181
+ "_linspace",
182
+ *[](Scalar start, Scalar end, int64_t steps, const torch::TensorOptions& options) {
183
+ return torch::linspace(start, end, steps, options);
184
+ })
185
+ .define_singleton_method(
186
+ "_logspace",
187
+ *[](Scalar start, Scalar end, int64_t steps, double base, const torch::TensorOptions& options) {
188
+ return torch::logspace(start, end, steps, base, options);
189
+ })
190
+ .define_singleton_method(
191
+ "_ones",
192
+ *[](IntArrayRef size, const torch::TensorOptions &options) {
193
+ return torch::ones(size, options);
194
+ })
195
+ .define_singleton_method(
196
+ "_rand",
197
+ *[](IntArrayRef size, const torch::TensorOptions &options) {
198
+ return torch::rand(size, options);
199
+ })
200
+ .define_singleton_method(
201
+ "_randint",
202
+ *[](int64_t low, int64_t high, IntArrayRef size, const torch::TensorOptions &options) {
203
+ return torch::randint(low, high, size, options);
204
+ })
205
+ .define_singleton_method(
206
+ "_randn",
207
+ *[](IntArrayRef size, const torch::TensorOptions &options) {
208
+ return torch::randn(size, options);
209
+ })
210
+ .define_singleton_method(
211
+ "_randperm",
212
+ *[](int64_t n, const torch::TensorOptions &options) {
213
+ return torch::randperm(n, options);
214
+ })
215
+ .define_singleton_method(
216
+ "_zeros",
217
+ *[](IntArrayRef size, const torch::TensorOptions &options) {
218
+ return torch::zeros(size, options);
219
+ })
220
+ // begin operations
221
+ .define_singleton_method(
222
+ "_mean",
223
+ *[](torch::Tensor& input) {
224
+ return torch::mean(input);
225
+ })
226
+ .define_singleton_method(
227
+ "_mean_dim",
228
+ *[](torch::Tensor& input, int64_t dim, bool keepdim) {
229
+ return torch::mean(input, dim, keepdim);
230
+ })
231
+ .define_singleton_method(
232
+ "_sum",
233
+ *[](torch::Tensor& input) {
234
+ return torch::sum(input);
235
+ })
236
+ .define_singleton_method(
237
+ "_sum_dim",
238
+ *[](torch::Tensor& input, int64_t dim, bool keepdim) {
239
+ return torch::sum(input, dim, keepdim);
240
+ })
241
+ .define_singleton_method(
242
+ "_argmax",
243
+ *[](torch::Tensor& input) {
244
+ return torch::argmax(input);
245
+ })
246
+ .define_singleton_method(
247
+ "_argmax_dim",
248
+ *[](torch::Tensor& input, int64_t dim, bool keepdim) {
249
+ return torch::argmax(input, dim, keepdim);
250
+ })
251
+ .define_singleton_method(
252
+ "_cat",
253
+ *[](TensorList tensors, int64_t dim) {
254
+ return torch::cat(tensors, dim);
255
+ })
256
+ .define_singleton_method(
257
+ "_norm",
258
+ *[](torch::Tensor& input) {
259
+ return torch::norm(input);
260
+ })
261
+ .define_singleton_method(
262
+ "_min",
263
+ *[](torch::Tensor& input) {
264
+ return torch::min(input);
265
+ })
266
+ .define_singleton_method(
267
+ "_max",
268
+ *[](torch::Tensor& input) {
269
+ return torch::max(input);
270
+ })
271
+ .define_singleton_method(
272
+ "_max_out",
273
+ *[](torch::Tensor &max, torch::Tensor &max_indices, const torch::Tensor &input, int64_t dim, bool keepdim) {
274
+ // TODO add return value
275
+ torch::_max_out(max, max_indices, input, dim, keepdim);
276
+ })
277
+ .define_singleton_method(
278
+ "_sqrt",
279
+ *[](torch::Tensor& input) {
280
+ return torch::sqrt(input);
281
+ })
282
+ .define_singleton_method(
283
+ "_exp",
284
+ *[](torch::Tensor& input) {
285
+ return torch::exp(input);
286
+ })
287
+ .define_singleton_method(
288
+ "_log",
289
+ *[](torch::Tensor& input) {
290
+ return torch::log(input);
291
+ })
292
+ .define_singleton_method(
293
+ "_sign",
294
+ *[](torch::Tensor& input) {
295
+ return torch::sign(input);
296
+ })
297
+ .define_singleton_method(
298
+ "_unsqueeze",
299
+ *[](torch::Tensor& input, int64_t dim) {
300
+ return torch::unsqueeze(input, dim);
301
+ })
302
+ .define_singleton_method(
303
+ "_dot",
304
+ *[](torch::Tensor& input, torch::Tensor& tensor) {
305
+ return torch::dot(input, tensor);
306
+ })
307
+ .define_singleton_method(
308
+ "_matmul",
309
+ *[](torch::Tensor& input, torch::Tensor& other) {
310
+ return torch::matmul(input, other);
311
+ })
312
+ .define_singleton_method(
313
+ "_eq",
314
+ *[](torch::Tensor& input, torch::Tensor& other) {
315
+ return torch::eq(input, other);
316
+ })
317
+ .define_singleton_method(
318
+ "_gt",
319
+ // TODO support tensors
320
+ *[](torch::Tensor& input, Scalar other) {
321
+ return torch::gt(input, other);
322
+ })
323
+ .define_singleton_method(
324
+ "_lt",
325
+ // TODO support tensors
326
+ *[](torch::Tensor& input, Scalar other) {
327
+ return torch::lt(input, other);
328
+ })
329
+ .define_singleton_method(
330
+ "_add",
331
+ *[](torch::Tensor& input, torch::Tensor& other) {
332
+ return torch::add(input, other);
333
+ })
334
+ .define_singleton_method(
335
+ "_add_scalar",
336
+ *[](torch::Tensor& input, Scalar other) {
337
+ return torch::add(input, other);
338
+ })
339
+ .define_singleton_method(
340
+ "_add_out",
341
+ *[](torch::Tensor& out, torch::Tensor& input, torch::Tensor& other) {
342
+ return torch::add_out(out, input, other);
343
+ })
344
+ .define_singleton_method(
345
+ "_sub",
346
+ *[](torch::Tensor& input, torch::Tensor& other) {
347
+ return torch::sub(input, other);
348
+ })
349
+ .define_singleton_method(
350
+ "_sub_scalar",
351
+ *[](torch::Tensor& input, Scalar other) {
352
+ return torch::sub(input, other);
353
+ })
354
+ .define_singleton_method(
355
+ "_mul",
356
+ *[](torch::Tensor& input, torch::Tensor& other) {
357
+ return torch::mul(input, other);
358
+ })
359
+ .define_singleton_method(
360
+ "_mul_scalar",
361
+ *[](torch::Tensor& input, Scalar other) {
362
+ return torch::mul(input, other);
363
+ })
364
+ .define_singleton_method(
365
+ "_div",
366
+ *[](torch::Tensor& input, torch::Tensor& other) {
367
+ return torch::div(input, other);
368
+ })
369
+ .define_singleton_method(
370
+ "_div_scalar",
371
+ *[](torch::Tensor& input, Scalar other) {
372
+ return torch::div(input, other);
373
+ })
374
+ .define_singleton_method(
375
+ "_remainder",
376
+ *[](torch::Tensor& input, torch::Tensor& other) {
377
+ return torch::remainder(input, other);
378
+ })
379
+ .define_singleton_method(
380
+ "_remainder_scalar",
381
+ *[](torch::Tensor& input, Scalar other) {
382
+ return torch::remainder(input, other);
383
+ })
384
+ .define_singleton_method(
385
+ "_pow",
386
+ *[](torch::Tensor& input, Scalar exponent) {
387
+ return torch::pow(input, exponent);
388
+ })
389
+ .define_singleton_method(
390
+ "_abs",
391
+ *[](torch::Tensor& input) {
392
+ return torch::abs(input);
393
+ })
394
+ .define_singleton_method(
395
+ "_neg",
396
+ *[](torch::Tensor& input) {
397
+ return torch::neg(input);
398
+ })
399
+ .define_singleton_method(
400
+ "_reshape",
401
+ *[](torch::Tensor& input, IntArrayRef shape) {
402
+ return torch::reshape(input, shape);
403
+ })
404
+ .define_singleton_method(
405
+ "_flatten",
406
+ *[](torch::Tensor& input, int64_t start_dim, int64_t end_dim) {
407
+ return torch::flatten(input, start_dim, end_dim);
408
+ })
409
+ .define_singleton_method(
410
+ "relu",
411
+ *[](torch::Tensor& input) {
412
+ return torch::relu(input);
413
+ })
414
+ .define_singleton_method(
415
+ "conv2d",
416
+ *[](torch::Tensor& input, torch::Tensor& weight, torch::Tensor& bias, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, int64_t groups) {
417
+ return torch::conv2d(input, weight, bias, stride, padding, dilation, groups);
418
+ })
419
+ .define_singleton_method(
420
+ "linear",
421
+ *[](torch::Tensor& input, torch::Tensor& weight, torch::Tensor& bias) {
422
+ return torch::linear(input, weight, bias);
423
+ })
424
+ .define_singleton_method(
425
+ "max_pool2d",
426
+ *[](torch::Tensor& input, IntArrayRef kernel_size) {
427
+ return torch::max_pool2d(input, kernel_size);
428
+ })
429
+ .define_singleton_method(
430
+ "avg_pool2d",
431
+ *[](torch::Tensor& input, IntArrayRef kernel_size) {
432
+ return torch::avg_pool2d(input, kernel_size);
433
+ })
434
+ .define_singleton_method(
435
+ "_dropout",
436
+ *[](torch::Tensor& input, float p, bool train) {
437
+ return torch::dropout(input, p, train);
438
+ })
439
+ .define_singleton_method(
440
+ "_dropout!",
441
+ *[](torch::Tensor& input, float p, bool train) {
442
+ return torch::dropout_(input, p, train);
443
+ })
444
+ .define_singleton_method(
445
+ "_feature_dropout",
446
+ *[](torch::Tensor& input, float p, bool train) {
447
+ return torch::feature_dropout(input, p, train);
448
+ })
449
+ .define_singleton_method(
450
+ "_feature_dropout!",
451
+ *[](torch::Tensor& input, float p, bool train) {
452
+ return torch::feature_dropout_(input, p, train);
453
+ })
454
+ .define_singleton_method(
455
+ "_alpha_dropout",
456
+ *[](torch::Tensor& input, float p, bool train) {
457
+ return torch::alpha_dropout(input, p, train);
458
+ })
459
+ .define_singleton_method(
460
+ "_alpha_dropout!",
461
+ *[](torch::Tensor& input, float p, bool train) {
462
+ return torch::alpha_dropout_(input, p, train);
463
+ })
464
+ .define_singleton_method(
465
+ "_feature_alpha_dropout",
466
+ *[](torch::Tensor& input, float p, bool train) {
467
+ return torch::feature_alpha_dropout(input, p, train);
468
+ })
469
+ .define_singleton_method(
470
+ "_feature_alpha_dropout!",
471
+ *[](torch::Tensor& input, float p, bool train) {
472
+ return torch::feature_alpha_dropout_(input, p, train);
473
+ })
474
+ .define_singleton_method(
475
+ "_embedding",
476
+ // weight and indices are swapped from Python interface
477
+ *[](const torch::Tensor &indices, const torch::Tensor &weight, int64_t padding_idx, bool scale_grad_by_freq, bool sparse) {
478
+ return torch::embedding(weight, indices, padding_idx, scale_grad_by_freq, sparse);
479
+ })
480
+ .define_singleton_method(
481
+ "mse_loss",
482
+ *[](torch::Tensor& input, torch::Tensor& target, std::string reduction) {
483
+ auto red = reduction == "mean" ? Reduction::Mean : Reduction::Sum;
484
+ return torch::mse_loss(input, target, red);
485
+ })
486
+ .define_singleton_method(
487
+ "nll_loss",
488
+ *[](torch::Tensor& input, torch::Tensor& target, std::string reduction) {
489
+ auto red = reduction == "mean" ? Reduction::Mean : Reduction::Sum;
490
+ return torch::nll_loss(input, target, {}, red);
491
+ })
492
+ .define_singleton_method("numel", &torch::numel)
493
+ .define_singleton_method(
494
+ "_from_blob",
495
+ *[](String s, IntArrayRef size, const torch::TensorOptions &options) {
496
+ void *data = const_cast<char *>(s.c_str());
497
+ return torch::from_blob(data, size, options);
498
+ })
499
+ .define_singleton_method(
500
+ "_tensor",
501
+ *[](Object o, IntArrayRef size, const torch::TensorOptions &options) {
502
+ Array a = Array(o);
503
+ std::vector<float> vec;
504
+ for (size_t i = 0; i < a.size(); i++) {
505
+ vec.push_back(from_ruby<float>(a[i]));
506
+ }
507
+ return torch::tensor(vec, options).reshape(size);
508
+ });
509
+
510
+ Class rb_cTensor = define_class_under<torch::Tensor>(rb_mTorch, "Tensor")
511
+ .define_method("cuda?", &torch::Tensor::is_cuda)
512
+ .define_method("distributed?", &torch::Tensor::is_distributed)
513
+ .define_method("complex?", &torch::Tensor::is_complex)
514
+ .define_method("floating_point?", &torch::Tensor::is_floating_point)
515
+ .define_method("signed?", &torch::Tensor::is_signed)
516
+ .define_method("sparse?", &torch::Tensor::is_sparse)
517
+ .define_method("quantized?", &torch::Tensor::is_quantized)
518
+ .define_method("dim", &torch::Tensor::dim)
519
+ .define_method("element_size", &torch::Tensor::element_size)
520
+ .define_method("requires_grad", &torch::Tensor::requires_grad)
521
+ .define_method("view_as", &torch::Tensor::view_as)
522
+ .define_method(
523
+ "addcmul!",
524
+ *[](torch::Tensor& self, Scalar value, const torch::Tensor & tensor1, const torch::Tensor & tensor2) {
525
+ return self.addcmul_(tensor1, tensor2, value);
526
+ })
527
+ .define_method(
528
+ "addcdiv!",
529
+ *[](torch::Tensor& self, Scalar value, const torch::Tensor & tensor1, const torch::Tensor & tensor2) {
530
+ return self.addcdiv_(tensor1, tensor2, value);
531
+ })
532
+ .define_method(
533
+ "zero!",
534
+ *[](torch::Tensor& self) {
535
+ return self.zero_();
536
+ })
537
+ .define_method(
538
+ "detach!",
539
+ *[](torch::Tensor& self) {
540
+ return self.detach_();
541
+ })
542
+ .define_method(
543
+ "_select",
544
+ *[](torch::Tensor& self, int64_t dim, int64_t index) {
545
+ return self.select(dim, index);
546
+ })
547
+ .define_method(
548
+ "_slice",
549
+ *[](torch::Tensor& self, int64_t dim, int64_t start, int64_t end, int64_t step) {
550
+ return self.slice(dim, start, end, step);
551
+ })
552
+ .define_method(
553
+ "_requires_grad!",
554
+ *[](torch::Tensor& self, bool requires_grad) {
555
+ return self.set_requires_grad(requires_grad);
556
+ })
557
+ .define_method(
558
+ "_backward",
559
+ *[](torch::Tensor& self) {
560
+ return self.backward();
561
+ })
562
+ .define_method(
563
+ "_backward_gradient",
564
+ *[](torch::Tensor& self, const torch::Tensor& gradient) {
565
+ return self.backward(gradient);
566
+ })
567
+ .define_method(
568
+ "grad",
569
+ *[](torch::Tensor& self) {
570
+ return self.grad();
571
+ })
572
+ .define_method(
573
+ "_dtype",
574
+ *[](torch::Tensor& self) {
575
+ return (int) at::typeMetaToScalarType(self.dtype());
576
+ })
577
+ .define_method(
578
+ "_type",
579
+ *[](torch::Tensor& self, int dtype) {
580
+ return self.toType((torch::ScalarType) dtype);
581
+ })
582
+ .define_method(
583
+ "_layout",
584
+ *[](torch::Tensor& self) {
585
+ std::stringstream s;
586
+ s << self.layout();
587
+ return s.str();
588
+ })
589
+ .define_method(
590
+ "device",
591
+ *[](torch::Tensor& self) {
592
+ std::stringstream s;
593
+ s << self.device();
594
+ return s.str();
595
+ })
596
+ .define_method(
597
+ "_view",
598
+ *[](torch::Tensor& self, IntArrayRef size) {
599
+ return self.view(size);
600
+ })
601
+ .define_method(
602
+ "resize_as!",
603
+ *[](torch::Tensor& self, torch::Tensor& other) {
604
+ return self.resize_as_(other);
605
+ })
606
+ .define_method(
607
+ "fill!",
608
+ *[](torch::Tensor& self, Scalar value) {
609
+ return self.fill_(value);
610
+ })
611
+ .define_method(
612
+ "_add!",
613
+ *[](torch::Tensor& self, torch::Tensor& other) {
614
+ return self.add_(other);
615
+ })
616
+ .define_method(
617
+ "_add_alpha!",
618
+ *[](torch::Tensor& self, torch::Tensor& other, Scalar alpha) {
619
+ return self.add_(other, alpha);
620
+ })
621
+ .define_method(
622
+ "_add_scalar!",
623
+ *[](torch::Tensor& self, Scalar other) {
624
+ return self.add_(other);
625
+ })
626
+ .define_method(
627
+ "normal!",
628
+ *[](torch::Tensor& self, double mean, double std) {
629
+ return self.normal_(mean, std);
630
+ })
631
+ .define_method(
632
+ "sub!",
633
+ *[](torch::Tensor& self, torch::Tensor& other) {
634
+ return self.sub_(other);
635
+ })
636
+ .define_method(
637
+ "_mul!",
638
+ *[](torch::Tensor& self, torch::Tensor& other) {
639
+ return self.mul_(other);
640
+ })
641
+ .define_method(
642
+ "_mul_scalar!",
643
+ *[](torch::Tensor& self, Scalar other) {
644
+ return self.mul_(other);
645
+ })
646
+ .define_method(
647
+ "div!",
648
+ *[](torch::Tensor& self, torch::Tensor& other) {
649
+ return self.div_(other);
650
+ })
651
+ .define_method(
652
+ "sqrt!",
653
+ *[](torch::Tensor& self) {
654
+ return self.sqrt_();
655
+ })
656
+ .define_method(
657
+ "unsqueeze!",
658
+ *[](torch::Tensor& self, int64_t dim) {
659
+ return self.unsqueeze_(dim);
660
+ })
661
+ .define_method(
662
+ "copy!",
663
+ *[](torch::Tensor& self, torch::Tensor& src) {
664
+ return self.copy_(src);
665
+ })
666
+ .define_method(
667
+ "clone",
668
+ *[](torch::Tensor& self) {
669
+ return self.clone();
670
+ })
671
+ .define_method(
672
+ "log_softmax",
673
+ *[](torch::Tensor& self, int64_t dim) {
674
+ return self.log_softmax(dim);
675
+ })
676
+ .define_method(
677
+ "data",
678
+ *[](torch::Tensor& self) {
679
+ return self.data();
680
+ })
681
+ .define_method(
682
+ "_data",
683
+ *[](torch::Tensor& self) {
684
+ Array a;
685
+ auto dtype = self.dtype();
686
+
687
+ // TODO DRY if someone knows C++
688
+ if (dtype == torch::kByte) {
689
+ uint8_t* data = self.data_ptr<uint8_t>();
690
+ for (int i = 0; i < self.numel(); i++) {
691
+ a.push(data[i]);
692
+ }
693
+ } else if (dtype == torch::kChar) {
694
+ int8_t* data = self.data_ptr<int8_t>();
695
+ for (int i = 0; i < self.numel(); i++) {
696
+ a.push(to_ruby<int>(data[i]));
697
+ }
698
+ } else if (dtype == torch::kShort) {
699
+ int16_t* data = self.data_ptr<int16_t>();
700
+ for (int i = 0; i < self.numel(); i++) {
701
+ a.push(data[i]);
702
+ }
703
+ } else if (dtype == torch::kInt) {
704
+ int32_t* data = self.data_ptr<int32_t>();
705
+ for (int i = 0; i < self.numel(); i++) {
706
+ a.push(data[i]);
707
+ }
708
+ } else if (dtype == torch::kLong) {
709
+ int64_t* data = self.data_ptr<int64_t>();
710
+ for (int i = 0; i < self.numel(); i++) {
711
+ a.push(data[i]);
712
+ }
713
+ } else if (dtype == torch::kFloat) {
714
+ float* data = self.data_ptr<float>();
715
+ for (int i = 0; i < self.numel(); i++) {
716
+ a.push(data[i]);
717
+ }
718
+ } else if (dtype == torch::kDouble) {
719
+ double* data = self.data_ptr<double>();
720
+ for (int i = 0; i < self.numel(); i++) {
721
+ a.push(data[i]);
722
+ }
723
+ } else if (dtype == torch::kBool) {
724
+ bool* data = self.data_ptr<bool>();
725
+ for (int i = 0; i < self.numel(); i++) {
726
+ a.push(data[i] ? True : False);
727
+ }
728
+ } else {
729
+ throw std::runtime_error("Unsupported type");
730
+ }
731
+ return a;
732
+ })
733
+ .define_method(
734
+ "_size",
735
+ *[](torch::Tensor& self, int i) {
736
+ return self.size(i);
737
+ })
738
+ .define_method(
739
+ "_to",
740
+ *[](torch::Tensor& self, torch::Device device, int dtype, bool non_blocking, bool copy) {
741
+ return self.to(device, (torch::ScalarType) dtype, non_blocking, copy);
742
+ })
743
+ .define_singleton_method(
744
+ "_make_subclass",
745
+ *[](torch::Tensor& rd, bool requires_grad) {
746
+ auto data = torch::autograd::as_variable_ref(rd).detach();
747
+ data.unsafeGetTensorImpl()->set_allow_tensor_metadata_change(true);
748
+ auto var = data.set_requires_grad(requires_grad);
749
+ return torch::autograd::Variable(std::move(var));
750
+ });
751
+
752
+ Class rb_cTensorOptions = define_class_under<torch::TensorOptions>(rb_mTorch, "TensorOptions")
753
+ .define_constructor(Constructor<torch::TensorOptions>())
754
+ .define_method(
755
+ "dtype",
756
+ *[](torch::TensorOptions& self, int dtype) {
757
+ return self.dtype((torch::ScalarType) dtype);
758
+ })
759
+ .define_method(
760
+ "layout",
761
+ *[](torch::TensorOptions& self, std::string layout) {
762
+ torch::Layout l;
763
+ if (layout == "strided") {
764
+ l = torch::kStrided;
765
+ } else if (layout == "sparse") {
766
+ l = torch::kSparse;
767
+ throw std::runtime_error("Sparse layout not supported yet");
768
+ } else {
769
+ throw std::runtime_error("Unsupported layout: " + layout);
770
+ }
771
+ return self.layout(l);
772
+ })
773
+ .define_method(
774
+ "device",
775
+ *[](torch::TensorOptions& self, std::string device) {
776
+ torch::DeviceType d;
777
+ if (device == "cpu") {
778
+ d = torch::kCPU;
779
+ } else if (device == "cuda") {
780
+ d = torch::kCUDA;
781
+ } else {
782
+ throw std::runtime_error("Unsupported device: " + device);
783
+ }
784
+ return self.device(d);
785
+ })
786
+ .define_method(
787
+ "requires_grad",
788
+ *[](torch::TensorOptions& self, bool requires_grad) {
789
+ return self.requires_grad(requires_grad);
790
+ });
791
+
792
+ Module rb_mNN = define_module_under(rb_mTorch, "NN");
793
+
794
+ Module rb_mInit = define_module_under(rb_mNN, "Init")
795
+ .define_singleton_method(
796
+ "kaiming_uniform!",
797
+ *[](torch::Tensor& input, double a) {
798
+ return torch::nn::init::kaiming_uniform_(input, a);
799
+ })
800
+ .define_singleton_method(
801
+ "normal!",
802
+ *[](torch::Tensor& input) {
803
+ return torch::nn::init::normal_(input);
804
+ })
805
+ .define_singleton_method(
806
+ "uniform!",
807
+ *[](torch::Tensor& input, double to, double from) {
808
+ return torch::nn::init::uniform_(input, to, from);
809
+ });
810
+
811
+ Class rb_cParameter = define_class_under<torch::autograd::Variable, torch::Tensor>(rb_mNN, "Parameter")
812
+ // TODO return grad or nil to remove need for 2nd function
813
+ .define_method(
814
+ "_grad",
815
+ *[](torch::autograd::Variable& self) {
816
+ return self.grad();
817
+ })
818
+ .define_method(
819
+ "_grad_defined",
820
+ *[](torch::autograd::Variable& self) {
821
+ return self.grad().defined();
822
+ });
823
+
824
+ Class rb_cDevice = define_class_under<torch::Device>(rb_mTorch, "Device")
825
+ .define_constructor(Constructor<torch::Device, std::string>())
826
+ .define_method("index", &torch::Device::index)
827
+ .define_method("index?", &torch::Device::has_index)
828
+ .define_method(
829
+ "type",
830
+ *[](torch::Device& self) {
831
+ std::stringstream s;
832
+ s << self.type();
833
+ return s.str();
834
+ });
835
+
836
+ Module rb_mCUDA = define_module_under(rb_mTorch, "CUDA")
837
+ .define_singleton_method("available?", &torch::cuda::is_available)
838
+ .define_singleton_method("device_count", &torch::cuda::device_count);
839
+ }