torch-rb 0.1.1 → 0.1.6

Sign up to get free protection for your applications and to get access to all the features.
Files changed (142) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +40 -0
  3. data/LICENSE.txt +46 -22
  4. data/README.md +73 -9
  5. data/ext/torch/ext.cpp +148 -315
  6. data/ext/torch/extconf.rb +6 -0
  7. data/ext/torch/nn_functions.cpp +615 -0
  8. data/ext/torch/nn_functions.hpp +6 -0
  9. data/ext/torch/templates.cpp +55 -0
  10. data/ext/torch/templates.hpp +298 -0
  11. data/ext/torch/tensor_functions.cpp +1920 -0
  12. data/ext/torch/tensor_functions.hpp +6 -0
  13. data/ext/torch/torch_functions.cpp +2975 -0
  14. data/ext/torch/torch_functions.hpp +6 -0
  15. data/lib/torch.rb +236 -112
  16. data/lib/torch/ext.bundle +0 -0
  17. data/lib/torch/inspector.rb +52 -25
  18. data/lib/torch/native/dispatcher.rb +48 -0
  19. data/lib/torch/native/function.rb +109 -0
  20. data/lib/torch/native/generator.rb +168 -0
  21. data/lib/torch/native/native_functions.yaml +6837 -0
  22. data/lib/torch/native/parser.rb +134 -0
  23. data/lib/torch/nn/alpha_dropout.rb +9 -0
  24. data/lib/torch/nn/avg_pool1d.rb +18 -0
  25. data/lib/torch/nn/avg_pool2d.rb +19 -0
  26. data/lib/torch/nn/avg_pool3d.rb +19 -0
  27. data/lib/torch/nn/avg_poolnd.rb +9 -0
  28. data/lib/torch/nn/batch_norm.rb +75 -0
  29. data/lib/torch/nn/batch_norm1d.rb +11 -0
  30. data/lib/torch/nn/batch_norm2d.rb +11 -0
  31. data/lib/torch/nn/batch_norm3d.rb +11 -0
  32. data/lib/torch/nn/bce_loss.rb +13 -0
  33. data/lib/torch/nn/bce_with_logits_loss.rb +15 -0
  34. data/lib/torch/nn/bilinear.rb +38 -0
  35. data/lib/torch/nn/constant_pad1d.rb +10 -0
  36. data/lib/torch/nn/constant_pad2d.rb +10 -0
  37. data/lib/torch/nn/constant_pad3d.rb +10 -0
  38. data/lib/torch/nn/constant_padnd.rb +18 -0
  39. data/lib/torch/nn/conv1d.rb +22 -0
  40. data/lib/torch/nn/conv2d.rb +16 -39
  41. data/lib/torch/nn/conv3d.rb +22 -0
  42. data/lib/torch/nn/convnd.rb +41 -0
  43. data/lib/torch/nn/cosine_embedding_loss.rb +14 -0
  44. data/lib/torch/nn/cosine_similarity.rb +15 -0
  45. data/lib/torch/nn/cross_entropy_loss.rb +14 -0
  46. data/lib/torch/nn/ctc_loss.rb +15 -0
  47. data/lib/torch/nn/dropout.rb +9 -0
  48. data/lib/torch/nn/dropout2d.rb +9 -0
  49. data/lib/torch/nn/dropout3d.rb +9 -0
  50. data/lib/torch/nn/dropoutnd.rb +15 -0
  51. data/lib/torch/nn/embedding.rb +52 -0
  52. data/lib/torch/nn/embedding_bag.rb +34 -0
  53. data/lib/torch/nn/feature_alpha_dropout.rb +9 -0
  54. data/lib/torch/nn/fold.rb +20 -0
  55. data/lib/torch/nn/functional.rb +419 -16
  56. data/lib/torch/nn/group_norm.rb +36 -0
  57. data/lib/torch/nn/gru.rb +49 -0
  58. data/lib/torch/nn/hardshrink.rb +18 -0
  59. data/lib/torch/nn/hinge_embedding_loss.rb +14 -0
  60. data/lib/torch/nn/identity.rb +14 -0
  61. data/lib/torch/nn/init.rb +58 -1
  62. data/lib/torch/nn/instance_norm.rb +20 -0
  63. data/lib/torch/nn/instance_norm1d.rb +18 -0
  64. data/lib/torch/nn/instance_norm2d.rb +11 -0
  65. data/lib/torch/nn/instance_norm3d.rb +11 -0
  66. data/lib/torch/nn/kl_div_loss.rb +13 -0
  67. data/lib/torch/nn/l1_loss.rb +13 -0
  68. data/lib/torch/nn/layer_norm.rb +35 -0
  69. data/lib/torch/nn/leaky_relu.rb +20 -0
  70. data/lib/torch/nn/linear.rb +12 -11
  71. data/lib/torch/nn/local_response_norm.rb +21 -0
  72. data/lib/torch/nn/log_sigmoid.rb +9 -0
  73. data/lib/torch/nn/log_softmax.rb +14 -0
  74. data/lib/torch/nn/loss.rb +10 -0
  75. data/lib/torch/nn/lp_pool1d.rb +9 -0
  76. data/lib/torch/nn/lp_pool2d.rb +9 -0
  77. data/lib/torch/nn/lp_poolnd.rb +22 -0
  78. data/lib/torch/nn/lstm.rb +66 -0
  79. data/lib/torch/nn/margin_ranking_loss.rb +14 -0
  80. data/lib/torch/nn/max_pool1d.rb +9 -0
  81. data/lib/torch/nn/max_pool2d.rb +9 -0
  82. data/lib/torch/nn/max_pool3d.rb +9 -0
  83. data/lib/torch/nn/max_poolnd.rb +19 -0
  84. data/lib/torch/nn/max_unpool1d.rb +16 -0
  85. data/lib/torch/nn/max_unpool2d.rb +16 -0
  86. data/lib/torch/nn/max_unpool3d.rb +16 -0
  87. data/lib/torch/nn/max_unpoolnd.rb +9 -0
  88. data/lib/torch/nn/module.rb +191 -19
  89. data/lib/torch/nn/mse_loss.rb +2 -2
  90. data/lib/torch/nn/multi_label_margin_loss.rb +13 -0
  91. data/lib/torch/nn/multi_label_soft_margin_loss.rb +13 -0
  92. data/lib/torch/nn/multi_margin_loss.rb +17 -0
  93. data/lib/torch/nn/nll_loss.rb +14 -0
  94. data/lib/torch/nn/pairwise_distance.rb +16 -0
  95. data/lib/torch/nn/parameter.rb +4 -0
  96. data/lib/torch/nn/poisson_nll_loss.rb +16 -0
  97. data/lib/torch/nn/prelu.rb +19 -0
  98. data/lib/torch/nn/reflection_pad1d.rb +10 -0
  99. data/lib/torch/nn/reflection_pad2d.rb +10 -0
  100. data/lib/torch/nn/reflection_padnd.rb +13 -0
  101. data/lib/torch/nn/relu.rb +8 -3
  102. data/lib/torch/nn/replication_pad1d.rb +10 -0
  103. data/lib/torch/nn/replication_pad2d.rb +10 -0
  104. data/lib/torch/nn/replication_pad3d.rb +10 -0
  105. data/lib/torch/nn/replication_padnd.rb +13 -0
  106. data/lib/torch/nn/rnn.rb +22 -0
  107. data/lib/torch/nn/rnn_base.rb +198 -0
  108. data/lib/torch/nn/sequential.rb +1 -10
  109. data/lib/torch/nn/sigmoid.rb +9 -0
  110. data/lib/torch/nn/smooth_l1_loss.rb +13 -0
  111. data/lib/torch/nn/soft_margin_loss.rb +13 -0
  112. data/lib/torch/nn/softmax.rb +18 -0
  113. data/lib/torch/nn/softmax2d.rb +10 -0
  114. data/lib/torch/nn/softmin.rb +14 -0
  115. data/lib/torch/nn/softplus.rb +19 -0
  116. data/lib/torch/nn/softshrink.rb +18 -0
  117. data/lib/torch/nn/softsign.rb +9 -0
  118. data/lib/torch/nn/tanh.rb +9 -0
  119. data/lib/torch/nn/tanhshrink.rb +9 -0
  120. data/lib/torch/nn/triplet_margin_loss.rb +18 -0
  121. data/lib/torch/nn/unfold.rb +19 -0
  122. data/lib/torch/nn/utils.rb +25 -0
  123. data/lib/torch/nn/weighted_loss.rb +10 -0
  124. data/lib/torch/nn/zero_pad2d.rb +9 -0
  125. data/lib/torch/optim/adadelta.rb +57 -0
  126. data/lib/torch/optim/adagrad.rb +71 -0
  127. data/lib/torch/optim/adam.rb +81 -0
  128. data/lib/torch/optim/adamax.rb +68 -0
  129. data/lib/torch/optim/adamw.rb +82 -0
  130. data/lib/torch/optim/asgd.rb +65 -0
  131. data/lib/torch/optim/lr_scheduler/lr_scheduler.rb +33 -0
  132. data/lib/torch/optim/lr_scheduler/step_lr.rb +17 -0
  133. data/lib/torch/optim/optimizer.rb +62 -0
  134. data/lib/torch/optim/rmsprop.rb +76 -0
  135. data/lib/torch/optim/rprop.rb +68 -0
  136. data/lib/torch/optim/sgd.rb +60 -0
  137. data/lib/torch/random.rb +10 -0
  138. data/lib/torch/tensor.rb +90 -30
  139. data/lib/torch/utils/data/data_loader.rb +15 -0
  140. data/lib/torch/utils/data/tensor_dataset.rb +8 -1
  141. data/lib/torch/version.rb +1 -1
  142. metadata +122 -3
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: ff0920ba955063c03309fdb45ecf228b51c556508bea30b510d6bf652c1d0b18
4
- data.tar.gz: 481dccf6a8e929230033f74c82bc9d292ef38ea219e2cb2cc61ca0b0c5457403
3
+ metadata.gz: 9667f9d3256f5e2d39937f17ae8eb00449dd14f79bb01cd647800bd7ed972fc6
4
+ data.tar.gz: 54c23612c79355e09c97da5fcf6b97c183da8316d1c2a53d6f8f0463e98342a2
5
5
  SHA512:
6
- metadata.gz: cd6c8fd9db4af15640217c09813c4f86d3f66360202d30711015c4f34552853ef281d5614fd78dc274d405da0d9f46f08a2359475ae1b0721143db49183faf5d
7
- data.tar.gz: ee638c08458e0d2a8fac52e29c45d1347a74847ca7d8dab3a9a573afd887814d4c578c4e1a7fb80b204222785b27e21ca3c138f3face681e98e63c1bc02a9a7f
6
+ metadata.gz: bb2c8e5aae436367aeb871a2d19958e59ed9e9c7601b1b8b4473e33094cadf6d657947582b0ec93a29cb08723f8f7c81178a2d50beb23a125d5a356769d92177
7
+ data.tar.gz: 62feef39da31a19415e2e6c453aed4972e34db7367161a088944c06a977637a8b25cecc8eb2ad052b3b9deee0707f364e616cc33e7674cf0314899421f18fbee
@@ -1,3 +1,43 @@
1
+ ## 0.1.6 (2019-12-09)
2
+
3
+ - Added recurrent layers
4
+ - Added more pooling layers
5
+ - Added normalization layers
6
+
7
+ ## 0.1.5 (2019-12-06)
8
+
9
+ - Added many more functions
10
+ - Added tensor classes - `FloatTensor`, `LongTensor`, etc
11
+ - Improved modules
12
+
13
+ ## 0.1.4 (2019-12-01)
14
+
15
+ - Added distance functions
16
+ - Added more activations
17
+ - Added more linear layers
18
+ - Added more loss functions
19
+ - Added more init methods
20
+ - Added support for tensor assignment
21
+
22
+ ## 0.1.3 (2019-11-30)
23
+
24
+ - Changed to BSD 3-Clause license to match PyTorch
25
+ - Added many optimizers
26
+ - Added `StepLR` learning rate scheduler
27
+ - Added dropout
28
+ - Added embedding
29
+ - Added support for `bool` type
30
+ - Improved performance of `from_numo`
31
+
32
+ ## 0.1.2 (2019-11-27)
33
+
34
+ - Added SGD optimizer
35
+ - Added support for gradient to `backward` method
36
+ - Added `argmax`, `eq`, `leaky_relu`, `prelu`, and `reshape` methods
37
+ - Improved indexing
38
+ - Fixed `zero_grad`
39
+ - Fixed error with infinite values
40
+
1
41
  ## 0.1.1 (2019-11-26)
