torchaudio 0.1.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +7 -0
- data/CHANGELOG.md +3 -0
- data/LICENSE.txt +26 -0
- data/README.md +93 -0
- data/ext/torchaudio/csrc/register.cpp +65 -0
- data/ext/torchaudio/csrc/sox.cpp +361 -0
- data/ext/torchaudio/csrc/sox.h +71 -0
- data/ext/torchaudio/csrc/sox_effects.cpp +54 -0
- data/ext/torchaudio/csrc/sox_effects.h +18 -0
- data/ext/torchaudio/csrc/sox_io.cpp +170 -0
- data/ext/torchaudio/csrc/sox_io.h +41 -0
- data/ext/torchaudio/csrc/sox_utils.cpp +245 -0
- data/ext/torchaudio/csrc/sox_utils.h +100 -0
- data/ext/torchaudio/ext.cpp +33 -0
- data/ext/torchaudio/extconf.rb +81 -0
- data/lib/torchaudio.rb +95 -0
- data/lib/torchaudio/datasets/utils.rb +92 -0
- data/lib/torchaudio/datasets/yesno.rb +59 -0
- data/lib/torchaudio/version.rb +3 -0
- metadata +145 -0
@@ -0,0 +1,100 @@
|
|
1
|
+
#ifndef TORCHAUDIO_SOX_UTILS_H
|
2
|
+
#define TORCHAUDIO_SOX_UTILS_H
|
3
|
+
|
4
|
+
#include <sox.h>
|
5
|
+
#include <torch/script.h>
|
6
|
+
|
7
|
+
namespace torchaudio {
|
8
|
+
namespace sox_utils {
|
9
|
+
|
10
|
+
struct TensorSignal : torch::CustomClassHolder {
|
11
|
+
torch::Tensor tensor;
|
12
|
+
int64_t sample_rate;
|
13
|
+
bool channels_first;
|
14
|
+
|
15
|
+
TensorSignal(
|
16
|
+
torch::Tensor tensor_,
|
17
|
+
int64_t sample_rate_,
|
18
|
+
bool channels_first_);
|
19
|
+
|
20
|
+
torch::Tensor getTensor() const;
|
21
|
+
int64_t getSampleRate() const;
|
22
|
+
bool getChannelsFirst() const;
|
23
|
+
};
|
24
|
+
|
25
|
+
/// helper class to automatically close sox_format_t*
|
26
|
+
struct SoxFormat {
|
27
|
+
explicit SoxFormat(sox_format_t* fd) noexcept;
|
28
|
+
SoxFormat(const SoxFormat& other) = delete;
|
29
|
+
SoxFormat(SoxFormat&& other) = delete;
|
30
|
+
SoxFormat& operator=(const SoxFormat& other) = delete;
|
31
|
+
SoxFormat& operator=(SoxFormat&& other) = delete;
|
32
|
+
~SoxFormat();
|
33
|
+
sox_format_t* operator->() const noexcept;
|
34
|
+
operator sox_format_t*() const noexcept;
|
35
|
+
|
36
|
+
private:
|
37
|
+
sox_format_t* fd_;
|
38
|
+
};
|
39
|
+
|
40
|
+
///
|
41
|
+
/// Verify that input file is found, has known encoding, and not empty
|
42
|
+
void validate_input_file(const SoxFormat& sf);
|
43
|
+
|
44
|
+
///
|
45
|
+
/// Verify that input Tensor is 2D, CPU and either uin8, int16, int32 or float32
|
46
|
+
void validate_input_tensor(const torch::Tensor);
|
47
|
+
|
48
|
+
///
|
49
|
+
/// Get target dtype for the given encoding and precision.
|
50
|
+
caffe2::TypeMeta get_dtype(
|
51
|
+
const sox_encoding_t encoding,
|
52
|
+
const unsigned precision);
|
53
|
+
|
54
|
+
///
|
55
|
+
/// Convert sox_sample_t buffer to uint8/int16/int32/float32 Tensor
|
56
|
+
/// NOTE: This function might modify the values in the input buffer to
|
57
|
+
/// reduce the number of memory copy.
|
58
|
+
/// @param buffer Pointer to buffer that contains audio data.
|
59
|
+
/// @param num_samples The number of samples to read.
|
60
|
+
/// @param num_channels The number of channels. Used to reshape the resulting
|
61
|
+
/// Tensor.
|
62
|
+
/// @param dtype Target dtype. Determines the output dtype and value range in
|
63
|
+
/// conjunction with normalization.
|
64
|
+
/// @param noramlize Perform normalization. Only effective when dtype is not
|
65
|
+
/// kFloat32. When effective, the output tensor is kFloat32 type and value range
|
66
|
+
/// is [-1.0, 1.0]
|
67
|
+
/// @param channels_first When True, output Tensor has shape of [num_channels,
|
68
|
+
/// num_frames].
|
69
|
+
torch::Tensor convert_to_tensor(
|
70
|
+
sox_sample_t* buffer,
|
71
|
+
const int32_t num_samples,
|
72
|
+
const int32_t num_channels,
|
73
|
+
const caffe2::TypeMeta dtype,
|
74
|
+
const bool normalize,
|
75
|
+
const bool channels_first);
|
76
|
+
|
77
|
+
///
|
78
|
+
/// Convert float32/int32/int16/uint8 Tensor to int32 for Torch -> Sox
|
79
|
+
/// conversion.
|
80
|
+
torch::Tensor unnormalize_wav(const torch::Tensor);
|
81
|
+
|
82
|
+
/// Extract extension from file path
|
83
|
+
const std::string get_filetype(const std::string path);
|
84
|
+
|
85
|
+
/// Get sox_signalinfo_t for passing a torch::Tensor object.
|
86
|
+
sox_signalinfo_t get_signalinfo(
|
87
|
+
const torch::Tensor& tensor,
|
88
|
+
const int64_t sample_rate,
|
89
|
+
const bool channels_first,
|
90
|
+
const std::string filetype);
|
91
|
+
|
92
|
+
/// Get sox_encofinginfo_t for saving audoi file
|
93
|
+
sox_encodinginfo_t get_encodinginfo(
|
94
|
+
const std::string filetype,
|
95
|
+
const caffe2::TypeMeta dtype,
|
96
|
+
const double compression);
|
97
|
+
|
98
|
+
} // namespace sox_utils
|
99
|
+
} // namespace torchaudio
|
100
|
+
#endif
|
@@ -0,0 +1,33 @@
|
|
1
|
+
#include <torchaudio/csrc/sox.h>
|
2
|
+
|
3
|
+
#include <rice/Module.hpp>
|
4
|
+
|
5
|
+
using namespace Rice;
|
6
|
+
|
7
|
+
template<>
|
8
|
+
inline
|
9
|
+
sox_signalinfo_t* from_ruby<sox_signalinfo_t*>(Object x)
|
10
|
+
{
|
11
|
+
if (x.is_nil()) {
|
12
|
+
return nullptr;
|
13
|
+
}
|
14
|
+
throw std::runtime_error("Unsupported signalinfo");
|
15
|
+
}
|
16
|
+
|
17
|
+
template<>
|
18
|
+
inline
|
19
|
+
sox_encodinginfo_t* from_ruby<sox_encodinginfo_t*>(Object x)
|
20
|
+
{
|
21
|
+
if (x.is_nil()) {
|
22
|
+
return nullptr;
|
23
|
+
}
|
24
|
+
throw std::runtime_error("Unsupported encodinginfo");
|
25
|
+
}
|
26
|
+
|
27
|
+
extern "C"
|
28
|
+
void Init_ext()
|
29
|
+
{
|
30
|
+
Module rb_mTorchAudio = define_module("TorchAudio");
|
31
|
+
Module rb_mNN = define_module_under(rb_mTorchAudio, "Ext")
|
32
|
+
.define_singleton_method("read_audio_file", &torch::audio::read_audio_file);
|
33
|
+
}
|
@@ -0,0 +1,81 @@
|
|
1
|
+
require "mkmf-rice"
|
2
|
+
|
3
|
+
abort "Missing stdc++" unless have_library("stdc++")
|
4
|
+
|
5
|
+
$CXXFLAGS += " -std=c++14"
|
6
|
+
|
7
|
+
abort "SoX not found" unless have_library("sox")
|
8
|
+
|
9
|
+
ext = File.expand_path(".", __dir__)
|
10
|
+
csrc = File.expand_path("csrc", __dir__)
|
11
|
+
|
12
|
+
$srcs = Dir["{#{ext},#{csrc}}/*.cpp"]
|
13
|
+
$INCFLAGS << " -I#{File.expand_path("..", __dir__)}"
|
14
|
+
$VPATH << csrc
|
15
|
+
|
16
|
+
#
|
17
|
+
# keep rest synced with Torch
|
18
|
+
#
|
19
|
+
|
20
|
+
# change to 0 for Linux pre-cxx11 ABI version
|
21
|
+
$CXXFLAGS += " -D_GLIBCXX_USE_CXX11_ABI=1"
|
22
|
+
|
23
|
+
apple_clang = RbConfig::CONFIG["CC_VERSION_MESSAGE"] =~ /apple clang/i
|
24
|
+
|
25
|
+
# check omp first
|
26
|
+
if have_library("omp") || have_library("gomp")
|
27
|
+
$CXXFLAGS += " -DAT_PARALLEL_OPENMP=1"
|
28
|
+
$CXXFLAGS += " -Xclang" if apple_clang
|
29
|
+
$CXXFLAGS += " -fopenmp"
|
30
|
+
end
|
31
|
+
|
32
|
+
if apple_clang
|
33
|
+
# silence ruby/intern.h warning
|
34
|
+
$CXXFLAGS += " -Wno-deprecated-register"
|
35
|
+
|
36
|
+
# silence torch warnings
|
37
|
+
$CXXFLAGS += " -Wno-shorten-64-to-32 -Wno-missing-noreturn"
|
38
|
+
else
|
39
|
+
# silence rice warnings
|
40
|
+
$CXXFLAGS += " -Wno-noexcept-type"
|
41
|
+
|
42
|
+
# silence torch warnings
|
43
|
+
$CXXFLAGS += " -Wno-duplicated-cond -Wno-suggest-attribute=noreturn"
|
44
|
+
end
|
45
|
+
|
46
|
+
inc, lib = dir_config("torch")
|
47
|
+
inc ||= "/usr/local/include"
|
48
|
+
lib ||= "/usr/local/lib"
|
49
|
+
|
50
|
+
cuda_inc, cuda_lib = dir_config("cuda")
|
51
|
+
cuda_inc ||= "/usr/local/cuda/include"
|
52
|
+
cuda_lib ||= "/usr/local/cuda/lib64"
|
53
|
+
|
54
|
+
$LDFLAGS += " -L#{lib}" if Dir.exist?(lib)
|
55
|
+
abort "LibTorch not found" unless have_library("torch")
|
56
|
+
|
57
|
+
have_library("mkldnn")
|
58
|
+
have_library("nnpack")
|
59
|
+
|
60
|
+
with_cuda = false
|
61
|
+
if Dir["#{lib}/*torch_cuda*"].any?
|
62
|
+
$LDFLAGS += " -L#{cuda_lib}" if Dir.exist?(cuda_lib)
|
63
|
+
with_cuda = have_library("cuda") && have_library("cudnn")
|
64
|
+
end
|
65
|
+
|
66
|
+
$INCFLAGS += " -I#{inc}"
|
67
|
+
$INCFLAGS += " -I#{inc}/torch/csrc/api/include"
|
68
|
+
|
69
|
+
$LDFLAGS += " -Wl,-rpath,#{lib}"
|
70
|
+
$LDFLAGS += ":#{cuda_lib}/stubs:#{cuda_lib}" if with_cuda
|
71
|
+
|
72
|
+
# https://github.com/pytorch/pytorch/blob/v1.5.0/torch/utils/cpp_extension.py#L1232-L1238
|
73
|
+
$LDFLAGS += " -lc10 -ltorch_cpu -ltorch"
|
74
|
+
if with_cuda
|
75
|
+
$LDFLAGS += " -lcuda -lnvrtc -lnvToolsExt -lcudart -lc10_cuda -ltorch_cuda -lcufft -lcurand -lcublas -lcudnn"
|
76
|
+
# TODO figure out why this is needed
|
77
|
+
$LDFLAGS += " -Wl,--no-as-needed,#{lib}/libtorch.so"
|
78
|
+
end
|
79
|
+
|
80
|
+
# create makefile
|
81
|
+
create_makefile("torchaudio/ext")
|
data/lib/torchaudio.rb
ADDED
@@ -0,0 +1,95 @@
|
|
1
|
+
# dependencies
|
2
|
+
require "torch"
|
3
|
+
|
4
|
+
# ext
|
5
|
+
require "torchaudio/ext"
|
6
|
+
|
7
|
+
# stdlib
|
8
|
+
require "csv"
|
9
|
+
require "digest"
|
10
|
+
require "fileutils"
|
11
|
+
require "rubygems/package"
|
12
|
+
require "set"
|
13
|
+
|
14
|
+
# modules
|
15
|
+
require "torchaudio/datasets/utils"
|
16
|
+
require "torchaudio/datasets/yesno"
|
17
|
+
require "torchaudio/version"
|
18
|
+
|
19
|
+
module TorchAudio
|
20
|
+
class Error < StandardError; end
|
21
|
+
|
22
|
+
class << self
|
23
|
+
def load(
|
24
|
+
filepath, out: nil, normalization: true, channels_first: true, num_frames: 0,
|
25
|
+
offset: 0, signalinfo: nil, encodinginfo: nil, filetype: nil
|
26
|
+
)
|
27
|
+
|
28
|
+
filepath = filepath.to_s
|
29
|
+
|
30
|
+
# check if valid file
|
31
|
+
unless File.exist?(filepath)
|
32
|
+
raise ArgumentError, "#{filepath} not found or is a directory"
|
33
|
+
end
|
34
|
+
|
35
|
+
# initialize output tensor
|
36
|
+
if !out.nil?
|
37
|
+
check_input(out)
|
38
|
+
else
|
39
|
+
out = Torch::FloatTensor.new
|
40
|
+
end
|
41
|
+
|
42
|
+
if num_frames < -1
|
43
|
+
raise ArgumentError, "Expected value for num_samples -1 (entire file) or >=0"
|
44
|
+
end
|
45
|
+
if offset < 0
|
46
|
+
raise ArgumentError, "Expected positive offset value"
|
47
|
+
end
|
48
|
+
|
49
|
+
# same logic as C++
|
50
|
+
# could also make read_audio_file work with nil
|
51
|
+
filetype ||= File.extname(filepath)[1..-1]
|
52
|
+
|
53
|
+
sample_rate =
|
54
|
+
Ext.read_audio_file(
|
55
|
+
filepath,
|
56
|
+
out,
|
57
|
+
channels_first,
|
58
|
+
num_frames,
|
59
|
+
offset,
|
60
|
+
signalinfo,
|
61
|
+
encodinginfo,
|
62
|
+
filetype
|
63
|
+
)
|
64
|
+
|
65
|
+
# normalize if needed
|
66
|
+
normalize_audio(out, normalization)
|
67
|
+
|
68
|
+
[out, sample_rate]
|
69
|
+
end
|
70
|
+
|
71
|
+
def load_wav(filepath, **kwargs)
|
72
|
+
kwargs[:normalization] = 1 << 16
|
73
|
+
load(filepath, **kwargs)
|
74
|
+
end
|
75
|
+
|
76
|
+
private
|
77
|
+
|
78
|
+
def check_input(src)
|
79
|
+
raise ArgumentError, "Expected a tensor, got #{src.class.name}" unless Torch.tensor?(src)
|
80
|
+
raise ArgumentError, "Expected a CPU based tensor, got #{src.class.name}" if src.cuda?
|
81
|
+
end
|
82
|
+
|
83
|
+
def normalize_audio(signal, normalization)
|
84
|
+
return unless normalization
|
85
|
+
|
86
|
+
normalization = 1 << 31 if normalization == true
|
87
|
+
|
88
|
+
if normalization.is_a?(Numeric)
|
89
|
+
signal.div!(normalization)
|
90
|
+
elsif normalization.respond_to?(:call)
|
91
|
+
signal.div!(normalization.call(signal))
|
92
|
+
end
|
93
|
+
end
|
94
|
+
end
|
95
|
+
end
|
@@ -0,0 +1,92 @@
|
|
1
|
+
module TorchAudio
|
2
|
+
module Datasets
|
3
|
+
module Utils
|
4
|
+
class << self
|
5
|
+
def download_url(url, download_folder, filename: nil, hash_value: nil, hash_type: "sha256")
|
6
|
+
filename ||= File.basename(url)
|
7
|
+
filepath = File.join(download_folder, filename)
|
8
|
+
|
9
|
+
if File.exist?(filepath)
|
10
|
+
raise "#{filepath} already exists. Delete the file manually and retry."
|
11
|
+
end
|
12
|
+
|
13
|
+
puts "Downloading #{url}..."
|
14
|
+
download_url_to_file(url, filepath, hash_value, hash_type)
|
15
|
+
end
|
16
|
+
|
17
|
+
# follows redirects
|
18
|
+
def download_url_to_file(url, dst, hash_value, hash_type)
|
19
|
+
uri = URI(url)
|
20
|
+
tmp = nil
|
21
|
+
location = nil
|
22
|
+
|
23
|
+
Net::HTTP.start(uri.host, uri.port, use_ssl: uri.scheme == "https") do |http|
|
24
|
+
request = Net::HTTP::Get.new(uri)
|
25
|
+
|
26
|
+
http.request(request) do |response|
|
27
|
+
case response
|
28
|
+
when Net::HTTPRedirection
|
29
|
+
location = response["location"]
|
30
|
+
when Net::HTTPSuccess
|
31
|
+
tmp = "#{Dir.tmpdir}/#{Time.now.to_f}" # TODO better name
|
32
|
+
File.open(tmp, "wb") do |f|
|
33
|
+
response.read_body do |chunk|
|
34
|
+
f.write(chunk)
|
35
|
+
end
|
36
|
+
end
|
37
|
+
else
|
38
|
+
raise Error, "Bad response"
|
39
|
+
end
|
40
|
+
end
|
41
|
+
end
|
42
|
+
|
43
|
+
if location
|
44
|
+
download_url_to_file(location, dst)
|
45
|
+
else
|
46
|
+
# check hash
|
47
|
+
# TODO use hash_type
|
48
|
+
if Digest::MD5.file(tmp).hexdigest != hash_value
|
49
|
+
raise "The hash of #{dst} does not match. Delete the file manually and retry."
|
50
|
+
end
|
51
|
+
|
52
|
+
FileUtils.mv(tmp, dst)
|
53
|
+
dst
|
54
|
+
end
|
55
|
+
end
|
56
|
+
|
57
|
+
# extract_tar_gz doesn't list files, so just return to_path
|
58
|
+
def extract_archive(from_path, to_path: nil, overwrite: nil)
|
59
|
+
to_path ||= File.dirname(from_path)
|
60
|
+
|
61
|
+
if from_path.end_with?(".tar.gz") || from_path.end_with?(".tgz")
|
62
|
+
File.open(from_path, "rb") do |io|
|
63
|
+
Gem::Package.new("").extract_tar_gz(io, to_path)
|
64
|
+
end
|
65
|
+
return to_path
|
66
|
+
end
|
67
|
+
|
68
|
+
raise "We currently only support tar.gz and tgz archives."
|
69
|
+
end
|
70
|
+
|
71
|
+
def walk_files(root, suffix, prefix: false, remove_suffix: false)
|
72
|
+
return enum_for(:walk_files, root, suffix, prefix: prefix, remove_suffix: remove_suffix) unless block_given?
|
73
|
+
|
74
|
+
Dir.glob("**/*", base: root).sort.each do |f|
|
75
|
+
if f.end_with?(suffix)
|
76
|
+
if remove_suffix
|
77
|
+
f = f[0..(-suffix.length - 1)]
|
78
|
+
end
|
79
|
+
|
80
|
+
if prefix
|
81
|
+
raise "Not implemented yet"
|
82
|
+
# f = File.join(dirpath, f)
|
83
|
+
end
|
84
|
+
|
85
|
+
yield f
|
86
|
+
end
|
87
|
+
end
|
88
|
+
end
|
89
|
+
end
|
90
|
+
end
|
91
|
+
end
|
92
|
+
end
|
@@ -0,0 +1,59 @@
|
|
1
|
+
module TorchAudio
|
2
|
+
module Datasets
|
3
|
+
class YESNO < Torch::Utils::Data::Dataset
|
4
|
+
URL = "http://www.openslr.org/resources/1/waves_yesno.tar.gz"
|
5
|
+
FOLDER_IN_ARCHIVE = "waves_yesno"
|
6
|
+
CHECKSUMS = {
|
7
|
+
"http://www.openslr.org/resources/1/waves_yesno.tar.gz" => "962ff6e904d2df1126132ecec6978786"
|
8
|
+
}
|
9
|
+
|
10
|
+
def initialize(root, url: URL, folder_in_archive: FOLDER_IN_ARCHIVE, download: false)
|
11
|
+
archive = File.basename(url)
|
12
|
+
archive = File.join(root, archive)
|
13
|
+
@path = File.join(root, folder_in_archive)
|
14
|
+
|
15
|
+
if download
|
16
|
+
unless Dir.exist?(@path)
|
17
|
+
unless File.exist?(archive)
|
18
|
+
checksum = CHECKSUMS.fetch(url)
|
19
|
+
Utils.download_url(url, root, hash_value: checksum, hash_type: "md5")
|
20
|
+
end
|
21
|
+
Utils.extract_archive(archive)
|
22
|
+
end
|
23
|
+
end
|
24
|
+
|
25
|
+
unless Dir.exist?(@path)
|
26
|
+
raise "Dataset not found. Please use `download: true` to download it."
|
27
|
+
end
|
28
|
+
|
29
|
+
walker = Utils.walk_files(@path, ext_audio, prefix: false, remove_suffix: true)
|
30
|
+
@walker = walker.to_a
|
31
|
+
end
|
32
|
+
|
33
|
+
def [](n)
|
34
|
+
fileid = @walker[n]
|
35
|
+
load_yesno_item(fileid, @path, ext_audio)
|
36
|
+
end
|
37
|
+
|
38
|
+
def length
|
39
|
+
@walker.length
|
40
|
+
end
|
41
|
+
alias_method :size, :length
|
42
|
+
|
43
|
+
private
|
44
|
+
|
45
|
+
def load_yesno_item(fileid, path, ext_audio)
|
46
|
+
labels = fileid.split("_").map(&:to_i)
|
47
|
+
|
48
|
+
file_audio = File.join(path, fileid + ext_audio)
|
49
|
+
waveform, sample_rate = TorchAudio.load(file_audio)
|
50
|
+
|
51
|
+
[waveform, sample_rate, labels]
|
52
|
+
end
|
53
|
+
|
54
|
+
def ext_audio
|
55
|
+
".wav"
|
56
|
+
end
|
57
|
+
end
|
58
|
+
end
|
59
|
+
end
|