sonusai 0.19.6__py3-none-any.whl → 0.19.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. sonusai/__init__.py +1 -1
  2. sonusai/aawscd_probwrite.py +1 -1
  3. sonusai/calc_metric_spenh.py +1 -1
  4. sonusai/genft.py +29 -14
  5. sonusai/genmetrics.py +60 -42
  6. sonusai/genmix.py +41 -29
  7. sonusai/genmixdb.py +56 -64
  8. sonusai/metrics/calc_class_weights.py +1 -3
  9. sonusai/metrics/calc_optimal_thresholds.py +2 -2
  10. sonusai/metrics/calc_phase_distance.py +1 -1
  11. sonusai/metrics/calc_speech.py +6 -6
  12. sonusai/metrics/class_summary.py +6 -15
  13. sonusai/metrics/confusion_matrix_summary.py +11 -27
  14. sonusai/metrics/one_hot.py +3 -3
  15. sonusai/metrics/snr_summary.py +7 -7
  16. sonusai/mixture/__init__.py +2 -17
  17. sonusai/mixture/augmentation.py +5 -6
  18. sonusai/mixture/class_count.py +1 -1
  19. sonusai/mixture/config.py +36 -46
  20. sonusai/mixture/data_io.py +30 -1
  21. sonusai/mixture/datatypes.py +29 -40
  22. sonusai/mixture/db_datatypes.py +1 -1
  23. sonusai/mixture/feature.py +3 -23
  24. sonusai/mixture/generation.py +161 -204
  25. sonusai/mixture/helpers.py +29 -187
  26. sonusai/mixture/mixdb.py +386 -159
  27. sonusai/mixture/soundfile_audio.py +1 -1
  28. sonusai/mixture/sox_audio.py +4 -4
  29. sonusai/mixture/sox_augmentation.py +1 -1
  30. sonusai/mixture/target_class_balancing.py +9 -11
  31. sonusai/mixture/targets.py +23 -20
  32. sonusai/mixture/torchaudio_audio.py +18 -7
  33. sonusai/mixture/torchaudio_augmentation.py +3 -4
  34. sonusai/mixture/truth.py +21 -34
  35. sonusai/mixture/truth_functions/__init__.py +6 -0
  36. sonusai/mixture/truth_functions/crm.py +51 -37
  37. sonusai/mixture/truth_functions/energy.py +95 -50
  38. sonusai/mixture/truth_functions/file.py +12 -8
  39. sonusai/mixture/truth_functions/metadata.py +24 -0
  40. sonusai/mixture/truth_functions/metrics.py +28 -0
  41. sonusai/mixture/truth_functions/phoneme.py +4 -5
  42. sonusai/mixture/truth_functions/sed.py +32 -23
  43. sonusai/mixture/truth_functions/target.py +62 -29
  44. sonusai/mkwav.py +20 -19
  45. sonusai/queries/queries.py +9 -15
  46. sonusai/speech/l2arctic.py +6 -2
  47. sonusai/summarize_metric_spenh.py +1 -1
  48. sonusai/utils/__init__.py +1 -0
  49. sonusai/utils/asr_functions/aaware_whisper.py +1 -1
  50. sonusai/utils/audio_devices.py +27 -18
  51. sonusai/utils/docstring.py +6 -3
  52. sonusai/utils/energy_f.py +5 -3
  53. sonusai/utils/human_readable_size.py +6 -6
  54. sonusai/utils/load_object.py +15 -0
  55. sonusai/utils/onnx_utils.py +2 -2
  56. sonusai/utils/print_mixture_details.py +3 -3
  57. {sonusai-0.19.6.dist-info → sonusai-0.19.9.dist-info}/METADATA +2 -2
  58. {sonusai-0.19.6.dist-info → sonusai-0.19.9.dist-info}/RECORD +60 -58
  59. sonusai/mixture/truth_functions/datatypes.py +0 -37
  60. {sonusai-0.19.6.dist-info → sonusai-0.19.9.dist-info}/WHEEL +0 -0
  61. {sonusai-0.19.6.dist-info → sonusai-0.19.9.dist-info}/entry_points.txt +0 -0
