sonusai 0.19.6__py3-none-any.whl → 0.19.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. sonusai/__init__.py +1 -1
  2. sonusai/aawscd_probwrite.py +1 -1
  3. sonusai/calc_metric_spenh.py +1 -1
  4. sonusai/genft.py +29 -14
  5. sonusai/genmetrics.py +60 -42
  6. sonusai/genmix.py +41 -29
  7. sonusai/genmixdb.py +56 -64
  8. sonusai/metrics/calc_class_weights.py +1 -3
  9. sonusai/metrics/calc_optimal_thresholds.py +2 -2
  10. sonusai/metrics/calc_phase_distance.py +1 -1
  11. sonusai/metrics/calc_speech.py +6 -6
  12. sonusai/metrics/class_summary.py +6 -15
  13. sonusai/metrics/confusion_matrix_summary.py +11 -27
  14. sonusai/metrics/one_hot.py +3 -3
  15. sonusai/metrics/snr_summary.py +7 -7
  16. sonusai/mixture/__init__.py +2 -17
  17. sonusai/mixture/augmentation.py +5 -6
  18. sonusai/mixture/class_count.py +1 -1
  19. sonusai/mixture/config.py +36 -46
  20. sonusai/mixture/data_io.py +30 -1
  21. sonusai/mixture/datatypes.py +29 -40
  22. sonusai/mixture/db_datatypes.py +1 -1
  23. sonusai/mixture/feature.py +3 -23
  24. sonusai/mixture/generation.py +161 -204
  25. sonusai/mixture/helpers.py +29 -187
  26. sonusai/mixture/mixdb.py +386 -159
  27. sonusai/mixture/soundfile_audio.py +1 -1
  28. sonusai/mixture/sox_audio.py +4 -4
  29. sonusai/mixture/sox_augmentation.py +1 -1
  30. sonusai/mixture/target_class_balancing.py +9 -11
  31. sonusai/mixture/targets.py +23 -20
  32. sonusai/mixture/torchaudio_audio.py +18 -7
  33. sonusai/mixture/torchaudio_augmentation.py +3 -4
  34. sonusai/mixture/truth.py +21 -34
  35. sonusai/mixture/truth_functions/__init__.py +6 -0
  36. sonusai/mixture/truth_functions/crm.py +51 -37
  37. sonusai/mixture/truth_functions/energy.py +95 -50
  38. sonusai/mixture/truth_functions/file.py +12 -8
  39. sonusai/mixture/truth_functions/metadata.py +24 -0
  40. sonusai/mixture/truth_functions/metrics.py +28 -0
  41. sonusai/mixture/truth_functions/phoneme.py +4 -5
  42. sonusai/mixture/truth_functions/sed.py +32 -23
  43. sonusai/mixture/truth_functions/target.py +62 -29
  44. sonusai/mkwav.py +20 -19
  45. sonusai/queries/queries.py +9 -15
  46. sonusai/speech/l2arctic.py +6 -2
  47. sonusai/summarize_metric_spenh.py +1 -1
  48. sonusai/utils/__init__.py +1 -0
  49. sonusai/utils/asr_functions/aaware_whisper.py +1 -1
  50. sonusai/utils/audio_devices.py +27 -18
  51. sonusai/utils/docstring.py +6 -3
  52. sonusai/utils/energy_f.py +5 -3
  53. sonusai/utils/human_readable_size.py +6 -6
  54. sonusai/utils/load_object.py +15 -0
  55. sonusai/utils/onnx_utils.py +2 -2
  56. sonusai/utils/print_mixture_details.py +3 -3
  57. {sonusai-0.19.6.dist-info → sonusai-0.19.9.dist-info}/METADATA +2 -2
  58. {sonusai-0.19.6.dist-info → sonusai-0.19.9.dist-info}/RECORD +60 -58
  59. sonusai/mixture/truth_functions/datatypes.py +0 -37
  60. {sonusai-0.19.6.dist-info → sonusai-0.19.9.dist-info}/WHEEL +0 -0
  61. {sonusai-0.19.6.dist-info → sonusai-0.19.9.dist-info}/entry_points.txt +0 -0
@@ -32,7 +32,7 @@ def _raw_read(name: str | Path) -> tuple[AudioT, int]:
32
32
  else:
33
33
  raise OSError(f"Error reading {name}: {e}") from e
34
34
 
35
- return np.squeeze(raw[:, 0]), sample_rate
35
+ return np.squeeze(raw[:, 0].astype(np.float32)), sample_rate
36
36
 
37
37
 
38
38
  def get_sample_rate(name: str | Path) -> int:
@@ -207,7 +207,7 @@ class Transformer(SoxTransformer):
207
207
 
208
208
  return self
209
209
 
210
- def build(
210
+ def build( # pyright: ignore [reportIncompatibleMethodOverride]
211
211
  self,
212
212
  input_filepath: str | Path | None = None,
213
213
  output_filepath: str | Path | None = None,
@@ -320,11 +320,11 @@ class Transformer(SoxTransformer):
320
320
  logger.info("Created %s with effects: %s", output_filepath, " ".join(self.effects_log))
321
321
 
322
322
  if return_output:
323
- return status, out, err
323
+ return status, out, err # pyright: ignore [reportReturnType]
324
324
 
325
325
  return True, None, None
326
326
 
327
- def build_array(
327
+ def build_array( # pyright: ignore [reportIncompatibleMethodOverride]
328
328
  self,
329
329
  input_filepath: str | Path | None = None,
330
330
  input_array: np.ndarray | None = None,
@@ -465,7 +465,7 @@ class Transformer(SoxTransformer):
465
465
  if status != 0:
466
466
  raise SoxError(f"Stdout: {out}\nStderr: {err}")
467
467
 
468
- out = np.frombuffer(out, dtype=encoding_out)
468
+ out = np.frombuffer(out, dtype=encoding_out) # pyright: ignore [reportArgumentType, reportCallIssue]
469
469
  if output_format["channels"] > 1:
470
470
  out = out.reshape(
471
471
  (output_format["channels"], int(len(out) / output_format["channels"])),
@@ -118,7 +118,7 @@ def apply_impulse_response(audio: AudioT, ir: ImpulseResponseData) -> AudioT:
118
118
  # Apply IR and convert back to global sample rate
119
119
  tfm = Transformer()
120
120
  tfm.set_output_format(rate=SAMPLE_RATE)
121
- tfm.fir(coefficients=temp.name)
121
+ tfm.fir(coefficients=temp.name) # pyright: ignore [reportArgumentType]
122
122
  try:
123
123
  audio_out = tfm.build_array(input_array=audio_out, sample_rate_in=ir.sample_rate)
124
124
  except Exception as e:
@@ -1,19 +1,17 @@
1
1
  from sonusai.mixture.datatypes import AugmentationRule
2
- from sonusai.mixture.datatypes import AugmentationRules
3
- from sonusai.mixture.datatypes import AugmentedTargets
2
+ from sonusai.mixture.datatypes import AugmentedTarget
4
3
  from sonusai.mixture.datatypes import TargetFile
5
- from sonusai.mixture.datatypes import TargetFiles
6
4
 
7
5
 
8
6
  def balance_targets(
9
- augmented_targets: AugmentedTargets,
10
- targets: TargetFiles,
11
- target_augmentations: AugmentationRules,
7
+ augmented_targets: list[AugmentedTarget],
8
+ targets: list[TargetFile],
9
+ target_augmentations: list[AugmentationRule],
12
10
  class_balancing_augmentation: AugmentationRule,
13
11
  num_classes: int,
14
12
  num_ir: int,
15
13
  mixups: list[int] | None = None,
16
- ) -> tuple[AugmentedTargets, AugmentationRules]:
14
+ ) -> tuple[list[AugmentedTarget], list[AugmentationRule]]:
17
15
  import math
18
16
 
19
17
  from .augmentation import get_mixups
@@ -64,15 +62,15 @@ def balance_targets(
64
62
 
65
63
 
66
64
  def _get_unused_balancing_augmentation(
67
- augmented_targets: AugmentedTargets,
68
- targets: TargetFiles,
69
- target_augmentations: AugmentationRules,
65
+ augmented_targets: list[AugmentedTarget],
66
+ targets: list[TargetFile],
67
+ target_augmentations: list[AugmentationRule],
70
68
  class_balancing_augmentation: AugmentationRule,
71
69
  target_id: int,
72
70
  mixup: int,
73
71
  num_ir: int,
74
72
  first_cba_id: int,
75
- ) -> tuple[int, AugmentationRules]:
73
+ ) -> tuple[int, list[AugmentationRule]]:
76
74
  """Get an unused balancing augmentation for a given target file index"""
77
75
  from dataclasses import asdict
78
76
 
@@ -1,21 +1,20 @@
1
- from sonusai.mixture.datatypes import AugmentationRules
1
+ from sonusai.mixture.datatypes import AugmentationRule
2
2
  from sonusai.mixture.datatypes import AugmentedTarget
3
- from sonusai.mixture.datatypes import AugmentedTargets
4
- from sonusai.mixture.datatypes import TargetFiles
3
+ from sonusai.mixture.datatypes import TargetFile
5
4
 
6
5
 
7
6
  def get_augmented_targets(
8
- target_files: TargetFiles,
9
- target_augmentations: AugmentationRules,
7
+ target_files: list[TargetFile],
8
+ target_augmentations: list[AugmentationRule],
10
9
  mixups: list[int] | None = None,
11
- ) -> AugmentedTargets:
10
+ ) -> list[AugmentedTarget]:
12
11
  from .augmentation import get_augmentation_indices_for_mixup
13
12
  from .augmentation import get_mixups
14
13
 
15
14
  if mixups is None:
16
15
  mixups = get_mixups(target_augmentations)
17
16
 
18
- augmented_targets: AugmentedTargets = []
17
+ augmented_targets: list[AugmentedTarget] = []
19
18
  for mixup in mixups:
20
19
  augmentation_indices = get_augmentation_indices_for_mixup(target_augmentations, mixup)
21
20
  for target_index in range(len(target_files)):
@@ -30,15 +29,17 @@ def get_augmented_targets(
30
29
  return augmented_targets
31
30
 
32
31
 
33
- def get_class_index_for_augmented_target(augmented_target: AugmentedTarget, targets: TargetFiles) -> list[int]:
32
+ def get_class_index_for_augmented_target(augmented_target: AugmentedTarget, targets: list[TargetFile]) -> list[int]:
34
33
  return targets[augmented_target.target_id].class_indices
35
34
 
36
35
 
37
- def get_mixup_for_augmented_target(augmented_target: AugmentedTarget, augmentations: AugmentationRules) -> int:
36
+ def get_mixup_for_augmented_target(augmented_target: AugmentedTarget, augmentations: list[AugmentationRule]) -> int:
38
37
  return augmentations[augmented_target.target_augmentation_id].mixup
39
38
 
40
39
 
41
- def get_target_ids_for_class_index(targets: TargetFiles, class_index: int, allow_multiple: bool = False) -> list[int]:
40
+ def get_target_ids_for_class_index(
41
+ targets: list[TargetFile], class_index: int, allow_multiple: bool = False
42
+ ) -> list[int]:
42
43
  """Get a list of target indices containing the given class index.
43
44
 
44
45
  If allow_multiple is True, then include targets that contain multiple class indices.
@@ -55,9 +56,9 @@ def get_target_ids_for_class_index(targets: TargetFiles, class_index: int, allow
55
56
 
56
57
 
57
58
  def get_augmented_target_ids_for_class_index(
58
- augmented_targets: AugmentedTargets,
59
- targets: TargetFiles,
60
- augmentations: AugmentationRules,
59
+ augmented_targets: list[AugmentedTarget],
60
+ targets: list[TargetFile],
61
+ augmentations: list[AugmentationRule],
61
62
  class_index: int,
62
63
  mixup: int,
63
64
  allow_multiple: bool = False,
@@ -79,9 +80,9 @@ def get_augmented_target_ids_for_class_index(
79
80
 
80
81
 
81
82
  def get_augmented_target_ids_by_class(
82
- augmented_targets: AugmentedTargets,
83
- targets: TargetFiles,
84
- target_augmentations: AugmentationRules,
83
+ augmented_targets: list[AugmentedTarget],
84
+ targets: list[TargetFile],
85
+ target_augmentations: list[AugmentationRule],
85
86
  mixup: int,
86
87
  num_classes: int,
87
88
  ) -> list[list[int]]:
@@ -99,7 +100,9 @@ def get_augmented_target_ids_by_class(
99
100
  return indices
100
101
 
101
102
 
102
- def get_target_augmentations_for_mixup(target_augmentations: AugmentationRules, mixup: int) -> AugmentationRules:
103
+ def get_target_augmentations_for_mixup(
104
+ target_augmentations: list[AugmentationRule], mixup: int
105
+ ) -> list[AugmentationRule]:
103
106
  """Get target augmentations for a given mixup value
104
107
 
105
108
  :param target_augmentations: List of target augmentation rules
@@ -110,9 +113,9 @@ def get_target_augmentations_for_mixup(target_augmentations: AugmentationRules,
110
113
 
111
114
 
112
115
  def get_augmented_target_ids_for_mixup(
113
- augmented_targets: AugmentedTargets,
114
- targets: TargetFiles,
115
- target_augmentations: AugmentationRules,
116
+ augmented_targets: list[AugmentedTarget],
117
+ targets: list[TargetFile],
118
+ target_augmentations: list[AugmentationRule],
116
119
  mixup: int,
117
120
  num_classes: int,
118
121
  ) -> list[list[int]]:
@@ -4,10 +4,16 @@ from sonusai.mixture.datatypes import AudioT
4
4
  from sonusai.mixture.datatypes import ImpulseResponseData
5
5
 
6
6
 
7
- def read_impulse_response(name: str | Path) -> ImpulseResponseData:
7
+ def read_impulse_response(
8
+ name: str | Path,
9
+ delay_compensation: bool = True,
10
+ normalize: bool = True,
11
+ ) -> ImpulseResponseData:
8
12
  """Read impulse response data using torchaudio
9
13
 
10
14
  :param name: File name
15
+ :param delay_compensation: Apply delay compensation
16
+ :param normalize: Apply normalization
11
17
  :return: ImpulseResponseData object
12
18
  """
13
19
  import numpy as np
@@ -28,14 +34,19 @@ def read_impulse_response(name: str | Path) -> ImpulseResponseData:
28
34
  raise OSError(f"Error reading {name}: {e}") from e
29
35
 
30
36
  raw = torch.squeeze(raw[0, :])
31
- offset = torch.argmax(raw)
32
- raw = raw[offset:]
33
- # Inexplicably, torch.linalg.vector_norm() causes multiprocessing contexts to hang.
34
- # Use np.linalg.norm() instead.
35
- # raw = raw / torch.linalg.vector_norm(raw)
37
+
38
+ if delay_compensation:
39
+ offset = torch.argmax(raw)
40
+ raw = raw[offset:]
36
41
 
37
42
  data = np.array(raw).astype(np.float32)
38
- data = data / np.linalg.norm(data)
43
+
44
+ if normalize:
45
+ # Inexplicably,
46
+ # data = data / torch.linalg.vector_norm(data)
47
+ # causes multiprocessing contexts to hang.
48
+ # Use np.linalg.norm() instead.
49
+ data = data / np.linalg.norm(data)
39
50
 
40
51
  return ImpulseResponseData(name=str(name), sample_rate=sample_rate, data=data)
41
52
 
@@ -20,10 +20,9 @@ def apply_augmentation(audio: AudioT, augmentation: Augmentation, frame_length:
20
20
 
21
21
  effects: list[list[str]] = []
22
22
 
23
- # TODO
24
- # Always normalize and remove normalize from list of available augmentations
25
- # Normalize to globally set level (should this be a global config parameter,
26
- # or hard-coded into the script?)
23
+ # TODO: Always normalize and remove normalize from list of available augmentations
24
+ # Normalize to globally set level (should this be a global config parameter, or hard-coded into the script?)
25
+ # TODO: Support all sox effects supported by torchaudio (torchaudio.sox_effects.effect_names())
27
26
  if augmentation.normalize is not None:
28
27
  effects.append(["norm", str(augmentation.normalize)])
29
28
 
sonusai/mixture/truth.py CHANGED
@@ -1,39 +1,26 @@
1
- from sonusai.mixture.datatypes import AudioT
2
- from sonusai.mixture.datatypes import Truth
3
- from sonusai.mixture.datatypes import TruthConfig
4
- from sonusai.mixture.mixdb import MixtureDatabase
5
-
6
-
7
- def truth_function(
8
- target_audio: AudioT,
9
- noise_audio: AudioT,
10
- mixture_audio: AudioT,
11
- config: TruthConfig,
12
- feature: str,
13
- num_classes: int,
14
- class_indices: list[int],
15
- target_gain: float,
16
- ) -> Truth:
1
+ from sonusai.mixture import MixtureDatabase
2
+ from sonusai.mixture import Truth
3
+
4
+
5
+ def truth_function(mixdb: MixtureDatabase, m_id: int) -> list[Truth]:
6
+ from sonusai.mixture import TruthDict
17
7
  from sonusai.mixture import truth_functions
18
8
 
19
- from .truth_functions.datatypes import TruthFunctionConfig
20
- from .truth_functions.datatypes import TruthFunctionData
21
-
22
- t_config = TruthFunctionConfig(
23
- feature=feature,
24
- num_classes=num_classes,
25
- class_indices=class_indices,
26
- target_gain=target_gain,
27
- config=config.config,
28
- )
29
- t_data = TruthFunctionData(target_audio, noise_audio, mixture_audio)
30
-
31
- try:
32
- return getattr(truth_functions, config.function)(t_data, t_config)
33
- except AttributeError as e:
34
- raise AttributeError(f"Unsupported truth function: {config.function}") from e
35
- except Exception as e:
36
- raise RuntimeError(f"Error in truth function '{config.function}': {e}") from e
9
+ result: list[Truth] = []
10
+ for target_index in range(len(mixdb.mixture(m_id).targets)):
11
+ truth: TruthDict = {}
12
+ target_file = mixdb.target_file(mixdb.mixture(m_id).targets[target_index].file_id)
13
+ for name, config in target_file.truth_configs.items():
14
+ try:
15
+ truth[name] = getattr(truth_functions, config.function)(mixdb, m_id, target_index, config.config)
16
+ except AttributeError as e:
17
+ raise AttributeError(f"Unsupported truth function: {config.function}") from e
18
+ except Exception as e:
19
+ raise RuntimeError(f"Error in truth function '{config.function}': {e}") from e
20
+
21
+ result.append(truth)
22
+
23
+ return result
37
24
 
38
25
 
39
26
  def get_truth_indices_for_mixid(mixdb: MixtureDatabase, mixid: int) -> list[int]:
@@ -22,6 +22,12 @@ from .energy import snr_f_validate
22
22
  from .file import file
23
23
  from .file import file_parameters
24
24
  from .file import file_validate
25
+ from .metadata import metadata
26
+ from .metadata import metadata_parameters
27
+ from .metadata import metadata_validate
28
+ from .metrics import metrics
29
+ from .metrics import metrics_parameters
30
+ from .metrics import metrics_validate
25
31
  from .phoneme import phoneme
26
32
  from .phoneme import phoneme_parameters
27
33
  from .phoneme import phoneme_validate
@@ -1,22 +1,32 @@
1
- from sonusai.mixture.datatypes import Truth
2
- from sonusai.mixture.truth_functions.datatypes import TruthFunctionConfig
3
- from sonusai.mixture.truth_functions.datatypes import TruthFunctionData
1
+ from sonusai.mixture import MixtureDatabase
2
+ from sonusai.mixture import Truth
4
3
 
5
4
 
6
- def _core(data: TruthFunctionData, config: TruthFunctionConfig, polar: bool) -> Truth:
5
+ def _core(mixdb: MixtureDatabase, m_id: int, target_index: int, parameters: int, polar: bool) -> Truth:
7
6
  import numpy as np
7
+ import torch
8
+ from pyaaware import ForwardTransform
9
+ from pyaaware import feature_forward_transform_config
10
+ from pyaaware import feature_inverse_transform_config
8
11
 
9
- if config.target_fft.bins != config.noise_fft.bins:
10
- raise ValueError("Transform size mismatch for crm truth")
12
+ target_audio = torch.from_numpy(mixdb.mixture_targets(m_id)[target_index])
13
+ t_ft = ForwardTransform(**feature_forward_transform_config(mixdb.feature))
14
+ n_ft = ForwardTransform(**feature_forward_transform_config(mixdb.feature))
11
15
 
12
- frames = len(data.target_audio) // config.frame_size
13
- truth = np.empty((frames, config.target_fft.bins * 2), dtype=np.float32)
16
+ frames = t_ft.frames(target_audio)
17
+ if mixdb.mixture(m_id).target_gain(target_index) == 0:
18
+ return np.zeros((frames, parameters), dtype=np.float32)
19
+
20
+ noise_audio = torch.from_numpy(mixdb.mixture_noise(m_id))
21
+
22
+ frame_size = feature_inverse_transform_config(mixdb.feature)["overlap"]
23
+
24
+ frames = len(target_audio) // frame_size
25
+ truth = np.empty((frames, t_ft.bins * 2), dtype=np.float32)
14
26
  for frame in range(frames):
15
- offset = frame * config.frame_size
16
- target_f = config.target_fft.execute(data.target_audio[offset : offset + config.frame_size]).astype(
17
- np.complex64
18
- )
19
- noise_f = config.noise_fft.execute(data.noise_audio[offset : offset + config.frame_size]).astype(np.complex64)
27
+ offset = frame * frame_size
28
+ target_f = t_ft.execute(target_audio[offset : offset + frame_size])[0].numpy().astype(np.complex64)
29
+ noise_f = n_ft.execute(noise_audio[offset : offset + frame_size])[0].numpy().astype(np.complex64)
20
30
  mixture_f = target_f + noise_f
21
31
 
22
32
  crm_data = np.empty(target_f.shape, dtype=np.complex64)
@@ -31,8 +41,8 @@ def _core(data: TruthFunctionData, config: TruthFunctionConfig, polar: bool) ->
31
41
  else:
32
42
  crm_data[it.multi_index] = num / den
33
43
 
34
- truth[frame, : config.target_fft.bins] = np.absolute(crm_data) if polar else np.real(crm_data)
35
- truth[frame, config.target_fft.bins :] = np.angle(crm_data) if polar else np.imag(crm_data)
44
+ truth[frame, : t_ft.bins] = np.absolute(crm_data) if polar else np.real(crm_data)
45
+ truth[frame, t_ft.bins :] = np.angle(crm_data) if polar else np.imag(crm_data)
36
46
 
37
47
  return truth
38
48
 
@@ -41,11 +51,14 @@ def crm_validate(_config: dict) -> None:
41
51
  pass
42
52
 
43
53
 
44
- def crm_parameters(config: TruthFunctionConfig) -> int:
45
- return config.target_fft.bins * 2
54
+ def crm_parameters(feature: str, _num_classes: int, _config: dict) -> int:
55
+ from pyaaware import ForwardTransform
56
+ from pyaaware import feature_forward_transform_config
46
57
 
58
+ return ForwardTransform(**feature_forward_transform_config(feature)).bins * 2
47
59
 
48
- def crm(data: TruthFunctionData, config: TruthFunctionConfig) -> Truth:
60
+
61
+ def crm(mixdb: MixtureDatabase, m_id: int, target_index: int, _config: dict) -> Truth:
49
62
  """Complex ratio mask truth generation function
50
63
 
51
64
  Calculates the true complex ratio mask (CRM) truth which is a complex number
@@ -55,25 +68,27 @@ def crm(data: TruthFunctionData, config: TruthFunctionConfig) -> Truth:
55
68
 
56
69
  Output shape: [:, 2 * bins]
57
70
  """
58
- import numpy as np
59
-
60
- frames = config.target_fft.frames(data.target_audio)
61
- parameters = crm_parameters(config)
62
- if config.target_gain == 0:
63
- return np.zeros((frames, parameters), dtype=np.float32)
64
-
65
- return _core(data=data, config=config, polar=False)
71
+ return _core(
72
+ mixdb=mixdb,
73
+ m_id=m_id,
74
+ target_index=target_index,
75
+ parameters=crm_parameters(mixdb.feature, mixdb.num_classes, _config),
76
+ polar=False,
77
+ )
66
78
 
67
79
 
68
80
  def crmp_validate(_config: dict) -> None:
69
81
  pass
70
82
 
71
83
 
72
- def crmp_parameters(config: TruthFunctionConfig) -> int:
73
- return config.target_fft.bins * 2
84
+ def crmp_parameters(feature: str, _num_classes: int, _config: dict) -> int:
85
+ from pyaaware import ForwardTransform
86
+ from pyaaware import feature_forward_transform_config
74
87
 
88
+ return ForwardTransform(**feature_forward_transform_config(feature)).bins * 2
75
89
 
76
- def crmp(data: TruthFunctionData, config: TruthFunctionConfig) -> Truth:
90
+
91
+ def crmp(mixdb: MixtureDatabase, m_id: int, target_index: int, _config: dict) -> Truth:
77
92
  """Complex ratio mask polar truth generation function
78
93
 
79
94
  Same as the crm function except the results are magnitude and phase
@@ -81,11 +96,10 @@ def crmp(data: TruthFunctionData, config: TruthFunctionConfig) -> Truth:
81
96
 
82
97
  Output shape: [:, bins]
83
98
  """
84
- import numpy as np
85
-
86
- frames = config.target_fft.frames(data.target_audio)
87
- parameters = crmp_parameters(config)
88
- if config.target_gain == 0:
89
- return np.zeros((frames, parameters), dtype=np.float32)
90
-
91
- return _core(data=data, config=config, polar=True)
99
+ return _core(
100
+ mixdb=mixdb,
101
+ m_id=m_id,
102
+ target_index=target_index,
103
+ parameters=crmp_parameters(mixdb.feature, mixdb.num_classes, _config),
104
+ polar=True,
105
+ )