torch-rb 0.11.0 → 0.11.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +4 -0
- data/codegen/generate_functions.rb +7 -4
- data/ext/torch/ruby_arg_parser.cpp +64 -2
- data/ext/torch/ruby_arg_parser.h +7 -0
- data/lib/torch/tensor.rb +0 -5
- data/lib/torch/version.rb +1 -1
- data/lib/torch.rb +0 -12
- metadata +1 -1
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: c4c3936a55fd47a8898d3e678aabc81c158f2f39a89c08af0b36695700bf2043
|
4
|
+
data.tar.gz: a68179c7d7bab7547ac3be1d2369abbd5c3c632954996ca2f7fdd089f815edf7
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 97551aa27f154cade530e58b4ae92286cd18e432acc99646495a197c89fb719cc991b0399aacaf383a2de04e340c04a9e177c44a0169fce91e19435463df4753
|
7
|
+
data.tar.gz: 8a5a5decf900a4d93aa56eb5e5b3848d5386be8bbc3657c10c201ef321cf0082c6521700378a5e96f795c09882b5b72498d8fed3931153b6e1f7557630f55547
|
data/CHANGELOG.md
CHANGED
@@ -128,8 +128,11 @@ def write_body(type, method_defs, attach_defs)
|
|
128
128
|
end
|
129
129
|
|
130
130
|
def write_file(name, contents)
|
131
|
-
path = File.expand_path("../ext/torch", __dir__)
|
132
|
-
|
131
|
+
path = File.join(File.expand_path("../ext/torch", __dir__), name)
|
132
|
+
# only write if changed to improve compile times in development
|
133
|
+
if !File.exist?(path) || File.read(path) != contents
|
134
|
+
File.write(path, contents)
|
135
|
+
end
|
133
136
|
end
|
134
137
|
|
135
138
|
def generate_attach_def(name, type, def_method)
|
@@ -142,14 +145,14 @@ def generate_attach_def(name, type, def_method)
|
|
142
145
|
name
|
143
146
|
end
|
144
147
|
|
145
|
-
ruby_name = "_#{ruby_name}" if ["size", "stride", "random!"
|
148
|
+
ruby_name = "_#{ruby_name}" if ["size", "stride", "random!"].include?(ruby_name)
|
146
149
|
ruby_name = ruby_name.sub(/\Afft_/, "") if type == "fft"
|
147
150
|
ruby_name = ruby_name.sub(/\Alinalg_/, "") if type == "linalg"
|
148
151
|
ruby_name = ruby_name.sub(/\Aspecial_/, "") if type == "special"
|
149
152
|
ruby_name = ruby_name.sub(/\Asparse_/, "") if type == "sparse"
|
150
153
|
ruby_name = name if name.start_with?("__")
|
151
154
|
|
152
|
-
# cast for Ruby <
|
155
|
+
# cast for Ruby < 3.0 https://github.com/thisMagpie/fftw/issues/22#issuecomment-49508900
|
153
156
|
cast = RUBY_VERSION.to_f > 2.7 ? "" : "(VALUE (*)(...)) "
|
154
157
|
|
155
158
|
"rb_#{def_method}(m, \"#{ruby_name}\", #{cast}#{full_name(name, type)}, -1);"
|
@@ -253,6 +253,68 @@ static inline std::vector<int64_t> parse_intlist_args(const std::string& s, int6
|
|
253
253
|
return args;
|
254
254
|
}
|
255
255
|
|
256
|
+
// Parse a string literal to remove quotes and escape sequences
|
257
|
+
static std::string parse_string_literal(c10::string_view str) {
|
258
|
+
TORCH_CHECK(str.length() >= 2, "String defaults must be quoted");
|
259
|
+
|
260
|
+
if (str.front() == '"') {
|
261
|
+
TORCH_CHECK(
|
262
|
+
str.back() == '"', "Mismatched quotes in string default: ", str);
|
263
|
+
} else {
|
264
|
+
TORCH_CHECK(
|
265
|
+
str.front() == '\'' && str.back() == '\'',
|
266
|
+
"Invalid quotes in string default: ",
|
267
|
+
str)
|
268
|
+
}
|
269
|
+
|
270
|
+
std::string parsed;
|
271
|
+
parsed.reserve(str.size());
|
272
|
+
for (size_t i = 1; i < str.size() - 1;) {
|
273
|
+
if (str[i] != '\\') {
|
274
|
+
parsed.push_back(str[i]);
|
275
|
+
++i;
|
276
|
+
continue;
|
277
|
+
}
|
278
|
+
|
279
|
+
// Handle escape sequences
|
280
|
+
TORCH_CHECK(
|
281
|
+
i < str.size() - 2, "String ends with escaped final quote: ", str)
|
282
|
+
char c = str[i + 1];
|
283
|
+
switch (c) {
|
284
|
+
case '\\':
|
285
|
+
case '\'':
|
286
|
+
case '\"':
|
287
|
+
break;
|
288
|
+
case 'a':
|
289
|
+
c = '\a';
|
290
|
+
break;
|
291
|
+
case 'b':
|
292
|
+
c = '\b';
|
293
|
+
break;
|
294
|
+
case 'f':
|
295
|
+
c = '\f';
|
296
|
+
break;
|
297
|
+
case 'n':
|
298
|
+
c = '\n';
|
299
|
+
break;
|
300
|
+
case 'v':
|
301
|
+
c = '\v';
|
302
|
+
break;
|
303
|
+
case 't':
|
304
|
+
c = '\t';
|
305
|
+
break;
|
306
|
+
default:
|
307
|
+
TORCH_CHECK(
|
308
|
+
false,
|
309
|
+
"Unsupported escape sequence in string default: \\",
|
310
|
+
str[i + 1]);
|
311
|
+
}
|
312
|
+
parsed.push_back(c);
|
313
|
+
i += 2;
|
314
|
+
}
|
315
|
+
return parsed;
|
316
|
+
}
|
317
|
+
|
256
318
|
void FunctionParameter::set_default_str(const std::string& str) {
|
257
319
|
if (str == "None") {
|
258
320
|
allow_none = true;
|
@@ -308,8 +370,8 @@ void FunctionParameter::set_default_str(const std::string& str) {
|
|
308
370
|
throw std::runtime_error("invalid device: " + str);
|
309
371
|
}
|
310
372
|
} else if (type_ == ParameterType::STRING) {
|
311
|
-
if (str != "None"
|
312
|
-
|
373
|
+
if (str != "None") {
|
374
|
+
default_string = parse_string_literal(str);
|
313
375
|
}
|
314
376
|
}
|
315
377
|
}
|
data/ext/torch/ruby_arg_parser.h
CHANGED
@@ -35,6 +35,7 @@ struct FunctionParameter {
|
|
35
35
|
at::SmallVector<VALUE, 5> numpy_python_names;
|
36
36
|
at::Scalar default_scalar;
|
37
37
|
std::vector<int64_t> default_intlist;
|
38
|
+
std::string default_string;
|
38
39
|
union {
|
39
40
|
bool default_bool;
|
40
41
|
int64_t default_int;
|
@@ -108,6 +109,7 @@ struct RubyArgs {
|
|
108
109
|
inline c10::optional<at::MemoryFormat> memoryformatOptional(int i);
|
109
110
|
// inline at::QScheme toQScheme(int i);
|
110
111
|
inline std::string string(int i);
|
112
|
+
inline std::string stringWithDefault(int i, const std::string& default_str);
|
111
113
|
inline c10::optional<std::string> stringOptional(int i);
|
112
114
|
inline c10::string_view stringView(int i);
|
113
115
|
// inline c10::string_view stringViewWithDefault(int i, const c10::string_view default_str);
|
@@ -345,6 +347,11 @@ inline c10::optional<at::MemoryFormat> RubyArgs::memoryformatOptional(int i) {
|
|
345
347
|
}
|
346
348
|
|
347
349
|
inline std::string RubyArgs::string(int i) {
|
350
|
+
return stringWithDefault(i, signature.params[i].default_string);
|
351
|
+
}
|
352
|
+
|
353
|
+
inline std::string RubyArgs::stringWithDefault(int i, const std::string& default_str) {
|
354
|
+
if (!args[i]) return default_str;
|
348
355
|
return Rice::detail::From_Ruby<std::string>().convert(args[i]);
|
349
356
|
}
|
350
357
|
|
data/lib/torch/tensor.rb
CHANGED
data/lib/torch/version.rb
CHANGED
data/lib/torch.rb
CHANGED
@@ -422,18 +422,6 @@ module Torch
|
|
422
422
|
_tensor(data, size, tensor_options(**options))
|
423
423
|
end
|
424
424
|
|
425
|
-
# center option
|
426
|
-
def stft(input, n_fft, hop_length: nil, win_length: nil, window: nil, center: true, pad_mode: "reflect", normalized: false, onesided: true, return_complex: nil)
|
427
|
-
if center
|
428
|
-
signal_dim = input.dim
|
429
|
-
extended_shape = [1] * (3 - signal_dim) + input.size
|
430
|
-
pad = n_fft.div(2).to_i
|
431
|
-
input = NN::F.pad(input.view(extended_shape), [pad, pad], mode: pad_mode)
|
432
|
-
input = input.view(input.shape[-signal_dim..-1])
|
433
|
-
end
|
434
|
-
_stft(input, n_fft, hop_length, win_length, window, normalized, onesided, return_complex)
|
435
|
-
end
|
436
|
-
|
437
425
|
private
|
438
426
|
|
439
427
|
def to_ivalue(obj)
|