torch-rb 0.11.0 → 0.11.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/codegen/generate_functions.rb +7 -4
- data/ext/torch/ruby_arg_parser.cpp +64 -2
- data/ext/torch/ruby_arg_parser.h +7 -0
- data/lib/torch/tensor.rb +0 -5
- data/lib/torch/version.rb +1 -1
- data/lib/torch.rb +0 -12
- metadata +1 -1
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: c4c3936a55fd47a8898d3e678aabc81c158f2f39a89c08af0b36695700bf2043
|
4
|
+
data.tar.gz: a68179c7d7bab7547ac3be1d2369abbd5c3c632954996ca2f7fdd089f815edf7
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 97551aa27f154cade530e58b4ae92286cd18e432acc99646495a197c89fb719cc991b0399aacaf383a2de04e340c04a9e177c44a0169fce91e19435463df4753
|
7
|
+
data.tar.gz: 8a5a5decf900a4d93aa56eb5e5b3848d5386be8bbc3657c10c201ef321cf0082c6521700378a5e96f795c09882b5b72498d8fed3931153b6e1f7557630f55547
|
data/CHANGELOG.md
CHANGED
@@ -128,8 +128,11 @@ def write_body(type, method_defs, attach_defs)
|
|
128
128
|
end
|
129
129
|
|
130
130
|
def write_file(name, contents)
|
131
|
-
path = File.expand_path("../ext/torch", __dir__)
|
132
|
-
|
131
|
+
path = File.join(File.expand_path("../ext/torch", __dir__), name)
|
132
|
+
# only write if changed to improve compile times in development
|
133
|
+
if !File.exist?(path) || File.read(path) != contents
|
134
|
+
File.write(path, contents)
|
135
|
+
end
|
133
136
|
end
|
134
137
|
|
135
138
|
def generate_attach_def(name, type, def_method)
|
@@ -142,14 +145,14 @@ def generate_attach_def(name, type, def_method)
|
|
142
145
|
name
|
143
146
|
end
|
144
147
|
|
145
|
-
ruby_name = "_#{ruby_name}" if ["size", "stride", "random!"
|
148
|
+
ruby_name = "_#{ruby_name}" if ["size", "stride", "random!"].include?(ruby_name)
|
146
149
|
ruby_name = ruby_name.sub(/\Afft_/, "") if type == "fft"
|
147
150
|
ruby_name = ruby_name.sub(/\Alinalg_/, "") if type == "linalg"
|
148
151
|
ruby_name = ruby_name.sub(/\Aspecial_/, "") if type == "special"
|
149
152
|
ruby_name = ruby_name.sub(/\Asparse_/, "") if type == "sparse"
|
150
153
|
ruby_name = name if name.start_with?("__")
|
151
154
|
|
152
|
-
# cast for Ruby <
|
155
|
+
# cast for Ruby < 3.0 https://github.com/thisMagpie/fftw/issues/22#issuecomment-49508900
|
153
156
|
cast = RUBY_VERSION.to_f > 2.7 ? "" : "(VALUE (*)(...)) "
|
154
157
|
|
155
158
|
"rb_#{def_method}(m, \"#{ruby_name}\", #{cast}#{full_name(name, type)}, -1);"
|
@@ -253,6 +253,68 @@ static inline std::vector<int64_t> parse_intlist_args(const std::string& s, int6
|
|
253
253
|
return args;
|
254
254
|
}
|
255
255
|
|
256
|
+
// Parse a string literal to remove quotes and escape sequences
|
257
|
+
static std::string parse_string_literal(c10::string_view str) {
|
258
|
+
TORCH_CHECK(str.length() >= 2, "String defaults must be quoted");
|
259
|
+
|
260
|
+
if (str.front() == '"') {
|
261
|
+
TORCH_CHECK(
|
262
|
+
str.back() == '"', "Mismatched quotes in string default: ", str);
|
263
|
+
} else {
|
264
|
+
TORCH_CHECK(
|
265
|
+
str.front() == '\'' && str.back() == '\'',
|
266
|
+
"Invalid quotes in string default: ",
|
267
|
+
str)
|
268
|
+
}
|
269
|
+
|
270
|
+
std::string parsed;
|
271
|
+
parsed.reserve(str.size());
|
272
|
+
for (size_t i = 1; i < str.size() - 1;) {
|
273
|
+
if (str[i] != '\\') {
|
274
|
+
parsed.push_back(str[i]);
|
275
|
+
++i;
|
276
|
+
continue;
|
277
|
+
}
|
278
|
+
|
279
|
+
// Handle escape sequences
|
280
|
+
TORCH_CHECK(
|
281
|
+
i < str.size() - 2, "String ends with escaped final quote: ", str)
|
282
|
+
char c = str[i + 1];
|
283
|
+
switch (c) {
|
284
|
+
case '\\':
|
285
|
+
case '\'':
|
286
|
+
case '\"':
|
287
|
+
break;
|
288
|
+
case 'a':
|
289
|
+
c = '\a';
|
290
|
+
break;
|
291
|
+
case 'b':
|
292
|
+
c = '\b';
|
293
|
+
break;
|
294
|
+
case 'f':
|
295
|
+
c = '\f';
|
296
|
+
break;
|
297
|
+
case 'n':
|
298
|
+
c = '\n';
|
299
|
+
break;
|
300
|
+
case 'v':
|
301
|
+
c = '\v';
|
302
|
+
break;
|
303
|
+
case 't':
|
304
|
+
c = '\t';
|
305
|
+
break;
|
306
|
+
default:
|
307
|
+
TORCH_CHECK(
|
308
|
+
false,
|
309
|
+
"Unsupported escape sequence in string default: \\",
|
310
|
+
str[i + 1]);
|
311
|
+
}
|
312
|
+
parsed.push_back(c);
|
313
|
+
i += 2;
|
314
|
+
}
|
315
|
+
return parsed;
|
316
|
+
}
|
317
|
+
|
256
318
|
void FunctionParameter::set_default_str(const std::string& str) {
|
257
319
|
if (str == "None") {
|
258
320
|
allow_none = true;
|
@@ -308,8 +370,8 @@ void FunctionParameter::set_default_str(const std::string& str) {
|
|
308
370
|
throw std::runtime_error("invalid device: " + str);
|
309
371
|
}
|
310
372
|
} else if (type_ == ParameterType::STRING) {
|
311
|
-
if (str != "None"
|
312
|
-
|
373
|
+
if (str != "None") {
|
374
|
+
default_string = parse_string_literal(str);
|
313
375
|
}
|
314
376
|
}
|
315
377
|
}
|
data/ext/torch/ruby_arg_parser.h
CHANGED
@@ -35,6 +35,7 @@ struct FunctionParameter {
|
|
35
35
|
at::SmallVector<VALUE, 5> numpy_python_names;
|
36
36
|
at::Scalar default_scalar;
|
37
37
|
std::vector<int64_t> default_intlist;
|
38
|
+
std::string default_string;
|
38
39
|
union {
|
39
40
|
bool default_bool;
|
40
41
|
int64_t default_int;
|
@@ -108,6 +109,7 @@ struct RubyArgs {
|
|
108
109
|
inline c10::optional<at::MemoryFormat> memoryformatOptional(int i);
|
109
110
|
// inline at::QScheme toQScheme(int i);
|
110
111
|
inline std::string string(int i);
|
112
|
+
inline std::string stringWithDefault(int i, const std::string& default_str);
|
111
113
|
inline c10::optional<std::string> stringOptional(int i);
|
112
114
|
inline c10::string_view stringView(int i);
|
113
115
|
// inline c10::string_view stringViewWithDefault(int i, const c10::string_view default_str);
|
@@ -345,6 +347,11 @@ inline c10::optional<at::MemoryFormat> RubyArgs::memoryformatOptional(int i) {
|
|
345
347
|
}
|
346
348
|
|
347
349
|
inline std::string RubyArgs::string(int i) {
|
350
|
+
return stringWithDefault(i, signature.params[i].default_string);
|
351
|
+
}
|
352
|
+
|
353
|
+
inline std::string RubyArgs::stringWithDefault(int i, const std::string& default_str) {
|
354
|
+
if (!args[i]) return default_str;
|
348
355
|
return Rice::detail::From_Ruby<std::string>().convert(args[i]);
|
349
356
|
}
|
350
357
|
|
data/lib/torch/tensor.rb
CHANGED
data/lib/torch/version.rb
CHANGED
data/lib/torch.rb
CHANGED
@@ -422,18 +422,6 @@ module Torch
|
|
422
422
|
_tensor(data, size, tensor_options(**options))
|
423
423
|
end
|
424
424
|
|
425
|
-
# center option
|
426
|
-
def stft(input, n_fft, hop_length: nil, win_length: nil, window: nil, center: true, pad_mode: "reflect", normalized: false, onesided: true, return_complex: nil)
|
427
|
-
if center
|
428
|
-
signal_dim = input.dim
|
429
|
-
extended_shape = [1] * (3 - signal_dim) + input.size
|
430
|
-
pad = n_fft.div(2).to_i
|
431
|
-
input = NN::F.pad(input.view(extended_shape), [pad, pad], mode: pad_mode)
|
432
|
-
input = input.view(input.shape[-signal_dim..-1])
|
433
|
-
end
|
434
|
-
_stft(input, n_fft, hop_length, win_length, window, normalized, onesided, return_complex)
|
435
|
-
end
|
436
|
-
|
437
425
|
private
|
438
426
|
|
439
427
|
def to_ivalue(obj)
|