torch-rb 0.1.0 → 0.1.5

Sign up to get free protection for your applications and to get access to all the features.
Files changed (94) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +40 -0
  3. data/LICENSE.txt +46 -22
  4. data/README.md +85 -19
  5. data/ext/torch/ext.cpp +274 -256
  6. data/ext/torch/extconf.rb +9 -0
  7. data/ext/torch/nn_functions.cpp +595 -0
  8. data/ext/torch/nn_functions.hpp +6 -0
  9. data/ext/torch/templates.hpp +250 -0
  10. data/ext/torch/tensor_functions.cpp +1860 -0
  11. data/ext/torch/tensor_functions.hpp +6 -0
  12. data/ext/torch/torch_functions.cpp +2875 -0
  13. data/ext/torch/torch_functions.hpp +6 -0
  14. data/lib/torch.rb +199 -84
  15. data/lib/torch/ext.bundle +0 -0
  16. data/lib/torch/inspector.rb +52 -25
  17. data/lib/torch/native/dispatcher.rb +48 -0
  18. data/lib/torch/native/function.rb +78 -0
  19. data/lib/torch/native/generator.rb +149 -0
  20. data/lib/torch/native/native_functions.yaml +6837 -0
  21. data/lib/torch/native/parser.rb +97 -0
  22. data/lib/torch/nn/alpha_dropout.rb +9 -0
  23. data/lib/torch/nn/avg_pool2d.rb +14 -0
  24. data/lib/torch/nn/avg_poolnd.rb +9 -0
  25. data/lib/torch/nn/bce_loss.rb +13 -0
  26. data/lib/torch/nn/bce_with_logits_loss.rb +15 -0
  27. data/lib/torch/nn/bilinear.rb +38 -0
  28. data/lib/torch/nn/conv2d.rb +14 -29
  29. data/lib/torch/nn/convnd.rb +41 -0
  30. data/lib/torch/nn/cosine_embedding_loss.rb +14 -0
  31. data/lib/torch/nn/cosine_similarity.rb +15 -0
  32. data/lib/torch/nn/cross_entropy_loss.rb +14 -0
  33. data/lib/torch/nn/ctc_loss.rb +15 -0
  34. data/lib/torch/nn/dropout.rb +9 -0
  35. data/lib/torch/nn/dropout2d.rb +9 -0
  36. data/lib/torch/nn/dropout3d.rb +9 -0
  37. data/lib/torch/nn/dropoutnd.rb +15 -0
  38. data/lib/torch/nn/embedding.rb +52 -0
  39. data/lib/torch/nn/embedding_bag.rb +34 -0
  40. data/lib/torch/nn/feature_alpha_dropout.rb +9 -0
  41. data/lib/torch/nn/functional.rb +194 -11
  42. data/lib/torch/nn/hinge_embedding_loss.rb +14 -0
  43. data/lib/torch/nn/identity.rb +14 -0
  44. data/lib/torch/nn/init.rb +58 -1
  45. data/lib/torch/nn/kl_div_loss.rb +13 -0
  46. data/lib/torch/nn/l1_loss.rb +13 -0
  47. data/lib/torch/nn/leaky_relu.rb +20 -0
  48. data/lib/torch/nn/linear.rb +12 -11
  49. data/lib/torch/nn/log_softmax.rb +14 -0
  50. data/lib/torch/nn/loss.rb +10 -0
  51. data/lib/torch/nn/margin_ranking_loss.rb +14 -0
  52. data/lib/torch/nn/max_pool2d.rb +9 -0
  53. data/lib/torch/nn/max_poolnd.rb +19 -0
  54. data/lib/torch/nn/module.rb +184 -19
  55. data/lib/torch/nn/mse_loss.rb +2 -2
  56. data/lib/torch/nn/multi_label_margin_loss.rb +13 -0
  57. data/lib/torch/nn/multi_label_soft_margin_loss.rb +13 -0
  58. data/lib/torch/nn/multi_margin_loss.rb +17 -0
  59. data/lib/torch/nn/nll_loss.rb +14 -0
  60. data/lib/torch/nn/pairwise_distance.rb +16 -0
  61. data/lib/torch/nn/parameter.rb +4 -0
  62. data/lib/torch/nn/poisson_nll_loss.rb +16 -0
  63. data/lib/torch/nn/prelu.rb +19 -0
  64. data/lib/torch/nn/relu.rb +8 -3
  65. data/lib/torch/nn/rnn.rb +22 -0
  66. data/lib/torch/nn/rnn_base.rb +154 -0
  67. data/lib/torch/nn/sequential.rb +1 -10
  68. data/lib/torch/nn/sigmoid.rb +9 -0
  69. data/lib/torch/nn/smooth_l1_loss.rb +13 -0
  70. data/lib/torch/nn/soft_margin_loss.rb +13 -0
  71. data/lib/torch/nn/softmax.rb +18 -0
  72. data/lib/torch/nn/softmax2d.rb +10 -0
  73. data/lib/torch/nn/softmin.rb +14 -0
  74. data/lib/torch/nn/softplus.rb +19 -0
  75. data/lib/torch/nn/triplet_margin_loss.rb +18 -0
  76. data/lib/torch/nn/weighted_loss.rb +10 -0
  77. data/lib/torch/optim/adadelta.rb +57 -0
  78. data/lib/torch/optim/adagrad.rb +71 -0
  79. data/lib/torch/optim/adam.rb +81 -0
  80. data/lib/torch/optim/adamax.rb +68 -0
  81. data/lib/torch/optim/adamw.rb +82 -0
  82. data/lib/torch/optim/asgd.rb +65 -0
  83. data/lib/torch/optim/lr_scheduler/lr_scheduler.rb +33 -0
  84. data/lib/torch/optim/lr_scheduler/step_lr.rb +17 -0
  85. data/lib/torch/optim/optimizer.rb +62 -0
  86. data/lib/torch/optim/rmsprop.rb +76 -0
  87. data/lib/torch/optim/rprop.rb +68 -0
  88. data/lib/torch/optim/sgd.rb +60 -0
  89. data/lib/torch/random.rb +10 -0
  90. data/lib/torch/tensor.rb +92 -21
  91. data/lib/torch/utils/data/data_loader.rb +15 -0
  92. data/lib/torch/utils/data/tensor_dataset.rb +8 -1
  93. data/lib/torch/version.rb +1 -1
  94. metadata +74 -3
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 38e16e7f07d004fd9625f168694356d551c79cbc62b0131fe1403e4c0995f296
4
- data.tar.gz: 66bf6ae0e4dd373a7542fbfb1cfb9dbd89fc455e16166e6a76d0945b32fecf38
3
+ metadata.gz: 6b47306ed525e1a20d25cb8324d4658f750c18afa5704c9b7bafc215d8f568c1
4
+ data.tar.gz: dad6ddf955b111989b061e5af146006a32c83dc1ea1ca5005a6b6e34bc9a4892
5
5
  SHA512:
6
- metadata.gz: d100e3a21ac877fe93ac61e9b5e0d8a5e61126684fc037dda3e9f703b040188b1e1523aa4111dff4aaf92ada1001597c5f60674b9583b14d31afd18dbf1ff18d
7
- data.tar.gz: c234dee79e26d3ee25ade2aaddd75f155dea6d59d8b9c5af2c571423a7aaa8a6489f5cfce89f09f390468a951b1644a4212c19525a79816be09214f0938860a8
6
+ metadata.gz: 5d26e3642bf7cd921b9b570052df353d4c32b1bd955a6fbbf5f30249631fa4c0d4624f4fa91a1c06f61b3b0d6461cd117ab4df185cf013e915d2f63e52dbcf7c
7
+ data.tar.gz: 1728ce9b579f41f7a567e63d7256c82bb352840b67f16d88aac930a99e5abbf5a5f4061c5f9da16fb47d1664567e7956d276a8b2b44f13d2263032486afb53e8
data/CHANGELOG.md CHANGED
@@ -1,3 +1,43 @@
1
+ ## 0.1.5 (2019-12-06)
2
+
3
+ - Added many more functions
4
+ - Added tensor classes - `FloatTensor`, `LongTensor`, etc
5
+ - Improved modules
6
+
7
+ ## 0.1.4 (2019-12-01)
8
+
9
+ - Added distance functions
10
+ - Added more activations
11
+ - Added more linear layers
12
+ - Added more loss functions
13
+ - Added more init methods
14
+ - Added support for tensor assignment
15
+
16
+ ## 0.1.3 (2019-11-30)
17
+
18
+ - Changed to BSD 3-Clause license to match PyTorch
19
+ - Added many optimizers
20
+ - Added `StepLR` learning rate scheduler
21
+ - Added dropout
22
+ - Added embedding
23
+ - Added support for `bool` type
24
+ - Improved performance of `from_numo`
25
+
26
+ ## 0.1.2 (2019-11-27)
27
+
28
+ - Added SGD optimizer
29
+ - Added support for gradient to `backward` method
30
+ - Added `argmax`, `eq`, `leaky_relu`, `prelu`, and `reshape` methods
31
+ - Improved indexing
32
+ - Fixed `zero_grad`
33
+ - Fixed error with infinite values
34
+
35
+ ## 0.1.1 (2019-11-26)
36
+
37
+ - Added support for `uint8` and `int8` types
38
+ - Fixed `undefined symbol` error on Linux
39
+ - Fixed C++ error messages
40
+
1
41
  ## 0.1.0 (2019-11-26)
2
42
 
3
43
  - First release
data/LICENSE.txt CHANGED
@@ -1,22 +1,46 @@
1
- Copyright (c) 2019 Andrew Kane
2
-
3
- MIT License
4
-
5
- Permission is hereby granted, free of charge, to any person obtaining
6
- a copy of this software and associated documentation files (the
7
- "Software"), to deal in the Software without restriction, including
8
- without limitation the rights to use, copy, modify, merge, publish,
9
- distribute, sublicense, and/or sell copies of the Software, and to
10
- permit persons to whom the Software is furnished to do so, subject to
11
- the following conditions:
12
-
13
- The above copyright notice and this permission notice shall be
14
- included in all copies or substantial portions of the Software.
15
-
16
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
17
- EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
18
- MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
19
- NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
20
- LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
21
- OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
22
- WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
1
+ BSD 3-Clause License
2
+
3
+ From Torch-rb:
4
+
5
+ Copyright (c) 2019- Andrew Kane
6
+
7
+ From PyTorch (for ported code):
8
+
9
+ Copyright (c) 2016- Facebook, Inc (Adam Paszke)
10
+ Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
11
+ Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
12
+ Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
13
+ Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
14
+ Copyright (c) 2011-2013 NYU (Clement Farabet)
15
+ Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
16
+ Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
17
+ Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
18
+
19
+ All rights reserved.
20
+
21
+ Redistribution and use in source and binary forms, with or without
22
+ modification, are permitted provided that the following conditions are met:
23
+
24
+ 1. Redistributions of source code must retain the above copyright
25
+ notice, this list of conditions and the following disclaimer.
26
+
27
+ 2. Redistributions in binary form must reproduce the above copyright
28
+ notice, this list of conditions and the following disclaimer in the
29
+ documentation and/or other materials provided with the distribution.
30
+
31
+ 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America
32
+ and IDIAP Research Institute nor the names of its contributors may be
33
+ used to endorse or promote products derived from this software without
34
+ specific prior written permission.
35
+
36
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
37
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
38
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
39
+ ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
40
+ LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
41
+ CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
42
+ SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
43
+ INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
44
+ CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
45
+ ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
46
+ POSSIBILITY OF SUCH DAMAGE.
data/README.md CHANGED
@@ -2,14 +2,16 @@
2
2
 
3
3
  :fire: Deep learning for Ruby, powered by [LibTorch](https://pytorch.org)
4
4
 
5
- **Note:** This gem is currently experimental. There may be breaking changes between each release.
5
+ This gem is currently experimental. There may be breaking changes between each release. Please report any issues you experience.
6
+
7
+ [![Build Status](https://travis-ci.org/ankane/torch-rb.svg?branch=master)](https://travis-ci.org/ankane/torch-rb)
6
8
 
7
9
  ## Installation
8
10
 
9
11
  First, [install LibTorch](#libtorch-installation). For Homebrew, use:
10
12
 
11
13
  ```sh
12
- brew install ankane/brew/libtorch
14
+ brew install libtorch
13
15
  ```
14
16
 
15
17
  Add this line to your application’s Gemfile:
@@ -18,6 +20,8 @@ Add this line to your application’s Gemfile:
18
20
  gem 'torch-rb'
19
21
  ```
20
22
 
23
+ It can take a few minutes to compile the extension.
24
+
21
25
  ## Getting Started
22
26
 
23
27
  This library follows the [PyTorch API](https://pytorch.org/docs/stable/torch.html). There are a few changes to make it more Ruby-like:
@@ -26,9 +30,11 @@ This library follows the [PyTorch API](https://pytorch.org/docs/stable/torch.htm
26
30
  - Methods that return booleans use `?` instead of `is_` (`tensor?` instead of `is_tensor`)
27
31
  - Numo is used instead of NumPy (`x.numo` instead of `x.numpy()`)
28
32
 
29
- Many methods and options are missing at the moment. PRs welcome!
33
+ Some methods and options are missing at the moment. PRs welcome!
30
34
 
31
- Some examples below are from [Deep Learning with PyTorch: A 60 Minutes Blitz](https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html).
35
+ ## Tutorial
36
+
37
+ Some examples below are from [Deep Learning with PyTorch: A 60 Minutes Blitz](https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html)
32
38
 
33
39
  ### Tensors
34
40
 
@@ -143,7 +149,7 @@ Convert a Numo array to a tensor
143
149
 
144
150
  ```ruby
145
151
  b = Numo::NArray.cast([1, 2, 3])
146
- Torch.from_numpy(b)
152
+ Torch.from_numo(b)
147
153
  ```
148
154
 
149
155
  ### Autograd
@@ -171,17 +177,17 @@ out.backward
171
177
  Get gradients
172
178
 
173
179
  ```ruby
174
- x.grad
180
+ x.grad # tensor([[4.5, 4.5], [4.5, 4.5]])
175
181
  ```
176
182
 
177
183
  Stop autograd from tracking history
178
184
 
179
185
  ```ruby
180
186
  x.requires_grad # true
181
- (x ** 2).requires_grad # true
187
+ (x**2).requires_grad # true
182
188
 
183
189
  Torch.no_grad do
184
- (x ** 2).requires_grad # false
190
+ (x**2).requires_grad # false
185
191
  end
186
192
  ```
187
193
 
@@ -221,7 +227,7 @@ class Net < Torch::NN::Module
221
227
  end
222
228
  ```
223
229
 
224
- And run
230
+ Create an instance of it
225
231
 
226
232
  ```ruby
227
233
  net = Net.new
@@ -229,6 +235,58 @@ input = Torch.randn(1, 1, 32, 32)
229
235
  net.call(input)
230
236
  ```
231
237
 
238
+ Get trainable parameters
239
+
240
+ ```ruby
241
+ net.parameters
242
+ ```
243
+
244
+ Zero the gradient buffers and backprop with random gradients
245
+
246
+ ```ruby
247
+ net.zero_grad
248
+ out.backward(Torch.randn(1, 10))
249
+ ```
250
+
251
+ Define a loss function
252
+
253
+ ```ruby
254
+ output = net.call(input)
255
+ target = Torch.randn(10)
256
+ target = target.view(1, -1)
257
+ criterion = Torch::NN::MSELoss.new
258
+ loss = criterion.call(output, target)
259
+ ```
260
+
261
+ Backprop
262
+
263
+ ```ruby
264
+ net.zero_grad
265
+ p net.conv1.bias.grad
266
+ loss.backward
267
+ p net.conv1.bias.grad
268
+ ```
269
+
270
+ Update the weights
271
+
272
+ ```ruby
273
+ learning_rate = 0.01
274
+ net.parameters.each do |f|
275
+ f.data.sub!(f.grad.data * learning_rate)
276
+ end
277
+ ```
278
+
279
+ Use an optimizer
280
+
281
+ ```ruby
282
+ optimizer = Torch::Optim::SGD.new(net.parameters, lr: 0.01)
283
+ optimizer.zero_grad
284
+ output = net.call(input)
285
+ loss = criterion.call(output, target)
286
+ loss.backward
287
+ optimizer.step
288
+ ```
289
+
232
290
  ### Tensor Creation
233
291
 
234
292
  Here’s a list of functions to create tensors (descriptions from the [C++ docs](https://pytorch.org/cppdocs/notes/tensor_creation.html)):
@@ -242,7 +300,7 @@ Here’s a list of functions to create tensors (descriptions from the [C++ docs]
242
300
  - `empty` returns a tensor with uninitialized values
243
301
 
244
302
  ```ruby
245
- Torch.empty(3)
303
+ Torch.empty(3) # tensor([7.0054e-45, 0.0000e+00, 0.0000e+00])
246
304
  ```
247
305
 
248
306
  - `eye` returns an identity matrix
@@ -278,19 +336,19 @@ Here’s a list of functions to create tensors (descriptions from the [C++ docs]
278
336
  - `rand` returns a tensor filled with values drawn from a uniform distribution on [0, 1)
279
337
 
280
338
  ```ruby
281
- Torch.rand(3)
339
+ Torch.rand(3) # tensor([0.5444, 0.8799, 0.5571])
282
340
  ```
283
341
 
284
342
  - `randint` returns a tensor with integers randomly drawn from an interval
285
343
 
286
344
  ```ruby
287
- Torch.randint(1, 10, [3])
345
+ Torch.randint(1, 10, [3]) # tensor([7, 6, 4])
288
346
  ```
289
347
 
290
348
  - `randn` returns a tensor filled with values drawn from a unit normal distribution
291
349
 
292
350
  ```ruby
293
- Torch.randn(3)
351
+ Torch.randn(3) # tensor([-0.7147, 0.6614, 1.1453])
294
352
  ```
295
353
 
296
354
  - `randperm` returns a tensor filled with a random permutation of integers in some interval
@@ -305,12 +363,20 @@ Here’s a list of functions to create tensors (descriptions from the [C++ docs]
305
363
  Torch.zeros(3) # tensor([0, 0, 0])
306
364
  ```
307
365
 
366
+ ## Examples
367
+
368
+ Here are a few full examples:
369
+
370
+ - [Image classification with MNIST](examples/mnist) ([日本語版](https://qiita.com/kojix2/items/c19c36dc1bf73ea93409))
371
+ - [Collaborative filtering with MovieLens](examples/movielens)
372
+ - [Word embeddings](examples/nlp)
373
+
308
374
  ## LibTorch Installation
309
375
 
310
- [Download LibTorch](https://pytorch.org/) and run:
376
+ [Download LibTorch](https://pytorch.org/). For Linux, use the `cxx11 ABI` version. Then run:
311
377
 
312
378
  ```sh
313
- gem install torch-rb -- --with-torch-dir=/path/to/libtorch
379
+ bundle config build.torch-rb --with-torch-dir=/path/to/libtorch
314
380
  ```
315
381
 
316
382
  ### Homebrew
@@ -318,10 +384,10 @@ gem install torch-rb -- --with-torch-dir=/path/to/libtorch
318
384
  For Mac, you can use Homebrew.
319
385
 
320
386
  ```sh
321
- brew install ankane/brew/libtorch
387
+ brew install libtorch
322
388
  ```
323
389
 
324
- Then install the gem (no need for `--with-torch-dir`).
390
+ Then install the gem (no need for `bundle config`).
325
391
 
326
392
  ## rbenv
327
393
 
@@ -349,9 +415,9 @@ To get started with development:
349
415
 
350
416
  ```sh
351
417
  git clone https://github.com/ankane/torch-rb.git
352
- cd torch
418
+ cd torch-rb
353
419
  bundle install
354
- bundle exec rake compile
420
+ bundle exec rake compile -- --with-torch-dir=/path/to/libtorch
355
421
  bundle exec rake test
356
422
  ```
357
423
 
data/ext/torch/ext.cpp CHANGED
@@ -6,95 +6,36 @@
6
6
  #include <rice/Class.hpp>
7
7
  #include <rice/Constructor.hpp>
8
8
 
9
- using namespace Rice;
9
+ #include "templates.hpp"
10
10
 
11
- template<>
12
- inline
13
- long long from_ruby<long long>(Object x)
14
- {
15
- return NUM2LL(x);
16
- }
17
-
18
- template<>
19
- inline
20
- Object to_ruby<long long>(long long const & x)
21
- {
22
- return LL2NUM(x);
23
- }
24
-
25
- template<>
26
- inline
27
- unsigned long long from_ruby<unsigned long long>(Object x)
28
- {
29
- return NUM2ULL(x);
30
- }
11
+ // generated with:
12
+ // rake generate:functions
13
+ #include "torch_functions.hpp"
14
+ #include "tensor_functions.hpp"
15
+ #include "nn_functions.hpp"
31
16
 
32
- template<>
33
- inline
34
- Object to_ruby<unsigned long long>(unsigned long long const & x)
35
- {
36
- return ULL2NUM(x);
37
- }
38
-
39
- template<>
40
- inline
41
- short from_ruby<short>(Object x)
42
- {
43
- return NUM2SHORT(x);
44
- }
45
-
46
- template<>
47
- inline
48
- Object to_ruby<short>(short const & x)
49
- {
50
- return INT2NUM(x);
51
- }
17
+ using namespace Rice;
52
18
 
53
- template<>
54
- inline
55
- unsigned short from_ruby<unsigned short>(Object x)
56
- {
57
- return NUM2USHORT(x);
19
+ Object tensor_array(std::tuple<torch::Tensor, torch::Tensor> x) {
20
+ Array a;
21
+ a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
22
+ a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
23
+ return Object(a);
58
24
  }
59
25
 
60
- template<>
61
- inline
62
- Object to_ruby<unsigned short>(unsigned short const & x)
26
+ extern "C"
27
+ void Init_ext()
63
28
  {
64
- return UINT2NUM(x);
65
- }
29
+ Module rb_mTorch = define_module("Torch");
30
+ add_torch_functions(rb_mTorch);
66
31
 
67
- // need to wrap torch::IntArrayRef() since
68
- // it doesn't own underlying data
69
- class IntArrayRef {
70
- std::vector<int64_t> vec;
71
- public:
72
- IntArrayRef(Object o) {
73
- Array a = Array(o);
74
- for (size_t i = 0; i < a.size(); i++) {
75
- vec.push_back(from_ruby<int64_t>(a[i]));
76
- }
77
- }
78
- operator torch::IntArrayRef() {
79
- return torch::IntArrayRef(vec);
80
- }
81
- };
32
+ Class rb_cTensor = define_class_under<torch::Tensor>(rb_mTorch, "Tensor");
33
+ add_tensor_functions(rb_cTensor);
82
34
 
83
- template<>
84
- inline
85
- IntArrayRef from_ruby<IntArrayRef>(Object x)
86
- {
87
- return IntArrayRef(x);
88
- }
89
-
90
- // for now
91
- typedef float Scalar;
35
+ Module rb_mNN = define_module_under(rb_mTorch, "NN");
36
+ add_nn_functions(rb_mNN);
92
37
 
93
- extern "C"
94
- void Init_ext()
95
- {
96
- Module rb_mTorch = define_module("Torch")
97
- .define_singleton_method(
38
+ rb_mTorch.define_singleton_method(
98
39
  "grad_enabled?",
99
40
  *[]() {
100
41
  return torch::GradMode::is_enabled();
@@ -104,11 +45,6 @@ void Init_ext()
104
45
  *[](bool enabled) {
105
46
  torch::GradMode::set_enabled(enabled);
106
47
  })
107
- .define_singleton_method(
108
- "floating_point?",
109
- *[](torch::Tensor& input) {
110
- return torch::is_floating_point(input);
111
- })
112
48
  .define_singleton_method(
113
49
  "manual_seed",
114
50
  *[](uint64_t seed) {
@@ -178,172 +114,117 @@ void Init_ext()
178
114
  // begin operations
179
115
  .define_singleton_method(
180
116
  "_mean",
181
- *[](torch::Tensor& input) {
117
+ *[](Tensor& input) {
182
118
  return torch::mean(input);
183
119
  })
184
120
  .define_singleton_method(
185
121
  "_mean_dim",
186
- *[](torch::Tensor& input, int64_t dim, bool keepdim) {
122
+ *[](Tensor& input, int64_t dim, bool keepdim) {
187
123
  return torch::mean(input, dim, keepdim);
188
124
  })
189
125
  .define_singleton_method(
190
126
  "_sum",
191
- *[](torch::Tensor& input) {
127
+ *[](Tensor& input) {
192
128
  return torch::sum(input);
193
129
  })
194
130
  .define_singleton_method(
195
131
  "_sum_dim",
196
- *[](torch::Tensor& input, int64_t dim, bool keepdim) {
132
+ *[](Tensor& input, int64_t dim, bool keepdim) {
197
133
  return torch::sum(input, dim, keepdim);
198
134
  })
199
135
  .define_singleton_method(
200
- "_norm",
201
- *[](torch::Tensor& input) {
202
- return torch::norm(input);
203
- })
204
- .define_singleton_method(
205
- "_min",
206
- *[](torch::Tensor& input) {
207
- return torch::min(input);
136
+ "_max_out",
137
+ *[](Tensor &max, Tensor &max_indices, const Tensor &input, int64_t dim, bool keepdim) {
138
+ return tensor_array(torch::_max_out(max, max_indices, input, dim, keepdim));
208
139
  })
209
140
  .define_singleton_method(
210
- "_max",
211
- *[](torch::Tensor& input) {
212
- return torch::max(input);
141
+ "_topk",
142
+ *[](Tensor& input, int64_t k) {
143
+ return tensor_array(torch::topk(input, k));
213
144
  })
214
145
  .define_singleton_method(
215
- "_exp",
216
- *[](torch::Tensor& input) {
217
- return torch::exp(input);
146
+ "_softmax",
147
+ *[](const Tensor &input, int64_t dim) {
148
+ return torch::softmax(input, dim);
218
149
  })
219
150
  .define_singleton_method(
220
- "_log",
221
- *[](torch::Tensor& input) {
222
- return torch::log(input);
151
+ "_log_softmax",
152
+ *[](Tensor& input, int64_t dim) {
153
+ return torch::log_softmax(input, dim);
223
154
  })
224
155
  .define_singleton_method(
225
- "_unsqueeze",
226
- *[](torch::Tensor& input, int64_t dim) {
227
- return torch::unsqueeze(input, dim);
228
- })
229
- .define_singleton_method(
230
- "_dot",
231
- *[](torch::Tensor& input, torch::Tensor& tensor) {
232
- return torch::dot(input, tensor);
233
- })
234
- .define_singleton_method(
235
- "_matmul",
236
- *[](torch::Tensor& input, torch::Tensor& other) {
237
- return torch::matmul(input, other);
238
- })
239
- .define_singleton_method(
240
- "_add",
241
- *[](torch::Tensor& input, torch::Tensor& other) {
242
- return torch::add(input, other);
243
- })
244
- .define_singleton_method(
245
- "_add_scalar",
246
- *[](torch::Tensor& input, float other) {
247
- return torch::add(input, other);
248
- })
249
- .define_singleton_method(
250
- "_add_out",
251
- *[](torch::Tensor& out, torch::Tensor& input, torch::Tensor& other) {
252
- return torch::add_out(out, input, other);
253
- })
254
- .define_singleton_method(
255
- "_sub",
256
- *[](torch::Tensor& input, torch::Tensor& other) {
257
- return torch::sub(input, other);
258
- })
259
- .define_singleton_method(
260
- "_sub_scalar",
261
- *[](torch::Tensor& input, float other) {
262
- return torch::sub(input, other);
263
- })
264
- .define_singleton_method(
265
- "_mul",
266
- *[](torch::Tensor& input, torch::Tensor& other) {
267
- return torch::mul(input, other);
268
- })
269
- .define_singleton_method(
270
- "_mul_scalar",
271
- *[](torch::Tensor& input, float other) {
272
- return torch::mul(input, other);
273
- })
274
- .define_singleton_method(
275
- "_div",
276
- *[](torch::Tensor& input, torch::Tensor& other) {
277
- return torch::div(input, other);
278
- })
279
- .define_singleton_method(
280
- "_div_scalar",
281
- *[](torch::Tensor& input, float other) {
282
- return torch::div(input, other);
283
- })
284
- .define_singleton_method(
285
- "_remainder",
286
- *[](torch::Tensor& input, torch::Tensor& other) {
287
- return torch::remainder(input, other);
288
- })
289
- .define_singleton_method(
290
- "_remainder_scalar",
291
- *[](torch::Tensor& input, float other) {
292
- return torch::remainder(input, other);
156
+ "relu",
157
+ *[](Tensor& input) {
158
+ return torch::relu(input);
293
159
  })
294
160
  .define_singleton_method(
295
- "_pow",
296
- *[](torch::Tensor& input, Scalar exponent) {
297
- return torch::pow(input, exponent);
161
+ "prelu",
162
+ *[](torch::Tensor& input, torch::Tensor& weight) {
163
+ return torch::prelu(input, weight);
298
164
  })
299
165
  .define_singleton_method(
300
- "_neg",
301
- *[](torch::Tensor& input) {
302
- return torch::neg(input);
166
+ "leaky_relu",
167
+ *[](torch::Tensor& input, Scalar negative_slope) {
168
+ return torch::leaky_relu(input, negative_slope);
303
169
  })
304
170
  .define_singleton_method(
305
- "relu",
306
- *[](torch::Tensor& input) {
307
- return torch::relu(input);
171
+ "conv2d",
172
+ *[](Tensor& input, Tensor& weight, Tensor& bias, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, int64_t groups) {
173
+ return torch::conv2d(input, weight, bias, stride, padding, dilation, groups);
308
174
  })
175
+ // linear layers
309
176
  .define_singleton_method(
310
- "conv2d",
311
- *[](torch::Tensor& input, torch::Tensor& weight, torch::Tensor& bias) {
312
- return torch::conv2d(input, weight, bias);
177
+ "bilinear",
178
+ *[](const Tensor &input1, const Tensor &input2, const Tensor &weight, const Tensor &bias) {
179
+ return torch::bilinear(input1, input2, weight, bias);
313
180
  })
314
181
  .define_singleton_method(
315
182
  "linear",
316
- *[](torch::Tensor& input, torch::Tensor& weight, torch::Tensor& bias) {
183
+ *[](Tensor& input, Tensor& weight, Tensor& bias) {
317
184
  return torch::linear(input, weight, bias);
318
185
  })
186
+ // pooling layers
319
187
  .define_singleton_method(
320
188
  "max_pool2d",
321
- *[](torch::Tensor& input, IntArrayRef kernel_size) {
189
+ *[](Tensor& input, IntArrayRef kernel_size) {
322
190
  return torch::max_pool2d(input, kernel_size);
323
191
  })
324
192
  .define_singleton_method(
325
- "mse_loss",
326
- *[](torch::Tensor& input, torch::Tensor& target, std::string reduction) {
327
- auto red = reduction == "mean" ? Reduction::Mean : Reduction::Sum;
328
- return torch::mse_loss(input, target, red);
193
+ "avg_pool2d",
194
+ *[](Tensor& input, IntArrayRef kernel_size) {
195
+ return torch::avg_pool2d(input, kernel_size);
196
+ })
197
+ .define_singleton_method(
198
+ "_binary_cross_entropy_with_logits",
199
+ *[](const Tensor &input, const Tensor &target, OptionalTensor weight, OptionalTensor pos_weight, MyReduction reduction) {
200
+ return torch::binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction);
329
201
  })
202
+ .define_singleton_method("numel", &torch::numel)
330
203
  .define_singleton_method(
331
- "nll_loss",
332
- *[](torch::Tensor& input, torch::Tensor& target) {
333
- return torch::nll_loss(input, target);
204
+ "_from_blob",
205
+ *[](String s, IntArrayRef size, const torch::TensorOptions &options) {
206
+ void *data = const_cast<char *>(s.c_str());
207
+ return torch::from_blob(data, size, options);
334
208
  })
335
209
  .define_singleton_method(
336
210
  "_tensor",
337
211
  *[](Object o, IntArrayRef size, const torch::TensorOptions &options) {
338
212
  Array a = Array(o);
339
- std::vector<float> vec;
340
- for (size_t i = 0; i < a.size(); i++) {
341
- vec.push_back(from_ruby<float>(a[i]));
213
+ auto dtype = options.dtype();
214
+ torch::Tensor t;
215
+ if (dtype == torch::kBool) {
216
+ throw std::runtime_error("Cannot create bool from tensor method yet");
217
+ } else {
218
+ std::vector<float> vec;
219
+ for (size_t i = 0; i < a.size(); i++) {
220
+ vec.push_back(from_ruby<float>(a[i]));
221
+ }
222
+ t = torch::tensor(vec, options);
342
223
  }
343
- return torch::tensor(vec, options).reshape(size);
224
+ return t.reshape(size);
344
225
  });
345
226
 
346
- Class rb_cTensor = define_class_under<torch::Tensor>(rb_mTorch, "Tensor")
227
+ rb_cTensor
347
228
  .define_method("cuda?", &torch::Tensor::is_cuda)
348
229
  .define_method("distributed?", &torch::Tensor::is_distributed)
349
230
  .define_method("complex?", &torch::Tensor::is_complex)
@@ -352,108 +233,162 @@ void Init_ext()
352
233
  .define_method("sparse?", &torch::Tensor::is_sparse)
353
234
  .define_method("quantized?", &torch::Tensor::is_quantized)
354
235
  .define_method("dim", &torch::Tensor::dim)
355
- .define_method("numel", &torch::Tensor::numel)
356
236
  .define_method("element_size", &torch::Tensor::element_size)
357
237
  .define_method("requires_grad", &torch::Tensor::requires_grad)
238
+ .define_method("view_as", &torch::Tensor::view_as)
239
+ .define_method(
240
+ "addcmul!",
241
+ *[](Tensor& self, Scalar value, const Tensor & tensor1, const Tensor & tensor2) {
242
+ return self.addcmul_(tensor1, tensor2, value);
243
+ })
244
+ .define_method(
245
+ "addcdiv!",
246
+ *[](Tensor& self, Scalar value, const Tensor & tensor1, const Tensor & tensor2) {
247
+ return self.addcdiv_(tensor1, tensor2, value);
248
+ })
358
249
  .define_method(
359
250
  "zero!",
360
- *[](torch::Tensor& self) {
251
+ *[](Tensor& self) {
361
252
  return self.zero_();
362
253
  })
363
254
  .define_method(
364
- "detach!",
365
- *[](torch::Tensor& self) {
366
- return self.detach_();
255
+ "detach",
256
+ *[](Tensor& self) {
257
+ return self.detach();
367
258
  })
368
259
  .define_method(
369
- "_access",
370
- *[](torch::Tensor& self, int64_t index) {
371
- return self[index];
260
+ "detach!",
261
+ *[](Tensor& self) {
262
+ return self.detach_();
372
263
  })
373
264
  .define_method(
374
265
  "_requires_grad!",
375
- *[](torch::Tensor& self, bool requires_grad) {
266
+ *[](Tensor& self, bool requires_grad) {
376
267
  return self.set_requires_grad(requires_grad);
377
268
  })
378
269
  .define_method(
379
- "backward",
380
- *[](torch::Tensor& self) {
381
- return self.backward();
270
+ "_backward",
271
+ *[](Tensor& self, Object gradient) {
272
+ return gradient.is_nil() ? self.backward() : self.backward(from_ruby<torch::Tensor>(gradient));
382
273
  })
383
274
  .define_method(
384
275
  "grad",
385
- *[](torch::Tensor& self) {
276
+ *[](Tensor& self) {
386
277
  return self.grad();
387
278
  })
388
279
  .define_method(
389
280
  "_dtype",
390
- *[](torch::Tensor& self) {
281
+ *[](Tensor& self) {
391
282
  return (int) at::typeMetaToScalarType(self.dtype());
392
283
  })
284
+ .define_method(
285
+ "_type",
286
+ *[](Tensor& self, int dtype) {
287
+ return self.toType((torch::ScalarType) dtype);
288
+ })
393
289
  .define_method(
394
290
  "_layout",
395
- *[](torch::Tensor& self) {
291
+ *[](Tensor& self) {
396
292
  std::stringstream s;
397
293
  s << self.layout();
398
294
  return s.str();
399
295
  })
400
296
  .define_method(
401
297
  "device",
402
- *[](torch::Tensor& self) {
298
+ *[](Tensor& self) {
403
299
  std::stringstream s;
404
300
  s << self.device();
405
301
  return s.str();
406
302
  })
407
303
  .define_method(
408
- "_view",
409
- *[](torch::Tensor& self, IntArrayRef size) {
410
- return self.view(size);
304
+ "resize_as!",
305
+ *[](Tensor& self, Tensor& other) {
306
+ return self.resize_as_(other);
411
307
  })
412
308
  .define_method(
413
- "add!",
414
- *[](torch::Tensor& self, torch::Tensor& other) {
415
- self.add_(other);
309
+ "fill!",
310
+ *[](Tensor& self, Scalar value) {
311
+ return self.fill_(value);
416
312
  })
417
313
  .define_method(
418
- "sub!",
419
- *[](torch::Tensor& self, torch::Tensor& other) {
420
- self.sub_(other);
314
+ "relu!",
315
+ *[](Tensor& self) {
316
+ return self.relu_();
317
+ })
318
+ .define_method(
319
+ "normal!",
320
+ *[](Tensor& self, double mean, double std) {
321
+ return self.normal_(mean, std);
322
+ })
323
+ .define_method(
324
+ "random!",
325
+ *[](Tensor& self, int64_t to) {
326
+ return self.random_(to);
421
327
  })
422
328
  .define_method(
423
- "mul!",
424
- *[](torch::Tensor& self, torch::Tensor& other) {
425
- self.mul_(other);
329
+ "sub!",
330
+ *[](Tensor& self, Tensor& other) {
331
+ return self.sub_(other);
426
332
  })
427
333
  .define_method(
428
334
  "div!",
429
- *[](torch::Tensor& self, torch::Tensor& other) {
430
- self.div_(other);
335
+ *[](Tensor& self, Tensor& other) {
336
+ return self.div_(other);
431
337
  })
432
338
  .define_method(
433
- "log_softmax",
434
- *[](torch::Tensor& self, int64_t dim) {
435
- return self.log_softmax(dim);
339
+ "sqrt!",
340
+ *[](Tensor& self) {
341
+ return self.sqrt_();
436
342
  })
437
343
  .define_method(
438
- "_data",
439
- *[](torch::Tensor& self) {
344
+ "unsqueeze!",
345
+ *[](Tensor& self, int64_t dim) {
346
+ return self.unsqueeze_(dim);
347
+ })
348
+ .define_method(
349
+ "copy!",
350
+ *[](Tensor& self, Tensor& src) {
351
+ return self.copy_(src);
352
+ })
353
+ .define_method(
354
+ "clone",
355
+ *[](Tensor& self) {
356
+ return self.clone();
357
+ })
358
+ .define_method(
359
+ "data",
360
+ *[](Tensor& self) {
361
+ return self.data();
362
+ })
363
+ .define_method(
364
+ "_flat_data",
365
+ *[](Tensor& self) {
440
366
  Array a;
441
367
  auto dtype = self.dtype();
442
368
 
443
369
  // TODO DRY if someone knows C++
444
- // TODO kByte (uint8), kChar (int8), kBool (bool)
445
- if (dtype == torch::kShort) {
446
- short* data = self.data_ptr<short>();
370
+ if (dtype == torch::kByte) {
371
+ uint8_t* data = self.data_ptr<uint8_t>();
372
+ for (int i = 0; i < self.numel(); i++) {
373
+ a.push(data[i]);
374
+ }
375
+ } else if (dtype == torch::kChar) {
376
+ int8_t* data = self.data_ptr<int8_t>();
377
+ for (int i = 0; i < self.numel(); i++) {
378
+ a.push(to_ruby<int>(data[i]));
379
+ }
380
+ } else if (dtype == torch::kShort) {
381
+ int16_t* data = self.data_ptr<int16_t>();
447
382
  for (int i = 0; i < self.numel(); i++) {
448
383
  a.push(data[i]);
449
384
  }
450
385
  } else if (dtype == torch::kInt) {
451
- int* data = self.data_ptr<int>();
386
+ int32_t* data = self.data_ptr<int32_t>();
452
387
  for (int i = 0; i < self.numel(); i++) {
453
388
  a.push(data[i]);
454
389
  }
455
390
  } else if (dtype == torch::kLong) {
456
- long long* data = self.data_ptr<long long>();
391
+ int64_t* data = self.data_ptr<int64_t>();
457
392
  for (int i = 0; i < self.numel(); i++) {
458
393
  a.push(data[i]);
459
394
  }
@@ -467,19 +402,24 @@ void Init_ext()
467
402
  for (int i = 0; i < self.numel(); i++) {
468
403
  a.push(data[i]);
469
404
  }
405
+ } else if (dtype == torch::kBool) {
406
+ bool* data = self.data_ptr<bool>();
407
+ for (int i = 0; i < self.numel(); i++) {
408
+ a.push(data[i] ? True : False);
409
+ }
470
410
  } else {
471
- throw "Unsupported type";
411
+ throw std::runtime_error("Unsupported type");
472
412
  }
473
413
  return a;
474
414
  })
475
415
  .define_method(
476
- "_size",
477
- *[](torch::Tensor& self, int i) {
478
- return self.size(i);
416
+ "_to",
417
+ *[](Tensor& self, torch::Device device, int dtype, bool non_blocking, bool copy) {
418
+ return self.to(device, (torch::ScalarType) dtype, non_blocking, copy);
479
419
  })
480
420
  .define_singleton_method(
481
421
  "_make_subclass",
482
- *[](torch::Tensor& rd, bool requires_grad) {
422
+ *[](Tensor& rd, bool requires_grad) {
483
423
  auto data = torch::autograd::as_variable_ref(rd).detach();
484
424
  data.unsafeGetTensorImpl()->set_allow_tensor_metadata_change(true);
485
425
  auto var = data.set_requires_grad(requires_grad);
@@ -499,8 +439,11 @@ void Init_ext()
499
439
  torch::Layout l;
500
440
  if (layout == "strided") {
501
441
  l = torch::kStrided;
442
+ } else if (layout == "sparse") {
443
+ l = torch::kSparse;
444
+ throw std::runtime_error("Sparse layout not supported yet");
502
445
  } else {
503
- throw "Unsupported layout";
446
+ throw std::runtime_error("Unsupported layout: " + layout);
504
447
  }
505
448
  return self.layout(l);
506
449
  })
@@ -513,7 +456,7 @@ void Init_ext()
513
456
  } else if (device == "cuda") {
514
457
  d = torch::kCUDA;
515
458
  } else {
516
- throw "Unsupported device";
459
+ throw std::runtime_error("Unsupported device: " + device);
517
460
  }
518
461
  return self.device(d);
519
462
  })
@@ -523,24 +466,99 @@ void Init_ext()
523
466
  return self.requires_grad(requires_grad);
524
467
  });
525
468
 
526
- Module rb_mNN = define_module_under(rb_mTorch, "NN");
527
-
528
469
  Module rb_mInit = define_module_under(rb_mNN, "Init")
529
470
  .define_singleton_method(
530
- "kaiming_uniform_",
531
- *[](torch::Tensor& input, double a) {
532
- return torch::nn::init::kaiming_uniform_(input, a);
471
+ "_calculate_gain",
472
+ *[](NonlinearityType nonlinearity, double param) {
473
+ return torch::nn::init::calculate_gain(nonlinearity, param);
533
474
  })
534
475
  .define_singleton_method(
535
- "uniform_",
536
- *[](torch::Tensor& input, double to, double from) {
537
- return torch::nn::init::uniform_(input, to, from);
476
+ "_uniform!",
477
+ *[](Tensor tensor, double low, double high) {
478
+ return torch::nn::init::uniform_(tensor, low, high);
479
+ })
480
+ .define_singleton_method(
481
+ "_normal!",
482
+ *[](Tensor tensor, double mean, double std) {
483
+ return torch::nn::init::normal_(tensor, mean, std);
484
+ })
485
+ .define_singleton_method(
486
+ "_constant!",
487
+ *[](Tensor tensor, Scalar value) {
488
+ return torch::nn::init::constant_(tensor, value);
489
+ })
490
+ .define_singleton_method(
491
+ "_ones!",
492
+ *[](Tensor tensor) {
493
+ return torch::nn::init::ones_(tensor);
494
+ })
495
+ .define_singleton_method(
496
+ "_zeros!",
497
+ *[](Tensor tensor) {
498
+ return torch::nn::init::zeros_(tensor);
499
+ })
500
+ .define_singleton_method(
501
+ "_eye!",
502
+ *[](Tensor tensor) {
503
+ return torch::nn::init::eye_(tensor);
504
+ })
505
+ .define_singleton_method(
506
+ "_dirac!",
507
+ *[](Tensor tensor) {
508
+ return torch::nn::init::dirac_(tensor);
509
+ })
510
+ .define_singleton_method(
511
+ "_xavier_uniform!",
512
+ *[](Tensor tensor, double gain) {
513
+ return torch::nn::init::xavier_uniform_(tensor, gain);
514
+ })
515
+ .define_singleton_method(
516
+ "_xavier_normal!",
517
+ *[](Tensor tensor, double gain) {
518
+ return torch::nn::init::xavier_normal_(tensor, gain);
519
+ })
520
+ .define_singleton_method(
521
+ "_kaiming_uniform!",
522
+ *[](Tensor tensor, double a, FanModeType mode, NonlinearityType nonlinearity) {
523
+ return torch::nn::init::kaiming_uniform_(tensor, a, mode, nonlinearity);
524
+ })
525
+ .define_singleton_method(
526
+ "_kaiming_normal!",
527
+ *[](Tensor tensor, double a, FanModeType mode, NonlinearityType nonlinearity) {
528
+ return torch::nn::init::kaiming_normal_(tensor, a, mode, nonlinearity);
529
+ })
530
+ .define_singleton_method(
531
+ "_orthogonal!",
532
+ *[](Tensor tensor, double gain) {
533
+ return torch::nn::init::orthogonal_(tensor, gain);
534
+ })
535
+ .define_singleton_method(
536
+ "_sparse!",
537
+ *[](Tensor tensor, double sparsity, double std) {
538
+ return torch::nn::init::sparse_(tensor, sparsity, std);
538
539
  });
539
540
 
540
541
  Class rb_cParameter = define_class_under<torch::autograd::Variable, torch::Tensor>(rb_mNN, "Parameter")
541
542
  .define_method(
542
543
  "grad",
543
544
  *[](torch::autograd::Variable& self) {
544
- return self.grad();
545
+ auto grad = self.grad();
546
+ return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
547
+ });
548
+
549
+ Class rb_cDevice = define_class_under<torch::Device>(rb_mTorch, "Device")
550
+ .define_constructor(Constructor<torch::Device, std::string>())
551
+ .define_method("index", &torch::Device::index)
552
+ .define_method("index?", &torch::Device::has_index)
553
+ .define_method(
554
+ "type",
555
+ *[](torch::Device& self) {
556
+ std::stringstream s;
557
+ s << self.type();
558
+ return s.str();
545
559
  });
560
+
561
+ Module rb_mCUDA = define_module_under(rb_mTorch, "CUDA")
562
+ .define_singleton_method("available?", &torch::cuda::is_available)
563
+ .define_singleton_method("device_count", &torch::cuda::device_count);
546
564
  }