torch-rb 0.10.0 → 0.11.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
data/ext/torch/extconf.rb CHANGED
@@ -7,21 +7,9 @@ $CXXFLAGS += " -D_GLIBCXX_USE_CXX11_ABI=1"
7
7
 
8
8
  apple_clang = RbConfig::CONFIG["CC_VERSION_MESSAGE"] =~ /apple clang/i
9
9
 
10
- # check omp first
11
- if have_library("omp") || have_library("gomp")
12
- $CXXFLAGS += " -Xclang" if apple_clang
13
- $CXXFLAGS += " -fopenmp"
14
- end
15
-
16
10
  if apple_clang
17
- # silence rice warnings
18
- $CXXFLAGS += " -Wno-deprecated-declarations"
19
-
20
- # silence ruby/intern.h warning
21
- $CXXFLAGS += " -Wno-deprecated-register"
22
-
23
11
  # silence torch warnings
24
- $CXXFLAGS += " -Wno-shorten-64-to-32 -Wno-missing-noreturn"
12
+ $CXXFLAGS += " -Wno-deprecated-declarations"
25
13
  else
26
14
  # silence rice warnings
27
15
  $CXXFLAGS += " -Wno-noexcept-type"
@@ -83,23 +83,23 @@ struct RubyArgs {
83
83
  template<int N>
84
84
  inline std::array<at::Tensor, N> tensorlist_n(int i);
85
85
  inline std::vector<int64_t> intlist(int i);
86
- // inline c10::OptionalArray<int64_t> intlistOptional(int i);
87
- // inline std::vector<int64_t> intlistWithDefault(int i, std::vector<int64_t> default_intlist);
86
+ inline c10::OptionalArray<int64_t> intlistOptional(int i);
87
+ inline std::vector<int64_t> intlistWithDefault(int i, std::vector<int64_t> default_intlist);
88
88
  inline c10::optional<at::Generator> generator(int i);
89
89
  inline at::Storage storage(int i);
90
90
  inline at::ScalarType scalartype(int i);
91
- // inline at::ScalarType scalartypeWithDefault(int i, at::ScalarType default_scalartype);
91
+ inline at::ScalarType scalartypeWithDefault(int i, at::ScalarType default_scalartype);
92
92
  inline c10::optional<at::ScalarType> scalartypeOptional(int i);
93
93
  inline c10::optional<at::Scalar> scalarOptional(int i);
94
94
  inline c10::optional<int64_t> toInt64Optional(int i);
95
95
  inline c10::optional<bool> toBoolOptional(int i);
96
96
  inline c10::optional<double> toDoubleOptional(int i);
97
97
  inline c10::OptionalArray<double> doublelistOptional(int i);
98
- // inline at::Layout layout(int i);
99
- // inline at::Layout layoutWithDefault(int i, at::Layout default_layout);
98
+ inline at::Layout layout(int i);
99
+ inline at::Layout layoutWithDefault(int i, at::Layout default_layout);
100
100
  inline c10::optional<at::Layout> layoutOptional(int i);
101
101
  inline at::Device device(int i);
102
- // inline at::Device deviceWithDefault(int i, const at::Device& default_device);
102
+ inline at::Device deviceWithDefault(int i, const at::Device& default_device);
103
103
  // inline c10::optional<at::Device> deviceOptional(int i);
104
104
  // inline at::Dimname dimname(int i);
105
105
  // inline std::vector<at::Dimname> dimnamelist(int i);
@@ -166,8 +166,11 @@ inline std::array<at::Tensor, N> RubyArgs::tensorlist_n(int i) {
166
166
  }
167
167
 
168
168
  inline std::vector<int64_t> RubyArgs::intlist(int i) {
169
- if (NIL_P(args[i])) return signature.params[i].default_intlist;
169
+ return intlistWithDefault(i, signature.params[i].default_intlist);
170
+ }
170
171
 
172
+ inline std::vector<int64_t> RubyArgs::intlistWithDefault(int i, std::vector<int64_t> default_intlist) {
173
+ if (NIL_P(args[i])) return default_intlist;
171
174
  VALUE arg = args[i];
172
175
  auto size = signature.params[i].size;
173
176
  if (size > 0 && FIXNUM_P(arg)) {
@@ -189,6 +192,11 @@ inline std::vector<int64_t> RubyArgs::intlist(int i) {
189
192
  return res;
190
193
  }
191
194
 
195
+ inline c10::OptionalArray<int64_t> RubyArgs::intlistOptional(int i) {
196
+ if (NIL_P(args[i])) return {};
197
+ return intlist(i);
198
+ }
199
+
192
200
  inline c10::optional<at::Generator> RubyArgs::generator(int i) {
193
201
  if (NIL_P(args[i])) return c10::nullopt;
194
202
  throw std::runtime_error("generator not supported yet");
@@ -240,6 +248,11 @@ inline ScalarType RubyArgs::scalartype(int i) {
240
248
  return it->second;
241
249
  }
242
250
 
251
+ inline at::ScalarType RubyArgs::scalartypeWithDefault(int i, at::ScalarType default_scalartype) {
252
+ if (NIL_P(args[i])) return default_scalartype;
253
+ return scalartype(i);
254
+ }
255
+
243
256
  inline c10::optional<ScalarType> RubyArgs::scalartypeOptional(int i) {
244
257
  if (NIL_P(args[i])) return c10::nullopt;
245
258
  return scalartype(i);
@@ -284,8 +297,8 @@ inline c10::OptionalArray<double> RubyArgs::doublelistOptional(int i) {
284
297
  return res;
285
298
  }
286
299
 
287
- inline c10::optional<at::Layout> RubyArgs::layoutOptional(int i) {
288
- if (NIL_P(args[i])) return c10::nullopt;
300
+ inline at::Layout RubyArgs::layout(int i) {
301
+ if (NIL_P(args[i])) return signature.params[i].default_layout;
289
302
 
290
303
  static std::unordered_map<VALUE, Layout> layout_map = {
291
304
  {ID2SYM(rb_intern("strided")), Layout::Strided},
@@ -298,6 +311,16 @@ inline c10::optional<at::Layout> RubyArgs::layoutOptional(int i) {
298
311
  return it->second;
299
312
  }
300
313
 
314
+ inline at::Layout RubyArgs::layoutWithDefault(int i, at::Layout default_layout) {
315
+ if (NIL_P(args[i])) return default_layout;
316
+ return layout(i);
317
+ }
318
+
319
+ inline c10::optional<at::Layout> RubyArgs::layoutOptional(int i) {
320
+ if (NIL_P(args[i])) return c10::nullopt;
321
+ return layout(i);
322
+ }
323
+
301
324
  inline at::Device RubyArgs::device(int i) {
302
325
  if (NIL_P(args[i])) {
303
326
  return at::Device("cpu");
@@ -306,6 +329,11 @@ inline at::Device RubyArgs::device(int i) {
306
329
  return at::Device(device_str);
307
330
  }
308
331
 
332
+ inline at::Device RubyArgs::deviceWithDefault(int i, const at::Device& default_device) {
333
+ if (NIL_P(args[i])) return default_device;
334
+ return device(i);
335
+ }
336
+
309
337
  inline at::MemoryFormat RubyArgs::memoryformat(int i) {
310
338
  if (NIL_P(args[i])) return at::MemoryFormat::Contiguous;
311
339
  throw std::runtime_error("memoryformat not supported yet");
data/ext/torch/utils.h CHANGED
@@ -1,8 +1,15 @@
1
1
  #pragma once
2
2
 
3
+ #include <torch/torch.h>
4
+
3
5
  #include <rice/rice.hpp>
4
6
  #include <rice/stl.hpp>
5
7
 
8
+ static_assert(
9
+ TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR == 12,
10
+ "Incompatible LibTorch version"
11
+ );
12
+
6
13
  // TODO find better place
7
14
  inline void handle_error(torch::Error const & ex) {
8
15
  throw Rice::Exception(rb_eRuntimeError, ex.what_without_backtrace());
data/lib/torch/tensor.rb CHANGED
@@ -199,5 +199,13 @@ module Torch
199
199
  def real
200
200
  Torch.real(self)
201
201
  end
202
+
203
+ def coerce(other)
204
+ if other.is_a?(Numeric)
205
+ [Torch.tensor(other), self]
206
+ else
207
+ raise TypeError, "#{self.class} can't be coerced into #{other.class}"
208
+ end
209
+ end
202
210
  end
203
211
  end
data/lib/torch/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.10.0"
2
+ VERSION = "0.11.0"
3
3
  end
data/lib/torch.rb CHANGED
@@ -249,6 +249,7 @@ module Torch
249
249
  complex_float: 9,
250
250
  complex64: 9,
251
251
  complex_double: 10,
252
+ cdouble: 10,
252
253
  complex128: 10,
253
254
  bool: 11,
254
255
  qint8: 12,
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.10.0
4
+ version: 0.11.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2022-03-13 00:00:00.000000000 Z
11
+ date: 2022-07-06 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -222,7 +222,7 @@ required_ruby_version: !ruby/object:Gem::Requirement
222
222
  requirements:
223
223
  - - ">="
224
224
  - !ruby/object:Gem::Version
225
- version: '2.6'
225
+ version: '2.7'
226
226
  required_rubygems_version: !ruby/object:Gem::Requirement
227
227
  requirements:
228
228
  - - ">="