torch-rb 0.1.1 → 0.1.6

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (142) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +40 -0
  3. data/LICENSE.txt +46 -22
  4. data/README.md +73 -9
  5. data/ext/torch/ext.cpp +148 -315
  6. data/ext/torch/extconf.rb +6 -0
  7. data/ext/torch/nn_functions.cpp +615 -0
  8. data/ext/torch/nn_functions.hpp +6 -0
  9. data/ext/torch/templates.cpp +55 -0
  10. data/ext/torch/templates.hpp +298 -0
  11. data/ext/torch/tensor_functions.cpp +1920 -0
  12. data/ext/torch/tensor_functions.hpp +6 -0
  13. data/ext/torch/torch_functions.cpp +2975 -0
  14. data/ext/torch/torch_functions.hpp +6 -0
  15. data/lib/torch.rb +236 -112
  16. data/lib/torch/ext.bundle +0 -0
  17. data/lib/torch/inspector.rb +52 -25
  18. data/lib/torch/native/dispatcher.rb +48 -0
  19. data/lib/torch/native/function.rb +109 -0
  20. data/lib/torch/native/generator.rb +168 -0
  21. data/lib/torch/native/native_functions.yaml +6837 -0
  22. data/lib/torch/native/parser.rb +134 -0
  23. data/lib/torch/nn/alpha_dropout.rb +9 -0
  24. data/lib/torch/nn/avg_pool1d.rb +18 -0
  25. data/lib/torch/nn/avg_pool2d.rb +19 -0
  26. data/lib/torch/nn/avg_pool3d.rb +19 -0
  27. data/lib/torch/nn/avg_poolnd.rb +9 -0
  28. data/lib/torch/nn/batch_norm.rb +75 -0
  29. data/lib/torch/nn/batch_norm1d.rb +11 -0
  30. data/lib/torch/nn/batch_norm2d.rb +11 -0
  31. data/lib/torch/nn/batch_norm3d.rb +11 -0
  32. data/lib/torch/nn/bce_loss.rb +13 -0
  33. data/lib/torch/nn/bce_with_logits_loss.rb +15 -0
  34. data/lib/torch/nn/bilinear.rb +38 -0
  35. data/lib/torch/nn/constant_pad1d.rb +10 -0
  36. data/lib/torch/nn/constant_pad2d.rb +10 -0
  37. data/lib/torch/nn/constant_pad3d.rb +10 -0
  38. data/lib/torch/nn/constant_padnd.rb +18 -0
  39. data/lib/torch/nn/conv1d.rb +22 -0
  40. data/lib/torch/nn/conv2d.rb +16 -39
  41. data/lib/torch/nn/conv3d.rb +22 -0
  42. data/lib/torch/nn/convnd.rb +41 -0
  43. data/lib/torch/nn/cosine_embedding_loss.rb +14 -0
  44. data/lib/torch/nn/cosine_similarity.rb +15 -0
  45. data/lib/torch/nn/cross_entropy_loss.rb +14 -0
  46. data/lib/torch/nn/ctc_loss.rb +15 -0
  47. data/lib/torch/nn/dropout.rb +9 -0
  48. data/lib/torch/nn/dropout2d.rb +9 -0
  49. data/lib/torch/nn/dropout3d.rb +9 -0
  50. data/lib/torch/nn/dropoutnd.rb +15 -0
  51. data/lib/torch/nn/embedding.rb +52 -0
  52. data/lib/torch/nn/embedding_bag.rb +34 -0
  53. data/lib/torch/nn/feature_alpha_dropout.rb +9 -0
  54. data/lib/torch/nn/fold.rb +20 -0
  55. data/lib/torch/nn/functional.rb +419 -16
  56. data/lib/torch/nn/group_norm.rb +36 -0
  57. data/lib/torch/nn/gru.rb +49 -0
  58. data/lib/torch/nn/hardshrink.rb +18 -0
  59. data/lib/torch/nn/hinge_embedding_loss.rb +14 -0
  60. data/lib/torch/nn/identity.rb +14 -0
  61. data/lib/torch/nn/init.rb +58 -1
  62. data/lib/torch/nn/instance_norm.rb +20 -0
  63. data/lib/torch/nn/instance_norm1d.rb +18 -0
  64. data/lib/torch/nn/instance_norm2d.rb +11 -0
  65. data/lib/torch/nn/instance_norm3d.rb +11 -0
  66. data/lib/torch/nn/kl_div_loss.rb +13 -0
  67. data/lib/torch/nn/l1_loss.rb +13 -0
  68. data/lib/torch/nn/layer_norm.rb +35 -0
  69. data/lib/torch/nn/leaky_relu.rb +20 -0
  70. data/lib/torch/nn/linear.rb +12 -11
  71. data/lib/torch/nn/local_response_norm.rb +21 -0
  72. data/lib/torch/nn/log_sigmoid.rb +9 -0
  73. data/lib/torch/nn/log_softmax.rb +14 -0
  74. data/lib/torch/nn/loss.rb +10 -0
  75. data/lib/torch/nn/lp_pool1d.rb +9 -0
  76. data/lib/torch/nn/lp_pool2d.rb +9 -0
  77. data/lib/torch/nn/lp_poolnd.rb +22 -0
  78. data/lib/torch/nn/lstm.rb +66 -0
  79. data/lib/torch/nn/margin_ranking_loss.rb +14 -0
  80. data/lib/torch/nn/max_pool1d.rb +9 -0
  81. data/lib/torch/nn/max_pool2d.rb +9 -0
  82. data/lib/torch/nn/max_pool3d.rb +9 -0
  83. data/lib/torch/nn/max_poolnd.rb +19 -0
  84. data/lib/torch/nn/max_unpool1d.rb +16 -0
  85. data/lib/torch/nn/max_unpool2d.rb +16 -0
  86. data/lib/torch/nn/max_unpool3d.rb +16 -0
  87. data/lib/torch/nn/max_unpoolnd.rb +9 -0
  88. data/lib/torch/nn/module.rb +191 -19
  89. data/lib/torch/nn/mse_loss.rb +2 -2
  90. data/lib/torch/nn/multi_label_margin_loss.rb +13 -0
  91. data/lib/torch/nn/multi_label_soft_margin_loss.rb +13 -0
  92. data/lib/torch/nn/multi_margin_loss.rb +17 -0
  93. data/lib/torch/nn/nll_loss.rb +14 -0
  94. data/lib/torch/nn/pairwise_distance.rb +16 -0
  95. data/lib/torch/nn/parameter.rb +4 -0
  96. data/lib/torch/nn/poisson_nll_loss.rb +16 -0
  97. data/lib/torch/nn/prelu.rb +19 -0
  98. data/lib/torch/nn/reflection_pad1d.rb +10 -0
  99. data/lib/torch/nn/reflection_pad2d.rb +10 -0
  100. data/lib/torch/nn/reflection_padnd.rb +13 -0
  101. data/lib/torch/nn/relu.rb +8 -3
  102. data/lib/torch/nn/replication_pad1d.rb +10 -0
  103. data/lib/torch/nn/replication_pad2d.rb +10 -0
  104. data/lib/torch/nn/replication_pad3d.rb +10 -0
  105. data/lib/torch/nn/replication_padnd.rb +13 -0
  106. data/lib/torch/nn/rnn.rb +22 -0
  107. data/lib/torch/nn/rnn_base.rb +198 -0
  108. data/lib/torch/nn/sequential.rb +1 -10
  109. data/lib/torch/nn/sigmoid.rb +9 -0
  110. data/lib/torch/nn/smooth_l1_loss.rb +13 -0
  111. data/lib/torch/nn/soft_margin_loss.rb +13 -0
  112. data/lib/torch/nn/softmax.rb +18 -0
  113. data/lib/torch/nn/softmax2d.rb +10 -0
  114. data/lib/torch/nn/softmin.rb +14 -0
  115. data/lib/torch/nn/softplus.rb +19 -0
  116. data/lib/torch/nn/softshrink.rb +18 -0
  117. data/lib/torch/nn/softsign.rb +9 -0
  118. data/lib/torch/nn/tanh.rb +9 -0
  119. data/lib/torch/nn/tanhshrink.rb +9 -0
  120. data/lib/torch/nn/triplet_margin_loss.rb +18 -0
  121. data/lib/torch/nn/unfold.rb +19 -0
  122. data/lib/torch/nn/utils.rb +25 -0
  123. data/lib/torch/nn/weighted_loss.rb +10 -0
  124. data/lib/torch/nn/zero_pad2d.rb +9 -0
  125. data/lib/torch/optim/adadelta.rb +57 -0
  126. data/lib/torch/optim/adagrad.rb +71 -0
  127. data/lib/torch/optim/adam.rb +81 -0
  128. data/lib/torch/optim/adamax.rb +68 -0
  129. data/lib/torch/optim/adamw.rb +82 -0
  130. data/lib/torch/optim/asgd.rb +65 -0
  131. data/lib/torch/optim/lr_scheduler/lr_scheduler.rb +33 -0
  132. data/lib/torch/optim/lr_scheduler/step_lr.rb +17 -0
  133. data/lib/torch/optim/optimizer.rb +62 -0
  134. data/lib/torch/optim/rmsprop.rb +76 -0
  135. data/lib/torch/optim/rprop.rb +68 -0
  136. data/lib/torch/optim/sgd.rb +60 -0
  137. data/lib/torch/random.rb +10 -0
  138. data/lib/torch/tensor.rb +90 -30
  139. data/lib/torch/utils/data/data_loader.rb +15 -0
  140. data/lib/torch/utils/data/tensor_dataset.rb +8 -1
  141. data/lib/torch/version.rb +1 -1
  142. metadata +122 -3
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: ff0920ba955063c03309fdb45ecf228b51c556508bea30b510d6bf652c1d0b18
4
- data.tar.gz: 481dccf6a8e929230033f74c82bc9d292ef38ea219e2cb2cc61ca0b0c5457403
3
+ metadata.gz: 9667f9d3256f5e2d39937f17ae8eb00449dd14f79bb01cd647800bd7ed972fc6
4
+ data.tar.gz: 54c23612c79355e09c97da5fcf6b97c183da8316d1c2a53d6f8f0463e98342a2
5
5
  SHA512:
6
- metadata.gz: cd6c8fd9db4af15640217c09813c4f86d3f66360202d30711015c4f34552853ef281d5614fd78dc274d405da0d9f46f08a2359475ae1b0721143db49183faf5d
7
- data.tar.gz: ee638c08458e0d2a8fac52e29c45d1347a74847ca7d8dab3a9a573afd887814d4c578c4e1a7fb80b204222785b27e21ca3c138f3face681e98e63c1bc02a9a7f
6
+ metadata.gz: bb2c8e5aae436367aeb871a2d19958e59ed9e9c7601b1b8b4473e33094cadf6d657947582b0ec93a29cb08723f8f7c81178a2d50beb23a125d5a356769d92177
7
+ data.tar.gz: 62feef39da31a19415e2e6c453aed4972e34db7367161a088944c06a977637a8b25cecc8eb2ad052b3b9deee0707f364e616cc33e7674cf0314899421f18fbee
@@ -1,3 +1,43 @@
1
+ ## 0.1.6 (2019-12-09)
2
+
3
+ - Added recurrent layers
4
+ - Added more pooling layers
5
+ - Added normalization layers
6
+
7
+ ## 0.1.5 (2019-12-06)
8
+
9
+ - Added many more functions
10
+ - Added tensor classes - `FloatTensor`, `LongTensor`, etc
11
+ - Improved modules
12
+
13
+ ## 0.1.4 (2019-12-01)
14
+
15
+ - Added distance functions
16
+ - Added more activations
17
+ - Added more linear layers
18
+ - Added more loss functions
19
+ - Added more init methods
20
+ - Added support for tensor assignment
21
+
22
+ ## 0.1.3 (2019-11-30)
23
+
24
+ - Changed to BSD 3-Clause license to match PyTorch
25
+ - Added many optimizers
26
+ - Added `StepLR` learning rate scheduler
27
+ - Added dropout
28
+ - Added embedding
29
+ - Added support for `bool` type
30
+ - Improved performance of `from_numo`
31
+
32
+ ## 0.1.2 (2019-11-27)
33
+
34
+ - Added SGD optimizer
35
+ - Added support for gradient to `backward` method
36
+ - Added `argmax`, `eq`, `leaky_relu`, `prelu`, and `reshape` methods
37
+ - Improved indexing
38
+ - Fixed `zero_grad`
39
+ - Fixed error with infinite values
40
+
1
41
  ## 0.1.1 (2019-11-26)
2
42
 
3
43
  - Added support for `uint8` and `int8` types
@@ -1,22 +1,46 @@
1
- Copyright (c) 2019 Andrew Kane
2
-
3
- MIT License
4
-
5
- Permission is hereby granted, free of charge, to any person obtaining
6
- a copy of this software and associated documentation files (the
7
- "Software"), to deal in the Software without restriction, including
8
- without limitation the rights to use, copy, modify, merge, publish,
9
- distribute, sublicense, and/or sell copies of the Software, and to
10
- permit persons to whom the Software is furnished to do so, subject to
11
- the following conditions:
12
-
13
- The above copyright notice and this permission notice shall be
14
- included in all copies or substantial portions of the Software.
15
-
16
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
17
- EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
18
- MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
19
- NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
20
- LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
21
- OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
22
- WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
1
+ BSD 3-Clause License
2
+
3
+ From Torch-rb:
4
+
5
+ Copyright (c) 2019- Andrew Kane
6
+
7
+ From PyTorch (for ported code):
8
+
9
+ Copyright (c) 2016- Facebook, Inc (Adam Paszke)
10
+ Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
11
+ Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
12
+ Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
13
+ Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
14
+ Copyright (c) 2011-2013 NYU (Clement Farabet)
15
+ Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
16
+ Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
17
+ Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
18
+
19
+ All rights reserved.
20
+
21
+ Redistribution and use in source and binary forms, with or without
22
+ modification, are permitted provided that the following conditions are met:
23
+
24
+ 1. Redistributions of source code must retain the above copyright
25
+ notice, this list of conditions and the following disclaimer.
26
+
27
+ 2. Redistributions in binary form must reproduce the above copyright
28
+ notice, this list of conditions and the following disclaimer in the
29
+ documentation and/or other materials provided with the distribution.
30
+
31
+ 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America
32
+ and IDIAP Research Institute nor the names of its contributors may be
33
+ used to endorse or promote products derived from this software without
34
+ specific prior written permission.
35
+
36
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
37
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
38
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
39
+ ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
40
+ LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
41
+ CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
42
+ SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
43
+ INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
44
+ CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
45
+ ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
46
+ POSSIBILITY OF SUCH DAMAGE.
data/README.md CHANGED
@@ -2,7 +2,7 @@
2
2
 
3
3
  :fire: Deep learning for Ruby, powered by [LibTorch](https://pytorch.org)
4
4
 
5
- This gem is currently experimental. There may be breaking changes between each release.
5
+ This gem is currently experimental. There may be breaking changes between each release. Please report any issues you experience.
6
6
 
7
7
  [![Build Status](https://travis-ci.org/ankane/torch-rb.svg?branch=master)](https://travis-ci.org/ankane/torch-rb)
8
8
 
@@ -20,6 +20,8 @@ Add this line to your application’s Gemfile:
20
20
  gem 'torch-rb'
21
21
  ```
22
22
 
23
+ It can take a few minutes to compile the extension.
24
+
23
25
  ## Getting Started
24
26
 
25
27
  This library follows the [PyTorch API](https://pytorch.org/docs/stable/torch.html). There are a few changes to make it more Ruby-like:
@@ -28,9 +30,11 @@ This library follows the [PyTorch API](https://pytorch.org/docs/stable/torch.htm
28
30
  - Methods that return booleans use `?` instead of `is_` (`tensor?` instead of `is_tensor`)
29
31
  - Numo is used instead of NumPy (`x.numo` instead of `x.numpy()`)
30
32
 
31
- Many methods and options are missing at the moment. PRs welcome!
33
+ Some methods and options are missing at the moment. PRs welcome!
34
+
35
+ ## Tutorial
32
36
 
33
- Some examples below are from [Deep Learning with PyTorch: A 60 Minutes Blitz](https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html).
37
+ Some examples below are from [Deep Learning with PyTorch: A 60 Minutes Blitz](https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html)
34
38
 
35
39
  ### Tensors
36
40
 
@@ -145,7 +149,7 @@ Convert a Numo array to a tensor
145
149
 
146
150
  ```ruby
147
151
  b = Numo::NArray.cast([1, 2, 3])
148
- Torch.from_numpy(b)
152
+ Torch.from_numo(b)
149
153
  ```
150
154
 
151
155
  ### Autograd
@@ -180,10 +184,10 @@ Stop autograd from tracking history
180
184
 
181
185
  ```ruby
182
186
  x.requires_grad # true
183
- (x ** 2).requires_grad # true
187
+ (x**2).requires_grad # true
184
188
 
185
189
  Torch.no_grad do
186
- (x ** 2).requires_grad # false
190
+ (x**2).requires_grad # false
187
191
  end
188
192
  ```
189
193
 
@@ -223,7 +227,7 @@ class Net < Torch::NN::Module
223
227
  end
224
228
  ```
225
229
 
226
- And run
230
+ Create an instance of it
227
231
 
228
232
  ```ruby
229
233
  net = Net.new
@@ -231,6 +235,58 @@ input = Torch.randn(1, 1, 32, 32)
231
235
  net.call(input)
232
236
  ```
233
237
 
238
+ Get trainable parameters
239
+
240
+ ```ruby
241
+ net.parameters
242
+ ```
243
+
244
+ Zero the gradient buffers and backprop with random gradients
245
+
246
+ ```ruby
247
+ net.zero_grad
248
+ out.backward(Torch.randn(1, 10))
249
+ ```
250
+
251
+ Define a loss function
252
+
253
+ ```ruby
254
+ output = net.call(input)
255
+ target = Torch.randn(10)
256
+ target = target.view(1, -1)
257
+ criterion = Torch::NN::MSELoss.new
258
+ loss = criterion.call(output, target)
259
+ ```
260
+
261
+ Backprop
262
+
263
+ ```ruby
264
+ net.zero_grad
265
+ p net.conv1.bias.grad
266
+ loss.backward
267
+ p net.conv1.bias.grad
268
+ ```
269
+
270
+ Update the weights
271
+
272
+ ```ruby
273
+ learning_rate = 0.01
274
+ net.parameters.each do |f|
275
+ f.data.sub!(f.grad.data * learning_rate)
276
+ end
277
+ ```
278
+
279
+ Use an optimizer
280
+
281
+ ```ruby
282
+ optimizer = Torch::Optim::SGD.new(net.parameters, lr: 0.01)
283
+ optimizer.zero_grad
284
+ output = net.call(input)
285
+ loss = criterion.call(output, target)
286
+ loss.backward
287
+ optimizer.step
288
+ ```
289
+
234
290
  ### Tensor Creation
235
291
 
236
292
  Here’s a list of functions to create tensors (descriptions from the [C++ docs](https://pytorch.org/cppdocs/notes/tensor_creation.html)):
@@ -307,6 +363,14 @@ Here’s a list of functions to create tensors (descriptions from the [C++ docs]
307
363
  Torch.zeros(3) # tensor([0, 0, 0])
308
364
  ```
309
365
 
366
+ ## Examples
367
+
368
+ Here are a few full examples:
369
+
370
+ - [Image classification with MNIST](examples/mnist) ([日本語版](https://qiita.com/kojix2/items/c19c36dc1bf73ea93409))
371
+ - [Collaborative filtering with MovieLens](examples/movielens)
372
+ - [Sequence models and word embeddings](examples/nlp)
373
+
310
374
  ## LibTorch Installation
311
375
 
312
376
  [Download LibTorch](https://pytorch.org/). For Linux, use the `cxx11 ABI` version. Then run:
@@ -351,9 +415,9 @@ To get started with development:
351
415
 
352
416
  ```sh
353
417
  git clone https://github.com/ankane/torch-rb.git
354
- cd torch
418
+ cd torch-rb
355
419
  bundle install
356
- bundle exec rake compile
420
+ bundle exec rake compile -- --with-torch-dir=/path/to/libtorch
357
421
  bundle exec rake test
358
422
  ```
359
423
 
@@ -6,95 +6,29 @@
6
6
  #include <rice/Class.hpp>
7
7
  #include <rice/Constructor.hpp>
8
8
 
9
- using namespace Rice;
10
-
11
- template<>
12
- inline
13
- long long from_ruby<long long>(Object x)
14
- {
15
- return NUM2LL(x);
16
- }
17
-
18
- template<>
19
- inline
20
- Object to_ruby<long long>(long long const & x)
21
- {
22
- return LL2NUM(x);
23
- }
9
+ #include "templates.hpp"
24
10
 
25
- template<>
26
- inline
27
- unsigned long long from_ruby<unsigned long long>(Object x)
28
- {
29
- return NUM2ULL(x);
30
- }
11
+ // generated with:
12
+ // rake generate:functions
13
+ #include "torch_functions.hpp"
14
+ #include "tensor_functions.hpp"
15
+ #include "nn_functions.hpp"
31
16
 
32
- template<>
33
- inline
34
- Object to_ruby<unsigned long long>(unsigned long long const & x)
35
- {
36
- return ULL2NUM(x);
37
- }
38
-
39
- template<>
40
- inline
41
- short from_ruby<short>(Object x)
42
- {
43
- return NUM2SHORT(x);
44
- }
45
-
46
- template<>
47
- inline
48
- Object to_ruby<short>(short const & x)
49
- {
50
- return INT2NUM(x);
51
- }
52
-
53
- template<>
54
- inline
55
- unsigned short from_ruby<unsigned short>(Object x)
56
- {
57
- return NUM2USHORT(x);
58
- }
17
+ using namespace Rice;
59
18
 
60
- template<>
61
- inline
62
- Object to_ruby<unsigned short>(unsigned short const & x)
19
+ extern "C"
20
+ void Init_ext()
63
21
  {
64
- return UINT2NUM(x);
65
- }
22
+ Module rb_mTorch = define_module("Torch");
23
+ add_torch_functions(rb_mTorch);
66
24
 
67
- // need to wrap torch::IntArrayRef() since
68
- // it doesn't own underlying data
69
- class IntArrayRef {
70
- std::vector<int64_t> vec;
71
- public:
72
- IntArrayRef(Object o) {
73
- Array a = Array(o);
74
- for (size_t i = 0; i < a.size(); i++) {
75
- vec.push_back(from_ruby<int64_t>(a[i]));
76
- }
77
- }
78
- operator torch::IntArrayRef() {
79
- return torch::IntArrayRef(vec);
80
- }
81
- };
25
+ Class rb_cTensor = define_class_under<torch::Tensor>(rb_mTorch, "Tensor");
26
+ add_tensor_functions(rb_cTensor);
82
27
 
83
- template<>
84
- inline
85
- IntArrayRef from_ruby<IntArrayRef>(Object x)
86
- {
87
- return IntArrayRef(x);
88
- }
89
-
90
- // for now
91
- typedef float Scalar;
28
+ Module rb_mNN = define_module_under(rb_mTorch, "NN");
29
+ add_nn_functions(rb_mNN);
92
30
 
93
- extern "C"
94
- void Init_ext()
95
- {
96
- Module rb_mTorch = define_module("Torch")
97
- .define_singleton_method(
31
+ rb_mTorch.define_singleton_method(
98
32
  "grad_enabled?",
99
33
  *[]() {
100
34
  return torch::GradMode::is_enabled();
@@ -104,11 +38,6 @@ void Init_ext()
104
38
  *[](bool enabled) {
105
39
  torch::GradMode::set_enabled(enabled);
106
40
  })
107
- .define_singleton_method(
108
- "floating_point?",
109
- *[](torch::Tensor& input) {
110
- return torch::is_floating_point(input);
111
- })
112
41
  .define_singleton_method(
113
42
  "manual_seed",
114
43
  *[](uint64_t seed) {
@@ -177,266 +106,93 @@ void Init_ext()
177
106
  })
178
107
  // begin operations
179
108
  .define_singleton_method(
180
- "_mean",
181
- *[](torch::Tensor& input) {
182
- return torch::mean(input);
183
- })
184
- .define_singleton_method(
185
- "_mean_dim",
186
- *[](torch::Tensor& input, int64_t dim, bool keepdim) {
187
- return torch::mean(input, dim, keepdim);
109
+ "_binary_cross_entropy_with_logits",
110
+ *[](const Tensor &input, const Tensor &target, OptionalTensor weight, OptionalTensor pos_weight, MyReduction reduction) {
111
+ return torch::binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction);
188
112
  })
189
113
  .define_singleton_method(
190
- "_sum",
191
- *[](torch::Tensor& input) {
192
- return torch::sum(input);
193
- })
194
- .define_singleton_method(
195
- "_sum_dim",
196
- *[](torch::Tensor& input, int64_t dim, bool keepdim) {
197
- return torch::sum(input, dim, keepdim);
198
- })
199
- .define_singleton_method(
200
- "_norm",
201
- *[](torch::Tensor& input) {
202
- return torch::norm(input);
203
- })
204
- .define_singleton_method(
205
- "_min",
206
- *[](torch::Tensor& input) {
207
- return torch::min(input);
208
- })
209
- .define_singleton_method(
210
- "_max",
211
- *[](torch::Tensor& input) {
212
- return torch::max(input);
213
- })
214
- .define_singleton_method(
215
- "_exp",
216
- *[](torch::Tensor& input) {
217
- return torch::exp(input);
218
- })
219
- .define_singleton_method(
220
- "_log",
221
- *[](torch::Tensor& input) {
222
- return torch::log(input);
223
- })
224
- .define_singleton_method(
225
- "_unsqueeze",
226
- *[](torch::Tensor& input, int64_t dim) {
227
- return torch::unsqueeze(input, dim);
228
- })
229
- .define_singleton_method(
230
- "_dot",
231
- *[](torch::Tensor& input, torch::Tensor& tensor) {
232
- return torch::dot(input, tensor);
233
- })
234
- .define_singleton_method(
235
- "_matmul",
236
- *[](torch::Tensor& input, torch::Tensor& other) {
237
- return torch::matmul(input, other);
238
- })
239
- .define_singleton_method(
240
- "_add",
241
- *[](torch::Tensor& input, torch::Tensor& other) {
242
- return torch::add(input, other);
243
- })
244
- .define_singleton_method(
245
- "_add_scalar",
246
- *[](torch::Tensor& input, float other) {
247
- return torch::add(input, other);
248
- })
249
- .define_singleton_method(
250
- "_add_out",
251
- *[](torch::Tensor& out, torch::Tensor& input, torch::Tensor& other) {
252
- return torch::add_out(out, input, other);
253
- })
254
- .define_singleton_method(
255
- "_sub",
256
- *[](torch::Tensor& input, torch::Tensor& other) {
257
- return torch::sub(input, other);
258
- })
259
- .define_singleton_method(
260
- "_sub_scalar",
261
- *[](torch::Tensor& input, float other) {
262
- return torch::sub(input, other);
263
- })
264
- .define_singleton_method(
265
- "_mul",
266
- *[](torch::Tensor& input, torch::Tensor& other) {
267
- return torch::mul(input, other);
268
- })
269
- .define_singleton_method(
270
- "_mul_scalar",
271
- *[](torch::Tensor& input, float other) {
272
- return torch::mul(input, other);
273
- })
274
- .define_singleton_method(
275
- "_div",
276
- *[](torch::Tensor& input, torch::Tensor& other) {
277
- return torch::div(input, other);
278
- })
279
- .define_singleton_method(
280
- "_div_scalar",
281
- *[](torch::Tensor& input, float other) {
282
- return torch::div(input, other);
283
- })
284
- .define_singleton_method(
285
- "_remainder",
286
- *[](torch::Tensor& input, torch::Tensor& other) {
287
- return torch::remainder(input, other);
288
- })
289
- .define_singleton_method(
290
- "_remainder_scalar",
291
- *[](torch::Tensor& input, float other) {
292
- return torch::remainder(input, other);
293
- })
294
- .define_singleton_method(
295
- "_pow",
296
- *[](torch::Tensor& input, Scalar exponent) {
297
- return torch::pow(input, exponent);
298
- })
299
- .define_singleton_method(
300
- "_neg",
301
- *[](torch::Tensor& input) {
302
- return torch::neg(input);
303
- })
304
- .define_singleton_method(
305
- "relu",
306
- *[](torch::Tensor& input) {
307
- return torch::relu(input);
308
- })
309
- .define_singleton_method(
310
- "conv2d",
311
- *[](torch::Tensor& input, torch::Tensor& weight, torch::Tensor& bias) {
312
- return torch::conv2d(input, weight, bias);
313
- })
314
- .define_singleton_method(
315
- "linear",
316
- *[](torch::Tensor& input, torch::Tensor& weight, torch::Tensor& bias) {
317
- return torch::linear(input, weight, bias);
318
- })
319
- .define_singleton_method(
320
- "max_pool2d",
321
- *[](torch::Tensor& input, IntArrayRef kernel_size) {
322
- return torch::max_pool2d(input, kernel_size);
323
- })
324
- .define_singleton_method(
325
- "mse_loss",
326
- *[](torch::Tensor& input, torch::Tensor& target, std::string reduction) {
327
- auto red = reduction == "mean" ? Reduction::Mean : Reduction::Sum;
328
- return torch::mse_loss(input, target, red);
329
- })
330
- .define_singleton_method(
331
- "nll_loss",
332
- *[](torch::Tensor& input, torch::Tensor& target) {
333
- return torch::nll_loss(input, target);
114
+ "_from_blob",
115
+ *[](String s, IntArrayRef size, const torch::TensorOptions &options) {
116
+ void *data = const_cast<char *>(s.c_str());
117
+ return torch::from_blob(data, size, options);
334
118
  })
335
119
  .define_singleton_method(
336
120
  "_tensor",
337
121
  *[](Object o, IntArrayRef size, const torch::TensorOptions &options) {
338
122
  Array a = Array(o);
339
- std::vector<float> vec;
340
- for (size_t i = 0; i < a.size(); i++) {
341
- vec.push_back(from_ruby<float>(a[i]));
123
+ auto dtype = options.dtype();
124
+ torch::Tensor t;
125
+ if (dtype == torch::kBool) {
126
+ throw std::runtime_error("Cannot create bool from tensor method yet");
127
+ } else {
128
+ std::vector<float> vec;
129
+ for (size_t i = 0; i < a.size(); i++) {
130
+ vec.push_back(from_ruby<float>(a[i]));
131
+ }
132
+ t = torch::tensor(vec, options);
342
133
  }
343
- return torch::tensor(vec, options).reshape(size);
134
+ return t.reshape(size);
344
135
  });
345
136
 
346
- Class rb_cTensor = define_class_under<torch::Tensor>(rb_mTorch, "Tensor")
137
+ rb_cTensor
347
138
  .define_method("cuda?", &torch::Tensor::is_cuda)
348
- .define_method("distributed?", &torch::Tensor::is_distributed)
349
- .define_method("complex?", &torch::Tensor::is_complex)
350
- .define_method("floating_point?", &torch::Tensor::is_floating_point)
351
- .define_method("signed?", &torch::Tensor::is_signed)
352
139
  .define_method("sparse?", &torch::Tensor::is_sparse)
353
140
  .define_method("quantized?", &torch::Tensor::is_quantized)
354
141
  .define_method("dim", &torch::Tensor::dim)
355
- .define_method("numel", &torch::Tensor::numel)
356
142
  .define_method("element_size", &torch::Tensor::element_size)
357
143
  .define_method("requires_grad", &torch::Tensor::requires_grad)
358
144
  .define_method(
359
- "zero!",
360
- *[](torch::Tensor& self) {
361
- return self.zero_();
362
- })
363
- .define_method(
364
- "detach!",
365
- *[](torch::Tensor& self) {
366
- return self.detach_();
145
+ "addcmul!",
146
+ *[](Tensor& self, Scalar value, const Tensor & tensor1, const Tensor & tensor2) {
147
+ return self.addcmul_(tensor1, tensor2, value);
367
148
  })
368
149
  .define_method(
369
- "_access",
370
- *[](torch::Tensor& self, int64_t index) {
371
- return self[index];
150
+ "addcdiv!",
151
+ *[](Tensor& self, Scalar value, const Tensor & tensor1, const Tensor & tensor2) {
152
+ return self.addcdiv_(tensor1, tensor2, value);
372
153
  })
373
154
  .define_method(
374
155
  "_requires_grad!",
375
- *[](torch::Tensor& self, bool requires_grad) {
156
+ *[](Tensor& self, bool requires_grad) {
376
157
  return self.set_requires_grad(requires_grad);
377
158
  })
378
159
  .define_method(
379
- "backward",
380
- *[](torch::Tensor& self) {
381
- return self.backward();
160
+ "_backward",
161
+ *[](Tensor& self, Object gradient) {
162
+ return gradient.is_nil() ? self.backward() : self.backward(from_ruby<torch::Tensor>(gradient));
382
163
  })
383
164
  .define_method(
384
165
  "grad",
385
- *[](torch::Tensor& self) {
166
+ *[](Tensor& self) {
386
167
  return self.grad();
387
168
  })
388
169
  .define_method(
389
170
  "_dtype",
390
- *[](torch::Tensor& self) {
171
+ *[](Tensor& self) {
391
172
  return (int) at::typeMetaToScalarType(self.dtype());
392
173
  })
174
+ .define_method(
175
+ "_type",
176
+ *[](Tensor& self, int dtype) {
177
+ return self.toType((torch::ScalarType) dtype);
178
+ })
393
179
  .define_method(
394
180
  "_layout",
395
- *[](torch::Tensor& self) {
181
+ *[](Tensor& self) {
396
182
  std::stringstream s;
397
183
  s << self.layout();
398
184
  return s.str();
399
185
  })
400
186
  .define_method(
401
187
  "device",
402
- *[](torch::Tensor& self) {
188
+ *[](Tensor& self) {
403
189
  std::stringstream s;
404
190
  s << self.device();
405
191
  return s.str();
406
192
  })
407
193
  .define_method(
408
- "_view",
409
- *[](torch::Tensor& self, IntArrayRef size) {
410
- return self.view(size);
411
- })
412
- .define_method(
413
- "add!",
414
- *[](torch::Tensor& self, torch::Tensor& other) {
415
- self.add_(other);
416
- })
417
- .define_method(
418
- "sub!",
419
- *[](torch::Tensor& self, torch::Tensor& other) {
420
- self.sub_(other);
421
- })
422
- .define_method(
423
- "mul!",
424
- *[](torch::Tensor& self, torch::Tensor& other) {
425
- self.mul_(other);
426
- })
427
- .define_method(
428
- "div!",
429
- *[](torch::Tensor& self, torch::Tensor& other) {
430
- self.div_(other);
431
- })
432
- .define_method(
433
- "log_softmax",
434
- *[](torch::Tensor& self, int64_t dim) {
435
- return self.log_softmax(dim);
436
- })
437
- .define_method(
438
- "_data",
439
- *[](torch::Tensor& self) {
194
+ "_flat_data",
195
+ *[](Tensor& self) {
440
196
  Array a;
441
197
  auto dtype = self.dtype();
442
198
 
@@ -477,21 +233,23 @@ void Init_ext()
477
233
  a.push(data[i]);
478
234
  }
479
235
  } else if (dtype == torch::kBool) {
480
- // bool
481
- throw std::runtime_error("Type not supported yet");
236
+ bool* data = self.data_ptr<bool>();
237
+ for (int i = 0; i < self.numel(); i++) {
238
+ a.push(data[i] ? True : False);
239
+ }
482
240
  } else {
483
241
  throw std::runtime_error("Unsupported type");
484
242
  }
485
243
  return a;
486
244
  })
487
245
  .define_method(
488
- "_size",
489
- *[](torch::Tensor& self, int i) {
490
- return self.size(i);
246
+ "_to",
247
+ *[](Tensor& self, torch::Device device, int dtype, bool non_blocking, bool copy) {
248
+ return self.to(device, (torch::ScalarType) dtype, non_blocking, copy);
491
249
  })
492
250
  .define_singleton_method(
493
251
  "_make_subclass",
494
- *[](torch::Tensor& rd, bool requires_grad) {
252
+ *[](Tensor& rd, bool requires_grad) {
495
253
  auto data = torch::autograd::as_variable_ref(rd).detach();
496
254
  data.unsafeGetTensorImpl()->set_allow_tensor_metadata_change(true);
497
255
  auto var = data.set_requires_grad(requires_grad);
@@ -538,24 +296,99 @@ void Init_ext()
538
296
  return self.requires_grad(requires_grad);
539
297
  });
540
298
 
541
- Module rb_mNN = define_module_under(rb_mTorch, "NN");
542
-
543
299
  Module rb_mInit = define_module_under(rb_mNN, "Init")
544
300
  .define_singleton_method(
545
- "kaiming_uniform_",
546
- *[](torch::Tensor& input, double a) {
547
- return torch::nn::init::kaiming_uniform_(input, a);
301
+ "_calculate_gain",
302
+ *[](NonlinearityType nonlinearity, double param) {
303
+ return torch::nn::init::calculate_gain(nonlinearity, param);
304
+ })
305
+ .define_singleton_method(
306
+ "_uniform!",
307
+ *[](Tensor tensor, double low, double high) {
308
+ return torch::nn::init::uniform_(tensor, low, high);
309
+ })
310
+ .define_singleton_method(
311
+ "_normal!",
312
+ *[](Tensor tensor, double mean, double std) {
313
+ return torch::nn::init::normal_(tensor, mean, std);
314
+ })
315
+ .define_singleton_method(
316
+ "_constant!",
317
+ *[](Tensor tensor, Scalar value) {
318
+ return torch::nn::init::constant_(tensor, value);
319
+ })
320
+ .define_singleton_method(
321
+ "_ones!",
322
+ *[](Tensor tensor) {
323
+ return torch::nn::init::ones_(tensor);
324
+ })
325
+ .define_singleton_method(
326
+ "_zeros!",
327
+ *[](Tensor tensor) {
328
+ return torch::nn::init::zeros_(tensor);
329
+ })
330
+ .define_singleton_method(
331
+ "_eye!",
332
+ *[](Tensor tensor) {
333
+ return torch::nn::init::eye_(tensor);
334
+ })
335
+ .define_singleton_method(
336
+ "_dirac!",
337
+ *[](Tensor tensor) {
338
+ return torch::nn::init::dirac_(tensor);
339
+ })
340
+ .define_singleton_method(
341
+ "_xavier_uniform!",
342
+ *[](Tensor tensor, double gain) {
343
+ return torch::nn::init::xavier_uniform_(tensor, gain);
344
+ })
345
+ .define_singleton_method(
346
+ "_xavier_normal!",
347
+ *[](Tensor tensor, double gain) {
348
+ return torch::nn::init::xavier_normal_(tensor, gain);
349
+ })
350
+ .define_singleton_method(
351
+ "_kaiming_uniform!",
352
+ *[](Tensor tensor, double a, FanModeType mode, NonlinearityType nonlinearity) {
353
+ return torch::nn::init::kaiming_uniform_(tensor, a, mode, nonlinearity);
354
+ })
355
+ .define_singleton_method(
356
+ "_kaiming_normal!",
357
+ *[](Tensor tensor, double a, FanModeType mode, NonlinearityType nonlinearity) {
358
+ return torch::nn::init::kaiming_normal_(tensor, a, mode, nonlinearity);
548
359
  })
549
360
  .define_singleton_method(
550
- "uniform_",
551
- *[](torch::Tensor& input, double to, double from) {
552
- return torch::nn::init::uniform_(input, to, from);
361
+ "_orthogonal!",
362
+ *[](Tensor tensor, double gain) {
363
+ return torch::nn::init::orthogonal_(tensor, gain);
364
+ })
365
+ .define_singleton_method(
366
+ "_sparse!",
367
+ *[](Tensor tensor, double sparsity, double std) {
368
+ return torch::nn::init::sparse_(tensor, sparsity, std);
553
369
  });
554
370
 
555
371
  Class rb_cParameter = define_class_under<torch::autograd::Variable, torch::Tensor>(rb_mNN, "Parameter")
556
372
  .define_method(
557
373
  "grad",
558
374
  *[](torch::autograd::Variable& self) {
559
- return self.grad();
375
+ auto grad = self.grad();
376
+ return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
560
377
  });
378
+
379
+ Class rb_cDevice = define_class_under<torch::Device>(rb_mTorch, "Device")
380
+ .define_constructor(Constructor<torch::Device, std::string>())
381
+ .define_method("index", &torch::Device::index)
382
+ .define_method("index?", &torch::Device::has_index)
383
+ .define_method(
384
+ "type",
385
+ *[](torch::Device& self) {
386
+ std::stringstream s;
387
+ s << self.type();
388
+ return s.str();
389
+ });
390
+
391
+ Module rb_mCUDA = define_module_under(rb_mTorch, "CUDA")
392
+ .define_singleton_method("available?", &torch::cuda::is_available)
393
+ .define_singleton_method("device_count", &torch::cuda::device_count);
561
394
  }