torch-rb 0.1.2 → 0.1.3
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +10 -0
- data/LICENSE.txt +46 -22
- data/README.md +14 -5
- data/ext/torch/ext.cpp +248 -31
- data/lib/torch.rb +80 -9
- data/lib/torch/ext.bundle +0 -0
- data/lib/torch/inspector.rb +4 -3
- data/lib/torch/nn/alpha_dropout.rb +9 -0
- data/lib/torch/nn/conv2d.rb +12 -24
- data/lib/torch/nn/convnd.rb +41 -0
- data/lib/torch/nn/dropout.rb +9 -0
- data/lib/torch/nn/dropout2d.rb +9 -0
- data/lib/torch/nn/dropout3d.rb +9 -0
- data/lib/torch/nn/dropoutnd.rb +15 -0
- data/lib/torch/nn/embedding.rb +52 -0
- data/lib/torch/nn/feature_alpha_dropout.rb +9 -0
- data/lib/torch/nn/functional.rb +54 -12
- data/lib/torch/nn/linear.rb +2 -2
- data/lib/torch/nn/module.rb +30 -0
- data/lib/torch/optim/adadelta.rb +57 -0
- data/lib/torch/optim/adagrad.rb +71 -0
- data/lib/torch/optim/adam.rb +81 -0
- data/lib/torch/optim/adamax.rb +68 -0
- data/lib/torch/optim/adamw.rb +82 -0
- data/lib/torch/optim/asgd.rb +65 -0
- data/lib/torch/optim/lr_scheduler/lr_scheduler.rb +33 -0
- data/lib/torch/optim/lr_scheduler/step_lr.rb +17 -0
- data/lib/torch/optim/optimizer.rb +56 -0
- data/lib/torch/optim/rmsprop.rb +76 -0
- data/lib/torch/optim/rprop.rb +68 -0
- data/lib/torch/optim/sgd.rb +48 -16
- data/lib/torch/tensor.rb +38 -4
- data/lib/torch/utils/data/data_loader.rb +10 -4
- data/lib/torch/utils/data/tensor_dataset.rb +3 -0
- data/lib/torch/version.rb +1 -1
- metadata +21 -3
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: e7f715179c9a84dc7399b80d93fd61f2bbb58a0156e6084dc4abb23e1d4a1b52
|
4
|
+
data.tar.gz: 6928379ae7c92a77ad9dde4f4224ec33c6f8575a9b77585c0147e4f5361021de
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 9911a9e86d93f1e410776c44fdb3cd9aa06c83d1f0e42fdab8530970bea6520aed7906e96fb8243efd6b957453ebc13678b2b92e4c85b54407030a32c6196e08
|
7
|
+
data.tar.gz: 0d080f5458a5dcf8fee19ce5e2e342bf6269432de6e78d923036232963ebb80daeea993c0bbf4af2d6da46593ac28a72a8232020a9fcb48acc3276c9e1ebebf3
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,13 @@
|
|
1
|
+
## 0.1.3 (2019-11-30)
|
2
|
+
|
3
|
+
- Changed to BSD 3-Clause license to match PyTorch
|
4
|
+
- Added many optimizers
|
5
|
+
- Added `StepLR` learning rate scheduler
|
6
|
+
- Added dropout
|
7
|
+
- Added embedding
|
8
|
+
- Added support for `bool` type
|
9
|
+
- Improved performance of `from_numo`
|
10
|
+
|
1
11
|
## 0.1.2 (2019-11-27)
|
2
12
|
|
3
13
|
- Added SGD optimizer
|
data/LICENSE.txt
CHANGED
@@ -1,22 +1,46 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
1
|
+
BSD 3-Clause License
|
2
|
+
|
3
|
+
From Torch-rb:
|
4
|
+
|
5
|
+
Copyright (c) 2019- Andrew Kane
|
6
|
+
|
7
|
+
From PyTorch (for ported code):
|
8
|
+
|
9
|
+
Copyright (c) 2016- Facebook, Inc (Adam Paszke)
|
10
|
+
Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
|
11
|
+
Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
|
12
|
+
Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
|
13
|
+
Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
|
14
|
+
Copyright (c) 2011-2013 NYU (Clement Farabet)
|
15
|
+
Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
|
16
|
+
Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
|
17
|
+
Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
|
18
|
+
|
19
|
+
All rights reserved.
|
20
|
+
|
21
|
+
Redistribution and use in source and binary forms, with or without
|
22
|
+
modification, are permitted provided that the following conditions are met:
|
23
|
+
|
24
|
+
1. Redistributions of source code must retain the above copyright
|
25
|
+
notice, this list of conditions and the following disclaimer.
|
26
|
+
|
27
|
+
2. Redistributions in binary form must reproduce the above copyright
|
28
|
+
notice, this list of conditions and the following disclaimer in the
|
29
|
+
documentation and/or other materials provided with the distribution.
|
30
|
+
|
31
|
+
3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America
|
32
|
+
and IDIAP Research Institute nor the names of its contributors may be
|
33
|
+
used to endorse or promote products derived from this software without
|
34
|
+
specific prior written permission.
|
35
|
+
|
36
|
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
37
|
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
38
|
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
39
|
+
ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
|
40
|
+
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
41
|
+
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
42
|
+
SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
43
|
+
INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
44
|
+
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
45
|
+
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
46
|
+
POSSIBILITY OF SUCH DAMAGE.
|
data/README.md
CHANGED
@@ -30,7 +30,9 @@ This library follows the [PyTorch API](https://pytorch.org/docs/stable/torch.htm
|
|
30
30
|
|
31
31
|
Many methods and options are missing at the moment. PRs welcome!
|
32
32
|
|
33
|
-
|
33
|
+
## Tutorial
|
34
|
+
|
35
|
+
Some examples below are from [Deep Learning with PyTorch: A 60 Minutes Blitz](https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html)
|
34
36
|
|
35
37
|
### Tensors
|
36
38
|
|
@@ -145,7 +147,7 @@ Convert a Numo array to a tensor
|
|
145
147
|
|
146
148
|
```ruby
|
147
149
|
b = Numo::NArray.cast([1, 2, 3])
|
148
|
-
Torch.
|
150
|
+
Torch.from_numo(b)
|
149
151
|
```
|
150
152
|
|
151
153
|
### Autograd
|
@@ -180,10 +182,10 @@ Stop autograd from tracking history
|
|
180
182
|
|
181
183
|
```ruby
|
182
184
|
x.requires_grad # true
|
183
|
-
(x
|
185
|
+
(x**2).requires_grad # true
|
184
186
|
|
185
187
|
Torch.no_grad do
|
186
|
-
(x
|
188
|
+
(x**2).requires_grad # false
|
187
189
|
end
|
188
190
|
```
|
189
191
|
|
@@ -359,6 +361,13 @@ Here’s a list of functions to create tensors (descriptions from the [C++ docs]
|
|
359
361
|
Torch.zeros(3) # tensor([0, 0, 0])
|
360
362
|
```
|
361
363
|
|
364
|
+
## Examples
|
365
|
+
|
366
|
+
Here are a few full examples:
|
367
|
+
|
368
|
+
- [Image classification with MNIST](examples/mnist)
|
369
|
+
- [Collaborative filtering with MovieLens](examples/movielens)
|
370
|
+
|
362
371
|
## LibTorch Installation
|
363
372
|
|
364
373
|
[Download LibTorch](https://pytorch.org/). For Linux, use the `cxx11 ABI` version. Then run:
|
@@ -405,7 +414,7 @@ To get started with development:
|
|
405
414
|
git clone https://github.com/ankane/torch-rb.git
|
406
415
|
cd torch-rb
|
407
416
|
bundle install
|
408
|
-
bundle exec rake compile
|
417
|
+
bundle exec rake compile -- --with-torch-dir=/path/to/libtorch
|
409
418
|
bundle exec rake test
|
410
419
|
```
|
411
420
|
|
data/ext/torch/ext.cpp
CHANGED
@@ -88,7 +88,49 @@ IntArrayRef from_ruby<IntArrayRef>(Object x)
|
|
88
88
|
}
|
89
89
|
|
90
90
|
// for now
|
91
|
-
|
91
|
+
class Scalar {
|
92
|
+
torch::Scalar value;
|
93
|
+
public:
|
94
|
+
Scalar(Object o) {
|
95
|
+
// TODO cast based on Ruby type
|
96
|
+
if (o.rb_type() == T_FIXNUM) {
|
97
|
+
value = torch::Scalar(from_ruby<int64_t>(o));
|
98
|
+
} else {
|
99
|
+
value = torch::Scalar(from_ruby<float>(o));
|
100
|
+
}
|
101
|
+
}
|
102
|
+
operator torch::Scalar() {
|
103
|
+
return value;
|
104
|
+
}
|
105
|
+
};
|
106
|
+
|
107
|
+
template<>
|
108
|
+
inline
|
109
|
+
Scalar from_ruby<Scalar>(Object x)
|
110
|
+
{
|
111
|
+
return Scalar(x);
|
112
|
+
}
|
113
|
+
|
114
|
+
class TensorList {
|
115
|
+
std::vector<torch::Tensor> vec;
|
116
|
+
public:
|
117
|
+
TensorList(Object o) {
|
118
|
+
Array a = Array(o);
|
119
|
+
for (size_t i = 0; i < a.size(); i++) {
|
120
|
+
vec.push_back(from_ruby<torch::Tensor>(a[i]));
|
121
|
+
}
|
122
|
+
}
|
123
|
+
operator torch::TensorList() {
|
124
|
+
return torch::TensorList(vec);
|
125
|
+
}
|
126
|
+
};
|
127
|
+
|
128
|
+
template<>
|
129
|
+
inline
|
130
|
+
TensorList from_ruby<TensorList>(Object x)
|
131
|
+
{
|
132
|
+
return TensorList(x);
|
133
|
+
}
|
92
134
|
|
93
135
|
extern "C"
|
94
136
|
void Init_ext()
|
@@ -206,6 +248,11 @@ void Init_ext()
|
|
206
248
|
*[](torch::Tensor& input, int64_t dim, bool keepdim) {
|
207
249
|
return torch::argmax(input, dim, keepdim);
|
208
250
|
})
|
251
|
+
.define_singleton_method(
|
252
|
+
"_cat",
|
253
|
+
*[](TensorList tensors, int64_t dim) {
|
254
|
+
return torch::cat(tensors, dim);
|
255
|
+
})
|
209
256
|
.define_singleton_method(
|
210
257
|
"_norm",
|
211
258
|
*[](torch::Tensor& input) {
|
@@ -221,6 +268,17 @@ void Init_ext()
|
|
221
268
|
*[](torch::Tensor& input) {
|
222
269
|
return torch::max(input);
|
223
270
|
})
|
271
|
+
.define_singleton_method(
|
272
|
+
"_max_out",
|
273
|
+
*[](torch::Tensor &max, torch::Tensor &max_indices, const torch::Tensor &input, int64_t dim, bool keepdim) {
|
274
|
+
// TODO add return value
|
275
|
+
torch::_max_out(max, max_indices, input, dim, keepdim);
|
276
|
+
})
|
277
|
+
.define_singleton_method(
|
278
|
+
"_sqrt",
|
279
|
+
*[](torch::Tensor& input) {
|
280
|
+
return torch::sqrt(input);
|
281
|
+
})
|
224
282
|
.define_singleton_method(
|
225
283
|
"_exp",
|
226
284
|
*[](torch::Tensor& input) {
|
@@ -231,6 +289,11 @@ void Init_ext()
|
|
231
289
|
*[](torch::Tensor& input) {
|
232
290
|
return torch::log(input);
|
233
291
|
})
|
292
|
+
.define_singleton_method(
|
293
|
+
"_sign",
|
294
|
+
*[](torch::Tensor& input) {
|
295
|
+
return torch::sign(input);
|
296
|
+
})
|
234
297
|
.define_singleton_method(
|
235
298
|
"_unsqueeze",
|
236
299
|
*[](torch::Tensor& input, int64_t dim) {
|
@@ -251,6 +314,18 @@ void Init_ext()
|
|
251
314
|
*[](torch::Tensor& input, torch::Tensor& other) {
|
252
315
|
return torch::eq(input, other);
|
253
316
|
})
|
317
|
+
.define_singleton_method(
|
318
|
+
"_gt",
|
319
|
+
// TODO support tensors
|
320
|
+
*[](torch::Tensor& input, Scalar other) {
|
321
|
+
return torch::gt(input, other);
|
322
|
+
})
|
323
|
+
.define_singleton_method(
|
324
|
+
"_lt",
|
325
|
+
// TODO support tensors
|
326
|
+
*[](torch::Tensor& input, Scalar other) {
|
327
|
+
return torch::lt(input, other);
|
328
|
+
})
|
254
329
|
.define_singleton_method(
|
255
330
|
"_add",
|
256
331
|
*[](torch::Tensor& input, torch::Tensor& other) {
|
@@ -258,7 +333,7 @@ void Init_ext()
|
|
258
333
|
})
|
259
334
|
.define_singleton_method(
|
260
335
|
"_add_scalar",
|
261
|
-
*[](torch::Tensor& input,
|
336
|
+
*[](torch::Tensor& input, Scalar other) {
|
262
337
|
return torch::add(input, other);
|
263
338
|
})
|
264
339
|
.define_singleton_method(
|
@@ -273,7 +348,7 @@ void Init_ext()
|
|
273
348
|
})
|
274
349
|
.define_singleton_method(
|
275
350
|
"_sub_scalar",
|
276
|
-
*[](torch::Tensor& input,
|
351
|
+
*[](torch::Tensor& input, Scalar other) {
|
277
352
|
return torch::sub(input, other);
|
278
353
|
})
|
279
354
|
.define_singleton_method(
|
@@ -283,7 +358,7 @@ void Init_ext()
|
|
283
358
|
})
|
284
359
|
.define_singleton_method(
|
285
360
|
"_mul_scalar",
|
286
|
-
*[](torch::Tensor& input,
|
361
|
+
*[](torch::Tensor& input, Scalar other) {
|
287
362
|
return torch::mul(input, other);
|
288
363
|
})
|
289
364
|
.define_singleton_method(
|
@@ -293,7 +368,7 @@ void Init_ext()
|
|
293
368
|
})
|
294
369
|
.define_singleton_method(
|
295
370
|
"_div_scalar",
|
296
|
-
*[](torch::Tensor& input,
|
371
|
+
*[](torch::Tensor& input, Scalar other) {
|
297
372
|
return torch::div(input, other);
|
298
373
|
})
|
299
374
|
.define_singleton_method(
|
@@ -303,7 +378,7 @@ void Init_ext()
|
|
303
378
|
})
|
304
379
|
.define_singleton_method(
|
305
380
|
"_remainder_scalar",
|
306
|
-
*[](torch::Tensor& input,
|
381
|
+
*[](torch::Tensor& input, Scalar other) {
|
307
382
|
return torch::remainder(input, other);
|
308
383
|
})
|
309
384
|
.define_singleton_method(
|
@@ -311,6 +386,11 @@ void Init_ext()
|
|
311
386
|
*[](torch::Tensor& input, Scalar exponent) {
|
312
387
|
return torch::pow(input, exponent);
|
313
388
|
})
|
389
|
+
.define_singleton_method(
|
390
|
+
"_abs",
|
391
|
+
*[](torch::Tensor& input) {
|
392
|
+
return torch::abs(input);
|
393
|
+
})
|
314
394
|
.define_singleton_method(
|
315
395
|
"_neg",
|
316
396
|
*[](torch::Tensor& input) {
|
@@ -321,25 +401,20 @@ void Init_ext()
|
|
321
401
|
*[](torch::Tensor& input, IntArrayRef shape) {
|
322
402
|
return torch::reshape(input, shape);
|
323
403
|
})
|
404
|
+
.define_singleton_method(
|
405
|
+
"_flatten",
|
406
|
+
*[](torch::Tensor& input, int64_t start_dim, int64_t end_dim) {
|
407
|
+
return torch::flatten(input, start_dim, end_dim);
|
408
|
+
})
|
324
409
|
.define_singleton_method(
|
325
410
|
"relu",
|
326
411
|
*[](torch::Tensor& input) {
|
327
412
|
return torch::relu(input);
|
328
413
|
})
|
329
|
-
.define_singleton_method(
|
330
|
-
"prelu",
|
331
|
-
*[](torch::Tensor& input, torch::Tensor& weight) {
|
332
|
-
return torch::prelu(input, weight);
|
333
|
-
})
|
334
|
-
.define_singleton_method(
|
335
|
-
"leaky_relu",
|
336
|
-
*[](torch::Tensor& input, Scalar negative_slope = 0.01) {
|
337
|
-
return torch::leaky_relu(input, negative_slope);
|
338
|
-
})
|
339
414
|
.define_singleton_method(
|
340
415
|
"conv2d",
|
341
|
-
*[](torch::Tensor& input, torch::Tensor& weight, torch::Tensor& bias, IntArrayRef stride, IntArrayRef padding) {
|
342
|
-
return torch::conv2d(input, weight, bias, stride, padding);
|
416
|
+
*[](torch::Tensor& input, torch::Tensor& weight, torch::Tensor& bias, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, int64_t groups) {
|
417
|
+
return torch::conv2d(input, weight, bias, stride, padding, dilation, groups);
|
343
418
|
})
|
344
419
|
.define_singleton_method(
|
345
420
|
"linear",
|
@@ -356,6 +431,52 @@ void Init_ext()
|
|
356
431
|
*[](torch::Tensor& input, IntArrayRef kernel_size) {
|
357
432
|
return torch::avg_pool2d(input, kernel_size);
|
358
433
|
})
|
434
|
+
.define_singleton_method(
|
435
|
+
"_dropout",
|
436
|
+
*[](torch::Tensor& input, float p, bool train) {
|
437
|
+
return torch::dropout(input, p, train);
|
438
|
+
})
|
439
|
+
.define_singleton_method(
|
440
|
+
"_dropout!",
|
441
|
+
*[](torch::Tensor& input, float p, bool train) {
|
442
|
+
return torch::dropout_(input, p, train);
|
443
|
+
})
|
444
|
+
.define_singleton_method(
|
445
|
+
"_feature_dropout",
|
446
|
+
*[](torch::Tensor& input, float p, bool train) {
|
447
|
+
return torch::feature_dropout(input, p, train);
|
448
|
+
})
|
449
|
+
.define_singleton_method(
|
450
|
+
"_feature_dropout!",
|
451
|
+
*[](torch::Tensor& input, float p, bool train) {
|
452
|
+
return torch::feature_dropout_(input, p, train);
|
453
|
+
})
|
454
|
+
.define_singleton_method(
|
455
|
+
"_alpha_dropout",
|
456
|
+
*[](torch::Tensor& input, float p, bool train) {
|
457
|
+
return torch::alpha_dropout(input, p, train);
|
458
|
+
})
|
459
|
+
.define_singleton_method(
|
460
|
+
"_alpha_dropout!",
|
461
|
+
*[](torch::Tensor& input, float p, bool train) {
|
462
|
+
return torch::alpha_dropout_(input, p, train);
|
463
|
+
})
|
464
|
+
.define_singleton_method(
|
465
|
+
"_feature_alpha_dropout",
|
466
|
+
*[](torch::Tensor& input, float p, bool train) {
|
467
|
+
return torch::feature_alpha_dropout(input, p, train);
|
468
|
+
})
|
469
|
+
.define_singleton_method(
|
470
|
+
"_feature_alpha_dropout!",
|
471
|
+
*[](torch::Tensor& input, float p, bool train) {
|
472
|
+
return torch::feature_alpha_dropout_(input, p, train);
|
473
|
+
})
|
474
|
+
.define_singleton_method(
|
475
|
+
"_embedding",
|
476
|
+
// weight and indices are swapped from Python interface
|
477
|
+
*[](const torch::Tensor &indices, const torch::Tensor &weight, int64_t padding_idx, bool scale_grad_by_freq, bool sparse) {
|
478
|
+
return torch::embedding(weight, indices, padding_idx, scale_grad_by_freq, sparse);
|
479
|
+
})
|
359
480
|
.define_singleton_method(
|
360
481
|
"mse_loss",
|
361
482
|
*[](torch::Tensor& input, torch::Tensor& target, std::string reduction) {
|
@@ -364,8 +485,16 @@ void Init_ext()
|
|
364
485
|
})
|
365
486
|
.define_singleton_method(
|
366
487
|
"nll_loss",
|
367
|
-
*[](torch::Tensor& input, torch::Tensor& target) {
|
368
|
-
|
488
|
+
*[](torch::Tensor& input, torch::Tensor& target, std::string reduction) {
|
489
|
+
auto red = reduction == "mean" ? Reduction::Mean : Reduction::Sum;
|
490
|
+
return torch::nll_loss(input, target, {}, red);
|
491
|
+
})
|
492
|
+
.define_singleton_method("numel", &torch::numel)
|
493
|
+
.define_singleton_method(
|
494
|
+
"_from_blob",
|
495
|
+
*[](String s, IntArrayRef size, const torch::TensorOptions &options) {
|
496
|
+
void *data = const_cast<char *>(s.c_str());
|
497
|
+
return torch::from_blob(data, size, options);
|
369
498
|
})
|
370
499
|
.define_singleton_method(
|
371
500
|
"_tensor",
|
@@ -387,9 +516,19 @@ void Init_ext()
|
|
387
516
|
.define_method("sparse?", &torch::Tensor::is_sparse)
|
388
517
|
.define_method("quantized?", &torch::Tensor::is_quantized)
|
389
518
|
.define_method("dim", &torch::Tensor::dim)
|
390
|
-
.define_method("numel", &torch::Tensor::numel)
|
391
519
|
.define_method("element_size", &torch::Tensor::element_size)
|
392
520
|
.define_method("requires_grad", &torch::Tensor::requires_grad)
|
521
|
+
.define_method("view_as", &torch::Tensor::view_as)
|
522
|
+
.define_method(
|
523
|
+
"addcmul!",
|
524
|
+
*[](torch::Tensor& self, Scalar value, const torch::Tensor & tensor1, const torch::Tensor & tensor2) {
|
525
|
+
return self.addcmul_(tensor1, tensor2, value);
|
526
|
+
})
|
527
|
+
.define_method(
|
528
|
+
"addcdiv!",
|
529
|
+
*[](torch::Tensor& self, Scalar value, const torch::Tensor & tensor1, const torch::Tensor & tensor2) {
|
530
|
+
return self.addcdiv_(tensor1, tensor2, value);
|
531
|
+
})
|
393
532
|
.define_method(
|
394
533
|
"zero!",
|
395
534
|
*[](torch::Tensor& self) {
|
@@ -460,24 +599,74 @@ void Init_ext()
|
|
460
599
|
return self.view(size);
|
461
600
|
})
|
462
601
|
.define_method(
|
463
|
-
"
|
602
|
+
"resize_as!",
|
464
603
|
*[](torch::Tensor& self, torch::Tensor& other) {
|
465
|
-
self.
|
604
|
+
return self.resize_as_(other);
|
605
|
+
})
|
606
|
+
.define_method(
|
607
|
+
"fill!",
|
608
|
+
*[](torch::Tensor& self, Scalar value) {
|
609
|
+
return self.fill_(value);
|
610
|
+
})
|
611
|
+
.define_method(
|
612
|
+
"_add!",
|
613
|
+
*[](torch::Tensor& self, torch::Tensor& other) {
|
614
|
+
return self.add_(other);
|
615
|
+
})
|
616
|
+
.define_method(
|
617
|
+
"_add_alpha!",
|
618
|
+
*[](torch::Tensor& self, torch::Tensor& other, Scalar alpha) {
|
619
|
+
return self.add_(other, alpha);
|
620
|
+
})
|
621
|
+
.define_method(
|
622
|
+
"_add_scalar!",
|
623
|
+
*[](torch::Tensor& self, Scalar other) {
|
624
|
+
return self.add_(other);
|
625
|
+
})
|
626
|
+
.define_method(
|
627
|
+
"normal!",
|
628
|
+
*[](torch::Tensor& self, double mean, double std) {
|
629
|
+
return self.normal_(mean, std);
|
466
630
|
})
|
467
631
|
.define_method(
|
468
632
|
"sub!",
|
469
633
|
*[](torch::Tensor& self, torch::Tensor& other) {
|
470
|
-
self.sub_(other);
|
634
|
+
return self.sub_(other);
|
471
635
|
})
|
472
636
|
.define_method(
|
473
|
-
"
|
637
|
+
"_mul!",
|
474
638
|
*[](torch::Tensor& self, torch::Tensor& other) {
|
475
|
-
self.mul_(other);
|
639
|
+
return self.mul_(other);
|
640
|
+
})
|
641
|
+
.define_method(
|
642
|
+
"_mul_scalar!",
|
643
|
+
*[](torch::Tensor& self, Scalar other) {
|
644
|
+
return self.mul_(other);
|
476
645
|
})
|
477
646
|
.define_method(
|
478
647
|
"div!",
|
479
648
|
*[](torch::Tensor& self, torch::Tensor& other) {
|
480
|
-
self.div_(other);
|
649
|
+
return self.div_(other);
|
650
|
+
})
|
651
|
+
.define_method(
|
652
|
+
"sqrt!",
|
653
|
+
*[](torch::Tensor& self) {
|
654
|
+
return self.sqrt_();
|
655
|
+
})
|
656
|
+
.define_method(
|
657
|
+
"unsqueeze!",
|
658
|
+
*[](torch::Tensor& self, int64_t dim) {
|
659
|
+
return self.unsqueeze_(dim);
|
660
|
+
})
|
661
|
+
.define_method(
|
662
|
+
"copy!",
|
663
|
+
*[](torch::Tensor& self, torch::Tensor& src) {
|
664
|
+
return self.copy_(src);
|
665
|
+
})
|
666
|
+
.define_method(
|
667
|
+
"clone",
|
668
|
+
*[](torch::Tensor& self) {
|
669
|
+
return self.clone();
|
481
670
|
})
|
482
671
|
.define_method(
|
483
672
|
"log_softmax",
|
@@ -532,8 +721,10 @@ void Init_ext()
|
|
532
721
|
a.push(data[i]);
|
533
722
|
}
|
534
723
|
} else if (dtype == torch::kBool) {
|
535
|
-
|
536
|
-
|
724
|
+
bool* data = self.data_ptr<bool>();
|
725
|
+
for (int i = 0; i < self.numel(); i++) {
|
726
|
+
a.push(data[i] ? True : False);
|
727
|
+
}
|
537
728
|
} else {
|
538
729
|
throw std::runtime_error("Unsupported type");
|
539
730
|
}
|
@@ -544,6 +735,11 @@ void Init_ext()
|
|
544
735
|
*[](torch::Tensor& self, int i) {
|
545
736
|
return self.size(i);
|
546
737
|
})
|
738
|
+
.define_method(
|
739
|
+
"_to",
|
740
|
+
*[](torch::Tensor& self, torch::Device device, int dtype, bool non_blocking, bool copy) {
|
741
|
+
return self.to(device, (torch::ScalarType) dtype, non_blocking, copy);
|
742
|
+
})
|
547
743
|
.define_singleton_method(
|
548
744
|
"_make_subclass",
|
549
745
|
*[](torch::Tensor& rd, bool requires_grad) {
|
@@ -597,12 +793,17 @@ void Init_ext()
|
|
597
793
|
|
598
794
|
Module rb_mInit = define_module_under(rb_mNN, "Init")
|
599
795
|
.define_singleton_method(
|
600
|
-
"
|
796
|
+
"kaiming_uniform!",
|
601
797
|
*[](torch::Tensor& input, double a) {
|
602
798
|
return torch::nn::init::kaiming_uniform_(input, a);
|
603
799
|
})
|
604
800
|
.define_singleton_method(
|
605
|
-
"
|
801
|
+
"normal!",
|
802
|
+
*[](torch::Tensor& input) {
|
803
|
+
return torch::nn::init::normal_(input);
|
804
|
+
})
|
805
|
+
.define_singleton_method(
|
806
|
+
"uniform!",
|
606
807
|
*[](torch::Tensor& input, double to, double from) {
|
607
808
|
return torch::nn::init::uniform_(input, to, from);
|
608
809
|
});
|
@@ -619,4 +820,20 @@ void Init_ext()
|
|
619
820
|
*[](torch::autograd::Variable& self) {
|
620
821
|
return self.grad().defined();
|
621
822
|
});
|
823
|
+
|
824
|
+
Class rb_cDevice = define_class_under<torch::Device>(rb_mTorch, "Device")
|
825
|
+
.define_constructor(Constructor<torch::Device, std::string>())
|
826
|
+
.define_method("index", &torch::Device::index)
|
827
|
+
.define_method("index?", &torch::Device::has_index)
|
828
|
+
.define_method(
|
829
|
+
"type",
|
830
|
+
*[](torch::Device& self) {
|
831
|
+
std::stringstream s;
|
832
|
+
s << self.type();
|
833
|
+
return s.str();
|
834
|
+
});
|
835
|
+
|
836
|
+
Module rb_mCUDA = define_module_under(rb_mTorch, "CUDA")
|
837
|
+
.define_singleton_method("available?", &torch::cuda::is_available)
|
838
|
+
.define_singleton_method("device_count", &torch::cuda::device_count);
|
622
839
|
}
|