torch-rb 0.1.3 → 0.1.4

Sign up to get free protection for your applications and to get access to all the features.
Files changed (48) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +9 -0
  3. data/README.md +1 -0
  4. data/ext/torch/ext.cpp +375 -124
  5. data/lib/torch.rb +101 -20
  6. data/lib/torch/ext.bundle +0 -0
  7. data/lib/torch/inspector.rb +23 -19
  8. data/lib/torch/nn/avg_pool2d.rb +14 -0
  9. data/lib/torch/nn/avg_poolnd.rb +9 -0
  10. data/lib/torch/nn/bce_loss.rb +13 -0
  11. data/lib/torch/nn/bilinear.rb +38 -0
  12. data/lib/torch/nn/conv2d.rb +2 -2
  13. data/lib/torch/nn/convnd.rb +3 -3
  14. data/lib/torch/nn/cosine_similarity.rb +15 -0
  15. data/lib/torch/nn/cross_entropy_loss.rb +14 -0
  16. data/lib/torch/nn/ctc_loss.rb +15 -0
  17. data/lib/torch/nn/dropoutnd.rb +2 -2
  18. data/lib/torch/nn/embedding_bag.rb +34 -0
  19. data/lib/torch/nn/functional.rb +101 -13
  20. data/lib/torch/nn/identity.rb +13 -0
  21. data/lib/torch/nn/init.rb +58 -1
  22. data/lib/torch/nn/kl_div_loss.rb +13 -0
  23. data/lib/torch/nn/l1_loss.rb +13 -0
  24. data/lib/torch/nn/leaky_relu.rb +20 -0
  25. data/lib/torch/nn/linear.rb +12 -11
  26. data/lib/torch/nn/log_softmax.rb +14 -0
  27. data/lib/torch/nn/loss.rb +10 -0
  28. data/lib/torch/nn/max_pool2d.rb +9 -0
  29. data/lib/torch/nn/max_poolnd.rb +19 -0
  30. data/lib/torch/nn/module.rb +120 -31
  31. data/lib/torch/nn/mse_loss.rb +2 -2
  32. data/lib/torch/nn/nll_loss.rb +14 -0
  33. data/lib/torch/nn/pairwise_distance.rb +16 -0
  34. data/lib/torch/nn/parameter.rb +0 -4
  35. data/lib/torch/nn/poisson_nll_loss.rb +16 -0
  36. data/lib/torch/nn/prelu.rb +19 -0
  37. data/lib/torch/nn/relu.rb +8 -3
  38. data/lib/torch/nn/sequential.rb +1 -10
  39. data/lib/torch/nn/sigmoid.rb +9 -0
  40. data/lib/torch/nn/softmax.rb +18 -0
  41. data/lib/torch/nn/softmax2d.rb +10 -0
  42. data/lib/torch/nn/softmin.rb +14 -0
  43. data/lib/torch/nn/softplus.rb +19 -0
  44. data/lib/torch/nn/weighted_loss.rb +10 -0
  45. data/lib/torch/random.rb +10 -0
  46. data/lib/torch/tensor.rb +28 -10
  47. data/lib/torch/version.rb +1 -1
  48. metadata +29 -2
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: e7f715179c9a84dc7399b80d93fd61f2bbb58a0156e6084dc4abb23e1d4a1b52
4
- data.tar.gz: 6928379ae7c92a77ad9dde4f4224ec33c6f8575a9b77585c0147e4f5361021de
3
+ metadata.gz: 4faccffe2d2fd29519ad9dcce0560978a07c734831b5f64bb4624a0037f2b08c
4
+ data.tar.gz: 4a8f873a9bb99c2311c856c59e5c43a5dfadd3f4f2460da1370ca1db888b79ad
5
5
  SHA512:
6
- metadata.gz: 9911a9e86d93f1e410776c44fdb3cd9aa06c83d1f0e42fdab8530970bea6520aed7906e96fb8243efd6b957453ebc13678b2b92e4c85b54407030a32c6196e08
7
- data.tar.gz: 0d080f5458a5dcf8fee19ce5e2e342bf6269432de6e78d923036232963ebb80daeea993c0bbf4af2d6da46593ac28a72a8232020a9fcb48acc3276c9e1ebebf3
6
+ metadata.gz: 199b3b47325b72b38786f50c39f0cfb9b11709f02edf4b77e1c4e9198baf5fa2b3924d639d6c0f7e1715193528d4cfab4f53fbba7e2b16e06ac462d37862cf3a
7
+ data.tar.gz: 4514f7aab60d9beabee47db175c9df6e3e1e93080b47eaeacff0f9dd4e8e737b476c8400f84bda4ccdf61aca4f7c81e010bcfbc2fa67f8c629c33cd6dcdcb54c
data/CHANGELOG.md CHANGED
@@ -1,3 +1,12 @@
1
+ ## 0.1.4 (2019-12-01)
2
+
3
+ - Added distance functions
4
+ - Added more activations
5
+ - Added more linear layers
6
+ - Added more loss functions
7
+ - Added more init methods
8
+ - Added support for tensor assignment
9
+
1
10
  ## 0.1.3 (2019-11-30)
2
11
 
3
12
  - Changed to BSD 3-Clause license to match PyTorch
data/README.md CHANGED
@@ -367,6 +367,7 @@ Here are a few full examples:
367
367
 
368
368
  - [Image classification with MNIST](examples/mnist)
369
369
  - [Collaborative filtering with MovieLens](examples/movielens)
