torchaudio 0.1.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +7 -0
- data/CHANGELOG.md +3 -0
- data/LICENSE.txt +26 -0
- data/README.md +93 -0
- data/ext/torchaudio/csrc/register.cpp +65 -0
- data/ext/torchaudio/csrc/sox.cpp +361 -0
- data/ext/torchaudio/csrc/sox.h +71 -0
- data/ext/torchaudio/csrc/sox_effects.cpp +54 -0
- data/ext/torchaudio/csrc/sox_effects.h +18 -0
- data/ext/torchaudio/csrc/sox_io.cpp +170 -0
- data/ext/torchaudio/csrc/sox_io.h +41 -0
- data/ext/torchaudio/csrc/sox_utils.cpp +245 -0
- data/ext/torchaudio/csrc/sox_utils.h +100 -0
- data/ext/torchaudio/ext.cpp +33 -0
- data/ext/torchaudio/extconf.rb +81 -0
- data/lib/torchaudio.rb +95 -0
- data/lib/torchaudio/datasets/utils.rb +92 -0
- data/lib/torchaudio/datasets/yesno.rb +59 -0
- data/lib/torchaudio/version.rb +3 -0
- metadata +145 -0
@@ -0,0 +1,71 @@
|
|
1
|
+
#include <sox.h>
|
2
|
+
|
3
|
+
#include <string>
|
4
|
+
#include <tuple>
|
5
|
+
#include <vector>
|
6
|
+
#include <unistd.h>
|
7
|
+
|
8
|
+
// same as <torch/extension.h> without <torch/python.h>
|
9
|
+
#include <torch/all.h>
|
10
|
+
|
11
|
+
namespace at {
|
12
|
+
struct Tensor;
|
13
|
+
} // namespace at
|
14
|
+
|
15
|
+
namespace torch { namespace audio {
|
16
|
+
|
17
|
+
/// Reads an audio file from the given `path` into the `output` `Tensor` and
|
18
|
+
/// returns the sample rate of the audio file.
|
19
|
+
/// Throws `std::runtime_error` if the audio file could not be opened, or an
|
20
|
+
/// error occurred during reading of the audio data.
|
21
|
+
int read_audio_file(
|
22
|
+
const std::string& file_name,
|
23
|
+
at::Tensor output,
|
24
|
+
bool ch_first,
|
25
|
+
int64_t nframes,
|
26
|
+
int64_t offset,
|
27
|
+
sox_signalinfo_t* si,
|
28
|
+
sox_encodinginfo_t* ei,
|
29
|
+
const char* ft);
|
30
|
+
|
31
|
+
/// Writes the data of a `Tensor` into an audio file at the given `path`, with
|
32
|
+
/// a certain extension (e.g. `wav`or `mp3`) and sample rate.
|
33
|
+
/// Throws `std::runtime_error` when the audio file could not be opened for
|
34
|
+
/// writing, or an error occurred during writing of the audio data.
|
35
|
+
void write_audio_file(
|
36
|
+
const std::string& file_name,
|
37
|
+
const at::Tensor& tensor,
|
38
|
+
sox_signalinfo_t* si,
|
39
|
+
sox_encodinginfo_t* ei,
|
40
|
+
const char* file_type);
|
41
|
+
|
42
|
+
/// Reads an audio file from the given `path` and returns a tuple of
|
43
|
+
/// sox_signalinfo_t and sox_encodinginfo_t, which contain information about
|
44
|
+
/// the audio file such as sample rate, length, bit precision, encoding and more.
|
45
|
+
/// Throws `std::runtime_error` if the audio file could not be opened, or an
|
46
|
+
/// error occurred during reading of the audio data.
|
47
|
+
std::tuple<sox_signalinfo_t, sox_encodinginfo_t> get_info(
|
48
|
+
const std::string& file_name);
|
49
|
+
|
50
|
+
// Struct for build_flow_effects function
|
51
|
+
struct SoxEffect {
|
52
|
+
SoxEffect() : ename(""), eopts({""}) { }
|
53
|
+
std::string ename;
|
54
|
+
std::vector<std::string> eopts;
|
55
|
+
};
|
56
|
+
|
57
|
+
/// Build a SoX chain, flow the effects, and capture the results in a tensor.
|
58
|
+
/// An audio file from the given `path` flows through an effects chain given
|
59
|
+
/// by a list of effects and effect options to an output buffer which is encoded
|
60
|
+
/// into memory to a target signal type and target signal encoding. The resulting
|
61
|
+
/// buffer is then placed into a tensor. This function returns the output tensor
|
62
|
+
/// and the sample rate of the output tensor.
|
63
|
+
int build_flow_effects(const std::string& file_name,
|
64
|
+
at::Tensor otensor,
|
65
|
+
bool ch_first,
|
66
|
+
sox_signalinfo_t* target_signal,
|
67
|
+
sox_encodinginfo_t* target_encoding,
|
68
|
+
const char* file_type,
|
69
|
+
std::vector<SoxEffect> pyeffs,
|
70
|
+
int max_num_eopts);
|
71
|
+
}} // namespace torch::audio
|
@@ -0,0 +1,54 @@
|
|
1
|
+
#include <sox.h>
|
2
|
+
#include <torchaudio/csrc/sox_effects.h>
|
3
|
+
|
4
|
+
using namespace torch::indexing;
|
5
|
+
|
6
|
+
namespace torchaudio {
|
7
|
+
namespace sox_effects {
|
8
|
+
|
9
|
+
namespace {
|
10
|
+
|
11
|
+
enum SoxEffectsResourceState { NotInitialized, Initialized, ShutDown };
|
12
|
+
SoxEffectsResourceState SOX_RESOURCE_STATE = NotInitialized;
|
13
|
+
|
14
|
+
} // namespace
|
15
|
+
|
16
|
+
void initialize_sox_effects() {
|
17
|
+
if (SOX_RESOURCE_STATE == ShutDown) {
|
18
|
+
throw std::runtime_error(
|
19
|
+
"SoX Effects has been shut down. Cannot initialize again.");
|
20
|
+
}
|
21
|
+
if (SOX_RESOURCE_STATE == NotInitialized) {
|
22
|
+
if (sox_init() != SOX_SUCCESS) {
|
23
|
+
throw std::runtime_error("Failed to initialize sox effects.");
|
24
|
+
};
|
25
|
+
SOX_RESOURCE_STATE = Initialized;
|
26
|
+
}
|
27
|
+
};
|
28
|
+
|
29
|
+
void shutdown_sox_effects() {
|
30
|
+
if (SOX_RESOURCE_STATE == NotInitialized) {
|
31
|
+
throw std::runtime_error(
|
32
|
+
"SoX Effects is not initialized. Cannot shutdown.");
|
33
|
+
}
|
34
|
+
if (SOX_RESOURCE_STATE == Initialized) {
|
35
|
+
if (sox_quit() != SOX_SUCCESS) {
|
36
|
+
throw std::runtime_error("Failed to initialize sox effects.");
|
37
|
+
};
|
38
|
+
SOX_RESOURCE_STATE = ShutDown;
|
39
|
+
}
|
40
|
+
}
|
41
|
+
|
42
|
+
std::vector<std::string> list_effects() {
|
43
|
+
std::vector<std::string> names;
|
44
|
+
const sox_effect_fn_t* fns = sox_get_effect_fns();
|
45
|
+
for (int i = 0; fns[i]; ++i) {
|
46
|
+
const sox_effect_handler_t* handler = fns[i]();
|
47
|
+
if (handler && handler->name)
|
48
|
+
names.push_back(handler->name);
|
49
|
+
}
|
50
|
+
return names;
|
51
|
+
}
|
52
|
+
|
53
|
+
} // namespace sox_effects
|
54
|
+
} // namespace torchaudio
|
@@ -0,0 +1,18 @@
|
|
1
|
+
#ifndef TORCHAUDIO_SOX_EFFECTS_H
|
2
|
+
#define TORCHAUDIO_SOX_EFFECTS_H
|
3
|
+
|
4
|
+
#include <torch/script.h>
|
5
|
+
|
6
|
+
namespace torchaudio {
|
7
|
+
namespace sox_effects {
|
8
|
+
|
9
|
+
void initialize_sox_effects();
|
10
|
+
|
11
|
+
void shutdown_sox_effects();
|
12
|
+
|
13
|
+
std::vector<std::string> list_effects();
|
14
|
+
|
15
|
+
} // namespace sox_effects
|
16
|
+
} // namespace torchaudio
|
17
|
+
|
18
|
+
#endif
|
@@ -0,0 +1,170 @@
|
|
1
|
+
#include <sox.h>
|
2
|
+
#include <torchaudio/csrc/sox_io.h>
|
3
|
+
#include <torchaudio/csrc/sox_utils.h>
|
4
|
+
|
5
|
+
using namespace torch::indexing;
|
6
|
+
using namespace torchaudio::sox_utils;
|
7
|
+
|
8
|
+
namespace torchaudio {
|
9
|
+
namespace sox_io {
|
10
|
+
|
11
|
+
SignalInfo::SignalInfo(
|
12
|
+
const int64_t sample_rate_,
|
13
|
+
const int64_t num_channels_,
|
14
|
+
const int64_t num_frames_)
|
15
|
+
: sample_rate(sample_rate_),
|
16
|
+
num_channels(num_channels_),
|
17
|
+
num_frames(num_frames_){};
|
18
|
+
|
19
|
+
int64_t SignalInfo::getSampleRate() const {
|
20
|
+
return sample_rate;
|
21
|
+
}
|
22
|
+
|
23
|
+
int64_t SignalInfo::getNumChannels() const {
|
24
|
+
return num_channels;
|
25
|
+
}
|
26
|
+
|
27
|
+
int64_t SignalInfo::getNumFrames() const {
|
28
|
+
return num_frames;
|
29
|
+
}
|
30
|
+
|
31
|
+
c10::intrusive_ptr<SignalInfo> get_info(const std::string& path) {
|
32
|
+
SoxFormat sf(sox_open_read(
|
33
|
+
path.c_str(),
|
34
|
+
/*signal=*/nullptr,
|
35
|
+
/*encoding=*/nullptr,
|
36
|
+
/*filetype=*/nullptr));
|
37
|
+
|
38
|
+
if (static_cast<sox_format_t*>(sf) == nullptr) {
|
39
|
+
throw std::runtime_error("Error opening audio file");
|
40
|
+
}
|
41
|
+
|
42
|
+
return c10::make_intrusive<SignalInfo>(
|
43
|
+
static_cast<int64_t>(sf->signal.rate),
|
44
|
+
static_cast<int64_t>(sf->signal.channels),
|
45
|
+
static_cast<int64_t>(sf->signal.length / sf->signal.channels));
|
46
|
+
}
|
47
|
+
|
48
|
+
c10::intrusive_ptr<TensorSignal> load_audio_file(
|
49
|
+
const std::string& path,
|
50
|
+
const int64_t frame_offset,
|
51
|
+
const int64_t num_frames,
|
52
|
+
const bool normalize,
|
53
|
+
const bool channels_first) {
|
54
|
+
if (frame_offset < 0) {
|
55
|
+
throw std::runtime_error(
|
56
|
+
"Invalid argument: frame_offset must be non-negative.");
|
57
|
+
}
|
58
|
+
if (num_frames == 0 || num_frames < -1) {
|
59
|
+
throw std::runtime_error(
|
60
|
+
"Invalid argument: num_frames must be -1 or greater than 0.");
|
61
|
+
}
|
62
|
+
|
63
|
+
SoxFormat sf(sox_open_read(
|
64
|
+
path.c_str(),
|
65
|
+
/*signal=*/nullptr,
|
66
|
+
/*encoding=*/nullptr,
|
67
|
+
/*filetype=*/nullptr));
|
68
|
+
|
69
|
+
validate_input_file(sf);
|
70
|
+
|
71
|
+
const int64_t num_channels = sf->signal.channels;
|
72
|
+
const int64_t num_total_samples = sf->signal.length;
|
73
|
+
const int64_t sample_start = sf->signal.channels * frame_offset;
|
74
|
+
|
75
|
+
if (sox_seek(sf, sample_start, 0) == SOX_EOF) {
|
76
|
+
throw std::runtime_error("Error reading audio file: offset past EOF.");
|
77
|
+
}
|
78
|
+
|
79
|
+
const int64_t sample_end = [&]() {
|
80
|
+
if (num_frames == -1)
|
81
|
+
return num_total_samples;
|
82
|
+
const int64_t sample_end_ = num_channels * num_frames + sample_start;
|
83
|
+
if (num_total_samples < sample_end_) {
|
84
|
+
// For lossy encoding, it is difficult to predict exact size of buffer for
|
85
|
+
// reading the number of samples required.
|
86
|
+
// So we allocate buffer size of given `num_frames` and ask sox to read as
|
87
|
+
// much as possible. For lossless format, sox reads exact number of
|
88
|
+
// samples, but for lossy encoding, sox can end up reading less. (i.e.
|
89
|
+
// mp3) For the consistent behavior specification between lossy/lossless
|
90
|
+
// format, we allow users to provide `num_frames` value that exceeds #of
|
91
|
+
// available samples, and we adjust it here.
|
92
|
+
return num_total_samples;
|
93
|
+
}
|
94
|
+
return sample_end_;
|
95
|
+
}();
|
96
|
+
|
97
|
+
const int64_t max_samples = sample_end - sample_start;
|
98
|
+
|
99
|
+
// Read samples into buffer
|
100
|
+
std::vector<sox_sample_t> buffer;
|
101
|
+
buffer.reserve(max_samples);
|
102
|
+
const int64_t num_samples = sox_read(sf, buffer.data(), max_samples);
|
103
|
+
if (num_samples == 0) {
|
104
|
+
throw std::runtime_error(
|
105
|
+
"Error reading audio file: empty file or read operation failed.");
|
106
|
+
}
|
107
|
+
// NOTE: num_samples may be smaller than max_samples if the input
|
108
|
+
// format is compressed (i.e. mp3).
|
109
|
+
|
110
|
+
// Convert to Tensor
|
111
|
+
auto tensor = convert_to_tensor(
|
112
|
+
buffer.data(),
|
113
|
+
num_samples,
|
114
|
+
num_channels,
|
115
|
+
get_dtype(sf->encoding.encoding, sf->signal.precision),
|
116
|
+
normalize,
|
117
|
+
channels_first);
|
118
|
+
|
119
|
+
return c10::make_intrusive<TensorSignal>(
|
120
|
+
tensor, static_cast<int64_t>(sf->signal.rate), channels_first);
|
121
|
+
}
|
122
|
+
|
123
|
+
void save_audio_file(
|
124
|
+
const std::string& file_name,
|
125
|
+
const c10::intrusive_ptr<TensorSignal>& signal,
|
126
|
+
const double compression) {
|
127
|
+
const auto tensor = signal->getTensor();
|
128
|
+
const auto sample_rate = signal->getSampleRate();
|
129
|
+
const auto channels_first = signal->getChannelsFirst();
|
130
|
+
|
131
|
+
validate_input_tensor(tensor);
|
132
|
+
|
133
|
+
const auto filetype = get_filetype(file_name);
|
134
|
+
const auto signal_info =
|
135
|
+
get_signalinfo(tensor, sample_rate, channels_first, filetype);
|
136
|
+
const auto encoding_info =
|
137
|
+
get_encodinginfo(filetype, tensor.dtype(), compression);
|
138
|
+
|
139
|
+
SoxFormat sf(sox_open_write(
|
140
|
+
file_name.c_str(),
|
141
|
+
&signal_info,
|
142
|
+
&encoding_info,
|
143
|
+
/*filetype=*/filetype.c_str(),
|
144
|
+
/*oob=*/nullptr,
|
145
|
+
/*overwrite_permitted=*/nullptr));
|
146
|
+
|
147
|
+
if (static_cast<sox_format_t*>(sf) == nullptr) {
|
148
|
+
throw std::runtime_error("Error saving audio file: failed to open file.");
|
149
|
+
}
|
150
|
+
|
151
|
+
auto tensor_ = tensor;
|
152
|
+
if (channels_first) {
|
153
|
+
tensor_ = tensor_.t();
|
154
|
+
}
|
155
|
+
|
156
|
+
const int64_t frames_per_chunk = 65536;
|
157
|
+
for (int64_t i = 0; i < tensor_.size(0); i += frames_per_chunk) {
|
158
|
+
auto chunk = tensor_.index({Slice(i, i + frames_per_chunk), Slice()});
|
159
|
+
chunk = unnormalize_wav(chunk).contiguous();
|
160
|
+
|
161
|
+
const size_t numel = chunk.numel();
|
162
|
+
if (sox_write(sf, chunk.data_ptr<int32_t>(), numel) != numel) {
|
163
|
+
throw std::runtime_error(
|
164
|
+
"Error saving audio file: failed to write the entier buffer.");
|
165
|
+
}
|
166
|
+
}
|
167
|
+
}
|
168
|
+
|
169
|
+
} // namespace sox_io
|
170
|
+
} // namespace torchaudio
|
@@ -0,0 +1,41 @@
|
|
1
|
+
#ifndef TORCHAUDIO_SOX_IO_H
|
2
|
+
#define TORCHAUDIO_SOX_IO_H
|
3
|
+
|
4
|
+
#include <torch/script.h>
|
5
|
+
#include <torchaudio/csrc/sox_utils.h>
|
6
|
+
|
7
|
+
namespace torchaudio {
|
8
|
+
namespace sox_io {
|
9
|
+
|
10
|
+
struct SignalInfo : torch::CustomClassHolder {
|
11
|
+
int64_t sample_rate;
|
12
|
+
int64_t num_channels;
|
13
|
+
int64_t num_frames;
|
14
|
+
|
15
|
+
SignalInfo(
|
16
|
+
const int64_t sample_rate_,
|
17
|
+
const int64_t num_channels_,
|
18
|
+
const int64_t num_frames_);
|
19
|
+
int64_t getSampleRate() const;
|
20
|
+
int64_t getNumChannels() const;
|
21
|
+
int64_t getNumFrames() const;
|
22
|
+
};
|
23
|
+
|
24
|
+
c10::intrusive_ptr<SignalInfo> get_info(const std::string& path);
|
25
|
+
|
26
|
+
c10::intrusive_ptr<torchaudio::sox_utils::TensorSignal> load_audio_file(
|
27
|
+
const std::string& path,
|
28
|
+
const int64_t frame_offset = 0,
|
29
|
+
const int64_t num_frames = -1,
|
30
|
+
const bool normalize = true,
|
31
|
+
const bool channels_first = true);
|
32
|
+
|
33
|
+
void save_audio_file(
|
34
|
+
const std::string& file_name,
|
35
|
+
const c10::intrusive_ptr<torchaudio::sox_utils::TensorSignal>& signal,
|
36
|
+
const double compression = 0.);
|
37
|
+
|
38
|
+
} // namespace sox_io
|
39
|
+
} // namespace torchaudio
|
40
|
+
|
41
|
+
#endif
|
@@ -0,0 +1,245 @@
|
|
1
|
+
#include <c10/core/ScalarType.h>
|
2
|
+
#include <sox.h>
|
3
|
+
#include <torchaudio/csrc/sox_utils.h>
|
4
|
+
|
5
|
+
namespace torchaudio {
|
6
|
+
namespace sox_utils {
|
7
|
+
|
8
|
+
TensorSignal::TensorSignal(
|
9
|
+
torch::Tensor tensor_,
|
10
|
+
int64_t sample_rate_,
|
11
|
+
bool channels_first_)
|
12
|
+
: tensor(tensor_),
|
13
|
+
sample_rate(sample_rate_),
|
14
|
+
channels_first(channels_first_){};
|
15
|
+
|
16
|
+
torch::Tensor TensorSignal::getTensor() const {
|
17
|
+
return tensor;
|
18
|
+
}
|
19
|
+
int64_t TensorSignal::getSampleRate() const {
|
20
|
+
return sample_rate;
|
21
|
+
}
|
22
|
+
bool TensorSignal::getChannelsFirst() const {
|
23
|
+
return channels_first;
|
24
|
+
}
|
25
|
+
|
26
|
+
SoxFormat::SoxFormat(sox_format_t* fd) noexcept : fd_(fd) {}
|
27
|
+
SoxFormat::~SoxFormat() {
|
28
|
+
if (fd_ != nullptr) {
|
29
|
+
sox_close(fd_);
|
30
|
+
}
|
31
|
+
}
|
32
|
+
sox_format_t* SoxFormat::operator->() const noexcept {
|
33
|
+
return fd_;
|
34
|
+
}
|
35
|
+
SoxFormat::operator sox_format_t*() const noexcept {
|
36
|
+
return fd_;
|
37
|
+
}
|
38
|
+
|
39
|
+
void validate_input_file(const SoxFormat& sf) {
|
40
|
+
if (static_cast<sox_format_t*>(sf) == nullptr) {
|
41
|
+
throw std::runtime_error("Error loading audio file: failed to open file.");
|
42
|
+
}
|
43
|
+
if (sf->encoding.encoding == SOX_ENCODING_UNKNOWN) {
|
44
|
+
throw std::runtime_error("Error loading audio file: unknown encoding.");
|
45
|
+
}
|
46
|
+
if (sf->signal.length == 0) {
|
47
|
+
throw std::runtime_error("Error reading audio file: unkown length.");
|
48
|
+
}
|
49
|
+
}
|
50
|
+
|
51
|
+
void validate_input_tensor(const torch::Tensor tensor) {
|
52
|
+
if (!tensor.device().is_cpu()) {
|
53
|
+
throw std::runtime_error("Input tensor has to be on CPU.");
|
54
|
+
}
|
55
|
+
|
56
|
+
if (tensor.ndimension() != 2) {
|
57
|
+
throw std::runtime_error("Input tensor has to be 2D.");
|
58
|
+
}
|
59
|
+
|
60
|
+
const auto dtype = tensor.dtype();
|
61
|
+
if (!(dtype == torch::kFloat32 || dtype == torch::kInt32 ||
|
62
|
+
dtype == torch::kInt16 || dtype == torch::kUInt8)) {
|
63
|
+
throw std::runtime_error(
|
64
|
+
"Input tensor has to be one of float32, int32, int16 or uint8 type.");
|
65
|
+
}
|
66
|
+
}
|
67
|
+
|
68
|
+
caffe2::TypeMeta get_dtype(
|
69
|
+
const sox_encoding_t encoding,
|
70
|
+
const unsigned precision) {
|
71
|
+
const auto dtype = [&]() {
|
72
|
+
switch (encoding) {
|
73
|
+
case SOX_ENCODING_UNSIGNED: // 8-bit PCM WAV
|
74
|
+
return torch::kUInt8;
|
75
|
+
case SOX_ENCODING_SIGN2: // 16-bit or 32-bit PCM WAV
|
76
|
+
switch (precision) {
|
77
|
+
case 16:
|
78
|
+
return torch::kInt16;
|
79
|
+
case 32:
|
80
|
+
return torch::kInt32;
|
81
|
+
default:
|
82
|
+
throw std::runtime_error(
|
83
|
+
"Only 16 and 32 bits are supported for signed PCM.");
|
84
|
+
}
|
85
|
+
default:
|
86
|
+
// default to float32 for the other formats, including
|
87
|
+
// 32-bit flaoting-point WAV,
|
88
|
+
// MP3,
|
89
|
+
// FLAC,
|
90
|
+
// VORBIS etc...
|
91
|
+
return torch::kFloat32;
|
92
|
+
}
|
93
|
+
}();
|
94
|
+
return c10::scalarTypeToTypeMeta(dtype);
|
95
|
+
}
|
96
|
+
|
97
|
+
torch::Tensor convert_to_tensor(
|
98
|
+
sox_sample_t* buffer,
|
99
|
+
const int32_t num_samples,
|
100
|
+
const int32_t num_channels,
|
101
|
+
const caffe2::TypeMeta dtype,
|
102
|
+
const bool normalize,
|
103
|
+
const bool channels_first) {
|
104
|
+
auto t = torch::from_blob(
|
105
|
+
buffer, {num_samples / num_channels, num_channels}, torch::kInt32);
|
106
|
+
// Note: Tensor created from_blob does not own data but borrwos
|
107
|
+
// So make sure to create a new copy after processing samples.
|
108
|
+
if (normalize || dtype == torch::kFloat32) {
|
109
|
+
t = t.to(torch::kFloat32);
|
110
|
+
t *= (t > 0) / 2147483647. + (t < 0) / 2147483648.;
|
111
|
+
} else if (dtype == torch::kInt32) {
|
112
|
+
t = t.clone();
|
113
|
+
} else if (dtype == torch::kInt16) {
|
114
|
+
t.floor_divide_(1 << 16);
|
115
|
+
t = t.to(torch::kInt16);
|
116
|
+
} else if (dtype == torch::kUInt8) {
|
117
|
+
t.floor_divide_(1 << 24);
|
118
|
+
t += 128;
|
119
|
+
t = t.to(torch::kUInt8);
|
120
|
+
} else {
|
121
|
+
throw std::runtime_error("Unsupported dtype.");
|
122
|
+
}
|
123
|
+
if (channels_first) {
|
124
|
+
t = t.transpose(1, 0);
|
125
|
+
}
|
126
|
+
return t.contiguous();
|
127
|
+
}
|
128
|
+
|
129
|
+
torch::Tensor unnormalize_wav(const torch::Tensor input_tensor) {
|
130
|
+
const auto dtype = input_tensor.dtype();
|
131
|
+
auto tensor = input_tensor;
|
132
|
+
if (dtype == torch::kFloat32) {
|
133
|
+
double multi_pos = 2147483647.;
|
134
|
+
double multi_neg = -2147483648.;
|
135
|
+
auto mult = (tensor > 0) * multi_pos - (tensor < 0) * multi_neg;
|
136
|
+
tensor = tensor.to(torch::dtype(torch::kFloat64));
|
137
|
+
tensor *= mult;
|
138
|
+
tensor.clamp_(multi_neg, multi_pos);
|
139
|
+
tensor = tensor.to(torch::dtype(torch::kInt32));
|
140
|
+
} else if (dtype == torch::kInt32) {
|
141
|
+
// already denormalized
|
142
|
+
} else if (dtype == torch::kInt16) {
|
143
|
+
tensor = tensor.to(torch::dtype(torch::kInt32));
|
144
|
+
tensor *= ((tensor != 0) * 65536);
|
145
|
+
} else if (dtype == torch::kUInt8) {
|
146
|
+
tensor = tensor.to(torch::dtype(torch::kInt32));
|
147
|
+
tensor -= 128;
|
148
|
+
tensor *= 16777216;
|
149
|
+
} else {
|
150
|
+
throw std::runtime_error("Unexpected dtype.");
|
151
|
+
}
|
152
|
+
return tensor;
|
153
|
+
}
|
154
|
+
|
155
|
+
const std::string get_filetype(const std::string path) {
|
156
|
+
std::string ext = path.substr(path.find_last_of(".") + 1);
|
157
|
+
std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower);
|
158
|
+
return ext;
|
159
|
+
}
|
160
|
+
|
161
|
+
sox_encoding_t get_encoding(
|
162
|
+
const std::string filetype,
|
163
|
+
const caffe2::TypeMeta dtype) {
|
164
|
+
if (filetype == "mp3")
|
165
|
+
return SOX_ENCODING_MP3;
|
166
|
+
if (filetype == "flac")
|
167
|
+
return SOX_ENCODING_FLAC;
|
168
|
+
if (filetype == "ogg" || filetype == "vorbis")
|
169
|
+
return SOX_ENCODING_VORBIS;
|
170
|
+
if (filetype == "wav") {
|
171
|
+
if (dtype == torch::kUInt8)
|
172
|
+
return SOX_ENCODING_UNSIGNED;
|
173
|
+
if (dtype == torch::kInt16)
|
174
|
+
return SOX_ENCODING_SIGN2;
|
175
|
+
if (dtype == torch::kInt32)
|
176
|
+
return SOX_ENCODING_SIGN2;
|
177
|
+
if (dtype == torch::kFloat32)
|
178
|
+
return SOX_ENCODING_FLOAT;
|
179
|
+
throw std::runtime_error("Unsupported dtype.");
|
180
|
+
}
|
181
|
+
throw std::runtime_error("Unsupported file type.");
|
182
|
+
}
|
183
|
+
|
184
|
+
unsigned get_precision(
|
185
|
+
const std::string filetype,
|
186
|
+
const caffe2::TypeMeta dtype) {
|
187
|
+
if (filetype == "mp3")
|
188
|
+
return SOX_UNSPEC;
|
189
|
+
if (filetype == "flac")
|
190
|
+
return 24;
|
191
|
+
if (filetype == "ogg" || filetype == "vorbis")
|
192
|
+
return SOX_UNSPEC;
|
193
|
+
if (filetype == "wav") {
|
194
|
+
if (dtype == torch::kUInt8)
|
195
|
+
return 8;
|
196
|
+
if (dtype == torch::kInt16)
|
197
|
+
return 16;
|
198
|
+
if (dtype == torch::kInt32)
|
199
|
+
return 32;
|
200
|
+
if (dtype == torch::kFloat32)
|
201
|
+
return 32;
|
202
|
+
throw std::runtime_error("Unsupported dtype.");
|
203
|
+
}
|
204
|
+
throw std::runtime_error("Unsupported file type.");
|
205
|
+
}
|
206
|
+
|
207
|
+
sox_signalinfo_t get_signalinfo(
|
208
|
+
const torch::Tensor& tensor,
|
209
|
+
const int64_t sample_rate,
|
210
|
+
const bool channels_first,
|
211
|
+
const std::string filetype) {
|
212
|
+
return sox_signalinfo_t{
|
213
|
+
/*rate=*/static_cast<sox_rate_t>(sample_rate),
|
214
|
+
/*channels=*/static_cast<unsigned>(tensor.size(channels_first ? 0 : 1)),
|
215
|
+
/*precision=*/get_precision(filetype, tensor.dtype()),
|
216
|
+
/*length=*/static_cast<uint64_t>(tensor.numel())};
|
217
|
+
}
|
218
|
+
|
219
|
+
sox_encodinginfo_t get_encodinginfo(
|
220
|
+
const std::string filetype,
|
221
|
+
const caffe2::TypeMeta dtype,
|
222
|
+
const double compression) {
|
223
|
+
const double compression_ = [&]() {
|
224
|
+
if (filetype == "mp3")
|
225
|
+
return compression;
|
226
|
+
if (filetype == "flac")
|
227
|
+
return compression;
|
228
|
+
if (filetype == "ogg" || filetype == "vorbis")
|
229
|
+
return compression;
|
230
|
+
if (filetype == "wav")
|
231
|
+
return 0.;
|
232
|
+
throw std::runtime_error("Unsupported file type.");
|
233
|
+
}();
|
234
|
+
|
235
|
+
return sox_encodinginfo_t{/*encoding=*/get_encoding(filetype, dtype),
|
236
|
+
/*bits_per_sample=*/get_precision(filetype, dtype),
|
237
|
+
/*compression=*/compression_,
|
238
|
+
/*reverse_bytes=*/sox_option_default,
|
239
|
+
/*reverse_nibbles=*/sox_option_default,
|
240
|
+
/*reverse_bits=*/sox_option_default,
|
241
|
+
/*opposite_endian=*/sox_false};
|
242
|
+
}
|
243
|
+
|
244
|
+
} // namespace sox_utils
|
245
|
+
} // namespace torchaudio
|