@@ -1,20 +1,44 @@
1
1
  import numpy as np
2
2
 
3
- from sonusai.mixture.datatypes import Truth
4
- from sonusai.mixture.truth_functions.datatypes import TruthFunctionConfig
5
- from sonusai.mixture.truth_functions.datatypes import TruthFunctionData
3
+ from sonusai.mixture import MixtureDatabase
4
+ from sonusai.mixture import Truth
5
+ from sonusai.utils import load_object
6
+
7
+
8
+ def _core(
9
+ mixdb: MixtureDatabase,
10
+ m_id: int,
11
+ target_index: int,
12
+ config: dict,
13
+ parameters: int,
14
+ mapped: bool,
15
+ snr: bool,
16
+ ) -> Truth:
17
+ from os.path import join
6
18
 
19
+ import torch
20
+ from pyaaware import ForwardTransform
21
+ from pyaaware import feature_forward_transform_config
7
22
 
8
- def _core(data: TruthFunctionData, config: TruthFunctionConfig, mapped: bool, snr: bool) -> Truth:
9
23
  from sonusai.utils import compute_energy_f
10
24
 
11
- target_energy = compute_energy_f(time_domain=data.target_audio, transform=config.target_fft)
25
+ target_audio = mixdb.mixture_targets(m_id)[target_index]
26
+ ft = ForwardTransform(**feature_forward_transform_config(mixdb.feature))
27
+
28
+ frames = ft.frames(torch.from_numpy(target_audio))
29
+
30
+ if mixdb.mixture(m_id).target_gain(target_index) == 0:
31
+ return np.zeros((frames, parameters), dtype=np.float32)
32
+
33
+ noise_audio = mixdb.mixture_noise(m_id)
34
+
35
+ target_energy = compute_energy_f(time_domain=target_audio, transform=ft)
12
36
  noise_energy = None
13
37
  if snr:
14
- noise_energy = compute_energy_f(time_domain=data.noise_audio, transform=config.noise_fft)
38
+ noise_energy = compute_energy_f(time_domain=noise_audio, transform=ft)
15
39
 
16
40
  frames = len(target_energy)
17
- truth = np.empty((frames, config.target_fft.bins), dtype=np.float32)
41
+ truth = np.empty((frames, ft.bins), dtype=np.float32)
18
42
  for frame in range(frames):
19
43
  tmp = target_energy[frame]
20
44
 
