sonusai 1.0.16__cp311-abi3-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sonusai/__init__.py +170 -0
  2. sonusai/aawscd_probwrite.py +148 -0
  3. sonusai/audiofe.py +481 -0
  4. sonusai/calc_metric_spenh.py +1136 -0
  5. sonusai/config/__init__.py +0 -0
  6. sonusai/config/asr.py +21 -0
  7. sonusai/config/config.py +65 -0
  8. sonusai/config/config.yml +49 -0
  9. sonusai/config/constants.py +53 -0
  10. sonusai/config/ir.py +124 -0
  11. sonusai/config/ir_delay.py +62 -0
  12. sonusai/config/source.py +275 -0
  13. sonusai/config/spectral_masks.py +15 -0
  14. sonusai/config/truth.py +64 -0
  15. sonusai/constants.py +14 -0
  16. sonusai/data/__init__.py +0 -0
  17. sonusai/data/silero_vad_v5.1.jit +0 -0
  18. sonusai/data/silero_vad_v5.1.onnx +0 -0
  19. sonusai/data/speech_ma01_01.wav +0 -0
  20. sonusai/data/whitenoise.wav +0 -0
  21. sonusai/datatypes.py +383 -0
  22. sonusai/deprecated/gentcst.py +632 -0
  23. sonusai/deprecated/plot.py +519 -0
  24. sonusai/deprecated/tplot.py +365 -0
  25. sonusai/doc.py +52 -0
  26. sonusai/doc_strings/__init__.py +1 -0
  27. sonusai/doc_strings/doc_strings.py +531 -0
  28. sonusai/genft.py +196 -0
  29. sonusai/genmetrics.py +183 -0
  30. sonusai/genmix.py +199 -0
  31. sonusai/genmixdb.py +235 -0
  32. sonusai/ir_metric.py +551 -0
  33. sonusai/lsdb.py +141 -0
  34. sonusai/main.py +134 -0
  35. sonusai/metrics/__init__.py +43 -0
  36. sonusai/metrics/calc_audio_stats.py +42 -0
  37. sonusai/metrics/calc_class_weights.py +90 -0
  38. sonusai/metrics/calc_optimal_thresholds.py +73 -0
  39. sonusai/metrics/calc_pcm.py +45 -0
  40. sonusai/metrics/calc_pesq.py +36 -0
  41. sonusai/metrics/calc_phase_distance.py +43 -0
  42. sonusai/metrics/calc_sa_sdr.py +64 -0
  43. sonusai/metrics/calc_sample_weights.py +25 -0
  44. sonusai/metrics/calc_segsnr_f.py +82 -0
  45. sonusai/metrics/calc_speech.py +382 -0
  46. sonusai/metrics/calc_wer.py +71 -0
  47. sonusai/metrics/calc_wsdr.py +57 -0
  48. sonusai/metrics/calculate_metrics.py +395 -0
  49. sonusai/metrics/class_summary.py +74 -0
  50. sonusai/metrics/confusion_matrix_summary.py +75 -0
  51. sonusai/metrics/one_hot.py +283 -0
  52. sonusai/metrics/snr_summary.py +128 -0
  53. sonusai/metrics_summary.py +314 -0
  54. sonusai/mixture/__init__.py +15 -0
  55. sonusai/mixture/audio.py +187 -0
  56. sonusai/mixture/class_balancing.py +103 -0
  57. sonusai/mixture/constants.py +3 -0
  58. sonusai/mixture/data_io.py +173 -0
  59. sonusai/mixture/db.py +169 -0
  60. sonusai/mixture/db_datatypes.py +92 -0
  61. sonusai/mixture/effects.py +344 -0
  62. sonusai/mixture/feature.py +78 -0
  63. sonusai/mixture/generation.py +1116 -0
  64. sonusai/mixture/helpers.py +351 -0
  65. sonusai/mixture/ir_effects.py +77 -0
  66. sonusai/mixture/log_duration_and_sizes.py +23 -0
  67. sonusai/mixture/mixdb.py +1857 -0
  68. sonusai/mixture/pad_audio.py +35 -0
  69. sonusai/mixture/resample.py +7 -0
  70. sonusai/mixture/sox_effects.py +195 -0
  71. sonusai/mixture/sox_help.py +650 -0
  72. sonusai/mixture/spectral_mask.py +51 -0
  73. sonusai/mixture/truth.py +61 -0
  74. sonusai/mixture/truth_functions/__init__.py +45 -0
  75. sonusai/mixture/truth_functions/crm.py +105 -0
  76. sonusai/mixture/truth_functions/energy.py +222 -0
  77. sonusai/mixture/truth_functions/file.py +48 -0
  78. sonusai/mixture/truth_functions/metadata.py +24 -0
  79. sonusai/mixture/truth_functions/metrics.py +28 -0
  80. sonusai/mixture/truth_functions/phoneme.py +18 -0
  81. sonusai/mixture/truth_functions/sed.py +98 -0
  82. sonusai/mixture/truth_functions/target.py +142 -0
  83. sonusai/mkwav.py +135 -0
  84. sonusai/onnx_predict.py +363 -0
  85. sonusai/parse/__init__.py +0 -0
  86. sonusai/parse/expand.py +156 -0
  87. sonusai/parse/parse_source_directive.py +129 -0
  88. sonusai/parse/rand.py +214 -0
  89. sonusai/py.typed +0 -0
  90. sonusai/queries/__init__.py +0 -0
  91. sonusai/queries/queries.py +239 -0
  92. sonusai/rs.abi3.so +0 -0
  93. sonusai/rs.pyi +1 -0
  94. sonusai/rust/__init__.py +0 -0
  95. sonusai/speech/__init__.py +0 -0
  96. sonusai/speech/l2arctic.py +121 -0
  97. sonusai/speech/librispeech.py +102 -0
  98. sonusai/speech/mcgill.py +71 -0
  99. sonusai/speech/textgrid.py +89 -0
  100. sonusai/speech/timit.py +138 -0
  101. sonusai/speech/types.py +12 -0
  102. sonusai/speech/vctk.py +53 -0
  103. sonusai/speech/voxceleb.py +108 -0
  104. sonusai/utils/__init__.py +3 -0
  105. sonusai/utils/asl_p56.py +130 -0
  106. sonusai/utils/asr.py +91 -0
  107. sonusai/utils/asr_functions/__init__.py +3 -0
  108. sonusai/utils/asr_functions/aaware_whisper.py +69 -0
  109. sonusai/utils/audio_devices.py +50 -0
  110. sonusai/utils/braced_glob.py +50 -0
  111. sonusai/utils/calculate_input_shape.py +26 -0
  112. sonusai/utils/choice.py +51 -0
  113. sonusai/utils/compress.py +25 -0
  114. sonusai/utils/convert_string_to_number.py +6 -0
  115. sonusai/utils/create_timestamp.py +5 -0
  116. sonusai/utils/create_ts_name.py +14 -0
  117. sonusai/utils/dataclass_from_dict.py +27 -0
  118. sonusai/utils/db.py +16 -0
  119. sonusai/utils/docstring.py +53 -0
  120. sonusai/utils/energy_f.py +44 -0
  121. sonusai/utils/engineering_number.py +166 -0
  122. sonusai/utils/evaluate_random_rule.py +15 -0
  123. sonusai/utils/get_frames_per_batch.py +2 -0
  124. sonusai/utils/get_label_names.py +20 -0
  125. sonusai/utils/grouper.py +6 -0
  126. sonusai/utils/human_readable_size.py +7 -0
  127. sonusai/utils/keyboard_interrupt.py +12 -0
  128. sonusai/utils/load_object.py +21 -0
  129. sonusai/utils/max_text_width.py +9 -0
  130. sonusai/utils/model_utils.py +28 -0
  131. sonusai/utils/numeric_conversion.py +11 -0
  132. sonusai/utils/onnx_utils.py +155 -0
  133. sonusai/utils/parallel.py +162 -0
  134. sonusai/utils/path_info.py +7 -0
  135. sonusai/utils/print_mixture_details.py +60 -0
  136. sonusai/utils/rand.py +13 -0
  137. sonusai/utils/ranges.py +43 -0
  138. sonusai/utils/read_predict_data.py +32 -0
  139. sonusai/utils/reshape.py +154 -0
  140. sonusai/utils/seconds_to_hms.py +7 -0
  141. sonusai/utils/stacked_complex.py +82 -0
  142. sonusai/utils/stratified_shuffle_split.py +170 -0
  143. sonusai/utils/tokenized_shell_vars.py +143 -0
  144. sonusai/utils/write_audio.py +26 -0
  145. sonusai/utils/yes_or_no.py +8 -0
  146. sonusai/vars.py +47 -0
  147. sonusai-1.0.16.dist-info/METADATA +56 -0
  148. sonusai-1.0.16.dist-info/RECORD +150 -0
  149. sonusai-1.0.16.dist-info/WHEEL +4 -0
  150. sonusai-1.0.16.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,51 @@
1
+ from ..datatypes import AudioF
2
+ from ..datatypes import SpectralMask
3
+
4
+
5
+ def apply_spectral_mask(audio_f: AudioF, spectral_mask: SpectralMask, seed: int | None = None) -> AudioF:
6
+ """Apply frequency and time masking
7
+
8
+ Implementation of SpecAugment: A Simple Data Augmentation Method for Automatic Speech Recognition
9
+
10
+ Ref: https://arxiv.org/pdf/1904.08779.pdf
11
+
12
+ f_width consecutive bins [f_start, f_start + f_width) are masked, where f_width is chosen from a uniform
13
+ distribution from 0 to the f_max_width, and f_start is chosen from [0, bins - f_width).
14
+
15
+ t_width consecutive frames [t_start, t_start + t_width) are masked, where t_width is chosen from a uniform
16
+ distribution from 0 to the t_max_width, and t_start is chosen from [0, frames - t_width).
17
+
18
+ A time mask cannot be wider than t_max_percent times the number of frames.
19
+
20
+ :param audio_f: Numpy array of transform audio data [frames, bins]
21
+ :param spectral_mask: Spectral mask parameters
22
+ :param seed: Random number seed
23
+ :return: Augmented feature
24
+ """
25
+ import numpy as np
26
+
27
+ if audio_f.ndim != 2:
28
+ raise ValueError("feature input must have three dimensions [frames, bins]")
29
+
30
+ frames, bins = audio_f.shape
31
+
32
+ f_max_width = spectral_mask.f_max_width
33
+ if f_max_width not in range(0, bins + 1):
34
+ f_max_width = bins
35
+
36
+ rng = np.random.default_rng(seed)
37
+
38
+ # apply f_num frequency masks to the feature
39
+ for _ in range(spectral_mask.f_num):
40
+ f_width = int(rng.uniform(0, f_max_width))
41
+ f_start = rng.integers(0, bins - f_width, endpoint=True)
42
+ audio_f[:, f_start : f_start + f_width] = 0
43
+
44
+ # apply t_num time masks to the feature
45
+ t_upper_bound = int(spectral_mask.t_max_percent / 100 * frames)
46
+ for _ in range(spectral_mask.t_num):
47
+ t_width = min(int(rng.uniform(0, spectral_mask.t_max_width)), t_upper_bound)
48
+ t_start = rng.integers(0, frames - t_width, endpoint=True)
49
+ audio_f[t_start : t_start + t_width, :] = 0
50
+
51
+ return audio_f
@@ -0,0 +1,61 @@
1
+ from ..datatypes import Truth
2
+ from ..datatypes import TruthsDict
3
+ from .mixdb import MixtureDatabase
4
+
5
+
6
+ def truth_function(mixdb: MixtureDatabase, m_id: int) -> TruthsDict:
7
+ from ..datatypes import TruthDict
8
+ from . import truth_functions
9
+
10
+ result: TruthsDict = {}
11
+ for category, source in mixdb.mixture(m_id).all_sources.items():
12
+ truth: TruthDict = {}
13
+ source_file = mixdb.source_file(source.file_id)
14
+ for name, config in source_file.truth_configs.items():
15
+ try:
16
+ truth[name] = getattr(truth_functions, config.function)(mixdb, m_id, category, config.config)
17
+ except AttributeError as e:
18
+ raise AttributeError(f"Unsupported truth function: {config.function}") from e
19
+ except Exception as e:
20
+ raise RuntimeError(f"Error in truth function '{config.function}': {e}") from e
21
+
22
+ if truth:
23
+ result[category] = truth
24
+
25
+ return result
26
+
27
+
28
+ def get_class_indices_for_mixid(mixdb: MixtureDatabase, mixid: int) -> list[int]:
29
+ """Get a list of class indices for a given mixid."""
30
+ indices: list[int] = []
31
+ for source_id in [source.file_id for source in mixdb.mixture(mixid).all_sources.values()]:
32
+ indices.append(*mixdb.source_file(source_id).class_indices)
33
+
34
+ return sorted(set(indices))
35
+
36
+
37
+ def truth_stride_reduction(truth: Truth, function: str) -> Truth:
38
+ """Reduce stride dimension of truth.
39
+
40
+ :param truth: Truth data [frames, stride, truth_parameters]
41
+ :param function: Truth stride reduction function name
42
+ :return: Stride reduced truth data [frames, stride or 1, truth_parameters]
43
+ """
44
+ import numpy as np
45
+
46
+ if truth.ndim != 3:
47
+ raise ValueError("Invalid truth shape")
48
+
49
+ if function == "none":
50
+ return truth
51
+
52
+ if function == "max":
53
+ return np.max(truth, axis=1, keepdims=True)
54
+
55
+ if function == "mean":
56
+ return np.mean(truth, axis=1, keepdims=True)
57
+
58
+ if function == "first":
59
+ return truth[:, 0, :].reshape((truth.shape[0], 1, truth.shape[2]))
60
+
61
+ raise ValueError(f"Invalid truth stride reduction function: {function}")
@@ -0,0 +1,45 @@
1
+ # SonusAI truth functions
2
+ # ruff: noqa: F401
3
+
4
+ from .crm import crm
5
+ from .crm import crm_parameters
6
+ from .crm import crm_validate
7
+ from .crm import crmp
8
+ from .crm import crmp_parameters
9
+ from .crm import crmp_validate
10
+ from .energy import energy_f
11
+ from .energy import energy_f_parameters
12
+ from .energy import energy_f_validate
13
+ from .energy import energy_t
14
+ from .energy import energy_t_parameters
15
+ from .energy import energy_t_validate
16
+ from .energy import mapped_snr_f
17
+ from .energy import mapped_snr_f_parameters
18
+ from .energy import mapped_snr_f_validate
19
+ from .energy import snr_f
20
+ from .energy import snr_f_parameters
21
+ from .energy import snr_f_validate
22
+ from .file import file
23
+ from .file import file_parameters
24
+ from .file import file_validate
25
+ from .metadata import metadata
26
+ from .metadata import metadata_parameters
27
+ from .metadata import metadata_validate
28
+ from .metrics import metrics
29
+ from .metrics import metrics_parameters
30
+ from .metrics import metrics_validate
31
+ from .phoneme import phoneme
32
+ from .phoneme import phoneme_parameters
33
+ from .phoneme import phoneme_validate
34
+ from .sed import sed
35
+ from .sed import sed_parameters
36
+ from .sed import sed_validate
37
+ from .target import target_f
38
+ from .target import target_f_parameters
39
+ from .target import target_f_validate
40
+ from .target import target_mixture_f
41
+ from .target import target_mixture_f_parameters
42
+ from .target import target_mixture_f_validate
43
+ from .target import target_swin_f
44
+ from .target import target_swin_f_parameters
45
+ from .target import target_swin_f_validate
@@ -0,0 +1,105 @@
1
+ from ...datatypes import Truth
2
+ from ..mixdb import MixtureDatabase
3
+
4
+
5
+ def _core(mixdb: MixtureDatabase, m_id: int, category: str, parameters: int, polar: bool) -> Truth:
6
+ import numpy as np
7
+ import torch
8
+ from pyaaware import ForwardTransform
9
+ from pyaaware import feature_forward_transform_config
10
+ from pyaaware import feature_inverse_transform_config
11
+
12
+ source_audio = torch.from_numpy(mixdb.mixture_sources(m_id)[category])
13
+ t_ft = ForwardTransform(**feature_forward_transform_config(mixdb.feature))
14
+ n_ft = ForwardTransform(**feature_forward_transform_config(mixdb.feature))
15
+
16
+ frames = t_ft.frames(source_audio)
17
+ if mixdb.mixture(m_id).all_sources[category].snr_gain == 0:
18
+ return np.zeros((frames, parameters), dtype=np.float32)
19
+
20
+ noise_audio = torch.from_numpy(mixdb.mixture_noise(m_id))
21
+
22
+ frame_size = feature_inverse_transform_config(mixdb.feature)["overlap"]
23
+
24
+ frames = len(source_audio) // frame_size
25
+ truth = np.empty((frames, t_ft.bins * 2), dtype=np.float32)
26
+ for frame in range(frames):
27
+ offset = frame * frame_size
28
+ target_f = t_ft.execute(source_audio[offset : offset + frame_size])[0].numpy().astype(np.complex64)
29
+ noise_f = n_ft.execute(noise_audio[offset : offset + frame_size])[0].numpy().astype(np.complex64)
30
+ mixture_f = target_f + noise_f
31
+
32
+ crm_data = np.empty(target_f.shape, dtype=np.complex64)
33
+ with np.nditer(target_f, flags=["multi_index"], op_flags=[["readwrite"]]) as it:
34
+ for _ in it:
35
+ num = target_f[it.multi_index]
36
+ den = mixture_f[it.multi_index]
37
+ if num == 0:
38
+ crm_data[it.multi_index] = 0
39
+ elif den == 0:
40
+ crm_data[it.multi_index] = complex(np.inf, np.inf)
41
+ else:
42
+ crm_data[it.multi_index] = num / den
43
+
44
+ truth[frame, : t_ft.bins] = np.absolute(crm_data) if polar else np.real(crm_data)
45
+ truth[frame, t_ft.bins :] = np.angle(crm_data) if polar else np.imag(crm_data)
46
+
47
+ return truth
48
+
49
+
50
+ def crm_validate(_config: dict) -> None:
51
+ pass
52
+
53
+
54
+ def crm_parameters(feature: str, _num_classes: int, _config: dict) -> int:
55
+ from pyaaware import ForwardTransform
56
+ from pyaaware import feature_forward_transform_config
57
+
58
+ return ForwardTransform(**feature_forward_transform_config(feature)).bins * 2
59
+
60
+
61
+ def crm(mixdb: MixtureDatabase, m_id: int, category: str, _config: dict) -> Truth:
62
+ """Complex ratio mask truth generation function
63
+
64
+ Calculates the true complex ratio mask (CRM) truth which is a complex number
65
+ per bin = Mr + j*Mi. For a given noisy STFT bin value Y, it is used as
66
+
67
+ (Mr*Yr + Mi*Yi) / (Yr^2 + Yi^2) + j*(Mi*Yr - Mr*Yi)/ (Yr^2 + Yi^2)
68
+
69
+ Output shape: [:, 2 * bins]
70
+ """
71
+ return _core(
72
+ mixdb=mixdb,
73
+ m_id=m_id,
74
+ category=category,
75
+ parameters=crm_parameters(mixdb.feature, mixdb.num_classes, _config),
76
+ polar=False,
77
+ )
78
+
79
+
80
+ def crmp_validate(_config: dict) -> None:
81
+ pass
82
+
83
+
84
+ def crmp_parameters(feature: str, _num_classes: int, _config: dict) -> int:
85
+ from pyaaware import ForwardTransform
86
+ from pyaaware import feature_forward_transform_config
87
+
88
+ return ForwardTransform(**feature_forward_transform_config(feature)).bins * 2
89
+
90
+
91
+ def crmp(mixdb: MixtureDatabase, m_id: int, category: str, _config: dict) -> Truth:
92
+ """Complex ratio mask polar truth generation function
93
+
94
+ Same as the crm function except the results are magnitude and phase
95
+ instead of real and imaginary.
96
+
97
+ Output shape: [:, bins]
98
+ """
99
+ return _core(
100
+ mixdb=mixdb,
101
+ m_id=m_id,
102
+ category=category,
103
+ parameters=crmp_parameters(mixdb.feature, mixdb.num_classes, _config),
104
+ polar=True,
105
+ )
@@ -0,0 +1,222 @@
1
+ import numpy as np
2
+
3
+ from ...datatypes import Truth
4
+ from ...utils.load_object import load_object
5
+ from ..mixdb import MixtureDatabase
6
+
7
+
8
+ def _core(
9
+ mixdb: MixtureDatabase,
10
+ m_id: int,
11
+ category: str,
12
+ config: dict,
13
+ parameters: int,
14
+ mapped: bool,
15
+ snr: bool,
16
+ use_cache: bool = True,
17
+ ) -> Truth:
18
+ from os.path import join
19
+
20
+ import torch
21
+ from pyaaware import ForwardTransform
22
+ from pyaaware import feature_forward_transform_config
23
+
24
+ from ...utils.energy_f import compute_energy_f
25
+
26
+ source_audio = mixdb.mixture_sources(m_id)[category]
27
+ ft = ForwardTransform(**feature_forward_transform_config(mixdb.feature))
28
+
29
+ frames = ft.frames(torch.from_numpy(source_audio))
30
+
31
+ if mixdb.mixture(m_id).all_sources[category].snr_gain == 0:
32
+ return np.zeros((frames, parameters), dtype=np.float32)
33
+
34
+ noise_audio = mixdb.mixture_noise(m_id)
35
+
36
+ source_energy = compute_energy_f(time_domain=source_audio, transform=ft)
37
+ noise_energy = None
38
+ if snr:
39
+ noise_energy = compute_energy_f(time_domain=noise_audio, transform=ft)
40
+
41
+ frames = len(source_energy)
42
+ truth = np.empty((frames, ft.bins), dtype=np.float32)
43
+ for frame in range(frames):
44
+ tmp = source_energy[frame]
45
+
46
+ if noise_energy is not None:
47
+ old_err = np.seterr(divide="ignore", invalid="ignore")
48
+ tmp /= noise_energy[frame]
49
+ np.seterr(**old_err)
50
+
51
+ tmp = np.nan_to_num(tmp, nan=-np.inf, posinf=np.inf, neginf=-np.inf)
52
+
53
+ if mapped:
54
+ snr_db_mean = load_object(join(mixdb.location, config["snr_db_mean"]), use_cache)
55
+ snr_db_std = load_object(join(mixdb.location, config["snr_db_std"]), use_cache)
56
+ tmp = _calculate_mapped_snr_f(tmp, snr_db_mean, snr_db_std)
57
+
58
+ truth[frame] = tmp
59
+
60
+ return truth
61
+
62
+
63
+ def _calculate_mapped_snr_f(truth_f: np.ndarray, snr_db_mean: np.ndarray, snr_db_std: np.ndarray) -> np.ndarray:
64
+ """Calculate mapped SNR from standard SNR energy per bin/class."""
65
+ import scipy.special as sc
66
+
67
+ old_err = np.seterr(divide="ignore", invalid="ignore")
68
+ num = 10 * np.log10(np.double(truth_f)) - np.double(snr_db_mean)
69
+ den = np.double(snr_db_std) * np.sqrt(2)
70
+ q = num / den
71
+ q = np.nan_to_num(q, nan=-np.inf, posinf=np.inf, neginf=-np.inf)
72
+ result = 0.5 * (1 + sc.erf(q))
73
+ np.seterr(**old_err)
74
+
75
+ return result.astype(np.float32)
76
+
77
+
78
+ def energy_f_validate(_config: dict) -> None:
79
+ pass
80
+
81
+
82
+ def energy_f_parameters(feature: str, _num_classes: int, _config: dict) -> int:
83
+ from pyaaware import ForwardTransform
84
+ from pyaaware import feature_forward_transform_config
85
+
86
+ return ForwardTransform(**feature_forward_transform_config(feature)).bins
87
+
88
+
89
+ def energy_f(mixdb: MixtureDatabase, m_id: int, category: str, config: dict, use_cache: bool = True) -> Truth:
90
+ """Frequency domain energy truth generation function
91
+
92
+ Calculates the true energy per bin:
93
+
94
+ Ti^2 + Tr^2
95
+
96
+ where T is the target STFT bin values.
97
+
98
+ Output shape: [:, bins]
99
+ """
100
+ return _core(
101
+ mixdb=mixdb,
102
+ m_id=m_id,
103
+ category=category,
104
+ config=config,
105
+ parameters=energy_f_parameters(mixdb.feature, mixdb.num_classes, config),
106
+ mapped=False,
107
+ snr=False,
108
+ use_cache=use_cache,
109
+ )
110
+
111
+
112
+ def snr_f_validate(_config: dict) -> None:
113
+ pass
114
+
115
+
116
+ def snr_f_parameters(feature: str, _num_classes: int, _config: dict) -> int:
117
+ from pyaaware import ForwardTransform
118
+ from pyaaware import feature_forward_transform_config
119
+
120
+ return ForwardTransform(**feature_forward_transform_config(feature)).bins
121
+
122
+
123
+ def snr_f(mixdb: MixtureDatabase, m_id: int, category: str, config: dict, use_cache: bool = True) -> Truth:
124
+ """Frequency domain SNR truth function documentation
125
+
126
+ Calculates the true SNR per bin:
127
+
128
+ (Ti^2 + Tr^2) / (Ni^2 + Nr^2)
129
+
130
+ where T is the target and N is the noise STFT bin values.
131
+
132
+ Output shape: [:, bins]
133
+ """
134
+ return _core(
135
+ mixdb=mixdb,
136
+ m_id=m_id,
137
+ category=category,
138
+ config=config,
139
+ parameters=snr_f_parameters(mixdb.feature, mixdb.num_classes, config),
140
+ mapped=False,
141
+ snr=True,
142
+ use_cache=use_cache,
143
+ )
144
+
145
+
146
+ def mapped_snr_f_validate(config: dict) -> None:
147
+ if len(config) == 0:
148
+ raise AttributeError("mapped_snr_f truth function is missing config")
149
+
150
+ for parameter in ("snr_db_mean", "snr_db_std"):
151
+ if parameter not in config:
152
+ raise AttributeError(f"mapped_snr_f truth function is missing required '{parameter}'")
153
+
154
+
155
+ def mapped_snr_f_parameters(feature: str, _num_classes: int, _config: dict) -> int:
156
+ from pyaaware import ForwardTransform
157
+ from pyaaware import feature_forward_transform_config
158
+
159
+ return ForwardTransform(**feature_forward_transform_config(feature)).bins
160
+
161
+
162
+ def mapped_snr_f(mixdb: MixtureDatabase, m_id: int, category: str, config: dict, use_cache: bool = True) -> Truth:
163
+ """Frequency domain mapped SNR truth function documentation
164
+
165
+ Output shape: [:, bins]
166
+ """
167
+ return _core(
168
+ mixdb=mixdb,
169
+ m_id=m_id,
170
+ category=category,
171
+ config=config,
172
+ parameters=mapped_snr_f_parameters(mixdb.feature, mixdb.num_classes, config),
173
+ mapped=True,
174
+ snr=True,
175
+ use_cache=use_cache,
176
+ )
177
+
178
+
179
+ def energy_t_validate(_config: dict) -> None:
180
+ pass
181
+
182
+
183
+ def energy_t_parameters(_feature: str, _num_classes: int, _config: dict) -> int:
184
+ return 1
185
+
186
+
187
+ def energy_t(mixdb: MixtureDatabase, m_id: int, category: str, _config: dict) -> Truth:
188
+ """Time domain energy truth function documentation
189
+
190
+ Calculates the true time domain energy of each frame:
191
+
192
+ For OLS:
193
+ sum(x[0:N-1]^2) / N
194
+
195
+ For OLA:
196
+ sum(x[0:R-1]^2) / R
197
+
198
+ where x is the target time domain data,
199
+ N is the size of the transform, and
200
+ R is the number of new samples in the frame.
201
+
202
+ Output shape: [:, 1]
203
+
204
+ Note: feature transforms can be defined to use a subset of all bins,
205
+ i.e., subset of 0:128 for N=256 could be 0:127 or 1:128. energy_t
206
+ will reflect the total energy over all bins regardless of the feature
207
+ transform config.
208
+ """
209
+ import torch
210
+ from pyaaware import ForwardTransform
211
+ from pyaaware import feature_forward_transform_config
212
+
213
+ source_audio = torch.from_numpy(mixdb.mixture_sources(m_id)[category])
214
+
215
+ ft = ForwardTransform(**feature_forward_transform_config(mixdb.feature))
216
+
217
+ frames = ft.frames(source_audio)
218
+ parameters = energy_f_parameters(mixdb.feature, mixdb.num_classes, _config)
219
+ if mixdb.mixture(m_id).all_sources[category].snr_gain == 0:
220
+ return np.zeros((frames, parameters), dtype=np.float32)
221
+
222
+ return ft.execute_all(source_audio)[1].numpy()
@@ -0,0 +1,48 @@
1
+ from ...datatypes import Truth
2
+ from ..mixdb import MixtureDatabase
3
+
4
+
5
+ def file_validate(config: dict) -> None:
6
+ import h5py
7
+
8
+ if len(config) == 0:
9
+ raise AttributeError("file truth function is missing config")
10
+
11
+ if "file" not in config:
12
+ raise AttributeError("file truth function is missing required 'file'")
13
+
14
+ with h5py.File(config["file"], "r") as f:
15
+ if "truth_f" not in f:
16
+ raise ValueError("Truth file does not contain truth_f dataset")
17
+
18
+
19
+ def file_parameters(_feature: str, _num_classes: int, config: dict) -> int:
20
+ import h5py
21
+ import numpy as np
22
+
23
+ with h5py.File(config["file"], "r") as f:
24
+ truth = np.array(f["truth_f"])
25
+
26
+ return truth.shape[-1]
27
+
28
+
29
+ def file(mixdb: MixtureDatabase, m_id: int, category: str, config: dict) -> Truth:
30
+ """file truth function documentation"""
31
+ import h5py
32
+ import numpy as np
33
+ from pyaaware import feature_inverse_transform_config
34
+
35
+ source_audio = mixdb.mixture_sources(m_id)[category]
36
+
37
+ frame_size = feature_inverse_transform_config(mixdb.feature)["overlap"]
38
+
39
+ with h5py.File(config["file"], "r") as f:
40
+ truth = np.array(f["truth_f"])
41
+
42
+ if truth.ndim != 2:
43
+ raise ValueError("Truth file data is not 2 dimensions")
44
+
45
+ if truth.shape[0] != len(source_audio) // frame_size:
46
+ raise ValueError("Truth file does not contain the right amount of frames")
47
+
48
+ return truth
@@ -0,0 +1,24 @@
1
+ from ...datatypes import Truth
2
+ from ..mixdb import MixtureDatabase
3
+
4
+
5
+ def metadata_validate(config: dict) -> None:
6
+ if len(config) == 0:
7
+ raise AttributeError("metadata truth function is missing config")
8
+
9
+ parameters = ["tier"]
10
+ for parameter in parameters:
11
+ if parameter not in config:
12
+ raise AttributeError(f"metadata truth function is missing required '{parameter}'")
13
+
14
+
15
+ def metadata_parameters(_feature: str, _num_classes: int, _config: dict) -> int | None:
16
+ return None
17
+
18
+
19
+ def metadata(mixdb: MixtureDatabase, m_id: int, category: str, config: dict) -> Truth:
20
+ """Metadata truth generation function
21
+
22
+ Retrieves metadata from target.
23
+ """
24
+ return mixdb.mixture_speech_metadata(m_id, config["tier"])[category]
@@ -0,0 +1,28 @@
1
+ from ...datatypes import Truth
2
+ from ..mixdb import MixtureDatabase
3
+
4
+
5
+ def metrics_validate(config: dict) -> None:
6
+ if len(config) == 0:
7
+ raise AttributeError("metrics truth function is missing config")
8
+
9
+ parameters = ["metric"]
10
+ for parameter in parameters:
11
+ if parameter not in config:
12
+ raise AttributeError(f"metrics truth function is missing required '{parameter}'")
13
+
14
+
15
+ def metrics_parameters(_feature: str, _num_classes: int, _config: dict) -> int | None:
16
+ return None
17
+
18
+
19
+ def metrics(mixdb: MixtureDatabase, m_id: int, category: str, config: dict) -> Truth:
20
+ """Metadata truth generation function
21
+
22
+ Retrieves metrics from target.
23
+ """
24
+ if not isinstance(config["metric"], list):
25
+ m = [config["metric"]]
26
+ else:
27
+ m = config["metric"]
28
+ return mixdb.mixture_metrics(m_id, m)[m[0]][category]
@@ -0,0 +1,18 @@
1
+ from ...datatypes import Truth
2
+ from ..mixdb import MixtureDatabase
3
+
4
+
5
+ def phoneme_validate(_config: dict) -> None:
6
+ raise NotImplementedError("Truth function phoneme is not supported yet")
7
+
8
+
9
+ def phoneme_parameters(_feature: str, _num_classes: int, _config: dict) -> int:
10
+ raise NotImplementedError("Truth function phoneme is not supported yet")
11
+
12
+
13
+ def phoneme(_mixdb: MixtureDatabase, _m_id: int, _category: str, _config: dict) -> Truth:
14
+ """Read in .txt transcript and run a Python function to generate text grid data
15
+ (indicating which phonemes are active). Then generate truth based on this data and put
16
+ in the correct classes based on the index in the config.
17
+ """
18
+ raise NotImplementedError("Truth function phoneme is not supported yet")