torchaudio 0.1.0 → 0.2.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: b527976494325cc12e81342c25d318204d2d7c75bfba7036be4296769cdb30a0
4
- data.tar.gz: 2cfde7bd1b0e7a1628818d5bd74657cfbfba6dfa83ef42897f3ad0f98e77f739
3
+ metadata.gz: 9ed4c14921f1eee18f5e08ddabfae51e09a9b5a7ef408f1dd67fdf7bfe9622fe
4
+ data.tar.gz: 1e37d5b9abed9cab7bf56a8c30a769bc8ff8f8a3e15e78bbb772847c444571b2
5
5
  SHA512:
6
- metadata.gz: 8e6f34b014340b5ace3193ab589dae75ed0869ab7606402bd4b09de6042299e6f3a118d439dd381491f489ce9552bca4376a7d5b4693dddc3d1c5f5b26540900
7
- data.tar.gz: d651c46f5185ceb70ae3d9c90154c77afe29a5c35854d1a9d98913096b7ab9ba39a745242dd268548ca87f9e109b56c96dee9dc5539cf066f9ad0f773eddbdcd
6
+ metadata.gz: 9ca5436d7e4309dd9659fdce7ee893b122e9da96e9f7b15bf00de5dea32c635e2828a99939f04ef5bf0d9494ab89957829a65002dc3e855fa8a66f54abbbd181
7
+ data.tar.gz: d62b2a137c19d3b24facb11eda5c1b81be5841120b505877b8617bee2b9f183dbe4b4d42a95af27447346a3d48476d7faec48b57cd89b88c0ddc9709f1b5d51b
data/CHANGELOG.md CHANGED
@@ -1,3 +1,24 @@
1
+ ## 0.2.1 (2021-07-16)
2
+
3
+ - Added `create_dct` method
4
+ - Added `ComputeDeltas`, `Fade`, `MFCC`, and `Vol` transforms
5
+
6
+ ## 0.2.0 (2021-05-23)
7
+
8
+ - Updated to Rice 4
9
+ - Dropped support for Ruby < 2.6
10
+
11
+ ## 0.1.2 (2021-02-06)
12
+
13
+ - Added `amplitude_to_DB` and `DB_to_amplitude` methods
14
+ - Added `AmplitudeToDB` transform
15
+ - Fixed `save` options
16
+
17
+ ## 0.1.1 (2020-08-26)
18
+
19
+ - Added `save` method
20
+ - Added transforms
21
+
1
22
  ## 0.1.0 (2020-08-24)
2
23
 
3
24
  - First release
data/LICENSE.txt CHANGED
@@ -1,7 +1,7 @@
1
1
  BSD 2-Clause License
2
2
 
3
3
  Copyright (c) 2017 Facebook Inc. (Soumith Chintala),
4
- Copyright (c) 2020 Andrew Kane,
4
+ Copyright (c) 2020-2021 Andrew Kane,
5
5
  All rights reserved.
6
6
 
7
7
  Redistribution and use in source and binary forms, with or without
data/README.md CHANGED
@@ -2,6 +2,8 @@
2
2
 
3
3
  :fire: An audio library for Torch.rb
4
4
 
5
+ [![Build Status](https://github.com/ankane/torchaudio/workflows/build/badge.svg?branch=master)](https://github.com/ankane/torchaudio/actions)
6
+
5
7
  ## Installation
6
8
 
7
9
  First, [install SoX](#sox-installation). For Homebrew, use:
@@ -20,6 +22,67 @@ gem 'torchaudio'
20
22
 
21
23
  This library follows the [Python API](https://pytorch.org/audio/). Many methods and options are missing at the moment. PRs welcome!
22
24
 
25
+ ## Tutorial
26
+
27
+ - [PyTorch tutorial](https://pytorch.org/tutorials/beginner/audio_preprocessing_tutorial.html)
28
+ - [Ruby code](examples/tutorial.rb)
29
+
30
+ Download the [audio file](https://github.com/pytorch/tutorials/raw/master/_static/img/steam-train-whistle-daniel_simon-converted-from-mp3.wav) and install the [matplotlib](https://github.com/mrkn/matplotlib.rb) gem first.
31
+
32
+ ## Basics
33
+
34
+ Load a file
35
+
36
+ ```ruby
37
+ waveform, sample_rate = TorchAudio.load("file.wav")
38
+ ```
39
+
40
+ Save a file
41
+
42
+ ```ruby
43
+ TorchAudio.save("new.wave", waveform, sample_rate)
44
+ ```
45
+
46
+ ## Transforms
47
+
48
+ ```ruby
49
+ TorchAudio::Transforms::Spectrogram.new.call(waveform)
50
+ ```
51
+
52
+ Supported transforms are:
53
+
54
+ - AmplitudeToDB
55
+ - ComputeDeltas
56
+ - Fade
57
+ - MelScale
58
+ - MelSpectrogram
59
+ - MFCC
60
+ - MuLawDecoding
61
+ - MuLawEncoding
62
+ - Spectrogram
63
+ - Vol
64
+
65
+ ## Functional
66
+
67
+ ```ruby
68
+ TorchAudio::Functional.lowpass_biquad(waveform, sample_rate, cutoff_freq)
69
+ ```
70
+
71
+ Supported functions are:
72
+
73
+ - amplitude_to_DB
74
+ - compute_deltas
75
+ - create_dct
76
+ - create_fb_matrix
77
+ - DB_to_amplitude
78
+ - dither
79
+ - gain
80
+ - highpass_biquad
81
+ - lowpass_biquad
82
+ - mu_law_decoding
83
+ - mu_law_encoding
84
+ - spectrogram
85
+
23
86
  ## Datasets
24
87
 
25
88
  Load a dataset
@@ -1,33 +1,33 @@
1
1
  #include <torchaudio/csrc/sox.h>
2
2
 
3
- #include <rice/Module.hpp>
4
-
5
- using namespace Rice;
6
-
7
- template<>
8
- inline
9
- sox_signalinfo_t* from_ruby<sox_signalinfo_t*>(Object x)
10
- {
11
- if (x.is_nil()) {
12
- return nullptr;
13
- }
14
- throw std::runtime_error("Unsupported signalinfo");
15
- }
16
-
17
- template<>
18
- inline
19
- sox_encodinginfo_t* from_ruby<sox_encodinginfo_t*>(Object x)
20
- {
21
- if (x.is_nil()) {
22
- return nullptr;
23
- }
24
- throw std::runtime_error("Unsupported encodinginfo");
25
- }
3
+ #include <rice/rice.hpp>
4
+ #include <rice/stl.hpp>
26
5
 
27
6
  extern "C"
28
7
  void Init_ext()
29
8
  {
30
- Module rb_mTorchAudio = define_module("TorchAudio");
31
- Module rb_mNN = define_module_under(rb_mTorchAudio, "Ext")
32
- .define_singleton_method("read_audio_file", &torch::audio::read_audio_file);
9
+ auto rb_mTorchAudio = Rice::define_module("TorchAudio");
10
+
11
+ auto rb_mExt = Rice::define_module_under(rb_mTorchAudio, "Ext")
12
+ .define_singleton_function(
13
+ "read_audio_file",
14
+ [](const std::string& file_name, at::Tensor output, bool ch_first, int64_t nframes, int64_t offset, sox_signalinfo_t* si, sox_encodinginfo_t* ei, const char* ft) {
15
+ return torch::audio::read_audio_file(file_name, output, ch_first, nframes, offset, si, ei, ft);
16
+ })
17
+ .define_singleton_function(
18
+ "write_audio_file",
19
+ [](const std::string& file_name, const at::Tensor& tensor, sox_signalinfo_t* si, sox_encodinginfo_t* ei, const char* file_type) {
20
+ return torch::audio::write_audio_file(file_name, tensor, si, ei, file_type);
21
+ });
22
+
23
+ auto rb_cSignalInfo = Rice::define_class_under<sox_signalinfo_t>(rb_mExt, "SignalInfo")
24
+ .define_constructor(Rice::Constructor<sox_signalinfo_t>())
25
+ .define_method("rate", [](sox_signalinfo_t& self) { return self.rate; })
26
+ .define_method("channels", [](sox_signalinfo_t& self) { return self.channels; })
27
+ .define_method("precision", [](sox_signalinfo_t& self) { return self.precision; })
28
+ .define_method("length", [](sox_signalinfo_t& self) { return self.length; })
29
+ .define_method("rate=", [](sox_signalinfo_t& self, sox_rate_t rate) { self.rate = rate; })
30
+ .define_method("channels=", [](sox_signalinfo_t& self, unsigned channels) { self.channels = channels; })
31
+ .define_method("precision=", [](sox_signalinfo_t& self, unsigned precision) { self.precision = precision; })
32
+ .define_method("length=", [](sox_signalinfo_t& self, sox_uint64_t length) { self.length = length; });
33
33
  }
@@ -1,8 +1,6 @@
1
1
  require "mkmf-rice"
2
2
 
3
- abort "Missing stdc++" unless have_library("stdc++")
4
-
5
- $CXXFLAGS += " -std=c++14"
3
+ $CXXFLAGS += " -std=c++17 $(optflags)"
6
4
 
7
5
  abort "SoX not found" unless have_library("sox")
8
6
 
@@ -24,7 +22,6 @@ apple_clang = RbConfig::CONFIG["CC_VERSION_MESSAGE"] =~ /apple clang/i
24
22
 
25
23
  # check omp first
26
24
  if have_library("omp") || have_library("gomp")
27
- $CXXFLAGS += " -DAT_PARALLEL_OPENMP=1"
28
25
  $CXXFLAGS += " -Xclang" if apple_clang
29
26
  $CXXFLAGS += " -fopenmp"
30
27
  end
data/lib/torchaudio.rb CHANGED
@@ -14,6 +14,17 @@ require "set"
14
14
  # modules
15
15
  require "torchaudio/datasets/utils"
16
16
  require "torchaudio/datasets/yesno"
17
+ require "torchaudio/functional"
18
+ require "torchaudio/transforms/compute_deltas"
19
+ require "torchaudio/transforms/fade"
20
+ require "torchaudio/transforms/mel_scale"
21
+ require "torchaudio/transforms/mel_spectrogram"
22
+ require "torchaudio/transforms/mu_law_encoding"
23
+ require "torchaudio/transforms/mu_law_decoding"
24
+ require "torchaudio/transforms/spectrogram"
25
+ require "torchaudio/transforms/amplitude_to_db"
26
+ require "torchaudio/transforms/mfcc"
27
+ require "torchaudio/transforms/vol"
17
28
  require "torchaudio/version"
18
29
 
19
30
  module TorchAudio
@@ -73,6 +84,70 @@ module TorchAudio
73
84
  load(filepath, **kwargs)
74
85
  end
75
86
 
87
+ def save(filepath, src, sample_rate, precision: 16, channels_first: true)
88
+ si = Ext::SignalInfo.new
89
+ ch_idx = channels_first ? 0 : 1
90
+ si.rate = sample_rate
91
+ si.channels = src.dim == 1 ? 1 : src.size(ch_idx)
92
+ si.length = src.numel
93
+ si.precision = precision
94
+ save_encinfo(filepath, src, channels_first: channels_first, signalinfo: si)
95
+ end
96
+
97
+ def save_encinfo(filepath, src, channels_first: true, signalinfo: nil, encodinginfo: nil, filetype: nil)
98
+ ch_idx, len_idx = channels_first ? [0, 1] : [1, 0]
99
+
100
+ # check if save directory exists
101
+ abs_dirpath = File.dirname(File.expand_path(filepath))
102
+ unless Dir.exist?(abs_dirpath)
103
+ raise "Directory does not exist: #{abs_dirpath}"
104
+ end
105
+ # check that src is a CPU tensor
106
+ check_input(src)
107
+ # Check/Fix shape of source data
108
+ if src.dim == 1
109
+ # 1d tensors as assumed to be mono signals
110
+ src.unsqueeze!(ch_idx)
111
+ elsif src.dim > 2 || src.size(ch_idx) > 16
112
+ # assumes num_channels < 16
113
+ raise ArgumentError, "Expected format where C < 16, but found #{src.size}"
114
+ end
115
+ # sox stores the sample rate as a float, though practically sample rates are almost always integers
116
+ # convert integers to floats
117
+ if signalinfo
118
+ if signalinfo.rate && !signalinfo.rate.is_a?(Float)
119
+ if signalinfo.rate.to_f == signalinfo.rate
120
+ signalinfo.rate = signalinfo.rate.to_f
121
+ else
122
+ raise ArgumentError, "Sample rate should be a float or int"
123
+ end
124
+ end
125
+ # check if the bit precision (i.e. bits per sample) is an integer
126
+ if signalinfo.precision && ! signalinfo.precision.is_a?(Integer)
127
+ if signalinfo.precision.to_i == signalinfo.precision
128
+ signalinfo.precision = signalinfo.precision.to_i
129
+ else
130
+ raise ArgumentError, "Bit precision should be an integer"
131
+ end
132
+ end
133
+ end
134
+ # programs such as librosa normalize the signal, unnormalize if detected
135
+ if src.min >= -1.0 && src.max <= 1.0
136
+ src = src * (1 << 31)
137
+ src = src.long
138
+ end
139
+ # set filetype and allow for files with no extensions
140
+ extension = File.extname(filepath)
141
+ filetype = extension.length > 0 ? extension[1..-1] : filetype
142
+ # transpose from C x L -> L x C
143
+ if channels_first
144
+ src = src.transpose(1, 0)
145
+ end
146
+ # save data to file
147
+ src = src.contiguous
148
+ Ext.write_audio_file(filepath, src, signalinfo, encodinginfo, filetype)
149
+ end
150
+
76
151
  private
77
152
 
78
153
  def check_input(src)
@@ -0,0 +1,324 @@
1
+ module TorchAudio
2
+ module Functional
3
+ class << self
4
+ def spectrogram(waveform, pad, window, n_fft, hop_length, win_length, power, normalized)
5
+ if pad > 0
6
+ # TODO add "with torch.no_grad():" back when JIT supports it
7
+ waveform = Torch::NN::Functional.pad(waveform, [pad, pad], "constant")
8
+ end
9
+
10
+ # pack batch
11
+ shape = waveform.size
12
+ waveform = waveform.reshape(-1, shape[-1])
13
+
14
+ # default values are consistent with librosa.core.spectrum._spectrogram
15
+ spec_f =
16
+ Torch.stft(
17
+ waveform,
18
+ n_fft,
19
+ hop_length: hop_length,
20
+ win_length: win_length,
21
+ window: window,
22
+ center: true,
23
+ pad_mode: "reflect",
24
+ normalized: false,
25
+ onesided: true
26
+ )
27
+
28
+ # unpack batch
29
+ spec_f = spec_f.reshape(shape[0..-2] + spec_f.shape[-3..-1])
30
+
31
+ if normalized
32
+ spec_f.div!(window.pow(2.0).sum.sqrt)
33
+ end
34
+ if power
35
+ spec_f = complex_norm(spec_f, power: power)
36
+ end
37
+
38
+ spec_f
39
+ end
40
+
41
+ def mu_law_encoding(x, quantization_channels)
42
+ mu = quantization_channels - 1.0
43
+ if !x.floating_point?
44
+ x = x.to(dtype: :float)
45
+ end
46
+ mu = Torch.tensor(mu, dtype: x.dtype)
47
+ x_mu = Torch.sign(x) * Torch.log1p(mu * Torch.abs(x)) / Torch.log1p(mu)
48
+ x_mu = ((x_mu + 1) / 2 * mu + 0.5).to(dtype: :int64)
49
+ x_mu
50
+ end
51
+
52
+ def mu_law_decoding(x_mu, quantization_channels)
53
+ mu = quantization_channels - 1.0
54
+ if !x_mu.floating_point?
55
+ x_mu = x_mu.to(dtype: :float)
56
+ end
57
+ mu = Torch.tensor(mu, dtype: x_mu.dtype)
58
+ x = ((x_mu) / mu) * 2 - 1.0
59
+ x = Torch.sign(x) * (Torch.exp(Torch.abs(x) * Torch.log1p(mu)) - 1.0) / mu
60
+ x
61
+ end
62
+
63
+ def complex_norm(complex_tensor, power: 1.0)
64
+ complex_tensor.pow(2.0).sum(-1).pow(0.5 * power)
65
+ end
66
+
67
+ def create_fb_matrix(n_freqs, f_min, f_max, n_mels, sample_rate, norm: nil)
68
+ if norm && norm != "slaney"
69
+ raise ArgumentError, "norm must be one of None or 'slaney'"
70
+ end
71
+
72
+ # freq bins
73
+ # Equivalent filterbank construction by Librosa
74
+ all_freqs = Torch.linspace(0, sample_rate.div(2), n_freqs)
75
+
76
+ # calculate mel freq bins
77
+ # hertz to mel(f) is 2595. * math.log10(1. + (f / 700.))
78
+ m_min = 2595.0 * Math.log10(1.0 + (f_min / 700.0))
79
+ m_max = 2595.0 * Math.log10(1.0 + (f_max / 700.0))
80
+ m_pts = Torch.linspace(m_min, m_max, n_mels + 2)
81
+ # mel to hertz(mel) is 700. * (10**(mel / 2595.) - 1.)
82
+ f_pts = (Torch.pow(10, m_pts / 2595.0) - 1.0) * 700.0
83
+ # calculate the difference between each mel point and each stft freq point in hertz
84
+ f_diff = f_pts[1..-1] - f_pts[0...-1] # (n_mels + 1)
85
+ slopes = f_pts.unsqueeze(0) - all_freqs.unsqueeze(1) # (n_freqs, n_mels + 2)
86
+ # create overlapping triangles
87
+ zero = Torch.zeros(1)
88
+ down_slopes = (slopes[0..-1, 0...-2] * -1.0) / f_diff[0...-1] # (n_freqs, n_mels)
89
+ up_slopes = slopes[0..-1, 2..-1] / f_diff[1..-1] # (n_freqs, n_mels)
90
+ fb = Torch.max(zero, Torch.min(down_slopes, up_slopes))
91
+
92
+ if norm && norm == "slaney"
93
+ # Slaney-style mel is scaled to be approx constant energy per channel
94
+ enorm = 2.0 / (f_pts[2...(n_mels + 2)] - f_pts[:n_mels])
95
+ fb *= enorm.unsqueeze(0)
96
+ end
97
+
98
+ fb
99
+ end
100
+
101
+ def compute_deltas(specgram, win_length: 5, mode: "replicate")
102
+ device = specgram.device
103
+ dtype = specgram.dtype
104
+
105
+ # pack batch
106
+ shape = specgram.size
107
+ specgram = specgram.reshape(1, -1, shape[-1])
108
+
109
+ raise ArgumentError, "win_length must be >= 3" unless win_length >= 3
110
+
111
+ n = (win_length - 1).div(2)
112
+
113
+ # twice sum of integer squared
114
+ denom = n * (n + 1) * (2 * n + 1) / 3
115
+
116
+ specgram = Torch::NN::Functional.pad(specgram, [n, n], mode: mode)
117
+
118
+ kernel = Torch.arange(-n, n + 1, 1, device: device, dtype: dtype).repeat([specgram.shape[1], 1, 1])
119
+
120
+ output = Torch::NN::Functional.conv1d(specgram, kernel, groups: specgram.shape[1]) / denom
121
+
122
+ # unpack batch
123
+ output = output.reshape(shape)
124
+ end
125
+
126
+ def gain(waveform, gain_db: 1.0)
127
+ return waveform if gain_db == 0
128
+
129
+ ratio = 10 ** (gain_db / 20)
130
+
131
+ waveform * ratio
132
+ end
133
+
134
+ def dither(waveform, density_function: "TPDF", noise_shaping: false)
135
+ dithered = _apply_probability_distribution(waveform, density_function: density_function)
136
+
137
+ if noise_shaping
138
+ raise "Not implemented yet"
139
+ # _add_noise_shaping(dithered, waveform)
140
+ else
141
+ dithered
142
+ end
143
+ end
144
+
145
+ def biquad(waveform, b0, b1, b2, a0, a1, a2)
146
+ device = waveform.device
147
+ dtype = waveform.dtype
148
+
149
+ output_waveform = lfilter(
150
+ waveform,
151
+ Torch.tensor([a0, a1, a2], dtype: dtype, device: device),
152
+ Torch.tensor([b0, b1, b2], dtype: dtype, device: device)
153
+ )
154
+ output_waveform
155
+ end
156
+
157
+ def highpass_biquad(waveform, sample_rate, cutoff_freq, q: 0.707)
158
+ w0 = 2 * Math::PI * cutoff_freq / sample_rate
159
+ alpha = Math.sin(w0) / 2.0 / q
160
+
161
+ b0 = (1 + Math.cos(w0)) / 2
162
+ b1 = -1 - Math.cos(w0)
163
+ b2 = b0
164
+ a0 = 1 + alpha
165
+ a1 = -2 * Math.cos(w0)
166
+ a2 = 1 - alpha
167
+ biquad(waveform, b0, b1, b2, a0, a1, a2)
168
+ end
169
+
170
+ def lowpass_biquad(waveform, sample_rate, cutoff_freq, q: 0.707)
171
+ w0 = 2 * Math::PI * cutoff_freq / sample_rate
172
+ alpha = Math.sin(w0) / 2 / q
173
+
174
+ b0 = (1 - Math.cos(w0)) / 2
175
+ b1 = 1 - Math.cos(w0)
176
+ b2 = b0
177
+ a0 = 1 + alpha
178
+ a1 = -2 * Math.cos(w0)
179
+ a2 = 1 - alpha
180
+ biquad(waveform, b0, b1, b2, a0, a1, a2)
181
+ end
182
+
183
+ def lfilter(waveform, a_coeffs, b_coeffs, clamp: true)
184
+ # pack batch
185
+ shape = waveform.size
186
+ waveform = waveform.reshape(-1, shape[-1])
187
+
188
+ raise ArgumentError unless (a_coeffs.size(0) == b_coeffs.size(0))
189
+ raise ArgumentError unless (waveform.size.length == 2)
190
+ raise ArgumentError unless (waveform.device == a_coeffs.device)
191
+ raise ArgumentError unless (b_coeffs.device == a_coeffs.device)
192
+
193
+ device = waveform.device
194
+ dtype = waveform.dtype
195
+ n_channel, n_sample = waveform.size
196
+ n_order = a_coeffs.size(0)
197
+ n_sample_padded = n_sample + n_order - 1
198
+ raise ArgumentError unless (n_order > 0)
199
+
200
+ # Pad the input and create output
201
+ padded_waveform = Torch.zeros(n_channel, n_sample_padded, dtype: dtype, device: device)
202
+ padded_waveform[0..-1, (n_order - 1)..-1] = waveform
203
+ padded_output_waveform = Torch.zeros(n_channel, n_sample_padded, dtype: dtype, device: device)
204
+
205
+ # Set up the coefficients matrix
206
+ # Flip coefficients' order
207
+ a_coeffs_flipped = a_coeffs.flip([0])
208
+ b_coeffs_flipped = b_coeffs.flip([0])
209
+
210
+ # calculate windowed_input_signal in parallel
211
+ # create indices of original with shape (n_channel, n_order, n_sample)
212
+ window_idxs = Torch.arange(n_sample, device: device).unsqueeze(0) + Torch.arange(n_order, device: device).unsqueeze(1)
213
+ window_idxs = window_idxs.repeat([n_channel, 1, 1])
214
+ window_idxs += (Torch.arange(n_channel, device: device).unsqueeze(-1).unsqueeze(-1) * n_sample_padded)
215
+ window_idxs = window_idxs.long
216
+ # (n_order, ) matmul (n_channel, n_order, n_sample) -> (n_channel, n_sample)
217
+ input_signal_windows = Torch.matmul(b_coeffs_flipped, Torch.take(padded_waveform, window_idxs))
218
+
219
+ input_signal_windows.div!(a_coeffs[0])
220
+ a_coeffs_flipped.div!(a_coeffs[0])
221
+ input_signal_windows.t.each_with_index do |o0, i_sample|
222
+ windowed_output_signal = padded_output_waveform[0..-1, i_sample...(i_sample + n_order)]
223
+ o0.addmv!(windowed_output_signal, a_coeffs_flipped, alpha: -1)
224
+ padded_output_waveform[0..-1, i_sample + n_order - 1] = o0
225
+ end
226
+
227
+ output = padded_output_waveform[0..-1, (n_order - 1)..-1]
228
+
229
+ if clamp
230
+ output = Torch.clamp(output, -1.0, 1.0)
231
+ end
232
+
233
+ # unpack batch
234
+ output = output.reshape(shape[0...-1] + output.shape[-1..-1])
235
+
236
+ output
237
+ end
238
+
239
+ def amplitude_to_DB(amp, multiplier, amin, db_multiplier, top_db: nil)
240
+ db = Torch.log10(Torch.clamp(amp, min: amin)) * multiplier
241
+ db -= multiplier * db_multiplier
242
+
243
+ db = db.clamp(min: db.max.item - top_db) if top_db
244
+
245
+ db
246
+ end
247
+
248
+ def DB_to_amplitude(db, ref, power)
249
+ Torch.pow(Torch.pow(10.0, db * 0.1), power) * ref
250
+ end
251
+
252
+ def create_dct(n_mfcc, n_mels, norm: nil)
253
+ n = Torch.arange(n_mels.to_f)
254
+ k = Torch.arange(n_mfcc.to_f).unsqueeze!(1)
255
+ dct = Torch.cos((n + 0.5) * k * Math::PI / n_mels.to_f)
256
+
257
+ if norm.nil?
258
+ dct *= 2.0
259
+ else
260
+ raise ArgumentError, "Invalid DCT norm value" unless norm == :ortho
261
+
262
+ dct[0] *= 1.0 / Math.sqrt(2.0)
263
+ dct *= Math.sqrt(2.0 / n_mels)
264
+ end
265
+
266
+ dct.t
267
+ end
268
+
269
+ private
270
+
271
+ def _apply_probability_distribution(waveform, density_function: "TPDF")
272
+ # pack batch
273
+ shape = waveform.size
274
+ waveform = waveform.reshape(-1, shape[-1])
275
+
276
+ channel_size = waveform.size[0] - 1
277
+ time_size = waveform.size[-1] - 1
278
+
279
+ random_channel = channel_size > 0 ? Torch.randint(channel_size, [1]).item.to_i : 0
280
+ random_time = time_size > 0 ? Torch.randint(time_size, [1]).item.to_i : 0
281
+
282
+ number_of_bits = 16
283
+ up_scaling = 2 ** (number_of_bits - 1) - 2
284
+ signal_scaled = waveform * up_scaling
285
+ down_scaling = 2 ** (number_of_bits - 1)
286
+
287
+ signal_scaled_dis = waveform
288
+ if density_function == "RPDF"
289
+ rpdf = waveform[random_channel][random_time] - 0.5
290
+
291
+ signal_scaled_dis = signal_scaled + rpdf
292
+ elsif density_function == "GPDF"
293
+ # TODO Replace by distribution code once
294
+ # https://github.com/pytorch/pytorch/issues/29843 is resolved
295
+ # gaussian = torch.distributions.normal.Normal(torch.mean(waveform, -1), 1).sample()
296
+
297
+ num_rand_variables = 6
298
+
299
+ gaussian = waveform[random_channel][random_time]
300
+ (num_rand_variables * [time_size]).each do |ws|
301
+ rand_chan = Torch.randint(channel_size, [1]).item.to_i
302
+ gaussian += waveform[rand_chan][Torch.randint(ws, [1]).item.to_i]
303
+ end
304
+
305
+ signal_scaled_dis = signal_scaled + gaussian
306
+ else
307
+ # dtype needed for https://github.com/pytorch/pytorch/issues/32358
308
+ # TODO add support for dtype and device to bartlett_window
309
+ tpdf = Torch.bartlett_window(time_size + 1).to(signal_scaled.device, dtype: signal_scaled.dtype)
310
+ tpdf = tpdf.repeat([channel_size + 1, 1])
311
+ signal_scaled_dis = signal_scaled + tpdf
312
+ end
313
+
314
+ quantised_signal_scaled = Torch.round(signal_scaled_dis)
315
+ quantised_signal = quantised_signal_scaled / down_scaling
316
+
317
+ # unpack batch
318
+ quantised_signal.reshape(shape[0...-1] + quantised_signal.shape[-1..-1])
319
+ end
320
+ end
321
+ end
322
+
323
+ F = Functional
324
+ end
@@ -0,0 +1,27 @@
1
+ module TorchAudio
2
+ module Transforms
3
+ class AmplitudeToDB < Torch::NN::Module
4
+ def initialize(stype: :power, top_db: nil)
5
+ super()
6
+
7
+ @stype = stype
8
+
9
+ raise ArgumentError, 'top_db must be a positive numerical' if top_db && top_db.negative?
10
+
11
+ @top_db = top_db
12
+ @multiplier = stype == :power ? 10.0 : 20.0
13
+ @amin = 1e-10
14
+ @ref_value = 1.0
15
+ @db_multiplier = Math.log10([@amin, @ref_value].max)
16
+ end
17
+
18
+ def forward(amplitude_spectrogram)
19
+ F.amplitude_to_DB(
20
+ amplitude_spectrogram,
21
+ @multiplier, @amin, @db_multiplier,
22
+ top_db: @top_db
23
+ )
24
+ end
25
+ end
26
+ end
27
+ end
@@ -0,0 +1,15 @@
1
+ module TorchAudio
2
+ module Transforms
3
+ class ComputeDeltas < Torch::NN::Module
4
+ def initialize(win_length: 5, mode: "replicate")
5
+ super()
6
+ @win_length = win_length
7
+ @mode = mode
8
+ end
9
+
10
+ def forward(specgram)
11
+ F.compute_deltas(specgram, win_length: @win_length, mode: @mode)
12
+ end
13
+ end
14
+ end
15
+ end
@@ -0,0 +1,74 @@
1
+ module TorchAudio
2
+ module Transforms
3
+ class Fade < Torch::NN::Module
4
+ def initialize(fade_in_len: 0, fade_out_len: 0, fade_shape: "linear")
5
+ super()
6
+ @fade_in_len = fade_in_len
7
+ @fade_out_len = fade_out_len
8
+ @fade_shape = fade_shape
9
+ end
10
+
11
+ def forward(waveform)
12
+ waveform_length = waveform.size[-1]
13
+ device = waveform.device
14
+ fade_in(waveform_length).to(device) * fade_out(waveform_length).to(device) * waveform
15
+ end
16
+
17
+ private
18
+
19
+ def fade_in(waveform_length)
20
+ fade = Torch.linspace(0, 1, @fade_in_len)
21
+ ones = Torch.ones(waveform_length - @fade_in_len)
22
+
23
+ if @fade_shape == "linear"
24
+ fade = fade
25
+ end
26
+
27
+ if @fade_shape == "exponential"
28
+ fade = Torch.pow(2, (fade - 1)) * fade
29
+ end
30
+
31
+ if @fade_shape == "logarithmic"
32
+ fade = Torch.log10(0.1 + fade) + 1
33
+ end
34
+
35
+ if @fade_shape == "quarter_sine"
36
+ fade = Torch.sin(fade * Math::PI / 2)
37
+ end
38
+
39
+ if @fade_shape == "half_sine"
40
+ fade = Torch.sin(fade * Math::PI - Math::PI / 2) / 2 + 0.5
41
+ end
42
+
43
+ Torch.cat([fade, ones]).clamp!(0, 1)
44
+ end
45
+
46
+ def fade_out(waveform_length)
47
+ fade = Torch.linspace(0, 1, @fade_out_len)
48
+ ones = Torch.ones(waveform_length - @fade_out_len)
49
+
50
+ if @fade_shape == "linear"
51
+ fade = - fade + 1
52
+ end
53
+
54
+ if @fade_shape == "exponential"
55
+ fade = Torch.pow(2, - fade) * (1 - fade)
56
+ end
57
+
58
+ if @fade_shape == "logarithmic"
59
+ fade = Torch.log10(1.1 - fade) + 1
60
+ end
61
+
62
+ if @fade_shape == "quarter_sine"
63
+ fade = Torch.sin(fade * Math::PI / 2 + Math::PI / 2)
64
+ end
65
+
66
+ if @fade_shape == "half_sine"
67
+ fade = Torch.sin(fade * Math::PI + Math::PI / 2) / 2 + 0.5
68
+ end
69
+
70
+ Torch.cat([ones, fade]).clamp!(0, 1)
71
+ end
72
+ end
73
+ end
74
+ end
@@ -0,0 +1,39 @@
1
+ module TorchAudio
2
+ module Transforms
3
+ class MelScale < Torch::NN::Module
4
+ def initialize(n_mels: 128, sample_rate: 16000, f_min: 0.0, f_max: nil, n_stft: nil)
5
+ super()
6
+ @n_mels = n_mels
7
+ @sample_rate = sample_rate
8
+ @f_max = f_max || sample_rate.div(2).to_f
9
+ @f_min = f_min
10
+
11
+ raise ArgumentError, "Require f_min: %f < f_max: %f" % [f_min, @f_max] unless f_min <= @f_max
12
+
13
+ fb = n_stft.nil? ? Torch.empty(0) : F.create_fb_matrix(n_stft, @f_min, @f_max, @n_mels, @sample_rate)
14
+ register_buffer("fb", fb)
15
+ end
16
+
17
+ def forward(specgram)
18
+ shape = specgram.size
19
+ specgram = specgram.reshape(-1, shape[-2], shape[-1])
20
+
21
+ if @fb.numel == 0
22
+ tmp_fb = F.create_fb_matrix(specgram.size(1), @f_min, @f_max, @n_mels, @sample_rate)
23
+ # Attributes cannot be reassigned outside __init__ so workaround
24
+ @fb.resize!(tmp_fb.size)
25
+ @fb.copy!(tmp_fb)
26
+ end
27
+
28
+ # (channel, frequency, time).transpose(...) dot (frequency, n_mels)
29
+ # -> (channel, time, n_mels).transpose(...)
30
+ mel_specgram = Torch.matmul(specgram.transpose(1, 2), @fb).transpose(1, 2)
31
+
32
+ # unpack batch
33
+ mel_specgram = mel_specgram.reshape(shape[0...-2] + mel_specgram.shape[-2..-1])
34
+
35
+ mel_specgram
36
+ end
37
+ end
38
+ end
39
+ end
@@ -0,0 +1,37 @@
1
+ module TorchAudio
2
+ module Transforms
3
+ class MelSpectrogram < Torch::NN::Module
4
+ attr_reader :n_mels
5
+
6
+ def initialize(
7
+ sample_rate: 16000, n_fft: 400, win_length: nil, hop_length: nil, f_min: 0.0,
8
+ f_max: nil, pad: 0, n_mels: 128, window_fn: Torch.method(:hann_window),
9
+ power: 2.0, normalized: false, wkwargs: nil
10
+ )
11
+
12
+ super()
13
+ @sample_rate = sample_rate
14
+ @n_fft = n_fft
15
+ @win_length = win_length || n_fft
16
+ @hop_length = hop_length || @win_length.div(2)
17
+ @pad = pad
18
+ @power = power
19
+ @normalized = normalized
20
+ @n_mels = n_mels # number of mel frequency bins
21
+ @f_max = f_max
22
+ @f_min = f_min
23
+ @spectrogram =
24
+ Spectrogram.new(
25
+ n_fft: @n_fft, win_length: @win_length, hop_length: @hop_length, pad: @pad,
26
+ window_fn: window_fn, power: @power, normalized: @normalized, wkwargs: wkwargs
27
+ )
28
+ @mel_scale = MelScale.new(n_mels: @n_mels, sample_rate: @sample_rate, f_min: @f_min, f_max: @f_max, n_stft: @n_fft.div(2) + 1)
29
+ end
30
+
31
+ def forward(waveform)
32
+ specgram = @spectrogram.call(waveform)
33
+ @mel_scale.call(specgram)
34
+ end
35
+ end
36
+ end
37
+ end
@@ -0,0 +1,43 @@
1
+ module TorchAudio
2
+ module Transforms
3
+ class MFCC < Torch::NN::Module
4
+
5
+ SUPPORTED_DCT_TYPES = [2]
6
+
7
+ def initialize(sample_rate: 16000, n_mfcc: 40, dct_type: 2, norm: :ortho, log_mels: false, melkwargs: {})
8
+ super()
9
+
10
+ raise ArgumentError, "DCT type not supported: #{dct_type}" unless SUPPORTED_DCT_TYPES.include?(dct_type)
11
+
12
+ @sample_rate = sample_rate
13
+ @n_mfcc = n_mfcc
14
+ @dct_type = dct_type
15
+ @norm = norm
16
+ @top_db = 80.0
17
+ @amplitude_to_db = TorchAudio::Transforms::AmplitudeToDB.new(stype: :power, top_db: @top_db)
18
+
19
+ @melspectrogram = TorchAudio::Transforms::MelSpectrogram.new(sample_rate: @sample_rate, **melkwargs)
20
+
21
+ raise ArgumentError, "Cannot select more MFCC coefficients than # mel bins" if @n_mfcc > @melspectrogram.n_mels
22
+
23
+ dct_mat = F.create_dct(@n_mfcc, @melspectrogram.n_mels, norm: @norm)
24
+ register_buffer('dct_mat', dct_mat)
25
+
26
+ @log_mels = log_mels
27
+ end
28
+
29
+ def forward(waveform)
30
+ mel_specgram = @melspectrogram.(waveform)
31
+ if @log_mels
32
+ mel_specgram = Torch.log(mel_specgram + 1e-6)
33
+ else
34
+ mel_specgram = @amplitude_to_db.(mel_specgram)
35
+ end
36
+
37
+ Torch
38
+ .matmul(mel_specgram.transpose(-2, -1), @dct_mat)
39
+ .transpose(-2, -1)
40
+ end
41
+ end
42
+ end
43
+ end
@@ -0,0 +1,14 @@
1
+ module TorchAudio
2
+ module Transforms
3
+ class MuLawDecoding < Torch::NN::Module
4
+ def initialize(quantization_channels: 256)
5
+ super()
6
+ @quantization_channels = quantization_channels
7
+ end
8
+
9
+ def forward(x_mu)
10
+ F.mu_law_decoding(x_mu, @quantization_channels)
11
+ end
12
+ end
13
+ end
14
+ end
@@ -0,0 +1,14 @@
1
+ module TorchAudio
2
+ module Transforms
3
+ class MuLawEncoding < Torch::NN::Module
4
+ def initialize(quantization_channels: 256)
5
+ super()
6
+ @quantization_channels = quantization_channels
7
+ end
8
+
9
+ def forward(x)
10
+ F.mu_law_encoding(x, @quantization_channels)
11
+ end
12
+ end
13
+ end
14
+ end
@@ -0,0 +1,27 @@
1
+ module TorchAudio
2
+ module Transforms
3
+ class Spectrogram < Torch::NN::Module
4
+ def initialize(
5
+ n_fft: 400, win_length: nil, hop_length: nil, pad: 0,
6
+ window_fn: Torch.method(:hann_window), power: 2.0, normalized: false, wkwargs: nil
7
+ )
8
+
9
+ super()
10
+ @n_fft = n_fft
11
+ # number of FFT bins. the returned STFT result will have n_fft // 2 + 1
12
+ # number of frequecies due to onesided=True in torch.stft
13
+ @win_length = win_length || n_fft
14
+ @hop_length = hop_length || @win_length.div(2) # floor division
15
+ window = wkwargs.nil? ? window_fn.call(@win_length) : window_fn.call(@win_length, **wkwargs)
16
+ register_buffer("window", window)
17
+ @pad = pad
18
+ @power = power
19
+ @normalized = normalized
20
+ end
21
+
22
+ def forward(waveform)
23
+ F.spectrogram(waveform, @pad, @window, @n_fft, @hop_length, @win_length, @power, @normalized)
24
+ end
25
+ end
26
+ end
27
+ end
@@ -0,0 +1,31 @@
1
+ module TorchAudio
2
+ module Transforms
3
+ class Vol < Torch::NN::Module
4
+ def initialize(gain, gain_type: "amplitude")
5
+ super()
6
+ @gain = gain
7
+ @gain_type = gain_type
8
+
9
+ if ["amplitude", "power"].include?(gain_type) && gain < 0
10
+ raise ArgumentError, "If gain_type = amplitude or power, gain must be positive."
11
+ end
12
+ end
13
+
14
+ def forward(waveform)
15
+ if @gain_type == "amplitude"
16
+ waveform = waveform * @gain
17
+ end
18
+
19
+ if @gain_type == "db"
20
+ waveform = F.gain(waveform, @gain)
21
+ end
22
+
23
+ if @gain_type == "power"
24
+ waveform = F.gain(waveform, 10 * Math.log10(@gain))
25
+ end
26
+
27
+ Torch.clamp(waveform, -1, 1)
28
+ end
29
+ end
30
+ end
31
+ end
@@ -1,3 +1,3 @@
1
1
  module TorchAudio
2
- VERSION = "0.1.0"
2
+ VERSION = "0.2.1"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: torchaudio
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.1.0
4
+ version: 0.2.1
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
- autorequire:
8
+ autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2020-08-25 00:00:00.000000000 Z
11
+ date: 2021-07-16 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: torch-rb
@@ -16,86 +16,30 @@ dependencies:
16
16
  requirements:
17
17
  - - ">="
18
18
  - !ruby/object:Gem::Version
19
- version: 0.3.2
19
+ version: 0.7.0
20
20
  type: :runtime
21
21
  prerelease: false
22
22
  version_requirements: !ruby/object:Gem::Requirement
23
23
  requirements:
24
24
  - - ">="
25
25
  - !ruby/object:Gem::Version
26
- version: 0.3.2
26
+ version: 0.7.0
27
27
  - !ruby/object:Gem::Dependency
28
28
  name: rice
29
29
  requirement: !ruby/object:Gem::Requirement
30
30
  requirements:
31
31
  - - ">="
32
32
  - !ruby/object:Gem::Version
33
- version: '2.2'
33
+ version: 4.0.2
34
34
  type: :runtime
35
35
  prerelease: false
36
36
  version_requirements: !ruby/object:Gem::Requirement
37
37
  requirements:
38
38
  - - ">="
39
39
  - !ruby/object:Gem::Version
40
- version: '2.2'
41
- - !ruby/object:Gem::Dependency
42
- name: bundler
43
- requirement: !ruby/object:Gem::Requirement
44
- requirements:
45
- - - ">="
46
- - !ruby/object:Gem::Version
47
- version: '0'
48
- type: :development
49
- prerelease: false
50
- version_requirements: !ruby/object:Gem::Requirement
51
- requirements:
52
- - - ">="
53
- - !ruby/object:Gem::Version
54
- version: '0'
55
- - !ruby/object:Gem::Dependency
56
- name: rake
57
- requirement: !ruby/object:Gem::Requirement
58
- requirements:
59
- - - ">="
60
- - !ruby/object:Gem::Version
61
- version: '0'
62
- type: :development
63
- prerelease: false
64
- version_requirements: !ruby/object:Gem::Requirement
65
- requirements:
66
- - - ">="
67
- - !ruby/object:Gem::Version
68
- version: '0'
69
- - !ruby/object:Gem::Dependency
70
- name: rake-compiler
71
- requirement: !ruby/object:Gem::Requirement
72
- requirements:
73
- - - ">="
74
- - !ruby/object:Gem::Version
75
- version: '0'
76
- type: :development
77
- prerelease: false
78
- version_requirements: !ruby/object:Gem::Requirement
79
- requirements:
80
- - - ">="
81
- - !ruby/object:Gem::Version
82
- version: '0'
83
- - !ruby/object:Gem::Dependency
84
- name: minitest
85
- requirement: !ruby/object:Gem::Requirement
86
- requirements:
87
- - - ">="
88
- - !ruby/object:Gem::Version
89
- version: '5'
90
- type: :development
91
- prerelease: false
92
- version_requirements: !ruby/object:Gem::Requirement
93
- requirements:
94
- - - ">="
95
- - !ruby/object:Gem::Version
96
- version: '5'
97
- description:
98
- email: andrew@chartkick.com
40
+ version: 4.0.2
41
+ description:
42
+ email: andrew@ankane.org
99
43
  executables: []
100
44
  extensions:
101
45
  - ext/torchaudio/extconf.rb
@@ -118,12 +62,23 @@ files:
118
62
  - lib/torchaudio.rb
119
63
  - lib/torchaudio/datasets/utils.rb
120
64
  - lib/torchaudio/datasets/yesno.rb
65
+ - lib/torchaudio/functional.rb
66
+ - lib/torchaudio/transforms/amplitude_to_db.rb
67
+ - lib/torchaudio/transforms/compute_deltas.rb
68
+ - lib/torchaudio/transforms/fade.rb
69
+ - lib/torchaudio/transforms/mel_scale.rb
70
+ - lib/torchaudio/transforms/mel_spectrogram.rb
71
+ - lib/torchaudio/transforms/mfcc.rb
72
+ - lib/torchaudio/transforms/mu_law_decoding.rb
73
+ - lib/torchaudio/transforms/mu_law_encoding.rb
74
+ - lib/torchaudio/transforms/spectrogram.rb
75
+ - lib/torchaudio/transforms/vol.rb
121
76
  - lib/torchaudio/version.rb
122
77
  homepage: https://github.com/ankane/torchaudio
123
78
  licenses:
124
79
  - BSD-2-Clause
125
80
  metadata: {}
126
- post_install_message:
81
+ post_install_message:
127
82
  rdoc_options: []
128
83
  require_paths:
129
84
  - lib
@@ -131,15 +86,15 @@ required_ruby_version: !ruby/object:Gem::Requirement
131
86
  requirements:
132
87
  - - ">="
133
88
  - !ruby/object:Gem::Version
134
- version: '2.5'
89
+ version: '2.6'
135
90
  required_rubygems_version: !ruby/object:Gem::Requirement
136
91
  requirements:
137
92
  - - ">="
138
93
  - !ruby/object:Gem::Version
139
94
  version: '0'
140
95
  requirements: []
141
- rubygems_version: 3.1.2
142
- signing_key:
96
+ rubygems_version: 3.2.22
97
+ signing_key:
143
98
  specification_version: 4
144
99
  summary: Data manipulation and transformation for audio signal processing
145
100
  test_files: []