torch-rb 0.4.2 → 0.6.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +26 -0
- data/README.md +13 -3
- data/codegen/generate_functions.rb +20 -13
- data/codegen/native_functions.yaml +4129 -1521
- data/ext/torch/cuda.cpp +14 -0
- data/ext/torch/device.cpp +21 -0
- data/ext/torch/ext.cpp +17 -623
- data/ext/torch/extconf.rb +0 -1
- data/ext/torch/ivalue.cpp +134 -0
- data/ext/torch/nn.cpp +114 -0
- data/ext/torch/nn_functions.h +1 -1
- data/ext/torch/random.cpp +22 -0
- data/ext/torch/ruby_arg_parser.cpp +1 -1
- data/ext/torch/ruby_arg_parser.h +47 -7
- data/ext/torch/templates.h +3 -2
- data/ext/torch/tensor.cpp +307 -0
- data/ext/torch/tensor_functions.h +1 -1
- data/ext/torch/torch.cpp +86 -0
- data/ext/torch/torch_functions.h +1 -1
- data/ext/torch/utils.h +8 -1
- data/ext/torch/wrap_outputs.h +7 -0
- data/lib/torch.rb +14 -17
- data/lib/torch/nn/linear.rb +2 -0
- data/lib/torch/nn/module.rb +107 -21
- data/lib/torch/nn/parameter.rb +1 -1
- data/lib/torch/optim/adadelta.rb +2 -2
- data/lib/torch/optim/adagrad.rb +2 -2
- data/lib/torch/optim/adam.rb +2 -2
- data/lib/torch/optim/adamax.rb +1 -1
- data/lib/torch/optim/adamw.rb +2 -2
- data/lib/torch/optim/rmsprop.rb +3 -3
- data/lib/torch/optim/rprop.rb +1 -1
- data/lib/torch/tensor.rb +9 -0
- data/lib/torch/utils/data/data_loader.rb +1 -1
- data/lib/torch/version.rb +1 -1
- metadata +12 -89
@@ -0,0 +1,307 @@
|
|
1
|
+
#include <torch/torch.h>
|
2
|
+
|
3
|
+
#include <rice/Constructor.hpp>
|
4
|
+
#include <rice/Module.hpp>
|
5
|
+
|
6
|
+
#include "tensor_functions.h"
|
7
|
+
#include "ruby_arg_parser.h"
|
8
|
+
#include "templates.h"
|
9
|
+
#include "utils.h"
|
10
|
+
|
11
|
+
using namespace Rice;
|
12
|
+
using torch::indexing::TensorIndex;
|
13
|
+
|
14
|
+
Class rb_cTensor;
|
15
|
+
|
16
|
+
std::vector<TensorIndex> index_vector(Array a) {
|
17
|
+
Object obj;
|
18
|
+
|
19
|
+
std::vector<TensorIndex> indices;
|
20
|
+
indices.reserve(a.size());
|
21
|
+
|
22
|
+
for (long i = 0; i < a.size(); i++) {
|
23
|
+
obj = a[i];
|
24
|
+
|
25
|
+
if (obj.is_instance_of(rb_cInteger)) {
|
26
|
+
indices.push_back(from_ruby<int64_t>(obj));
|
27
|
+
} else if (obj.is_instance_of(rb_cRange)) {
|
28
|
+
torch::optional<int64_t> start_index = torch::nullopt;
|
29
|
+
torch::optional<int64_t> stop_index = torch::nullopt;
|
30
|
+
|
31
|
+
Object begin = obj.call("begin");
|
32
|
+
if (!begin.is_nil()) {
|
33
|
+
start_index = from_ruby<int64_t>(begin);
|
34
|
+
}
|
35
|
+
|
36
|
+
Object end = obj.call("end");
|
37
|
+
if (!end.is_nil()) {
|
38
|
+
stop_index = from_ruby<int64_t>(end);
|
39
|
+
}
|
40
|
+
|
41
|
+
Object exclude_end = obj.call("exclude_end?");
|
42
|
+
if (stop_index.has_value() && !exclude_end) {
|
43
|
+
if (stop_index.value() == -1) {
|
44
|
+
stop_index = torch::nullopt;
|
45
|
+
} else {
|
46
|
+
stop_index = stop_index.value() + 1;
|
47
|
+
}
|
48
|
+
}
|
49
|
+
|
50
|
+
indices.push_back(torch::indexing::Slice(start_index, stop_index));
|
51
|
+
} else if (obj.is_instance_of(rb_cTensor)) {
|
52
|
+
indices.push_back(from_ruby<Tensor>(obj));
|
53
|
+
} else if (obj.is_nil()) {
|
54
|
+
indices.push_back(torch::indexing::None);
|
55
|
+
} else if (obj == True || obj == False) {
|
56
|
+
indices.push_back(from_ruby<bool>(obj));
|
57
|
+
} else {
|
58
|
+
throw Exception(rb_eArgError, "Unsupported index type: %s", rb_obj_classname(obj));
|
59
|
+
}
|
60
|
+
}
|
61
|
+
return indices;
|
62
|
+
}
|
63
|
+
|
64
|
+
// hack (removes inputs argument)
|
65
|
+
// https://github.com/pytorch/pytorch/commit/2e5bfa9824f549be69a28e4705a72b4cf8a4c519
|
66
|
+
// TODO add support for inputs argument
|
67
|
+
// _backward
|
68
|
+
static VALUE tensor__backward(int argc, VALUE* argv, VALUE self_)
|
69
|
+
{
|
70
|
+
HANDLE_TH_ERRORS
|
71
|
+
Tensor& self = from_ruby<Tensor&>(self_);
|
72
|
+
static RubyArgParser parser({
|
73
|
+
"_backward(Tensor? gradient=None, bool? retain_graph=None, bool create_graph=False)"
|
74
|
+
});
|
75
|
+
ParsedArgs<4> parsed_args;
|
76
|
+
auto _r = parser.parse(self_, argc, argv, parsed_args);
|
77
|
+
// _backward(Tensor self, Tensor[] inputs, Tensor? gradient=None, bool? retain_graph=None, bool create_graph=False) -> ()
|
78
|
+
auto dispatch__backward = [](const Tensor & self, TensorList inputs, const OptionalTensor & gradient, c10::optional<bool> retain_graph, bool create_graph) -> void {
|
79
|
+
// in future, release GVL
|
80
|
+
self._backward(inputs, gradient, retain_graph, create_graph);
|
81
|
+
};
|
82
|
+
dispatch__backward(self, {}, _r.optionalTensor(0), _r.toBoolOptional(1), _r.toBool(2));
|
83
|
+
RETURN_NIL
|
84
|
+
END_HANDLE_TH_ERRORS
|
85
|
+
}
|
86
|
+
|
87
|
+
void init_tensor(Rice::Module& m) {
|
88
|
+
rb_cTensor = Rice::define_class_under<torch::Tensor>(m, "Tensor");
|
89
|
+
rb_cTensor.add_handler<torch::Error>(handle_error);
|
90
|
+
add_tensor_functions(rb_cTensor);
|
91
|
+
THPVariableClass = rb_cTensor.value();
|
92
|
+
|
93
|
+
rb_define_method(rb_cTensor, "backward", (VALUE (*)(...)) tensor__backward, -1);
|
94
|
+
|
95
|
+
rb_cTensor
|
96
|
+
.define_method("cuda?", &torch::Tensor::is_cuda)
|
97
|
+
.define_method("sparse?", &torch::Tensor::is_sparse)
|
98
|
+
.define_method("quantized?", &torch::Tensor::is_quantized)
|
99
|
+
.define_method("dim", &torch::Tensor::dim)
|
100
|
+
.define_method("numel", &torch::Tensor::numel)
|
101
|
+
.define_method("element_size", &torch::Tensor::element_size)
|
102
|
+
.define_method("requires_grad", &torch::Tensor::requires_grad)
|
103
|
+
.define_method(
|
104
|
+
"_size",
|
105
|
+
*[](Tensor& self, int64_t dim) {
|
106
|
+
return self.size(dim);
|
107
|
+
})
|
108
|
+
.define_method(
|
109
|
+
"_stride",
|
110
|
+
*[](Tensor& self, int64_t dim) {
|
111
|
+
return self.stride(dim);
|
112
|
+
})
|
113
|
+
// in C++ for performance
|
114
|
+
.define_method(
|
115
|
+
"shape",
|
116
|
+
*[](Tensor& self) {
|
117
|
+
Array a;
|
118
|
+
for (auto &size : self.sizes()) {
|
119
|
+
a.push(size);
|
120
|
+
}
|
121
|
+
return a;
|
122
|
+
})
|
123
|
+
.define_method(
|
124
|
+
"_strides",
|
125
|
+
*[](Tensor& self) {
|
126
|
+
Array a;
|
127
|
+
for (auto &stride : self.strides()) {
|
128
|
+
a.push(stride);
|
129
|
+
}
|
130
|
+
return a;
|
131
|
+
})
|
132
|
+
.define_method(
|
133
|
+
"_index",
|
134
|
+
*[](Tensor& self, Array indices) {
|
135
|
+
auto vec = index_vector(indices);
|
136
|
+
return self.index(vec);
|
137
|
+
})
|
138
|
+
.define_method(
|
139
|
+
"_index_put_custom",
|
140
|
+
*[](Tensor& self, Array indices, torch::Tensor& value) {
|
141
|
+
auto vec = index_vector(indices);
|
142
|
+
return self.index_put_(vec, value);
|
143
|
+
})
|
144
|
+
.define_method(
|
145
|
+
"contiguous?",
|
146
|
+
*[](Tensor& self) {
|
147
|
+
return self.is_contiguous();
|
148
|
+
})
|
149
|
+
.define_method(
|
150
|
+
"_requires_grad!",
|
151
|
+
*[](Tensor& self, bool requires_grad) {
|
152
|
+
return self.set_requires_grad(requires_grad);
|
153
|
+
})
|
154
|
+
.define_method(
|
155
|
+
"grad",
|
156
|
+
*[](Tensor& self) {
|
157
|
+
auto grad = self.grad();
|
158
|
+
return grad.defined() ? to_ruby<torch::Tensor>(grad) : Nil;
|
159
|
+
})
|
160
|
+
.define_method(
|
161
|
+
"grad=",
|
162
|
+
*[](Tensor& self, torch::Tensor& grad) {
|
163
|
+
self.mutable_grad() = grad;
|
164
|
+
})
|
165
|
+
.define_method(
|
166
|
+
"_dtype",
|
167
|
+
*[](Tensor& self) {
|
168
|
+
return (int) at::typeMetaToScalarType(self.dtype());
|
169
|
+
})
|
170
|
+
.define_method(
|
171
|
+
"_type",
|
172
|
+
*[](Tensor& self, int dtype) {
|
173
|
+
return self.toType((torch::ScalarType) dtype);
|
174
|
+
})
|
175
|
+
.define_method(
|
176
|
+
"_layout",
|
177
|
+
*[](Tensor& self) {
|
178
|
+
std::stringstream s;
|
179
|
+
s << self.layout();
|
180
|
+
return s.str();
|
181
|
+
})
|
182
|
+
.define_method(
|
183
|
+
"device",
|
184
|
+
*[](Tensor& self) {
|
185
|
+
std::stringstream s;
|
186
|
+
s << self.device();
|
187
|
+
return s.str();
|
188
|
+
})
|
189
|
+
.define_method(
|
190
|
+
"_data_str",
|
191
|
+
*[](Tensor& self) {
|
192
|
+
Tensor tensor = self;
|
193
|
+
|
194
|
+
// move to CPU to get data
|
195
|
+
if (tensor.device().type() != torch::kCPU) {
|
196
|
+
torch::Device device("cpu");
|
197
|
+
tensor = tensor.to(device);
|
198
|
+
}
|
199
|
+
|
200
|
+
if (!tensor.is_contiguous()) {
|
201
|
+
tensor = tensor.contiguous();
|
202
|
+
}
|
203
|
+
|
204
|
+
auto data_ptr = (const char *) tensor.data_ptr();
|
205
|
+
return std::string(data_ptr, tensor.numel() * tensor.element_size());
|
206
|
+
})
|
207
|
+
// for TorchVision
|
208
|
+
.define_method(
|
209
|
+
"_data_ptr",
|
210
|
+
*[](Tensor& self) {
|
211
|
+
return reinterpret_cast<uintptr_t>(self.data_ptr());
|
212
|
+
})
|
213
|
+
// TODO figure out a better way to do this
|
214
|
+
.define_method(
|
215
|
+
"_flat_data",
|
216
|
+
*[](Tensor& self) {
|
217
|
+
Tensor tensor = self;
|
218
|
+
|
219
|
+
// move to CPU to get data
|
220
|
+
if (tensor.device().type() != torch::kCPU) {
|
221
|
+
torch::Device device("cpu");
|
222
|
+
tensor = tensor.to(device);
|
223
|
+
}
|
224
|
+
|
225
|
+
Array a;
|
226
|
+
auto dtype = tensor.dtype();
|
227
|
+
|
228
|
+
Tensor view = tensor.reshape({tensor.numel()});
|
229
|
+
|
230
|
+
// TODO DRY if someone knows C++
|
231
|
+
if (dtype == torch::kByte) {
|
232
|
+
for (int i = 0; i < tensor.numel(); i++) {
|
233
|
+
a.push(view[i].item().to<uint8_t>());
|
234
|
+
}
|
235
|
+
} else if (dtype == torch::kChar) {
|
236
|
+
for (int i = 0; i < tensor.numel(); i++) {
|
237
|
+
a.push(to_ruby<int>(view[i].item().to<int8_t>()));
|
238
|
+
}
|
239
|
+
} else if (dtype == torch::kShort) {
|
240
|
+
for (int i = 0; i < tensor.numel(); i++) {
|
241
|
+
a.push(view[i].item().to<int16_t>());
|
242
|
+
}
|
243
|
+
} else if (dtype == torch::kInt) {
|
244
|
+
for (int i = 0; i < tensor.numel(); i++) {
|
245
|
+
a.push(view[i].item().to<int32_t>());
|
246
|
+
}
|
247
|
+
} else if (dtype == torch::kLong) {
|
248
|
+
for (int i = 0; i < tensor.numel(); i++) {
|
249
|
+
a.push(view[i].item().to<int64_t>());
|
250
|
+
}
|
251
|
+
} else if (dtype == torch::kFloat) {
|
252
|
+
for (int i = 0; i < tensor.numel(); i++) {
|
253
|
+
a.push(view[i].item().to<float>());
|
254
|
+
}
|
255
|
+
} else if (dtype == torch::kDouble) {
|
256
|
+
for (int i = 0; i < tensor.numel(); i++) {
|
257
|
+
a.push(view[i].item().to<double>());
|
258
|
+
}
|
259
|
+
} else if (dtype == torch::kBool) {
|
260
|
+
for (int i = 0; i < tensor.numel(); i++) {
|
261
|
+
a.push(view[i].item().to<bool>() ? True : False);
|
262
|
+
}
|
263
|
+
} else {
|
264
|
+
throw std::runtime_error("Unsupported type");
|
265
|
+
}
|
266
|
+
return a;
|
267
|
+
})
|
268
|
+
.define_method(
|
269
|
+
"_to",
|
270
|
+
*[](Tensor& self, torch::Device device, int dtype, bool non_blocking, bool copy) {
|
271
|
+
return self.to(device, (torch::ScalarType) dtype, non_blocking, copy);
|
272
|
+
});
|
273
|
+
|
274
|
+
Rice::define_class_under<torch::TensorOptions>(m, "TensorOptions")
|
275
|
+
.add_handler<torch::Error>(handle_error)
|
276
|
+
.define_constructor(Rice::Constructor<torch::TensorOptions>())
|
277
|
+
.define_method(
|
278
|
+
"dtype",
|
279
|
+
*[](torch::TensorOptions& self, int dtype) {
|
280
|
+
return self.dtype((torch::ScalarType) dtype);
|
281
|
+
})
|
282
|
+
.define_method(
|
283
|
+
"layout",
|
284
|
+
*[](torch::TensorOptions& self, const std::string& layout) {
|
285
|
+
torch::Layout l;
|
286
|
+
if (layout == "strided") {
|
287
|
+
l = torch::kStrided;
|
288
|
+
} else if (layout == "sparse") {
|
289
|
+
l = torch::kSparse;
|
290
|
+
throw std::runtime_error("Sparse layout not supported yet");
|
291
|
+
} else {
|
292
|
+
throw std::runtime_error("Unsupported layout: " + layout);
|
293
|
+
}
|
294
|
+
return self.layout(l);
|
295
|
+
})
|
296
|
+
.define_method(
|
297
|
+
"device",
|
298
|
+
*[](torch::TensorOptions& self, const std::string& device) {
|
299
|
+
torch::Device d(device);
|
300
|
+
return self.device(d);
|
301
|
+
})
|
302
|
+
.define_method(
|
303
|
+
"requires_grad",
|
304
|
+
*[](torch::TensorOptions& self, bool requires_grad) {
|
305
|
+
return self.requires_grad(requires_grad);
|
306
|
+
});
|
307
|
+
}
|
data/ext/torch/torch.cpp
ADDED
@@ -0,0 +1,86 @@
|
|
1
|
+
#include <torch/torch.h>
|
2
|
+
|
3
|
+
#include <rice/Module.hpp>
|
4
|
+
|
5
|
+
#include "torch_functions.h"
|
6
|
+
#include "templates.h"
|
7
|
+
#include "utils.h"
|
8
|
+
|
9
|
+
void init_torch(Rice::Module& m) {
|
10
|
+
m.add_handler<torch::Error>(handle_error);
|
11
|
+
add_torch_functions(m);
|
12
|
+
m.define_singleton_method(
|
13
|
+
"grad_enabled?",
|
14
|
+
*[]() {
|
15
|
+
return torch::GradMode::is_enabled();
|
16
|
+
})
|
17
|
+
.define_singleton_method(
|
18
|
+
"_set_grad_enabled",
|
19
|
+
*[](bool enabled) {
|
20
|
+
torch::GradMode::set_enabled(enabled);
|
21
|
+
})
|
22
|
+
.define_singleton_method(
|
23
|
+
"manual_seed",
|
24
|
+
*[](uint64_t seed) {
|
25
|
+
return torch::manual_seed(seed);
|
26
|
+
})
|
27
|
+
// config
|
28
|
+
.define_singleton_method(
|
29
|
+
"show_config",
|
30
|
+
*[] {
|
31
|
+
return torch::show_config();
|
32
|
+
})
|
33
|
+
.define_singleton_method(
|
34
|
+
"parallel_info",
|
35
|
+
*[] {
|
36
|
+
return torch::get_parallel_info();
|
37
|
+
})
|
38
|
+
// begin operations
|
39
|
+
.define_singleton_method(
|
40
|
+
"_save",
|
41
|
+
*[](const torch::IValue &value) {
|
42
|
+
auto v = torch::pickle_save(value);
|
43
|
+
std::string str(v.begin(), v.end());
|
44
|
+
return str;
|
45
|
+
})
|
46
|
+
.define_singleton_method(
|
47
|
+
"_load",
|
48
|
+
*[](const std::string &s) {
|
49
|
+
std::vector<char> v;
|
50
|
+
std::copy(s.begin(), s.end(), std::back_inserter(v));
|
51
|
+
// https://github.com/pytorch/pytorch/issues/20356#issuecomment-567663701
|
52
|
+
return torch::pickle_load(v);
|
53
|
+
})
|
54
|
+
.define_singleton_method(
|
55
|
+
"_from_blob",
|
56
|
+
*[](Rice::String s, std::vector<int64_t> size, const torch::TensorOptions &options) {
|
57
|
+
void *data = const_cast<char *>(s.c_str());
|
58
|
+
return torch::from_blob(data, size, options);
|
59
|
+
})
|
60
|
+
.define_singleton_method(
|
61
|
+
"_tensor",
|
62
|
+
*[](Rice::Array a, std::vector<int64_t> size, const torch::TensorOptions &options) {
|
63
|
+
auto dtype = options.dtype();
|
64
|
+
torch::Tensor t;
|
65
|
+
if (dtype == torch::kBool) {
|
66
|
+
std::vector<uint8_t> vec;
|
67
|
+
for (long i = 0; i < a.size(); i++) {
|
68
|
+
vec.push_back(from_ruby<bool>(a[i]));
|
69
|
+
}
|
70
|
+
t = torch::tensor(vec, options);
|
71
|
+
} else {
|
72
|
+
std::vector<float> vec;
|
73
|
+
for (long i = 0; i < a.size(); i++) {
|
74
|
+
vec.push_back(from_ruby<float>(a[i]));
|
75
|
+
}
|
76
|
+
// hack for requires_grad error
|
77
|
+
if (options.requires_grad()) {
|
78
|
+
t = torch::tensor(vec, options.requires_grad(c10::nullopt));
|
79
|
+
t.set_requires_grad(true);
|
80
|
+
} else {
|
81
|
+
t = torch::tensor(vec, options);
|
82
|
+
}
|
83
|
+
}
|
84
|
+
return t.reshape(size);
|
85
|
+
});
|
86
|
+
}
|
data/ext/torch/torch_functions.h
CHANGED
data/ext/torch/utils.h
CHANGED
@@ -1,13 +1,20 @@
|
|
1
1
|
#pragma once
|
2
2
|
|
3
|
+
#include <rice/Exception.hpp>
|
3
4
|
#include <rice/Symbol.hpp>
|
4
5
|
|
6
|
+
// TODO find better place
|
7
|
+
inline void handle_error(torch::Error const & ex)
|
8
|
+
{
|
9
|
+
throw Rice::Exception(rb_eRuntimeError, ex.what_without_backtrace());
|
10
|
+
}
|
11
|
+
|
5
12
|
// keep THP prefix for now to make it easier to compare code
|
6
13
|
|
7
14
|
extern VALUE THPVariableClass;
|
8
15
|
|
9
16
|
inline VALUE THPUtils_internSymbol(const std::string& str) {
|
10
|
-
return Symbol(str);
|
17
|
+
return Rice::Symbol(str);
|
11
18
|
}
|
12
19
|
|
13
20
|
inline std::string THPUtils_unpackSymbol(VALUE obj) {
|
data/ext/torch/wrap_outputs.h
CHANGED
data/lib/torch.rb
CHANGED
@@ -261,6 +261,8 @@ module Torch
|
|
261
261
|
Torch._from_blob(bytes, [bytes.bytesize], TensorOptions.new.dtype(DTYPE_TO_ENUM[dtype]))
|
262
262
|
elsif args.size == 1 && args.first.is_a?(Array)
|
263
263
|
Torch.tensor(args.first, dtype: dtype, device: device)
|
264
|
+
elsif args.size == 0
|
265
|
+
Torch.empty(0, dtype: dtype, device: device)
|
264
266
|
else
|
265
267
|
Torch.empty(*args, dtype: dtype, device: device)
|
266
268
|
end
|
@@ -335,25 +337,24 @@ module Torch
|
|
335
337
|
}
|
336
338
|
end
|
337
339
|
|
338
|
-
def no_grad
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
_set_grad_enabled(previous_value)
|
345
|
-
end
|
340
|
+
def no_grad(&block)
|
341
|
+
grad_enabled(false, &block)
|
342
|
+
end
|
343
|
+
|
344
|
+
def enable_grad(&block)
|
345
|
+
grad_enabled(true, &block)
|
346
346
|
end
|
347
347
|
|
348
|
-
def
|
348
|
+
def grad_enabled(value)
|
349
349
|
previous_value = grad_enabled?
|
350
350
|
begin
|
351
|
-
_set_grad_enabled(
|
351
|
+
_set_grad_enabled(value)
|
352
352
|
yield
|
353
353
|
ensure
|
354
354
|
_set_grad_enabled(previous_value)
|
355
355
|
end
|
356
356
|
end
|
357
|
+
alias_method :set_grad_enabled, :grad_enabled
|
357
358
|
|
358
359
|
def device(str)
|
359
360
|
Device.new(str)
|
@@ -434,7 +435,8 @@ module Torch
|
|
434
435
|
zeros(input.size, **like_options(input, options))
|
435
436
|
end
|
436
437
|
|
437
|
-
|
438
|
+
# center option
|
439
|
+
def stft(input, n_fft, hop_length: nil, win_length: nil, window: nil, center: true, pad_mode: "reflect", normalized: false, onesided: true, return_complex: nil)
|
438
440
|
if center
|
439
441
|
signal_dim = input.dim
|
440
442
|
extended_shape = [1] * (3 - signal_dim) + input.size
|
@@ -442,12 +444,7 @@ module Torch
|
|
442
444
|
input = NN::F.pad(input.view(extended_shape), [pad, pad], mode: pad_mode)
|
443
445
|
input = input.view(input.shape[-signal_dim..-1])
|
444
446
|
end
|
445
|
-
_stft(input, n_fft, hop_length, win_length, window, normalized, onesided)
|
446
|
-
end
|
447
|
-
|
448
|
-
def clamp(tensor, min, max)
|
449
|
-
tensor = _clamp_min(tensor, min)
|
450
|
-
_clamp_max(tensor, max)
|
447
|
+
_stft(input, n_fft, hop_length, win_length, window, normalized, onesided, return_complex)
|
451
448
|
end
|
452
449
|
|
453
450
|
private
|