gtcrn 0.0.3 → 0.0.4

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: e838b8e452d988facb9cf6cfaa99c8a1b4e9c48073d7a035a07d3b3e42461ec8
4
- data.tar.gz: 433c73cbe706d29786d62499ad17eb8b629ab5f5fc640aa3c8e999a56aa46397
3
+ metadata.gz: 996bfc8a312128c55b4eee2b91f908a612ccbcc619711168bd0468d92018a79b
4
+ data.tar.gz: 63831dc111ee4d3016c99915aa4dc9cf81b440bb04374c0f64f36402ce8ea1c8
5
5
  SHA512:
6
- metadata.gz: 21fd283b28ca5b35b7f9f02dab22cef8d4e342af47e7556594b97e70643351e18ff77a9da542e6b5d3963fbb812922740f00f6067b780e95409eb081111b8806
7
- data.tar.gz: bc8d8aa599c3b53ef4801bebcb1ac5ef49e11349c71753292d4eeeecfc65bfb495d40a65df5cbe4ba8d14e6db17fe7e16ff25ff01417484a8db0f981ab8dac36
6
+ metadata.gz: 55836e22a48b395dc403b0340d14d3a7e8ac107902756a6b685bc09514290b8571c2b50e37ac6dcb67cfcc6c716ad3476be847e1fbff9648f324a0071da701ff
7
+ data.tar.gz: 38558c34bc1ef02473b84c6d74a0397030153aa6cd09393c408f704eebb22e91898e920e6a72617b7904e840489d8d94b27216418e13ffe04b54b332f2a1ee8c
data/README.md CHANGED
@@ -23,7 +23,8 @@ INSTALLATION
23
23
 
24
24
  This gem depends on [Torch.rb][], [TorchAudio Ruby][] and [TorchCodec Ruby][] which require precompiled libtorch and being built with it.
25
25
 
26
- % wget https://download.pytorch.org/libtorch/cpu/libtorch-macos-arm64-2.10.0.zip # See https://pytorch.org/get-started/locally/ for download URI for your environment
26
+ # See https://pytorch.org/get-started/locally/ for download URI for your environment
27
+ % wget https://download.pytorch.org/libtorch/cpu/libtorch-macos-arm64-2.10.0.zip
27
28
  % unzip -d path/to/libtorch libtorch-macos-arm64-2.10.0.zip
28
29
  % gem install torch-rb -- --with-torch-dir=path/to/libtorch
29
30
  % gem install torchaudio -- --with-torch-dir=path/to/libtorch
data/bin/gtcrn CHANGED
@@ -1,4 +1,5 @@
1
1
  require "optparse"
2
+ require "optparse/pathname"
2
3
  require "gtcrn"
3
4
 
4
5
  def main(argv)
@@ -18,7 +19,7 @@ def parse_options(argv)
18
19
 
19
20
  Usage: gtcrn [options] INPUT
20
21
  EOB
21
- opt.on "-o", "--output=PATH", "Specify output file or directory" do |path|
22
+ opt.on "-o", "--output=PATH", "Specify output file or directory", Pathname do |path|
22
23
  options[:output] = path
23
24
  end
24
25
  }.parse!(argv)
data/gtcrn.gemspec CHANGED
@@ -1,6 +1,6 @@
1
1
  Gem::Specification.new do |s|
2
2
  s.name = "gtcrn"
3
- s.version = "0.0.3"
3
+ s.version = "0.0.4"
4
4
  s.authors = ["Kitaiti Makoto"]
5
5
  s.email = ["KitaitiMakoto@gmail.com"]
6
6
  s.summary = "Denoises audio"
@@ -18,7 +18,8 @@ Gem::Specification.new do |s|
18
18
  s.add_runtime_dependency "torch-rb"
19
19
  s.add_runtime_dependency "torchaudio", ">= 0.5.0"
20
20
  s.add_runtime_dependency "torchcodec"
21
- s.add_runtime_dependency "numo-narray-alt"
21
+ s.add_runtime_dependency "optparse-pathname"
22
+ s.add_runtime_dependency "onnxruntime-torch-tensor"
22
23
 
23
24
  s.add_development_dependency "rake"
24
25
  s.add_development_dependency "rubygems-tasks"
data/lib/gtcrn.rb CHANGED
@@ -2,9 +2,11 @@ require "pathname"
2
2
  require "torchaudio"
3
3
  require "torch"
4
4
  require "onnxruntime"
5
- require "numo/narray/alt"
5
+ require "onnxruntime/torch/tensor"
6
6
 
7
7
  class GTCRN
8
+ include NDAV::Converter
9
+
8
10
  MODEL_PATH = File.join(__dir__, "../vendor/gtcrn/stream/onnx_models/gtcrn_simple.onnx").freeze
9
11
  ISTFT_OPTS = {
10
12
  n_fft: 512,
@@ -19,6 +21,7 @@ class GTCRN
19
21
 
20
22
  def initialize
21
23
  @session = OnnxRuntime::InferenceSession.new(MODEL_PATH)
24
+ @cache_shapes = @session.inputs[1..].collect {|input| input[:shape]}
22
25
  @output_names = @session.outputs.collect {|output| output[:name]}
23
26
  end
24
27
 
@@ -44,7 +47,7 @@ class GTCRN
44
47
  def enhance_speech_waveform(waveform)
45
48
  ndim = waveform.ndim
46
49
  unless ndim == 1 or ndim == 2
47
- raise ArgumentError, "wrong dimension of argment (given #{ndim}, expected 1D or 2D"
50
+ raise ArgumentError, "wrong dimension of argment (given #{ndim}, expected 1D or 2D)"
48
51
  end
49
52
  waveform = [waveform] if ndim == 1
50
53
  channels = waveform.collect {|channel| enhance_speech_waveform_channel(channel)}
@@ -52,33 +55,25 @@ class GTCRN
52
55
  end
53
56
 
54
57
  def enhance_speech_waveform_channel(channel)
55
- conv_cache, tra_cache, inter_cache = 1.upto(3).collect {|i|
56
- OnnxRuntime::OrtValue.from_numo(
57
- Numo::SFloat.zeros(*@session.inputs[i][:shape])
58
- )
58
+ conv_cache, tra_cache, inter_cache = @cache_shapes.collect {|shape|
59
+ OrtValue(Torch.zeros(*shape, dtype: :float32))
59
60
  }
60
61
  inputs = Torch.view_as_real(
61
62
  Torch.stft(channel, **STFT_OPTS)[nil]
62
- ).numo
63
+ )
63
64
  outputs = []
64
65
  inputs.shape[-2].times do |i|
66
+ input = inputs[0.., 0.., i..i, 0..]
65
67
  enh, conv_cache, tra_cache, inter_cache = @session.run(
66
68
  @output_names,
67
- {
68
- mix: OnnxRuntime::OrtValue.from_numo(inputs[0.., 0.., i..i, 0..]),
69
- conv_cache:, tra_cache:, inter_cache:,
70
- },
69
+ {mix: OrtValue(input), conv_cache:, tra_cache:, inter_cache:},
71
70
  output_type: :ort_value
72
71
  )
73
- outputs << enh.numo
72
+ outputs << TorchTensor(enh)
74
73
  end
75
- concated = Numo::NArray.concatenate(outputs, axis: 2)
76
- real = concated[0.., 0.., 0.., 0]
77
- imag = concated[0.., 0.., 0.., 1]
78
- enhanced = Torch.istft(
79
- Torch.complex(Torch.from_numo(real), Torch.from_numo(imag)),
80
- **ISTFT_OPTS
81
- )
82
- enhanced.squeeze(0)
74
+ concated = Torch.cat(outputs, dim: 2)
75
+ Torch
76
+ .istft(Torch.view_as_complex(concated), **ISTFT_OPTS)
77
+ .squeeze(0)
83
78
  end
84
79
  end
data/test/test_gtcrn.rb CHANGED
@@ -26,7 +26,7 @@ class TestGTCRN < Test::Unit::TestCase
26
26
  assert_equal waveform.ndim, enhanced.ndim
27
27
  end
28
28
 
29
- def test_enhance_speech_waveform_two_dim
29
+ def test_enhance_speech_waveform_multi_dim
30
30
  channels = 5
31
31
  gtcrn = GTCRN.new
32
32
  waveform = Torch.rand(channels, 16000)
@@ -36,7 +36,7 @@ class TestGTCRN < Test::Unit::TestCase
36
36
 
37
37
  0.upto(channels - 1) do |i|
38
38
  enh = gtcrn.enhance_speech_waveform(waveform[i])
39
- assert enh.equal(enhanced[i])
39
+ assert enh.equal(enhanced[i]), "channel #{i}"
40
40
  end
41
41
  end
42
42
 
metadata CHANGED
@@ -1,7 +1,7 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: gtcrn
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.0.3
4
+ version: 0.0.4
5
5
  platform: ruby
6
6
  authors:
7
7
  - Kitaiti Makoto
@@ -66,7 +66,21 @@ dependencies:
66
66
  - !ruby/object:Gem::Version
67
67
  version: '0'
68
68
  - !ruby/object:Gem::Dependency
69
- name: numo-narray-alt
69
+ name: optparse-pathname
70
+ requirement: !ruby/object:Gem::Requirement
71
+ requirements:
72
+ - - ">="
73
+ - !ruby/object:Gem::Version
74
+ version: '0'
75
+ type: :runtime
76
+ prerelease: false
77
+ version_requirements: !ruby/object:Gem::Requirement
78
+ requirements:
79
+ - - ">="
80
+ - !ruby/object:Gem::Version
81
+ version: '0'
82
+ - !ruby/object:Gem::Dependency
83
+ name: onnxruntime-torch-tensor
70
84
  requirement: !ruby/object:Gem::Requirement
71
85
  requirements:
72
86
  - - ">="