torch-rb 0.10.1 → 0.11.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +14 -0
- data/README.md +16 -3
- data/codegen/function.rb +1 -1
- data/codegen/generate_functions.rb +31 -11
- data/codegen/native_functions.yaml +1362 -199
- data/ext/torch/extconf.rb +1 -13
- data/ext/torch/ruby_arg_parser.cpp +64 -2
- data/ext/torch/ruby_arg_parser.h +18 -3
- data/ext/torch/utils.h +1 -1
- data/lib/torch/tensor.rb +8 -5
- data/lib/torch/version.rb +1 -1
- data/lib/torch.rb +1 -12
- metadata +3 -3
data/ext/torch/extconf.rb
CHANGED
@@ -7,21 +7,9 @@ $CXXFLAGS += " -D_GLIBCXX_USE_CXX11_ABI=1"
|
|
7
7
|
|
8
8
|
apple_clang = RbConfig::CONFIG["CC_VERSION_MESSAGE"] =~ /apple clang/i
|
9
9
|
|
10
|
-
# check omp first
|
11
|
-
if have_library("omp") || have_library("gomp")
|
12
|
-
$CXXFLAGS += " -Xclang" if apple_clang
|
13
|
-
$CXXFLAGS += " -fopenmp"
|
14
|
-
end
|
15
|
-
|
16
10
|
if apple_clang
|
17
|
-
# silence rice warnings
|
18
|
-
$CXXFLAGS += " -Wno-deprecated-declarations"
|
19
|
-
|
20
|
-
# silence ruby/intern.h warning
|
21
|
-
$CXXFLAGS += " -Wno-deprecated-register"
|
22
|
-
|
23
11
|
# silence torch warnings
|
24
|
-
$CXXFLAGS += " -Wno-
|
12
|
+
$CXXFLAGS += " -Wno-deprecated-declarations"
|
25
13
|
else
|
26
14
|
# silence rice warnings
|
27
15
|
$CXXFLAGS += " -Wno-noexcept-type"
|
@@ -253,6 +253,68 @@ static inline std::vector<int64_t> parse_intlist_args(const std::string& s, int6
|
|
253
253
|
return args;
|
254
254
|
}
|
255
255
|
|
256
|
+
// Parse a string literal to remove quotes and escape sequences
|
257
|
+
static std::string parse_string_literal(c10::string_view str) {
|
258
|
+
TORCH_CHECK(str.length() >= 2, "String defaults must be quoted");
|
259
|
+
|
260
|
+
if (str.front() == '"') {
|
261
|
+
TORCH_CHECK(
|
262
|
+
str.back() == '"', "Mismatched quotes in string default: ", str);
|
263
|
+
} else {
|
264
|
+
TORCH_CHECK(
|
265
|
+
str.front() == '\'' && str.back() == '\'',
|
266
|
+
"Invalid quotes in string default: ",
|
267
|
+
str)
|
268
|
+
}
|
269
|
+
|
270
|
+
std::string parsed;
|
271
|
+
parsed.reserve(str.size());
|
272
|
+
for (size_t i = 1; i < str.size() - 1;) {
|
273
|
+
if (str[i] != '\\') {
|
274
|
+
parsed.push_back(str[i]);
|
275
|
+
++i;
|
276
|
+
continue;
|
277
|
+
}
|
278
|
+
|
279
|
+
// Handle escape sequences
|
280
|
+
TORCH_CHECK(
|
281
|
+
i < str.size() - 2, "String ends with escaped final quote: ", str)
|
282
|
+
char c = str[i + 1];
|
283
|
+
switch (c) {
|
284
|
+
case '\\':
|
285
|
+
case '\'':
|
286
|
+
case '\"':
|
287
|
+
break;
|
288
|
+
case 'a':
|
289
|
+
c = '\a';
|
290
|
+
break;
|
291
|
+
case 'b':
|
292
|
+
c = '\b';
|
293
|
+
break;
|
294
|
+
case 'f':
|
295
|
+
c = '\f';
|
296
|
+
break;
|
297
|
+
case 'n':
|
298
|
+
c = '\n';
|
299
|
+
break;
|
300
|
+
case 'v':
|
301
|
+
c = '\v';
|
302
|
+
break;
|
303
|
+
case 't':
|
304
|
+
c = '\t';
|
305
|
+
break;
|
306
|
+
default:
|
307
|
+
TORCH_CHECK(
|
308
|
+
false,
|
309
|
+
"Unsupported escape sequence in string default: \\",
|
310
|
+
str[i + 1]);
|
311
|
+
}
|
312
|
+
parsed.push_back(c);
|
313
|
+
i += 2;
|
314
|
+
}
|
315
|
+
return parsed;
|
316
|
+
}
|
317
|
+
|
256
318
|
void FunctionParameter::set_default_str(const std::string& str) {
|
257
319
|
if (str == "None") {
|
258
320
|
allow_none = true;
|
@@ -308,8 +370,8 @@ void FunctionParameter::set_default_str(const std::string& str) {
|
|
308
370
|
throw std::runtime_error("invalid device: " + str);
|
309
371
|
}
|
310
372
|
} else if (type_ == ParameterType::STRING) {
|
311
|
-
if (str != "None"
|
312
|
-
|
373
|
+
if (str != "None") {
|
374
|
+
default_string = parse_string_literal(str);
|
313
375
|
}
|
314
376
|
}
|
315
377
|
}
|
data/ext/torch/ruby_arg_parser.h
CHANGED
@@ -35,6 +35,7 @@ struct FunctionParameter {
|
|
35
35
|
at::SmallVector<VALUE, 5> numpy_python_names;
|
36
36
|
at::Scalar default_scalar;
|
37
37
|
std::vector<int64_t> default_intlist;
|
38
|
+
std::string default_string;
|
38
39
|
union {
|
39
40
|
bool default_bool;
|
40
41
|
int64_t default_int;
|
@@ -83,8 +84,8 @@ struct RubyArgs {
|
|
83
84
|
template<int N>
|
84
85
|
inline std::array<at::Tensor, N> tensorlist_n(int i);
|
85
86
|
inline std::vector<int64_t> intlist(int i);
|
86
|
-
|
87
|
-
|
87
|
+
inline c10::OptionalArray<int64_t> intlistOptional(int i);
|
88
|
+
inline std::vector<int64_t> intlistWithDefault(int i, std::vector<int64_t> default_intlist);
|
88
89
|
inline c10::optional<at::Generator> generator(int i);
|
89
90
|
inline at::Storage storage(int i);
|
90
91
|
inline at::ScalarType scalartype(int i);
|
@@ -108,6 +109,7 @@ struct RubyArgs {
|
|
108
109
|
inline c10::optional<at::MemoryFormat> memoryformatOptional(int i);
|
109
110
|
// inline at::QScheme toQScheme(int i);
|
110
111
|
inline std::string string(int i);
|
112
|
+
inline std::string stringWithDefault(int i, const std::string& default_str);
|
111
113
|
inline c10::optional<std::string> stringOptional(int i);
|
112
114
|
inline c10::string_view stringView(int i);
|
113
115
|
// inline c10::string_view stringViewWithDefault(int i, const c10::string_view default_str);
|
@@ -166,8 +168,11 @@ inline std::array<at::Tensor, N> RubyArgs::tensorlist_n(int i) {
|
|
166
168
|
}
|
167
169
|
|
168
170
|
inline std::vector<int64_t> RubyArgs::intlist(int i) {
|
169
|
-
|
171
|
+
return intlistWithDefault(i, signature.params[i].default_intlist);
|
172
|
+
}
|
170
173
|
|
174
|
+
inline std::vector<int64_t> RubyArgs::intlistWithDefault(int i, std::vector<int64_t> default_intlist) {
|
175
|
+
if (NIL_P(args[i])) return default_intlist;
|
171
176
|
VALUE arg = args[i];
|
172
177
|
auto size = signature.params[i].size;
|
173
178
|
if (size > 0 && FIXNUM_P(arg)) {
|
@@ -189,6 +194,11 @@ inline std::vector<int64_t> RubyArgs::intlist(int i) {
|
|
189
194
|
return res;
|
190
195
|
}
|
191
196
|
|
197
|
+
inline c10::OptionalArray<int64_t> RubyArgs::intlistOptional(int i) {
|
198
|
+
if (NIL_P(args[i])) return {};
|
199
|
+
return intlist(i);
|
200
|
+
}
|
201
|
+
|
192
202
|
inline c10::optional<at::Generator> RubyArgs::generator(int i) {
|
193
203
|
if (NIL_P(args[i])) return c10::nullopt;
|
194
204
|
throw std::runtime_error("generator not supported yet");
|
@@ -337,6 +347,11 @@ inline c10::optional<at::MemoryFormat> RubyArgs::memoryformatOptional(int i) {
|
|
337
347
|
}
|
338
348
|
|
339
349
|
inline std::string RubyArgs::string(int i) {
|
350
|
+
return stringWithDefault(i, signature.params[i].default_string);
|
351
|
+
}
|
352
|
+
|
353
|
+
inline std::string RubyArgs::stringWithDefault(int i, const std::string& default_str) {
|
354
|
+
if (!args[i]) return default_str;
|
340
355
|
return Rice::detail::From_Ruby<std::string>().convert(args[i]);
|
341
356
|
}
|
342
357
|
|
data/ext/torch/utils.h
CHANGED
data/lib/torch/tensor.rb
CHANGED
@@ -177,11 +177,6 @@ module Torch
|
|
177
177
|
_random!(*args)
|
178
178
|
end
|
179
179
|
|
180
|
-
# center option
|
181
|
-
def stft(*args)
|
182
|
-
Torch.stft(*args)
|
183
|
-
end
|
184
|
-
|
185
180
|
def dup
|
186
181
|
Torch.no_grad do
|
187
182
|
clone
|
@@ -199,5 +194,13 @@ module Torch
|
|
199
194
|
def real
|
200
195
|
Torch.real(self)
|
201
196
|
end
|
197
|
+
|
198
|
+
def coerce(other)
|
199
|
+
if other.is_a?(Numeric)
|
200
|
+
[Torch.tensor(other), self]
|
201
|
+
else
|
202
|
+
raise TypeError, "#{self.class} can't be coerced into #{other.class}"
|
203
|
+
end
|
204
|
+
end
|
202
205
|
end
|
203
206
|
end
|
data/lib/torch/version.rb
CHANGED
data/lib/torch.rb
CHANGED
@@ -249,6 +249,7 @@ module Torch
|
|
249
249
|
complex_float: 9,
|
250
250
|
complex64: 9,
|
251
251
|
complex_double: 10,
|
252
|
+
cdouble: 10,
|
252
253
|
complex128: 10,
|
253
254
|
bool: 11,
|
254
255
|
qint8: 12,
|
@@ -421,18 +422,6 @@ module Torch
|
|
421
422
|
_tensor(data, size, tensor_options(**options))
|
422
423
|
end
|
423
424
|
|
424
|
-
# center option
|
425
|
-
def stft(input, n_fft, hop_length: nil, win_length: nil, window: nil, center: true, pad_mode: "reflect", normalized: false, onesided: true, return_complex: nil)
|
426
|
-
if center
|
427
|
-
signal_dim = input.dim
|
428
|
-
extended_shape = [1] * (3 - signal_dim) + input.size
|
429
|
-
pad = n_fft.div(2).to_i
|
430
|
-
input = NN::F.pad(input.view(extended_shape), [pad, pad], mode: pad_mode)
|
431
|
-
input = input.view(input.shape[-signal_dim..-1])
|
432
|
-
end
|
433
|
-
_stft(input, n_fft, hop_length, win_length, window, normalized, onesided, return_complex)
|
434
|
-
end
|
435
|
-
|
436
425
|
private
|
437
426
|
|
438
427
|
def to_ivalue(obj)
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torch-rb
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.11.1
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2022-
|
11
|
+
date: 2022-07-06 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: rice
|
@@ -222,7 +222,7 @@ required_ruby_version: !ruby/object:Gem::Requirement
|
|
222
222
|
requirements:
|
223
223
|
- - ">="
|
224
224
|
- !ruby/object:Gem::Version
|
225
|
-
version: '2.
|
225
|
+
version: '2.7'
|
226
226
|
required_rubygems_version: !ruby/object:Gem::Requirement
|
227
227
|
requirements:
|
228
228
|
- - ">="
|