torch-rb 0.10.2 → 0.11.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
data/ext/torch/extconf.rb CHANGED
@@ -7,21 +7,9 @@ $CXXFLAGS += " -D_GLIBCXX_USE_CXX11_ABI=1"
7
7
 
8
8
  apple_clang = RbConfig::CONFIG["CC_VERSION_MESSAGE"] =~ /apple clang/i
9
9
 
10
- # check omp first
11
- if have_library("omp") || have_library("gomp")
12
- $CXXFLAGS += " -Xclang" if apple_clang
13
- $CXXFLAGS += " -fopenmp"
14
- end
15
-
16
10
  if apple_clang
17
- # silence rice warnings
18
- $CXXFLAGS += " -Wno-deprecated-declarations"
19
-
20
- # silence ruby/intern.h warning
21
- $CXXFLAGS += " -Wno-deprecated-register"
22
-
23
11
  # silence torch warnings
24
- $CXXFLAGS += " -Wno-shorten-64-to-32 -Wno-missing-noreturn"
12
+ $CXXFLAGS += " -Wno-deprecated-declarations"
25
13
  else
26
14
  # silence rice warnings
27
15
  $CXXFLAGS += " -Wno-noexcept-type"
@@ -83,8 +83,8 @@ struct RubyArgs {
83
83
  template<int N>
84
84
  inline std::array<at::Tensor, N> tensorlist_n(int i);
85
85
  inline std::vector<int64_t> intlist(int i);
86
- // inline c10::OptionalArray<int64_t> intlistOptional(int i);
87
- // inline std::vector<int64_t> intlistWithDefault(int i, std::vector<int64_t> default_intlist);
86
+ inline c10::OptionalArray<int64_t> intlistOptional(int i);
87
+ inline std::vector<int64_t> intlistWithDefault(int i, std::vector<int64_t> default_intlist);
88
88
  inline c10::optional<at::Generator> generator(int i);
89
89
  inline at::Storage storage(int i);
90
90
  inline at::ScalarType scalartype(int i);
@@ -166,8 +166,11 @@ inline std::array<at::Tensor, N> RubyArgs::tensorlist_n(int i) {
166
166
  }
167
167
 
168
168
  inline std::vector<int64_t> RubyArgs::intlist(int i) {
169
- if (NIL_P(args[i])) return signature.params[i].default_intlist;
169
+ return intlistWithDefault(i, signature.params[i].default_intlist);
170
+ }
170
171
 
172
+ inline std::vector<int64_t> RubyArgs::intlistWithDefault(int i, std::vector<int64_t> default_intlist) {
173
+ if (NIL_P(args[i])) return default_intlist;
171
174
  VALUE arg = args[i];
172
175
  auto size = signature.params[i].size;
173
176
  if (size > 0 && FIXNUM_P(arg)) {
@@ -189,6 +192,11 @@ inline std::vector<int64_t> RubyArgs::intlist(int i) {
189
192
  return res;
190
193
  }
191
194
 
195
+ inline c10::OptionalArray<int64_t> RubyArgs::intlistOptional(int i) {
196
+ if (NIL_P(args[i])) return {};
197
+ return intlist(i);
198
+ }
199
+
192
200
  inline c10::optional<at::Generator> RubyArgs::generator(int i) {
193
201
  if (NIL_P(args[i])) return c10::nullopt;
194
202
  throw std::runtime_error("generator not supported yet");
data/ext/torch/utils.h CHANGED
@@ -6,7 +6,7 @@
6
6
  #include <rice/stl.hpp>
7
7
 
8
8
  static_assert(
9
- TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR == 11,
9
+ TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR == 12,
10
10
  "Incompatible LibTorch version"
11
11
  );
12
12
 
data/lib/torch/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.10.2"
2
+ VERSION = "0.11.0"
3
3
  end
data/lib/torch.rb CHANGED
@@ -249,6 +249,7 @@ module Torch
249
249
  complex_float: 9,
250
250
  complex64: 9,
251
251
  complex_double: 10,
252
+ cdouble: 10,
252
253
  complex128: 10,
253
254
  bool: 11,
254
255
  qint8: 12,
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.10.2
4
+ version: 0.11.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2022-06-14 00:00:00.000000000 Z
11
+ date: 2022-07-06 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -222,7 +222,7 @@ required_ruby_version: !ruby/object:Gem::Requirement
222
222
  requirements:
223
223
  - - ">="
224
224
  - !ruby/object:Gem::Version
225
- version: '2.6'
225
+ version: '2.7'
226
226
  required_rubygems_version: !ruby/object:Gem::Requirement
227
227
  requirements:
228
228
  - - ">="