torchaudio 0.1.0 → 0.1.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +5 -0
- data/README.md +36 -0
- data/ext/torchaudio/ext.cpp +54 -12
- data/lib/torchaudio.rb +70 -0
- data/lib/torchaudio/functional.rb +285 -0
- data/lib/torchaudio/transforms/mel_scale.rb +39 -0
- data/lib/torchaudio/transforms/mel_spectrogram.rb +35 -0
- data/lib/torchaudio/transforms/mu_law_decoding.rb +14 -0
- data/lib/torchaudio/transforms/mu_law_encoding.rb +14 -0
- data/lib/torchaudio/transforms/spectrogram.rb +27 -0
- data/lib/torchaudio/version.rb +1 -1
- metadata +10 -4
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: 354a8636b2a04fd0ed60db93e2667f630f1db04ab047a6eb6e0c3607190c58d5
|
4
|
+
data.tar.gz: cde6229f8ba416c5a72bbea6d47e26fd987b6f13073f1ca9a1e6239cd62b79c1
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: 76c9c637e4f12700c16bbf343104bb46f5e4f04161a3556cd0600ea236222902ac0ba3b9504f4bdb28a9919317a1b791c8f7b4dbcb48b35ed4c674de39a83277
|
7
|
+
data.tar.gz: '019d89cd2231c386c3458447358058618a6de78b77ed5b4d6d279ee0e3879946bda03cf339051136069dd552d6eeaed33bf40defa9d6c90c9a07eb27d68fdce2'
|
data/CHANGELOG.md
CHANGED
data/README.md
CHANGED
@@ -2,6 +2,8 @@
|
|
2
2
|
|
3
3
|
:fire: An audio library for Torch.rb
|
4
4
|
|
5
|
+
[](https://travis-ci.org/ankane/torchaudio)
|
6
|
+
|
5
7
|
## Installation
|
6
8
|
|
7
9
|
First, [install SoX](#sox-installation). For Homebrew, use:
|
@@ -20,6 +22,40 @@ gem 'torchaudio'
|
|
20
22
|
|
21
23
|
This library follows the [Python API](https://pytorch.org/audio/). Many methods and options are missing at the moment. PRs welcome!
|
22
24
|
|
25
|
+
## Tutorial
|
26
|
+
|
27
|
+
- [PyTorch tutorial](https://pytorch.org/tutorials/beginner/audio_preprocessing_tutorial.html)
|
28
|
+
- [Ruby code](examples/tutorial.rb)
|
29
|
+
|
30
|
+
Download the [audio file](https://github.com/pytorch/tutorials/raw/master/_static/img/steam-train-whistle-daniel_simon-converted-from-mp3.wav) and install the [matplotlib](https://github.com/mrkn/matplotlib.rb) gem first.
|
31
|
+
|
32
|
+
## Basics
|
33
|
+
|
34
|
+
Load a file
|
35
|
+
|
36
|
+
```ruby
|
37
|
+
waveform, sample_rate = TorchAudio.load("file.wav")
|
38
|
+
```
|
39
|
+
|
40
|
+
Save a file
|
41
|
+
|
42
|
+
```ruby
|
43
|
+
TorchAudio.save("new.wave", waveform, sample_rate)
|
44
|
+
```
|
45
|
+
|
46
|
+
## Transforms
|
47
|
+
|
48
|
+
```ruby
|
49
|
+
TorchAudio::Transforms::Spectrogram.new.call(waveform)
|
50
|
+
```
|
51
|
+
|
52
|
+
Supported transforms are:
|
53
|
+
|
54
|
+
- MelSpectrogram
|
55
|
+
- MuLawDecoding
|
56
|
+
- MuLawEncoding
|
57
|
+
- Spectrogram
|
58
|
+
|
23
59
|
## Datasets
|
24
60
|
|
25
61
|
Load a dataset
|
data/ext/torchaudio/ext.cpp
CHANGED
@@ -1,33 +1,75 @@
|
|
1
1
|
#include <torchaudio/csrc/sox.h>
|
2
2
|
|
3
|
+
#include <rice/Constructor.hpp>
|
3
4
|
#include <rice/Module.hpp>
|
4
5
|
|
5
6
|
using namespace Rice;
|
6
7
|
|
8
|
+
class SignalInfo {
|
9
|
+
sox_signalinfo_t* value = nullptr;
|
10
|
+
public:
|
11
|
+
SignalInfo(Object o) {
|
12
|
+
if (!o.is_nil()) {
|
13
|
+
value = from_ruby<sox_signalinfo_t*>(o);
|
14
|
+
}
|
15
|
+
}
|
16
|
+
operator sox_signalinfo_t*() {
|
17
|
+
return value;
|
18
|
+
}
|
19
|
+
};
|
20
|
+
|
7
21
|
template<>
|
8
22
|
inline
|
9
|
-
|
23
|
+
SignalInfo from_ruby<SignalInfo>(Object x)
|
10
24
|
{
|
11
|
-
|
12
|
-
return nullptr;
|
13
|
-
}
|
14
|
-
throw std::runtime_error("Unsupported signalinfo");
|
25
|
+
return SignalInfo(x);
|
15
26
|
}
|
16
27
|
|
28
|
+
class EncodingInfo {
|
29
|
+
sox_encodinginfo_t* value = nullptr;
|
30
|
+
public:
|
31
|
+
EncodingInfo(Object o) {
|
32
|
+
if (!o.is_nil()) {
|
33
|
+
value = from_ruby<sox_encodinginfo_t*>(o);
|
34
|
+
}
|
35
|
+
}
|
36
|
+
operator sox_encodinginfo_t*() {
|
37
|
+
return value;
|
38
|
+
}
|
39
|
+
};
|
40
|
+
|
17
41
|
template<>
|
18
42
|
inline
|
19
|
-
|
43
|
+
EncodingInfo from_ruby<EncodingInfo>(Object x)
|
20
44
|
{
|
21
|
-
|
22
|
-
return nullptr;
|
23
|
-
}
|
24
|
-
throw std::runtime_error("Unsupported encodinginfo");
|
45
|
+
return EncodingInfo(x);
|
25
46
|
}
|
26
47
|
|
27
48
|
extern "C"
|
28
49
|
void Init_ext()
|
29
50
|
{
|
30
51
|
Module rb_mTorchAudio = define_module("TorchAudio");
|
31
|
-
|
32
|
-
|
52
|
+
|
53
|
+
Module rb_mExt = define_module_under(rb_mTorchAudio, "Ext")
|
54
|
+
.define_singleton_method(
|
55
|
+
"read_audio_file",
|
56
|
+
*[](const std::string& file_name, at::Tensor output, bool ch_first, int64_t nframes, int64_t offset, SignalInfo si, EncodingInfo ei, const char* ft) {
|
57
|
+
return torch::audio::read_audio_file(file_name, output, ch_first, nframes, offset, si, ei, ft);
|
58
|
+
})
|
59
|
+
.define_singleton_method(
|
60
|
+
"write_audio_file",
|
61
|
+
*[](const std::string& file_name, const at::Tensor& tensor, SignalInfo si, EncodingInfo ei, const char* file_type) {
|
62
|
+
return torch::audio::write_audio_file(file_name, tensor, si, ei, file_type);
|
63
|
+
});
|
64
|
+
|
65
|
+
Class rb_cSignalInfo = define_class_under<sox_signalinfo_t>(rb_mExt, "SignalInfo")
|
66
|
+
.define_constructor(Constructor<sox_signalinfo_t>())
|
67
|
+
.define_method("rate", *[](sox_signalinfo_t self) { return self.rate; })
|
68
|
+
.define_method("channels", *[](sox_signalinfo_t self) { return self.channels; })
|
69
|
+
.define_method("precision", *[](sox_signalinfo_t self) { return self.precision; })
|
70
|
+
.define_method("length", *[](sox_signalinfo_t self) { return self.length; })
|
71
|
+
.define_method("rate=", *[](sox_signalinfo_t self, sox_rate_t rate) { self.rate = rate; })
|
72
|
+
.define_method("channels=", *[](sox_signalinfo_t self, unsigned channels) { self.channels = channels; })
|
73
|
+
.define_method("precision=", *[](sox_signalinfo_t self, unsigned precision) { self.precision = precision; })
|
74
|
+
.define_method("length=", *[](sox_signalinfo_t self, sox_uint64_t length) { self.length = length; });
|
33
75
|
}
|
data/lib/torchaudio.rb
CHANGED
@@ -14,6 +14,12 @@ require "set"
|
|
14
14
|
# modules
|
15
15
|
require "torchaudio/datasets/utils"
|
16
16
|
require "torchaudio/datasets/yesno"
|
17
|
+
require "torchaudio/functional"
|
18
|
+
require "torchaudio/transforms/mel_scale"
|
19
|
+
require "torchaudio/transforms/mel_spectrogram"
|
20
|
+
require "torchaudio/transforms/mu_law_encoding"
|
21
|
+
require "torchaudio/transforms/mu_law_decoding"
|
22
|
+
require "torchaudio/transforms/spectrogram"
|
17
23
|
require "torchaudio/version"
|
18
24
|
|
19
25
|
module TorchAudio
|
@@ -73,6 +79,70 @@ module TorchAudio
|
|
73
79
|
load(filepath, **kwargs)
|
74
80
|
end
|
75
81
|
|
82
|
+
def save(filepath, src, sample_rate, precision: 16, channels_first: true)
|
83
|
+
si = Ext::SignalInfo.new
|
84
|
+
ch_idx = channels_first ? 0 : 1
|
85
|
+
si.rate = sample_rate
|
86
|
+
si.channels = src.dim == 1 ? 1 : src.size(ch_idx)
|
87
|
+
si.length = src.numel
|
88
|
+
si.precision = precision
|
89
|
+
save_encinfo(filepath, src, channels_first: channels_first, signalinfo: si)
|
90
|
+
end
|
91
|
+
|
92
|
+
def save_encinfo(filepath, src, channels_first: true, signalinfo: nil, encodinginfo: nil, filetype: nil)
|
93
|
+
ch_idx, len_idx = channels_first ? [0, 1] : [1, 0]
|
94
|
+
|
95
|
+
# check if save directory exists
|
96
|
+
abs_dirpath = File.dirname(File.expand_path(filepath))
|
97
|
+
unless Dir.exist?(abs_dirpath)
|
98
|
+
raise "Directory does not exist: #{abs_dirpath}"
|
99
|
+
end
|
100
|
+
# check that src is a CPU tensor
|
101
|
+
check_input(src)
|
102
|
+
# Check/Fix shape of source data
|
103
|
+
if src.dim == 1
|
104
|
+
# 1d tensors as assumed to be mono signals
|
105
|
+
src.unsqueeze!(ch_idx)
|
106
|
+
elsif src.dim > 2 || src.size(ch_idx) > 16
|
107
|
+
# assumes num_channels < 16
|
108
|
+
raise ArgumentError, "Expected format where C < 16, but found #{src.size}"
|
109
|
+
end
|
110
|
+
# sox stores the sample rate as a float, though practically sample rates are almost always integers
|
111
|
+
# convert integers to floats
|
112
|
+
if signalinfo
|
113
|
+
if signalinfo.rate && !signalinfo.rate.is_a?(Float)
|
114
|
+
if signalinfo.rate.to_f == signalinfo.rate
|
115
|
+
signalinfo.rate = signalinfo.rate.to_f
|
116
|
+
else
|
117
|
+
raise ArgumentError, "Sample rate should be a float or int"
|
118
|
+
end
|
119
|
+
end
|
120
|
+
# check if the bit precision (i.e. bits per sample) is an integer
|
121
|
+
if signalinfo.precision && ! signalinfo.precision.is_a?(Integer)
|
122
|
+
if signalinfo.precision.to_i == signalinfo.precision
|
123
|
+
signalinfo.precision = signalinfo.precision.to_i
|
124
|
+
else
|
125
|
+
raise ArgumentError, "Bit precision should be an integer"
|
126
|
+
end
|
127
|
+
end
|
128
|
+
end
|
129
|
+
# programs such as librosa normalize the signal, unnormalize if detected
|
130
|
+
if src.min >= -1.0 && src.max <= 1.0
|
131
|
+
src = src * (1 << 31)
|
132
|
+
src = src.long
|
133
|
+
end
|
134
|
+
# set filetype and allow for files with no extensions
|
135
|
+
extension = File.extname(filepath)
|
136
|
+
filetype = extension.length > 0 ? extension[1..-1] : filetype
|
137
|
+
# transpose from C x L -> L x C
|
138
|
+
if channels_first
|
139
|
+
src = src.transpose(1, 0)
|
140
|
+
end
|
141
|
+
# save data to file
|
142
|
+
src = src.contiguous
|
143
|
+
Ext.write_audio_file(filepath, src, signalinfo, encodinginfo, filetype)
|
144
|
+
end
|
145
|
+
|
76
146
|
private
|
77
147
|
|
78
148
|
def check_input(src)
|
@@ -0,0 +1,285 @@
|
|
1
|
+
module TorchAudio
|
2
|
+
module Functional
|
3
|
+
class << self
|
4
|
+
def spectrogram(waveform, pad, window, n_fft, hop_length, win_length, power, normalized)
|
5
|
+
if pad > 0
|
6
|
+
# TODO add "with torch.no_grad():" back when JIT supports it
|
7
|
+
waveform = Torch::NN::Functional.pad(waveform, [pad, pad], "constant")
|
8
|
+
end
|
9
|
+
|
10
|
+
# pack batch
|
11
|
+
shape = waveform.size
|
12
|
+
waveform = waveform.reshape(-1, shape[-1])
|
13
|
+
|
14
|
+
# default values are consistent with librosa.core.spectrum._spectrogram
|
15
|
+
spec_f = Torch.stft(
|
16
|
+
waveform, n_fft, hop_length: hop_length, win_length: win_length, window: window, center: true, pad_mode: "reflect", normalized: false, onesided: true
|
17
|
+
)
|
18
|
+
|
19
|
+
# unpack batch
|
20
|
+
spec_f = spec_f.reshape(shape[0..-2] + spec_f.shape[-3..-1])
|
21
|
+
|
22
|
+
if normalized
|
23
|
+
spec_f.div!(window.pow(2.0).sum.sqrt)
|
24
|
+
end
|
25
|
+
if power
|
26
|
+
spec_f = complex_norm(spec_f, power: power)
|
27
|
+
end
|
28
|
+
|
29
|
+
spec_f
|
30
|
+
end
|
31
|
+
|
32
|
+
def mu_law_encoding(x, quantization_channels)
|
33
|
+
mu = quantization_channels - 1.0
|
34
|
+
if !x.floating_point?
|
35
|
+
x = x.to(dtype: :float)
|
36
|
+
end
|
37
|
+
mu = Torch.tensor(mu, dtype: x.dtype)
|
38
|
+
x_mu = Torch.sign(x) * Torch.log1p(mu * Torch.abs(x)) / Torch.log1p(mu)
|
39
|
+
x_mu = ((x_mu + 1) / 2 * mu + 0.5).to(dtype: :int64)
|
40
|
+
x_mu
|
41
|
+
end
|
42
|
+
|
43
|
+
def mu_law_decoding(x_mu, quantization_channels)
|
44
|
+
mu = quantization_channels - 1.0
|
45
|
+
if !x_mu.floating_point?
|
46
|
+
x_mu = x_mu.to(dtype: :float)
|
47
|
+
end
|
48
|
+
mu = Torch.tensor(mu, dtype: x_mu.dtype)
|
49
|
+
x = ((x_mu) / mu) * 2 - 1.0
|
50
|
+
x = Torch.sign(x) * (Torch.exp(Torch.abs(x) * Torch.log1p(mu)) - 1.0) / mu
|
51
|
+
x
|
52
|
+
end
|
53
|
+
|
54
|
+
def complex_norm(complex_tensor, power: 1.0)
|
55
|
+
complex_tensor.pow(2.0).sum(-1).pow(0.5 * power)
|
56
|
+
end
|
57
|
+
|
58
|
+
def create_fb_matrix(n_freqs, f_min, f_max, n_mels, sample_rate, norm: nil)
|
59
|
+
if norm && norm != "slaney"
|
60
|
+
raise ArgumentError, "norm must be one of None or 'slaney'"
|
61
|
+
end
|
62
|
+
|
63
|
+
# freq bins
|
64
|
+
# Equivalent filterbank construction by Librosa
|
65
|
+
all_freqs = Torch.linspace(0, sample_rate.div(2), n_freqs)
|
66
|
+
|
67
|
+
# calculate mel freq bins
|
68
|
+
# hertz to mel(f) is 2595. * math.log10(1. + (f / 700.))
|
69
|
+
m_min = 2595.0 * Math.log10(1.0 + (f_min / 700.0))
|
70
|
+
m_max = 2595.0 * Math.log10(1.0 + (f_max / 700.0))
|
71
|
+
m_pts = Torch.linspace(m_min, m_max, n_mels + 2)
|
72
|
+
# mel to hertz(mel) is 700. * (10**(mel / 2595.) - 1.)
|
73
|
+
f_pts = (Torch.pow(10, m_pts / 2595.0) - 1.0) * 700.0
|
74
|
+
# calculate the difference between each mel point and each stft freq point in hertz
|
75
|
+
f_diff = f_pts[1..-1] - f_pts[0...-1] # (n_mels + 1)
|
76
|
+
slopes = f_pts.unsqueeze(0) - all_freqs.unsqueeze(1) # (n_freqs, n_mels + 2)
|
77
|
+
# create overlapping triangles
|
78
|
+
zero = Torch.zeros(1)
|
79
|
+
down_slopes = (slopes[0..-1, 0...-2] * -1.0) / f_diff[0...-1] # (n_freqs, n_mels)
|
80
|
+
up_slopes = slopes[0..-1, 2..-1] / f_diff[1..-1] # (n_freqs, n_mels)
|
81
|
+
fb = Torch.max(zero, Torch.min(down_slopes, up_slopes))
|
82
|
+
|
83
|
+
if norm && norm == "slaney"
|
84
|
+
# Slaney-style mel is scaled to be approx constant energy per channel
|
85
|
+
enorm = 2.0 / (f_pts[2...(n_mels + 2)] - f_pts[:n_mels])
|
86
|
+
fb *= enorm.unsqueeze(0)
|
87
|
+
end
|
88
|
+
|
89
|
+
fb
|
90
|
+
end
|
91
|
+
|
92
|
+
def compute_deltas(specgram, win_length: 5, mode: "replicate")
|
93
|
+
device = specgram.device
|
94
|
+
dtype = specgram.dtype
|
95
|
+
|
96
|
+
# pack batch
|
97
|
+
shape = specgram.size
|
98
|
+
specgram = specgram.reshape(1, -1, shape[-1])
|
99
|
+
|
100
|
+
raise ArgumentError, "win_length must be >= 3" unless win_length >= 3
|
101
|
+
|
102
|
+
n = (win_length - 1).div(2)
|
103
|
+
|
104
|
+
# twice sum of integer squared
|
105
|
+
denom = n * (n + 1) * (2 * n + 1) / 3
|
106
|
+
|
107
|
+
specgram = Torch::NN::Functional.pad(specgram, [n, n], mode: mode)
|
108
|
+
|
109
|
+
kernel = Torch.arange(-n, n + 1, 1, device: device, dtype: dtype).repeat([specgram.shape[1], 1, 1])
|
110
|
+
|
111
|
+
output = Torch::NN::Functional.conv1d(specgram, kernel, groups: specgram.shape[1]) / denom
|
112
|
+
|
113
|
+
# unpack batch
|
114
|
+
output = output.reshape(shape)
|
115
|
+
end
|
116
|
+
|
117
|
+
def gain(waveform, gain_db: 1.0)
|
118
|
+
return waveform if gain_db == 0
|
119
|
+
|
120
|
+
ratio = 10 ** (gain_db / 20)
|
121
|
+
|
122
|
+
waveform * ratio
|
123
|
+
end
|
124
|
+
|
125
|
+
def dither(waveform, density_function: "TPDF", noise_shaping: false)
|
126
|
+
dithered = _apply_probability_distribution(waveform, density_function: density_function)
|
127
|
+
|
128
|
+
if noise_shaping
|
129
|
+
raise "Not implemented yet"
|
130
|
+
# _add_noise_shaping(dithered, waveform)
|
131
|
+
else
|
132
|
+
dithered
|
133
|
+
end
|
134
|
+
end
|
135
|
+
|
136
|
+
def biquad(waveform, b0, b1, b2, a0, a1, a2)
|
137
|
+
device = waveform.device
|
138
|
+
dtype = waveform.dtype
|
139
|
+
|
140
|
+
output_waveform = lfilter(
|
141
|
+
waveform,
|
142
|
+
Torch.tensor([a0, a1, a2], dtype: dtype, device: device),
|
143
|
+
Torch.tensor([b0, b1, b2], dtype: dtype, device: device)
|
144
|
+
)
|
145
|
+
output_waveform
|
146
|
+
end
|
147
|
+
|
148
|
+
def highpass_biquad(waveform, sample_rate, cutoff_freq, q: 0.707)
|
149
|
+
w0 = 2 * Math::PI * cutoff_freq / sample_rate
|
150
|
+
alpha = Math.sin(w0) / 2.0 / q
|
151
|
+
|
152
|
+
b0 = (1 + Math.cos(w0)) / 2
|
153
|
+
b1 = -1 - Math.cos(w0)
|
154
|
+
b2 = b0
|
155
|
+
a0 = 1 + alpha
|
156
|
+
a1 = -2 * Math.cos(w0)
|
157
|
+
a2 = 1 - alpha
|
158
|
+
biquad(waveform, b0, b1, b2, a0, a1, a2)
|
159
|
+
end
|
160
|
+
|
161
|
+
def lowpass_biquad(waveform, sample_rate, cutoff_freq, q: 0.707)
|
162
|
+
w0 = 2 * Math::PI * cutoff_freq / sample_rate
|
163
|
+
alpha = Math.sin(w0) / 2 / q
|
164
|
+
|
165
|
+
b0 = (1 - Math.cos(w0)) / 2
|
166
|
+
b1 = 1 - Math.cos(w0)
|
167
|
+
b2 = b0
|
168
|
+
a0 = 1 + alpha
|
169
|
+
a1 = -2 * Math.cos(w0)
|
170
|
+
a2 = 1 - alpha
|
171
|
+
biquad(waveform, b0, b1, b2, a0, a1, a2)
|
172
|
+
end
|
173
|
+
|
174
|
+
def lfilter(waveform, a_coeffs, b_coeffs, clamp: true)
|
175
|
+
# pack batch
|
176
|
+
shape = waveform.size
|
177
|
+
waveform = waveform.reshape(-1, shape[-1])
|
178
|
+
|
179
|
+
raise ArgumentError unless (a_coeffs.size(0) == b_coeffs.size(0))
|
180
|
+
raise ArgumentError unless (waveform.size.length == 2)
|
181
|
+
raise ArgumentError unless (waveform.device == a_coeffs.device)
|
182
|
+
raise ArgumentError unless (b_coeffs.device == a_coeffs.device)
|
183
|
+
|
184
|
+
device = waveform.device
|
185
|
+
dtype = waveform.dtype
|
186
|
+
n_channel, n_sample = waveform.size
|
187
|
+
n_order = a_coeffs.size(0)
|
188
|
+
n_sample_padded = n_sample + n_order - 1
|
189
|
+
raise ArgumentError unless (n_order > 0)
|
190
|
+
|
191
|
+
# Pad the input and create output
|
192
|
+
padded_waveform = Torch.zeros(n_channel, n_sample_padded, dtype: dtype, device: device)
|
193
|
+
padded_waveform[0..-1, (n_order - 1)..-1] = waveform
|
194
|
+
padded_output_waveform = Torch.zeros(n_channel, n_sample_padded, dtype: dtype, device: device)
|
195
|
+
|
196
|
+
# Set up the coefficients matrix
|
197
|
+
# Flip coefficients' order
|
198
|
+
a_coeffs_flipped = a_coeffs.flip([0])
|
199
|
+
b_coeffs_flipped = b_coeffs.flip([0])
|
200
|
+
|
201
|
+
# calculate windowed_input_signal in parallel
|
202
|
+
# create indices of original with shape (n_channel, n_order, n_sample)
|
203
|
+
window_idxs = Torch.arange(n_sample, device: device).unsqueeze(0) + Torch.arange(n_order, device: device).unsqueeze(1)
|
204
|
+
window_idxs = window_idxs.repeat([n_channel, 1, 1])
|
205
|
+
window_idxs += (Torch.arange(n_channel, device: device).unsqueeze(-1).unsqueeze(-1) * n_sample_padded)
|
206
|
+
window_idxs = window_idxs.long
|
207
|
+
# (n_order, ) matmul (n_channel, n_order, n_sample) -> (n_channel, n_sample)
|
208
|
+
input_signal_windows = Torch.matmul(b_coeffs_flipped, Torch.take(padded_waveform, window_idxs))
|
209
|
+
|
210
|
+
input_signal_windows.div!(a_coeffs[0])
|
211
|
+
a_coeffs_flipped.div!(a_coeffs[0])
|
212
|
+
input_signal_windows.t.each_with_index do |o0, i_sample|
|
213
|
+
windowed_output_signal = padded_output_waveform[0..-1, i_sample...(i_sample + n_order)]
|
214
|
+
o0.addmv!(windowed_output_signal, a_coeffs_flipped, alpha: -1)
|
215
|
+
padded_output_waveform[0..-1, i_sample + n_order - 1] = o0
|
216
|
+
end
|
217
|
+
|
218
|
+
output = padded_output_waveform[0..-1, (n_order - 1)..-1]
|
219
|
+
|
220
|
+
if clamp
|
221
|
+
output = Torch.clamp(output, -1.0, 1.0)
|
222
|
+
end
|
223
|
+
|
224
|
+
# unpack batch
|
225
|
+
output = output.reshape(shape[0...-1] + output.shape[-1..-1])
|
226
|
+
|
227
|
+
output
|
228
|
+
end
|
229
|
+
|
230
|
+
private
|
231
|
+
|
232
|
+
def _apply_probability_distribution(waveform, density_function: "TPDF")
|
233
|
+
# pack batch
|
234
|
+
shape = waveform.size
|
235
|
+
waveform = waveform.reshape(-1, shape[-1])
|
236
|
+
|
237
|
+
channel_size = waveform.size[0] - 1
|
238
|
+
time_size = waveform.size[-1] - 1
|
239
|
+
|
240
|
+
random_channel = channel_size > 0 ? Torch.randint(channel_size, [1]).item.to_i : 0
|
241
|
+
random_time = time_size > 0 ? Torch.randint(time_size, [1]).item.to_i : 0
|
242
|
+
|
243
|
+
number_of_bits = 16
|
244
|
+
up_scaling = 2 ** (number_of_bits - 1) - 2
|
245
|
+
signal_scaled = waveform * up_scaling
|
246
|
+
down_scaling = 2 ** (number_of_bits - 1)
|
247
|
+
|
248
|
+
signal_scaled_dis = waveform
|
249
|
+
if density_function == "RPDF"
|
250
|
+
rpdf = waveform[random_channel][random_time] - 0.5
|
251
|
+
|
252
|
+
signal_scaled_dis = signal_scaled + rpdf
|
253
|
+
elsif density_function == "GPDF"
|
254
|
+
# TODO Replace by distribution code once
|
255
|
+
# https://github.com/pytorch/pytorch/issues/29843 is resolved
|
256
|
+
# gaussian = torch.distributions.normal.Normal(torch.mean(waveform, -1), 1).sample()
|
257
|
+
|
258
|
+
num_rand_variables = 6
|
259
|
+
|
260
|
+
gaussian = waveform[random_channel][random_time]
|
261
|
+
(num_rand_variables * [time_size]).each do |ws|
|
262
|
+
rand_chan = Torch.randint(channel_size, [1]).item.to_i
|
263
|
+
gaussian += waveform[rand_chan][Torch.randint(ws, [1]).item.to_i]
|
264
|
+
end
|
265
|
+
|
266
|
+
signal_scaled_dis = signal_scaled + gaussian
|
267
|
+
else
|
268
|
+
# dtype needed for https://github.com/pytorch/pytorch/issues/32358
|
269
|
+
# TODO add support for dtype and device to bartlett_window
|
270
|
+
tpdf = Torch.bartlett_window(time_size + 1).to(signal_scaled.device, dtype: signal_scaled.dtype)
|
271
|
+
tpdf = tpdf.repeat([channel_size + 1, 1])
|
272
|
+
signal_scaled_dis = signal_scaled + tpdf
|
273
|
+
end
|
274
|
+
|
275
|
+
quantised_signal_scaled = Torch.round(signal_scaled_dis)
|
276
|
+
quantised_signal = quantised_signal_scaled / down_scaling
|
277
|
+
|
278
|
+
# unpack batch
|
279
|
+
quantised_signal.reshape(shape[0...-1] + quantised_signal.shape[-1..-1])
|
280
|
+
end
|
281
|
+
end
|
282
|
+
end
|
283
|
+
|
284
|
+
F = Functional
|
285
|
+
end
|
@@ -0,0 +1,39 @@
|
|
1
|
+
module TorchAudio
|
2
|
+
module Transforms
|
3
|
+
class MelScale < Torch::NN::Module
|
4
|
+
def initialize(n_mels: 128, sample_rate: 16000, f_min: 0.0, f_max: nil, n_stft: nil)
|
5
|
+
super()
|
6
|
+
@n_mels = n_mels
|
7
|
+
@sample_rate = sample_rate
|
8
|
+
@f_max = f_max || sample_rate.div(2).to_f
|
9
|
+
@f_min = f_min
|
10
|
+
|
11
|
+
raise ArgumentError, "Require f_min: %f < f_max: %f" % [f_min, @f_max] unless f_min <= @f_max
|
12
|
+
|
13
|
+
fb = n_stft.nil? ? Torch.empty(0) : F.create_fb_matrix(n_stft, @f_min, @f_max, @n_mels, @sample_rate)
|
14
|
+
register_buffer("fb", fb)
|
15
|
+
end
|
16
|
+
|
17
|
+
def forward(specgram)
|
18
|
+
shape = specgram.size
|
19
|
+
specgram = specgram.reshape(-1, shape[-2], shape[-1])
|
20
|
+
|
21
|
+
if @fb.numel == 0
|
22
|
+
tmp_fb = F.create_fb_matrix(specgram.size(1), @f_min, @f_max, @n_mels, @sample_rate)
|
23
|
+
# Attributes cannot be reassigned outside __init__ so workaround
|
24
|
+
@fb.resize!(tmp_fb.size)
|
25
|
+
@fb.copy!(tmp_fb)
|
26
|
+
end
|
27
|
+
|
28
|
+
# (channel, frequency, time).transpose(...) dot (frequency, n_mels)
|
29
|
+
# -> (channel, time, n_mels).transpose(...)
|
30
|
+
mel_specgram = Torch.matmul(specgram.transpose(1, 2), @fb).transpose(1, 2)
|
31
|
+
|
32
|
+
# unpack batch
|
33
|
+
mel_specgram = mel_specgram.reshape(shape[0...-2] + mel_specgram.shape[-2..-1])
|
34
|
+
|
35
|
+
mel_specgram
|
36
|
+
end
|
37
|
+
end
|
38
|
+
end
|
39
|
+
end
|
@@ -0,0 +1,35 @@
|
|
1
|
+
module TorchAudio
|
2
|
+
module Transforms
|
3
|
+
class MelSpectrogram < Torch::NN::Module
|
4
|
+
def initialize(
|
5
|
+
sample_rate: 16000, n_fft: 400, win_length: nil, hop_length: nil, f_min: 0.0,
|
6
|
+
f_max: nil, pad: 0, n_mels: 128, window_fn: Torch.method(:hann_window),
|
7
|
+
power: 2.0, normalized: false, wkwargs: nil
|
8
|
+
)
|
9
|
+
|
10
|
+
super()
|
11
|
+
@sample_rate = sample_rate
|
12
|
+
@n_fft = n_fft
|
13
|
+
@win_length = win_length || n_fft
|
14
|
+
@hop_length = hop_length || @win_length.div(2)
|
15
|
+
@pad = pad
|
16
|
+
@power = power
|
17
|
+
@normalized = normalized
|
18
|
+
@n_mels = n_mels # number of mel frequency bins
|
19
|
+
@f_max = f_max
|
20
|
+
@f_min = f_min
|
21
|
+
@spectrogram =
|
22
|
+
Spectrogram.new(
|
23
|
+
n_fft: @n_fft, win_length: @win_length, hop_length: @hop_length, pad: @pad,
|
24
|
+
window_fn: window_fn, power: @power, normalized: @normalized, wkwargs: wkwargs
|
25
|
+
)
|
26
|
+
@mel_scale = MelScale.new(n_mels: @n_mels, sample_rate: @sample_rate, f_min: @f_min, f_max: @f_max, n_stft: @n_fft.div(2) + 1)
|
27
|
+
end
|
28
|
+
|
29
|
+
def forward(waveform)
|
30
|
+
specgram = @spectrogram.call(waveform)
|
31
|
+
@mel_scale.call(specgram)
|
32
|
+
end
|
33
|
+
end
|
34
|
+
end
|
35
|
+
end
|
@@ -0,0 +1,14 @@
|
|
1
|
+
module TorchAudio
|
2
|
+
module Transforms
|
3
|
+
class MuLawDecoding < Torch::NN::Module
|
4
|
+
def initialize(quantization_channels: 256)
|
5
|
+
super()
|
6
|
+
@quantization_channels = quantization_channels
|
7
|
+
end
|
8
|
+
|
9
|
+
def forward(x_mu)
|
10
|
+
F.mu_law_decoding(x_mu, @quantization_channels)
|
11
|
+
end
|
12
|
+
end
|
13
|
+
end
|
14
|
+
end
|
@@ -0,0 +1,14 @@
|
|
1
|
+
module TorchAudio
|
2
|
+
module Transforms
|
3
|
+
class MuLawEncoding < Torch::NN::Module
|
4
|
+
def initialize(quantization_channels: 256)
|
5
|
+
super()
|
6
|
+
@quantization_channels = quantization_channels
|
7
|
+
end
|
8
|
+
|
9
|
+
def forward(x)
|
10
|
+
F.mu_law_encoding(x, @quantization_channels)
|
11
|
+
end
|
12
|
+
end
|
13
|
+
end
|
14
|
+
end
|
@@ -0,0 +1,27 @@
|
|
1
|
+
module TorchAudio
|
2
|
+
module Transforms
|
3
|
+
class Spectrogram < Torch::NN::Module
|
4
|
+
def initialize(
|
5
|
+
n_fft: 400, win_length: nil, hop_length: nil, pad: 0,
|
6
|
+
window_fn: Torch.method(:hann_window), power: 2.0, normalized: false, wkwargs: nil
|
7
|
+
)
|
8
|
+
|
9
|
+
super()
|
10
|
+
@n_fft = n_fft
|
11
|
+
# number of FFT bins. the returned STFT result will have n_fft // 2 + 1
|
12
|
+
# number of frequecies due to onesided=True in torch.stft
|
13
|
+
@win_length = win_length || n_fft
|
14
|
+
@hop_length = hop_length || @win_length.div(2) # floor division
|
15
|
+
window = wkwargs.nil? ? window_fn.call(@win_length) : window_fn.call(@win_length, **wkwargs)
|
16
|
+
register_buffer("window", window)
|
17
|
+
@pad = pad
|
18
|
+
@power = power
|
19
|
+
@normalized = normalized
|
20
|
+
end
|
21
|
+
|
22
|
+
def forward(waveform)
|
23
|
+
F.spectrogram(waveform, @pad, @window, @n_fft, @hop_length, @win_length, @power, @normalized)
|
24
|
+
end
|
25
|
+
end
|
26
|
+
end
|
27
|
+
end
|
data/lib/torchaudio/version.rb
CHANGED
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: torchaudio
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.1.
|
4
|
+
version: 0.1.1
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2020-08-
|
11
|
+
date: 2020-08-26 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: torch-rb
|
@@ -16,14 +16,14 @@ dependencies:
|
|
16
16
|
requirements:
|
17
17
|
- - ">="
|
18
18
|
- !ruby/object:Gem::Version
|
19
|
-
version: 0.3.
|
19
|
+
version: 0.3.4
|
20
20
|
type: :runtime
|
21
21
|
prerelease: false
|
22
22
|
version_requirements: !ruby/object:Gem::Requirement
|
23
23
|
requirements:
|
24
24
|
- - ">="
|
25
25
|
- !ruby/object:Gem::Version
|
26
|
-
version: 0.3.
|
26
|
+
version: 0.3.4
|
27
27
|
- !ruby/object:Gem::Dependency
|
28
28
|
name: rice
|
29
29
|
requirement: !ruby/object:Gem::Requirement
|
@@ -118,6 +118,12 @@ files:
|
|
118
118
|
- lib/torchaudio.rb
|
119
119
|
- lib/torchaudio/datasets/utils.rb
|
120
120
|
- lib/torchaudio/datasets/yesno.rb
|
121
|
+
- lib/torchaudio/functional.rb
|
122
|
+
- lib/torchaudio/transforms/mel_scale.rb
|
123
|
+
- lib/torchaudio/transforms/mel_spectrogram.rb
|
124
|
+
- lib/torchaudio/transforms/mu_law_decoding.rb
|
125
|
+
- lib/torchaudio/transforms/mu_law_encoding.rb
|
126
|
+
- lib/torchaudio/transforms/spectrogram.rb
|
121
127
|
- lib/torchaudio/version.rb
|
122
128
|
homepage: https://github.com/ankane/torchaudio
|
123
129
|
licenses:
|