2
42
 
3
43
  - Added support for `uint8` and `int8` types
@@ -1,22 +1,46 @@
1
- Copyright (c) 2019 Andrew Kane
2
-
3
- MIT License
4
-
5
- Permission is hereby granted, free of charge, to any person obtaining
6
- a copy of this software and associated documentation files (the
7
- "Software"), to deal in the Software without restriction, including
8
- without limitation the rights to use, copy, modify, merge, publish,
9
- distribute, sublicense, and/or sell copies of the Software, and to
10
- permit persons to whom the Software is furnished to do so, subject to
11
- the following conditions:
12
-
13
- The above copyright notice and this permission notice shall be
14
- included in all copies or substantial portions of the Software.
15
-
16
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
17
- EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
18
- MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
19
- NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
20
- LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
21
- OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
22
- WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
1
+ BSD 3-Clause License
2
+
3
+ From Torch-rb:
4
+
5
+ Copyright (c) 2019- Andrew Kane
6
+
7
+ From PyTorch (for ported code):
8
+
9
+ Copyright (c) 2016- Facebook, Inc (Adam Paszke)
10
+ Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
11
+ Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
12
+ Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
13
+ Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
14
+ Copyright (c) 2011-2013 NYU (Clement Farabet)
15
+ Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
16
+ Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
17
+ Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
18
+
19
+ All rights reserved.
20
+
21
+ Redistribution and use in source and binary forms, with or without
22
+ modification, are permitted provided that the following conditions are met:
23
+
24
+ 1. Redistributions of source code must retain the above copyright
25
+ notice, this list of conditions and the following disclaimer.
26
+
27
+ 2. Redistributions in binary form must reproduce the above copyright
28
+ notice, this list of conditions and the following disclaimer in the
29
+ documentation and/or other materials provided with the distribution.
30
+
31
+ 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America
32
+ and IDIAP Research Institute nor the names of its contributors may be
33
+ used to endorse or promote products derived from this software without
34
+ specific prior written permission.
35
+
36
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
37
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
38
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
39
+ ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
40
+ LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
41
+ CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
42
+ SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
43
+ INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
44
+ CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
45
+ ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
46
+ POSSIBILITY OF SUCH DAMAGE.
data/README.md CHANGED
@@ -2,7 +2,7 @@
2
2
 
3
3
  :fire: Deep learning for Ruby, powered by [LibTorch](https://pytorch.org)
4
4
 
5
- This gem is currently experimental. There may be breaking changes between each release.
5
+ This gem is currently experimental. There may be breaking changes between each release. Please report any issues you experience.
6
6
 
7
7
  [![Build Status](https://travis-ci.org/ankane/torch-rb.svg?branch=master)](https://travis-ci.org/ankane/torch-rb)
8
8
 
@@ -20,6 +20,8 @@ Add this line to your application’s Gemfile:
20
20
  gem 'torch-rb'
21
21
  ```
22
22
 
23
+ It can take a few minutes to compile the extension.
24
+
23
25
  ## Getting Started
24
26
 
25
27
  This library follows the [PyTorch API](https://pytorch.org/docs/stable/torch.html). There are a few changes to make it more Ruby-like:
@@ -28,9 +30,11 @@ This library follows the [PyTorch API](https://pytorch.org/docs/stable/torch.htm
28
30
  - Methods that return booleans use `?` instead of `is_` (`tensor?` instead of `is_tensor`)
29
31
  - Numo is used instead of NumPy (`x.numo` instead of `x.numpy()`)
30
32
 
31
- Many methods and options are missing at the moment. PRs welcome!
33
+ Some methods and options are missing at the moment. PRs welcome!
34
+
35
+ ## Tutorial
32
36
 
33
- Some examples below are from [Deep Learning with PyTorch: A 60 Minutes Blitz](https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html).
37
+ Some examples below are from [Deep Learning with PyTorch: A 60 Minutes Blitz](https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html)
34
38
 
35
39
  ### Tensors
36
40
 
@@ -145,7 +149,7 @@ Convert a Numo array to a tensor
145
149
 
146
150
  ```ruby
147
151
  b = Numo::NArray.cast([1, 2, 3])
148
- Torch.from_numpy(b)
152
+ Torch.from_numo(b)
149
153
  ```
150
154
 
151
155
  ### Autograd
@@ -180,10 +184,10 @@ Stop autograd from tracking history
180
184
 
181
185
  ```ruby
182
186
  x.requires_grad # true
183
- (x ** 2).requires_grad # true
187
+ (x**2).requires_grad # true
184
188
 
185
189
  Torch.no_grad do
186
- (x ** 2).requires_grad # false
190
+ (x**2).requires_grad # false
187
191
  end
188
192
  ```
189
193
 
@@ -223,7 +227,7 @@ class Net < Torch::NN::Module
223
227
  end
224
228
  ```
225
229
 
226
- And run
230
+ Create an instance of it
227
231
 
228
232
  ```ruby
229
233
  net = Net.new
@@ -231,6 +235,58 @@ input = Torch.randn(1, 1, 32, 32)
231
235
  net.call(input)
232
236
  ```
233
237
 
238
+ Get trainable parameters
239
+
240
+ ```ruby
241
+ net.parameters
242
+ ```
243
+
244
+ Zero the gradient buffers and backprop with random gradients
245
+
246
+ ```ruby
247
+ net.zero_grad
248
+ out.backward(Torch.randn(1, 10))
249
+ ```
250
+
251
+ Define a loss function
252
+
253
+ ```ruby
254
+ output = net.call(input)
255
+ target = Torch.randn(10)
256
+ target = target.view(1, -1)
257
+ criterion = Torch::NN::MSELoss.new
258
+ loss = criterion.call(output, target)
259
+ ```
260
+
261
+ Backprop
262
+
263
+ ```ruby
264
+ net.zero_grad
265
+ p net.conv1.bias.grad
266
+ loss.backward
267
+ p net.conv1.bias.grad
268
+ ```
269
+
270
+ Update the weights
271
+
272
+ ```ruby
273
+ learning_rate = 0.01
274
+ net.parameters.each do |f|
275
+ f.data.sub!(f.grad.data * learning_rate)
276
+ end
277
+ ```
278
+
279
+ Use an optimizer
280
+
281
+ ```ruby
282
+ optimizer = Torch::Optim::SGD.new(net.parameters, lr: 0.01)
283
+ optimizer.zero_grad
284
+ output = net.call(input)
285
+ loss = criterion.call(output, target)
286
+ loss.backward
287
+ optimizer.step
288
+ ```
289
+
234
290
  ### Tensor Creation
235
291
 
236
292
  Here’s a list of functions to create tensors (descriptions from the [C++ docs](https://pytorch.org/cppdocs/notes/tensor_creation.html)):
@@ -307,6 +363,14 @@ Here’s a list of functions to create tensors (descriptions from the [C++ docs]
307
363
  Torch.zeros(3) # tensor([0, 0, 0])
308
364
  ```
309
365
 
366
+ ## Examples
367
+
368
+ Here are a few full examples:
369
+
370
+ - [Image classification with MNIST](examples/mnist) ([日本語版](https://qiita.com/kojix2/items/c19c36dc1bf73ea93409))
371
+ - [Collaborative filtering with MovieLens](examples/movielens)
372
+ - [Sequence models and word embeddings](examples/nlp)
373
+
310
374
  ## LibTorch Installation
311
375
 
312
376
  [Download LibTorch](https://pytorch.org/). For Linux, use the `cxx11 ABI` version. Then run:
@@ -351,9 +415,9 @@ To get started with development:
351
415
 
352
416
  ```sh
353
417
  git clone https://github.com/ankane/torch-rb.git
354
- cd torch
418
+ cd torch-rb
355
419
  bundle install
356
- bundle exec rake compile
420
+ bundle exec rake compile -- --with-torch-dir=/path/to/libtorch
357
421
  bundle exec rake test
358
422
  ```
359
423
 
@@ -6,95 +6,29 @@
6
6
  #include <rice/Class.hpp>
7
7
  #include <rice/Constructor.hpp>
8
8
 
9
- using namespace Rice;
10
-
11
- template<>
12
- inline
13
- long long from_ruby<long long>(Object x)
14
- {
15
- return NUM2LL(x);
16
- }
17
-
18
- template<>
19
- inline
20
- Object to_ruby<long long>(long long const & x)
21
- {
22
- return LL2NUM(x);
23
- }
9
+ #include "templates.hpp"
24
10
 
25
- template<>
26
- inline
27
- unsigned long long from_ruby<unsigned long long>(Object x)
28
- {
29
- return NUM2ULL(x);
30
- }
11
+ // generated with:
12
+ // rake generate:functions
13
+ #include "torch_functions.hpp"
14
+ #include "tensor_functions.hpp"
15
+ #include "nn_functions.hpp"
31
16
 
32
- template<>
33
- inline
34
- Object to_ruby<unsigned long long>(unsigned long long const & x)
35
- {
36
- return ULL2NUM(x);
37
- }
38
-
39
- template<>
40
- inline
41
- short from_ruby<short>(Object x)
42
- {
43
- return NUM2SHORT(x);
44
- }
45
-
46
- template<>
47
- inline
48
- Object to_ruby<short>(short const & x)
49
- {
50
- return INT2NUM(x);
51
- }
52
-
53
- template<>
54
- inline
55
- unsigned short from_ruby<unsigned short>(Object x)
56
- {
57
- return NUM2USHORT(x);
58
- }
17
+ using namespace Rice;
59
18
 
60
- template<>
61
- inline
62
- Object to_ruby<unsigned short>(unsigned short const & x)
19
+ extern "C"
20
+ void Init_ext()
63
21
  {
64
- return UINT2NUM(x);
65
- }
22
+ Module rb_mTorch = define_module("Torch");
23
+ add_torch_functions(rb_mTorch);
66
24
 
67
- // need to wrap torch::IntArrayRef() since
68
- // it doesn't own underlying data
69
- class IntArrayRef {
70
- std::vector<int64_t> vec;
71
- public:
72
- IntArrayRef(Object o) {
73
- Array a = Array(o);
74
- for (size_t i = 0; i < a.size(); i++) {
75
- vec.push_back(from_ruby<int64_t>(a[i]));
76
- }
77
- }
78
- operator torch::IntArrayRef() {
79
- return torch::IntArrayRef(vec);
80
- }
81
- };
25
+ Class rb_cTensor = define_class_under<torch::Tensor>(rb_mTorch, "Tensor");
26
+ add_tensor_functions(rb_cTensor);
82
27
 
83
- template<>
84
- inline
85
- IntArrayRef from_ruby<IntArrayRef>(Object x)
86
- {
87
- return IntArrayRef(x);
88
- }
89
-
90
- // for now
91
- typedef float Scalar;
28
+ Module rb_mNN = define_module_under(rb_mTorch, "NN");
29
+ add_nn_functions(rb_mNN);
92
30
 
93
- extern "C"
94
- void Init_ext()
95
- {
96
- Module rb_mTorch = define_module("Torch")
97
- .define_singleton_method(
31
+ rb_mTorch.define_singleton_method(
98
32
  "grad_enabled?",
99
33
  *[]() {
100
34
  return torch::GradMode::is_enabled();
@@ -104,11 +38,6 @@ void Init_ext()
104
38
  *[](bool enabled) {
105
39
  torch::GradMode::set_enabled(enabled);
106
40
  })
107
- .define_singleton_method(
108
- "floating_point?",
109
- *[](torch::Tensor& input) {
110
- return torch::is_floating_point(input);
111
- })
112
41
  .define_singleton_method(
113
42
  "manual_seed",
114
43
  *[](uint64_t seed) {
@@ -177,266 +106,93 @@ void Init_ext()
177
106
  })
178
107
  // begin operations
179
108
  .define_singleton_method(
180
- "_mean",
181
- *[](torch::Tensor& input) {
182
- return torch::mean(input);
183
- })
184
- .define_singleton_method(
185
- "_mean_dim",
186
- *[](torch::Tensor& input, int64_t dim, bool keepdim) {
187
- return torch::mean(input, dim, keepdim);
109
+ "_binary_cross_entropy_with_logits",
110
+ *[](const Tensor &input, const Tensor &target, OptionalTensor weight, OptionalTensor pos_weight, MyReduction reduction) {
111
+ return torch::binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction);
188
112
  })
189
113
  .define_singleton_method(
190
- "_sum",
191
- *[](torch::Tensor& input) {
192
- return torch::sum(input);
193
- })
194
- .define_singleton_method(
195
- "_sum_dim",
196
- *[](torch::Tensor& input, int64_t dim, bool keepdim) {
197
- return torch::sum(input, dim, keepdim);
198
- })
199
- .define_singleton_method(
200
- "_norm",
201
- *[](torch::Tensor& input) {
202
- return torch::norm(input);
203
- })
204
- .define_singleton_method(
205
- "_min",
206
- *[](torch::Tensor& input) {
207
- return torch::min(input);
208
- })
209
- .define_singleton_method(
210
- "_max",
211
- *[](torch::Tensor& input) {
212
- return torch::max(input);
213
- })
214
- .define_singleton_method(
215
- "_exp",
216
- *[](torch::Tensor& input) {
217
- return torch::exp(input);
218
- })
219
- .define_singleton_method(
220
- "_log",
221
- *[](torch::Tensor& input) {
222
- return torch::log(input);
223
- })
224
- .define_singleton_method(
225
- "_unsqueeze",
226
- *[](torch::Tensor& input, int64_t dim) {
227
- return torch::unsqueeze(input, dim);
228
- })
229
- .define_singleton_method(
230
- "_dot",
231
- *[](torch::Tensor& input, torch::Tensor& tensor) {
232
- return torch::dot(input, tensor);
233
- })
234
- .define_singleton_method(
235
- "_matmul",
236
- *[](torch::Tensor& input, torch::Tensor& other) {
237
- return torch::matmul(input, other);
238
- })
239
- .define_singleton_method(
240
- "_add",
241
- *[](torch::Tensor& input, torch::Tensor& other) {
242
- return torch::add(input, other);
243
- })
244
- .define_singleton_method(
245
- "_add_scalar",
246
- *[](torch::Tensor& input, float other) {
247
- return torch::add(input, other);
248
- })
249
- .define_singleton_method(
250
- "_add_out",
251
- *[](torch::Tensor& out, torch::Tensor& input, torch::Tensor& other) {
252
- return torch::add_out(out, input, other);
253
- })
254
- .define_singleton_method(
255
- "_sub",
256
- *[](torch::Tensor& input, torch::Tensor& other) {
257
- return torch::sub(input, other);
258
- })
259
- .define_singleton_method(
260
- "_sub_scalar",
261
- *[](torch::Tensor& input, float other) {
262
- return torch::sub(input, other);
263
- })
264
- .define_singleton_method(
265
- "_mul",
266
- *[](torch::Tensor& input, torch::Tensor& other) {
267
- return torch::mul(input, other);
268
- })
269
- .define_singleton_method(
270
- "_mul_scalar",
271
- *[](torch::Tensor& input, float other) {
272
- return torch::mul(input, other);
273
- })
274
- .define_singleton_method(
275
- "_div",
276
- *[](torch::Tensor& input, torch::Tensor& other) {
277
- return torch::div(input, other);
278
- })
279
- .define_singleton_method(
280
- "_div_scalar",
281
- *[](torch::Tensor& input, float other) {
282
- return torch::div(input, other);
283
- })
284
- .define_singleton_method(
285
- "_remainder",
286
- *[](torch::Tensor& input, torch::Tensor& other) {
287
- return torch::remainder(input, other);
288
- })
289
- .define_singleton_method(
290
- "_remainder_scalar",
291
- *[](torch::Tensor& input, float other) {
292
- return torch::remainder(input, other);
293
- })
294
- .define_singleton_method(
295
- "_pow",
296
- *[](torch::Tensor& input, Scalar exponent) {
297
- return torch::pow(input, exponent);
298
- })
299
- .define_singleton_method(
300
- "_neg",
301
- *[](torch::Tensor& input) {
302
- return torch::neg(input);
303
- })
304
- .define_singleton_method(
305
- "relu",
306
- *[](torch::Tensor& input) {
307
- return torch::relu(input);
308
- })
309
- .define_singleton_method(
310
- "conv2d",
311
- *[](torch::Tensor& input, torch::Tensor& weight, torch::Tensor& bias) {
312
- return torch::conv2d(input, weight, bias);
313
- })
314
- .define_singleton_method(
315
- "linear",
316
- *[](torch::Tensor& input, torch::Tensor& weight, torch::Tensor& bias) {
317
- return torch::linear(input, weight, bias);
318
- })
319
- .define_singleton_method(
320
- "max_pool2d",
321
- *[](torch::Tensor& input, IntArrayRef kernel_size) {
322
- return torch::max_pool2d(input, kernel_size);
323
- })
324
- .define_singleton_method(
325
- "mse_loss",
326
- *[](torch::Tensor& input, torch::Tensor& target, std::string reduction) {
327
- auto red = reduction == "mean" ? Reduction::Mean : Reduction::Sum;
328
- return torch::mse_loss(input, target, red);
329
- })
330
- .define_singleton_method(
331
- "nll_loss",
332
- *[](torch::Tensor& input, torch::Tensor& target) {
333
- return torch::nll_loss(input, target);
114
+ "_from_blob",
115
+ *[](String s, IntArrayRef size, const torch::TensorOptions &options) {
116
+ void *data = const_cast<char *>(s.c_str());
117
+ return torch::from_blob(data, size, options);
334
118
  })
335
119
  .define_singleton_method(
336
120
  "_tensor",
337
121
  *[](Object o, IntArrayRef size, const torch::TensorOptions &options) {
338
122
  Array a = Array(o);
339
- std::vector<float> vec;
340
- for (size_t i = 0; i < a.size(); i++) {
341
- vec.push_back(from_ruby<float>(a[i]));
123
+ auto dtype = options.dtype();
124
+ torch::Tensor t;
125
+ if (dtype == torch::kBool) {
126
+ throw std::runtime_error("Cannot create bool from tensor method yet");
127
+ } else {
128
+ std::vector<float> vec;
129
+ for (size_t i = 0; i < a.size(); i++) {
130
+ vec.push_back(from_ruby<float>(a[i]));
131
+ }
132
+ t = torch::tensor(vec, options);
342
133
  }
343
- return torch::tensor(vec, options).reshape(size);
134
+ return t.reshape(size);
344
135
  });
345
136
 
346
- Class rb_cTensor = define_class_under<torch::Tensor>(rb_mTorch, "Tensor")
137
+ rb_cTensor
347
138
  .define_method("cuda?", &torch::Tensor::is_cuda)
348
- .define_method("distributed?", &torch::Tensor::is_distributed)
349
- .define_method("complex?", &torch::Tensor::is_complex)
350
- .define_method("floating_point?", &torch::Tensor::is_floating_point)
351
- .define_method("signed?", &torch::Tensor::is_signed)
352
139
  .define_method("sparse?", &torch::Tensor::is_sparse)
353
140
  .define_method("quantized?", &torch::Tensor::is_quantized)
354
141
  .define_method("dim", &torch::Tensor::dim)
355
- .define_method("numel", &torch::Tensor::numel)
356
142
  .define_method("element_size", &torch::Tensor::element_size)
357
143
  .define_method("requires_grad", &torch::Tensor::requires_grad)
358
144
  .define_method(
359
- "zero!",
360
- *[](torch::Tensor& self) {
361
- return self.zero_();
362
- })
363
- .define_method(
364
- "detach!",
365
- *[](torch::Tensor& self) {
366
- return self.detach_();
145
+ "addcmul!",
146
+ *[](Tensor& self, Scalar value, const Tensor & tensor1, const Tensor & tensor2) {
147
+ return self.addcmul_(tensor1, tensor2, value);
367
148
  })
368
149
  .define_method(
369
- "_access",
370
- *[](torch::Tensor& self, int64_t index) {
371
- return self[index];
150
+ "addcdiv!",
151
+ *[](Tensor& self, Scalar value, const Tensor & tensor1, const Tensor & tensor2) {
152
+ return self.addcdiv_(tensor1, tensor2, value);
372
153
  })
373
154
  .define_method(
374
155
  "_requires_grad!",
375
- *[](torch::Tensor& self, bool requires_grad) {
156
+ *[](Tensor& self, bool requires_grad) {
376
157
  return self.set_requires_grad(requires_grad);
377
158
  })
378
159
  .define_method(
379
- "backward",
380
- *[](torch::Tensor& self) {
381
- return self.backward();
160
+ "_backward",
161
+ *[](Tensor& self, Object gradient) {
162
+ return gradient.is_nil() ? self.backward() : self.backward(from_ruby<torch::Tensor>(gradient));
382
163
  })
383
164
  .define_method(
384
165
  "grad",
385
- *[](torch::Tensor& self) {
166
+ *[](Tensor& self) {
386
167
  return self.grad();
387
168
  })
388
169
  .define_method(
389
170
  "_dtype",
390
- *[](torch::Tensor& self) {
171
+ *[](Tensor& self) {
391
172
  return (int) at::typeMetaToScalarType(self.dtype());
392
173
  })
174
+ .define_method(
175
+ "_type",
176
+ *[](Tensor& self, int dtype) {
177
+ return self.toType((torch::ScalarType) dtype);
178
+ })
393
179
  .define_method(
394
180
  "_layout",
395
- *[](torch::Tensor& self) {
181
+ *[](Tensor& self) {
396
182
  std::stringstream s;
397
183
  s << self.layout();
398
184
  return s.str();
399
185
  })
400
186
  .define_method(
401
187
  "device",
402
- *[](torch::Tensor& self) {
188
+ *[](Tensor& self) {
403
189
  std::stringstream s;
404
190
  s << self.device();
405
191
  return s.str();
406
192
  })
407
193
  .define_method(
408
- "_view",
409
- *[](torch::Tensor& self, IntArrayRef size) {
410
- return self.view(size);
411
- })
412
- .define_method(
413
- "add!",
414
- *[](torch::Tensor& self, torch::Tensor& other) {
415
- self.add_(other);
416
- })
417
- .define_method(
418
- "sub!",
419
- *[](torch::Tensor& self, torch::Tensor& other) {
420
- self.sub_(other);
421
- })
422
- .define_method(
423
- "mul!",
424
- *[](torch::Tensor& self, torch::Tensor& other) {
425
- self.mul_(other);
426
- })
427
- .define_method(
428
- "div!",
429
- *[](torch::Tensor& self, torch::Tensor& other) {
430
- self.div_(other);
431
- })
432
- .define_method(
433
- "log_softmax",
434
- *[](torch::Tensor& self, int64_t dim) {
435
- return self.log_softmax(dim);
436
- })
437
- .define_method(
438
- "_data",
439
- *[](torch::Tensor& self) {
194
+ "_flat_data",
195
+ *[](Tensor& self) {
440
196
  Array a;
441
197
  auto dtype = self.dtype();
442
198
 
@@ -477,21 +233,23 @@ void Init_ext()
477
233
  a.push(data[i]);
478
234
  }
479
235
  } else if (dtype == torch::kBool) {
480
- // bool
481
- throw std::runtime_error("Type not supported yet");
236
+ bool* data = self.data_ptr<bool>();
237
+ for (int i = 0; i < self.numel(); i++) {
238
+ a.push(data[i] ? True : False);
239
+ }
482
240
  } else {
483
241
  throw std::runtime_error("Unsupported type");
484
242
  }
485
243
  return a;
486
244
  })
487
245
  .define_method(
488
- "_size",
489
- *[](torch::Tensor& self, int i) {
490
- return self.size(i);
246
+ "_to",
247
+ *[](Tensor& self, torch::Device device, int dtype, bool non_blocking, bool copy) {
248
+ return self.to(device, (torch::ScalarType) dtype, non_blocking, copy);
491
249
  })
492
250
  .define_singleton_method(
493
251
  "_make_subclass",
494
- *[](torch::Tensor& rd, bool requires_grad) {
252
+ *[](Tensor& rd, bool requires_grad) {
495
253
  auto data = torch::autograd::as_variable_ref(rd).detach();
496
254
  data.unsafeGetTensorImpl()->set_allow_tensor_metadata_change(true);
497
255
  auto var = data.set_requires_grad(requires_grad);
@@ -538,24 +296,99 @@ void Init_ext()
538
296
  return self.requires_grad(requires_grad);
539
297
  });
540
298
 
541
- Module rb_mNN = define_module_under(rb_mTorch, "NN");
542
-
543
299
  Module rb_mInit = define_module_under(rb_mNN, "Init")
544
300
  .define_singleton_method(
545
- "kaiming_uniform_",
546
- *[](torch::Tensor& input, double a) {
547
- return torch::nn::init::kaiming_uniform_(input, a);
301
+ "_calculate_gain",
302
+ *[](NonlinearityType nonlinearity, double param) {
303
+ return torch::nn::init::calculate_gain(nonlinearity, param);
304
+ })
305
+ .define_singleton_method(
306
+ "_uniform!",
307
+ *[](Tensor tensor, double low, double high) {
308
+ return torch::nn::init::uniform_(tensor, low, high);
309
+ })
310
+ .define_singleton_method(
311
+ "_normal!",
312
+ *[](Tensor tensor, double mean, double std) {
313
+ return torch::nn::init::normal_(tensor, mean, std);
314
+ })
315
+ .define_singleton_method(
316
+ "_constant!",
317
+ *[](Tensor tensor, Scalar value) {
318
+ return torch::nn::init::constant_(tensor, value);
319
+ })
320
+ .define_singleton_method(
321
+ "_ones!",
322
+ *[](Tensor tensor) {
323
+ return torch::nn::init::ones_(tensor);
324
+ })
325
+ .define_singleton_method(
326
+ "_zeros!",
327
+ *[](Tensor tensor) {
328
+ return torch::nn::init::zeros_(tensor);
329
+ })
330
+ .define_singleton_method(
331
+ "_eye!",
332
+ *[](Tensor tensor) {
333
+ return torch::nn::init::eye_(tensor);
334
+ })
335
+ .define_singleton_method(
336
+ "_dirac!",
337
+ *[](Tensor tensor) {
338
+ return torch::nn::init::dirac_(tensor);
339
+ })
340
+ .define_singleton_method(
341
+ "_xavier_uniform!",
342
+ *[](Tensor tensor, double gain) {
343
+ return torch::nn::init::xavier_uniform_(tensor, gain);
344
+ })
345
+ .define_singleton_method(
346
+ "_xavier_normal!",
347
+ *[](Tensor tensor, double gain) {
348
+ return torch::nn::init::xavier_normal_(tensor, gain);
349
+ })
350
+ .define_singleton_method(
351
+ "_kaiming_uniform!",
352
+ *[](Tensor tensor, double a, FanModeType mode, NonlinearityType nonlinearity) {
353
+ return torch::nn::init::kaiming_uniform_(tensor, a, mode, nonlinearity);
354
+ })
355
+ .define_singleton_method(
356
+ "_kaiming_normal!",
357
+ *[](Tensor tensor, double a, FanModeType mode, NonlinearityType nonlinearity) {
358
+ return torch::nn::init::kaiming_normal_(tensor, a, mode, nonlinearity);
548
359
  })
549
360
  .define_singleton_method(
550
- "uniform_",
551
- *[](torch::Tensor& input, double to, double from) {
552
- return torch::nn::init::uniform_(input, to, from);
361
+ "_orthogonal!",
362
+ *[](Tensor tensor, double gain) {
363
+ return torch::nn::init::orthogonal_(tensor, gain);
364
+ })
365
+ .define_singleton_method(
366
+ "_sparse!",
367
+ *[](Tensor tensor, double sparsity, double std) {
368
+ return torch::nn::init::sparse_(tensor, sparsity, std);
553
369
  });
554
370
 
555
371
  Class rb_cParameter = define_class_under<torch::autograd::Variable, torch::Tensor>(rb_mNN, "Parameter")
556
372
  .define_method(
557
373
  "grad",
558
374
  *[](torch::autograd::Variable& self) {
559
- return self.grad();
375
+ auto grad = self.grad();
376
+ return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
560
377
  });
378
+
379
+ Class rb_cDevice = define_class_under<torch::Device>(rb_mTorch, "Device")
380
+ .define_constructor(Constructor<torch::Device, std::string>())
381
+ .define_method("index", &torch::Device::index)
382
+ .define_method("index?", &torch::Device::has_index)
383
+ .define_method(
384
+ "type",
385
+ *[](torch::Device& self) {
386
+ std::stringstream s;
387
+ s << self.type();
388
+ return s.str();
389
+ });
390
+
391
+ Module rb_mCUDA = define_module_under(rb_mTorch, "CUDA")
392
+ .define_singleton_method("available?", &torch::cuda::is_available)
393
+ .define_singleton_method("device_count", &torch::cuda::device_count);
561
394
  }