torch-rb 0.1.0 → 0.1.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +6 -0
- data/README.md +13 -11
- data/ext/torch/ext.cpp +23 -8
- data/ext/torch/extconf.rb +3 -0
- data/lib/torch/ext.bundle +0 -0
- data/lib/torch/version.rb +1 -1
- metadata +2 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: ff0920ba955063c03309fdb45ecf228b51c556508bea30b510d6bf652c1d0b18
|
4
|
+
data.tar.gz: 481dccf6a8e929230033f74c82bc9d292ef38ea219e2cb2cc61ca0b0c5457403
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: cd6c8fd9db4af15640217c09813c4f86d3f66360202d30711015c4f34552853ef281d5614fd78dc274d405da0d9f46f08a2359475ae1b0721143db49183faf5d
|
7
|
+
data.tar.gz: ee638c08458e0d2a8fac52e29c45d1347a74847ca7d8dab3a9a573afd887814d4c578c4e1a7fb80b204222785b27e21ca3c138f3face681e98e63c1bc02a9a7f
|
data/CHANGELOG.md
CHANGED
data/README.md
CHANGED
@@ -2,14 +2,16 @@
|
|
2
2
|
|
3
3
|
:fire: Deep learning for Ruby, powered by [LibTorch](https://pytorch.org)
|
4
4
|
|
5
|
-
|
5
|
+
This gem is currently experimental. There may be breaking changes between each release.
|
6
|
+
|
7
|
+
[](https://travis-ci.org/ankane/torch-rb)
|
6
8
|
|
7
9
|
## Installation
|
8
10
|
|
9
11
|
First, [install LibTorch](#libtorch-installation). For Homebrew, use:
|
10
12
|
|
11
13
|
```sh
|
12
|
-
brew install
|
14
|
+
brew install libtorch
|
13
15
|
```
|
14
16
|
|
15
17
|
Add this line to your application’s Gemfile:
|
@@ -171,7 +173,7 @@ out.backward
|
|
171
173
|
Get gradients
|
172
174
|
|
173
175
|
```ruby
|
174
|
-
x.grad
|
176
|
+
x.grad # tensor([[4.5, 4.5], [4.5, 4.5]])
|
175
177
|
```
|
176
178
|
|
177
179
|
Stop autograd from tracking history
|
@@ -242,7 +244,7 @@ Here’s a list of functions to create tensors (descriptions from the [C++ docs]
|
|
242
244
|
- `empty` returns a tensor with uninitialized values
|
243
245
|
|
244
246
|
```ruby
|
245
|
-
Torch.empty(3)
|
247
|
+
Torch.empty(3) # tensor([7.0054e-45, 0.0000e+00, 0.0000e+00])
|
246
248
|
```
|
247
249
|
|
248
250
|
- `eye` returns an identity matrix
|
@@ -278,19 +280,19 @@ Here’s a list of functions to create tensors (descriptions from the [C++ docs]
|
|
278
280
|
- `rand` returns a tensor filled with values drawn from a uniform distribution on [0, 1)
|
279
281
|
|
280
282
|
```ruby
|
281
|
-
Torch.rand(3)
|
283
|
+
Torch.rand(3) # tensor([0.5444, 0.8799, 0.5571])
|
282
284
|
```
|
283
285
|
|
284
286
|
- `randint` returns a tensor with integers randomly drawn from an interval
|
285
287
|
|
286
288
|
```ruby
|
287
|
-
Torch.randint(1, 10, [3])
|
289
|
+
Torch.randint(1, 10, [3]) # tensor([7, 6, 4])
|
288
290
|
```
|
289
291
|
|
290
292
|
- `randn` returns a tensor filled with values drawn from a unit normal distribution
|
291
293
|
|
292
294
|
```ruby
|
293
|
-
Torch.randn(3)
|
295
|
+
Torch.randn(3) # tensor([-0.7147, 0.6614, 1.1453])
|
294
296
|
```
|
295
297
|
|
296
298
|
- `randperm` returns a tensor filled with a random permutation of integers in some interval
|
@@ -307,10 +309,10 @@ Here’s a list of functions to create tensors (descriptions from the [C++ docs]
|
|
307
309
|
|
308
310
|
## LibTorch Installation
|
309
311
|
|
310
|
-
[Download LibTorch](https://pytorch.org/)
|
312
|
+
[Download LibTorch](https://pytorch.org/). For Linux, use the `cxx11 ABI` version. Then run:
|
311
313
|
|
312
314
|
```sh
|
313
|
-
|
315
|
+
bundle config build.torch-rb --with-torch-dir=/path/to/libtorch
|
314
316
|
```
|
315
317
|
|
316
318
|
### Homebrew
|
@@ -318,10 +320,10 @@ gem install torch-rb -- --with-torch-dir=/path/to/libtorch
|
|
318
320
|
For Mac, you can use Homebrew.
|
319
321
|
|
320
322
|
```sh
|
321
|
-
brew install
|
323
|
+
brew install libtorch
|
322
324
|
```
|
323
325
|
|
324
|
-
Then install the gem (no need for
|
326
|
+
Then install the gem (no need for `bundle config`).
|
325
327
|
|
326
328
|
## rbenv
|
327
329
|
|
data/ext/torch/ext.cpp
CHANGED
@@ -441,19 +441,28 @@ void Init_ext()
|
|
441
441
|
auto dtype = self.dtype();
|
442
442
|
|
443
443
|
// TODO DRY if someone knows C++
|
444
|
-
|
445
|
-
|
446
|
-
|
444
|
+
if (dtype == torch::kByte) {
|
445
|
+
uint8_t* data = self.data_ptr<uint8_t>();
|
446
|
+
for (int i = 0; i < self.numel(); i++) {
|
447
|
+
a.push(data[i]);
|
448
|
+
}
|
449
|
+
} else if (dtype == torch::kChar) {
|
450
|
+
int8_t* data = self.data_ptr<int8_t>();
|
451
|
+
for (int i = 0; i < self.numel(); i++) {
|
452
|
+
a.push(to_ruby<int>(data[i]));
|
453
|
+
}
|
454
|
+
} else if (dtype == torch::kShort) {
|
455
|
+
int16_t* data = self.data_ptr<int16_t>();
|
447
456
|
for (int i = 0; i < self.numel(); i++) {
|
448
457
|
a.push(data[i]);
|
449
458
|
}
|
450
459
|
} else if (dtype == torch::kInt) {
|
451
|
-
|
460
|
+
int32_t* data = self.data_ptr<int32_t>();
|
452
461
|
for (int i = 0; i < self.numel(); i++) {
|
453
462
|
a.push(data[i]);
|
454
463
|
}
|
455
464
|
} else if (dtype == torch::kLong) {
|
456
|
-
|
465
|
+
int64_t* data = self.data_ptr<int64_t>();
|
457
466
|
for (int i = 0; i < self.numel(); i++) {
|
458
467
|
a.push(data[i]);
|
459
468
|
}
|
@@ -467,8 +476,11 @@ void Init_ext()
|
|
467
476
|
for (int i = 0; i < self.numel(); i++) {
|
468
477
|
a.push(data[i]);
|
469
478
|
}
|
479
|
+
} else if (dtype == torch::kBool) {
|
480
|
+
// bool
|
481
|
+
throw std::runtime_error("Type not supported yet");
|
470
482
|
} else {
|
471
|
-
throw "Unsupported type";
|
483
|
+
throw std::runtime_error("Unsupported type");
|
472
484
|
}
|
473
485
|
return a;
|
474
486
|
})
|
@@ -499,8 +511,11 @@ void Init_ext()
|
|
499
511
|
torch::Layout l;
|
500
512
|
if (layout == "strided") {
|
501
513
|
l = torch::kStrided;
|
514
|
+
} else if (layout == "sparse") {
|
515
|
+
l = torch::kSparse;
|
516
|
+
throw std::runtime_error("Sparse layout not supported yet");
|
502
517
|
} else {
|
503
|
-
throw "Unsupported layout";
|
518
|
+
throw std::runtime_error("Unsupported layout: " + layout);
|
504
519
|
}
|
505
520
|
return self.layout(l);
|
506
521
|
})
|
@@ -513,7 +528,7 @@ void Init_ext()
|
|
513
528
|
} else if (device == "cuda") {
|
514
529
|
d = torch::kCUDA;
|
515
530
|
} else {
|
516
|
-
throw "Unsupported device";
|
531
|
+
throw std::runtime_error("Unsupported device: " + device);
|
517
532
|
}
|
518
533
|
return self.device(d);
|
519
534
|
})
|
data/ext/torch/extconf.rb
CHANGED
data/lib/torch/ext.bundle
CHANGED
Binary file
|
data/lib/torch/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.1.
|
4
|
+
version: 0.1.1
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2019-11-
|
11
|
+
date: 2019-11-27 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|