torch-rb 0.3.0 → 0.3.1

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 33636e58063f25c2b9f122d29332e4136bb6a4de0fd227349f75d65a9db94931
4
- data.tar.gz: 9349dd0b050a4c9e0714d92bb451bdd916e55fb47a5c4d90a74720d53564a1d6
3
+ metadata.gz: 06e94b492acbbdb71f9e6a11081fb043a03ae0d5c704cc79faa31dd96bde70ef
4
+ data.tar.gz: 4f38fa52d30ef9bf121204423b4d675f21dbef806b6f137152f2cf9399ddf4bb
5
5
  SHA512:
6
- metadata.gz: 692e8dc3531426377413fc9325c2b03dd7fcbbbce0c05cbd5d7c3182a08bbe733bb9a6b0aa62a56fb27917073e1f1f5859aa0dd3f6f40e74043bba242ede6267
7
- data.tar.gz: f57d0411c18c7c4753f5edc82e48b666ad3154c72bf189a8ec2c4dceb08cd1f37a9421b9a6eae3d20e10c4c9236a1136d76559a31989ce22868a9f44ef3e0e66
6
+ metadata.gz: 2fb2613ca629a70f55009b697b15830d59c0d8fc06c1c5102917b4870cb783427fb56ecc08889c09e15c342381385f258b2a33102dc5adddf2d463d41674994d
7
+ data.tar.gz: f26a6ba91caa57a92b8b047217a35c39d1e9c4c361df77e2182053b4ab490f20792fc88dba169dae87d4a3d4ee4d69e2c779efb1fa6150b4d3f0d93e3762aec9
@@ -1,3 +1,8 @@
1
+ ## 0.3.1 (2020-08-17)
2
+
3
+ - Added `create_graph` and `retain_graph` options to `backward` method
4
+ - Fixed error when `set` not required
5
+
1
6
  ## 0.3.0 (2020-07-29)
2
7
 
3
8
  - Updated LibTorch to 1.6.0
data/README.md CHANGED
@@ -411,7 +411,7 @@ Here’s the list of compatible versions.
411
411
 
412
412
  Torch.rb | LibTorch
413
413
  --- | ---
414
- 0.3.0 | 1.6.0
414
+ 0.3.0-0.3.1 | 1.6.0
415
415
  0.2.0-0.2.7 | 1.5.0-1.5.1
416
416
  0.1.8 | 1.4.0
417
417
  0.1.0-0.1.7 | 1.3.1
@@ -352,8 +352,8 @@ void Init_ext()
352
352
  })
353
353
  .define_method(
354
354
  "_backward",
355
- *[](Tensor& self, Object gradient) {
356
- return gradient.is_nil() ? self.backward() : self.backward(from_ruby<torch::Tensor>(gradient));
355
+ *[](Tensor& self, OptionalTensor gradient, bool create_graph, bool retain_graph) {
356
+ return self.backward(gradient, create_graph, retain_graph);
357
357
  })
358
358
  .define_method(
359
359
  "grad",
@@ -7,17 +7,16 @@ $CXXFLAGS += " -std=c++14"
7
7
  # change to 0 for Linux pre-cxx11 ABI version
8
8
  $CXXFLAGS += " -D_GLIBCXX_USE_CXX11_ABI=1"
9
9
 
10
- # TODO check compiler name
11
- clang = RbConfig::CONFIG["host_os"] =~ /darwin/i
10
+ apple_clang = RbConfig::CONFIG["CC_VERSION_MESSAGE"] =~ /apple clang/i
12
11
 
13
12
  # check omp first
14
13
  if have_library("omp") || have_library("gomp")
15
14
  $CXXFLAGS += " -DAT_PARALLEL_OPENMP=1"
16
- $CXXFLAGS += " -Xclang" if clang
15
+ $CXXFLAGS += " -Xclang" if apple_clang
17
16
  $CXXFLAGS += " -fopenmp"
18
17
  end
19
18
 
20
- if clang
19
+ if apple_clang
21
20
  # silence ruby/intern.h warning
22
21
  $CXXFLAGS += " -Wno-deprecated-register"
23
22
 
@@ -4,6 +4,7 @@ require "torch/ext"
4
4
  # stdlib
5
5
  require "fileutils"
6
6
  require "net/http"
7
+ require "set"
7
8
  require "tmpdir"
8
9
 
9
10
  # native functions
@@ -103,8 +103,9 @@ module Torch
103
103
  Torch.empty(0, dtype: dtype)
104
104
  end
105
105
 
106
- def backward(gradient = nil)
107
- _backward(gradient)
106
+ def backward(gradient = nil, retain_graph: nil, create_graph: false)
107
+ retain_graph = create_graph if retain_graph.nil?
108
+ _backward(gradient, retain_graph, create_graph)
108
109
  end
109
110
 
110
111
  # TODO read directly from memory
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.3.0"
2
+ VERSION = "0.3.1"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.3.0
4
+ version: 0.3.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2020-07-29 00:00:00.000000000 Z
11
+ date: 2020-08-17 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice