torch-rb 0.1.3

Sign up to get free protection for your applications and to get access to all the features.
Files changed (44) hide show
  1. checksums.yaml +7 -0
  2. data/CHANGELOG.md +28 -0
  3. data/LICENSE.txt +46 -0
  4. data/README.md +426 -0
  5. data/ext/torch/ext.cpp +839 -0
  6. data/ext/torch/extconf.rb +25 -0
  7. data/lib/torch-rb.rb +1 -0
  8. data/lib/torch.rb +422 -0
  9. data/lib/torch/ext.bundle +0 -0
  10. data/lib/torch/inspector.rb +85 -0
  11. data/lib/torch/nn/alpha_dropout.rb +9 -0
  12. data/lib/torch/nn/conv2d.rb +37 -0
  13. data/lib/torch/nn/convnd.rb +41 -0
  14. data/lib/torch/nn/dropout.rb +9 -0
  15. data/lib/torch/nn/dropout2d.rb +9 -0
  16. data/lib/torch/nn/dropout3d.rb +9 -0
  17. data/lib/torch/nn/dropoutnd.rb +15 -0
  18. data/lib/torch/nn/embedding.rb +52 -0
  19. data/lib/torch/nn/feature_alpha_dropout.rb +9 -0
  20. data/lib/torch/nn/functional.rb +100 -0
  21. data/lib/torch/nn/init.rb +30 -0
  22. data/lib/torch/nn/linear.rb +36 -0
  23. data/lib/torch/nn/module.rb +85 -0
  24. data/lib/torch/nn/mse_loss.rb +13 -0
  25. data/lib/torch/nn/parameter.rb +14 -0
  26. data/lib/torch/nn/relu.rb +13 -0
  27. data/lib/torch/nn/sequential.rb +29 -0
  28. data/lib/torch/optim/adadelta.rb +57 -0
  29. data/lib/torch/optim/adagrad.rb +71 -0
  30. data/lib/torch/optim/adam.rb +81 -0
  31. data/lib/torch/optim/adamax.rb +68 -0
  32. data/lib/torch/optim/adamw.rb +82 -0
  33. data/lib/torch/optim/asgd.rb +65 -0
  34. data/lib/torch/optim/lr_scheduler/lr_scheduler.rb +33 -0
  35. data/lib/torch/optim/lr_scheduler/step_lr.rb +17 -0
  36. data/lib/torch/optim/optimizer.rb +62 -0
  37. data/lib/torch/optim/rmsprop.rb +76 -0
  38. data/lib/torch/optim/rprop.rb +68 -0
  39. data/lib/torch/optim/sgd.rb +60 -0
  40. data/lib/torch/tensor.rb +196 -0
  41. data/lib/torch/utils/data/data_loader.rb +27 -0
  42. data/lib/torch/utils/data/tensor_dataset.rb +22 -0
  43. data/lib/torch/version.rb +3 -0
  44. metadata +169 -0
@@ -0,0 +1,839 @@
1
+ #include <sstream>
2
+
3
+ #include <torch/torch.h>
4
+
5
+ #include <rice/Array.hpp>
6
+ #include <rice/Class.hpp>
7
+ #include <rice/Constructor.hpp>
8
+
9
+ using namespace Rice;
10
+
11
+ template<>
12
+ inline
13
+ long long from_ruby<long long>(Object x)
14
+ {
15
+ return NUM2LL(x);
16
+ }
17
+
18
+ template<>
19
+ inline
20
+ Object to_ruby<long long>(long long const & x)
21
+ {
22
+ return LL2NUM(x);
23
+ }
24
+
25
+ template<>
26
+ inline
27
+ unsigned long long from_ruby<unsigned long long>(Object x)
28
+ {
29
+ return NUM2ULL(x);
30
+ }
31
+
32
+ template<>
33
+ inline
34
+ Object to_ruby<unsigned long long>(unsigned long long const & x)
35
+ {
36
+ return ULL2NUM(x);
37
+ }
38
+
39
+ template<>
40
+ inline
41
+ short from_ruby<short>(Object x)
42
+ {
43
+ return NUM2SHORT(x);
44
+ }
45
+
46
+ template<>
47
+ inline
48
+ Object to_ruby<short>(short const & x)
49
+ {
50
+ return INT2NUM(x);
51
+ }
52
+
53
+ template<>
54
+ inline
55
+ unsigned short from_ruby<unsigned short>(Object x)
56
+ {
57
+ return NUM2USHORT(x);
58
+ }
59
+
60
+ template<>
61
+ inline
62
+ Object to_ruby<unsigned short>(unsigned short const & x)
63
+ {
64
+ return UINT2NUM(x);
65
+ }
66
+
67
+ // need to wrap torch::IntArrayRef() since
68
+ // it doesn't own underlying data
69
+ class IntArrayRef {
70
+ std::vector<int64_t> vec;
71
+ public:
72
+ IntArrayRef(Object o) {
73
+ Array a = Array(o);
74
+ for (size_t i = 0; i < a.size(); i++) {
75
+ vec.push_back(from_ruby<int64_t>(a[i]));
76
+ }
77
+ }
78
+ operator torch::IntArrayRef() {
79
+ return torch::IntArrayRef(vec);
80
+ }
81
+ };
82
+
83
+ template<>
84
+ inline
85
+ IntArrayRef from_ruby<IntArrayRef>(Object x)
86
+ {
87
+ return IntArrayRef(x);
88
+ }
89
+
90
+ // for now
91
+ class Scalar {
92
+ torch::Scalar value;
93
+ public:
94
+ Scalar(Object o) {
95
+ // TODO cast based on Ruby type
96
+ if (o.rb_type() == T_FIXNUM) {
97
+ value = torch::Scalar(from_ruby<int64_t>(o));
98
+ } else {
99
+ value = torch::Scalar(from_ruby<float>(o));
100
+ }
101
+ }
102
+ operator torch::Scalar() {
103
+ return value;
104
+ }
105
+ };
106
+
107
+ template<>
108
+ inline
109
+ Scalar from_ruby<Scalar>(Object x)
110
+ {
111
+ return Scalar(x);
112
+ }
113
+
114
+ class TensorList {
115
+ std::vector<torch::Tensor> vec;
116
+ public:
117
+ TensorList(Object o) {
118
+ Array a = Array(o);
119
+ for (size_t i = 0; i < a.size(); i++) {
120
+ vec.push_back(from_ruby<torch::Tensor>(a[i]));
121
+ }
122
+ }
123
+ operator torch::TensorList() {
124
+ return torch::TensorList(vec);
125
+ }
126
+ };
127
+
128
+ template<>
129
+ inline
130
+ TensorList from_ruby<TensorList>(Object x)
131
+ {
132
+ return TensorList(x);
133
+ }
134
+
135
+ extern "C"
136
+ void Init_ext()
137
+ {
138
+ Module rb_mTorch = define_module("Torch")
139
+ .define_singleton_method(
140
+ "grad_enabled?",
141
+ *[]() {
142
+ return torch::GradMode::is_enabled();
143
+ })
144
+ .define_singleton_method(
145
+ "_set_grad_enabled",
146
+ *[](bool enabled) {
147
+ torch::GradMode::set_enabled(enabled);
148
+ })
149
+ .define_singleton_method(
150
+ "floating_point?",
151
+ *[](torch::Tensor& input) {
152
+ return torch::is_floating_point(input);
153
+ })
154
+ .define_singleton_method(
155
+ "manual_seed",
156
+ *[](uint64_t seed) {
157
+ return torch::manual_seed(seed);
158
+ })
159
+ // begin tensor creation
160
+ .define_singleton_method(
161
+ "_arange",
162
+ *[](Scalar start, Scalar end, Scalar step, const torch::TensorOptions &options) {
163
+ return torch::arange(start, end, step, options);
164
+ })
165
+ .define_singleton_method(
166
+ "_empty",
167
+ *[](IntArrayRef size, const torch::TensorOptions &options) {
168
+ return torch::empty(size, options);
169
+ })
170
+ .define_singleton_method(
171
+ "_eye",
172
+ *[](int64_t m, int64_t n, const torch::TensorOptions &options) {
173
+ return torch::eye(m, n, options);
174
+ })
175
+ .define_singleton_method(
176
+ "_full",
177
+ *[](IntArrayRef size, Scalar fill_value, const torch::TensorOptions& options) {
178
+ return torch::full(size, fill_value, options);
179
+ })
180
+ .define_singleton_method(
181
+ "_linspace",
182
+ *[](Scalar start, Scalar end, int64_t steps, const torch::TensorOptions& options) {
183
+ return torch::linspace(start, end, steps, options);
184
+ })
185
+ .define_singleton_method(
186
+ "_logspace",
187
+ *[](Scalar start, Scalar end, int64_t steps, double base, const torch::TensorOptions& options) {
188
+ return torch::logspace(start, end, steps, base, options);
189
+ })
190
+ .define_singleton_method(
191
+ "_ones",
192
+ *[](IntArrayRef size, const torch::TensorOptions &options) {
193
+ return torch::ones(size, options);
194
+ })
195
+ .define_singleton_method(
196
+ "_rand",
197
+ *[](IntArrayRef size, const torch::TensorOptions &options) {
198
+ return torch::rand(size, options);
199
+ })
200
+ .define_singleton_method(
201
+ "_randint",
202
+ *[](int64_t low, int64_t high, IntArrayRef size, const torch::TensorOptions &options) {
203
+ return torch::randint(low, high, size, options);
204
+ })
205
+ .define_singleton_method(
206
+ "_randn",
207
+ *[](IntArrayRef size, const torch::TensorOptions &options) {
208
+ return torch::randn(size, options);
209
+ })
210
+ .define_singleton_method(
211
+ "_randperm",
212
+ *[](int64_t n, const torch::TensorOptions &options) {
213
+ return torch::randperm(n, options);
214
+ })
215
+ .define_singleton_method(
216
+ "_zeros",
217
+ *[](IntArrayRef size, const torch::TensorOptions &options) {
218
+ return torch::zeros(size, options);
219
+ })
220
+ // begin operations
221
+ .define_singleton_method(
222
+ "_mean",
223
+ *[](torch::Tensor& input) {
224
+ return torch::mean(input);
225
+ })
226
+ .define_singleton_method(
227
+ "_mean_dim",
228
+ *[](torch::Tensor& input, int64_t dim, bool keepdim) {
229
+ return torch::mean(input, dim, keepdim);
230
+ })
231
+ .define_singleton_method(
232
+ "_sum",
233
+ *[](torch::Tensor& input) {
234
+ return torch::sum(input);
235
+ })
236
+ .define_singleton_method(
237
+ "_sum_dim",
238
+ *[](torch::Tensor& input, int64_t dim, bool keepdim) {
239
+ return torch::sum(input, dim, keepdim);
240
+ })
241
+ .define_singleton_method(
242
+ "_argmax",
243
+ *[](torch::Tensor& input) {
244
+ return torch::argmax(input);
245
+ })
246
+ .define_singleton_method(
247
+ "_argmax_dim",
248
+ *[](torch::Tensor& input, int64_t dim, bool keepdim) {
249
+ return torch::argmax(input, dim, keepdim);
250
+ })
251
+ .define_singleton_method(
252
+ "_cat",
253
+ *[](TensorList tensors, int64_t dim) {
254
+ return torch::cat(tensors, dim);
255
+ })
256
+ .define_singleton_method(
257
+ "_norm",
258
+ *[](torch::Tensor& input) {
259
+ return torch::norm(input);
260
+ })
261
+ .define_singleton_method(
262
+ "_min",
263
+ *[](torch::Tensor& input) {
264
+ return torch::min(input);
265
+ })
266
+ .define_singleton_method(
267
+ "_max",
268
+ *[](torch::Tensor& input) {
269
+ return torch::max(input);
270
+ })
271
+ .define_singleton_method(
272
+ "_max_out",
273
+ *[](torch::Tensor &max, torch::Tensor &max_indices, const torch::Tensor &input, int64_t dim, bool keepdim) {
274
+ // TODO add return value
275
+ torch::_max_out(max, max_indices, input, dim, keepdim);
276
+ })
277
+ .define_singleton_method(
278
+ "_sqrt",
279
+ *[](torch::Tensor& input) {
280
+ return torch::sqrt(input);
281
+ })
282
+ .define_singleton_method(
283
+ "_exp",
284
+ *[](torch::Tensor& input) {
285
+ return torch::exp(input);
286
+ })
287
+ .define_singleton_method(
288
+ "_log",
289
+ *[](torch::Tensor& input) {
290
+ return torch::log(input);
291
+ })
292
+ .define_singleton_method(
293
+ "_sign",
294
+ *[](torch::Tensor& input) {
295
+ return torch::sign(input);
296
+ })
297
+ .define_singleton_method(
298
+ "_unsqueeze",
299
+ *[](torch::Tensor& input, int64_t dim) {
300
+ return torch::unsqueeze(input, dim);
301
+ })
302
+ .define_singleton_method(
303
+ "_dot",
304
+ *[](torch::Tensor& input, torch::Tensor& tensor) {
305
+ return torch::dot(input, tensor);
306
+ })
307
+ .define_singleton_method(
308
+ "_matmul",
309
+ *[](torch::Tensor& input, torch::Tensor& other) {
310
+ return torch::matmul(input, other);
311
+ })
312
+ .define_singleton_method(
313
+ "_eq",
314
+ *[](torch::Tensor& input, torch::Tensor& other) {
315
+ return torch::eq(input, other);
316
+ })
317
+ .define_singleton_method(
318
+ "_gt",
319
+ // TODO support tensors
320
+ *[](torch::Tensor& input, Scalar other) {
321
+ return torch::gt(input, other);
322
+ })
323
+ .define_singleton_method(
324
+ "_lt",
325
+ // TODO support tensors
326
+ *[](torch::Tensor& input, Scalar other) {
327
+ return torch::lt(input, other);
328
+ })
329
+ .define_singleton_method(
330
+ "_add",
331
+ *[](torch::Tensor& input, torch::Tensor& other) {
332
+ return torch::add(input, other);
333
+ })
334
+ .define_singleton_method(
335
+ "_add_scalar",
336
+ *[](torch::Tensor& input, Scalar other) {
337
+ return torch::add(input, other);
338
+ })
339
+ .define_singleton_method(
340
+ "_add_out",
341
+ *[](torch::Tensor& out, torch::Tensor& input, torch::Tensor& other) {
342
+ return torch::add_out(out, input, other);
343
+ })
344
+ .define_singleton_method(
345
+ "_sub",
346
+ *[](torch::Tensor& input, torch::Tensor& other) {
347
+ return torch::sub(input, other);
348
+ })
349
+ .define_singleton_method(
350
+ "_sub_scalar",
351
+ *[](torch::Tensor& input, Scalar other) {
352
+ return torch::sub(input, other);
353
+ })
354
+ .define_singleton_method(
355
+ "_mul",
356
+ *[](torch::Tensor& input, torch::Tensor& other) {
357
+ return torch::mul(input, other);
358
+ })
359
+ .define_singleton_method(
360
+ "_mul_scalar",
361
+ *[](torch::Tensor& input, Scalar other) {
362
+ return torch::mul(input, other);
363
+ })
364
+ .define_singleton_method(
365
+ "_div",
366
+ *[](torch::Tensor& input, torch::Tensor& other) {
367
+ return torch::div(input, other);
368
+ })
369
+ .define_singleton_method(
370
+ "_div_scalar",
371
+ *[](torch::Tensor& input, Scalar other) {
372
+ return torch::div(input, other);
373
+ })
374
+ .define_singleton_method(
375
+ "_remainder",
376
+ *[](torch::Tensor& input, torch::Tensor& other) {
377
+ return torch::remainder(input, other);
378
+ })
379
+ .define_singleton_method(
380
+ "_remainder_scalar",
381
+ *[](torch::Tensor& input, Scalar other) {
382
+ return torch::remainder(input, other);
383
+ })
384
+ .define_singleton_method(
385
+ "_pow",
386
+ *[](torch::Tensor& input, Scalar exponent) {
387
+ return torch::pow(input, exponent);
388
+ })
389
+ .define_singleton_method(
390
+ "_abs",
391
+ *[](torch::Tensor& input) {
392
+ return torch::abs(input);
393
+ })
394
+ .define_singleton_method(
395
+ "_neg",
396
+ *[](torch::Tensor& input) {
397
+ return torch::neg(input);
398
+ })
399
+ .define_singleton_method(
400
+ "_reshape",
401
+ *[](torch::Tensor& input, IntArrayRef shape) {
402
+ return torch::reshape(input, shape);
403
+ })
404
+ .define_singleton_method(
405
+ "_flatten",
406
+ *[](torch::Tensor& input, int64_t start_dim, int64_t end_dim) {
407
+ return torch::flatten(input, start_dim, end_dim);
408
+ })
409
+ .define_singleton_method(
410
+ "relu",
411
+ *[](torch::Tensor& input) {
412
+ return torch::relu(input);
413
+ })
414
+ .define_singleton_method(
415
+ "conv2d",
416
+ *[](torch::Tensor& input, torch::Tensor& weight, torch::Tensor& bias, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, int64_t groups) {
417
+ return torch::conv2d(input, weight, bias, stride, padding, dilation, groups);
418
+ })
419
+ .define_singleton_method(
420
+ "linear",
421
+ *[](torch::Tensor& input, torch::Tensor& weight, torch::Tensor& bias) {
422
+ return torch::linear(input, weight, bias);
423
+ })
424
+ .define_singleton_method(
425
+ "max_pool2d",
426
+ *[](torch::Tensor& input, IntArrayRef kernel_size) {
427
+ return torch::max_pool2d(input, kernel_size);
428
+ })
429
+ .define_singleton_method(
430
+ "avg_pool2d",
431
+ *[](torch::Tensor& input, IntArrayRef kernel_size) {
432
+ return torch::avg_pool2d(input, kernel_size);
433
+ })
434
+ .define_singleton_method(
435
+ "_dropout",
436
+ *[](torch::Tensor& input, float p, bool train) {
437
+ return torch::dropout(input, p, train);
438
+ })
439
+ .define_singleton_method(
440
+ "_dropout!",
441
+ *[](torch::Tensor& input, float p, bool train) {
442
+ return torch::dropout_(input, p, train);
443
+ })
444
+ .define_singleton_method(
445
+ "_feature_dropout",
446
+ *[](torch::Tensor& input, float p, bool train) {
447
+ return torch::feature_dropout(input, p, train);
448
+ })
449
+ .define_singleton_method(
450
+ "_feature_dropout!",
451
+ *[](torch::Tensor& input, float p, bool train) {
452
+ return torch::feature_dropout_(input, p, train);
453
+ })
454
+ .define_singleton_method(
455
+ "_alpha_dropout",
456
+ *[](torch::Tensor& input, float p, bool train) {
457
+ return torch::alpha_dropout(input, p, train);
458
+ })
459
+ .define_singleton_method(
460
+ "_alpha_dropout!",
461
+ *[](torch::Tensor& input, float p, bool train) {
462
+ return torch::alpha_dropout_(input, p, train);
463
+ })
464
+ .define_singleton_method(
465
+ "_feature_alpha_dropout",
466
+ *[](torch::Tensor& input, float p, bool train) {
467
+ return torch::feature_alpha_dropout(input, p, train);
468
+ })
469
+ .define_singleton_method(
470
+ "_feature_alpha_dropout!",
471
+ *[](torch::Tensor& input, float p, bool train) {
472
+ return torch::feature_alpha_dropout_(input, p, train);
473
+ })
474
+ .define_singleton_method(
475
+ "_embedding",
476
+ // weight and indices are swapped from Python interface
477
+ *[](const torch::Tensor &indices, const torch::Tensor &weight, int64_t padding_idx, bool scale_grad_by_freq, bool sparse) {
478
+ return torch::embedding(weight, indices, padding_idx, scale_grad_by_freq, sparse);
479
+ })
480
+ .define_singleton_method(
481
+ "mse_loss",
482
+ *[](torch::Tensor& input, torch::Tensor& target, std::string reduction) {
483
+ auto red = reduction == "mean" ? Reduction::Mean : Reduction::Sum;
484
+ return torch::mse_loss(input, target, red);
485
+ })
486
+ .define_singleton_method(
487
+ "nll_loss",
488
+ *[](torch::Tensor& input, torch::Tensor& target, std::string reduction) {
489
+ auto red = reduction == "mean" ? Reduction::Mean : Reduction::Sum;
490
+ return torch::nll_loss(input, target, {}, red);
491
+ })
492
+ .define_singleton_method("numel", &torch::numel)
493
+ .define_singleton_method(
494
+ "_from_blob",
495
+ *[](String s, IntArrayRef size, const torch::TensorOptions &options) {
496
+ void *data = const_cast<char *>(s.c_str());
497
+ return torch::from_blob(data, size, options);
498
+ })
499
+ .define_singleton_method(
500
+ "_tensor",
501
+ *[](Object o, IntArrayRef size, const torch::TensorOptions &options) {
502
+ Array a = Array(o);
503
+ std::vector<float> vec;
504
+ for (size_t i = 0; i < a.size(); i++) {
505
+ vec.push_back(from_ruby<float>(a[i]));
506
+ }
507
+ return torch::tensor(vec, options).reshape(size);
508
+ });
509
+
510
+ Class rb_cTensor = define_class_under<torch::Tensor>(rb_mTorch, "Tensor")
511
+ .define_method("cuda?", &torch::Tensor::is_cuda)
512
+ .define_method("distributed?", &torch::Tensor::is_distributed)
513
+ .define_method("complex?", &torch::Tensor::is_complex)
514
+ .define_method("floating_point?", &torch::Tensor::is_floating_point)
515
+ .define_method("signed?", &torch::Tensor::is_signed)
516
+ .define_method("sparse?", &torch::Tensor::is_sparse)
517
+ .define_method("quantized?", &torch::Tensor::is_quantized)
518
+ .define_method("dim", &torch::Tensor::dim)
519
+ .define_method("element_size", &torch::Tensor::element_size)
520
+ .define_method("requires_grad", &torch::Tensor::requires_grad)
521
+ .define_method("view_as", &torch::Tensor::view_as)
522
+ .define_method(
523
+ "addcmul!",
524
+ *[](torch::Tensor& self, Scalar value, const torch::Tensor & tensor1, const torch::Tensor & tensor2) {
525
+ return self.addcmul_(tensor1, tensor2, value);
526
+ })
527
+ .define_method(
528
+ "addcdiv!",
529
+ *[](torch::Tensor& self, Scalar value, const torch::Tensor & tensor1, const torch::Tensor & tensor2) {
530
+ return self.addcdiv_(tensor1, tensor2, value);
531
+ })
532
+ .define_method(
533
+ "zero!",
534
+ *[](torch::Tensor& self) {
535
+ return self.zero_();
536
+ })
537
+ .define_method(
538
+ "detach!",
539
+ *[](torch::Tensor& self) {
540
+ return self.detach_();
541
+ })
542
+ .define_method(
543
+ "_select",
544
+ *[](torch::Tensor& self, int64_t dim, int64_t index) {
545
+ return self.select(dim, index);
546
+ })
547
+ .define_method(
548
+ "_slice",
549
+ *[](torch::Tensor& self, int64_t dim, int64_t start, int64_t end, int64_t step) {
550
+ return self.slice(dim, start, end, step);
551
+ })
552
+ .define_method(
553
+ "_requires_grad!",
554
+ *[](torch::Tensor& self, bool requires_grad) {
555
+ return self.set_requires_grad(requires_grad);
556
+ })
557
+ .define_method(
558
+ "_backward",
559
+ *[](torch::Tensor& self) {
560
+ return self.backward();
561
+ })
562
+ .define_method(
563
+ "_backward_gradient",
564
+ *[](torch::Tensor& self, const torch::Tensor& gradient) {
565
+ return self.backward(gradient);
566
+ })
567
+ .define_method(
568
+ "grad",
569
+ *[](torch::Tensor& self) {
570
+ return self.grad();
571
+ })
572
+ .define_method(
573
+ "_dtype",
574
+ *[](torch::Tensor& self) {
575
+ return (int) at::typeMetaToScalarType(self.dtype());
576
+ })
577
+ .define_method(
578
+ "_type",
579
+ *[](torch::Tensor& self, int dtype) {
580
+ return self.toType((torch::ScalarType) dtype);
581
+ })
582
+ .define_method(
583
+ "_layout",
584
+ *[](torch::Tensor& self) {
585
+ std::stringstream s;
586
+ s << self.layout();
587
+ return s.str();
588
+ })
589
+ .define_method(
590
+ "device",
591
+ *[](torch::Tensor& self) {
592
+ std::stringstream s;
593
+ s << self.device();
594
+ return s.str();
595
+ })
596
+ .define_method(
597
+ "_view",
598
+ *[](torch::Tensor& self, IntArrayRef size) {
599
+ return self.view(size);
600
+ })
601
+ .define_method(
602
+ "resize_as!",
603
+ *[](torch::Tensor& self, torch::Tensor& other) {
604
+ return self.resize_as_(other);
605
+ })
606
+ .define_method(
607
+ "fill!",
608
+ *[](torch::Tensor& self, Scalar value) {
609
+ return self.fill_(value);
610
+ })
611
+ .define_method(
612
+ "_add!",
613
+ *[](torch::Tensor& self, torch::Tensor& other) {
614
+ return self.add_(other);
615
+ })
616
+ .define_method(
617
+ "_add_alpha!",
618
+ *[](torch::Tensor& self, torch::Tensor& other, Scalar alpha) {
619
+ return self.add_(other, alpha);
620
+ })
621
+ .define_method(
622
+ "_add_scalar!",
623
+ *[](torch::Tensor& self, Scalar other) {
624
+ return self.add_(other);
625
+ })
626
+ .define_method(
627
+ "normal!",
628
+ *[](torch::Tensor& self, double mean, double std) {
629
+ return self.normal_(mean, std);
630
+ })
631
+ .define_method(
632
+ "sub!",
633
+ *[](torch::Tensor& self, torch::Tensor& other) {
634
+ return self.sub_(other);
635
+ })
636
+ .define_method(
637
+ "_mul!",
638
+ *[](torch::Tensor& self, torch::Tensor& other) {
639
+ return self.mul_(other);
640
+ })
641
+ .define_method(
642
+ "_mul_scalar!",
643
+ *[](torch::Tensor& self, Scalar other) {
644
+ return self.mul_(other);
645
+ })
646
+ .define_method(
647
+ "div!",
648
+ *[](torch::Tensor& self, torch::Tensor& other) {
649
+ return self.div_(other);
650
+ })
651
+ .define_method(
652
+ "sqrt!",
653
+ *[](torch::Tensor& self) {
654
+ return self.sqrt_();
655
+ })
656
+ .define_method(
657
+ "unsqueeze!",
658
+ *[](torch::Tensor& self, int64_t dim) {
659
+ return self.unsqueeze_(dim);
660
+ })
661
+ .define_method(
662
+ "copy!",
663
+ *[](torch::Tensor& self, torch::Tensor& src) {
664
+ return self.copy_(src);
665
+ })
666
+ .define_method(
667
+ "clone",
668
+ *[](torch::Tensor& self) {
669
+ return self.clone();
670
+ })
671
+ .define_method(
672
+ "log_softmax",
673
+ *[](torch::Tensor& self, int64_t dim) {
674
+ return self.log_softmax(dim);
675
+ })
676
+ .define_method(
677
+ "data",
678
+ *[](torch::Tensor& self) {
679
+ return self.data();
680
+ })
681
+ .define_method(
682
+ "_data",
683
+ *[](torch::Tensor& self) {
684
+ Array a;
685
+ auto dtype = self.dtype();
686
+
687
+ // TODO DRY if someone knows C++
688
+ if (dtype == torch::kByte) {
689
+ uint8_t* data = self.data_ptr<uint8_t>();
690
+ for (int i = 0; i < self.numel(); i++) {
691
+ a.push(data[i]);
692
+ }
693
+ } else if (dtype == torch::kChar) {
694
+ int8_t* data = self.data_ptr<int8_t>();
695
+ for (int i = 0; i < self.numel(); i++) {
696
+ a.push(to_ruby<int>(data[i]));
697
+ }
698
+ } else if (dtype == torch::kShort) {
699
+ int16_t* data = self.data_ptr<int16_t>();
700
+ for (int i = 0; i < self.numel(); i++) {
701
+ a.push(data[i]);
702
+ }
703
+ } else if (dtype == torch::kInt) {
704
+ int32_t* data = self.data_ptr<int32_t>();
705
+ for (int i = 0; i < self.numel(); i++) {
706
+ a.push(data[i]);
707
+ }
708
+ } else if (dtype == torch::kLong) {
709
+ int64_t* data = self.data_ptr<int64_t>();
710
+ for (int i = 0; i < self.numel(); i++) {
711
+ a.push(data[i]);
712
+ }
713
+ } else if (dtype == torch::kFloat) {
714
+ float* data = self.data_ptr<float>();
715
+ for (int i = 0; i < self.numel(); i++) {
716
+ a.push(data[i]);
717
+ }
718
+ } else if (dtype == torch::kDouble) {
719
+ double* data = self.data_ptr<double>();
720
+ for (int i = 0; i < self.numel(); i++) {
721
+ a.push(data[i]);
722
+ }
723
+ } else if (dtype == torch::kBool) {
724
+ bool* data = self.data_ptr<bool>();
725
+ for (int i = 0; i < self.numel(); i++) {
726
+ a.push(data[i] ? True : False);
727
+ }
728
+ } else {
729
+ throw std::runtime_error("Unsupported type");
730
+ }
731
+ return a;
732
+ })
733
+ .define_method(
734
+ "_size",
735
+ *[](torch::Tensor& self, int i) {
736
+ return self.size(i);
737
+ })
738
+ .define_method(
739
+ "_to",
740
+ *[](torch::Tensor& self, torch::Device device, int dtype, bool non_blocking, bool copy) {
741
+ return self.to(device, (torch::ScalarType) dtype, non_blocking, copy);
742
+ })
743
+ .define_singleton_method(
744
+ "_make_subclass",
745
+ *[](torch::Tensor& rd, bool requires_grad) {
746
+ auto data = torch::autograd::as_variable_ref(rd).detach();
747
+ data.unsafeGetTensorImpl()->set_allow_tensor_metadata_change(true);
748
+ auto var = data.set_requires_grad(requires_grad);
749
+ return torch::autograd::Variable(std::move(var));
750
+ });
751
+
752
+ Class rb_cTensorOptions = define_class_under<torch::TensorOptions>(rb_mTorch, "TensorOptions")
753
+ .define_constructor(Constructor<torch::TensorOptions>())
754
+ .define_method(
755
+ "dtype",
756
+ *[](torch::TensorOptions& self, int dtype) {
757
+ return self.dtype((torch::ScalarType) dtype);
758
+ })
759
+ .define_method(
760
+ "layout",
761
+ *[](torch::TensorOptions& self, std::string layout) {
762
+ torch::Layout l;
763
+ if (layout == "strided") {
764
+ l = torch::kStrided;
765
+ } else if (layout == "sparse") {
766
+ l = torch::kSparse;
767
+ throw std::runtime_error("Sparse layout not supported yet");
768
+ } else {
769
+ throw std::runtime_error("Unsupported layout: " + layout);
770
+ }
771
+ return self.layout(l);
772
+ })
773
+ .define_method(
774
+ "device",
775
+ *[](torch::TensorOptions& self, std::string device) {
776
+ torch::DeviceType d;
777
+ if (device == "cpu") {
778
+ d = torch::kCPU;
779
+ } else if (device == "cuda") {
780
+ d = torch::kCUDA;
781
+ } else {
782
+ throw std::runtime_error("Unsupported device: " + device);
783
+ }
784
+ return self.device(d);
785
+ })
786
+ .define_method(
787
+ "requires_grad",
788
+ *[](torch::TensorOptions& self, bool requires_grad) {
789
+ return self.requires_grad(requires_grad);
790
+ });
791
+
792
+ Module rb_mNN = define_module_under(rb_mTorch, "NN");
793
+
794
+ Module rb_mInit = define_module_under(rb_mNN, "Init")
795
+ .define_singleton_method(
796
+ "kaiming_uniform!",
797
+ *[](torch::Tensor& input, double a) {
798
+ return torch::nn::init::kaiming_uniform_(input, a);
799
+ })
800
+ .define_singleton_method(
801
+ "normal!",
802
+ *[](torch::Tensor& input) {
803
+ return torch::nn::init::normal_(input);
804
+ })
805
+ .define_singleton_method(
806
+ "uniform!",
807
+ *[](torch::Tensor& input, double to, double from) {
808
+ return torch::nn::init::uniform_(input, to, from);
809
+ });
810
+
811
+ Class rb_cParameter = define_class_under<torch::autograd::Variable, torch::Tensor>(rb_mNN, "Parameter")
812
+ // TODO return grad or nil to remove need for 2nd function
813
+ .define_method(
814
+ "_grad",
815
+ *[](torch::autograd::Variable& self) {
816
+ return self.grad();
817
+ })
818
+ .define_method(
819
+ "_grad_defined",
820
+ *[](torch::autograd::Variable& self) {
821
+ return self.grad().defined();
822
+ });
823
+
824
+ Class rb_cDevice = define_class_under<torch::Device>(rb_mTorch, "Device")
825
+ .define_constructor(Constructor<torch::Device, std::string>())
826
+ .define_method("index", &torch::Device::index)
827
+ .define_method("index?", &torch::Device::has_index)
828
+ .define_method(
829
+ "type",
830
+ *[](torch::Device& self) {
831
+ std::stringstream s;
832
+ s << self.type();
833
+ return s.str();
834
+ });
835
+
836
+ Module rb_mCUDA = define_module_under(rb_mTorch, "CUDA")
837
+ .define_singleton_method("available?", &torch::cuda::is_available)
838
+ .define_singleton_method("device_count", &torch::cuda::device_count);
839
+ }