sonusai 1.0.16__cp311-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sonusai/__init__.py +170 -0
  2. sonusai/aawscd_probwrite.py +148 -0
  3. sonusai/audiofe.py +481 -0
  4. sonusai/calc_metric_spenh.py +1136 -0
  5. sonusai/config/__init__.py +0 -0
  6. sonusai/config/asr.py +21 -0
  7. sonusai/config/config.py +65 -0
  8. sonusai/config/config.yml +49 -0
  9. sonusai/config/constants.py +53 -0
  10. sonusai/config/ir.py +124 -0
  11. sonusai/config/ir_delay.py +62 -0
  12. sonusai/config/source.py +275 -0
  13. sonusai/config/spectral_masks.py +15 -0
  14. sonusai/config/truth.py +64 -0
  15. sonusai/constants.py +14 -0
  16. sonusai/data/__init__.py +0 -0
  17. sonusai/data/silero_vad_v5.1.jit +0 -0
  18. sonusai/data/silero_vad_v5.1.onnx +0 -0
  19. sonusai/data/speech_ma01_01.wav +0 -0
  20. sonusai/data/whitenoise.wav +0 -0
  21. sonusai/datatypes.py +383 -0
  22. sonusai/deprecated/gentcst.py +632 -0
  23. sonusai/deprecated/plot.py +519 -0
  24. sonusai/deprecated/tplot.py +365 -0
  25. sonusai/doc.py +52 -0
  26. sonusai/doc_strings/__init__.py +1 -0
  27. sonusai/doc_strings/doc_strings.py +531 -0
  28. sonusai/genft.py +196 -0
  29. sonusai/genmetrics.py +183 -0
  30. sonusai/genmix.py +199 -0
  31. sonusai/genmixdb.py +235 -0
  32. sonusai/ir_metric.py +551 -0
  33. sonusai/lsdb.py +141 -0
  34. sonusai/main.py +134 -0
  35. sonusai/metrics/__init__.py +43 -0
  36. sonusai/metrics/calc_audio_stats.py +42 -0
  37. sonusai/metrics/calc_class_weights.py +90 -0
  38. sonusai/metrics/calc_optimal_thresholds.py +73 -0
  39. sonusai/metrics/calc_pcm.py +45 -0
  40. sonusai/metrics/calc_pesq.py +36 -0
  41. sonusai/metrics/calc_phase_distance.py +43 -0
  42. sonusai/metrics/calc_sa_sdr.py +64 -0
  43. sonusai/metrics/calc_sample_weights.py +25 -0
  44. sonusai/metrics/calc_segsnr_f.py +82 -0
  45. sonusai/metrics/calc_speech.py +382 -0
  46. sonusai/metrics/calc_wer.py +71 -0
  47. sonusai/metrics/calc_wsdr.py +57 -0
  48. sonusai/metrics/calculate_metrics.py +395 -0
  49. sonusai/metrics/class_summary.py +74 -0
  50. sonusai/metrics/confusion_matrix_summary.py +75 -0
  51. sonusai/metrics/one_hot.py +283 -0
  52. sonusai/metrics/snr_summary.py +128 -0
  53. sonusai/metrics_summary.py +314 -0
  54. sonusai/mixture/__init__.py +15 -0
  55. sonusai/mixture/audio.py +187 -0
  56. sonusai/mixture/class_balancing.py +103 -0
  57. sonusai/mixture/constants.py +3 -0
  58. sonusai/mixture/data_io.py +173 -0
  59. sonusai/mixture/db.py +169 -0
  60. sonusai/mixture/db_datatypes.py +92 -0
  61. sonusai/mixture/effects.py +344 -0
  62. sonusai/mixture/feature.py +78 -0
  63. sonusai/mixture/generation.py +1116 -0
  64. sonusai/mixture/helpers.py +351 -0
  65. sonusai/mixture/ir_effects.py +77 -0
  66. sonusai/mixture/log_duration_and_sizes.py +23 -0
  67. sonusai/mixture/mixdb.py +1857 -0
  68. sonusai/mixture/pad_audio.py +35 -0
  69. sonusai/mixture/resample.py +7 -0
  70. sonusai/mixture/sox_effects.py +195 -0
  71. sonusai/mixture/sox_help.py +650 -0
  72. sonusai/mixture/spectral_mask.py +51 -0
  73. sonusai/mixture/truth.py +61 -0
  74. sonusai/mixture/truth_functions/__init__.py +45 -0
  75. sonusai/mixture/truth_functions/crm.py +105 -0
  76. sonusai/mixture/truth_functions/energy.py +222 -0
  77. sonusai/mixture/truth_functions/file.py +48 -0
  78. sonusai/mixture/truth_functions/metadata.py +24 -0
  79. sonusai/mixture/truth_functions/metrics.py +28 -0
  80. sonusai/mixture/truth_functions/phoneme.py +18 -0
  81. sonusai/mixture/truth_functions/sed.py +98 -0
  82. sonusai/mixture/truth_functions/target.py +142 -0
  83. sonusai/mkwav.py +135 -0
  84. sonusai/onnx_predict.py +363 -0
  85. sonusai/parse/__init__.py +0 -0
  86. sonusai/parse/expand.py +156 -0
  87. sonusai/parse/parse_source_directive.py +129 -0
  88. sonusai/parse/rand.py +214 -0
  89. sonusai/py.typed +0 -0
  90. sonusai/queries/__init__.py +0 -0
  91. sonusai/queries/queries.py +239 -0
  92. sonusai/rs.abi3.so +0 -0
  93. sonusai/rs.pyi +1 -0
  94. sonusai/rust/__init__.py +0 -0
  95. sonusai/speech/__init__.py +0 -0
  96. sonusai/speech/l2arctic.py +121 -0
  97. sonusai/speech/librispeech.py +102 -0
  98. sonusai/speech/mcgill.py +71 -0
  99. sonusai/speech/textgrid.py +89 -0
  100. sonusai/speech/timit.py +138 -0
  101. sonusai/speech/types.py +12 -0
  102. sonusai/speech/vctk.py +53 -0
  103. sonusai/speech/voxceleb.py +108 -0
  104. sonusai/utils/__init__.py +3 -0
  105. sonusai/utils/asl_p56.py +130 -0
  106. sonusai/utils/asr.py +91 -0
  107. sonusai/utils/asr_functions/__init__.py +3 -0
  108. sonusai/utils/asr_functions/aaware_whisper.py +69 -0
  109. sonusai/utils/audio_devices.py +50 -0
  110. sonusai/utils/braced_glob.py +50 -0
  111. sonusai/utils/calculate_input_shape.py +26 -0
  112. sonusai/utils/choice.py +51 -0
  113. sonusai/utils/compress.py +25 -0
  114. sonusai/utils/convert_string_to_number.py +6 -0
  115. sonusai/utils/create_timestamp.py +5 -0
  116. sonusai/utils/create_ts_name.py +14 -0
  117. sonusai/utils/dataclass_from_dict.py +27 -0
  118. sonusai/utils/db.py +16 -0
  119. sonusai/utils/docstring.py +53 -0
  120. sonusai/utils/energy_f.py +44 -0
  121. sonusai/utils/engineering_number.py +166 -0
  122. sonusai/utils/evaluate_random_rule.py +15 -0
  123. sonusai/utils/get_frames_per_batch.py +2 -0
  124. sonusai/utils/get_label_names.py +20 -0
  125. sonusai/utils/grouper.py +6 -0
  126. sonusai/utils/human_readable_size.py +7 -0
  127. sonusai/utils/keyboard_interrupt.py +12 -0
  128. sonusai/utils/load_object.py +21 -0
  129. sonusai/utils/max_text_width.py +9 -0
  130. sonusai/utils/model_utils.py +28 -0
  131. sonusai/utils/numeric_conversion.py +11 -0
  132. sonusai/utils/onnx_utils.py +155 -0
  133. sonusai/utils/parallel.py +162 -0
  134. sonusai/utils/path_info.py +7 -0
  135. sonusai/utils/print_mixture_details.py +60 -0
  136. sonusai/utils/rand.py +13 -0
  137. sonusai/utils/ranges.py +43 -0
  138. sonusai/utils/read_predict_data.py +32 -0
  139. sonusai/utils/reshape.py +154 -0
  140. sonusai/utils/seconds_to_hms.py +7 -0
  141. sonusai/utils/stacked_complex.py +82 -0
  142. sonusai/utils/stratified_shuffle_split.py +170 -0
  143. sonusai/utils/tokenized_shell_vars.py +143 -0
  144. sonusai/utils/write_audio.py +26 -0
  145. sonusai/utils/yes_or_no.py +8 -0
  146. sonusai/vars.py +47 -0
  147. sonusai-1.0.16.dist-info/METADATA +56 -0
  148. sonusai-1.0.16.dist-info/RECORD +150 -0
  149. sonusai-1.0.16.dist-info/WHEEL +4 -0
  150. sonusai-1.0.16.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,98 @@
1
+ from numpy.lib.utils import source
2
+
3
+ from ...datatypes import Truth
4
+ from ..mixdb import MixtureDatabase
5
+
6
+
7
+ def sed_validate(config: dict) -> None:
8
+ if len(config) == 0:
9
+ raise AttributeError("sed truth function is missing config")
10
+
11
+ parameters = ["thresholds"]
12
+ for parameter in parameters:
13
+ if parameter not in config:
14
+ raise AttributeError(f"sed truth function is missing required '{parameter}'")
15
+
16
+ thresholds = config["thresholds"]
17
+ if not _strictly_decreasing(thresholds):
18
+ raise ValueError(f"sed truth function 'thresholds' are not strictly decreasing: {thresholds}")
19
+
20
+
21
+ def sed_parameters(_feature: str, num_classes: int, _config: dict) -> int:
22
+ return num_classes
23
+
24
+
25
+ def sed(mixdb: MixtureDatabase, m_id: int, category: str, config: dict) -> Truth:
26
+ """Sound energy detection truth generation function
27
+
28
+ Calculates sound energy detection truth using simple 3 threshold
29
+ hysteresis algorithm. SED outputs 3 possible probabilities of
30
+ sound presence: 1.0 present, 0.5 (transition/uncertain), 0 not
31
+ present. The output values will be assigned to the truth output
32
+ at the index specified in the config.
33
+
34
+ Output shape: [:, num_classes]
35
+
36
+ index Truth index <int> or list(<int>)
37
+
38
+ index indicates which truth fields should be set.
39
+ 0 indicates none, 1 is first element in truth output vector, 2 2nd element, etc.
40
+
41
+ Examples:
42
+ index = 5 truth in class 5, truth(4, 1)
43
+ index = [1, 5] truth in classes 1 and 5, truth([0, 4], 1)
44
+
45
+ In mutually-exclusive mode, a frame is expected to only
46
+ belong to one class and thus all probabilities must sum to
47
+ 1. This is effectively truth for a classifier with multichannel
48
+ softmax output.
49
+
50
+ For multi-label classification each class is an individual
51
+ probability for that class and any given frame can be
52
+ assigned to multiple classes/labels, i.e., the classes are
53
+ not mutually-exclusive. For example, a NN classifier with
54
+ multichannel sigmoid output. In this case, index could
55
+ also be a vector with multiple class indices.
56
+ """
57
+ import numpy as np
58
+ import torch
59
+ from pyaaware import SED
60
+ from pyaaware import ForwardTransform
61
+ from pyaaware import feature_forward_transform_config
62
+ from pyaaware import feature_inverse_transform_config
63
+
64
+ source_audio = torch.from_numpy(mixdb.mixture_sources(m_id)[category])
65
+
66
+ frame_size = feature_inverse_transform_config(mixdb.feature)["overlap"]
67
+
68
+ ft = ForwardTransform(**feature_forward_transform_config(mixdb.feature))
69
+
70
+ if len(source_audio) % frame_size != 0:
71
+ raise ValueError(f"Number of samples in audio is not a multiple of {frame_size}")
72
+
73
+ frames = ft.frames(source_audio)
74
+ parameters = sed_parameters(mixdb.feature, mixdb.num_classes, config)
75
+ if mixdb.mixture(m_id).all_sources[category].snr_gain == 0:
76
+ return np.zeros((frames, parameters), dtype=np.float32)
77
+
78
+ # SED wants 1-based indices
79
+ s = SED(
80
+ thresholds=config["thresholds"],
81
+ index=mixdb.source_file(mixdb.mixture(m_id).all_sources[category].file_id).class_indices,
82
+ frame_size=frame_size,
83
+ num_classes=mixdb.num_classes,
84
+ )
85
+
86
+ # Compute energy
87
+ target_energy = ft.execute_all(source_audio)[1].numpy()
88
+
89
+ if frames != target_energy.shape[0]:
90
+ raise ValueError("Incorrect frames calculation in sed truth function")
91
+
92
+ return s.execute_all(target_energy)
93
+
94
+
95
+ def _strictly_decreasing(list_to_check: list) -> bool:
96
+ from itertools import pairwise
97
+
98
+ return all(x > y for x, y in pairwise(list_to_check))
@@ -0,0 +1,142 @@
1
+ from ...datatypes import Truth
2
+ from ..mixdb import MixtureDatabase
3
+
4
+
5
+ def target_f_validate(_config: dict) -> None:
6
+ pass
7
+
8
+
9
+ def target_f_parameters(feature: str, _num_classes: int, _config: dict) -> int:
10
+ from pyaaware import ForwardTransform
11
+ from pyaaware import feature_forward_transform_config
12
+
13
+ ft = ForwardTransform(**feature_forward_transform_config(feature))
14
+
15
+ if ft.ttype == "tdac-co":
16
+ return ft.bins
17
+
18
+ return ft.bins * 2
19
+
20
+
21
+ def target_f(mixdb: MixtureDatabase, m_id: int, category: str, _config: dict) -> Truth:
22
+ """Frequency domain target truth function
23
+
24
+ Calculates the true transform of the target using the STFT
25
+ configuration defined by the feature. This will include a
26
+ forward transform window if defined by the feature.
27
+
28
+ Output shape: [:, 2 * bins] (target stacked real, imag) or
29
+ [:, bins] (target real only for tdac-co)
30
+ """
31
+ import torch
32
+ from pyaaware import ForwardTransform
33
+ from pyaaware import feature_forward_transform_config
34
+
35
+ ft = ForwardTransform(**feature_forward_transform_config(mixdb.feature))
36
+
37
+ target_audio = torch.from_numpy(mixdb.mixture_sources(m_id)[category])
38
+
39
+ target_freq = ft.execute_all(target_audio)[0].numpy()
40
+ return _stack_real_imag(target_freq, ft.ttype)
41
+
42
+
43
+ def target_mixture_f_validate(_config: dict) -> None:
44
+ pass
45
+
46
+
47
+ def target_mixture_f_parameters(feature: str, _num_classes: int, _config: dict) -> int:
48
+ from pyaaware import ForwardTransform
49
+ from pyaaware import feature_forward_transform_config
50
+
51
+ ft = ForwardTransform(**feature_forward_transform_config(feature))
52
+
53
+ if ft.ttype == "tdac-co":
54
+ return ft.bins * 2
55
+
56
+ return ft.bins * 4
57
+
58
+
59
+ def target_mixture_f(mixdb: MixtureDatabase, m_id: int, category: str, _config: dict) -> Truth:
60
+ """Frequency domain target and mixture truth function
61
+
62
+ Calculates the true transform of the target and the mixture
63
+ using the STFT configuration defined by the feature. This
64
+ will include a forward transform window if defined by the
65
+ feature.
66
+
67
+ Output shape: [:, 4 * bins] (target stacked real, imag; mixture stacked real, imag) or
68
+ [:, 2 * bins] (target real; mixture real for tdac-co)
69
+ """
70
+ import numpy as np
71
+ import torch
72
+ from pyaaware import ForwardTransform
73
+ from pyaaware import feature_forward_transform_config
74
+
75
+ ft = ForwardTransform(**feature_forward_transform_config(mixdb.feature))
76
+
77
+ target_audio = torch.from_numpy(mixdb.mixture_sources(m_id)[category])
78
+ mixture_audio = torch.from_numpy(mixdb.mixture_mixture(m_id))
79
+
80
+ target_freq = ft.execute_all(torch.from_numpy(target_audio))[0].numpy()
81
+ mixture_freq = ft.execute_all(torch.from_numpy(mixture_audio))[0].numpy()
82
+
83
+ frames, bins = target_freq.shape
84
+ truth = np.empty((frames, bins * 4), dtype=np.float32)
85
+ truth[:, : bins * 2] = _stack_real_imag(target_freq, ft.ttype)
86
+ truth[:, bins * 2 :] = _stack_real_imag(mixture_freq, ft.ttype)
87
+ return truth
88
+
89
+
90
+ def target_swin_f_validate(_config: dict) -> None:
91
+ pass
92
+
93
+
94
+ def target_swin_f_parameters(feature: str, _num_classes: int, _config: dict) -> int:
95
+ from pyaaware import ForwardTransform
96
+ from pyaaware import feature_forward_transform_config
97
+
98
+ return ForwardTransform(**feature_forward_transform_config(feature)).bins * 2
99
+
100
+
101
+ def target_swin_f(mixdb: MixtureDatabase, m_id: int, category: str, _config: dict) -> Truth:
102
+ """Frequency domain target with synthesis window truth function
103
+
104
+ Calculates the true transform of the target using the STFT
105
+ configuration defined by the feature. This will include a
106
+ forward transform window if defined by the feature and also
107
+ the inverse transform (or synthesis) window.
108
+
109
+ Output shape: [:, 2 * bins] (stacked real, imag)
110
+ """
111
+ import numpy as np
112
+ import torch
113
+ from pyaaware import ForwardTransform
114
+ from pyaaware import InverseTransform
115
+ from pyaaware import feature_forward_transform_config
116
+ from pyaaware import feature_inverse_transform_config
117
+
118
+ from ...utils.stacked_complex import stack_complex
119
+
120
+ ft = ForwardTransform(**feature_forward_transform_config(mixdb.feature))
121
+ it = InverseTransform(**feature_inverse_transform_config(mixdb.feature))
122
+
123
+ target_audio = mixdb.mixture_sources(m_id)[category]
124
+
125
+ truth = np.empty((len(target_audio) // ft.overlap, ft.bins * 2), dtype=np.float32)
126
+ for idx, offset in enumerate(range(0, len(target_audio), ft.overlap)):
127
+ audio_frame = torch.from_numpy(np.multiply(target_audio[offset : offset + ft.overlap], it.window))
128
+ target_freq = ft.execute(audio_frame)[0].numpy()
129
+ truth[idx] = stack_complex(target_freq)
130
+
131
+ return truth
132
+
133
+
134
+ def _stack_real_imag(data: Truth, ttype: str) -> Truth:
135
+ import numpy as np
136
+
137
+ from ...utils.stacked_complex import stack_complex
138
+
139
+ if ttype == "tdac-co":
140
+ return np.real(data)
141
+
142
+ return stack_complex(data)
sonusai/mkwav.py ADDED
@@ -0,0 +1,135 @@
1
+ """sonusai mkwav
2
+
3
+ usage: mkwav [-hvtsn] [-i MIXID] LOC
4
+
5
+ options:
6
+ -h, --help
7
+ -v, --verbose Be verbose.
8
+ -i MIXID, --mixid MIXID Mixture ID(s) to generate. [default: *].
9
+ -t, --source Write source file.
10
+ -s, --sources Write sources files.
11
+ -n, --noise Write noise file.
12
+
13
+ The mkwav command creates WAV files from a SonusAI database.
14
+
15
+ Inputs:
16
+ LOC A SonusAI mixture database directory.
17
+ MIXID A glob of mixture ID(s) to generate.
18
+
19
+ Outputs the following to the mixture database directory:
20
+ <id>
21
+ mixture.wav: mixture
22
+ source.wav: source (optional)
23
+ source_<c>.wav: source <category> (optional)
24
+ noise.wav: noise (optional)
25
+ metadata.txt
26
+ mkwav.log
27
+
28
+ """
29
+
30
+
31
+ def _process_mixture(m_id: int, location: str, write_target: bool, write_targets: bool, write_noise: bool) -> None:
32
+ from os import makedirs
33
+ from os.path import join
34
+
35
+ from sonusai.mixture import MixtureDatabase
36
+ from sonusai.mixture.helpers import write_mixture_metadata
37
+ from sonusai.utils.numeric_conversion import float_to_int16
38
+ from sonusai.utils.write_audio import write_audio
39
+
40
+ mixdb = MixtureDatabase(location)
41
+
42
+ index = mixdb.mixture(m_id).name
43
+ location = join(mixdb.location, "mixture", index)
44
+ makedirs(location, exist_ok=True)
45
+
46
+ write_audio(name=join(location, "mixture.wav"), audio=float_to_int16(mixdb.mixture_mixture(m_id)))
47
+ if write_target:
48
+ write_audio(name=join(location, "source.wav"), audio=float_to_int16(mixdb.mixture_source(m_id)))
49
+ if write_targets:
50
+ for category, source in mixdb.mixture_sources(m_id).items():
51
+ write_audio(name=join(location, f"sources_{category}.wav"), audio=float_to_int16(source))
52
+ if write_noise:
53
+ write_audio(name=join(location, "noise.wav"), audio=float_to_int16(mixdb.mixture_noise(m_id)))
54
+
55
+ write_mixture_metadata(mixdb, m_id=m_id)
56
+
57
+
58
+ def main() -> None:
59
+ from docopt import docopt
60
+
61
+ from sonusai import __version__ as sai_version
62
+ from sonusai.utils.docstring import trim_docstring
63
+
64
+ args = docopt(trim_docstring(__doc__), version=sai_version, options_first=True)
65
+
66
+ verbose = args["--verbose"]
67
+ mixid = args["--mixid"]
68
+ write_source = args["--source"]
69
+ write_sources = args["--sources"]
70
+ write_noise = args["--noise"]
71
+ location = args["LOC"]
72
+
73
+ import time
74
+ from functools import partial
75
+ from os.path import join
76
+
77
+ from sonusai import create_file_handler
78
+ from sonusai import initial_log_messages
79
+ from sonusai import logger
80
+ from sonusai import update_console_handler
81
+ from sonusai.mixture import MixtureDatabase
82
+ from sonusai.mixture.helpers import check_audio_files_exist
83
+ from sonusai.utils.parallel import par_track
84
+ from sonusai.utils.parallel import track
85
+ from sonusai.utils.seconds_to_hms import seconds_to_hms
86
+
87
+ start_time = time.monotonic()
88
+
89
+ create_file_handler(join(location, "mkwav.log"), verbose)
90
+ update_console_handler(verbose)
91
+ initial_log_messages("mkwav")
92
+
93
+ logger.info(f"Load mixture database from {location}")
94
+ mixdb = MixtureDatabase(location)
95
+ mixid = mixdb.mixids_to_list(mixid)
96
+
97
+ total_samples = mixdb.total_samples(mixid)
98
+
99
+ logger.info("")
100
+ logger.info(f"Found {len(mixid):,} mixtures to process")
101
+ logger.info(f"{total_samples:,} samples")
102
+
103
+ check_audio_files_exist(mixdb)
104
+
105
+ progress = track(total=len(mixid))
106
+ par_track(
107
+ partial(
108
+ _process_mixture,
109
+ location=location,
110
+ write_target=write_source,
111
+ write_targets=write_sources,
112
+ write_noise=write_noise,
113
+ ),
114
+ mixid,
115
+ progress=progress,
116
+ # no_par=True,
117
+ )
118
+ progress.close()
119
+
120
+ logger.info(f"Wrote {len(mixid)} mixtures to {location}")
121
+ logger.info("")
122
+ end_time = time.monotonic()
123
+ logger.info(f"Completed in {seconds_to_hms(seconds=end_time - start_time)}")
124
+ logger.info("")
125
+
126
+
127
+ if __name__ == "__main__":
128
+ from sonusai import exception_handler
129
+ from sonusai.utils.keyboard_interrupt import register_keyboard_interrupt
130
+
131
+ register_keyboard_interrupt()
132
+ try:
133
+ main()
134
+ except Exception as e:
135
+ exception_handler(e)