torch-rb 0.3.0 → 0.3.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 33636e58063f25c2b9f122d29332e4136bb6a4de0fd227349f75d65a9db94931
4
- data.tar.gz: 9349dd0b050a4c9e0714d92bb451bdd916e55fb47a5c4d90a74720d53564a1d6
3
+ metadata.gz: 06e94b492acbbdb71f9e6a11081fb043a03ae0d5c704cc79faa31dd96bde70ef
4
+ data.tar.gz: 4f38fa52d30ef9bf121204423b4d675f21dbef806b6f137152f2cf9399ddf4bb
5
5
  SHA512:
6
- metadata.gz: 692e8dc3531426377413fc9325c2b03dd7fcbbbce0c05cbd5d7c3182a08bbe733bb9a6b0aa62a56fb27917073e1f1f5859aa0dd3f6f40e74043bba242ede6267
7
- data.tar.gz: f57d0411c18c7c4753f5edc82e48b666ad3154c72bf189a8ec2c4dceb08cd1f37a9421b9a6eae3d20e10c4c9236a1136d76559a31989ce22868a9f44ef3e0e66
6
+ metadata.gz: 2fb2613ca629a70f55009b697b15830d59c0d8fc06c1c5102917b4870cb783427fb56ecc08889c09e15c342381385f258b2a33102dc5adddf2d463d41674994d
7
+ data.tar.gz: f26a6ba91caa57a92b8b047217a35c39d1e9c4c361df77e2182053b4ab490f20792fc88dba169dae87d4a3d4ee4d69e2c779efb1fa6150b4d3f0d93e3762aec9
@@ -1,3 +1,8 @@
1
+ ## 0.3.1 (2020-08-17)
2
+
3
+ - Added `create_graph` and `retain_graph` options to `backward` method
4
+ - Fixed error when `set` not required
5
+
1
6
  ## 0.3.0 (2020-07-29)
2
7
 
3
8
  - Updated LibTorch to 1.6.0
data/README.md CHANGED
@@ -411,7 +411,7 @@ Here’s the list of compatible versions.
411
411
 
412
412
  Torch.rb | LibTorch
413
413
  --- | ---
414
- 0.3.0 | 1.6.0
414
+ 0.3.0-0.3.1 | 1.6.0
415
415
  0.2.0-0.2.7 | 1.5.0-1.5.1
416
416
  0.1.8 | 1.4.0
417
417
  0.1.0-0.1.7 | 1.3.1
@@ -352,8 +352,8 @@ void Init_ext()
352
352
  })
353
353
  .define_method(
354
354
  "_backward",
355
- *[](Tensor& self, Object gradient) {
356
- return gradient.is_nil() ? self.backward() : self.backward(from_ruby<torch::Tensor>(gradient));
355
+ *[](Tensor& self, OptionalTensor gradient, bool create_graph, bool retain_graph) {
356
+ return self.backward(gradient, create_graph, retain_graph);
357
357
  })
358
358
  .define_method(
359
359
  "grad",
@@ -7,17 +7,16 @@ $CXXFLAGS += " -std=c++14"
7
7
  # change to 0 for Linux pre-cxx11 ABI version
8
8
  $CXXFLAGS += " -D_GLIBCXX_USE_CXX11_ABI=1"
9
9
 
10
- # TODO check compiler name
11
- clang = RbConfig::CONFIG["host_os"] =~ /darwin/i
10
+ apple_clang = RbConfig::CONFIG["CC_VERSION_MESSAGE"] =~ /apple clang/i
12
11
 
13
12
  # check omp first
14
13
  if have_library("omp") || have_library("gomp")
15
14
  $CXXFLAGS += " -DAT_PARALLEL_OPENMP=1"
16
- $CXXFLAGS += " -Xclang" if clang
15
+ $CXXFLAGS += " -Xclang" if apple_clang
17
16
  $CXXFLAGS += " -fopenmp"
18
17
  end
19
18
 
20
- if clang
19
+ if apple_clang
21
20
  # silence ruby/intern.h warning
22
21
  $CXXFLAGS += " -Wno-deprecated-register"
23
22
 
@@ -4,6 +4,7 @@ require "torch/ext"
4
4
  # stdlib
5
5
  require "fileutils"
6
6
  require "net/http"
7
+ require "set"
7
8
  require "tmpdir"
8
9
 
9
10
  # native functions
@@ -103,8 +103,9 @@ module Torch
103
103
  Torch.empty(0, dtype: dtype)
104
104
  end
105
105
 
106
- def backward(gradient = nil)
107
- _backward(gradient)
106
+ def backward(gradient = nil, retain_graph: nil, create_graph: false)
107
+ retain_graph = create_graph if retain_graph.nil?
108
+ _backward(gradient, retain_graph, create_graph)
108
109
  end
109
110
 
110
111
  # TODO read directly from memory
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.3.0"
2
+ VERSION = "0.3.1"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.3.0
4
+ version: 0.3.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2020-07-29 00:00:00.000000000 Z
11
+ date: 2020-08-17 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice