torch-rb 0.1.4 → 0.1.5

Sign up to get free protection for your applications and to get access to all the features.
Files changed (39) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +6 -0
  3. data/README.md +5 -3
  4. data/ext/torch/ext.cpp +22 -548
  5. data/ext/torch/extconf.rb +6 -0
  6. data/ext/torch/nn_functions.cpp +595 -0
  7. data/ext/torch/nn_functions.hpp +6 -0
  8. data/ext/torch/templates.hpp +250 -0
  9. data/ext/torch/tensor_functions.cpp +1860 -0
  10. data/ext/torch/tensor_functions.hpp +6 -0
  11. data/ext/torch/torch_functions.cpp +2875 -0
  12. data/ext/torch/torch_functions.hpp +6 -0
  13. data/lib/torch.rb +68 -129
  14. data/lib/torch/ext.bundle +0 -0
  15. data/lib/torch/native/dispatcher.rb +48 -0
  16. data/lib/torch/native/function.rb +78 -0
  17. data/lib/torch/native/generator.rb +149 -0
  18. data/lib/torch/native/native_functions.yaml +6837 -0
  19. data/lib/torch/native/parser.rb +97 -0
  20. data/lib/torch/nn/bce_with_logits_loss.rb +15 -0
  21. data/lib/torch/nn/conv2d.rb +0 -2
  22. data/lib/torch/nn/cosine_embedding_loss.rb +14 -0
  23. data/lib/torch/nn/functional.rb +55 -16
  24. data/lib/torch/nn/hinge_embedding_loss.rb +14 -0
  25. data/lib/torch/nn/identity.rb +1 -0
  26. data/lib/torch/nn/margin_ranking_loss.rb +14 -0
  27. data/lib/torch/nn/module.rb +59 -12
  28. data/lib/torch/nn/multi_label_margin_loss.rb +13 -0
  29. data/lib/torch/nn/multi_label_soft_margin_loss.rb +13 -0
  30. data/lib/torch/nn/multi_margin_loss.rb +17 -0
  31. data/lib/torch/nn/parameter.rb +4 -0
  32. data/lib/torch/nn/rnn.rb +22 -0
  33. data/lib/torch/nn/rnn_base.rb +154 -0
  34. data/lib/torch/nn/smooth_l1_loss.rb +13 -0
  35. data/lib/torch/nn/soft_margin_loss.rb +13 -0
  36. data/lib/torch/nn/triplet_margin_loss.rb +18 -0
  37. data/lib/torch/tensor.rb +19 -19
  38. data/lib/torch/version.rb +1 -1
  39. metadata +26 -2
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 4faccffe2d2fd29519ad9dcce0560978a07c734831b5f64bb4624a0037f2b08c
4
- data.tar.gz: 4a8f873a9bb99c2311c856c59e5c43a5dfadd3f4f2460da1370ca1db888b79ad
3
+ metadata.gz: 6b47306ed525e1a20d25cb8324d4658f750c18afa5704c9b7bafc215d8f568c1
4
+ data.tar.gz: dad6ddf955b111989b061e5af146006a32c83dc1ea1ca5005a6b6e34bc9a4892
5
5
  SHA512:
6
- metadata.gz: 199b3b47325b72b38786f50c39f0cfb9b11709f02edf4b77e1c4e9198baf5fa2b3924d639d6c0f7e1715193528d4cfab4f53fbba7e2b16e06ac462d37862cf3a
7
- data.tar.gz: 4514f7aab60d9beabee47db175c9df6e3e1e93080b47eaeacff0f9dd4e8e737b476c8400f84bda4ccdf61aca4f7c81e010bcfbc2fa67f8c629c33cd6dcdcb54c
6
+ metadata.gz: 5d26e3642bf7cd921b9b570052df353d4c32b1bd955a6fbbf5f30249631fa4c0d4624f4fa91a1c06f61b3b0d6461cd117ab4df185cf013e915d2f63e52dbcf7c
7
+ data.tar.gz: 1728ce9b579f41f7a567e63d7256c82bb352840b67f16d88aac930a99e5abbf5a5f4061c5f9da16fb47d1664567e7956d276a8b2b44f13d2263032486afb53e8
data/CHANGELOG.md CHANGED
@@ -1,3 +1,9 @@
1
+ ## 0.1.5 (2019-12-06)
2
+
3
+ - Added many more functions
4
+ - Added tensor classes - `FloatTensor`, `LongTensor`, etc
5
+ - Improved modules
6
+
1
7
  ## 0.1.4 (2019-12-01)
