torchaudio 0.1.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,100 @@
1
+ #ifndef TORCHAUDIO_SOX_UTILS_H
2
+ #define TORCHAUDIO_SOX_UTILS_H
3
+
4
+ #include <sox.h>
5
+ #include <torch/script.h>
6
+
7
+ namespace torchaudio {
8
+ namespace sox_utils {
9
+
10
+ struct TensorSignal : torch::CustomClassHolder {
11
+ torch::Tensor tensor;
12
+ int64_t sample_rate;
13
+ bool channels_first;
14
+
15
+ TensorSignal(
16
+ torch::Tensor tensor_,
17
+ int64_t sample_rate_,
18
+ bool channels_first_);
19
+
20
+ torch::Tensor getTensor() const;
21
+ int64_t getSampleRate() const;
22
+ bool getChannelsFirst() const;
23
+ };
24
+
25
+ /// helper class to automatically close sox_format_t*
26
+ struct SoxFormat {
27
+ explicit SoxFormat(sox_format_t* fd) noexcept;
28
+ SoxFormat(const SoxFormat& other) = delete;
29
+ SoxFormat(SoxFormat&& other) = delete;
30
+ SoxFormat& operator=(const SoxFormat& other) = delete;
31
+ SoxFormat& operator=(SoxFormat&& other) = delete;
32
+ ~SoxFormat();
33
+ sox_format_t* operator->() const noexcept;
34
+ operator sox_format_t*() const noexcept;
35
+
36
+ private:
37
+ sox_format_t* fd_;
38
+ };
39
+
40
+ ///
41
+ /// Verify that input file is found, has known encoding, and not empty
42
+ void validate_input_file(const SoxFormat& sf);
43
+
44
+ ///
45
+ /// Verify that input Tensor is 2D, CPU and either uin8, int16, int32 or float32
46
+ void validate_input_tensor(const torch::Tensor);
47
+
48
+ ///
49
+ /// Get target dtype for the given encoding and precision.
50
+ caffe2::TypeMeta get_dtype(
51
+ const sox_encoding_t encoding,
52
+ const unsigned precision);
53
+
54
+ ///
55
+ /// Convert sox_sample_t buffer to uint8/int16/int32/float32 Tensor
56
+ /// NOTE: This function might modify the values in the input buffer to
57
+ /// reduce the number of memory copy.
58
+ /// @param buffer Pointer to buffer that contains audio data.
59
+ /// @param num_samples The number of samples to read.
60
+ /// @param num_channels The number of channels. Used to reshape the resulting
61
+ /// Tensor.
62
+ /// @param dtype Target dtype. Determines the output dtype and value range in
63
+ /// conjunction with normalization.
64
+ /// @param noramlize Perform normalization. Only effective when dtype is not
65
+ /// kFloat32. When effective, the output tensor is kFloat32 type and value range
66
+ /// is [-1.0, 1.0]
67
+ /// @param channels_first When True, output Tensor has shape of [num_channels,
68
+ /// num_frames].
69
+ torch::Tensor convert_to_tensor(
70
+ sox_sample_t* buffer,
71
+ const int32_t num_samples,
72
+ const int32_t num_channels,
73
+ const caffe2::TypeMeta dtype,
74
+ const bool normalize,
75
+ const bool channels_first);
76
+
77
+ ///
78
+ /// Convert float32/int32/int16/uint8 Tensor to int32 for Torch -> Sox
79
+ /// conversion.
80
+ torch::Tensor unnormalize_wav(const torch::Tensor);
81
+
82
+ /// Extract extension from file path
83
+ const std::string get_filetype(const std::string path);
84
+
85
+ /// Get sox_signalinfo_t for passing a torch::Tensor object.
86
+ sox_signalinfo_t get_signalinfo(
87
+ const torch::Tensor& tensor,
88
+ const int64_t sample_rate,
89
+ const bool channels_first,
90
+ const std::string filetype);
91
+
92
+ /// Get sox_encofinginfo_t for saving audoi file
93
+ sox_encodinginfo_t get_encodinginfo(
94
+ const std::string filetype,
95
+ const caffe2::TypeMeta dtype,
96
+ const double compression);
97
+
98
+ } // namespace sox_utils
99
+ } // namespace torchaudio
100
+ #endif
@@ -0,0 +1,33 @@
1
+ #include <torchaudio/csrc/sox.h>
2
+
3
+ #include <rice/Module.hpp>
4
+
5
+ using namespace Rice;
6
+
7
+ template<>
8
+ inline
9
+ sox_signalinfo_t* from_ruby<sox_signalinfo_t*>(Object x)
10
+ {
11
+ if (x.is_nil()) {
12
+ return nullptr;
13
+ }
14
+ throw std::runtime_error("Unsupported signalinfo");
15
+ }
16
+
17
+ template<>
18
+ inline
19
+ sox_encodinginfo_t* from_ruby<sox_encodinginfo_t*>(Object x)
20
+ {
21
+ if (x.is_nil()) {
22
+ return nullptr;
23
+ }
24
+ throw std::runtime_error("Unsupported encodinginfo");
25
+ }
26
+
27
+ extern "C"
28
+ void Init_ext()
29
+ {
30
+ Module rb_mTorchAudio = define_module("TorchAudio");
31
+ Module rb_mNN = define_module_under(rb_mTorchAudio, "Ext")
32
+ .define_singleton_method("read_audio_file", &torch::audio::read_audio_file);
33
+ }
@@ -0,0 +1,81 @@
1
+ require "mkmf-rice"
2
+
3
+ abort "Missing stdc++" unless have_library("stdc++")
4
+
5
+ $CXXFLAGS += " -std=c++14"
6
+
7
+ abort "SoX not found" unless have_library("sox")
8
+
9
+ ext = File.expand_path(".", __dir__)
10
+ csrc = File.expand_path("csrc", __dir__)
11
+
12
+ $srcs = Dir["{#{ext},#{csrc}}/*.cpp"]
13
+ $INCFLAGS << " -I#{File.expand_path("..", __dir__)}"
14
+ $VPATH << csrc
15
+
16
+ #
17
+ # keep rest synced with Torch
18
+ #
19
+
20
+ # change to 0 for Linux pre-cxx11 ABI version
21
+ $CXXFLAGS += " -D_GLIBCXX_USE_CXX11_ABI=1"
22
+
23
+ apple_clang = RbConfig::CONFIG["CC_VERSION_MESSAGE"] =~ /apple clang/i
24
+
25
+ # check omp first
26
+ if have_library("omp") || have_library("gomp")
27
+ $CXXFLAGS += " -DAT_PARALLEL_OPENMP=1"
28
+ $CXXFLAGS += " -Xclang" if apple_clang
29
+ $CXXFLAGS += " -fopenmp"
30
+ end
31
+
32
+ if apple_clang
33
+ # silence ruby/intern.h warning
34
+ $CXXFLAGS += " -Wno-deprecated-register"
35
+
36
+ # silence torch warnings
37
+ $CXXFLAGS += " -Wno-shorten-64-to-32 -Wno-missing-noreturn"
38
+ else
39
+ # silence rice warnings
40
+ $CXXFLAGS += " -Wno-noexcept-type"
41
+
42
+ # silence torch warnings
43
+ $CXXFLAGS += " -Wno-duplicated-cond -Wno-suggest-attribute=noreturn"
44
+ end
45
+
46
+ inc, lib = dir_config("torch")
47
+ inc ||= "/usr/local/include"
48
+ lib ||= "/usr/local/lib"
49
+
50
+ cuda_inc, cuda_lib = dir_config("cuda")
51
+ cuda_inc ||= "/usr/local/cuda/include"
52
+ cuda_lib ||= "/usr/local/cuda/lib64"
53
+
54
+ $LDFLAGS += " -L#{lib}" if Dir.exist?(lib)
55
+ abort "LibTorch not found" unless have_library("torch")
56
+
57
+ have_library("mkldnn")
58
+ have_library("nnpack")
59
+
60
+ with_cuda = false
61
+ if Dir["#{lib}/*torch_cuda*"].any?
62
+ $LDFLAGS += " -L#{cuda_lib}" if Dir.exist?(cuda_lib)
63
+ with_cuda = have_library("cuda") && have_library("cudnn")
64
+ end
65
+
66
+ $INCFLAGS += " -I#{inc}"
67
+ $INCFLAGS += " -I#{inc}/torch/csrc/api/include"
68
+
69
+ $LDFLAGS += " -Wl,-rpath,#{lib}"
70
+ $LDFLAGS += ":#{cuda_lib}/stubs:#{cuda_lib}" if with_cuda
71
+
72
+ # https://github.com/pytorch/pytorch/blob/v1.5.0/torch/utils/cpp_extension.py#L1232-L1238
73
+ $LDFLAGS += " -lc10 -ltorch_cpu -ltorch"
74
+ if with_cuda
75
+ $LDFLAGS += " -lcuda -lnvrtc -lnvToolsExt -lcudart -lc10_cuda -ltorch_cuda -lcufft -lcurand -lcublas -lcudnn"
76
+ # TODO figure out why this is needed
77
+ $LDFLAGS += " -Wl,--no-as-needed,#{lib}/libtorch.so"
78
+ end
79
+
80
+ # create makefile
81
+ create_makefile("torchaudio/ext")
@@ -0,0 +1,95 @@
1
+ # dependencies
2
+ require "torch"
3
+
4
+ # ext
5
+ require "torchaudio/ext"
6
+
7
+ # stdlib
8
+ require "csv"
9
+ require "digest"
10
+ require "fileutils"
11
+ require "rubygems/package"
12
+ require "set"
13
+
14
+ # modules
15
+ require "torchaudio/datasets/utils"
16
+ require "torchaudio/datasets/yesno"
17
+ require "torchaudio/version"
18
+
19
+ module TorchAudio
20
+ class Error < StandardError; end
21
+
22
+ class << self
23
+ def load(
24
+ filepath, out: nil, normalization: true, channels_first: true, num_frames: 0,
25
+ offset: 0, signalinfo: nil, encodinginfo: nil, filetype: nil
26
+ )
27
+
28
+ filepath = filepath.to_s
29
+
30
+ # check if valid file
31
+ unless File.exist?(filepath)
32
+ raise ArgumentError, "#{filepath} not found or is a directory"
33
+ end
34
+
35
+ # initialize output tensor
36
+ if !out.nil?
37
+ check_input(out)
38
+ else
39
+ out = Torch::FloatTensor.new
40
+ end
41
+
42
+ if num_frames < -1
43
+ raise ArgumentError, "Expected value for num_samples -1 (entire file) or >=0"
44
+ end
45
+ if offset < 0
46
+ raise ArgumentError, "Expected positive offset value"
47
+ end
48
+
49
+ # same logic as C++
50
+ # could also make read_audio_file work with nil
51
+ filetype ||= File.extname(filepath)[1..-1]
52
+
53
+ sample_rate =
54
+ Ext.read_audio_file(
55
+ filepath,
56
+ out,
57
+ channels_first,
58
+ num_frames,
59
+ offset,
60
+ signalinfo,
61
+ encodinginfo,
62
+ filetype
63
+ )
64
+
65
+ # normalize if needed
66
+ normalize_audio(out, normalization)
67
+
68
+ [out, sample_rate]
69
+ end
70
+
71
+ def load_wav(filepath, **kwargs)
72
+ kwargs[:normalization] = 1 << 16
73
+ load(filepath, **kwargs)
74
+ end
75
+
76
+ private
77
+
78
+ def check_input(src)
79
+ raise ArgumentError, "Expected a tensor, got #{src.class.name}" unless Torch.tensor?(src)
80
+ raise ArgumentError, "Expected a CPU based tensor, got #{src.class.name}" if src.cuda?
81
+ end
82
+
83
+ def normalize_audio(signal, normalization)
84
+ return unless normalization
85
+
86
+ normalization = 1 << 31 if normalization == true
87
+
88
+ if normalization.is_a?(Numeric)
89
+ signal.div!(normalization)
90
+ elsif normalization.respond_to?(:call)
91
+ signal.div!(normalization.call(signal))
92
+ end
93
+ end
94
+ end
95
+ end
@@ -0,0 +1,92 @@
1
+ module TorchAudio
2
+ module Datasets
3
+ module Utils
4
+ class << self
5
+ def download_url(url, download_folder, filename: nil, hash_value: nil, hash_type: "sha256")
6
+ filename ||= File.basename(url)
7
+ filepath = File.join(download_folder, filename)
8
+
9
+ if File.exist?(filepath)
10
+ raise "#{filepath} already exists. Delete the file manually and retry."
11
+ end
12
+
13
+ puts "Downloading #{url}..."
14
+ download_url_to_file(url, filepath, hash_value, hash_type)
15
+ end
16
+
17
+ # follows redirects
18
+ def download_url_to_file(url, dst, hash_value, hash_type)
19
+ uri = URI(url)
20
+ tmp = nil
21
+ location = nil
22
+
23
+ Net::HTTP.start(uri.host, uri.port, use_ssl: uri.scheme == "https") do |http|
24
+ request = Net::HTTP::Get.new(uri)
25
+
26
+ http.request(request) do |response|
27
+ case response
28
+ when Net::HTTPRedirection
29
+ location = response["location"]
30
+ when Net::HTTPSuccess
31
+ tmp = "#{Dir.tmpdir}/#{Time.now.to_f}" # TODO better name
32
+ File.open(tmp, "wb") do |f|
33
+ response.read_body do |chunk|
34
+ f.write(chunk)
35
+ end
36
+ end
37
+ else
38
+ raise Error, "Bad response"
39
+ end
40
+ end
41
+ end
42
+
43
+ if location
44
+ download_url_to_file(location, dst)
45
+ else
46
+ # check hash
47
+ # TODO use hash_type
48
+ if Digest::MD5.file(tmp).hexdigest != hash_value
49
+ raise "The hash of #{dst} does not match. Delete the file manually and retry."
50
+ end
51
+
52
+ FileUtils.mv(tmp, dst)
53
+ dst
54
+ end
55
+ end
56
+
57
+ # extract_tar_gz doesn't list files, so just return to_path
58
+ def extract_archive(from_path, to_path: nil, overwrite: nil)
59
+ to_path ||= File.dirname(from_path)
60
+
61
+ if from_path.end_with?(".tar.gz") || from_path.end_with?(".tgz")
62
+ File.open(from_path, "rb") do |io|
63
+ Gem::Package.new("").extract_tar_gz(io, to_path)
64
+ end
65
+ return to_path
66
+ end
67
+
68
+ raise "We currently only support tar.gz and tgz archives."
69
+ end
70
+
71
+ def walk_files(root, suffix, prefix: false, remove_suffix: false)
72
+ return enum_for(:walk_files, root, suffix, prefix: prefix, remove_suffix: remove_suffix) unless block_given?
73
+
74
+ Dir.glob("**/*", base: root).sort.each do |f|
75
+ if f.end_with?(suffix)
76
+ if remove_suffix
77
+ f = f[0..(-suffix.length - 1)]
78
+ end
79
+
80
+ if prefix
81
+ raise "Not implemented yet"
82
+ # f = File.join(dirpath, f)
83
+ end
84
+
85
+ yield f
86
+ end
87
+ end
88
+ end
89
+ end
90
+ end
91
+ end
92
+ end
@@ -0,0 +1,59 @@
1
+ module TorchAudio
2
+ module Datasets
3
+ class YESNO < Torch::Utils::Data::Dataset
4
+ URL = "http://www.openslr.org/resources/1/waves_yesno.tar.gz"
5
+ FOLDER_IN_ARCHIVE = "waves_yesno"
6
+ CHECKSUMS = {
7
+ "http://www.openslr.org/resources/1/waves_yesno.tar.gz" => "962ff6e904d2df1126132ecec6978786"
8
+ }
9
+
10
+ def initialize(root, url: URL, folder_in_archive: FOLDER_IN_ARCHIVE, download: false)
11
+ archive = File.basename(url)
12
+ archive = File.join(root, archive)
13
+ @path = File.join(root, folder_in_archive)
14
+
15
+ if download
16
+ unless Dir.exist?(@path)
17
+ unless File.exist?(archive)
18
+ checksum = CHECKSUMS.fetch(url)
19
+ Utils.download_url(url, root, hash_value: checksum, hash_type: "md5")
20
+ end
21
+ Utils.extract_archive(archive)
22
+ end
23
+ end
24
+
25
+ unless Dir.exist?(@path)
26
+ raise "Dataset not found. Please use `download: true` to download it."
27
+ end
28
+
29
+ walker = Utils.walk_files(@path, ext_audio, prefix: false, remove_suffix: true)
30
+ @walker = walker.to_a
31
+ end
32
+
33
+ def [](n)
34
+ fileid = @walker[n]
35
+ load_yesno_item(fileid, @path, ext_audio)
36
+ end
37
+
38
+ def length
39
+ @walker.length
40
+ end
41
+ alias_method :size, :length
42
+
43
+ private
44
+
45
+ def load_yesno_item(fileid, path, ext_audio)
46
+ labels = fileid.split("_").map(&:to_i)
47
+
48
+ file_audio = File.join(path, fileid + ext_audio)
49
+ waveform, sample_rate = TorchAudio.load(file_audio)
50
+
51
+ [waveform, sample_rate, labels]
52
+ end
53
+
54
+ def ext_audio
55
+ ".wav"
56
+ end
57
+ end
58
+ end
59
+ end