torchaudio 0.1.0 → 0.1.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: b527976494325cc12e81342c25d318204d2d7c75bfba7036be4296769cdb30a0
4
- data.tar.gz: 2cfde7bd1b0e7a1628818d5bd74657cfbfba6dfa83ef42897f3ad0f98e77f739
3
+ metadata.gz: 354a8636b2a04fd0ed60db93e2667f630f1db04ab047a6eb6e0c3607190c58d5
4
+ data.tar.gz: cde6229f8ba416c5a72bbea6d47e26fd987b6f13073f1ca9a1e6239cd62b79c1
5
5
  SHA512:
6
- metadata.gz: 8e6f34b014340b5ace3193ab589dae75ed0869ab7606402bd4b09de6042299e6f3a118d439dd381491f489ce9552bca4376a7d5b4693dddc3d1c5f5b26540900
7
- data.tar.gz: d651c46f5185ceb70ae3d9c90154c77afe29a5c35854d1a9d98913096b7ab9ba39a745242dd268548ca87f9e109b56c96dee9dc5539cf066f9ad0f773eddbdcd
6
+ metadata.gz: 76c9c637e4f12700c16bbf343104bb46f5e4f04161a3556cd0600ea236222902ac0ba3b9504f4bdb28a9919317a1b791c8f7b4dbcb48b35ed4c674de39a83277
7
+ data.tar.gz: '019d89cd2231c386c3458447358058618a6de78b77ed5b4d6d279ee0e3879946bda03cf339051136069dd552d6eeaed33bf40defa9d6c90c9a07eb27d68fdce2'
@@ -1,3 +1,8 @@
1
+ ## 0.1.1 (2020-08-26)
2
+
3
+ - Added `save` method
4
+ - Added transforms
5
+
1
6
  ## 0.1.0 (2020-08-24)
2
7
 
3
8
  - First release
data/README.md CHANGED
@@ -2,6 +2,8 @@
2
2
 
3
3
  :fire: An audio library for Torch.rb
4
4
 
5
+ [![Build Status](https://travis-ci.org/ankane/torchaudio.svg?branch=master)](https://travis-ci.org/ankane/torchaudio)
6
+
5
7
  ## Installation
6
8
 
7
9
  First, [install SoX](#sox-installation). For Homebrew, use:
@@ -20,6 +22,40 @@ gem 'torchaudio'
20
22
 
21
23
  This library follows the [Python API](https://pytorch.org/audio/). Many methods and options are missing at the moment. PRs welcome!
22
24
 
25
+ ## Tutorial
26
+
27
+ - [PyTorch tutorial](https://pytorch.org/tutorials/beginner/audio_preprocessing_tutorial.html)
28
+ - [Ruby code](examples/tutorial.rb)
29
+
30
+ Download the [audio file](https://github.com/pytorch/tutorials/raw/master/_static/img/steam-train-whistle-daniel_simon-converted-from-mp3.wav) and install the [matplotlib](https://github.com/mrkn/matplotlib.rb) gem first.
31
+
32
+ ## Basics
33
+
34
+ Load a file
35
+
36
+ ```ruby
37
+ waveform, sample_rate = TorchAudio.load("file.wav")
38
+ ```
39
+
40
+ Save a file
41
+
42
+ ```ruby
43
+ TorchAudio.save("new.wave", waveform, sample_rate)
44
+ ```
45
+
46
+ ## Transforms
47
+
48
+ ```ruby
49
+ TorchAudio::Transforms::Spectrogram.new.call(waveform)
50
+ ```
51
+
52
+ Supported transforms are:
53
+
54
+ - MelSpectrogram
55
+ - MuLawDecoding
56
+ - MuLawEncoding
57
+ - Spectrogram
58
+
23
59
  ## Datasets
24
60
 
25
61
  Load a dataset
@@ -1,33 +1,75 @@
1
1
  #include <torchaudio/csrc/sox.h>
2
2
 
3
+ #include <rice/Constructor.hpp>
3
4
  #include <rice/Module.hpp>
4
5
 
5
6
  using namespace Rice;
6
7
 
8
+ class SignalInfo {
9
+ sox_signalinfo_t* value = nullptr;
10
+ public:
11
+ SignalInfo(Object o) {
12
+ if (!o.is_nil()) {
13
+ value = from_ruby<sox_signalinfo_t*>(o);
14
+ }
15
+ }
16
+ operator sox_signalinfo_t*() {
17
+ return value;
18
+ }
19
+ };
20
+
7
21
  template<>
8
22
  inline
9
- sox_signalinfo_t* from_ruby<sox_signalinfo_t*>(Object x)
23
+ SignalInfo from_ruby<SignalInfo>(Object x)
10
24
  {
11
- if (x.is_nil()) {
12
- return nullptr;
13
- }
14
- throw std::runtime_error("Unsupported signalinfo");
25
+ return SignalInfo(x);
15
26
  }
16
27
 
28
+ class EncodingInfo {
29
+ sox_encodinginfo_t* value = nullptr;
30
+ public:
31
+ EncodingInfo(Object o) {
32
+ if (!o.is_nil()) {
33
+ value = from_ruby<sox_encodinginfo_t*>(o);
34
+ }
35
+ }
36
+ operator sox_encodinginfo_t*() {
37
+ return value;
38
+ }
39
+ };
40
+
17
41
  template<>
18
42
  inline
19
- sox_encodinginfo_t* from_ruby<sox_encodinginfo_t*>(Object x)
43
+ EncodingInfo from_ruby<EncodingInfo>(Object x)
20
44
  {
21
- if (x.is_nil()) {
22
- return nullptr;
23
- }
24
- throw std::runtime_error("Unsupported encodinginfo");
45
+ return EncodingInfo(x);
25
46
  }
26
47
 
27
48
  extern "C"
28
49
  void Init_ext()
29
50
  {
30
51
  Module rb_mTorchAudio = define_module("TorchAudio");
31
- Module rb_mNN = define_module_under(rb_mTorchAudio, "Ext")
32
- .define_singleton_method("read_audio_file", &torch::audio::read_audio_file);
52
+
53
+ Module rb_mExt = define_module_under(rb_mTorchAudio, "Ext")
54
+ .define_singleton_method(
55
+ "read_audio_file",
56
+ *[](const std::string& file_name, at::Tensor output, bool ch_first, int64_t nframes, int64_t offset, SignalInfo si, EncodingInfo ei, const char* ft) {
57
+ return torch::audio::read_audio_file(file_name, output, ch_first, nframes, offset, si, ei, ft);
58
+ })
59
+ .define_singleton_method(
60
+ "write_audio_file",
61
+ *[](const std::string& file_name, const at::Tensor& tensor, SignalInfo si, EncodingInfo ei, const char* file_type) {
62
+ return torch::audio::write_audio_file(file_name, tensor, si, ei, file_type);
63
+ });
64
+
65
+ Class rb_cSignalInfo = define_class_under<sox_signalinfo_t>(rb_mExt, "SignalInfo")
66
+ .define_constructor(Constructor<sox_signalinfo_t>())
67
+ .define_method("rate", *[](sox_signalinfo_t self) { return self.rate; })
68
+ .define_method("channels", *[](sox_signalinfo_t self) { return self.channels; })
69
+ .define_method("precision", *[](sox_signalinfo_t self) { return self.precision; })
70
+ .define_method("length", *[](sox_signalinfo_t self) { return self.length; })
71
+ .define_method("rate=", *[](sox_signalinfo_t self, sox_rate_t rate) { self.rate = rate; })
72
+ .define_method("channels=", *[](sox_signalinfo_t self, unsigned channels) { self.channels = channels; })
73
+ .define_method("precision=", *[](sox_signalinfo_t self, unsigned precision) { self.precision = precision; })
74
+ .define_method("length=", *[](sox_signalinfo_t self, sox_uint64_t length) { self.length = length; });
33
75
  }
@@ -14,6 +14,12 @@ require "set"
14
14
  # modules
15
15
  require "torchaudio/datasets/utils"
16
16
  require "torchaudio/datasets/yesno"
17
+ require "torchaudio/functional"
18
+ require "torchaudio/transforms/mel_scale"
19
+ require "torchaudio/transforms/mel_spectrogram"
20
+ require "torchaudio/transforms/mu_law_encoding"
21
+ require "torchaudio/transforms/mu_law_decoding"
22
+ require "torchaudio/transforms/spectrogram"
17
23
  require "torchaudio/version"
18
24
 
19
25
  module TorchAudio
@@ -73,6 +79,70 @@ module TorchAudio
73
79
  load(filepath, **kwargs)
74
80
  end
75
81
 
82
+ def save(filepath, src, sample_rate, precision: 16, channels_first: true)
83
+ si = Ext::SignalInfo.new
84
+ ch_idx = channels_first ? 0 : 1
85
+ si.rate = sample_rate
86
+ si.channels = src.dim == 1 ? 1 : src.size(ch_idx)
87
+ si.length = src.numel
88
+ si.precision = precision
89
+ save_encinfo(filepath, src, channels_first: channels_first, signalinfo: si)
90
+ end
91
+
92
+ def save_encinfo(filepath, src, channels_first: true, signalinfo: nil, encodinginfo: nil, filetype: nil)
93
+ ch_idx, len_idx = channels_first ? [0, 1] : [1, 0]
94
+
95
+ # check if save directory exists
96
+ abs_dirpath = File.dirname(File.expand_path(filepath))
97
+ unless Dir.exist?(abs_dirpath)
98
+ raise "Directory does not exist: #{abs_dirpath}"
99
+ end
100
+ # check that src is a CPU tensor
101
+ check_input(src)
102
+ # Check/Fix shape of source data
103
+ if src.dim == 1
104
+ # 1d tensors as assumed to be mono signals
105
+ src.unsqueeze!(ch_idx)
106
+ elsif src.dim > 2 || src.size(ch_idx) > 16
107
+ # assumes num_channels < 16
108
+ raise ArgumentError, "Expected format where C < 16, but found #{src.size}"
109
+ end
110
+ # sox stores the sample rate as a float, though practically sample rates are almost always integers
111
+ # convert integers to floats
112
+ if signalinfo
113
+ if signalinfo.rate && !signalinfo.rate.is_a?(Float)
114
+ if signalinfo.rate.to_f == signalinfo.rate
115
+ signalinfo.rate = signalinfo.rate.to_f
116
+ else
117
+ raise ArgumentError, "Sample rate should be a float or int"
118
+ end
119
+ end
120
+ # check if the bit precision (i.e. bits per sample) is an integer
121
+ if signalinfo.precision && ! signalinfo.precision.is_a?(Integer)
122
+ if signalinfo.precision.to_i == signalinfo.precision
123
+ signalinfo.precision = signalinfo.precision.to_i
124
+ else
125
+ raise ArgumentError, "Bit precision should be an integer"
126
+ end
127
+ end
128
+ end
129
+ # programs such as librosa normalize the signal, unnormalize if detected
130
+ if src.min >= -1.0 && src.max <= 1.0
131
+ src = src * (1 << 31)
132
+ src = src.long
133
+ end
134
+ # set filetype and allow for files with no extensions
135
+ extension = File.extname(filepath)
136
+ filetype = extension.length > 0 ? extension[1..-1] : filetype
137
+ # transpose from C x L -> L x C
138
+ if channels_first
139
+ src = src.transpose(1, 0)
140
+ end
141
+ # save data to file
142
+ src = src.contiguous
143
+ Ext.write_audio_file(filepath, src, signalinfo, encodinginfo, filetype)
144
+ end
145
+
76
146
  private
77
147
 
78
148
  def check_input(src)
@@ -0,0 +1,285 @@
1
+ module TorchAudio
2
+ module Functional
3
+ class << self
4
+ def spectrogram(waveform, pad, window, n_fft, hop_length, win_length, power, normalized)
5
+ if pad > 0
6
+ # TODO add "with torch.no_grad():" back when JIT supports it
7
+ waveform = Torch::NN::Functional.pad(waveform, [pad, pad], "constant")
8
+ end
9
+
10
+ # pack batch
11
+ shape = waveform.size
12
+ waveform = waveform.reshape(-1, shape[-1])
13
+
14
+ # default values are consistent with librosa.core.spectrum._spectrogram
15
+ spec_f = Torch.stft(
16
+ waveform, n_fft, hop_length: hop_length, win_length: win_length, window: window, center: true, pad_mode: "reflect", normalized: false, onesided: true
17
+ )
18
+
19
+ # unpack batch
20
+ spec_f = spec_f.reshape(shape[0..-2] + spec_f.shape[-3..-1])
21
+
22
+ if normalized
23
+ spec_f.div!(window.pow(2.0).sum.sqrt)
24
+ end
25
+ if power
26
+ spec_f = complex_norm(spec_f, power: power)
27
+ end
28
+
29
+ spec_f
30
+ end
31
+
32
+ def mu_law_encoding(x, quantization_channels)
33
+ mu = quantization_channels - 1.0
34
+ if !x.floating_point?
35
+ x = x.to(dtype: :float)
36
+ end
37
+ mu = Torch.tensor(mu, dtype: x.dtype)
38
+ x_mu = Torch.sign(x) * Torch.log1p(mu * Torch.abs(x)) / Torch.log1p(mu)
39
+ x_mu = ((x_mu + 1) / 2 * mu + 0.5).to(dtype: :int64)
40
+ x_mu
41
+ end
42
+
43
+ def mu_law_decoding(x_mu, quantization_channels)
44
+ mu = quantization_channels - 1.0
45
+ if !x_mu.floating_point?
46
+ x_mu = x_mu.to(dtype: :float)
47
+ end
48
+ mu = Torch.tensor(mu, dtype: x_mu.dtype)
49
+ x = ((x_mu) / mu) * 2 - 1.0
50
+ x = Torch.sign(x) * (Torch.exp(Torch.abs(x) * Torch.log1p(mu)) - 1.0) / mu
51
+ x
52
+ end
53
+
54
+ def complex_norm(complex_tensor, power: 1.0)
55
+ complex_tensor.pow(2.0).sum(-1).pow(0.5 * power)
56
+ end
57
+
58
+ def create_fb_matrix(n_freqs, f_min, f_max, n_mels, sample_rate, norm: nil)
59
+ if norm && norm != "slaney"
60
+ raise ArgumentError, "norm must be one of None or 'slaney'"
61
+ end
62
+
63
+ # freq bins
64
+ # Equivalent filterbank construction by Librosa
65
+ all_freqs = Torch.linspace(0, sample_rate.div(2), n_freqs)
66
+
67
+ # calculate mel freq bins
68
+ # hertz to mel(f) is 2595. * math.log10(1. + (f / 700.))
69
+ m_min = 2595.0 * Math.log10(1.0 + (f_min / 700.0))
70
+ m_max = 2595.0 * Math.log10(1.0 + (f_max / 700.0))
71
+ m_pts = Torch.linspace(m_min, m_max, n_mels + 2)
72
+ # mel to hertz(mel) is 700. * (10**(mel / 2595.) - 1.)
73
+ f_pts = (Torch.pow(10, m_pts / 2595.0) - 1.0) * 700.0
74
+ # calculate the difference between each mel point and each stft freq point in hertz
75
+ f_diff = f_pts[1..-1] - f_pts[0...-1] # (n_mels + 1)
76
+ slopes = f_pts.unsqueeze(0) - all_freqs.unsqueeze(1) # (n_freqs, n_mels + 2)
77
+ # create overlapping triangles
78
+ zero = Torch.zeros(1)
79
+ down_slopes = (slopes[0..-1, 0...-2] * -1.0) / f_diff[0...-1] # (n_freqs, n_mels)
80
+ up_slopes = slopes[0..-1, 2..-1] / f_diff[1..-1] # (n_freqs, n_mels)
81
+ fb = Torch.max(zero, Torch.min(down_slopes, up_slopes))
82
+
83
+ if norm && norm == "slaney"
84
+ # Slaney-style mel is scaled to be approx constant energy per channel
85
+ enorm = 2.0 / (f_pts[2...(n_mels + 2)] - f_pts[:n_mels])
86
+ fb *= enorm.unsqueeze(0)
87
+ end
88
+
89
+ fb
90
+ end
91
+
92
+ def compute_deltas(specgram, win_length: 5, mode: "replicate")
93
+ device = specgram.device
94
+ dtype = specgram.dtype
95
+
96
+ # pack batch
97
+ shape = specgram.size
98
+ specgram = specgram.reshape(1, -1, shape[-1])
99
+
100
+ raise ArgumentError, "win_length must be >= 3" unless win_length >= 3
101
+
102
+ n = (win_length - 1).div(2)
103
+
104
+ # twice sum of integer squared
105
+ denom = n * (n + 1) * (2 * n + 1) / 3
106
+
107
+ specgram = Torch::NN::Functional.pad(specgram, [n, n], mode: mode)
108
+
109
+ kernel = Torch.arange(-n, n + 1, 1, device: device, dtype: dtype).repeat([specgram.shape[1], 1, 1])
110
+
111
+ output = Torch::NN::Functional.conv1d(specgram, kernel, groups: specgram.shape[1]) / denom
112
+
113
+ # unpack batch
114
+ output = output.reshape(shape)
115
+ end
116
+
117
+ def gain(waveform, gain_db: 1.0)
118
+ return waveform if gain_db == 0
119
+
120
+ ratio = 10 ** (gain_db / 20)
121
+
122
+ waveform * ratio
123
+ end
124
+
125
+ def dither(waveform, density_function: "TPDF", noise_shaping: false)
126
+ dithered = _apply_probability_distribution(waveform, density_function: density_function)
127
+
128
+ if noise_shaping
129
+ raise "Not implemented yet"
130
+ # _add_noise_shaping(dithered, waveform)
131
+ else
132
+ dithered
133
+ end
134
+ end
135
+
136
+ def biquad(waveform, b0, b1, b2, a0, a1, a2)
137
+ device = waveform.device
138
+ dtype = waveform.dtype
139
+
140
+ output_waveform = lfilter(
141
+ waveform,
142
+ Torch.tensor([a0, a1, a2], dtype: dtype, device: device),
143
+ Torch.tensor([b0, b1, b2], dtype: dtype, device: device)
144
+ )
145
+ output_waveform
146
+ end
147
+
148
+ def highpass_biquad(waveform, sample_rate, cutoff_freq, q: 0.707)
149
+ w0 = 2 * Math::PI * cutoff_freq / sample_rate
150
+ alpha = Math.sin(w0) / 2.0 / q
151
+
152
+ b0 = (1 + Math.cos(w0)) / 2
153
+ b1 = -1 - Math.cos(w0)
154
+ b2 = b0
155
+ a0 = 1 + alpha
156
+ a1 = -2 * Math.cos(w0)
157
+ a2 = 1 - alpha
158
+ biquad(waveform, b0, b1, b2, a0, a1, a2)
159
+ end
160
+
161
+ def lowpass_biquad(waveform, sample_rate, cutoff_freq, q: 0.707)
162
+ w0 = 2 * Math::PI * cutoff_freq / sample_rate
163
+ alpha = Math.sin(w0) / 2 / q
164
+
165
+ b0 = (1 - Math.cos(w0)) / 2
166
+ b1 = 1 - Math.cos(w0)
167
+ b2 = b0
168
+ a0 = 1 + alpha
169
+ a1 = -2 * Math.cos(w0)
170
+ a2 = 1 - alpha
171
+ biquad(waveform, b0, b1, b2, a0, a1, a2)
172
+ end
173
+
174
+ def lfilter(waveform, a_coeffs, b_coeffs, clamp: true)
175
+ # pack batch
176
+ shape = waveform.size
177
+ waveform = waveform.reshape(-1, shape[-1])
178
+
179
+ raise ArgumentError unless (a_coeffs.size(0) == b_coeffs.size(0))
180
+ raise ArgumentError unless (waveform.size.length == 2)
181
+ raise ArgumentError unless (waveform.device == a_coeffs.device)
182
+ raise ArgumentError unless (b_coeffs.device == a_coeffs.device)
183
+
184
+ device = waveform.device
185
+ dtype = waveform.dtype
186
+ n_channel, n_sample = waveform.size
187
+ n_order = a_coeffs.size(0)
188
+ n_sample_padded = n_sample + n_order - 1
189
+ raise ArgumentError unless (n_order > 0)
190
+
191
+ # Pad the input and create output
192
+ padded_waveform = Torch.zeros(n_channel, n_sample_padded, dtype: dtype, device: device)
193
+ padded_waveform[0..-1, (n_order - 1)..-1] = waveform
194
+ padded_output_waveform = Torch.zeros(n_channel, n_sample_padded, dtype: dtype, device: device)
195
+
196
+ # Set up the coefficients matrix
197
+ # Flip coefficients' order
198
+ a_coeffs_flipped = a_coeffs.flip([0])
199
+ b_coeffs_flipped = b_coeffs.flip([0])
200
+
201
+ # calculate windowed_input_signal in parallel
202
+ # create indices of original with shape (n_channel, n_order, n_sample)
203
+ window_idxs = Torch.arange(n_sample, device: device).unsqueeze(0) + Torch.arange(n_order, device: device).unsqueeze(1)
204
+ window_idxs = window_idxs.repeat([n_channel, 1, 1])
205
+ window_idxs += (Torch.arange(n_channel, device: device).unsqueeze(-1).unsqueeze(-1) * n_sample_padded)
206
+ window_idxs = window_idxs.long
207
+ # (n_order, ) matmul (n_channel, n_order, n_sample) -> (n_channel, n_sample)
208
+ input_signal_windows = Torch.matmul(b_coeffs_flipped, Torch.take(padded_waveform, window_idxs))
209
+
210
+ input_signal_windows.div!(a_coeffs[0])
211
+ a_coeffs_flipped.div!(a_coeffs[0])
212
+ input_signal_windows.t.each_with_index do |o0, i_sample|
213
+ windowed_output_signal = padded_output_waveform[0..-1, i_sample...(i_sample + n_order)]
214
+ o0.addmv!(windowed_output_signal, a_coeffs_flipped, alpha: -1)
215
+ padded_output_waveform[0..-1, i_sample + n_order - 1] = o0
216
+ end
217
+
218
+ output = padded_output_waveform[0..-1, (n_order - 1)..-1]
219
+
220
+ if clamp
221
+ output = Torch.clamp(output, -1.0, 1.0)
222
+ end
223
+
224
+ # unpack batch
225
+ output = output.reshape(shape[0...-1] + output.shape[-1..-1])
226
+
227
+ output
228
+ end
229
+
230
+ private
231
+
232
+ def _apply_probability_distribution(waveform, density_function: "TPDF")
233
+ # pack batch
234
+ shape = waveform.size
235
+ waveform = waveform.reshape(-1, shape[-1])
236
+
237
+ channel_size = waveform.size[0] - 1
238
+ time_size = waveform.size[-1] - 1
239
+
240
+ random_channel = channel_size > 0 ? Torch.randint(channel_size, [1]).item.to_i : 0
241
+ random_time = time_size > 0 ? Torch.randint(time_size, [1]).item.to_i : 0
242
+
243
+ number_of_bits = 16
244
+ up_scaling = 2 ** (number_of_bits - 1) - 2
245
+ signal_scaled = waveform * up_scaling
246
+ down_scaling = 2 ** (number_of_bits - 1)
247
+
248
+ signal_scaled_dis = waveform
249
+ if density_function == "RPDF"
250
+ rpdf = waveform[random_channel][random_time] - 0.5
251
+
252
+ signal_scaled_dis = signal_scaled + rpdf
253
+ elsif density_function == "GPDF"
254
+ # TODO Replace by distribution code once
255
+ # https://github.com/pytorch/pytorch/issues/29843 is resolved
256
+ # gaussian = torch.distributions.normal.Normal(torch.mean(waveform, -1), 1).sample()
257
+
258
+ num_rand_variables = 6
259
+
260
+ gaussian = waveform[random_channel][random_time]
261
+ (num_rand_variables * [time_size]).each do |ws|
262
+ rand_chan = Torch.randint(channel_size, [1]).item.to_i
263
+ gaussian += waveform[rand_chan][Torch.randint(ws, [1]).item.to_i]
264
+ end
265
+
266
+ signal_scaled_dis = signal_scaled + gaussian
267
+ else
268
+ # dtype needed for https://github.com/pytorch/pytorch/issues/32358
269
+ # TODO add support for dtype and device to bartlett_window
270
+ tpdf = Torch.bartlett_window(time_size + 1).to(signal_scaled.device, dtype: signal_scaled.dtype)
271
+ tpdf = tpdf.repeat([channel_size + 1, 1])
272
+ signal_scaled_dis = signal_scaled + tpdf
273
+ end
274
+
275
+ quantised_signal_scaled = Torch.round(signal_scaled_dis)
276
+ quantised_signal = quantised_signal_scaled / down_scaling
277
+
278
+ # unpack batch
279
+ quantised_signal.reshape(shape[0...-1] + quantised_signal.shape[-1..-1])
280
+ end
281
+ end
282
+ end
283
+
284
+ F = Functional
285
+ end
@@ -0,0 +1,39 @@
1
+ module TorchAudio
2
+ module Transforms
3
+ class MelScale < Torch::NN::Module
4
+ def initialize(n_mels: 128, sample_rate: 16000, f_min: 0.0, f_max: nil, n_stft: nil)
5
+ super()
6
+ @n_mels = n_mels
7
+ @sample_rate = sample_rate
8
+ @f_max = f_max || sample_rate.div(2).to_f
9
+ @f_min = f_min
10
+
11
+ raise ArgumentError, "Require f_min: %f < f_max: %f" % [f_min, @f_max] unless f_min <= @f_max
12
+
13
+ fb = n_stft.nil? ? Torch.empty(0) : F.create_fb_matrix(n_stft, @f_min, @f_max, @n_mels, @sample_rate)
14
+ register_buffer("fb", fb)
15
+ end
16
+
17
+ def forward(specgram)
18
+ shape = specgram.size
19
+ specgram = specgram.reshape(-1, shape[-2], shape[-1])
20
+
21
+ if @fb.numel == 0
22
+ tmp_fb = F.create_fb_matrix(specgram.size(1), @f_min, @f_max, @n_mels, @sample_rate)
23
+ # Attributes cannot be reassigned outside __init__ so workaround
24
+ @fb.resize!(tmp_fb.size)
25
+ @fb.copy!(tmp_fb)
26
+ end
27
+
28
+ # (channel, frequency, time).transpose(...) dot (frequency, n_mels)
29
+ # -> (channel, time, n_mels).transpose(...)
30
+ mel_specgram = Torch.matmul(specgram.transpose(1, 2), @fb).transpose(1, 2)
31
+
32
+ # unpack batch
33
+ mel_specgram = mel_specgram.reshape(shape[0...-2] + mel_specgram.shape[-2..-1])
34
+
35
+ mel_specgram
36
+ end
37
+ end
38
+ end
39
+ end
@@ -0,0 +1,35 @@
1
+ module TorchAudio
2
+ module Transforms
3
+ class MelSpectrogram < Torch::NN::Module
4
+ def initialize(
5
+ sample_rate: 16000, n_fft: 400, win_length: nil, hop_length: nil, f_min: 0.0,
6
+ f_max: nil, pad: 0, n_mels: 128, window_fn: Torch.method(:hann_window),
7
+ power: 2.0, normalized: false, wkwargs: nil
8
+ )
9
+
10
+ super()
11
+ @sample_rate = sample_rate
12
+ @n_fft = n_fft
13
+ @win_length = win_length || n_fft
14
+ @hop_length = hop_length || @win_length.div(2)
15
+ @pad = pad
16
+ @power = power
17
+ @normalized = normalized
18
+ @n_mels = n_mels # number of mel frequency bins
19
+ @f_max = f_max
20
+ @f_min = f_min
21
+ @spectrogram =
22
+ Spectrogram.new(
23
+ n_fft: @n_fft, win_length: @win_length, hop_length: @hop_length, pad: @pad,
24
+ window_fn: window_fn, power: @power, normalized: @normalized, wkwargs: wkwargs
25
+ )
26
+ @mel_scale = MelScale.new(n_mels: @n_mels, sample_rate: @sample_rate, f_min: @f_min, f_max: @f_max, n_stft: @n_fft.div(2) + 1)
27
+ end
28
+
29
+ def forward(waveform)
30
+ specgram = @spectrogram.call(waveform)
31
+ @mel_scale.call(specgram)
32
+ end
33
+ end
34
+ end
35
+ end
@@ -0,0 +1,14 @@
1
+ module TorchAudio
2
+ module Transforms
3
+ class MuLawDecoding < Torch::NN::Module
4
+ def initialize(quantization_channels: 256)
5
+ super()
6
+ @quantization_channels = quantization_channels
7
+ end
8
+
9
+ def forward(x_mu)
10
+ F.mu_law_decoding(x_mu, @quantization_channels)
11
+ end
12
+ end
13
+ end
14
+ end
@@ -0,0 +1,14 @@
1
+ module TorchAudio
2
+ module Transforms
3
+ class MuLawEncoding < Torch::NN::Module
4
+ def initialize(quantization_channels: 256)
5
+ super()
6
+ @quantization_channels = quantization_channels
7
+ end
8
+
9
+ def forward(x)
10
+ F.mu_law_encoding(x, @quantization_channels)
11
+ end
12
+ end
13
+ end
14
+ end
@@ -0,0 +1,27 @@
1
+ module TorchAudio
2
+ module Transforms
3
+ class Spectrogram < Torch::NN::Module
4
+ def initialize(
5
+ n_fft: 400, win_length: nil, hop_length: nil, pad: 0,
6
+ window_fn: Torch.method(:hann_window), power: 2.0, normalized: false, wkwargs: nil
7
+ )
8
+
9
+ super()
10
+ @n_fft = n_fft
11
+ # number of FFT bins. the returned STFT result will have n_fft // 2 + 1
12
+ # number of frequecies due to onesided=True in torch.stft
13
+ @win_length = win_length || n_fft
14
+ @hop_length = hop_length || @win_length.div(2) # floor division
15
+ window = wkwargs.nil? ? window_fn.call(@win_length) : window_fn.call(@win_length, **wkwargs)
16
+ register_buffer("window", window)
17
+ @pad = pad
18
+ @power = power
19
+ @normalized = normalized
20
+ end
21
+
22
+ def forward(waveform)
23
+ F.spectrogram(waveform, @pad, @window, @n_fft, @hop_length, @win_length, @power, @normalized)
24
+ end
25
+ end
26
+ end
27
+ end
@@ -1,3 +1,3 @@
1
1
  module TorchAudio
2
- VERSION = "0.1.0"
2
+ VERSION = "0.1.1"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torchaudio
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.1.0
4
+ version: 0.1.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2020-08-25 00:00:00.000000000 Z
11
+ date: 2020-08-26 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: torch-rb
@@ -16,14 +16,14 @@ dependencies:
16
16
  requirements:
17
17
  - - ">="
18
18
  - !ruby/object:Gem::Version
19
- version: 0.3.2
19
+ version: 0.3.4
20
20
  type: :runtime
21
21
  prerelease: false
22
22
  version_requirements: !ruby/object:Gem::Requirement
23
23
  requirements:
24
24
  - - ">="
25
25
  - !ruby/object:Gem::Version
26
- version: 0.3.2
26
+ version: 0.3.4
27
27
  - !ruby/object:Gem::Dependency
28
28
  name: rice
29
29
  requirement: !ruby/object:Gem::Requirement
@@ -118,6 +118,12 @@ files:
118
118
  - lib/torchaudio.rb
119
119
  - lib/torchaudio/datasets/utils.rb
120
120
  - lib/torchaudio/datasets/yesno.rb
121
+ - lib/torchaudio/functional.rb
122
+ - lib/torchaudio/transforms/mel_scale.rb
123
+ - lib/torchaudio/transforms/mel_spectrogram.rb
124
+ - lib/torchaudio/transforms/mu_law_decoding.rb
125
+ - lib/torchaudio/transforms/mu_law_encoding.rb
126
+ - lib/torchaudio/transforms/spectrogram.rb
121
127
  - lib/torchaudio/version.rb
122
128
  homepage: https://github.com/ankane/torchaudio
123
129
  licenses: