torch-rb 0.1.4 → 0.1.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/README.md +5 -3
- data/ext/torch/ext.cpp +22 -548
- data/ext/torch/extconf.rb +6 -0
- data/ext/torch/nn_functions.cpp +595 -0
- data/ext/torch/nn_functions.hpp +6 -0
- data/ext/torch/templates.hpp +250 -0
- data/ext/torch/tensor_functions.cpp +1860 -0
- data/ext/torch/tensor_functions.hpp +6 -0
- data/ext/torch/torch_functions.cpp +2875 -0
- data/ext/torch/torch_functions.hpp +6 -0
- data/lib/torch.rb +68 -129
- data/lib/torch/ext.bundle +0 -0
- data/lib/torch/native/dispatcher.rb +48 -0
- data/lib/torch/native/function.rb +78 -0
- data/lib/torch/native/generator.rb +149 -0
- data/lib/torch/native/native_functions.yaml +6837 -0
- data/lib/torch/native/parser.rb +97 -0
- data/lib/torch/nn/bce_with_logits_loss.rb +15 -0
- data/lib/torch/nn/conv2d.rb +0 -2
- data/lib/torch/nn/cosine_embedding_loss.rb +14 -0
- data/lib/torch/nn/functional.rb +55 -16
- data/lib/torch/nn/hinge_embedding_loss.rb +14 -0
- data/lib/torch/nn/identity.rb +1 -0
- data/lib/torch/nn/margin_ranking_loss.rb +14 -0
- data/lib/torch/nn/module.rb +59 -12
- data/lib/torch/nn/multi_label_margin_loss.rb +13 -0
- data/lib/torch/nn/multi_label_soft_margin_loss.rb +13 -0
- data/lib/torch/nn/multi_margin_loss.rb +17 -0
- data/lib/torch/nn/parameter.rb +4 -0
- data/lib/torch/nn/rnn.rb +22 -0
- data/lib/torch/nn/rnn_base.rb +154 -0
- data/lib/torch/nn/smooth_l1_loss.rb +13 -0
- data/lib/torch/nn/soft_margin_loss.rb +13 -0
- data/lib/torch/nn/triplet_margin_loss.rb +18 -0
- data/lib/torch/tensor.rb +19 -19
- data/lib/torch/version.rb +1 -1
- metadata +26 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 6b47306ed525e1a20d25cb8324d4658f750c18afa5704c9b7bafc215d8f568c1
|
4
|
+
data.tar.gz: dad6ddf955b111989b061e5af146006a32c83dc1ea1ca5005a6b6e34bc9a4892
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 5d26e3642bf7cd921b9b570052df353d4c32b1bd955a6fbbf5f30249631fa4c0d4624f4fa91a1c06f61b3b0d6461cd117ab4df185cf013e915d2f63e52dbcf7c
|
7
|
+
data.tar.gz: 1728ce9b579f41f7a567e63d7256c82bb352840b67f16d88aac930a99e5abbf5a5f4061c5f9da16fb47d1664567e7956d276a8b2b44f13d2263032486afb53e8
|
data/CHANGELOG.md
CHANGED
data/README.md
CHANGED
@@ -20,6 +20,8 @@ Add this line to your application’s Gemfile:
|
|
20
20
|
gem 'torch-rb'
|
21
21
|
```
|
22
22
|
|
23
|
+
It can take a few minutes to compile the extension.
|
24
|
+
|
23
25
|
## Getting Started
|
24
26
|
|
25
27
|
This library follows the [PyTorch API](https://pytorch.org/docs/stable/torch.html). There are a few changes to make it more Ruby-like:
|
@@ -28,7 +30,7 @@ This library follows the [PyTorch API](https://pytorch.org/docs/stable/torch.htm
|
|
28
30
|
- Methods that return booleans use `?` instead of `is_` (`tensor?` instead of `is_tensor`)
|
29
31
|
- Numo is used instead of NumPy (`x.numo` instead of `x.numpy()`)
|
30
32
|
|
31
|
-
|
33
|
+
Some methods and options are missing at the moment. PRs welcome!
|
32
34
|
|
33
35
|
## Tutorial
|
34
36
|
|
@@ -365,9 +367,9 @@ Here’s a list of functions to create tensors (descriptions from the [C++ docs]
|
|
365
367
|
|
366
368
|
Here are a few full examples:
|
367
369
|
|
368
|
-
- [Image classification with MNIST](examples/mnist)
|
370
|
+
- [Image classification with MNIST](examples/mnist) ([日本語版](https://qiita.com/kojix2/items/c19c36dc1bf73ea93409))
|
369
371
|
- [Collaborative filtering with MovieLens](examples/movielens)
|
370
|
-
- [Word embeddings](examples/nlp)
|
372
|
+
- [Word embeddings](examples/nlp)
|
371
373
|
|
372
374
|
## LibTorch Installation
|
373
375
|
|
data/ext/torch/ext.cpp
CHANGED
@@ -6,230 +6,15 @@
|
|
6
6
|
#include <rice/Class.hpp>
|
7
7
|
#include <rice/Constructor.hpp>
|
8
8
|
|
9
|
-
|
10
|
-
|
11
|
-
template<>
|
12
|
-
inline
|
13
|
-
long long from_ruby<long long>(Object x)
|
14
|
-
{
|
15
|
-
return NUM2LL(x);
|
16
|
-
}
|
17
|
-
|
18
|
-
template<>
|
19
|
-
inline
|
20
|
-
Object to_ruby<long long>(long long const & x)
|
21
|
-
{
|
22
|
-
return LL2NUM(x);
|
23
|
-
}
|
24
|
-
|
25
|
-
template<>
|
26
|
-
inline
|
27
|
-
unsigned long long from_ruby<unsigned long long>(Object x)
|
28
|
-
{
|
29
|
-
return NUM2ULL(x);
|
30
|
-
}
|
31
|
-
|
32
|
-
template<>
|
33
|
-
inline
|
34
|
-
Object to_ruby<unsigned long long>(unsigned long long const & x)
|
35
|
-
{
|
36
|
-
return ULL2NUM(x);
|
37
|
-
}
|
38
|
-
|
39
|
-
template<>
|
40
|
-
inline
|
41
|
-
short from_ruby<short>(Object x)
|
42
|
-
{
|
43
|
-
return NUM2SHORT(x);
|
44
|
-
}
|
45
|
-
|
46
|
-
template<>
|
47
|
-
inline
|
48
|
-
Object to_ruby<short>(short const & x)
|
49
|
-
{
|
50
|
-
return INT2NUM(x);
|
51
|
-
}
|
52
|
-
|
53
|
-
template<>
|
54
|
-
inline
|
55
|
-
unsigned short from_ruby<unsigned short>(Object x)
|
56
|
-
{
|
57
|
-
return NUM2USHORT(x);
|
58
|
-
}
|
59
|
-
|
60
|
-
template<>
|
61
|
-
inline
|
62
|
-
Object to_ruby<unsigned short>(unsigned short const & x)
|
63
|
-
{
|
64
|
-
return UINT2NUM(x);
|
65
|
-
}
|
9
|
+
#include "templates.hpp"
|
66
10
|
|
67
|
-
//
|
68
|
-
//
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
IntArrayRef(Object o) {
|
73
|
-
Array a = Array(o);
|
74
|
-
for (size_t i = 0; i < a.size(); i++) {
|
75
|
-
vec.push_back(from_ruby<int64_t>(a[i]));
|
76
|
-
}
|
77
|
-
}
|
78
|
-
operator torch::IntArrayRef() {
|
79
|
-
return torch::IntArrayRef(vec);
|
80
|
-
}
|
81
|
-
};
|
11
|
+
// generated with:
|
12
|
+
// rake generate:functions
|
13
|
+
#include "torch_functions.hpp"
|
14
|
+
#include "tensor_functions.hpp"
|
15
|
+
#include "nn_functions.hpp"
|
82
16
|
|
83
|
-
|
84
|
-
inline
|
85
|
-
IntArrayRef from_ruby<IntArrayRef>(Object x)
|
86
|
-
{
|
87
|
-
return IntArrayRef(x);
|
88
|
-
}
|
89
|
-
|
90
|
-
// for now
|
91
|
-
class Scalar {
|
92
|
-
torch::Scalar value;
|
93
|
-
public:
|
94
|
-
Scalar(Object o) {
|
95
|
-
// TODO cast based on Ruby type
|
96
|
-
if (o.rb_type() == T_FIXNUM) {
|
97
|
-
value = torch::Scalar(from_ruby<int64_t>(o));
|
98
|
-
} else {
|
99
|
-
value = torch::Scalar(from_ruby<float>(o));
|
100
|
-
}
|
101
|
-
}
|
102
|
-
operator torch::Scalar() {
|
103
|
-
return value;
|
104
|
-
}
|
105
|
-
};
|
106
|
-
|
107
|
-
template<>
|
108
|
-
inline
|
109
|
-
Scalar from_ruby<Scalar>(Object x)
|
110
|
-
{
|
111
|
-
return Scalar(x);
|
112
|
-
}
|
113
|
-
|
114
|
-
class TensorList {
|
115
|
-
std::vector<torch::Tensor> vec;
|
116
|
-
public:
|
117
|
-
TensorList(Object o) {
|
118
|
-
Array a = Array(o);
|
119
|
-
for (size_t i = 0; i < a.size(); i++) {
|
120
|
-
vec.push_back(from_ruby<torch::Tensor>(a[i]));
|
121
|
-
}
|
122
|
-
}
|
123
|
-
operator torch::TensorList() {
|
124
|
-
return torch::TensorList(vec);
|
125
|
-
}
|
126
|
-
};
|
127
|
-
|
128
|
-
template<>
|
129
|
-
inline
|
130
|
-
TensorList from_ruby<TensorList>(Object x)
|
131
|
-
{
|
132
|
-
return TensorList(x);
|
133
|
-
}
|
134
|
-
|
135
|
-
class FanModeType {
|
136
|
-
std::string s;
|
137
|
-
public:
|
138
|
-
FanModeType(Object o) {
|
139
|
-
s = String(o).str();
|
140
|
-
}
|
141
|
-
// TODO switch NonlinearityType after LibTorch 1.4 release
|
142
|
-
operator torch::nn::init::FanMode() {
|
143
|
-
if (s == "fan_in") {
|
144
|
-
return torch::nn::init::FanMode::FanIn;
|
145
|
-
} else if (s == "fan_out") {
|
146
|
-
return torch::nn::init::FanMode::FanOut;
|
147
|
-
} else {
|
148
|
-
throw std::runtime_error("Unsupported nonlinearity type: " + s);
|
149
|
-
}
|
150
|
-
}
|
151
|
-
};
|
152
|
-
|
153
|
-
template<>
|
154
|
-
inline
|
155
|
-
FanModeType from_ruby<FanModeType>(Object x)
|
156
|
-
{
|
157
|
-
return FanModeType(x);
|
158
|
-
}
|
159
|
-
|
160
|
-
class NonlinearityType {
|
161
|
-
std::string s;
|
162
|
-
public:
|
163
|
-
NonlinearityType(Object o) {
|
164
|
-
s = String(o).str();
|
165
|
-
}
|
166
|
-
// TODO switch NonlinearityType after LibTorch 1.4 release
|
167
|
-
operator torch::nn::init::Nonlinearity() {
|
168
|
-
if (s == "linear") {
|
169
|
-
return torch::nn::init::Nonlinearity::Linear;
|
170
|
-
} else if (s == "conv1d") {
|
171
|
-
return torch::nn::init::Nonlinearity::Conv1D;
|
172
|
-
} else if (s == "conv2d") {
|
173
|
-
return torch::nn::init::Nonlinearity::Conv2D;
|
174
|
-
} else if (s == "conv3d") {
|
175
|
-
return torch::nn::init::Nonlinearity::Conv3D;
|
176
|
-
} else if (s == "conv_transpose1d") {
|
177
|
-
return torch::nn::init::Nonlinearity::ConvTranspose1D;
|
178
|
-
} else if (s == "conv_transpose2d") {
|
179
|
-
return torch::nn::init::Nonlinearity::ConvTranspose2D;
|
180
|
-
} else if (s == "conv_transpose3d") {
|
181
|
-
return torch::nn::init::Nonlinearity::ConvTranspose3D;
|
182
|
-
} else if (s == "sigmoid") {
|
183
|
-
return torch::nn::init::Nonlinearity::Sigmoid;
|
184
|
-
} else if (s == "tanh") {
|
185
|
-
return torch::nn::init::Nonlinearity::Tanh;
|
186
|
-
} else if (s == "relu") {
|
187
|
-
return torch::nn::init::Nonlinearity::ReLU;
|
188
|
-
} else if (s == "leaky_relu") {
|
189
|
-
return torch::nn::init::Nonlinearity::LeakyReLU;
|
190
|
-
} else {
|
191
|
-
throw std::runtime_error("Unsupported nonlinearity type: " + s);
|
192
|
-
}
|
193
|
-
}
|
194
|
-
};
|
195
|
-
|
196
|
-
template<>
|
197
|
-
inline
|
198
|
-
NonlinearityType from_ruby<NonlinearityType>(Object x)
|
199
|
-
{
|
200
|
-
return NonlinearityType(x);
|
201
|
-
}
|
202
|
-
|
203
|
-
class MyReduction {
|
204
|
-
Object value;
|
205
|
-
public:
|
206
|
-
MyReduction(Object o) {
|
207
|
-
value = o;
|
208
|
-
}
|
209
|
-
operator int64_t() {
|
210
|
-
if (value.is_nil()) {
|
211
|
-
return Reduction::None;
|
212
|
-
}
|
213
|
-
|
214
|
-
std::string s = String(value).str();
|
215
|
-
if (s == "mean") {
|
216
|
-
return Reduction::Mean;
|
217
|
-
} else if (s == "sum") {
|
218
|
-
return Reduction::Sum;
|
219
|
-
} else {
|
220
|
-
throw std::runtime_error("Unsupported reduction: " + s);
|
221
|
-
}
|
222
|
-
}
|
223
|
-
};
|
224
|
-
|
225
|
-
template<>
|
226
|
-
inline
|
227
|
-
MyReduction from_ruby<MyReduction>(Object x)
|
228
|
-
{
|
229
|
-
return MyReduction(x);
|
230
|
-
}
|
231
|
-
|
232
|
-
typedef torch::Tensor Tensor;
|
17
|
+
using namespace Rice;
|
233
18
|
|
234
19
|
Object tensor_array(std::tuple<torch::Tensor, torch::Tensor> x) {
|
235
20
|
Array a;
|
@@ -241,8 +26,16 @@ Object tensor_array(std::tuple<torch::Tensor, torch::Tensor> x) {
|
|
241
26
|
extern "C"
|
242
27
|
void Init_ext()
|
243
28
|
{
|
244
|
-
Module rb_mTorch = define_module("Torch")
|
245
|
-
|
29
|
+
Module rb_mTorch = define_module("Torch");
|
30
|
+
add_torch_functions(rb_mTorch);
|
31
|
+
|
32
|
+
Class rb_cTensor = define_class_under<torch::Tensor>(rb_mTorch, "Tensor");
|
33
|
+
add_tensor_functions(rb_cTensor);
|
34
|
+
|
35
|
+
Module rb_mNN = define_module_under(rb_mTorch, "NN");
|
36
|
+
add_nn_functions(rb_mNN);
|
37
|
+
|
38
|
+
rb_mTorch.define_singleton_method(
|
246
39
|
"grad_enabled?",
|
247
40
|
*[]() {
|
248
41
|
return torch::GradMode::is_enabled();
|
@@ -252,11 +45,6 @@ void Init_ext()
|
|
252
45
|
*[](bool enabled) {
|
253
46
|
torch::GradMode::set_enabled(enabled);
|
254
47
|
})
|
255
|
-
.define_singleton_method(
|
256
|
-
"floating_point?",
|
257
|
-
*[](Tensor& input) {
|
258
|
-
return torch::is_floating_point(input);
|
259
|
-
})
|
260
48
|
.define_singleton_method(
|
261
49
|
"manual_seed",
|
262
50
|
*[](uint64_t seed) {
|
@@ -344,168 +132,16 @@ void Init_ext()
|
|
344
132
|
*[](Tensor& input, int64_t dim, bool keepdim) {
|
345
133
|
return torch::sum(input, dim, keepdim);
|
346
134
|
})
|
347
|
-
.define_singleton_method(
|
348
|
-
"_argmax",
|
349
|
-
*[](Tensor& input) {
|
350
|
-
return torch::argmax(input);
|
351
|
-
})
|
352
|
-
.define_singleton_method(
|
353
|
-
"_argmax_dim",
|
354
|
-
*[](Tensor& input, int64_t dim, bool keepdim) {
|
355
|
-
return torch::argmax(input, dim, keepdim);
|
356
|
-
})
|
357
|
-
.define_singleton_method(
|
358
|
-
"_cat",
|
359
|
-
*[](TensorList tensors, int64_t dim) {
|
360
|
-
return torch::cat(tensors, dim);
|
361
|
-
})
|
362
|
-
.define_singleton_method(
|
363
|
-
"_norm",
|
364
|
-
*[](Tensor& input) {
|
365
|
-
return torch::norm(input);
|
366
|
-
})
|
367
|
-
.define_singleton_method(
|
368
|
-
"_min",
|
369
|
-
*[](Tensor& input) {
|
370
|
-
return torch::min(input);
|
371
|
-
})
|
372
|
-
.define_singleton_method(
|
373
|
-
"_max",
|
374
|
-
*[](Tensor& input) {
|
375
|
-
return torch::max(input);
|
376
|
-
})
|
377
135
|
.define_singleton_method(
|
378
136
|
"_max_out",
|
379
137
|
*[](Tensor &max, Tensor &max_indices, const Tensor &input, int64_t dim, bool keepdim) {
|
380
138
|
return tensor_array(torch::_max_out(max, max_indices, input, dim, keepdim));
|
381
139
|
})
|
382
|
-
.define_singleton_method(
|
383
|
-
"_sqrt",
|
384
|
-
*[](Tensor& input) {
|
385
|
-
return torch::sqrt(input);
|
386
|
-
})
|
387
|
-
.define_singleton_method(
|
388
|
-
"_exp",
|
389
|
-
*[](Tensor& input) {
|
390
|
-
return torch::exp(input);
|
391
|
-
})
|
392
|
-
.define_singleton_method(
|
393
|
-
"_log",
|
394
|
-
*[](Tensor& input) {
|
395
|
-
return torch::log(input);
|
396
|
-
})
|
397
|
-
.define_singleton_method(
|
398
|
-
"_sign",
|
399
|
-
*[](Tensor& input) {
|
400
|
-
return torch::sign(input);
|
401
|
-
})
|
402
|
-
.define_singleton_method(
|
403
|
-
"_unsqueeze",
|
404
|
-
*[](Tensor& input, int64_t dim) {
|
405
|
-
return torch::unsqueeze(input, dim);
|
406
|
-
})
|
407
|
-
.define_singleton_method(
|
408
|
-
"_dot",
|
409
|
-
*[](Tensor& input, Tensor& tensor) {
|
410
|
-
return torch::dot(input, tensor);
|
411
|
-
})
|
412
|
-
.define_singleton_method(
|
413
|
-
"_matmul",
|
414
|
-
*[](Tensor& input, Tensor& other) {
|
415
|
-
return torch::matmul(input, other);
|
416
|
-
})
|
417
|
-
.define_singleton_method(
|
418
|
-
"_eq",
|
419
|
-
*[](Tensor& input, Tensor& other) {
|
420
|
-
return torch::eq(input, other);
|
421
|
-
})
|
422
|
-
.define_singleton_method(
|
423
|
-
"_gt",
|
424
|
-
// TODO support tensors
|
425
|
-
*[](Tensor& input, Scalar other) {
|
426
|
-
return torch::gt(input, other);
|
427
|
-
})
|
428
|
-
.define_singleton_method(
|
429
|
-
"_lt",
|
430
|
-
// TODO support tensors
|
431
|
-
*[](Tensor& input, Scalar other) {
|
432
|
-
return torch::lt(input, other);
|
433
|
-
})
|
434
|
-
.define_singleton_method(
|
435
|
-
"_add",
|
436
|
-
*[](Tensor& input, Tensor& other) {
|
437
|
-
return torch::add(input, other);
|
438
|
-
})
|
439
|
-
.define_singleton_method(
|
440
|
-
"_add_scalar",
|
441
|
-
*[](Tensor& input, Scalar other) {
|
442
|
-
return torch::add(input, other);
|
443
|
-
})
|
444
|
-
.define_singleton_method(
|
445
|
-
"_add_out",
|
446
|
-
*[](Tensor& out, Tensor& input, Tensor& other) {
|
447
|
-
return torch::add_out(out, input, other);
|
448
|
-
})
|
449
|
-
.define_singleton_method(
|
450
|
-
"_sub",
|
451
|
-
*[](Tensor& input, Tensor& other) {
|
452
|
-
return torch::sub(input, other);
|
453
|
-
})
|
454
|
-
.define_singleton_method(
|
455
|
-
"_sub_scalar",
|
456
|
-
*[](Tensor& input, Scalar other) {
|
457
|
-
return torch::sub(input, other);
|
458
|
-
})
|
459
|
-
.define_singleton_method(
|
460
|
-
"_mul",
|
461
|
-
*[](Tensor& input, Tensor& other) {
|
462
|
-
return torch::mul(input, other);
|
463
|
-
})
|
464
|
-
.define_singleton_method(
|
465
|
-
"_mul_scalar",
|
466
|
-
*[](Tensor& input, Scalar other) {
|
467
|
-
return torch::mul(input, other);
|
468
|
-
})
|
469
|
-
.define_singleton_method(
|
470
|
-
"_div",
|
471
|
-
*[](Tensor& input, Tensor& other) {
|
472
|
-
return torch::div(input, other);
|
473
|
-
})
|
474
|
-
.define_singleton_method(
|
475
|
-
"_div_scalar",
|
476
|
-
*[](Tensor& input, Scalar other) {
|
477
|
-
return torch::div(input, other);
|
478
|
-
})
|
479
|
-
.define_singleton_method(
|
480
|
-
"_remainder",
|
481
|
-
*[](Tensor& input, Tensor& other) {
|
482
|
-
return torch::remainder(input, other);
|
483
|
-
})
|
484
|
-
.define_singleton_method(
|
485
|
-
"_remainder_scalar",
|
486
|
-
*[](Tensor& input, Scalar other) {
|
487
|
-
return torch::remainder(input, other);
|
488
|
-
})
|
489
|
-
.define_singleton_method(
|
490
|
-
"_pow",
|
491
|
-
*[](Tensor& input, Scalar exponent) {
|
492
|
-
return torch::pow(input, exponent);
|
493
|
-
})
|
494
140
|
.define_singleton_method(
|
495
141
|
"_topk",
|
496
142
|
*[](Tensor& input, int64_t k) {
|
497
143
|
return tensor_array(torch::topk(input, k));
|
498
144
|
})
|
499
|
-
.define_singleton_method(
|
500
|
-
"_sigmoid",
|
501
|
-
*[](Tensor& input) {
|
502
|
-
return torch::sigmoid(input);
|
503
|
-
})
|
504
|
-
.define_singleton_method(
|
505
|
-
"_softplus",
|
506
|
-
*[](const Tensor &input, Scalar beta, Scalar threshold) {
|
507
|
-
return torch::softplus(input, beta, threshold);
|
508
|
-
})
|
509
145
|
.define_singleton_method(
|
510
146
|
"_softmax",
|
511
147
|
*[](const Tensor &input, int64_t dim) {
|
@@ -516,26 +152,6 @@ void Init_ext()
|
|
516
152
|
*[](Tensor& input, int64_t dim) {
|
517
153
|
return torch::log_softmax(input, dim);
|
518
154
|
})
|
519
|
-
.define_singleton_method(
|
520
|
-
"_abs",
|
521
|
-
*[](Tensor& input) {
|
522
|
-
return torch::abs(input);
|
523
|
-
})
|
524
|
-
.define_singleton_method(
|
525
|
-
"_neg",
|
526
|
-
*[](Tensor& input) {
|
527
|
-
return torch::neg(input);
|
528
|
-
})
|
529
|
-
.define_singleton_method(
|
530
|
-
"_reshape",
|
531
|
-
*[](Tensor& input, IntArrayRef shape) {
|
532
|
-
return torch::reshape(input, shape);
|
533
|
-
})
|
534
|
-
.define_singleton_method(
|
535
|
-
"_flatten",
|
536
|
-
*[](Tensor& input, int64_t start_dim, int64_t end_dim) {
|
537
|
-
return torch::flatten(input, start_dim, end_dim);
|
538
|
-
})
|
539
155
|
.define_singleton_method(
|
540
156
|
"relu",
|
541
157
|
*[](Tensor& input) {
|
@@ -579,104 +195,9 @@ void Init_ext()
|
|
579
195
|
return torch::avg_pool2d(input, kernel_size);
|
580
196
|
})
|
581
197
|
.define_singleton_method(
|
582
|
-
"
|
583
|
-
*[](Tensor&
|
584
|
-
return torch::
|
585
|
-
})
|
586
|
-
.define_singleton_method(
|
587
|
-
"_dropout!",
|
588
|
-
*[](Tensor& input, float p, bool train) {
|
589
|
-
return torch::dropout_(input, p, train);
|
590
|
-
})
|
591
|
-
.define_singleton_method(
|
592
|
-
"_feature_dropout",
|
593
|
-
*[](Tensor& input, float p, bool train) {
|
594
|
-
return torch::feature_dropout(input, p, train);
|
595
|
-
})
|
596
|
-
.define_singleton_method(
|
597
|
-
"_feature_dropout!",
|
598
|
-
*[](Tensor& input, float p, bool train) {
|
599
|
-
return torch::feature_dropout_(input, p, train);
|
600
|
-
})
|
601
|
-
.define_singleton_method(
|
602
|
-
"_alpha_dropout",
|
603
|
-
*[](Tensor& input, float p, bool train) {
|
604
|
-
return torch::alpha_dropout(input, p, train);
|
605
|
-
})
|
606
|
-
.define_singleton_method(
|
607
|
-
"_alpha_dropout!",
|
608
|
-
*[](Tensor& input, float p, bool train) {
|
609
|
-
return torch::alpha_dropout_(input, p, train);
|
610
|
-
})
|
611
|
-
.define_singleton_method(
|
612
|
-
"_feature_alpha_dropout",
|
613
|
-
*[](Tensor& input, float p, bool train) {
|
614
|
-
return torch::feature_alpha_dropout(input, p, train);
|
615
|
-
})
|
616
|
-
.define_singleton_method(
|
617
|
-
"_feature_alpha_dropout!",
|
618
|
-
*[](Tensor& input, float p, bool train) {
|
619
|
-
return torch::feature_alpha_dropout_(input, p, train);
|
620
|
-
})
|
621
|
-
// sparse layers
|
622
|
-
.define_singleton_method(
|
623
|
-
"_embedding",
|
624
|
-
// weight and indices are swapped from Python interface
|
625
|
-
*[](const Tensor &indices, const Tensor &weight, int64_t padding_idx, bool scale_grad_by_freq, bool sparse) {
|
626
|
-
return torch::embedding(weight, indices, padding_idx, scale_grad_by_freq, sparse);
|
627
|
-
})
|
628
|
-
.define_singleton_method(
|
629
|
-
"_embedding_bag",
|
630
|
-
// weight and indices are swapped from Python interface
|
631
|
-
*[](const Tensor &weight, const Tensor &indices, const Tensor &offsets, bool scale_grad_by_freq, int64_t mode, bool sparse, const Tensor &per_sample_weights) {
|
632
|
-
return torch::embedding_bag(weight, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights);
|
633
|
-
})
|
634
|
-
// distance functions
|
635
|
-
.define_singleton_method(
|
636
|
-
"_cosine_similarity",
|
637
|
-
*[](const Tensor &x1, const Tensor &x2, int64_t dim, double eps) {
|
638
|
-
return torch::cosine_similarity(x1, x2, dim, eps);
|
639
|
-
})
|
640
|
-
.define_singleton_method(
|
641
|
-
"_pairwise_distance",
|
642
|
-
*[](const Tensor &x1, const Tensor &x2, double p, double eps, bool keepdim) {
|
643
|
-
return torch::pairwise_distance(x1, x2, p, eps, keepdim);
|
644
|
-
})
|
645
|
-
// loss functions
|
646
|
-
.define_singleton_method(
|
647
|
-
"binary_cross_entropy",
|
648
|
-
*[](Tensor& input, Tensor& target, MyReduction reduction) {
|
649
|
-
return torch::binary_cross_entropy(input, target, {}, reduction);
|
650
|
-
})
|
651
|
-
.define_singleton_method(
|
652
|
-
"ctc_loss",
|
653
|
-
*[](const Tensor &log_probs, const Tensor &targets, IntArrayRef input_lengths, IntArrayRef target_lengths, int64_t blank, MyReduction reduction, bool zero_infinity) {
|
654
|
-
return torch::ctc_loss(log_probs, targets, input_lengths, target_lengths, blank, reduction, zero_infinity);
|
655
|
-
})
|
656
|
-
.define_singleton_method(
|
657
|
-
"kl_div",
|
658
|
-
*[](Tensor& input, Tensor& target, MyReduction reduction) {
|
659
|
-
return torch::kl_div(input, target, reduction);
|
660
|
-
})
|
661
|
-
.define_singleton_method(
|
662
|
-
"l1_loss",
|
663
|
-
*[](Tensor& input, Tensor& target, MyReduction reduction) {
|
664
|
-
return torch::l1_loss(input, target, reduction);
|
665
|
-
})
|
666
|
-
.define_singleton_method(
|
667
|
-
"mse_loss",
|
668
|
-
*[](Tensor& input, Tensor& target, MyReduction reduction) {
|
669
|
-
return torch::mse_loss(input, target, reduction);
|
670
|
-
})
|
671
|
-
.define_singleton_method(
|
672
|
-
"nll_loss",
|
673
|
-
*[](Tensor& input, Tensor& target, MyReduction reduction, int64_t ignore_index) {
|
674
|
-
return torch::nll_loss(input, target, {}, reduction, ignore_index);
|
675
|
-
})
|
676
|
-
.define_singleton_method(
|
677
|
-
"poisson_nll_loss",
|
678
|
-
*[](const Tensor &input, const Tensor &target, bool log_input, bool full, double eps, MyReduction reduction) {
|
679
|
-
return torch::poisson_nll_loss(input, target, log_input, full, eps, reduction);
|
198
|
+
"_binary_cross_entropy_with_logits",
|
199
|
+
*[](const Tensor &input, const Tensor &target, OptionalTensor weight, OptionalTensor pos_weight, MyReduction reduction) {
|
200
|
+
return torch::binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction);
|
680
201
|
})
|
681
202
|
.define_singleton_method("numel", &torch::numel)
|
682
203
|
.define_singleton_method(
|
@@ -703,7 +224,7 @@ void Init_ext()
|
|
703
224
|
return t.reshape(size);
|
704
225
|
});
|
705
226
|
|
706
|
-
|
227
|
+
rb_cTensor
|
707
228
|
.define_method("cuda?", &torch::Tensor::is_cuda)
|
708
229
|
.define_method("distributed?", &torch::Tensor::is_distributed)
|
709
230
|
.define_method("complex?", &torch::Tensor::is_complex)
|
@@ -740,16 +261,6 @@ void Init_ext()
|
|
740
261
|
*[](Tensor& self) {
|
741
262
|
return self.detach_();
|
742
263
|
})
|
743
|
-
.define_method(
|
744
|
-
"_select",
|
745
|
-
*[](Tensor& self, int64_t dim, int64_t index) {
|
746
|
-
return self.select(dim, index);
|
747
|
-
})
|
748
|
-
.define_method(
|
749
|
-
"_slice",
|
750
|
-
*[](Tensor& self, int64_t dim, int64_t start, int64_t end, int64_t step) {
|
751
|
-
return self.slice(dim, start, end, step);
|
752
|
-
})
|
753
264
|
.define_method(
|
754
265
|
"_requires_grad!",
|
755
266
|
*[](Tensor& self, bool requires_grad) {
|
@@ -789,11 +300,6 @@ void Init_ext()
|
|
789
300
|
s << self.device();
|
790
301
|
return s.str();
|
791
302
|
})
|
792
|
-
.define_method(
|
793
|
-
"_view",
|
794
|
-
*[](Tensor& self, IntArrayRef size) {
|
795
|
-
return self.view(size);
|
796
|
-
})
|
797
303
|
.define_method(
|
798
304
|
"resize_as!",
|
799
305
|
*[](Tensor& self, Tensor& other) {
|
@@ -809,21 +315,6 @@ void Init_ext()
|
|
809
315
|
*[](Tensor& self) {
|
810
316
|
return self.relu_();
|
811
317
|
})
|
812
|
-
.define_method(
|
813
|
-
"_add!",
|
814
|
-
*[](Tensor& self, Tensor& other) {
|
815
|
-
return self.add_(other);
|
816
|
-
})
|
817
|
-
.define_method(
|
818
|
-
"_add_alpha!",
|
819
|
-
*[](Tensor& self, Tensor& other, Scalar alpha) {
|
820
|
-
return self.add_(other, alpha);
|
821
|
-
})
|
822
|
-
.define_method(
|
823
|
-
"_add_scalar!",
|
824
|
-
*[](Tensor& self, Scalar other) {
|
825
|
-
return self.add_(other);
|
826
|
-
})
|
827
318
|
.define_method(
|
828
319
|
"normal!",
|
829
320
|
*[](Tensor& self, double mean, double std) {
|
@@ -839,16 +330,6 @@ void Init_ext()
|
|
839
330
|
*[](Tensor& self, Tensor& other) {
|
840
331
|
return self.sub_(other);
|
841
332
|
})
|
842
|
-
.define_method(
|
843
|
-
"_mul!",
|
844
|
-
*[](Tensor& self, Tensor& other) {
|
845
|
-
return self.mul_(other);
|
846
|
-
})
|
847
|
-
.define_method(
|
848
|
-
"_mul_scalar!",
|
849
|
-
*[](Tensor& self, Scalar other) {
|
850
|
-
return self.mul_(other);
|
851
|
-
})
|
852
333
|
.define_method(
|
853
334
|
"div!",
|
854
335
|
*[](Tensor& self, Tensor& other) {
|
@@ -880,7 +361,7 @@ void Init_ext()
|
|
880
361
|
return self.data();
|
881
362
|
})
|
882
363
|
.define_method(
|
883
|
-
"
|
364
|
+
"_flat_data",
|
884
365
|
*[](Tensor& self) {
|
885
366
|
Array a;
|
886
367
|
auto dtype = self.dtype();
|
@@ -931,11 +412,6 @@ void Init_ext()
|
|
931
412
|
}
|
932
413
|
return a;
|
933
414
|
})
|
934
|
-
.define_method(
|
935
|
-
"_size",
|
936
|
-
*[](Tensor& self, int i) {
|
937
|
-
return self.size(i);
|
938
|
-
})
|
939
415
|
.define_method(
|
940
416
|
"_to",
|
941
417
|
*[](Tensor& self, torch::Device device, int dtype, bool non_blocking, bool copy) {
|
@@ -990,8 +466,6 @@ void Init_ext()
|
|
990
466
|
return self.requires_grad(requires_grad);
|
991
467
|
});
|
992
468
|
|
993
|
-
Module rb_mNN = define_module_under(rb_mTorch, "NN");
|
994
|
-
|
995
469
|
Module rb_mInit = define_module_under(rb_mNN, "Init")
|
996
470
|
.define_singleton_method(
|
997
471
|
"_calculate_gain",
|