torchaudio 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,71 @@
1
+ #include <sox.h>
2
+
3
+ #include <string>
4
+ #include <tuple>
5
+ #include <vector>
6
+ #include <unistd.h>
7
+
8
+ // same as <torch/extension.h> without <torch/python.h>
9
+ #include <torch/all.h>
10
+
11
+ namespace at {
12
+ struct Tensor;
13
+ } // namespace at
14
+
15
+ namespace torch { namespace audio {
16
+
17
+ /// Reads an audio file from the given `path` into the `output` `Tensor` and
18
+ /// returns the sample rate of the audio file.
19
+ /// Throws `std::runtime_error` if the audio file could not be opened, or an
20
+ /// error occurred during reading of the audio data.
21
+ int read_audio_file(
22
+ const std::string& file_name,
23
+ at::Tensor output,
24
+ bool ch_first,
25
+ int64_t nframes,
26
+ int64_t offset,
27
+ sox_signalinfo_t* si,
28
+ sox_encodinginfo_t* ei,
29
+ const char* ft);
30
+
31
+ /// Writes the data of a `Tensor` into an audio file at the given `path`, with
32
+ /// a certain extension (e.g. `wav`or `mp3`) and sample rate.
33
+ /// Throws `std::runtime_error` when the audio file could not be opened for
34
+ /// writing, or an error occurred during writing of the audio data.
35
+ void write_audio_file(
36
+ const std::string& file_name,
37
+ const at::Tensor& tensor,
38
+ sox_signalinfo_t* si,
39
+ sox_encodinginfo_t* ei,
40
+ const char* file_type);
41
+
42
+ /// Reads an audio file from the given `path` and returns a tuple of
43
+ /// sox_signalinfo_t and sox_encodinginfo_t, which contain information about
44
+ /// the audio file such as sample rate, length, bit precision, encoding and more.
45
+ /// Throws `std::runtime_error` if the audio file could not be opened, or an
46
+ /// error occurred during reading of the audio data.
47
+ std::tuple<sox_signalinfo_t, sox_encodinginfo_t> get_info(
48
+ const std::string& file_name);
49
+
50
+ // Struct for build_flow_effects function
51
+ struct SoxEffect {
52
+ SoxEffect() : ename(""), eopts({""}) { }
53
+ std::string ename;
54
+ std::vector<std::string> eopts;
55
+ };
56
+
57
+ /// Build a SoX chain, flow the effects, and capture the results in a tensor.
58
+ /// An audio file from the given `path` flows through an effects chain given
59
+ /// by a list of effects and effect options to an output buffer which is encoded
60
+ /// into memory to a target signal type and target signal encoding. The resulting
61
+ /// buffer is then placed into a tensor. This function returns the output tensor
62
+ /// and the sample rate of the output tensor.
63
+ int build_flow_effects(const std::string& file_name,
64
+ at::Tensor otensor,
65
+ bool ch_first,
66
+ sox_signalinfo_t* target_signal,
67
+ sox_encodinginfo_t* target_encoding,
68
+ const char* file_type,
69
+ std::vector<SoxEffect> pyeffs,
70
+ int max_num_eopts);
71
+ }} // namespace torch::audio
@@ -0,0 +1,54 @@
1
+ #include <sox.h>
2
+ #include <torchaudio/csrc/sox_effects.h>
3
+
4
+ using namespace torch::indexing;
5
+
6
+ namespace torchaudio {
7
+ namespace sox_effects {
8
+
9
+ namespace {
10
+
11
+ enum SoxEffectsResourceState { NotInitialized, Initialized, ShutDown };
12
+ SoxEffectsResourceState SOX_RESOURCE_STATE = NotInitialized;
13
+
14
+ } // namespace
15
+
16
+ void initialize_sox_effects() {
17
+ if (SOX_RESOURCE_STATE == ShutDown) {
18
+ throw std::runtime_error(
19
+ "SoX Effects has been shut down. Cannot initialize again.");
20
+ }
21
+ if (SOX_RESOURCE_STATE == NotInitialized) {
22
+ if (sox_init() != SOX_SUCCESS) {
23
+ throw std::runtime_error("Failed to initialize sox effects.");
24
+ };
25
+ SOX_RESOURCE_STATE = Initialized;
26
+ }
27
+ };
28
+
29
+ void shutdown_sox_effects() {
30
+ if (SOX_RESOURCE_STATE == NotInitialized) {
31
+ throw std::runtime_error(
32
+ "SoX Effects is not initialized. Cannot shutdown.");
33
+ }
34
+ if (SOX_RESOURCE_STATE == Initialized) {
35
+ if (sox_quit() != SOX_SUCCESS) {
36
+ throw std::runtime_error("Failed to initialize sox effects.");
37
+ };
38
+ SOX_RESOURCE_STATE = ShutDown;
39
+ }
40
+ }
41
+
42
+ std::vector<std::string> list_effects() {
43
+ std::vector<std::string> names;
44
+ const sox_effect_fn_t* fns = sox_get_effect_fns();
45
+ for (int i = 0; fns[i]; ++i) {
46
+ const sox_effect_handler_t* handler = fns[i]();
47
+ if (handler && handler->name)
48
+ names.push_back(handler->name);
49
+ }
50
+ return names;
51
+ }
52
+
53
+ } // namespace sox_effects
54
+ } // namespace torchaudio
@@ -0,0 +1,18 @@
1
+ #ifndef TORCHAUDIO_SOX_EFFECTS_H
2
+ #define TORCHAUDIO_SOX_EFFECTS_H
3
+
4
+ #include <torch/script.h>
5
+
6
+ namespace torchaudio {
7
+ namespace sox_effects {
8
+
9
+ void initialize_sox_effects();
10
+
11
+ void shutdown_sox_effects();
12
+
13
+ std::vector<std::string> list_effects();
14
+
15
+ } // namespace sox_effects
16
+ } // namespace torchaudio
17
+
18
+ #endif
@@ -0,0 +1,170 @@
1
+ #include <sox.h>
2
+ #include <torchaudio/csrc/sox_io.h>
3
+ #include <torchaudio/csrc/sox_utils.h>
4
+
5
+ using namespace torch::indexing;
6
+ using namespace torchaudio::sox_utils;
7
+
8
+ namespace torchaudio {
9
+ namespace sox_io {
10
+
11
+ SignalInfo::SignalInfo(
12
+ const int64_t sample_rate_,
13
+ const int64_t num_channels_,
14
+ const int64_t num_frames_)
15
+ : sample_rate(sample_rate_),
16
+ num_channels(num_channels_),
17
+ num_frames(num_frames_){};
18
+
19
+ int64_t SignalInfo::getSampleRate() const {
20
+ return sample_rate;
21
+ }
22
+
23
+ int64_t SignalInfo::getNumChannels() const {
24
+ return num_channels;
25
+ }
26
+
27
+ int64_t SignalInfo::getNumFrames() const {
28
+ return num_frames;
29
+ }
30
+
31
+ c10::intrusive_ptr<SignalInfo> get_info(const std::string& path) {
32
+ SoxFormat sf(sox_open_read(
33
+ path.c_str(),
34
+ /*signal=*/nullptr,
35
+ /*encoding=*/nullptr,
36
+ /*filetype=*/nullptr));
37
+
38
+ if (static_cast<sox_format_t*>(sf) == nullptr) {
39
+ throw std::runtime_error("Error opening audio file");
40
+ }
41
+
42
+ return c10::make_intrusive<SignalInfo>(
43
+ static_cast<int64_t>(sf->signal.rate),
44
+ static_cast<int64_t>(sf->signal.channels),
45
+ static_cast<int64_t>(sf->signal.length / sf->signal.channels));
46
+ }
47
+
48
+ c10::intrusive_ptr<TensorSignal> load_audio_file(
49
+ const std::string& path,
50
+ const int64_t frame_offset,
51
+ const int64_t num_frames,
52
+ const bool normalize,
53
+ const bool channels_first) {
54
+ if (frame_offset < 0) {
55
+ throw std::runtime_error(
56
+ "Invalid argument: frame_offset must be non-negative.");
57
+ }
58
+ if (num_frames == 0 || num_frames < -1) {
59
+ throw std::runtime_error(
60
+ "Invalid argument: num_frames must be -1 or greater than 0.");
61
+ }
62
+
63
+ SoxFormat sf(sox_open_read(
64
+ path.c_str(),
65
+ /*signal=*/nullptr,
66
+ /*encoding=*/nullptr,
67
+ /*filetype=*/nullptr));
68
+
69
+ validate_input_file(sf);
70
+
71
+ const int64_t num_channels = sf->signal.channels;
72
+ const int64_t num_total_samples = sf->signal.length;
73
+ const int64_t sample_start = sf->signal.channels * frame_offset;
74
+
75
+ if (sox_seek(sf, sample_start, 0) == SOX_EOF) {
76
+ throw std::runtime_error("Error reading audio file: offset past EOF.");
77
+ }
78
+
79
+ const int64_t sample_end = [&]() {
80
+ if (num_frames == -1)
81
+ return num_total_samples;
82
+ const int64_t sample_end_ = num_channels * num_frames + sample_start;
83
+ if (num_total_samples < sample_end_) {
84
+ // For lossy encoding, it is difficult to predict exact size of buffer for
85
+ // reading the number of samples required.
86
+ // So we allocate buffer size of given `num_frames` and ask sox to read as
87
+ // much as possible. For lossless format, sox reads exact number of
88
+ // samples, but for lossy encoding, sox can end up reading less. (i.e.
89
+ // mp3) For the consistent behavior specification between lossy/lossless
90
+ // format, we allow users to provide `num_frames` value that exceeds #of
91
+ // available samples, and we adjust it here.
92
+ return num_total_samples;
93
+ }
94
+ return sample_end_;
95
+ }();
96
+
97
+ const int64_t max_samples = sample_end - sample_start;
98
+
99
+ // Read samples into buffer
100
+ std::vector<sox_sample_t> buffer;
101
+ buffer.reserve(max_samples);
102
+ const int64_t num_samples = sox_read(sf, buffer.data(), max_samples);
103
+ if (num_samples == 0) {
104
+ throw std::runtime_error(
105
+ "Error reading audio file: empty file or read operation failed.");
106
+ }
107
+ // NOTE: num_samples may be smaller than max_samples if the input
108
+ // format is compressed (i.e. mp3).
109
+
110
+ // Convert to Tensor
111
+ auto tensor = convert_to_tensor(
112
+ buffer.data(),
113
+ num_samples,
114
+ num_channels,
115
+ get_dtype(sf->encoding.encoding, sf->signal.precision),
116
+ normalize,
117
+ channels_first);
118
+
119
+ return c10::make_intrusive<TensorSignal>(
120
+ tensor, static_cast<int64_t>(sf->signal.rate), channels_first);
121
+ }
122
+
123
+ void save_audio_file(
124
+ const std::string& file_name,
125
+ const c10::intrusive_ptr<TensorSignal>& signal,
126
+ const double compression) {
127
+ const auto tensor = signal->getTensor();
128
+ const auto sample_rate = signal->getSampleRate();
129
+ const auto channels_first = signal->getChannelsFirst();
130
+
131
+ validate_input_tensor(tensor);
132
+
133
+ const auto filetype = get_filetype(file_name);
134
+ const auto signal_info =
135
+ get_signalinfo(tensor, sample_rate, channels_first, filetype);
136
+ const auto encoding_info =
137
+ get_encodinginfo(filetype, tensor.dtype(), compression);
138
+
139
+ SoxFormat sf(sox_open_write(
140
+ file_name.c_str(),
141
+ &signal_info,
142
+ &encoding_info,
143
+ /*filetype=*/filetype.c_str(),
144
+ /*oob=*/nullptr,
145
+ /*overwrite_permitted=*/nullptr));
146
+
147
+ if (static_cast<sox_format_t*>(sf) == nullptr) {
148
+ throw std::runtime_error("Error saving audio file: failed to open file.");
149
+ }
150
+
151
+ auto tensor_ = tensor;
152
+ if (channels_first) {
153
+ tensor_ = tensor_.t();
154
+ }
155
+
156
+ const int64_t frames_per_chunk = 65536;
157
+ for (int64_t i = 0; i < tensor_.size(0); i += frames_per_chunk) {
158
+ auto chunk = tensor_.index({Slice(i, i + frames_per_chunk), Slice()});
159
+ chunk = unnormalize_wav(chunk).contiguous();
160
+
161
+ const size_t numel = chunk.numel();
162
+ if (sox_write(sf, chunk.data_ptr<int32_t>(), numel) != numel) {
163
+ throw std::runtime_error(
164
+ "Error saving audio file: failed to write the entier buffer.");
165
+ }
166
+ }
167
+ }
168
+
169
+ } // namespace sox_io
170
+ } // namespace torchaudio
@@ -0,0 +1,41 @@
1
+ #ifndef TORCHAUDIO_SOX_IO_H
2
+ #define TORCHAUDIO_SOX_IO_H
3
+
4
+ #include <torch/script.h>
5
+ #include <torchaudio/csrc/sox_utils.h>
6
+
7
+ namespace torchaudio {
8
+ namespace sox_io {
9
+
10
+ struct SignalInfo : torch::CustomClassHolder {
11
+ int64_t sample_rate;
12
+ int64_t num_channels;
13
+ int64_t num_frames;
14
+
15
+ SignalInfo(
16
+ const int64_t sample_rate_,
17
+ const int64_t num_channels_,
18
+ const int64_t num_frames_);
19
+ int64_t getSampleRate() const;
20
+ int64_t getNumChannels() const;
21
+ int64_t getNumFrames() const;
22
+ };
23
+
24
+ c10::intrusive_ptr<SignalInfo> get_info(const std::string& path);
25
+
26
+ c10::intrusive_ptr<torchaudio::sox_utils::TensorSignal> load_audio_file(
27
+ const std::string& path,
28
+ const int64_t frame_offset = 0,
29
+ const int64_t num_frames = -1,
30
+ const bool normalize = true,
31
+ const bool channels_first = true);
32
+
33
+ void save_audio_file(
34
+ const std::string& file_name,
35
+ const c10::intrusive_ptr<torchaudio::sox_utils::TensorSignal>& signal,
36
+ const double compression = 0.);
37
+
38
+ } // namespace sox_io
39
+ } // namespace torchaudio
40
+
41
+ #endif
@@ -0,0 +1,245 @@
1
+ #include <c10/core/ScalarType.h>
2
+ #include <sox.h>
3
+ #include <torchaudio/csrc/sox_utils.h>
4
+
5
+ namespace torchaudio {
6
+ namespace sox_utils {
7
+
8
+ TensorSignal::TensorSignal(
9
+ torch::Tensor tensor_,
10
+ int64_t sample_rate_,
11
+ bool channels_first_)
12
+ : tensor(tensor_),
13
+ sample_rate(sample_rate_),
14
+ channels_first(channels_first_){};
15
+
16
+ torch::Tensor TensorSignal::getTensor() const {
17
+ return tensor;
18
+ }
19
+ int64_t TensorSignal::getSampleRate() const {
20
+ return sample_rate;
21
+ }
22
+ bool TensorSignal::getChannelsFirst() const {
23
+ return channels_first;
24
+ }
25
+
26
+ SoxFormat::SoxFormat(sox_format_t* fd) noexcept : fd_(fd) {}
27
+ SoxFormat::~SoxFormat() {
28
+ if (fd_ != nullptr) {
29
+ sox_close(fd_);
30
+ }
31
+ }
32
+ sox_format_t* SoxFormat::operator->() const noexcept {
33
+ return fd_;
34
+ }
35
+ SoxFormat::operator sox_format_t*() const noexcept {
36
+ return fd_;
37
+ }
38
+
39
+ void validate_input_file(const SoxFormat& sf) {
40
+ if (static_cast<sox_format_t*>(sf) == nullptr) {
41
+ throw std::runtime_error("Error loading audio file: failed to open file.");
42
+ }
43
+ if (sf->encoding.encoding == SOX_ENCODING_UNKNOWN) {
44
+ throw std::runtime_error("Error loading audio file: unknown encoding.");
45
+ }
46
+ if (sf->signal.length == 0) {
47
+ throw std::runtime_error("Error reading audio file: unkown length.");
48
+ }
49
+ }
50
+
51
+ void validate_input_tensor(const torch::Tensor tensor) {
52
+ if (!tensor.device().is_cpu()) {
53
+ throw std::runtime_error("Input tensor has to be on CPU.");
54
+ }
55
+
56
+ if (tensor.ndimension() != 2) {
57
+ throw std::runtime_error("Input tensor has to be 2D.");
58
+ }
59
+
60
+ const auto dtype = tensor.dtype();
61
+ if (!(dtype == torch::kFloat32 || dtype == torch::kInt32 ||
62
+ dtype == torch::kInt16 || dtype == torch::kUInt8)) {
63
+ throw std::runtime_error(
64
+ "Input tensor has to be one of float32, int32, int16 or uint8 type.");
65
+ }
66
+ }
67
+
68
+ caffe2::TypeMeta get_dtype(
69
+ const sox_encoding_t encoding,
70
+ const unsigned precision) {
71
+ const auto dtype = [&]() {
72
+ switch (encoding) {
73
+ case SOX_ENCODING_UNSIGNED: // 8-bit PCM WAV
74
+ return torch::kUInt8;
75
+ case SOX_ENCODING_SIGN2: // 16-bit or 32-bit PCM WAV
76
+ switch (precision) {
77
+ case 16:
78
+ return torch::kInt16;
79
+ case 32:
80
+ return torch::kInt32;
81
+ default:
82
+ throw std::runtime_error(
83
+ "Only 16 and 32 bits are supported for signed PCM.");
84
+ }
85
+ default:
86
+ // default to float32 for the other formats, including
87
+ // 32-bit flaoting-point WAV,
88
+ // MP3,
89
+ // FLAC,
90
+ // VORBIS etc...
91
+ return torch::kFloat32;
92
+ }
93
+ }();
94
+ return c10::scalarTypeToTypeMeta(dtype);
95
+ }
96
+
97
+ torch::Tensor convert_to_tensor(
98
+ sox_sample_t* buffer,
99
+ const int32_t num_samples,
100
+ const int32_t num_channels,
101
+ const caffe2::TypeMeta dtype,
102
+ const bool normalize,
103
+ const bool channels_first) {
104
+ auto t = torch::from_blob(
105
+ buffer, {num_samples / num_channels, num_channels}, torch::kInt32);
106
+ // Note: Tensor created from_blob does not own data but borrwos
107
+ // So make sure to create a new copy after processing samples.
108
+ if (normalize || dtype == torch::kFloat32) {
109
+ t = t.to(torch::kFloat32);
110
+ t *= (t > 0) / 2147483647. + (t < 0) / 2147483648.;
111
+ } else if (dtype == torch::kInt32) {
112
+ t = t.clone();
113
+ } else if (dtype == torch::kInt16) {
114
+ t.floor_divide_(1 << 16);
115
+ t = t.to(torch::kInt16);
116
+ } else if (dtype == torch::kUInt8) {
117
+ t.floor_divide_(1 << 24);
118
+ t += 128;
119
+ t = t.to(torch::kUInt8);
120
+ } else {
121
+ throw std::runtime_error("Unsupported dtype.");
122
+ }
123
+ if (channels_first) {
124
+ t = t.transpose(1, 0);
125
+ }
126
+ return t.contiguous();
127
+ }
128
+
129
+ torch::Tensor unnormalize_wav(const torch::Tensor input_tensor) {
130
+ const auto dtype = input_tensor.dtype();
131
+ auto tensor = input_tensor;
132
+ if (dtype == torch::kFloat32) {
133
+ double multi_pos = 2147483647.;
134
+ double multi_neg = -2147483648.;
135
+ auto mult = (tensor > 0) * multi_pos - (tensor < 0) * multi_neg;
136
+ tensor = tensor.to(torch::dtype(torch::kFloat64));
137
+ tensor *= mult;
138
+ tensor.clamp_(multi_neg, multi_pos);
139
+ tensor = tensor.to(torch::dtype(torch::kInt32));
140
+ } else if (dtype == torch::kInt32) {
141
+ // already denormalized
142
+ } else if (dtype == torch::kInt16) {
143
+ tensor = tensor.to(torch::dtype(torch::kInt32));
144
+ tensor *= ((tensor != 0) * 65536);
145
+ } else if (dtype == torch::kUInt8) {
146
+ tensor = tensor.to(torch::dtype(torch::kInt32));
147
+ tensor -= 128;
148
+ tensor *= 16777216;
149
+ } else {
150
+ throw std::runtime_error("Unexpected dtype.");
151
+ }
152
+ return tensor;
153
+ }
154
+
155
+ const std::string get_filetype(const std::string path) {
156
+ std::string ext = path.substr(path.find_last_of(".") + 1);
157
+ std::transform(ext.begin(), ext.end(), ext.begin(), ::tolower);
158
+ return ext;
159
+ }
160
+
161
+ sox_encoding_t get_encoding(
162
+ const std::string filetype,
163
+ const caffe2::TypeMeta dtype) {
164
+ if (filetype == "mp3")
165
+ return SOX_ENCODING_MP3;
166
+ if (filetype == "flac")
167
+ return SOX_ENCODING_FLAC;
168
+ if (filetype == "ogg" || filetype == "vorbis")
169
+ return SOX_ENCODING_VORBIS;
170
+ if (filetype == "wav") {
171
+ if (dtype == torch::kUInt8)
172
+ return SOX_ENCODING_UNSIGNED;
173
+ if (dtype == torch::kInt16)
174
+ return SOX_ENCODING_SIGN2;
175
+ if (dtype == torch::kInt32)
176
+ return SOX_ENCODING_SIGN2;
177
+ if (dtype == torch::kFloat32)
178
+ return SOX_ENCODING_FLOAT;
179
+ throw std::runtime_error("Unsupported dtype.");
180
+ }
181
+ throw std::runtime_error("Unsupported file type.");
182
+ }
183
+
184
+ unsigned get_precision(
185
+ const std::string filetype,
186
+ const caffe2::TypeMeta dtype) {
187
+ if (filetype == "mp3")
188
+ return SOX_UNSPEC;
189
+ if (filetype == "flac")
190
+ return 24;
191
+ if (filetype == "ogg" || filetype == "vorbis")
192
+ return SOX_UNSPEC;
193
+ if (filetype == "wav") {
194
+ if (dtype == torch::kUInt8)
195
+ return 8;
196
+ if (dtype == torch::kInt16)
197
+ return 16;
198
+ if (dtype == torch::kInt32)
199
+ return 32;
200
+ if (dtype == torch::kFloat32)
201
+ return 32;
202
+ throw std::runtime_error("Unsupported dtype.");
203
+ }
204
+ throw std::runtime_error("Unsupported file type.");
205
+ }
206
+
207
+ sox_signalinfo_t get_signalinfo(
208
+ const torch::Tensor& tensor,
209
+ const int64_t sample_rate,
210
+ const bool channels_first,
211
+ const std::string filetype) {
212
+ return sox_signalinfo_t{
213
+ /*rate=*/static_cast<sox_rate_t>(sample_rate),
214
+ /*channels=*/static_cast<unsigned>(tensor.size(channels_first ? 0 : 1)),
215
+ /*precision=*/get_precision(filetype, tensor.dtype()),
216
+ /*length=*/static_cast<uint64_t>(tensor.numel())};
217
+ }
218
+
219
+ sox_encodinginfo_t get_encodinginfo(
220
+ const std::string filetype,
221
+ const caffe2::TypeMeta dtype,
222
+ const double compression) {
223
+ const double compression_ = [&]() {
224
+ if (filetype == "mp3")
225
+ return compression;
226
+ if (filetype == "flac")
227
+ return compression;
228
+ if (filetype == "ogg" || filetype == "vorbis")
229
+ return compression;
230
+ if (filetype == "wav")
231
+ return 0.;
232
+ throw std::runtime_error("Unsupported file type.");
233
+ }();
234
+
235
+ return sox_encodinginfo_t{/*encoding=*/get_encoding(filetype, dtype),
236
+ /*bits_per_sample=*/get_precision(filetype, dtype),
237
+ /*compression=*/compression_,
238
+ /*reverse_bytes=*/sox_option_default,
239
+ /*reverse_nibbles=*/sox_option_default,
240
+ /*reverse_bits=*/sox_option_default,
241
+ /*opposite_endian=*/sox_false};
242
+ }
243
+
244
+ } // namespace sox_utils
245
+ } // namespace torchaudio