torch-rb 0.1.2 → 0.1.3

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 30089078de4039df111087e5c27e0cb10d6f36398c0e8d5cc774e9b642a8e133
4
- data.tar.gz: 89eb9e183b395dd67cd9cf228749cf26402993bb561de973f1ba7438bc372b04
3
+ metadata.gz: e7f715179c9a84dc7399b80d93fd61f2bbb58a0156e6084dc4abb23e1d4a1b52
4
+ data.tar.gz: 6928379ae7c92a77ad9dde4f4224ec33c6f8575a9b77585c0147e4f5361021de
5
5
  SHA512:
6
- metadata.gz: 027a069b00ac1329c007ddaf471a21b57a82a823ad974a937f832d17720b8e26474c64c79e9a29ec71bac433abb3d74d6a7cf407f0a983bb3c0cafb5b5c7532f
7
- data.tar.gz: 6d7ef10b53db0df39eda13d07aa9b52b4afac0965674919b5cc517e7b53f59a9010cb647e50d62bc06154f7d8f3ef632d5897e4f7774372d7ab1b44b2cb6ca82
6
+ metadata.gz: 9911a9e86d93f1e410776c44fdb3cd9aa06c83d1f0e42fdab8530970bea6520aed7906e96fb8243efd6b957453ebc13678b2b92e4c85b54407030a32c6196e08
7
+ data.tar.gz: 0d080f5458a5dcf8fee19ce5e2e342bf6269432de6e78d923036232963ebb80daeea993c0bbf4af2d6da46593ac28a72a8232020a9fcb48acc3276c9e1ebebf3
data/CHANGELOG.md CHANGED
@@ -1,3 +1,13 @@
1
+ ## 0.1.3 (2019-11-30)
2
+
3
+ - Changed to BSD 3-Clause license to match PyTorch
4
+ - Added many optimizers
5
+ - Added `StepLR` learning rate scheduler
6
+ - Added dropout
7
+ - Added embedding
8
+ - Added support for `bool` type
9
+ - Improved performance of `from_numo`
10
+
1
11
  ## 0.1.2 (2019-11-27)
2
12
 
3
13
  - Added SGD optimizer
data/LICENSE.txt CHANGED
@@ -1,22 +1,46 @@
1
- Copyright (c) 2019 Andrew Kane
2
-
3
- MIT License
4
-
5
- Permission is hereby granted, free of charge, to any person obtaining
6
- a copy of this software and associated documentation files (the
7
- "Software"), to deal in the Software without restriction, including
8
- without limitation the rights to use, copy, modify, merge, publish,
9
- distribute, sublicense, and/or sell copies of the Software, and to
10
- permit persons to whom the Software is furnished to do so, subject to
11
- the following conditions:
12
-
13
- The above copyright notice and this permission notice shall be
14
- included in all copies or substantial portions of the Software.
15
-
16
- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
17
- EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
18
- MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
19
- NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
20
- LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
21
- OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
22
- WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
1
+ BSD 3-Clause License
2
+
3
+ From Torch-rb:
4
+
5
+ Copyright (c) 2019- Andrew Kane
6
+
7
+ From PyTorch (for ported code):
8
+
9
+ Copyright (c) 2016- Facebook, Inc (Adam Paszke)
10
+ Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
11
+ Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
12
+ Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
13
+ Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
14
+ Copyright (c) 2011-2013 NYU (Clement Farabet)
15
+ Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
16
+ Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
17
+ Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
18
+
19
+ All rights reserved.
20
+
21
+ Redistribution and use in source and binary forms, with or without
22
+ modification, are permitted provided that the following conditions are met:
23
+
24
+ 1. Redistributions of source code must retain the above copyright
25
+ notice, this list of conditions and the following disclaimer.
26
+
27
+ 2. Redistributions in binary form must reproduce the above copyright
28
+ notice, this list of conditions and the following disclaimer in the
29
+ documentation and/or other materials provided with the distribution.
30
+
31
+ 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories America
32
+ and IDIAP Research Institute nor the names of its contributors may be
33
+ used to endorse or promote products derived from this software without
34
+ specific prior written permission.
35
+
36
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
37
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
38
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
39
+ ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE
40
+ LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
41
+ CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
42
+ SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
43
+ INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
44
+ CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
45
+ ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
46
+ POSSIBILITY OF SUCH DAMAGE.
data/README.md CHANGED
@@ -30,7 +30,9 @@ This library follows the [PyTorch API](https://pytorch.org/docs/stable/torch.htm
30
30
 
31
31
  Many methods and options are missing at the moment. PRs welcome!
32
32
 
33
- Some examples below are from [Deep Learning with PyTorch: A 60 Minutes Blitz](https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html).
33
+ ## Tutorial
34
+
35
+ Some examples below are from [Deep Learning with PyTorch: A 60 Minutes Blitz](https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html)
34
36
 
35
37
  ### Tensors
36
38
 
@@ -145,7 +147,7 @@ Convert a Numo array to a tensor
145
147
 
146
148
  ```ruby
147
149
  b = Numo::NArray.cast([1, 2, 3])
148
- Torch.from_numpy(b)
150
+ Torch.from_numo(b)
149
151
  ```
150
152
 
151
153
  ### Autograd
@@ -180,10 +182,10 @@ Stop autograd from tracking history
180
182
 
181
183
  ```ruby
182
184
  x.requires_grad # true
183
- (x ** 2).requires_grad # true
185
+ (x**2).requires_grad # true
184
186
 
185
187
  Torch.no_grad do
186
- (x ** 2).requires_grad # false
188
+ (x**2).requires_grad # false
187
189
  end
188
190
  ```
189
191
 
@@ -359,6 +361,13 @@ Here’s a list of functions to create tensors (descriptions from the [C++ docs]
359
361
  Torch.zeros(3) # tensor([0, 0, 0])
360
362
  ```
361
363
 
364
+ ## Examples
365
+
366
+ Here are a few full examples:
367
+
368
+ - [Image classification with MNIST](examples/mnist)
369
+ - [Collaborative filtering with MovieLens](examples/movielens)
370
+
362
371
  ## LibTorch Installation
363
372
 
364
373
  [Download LibTorch](https://pytorch.org/). For Linux, use the `cxx11 ABI` version. Then run:
@@ -405,7 +414,7 @@ To get started with development:
405
414
  git clone https://github.com/ankane/torch-rb.git
406
415
  cd torch-rb
407
416
  bundle install
408
- bundle exec rake compile
417
+ bundle exec rake compile -- --with-torch-dir=/path/to/libtorch
409
418
  bundle exec rake test
410
419
  ```
411
420
 
data/ext/torch/ext.cpp CHANGED
@@ -88,7 +88,49 @@ IntArrayRef from_ruby<IntArrayRef>(Object x)
88
88
  }
89
89
 
90
90
  // for now
91
- typedef float Scalar;
91
+ class Scalar {
92
+ torch::Scalar value;
93
+ public:
94
+ Scalar(Object o) {
95
+ // TODO cast based on Ruby type
96
+ if (o.rb_type() == T_FIXNUM) {
97
+ value = torch::Scalar(from_ruby<int64_t>(o));
98
+ } else {
99
+ value = torch::Scalar(from_ruby<float>(o));
100
+ }
101
+ }
102
+ operator torch::Scalar() {
103
+ return value;
104
+ }
105
+ };
106
+
107
+ template<>
108
+ inline
109
+ Scalar from_ruby<Scalar>(Object x)
110
+ {
111
+ return Scalar(x);
112
+ }
113
+
114
+ class TensorList {
115
+ std::vector<torch::Tensor> vec;
116
+ public:
117
+ TensorList(Object o) {
118
+ Array a = Array(o);
119
+ for (size_t i = 0; i < a.size(); i++) {
120
+ vec.push_back(from_ruby<torch::Tensor>(a[i]));
121
+ }
122
+ }
123
+ operator torch::TensorList() {
124
+ return torch::TensorList(vec);
125
+ }
126
+ };
127
+
128
+ template<>
129
+ inline
130
+ TensorList from_ruby<TensorList>(Object x)
131
+ {
132
+ return TensorList(x);
133
+ }
92
134
 
93
135
  extern "C"
94
136
  void Init_ext()
@@ -206,6 +248,11 @@ void Init_ext()
206
248
  *[](torch::Tensor& input, int64_t dim, bool keepdim) {
207
249
  return torch::argmax(input, dim, keepdim);
208
250
  })
251
+ .define_singleton_method(
252
+ "_cat",
253
+ *[](TensorList tensors, int64_t dim) {
254
+ return torch::cat(tensors, dim);
255
+ })
209
256
  .define_singleton_method(
210
257
  "_norm",
211
258
  *[](torch::Tensor& input) {
@@ -221,6 +268,17 @@ void Init_ext()
221
268
  *[](torch::Tensor& input) {
222
269
  return torch::max(input);
223
270
  })
271
+ .define_singleton_method(
272
+ "_max_out",
273
+ *[](torch::Tensor &max, torch::Tensor &max_indices, const torch::Tensor &input, int64_t dim, bool keepdim) {
274
+ // TODO add return value
275
+ torch::_max_out(max, max_indices, input, dim, keepdim);
276
+ })
277
+ .define_singleton_method(
278
+ "_sqrt",
279
+ *[](torch::Tensor& input) {
280
+ return torch::sqrt(input);
281
+ })
224
282
  .define_singleton_method(
225
283
  "_exp",
226
284
  *[](torch::Tensor& input) {
@@ -231,6 +289,11 @@ void Init_ext()
231
289
  *[](torch::Tensor& input) {
232
290
  return torch::log(input);
233
291
  })
292
+ .define_singleton_method(
293
+ "_sign",
294
+ *[](torch::Tensor& input) {
295
+ return torch::sign(input);
296
+ })
234
297
  .define_singleton_method(
235
298
  "_unsqueeze",
236
299
  *[](torch::Tensor& input, int64_t dim) {
@@ -251,6 +314,18 @@ void Init_ext()
251
314
  *[](torch::Tensor& input, torch::Tensor& other) {
252
315
  return torch::eq(input, other);
253
316
  })
317
+ .define_singleton_method(
318
+ "_gt",
319
+ // TODO support tensors
320
+ *[](torch::Tensor& input, Scalar other) {
321
+ return torch::gt(input, other);
322
+ })
323
+ .define_singleton_method(
324
+ "_lt",
325
+ // TODO support tensors
326
+ *[](torch::Tensor& input, Scalar other) {
327
+ return torch::lt(input, other);
328
+ })
254
329
  .define_singleton_method(
255
330
  "_add",
256
331
  *[](torch::Tensor& input, torch::Tensor& other) {
@@ -258,7 +333,7 @@ void Init_ext()
258
333
  })
259
334
  .define_singleton_method(
260
335
  "_add_scalar",
261
- *[](torch::Tensor& input, float other) {
336
+ *[](torch::Tensor& input, Scalar other) {
262
337
  return torch::add(input, other);
263
338
  })
264
339
  .define_singleton_method(
@@ -273,7 +348,7 @@ void Init_ext()
273
348
  })
274
349
  .define_singleton_method(
275
350
  "_sub_scalar",
276
- *[](torch::Tensor& input, float other) {
351
+ *[](torch::Tensor& input, Scalar other) {
277
352
  return torch::sub(input, other);
278
353
  })
279
354
  .define_singleton_method(
@@ -283,7 +358,7 @@ void Init_ext()
283
358
  })
284
359
  .define_singleton_method(
285
360
  "_mul_scalar",
286
- *[](torch::Tensor& input, float other) {
361
+ *[](torch::Tensor& input, Scalar other) {
287
362
  return torch::mul(input, other);
288
363
  })
289
364
  .define_singleton_method(
@@ -293,7 +368,7 @@ void Init_ext()
293
368
  })
294
369
  .define_singleton_method(
295
370
  "_div_scalar",
296
- *[](torch::Tensor& input, float other) {
371
+ *[](torch::Tensor& input, Scalar other) {
297
372
  return torch::div(input, other);
298
373
  })
299
374
  .define_singleton_method(
@@ -303,7 +378,7 @@ void Init_ext()
303
378
  })
304
379
  .define_singleton_method(
305
380
  "_remainder_scalar",
306
- *[](torch::Tensor& input, float other) {
381
+ *[](torch::Tensor& input, Scalar other) {
307
382
  return torch::remainder(input, other);
308
383
  })
309
384
  .define_singleton_method(
@@ -311,6 +386,11 @@ void Init_ext()
311
386
  *[](torch::Tensor& input, Scalar exponent) {
312
387
  return torch::pow(input, exponent);
313
388
  })
389
+ .define_singleton_method(
390
+ "_abs",
391
+ *[](torch::Tensor& input) {
392
+ return torch::abs(input);
393
+ })
314
394
  .define_singleton_method(
315
395
  "_neg",
316
396
  *[](torch::Tensor& input) {
@@ -321,25 +401,20 @@ void Init_ext()
321
401
  *[](torch::Tensor& input, IntArrayRef shape) {
322
402
  return torch::reshape(input, shape);
323
403
  })
404
+ .define_singleton_method(
405
+ "_flatten",
406
+ *[](torch::Tensor& input, int64_t start_dim, int64_t end_dim) {
407
+ return torch::flatten(input, start_dim, end_dim);
408
+ })
324
409
  .define_singleton_method(
325
410
  "relu",
326
411
  *[](torch::Tensor& input) {
327
412
  return torch::relu(input);
328
413
  })
329
- .define_singleton_method(
330
- "prelu",
331
- *[](torch::Tensor& input, torch::Tensor& weight) {
332
- return torch::prelu(input, weight);
333
- })
334
- .define_singleton_method(
335
- "leaky_relu",
336
- *[](torch::Tensor& input, Scalar negative_slope = 0.01) {
337
- return torch::leaky_relu(input, negative_slope);
338
- })
339
414
  .define_singleton_method(
340
415
  "conv2d",
341
- *[](torch::Tensor& input, torch::Tensor& weight, torch::Tensor& bias, IntArrayRef stride, IntArrayRef padding) {
342
- return torch::conv2d(input, weight, bias, stride, padding);
416
+ *[](torch::Tensor& input, torch::Tensor& weight, torch::Tensor& bias, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, int64_t groups) {
417
+ return torch::conv2d(input, weight, bias, stride, padding, dilation, groups);
343
418
  })
344
419
  .define_singleton_method(
345
420
  "linear",
@@ -356,6 +431,52 @@ void Init_ext()
356
431
  *[](torch::Tensor& input, IntArrayRef kernel_size) {
357
432
  return torch::avg_pool2d(input, kernel_size);
358
433
  })
434
+ .define_singleton_method(
435
+ "_dropout",
436
+ *[](torch::Tensor& input, float p, bool train) {
437
+ return torch::dropout(input, p, train);
438
+ })
439
+ .define_singleton_method(
440
+ "_dropout!",
441
+ *[](torch::Tensor& input, float p, bool train) {
442
+ return torch::dropout_(input, p, train);
443
+ })
444
+ .define_singleton_method(
445
+ "_feature_dropout",
446
+ *[](torch::Tensor& input, float p, bool train) {
447
+ return torch::feature_dropout(input, p, train);
448
+ })
449
+ .define_singleton_method(
450
+ "_feature_dropout!",
451
+ *[](torch::Tensor& input, float p, bool train) {
452
+ return torch::feature_dropout_(input, p, train);
453
+ })
454
+ .define_singleton_method(
455
+ "_alpha_dropout",
456
+ *[](torch::Tensor& input, float p, bool train) {
457
+ return torch::alpha_dropout(input, p, train);
458
+ })
459
+ .define_singleton_method(
460
+ "_alpha_dropout!",
461
+ *[](torch::Tensor& input, float p, bool train) {
462
+ return torch::alpha_dropout_(input, p, train);
463
+ })
464
+ .define_singleton_method(
465
+ "_feature_alpha_dropout",
466
+ *[](torch::Tensor& input, float p, bool train) {
467
+ return torch::feature_alpha_dropout(input, p, train);
468
+ })
469
+ .define_singleton_method(
470
+ "_feature_alpha_dropout!",
471
+ *[](torch::Tensor& input, float p, bool train) {
472
+ return torch::feature_alpha_dropout_(input, p, train);
473
+ })
474
+ .define_singleton_method(
475
+ "_embedding",
476
+ // weight and indices are swapped from Python interface
477
+ *[](const torch::Tensor &indices, const torch::Tensor &weight, int64_t padding_idx, bool scale_grad_by_freq, bool sparse) {
478
+ return torch::embedding(weight, indices, padding_idx, scale_grad_by_freq, sparse);
479
+ })
359
480
  .define_singleton_method(
360
481
  "mse_loss",
361
482
  *[](torch::Tensor& input, torch::Tensor& target, std::string reduction) {
@@ -364,8 +485,16 @@ void Init_ext()
364
485
  })
365
486
  .define_singleton_method(
366
487
  "nll_loss",
367
- *[](torch::Tensor& input, torch::Tensor& target) {
368
- return torch::nll_loss(input, target);
488
+ *[](torch::Tensor& input, torch::Tensor& target, std::string reduction) {
489
+ auto red = reduction == "mean" ? Reduction::Mean : Reduction::Sum;
490
+ return torch::nll_loss(input, target, {}, red);
491
+ })
492
+ .define_singleton_method("numel", &torch::numel)
493
+ .define_singleton_method(
494
+ "_from_blob",
495
+ *[](String s, IntArrayRef size, const torch::TensorOptions &options) {
496
+ void *data = const_cast<char *>(s.c_str());
497
+ return torch::from_blob(data, size, options);
369
498
  })
370
499
  .define_singleton_method(
371
500
  "_tensor",
@@ -387,9 +516,19 @@ void Init_ext()
387
516
  .define_method("sparse?", &torch::Tensor::is_sparse)
388
517
  .define_method("quantized?", &torch::Tensor::is_quantized)
389
518
  .define_method("dim", &torch::Tensor::dim)
390
- .define_method("numel", &torch::Tensor::numel)
391
519
  .define_method("element_size", &torch::Tensor::element_size)
392
520
  .define_method("requires_grad", &torch::Tensor::requires_grad)
521
+ .define_method("view_as", &torch::Tensor::view_as)
522
+ .define_method(
523
+ "addcmul!",
524
+ *[](torch::Tensor& self, Scalar value, const torch::Tensor & tensor1, const torch::Tensor & tensor2) {
525
+ return self.addcmul_(tensor1, tensor2, value);
526
+ })
527
+ .define_method(
528
+ "addcdiv!",
529
+ *[](torch::Tensor& self, Scalar value, const torch::Tensor & tensor1, const torch::Tensor & tensor2) {
530
+ return self.addcdiv_(tensor1, tensor2, value);
531
+ })
393
532
  .define_method(
394
533
  "zero!",
395
534
  *[](torch::Tensor& self) {
@@ -460,24 +599,74 @@ void Init_ext()
460
599
  return self.view(size);
461
600
  })
462
601
  .define_method(
463
- "add!",
602
+ "resize_as!",
464
603
  *[](torch::Tensor& self, torch::Tensor& other) {
465
- self.add_(other);
604
+ return self.resize_as_(other);
605
+ })
606
+ .define_method(
607
+ "fill!",
608
+ *[](torch::Tensor& self, Scalar value) {
609
+ return self.fill_(value);
610
+ })
611
+ .define_method(
612
+ "_add!",
613
+ *[](torch::Tensor& self, torch::Tensor& other) {
614
+ return self.add_(other);
615
+ })
616
+ .define_method(
617
+ "_add_alpha!",
618
+ *[](torch::Tensor& self, torch::Tensor& other, Scalar alpha) {
619
+ return self.add_(other, alpha);
620
+ })
621
+ .define_method(
622
+ "_add_scalar!",
623
+ *[](torch::Tensor& self, Scalar other) {
624
+ return self.add_(other);
625
+ })
626
+ .define_method(
627
+ "normal!",
628
+ *[](torch::Tensor& self, double mean, double std) {
629
+ return self.normal_(mean, std);
466
630
  })
467
631
  .define_method(
468
632
  "sub!",
469
633
  *[](torch::Tensor& self, torch::Tensor& other) {
470
- self.sub_(other);
634
+ return self.sub_(other);
471
635
  })
472
636
  .define_method(
473
- "mul!",
637
+ "_mul!",
474
638
  *[](torch::Tensor& self, torch::Tensor& other) {
475
- self.mul_(other);
639
+ return self.mul_(other);
640
+ })
641
+ .define_method(
642
+ "_mul_scalar!",
643
+ *[](torch::Tensor& self, Scalar other) {
644
+ return self.mul_(other);
476
645
  })
477
646
  .define_method(
478
647
  "div!",
479
648
  *[](torch::Tensor& self, torch::Tensor& other) {
480
- self.div_(other);
649
+ return self.div_(other);
650
+ })
651
+ .define_method(
652
+ "sqrt!",
653
+ *[](torch::Tensor& self) {
654
+ return self.sqrt_();
655
+ })
656
+ .define_method(
657
+ "unsqueeze!",
658
+ *[](torch::Tensor& self, int64_t dim) {
659
+ return self.unsqueeze_(dim);
660
+ })
661
+ .define_method(
662
+ "copy!",
663
+ *[](torch::Tensor& self, torch::Tensor& src) {
664
+ return self.copy_(src);
665
+ })
666
+ .define_method(
667
+ "clone",
668
+ *[](torch::Tensor& self) {
669
+ return self.clone();
481
670
  })
482
671
  .define_method(
483
672
  "log_softmax",
@@ -532,8 +721,10 @@ void Init_ext()
532
721
  a.push(data[i]);
533
722
  }
534
723
  } else if (dtype == torch::kBool) {
535
- // bool
536
- throw std::runtime_error("Type not supported yet");
724
+ bool* data = self.data_ptr<bool>();
725
+ for (int i = 0; i < self.numel(); i++) {
726
+ a.push(data[i] ? True : False);
727
+ }
537
728
  } else {
538
729
  throw std::runtime_error("Unsupported type");
539
730
  }
@@ -544,6 +735,11 @@ void Init_ext()
544
735
  *[](torch::Tensor& self, int i) {
545
736
  return self.size(i);
546
737
  })
738
+ .define_method(
739
+ "_to",
740
+ *[](torch::Tensor& self, torch::Device device, int dtype, bool non_blocking, bool copy) {
741
+ return self.to(device, (torch::ScalarType) dtype, non_blocking, copy);
742
+ })
547
743
  .define_singleton_method(
548
744
  "_make_subclass",
549
745
  *[](torch::Tensor& rd, bool requires_grad) {
@@ -597,12 +793,17 @@ void Init_ext()
597
793
 
598
794
  Module rb_mInit = define_module_under(rb_mNN, "Init")
599
795
  .define_singleton_method(
600
- "kaiming_uniform_",
796
+ "kaiming_uniform!",
601
797
  *[](torch::Tensor& input, double a) {
602
798
  return torch::nn::init::kaiming_uniform_(input, a);
603
799
  })
604
800
  .define_singleton_method(
605
- "uniform_",
801
+ "normal!",
802
+ *[](torch::Tensor& input) {
803
+ return torch::nn::init::normal_(input);
804
+ })
805
+ .define_singleton_method(
806
+ "uniform!",
606
807
  *[](torch::Tensor& input, double to, double from) {
607
808
  return torch::nn::init::uniform_(input, to, from);
608
809
  });
@@ -619,4 +820,20 @@ void Init_ext()
619
820
  *[](torch::autograd::Variable& self) {
620
821
  return self.grad().defined();
621
822
  });
823
+
824
+ Class rb_cDevice = define_class_under<torch::Device>(rb_mTorch, "Device")
825
+ .define_constructor(Constructor<torch::Device, std::string>())
826
+ .define_method("index", &torch::Device::index)
827
+ .define_method("index?", &torch::Device::has_index)
828
+ .define_method(
829
+ "type",
830
+ *[](torch::Device& self) {
831
+ std::stringstream s;
832
+ s << self.type();
833
+ return s.str();
834
+ });
835
+
836
+ Module rb_mCUDA = define_module_under(rb_mTorch, "CUDA")
837
+ .define_singleton_method("available?", &torch::cuda::is_available)
838
+ .define_singleton_method("device_count", &torch::cuda::device_count);
622
839
  }