torch-rb 0.1.2 → 0.1.7

Sign up to get free protection for your applications and to get access to all the features.
Files changed (142) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +35 -0
  3. data/LICENSE.txt +46 -22
  4. data/README.md +18 -6
  5. data/ext/torch/ext.cpp +148 -369
  6. data/ext/torch/extconf.rb +6 -0
  7. data/ext/torch/nn_functions.cpp +615 -0
  8. data/ext/torch/nn_functions.hpp +6 -0
  9. data/ext/torch/templates.cpp +55 -0
  10. data/ext/torch/templates.hpp +242 -0
  11. data/ext/torch/tensor_functions.cpp +1920 -0
  12. data/ext/torch/tensor_functions.hpp +6 -0
  13. data/ext/torch/torch_functions.cpp +2975 -0
  14. data/ext/torch/torch_functions.hpp +6 -0
  15. data/lib/torch.rb +240 -131
  16. data/lib/torch/ext.bundle +0 -0
  17. data/lib/torch/inspector.rb +27 -22
  18. data/lib/torch/native/dispatcher.rb +48 -0
  19. data/lib/torch/native/function.rb +109 -0
  20. data/lib/torch/native/generator.rb +168 -0
  21. data/lib/torch/native/native_functions.yaml +6837 -0
  22. data/lib/torch/native/parser.rb +134 -0
  23. data/lib/torch/nn/alpha_dropout.rb +9 -0
  24. data/lib/torch/nn/avg_pool1d.rb +18 -0
  25. data/lib/torch/nn/avg_pool2d.rb +19 -0
  26. data/lib/torch/nn/avg_pool3d.rb +19 -0
  27. data/lib/torch/nn/avg_poolnd.rb +9 -0
  28. data/lib/torch/nn/batch_norm.rb +75 -0
  29. data/lib/torch/nn/batch_norm1d.rb +11 -0
  30. data/lib/torch/nn/batch_norm2d.rb +11 -0
  31. data/lib/torch/nn/batch_norm3d.rb +11 -0
  32. data/lib/torch/nn/bce_loss.rb +13 -0
  33. data/lib/torch/nn/bce_with_logits_loss.rb +15 -0
  34. data/lib/torch/nn/bilinear.rb +38 -0
  35. data/lib/torch/nn/constant_pad1d.rb +10 -0
  36. data/lib/torch/nn/constant_pad2d.rb +10 -0
  37. data/lib/torch/nn/constant_pad3d.rb +10 -0
  38. data/lib/torch/nn/constant_padnd.rb +18 -0
  39. data/lib/torch/nn/conv1d.rb +22 -0
  40. data/lib/torch/nn/conv2d.rb +16 -38
  41. data/lib/torch/nn/conv3d.rb +22 -0
  42. data/lib/torch/nn/convnd.rb +41 -0
  43. data/lib/torch/nn/cosine_embedding_loss.rb +14 -0
  44. data/lib/torch/nn/cosine_similarity.rb +15 -0
  45. data/lib/torch/nn/cross_entropy_loss.rb +14 -0
  46. data/lib/torch/nn/ctc_loss.rb +15 -0
  47. data/lib/torch/nn/dropout.rb +9 -0
  48. data/lib/torch/nn/dropout2d.rb +9 -0
  49. data/lib/torch/nn/dropout3d.rb +9 -0
  50. data/lib/torch/nn/dropoutnd.rb +15 -0
  51. data/lib/torch/nn/embedding.rb +52 -0
  52. data/lib/torch/nn/embedding_bag.rb +34 -0
  53. data/lib/torch/nn/feature_alpha_dropout.rb +9 -0
  54. data/lib/torch/nn/fold.rb +20 -0
  55. data/lib/torch/nn/functional.rb +411 -22
  56. data/lib/torch/nn/group_norm.rb +36 -0
  57. data/lib/torch/nn/gru.rb +49 -0
  58. data/lib/torch/nn/hardshrink.rb +18 -0
  59. data/lib/torch/nn/hinge_embedding_loss.rb +14 -0
  60. data/lib/torch/nn/identity.rb +14 -0
  61. data/lib/torch/nn/init.rb +58 -1
  62. data/lib/torch/nn/instance_norm.rb +20 -0
  63. data/lib/torch/nn/instance_norm1d.rb +18 -0
  64. data/lib/torch/nn/instance_norm2d.rb +11 -0
  65. data/lib/torch/nn/instance_norm3d.rb +11 -0
  66. data/lib/torch/nn/kl_div_loss.rb +13 -0
  67. data/lib/torch/nn/l1_loss.rb +13 -0
  68. data/lib/torch/nn/layer_norm.rb +35 -0
  69. data/lib/torch/nn/leaky_relu.rb +20 -0
  70. data/lib/torch/nn/linear.rb +12 -11
  71. data/lib/torch/nn/local_response_norm.rb +21 -0
  72. data/lib/torch/nn/log_sigmoid.rb +9 -0
  73. data/lib/torch/nn/log_softmax.rb +14 -0
  74. data/lib/torch/nn/loss.rb +10 -0
  75. data/lib/torch/nn/lp_pool1d.rb +9 -0
  76. data/lib/torch/nn/lp_pool2d.rb +9 -0
  77. data/lib/torch/nn/lp_poolnd.rb +22 -0
  78. data/lib/torch/nn/lstm.rb +66 -0
  79. data/lib/torch/nn/margin_ranking_loss.rb +14 -0
  80. data/lib/torch/nn/max_pool1d.rb +9 -0
  81. data/lib/torch/nn/max_pool2d.rb +9 -0
  82. data/lib/torch/nn/max_pool3d.rb +9 -0
  83. data/lib/torch/nn/max_poolnd.rb +19 -0
  84. data/lib/torch/nn/max_unpool1d.rb +16 -0
  85. data/lib/torch/nn/max_unpool2d.rb +16 -0
  86. data/lib/torch/nn/max_unpool3d.rb +16 -0
  87. data/lib/torch/nn/max_unpoolnd.rb +9 -0
  88. data/lib/torch/nn/module.rb +201 -20
  89. data/lib/torch/nn/mse_loss.rb +2 -2
  90. data/lib/torch/nn/multi_label_margin_loss.rb +13 -0
  91. data/lib/torch/nn/multi_label_soft_margin_loss.rb +13 -0
  92. data/lib/torch/nn/multi_margin_loss.rb +17 -0
  93. data/lib/torch/nn/nll_loss.rb +14 -0
  94. data/lib/torch/nn/pairwise_distance.rb +16 -0
  95. data/lib/torch/nn/parameter.rb +2 -2
  96. data/lib/torch/nn/poisson_nll_loss.rb +16 -0
  97. data/lib/torch/nn/prelu.rb +19 -0
  98. data/lib/torch/nn/reflection_pad1d.rb +10 -0
  99. data/lib/torch/nn/reflection_pad2d.rb +10 -0
  100. data/lib/torch/nn/reflection_padnd.rb +13 -0
  101. data/lib/torch/nn/relu.rb +8 -3
  102. data/lib/torch/nn/replication_pad1d.rb +10 -0
  103. data/lib/torch/nn/replication_pad2d.rb +10 -0
  104. data/lib/torch/nn/replication_pad3d.rb +10 -0
  105. data/lib/torch/nn/replication_padnd.rb +13 -0
  106. data/lib/torch/nn/rnn.rb +22 -0
  107. data/lib/torch/nn/rnn_base.rb +198 -0
  108. data/lib/torch/nn/sequential.rb +1 -10
  109. data/lib/torch/nn/sigmoid.rb +9 -0
  110. data/lib/torch/nn/smooth_l1_loss.rb +13 -0
  111. data/lib/torch/nn/soft_margin_loss.rb +13 -0
  112. data/lib/torch/nn/softmax.rb +18 -0
  113. data/lib/torch/nn/softmax2d.rb +10 -0
  114. data/lib/torch/nn/softmin.rb +14 -0
  115. data/lib/torch/nn/softplus.rb +19 -0
  116. data/lib/torch/nn/softshrink.rb +18 -0
  117. data/lib/torch/nn/softsign.rb +9 -0
  118. data/lib/torch/nn/tanh.rb +9 -0
  119. data/lib/torch/nn/tanhshrink.rb +9 -0
  120. data/lib/torch/nn/triplet_margin_loss.rb +18 -0
  121. data/lib/torch/nn/unfold.rb +19 -0
  122. data/lib/torch/nn/utils.rb +25 -0
  123. data/lib/torch/nn/weighted_loss.rb +10 -0
  124. data/lib/torch/nn/zero_pad2d.rb +9 -0
  125. data/lib/torch/optim/adadelta.rb +57 -0
  126. data/lib/torch/optim/adagrad.rb +71 -0
  127. data/lib/torch/optim/adam.rb +81 -0
  128. data/lib/torch/optim/adamax.rb +68 -0
  129. data/lib/torch/optim/adamw.rb +82 -0
  130. data/lib/torch/optim/asgd.rb +65 -0
  131. data/lib/torch/optim/lr_scheduler/lr_scheduler.rb +33 -0
  132. data/lib/torch/optim/lr_scheduler/step_lr.rb +17 -0
  133. data/lib/torch/optim/optimizer.rb +56 -0
  134. data/lib/torch/optim/rmsprop.rb +76 -0
  135. data/lib/torch/optim/rprop.rb +68 -0
  136. data/lib/torch/optim/sgd.rb +48 -16
  137. data/lib/torch/random.rb +10 -0
  138. data/lib/torch/tensor.rb +71 -30
  139. data/lib/torch/utils/data/data_loader.rb +10 -4
  140. data/lib/torch/utils/data/tensor_dataset.rb +3 -0
  141. data/lib/torch/version.rb +1 -1
  142. metadata +123 -6
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 30089078de4039df111087e5c27e0cb10d6f36398c0e8d5cc774e9b642a8e133
4
- data.tar.gz: 89eb9e183b395dd67cd9cf228749cf26402993bb561de973f1ba7438bc372b04
3
+ metadata.gz: 51bcc56112e13ba206402857b379aee0df4c7695f75af354e833760adec67756
4
+ data.tar.gz: b2ff24940e4c219d88c5a001d4e8b4e44d0e55a35fc266989f0196e696d15bc8
5
5
  SHA512:
6
- metadata.gz: 027a069b00ac1329c007ddaf471a21b57a82a823ad974a937f832d17720b8e26474c64c79e9a29ec71bac433abb3d74d6a7cf407f0a983bb3c0cafb5b5c7532f
7
- data.tar.gz: 6d7ef10b53db0df39eda13d07aa9b52b4afac0965674919b5cc517e7b53f59a9010cb647e50d62bc06154f7d8f3ef632d5897e4f7774372d7ab1b44b2cb6ca82
6
+ metadata.gz: 95506016db5598333f0cb99a435d29951342af91f75ae4b1f01ef11df81891738888b90c7d27317071ad00bd9b81714cf41c0ea635c2578fd756c388b5e1da7f
7
+ data.tar.gz: 053c9c75e66fe54902f07413687deb6996afc7ae88217bd5dcc852ca59d535c663bb9fb3aed28b20dba953a42e714410867dbd6ecd747f96fe8e8dfd81da8d6c
data/CHANGELOG.md CHANGED
@@ -1,3 +1,38 @@
1
+ ## 0.1.7 (2019-01-10)
2
+
3
+ - Fixed installation error with Ruby 2.7
4
+
5
+ ## 0.1.6 (2019-12-09)
6
+
7
+ - Added recurrent layers
8
+ - Added more pooling layers
9
+ - Added normalization layers
10
+
11
+ ## 0.1.5 (2019-12-06)
12
+
13
+ - Added many more functions
14
+ - Added tensor classes - `FloatTensor`, `LongTensor`, etc
15
+ - Improved modules
16
+
17
+ ## 0.1.4 (2019-12-01)
18
+
19
+ - Added distance functions
20
+ - Added more activations
21
+ - Added more linear layers
22
+ - Added more loss functions
23
+ - Added more init methods
24
+ - Added support for tensor assignment
25
+
26
+ ## 0.1.3 (2019-11-30)
27
+
28
+ - Changed to BSD 3-Clause license to match PyTorch
29
+ - Added many optimizers
30
+ - Added `StepLR` learning rate scheduler
31
+ - Added dropout
32
+ - Added embedding
33
+ - Added support for `bool` type
34
+ - Improved performance of `from_numo`
35
+
1
36
  ## 0.1.2 (2019-11-27)
2
37
 
3
38
  - Added SGD optimizer
data/LICENSE.txt CHANGED
@@ -1,22 +1,46 @@
1
- Copyright (c) 2019 Andrew Kane
2
-
3
- MIT License
4
-
5
- Permission is hereby granted, free of charge, to any person obtaining
6
- a copy of this software and associated documentation files (the
7
- "Software"), to deal in the Software without restriction, including
8
- without limitation the rights to use, copy, modify, merge, publish,
9
- distribute, sublicense, and/or sell copies of the Software, and to
10
- permit persons to whom the Software is furnished to do so, subject to
11
- the following conditions:
12
-
13
- The above copyright notice and this permission notice shall be
14
- included in all copies or substantial portions of the Software.
15
-
16
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
17
- EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
18
- MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
19
- NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
20
- LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
21
- OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
22
- WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
1
+ BSD 3-Clause License
2
+
3
+ From Torch-rb:
4
+
5
+ Copyright (c) 2019- Andrew Kane
6
+
7
+ From PyTorch (for ported code):
8
+
9
+ Copyright (c) 2016- Facebook, Inc (Adam Paszke)
10
+ Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
11
+ Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
12
+ Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
13
+ Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
14
+ Copyright (c) 2011-2013 NYU (Clement Farabet)
15
+ Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
16
+ Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
17
+ Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
18
+
19
+ All rights reserved.
20
+
21
+ Redistribution and use in source and binary forms, with or without
22
+ modification, are permitted provided that the following conditions are met:
23
+
24
+ 1. Redistributions of source code must retain the above copyright
25
+ notice, this list of conditions and the following disclaimer.
26
+
27
+ 2. Redistributions in binary form must reproduce the above copyright
28
+ notice, this list of conditions and the following disclaimer in the
29
+ documentation and/or other materials provided with the distribution.
30
+
31
+ 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America
32
+ and IDIAP Research Institute nor the names of its contributors may be
33
+ used to endorse or promote products derived from this software without
34
+ specific prior written permission.
35
+
36
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
37
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
38
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
39
+ ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
40
+ LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
41
+ CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
42
+ SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
43
+ INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
44
+ CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
45
+ ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
46
+ POSSIBILITY OF SUCH DAMAGE.
data/README.md CHANGED
@@ -20,6 +20,8 @@ Add this line to your application’s Gemfile:
20
20
  gem 'torch-rb'
21
21
  ```
22
22
 
23
+ It can take a few minutes to compile the extension.
24
+
23
25
  ## Getting Started
24
26
 
25
27
  This library follows the [PyTorch API](https://pytorch.org/docs/stable/torch.html). There are a few changes to make it more Ruby-like:
@@ -28,9 +30,11 @@ This library follows the [PyTorch API](https://pytorch.org/docs/stable/torch.htm
28
30
  - Methods that return booleans use `?` instead of `is_` (`tensor?` instead of `is_tensor`)
29
31
  - Numo is used instead of NumPy (`x.numo` instead of `x.numpy()`)
30
32
 
31
- Many methods and options are missing at the moment. PRs welcome!
33
+ Some methods and options are missing at the moment. PRs welcome!
34
+
35
+ ## Tutorial
32
36
 
33
- Some examples below are from [Deep Learning with PyTorch: A 60 Minutes Blitz](https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html).
37
+ Some examples below are from [Deep Learning with PyTorch: A 60 Minutes Blitz](https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html)
34
38
 
35
39
  ### Tensors
36
40
 
@@ -145,7 +149,7 @@ Convert a Numo array to a tensor
145
149
 
146
150
  ```ruby
147
151
  b = Numo::NArray.cast([1, 2, 3])
148
- Torch.from_numpy(b)
152
+ Torch.from_numo(b)
149
153
  ```
150
154
 
151
155
  ### Autograd
@@ -180,10 +184,10 @@ Stop autograd from tracking history
180
184
 
181
185
  ```ruby
182
186
  x.requires_grad # true
183
- (x ** 2).requires_grad # true
187
+ (x**2).requires_grad # true
184
188
 
185
189
  Torch.no_grad do
186
- (x ** 2).requires_grad # false
190
+ (x**2).requires_grad # false
187
191
  end
188
192
  ```
189
193
 
@@ -359,6 +363,14 @@ Here’s a list of functions to create tensors (descriptions from the [C++ docs]
359
363
  Torch.zeros(3) # tensor([0, 0, 0])
360
364
  ```
361
365
 
366
+ ## Examples
367
+
368
+ Here are a few full examples:
369
+
370
+ - [Image classification with MNIST](examples/mnist) ([日本語版](https://qiita.com/kojix2/items/c19c36dc1bf73ea93409))
371
+ - [Collaborative filtering with MovieLens](examples/movielens)
372
+ - [Sequence models and word embeddings](examples/nlp)
373
+
362
374
  ## LibTorch Installation
363
375
 
364
376
  [Download LibTorch](https://pytorch.org/). For Linux, use the `cxx11 ABI` version. Then run:
@@ -405,7 +417,7 @@ To get started with development:
405
417
  git clone https://github.com/ankane/torch-rb.git
406
418
  cd torch-rb
407
419
  bundle install
408
- bundle exec rake compile
420
+ bundle exec rake compile -- --with-torch-dir=/path/to/libtorch
409
421
  bundle exec rake test
410
422
  ```
411
423
 
data/ext/torch/ext.cpp CHANGED
@@ -6,95 +6,29 @@
6
6
  #include <rice/Class.hpp>
7
7
  #include <rice/Constructor.hpp>
8
8
 
9
- using namespace Rice;
10
-
11
- template<>
12
- inline
13
- long long from_ruby<long long>(Object x)
14
- {
15
- return NUM2LL(x);
16
- }
9
+ #include "templates.hpp"
17
10
 
18
- template<>
19
- inline
20
- Object to_ruby<long long>(long long const & x)
21
- {
22
- return LL2NUM(x);
23
- }
11
+ // generated with:
12
+ // rake generate:functions
13
+ #include "torch_functions.hpp"
14
+ #include "tensor_functions.hpp"
15
+ #include "nn_functions.hpp"
24
16
 
25
- template<>
26
- inline
27
- unsigned long long from_ruby<unsigned long long>(Object x)
28
- {
29
- return NUM2ULL(x);
30
- }
31
-
32
- template<>
33
- inline
34
- Object to_ruby<unsigned long long>(unsigned long long const & x)
35
- {
36
- return ULL2NUM(x);
37
- }
38
-
39
- template<>
40
- inline
41
- short from_ruby<short>(Object x)
42
- {
43
- return NUM2SHORT(x);
44
- }
45
-
46
- template<>
47
- inline
48
- Object to_ruby<short>(short const & x)
49
- {
50
- return INT2NUM(x);
51
- }
52
-
53
- template<>
54
- inline
55
- unsigned short from_ruby<unsigned short>(Object x)
56
- {
57
- return NUM2USHORT(x);
58
- }
17
+ using namespace Rice;
59
18
 
60
- template<>
61
- inline
62
- Object to_ruby<unsigned short>(unsigned short const & x)
19
+ extern "C"
20
+ void Init_ext()
63
21
  {
64
- return UINT2NUM(x);
65
- }
22
+ Module rb_mTorch = define_module("Torch");
23
+ add_torch_functions(rb_mTorch);
66
24
 
67
- // need to wrap torch::IntArrayRef() since
68
- // it doesn't own underlying data
69
- class IntArrayRef {
70
- std::vector<int64_t> vec;
71
- public:
72
- IntArrayRef(Object o) {
73
- Array a = Array(o);
74
- for (size_t i = 0; i < a.size(); i++) {
75
- vec.push_back(from_ruby<int64_t>(a[i]));
76
- }
77
- }
78
- operator torch::IntArrayRef() {
79
- return torch::IntArrayRef(vec);
80
- }
81
- };
82
-
83
- template<>
84
- inline
85
- IntArrayRef from_ruby<IntArrayRef>(Object x)
86
- {
87
- return IntArrayRef(x);
88
- }
25
+ Class rb_cTensor = define_class_under<torch::Tensor>(rb_mTorch, "Tensor");
26
+ add_tensor_functions(rb_cTensor);
89
27
 
90
- // for now
91
- typedef float Scalar;
28
+ Module rb_mNN = define_module_under(rb_mTorch, "NN");
29
+ add_nn_functions(rb_mNN);
92
30
 
93
- extern "C"
94
- void Init_ext()
95
- {
96
- Module rb_mTorch = define_module("Torch")
97
- .define_singleton_method(
31
+ rb_mTorch.define_singleton_method(
98
32
  "grad_enabled?",
99
33
  *[]() {
100
34
  return torch::GradMode::is_enabled();
@@ -104,11 +38,6 @@ void Init_ext()
104
38
  *[](bool enabled) {
105
39
  torch::GradMode::set_enabled(enabled);
106
40
  })
107
- .define_singleton_method(
108
- "floating_point?",
109
- *[](torch::Tensor& input) {
110
- return torch::is_floating_point(input);
111
- })
112
41
  .define_singleton_method(
113
42
  "manual_seed",
114
43
  *[](uint64_t seed) {
@@ -177,321 +106,100 @@ void Init_ext()
177
106
  })
178
107
  // begin operations
179
108
  .define_singleton_method(
180
- "_mean",
181
- *[](torch::Tensor& input) {
182
- return torch::mean(input);
183
- })
184
- .define_singleton_method(
185
- "_mean_dim",
186
- *[](torch::Tensor& input, int64_t dim, bool keepdim) {
187
- return torch::mean(input, dim, keepdim);
188
- })
189
- .define_singleton_method(
190
- "_sum",
191
- *[](torch::Tensor& input) {
192
- return torch::sum(input);
193
- })
194
- .define_singleton_method(
195
- "_sum_dim",
196
- *[](torch::Tensor& input, int64_t dim, bool keepdim) {
197
- return torch::sum(input, dim, keepdim);
198
- })
199
- .define_singleton_method(
200
- "_argmax",
201
- *[](torch::Tensor& input) {
202
- return torch::argmax(input);
203
- })
204
- .define_singleton_method(
205
- "_argmax_dim",
206
- *[](torch::Tensor& input, int64_t dim, bool keepdim) {
207
- return torch::argmax(input, dim, keepdim);
208
- })
209
- .define_singleton_method(
210
- "_norm",
211
- *[](torch::Tensor& input) {
212
- return torch::norm(input);
213
- })
214
- .define_singleton_method(
215
- "_min",
216
- *[](torch::Tensor& input) {
217
- return torch::min(input);
218
- })
219
- .define_singleton_method(
220
- "_max",
221
- *[](torch::Tensor& input) {
222
- return torch::max(input);
223
- })
224
- .define_singleton_method(
225
- "_exp",
226
- *[](torch::Tensor& input) {
227
- return torch::exp(input);
228
- })
229
- .define_singleton_method(
230
- "_log",
231
- *[](torch::Tensor& input) {
232
- return torch::log(input);
233
- })
234
- .define_singleton_method(
235
- "_unsqueeze",
236
- *[](torch::Tensor& input, int64_t dim) {
237
- return torch::unsqueeze(input, dim);
238
- })
239
- .define_singleton_method(
240
- "_dot",
241
- *[](torch::Tensor& input, torch::Tensor& tensor) {
242
- return torch::dot(input, tensor);
243
- })
244
- .define_singleton_method(
245
- "_matmul",
246
- *[](torch::Tensor& input, torch::Tensor& other) {
247
- return torch::matmul(input, other);
248
- })
249
- .define_singleton_method(
250
- "_eq",
251
- *[](torch::Tensor& input, torch::Tensor& other) {
252
- return torch::eq(input, other);
253
- })
254
- .define_singleton_method(
255
- "_add",
256
- *[](torch::Tensor& input, torch::Tensor& other) {
257
- return torch::add(input, other);
258
- })
259
- .define_singleton_method(
260
- "_add_scalar",
261
- *[](torch::Tensor& input, float other) {
262
- return torch::add(input, other);
263
- })
264
- .define_singleton_method(
265
- "_add_out",
266
- *[](torch::Tensor& out, torch::Tensor& input, torch::Tensor& other) {
267
- return torch::add_out(out, input, other);
268
- })
269
- .define_singleton_method(
270
- "_sub",
271
- *[](torch::Tensor& input, torch::Tensor& other) {
272
- return torch::sub(input, other);
273
- })
274
- .define_singleton_method(
275
- "_sub_scalar",
276
- *[](torch::Tensor& input, float other) {
277
- return torch::sub(input, other);
278
- })
279
- .define_singleton_method(
280
- "_mul",
281
- *[](torch::Tensor& input, torch::Tensor& other) {
282
- return torch::mul(input, other);
283
- })
284
- .define_singleton_method(
285
- "_mul_scalar",
286
- *[](torch::Tensor& input, float other) {
287
- return torch::mul(input, other);
288
- })
289
- .define_singleton_method(
290
- "_div",
291
- *[](torch::Tensor& input, torch::Tensor& other) {
292
- return torch::div(input, other);
293
- })
294
- .define_singleton_method(
295
- "_div_scalar",
296
- *[](torch::Tensor& input, float other) {
297
- return torch::div(input, other);
298
- })
299
- .define_singleton_method(
300
- "_remainder",
301
- *[](torch::Tensor& input, torch::Tensor& other) {
302
- return torch::remainder(input, other);
109
+ "_save",
110
+ *[](const Tensor &value) {
111
+ auto v = torch::pickle_save(value);
112
+ std::string str(v.begin(), v.end());
113
+ return str;
303
114
  })
304
115
  .define_singleton_method(
305
- "_remainder_scalar",
306
- *[](torch::Tensor& input, float other) {
307
- return torch::remainder(input, other);
116
+ "_binary_cross_entropy_with_logits",
117
+ *[](const Tensor &input, const Tensor &target, OptionalTensor weight, OptionalTensor pos_weight, MyReduction reduction) {
118
+ return torch::binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction);
308
119
  })
309
120
  .define_singleton_method(
310
- "_pow",
311
- *[](torch::Tensor& input, Scalar exponent) {
312
- return torch::pow(input, exponent);
313
- })
314
- .define_singleton_method(
315
- "_neg",
316
- *[](torch::Tensor& input) {
317
- return torch::neg(input);
318
- })
319
- .define_singleton_method(
320
- "_reshape",
321
- *[](torch::Tensor& input, IntArrayRef shape) {
322
- return torch::reshape(input, shape);
323
- })
324
- .define_singleton_method(
325
- "relu",
326
- *[](torch::Tensor& input) {
327
- return torch::relu(input);
328
- })
329
- .define_singleton_method(
330
- "prelu",
331
- *[](torch::Tensor& input, torch::Tensor& weight) {
332
- return torch::prelu(input, weight);
333
- })
334
- .define_singleton_method(
335
- "leaky_relu",
336
- *[](torch::Tensor& input, Scalar negative_slope = 0.01) {
337
- return torch::leaky_relu(input, negative_slope);
338
- })
339
- .define_singleton_method(
340
- "conv2d",
341
- *[](torch::Tensor& input, torch::Tensor& weight, torch::Tensor& bias, IntArrayRef stride, IntArrayRef padding) {
342
- return torch::conv2d(input, weight, bias, stride, padding);
343
- })
344
- .define_singleton_method(
345
- "linear",
346
- *[](torch::Tensor& input, torch::Tensor& weight, torch::Tensor& bias) {
347
- return torch::linear(input, weight, bias);
348
- })
349
- .define_singleton_method(
350
- "max_pool2d",
351
- *[](torch::Tensor& input, IntArrayRef kernel_size) {
352
- return torch::max_pool2d(input, kernel_size);
353
- })
354
- .define_singleton_method(
355
- "avg_pool2d",
356
- *[](torch::Tensor& input, IntArrayRef kernel_size) {
357
- return torch::avg_pool2d(input, kernel_size);
358
- })
359
- .define_singleton_method(
360
- "mse_loss",
361
- *[](torch::Tensor& input, torch::Tensor& target, std::string reduction) {
362
- auto red = reduction == "mean" ? Reduction::Mean : Reduction::Sum;
363
- return torch::mse_loss(input, target, red);
364
- })
365
- .define_singleton_method(
366
- "nll_loss",
367
- *[](torch::Tensor& input, torch::Tensor& target) {
368
- return torch::nll_loss(input, target);
121
+ "_from_blob",
122
+ *[](String s, IntArrayRef size, const torch::TensorOptions &options) {
123
+ void *data = const_cast<char *>(s.c_str());
124
+ return torch::from_blob(data, size, options);
369
125
  })
370
126
  .define_singleton_method(
371
127
  "_tensor",
372
128
  *[](Object o, IntArrayRef size, const torch::TensorOptions &options) {
373
129
  Array a = Array(o);
374
- std::vector<float> vec;
375
- for (size_t i = 0; i < a.size(); i++) {
376
- vec.push_back(from_ruby<float>(a[i]));
130
+ auto dtype = options.dtype();
131
+ torch::Tensor t;
132
+ if (dtype == torch::kBool) {
133
+ throw std::runtime_error("Cannot create bool from tensor method yet");
134
+ } else {
135
+ std::vector<float> vec;
136
+ for (size_t i = 0; i < a.size(); i++) {
137
+ vec.push_back(from_ruby<float>(a[i]));
138
+ }
139
+ t = torch::tensor(vec, options);
377
140
  }
378
- return torch::tensor(vec, options).reshape(size);
141
+ return t.reshape(size);
379
142
  });
380
143
 
381
- Class rb_cTensor = define_class_under<torch::Tensor>(rb_mTorch, "Tensor")
144
+ rb_cTensor
382
145
  .define_method("cuda?", &torch::Tensor::is_cuda)
383
- .define_method("distributed?", &torch::Tensor::is_distributed)
384
- .define_method("complex?", &torch::Tensor::is_complex)
385
- .define_method("floating_point?", &torch::Tensor::is_floating_point)
386
- .define_method("signed?", &torch::Tensor::is_signed)
387
146
  .define_method("sparse?", &torch::Tensor::is_sparse)
388
147
  .define_method("quantized?", &torch::Tensor::is_quantized)
389
148
  .define_method("dim", &torch::Tensor::dim)
390
- .define_method("numel", &torch::Tensor::numel)
391
149
  .define_method("element_size", &torch::Tensor::element_size)
392
150
  .define_method("requires_grad", &torch::Tensor::requires_grad)
393
151
  .define_method(
394
- "zero!",
395
- *[](torch::Tensor& self) {
396
- return self.zero_();
397
- })
398
- .define_method(
399
- "detach!",
400
- *[](torch::Tensor& self) {
401
- return self.detach_();
152
+ "addcmul!",
153
+ *[](Tensor& self, Scalar value, const Tensor & tensor1, const Tensor & tensor2) {
154
+ return self.addcmul_(tensor1, tensor2, value);
402
155
  })
403
156
  .define_method(
404
- "_select",
405
- *[](torch::Tensor& self, int64_t dim, int64_t index) {
406
- return self.select(dim, index);
407
- })
408
- .define_method(
409
- "_slice",
410
- *[](torch::Tensor& self, int64_t dim, int64_t start, int64_t end, int64_t step) {
411
- return self.slice(dim, start, end, step);
157
+ "addcdiv!",
158
+ *[](Tensor& self, Scalar value, const Tensor & tensor1, const Tensor & tensor2) {
159
+ return self.addcdiv_(tensor1, tensor2, value);
412
160
  })
413
161
  .define_method(
414
162
  "_requires_grad!",
415
- *[](torch::Tensor& self, bool requires_grad) {
163
+ *[](Tensor& self, bool requires_grad) {
416
164
  return self.set_requires_grad(requires_grad);
417
165
  })
418
166
  .define_method(
419
167
  "_backward",
420
- *[](torch::Tensor& self) {
421
- return self.backward();
422
- })
423
- .define_method(
424
- "_backward_gradient",
425
- *[](torch::Tensor& self, const torch::Tensor& gradient) {
426
- return self.backward(gradient);
168
+ *[](Tensor& self, Object gradient) {
169
+ return gradient.is_nil() ? self.backward() : self.backward(from_ruby<torch::Tensor>(gradient));
427
170
  })
428
171
  .define_method(
429
172
  "grad",
430
- *[](torch::Tensor& self) {
173
+ *[](Tensor& self) {
431
174
  return self.grad();
432
175
  })
433
176
  .define_method(
434
177
  "_dtype",
435
- *[](torch::Tensor& self) {
178
+ *[](Tensor& self) {
436
179
  return (int) at::typeMetaToScalarType(self.dtype());
437
180
  })
438
181
  .define_method(
439
182
  "_type",
440
- *[](torch::Tensor& self, int dtype) {
183
+ *[](Tensor& self, int dtype) {
441
184
  return self.toType((torch::ScalarType) dtype);
442
185
  })
443
186
  .define_method(
444
187
  "_layout",
445
- *[](torch::Tensor& self) {
188
+ *[](Tensor& self) {
446
189
  std::stringstream s;
447
190
  s << self.layout();
448
191
  return s.str();
449
192
  })
450
193
  .define_method(
451
194
  "device",
452
- *[](torch::Tensor& self) {
195
+ *[](Tensor& self) {
453
196
  std::stringstream s;
454
197
  s << self.device();
455
198
  return s.str();
456
199
  })
457
200
  .define_method(
458
- "_view",
459
- *[](torch::Tensor& self, IntArrayRef size) {
460
- return self.view(size);
461
- })
462
- .define_method(
463
- "add!",
464
- *[](torch::Tensor& self, torch::Tensor& other) {
465
- self.add_(other);
466
- })
467
- .define_method(
468
- "sub!",
469
- *[](torch::Tensor& self, torch::Tensor& other) {
470
- self.sub_(other);
471
- })
472
- .define_method(
473
- "mul!",
474
- *[](torch::Tensor& self, torch::Tensor& other) {
475
- self.mul_(other);
476
- })
477
- .define_method(
478
- "div!",
479
- *[](torch::Tensor& self, torch::Tensor& other) {
480
- self.div_(other);
481
- })
482
- .define_method(
483
- "log_softmax",
484
- *[](torch::Tensor& self, int64_t dim) {
485
- return self.log_softmax(dim);
486
- })
487
- .define_method(
488
- "data",
489
- *[](torch::Tensor& self) {
490
- return self.data();
491
- })
492
- .define_method(
493
- "_data",
494
- *[](torch::Tensor& self) {
201
+ "_flat_data",
202
+ *[](Tensor& self) {
495
203
  Array a;
496
204
  auto dtype = self.dtype();
497
205
 
@@ -532,21 +240,23 @@ void Init_ext()
532
240
  a.push(data[i]);
533
241
  }
534
242
  } else if (dtype == torch::kBool) {
535
- // bool
536
- throw std::runtime_error("Type not supported yet");
243
+ bool* data = self.data_ptr<bool>();
244
+ for (int i = 0; i < self.numel(); i++) {
245
+ a.push(data[i] ? True : False);
246
+ }
537
247
  } else {
538
248
  throw std::runtime_error("Unsupported type");
539
249
  }
540
250
  return a;
541
251
  })
542
252
  .define_method(
543
- "_size",
544
- *[](torch::Tensor& self, int i) {
545
- return self.size(i);
253
+ "_to",
254
+ *[](Tensor& self, torch::Device device, int dtype, bool non_blocking, bool copy) {
255
+ return self.to(device, (torch::ScalarType) dtype, non_blocking, copy);
546
256
  })
547
257
  .define_singleton_method(
548
258
  "_make_subclass",
549
- *[](torch::Tensor& rd, bool requires_grad) {
259
+ *[](Tensor& rd, bool requires_grad) {
550
260
  auto data = torch::autograd::as_variable_ref(rd).detach();
551
261
  data.unsafeGetTensorImpl()->set_allow_tensor_metadata_change(true);
552
262
  auto var = data.set_requires_grad(requires_grad);
@@ -593,30 +303,99 @@ void Init_ext()
593
303
  return self.requires_grad(requires_grad);
594
304
  });
595
305
 
596
- Module rb_mNN = define_module_under(rb_mTorch, "NN");
597
-
598
306
  Module rb_mInit = define_module_under(rb_mNN, "Init")
599
307
  .define_singleton_method(
600
- "kaiming_uniform_",
601
- *[](torch::Tensor& input, double a) {
602
- return torch::nn::init::kaiming_uniform_(input, a);
308
+ "_calculate_gain",
309
+ *[](NonlinearityType nonlinearity, double param) {
310
+ return torch::nn::init::calculate_gain(nonlinearity, param);
311
+ })
312
+ .define_singleton_method(
313
+ "_uniform!",
314
+ *[](Tensor tensor, double low, double high) {
315
+ return torch::nn::init::uniform_(tensor, low, high);
316
+ })
317
+ .define_singleton_method(
318
+ "_normal!",
319
+ *[](Tensor tensor, double mean, double std) {
320
+ return torch::nn::init::normal_(tensor, mean, std);
321
+ })
322
+ .define_singleton_method(
323
+ "_constant!",
324
+ *[](Tensor tensor, Scalar value) {
325
+ return torch::nn::init::constant_(tensor, value);
326
+ })
327
+ .define_singleton_method(
328
+ "_ones!",
329
+ *[](Tensor tensor) {
330
+ return torch::nn::init::ones_(tensor);
331
+ })
332
+ .define_singleton_method(
333
+ "_zeros!",
334
+ *[](Tensor tensor) {
335
+ return torch::nn::init::zeros_(tensor);
336
+ })
337
+ .define_singleton_method(
338
+ "_eye!",
339
+ *[](Tensor tensor) {
340
+ return torch::nn::init::eye_(tensor);
341
+ })
342
+ .define_singleton_method(
343
+ "_dirac!",
344
+ *[](Tensor tensor) {
345
+ return torch::nn::init::dirac_(tensor);
346
+ })
347
+ .define_singleton_method(
348
+ "_xavier_uniform!",
349
+ *[](Tensor tensor, double gain) {
350
+ return torch::nn::init::xavier_uniform_(tensor, gain);
351
+ })
352
+ .define_singleton_method(
353
+ "_xavier_normal!",
354
+ *[](Tensor tensor, double gain) {
355
+ return torch::nn::init::xavier_normal_(tensor, gain);
356
+ })
357
+ .define_singleton_method(
358
+ "_kaiming_uniform!",
359
+ *[](Tensor tensor, double a, FanModeType mode, NonlinearityType nonlinearity) {
360
+ return torch::nn::init::kaiming_uniform_(tensor, a, mode, nonlinearity);
361
+ })
362
+ .define_singleton_method(
363
+ "_kaiming_normal!",
364
+ *[](Tensor tensor, double a, FanModeType mode, NonlinearityType nonlinearity) {
365
+ return torch::nn::init::kaiming_normal_(tensor, a, mode, nonlinearity);
366
+ })
367
+ .define_singleton_method(
368
+ "_orthogonal!",
369
+ *[](Tensor tensor, double gain) {
370
+ return torch::nn::init::orthogonal_(tensor, gain);
603
371
  })
604
372
  .define_singleton_method(
605
- "uniform_",
606
- *[](torch::Tensor& input, double to, double from) {
607
- return torch::nn::init::uniform_(input, to, from);
373
+ "_sparse!",
374
+ *[](Tensor tensor, double sparsity, double std) {
375
+ return torch::nn::init::sparse_(tensor, sparsity, std);
608
376
  });
609
377
 
610
378
  Class rb_cParameter = define_class_under<torch::autograd::Variable, torch::Tensor>(rb_mNN, "Parameter")
611
- // TODO return grad or nil to remove need for 2nd function
612
379
  .define_method(
613
- "_grad",
380
+ "grad",
614
381
  *[](torch::autograd::Variable& self) {
615
- return self.grad();
616
- })
382
+ auto grad = self.grad();
383
+ return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
384
+ });
385
+
386
+ Class rb_cDevice = define_class_under<torch::Device>(rb_mTorch, "Device")
387
+ .define_constructor(Constructor<torch::Device, std::string>())
388
+ .define_method("index", &torch::Device::index)
389
+ .define_method("index?", &torch::Device::has_index)
617
390
  .define_method(
618
- "_grad_defined",
619
- *[](torch::autograd::Variable& self) {
620
- return self.grad().defined();
391
+ "type",
392
+ *[](torch::Device& self) {
393
+ std::stringstream s;
394
+ s << self.type();
395
+ return s.str();
621
396
  });
397
+
398
+ Module rb_mCUDA = define_module_under(rb_mTorch, "CUDA")
399
+ .define_singleton_method("available?", &torch::cuda::is_available)
400
+ .define_singleton_method("device_count", &torch::cuda::device_count);
622
401
  }