torch-rb 0.10.1 → 0.11.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
data/ext/torch/extconf.rb CHANGED
@@ -7,21 +7,9 @@ $CXXFLAGS += " -D_GLIBCXX_USE_CXX11_ABI=1"
7
7
 
8
8
  apple_clang = RbConfig::CONFIG["CC_VERSION_MESSAGE"] =~ /apple clang/i
9
9
 
10
- # check omp first
11
- if have_library("omp") || have_library("gomp")
12
- $CXXFLAGS += " -Xclang" if apple_clang
13
- $CXXFLAGS += " -fopenmp"
14
- end
15
-
16
10
  if apple_clang
17
- # silence rice warnings
18
- $CXXFLAGS += " -Wno-deprecated-declarations"
19
-
20
- # silence ruby/intern.h warning
21
- $CXXFLAGS += " -Wno-deprecated-register"
22
-
23
11
  # silence torch warnings
24
- $CXXFLAGS += " -Wno-shorten-64-to-32 -Wno-missing-noreturn"
12
+ $CXXFLAGS += " -Wno-deprecated-declarations"
25
13
  else
26
14
  # silence rice warnings
27
15
  $CXXFLAGS += " -Wno-noexcept-type"
@@ -253,6 +253,68 @@ static inline std::vector<int64_t> parse_intlist_args(const std::string& s, int6
253
253
  return args;
254
254
  }
255
255
 
256
+ // Parse a string literal to remove quotes and escape sequences
257
+ static std::string parse_string_literal(c10::string_view str) {
258
+ TORCH_CHECK(str.length() >= 2, "String defaults must be quoted");
259
+
260
+ if (str.front() == '"') {
261
+ TORCH_CHECK(
262
+ str.back() == '"', "Mismatched quotes in string default: ", str);
263
+ } else {
264
+ TORCH_CHECK(
265
+ str.front() == '\'' && str.back() == '\'',
266
+ "Invalid quotes in string default: ",
267
+ str)
268
+ }
269
+
270
+ std::string parsed;
271
+ parsed.reserve(str.size());
272
+ for (size_t i = 1; i < str.size() - 1;) {
273
+ if (str[i] != '\\') {
274
+ parsed.push_back(str[i]);
275
+ ++i;
276
+ continue;
277
+ }
278
+
279
+ // Handle escape sequences
280
+ TORCH_CHECK(
281
+ i < str.size() - 2, "String ends with escaped final quote: ", str)
282
+ char c = str[i + 1];
283
+ switch (c) {
284
+ case '\\':
285
+ case '\'':
286
+ case '\"':
287
+ break;
288
+ case 'a':
289
+ c = '\a';
290
+ break;
291
+ case 'b':
292
+ c = '\b';
293
+ break;
294
+ case 'f':
295
+ c = '\f';
296
+ break;
297
+ case 'n':
298
+ c = '\n';
299
+ break;
300
+ case 'v':
301
+ c = '\v';
302
+ break;
303
+ case 't':
304
+ c = '\t';
305
+ break;
306
+ default:
307
+ TORCH_CHECK(
308
+ false,
309
+ "Unsupported escape sequence in string default: \\",
310
+ str[i + 1]);
311
+ }
312
+ parsed.push_back(c);
313
+ i += 2;
314
+ }
315
+ return parsed;
316
+ }
317
+
256
318
  void FunctionParameter::set_default_str(const std::string& str) {
257
319
  if (str == "None") {
258
320
  allow_none = true;
@@ -308,8 +370,8 @@ void FunctionParameter::set_default_str(const std::string& str) {
308
370
  throw std::runtime_error("invalid device: " + str);
309
371
  }
310
372
  } else if (type_ == ParameterType::STRING) {
311
- if (str != "None" && str != "") {
312
- throw std::runtime_error("invalid default string: " + str);
373
+ if (str != "None") {
374
+ default_string = parse_string_literal(str);
313
375
  }
314
376
  }
315
377
  }
@@ -35,6 +35,7 @@ struct FunctionParameter {
35
35
  at::SmallVector<VALUE, 5> numpy_python_names;
36
36
  at::Scalar default_scalar;
37
37
  std::vector<int64_t> default_intlist;
38
+ std::string default_string;
38
39
  union {
39
40
  bool default_bool;
40
41
  int64_t default_int;
@@ -83,8 +84,8 @@ struct RubyArgs {
83
84
  template<int N>
84
85
  inline std::array<at::Tensor, N> tensorlist_n(int i);
85
86
  inline std::vector<int64_t> intlist(int i);
86
- // inline c10::OptionalArray<int64_t> intlistOptional(int i);
87
- // inline std::vector<int64_t> intlistWithDefault(int i, std::vector<int64_t> default_intlist);
87
+ inline c10::OptionalArray<int64_t> intlistOptional(int i);
88
+ inline std::vector<int64_t> intlistWithDefault(int i, std::vector<int64_t> default_intlist);
88
89
  inline c10::optional<at::Generator> generator(int i);
89
90
  inline at::Storage storage(int i);
90
91
  inline at::ScalarType scalartype(int i);
@@ -108,6 +109,7 @@ struct RubyArgs {
108
109
  inline c10::optional<at::MemoryFormat> memoryformatOptional(int i);
109
110
  // inline at::QScheme toQScheme(int i);
110
111
  inline std::string string(int i);
112
+ inline std::string stringWithDefault(int i, const std::string& default_str);
111
113
  inline c10::optional<std::string> stringOptional(int i);
112
114
  inline c10::string_view stringView(int i);
113
115
  // inline c10::string_view stringViewWithDefault(int i, const c10::string_view default_str);
@@ -166,8 +168,11 @@ inline std::array<at::Tensor, N> RubyArgs::tensorlist_n(int i) {
166
168
  }
167
169
 
168
170
  inline std::vector<int64_t> RubyArgs::intlist(int i) {
169
- if (NIL_P(args[i])) return signature.params[i].default_intlist;
171
+ return intlistWithDefault(i, signature.params[i].default_intlist);
172
+ }
170
173
 
174
+ inline std::vector<int64_t> RubyArgs::intlistWithDefault(int i, std::vector<int64_t> default_intlist) {
175
+ if (NIL_P(args[i])) return default_intlist;
171
176
  VALUE arg = args[i];
172
177
  auto size = signature.params[i].size;
173
178
  if (size > 0 && FIXNUM_P(arg)) {
@@ -189,6 +194,11 @@ inline std::vector<int64_t> RubyArgs::intlist(int i) {
189
194
  return res;
190
195
  }
191
196
 
197
+ inline c10::OptionalArray<int64_t> RubyArgs::intlistOptional(int i) {
198
+ if (NIL_P(args[i])) return {};
199
+ return intlist(i);
200
+ }
201
+
192
202
  inline c10::optional<at::Generator> RubyArgs::generator(int i) {
193
203
  if (NIL_P(args[i])) return c10::nullopt;
194
204
  throw std::runtime_error("generator not supported yet");
@@ -337,6 +347,11 @@ inline c10::optional<at::MemoryFormat> RubyArgs::memoryformatOptional(int i) {
337
347
  }
338
348
 
339
349
  inline std::string RubyArgs::string(int i) {
350
+ return stringWithDefault(i, signature.params[i].default_string);
351
+ }
352
+
353
+ inline std::string RubyArgs::stringWithDefault(int i, const std::string& default_str) {
354
+ if (!args[i]) return default_str;
340
355
  return Rice::detail::From_Ruby<std::string>().convert(args[i]);
341
356
  }
342
357
 
data/ext/torch/utils.h CHANGED
@@ -6,7 +6,7 @@
6
6
  #include <rice/stl.hpp>
7
7
 
8
8
  static_assert(
9
- TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR == 11,
9
+ TORCH_VERSION_MAJOR == 1 && TORCH_VERSION_MINOR == 12,
10
10
  "Incompatible LibTorch version"
11
11
  );
12
12
 
data/lib/torch/tensor.rb CHANGED
@@ -177,11 +177,6 @@ module Torch
177
177
  _random!(*args)
178
178
  end
179
179
 
180
- # center option
181
- def stft(*args)
182
- Torch.stft(*args)
183
- end
184
-
185
180
  def dup
186
181
  Torch.no_grad do
187
182
  clone
@@ -199,5 +194,13 @@ module Torch
199
194
  def real
200
195
  Torch.real(self)
201
196
  end
197
+
198
+ def coerce(other)
199
+ if other.is_a?(Numeric)
200
+ [Torch.tensor(other), self]
201
+ else
202
+ raise TypeError, "#{self.class} can't be coerced into #{other.class}"
203
+ end
204
+ end
202
205
  end
203
206
  end
data/lib/torch/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.10.1"
2
+ VERSION = "0.11.1"
3
3
  end
data/lib/torch.rb CHANGED
@@ -249,6 +249,7 @@ module Torch
249
249
  complex_float: 9,
250
250
  complex64: 9,
251
251
  complex_double: 10,
252
+ cdouble: 10,
252
253
  complex128: 10,
253
254
  bool: 11,
254
255
  qint8: 12,
@@ -421,18 +422,6 @@ module Torch
421
422
  _tensor(data, size, tensor_options(**options))
422
423
  end
423
424
 
424
- # center option
425
- def stft(input, n_fft, hop_length: nil, win_length: nil, window: nil, center: true, pad_mode: "reflect", normalized: false, onesided: true, return_complex: nil)
426
- if center
427
- signal_dim = input.dim
428
- extended_shape = [1] * (3 - signal_dim) + input.size
429
- pad = n_fft.div(2).to_i
430
- input = NN::F.pad(input.view(extended_shape), [pad, pad], mode: pad_mode)
431
- input = input.view(input.shape[-signal_dim..-1])
432
- end
433
- _stft(input, n_fft, hop_length, win_length, window, normalized, onesided, return_complex)
434
- end
435
-
436
425
  private
437
426
 
438
427
  def to_ivalue(obj)
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.10.1
4
+ version: 0.11.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2022-04-12 00:00:00.000000000 Z
11
+ date: 2022-07-06 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice
@@ -222,7 +222,7 @@ required_ruby_version: !ruby/object:Gem::Requirement
222
222
  requirements:
223
223
  - - ">="
224
224
  - !ruby/object:Gem::Version
225
- version: '2.6'
225
+ version: '2.7'
226
226
  required_rubygems_version: !ruby/object:Gem::Requirement
227
227
  requirements:
228
228
  - - ">="