torch-rb 0.1.0 → 0.1.5

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (94) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +40 -0
  3. data/LICENSE.txt +46 -22
  4. data/README.md +85 -19
  5. data/ext/torch/ext.cpp +274 -256
  6. data/ext/torch/extconf.rb +9 -0
  7. data/ext/torch/nn_functions.cpp +595 -0
  8. data/ext/torch/nn_functions.hpp +6 -0
  9. data/ext/torch/templates.hpp +250 -0
  10. data/ext/torch/tensor_functions.cpp +1860 -0
  11. data/ext/torch/tensor_functions.hpp +6 -0
  12. data/ext/torch/torch_functions.cpp +2875 -0
  13. data/ext/torch/torch_functions.hpp +6 -0
  14. data/lib/torch.rb +199 -84
  15. data/lib/torch/ext.bundle +0 -0
  16. data/lib/torch/inspector.rb +52 -25
  17. data/lib/torch/native/dispatcher.rb +48 -0
  18. data/lib/torch/native/function.rb +78 -0
  19. data/lib/torch/native/generator.rb +149 -0
  20. data/lib/torch/native/native_functions.yaml +6837 -0
  21. data/lib/torch/native/parser.rb +97 -0
  22. data/lib/torch/nn/alpha_dropout.rb +9 -0
  23. data/lib/torch/nn/avg_pool2d.rb +14 -0
  24. data/lib/torch/nn/avg_poolnd.rb +9 -0
  25. data/lib/torch/nn/bce_loss.rb +13 -0
  26. data/lib/torch/nn/bce_with_logits_loss.rb +15 -0
  27. data/lib/torch/nn/bilinear.rb +38 -0
  28. data/lib/torch/nn/conv2d.rb +14 -29
  29. data/lib/torch/nn/convnd.rb +41 -0
  30. data/lib/torch/nn/cosine_embedding_loss.rb +14 -0
  31. data/lib/torch/nn/cosine_similarity.rb +15 -0
  32. data/lib/torch/nn/cross_entropy_loss.rb +14 -0
  33. data/lib/torch/nn/ctc_loss.rb +15 -0
  34. data/lib/torch/nn/dropout.rb +9 -0
  35. data/lib/torch/nn/dropout2d.rb +9 -0
  36. data/lib/torch/nn/dropout3d.rb +9 -0
  37. data/lib/torch/nn/dropoutnd.rb +15 -0
  38. data/lib/torch/nn/embedding.rb +52 -0
  39. data/lib/torch/nn/embedding_bag.rb +34 -0
  40. data/lib/torch/nn/feature_alpha_dropout.rb +9 -0
  41. data/lib/torch/nn/functional.rb +194 -11
  42. data/lib/torch/nn/hinge_embedding_loss.rb +14 -0
  43. data/lib/torch/nn/identity.rb +14 -0
  44. data/lib/torch/nn/init.rb +58 -1
  45. data/lib/torch/nn/kl_div_loss.rb +13 -0
  46. data/lib/torch/nn/l1_loss.rb +13 -0
  47. data/lib/torch/nn/leaky_relu.rb +20 -0
  48. data/lib/torch/nn/linear.rb +12 -11
  49. data/lib/torch/nn/log_softmax.rb +14 -0
  50. data/lib/torch/nn/loss.rb +10 -0
  51. data/lib/torch/nn/margin_ranking_loss.rb +14 -0
  52. data/lib/torch/nn/max_pool2d.rb +9 -0
  53. data/lib/torch/nn/max_poolnd.rb +19 -0
  54. data/lib/torch/nn/module.rb +184 -19
  55. data/lib/torch/nn/mse_loss.rb +2 -2
  56. data/lib/torch/nn/multi_label_margin_loss.rb +13 -0
  57. data/lib/torch/nn/multi_label_soft_margin_loss.rb +13 -0
  58. data/lib/torch/nn/multi_margin_loss.rb +17 -0
  59. data/lib/torch/nn/nll_loss.rb +14 -0
  60. data/lib/torch/nn/pairwise_distance.rb +16 -0
  61. data/lib/torch/nn/parameter.rb +4 -0
  62. data/lib/torch/nn/poisson_nll_loss.rb +16 -0
  63. data/lib/torch/nn/prelu.rb +19 -0
  64. data/lib/torch/nn/relu.rb +8 -3
  65. data/lib/torch/nn/rnn.rb +22 -0
  66. data/lib/torch/nn/rnn_base.rb +154 -0
  67. data/lib/torch/nn/sequential.rb +1 -10
  68. data/lib/torch/nn/sigmoid.rb +9 -0
  69. data/lib/torch/nn/smooth_l1_loss.rb +13 -0
  70. data/lib/torch/nn/soft_margin_loss.rb +13 -0
  71. data/lib/torch/nn/softmax.rb +18 -0
  72. data/lib/torch/nn/softmax2d.rb +10 -0
  73. data/lib/torch/nn/softmin.rb +14 -0
  74. data/lib/torch/nn/softplus.rb +19 -0
  75. data/lib/torch/nn/triplet_margin_loss.rb +18 -0
  76. data/lib/torch/nn/weighted_loss.rb +10 -0
  77. data/lib/torch/optim/adadelta.rb +57 -0
  78. data/lib/torch/optim/adagrad.rb +71 -0
  79. data/lib/torch/optim/adam.rb +81 -0
  80. data/lib/torch/optim/adamax.rb +68 -0
  81. data/lib/torch/optim/adamw.rb +82 -0
  82. data/lib/torch/optim/asgd.rb +65 -0
  83. data/lib/torch/optim/lr_scheduler/lr_scheduler.rb +33 -0
  84. data/lib/torch/optim/lr_scheduler/step_lr.rb +17 -0
  85. data/lib/torch/optim/optimizer.rb +62 -0
  86. data/lib/torch/optim/rmsprop.rb +76 -0
  87. data/lib/torch/optim/rprop.rb +68 -0
  88. data/lib/torch/optim/sgd.rb +60 -0
  89. data/lib/torch/random.rb +10 -0
  90. data/lib/torch/tensor.rb +92 -21
  91. data/lib/torch/utils/data/data_loader.rb +15 -0
  92. data/lib/torch/utils/data/tensor_dataset.rb +8 -1
  93. data/lib/torch/version.rb +1 -1
  94. metadata +74 -3
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 38e16e7f07d004fd9625f168694356d551c79cbc62b0131fe1403e4c0995f296
4
- data.tar.gz: 66bf6ae0e4dd373a7542fbfb1cfb9dbd89fc455e16166e6a76d0945b32fecf38
3
+ metadata.gz: 6b47306ed525e1a20d25cb8324d4658f750c18afa5704c9b7bafc215d8f568c1
4
+ data.tar.gz: dad6ddf955b111989b061e5af146006a32c83dc1ea1ca5005a6b6e34bc9a4892
5
5
  SHA512:
6
- metadata.gz: d100e3a21ac877fe93ac61e9b5e0d8a5e61126684fc037dda3e9f703b040188b1e1523aa4111dff4aaf92ada1001597c5f60674b9583b14d31afd18dbf1ff18d
7
- data.tar.gz: c234dee79e26d3ee25ade2aaddd75f155dea6d59d8b9c5af2c571423a7aaa8a6489f5cfce89f09f390468a951b1644a4212c19525a79816be09214f0938860a8
6
+ metadata.gz: 5d26e3642bf7cd921b9b570052df353d4c32b1bd955a6fbbf5f30249631fa4c0d4624f4fa91a1c06f61b3b0d6461cd117ab4df185cf013e915d2f63e52dbcf7c
7
+ data.tar.gz: 1728ce9b579f41f7a567e63d7256c82bb352840b67f16d88aac930a99e5abbf5a5f4061c5f9da16fb47d1664567e7956d276a8b2b44f13d2263032486afb53e8
data/CHANGELOG.md CHANGED
@@ -1,3 +1,43 @@
1
+ ## 0.1.5 (2019-12-06)
2
+
3
+ - Added many more functions
4
+ - Added tensor classes - `FloatTensor`, `LongTensor`, etc
5
+ - Improved modules
6
+
7
+ ## 0.1.4 (2019-12-01)
8
+
9
+ - Added distance functions
10
+ - Added more activations
11
+ - Added more linear layers
12
+ - Added more loss functions
13
+ - Added more init methods
14
+ - Added support for tensor assignment
15
+
16
+ ## 0.1.3 (2019-11-30)
17
+
18
+ - Changed to BSD 3-Clause license to match PyTorch
19
+ - Added many optimizers
20
+ - Added `StepLR` learning rate scheduler
21
+ - Added dropout
22
+ - Added embedding
23
+ - Added support for `bool` type
24
+ - Improved performance of `from_numo`
25
+
26
+ ## 0.1.2 (2019-11-27)
27
+
28
+ - Added SGD optimizer
29
+ - Added support for gradient to `backward` method
30
+ - Added `argmax`, `eq`, `leaky_relu`, `prelu`, and `reshape` methods
31
+ - Improved indexing
32
+ - Fixed `zero_grad`
33
+ - Fixed error with infinite values
34
+
35
+ ## 0.1.1 (2019-11-26)
36
+
37
+ - Added support for `uint8` and `int8` types
38
+ - Fixed `undefined symbol` error on Linux
39
+ - Fixed C++ error messages
40
+
1
41
  ## 0.1.0 (2019-11-26)
2
42
 
3
43
  - First release
data/LICENSE.txt CHANGED
@@ -1,22 +1,46 @@
1
- Copyright (c) 2019 Andrew Kane
2
-
3
- MIT License
4
-
5
- Permission is hereby granted, free of charge, to any person obtaining
6
- a copy of this software and associated documentation files (the
7
- "Software"), to deal in the Software without restriction, including
8
- without limitation the rights to use, copy, modify, merge, publish,
9
- distribute, sublicense, and/or sell copies of the Software, and to
10
- permit persons to whom the Software is furnished to do so, subject to
11
- the following conditions:
12
-
13
- The above copyright notice and this permission notice shall be
14
- included in all copies or substantial portions of the Software.
15
-
16
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
17
- EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
18
- MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
19
- NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
20
- LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
21
- OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
22
- WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
1
+ BSD 3-Clause License
2
+
3
+ From Torch-rb:
4
+
5
+ Copyright (c) 2019- Andrew Kane
6
+
7
+ From PyTorch (for ported code):
8
+
9
+ Copyright (c) 2016- Facebook, Inc (Adam Paszke)
10
+ Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
11
+ Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
12
+ Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
13
+ Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
14
+ Copyright (c) 2011-2013 NYU (Clement Farabet)
15
+ Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
16
+ Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
17
+ Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
18
+
19
+ All rights reserved.
20
+
21
+ Redistribution and use in source and binary forms, with or without
22
+ modification, are permitted provided that the following conditions are met:
23
+
24
+ 1. Redistributions of source code must retain the above copyright
25
+ notice, this list of conditions and the following disclaimer.
26
+
27
+ 2. Redistributions in binary form must reproduce the above copyright
28
+ notice, this list of conditions and the following disclaimer in the
29
+ documentation and/or other materials provided with the distribution.
30
+
31
+ 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America
32
+ and IDIAP Research Institute nor the names of its contributors may be
33
+ used to endorse or promote products derived from this software without
34
+ specific prior written permission.
35
+
36
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
37
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
38
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
39
+ ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
40
+ LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
41
+ CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
42
+ SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
43
+ INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
44
+ CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
45
+ ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
46
+ POSSIBILITY OF SUCH DAMAGE.
data/README.md CHANGED
@@ -2,14 +2,16 @@
2
2
 
3
3
  :fire: Deep learning for Ruby, powered by [LibTorch](https://pytorch.org)
4
4
 
5
- **Note:** This gem is currently experimental. There may be breaking changes between each release.
5
+ This gem is currently experimental. There may be breaking changes between each release. Please report any issues you experience.
6
+
7
+ [![Build Status](https://travis-ci.org/ankane/torch-rb.svg?branch=master)](https://travis-ci.org/ankane/torch-rb)
6
8
 
7
9
  ## Installation
8
10
 
9
11
  First, [install LibTorch](#libtorch-installation). For Homebrew, use:
10
12
 
11
13
  ```sh
12
- brew install ankane/brew/libtorch
14
+ brew install libtorch
13
15
  ```
14
16
 
15
17
  Add this line to your application’s Gemfile:
@@ -18,6 +20,8 @@ Add this line to your application’s Gemfile:
18
20
  gem 'torch-rb'
19
21
  ```
20
22
 
23
+ It can take a few minutes to compile the extension.
24
+
21
25
  ## Getting Started
22
26
 
23
27
  This library follows the [PyTorch API](https://pytorch.org/docs/stable/torch.html). There are a few changes to make it more Ruby-like:
@@ -26,9 +30,11 @@ This library follows the [PyTorch API](https://pytorch.org/docs/stable/torch.htm
26
30
  - Methods that return booleans use `?` instead of `is_` (`tensor?` instead of `is_tensor`)
27
31
  - Numo is used instead of NumPy (`x.numo` instead of `x.numpy()`)
28
32
 
29
- Many methods and options are missing at the moment. PRs welcome!
33
+ Some methods and options are missing at the moment. PRs welcome!
30
34
 
31
- Some examples below are from [Deep Learning with PyTorch: A 60 Minutes Blitz](https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html).
35
+ ## Tutorial
36
+
37
+ Some examples below are from [Deep Learning with PyTorch: A 60 Minutes Blitz](https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html)
32
38
 
33
39
  ### Tensors
34
40
 
@@ -143,7 +149,7 @@ Convert a Numo array to a tensor
143
149
 
144
150
  ```ruby
145
151
  b = Numo::NArray.cast([1, 2, 3])
146
- Torch.from_numpy(b)
152
+ Torch.from_numo(b)
147
153
  ```
148
154
 
149
155
  ### Autograd
@@ -171,17 +177,17 @@ out.backward
171
177
  Get gradients
172
178
 
173
179
  ```ruby
174
- x.grad
180
+ x.grad # tensor([[4.5, 4.5], [4.5, 4.5]])
175
181
  ```
176
182
 
177
183
  Stop autograd from tracking history
178
184
 
179
185
  ```ruby
180
186
  x.requires_grad # true
181
- (x ** 2).requires_grad # true
187
+ (x**2).requires_grad # true
182
188
 
183
189
  Torch.no_grad do
184
- (x ** 2).requires_grad # false
190
+ (x**2).requires_grad # false
185
191
  end
186
192
  ```
187
193
 
@@ -221,7 +227,7 @@ class Net < Torch::NN::Module
221
227
  end
222
228
  ```
223
229
 
224
- And run
230
+ Create an instance of it
225
231
 
226
232
  ```ruby
227
233
  net = Net.new
@@ -229,6 +235,58 @@ input = Torch.randn(1, 1, 32, 32)
229
235
  net.call(input)
230
236
  ```
231
237
 
238
+ Get trainable parameters
239
+
240
+ ```ruby
241
+ net.parameters
242
+ ```
243
+
244
+ Zero the gradient buffers and backprop with random gradients
245
+
246
+ ```ruby
247
+ net.zero_grad
248
+ out.backward(Torch.randn(1, 10))
249
+ ```
250
+
251
+ Define a loss function
252
+
253
+ ```ruby
254
+ output = net.call(input)
255
+ target = Torch.randn(10)
256
+ target = target.view(1, -1)
257
+ criterion = Torch::NN::MSELoss.new
258
+ loss = criterion.call(output, target)
259
+ ```
260
+
261
+ Backprop
262
+
263
+ ```ruby
264
+ net.zero_grad
265
+ p net.conv1.bias.grad
266
+ loss.backward
267
+ p net.conv1.bias.grad
268
+ ```
269
+
270
+ Update the weights
271
+
272
+ ```ruby
273
+ learning_rate = 0.01
274
+ net.parameters.each do |f|
275
+ f.data.sub!(f.grad.data * learning_rate)
276
+ end
277
+ ```
278
+
279
+ Use an optimizer
280
+
281
+ ```ruby
282
+ optimizer = Torch::Optim::SGD.new(net.parameters, lr: 0.01)
283
+ optimizer.zero_grad
284
+ output = net.call(input)
285
+ loss = criterion.call(output, target)
286
+ loss.backward
287
+ optimizer.step
288
+ ```
289
+
232
290
  ### Tensor Creation
233
291
 
234
292
  Here’s a list of functions to create tensors (descriptions from the [C++ docs](https://pytorch.org/cppdocs/notes/tensor_creation.html)):
@@ -242,7 +300,7 @@ Here’s a list of functions to create tensors (descriptions from the [C++ docs]
242
300
  - `empty` returns a tensor with uninitialized values
243
301
 
244
302
  ```ruby
245
- Torch.empty(3)
303
+ Torch.empty(3) # tensor([7.0054e-45, 0.0000e+00, 0.0000e+00])
246
304
  ```
247
305
 
248
306
  - `eye` returns an identity matrix
@@ -278,19 +336,19 @@ Here’s a list of functions to create tensors (descriptions from the [C++ docs]
278
336
  - `rand` returns a tensor filled with values drawn from a uniform distribution on [0, 1)
279
337
 
280
338
  ```ruby
281
- Torch.rand(3)
339
+ Torch.rand(3) # tensor([0.5444, 0.8799, 0.5571])
282
340
  ```
283
341
 
284
342
  - `randint` returns a tensor with integers randomly drawn from an interval
285
343
 
286
344
  ```ruby
287
- Torch.randint(1, 10, [3])
345
+ Torch.randint(1, 10, [3]) # tensor([7, 6, 4])
288
346
  ```
289
347
 
290
348
  - `randn` returns a tensor filled with values drawn from a unit normal distribution
291
349
 
292
350
  ```ruby
293
- Torch.randn(3)
351
+ Torch.randn(3) # tensor([-0.7147, 0.6614, 1.1453])
294
352
  ```
295
353
 
296
354
  - `randperm` returns a tensor filled with a random permutation of integers in some interval
@@ -305,12 +363,20 @@ Here’s a list of functions to create tensors (descriptions from the [C++ docs]
305
363
  Torch.zeros(3) # tensor([0, 0, 0])
306
364
  ```
307
365
 
366
+ ## Examples
367
+
368
+ Here are a few full examples:
369
+
370
+ - [Image classification with MNIST](examples/mnist) ([日本語版](https://qiita.com/kojix2/items/c19c36dc1bf73ea93409))
371
+ - [Collaborative filtering with MovieLens](examples/movielens)
372
+ - [Word embeddings](examples/nlp)
373
+
308
374
  ## LibTorch Installation
309
375
 
310
- [Download LibTorch](https://pytorch.org/) and run:
376
+ [Download LibTorch](https://pytorch.org/). For Linux, use the `cxx11 ABI` version. Then run:
311
377
 
312
378
  ```sh
313
- gem install torch-rb -- --with-torch-dir=/path/to/libtorch
379
+ bundle config build.torch-rb --with-torch-dir=/path/to/libtorch
314
380
  ```
315
381
 
316
382
  ### Homebrew
@@ -318,10 +384,10 @@ gem install torch-rb -- --with-torch-dir=/path/to/libtorch
318
384
  For Mac, you can use Homebrew.
319
385
 
320
386
  ```sh
321
- brew install ankane/brew/libtorch
387
+ brew install libtorch
322
388
  ```
323
389
 
324
- Then install the gem (no need for `--with-torch-dir`).
390
+ Then install the gem (no need for `bundle config`).
325
391
 
326
392
  ## rbenv
327
393
 
@@ -349,9 +415,9 @@ To get started with development:
349
415
 
350
416
  ```sh
351
417
  git clone https://github.com/ankane/torch-rb.git
352
- cd torch
418
+ cd torch-rb
353
419
  bundle install
354
- bundle exec rake compile
420
+ bundle exec rake compile -- --with-torch-dir=/path/to/libtorch
355
421
  bundle exec rake test
356
422
  ```
357
423
 
data/ext/torch/ext.cpp CHANGED
@@ -6,95 +6,36 @@
6
6
  #include <rice/Class.hpp>
7
7
  #include <rice/Constructor.hpp>
8
8
 
9
- using namespace Rice;
9
+ #include "templates.hpp"
10
10
 
11
- template<>
12
- inline
13
- long long from_ruby<long long>(Object x)
14
- {
15
- return NUM2LL(x);
16
- }
17
-
18
- template<>
19
- inline
20
- Object to_ruby<long long>(long long const & x)
21
- {
22
- return LL2NUM(x);
23
- }
24
-
25
- template<>
26
- inline
27
- unsigned long long from_ruby<unsigned long long>(Object x)
28
- {
29
- return NUM2ULL(x);
30
- }
11
+ // generated with:
12
+ // rake generate:functions
13
+ #include "torch_functions.hpp"
14
+ #include "tensor_functions.hpp"
15
+ #include "nn_functions.hpp"
31
16
 
32
- template<>
33
- inline
34
- Object to_ruby<unsigned long long>(unsigned long long const & x)
35
- {
36
- return ULL2NUM(x);
37
- }
38
-
39
- template<>
40
- inline
41
- short from_ruby<short>(Object x)
42
- {
43
- return NUM2SHORT(x);
44
- }
45
-
46
- template<>
47
- inline
48
- Object to_ruby<short>(short const & x)
49
- {
50
- return INT2NUM(x);
51
- }
17
+ using namespace Rice;
52
18
 
53
- template<>
54
- inline
55
- unsigned short from_ruby<unsigned short>(Object x)
56
- {
57
- return NUM2USHORT(x);
19
+ Object tensor_array(std::tuple<torch::Tensor, torch::Tensor> x) {
20
+ Array a;
21
+ a.push(to_ruby<torch::Tensor>(std::get<0>(x)));
22
+ a.push(to_ruby<torch::Tensor>(std::get<1>(x)));
23
+ return Object(a);
58
24
  }
59
25
 
60
- template<>
61
- inline
62
- Object to_ruby<unsigned short>(unsigned short const & x)
26
+ extern "C"
27
+ void Init_ext()
63
28
  {
64
- return UINT2NUM(x);
65
- }
29
+ Module rb_mTorch = define_module("Torch");
30
+ add_torch_functions(rb_mTorch);
66
31
 
67
- // need to wrap torch::IntArrayRef() since
68
- // it doesn't own underlying data
69
- class IntArrayRef {
70
- std::vector<int64_t> vec;
71
- public:
72
- IntArrayRef(Object o) {
73
- Array a = Array(o);
74
- for (size_t i = 0; i < a.size(); i++) {
75
- vec.push_back(from_ruby<int64_t>(a[i]));
76
- }
77
- }
78
- operator torch::IntArrayRef() {
79
- return torch::IntArrayRef(vec);
80
- }
81
- };
32
+ Class rb_cTensor = define_class_under<torch::Tensor>(rb_mTorch, "Tensor");
33
+ add_tensor_functions(rb_cTensor);
82
34
 
83
- template<>
84
- inline
85
- IntArrayRef from_ruby<IntArrayRef>(Object x)
86
- {
87
- return IntArrayRef(x);
88
- }
89
-
90
- // for now
91
- typedef float Scalar;
35
+ Module rb_mNN = define_module_under(rb_mTorch, "NN");
36
+ add_nn_functions(rb_mNN);
92
37
 
93
- extern "C"
94
- void Init_ext()
95
- {
96
- Module rb_mTorch = define_module("Torch")
97
- .define_singleton_method(
38
+ rb_mTorch.define_singleton_method(
98
39
  "grad_enabled?",
99
40
  *[]() {
100
41
  return torch::GradMode::is_enabled();
@@ -104,11 +45,6 @@ void Init_ext()
104
45
  *[](bool enabled) {
105
46
  torch::GradMode::set_enabled(enabled);
106
47
  })
107
- .define_singleton_method(
108
- "floating_point?",
109
- *[](torch::Tensor& input) {
110
- return torch::is_floating_point(input);
111
- })
112
48
  .define_singleton_method(
113
49
  "manual_seed",
114
50
  *[](uint64_t seed) {
@@ -178,172 +114,117 @@ void Init_ext()
178
114
  // begin operations
179
115
  .define_singleton_method(
180
116
  "_mean",
181
- *[](torch::Tensor& input) {
117
+ *[](Tensor& input) {
182
118
  return torch::mean(input);
183
119
  })
184
120
  .define_singleton_method(
185
121
  "_mean_dim",
186
- *[](torch::Tensor& input, int64_t dim, bool keepdim) {
122
+ *[](Tensor& input, int64_t dim, bool keepdim) {
187
123
  return torch::mean(input, dim, keepdim);
188
124
  })
189
125
  .define_singleton_method(
190
126
  "_sum",
191
- *[](torch::Tensor& input) {
127
+ *[](Tensor& input) {
192
128
  return torch::sum(input);
193
129
  })
194
130
  .define_singleton_method(
195
131
  "_sum_dim",
196
- *[](torch::Tensor& input, int64_t dim, bool keepdim) {
132
+ *[](Tensor& input, int64_t dim, bool keepdim) {
197
133
  return torch::sum(input, dim, keepdim);
198
134
  })
199
135
  .define_singleton_method(
200
- "_norm",
201
- *[](torch::Tensor& input) {
202
- return torch::norm(input);
203
- })
204
- .define_singleton_method(
205
- "_min",
206
- *[](torch::Tensor& input) {
207
- return torch::min(input);
136
+ "_max_out",
137
+ *[](Tensor &max, Tensor &max_indices, const Tensor &input, int64_t dim, bool keepdim) {
138
+ return tensor_array(torch::_max_out(max, max_indices, input, dim, keepdim));
208
139
  })
209
140
  .define_singleton_method(
210
- "_max",
211
- *[](torch::Tensor& input) {
212
- return torch::max(input);
141
+ "_topk",
142
+ *[](Tensor& input, int64_t k) {
143
+ return tensor_array(torch::topk(input, k));
213
144
  })
214
145
  .define_singleton_method(
215
- "_exp",
216
- *[](torch::Tensor& input) {
217
- return torch::exp(input);
146
+ "_softmax",
147
+ *[](const Tensor &input, int64_t dim) {
148
+ return torch::softmax(input, dim);
218
149
  })
219
150
  .define_singleton_method(
220
- "_log",
221
- *[](torch::Tensor& input) {
222
- return torch::log(input);
151
+ "_log_softmax",
152
+ *[](Tensor& input, int64_t dim) {
153
+ return torch::log_softmax(input, dim);
223
154
  })
224
155
  .define_singleton_method(
225
- "_unsqueeze",
226
- *[](torch::Tensor& input, int64_t dim) {
227
- return torch::unsqueeze(input, dim);
228
- })
229
- .define_singleton_method(
230
- "_dot",
231
- *[](torch::Tensor& input, torch::Tensor& tensor) {
232
- return torch::dot(input, tensor);
233
- })
234
- .define_singleton_method(
235
- "_matmul",
236
- *[](torch::Tensor& input, torch::Tensor& other) {
237
- return torch::matmul(input, other);
238
- })
239
- .define_singleton_method(
240
- "_add",
241
- *[](torch::Tensor& input, torch::Tensor& other) {
242
- return torch::add(input, other);
243
- })
244
- .define_singleton_method(
245
- "_add_scalar",
246
- *[](torch::Tensor& input, float other) {
247
- return torch::add(input, other);
248
- })
249
- .define_singleton_method(
250
- "_add_out",
251
- *[](torch::Tensor& out, torch::Tensor& input, torch::Tensor& other) {
252
- return torch::add_out(out, input, other);
253
- })
254
- .define_singleton_method(
255
- "_sub",
256
- *[](torch::Tensor& input, torch::Tensor& other) {
257
- return torch::sub(input, other);
258
- })
259
- .define_singleton_method(
260
- "_sub_scalar",
261
- *[](torch::Tensor& input, float other) {
262
- return torch::sub(input, other);
263
- })
264
- .define_singleton_method(
265
- "_mul",
266
- *[](torch::Tensor& input, torch::Tensor& other) {
267
- return torch::mul(input, other);
268
- })
269
- .define_singleton_method(
270
- "_mul_scalar",
271
- *[](torch::Tensor& input, float other) {
272
- return torch::mul(input, other);
273
- })
274
- .define_singleton_method(
275
- "_div",
276
- *[](torch::Tensor& input, torch::Tensor& other) {
277
- return torch::div(input, other);
278
- })
279
- .define_singleton_method(
280
- "_div_scalar",
281
- *[](torch::Tensor& input, float other) {
282
- return torch::div(input, other);
283
- })
284
- .define_singleton_method(
285
- "_remainder",
286
- *[](torch::Tensor& input, torch::Tensor& other) {
287
- return torch::remainder(input, other);
288
- })
289
- .define_singleton_method(
290
- "_remainder_scalar",
291
- *[](torch::Tensor& input, float other) {
292
- return torch::remainder(input, other);
156
+ "relu",
157
+ *[](Tensor& input) {
158
+ return torch::relu(input);
293
159
  })
294
160
  .define_singleton_method(
295
- "_pow",
296
- *[](torch::Tensor& input, Scalar exponent) {
297
- return torch::pow(input, exponent);
161
+ "prelu",
162
+ *[](torch::Tensor& input, torch::Tensor& weight) {
163
+ return torch::prelu(input, weight);
298
164
  })
299
165
  .define_singleton_method(
300
- "_neg",
301
- *[](torch::Tensor& input) {
302
- return torch::neg(input);
166
+ "leaky_relu",
167
+ *[](torch::Tensor& input, Scalar negative_slope) {
168
+ return torch::leaky_relu(input, negative_slope);
303
169
  })
304
170
  .define_singleton_method(
305
- "relu",
306
- *[](torch::Tensor& input) {
307
- return torch::relu(input);
171
+ "conv2d",
172
+ *[](Tensor& input, Tensor& weight, Tensor& bias, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, int64_t groups) {
173
+ return torch::conv2d(input, weight, bias, stride, padding, dilation, groups);
308
174
  })
175
+ // linear layers
309
176
  .define_singleton_method(
310
- "conv2d",
311
- *[](torch::Tensor& input, torch::Tensor& weight, torch::Tensor& bias) {
312
- return torch::conv2d(input, weight, bias);
177
+ "bilinear",
178
+ *[](const Tensor &input1, const Tensor &input2, const Tensor &weight, const Tensor &bias) {
179
+ return torch::bilinear(input1, input2, weight, bias);
313
180
  })
314
181
  .define_singleton_method(
315
182
  "linear",
316
- *[](torch::Tensor& input, torch::Tensor& weight, torch::Tensor& bias) {
183
+ *[](Tensor& input, Tensor& weight, Tensor& bias) {
317
184
  return torch::linear(input, weight, bias);
318
185
  })
186
+ // pooling layers
319
187
  .define_singleton_method(
320
188
  "max_pool2d",
321
- *[](torch::Tensor& input, IntArrayRef kernel_size) {
189
+ *[](Tensor& input, IntArrayRef kernel_size) {
322
190
  return torch::max_pool2d(input, kernel_size);
323
191
  })
324
192
  .define_singleton_method(
325
- "mse_loss",
326
- *[](torch::Tensor& input, torch::Tensor& target, std::string reduction) {
327
- auto red = reduction == "mean" ? Reduction::Mean : Reduction::Sum;
328
- return torch::mse_loss(input, target, red);
193
+ "avg_pool2d",
194
+ *[](Tensor& input, IntArrayRef kernel_size) {
195
+ return torch::avg_pool2d(input, kernel_size);
196
+ })
197
+ .define_singleton_method(
198
+ "_binary_cross_entropy_with_logits",
199
+ *[](const Tensor &input, const Tensor &target, OptionalTensor weight, OptionalTensor pos_weight, MyReduction reduction) {
200
+ return torch::binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction);
329
201
  })
202
+ .define_singleton_method("numel", &torch::numel)
330
203
  .define_singleton_method(
331
- "nll_loss",
332
- *[](torch::Tensor& input, torch::Tensor& target) {
333
- return torch::nll_loss(input, target);
204
+ "_from_blob",
205
+ *[](String s, IntArrayRef size, const torch::TensorOptions &options) {
206
+ void *data = const_cast<char *>(s.c_str());
207
+ return torch::from_blob(data, size, options);
334
208
  })
335
209
  .define_singleton_method(
336
210
  "_tensor",
337
211
  *[](Object o, IntArrayRef size, const torch::TensorOptions &options) {
338
212
  Array a = Array(o);
339
- std::vector<float> vec;
340
- for (size_t i = 0; i < a.size(); i++) {
341
- vec.push_back(from_ruby<float>(a[i]));
213
+ auto dtype = options.dtype();
214
+ torch::Tensor t;
215
+ if (dtype == torch::kBool) {
216
+ throw std::runtime_error("Cannot create bool from tensor method yet");
217
+ } else {
218
+ std::vector<float> vec;
219
+ for (size_t i = 0; i < a.size(); i++) {
220
+ vec.push_back(from_ruby<float>(a[i]));
221
+ }
222
+ t = torch::tensor(vec, options);
342
223
  }
343
- return torch::tensor(vec, options).reshape(size);
224
+ return t.reshape(size);
344
225
  });
345
226
 
346
- Class rb_cTensor = define_class_under<torch::Tensor>(rb_mTorch, "Tensor")
227
+ rb_cTensor
347
228
  .define_method("cuda?", &torch::Tensor::is_cuda)
348
229
  .define_method("distributed?", &torch::Tensor::is_distributed)
349
230
  .define_method("complex?", &torch::Tensor::is_complex)
@@ -352,108 +233,162 @@ void Init_ext()
352
233
  .define_method("sparse?", &torch::Tensor::is_sparse)
353
234
  .define_method("quantized?", &torch::Tensor::is_quantized)
354
235
  .define_method("dim", &torch::Tensor::dim)
355
- .define_method("numel", &torch::Tensor::numel)
356
236
  .define_method("element_size", &torch::Tensor::element_size)
357
237
  .define_method("requires_grad", &torch::Tensor::requires_grad)
238
+ .define_method("view_as", &torch::Tensor::view_as)
239
+ .define_method(
240
+ "addcmul!",
241
+ *[](Tensor& self, Scalar value, const Tensor & tensor1, const Tensor & tensor2) {
242
+ return self.addcmul_(tensor1, tensor2, value);
243
+ })
244
+ .define_method(
245
+ "addcdiv!",
246
+ *[](Tensor& self, Scalar value, const Tensor & tensor1, const Tensor & tensor2) {
247
+ return self.addcdiv_(tensor1, tensor2, value);
248
+ })
358
249
  .define_method(
359
250
  "zero!",
360
- *[](torch::Tensor& self) {
251
+ *[](Tensor& self) {
361
252
  return self.zero_();
362
253
  })
363
254
  .define_method(
364
- "detach!",
365
- *[](torch::Tensor& self) {
366
- return self.detach_();
255
+ "detach",
256
+ *[](Tensor& self) {
257
+ return self.detach();
367
258
  })
368
259
  .define_method(
369
- "_access",
370
- *[](torch::Tensor& self, int64_t index) {
371
- return self[index];
260
+ "detach!",
261
+ *[](Tensor& self) {
262
+ return self.detach_();
372
263
  })
373
264
  .define_method(
374
265
  "_requires_grad!",
375
- *[](torch::Tensor& self, bool requires_grad) {
266
+ *[](Tensor& self, bool requires_grad) {
376
267
  return self.set_requires_grad(requires_grad);
377
268
  })
378
269
  .define_method(
379
- "backward",
380
- *[](torch::Tensor& self) {
381
- return self.backward();
270
+ "_backward",
271
+ *[](Tensor& self, Object gradient) {
272
+ return gradient.is_nil() ? self.backward() : self.backward(from_ruby<torch::Tensor>(gradient));
382
273
  })
383
274
  .define_method(
384
275
  "grad",
385
- *[](torch::Tensor& self) {
276
+ *[](Tensor& self) {
386
277
  return self.grad();
387
278
  })
388
279
  .define_method(
389
280
  "_dtype",
390
- *[](torch::Tensor& self) {
281
+ *[](Tensor& self) {
391
282
  return (int) at::typeMetaToScalarType(self.dtype());
392
283
  })
284
+ .define_method(
285
+ "_type",
286
+ *[](Tensor& self, int dtype) {
287
+ return self.toType((torch::ScalarType) dtype);
288
+ })
393
289
  .define_method(
394
290
  "_layout",
395
- *[](torch::Tensor& self) {
291
+ *[](Tensor& self) {
396
292
  std::stringstream s;
397
293
  s << self.layout();
398
294
  return s.str();
399
295
  })
400
296
  .define_method(
401
297
  "device",
402
- *[](torch::Tensor& self) {
298
+ *[](Tensor& self) {
403
299
  std::stringstream s;
404
300
  s << self.device();
405
301
  return s.str();
406
302
  })
407
303
  .define_method(
408
- "_view",
409
- *[](torch::Tensor& self, IntArrayRef size) {
410
- return self.view(size);
304
+ "resize_as!",
305
+ *[](Tensor& self, Tensor& other) {
306
+ return self.resize_as_(other);
411
307
  })
412
308
  .define_method(
413
- "add!",
414
- *[](torch::Tensor& self, torch::Tensor& other) {
415
- self.add_(other);
309
+ "fill!",
310
+ *[](Tensor& self, Scalar value) {
311
+ return self.fill_(value);
416
312
  })
417
313
  .define_method(
418
- "sub!",
419
- *[](torch::Tensor& self, torch::Tensor& other) {
420
- self.sub_(other);
314
+ "relu!",
315
+ *[](Tensor& self) {
316
+ return self.relu_();
317
+ })
318
+ .define_method(
319
+ "normal!",
320
+ *[](Tensor& self, double mean, double std) {
321
+ return self.normal_(mean, std);
322
+ })
323
+ .define_method(
324
+ "random!",
325
+ *[](Tensor& self, int64_t to) {
326
+ return self.random_(to);
421
327
  })
422
328
  .define_method(
423
- "mul!",
424
- *[](torch::Tensor& self, torch::Tensor& other) {
425
- self.mul_(other);
329
+ "sub!",
330
+ *[](Tensor& self, Tensor& other) {
331
+ return self.sub_(other);
426
332
  })
427
333
  .define_method(
428
334
  "div!",
429
- *[](torch::Tensor& self, torch::Tensor& other) {
430
- self.div_(other);
335
+ *[](Tensor& self, Tensor& other) {
336
+ return self.div_(other);
431
337
  })
432
338
  .define_method(
433
- "log_softmax",
434
- *[](torch::Tensor& self, int64_t dim) {
435
- return self.log_softmax(dim);
339
+ "sqrt!",
340
+ *[](Tensor& self) {
341
+ return self.sqrt_();
436
342
  })
437
343
  .define_method(
438
- "_data",
439
- *[](torch::Tensor& self) {
344
+ "unsqueeze!",
345
+ *[](Tensor& self, int64_t dim) {
346
+ return self.unsqueeze_(dim);
347
+ })
348
+ .define_method(
349
+ "copy!",
350
+ *[](Tensor& self, Tensor& src) {
351
+ return self.copy_(src);
352
+ })
353
+ .define_method(
354
+ "clone",
355
+ *[](Tensor& self) {
356
+ return self.clone();
357
+ })
358
+ .define_method(
359
+ "data",
360
+ *[](Tensor& self) {
361
+ return self.data();
362
+ })
363
+ .define_method(
364
+ "_flat_data",
365
+ *[](Tensor& self) {
440
366
  Array a;
441
367
  auto dtype = self.dtype();
442
368
 
443
369
  // TODO DRY if someone knows C++
444
- // TODO kByte (uint8), kChar (int8), kBool (bool)
445
- if (dtype == torch::kShort) {
446
- short* data = self.data_ptr<short>();
370
+ if (dtype == torch::kByte) {
371
+ uint8_t* data = self.data_ptr<uint8_t>();
372
+ for (int i = 0; i < self.numel(); i++) {
373
+ a.push(data[i]);
374
+ }
375
+ } else if (dtype == torch::kChar) {
376
+ int8_t* data = self.data_ptr<int8_t>();
377
+ for (int i = 0; i < self.numel(); i++) {
378
+ a.push(to_ruby<int>(data[i]));
379
+ }
380
+ } else if (dtype == torch::kShort) {
381
+ int16_t* data = self.data_ptr<int16_t>();
447
382
  for (int i = 0; i < self.numel(); i++) {
448
383
  a.push(data[i]);
449
384
  }
450
385
  } else if (dtype == torch::kInt) {
451
- int* data = self.data_ptr<int>();
386
+ int32_t* data = self.data_ptr<int32_t>();
452
387
  for (int i = 0; i < self.numel(); i++) {
453
388
  a.push(data[i]);
454
389
  }
455
390
  } else if (dtype == torch::kLong) {
456
- long long* data = self.data_ptr<long long>();
391
+ int64_t* data = self.data_ptr<int64_t>();
457
392
  for (int i = 0; i < self.numel(); i++) {
458
393
  a.push(data[i]);
459
394
  }
@@ -467,19 +402,24 @@ void Init_ext()
467
402
  for (int i = 0; i < self.numel(); i++) {
468
403
  a.push(data[i]);
469
404
  }
405
+ } else if (dtype == torch::kBool) {
406
+ bool* data = self.data_ptr<bool>();
407
+ for (int i = 0; i < self.numel(); i++) {
408
+ a.push(data[i] ? True : False);
409
+ }
470
410
  } else {
471
- throw "Unsupported type";
411
+ throw std::runtime_error("Unsupported type");
472
412
  }
473
413
  return a;
474
414
  })
475
415
  .define_method(
476
- "_size",
477
- *[](torch::Tensor& self, int i) {
478
- return self.size(i);
416
+ "_to",
417
+ *[](Tensor& self, torch::Device device, int dtype, bool non_blocking, bool copy) {
418
+ return self.to(device, (torch::ScalarType) dtype, non_blocking, copy);
479
419
  })
480
420
  .define_singleton_method(
481
421
  "_make_subclass",
482
- *[](torch::Tensor& rd, bool requires_grad) {
422
+ *[](Tensor& rd, bool requires_grad) {
483
423
  auto data = torch::autograd::as_variable_ref(rd).detach();
484
424
  data.unsafeGetTensorImpl()->set_allow_tensor_metadata_change(true);
485
425
  auto var = data.set_requires_grad(requires_grad);
@@ -499,8 +439,11 @@ void Init_ext()
499
439
  torch::Layout l;
500
440
  if (layout == "strided") {
501
441
  l = torch::kStrided;
442
+ } else if (layout == "sparse") {
443
+ l = torch::kSparse;
444
+ throw std::runtime_error("Sparse layout not supported yet");
502
445
  } else {
503
- throw "Unsupported layout";
446
+ throw std::runtime_error("Unsupported layout: " + layout);
504
447
  }
505
448
  return self.layout(l);
506
449
  })
@@ -513,7 +456,7 @@ void Init_ext()
513
456
  } else if (device == "cuda") {
514
457
  d = torch::kCUDA;
515
458
  } else {
516
- throw "Unsupported device";
459
+ throw std::runtime_error("Unsupported device: " + device);
517
460
  }
518
461
  return self.device(d);
519
462
  })
@@ -523,24 +466,99 @@ void Init_ext()
523
466
  return self.requires_grad(requires_grad);
524
467
  });
525
468
 
526
- Module rb_mNN = define_module_under(rb_mTorch, "NN");
527
-
528
469
  Module rb_mInit = define_module_under(rb_mNN, "Init")
529
470
  .define_singleton_method(
530
- "kaiming_uniform_",
531
- *[](torch::Tensor& input, double a) {
532
- return torch::nn::init::kaiming_uniform_(input, a);
471
+ "_calculate_gain",
472
+ *[](NonlinearityType nonlinearity, double param) {
473
+ return torch::nn::init::calculate_gain(nonlinearity, param);
533
474
  })
534
475
  .define_singleton_method(
535
- "uniform_",
536
- *[](torch::Tensor& input, double to, double from) {
537
- return torch::nn::init::uniform_(input, to, from);
476
+ "_uniform!",
477
+ *[](Tensor tensor, double low, double high) {
478
+ return torch::nn::init::uniform_(tensor, low, high);
479
+ })
480
+ .define_singleton_method(
481
+ "_normal!",
482
+ *[](Tensor tensor, double mean, double std) {
483
+ return torch::nn::init::normal_(tensor, mean, std);
484
+ })
485
+ .define_singleton_method(
486
+ "_constant!",
487
+ *[](Tensor tensor, Scalar value) {
488
+ return torch::nn::init::constant_(tensor, value);
489
+ })
490
+ .define_singleton_method(
491
+ "_ones!",
492
+ *[](Tensor tensor) {
493
+ return torch::nn::init::ones_(tensor);
494
+ })
495
+ .define_singleton_method(
496
+ "_zeros!",
497
+ *[](Tensor tensor) {
498
+ return torch::nn::init::zeros_(tensor);
499
+ })
500
+ .define_singleton_method(
501
+ "_eye!",
502
+ *[](Tensor tensor) {
503
+ return torch::nn::init::eye_(tensor);
504
+ })
505
+ .define_singleton_method(
506
+ "_dirac!",
507
+ *[](Tensor tensor) {
508
+ return torch::nn::init::dirac_(tensor);
509
+ })
510
+ .define_singleton_method(
511
+ "_xavier_uniform!",
512
+ *[](Tensor tensor, double gain) {
513
+ return torch::nn::init::xavier_uniform_(tensor, gain);
514
+ })
515
+ .define_singleton_method(
516
+ "_xavier_normal!",
517
+ *[](Tensor tensor, double gain) {
518
+ return torch::nn::init::xavier_normal_(tensor, gain);
519
+ })
520
+ .define_singleton_method(
521
+ "_kaiming_uniform!",
522
+ *[](Tensor tensor, double a, FanModeType mode, NonlinearityType nonlinearity) {
523
+ return torch::nn::init::kaiming_uniform_(tensor, a, mode, nonlinearity);
524
+ })
525
+ .define_singleton_method(
526
+ "_kaiming_normal!",
527
+ *[](Tensor tensor, double a, FanModeType mode, NonlinearityType nonlinearity) {
528
+ return torch::nn::init::kaiming_normal_(tensor, a, mode, nonlinearity);
529
+ })
530
+ .define_singleton_method(
531
+ "_orthogonal!",
532
+ *[](Tensor tensor, double gain) {
533
+ return torch::nn::init::orthogonal_(tensor, gain);
534
+ })
535
+ .define_singleton_method(
536
+ "_sparse!",
537
+ *[](Tensor tensor, double sparsity, double std) {
538
+ return torch::nn::init::sparse_(tensor, sparsity, std);
538
539
  });
539
540
 
540
541
  Class rb_cParameter = define_class_under<torch::autograd::Variable, torch::Tensor>(rb_mNN, "Parameter")
541
542
  .define_method(
542
543
  "grad",
543
544
  *[](torch::autograd::Variable& self) {
544
- return self.grad();
545
+ auto grad = self.grad();
546
+ return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
547
+ });
548
+
549
+ Class rb_cDevice = define_class_under<torch::Device>(rb_mTorch, "Device")
550
+ .define_constructor(Constructor<torch::Device, std::string>())
551
+ .define_method("index", &torch::Device::index)
552
+ .define_method("index?", &torch::Device::has_index)
553
+ .define_method(
554
+ "type",
555
+ *[](torch::Device& self) {
556
+ std::stringstream s;
557
+ s << self.type();
558
+ return s.str();
545
559
  });
560
+
561
+ Module rb_mCUDA = define_module_under(rb_mTorch, "CUDA")
562
+ .define_singleton_method("available?", &torch::cuda::is_available)
563
+ .define_singleton_method("device_count", &torch::cuda::device_count);
546
564
  }