torch-rb 0.1.0 → 0.1.5
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +40 -0
- data/LICENSE.txt +46 -22
- data/README.md +85 -19
- data/ext/torch/ext.cpp +274 -256
- data/ext/torch/extconf.rb +9 -0
- data/ext/torch/nn_functions.cpp +595 -0
- data/ext/torch/nn_functions.hpp +6 -0
- data/ext/torch/templates.hpp +250 -0
- data/ext/torch/tensor_functions.cpp +1860 -0
- data/ext/torch/tensor_functions.hpp +6 -0
- data/ext/torch/torch_functions.cpp +2875 -0
- data/ext/torch/torch_functions.hpp +6 -0
- data/lib/torch.rb +199 -84
- data/lib/torch/ext.bundle +0 -0
- data/lib/torch/inspector.rb +52 -25
- data/lib/torch/native/dispatcher.rb +48 -0
- data/lib/torch/native/function.rb +78 -0
- data/lib/torch/native/generator.rb +149 -0
- data/lib/torch/native/native_functions.yaml +6837 -0
- data/lib/torch/native/parser.rb +97 -0
- data/lib/torch/nn/alpha_dropout.rb +9 -0
- data/lib/torch/nn/avg_pool2d.rb +14 -0
- data/lib/torch/nn/avg_poolnd.rb +9 -0
- data/lib/torch/nn/bce_loss.rb +13 -0
- data/lib/torch/nn/bce_with_logits_loss.rb +15 -0
- data/lib/torch/nn/bilinear.rb +38 -0
- data/lib/torch/nn/conv2d.rb +14 -29
- data/lib/torch/nn/convnd.rb +41 -0
- data/lib/torch/nn/cosine_embedding_loss.rb +14 -0
- data/lib/torch/nn/cosine_similarity.rb +15 -0
- data/lib/torch/nn/cross_entropy_loss.rb +14 -0
- data/lib/torch/nn/ctc_loss.rb +15 -0
- data/lib/torch/nn/dropout.rb +9 -0
- data/lib/torch/nn/dropout2d.rb +9 -0
- data/lib/torch/nn/dropout3d.rb +9 -0
- data/lib/torch/nn/dropoutnd.rb +15 -0
- data/lib/torch/nn/embedding.rb +52 -0
- data/lib/torch/nn/embedding_bag.rb +34 -0
- data/lib/torch/nn/feature_alpha_dropout.rb +9 -0
- data/lib/torch/nn/functional.rb +194 -11
- data/lib/torch/nn/hinge_embedding_loss.rb +14 -0
- data/lib/torch/nn/identity.rb +14 -0
- data/lib/torch/nn/init.rb +58 -1
- data/lib/torch/nn/kl_div_loss.rb +13 -0
- data/lib/torch/nn/l1_loss.rb +13 -0
- data/lib/torch/nn/leaky_relu.rb +20 -0
- data/lib/torch/nn/linear.rb +12 -11
- data/lib/torch/nn/log_softmax.rb +14 -0
- data/lib/torch/nn/loss.rb +10 -0
- data/lib/torch/nn/margin_ranking_loss.rb +14 -0
- data/lib/torch/nn/max_pool2d.rb +9 -0
- data/lib/torch/nn/max_poolnd.rb +19 -0
- data/lib/torch/nn/module.rb +184 -19
- data/lib/torch/nn/mse_loss.rb +2 -2
- data/lib/torch/nn/multi_label_margin_loss.rb +13 -0
- data/lib/torch/nn/multi_label_soft_margin_loss.rb +13 -0
- data/lib/torch/nn/multi_margin_loss.rb +17 -0
- data/lib/torch/nn/nll_loss.rb +14 -0
- data/lib/torch/nn/pairwise_distance.rb +16 -0
- data/lib/torch/nn/parameter.rb +4 -0
- data/lib/torch/nn/poisson_nll_loss.rb +16 -0
- data/lib/torch/nn/prelu.rb +19 -0
- data/lib/torch/nn/relu.rb +8 -3
- data/lib/torch/nn/rnn.rb +22 -0
- data/lib/torch/nn/rnn_base.rb +154 -0
- data/lib/torch/nn/sequential.rb +1 -10
- data/lib/torch/nn/sigmoid.rb +9 -0
- data/lib/torch/nn/smooth_l1_loss.rb +13 -0
- data/lib/torch/nn/soft_margin_loss.rb +13 -0
- data/lib/torch/nn/softmax.rb +18 -0
- data/lib/torch/nn/softmax2d.rb +10 -0
- data/lib/torch/nn/softmin.rb +14 -0
- data/lib/torch/nn/softplus.rb +19 -0
- data/lib/torch/nn/triplet_margin_loss.rb +18 -0
- data/lib/torch/nn/weighted_loss.rb +10 -0
- data/lib/torch/optim/adadelta.rb +57 -0
- data/lib/torch/optim/adagrad.rb +71 -0
- data/lib/torch/optim/adam.rb +81 -0
- data/lib/torch/optim/adamax.rb +68 -0
- data/lib/torch/optim/adamw.rb +82 -0
- data/lib/torch/optim/asgd.rb +65 -0
- data/lib/torch/optim/lr_scheduler/lr_scheduler.rb +33 -0
- data/lib/torch/optim/lr_scheduler/step_lr.rb +17 -0
- data/lib/torch/optim/optimizer.rb +62 -0
- data/lib/torch/optim/rmsprop.rb +76 -0
- data/lib/torch/optim/rprop.rb +68 -0
- data/lib/torch/optim/sgd.rb +60 -0
- data/lib/torch/random.rb +10 -0
- data/lib/torch/tensor.rb +92 -21
- data/lib/torch/utils/data/data_loader.rb +15 -0
- data/lib/torch/utils/data/tensor_dataset.rb +8 -1
- data/lib/torch/version.rb +1 -1
- metadata +74 -3
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 6b47306ed525e1a20d25cb8324d4658f750c18afa5704c9b7bafc215d8f568c1
|
4
|
+
data.tar.gz: dad6ddf955b111989b061e5af146006a32c83dc1ea1ca5005a6b6e34bc9a4892
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 5d26e3642bf7cd921b9b570052df353d4c32b1bd955a6fbbf5f30249631fa4c0d4624f4fa91a1c06f61b3b0d6461cd117ab4df185cf013e915d2f63e52dbcf7c
|
7
|
+
data.tar.gz: 1728ce9b579f41f7a567e63d7256c82bb352840b67f16d88aac930a99e5abbf5a5f4061c5f9da16fb47d1664567e7956d276a8b2b44f13d2263032486afb53e8
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,43 @@
|
|
1
|
+
## 0.1.5 (2019-12-06)
|
2
|
+
|
3
|
+
- Added many more functions
|
4
|
+
- Added tensor classes - `FloatTensor`, `LongTensor`, etc
|
5
|
+
- Improved modules
|
6
|
+
|
7
|
+
## 0.1.4 (2019-12-01)
|
8
|
+
|
9
|
+
- Added distance functions
|
10
|
+
- Added more activations
|
11
|
+
- Added more linear layers
|
12
|
+
- Added more loss functions
|
13
|
+
- Added more init methods
|
14
|
+
- Added support for tensor assignment
|
15
|
+
|
16
|
+
## 0.1.3 (2019-11-30)
|
17
|
+
|
18
|
+
- Changed to BSD 3-Clause license to match PyTorch
|
19
|
+
- Added many optimizers
|
20
|
+
- Added `StepLR` learning rate scheduler
|
21
|
+
- Added dropout
|
22
|
+
- Added embedding
|
23
|
+
- Added support for `bool` type
|
24
|
+
- Improved performance of `from_numo`
|
25
|
+
|
26
|
+
## 0.1.2 (2019-11-27)
|
27
|
+
|
28
|
+
- Added SGD optimizer
|
29
|
+
- Added support for gradient to `backward` method
|
30
|
+
- Added `argmax`, `eq`, `leaky_relu`, `prelu`, and `reshape` methods
|
31
|
+
- Improved indexing
|
32
|
+
- Fixed `zero_grad`
|
33
|
+
- Fixed error with infinite values
|
34
|
+
|
35
|
+
## 0.1.1 (2019-11-26)
|
36
|
+
|
37
|
+
- Added support for `uint8` and `int8` types
|
38
|
+
- Fixed `undefined symbol` error on Linux
|
39
|
+
- Fixed C++ error messages
|
40
|
+
|
1
41
|
## 0.1.0 (2019-11-26)
|
2
42
|
|
3
43
|
- First release
|
data/LICENSE.txt
CHANGED
@@ -1,22 +1,46 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
1
|
+
BSD 3-Clause License
|
2
|
+
|
3
|
+
From Torch-rb:
|
4
|
+
|
5
|
+
Copyright (c) 2019- Andrew Kane
|
6
|
+
|
7
|
+
From PyTorch (for ported code):
|
8
|
+
|
9
|
+
Copyright (c) 2016- Facebook, Inc (Adam Paszke)
|
10
|
+
Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
|
11
|
+
Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
|
12
|
+
Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
|
13
|
+
Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
|
14
|
+
Copyright (c) 2011-2013 NYU (Clement Farabet)
|
15
|
+
Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
|
16
|
+
Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
|
17
|
+
Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
|
18
|
+
|
19
|
+
All rights reserved.
|
20
|
+
|
21
|
+
Redistribution and use in source and binary forms, with or without
|
22
|
+
modification, are permitted provided that the following conditions are met:
|
23
|
+
|
24
|
+
1. Redistributions of source code must retain the above copyright
|
25
|
+
notice, this list of conditions and the following disclaimer.
|
26
|
+
|
27
|
+
2. Redistributions in binary form must reproduce the above copyright
|
28
|
+
notice, this list of conditions and the following disclaimer in the
|
29
|
+
documentation and/or other materials provided with the distribution.
|
30
|
+
|
31
|
+
3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America
|
32
|
+
and IDIAP Research Institute nor the names of its contributors may be
|
33
|
+
used to endorse or promote products derived from this software without
|
34
|
+
specific prior written permission.
|
35
|
+
|
36
|
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
37
|
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
38
|
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
39
|
+
ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
|
40
|
+
LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
41
|
+
CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
42
|
+
SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
43
|
+
INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
44
|
+
CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
45
|
+
ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
46
|
+
POSSIBILITY OF SUCH DAMAGE.
|
data/README.md
CHANGED
@@ -2,14 +2,16 @@
|
|
2
2
|
|
3
3
|
:fire: Deep learning for Ruby, powered by [LibTorch](https://pytorch.org)
|
4
4
|
|
5
|
-
|
5
|
+
This gem is currently experimental. There may be breaking changes between each release. Please report any issues you experience.
|
6
|
+
|
7
|
+
[](https://travis-ci.org/ankane/torch-rb)
|
6
8
|
|
7
9
|
## Installation
|
8
10
|
|
9
11
|
First, [install LibTorch](#libtorch-installation). For Homebrew, use:
|
10
12
|
|
11
13
|
```sh
|
12
|
-
brew install
|
14
|
+
brew install libtorch
|
13
15
|
```
|
14
16
|
|
15
17
|
Add this line to your application’s Gemfile:
|
@@ -18,6 +20,8 @@ Add this line to your application’s Gemfile:
|
|
18
20
|
gem 'torch-rb'
|
19
21
|
```
|
20
22
|
|
23
|
+
It can take a few minutes to compile the extension.
|
24
|
+
|
21
25
|
## Getting Started
|
22
26
|
|
23
27
|
This library follows the [PyTorch API](https://pytorch.org/docs/stable/torch.html). There are a few changes to make it more Ruby-like:
|
@@ -26,9 +30,11 @@ This library follows the [PyTorch API](https://pytorch.org/docs/stable/torch.htm
|
|
26
30
|
- Methods that return booleans use `?` instead of `is_` (`tensor?` instead of `is_tensor`)
|
27
31
|
- Numo is used instead of NumPy (`x.numo` instead of `x.numpy()`)
|
28
32
|
|
29
|
-
|
33
|
+
Some methods and options are missing at the moment. PRs welcome!
|
30
34
|
|
31
|
-
|
35
|
+
## Tutorial
|
36
|
+
|
37
|
+
Some examples below are from [Deep Learning with PyTorch: A 60 Minutes Blitz](https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html)
|
32
38
|
|
33
39
|
### Tensors
|
34
40
|
|
@@ -143,7 +149,7 @@ Convert a Numo array to a tensor
|
|
143
149
|
|
144
150
|
```ruby
|
145
151
|
b = Numo::NArray.cast([1, 2, 3])
|
146
|
-
Torch.
|
152
|
+
Torch.from_numo(b)
|
147
153
|
```
|
148
154
|
|
149
155
|
### Autograd
|
@@ -171,17 +177,17 @@ out.backward
|
|
171
177
|
Get gradients
|
172
178
|
|
173
179
|
```ruby
|
174
|
-
x.grad
|
180
|
+
x.grad # tensor([[4.5, 4.5], [4.5, 4.5]])
|
175
181
|
```
|
176
182
|
|
177
183
|
Stop autograd from tracking history
|
178
184
|
|
179
185
|
```ruby
|
180
186
|
x.requires_grad # true
|
181
|
-
(x
|
187
|
+
(x**2).requires_grad # true
|
182
188
|
|
183
189
|
Torch.no_grad do
|
184
|
-
(x
|
190
|
+
(x**2).requires_grad # false
|
185
191
|
end
|
186
192
|
```
|
187
193
|
|
@@ -221,7 +227,7 @@ class Net < Torch::NN::Module
|
|
221
227
|
end
|
222
228
|
```
|
223
229
|
|
224
|
-
|
230
|
+
Create an instance of it
|
225
231
|
|
226
232
|
```ruby
|
227
233
|
net = Net.new
|
@@ -229,6 +235,58 @@ input = Torch.randn(1, 1, 32, 32)
|
|
229
235
|
net.call(input)
|
230
236
|
```
|
231
237
|
|
238
|
+
Get trainable parameters
|
239
|
+
|
240
|
+
```ruby
|
241
|
+
net.parameters
|
242
|
+
```
|
243
|
+
|
244
|
+
Zero the gradient buffers and backprop with random gradients
|
245
|
+
|
246
|
+
```ruby
|
247
|
+
net.zero_grad
|
248
|
+
out.backward(Torch.randn(1, 10))
|
249
|
+
```
|
250
|
+
|
251
|
+
Define a loss function
|
252
|
+
|
253
|
+
```ruby
|
254
|
+
output = net.call(input)
|
255
|
+
target = Torch.randn(10)
|
256
|
+
target = target.view(1, -1)
|
257
|
+
criterion = Torch::NN::MSELoss.new
|
258
|
+
loss = criterion.call(output, target)
|
259
|
+
```
|
260
|
+
|
261
|
+
Backprop
|
262
|
+
|
263
|
+
```ruby
|
264
|
+
net.zero_grad
|
265
|
+
p net.conv1.bias.grad
|
266
|
+
loss.backward
|
267
|
+
p net.conv1.bias.grad
|
268
|
+
```
|
269
|
+
|
270
|
+
Update the weights
|
271
|
+
|
272
|
+
```ruby
|
273
|
+
learning_rate = 0.01
|
274
|
+
net.parameters.each do |f|
|
275
|
+
f.data.sub!(f.grad.data * learning_rate)
|
276
|
+
end
|
277
|
+
```
|
278
|
+
|
279
|
+
Use an optimizer
|
280
|
+
|
281
|
+
```ruby
|
282
|
+
optimizer = Torch::Optim::SGD.new(net.parameters, lr: 0.01)
|
283
|
+
optimizer.zero_grad
|
284
|
+
output = net.call(input)
|
285
|
+
loss = criterion.call(output, target)
|
286
|
+
loss.backward
|
287
|
+
optimizer.step
|
288
|
+
```
|
289
|
+
|
232
290
|
### Tensor Creation
|
233
291
|
|
234
292
|
Here’s a list of functions to create tensors (descriptions from the [C++ docs](https://pytorch.org/cppdocs/notes/tensor_creation.html)):
|
@@ -242,7 +300,7 @@ Here’s a list of functions to create tensors (descriptions from the [C++ docs]
|
|
242
300
|
- `empty` returns a tensor with uninitialized values
|
243
301
|
|
244
302
|
```ruby
|
245
|
-
Torch.empty(3)
|
303
|
+
Torch.empty(3) # tensor([7.0054e-45, 0.0000e+00, 0.0000e+00])
|
246
304
|
```
|
247
305
|
|
248
306
|
- `eye` returns an identity matrix
|
@@ -278,19 +336,19 @@ Here’s a list of functions to create tensors (descriptions from the [C++ docs]
|
|
278
336
|
- `rand` returns a tensor filled with values drawn from a uniform distribution on [0, 1)
|
279
337
|
|
280
338
|
```ruby
|
281
|
-
Torch.rand(3)
|
339
|
+
Torch.rand(3) # tensor([0.5444, 0.8799, 0.5571])
|
282
340
|
```
|
283
341
|
|
284
342
|
- `randint` returns a tensor with integers randomly drawn from an interval
|
285
343
|
|
286
344
|
```ruby
|
287
|
-
Torch.randint(1, 10, [3])
|
345
|
+
Torch.randint(1, 10, [3]) # tensor([7, 6, 4])
|
288
346
|
```
|
289
347
|
|
290
348
|
- `randn` returns a tensor filled with values drawn from a unit normal distribution
|
291
349
|
|
292
350
|
```ruby
|
293
|
-
Torch.randn(3)
|
351
|
+
Torch.randn(3) # tensor([-0.7147, 0.6614, 1.1453])
|
294
352
|
```
|
295
353
|
|
296
354
|
- `randperm` returns a tensor filled with a random permutation of integers in some interval
|
@@ -305,12 +363,20 @@ Here’s a list of functions to create tensors (descriptions from the [C++ docs]
|
|
305
363
|
Torch.zeros(3) # tensor([0, 0, 0])
|
306
364
|
```
|
307
365
|
|
366
|
+
## Examples
|
367
|
+
|
368
|
+
Here are a few full examples:
|
369
|
+
|
370
|
+
- [Image classification with MNIST](examples/mnist) ([日本語版](https://qiita.com/kojix2/items/c19c36dc1bf73ea93409))
|
371
|
+
- [Collaborative filtering with MovieLens](examples/movielens)
|
372
|
+
- [Word embeddings](examples/nlp)
|
373
|
+
|
308
374
|
## LibTorch Installation
|
309
375
|
|
310
|
-
[Download LibTorch](https://pytorch.org/)
|
376
|
+
[Download LibTorch](https://pytorch.org/). For Linux, use the `cxx11 ABI` version. Then run:
|
311
377
|
|
312
378
|
```sh
|
313
|
-
|
379
|
+
bundle config build.torch-rb --with-torch-dir=/path/to/libtorch
|
314
380
|
```
|
315
381
|
|
316
382
|
### Homebrew
|
@@ -318,10 +384,10 @@ gem install torch-rb -- --with-torch-dir=/path/to/libtorch
|
|
318
384
|
For Mac, you can use Homebrew.
|
319
385
|
|
320
386
|
```sh
|
321
|
-
brew install
|
387
|
+
brew install libtorch
|
322
388
|
```
|
323
389
|
|
324
|
-
Then install the gem (no need for
|
390
|
+
Then install the gem (no need for `bundle config`).
|
325
391
|
|
326
392
|
## rbenv
|
327
393
|
|
@@ -349,9 +415,9 @@ To get started with development:
|
|
349
415
|
|
350
416
|
```sh
|
351
417
|
git clone https://github.com/ankane/torch-rb.git
|
352
|
-
cd torch
|
418
|
+
cd torch-rb
|
353
419
|
bundle install
|
354
|
-
bundle exec rake compile
|
420
|
+
bundle exec rake compile -- --with-torch-dir=/path/to/libtorch
|
355
421
|
bundle exec rake test
|
356
422
|
```
|
357
423
|
|
data/ext/torch/ext.cpp
CHANGED
@@ -6,95 +6,36 @@
|
|
6
6
|
#include <rice/Class.hpp>
|
7
7
|
#include <rice/Constructor.hpp>
|
8
8
|
|
9
|
-
|
9
|
+
#include "templates.hpp"
|
10
10
|
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
}
|
17
|
-
|
18
|
-
template<>
|
19
|
-
inline
|
20
|
-
Object to_ruby<long long>(long long const & x)
|
21
|
-
{
|
22
|
-
return LL2NUM(x);
|
23
|
-
}
|
24
|
-
|
25
|
-
template<>
|
26
|
-
inline
|
27
|
-
unsigned long long from_ruby<unsigned long long>(Object x)
|
28
|
-
{
|
29
|
-
return NUM2ULL(x);
|
30
|
-
}
|
11
|
+
// generated with:
|
12
|
+
// rake generate:functions
|
13
|
+
#include "torch_functions.hpp"
|
14
|
+
#include "tensor_functions.hpp"
|
15
|
+
#include "nn_functions.hpp"
|
31
16
|
|
32
|
-
|
33
|
-
inline
|
34
|
-
Object to_ruby<unsigned long long>(unsigned long long const & x)
|
35
|
-
{
|
36
|
-
return ULL2NUM(x);
|
37
|
-
}
|
38
|
-
|
39
|
-
template<>
|
40
|
-
inline
|
41
|
-
short from_ruby<short>(Object x)
|
42
|
-
{
|
43
|
-
return NUM2SHORT(x);
|
44
|
-
}
|
45
|
-
|
46
|
-
template<>
|
47
|
-
inline
|
48
|
-
Object to_ruby<short>(short const & x)
|
49
|
-
{
|
50
|
-
return INT2NUM(x);
|
51
|
-
}
|
17
|
+
using namespace Rice;
|
52
18
|
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
return
|
19
|
+
Object tensor_array(std::tuple<torch::Tensor, torch::Tensor> x) {
|
20
|
+
Array a;
|
21
|
+
a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
|
22
|
+
a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
|
23
|
+
return Object(a);
|
58
24
|
}
|
59
25
|
|
60
|
-
|
61
|
-
|
62
|
-
Object to_ruby<unsigned short>(unsigned short const & x)
|
26
|
+
extern "C"
|
27
|
+
void Init_ext()
|
63
28
|
{
|
64
|
-
|
65
|
-
|
29
|
+
Module rb_mTorch = define_module("Torch");
|
30
|
+
add_torch_functions(rb_mTorch);
|
66
31
|
|
67
|
-
|
68
|
-
|
69
|
-
class IntArrayRef {
|
70
|
-
std::vector<int64_t> vec;
|
71
|
-
public:
|
72
|
-
IntArrayRef(Object o) {
|
73
|
-
Array a = Array(o);
|
74
|
-
for (size_t i = 0; i < a.size(); i++) {
|
75
|
-
vec.push_back(from_ruby<int64_t>(a[i]));
|
76
|
-
}
|
77
|
-
}
|
78
|
-
operator torch::IntArrayRef() {
|
79
|
-
return torch::IntArrayRef(vec);
|
80
|
-
}
|
81
|
-
};
|
32
|
+
Class rb_cTensor = define_class_under<torch::Tensor>(rb_mTorch, "Tensor");
|
33
|
+
add_tensor_functions(rb_cTensor);
|
82
34
|
|
83
|
-
|
84
|
-
|
85
|
-
IntArrayRef from_ruby<IntArrayRef>(Object x)
|
86
|
-
{
|
87
|
-
return IntArrayRef(x);
|
88
|
-
}
|
89
|
-
|
90
|
-
// for now
|
91
|
-
typedef float Scalar;
|
35
|
+
Module rb_mNN = define_module_under(rb_mTorch, "NN");
|
36
|
+
add_nn_functions(rb_mNN);
|
92
37
|
|
93
|
-
|
94
|
-
void Init_ext()
|
95
|
-
{
|
96
|
-
Module rb_mTorch = define_module("Torch")
|
97
|
-
.define_singleton_method(
|
38
|
+
rb_mTorch.define_singleton_method(
|
98
39
|
"grad_enabled?",
|
99
40
|
*[]() {
|
100
41
|
return torch::GradMode::is_enabled();
|
@@ -104,11 +45,6 @@ void Init_ext()
|
|
104
45
|
*[](bool enabled) {
|
105
46
|
torch::GradMode::set_enabled(enabled);
|
106
47
|
})
|
107
|
-
.define_singleton_method(
|
108
|
-
"floating_point?",
|
109
|
-
*[](torch::Tensor& input) {
|
110
|
-
return torch::is_floating_point(input);
|
111
|
-
})
|
112
48
|
.define_singleton_method(
|
113
49
|
"manual_seed",
|
114
50
|
*[](uint64_t seed) {
|
@@ -178,172 +114,117 @@ void Init_ext()
|
|
178
114
|
// begin operations
|
179
115
|
.define_singleton_method(
|
180
116
|
"_mean",
|
181
|
-
*[](
|
117
|
+
*[](Tensor& input) {
|
182
118
|
return torch::mean(input);
|
183
119
|
})
|
184
120
|
.define_singleton_method(
|
185
121
|
"_mean_dim",
|
186
|
-
*[](
|
122
|
+
*[](Tensor& input, int64_t dim, bool keepdim) {
|
187
123
|
return torch::mean(input, dim, keepdim);
|
188
124
|
})
|
189
125
|
.define_singleton_method(
|
190
126
|
"_sum",
|
191
|
-
*[](
|
127
|
+
*[](Tensor& input) {
|
192
128
|
return torch::sum(input);
|
193
129
|
})
|
194
130
|
.define_singleton_method(
|
195
131
|
"_sum_dim",
|
196
|
-
*[](
|
132
|
+
*[](Tensor& input, int64_t dim, bool keepdim) {
|
197
133
|
return torch::sum(input, dim, keepdim);
|
198
134
|
})
|
199
135
|
.define_singleton_method(
|
200
|
-
"
|
201
|
-
*[](
|
202
|
-
return torch::
|
203
|
-
})
|
204
|
-
.define_singleton_method(
|
205
|
-
"_min",
|
206
|
-
*[](torch::Tensor& input) {
|
207
|
-
return torch::min(input);
|
136
|
+
"_max_out",
|
137
|
+
*[](Tensor &max, Tensor &max_indices, const Tensor &input, int64_t dim, bool keepdim) {
|
138
|
+
return tensor_array(torch::_max_out(max, max_indices, input, dim, keepdim));
|
208
139
|
})
|
209
140
|
.define_singleton_method(
|
210
|
-
"
|
211
|
-
*[](
|
212
|
-
return torch::
|
141
|
+
"_topk",
|
142
|
+
*[](Tensor& input, int64_t k) {
|
143
|
+
return tensor_array(torch::topk(input, k));
|
213
144
|
})
|
214
145
|
.define_singleton_method(
|
215
|
-
"
|
216
|
-
*[](
|
217
|
-
return torch::
|
146
|
+
"_softmax",
|
147
|
+
*[](const Tensor &input, int64_t dim) {
|
148
|
+
return torch::softmax(input, dim);
|
218
149
|
})
|
219
150
|
.define_singleton_method(
|
220
|
-
"
|
221
|
-
*[](
|
222
|
-
return torch::
|
151
|
+
"_log_softmax",
|
152
|
+
*[](Tensor& input, int64_t dim) {
|
153
|
+
return torch::log_softmax(input, dim);
|
223
154
|
})
|
224
155
|
.define_singleton_method(
|
225
|
-
"
|
226
|
-
*[](
|
227
|
-
return torch::
|
228
|
-
})
|
229
|
-
.define_singleton_method(
|
230
|
-
"_dot",
|
231
|
-
*[](torch::Tensor& input, torch::Tensor& tensor) {
|
232
|
-
return torch::dot(input, tensor);
|
233
|
-
})
|
234
|
-
.define_singleton_method(
|
235
|
-
"_matmul",
|
236
|
-
*[](torch::Tensor& input, torch::Tensor& other) {
|
237
|
-
return torch::matmul(input, other);
|
238
|
-
})
|
239
|
-
.define_singleton_method(
|
240
|
-
"_add",
|
241
|
-
*[](torch::Tensor& input, torch::Tensor& other) {
|
242
|
-
return torch::add(input, other);
|
243
|
-
})
|
244
|
-
.define_singleton_method(
|
245
|
-
"_add_scalar",
|
246
|
-
*[](torch::Tensor& input, float other) {
|
247
|
-
return torch::add(input, other);
|
248
|
-
})
|
249
|
-
.define_singleton_method(
|
250
|
-
"_add_out",
|
251
|
-
*[](torch::Tensor& out, torch::Tensor& input, torch::Tensor& other) {
|
252
|
-
return torch::add_out(out, input, other);
|
253
|
-
})
|
254
|
-
.define_singleton_method(
|
255
|
-
"_sub",
|
256
|
-
*[](torch::Tensor& input, torch::Tensor& other) {
|
257
|
-
return torch::sub(input, other);
|
258
|
-
})
|
259
|
-
.define_singleton_method(
|
260
|
-
"_sub_scalar",
|
261
|
-
*[](torch::Tensor& input, float other) {
|
262
|
-
return torch::sub(input, other);
|
263
|
-
})
|
264
|
-
.define_singleton_method(
|
265
|
-
"_mul",
|
266
|
-
*[](torch::Tensor& input, torch::Tensor& other) {
|
267
|
-
return torch::mul(input, other);
|
268
|
-
})
|
269
|
-
.define_singleton_method(
|
270
|
-
"_mul_scalar",
|
271
|
-
*[](torch::Tensor& input, float other) {
|
272
|
-
return torch::mul(input, other);
|
273
|
-
})
|
274
|
-
.define_singleton_method(
|
275
|
-
"_div",
|
276
|
-
*[](torch::Tensor& input, torch::Tensor& other) {
|
277
|
-
return torch::div(input, other);
|
278
|
-
})
|
279
|
-
.define_singleton_method(
|
280
|
-
"_div_scalar",
|
281
|
-
*[](torch::Tensor& input, float other) {
|
282
|
-
return torch::div(input, other);
|
283
|
-
})
|
284
|
-
.define_singleton_method(
|
285
|
-
"_remainder",
|
286
|
-
*[](torch::Tensor& input, torch::Tensor& other) {
|
287
|
-
return torch::remainder(input, other);
|
288
|
-
})
|
289
|
-
.define_singleton_method(
|
290
|
-
"_remainder_scalar",
|
291
|
-
*[](torch::Tensor& input, float other) {
|
292
|
-
return torch::remainder(input, other);
|
156
|
+
"relu",
|
157
|
+
*[](Tensor& input) {
|
158
|
+
return torch::relu(input);
|
293
159
|
})
|
294
160
|
.define_singleton_method(
|
295
|
-
"
|
296
|
-
*[](torch::Tensor& input,
|
297
|
-
return torch::
|
161
|
+
"prelu",
|
162
|
+
*[](torch::Tensor& input, torch::Tensor& weight) {
|
163
|
+
return torch::prelu(input, weight);
|
298
164
|
})
|
299
165
|
.define_singleton_method(
|
300
|
-
"
|
301
|
-
*[](torch::Tensor& input) {
|
302
|
-
return torch::
|
166
|
+
"leaky_relu",
|
167
|
+
*[](torch::Tensor& input, Scalar negative_slope) {
|
168
|
+
return torch::leaky_relu(input, negative_slope);
|
303
169
|
})
|
304
170
|
.define_singleton_method(
|
305
|
-
"
|
306
|
-
*[](
|
307
|
-
return torch::
|
171
|
+
"conv2d",
|
172
|
+
*[](Tensor& input, Tensor& weight, Tensor& bias, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, int64_t groups) {
|
173
|
+
return torch::conv2d(input, weight, bias, stride, padding, dilation, groups);
|
308
174
|
})
|
175
|
+
// linear layers
|
309
176
|
.define_singleton_method(
|
310
|
-
"
|
311
|
-
*[](
|
312
|
-
return torch::
|
177
|
+
"bilinear",
|
178
|
+
*[](const Tensor &input1, const Tensor &input2, const Tensor &weight, const Tensor &bias) {
|
179
|
+
return torch::bilinear(input1, input2, weight, bias);
|
313
180
|
})
|
314
181
|
.define_singleton_method(
|
315
182
|
"linear",
|
316
|
-
*[](
|
183
|
+
*[](Tensor& input, Tensor& weight, Tensor& bias) {
|
317
184
|
return torch::linear(input, weight, bias);
|
318
185
|
})
|
186
|
+
// pooling layers
|
319
187
|
.define_singleton_method(
|
320
188
|
"max_pool2d",
|
321
|
-
*[](
|
189
|
+
*[](Tensor& input, IntArrayRef kernel_size) {
|
322
190
|
return torch::max_pool2d(input, kernel_size);
|
323
191
|
})
|
324
192
|
.define_singleton_method(
|
325
|
-
"
|
326
|
-
*[](
|
327
|
-
|
328
|
-
|
193
|
+
"avg_pool2d",
|
194
|
+
*[](Tensor& input, IntArrayRef kernel_size) {
|
195
|
+
return torch::avg_pool2d(input, kernel_size);
|
196
|
+
})
|
197
|
+
.define_singleton_method(
|
198
|
+
"_binary_cross_entropy_with_logits",
|
199
|
+
*[](const Tensor &input, const Tensor &target, OptionalTensor weight, OptionalTensor pos_weight, MyReduction reduction) {
|
200
|
+
return torch::binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction);
|
329
201
|
})
|
202
|
+
.define_singleton_method("numel", &torch::numel)
|
330
203
|
.define_singleton_method(
|
331
|
-
"
|
332
|
-
*[](
|
333
|
-
|
204
|
+
"_from_blob",
|
205
|
+
*[](String s, IntArrayRef size, const torch::TensorOptions &options) {
|
206
|
+
void *data = const_cast<char *>(s.c_str());
|
207
|
+
return torch::from_blob(data, size, options);
|
334
208
|
})
|
335
209
|
.define_singleton_method(
|
336
210
|
"_tensor",
|
337
211
|
*[](Object o, IntArrayRef size, const torch::TensorOptions &options) {
|
338
212
|
Array a = Array(o);
|
339
|
-
|
340
|
-
|
341
|
-
|
213
|
+
auto dtype = options.dtype();
|
214
|
+
torch::Tensor t;
|
215
|
+
if (dtype == torch::kBool) {
|
216
|
+
throw std::runtime_error("Cannot create bool from tensor method yet");
|
217
|
+
} else {
|
218
|
+
std::vector<float> vec;
|
219
|
+
for (size_t i = 0; i < a.size(); i++) {
|
220
|
+
vec.push_back(from_ruby<float>(a[i]));
|
221
|
+
}
|
222
|
+
t = torch::tensor(vec, options);
|
342
223
|
}
|
343
|
-
return
|
224
|
+
return t.reshape(size);
|
344
225
|
});
|
345
226
|
|
346
|
-
|
227
|
+
rb_cTensor
|
347
228
|
.define_method("cuda?", &torch::Tensor::is_cuda)
|
348
229
|
.define_method("distributed?", &torch::Tensor::is_distributed)
|
349
230
|
.define_method("complex?", &torch::Tensor::is_complex)
|
@@ -352,108 +233,162 @@ void Init_ext()
|
|
352
233
|
.define_method("sparse?", &torch::Tensor::is_sparse)
|
353
234
|
.define_method("quantized?", &torch::Tensor::is_quantized)
|
354
235
|
.define_method("dim", &torch::Tensor::dim)
|
355
|
-
.define_method("numel", &torch::Tensor::numel)
|
356
236
|
.define_method("element_size", &torch::Tensor::element_size)
|
357
237
|
.define_method("requires_grad", &torch::Tensor::requires_grad)
|
238
|
+
.define_method("view_as", &torch::Tensor::view_as)
|
239
|
+
.define_method(
|
240
|
+
"addcmul!",
|
241
|
+
*[](Tensor& self, Scalar value, const Tensor & tensor1, const Tensor & tensor2) {
|
242
|
+
return self.addcmul_(tensor1, tensor2, value);
|
243
|
+
})
|
244
|
+
.define_method(
|
245
|
+
"addcdiv!",
|
246
|
+
*[](Tensor& self, Scalar value, const Tensor & tensor1, const Tensor & tensor2) {
|
247
|
+
return self.addcdiv_(tensor1, tensor2, value);
|
248
|
+
})
|
358
249
|
.define_method(
|
359
250
|
"zero!",
|
360
|
-
*[](
|
251
|
+
*[](Tensor& self) {
|
361
252
|
return self.zero_();
|
362
253
|
})
|
363
254
|
.define_method(
|
364
|
-
"detach
|
365
|
-
*[](
|
366
|
-
return self.
|
255
|
+
"detach",
|
256
|
+
*[](Tensor& self) {
|
257
|
+
return self.detach();
|
367
258
|
})
|
368
259
|
.define_method(
|
369
|
-
"
|
370
|
-
*[](
|
371
|
-
return self
|
260
|
+
"detach!",
|
261
|
+
*[](Tensor& self) {
|
262
|
+
return self.detach_();
|
372
263
|
})
|
373
264
|
.define_method(
|
374
265
|
"_requires_grad!",
|
375
|
-
*[](
|
266
|
+
*[](Tensor& self, bool requires_grad) {
|
376
267
|
return self.set_requires_grad(requires_grad);
|
377
268
|
})
|
378
269
|
.define_method(
|
379
|
-
"
|
380
|
-
*[](
|
381
|
-
return self.backward();
|
270
|
+
"_backward",
|
271
|
+
*[](Tensor& self, Object gradient) {
|
272
|
+
return gradient.is_nil() ? self.backward() : self.backward(from_ruby<torch::Tensor>(gradient));
|
382
273
|
})
|
383
274
|
.define_method(
|
384
275
|
"grad",
|
385
|
-
*[](
|
276
|
+
*[](Tensor& self) {
|
386
277
|
return self.grad();
|
387
278
|
})
|
388
279
|
.define_method(
|
389
280
|
"_dtype",
|
390
|
-
*[](
|
281
|
+
*[](Tensor& self) {
|
391
282
|
return (int) at::typeMetaToScalarType(self.dtype());
|
392
283
|
})
|
284
|
+
.define_method(
|
285
|
+
"_type",
|
286
|
+
*[](Tensor& self, int dtype) {
|
287
|
+
return self.toType((torch::ScalarType) dtype);
|
288
|
+
})
|
393
289
|
.define_method(
|
394
290
|
"_layout",
|
395
|
-
*[](
|
291
|
+
*[](Tensor& self) {
|
396
292
|
std::stringstream s;
|
397
293
|
s << self.layout();
|
398
294
|
return s.str();
|
399
295
|
})
|
400
296
|
.define_method(
|
401
297
|
"device",
|
402
|
-
*[](
|
298
|
+
*[](Tensor& self) {
|
403
299
|
std::stringstream s;
|
404
300
|
s << self.device();
|
405
301
|
return s.str();
|
406
302
|
})
|
407
303
|
.define_method(
|
408
|
-
"
|
409
|
-
*[](
|
410
|
-
return self.
|
304
|
+
"resize_as!",
|
305
|
+
*[](Tensor& self, Tensor& other) {
|
306
|
+
return self.resize_as_(other);
|
411
307
|
})
|
412
308
|
.define_method(
|
413
|
-
"
|
414
|
-
*[](
|
415
|
-
self.
|
309
|
+
"fill!",
|
310
|
+
*[](Tensor& self, Scalar value) {
|
311
|
+
return self.fill_(value);
|
416
312
|
})
|
417
313
|
.define_method(
|
418
|
-
"
|
419
|
-
*[](
|
420
|
-
self.
|
314
|
+
"relu!",
|
315
|
+
*[](Tensor& self) {
|
316
|
+
return self.relu_();
|
317
|
+
})
|
318
|
+
.define_method(
|
319
|
+
"normal!",
|
320
|
+
*[](Tensor& self, double mean, double std) {
|
321
|
+
return self.normal_(mean, std);
|
322
|
+
})
|
323
|
+
.define_method(
|
324
|
+
"random!",
|
325
|
+
*[](Tensor& self, int64_t to) {
|
326
|
+
return self.random_(to);
|
421
327
|
})
|
422
328
|
.define_method(
|
423
|
-
"
|
424
|
-
*[](
|
425
|
-
self.
|
329
|
+
"sub!",
|
330
|
+
*[](Tensor& self, Tensor& other) {
|
331
|
+
return self.sub_(other);
|
426
332
|
})
|
427
333
|
.define_method(
|
428
334
|
"div!",
|
429
|
-
*[](
|
430
|
-
self.div_(other);
|
335
|
+
*[](Tensor& self, Tensor& other) {
|
336
|
+
return self.div_(other);
|
431
337
|
})
|
432
338
|
.define_method(
|
433
|
-
"
|
434
|
-
*[](
|
435
|
-
return self.
|
339
|
+
"sqrt!",
|
340
|
+
*[](Tensor& self) {
|
341
|
+
return self.sqrt_();
|
436
342
|
})
|
437
343
|
.define_method(
|
438
|
-
"
|
439
|
-
*[](
|
344
|
+
"unsqueeze!",
|
345
|
+
*[](Tensor& self, int64_t dim) {
|
346
|
+
return self.unsqueeze_(dim);
|
347
|
+
})
|
348
|
+
.define_method(
|
349
|
+
"copy!",
|
350
|
+
*[](Tensor& self, Tensor& src) {
|
351
|
+
return self.copy_(src);
|
352
|
+
})
|
353
|
+
.define_method(
|
354
|
+
"clone",
|
355
|
+
*[](Tensor& self) {
|
356
|
+
return self.clone();
|
357
|
+
})
|
358
|
+
.define_method(
|
359
|
+
"data",
|
360
|
+
*[](Tensor& self) {
|
361
|
+
return self.data();
|
362
|
+
})
|
363
|
+
.define_method(
|
364
|
+
"_flat_data",
|
365
|
+
*[](Tensor& self) {
|
440
366
|
Array a;
|
441
367
|
auto dtype = self.dtype();
|
442
368
|
|
443
369
|
// TODO DRY if someone knows C++
|
444
|
-
|
445
|
-
|
446
|
-
|
370
|
+
if (dtype == torch::kByte) {
|
371
|
+
uint8_t* data = self.data_ptr<uint8_t>();
|
372
|
+
for (int i = 0; i < self.numel(); i++) {
|
373
|
+
a.push(data[i]);
|
374
|
+
}
|
375
|
+
} else if (dtype == torch::kChar) {
|
376
|
+
int8_t* data = self.data_ptr<int8_t>();
|
377
|
+
for (int i = 0; i < self.numel(); i++) {
|
378
|
+
a.push(to_ruby<int>(data[i]));
|
379
|
+
}
|
380
|
+
} else if (dtype == torch::kShort) {
|
381
|
+
int16_t* data = self.data_ptr<int16_t>();
|
447
382
|
for (int i = 0; i < self.numel(); i++) {
|
448
383
|
a.push(data[i]);
|
449
384
|
}
|
450
385
|
} else if (dtype == torch::kInt) {
|
451
|
-
|
386
|
+
int32_t* data = self.data_ptr<int32_t>();
|
452
387
|
for (int i = 0; i < self.numel(); i++) {
|
453
388
|
a.push(data[i]);
|
454
389
|
}
|
455
390
|
} else if (dtype == torch::kLong) {
|
456
|
-
|
391
|
+
int64_t* data = self.data_ptr<int64_t>();
|
457
392
|
for (int i = 0; i < self.numel(); i++) {
|
458
393
|
a.push(data[i]);
|
459
394
|
}
|
@@ -467,19 +402,24 @@ void Init_ext()
|
|
467
402
|
for (int i = 0; i < self.numel(); i++) {
|
468
403
|
a.push(data[i]);
|
469
404
|
}
|
405
|
+
} else if (dtype == torch::kBool) {
|
406
|
+
bool* data = self.data_ptr<bool>();
|
407
|
+
for (int i = 0; i < self.numel(); i++) {
|
408
|
+
a.push(data[i] ? True : False);
|
409
|
+
}
|
470
410
|
} else {
|
471
|
-
throw "Unsupported type";
|
411
|
+
throw std::runtime_error("Unsupported type");
|
472
412
|
}
|
473
413
|
return a;
|
474
414
|
})
|
475
415
|
.define_method(
|
476
|
-
"
|
477
|
-
*[](
|
478
|
-
return self.
|
416
|
+
"_to",
|
417
|
+
*[](Tensor& self, torch::Device device, int dtype, bool non_blocking, bool copy) {
|
418
|
+
return self.to(device, (torch::ScalarType) dtype, non_blocking, copy);
|
479
419
|
})
|
480
420
|
.define_singleton_method(
|
481
421
|
"_make_subclass",
|
482
|
-
*[](
|
422
|
+
*[](Tensor& rd, bool requires_grad) {
|
483
423
|
auto data = torch::autograd::as_variable_ref(rd).detach();
|
484
424
|
data.unsafeGetTensorImpl()->set_allow_tensor_metadata_change(true);
|
485
425
|
auto var = data.set_requires_grad(requires_grad);
|
@@ -499,8 +439,11 @@ void Init_ext()
|
|
499
439
|
torch::Layout l;
|
500
440
|
if (layout == "strided") {
|
501
441
|
l = torch::kStrided;
|
442
|
+
} else if (layout == "sparse") {
|
443
|
+
l = torch::kSparse;
|
444
|
+
throw std::runtime_error("Sparse layout not supported yet");
|
502
445
|
} else {
|
503
|
-
throw "Unsupported layout";
|
446
|
+
throw std::runtime_error("Unsupported layout: " + layout);
|
504
447
|
}
|
505
448
|
return self.layout(l);
|
506
449
|
})
|
@@ -513,7 +456,7 @@ void Init_ext()
|
|
513
456
|
} else if (device == "cuda") {
|
514
457
|
d = torch::kCUDA;
|
515
458
|
} else {
|
516
|
-
throw "Unsupported device";
|
459
|
+
throw std::runtime_error("Unsupported device: " + device);
|
517
460
|
}
|
518
461
|
return self.device(d);
|
519
462
|
})
|
@@ -523,24 +466,99 @@ void Init_ext()
|
|
523
466
|
return self.requires_grad(requires_grad);
|
524
467
|
});
|
525
468
|
|
526
|
-
Module rb_mNN = define_module_under(rb_mTorch, "NN");
|
527
|
-
|
528
469
|
Module rb_mInit = define_module_under(rb_mNN, "Init")
|
529
470
|
.define_singleton_method(
|
530
|
-
"
|
531
|
-
*[](
|
532
|
-
return torch::nn::init::
|
471
|
+
"_calculate_gain",
|
472
|
+
*[](NonlinearityType nonlinearity, double param) {
|
473
|
+
return torch::nn::init::calculate_gain(nonlinearity, param);
|
533
474
|
})
|
534
475
|
.define_singleton_method(
|
535
|
-
"
|
536
|
-
*[](
|
537
|
-
return torch::nn::init::uniform_(
|
476
|
+
"_uniform!",
|
477
|
+
*[](Tensor tensor, double low, double high) {
|
478
|
+
return torch::nn::init::uniform_(tensor, low, high);
|
479
|
+
})
|
480
|
+
.define_singleton_method(
|
481
|
+
"_normal!",
|
482
|
+
*[](Tensor tensor, double mean, double std) {
|
483
|
+
return torch::nn::init::normal_(tensor, mean, std);
|
484
|
+
})
|
485
|
+
.define_singleton_method(
|
486
|
+
"_constant!",
|
487
|
+
*[](Tensor tensor, Scalar value) {
|
488
|
+
return torch::nn::init::constant_(tensor, value);
|
489
|
+
})
|
490
|
+
.define_singleton_method(
|
491
|
+
"_ones!",
|
492
|
+
*[](Tensor tensor) {
|
493
|
+
return torch::nn::init::ones_(tensor);
|
494
|
+
})
|
495
|
+
.define_singleton_method(
|
496
|
+
"_zeros!",
|
497
|
+
*[](Tensor tensor) {
|
498
|
+
return torch::nn::init::zeros_(tensor);
|
499
|
+
})
|
500
|
+
.define_singleton_method(
|
501
|
+
"_eye!",
|
502
|
+
*[](Tensor tensor) {
|
503
|
+
return torch::nn::init::eye_(tensor);
|
504
|
+
})
|
505
|
+
.define_singleton_method(
|
506
|
+
"_dirac!",
|
507
|
+
*[](Tensor tensor) {
|
508
|
+
return torch::nn::init::dirac_(tensor);
|
509
|
+
})
|
510
|
+
.define_singleton_method(
|
511
|
+
"_xavier_uniform!",
|
512
|
+
*[](Tensor tensor, double gain) {
|
513
|
+
return torch::nn::init::xavier_uniform_(tensor, gain);
|
514
|
+
})
|
515
|
+
.define_singleton_method(
|
516
|
+
"_xavier_normal!",
|
517
|
+
*[](Tensor tensor, double gain) {
|
518
|
+
return torch::nn::init::xavier_normal_(tensor, gain);
|
519
|
+
})
|
520
|
+
.define_singleton_method(
|
521
|
+
"_kaiming_uniform!",
|
522
|
+
*[](Tensor tensor, double a, FanModeType mode, NonlinearityType nonlinearity) {
|
523
|
+
return torch::nn::init::kaiming_uniform_(tensor, a, mode, nonlinearity);
|
524
|
+
})
|
525
|
+
.define_singleton_method(
|
526
|
+
"_kaiming_normal!",
|
527
|
+
*[](Tensor tensor, double a, FanModeType mode, NonlinearityType nonlinearity) {
|
528
|
+
return torch::nn::init::kaiming_normal_(tensor, a, mode, nonlinearity);
|
529
|
+
})
|
530
|
+
.define_singleton_method(
|
531
|
+
"_orthogonal!",
|
532
|
+
*[](Tensor tensor, double gain) {
|
533
|
+
return torch::nn::init::orthogonal_(tensor, gain);
|
534
|
+
})
|
535
|
+
.define_singleton_method(
|
536
|
+
"_sparse!",
|
537
|
+
*[](Tensor tensor, double sparsity, double std) {
|
538
|
+
return torch::nn::init::sparse_(tensor, sparsity, std);
|
538
539
|
});
|
539
540
|
|
540
541
|
Class rb_cParameter = define_class_under<torch::autograd::Variable, torch::Tensor>(rb_mNN, "Parameter")
|
541
542
|
.define_method(
|
542
543
|
"grad",
|
543
544
|
*[](torch::autograd::Variable& self) {
|
544
|
-
|
545
|
+
auto grad = self.grad();
|
546
|
+
return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
|
547
|
+
});
|
548
|
+
|
549
|
+
Class rb_cDevice = define_class_under<torch::Device>(rb_mTorch, "Device")
|
550
|
+
.define_constructor(Constructor<torch::Device, std::string>())
|
551
|
+
.define_method("index", &torch::Device::index)
|
552
|
+
.define_method("index?", &torch::Device::has_index)
|
553
|
+
.define_method(
|
554
|
+
"type",
|
555
|
+
*[](torch::Device& self) {
|
556
|
+
std::stringstream s;
|
557
|
+
s << self.type();
|
558
|
+
return s.str();
|
545
559
|
});
|
560
|
+
|
561
|
+
Module rb_mCUDA = define_module_under(rb_mTorch, "CUDA")
|
562
|
+
.define_singleton_method("available?", &torch::cuda::is_available)
|
563
|
+
.define_singleton_method("device_count", &torch::cuda::device_count);
|
546
564
|
}
|