2
8
 
3
9
  - Added distance functions
data/README.md CHANGED
@@ -20,6 +20,8 @@ Add this line to your application’s Gemfile:
20
20
  gem 'torch-rb'
21
21
  ```
22
22
 
23
+ It can take a few minutes to compile the extension.
24
+
23
25
  ## Getting Started
24
26
 
25
27
  This library follows the [PyTorch API](https://pytorch.org/docs/stable/torch.html). There are a few changes to make it more Ruby-like:
@@ -28,7 +30,7 @@ This library follows the [PyTorch API](https://pytorch.org/docs/stable/torch.htm
28
30
  - Methods that return booleans use `?` instead of `is_` (`tensor?` instead of `is_tensor`)
29
31
  - Numo is used instead of NumPy (`x.numo` instead of `x.numpy()`)
30
32
 
31
- Many methods and options are missing at the moment. PRs welcome!
33
+ Some methods and options are missing at the moment. PRs welcome!
32
34
 
33
35
  ## Tutorial
34
36
 
@@ -365,9 +367,9 @@ Here’s a list of functions to create tensors (descriptions from the [C++ docs]
365
367
 
366
368
  Here are a few full examples:
367
369
 
368
- - [Image classification with MNIST](examples/mnist)
370
+ - [Image classification with MNIST](examples/mnist) ([日本語版](https://qiita.com/kojix2/items/c19c36dc1bf73ea93409))
369
371
  - [Collaborative filtering with MovieLens](examples/movielens)
370
- - [Word embeddings](examples/nlp) [master]
372
+ - [Word embeddings](examples/nlp)
371
373
 
372
374
  ## LibTorch Installation
373
375
 
data/ext/torch/ext.cpp CHANGED
@@ -6,230 +6,15 @@
6
6
  #include <rice/Class.hpp>
7
7
  #include <rice/Constructor.hpp>
8
8
 
9
- using namespace Rice;
10
-
11
- template<>
12
- inline
13
- long long from_ruby<long long>(Object x)
14
- {
15
- return NUM2LL(x);
16
- }
17
-
18
- template<>
19
- inline
20
- Object to_ruby<long long>(long long const & x)
21
- {
22
- return LL2NUM(x);
23
- }
24
-
25
- template<>
26
- inline
27
- unsigned long long from_ruby<unsigned long long>(Object x)
28
- {
29
- return NUM2ULL(x);
30
- }
31
-
32
- template<>
33
- inline
34
- Object to_ruby<unsigned long long>(unsigned long long const & x)
35
- {
36
- return ULL2NUM(x);
37
- }
38
-
39
- template<>
40
- inline
41
- short from_ruby<short>(Object x)
42
- {
43
- return NUM2SHORT(x);
44
- }
45
-
46
- template<>
47
- inline
48
- Object to_ruby<short>(short const & x)
49
- {
50
- return INT2NUM(x);
51
- }
52
-
53
- template<>
54
- inline
55
- unsigned short from_ruby<unsigned short>(Object x)
56
- {
57
- return NUM2USHORT(x);
58
- }
59
-
60
- template<>
61
- inline
62
- Object to_ruby<unsigned short>(unsigned short const & x)
63
- {
64
- return UINT2NUM(x);
65
- }
9
+ #include "templates.hpp"
66
10
 
67
- // need to wrap torch::IntArrayRef() since
68
- // it doesn't own underlying data
69
- class IntArrayRef {
70
- std::vector<int64_t> vec;
71
- public:
72
- IntArrayRef(Object o) {
73
- Array a = Array(o);
74
- for (size_t i = 0; i < a.size(); i++) {
75
- vec.push_back(from_ruby<int64_t>(a[i]));
76
- }
77
- }
78
- operator torch::IntArrayRef() {
79
- return torch::IntArrayRef(vec);
80
- }
81
- };
11
+ // generated with:
12
+ // rake generate:functions
13
+ #include "torch_functions.hpp"
14
+ #include "tensor_functions.hpp"
15
+ #include "nn_functions.hpp"
82
16
 
83
- template<>
84
- inline
85
- IntArrayRef from_ruby<IntArrayRef>(Object x)
86
- {
87
- return IntArrayRef(x);
88
- }
89
-
90
- // for now
91
- class Scalar {
92
- torch::Scalar value;
93
- public:
94
- Scalar(Object o) {
95
- // TODO cast based on Ruby type
96
- if (o.rb_type() == T_FIXNUM) {
97
- value = torch::Scalar(from_ruby<int64_t>(o));
98
- } else {
99
- value = torch::Scalar(from_ruby<float>(o));
100
- }
101
- }
102
- operator torch::Scalar() {
103
- return value;
104
- }
105
- };
106
-
107
- template<>
108
- inline
109
- Scalar from_ruby<Scalar>(Object x)
110
- {
111
- return Scalar(x);
112
- }
113
-
114
- class TensorList {
115
- std::vector<torch::Tensor> vec;
116
- public:
117
- TensorList(Object o) {
118
- Array a = Array(o);
119
- for (size_t i = 0; i < a.size(); i++) {
120
- vec.push_back(from_ruby<torch::Tensor>(a[i]));
121
- }
122
- }
123
- operator torch::TensorList() {
124
- return torch::TensorList(vec);
125
- }
126
- };
127
-
128
- template<>
129
- inline
130
- TensorList from_ruby<TensorList>(Object x)
131
- {
132
- return TensorList(x);
133
- }
134
-
135
- class FanModeType {
136
- std::string s;
137
- public:
138
- FanModeType(Object o) {
139
- s = String(o).str();
140
- }
141
- // TODO switch NonlinearityType after LibTorch 1.4 release
142
- operator torch::nn::init::FanMode() {
143
- if (s == "fan_in") {
144
- return torch::nn::init::FanMode::FanIn;
145
- } else if (s == "fan_out") {
146
- return torch::nn::init::FanMode::FanOut;
147
- } else {
148
- throw std::runtime_error("Unsupported nonlinearity type: " + s);
149
- }
150
- }
151
- };
152
-
153
- template<>
154
- inline
155
- FanModeType from_ruby<FanModeType>(Object x)
156
- {
157
- return FanModeType(x);
158
- }
159
-
160
- class NonlinearityType {
161
- std::string s;
162
- public:
163
- NonlinearityType(Object o) {
164
- s = String(o).str();
165
- }
166
- // TODO switch NonlinearityType after LibTorch 1.4 release
167
- operator torch::nn::init::Nonlinearity() {
168
- if (s == "linear") {
169
- return torch::nn::init::Nonlinearity::Linear;
170
- } else if (s == "conv1d") {
171
- return torch::nn::init::Nonlinearity::Conv1D;
172
- } else if (s == "conv2d") {
173
- return torch::nn::init::Nonlinearity::Conv2D;
174
- } else if (s == "conv3d") {
175
- return torch::nn::init::Nonlinearity::Conv3D;
176
- } else if (s == "conv_transpose1d") {
177
- return torch::nn::init::Nonlinearity::ConvTranspose1D;
178
- } else if (s == "conv_transpose2d") {
179
- return torch::nn::init::Nonlinearity::ConvTranspose2D;
180
- } else if (s == "conv_transpose3d") {
181
- return torch::nn::init::Nonlinearity::ConvTranspose3D;
182
- } else if (s == "sigmoid") {
183
- return torch::nn::init::Nonlinearity::Sigmoid;
184
- } else if (s == "tanh") {
185
- return torch::nn::init::Nonlinearity::Tanh;
186
- } else if (s == "relu") {
187
- return torch::nn::init::Nonlinearity::ReLU;
188
- } else if (s == "leaky_relu") {
189
- return torch::nn::init::Nonlinearity::LeakyReLU;
190
- } else {
191
- throw std::runtime_error("Unsupported nonlinearity type: " + s);
192
- }
193
- }
194
- };
195
-
196
- template<>
197
- inline
198
- NonlinearityType from_ruby<NonlinearityType>(Object x)
199
- {
200
- return NonlinearityType(x);
201
- }
202
-
203
- class MyReduction {
204
- Object value;
205
- public:
206
- MyReduction(Object o) {
207
- value = o;
208
- }
209
- operator int64_t() {
210
- if (value.is_nil()) {
211
- return Reduction::None;
212
- }
213
-
214
- std::string s = String(value).str();
215
- if (s == "mean") {
216
- return Reduction::Mean;
217
- } else if (s == "sum") {
218
- return Reduction::Sum;
219
- } else {
220
- throw std::runtime_error("Unsupported reduction: " + s);
221
- }
222
- }
223
- };
224
-
225
- template<>
226
- inline
227
- MyReduction from_ruby<MyReduction>(Object x)
228
- {
229
- return MyReduction(x);
230
- }
231
-
232
- typedef torch::Tensor Tensor;
17
+ using namespace Rice;
233
18
 
234
19
  Object tensor_array(std::tuple<torch::Tensor, torch::Tensor> x) {
235
20
  Array a;
@@ -241,8 +26,16 @@ Object tensor_array(std::tuple<torch::Tensor, torch::Tensor> x) {
241
26
  extern "C"
242
27
  void Init_ext()
243
28
  {
244
- Module rb_mTorch = define_module("Torch")
245
- .define_singleton_method(
29
+ Module rb_mTorch = define_module("Torch");
30
+ add_torch_functions(rb_mTorch);
31
+
32
+ Class rb_cTensor = define_class_under<torch::Tensor>(rb_mTorch, "Tensor");
33
+ add_tensor_functions(rb_cTensor);
34
+
35
+ Module rb_mNN = define_module_under(rb_mTorch, "NN");
36
+ add_nn_functions(rb_mNN);
37
+
38
+ rb_mTorch.define_singleton_method(
246
39
  "grad_enabled?",
247
40
  *[]() {
248
41
  return torch::GradMode::is_enabled();
@@ -252,11 +45,6 @@ void Init_ext()
252
45
  *[](bool enabled) {
253
46
  torch::GradMode::set_enabled(enabled);
254
47
  })
255
- .define_singleton_method(
256
- "floating_point?",
257
- *[](Tensor& input) {
258
- return torch::is_floating_point(input);
259
- })
260
48
  .define_singleton_method(
261
49
  "manual_seed",
262
50
  *[](uint64_t seed) {
@@ -344,168 +132,16 @@ void Init_ext()
344
132
  *[](Tensor& input, int64_t dim, bool keepdim) {
345
133
  return torch::sum(input, dim, keepdim);
346
134
  })
347
- .define_singleton_method(
348
- "_argmax",
349
- *[](Tensor& input) {
350
- return torch::argmax(input);
351
- })
352
- .define_singleton_method(
353
- "_argmax_dim",
354
- *[](Tensor& input, int64_t dim, bool keepdim) {
355
- return torch::argmax(input, dim, keepdim);
356
- })
357
- .define_singleton_method(
358
- "_cat",
359
- *[](TensorList tensors, int64_t dim) {
360
- return torch::cat(tensors, dim);
361
- })
362
- .define_singleton_method(
363
- "_norm",
364
- *[](Tensor& input) {
365
- return torch::norm(input);
366
- })
367
- .define_singleton_method(
368
- "_min",
369
- *[](Tensor& input) {
370
- return torch::min(input);
371
- })
372
- .define_singleton_method(
373
- "_max",
374
- *[](Tensor& input) {
375
- return torch::max(input);
376
- })
377
135
  .define_singleton_method(
378
136
  "_max_out",
379
137
  *[](Tensor &max, Tensor &max_indices, const Tensor &input, int64_t dim, bool keepdim) {
380
138
  return tensor_array(torch::_max_out(max, max_indices, input, dim, keepdim));
381
139
  })
382
- .define_singleton_method(
383
- "_sqrt",
384
- *[](Tensor& input) {
385
- return torch::sqrt(input);
386
- })
387
- .define_singleton_method(
388
- "_exp",
389
- *[](Tensor& input) {
390
- return torch::exp(input);
391
- })
392
- .define_singleton_method(
393
- "_log",
394
- *[](Tensor& input) {
395
- return torch::log(input);
396
- })
397
- .define_singleton_method(
398
- "_sign",
399
- *[](Tensor& input) {
400
- return torch::sign(input);
401
- })
402
- .define_singleton_method(
403
- "_unsqueeze",
404
- *[](Tensor& input, int64_t dim) {
405
- return torch::unsqueeze(input, dim);
406
- })
407
- .define_singleton_method(
408
- "_dot",
409
- *[](Tensor& input, Tensor& tensor) {
410
- return torch::dot(input, tensor);
411
- })
412
- .define_singleton_method(
413
- "_matmul",
414
- *[](Tensor& input, Tensor& other) {
415
- return torch::matmul(input, other);
416
- })
417
- .define_singleton_method(
418
- "_eq",
419
- *[](Tensor& input, Tensor& other) {
420
- return torch::eq(input, other);
421
- })
422
- .define_singleton_method(
423
- "_gt",
424
- // TODO support tensors
425
- *[](Tensor& input, Scalar other) {
426
- return torch::gt(input, other);
427
- })
428
- .define_singleton_method(
429
- "_lt",
430
- // TODO support tensors
431
- *[](Tensor& input, Scalar other) {
432
- return torch::lt(input, other);
433
- })
434
- .define_singleton_method(
435
- "_add",
436
- *[](Tensor& input, Tensor& other) {
437
- return torch::add(input, other);
438
- })
439
- .define_singleton_method(
440
- "_add_scalar",
441
- *[](Tensor& input, Scalar other) {
442
- return torch::add(input, other);
443
- })
444
- .define_singleton_method(
445
- "_add_out",
446
- *[](Tensor& out, Tensor& input, Tensor& other) {
447
- return torch::add_out(out, input, other);
448
- })
449
- .define_singleton_method(
450
- "_sub",
451
- *[](Tensor& input, Tensor& other) {
452
- return torch::sub(input, other);
453
- })
454
- .define_singleton_method(
455
- "_sub_scalar",
456
- *[](Tensor& input, Scalar other) {
457
- return torch::sub(input, other);
458
- })
459
- .define_singleton_method(
460
- "_mul",
461
- *[](Tensor& input, Tensor& other) {
462
- return torch::mul(input, other);
463
- })
464
- .define_singleton_method(
465
- "_mul_scalar",
466
- *[](Tensor& input, Scalar other) {
467
- return torch::mul(input, other);
468
- })
469
- .define_singleton_method(
470
- "_div",
471
- *[](Tensor& input, Tensor& other) {
472
- return torch::div(input, other);
473
- })
474
- .define_singleton_method(
475
- "_div_scalar",
476
- *[](Tensor& input, Scalar other) {
477
- return torch::div(input, other);
478
- })
479
- .define_singleton_method(
480
- "_remainder",
481
- *[](Tensor& input, Tensor& other) {
482
- return torch::remainder(input, other);
483
- })
484
- .define_singleton_method(
485
- "_remainder_scalar",
486
- *[](Tensor& input, Scalar other) {
487
- return torch::remainder(input, other);
488
- })
489
- .define_singleton_method(
490
- "_pow",
491
- *[](Tensor& input, Scalar exponent) {
492
- return torch::pow(input, exponent);
493
- })
494
140
  .define_singleton_method(
495
141
  "_topk",
496
142
  *[](Tensor& input, int64_t k) {
497
143
  return tensor_array(torch::topk(input, k));
498
144
  })
499
- .define_singleton_method(
500
- "_sigmoid",
501
- *[](Tensor& input) {
502
- return torch::sigmoid(input);
503
- })
504
- .define_singleton_method(
505
- "_softplus",
506
- *[](const Tensor &input, Scalar beta, Scalar threshold) {
507
- return torch::softplus(input, beta, threshold);
508
- })
509
145
  .define_singleton_method(
510
146
  "_softmax",
511
147
  *[](const Tensor &input, int64_t dim) {
@@ -516,26 +152,6 @@ void Init_ext()
516
152
  *[](Tensor& input, int64_t dim) {
517
153
  return torch::log_softmax(input, dim);
518
154
  })
519
- .define_singleton_method(
520
- "_abs",
521
- *[](Tensor& input) {
522
- return torch::abs(input);
523
- })
524
- .define_singleton_method(
525
- "_neg",
526
- *[](Tensor& input) {
527
- return torch::neg(input);
528
- })
529
- .define_singleton_method(
530
- "_reshape",
531
- *[](Tensor& input, IntArrayRef shape) {
532
- return torch::reshape(input, shape);
533
- })
534
- .define_singleton_method(
535
- "_flatten",
536
- *[](Tensor& input, int64_t start_dim, int64_t end_dim) {
537
- return torch::flatten(input, start_dim, end_dim);
538
- })
539
155
  .define_singleton_method(
540
156
  "relu",
541
157
  *[](Tensor& input) {
@@ -579,104 +195,9 @@ void Init_ext()
579
195
  return torch::avg_pool2d(input, kernel_size);
580
196
  })
581
197
  .define_singleton_method(
582
- "_dropout",
583
- *[](Tensor& input, float p, bool train) {
584
- return torch::dropout(input, p, train);
585
- })
586
- .define_singleton_method(
587
- "_dropout!",
588
- *[](Tensor& input, float p, bool train) {
589
- return torch::dropout_(input, p, train);
590
- })
591
- .define_singleton_method(
592
- "_feature_dropout",
593
- *[](Tensor& input, float p, bool train) {
594
- return torch::feature_dropout(input, p, train);
595
- })
596
- .define_singleton_method(
597
- "_feature_dropout!",
598
- *[](Tensor& input, float p, bool train) {
599
- return torch::feature_dropout_(input, p, train);
600
- })
601
- .define_singleton_method(
602
- "_alpha_dropout",
603
- *[](Tensor& input, float p, bool train) {
604
- return torch::alpha_dropout(input, p, train);
605
- })
606
- .define_singleton_method(
607
- "_alpha_dropout!",
608
- *[](Tensor& input, float p, bool train) {
609
- return torch::alpha_dropout_(input, p, train);
610
- })
611
- .define_singleton_method(
612
- "_feature_alpha_dropout",
613
- *[](Tensor& input, float p, bool train) {
614
- return torch::feature_alpha_dropout(input, p, train);
615
- })
616
- .define_singleton_method(
617
- "_feature_alpha_dropout!",
618
- *[](Tensor& input, float p, bool train) {
619
- return torch::feature_alpha_dropout_(input, p, train);
620
- })
621
- // sparse layers
622
- .define_singleton_method(
623
- "_embedding",
624
- // weight and indices are swapped from Python interface
625
- *[](const Tensor &indices, const Tensor &weight, int64_t padding_idx, bool scale_grad_by_freq, bool sparse) {
626
- return torch::embedding(weight, indices, padding_idx, scale_grad_by_freq, sparse);
627
- })
628
- .define_singleton_method(
629
- "_embedding_bag",
630
- // weight and indices are swapped from Python interface
631
- *[](const Tensor &weight, const Tensor &indices, const Tensor &offsets, bool scale_grad_by_freq, int64_t mode, bool sparse, const Tensor &per_sample_weights) {
632
- return torch::embedding_bag(weight, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights);
633
- })
634
- // distance functions
635
- .define_singleton_method(
636
- "_cosine_similarity",
637
- *[](const Tensor &x1, const Tensor &x2, int64_t dim, double eps) {
638
- return torch::cosine_similarity(x1, x2, dim, eps);
639
- })
640
- .define_singleton_method(
641
- "_pairwise_distance",
642
- *[](const Tensor &x1, const Tensor &x2, double p, double eps, bool keepdim) {
643
- return torch::pairwise_distance(x1, x2, p, eps, keepdim);
644
- })
645
- // loss functions
646
- .define_singleton_method(
647
- "binary_cross_entropy",
648
- *[](Tensor& input, Tensor& target, MyReduction reduction) {
649
- return torch::binary_cross_entropy(input, target, {}, reduction);
650
- })
651
- .define_singleton_method(
652
- "ctc_loss",
653
- *[](const Tensor &log_probs, const Tensor &targets, IntArrayRef input_lengths, IntArrayRef target_lengths, int64_t blank, MyReduction reduction, bool zero_infinity) {
654
- return torch::ctc_loss(log_probs, targets, input_lengths, target_lengths, blank, reduction, zero_infinity);
655
- })
656
- .define_singleton_method(
657
- "kl_div",
658
- *[](Tensor& input, Tensor& target, MyReduction reduction) {
659
- return torch::kl_div(input, target, reduction);
660
- })
661
- .define_singleton_method(
662
- "l1_loss",
663
- *[](Tensor& input, Tensor& target, MyReduction reduction) {
664
- return torch::l1_loss(input, target, reduction);
665
- })
666
- .define_singleton_method(
667
- "mse_loss",
668
- *[](Tensor& input, Tensor& target, MyReduction reduction) {
669
- return torch::mse_loss(input, target, reduction);
670
- })
671
- .define_singleton_method(
672
- "nll_loss",
673
- *[](Tensor& input, Tensor& target, MyReduction reduction, int64_t ignore_index) {
674
- return torch::nll_loss(input, target, {}, reduction, ignore_index);
675
- })
676
- .define_singleton_method(
677
- "poisson_nll_loss",
678
- *[](const Tensor &input, const Tensor &target, bool log_input, bool full, double eps, MyReduction reduction) {
679
- return torch::poisson_nll_loss(input, target, log_input, full, eps, reduction);
198
+ "_binary_cross_entropy_with_logits",
199
+ *[](const Tensor &input, const Tensor &target, OptionalTensor weight, OptionalTensor pos_weight, MyReduction reduction) {
200
+ return torch::binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction);
680
201
  })
681
202
  .define_singleton_method("numel", &torch::numel)
682
203
  .define_singleton_method(
@@ -703,7 +224,7 @@ void Init_ext()
703
224
  return t.reshape(size);
704
225
  });
705
226
 
706
- Class rb_cTensor = define_class_under<torch::Tensor>(rb_mTorch, "Tensor")
227
+ rb_cTensor
707
228
  .define_method("cuda?", &torch::Tensor::is_cuda)
708
229
  .define_method("distributed?", &torch::Tensor::is_distributed)
709
230
  .define_method("complex?", &torch::Tensor::is_complex)
@@ -740,16 +261,6 @@ void Init_ext()
740
261
  *[](Tensor& self) {
741
262
  return self.detach_();
742
263
  })
743
- .define_method(
744
- "_select",
745
- *[](Tensor& self, int64_t dim, int64_t index) {
746
- return self.select(dim, index);
747
- })
748
- .define_method(
749
- "_slice",
750
- *[](Tensor& self, int64_t dim, int64_t start, int64_t end, int64_t step) {
751
- return self.slice(dim, start, end, step);
752
- })
753
264
  .define_method(
754
265
  "_requires_grad!",
755
266
  *[](Tensor& self, bool requires_grad) {
@@ -789,11 +300,6 @@ void Init_ext()
789
300
  s << self.device();
790
301
  return s.str();
791
302
  })
792
- .define_method(
793
- "_view",
794
- *[](Tensor& self, IntArrayRef size) {
795
- return self.view(size);
796
- })
797
303
  .define_method(
798
304
  "resize_as!",
799
305
  *[](Tensor& self, Tensor& other) {
@@ -809,21 +315,6 @@ void Init_ext()
809
315
  *[](Tensor& self) {
810
316
  return self.relu_();
811
317
  })
812
- .define_method(
813
- "_add!",
814
- *[](Tensor& self, Tensor& other) {
815
- return self.add_(other);
816
- })
817
- .define_method(
818
- "_add_alpha!",
819
- *[](Tensor& self, Tensor& other, Scalar alpha) {
820
- return self.add_(other, alpha);
821
- })
822
- .define_method(
823
- "_add_scalar!",
824
- *[](Tensor& self, Scalar other) {
825
- return self.add_(other);
826
- })
827
318
  .define_method(
828
319
  "normal!",
829
320
  *[](Tensor& self, double mean, double std) {
@@ -839,16 +330,6 @@ void Init_ext()
839
330
  *[](Tensor& self, Tensor& other) {
840
331
  return self.sub_(other);
841
332
  })
842
- .define_method(
843
- "_mul!",
844
- *[](Tensor& self, Tensor& other) {
845
- return self.mul_(other);
846
- })
847
- .define_method(
848
- "_mul_scalar!",
849
- *[](Tensor& self, Scalar other) {
850
- return self.mul_(other);
851
- })
852
333
  .define_method(
853
334
  "div!",
854
335
  *[](Tensor& self, Tensor& other) {
@@ -880,7 +361,7 @@ void Init_ext()
880
361
  return self.data();
881
362
  })
882
363
  .define_method(
883
- "_data",
364
+ "_flat_data",
884
365
  *[](Tensor& self) {
885
366
  Array a;
886
367
  auto dtype = self.dtype();
@@ -931,11 +412,6 @@ void Init_ext()
931
412
  }
932
413
  return a;
933
414
  })
934
- .define_method(
935
- "_size",
936
- *[](Tensor& self, int i) {
937
- return self.size(i);
938
- })
939
415
  .define_method(
940
416
  "_to",
941
417
  *[](Tensor& self, torch::Device device, int dtype, bool non_blocking, bool copy) {
@@ -990,8 +466,6 @@ void Init_ext()
990
466
  return self.requires_grad(requires_grad);
991
467
  });
992
468
 
993
- Module rb_mNN = define_module_under(rb_mTorch, "NN");
994
-
995
469
  Module rb_mInit = define_module_under(rb_mNN, "Init")
996
470
  .define_singleton_method(
997
471
  "_calculate_gain",