torch-rb 0.11.0 → 0.11.2
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +8 -0
- data/README.md +0 -14
- data/codegen/generate_functions.rb +7 -4
- data/ext/torch/extconf.rb +12 -2
- data/ext/torch/ruby_arg_parser.cpp +64 -2
- data/ext/torch/ruby_arg_parser.h +7 -0
- data/lib/torch/tensor.rb +0 -5
- data/lib/torch/version.rb +1 -1
- data/lib/torch.rb +0 -12
- metadata +2 -2
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 618b3c09305777402177e8bc3d536053434b1c94e5294035a8b4a78e645b3873
|
4
|
+
data.tar.gz: 910ac373619cb43887826d6277bc075b88311abd5dcc2be92bc6e9e15a35971d
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: af100f347769a268f7b45779fd9531f631e341bb0af4c999b1fba075b5d1aedae2d736eff048d9926c005b56e1ae35582a946b71db0cba5e62d83e938c315814
|
7
|
+
data.tar.gz: 67803295ac4642c4cd32e66e9f3fcdeaaa33a37ae675693236c0561c4dfb4d6d405db6277ab156b0f31415058333ecd2aa06b3788f895a7836fc277b4d3fc084
|
data/CHANGELOG.md
CHANGED
data/README.md
CHANGED
@@ -428,20 +428,6 @@ You can also use Homebrew.
|
|
428
428
|
brew install libtorch
|
429
429
|
```
|
430
430
|
|
431
|
-
For Mac ARM, run:
|
432
|
-
|
433
|
-
```sh
|
434
|
-
bundle config build.torch-rb --with-torch-dir=/opt/homebrew
|
435
|
-
```
|
436
|
-
|
437
|
-
And for Linux, run:
|
438
|
-
|
439
|
-
```sh
|
440
|
-
bundle config build.torch-rb --with-torch-dir=/home/linuxbrew/.linuxbrew
|
441
|
-
```
|
442
|
-
|
443
|
-
Then install the gem.
|
444
|
-
|
445
431
|
## Performance
|
446
432
|
|
447
433
|
Deep learning is significantly faster on a GPU. With Linux, install [CUDA](https://developer.nvidia.com/cuda-downloads) and [cuDNN](https://developer.nvidia.com/cudnn) and reinstall the gem.
|
@@ -128,8 +128,11 @@ def write_body(type, method_defs, attach_defs)
|
|
128
128
|
end
|
129
129
|
|
130
130
|
def write_file(name, contents)
|
131
|
-
path = File.expand_path("../ext/torch", __dir__)
|
132
|
-
|
131
|
+
path = File.join(File.expand_path("../ext/torch", __dir__), name)
|
132
|
+
# only write if changed to improve compile times in development
|
133
|
+
if !File.exist?(path) || File.read(path) != contents
|
134
|
+
File.write(path, contents)
|
135
|
+
end
|
133
136
|
end
|
134
137
|
|
135
138
|
def generate_attach_def(name, type, def_method)
|
@@ -142,14 +145,14 @@ def generate_attach_def(name, type, def_method)
|
|
142
145
|
name
|
143
146
|
end
|
144
147
|
|
145
|
-
ruby_name = "_#{ruby_name}" if ["size", "stride", "random!"
|
148
|
+
ruby_name = "_#{ruby_name}" if ["size", "stride", "random!"].include?(ruby_name)
|
146
149
|
ruby_name = ruby_name.sub(/\Afft_/, "") if type == "fft"
|
147
150
|
ruby_name = ruby_name.sub(/\Alinalg_/, "") if type == "linalg"
|
148
151
|
ruby_name = ruby_name.sub(/\Aspecial_/, "") if type == "special"
|
149
152
|
ruby_name = ruby_name.sub(/\Asparse_/, "") if type == "sparse"
|
150
153
|
ruby_name = name if name.start_with?("__")
|
151
154
|
|
152
|
-
# cast for Ruby <
|
155
|
+
# cast for Ruby < 3.0 https://github.com/thisMagpie/fftw/issues/22#issuecomment-49508900
|
153
156
|
cast = RUBY_VERSION.to_f > 2.7 ? "" : "(VALUE (*)(...)) "
|
154
157
|
|
155
158
|
"rb_#{def_method}(m, \"#{ruby_name}\", #{cast}#{full_name(name, type)}, -1);"
|
data/ext/torch/extconf.rb
CHANGED
@@ -18,9 +18,19 @@ else
|
|
18
18
|
$CXXFLAGS += " -Wno-duplicated-cond -Wno-suggest-attribute=noreturn"
|
19
19
|
end
|
20
20
|
|
21
|
+
paths = [
|
22
|
+
"/usr/local",
|
23
|
+
"/opt/homebrew",
|
24
|
+
"/home/linuxbrew/.linuxbrew"
|
25
|
+
]
|
26
|
+
|
21
27
|
inc, lib = dir_config("torch")
|
22
|
-
inc ||= "/
|
23
|
-
lib ||= "/
|
28
|
+
inc ||= paths.map { |v| "#{v}/include" }.find { |v| Dir.exist?("#{v}/torch") }
|
29
|
+
lib ||= paths.map { |v| "#{v}/lib" }.find { |v| Dir["#{v}/*torch_cpu*"].any? }
|
30
|
+
|
31
|
+
unless inc && lib
|
32
|
+
abort "LibTorch not found"
|
33
|
+
end
|
24
34
|
|
25
35
|
cuda_inc, cuda_lib = dir_config("cuda")
|
26
36
|
cuda_inc ||= "/usr/local/cuda/include"
|
@@ -253,6 +253,68 @@ static inline std::vector<int64_t> parse_intlist_args(const std::string& s, int6
|
|
253
253
|
return args;
|
254
254
|
}
|
255
255
|
|
256
|
+
// Parse a string literal to remove quotes and escape sequences
|
257
|
+
static std::string parse_string_literal(c10::string_view str) {
|
258
|
+
TORCH_CHECK(str.length() >= 2, "String defaults must be quoted");
|
259
|
+
|
260
|
+
if (str.front() == '"') {
|
261
|
+
TORCH_CHECK(
|
262
|
+
str.back() == '"', "Mismatched quotes in string default: ", str);
|
263
|
+
} else {
|
264
|
+
TORCH_CHECK(
|
265
|
+
str.front() == '\'' && str.back() == '\'',
|
266
|
+
"Invalid quotes in string default: ",
|
267
|
+
str)
|
268
|
+
}
|
269
|
+
|
270
|
+
std::string parsed;
|
271
|
+
parsed.reserve(str.size());
|
272
|
+
for (size_t i = 1; i < str.size() - 1;) {
|
273
|
+
if (str[i] != '\\') {
|
274
|
+
parsed.push_back(str[i]);
|
275
|
+
++i;
|
276
|
+
continue;
|
277
|
+
}
|
278
|
+
|
279
|
+
// Handle escape sequences
|
280
|
+
TORCH_CHECK(
|
281
|
+
i < str.size() - 2, "String ends with escaped final quote: ", str)
|
282
|
+
char c = str[i + 1];
|
283
|
+
switch (c) {
|
284
|
+
case '\\':
|
285
|
+
case '\'':
|
286
|
+
case '\"':
|
287
|
+
break;
|
288
|
+
case 'a':
|
289
|
+
c = '\a';
|
290
|
+
break;
|
291
|
+
case 'b':
|
292
|
+
c = '\b';
|
293
|
+
break;
|
294
|
+
case 'f':
|
295
|
+
c = '\f';
|
296
|
+
break;
|
297
|
+
case 'n':
|
298
|
+
c = '\n';
|
299
|
+
break;
|
300
|
+
case 'v':
|
301
|
+
c = '\v';
|
302
|
+
break;
|
303
|
+
case 't':
|
304
|
+
c = '\t';
|
305
|
+
break;
|
306
|
+
default:
|
307
|
+
TORCH_CHECK(
|
308
|
+
false,
|
309
|
+
"Unsupported escape sequence in string default: \\",
|
310
|
+
str[i + 1]);
|
311
|
+
}
|
312
|
+
parsed.push_back(c);
|
313
|
+
i += 2;
|
314
|
+
}
|
315
|
+
return parsed;
|
316
|
+
}
|
317
|
+
|
256
318
|
void FunctionParameter::set_default_str(const std::string& str) {
|
257
319
|
if (str == "None") {
|
258
320
|
allow_none = true;
|
@@ -308,8 +370,8 @@ void FunctionParameter::set_default_str(const std::string& str) {
|
|
308
370
|
throw std::runtime_error("invalid device: " + str);
|
309
371
|
}
|
310
372
|
} else if (type_ == ParameterType::STRING) {
|
311
|
-
if (str != "None"
|
312
|
-
|
373
|
+
if (str != "None") {
|
374
|
+
default_string = parse_string_literal(str);
|
313
375
|
}
|
314
376
|
}
|
315
377
|
}
|
data/ext/torch/ruby_arg_parser.h
CHANGED
@@ -35,6 +35,7 @@ struct FunctionParameter {
|
|
35
35
|
at::SmallVector<VALUE, 5> numpy_python_names;
|
36
36
|
at::Scalar default_scalar;
|
37
37
|
std::vector<int64_t> default_intlist;
|
38
|
+
std::string default_string;
|
38
39
|
union {
|
39
40
|
bool default_bool;
|
40
41
|
int64_t default_int;
|
@@ -108,6 +109,7 @@ struct RubyArgs {
|
|
108
109
|
inline c10::optional<at::MemoryFormat> memoryformatOptional(int i);
|
109
110
|
// inline at::QScheme toQScheme(int i);
|
110
111
|
inline std::string string(int i);
|
112
|
+
inline std::string stringWithDefault(int i, const std::string& default_str);
|
111
113
|
inline c10::optional<std::string> stringOptional(int i);
|
112
114
|
inline c10::string_view stringView(int i);
|
113
115
|
// inline c10::string_view stringViewWithDefault(int i, const c10::string_view default_str);
|
@@ -345,6 +347,11 @@ inline c10::optional<at::MemoryFormat> RubyArgs::memoryformatOptional(int i) {
|
|
345
347
|
}
|
346
348
|
|
347
349
|
inline std::string RubyArgs::string(int i) {
|
350
|
+
return stringWithDefault(i, signature.params[i].default_string);
|
351
|
+
}
|
352
|
+
|
353
|
+
inline std::string RubyArgs::stringWithDefault(int i, const std::string& default_str) {
|
354
|
+
if (!args[i]) return default_str;
|
348
355
|
return Rice::detail::From_Ruby<std::string>().convert(args[i]);
|
349
356
|
}
|
350
357
|
|
data/lib/torch/tensor.rb
CHANGED
data/lib/torch/version.rb
CHANGED
data/lib/torch.rb
CHANGED
@@ -422,18 +422,6 @@ module Torch
|
|
422
422
|
_tensor(data, size, tensor_options(**options))
|
423
423
|
end
|
424
424
|
|
425
|
-
# center option
|
426
|
-
def stft(input, n_fft, hop_length: nil, win_length: nil, window: nil, center: true, pad_mode: "reflect", normalized: false, onesided: true, return_complex: nil)
|
427
|
-
if center
|
428
|
-
signal_dim = input.dim
|
429
|
-
extended_shape = [1] * (3 - signal_dim) + input.size
|
430
|
-
pad = n_fft.div(2).to_i
|
431
|
-
input = NN::F.pad(input.view(extended_shape), [pad, pad], mode: pad_mode)
|
432
|
-
input = input.view(input.shape[-signal_dim..-1])
|
433
|
-
end
|
434
|
-
_stft(input, n_fft, hop_length, win_length, window, normalized, onesided, return_complex)
|
435
|
-
end
|
436
|
-
|
437
425
|
private
|
438
426
|
|
439
427
|
def to_ivalue(obj)
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.11.
|
4
|
+
version: 0.11.2
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2022-
|
11
|
+
date: 2022-09-25 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|