torchaudio 0.1.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,71 @@
1
+ #include <sox.h>
2
+
3
+ #include <string>
4
+ #include <tuple>
5
+ #include <vector>
6
+ #include <unistd.h>
7
+
8
+ // same as <torch/extension.h> without <torch/python.h>
9
+ #include <torch/all.h>
10
+
11
+ namespace at {
12
+ struct Tensor;
13
+ } // namespace at
14
+
15
+ namespace torch { namespace audio {
16
+
17
+ /// Reads an audio file from the given `path` into the `output` `Tensor` and
18
+ /// returns the sample rate of the audio file.
19
+ /// Throws `std::runtime_error` if the audio file could not be opened, or an
20
+ /// error occurred during reading of the audio data.
21
+ int read_audio_file(
22
+ const std::string& file_name,
23
+ at::Tensor output,
24
+ bool ch_first,
25
+ int64_t nframes,
26
+ int64_t offset,
27
+ sox_signalinfo_t* si,
28
+ sox_encodinginfo_t* ei,
29
+ const char* ft);
30
+
31
+ /// Writes the data of a `Tensor` into an audio file at the given `path`, with
32
+ /// a certain extension (e.g. `wav`or `mp3`) and sample rate.
33
+ /// Throws `std::runtime_error` when the audio file could not be opened for
34
+ /// writing, or an error occurred during writing of the audio data.
35
+ void write_audio_file(
36
+ const std::string& file_name,
37
+ const at::Tensor& tensor,
38
+ sox_signalinfo_t* si,
39
+ sox_encodinginfo_t* ei,
40
+ const char* file_type);
41
+
42
+ /// Reads an audio file from the given `path` and returns a tuple of
43
+ /// sox_signalinfo_t and sox_encodinginfo_t, which contain information about
44
+ /// the audio file such as sample rate, length, bit precision, encoding and more.
45
+ /// Throws `std::runtime_error` if the audio file could not be opened, or an
46
+ /// error occurred during reading of the audio data.
47
+ std::tuple<sox_signalinfo_t, sox_encodinginfo_t> get_info(
48
+ const std::string& file_name);
49
+
50
+ // Struct for build_flow_effects function
51
+ struct SoxEffect {
52
+ SoxEffect() : ename(""), eopts({""}) { }
53
+ std::string ename;
54
+ std::vector<std::string> eopts;
55
+ };
56
+
57
+ /// Build a SoX chain, flow the effects, and capture the results in a tensor.
58
+ /// An audio file from the given `path` flows through an effects chain given
59
+ /// by a list of effects and effect options to an output buffer which is encoded
60
+ /// into memory to a target signal type and target signal encoding. The resulting
61
+ /// buffer is then placed into a tensor. This function returns the output tensor
62
+ /// and the sample rate of the output tensor.
63
+ int build_flow_effects(const std::string& file_name,
64
+ at::Tensor otensor,
65
+ bool ch_first,
66
+ sox_signalinfo_t* target_signal,
67
+ sox_encodinginfo_t* target_encoding,
68
+ const char* file_type,
69
+ std::vector<SoxEffect> pyeffs,
70
+ int max_num_eopts);
71
+ }} // namespace torch::audio
@@ -0,0 +1,54 @@
1
+ #include <sox.h>
2
+ #include <torchaudio/csrc/sox_effects.h>
3
+
4
+ using namespace torch::indexing;
5
+
6
+ namespace torchaudio {
7
+ namespace sox_effects {
8
+
9
+ namespace {
10
+
11
+ enum SoxEffectsResourceState { NotInitialized, Initialized, ShutDown };
12
+ SoxEffectsResourceState SOX_RESOURCE_STATE = NotInitialized;
13
+
14
+ } // namespace
15
+
16
+ void initialize_sox_effects() {
17
+ if (SOX_RESOURCE_STATE == ShutDown) {
18
+ throw std::runtime_error(
19
+ "SoX Effects has been shut down. Cannot initialize again.");
20
+ }
21
+ if (SOX_RESOURCE_STATE == NotInitialized) {
22
+ if (sox_init() != SOX_SUCCESS) {
23
+ throw std::runtime_error("Failed to initialize sox effects.");
24
+ };
25
+ SOX_RESOURCE_STATE = Initialized;
26
+ }
27
+ };
28
+
29
+ void shutdown_sox_effects() {
30
+ if (SOX_RESOURCE_STATE == NotInitialized) {
31
+ throw std::runtime_error(
32
+ "SoX Effects is not initialized. Cannot shutdown.");
33
+ }
34
+ if (SOX_RESOURCE_STATE == Initialized) {
35
+ if (sox_quit() != SOX_SUCCESS) {
36
+ throw std::runtime_error("Failed to initialize sox effects.");
37
+ };
38
+ SOX_RESOURCE_STATE = ShutDown;
39
+ }
40
+ }
41
+
42
+ std::vector<std::string> list_effects() {
43
+ std::vector<std::string> names;
44
+ const sox_effect_fn_t* fns = sox_get_effect_fns();
45
+ for (int i = 0; fns[i]; ++i) {
46
+ const sox_effect_handler_t* handler = fns[i]();
47
+ if (handler && handler->name)
48
+ names.push_back(handler->name);
49
+ }
50
+ return names;
51
+ }
52
+
53
+ } // namespace sox_effects
54
+ } // namespace torchaudio
@@ -0,0 +1,18 @@
1
+ #ifndef TORCHAUDIO_SOX_EFFECTS_H
2
+ #define TORCHAUDIO_SOX_EFFECTS_H
3
+
4
+ #include <torch/script.h>
5
+
6
+ namespace torchaudio {
7
+ namespace sox_effects {
8
+
9
+ void initialize_sox_effects();
10
+
11
+ void shutdown_sox_effects();
12
+
13
+ std::vector<std::string> list_effects();
14
+
15
+ } // namespace sox_effects
16
+ } // namespace torchaudio
17
+
18
+ #endif
@@ -0,0 +1,170 @@
1
+ #include <sox.h>
2
+ #include <torchaudio/csrc/sox_io.h>
3
+ #include <torchaudio/csrc/sox_utils.h>
4
+
5
+ using namespace torch::indexing;
6
+ using namespace torchaudio::sox_utils;
7
+
8
+ namespace torchaudio {
9
+ namespace sox_io {
10
+
11
+ SignalInfo::SignalInfo(
12
+ const int64_t sample_rate_,
13
+ const int64_t num_channels_,
14
+ const int64_t num_frames_)
15
+ : sample_rate(sample_rate_),
16
+ num_channels(num_channels_),
17
+ num_frames(num_frames_){};
18
+
19
+ int64_t SignalInfo::getSampleRate() const {
20
+ return sample_rate;
21
+ }
22
+
23
+ int64_t SignalInfo::getNumChannels() const {
24
+ return num_channels;
25
+ }
26
+
27
+ int64_t SignalInfo::getNumFrames() const {
28
+ return num_frames;
29
+ }
30
+
31
+ c10::intrusive_ptr<SignalInfo> get_info(const std::string& path) {
32
+ SoxFormat sf(sox_open_read(
33
+ path.c_str(),
34
+ /*signal=*/nullptr,
35
+ /*encoding=*/nullptr,
36
+ /*filetype=*/nullptr));
37
+
38
+ if (static_cast<sox_format_t*>(sf) == nullptr) {
39
+ throw std::runtime_error("Error opening audio file");
40
+ }
41
+
42
+ return c10::make_intrusive<SignalInfo>(
43
+ static_cast<int64_t>(sf->signal.rate),
44
+ static_cast<int64_t>(sf->signal.channels),
45
+ static_cast<int64_t>(sf->signal.length / sf->signal.channels));
46
+ }
47
+
48
+ c10::intrusive_ptr<TensorSignal> load_audio_file(
49
+ const std::string& path,
50
+ const int64_t frame_offset,
51
+ const int64_t num_frames,
52
+ const bool normalize,
53
+ const bool channels_first) {
54
+ if (frame_offset < 0) {
55
+ throw std::runtime_error(
56
+ "Invalid argument: frame_offset must be non-negative.");
57
+ }
58
+ if (num_frames == 0 || num_frames < -1) {
59
+ throw std::runtime_error(
60
+ "Invalid argument: num_frames must be -1 or greater than 0.");
61
+ }
62
+
63
+ SoxFormat sf(sox_open_read(
64
+ path.c_str(),
65
+ /*signal=*/nullptr,
66
+ /*encoding=*/nullptr,
67
+ /*filetype=*/nullptr));
68
+
69
+ validate_input_file(sf);
70
+
71
+ const int64_t num_channels = sf->signal.channels;
72
+ const int64_t num_total_samples = sf->signal.length;
73
+ const int64_t sample_start = sf->signal.channels * frame_offset;
74
+
75
+ if (sox_seek(sf, sample_start, 0) == SOX_EOF) {
76
+ throw std::runtime_error("Error reading audio file: offset past EOF.");
77
+ }
78
+
79
+ const int64_t sample_end = [&]() {
80
+ if (num_frames == -1)
81
+ return num_total_samples;
82
+ const int64_t sample_end_ = num_channels * num_frames + sample_start;
83
+ if (num_total_samples < sample_end_) {
84
+ // For lossy encoding, it is difficult to predict exact size of buffer for
85
+ // reading the number of samples required.
86
+ // So we allocate buffer size of given `num_frames` and ask sox to read as
87
+ // much as possible. For lossless format, sox reads exact number of
88
+ // samples, but for lossy encoding, sox can end up reading less. (i.e.
89
+ // mp3) For the consistent behavior specification between lossy/lossless
90
+ // format, we allow users to provide `num_frames` value that exceeds #of
91
+ // available samples, and we adjust it here.
92
+ return num_total_samples;
93
+ }
94
+ return sample_end_;
95
+ }();
96
+
97
+ const int64_t max_samples = sample_end - sample_start;
98
+
99
+ // Read samples into buffer
100
+ std::vector<sox_sample_t> buffer;
101
+ buffer.reserve(max_samples);
102
+ const int64_t num_samples = sox_read(sf, buffer.data(), max_samples);
103
+ if (num_samples == 0) {
104
+ throw std::runtime_error(
105
+ "Error reading audio file: empty file or read operation failed.");
106
+ }
107
+ // NOTE: num_samples may be smaller than max_samples if the input
108
+ // format is compressed (i.e. mp3).
109
+
110
+ // Convert to Tensor
111
+ auto tensor = convert_to_tensor(
112
+ buffer.data(),
113
+ num_samples,
114
+ num_channels,
115
+ get_dtype(sf->encoding.encoding, sf->signal.precision),
116
+ normalize,
117
+ channels_first);
118
+
119
+ return c10::make_intrusive<TensorSignal>(
120
+ tensor, static_cast<int64_t>(sf->signal.rate), channels_first);
121
+ }
122
+
123
+ void save_audio_file(
124
+ const std::string& file_name,
125
+ const c10::intrusive_ptr<TensorSignal>& signal,
126
+ const double compression) {
127
+ const auto tensor = signal->getTensor();
128
+ const auto sample_rate = signal->getSampleRate();
129
+ const auto channels_first = signal->getChannelsFirst();
130
+
131
+ validate_input_tensor(tensor);
132
+
133
+ const auto filetype = get_filetype(file_name);
134
+ const auto signal_info =
135
+ get_signalinfo(tensor, sample_rate, channels_first, filetype);
136
+ const auto encoding_info =
137
+ get_encodinginfo(filetype, tensor.dtype(), compression);
138
+
139
+ SoxFormat sf(sox_open_write(
140
+ file_name.c_str(),
141
+ &signal_info,
142
+ &encoding_info,
143
+ /*filetype=*/filetype.c_str(),
144
+ /*oob=*/nullptr,
145
+ /*overwrite_permitted=*/nullptr));
146
+
147
+ if (static_cast<sox_format_t*>(sf) == nullptr) {
148
+ throw std::runtime_error("Error saving audio file: failed to open file.");
149
+ }
150
+
151
+ auto tensor_ = tensor;
152
+ if (channels_first) {
153
+ tensor_ = tensor_.t();
154
+ }
155
+
156
+ const int64_t frames_per_chunk = 65536;
157
+ for (int64_t i = 0; i < tensor_.size(0); i += frames_per_chunk) {
158
+ auto chunk = tensor_.index({Slice(i, i + frames_per_chunk), Slice()});
159
+ chunk = unnormalize_wav(chunk).contiguous();
160
+
161
+ const size_t numel = chunk.numel();
162
+ if (sox_write(sf, chunk.data_ptr<int32_t>(), numel) != numel) {
163
+ throw std::runtime_error(
164
+ "Error saving audio file: failed to write the entier buffer.");
165
+ }
166
+ }
167
+ }
168
+
169
+ } // namespace sox_io
170
+ } // namespace torchaudio
@@ -0,0 +1,41 @@
1
+ #ifndef TORCHAUDIO_SOX_IO_H
2
+ #define TORCHAUDIO_SOX_IO_H
3
+
4
+ #include <torch/script.h>
5
+ #include <torchaudio/csrc/sox_utils.h>
6
+
7
+ namespace torchaudio {
8
+ namespace sox_io {
9
+
10
+ struct SignalInfo : torch::CustomClassHolder {
11
+ int64_t sample_rate;
12
+ int64_t num_channels;
13
+ int64_t num_frames;
14
+
15
+ SignalInfo(
16
+ const int64_t sample_rate_,
17
+ const int64_t num_channels_,
18
+ const int64_t num_frames_);
19
+ int64_t getSampleRate() const;
20
+ int64_t getNumChannels() const;
21
+ int64_t getNumFrames() const;
22
+ };
23
+
24
+ c10::intrusive_ptr<SignalInfo> get_info(const std::string& path);
25
+
26
+ c10::intrusive_ptr<torchaudio::sox_utils::TensorSignal> load_audio_file(
27
+ const std::string& path,
28
+ const int64_t frame_offset = 0,
29
+ const int64_t num_frames = -1,
30
+ const bool normalize = true,
31
+ const bool channels_first = true);
32
+
33
+ void save_audio_file(
34
+ const std::string& file_name,
35
+ const c10::intrusive_ptr<torchaudio::sox_utils::TensorSignal>& signal,
36
+ const double compression = 0.);
37
+
38
+ } // namespace sox_io
39
+ } // namespace torchaudio
40
+
41
+ #endif
@@ -0,0 +1,245 @@
1
+ #include <c10/core/ScalarType.h>
2
+ #include <sox.h>
3
+ #include <torchaudio/csrc/sox_utils.h>
4
+
5
+ namespace torchaudio {
6
+ namespace sox_utils {
7
+
8
+ TensorSignal::TensorSignal(
9
+ torch::Tensor tensor_,
10
+ int64_t sample_rate_,
11
+ bool channels_first_)
12
+ : tensor(tensor_),
13
+ sample_rate(sample_rate_),
14
+ channels_first(channels_first_){};
15
+
16
+ torch::Tensor TensorSignal::getTensor() const {
17
+ return tensor;
18
+ }
19
+ int64_t TensorSignal::getSampleRate() const {
20
+ return sample_rate;
21
+ }
22
+ bool TensorSignal::getChannelsFirst() const {
23
+ return channels_first;
24
+ }
25
+
26
+ SoxFormat::SoxFormat(sox_format_t* fd) noexcept : fd_(fd) {}
27
+ SoxFormat::~SoxFormat() {
28
+ if (fd_ != nullptr) {
29
+ sox_close(fd_);
30
+ }
31
+ }
32
+ sox_format_t* SoxFormat::operator->() const noexcept {
33
+ return fd_;
34
+ }
35
+ SoxFormat::operator sox_format_t*() const noexcept {
36
+ return fd_;
37
+ }
38
+
39
+ void validate_input_file(const SoxFormat& sf) {
40
+ if (static_cast<sox_format_t*>(sf) == nullptr) {
41
+ throw std::runtime_error("Error loading audio file: failed to open file.");
42
+ }
43
+ if (sf->encoding.encoding == SOX_ENCODING_UNKNOWN) {
44
+ throw std::runtime_error("Error loading audio file: unknown encoding.");
45
+ }
46
+ if (sf->signal.length == 0) {
47
+ throw std::runtime_error("Error reading audio file: unkown length.");
48
+ }
49
+ }
50
+
51
+ void validate_input_tensor(const torch::Tensor tensor) {
52
+ if (!tensor.device().is_cpu()) {
53
+ throw std::runtime_error("Input tensor has to be on CPU.");
54
+ }
55
+
56
+ if (tensor.ndimension() != 2) {
57
+ throw std::runtime_error("Input tensor has to be 2D.");
58
+ }
59
+
60
+ const auto dtype = tensor.dtype();
61
+ if (!(dtype == torch::kFloat32 || dtype == torch::kInt32 ||
62
+ dtype == torch::kInt16 || dtype == torch::kUInt8)) {
63
+ throw std::runtime_error(
64
+ "Input tensor has to be one of float32, int32, int16 or uint8 type.");
65
+ }
66
+ }
67
+
68
+ caffe2::TypeMeta get_dtype(
69
+ const sox_encoding_t encoding,
70
+ const unsigned precision) {
71
+ const auto dtype = [&]() {
72
+ switch (encoding) {
73
+ case SOX_ENCODING_UNSIGNED: // 8-bit PCM WAV
74
+ return torch::kUInt8;
75
+ case SOX_ENCODING_SIGN2: // 16-bit or 32-bit PCM WAV
76
+ switch (precision) {
77
+ case 16:
78
+ return torch::kInt16;
79
+ case 32:
80
+ return torch::kInt32;
81
+ default:
82
+ throw std::runtime_error(
83
+ "Only 16 and 32 bits are supported for signed PCM.");
84
+ }
85
+ default:
86
+ // default to float32 for the other formats, including
87
+ // 32-bit flaoting-point WAV,
88
+ // MP3,
89
+ // FLAC,
90
+ // VORBIS etc...
91
+ return torch::kFloat32;
92
+ }
93
+ }();
94
+ return c10::scalarTypeToTypeMeta(dtype);
95
+ }
96
+
97
+ torch::Tensor convert_to_tensor(
98
+ sox_sample_t* buffer,
99
+ const int32_t num_samples,
100
+ const int32_t num_channels,
101
+ const caffe2::TypeMeta dtype,
102
+ const bool normalize,
103
+ const bool channels_first) {
104
+ auto t = torch::from_blob(
105
+ buffer, {num_samples / num_channels, num_channels}, torch::kInt32);
106
+ // Note: Tensor created from_blob does not own data but borrwos
107
+ // So make sure to create a new copy after processing samples.
108
+ if (normalize || dtype == torch::kFloat32) {
109
+ t = t.to(torch::kFloat32);
110
+ t *= (t > 0) / 2147483647. + (t < 0) / 2147483648.;
111
+ } else if (dtype == torch::kInt32) {
112
+ t = t.clone();
113
+ } else if (dtype == torch::kInt16) {
114
+ t.floor_divide_(1 << 16);
115
+ t = t.to(torch::kInt16);
116
+ } else if (dtype == torch::kUInt8) {
117
+ t.floor_divide_(1 << 24);
118
+ t += 128;
119
+ t = t.to(torch::kUInt8);
120
+ } else {
121
+ throw std::runtime_error("Unsupported dtype.");
122
+ }
123
+ if (channels_first) {
124
+ t = t.transpose(1, 0);
125
+ }
126
+ return t.contiguous();
127
+ }
128
+
129
+ torch::Tensor unnormalize_wav(const torch::Tensor input_tensor) {
130
+ const auto dtype = input_tensor.dtype();
131
+ auto tensor = input_tensor;
132
+ if (dtype == torch::kFloat32) {
133
+ double multi_pos = 2147483647.;
134
+ double multi_neg = -2147483648.;
135
+ auto mult = (tensor > 0) * multi_pos - (tensor < 0) * multi_neg;
136
+ tensor = tensor.to(torch::dtype(torch::kFloat64));
137
+ tensor *= mult;
138
+ tensor.clamp_(multi_neg, multi_pos);
139
+ tensor = tensor.to(torch::dtype(torch::kInt32));
140
+ } else if (dtype == torch::kInt32) {
141
+ // already denormalized
142
+ } else if (dtype == torch::kInt16) {
143
+ tensor = tensor.to(torch::dtype(torch::kInt32));
144
+ tensor *= ((tensor != 0) * 65536);
145
+ } else if (dtype == torch::kUInt8) {
146
+ tensor = tensor.to(torch::dtype(torch::kInt32));
147
+ tensor -= 128;
148
+ tensor *= 16777216;
149
+ } else {
150
+ throw std::runtime_error("Unexpected dtype.");
151
+ }
152
+ return tensor;
153
+ }
154
+
155
+ const std::string get_filetype(const std::string path) {
156
+ std::string ext = path.substr(path.find_last_of(".") + 1);
157
+ std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower);
158
+ return ext;
159
+ }
160
+
161
+ sox_encoding_t get_encoding(
162
+ const std::string filetype,
163
+ const caffe2::TypeMeta dtype) {
164
+ if (filetype == "mp3")
165
+ return SOX_ENCODING_MP3;
166
+ if (filetype == "flac")
167
+ return SOX_ENCODING_FLAC;
168
+ if (filetype == "ogg" || filetype == "vorbis")
169
+ return SOX_ENCODING_VORBIS;
170
+ if (filetype == "wav") {
171
+ if (dtype == torch::kUInt8)
172
+ return SOX_ENCODING_UNSIGNED;
173
+ if (dtype == torch::kInt16)
174
+ return SOX_ENCODING_SIGN2;
175
+ if (dtype == torch::kInt32)
176
+ return SOX_ENCODING_SIGN2;
177
+ if (dtype == torch::kFloat32)
178
+ return SOX_ENCODING_FLOAT;
179
+ throw std::runtime_error("Unsupported dtype.");
180
+ }
181
+ throw std::runtime_error("Unsupported file type.");
182
+ }
183
+
184
+ unsigned get_precision(
185
+ const std::string filetype,
186
+ const caffe2::TypeMeta dtype) {
187
+ if (filetype == "mp3")
188
+ return SOX_UNSPEC;
189
+ if (filetype == "flac")
190
+ return 24;
191
+ if (filetype == "ogg" || filetype == "vorbis")
192
+ return SOX_UNSPEC;
193
+ if (filetype == "wav") {
194
+ if (dtype == torch::kUInt8)
195
+ return 8;
196
+ if (dtype == torch::kInt16)
197
+ return 16;
198
+ if (dtype == torch::kInt32)
199
+ return 32;
200
+ if (dtype == torch::kFloat32)
201
+ return 32;
202
+ throw std::runtime_error("Unsupported dtype.");
203
+ }
204
+ throw std::runtime_error("Unsupported file type.");
205
+ }
206
+
207
+ sox_signalinfo_t get_signalinfo(
208
+ const torch::Tensor& tensor,
209
+ const int64_t sample_rate,
210
+ const bool channels_first,
211
+ const std::string filetype) {
212
+ return sox_signalinfo_t{
213
+ /*rate=*/static_cast<sox_rate_t>(sample_rate),
214
+ /*channels=*/static_cast<unsigned>(tensor.size(channels_first ? 0 : 1)),
215
+ /*precision=*/get_precision(filetype, tensor.dtype()),
216
+ /*length=*/static_cast<uint64_t>(tensor.numel())};
217
+ }
218
+
219
+ sox_encodinginfo_t get_encodinginfo(
220
+ const std::string filetype,
221
+ const caffe2::TypeMeta dtype,
222
+ const double compression) {
223
+ const double compression_ = [&]() {
224
+ if (filetype == "mp3")
225
+ return compression;
226
+ if (filetype == "flac")
227
+ return compression;
228
+ if (filetype == "ogg" || filetype == "vorbis")
229
+ return compression;
230
+ if (filetype == "wav")
231
+ return 0.;
232
+ throw std::runtime_error("Unsupported file type.");
233
+ }();
234
+
235
+ return sox_encodinginfo_t{/*encoding=*/get_encoding(filetype, dtype),
236
+ /*bits_per_sample=*/get_precision(filetype, dtype),
237
+ /*compression=*/compression_,
238
+ /*reverse_bytes=*/sox_option_default,
239
+ /*reverse_nibbles=*/sox_option_default,
240
+ /*reverse_bits=*/sox_option_default,
241
+ /*opposite_endian=*/sox_false};
242
+ }
243
+
244
+ } // namespace sox_utils
245
+ } // namespace torchaudio