torch-rb 0.1.0 → 0.1.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/README.md +13 -11
- data/ext/torch/ext.cpp +23 -8
- data/ext/torch/extconf.rb +3 -0
- data/lib/torch/ext.bundle +0 -0
- data/lib/torch/version.rb +1 -1
- metadata +2 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: ff0920ba955063c03309fdb45ecf228b51c556508bea30b510d6bf652c1d0b18
|
4
|
+
data.tar.gz: 481dccf6a8e929230033f74c82bc9d292ef38ea219e2cb2cc61ca0b0c5457403
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: cd6c8fd9db4af15640217c09813c4f86d3f66360202d30711015c4f34552853ef281d5614fd78dc274d405da0d9f46f08a2359475ae1b0721143db49183faf5d
|
7
|
+
data.tar.gz: ee638c08458e0d2a8fac52e29c45d1347a74847ca7d8dab3a9a573afd887814d4c578c4e1a7fb80b204222785b27e21ca3c138f3face681e98e63c1bc02a9a7f
|
data/CHANGELOG.md
CHANGED
data/README.md
CHANGED
@@ -2,14 +2,16 @@
|
|
2
2
|
|
3
3
|
:fire: Deep learning for Ruby, powered by [LibTorch](https://pytorch.org)
|
4
4
|
|
5
|
-
|
5
|
+
This gem is currently experimental. There may be breaking changes between each release.
|
6
|
+
|
7
|
+
[![Build Status](https://travis-ci.org/ankane/torch-rb.svg?branch=master)](https://travis-ci.org/ankane/torch-rb)
|
6
8
|
|
7
9
|
## Installation
|
8
10
|
|
9
11
|
First, [install LibTorch](#libtorch-installation). For Homebrew, use:
|
10
12
|
|
11
13
|
```sh
|
12
|
-
brew install
|
14
|
+
brew install libtorch
|
13
15
|
```
|
14
16
|
|
15
17
|
Add this line to your application’s Gemfile:
|
@@ -171,7 +173,7 @@ out.backward
|
|
171
173
|
Get gradients
|
172
174
|
|
173
175
|
```ruby
|
174
|
-
x.grad
|
176
|
+
x.grad # tensor([[4.5, 4.5], [4.5, 4.5]])
|
175
177
|
```
|
176
178
|
|
177
179
|
Stop autograd from tracking history
|
@@ -242,7 +244,7 @@ Here’s a list of functions to create tensors (descriptions from the [C++ docs]
|
|
242
244
|
- `empty` returns a tensor with uninitialized values
|
243
245
|
|
244
246
|
```ruby
|
245
|
-
Torch.empty(3)
|
247
|
+
Torch.empty(3) # tensor([7.0054e-45, 0.0000e+00, 0.0000e+00])
|
246
248
|
```
|
247
249
|
|
248
250
|
- `eye` returns an identity matrix
|
@@ -278,19 +280,19 @@ Here’s a list of functions to create tensors (descriptions from the [C++ docs]
|
|
278
280
|
- `rand` returns a tensor filled with values drawn from a uniform distribution on [0, 1)
|
279
281
|
|
280
282
|
```ruby
|
281
|
-
Torch.rand(3)
|
283
|
+
Torch.rand(3) # tensor([0.5444, 0.8799, 0.5571])
|
282
284
|
```
|
283
285
|
|
284
286
|
- `randint` returns a tensor with integers randomly drawn from an interval
|
285
287
|
|
286
288
|
```ruby
|
287
|
-
Torch.randint(1, 10, [3])
|
289
|
+
Torch.randint(1, 10, [3]) # tensor([7, 6, 4])
|
288
290
|
```
|
289
291
|
|
290
292
|
- `randn` returns a tensor filled with values drawn from a unit normal distribution
|
291
293
|
|
292
294
|
```ruby
|
293
|
-
Torch.randn(3)
|
295
|
+
Torch.randn(3) # tensor([-0.7147, 0.6614, 1.1453])
|
294
296
|
```
|
295
297
|
|
296
298
|
- `randperm` returns a tensor filled with a random permutation of integers in some interval
|
@@ -307,10 +309,10 @@ Here’s a list of functions to create tensors (descriptions from the [C++ docs]
|
|
307
309
|
|
308
310
|
## LibTorch Installation
|
309
311
|
|
310
|
-
[Download LibTorch](https://pytorch.org/)
|
312
|
+
[Download LibTorch](https://pytorch.org/). For Linux, use the `cxx11 ABI` version. Then run:
|
311
313
|
|
312
314
|
```sh
|
313
|
-
|
315
|
+
bundle config build.torch-rb --with-torch-dir=/path/to/libtorch
|
314
316
|
```
|
315
317
|
|
316
318
|
### Homebrew
|
@@ -318,10 +320,10 @@ gem install torch-rb -- --with-torch-dir=/path/to/libtorch
|
|
318
320
|
For Mac, you can use Homebrew.
|
319
321
|
|
320
322
|
```sh
|
321
|
-
brew install
|
323
|
+
brew install libtorch
|
322
324
|
```
|
323
325
|
|
324
|
-
Then install the gem (no need for
|
326
|
+
Then install the gem (no need for `bundle config`).
|
325
327
|
|
326
328
|
## rbenv
|
327
329
|
|
data/ext/torch/ext.cpp
CHANGED
@@ -441,19 +441,28 @@ void Init_ext()
|
|
441
441
|
auto dtype = self.dtype();
|
442
442
|
|
443
443
|
// TODO DRY if someone knows C++
|
444
|
-
|
445
|
-
|
446
|
-
|
444
|
+
if (dtype == torch::kByte) {
|
445
|
+
uint8_t* data = self.data_ptr<uint8_t>();
|
446
|
+
for (int i = 0; i < self.numel(); i++) {
|
447
|
+
a.push(data[i]);
|
448
|
+
}
|
449
|
+
} else if (dtype == torch::kChar) {
|
450
|
+
int8_t* data = self.data_ptr<int8_t>();
|
451
|
+
for (int i = 0; i < self.numel(); i++) {
|
452
|
+
a.push(to_ruby<int>(data[i]));
|
453
|
+
}
|
454
|
+
} else if (dtype == torch::kShort) {
|
455
|
+
int16_t* data = self.data_ptr<int16_t>();
|
447
456
|
for (int i = 0; i < self.numel(); i++) {
|
448
457
|
a.push(data[i]);
|
449
458
|
}
|
450
459
|
} else if (dtype == torch::kInt) {
|
451
|
-
|
460
|
+
int32_t* data = self.data_ptr<int32_t>();
|
452
461
|
for (int i = 0; i < self.numel(); i++) {
|
453
462
|
a.push(data[i]);
|
454
463
|
}
|
455
464
|
} else if (dtype == torch::kLong) {
|
456
|
-
|
465
|
+
int64_t* data = self.data_ptr<int64_t>();
|
457
466
|
for (int i = 0; i < self.numel(); i++) {
|
458
467
|
a.push(data[i]);
|
459
468
|
}
|
@@ -467,8 +476,11 @@ void Init_ext()
|
|
467
476
|
for (int i = 0; i < self.numel(); i++) {
|
468
477
|
a.push(data[i]);
|
469
478
|
}
|
479
|
+
} else if (dtype == torch::kBool) {
|
480
|
+
// bool
|
481
|
+
throw std::runtime_error("Type not supported yet");
|
470
482
|
} else {
|
471
|
-
throw "Unsupported type";
|
483
|
+
throw std::runtime_error("Unsupported type");
|
472
484
|
}
|
473
485
|
return a;
|
474
486
|
})
|
@@ -499,8 +511,11 @@ void Init_ext()
|
|
499
511
|
torch::Layout l;
|
500
512
|
if (layout == "strided") {
|
501
513
|
l = torch::kStrided;
|
514
|
+
} else if (layout == "sparse") {
|
515
|
+
l = torch::kSparse;
|
516
|
+
throw std::runtime_error("Sparse layout not supported yet");
|
502
517
|
} else {
|
503
|
-
throw "Unsupported layout";
|
518
|
+
throw std::runtime_error("Unsupported layout: " + layout);
|
504
519
|
}
|
505
520
|
return self.layout(l);
|
506
521
|
})
|
@@ -513,7 +528,7 @@ void Init_ext()
|
|
513
528
|
} else if (device == "cuda") {
|
514
529
|
d = torch::kCUDA;
|
515
530
|
} else {
|
516
|
-
throw "Unsupported device";
|
531
|
+
throw std::runtime_error("Unsupported device: " + device);
|
517
532
|
}
|
518
533
|
return self.device(d);
|
519
534
|
})
|
data/ext/torch/extconf.rb
CHANGED
data/lib/torch/ext.bundle
CHANGED
Binary file
|
data/lib/torch/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.1.
|
4
|
+
version: 0.1.1
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2019-11-
|
11
|
+
date: 2019-11-27 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|