torch-rb 0.11.0 → 0.11.1

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: d2f88c938144476a772fd7c606751e0a6e67a338cb1c37ede9c2011db7fc4579
4
- data.tar.gz: f4408457e3bf8c7bf9b42863459248aa592c70e8e3924cdbe72863a979e65106
3
+ metadata.gz: c4c3936a55fd47a8898d3e678aabc81c158f2f39a89c08af0b36695700bf2043
4
+ data.tar.gz: a68179c7d7bab7547ac3be1d2369abbd5c3c632954996ca2f7fdd089f815edf7
5
5
  SHA512:
6
- metadata.gz: 90ebc506942809b02331f7accce6e93e714a57b5d2f58b06ad0d3c60204b947eba61fe0bf43431a19e27890f5c764cb348ccb1785a2ffccfba7969f4726b6f1f
7
- data.tar.gz: 410af7b4934f79aaae94f6c1bde1d17ee8b043f02354b4ea9f139271575f43abb90afeba162badfa01a3ec406612ae6cdcb4a2adfb1067b1634982a590d10cc9
6
+ metadata.gz: 97551aa27f154cade530e58b4ae92286cd18e432acc99646495a197c89fb719cc991b0399aacaf383a2de04e340c04a9e177c44a0169fce91e19435463df4753
7
+ data.tar.gz: 8a5a5decf900a4d93aa56eb5e5b3848d5386be8bbc3657c10c201ef321cf0082c6521700378a5e96f795c09882b5b72498d8fed3931153b6e1f7557630f55547
data/CHANGELOG.md CHANGED
@@ -1,3 +1,7 @@
1
+ ## 0.11.1 (2022-07-06)
2
+
3
+ - Fixed error with `stft` method
4
+
1
5
  ## 0.11.0 (2022-07-06)
2
6
 
3
7
  - Updated LibTorch to 1.12.0
@@ -128,8 +128,11 @@ def write_body(type, method_defs, attach_defs)
128
128
  end
129
129
 
130
130
  def write_file(name, contents)
131
- path = File.expand_path("../ext/torch", __dir__)
132
- File.write(File.join(path, name), contents)
131
+ path = File.join(File.expand_path("../ext/torch", __dir__), name)
132
+ # only write if changed to improve compile times in development
133
+ if !File.exist?(path) || File.read(path) != contents
134
+ File.write(path, contents)
135
+ end
133
136
  end
134
137
 
135
138
  def generate_attach_def(name, type, def_method)
@@ -142,14 +145,14 @@ def generate_attach_def(name, type, def_method)
142
145
  name
143
146
  end
144
147
 
145
- ruby_name = "_#{ruby_name}" if ["size", "stride", "random!", "stft"].include?(ruby_name)
148
+ ruby_name = "_#{ruby_name}" if ["size", "stride", "random!"].include?(ruby_name)
146
149
  ruby_name = ruby_name.sub(/\Afft_/, "") if type == "fft"
147
150
  ruby_name = ruby_name.sub(/\Alinalg_/, "") if type == "linalg"
148
151
  ruby_name = ruby_name.sub(/\Aspecial_/, "") if type == "special"
149
152
  ruby_name = ruby_name.sub(/\Asparse_/, "") if type == "sparse"
150
153
  ruby_name = name if name.start_with?("__")
151
154
 
152
- # cast for Ruby < 2.7 https://github.com/thisMagpie/fftw/issues/22#issuecomment-49508900
155
+ # cast for Ruby < 3.0 https://github.com/thisMagpie/fftw/issues/22#issuecomment-49508900
153
156
  cast = RUBY_VERSION.to_f > 2.7 ? "" : "(VALUE (*)(...)) "
154
157
 
155
158
  "rb_#{def_method}(m, \"#{ruby_name}\", #{cast}#{full_name(name, type)}, -1);"
@@ -253,6 +253,68 @@ static inline std::vector<int64_t> parse_intlist_args(const std::string& s, int6
253
253
  return args;
254
254
  }
255
255
 
256
+ // Parse a string literal to remove quotes and escape sequences
257
+ static std::string parse_string_literal(c10::string_view str) {
258
+ TORCH_CHECK(str.length() >= 2, "String defaults must be quoted");
259
+
260
+ if (str.front() == '"') {
261
+ TORCH_CHECK(
262
+ str.back() == '"', "Mismatched quotes in string default: ", str);
263
+ } else {
264
+ TORCH_CHECK(
265
+ str.front() == '\'' && str.back() == '\'',
266
+ "Invalid quotes in string default: ",
267
+ str)
268
+ }
269
+
270
+ std::string parsed;
271
+ parsed.reserve(str.size());
272
+ for (size_t i = 1; i < str.size() - 1;) {
273
+ if (str[i] != '\\') {
274
+ parsed.push_back(str[i]);
275
+ ++i;
276
+ continue;
277
+ }
278
+
279
+ // Handle escape sequences
280
+ TORCH_CHECK(
281
+ i < str.size() - 2, "String ends with escaped final quote: ", str)
282
+ char c = str[i + 1];
283
+ switch (c) {
284
+ case '\\':
285
+ case '\'':
286
+ case '\"':
287
+ break;
288
+ case 'a':
289
+ c = '\a';
290
+ break;
291
+ case 'b':
292
+ c = '\b';
293
+ break;
294
+ case 'f':
295
+ c = '\f';
296
+ break;
297
+ case 'n':
298
+ c = '\n';
299
+ break;
300
+ case 'v':
301
+ c = '\v';
302
+ break;
303
+ case 't':
304
+ c = '\t';
305
+ break;
306
+ default:
307
+ TORCH_CHECK(
308
+ false,
309
+ "Unsupported escape sequence in string default: \\",
310
+ str[i + 1]);
311
+ }
312
+ parsed.push_back(c);
313
+ i += 2;
314
+ }
315
+ return parsed;
316
+ }
317
+
256
318
  void FunctionParameter::set_default_str(const std::string& str) {
257
319
  if (str == "None") {
258
320
  allow_none = true;
@@ -308,8 +370,8 @@ void FunctionParameter::set_default_str(const std::string& str) {
308
370
  throw std::runtime_error("invalid device: " + str);
309
371
  }
310
372
  } else if (type_ == ParameterType::STRING) {
311
- if (str != "None" && str != "") {
312
- throw std::runtime_error("invalid default string: " + str);
373
+ if (str != "None") {
374
+ default_string = parse_string_literal(str);
313
375
  }
314
376
  }
315
377
  }
@@ -35,6 +35,7 @@ struct FunctionParameter {
35
35
  at::SmallVector<VALUE, 5> numpy_python_names;
36
36
  at::Scalar default_scalar;
37
37
  std::vector<int64_t> default_intlist;
38
+ std::string default_string;
38
39
  union {
39
40
  bool default_bool;
40
41
  int64_t default_int;
@@ -108,6 +109,7 @@ struct RubyArgs {
108
109
  inline c10::optional<at::MemoryFormat> memoryformatOptional(int i);
109
110
  // inline at::QScheme toQScheme(int i);
110
111
  inline std::string string(int i);
112
+ inline std::string stringWithDefault(int i, const std::string& default_str);
111
113
  inline c10::optional<std::string> stringOptional(int i);
112
114
  inline c10::string_view stringView(int i);
113
115
  // inline c10::string_view stringViewWithDefault(int i, const c10::string_view default_str);
@@ -345,6 +347,11 @@ inline c10::optional<at::MemoryFormat> RubyArgs::memoryformatOptional(int i) {
345
347
  }
346
348
 
347
349
  inline std::string RubyArgs::string(int i) {
350
+ return stringWithDefault(i, signature.params[i].default_string);
351
+ }
352
+
353
+ inline std::string RubyArgs::stringWithDefault(int i, const std::string& default_str) {
354
+ if (!args[i]) return default_str;
348
355
  return Rice::detail::From_Ruby<std::string>().convert(args[i]);
349
356
  }
350
357
 
data/lib/torch/tensor.rb CHANGED
@@ -177,11 +177,6 @@ module Torch
177
177
  _random!(*args)
178
178
  end
179
179
 
180
- # center option
181
- def stft(*args)
182
- Torch.stft(*args)
183
- end
184
-
185
180
  def dup
186
181
  Torch.no_grad do
187
182
  clone
data/lib/torch/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.11.0"
2
+ VERSION = "0.11.1"
3
3
  end
data/lib/torch.rb CHANGED
@@ -422,18 +422,6 @@ module Torch
422
422
  _tensor(data, size, tensor_options(**options))
423
423
  end
424
424
 
425
- # center option
426
- def stft(input, n_fft, hop_length: nil, win_length: nil, window: nil, center: true, pad_mode: "reflect", normalized: false, onesided: true, return_complex: nil)
427
- if center
428
- signal_dim = input.dim
429
- extended_shape = [1] * (3 - signal_dim) + input.size
430
- pad = n_fft.div(2).to_i
431
- input = NN::F.pad(input.view(extended_shape), [pad, pad], mode: pad_mode)
432
- input = input.view(input.shape[-signal_dim..-1])
433
- end
434
- _stft(input, n_fft, hop_length, win_length, window, normalized, onesided, return_complex)
435
- end
436
-
437
425
  private
438
426
 
439
427
  def to_ivalue(obj)
metadata CHANGED
@@ -1,7 +1,7 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.11.0
4
+ version: 0.11.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane