sonusai 0.18.9__py3-none-any.whl → 0.19.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (118) hide show
  1. sonusai/__init__.py +20 -29
  2. sonusai/aawscd_probwrite.py +18 -18
  3. sonusai/audiofe.py +93 -80
  4. sonusai/calc_metric_spenh.py +395 -321
  5. sonusai/data/genmixdb.yml +5 -11
  6. sonusai/{gentcst.py → deprecated/gentcst.py} +146 -149
  7. sonusai/{plot.py → deprecated/plot.py} +177 -131
  8. sonusai/{tplot.py → deprecated/tplot.py} +124 -102
  9. sonusai/doc/__init__.py +1 -1
  10. sonusai/doc/doc.py +112 -177
  11. sonusai/doc.py +10 -10
  12. sonusai/genft.py +93 -77
  13. sonusai/genmetrics.py +59 -46
  14. sonusai/genmix.py +116 -104
  15. sonusai/genmixdb.py +194 -153
  16. sonusai/lsdb.py +56 -66
  17. sonusai/main.py +23 -20
  18. sonusai/metrics/__init__.py +2 -0
  19. sonusai/metrics/calc_audio_stats.py +29 -24
  20. sonusai/metrics/calc_class_weights.py +7 -7
  21. sonusai/metrics/calc_optimal_thresholds.py +5 -7
  22. sonusai/metrics/calc_pcm.py +3 -3
  23. sonusai/metrics/calc_pesq.py +10 -7
  24. sonusai/metrics/calc_phase_distance.py +3 -3
  25. sonusai/metrics/calc_sa_sdr.py +10 -8
  26. sonusai/metrics/calc_segsnr_f.py +15 -17
  27. sonusai/metrics/calc_speech.py +105 -47
  28. sonusai/metrics/calc_wer.py +35 -32
  29. sonusai/metrics/calc_wsdr.py +10 -7
  30. sonusai/metrics/class_summary.py +30 -27
  31. sonusai/metrics/confusion_matrix_summary.py +25 -22
  32. sonusai/metrics/one_hot.py +91 -57
  33. sonusai/metrics/snr_summary.py +53 -46
  34. sonusai/mixture/__init__.py +19 -14
  35. sonusai/mixture/audio.py +4 -6
  36. sonusai/mixture/augmentation.py +37 -43
  37. sonusai/mixture/class_count.py +5 -14
  38. sonusai/mixture/config.py +292 -225
  39. sonusai/mixture/constants.py +41 -30
  40. sonusai/mixture/data_io.py +155 -0
  41. sonusai/mixture/datatypes.py +111 -108
  42. sonusai/mixture/db_datatypes.py +54 -70
  43. sonusai/mixture/eq_rule_is_valid.py +6 -9
  44. sonusai/mixture/feature.py +40 -38
  45. sonusai/mixture/generation.py +522 -389
  46. sonusai/mixture/helpers.py +217 -272
  47. sonusai/mixture/log_duration_and_sizes.py +16 -13
  48. sonusai/mixture/mixdb.py +669 -477
  49. sonusai/mixture/soundfile_audio.py +12 -17
  50. sonusai/mixture/sox_audio.py +91 -112
  51. sonusai/mixture/sox_augmentation.py +8 -9
  52. sonusai/mixture/spectral_mask.py +4 -6
  53. sonusai/mixture/target_class_balancing.py +41 -36
  54. sonusai/mixture/targets.py +69 -67
  55. sonusai/mixture/tokenized_shell_vars.py +23 -23
  56. sonusai/mixture/torchaudio_audio.py +14 -15
  57. sonusai/mixture/torchaudio_augmentation.py +23 -27
  58. sonusai/mixture/truth.py +48 -26
  59. sonusai/mixture/truth_functions/__init__.py +26 -0
  60. sonusai/mixture/truth_functions/crm.py +56 -38
  61. sonusai/mixture/truth_functions/datatypes.py +37 -0
  62. sonusai/mixture/truth_functions/energy.py +85 -59
  63. sonusai/mixture/truth_functions/file.py +30 -30
  64. sonusai/mixture/truth_functions/phoneme.py +14 -7
  65. sonusai/mixture/truth_functions/sed.py +71 -45
  66. sonusai/mixture/truth_functions/target.py +69 -106
  67. sonusai/mkwav.py +52 -85
  68. sonusai/onnx_predict.py +46 -43
  69. sonusai/queries/__init__.py +3 -1
  70. sonusai/queries/queries.py +100 -59
  71. sonusai/speech/__init__.py +2 -0
  72. sonusai/speech/l2arctic.py +24 -23
  73. sonusai/speech/librispeech.py +16 -17
  74. sonusai/speech/mcgill.py +22 -21
  75. sonusai/speech/textgrid.py +32 -25
  76. sonusai/speech/timit.py +45 -42
  77. sonusai/speech/vctk.py +14 -13
  78. sonusai/speech/voxceleb.py +26 -20
  79. sonusai/summarize_metric_spenh.py +11 -10
  80. sonusai/utils/__init__.py +4 -3
  81. sonusai/utils/asl_p56.py +1 -1
  82. sonusai/utils/asr.py +37 -17
  83. sonusai/utils/asr_functions/__init__.py +2 -0
  84. sonusai/utils/asr_functions/aaware_whisper.py +18 -12
  85. sonusai/utils/audio_devices.py +12 -12
  86. sonusai/utils/braced_glob.py +6 -8
  87. sonusai/utils/calculate_input_shape.py +1 -4
  88. sonusai/utils/compress.py +2 -2
  89. sonusai/utils/convert_string_to_number.py +1 -3
  90. sonusai/utils/create_timestamp.py +1 -1
  91. sonusai/utils/create_ts_name.py +2 -2
  92. sonusai/utils/dataclass_from_dict.py +1 -1
  93. sonusai/utils/docstring.py +6 -6
  94. sonusai/utils/energy_f.py +9 -7
  95. sonusai/utils/engineering_number.py +56 -54
  96. sonusai/utils/get_label_names.py +8 -10
  97. sonusai/utils/human_readable_size.py +2 -2
  98. sonusai/utils/model_utils.py +3 -5
  99. sonusai/utils/numeric_conversion.py +2 -4
  100. sonusai/utils/onnx_utils.py +43 -32
  101. sonusai/utils/parallel.py +40 -27
  102. sonusai/utils/print_mixture_details.py +25 -22
  103. sonusai/utils/ranges.py +12 -12
  104. sonusai/utils/read_predict_data.py +11 -9
  105. sonusai/utils/reshape.py +19 -26
  106. sonusai/utils/seconds_to_hms.py +1 -1
  107. sonusai/utils/stacked_complex.py +8 -16
  108. sonusai/utils/stratified_shuffle_split.py +29 -27
  109. sonusai/utils/write_audio.py +2 -2
  110. sonusai/utils/yes_or_no.py +3 -3
  111. sonusai/vars.py +14 -14
  112. {sonusai-0.18.9.dist-info → sonusai-0.19.5.dist-info}/METADATA +20 -21
  113. sonusai-0.19.5.dist-info/RECORD +125 -0
  114. {sonusai-0.18.9.dist-info → sonusai-0.19.5.dist-info}/WHEEL +1 -1
  115. sonusai/mixture/truth_functions/data.py +0 -58
  116. sonusai/utils/read_mixture_data.py +0 -14
  117. sonusai-0.18.9.dist-info/RECORD +0 -125
  118. {sonusai-0.18.9.dist-info → sonusai-0.19.5.dist-info}/entry_points.txt +0 -0
@@ -1,65 +1,91 @@
1
1
  from sonusai.mixture.datatypes import Truth
2
- from sonusai.mixture.truth_functions.data import Data
2
+ from sonusai.mixture.truth_functions.datatypes import TruthFunctionConfig
3
+ from sonusai.mixture.truth_functions.datatypes import TruthFunctionData
3
4
 
4
5
 
5
- def sed(data: Data) -> Truth:
6
+ def _strictly_decreasing(list_to_check: list) -> bool:
7
+ from itertools import pairwise
8
+
9
+ return all(x > y for x, y in pairwise(list_to_check))
10
+
11
+
12
+ def sed_validate(config: dict) -> None:
13
+ if len(config) == 0:
14
+ raise AttributeError("sed truth function is missing config")
15
+
16
+ parameters = ["thresholds"]
17
+ for parameter in parameters:
18
+ if parameter not in config:
19
+ raise AttributeError(f"sed truth function is missing required '{parameter}'")
20
+
21
+ thresholds = config["thresholds"]
22
+ if not _strictly_decreasing(thresholds):
23
+ raise ValueError(f"sed truth function 'thresholds' are not strictly decreasing: {thresholds}")
24
+
25
+
26
+ def sed_parameters(config: TruthFunctionConfig) -> int:
27
+ return config.num_classes
28
+
29
+
30
+ def sed(data: TruthFunctionData, config: TruthFunctionConfig) -> Truth:
6
31
  """Sound energy detection truth generation function
7
32
 
8
- Calculates sound energy detection truth using simple 3 threshold
9
- hysteresis algorithm. SED outputs 3 possible probabilities of
10
- sound presence: 1.0 present, 0.5 (transition/uncertain), 0 not
11
- present. The output values will be assigned to the truth output
12
- at the index specified in the truth_settings: index.
33
+ Calculates sound energy detection truth using simple 3 threshold
34
+ hysteresis algorithm. SED outputs 3 possible probabilities of
35
+ sound presence: 1.0 present, 0.5 (transition/uncertain), 0 not
36
+ present. The output values will be assigned to the truth output
37
+ at the index specified in the config.
38
+
39
+ Output shape: [:, num_classes]
40
+
41
+ index Truth index <int> or list(<int>)
13
42
 
14
- Output shape: [:, num_classes]
43
+ index indicates which truth fields should be set.
44
+ 0 indicates none, 1 is first element in truth output vector, 2 2nd element, etc.
15
45
 
16
- For multilabel classification applications, num_classes should be
17
- set to the number of sounds/classes to be detected.
46
+ Examples:
47
+ index = 5 truth in class 5, truth(4, 1)
48
+ index = [1, 5] truth in classes 1 and 5, truth([0, 4], 1)
18
49
 
19
- For single-label classification, where truth_mutex=1, num_classes
20
- should be set to the number of sounds/classes to be detected + 1 for
21
- the other class.
50
+ In mutually-exclusive mode, a frame is expected to only
51
+ belong to one class and thus all probabilities must sum to
52
+ 1. This is effectively truth for a classifier with multichannel
53
+ softmax output.
54
+
55
+ For multi-label classification each class is an individual
56
+ probability for that class and any given frame can be
57
+ assigned to multiple classes/labels, i.e., the classes are
58
+ not mutually-exclusive. For example, a NN classifier with
59
+ multichannel sigmoid output. In this case, index could
60
+ also be a vector with multiple class indices.
22
61
  """
23
62
  import numpy as np
24
63
  import torch
25
64
  from pyaaware import SED
26
65
 
27
- from sonusai import SonusAIError
28
-
29
- if data.config.config is None:
30
- raise SonusAIError('Truth function SED missing config')
66
+ if len(data.target_audio) % config.frame_size != 0:
67
+ raise ValueError(f"Number of samples in audio is not a multiple of {config.frame_size}")
31
68
 
32
- parameters = ['thresholds']
33
- for parameter in parameters:
34
- if 'thresholds' not in data.config.config:
35
- raise SonusAIError(f'Truth function SED config missing required parameter: {parameter}')
36
-
37
- thresholds = data.config.config['thresholds']
38
- if not _strictly_decreasing(thresholds):
39
- raise SonusAIError(f'Truth function SED thresholds are not strictly decreasing: {thresholds}')
40
-
41
- if len(data.target_audio) % data.frame_size != 0:
42
- raise SonusAIError(f'Number of samples in audio is not a multiple of {data.frame_size}')
69
+ frames = config.target_fft.frames(data.target_audio)
70
+ parameters = sed_parameters(config)
71
+ if config.target_gain == 0:
72
+ return np.zeros((frames, parameters), dtype=np.float32)
43
73
 
44
74
  # SED wants 1-based indices
45
- s = SED(thresholds=thresholds,
46
- index=data.config.index,
47
- frame_size=data.frame_size,
48
- num_classes=data.config.num_classes,
49
- mutex=data.config.mutex)
75
+ s = SED(
76
+ thresholds=config.config["thresholds"],
77
+ index=config.class_indices,
78
+ frame_size=config.frame_size,
79
+ num_classes=config.num_classes,
80
+ )
50
81
 
51
- target_audio = data.target_audio / data.config.target_gain
52
- energy_t = data.target_fft.execute_all(torch.from_numpy(target_audio))[1].numpy()
53
- if len(energy_t) != len(data.offsets):
54
- raise SonusAIError(f'Number of frames in energy_t, {len(energy_t)},'
55
- f' is not number of frames in truth, {len(data.offsets)}')
82
+ # Back out target gain
83
+ target_audio = data.target_audio / config.target_gain
56
84
 
57
- for idx, offset in enumerate(data.offsets):
58
- new_truth = s.execute(energy_t[idx])
59
- data.truth[offset:offset + data.frame_size] = np.reshape(new_truth, (1, len(new_truth)))
85
+ # Compute energy
86
+ target_energy = config.target_fft.execute_all(torch.from_numpy(target_audio))[1].numpy()
60
87
 
61
- return data.truth
88
+ if frames != target_energy.shape[0]:
89
+ raise ValueError("Incorrect frames calculation in sed truth function")
62
90
 
63
-
64
- def _strictly_decreasing(list_to_check: list) -> bool:
65
- return all(x > y for x, y in zip(list_to_check, list_to_check[1:]))
91
+ return s.execute_all(target_energy)
@@ -1,146 +1,109 @@
1
- from sonusai import ForwardTransform
2
-
3
1
  from sonusai.mixture.datatypes import AudioF
4
- from sonusai.mixture.datatypes import AudioT
5
2
  from sonusai.mixture.datatypes import Truth
6
- from sonusai.mixture.truth_functions.data import Data
7
-
8
-
9
- def target_f(data: Data) -> Truth:
10
- """Frequency domain target truth function
3
+ from sonusai.mixture.truth_functions.datatypes import TruthFunctionConfig
4
+ from sonusai.mixture.truth_functions.datatypes import TruthFunctionData
11
5
 
12
- Calculates the true transform of the target using the STFT
13
- configuration defined by the feature. This will include a
14
- forward transform window if defined by the feature.
15
6
 
16
- Output shape: [:, num_classes]
17
- (target stacked real, imag; or real only for tdac-co)
18
- """
19
- from sonusai import SonusAIError
7
+ def target_f_validate(_config: dict) -> None:
8
+ pass
20
9
 
21
- if data.config.num_classes != data.feature_parameters:
22
- raise SonusAIError(f'Invalid num_classes for target_f truth: {data.config.num_classes}')
23
10
 
24
- target_freq = _execute_fft(data.target_audio, data.target_fft, len(data.offsets))
25
- for idx, offset in enumerate(data.offsets):
26
- data.truth = _stack_real_imag(data=target_freq[idx],
27
- offset=offset,
28
- frame_size=data.frame_size,
29
- zero_based_indices=data.zero_based_indices,
30
- bins=data.target_fft.bins,
31
- ttype=data.ttype,
32
- start=0,
33
- truth=data.truth)
11
+ def target_f_parameters(config: TruthFunctionConfig) -> int:
12
+ if config.ttype == "tdac-co":
13
+ return config.target_fft.bins
34
14
 
35
- return data.truth
15
+ return config.target_fft.bins * 2
36
16
 
37
17
 
38
- # TODO: Need Data to include mixture audio to do this correctly
39
- def target_mixture_f(data: Data) -> Truth:
40
- """Frequency domain target and mixture truth function
18
+ def target_f(data: TruthFunctionData, config: TruthFunctionConfig) -> Truth:
19
+ """Frequency domain target truth function
41
20
 
42
- Calculates the true transform of the target and the mixture
43
- using the STFT configuration defined by the feature. This
44
- will include a forward transform window if defined by the
45
- feature.
21
+ Calculates the true transform of the target using the STFT
22
+ configuration defined by the feature. This will include a
23
+ forward transform window if defined by the feature.
46
24
 
47
- Output shape: [:, 2 * num_classes]
48
- (target stacked real, imag; or real only for tdac-co)
49
- (mixture stacked real, imag; or real only for tdac-co)
25
+ Output shape: [:, 2 * bins] (target stacked real, imag) or
26
+ [:, bins] (target real only for tdac-co)
50
27
  """
51
- from sonusai import SonusAIError
28
+ import torch
52
29
 
53
- if data.config.num_classes != 2 * data.feature_parameters:
54
- raise SonusAIError(f'Invalid num_classes for target_mixture_f truth: {data.config.num_classes}')
30
+ target_freq = config.target_fft.execute_all(torch.from_numpy(data.target_audio))[0].numpy()
31
+ return _stack_real_imag(target_freq, config.ttype)
55
32
 
56
- target_freq = _execute_fft(data.target_audio, data.target_fft, len(data.offsets))
57
- mixture_freq = _execute_fft(data.mixture_audio, data.mixture_fft, len(data.offsets))
58
33
 
59
- for idx, offset in enumerate(data.offsets):
60
- data.truth = _stack_real_imag(data=target_freq[idx],
61
- offset=offset,
62
- frame_size=data.frame_size,
63
- zero_based_indices=data.zero_based_indices,
64
- bins=data.target_fft.bins,
65
- ttype=data.ttype,
66
- start=0,
67
- truth=data.truth)
34
+ def target_mixture_f_validate(_config: dict) -> None:
35
+ pass
68
36
 
69
- data.truth = _stack_real_imag(data=mixture_freq[idx],
70
- offset=offset,
71
- frame_size=data.frame_size,
72
- zero_based_indices=data.zero_based_indices,
73
- bins=data.target_fft.bins,
74
- ttype=data.ttype,
75
- start=data.target_fft.bins * 2,
76
- truth=data.truth)
77
37
 
78
- return data.truth
38
+ def target_mixture_f_parameters(config: TruthFunctionConfig) -> int:
39
+ if config.ttype == "tdac-co":
40
+ return config.target_fft.bins * 2
79
41
 
42
+ return config.target_fft.bins * 4
80
43
 
81
- def target_swin_f(data: Data) -> Truth:
82
- """Frequency domain target with synthesis window truth function
83
44
 
84
- Calculates the true transform of the target using the STFT
85
- configuration defined by the feature. This will include a
86
- forward transform window if defined by the feature and also
87
- the inverse transform (or synthesis) window.
45
+ def target_mixture_f(data: TruthFunctionData, config: TruthFunctionConfig) -> Truth:
46
+ """Frequency domain target and mixture truth function
47
+
48
+ Calculates the true transform of the target and the mixture
49
+ using the STFT configuration defined by the feature. This
50
+ will include a forward transform window if defined by the
51
+ feature.
88
52
 
89
- Output shape: [:, 2 * bins] (stacked real, imag)
53
+ Output shape: [:, 4 * bins] (target stacked real, imag; mixture stacked real, imag) or
54
+ [:, 2 * bins] (target real; mixture real for tdac-co)
90
55
  """
91
56
  import numpy as np
57
+ import torch
92
58
 
93
- from sonusai import SonusAIError
59
+ target_freq = config.target_fft.execute_all(torch.from_numpy(data.target_audio))[0].numpy()
60
+ mixture_freq = config.mixture_fft.execute_all(torch.from_numpy(data.mixture_audio))[0].numpy()
94
61
 
95
- if data.config.num_classes != 2 * data.target_fft.bins:
96
- raise SonusAIError(f'Invalid num_classes for target_swin_f truth: {data.config.num_classes}')
62
+ frames, bins = target_freq.shape
63
+ truth = np.empty((frames, bins * 4), dtype=np.float32)
64
+ truth[:, : bins * 2] = _stack_real_imag(target_freq, config.ttype)
65
+ truth[:, bins * 2 :] = _stack_real_imag(mixture_freq, config.ttype)
66
+ return truth
97
67
 
98
- for idx, offset in enumerate(data.offsets):
99
- target_freq, _ = data.target_fft.execute(
100
- np.multiply(data.target_audio[offset:offset + data.frame_size], data.swin))
101
68
 
102
- indices = slice(offset, offset + data.frame_size)
103
- for index in data.zero_based_indices:
104
- bins = _get_bin_slice(index, data.target_fft.bins)
105
- data.truth[indices, bins] = np.real(target_freq[idx])
69
+ def target_swin_f_validate(_config: dict) -> None:
70
+ pass
106
71
 
107
- bins = _get_bin_slice(bins.stop, data.target_fft.bins)
108
- data.truth[indices, bins] = np.imag(target_freq[idx])
109
72
 
110
- return data.truth
73
+ def target_swin_f_parameters(config: TruthFunctionConfig) -> int:
74
+ return config.target_fft.bins * 2
111
75
 
112
76
 
113
- def _execute_fft(audio: AudioT, transform: ForwardTransform, expected_frames: int) -> AudioF:
114
- import torch
115
- from sonusai import SonusAIError
77
+ def target_swin_f(data: TruthFunctionData, config: TruthFunctionConfig) -> Truth:
78
+ """Frequency domain target with synthesis window truth function
79
+
80
+ Calculates the true transform of the target using the STFT
81
+ configuration defined by the feature. This will include a
82
+ forward transform window if defined by the feature and also
83
+ the inverse transform (or synthesis) window.
116
84
 
117
- freq = transform.execute_all(torch.from_numpy(audio))[0].numpy()
118
- if len(freq) != expected_frames:
119
- raise SonusAIError(f'Number of frames, {len(freq)}, is not number of frames expected, {expected_frames}')
120
- return freq
85
+ Output shape: [:, 2 * bins] (stacked real, imag)
86
+ """
87
+ import numpy as np
121
88
 
89
+ from sonusai.utils import stack_complex
122
90
 
123
- def _get_bin_slice(start: int, length: int) -> slice:
124
- return slice(start, start + length)
91
+ truth = np.empty((len(data.target_audio) // config.frame_size, config.target_fft.bins * 2), dtype=np.float32)
92
+ for idx, offset in enumerate(range(0, len(data.target_audio), config.frame_size)):
93
+ target_freq = config.target_fft.execute(
94
+ np.multiply(data.target_audio[offset : offset + config.frame_size], config.swin)
95
+ )[0]
96
+ truth[idx] = stack_complex(target_freq)
125
97
 
98
+ return truth
126
99
 
127
- def _stack_real_imag(data: AudioF,
128
- offset: int,
129
- frame_size: int,
130
- zero_based_indices: list[int],
131
- bins: int,
132
- ttype: str,
133
- start: int,
134
- truth: Truth) -> Truth:
100
+
101
+ def _stack_real_imag(data: AudioF, ttype: str) -> Truth:
135
102
  import numpy as np
136
103
 
137
- i = _get_bin_slice(offset, frame_size)
138
- for index in zero_based_indices:
139
- b = _get_bin_slice(index + start, bins)
140
- truth[i, b] = np.real(data)
104
+ from sonusai.utils import stack_complex
141
105
 
142
- if ttype != 'tdac-co':
143
- b = _get_bin_slice(b.stop, bins)
144
- truth[i, b] = np.imag(data)
106
+ if ttype == "tdac-co":
107
+ return np.real(data)
145
108
 
146
- return truth
109
+ return stack_complex(data)
sonusai/mkwav.py CHANGED
@@ -16,17 +16,18 @@ Inputs:
16
16
  MIXID A glob of mixture ID(s) to generate.
17
17
 
18
18
  Outputs the following to the mixture database directory:
19
- <id>_mixture.wav: mixture
20
- <id>_target.wav: target (optional)
21
- <id>_noise.wav: noise (optional)
22
- <id>.txt
19
+ <id>
20
+ mixture.wav: mixture
21
+ target.wav: target (optional)
22
+ noise.wav: noise (optional)
23
+ metadata.txt
23
24
  mkwav.log
24
25
 
25
26
  """
27
+
26
28
  import signal
27
29
  from dataclasses import dataclass
28
30
 
29
- from sonusai.mixture import AudioT
30
31
  from sonusai.mixture import MixtureDatabase
31
32
 
32
33
 
@@ -35,7 +36,7 @@ def signal_handler(_sig, _frame):
35
36
 
36
37
  from sonusai import logger
37
38
 
38
- logger.info('Canceled due to keyboard interrupt')
39
+ logger.info("Canceled due to keyboard interrupt")
39
40
  sys.exit(1)
40
41
 
41
42
 
@@ -44,70 +45,37 @@ signal.signal(signal.SIGINT, signal_handler)
44
45
 
45
46
  @dataclass
46
47
  class MPGlobal:
47
- mixdb: MixtureDatabase = None
48
- write_target: bool = None
49
- write_noise: bool = None
50
-
51
-
52
- MP_GLOBAL = MPGlobal()
53
-
48
+ mixdb: MixtureDatabase
49
+ write_target: bool
50
+ write_noise: bool
54
51
 
55
- def mkwav(location: str, mixid: int) -> tuple[AudioT, AudioT, AudioT]:
56
- import numpy as np
57
52
 
58
- from sonusai.genmix import genmix
53
+ MP_GLOBAL: MPGlobal
59
54
 
60
- data = genmix(location=location, mixids=mixid, force=False)
61
55
 
62
- return data[0].mixture, np.sum(data[0].targets, axis=0), data[0].noise
63
-
64
-
65
- def _process_mixture(mixid: int) -> None:
66
- from os.path import exists
56
+ def _process_mixture(m_id: int) -> None:
67
57
  from os.path import join
68
- from os.path import splitext
69
-
70
- import h5py
71
- import numpy as np
72
58
 
73
- from sonusai.mixture import mixture_metadata
59
+ from sonusai.mixture import write_mixture_metadata
74
60
  from sonusai.utils import float_to_int16
75
61
  from sonusai.utils import write_audio
76
62
 
77
- mixture_filename = join(MP_GLOBAL.mixdb.location, MP_GLOBAL.mixdb.mixtures[mixid].name)
78
- mixture_basename = splitext(mixture_filename)[0]
79
-
80
- target = None
81
- noise = None
82
-
83
- need_data = True
84
- if exists(mixture_filename + '.h5'):
85
- with h5py.File(mixture_filename, 'r') as f:
86
- if 'mixture' in f:
87
- need_data = False
88
- if MP_GLOBAL.write_target and 'targets' not in f:
89
- need_data = True
90
- if MP_GLOBAL.write_noise and 'noise' not in f:
91
- need_data = True
92
-
93
- if need_data:
94
- mixture, target, noise = mkwav(location=MP_GLOBAL.mixdb.location, mixid=mixid)
95
- else:
96
- with h5py.File(mixture_filename, 'r') as f:
97
- mixture = np.array(f['mixture'])
98
- if MP_GLOBAL.write_target:
99
- target = np.sum(np.array(f['targets']), axis=0)
100
- if MP_GLOBAL.write_noise:
101
- noise = np.array(f['noise'])
102
-
103
- write_audio(name=mixture_basename + '_mixture.wav', audio=float_to_int16(mixture))
104
- if MP_GLOBAL.write_target:
105
- write_audio(name=mixture_basename + '_target.wav', audio=float_to_int16(target))
106
- if MP_GLOBAL.write_noise:
107
- write_audio(name=mixture_basename + '_noise.wav', audio=float_to_int16(noise))
63
+ global MP_GLOBAL
64
+
65
+ mixdb = MP_GLOBAL.mixdb
66
+ write_target = MP_GLOBAL.write_target
67
+ write_noise = MP_GLOBAL.write_noise
108
68
 
109
- with open(file=mixture_basename + '.txt', mode='w') as f:
110
- f.write(mixture_metadata(MP_GLOBAL.mixdb, MP_GLOBAL.mixdb.mixture(mixid)))
69
+ mixture = mixdb.mixture(m_id)
70
+ location = join(mixdb.location, mixture.name)
71
+
72
+ write_audio(name=join(location, "mixture.wav"), audio=float_to_int16(mixdb.mixture_mixture(m_id)))
73
+ if write_target:
74
+ write_audio(name=join(location, "target.wav"), audio=float_to_int16(mixdb.mixture_target(m_id)))
75
+ if write_noise:
76
+ write_audio(name=join(location, "noise.wav"), audio=float_to_int16(mixdb.mixture_noise(m_id)))
77
+
78
+ write_mixture_metadata(mixdb, mixture)
111
79
 
112
80
 
113
81
  def main() -> None:
@@ -118,63 +86,62 @@ def main() -> None:
118
86
 
119
87
  args = docopt(trim_docstring(__doc__), version=sonusai.__version__, options_first=True)
120
88
 
121
- verbose = args['--verbose']
122
- mixid = args['--mixid']
123
- MP_GLOBAL.write_target = args['--target']
124
- MP_GLOBAL.write_noise = args['--noise']
125
- location = args['LOC']
89
+ verbose = args["--verbose"]
90
+ mixid = args["--mixid"]
91
+ MP_GLOBAL.write_target = args["--target"]
92
+ MP_GLOBAL.write_noise = args["--noise"]
93
+ location = args["LOC"]
126
94
 
127
95
  import time
128
96
  from os.path import join
129
97
 
130
- from tqdm import tqdm
131
-
132
98
  import sonusai
133
99
  from sonusai import create_file_handler
134
100
  from sonusai import initial_log_messages
135
101
  from sonusai import logger
136
102
  from sonusai import update_console_handler
137
103
  from sonusai.mixture import check_audio_files_exist
138
- from sonusai.utils import pp_tqdm_imap
139
104
  from sonusai.utils import human_readable_size
105
+ from sonusai.utils import par_track
140
106
  from sonusai.utils import seconds_to_hms
107
+ from sonusai.utils import track
141
108
 
142
109
  start_time = time.monotonic()
143
110
 
144
- create_file_handler(join(location, 'mkwav.log'))
111
+ create_file_handler(join(location, "mkwav.log"))
145
112
  update_console_handler(verbose)
146
- initial_log_messages('mkwav')
113
+ initial_log_messages("mkwav")
147
114
 
148
- logger.info(f'Load mixture database from {location}')
115
+ logger.info(f"Load mixture database from {location}")
149
116
  MP_GLOBAL.mixdb = MixtureDatabase(location)
150
117
  mixid = MP_GLOBAL.mixdb.mixids_to_list(mixid)
151
118
 
152
119
  total_samples = MP_GLOBAL.mixdb.total_samples(mixid)
153
120
  duration = total_samples / sonusai.mixture.SAMPLE_RATE
154
121
 
155
- logger.info('')
156
- logger.info(f'Found {len(mixid):,} mixtures to process')
157
- logger.info(f'{total_samples:,} samples')
122
+ logger.info("")
123
+ logger.info(f"Found {len(mixid):,} mixtures to process")
124
+ logger.info(f"{total_samples:,} samples")
158
125
 
159
126
  check_audio_files_exist(MP_GLOBAL.mixdb)
160
127
 
161
- progress = tqdm(total=len(mixid))
162
- pp_tqdm_imap(_process_mixture, mixid, progress=progress)
128
+ progress = track(total=len(mixid))
129
+ par_track(_process_mixture, mixid, progress=progress)
163
130
  progress.close()
164
131
 
165
- logger.info(f'Wrote {len(mixid)} mixtures to {location}')
166
- logger.info('')
167
- logger.info(f'Duration: {seconds_to_hms(seconds=duration)}')
168
- logger.info(f'mixture: {human_readable_size(total_samples * 2, 1)}')
132
+ logger.info(f"Wrote {len(mixid)} mixtures to {location}")
133
+ logger.info("")
134
+ logger.info(f"Duration: {seconds_to_hms(seconds=duration)}")
135
+ logger.info(f"mixture: {human_readable_size(total_samples * 2, 1)}")
169
136
  if MP_GLOBAL.write_target:
170
- logger.info(f'target: {human_readable_size(total_samples * 2, 1)}')
137
+ logger.info(f"target: {human_readable_size(total_samples * 2, 1)}")
171
138
  if MP_GLOBAL.write_noise:
172
- logger.info(f'noise: {human_readable_size(total_samples * 2, 1)}')
139
+ logger.info(f"noise: {human_readable_size(total_samples * 2, 1)}")
173
140
 
174
141
  end_time = time.monotonic()
175
- logger.info(f'Completed in {seconds_to_hms(seconds=end_time - start_time)}')
176
- logger.info('')
142
+ logger.info(f"Completed in {seconds_to_hms(seconds=end_time - start_time)}")
143
+ logger.info("")
177
144
 
178
145
 
179
- if __name__ == '__main__':
146
+ if __name__ == "__main__":
180
147
  main()