torchaudio 0.1.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,100 @@
1
+ #ifndef TORCHAUDIO_SOX_UTILS_H
2
+ #define TORCHAUDIO_SOX_UTILS_H
3
+
4
+ #include <sox.h>
5
+ #include <torch/script.h>
6
+
7
+ namespace torchaudio {
8
+ namespace sox_utils {
9
+
10
+ struct TensorSignal : torch::CustomClassHolder {
11
+ torch::Tensor tensor;
12
+ int64_t sample_rate;
13
+ bool channels_first;
14
+
15
+ TensorSignal(
16
+ torch::Tensor tensor_,
17
+ int64_t sample_rate_,
18
+ bool channels_first_);
19
+
20
+ torch::Tensor getTensor() const;
21
+ int64_t getSampleRate() const;
22
+ bool getChannelsFirst() const;
23
+ };
24
+
25
+ /// helper class to automatically close sox_format_t*
26
+ struct SoxFormat {
27
+ explicit SoxFormat(sox_format_t* fd) noexcept;
28
+ SoxFormat(const SoxFormat& other) = delete;
29
+ SoxFormat(SoxFormat&& other) = delete;
30
+ SoxFormat& operator=(const SoxFormat& other) = delete;
31
+ SoxFormat& operator=(SoxFormat&& other) = delete;
32
+ ~SoxFormat();
33
+ sox_format_t* operator->() const noexcept;
34
+ operator sox_format_t*() const noexcept;
35
+
36
+ private:
37
+ sox_format_t* fd_;
38
+ };
39
+
40
+ ///
41
+ /// Verify that input file is found, has known encoding, and not empty
42
+ void validate_input_file(const SoxFormat& sf);
43
+
44
+ ///
45
+ /// Verify that input Tensor is 2D, CPU and either uin8, int16, int32 or float32
46
+ void validate_input_tensor(const torch::Tensor);
47
+
48
+ ///
49
+ /// Get target dtype for the given encoding and precision.
50
+ caffe2::TypeMeta get_dtype(
51
+ const sox_encoding_t encoding,
52
+ const unsigned precision);
53
+
54
+ ///
55
+ /// Convert sox_sample_t buffer to uint8/int16/int32/float32 Tensor
56
+ /// NOTE: This function might modify the values in the input buffer to
57
+ /// reduce the number of memory copy.
58
+ /// @param buffer Pointer to buffer that contains audio data.
59
+ /// @param num_samples The number of samples to read.
60
+ /// @param num_channels The number of channels. Used to reshape the resulting
61
+ /// Tensor.
62
+ /// @param dtype Target dtype. Determines the output dtype and value range in
63
+ /// conjunction with normalization.
64
+ /// @param noramlize Perform normalization. Only effective when dtype is not
65
+ /// kFloat32. When effective, the output tensor is kFloat32 type and value range
66
+ /// is [-1.0, 1.0]
67
+ /// @param channels_first When True, output Tensor has shape of [num_channels,
68
+ /// num_frames].
69
+ torch::Tensor convert_to_tensor(
70
+ sox_sample_t* buffer,
71
+ const int32_t num_samples,
72
+ const int32_t num_channels,
73
+ const caffe2::TypeMeta dtype,
74
+ const bool normalize,
75
+ const bool channels_first);
76
+
77
+ ///
78
+ /// Convert float32/int32/int16/uint8 Tensor to int32 for Torch -> Sox
79
+ /// conversion.
80
+ torch::Tensor unnormalize_wav(const torch::Tensor);
81
+
82
+ /// Extract extension from file path
83
+ const std::string get_filetype(const std::string path);
84
+
85
+ /// Get sox_signalinfo_t for passing a torch::Tensor object.
86
+ sox_signalinfo_t get_signalinfo(
87
+ const torch::Tensor& tensor,
88
+ const int64_t sample_rate,
89
+ const bool channels_first,
90
+ const std::string filetype);
91
+
92
+ /// Get sox_encofinginfo_t for saving audoi file
93
+ sox_encodinginfo_t get_encodinginfo(
94
+ const std::string filetype,
95
+ const caffe2::TypeMeta dtype,
96
+ const double compression);
97
+
98
+ } // namespace sox_utils
99
+ } // namespace torchaudio
100
+ #endif
@@ -0,0 +1,33 @@
1
+ #include <torchaudio/csrc/sox.h>
2
+
3
+ #include <rice/Module.hpp>
4
+
5
+ using namespace Rice;
6
+
7
+ template<>
8
+ inline
9
+ sox_signalinfo_t* from_ruby<sox_signalinfo_t*>(Object x)
10
+ {
11
+ if (x.is_nil()) {
12
+ return nullptr;
13
+ }
14
+ throw std::runtime_error("Unsupported signalinfo");
15
+ }
16
+
17
+ template<>
18
+ inline
19
+ sox_encodinginfo_t* from_ruby<sox_encodinginfo_t*>(Object x)
20
+ {
21
+ if (x.is_nil()) {
22
+ return nullptr;
23
+ }
24
+ throw std::runtime_error("Unsupported encodinginfo");
25
+ }
26
+
27
+ extern "C"
28
+ void Init_ext()
29
+ {
30
+ Module rb_mTorchAudio = define_module("TorchAudio");
31
+ Module rb_mNN = define_module_under(rb_mTorchAudio, "Ext")
32
+ .define_singleton_method("read_audio_file", &torch::audio::read_audio_file);
33
+ }
@@ -0,0 +1,81 @@
1
+ require "mkmf-rice"
2
+
3
+ abort "Missing stdc++" unless have_library("stdc++")
4
+
5
+ $CXXFLAGS += " -std=c++14"
6
+
7
+ abort "SoX not found" unless have_library("sox")
8
+
9
+ ext = File.expand_path(".", __dir__)
10
+ csrc = File.expand_path("csrc", __dir__)
11
+
12
+ $srcs = Dir["{#{ext},#{csrc}}/*.cpp"]
13
+ $INCFLAGS << " -I#{File.expand_path("..", __dir__)}"
14
+ $VPATH << csrc
15
+
16
+ #
17
+ # keep rest synced with Torch
18
+ #
19
+
20
+ # change to 0 for Linux pre-cxx11 ABI version
21
+ $CXXFLAGS += " -D_GLIBCXX_USE_CXX11_ABI=1"
22
+
23
+ apple_clang = RbConfig::CONFIG["CC_VERSION_MESSAGE"] =~ /apple clang/i
24
+
25
+ # check omp first
26
+ if have_library("omp") || have_library("gomp")
27
+ $CXXFLAGS += " -DAT_PARALLEL_OPENMP=1"
28
+ $CXXFLAGS += " -Xclang" if apple_clang
29
+ $CXXFLAGS += " -fopenmp"
30
+ end
31
+
32
+ if apple_clang
33
+ # silence ruby/intern.h warning
34
+ $CXXFLAGS += " -Wno-deprecated-register"
35
+
36
+ # silence torch warnings
37
+ $CXXFLAGS += " -Wno-shorten-64-to-32 -Wno-missing-noreturn"
38
+ else
39
+ # silence rice warnings
40
+ $CXXFLAGS += " -Wno-noexcept-type"
41
+
42
+ # silence torch warnings
43
+ $CXXFLAGS += " -Wno-duplicated-cond -Wno-suggest-attribute=noreturn"
44
+ end
45
+
46
+ inc, lib = dir_config("torch")
47
+ inc ||= "/usr/local/include"
48
+ lib ||= "/usr/local/lib"
49
+
50
+ cuda_inc, cuda_lib = dir_config("cuda")
51
+ cuda_inc ||= "/usr/local/cuda/include"
52
+ cuda_lib ||= "/usr/local/cuda/lib64"
53
+
54
+ $LDFLAGS += " -L#{lib}" if Dir.exist?(lib)
55
+ abort "LibTorch not found" unless have_library("torch")
56
+
57
+ have_library("mkldnn")
58
+ have_library("nnpack")
59
+
60
+ with_cuda = false
61
+ if Dir["#{lib}/*torch_cuda*"].any?
62
+ $LDFLAGS += " -L#{cuda_lib}" if Dir.exist?(cuda_lib)
63
+ with_cuda = have_library("cuda") && have_library("cudnn")
64
+ end
65
+
66
+ $INCFLAGS += " -I#{inc}"
67
+ $INCFLAGS += " -I#{inc}/torch/csrc/api/include"
68
+
69
+ $LDFLAGS += " -Wl,-rpath,#{lib}"
70
+ $LDFLAGS += ":#{cuda_lib}/stubs:#{cuda_lib}" if with_cuda
71
+
72
+ # https://github.com/pytorch/pytorch/blob/v1.5.0/torch/utils/cpp_extension.py#L1232-L1238
73
+ $LDFLAGS += " -lc10 -ltorch_cpu -ltorch"
74
+ if with_cuda
75
+ $LDFLAGS += " -lcuda -lnvrtc -lnvToolsExt -lcudart -lc10_cuda -ltorch_cuda -lcufft -lcurand -lcublas -lcudnn"
76
+ # TODO figure out why this is needed
77
+ $LDFLAGS += " -Wl,--no-as-needed,#{lib}/libtorch.so"
78
+ end
79
+
80
+ # create makefile
81
+ create_makefile("torchaudio/ext")
@@ -0,0 +1,95 @@
1
+ # dependencies
2
+ require "torch"
3
+
4
+ # ext
5
+ require "torchaudio/ext"
6
+
7
+ # stdlib
8
+ require "csv"
9
+ require "digest"
10
+ require "fileutils"
11
+ require "rubygems/package"
12
+ require "set"
13
+
14
+ # modules
15
+ require "torchaudio/datasets/utils"
16
+ require "torchaudio/datasets/yesno"
17
+ require "torchaudio/version"
18
+
19
+ module TorchAudio
20
+ class Error < StandardError; end
21
+
22
+ class << self
23
+ def load(
24
+ filepath, out: nil, normalization: true, channels_first: true, num_frames: 0,
25
+ offset: 0, signalinfo: nil, encodinginfo: nil, filetype: nil
26
+ )
27
+
28
+ filepath = filepath.to_s
29
+
30
+ # check if valid file
31
+ unless File.exist?(filepath)
32
+ raise ArgumentError, "#{filepath} not found or is a directory"
33
+ end
34
+
35
+ # initialize output tensor
36
+ if !out.nil?
37
+ check_input(out)
38
+ else
39
+ out = Torch::FloatTensor.new
40
+ end
41
+
42
+ if num_frames < -1
43
+ raise ArgumentError, "Expected value for num_samples -1 (entire file) or >=0"
44
+ end
45
+ if offset < 0
46
+ raise ArgumentError, "Expected positive offset value"
47
+ end
48
+
49
+ # same logic as C++
50
+ # could also make read_audio_file work with nil
51
+ filetype ||= File.extname(filepath)[1..-1]
52
+
53
+ sample_rate =
54
+ Ext.read_audio_file(
55
+ filepath,
56
+ out,
57
+ channels_first,
58
+ num_frames,
59
+ offset,
60
+ signalinfo,
61
+ encodinginfo,
62
+ filetype
63
+ )
64
+
65
+ # normalize if needed
66
+ normalize_audio(out, normalization)
67
+
68
+ [out, sample_rate]
69
+ end
70
+
71
+ def load_wav(filepath, **kwargs)
72
+ kwargs[:normalization] = 1 << 16
73
+ load(filepath, **kwargs)
74
+ end
75
+
76
+ private
77
+
78
+ def check_input(src)
79
+ raise ArgumentError, "Expected a tensor, got #{src.class.name}" unless Torch.tensor?(src)
80
+ raise ArgumentError, "Expected a CPU based tensor, got #{src.class.name}" if src.cuda?
81
+ end
82
+
83
+ def normalize_audio(signal, normalization)
84
+ return unless normalization
85
+
86
+ normalization = 1 << 31 if normalization == true
87
+
88
+ if normalization.is_a?(Numeric)
89
+ signal.div!(normalization)
90
+ elsif normalization.respond_to?(:call)
91
+ signal.div!(normalization.call(signal))
92
+ end
93
+ end
94
+ end
95
+ end
@@ -0,0 +1,92 @@
1
+ module TorchAudio
2
+ module Datasets
3
+ module Utils
4
+ class << self
5
+ def download_url(url, download_folder, filename: nil, hash_value: nil, hash_type: "sha256")
6
+ filename ||= File.basename(url)
7
+ filepath = File.join(download_folder, filename)
8
+
9
+ if File.exist?(filepath)
10
+ raise "#{filepath} already exists. Delete the file manually and retry."
11
+ end
12
+
13
+ puts "Downloading #{url}..."
14
+ download_url_to_file(url, filepath, hash_value, hash_type)
15
+ end
16
+
17
+ # follows redirects
18
+ def download_url_to_file(url, dst, hash_value, hash_type)
19
+ uri = URI(url)
20
+ tmp = nil
21
+ location = nil
22
+
23
+ Net::HTTP.start(uri.host, uri.port, use_ssl: uri.scheme == "https") do |http|
24
+ request = Net::HTTP::Get.new(uri)
25
+
26
+ http.request(request) do |response|
27
+ case response
28
+ when Net::HTTPRedirection
29
+ location = response["location"]
30
+ when Net::HTTPSuccess
31
+ tmp = "#{Dir.tmpdir}/#{Time.now.to_f}" # TODO better name
32
+ File.open(tmp, "wb") do |f|
33
+ response.read_body do |chunk|
34
+ f.write(chunk)
35
+ end
36
+ end
37
+ else
38
+ raise Error, "Bad response"
39
+ end
40
+ end
41
+ end
42
+
43
+ if location
44
+ download_url_to_file(location, dst)
45
+ else
46
+ # check hash
47
+ # TODO use hash_type
48
+ if Digest::MD5.file(tmp).hexdigest != hash_value
49
+ raise "The hash of #{dst} does not match. Delete the file manually and retry."
50
+ end
51
+
52
+ FileUtils.mv(tmp, dst)
53
+ dst
54
+ end
55
+ end
56
+
57
+ # extract_tar_gz doesn't list files, so just return to_path
58
+ def extract_archive(from_path, to_path: nil, overwrite: nil)
59
+ to_path ||= File.dirname(from_path)
60
+
61
+ if from_path.end_with?(".tar.gz") || from_path.end_with?(".tgz")
62
+ File.open(from_path, "rb") do |io|
63
+ Gem::Package.new("").extract_tar_gz(io, to_path)
64
+ end
65
+ return to_path
66
+ end
67
+
68
+ raise "We currently only support tar.gz and tgz archives."
69
+ end
70
+
71
+ def walk_files(root, suffix, prefix: false, remove_suffix: false)
72
+ return enum_for(:walk_files, root, suffix, prefix: prefix, remove_suffix: remove_suffix) unless block_given?
73
+
74
+ Dir.glob("**/*", base: root).sort.each do |f|
75
+ if f.end_with?(suffix)
76
+ if remove_suffix
77
+ f = f[0..(-suffix.length - 1)]
78
+ end
79
+
80
+ if prefix
81
+ raise "Not implemented yet"
82
+ # f = File.join(dirpath, f)
83
+ end
84
+
85
+ yield f
86
+ end
87
+ end
88
+ end
89
+ end
90
+ end
91
+ end
92
+ end
@@ -0,0 +1,59 @@
1
+ module TorchAudio
2
+ module Datasets
3
+ class YESNO < Torch::Utils::Data::Dataset
4
+ URL = "http://www.openslr.org/resources/1/waves_yesno.tar.gz"
5
+ FOLDER_IN_ARCHIVE = "waves_yesno"
6
+ CHECKSUMS = {
7
+ "http://www.openslr.org/resources/1/waves_yesno.tar.gz" => "962ff6e904d2df1126132ecec6978786"
8
+ }
9
+
10
+ def initialize(root, url: URL, folder_in_archive: FOLDER_IN_ARCHIVE, download: false)
11
+ archive = File.basename(url)
12
+ archive = File.join(root, archive)
13
+ @path = File.join(root, folder_in_archive)
14
+
15
+ if download
16
+ unless Dir.exist?(@path)
17
+ unless File.exist?(archive)
18
+ checksum = CHECKSUMS.fetch(url)
19
+ Utils.download_url(url, root, hash_value: checksum, hash_type: "md5")
20
+ end
21
+ Utils.extract_archive(archive)
22
+ end
23
+ end
24
+
25
+ unless Dir.exist?(@path)
26
+ raise "Dataset not found. Please use `download: true` to download it."
27
+ end
28
+
29
+ walker = Utils.walk_files(@path, ext_audio, prefix: false, remove_suffix: true)
30
+ @walker = walker.to_a
31
+ end
32
+
33
+ def [](n)
34
+ fileid = @walker[n]
35
+ load_yesno_item(fileid, @path, ext_audio)
36
+ end
37
+
38
+ def length
39
+ @walker.length
40
+ end
41
+ alias_method :size, :length
42
+
43
+ private
44
+
45
+ def load_yesno_item(fileid, path, ext_audio)
46
+ labels = fileid.split("_").map(&:to_i)
47
+
48
+ file_audio = File.join(path, fileid + ext_audio)
49
+ waveform, sample_rate = TorchAudio.load(file_audio)
50
+
51
+ [waveform, sample_rate, labels]
52
+ end
53
+
54
+ def ext_audio
55
+ ".wav"
56
+ end
57
+ end
58
+ end
59
+ end