torch-rb 0.1.3 → 0.1.8

Sign up to get free protection for your applications and to get access to all the features.
Files changed (115) hide show
  1. checksums.yaml +4 -4
  2. data/CHANGELOG.md +30 -0
  3. data/README.md +5 -2
  4. data/ext/torch/ext.cpp +130 -555
  5. data/ext/torch/extconf.rb +9 -0
  6. data/ext/torch/templates.cpp +55 -0
  7. data/ext/torch/templates.hpp +244 -0
  8. data/lib/torch.rb +209 -171
  9. data/lib/torch/inspector.rb +23 -19
  10. data/lib/torch/native/dispatcher.rb +48 -0
  11. data/lib/torch/native/function.rb +110 -0
  12. data/lib/torch/native/generator.rb +168 -0
  13. data/lib/torch/native/native_functions.yaml +6491 -0
  14. data/lib/torch/native/parser.rb +134 -0
  15. data/lib/torch/nn/avg_pool1d.rb +18 -0
  16. data/lib/torch/nn/avg_pool2d.rb +19 -0
  17. data/lib/torch/nn/avg_pool3d.rb +19 -0
  18. data/lib/torch/nn/avg_poolnd.rb +9 -0
  19. data/lib/torch/nn/batch_norm.rb +75 -0
  20. data/lib/torch/nn/batch_norm1d.rb +11 -0
  21. data/lib/torch/nn/batch_norm2d.rb +11 -0
  22. data/lib/torch/nn/batch_norm3d.rb +11 -0
  23. data/lib/torch/nn/bce_loss.rb +13 -0
  24. data/lib/torch/nn/bce_with_logits_loss.rb +15 -0
  25. data/lib/torch/nn/bilinear.rb +38 -0
  26. data/lib/torch/nn/constant_pad1d.rb +10 -0
  27. data/lib/torch/nn/constant_pad2d.rb +10 -0
  28. data/lib/torch/nn/constant_pad3d.rb +10 -0
  29. data/lib/torch/nn/constant_padnd.rb +18 -0
  30. data/lib/torch/nn/conv1d.rb +22 -0
  31. data/lib/torch/nn/conv2d.rb +10 -20
  32. data/lib/torch/nn/conv3d.rb +22 -0
  33. data/lib/torch/nn/convnd.rb +3 -3
  34. data/lib/torch/nn/cosine_embedding_loss.rb +14 -0
  35. data/lib/torch/nn/cosine_similarity.rb +15 -0
  36. data/lib/torch/nn/cross_entropy_loss.rb +14 -0
  37. data/lib/torch/nn/ctc_loss.rb +15 -0
  38. data/lib/torch/nn/dropoutnd.rb +2 -2
  39. data/lib/torch/nn/embedding_bag.rb +34 -0
  40. data/lib/torch/nn/fold.rb +20 -0
  41. data/lib/torch/nn/functional.rb +379 -32
  42. data/lib/torch/nn/group_norm.rb +36 -0
  43. data/lib/torch/nn/gru.rb +49 -0
  44. data/lib/torch/nn/hardshrink.rb +18 -0
  45. data/lib/torch/nn/hinge_embedding_loss.rb +14 -0
  46. data/lib/torch/nn/identity.rb +14 -0
  47. data/lib/torch/nn/init.rb +58 -1
  48. data/lib/torch/nn/instance_norm.rb +20 -0
  49. data/lib/torch/nn/instance_norm1d.rb +18 -0
  50. data/lib/torch/nn/instance_norm2d.rb +11 -0
  51. data/lib/torch/nn/instance_norm3d.rb +11 -0
  52. data/lib/torch/nn/kl_div_loss.rb +13 -0
  53. data/lib/torch/nn/l1_loss.rb +13 -0
  54. data/lib/torch/nn/layer_norm.rb +35 -0
  55. data/lib/torch/nn/leaky_relu.rb +20 -0
  56. data/lib/torch/nn/linear.rb +12 -11
  57. data/lib/torch/nn/local_response_norm.rb +21 -0
  58. data/lib/torch/nn/log_sigmoid.rb +9 -0
  59. data/lib/torch/nn/log_softmax.rb +14 -0
  60. data/lib/torch/nn/loss.rb +10 -0
  61. data/lib/torch/nn/lp_pool1d.rb +9 -0
  62. data/lib/torch/nn/lp_pool2d.rb +9 -0
  63. data/lib/torch/nn/lp_poolnd.rb +22 -0
  64. data/lib/torch/nn/lstm.rb +66 -0
  65. data/lib/torch/nn/margin_ranking_loss.rb +14 -0
  66. data/lib/torch/nn/max_pool1d.rb +9 -0
  67. data/lib/torch/nn/max_pool2d.rb +9 -0
  68. data/lib/torch/nn/max_pool3d.rb +9 -0
  69. data/lib/torch/nn/max_poolnd.rb +19 -0
  70. data/lib/torch/nn/max_unpool1d.rb +16 -0
  71. data/lib/torch/nn/max_unpool2d.rb +16 -0
  72. data/lib/torch/nn/max_unpool3d.rb +16 -0
  73. data/lib/torch/nn/max_unpoolnd.rb +9 -0
  74. data/lib/torch/nn/module.rb +186 -35
  75. data/lib/torch/nn/mse_loss.rb +2 -2
  76. data/lib/torch/nn/multi_label_margin_loss.rb +13 -0
  77. data/lib/torch/nn/multi_label_soft_margin_loss.rb +13 -0
  78. data/lib/torch/nn/multi_margin_loss.rb +17 -0
  79. data/lib/torch/nn/nll_loss.rb +14 -0
  80. data/lib/torch/nn/pairwise_distance.rb +16 -0
  81. data/lib/torch/nn/parameter.rb +2 -2
  82. data/lib/torch/nn/poisson_nll_loss.rb +16 -0
  83. data/lib/torch/nn/prelu.rb +19 -0
  84. data/lib/torch/nn/reflection_pad1d.rb +10 -0
  85. data/lib/torch/nn/reflection_pad2d.rb +10 -0
  86. data/lib/torch/nn/reflection_padnd.rb +13 -0
  87. data/lib/torch/nn/relu.rb +8 -3
  88. data/lib/torch/nn/replication_pad1d.rb +10 -0
  89. data/lib/torch/nn/replication_pad2d.rb +10 -0
  90. data/lib/torch/nn/replication_pad3d.rb +10 -0
  91. data/lib/torch/nn/replication_padnd.rb +13 -0
  92. data/lib/torch/nn/rnn.rb +22 -0
  93. data/lib/torch/nn/rnn_base.rb +198 -0
  94. data/lib/torch/nn/sequential.rb +1 -10
  95. data/lib/torch/nn/sigmoid.rb +9 -0
  96. data/lib/torch/nn/smooth_l1_loss.rb +13 -0
  97. data/lib/torch/nn/soft_margin_loss.rb +13 -0
  98. data/lib/torch/nn/softmax.rb +18 -0
  99. data/lib/torch/nn/softmax2d.rb +10 -0
  100. data/lib/torch/nn/softmin.rb +14 -0
  101. data/lib/torch/nn/softplus.rb +19 -0
  102. data/lib/torch/nn/softshrink.rb +18 -0
  103. data/lib/torch/nn/softsign.rb +9 -0
  104. data/lib/torch/nn/tanh.rb +9 -0
  105. data/lib/torch/nn/tanhshrink.rb +9 -0
  106. data/lib/torch/nn/triplet_margin_loss.rb +18 -0
  107. data/lib/torch/nn/unfold.rb +19 -0
  108. data/lib/torch/nn/utils.rb +25 -0
  109. data/lib/torch/nn/weighted_loss.rb +10 -0
  110. data/lib/torch/nn/zero_pad2d.rb +9 -0
  111. data/lib/torch/random.rb +10 -0
  112. data/lib/torch/tensor.rb +51 -44
  113. data/lib/torch/version.rb +1 -1
  114. metadata +98 -6
  115. data/lib/torch/ext.bundle +0 -0
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: e7f715179c9a84dc7399b80d93fd61f2bbb58a0156e6084dc4abb23e1d4a1b52
4
- data.tar.gz: 6928379ae7c92a77ad9dde4f4224ec33c6f8575a9b77585c0147e4f5361021de
3
+ metadata.gz: fca87cb9b6d255287e9fafadf786c113798abbe76b36c82b8271b79cfbf3c2b9
4
+ data.tar.gz: 4813c71f5ad6d078e78da03cf59f8036e9e76258ffb67f538899bba146dcba2a
5
5
  SHA512:
6
- metadata.gz: 9911a9e86d93f1e410776c44fdb3cd9aa06c83d1f0e42fdab8530970bea6520aed7906e96fb8243efd6b957453ebc13678b2b92e4c85b54407030a32c6196e08
7
- data.tar.gz: 0d080f5458a5dcf8fee19ce5e2e342bf6269432de6e78d923036232963ebb80daeea993c0bbf4af2d6da46593ac28a72a8232020a9fcb48acc3276c9e1ebebf3
6
+ metadata.gz: 22c7150e6a7d9132c40c67819beecc6b8c69b268bd227a8e4aa324ef5e2707004691d5b65dcd4ba1ac537bfaf783947da7e5a323417cffcbf7d348768c40b7c6
7
+ data.tar.gz: 8a86c6b68efe6ad85a261d7033b87f040c22b2c670a0238accd6246274caed17b86d7b424441bba80c5ea67ec1bf53b05444dfb0c45ea5b8a52806d0ce19ec1e
@@ -1,3 +1,33 @@
1
+ ## 0.1.8 (2020-01-17)
2
+
3
+ - Added support for libtorch 1.4.0
4
+ - Dropped support for libtorch 1.3.1
5
+
6
+ ## 0.1.7 (2020-01-10)
7
+
8
+ - Fixed installation error with Ruby 2.7
9
+
10
+ ## 0.1.6 (2019-12-09)
11
+
12
+ - Added recurrent layers
13
+ - Added more pooling layers
14
+ - Added normalization layers
15
+
16
+ ## 0.1.5 (2019-12-06)
17
+
18
+ - Added many more functions
19
+ - Added tensor classes - `FloatTensor`, `LongTensor`, etc
20
+ - Improved modules
21
+
22
+ ## 0.1.4 (2019-12-01)
23
+
24
+ - Added distance functions
25
+ - Added more activations
26
+ - Added more linear layers
27
+ - Added more loss functions
28
+ - Added more init methods
29
+ - Added support for tensor assignment
30
+
1
31
  ## 0.1.3 (2019-11-30)
2
32
 
3
33
  - Changed to BSD 3-Clause license to match PyTorch
data/README.md CHANGED
@@ -20,6 +20,8 @@ Add this line to your application’s Gemfile:
20
20
  gem 'torch-rb'
21
21
  ```
22
22
 
23
+ It can take a few minutes to compile the extension.
24
+
23
25
  ## Getting Started
24
26
 
25
27
  This library follows the [PyTorch API](https://pytorch.org/docs/stable/torch.html). There are a few changes to make it more Ruby-like:
@@ -28,7 +30,7 @@ This library follows the [PyTorch API](https://pytorch.org/docs/stable/torch.htm
28
30
  - Methods that return booleans use `?` instead of `is_` (`tensor?` instead of `is_tensor`)
29
31
  - Numo is used instead of NumPy (`x.numo` instead of `x.numpy()`)
30
32
 
31
- Many methods and options are missing at the moment. PRs welcome!
33
+ Some methods and options are missing at the moment. PRs welcome!
32
34
 
33
35
  ## Tutorial
34
36
 
@@ -365,8 +367,9 @@ Here’s a list of functions to create tensors (descriptions from the [C++ docs]
365
367
 
366
368
  Here are a few full examples:
367
369
 
368
- - [Image classification with MNIST](examples/mnist)
370
+ - [Image classification with MNIST](examples/mnist) ([日本語版](https://qiita.com/kojix2/items/c19c36dc1bf73ea93409))
369
371
  - [Collaborative filtering with MovieLens](examples/movielens)
372
+ - [Sequence models and word embeddings](examples/nlp)
370
373
 
371
374
  ## LibTorch Installation
372
375
 
@@ -6,137 +6,35 @@
6
6
  #include <rice/Class.hpp>
7
7
  #include <rice/Constructor.hpp>
8
8
 
9
- using namespace Rice;
10
-
11
- template<>
12
- inline
13
- long long from_ruby<long long>(Object x)
14
- {
15
- return NUM2LL(x);
16
- }
17
-
18
- template<>
19
- inline
20
- Object to_ruby<long long>(long long const & x)
21
- {
22
- return LL2NUM(x);
23
- }
24
-
25
- template<>
26
- inline
27
- unsigned long long from_ruby<unsigned long long>(Object x)
28
- {
29
- return NUM2ULL(x);
30
- }
31
-
32
- template<>
33
- inline
34
- Object to_ruby<unsigned long long>(unsigned long long const & x)
35
- {
36
- return ULL2NUM(x);
37
- }
38
-
39
- template<>
40
- inline
41
- short from_ruby<short>(Object x)
42
- {
43
- return NUM2SHORT(x);
44
- }
45
-
46
- template<>
47
- inline
48
- Object to_ruby<short>(short const & x)
49
- {
50
- return INT2NUM(x);
51
- }
52
-
53
- template<>
54
- inline
55
- unsigned short from_ruby<unsigned short>(Object x)
56
- {
57
- return NUM2USHORT(x);
58
- }
9
+ #include "templates.hpp"
59
10
 
60
- template<>
61
- inline
62
- Object to_ruby<unsigned short>(unsigned short const & x)
63
- {
64
- return UINT2NUM(x);
65
- }
11
+ // generated with:
12
+ // rake generate:functions
13
+ #include "torch_functions.hpp"
14
+ #include "tensor_functions.hpp"
15
+ #include "nn_functions.hpp"
66
16
 
67
- // need to wrap torch::IntArrayRef() since
68
- // it doesn't own underlying data
69
- class IntArrayRef {
70
- std::vector<int64_t> vec;
71
- public:
72
- IntArrayRef(Object o) {
73
- Array a = Array(o);
74
- for (size_t i = 0; i < a.size(); i++) {
75
- vec.push_back(from_ruby<int64_t>(a[i]));
76
- }
77
- }
78
- operator torch::IntArrayRef() {
79
- return torch::IntArrayRef(vec);
80
- }
81
- };
82
-
83
- template<>
84
- inline
85
- IntArrayRef from_ruby<IntArrayRef>(Object x)
86
- {
87
- return IntArrayRef(x);
88
- }
17
+ using namespace Rice;
89
18
 
90
- // for now
91
- class Scalar {
92
- torch::Scalar value;
19
+ // need to make a distinction between parameters and tensors
20
+ class Parameter: public torch::autograd::Variable {
93
21
  public:
94
- Scalar(Object o) {
95
- // TODO cast based on Ruby type
96
- if (o.rb_type() == T_FIXNUM) {
97
- value = torch::Scalar(from_ruby<int64_t>(o));
98
- } else {
99
- value = torch::Scalar(from_ruby<float>(o));
100
- }
101
- }
102
- operator torch::Scalar() {
103
- return value;
104
- }
22
+ Parameter(Tensor&& t) : torch::autograd::Variable(t) { }
105
23
  };
106
24
 
107
- template<>
108
- inline
109
- Scalar from_ruby<Scalar>(Object x)
25
+ extern "C"
26
+ void Init_ext()
110
27
  {
111
- return Scalar(x);
112
- }
28
+ Module rb_mTorch = define_module("Torch");
29
+ add_torch_functions(rb_mTorch);
113
30
 
114
- class TensorList {
115
- std::vector<torch::Tensor> vec;
116
- public:
117
- TensorList(Object o) {
118
- Array a = Array(o);
119
- for (size_t i = 0; i < a.size(); i++) {
120
- vec.push_back(from_ruby<torch::Tensor>(a[i]));
121
- }
122
- }
123
- operator torch::TensorList() {
124
- return torch::TensorList(vec);
125
- }
126
- };
31
+ Class rb_cTensor = define_class_under<torch::Tensor>(rb_mTorch, "Tensor");
32
+ add_tensor_functions(rb_cTensor);
127
33
 
128
- template<>
129
- inline
130
- TensorList from_ruby<TensorList>(Object x)
131
- {
132
- return TensorList(x);
133
- }
34
+ Module rb_mNN = define_module_under(rb_mTorch, "NN");
35
+ add_nn_functions(rb_mNN);
134
36
 
135
- extern "C"
136
- void Init_ext()
137
- {
138
- Module rb_mTorch = define_module("Torch")
139
- .define_singleton_method(
37
+ rb_mTorch.define_singleton_method(
140
38
  "grad_enabled?",
141
39
  *[]() {
142
40
  return torch::GradMode::is_enabled();
@@ -146,11 +44,6 @@ void Init_ext()
146
44
  *[](bool enabled) {
147
45
  torch::GradMode::set_enabled(enabled);
148
46
  })
149
- .define_singleton_method(
150
- "floating_point?",
151
- *[](torch::Tensor& input) {
152
- return torch::is_floating_point(input);
153
- })
154
47
  .define_singleton_method(
155
48
  "manual_seed",
156
49
  *[](uint64_t seed) {
@@ -219,277 +112,17 @@ void Init_ext()
219
112
  })
220
113
  // begin operations
221
114
  .define_singleton_method(
222
- "_mean",
223
- *[](torch::Tensor& input) {
224
- return torch::mean(input);
225
- })
226
- .define_singleton_method(
227
- "_mean_dim",
228
- *[](torch::Tensor& input, int64_t dim, bool keepdim) {
229
- return torch::mean(input, dim, keepdim);
230
- })
231
- .define_singleton_method(
232
- "_sum",
233
- *[](torch::Tensor& input) {
234
- return torch::sum(input);
235
- })
236
- .define_singleton_method(
237
- "_sum_dim",
238
- *[](torch::Tensor& input, int64_t dim, bool keepdim) {
239
- return torch::sum(input, dim, keepdim);
240
- })
241
- .define_singleton_method(
242
- "_argmax",
243
- *[](torch::Tensor& input) {
244
- return torch::argmax(input);
245
- })
246
- .define_singleton_method(
247
- "_argmax_dim",
248
- *[](torch::Tensor& input, int64_t dim, bool keepdim) {
249
- return torch::argmax(input, dim, keepdim);
250
- })
251
- .define_singleton_method(
252
- "_cat",
253
- *[](TensorList tensors, int64_t dim) {
254
- return torch::cat(tensors, dim);
255
- })
256
- .define_singleton_method(
257
- "_norm",
258
- *[](torch::Tensor& input) {
259
- return torch::norm(input);
260
- })
261
- .define_singleton_method(
262
- "_min",
263
- *[](torch::Tensor& input) {
264
- return torch::min(input);
265
- })
266
- .define_singleton_method(
267
- "_max",
268
- *[](torch::Tensor& input) {
269
- return torch::max(input);
270
- })
271
- .define_singleton_method(
272
- "_max_out",
273
- *[](torch::Tensor &max, torch::Tensor &max_indices, const torch::Tensor &input, int64_t dim, bool keepdim) {
274
- // TODO add return value
275
- torch::_max_out(max, max_indices, input, dim, keepdim);
276
- })
277
- .define_singleton_method(
278
- "_sqrt",
279
- *[](torch::Tensor& input) {
280
- return torch::sqrt(input);
281
- })
282
- .define_singleton_method(
283
- "_exp",
284
- *[](torch::Tensor& input) {
285
- return torch::exp(input);
286
- })
287
- .define_singleton_method(
288
- "_log",
289
- *[](torch::Tensor& input) {
290
- return torch::log(input);
291
- })
292
- .define_singleton_method(
293
- "_sign",
294
- *[](torch::Tensor& input) {
295
- return torch::sign(input);
296
- })
297
- .define_singleton_method(
298
- "_unsqueeze",
299
- *[](torch::Tensor& input, int64_t dim) {
300
- return torch::unsqueeze(input, dim);
301
- })
302
- .define_singleton_method(
303
- "_dot",
304
- *[](torch::Tensor& input, torch::Tensor& tensor) {
305
- return torch::dot(input, tensor);
306
- })
307
- .define_singleton_method(
308
- "_matmul",
309
- *[](torch::Tensor& input, torch::Tensor& other) {
310
- return torch::matmul(input, other);
311
- })
312
- .define_singleton_method(
313
- "_eq",
314
- *[](torch::Tensor& input, torch::Tensor& other) {
315
- return torch::eq(input, other);
316
- })
317
- .define_singleton_method(
318
- "_gt",
319
- // TODO support tensors
320
- *[](torch::Tensor& input, Scalar other) {
321
- return torch::gt(input, other);
322
- })
323
- .define_singleton_method(
324
- "_lt",
325
- // TODO support tensors
326
- *[](torch::Tensor& input, Scalar other) {
327
- return torch::lt(input, other);
328
- })
329
- .define_singleton_method(
330
- "_add",
331
- *[](torch::Tensor& input, torch::Tensor& other) {
332
- return torch::add(input, other);
333
- })
334
- .define_singleton_method(
335
- "_add_scalar",
336
- *[](torch::Tensor& input, Scalar other) {
337
- return torch::add(input, other);
338
- })
339
- .define_singleton_method(
340
- "_add_out",
341
- *[](torch::Tensor& out, torch::Tensor& input, torch::Tensor& other) {
342
- return torch::add_out(out, input, other);
343
- })
344
- .define_singleton_method(
345
- "_sub",
346
- *[](torch::Tensor& input, torch::Tensor& other) {
347
- return torch::sub(input, other);
348
- })
349
- .define_singleton_method(
350
- "_sub_scalar",
351
- *[](torch::Tensor& input, Scalar other) {
352
- return torch::sub(input, other);
353
- })
354
- .define_singleton_method(
355
- "_mul",
356
- *[](torch::Tensor& input, torch::Tensor& other) {
357
- return torch::mul(input, other);
358
- })
359
- .define_singleton_method(
360
- "_mul_scalar",
361
- *[](torch::Tensor& input, Scalar other) {
362
- return torch::mul(input, other);
363
- })
364
- .define_singleton_method(
365
- "_div",
366
- *[](torch::Tensor& input, torch::Tensor& other) {
367
- return torch::div(input, other);
368
- })
369
- .define_singleton_method(
370
- "_div_scalar",
371
- *[](torch::Tensor& input, Scalar other) {
372
- return torch::div(input, other);
373
- })
374
- .define_singleton_method(
375
- "_remainder",
376
- *[](torch::Tensor& input, torch::Tensor& other) {
377
- return torch::remainder(input, other);
378
- })
379
- .define_singleton_method(
380
- "_remainder_scalar",
381
- *[](torch::Tensor& input, Scalar other) {
382
- return torch::remainder(input, other);
383
- })
384
- .define_singleton_method(
385
- "_pow",
386
- *[](torch::Tensor& input, Scalar exponent) {
387
- return torch::pow(input, exponent);
388
- })
389
- .define_singleton_method(
390
- "_abs",
391
- *[](torch::Tensor& input) {
392
- return torch::abs(input);
393
- })
394
- .define_singleton_method(
395
- "_neg",
396
- *[](torch::Tensor& input) {
397
- return torch::neg(input);
115
+ "_save",
116
+ *[](const Tensor &value) {
117
+ auto v = torch::pickle_save(value);
118
+ std::string str(v.begin(), v.end());
119
+ return str;
398
120
  })
399
121
  .define_singleton_method(
400
- "_reshape",
401
- *[](torch::Tensor& input, IntArrayRef shape) {
402
- return torch::reshape(input, shape);
122
+ "_binary_cross_entropy_with_logits",
123
+ *[](const Tensor &input, const Tensor &target, OptionalTensor weight, OptionalTensor pos_weight, MyReduction reduction) {
124
+ return torch::binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction);
403
125
  })
404
- .define_singleton_method(
405
- "_flatten",
406
- *[](torch::Tensor& input, int64_t start_dim, int64_t end_dim) {
407
- return torch::flatten(input, start_dim, end_dim);
408
- })
409
- .define_singleton_method(
410
- "relu",
411
- *[](torch::Tensor& input) {
412
- return torch::relu(input);
413
- })
414
- .define_singleton_method(
415
- "conv2d",
416
- *[](torch::Tensor& input, torch::Tensor& weight, torch::Tensor& bias, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, int64_t groups) {
417
- return torch::conv2d(input, weight, bias, stride, padding, dilation, groups);
418
- })
419
- .define_singleton_method(
420
- "linear",
421
- *[](torch::Tensor& input, torch::Tensor& weight, torch::Tensor& bias) {
422
- return torch::linear(input, weight, bias);
423
- })
424
- .define_singleton_method(
425
- "max_pool2d",
426
- *[](torch::Tensor& input, IntArrayRef kernel_size) {
427
- return torch::max_pool2d(input, kernel_size);
428
- })
429
- .define_singleton_method(
430
- "avg_pool2d",
431
- *[](torch::Tensor& input, IntArrayRef kernel_size) {
432
- return torch::avg_pool2d(input, kernel_size);
433
- })
434
- .define_singleton_method(
435
- "_dropout",
436
- *[](torch::Tensor& input, float p, bool train) {
437
- return torch::dropout(input, p, train);
438
- })
439
- .define_singleton_method(
440
- "_dropout!",
441
- *[](torch::Tensor& input, float p, bool train) {
442
- return torch::dropout_(input, p, train);
443
- })
444
- .define_singleton_method(
445
- "_feature_dropout",
446
- *[](torch::Tensor& input, float p, bool train) {
447
- return torch::feature_dropout(input, p, train);
448
- })
449
- .define_singleton_method(
450
- "_feature_dropout!",
451
- *[](torch::Tensor& input, float p, bool train) {
452
- return torch::feature_dropout_(input, p, train);
453
- })
454
- .define_singleton_method(
455
- "_alpha_dropout",
456
- *[](torch::Tensor& input, float p, bool train) {
457
- return torch::alpha_dropout(input, p, train);
458
- })
459
- .define_singleton_method(
460
- "_alpha_dropout!",
461
- *[](torch::Tensor& input, float p, bool train) {
462
- return torch::alpha_dropout_(input, p, train);
463
- })
464
- .define_singleton_method(
465
- "_feature_alpha_dropout",
466
- *[](torch::Tensor& input, float p, bool train) {
467
- return torch::feature_alpha_dropout(input, p, train);
468
- })
469
- .define_singleton_method(
470
- "_feature_alpha_dropout!",
471
- *[](torch::Tensor& input, float p, bool train) {
472
- return torch::feature_alpha_dropout_(input, p, train);
473
- })
474
- .define_singleton_method(
475
- "_embedding",
476
- // weight and indices are swapped from Python interface
477
- *[](const torch::Tensor &indices, const torch::Tensor &weight, int64_t padding_idx, bool scale_grad_by_freq, bool sparse) {
478
- return torch::embedding(weight, indices, padding_idx, scale_grad_by_freq, sparse);
479
- })
480
- .define_singleton_method(
481
- "mse_loss",
482
- *[](torch::Tensor& input, torch::Tensor& target, std::string reduction) {
483
- auto red = reduction == "mean" ? Reduction::Mean : Reduction::Sum;
484
- return torch::mse_loss(input, target, red);
485
- })
486
- .define_singleton_method(
487
- "nll_loss",
488
- *[](torch::Tensor& input, torch::Tensor& target, std::string reduction) {
489
- auto red = reduction == "mean" ? Reduction::Mean : Reduction::Sum;
490
- return torch::nll_loss(input, target, {}, red);
491
- })
492
- .define_singleton_method("numel", &torch::numel)
493
126
  .define_singleton_method(
494
127
  "_from_blob",
495
128
  *[](String s, IntArrayRef size, const torch::TensorOptions &options) {
@@ -500,187 +133,86 @@ void Init_ext()
500
133
  "_tensor",
501
134
  *[](Object o, IntArrayRef size, const torch::TensorOptions &options) {
502
135
  Array a = Array(o);
503
- std::vector<float> vec;
504
- for (size_t i = 0; i < a.size(); i++) {
505
- vec.push_back(from_ruby<float>(a[i]));
136
+ auto dtype = options.dtype();
137
+ torch::Tensor t;
138
+ if (dtype == torch::kBool) {
139
+ throw std::runtime_error("Cannot create bool from tensor method yet");
140
+ } else {
141
+ std::vector<float> vec;
142
+ for (size_t i = 0; i < a.size(); i++) {
143
+ vec.push_back(from_ruby<float>(a[i]));
144
+ }
145
+ // hack for requires_grad error
146
+ if (options.requires_grad()) {
147
+ t = torch::tensor(vec, options.requires_grad(c10::nullopt));
148
+ t.set_requires_grad(true);
149
+ } else {
150
+ t = torch::tensor(vec, options);
151
+ }
506
152
  }
507
- return torch::tensor(vec, options).reshape(size);
153
+ return t.reshape(size);
508
154
  });
509
155
 
510
- Class rb_cTensor = define_class_under<torch::Tensor>(rb_mTorch, "Tensor")
156
+ rb_cTensor
511
157
  .define_method("cuda?", &torch::Tensor::is_cuda)
512
- .define_method("distributed?", &torch::Tensor::is_distributed)
513
- .define_method("complex?", &torch::Tensor::is_complex)
514
- .define_method("floating_point?", &torch::Tensor::is_floating_point)
515
- .define_method("signed?", &torch::Tensor::is_signed)
516
158
  .define_method("sparse?", &torch::Tensor::is_sparse)
517
159
  .define_method("quantized?", &torch::Tensor::is_quantized)
518
160
  .define_method("dim", &torch::Tensor::dim)
161
+ .define_method("numel", &torch::Tensor::numel)
519
162
  .define_method("element_size", &torch::Tensor::element_size)
520
163
  .define_method("requires_grad", &torch::Tensor::requires_grad)
521
- .define_method("view_as", &torch::Tensor::view_as)
522
164
  .define_method(
523
165
  "addcmul!",
524
- *[](torch::Tensor& self, Scalar value, const torch::Tensor & tensor1, const torch::Tensor & tensor2) {
166
+ *[](Tensor& self, Scalar value, const Tensor & tensor1, const Tensor & tensor2) {
525
167
  return self.addcmul_(tensor1, tensor2, value);
526
168
  })
527
169
  .define_method(
528
170
  "addcdiv!",
529
- *[](torch::Tensor& self, Scalar value, const torch::Tensor & tensor1, const torch::Tensor & tensor2) {
171
+ *[](Tensor& self, Scalar value, const Tensor & tensor1, const Tensor & tensor2) {
530
172
  return self.addcdiv_(tensor1, tensor2, value);
531
173
  })
532
- .define_method(
533
- "zero!",
534
- *[](torch::Tensor& self) {
535
- return self.zero_();
536
- })
537
- .define_method(
538
- "detach!",
539
- *[](torch::Tensor& self) {
540
- return self.detach_();
541
- })
542
- .define_method(
543
- "_select",
544
- *[](torch::Tensor& self, int64_t dim, int64_t index) {
545
- return self.select(dim, index);
546
- })
547
- .define_method(
548
- "_slice",
549
- *[](torch::Tensor& self, int64_t dim, int64_t start, int64_t end, int64_t step) {
550
- return self.slice(dim, start, end, step);
551
- })
552
174
  .define_method(
553
175
  "_requires_grad!",
554
- *[](torch::Tensor& self, bool requires_grad) {
176
+ *[](Tensor& self, bool requires_grad) {
555
177
  return self.set_requires_grad(requires_grad);
556
178
  })
557
179
  .define_method(
558
180
  "_backward",
559
- *[](torch::Tensor& self) {
560
- return self.backward();
561
- })
562
- .define_method(
563
- "_backward_gradient",
564
- *[](torch::Tensor& self, const torch::Tensor& gradient) {
565
- return self.backward(gradient);
181
+ *[](Tensor& self, Object gradient) {
182
+ return gradient.is_nil() ? self.backward() : self.backward(from_ruby<torch::Tensor>(gradient));
566
183
  })
567
184
  .define_method(
568
185
  "grad",
569
- *[](torch::Tensor& self) {
186
+ *[](Tensor& self) {
570
187
  return self.grad();
571
188
  })
572
189
  .define_method(
573
190
  "_dtype",
574
- *[](torch::Tensor& self) {
191
+ *[](Tensor& self) {
575
192
  return (int) at::typeMetaToScalarType(self.dtype());
576
193
  })
577
194
  .define_method(
578
195
  "_type",
579
- *[](torch::Tensor& self, int dtype) {
196
+ *[](Tensor& self, int dtype) {
580
197
  return self.toType((torch::ScalarType) dtype);
581
198
  })
582
199
  .define_method(
583
200
  "_layout",
584
- *[](torch::Tensor& self) {
201
+ *[](Tensor& self) {
585
202
  std::stringstream s;
586
203
  s << self.layout();
587
204
  return s.str();
588
205
  })
589
206
  .define_method(
590
207
  "device",
591
- *[](torch::Tensor& self) {
208
+ *[](Tensor& self) {
592
209
  std::stringstream s;
593
210
  s << self.device();
594
211
  return s.str();
595
212
  })
596
213
  .define_method(
597
- "_view",
598
- *[](torch::Tensor& self, IntArrayRef size) {
599
- return self.view(size);
600
- })
601
- .define_method(
602
- "resize_as!",
603
- *[](torch::Tensor& self, torch::Tensor& other) {
604
- return self.resize_as_(other);
605
- })
606
- .define_method(
607
- "fill!",
608
- *[](torch::Tensor& self, Scalar value) {
609
- return self.fill_(value);
610
- })
611
- .define_method(
612
- "_add!",
613
- *[](torch::Tensor& self, torch::Tensor& other) {
614
- return self.add_(other);
615
- })
616
- .define_method(
617
- "_add_alpha!",
618
- *[](torch::Tensor& self, torch::Tensor& other, Scalar alpha) {
619
- return self.add_(other, alpha);
620
- })
621
- .define_method(
622
- "_add_scalar!",
623
- *[](torch::Tensor& self, Scalar other) {
624
- return self.add_(other);
625
- })
626
- .define_method(
627
- "normal!",
628
- *[](torch::Tensor& self, double mean, double std) {
629
- return self.normal_(mean, std);
630
- })
631
- .define_method(
632
- "sub!",
633
- *[](torch::Tensor& self, torch::Tensor& other) {
634
- return self.sub_(other);
635
- })
636
- .define_method(
637
- "_mul!",
638
- *[](torch::Tensor& self, torch::Tensor& other) {
639
- return self.mul_(other);
640
- })
641
- .define_method(
642
- "_mul_scalar!",
643
- *[](torch::Tensor& self, Scalar other) {
644
- return self.mul_(other);
645
- })
646
- .define_method(
647
- "div!",
648
- *[](torch::Tensor& self, torch::Tensor& other) {
649
- return self.div_(other);
650
- })
651
- .define_method(
652
- "sqrt!",
653
- *[](torch::Tensor& self) {
654
- return self.sqrt_();
655
- })
656
- .define_method(
657
- "unsqueeze!",
658
- *[](torch::Tensor& self, int64_t dim) {
659
- return self.unsqueeze_(dim);
660
- })
661
- .define_method(
662
- "copy!",
663
- *[](torch::Tensor& self, torch::Tensor& src) {
664
- return self.copy_(src);
665
- })
666
- .define_method(
667
- "clone",
668
- *[](torch::Tensor& self) {
669
- return self.clone();
670
- })
671
- .define_method(
672
- "log_softmax",
673
- *[](torch::Tensor& self, int64_t dim) {
674
- return self.log_softmax(dim);
675
- })
676
- .define_method(
677
- "data",
678
- *[](torch::Tensor& self) {
679
- return self.data();
680
- })
681
- .define_method(
682
- "_data",
683
- *[](torch::Tensor& self) {
214
+ "_flat_data",
215
+ *[](Tensor& self) {
684
216
  Array a;
685
217
  auto dtype = self.dtype();
686
218
 
@@ -730,23 +262,18 @@ void Init_ext()
730
262
  }
731
263
  return a;
732
264
  })
733
- .define_method(
734
- "_size",
735
- *[](torch::Tensor& self, int i) {
736
- return self.size(i);
737
- })
738
265
  .define_method(
739
266
  "_to",
740
- *[](torch::Tensor& self, torch::Device device, int dtype, bool non_blocking, bool copy) {
267
+ *[](Tensor& self, torch::Device device, int dtype, bool non_blocking, bool copy) {
741
268
  return self.to(device, (torch::ScalarType) dtype, non_blocking, copy);
742
269
  })
743
270
  .define_singleton_method(
744
271
  "_make_subclass",
745
- *[](torch::Tensor& rd, bool requires_grad) {
272
+ *[](Tensor& rd, bool requires_grad) {
746
273
  auto data = torch::autograd::as_variable_ref(rd).detach();
747
274
  data.unsafeGetTensorImpl()->set_allow_tensor_metadata_change(true);
748
275
  auto var = data.set_requires_grad(requires_grad);
749
- return torch::autograd::Variable(std::move(var));
276
+ return Parameter(std::move(var));
750
277
  });
751
278
 
752
279
  Class rb_cTensorOptions = define_class_under<torch::TensorOptions>(rb_mTorch, "TensorOptions")
@@ -789,36 +316,84 @@ void Init_ext()
789
316
  return self.requires_grad(requires_grad);
790
317
  });
791
318
 
792
- Module rb_mNN = define_module_under(rb_mTorch, "NN");
793
-
794
319
  Module rb_mInit = define_module_under(rb_mNN, "Init")
795
320
  .define_singleton_method(
796
- "kaiming_uniform!",
797
- *[](torch::Tensor& input, double a) {
798
- return torch::nn::init::kaiming_uniform_(input, a);
321
+ "_calculate_gain",
322
+ *[](NonlinearityType nonlinearity, double param) {
323
+ return torch::nn::init::calculate_gain(nonlinearity, param);
324
+ })
325
+ .define_singleton_method(
326
+ "_uniform!",
327
+ *[](Tensor tensor, double low, double high) {
328
+ return torch::nn::init::uniform_(tensor, low, high);
329
+ })
330
+ .define_singleton_method(
331
+ "_normal!",
332
+ *[](Tensor tensor, double mean, double std) {
333
+ return torch::nn::init::normal_(tensor, mean, std);
334
+ })
335
+ .define_singleton_method(
336
+ "_constant!",
337
+ *[](Tensor tensor, Scalar value) {
338
+ return torch::nn::init::constant_(tensor, value);
799
339
  })
800
340
  .define_singleton_method(
801
- "normal!",
802
- *[](torch::Tensor& input) {
803
- return torch::nn::init::normal_(input);
341
+ "_ones!",
342
+ *[](Tensor tensor) {
343
+ return torch::nn::init::ones_(tensor);
804
344
  })
805
345
  .define_singleton_method(
806
- "uniform!",
807
- *[](torch::Tensor& input, double to, double from) {
808
- return torch::nn::init::uniform_(input, to, from);
346
+ "_zeros!",
347
+ *[](Tensor tensor) {
348
+ return torch::nn::init::zeros_(tensor);
349
+ })
350
+ .define_singleton_method(
351
+ "_eye!",
352
+ *[](Tensor tensor) {
353
+ return torch::nn::init::eye_(tensor);
354
+ })
355
+ .define_singleton_method(
356
+ "_dirac!",
357
+ *[](Tensor tensor) {
358
+ return torch::nn::init::dirac_(tensor);
359
+ })
360
+ .define_singleton_method(
361
+ "_xavier_uniform!",
362
+ *[](Tensor tensor, double gain) {
363
+ return torch::nn::init::xavier_uniform_(tensor, gain);
364
+ })
365
+ .define_singleton_method(
366
+ "_xavier_normal!",
367
+ *[](Tensor tensor, double gain) {
368
+ return torch::nn::init::xavier_normal_(tensor, gain);
369
+ })
370
+ .define_singleton_method(
371
+ "_kaiming_uniform!",
372
+ *[](Tensor tensor, double a, FanModeType mode, NonlinearityType nonlinearity) {
373
+ return torch::nn::init::kaiming_uniform_(tensor, a, mode, nonlinearity);
374
+ })
375
+ .define_singleton_method(
376
+ "_kaiming_normal!",
377
+ *[](Tensor tensor, double a, FanModeType mode, NonlinearityType nonlinearity) {
378
+ return torch::nn::init::kaiming_normal_(tensor, a, mode, nonlinearity);
379
+ })
380
+ .define_singleton_method(
381
+ "_orthogonal!",
382
+ *[](Tensor tensor, double gain) {
383
+ return torch::nn::init::orthogonal_(tensor, gain);
384
+ })
385
+ .define_singleton_method(
386
+ "_sparse!",
387
+ *[](Tensor tensor, double sparsity, double std) {
388
+ return torch::nn::init::sparse_(tensor, sparsity, std);
809
389
  });
810
390
 
811
- Class rb_cParameter = define_class_under<torch::autograd::Variable, torch::Tensor>(rb_mNN, "Parameter")
812
- // TODO return grad or nil to remove need for 2nd function
813
- .define_method(
814
- "_grad",
815
- *[](torch::autograd::Variable& self) {
816
- return self.grad();
817
- })
391
+ Class rb_cParameter = define_class_under<Parameter, torch::Tensor>(rb_mNN, "Parameter")
818
392
  .define_method(
819
- "_grad_defined",
820
- *[](torch::autograd::Variable& self) {
821
- return self.grad().defined();
393
+ "grad",
394
+ *[](Parameter& self) {
395
+ auto grad = self.grad();
396
+ return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
822
397
  });
823
398
 
824
399
  Class rb_cDevice = define_class_under<torch::Device>(rb_mTorch, "Device")