torch-rb 0.1.2 → 0.1.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 30089078de4039df111087e5c27e0cb10d6f36398c0e8d5cc774e9b642a8e133
4
- data.tar.gz: 89eb9e183b395dd67cd9cf228749cf26402993bb561de973f1ba7438bc372b04
3
+ metadata.gz: e7f715179c9a84dc7399b80d93fd61f2bbb58a0156e6084dc4abb23e1d4a1b52
4
+ data.tar.gz: 6928379ae7c92a77ad9dde4f4224ec33c6f8575a9b77585c0147e4f5361021de
5
5
  SHA512:
6
- metadata.gz: 027a069b00ac1329c007ddaf471a21b57a82a823ad974a937f832d17720b8e26474c64c79e9a29ec71bac433abb3d74d6a7cf407f0a983bb3c0cafb5b5c7532f
7
- data.tar.gz: 6d7ef10b53db0df39eda13d07aa9b52b4afac0965674919b5cc517e7b53f59a9010cb647e50d62bc06154f7d8f3ef632d5897e4f7774372d7ab1b44b2cb6ca82
6
+ metadata.gz: 9911a9e86d93f1e410776c44fdb3cd9aa06c83d1f0e42fdab8530970bea6520aed7906e96fb8243efd6b957453ebc13678b2b92e4c85b54407030a32c6196e08
7
+ data.tar.gz: 0d080f5458a5dcf8fee19ce5e2e342bf6269432de6e78d923036232963ebb80daeea993c0bbf4af2d6da46593ac28a72a8232020a9fcb48acc3276c9e1ebebf3
data/CHANGELOG.md CHANGED
@@ -1,3 +1,13 @@
1
+ ## 0.1.3 (2019-11-30)
2
+
3
+ - Changed to BSD 3-Clause license to match PyTorch
4
+ - Added many optimizers
5
+ - Added `StepLR` learning rate scheduler
6
+ - Added dropout
7
+ - Added embedding
8
+ - Added support for `bool` type
9
+ - Improved performance of `from_numo`
10
+
1
11
  ## 0.1.2 (2019-11-27)
2
12
 
3
13
  - Added SGD optimizer
data/LICENSE.txt CHANGED
@@ -1,22 +1,46 @@
1
- Copyright (c) 2019 Andrew Kane
2
-
3
- MIT License
4
-
5
- Permission is hereby granted, free of charge, to any person obtaining
6
- a copy of this software and associated documentation files (the
7
- "Software"), to deal in the Software without restriction, including
8
- without limitation the rights to use, copy, modify, merge, publish,
9
- distribute, sublicense, and/or sell copies of the Software, and to
10
- permit persons to whom the Software is furnished to do so, subject to
11
- the following conditions:
12
-
13
- The above copyright notice and this permission notice shall be
14
- included in all copies or substantial portions of the Software.
15
-
16
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
17
- EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
18
- MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
19
- NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
20
- LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
21
- OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
22
- WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
1
+ BSD 3-Clause License
2
+
3
+ From Torch-rb:
4
+
5
+ Copyright (c) 2019- Andrew Kane
6
+
7
+ From PyTorch (for ported code):
8
+
9
+ Copyright (c) 2016- Facebook, Inc (Adam Paszke)
10
+ Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
11
+ Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
12
+ Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
13
+ Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
14
+ Copyright (c) 2011-2013 NYU (Clement Farabet)
15
+ Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
16
+ Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
17
+ Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
18
+
19
+ All rights reserved.
20
+
21
+ Redistribution and use in source and binary forms, with or without
22
+ modification, are permitted provided that the following conditions are met:
23
+
24
+ 1. Redistributions of source code must retain the above copyright
25
+ notice, this list of conditions and the following disclaimer.
26
+
27
+ 2. Redistributions in binary form must reproduce the above copyright
28
+ notice, this list of conditions and the following disclaimer in the
29
+ documentation and/or other materials provided with the distribution.
30
+
31
+ 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America
32
+ and IDIAP Research Institute nor the names of its contributors may be
33
+ used to endorse or promote products derived from this software without
34
+ specific prior written permission.
35
+
36
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
37
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
38
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
39
+ ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
40
+ LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
41
+ CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
42
+ SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
43
+ INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
44
+ CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
45
+ ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
46
+ POSSIBILITY OF SUCH DAMAGE.
data/README.md CHANGED
@@ -30,7 +30,9 @@ This library follows the [PyTorch API](https://pytorch.org/docs/stable/torch.htm
30
30
 
31
31
  Many methods and options are missing at the moment. PRs welcome!
32
32
 
33
- Some examples below are from [Deep Learning with PyTorch: A 60 Minutes Blitz](https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html).
33
+ ## Tutorial
34
+
35
+ Some examples below are from [Deep Learning with PyTorch: A 60 Minutes Blitz](https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html)
34
36
 
35
37
  ### Tensors
36
38
 
@@ -145,7 +147,7 @@ Convert a Numo array to a tensor
145
147
 
146
148
  ```ruby
147
149
  b = Numo::NArray.cast([1, 2, 3])
148
- Torch.from_numpy(b)
150
+ Torch.from_numo(b)
149
151
  ```
150
152
 
151
153
  ### Autograd
@@ -180,10 +182,10 @@ Stop autograd from tracking history
180
182
 
181
183
  ```ruby
182
184
  x.requires_grad # true
183
- (x ** 2).requires_grad # true
185
+ (x**2).requires_grad # true
184
186
 
185
187
  Torch.no_grad do
186
- (x ** 2).requires_grad # false
188
+ (x**2).requires_grad # false
187
189
  end
188
190
  ```
189
191
 
@@ -359,6 +361,13 @@ Here’s a list of functions to create tensors (descriptions from the [C++ docs]
359
361
  Torch.zeros(3) # tensor([0, 0, 0])
360
362
  ```
361
363
 
364
+ ## Examples
365
+
366
+ Here are a few full examples:
367
+
368
+ - [Image classification with MNIST](examples/mnist)
369
+ - [Collaborative filtering with MovieLens](examples/movielens)
370
+
362
371
  ## LibTorch Installation
363
372
 
364
373
  [Download LibTorch](https://pytorch.org/). For Linux, use the `cxx11 ABI` version. Then run:
@@ -405,7 +414,7 @@ To get started with development:
405
414
  git clone https://github.com/ankane/torch-rb.git
406
415
  cd torch-rb
407
416
  bundle install
408
- bundle exec rake compile
417
+ bundle exec rake compile -- --with-torch-dir=/path/to/libtorch
409
418
  bundle exec rake test
410
419
  ```
411
420
 
data/ext/torch/ext.cpp CHANGED
@@ -88,7 +88,49 @@ IntArrayRef from_ruby<IntArrayRef>(Object x)
88
88
  }
89
89
 
90
90
  // for now
91
- typedef float Scalar;
91
+ class Scalar {
92
+ torch::Scalar value;
93
+ public:
94
+ Scalar(Object o) {
95
+ // TODO cast based on Ruby type
96
+ if (o.rb_type() == T_FIXNUM) {
97
+ value = torch::Scalar(from_ruby<int64_t>(o));
98
+ } else {
99
+ value = torch::Scalar(from_ruby<float>(o));
100
+ }
101
+ }
102
+ operator torch::Scalar() {
103
+ return value;
104
+ }
105
+ };
106
+
107
+ template<>
108
+ inline
109
+ Scalar from_ruby<Scalar>(Object x)
110
+ {
111
+ return Scalar(x);
112
+ }
113
+
114
+ class TensorList {
115
+ std::vector<torch::Tensor> vec;
116
+ public:
117
+ TensorList(Object o) {
118
+ Array a = Array(o);
119
+ for (size_t i = 0; i < a.size(); i++) {
120
+ vec.push_back(from_ruby<torch::Tensor>(a[i]));
121
+ }
122
+ }
123
+ operator torch::TensorList() {
124
+ return torch::TensorList(vec);
125
+ }
126
+ };
127
+
128
+ template<>
129
+ inline
130
+ TensorList from_ruby<TensorList>(Object x)
131
+ {
132
+ return TensorList(x);
133
+ }
92
134
 
93
135
  extern "C"
94
136
  void Init_ext()
@@ -206,6 +248,11 @@ void Init_ext()
206
248
  *[](torch::Tensor& input, int64_t dim, bool keepdim) {
207
249
  return torch::argmax(input, dim, keepdim);
208
250
  })
251
+ .define_singleton_method(
252
+ "_cat",
253
+ *[](TensorList tensors, int64_t dim) {
254
+ return torch::cat(tensors, dim);
255
+ })
209
256
  .define_singleton_method(
210
257
  "_norm",
211
258
  *[](torch::Tensor& input) {
@@ -221,6 +268,17 @@ void Init_ext()
221
268
  *[](torch::Tensor& input) {
222
269
  return torch::max(input);
223
270
  })
271
+ .define_singleton_method(
272
+ "_max_out",
273
+ *[](torch::Tensor &max, torch::Tensor &max_indices, const torch::Tensor &input, int64_t dim, bool keepdim) {
274
+ // TODO add return value
275
+ torch::_max_out(max, max_indices, input, dim, keepdim);
276
+ })
277
+ .define_singleton_method(
278
+ "_sqrt",
279
+ *[](torch::Tensor& input) {
280
+ return torch::sqrt(input);
281
+ })
224
282
  .define_singleton_method(
225
283
  "_exp",
226
284
  *[](torch::Tensor& input) {
@@ -231,6 +289,11 @@ void Init_ext()
231
289
  *[](torch::Tensor& input) {
232
290
  return torch::log(input);
233
291
  })
292
+ .define_singleton_method(
293
+ "_sign",
294
+ *[](torch::Tensor& input) {
295
+ return torch::sign(input);
296
+ })
234
297
  .define_singleton_method(
235
298
  "_unsqueeze",
236
299
  *[](torch::Tensor& input, int64_t dim) {
@@ -251,6 +314,18 @@ void Init_ext()
251
314
  *[](torch::Tensor& input, torch::Tensor& other) {
252
315
  return torch::eq(input, other);
253
316
  })
317
+ .define_singleton_method(
318
+ "_gt",
319
+ // TODO support tensors
320
+ *[](torch::Tensor& input, Scalar other) {
321
+ return torch::gt(input, other);
322
+ })
323
+ .define_singleton_method(
324
+ "_lt",
325
+ // TODO support tensors
326
+ *[](torch::Tensor& input, Scalar other) {
327
+ return torch::lt(input, other);
328
+ })
254
329
  .define_singleton_method(
255
330
  "_add",
256
331
  *[](torch::Tensor& input, torch::Tensor& other) {
@@ -258,7 +333,7 @@ void Init_ext()
258
333
  })
259
334
  .define_singleton_method(
260
335
  "_add_scalar",
261
- *[](torch::Tensor& input, float other) {
336
+ *[](torch::Tensor& input, Scalar other) {
262
337
  return torch::add(input, other);
263
338
  })
264
339
  .define_singleton_method(
@@ -273,7 +348,7 @@ void Init_ext()
273
348
  })
274
349
  .define_singleton_method(
275
350
  "_sub_scalar",
276
- *[](torch::Tensor& input, float other) {
351
+ *[](torch::Tensor& input, Scalar other) {
277
352
  return torch::sub(input, other);
278
353
  })
279
354
  .define_singleton_method(
@@ -283,7 +358,7 @@ void Init_ext()
283
358
  })
284
359
  .define_singleton_method(
285
360
  "_mul_scalar",
286
- *[](torch::Tensor& input, float other) {
361
+ *[](torch::Tensor& input, Scalar other) {
287
362
  return torch::mul(input, other);
288
363
  })
289
364
  .define_singleton_method(
@@ -293,7 +368,7 @@ void Init_ext()
293
368
  })
294
369
  .define_singleton_method(
295
370
  "_div_scalar",
296
- *[](torch::Tensor& input, float other) {
371
+ *[](torch::Tensor& input, Scalar other) {
297
372
  return torch::div(input, other);
298
373
  })
299
374
  .define_singleton_method(
@@ -303,7 +378,7 @@ void Init_ext()
303
378
  })
304
379
  .define_singleton_method(
305
380
  "_remainder_scalar",
306
- *[](torch::Tensor& input, float other) {
381
+ *[](torch::Tensor& input, Scalar other) {
307
382
  return torch::remainder(input, other);
308
383
  })
309
384
  .define_singleton_method(
@@ -311,6 +386,11 @@ void Init_ext()
311
386
  *[](torch::Tensor& input, Scalar exponent) {
312
387
  return torch::pow(input, exponent);
313
388
  })
389
+ .define_singleton_method(
390
+ "_abs",
391
+ *[](torch::Tensor& input) {
392
+ return torch::abs(input);
393
+ })
314
394
  .define_singleton_method(
315
395
  "_neg",
316
396
  *[](torch::Tensor& input) {
@@ -321,25 +401,20 @@ void Init_ext()
321
401
  *[](torch::Tensor& input, IntArrayRef shape) {
322
402
  return torch::reshape(input, shape);
323
403
  })
404
+ .define_singleton_method(
405
+ "_flatten",
406
+ *[](torch::Tensor& input, int64_t start_dim, int64_t end_dim) {
407
+ return torch::flatten(input, start_dim, end_dim);
408
+ })
324
409
  .define_singleton_method(
325
410
  "relu",
326
411
  *[](torch::Tensor& input) {
327
412
  return torch::relu(input);
328
413
  })
329
- .define_singleton_method(
330
- "prelu",
331
- *[](torch::Tensor& input, torch::Tensor& weight) {
332
- return torch::prelu(input, weight);
333
- })
334
- .define_singleton_method(
335
- "leaky_relu",
336
- *[](torch::Tensor& input, Scalar negative_slope = 0.01) {
337
- return torch::leaky_relu(input, negative_slope);
338
- })
339
414
  .define_singleton_method(
340
415
  "conv2d",
341
- *[](torch::Tensor& input, torch::Tensor& weight, torch::Tensor& bias, IntArrayRef stride, IntArrayRef padding) {
342
- return torch::conv2d(input, weight, bias, stride, padding);
416
+ *[](torch::Tensor& input, torch::Tensor& weight, torch::Tensor& bias, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, int64_t groups) {
417
+ return torch::conv2d(input, weight, bias, stride, padding, dilation, groups);
343
418
  })
344
419
  .define_singleton_method(
345
420
  "linear",
@@ -356,6 +431,52 @@ void Init_ext()
356
431
  *[](torch::Tensor& input, IntArrayRef kernel_size) {
357
432
  return torch::avg_pool2d(input, kernel_size);
358
433
  })
434
+ .define_singleton_method(
435
+ "_dropout",
436
+ *[](torch::Tensor& input, float p, bool train) {
437
+ return torch::dropout(input, p, train);
438
+ })
439
+ .define_singleton_method(
440
+ "_dropout!",
441
+ *[](torch::Tensor& input, float p, bool train) {
442
+ return torch::dropout_(input, p, train);
443
+ })
444
+ .define_singleton_method(
445
+ "_feature_dropout",
446
+ *[](torch::Tensor& input, float p, bool train) {
447
+ return torch::feature_dropout(input, p, train);
448
+ })
449
+ .define_singleton_method(
450
+ "_feature_dropout!",
451
+ *[](torch::Tensor& input, float p, bool train) {
452
+ return torch::feature_dropout_(input, p, train);
453
+ })
454
+ .define_singleton_method(
455
+ "_alpha_dropout",
456
+ *[](torch::Tensor& input, float p, bool train) {
457
+ return torch::alpha_dropout(input, p, train);
458
+ })
459
+ .define_singleton_method(
460
+ "_alpha_dropout!",
461
+ *[](torch::Tensor& input, float p, bool train) {
462
+ return torch::alpha_dropout_(input, p, train);
463
+ })
464
+ .define_singleton_method(
465
+ "_feature_alpha_dropout",
466
+ *[](torch::Tensor& input, float p, bool train) {
467
+ return torch::feature_alpha_dropout(input, p, train);
468
+ })
469
+ .define_singleton_method(
470
+ "_feature_alpha_dropout!",
471
+ *[](torch::Tensor& input, float p, bool train) {
472
+ return torch::feature_alpha_dropout_(input, p, train);
473
+ })
474
+ .define_singleton_method(
475
+ "_embedding",
476
+ // weight and indices are swapped from Python interface
477
+ *[](const torch::Tensor &indices, const torch::Tensor &weight, int64_t padding_idx, bool scale_grad_by_freq, bool sparse) {
478
+ return torch::embedding(weight, indices, padding_idx, scale_grad_by_freq, sparse);
479
+ })
359
480
  .define_singleton_method(
360
481
  "mse_loss",
361
482
  *[](torch::Tensor& input, torch::Tensor& target, std::string reduction) {
@@ -364,8 +485,16 @@ void Init_ext()
364
485
  })
365
486
  .define_singleton_method(
366
487
  "nll_loss",
367
- *[](torch::Tensor& input, torch::Tensor& target) {
368
- return torch::nll_loss(input, target);
488
+ *[](torch::Tensor& input, torch::Tensor& target, std::string reduction) {
489
+ auto red = reduction == "mean" ? Reduction::Mean : Reduction::Sum;
490
+ return torch::nll_loss(input, target, {}, red);
491
+ })
492
+ .define_singleton_method("numel", &torch::numel)
493
+ .define_singleton_method(
494
+ "_from_blob",
495
+ *[](String s, IntArrayRef size, const torch::TensorOptions &options) {
496
+ void *data = const_cast<char *>(s.c_str());
497
+ return torch::from_blob(data, size, options);
369
498
  })
370
499
  .define_singleton_method(
371
500
  "_tensor",
@@ -387,9 +516,19 @@ void Init_ext()
387
516
  .define_method("sparse?", &torch::Tensor::is_sparse)
388
517
  .define_method("quantized?", &torch::Tensor::is_quantized)
389
518
  .define_method("dim", &torch::Tensor::dim)
390
- .define_method("numel", &torch::Tensor::numel)
391
519
  .define_method("element_size", &torch::Tensor::element_size)
392
520
  .define_method("requires_grad", &torch::Tensor::requires_grad)
521
+ .define_method("view_as", &torch::Tensor::view_as)
522
+ .define_method(
523
+ "addcmul!",
524
+ *[](torch::Tensor& self, Scalar value, const torch::Tensor & tensor1, const torch::Tensor & tensor2) {
525
+ return self.addcmul_(tensor1, tensor2, value);
526
+ })
527
+ .define_method(
528
+ "addcdiv!",
529
+ *[](torch::Tensor& self, Scalar value, const torch::Tensor & tensor1, const torch::Tensor & tensor2) {
530
+ return self.addcdiv_(tensor1, tensor2, value);
531
+ })
393
532
  .define_method(
394
533
  "zero!",
395
534
  *[](torch::Tensor& self) {
@@ -460,24 +599,74 @@ void Init_ext()
460
599
  return self.view(size);
461
600
  })
462
601
  .define_method(
463
- "add!",
602
+ "resize_as!",
464
603
  *[](torch::Tensor& self, torch::Tensor& other) {
465
- self.add_(other);
604
+ return self.resize_as_(other);
605
+ })
606
+ .define_method(
607
+ "fill!",
608
+ *[](torch::Tensor& self, Scalar value) {
609
+ return self.fill_(value);
610
+ })
611
+ .define_method(
612
+ "_add!",
613
+ *[](torch::Tensor& self, torch::Tensor& other) {
614
+ return self.add_(other);
615
+ })
616
+ .define_method(
617
+ "_add_alpha!",
618
+ *[](torch::Tensor& self, torch::Tensor& other, Scalar alpha) {
619
+ return self.add_(other, alpha);
620
+ })
621
+ .define_method(
622
+ "_add_scalar!",
623
+ *[](torch::Tensor& self, Scalar other) {
624
+ return self.add_(other);
625
+ })
626
+ .define_method(
627
+ "normal!",
628
+ *[](torch::Tensor& self, double mean, double std) {
629
+ return self.normal_(mean, std);
466
630
  })
467
631
  .define_method(
468
632
  "sub!",
469
633
  *[](torch::Tensor& self, torch::Tensor& other) {
470
- self.sub_(other);
634
+ return self.sub_(other);
471
635
  })
472
636
  .define_method(
473
- "mul!",
637
+ "_mul!",
474
638
  *[](torch::Tensor& self, torch::Tensor& other) {
475
- self.mul_(other);
639
+ return self.mul_(other);
640
+ })
641
+ .define_method(
642
+ "_mul_scalar!",
643
+ *[](torch::Tensor& self, Scalar other) {
644
+ return self.mul_(other);
476
645
  })
477
646
  .define_method(
478
647
  "div!",
479
648
  *[](torch::Tensor& self, torch::Tensor& other) {
480
- self.div_(other);
649
+ return self.div_(other);
650
+ })
651
+ .define_method(
652
+ "sqrt!",
653
+ *[](torch::Tensor& self) {
654
+ return self.sqrt_();
655
+ })
656
+ .define_method(
657
+ "unsqueeze!",
658
+ *[](torch::Tensor& self, int64_t dim) {
659
+ return self.unsqueeze_(dim);
660
+ })
661
+ .define_method(
662
+ "copy!",
663
+ *[](torch::Tensor& self, torch::Tensor& src) {
664
+ return self.copy_(src);
665
+ })
666
+ .define_method(
667
+ "clone",
668
+ *[](torch::Tensor& self) {
669
+ return self.clone();
481
670
  })
482
671
  .define_method(
483
672
  "log_softmax",
@@ -532,8 +721,10 @@ void Init_ext()
532
721
  a.push(data[i]);
533
722
  }
534
723
  } else if (dtype == torch::kBool) {
535
- // bool
536
- throw std::runtime_error("Type not supported yet");
724
+ bool* data = self.data_ptr<bool>();
725
+ for (int i = 0; i < self.numel(); i++) {
726
+ a.push(data[i] ? True : False);
727
+ }
537
728
  } else {
538
729
  throw std::runtime_error("Unsupported type");
539
730
  }
@@ -544,6 +735,11 @@ void Init_ext()
544
735
  *[](torch::Tensor& self, int i) {
545
736
  return self.size(i);
546
737
  })
738
+ .define_method(
739
+ "_to",
740
+ *[](torch::Tensor& self, torch::Device device, int dtype, bool non_blocking, bool copy) {
741
+ return self.to(device, (torch::ScalarType) dtype, non_blocking, copy);
742
+ })
547
743
  .define_singleton_method(
548
744
  "_make_subclass",
549
745
  *[](torch::Tensor& rd, bool requires_grad) {
@@ -597,12 +793,17 @@ void Init_ext()
597
793
 
598
794
  Module rb_mInit = define_module_under(rb_mNN, "Init")
599
795
  .define_singleton_method(
600
- "kaiming_uniform_",
796
+ "kaiming_uniform!",
601
797
  *[](torch::Tensor& input, double a) {
602
798
  return torch::nn::init::kaiming_uniform_(input, a);
603
799
  })
604
800
  .define_singleton_method(
605
- "uniform_",
801
+ "normal!",
802
+ *[](torch::Tensor& input) {
803
+ return torch::nn::init::normal_(input);
804
+ })
805
+ .define_singleton_method(
806
+ "uniform!",
606
807
  *[](torch::Tensor& input, double to, double from) {
607
808
  return torch::nn::init::uniform_(input, to, from);
608
809
  });
@@ -619,4 +820,20 @@ void Init_ext()
619
820
  *[](torch::autograd::Variable& self) {
620
821
  return self.grad().defined();
621
822
  });
823
+
824
+ Class rb_cDevice = define_class_under<torch::Device>(rb_mTorch, "Device")
825
+ .define_constructor(Constructor<torch::Device, std::string>())
826
+ .define_method("index", &torch::Device::index)
827
+ .define_method("index?", &torch::Device::has_index)
828
+ .define_method(
829
+ "type",
830
+ *[](torch::Device& self) {
831
+ std::stringstream s;
832
+ s << self.type();
833
+ return s.str();
834
+ });
835
+
836
+ Module rb_mCUDA = define_module_under(rb_mTorch, "CUDA")
837
+ .define_singleton_method("available?", &torch::cuda::is_available)
838
+ .define_singleton_method("device_count", &torch::cuda::device_count);
622
839
  }