torch-rb 0.1.2 → 0.1.7

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (142) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +35 -0
  3. data/LICENSE.txt +46 -22
  4. data/README.md +18 -6
  5. data/ext/torch/ext.cpp +148 -369
  6. data/ext/torch/extconf.rb +6 -0
  7. data/ext/torch/nn_functions.cpp +615 -0
  8. data/ext/torch/nn_functions.hpp +6 -0
  9. data/ext/torch/templates.cpp +55 -0
  10. data/ext/torch/templates.hpp +242 -0
  11. data/ext/torch/tensor_functions.cpp +1920 -0
  12. data/ext/torch/tensor_functions.hpp +6 -0
  13. data/ext/torch/torch_functions.cpp +2975 -0
  14. data/ext/torch/torch_functions.hpp +6 -0
  15. data/lib/torch.rb +240 -131
  16. data/lib/torch/ext.bundle +0 -0
  17. data/lib/torch/inspector.rb +27 -22
  18. data/lib/torch/native/dispatcher.rb +48 -0
  19. data/lib/torch/native/function.rb +109 -0
  20. data/lib/torch/native/generator.rb +168 -0
  21. data/lib/torch/native/native_functions.yaml +6837 -0
  22. data/lib/torch/native/parser.rb +134 -0
  23. data/lib/torch/nn/alpha_dropout.rb +9 -0
  24. data/lib/torch/nn/avg_pool1d.rb +18 -0
  25. data/lib/torch/nn/avg_pool2d.rb +19 -0
  26. data/lib/torch/nn/avg_pool3d.rb +19 -0
  27. data/lib/torch/nn/avg_poolnd.rb +9 -0
  28. data/lib/torch/nn/batch_norm.rb +75 -0
  29. data/lib/torch/nn/batch_norm1d.rb +11 -0
  30. data/lib/torch/nn/batch_norm2d.rb +11 -0
  31. data/lib/torch/nn/batch_norm3d.rb +11 -0
  32. data/lib/torch/nn/bce_loss.rb +13 -0
  33. data/lib/torch/nn/bce_with_logits_loss.rb +15 -0
  34. data/lib/torch/nn/bilinear.rb +38 -0
  35. data/lib/torch/nn/constant_pad1d.rb +10 -0
  36. data/lib/torch/nn/constant_pad2d.rb +10 -0
  37. data/lib/torch/nn/constant_pad3d.rb +10 -0
  38. data/lib/torch/nn/constant_padnd.rb +18 -0
  39. data/lib/torch/nn/conv1d.rb +22 -0
  40. data/lib/torch/nn/conv2d.rb +16 -38
  41. data/lib/torch/nn/conv3d.rb +22 -0
  42. data/lib/torch/nn/convnd.rb +41 -0
  43. data/lib/torch/nn/cosine_embedding_loss.rb +14 -0
  44. data/lib/torch/nn/cosine_similarity.rb +15 -0
  45. data/lib/torch/nn/cross_entropy_loss.rb +14 -0
  46. data/lib/torch/nn/ctc_loss.rb +15 -0
  47. data/lib/torch/nn/dropout.rb +9 -0
  48. data/lib/torch/nn/dropout2d.rb +9 -0
  49. data/lib/torch/nn/dropout3d.rb +9 -0
  50. data/lib/torch/nn/dropoutnd.rb +15 -0
  51. data/lib/torch/nn/embedding.rb +52 -0
  52. data/lib/torch/nn/embedding_bag.rb +34 -0
  53. data/lib/torch/nn/feature_alpha_dropout.rb +9 -0
  54. data/lib/torch/nn/fold.rb +20 -0
  55. data/lib/torch/nn/functional.rb +411 -22
  56. data/lib/torch/nn/group_norm.rb +36 -0
  57. data/lib/torch/nn/gru.rb +49 -0
  58. data/lib/torch/nn/hardshrink.rb +18 -0
  59. data/lib/torch/nn/hinge_embedding_loss.rb +14 -0
  60. data/lib/torch/nn/identity.rb +14 -0
  61. data/lib/torch/nn/init.rb +58 -1
  62. data/lib/torch/nn/instance_norm.rb +20 -0
  63. data/lib/torch/nn/instance_norm1d.rb +18 -0
  64. data/lib/torch/nn/instance_norm2d.rb +11 -0
  65. data/lib/torch/nn/instance_norm3d.rb +11 -0
  66. data/lib/torch/nn/kl_div_loss.rb +13 -0
  67. data/lib/torch/nn/l1_loss.rb +13 -0
  68. data/lib/torch/nn/layer_norm.rb +35 -0
  69. data/lib/torch/nn/leaky_relu.rb +20 -0
  70. data/lib/torch/nn/linear.rb +12 -11
  71. data/lib/torch/nn/local_response_norm.rb +21 -0
  72. data/lib/torch/nn/log_sigmoid.rb +9 -0
  73. data/lib/torch/nn/log_softmax.rb +14 -0
  74. data/lib/torch/nn/loss.rb +10 -0
  75. data/lib/torch/nn/lp_pool1d.rb +9 -0
  76. data/lib/torch/nn/lp_pool2d.rb +9 -0
  77. data/lib/torch/nn/lp_poolnd.rb +22 -0
  78. data/lib/torch/nn/lstm.rb +66 -0
  79. data/lib/torch/nn/margin_ranking_loss.rb +14 -0
  80. data/lib/torch/nn/max_pool1d.rb +9 -0
  81. data/lib/torch/nn/max_pool2d.rb +9 -0
  82. data/lib/torch/nn/max_pool3d.rb +9 -0
  83. data/lib/torch/nn/max_poolnd.rb +19 -0
  84. data/lib/torch/nn/max_unpool1d.rb +16 -0
  85. data/lib/torch/nn/max_unpool2d.rb +16 -0
  86. data/lib/torch/nn/max_unpool3d.rb +16 -0
  87. data/lib/torch/nn/max_unpoolnd.rb +9 -0
  88. data/lib/torch/nn/module.rb +201 -20
  89. data/lib/torch/nn/mse_loss.rb +2 -2
  90. data/lib/torch/nn/multi_label_margin_loss.rb +13 -0
  91. data/lib/torch/nn/multi_label_soft_margin_loss.rb +13 -0
  92. data/lib/torch/nn/multi_margin_loss.rb +17 -0
  93. data/lib/torch/nn/nll_loss.rb +14 -0
  94. data/lib/torch/nn/pairwise_distance.rb +16 -0
  95. data/lib/torch/nn/parameter.rb +2 -2
  96. data/lib/torch/nn/poisson_nll_loss.rb +16 -0
  97. data/lib/torch/nn/prelu.rb +19 -0
  98. data/lib/torch/nn/reflection_pad1d.rb +10 -0
  99. data/lib/torch/nn/reflection_pad2d.rb +10 -0
  100. data/lib/torch/nn/reflection_padnd.rb +13 -0
  101. data/lib/torch/nn/relu.rb +8 -3
  102. data/lib/torch/nn/replication_pad1d.rb +10 -0
  103. data/lib/torch/nn/replication_pad2d.rb +10 -0
  104. data/lib/torch/nn/replication_pad3d.rb +10 -0
  105. data/lib/torch/nn/replication_padnd.rb +13 -0
  106. data/lib/torch/nn/rnn.rb +22 -0
  107. data/lib/torch/nn/rnn_base.rb +198 -0
  108. data/lib/torch/nn/sequential.rb +1 -10
  109. data/lib/torch/nn/sigmoid.rb +9 -0
  110. data/lib/torch/nn/smooth_l1_loss.rb +13 -0
  111. data/lib/torch/nn/soft_margin_loss.rb +13 -0
  112. data/lib/torch/nn/softmax.rb +18 -0
  113. data/lib/torch/nn/softmax2d.rb +10 -0
  114. data/lib/torch/nn/softmin.rb +14 -0
  115. data/lib/torch/nn/softplus.rb +19 -0
  116. data/lib/torch/nn/softshrink.rb +18 -0
  117. data/lib/torch/nn/softsign.rb +9 -0
  118. data/lib/torch/nn/tanh.rb +9 -0
  119. data/lib/torch/nn/tanhshrink.rb +9 -0
  120. data/lib/torch/nn/triplet_margin_loss.rb +18 -0
  121. data/lib/torch/nn/unfold.rb +19 -0
  122. data/lib/torch/nn/utils.rb +25 -0
  123. data/lib/torch/nn/weighted_loss.rb +10 -0
  124. data/lib/torch/nn/zero_pad2d.rb +9 -0
  125. data/lib/torch/optim/adadelta.rb +57 -0
  126. data/lib/torch/optim/adagrad.rb +71 -0
  127. data/lib/torch/optim/adam.rb +81 -0
  128. data/lib/torch/optim/adamax.rb +68 -0
  129. data/lib/torch/optim/adamw.rb +82 -0
  130. data/lib/torch/optim/asgd.rb +65 -0
  131. data/lib/torch/optim/lr_scheduler/lr_scheduler.rb +33 -0
  132. data/lib/torch/optim/lr_scheduler/step_lr.rb +17 -0
  133. data/lib/torch/optim/optimizer.rb +56 -0
  134. data/lib/torch/optim/rmsprop.rb +76 -0
  135. data/lib/torch/optim/rprop.rb +68 -0
  136. data/lib/torch/optim/sgd.rb +48 -16
  137. data/lib/torch/random.rb +10 -0
  138. data/lib/torch/tensor.rb +71 -30
  139. data/lib/torch/utils/data/data_loader.rb +10 -4
  140. data/lib/torch/utils/data/tensor_dataset.rb +3 -0
  141. data/lib/torch/version.rb +1 -1
  142. metadata +123 -6
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 30089078de4039df111087e5c27e0cb10d6f36398c0e8d5cc774e9b642a8e133
4
- data.tar.gz: 89eb9e183b395dd67cd9cf228749cf26402993bb561de973f1ba7438bc372b04
3
+ metadata.gz: 51bcc56112e13ba206402857b379aee0df4c7695f75af354e833760adec67756
4
+ data.tar.gz: b2ff24940e4c219d88c5a001d4e8b4e44d0e55a35fc266989f0196e696d15bc8
5
5
  SHA512:
6
- metadata.gz: 027a069b00ac1329c007ddaf471a21b57a82a823ad974a937f832d17720b8e26474c64c79e9a29ec71bac433abb3d74d6a7cf407f0a983bb3c0cafb5b5c7532f
7
- data.tar.gz: 6d7ef10b53db0df39eda13d07aa9b52b4afac0965674919b5cc517e7b53f59a9010cb647e50d62bc06154f7d8f3ef632d5897e4f7774372d7ab1b44b2cb6ca82
6
+ metadata.gz: 95506016db5598333f0cb99a435d29951342af91f75ae4b1f01ef11df81891738888b90c7d27317071ad00bd9b81714cf41c0ea635c2578fd756c388b5e1da7f
7
+ data.tar.gz: 053c9c75e66fe54902f07413687deb6996afc7ae88217bd5dcc852ca59d535c663bb9fb3aed28b20dba953a42e714410867dbd6ecd747f96fe8e8dfd81da8d6c
data/CHANGELOG.md CHANGED
@@ -1,3 +1,38 @@
1
+ ## 0.1.7 (2019-01-10)
2
+
3
+ - Fixed installation error with Ruby 2.7
4
+
5
+ ## 0.1.6 (2019-12-09)
6
+
7
+ - Added recurrent layers
8
+ - Added more pooling layers
9
+ - Added normalization layers
10
+
11
+ ## 0.1.5 (2019-12-06)
12
+
13
+ - Added many more functions
14
+ - Added tensor classes - `FloatTensor`, `LongTensor`, etc
15
+ - Improved modules
16
+
17
+ ## 0.1.4 (2019-12-01)
18
+
19
+ - Added distance functions
20
+ - Added more activations
21
+ - Added more linear layers
22
+ - Added more loss functions
23
+ - Added more init methods
24
+ - Added support for tensor assignment
25
+
26
+ ## 0.1.3 (2019-11-30)
27
+
28
+ - Changed to BSD 3-Clause license to match PyTorch
29
+ - Added many optimizers
30
+ - Added `StepLR` learning rate scheduler
31
+ - Added dropout
32
+ - Added embedding
33
+ - Added support for `bool` type
34
+ - Improved performance of `from_numo`
35
+
1
36
  ## 0.1.2 (2019-11-27)
2
37
 
3
38
  - Added SGD optimizer
data/LICENSE.txt CHANGED
@@ -1,22 +1,46 @@
1
- Copyright (c) 2019 Andrew Kane
2
-
3
- MIT License
4
-
5
- Permission is hereby granted, free of charge, to any person obtaining
6
- a copy of this software and associated documentation files (the
7
- "Software"), to deal in the Software without restriction, including
8
- without limitation the rights to use, copy, modify, merge, publish,
9
- distribute, sublicense, and/or sell copies of the Software, and to
10
- permit persons to whom the Software is furnished to do so, subject to
11
- the following conditions:
12
-
13
- The above copyright notice and this permission notice shall be
14
- included in all copies or substantial portions of the Software.
15
-
16
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
17
- EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
18
- MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
19
- NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
20
- LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
21
- OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
22
- WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
1
+ BSD 3-Clause License
2
+
3
+ From Torch-rb:
4
+
5
+ Copyright (c) 2019- Andrew Kane
6
+
7
+ From PyTorch (for ported code):
8
+
9
+ Copyright (c) 2016- Facebook, Inc (Adam Paszke)
10
+ Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
11
+ Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
12
+ Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
13
+ Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
14
+ Copyright (c) 2011-2013 NYU (Clement Farabet)
15
+ Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
16
+ Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
17
+ Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
18
+
19
+ All rights reserved.
20
+
21
+ Redistribution and use in source and binary forms, with or without
22
+ modification, are permitted provided that the following conditions are met:
23
+
24
+ 1. Redistributions of source code must retain the above copyright
25
+ notice, this list of conditions and the following disclaimer.
26
+
27
+ 2. Redistributions in binary form must reproduce the above copyright
28
+ notice, this list of conditions and the following disclaimer in the
29
+ documentation and/or other materials provided with the distribution.
30
+
31
+ 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America
32
+ and IDIAP Research Institute nor the names of its contributors may be
33
+ used to endorse or promote products derived from this software without
34
+ specific prior written permission.
35
+
36
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
37
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
38
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
39
+ ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
40
+ LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
41
+ CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
42
+ SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
43
+ INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
44
+ CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
45
+ ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
46
+ POSSIBILITY OF SUCH DAMAGE.
data/README.md CHANGED
@@ -20,6 +20,8 @@ Add this line to your application’s Gemfile:
20
20
  gem 'torch-rb'
21
21
  ```
22
22
 
23
+ It can take a few minutes to compile the extension.
24
+
23
25
  ## Getting Started
24
26
 
25
27
  This library follows the [PyTorch API](https://pytorch.org/docs/stable/torch.html). There are a few changes to make it more Ruby-like:
@@ -28,9 +30,11 @@ This library follows the [PyTorch API](https://pytorch.org/docs/stable/torch.htm
28
30
  - Methods that return booleans use `?` instead of `is_` (`tensor?` instead of `is_tensor`)
29
31
  - Numo is used instead of NumPy (`x.numo` instead of `x.numpy()`)
30
32
 
31
- Many methods and options are missing at the moment. PRs welcome!
33
+ Some methods and options are missing at the moment. PRs welcome!
34
+
35
+ ## Tutorial
32
36
 
33
- Some examples below are from [Deep Learning with PyTorch: A 60 Minutes Blitz](https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html).
37
+ Some examples below are from [Deep Learning with PyTorch: A 60 Minutes Blitz](https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html)
34
38
 
35
39
  ### Tensors
36
40
 
@@ -145,7 +149,7 @@ Convert a Numo array to a tensor
145
149
 
146
150
  ```ruby
147
151
  b = Numo::NArray.cast([1, 2, 3])
148
- Torch.from_numpy(b)
152
+ Torch.from_numo(b)
149
153
  ```
150
154
 
151
155
  ### Autograd
@@ -180,10 +184,10 @@ Stop autograd from tracking history
180
184
 
181
185
  ```ruby
182
186
  x.requires_grad # true
183
- (x ** 2).requires_grad # true
187
+ (x**2).requires_grad # true
184
188
 
185
189
  Torch.no_grad do
186
- (x ** 2).requires_grad # false
190
+ (x**2).requires_grad # false
187
191
  end
188
192
  ```
189
193
 
@@ -359,6 +363,14 @@ Here’s a list of functions to create tensors (descriptions from the [C++ docs]
359
363
  Torch.zeros(3) # tensor([0, 0, 0])
360
364
  ```
361
365
 
366
+ ## Examples
367
+
368
+ Here are a few full examples:
369
+
370
+ - [Image classification with MNIST](examples/mnist) ([日本語版](https://qiita.com/kojix2/items/c19c36dc1bf73ea93409))
371
+ - [Collaborative filtering with MovieLens](examples/movielens)
372
+ - [Sequence models and word embeddings](examples/nlp)
373
+
362
374
  ## LibTorch Installation
363
375
 
364
376
  [Download LibTorch](https://pytorch.org/). For Linux, use the `cxx11 ABI` version. Then run:
@@ -405,7 +417,7 @@ To get started with development:
405
417
  git clone https://github.com/ankane/torch-rb.git
406
418
  cd torch-rb
407
419
  bundle install
408
- bundle exec rake compile
420
+ bundle exec rake compile -- --with-torch-dir=/path/to/libtorch
409
421
  bundle exec rake test
410
422
  ```
411
423
 
data/ext/torch/ext.cpp CHANGED
@@ -6,95 +6,29 @@
6
6
  #include <rice/Class.hpp>
7
7
  #include <rice/Constructor.hpp>
8
8
 
9
- using namespace Rice;
10
-
11
- template<>
12
- inline
13
- long long from_ruby<long long>(Object x)
14
- {
15
- return NUM2LL(x);
16
- }
9
+ #include "templates.hpp"
17
10
 
18
- template<>
19
- inline
20
- Object to_ruby<long long>(long long const & x)
21
- {
22
- return LL2NUM(x);
23
- }
11
+ // generated with:
12
+ // rake generate:functions
13
+ #include "torch_functions.hpp"
14
+ #include "tensor_functions.hpp"
15
+ #include "nn_functions.hpp"
24
16
 
25
- template<>
26
- inline
27
- unsigned long long from_ruby<unsigned long long>(Object x)
28
- {
29
- return NUM2ULL(x);
30
- }
31
-
32
- template<>
33
- inline
34
- Object to_ruby<unsigned long long>(unsigned long long const & x)
35
- {
36
- return ULL2NUM(x);
37
- }
38
-
39
- template<>
40
- inline
41
- short from_ruby<short>(Object x)
42
- {
43
- return NUM2SHORT(x);
44
- }
45
-
46
- template<>
47
- inline
48
- Object to_ruby<short>(short const & x)
49
- {
50
- return INT2NUM(x);
51
- }
52
-
53
- template<>
54
- inline
55
- unsigned short from_ruby<unsigned short>(Object x)
56
- {
57
- return NUM2USHORT(x);
58
- }
17
+ using namespace Rice;
59
18
 
60
- template<>
61
- inline
62
- Object to_ruby<unsigned short>(unsigned short const & x)
19
+ extern "C"
20
+ void Init_ext()
63
21
  {
64
- return UINT2NUM(x);
65
- }
22
+ Module rb_mTorch = define_module("Torch");
23
+ add_torch_functions(rb_mTorch);
66
24
 
67
- // need to wrap torch::IntArrayRef() since
68
- // it doesn't own underlying data
69
- class IntArrayRef {
70
- std::vector<int64_t> vec;
71
- public:
72
- IntArrayRef(Object o) {
73
- Array a = Array(o);
74
- for (size_t i = 0; i < a.size(); i++) {
75
- vec.push_back(from_ruby<int64_t>(a[i]));
76
- }
77
- }
78
- operator torch::IntArrayRef() {
79
- return torch::IntArrayRef(vec);
80
- }
81
- };
82
-
83
- template<>
84
- inline
85
- IntArrayRef from_ruby<IntArrayRef>(Object x)
86
- {
87
- return IntArrayRef(x);
88
- }
25
+ Class rb_cTensor = define_class_under<torch::Tensor>(rb_mTorch, "Tensor");
26
+ add_tensor_functions(rb_cTensor);
89
27
 
90
- // for now
91
- typedef float Scalar;
28
+ Module rb_mNN = define_module_under(rb_mTorch, "NN");
29
+ add_nn_functions(rb_mNN);
92
30
 
93
- extern "C"
94
- void Init_ext()
95
- {
96
- Module rb_mTorch = define_module("Torch")
97
- .define_singleton_method(
31
+ rb_mTorch.define_singleton_method(
98
32
  "grad_enabled?",
99
33
  *[]() {
100
34
  return torch::GradMode::is_enabled();
@@ -104,11 +38,6 @@ void Init_ext()
104
38
  *[](bool enabled) {
105
39
  torch::GradMode::set_enabled(enabled);
106
40
  })
107
- .define_singleton_method(
108
- "floating_point?",
109
- *[](torch::Tensor& input) {
110
- return torch::is_floating_point(input);
111
- })
112
41
  .define_singleton_method(
113
42
  "manual_seed",
114
43
  *[](uint64_t seed) {
@@ -177,321 +106,100 @@ void Init_ext()
177
106
  })
178
107
  // begin operations
179
108
  .define_singleton_method(
180
- "_mean",
181
- *[](torch::Tensor& input) {
182
- return torch::mean(input);
183
- })
184
- .define_singleton_method(
185
- "_mean_dim",
186
- *[](torch::Tensor& input, int64_t dim, bool keepdim) {
187
- return torch::mean(input, dim, keepdim);
188
- })
189
- .define_singleton_method(
190
- "_sum",
191
- *[](torch::Tensor& input) {
192
- return torch::sum(input);
193
- })
194
- .define_singleton_method(
195
- "_sum_dim",
196
- *[](torch::Tensor& input, int64_t dim, bool keepdim) {
197
- return torch::sum(input, dim, keepdim);
198
- })
199
- .define_singleton_method(
200
- "_argmax",
201
- *[](torch::Tensor& input) {
202
- return torch::argmax(input);
203
- })
204
- .define_singleton_method(
205
- "_argmax_dim",
206
- *[](torch::Tensor& input, int64_t dim, bool keepdim) {
207
- return torch::argmax(input, dim, keepdim);
208
- })
209
- .define_singleton_method(
210
- "_norm",
211
- *[](torch::Tensor& input) {
212
- return torch::norm(input);
213
- })
214
- .define_singleton_method(
215
- "_min",
216
- *[](torch::Tensor& input) {
217
- return torch::min(input);
218
- })
219
- .define_singleton_method(
220
- "_max",
221
- *[](torch::Tensor& input) {
222
- return torch::max(input);
223
- })
224
- .define_singleton_method(
225
- "_exp",
226
- *[](torch::Tensor& input) {
227
- return torch::exp(input);
228
- })
229
- .define_singleton_method(
230
- "_log",
231
- *[](torch::Tensor& input) {
232
- return torch::log(input);
233
- })
234
- .define_singleton_method(
235
- "_unsqueeze",
236
- *[](torch::Tensor& input, int64_t dim) {
237
- return torch::unsqueeze(input, dim);
238
- })
239
- .define_singleton_method(
240
- "_dot",
241
- *[](torch::Tensor& input, torch::Tensor& tensor) {
242
- return torch::dot(input, tensor);
243
- })
244
- .define_singleton_method(
245
- "_matmul",
246
- *[](torch::Tensor& input, torch::Tensor& other) {
247
- return torch::matmul(input, other);
248
- })
249
- .define_singleton_method(
250
- "_eq",
251
- *[](torch::Tensor& input, torch::Tensor& other) {
252
- return torch::eq(input, other);
253
- })
254
- .define_singleton_method(
255
- "_add",
256
- *[](torch::Tensor& input, torch::Tensor& other) {
257
- return torch::add(input, other);
258
- })
259
- .define_singleton_method(
260
- "_add_scalar",
261
- *[](torch::Tensor& input, float other) {
262
- return torch::add(input, other);
263
- })
264
- .define_singleton_method(
265
- "_add_out",
266
- *[](torch::Tensor& out, torch::Tensor& input, torch::Tensor& other) {
267
- return torch::add_out(out, input, other);
268
- })
269
- .define_singleton_method(
270
- "_sub",
271
- *[](torch::Tensor& input, torch::Tensor& other) {
272
- return torch::sub(input, other);
273
- })
274
- .define_singleton_method(
275
- "_sub_scalar",
276
- *[](torch::Tensor& input, float other) {
277
- return torch::sub(input, other);
278
- })
279
- .define_singleton_method(
280
- "_mul",
281
- *[](torch::Tensor& input, torch::Tensor& other) {
282
- return torch::mul(input, other);
283
- })
284
- .define_singleton_method(
285
- "_mul_scalar",
286
- *[](torch::Tensor& input, float other) {
287
- return torch::mul(input, other);
288
- })
289
- .define_singleton_method(
290
- "_div",
291
- *[](torch::Tensor& input, torch::Tensor& other) {
292
- return torch::div(input, other);
293
- })
294
- .define_singleton_method(
295
- "_div_scalar",
296
- *[](torch::Tensor& input, float other) {
297
- return torch::div(input, other);
298
- })
299
- .define_singleton_method(
300
- "_remainder",
301
- *[](torch::Tensor& input, torch::Tensor& other) {
302
- return torch::remainder(input, other);
109
+ "_save",
110
+ *[](const Tensor &value) {
111
+ auto v = torch::pickle_save(value);
112
+ std::string str(v.begin(), v.end());
113
+ return str;
303
114
  })
304
115
  .define_singleton_method(
305
- "_remainder_scalar",
306
- *[](torch::Tensor& input, float other) {
307
- return torch::remainder(input, other);
116
+ "_binary_cross_entropy_with_logits",
117
+ *[](const Tensor &input, const Tensor &target, OptionalTensor weight, OptionalTensor pos_weight, MyReduction reduction) {
118
+ return torch::binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction);
308
119
  })
309
120
  .define_singleton_method(
310
- "_pow",
311
- *[](torch::Tensor& input, Scalar exponent) {
312
- return torch::pow(input, exponent);
313
- })
314
- .define_singleton_method(
315
- "_neg",
316
- *[](torch::Tensor& input) {
317
- return torch::neg(input);
318
- })
319
- .define_singleton_method(
320
- "_reshape",
321
- *[](torch::Tensor& input, IntArrayRef shape) {
322
- return torch::reshape(input, shape);
323
- })
324
- .define_singleton_method(
325
- "relu",
326
- *[](torch::Tensor& input) {
327
- return torch::relu(input);
328
- })
329
- .define_singleton_method(
330
- "prelu",
331
- *[](torch::Tensor& input, torch::Tensor& weight) {
332
- return torch::prelu(input, weight);
333
- })
334
- .define_singleton_method(
335
- "leaky_relu",
336
- *[](torch::Tensor& input, Scalar negative_slope = 0.01) {
337
- return torch::leaky_relu(input, negative_slope);
338
- })
339
- .define_singleton_method(
340
- "conv2d",
341
- *[](torch::Tensor& input, torch::Tensor& weight, torch::Tensor& bias, IntArrayRef stride, IntArrayRef padding) {
342
- return torch::conv2d(input, weight, bias, stride, padding);
343
- })
344
- .define_singleton_method(
345
- "linear",
346
- *[](torch::Tensor& input, torch::Tensor& weight, torch::Tensor& bias) {
347
- return torch::linear(input, weight, bias);
348
- })
349
- .define_singleton_method(
350
- "max_pool2d",
351
- *[](torch::Tensor& input, IntArrayRef kernel_size) {
352
- return torch::max_pool2d(input, kernel_size);
353
- })
354
- .define_singleton_method(
355
- "avg_pool2d",
356
- *[](torch::Tensor& input, IntArrayRef kernel_size) {
357
- return torch::avg_pool2d(input, kernel_size);
358
- })
359
- .define_singleton_method(
360
- "mse_loss",
361
- *[](torch::Tensor& input, torch::Tensor& target, std::string reduction) {
362
- auto red = reduction == "mean" ? Reduction::Mean : Reduction::Sum;
363
- return torch::mse_loss(input, target, red);
364
- })
365
- .define_singleton_method(
366
- "nll_loss",
367
- *[](torch::Tensor& input, torch::Tensor& target) {
368
- return torch::nll_loss(input, target);
121
+ "_from_blob",
122
+ *[](String s, IntArrayRef size, const torch::TensorOptions &options) {
123
+ void *data = const_cast<char *>(s.c_str());
124
+ return torch::from_blob(data, size, options);
369
125
  })
370
126
  .define_singleton_method(
371
127
  "_tensor",
372
128
  *[](Object o, IntArrayRef size, const torch::TensorOptions &options) {
373
129
  Array a = Array(o);
374
- std::vector<float> vec;
375
- for (size_t i = 0; i < a.size(); i++) {
376
- vec.push_back(from_ruby<float>(a[i]));
130
+ auto dtype = options.dtype();
131
+ torch::Tensor t;
132
+ if (dtype == torch::kBool) {
133
+ throw std::runtime_error("Cannot create bool from tensor method yet");
134
+ } else {
135
+ std::vector<float> vec;
136
+ for (size_t i = 0; i < a.size(); i++) {
137
+ vec.push_back(from_ruby<float>(a[i]));
138
+ }
139
+ t = torch::tensor(vec, options);
377
140
  }
378
- return torch::tensor(vec, options).reshape(size);
141
+ return t.reshape(size);
379
142
  });
380
143
 
381
- Class rb_cTensor = define_class_under<torch::Tensor>(rb_mTorch, "Tensor")
144
+ rb_cTensor
382
145
  .define_method("cuda?", &torch::Tensor::is_cuda)
383
- .define_method("distributed?", &torch::Tensor::is_distributed)
384
- .define_method("complex?", &torch::Tensor::is_complex)
385
- .define_method("floating_point?", &torch::Tensor::is_floating_point)
386
- .define_method("signed?", &torch::Tensor::is_signed)
387
146
  .define_method("sparse?", &torch::Tensor::is_sparse)
388
147
  .define_method("quantized?", &torch::Tensor::is_quantized)
389
148
  .define_method("dim", &torch::Tensor::dim)
390
- .define_method("numel", &torch::Tensor::numel)
391
149
  .define_method("element_size", &torch::Tensor::element_size)
392
150
  .define_method("requires_grad", &torch::Tensor::requires_grad)
393
151
  .define_method(
394
- "zero!",
395
- *[](torch::Tensor& self) {
396
- return self.zero_();
397
- })
398
- .define_method(
399
- "detach!",
400
- *[](torch::Tensor& self) {
401
- return self.detach_();
152
+ "addcmul!",
153
+ *[](Tensor& self, Scalar value, const Tensor & tensor1, const Tensor & tensor2) {
154
+ return self.addcmul_(tensor1, tensor2, value);
402
155
  })
403
156
  .define_method(
404
- "_select",
405
- *[](torch::Tensor& self, int64_t dim, int64_t index) {
406
- return self.select(dim, index);
407
- })
408
- .define_method(
409
- "_slice",
410
- *[](torch::Tensor& self, int64_t dim, int64_t start, int64_t end, int64_t step) {
411
- return self.slice(dim, start, end, step);
157
+ "addcdiv!",
158
+ *[](Tensor& self, Scalar value, const Tensor & tensor1, const Tensor & tensor2) {
159
+ return self.addcdiv_(tensor1, tensor2, value);
412
160
  })
413
161
  .define_method(
414
162
  "_requires_grad!",
415
- *[](torch::Tensor& self, bool requires_grad) {
163
+ *[](Tensor& self, bool requires_grad) {
416
164
  return self.set_requires_grad(requires_grad);
417
165
  })
418
166
  .define_method(
419
167
  "_backward",
420
- *[](torch::Tensor& self) {
421
- return self.backward();
422
- })
423
- .define_method(
424
- "_backward_gradient",
425
- *[](torch::Tensor& self, const torch::Tensor& gradient) {
426
- return self.backward(gradient);
168
+ *[](Tensor& self, Object gradient) {
169
+ return gradient.is_nil() ? self.backward() : self.backward(from_ruby<torch::Tensor>(gradient));
427
170
  })
428
171
  .define_method(
429
172
  "grad",
430
- *[](torch::Tensor& self) {
173
+ *[](Tensor& self) {
431
174
  return self.grad();
432
175
  })
433
176
  .define_method(
434
177
  "_dtype",
435
- *[](torch::Tensor& self) {
178
+ *[](Tensor& self) {
436
179
  return (int) at::typeMetaToScalarType(self.dtype());
437
180
  })
438
181
  .define_method(
439
182
  "_type",
440
- *[](torch::Tensor& self, int dtype) {
183
+ *[](Tensor& self, int dtype) {
441
184
  return self.toType((torch::ScalarType) dtype);
442
185
  })
443
186
  .define_method(
444
187
  "_layout",
445
- *[](torch::Tensor& self) {
188
+ *[](Tensor& self) {
446
189
  std::stringstream s;
447
190
  s << self.layout();
448
191
  return s.str();
449
192
  })
450
193
  .define_method(
451
194
  "device",
452
- *[](torch::Tensor& self) {
195
+ *[](Tensor& self) {
453
196
  std::stringstream s;
454
197
  s << self.device();
455
198
  return s.str();
456
199
  })
457
200
  .define_method(
458
- "_view",
459
- *[](torch::Tensor& self, IntArrayRef size) {
460
- return self.view(size);
461
- })
462
- .define_method(
463
- "add!",
464
- *[](torch::Tensor& self, torch::Tensor& other) {
465
- self.add_(other);
466
- })
467
- .define_method(
468
- "sub!",
469
- *[](torch::Tensor& self, torch::Tensor& other) {
470
- self.sub_(other);
471
- })
472
- .define_method(
473
- "mul!",
474
- *[](torch::Tensor& self, torch::Tensor& other) {
475
- self.mul_(other);
476
- })
477
- .define_method(
478
- "div!",
479
- *[](torch::Tensor& self, torch::Tensor& other) {
480
- self.div_(other);
481
- })
482
- .define_method(
483
- "log_softmax",
484
- *[](torch::Tensor& self, int64_t dim) {
485
- return self.log_softmax(dim);
486
- })
487
- .define_method(
488
- "data",
489
- *[](torch::Tensor& self) {
490
- return self.data();
491
- })
492
- .define_method(
493
- "_data",
494
- *[](torch::Tensor& self) {
201
+ "_flat_data",
202
+ *[](Tensor& self) {
495
203
  Array a;
496
204
  auto dtype = self.dtype();
497
205
 
@@ -532,21 +240,23 @@ void Init_ext()
532
240
  a.push(data[i]);
533
241
  }
534
242
  } else if (dtype == torch::kBool) {
535
- // bool
536
- throw std::runtime_error("Type not supported yet");
243
+ bool* data = self.data_ptr<bool>();
244
+ for (int i = 0; i < self.numel(); i++) {
245
+ a.push(data[i] ? True : False);
246
+ }
537
247
  } else {
538
248
  throw std::runtime_error("Unsupported type");
539
249
  }
540
250
  return a;
541
251
  })
542
252
  .define_method(
543
- "_size",
544
- *[](torch::Tensor& self, int i) {
545
- return self.size(i);
253
+ "_to",
254
+ *[](Tensor& self, torch::Device device, int dtype, bool non_blocking, bool copy) {
255
+ return self.to(device, (torch::ScalarType) dtype, non_blocking, copy);
546
256
  })
547
257
  .define_singleton_method(
548
258
  "_make_subclass",
549
- *[](torch::Tensor& rd, bool requires_grad) {
259
+ *[](Tensor& rd, bool requires_grad) {
550
260
  auto data = torch::autograd::as_variable_ref(rd).detach();
551
261
  data.unsafeGetTensorImpl()->set_allow_tensor_metadata_change(true);
552
262
  auto var = data.set_requires_grad(requires_grad);
@@ -593,30 +303,99 @@ void Init_ext()
593
303
  return self.requires_grad(requires_grad);
594
304
  });
595
305
 
596
- Module rb_mNN = define_module_under(rb_mTorch, "NN");
597
-
598
306
  Module rb_mInit = define_module_under(rb_mNN, "Init")
599
307
  .define_singleton_method(
600
- "kaiming_uniform_",
601
- *[](torch::Tensor& input, double a) {
602
- return torch::nn::init::kaiming_uniform_(input, a);
308
+ "_calculate_gain",
309
+ *[](NonlinearityType nonlinearity, double param) {
310
+ return torch::nn::init::calculate_gain(nonlinearity, param);
311
+ })
312
+ .define_singleton_method(
313
+ "_uniform!",
314
+ *[](Tensor tensor, double low, double high) {
315
+ return torch::nn::init::uniform_(tensor, low, high);
316
+ })
317
+ .define_singleton_method(
318
+ "_normal!",
319
+ *[](Tensor tensor, double mean, double std) {
320
+ return torch::nn::init::normal_(tensor, mean, std);
321
+ })
322
+ .define_singleton_method(
323
+ "_constant!",
324
+ *[](Tensor tensor, Scalar value) {
325
+ return torch::nn::init::constant_(tensor, value);
326
+ })
327
+ .define_singleton_method(
328
+ "_ones!",
329
+ *[](Tensor tensor) {
330
+ return torch::nn::init::ones_(tensor);
331
+ })
332
+ .define_singleton_method(
333
+ "_zeros!",
334
+ *[](Tensor tensor) {
335
+ return torch::nn::init::zeros_(tensor);
336
+ })
337
+ .define_singleton_method(
338
+ "_eye!",
339
+ *[](Tensor tensor) {
340
+ return torch::nn::init::eye_(tensor);
341
+ })
342
+ .define_singleton_method(
343
+ "_dirac!",
344
+ *[](Tensor tensor) {
345
+ return torch::nn::init::dirac_(tensor);
346
+ })
347
+ .define_singleton_method(
348
+ "_xavier_uniform!",
349
+ *[](Tensor tensor, double gain) {
350
+ return torch::nn::init::xavier_uniform_(tensor, gain);
351
+ })
352
+ .define_singleton_method(
353
+ "_xavier_normal!",
354
+ *[](Tensor tensor, double gain) {
355
+ return torch::nn::init::xavier_normal_(tensor, gain);
356
+ })
357
+ .define_singleton_method(
358
+ "_kaiming_uniform!",
359
+ *[](Tensor tensor, double a, FanModeType mode, NonlinearityType nonlinearity) {
360
+ return torch::nn::init::kaiming_uniform_(tensor, a, mode, nonlinearity);
361
+ })
362
+ .define_singleton_method(
363
+ "_kaiming_normal!",
364
+ *[](Tensor tensor, double a, FanModeType mode, NonlinearityType nonlinearity) {
365
+ return torch::nn::init::kaiming_normal_(tensor, a, mode, nonlinearity);
366
+ })
367
+ .define_singleton_method(
368
+ "_orthogonal!",
369
+ *[](Tensor tensor, double gain) {
370
+ return torch::nn::init::orthogonal_(tensor, gain);
603
371
  })
604
372
  .define_singleton_method(
605
- "uniform_",
606
- *[](torch::Tensor& input, double to, double from) {
607
- return torch::nn::init::uniform_(input, to, from);
373
+ "_sparse!",
374
+ *[](Tensor tensor, double sparsity, double std) {
375
+ return torch::nn::init::sparse_(tensor, sparsity, std);
608
376
  });
609
377
 
610
378
  Class rb_cParameter = define_class_under<torch::autograd::Variable, torch::Tensor>(rb_mNN, "Parameter")
611
- // TODO return grad or nil to remove need for 2nd function
612
379
  .define_method(
613
- "_grad",
380
+ "grad",
614
381
  *[](torch::autograd::Variable& self) {
615
- return self.grad();
616
- })
382
+ auto grad = self.grad();
383
+ return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
384
+ });
385
+
386
+ Class rb_cDevice = define_class_under<torch::Device>(rb_mTorch, "Device")
387
+ .define_constructor(Constructor<torch::Device, std::string>())
388
+ .define_method("index", &torch::Device::index)
389
+ .define_method("index?", &torch::Device::has_index)
617
390
  .define_method(
618
- "_grad_defined",
619
- *[](torch::autograd::Variable& self) {
620
- return self.grad().defined();
391
+ "type",
392
+ *[](torch::Device& self) {
393
+ std::stringstream s;
394
+ s << self.type();
395
+ return s.str();
621
396
  });
397
+
398
+ Module rb_mCUDA = define_module_under(rb_mTorch, "CUDA")
399
+ .define_singleton_method("available?", &torch::cuda::is_available)
400
+ .define_singleton_method("device_count", &torch::cuda::device_count);
622
401
  }