370
+ - [Word embeddings](examples/nlp) [master]
370
371
 
371
372
  ## LibTorch Installation
372
373
 
data/ext/torch/ext.cpp CHANGED
@@ -132,6 +132,112 @@ TensorList from_ruby<TensorList>(Object x)
132
132
  return TensorList(x);
133
133
  }
134
134
 
135
+ class FanModeType {
136
+ std::string s;
137
+ public:
138
+ FanModeType(Object o) {
139
+ s = String(o).str();
140
+ }
141
+ // TODO switch NonlinearityType after LibTorch 1.4 release
142
+ operator torch::nn::init::FanMode() {
143
+ if (s == "fan_in") {
144
+ return torch::nn::init::FanMode::FanIn;
145
+ } else if (s == "fan_out") {
146
+ return torch::nn::init::FanMode::FanOut;
147
+ } else {
148
+ throw std::runtime_error("Unsupported nonlinearity type: " + s);
149
+ }
150
+ }
151
+ };
152
+
153
+ template<>
154
+ inline
155
+ FanModeType from_ruby<FanModeType>(Object x)
156
+ {
157
+ return FanModeType(x);
158
+ }
159
+
160
+ class NonlinearityType {
161
+ std::string s;
162
+ public:
163
+ NonlinearityType(Object o) {
164
+ s = String(o).str();
165
+ }
166
+ // TODO switch NonlinearityType after LibTorch 1.4 release
167
+ operator torch::nn::init::Nonlinearity() {
168
+ if (s == "linear") {
169
+ return torch::nn::init::Nonlinearity::Linear;
170
+ } else if (s == "conv1d") {
171
+ return torch::nn::init::Nonlinearity::Conv1D;
172
+ } else if (s == "conv2d") {
173
+ return torch::nn::init::Nonlinearity::Conv2D;
174
+ } else if (s == "conv3d") {
175
+ return torch::nn::init::Nonlinearity::Conv3D;
176
+ } else if (s == "conv_transpose1d") {
177
+ return torch::nn::init::Nonlinearity::ConvTranspose1D;
178
+ } else if (s == "conv_transpose2d") {
179
+ return torch::nn::init::Nonlinearity::ConvTranspose2D;
180
+ } else if (s == "conv_transpose3d") {
181
+ return torch::nn::init::Nonlinearity::ConvTranspose3D;
182
+ } else if (s == "sigmoid") {
183
+ return torch::nn::init::Nonlinearity::Sigmoid;
184
+ } else if (s == "tanh") {
185
+ return torch::nn::init::Nonlinearity::Tanh;
186
+ } else if (s == "relu") {
187
+ return torch::nn::init::Nonlinearity::ReLU;
188
+ } else if (s == "leaky_relu") {
189
+ return torch::nn::init::Nonlinearity::LeakyReLU;
190
+ } else {
191
+ throw std::runtime_error("Unsupported nonlinearity type: " + s);
192
+ }
193
+ }
194
+ };
195
+
196
+ template<>
197
+ inline
198
+ NonlinearityType from_ruby<NonlinearityType>(Object x)
199
+ {
200
+ return NonlinearityType(x);
201
+ }
202
+
203
+ class MyReduction {
204
+ Object value;
205
+ public:
206
+ MyReduction(Object o) {
207
+ value = o;
208
+ }
209
+ operator int64_t() {
210
+ if (value.is_nil()) {
211
+ return Reduction::None;
212
+ }
213
+
214
+ std::string s = String(value).str();
215
+ if (s == "mean") {
216
+ return Reduction::Mean;
217
+ } else if (s == "sum") {
218
+ return Reduction::Sum;
219
+ } else {
220
+ throw std::runtime_error("Unsupported reduction: " + s);
221
+ }
222
+ }
223
+ };
224
+
225
+ template<>
226
+ inline
227
+ MyReduction from_ruby<MyReduction>(Object x)
228
+ {
229
+ return MyReduction(x);
230
+ }
231
+
232
+ typedef torch::Tensor Tensor;
233
+
234
+ Object tensor_array(std::tuple<torch::Tensor, torch::Tensor> x) {
235
+ Array a;
236
+ a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
237
+ a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
238
+ return Object(a);
239
+ }
240
+
135
241
  extern "C"
136
242
  void Init_ext()
137
243
  {
@@ -148,7 +254,7 @@ void Init_ext()
148
254
  })
149
255
  .define_singleton_method(
150
256
  "floating_point?",
151
- *[](torch::Tensor& input) {
257
+ *[](Tensor& input) {
152
258
  return torch::is_floating_point(input);
153
259
  })
154
260
  .define_singleton_method(
@@ -220,32 +326,32 @@ void Init_ext()
220
326
  // begin operations
221
327
  .define_singleton_method(
222
328
  "_mean",
223
- *[](torch::Tensor& input) {
329
+ *[](Tensor& input) {
224
330
  return torch::mean(input);
225
331
  })
226
332
  .define_singleton_method(
227
333
  "_mean_dim",
228
- *[](torch::Tensor& input, int64_t dim, bool keepdim) {
334
+ *[](Tensor& input, int64_t dim, bool keepdim) {
229
335
  return torch::mean(input, dim, keepdim);
230
336
  })
231
337
  .define_singleton_method(
232
338
  "_sum",
233
- *[](torch::Tensor& input) {
339
+ *[](Tensor& input) {
234
340
  return torch::sum(input);
235
341
  })
236
342
  .define_singleton_method(
237
343
  "_sum_dim",
238
- *[](torch::Tensor& input, int64_t dim, bool keepdim) {
344
+ *[](Tensor& input, int64_t dim, bool keepdim) {
239
345
  return torch::sum(input, dim, keepdim);
240
346
  })
241
347
  .define_singleton_method(
242
348
  "_argmax",
243
- *[](torch::Tensor& input) {
349
+ *[](Tensor& input) {
244
350
  return torch::argmax(input);
245
351
  })
246
352
  .define_singleton_method(
247
353
  "_argmax_dim",
248
- *[](torch::Tensor& input, int64_t dim, bool keepdim) {
354
+ *[](Tensor& input, int64_t dim, bool keepdim) {
249
355
  return torch::argmax(input, dim, keepdim);
250
356
  })
251
357
  .define_singleton_method(
@@ -255,239 +361,322 @@ void Init_ext()
255
361
  })
256
362
  .define_singleton_method(
257
363
  "_norm",
258
- *[](torch::Tensor& input) {
364
+ *[](Tensor& input) {
259
365
  return torch::norm(input);
260
366
  })
261
367
  .define_singleton_method(
262
368
  "_min",
263
- *[](torch::Tensor& input) {
369
+ *[](Tensor& input) {
264
370
  return torch::min(input);
265
371
  })
266
372
  .define_singleton_method(
267
373
  "_max",
268
- *[](torch::Tensor& input) {
374
+ *[](Tensor& input) {
269
375
  return torch::max(input);
270
376
  })
271
377
  .define_singleton_method(
272
378
  "_max_out",
273
- *[](torch::Tensor &max, torch::Tensor &max_indices, const torch::Tensor &input, int64_t dim, bool keepdim) {
274
- // TODO add return value
275
- torch::_max_out(max, max_indices, input, dim, keepdim);
379
+ *[](Tensor &max, Tensor &max_indices, const Tensor &input, int64_t dim, bool keepdim) {
380
+ return tensor_array(torch::_max_out(max, max_indices, input, dim, keepdim));
276
381
  })
277
382
  .define_singleton_method(
278
383
  "_sqrt",
279
- *[](torch::Tensor& input) {
384
+ *[](Tensor& input) {
280
385
  return torch::sqrt(input);
281
386
  })
282
387
  .define_singleton_method(
283
388
  "_exp",
284
- *[](torch::Tensor& input) {
389
+ *[](Tensor& input) {
285
390
  return torch::exp(input);
286
391
  })
287
392
  .define_singleton_method(
288
393
  "_log",
289
- *[](torch::Tensor& input) {
394
+ *[](Tensor& input) {
290
395
  return torch::log(input);
291
396
  })
292
397
  .define_singleton_method(
293
398
  "_sign",
294
- *[](torch::Tensor& input) {
399
+ *[](Tensor& input) {
295
400
  return torch::sign(input);
296
401
  })
297
402
  .define_singleton_method(
298
403
  "_unsqueeze",
299
- *[](torch::Tensor& input, int64_t dim) {
404
+ *[](Tensor& input, int64_t dim) {
300
405
  return torch::unsqueeze(input, dim);
301
406
  })
302
407
  .define_singleton_method(
303
408
  "_dot",
304
- *[](torch::Tensor& input, torch::Tensor& tensor) {
409
+ *[](Tensor& input, Tensor& tensor) {
305
410
  return torch::dot(input, tensor);
306
411
  })
307
412
  .define_singleton_method(
308
413
  "_matmul",
309
- *[](torch::Tensor& input, torch::Tensor& other) {
414
+ *[](Tensor& input, Tensor& other) {
310
415
  return torch::matmul(input, other);
311
416
  })
312
417
  .define_singleton_method(
313
418
  "_eq",
314
- *[](torch::Tensor& input, torch::Tensor& other) {
419
+ *[](Tensor& input, Tensor& other) {
315
420
  return torch::eq(input, other);
316
421
  })
317
422
  .define_singleton_method(
318
423
  "_gt",
319
424
  // TODO support tensors
320
- *[](torch::Tensor& input, Scalar other) {
425
+ *[](Tensor& input, Scalar other) {
321
426
  return torch::gt(input, other);
322
427
  })
323
428
  .define_singleton_method(
324
429
  "_lt",
325
430
  // TODO support tensors
326
- *[](torch::Tensor& input, Scalar other) {
431
+ *[](Tensor& input, Scalar other) {
327
432
  return torch::lt(input, other);
328
433
  })
329
434
  .define_singleton_method(
330
435
  "_add",
331
- *[](torch::Tensor& input, torch::Tensor& other) {
436
+ *[](Tensor& input, Tensor& other) {
332
437
  return torch::add(input, other);
333
438
  })
334
439
  .define_singleton_method(
335
440
  "_add_scalar",
336
- *[](torch::Tensor& input, Scalar other) {
441
+ *[](Tensor& input, Scalar other) {
337
442
  return torch::add(input, other);
338
443
  })
339
444
  .define_singleton_method(
340
445
  "_add_out",
341
- *[](torch::Tensor& out, torch::Tensor& input, torch::Tensor& other) {
446
+ *[](Tensor& out, Tensor& input, Tensor& other) {
342
447
  return torch::add_out(out, input, other);
343
448
  })
344
449
  .define_singleton_method(
345
450
  "_sub",
346
- *[](torch::Tensor& input, torch::Tensor& other) {
451
+ *[](Tensor& input, Tensor& other) {
347
452
  return torch::sub(input, other);
348
453
  })
349
454
  .define_singleton_method(
350
455
  "_sub_scalar",
351
- *[](torch::Tensor& input, Scalar other) {
456
+ *[](Tensor& input, Scalar other) {
352
457
  return torch::sub(input, other);
353
458
  })
354
459
  .define_singleton_method(
355
460
  "_mul",
356
- *[](torch::Tensor& input, torch::Tensor& other) {
461
+ *[](Tensor& input, Tensor& other) {
357
462
  return torch::mul(input, other);
358
463
  })
359
464
  .define_singleton_method(
360
465
  "_mul_scalar",
361
- *[](torch::Tensor& input, Scalar other) {
466
+ *[](Tensor& input, Scalar other) {
362
467
  return torch::mul(input, other);
363
468
  })
364
469
  .define_singleton_method(
365
470
  "_div",
366
- *[](torch::Tensor& input, torch::Tensor& other) {
471
+ *[](Tensor& input, Tensor& other) {
367
472
  return torch::div(input, other);
368
473
  })
369
474
  .define_singleton_method(
370
475
  "_div_scalar",
371
- *[](torch::Tensor& input, Scalar other) {
476
+ *[](Tensor& input, Scalar other) {
372
477
  return torch::div(input, other);
373
478
  })
374
479
  .define_singleton_method(
375
480
  "_remainder",
376
- *[](torch::Tensor& input, torch::Tensor& other) {
481
+ *[](Tensor& input, Tensor& other) {
377
482
  return torch::remainder(input, other);
378
483
  })
379
484
  .define_singleton_method(
380
485
  "_remainder_scalar",
381
- *[](torch::Tensor& input, Scalar other) {
486
+ *[](Tensor& input, Scalar other) {
382
487
  return torch::remainder(input, other);
383
488
  })
384
489
  .define_singleton_method(
385
490
  "_pow",
386
- *[](torch::Tensor& input, Scalar exponent) {
491
+ *[](Tensor& input, Scalar exponent) {
387
492
  return torch::pow(input, exponent);
388
493
  })
494
+ .define_singleton_method(
495
+ "_topk",
496
+ *[](Tensor& input, int64_t k) {
497
+ return tensor_array(torch::topk(input, k));
498
+ })
499
+ .define_singleton_method(
500
+ "_sigmoid",
501
+ *[](Tensor& input) {
502
+ return torch::sigmoid(input);
503
+ })
504
+ .define_singleton_method(
505
+ "_softplus",
506
+ *[](const Tensor &input, Scalar beta, Scalar threshold) {
507
+ return torch::softplus(input, beta, threshold);
508
+ })
509
+ .define_singleton_method(
510
+ "_softmax",
511
+ *[](const Tensor &input, int64_t dim) {
512
+ return torch::softmax(input, dim);
513
+ })
514
+ .define_singleton_method(
515
+ "_log_softmax",
516
+ *[](Tensor& input, int64_t dim) {
517
+ return torch::log_softmax(input, dim);
518
+ })
389
519
  .define_singleton_method(
390
520
  "_abs",
391
- *[](torch::Tensor& input) {
521
+ *[](Tensor& input) {
392
522
  return torch::abs(input);
393
523
  })
394
524
  .define_singleton_method(
395
525
  "_neg",
396
- *[](torch::Tensor& input) {
526
+ *[](Tensor& input) {
397
527
  return torch::neg(input);
398
528
  })
399
529
  .define_singleton_method(
400
530
  "_reshape",
401
- *[](torch::Tensor& input, IntArrayRef shape) {
531
+ *[](Tensor& input, IntArrayRef shape) {
402
532
  return torch::reshape(input, shape);
403
533
  })
404
534
  .define_singleton_method(
405
535
  "_flatten",
406
- *[](torch::Tensor& input, int64_t start_dim, int64_t end_dim) {
536
+ *[](Tensor& input, int64_t start_dim, int64_t end_dim) {
407
537
  return torch::flatten(input, start_dim, end_dim);
408
538
  })
409
539
  .define_singleton_method(
410
540
  "relu",
411
- *[](torch::Tensor& input) {
541
+ *[](Tensor& input) {
412
542
  return torch::relu(input);
413
543
  })
544
+ .define_singleton_method(
545
+ "prelu",
546
+ *[](torch::Tensor& input, torch::Tensor& weight) {
547
+ return torch::prelu(input, weight);
548
+ })
549
+ .define_singleton_method(
550
+ "leaky_relu",
551
+ *[](torch::Tensor& input, Scalar negative_slope) {
552
+ return torch::leaky_relu(input, negative_slope);
553
+ })
414
554
  .define_singleton_method(
415
555
  "conv2d",
416
- *[](torch::Tensor& input, torch::Tensor& weight, torch::Tensor& bias, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, int64_t groups) {
556
+ *[](Tensor& input, Tensor& weight, Tensor& bias, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, int64_t groups) {
417
557
  return torch::conv2d(input, weight, bias, stride, padding, dilation, groups);
418
558
  })
559
+ // linear layers
560
+ .define_singleton_method(
561
+ "bilinear",
562
+ *[](const Tensor &input1, const Tensor &input2, const Tensor &weight, const Tensor &bias) {
563
+ return torch::bilinear(input1, input2, weight, bias);
564
+ })
419
565
  .define_singleton_method(
420
566
  "linear",
421
- *[](torch::Tensor& input, torch::Tensor& weight, torch::Tensor& bias) {
567
+ *[](Tensor& input, Tensor& weight, Tensor& bias) {
422
568
  return torch::linear(input, weight, bias);
423
569
  })
570
+ // pooling layers
424
571
  .define_singleton_method(
425
572
  "max_pool2d",
426
- *[](torch::Tensor& input, IntArrayRef kernel_size) {
573
+ *[](Tensor& input, IntArrayRef kernel_size) {
427
574
  return torch::max_pool2d(input, kernel_size);
428
575
  })
429
576
  .define_singleton_method(
430
577
  "avg_pool2d",
431
- *[](torch::Tensor& input, IntArrayRef kernel_size) {
578
+ *[](Tensor& input, IntArrayRef kernel_size) {
432
579
  return torch::avg_pool2d(input, kernel_size);
433
580
  })
434
581
  .define_singleton_method(
435
582
  "_dropout",
436
- *[](torch::Tensor& input, float p, bool train) {
583
+ *[](Tensor& input, float p, bool train) {
437
584
  return torch::dropout(input, p, train);
438
585
  })
439
586
  .define_singleton_method(
440
587
  "_dropout!",
441
- *[](torch::Tensor& input, float p, bool train) {
588
+ *[](Tensor& input, float p, bool train) {
442
589
  return torch::dropout_(input, p, train);
443
590
  })
444
591
  .define_singleton_method(
445
592
  "_feature_dropout",
446
- *[](torch::Tensor& input, float p, bool train) {
593
+ *[](Tensor& input, float p, bool train) {
447
594
  return torch::feature_dropout(input, p, train);
448
595
  })
449
596
  .define_singleton_method(
450
597
  "_feature_dropout!",
451
- *[](torch::Tensor& input, float p, bool train) {
598
+ *[](Tensor& input, float p, bool train) {
452
599
  return torch::feature_dropout_(input, p, train);
453
600
  })
454
601
  .define_singleton_method(
455
602
  "_alpha_dropout",
456
- *[](torch::Tensor& input, float p, bool train) {
603
+ *[](Tensor& input, float p, bool train) {
457
604
  return torch::alpha_dropout(input, p, train);
458
605
  })
459
606
  .define_singleton_method(
460
607
  "_alpha_dropout!",
461
- *[](torch::Tensor& input, float p, bool train) {
608
+ *[](Tensor& input, float p, bool train) {
462
609
  return torch::alpha_dropout_(input, p, train);
463
610
  })
464
611
  .define_singleton_method(
465
612
  "_feature_alpha_dropout",
466
- *[](torch::Tensor& input, float p, bool train) {
613
+ *[](Tensor& input, float p, bool train) {
467
614
  return torch::feature_alpha_dropout(input, p, train);
468
615
  })
469
616
  .define_singleton_method(
470
617
  "_feature_alpha_dropout!",
471
- *[](torch::Tensor& input, float p, bool train) {
618
+ *[](Tensor& input, float p, bool train) {
472
619
  return torch::feature_alpha_dropout_(input, p, train);
473
620
  })
621
+ // sparse layers
474
622
  .define_singleton_method(
475
623
  "_embedding",
476
624
  // weight and indices are swapped from Python interface
477
- *[](const torch::Tensor &indices, const torch::Tensor &weight, int64_t padding_idx, bool scale_grad_by_freq, bool sparse) {
625
+ *[](const Tensor &indices, const Tensor &weight, int64_t padding_idx, bool scale_grad_by_freq, bool sparse) {
478
626
  return torch::embedding(weight, indices, padding_idx, scale_grad_by_freq, sparse);
479
627
  })
628
+ .define_singleton_method(
629
+ "_embedding_bag",
630
+ // weight and indices are swapped from Python interface
631
+ *[](const Tensor &weight, const Tensor &indices, const Tensor &offsets, bool scale_grad_by_freq, int64_t mode, bool sparse, const Tensor &per_sample_weights) {
632
+ return torch::embedding_bag(weight, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights);
633
+ })
634
+ // distance functions
635
+ .define_singleton_method(
636
+ "_cosine_similarity",
637
+ *[](const Tensor &x1, const Tensor &x2, int64_t dim, double eps) {
638
+ return torch::cosine_similarity(x1, x2, dim, eps);
639
+ })
640
+ .define_singleton_method(
641
+ "_pairwise_distance",
642
+ *[](const Tensor &x1, const Tensor &x2, double p, double eps, bool keepdim) {
643
+ return torch::pairwise_distance(x1, x2, p, eps, keepdim);
644
+ })
645
+ // loss functions
646
+ .define_singleton_method(
647
+ "binary_cross_entropy",
648
+ *[](Tensor& input, Tensor& target, MyReduction reduction) {
649
+ return torch::binary_cross_entropy(input, target, {}, reduction);
650
+ })
651
+ .define_singleton_method(
652
+ "ctc_loss",
653
+ *[](const Tensor &log_probs, const Tensor &targets, IntArrayRef input_lengths, IntArrayRef target_lengths, int64_t blank, MyReduction reduction, bool zero_infinity) {
654
+ return torch::ctc_loss(log_probs, targets, input_lengths, target_lengths, blank, reduction, zero_infinity);
655
+ })
656
+ .define_singleton_method(
657
+ "kl_div",
658
+ *[](Tensor& input, Tensor& target, MyReduction reduction) {
659
+ return torch::kl_div(input, target, reduction);
660
+ })
661
+ .define_singleton_method(
662
+ "l1_loss",
663
+ *[](Tensor& input, Tensor& target, MyReduction reduction) {
664
+ return torch::l1_loss(input, target, reduction);
665
+ })
480
666
  .define_singleton_method(
481
667
  "mse_loss",
482
- *[](torch::Tensor& input, torch::Tensor& target, std::string reduction) {
483
- auto red = reduction == "mean" ? Reduction::Mean : Reduction::Sum;
484
- return torch::mse_loss(input, target, red);
668
+ *[](Tensor& input, Tensor& target, MyReduction reduction) {
669
+ return torch::mse_loss(input, target, reduction);
485
670
  })
486
671
  .define_singleton_method(
487
672
  "nll_loss",
488
- *[](torch::Tensor& input, torch::Tensor& target, std::string reduction) {
489
- auto red = reduction == "mean" ? Reduction::Mean : Reduction::Sum;
490
- return torch::nll_loss(input, target, {}, red);
673
+ *[](Tensor& input, Tensor& target, MyReduction reduction, int64_t ignore_index) {
674
+ return torch::nll_loss(input, target, {}, reduction, ignore_index);
675
+ })
676
+ .define_singleton_method(
677
+ "poisson_nll_loss",
678
+ *[](const Tensor &input, const Tensor &target, bool log_input, bool full, double eps, MyReduction reduction) {
679
+ return torch::poisson_nll_loss(input, target, log_input, full, eps, reduction);
491
680
  })
492
681
  .define_singleton_method("numel", &torch::numel)
493
682
  .define_singleton_method(
@@ -500,11 +689,18 @@ void Init_ext()
500
689
  "_tensor",
501
690
  *[](Object o, IntArrayRef size, const torch::TensorOptions &options) {
502
691
  Array a = Array(o);
503
- std::vector<float> vec;
504
- for (size_t i = 0; i < a.size(); i++) {
505
- vec.push_back(from_ruby<float>(a[i]));
692
+ auto dtype = options.dtype();
693
+ torch::Tensor t;
694
+ if (dtype == torch::kBool) {
695
+ throw std::runtime_error("Cannot create bool from tensor method yet");
696
+ } else {
697
+ std::vector<float> vec;
698
+ for (size_t i = 0; i < a.size(); i++) {
699
+ vec.push_back(from_ruby<float>(a[i]));
700
+ }
701
+ t = torch::tensor(vec, options);
506
702
  }
507
- return torch::tensor(vec, options).reshape(size);
703
+ return t.reshape(size);
508
704
  });
509
705
 
510
706
  Class rb_cTensor = define_class_under<torch::Tensor>(rb_mTorch, "Tensor")
@@ -521,166 +717,171 @@ void Init_ext()
521
717
  .define_method("view_as", &torch::Tensor::view_as)
522
718
  .define_method(
523
719
  "addcmul!",
524
- *[](torch::Tensor& self, Scalar value, const torch::Tensor & tensor1, const torch::Tensor & tensor2) {
720
+ *[](Tensor& self, Scalar value, const Tensor & tensor1, const Tensor & tensor2) {
525
721
  return self.addcmul_(tensor1, tensor2, value);
526
722
  })
527
723
  .define_method(
528
724
  "addcdiv!",
529
- *[](torch::Tensor& self, Scalar value, const torch::Tensor & tensor1, const torch::Tensor & tensor2) {
725
+ *[](Tensor& self, Scalar value, const Tensor & tensor1, const Tensor & tensor2) {
530
726
  return self.addcdiv_(tensor1, tensor2, value);
531
727
  })
532
728
  .define_method(
533
729
  "zero!",
534
- *[](torch::Tensor& self) {
730
+ *[](Tensor& self) {
535
731
  return self.zero_();
536
732
  })
733
+ .define_method(
734
+ "detach",
735
+ *[](Tensor& self) {
736
+ return self.detach();
737
+ })
537
738
  .define_method(
538
739
  "detach!",
539
- *[](torch::Tensor& self) {
740
+ *[](Tensor& self) {
540
741
  return self.detach_();
541
742
  })
542
743
  .define_method(
543
744
  "_select",
544
- *[](torch::Tensor& self, int64_t dim, int64_t index) {
745
+ *[](Tensor& self, int64_t dim, int64_t index) {
545
746
  return self.select(dim, index);
546
747
  })
547
748
  .define_method(
548
749
  "_slice",
549
- *[](torch::Tensor& self, int64_t dim, int64_t start, int64_t end, int64_t step) {
750
+ *[](Tensor& self, int64_t dim, int64_t start, int64_t end, int64_t step) {
550
751
  return self.slice(dim, start, end, step);
551
752
  })
552
753
  .define_method(
553
754
  "_requires_grad!",
554
- *[](torch::Tensor& self, bool requires_grad) {
755
+ *[](Tensor& self, bool requires_grad) {
555
756
  return self.set_requires_grad(requires_grad);
556
757
  })
557
758
  .define_method(
558
759
  "_backward",
559
- *[](torch::Tensor& self) {
560
- return self.backward();
561
- })
562
- .define_method(
563
- "_backward_gradient",
564
- *[](torch::Tensor& self, const torch::Tensor& gradient) {
565
- return self.backward(gradient);
760
+ *[](Tensor& self, Object gradient) {
761
+ return gradient.is_nil() ? self.backward() : self.backward(from_ruby<torch::Tensor>(gradient));
566
762
  })
567
763
  .define_method(
568
764
  "grad",
569
- *[](torch::Tensor& self) {
765
+ *[](Tensor& self) {
570
766
  return self.grad();
571
767
  })
572
768
  .define_method(
573
769
  "_dtype",
574
- *[](torch::Tensor& self) {
770
+ *[](Tensor& self) {
575
771
  return (int) at::typeMetaToScalarType(self.dtype());
576
772
  })
577
773
  .define_method(
578
774
  "_type",
579
- *[](torch::Tensor& self, int dtype) {
775
+ *[](Tensor& self, int dtype) {
580
776
  return self.toType((torch::ScalarType) dtype);
581
777
  })
582
778
  .define_method(
583
779
  "_layout",
584
- *[](torch::Tensor& self) {
780
+ *[](Tensor& self) {
585
781
  std::stringstream s;
586
782
  s << self.layout();
587
783
  return s.str();
588
784
  })
589
785
  .define_method(
590
786
  "device",
591
- *[](torch::Tensor& self) {
787
+ *[](Tensor& self) {
592
788
  std::stringstream s;
593
789
  s << self.device();
594
790
  return s.str();
595
791
  })
596
792
  .define_method(
597
793
  "_view",
598
- *[](torch::Tensor& self, IntArrayRef size) {
794
+ *[](Tensor& self, IntArrayRef size) {
599
795
  return self.view(size);
600
796
  })
601
797
  .define_method(
602
798
  "resize_as!",
603
- *[](torch::Tensor& self, torch::Tensor& other) {
799
+ *[](Tensor& self, Tensor& other) {
604
800
  return self.resize_as_(other);
605
801
  })
606
802
  .define_method(
607
803
  "fill!",
608
- *[](torch::Tensor& self, Scalar value) {
804
+ *[](Tensor& self, Scalar value) {
609
805
  return self.fill_(value);
610
806
  })
807
+ .define_method(
808
+ "relu!",
809
+ *[](Tensor& self) {
810
+ return self.relu_();
811
+ })
611
812
  .define_method(
612
813
  "_add!",
613
- *[](torch::Tensor& self, torch::Tensor& other) {
814
+ *[](Tensor& self, Tensor& other) {
614
815
  return self.add_(other);
615
816
  })
616
817
  .define_method(
617
818
  "_add_alpha!",
618
- *[](torch::Tensor& self, torch::Tensor& other, Scalar alpha) {
819
+ *[](Tensor& self, Tensor& other, Scalar alpha) {
619
820
  return self.add_(other, alpha);
620
821
  })
621
822
  .define_method(
622
823
  "_add_scalar!",
623
- *[](torch::Tensor& self, Scalar other) {
824
+ *[](Tensor& self, Scalar other) {
624
825
  return self.add_(other);
625
826
  })
626
827
  .define_method(
627
828
  "normal!",
628
- *[](torch::Tensor& self, double mean, double std) {
829
+ *[](Tensor& self, double mean, double std) {
629
830
  return self.normal_(mean, std);
630
831
  })
832
+ .define_method(
833
+ "random!",
834
+ *[](Tensor& self, int64_t to) {
835
+ return self.random_(to);
836
+ })
631
837
  .define_method(
632
838
  "sub!",
633
- *[](torch::Tensor& self, torch::Tensor& other) {
839
+ *[](Tensor& self, Tensor& other) {
634
840
  return self.sub_(other);
635
841
  })
636
842
  .define_method(
637
843
  "_mul!",
638
- *[](torch::Tensor& self, torch::Tensor& other) {
844
+ *[](Tensor& self, Tensor& other) {
639
845
  return self.mul_(other);
640
846
  })
641
847
  .define_method(
642
848
  "_mul_scalar!",
643
- *[](torch::Tensor& self, Scalar other) {
849
+ *[](Tensor& self, Scalar other) {
644
850
  return self.mul_(other);
645
851
  })
646
852
  .define_method(
647
853
  "div!",
648
- *[](torch::Tensor& self, torch::Tensor& other) {
854
+ *[](Tensor& self, Tensor& other) {
649
855
  return self.div_(other);
650
856
  })
651
857
  .define_method(
652
858
  "sqrt!",
653
- *[](torch::Tensor& self) {
859
+ *[](Tensor& self) {
654
860
  return self.sqrt_();
655
861
  })
656
862
  .define_method(
657
863
  "unsqueeze!",
658
- *[](torch::Tensor& self, int64_t dim) {
864
+ *[](Tensor& self, int64_t dim) {
659
865
  return self.unsqueeze_(dim);
660
866
  })
661
867
  .define_method(
662
868
  "copy!",
663
- *[](torch::Tensor& self, torch::Tensor& src) {
869
+ *[](Tensor& self, Tensor& src) {
664
870
  return self.copy_(src);
665
871
  })
666
872
  .define_method(
667
873
  "clone",
668
- *[](torch::Tensor& self) {
874
+ *[](Tensor& self) {
669
875
  return self.clone();
670
876
  })
671
- .define_method(
672
- "log_softmax",
673
- *[](torch::Tensor& self, int64_t dim) {
674
- return self.log_softmax(dim);
675
- })
676
877
  .define_method(
677
878
  "data",
678
- *[](torch::Tensor& self) {
879
+ *[](Tensor& self) {
679
880
  return self.data();
680
881
  })
681
882
  .define_method(
682
883
  "_data",
683
- *[](torch::Tensor& self) {
884
+ *[](Tensor& self) {
684
885
  Array a;
685
886
  auto dtype = self.dtype();
686
887
 
@@ -732,17 +933,17 @@ void Init_ext()
732
933
  })
733
934
  .define_method(
734
935
  "_size",
735
- *[](torch::Tensor& self, int i) {
936
+ *[](Tensor& self, int i) {
736
937
  return self.size(i);
737
938
  })
738
939
  .define_method(
739
940
  "_to",
740
- *[](torch::Tensor& self, torch::Device device, int dtype, bool non_blocking, bool copy) {
941
+ *[](Tensor& self, torch::Device device, int dtype, bool non_blocking, bool copy) {
741
942
  return self.to(device, (torch::ScalarType) dtype, non_blocking, copy);
742
943
  })
743
944
  .define_singleton_method(
744
945
  "_make_subclass",
745
- *[](torch::Tensor& rd, bool requires_grad) {
946
+ *[](Tensor& rd, bool requires_grad) {
746
947
  auto data = torch::autograd::as_variable_ref(rd).detach();
747
948
  data.unsafeGetTensorImpl()->set_allow_tensor_metadata_change(true);
748
949
  auto var = data.set_requires_grad(requires_grad);
@@ -793,32 +994,82 @@ void Init_ext()
793
994
 
794
995
  Module rb_mInit = define_module_under(rb_mNN, "Init")
795
996
  .define_singleton_method(
796
- "kaiming_uniform!",
797
- *[](torch::Tensor& input, double a) {
798
- return torch::nn::init::kaiming_uniform_(input, a);
997
+ "_calculate_gain",
998
+ *[](NonlinearityType nonlinearity, double param) {
999
+ return torch::nn::init::calculate_gain(nonlinearity, param);
799
1000
  })
800
1001
  .define_singleton_method(
801
- "normal!",
802
- *[](torch::Tensor& input) {
803
- return torch::nn::init::normal_(input);
1002
+ "_uniform!",
1003
+ *[](Tensor tensor, double low, double high) {
1004
+ return torch::nn::init::uniform_(tensor, low, high);
1005
+ })
1006
+ .define_singleton_method(
1007
+ "_normal!",
1008
+ *[](Tensor tensor, double mean, double std) {
1009
+ return torch::nn::init::normal_(tensor, mean, std);
1010
+ })
1011
+ .define_singleton_method(
1012
+ "_constant!",
1013
+ *[](Tensor tensor, Scalar value) {
1014
+ return torch::nn::init::constant_(tensor, value);
1015
+ })
1016
+ .define_singleton_method(
1017
+ "_ones!",
1018
+ *[](Tensor tensor) {
1019
+ return torch::nn::init::ones_(tensor);
1020
+ })
1021
+ .define_singleton_method(
1022
+ "_zeros!",
1023
+ *[](Tensor tensor) {
1024
+ return torch::nn::init::zeros_(tensor);
1025
+ })
1026
+ .define_singleton_method(
1027
+ "_eye!",
1028
+ *[](Tensor tensor) {
1029
+ return torch::nn::init::eye_(tensor);
1030
+ })
1031
+ .define_singleton_method(
1032
+ "_dirac!",
1033
+ *[](Tensor tensor) {
1034
+ return torch::nn::init::dirac_(tensor);
1035
+ })
1036
+ .define_singleton_method(
1037
+ "_xavier_uniform!",
1038
+ *[](Tensor tensor, double gain) {
1039
+ return torch::nn::init::xavier_uniform_(tensor, gain);
1040
+ })
1041
+ .define_singleton_method(
1042
+ "_xavier_normal!",
1043
+ *[](Tensor tensor, double gain) {
1044
+ return torch::nn::init::xavier_normal_(tensor, gain);
1045
+ })
1046
+ .define_singleton_method(
1047
+ "_kaiming_uniform!",
1048
+ *[](Tensor tensor, double a, FanModeType mode, NonlinearityType nonlinearity) {
1049
+ return torch::nn::init::kaiming_uniform_(tensor, a, mode, nonlinearity);
1050
+ })
1051
+ .define_singleton_method(
1052
+ "_kaiming_normal!",
1053
+ *[](Tensor tensor, double a, FanModeType mode, NonlinearityType nonlinearity) {
1054
+ return torch::nn::init::kaiming_normal_(tensor, a, mode, nonlinearity);
804
1055
  })
805
1056
  .define_singleton_method(
806
- "uniform!",
807
- *[](torch::Tensor& input, double to, double from) {
808
- return torch::nn::init::uniform_(input, to, from);
1057
+ "_orthogonal!",
1058
+ *[](Tensor tensor, double gain) {
1059
+ return torch::nn::init::orthogonal_(tensor, gain);
1060
+ })
1061
+ .define_singleton_method(
1062
+ "_sparse!",
1063
+ *[](Tensor tensor, double sparsity, double std) {
1064
+ return torch::nn::init::sparse_(tensor, sparsity, std);
809
1065
  });
810
1066
 
811
1067
  Class rb_cParameter = define_class_under<torch::autograd::Variable, torch::Tensor>(rb_mNN, "Parameter")
812
- // TODO return grad or nil to remove need for 2nd function
813
1068
  .define_method(
814
- "_grad",
815
- *[](torch::autograd::Variable& self) {
816
- return self.grad();
817
- })
818
- .define_method(
819
- "_grad_defined",
1069
+ "grad",
820
1070
  *[](torch::autograd::Variable& self) {
821
- return self.grad().defined();
1071
+ auto grad = self.grad();
1072
+ return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
822
1073
  });
823
1074
 
824
1075
  Class rb_cDevice = define_class_under<torch::Device>(rb_mTorch, "Device")