torchaudio 0.1.0 → 0.2.1

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: b527976494325cc12e81342c25d318204d2d7c75bfba7036be4296769cdb30a0
4
- data.tar.gz: 2cfde7bd1b0e7a1628818d5bd74657cfbfba6dfa83ef42897f3ad0f98e77f739
3
+ metadata.gz: 9ed4c14921f1eee18f5e08ddabfae51e09a9b5a7ef408f1dd67fdf7bfe9622fe
4
+ data.tar.gz: 1e37d5b9abed9cab7bf56a8c30a769bc8ff8f8a3e15e78bbb772847c444571b2
5
5
  SHA512:
6
- metadata.gz: 8e6f34b014340b5ace3193ab589dae75ed0869ab7606402bd4b09de6042299e6f3a118d439dd381491f489ce9552bca4376a7d5b4693dddc3d1c5f5b26540900
7
- data.tar.gz: d651c46f5185ceb70ae3d9c90154c77afe29a5c35854d1a9d98913096b7ab9ba39a745242dd268548ca87f9e109b56c96dee9dc5539cf066f9ad0f773eddbdcd
6
+ metadata.gz: 9ca5436d7e4309dd9659fdce7ee893b122e9da96e9f7b15bf00de5dea32c635e2828a99939f04ef5bf0d9494ab89957829a65002dc3e855fa8a66f54abbbd181
7
+ data.tar.gz: d62b2a137c19d3b24facb11eda5c1b81be5841120b505877b8617bee2b9f183dbe4b4d42a95af27447346a3d48476d7faec48b57cd89b88c0ddc9709f1b5d51b
data/CHANGELOG.md CHANGED
@@ -1,3 +1,24 @@
1
+ ## 0.2.1 (2021-07-16)
2
+
3
+ - Added `create_dct` method
4
+ - Added `ComputeDeltas`, `Fade`, `MFCC`, and `Vol` transforms
5
+
6
+ ## 0.2.0 (2021-05-23)
7
+
8
+ - Updated to Rice 4
9
+ - Dropped support for Ruby < 2.6
10
+
11
+ ## 0.1.2 (2021-02-06)
12
+
13
+ - Added `amplitude_to_DB` and `DB_to_amplitude` methods
14
+ - Added `AmplitudeToDB` transform
15
+ - Fixed `save` options
16
+
17
+ ## 0.1.1 (2020-08-26)
18
+
19
+ - Added `save` method
20
+ - Added transforms
21
+
1
22
  ## 0.1.0 (2020-08-24)
2
23
 
3
24
  - First release
data/LICENSE.txt CHANGED
@@ -1,7 +1,7 @@
1
1
  BSD 2-Clause License
2
2
 
3
3
  Copyright (c) 2017 Facebook Inc. (Soumith Chintala),
4
- Copyright (c) 2020 Andrew Kane,
4
+ Copyright (c) 2020-2021 Andrew Kane,
5
5
  All rights reserved.
6
6
 
7
7
  Redistribution and use in source and binary forms, with or without
data/README.md CHANGED
@@ -2,6 +2,8 @@
2
2
 
3
3
  :fire: An audio library for Torch.rb
4
4
 
5
+ [![Build Status](https://github.com/ankane/torchaudio/workflows/build/badge.svg?branch=master)](https://github.com/ankane/torchaudio/actions)
6
+
5
7
  ## Installation
6
8
 
7
9
  First, [install SoX](#sox-installation). For Homebrew, use:
@@ -20,6 +22,67 @@ gem 'torchaudio'
20
22
 
21
23
  This library follows the [Python API](https://pytorch.org/audio/). Many methods and options are missing at the moment. PRs welcome!
22
24
 
25
+ ## Tutorial
26
+
27
+ - [PyTorch tutorial](https://pytorch.org/tutorials/beginner/audio_preprocessing_tutorial.html)
28
+ - [Ruby code](examples/tutorial.rb)
29
+
30
+ Download the [audio file](https://github.com/pytorch/tutorials/raw/master/_static/img/steam-train-whistle-daniel_simon-converted-from-mp3.wav) and install the [matplotlib](https://github.com/mrkn/matplotlib.rb) gem first.
31
+
32
+ ## Basics
33
+
34
+ Load a file
35
+
36
+ ```ruby
37
+ waveform, sample_rate = TorchAudio.load("file.wav")
38
+ ```
39
+
40
+ Save a file
41
+
42
+ ```ruby
43
+ TorchAudio.save("new.wave", waveform, sample_rate)
44
+ ```
45
+
46
+ ## Transforms
47
+
48
+ ```ruby
49
+ TorchAudio::Transforms::Spectrogram.new.call(waveform)
50
+ ```
51
+
52
+ Supported transforms are:
53
+
54
+ - AmplitudeToDB
55
+ - ComputeDeltas
56
+ - Fade
57
+ - MelScale
58
+ - MelSpectrogram
59
+ - MFCC
60
+ - MuLawDecoding
61
+ - MuLawEncoding
62
+ - Spectrogram
63
+ - Vol
64
+
65
+ ## Functional
66
+
67
+ ```ruby
68
+ TorchAudio::Functional.lowpass_biquad(waveform, sample_rate, cutoff_freq)
69
+ ```
70
+
71
+ Supported functions are:
72
+
73
+ - amplitude_to_DB
74
+ - compute_deltas
75
+ - create_dct
76
+ - create_fb_matrix
77
+ - DB_to_amplitude
78
+ - dither
79
+ - gain
80
+ - highpass_biquad
81
+ - lowpass_biquad
82
+ - mu_law_decoding
83
+ - mu_law_encoding
84
+ - spectrogram
85
+
23
86
  ## Datasets
24
87
 
25
88
  Load a dataset
@@ -1,33 +1,33 @@
1
1
  #include <torchaudio/csrc/sox.h>
2
2
 
3
- #include <rice/Module.hpp>
4
-
5
- using namespace Rice;
6
-
7
- template<>
8
- inline
9
- sox_signalinfo_t* from_ruby<sox_signalinfo_t*>(Object x)
10
- {
11
- if (x.is_nil()) {
12
- return nullptr;
13
- }
14
- throw std::runtime_error("Unsupported signalinfo");
15
- }
16
-
17
- template<>
18
- inline
19
- sox_encodinginfo_t* from_ruby<sox_encodinginfo_t*>(Object x)
20
- {
21
- if (x.is_nil()) {
22
- return nullptr;
23
- }
24
- throw std::runtime_error("Unsupported encodinginfo");
25
- }
3
+ #include <rice/rice.hpp>
4
+ #include <rice/stl.hpp>
26
5
 
27
6
  extern "C"
28
7
  void Init_ext()
29
8
  {
30
- Module rb_mTorchAudio = define_module("TorchAudio");
31
- Module rb_mNN = define_module_under(rb_mTorchAudio, "Ext")
32
- .define_singleton_method("read_audio_file", &torch::audio::read_audio_file);
9
+ auto rb_mTorchAudio = Rice::define_module("TorchAudio");
10
+
11
+ auto rb_mExt = Rice::define_module_under(rb_mTorchAudio, "Ext")
12
+ .define_singleton_function(
13
+ "read_audio_file",
14
+ [](const std::string& file_name, at::Tensor output, bool ch_first, int64_t nframes, int64_t offset, sox_signalinfo_t* si, sox_encodinginfo_t* ei, const char* ft) {
15
+ return torch::audio::read_audio_file(file_name, output, ch_first, nframes, offset, si, ei, ft);
16
+ })
17
+ .define_singleton_function(
18
+ "write_audio_file",
19
+ [](const std::string& file_name, const at::Tensor& tensor, sox_signalinfo_t* si, sox_encodinginfo_t* ei, const char* file_type) {
20
+ return torch::audio::write_audio_file(file_name, tensor, si, ei, file_type);
21
+ });
22
+
23
+ auto rb_cSignalInfo = Rice::define_class_under<sox_signalinfo_t>(rb_mExt, "SignalInfo")
24
+ .define_constructor(Rice::Constructor<sox_signalinfo_t>())
25
+ .define_method("rate", [](sox_signalinfo_t& self) { return self.rate; })
26
+ .define_method("channels", [](sox_signalinfo_t& self) { return self.channels; })
27
+ .define_method("precision", [](sox_signalinfo_t& self) { return self.precision; })
28
+ .define_method("length", [](sox_signalinfo_t& self) { return self.length; })
29
+ .define_method("rate=", [](sox_signalinfo_t& self, sox_rate_t rate) { self.rate = rate; })
30
+ .define_method("channels=", [](sox_signalinfo_t& self, unsigned channels) { self.channels = channels; })
31
+ .define_method("precision=", [](sox_signalinfo_t& self, unsigned precision) { self.precision = precision; })
32
+ .define_method("length=", [](sox_signalinfo_t& self, sox_uint64_t length) { self.length = length; });
33
33
  }
@@ -1,8 +1,6 @@
1
1
  require "mkmf-rice"
2
2
 
3
- abort "Missing stdc++" unless have_library("stdc++")
4
-
5
- $CXXFLAGS += " -std=c++14"
3
+ $CXXFLAGS += " -std=c++17 $(optflags)"
6
4
 
7
5
  abort "SoX not found" unless have_library("sox")
8
6
 
@@ -24,7 +22,6 @@ apple_clang = RbConfig::CONFIG["CC_VERSION_MESSAGE"] =~ /apple clang/i
24
22
 
25
23
  # check omp first
26
24
  if have_library("omp") || have_library("gomp")
27
- $CXXFLAGS += " -DAT_PARALLEL_OPENMP=1"
28
25
  $CXXFLAGS += " -Xclang" if apple_clang
29
26
  $CXXFLAGS += " -fopenmp"
30
27
  end
data/lib/torchaudio.rb CHANGED
@@ -14,6 +14,17 @@ require "set"
14
14
  # modules
15
15
  require "torchaudio/datasets/utils"
16
16
  require "torchaudio/datasets/yesno"
17
+ require "torchaudio/functional"
18
+ require "torchaudio/transforms/compute_deltas"
19
+ require "torchaudio/transforms/fade"
20
+ require "torchaudio/transforms/mel_scale"
21
+ require "torchaudio/transforms/mel_spectrogram"
22
+ require "torchaudio/transforms/mu_law_encoding"
23
+ require "torchaudio/transforms/mu_law_decoding"
24
+ require "torchaudio/transforms/spectrogram"
25
+ require "torchaudio/transforms/amplitude_to_db"
26
+ require "torchaudio/transforms/mfcc"
27
+ require "torchaudio/transforms/vol"
17
28
  require "torchaudio/version"
18
29
 
19
30
  module TorchAudio
@@ -73,6 +84,70 @@ module TorchAudio
73
84
  load(filepath, **kwargs)
74
85
  end
75
86
 
87
+ def save(filepath, src, sample_rate, precision: 16, channels_first: true)
88
+ si = Ext::SignalInfo.new
89
+ ch_idx = channels_first ? 0 : 1
90
+ si.rate = sample_rate
91
+ si.channels = src.dim == 1 ? 1 : src.size(ch_idx)
92
+ si.length = src.numel
93
+ si.precision = precision
94
+ save_encinfo(filepath, src, channels_first: channels_first, signalinfo: si)
95
+ end
96
+
97
+ def save_encinfo(filepath, src, channels_first: true, signalinfo: nil, encodinginfo: nil, filetype: nil)
98
+ ch_idx, len_idx = channels_first ? [0, 1] : [1, 0]
99
+
100
+ # check if save directory exists
101
+ abs_dirpath = File.dirname(File.expand_path(filepath))
102
+ unless Dir.exist?(abs_dirpath)
103
+ raise "Directory does not exist: #{abs_dirpath}"
104
+ end
105
+ # check that src is a CPU tensor
106
+ check_input(src)
107
+ # Check/Fix shape of source data
108
+ if src.dim == 1
109
+ # 1d tensors as assumed to be mono signals
110
+ src.unsqueeze!(ch_idx)
111
+ elsif src.dim > 2 || src.size(ch_idx) > 16
112
+ # assumes num_channels < 16
113
+ raise ArgumentError, "Expected format where C < 16, but found #{src.size}"
114
+ end
115
+ # sox stores the sample rate as a float, though practically sample rates are almost always integers
116
+ # convert integers to floats
117
+ if signalinfo
118
+ if signalinfo.rate && !signalinfo.rate.is_a?(Float)
119
+ if signalinfo.rate.to_f == signalinfo.rate
120
+ signalinfo.rate = signalinfo.rate.to_f
121
+ else
122
+ raise ArgumentError, "Sample rate should be a float or int"
123
+ end
124
+ end
125
+ # check if the bit precision (i.e. bits per sample) is an integer
126
+ if signalinfo.precision && ! signalinfo.precision.is_a?(Integer)
127
+ if signalinfo.precision.to_i == signalinfo.precision
128
+ signalinfo.precision = signalinfo.precision.to_i
129
+ else
130
+ raise ArgumentError, "Bit precision should be an integer"
131
+ end
132
+ end
133
+ end
134
+ # programs such as librosa normalize the signal, unnormalize if detected
135
+ if src.min >= -1.0 && src.max <= 1.0
136
+ src = src * (1 << 31)
137
+ src = src.long
138
+ end
139
+ # set filetype and allow for files with no extensions
140
+ extension = File.extname(filepath)
141
+ filetype = extension.length > 0 ? extension[1..-1] : filetype
142
+ # transpose from C x L -> L x C
143
+ if channels_first
144
+ src = src.transpose(1, 0)
145
+ end
146
+ # save data to file
147
+ src = src.contiguous
148
+ Ext.write_audio_file(filepath, src, signalinfo, encodinginfo, filetype)
149
+ end
150
+
76
151
  private
77
152
 
78
153
  def check_input(src)
@@ -0,0 +1,324 @@
1
+ module TorchAudio
2
+ module Functional
3
+ class << self
4
+ def spectrogram(waveform, pad, window, n_fft, hop_length, win_length, power, normalized)
5
+ if pad > 0
6
+ # TODO add "with torch.no_grad():" back when JIT supports it
7
+ waveform = Torch::NN::Functional.pad(waveform, [pad, pad], "constant")
8
+ end
9
+
10
+ # pack batch
11
+ shape = waveform.size
12
+ waveform = waveform.reshape(-1, shape[-1])
13
+
14
+ # default values are consistent with librosa.core.spectrum._spectrogram
15
+ spec_f =
16
+ Torch.stft(
17
+ waveform,
18
+ n_fft,
19
+ hop_length: hop_length,
20
+ win_length: win_length,
21
+ window: window,
22
+ center: true,
23
+ pad_mode: "reflect",
24
+ normalized: false,
25
+ onesided: true
26
+ )
27
+
28
+ # unpack batch
29
+ spec_f = spec_f.reshape(shape[0..-2] + spec_f.shape[-3..-1])
30
+
31
+ if normalized
32
+ spec_f.div!(window.pow(2.0).sum.sqrt)
33
+ end
34
+ if power
35
+ spec_f = complex_norm(spec_f, power: power)
36
+ end
37
+
38
+ spec_f
39
+ end
40
+
41
+ def mu_law_encoding(x, quantization_channels)
42
+ mu = quantization_channels - 1.0
43
+ if !x.floating_point?
44
+ x = x.to(dtype: :float)
45
+ end
46
+ mu = Torch.tensor(mu, dtype: x.dtype)
47
+ x_mu = Torch.sign(x) * Torch.log1p(mu * Torch.abs(x)) / Torch.log1p(mu)
48
+ x_mu = ((x_mu + 1) / 2 * mu + 0.5).to(dtype: :int64)
49
+ x_mu
50
+ end
51
+
52
+ def mu_law_decoding(x_mu, quantization_channels)
53
+ mu = quantization_channels - 1.0
54
+ if !x_mu.floating_point?
55
+ x_mu = x_mu.to(dtype: :float)
56
+ end
57
+ mu = Torch.tensor(mu, dtype: x_mu.dtype)
58
+ x = ((x_mu) / mu) * 2 - 1.0
59
+ x = Torch.sign(x) * (Torch.exp(Torch.abs(x) * Torch.log1p(mu)) - 1.0) / mu
60
+ x
61
+ end
62
+
63
+ def complex_norm(complex_tensor, power: 1.0)
64
+ complex_tensor.pow(2.0).sum(-1).pow(0.5 * power)
65
+ end
66
+
67
+ def create_fb_matrix(n_freqs, f_min, f_max, n_mels, sample_rate, norm: nil)
68
+ if norm && norm != "slaney"
69
+ raise ArgumentError, "norm must be one of None or 'slaney'"
70
+ end
71
+
72
+ # freq bins
73
+ # Equivalent filterbank construction by Librosa
74
+ all_freqs = Torch.linspace(0, sample_rate.div(2), n_freqs)
75
+
76
+ # calculate mel freq bins
77
+ # hertz to mel(f) is 2595. * math.log10(1. + (f / 700.))
78
+ m_min = 2595.0 * Math.log10(1.0 + (f_min / 700.0))
79
+ m_max = 2595.0 * Math.log10(1.0 + (f_max / 700.0))
80
+ m_pts = Torch.linspace(m_min, m_max, n_mels + 2)
81
+ # mel to hertz(mel) is 700. * (10**(mel / 2595.) - 1.)
82
+ f_pts = (Torch.pow(10, m_pts / 2595.0) - 1.0) * 700.0
83
+ # calculate the difference between each mel point and each stft freq point in hertz
84
+ f_diff = f_pts[1..-1] - f_pts[0...-1] # (n_mels + 1)
85
+ slopes = f_pts.unsqueeze(0) - all_freqs.unsqueeze(1) # (n_freqs, n_mels + 2)
86
+ # create overlapping triangles
87
+ zero = Torch.zeros(1)
88
+ down_slopes = (slopes[0..-1, 0...-2] * -1.0) / f_diff[0...-1] # (n_freqs, n_mels)
89
+ up_slopes = slopes[0..-1, 2..-1] / f_diff[1..-1] # (n_freqs, n_mels)
90
+ fb = Torch.max(zero, Torch.min(down_slopes, up_slopes))
91
+
92
+ if norm && norm == "slaney"
93
+ # Slaney-style mel is scaled to be approx constant energy per channel
94
+ enorm = 2.0 / (f_pts[2...(n_mels + 2)] - f_pts[:n_mels])
95
+ fb *= enorm.unsqueeze(0)
96
+ end
97
+
98
+ fb
99
+ end
100
+
101
+ def compute_deltas(specgram, win_length: 5, mode: "replicate")
102
+ device = specgram.device
103
+ dtype = specgram.dtype
104
+
105
+ # pack batch
106
+ shape = specgram.size
107
+ specgram = specgram.reshape(1, -1, shape[-1])
108
+
109
+ raise ArgumentError, "win_length must be >= 3" unless win_length >= 3
110
+
111
+ n = (win_length - 1).div(2)
112
+
113
+ # twice sum of integer squared
114
+ denom = n * (n + 1) * (2 * n + 1) / 3
115
+
116
+ specgram = Torch::NN::Functional.pad(specgram, [n, n], mode: mode)
117
+
118
+ kernel = Torch.arange(-n, n + 1, 1, device: device, dtype: dtype).repeat([specgram.shape[1], 1, 1])
119
+
120
+ output = Torch::NN::Functional.conv1d(specgram, kernel, groups: specgram.shape[1]) / denom
121
+
122
+ # unpack batch
123
+ output = output.reshape(shape)
124
+ end
125
+
126
+ def gain(waveform, gain_db: 1.0)
127
+ return waveform if gain_db == 0
128
+
129
+ ratio = 10 ** (gain_db / 20)
130
+
131
+ waveform * ratio
132
+ end
133
+
134
+ def dither(waveform, density_function: "TPDF", noise_shaping: false)
135
+ dithered = _apply_probability_distribution(waveform, density_function: density_function)
136
+
137
+ if noise_shaping
138
+ raise "Not implemented yet"
139
+ # _add_noise_shaping(dithered, waveform)
140
+ else
141
+ dithered
142
+ end
143
+ end
144
+
145
+ def biquad(waveform, b0, b1, b2, a0, a1, a2)
146
+ device = waveform.device
147
+ dtype = waveform.dtype
148
+
149
+ output_waveform = lfilter(
150
+ waveform,
151
+ Torch.tensor([a0, a1, a2], dtype: dtype, device: device),
152
+ Torch.tensor([b0, b1, b2], dtype: dtype, device: device)
153
+ )
154
+ output_waveform
155
+ end
156
+
157
+ def highpass_biquad(waveform, sample_rate, cutoff_freq, q: 0.707)
158
+ w0 = 2 * Math::PI * cutoff_freq / sample_rate
159
+ alpha = Math.sin(w0) / 2.0 / q
160
+
161
+ b0 = (1 + Math.cos(w0)) / 2
162
+ b1 = -1 - Math.cos(w0)
163
+ b2 = b0
164
+ a0 = 1 + alpha
165
+ a1 = -2 * Math.cos(w0)
166
+ a2 = 1 - alpha
167
+ biquad(waveform, b0, b1, b2, a0, a1, a2)
168
+ end
169
+
170
+ def lowpass_biquad(waveform, sample_rate, cutoff_freq, q: 0.707)
171
+ w0 = 2 * Math::PI * cutoff_freq / sample_rate
172
+ alpha = Math.sin(w0) / 2 / q
173
+
174
+ b0 = (1 - Math.cos(w0)) / 2
175
+ b1 = 1 - Math.cos(w0)
176
+ b2 = b0
177
+ a0 = 1 + alpha
178
+ a1 = -2 * Math.cos(w0)
179
+ a2 = 1 - alpha
180
+ biquad(waveform, b0, b1, b2, a0, a1, a2)
181
+ end
182
+
183
+ def lfilter(waveform, a_coeffs, b_coeffs, clamp: true)
184
+ # pack batch
185
+ shape = waveform.size
186
+ waveform = waveform.reshape(-1, shape[-1])
187
+
188
+ raise ArgumentError unless (a_coeffs.size(0) == b_coeffs.size(0))
189
+ raise ArgumentError unless (waveform.size.length == 2)
190
+ raise ArgumentError unless (waveform.device == a_coeffs.device)
191
+ raise ArgumentError unless (b_coeffs.device == a_coeffs.device)
192
+
193
+ device = waveform.device
194
+ dtype = waveform.dtype
195
+ n_channel, n_sample = waveform.size
196
+ n_order = a_coeffs.size(0)
197
+ n_sample_padded = n_sample + n_order - 1
198
+ raise ArgumentError unless (n_order > 0)
199
+
200
+ # Pad the input and create output
201
+ padded_waveform = Torch.zeros(n_channel, n_sample_padded, dtype: dtype, device: device)
202
+ padded_waveform[0..-1, (n_order - 1)..-1] = waveform
203
+ padded_output_waveform = Torch.zeros(n_channel, n_sample_padded, dtype: dtype, device: device)
204
+
205
+ # Set up the coefficients matrix
206
+ # Flip coefficients' order
207
+ a_coeffs_flipped = a_coeffs.flip([0])
208
+ b_coeffs_flipped = b_coeffs.flip([0])
209
+
210
+ # calculate windowed_input_signal in parallel
211
+ # create indices of original with shape (n_channel, n_order, n_sample)
212
+ window_idxs = Torch.arange(n_sample, device: device).unsqueeze(0) + Torch.arange(n_order, device: device).unsqueeze(1)
213
+ window_idxs = window_idxs.repeat([n_channel, 1, 1])
214
+ window_idxs += (Torch.arange(n_channel, device: device).unsqueeze(-1).unsqueeze(-1) * n_sample_padded)
215
+ window_idxs = window_idxs.long
216
+ # (n_order, ) matmul (n_channel, n_order, n_sample) -> (n_channel, n_sample)
217
+ input_signal_windows = Torch.matmul(b_coeffs_flipped, Torch.take(padded_waveform, window_idxs))
218
+
219
+ input_signal_windows.div!(a_coeffs[0])
220
+ a_coeffs_flipped.div!(a_coeffs[0])
221
+ input_signal_windows.t.each_with_index do |o0, i_sample|
222
+ windowed_output_signal = padded_output_waveform[0..-1, i_sample...(i_sample + n_order)]
223
+ o0.addmv!(windowed_output_signal, a_coeffs_flipped, alpha: -1)
224
+ padded_output_waveform[0..-1, i_sample + n_order - 1] = o0
225
+ end
226
+
227
+ output = padded_output_waveform[0..-1, (n_order - 1)..-1]
228
+
229
+ if clamp
230
+ output = Torch.clamp(output, -1.0, 1.0)
231
+ end
232
+
233
+ # unpack batch
234
+ output = output.reshape(shape[0...-1] + output.shape[-1..-1])
235
+
236
+ output
237
+ end
238
+
239
+ def amplitude_to_DB(amp, multiplier, amin, db_multiplier, top_db: nil)
240
+ db = Torch.log10(Torch.clamp(amp, min: amin)) * multiplier
241
+ db -= multiplier * db_multiplier
242
+
243
+ db = db.clamp(min: db.max.item - top_db) if top_db
244
+
245
+ db
246
+ end
247
+
248
+ def DB_to_amplitude(db, ref, power)
249
+ Torch.pow(Torch.pow(10.0, db * 0.1), power) * ref
250
+ end
251
+
252
+ def create_dct(n_mfcc, n_mels, norm: nil)
253
+ n = Torch.arange(n_mels.to_f)
254
+ k = Torch.arange(n_mfcc.to_f).unsqueeze!(1)
255
+ dct = Torch.cos((n + 0.5) * k * Math::PI / n_mels.to_f)
256
+
257
+ if norm.nil?
258
+ dct *= 2.0
259
+ else
260
+ raise ArgumentError, "Invalid DCT norm value" unless norm == :ortho
261
+
262
+ dct[0] *= 1.0 / Math.sqrt(2.0)
263
+ dct *= Math.sqrt(2.0 / n_mels)
264
+ end
265
+
266
+ dct.t
267
+ end
268
+
269
+ private
270
+
271
+ def _apply_probability_distribution(waveform, density_function: "TPDF")
272
+ # pack batch
273
+ shape = waveform.size
274
+ waveform = waveform.reshape(-1, shape[-1])
275
+
276
+ channel_size = waveform.size[0] - 1
277
+ time_size = waveform.size[-1] - 1
278
+
279
+ random_channel = channel_size > 0 ? Torch.randint(channel_size, [1]).item.to_i : 0
280
+ random_time = time_size > 0 ? Torch.randint(time_size, [1]).item.to_i : 0
281
+
282
+ number_of_bits = 16
283
+ up_scaling = 2 ** (number_of_bits - 1) - 2
284
+ signal_scaled = waveform * up_scaling
285
+ down_scaling = 2 ** (number_of_bits - 1)
286
+
287
+ signal_scaled_dis = waveform
288
+ if density_function == "RPDF"
289
+ rpdf = waveform[random_channel][random_time] - 0.5
290
+
291
+ signal_scaled_dis = signal_scaled + rpdf
292
+ elsif density_function == "GPDF"
293
+ # TODO Replace by distribution code once
294
+ # https://github.com/pytorch/pytorch/issues/29843 is resolved
295
+ # gaussian = torch.distributions.normal.Normal(torch.mean(waveform, -1), 1).sample()
296
+
297
+ num_rand_variables = 6
298
+
299
+ gaussian = waveform[random_channel][random_time]
300
+ (num_rand_variables * [time_size]).each do |ws|
301
+ rand_chan = Torch.randint(channel_size, [1]).item.to_i
302
+ gaussian += waveform[rand_chan][Torch.randint(ws, [1]).item.to_i]
303
+ end
304
+
305
+ signal_scaled_dis = signal_scaled + gaussian
306
+ else
307
+ # dtype needed for https://github.com/pytorch/pytorch/issues/32358
308
+ # TODO add support for dtype and device to bartlett_window
309
+ tpdf = Torch.bartlett_window(time_size + 1).to(signal_scaled.device, dtype: signal_scaled.dtype)
310
+ tpdf = tpdf.repeat([channel_size + 1, 1])
311
+ signal_scaled_dis = signal_scaled + tpdf
312
+ end
313
+
314
+ quantised_signal_scaled = Torch.round(signal_scaled_dis)
315
+ quantised_signal = quantised_signal_scaled / down_scaling
316
+
317
+ # unpack batch
318
+ quantised_signal.reshape(shape[0...-1] + quantised_signal.shape[-1..-1])
319
+ end
320
+ end
321
+ end
322
+
323
+ F = Functional
324
+ end
@@ -0,0 +1,27 @@
1
+ module TorchAudio
2
+ module Transforms
3
+ class AmplitudeToDB < Torch::NN::Module
4
+ def initialize(stype: :power, top_db: nil)
5
+ super()
6
+
7
+ @stype = stype
8
+
9
+ raise ArgumentError, 'top_db must be a positive numerical' if top_db && top_db.negative?
10
+
11
+ @top_db = top_db
12
+ @multiplier = stype == :power ? 10.0 : 20.0
13
+ @amin = 1e-10
14
+ @ref_value = 1.0
15
+ @db_multiplier = Math.log10([@amin, @ref_value].max)
16
+ end
17
+
18
+ def forward(amplitude_spectrogram)
19
+ F.amplitude_to_DB(
20
+ amplitude_spectrogram,
21
+ @multiplier, @amin, @db_multiplier,
22
+ top_db: @top_db
23
+ )
24
+ end
25
+ end
26
+ end
27
+ end
@@ -0,0 +1,15 @@
1
+ module TorchAudio
2
+ module Transforms
3
+ class ComputeDeltas < Torch::NN::Module
4
+ def initialize(win_length: 5, mode: "replicate")
5
+ super()
6
+ @win_length = win_length
7
+ @mode = mode
8
+ end
9
+
10
+ def forward(specgram)
11
+ F.compute_deltas(specgram, win_length: @win_length, mode: @mode)
12
+ end
13
+ end
14
+ end
15
+ end
@@ -0,0 +1,74 @@
1
+ module TorchAudio
2
+ module Transforms
3
+ class Fade < Torch::NN::Module
4
+ def initialize(fade_in_len: 0, fade_out_len: 0, fade_shape: "linear")
5
+ super()
6
+ @fade_in_len = fade_in_len
7
+ @fade_out_len = fade_out_len
8
+ @fade_shape = fade_shape
9
+ end
10
+
11
+ def forward(waveform)
12
+ waveform_length = waveform.size[-1]
13
+ device = waveform.device
14
+ fade_in(waveform_length).to(device) * fade_out(waveform_length).to(device) * waveform
15
+ end
16
+
17
+ private
18
+
19
+ def fade_in(waveform_length)
20
+ fade = Torch.linspace(0, 1, @fade_in_len)
21
+ ones = Torch.ones(waveform_length - @fade_in_len)
22
+
23
+ if @fade_shape == "linear"
24
+ fade = fade
25
+ end
26
+
27
+ if @fade_shape == "exponential"
28
+ fade = Torch.pow(2, (fade - 1)) * fade
29
+ end
30
+
31
+ if @fade_shape == "logarithmic"
32
+ fade = Torch.log10(0.1 + fade) + 1
33
+ end
34
+
35
+ if @fade_shape == "quarter_sine"
36
+ fade = Torch.sin(fade * Math::PI / 2)
37
+ end
38
+
39
+ if @fade_shape == "half_sine"
40
+ fade = Torch.sin(fade * Math::PI - Math::PI / 2) / 2 + 0.5
41
+ end
42
+
43
+ Torch.cat([fade, ones]).clamp!(0, 1)
44
+ end
45
+
46
+ def fade_out(waveform_length)
47
+ fade = Torch.linspace(0, 1, @fade_out_len)
48
+ ones = Torch.ones(waveform_length - @fade_out_len)
49
+
50
+ if @fade_shape == "linear"
51
+ fade = - fade + 1
52
+ end
53
+
54
+ if @fade_shape == "exponential"
55
+ fade = Torch.pow(2, - fade) * (1 - fade)
56
+ end
57
+
58
+ if @fade_shape == "logarithmic"
59
+ fade = Torch.log10(1.1 - fade) + 1
60
+ end
61
+
62
+ if @fade_shape == "quarter_sine"
63
+ fade = Torch.sin(fade * Math::PI / 2 + Math::PI / 2)
64
+ end
65
+
66
+ if @fade_shape == "half_sine"
67
+ fade = Torch.sin(fade * Math::PI + Math::PI / 2) / 2 + 0.5
68
+ end
69
+
70
+ Torch.cat([ones, fade]).clamp!(0, 1)
71
+ end
72
+ end
73
+ end
74
+ end
@@ -0,0 +1,39 @@
1
+ module TorchAudio
2
+ module Transforms
3
+ class MelScale < Torch::NN::Module
4
+ def initialize(n_mels: 128, sample_rate: 16000, f_min: 0.0, f_max: nil, n_stft: nil)
5
+ super()
6
+ @n_mels = n_mels
7
+ @sample_rate = sample_rate
8
+ @f_max = f_max || sample_rate.div(2).to_f
9
+ @f_min = f_min
10
+
11
+ raise ArgumentError, "Require f_min: %f < f_max: %f" % [f_min, @f_max] unless f_min <= @f_max
12
+
13
+ fb = n_stft.nil? ? Torch.empty(0) : F.create_fb_matrix(n_stft, @f_min, @f_max, @n_mels, @sample_rate)
14
+ register_buffer("fb", fb)
15
+ end
16
+
17
+ def forward(specgram)
18
+ shape = specgram.size
19
+ specgram = specgram.reshape(-1, shape[-2], shape[-1])
20
+
21
+ if @fb.numel == 0
22
+ tmp_fb = F.create_fb_matrix(specgram.size(1), @f_min, @f_max, @n_mels, @sample_rate)
23
+ # Attributes cannot be reassigned outside __init__ so workaround
24
+ @fb.resize!(tmp_fb.size)
25
+ @fb.copy!(tmp_fb)
26
+ end
27
+
28
+ # (channel, frequency, time).transpose(...) dot (frequency, n_mels)
29
+ # -> (channel, time, n_mels).transpose(...)
30
+ mel_specgram = Torch.matmul(specgram.transpose(1, 2), @fb).transpose(1, 2)
31
+
32
+ # unpack batch
33
+ mel_specgram = mel_specgram.reshape(shape[0...-2] + mel_specgram.shape[-2..-1])
34
+
35
+ mel_specgram
36
+ end
37
+ end
38
+ end
39
+ end
@@ -0,0 +1,37 @@
1
+ module TorchAudio
2
+ module Transforms
3
+ class MelSpectrogram < Torch::NN::Module
4
+ attr_reader :n_mels
5
+
6
+ def initialize(
7
+ sample_rate: 16000, n_fft: 400, win_length: nil, hop_length: nil, f_min: 0.0,
8
+ f_max: nil, pad: 0, n_mels: 128, window_fn: Torch.method(:hann_window),
9
+ power: 2.0, normalized: false, wkwargs: nil
10
+ )
11
+
12
+ super()
13
+ @sample_rate = sample_rate
14
+ @n_fft = n_fft
15
+ @win_length = win_length || n_fft
16
+ @hop_length = hop_length || @win_length.div(2)
17
+ @pad = pad
18
+ @power = power
19
+ @normalized = normalized
20
+ @n_mels = n_mels # number of mel frequency bins
21
+ @f_max = f_max
22
+ @f_min = f_min
23
+ @spectrogram =
24
+ Spectrogram.new(
25
+ n_fft: @n_fft, win_length: @win_length, hop_length: @hop_length, pad: @pad,
26
+ window_fn: window_fn, power: @power, normalized: @normalized, wkwargs: wkwargs
27
+ )
28
+ @mel_scale = MelScale.new(n_mels: @n_mels, sample_rate: @sample_rate, f_min: @f_min, f_max: @f_max, n_stft: @n_fft.div(2) + 1)
29
+ end
30
+
31
+ def forward(waveform)
32
+ specgram = @spectrogram.call(waveform)
33
+ @mel_scale.call(specgram)
34
+ end
35
+ end
36
+ end
37
+ end
@@ -0,0 +1,43 @@
1
+ module TorchAudio
2
+ module Transforms
3
+ class MFCC < Torch::NN::Module
4
+
5
+ SUPPORTED_DCT_TYPES = [2]
6
+
7
+ def initialize(sample_rate: 16000, n_mfcc: 40, dct_type: 2, norm: :ortho, log_mels: false, melkwargs: {})
8
+ super()
9
+
10
+ raise ArgumentError, "DCT type not supported: #{dct_type}" unless SUPPORTED_DCT_TYPES.include?(dct_type)
11
+
12
+ @sample_rate = sample_rate
13
+ @n_mfcc = n_mfcc
14
+ @dct_type = dct_type
15
+ @norm = norm
16
+ @top_db = 80.0
17
+ @amplitude_to_db = TorchAudio::Transforms::AmplitudeToDB.new(stype: :power, top_db: @top_db)
18
+
19
+ @melspectrogram = TorchAudio::Transforms::MelSpectrogram.new(sample_rate: @sample_rate, **melkwargs)
20
+
21
+ raise ArgumentError, "Cannot select more MFCC coefficients than # mel bins" if @n_mfcc > @melspectrogram.n_mels
22
+
23
+ dct_mat = F.create_dct(@n_mfcc, @melspectrogram.n_mels, norm: @norm)
24
+ register_buffer('dct_mat', dct_mat)
25
+
26
+ @log_mels = log_mels
27
+ end
28
+
29
+ def forward(waveform)
30
+ mel_specgram = @melspectrogram.(waveform)
31
+ if @log_mels
32
+ mel_specgram = Torch.log(mel_specgram + 1e-6)
33
+ else
34
+ mel_specgram = @amplitude_to_db.(mel_specgram)
35
+ end
36
+
37
+ Torch
38
+ .matmul(mel_specgram.transpose(-2, -1), @dct_mat)
39
+ .transpose(-2, -1)
40
+ end
41
+ end
42
+ end
43
+ end
@@ -0,0 +1,14 @@
1
+ module TorchAudio
2
+ module Transforms
3
+ class MuLawDecoding < Torch::NN::Module
4
+ def initialize(quantization_channels: 256)
5
+ super()
6
+ @quantization_channels = quantization_channels
7
+ end
8
+
9
+ def forward(x_mu)
10
+ F.mu_law_decoding(x_mu, @quantization_channels)
11
+ end
12
+ end
13
+ end
14
+ end
@@ -0,0 +1,14 @@
1
+ module TorchAudio
2
+ module Transforms
3
+ class MuLawEncoding < Torch::NN::Module
4
+ def initialize(quantization_channels: 256)
5
+ super()
6
+ @quantization_channels = quantization_channels
7
+ end
8
+
9
+ def forward(x)
10
+ F.mu_law_encoding(x, @quantization_channels)
11
+ end
12
+ end
13
+ end
14
+ end
@@ -0,0 +1,27 @@
1
+ module TorchAudio
2
+ module Transforms
3
+ class Spectrogram < Torch::NN::Module
4
+ def initialize(
5
+ n_fft: 400, win_length: nil, hop_length: nil, pad: 0,
6
+ window_fn: Torch.method(:hann_window), power: 2.0, normalized: false, wkwargs: nil
7
+ )
8
+
9
+ super()
10
+ @n_fft = n_fft
11
+ # number of FFT bins. the returned STFT result will have n_fft // 2 + 1
12
+ # number of frequecies due to onesided=True in torch.stft
13
+ @win_length = win_length || n_fft
14
+ @hop_length = hop_length || @win_length.div(2) # floor division
15
+ window = wkwargs.nil? ? window_fn.call(@win_length) : window_fn.call(@win_length, **wkwargs)
16
+ register_buffer("window", window)
17
+ @pad = pad
18
+ @power = power
19
+ @normalized = normalized
20
+ end
21
+
22
+ def forward(waveform)
23
+ F.spectrogram(waveform, @pad, @window, @n_fft, @hop_length, @win_length, @power, @normalized)
24
+ end
25
+ end
26
+ end
27
+ end
@@ -0,0 +1,31 @@
1
+ module TorchAudio
2
+ module Transforms
3
+ class Vol < Torch::NN::Module
4
+ def initialize(gain, gain_type: "amplitude")
5
+ super()
6
+ @gain = gain
7
+ @gain_type = gain_type
8
+
9
+ if ["amplitude", "power"].include?(gain_type) && gain < 0
10
+ raise ArgumentError, "If gain_type = amplitude or power, gain must be positive."
11
+ end
12
+ end
13
+
14
+ def forward(waveform)
15
+ if @gain_type == "amplitude"
16
+ waveform = waveform * @gain
17
+ end
18
+
19
+ if @gain_type == "db"
20
+ waveform = F.gain(waveform, @gain)
21
+ end
22
+
23
+ if @gain_type == "power"
24
+ waveform = F.gain(waveform, 10 * Math.log10(@gain))
25
+ end
26
+
27
+ Torch.clamp(waveform, -1, 1)
28
+ end
29
+ end
30
+ end
31
+ end
@@ -1,3 +1,3 @@
1
1
  module TorchAudio
2
- VERSION = "0.1.0"
2
+ VERSION = "0.2.1"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torchaudio
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.1.0
4
+ version: 0.2.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
- autorequire:
8
+ autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2020-08-25 00:00:00.000000000 Z
11
+ date: 2021-07-16 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: torch-rb
@@ -16,86 +16,30 @@ dependencies:
16
16
  requirements:
17
17
  - - ">="
18
18
  - !ruby/object:Gem::Version
19
- version: 0.3.2
19
+ version: 0.7.0
20
20
  type: :runtime
21
21
  prerelease: false
22
22
  version_requirements: !ruby/object:Gem::Requirement
23
23
  requirements:
24
24
  - - ">="
25
25
  - !ruby/object:Gem::Version
26
- version: 0.3.2
26
+ version: 0.7.0
27
27
  - !ruby/object:Gem::Dependency
28
28
  name: rice
29
29
  requirement: !ruby/object:Gem::Requirement
30
30
  requirements:
31
31
  - - ">="
32
32
  - !ruby/object:Gem::Version
33
- version: '2.2'
33
+ version: 4.0.2
34
34
  type: :runtime
35
35
  prerelease: false
36
36
  version_requirements: !ruby/object:Gem::Requirement
37
37
  requirements:
38
38
  - - ">="
39
39
  - !ruby/object:Gem::Version
40
- version: '2.2'
41
- - !ruby/object:Gem::Dependency
42
- name: bundler
43
- requirement: !ruby/object:Gem::Requirement
44
- requirements:
45
- - - ">="
46
- - !ruby/object:Gem::Version
47
- version: '0'
48
- type: :development
49
- prerelease: false
50
- version_requirements: !ruby/object:Gem::Requirement
51
- requirements:
52
- - - ">="
53
- - !ruby/object:Gem::Version
54
- version: '0'
55
- - !ruby/object:Gem::Dependency
56
- name: rake
57
- requirement: !ruby/object:Gem::Requirement
58
- requirements:
59
- - - ">="
60
- - !ruby/object:Gem::Version
61
- version: '0'
62
- type: :development
63
- prerelease: false
64
- version_requirements: !ruby/object:Gem::Requirement
65
- requirements:
66
- - - ">="
67
- - !ruby/object:Gem::Version
68
- version: '0'
69
- - !ruby/object:Gem::Dependency
70
- name: rake-compiler
71
- requirement: !ruby/object:Gem::Requirement
72
- requirements:
73
- - - ">="
74
- - !ruby/object:Gem::Version
75
- version: '0'
76
- type: :development
77
- prerelease: false
78
- version_requirements: !ruby/object:Gem::Requirement
79
- requirements:
80
- - - ">="
81
- - !ruby/object:Gem::Version
82
- version: '0'
83
- - !ruby/object:Gem::Dependency
84
- name: minitest
85
- requirement: !ruby/object:Gem::Requirement
86
- requirements:
87
- - - ">="
88
- - !ruby/object:Gem::Version
89
- version: '5'
90
- type: :development
91
- prerelease: false
92
- version_requirements: !ruby/object:Gem::Requirement
93
- requirements:
94
- - - ">="
95
- - !ruby/object:Gem::Version
96
- version: '5'
97
- description:
98
- email: andrew@chartkick.com
40
+ version: 4.0.2
41
+ description:
42
+ email: andrew@ankane.org
99
43
  executables: []
100
44
  extensions:
101
45
  - ext/torchaudio/extconf.rb
@@ -118,12 +62,23 @@ files:
118
62
  - lib/torchaudio.rb
119
63
  - lib/torchaudio/datasets/utils.rb
120
64
  - lib/torchaudio/datasets/yesno.rb
65
+ - lib/torchaudio/functional.rb
66
+ - lib/torchaudio/transforms/amplitude_to_db.rb
67
+ - lib/torchaudio/transforms/compute_deltas.rb
68
+ - lib/torchaudio/transforms/fade.rb
69
+ - lib/torchaudio/transforms/mel_scale.rb
70
+ - lib/torchaudio/transforms/mel_spectrogram.rb
71
+ - lib/torchaudio/transforms/mfcc.rb
72
+ - lib/torchaudio/transforms/mu_law_decoding.rb
73
+ - lib/torchaudio/transforms/mu_law_encoding.rb
74
+ - lib/torchaudio/transforms/spectrogram.rb
75
+ - lib/torchaudio/transforms/vol.rb
121
76
  - lib/torchaudio/version.rb
122
77
  homepage: https://github.com/ankane/torchaudio
123
78
  licenses:
124
79
  - BSD-2-Clause
125
80
  metadata: {}
126
- post_install_message:
81
+ post_install_message:
127
82
  rdoc_options: []
128
83
  require_paths:
129
84
  - lib
@@ -131,15 +86,15 @@ required_ruby_version: !ruby/object:Gem::Requirement
131
86
  requirements:
132
87
  - - ">="
133
88
  - !ruby/object:Gem::Version
134
- version: '2.5'
89
+ version: '2.6'
135
90
  required_rubygems_version: !ruby/object:Gem::Requirement
136
91
  requirements:
137
92
  - - ">="
138
93
  - !ruby/object:Gem::Version
139
94
  version: '0'
140
95
  requirements: []
141
- rubygems_version: 3.1.2
142
- signing_key:
96
+ rubygems_version: 3.2.22
97
+ signing_key:
143
98
  specification_version: 4
144
99
  summary: Data manipulation and transformation for audio signal processing
145
100
  test_files: []