torch-rb 0.11.0 → 0.11.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: d2f88c938144476a772fd7c606751e0a6e67a338cb1c37ede9c2011db7fc4579
4
- data.tar.gz: f4408457e3bf8c7bf9b42863459248aa592c70e8e3924cdbe72863a979e65106
3
+ metadata.gz: c4c3936a55fd47a8898d3e678aabc81c158f2f39a89c08af0b36695700bf2043
4
+ data.tar.gz: a68179c7d7bab7547ac3be1d2369abbd5c3c632954996ca2f7fdd089f815edf7
5
5
  SHA512:
6
- metadata.gz: 90ebc506942809b02331f7accce6e93e714a57b5d2f58b06ad0d3c60204b947eba61fe0bf43431a19e27890f5c764cb348ccb1785a2ffccfba7969f4726b6f1f
7
- data.tar.gz: 410af7b4934f79aaae94f6c1bde1d17ee8b043f02354b4ea9f139271575f43abb90afeba162badfa01a3ec406612ae6cdcb4a2adfb1067b1634982a590d10cc9
6
+ metadata.gz: 97551aa27f154cade530e58b4ae92286cd18e432acc99646495a197c89fb719cc991b0399aacaf383a2de04e340c04a9e177c44a0169fce91e19435463df4753
7
+ data.tar.gz: 8a5a5decf900a4d93aa56eb5e5b3848d5386be8bbc3657c10c201ef321cf0082c6521700378a5e96f795c09882b5b72498d8fed3931153b6e1f7557630f55547
data/CHANGELOG.md CHANGED
@@ -1,3 +1,7 @@
1
+ ## 0.11.1 (2022-07-06)
2
+
3
+ - Fixed error with `stft` method
4
+
1
5
  ## 0.11.0 (2022-07-06)
2
6
 
3
7
  - Updated LibTorch to 1.12.0
@@ -128,8 +128,11 @@ def write_body(type, method_defs, attach_defs)
128
128
  end
129
129
 
130
130
  def write_file(name, contents)
131
- path = File.expand_path("../ext/torch", __dir__)
132
- File.write(File.join(path, name), contents)
131
+ path = File.join(File.expand_path("../ext/torch", __dir__), name)
132
+ # only write if changed to improve compile times in development
133
+ if !File.exist?(path) || File.read(path) != contents
134
+ File.write(path, contents)
135
+ end
133
136
  end
134
137
 
135
138
  def generate_attach_def(name, type, def_method)
@@ -142,14 +145,14 @@ def generate_attach_def(name, type, def_method)
142
145
  name
143
146
  end
144
147
 
145
- ruby_name = "_#{ruby_name}" if ["size", "stride", "random!", "stft"].include?(ruby_name)
148
+ ruby_name = "_#{ruby_name}" if ["size", "stride", "random!"].include?(ruby_name)
146
149
  ruby_name = ruby_name.sub(/\Afft_/, "") if type == "fft"
147
150
  ruby_name = ruby_name.sub(/\Alinalg_/, "") if type == "linalg"
148
151
  ruby_name = ruby_name.sub(/\Aspecial_/, "") if type == "special"
149
152
  ruby_name = ruby_name.sub(/\Asparse_/, "") if type == "sparse"
150
153
  ruby_name = name if name.start_with?("__")
151
154
 
152
- # cast for Ruby < 2.7 https://github.com/thisMagpie/fftw/issues/22#issuecomment-49508900
155
+ # cast for Ruby < 3.0 https://github.com/thisMagpie/fftw/issues/22#issuecomment-49508900
153
156
  cast = RUBY_VERSION.to_f > 2.7 ? "" : "(VALUE (*)(...)) "
154
157
 
155
158
  "rb_#{def_method}(m, \"#{ruby_name}\", #{cast}#{full_name(name, type)}, -1);"
@@ -253,6 +253,68 @@ static inline std::vector<int64_t> parse_intlist_args(const std::string& s, int6
253
253
  return args;
254
254
  }
255
255
 
256
+ // Parse a string literal to remove quotes and escape sequences
257
+ static std::string parse_string_literal(c10::string_view str) {
258
+ TORCH_CHECK(str.length() >= 2, "String defaults must be quoted");
259
+
260
+ if (str.front() == '"') {
261
+ TORCH_CHECK(
262
+ str.back() == '"', "Mismatched quotes in string default: ", str);
263
+ } else {
264
+ TORCH_CHECK(
265
+ str.front() == '\'' && str.back() == '\'',
266
+ "Invalid quotes in string default: ",
267
+ str)
268
+ }
269
+
270
+ std::string parsed;
271
+ parsed.reserve(str.size());
272
+ for (size_t i = 1; i < str.size() - 1;) {
273
+ if (str[i] != '\\') {
274
+ parsed.push_back(str[i]);
275
+ ++i;
276
+ continue;
277
+ }
278
+
279
+ // Handle escape sequences
280
+ TORCH_CHECK(
281
+ i < str.size() - 2, "String ends with escaped final quote: ", str)
282
+ char c = str[i + 1];
283
+ switch (c) {
284
+ case '\\':
285
+ case '\'':
286
+ case '\"':
287
+ break;
288
+ case 'a':
289
+ c = '\a';
290
+ break;
291
+ case 'b':
292
+ c = '\b';
293
+ break;
294
+ case 'f':
295
+ c = '\f';
296
+ break;
297
+ case 'n':
298
+ c = '\n';
299
+ break;
300
+ case 'v':
301
+ c = '\v';
302
+ break;
303
+ case 't':
304
+ c = '\t';
305
+ break;
306
+ default:
307
+ TORCH_CHECK(
308
+ false,
309
+ "Unsupported escape sequence in string default: \\",
310
+ str[i + 1]);
311
+ }
312
+ parsed.push_back(c);
313
+ i += 2;
314
+ }
315
+ return parsed;
316
+ }
317
+
256
318
  void FunctionParameter::set_default_str(const std::string& str) {
257
319
  if (str == "None") {
258
320
  allow_none = true;
@@ -308,8 +370,8 @@ void FunctionParameter::set_default_str(const std::string& str) {
308
370
  throw std::runtime_error("invalid device: " + str);
309
371
  }
310
372
  } else if (type_ == ParameterType::STRING) {
311
- if (str != "None" && str != "") {
312
- throw std::runtime_error("invalid default string: " + str);
373
+ if (str != "None") {
374
+ default_string = parse_string_literal(str);
313
375
  }
314
376
  }
315
377
  }
@@ -35,6 +35,7 @@ struct FunctionParameter {
35
35
  at::SmallVector<VALUE, 5> numpy_python_names;
36
36
  at::Scalar default_scalar;
37
37
  std::vector<int64_t> default_intlist;
38
+ std::string default_string;
38
39
  union {
39
40
  bool default_bool;
40
41
  int64_t default_int;
@@ -108,6 +109,7 @@ struct RubyArgs {
108
109
  inline c10::optional<at::MemoryFormat> memoryformatOptional(int i);
109
110
  // inline at::QScheme toQScheme(int i);
110
111
  inline std::string string(int i);
112
+ inline std::string stringWithDefault(int i, const std::string& default_str);
111
113
  inline c10::optional<std::string> stringOptional(int i);
112
114
  inline c10::string_view stringView(int i);
113
115
  // inline c10::string_view stringViewWithDefault(int i, const c10::string_view default_str);
@@ -345,6 +347,11 @@ inline c10::optional<at::MemoryFormat> RubyArgs::memoryformatOptional(int i) {
345
347
  }
346
348
 
347
349
  inline std::string RubyArgs::string(int i) {
350
+ return stringWithDefault(i, signature.params[i].default_string);
351
+ }
352
+
353
+ inline std::string RubyArgs::stringWithDefault(int i, const std::string& default_str) {
354
+ if (!args[i]) return default_str;
348
355
  return Rice::detail::From_Ruby<std::string>().convert(args[i]);
349
356
  }
350
357
 
data/lib/torch/tensor.rb CHANGED
@@ -177,11 +177,6 @@ module Torch
177
177
  _random!(*args)
178
178
  end
179
179
 
180
- # center option
181
- def stft(*args)
182
- Torch.stft(*args)
183
- end
184
-
185
180
  def dup
186
181
  Torch.no_grad do
187
182
  clone
data/lib/torch/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.11.0"
2
+ VERSION = "0.11.1"
3
3
  end
data/lib/torch.rb CHANGED
@@ -422,18 +422,6 @@ module Torch
422
422
  _tensor(data, size, tensor_options(**options))
423
423
  end
424
424
 
425
- # center option
426
- def stft(input, n_fft, hop_length: nil, win_length: nil, window: nil, center: true, pad_mode: "reflect", normalized: false, onesided: true, return_complex: nil)
427
- if center
428
- signal_dim = input.dim
429
- extended_shape = [1] * (3 - signal_dim) + input.size
430
- pad = n_fft.div(2).to_i
431
- input = NN::F.pad(input.view(extended_shape), [pad, pad], mode: pad_mode)
432
- input = input.view(input.shape[-signal_dim..-1])
433
- end
434
- _stft(input, n_fft, hop_length, win_length, window, normalized, onesided, return_complex)
435
- end
436
-
437
425
  private
438
426
 
439
427
  def to_ivalue(obj)
metadata CHANGED
@@ -1,7 +1,7 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.11.0
4
+ version: 0.11.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane