torchaudio 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/CHANGELOG.md +3 -0
- data/LICENSE.txt +26 -0
- data/README.md +93 -0
- data/ext/torchaudio/csrc/register.cpp +65 -0
- data/ext/torchaudio/csrc/sox.cpp +361 -0
- data/ext/torchaudio/csrc/sox.h +71 -0
- data/ext/torchaudio/csrc/sox_effects.cpp +54 -0
- data/ext/torchaudio/csrc/sox_effects.h +18 -0
- data/ext/torchaudio/csrc/sox_io.cpp +170 -0
- data/ext/torchaudio/csrc/sox_io.h +41 -0
- data/ext/torchaudio/csrc/sox_utils.cpp +245 -0
- data/ext/torchaudio/csrc/sox_utils.h +100 -0
- data/ext/torchaudio/ext.cpp +33 -0
- data/ext/torchaudio/extconf.rb +81 -0
- data/lib/torchaudio.rb +95 -0
- data/lib/torchaudio/datasets/utils.rb +92 -0
- data/lib/torchaudio/datasets/yesno.rb +59 -0
- data/lib/torchaudio/version.rb +3 -0
- metadata +145 -0
@@ -0,0 +1,100 @@
|
|
1
|
+
#ifndef TORCHAUDIO_SOX_UTILS_H
|
2
|
+
#define TORCHAUDIO_SOX_UTILS_H
|
3
|
+
|
4
|
+
#include <sox.h>
|
5
|
+
#include <torch/script.h>
|
6
|
+
|
7
|
+
namespace torchaudio {
|
8
|
+
namespace sox_utils {
|
9
|
+
|
10
|
+
struct TensorSignal : torch::CustomClassHolder {
|
11
|
+
torch::Tensor tensor;
|
12
|
+
int64_t sample_rate;
|
13
|
+
bool channels_first;
|
14
|
+
|
15
|
+
TensorSignal(
|
16
|
+
torch::Tensor tensor_,
|
17
|
+
int64_t sample_rate_,
|
18
|
+
bool channels_first_);
|
19
|
+
|
20
|
+
torch::Tensor getTensor() const;
|
21
|
+
int64_t getSampleRate() const;
|
22
|
+
bool getChannelsFirst() const;
|
23
|
+
};
|
24
|
+
|
25
|
+
/// helper class to automatically close sox_format_t*
|
26
|
+
struct SoxFormat {
|
27
|
+
explicit SoxFormat(sox_format_t* fd) noexcept;
|
28
|
+
SoxFormat(const SoxFormat& other) = delete;
|
29
|
+
SoxFormat(SoxFormat&& other) = delete;
|
30
|
+
SoxFormat& operator=(const SoxFormat& other) = delete;
|
31
|
+
SoxFormat& operator=(SoxFormat&& other) = delete;
|
32
|
+
~SoxFormat();
|
33
|
+
sox_format_t* operator->() const noexcept;
|
34
|
+
operator sox_format_t*() const noexcept;
|
35
|
+
|
36
|
+
private:
|
37
|
+
sox_format_t* fd_;
|
38
|
+
};
|
39
|
+
|
40
|
+
///
|
41
|
+
/// Verify that input file is found, has known encoding, and not empty
|
42
|
+
void validate_input_file(const SoxFormat& sf);
|
43
|
+
|
44
|
+
///
|
45
|
+
/// Verify that input Tensor is 2D, CPU and either uin8, int16, int32 or float32
|
46
|
+
void validate_input_tensor(const torch::Tensor);
|
47
|
+
|
48
|
+
///
|
49
|
+
/// Get target dtype for the given encoding and precision.
|
50
|
+
caffe2::TypeMeta get_dtype(
|
51
|
+
const sox_encoding_t encoding,
|
52
|
+
const unsigned precision);
|
53
|
+
|
54
|
+
///
|
55
|
+
/// Convert sox_sample_t buffer to uint8/int16/int32/float32 Tensor
|
56
|
+
/// NOTE: This function might modify the values in the input buffer to
|
57
|
+
/// reduce the number of memory copy.
|
58
|
+
/// @param buffer Pointer to buffer that contains audio data.
|
59
|
+
/// @param num_samples The number of samples to read.
|
60
|
+
/// @param num_channels The number of channels. Used to reshape the resulting
|
61
|
+
/// Tensor.
|
62
|
+
/// @param dtype Target dtype. Determines the output dtype and value range in
|
63
|
+
/// conjunction with normalization.
|
64
|
+
/// @param noramlize Perform normalization. Only effective when dtype is not
|
65
|
+
/// kFloat32. When effective, the output tensor is kFloat32 type and value range
|
66
|
+
/// is [-1.0, 1.0]
|
67
|
+
/// @param channels_first When True, output Tensor has shape of [num_channels,
|
68
|
+
/// num_frames].
|
69
|
+
torch::Tensor convert_to_tensor(
|
70
|
+
sox_sample_t* buffer,
|
71
|
+
const int32_t num_samples,
|
72
|
+
const int32_t num_channels,
|
73
|
+
const caffe2::TypeMeta dtype,
|
74
|
+
const bool normalize,
|
75
|
+
const bool channels_first);
|
76
|
+
|
77
|
+
///
|
78
|
+
/// Convert float32/int32/int16/uint8 Tensor to int32 for Torch -> Sox
|
79
|
+
/// conversion.
|
80
|
+
torch::Tensor unnormalize_wav(const torch::Tensor);
|
81
|
+
|
82
|
+
/// Extract extension from file path
|
83
|
+
const std::string get_filetype(const std::string path);
|
84
|
+
|
85
|
+
/// Get sox_signalinfo_t for passing a torch::Tensor object.
|
86
|
+
sox_signalinfo_t get_signalinfo(
|
87
|
+
const torch::Tensor& tensor,
|
88
|
+
const int64_t sample_rate,
|
89
|
+
const bool channels_first,
|
90
|
+
const std::string filetype);
|
91
|
+
|
92
|
+
/// Get sox_encofinginfo_t for saving audoi file
|
93
|
+
sox_encodinginfo_t get_encodinginfo(
|
94
|
+
const std::string filetype,
|
95
|
+
const caffe2::TypeMeta dtype,
|
96
|
+
const double compression);
|
97
|
+
|
98
|
+
} // namespace sox_utils
|
99
|
+
} // namespace torchaudio
|
100
|
+
#endif
|
@@ -0,0 +1,33 @@
|
|
1
|
+
#include <torchaudio/csrc/sox.h>
|
2
|
+
|
3
|
+
#include <rice/Module.hpp>
|
4
|
+
|
5
|
+
using namespace Rice;
|
6
|
+
|
7
|
+
template<>
|
8
|
+
inline
|
9
|
+
sox_signalinfo_t* from_ruby<sox_signalinfo_t*>(Object x)
|
10
|
+
{
|
11
|
+
if (x.is_nil()) {
|
12
|
+
return nullptr;
|
13
|
+
}
|
14
|
+
throw std::runtime_error("Unsupported signalinfo");
|
15
|
+
}
|
16
|
+
|
17
|
+
template<>
|
18
|
+
inline
|
19
|
+
sox_encodinginfo_t* from_ruby<sox_encodinginfo_t*>(Object x)
|
20
|
+
{
|
21
|
+
if (x.is_nil()) {
|
22
|
+
return nullptr;
|
23
|
+
}
|
24
|
+
throw std::runtime_error("Unsupported encodinginfo");
|
25
|
+
}
|
26
|
+
|
27
|
+
extern "C"
|
28
|
+
void Init_ext()
|
29
|
+
{
|
30
|
+
Module rb_mTorchAudio = define_module("TorchAudio");
|
31
|
+
Module rb_mNN = define_module_under(rb_mTorchAudio, "Ext")
|
32
|
+
.define_singleton_method("read_audio_file", &torch::audio::read_audio_file);
|
33
|
+
}
|
@@ -0,0 +1,81 @@
|
|
1
|
+
require "mkmf-rice"
|
2
|
+
|
3
|
+
abort "Missing stdc++" unless have_library("stdc++")
|
4
|
+
|
5
|
+
$CXXFLAGS += " -std=c++14"
|
6
|
+
|
7
|
+
abort "SoX not found" unless have_library("sox")
|
8
|
+
|
9
|
+
ext = File.expand_path(".", __dir__)
|
10
|
+
csrc = File.expand_path("csrc", __dir__)
|
11
|
+
|
12
|
+
$srcs = Dir["{#{ext},#{csrc}}/*.cpp"]
|
13
|
+
$INCFLAGS << " -I#{File.expand_path("..", __dir__)}"
|
14
|
+
$VPATH << csrc
|
15
|
+
|
16
|
+
#
|
17
|
+
# keep rest synced with Torch
|
18
|
+
#
|
19
|
+
|
20
|
+
# change to 0 for Linux pre-cxx11 ABI version
|
21
|
+
$CXXFLAGS += " -D_GLIBCXX_USE_CXX11_ABI=1"
|
22
|
+
|
23
|
+
apple_clang = RbConfig::CONFIG["CC_VERSION_MESSAGE"] =~ /apple clang/i
|
24
|
+
|
25
|
+
# check omp first
|
26
|
+
if have_library("omp") || have_library("gomp")
|
27
|
+
$CXXFLAGS += " -DAT_PARALLEL_OPENMP=1"
|
28
|
+
$CXXFLAGS += " -Xclang" if apple_clang
|
29
|
+
$CXXFLAGS += " -fopenmp"
|
30
|
+
end
|
31
|
+
|
32
|
+
if apple_clang
|
33
|
+
# silence ruby/intern.h warning
|
34
|
+
$CXXFLAGS += " -Wno-deprecated-register"
|
35
|
+
|
36
|
+
# silence torch warnings
|
37
|
+
$CXXFLAGS += " -Wno-shorten-64-to-32 -Wno-missing-noreturn"
|
38
|
+
else
|
39
|
+
# silence rice warnings
|
40
|
+
$CXXFLAGS += " -Wno-noexcept-type"
|
41
|
+
|
42
|
+
# silence torch warnings
|
43
|
+
$CXXFLAGS += " -Wno-duplicated-cond -Wno-suggest-attribute=noreturn"
|
44
|
+
end
|
45
|
+
|
46
|
+
inc, lib = dir_config("torch")
|
47
|
+
inc ||= "/usr/local/include"
|
48
|
+
lib ||= "/usr/local/lib"
|
49
|
+
|
50
|
+
cuda_inc, cuda_lib = dir_config("cuda")
|
51
|
+
cuda_inc ||= "/usr/local/cuda/include"
|
52
|
+
cuda_lib ||= "/usr/local/cuda/lib64"
|
53
|
+
|
54
|
+
$LDFLAGS += " -L#{lib}" if Dir.exist?(lib)
|
55
|
+
abort "LibTorch not found" unless have_library("torch")
|
56
|
+
|
57
|
+
have_library("mkldnn")
|
58
|
+
have_library("nnpack")
|
59
|
+
|
60
|
+
with_cuda = false
|
61
|
+
if Dir["#{lib}/*torch_cuda*"].any?
|
62
|
+
$LDFLAGS += " -L#{cuda_lib}" if Dir.exist?(cuda_lib)
|
63
|
+
with_cuda = have_library("cuda") && have_library("cudnn")
|
64
|
+
end
|
65
|
+
|
66
|
+
$INCFLAGS += " -I#{inc}"
|
67
|
+
$INCFLAGS += " -I#{inc}/torch/csrc/api/include"
|
68
|
+
|
69
|
+
$LDFLAGS += " -Wl,-rpath,#{lib}"
|
70
|
+
$LDFLAGS += ":#{cuda_lib}/stubs:#{cuda_lib}" if with_cuda
|
71
|
+
|
72
|
+
# https://github.com/pytorch/pytorch/blob/v1.5.0/torch/utils/cpp_extension.py#L1232-L1238
|
73
|
+
$LDFLAGS += " -lc10 -ltorch_cpu -ltorch"
|
74
|
+
if with_cuda
|
75
|
+
$LDFLAGS += " -lcuda -lnvrtc -lnvToolsExt -lcudart -lc10_cuda -ltorch_cuda -lcufft -lcurand -lcublas -lcudnn"
|
76
|
+
# TODO figure out why this is needed
|
77
|
+
$LDFLAGS += " -Wl,--no-as-needed,#{lib}/libtorch.so"
|
78
|
+
end
|
79
|
+
|
80
|
+
# create makefile
|
81
|
+
create_makefile("torchaudio/ext")
|
data/lib/torchaudio.rb
ADDED
@@ -0,0 +1,95 @@
|
|
1
|
+
# dependencies
|
2
|
+
require "torch"
|
3
|
+
|
4
|
+
# ext
|
5
|
+
require "torchaudio/ext"
|
6
|
+
|
7
|
+
# stdlib
|
8
|
+
require "csv"
|
9
|
+
require "digest"
|
10
|
+
require "fileutils"
|
11
|
+
require "rubygems/package"
|
12
|
+
require "set"
|
13
|
+
|
14
|
+
# modules
|
15
|
+
require "torchaudio/datasets/utils"
|
16
|
+
require "torchaudio/datasets/yesno"
|
17
|
+
require "torchaudio/version"
|
18
|
+
|
19
|
+
module TorchAudio
|
20
|
+
class Error < StandardError; end
|
21
|
+
|
22
|
+
class << self
|
23
|
+
def load(
|
24
|
+
filepath, out: nil, normalization: true, channels_first: true, num_frames: 0,
|
25
|
+
offset: 0, signalinfo: nil, encodinginfo: nil, filetype: nil
|
26
|
+
)
|
27
|
+
|
28
|
+
filepath = filepath.to_s
|
29
|
+
|
30
|
+
# check if valid file
|
31
|
+
unless File.exist?(filepath)
|
32
|
+
raise ArgumentError, "#{filepath} not found or is a directory"
|
33
|
+
end
|
34
|
+
|
35
|
+
# initialize output tensor
|
36
|
+
if !out.nil?
|
37
|
+
check_input(out)
|
38
|
+
else
|
39
|
+
out = Torch::FloatTensor.new
|
40
|
+
end
|
41
|
+
|
42
|
+
if num_frames < -1
|
43
|
+
raise ArgumentError, "Expected value for num_samples -1 (entire file) or >=0"
|
44
|
+
end
|
45
|
+
if offset < 0
|
46
|
+
raise ArgumentError, "Expected positive offset value"
|
47
|
+
end
|
48
|
+
|
49
|
+
# same logic as C++
|
50
|
+
# could also make read_audio_file work with nil
|
51
|
+
filetype ||= File.extname(filepath)[1..-1]
|
52
|
+
|
53
|
+
sample_rate =
|
54
|
+
Ext.read_audio_file(
|
55
|
+
filepath,
|
56
|
+
out,
|
57
|
+
channels_first,
|
58
|
+
num_frames,
|
59
|
+
offset,
|
60
|
+
signalinfo,
|
61
|
+
encodinginfo,
|
62
|
+
filetype
|
63
|
+
)
|
64
|
+
|
65
|
+
# normalize if needed
|
66
|
+
normalize_audio(out, normalization)
|
67
|
+
|
68
|
+
[out, sample_rate]
|
69
|
+
end
|
70
|
+
|
71
|
+
def load_wav(filepath, **kwargs)
|
72
|
+
kwargs[:normalization] = 1 << 16
|
73
|
+
load(filepath, **kwargs)
|
74
|
+
end
|
75
|
+
|
76
|
+
private
|
77
|
+
|
78
|
+
def check_input(src)
|
79
|
+
raise ArgumentError, "Expected a tensor, got #{src.class.name}" unless Torch.tensor?(src)
|
80
|
+
raise ArgumentError, "Expected a CPU based tensor, got #{src.class.name}" if src.cuda?
|
81
|
+
end
|
82
|
+
|
83
|
+
def normalize_audio(signal, normalization)
|
84
|
+
return unless normalization
|
85
|
+
|
86
|
+
normalization = 1 << 31 if normalization == true
|
87
|
+
|
88
|
+
if normalization.is_a?(Numeric)
|
89
|
+
signal.div!(normalization)
|
90
|
+
elsif normalization.respond_to?(:call)
|
91
|
+
signal.div!(normalization.call(signal))
|
92
|
+
end
|
93
|
+
end
|
94
|
+
end
|
95
|
+
end
|
@@ -0,0 +1,92 @@
|
|
1
|
+
module TorchAudio
|
2
|
+
module Datasets
|
3
|
+
module Utils
|
4
|
+
class << self
|
5
|
+
def download_url(url, download_folder, filename: nil, hash_value: nil, hash_type: "sha256")
|
6
|
+
filename ||= File.basename(url)
|
7
|
+
filepath = File.join(download_folder, filename)
|
8
|
+
|
9
|
+
if File.exist?(filepath)
|
10
|
+
raise "#{filepath} already exists. Delete the file manually and retry."
|
11
|
+
end
|
12
|
+
|
13
|
+
puts "Downloading #{url}..."
|
14
|
+
download_url_to_file(url, filepath, hash_value, hash_type)
|
15
|
+
end
|
16
|
+
|
17
|
+
# follows redirects
|
18
|
+
def download_url_to_file(url, dst, hash_value, hash_type)
|
19
|
+
uri = URI(url)
|
20
|
+
tmp = nil
|
21
|
+
location = nil
|
22
|
+
|
23
|
+
Net::HTTP.start(uri.host, uri.port, use_ssl: uri.scheme == "https") do |http|
|
24
|
+
request = Net::HTTP::Get.new(uri)
|
25
|
+
|
26
|
+
http.request(request) do |response|
|
27
|
+
case response
|
28
|
+
when Net::HTTPRedirection
|
29
|
+
location = response["location"]
|
30
|
+
when Net::HTTPSuccess
|
31
|
+
tmp = "#{Dir.tmpdir}/#{Time.now.to_f}" # TODO better name
|
32
|
+
File.open(tmp, "wb") do |f|
|
33
|
+
response.read_body do |chunk|
|
34
|
+
f.write(chunk)
|
35
|
+
end
|
36
|
+
end
|
37
|
+
else
|
38
|
+
raise Error, "Bad response"
|
39
|
+
end
|
40
|
+
end
|
41
|
+
end
|
42
|
+
|
43
|
+
if location
|
44
|
+
download_url_to_file(location, dst)
|
45
|
+
else
|
46
|
+
# check hash
|
47
|
+
# TODO use hash_type
|
48
|
+
if Digest::MD5.file(tmp).hexdigest != hash_value
|
49
|
+
raise "The hash of #{dst} does not match. Delete the file manually and retry."
|
50
|
+
end
|
51
|
+
|
52
|
+
FileUtils.mv(tmp, dst)
|
53
|
+
dst
|
54
|
+
end
|
55
|
+
end
|
56
|
+
|
57
|
+
# extract_tar_gz doesn't list files, so just return to_path
|
58
|
+
def extract_archive(from_path, to_path: nil, overwrite: nil)
|
59
|
+
to_path ||= File.dirname(from_path)
|
60
|
+
|
61
|
+
if from_path.end_with?(".tar.gz") || from_path.end_with?(".tgz")
|
62
|
+
File.open(from_path, "rb") do |io|
|
63
|
+
Gem::Package.new("").extract_tar_gz(io, to_path)
|
64
|
+
end
|
65
|
+
return to_path
|
66
|
+
end
|
67
|
+
|
68
|
+
raise "We currently only support tar.gz and tgz archives."
|
69
|
+
end
|
70
|
+
|
71
|
+
def walk_files(root, suffix, prefix: false, remove_suffix: false)
|
72
|
+
return enum_for(:walk_files, root, suffix, prefix: prefix, remove_suffix: remove_suffix) unless block_given?
|
73
|
+
|
74
|
+
Dir.glob("**/*", base: root).sort.each do |f|
|
75
|
+
if f.end_with?(suffix)
|
76
|
+
if remove_suffix
|
77
|
+
f = f[0..(-suffix.length - 1)]
|
78
|
+
end
|
79
|
+
|
80
|
+
if prefix
|
81
|
+
raise "Not implemented yet"
|
82
|
+
# f = File.join(dirpath, f)
|
83
|
+
end
|
84
|
+
|
85
|
+
yield f
|
86
|
+
end
|
87
|
+
end
|
88
|
+
end
|
89
|
+
end
|
90
|
+
end
|
91
|
+
end
|
92
|
+
end
|
@@ -0,0 +1,59 @@
|
|
1
|
+
module TorchAudio
|
2
|
+
module Datasets
|
3
|
+
class YESNO < Torch::Utils::Data::Dataset
|
4
|
+
URL = "http://www.openslr.org/resources/1/waves_yesno.tar.gz"
|
5
|
+
FOLDER_IN_ARCHIVE = "waves_yesno"
|
6
|
+
CHECKSUMS = {
|
7
|
+
"http://www.openslr.org/resources/1/waves_yesno.tar.gz" => "962ff6e904d2df1126132ecec6978786"
|
8
|
+
}
|
9
|
+
|
10
|
+
def initialize(root, url: URL, folder_in_archive: FOLDER_IN_ARCHIVE, download: false)
|
11
|
+
archive = File.basename(url)
|
12
|
+
archive = File.join(root, archive)
|
13
|
+
@path = File.join(root, folder_in_archive)
|
14
|
+
|
15
|
+
if download
|
16
|
+
unless Dir.exist?(@path)
|
17
|
+
unless File.exist?(archive)
|
18
|
+
checksum = CHECKSUMS.fetch(url)
|
19
|
+
Utils.download_url(url, root, hash_value: checksum, hash_type: "md5")
|
20
|
+
end
|
21
|
+
Utils.extract_archive(archive)
|
22
|
+
end
|
23
|
+
end
|
24
|
+
|
25
|
+
unless Dir.exist?(@path)
|
26
|
+
raise "Dataset not found. Please use `download: true` to download it."
|
27
|
+
end
|
28
|
+
|
29
|
+
walker = Utils.walk_files(@path, ext_audio, prefix: false, remove_suffix: true)
|
30
|
+
@walker = walker.to_a
|
31
|
+
end
|
32
|
+
|
33
|
+
def [](n)
|
34
|
+
fileid = @walker[n]
|
35
|
+
load_yesno_item(fileid, @path, ext_audio)
|
36
|
+
end
|
37
|
+
|
38
|
+
def length
|
39
|
+
@walker.length
|
40
|
+
end
|
41
|
+
alias_method :size, :length
|
42
|
+
|
43
|
+
private
|
44
|
+
|
45
|
+
def load_yesno_item(fileid, path, ext_audio)
|
46
|
+
labels = fileid.split("_").map(&:to_i)
|
47
|
+
|
48
|
+
file_audio = File.join(path, fileid + ext_audio)
|
49
|
+
waveform, sample_rate = TorchAudio.load(file_audio)
|
50
|
+
|
51
|
+
[waveform, sample_rate, labels]
|
52
|
+
end
|
53
|
+
|
54
|
+
def ext_audio
|
55
|
+
".wav"
|
56
|
+
end
|
57
|
+
end
|
58
|
+
end
|
59
|
+
end
|