torchaudio 0.1.0 → 0.1.1

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: b527976494325cc12e81342c25d318204d2d7c75bfba7036be4296769cdb30a0
4
- data.tar.gz: 2cfde7bd1b0e7a1628818d5bd74657cfbfba6dfa83ef42897f3ad0f98e77f739
3
+ metadata.gz: 354a8636b2a04fd0ed60db93e2667f630f1db04ab047a6eb6e0c3607190c58d5
4
+ data.tar.gz: cde6229f8ba416c5a72bbea6d47e26fd987b6f13073f1ca9a1e6239cd62b79c1
5
5
  SHA512:
6
- metadata.gz: 8e6f34b014340b5ace3193ab589dae75ed0869ab7606402bd4b09de6042299e6f3a118d439dd381491f489ce9552bca4376a7d5b4693dddc3d1c5f5b26540900
7
- data.tar.gz: d651c46f5185ceb70ae3d9c90154c77afe29a5c35854d1a9d98913096b7ab9ba39a745242dd268548ca87f9e109b56c96dee9dc5539cf066f9ad0f773eddbdcd
6
+ metadata.gz: 76c9c637e4f12700c16bbf343104bb46f5e4f04161a3556cd0600ea236222902ac0ba3b9504f4bdb28a9919317a1b791c8f7b4dbcb48b35ed4c674de39a83277
7
+ data.tar.gz: '019d89cd2231c386c3458447358058618a6de78b77ed5b4d6d279ee0e3879946bda03cf339051136069dd552d6eeaed33bf40defa9d6c90c9a07eb27d68fdce2'
@@ -1,3 +1,8 @@
1
+ ## 0.1.1 (2020-08-26)
2
+
3
+ - Added `save` method
4
+ - Added transforms
5
+
1
6
  ## 0.1.0 (2020-08-24)
2
7
 
3
8
  - First release
data/README.md CHANGED
@@ -2,6 +2,8 @@
2
2
 
3
3
  :fire: An audio library for Torch.rb
4
4
 
5
+ [![Build Status](https://travis-ci.org/ankane/torchaudio.svg?branch=master)](https://travis-ci.org/ankane/torchaudio)
6
+
5
7
  ## Installation
6
8
 
7
9
  First, [install SoX](#sox-installation). For Homebrew, use:
@@ -20,6 +22,40 @@ gem 'torchaudio'
20
22
 
21
23
  This library follows the [Python API](https://pytorch.org/audio/). Many methods and options are missing at the moment. PRs welcome!
22
24
 
25
+ ## Tutorial
26
+
27
+ - [PyTorch tutorial](https://pytorch.org/tutorials/beginner/audio_preprocessing_tutorial.html)
28
+ - [Ruby code](examples/tutorial.rb)
29
+
30
+ Download the [audio file](https://github.com/pytorch/tutorials/raw/master/_static/img/steam-train-whistle-daniel_simon-converted-from-mp3.wav) and install the [matplotlib](https://github.com/mrkn/matplotlib.rb) gem first.
31
+
32
+ ## Basics
33
+
34
+ Load a file
35
+
36
+ ```ruby
37
+ waveform, sample_rate = TorchAudio.load("file.wav")
38
+ ```
39
+
40
+ Save a file
41
+
42
+ ```ruby
43
+ TorchAudio.save("new.wave", waveform, sample_rate)
44
+ ```
45
+
46
+ ## Transforms
47
+
48
+ ```ruby
49
+ TorchAudio::Transforms::Spectrogram.new.call(waveform)
50
+ ```
51
+
52
+ Supported transforms are:
53
+
54
+ - MelSpectrogram
55
+ - MuLawDecoding
56
+ - MuLawEncoding
57
+ - Spectrogram
58
+
23
59
  ## Datasets
24
60
 
25
61
  Load a dataset
@@ -1,33 +1,75 @@
1
1
  #include <torchaudio/csrc/sox.h>
2
2
 
3
+ #include <rice/Constructor.hpp>
3
4
  #include <rice/Module.hpp>
4
5
 
5
6
  using namespace Rice;
6
7
 
8
+ class SignalInfo {
9
+ sox_signalinfo_t* value = nullptr;
10
+ public:
11
+ SignalInfo(Object o) {
12
+ if (!o.is_nil()) {
13
+ value = from_ruby<sox_signalinfo_t*>(o);
14
+ }
15
+ }
16
+ operator sox_signalinfo_t*() {
17
+ return value;
18
+ }
19
+ };
20
+
7
21
  template<>
8
22
  inline
9
- sox_signalinfo_t* from_ruby<sox_signalinfo_t*>(Object x)
23
+ SignalInfo from_ruby<SignalInfo>(Object x)
10
24
  {
11
- if (x.is_nil()) {
12
- return nullptr;
13
- }
14
- throw std::runtime_error("Unsupported signalinfo");
25
+ return SignalInfo(x);
15
26
  }
16
27
 
28
+ class EncodingInfo {
29
+ sox_encodinginfo_t* value = nullptr;
30
+ public:
31
+ EncodingInfo(Object o) {
32
+ if (!o.is_nil()) {
33
+ value = from_ruby<sox_encodinginfo_t*>(o);
34
+ }
35
+ }
36
+ operator sox_encodinginfo_t*() {
37
+ return value;
38
+ }
39
+ };
40
+
17
41
  template<>
18
42
  inline
19
- sox_encodinginfo_t* from_ruby<sox_encodinginfo_t*>(Object x)
43
+ EncodingInfo from_ruby<EncodingInfo>(Object x)
20
44
  {
21
- if (x.is_nil()) {
22
- return nullptr;
23
- }
24
- throw std::runtime_error("Unsupported encodinginfo");
45
+ return EncodingInfo(x);
25
46
  }
26
47
 
27
48
  extern "C"
28
49
  void Init_ext()
29
50
  {
30
51
  Module rb_mTorchAudio = define_module("TorchAudio");
31
- Module rb_mNN = define_module_under(rb_mTorchAudio, "Ext")
32
- .define_singleton_method("read_audio_file", &torch::audio::read_audio_file);
52
+
53
+ Module rb_mExt = define_module_under(rb_mTorchAudio, "Ext")
54
+ .define_singleton_method(
55
+ "read_audio_file",
56
+ *[](const std::string& file_name, at::Tensor output, bool ch_first, int64_t nframes, int64_t offset, SignalInfo si, EncodingInfo ei, const char* ft) {
57
+ return torch::audio::read_audio_file(file_name, output, ch_first, nframes, offset, si, ei, ft);
58
+ })
59
+ .define_singleton_method(
60
+ "write_audio_file",
61
+ *[](const std::string& file_name, const at::Tensor& tensor, SignalInfo si, EncodingInfo ei, const char* file_type) {
62
+ return torch::audio::write_audio_file(file_name, tensor, si, ei, file_type);
63
+ });
64
+
65
+ Class rb_cSignalInfo = define_class_under<sox_signalinfo_t>(rb_mExt, "SignalInfo")
66
+ .define_constructor(Constructor<sox_signalinfo_t>())
67
+ .define_method("rate", *[](sox_signalinfo_t self) { return self.rate; })
68
+ .define_method("channels", *[](sox_signalinfo_t self) { return self.channels; })
69
+ .define_method("precision", *[](sox_signalinfo_t self) { return self.precision; })
70
+ .define_method("length", *[](sox_signalinfo_t self) { return self.length; })
71
+ .define_method("rate=", *[](sox_signalinfo_t self, sox_rate_t rate) { self.rate = rate; })
72
+ .define_method("channels=", *[](sox_signalinfo_t self, unsigned channels) { self.channels = channels; })
73
+ .define_method("precision=", *[](sox_signalinfo_t self, unsigned precision) { self.precision = precision; })
74
+ .define_method("length=", *[](sox_signalinfo_t self, sox_uint64_t length) { self.length = length; });
33
75
  }
@@ -14,6 +14,12 @@ require "set"
14
14
  # modules
15
15
  require "torchaudio/datasets/utils"
16
16
  require "torchaudio/datasets/yesno"
17
+ require "torchaudio/functional"
18
+ require "torchaudio/transforms/mel_scale"
19
+ require "torchaudio/transforms/mel_spectrogram"
20
+ require "torchaudio/transforms/mu_law_encoding"
21
+ require "torchaudio/transforms/mu_law_decoding"
22
+ require "torchaudio/transforms/spectrogram"
17
23
  require "torchaudio/version"
18
24
 
19
25
  module TorchAudio
@@ -73,6 +79,70 @@ module TorchAudio
73
79
  load(filepath, **kwargs)
74
80
  end
75
81
 
82
+ def save(filepath, src, sample_rate, precision: 16, channels_first: true)
83
+ si = Ext::SignalInfo.new
84
+ ch_idx = channels_first ? 0 : 1
85
+ si.rate = sample_rate
86
+ si.channels = src.dim == 1 ? 1 : src.size(ch_idx)
87
+ si.length = src.numel
88
+ si.precision = precision
89
+ save_encinfo(filepath, src, channels_first: channels_first, signalinfo: si)
90
+ end
91
+
92
+ def save_encinfo(filepath, src, channels_first: true, signalinfo: nil, encodinginfo: nil, filetype: nil)
93
+ ch_idx, len_idx = channels_first ? [0, 1] : [1, 0]
94
+
95
+ # check if save directory exists
96
+ abs_dirpath = File.dirname(File.expand_path(filepath))
97
+ unless Dir.exist?(abs_dirpath)
98
+ raise "Directory does not exist: #{abs_dirpath}"
99
+ end
100
+ # check that src is a CPU tensor
101
+ check_input(src)
102
+ # Check/Fix shape of source data
103
+ if src.dim == 1
104
+ # 1d tensors as assumed to be mono signals
105
+ src.unsqueeze!(ch_idx)
106
+ elsif src.dim > 2 || src.size(ch_idx) > 16
107
+ # assumes num_channels < 16
108
+ raise ArgumentError, "Expected format where C < 16, but found #{src.size}"
109
+ end
110
+ # sox stores the sample rate as a float, though practically sample rates are almost always integers
111
+ # convert integers to floats
112
+ if signalinfo
113
+ if signalinfo.rate && !signalinfo.rate.is_a?(Float)
114
+ if signalinfo.rate.to_f == signalinfo.rate
115
+ signalinfo.rate = signalinfo.rate.to_f
116
+ else
117
+ raise ArgumentError, "Sample rate should be a float or int"
118
+ end
119
+ end
120
+ # check if the bit precision (i.e. bits per sample) is an integer
121
+ if signalinfo.precision && ! signalinfo.precision.is_a?(Integer)
122
+ if signalinfo.precision.to_i == signalinfo.precision
123
+ signalinfo.precision = signalinfo.precision.to_i
124
+ else
125
+ raise ArgumentError, "Bit precision should be an integer"
126
+ end
127
+ end
128
+ end
129
+ # programs such as librosa normalize the signal, unnormalize if detected
130
+ if src.min >= -1.0 && src.max <= 1.0
131
+ src = src * (1 << 31)
132
+ src = src.long
133
+ end
134
+ # set filetype and allow for files with no extensions
135
+ extension = File.extname(filepath)
136
+ filetype = extension.length > 0 ? extension[1..-1] : filetype
137
+ # transpose from C x L -> L x C
138
+ if channels_first
139
+ src = src.transpose(1, 0)
140
+ end
141
+ # save data to file
142
+ src = src.contiguous
143
+ Ext.write_audio_file(filepath, src, signalinfo, encodinginfo, filetype)
144
+ end
145
+
76
146
  private
77
147
 
78
148
  def check_input(src)
@@ -0,0 +1,285 @@
1
+ module TorchAudio
2
+ module Functional
3
+ class << self
4
+ def spectrogram(waveform, pad, window, n_fft, hop_length, win_length, power, normalized)
5
+ if pad > 0
6
+ # TODO add "with torch.no_grad():" back when JIT supports it
7
+ waveform = Torch::NN::Functional.pad(waveform, [pad, pad], "constant")
8
+ end
9
+
10
+ # pack batch
11
+ shape = waveform.size
12
+ waveform = waveform.reshape(-1, shape[-1])
13
+
14
+ # default values are consistent with librosa.core.spectrum._spectrogram
15
+ spec_f = Torch.stft(
16
+ waveform, n_fft, hop_length: hop_length, win_length: win_length, window: window, center: true, pad_mode: "reflect", normalized: false, onesided: true
17
+ )
18
+
19
+ # unpack batch
20
+ spec_f = spec_f.reshape(shape[0..-2] + spec_f.shape[-3..-1])
21
+
22
+ if normalized
23
+ spec_f.div!(window.pow(2.0).sum.sqrt)
24
+ end
25
+ if power
26
+ spec_f = complex_norm(spec_f, power: power)
27
+ end
28
+
29
+ spec_f
30
+ end
31
+
32
+ def mu_law_encoding(x, quantization_channels)
33
+ mu = quantization_channels - 1.0
34
+ if !x.floating_point?
35
+ x = x.to(dtype: :float)
36
+ end
37
+ mu = Torch.tensor(mu, dtype: x.dtype)
38
+ x_mu = Torch.sign(x) * Torch.log1p(mu * Torch.abs(x)) / Torch.log1p(mu)
39
+ x_mu = ((x_mu + 1) / 2 * mu + 0.5).to(dtype: :int64)
40
+ x_mu
41
+ end
42
+
43
+ def mu_law_decoding(x_mu, quantization_channels)
44
+ mu = quantization_channels - 1.0
45
+ if !x_mu.floating_point?
46
+ x_mu = x_mu.to(dtype: :float)
47
+ end
48
+ mu = Torch.tensor(mu, dtype: x_mu.dtype)
49
+ x = ((x_mu) / mu) * 2 - 1.0
50
+ x = Torch.sign(x) * (Torch.exp(Torch.abs(x) * Torch.log1p(mu)) - 1.0) / mu
51
+ x
52
+ end
53
+
54
+ def complex_norm(complex_tensor, power: 1.0)
55
+ complex_tensor.pow(2.0).sum(-1).pow(0.5 * power)
56
+ end
57
+
58
+ def create_fb_matrix(n_freqs, f_min, f_max, n_mels, sample_rate, norm: nil)
59
+ if norm && norm != "slaney"
60
+ raise ArgumentError, "norm must be one of None or 'slaney'"
61
+ end
62
+
63
+ # freq bins
64
+ # Equivalent filterbank construction by Librosa
65
+ all_freqs = Torch.linspace(0, sample_rate.div(2), n_freqs)
66
+
67
+ # calculate mel freq bins
68
+ # hertz to mel(f) is 2595. * math.log10(1. + (f / 700.))
69
+ m_min = 2595.0 * Math.log10(1.0 + (f_min / 700.0))
70
+ m_max = 2595.0 * Math.log10(1.0 + (f_max / 700.0))
71
+ m_pts = Torch.linspace(m_min, m_max, n_mels + 2)
72
+ # mel to hertz(mel) is 700. * (10**(mel / 2595.) - 1.)
73
+ f_pts = (Torch.pow(10, m_pts / 2595.0) - 1.0) * 700.0
74
+ # calculate the difference between each mel point and each stft freq point in hertz
75
+ f_diff = f_pts[1..-1] - f_pts[0...-1] # (n_mels + 1)
76
+ slopes = f_pts.unsqueeze(0) - all_freqs.unsqueeze(1) # (n_freqs, n_mels + 2)
77
+ # create overlapping triangles
78
+ zero = Torch.zeros(1)
79
+ down_slopes = (slopes[0..-1, 0...-2] * -1.0) / f_diff[0...-1] # (n_freqs, n_mels)
80
+ up_slopes = slopes[0..-1, 2..-1] / f_diff[1..-1] # (n_freqs, n_mels)
81
+ fb = Torch.max(zero, Torch.min(down_slopes, up_slopes))
82
+
83
+ if norm && norm == "slaney"
84
+ # Slaney-style mel is scaled to be approx constant energy per channel
85
+ enorm = 2.0 / (f_pts[2...(n_mels + 2)] - f_pts[:n_mels])
86
+ fb *= enorm.unsqueeze(0)
87
+ end
88
+
89
+ fb
90
+ end
91
+
92
+ def compute_deltas(specgram, win_length: 5, mode: "replicate")
93
+ device = specgram.device
94
+ dtype = specgram.dtype
95
+
96
+ # pack batch
97
+ shape = specgram.size
98
+ specgram = specgram.reshape(1, -1, shape[-1])
99
+
100
+ raise ArgumentError, "win_length must be >= 3" unless win_length >= 3
101
+
102
+ n = (win_length - 1).div(2)
103
+
104
+ # twice sum of integer squared
105
+ denom = n * (n + 1) * (2 * n + 1) / 3
106
+
107
+ specgram = Torch::NN::Functional.pad(specgram, [n, n], mode: mode)
108
+
109
+ kernel = Torch.arange(-n, n + 1, 1, device: device, dtype: dtype).repeat([specgram.shape[1], 1, 1])
110
+
111
+ output = Torch::NN::Functional.conv1d(specgram, kernel, groups: specgram.shape[1]) / denom
112
+
113
+ # unpack batch
114
+ output = output.reshape(shape)
115
+ end
116
+
117
+ def gain(waveform, gain_db: 1.0)
118
+ return waveform if gain_db == 0
119
+
120
+ ratio = 10 ** (gain_db / 20)
121
+
122
+ waveform * ratio
123
+ end
124
+
125
+ def dither(waveform, density_function: "TPDF", noise_shaping: false)
126
+ dithered = _apply_probability_distribution(waveform, density_function: density_function)
127
+
128
+ if noise_shaping
129
+ raise "Not implemented yet"
130
+ # _add_noise_shaping(dithered, waveform)
131
+ else
132
+ dithered
133
+ end
134
+ end
135
+
136
+ def biquad(waveform, b0, b1, b2, a0, a1, a2)
137
+ device = waveform.device
138
+ dtype = waveform.dtype
139
+
140
+ output_waveform = lfilter(
141
+ waveform,
142
+ Torch.tensor([a0, a1, a2], dtype: dtype, device: device),
143
+ Torch.tensor([b0, b1, b2], dtype: dtype, device: device)
144
+ )
145
+ output_waveform
146
+ end
147
+
148
+ def highpass_biquad(waveform, sample_rate, cutoff_freq, q: 0.707)
149
+ w0 = 2 * Math::PI * cutoff_freq / sample_rate
150
+ alpha = Math.sin(w0) / 2.0 / q
151
+
152
+ b0 = (1 + Math.cos(w0)) / 2
153
+ b1 = -1 - Math.cos(w0)
154
+ b2 = b0
155
+ a0 = 1 + alpha
156
+ a1 = -2 * Math.cos(w0)
157
+ a2 = 1 - alpha
158
+ biquad(waveform, b0, b1, b2, a0, a1, a2)
159
+ end
160
+
161
+ def lowpass_biquad(waveform, sample_rate, cutoff_freq, q: 0.707)
162
+ w0 = 2 * Math::PI * cutoff_freq / sample_rate
163
+ alpha = Math.sin(w0) / 2 / q
164
+
165
+ b0 = (1 - Math.cos(w0)) / 2
166
+ b1 = 1 - Math.cos(w0)
167
+ b2 = b0
168
+ a0 = 1 + alpha
169
+ a1 = -2 * Math.cos(w0)
170
+ a2 = 1 - alpha
171
+ biquad(waveform, b0, b1, b2, a0, a1, a2)
172
+ end
173
+
174
+ def lfilter(waveform, a_coeffs, b_coeffs, clamp: true)
175
+ # pack batch
176
+ shape = waveform.size
177
+ waveform = waveform.reshape(-1, shape[-1])
178
+
179
+ raise ArgumentError unless (a_coeffs.size(0) == b_coeffs.size(0))
180
+ raise ArgumentError unless (waveform.size.length == 2)
181
+ raise ArgumentError unless (waveform.device == a_coeffs.device)
182
+ raise ArgumentError unless (b_coeffs.device == a_coeffs.device)
183
+
184
+ device = waveform.device
185
+ dtype = waveform.dtype
186
+ n_channel, n_sample = waveform.size
187
+ n_order = a_coeffs.size(0)
188
+ n_sample_padded = n_sample + n_order - 1
189
+ raise ArgumentError unless (n_order > 0)
190
+
191
+ # Pad the input and create output
192
+ padded_waveform = Torch.zeros(n_channel, n_sample_padded, dtype: dtype, device: device)
193
+ padded_waveform[0..-1, (n_order - 1)..-1] = waveform
194
+ padded_output_waveform = Torch.zeros(n_channel, n_sample_padded, dtype: dtype, device: device)
195
+
196
+ # Set up the coefficients matrix
197
+ # Flip coefficients' order
198
+ a_coeffs_flipped = a_coeffs.flip([0])
199
+ b_coeffs_flipped = b_coeffs.flip([0])
200
+
201
+ # calculate windowed_input_signal in parallel
202
+ # create indices of original with shape (n_channel, n_order, n_sample)
203
+ window_idxs = Torch.arange(n_sample, device: device).unsqueeze(0) + Torch.arange(n_order, device: device).unsqueeze(1)
204
+ window_idxs = window_idxs.repeat([n_channel, 1, 1])
205
+ window_idxs += (Torch.arange(n_channel, device: device).unsqueeze(-1).unsqueeze(-1) * n_sample_padded)
206
+ window_idxs = window_idxs.long
207
+ # (n_order, ) matmul (n_channel, n_order, n_sample) -> (n_channel, n_sample)
208
+ input_signal_windows = Torch.matmul(b_coeffs_flipped, Torch.take(padded_waveform, window_idxs))
209
+
210
+ input_signal_windows.div!(a_coeffs[0])
211
+ a_coeffs_flipped.div!(a_coeffs[0])
212
+ input_signal_windows.t.each_with_index do |o0, i_sample|
213
+ windowed_output_signal = padded_output_waveform[0..-1, i_sample...(i_sample + n_order)]
214
+ o0.addmv!(windowed_output_signal, a_coeffs_flipped, alpha: -1)
215
+ padded_output_waveform[0..-1, i_sample + n_order - 1] = o0
216
+ end
217
+
218
+ output = padded_output_waveform[0..-1, (n_order - 1)..-1]
219
+
220
+ if clamp
221
+ output = Torch.clamp(output, -1.0, 1.0)
222
+ end
223
+
224
+ # unpack batch
225
+ output = output.reshape(shape[0...-1] + output.shape[-1..-1])
226
+
227
+ output
228
+ end
229
+
230
+ private
231
+
232
+ def _apply_probability_distribution(waveform, density_function: "TPDF")
233
+ # pack batch
234
+ shape = waveform.size
235
+ waveform = waveform.reshape(-1, shape[-1])
236
+
237
+ channel_size = waveform.size[0] - 1
238
+ time_size = waveform.size[-1] - 1
239
+
240
+ random_channel = channel_size > 0 ? Torch.randint(channel_size, [1]).item.to_i : 0
241
+ random_time = time_size > 0 ? Torch.randint(time_size, [1]).item.to_i : 0
242
+
243
+ number_of_bits = 16
244
+ up_scaling = 2 ** (number_of_bits - 1) - 2
245
+ signal_scaled = waveform * up_scaling
246
+ down_scaling = 2 ** (number_of_bits - 1)
247
+
248
+ signal_scaled_dis = waveform
249
+ if density_function == "RPDF"
250
+ rpdf = waveform[random_channel][random_time] - 0.5
251
+
252
+ signal_scaled_dis = signal_scaled + rpdf
253
+ elsif density_function == "GPDF"
254
+ # TODO Replace by distribution code once
255
+ # https://github.com/pytorch/pytorch/issues/29843 is resolved
256
+ # gaussian = torch.distributions.normal.Normal(torch.mean(waveform, -1), 1).sample()
257
+
258
+ num_rand_variables = 6
259
+
260
+ gaussian = waveform[random_channel][random_time]
261
+ (num_rand_variables * [time_size]).each do |ws|
262
+ rand_chan = Torch.randint(channel_size, [1]).item.to_i
263
+ gaussian += waveform[rand_chan][Torch.randint(ws, [1]).item.to_i]
264
+ end
265
+
266
+ signal_scaled_dis = signal_scaled + gaussian
267
+ else
268
+ # dtype needed for https://github.com/pytorch/pytorch/issues/32358
269
+ # TODO add support for dtype and device to bartlett_window
270
+ tpdf = Torch.bartlett_window(time_size + 1).to(signal_scaled.device, dtype: signal_scaled.dtype)
271
+ tpdf = tpdf.repeat([channel_size + 1, 1])
272
+ signal_scaled_dis = signal_scaled + tpdf
273
+ end
274
+
275
+ quantised_signal_scaled = Torch.round(signal_scaled_dis)
276
+ quantised_signal = quantised_signal_scaled / down_scaling
277
+
278
+ # unpack batch
279
+ quantised_signal.reshape(shape[0...-1] + quantised_signal.shape[-1..-1])
280
+ end
281
+ end
282
+ end
283
+
284
+ F = Functional
285
+ end
@@ -0,0 +1,39 @@
1
+ module TorchAudio
2
+ module Transforms
3
+ class MelScale < Torch::NN::Module
4
+ def initialize(n_mels: 128, sample_rate: 16000, f_min: 0.0, f_max: nil, n_stft: nil)
5
+ super()
6
+ @n_mels = n_mels
7
+ @sample_rate = sample_rate
8
+ @f_max = f_max || sample_rate.div(2).to_f
9
+ @f_min = f_min
10
+
11
+ raise ArgumentError, "Require f_min: %f < f_max: %f" % [f_min, @f_max] unless f_min <= @f_max
12
+
13
+ fb = n_stft.nil? ? Torch.empty(0) : F.create_fb_matrix(n_stft, @f_min, @f_max, @n_mels, @sample_rate)
14
+ register_buffer("fb", fb)
15
+ end
16
+
17
+ def forward(specgram)
18
+ shape = specgram.size
19
+ specgram = specgram.reshape(-1, shape[-2], shape[-1])
20
+
21
+ if @fb.numel == 0
22
+ tmp_fb = F.create_fb_matrix(specgram.size(1), @f_min, @f_max, @n_mels, @sample_rate)
23
+ # Attributes cannot be reassigned outside __init__ so workaround
24
+ @fb.resize!(tmp_fb.size)
25
+ @fb.copy!(tmp_fb)
26
+ end
27
+
28
+ # (channel, frequency, time).transpose(...) dot (frequency, n_mels)
29
+ # -> (channel, time, n_mels).transpose(...)
30
+ mel_specgram = Torch.matmul(specgram.transpose(1, 2), @fb).transpose(1, 2)
31
+
32
+ # unpack batch
33
+ mel_specgram = mel_specgram.reshape(shape[0...-2] + mel_specgram.shape[-2..-1])
34
+
35
+ mel_specgram
36
+ end
37
+ end
38
+ end
39
+ end
@@ -0,0 +1,35 @@
1
+ module TorchAudio
2
+ module Transforms
3
+ class MelSpectrogram < Torch::NN::Module
4
+ def initialize(
5
+ sample_rate: 16000, n_fft: 400, win_length: nil, hop_length: nil, f_min: 0.0,
6
+ f_max: nil, pad: 0, n_mels: 128, window_fn: Torch.method(:hann_window),
7
+ power: 2.0, normalized: false, wkwargs: nil
8
+ )
9
+
10
+ super()
11
+ @sample_rate = sample_rate
12
+ @n_fft = n_fft
13
+ @win_length = win_length || n_fft
14
+ @hop_length = hop_length || @win_length.div(2)
15
+ @pad = pad
16
+ @power = power
17
+ @normalized = normalized
18
+ @n_mels = n_mels # number of mel frequency bins
19
+ @f_max = f_max
20
+ @f_min = f_min
21
+ @spectrogram =
22
+ Spectrogram.new(
23
+ n_fft: @n_fft, win_length: @win_length, hop_length: @hop_length, pad: @pad,
24
+ window_fn: window_fn, power: @power, normalized: @normalized, wkwargs: wkwargs
25
+ )
26
+ @mel_scale = MelScale.new(n_mels: @n_mels, sample_rate: @sample_rate, f_min: @f_min, f_max: @f_max, n_stft: @n_fft.div(2) + 1)
27
+ end
28
+
29
+ def forward(waveform)
30
+ specgram = @spectrogram.call(waveform)
31
+ @mel_scale.call(specgram)
32
+ end
33
+ end
34
+ end
35
+ end
@@ -0,0 +1,14 @@
1
+ module TorchAudio
2
+ module Transforms
3
+ class MuLawDecoding < Torch::NN::Module
4
+ def initialize(quantization_channels: 256)
5
+ super()
6
+ @quantization_channels = quantization_channels
7
+ end
8
+
9
+ def forward(x_mu)
10
+ F.mu_law_decoding(x_mu, @quantization_channels)
11
+ end
12
+ end
13
+ end
14
+ end
@@ -0,0 +1,14 @@
1
+ module TorchAudio
2
+ module Transforms
3
+ class MuLawEncoding < Torch::NN::Module
4
+ def initialize(quantization_channels: 256)
5
+ super()
6
+ @quantization_channels = quantization_channels
7
+ end
8
+
9
+ def forward(x)
10
+ F.mu_law_encoding(x, @quantization_channels)
11
+ end
12
+ end
13
+ end
14
+ end
@@ -0,0 +1,27 @@
1
+ module TorchAudio
2
+ module Transforms
3
+ class Spectrogram < Torch::NN::Module
4
+ def initialize(
5
+ n_fft: 400, win_length: nil, hop_length: nil, pad: 0,
6
+ window_fn: Torch.method(:hann_window), power: 2.0, normalized: false, wkwargs: nil
7
+ )
8
+
9
+ super()
10
+ @n_fft = n_fft
11
+ # number of FFT bins. the returned STFT result will have n_fft // 2 + 1
12
+ # number of frequecies due to onesided=True in torch.stft
13
+ @win_length = win_length || n_fft
14
+ @hop_length = hop_length || @win_length.div(2) # floor division
15
+ window = wkwargs.nil? ? window_fn.call(@win_length) : window_fn.call(@win_length, **wkwargs)
16
+ register_buffer("window", window)
17
+ @pad = pad
18
+ @power = power
19
+ @normalized = normalized
20
+ end
21
+
22
+ def forward(waveform)
23
+ F.spectrogram(waveform, @pad, @window, @n_fft, @hop_length, @win_length, @power, @normalized)
24
+ end
25
+ end
26
+ end
27
+ end
@@ -1,3 +1,3 @@
1
1
  module TorchAudio
2
- VERSION = "0.1.0"
2
+ VERSION = "0.1.1"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torchaudio
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.1.0
4
+ version: 0.1.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2020-08-25 00:00:00.000000000 Z
11
+ date: 2020-08-26 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: torch-rb
@@ -16,14 +16,14 @@ dependencies:
16
16
  requirements:
17
17
  - - ">="
18
18
  - !ruby/object:Gem::Version
19
- version: 0.3.2
19
+ version: 0.3.4
20
20
  type: :runtime
21
21
  prerelease: false
22
22
  version_requirements: !ruby/object:Gem::Requirement
23
23
  requirements:
24
24
  - - ">="
25
25
  - !ruby/object:Gem::Version
26
- version: 0.3.2
26
+ version: 0.3.4
27
27
  - !ruby/object:Gem::Dependency
28
28
  name: rice
29
29
  requirement: !ruby/object:Gem::Requirement
@@ -118,6 +118,12 @@ files:
118
118
  - lib/torchaudio.rb
119
119
  - lib/torchaudio/datasets/utils.rb
120
120
  - lib/torchaudio/datasets/yesno.rb
121
+ - lib/torchaudio/functional.rb
122
+ - lib/torchaudio/transforms/mel_scale.rb
123
+ - lib/torchaudio/transforms/mel_spectrogram.rb
124
+ - lib/torchaudio/transforms/mu_law_decoding.rb
125
+ - lib/torchaudio/transforms/mu_law_encoding.rb
126
+ - lib/torchaudio/transforms/spectrogram.rb
121
127
  - lib/torchaudio/version.rb
122
128
  homepage: https://github.com/ankane/torchaudio
123
129
  licenses: