torch-rb 0.1.0 → 0.1.1

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 38e16e7f07d004fd9625f168694356d551c79cbc62b0131fe1403e4c0995f296
4
- data.tar.gz: 66bf6ae0e4dd373a7542fbfb1cfb9dbd89fc455e16166e6a76d0945b32fecf38
3
+ metadata.gz: ff0920ba955063c03309fdb45ecf228b51c556508bea30b510d6bf652c1d0b18
4
+ data.tar.gz: 481dccf6a8e929230033f74c82bc9d292ef38ea219e2cb2cc61ca0b0c5457403
5
5
  SHA512:
6
- metadata.gz: d100e3a21ac877fe93ac61e9b5e0d8a5e61126684fc037dda3e9f703b040188b1e1523aa4111dff4aaf92ada1001597c5f60674b9583b14d31afd18dbf1ff18d
7
- data.tar.gz: c234dee79e26d3ee25ade2aaddd75f155dea6d59d8b9c5af2c571423a7aaa8a6489f5cfce89f09f390468a951b1644a4212c19525a79816be09214f0938860a8
6
+ metadata.gz: cd6c8fd9db4af15640217c09813c4f86d3f66360202d30711015c4f34552853ef281d5614fd78dc274d405da0d9f46f08a2359475ae1b0721143db49183faf5d
7
+ data.tar.gz: ee638c08458e0d2a8fac52e29c45d1347a74847ca7d8dab3a9a573afd887814d4c578c4e1a7fb80b204222785b27e21ca3c138f3face681e98e63c1bc02a9a7f
data/CHANGELOG.md CHANGED
@@ -1,3 +1,9 @@
1
+ ## 0.1.1 (2019-11-26)
2
+
3
+ - Added support for `uint8` and `int8` types
4
+ - Fixed `undefined symbol` error on Linux
5
+ - Fixed C++ error messages
6
+
1
7
  ## 0.1.0 (2019-11-26)
2
8
 
3
9
  - First release
data/README.md CHANGED
@@ -2,14 +2,16 @@
2
2
 
3
3
  :fire: Deep learning for Ruby, powered by [LibTorch](https://pytorch.org)
4
4
 
5
- **Note:** This gem is currently experimental. There may be breaking changes between each release.
5
+ This gem is currently experimental. There may be breaking changes between each release.
6
+
7
+ [![Build Status](https://travis-ci.org/ankane/torch-rb.svg?branch=master)](https://travis-ci.org/ankane/torch-rb)
6
8
 
7
9
  ## Installation
8
10
 
9
11
  First, [install LibTorch](#libtorch-installation). For Homebrew, use:
10
12
 
11
13
  ```sh
12
- brew install ankane/brew/libtorch
14
+ brew install libtorch
13
15
  ```
14
16
 
15
17
  Add this line to your application’s Gemfile:
@@ -171,7 +173,7 @@ out.backward
171
173
  Get gradients
172
174
 
173
175
  ```ruby
174
- x.grad
176
+ x.grad # tensor([[4.5, 4.5], [4.5, 4.5]])
175
177
  ```
176
178
 
177
179
  Stop autograd from tracking history
@@ -242,7 +244,7 @@ Here’s a list of functions to create tensors (descriptions from the [C++ docs]
242
244
  - `empty` returns a tensor with uninitialized values
243
245
 
244
246
  ```ruby
245
- Torch.empty(3)
247
+ Torch.empty(3) # tensor([7.0054e-45, 0.0000e+00, 0.0000e+00])
246
248
  ```
247
249
 
248
250
  - `eye` returns an identity matrix
@@ -278,19 +280,19 @@ Here’s a list of functions to create tensors (descriptions from the [C++ docs]
278
280
  - `rand` returns a tensor filled with values drawn from a uniform distribution on [0, 1)
279
281
 
280
282
  ```ruby
281
- Torch.rand(3)
283
+ Torch.rand(3) # tensor([0.5444, 0.8799, 0.5571])
282
284
  ```
283
285
 
284
286
  - `randint` returns a tensor with integers randomly drawn from an interval
285
287
 
286
288
  ```ruby
287
- Torch.randint(1, 10, [3])
289
+ Torch.randint(1, 10, [3]) # tensor([7, 6, 4])
288
290
  ```
289
291
 
290
292
  - `randn` returns a tensor filled with values drawn from a unit normal distribution
291
293
 
292
294
  ```ruby
293
- Torch.randn(3)
295
+ Torch.randn(3) # tensor([-0.7147, 0.6614, 1.1453])
294
296
  ```
295
297
 
296
298
  - `randperm` returns a tensor filled with a random permutation of integers in some interval
@@ -307,10 +309,10 @@ Here’s a list of functions to create tensors (descriptions from the [C++ docs]
307
309
 
308
310
  ## LibTorch Installation
309
311
 
310
- [Download LibTorch](https://pytorch.org/) and run:
312
+ [Download LibTorch](https://pytorch.org/). For Linux, use the `cxx11 ABI` version. Then run:
311
313
 
312
314
  ```sh
313
- gem install torch-rb -- --with-torch-dir=/path/to/libtorch
315
+ bundle config build.torch-rb --with-torch-dir=/path/to/libtorch
314
316
  ```
315
317
 
316
318
  ### Homebrew
@@ -318,10 +320,10 @@ gem install torch-rb -- --with-torch-dir=/path/to/libtorch
318
320
  For Mac, you can use Homebrew.
319
321
 
320
322
  ```sh
321
- brew install ankane/brew/libtorch
323
+ brew install libtorch
322
324
  ```
323
325
 
324
- Then install the gem (no need for `--with-torch-dir`).
326
+ Then install the gem (no need for `bundle config`).
325
327
 
326
328
  ## rbenv
327
329
 
data/ext/torch/ext.cpp CHANGED
@@ -441,19 +441,28 @@ void Init_ext()
441
441
  auto dtype = self.dtype();
442
442
 
443
443
  // TODO DRY if someone knows C++
444
- // TODO kByte (uint8), kChar (int8), kBool (bool)
445
- if (dtype == torch::kShort) {
446
- short* data = self.data_ptr<short>();
444
+ if (dtype == torch::kByte) {
445
+ uint8_t* data = self.data_ptr<uint8_t>();
446
+ for (int i = 0; i < self.numel(); i++) {
447
+ a.push(data[i]);
448
+ }
449
+ } else if (dtype == torch::kChar) {
450
+ int8_t* data = self.data_ptr<int8_t>();
451
+ for (int i = 0; i < self.numel(); i++) {
452
+ a.push(to_ruby<int>(data[i]));
453
+ }
454
+ } else if (dtype == torch::kShort) {
455
+ int16_t* data = self.data_ptr<int16_t>();
447
456
  for (int i = 0; i < self.numel(); i++) {
448
457
  a.push(data[i]);
449
458
  }
450
459
  } else if (dtype == torch::kInt) {
451
- int* data = self.data_ptr<int>();
460
+ int32_t* data = self.data_ptr<int32_t>();
452
461
  for (int i = 0; i < self.numel(); i++) {
453
462
  a.push(data[i]);
454
463
  }
455
464
  } else if (dtype == torch::kLong) {
456
- long long* data = self.data_ptr<long long>();
465
+ int64_t* data = self.data_ptr<int64_t>();
457
466
  for (int i = 0; i < self.numel(); i++) {
458
467
  a.push(data[i]);
459
468
  }
@@ -467,8 +476,11 @@ void Init_ext()
467
476
  for (int i = 0; i < self.numel(); i++) {
468
477
  a.push(data[i]);
469
478
  }
479
+ } else if (dtype == torch::kBool) {
480
+ // bool
481
+ throw std::runtime_error("Type not supported yet");
470
482
  } else {
471
- throw "Unsupported type";
483
+ throw std::runtime_error("Unsupported type");
472
484
  }
473
485
  return a;
474
486
  })
@@ -499,8 +511,11 @@ void Init_ext()
499
511
  torch::Layout l;
500
512
  if (layout == "strided") {
501
513
  l = torch::kStrided;
514
+ } else if (layout == "sparse") {
515
+ l = torch::kSparse;
516
+ throw std::runtime_error("Sparse layout not supported yet");
502
517
  } else {
503
- throw "Unsupported layout";
518
+ throw std::runtime_error("Unsupported layout: " + layout);
504
519
  }
505
520
  return self.layout(l);
506
521
  })
@@ -513,7 +528,7 @@ void Init_ext()
513
528
  } else if (device == "cuda") {
514
529
  d = torch::kCUDA;
515
530
  } else {
516
- throw "Unsupported device";
531
+ throw std::runtime_error("Unsupported device: " + device);
517
532
  }
518
533
  return self.device(d);
519
534
  })
data/ext/torch/extconf.rb CHANGED
@@ -4,6 +4,9 @@ abort "Missing stdc++" unless have_library("stdc++")
4
4
 
5
5
  $CXXFLAGS << " -std=c++11"
6
6
 
7
+ # needed for Linux pre-cxx11 ABI version
8
+ # $CXXFLAGS << " -D_GLIBCXX_USE_CXX11_ABI=0"
9
+
7
10
  # silence ruby/intern.h warning
8
11
  $CXXFLAGS << " -Wno-deprecated-register"
9
12
 
data/lib/torch/ext.bundle CHANGED
Binary file
data/lib/torch/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.1.0"
2
+ VERSION = "0.1.1"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.1.0
4
+ version: 0.1.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2019-11-26 00:00:00.000000000 Z
11
+ date: 2019-11-27 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice