torch-rb 0.11.0 → 0.11.2

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: d2f88c938144476a772fd7c606751e0a6e67a338cb1c37ede9c2011db7fc4579
4
- data.tar.gz: f4408457e3bf8c7bf9b42863459248aa592c70e8e3924cdbe72863a979e65106
3
+ metadata.gz: 618b3c09305777402177e8bc3d536053434b1c94e5294035a8b4a78e645b3873
4
+ data.tar.gz: 910ac373619cb43887826d6277bc075b88311abd5dcc2be92bc6e9e15a35971d
5
5
  SHA512:
6
- metadata.gz: 90ebc506942809b02331f7accce6e93e714a57b5d2f58b06ad0d3c60204b947eba61fe0bf43431a19e27890f5c764cb348ccb1785a2ffccfba7969f4726b6f1f
7
- data.tar.gz: 410af7b4934f79aaae94f6c1bde1d17ee8b043f02354b4ea9f139271575f43abb90afeba162badfa01a3ec406612ae6cdcb4a2adfb1067b1634982a590d10cc9
6
+ metadata.gz: af100f347769a268f7b45779fd9531f631e341bb0af4c999b1fba075b5d1aedae2d736eff048d9926c005b56e1ae35582a946b71db0cba5e62d83e938c315814
7
+ data.tar.gz: 67803295ac4642c4cd32e66e9f3fcdeaaa33a37ae675693236c0561c4dfb4d6d405db6277ab156b0f31415058333ecd2aa06b3788f895a7836fc277b4d3fc084
data/CHANGELOG.md CHANGED
@@ -1,3 +1,11 @@
1
+ ## 0.11.2 (2022-09-25)
2
+
3
+ - Improved LibTorch detection for Homebrew on Mac ARM and Linux
4
+
5
+ ## 0.11.1 (2022-07-06)
6
+
7
+ - Fixed error with `stft` method
8
+
1
9
  ## 0.11.0 (2022-07-06)
2
10
 
3
11
  - Updated LibTorch to 1.12.0
data/README.md CHANGED
@@ -428,20 +428,6 @@ You can also use Homebrew.
428
428
  brew install libtorch
429
429
  ```
430
430
 
431
- For Mac ARM, run:
432
-
433
- ```sh
434
- bundle config build.torch-rb --with-torch-dir=/opt/homebrew
435
- ```
436
-
437
- And for Linux, run:
438
-
439
- ```sh
440
- bundle config build.torch-rb --with-torch-dir=/home/linuxbrew/.linuxbrew
441
- ```
442
-
443
- Then install the gem.
444
-
445
431
  ## Performance
446
432
 
447
433
  Deep learning is significantly faster on a GPU. With Linux, install [CUDA](https://developer.nvidia.com/cuda-downloads) and [cuDNN](https://developer.nvidia.com/cudnn) and reinstall the gem.
@@ -128,8 +128,11 @@ def write_body(type, method_defs, attach_defs)
128
128
  end
129
129
 
130
130
  def write_file(name, contents)
131
- path = File.expand_path("../ext/torch", __dir__)
132
- File.write(File.join(path, name), contents)
131
+ path = File.join(File.expand_path("../ext/torch", __dir__), name)
132
+ # only write if changed to improve compile times in development
133
+ if !File.exist?(path) || File.read(path) != contents
134
+ File.write(path, contents)
135
+ end
133
136
  end
134
137
 
135
138
  def generate_attach_def(name, type, def_method)
@@ -142,14 +145,14 @@ def generate_attach_def(name, type, def_method)
142
145
  name
143
146
  end
144
147
 
145
- ruby_name = "_#{ruby_name}" if ["size", "stride", "random!", "stft"].include?(ruby_name)
148
+ ruby_name = "_#{ruby_name}" if ["size", "stride", "random!"].include?(ruby_name)
146
149
  ruby_name = ruby_name.sub(/\Afft_/, "") if type == "fft"
147
150
  ruby_name = ruby_name.sub(/\Alinalg_/, "") if type == "linalg"
148
151
  ruby_name = ruby_name.sub(/\Aspecial_/, "") if type == "special"
149
152
  ruby_name = ruby_name.sub(/\Asparse_/, "") if type == "sparse"
150
153
  ruby_name = name if name.start_with?("__")
151
154
 
152
- # cast for Ruby < 2.7 https://github.com/thisMagpie/fftw/issues/22#issuecomment-49508900
155
+ # cast for Ruby < 3.0 https://github.com/thisMagpie/fftw/issues/22#issuecomment-49508900
153
156
  cast = RUBY_VERSION.to_f > 2.7 ? "" : "(VALUE (*)(...)) "
154
157
 
155
158
  "rb_#{def_method}(m, \"#{ruby_name}\", #{cast}#{full_name(name, type)}, -1);"
data/ext/torch/extconf.rb CHANGED
@@ -18,9 +18,19 @@ else
18
18
  $CXXFLAGS += " -Wno-duplicated-cond -Wno-suggest-attribute=noreturn"
19
19
  end
20
20
 
21
+ paths = [
22
+ "/usr/local",
23
+ "/opt/homebrew",
24
+ "/home/linuxbrew/.linuxbrew"
25
+ ]
26
+
21
27
  inc, lib = dir_config("torch")
22
- inc ||= "/usr/local/include"
23
- lib ||= "/usr/local/lib"
28
+ inc ||= paths.map { |v| "#{v}/include" }.find { |v| Dir.exist?("#{v}/torch") }
29
+ lib ||= paths.map { |v| "#{v}/lib" }.find { |v| Dir["#{v}/*torch_cpu*"].any? }
30
+
31
+ unless inc && lib
32
+ abort "LibTorch not found"
33
+ end
24
34
 
25
35
  cuda_inc, cuda_lib = dir_config("cuda")
26
36
  cuda_inc ||= "/usr/local/cuda/include"
@@ -253,6 +253,68 @@ static inline std::vector<int64_t> parse_intlist_args(const std::string& s, int6
253
253
  return args;
254
254
  }
255
255
 
256
+ // Parse a string literal to remove quotes and escape sequences
257
+ static std::string parse_string_literal(c10::string_view str) {
258
+ TORCH_CHECK(str.length() >= 2, "String defaults must be quoted");
259
+
260
+ if (str.front() == '"') {
261
+ TORCH_CHECK(
262
+ str.back() == '"', "Mismatched quotes in string default: ", str);
263
+ } else {
264
+ TORCH_CHECK(
265
+ str.front() == '\'' && str.back() == '\'',
266
+ "Invalid quotes in string default: ",
267
+ str)
268
+ }
269
+
270
+ std::string parsed;
271
+ parsed.reserve(str.size());
272
+ for (size_t i = 1; i < str.size() - 1;) {
273
+ if (str[i] != '\\') {
274
+ parsed.push_back(str[i]);
275
+ ++i;
276
+ continue;
277
+ }
278
+
279
+ // Handle escape sequences
280
+ TORCH_CHECK(
281
+ i < str.size() - 2, "String ends with escaped final quote: ", str)
282
+ char c = str[i + 1];
283
+ switch (c) {
284
+ case '\\':
285
+ case '\'':
286
+ case '\"':
287
+ break;
288
+ case 'a':
289
+ c = '\a';
290
+ break;
291
+ case 'b':
292
+ c = '\b';
293
+ break;
294
+ case 'f':
295
+ c = '\f';
296
+ break;
297
+ case 'n':
298
+ c = '\n';
299
+ break;
300
+ case 'v':
301
+ c = '\v';
302
+ break;
303
+ case 't':
304
+ c = '\t';
305
+ break;
306
+ default:
307
+ TORCH_CHECK(
308
+ false,
309
+ "Unsupported escape sequence in string default: \\",
310
+ str[i + 1]);
311
+ }
312
+ parsed.push_back(c);
313
+ i += 2;
314
+ }
315
+ return parsed;
316
+ }
317
+
256
318
  void FunctionParameter::set_default_str(const std::string& str) {
257
319
  if (str == "None") {
258
320
  allow_none = true;
@@ -308,8 +370,8 @@ void FunctionParameter::set_default_str(const std::string& str) {
308
370
  throw std::runtime_error("invalid device: " + str);
309
371
  }
310
372
  } else if (type_ == ParameterType::STRING) {
311
- if (str != "None" && str != "") {
312
- throw std::runtime_error("invalid default string: " + str);
373
+ if (str != "None") {
374
+ default_string = parse_string_literal(str);
313
375
  }
314
376
  }
315
377
  }
@@ -35,6 +35,7 @@ struct FunctionParameter {
35
35
  at::SmallVector<VALUE, 5> numpy_python_names;
36
36
  at::Scalar default_scalar;
37
37
  std::vector<int64_t> default_intlist;
38
+ std::string default_string;
38
39
  union {
39
40
  bool default_bool;
40
41
  int64_t default_int;
@@ -108,6 +109,7 @@ struct RubyArgs {
108
109
  inline c10::optional<at::MemoryFormat> memoryformatOptional(int i);
109
110
  // inline at::QScheme toQScheme(int i);
110
111
  inline std::string string(int i);
112
+ inline std::string stringWithDefault(int i, const std::string& default_str);
111
113
  inline c10::optional<std::string> stringOptional(int i);
112
114
  inline c10::string_view stringView(int i);
113
115
  // inline c10::string_view stringViewWithDefault(int i, const c10::string_view default_str);
@@ -345,6 +347,11 @@ inline c10::optional<at::MemoryFormat> RubyArgs::memoryformatOptional(int i) {
345
347
  }
346
348
 
347
349
  inline std::string RubyArgs::string(int i) {
350
+ return stringWithDefault(i, signature.params[i].default_string);
351
+ }
352
+
353
+ inline std::string RubyArgs::stringWithDefault(int i, const std::string& default_str) {
354
+ if (!args[i]) return default_str;
348
355
  return Rice::detail::From_Ruby<std::string>().convert(args[i]);
349
356
  }
350
357
 
data/lib/torch/tensor.rb CHANGED
@@ -177,11 +177,6 @@ module Torch
177
177
  _random!(*args)
178
178
  end
179
179
 
180
- # center option
181
- def stft(*args)
182
- Torch.stft(*args)
183
- end
184
-
185
180
  def dup
186
181
  Torch.no_grad do
187
182
  clone
data/lib/torch/version.rb CHANGED
@@ -1,3 +1,3 @@
1
1
  module Torch
2
- VERSION = "0.11.0"
2
+ VERSION = "0.11.2"
3
3
  end
data/lib/torch.rb CHANGED
@@ -422,18 +422,6 @@ module Torch
422
422
  _tensor(data, size, tensor_options(**options))
423
423
  end
424
424
 
425
- # center option
426
- def stft(input, n_fft, hop_length: nil, win_length: nil, window: nil, center: true, pad_mode: "reflect", normalized: false, onesided: true, return_complex: nil)
427
- if center
428
- signal_dim = input.dim
429
- extended_shape = [1] * (3 - signal_dim) + input.size
430
- pad = n_fft.div(2).to_i
431
- input = NN::F.pad(input.view(extended_shape), [pad, pad], mode: pad_mode)
432
- input = input.view(input.shape[-signal_dim..-1])
433
- end
434
- _stft(input, n_fft, hop_length, win_length, window, normalized, onesided, return_complex)
435
- end
436
-
437
425
  private
438
426
 
439
427
  def to_ivalue(obj)
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torch-rb
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.11.0
4
+ version: 0.11.2
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2022-07-06 00:00:00.000000000 Z
11
+ date: 2022-09-25 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: rice