torch-rb 0.1.3 → 0.1.8

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +30 -0
  3. data/README.md +5 -2
  4. data/ext/torch/ext.cpp +130 -555
  5. data/ext/torch/extconf.rb +9 -0
  6. data/ext/torch/templates.cpp +55 -0
  7. data/ext/torch/templates.hpp +244 -0
  8. data/lib/torch.rb +209 -171
  9. data/lib/torch/inspector.rb +23 -19
  10. data/lib/torch/native/dispatcher.rb +48 -0
  11. data/lib/torch/native/function.rb +110 -0
  12. data/lib/torch/native/generator.rb +168 -0
  13. data/lib/torch/native/native_functions.yaml +6491 -0
  14. data/lib/torch/native/parser.rb +134 -0
  15. data/lib/torch/nn/avg_pool1d.rb +18 -0
  16. data/lib/torch/nn/avg_pool2d.rb +19 -0
  17. data/lib/torch/nn/avg_pool3d.rb +19 -0
  18. data/lib/torch/nn/avg_poolnd.rb +9 -0
  19. data/lib/torch/nn/batch_norm.rb +75 -0
  20. data/lib/torch/nn/batch_norm1d.rb +11 -0
  21. data/lib/torch/nn/batch_norm2d.rb +11 -0
  22. data/lib/torch/nn/batch_norm3d.rb +11 -0
  23. data/lib/torch/nn/bce_loss.rb +13 -0
  24. data/lib/torch/nn/bce_with_logits_loss.rb +15 -0
  25. data/lib/torch/nn/bilinear.rb +38 -0
  26. data/lib/torch/nn/constant_pad1d.rb +10 -0
  27. data/lib/torch/nn/constant_pad2d.rb +10 -0
  28. data/lib/torch/nn/constant_pad3d.rb +10 -0
  29. data/lib/torch/nn/constant_padnd.rb +18 -0
  30. data/lib/torch/nn/conv1d.rb +22 -0
  31. data/lib/torch/nn/conv2d.rb +10 -20
  32. data/lib/torch/nn/conv3d.rb +22 -0
  33. data/lib/torch/nn/convnd.rb +3 -3
  34. data/lib/torch/nn/cosine_embedding_loss.rb +14 -0
  35. data/lib/torch/nn/cosine_similarity.rb +15 -0
  36. data/lib/torch/nn/cross_entropy_loss.rb +14 -0
  37. data/lib/torch/nn/ctc_loss.rb +15 -0
  38. data/lib/torch/nn/dropoutnd.rb +2 -2
  39. data/lib/torch/nn/embedding_bag.rb +34 -0
  40. data/lib/torch/nn/fold.rb +20 -0
  41. data/lib/torch/nn/functional.rb +379 -32
  42. data/lib/torch/nn/group_norm.rb +36 -0
  43. data/lib/torch/nn/gru.rb +49 -0
  44. data/lib/torch/nn/hardshrink.rb +18 -0
  45. data/lib/torch/nn/hinge_embedding_loss.rb +14 -0
  46. data/lib/torch/nn/identity.rb +14 -0
  47. data/lib/torch/nn/init.rb +58 -1
  48. data/lib/torch/nn/instance_norm.rb +20 -0
  49. data/lib/torch/nn/instance_norm1d.rb +18 -0
  50. data/lib/torch/nn/instance_norm2d.rb +11 -0
  51. data/lib/torch/nn/instance_norm3d.rb +11 -0
  52. data/lib/torch/nn/kl_div_loss.rb +13 -0
  53. data/lib/torch/nn/l1_loss.rb +13 -0
  54. data/lib/torch/nn/layer_norm.rb +35 -0
  55. data/lib/torch/nn/leaky_relu.rb +20 -0
  56. data/lib/torch/nn/linear.rb +12 -11
  57. data/lib/torch/nn/local_response_norm.rb +21 -0
  58. data/lib/torch/nn/log_sigmoid.rb +9 -0
  59. data/lib/torch/nn/log_softmax.rb +14 -0
  60. data/lib/torch/nn/loss.rb +10 -0
  61. data/lib/torch/nn/lp_pool1d.rb +9 -0
  62. data/lib/torch/nn/lp_pool2d.rb +9 -0
  63. data/lib/torch/nn/lp_poolnd.rb +22 -0
  64. data/lib/torch/nn/lstm.rb +66 -0
  65. data/lib/torch/nn/margin_ranking_loss.rb +14 -0
  66. data/lib/torch/nn/max_pool1d.rb +9 -0
  67. data/lib/torch/nn/max_pool2d.rb +9 -0
  68. data/lib/torch/nn/max_pool3d.rb +9 -0
  69. data/lib/torch/nn/max_poolnd.rb +19 -0
  70. data/lib/torch/nn/max_unpool1d.rb +16 -0
  71. data/lib/torch/nn/max_unpool2d.rb +16 -0
  72. data/lib/torch/nn/max_unpool3d.rb +16 -0
  73. data/lib/torch/nn/max_unpoolnd.rb +9 -0
  74. data/lib/torch/nn/module.rb +186 -35
  75. data/lib/torch/nn/mse_loss.rb +2 -2
  76. data/lib/torch/nn/multi_label_margin_loss.rb +13 -0
  77. data/lib/torch/nn/multi_label_soft_margin_loss.rb +13 -0
  78. data/lib/torch/nn/multi_margin_loss.rb +17 -0
  79. data/lib/torch/nn/nll_loss.rb +14 -0
  80. data/lib/torch/nn/pairwise_distance.rb +16 -0
  81. data/lib/torch/nn/parameter.rb +2 -2
  82. data/lib/torch/nn/poisson_nll_loss.rb +16 -0
  83. data/lib/torch/nn/prelu.rb +19 -0
  84. data/lib/torch/nn/reflection_pad1d.rb +10 -0
  85. data/lib/torch/nn/reflection_pad2d.rb +10 -0
  86. data/lib/torch/nn/reflection_padnd.rb +13 -0
  87. data/lib/torch/nn/relu.rb +8 -3
  88. data/lib/torch/nn/replication_pad1d.rb +10 -0
  89. data/lib/torch/nn/replication_pad2d.rb +10 -0
  90. data/lib/torch/nn/replication_pad3d.rb +10 -0
  91. data/lib/torch/nn/replication_padnd.rb +13 -0
  92. data/lib/torch/nn/rnn.rb +22 -0
  93. data/lib/torch/nn/rnn_base.rb +198 -0
  94. data/lib/torch/nn/sequential.rb +1 -10
  95. data/lib/torch/nn/sigmoid.rb +9 -0
  96. data/lib/torch/nn/smooth_l1_loss.rb +13 -0
  97. data/lib/torch/nn/soft_margin_loss.rb +13 -0
  98. data/lib/torch/nn/softmax.rb +18 -0
  99. data/lib/torch/nn/softmax2d.rb +10 -0
  100. data/lib/torch/nn/softmin.rb +14 -0
  101. data/lib/torch/nn/softplus.rb +19 -0
  102. data/lib/torch/nn/softshrink.rb +18 -0
  103. data/lib/torch/nn/softsign.rb +9 -0
  104. data/lib/torch/nn/tanh.rb +9 -0
  105. data/lib/torch/nn/tanhshrink.rb +9 -0
  106. data/lib/torch/nn/triplet_margin_loss.rb +18 -0
  107. data/lib/torch/nn/unfold.rb +19 -0
  108. data/lib/torch/nn/utils.rb +25 -0
  109. data/lib/torch/nn/weighted_loss.rb +10 -0
  110. data/lib/torch/nn/zero_pad2d.rb +9 -0
  111. data/lib/torch/random.rb +10 -0
  112. data/lib/torch/tensor.rb +51 -44
  113. data/lib/torch/version.rb +1 -1
  114. metadata +98 -6
  115. data/lib/torch/ext.bundle +0 -0
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: e7f715179c9a84dc7399b80d93fd61f2bbb58a0156e6084dc4abb23e1d4a1b52
4
- data.tar.gz: 6928379ae7c92a77ad9dde4f4224ec33c6f8575a9b77585c0147e4f5361021de
3
+ metadata.gz: fca87cb9b6d255287e9fafadf786c113798abbe76b36c82b8271b79cfbf3c2b9
4
+ data.tar.gz: 4813c71f5ad6d078e78da03cf59f8036e9e76258ffb67f538899bba146dcba2a
5
5
  SHA512:
6
- metadata.gz: 9911a9e86d93f1e410776c44fdb3cd9aa06c83d1f0e42fdab8530970bea6520aed7906e96fb8243efd6b957453ebc13678b2b92e4c85b54407030a32c6196e08
7
- data.tar.gz: 0d080f5458a5dcf8fee19ce5e2e342bf6269432de6e78d923036232963ebb80daeea993c0bbf4af2d6da46593ac28a72a8232020a9fcb48acc3276c9e1ebebf3
6
+ metadata.gz: 22c7150e6a7d9132c40c67819beecc6b8c69b268bd227a8e4aa324ef5e2707004691d5b65dcd4ba1ac537bfaf783947da7e5a323417cffcbf7d348768c40b7c6
7
+ data.tar.gz: 8a86c6b68efe6ad85a261d7033b87f040c22b2c670a0238accd6246274caed17b86d7b424441bba80c5ea67ec1bf53b05444dfb0c45ea5b8a52806d0ce19ec1e
@@ -1,3 +1,33 @@
1
+ ## 0.1.8 (2020-01-17)
2
+
3
+ - Added support for libtorch 1.4.0
4
+ - Dropped support for libtorch 1.3.1
5
+
6
+ ## 0.1.7 (2020-01-10)
7
+
8
+ - Fixed installation error with Ruby 2.7
9
+
10
+ ## 0.1.6 (2019-12-09)
11
+
12
+ - Added recurrent layers
13
+ - Added more pooling layers
14
+ - Added normalization layers
15
+
16
+ ## 0.1.5 (2019-12-06)
17
+
18
+ - Added many more functions
19
+ - Added tensor classes - `FloatTensor`, `LongTensor`, etc
20
+ - Improved modules
21
+
22
+ ## 0.1.4 (2019-12-01)
23
+
24
+ - Added distance functions
25
+ - Added more activations
26
+ - Added more linear layers
27
+ - Added more loss functions
28
+ - Added more init methods
29
+ - Added support for tensor assignment
30
+
1
31
  ## 0.1.3 (2019-11-30)
2
32
 
3
33
  - Changed to BSD 3-Clause license to match PyTorch
data/README.md CHANGED
@@ -20,6 +20,8 @@ Add this line to your application’s Gemfile:
20
20
  gem 'torch-rb'
21
21
  ```
22
22
 
23
+ It can take a few minutes to compile the extension.
24
+
23
25
  ## Getting Started
24
26
 
25
27
  This library follows the [PyTorch API](https://pytorch.org/docs/stable/torch.html). There are a few changes to make it more Ruby-like:
@@ -28,7 +30,7 @@ This library follows the [PyTorch API](https://pytorch.org/docs/stable/torch.htm
28
30
  - Methods that return booleans use `?` instead of `is_` (`tensor?` instead of `is_tensor`)
29
31
  - Numo is used instead of NumPy (`x.numo` instead of `x.numpy()`)
30
32
 
31
- Many methods and options are missing at the moment. PRs welcome!
33
+ Some methods and options are missing at the moment. PRs welcome!
32
34
 
33
35
  ## Tutorial
34
36
 
@@ -365,8 +367,9 @@ Here’s a list of functions to create tensors (descriptions from the [C++ docs]
365
367
 
366
368
  Here are a few full examples:
367
369
 
368
- - [Image classification with MNIST](examples/mnist)
370
+ - [Image classification with MNIST](examples/mnist) ([日本語版](https://qiita.com/kojix2/items/c19c36dc1bf73ea93409))
369
371
  - [Collaborative filtering with MovieLens](examples/movielens)
372
+ - [Sequence models and word embeddings](examples/nlp)
370
373
 
371
374
  ## LibTorch Installation
372
375
 
@@ -6,137 +6,35 @@
6
6
  #include <rice/Class.hpp>
7
7
  #include <rice/Constructor.hpp>
8
8
 
9
- using namespace Rice;
10
-
11
- template<>
12
- inline
13
- long long from_ruby<long long>(Object x)
14
- {
15
- return NUM2LL(x);
16
- }
17
-
18
- template<>
19
- inline
20
- Object to_ruby<long long>(long long const & x)
21
- {
22
- return LL2NUM(x);
23
- }
24
-
25
- template<>
26
- inline
27
- unsigned long long from_ruby<unsigned long long>(Object x)
28
- {
29
- return NUM2ULL(x);
30
- }
31
-
32
- template<>
33
- inline
34
- Object to_ruby<unsigned long long>(unsigned long long const & x)
35
- {
36
- return ULL2NUM(x);
37
- }
38
-
39
- template<>
40
- inline
41
- short from_ruby<short>(Object x)
42
- {
43
- return NUM2SHORT(x);
44
- }
45
-
46
- template<>
47
- inline
48
- Object to_ruby<short>(short const & x)
49
- {
50
- return INT2NUM(x);
51
- }
52
-
53
- template<>
54
- inline
55
- unsigned short from_ruby<unsigned short>(Object x)
56
- {
57
- return NUM2USHORT(x);
58
- }
9
+ #include "templates.hpp"
59
10
 
60
- template<>
61
- inline
62
- Object to_ruby<unsigned short>(unsigned short const & x)
63
- {
64
- return UINT2NUM(x);
65
- }
11
+ // generated with:
12
+ // rake generate:functions
13
+ #include "torch_functions.hpp"
14
+ #include "tensor_functions.hpp"
15
+ #include "nn_functions.hpp"
66
16
 
67
- // need to wrap torch::IntArrayRef() since
68
- // it doesn't own underlying data
69
- class IntArrayRef {
70
- std::vector<int64_t> vec;
71
- public:
72
- IntArrayRef(Object o) {
73
- Array a = Array(o);
74
- for (size_t i = 0; i < a.size(); i++) {
75
- vec.push_back(from_ruby<int64_t>(a[i]));
76
- }
77
- }
78
- operator torch::IntArrayRef() {
79
- return torch::IntArrayRef(vec);
80
- }
81
- };
82
-
83
- template<>
84
- inline
85
- IntArrayRef from_ruby<IntArrayRef>(Object x)
86
- {
87
- return IntArrayRef(x);
88
- }
17
+ using namespace Rice;
89
18
 
90
- // for now
91
- class Scalar {
92
- torch::Scalar value;
19
+ // need to make a distinction between parameters and tensors
20
+ class Parameter: public torch::autograd::Variable {
93
21
  public:
94
- Scalar(Object o) {
95
- // TODO cast based on Ruby type
96
- if (o.rb_type() == T_FIXNUM) {
97
- value = torch::Scalar(from_ruby<int64_t>(o));
98
- } else {
99
- value = torch::Scalar(from_ruby<float>(o));
100
- }
101
- }
102
- operator torch::Scalar() {
103
- return value;
104
- }
22
+ Parameter(Tensor&& t) : torch::autograd::Variable(t) { }
105
23
  };
106
24
 
107
- template<>
108
- inline
109
- Scalar from_ruby<Scalar>(Object x)
25
+ extern "C"
26
+ void Init_ext()
110
27
  {
111
- return Scalar(x);
112
- }
28
+ Module rb_mTorch = define_module("Torch");
29
+ add_torch_functions(rb_mTorch);
113
30
 
114
- class TensorList {
115
- std::vector<torch::Tensor> vec;
116
- public:
117
- TensorList(Object o) {
118
- Array a = Array(o);
119
- for (size_t i = 0; i < a.size(); i++) {
120
- vec.push_back(from_ruby<torch::Tensor>(a[i]));
121
- }
122
- }
123
- operator torch::TensorList() {
124
- return torch::TensorList(vec);
125
- }
126
- };
31
+ Class rb_cTensor = define_class_under<torch::Tensor>(rb_mTorch, "Tensor");
32
+ add_tensor_functions(rb_cTensor);
127
33
 
128
- template<>
129
- inline
130
- TensorList from_ruby<TensorList>(Object x)
131
- {
132
- return TensorList(x);
133
- }
34
+ Module rb_mNN = define_module_under(rb_mTorch, "NN");
35
+ add_nn_functions(rb_mNN);
134
36
 
135
- extern "C"
136
- void Init_ext()
137
- {
138
- Module rb_mTorch = define_module("Torch")
139
- .define_singleton_method(
37
+ rb_mTorch.define_singleton_method(
140
38
  "grad_enabled?",
141
39
  *[]() {
142
40
  return torch::GradMode::is_enabled();
@@ -146,11 +44,6 @@ void Init_ext()
146
44
  *[](bool enabled) {
147
45
  torch::GradMode::set_enabled(enabled);
148
46
  })
149
- .define_singleton_method(
150
- "floating_point?",
151
- *[](torch::Tensor& input) {
152
- return torch::is_floating_point(input);
153
- })
154
47
  .define_singleton_method(
155
48
  "manual_seed",
156
49
  *[](uint64_t seed) {
@@ -219,277 +112,17 @@ void Init_ext()
219
112
  })
220
113
  // begin operations
221
114
  .define_singleton_method(
222
- "_mean",
223
- *[](torch::Tensor& input) {
224
- return torch::mean(input);
225
- })
226
- .define_singleton_method(
227
- "_mean_dim",
228
- *[](torch::Tensor& input, int64_t dim, bool keepdim) {
229
- return torch::mean(input, dim, keepdim);
230
- })
231
- .define_singleton_method(
232
- "_sum",
233
- *[](torch::Tensor& input) {
234
- return torch::sum(input);
235
- })
236
- .define_singleton_method(
237
- "_sum_dim",
238
- *[](torch::Tensor& input, int64_t dim, bool keepdim) {
239
- return torch::sum(input, dim, keepdim);
240
- })
241
- .define_singleton_method(
242
- "_argmax",
243
- *[](torch::Tensor& input) {
244
- return torch::argmax(input);
245
- })
246
- .define_singleton_method(
247
- "_argmax_dim",
248
- *[](torch::Tensor& input, int64_t dim, bool keepdim) {
249
- return torch::argmax(input, dim, keepdim);
250
- })
251
- .define_singleton_method(
252
- "_cat",
253
- *[](TensorList tensors, int64_t dim) {
254
- return torch::cat(tensors, dim);
255
- })
256
- .define_singleton_method(
257
- "_norm",
258
- *[](torch::Tensor& input) {
259
- return torch::norm(input);
260
- })
261
- .define_singleton_method(
262
- "_min",
263
- *[](torch::Tensor& input) {
264
- return torch::min(input);
265
- })
266
- .define_singleton_method(
267
- "_max",
268
- *[](torch::Tensor& input) {
269
- return torch::max(input);
270
- })
271
- .define_singleton_method(
272
- "_max_out",
273
- *[](torch::Tensor &max, torch::Tensor &max_indices, const torch::Tensor &input, int64_t dim, bool keepdim) {
274
- // TODO add return value
275
- torch::_max_out(max, max_indices, input, dim, keepdim);
276
- })
277
- .define_singleton_method(
278
- "_sqrt",
279
- *[](torch::Tensor& input) {
280
- return torch::sqrt(input);
281
- })
282
- .define_singleton_method(
283
- "_exp",
284
- *[](torch::Tensor& input) {
285
- return torch::exp(input);
286
- })
287
- .define_singleton_method(
288
- "_log",
289
- *[](torch::Tensor& input) {
290
- return torch::log(input);
291
- })
292
- .define_singleton_method(
293
- "_sign",
294
- *[](torch::Tensor& input) {
295
- return torch::sign(input);
296
- })
297
- .define_singleton_method(
298
- "_unsqueeze",
299
- *[](torch::Tensor& input, int64_t dim) {
300
- return torch::unsqueeze(input, dim);
301
- })
302
- .define_singleton_method(
303
- "_dot",
304
- *[](torch::Tensor& input, torch::Tensor& tensor) {
305
- return torch::dot(input, tensor);
306
- })
307
- .define_singleton_method(
308
- "_matmul",
309
- *[](torch::Tensor& input, torch::Tensor& other) {
310
- return torch::matmul(input, other);
311
- })
312
- .define_singleton_method(
313
- "_eq",
314
- *[](torch::Tensor& input, torch::Tensor& other) {
315
- return torch::eq(input, other);
316
- })
317
- .define_singleton_method(
318
- "_gt",
319
- // TODO support tensors
320
- *[](torch::Tensor& input, Scalar other) {
321
- return torch::gt(input, other);
322
- })
323
- .define_singleton_method(
324
- "_lt",
325
- // TODO support tensors
326
- *[](torch::Tensor& input, Scalar other) {
327
- return torch::lt(input, other);
328
- })
329
- .define_singleton_method(
330
- "_add",
331
- *[](torch::Tensor& input, torch::Tensor& other) {
332
- return torch::add(input, other);
333
- })
334
- .define_singleton_method(
335
- "_add_scalar",
336
- *[](torch::Tensor& input, Scalar other) {
337
- return torch::add(input, other);
338
- })
339
- .define_singleton_method(
340
- "_add_out",
341
- *[](torch::Tensor& out, torch::Tensor& input, torch::Tensor& other) {
342
- return torch::add_out(out, input, other);
343
- })
344
- .define_singleton_method(
345
- "_sub",
346
- *[](torch::Tensor& input, torch::Tensor& other) {
347
- return torch::sub(input, other);
348
- })
349
- .define_singleton_method(
350
- "_sub_scalar",
351
- *[](torch::Tensor& input, Scalar other) {
352
- return torch::sub(input, other);
353
- })
354
- .define_singleton_method(
355
- "_mul",
356
- *[](torch::Tensor& input, torch::Tensor& other) {
357
- return torch::mul(input, other);
358
- })
359
- .define_singleton_method(
360
- "_mul_scalar",
361
- *[](torch::Tensor& input, Scalar other) {
362
- return torch::mul(input, other);
363
- })
364
- .define_singleton_method(
365
- "_div",
366
- *[](torch::Tensor& input, torch::Tensor& other) {
367
- return torch::div(input, other);
368
- })
369
- .define_singleton_method(
370
- "_div_scalar",
371
- *[](torch::Tensor& input, Scalar other) {
372
- return torch::div(input, other);
373
- })
374
- .define_singleton_method(
375
- "_remainder",
376
- *[](torch::Tensor& input, torch::Tensor& other) {
377
- return torch::remainder(input, other);
378
- })
379
- .define_singleton_method(
380
- "_remainder_scalar",
381
- *[](torch::Tensor& input, Scalar other) {
382
- return torch::remainder(input, other);
383
- })
384
- .define_singleton_method(
385
- "_pow",
386
- *[](torch::Tensor& input, Scalar exponent) {
387
- return torch::pow(input, exponent);
388
- })
389
- .define_singleton_method(
390
- "_abs",
391
- *[](torch::Tensor& input) {
392
- return torch::abs(input);
393
- })
394
- .define_singleton_method(
395
- "_neg",
396
- *[](torch::Tensor& input) {
397
- return torch::neg(input);
115
+ "_save",
116
+ *[](const Tensor &value) {
117
+ auto v = torch::pickle_save(value);
118
+ std::string str(v.begin(), v.end());
119
+ return str;
398
120
  })
399
121
  .define_singleton_method(
400
- "_reshape",
401
- *[](torch::Tensor& input, IntArrayRef shape) {
402
- return torch::reshape(input, shape);
122
+ "_binary_cross_entropy_with_logits",
123
+ *[](const Tensor &input, const Tensor &target, OptionalTensor weight, OptionalTensor pos_weight, MyReduction reduction) {
124
+ return torch::binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction);
403
125
  })
404
- .define_singleton_method(
405
- "_flatten",
406
- *[](torch::Tensor& input, int64_t start_dim, int64_t end_dim) {
407
- return torch::flatten(input, start_dim, end_dim);
408
- })
409
- .define_singleton_method(
410
- "relu",
411
- *[](torch::Tensor& input) {
412
- return torch::relu(input);
413
- })
414
- .define_singleton_method(
415
- "conv2d",
416
- *[](torch::Tensor& input, torch::Tensor& weight, torch::Tensor& bias, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, int64_t groups) {
417
- return torch::conv2d(input, weight, bias, stride, padding, dilation, groups);
418
- })
419
- .define_singleton_method(
420
- "linear",
421
- *[](torch::Tensor& input, torch::Tensor& weight, torch::Tensor& bias) {
422
- return torch::linear(input, weight, bias);
423
- })
424
- .define_singleton_method(
425
- "max_pool2d",
426
- *[](torch::Tensor& input, IntArrayRef kernel_size) {
427
- return torch::max_pool2d(input, kernel_size);
428
- })
429
- .define_singleton_method(
430
- "avg_pool2d",
431
- *[](torch::Tensor& input, IntArrayRef kernel_size) {
432
- return torch::avg_pool2d(input, kernel_size);
433
- })
434
- .define_singleton_method(
435
- "_dropout",
436
- *[](torch::Tensor& input, float p, bool train) {
437
- return torch::dropout(input, p, train);
438
- })
439
- .define_singleton_method(
440
- "_dropout!",
441
- *[](torch::Tensor& input, float p, bool train) {
442
- return torch::dropout_(input, p, train);
443
- })
444
- .define_singleton_method(
445
- "_feature_dropout",
446
- *[](torch::Tensor& input, float p, bool train) {
447
- return torch::feature_dropout(input, p, train);
448
- })
449
- .define_singleton_method(
450
- "_feature_dropout!",
451
- *[](torch::Tensor& input, float p, bool train) {
452
- return torch::feature_dropout_(input, p, train);
453
- })
454
- .define_singleton_method(
455
- "_alpha_dropout",
456
- *[](torch::Tensor& input, float p, bool train) {
457
- return torch::alpha_dropout(input, p, train);
458
- })
459
- .define_singleton_method(
460
- "_alpha_dropout!",
461
- *[](torch::Tensor& input, float p, bool train) {
462
- return torch::alpha_dropout_(input, p, train);
463
- })
464
- .define_singleton_method(
465
- "_feature_alpha_dropout",
466
- *[](torch::Tensor& input, float p, bool train) {
467
- return torch::feature_alpha_dropout(input, p, train);
468
- })
469
- .define_singleton_method(
470
- "_feature_alpha_dropout!",
471
- *[](torch::Tensor& input, float p, bool train) {
472
- return torch::feature_alpha_dropout_(input, p, train);
473
- })
474
- .define_singleton_method(
475
- "_embedding",
476
- // weight and indices are swapped from Python interface
477
- *[](const torch::Tensor &indices, const torch::Tensor &weight, int64_t padding_idx, bool scale_grad_by_freq, bool sparse) {
478
- return torch::embedding(weight, indices, padding_idx, scale_grad_by_freq, sparse);
479
- })
480
- .define_singleton_method(
481
- "mse_loss",
482
- *[](torch::Tensor& input, torch::Tensor& target, std::string reduction) {
483
- auto red = reduction == "mean" ? Reduction::Mean : Reduction::Sum;
484
- return torch::mse_loss(input, target, red);
485
- })
486
- .define_singleton_method(
487
- "nll_loss",
488
- *[](torch::Tensor& input, torch::Tensor& target, std::string reduction) {
489
- auto red = reduction == "mean" ? Reduction::Mean : Reduction::Sum;
490
- return torch::nll_loss(input, target, {}, red);
491
- })
492
- .define_singleton_method("numel", &torch::numel)
493
126
  .define_singleton_method(
494
127
  "_from_blob",
495
128
  *[](String s, IntArrayRef size, const torch::TensorOptions &options) {
@@ -500,187 +133,86 @@ void Init_ext()
500
133
  "_tensor",
501
134
  *[](Object o, IntArrayRef size, const torch::TensorOptions &options) {
502
135
  Array a = Array(o);
503
- std::vector<float> vec;
504
- for (size_t i = 0; i < a.size(); i++) {
505
- vec.push_back(from_ruby<float>(a[i]));
136
+ auto dtype = options.dtype();
137
+ torch::Tensor t;
138
+ if (dtype == torch::kBool) {
139
+ throw std::runtime_error("Cannot create bool from tensor method yet");
140
+ } else {
141
+ std::vector<float> vec;
142
+ for (size_t i = 0; i < a.size(); i++) {
143
+ vec.push_back(from_ruby<float>(a[i]));
144
+ }
145
+ // hack for requires_grad error
146
+ if (options.requires_grad()) {
147
+ t = torch::tensor(vec, options.requires_grad(c10::nullopt));
148
+ t.set_requires_grad(true);
149
+ } else {
150
+ t = torch::tensor(vec, options);
151
+ }
506
152
  }
507
- return torch::tensor(vec, options).reshape(size);
153
+ return t.reshape(size);
508
154
  });
509
155
 
510
- Class rb_cTensor = define_class_under<torch::Tensor>(rb_mTorch, "Tensor")
156
+ rb_cTensor
511
157
  .define_method("cuda?", &torch::Tensor::is_cuda)
512
- .define_method("distributed?", &torch::Tensor::is_distributed)
513
- .define_method("complex?", &torch::Tensor::is_complex)
514
- .define_method("floating_point?", &torch::Tensor::is_floating_point)
515
- .define_method("signed?", &torch::Tensor::is_signed)
516
158
  .define_method("sparse?", &torch::Tensor::is_sparse)
517
159
  .define_method("quantized?", &torch::Tensor::is_quantized)
518
160
  .define_method("dim", &torch::Tensor::dim)
161
+ .define_method("numel", &torch::Tensor::numel)
519
162
  .define_method("element_size", &torch::Tensor::element_size)
520
163
  .define_method("requires_grad", &torch::Tensor::requires_grad)
521
- .define_method("view_as", &torch::Tensor::view_as)
522
164
  .define_method(
523
165
  "addcmul!",
524
- *[](torch::Tensor& self, Scalar value, const torch::Tensor & tensor1, const torch::Tensor & tensor2) {
166
+ *[](Tensor& self, Scalar value, const Tensor & tensor1, const Tensor & tensor2) {
525
167
  return self.addcmul_(tensor1, tensor2, value);
526
168
  })
527
169
  .define_method(
528
170
  "addcdiv!",
529
- *[](torch::Tensor& self, Scalar value, const torch::Tensor & tensor1, const torch::Tensor & tensor2) {
171
+ *[](Tensor& self, Scalar value, const Tensor & tensor1, const Tensor & tensor2) {
530
172
  return self.addcdiv_(tensor1, tensor2, value);
531
173
  })
532
- .define_method(
533
- "zero!",
534
- *[](torch::Tensor& self) {
535
- return self.zero_();
536
- })
537
- .define_method(
538
- "detach!",
539
- *[](torch::Tensor& self) {
540
- return self.detach_();
541
- })
542
- .define_method(
543
- "_select",
544
- *[](torch::Tensor& self, int64_t dim, int64_t index) {
545
- return self.select(dim, index);
546
- })
547
- .define_method(
548
- "_slice",
549
- *[](torch::Tensor& self, int64_t dim, int64_t start, int64_t end, int64_t step) {
550
- return self.slice(dim, start, end, step);
551
- })
552
174
  .define_method(
553
175
  "_requires_grad!",
554
- *[](torch::Tensor& self, bool requires_grad) {
176
+ *[](Tensor& self, bool requires_grad) {
555
177
  return self.set_requires_grad(requires_grad);
556
178
  })
557
179
  .define_method(
558
180
  "_backward",
559
- *[](torch::Tensor& self) {
560
- return self.backward();
561
- })
562
- .define_method(
563
- "_backward_gradient",
564
- *[](torch::Tensor& self, const torch::Tensor& gradient) {
565
- return self.backward(gradient);
181
+ *[](Tensor& self, Object gradient) {
182
+ return gradient.is_nil() ? self.backward() : self.backward(from_ruby<torch::Tensor>(gradient));
566
183
  })
567
184
  .define_method(
568
185
  "grad",
569
- *[](torch::Tensor& self) {
186
+ *[](Tensor& self) {
570
187
  return self.grad();
571
188
  })
572
189
  .define_method(
573
190
  "_dtype",
574
- *[](torch::Tensor& self) {
191
+ *[](Tensor& self) {
575
192
  return (int) at::typeMetaToScalarType(self.dtype());
576
193
  })
577
194
  .define_method(
578
195
  "_type",
579
- *[](torch::Tensor& self, int dtype) {
196
+ *[](Tensor& self, int dtype) {
580
197
  return self.toType((torch::ScalarType) dtype);
581
198
  })
582
199
  .define_method(
583
200
  "_layout",
584
- *[](torch::Tensor& self) {
201
+ *[](Tensor& self) {
585
202
  std::stringstream s;
586
203
  s << self.layout();
587
204
  return s.str();
588
205
  })
589
206
  .define_method(
590
207
  "device",
591
- *[](torch::Tensor& self) {
208
+ *[](Tensor& self) {
592
209
  std::stringstream s;
593
210
  s << self.device();
594
211
  return s.str();
595
212
  })
596
213
  .define_method(
597
- "_view",
598
- *[](torch::Tensor& self, IntArrayRef size) {
599
- return self.view(size);
600
- })
601
- .define_method(
602
- "resize_as!",
603
- *[](torch::Tensor& self, torch::Tensor& other) {
604
- return self.resize_as_(other);
605
- })
606
- .define_method(
607
- "fill!",
608
- *[](torch::Tensor& self, Scalar value) {
609
- return self.fill_(value);
610
- })
611
- .define_method(
612
- "_add!",
613
- *[](torch::Tensor& self, torch::Tensor& other) {
614
- return self.add_(other);
615
- })
616
- .define_method(
617
- "_add_alpha!",
618
- *[](torch::Tensor& self, torch::Tensor& other, Scalar alpha) {
619
- return self.add_(other, alpha);
620
- })
621
- .define_method(
622
- "_add_scalar!",
623
- *[](torch::Tensor& self, Scalar other) {
624
- return self.add_(other);
625
- })
626
- .define_method(
627
- "normal!",
628
- *[](torch::Tensor& self, double mean, double std) {
629
- return self.normal_(mean, std);
630
- })
631
- .define_method(
632
- "sub!",
633
- *[](torch::Tensor& self, torch::Tensor& other) {
634
- return self.sub_(other);
635
- })
636
- .define_method(
637
- "_mul!",
638
- *[](torch::Tensor& self, torch::Tensor& other) {
639
- return self.mul_(other);
640
- })
641
- .define_method(
642
- "_mul_scalar!",
643
- *[](torch::Tensor& self, Scalar other) {
644
- return self.mul_(other);
645
- })
646
- .define_method(
647
- "div!",
648
- *[](torch::Tensor& self, torch::Tensor& other) {
649
- return self.div_(other);
650
- })
651
- .define_method(
652
- "sqrt!",
653
- *[](torch::Tensor& self) {
654
- return self.sqrt_();
655
- })
656
- .define_method(
657
- "unsqueeze!",
658
- *[](torch::Tensor& self, int64_t dim) {
659
- return self.unsqueeze_(dim);
660
- })
661
- .define_method(
662
- "copy!",
663
- *[](torch::Tensor& self, torch::Tensor& src) {
664
- return self.copy_(src);
665
- })
666
- .define_method(
667
- "clone",
668
- *[](torch::Tensor& self) {
669
- return self.clone();
670
- })
671
- .define_method(
672
- "log_softmax",
673
- *[](torch::Tensor& self, int64_t dim) {
674
- return self.log_softmax(dim);
675
- })
676
- .define_method(
677
- "data",
678
- *[](torch::Tensor& self) {
679
- return self.data();
680
- })
681
- .define_method(
682
- "_data",
683
- *[](torch::Tensor& self) {
214
+ "_flat_data",
215
+ *[](Tensor& self) {
684
216
  Array a;
685
217
  auto dtype = self.dtype();
686
218
 
@@ -730,23 +262,18 @@ void Init_ext()
730
262
  }
731
263
  return a;
732
264
  })
733
- .define_method(
734
- "_size",
735
- *[](torch::Tensor& self, int i) {
736
- return self.size(i);
737
- })
738
265
  .define_method(
739
266
  "_to",
740
- *[](torch::Tensor& self, torch::Device device, int dtype, bool non_blocking, bool copy) {
267
+ *[](Tensor& self, torch::Device device, int dtype, bool non_blocking, bool copy) {
741
268
  return self.to(device, (torch::ScalarType) dtype, non_blocking, copy);
742
269
  })
743
270
  .define_singleton_method(
744
271
  "_make_subclass",
745
- *[](torch::Tensor& rd, bool requires_grad) {
272
+ *[](Tensor& rd, bool requires_grad) {
746
273
  auto data = torch::autograd::as_variable_ref(rd).detach();
747
274
  data.unsafeGetTensorImpl()->set_allow_tensor_metadata_change(true);
748
275
  auto var = data.set_requires_grad(requires_grad);
749
- return torch::autograd::Variable(std::move(var));
276
+ return Parameter(std::move(var));
750
277
  });
751
278
 
752
279
  Class rb_cTensorOptions = define_class_under<torch::TensorOptions>(rb_mTorch, "TensorOptions")
@@ -789,36 +316,84 @@ void Init_ext()
789
316
  return self.requires_grad(requires_grad);
790
317
  });
791
318
 
792
- Module rb_mNN = define_module_under(rb_mTorch, "NN");
793
-
794
319
  Module rb_mInit = define_module_under(rb_mNN, "Init")
795
320
  .define_singleton_method(
796
- "kaiming_uniform!",
797
- *[](torch::Tensor& input, double a) {
798
- return torch::nn::init::kaiming_uniform_(input, a);
321
+ "_calculate_gain",
322
+ *[](NonlinearityType nonlinearity, double param) {
323
+ return torch::nn::init::calculate_gain(nonlinearity, param);
324
+ })
325
+ .define_singleton_method(
326
+ "_uniform!",
327
+ *[](Tensor tensor, double low, double high) {
328
+ return torch::nn::init::uniform_(tensor, low, high);
329
+ })
330
+ .define_singleton_method(
331
+ "_normal!",
332
+ *[](Tensor tensor, double mean, double std) {
333
+ return torch::nn::init::normal_(tensor, mean, std);
334
+ })
335
+ .define_singleton_method(
336
+ "_constant!",
337
+ *[](Tensor tensor, Scalar value) {
338
+ return torch::nn::init::constant_(tensor, value);
799
339
  })
800
340
  .define_singleton_method(
801
- "normal!",
802
- *[](torch::Tensor& input) {
803
- return torch::nn::init::normal_(input);
341
+ "_ones!",
342
+ *[](Tensor tensor) {
343
+ return torch::nn::init::ones_(tensor);
804
344
  })
805
345
  .define_singleton_method(
806
- "uniform!",
807
- *[](torch::Tensor& input, double to, double from) {
808
- return torch::nn::init::uniform_(input, to, from);
346
+ "_zeros!",
347
+ *[](Tensor tensor) {
348
+ return torch::nn::init::zeros_(tensor);
349
+ })
350
+ .define_singleton_method(
351
+ "_eye!",
352
+ *[](Tensor tensor) {
353
+ return torch::nn::init::eye_(tensor);
354
+ })
355
+ .define_singleton_method(
356
+ "_dirac!",
357
+ *[](Tensor tensor) {
358
+ return torch::nn::init::dirac_(tensor);
359
+ })
360
+ .define_singleton_method(
361
+ "_xavier_uniform!",
362
+ *[](Tensor tensor, double gain) {
363
+ return torch::nn::init::xavier_uniform_(tensor, gain);
364
+ })
365
+ .define_singleton_method(
366
+ "_xavier_normal!",
367
+ *[](Tensor tensor, double gain) {
368
+ return torch::nn::init::xavier_normal_(tensor, gain);
369
+ })
370
+ .define_singleton_method(
371
+ "_kaiming_uniform!",
372
+ *[](Tensor tensor, double a, FanModeType mode, NonlinearityType nonlinearity) {
373
+ return torch::nn::init::kaiming_uniform_(tensor, a, mode, nonlinearity);
374
+ })
375
+ .define_singleton_method(
376
+ "_kaiming_normal!",
377
+ *[](Tensor tensor, double a, FanModeType mode, NonlinearityType nonlinearity) {
378
+ return torch::nn::init::kaiming_normal_(tensor, a, mode, nonlinearity);
379
+ })
380
+ .define_singleton_method(
381
+ "_orthogonal!",
382
+ *[](Tensor tensor, double gain) {
383
+ return torch::nn::init::orthogonal_(tensor, gain);
384
+ })
385
+ .define_singleton_method(
386
+ "_sparse!",
387
+ *[](Tensor tensor, double sparsity, double std) {
388
+ return torch::nn::init::sparse_(tensor, sparsity, std);
809
389
  });
810
390
 
811
- Class rb_cParameter = define_class_under<torch::autograd::Variable, torch::Tensor>(rb_mNN, "Parameter")
812
- // TODO return grad or nil to remove need for 2nd function
813
- .define_method(
814
- "_grad",
815
- *[](torch::autograd::Variable& self) {
816
- return self.grad();
817
- })
391
+ Class rb_cParameter = define_class_under<Parameter, torch::Tensor>(rb_mNN, "Parameter")
818
392
  .define_method(
819
- "_grad_defined",
820
- *[](torch::autograd::Variable& self) {
821
- return self.grad().defined();
393
+ "grad",
394
+ *[](Parameter& self) {
395
+ auto grad = self.grad();
396
+ return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
822
397
  });
823
398
 
824
399
  Class rb_cDevice = define_class_under<torch::Device>(rb_mTorch, "Device")