torch-rb 0.1.4 → 0.1.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +6 -0
  3. data/README.md +5 -3
  4. data/ext/torch/ext.cpp +22 -548
  5. data/ext/torch/extconf.rb +6 -0
  6. data/ext/torch/nn_functions.cpp +595 -0
  7. data/ext/torch/nn_functions.hpp +6 -0
  8. data/ext/torch/templates.hpp +250 -0
  9. data/ext/torch/tensor_functions.cpp +1860 -0
  10. data/ext/torch/tensor_functions.hpp +6 -0
  11. data/ext/torch/torch_functions.cpp +2875 -0
  12. data/ext/torch/torch_functions.hpp +6 -0
  13. data/lib/torch.rb +68 -129
  14. data/lib/torch/ext.bundle +0 -0
  15. data/lib/torch/native/dispatcher.rb +48 -0
  16. data/lib/torch/native/function.rb +78 -0
  17. data/lib/torch/native/generator.rb +149 -0
  18. data/lib/torch/native/native_functions.yaml +6837 -0
  19. data/lib/torch/native/parser.rb +97 -0
  20. data/lib/torch/nn/bce_with_logits_loss.rb +15 -0
  21. data/lib/torch/nn/conv2d.rb +0 -2
  22. data/lib/torch/nn/cosine_embedding_loss.rb +14 -0
  23. data/lib/torch/nn/functional.rb +55 -16
  24. data/lib/torch/nn/hinge_embedding_loss.rb +14 -0
  25. data/lib/torch/nn/identity.rb +1 -0
  26. data/lib/torch/nn/margin_ranking_loss.rb +14 -0
  27. data/lib/torch/nn/module.rb +59 -12
  28. data/lib/torch/nn/multi_label_margin_loss.rb +13 -0
  29. data/lib/torch/nn/multi_label_soft_margin_loss.rb +13 -0
  30. data/lib/torch/nn/multi_margin_loss.rb +17 -0
  31. data/lib/torch/nn/parameter.rb +4 -0
  32. data/lib/torch/nn/rnn.rb +22 -0
  33. data/lib/torch/nn/rnn_base.rb +154 -0
  34. data/lib/torch/nn/smooth_l1_loss.rb +13 -0
  35. data/lib/torch/nn/soft_margin_loss.rb +13 -0
  36. data/lib/torch/nn/triplet_margin_loss.rb +18 -0
  37. data/lib/torch/tensor.rb +19 -19
  38. data/lib/torch/version.rb +1 -1
  39. metadata +26 -2
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 4faccffe2d2fd29519ad9dcce0560978a07c734831b5f64bb4624a0037f2b08c
4
- data.tar.gz: 4a8f873a9bb99c2311c856c59e5c43a5dfadd3f4f2460da1370ca1db888b79ad
3
+ metadata.gz: 6b47306ed525e1a20d25cb8324d4658f750c18afa5704c9b7bafc215d8f568c1
4
+ data.tar.gz: dad6ddf955b111989b061e5af146006a32c83dc1ea1ca5005a6b6e34bc9a4892
5
5
  SHA512:
6
- metadata.gz: 199b3b47325b72b38786f50c39f0cfb9b11709f02edf4b77e1c4e9198baf5fa2b3924d639d6c0f7e1715193528d4cfab4f53fbba7e2b16e06ac462d37862cf3a
7
- data.tar.gz: 4514f7aab60d9beabee47db175c9df6e3e1e93080b47eaeacff0f9dd4e8e737b476c8400f84bda4ccdf61aca4f7c81e010bcfbc2fa67f8c629c33cd6dcdcb54c
6
+ metadata.gz: 5d26e3642bf7cd921b9b570052df353d4c32b1bd955a6fbbf5f30249631fa4c0d4624f4fa91a1c06f61b3b0d6461cd117ab4df185cf013e915d2f63e52dbcf7c
7
+ data.tar.gz: 1728ce9b579f41f7a567e63d7256c82bb352840b67f16d88aac930a99e5abbf5a5f4061c5f9da16fb47d1664567e7956d276a8b2b44f13d2263032486afb53e8
data/CHANGELOG.md CHANGED
@@ -1,3 +1,9 @@
1
+ ## 0.1.5 (2019-12-06)
2
+
3
+ - Added many more functions
4
+ - Added tensor classes - `FloatTensor`, `LongTensor`, etc
5
+ - Improved modules
6
+
1
7
  ## 0.1.4 (2019-12-01)
2
8
 
3
9
  - Added distance functions
data/README.md CHANGED
@@ -20,6 +20,8 @@ Add this line to your application’s Gemfile:
20
20
  gem 'torch-rb'
21
21
  ```
22
22
 
23
+ It can take a few minutes to compile the extension.
24
+
23
25
  ## Getting Started
24
26
 
25
27
  This library follows the [PyTorch API](https://pytorch.org/docs/stable/torch.html). There are a few changes to make it more Ruby-like:
@@ -28,7 +30,7 @@ This library follows the [PyTorch API](https://pytorch.org/docs/stable/torch.htm
28
30
  - Methods that return booleans use `?` instead of `is_` (`tensor?` instead of `is_tensor`)
29
31
  - Numo is used instead of NumPy (`x.numo` instead of `x.numpy()`)
30
32
 
31
- Many methods and options are missing at the moment. PRs welcome!
33
+ Some methods and options are missing at the moment. PRs welcome!
32
34
 
33
35
  ## Tutorial
34
36
 
@@ -365,9 +367,9 @@ Here’s a list of functions to create tensors (descriptions from the [C++ docs]
365
367
 
366
368
  Here are a few full examples:
367
369
 
368
- - [Image classification with MNIST](examples/mnist)
370
+ - [Image classification with MNIST](examples/mnist) ([日本語版](https://qiita.com/kojix2/items/c19c36dc1bf73ea93409))
369
371
  - [Collaborative filtering with MovieLens](examples/movielens)
370
- - [Word embeddings](examples/nlp) [master]
372
+ - [Word embeddings](examples/nlp)
371
373
 
372
374
  ## LibTorch Installation
373
375
 
data/ext/torch/ext.cpp CHANGED
@@ -6,230 +6,15 @@
6
6
  #include <rice/Class.hpp>
7
7
  #include <rice/Constructor.hpp>
8
8
 
9
- using namespace Rice;
10
-
11
- template<>
12
- inline
13
- long long from_ruby<long long>(Object x)
14
- {
15
- return NUM2LL(x);
16
- }
17
-
18
- template<>
19
- inline
20
- Object to_ruby<long long>(long long const & x)
21
- {
22
- return LL2NUM(x);
23
- }
24
-
25
- template<>
26
- inline
27
- unsigned long long from_ruby<unsigned long long>(Object x)
28
- {
29
- return NUM2ULL(x);
30
- }
31
-
32
- template<>
33
- inline
34
- Object to_ruby<unsigned long long>(unsigned long long const & x)
35
- {
36
- return ULL2NUM(x);
37
- }
38
-
39
- template<>
40
- inline
41
- short from_ruby<short>(Object x)
42
- {
43
- return NUM2SHORT(x);
44
- }
45
-
46
- template<>
47
- inline
48
- Object to_ruby<short>(short const & x)
49
- {
50
- return INT2NUM(x);
51
- }
52
-
53
- template<>
54
- inline
55
- unsigned short from_ruby<unsigned short>(Object x)
56
- {
57
- return NUM2USHORT(x);
58
- }
59
-
60
- template<>
61
- inline
62
- Object to_ruby<unsigned short>(unsigned short const & x)
63
- {
64
- return UINT2NUM(x);
65
- }
9
+ #include "templates.hpp"
66
10
 
67
- // need to wrap torch::IntArrayRef() since
68
- // it doesn't own underlying data
69
- class IntArrayRef {
70
- std::vector<int64_t> vec;
71
- public:
72
- IntArrayRef(Object o) {
73
- Array a = Array(o);
74
- for (size_t i = 0; i < a.size(); i++) {
75
- vec.push_back(from_ruby<int64_t>(a[i]));
76
- }
77
- }
78
- operator torch::IntArrayRef() {
79
- return torch::IntArrayRef(vec);
80
- }
81
- };
11
+ // generated with:
12
+ // rake generate:functions
13
+ #include "torch_functions.hpp"
14
+ #include "tensor_functions.hpp"
15
+ #include "nn_functions.hpp"
82
16
 
83
- template<>
84
- inline
85
- IntArrayRef from_ruby<IntArrayRef>(Object x)
86
- {
87
- return IntArrayRef(x);
88
- }
89
-
90
- // for now
91
- class Scalar {
92
- torch::Scalar value;
93
- public:
94
- Scalar(Object o) {
95
- // TODO cast based on Ruby type
96
- if (o.rb_type() == T_FIXNUM) {
97
- value = torch::Scalar(from_ruby<int64_t>(o));
98
- } else {
99
- value = torch::Scalar(from_ruby<float>(o));
100
- }
101
- }
102
- operator torch::Scalar() {
103
- return value;
104
- }
105
- };
106
-
107
- template<>
108
- inline
109
- Scalar from_ruby<Scalar>(Object x)
110
- {
111
- return Scalar(x);
112
- }
113
-
114
- class TensorList {
115
- std::vector<torch::Tensor> vec;
116
- public:
117
- TensorList(Object o) {
118
- Array a = Array(o);
119
- for (size_t i = 0; i < a.size(); i++) {
120
- vec.push_back(from_ruby<torch::Tensor>(a[i]));
121
- }
122
- }
123
- operator torch::TensorList() {
124
- return torch::TensorList(vec);
125
- }
126
- };
127
-
128
- template<>
129
- inline
130
- TensorList from_ruby<TensorList>(Object x)
131
- {
132
- return TensorList(x);
133
- }
134
-
135
- class FanModeType {
136
- std::string s;
137
- public:
138
- FanModeType(Object o) {
139
- s = String(o).str();
140
- }
141
- // TODO switch NonlinearityType after LibTorch 1.4 release
142
- operator torch::nn::init::FanMode() {
143
- if (s == "fan_in") {
144
- return torch::nn::init::FanMode::FanIn;
145
- } else if (s == "fan_out") {
146
- return torch::nn::init::FanMode::FanOut;
147
- } else {
148
- throw std::runtime_error("Unsupported nonlinearity type: " + s);
149
- }
150
- }
151
- };
152
-
153
- template<>
154
- inline
155
- FanModeType from_ruby<FanModeType>(Object x)
156
- {
157
- return FanModeType(x);
158
- }
159
-
160
- class NonlinearityType {
161
- std::string s;
162
- public:
163
- NonlinearityType(Object o) {
164
- s = String(o).str();
165
- }
166
- // TODO switch NonlinearityType after LibTorch 1.4 release
167
- operator torch::nn::init::Nonlinearity() {
168
- if (s == "linear") {
169
- return torch::nn::init::Nonlinearity::Linear;
170
- } else if (s == "conv1d") {
171
- return torch::nn::init::Nonlinearity::Conv1D;
172
- } else if (s == "conv2d") {
173
- return torch::nn::init::Nonlinearity::Conv2D;
174
- } else if (s == "conv3d") {
175
- return torch::nn::init::Nonlinearity::Conv3D;
176
- } else if (s == "conv_transpose1d") {
177
- return torch::nn::init::Nonlinearity::ConvTranspose1D;
178
- } else if (s == "conv_transpose2d") {
179
- return torch::nn::init::Nonlinearity::ConvTranspose2D;
180
- } else if (s == "conv_transpose3d") {
181
- return torch::nn::init::Nonlinearity::ConvTranspose3D;
182
- } else if (s == "sigmoid") {
183
- return torch::nn::init::Nonlinearity::Sigmoid;
184
- } else if (s == "tanh") {
185
- return torch::nn::init::Nonlinearity::Tanh;
186
- } else if (s == "relu") {
187
- return torch::nn::init::Nonlinearity::ReLU;
188
- } else if (s == "leaky_relu") {
189
- return torch::nn::init::Nonlinearity::LeakyReLU;
190
- } else {
191
- throw std::runtime_error("Unsupported nonlinearity type: " + s);
192
- }
193
- }
194
- };
195
-
196
- template<>
197
- inline
198
- NonlinearityType from_ruby<NonlinearityType>(Object x)
199
- {
200
- return NonlinearityType(x);
201
- }
202
-
203
- class MyReduction {
204
- Object value;
205
- public:
206
- MyReduction(Object o) {
207
- value = o;
208
- }
209
- operator int64_t() {
210
- if (value.is_nil()) {
211
- return Reduction::None;
212
- }
213
-
214
- std::string s = String(value).str();
215
- if (s == "mean") {
216
- return Reduction::Mean;
217
- } else if (s == "sum") {
218
- return Reduction::Sum;
219
- } else {
220
- throw std::runtime_error("Unsupported reduction: " + s);
221
- }
222
- }
223
- };
224
-
225
- template<>
226
- inline
227
- MyReduction from_ruby<MyReduction>(Object x)
228
- {
229
- return MyReduction(x);
230
- }
231
-
232
- typedef torch::Tensor Tensor;
17
+ using namespace Rice;
233
18
 
234
19
  Object tensor_array(std::tuple<torch::Tensor, torch::Tensor> x) {
235
20
  Array a;
@@ -241,8 +26,16 @@ Object tensor_array(std::tuple<torch::Tensor, torch::Tensor> x) {
241
26
  extern "C"
242
27
  void Init_ext()
243
28
  {
244
- Module rb_mTorch = define_module("Torch")
245
- .define_singleton_method(
29
+ Module rb_mTorch = define_module("Torch");
30
+ add_torch_functions(rb_mTorch);
31
+
32
+ Class rb_cTensor = define_class_under<torch::Tensor>(rb_mTorch, "Tensor");
33
+ add_tensor_functions(rb_cTensor);
34
+
35
+ Module rb_mNN = define_module_under(rb_mTorch, "NN");
36
+ add_nn_functions(rb_mNN);
37
+
38
+ rb_mTorch.define_singleton_method(
246
39
  "grad_enabled?",
247
40
  *[]() {
248
41
  return torch::GradMode::is_enabled();
@@ -252,11 +45,6 @@ void Init_ext()
252
45
  *[](bool enabled) {
253
46
  torch::GradMode::set_enabled(enabled);
254
47
  })
255
- .define_singleton_method(
256
- "floating_point?",
257
- *[](Tensor& input) {
258
- return torch::is_floating_point(input);
259
- })
260
48
  .define_singleton_method(
261
49
  "manual_seed",
262
50
  *[](uint64_t seed) {
@@ -344,168 +132,16 @@ void Init_ext()
344
132
  *[](Tensor& input, int64_t dim, bool keepdim) {
345
133
  return torch::sum(input, dim, keepdim);
346
134
  })
347
- .define_singleton_method(
348
- "_argmax",
349
- *[](Tensor& input) {
350
- return torch::argmax(input);
351
- })
352
- .define_singleton_method(
353
- "_argmax_dim",
354
- *[](Tensor& input, int64_t dim, bool keepdim) {
355
- return torch::argmax(input, dim, keepdim);
356
- })
357
- .define_singleton_method(
358
- "_cat",
359
- *[](TensorList tensors, int64_t dim) {
360
- return torch::cat(tensors, dim);
361
- })
362
- .define_singleton_method(
363
- "_norm",
364
- *[](Tensor& input) {
365
- return torch::norm(input);
366
- })
367
- .define_singleton_method(
368
- "_min",
369
- *[](Tensor& input) {
370
- return torch::min(input);
371
- })
372
- .define_singleton_method(
373
- "_max",
374
- *[](Tensor& input) {
375
- return torch::max(input);
376
- })
377
135
  .define_singleton_method(
378
136
  "_max_out",
379
137
  *[](Tensor &max, Tensor &max_indices, const Tensor &input, int64_t dim, bool keepdim) {
380
138
  return tensor_array(torch::_max_out(max, max_indices, input, dim, keepdim));
381
139
  })
382
- .define_singleton_method(
383
- "_sqrt",
384
- *[](Tensor& input) {
385
- return torch::sqrt(input);
386
- })
387
- .define_singleton_method(
388
- "_exp",
389
- *[](Tensor& input) {
390
- return torch::exp(input);
391
- })
392
- .define_singleton_method(
393
- "_log",
394
- *[](Tensor& input) {
395
- return torch::log(input);
396
- })
397
- .define_singleton_method(
398
- "_sign",
399
- *[](Tensor& input) {
400
- return torch::sign(input);
401
- })
402
- .define_singleton_method(
403
- "_unsqueeze",
404
- *[](Tensor& input, int64_t dim) {
405
- return torch::unsqueeze(input, dim);
406
- })
407
- .define_singleton_method(
408
- "_dot",
409
- *[](Tensor& input, Tensor& tensor) {
410
- return torch::dot(input, tensor);
411
- })
412
- .define_singleton_method(
413
- "_matmul",
414
- *[](Tensor& input, Tensor& other) {
415
- return torch::matmul(input, other);
416
- })
417
- .define_singleton_method(
418
- "_eq",
419
- *[](Tensor& input, Tensor& other) {
420
- return torch::eq(input, other);
421
- })
422
- .define_singleton_method(
423
- "_gt",
424
- // TODO support tensors
425
- *[](Tensor& input, Scalar other) {
426
- return torch::gt(input, other);
427
- })
428
- .define_singleton_method(
429
- "_lt",
430
- // TODO support tensors
431
- *[](Tensor& input, Scalar other) {
432
- return torch::lt(input, other);
433
- })
434
- .define_singleton_method(
435
- "_add",
436
- *[](Tensor& input, Tensor& other) {
437
- return torch::add(input, other);
438
- })
439
- .define_singleton_method(
440
- "_add_scalar",
441
- *[](Tensor& input, Scalar other) {
442
- return torch::add(input, other);
443
- })
444
- .define_singleton_method(
445
- "_add_out",
446
- *[](Tensor& out, Tensor& input, Tensor& other) {
447
- return torch::add_out(out, input, other);
448
- })
449
- .define_singleton_method(
450
- "_sub",
451
- *[](Tensor& input, Tensor& other) {
452
- return torch::sub(input, other);
453
- })
454
- .define_singleton_method(
455
- "_sub_scalar",
456
- *[](Tensor& input, Scalar other) {
457
- return torch::sub(input, other);
458
- })
459
- .define_singleton_method(
460
- "_mul",
461
- *[](Tensor& input, Tensor& other) {
462
- return torch::mul(input, other);
463
- })
464
- .define_singleton_method(
465
- "_mul_scalar",
466
- *[](Tensor& input, Scalar other) {
467
- return torch::mul(input, other);
468
- })
469
- .define_singleton_method(
470
- "_div",
471
- *[](Tensor& input, Tensor& other) {
472
- return torch::div(input, other);
473
- })
474
- .define_singleton_method(
475
- "_div_scalar",
476
- *[](Tensor& input, Scalar other) {
477
- return torch::div(input, other);
478
- })
479
- .define_singleton_method(
480
- "_remainder",
481
- *[](Tensor& input, Tensor& other) {
482
- return torch::remainder(input, other);
483
- })
484
- .define_singleton_method(
485
- "_remainder_scalar",
486
- *[](Tensor& input, Scalar other) {
487
- return torch::remainder(input, other);
488
- })
489
- .define_singleton_method(
490
- "_pow",
491
- *[](Tensor& input, Scalar exponent) {
492
- return torch::pow(input, exponent);
493
- })
494
140
  .define_singleton_method(
495
141
  "_topk",
496
142
  *[](Tensor& input, int64_t k) {
497
143
  return tensor_array(torch::topk(input, k));
498
144
  })
499
- .define_singleton_method(
500
- "_sigmoid",
501
- *[](Tensor& input) {
502
- return torch::sigmoid(input);
503
- })
504
- .define_singleton_method(
505
- "_softplus",
506
- *[](const Tensor &input, Scalar beta, Scalar threshold) {
507
- return torch::softplus(input, beta, threshold);
508
- })
509
145
  .define_singleton_method(
510
146
  "_softmax",
511
147
  *[](const Tensor &input, int64_t dim) {
@@ -516,26 +152,6 @@ void Init_ext()
516
152
  *[](Tensor& input, int64_t dim) {
517
153
  return torch::log_softmax(input, dim);
518
154
  })
519
- .define_singleton_method(
520
- "_abs",
521
- *[](Tensor& input) {
522
- return torch::abs(input);
523
- })
524
- .define_singleton_method(
525
- "_neg",
526
- *[](Tensor& input) {
527
- return torch::neg(input);
528
- })
529
- .define_singleton_method(
530
- "_reshape",
531
- *[](Tensor& input, IntArrayRef shape) {
532
- return torch::reshape(input, shape);
533
- })
534
- .define_singleton_method(
535
- "_flatten",
536
- *[](Tensor& input, int64_t start_dim, int64_t end_dim) {
537
- return torch::flatten(input, start_dim, end_dim);
538
- })
539
155
  .define_singleton_method(
540
156
  "relu",
541
157
  *[](Tensor& input) {
@@ -579,104 +195,9 @@ void Init_ext()
579
195
  return torch::avg_pool2d(input, kernel_size);
580
196
  })
581
197
  .define_singleton_method(
582
- "_dropout",
583
- *[](Tensor& input, float p, bool train) {
584
- return torch::dropout(input, p, train);
585
- })
586
- .define_singleton_method(
587
- "_dropout!",
588
- *[](Tensor& input, float p, bool train) {
589
- return torch::dropout_(input, p, train);
590
- })
591
- .define_singleton_method(
592
- "_feature_dropout",
593
- *[](Tensor& input, float p, bool train) {
594
- return torch::feature_dropout(input, p, train);
595
- })
596
- .define_singleton_method(
597
- "_feature_dropout!",
598
- *[](Tensor& input, float p, bool train) {
599
- return torch::feature_dropout_(input, p, train);
600
- })
601
- .define_singleton_method(
602
- "_alpha_dropout",
603
- *[](Tensor& input, float p, bool train) {
604
- return torch::alpha_dropout(input, p, train);
605
- })
606
- .define_singleton_method(
607
- "_alpha_dropout!",
608
- *[](Tensor& input, float p, bool train) {
609
- return torch::alpha_dropout_(input, p, train);
610
- })
611
- .define_singleton_method(
612
- "_feature_alpha_dropout",
613
- *[](Tensor& input, float p, bool train) {
614
- return torch::feature_alpha_dropout(input, p, train);
615
- })
616
- .define_singleton_method(
617
- "_feature_alpha_dropout!",
618
- *[](Tensor& input, float p, bool train) {
619
- return torch::feature_alpha_dropout_(input, p, train);
620
- })
621
- // sparse layers
622
- .define_singleton_method(
623
- "_embedding",
624
- // weight and indices are swapped from Python interface
625
- *[](const Tensor &indices, const Tensor &weight, int64_t padding_idx, bool scale_grad_by_freq, bool sparse) {
626
- return torch::embedding(weight, indices, padding_idx, scale_grad_by_freq, sparse);
627
- })
628
- .define_singleton_method(
629
- "_embedding_bag",
630
- // weight and indices are swapped from Python interface
631
- *[](const Tensor &weight, const Tensor &indices, const Tensor &offsets, bool scale_grad_by_freq, int64_t mode, bool sparse, const Tensor &per_sample_weights) {
632
- return torch::embedding_bag(weight, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights);
633
- })
634
- // distance functions
635
- .define_singleton_method(
636
- "_cosine_similarity",
637
- *[](const Tensor &x1, const Tensor &x2, int64_t dim, double eps) {
638
- return torch::cosine_similarity(x1, x2, dim, eps);
639
- })
640
- .define_singleton_method(
641
- "_pairwise_distance",
642
- *[](const Tensor &x1, const Tensor &x2, double p, double eps, bool keepdim) {
643
- return torch::pairwise_distance(x1, x2, p, eps, keepdim);
644
- })
645
- // loss functions
646
- .define_singleton_method(
647
- "binary_cross_entropy",
648
- *[](Tensor& input, Tensor& target, MyReduction reduction) {
649
- return torch::binary_cross_entropy(input, target, {}, reduction);
650
- })
651
- .define_singleton_method(
652
- "ctc_loss",
653
- *[](const Tensor &log_probs, const Tensor &targets, IntArrayRef input_lengths, IntArrayRef target_lengths, int64_t blank, MyReduction reduction, bool zero_infinity) {
654
- return torch::ctc_loss(log_probs, targets, input_lengths, target_lengths, blank, reduction, zero_infinity);
655
- })
656
- .define_singleton_method(
657
- "kl_div",
658
- *[](Tensor& input, Tensor& target, MyReduction reduction) {
659
- return torch::kl_div(input, target, reduction);
660
- })
661
- .define_singleton_method(
662
- "l1_loss",
663
- *[](Tensor& input, Tensor& target, MyReduction reduction) {
664
- return torch::l1_loss(input, target, reduction);
665
- })
666
- .define_singleton_method(
667
- "mse_loss",
668
- *[](Tensor& input, Tensor& target, MyReduction reduction) {
669
- return torch::mse_loss(input, target, reduction);
670
- })
671
- .define_singleton_method(
672
- "nll_loss",
673
- *[](Tensor& input, Tensor& target, MyReduction reduction, int64_t ignore_index) {
674
- return torch::nll_loss(input, target, {}, reduction, ignore_index);
675
- })
676
- .define_singleton_method(
677
- "poisson_nll_loss",
678
- *[](const Tensor &input, const Tensor &target, bool log_input, bool full, double eps, MyReduction reduction) {
679
- return torch::poisson_nll_loss(input, target, log_input, full, eps, reduction);
198
+ "_binary_cross_entropy_with_logits",
199
+ *[](const Tensor &input, const Tensor &target, OptionalTensor weight, OptionalTensor pos_weight, MyReduction reduction) {
200
+ return torch::binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction);
680
201
  })
681
202
  .define_singleton_method("numel", &torch::numel)
682
203
  .define_singleton_method(
@@ -703,7 +224,7 @@ void Init_ext()
703
224
  return t.reshape(size);
704
225
  });
705
226
 
706
- Class rb_cTensor = define_class_under<torch::Tensor>(rb_mTorch, "Tensor")
227
+ rb_cTensor
707
228
  .define_method("cuda?", &torch::Tensor::is_cuda)
708
229
  .define_method("distributed?", &torch::Tensor::is_distributed)
709
230
  .define_method("complex?", &torch::Tensor::is_complex)
@@ -740,16 +261,6 @@ void Init_ext()
740
261
  *[](Tensor& self) {
741
262
  return self.detach_();
742
263
  })
743
- .define_method(
744
- "_select",
745
- *[](Tensor& self, int64_t dim, int64_t index) {
746
- return self.select(dim, index);
747
- })
748
- .define_method(
749
- "_slice",
750
- *[](Tensor& self, int64_t dim, int64_t start, int64_t end, int64_t step) {
751
- return self.slice(dim, start, end, step);
752
- })
753
264
  .define_method(
754
265
  "_requires_grad!",
755
266
  *[](Tensor& self, bool requires_grad) {
@@ -789,11 +300,6 @@ void Init_ext()
789
300
  s << self.device();
790
301
  return s.str();
791
302
  })
792
- .define_method(
793
- "_view",
794
- *[](Tensor& self, IntArrayRef size) {
795
- return self.view(size);
796
- })
797
303
  .define_method(
798
304
  "resize_as!",
799
305
  *[](Tensor& self, Tensor& other) {
@@ -809,21 +315,6 @@ void Init_ext()
809
315
  *[](Tensor& self) {
810
316
  return self.relu_();
811
317
  })
812
- .define_method(
813
- "_add!",
814
- *[](Tensor& self, Tensor& other) {
815
- return self.add_(other);
816
- })
817
- .define_method(
818
- "_add_alpha!",
819
- *[](Tensor& self, Tensor& other, Scalar alpha) {
820
- return self.add_(other, alpha);
821
- })
822
- .define_method(
823
- "_add_scalar!",
824
- *[](Tensor& self, Scalar other) {
825
- return self.add_(other);
826
- })
827
318
  .define_method(
828
319
  "normal!",
829
320
  *[](Tensor& self, double mean, double std) {
@@ -839,16 +330,6 @@ void Init_ext()
839
330
  *[](Tensor& self, Tensor& other) {
840
331
  return self.sub_(other);
841
332
  })
842
- .define_method(
843
- "_mul!",
844
- *[](Tensor& self, Tensor& other) {
845
- return self.mul_(other);
846
- })
847
- .define_method(
848
- "_mul_scalar!",
849
- *[](Tensor& self, Scalar other) {
850
- return self.mul_(other);
851
- })
852
333
  .define_method(
853
334
  "div!",
854
335
  *[](Tensor& self, Tensor& other) {
@@ -880,7 +361,7 @@ void Init_ext()
880
361
  return self.data();
881
362
  })
882
363
  .define_method(
883
- "_data",
364
+ "_flat_data",
884
365
  *[](Tensor& self) {
885
366
  Array a;
886
367
  auto dtype = self.dtype();
@@ -931,11 +412,6 @@ void Init_ext()
931
412
  }
932
413
  return a;
933
414
  })
934
- .define_method(
935
- "_size",
936
- *[](Tensor& self, int i) {
937
- return self.size(i);
938
- })
939
415
  .define_method(
940
416
  "_to",
941
417
  *[](Tensor& self, torch::Device device, int dtype, bool non_blocking, bool copy) {
@@ -990,8 +466,6 @@ void Init_ext()
990
466
  return self.requires_grad(requires_grad);
991
467
  });
992
468
 
993
- Module rb_mNN = define_module_under(rb_mTorch, "NN");
994
-
995
469
  Module rb_mInit = define_module_under(rb_mNN, "Init")
996
470
  .define_singleton_method(
997
471
  "_calculate_gain",