torch-rb 0.1.3 → 0.1.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +9 -0
  3. data/README.md +1 -0
  4. data/ext/torch/ext.cpp +375 -124
  5. data/lib/torch.rb +101 -20
  6. data/lib/torch/ext.bundle +0 -0
  7. data/lib/torch/inspector.rb +23 -19
  8. data/lib/torch/nn/avg_pool2d.rb +14 -0
  9. data/lib/torch/nn/avg_poolnd.rb +9 -0
  10. data/lib/torch/nn/bce_loss.rb +13 -0
  11. data/lib/torch/nn/bilinear.rb +38 -0
  12. data/lib/torch/nn/conv2d.rb +2 -2
  13. data/lib/torch/nn/convnd.rb +3 -3
  14. data/lib/torch/nn/cosine_similarity.rb +15 -0
  15. data/lib/torch/nn/cross_entropy_loss.rb +14 -0
  16. data/lib/torch/nn/ctc_loss.rb +15 -0
  17. data/lib/torch/nn/dropoutnd.rb +2 -2
  18. data/lib/torch/nn/embedding_bag.rb +34 -0
  19. data/lib/torch/nn/functional.rb +101 -13
  20. data/lib/torch/nn/identity.rb +13 -0
  21. data/lib/torch/nn/init.rb +58 -1
  22. data/lib/torch/nn/kl_div_loss.rb +13 -0
  23. data/lib/torch/nn/l1_loss.rb +13 -0
  24. data/lib/torch/nn/leaky_relu.rb +20 -0
  25. data/lib/torch/nn/linear.rb +12 -11
  26. data/lib/torch/nn/log_softmax.rb +14 -0
  27. data/lib/torch/nn/loss.rb +10 -0
  28. data/lib/torch/nn/max_pool2d.rb +9 -0
  29. data/lib/torch/nn/max_poolnd.rb +19 -0
  30. data/lib/torch/nn/module.rb +120 -31
  31. data/lib/torch/nn/mse_loss.rb +2 -2
  32. data/lib/torch/nn/nll_loss.rb +14 -0
  33. data/lib/torch/nn/pairwise_distance.rb +16 -0
  34. data/lib/torch/nn/parameter.rb +0 -4
  35. data/lib/torch/nn/poisson_nll_loss.rb +16 -0
  36. data/lib/torch/nn/prelu.rb +19 -0
  37. data/lib/torch/nn/relu.rb +8 -3
  38. data/lib/torch/nn/sequential.rb +1 -10
  39. data/lib/torch/nn/sigmoid.rb +9 -0
  40. data/lib/torch/nn/softmax.rb +18 -0
  41. data/lib/torch/nn/softmax2d.rb +10 -0
  42. data/lib/torch/nn/softmin.rb +14 -0
  43. data/lib/torch/nn/softplus.rb +19 -0
  44. data/lib/torch/nn/weighted_loss.rb +10 -0
  45. data/lib/torch/random.rb +10 -0
  46. data/lib/torch/tensor.rb +28 -10
  47. data/lib/torch/version.rb +1 -1
  48. metadata +29 -2
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: e7f715179c9a84dc7399b80d93fd61f2bbb58a0156e6084dc4abb23e1d4a1b52
4
- data.tar.gz: 6928379ae7c92a77ad9dde4f4224ec33c6f8575a9b77585c0147e4f5361021de
3
+ metadata.gz: 4faccffe2d2fd29519ad9dcce0560978a07c734831b5f64bb4624a0037f2b08c
4
+ data.tar.gz: 4a8f873a9bb99c2311c856c59e5c43a5dfadd3f4f2460da1370ca1db888b79ad
5
5
  SHA512:
6
- metadata.gz: 9911a9e86d93f1e410776c44fdb3cd9aa06c83d1f0e42fdab8530970bea6520aed7906e96fb8243efd6b957453ebc13678b2b92e4c85b54407030a32c6196e08
7
- data.tar.gz: 0d080f5458a5dcf8fee19ce5e2e342bf6269432de6e78d923036232963ebb80daeea993c0bbf4af2d6da46593ac28a72a8232020a9fcb48acc3276c9e1ebebf3
6
+ metadata.gz: 199b3b47325b72b38786f50c39f0cfb9b11709f02edf4b77e1c4e9198baf5fa2b3924d639d6c0f7e1715193528d4cfab4f53fbba7e2b16e06ac462d37862cf3a
7
+ data.tar.gz: 4514f7aab60d9beabee47db175c9df6e3e1e93080b47eaeacff0f9dd4e8e737b476c8400f84bda4ccdf61aca4f7c81e010bcfbc2fa67f8c629c33cd6dcdcb54c
data/CHANGELOG.md CHANGED
@@ -1,3 +1,12 @@
1
+ ## 0.1.4 (2019-12-01)
2
+
3
+ - Added distance functions
4
+ - Added more activations
5
+ - Added more linear layers
6
+ - Added more loss functions
7
+ - Added more init methods
8
+ - Added support for tensor assignment
9
+
1
10
  ## 0.1.3 (2019-11-30)
2
11
 
3
12
  - Changed to BSD 3-Clause license to match PyTorch
data/README.md CHANGED
@@ -367,6 +367,7 @@ Here are a few full examples:
367
367
 
368
368
  - [Image classification with MNIST](examples/mnist)
369
369
  - [Collaborative filtering with MovieLens](examples/movielens)
370
+ - [Word embeddings](examples/nlp) [master]
370
371
 
371
372
  ## LibTorch Installation
372
373
 
data/ext/torch/ext.cpp CHANGED
@@ -132,6 +132,112 @@ TensorList from_ruby<TensorList>(Object x)
132
132
  return TensorList(x);
133
133
  }
134
134
 
135
+ class FanModeType {
136
+ std::string s;
137
+ public:
138
+ FanModeType(Object o) {
139
+ s = String(o).str();
140
+ }
141
+ // TODO switch NonlinearityType after LibTorch 1.4 release
142
+ operator torch::nn::init::FanMode() {
143
+ if (s == "fan_in") {
144
+ return torch::nn::init::FanMode::FanIn;
145
+ } else if (s == "fan_out") {
146
+ return torch::nn::init::FanMode::FanOut;
147
+ } else {
148
+ throw std::runtime_error("Unsupported nonlinearity type: " + s);
149
+ }
150
+ }
151
+ };
152
+
153
+ template<>
154
+ inline
155
+ FanModeType from_ruby<FanModeType>(Object x)
156
+ {
157
+ return FanModeType(x);
158
+ }
159
+
160
+ class NonlinearityType {
161
+ std::string s;
162
+ public:
163
+ NonlinearityType(Object o) {
164
+ s = String(o).str();
165
+ }
166
+ // TODO switch NonlinearityType after LibTorch 1.4 release
167
+ operator torch::nn::init::Nonlinearity() {
168
+ if (s == "linear") {
169
+ return torch::nn::init::Nonlinearity::Linear;
170
+ } else if (s == "conv1d") {
171
+ return torch::nn::init::Nonlinearity::Conv1D;
172
+ } else if (s == "conv2d") {
173
+ return torch::nn::init::Nonlinearity::Conv2D;
174
+ } else if (s == "conv3d") {
175
+ return torch::nn::init::Nonlinearity::Conv3D;
176
+ } else if (s == "conv_transpose1d") {
177
+ return torch::nn::init::Nonlinearity::ConvTranspose1D;
178
+ } else if (s == "conv_transpose2d") {
179
+ return torch::nn::init::Nonlinearity::ConvTranspose2D;
180
+ } else if (s == "conv_transpose3d") {
181
+ return torch::nn::init::Nonlinearity::ConvTranspose3D;
182
+ } else if (s == "sigmoid") {
183
+ return torch::nn::init::Nonlinearity::Sigmoid;
184
+ } else if (s == "tanh") {
185
+ return torch::nn::init::Nonlinearity::Tanh;
186
+ } else if (s == "relu") {
187
+ return torch::nn::init::Nonlinearity::ReLU;
188
+ } else if (s == "leaky_relu") {
189
+ return torch::nn::init::Nonlinearity::LeakyReLU;
190
+ } else {
191
+ throw std::runtime_error("Unsupported nonlinearity type: " + s);
192
+ }
193
+ }
194
+ };
195
+
196
+ template<>
197
+ inline
198
+ NonlinearityType from_ruby<NonlinearityType>(Object x)
199
+ {
200
+ return NonlinearityType(x);
201
+ }
202
+
203
+ class MyReduction {
204
+ Object value;
205
+ public:
206
+ MyReduction(Object o) {
207
+ value = o;
208
+ }
209
+ operator int64_t() {
210
+ if (value.is_nil()) {
211
+ return Reduction::None;
212
+ }
213
+
214
+ std::string s = String(value).str();
215
+ if (s == "mean") {
216
+ return Reduction::Mean;
217
+ } else if (s == "sum") {
218
+ return Reduction::Sum;
219
+ } else {
220
+ throw std::runtime_error("Unsupported reduction: " + s);
221
+ }
222
+ }
223
+ };
224
+
225
+ template<>
226
+ inline
227
+ MyReduction from_ruby<MyReduction>(Object x)
228
+ {
229
+ return MyReduction(x);
230
+ }
231
+
232
+ typedef torch::Tensor Tensor;
233
+
234
+ Object tensor_array(std::tuple<torch::Tensor, torch::Tensor> x) {
235
+ Array a;
236
+ a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
237
+ a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
238
+ return Object(a);
239
+ }
240
+
135
241
  extern "C"
136
242
  void Init_ext()
137
243
  {
@@ -148,7 +254,7 @@ void Init_ext()
148
254
  })
