sonusai 0.18.9__py3-none-any.whl → 0.19.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (118) hide show
  1. sonusai/__init__.py +20 -29
  2. sonusai/aawscd_probwrite.py +18 -18
  3. sonusai/audiofe.py +93 -80
  4. sonusai/calc_metric_spenh.py +395 -321
  5. sonusai/data/genmixdb.yml +5 -11
  6. sonusai/{gentcst.py → deprecated/gentcst.py} +146 -149
  7. sonusai/{plot.py → deprecated/plot.py} +177 -131
  8. sonusai/{tplot.py → deprecated/tplot.py} +124 -102
  9. sonusai/doc/__init__.py +1 -1
  10. sonusai/doc/doc.py +112 -177
  11. sonusai/doc.py +10 -10
  12. sonusai/genft.py +93 -77
  13. sonusai/genmetrics.py +59 -46
  14. sonusai/genmix.py +116 -104
  15. sonusai/genmixdb.py +194 -153
  16. sonusai/lsdb.py +56 -66
  17. sonusai/main.py +23 -20
  18. sonusai/metrics/__init__.py +2 -0
  19. sonusai/metrics/calc_audio_stats.py +29 -24
  20. sonusai/metrics/calc_class_weights.py +7 -7
  21. sonusai/metrics/calc_optimal_thresholds.py +5 -7
  22. sonusai/metrics/calc_pcm.py +3 -3
  23. sonusai/metrics/calc_pesq.py +10 -7
  24. sonusai/metrics/calc_phase_distance.py +3 -3
  25. sonusai/metrics/calc_sa_sdr.py +10 -8
  26. sonusai/metrics/calc_segsnr_f.py +15 -17
  27. sonusai/metrics/calc_speech.py +105 -47
  28. sonusai/metrics/calc_wer.py +35 -32
  29. sonusai/metrics/calc_wsdr.py +10 -7
  30. sonusai/metrics/class_summary.py +30 -27
  31. sonusai/metrics/confusion_matrix_summary.py +25 -22
  32. sonusai/metrics/one_hot.py +91 -57
  33. sonusai/metrics/snr_summary.py +53 -46
  34. sonusai/mixture/__init__.py +19 -14
  35. sonusai/mixture/audio.py +4 -6
  36. sonusai/mixture/augmentation.py +37 -43
  37. sonusai/mixture/class_count.py +5 -14
  38. sonusai/mixture/config.py +292 -225
  39. sonusai/mixture/constants.py +41 -30
  40. sonusai/mixture/data_io.py +155 -0
  41. sonusai/mixture/datatypes.py +111 -108
  42. sonusai/mixture/db_datatypes.py +54 -70
  43. sonusai/mixture/eq_rule_is_valid.py +6 -9
  44. sonusai/mixture/feature.py +40 -38
  45. sonusai/mixture/generation.py +522 -389
  46. sonusai/mixture/helpers.py +217 -272
  47. sonusai/mixture/log_duration_and_sizes.py +16 -13
  48. sonusai/mixture/mixdb.py +669 -477
  49. sonusai/mixture/soundfile_audio.py +12 -17
  50. sonusai/mixture/sox_audio.py +91 -112
  51. sonusai/mixture/sox_augmentation.py +8 -9
  52. sonusai/mixture/spectral_mask.py +4 -6
  53. sonusai/mixture/target_class_balancing.py +41 -36
  54. sonusai/mixture/targets.py +69 -67
  55. sonusai/mixture/tokenized_shell_vars.py +23 -23
  56. sonusai/mixture/torchaudio_audio.py +14 -15
  57. sonusai/mixture/torchaudio_augmentation.py +23 -27
  58. sonusai/mixture/truth.py +48 -26
  59. sonusai/mixture/truth_functions/__init__.py +26 -0
  60. sonusai/mixture/truth_functions/crm.py +56 -38
  61. sonusai/mixture/truth_functions/datatypes.py +37 -0
  62. sonusai/mixture/truth_functions/energy.py +85 -59
  63. sonusai/mixture/truth_functions/file.py +30 -30
  64. sonusai/mixture/truth_functions/phoneme.py +14 -7
  65. sonusai/mixture/truth_functions/sed.py +71 -45
  66. sonusai/mixture/truth_functions/target.py +69 -106
  67. sonusai/mkwav.py +52 -85
  68. sonusai/onnx_predict.py +46 -43
  69. sonusai/queries/__init__.py +3 -1
  70. sonusai/queries/queries.py +100 -59
  71. sonusai/speech/__init__.py +2 -0
  72. sonusai/speech/l2arctic.py +24 -23
  73. sonusai/speech/librispeech.py +16 -17
  74. sonusai/speech/mcgill.py +22 -21
  75. sonusai/speech/textgrid.py +32 -25
  76. sonusai/speech/timit.py +45 -42
  77. sonusai/speech/vctk.py +14 -13
  78. sonusai/speech/voxceleb.py +26 -20
  79. sonusai/summarize_metric_spenh.py +11 -10
  80. sonusai/utils/__init__.py +4 -3
  81. sonusai/utils/asl_p56.py +1 -1
  82. sonusai/utils/asr.py +37 -17
  83. sonusai/utils/asr_functions/__init__.py +2 -0
  84. sonusai/utils/asr_functions/aaware_whisper.py +18 -12
  85. sonusai/utils/audio_devices.py +12 -12
  86. sonusai/utils/braced_glob.py +6 -8
  87. sonusai/utils/calculate_input_shape.py +1 -4
  88. sonusai/utils/compress.py +2 -2
  89. sonusai/utils/convert_string_to_number.py +1 -3
  90. sonusai/utils/create_timestamp.py +1 -1
  91. sonusai/utils/create_ts_name.py +2 -2
  92. sonusai/utils/dataclass_from_dict.py +1 -1
  93. sonusai/utils/docstring.py +6 -6
  94. sonusai/utils/energy_f.py +9 -7
  95. sonusai/utils/engineering_number.py +56 -54
  96. sonusai/utils/get_label_names.py +8 -10
  97. sonusai/utils/human_readable_size.py +2 -2
  98. sonusai/utils/model_utils.py +3 -5
  99. sonusai/utils/numeric_conversion.py +2 -4
  100. sonusai/utils/onnx_utils.py +43 -32
  101. sonusai/utils/parallel.py +40 -27
  102. sonusai/utils/print_mixture_details.py +25 -22
  103. sonusai/utils/ranges.py +12 -12
  104. sonusai/utils/read_predict_data.py +11 -9
  105. sonusai/utils/reshape.py +19 -26
  106. sonusai/utils/seconds_to_hms.py +1 -1
  107. sonusai/utils/stacked_complex.py +8 -16
  108. sonusai/utils/stratified_shuffle_split.py +29 -27
  109. sonusai/utils/write_audio.py +2 -2
  110. sonusai/utils/yes_or_no.py +3 -3
  111. sonusai/vars.py +14 -14
  112. {sonusai-0.18.9.dist-info → sonusai-0.19.5.dist-info}/METADATA +20 -21
  113. sonusai-0.19.5.dist-info/RECORD +125 -0
  114. {sonusai-0.18.9.dist-info → sonusai-0.19.5.dist-info}/WHEEL +1 -1
  115. sonusai/mixture/truth_functions/data.py +0 -58
  116. sonusai/utils/read_mixture_data.py +0 -14
  117. sonusai-0.18.9.dist-info/RECORD +0 -125
  118. {sonusai-0.18.9.dist-info → sonusai-0.19.5.dist-info}/entry_points.txt +0 -0
sonusai/mixture/truth.py CHANGED
@@ -1,50 +1,72 @@
1
1
  from sonusai.mixture.datatypes import AudioT
2
2
  from sonusai.mixture.datatypes import Truth
3
- from sonusai.mixture.datatypes import TruthFunctionConfig
3
+ from sonusai.mixture.datatypes import TruthConfig
4
4
  from sonusai.mixture.mixdb import MixtureDatabase
5
5
 
6
6
 
7
- def truth_function(target_audio: AudioT,
8
- noise_audio: AudioT,
9
- mixture_audio: AudioT,
10
- config: TruthFunctionConfig) -> Truth:
11
- from sonusai import SonusAIError
7
+ def truth_function(
8
+ target_audio: AudioT,
9
+ noise_audio: AudioT,
10
+ mixture_audio: AudioT,
11
+ config: TruthConfig,
12
+ feature: str,
13
+ num_classes: int,
14
+ class_indices: list[int],
15
+ target_gain: float,
16
+ ) -> Truth:
12
17
  from sonusai.mixture import truth_functions
13
- from .truth_functions.data import Data
14
18
 
15
- data = Data(target_audio, noise_audio, mixture_audio, config)
16
- if data.config.target_gain == 0:
17
- return data.truth
19
+ from .truth_functions.datatypes import TruthFunctionConfig
20
+ from .truth_functions.datatypes import TruthFunctionData
21
+
22
+ t_config = TruthFunctionConfig(
23
+ feature=feature,
24
+ num_classes=num_classes,
25
+ class_indices=class_indices,
26
+ target_gain=target_gain,
27
+ config=config.config,
28
+ )
29
+ t_data = TruthFunctionData(target_audio, noise_audio, mixture_audio)
18
30
 
19
31
  try:
20
- return getattr(truth_functions, data.config.function)(data)
21
- except AttributeError:
22
- raise SonusAIError(f'Unsupported truth function: {data.config.function}')
32
+ return getattr(truth_functions, config.function)(t_data, t_config)
33
+ except AttributeError as e:
34
+ raise AttributeError(f"Unsupported truth function: {config.function}") from e
35
+ except Exception as e:
36
+ raise RuntimeError(f"Error in truth function '{config.function}': {e}") from e
23
37
 
24
38
 
25
39
  def get_truth_indices_for_mixid(mixdb: MixtureDatabase, mixid: int) -> list[int]:
26
40
  """Get a list of truth indices for a given mixid."""
27
- from .targets import get_truth_indices_for_target
28
-
29
41
  indices: list[int] = []
30
42
  for target_id in [target.file_id for target in mixdb.mixture(mixid).targets]:
31
- indices.append(*get_truth_indices_for_target(mixdb.target_file(target_id)))
43
+ indices.append(*mixdb.target_file(target_id).class_indices)
32
44
 
33
- return sorted(list(set(indices)))
45
+ return sorted(set(indices))
34
46
 
35
47
 
36
- def truth_reduction(x: Truth, func: str) -> Truth:
48
+ def truth_stride_reduction(truth: Truth, function: str) -> Truth:
49
+ """Reduce stride dimension of truth.
50
+
51
+ :param truth: Truth data [frames, stride, truth_parameters]
52
+ :param function: Truth stride reduction function name
53
+ :return: Stride reduced truth data [frames, stride or 1, truth_parameters]
54
+ """
37
55
  import numpy as np
38
56
 
39
- from sonusai import SonusAIError
57
+ if truth.ndim != 3:
58
+ raise ValueError("Invalid truth shape")
59
+
60
+ if function == "none":
61
+ return truth
40
62
 
41
- if func == 'max':
42
- return np.max(x, axis=0)
63
+ if function == "max":
64
+ return np.max(truth, axis=1, keepdims=True)
43
65
 
44
- if func == 'mean':
45
- return np.mean(x, axis=0)
66
+ if function == "mean":
67
+ return np.mean(truth, axis=1, keepdims=True)
46
68
 
47
- if func == 'index0':
48
- return np.squeeze(x[0, :])
69
+ if function == "first":
70
+ return truth[:, 0, :].reshape((truth.shape[0], 1, truth.shape[2]))
49
71
 
50
- raise SonusAIError(f'Invalid truth reduction function: {func}')
72
+ raise ValueError(f"Invalid truth stride reduction function: {function}")
@@ -1,13 +1,39 @@
1
1
  # SonusAI truth functions
2
+ # ruff: noqa: F401
3
+
2
4
  from .crm import crm
5
+ from .crm import crm_parameters
6
+ from .crm import crm_validate
3
7
  from .crm import crmp
8
+ from .crm import crmp_parameters
9
+ from .crm import crmp_validate
4
10
  from .energy import energy_f
11
+ from .energy import energy_f_parameters
12
+ from .energy import energy_f_validate
5
13
  from .energy import energy_t
14
+ from .energy import energy_t_parameters
15
+ from .energy import energy_t_validate
6
16
  from .energy import mapped_snr_f
17
+ from .energy import mapped_snr_f_parameters
18
+ from .energy import mapped_snr_f_validate
7
19
  from .energy import snr_f
20
+ from .energy import snr_f_parameters
21
+ from .energy import snr_f_validate
8
22
  from .file import file
23
+ from .file import file_parameters
24
+ from .file import file_validate
9
25
  from .phoneme import phoneme
26
+ from .phoneme import phoneme_parameters
27
+ from .phoneme import phoneme_validate
10
28
  from .sed import sed
29
+ from .sed import sed_parameters
30
+ from .sed import sed_validate
11
31
  from .target import target_f
32
+ from .target import target_f_parameters
33
+ from .target import target_f_validate
12
34
  from .target import target_mixture_f
35
+ from .target import target_mixture_f_parameters
36
+ from .target import target_mixture_f_validate
13
37
  from .target import target_swin_f
38
+ from .target import target_swin_f_parameters
39
+ from .target import target_swin_f_validate
@@ -1,25 +1,26 @@
1
1
  from sonusai.mixture.datatypes import Truth
2
- from sonusai.mixture.truth_functions.data import Data
2
+ from sonusai.mixture.truth_functions.datatypes import TruthFunctionConfig
3
+ from sonusai.mixture.truth_functions.datatypes import TruthFunctionData
3
4
 
4
5
 
5
- def _core(data: Data, polar: bool) -> Truth:
6
+ def _core(data: TruthFunctionData, config: TruthFunctionConfig, polar: bool) -> Truth:
6
7
  import numpy as np
7
8
 
8
- from sonusai import SonusAIError
9
-
10
- if data.config.num_classes != data.target_fft.bins:
11
- raise SonusAIError(f'Invalid num_classes for crm truth: {data.config.num_classes}')
12
-
13
- if data.target_fft.bins != data.noise_fft.bins:
14
- raise SonusAIError('Transform size mismatch for crm truth')
15
-
16
- for offset in data.offsets:
17
- target_f = data.target_fft.execute(data.target_audio[offset:offset + data.frame_size]).astype(np.complex64)
18
- noise_f = data.noise_fft.execute(data.noise_audio[offset:offset + data.frame_size]).astype(np.complex64)
9
+ if config.target_fft.bins != config.noise_fft.bins:
10
+ raise ValueError("Transform size mismatch for crm truth")
11
+
12
+ frames = len(data.target_audio) // config.frame_size
13
+ truth = np.empty((frames, config.target_fft.bins * 2), dtype=np.float32)
14
+ for frame in range(frames):
15
+ offset = frame * config.frame_size
16
+ target_f = config.target_fft.execute(data.target_audio[offset : offset + config.frame_size]).astype(
17
+ np.complex64
18
+ )
19
+ noise_f = config.noise_fft.execute(data.noise_audio[offset : offset + config.frame_size]).astype(np.complex64)
19
20
  mixture_f = target_f + noise_f
20
21
 
21
22
  crm_data = np.empty(target_f.shape, dtype=np.complex64)
22
- with np.nditer(target_f, flags=['multi_index'], op_flags=[['readwrite']]) as it:
23
+ with np.nditer(target_f, flags=["multi_index"], op_flags=[["readwrite"]]) as it:
23
24
  for _ in it:
24
25
  num = target_f[it.multi_index]
25
26
  den = mixture_f[it.multi_index]
@@ -30,44 +31,61 @@ def _core(data: Data, polar: bool) -> Truth:
30
31
  else:
31
32
  crm_data[it.multi_index] = num / den
32
33
 
33
- indices = slice(offset, offset + data.frame_size)
34
+ truth[frame, : config.target_fft.bins] = np.absolute(crm_data) if polar else np.real(crm_data)
35
+ truth[frame, config.target_fft.bins :] = np.angle(crm_data) if polar else np.imag(crm_data)
36
+
37
+ return truth
34
38
 
35
- def c1(c_data: np.ndarray, is_polar: bool) -> np.ndarray:
36
- if is_polar:
37
- return np.absolute(c_data)
38
- return np.real(c_data)
39
39
 
40
- def c2(c_data: np.ndarray, is_polar: bool) -> np.ndarray:
41
- if is_polar:
42
- return np.angle(c_data)
43
- return np.imag(c_data)
40
+ def crm_validate(_config: dict) -> None:
41
+ pass
44
42
 
45
- for index in data.zero_based_indices:
46
- data.truth[indices, index:index + data.target_fft.bins] = c1(crm_data, polar)
47
- data.truth[indices, (index + data.target_fft.bins):(index + 2 * data.target_fft.bins)] = c2(crm_data, polar)
48
43
 
49
- return data.truth
44
+ def crm_parameters(config: TruthFunctionConfig) -> int:
45
+ return config.target_fft.bins * 2
50
46
 
51
47
 
52
- def crm(data: Data) -> Truth:
48
+ def crm(data: TruthFunctionData, config: TruthFunctionConfig) -> Truth:
53
49
  """Complex ratio mask truth generation function
54
50
 
55
- Calculates the true complex ratio mask (CRM) truth which is a complex number
56
- per bin = Mr + j*Mi. For a given noisy STFT bin value Y, it is used as
51
+ Calculates the true complex ratio mask (CRM) truth which is a complex number
52
+ per bin = Mr + j*Mi. For a given noisy STFT bin value Y, it is used as
57
53
 
58
- (Mr*Yr + Mi*Yi) / (Yr^2 + Yi^2) + j*(Mi*Yr - Mr*Yi)/ (Yr^2 + Yi^2)
54
+ (Mr*Yr + Mi*Yi) / (Yr^2 + Yi^2) + j*(Mi*Yr - Mr*Yi)/ (Yr^2 + Yi^2)
59
55
 
60
- Output shape: [:, bins]
56
+ Output shape: [:, 2 * bins]
61
57
  """
62
- return _core(data=data, polar=False)
58
+ import numpy as np
59
+
60
+ frames = config.target_fft.frames(data.target_audio)
61
+ parameters = crm_parameters(config)
62
+ if config.target_gain == 0:
63
+ return np.zeros((frames, parameters), dtype=np.float32)
64
+
65
+ return _core(data=data, config=config, polar=False)
63
66
 
64
67
 
65
- def crmp(data: Data) -> Truth:
68
+ def crmp_validate(_config: dict) -> None:
69
+ pass
70
+
71
+
72
+ def crmp_parameters(config: TruthFunctionConfig) -> int:
73
+ return config.target_fft.bins * 2
74
+
75
+
76
+ def crmp(data: TruthFunctionData, config: TruthFunctionConfig) -> Truth:
66
77
  """Complex ratio mask polar truth generation function
67
78
 
68
- Same as the crm function except the results are magnitude and phase
69
- instead of real and imaginary.
79
+ Same as the crm function except the results are magnitude and phase
80
+ instead of real and imaginary.
70
81
 
71
- Output shape: [:, bins]
82
+ Output shape: [:, bins]
72
83
  """
73
- return _core(data=data, polar=True)
84
+ import numpy as np
85
+
86
+ frames = config.target_fft.frames(data.target_audio)
87
+ parameters = crmp_parameters(config)
88
+ if config.target_gain == 0:
89
+ return np.zeros((frames, parameters), dtype=np.float32)
90
+
91
+ return _core(data=data, config=config, polar=True)
@@ -0,0 +1,37 @@
1
+ from dataclasses import dataclass
2
+
3
+ from sonusai.mixture.datatypes import AudioT
4
+
5
+
6
+ class TruthFunctionConfig:
7
+ def __init__(self, feature: str, num_classes: int, class_indices: list[int], target_gain: float, config: dict):
8
+ from pyaaware import ForwardTransform
9
+ from pyaaware import InverseTransform
10
+ from pyaaware import feature_forward_transform_config
11
+ from pyaaware import feature_inverse_transform_config
12
+ from pyaaware import feature_parameters
13
+
14
+ self.feature = feature
15
+ self.num_classes = num_classes
16
+ self.class_indices = class_indices
17
+ self.target_gain = target_gain
18
+ self.config = config
19
+
20
+ self.feature_parameters = feature_parameters(feature)
21
+ ft_config = feature_forward_transform_config(feature)
22
+ it_config = feature_inverse_transform_config(feature)
23
+
24
+ self.ttype = it_config["ttype"]
25
+ self.frame_size = it_config["overlap"]
26
+
27
+ self.target_fft = ForwardTransform(**ft_config)
28
+ self.noise_fft = ForwardTransform(**ft_config)
29
+ self.mixture_fft = ForwardTransform(**ft_config)
30
+ self.swin = InverseTransform(**it_config).window
31
+
32
+
33
+ @dataclass
34
+ class TruthFunctionData:
35
+ target_audio: AudioT
36
+ noise_audio: AudioT
37
+ mixture_audio: AudioT
@@ -1,69 +1,43 @@
1
1
  import numpy as np
2
2
 
3
3
  from sonusai.mixture.datatypes import Truth
4
- from sonusai.mixture.truth_functions.data import Data
4
+ from sonusai.mixture.truth_functions.datatypes import TruthFunctionConfig
5
+ from sonusai.mixture.truth_functions.datatypes import TruthFunctionData
5
6
 
6
7
 
7
- def _core(data: Data, mapped: bool, snr: bool) -> Truth:
8
- from sonusai import SonusAIError
8
+ def _core(data: TruthFunctionData, config: TruthFunctionConfig, mapped: bool, snr: bool) -> Truth:
9
9
  from sonusai.utils import compute_energy_f
10
10
 
11
- snr_db_mean = None
12
- snr_db_std = None
13
- if mapped:
14
- if data.config.config is None:
15
- raise SonusAIError('Truth function mapped SNR missing config')
16
-
17
- parameters = ['snr_db_mean', 'snr_db_std']
18
- for parameter in parameters:
19
- if parameter not in data.config.config:
20
- raise SonusAIError(f'Truth function mapped_snr_f config missing required parameter: {parameter}')
21
-
22
- snr_db_mean = data.config.config['snr_db_mean']
23
- if len(snr_db_mean) != data.target_fft.bins:
24
- raise SonusAIError(f'Truth function mapped_snr_f snr_db_mean does not have {data.target_fft.bins} elements')
25
-
26
- snr_db_std = data.config.config['snr_db_std']
27
- if len(snr_db_std) != data.target_fft.bins:
28
- raise SonusAIError(f'Truth function mapped_snr_f snr_db_std does not have {data.target_fft.bins} elements')
29
-
30
- for index in data.zero_based_indices:
31
- if index + data.target_fft.bins > data.config.num_classes:
32
- raise SonusAIError('Truth index exceeds the number of classes')
33
-
34
- target_energy = compute_energy_f(time_domain=data.target_audio, transform=data.target_fft)
11
+ target_energy = compute_energy_f(time_domain=data.target_audio, transform=config.target_fft)
35
12
  noise_energy = None
36
13
  if snr:
37
- noise_energy = compute_energy_f(time_domain=data.noise_audio, transform=data.noise_fft)
38
-
39
- if len(target_energy) != len(data.offsets):
40
- raise SonusAIError(f'Number of frames in target_energy, {len(target_energy)},'
41
- f' is not number of frames in truth, {len(data.offsets)}')
14
+ noise_energy = compute_energy_f(time_domain=data.noise_audio, transform=config.noise_fft)
42
15
 
43
- for idx, offset in enumerate(data.offsets):
44
- tmp = target_energy[idx]
16
+ frames = len(target_energy)
17
+ truth = np.empty((frames, config.target_fft.bins), dtype=np.float32)
18
+ for frame in range(frames):
19
+ tmp = target_energy[frame]
45
20
 
46
- if snr:
47
- old_err = np.seterr(divide='ignore', invalid='ignore')
48
- tmp /= noise_energy[idx]
21
+ if noise_energy is not None:
22
+ old_err = np.seterr(divide="ignore", invalid="ignore")
23
+ tmp /= noise_energy[frame]
49
24
  np.seterr(**old_err)
50
25
 
51
26
  tmp = np.nan_to_num(tmp, nan=-np.inf, posinf=np.inf, neginf=-np.inf)
52
27
 
53
28
  if mapped:
54
- tmp = _calculate_mapped_snr_f(tmp, snr_db_mean, snr_db_std)
29
+ tmp = _calculate_mapped_snr_f(tmp, config.config["snr_db_mean"], config.config["snr_db_std"])
55
30
 
56
- for index in data.zero_based_indices:
57
- data.truth[offset:offset + data.frame_size, index:index + data.target_fft.bins] = tmp
31
+ truth[frame] = tmp
58
32
 
59
- return data.truth
33
+ return truth
60
34
 
61
35
 
62
36
  def _calculate_mapped_snr_f(truth_f: np.ndarray, snr_db_mean: np.ndarray, snr_db_std: np.ndarray) -> np.ndarray:
63
37
  """Calculate mapped SNR from standard SNR energy per bin/class."""
64
38
  import scipy.special as sc
65
39
 
66
- old_err = np.seterr(divide='ignore', invalid='ignore')
40
+ old_err = np.seterr(divide="ignore", invalid="ignore")
67
41
  num = 10 * np.log10(np.double(truth_f)) - np.double(snr_db_mean)
68
42
  den = np.double(snr_db_std) * np.sqrt(2)
69
43
  q = num / den
@@ -74,7 +48,15 @@ def _calculate_mapped_snr_f(truth_f: np.ndarray, snr_db_mean: np.ndarray, snr_db
74
48
  return result.astype(np.float32)
75
49
 
76
50
 
77
- def energy_f(data: Data) -> Truth:
51
+ def energy_f_validate(_config: dict) -> None:
52
+ pass
53
+
54
+
55
+ def energy_f_parameters(config: TruthFunctionConfig) -> int:
56
+ return config.target_fft.bins
57
+
58
+
59
+ def energy_f(data: TruthFunctionData, config: TruthFunctionConfig) -> Truth:
78
60
  """Frequency domain energy truth generation function
79
61
 
80
62
  Calculates the true energy per bin:
@@ -85,10 +67,23 @@ def energy_f(data: Data) -> Truth:
85
67
 
86
68
  Output shape: [:, bins]
87
69
  """
88
- return _core(data=data, mapped=False, snr=False)
70
+ frames = config.target_fft.frames(data.target_audio)
71
+ parameters = energy_f_parameters(config)
72
+ if config.target_gain == 0:
73
+ return np.zeros((frames, parameters), dtype=np.float32)
74
+
75
+ return _core(data=data, config=config, mapped=False, snr=False)
76
+
77
+
78
+ def snr_f_validate(_config: dict) -> None:
79
+ pass
80
+
81
+
82
+ def snr_f_parameters(config: TruthFunctionConfig) -> int:
83
+ return config.target_fft.bins
89
84
 
90
85
 
91
- def snr_f(data: Data) -> Truth:
86
+ def snr_f(data: TruthFunctionData, config: TruthFunctionConfig) -> Truth:
92
87
  """Frequency domain SNR truth function documentation
93
88
 
94
89
  Calculates the true SNR per bin:
@@ -99,18 +94,54 @@ def snr_f(data: Data) -> Truth:
99
94
 
100
95
  Output shape: [:, bins]
101
96
  """
102
- return _core(data=data, mapped=False, snr=True)
97
+ frames = config.target_fft.frames(data.target_audio)
98
+ parameters = snr_f_parameters(config)
99
+ if config.target_gain == 0:
100
+ return np.zeros((frames, parameters), dtype=np.float32)
103
101
 
102
+ return _core(data=data, config=config, mapped=False, snr=True)
104
103
 
105
- def mapped_snr_f(data: Data) -> Truth:
104
+
105
+ def mapped_snr_f_validate(config: TruthFunctionConfig) -> None:
106
+ if len(config.config) == 0:
107
+ raise AttributeError("mapped_snr_f truth function is missing config")
108
+
109
+ for parameter in ("snr_db_mean", "snr_db_std"):
110
+ if parameter not in config.config:
111
+ raise AttributeError(f"mapped_snr_f truth function is missing required '{parameter}'")
112
+
113
+ if len(config.config[parameter]) != config.target_fft.bins:
114
+ raise ValueError(
115
+ f"mapped_snr_f truth function '{parameter}' does not have {config.target_fft.bins} elements"
116
+ )
117
+
118
+
119
+ def mapped_snr_f_parameters(config: TruthFunctionConfig) -> int:
120
+ return config.target_fft.bins
121
+
122
+
123
+ def mapped_snr_f(data: TruthFunctionData, config: TruthFunctionConfig) -> Truth:
106
124
  """Frequency domain mapped SNR truth function documentation
107
125
 
108
126
  Output shape: [:, bins]
109
127
  """
110
- return _core(data=data, mapped=True, snr=True)
128
+ frames = config.target_fft.frames(data.target_audio)
129
+ parameters = mapped_snr_f_parameters(config)
130
+ if config.target_gain == 0:
131
+ return np.zeros((frames, parameters), dtype=np.float32)
111
132
 
133
+ return _core(data=data, config=config, mapped=True, snr=True)
112
134
 
113
- def energy_t(data: Data) -> Truth:
135
+
136
+ def energy_t_validate(_config: dict) -> None:
137
+ pass
138
+
139
+
140
+ def energy_t_parameters(_config: TruthFunctionConfig) -> int:
141
+ return 1
142
+
143
+
144
+ def energy_t(data: TruthFunctionData, config: TruthFunctionConfig) -> Truth:
114
145
  """Time domain energy truth function documentation
115
146
 
116
147
  Calculates the true time domain energy of each frame:
@@ -134,14 +165,9 @@ def energy_t(data: Data) -> Truth:
134
165
  """
135
166
  import torch
136
167
 
137
- from sonusai import SonusAIError
138
-
139
- target_energy = data.target_fft.execute_all(torch.from_numpy(data.target_audio))[1].numpy()
140
- if len(target_energy) != len(data.offsets):
141
- raise SonusAIError(f'Number of frames in target_energy, {len(target_energy)},'
142
- f' is not number of frames in truth, {len(data.offsets)}')
143
-
144
- for offset in data.offsets:
145
- data.truth[offset:offset + data.frame_size, data.zero_based_indices] = np.float32(target_energy)
168
+ frames = config.target_fft.frames(data.target_audio)
169
+ parameters = energy_t_parameters(config)
170
+ if config.target_gain == 0:
171
+ return np.zeros((frames, parameters), dtype=np.float32)
146
172
 
147
- return data.truth
173
+ return config.target_fft.execute_all(torch.from_numpy(data.target_audio))[1].numpy()
@@ -1,44 +1,44 @@
1
1
  from sonusai.mixture.datatypes import Truth
2
- from sonusai.mixture.truth_functions.data import Data
2
+ from sonusai.mixture.truth_functions.datatypes import TruthFunctionConfig
3
+ from sonusai.mixture.truth_functions.datatypes import TruthFunctionData
3
4
 
4
5
 
5
- def file(data: Data) -> Truth:
6
- """file truth function documentation
7
- """
6
+ def file_validate(config: dict) -> None:
8
7
  import h5py
9
- import numpy as np
10
8
 
11
- from sonusai import SonusAIError
9
+ if len(config) == 0:
10
+ raise AttributeError("file truth function is missing config")
11
+
12
+ if "file" not in config:
13
+ raise AttributeError("file truth function is missing required 'file'")
14
+
15
+ with h5py.File(config["file"], "r") as f:
16
+ if "truth_f" not in f:
17
+ raise ValueError("Truth file does not contain truth_f dataset")
18
+
12
19
 
13
- if data.config.config is None:
14
- raise SonusAIError('Truth function file missing config')
20
+ def file_parameters(config: TruthFunctionConfig) -> int:
21
+ import h5py
22
+ import numpy as np
15
23
 
16
- parameters = ['file']
17
- for parameter in parameters:
18
- if 'file' not in data.config.config:
19
- raise SonusAIError(f'Truth function file config missing required parameter: {parameter}')
24
+ with h5py.File(config.config["file"], "r") as f:
25
+ truth = np.array(f["truth_f"])
20
26
 
21
- with h5py.File(data.config.config['file'], 'r') as f:
22
- if 'truth_t' not in f:
23
- raise SonusAIError('Truth file does not contain truth_t dataset')
24
- truth_in = np.array(f['truth_t'])
27
+ return truth.shape[-1]
25
28
 
26
- if truth_in.ndim != 2:
27
- raise SonusAIError('Truth file data is not 2 dimensions')
28
29
 
29
- if truth_in.shape[0] != len(data.target_audio):
30
- raise SonusAIError('Truth file does not contain the right amount of samples')
30
+ def file(data: TruthFunctionData, config: TruthFunctionConfig) -> Truth:
31
+ """file truth function documentation"""
32
+ import h5py
33
+ import numpy as np
31
34
 
32
- if len(data.zero_based_indices) > 1:
33
- if len(data.zero_based_indices) != truth_in.shape[1]:
34
- raise SonusAIError('Truth file does not contain the right amount of classes')
35
+ with h5py.File(config.config["file"], "r") as f:
36
+ truth = np.array(f["truth_f"])
35
37
 
36
- data.truth[:, data.zero_based_indices] = truth_in
37
- else:
38
- index = data.zero_based_indices[0]
39
- if index + truth_in.shape[1] > data.config.num_classes:
40
- raise SonusAIError('Truth file contains too many classes')
38
+ if truth.ndim != 2:
39
+ raise ValueError("Truth file data is not 2 dimensions")
41
40
 
42
- data.truth[:, index:index + truth_in.shape[1]] = truth_in
41
+ if truth.shape[0] != len(data.target_audio) // config.frame_size:
42
+ raise ValueError("Truth file does not contain the right amount of frames")
43
43
 
44
- return data.truth
44
+ return truth
@@ -1,12 +1,19 @@
1
1
  from sonusai.mixture.datatypes import Truth
2
- from sonusai.mixture.truth_functions.data import Data
2
+ from sonusai.mixture.truth_functions.datatypes import TruthFunctionConfig
3
+ from sonusai.mixture.truth_functions.datatypes import TruthFunctionData
3
4
 
4
5
 
5
- def phoneme(_data: Data) -> Truth:
6
+ def phoneme_validate(_config: dict) -> None:
7
+ raise NotImplementedError("Truth function phoneme is not supported yet")
8
+
9
+
10
+ def phoneme_parameters(_config: TruthFunctionConfig) -> int:
11
+ raise NotImplementedError("Truth function phoneme is not supported yet")
12
+
13
+
14
+ def phoneme(_data: TruthFunctionData, _config: TruthFunctionConfig) -> Truth:
6
15
  """Read in .txt transcript and run a Python function to generate text grid data
7
- (indicating which phonemes are active). Then generate truth based on this data and put
8
- in the correct classes based on the index in the config.
16
+ (indicating which phonemes are active). Then generate truth based on this data and put
17
+ in the correct classes based on the index in the config.
9
18
  """
10
- from sonusai import SonusAIError
11
-
12
- raise SonusAIError('Truth function phoneme is not supported yet')
19
+ raise NotImplementedError("Truth function phoneme is not supported yet")