torch-rb 0.1.0 → 0.1.5
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +40 -0
- data/LICENSE.txt +46 -22
- data/README.md +85 -19
- data/ext/torch/ext.cpp +274 -256
- data/ext/torch/extconf.rb +9 -0
- data/ext/torch/nn_functions.cpp +595 -0
- data/ext/torch/nn_functions.hpp +6 -0
- data/ext/torch/templates.hpp +250 -0
- data/ext/torch/tensor_functions.cpp +1860 -0
- data/ext/torch/tensor_functions.hpp +6 -0
- data/ext/torch/torch_functions.cpp +2875 -0
- data/ext/torch/torch_functions.hpp +6 -0
- data/lib/torch.rb +199 -84
- data/lib/torch/ext.bundle +0 -0
- data/lib/torch/inspector.rb +52 -25
- data/lib/torch/native/dispatcher.rb +48 -0
- data/lib/torch/native/function.rb +78 -0
- data/lib/torch/native/generator.rb +149 -0
- data/lib/torch/native/native_functions.yaml +6837 -0
- data/lib/torch/native/parser.rb +97 -0
- data/lib/torch/nn/alpha_dropout.rb +9 -0
- data/lib/torch/nn/avg_pool2d.rb +14 -0
- data/lib/torch/nn/avg_poolnd.rb +9 -0
- data/lib/torch/nn/bce_loss.rb +13 -0
- data/lib/torch/nn/bce_with_logits_loss.rb +15 -0
- data/lib/torch/nn/bilinear.rb +38 -0
- data/lib/torch/nn/conv2d.rb +14 -29
- data/lib/torch/nn/convnd.rb +41 -0
- data/lib/torch/nn/cosine_embedding_loss.rb +14 -0
- data/lib/torch/nn/cosine_similarity.rb +15 -0
- data/lib/torch/nn/cross_entropy_loss.rb +14 -0
- data/lib/torch/nn/ctc_loss.rb +15 -0
- data/lib/torch/nn/dropout.rb +9 -0
- data/lib/torch/nn/dropout2d.rb +9 -0
- data/lib/torch/nn/dropout3d.rb +9 -0
- data/lib/torch/nn/dropoutnd.rb +15 -0
- data/lib/torch/nn/embedding.rb +52 -0
- data/lib/torch/nn/embedding_bag.rb +34 -0
- data/lib/torch/nn/feature_alpha_dropout.rb +9 -0
- data/lib/torch/nn/functional.rb +194 -11
- data/lib/torch/nn/hinge_embedding_loss.rb +14 -0
- data/lib/torch/nn/identity.rb +14 -0
- data/lib/torch/nn/init.rb +58 -1
- data/lib/torch/nn/kl_div_loss.rb +13 -0
- data/lib/torch/nn/l1_loss.rb +13 -0
- data/lib/torch/nn/leaky_relu.rb +20 -0
- data/lib/torch/nn/linear.rb +12 -11
- data/lib/torch/nn/log_softmax.rb +14 -0
- data/lib/torch/nn/loss.rb +10 -0
- data/lib/torch/nn/margin_ranking_loss.rb +14 -0
- data/lib/torch/nn/max_pool2d.rb +9 -0
- data/lib/torch/nn/max_poolnd.rb +19 -0
- data/lib/torch/nn/module.rb +184 -19
- data/lib/torch/nn/mse_loss.rb +2 -2
- data/lib/torch/nn/multi_label_margin_loss.rb +13 -0
- data/lib/torch/nn/multi_label_soft_margin_loss.rb +13 -0
- data/lib/torch/nn/multi_margin_loss.rb +17 -0
- data/lib/torch/nn/nll_loss.rb +14 -0
- data/lib/torch/nn/pairwise_distance.rb +16 -0
- data/lib/torch/nn/parameter.rb +4 -0
- data/lib/torch/nn/poisson_nll_loss.rb +16 -0
- data/lib/torch/nn/prelu.rb +19 -0
- data/lib/torch/nn/relu.rb +8 -3
- data/lib/torch/nn/rnn.rb +22 -0
- data/lib/torch/nn/rnn_base.rb +154 -0
- data/lib/torch/nn/sequential.rb +1 -10
- data/lib/torch/nn/sigmoid.rb +9 -0
- data/lib/torch/nn/smooth_l1_loss.rb +13 -0
- data/lib/torch/nn/soft_margin_loss.rb +13 -0
- data/lib/torch/nn/softmax.rb +18 -0
- data/lib/torch/nn/softmax2d.rb +10 -0
- data/lib/torch/nn/softmin.rb +14 -0
- data/lib/torch/nn/softplus.rb +19 -0
- data/lib/torch/nn/triplet_margin_loss.rb +18 -0
- data/lib/torch/nn/weighted_loss.rb +10 -0
- data/lib/torch/optim/adadelta.rb +57 -0
- data/lib/torch/optim/adagrad.rb +71 -0
- data/lib/torch/optim/adam.rb +81 -0
- data/lib/torch/optim/adamax.rb +68 -0
- data/lib/torch/optim/adamw.rb +82 -0
- data/lib/torch/optim/asgd.rb +65 -0
- data/lib/torch/optim/lr_scheduler/lr_scheduler.rb +33 -0
- data/lib/torch/optim/lr_scheduler/step_lr.rb +17 -0
- data/lib/torch/optim/optimizer.rb +62 -0
- data/lib/torch/optim/rmsprop.rb +76 -0
- data/lib/torch/optim/rprop.rb +68 -0
- data/lib/torch/optim/sgd.rb +60 -0
- data/lib/torch/random.rb +10 -0
- data/lib/torch/tensor.rb +92 -21
- data/lib/torch/utils/data/data_loader.rb +15 -0
- data/lib/torch/utils/data/tensor_dataset.rb +8 -1
- data/lib/torch/version.rb +1 -1
- metadata +74 -3
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 6b47306ed525e1a20d25cb8324d4658f750c18afa5704c9b7bafc215d8f568c1
|
4
|
+
data.tar.gz: dad6ddf955b111989b061e5af146006a32c83dc1ea1ca5005a6b6e34bc9a4892
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 5d26e3642bf7cd921b9b570052df353d4c32b1bd955a6fbbf5f30249631fa4c0d4624f4fa91a1c06f61b3b0d6461cd117ab4df185cf013e915d2f63e52dbcf7c
|
7
|
+
data.tar.gz: 1728ce9b579f41f7a567e63d7256c82bb352840b67f16d88aac930a99e5abbf5a5f4061c5f9da16fb47d1664567e7956d276a8b2b44f13d2263032486afb53e8
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,43 @@
|
|
1
|
+
## 0.1.5 (2019-12-06)
|
2
|
+
|
3
|
+
- Added many more functions
|
4
|
+
- Added tensor classes - `FloatTensor`, `LongTensor`, etc
|
5
|
+
- Improved modules
|
6
|
+
|
7
|
+
## 0.1.4 (2019-12-01)
|
8
|
+
|
9
|
+
- Added distance functions
|
10
|
+
- Added more activations
|
11
|
+
- Added more linear layers
|
12
|
+
- Added more loss functions
|
13
|
+
- Added more init methods
|
14
|
+
- Added support for tensor assignment
|
15
|
+
|
16
|
+
## 0.1.3 (2019-11-30)
|
17
|
+
|
18
|
+
- Changed to BSD 3-Clause license to match PyTorch
|
19
|
+
- Added many optimizers
|
20
|
+
- Added `StepLR` learning rate scheduler
|
21
|
+
- Added dropout
|
22
|
+
- Added embedding
|
23
|
+
- Added support for `bool` type
|
24
|
+
- Improved performance of `from_numo`
|
25
|
+
|
26
|
+
## 0.1.2 (2019-11-27)
|
27
|
+
|
28
|
+
- Added SGD optimizer
|
29
|
+
- Added support for gradient to `backward` method
|
30
|
+
- Added `argmax`, `eq`, `leaky_relu`, `prelu`, and `reshape` methods
|
31
|
+
- Improved indexing
|
32
|
+
- Fixed `zero_grad`
|
33
|
+
- Fixed error with infinite values
|
34
|
+
|
35
|
+
## 0.1.1 (2019-11-26)
|
36
|
+
|
37
|
+
- Added support for `uint8` and `int8` types
|
38
|
+
- Fixed `undefined symbol` error on Linux
|
39
|
+
- Fixed C++ error messages
|
40
|
+
|
1
41
|
## 0.1.0 (2019-11-26)
|
2
42
|
|
3
43
|
- First release
|
data/LICENSE.txt
CHANGED
@@ -1,22 +1,46 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
1
|
+
BSD 3-Clause License
|
2
|
+
|
3
|
+
From Torch-rb:
|
4
|
+
|
5
|
+
Copyright (c) 2019- Andrew Kane
|
6
|
+
|
7
|
+
From PyTorch (for ported code):
|
8
|
+
|
9
|
+
Copyright (c) 2016- Facebook, Inc (Adam Paszke)
|
10
|
+
Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
|
11
|
+
Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
|
12
|
+
Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
|
13
|
+
Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
|
14
|
+
Copyright (c) 2011-2013 NYU (Clement Farabet)
|
15
|
+
Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
|
16
|
+
Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
|
17
|
+
Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
|
18
|
+
|
19
|
+
All rights reserved.
|
20
|
+
|
21
|
+
Redistribution and use in source and binary forms, with or without
|
22
|
+
modification, are permitted provided that the following conditions are met:
|
23
|
+
|
24
|
+
1. Redistributions of source code must retain the above copyright
|
25
|
+
notice, this list of conditions and the following disclaimer.
|
26
|
+
|
27
|
+
2. Redistributions in binary form must reproduce the above copyright
|
28
|
+
notice, this list of conditions and the following disclaimer in the
|
29
|
+
documentation and/or other materials provided with the distribution.
|
30
|
+
|
31
|
+
3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America
|
32
|
+
and IDIAP Research Institute nor the names of its contributors may be
|
33
|
+
used to endorse or promote products derived from this software without
|
34
|
+
specific prior written permission.
|
35
|
+
|
36
|
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
37
|
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
38
|
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
39
|
+
ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
|
40
|
+
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
41
|
+
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
42
|
+
SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
43
|
+
INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
44
|
+
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
45
|
+
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
46
|
+
POSSIBILITY OF SUCH DAMAGE.
|
data/README.md
CHANGED
@@ -2,14 +2,16 @@
|
|
2
2
|
|
3
3
|
:fire: Deep learning for Ruby, powered by [LibTorch](https://pytorch.org)
|
4
4
|
|
5
|
-
|
5
|
+
This gem is currently experimental. There may be breaking changes between each release. Please report any issues you experience.
|
6
|
+
|
7
|
+
[![Build Status](https://travis-ci.org/ankane/torch-rb.svg?branch=master)](https://travis-ci.org/ankane/torch-rb)
|
6
8
|
|
7
9
|
## Installation
|
8
10
|
|
9
11
|
First, [install LibTorch](#libtorch-installation). For Homebrew, use:
|
10
12
|
|
11
13
|
```sh
|
12
|
-
brew install
|
14
|
+
brew install libtorch
|
13
15
|
```
|
14
16
|
|
15
17
|
Add this line to your application’s Gemfile:
|
@@ -18,6 +20,8 @@ Add this line to your application’s Gemfile:
|
|
18
20
|
gem 'torch-rb'
|
19
21
|
```
|
20
22
|
|
23
|
+
It can take a few minutes to compile the extension.
|
24
|
+
|
21
25
|
## Getting Started
|
22
26
|
|
23
27
|
This library follows the [PyTorch API](https://pytorch.org/docs/stable/torch.html). There are a few changes to make it more Ruby-like:
|
@@ -26,9 +30,11 @@ This library follows the [PyTorch API](https://pytorch.org/docs/stable/torch.htm
|
|
26
30
|
- Methods that return booleans use `?` instead of `is_` (`tensor?` instead of `is_tensor`)
|
27
31
|
- Numo is used instead of NumPy (`x.numo` instead of `x.numpy()`)
|
28
32
|
|
29
|
-
|
33
|
+
Some methods and options are missing at the moment. PRs welcome!
|
30
34
|
|
31
|
-
|
35
|
+
## Tutorial
|
36
|
+
|
37
|
+
Some examples below are from [Deep Learning with PyTorch: A 60 Minutes Blitz](https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html)
|
32
38
|
|
33
39
|
### Tensors
|
34
40
|
|
@@ -143,7 +149,7 @@ Convert a Numo array to a tensor
|
|
143
149
|
|
144
150
|
```ruby
|
145
151
|
b = Numo::NArray.cast([1, 2, 3])
|
146
|
-
Torch.
|
152
|
+
Torch.from_numo(b)
|
147
153
|
```
|
148
154
|
|
149
155
|
### Autograd
|
@@ -171,17 +177,17 @@ out.backward
|
|
171
177
|
Get gradients
|
172
178
|
|
173
179
|
```ruby
|
174
|
-
x.grad
|
180
|
+
x.grad # tensor([[4.5, 4.5], [4.5, 4.5]])
|
175
181
|
```
|
176
182
|
|
177
183
|
Stop autograd from tracking history
|
178
184
|
|
179
185
|
```ruby
|
180
186
|
x.requires_grad # true
|
181
|
-
(x
|
187
|
+
(x**2).requires_grad # true
|
182
188
|
|
183
189
|
Torch.no_grad do
|
184
|
-
(x
|
190
|
+
(x**2).requires_grad # false
|
185
191
|
end
|
186
192
|
```
|
187
193
|
|
@@ -221,7 +227,7 @@ class Net < Torch::NN::Module
|
|
221
227
|
end
|
222
228
|
```
|
223
229
|
|
224
|
-
|
230
|
+
Create an instance of it
|
225
231
|
|
226
232
|
```ruby
|
227
233
|
net = Net.new
|
@@ -229,6 +235,58 @@ input = Torch.randn(1, 1, 32, 32)
|
|
229
235
|
net.call(input)
|
230
236
|
```
|
231
237
|
|
238
|
+
Get trainable parameters
|
239
|
+
|
240
|
+
```ruby
|
241
|
+
net.parameters
|
242
|
+
```
|
243
|
+
|
244
|
+
Zero the gradient buffers and backprop with random gradients
|
245
|
+
|
246
|
+
```ruby
|
247
|
+
net.zero_grad
|
248
|
+
out.backward(Torch.randn(1, 10))
|
249
|
+
```
|
250
|
+
|
251
|
+
Define a loss function
|
252
|
+
|
253
|
+
```ruby
|
254
|
+
output = net.call(input)
|
255
|
+
target = Torch.randn(10)
|
256
|
+
target = target.view(1, -1)
|
257
|
+
criterion = Torch::NN::MSELoss.new
|
258
|
+
loss = criterion.call(output, target)
|
259
|
+
```
|
260
|
+
|
261
|
+
Backprop
|
262
|
+
|
263
|
+
```ruby
|
264
|
+
net.zero_grad
|
265
|
+
p net.conv1.bias.grad
|
266
|
+
loss.backward
|
267
|
+
p net.conv1.bias.grad
|
268
|
+
```
|
269
|
+
|
270
|
+
Update the weights
|
271
|
+
|
272
|
+
```ruby
|
273
|
+
learning_rate = 0.01
|
274
|
+
net.parameters.each do |f|
|
275
|
+
f.data.sub!(f.grad.data * learning_rate)
|
276
|
+
end
|
277
|
+
```
|
278
|
+
|
279
|
+
Use an optimizer
|
280
|
+
|
281
|
+
```ruby
|
282
|
+
optimizer = Torch::Optim::SGD.new(net.parameters, lr: 0.01)
|
283
|
+
optimizer.zero_grad
|
284
|
+
output = net.call(input)
|
285
|
+
loss = criterion.call(output, target)
|
286
|
+
loss.backward
|
287
|
+
optimizer.step
|
288
|
+
```
|
289
|
+
|
232
290
|
### Tensor Creation
|
233
291
|
|
234
292
|
Here’s a list of functions to create tensors (descriptions from the [C++ docs](https://pytorch.org/cppdocs/notes/tensor_creation.html)):
|
@@ -242,7 +300,7 @@ Here’s a list of functions to create tensors (descriptions from the [C++ docs]
|
|
242
300
|
- `empty` returns a tensor with uninitialized values
|
243
301
|
|
244
302
|
```ruby
|
245
|
-
Torch.empty(3)
|
303
|
+
Torch.empty(3) # tensor([7.0054e-45, 0.0000e+00, 0.0000e+00])
|
246
304
|
```
|
247
305
|
|
248
306
|
- `eye` returns an identity matrix
|
@@ -278,19 +336,19 @@ Here’s a list of functions to create tensors (descriptions from the [C++ docs]
|
|
278
336
|
- `rand` returns a tensor filled with values drawn from a uniform distribution on [0, 1)
|
279
337
|
|
280
338
|
```ruby
|
281
|
-
Torch.rand(3)
|
339
|
+
Torch.rand(3) # tensor([0.5444, 0.8799, 0.5571])
|
282
340
|
```
|
283
341
|
|
284
342
|
- `randint` returns a tensor with integers randomly drawn from an interval
|
285
343
|
|
286
344
|
```ruby
|
287
|
-
Torch.randint(1, 10, [3])
|
345
|
+
Torch.randint(1, 10, [3]) # tensor([7, 6, 4])
|
288
346
|
```
|
289
347
|
|
290
348
|
- `randn` returns a tensor filled with values drawn from a unit normal distribution
|
291
349
|
|
292
350
|
```ruby
|
293
|
-
Torch.randn(3)
|
351
|
+
Torch.randn(3) # tensor([-0.7147, 0.6614, 1.1453])
|
294
352
|
```
|
295
353
|
|
296
354
|
- `randperm` returns a tensor filled with a random permutation of integers in some interval
|
@@ -305,12 +363,20 @@ Here’s a list of functions to create tensors (descriptions from the [C++ docs]
|
|
305
363
|
Torch.zeros(3) # tensor([0, 0, 0])
|
306
364
|
```
|
307
365
|
|
366
|
+
## Examples
|
367
|
+
|
368
|
+
Here are a few full examples:
|
369
|
+
|
370
|
+
- [Image classification with MNIST](examples/mnist) ([日本語版](https://qiita.com/kojix2/items/c19c36dc1bf73ea93409))
|
371
|
+
- [Collaborative filtering with MovieLens](examples/movielens)
|
372
|
+
- [Word embeddings](examples/nlp)
|
373
|
+
|
308
374
|
## LibTorch Installation
|
309
375
|
|
310
|
-
[Download LibTorch](https://pytorch.org/)
|
376
|
+
[Download LibTorch](https://pytorch.org/). For Linux, use the `cxx11 ABI` version. Then run:
|
311
377
|
|
312
378
|
```sh
|
313
|
-
|
379
|
+
bundle config build.torch-rb --with-torch-dir=/path/to/libtorch
|
314
380
|
```
|
315
381
|
|
316
382
|
### Homebrew
|
@@ -318,10 +384,10 @@ gem install torch-rb -- --with-torch-dir=/path/to/libtorch
|
|
318
384
|
For Mac, you can use Homebrew.
|
319
385
|
|
320
386
|
```sh
|
321
|
-
brew install
|
387
|
+
brew install libtorch
|
322
388
|
```
|
323
389
|
|
324
|
-
Then install the gem (no need for
|
390
|
+
Then install the gem (no need for `bundle config`).
|
325
391
|
|
326
392
|
## rbenv
|
327
393
|
|
@@ -349,9 +415,9 @@ To get started with development:
|
|
349
415
|
|
350
416
|
```sh
|
351
417
|
git clone https://github.com/ankane/torch-rb.git
|
352
|
-
cd torch
|
418
|
+
cd torch-rb
|
353
419
|
bundle install
|
354
|
-
bundle exec rake compile
|
420
|
+
bundle exec rake compile -- --with-torch-dir=/path/to/libtorch
|
355
421
|
bundle exec rake test
|
356
422
|
```
|
357
423
|
|
data/ext/torch/ext.cpp
CHANGED
@@ -6,95 +6,36 @@
|
|
6
6
|
#include <rice/Class.hpp>
|
7
7
|
#include <rice/Constructor.hpp>
|
8
8
|
|
9
|
-
|
9
|
+
#include "templates.hpp"
|
10
10
|
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
}
|
17
|
-
|
18
|
-
template<>
|
19
|
-
inline
|
20
|
-
Object to_ruby<long long>(long long const & x)
|
21
|
-
{
|
22
|
-
return LL2NUM(x);
|
23
|
-
}
|
24
|
-
|
25
|
-
template<>
|
26
|
-
inline
|
27
|
-
unsigned long long from_ruby<unsigned long long>(Object x)
|
28
|
-
{
|
29
|
-
return NUM2ULL(x);
|
30
|
-
}
|
11
|
+
// generated with:
|
12
|
+
// rake generate:functions
|
13
|
+
#include "torch_functions.hpp"
|
14
|
+
#include "tensor_functions.hpp"
|
15
|
+
#include "nn_functions.hpp"
|
31
16
|
|
32
|
-
|
33
|
-
inline
|
34
|
-
Object to_ruby<unsigned long long>(unsigned long long const & x)
|
35
|
-
{
|
36
|
-
return ULL2NUM(x);
|
37
|
-
}
|
38
|
-
|
39
|
-
template<>
|
40
|
-
inline
|
41
|
-
short from_ruby<short>(Object x)
|
42
|
-
{
|
43
|
-
return NUM2SHORT(x);
|
44
|
-
}
|
45
|
-
|
46
|
-
template<>
|
47
|
-
inline
|
48
|
-
Object to_ruby<short>(short const & x)
|
49
|
-
{
|
50
|
-
return INT2NUM(x);
|
51
|
-
}
|
17
|
+
using namespace Rice;
|
52
18
|
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
return
|
19
|
+
Object tensor_array(std::tuple<torch::Tensor, torch::Tensor> x) {
|
20
|
+
Array a;
|
21
|
+
a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
|
22
|
+
a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
|
23
|
+
return Object(a);
|
58
24
|
}
|
59
25
|
|
60
|
-
|
61
|
-
|
62
|
-
Object to_ruby<unsigned short>(unsigned short const & x)
|
26
|
+
extern "C"
|
27
|
+
void Init_ext()
|
63
28
|
{
|
64
|
-
|
65
|
-
|
29
|
+
Module rb_mTorch = define_module("Torch");
|
30
|
+
add_torch_functions(rb_mTorch);
|
66
31
|
|
67
|
-
|
68
|
-
|
69
|
-
class IntArrayRef {
|
70
|
-
std::vector<int64_t> vec;
|
71
|
-
public:
|
72
|
-
IntArrayRef(Object o) {
|
73
|
-
Array a = Array(o);
|
74
|
-
for (size_t i = 0; i < a.size(); i++) {
|
75
|
-
vec.push_back(from_ruby<int64_t>(a[i]));
|
76
|
-
}
|
77
|
-
}
|
78
|
-
operator torch::IntArrayRef() {
|
79
|
-
return torch::IntArrayRef(vec);
|
80
|
-
}
|
81
|
-
};
|
32
|
+
Class rb_cTensor = define_class_under<torch::Tensor>(rb_mTorch, "Tensor");
|
33
|
+
add_tensor_functions(rb_cTensor);
|
82
34
|
|
83
|
-
|
84
|
-
|
85
|
-
IntArrayRef from_ruby<IntArrayRef>(Object x)
|
86
|
-
{
|
87
|
-
return IntArrayRef(x);
|
88
|
-
}
|
89
|
-
|
90
|
-
// for now
|
91
|
-
typedef float Scalar;
|
35
|
+
Module rb_mNN = define_module_under(rb_mTorch, "NN");
|
36
|
+
add_nn_functions(rb_mNN);
|
92
37
|
|
93
|
-
|
94
|
-
void Init_ext()
|
95
|
-
{
|
96
|
-
Module rb_mTorch = define_module("Torch")
|
97
|
-
.define_singleton_method(
|
38
|
+
rb_mTorch.define_singleton_method(
|
98
39
|
"grad_enabled?",
|
99
40
|
*[]() {
|
100
41
|
return torch::GradMode::is_enabled();
|
@@ -104,11 +45,6 @@ void Init_ext()
|
|
104
45
|
*[](bool enabled) {
|
105
46
|
torch::GradMode::set_enabled(enabled);
|
106
47
|
})
|
107
|
-
.define_singleton_method(
|
108
|
-
"floating_point?",
|
109
|
-
*[](torch::Tensor& input) {
|
110
|
-
return torch::is_floating_point(input);
|
111
|
-
})
|
112
48
|
.define_singleton_method(
|
113
49
|
"manual_seed",
|
114
50
|
*[](uint64_t seed) {
|
@@ -178,172 +114,117 @@ void Init_ext()
|
|
178
114
|
// begin operations
|
179
115
|
.define_singleton_method(
|
180
116
|
"_mean",
|
181
|
-
*[](
|
117
|
+
*[](Tensor& input) {
|
182
118
|
return torch::mean(input);
|
183
119
|
})
|
184
120
|
.define_singleton_method(
|
185
121
|
"_mean_dim",
|
186
|
-
*[](
|
122
|
+
*[](Tensor& input, int64_t dim, bool keepdim) {
|
187
123
|
return torch::mean(input, dim, keepdim);
|
188
124
|
})
|
189
125
|
.define_singleton_method(
|
190
126
|
"_sum",
|
191
|
-
*[](
|
127
|
+
*[](Tensor& input) {
|
192
128
|
return torch::sum(input);
|
193
129
|
})
|
194
130
|
.define_singleton_method(
|
195
131
|
"_sum_dim",
|
196
|
-
*[](
|
132
|
+
*[](Tensor& input, int64_t dim, bool keepdim) {
|
197
133
|
return torch::sum(input, dim, keepdim);
|
198
134
|
})
|
199
135
|
.define_singleton_method(
|
200
|
-
"
|
201
|
-
*[](
|
202
|
-
return torch::
|
203
|
-
})
|
204
|
-
.define_singleton_method(
|
205
|
-
"_min",
|
206
|
-
*[](torch::Tensor& input) {
|
207
|
-
return torch::min(input);
|
136
|
+
"_max_out",
|
137
|
+
*[](Tensor &max, Tensor &max_indices, const Tensor &input, int64_t dim, bool keepdim) {
|
138
|
+
return tensor_array(torch::_max_out(max, max_indices, input, dim, keepdim));
|
208
139
|
})
|
209
140
|
.define_singleton_method(
|
210
|
-
"
|
211
|
-
*[](
|
212
|
-
return torch::
|
141
|
+
"_topk",
|
142
|
+
*[](Tensor& input, int64_t k) {
|
143
|
+
return tensor_array(torch::topk(input, k));
|
213
144
|
})
|
214
145
|
.define_singleton_method(
|
215
|
-
"
|
216
|
-
*[](
|
217
|
-
return torch::
|
146
|
+
"_softmax",
|
147
|
+
*[](const Tensor &input, int64_t dim) {
|
148
|
+
return torch::softmax(input, dim);
|
218
149
|
})
|
219
150
|
.define_singleton_method(
|
220
|
-
"
|
221
|
-
*[](
|
222
|
-
return torch::
|
151
|
+
"_log_softmax",
|
152
|
+
*[](Tensor& input, int64_t dim) {
|
153
|
+
return torch::log_softmax(input, dim);
|
223
154
|
})
|
224
155
|
.define_singleton_method(
|
225
|
-
"
|
226
|
-
*[](
|
227
|
-
return torch::
|
228
|
-
})
|
229
|
-
.define_singleton_method(
|
230
|
-
"_dot",
|
231
|
-
*[](torch::Tensor& input, torch::Tensor& tensor) {
|
232
|
-
return torch::dot(input, tensor);
|
233
|
-
})
|
234
|
-
.define_singleton_method(
|
235
|
-
"_matmul",
|
236
|
-
*[](torch::Tensor& input, torch::Tensor& other) {
|
237
|
-
return torch::matmul(input, other);
|
238
|
-
})
|
239
|
-
.define_singleton_method(
|
240
|
-
"_add",
|
241
|
-
*[](torch::Tensor& input, torch::Tensor& other) {
|
242
|
-
return torch::add(input, other);
|
243
|
-
})
|
244
|
-
.define_singleton_method(
|
245
|
-
"_add_scalar",
|
246
|
-
*[](torch::Tensor& input, float other) {
|
247
|
-
return torch::add(input, other);
|
248
|
-
})
|
249
|
-
.define_singleton_method(
|
250
|
-
"_add_out",
|
251
|
-
*[](torch::Tensor& out, torch::Tensor& input, torch::Tensor& other) {
|
252
|
-
return torch::add_out(out, input, other);
|
253
|
-
})
|
254
|
-
.define_singleton_method(
|
255
|
-
"_sub",
|
256
|
-
*[](torch::Tensor& input, torch::Tensor& other) {
|
257
|
-
return torch::sub(input, other);
|
258
|
-
})
|
259
|
-
.define_singleton_method(
|
260
|
-
"_sub_scalar",
|
261
|
-
*[](torch::Tensor& input, float other) {
|
262
|
-
return torch::sub(input, other);
|
263
|
-
})
|
264
|
-
.define_singleton_method(
|
265
|
-
"_mul",
|
266
|
-
*[](torch::Tensor& input, torch::Tensor& other) {
|
267
|
-
return torch::mul(input, other);
|
268
|
-
})
|
269
|
-
.define_singleton_method(
|
270
|
-
"_mul_scalar",
|
271
|
-
*[](torch::Tensor& input, float other) {
|
272
|
-
return torch::mul(input, other);
|
273
|
-
})
|
274
|
-
.define_singleton_method(
|
275
|
-
"_div",
|
276
|
-
*[](torch::Tensor& input, torch::Tensor& other) {
|
277
|
-
return torch::div(input, other);
|
278
|
-
})
|
279
|
-
.define_singleton_method(
|
280
|
-
"_div_scalar",
|
281
|
-
*[](torch::Tensor& input, float other) {
|
282
|
-
return torch::div(input, other);
|
283
|
-
})
|
284
|
-
.define_singleton_method(
|
285
|
-
"_remainder",
|
286
|
-
*[](torch::Tensor& input, torch::Tensor& other) {
|
287
|
-
return torch::remainder(input, other);
|
288
|
-
})
|
289
|
-
.define_singleton_method(
|
290
|
-
"_remainder_scalar",
|
291
|
-
*[](torch::Tensor& input, float other) {
|
292
|
-
return torch::remainder(input, other);
|
156
|
+
"relu",
|
157
|
+
*[](Tensor& input) {
|
158
|
+
return torch::relu(input);
|
293
159
|
})
|
294
160
|
.define_singleton_method(
|
295
|
-
"
|
296
|
-
*[](torch::Tensor& input,
|
297
|
-
return torch::
|
161
|
+
"prelu",
|
162
|
+
*[](torch::Tensor& input, torch::Tensor& weight) {
|
163
|
+
return torch::prelu(input, weight);
|
298
164
|
})
|
299
165
|
.define_singleton_method(
|
300
|
-
"
|
301
|
-
*[](torch::Tensor& input) {
|
302
|
-
return torch::
|
166
|
+
"leaky_relu",
|
167
|
+
*[](torch::Tensor& input, Scalar negative_slope) {
|
168
|
+
return torch::leaky_relu(input, negative_slope);
|
303
169
|
})
|
304
170
|
.define_singleton_method(
|
305
|
-
"
|
306
|
-
*[](
|
307
|
-
return torch::
|
171
|
+
"conv2d",
|
172
|
+
*[](Tensor& input, Tensor& weight, Tensor& bias, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, int64_t groups) {
|
173
|
+
return torch::conv2d(input, weight, bias, stride, padding, dilation, groups);
|
308
174
|
})
|
175
|
+
// linear layers
|
309
176
|
.define_singleton_method(
|
310
|
-
"
|
311
|
-
*[](
|
312
|
-
return torch::
|
177
|
+
"bilinear",
|
178
|
+
*[](const Tensor &input1, const Tensor &input2, const Tensor &weight, const Tensor &bias) {
|
179
|
+
return torch::bilinear(input1, input2, weight, bias);
|
313
180
|
})
|
314
181
|
.define_singleton_method(
|
315
182
|
"linear",
|
316
|
-
*[](
|
183
|
+
*[](Tensor& input, Tensor& weight, Tensor& bias) {
|
317
184
|
return torch::linear(input, weight, bias);
|
318
185
|
})
|
186
|
+
// pooling layers
|
319
187
|
.define_singleton_method(
|
320
188
|
"max_pool2d",
|
321
|
-
*[](
|
189
|
+
*[](Tensor& input, IntArrayRef kernel_size) {
|
322
190
|
return torch::max_pool2d(input, kernel_size);
|
323
191
|
})
|
324
192
|
.define_singleton_method(
|
325
|
-
"
|
326
|
-
*[](
|
327
|
-
|
328
|
-
|
193
|
+
"avg_pool2d",
|
194
|
+
*[](Tensor& input, IntArrayRef kernel_size) {
|
195
|
+
return torch::avg_pool2d(input, kernel_size);
|
196
|
+
})
|
197
|
+
.define_singleton_method(
|
198
|
+
"_binary_cross_entropy_with_logits",
|
199
|
+
*[](const Tensor &input, const Tensor &target, OptionalTensor weight, OptionalTensor pos_weight, MyReduction reduction) {
|
200
|
+
return torch::binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction);
|
329
201
|
})
|
202
|
+
.define_singleton_method("numel", &torch::numel)
|
330
203
|
.define_singleton_method(
|
331
|
-
"
|
332
|
-
*[](
|
333
|
-
|
204
|
+
"_from_blob",
|
205
|
+
*[](String s, IntArrayRef size, const torch::TensorOptions &options) {
|
206
|
+
void *data = const_cast<char *>(s.c_str());
|
207
|
+
return torch::from_blob(data, size, options);
|
334
208
|
})
|
335
209
|
.define_singleton_method(
|
336
210
|
"_tensor",
|
337
211
|
*[](Object o, IntArrayRef size, const torch::TensorOptions &options) {
|
338
212
|
Array a = Array(o);
|
339
|
-
|
340
|
-
|
341
|
-
|
213
|
+
auto dtype = options.dtype();
|
214
|
+
torch::Tensor t;
|
215
|
+
if (dtype == torch::kBool) {
|
216
|
+
throw std::runtime_error("Cannot create bool from tensor method yet");
|
217
|
+
} else {
|
218
|
+
std::vector<float> vec;
|
219
|
+
for (size_t i = 0; i < a.size(); i++) {
|
220
|
+
vec.push_back(from_ruby<float>(a[i]));
|
221
|
+
}
|
222
|
+
t = torch::tensor(vec, options);
|
342
223
|
}
|
343
|
-
return
|
224
|
+
return t.reshape(size);
|
344
225
|
});
|
345
226
|
|
346
|
-
|
227
|
+
rb_cTensor
|
347
228
|
.define_method("cuda?", &torch::Tensor::is_cuda)
|
348
229
|
.define_method("distributed?", &torch::Tensor::is_distributed)
|
349
230
|
.define_method("complex?", &torch::Tensor::is_complex)
|
@@ -352,108 +233,162 @@ void Init_ext()
|
|
352
233
|
.define_method("sparse?", &torch::Tensor::is_sparse)
|
353
234
|
.define_method("quantized?", &torch::Tensor::is_quantized)
|
354
235
|
.define_method("dim", &torch::Tensor::dim)
|
355
|
-
.define_method("numel", &torch::Tensor::numel)
|
356
236
|
.define_method("element_size", &torch::Tensor::element_size)
|
357
237
|
.define_method("requires_grad", &torch::Tensor::requires_grad)
|
238
|
+
.define_method("view_as", &torch::Tensor::view_as)
|
239
|
+
.define_method(
|
240
|
+
"addcmul!",
|
241
|
+
*[](Tensor& self, Scalar value, const Tensor & tensor1, const Tensor & tensor2) {
|
242
|
+
return self.addcmul_(tensor1, tensor2, value);
|
243
|
+
})
|
244
|
+
.define_method(
|
245
|
+
"addcdiv!",
|
246
|
+
*[](Tensor& self, Scalar value, const Tensor & tensor1, const Tensor & tensor2) {
|
247
|
+
return self.addcdiv_(tensor1, tensor2, value);
|
248
|
+
})
|
358
249
|
.define_method(
|
359
250
|
"zero!",
|
360
|
-
*[](
|
251
|
+
*[](Tensor& self) {
|
361
252
|
return self.zero_();
|
362
253
|
})
|
363
254
|
.define_method(
|
364
|
-
"detach
|
365
|
-
*[](
|
366
|
-
return self.
|
255
|
+
"detach",
|
256
|
+
*[](Tensor& self) {
|
257
|
+
return self.detach();
|
367
258
|
})
|
368
259
|
.define_method(
|
369
|
-
"
|
370
|
-
*[](
|
371
|
-
return self
|
260
|
+
"detach!",
|
261
|
+
*[](Tensor& self) {
|
262
|
+
return self.detach_();
|
372
263
|
})
|
373
264
|
.define_method(
|
374
265
|
"_requires_grad!",
|
375
|
-
*[](
|
266
|
+
*[](Tensor& self, bool requires_grad) {
|
376
267
|
return self.set_requires_grad(requires_grad);
|
377
268
|
})
|
378
269
|
.define_method(
|
379
|
-
"
|
380
|
-
*[](
|
381
|
-
return self.backward();
|
270
|
+
"_backward",
|
271
|
+
*[](Tensor& self, Object gradient) {
|
272
|
+
return gradient.is_nil() ? self.backward() : self.backward(from_ruby<torch::Tensor>(gradient));
|
382
273
|
})
|
383
274
|
.define_method(
|
384
275
|
"grad",
|
385
|
-
*[](
|
276
|
+
*[](Tensor& self) {
|
386
277
|
return self.grad();
|
387
278
|
})
|
388
279
|
.define_method(
|
389
280
|
"_dtype",
|
390
|
-
*[](
|
281
|
+
*[](Tensor& self) {
|
391
282
|
return (int) at::typeMetaToScalarType(self.dtype());
|
392
283
|
})
|
284
|
+
.define_method(
|
285
|
+
"_type",
|
286
|
+
*[](Tensor& self, int dtype) {
|
287
|
+
return self.toType((torch::ScalarType) dtype);
|
288
|
+
})
|
393
289
|
.define_method(
|
394
290
|
"_layout",
|
395
|
-
*[](
|
291
|
+
*[](Tensor& self) {
|
396
292
|
std::stringstream s;
|
397
293
|
s << self.layout();
|
398
294
|
return s.str();
|
399
295
|
})
|
400
296
|
.define_method(
|
401
297
|
"device",
|
402
|
-
*[](
|
298
|
+
*[](Tensor& self) {
|
403
299
|
std::stringstream s;
|
404
300
|
s << self.device();
|
405
301
|
return s.str();
|
406
302
|
})
|
407
303
|
.define_method(
|
408
|
-
"
|
409
|
-
*[](
|
410
|
-
return self.
|
304
|
+
"resize_as!",
|
305
|
+
*[](Tensor& self, Tensor& other) {
|
306
|
+
return self.resize_as_(other);
|
411
307
|
})
|
412
308
|
.define_method(
|
413
|
-
"
|
414
|
-
*[](
|
415
|
-
self.
|
309
|
+
"fill!",
|
310
|
+
*[](Tensor& self, Scalar value) {
|
311
|
+
return self.fill_(value);
|
416
312
|
})
|
417
313
|
.define_method(
|
418
|
-
"
|
419
|
-
*[](
|
420
|
-
self.
|
314
|
+
"relu!",
|
315
|
+
*[](Tensor& self) {
|
316
|
+
return self.relu_();
|
317
|
+
})
|
318
|
+
.define_method(
|
319
|
+
"normal!",
|
320
|
+
*[](Tensor& self, double mean, double std) {
|
321
|
+
return self.normal_(mean, std);
|
322
|
+
})
|
323
|
+
.define_method(
|
324
|
+
"random!",
|
325
|
+
*[](Tensor& self, int64_t to) {
|
326
|
+
return self.random_(to);
|
421
327
|
})
|
422
328
|
.define_method(
|
423
|
-
"
|
424
|
-
*[](
|
425
|
-
self.
|
329
|
+
"sub!",
|
330
|
+
*[](Tensor& self, Tensor& other) {
|
331
|
+
return self.sub_(other);
|
426
332
|
})
|
427
333
|
.define_method(
|
428
334
|
"div!",
|
429
|
-
*[](
|
430
|
-
self.div_(other);
|
335
|
+
*[](Tensor& self, Tensor& other) {
|
336
|
+
return self.div_(other);
|
431
337
|
})
|
432
338
|
.define_method(
|
433
|
-
"
|
434
|
-
*[](
|
435
|
-
return self.
|
339
|
+
"sqrt!",
|
340
|
+
*[](Tensor& self) {
|
341
|
+
return self.sqrt_();
|
436
342
|
})
|
437
343
|
.define_method(
|
438
|
-
"
|
439
|
-
*[](
|
344
|
+
"unsqueeze!",
|
345
|
+
*[](Tensor& self, int64_t dim) {
|
346
|
+
return self.unsqueeze_(dim);
|
347
|
+
})
|
348
|
+
.define_method(
|
349
|
+
"copy!",
|
350
|
+
*[](Tensor& self, Tensor& src) {
|
351
|
+
return self.copy_(src);
|
352
|
+
})
|
353
|
+
.define_method(
|
354
|
+
"clone",
|
355
|
+
*[](Tensor& self) {
|
356
|
+
return self.clone();
|
357
|
+
})
|
358
|
+
.define_method(
|
359
|
+
"data",
|
360
|
+
*[](Tensor& self) {
|
361
|
+
return self.data();
|
362
|
+
})
|
363
|
+
.define_method(
|
364
|
+
"_flat_data",
|
365
|
+
*[](Tensor& self) {
|
440
366
|
Array a;
|
441
367
|
auto dtype = self.dtype();
|
442
368
|
|
443
369
|
// TODO DRY if someone knows C++
|
444
|
-
|
445
|
-
|
446
|
-
|
370
|
+
if (dtype == torch::kByte) {
|
371
|
+
uint8_t* data = self.data_ptr<uint8_t>();
|
372
|
+
for (int i = 0; i < self.numel(); i++) {
|
373
|
+
a.push(data[i]);
|
374
|
+
}
|
375
|
+
} else if (dtype == torch::kChar) {
|
376
|
+
int8_t* data = self.data_ptr<int8_t>();
|
377
|
+
for (int i = 0; i < self.numel(); i++) {
|
378
|
+
a.push(to_ruby<int>(data[i]));
|
379
|
+
}
|
380
|
+
} else if (dtype == torch::kShort) {
|
381
|
+
int16_t* data = self.data_ptr<int16_t>();
|
447
382
|
for (int i = 0; i < self.numel(); i++) {
|
448
383
|
a.push(data[i]);
|
449
384
|
}
|
450
385
|
} else if (dtype == torch::kInt) {
|
451
|
-
|
386
|
+
int32_t* data = self.data_ptr<int32_t>();
|
452
387
|
for (int i = 0; i < self.numel(); i++) {
|
453
388
|
a.push(data[i]);
|
454
389
|
}
|
455
390
|
} else if (dtype == torch::kLong) {
|
456
|
-
|
391
|
+
int64_t* data = self.data_ptr<int64_t>();
|
457
392
|
for (int i = 0; i < self.numel(); i++) {
|
458
393
|
a.push(data[i]);
|
459
394
|
}
|
@@ -467,19 +402,24 @@ void Init_ext()
|
|
467
402
|
for (int i = 0; i < self.numel(); i++) {
|
468
403
|
a.push(data[i]);
|
469
404
|
}
|
405
|
+
} else if (dtype == torch::kBool) {
|
406
|
+
bool* data = self.data_ptr<bool>();
|
407
|
+
for (int i = 0; i < self.numel(); i++) {
|
408
|
+
a.push(data[i] ? True : False);
|
409
|
+
}
|
470
410
|
} else {
|
471
|
-
throw "Unsupported type";
|
411
|
+
throw std::runtime_error("Unsupported type");
|
472
412
|
}
|
473
413
|
return a;
|
474
414
|
})
|
475
415
|
.define_method(
|
476
|
-
"
|
477
|
-
*[](
|
478
|
-
return self.
|
416
|
+
"_to",
|
417
|
+
*[](Tensor& self, torch::Device device, int dtype, bool non_blocking, bool copy) {
|
418
|
+
return self.to(device, (torch::ScalarType) dtype, non_blocking, copy);
|
479
419
|
})
|
480
420
|
.define_singleton_method(
|
481
421
|
"_make_subclass",
|
482
|
-
*[](
|
422
|
+
*[](Tensor& rd, bool requires_grad) {
|
483
423
|
auto data = torch::autograd::as_variable_ref(rd).detach();
|
484
424
|
data.unsafeGetTensorImpl()->set_allow_tensor_metadata_change(true);
|
485
425
|
auto var = data.set_requires_grad(requires_grad);
|
@@ -499,8 +439,11 @@ void Init_ext()
|
|
499
439
|
torch::Layout l;
|
500
440
|
if (layout == "strided") {
|
501
441
|
l = torch::kStrided;
|
442
|
+
} else if (layout == "sparse") {
|
443
|
+
l = torch::kSparse;
|
444
|
+
throw std::runtime_error("Sparse layout not supported yet");
|
502
445
|
} else {
|
503
|
-
throw "Unsupported layout";
|
446
|
+
throw std::runtime_error("Unsupported layout: " + layout);
|
504
447
|
}
|
505
448
|
return self.layout(l);
|
506
449
|
})
|
@@ -513,7 +456,7 @@ void Init_ext()
|
|
513
456
|
} else if (device == "cuda") {
|
514
457
|
d = torch::kCUDA;
|
515
458
|
} else {
|
516
|
-
throw "Unsupported device";
|
459
|
+
throw std::runtime_error("Unsupported device: " + device);
|
517
460
|
}
|
518
461
|
return self.device(d);
|
519
462
|
})
|
@@ -523,24 +466,99 @@ void Init_ext()
|
|
523
466
|
return self.requires_grad(requires_grad);
|
524
467
|
});
|
525
468
|
|
526
|
-
Module rb_mNN = define_module_under(rb_mTorch, "NN");
|
527
|
-
|
528
469
|
Module rb_mInit = define_module_under(rb_mNN, "Init")
|
529
470
|
.define_singleton_method(
|
530
|
-
"
|
531
|
-
*[](
|
532
|
-
return torch::nn::init::
|
471
|
+
"_calculate_gain",
|
472
|
+
*[](NonlinearityType nonlinearity, double param) {
|
473
|
+
return torch::nn::init::calculate_gain(nonlinearity, param);
|
533
474
|
})
|
534
475
|
.define_singleton_method(
|
535
|
-
"
|
536
|
-
*[](
|
537
|
-
return torch::nn::init::uniform_(
|
476
|
+
"_uniform!",
|
477
|
+
*[](Tensor tensor, double low, double high) {
|
478
|
+
return torch::nn::init::uniform_(tensor, low, high);
|
479
|
+
})
|
480
|
+
.define_singleton_method(
|
481
|
+
"_normal!",
|
482
|
+
*[](Tensor tensor, double mean, double std) {
|
483
|
+
return torch::nn::init::normal_(tensor, mean, std);
|
484
|
+
})
|
485
|
+
.define_singleton_method(
|
486
|
+
"_constant!",
|
487
|
+
*[](Tensor tensor, Scalar value) {
|
488
|
+
return torch::nn::init::constant_(tensor, value);
|
489
|
+
})
|
490
|
+
.define_singleton_method(
|
491
|
+
"_ones!",
|
492
|
+
*[](Tensor tensor) {
|
493
|
+
return torch::nn::init::ones_(tensor);
|
494
|
+
})
|
495
|
+
.define_singleton_method(
|
496
|
+
"_zeros!",
|
497
|
+
*[](Tensor tensor) {
|
498
|
+
return torch::nn::init::zeros_(tensor);
|
499
|
+
})
|
500
|
+
.define_singleton_method(
|
501
|
+
"_eye!",
|
502
|
+
*[](Tensor tensor) {
|
503
|
+
return torch::nn::init::eye_(tensor);
|
504
|
+
})
|
505
|
+
.define_singleton_method(
|
506
|
+
"_dirac!",
|
507
|
+
*[](Tensor tensor) {
|
508
|
+
return torch::nn::init::dirac_(tensor);
|
509
|
+
})
|
510
|
+
.define_singleton_method(
|
511
|
+
"_xavier_uniform!",
|
512
|
+
*[](Tensor tensor, double gain) {
|
513
|
+
return torch::nn::init::xavier_uniform_(tensor, gain);
|
514
|
+
})
|
515
|
+
.define_singleton_method(
|
516
|
+
"_xavier_normal!",
|
517
|
+
*[](Tensor tensor, double gain) {
|
518
|
+
return torch::nn::init::xavier_normal_(tensor, gain);
|
519
|
+
})
|
520
|
+
.define_singleton_method(
|
521
|
+
"_kaiming_uniform!",
|
522
|
+
*[](Tensor tensor, double a, FanModeType mode, NonlinearityType nonlinearity) {
|
523
|
+
return torch::nn::init::kaiming_uniform_(tensor, a, mode, nonlinearity);
|
524
|
+
})
|
525
|
+
.define_singleton_method(
|
526
|
+
"_kaiming_normal!",
|
527
|
+
*[](Tensor tensor, double a, FanModeType mode, NonlinearityType nonlinearity) {
|
528
|
+
return torch::nn::init::kaiming_normal_(tensor, a, mode, nonlinearity);
|
529
|
+
})
|
530
|
+
.define_singleton_method(
|
531
|
+
"_orthogonal!",
|
532
|
+
*[](Tensor tensor, double gain) {
|
533
|
+
return torch::nn::init::orthogonal_(tensor, gain);
|
534
|
+
})
|
535
|
+
.define_singleton_method(
|
536
|
+
"_sparse!",
|
537
|
+
*[](Tensor tensor, double sparsity, double std) {
|
538
|
+
return torch::nn::init::sparse_(tensor, sparsity, std);
|
538
539
|
});
|
539
540
|
|
540
541
|
Class rb_cParameter = define_class_under<torch::autograd::Variable, torch::Tensor>(rb_mNN, "Parameter")
|
541
542
|
.define_method(
|
542
543
|
"grad",
|
543
544
|
*[](torch::autograd::Variable& self) {
|
544
|
-
|
545
|
+
auto grad = self.grad();
|
546
|
+
return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
|
547
|
+
});
|
548
|
+
|
549
|
+
Class rb_cDevice = define_class_under<torch::Device>(rb_mTorch, "Device")
|
550
|
+
.define_constructor(Constructor<torch::Device, std::string>())
|
551
|
+
.define_method("index", &torch::Device::index)
|
552
|
+
.define_method("index?", &torch::Device::has_index)
|
553
|
+
.define_method(
|
554
|
+
"type",
|
555
|
+
*[](torch::Device& self) {
|
556
|
+
std::stringstream s;
|
557
|
+
s << self.type();
|
558
|
+
return s.str();
|
545
559
|
});
|
560
|
+
|
561
|
+
Module rb_mCUDA = define_module_under(rb_mTorch, "CUDA")
|
562
|
+
.define_singleton_method("available?", &torch::cuda::is_available)
|
563
|
+
.define_singleton_method("device_count", &torch::cuda::device_count);
|
546
564
|
}
|