149
255
  .define_singleton_method(
150
256
  "floating_point?",
151
- *[](torch::Tensor& input) {
257
+ *[](Tensor& input) {
152
258
  return torch::is_floating_point(input);
153
259
  })
154
260
  .define_singleton_method(
@@ -220,32 +326,32 @@ void Init_ext()
220
326
  // begin operations
221
327
  .define_singleton_method(
222
328
  "_mean",
223
- *[](torch::Tensor& input) {
329
+ *[](Tensor& input) {
224
330
  return torch::mean(input);
225
331
  })
226
332
  .define_singleton_method(
227
333
  "_mean_dim",
228
- *[](torch::Tensor& input, int64_t dim, bool keepdim) {
334
+ *[](Tensor& input, int64_t dim, bool keepdim) {
229
335
  return torch::mean(input, dim, keepdim);
230
336
  })
231
337
  .define_singleton_method(
232
338
  "_sum",
233
- *[](torch::Tensor& input) {
339
+ *[](Tensor& input) {
234
340
  return torch::sum(input);
235
341
  })
236
342
  .define_singleton_method(
237
343
  "_sum_dim",
238
- *[](torch::Tensor& input, int64_t dim, bool keepdim) {
344
+ *[](Tensor& input, int64_t dim, bool keepdim) {
239
345
  return torch::sum(input, dim, keepdim);
240
346
  })
241
347
  .define_singleton_method(
242
348
  "_argmax",
243
- *[](torch::Tensor& input) {
349
+ *[](Tensor& input) {
244
350
  return torch::argmax(input);
245
351
  })
246
352
  .define_singleton_method(
247
353
  "_argmax_dim",
248
- *[](torch::Tensor& input, int64_t dim, bool keepdim) {
354
+ *[](Tensor& input, int64_t dim, bool keepdim) {
249
355
  return torch::argmax(input, dim, keepdim);
250
356
  })
251
357
  .define_singleton_method(
@@ -255,239 +361,322 @@ void Init_ext()
255
361
  })
256
362
  .define_singleton_method(
257
363
  "_norm",
258
- *[](torch::Tensor& input) {
364
+ *[](Tensor& input) {
259
365
  return torch::norm(input);
260
366
  })
261
367
  .define_singleton_method(
262
368
  "_min",
263
- *[](torch::Tensor& input) {
369
+ *[](Tensor& input) {
264
370
  return torch::min(input);
265
371
  })
266
372
  .define_singleton_method(
267
373
  "_max",
268
- *[](torch::Tensor& input) {
374
+ *[](Tensor& input) {
269
375
  return torch::max(input);
270
376
  })
271
377
  .define_singleton_method(
272
378
  "_max_out",
273
- *[](torch::Tensor &max, torch::Tensor &max_indices, const torch::Tensor &input, int64_t dim, bool keepdim) {
274
- // TODO add return value
275
- torch::_max_out(max, max_indices, input, dim, keepdim);
379
+ *[](Tensor &max, Tensor &max_indices, const Tensor &input, int64_t dim, bool keepdim) {
380
+ return tensor_array(torch::_max_out(max, max_indices, input, dim, keepdim));
276
381
  })
277
382
  .define_singleton_method(
278
383
  "_sqrt",
279
- *[](torch::Tensor& input) {
384
+ *[](Tensor& input) {
280
385
  return torch::sqrt(input);
281
386
  })
282
387
  .define_singleton_method(
283
388
  "_exp",
284
- *[](torch::Tensor& input) {
389
+ *[](Tensor& input) {
285
390
  return torch::exp(input);
286
391
  })
287
392
  .define_singleton_method(
288
393
  "_log",
289
- *[](torch::Tensor& input) {
394
+ *[](Tensor& input) {
290
395
  return torch::log(input);
291
396
  })
292
397
  .define_singleton_method(
293
398
  "_sign",
294
- *[](torch::Tensor& input) {
399
+ *[](Tensor& input) {
295
400
  return torch::sign(input);
296
401
  })
297
402
  .define_singleton_method(
298
403
  "_unsqueeze",
299
- *[](torch::Tensor& input, int64_t dim) {
404
+ *[](Tensor& input, int64_t dim) {
300
405
  return torch::unsqueeze(input, dim);
301
406
  })
302
407
  .define_singleton_method(
303
408
  "_dot",
304
- *[](torch::Tensor& input, torch::Tensor& tensor) {
409
+ *[](Tensor& input, Tensor& tensor) {
305
410
  return torch::dot(input, tensor);
306
411
  })
307
412
  .define_singleton_method(
308
413
  "_matmul",
309
- *[](torch::Tensor& input, torch::Tensor& other) {
414
+ *[](Tensor& input, Tensor& other) {
310
415
  return torch::matmul(input, other);
311
416
  })
312
417
  .define_singleton_method(
313
418
  "_eq",
314
- *[](torch::Tensor& input, torch::Tensor& other) {
419
+ *[](Tensor& input, Tensor& other) {
315
420
  return torch::eq(input, other);
316
421
  })
317
422
  .define_singleton_method(
318
423
  "_gt",
319
424
  // TODO support tensors
320
- *[](torch::Tensor& input, Scalar other) {
425
+ *[](Tensor& input, Scalar other) {
321
426
  return torch::gt(input, other);
322
427
  })
323
428
  .define_singleton_method(
324
429
  "_lt",
325
430
  // TODO support tensors
326
- *[](torch::Tensor& input, Scalar other) {
431
+ *[](Tensor& input, Scalar other) {
327
432
  return torch::lt(input, other);
328
433
  })
329
434
  .define_singleton_method(
330
435
  "_add",
331
- *[](torch::Tensor& input, torch::Tensor& other) {
436
+ *[](Tensor& input, Tensor& other) {
332
437
  return torch::add(input, other);
333
438
  })
334
439
  .define_singleton_method(
335
440
  "_add_scalar",
336
- *[](torch::Tensor& input, Scalar other) {
441
+ *[](Tensor& input, Scalar other) {
337
442
  return torch::add(input, other);
338
443
  })
339
444
  .define_singleton_method(
340
445
  "_add_out",
341
- *[](torch::Tensor& out, torch::Tensor& input, torch::Tensor& other) {
446
+ *[](Tensor& out, Tensor& input, Tensor& other) {
342
447
  return torch::add_out(out, input, other);
343
448
  })
344
449
  .define_singleton_method(
345
450
  "_sub",
346
- *[](torch::Tensor& input, torch::Tensor& other) {
451
+ *[](Tensor& input, Tensor& other) {
347
452
  return torch::sub(input, other);
348
453
  })
349
454
  .define_singleton_method(
350
455
  "_sub_scalar",
351
- *[](torch::Tensor& input, Scalar other) {
456
+ *[](Tensor& input, Scalar other) {
352
457
  return torch::sub(input, other);
353
458
  })
354
459
  .define_singleton_method(
355
460
  "_mul",
356
- *[](torch::Tensor& input, torch::Tensor& other) {
461
+ *[](Tensor& input, Tensor& other) {
357
462
  return torch::mul(input, other);
358
463
  })
359
464
  .define_singleton_method(
360
465
  "_mul_scalar",
361
- *[](torch::Tensor& input, Scalar other) {
466
+ *[](Tensor& input, Scalar other) {
362
467
  return torch::mul(input, other);
363
468
  })
364
469
  .define_singleton_method(
365
470
  "_div",
366
- *[](torch::Tensor& input, torch::Tensor& other) {
471
+ *[](Tensor& input, Tensor& other) {
367
472
  return torch::div(input, other);
368
473
  })
369
474
  .define_singleton_method(
370
475
  "_div_scalar",
371
- *[](torch::Tensor& input, Scalar other) {
476
+ *[](Tensor& input, Scalar other) {
372
477
  return torch::div(input, other);
373
478
  })
374
479
  .define_singleton_method(
375
480
  "_remainder",
376
- *[](torch::Tensor& input, torch::Tensor& other) {
481
+ *[](Tensor& input, Tensor& other) {
377
482
  return torch::remainder(input, other);
378
483
  })
379
484
  .define_singleton_method(
380
485
  "_remainder_scalar",
381
- *[](torch::Tensor& input, Scalar other) {
486
+ *[](Tensor& input, Scalar other) {
382
487
  return torch::remainder(input, other);
383
488
  })
384
489
  .define_singleton_method(
385
490
  "_pow",
386
- *[](torch::Tensor& input, Scalar exponent) {
491
+ *[](Tensor& input, Scalar exponent) {
387
492
  return torch::pow(input, exponent);
388
493
  })
494
+ .define_singleton_method(
495
+ "_topk",
496
+ *[](Tensor& input, int64_t k) {
497
+ return tensor_array(torch::topk(input, k));
498
+ })
499
+ .define_singleton_method(
500
+ "_sigmoid",
501
+ *[](Tensor& input) {
502
+ return torch::sigmoid(input);
503
+ })
504
+ .define_singleton_method(
505
+ "_softplus",
506
+ *[](const Tensor &input, Scalar beta, Scalar threshold) {
507
+ return torch::softplus(input, beta, threshold);
508
+ })
509
+ .define_singleton_method(
510
+ "_softmax",
511
+ *[](const Tensor &input, int64_t dim) {
512
+ return torch::softmax(input, dim);
513
+ })
514
+ .define_singleton_method(
515
+ "_log_softmax",
516
+ *[](Tensor& input, int64_t dim) {
517
+ return torch::log_softmax(input, dim);
518
+ })
389
519
  .define_singleton_method(
390
520
  "_abs",
391
- *[](torch::Tensor& input) {
521
+ *[](Tensor& input) {
392
522
  return torch::abs(input);
393
523
  })
394
524
  .define_singleton_method(
395
525
  "_neg",
396
- *[](torch::Tensor& input) {
526
+ *[](Tensor& input) {
397
527
  return torch::neg(input);
398
528
  })
399
529
  .define_singleton_method(
400
530
  "_reshape",
401
- *[](torch::Tensor& input, IntArrayRef shape) {
531
+ *[](Tensor& input, IntArrayRef shape) {
402
532
  return torch::reshape(input, shape);
403
533
  })
404
534
  .define_singleton_method(
405
535
  "_flatten",
406
- *[](torch::Tensor& input, int64_t start_dim, int64_t end_dim) {
536
+ *[](Tensor& input, int64_t start_dim, int64_t end_dim) {
407
537
  return torch::flatten(input, start_dim, end_dim);
408
538
  })
409
539
  .define_singleton_method(
410
540
  "relu",
411
- *[](torch::Tensor& input) {
541
+ *[](Tensor& input) {
412
542
  return torch::relu(input);
413
543
  })
544
+ .define_singleton_method(
545
+ "prelu",
546
+ *[](torch::Tensor& input, torch::Tensor& weight) {
547
+ return torch::prelu(input, weight);
548
+ })
549
+ .define_singleton_method(
550
+ "leaky_relu",
551
+ *[](torch::Tensor& input, Scalar negative_slope) {
552
+ return torch::leaky_relu(input, negative_slope);
553
+ })
414
554
  .define_singleton_method(
415
555
  "conv2d",
416
- *[](torch::Tensor& input, torch::Tensor& weight, torch::Tensor& bias, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, int64_t groups) {
556
+ *[](Tensor& input, Tensor& weight, Tensor& bias, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, int64_t groups) {
417
557
  return torch::conv2d(input, weight, bias, stride, padding, dilation, groups);
418
558
  })
559
+ // linear layers
560
+ .define_singleton_method(
561
+ "bilinear",
562
+ *[](const Tensor &input1, const Tensor &input2, const Tensor &weight, const Tensor &bias) {
563
+ return torch::bilinear(input1, input2, weight, bias);
564
+ })
419
565
  .define_singleton_method(
420
566
  "linear",
421
- *[](torch::Tensor& input, torch::Tensor& weight, torch::Tensor& bias) {
567
+ *[](Tensor& input, Tensor& weight, Tensor& bias) {
422
568
  return torch::linear(input, weight, bias);
423
569
  })
570
+ // pooling layers
424
571
  .define_singleton_method(
425
572
  "max_pool2d",
426
- *[](torch::Tensor& input, IntArrayRef kernel_size) {
573
+ *[](Tensor& input, IntArrayRef kernel_size) {
427
574
  return torch::max_pool2d(input, kernel_size);
428
575
  })
429
576
  .define_singleton_method(
430
577
  "avg_pool2d",
431
- *[](torch::Tensor& input, IntArrayRef kernel_size) {
578
+ *[](Tensor& input, IntArrayRef kernel_size) {
432
579
  return torch::avg_pool2d(input, kernel_size);
433
580
  })
434
581
  .define_singleton_method(
435
582
  "_dropout",
436
- *[](torch::Tensor& input, float p, bool train) {
583
+ *[](Tensor& input, float p, bool train) {
437
584
  return torch::dropout(input, p, train);
438
585
  })
439
586
  .define_singleton_method(
440
587
  "_dropout!",
441
- *[](torch::Tensor& input, float p, bool train) {
588
+ *[](Tensor& input, float p, bool train) {
442
589
  return torch::dropout_(input, p, train);
443
590
  })
444
591
  .define_singleton_method(
445
592
  "_feature_dropout",
446
- *[](torch::Tensor& input, float p, bool train) {
593
+ *[](Tensor& input, float p, bool train) {
447
594
  return torch::feature_dropout(input, p, train);
448
595
  })
449
596
  .define_singleton_method(
450
597
  "_feature_dropout!",
451
- *[](torch::Tensor& input, float p, bool train) {
598
+ *[](Tensor& input, float p, bool train) {
452
599
  return torch::feature_dropout_(input, p, train);
453
600
  })
454
601
  .define_singleton_method(
455
602
  "_alpha_dropout",
456
- *[](torch::Tensor& input, float p, bool train) {
603
+ *[](Tensor& input, float p, bool train) {
457
604
  return torch::alpha_dropout(input, p, train);
458
605
  })
459
606
  .define_singleton_method(
460
607
  "_alpha_dropout!",
461
- *[](torch::Tensor& input, float p, bool train) {
608
+ *[](Tensor& input, float p, bool train) {
462
609
  return torch::alpha_dropout_(input, p, train);
463
610
  })
464
611
  .define_singleton_method(
465
612
  "_feature_alpha_dropout",
466
- *[](torch::Tensor& input, float p, bool train) {
613
+ *[](Tensor& input, float p, bool train) {
467
614
  return torch::feature_alpha_dropout(input, p, train);
468
615
  })
469
616
  .define_singleton_method(
470
617
  "_feature_alpha_dropout!",
471
- *[](torch::Tensor& input, float p, bool train) {
618
+ *[](Tensor& input, float p, bool train) {
472
619
  return torch::feature_alpha_dropout_(input, p, train);
473
620
  })
621
+ // sparse layers
474
622
  .define_singleton_method(
475
623
  "_embedding",
476
624
  // weight and indices are swapped from Python interface
477
- *[](const torch::Tensor &indices, const torch::Tensor &weight, int64_t padding_idx, bool scale_grad_by_freq, bool sparse) {
625
+ *[](const Tensor &indices, const Tensor &weight, int64_t padding_idx, bool scale_grad_by_freq, bool sparse) {
478
626
  return torch::embedding(weight, indices, padding_idx, scale_grad_by_freq, sparse);
479
627
  })
628
+ .define_singleton_method(
629
+ "_embedding_bag",
630
+ // weight and indices are swapped from Python interface
631
+ *[](const Tensor &weight, const Tensor &indices, const Tensor &offsets, bool scale_grad_by_freq, int64_t mode, bool sparse, const Tensor &per_sample_weights) {
632
+ return torch::embedding_bag(weight, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights);
633
+ })
634
+ // distance functions
635
+ .define_singleton_method(
636
+ "_cosine_similarity",
637
+ *[](const Tensor &x1, const Tensor &x2, int64_t dim, double eps) {
638
+ return torch::cosine_similarity(x1, x2, dim, eps);
639
+ })
640
+ .define_singleton_method(
641
+ "_pairwise_distance",
642
+ *[](const Tensor &x1, const Tensor &x2, double p, double eps, bool keepdim) {
643
+ return torch::pairwise_distance(x1, x2, p, eps, keepdim);
644
+ })
645
+ // loss functions
646
+ .define_singleton_method(
647
+ "binary_cross_entropy",
648
+ *[](Tensor& input, Tensor& target, MyReduction reduction) {
649
+ return torch::binary_cross_entropy(input, target, {}, reduction);
650
+ })
651
+ .define_singleton_method(
652
+ "ctc_loss",
653
+ *[](const Tensor &log_probs, const Tensor &targets, IntArrayRef input_lengths, IntArrayRef target_lengths, int64_t blank, MyReduction reduction, bool zero_infinity) {
654
+ return torch::ctc_loss(log_probs, targets, input_lengths, target_lengths, blank, reduction, zero_infinity);
655
+ })
656
+ .define_singleton_method(
657
+ "kl_div",
658
+ *[](Tensor& input, Tensor& target, MyReduction reduction) {
659
+ return torch::kl_div(input, target, reduction);
660
+ })
661
+ .define_singleton_method(
662
+ "l1_loss",
663
+ *[](Tensor& input, Tensor& target, MyReduction reduction) {
664
+ return torch::l1_loss(input, target, reduction);
665
+ })
480
666
  .define_singleton_method(
481
667
  "mse_loss",
482
- *[](torch::Tensor& input, torch::Tensor& target, std::string reduction) {
483
- auto red = reduction == "mean" ? Reduction::Mean : Reduction::Sum;
484
- return torch::mse_loss(input, target, red);
668
+ *[](Tensor& input, Tensor& target, MyReduction reduction) {
669
+ return torch::mse_loss(input, target, reduction);
485
670
  })
486
671
  .define_singleton_method(
487
672
  "nll_loss",
488
- *[](torch::Tensor& input, torch::Tensor& target, std::string reduction) {
489
- auto red = reduction == "mean" ? Reduction::Mean : Reduction::Sum;
490
- return torch::nll_loss(input, target, {}, red);
673
+ *[](Tensor& input, Tensor& target, MyReduction reduction, int64_t ignore_index) {
674
+ return torch::nll_loss(input, target, {}, reduction, ignore_index);
675
+ })
676
+ .define_singleton_method(
677
+ "poisson_nll_loss",
678
+ *[](const Tensor &input, const Tensor &target, bool log_input, bool full, double eps, MyReduction reduction) {
679
+ return torch::poisson_nll_loss(input, target, log_input, full, eps, reduction);
491
680
  })
492
681
  .define_singleton_method("numel", &torch::numel)
493
682
  .define_singleton_method(
@@ -500,11 +689,18 @@ void Init_ext()
500
689
  "_tensor",
501
690
  *[](Object o, IntArrayRef size, const torch::TensorOptions &options) {
502
691
  Array a = Array(o);
503
- std::vector<float> vec;
504
- for (size_t i = 0; i < a.size(); i++) {
505
- vec.push_back(from_ruby<float>(a[i]));
692
+ auto dtype = options.dtype();
693
+ torch::Tensor t;
694
+ if (dtype == torch::kBool) {
695
+ throw std::runtime_error("Cannot create bool from tensor method yet");
696
+ } else {
697
+ std::vector<float> vec;
698
+ for (size_t i = 0; i < a.size(); i++) {
699
+ vec.push_back(from_ruby<float>(a[i]));
700
+ }
701
+ t = torch::tensor(vec, options);
506
702
  }
507
- return torch::tensor(vec, options).reshape(size);
703
+ return t.reshape(size);
508
704
  });
509
705
 
510
706
  Class rb_cTensor = define_class_under<torch::Tensor>(rb_mTorch, "Tensor")
@@ -521,166 +717,171 @@ void Init_ext()
521
717
  .define_method("view_as", &torch::Tensor::view_as)
522
718
  .define_method(
523
719
  "addcmul!",
524
- *[](torch::Tensor& self, Scalar value, const torch::Tensor & tensor1, const torch::Tensor & tensor2) {
720
+ *[](Tensor& self, Scalar value, const Tensor & tensor1, const Tensor & tensor2) {
525
721
  return self.addcmul_(tensor1, tensor2, value);
526
722
  })
527
723
  .define_method(
528
724
  "addcdiv!",
529
- *[](torch::Tensor& self, Scalar value, const torch::Tensor & tensor1, const torch::Tensor & tensor2) {
725
+ *[](Tensor& self, Scalar value, const Tensor & tensor1, const Tensor & tensor2) {
530
726
  return self.addcdiv_(tensor1, tensor2, value);
531
727
  })
532
728
  .define_method(
533
729
  "zero!",
534
- *[](torch::Tensor& self) {
730
+ *[](Tensor& self) {
535
731
  return self.zero_();
536
732
  })
733
+ .define_method(
734
+ "detach",
735
+ *[](Tensor& self) {
736
+ return self.detach();
737
+ })
537
738
  .define_method(
538
739
  "detach!",
539
- *[](torch::Tensor& self) {
740
+ *[](Tensor& self) {
540
741
  return self.detach_();
541
742
  })
542
743
  .define_method(
543
744
  "_select",
544
- *[](torch::Tensor& self, int64_t dim, int64_t index) {
745
+ *[](Tensor& self, int64_t dim, int64_t index) {
545
746
  return self.select(dim, index);
546
747
  })
547
748
  .define_method(
548
749
  "_slice",
549
- *[](torch::Tensor& self, int64_t dim, int64_t start, int64_t end, int64_t step) {
750
+ *[](Tensor& self, int64_t dim, int64_t start, int64_t end, int64_t step) {
550
751
  return self.slice(dim, start, end, step);
551
752
  })
552
753
  .define_method(
553
754
  "_requires_grad!",
554
- *[](torch::Tensor& self, bool requires_grad) {
755
+ *[](Tensor& self, bool requires_grad) {
555
756
  return self.set_requires_grad(requires_grad);
556
757
  })
557
758
  .define_method(
558
759
  "_backward",
559
- *[](torch::Tensor& self) {
560
- return self.backward();
561
- })
562
- .define_method(
563
- "_backward_gradient",
564
- *[](torch::Tensor& self, const torch::Tensor& gradient) {
565
- return self.backward(gradient);
760
+ *[](Tensor& self, Object gradient) {
761
+ return gradient.is_nil() ? self.backward() : self.backward(from_ruby<torch::Tensor>(gradient));
566
762
  })
567
763
  .define_method(
568
764
  "grad",
569
- *[](torch::Tensor& self) {
765
+ *[](Tensor& self) {
570
766
  return self.grad();
571
767
  })
572
768
  .define_method(
573
769
  "_dtype",
574
- *[](torch::Tensor& self) {
770
+ *[](Tensor& self) {
575
771
  return (int) at::typeMetaToScalarType(self.dtype());
576
772
  })
577
773
  .define_method(
578
774
  "_type",
579
- *[](torch::Tensor& self, int dtype) {
775
+ *[](Tensor& self, int dtype) {
580
776
  return self.toType((torch::ScalarType) dtype);
581
777
  })
582
778
  .define_method(
583
779
  "_layout",
584
- *[](torch::Tensor& self) {
780
+ *[](Tensor& self) {
585
781
  std::stringstream s;
586
782
  s << self.layout();
587
783
  return s.str();
588
784
  })
589
785
  .define_method(
590
786
  "device",
591
- *[](torch::Tensor& self) {
787
+ *[](Tensor& self) {
592
788
  std::stringstream s;
593
789
  s << self.device();
594
790
  return s.str();
595
791
  })
596
792
  .define_method(
597
793
  "_view",
598
- *[](torch::Tensor& self, IntArrayRef size) {
794
+ *[](Tensor& self, IntArrayRef size) {
599
795
  return self.view(size);
600
796
  })
601
797
  .define_method(
602
798
  "resize_as!",
603
- *[](torch::Tensor& self, torch::Tensor& other) {
799
+ *[](Tensor& self, Tensor& other) {
604
800
  return self.resize_as_(other);
605
801
  })
606
802
  .define_method(
607
803
  "fill!",
608
- *[](torch::Tensor& self, Scalar value) {
804
+ *[](Tensor& self, Scalar value) {
609
805
  return self.fill_(value);
610
806
  })
807
+ .define_method(
808
+ "relu!",
809
+ *[](Tensor& self) {
810
+ return self.relu_();
811
+ })
611
812
  .define_method(
612
813
  "_add!",
613
- *[](torch::Tensor& self, torch::Tensor& other) {
814
+ *[](Tensor& self, Tensor& other) {
614
815
  return self.add_(other);
615
816
  })
616
817
  .define_method(
617
818
  "_add_alpha!",
618
- *[](torch::Tensor& self, torch::Tensor& other, Scalar alpha) {
819
+ *[](Tensor& self, Tensor& other, Scalar alpha) {
619
820
  return self.add_(other, alpha);
620
821
  })
621
822
  .define_method(
622
823
  "_add_scalar!",
623
- *[](torch::Tensor& self, Scalar other) {
824
+ *[](Tensor& self, Scalar other) {
624
825
  return self.add_(other);
625
826
  })
626
827
  .define_method(
627
828
  "normal!",
628
- *[](torch::Tensor& self, double mean, double std) {
829
+ *[](Tensor& self, double mean, double std) {
629
830
  return self.normal_(mean, std);
630
831
  })
832
+ .define_method(
833
+ "random!",
834
+ *[](Tensor& self, int64_t to) {
835
+ return self.random_(to);
836
+ })
631
837
  .define_method(
632
838
  "sub!",
633
- *[](torch::Tensor& self, torch::Tensor& other) {
839
+ *[](Tensor& self, Tensor& other) {
634
840
  return self.sub_(other);
635
841
  })
636
842
  .define_method(
637
843
  "_mul!",
638
- *[](torch::Tensor& self, torch::Tensor& other) {
844
+ *[](Tensor& self, Tensor& other) {
639
845
  return self.mul_(other);
640
846
  })
641
847
  .define_method(
642
848
  "_mul_scalar!",
643
- *[](torch::Tensor& self, Scalar other) {
849
+ *[](Tensor& self, Scalar other) {
644
850
  return self.mul_(other);
645
851
  })
646
852
  .define_method(
647
853
  "div!",
648
- *[](torch::Tensor& self, torch::Tensor& other) {
854
+ *[](Tensor& self, Tensor& other) {
649
855
  return self.div_(other);
650
856
  })
651
857
  .define_method(
652
858
  "sqrt!",
653
- *[](torch::Tensor& self) {
859
+ *[](Tensor& self) {
654
860
  return self.sqrt_();
655
861
  })
656
862
  .define_method(
657
863
  "unsqueeze!",
658
- *[](torch::Tensor& self, int64_t dim) {
864
+ *[](Tensor& self, int64_t dim) {
659
865
  return self.unsqueeze_(dim);
660
866
  })
661
867
  .define_method(
662
868
  "copy!",
663
- *[](torch::Tensor& self, torch::Tensor& src) {
869
+ *[](Tensor& self, Tensor& src) {
664
870
  return self.copy_(src);
665
871
  })
666
872
  .define_method(
667
873
  "clone",
668
- *[](torch::Tensor& self) {
874
+ *[](Tensor& self) {
669
875
  return self.clone();
670
876
  })
671
- .define_method(
672
- "log_softmax",
673
- *[](torch::Tensor& self, int64_t dim) {
674
- return self.log_softmax(dim);
675
- })
676
877
  .define_method(
677
878
  "data",
678
- *[](torch::Tensor& self) {
879
+ *[](Tensor& self) {
679
880
  return self.data();
680
881
  })
681
882
  .define_method(
682
883
  "_data",
683
- *[](torch::Tensor& self) {
884
+ *[](Tensor& self) {
684
885
  Array a;
685
886
  auto dtype = self.dtype();
686
887
 
@@ -732,17 +933,17 @@ void Init_ext()
732
933
  })
733
934
  .define_method(
734
935
  "_size",
735
- *[](torch::Tensor& self, int i) {
936
+ *[](Tensor& self, int i) {
736
937
  return self.size(i);
737
938
  })
738
939
  .define_method(
739
940
  "_to",
740
- *[](torch::Tensor& self, torch::Device device, int dtype, bool non_blocking, bool copy) {
941
+ *[](Tensor& self, torch::Device device, int dtype, bool non_blocking, bool copy) {
741
942
  return self.to(device, (torch::ScalarType) dtype, non_blocking, copy);
742
943
  })
743
944
  .define_singleton_method(
744
945
  "_make_subclass",
745
- *[](torch::Tensor& rd, bool requires_grad) {
946
+ *[](Tensor& rd, bool requires_grad) {
746
947
  auto data = torch::autograd::as_variable_ref(rd).detach();
747
948
  data.unsafeGetTensorImpl()->set_allow_tensor_metadata_change(true);
748
949
  auto var = data.set_requires_grad(requires_grad);
@@ -793,32 +994,82 @@ void Init_ext()
793
994
 
794
995
  Module rb_mInit = define_module_under(rb_mNN, "Init")
795
996
  .define_singleton_method(
796
- "kaiming_uniform!",
797
- *[](torch::Tensor& input, double a) {
798
- return torch::nn::init::kaiming_uniform_(input, a);
997
+ "_calculate_gain",
998
+ *[](NonlinearityType nonlinearity, double param) {
999
+ return torch::nn::init::calculate_gain(nonlinearity, param);
799
1000
  })
800
1001
  .define_singleton_method(
801
- "normal!",
802
- *[](torch::Tensor& input) {
803
- return torch::nn::init::normal_(input);
1002
+ "_uniform!",
1003
+ *[](Tensor tensor, double low, double high) {
1004
+ return torch::nn::init::uniform_(tensor, low, high);
1005
+ })
1006
+ .define_singleton_method(
1007
+ "_normal!",
1008
+ *[](Tensor tensor, double mean, double std) {
1009
+ return torch::nn::init::normal_(tensor, mean, std);
1010
+ })
1011
+ .define_singleton_method(
1012
+ "_constant!",
1013
+ *[](Tensor tensor, Scalar value) {
1014
+ return torch::nn::init::constant_(tensor, value);
1015
+ })
1016
+ .define_singleton_method(
1017
+ "_ones!",
1018
+ *[](Tensor tensor) {
1019
+ return torch::nn::init::ones_(tensor);
1020
+ })
1021
+ .define_singleton_method(
1022
+ "_zeros!",
1023
+ *[](Tensor tensor) {
1024
+ return torch::nn::init::zeros_(tensor);
1025
+ })
1026
+ .define_singleton_method(
1027
+ "_eye!",
1028
+ *[](Tensor tensor) {
1029
+ return torch::nn::init::eye_(tensor);
1030
+ })
1031
+ .define_singleton_method(
1032
+ "_dirac!",
1033
+ *[](Tensor tensor) {
1034
+ return torch::nn::init::dirac_(tensor);
1035
+ })
1036
+ .define_singleton_method(
1037
+ "_xavier_uniform!",
1038
+ *[](Tensor tensor, double gain) {
1039
+ return torch::nn::init::xavier_uniform_(tensor, gain);
1040
+ })
1041
+ .define_singleton_method(
1042
+ "_xavier_normal!",
1043
+ *[](Tensor tensor, double gain) {
1044
+ return torch::nn::init::xavier_normal_(tensor, gain);
1045
+ })
1046
+ .define_singleton_method(
1047
+ "_kaiming_uniform!",
1048
+ *[](Tensor tensor, double a, FanModeType mode, NonlinearityType nonlinearity) {
1049
+ return torch::nn::init::kaiming_uniform_(tensor, a, mode, nonlinearity);
1050
+ })
1051
+ .define_singleton_method(
1052
+ "_kaiming_normal!",
1053
+ *[](Tensor tensor, double a, FanModeType mode, NonlinearityType nonlinearity) {
1054
+ return torch::nn::init::kaiming_normal_(tensor, a, mode, nonlinearity);
804
1055
  })
805
1056
  .define_singleton_method(
806
- "uniform!",
807
- *[](torch::Tensor& input, double to, double from) {
808
- return torch::nn::init::uniform_(input, to, from);
1057
+ "_orthogonal!",
1058
+ *[](Tensor tensor, double gain) {
1059
+ return torch::nn::init::orthogonal_(tensor, gain);
1060
+ })
1061
+ .define_singleton_method(
1062
+ "_sparse!",
1063
+ *[](Tensor tensor, double sparsity, double std) {
1064
+ return torch::nn::init::sparse_(tensor, sparsity, std);
809
1065
  });
810
1066
 
811
1067
  Class rb_cParameter = define_class_under<torch::autograd::Variable, torch::Tensor>(rb_mNN, "Parameter")
812
- // TODO return grad or nil to remove need for 2nd function
813
1068
  .define_method(
814
- "_grad",
815
- *[](torch::autograd::Variable& self) {
816
- return self.grad();
817
- })
818
- .define_method(
819
- "_grad_defined",
1069
+ "grad",
820
1070
  *[](torch::autograd::Variable& self) {
821
- return self.grad().defined();
1071
+ auto grad = self.grad();
1072
+ return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
822
1073
  });
823
1074
 
824
1075
  Class rb_cDevice = define_class_under<torch::Device>(rb_mTorch, "Device")