@@ -26,7 +50,9 @@ def _core(data: TruthFunctionData, config: TruthFunctionConfig, mapped: bool, sn
26
50
  tmp = np.nan_to_num(tmp, nan=-np.inf, posinf=np.inf, neginf=-np.inf)
27
51
 
28
52
  if mapped:
29
- tmp = _calculate_mapped_snr_f(tmp, config.config["snr_db_mean"], config.config["snr_db_std"])
53
+ snr_db_mean = load_object(join(mixdb.location, config["snr_db_mean"]))
54
+ snr_db_std = load_object(join(mixdb.location, config["snr_db_std"]))
55
+ tmp = _calculate_mapped_snr_f(tmp, snr_db_mean, snr_db_std)
30
56
 
31
57
  truth[frame] = tmp
32
58
 
@@ -52,11 +78,14 @@ def energy_f_validate(_config: dict) -> None:
52
78
  pass
53
79
 
54
80
 
55
- def energy_f_parameters(config: TruthFunctionConfig) -> int:
56
- return config.target_fft.bins
81
+ def energy_f_parameters(feature: str, _num_classes: int, _config: dict) -> int:
82
+ from pyaaware import ForwardTransform
83
+ from pyaaware import feature_forward_transform_config
84
+
85
+ return ForwardTransform(**feature_forward_transform_config(feature)).bins
57
86
 
58
87
 
59
- def energy_f(data: TruthFunctionData, config: TruthFunctionConfig) -> Truth:
88
+ def energy_f(mixdb: MixtureDatabase, m_id: int, target_index: int, config: dict) -> Truth:
60
89
  """Frequency domain energy truth generation function
61
90
 
62
91
  Calculates the true energy per bin:
@@ -67,23 +96,29 @@ def energy_f(data: TruthFunctionData, config: TruthFunctionConfig) -> Truth:
67
96
 
68
97
  Output shape: [:, bins]
69
98
  """
70
- frames = config.target_fft.frames(data.target_audio)
71
- parameters = energy_f_parameters(config)
72
- if config.target_gain == 0:
73
- return np.zeros((frames, parameters), dtype=np.float32)
74
-
75
- return _core(data=data, config=config, mapped=False, snr=False)
99
+ return _core(
100
+ mixdb=mixdb,
101
+ m_id=m_id,
102
+ target_index=target_index,
103
+ config=config,
104
+ parameters=energy_f_parameters(mixdb.feature, mixdb.num_classes, config),
105
+ mapped=False,
106
+ snr=False,
107
+ )
76
108
 
77
109
 
78
110
  def snr_f_validate(_config: dict) -> None:
79
111
  pass
80
112
 
81
113
 
82
- def snr_f_parameters(config: TruthFunctionConfig) -> int:
83
- return config.target_fft.bins
114
+ def snr_f_parameters(feature: str, _num_classes: int, _config: dict) -> int:
115
+ from pyaaware import ForwardTransform
116
+ from pyaaware import feature_forward_transform_config
84
117
 
118
+ return ForwardTransform(**feature_forward_transform_config(feature)).bins
85
119
 
86
- def snr_f(data: TruthFunctionData, config: TruthFunctionConfig) -> Truth:
120
+
121
+ def snr_f(mixdb: MixtureDatabase, m_id: int, target_index: int, config: dict) -> Truth:
87
122
  """Frequency domain SNR truth function documentation
88
123
 
89
124
  Calculates the true SNR per bin:
@@ -94,54 +129,58 @@ def snr_f(data: TruthFunctionData, config: TruthFunctionConfig) -> Truth:
94
129
 
95
130
  Output shape: [:, bins]
96
131
  """
97
- frames = config.target_fft.frames(data.target_audio)
98
- parameters = snr_f_parameters(config)
99
- if config.target_gain == 0:
100
- return np.zeros((frames, parameters), dtype=np.float32)
101
-
102
- return _core(data=data, config=config, mapped=False, snr=True)
103
-
104
-
105
- def mapped_snr_f_validate(config: TruthFunctionConfig) -> None:
106
- if len(config.config) == 0:
132
+ return _core(
133
+ mixdb=mixdb,
134
+ m_id=m_id,
135
+ target_index=target_index,
136
+ config=config,
137
+ parameters=snr_f_parameters(mixdb.feature, mixdb.num_classes, config),
138
+ mapped=False,
139
+ snr=True,
140
+ )
141
+
142
+
143
+ def mapped_snr_f_validate(config: dict) -> None:
144
+ if len(config) == 0:
107
145
  raise AttributeError("mapped_snr_f truth function is missing config")
108
146
 
109
147
  for parameter in ("snr_db_mean", "snr_db_std"):
110
- if parameter not in config.config:
148
+ if parameter not in config:
111
149
  raise AttributeError(f"mapped_snr_f truth function is missing required '{parameter}'")
112
150
 
113
- if len(config.config[parameter]) != config.target_fft.bins:
114
- raise ValueError(
115
- f"mapped_snr_f truth function '{parameter}' does not have {config.target_fft.bins} elements"
116
- )
117
151
 
152
+ def mapped_snr_f_parameters(feature: str, _num_classes: int, _config: dict) -> int:
153
+ from pyaaware import ForwardTransform
154
+ from pyaaware import feature_forward_transform_config
118
155
 
119
- def mapped_snr_f_parameters(config: TruthFunctionConfig) -> int:
120
- return config.target_fft.bins
156
+ return ForwardTransform(**feature_forward_transform_config(feature)).bins
121
157
 
122
158
 
123
- def mapped_snr_f(data: TruthFunctionData, config: TruthFunctionConfig) -> Truth:
159
+ def mapped_snr_f(mixdb: MixtureDatabase, m_id: int, target_index: int, config: dict) -> Truth:
124
160
  """Frequency domain mapped SNR truth function documentation
125
161
 
126
162
  Output shape: [:, bins]
127
163
  """
128
- frames = config.target_fft.frames(data.target_audio)
129
- parameters = mapped_snr_f_parameters(config)
130
- if config.target_gain == 0:
131
- return np.zeros((frames, parameters), dtype=np.float32)
132
-
133
- return _core(data=data, config=config, mapped=True, snr=True)
164
+ return _core(
165
+ mixdb=mixdb,
166
+ m_id=m_id,
167
+ target_index=target_index,
168
+ config=config,
169
+ parameters=mapped_snr_f_parameters(mixdb.feature, mixdb.num_classes, config),
170
+ mapped=True,
171
+ snr=True,
172
+ )
134
173
 
135
174
 
136
175
  def energy_t_validate(_config: dict) -> None:
137
176
  pass
138
177
 
139
178
 
140
- def energy_t_parameters(_config: TruthFunctionConfig) -> int:
179
+ def energy_t_parameters(_feature: str, _num_classes: int, _config: dict) -> int:
141
180
  return 1
142
181
 
143
182
 
144
- def energy_t(data: TruthFunctionData, config: TruthFunctionConfig) -> Truth:
183
+ def energy_t(mixdb: MixtureDatabase, m_id: int, target_index: int, _config: dict) -> Truth:
145
184
  """Time domain energy truth function documentation
146
185
 
147
186
  Calculates the true time domain energy of each frame:
@@ -164,10 +203,16 @@ def energy_t(data: TruthFunctionData, config: TruthFunctionConfig) -> Truth:
164
203
  transform config.
165
204
  """
166
205
  import torch
206
+ from pyaaware import ForwardTransform
207
+ from pyaaware import feature_forward_transform_config
208
+
209
+ target_audio = torch.from_numpy(mixdb.mixture_targets(m_id)[target_index])
210
+
211
+ ft = ForwardTransform(**feature_forward_transform_config(mixdb.feature))
167
212
 
168
- frames = config.target_fft.frames(data.target_audio)
169
- parameters = energy_t_parameters(config)
170
- if config.target_gain == 0:
213
+ frames = ft.frames(target_audio)
214
+ parameters = energy_f_parameters(mixdb.feature, mixdb.num_classes, _config)
215
+ if mixdb.mixture(m_id).target_gain(target_index) == 0:
171
216
  return np.zeros((frames, parameters), dtype=np.float32)
172
217
 
173
- return config.target_fft.execute_all(torch.from_numpy(data.target_audio))[1].numpy()
218
+ return ft.execute_all(target_audio)[1].numpy()
@@ -1,6 +1,5 @@
1
- from sonusai.mixture.datatypes import Truth
2
- from sonusai.mixture.truth_functions.datatypes import TruthFunctionConfig
3
- from sonusai.mixture.truth_functions.datatypes import TruthFunctionData
1
+ from sonusai.mixture import MixtureDatabase
2
+ from sonusai.mixture import Truth
4
3
 
5
4
 
6
5
  def file_validate(config: dict) -> None:
@@ -17,28 +16,33 @@ def file_validate(config: dict) -> None:
17
16
  raise ValueError("Truth file does not contain truth_f dataset")
18
17
 
19
18
 
20
- def file_parameters(config: TruthFunctionConfig) -> int:
19
+ def file_parameters(_feature: str, _num_classes: int, config: dict) -> int:
21
20
  import h5py
22
21
  import numpy as np
23
22
 
24
- with h5py.File(config.config["file"], "r") as f:
23
+ with h5py.File(config["file"], "r") as f:
25
24
  truth = np.array(f["truth_f"])
26
25
 
27
26
  return truth.shape[-1]
28
27
 
29
28
 
30
- def file(data: TruthFunctionData, config: TruthFunctionConfig) -> Truth:
29
+ def file(mixdb: MixtureDatabase, m_id: int, target_index: int, config: dict) -> Truth:
31
30
  """file truth function documentation"""
32
31
  import h5py
33
32
  import numpy as np
33
+ from pyaaware import feature_inverse_transform_config
34
+
35
+ target_audio = mixdb.mixture_targets(m_id)[target_index]
34
36
 
35
- with h5py.File(config.config["file"], "r") as f:
37
+ frame_size = feature_inverse_transform_config(mixdb.feature)["overlap"]
38
+
39
+ with h5py.File(config["file"], "r") as f:
36
40
  truth = np.array(f["truth_f"])
37
41
 
38
42
  if truth.ndim != 2:
39
43
  raise ValueError("Truth file data is not 2 dimensions")
40
44
 
41
- if truth.shape[0] != len(data.target_audio) // config.frame_size:
45
+ if truth.shape[0] != len(target_audio) // frame_size:
42
46
  raise ValueError("Truth file does not contain the right amount of frames")
43
47
 
44
48
  return truth
@@ -0,0 +1,24 @@
1
+ from sonusai.mixture import MixtureDatabase
2
+ from sonusai.mixture import Truth
3
+
4
+
5
+ def metadata_validate(config: dict) -> None:
6
+ if len(config) == 0:
7
+ raise AttributeError("metadata truth function is missing config")
8
+
9
+ parameters = ["tier"]
10
+ for parameter in parameters:
11
+ if parameter not in config:
12
+ raise AttributeError(f"metadata truth function is missing required '{parameter}'")
13
+
14
+
15
+ def metadata_parameters(_feature: str, _num_classes: int, _config: dict) -> int | None:
16
+ return None
17
+
18
+
19
+ def metadata(mixdb: MixtureDatabase, m_id: int, target_index: int, config: dict) -> Truth:
20
+ """Metadata truth generation function
21
+
22
+ Retrieves metadata from target.
23
+ """
24
+ return mixdb.mixture_speech_metadata(m_id, config["tier"])[target_index]
@@ -0,0 +1,28 @@
1
+ from sonusai.mixture import MixtureDatabase
2
+ from sonusai.mixture import Truth
3
+
4
+
5
+ def metrics_validate(config: dict) -> None:
6
+ if len(config) == 0:
7
+ raise AttributeError("metrics truth function is missing config")
8
+
9
+ parameters = ["metric"]
10
+ for parameter in parameters:
11
+ if parameter not in config:
12
+ raise AttributeError(f"metrics truth function is missing required '{parameter}'")
13
+
14
+
15
+ def metrics_parameters(_feature: str, _num_classes: int, _config: dict) -> int | None:
16
+ return None
17
+
18
+
19
+ def metrics(mixdb: MixtureDatabase, m_id: int, target_index: int, config: dict) -> Truth:
20
+ """Metadata truth generation function
21
+
22
+ Retrieves metrics from target.
23
+ """
24
+ if not isinstance(config["metric"], list):
25
+ m = [config["metric"]]
26
+ else:
27
+ m = config["metric"]
28
+ return mixdb.mixture_metrics(m_id, m)[0][target_index]
@@ -1,17 +1,16 @@
1
- from sonusai.mixture.datatypes import Truth
2
- from sonusai.mixture.truth_functions.datatypes import TruthFunctionConfig
3
- from sonusai.mixture.truth_functions.datatypes import TruthFunctionData
1
+ from sonusai.mixture import MixtureDatabase
2
+ from sonusai.mixture import Truth
4
3
 
5
4
 
6
5
  def phoneme_validate(_config: dict) -> None:
7
6
  raise NotImplementedError("Truth function phoneme is not supported yet")
8
7
 
9
8
 
10
- def phoneme_parameters(_config: TruthFunctionConfig) -> int:
9
+ def phoneme_parameters(_feature: str, _num_classes: int, _config: dict) -> int:
11
10
  raise NotImplementedError("Truth function phoneme is not supported yet")
12
11
 
13
12
 
14
- def phoneme(_data: TruthFunctionData, _config: TruthFunctionConfig) -> Truth:
13
+ def phoneme(_mixdb: MixtureDatabase, _m_id: int, _target_index: int, _config: dict) -> Truth:
15
14
  """Read in .txt transcript and run a Python function to generate text grid data
16
15
  (indicating which phonemes are active). Then generate truth based on this data and put
17
16
  in the correct classes based on the index in the config.
@@ -1,12 +1,5 @@
1
- from sonusai.mixture.datatypes import Truth
2
- from sonusai.mixture.truth_functions.datatypes import TruthFunctionConfig
3
- from sonusai.mixture.truth_functions.datatypes import TruthFunctionData
4
-
5
-
6
- def _strictly_decreasing(list_to_check: list) -> bool:
7
- from itertools import pairwise
8
-
9
- return all(x > y for x, y in pairwise(list_to_check))
1
+ from sonusai.mixture import MixtureDatabase
2
+ from sonusai.mixture import Truth
10
3
 
11
4
 
12
5
  def sed_validate(config: dict) -> None:
@@ -23,11 +16,11 @@ def sed_validate(config: dict) -> None:
23
16
  raise ValueError(f"sed truth function 'thresholds' are not strictly decreasing: {thresholds}")
24
17
 
25
18
 
26
- def sed_parameters(config: TruthFunctionConfig) -> int:
27
- return config.num_classes
19
+ def sed_parameters(_feature: str, num_classes: int, _config: dict) -> int:
20
+ return num_classes
28
21
 
29
22
 
30
- def sed(data: TruthFunctionData, config: TruthFunctionConfig) -> Truth:
23
+ def sed(mixdb: MixtureDatabase, m_id: int, target_index: int, config: dict) -> Truth:
31
24
  """Sound energy detection truth generation function
32
25
 
33
26
  Calculates sound energy detection truth using simple 3 threshold
@@ -62,30 +55,46 @@ def sed(data: TruthFunctionData, config: TruthFunctionConfig) -> Truth:
62
55
  import numpy as np
63
56
  import torch
64
57
  from pyaaware import SED
58
+ from pyaaware import ForwardTransform
59
+ from pyaaware import feature_forward_transform_config
60
+ from pyaaware import feature_inverse_transform_config
61
+
62
+ target_audio = torch.from_numpy(mixdb.mixture_targets(m_id)[target_index])
63
+
64
+ frame_size = feature_inverse_transform_config(mixdb.feature)["overlap"]
65
65
 
66
- if len(data.target_audio) % config.frame_size != 0:
67
- raise ValueError(f"Number of samples in audio is not a multiple of {config.frame_size}")
66
+ ft = ForwardTransform(**feature_forward_transform_config(mixdb.feature))
68
67
 
69
- frames = config.target_fft.frames(data.target_audio)
70
- parameters = sed_parameters(config)
71
- if config.target_gain == 0:
68
+ if len(target_audio) % frame_size != 0:
69
+ raise ValueError(f"Number of samples in audio is not a multiple of {frame_size}")
70
+
71
+ frames = ft.frames(target_audio)
72
+ parameters = sed_parameters(mixdb.feature, mixdb.num_classes, config)
73
+ target_gain = mixdb.mixture(m_id).target_gain(target_index)
74
+ if target_gain == 0:
72
75
  return np.zeros((frames, parameters), dtype=np.float32)
73
76
 
74
77
  # SED wants 1-based indices
75
78
  s = SED(
76
- thresholds=config.config["thresholds"],
77
- index=config.class_indices,
78
- frame_size=config.frame_size,
79
- num_classes=config.num_classes,
79
+ thresholds=config["thresholds"],
80
+ index=mixdb.target_file(mixdb.mixture(m_id).targets[target_index].file_id).class_indices,
81
+ frame_size=frame_size,
82
+ num_classes=mixdb.num_classes,
80
83
  )
81
84
 
82
85
  # Back out target gain
83
- target_audio = data.target_audio / config.target_gain
86
+ target_audio = target_audio / target_gain
84
87
 
85
88
  # Compute energy
86
- target_energy = config.target_fft.execute_all(torch.from_numpy(target_audio))[1].numpy()
89
+ target_energy = ft.execute_all(target_audio)[1].numpy()
87
90
 
88
91
  if frames != target_energy.shape[0]:
89
92
  raise ValueError("Incorrect frames calculation in sed truth function")
90
93
 
91
94
  return s.execute_all(target_energy)
95
+
96
+
97
+ def _strictly_decreasing(list_to_check: list) -> bool:
98
+ from itertools import pairwise
99
+
100
+ return all(x > y for x, y in pairwise(list_to_check))
@@ -1,21 +1,24 @@
1
- from sonusai.mixture.datatypes import AudioF
2
- from sonusai.mixture.datatypes import Truth
3
- from sonusai.mixture.truth_functions.datatypes import TruthFunctionConfig
4
- from sonusai.mixture.truth_functions.datatypes import TruthFunctionData
1
+ from sonusai.mixture import MixtureDatabase
2
+ from sonusai.mixture import Truth
5
3
 
6
4
 
7
5
  def target_f_validate(_config: dict) -> None:
8
6
  pass
9
7
 
10
8
 
11
- def target_f_parameters(config: TruthFunctionConfig) -> int:
12
- if config.ttype == "tdac-co":
13
- return config.target_fft.bins
9
+ def target_f_parameters(feature: str, _num_classes: int, _config: dict) -> int:
10
+ from pyaaware import ForwardTransform
11
+ from pyaaware import feature_forward_transform_config
14
12
 
15
- return config.target_fft.bins * 2
13
+ ft = ForwardTransform(**feature_forward_transform_config(feature))
16
14
 
15
+ if ft.ttype == "tdac-co":
16
+ return ft.bins
17
17
 
18
- def target_f(data: TruthFunctionData, config: TruthFunctionConfig) -> Truth:
18
+ return ft.bins * 2
19
+
20
+
21
+ def target_f(mixdb: MixtureDatabase, m_id: int, target_index: int, _config: dict) -> Truth:
19
22
  """Frequency domain target truth function
20
23
 
21
24
  Calculates the true transform of the target using the STFT
@@ -26,23 +29,34 @@ def target_f(data: TruthFunctionData, config: TruthFunctionConfig) -> Truth:
26
29
  [:, bins] (target real only for tdac-co)
27
30
  """
28
31
  import torch
32
+ from pyaaware import ForwardTransform
33
+ from pyaaware import feature_forward_transform_config
34
+
35
+ ft = ForwardTransform(**feature_forward_transform_config(mixdb.feature))
29
36
 
30
- target_freq = config.target_fft.execute_all(torch.from_numpy(data.target_audio))[0].numpy()
31
- return _stack_real_imag(target_freq, config.ttype)
37
+ target_audio = torch.from_numpy(mixdb.mixture_targets(m_id)[target_index])
38
+
39
+ target_freq = ft.execute_all(target_audio)[0].numpy()
40
+ return _stack_real_imag(target_freq, ft.ttype)
32
41
 
33
42
 
34
43
  def target_mixture_f_validate(_config: dict) -> None:
35
44
  pass
36
45
 
37
46
 
38
- def target_mixture_f_parameters(config: TruthFunctionConfig) -> int:
39
- if config.ttype == "tdac-co":
40
- return config.target_fft.bins * 2
47
+ def target_mixture_f_parameters(feature: str, _num_classes: int, _config: dict) -> int:
48
+ from pyaaware import ForwardTransform
49
+ from pyaaware import feature_forward_transform_config
50
+
51
+ ft = ForwardTransform(**feature_forward_transform_config(feature))
52
+
53
+ if ft.ttype == "tdac-co":
54
+ return ft.bins * 2
41
55
 
42
- return config.target_fft.bins * 4
56
+ return ft.bins * 4
43
57
 
44
58
 
45
- def target_mixture_f(data: TruthFunctionData, config: TruthFunctionConfig) -> Truth:
59
+ def target_mixture_f(mixdb: MixtureDatabase, m_id: int, target_index: int, _config: dict) -> Truth:
46
60
  """Frequency domain target and mixture truth function
47
61
 
48
62
  Calculates the true transform of the target and the mixture
@@ -55,14 +69,21 @@ def target_mixture_f(data: TruthFunctionData, config: TruthFunctionConfig) -> Tr
55
69
  """
56
70
  import numpy as np
57
71
  import torch
72
+ from pyaaware import ForwardTransform
73
+ from pyaaware import feature_forward_transform_config
58
74
 
59
- target_freq = config.target_fft.execute_all(torch.from_numpy(data.target_audio))[0].numpy()
60
- mixture_freq = config.mixture_fft.execute_all(torch.from_numpy(data.mixture_audio))[0].numpy()
75
+ ft = ForwardTransform(**feature_forward_transform_config(mixdb.feature))
76
+
77
+ target_audio = torch.from_numpy(mixdb.mixture_targets(m_id)[target_index])
78
+ mixture_audio = torch.from_numpy(mixdb.mixture_mixture(m_id))
79
+
80
+ target_freq = ft.execute_all(torch.from_numpy(target_audio))[0].numpy()
81
+ mixture_freq = ft.execute_all(torch.from_numpy(mixture_audio))[0].numpy()
61
82
 
62
83
  frames, bins = target_freq.shape
63
84
  truth = np.empty((frames, bins * 4), dtype=np.float32)
64
- truth[:, : bins * 2] = _stack_real_imag(target_freq, config.ttype)
65
- truth[:, bins * 2 :] = _stack_real_imag(mixture_freq, config.ttype)
85
+ truth[:, : bins * 2] = _stack_real_imag(target_freq, ft.ttype)
86
+ truth[:, bins * 2 :] = _stack_real_imag(mixture_freq, ft.ttype)
66
87
  return truth
67
88
 
68
89
 
@@ -70,11 +91,14 @@ def target_swin_f_validate(_config: dict) -> None:
70
91
  pass
71
92
 
72
93
 
73
- def target_swin_f_parameters(config: TruthFunctionConfig) -> int:
74
- return config.target_fft.bins * 2
94
+ def target_swin_f_parameters(feature: str, _num_classes: int, _config: dict) -> int:
95
+ from pyaaware import ForwardTransform
96
+ from pyaaware import feature_forward_transform_config
97
+
98
+ return ForwardTransform(**feature_forward_transform_config(feature)).bins * 2
75
99
 
76
100
 
77
- def target_swin_f(data: TruthFunctionData, config: TruthFunctionConfig) -> Truth:
101
+ def target_swin_f(mixdb: MixtureDatabase, m_id: int, target_index: int, _config: dict) -> Truth:
78
102
  """Frequency domain target with synthesis window truth function
79
103
 
80
104
  Calculates the true transform of the target using the STFT
@@ -85,20 +109,29 @@ def target_swin_f(data: TruthFunctionData, config: TruthFunctionConfig) -> Truth
85
109
  Output shape: [:, 2 * bins] (stacked real, imag)
86
110
  """
87
111
  import numpy as np
112
+ import torch
113
+ from pyaaware import ForwardTransform
114
+ from pyaaware import InverseTransform
115
+ from pyaaware import feature_forward_transform_config
116
+ from pyaaware import feature_inverse_transform_config
88
117
 
89
118
  from sonusai.utils import stack_complex
90
119
 
91
- truth = np.empty((len(data.target_audio) // config.frame_size, config.target_fft.bins * 2), dtype=np.float32)
92
- for idx, offset in enumerate(range(0, len(data.target_audio), config.frame_size)):
93
- target_freq = config.target_fft.execute(
94
- np.multiply(data.target_audio[offset : offset + config.frame_size], config.swin)
95
- )[0]
120
+ ft = ForwardTransform(**feature_forward_transform_config(mixdb.feature))
121
+ it = InverseTransform(**feature_inverse_transform_config(mixdb.feature))
122
+
123
+ target_audio = mixdb.mixture_targets(m_id)[target_index]
124
+
125
+ truth = np.empty((len(target_audio) // ft.overlap, ft.bins * 2), dtype=np.float32)
126
+ for idx, offset in enumerate(range(0, len(target_audio), ft.overlap)):
127
+ audio_frame = torch.from_numpy(np.multiply(target_audio[offset : offset + ft.overlap], it.window))
128
+ target_freq = ft.execute(audio_frame)[0].numpy()
96
129
  truth[idx] = stack_complex(target_freq)
97
130
 
98
131
  return truth
99
132
 
100
133
 
101
- def _stack_real_imag(data: AudioF, ttype: str) -> Truth:
134
+ def _stack_real_imag(data: Truth, ttype: str) -> Truth:
102
135
  import numpy as np
103
136
 
104
137
  from sonusai.utils import stack_complex