sonusai 0.18.8__py3-none-any.whl → 0.19.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (118) hide show
  1. sonusai/__init__.py +20 -29
  2. sonusai/aawscd_probwrite.py +18 -18
  3. sonusai/audiofe.py +93 -80
  4. sonusai/calc_metric_spenh.py +395 -321
  5. sonusai/data/genmixdb.yml +5 -11
  6. sonusai/{gentcst.py → deprecated/gentcst.py} +146 -149
  7. sonusai/{plot.py → deprecated/plot.py} +177 -131
  8. sonusai/{tplot.py → deprecated/tplot.py} +124 -102
  9. sonusai/doc/__init__.py +1 -1
  10. sonusai/doc/doc.py +112 -177
  11. sonusai/doc.py +10 -10
  12. sonusai/genft.py +93 -77
  13. sonusai/genmetrics.py +59 -46
  14. sonusai/genmix.py +116 -104
  15. sonusai/genmixdb.py +194 -153
  16. sonusai/lsdb.py +56 -66
  17. sonusai/main.py +23 -20
  18. sonusai/metrics/__init__.py +2 -0
  19. sonusai/metrics/calc_audio_stats.py +29 -24
  20. sonusai/metrics/calc_class_weights.py +7 -7
  21. sonusai/metrics/calc_optimal_thresholds.py +5 -7
  22. sonusai/metrics/calc_pcm.py +3 -3
  23. sonusai/metrics/calc_pesq.py +10 -7
  24. sonusai/metrics/calc_phase_distance.py +3 -3
  25. sonusai/metrics/calc_sa_sdr.py +10 -8
  26. sonusai/metrics/calc_segsnr_f.py +15 -17
  27. sonusai/metrics/calc_speech.py +105 -47
  28. sonusai/metrics/calc_wer.py +35 -32
  29. sonusai/metrics/calc_wsdr.py +10 -7
  30. sonusai/metrics/class_summary.py +30 -27
  31. sonusai/metrics/confusion_matrix_summary.py +25 -22
  32. sonusai/metrics/one_hot.py +91 -57
  33. sonusai/metrics/snr_summary.py +53 -46
  34. sonusai/mixture/__init__.py +19 -14
  35. sonusai/mixture/audio.py +4 -6
  36. sonusai/mixture/augmentation.py +37 -43
  37. sonusai/mixture/class_count.py +5 -14
  38. sonusai/mixture/config.py +292 -225
  39. sonusai/mixture/constants.py +41 -30
  40. sonusai/mixture/data_io.py +155 -0
  41. sonusai/mixture/datatypes.py +111 -108
  42. sonusai/mixture/db_datatypes.py +54 -70
  43. sonusai/mixture/eq_rule_is_valid.py +6 -9
  44. sonusai/mixture/feature.py +50 -46
  45. sonusai/mixture/generation.py +522 -389
  46. sonusai/mixture/helpers.py +217 -272
  47. sonusai/mixture/log_duration_and_sizes.py +16 -13
  48. sonusai/mixture/mixdb.py +677 -473
  49. sonusai/mixture/soundfile_audio.py +12 -17
  50. sonusai/mixture/sox_audio.py +91 -112
  51. sonusai/mixture/sox_augmentation.py +8 -9
  52. sonusai/mixture/spectral_mask.py +4 -6
  53. sonusai/mixture/target_class_balancing.py +41 -36
  54. sonusai/mixture/targets.py +69 -67
  55. sonusai/mixture/tokenized_shell_vars.py +23 -23
  56. sonusai/mixture/torchaudio_audio.py +14 -15
  57. sonusai/mixture/torchaudio_augmentation.py +23 -27
  58. sonusai/mixture/truth.py +48 -26
  59. sonusai/mixture/truth_functions/__init__.py +26 -0
  60. sonusai/mixture/truth_functions/crm.py +56 -38
  61. sonusai/mixture/truth_functions/datatypes.py +37 -0
  62. sonusai/mixture/truth_functions/energy.py +85 -59
  63. sonusai/mixture/truth_functions/file.py +30 -30
  64. sonusai/mixture/truth_functions/phoneme.py +14 -7
  65. sonusai/mixture/truth_functions/sed.py +71 -45
  66. sonusai/mixture/truth_functions/target.py +69 -106
  67. sonusai/mkwav.py +52 -85
  68. sonusai/onnx_predict.py +46 -43
  69. sonusai/queries/__init__.py +3 -1
  70. sonusai/queries/queries.py +100 -59
  71. sonusai/speech/__init__.py +2 -0
  72. sonusai/speech/l2arctic.py +24 -23
  73. sonusai/speech/librispeech.py +16 -17
  74. sonusai/speech/mcgill.py +22 -21
  75. sonusai/speech/textgrid.py +32 -25
  76. sonusai/speech/timit.py +45 -42
  77. sonusai/speech/vctk.py +14 -13
  78. sonusai/speech/voxceleb.py +26 -20
  79. sonusai/summarize_metric_spenh.py +11 -10
  80. sonusai/utils/__init__.py +4 -3
  81. sonusai/utils/asl_p56.py +1 -1
  82. sonusai/utils/asr.py +37 -17
  83. sonusai/utils/asr_functions/__init__.py +2 -0
  84. sonusai/utils/asr_functions/aaware_whisper.py +18 -12
  85. sonusai/utils/audio_devices.py +12 -12
  86. sonusai/utils/braced_glob.py +6 -8
  87. sonusai/utils/calculate_input_shape.py +1 -4
  88. sonusai/utils/compress.py +2 -2
  89. sonusai/utils/convert_string_to_number.py +1 -3
  90. sonusai/utils/create_timestamp.py +1 -1
  91. sonusai/utils/create_ts_name.py +2 -2
  92. sonusai/utils/dataclass_from_dict.py +1 -1
  93. sonusai/utils/docstring.py +6 -6
  94. sonusai/utils/energy_f.py +9 -7
  95. sonusai/utils/engineering_number.py +56 -54
  96. sonusai/utils/get_label_names.py +8 -10
  97. sonusai/utils/human_readable_size.py +2 -2
  98. sonusai/utils/model_utils.py +3 -5
  99. sonusai/utils/numeric_conversion.py +2 -4
  100. sonusai/utils/onnx_utils.py +43 -32
  101. sonusai/utils/parallel.py +40 -27
  102. sonusai/utils/print_mixture_details.py +25 -22
  103. sonusai/utils/ranges.py +12 -12
  104. sonusai/utils/read_predict_data.py +11 -9
  105. sonusai/utils/reshape.py +19 -26
  106. sonusai/utils/seconds_to_hms.py +1 -1
  107. sonusai/utils/stacked_complex.py +8 -16
  108. sonusai/utils/stratified_shuffle_split.py +29 -27
  109. sonusai/utils/write_audio.py +2 -2
  110. sonusai/utils/yes_or_no.py +3 -3
  111. sonusai/vars.py +14 -14
  112. {sonusai-0.18.8.dist-info → sonusai-0.19.5.dist-info}/METADATA +20 -21
  113. sonusai-0.19.5.dist-info/RECORD +125 -0
  114. {sonusai-0.18.8.dist-info → sonusai-0.19.5.dist-info}/WHEEL +1 -1
  115. sonusai/mixture/truth_functions/data.py +0 -58
  116. sonusai/utils/read_mixture_data.py +0 -14
  117. sonusai-0.18.8.dist-info/RECORD +0 -125
  118. {sonusai-0.18.8.dist-info → sonusai-0.19.5.dist-info}/entry_points.txt +0 -0
sonusai/mixture/mixdb.py CHANGED
@@ -1,46 +1,47 @@
1
+ # ruff: noqa: S608
1
2
  from functools import cached_property
2
3
  from functools import lru_cache
3
4
  from functools import partial
4
5
  from sqlite3 import Connection
5
6
  from sqlite3 import Cursor
6
7
  from typing import Any
7
- from typing import Optional
8
-
9
- from sonusai.mixture.datatypes import ASRConfigs
10
- from sonusai.mixture.datatypes import AudioF
11
- from sonusai.mixture.datatypes import AudioT
12
- from sonusai.mixture.datatypes import AudiosF
13
- from sonusai.mixture.datatypes import AudiosT
14
- from sonusai.mixture.datatypes import ClassCount
15
- from sonusai.mixture.datatypes import Feature
16
- from sonusai.mixture.datatypes import FeatureGeneratorConfig
17
- from sonusai.mixture.datatypes import FeatureGeneratorInfo
18
- from sonusai.mixture.datatypes import GeneralizedIDs
19
- from sonusai.mixture.datatypes import ImpulseResponseFiles
20
- from sonusai.mixture.datatypes import MetricDoc
21
- from sonusai.mixture.datatypes import MetricDocs
22
- from sonusai.mixture.datatypes import Mixture
23
- from sonusai.mixture.datatypes import Mixtures
24
- from sonusai.mixture.datatypes import NoiseFile
25
- from sonusai.mixture.datatypes import NoiseFiles
26
- from sonusai.mixture.datatypes import Segsnr
27
- from sonusai.mixture.datatypes import SpectralMask
28
- from sonusai.mixture.datatypes import SpectralMasks
29
- from sonusai.mixture.datatypes import SpeechMetadata
30
- from sonusai.mixture.datatypes import TargetFile
31
- from sonusai.mixture.datatypes import TargetFiles
32
- from sonusai.mixture.datatypes import TransformConfig
33
- from sonusai.mixture.datatypes import Truth
34
- from sonusai.mixture.datatypes import UniversalSNR
8
+
9
+ from .datatypes import ASRConfigs
10
+ from .datatypes import AudioF
11
+ from .datatypes import AudiosF
12
+ from .datatypes import AudiosT
13
+ from .datatypes import AudioT
14
+ from .datatypes import ClassCount
15
+ from .datatypes import Feature
16
+ from .datatypes import FeatureGeneratorConfig
17
+ from .datatypes import FeatureGeneratorInfo
18
+ from .datatypes import GeneralizedIDs
19
+ from .datatypes import ImpulseResponseFiles
20
+ from .datatypes import MetricDoc
21
+ from .datatypes import MetricDocs
22
+ from .datatypes import Mixture
23
+ from .datatypes import Mixtures
24
+ from .datatypes import NoiseFile
25
+ from .datatypes import NoiseFiles
26
+ from .datatypes import Segsnr
27
+ from .datatypes import SpectralMask
28
+ from .datatypes import SpectralMasks
29
+ from .datatypes import SpeechMetadata
30
+ from .datatypes import TargetFile
31
+ from .datatypes import TargetFiles
32
+ from .datatypes import TransformConfig
33
+ from .datatypes import TruthConfigs
34
+ from .datatypes import TruthDict
35
+ from .datatypes import UniversalSNR
35
36
 
36
37
 
37
38
  def db_file(location: str, test: bool = False) -> str:
38
39
  from os.path import join
39
40
 
40
41
  if test:
41
- name = 'mixdb_test.db'
42
+ name = "mixdb_test.db"
42
43
  else:
43
- name = 'mixdb.db'
44
+ name = "mixdb.db"
44
45
 
45
46
  return join(location, name)
46
47
 
@@ -50,19 +51,17 @@ def db_connection(location: str, create: bool = False, readonly: bool = True, te
50
51
  from os import remove
51
52
  from os.path import exists
52
53
 
53
- from sonusai import SonusAIError
54
-
55
54
  name = db_file(location, test)
56
55
  if create and exists(name):
57
56
  remove(name)
58
57
 
59
58
  if not create and not exists(name):
60
- raise SonusAIError(f'Could not find mixture database in {location}')
59
+ raise OSError(f"Could not find mixture database in {location}")
61
60
 
62
61
  if not create and readonly:
63
- name += '?mode=ro'
62
+ name += "?mode=ro"
64
63
 
65
- connection = sqlite3.connect('file:' + name, uri=True)
64
+ connection = sqlite3.connect("file:" + name, uri=True)
66
65
  # connection.set_trace_callback(print)
67
66
  return connection
68
67
 
@@ -103,25 +102,23 @@ class MixtureDatabase:
103
102
  num_classes=self.num_classes,
104
103
  spectral_masks=self.spectral_masks,
105
104
  target_files=self.target_files,
106
- truth_mutex=self.truth_mutex,
107
- truth_reduction_function=self.truth_reduction_function
108
105
  )
109
106
  return config.to_json(indent=2)
110
107
 
111
108
  def save(self) -> None:
112
- """Save the MixtureDatabase as a JSON file
113
- """
109
+ """Save the MixtureDatabase as a JSON file"""
114
110
  from os.path import join
115
111
 
116
- json_name = join(self.location, 'mixdb.json')
117
- with open(file=json_name, mode='w') as file:
112
+ json_name = join(self.location, "mixdb.json")
113
+ with open(file=json_name, mode="w") as file:
118
114
  file.write(self.json)
119
115
 
120
116
  @cached_property
121
117
  def fg_config(self) -> FeatureGeneratorConfig:
122
- return FeatureGeneratorConfig(feature_mode=self.feature,
123
- num_classes=self.num_classes,
124
- truth_mutex=self.truth_mutex)
118
+ return FeatureGeneratorConfig(
119
+ feature_mode=self.feature,
120
+ truth_parameters=self.truth_parameters,
121
+ )
125
122
 
126
123
  @cached_property
127
124
  def fg_info(self) -> FeatureGeneratorInfo:
@@ -130,19 +127,18 @@ class MixtureDatabase:
130
127
  return get_feature_generator_info(self.fg_config)
131
128
 
132
129
  @cached_property
133
- def num_classes(self) -> int:
134
- with self.db() as c:
135
- return int(c.execute("SELECT top.num_classes FROM top").fetchone()[0])
136
-
137
- @cached_property
138
- def truth_mutex(self) -> bool:
130
+ def truth_parameters(self) -> dict[str, int]:
139
131
  with self.db() as c:
140
- return bool(c.execute("SELECT top.truth_mutex FROM top").fetchone()[0])
132
+ rows = c.execute("SELECT * FROM truth_parameters").fetchall()
133
+ truth_parameters: dict[str, int] = {}
134
+ for row in rows:
135
+ truth_parameters[row[1]] = row[2]
136
+ return truth_parameters
141
137
 
142
138
  @cached_property
143
- def truth_reduction_function(self) -> str:
139
+ def num_classes(self) -> int:
144
140
  with self.db() as c:
145
- return str(c.execute("SELECT top.truth_reduction_function FROM top").fetchone()[0])
141
+ return int(c.execute("SELECT top.num_classes FROM top").fetchone()[0])
146
142
 
147
143
  @cached_property
148
144
  def noise_mix_mode(self) -> str:
@@ -158,68 +154,152 @@ class MixtureDatabase:
158
154
 
159
155
  @cached_property
160
156
  def supported_metrics(self) -> MetricDocs:
161
- metrics = MetricDocs([
162
- MetricDoc('Mixture Metrics', 'mxsnr', 'SNR specification in dB'),
163
- MetricDoc('Mixture Metrics', 'mxssnr_avg', 'Segmental SNR average over all frames'),
164
- MetricDoc('Mixture Metrics', 'mxssnr_std', 'Segmental SNR standard deviation over all frames'),
165
- MetricDoc('Mixture Metrics', 'mxssnrdb_avg',
166
- 'Segmental SNR average of the dB frame values over all frames'),
167
- MetricDoc('Mixture Metrics', 'mxssnrdb_std',
168
- 'Segmental SNR standard deviation of the dB frame values over all frames'),
169
- MetricDoc('Mixture Metrics', 'mxssnrf_avg',
170
- 'Per-bin segmental SNR average over all frames (using feature transform)'),
171
- MetricDoc('Mixture Metrics', 'mxssnrf_std',
172
- 'Per-bin segmental SNR standard deviation over all frames (using feature transform)'),
173
- MetricDoc('Mixture Metrics', 'mxssnrdbf_avg',
174
- 'Per-bin segmental average of the dB frame values over all frames (using feature transform)'),
175
- MetricDoc('Mixture Metrics', 'mxssnrdbf_std',
176
- 'Per-bin segmental standard deviation of the dB frame values over all frames (using feature transform)'),
177
- MetricDoc('Mixture Metrics', 'mxpesq', 'PESQ of mixture versus true target[0]'),
178
- MetricDoc('Mixture Metrics', 'mxwsdr', 'Weighted signal distorion ratio of mixture versus true target[0]'),
179
- MetricDoc('Mixture Metrics', 'mxpd', 'Phase distance between mixture and true target[0]'),
180
- MetricDoc('Mixture Metrics', 'mxstoi',
181
- 'Short term objective intelligibility of mixture versus true target[0]'),
182
- MetricDoc('Mixture Metrics', 'mxcsig',
183
- 'Predicted rating of speech distortion of mixture versus true target[0]'),
184
- MetricDoc('Mixture Metrics', 'mxcbak',
185
- 'Predicted rating of background distortion of mixture versus true target[0]'),
186
- MetricDoc('Mixture Metrics', 'mxcovl',
187
- 'Predicted rating of overall quality of mixture versus true target[0]'),
188
- MetricDoc('Mixture Metrics', 'ssnr', 'Segmental SNR'),
189
- MetricDoc('Target Metrics', 'tdco', 'Target[0] DC offset'),
190
- MetricDoc('Target Metrics', 'tmin', 'Target[0] min level'),
191
- MetricDoc('Target Metrics', 'tmax', 'Target[0] max levl'),
192
- MetricDoc('Target Metrics', 'tpkdb', 'Target[0] Pk lev dB'),
193
- MetricDoc('Target Metrics', 'tlrms', 'Target[0] RMS lev dB'),
194
- MetricDoc('Target Metrics', 'tpkr', 'Target[0] RMS Pk dB'),
195
- MetricDoc('Target Metrics', 'ttr', 'Target[0] RMS Tr dB'),
196
- MetricDoc('Target Metrics', 'tcr', 'Target[0] Crest factor'),
197
- MetricDoc('Target Metrics', 'tfl', 'Target[0] Flat factor'),
198
- MetricDoc('Target Metrics', 'tpkc', 'Target[0] Pk count'),
199
- MetricDoc('Noise Metrics', 'ndco', 'Noise DC offset'),
200
- MetricDoc('Noise Metrics', 'nmin', 'Noise min level'),
201
- MetricDoc('Noise Metrics', 'nmax', 'Noise max levl'),
202
- MetricDoc('Noise Metrics', 'npkdb', 'Noise Pk lev dB'),
203
- MetricDoc('Noise Metrics', 'nlrms', 'Noise RMS lev dB'),
204
- MetricDoc('Noise Metrics', 'npkr', 'Noise RMS Pk dB'),
205
- MetricDoc('Noise Metrics', 'ntr', 'Noise RMS Tr dB'),
206
- MetricDoc('Noise Metrics', 'ncr', 'Noise Crest factor'),
207
- MetricDoc('Noise Metrics', 'nfl', 'Noise Flat factor'),
208
- MetricDoc('Noise Metrics', 'npkc', 'Noise Pk count'),
209
- MetricDoc('Truth Metrics', 'sedavg',
210
- '(not implemented) Average SED activity over all frames [num_classes, 1]'),
211
- MetricDoc('Truth Metrics', 'sedcnt',
212
- '(not implemented) Count in number of frames that SED is active [num_classes, 1]'),
213
- MetricDoc('Truth Metrics', 'sedtop3', '(not implemented) 3 most active by largest sedavg [3, 1]'),
214
- MetricDoc('Truth Metrics', 'sedtopn', '(not implemented) N most active by largest sedavg [N, 1]'),
215
- ])
157
+ metrics = MetricDocs(
158
+ [
159
+ MetricDoc("Mixture Metrics", "mxsnr", "SNR specification in dB"),
160
+ MetricDoc(
161
+ "Mixture Metrics",
162
+ "mxssnr_avg",
163
+ "Segmental SNR average over all frames",
164
+ ),
165
+ MetricDoc(
166
+ "Mixture Metrics",
167
+ "mxssnr_std",
168
+ "Segmental SNR standard deviation over all frames",
169
+ ),
170
+ MetricDoc(
171
+ "Mixture Metrics",
172
+ "mxssnrdb_avg",
173
+ "Segmental SNR average of the dB frame values over all frames",
174
+ ),
175
+ MetricDoc(
176
+ "Mixture Metrics",
177
+ "mxssnrdb_std",
178
+ "Segmental SNR standard deviation of the dB frame values over all frames",
179
+ ),
180
+ MetricDoc(
181
+ "Mixture Metrics",
182
+ "mxssnrf_avg",
183
+ "Per-bin segmental SNR average over all frames (using feature transform)",
184
+ ),
185
+ MetricDoc(
186
+ "Mixture Metrics",
187
+ "mxssnrf_std",
188
+ "Per-bin segmental SNR standard deviation over all frames (using feature transform)",
189
+ ),
190
+ MetricDoc(
191
+ "Mixture Metrics",
192
+ "mxssnrdbf_avg",
193
+ "Per-bin segmental average of the dB frame values over all frames (using feature transform)",
194
+ ),
195
+ MetricDoc(
196
+ "Mixture Metrics",
197
+ "mxssnrdbf_std",
198
+ "Per-bin segmental standard deviation of the dB frame values over all frames (using feature transform)",
199
+ ),
200
+ MetricDoc("Mixture Metrics", "mxpesq", "PESQ of mixture versus true target[0]"),
201
+ MetricDoc(
202
+ "Mixture Metrics",
203
+ "mxwsdr",
204
+ "Weighted signal distorion ratio of mixture versus true target[0]",
205
+ ),
206
+ MetricDoc(
207
+ "Mixture Metrics",
208
+ "mxpd",
209
+ "Phase distance between mixture and true target[0]",
210
+ ),
211
+ MetricDoc(
212
+ "Mixture Metrics",
213
+ "mxstoi",
214
+ "Short term objective intelligibility of mixture versus true target[0]",
215
+ ),
216
+ MetricDoc(
217
+ "Mixture Metrics",
218
+ "mxcsig",
219
+ "Predicted rating of speech distortion of mixture versus true target[0]",
220
+ ),
221
+ MetricDoc(
222
+ "Mixture Metrics",
223
+ "mxcbak",
224
+ "Predicted rating of background distortion of mixture versus true target[0]",
225
+ ),
226
+ MetricDoc(
227
+ "Mixture Metrics",
228
+ "mxcovl",
229
+ "Predicted rating of overall quality of mixture versus true target[0]",
230
+ ),
231
+ MetricDoc("Mixture Metrics", "ssnr", "Segmental SNR"),
232
+ MetricDoc("Target Metrics", "tdco", "Target[0] DC offset"),
233
+ MetricDoc("Target Metrics", "tmin", "Target[0] min level"),
234
+ MetricDoc("Target Metrics", "tmax", "Target[0] max levl"),
235
+ MetricDoc("Target Metrics", "tpkdb", "Target[0] Pk lev dB"),
236
+ MetricDoc("Target Metrics", "tlrms", "Target[0] RMS lev dB"),
237
+ MetricDoc("Target Metrics", "tpkr", "Target[0] RMS Pk dB"),
238
+ MetricDoc("Target Metrics", "ttr", "Target[0] RMS Tr dB"),
239
+ MetricDoc("Target Metrics", "tcr", "Target[0] Crest factor"),
240
+ MetricDoc("Target Metrics", "tfl", "Target[0] Flat factor"),
241
+ MetricDoc("Target Metrics", "tpkc", "Target[0] Pk count"),
242
+ MetricDoc("Noise Metrics", "ndco", "Noise DC offset"),
243
+ MetricDoc("Noise Metrics", "nmin", "Noise min level"),
244
+ MetricDoc("Noise Metrics", "nmax", "Noise max levl"),
245
+ MetricDoc("Noise Metrics", "npkdb", "Noise Pk lev dB"),
246
+ MetricDoc("Noise Metrics", "nlrms", "Noise RMS lev dB"),
247
+ MetricDoc("Noise Metrics", "npkr", "Noise RMS Pk dB"),
248
+ MetricDoc("Noise Metrics", "ntr", "Noise RMS Tr dB"),
249
+ MetricDoc("Noise Metrics", "ncr", "Noise Crest factor"),
250
+ MetricDoc("Noise Metrics", "nfl", "Noise Flat factor"),
251
+ MetricDoc("Noise Metrics", "npkc", "Noise Pk count"),
252
+ MetricDoc(
253
+ "Truth Metrics",
254
+ "sedavg",
255
+ "(not implemented) Average SED activity over all frames [truth_parameters, 1]",
256
+ ),
257
+ MetricDoc(
258
+ "Truth Metrics",
259
+ "sedcnt",
260
+ "(not implemented) Count in number of frames that SED is active [truth_parameters, 1]",
261
+ ),
262
+ MetricDoc(
263
+ "Truth Metrics",
264
+ "sedtop3",
265
+ "(not implemented) 3 most active by largest sedavg [3, 1]",
266
+ ),
267
+ MetricDoc(
268
+ "Truth Metrics",
269
+ "sedtopn",
270
+ "(not implemented) N most active by largest sedavg [N, 1]",
271
+ ),
272
+ ]
273
+ )
216
274
  for name in self.asr_configs:
217
- metrics.append(MetricDoc('Target Metrics', f'tasr.{name}',
218
- f'Target[0] ASR text using {name} ASR as defined in mixdb asr_configs parameter'))
219
- metrics.append(MetricDoc('Mixture Metrics', f'mxasr.{name}',
220
- f'ASR text using {name} ASR as defined in mixdb asr_configs parameter'))
221
- metrics.append(MetricDoc('Mixture Metrics', f'mxwer.{name}',
222
- f'Word error rate using {name} ASR as defined in mixdb asr_configs parameter'))
275
+ metrics.append(
276
+ MetricDoc(
277
+ "Target Metrics",
278
+ f"tasr.{name}",
279
+ f"Target[0] ASR text using {name} ASR as defined in mixdb asr_configs parameter",
280
+ )
281
+ )
282
+ metrics.append(
283
+ MetricDoc(
284
+ "Mixture Metrics",
285
+ f"mxasr.{name}",
286
+ f"ASR text using {name} ASR as defined in mixdb asr_configs parameter",
287
+ )
288
+ )
289
+ metrics.append(
290
+ MetricDoc(
291
+ "Target Metrics",
292
+ f"basewer.{name}",
293
+ f"Word error rate of tasr.{name} vs. speech text metadata for the target",
294
+ )
295
+ )
296
+ metrics.append(
297
+ MetricDoc(
298
+ "Mixture Metrics",
299
+ f"mxwer.{name}",
300
+ f"Word error rate of mxasr.{name} vs. tasr.{name}",
301
+ )
302
+ )
223
303
 
224
304
  return metrics
225
305
 
@@ -265,7 +345,7 @@ class MixtureDatabase:
265
345
  def transform_frame_ms(self) -> float:
266
346
  from .constants import SAMPLE_RATE
267
347
 
268
- return float(self.ft_config.R) / float(SAMPLE_RATE / 1000)
348
+ return float(self.ft_config.overlap) / float(SAMPLE_RATE / 1000)
269
349
 
270
350
  @cached_property
271
351
  def feature_ms(self) -> float:
@@ -273,7 +353,7 @@ class MixtureDatabase:
273
353
 
274
354
  @cached_property
275
355
  def feature_samples(self) -> int:
276
- return self.ft_config.R * self.fg_decimation * self.fg_stride
356
+ return self.ft_config.overlap * self.fg_decimation * self.fg_stride
277
357
 
278
358
  @cached_property
279
359
  def feature_step_ms(self) -> float:
@@ -281,28 +361,33 @@ class MixtureDatabase:
281
361
 
282
362
  @cached_property
283
363
  def feature_step_samples(self) -> int:
284
- return self.ft_config.R * self.fg_decimation * self.fg_step
364
+ return self.ft_config.overlap * self.fg_decimation * self.fg_step
285
365
 
286
- def total_samples(self, m_ids: GeneralizedIDs = '*') -> int:
287
- return sum([self.mixture(m_id).samples for m_id in self.mixids_to_list(m_ids)])
366
+ def total_samples(self, m_ids: GeneralizedIDs = "*") -> int:
367
+ samples = 0
368
+ for m_id in self.mixids_to_list(m_ids):
369
+ s = self.mixture(m_id).samples
370
+ if s is not None:
371
+ samples += s
372
+ return samples
288
373
 
289
- def total_transform_frames(self, m_ids: GeneralizedIDs = '*') -> int:
290
- return self.total_samples(m_ids) // self.ft_config.R
374
+ def total_transform_frames(self, m_ids: GeneralizedIDs = "*") -> int:
375
+ return self.total_samples(m_ids) // self.ft_config.overlap
291
376
 
292
- def total_feature_frames(self, m_ids: GeneralizedIDs = '*') -> int:
377
+ def total_feature_frames(self, m_ids: GeneralizedIDs = "*") -> int:
293
378
  return self.total_samples(m_ids) // self.feature_step_samples
294
379
 
295
380
  def mixture_transform_frames(self, m_id: int) -> int:
296
381
  from .helpers import frames_from_samples
297
382
 
298
- return frames_from_samples(self.mixture(m_id).samples, self.ft_config.R)
383
+ return frames_from_samples(self.mixture(m_id).samples, self.ft_config.overlap)
299
384
 
300
385
  def mixture_feature_frames(self, m_id: int) -> int:
301
386
  from .helpers import frames_from_samples
302
387
 
303
388
  return frames_from_samples(self.mixture(m_id).samples, self.feature_step_samples)
304
389
 
305
- def mixids_to_list(self, m_ids: Optional[GeneralizedIDs] = None) -> list[int]:
390
+ def mixids_to_list(self, m_ids: GeneralizedIDs = "*") -> list[int]:
306
391
  """Resolve generalized mixture IDs to a list of integers
307
392
 
308
393
  :param m_ids: Generalized mixture IDs
@@ -319,8 +404,10 @@ class MixtureDatabase:
319
404
  :return: Class labels
320
405
  """
321
406
  with self.db() as c:
322
- return [str(item[0]) for item in
323
- c.execute("SELECT class_label.label FROM class_label ORDER BY class_label.id").fetchall()]
407
+ return [
408
+ str(item[0])
409
+ for item in c.execute("SELECT class_label.label FROM class_label ORDER BY class_label.id").fetchall()
410
+ ]
324
411
 
325
412
  @cached_property
326
413
  def class_weights_thresholds(self) -> list[float]:
@@ -329,8 +416,37 @@ class MixtureDatabase:
329
416
  :return: Class weights thresholds
330
417
  """
331
418
  with self.db() as c:
332
- return [float(item[0]) for item in
333
- c.execute("SELECT class_weights_threshold.threshold FROM class_weights_threshold").fetchall()]
419
+ return [
420
+ float(item[0])
421
+ for item in c.execute(
422
+ "SELECT class_weights_threshold.threshold FROM class_weights_threshold"
423
+ ).fetchall()
424
+ ]
425
+
426
+ @cached_property
427
+ def truth_configs(self) -> TruthConfigs:
428
+ """Get truth configs from db
429
+
430
+ :return: Truth configs
431
+ """
432
+ import json
433
+
434
+ from .datatypes import TruthConfig
435
+
436
+ with self.db() as c:
437
+ truth_configs: TruthConfigs = {}
438
+ for truth_config_record in c.execute("SELECT truth_config.config FROM truth_config").fetchall():
439
+ truth_config = json.loads(truth_config_record[0])
440
+ if truth_config["name"] not in truth_configs:
441
+ truth_configs[truth_config["name"]] = TruthConfig(
442
+ function=truth_config["function"],
443
+ stride_reduction=truth_config["stride_reduction"],
444
+ config=truth_config["config"],
445
+ )
446
+ return truth_configs
447
+
448
+ def target_truth_configs(self, t_id: int) -> TruthConfigs:
449
+ return _target_truth_configs(self.db, t_id)
334
450
 
335
451
  @cached_property
336
452
  def random_snrs(self) -> list[float]:
@@ -339,8 +455,12 @@ class MixtureDatabase:
339
455
  :return: Random SNRs
340
456
  """
341
457
  with self.db() as c:
342
- return list(set([float(item[0]) for item in
343
- c.execute("SELECT mixture.snr FROM mixture WHERE mixture.random_snr == 1").fetchall()]))
458
+ return list(
459
+ {
460
+ float(item[0])
461
+ for item in c.execute("SELECT mixture.snr FROM mixture WHERE mixture.random_snr == 1").fetchall()
462
+ }
463
+ )
344
464
 
345
465
  @cached_property
346
466
  def snrs(self) -> list[float]:
@@ -349,13 +469,21 @@ class MixtureDatabase:
349
469
  :return: SNRs
350
470
  """
351
471
  with self.db() as c:
352
- return list(set([float(item[0]) for item in
353
- c.execute("SELECT mixture.snr FROM mixture WHERE mixture.random_snr == 0").fetchall()]))
472
+ return list(
473
+ {
474
+ float(item[0])
475
+ for item in c.execute("SELECT mixture.snr FROM mixture WHERE mixture.random_snr == 0").fetchall()
476
+ }
477
+ )
354
478
 
355
479
  @cached_property
356
480
  def all_snrs(self) -> list[UniversalSNR]:
357
- return sorted(list(set([UniversalSNR(is_random=False, value=snr) for snr in self.snrs] +
358
- [UniversalSNR(is_random=True, value=snr) for snr in self.random_snrs])))
481
+ return sorted(
482
+ set(
483
+ [UniversalSNR(is_random=False, value=snr) for snr in self.snrs]
484
+ + [UniversalSNR(is_random=True, value=snr) for snr in self.random_snrs]
485
+ )
486
+ )
359
487
 
360
488
  @cached_property
361
489
  def spectral_masks(self) -> SpectralMasks:
@@ -366,13 +494,19 @@ class MixtureDatabase:
366
494
  from .db_datatypes import SpectralMaskRecord
367
495
 
368
496
  with self.db() as c:
369
- spectral_masks = [SpectralMaskRecord(*result) for result in
370
- c.execute("SELECT * FROM spectral_mask").fetchall()]
371
- return [SpectralMask(f_max_width=spectral_mask.f_max_width,
372
- f_num=spectral_mask.f_num,
373
- t_max_width=spectral_mask.t_max_width,
374
- t_num=spectral_mask.t_num,
375
- t_max_percent=spectral_mask.t_max_percent) for spectral_mask in spectral_masks]
497
+ spectral_masks = [
498
+ SpectralMaskRecord(*result) for result in c.execute("SELECT * FROM spectral_mask").fetchall()
499
+ ]
500
+ return [
501
+ SpectralMask(
502
+ f_max_width=spectral_mask.f_max_width,
503
+ f_num=spectral_mask.f_num,
504
+ t_max_width=spectral_mask.t_max_width,
505
+ t_num=spectral_mask.t_num,
506
+ t_max_percent=spectral_mask.t_max_percent,
507
+ )
508
+ for spectral_mask in spectral_masks
509
+ ]
376
510
 
377
511
  def spectral_mask(self, sm_id: int) -> SpectralMask:
378
512
  """Get spectral mask with ID from db
@@ -390,31 +524,40 @@ class MixtureDatabase:
390
524
  """
391
525
  import json
392
526
 
393
- from .datatypes import TruthSetting
394
- from .datatypes import TruthSettings
527
+ from .datatypes import TruthConfig
528
+ from .datatypes import TruthConfigs
395
529
  from .db_datatypes import TargetFileRecord
396
530
 
397
531
  with self.db() as c:
398
532
  target_files: TargetFiles = []
399
- target_file_records = [TargetFileRecord(*result) for result in
400
- c.execute("SELECT * FROM target_file").fetchall()]
533
+ target_file_records = [
534
+ TargetFileRecord(*result) for result in c.execute("SELECT * FROM target_file").fetchall()
535
+ ]
401
536
  for target_file_record in target_file_records:
402
- truth_settings: TruthSettings = []
403
- for truth_setting_records in c.execute(
404
- "SELECT truth_setting.setting " +
405
- "FROM truth_setting, target_file_truth_setting " +
406
- "WHERE ? = target_file_truth_setting.target_file_id " +
407
- "AND truth_setting.id = target_file_truth_setting.truth_setting_id",
408
- (target_file_record.id,)).fetchall():
409
- truth_setting = json.loads(truth_setting_records[0])
410
- truth_settings.append(TruthSetting(config=truth_setting.get('config', None),
411
- function=truth_setting.get('function', None),
412
- index=truth_setting.get('index', None)))
413
- target_files.append(TargetFile(name=target_file_record.name,
414
- samples=target_file_record.samples,
415
- level_type=target_file_record.level_type,
416
- truth_settings=truth_settings,
417
- speaker_id=target_file_record.speaker_id))
537
+ truth_configs: TruthConfigs = {}
538
+ for truth_config_records in c.execute(
539
+ "SELECT truth_config.config "
540
+ + "FROM truth_config, target_file_truth_config "
541
+ + "WHERE ? = target_file_truth_config.target_file_id "
542
+ + "AND truth_config.id = target_file_truth_config.truth_config_id",
543
+ (target_file_record.id,),
544
+ ).fetchall():
545
+ truth_config = json.loads(truth_config_records[0])
546
+ truth_configs[truth_config["name"]] = TruthConfig(
547
+ function=truth_config["function"],
548
+ stride_reduction=truth_config["stride_reduction"],
549
+ config=truth_config["config"],
550
+ )
551
+ target_files.append(
552
+ TargetFile(
553
+ name=target_file_record.name,
554
+ samples=target_file_record.samples,
555
+ class_indices=json.loads(target_file_record.class_indices),
556
+ level_type=target_file_record.level_type,
557
+ truth_configs=truth_configs,
558
+ speaker_id=target_file_record.speaker_id,
559
+ )
560
+ )
418
561
  return target_files
419
562
 
420
563
  @cached_property
@@ -450,8 +593,10 @@ class MixtureDatabase:
450
593
  :return: Noise files
451
594
  """
452
595
  with self.db() as c:
453
- return [NoiseFile(name=noise[0], samples=noise[1]) for noise in
454
- c.execute("SELECT noise_file.name, samples FROM noise_file").fetchall()]
596
+ return [
597
+ NoiseFile(name=noise[0], samples=noise[1])
598
+ for noise in c.execute("SELECT noise_file.name, samples FROM noise_file").fetchall()
599
+ ]
455
600
 
456
601
  @cached_property
457
602
  def noise_file_ids(self) -> list[int]:
@@ -485,9 +630,21 @@ class MixtureDatabase:
485
630
 
486
631
  :return: Impulse response files
487
632
  """
633
+ import json
634
+
635
+ from .datatypes import ImpulseResponseFile
636
+
488
637
  with self.db() as c:
489
- return [str(impulse_response[0]) for impulse_response in
490
- c.execute("SELECT impulse_response_file.file FROM impulse_response_file").fetchall()]
638
+ # for impulse_response in c.execute(
639
+ # "SELECT impulse_response_file.* FROM impulse_response_file"
640
+ # ).fetchall():
641
+ # print(impulse_response)
642
+ return [
643
+ ImpulseResponseFile(impulse_response[1], json.loads(impulse_response[2]))
644
+ for impulse_response in c.execute(
645
+ "SELECT impulse_response_file.* FROM impulse_response_file"
646
+ ).fetchall()
647
+ ]
491
648
 
492
649
  @cached_property
493
650
  def impulse_response_file_ids(self) -> list[int]:
@@ -496,15 +653,19 @@ class MixtureDatabase:
496
653
  :return: List of impulse response file IDs
497
654
  """
498
655
  with self.db() as c:
499
- return [int(item[0]) for item in
500
- c.execute("SELECT impulse_response_file.id FROM impulse_response_file").fetchall()]
656
+ return [
657
+ int(item[0])
658
+ for item in c.execute("SELECT impulse_response_file.id FROM impulse_response_file").fetchall()
659
+ ]
501
660
 
502
- def impulse_response_file(self, ir_id: int) -> str:
661
+ def impulse_response_file(self, ir_id: int | None) -> str | None:
503
662
  """Get impulse response file with ID from db
504
663
 
505
664
  :param ir_id: Impulse response file ID
506
665
  :return: Noise
507
666
  """
667
+ if ir_id is None:
668
+ return None
508
669
  return _impulse_response_file(self.db, ir_id)
509
670
 
510
671
  @cached_property
@@ -522,18 +683,22 @@ class MixtureDatabase:
522
683
 
523
684
  :return: Mixtures
524
685
  """
525
- from .helpers import to_mixture
526
- from .helpers import to_target
527
686
  from .db_datatypes import MixtureRecord
528
687
  from .db_datatypes import TargetRecord
688
+ from .helpers import to_mixture
689
+ from .helpers import to_target
529
690
 
530
691
  with self.db() as c:
531
692
  mixtures: Mixtures = []
532
693
  for mixture in [MixtureRecord(*record) for record in c.execute("SELECT * FROM mixture").fetchall()]:
533
- targets = [to_target(TargetRecord(*target)) for target in c.execute(
534
- "SELECT target.* FROM target, mixture_target " +
535
- "WHERE ? = mixture_target.mixture_id AND target.id = mixture_target.target_id",
536
- (mixture.id,)).fetchall()]
694
+ targets = [
695
+ to_target(TargetRecord(*target))
696
+ for target in c.execute(
697
+ "SELECT target.* FROM target, mixture_target "
698
+ + "WHERE ? = mixture_target.mixture_id AND target.id = mixture_target.target_id",
699
+ (mixture.id,),
700
+ ).fetchall()
701
+ ]
537
702
  mixtures.append(to_mixture(mixture, targets))
538
703
  return mixtures
539
704
 
@@ -559,23 +724,15 @@ class MixtureDatabase:
559
724
  with self.db() as c:
560
725
  return int(c.execute("SELECT top.mixid_width FROM top").fetchone()[0])
561
726
 
562
- def location_filename(self, name: str) -> str:
563
- """Add the location to the given file name
727
+ def mixture_location(self, m_id: int) -> str:
728
+ """Get the file location for the give mixture ID
564
729
 
565
- :param name: File name
566
- :return: Location added
730
+ :param m_id: Zero-based mixture ID
731
+ :return: File location
567
732
  """
568
733
  from os.path import join
569
734
 
570
- return join(self.location, name)
571
-
572
- def mixture_filename(self, m_id: int) -> str:
573
- """Get the HDF5 file name for the give mixture ID
574
-
575
- :param m_id: Zero-based mixture ID
576
- :return: File name
577
- """
578
- return self.location_filename(self.mixture(m_id).name)
735
+ return join(self.location, self.mixture(m_id).name)
579
736
 
580
737
  @cached_property
581
738
  def num_mixtures(self) -> int:
@@ -593,9 +750,9 @@ class MixtureDatabase:
593
750
  :param items: String(s) of dataset(s) to retrieve
594
751
  :return: Data (or tuple of data)
595
752
  """
596
- from .helpers import read_mixture_data
753
+ from sonusai.mixture import read_cached_data
597
754
 
598
- return read_mixture_data(self.location_filename(self.mixture(m_id).name), items)
755
+ return read_cached_data(self.location, "mixture", self.mixture(m_id).name, items)
599
756
 
600
757
  def read_target_audio(self, t_id: int) -> AudioT:
601
758
  """Read target audio
@@ -622,10 +779,19 @@ class MixtureDatabase:
622
779
  audio = read_audio(noise.name)
623
780
  audio = apply_augmentation(audio, mixture.noise.augmentation)
624
781
  if mixture.noise.augmentation.ir is not None:
625
- audio = apply_impulse_response(audio, read_ir(self.impulse_response_file(mixture.noise.augmentation.ir)))
782
+ audio = apply_impulse_response(
783
+ audio,
784
+ read_ir(self.impulse_response_file(mixture.noise.augmentation.ir)),
785
+ )
626
786
 
627
787
  return audio
628
788
 
789
+ def mixture_class_indices(self, m_id: int) -> list[int]:
790
+ class_indices: list[int] = []
791
+ for t_id in self.mixture(m_id).target_ids:
792
+ class_indices.extend(self.target_file(t_id).class_indices)
793
+ return sorted(set(class_indices))
794
+
629
795
  def mixture_targets(self, m_id: int, force: bool = False) -> AudiosT:
630
796
  """Get the list of augmented target audio data (one per target in the mixup) for the given mixture ID
631
797
 
@@ -633,36 +799,34 @@ class MixtureDatabase:
633
799
  :param force: Force computing data from original sources regardless of whether cached data exists
634
800
  :return: List of augmented target audio data (one per target in the mixup)
635
801
  """
636
- from sonusai import SonusAIError
637
802
  from .augmentation import apply_augmentation
638
803
  from .augmentation import apply_gain
639
804
  from .augmentation import pad_audio_to_length
640
805
 
641
806
  if not force:
642
- targets_audio = self.read_mixture_data(m_id, 'targets')
807
+ targets_audio = self.read_mixture_data(m_id, "targets")
643
808
  if targets_audio is not None:
644
809
  return list(targets_audio)
645
810
 
646
811
  mixture = self.mixture(m_id)
647
812
  if mixture is None:
648
- raise SonusAIError(f'Could not find mixture for m_id: {m_id}')
813
+ raise ValueError(f"Could not find mixture for m_id: {m_id}")
649
814
 
650
815
  targets_audio = []
651
816
  for target in mixture.targets:
652
817
  target_audio = self.read_target_audio(target.file_id)
653
- target_audio = apply_augmentation(audio=target_audio,
654
- augmentation=target.augmentation,
655
- frame_length=self.feature_step_samples)
818
+ target_audio = apply_augmentation(
819
+ audio=target_audio,
820
+ augmentation=target.augmentation,
821
+ frame_length=self.feature_step_samples,
822
+ )
656
823
  target_audio = apply_gain(audio=target_audio, gain=mixture.target_snr_gain)
657
824
  target_audio = pad_audio_to_length(audio=target_audio, length=mixture.samples)
658
825
  targets_audio.append(target_audio)
659
826
 
660
827
  return targets_audio
661
828
 
662
- def mixture_targets_f(self,
663
- m_id: int,
664
- targets: Optional[AudiosT] = None,
665
- force: bool = False) -> AudiosF:
829
+ def mixture_targets_f(self, m_id: int, targets: AudiosT | None = None, force: bool = False) -> AudiosF:
666
830
  """Get the list of augmented target transform data (one per target in the mixup) for the given mixture ID
667
831
 
668
832
  :param m_id: Zero-based mixture ID
@@ -677,10 +841,7 @@ class MixtureDatabase:
677
841
 
678
842
  return [forward_transform(target, self.ft_config) for target in targets]
679
843
 
680
- def mixture_target(self,
681
- m_id: int,
682
- targets: Optional[AudiosT] = None,
683
- force: bool = False) -> AudioT:
844
+ def mixture_target(self, m_id: int, targets: AudiosT | None = None, force: bool = False) -> AudioT:
684
845
  """Get the augmented target audio data for the given mixture ID
685
846
 
686
847
  :param m_id: Zero-based mixture ID
@@ -691,7 +852,7 @@ class MixtureDatabase:
691
852
  from .helpers import get_target
692
853
 
693
854
  if not force:
694
- target = self.read_mixture_data(m_id, 'target')
855
+ target = self.read_mixture_data(m_id, "target")
695
856
  if target is not None:
696
857
  return target
697
858
 
@@ -700,11 +861,13 @@ class MixtureDatabase:
700
861
 
701
862
  return get_target(self, self.mixture(m_id), targets)
702
863
 
703
- def mixture_target_f(self,
704
- m_id: int,
705
- targets: Optional[AudiosT] = None,
706
- target: Optional[AudioT] = None,
707
- force: bool = False) -> AudioF:
864
+ def mixture_target_f(
865
+ self,
866
+ m_id: int,
867
+ targets: AudiosT | None = None,
868
+ target: AudioT | None = None,
869
+ force: bool = False,
870
+ ) -> AudioF:
708
871
  """Get the augmented target transform data for the given mixture ID
709
872
 
710
873
  :param m_id: Zero-based mixture ID
@@ -720,9 +883,7 @@ class MixtureDatabase:
720
883
 
721
884
  return forward_transform(target, self.ft_config)
722
885
 
723
- def mixture_noise(self,
724
- m_id: int,
725
- force: bool = False) -> AudioT:
886
+ def mixture_noise(self, m_id: int, force: bool = False) -> AudioT:
726
887
  """Get the augmented noise audio data for the given mixture ID
727
888
 
728
889
  :param m_id: Zero-based mixture ID
@@ -733,7 +894,7 @@ class MixtureDatabase:
733
894
  from .augmentation import apply_gain
734
895
 
735
896
  if not force:
736
- noise = self.read_mixture_data(m_id, 'noise')
897
+ noise = self.read_mixture_data(m_id, "noise")
737
898
  if noise is not None:
738
899
  return noise
739
900
 
@@ -742,10 +903,7 @@ class MixtureDatabase:
742
903
  noise = get_next_noise(audio=noise, offset=mixture.noise.offset, length=mixture.samples)
743
904
  return apply_gain(audio=noise, gain=mixture.noise_snr_gain)
744
905
 
745
- def mixture_noise_f(self,
746
- m_id: int,
747
- noise: Optional[AudioT] = None,
748
- force: bool = False) -> AudioF:
906
+ def mixture_noise_f(self, m_id: int, noise: AudioT | None = None, force: bool = False) -> AudioF:
749
907
  """Get the augmented noise transform for the given mixture ID
750
908
 
751
909
  :param m_id: Zero-based mixture ID
@@ -760,12 +918,14 @@ class MixtureDatabase:
760
918
 
761
919
  return forward_transform(noise, self.ft_config)
762
920
 
763
- def mixture_mixture(self,
764
- m_id: int,
765
- targets: Optional[AudiosT] = None,
766
- target: Optional[AudioT] = None,
767
- noise: Optional[AudioT] = None,
768
- force: bool = False) -> AudioT:
921
+ def mixture_mixture(
922
+ self,
923
+ m_id: int,
924
+ targets: AudiosT | None = None,
925
+ target: AudioT | None = None,
926
+ noise: AudioT | None = None,
927
+ force: bool = False,
928
+ ) -> AudioT:
769
929
  """Get the mixture audio data for the given mixture ID
770
930
 
771
931
  :param m_id: Zero-based mixture ID
@@ -776,7 +936,7 @@ class MixtureDatabase:
776
936
  :return: Mixture audio data
777
937
  """
778
938
  if not force:
779
- mixture = self.read_mixture_data(m_id, 'mixture')
939
+ mixture = self.read_mixture_data(m_id, "mixture")
780
940
  if mixture is not None:
781
941
  return mixture
782
942
 
@@ -788,13 +948,15 @@ class MixtureDatabase:
788
948
 
789
949
  return target + noise
790
950
 
791
- def mixture_mixture_f(self,
792
- m_id: int,
793
- targets: Optional[AudiosT] = None,
794
- target: Optional[AudioT] = None,
795
- noise: Optional[AudioT] = None,
796
- mixture: Optional[AudioT] = None,
797
- force: bool = False) -> AudioF:
951
+ def mixture_mixture_f(
952
+ self,
953
+ m_id: int,
954
+ targets: AudiosT | None = None,
955
+ target: AudioT | None = None,
956
+ noise: AudioT | None = None,
957
+ mixture: AudioT | None = None,
958
+ force: bool = False,
959
+ ) -> AudioF:
798
960
  """Get the mixture transform for the given mixture ID
799
961
 
800
962
  :param m_id: Zero-based mixture ID
@@ -815,18 +977,22 @@ class MixtureDatabase:
815
977
 
816
978
  m = self.mixture(m_id)
817
979
  if m.spectral_mask_id is not None:
818
- mixture_f = apply_spectral_mask(audio_f=mixture_f,
819
- spectral_mask=self.spectral_mask(int(m.spectral_mask_id)),
820
- seed=m.spectral_mask_seed)
980
+ mixture_f = apply_spectral_mask(
981
+ audio_f=mixture_f,
982
+ spectral_mask=self.spectral_mask(int(m.spectral_mask_id)),
983
+ seed=m.spectral_mask_seed,
984
+ )
821
985
 
822
986
  return mixture_f
823
987
 
824
- def mixture_truth_t(self,
825
- m_id: int,
826
- targets: Optional[AudiosT] = None,
827
- noise: Optional[AudioT] = None,
828
- mixture: Optional[AudioT] = None,
829
- force: bool = False) -> Truth:
988
+ def mixture_truth_t(
989
+ self,
990
+ m_id: int,
991
+ targets: AudiosT | None = None,
992
+ noise: AudioT | None = None,
993
+ mixture: AudioT | None = None,
994
+ force: bool = False,
995
+ ) -> TruthDict:
830
996
  """Get the truth_t data for the given mixture ID
831
997
 
832
998
  :param m_id: Zero-based mixture ID
@@ -836,10 +1002,10 @@ class MixtureDatabase:
836
1002
  :param force: Force computing data from original sources regardless of whether cached data exists
837
1003
  :return: truth_t data
838
1004
  """
839
- from .helpers import get_truth_t
1005
+ from .helpers import get_truth
840
1006
 
841
1007
  if not force:
842
- truth_t = self.read_mixture_data(m_id, 'truth_t')
1008
+ truth_t = self.read_mixture_data(m_id, "truth_t")
843
1009
  if truth_t is not None:
844
1010
  return truth_t
845
1011
 
@@ -850,19 +1016,18 @@ class MixtureDatabase:
850
1016
  noise = self.mixture_noise(m_id, force)
851
1017
 
852
1018
  if force or mixture is None:
853
- noise = self.mixture_mixture(m_id,
854
- targets=targets,
855
- noise=noise,
856
- force=force)
857
-
858
- return get_truth_t(self, self.mixture(m_id), targets, noise, mixture)
859
-
860
- def mixture_segsnr_t(self,
861
- m_id: int,
862
- targets: Optional[AudiosT] = None,
863
- target: Optional[AudioT] = None,
864
- noise: Optional[AudioT] = None,
865
- force: bool = False) -> Segsnr:
1019
+ mixture = self.mixture_mixture(m_id, targets=targets, noise=noise, force=force)
1020
+
1021
+ return get_truth(self, self.mixture(m_id), targets, noise, mixture)
1022
+
1023
+ def mixture_segsnr_t(
1024
+ self,
1025
+ m_id: int,
1026
+ targets: AudiosT | None = None,
1027
+ target: AudioT | None = None,
1028
+ noise: AudioT | None = None,
1029
+ force: bool = False,
1030
+ ) -> Segsnr:
866
1031
  """Get the segsnr_t data for the given mixture ID
867
1032
 
868
1033
  :param m_id: Zero-based mixture ID
@@ -875,7 +1040,7 @@ class MixtureDatabase:
875
1040
  from .helpers import get_segsnr_t
876
1041
 
877
1042
  if not force:
878
- segsnr_t = self.read_mixture_data(m_id, 'segsnr_t')
1043
+ segsnr_t = self.read_mixture_data(m_id, "segsnr_t")
879
1044
  if segsnr_t is not None:
880
1045
  return segsnr_t
881
1046
 
@@ -887,13 +1052,15 @@ class MixtureDatabase:
887
1052
 
888
1053
  return get_segsnr_t(self, self.mixture(m_id), target, noise)
889
1054
 
890
- def mixture_segsnr(self,
891
- m_id: int,
892
- segsnr_t: Optional[Segsnr] = None,
893
- targets: Optional[AudiosT] = None,
894
- target: Optional[AudioT] = None,
895
- noise: Optional[AudioT] = None,
896
- force: bool = False) -> Segsnr:
1055
+ def mixture_segsnr(
1056
+ self,
1057
+ m_id: int,
1058
+ segsnr_t: Segsnr | None = None,
1059
+ targets: AudiosT | None = None,
1060
+ target: AudioT | None = None,
1061
+ noise: AudioT | None = None,
1062
+ force: bool = False,
1063
+ ) -> Segsnr:
897
1064
  """Get the segsnr data for the given mixture ID
898
1065
 
899
1066
  :param m_id: Zero-based mixture ID
@@ -905,28 +1072,30 @@ class MixtureDatabase:
905
1072
  :return: segsnr data
906
1073
  """
907
1074
  if not force:
908
- segsnr = self.read_mixture_data(m_id, 'segsnr')
1075
+ segsnr = self.read_mixture_data(m_id, "segsnr")
909
1076
  if segsnr is not None:
910
1077
  return segsnr
911
1078
 
912
- segsnr_t = self.read_mixture_data(m_id, 'segsnr_t')
1079
+ segsnr_t = self.read_mixture_data(m_id, "segsnr_t")
913
1080
  if segsnr_t is not None:
914
- return segsnr_t[0::self.ft_config.R]
1081
+ return segsnr_t[0 :: self.ft_config.overlap]
915
1082
 
916
1083
  if force or segsnr_t is None:
917
1084
  segsnr_t = self.mixture_segsnr_t(m_id, targets, target, noise, force)
918
1085
 
919
- return segsnr_t[0::self.ft_config.R]
920
-
921
- def mixture_ft(self,
922
- m_id: int,
923
- targets: Optional[AudiosT] = None,
924
- target: Optional[AudioT] = None,
925
- noise: Optional[AudioT] = None,
926
- mixture_f: Optional[AudioF] = None,
927
- mixture: Optional[AudioT] = None,
928
- truth_t: Optional[Truth] = None,
929
- force: bool = False) -> tuple[Feature, Truth]:
1086
+ return segsnr_t[0 :: self.ft_config.overlap]
1087
+
1088
+ def mixture_ft(
1089
+ self,
1090
+ m_id: int,
1091
+ targets: AudiosT | None = None,
1092
+ target: AudioT | None = None,
1093
+ noise: AudioT | None = None,
1094
+ mixture_f: AudioF | None = None,
1095
+ mixture: AudioT | None = None,
1096
+ truth_t: TruthDict | None = None,
1097
+ force: bool = False,
1098
+ ) -> tuple[Feature, TruthDict]:
930
1099
  """Get the feature and truth_f data for the given mixture ID
931
1100
 
932
1101
  :param m_id: Zero-based mixture ID
@@ -939,63 +1108,45 @@ class MixtureDatabase:
939
1108
  :param force: Force computing data from original sources regardless of whether cached data exists
940
1109
  :return: Tuple of (feature, truth_f) data
941
1110
  """
942
- from dataclasses import asdict
943
-
944
- import numpy as np
945
1111
  from pyaaware import FeatureGenerator
946
1112
 
947
- from .truth import truth_reduction
1113
+ from .truth import truth_stride_reduction
948
1114
 
949
1115
  if not force:
950
- feature, truth_f = self.read_mixture_data(m_id, ['feature', 'truth_f'])
1116
+ feature, truth_f = self.read_mixture_data(m_id, ["feature", "truth_f"])
951
1117
  if feature is not None and truth_f is not None:
952
1118
  return feature, truth_f
953
1119
 
954
1120
  if force or mixture_f is None:
955
- mixture_f = self.mixture_mixture_f(m_id=m_id,
956
- targets=targets,
957
- target=target,
958
- noise=noise,
959
- mixture=mixture,
960
- force=force)
1121
+ mixture_f = self.mixture_mixture_f(
1122
+ m_id=m_id,
1123
+ targets=targets,
1124
+ target=target,
1125
+ noise=noise,
1126
+ mixture=mixture,
1127
+ force=force,
1128
+ )
961
1129
 
962
1130
  if force or truth_t is None:
963
1131
  truth_t = self.mixture_truth_t(m_id=m_id, targets=targets, noise=noise, force=force)
964
1132
 
965
- m = self.mixture(m_id)
966
- transform_frames = self.mixture_transform_frames(m_id)
967
- feature_frames = self.mixture_feature_frames(m_id)
968
-
969
- if truth_t is None:
970
- truth_t = np.zeros((m.samples, self.num_classes), dtype=np.float32)
1133
+ fg = FeatureGenerator(self.fg_config.feature_mode, self.fg_config.truth_parameters)
971
1134
 
972
- feature = np.empty((feature_frames, self.fg_stride, self.feature_parameters), dtype=np.float32)
973
- truth_f = np.empty((feature_frames, self.num_classes), dtype=np.complex64)
974
-
975
- fg = FeatureGenerator(**asdict(self.fg_config))
976
- feature_frame = 0
977
- for transform_frame in range(transform_frames):
978
- indices = slice(transform_frame * self.ft_config.R, (transform_frame + 1) * self.ft_config.R)
979
- fg.execute(mixture_f[transform_frame],
980
- truth_reduction(truth_t[indices], self.truth_reduction_function))
981
-
982
- if fg.eof():
983
- feature[feature_frame] = fg.feature()
984
- truth_f[feature_frame] = fg.truth()
985
- feature_frame += 1
986
-
987
- if np.isreal(truth_f).all():
988
- return feature, truth_f.real
1135
+ feature, truth_f = fg.execute_all(mixture_f, truth_t)
1136
+ for key in self.truth_configs:
1137
+ truth_f[key] = truth_stride_reduction(truth_f[key], self.truth_configs[key].stride_reduction)
989
1138
 
990
1139
  return feature, truth_f
991
1140
 
992
- def mixture_feature(self,
993
- m_id: int,
994
- targets: Optional[AudiosT] = None,
995
- noise: Optional[AudioT] = None,
996
- mixture: Optional[AudioT] = None,
997
- truth_t: Optional[Truth] = None,
998
- force: bool = False) -> Feature:
1141
+ def mixture_feature(
1142
+ self,
1143
+ m_id: int,
1144
+ targets: AudiosT | None = None,
1145
+ noise: AudioT | None = None,
1146
+ mixture: AudioT | None = None,
1147
+ truth_t: TruthDict | None = None,
1148
+ force: bool = False,
1149
+ ) -> Feature:
999
1150
  """Get the feature data for the given mixture ID
1000
1151
 
1001
1152
  :param m_id: Zero-based mixture ID
@@ -1006,21 +1157,25 @@ class MixtureDatabase:
1006
1157
  :param force: Force computing data from original sources regardless of whether cached data exists
1007
1158
  :return: Feature data
1008
1159
  """
1009
- feature, _ = self.mixture_ft(m_id=m_id,
1010
- targets=targets,
1011
- noise=noise,
1012
- mixture=mixture,
1013
- truth_t=truth_t,
1014
- force=force)
1160
+ feature, _ = self.mixture_ft(
1161
+ m_id=m_id,
1162
+ targets=targets,
1163
+ noise=noise,
1164
+ mixture=mixture,
1165
+ truth_t=truth_t,
1166
+ force=force,
1167
+ )
1015
1168
  return feature
1016
1169
 
1017
- def mixture_truth_f(self,
1018
- m_id: int,
1019
- targets: Optional[AudiosT] = None,
1020
- noise: Optional[AudioT] = None,
1021
- mixture: Optional[AudioT] = None,
1022
- truth_t: Optional[Truth] = None,
1023
- force: bool = False) -> Truth:
1170
+ def mixture_truth_f(
1171
+ self,
1172
+ m_id: int,
1173
+ targets: AudiosT | None = None,
1174
+ noise: AudioT | None = None,
1175
+ mixture: AudioT | None = None,
1176
+ truth_t: TruthDict | None = None,
1177
+ force: bool = False,
1178
+ ) -> TruthDict:
1024
1179
  """Get the truth_f data for the given mixture ID
1025
1180
 
1026
1181
  :param m_id: Zero-based mixture ID
@@ -1031,20 +1186,24 @@ class MixtureDatabase:
1031
1186
  :param force: Force computing data from original sources regardless of whether cached data exists
1032
1187
  :return: truth_f data
1033
1188
  """
1034
- _, truth_f = self.mixture_ft(m_id=m_id,
1035
- targets=targets,
1036
- noise=noise,
1037
- mixture=mixture,
1038
- truth_t=truth_t,
1039
- force=force)
1189
+ _, truth_f = self.mixture_ft(
1190
+ m_id=m_id,
1191
+ targets=targets,
1192
+ noise=noise,
1193
+ mixture=mixture,
1194
+ truth_t=truth_t,
1195
+ force=force,
1196
+ )
1040
1197
  return truth_f
1041
1198
 
1042
- def mixture_class_count(self,
1043
- m_id: int,
1044
- targets: Optional[AudiosT] = None,
1045
- noise: Optional[AudioT] = None,
1046
- truth_t: Optional[Truth] = None) -> ClassCount:
1047
- """Compute the number of samples for which each truth index is active for the given mixture ID
1199
+ def mixture_class_count(
1200
+ self,
1201
+ m_id: int,
1202
+ targets: AudiosT | None = None,
1203
+ noise: AudioT | None = None,
1204
+ truth_t: TruthDict | None = None,
1205
+ ) -> ClassCount:
1206
+ """Compute the number of frames for which each class index is active for the given mixture ID
1048
1207
 
1049
1208
  :param m_id: Zero-based mixture ID
1050
1209
  :param targets: List of augmented target audio (one per target in the mixup)
@@ -1059,10 +1218,9 @@ class MixtureDatabase:
1059
1218
 
1060
1219
  class_count = [0] * self.num_classes
1061
1220
  num_classes = self.num_classes
1062
- if self.truth_mutex:
1063
- num_classes -= 1
1064
- for cl in range(num_classes):
1065
- class_count[cl] = int(np.sum(truth_t[:, cl] >= self.class_weights_thresholds[cl]))
1221
+ if "sed" in self.truth_configs:
1222
+ for cl in range(num_classes):
1223
+ class_count[cl] = int(np.sum(truth_t["sed"][:, cl] >= self.class_weights_thresholds[cl]))
1066
1224
 
1067
1225
  return class_count
1068
1226
 
@@ -1084,7 +1242,7 @@ class MixtureDatabase:
1084
1242
  def speech_metadata_tiers(self) -> list[str]:
1085
1243
  return sorted(set(self.speaker_metadata_tiers + self.textgrid_metadata_tiers))
1086
1244
 
1087
- def speaker(self, s_id: int | None, tier: str) -> Optional[str]:
1245
+ def speaker(self, s_id: int | None, tier: str) -> str | None:
1088
1246
  return _speaker(self.db, s_id, tier)
1089
1247
 
1090
1248
  def speech_metadata(self, tier: str) -> list[str]:
@@ -1124,9 +1282,13 @@ class MixtureDatabase:
1124
1282
  entries = []
1125
1283
  for entry in data:
1126
1284
  if target.augmentation.tempo is not None:
1127
- entries.append(Interval(entry.start / target.augmentation.tempo,
1128
- entry.end / target.augmentation.tempo,
1129
- entry.label))
1285
+ entries.append(
1286
+ Interval(
1287
+ entry.start / target.augmentation.tempo,
1288
+ entry.end / target.augmentation.tempo,
1289
+ entry.label,
1290
+ )
1291
+ )
1130
1292
  else:
1131
1293
  entries.append(entry)
1132
1294
  results.append(entries)
@@ -1136,12 +1298,9 @@ class MixtureDatabase:
1136
1298
  for target in self.mixture(mixid).targets:
1137
1299
  results.append(self.speaker(self.target_file(target.file_id).speaker_id, tier))
1138
1300
 
1139
- return sorted(results)
1301
+ return results
1140
1302
 
1141
- def mixids_for_speech_metadata(self,
1142
- tier: str,
1143
- value: str = None,
1144
- where: str = None) -> list[int]:
1303
+ def mixids_for_speech_metadata(self, tier: str, value: str | None = None, where: str | None = None) -> list[int]:
1145
1304
  """Get a list of mixture IDs for the given speech metadata tier.
1146
1305
 
1147
1306
  If 'where' is None, then include mixture IDs whose tier values are equal to the given 'value'.
@@ -1160,25 +1319,27 @@ class MixtureDatabase:
1160
1319
  >>> mixids = mixdb.mixids_for_speech_metadata('dialect', where="dialect in ('New York City', 'Northern')")
1161
1320
  Get mixture IDs for mixtures with speakers whose dialects are either 'New York City' or 'Northern'.
1162
1321
  """
1163
- from sonusai import SonusAIError
1164
-
1165
1322
  if value is None and where is None:
1166
- raise SonusAIError('Must provide either value or where')
1323
+ raise ValueError("Must provide either value or where")
1167
1324
 
1168
1325
  if where is None:
1169
1326
  where = f"{tier} = '{value}'"
1170
1327
 
1171
1328
  if tier in self.textgrid_metadata_tiers:
1172
- raise SonusAIError(f'TextGrid tier data, "{tier}", is not supported in mixids_for_speech_metadata().')
1329
+ raise ValueError(f"TextGrid tier data, '{tier}', is not supported in mixids_for_speech_metadata().")
1173
1330
 
1174
1331
  with self.db() as c:
1175
- speaker_ids = [speaker_id[0] for speaker_id in
1176
- c.execute(f"SELECT id FROM speaker WHERE {where}").fetchall()]
1177
- results = c.execute(f"SELECT id FROM target_file " +
1178
- f"WHERE speaker_id IN ({','.join(map(str, speaker_ids))})").fetchall()
1332
+ speaker_ids = [
1333
+ speaker_id[0] for speaker_id in c.execute(f"SELECT id FROM speaker WHERE {where}").fetchall()
1334
+ ]
1335
+ results = c.execute(
1336
+ "SELECT id FROM target_file " + f"WHERE speaker_id IN ({','.join(map(str, speaker_ids))})"
1337
+ ).fetchall()
1179
1338
  target_file_ids = [target_file_id[0] for target_file_id in results]
1180
- results = c.execute("SELECT mixture_id FROM mixture_target " +
1181
- f"WHERE mixture_target.target_id IN ({','.join(map(str, target_file_ids))})").fetchall()
1339
+ results = c.execute(
1340
+ "SELECT mixture_id FROM mixture_target "
1341
+ + f"WHERE mixture_target.target_id IN ({','.join(map(str, target_file_ids))})"
1342
+ ).fetchall()
1182
1343
 
1183
1344
  return [mixture_id[0] - 1 for mixture_id in results]
1184
1345
 
@@ -1187,9 +1348,9 @@ class MixtureDatabase:
1187
1348
 
1188
1349
  return mixture_all_speech_metadata(self, self.mixture(m_id))
1189
1350
 
1190
- def mixture_metrics(self, m_id: int,
1191
- metrics: list[str],
1192
- force: bool = False) -> list[float | int | str | Segsnr]:
1351
+ def mixture_metrics(
1352
+ self, m_id: int, metrics: list[str], force: bool = False
1353
+ ) -> list[float | int | str | Segsnr | None]:
1193
1354
  """Get metrics data for the given mixture ID
1194
1355
 
1195
1356
  :param m_id: Zero-based mixture ID
@@ -1197,12 +1358,11 @@ class MixtureDatabase:
1197
1358
  :param force: Force computing data from original sources regardless of whether cached data exists
1198
1359
  :return: List of metric data
1199
1360
  """
1200
- from typing import Callable
1361
+ from collections.abc import Callable
1201
1362
 
1202
1363
  import numpy as np
1203
1364
  from pystoi import stoi
1204
1365
 
1205
- from sonusai import SonusAIError
1206
1366
  from sonusai.metrics import calc_audio_stats
1207
1367
  from sonusai.metrics import calc_phase_distance
1208
1368
  from sonusai.metrics import calc_segsnr_f
@@ -1312,7 +1472,7 @@ class MixtureDatabase:
1312
1472
  def get() -> AudioStatsMetrics:
1313
1473
  nonlocal state
1314
1474
  if state is None:
1315
- state = calc_audio_stats(target_audio(), self.fg_info.ft_config.N / SAMPLE_RATE)
1475
+ state = calc_audio_stats(target_audio(), self.fg_info.ft_config.length / SAMPLE_RATE)
1316
1476
  return state
1317
1477
 
1318
1478
  return get
@@ -1325,7 +1485,7 @@ class MixtureDatabase:
1325
1485
  def get() -> AudioStatsMetrics:
1326
1486
  nonlocal state
1327
1487
  if state is None:
1328
- state = calc_audio_stats(noise_audio(), self.fg_info.ft_config.N / SAMPLE_RATE)
1488
+ state = calc_audio_stats(noise_audio(), self.fg_info.ft_config.length / SAMPLE_RATE)
1329
1489
  return state
1330
1490
 
1331
1491
  return get
@@ -1338,9 +1498,10 @@ class MixtureDatabase:
1338
1498
  def get(asr_name) -> dict:
1339
1499
  nonlocal state
1340
1500
  if asr_name not in state:
1341
- state[asr_name] = self.asr_configs.get(asr_name, None)
1342
- if state[asr_name] is None:
1343
- raise SonusAIError(f"Unrecognized ASR name: '{asr_name}'")
1501
+ value = self.asr_configs.get(asr_name, None)
1502
+ if value is None:
1503
+ raise ValueError(f"Unrecognized ASR name: '{asr_name}'")
1504
+ state[asr_name] = value
1344
1505
  return state[asr_name]
1345
1506
 
1346
1507
  return get
@@ -1374,15 +1535,14 @@ class MixtureDatabase:
1374
1535
  mixture_asr = create_mixture_asr()
1375
1536
 
1376
1537
  def get_asr_name(m: str) -> str:
1377
- parts = m.split('.')
1538
+ parts = m.split(".")
1378
1539
  if len(parts) != 2:
1379
- raise SonusAIError(
1380
- f"Unrecognized format: '{m}'; must be of the form: '<metric>.<name>'")
1540
+ raise ValueError(f"Unrecognized format: '{m}'; must be of the form: '<metric>.<name>'")
1381
1541
  asr_name = parts[1]
1382
1542
  return asr_name
1383
1543
 
1384
- def calc(m: str) -> float | int | str | Segsnr:
1385
- if m == 'mxsnr':
1544
+ def calc(m: str) -> float | int | str | Segsnr | None:
1545
+ if m == "mxsnr":
1386
1546
  return self.mixture(m_id).snr
1387
1547
 
1388
1548
  # Get cached data first, if exists
@@ -1392,12 +1552,12 @@ class MixtureDatabase:
1392
1552
  return value
1393
1553
 
1394
1554
  # Otherwise, generate data as needed
1395
- if m.startswith('mxwer'):
1555
+ if m.startswith("mxwer"):
1396
1556
  asr_name = get_asr_name(m)
1397
1557
 
1398
1558
  if self.mixture(m_id).snr < -96:
1399
1559
  # noise only, ignore/reset target asr
1400
- return float('nan')
1560
+ return float("nan")
1401
1561
 
1402
1562
  if target_asr(asr_name):
1403
1563
  return calc_wer(mixture_asr(asr_name), target_asr(asr_name)).wer * 100
@@ -1405,149 +1565,166 @@ class MixtureDatabase:
1405
1565
  # TODO: should this be NaN like above?
1406
1566
  return float(0)
1407
1567
 
1408
- if m.startswith('mxasr'):
1568
+ if m.startswith("basewer"):
1569
+ asr_name = get_asr_name(m)
1570
+
1571
+ text = self.mixture_speech_metadata(m_id, "text")[0]
1572
+ if text is not None:
1573
+ return calc_wer(target_asr(asr_name), text).wer * 100
1574
+
1575
+ # TODO: should this be NaN like above?
1576
+ return float(0)
1577
+
1578
+ if m.startswith("mxasr"):
1409
1579
  return mixture_asr(get_asr_name(m))
1410
1580
 
1411
- if m == 'mxssnr_avg':
1581
+ if m == "mxssnr_avg":
1412
1582
  return calc_segsnr_f(segsnr_f()).avg
1413
1583
 
1414
- if m == 'mxssnr_std':
1584
+ if m == "mxssnr_std":
1415
1585
  return calc_segsnr_f(segsnr_f()).std
1416
1586
 
1417
- if m == 'mxssnrdb_avg':
1587
+ if m == "mxssnrdb_avg":
1418
1588
  return calc_segsnr_f(segsnr_f()).db_avg
1419
1589
 
1420
- if m == 'mxssnrdb_std':
1590
+ if m == "mxssnrdb_std":
1421
1591
  return calc_segsnr_f(segsnr_f()).db_std
1422
1592
 
1423
- if m == 'mxssnrf_avg':
1593
+ if m == "mxssnrf_avg":
1424
1594
  return calc_segsnr_f_bin(target_f(), noise_f()).avg
1425
1595
 
1426
- if m == 'mxssnrf_std':
1596
+ if m == "mxssnrf_std":
1427
1597
  return calc_segsnr_f_bin(target_f(), noise_f()).std
1428
1598
 
1429
- if m == 'mxssnrdbf_avg':
1599
+ if m == "mxssnrdbf_avg":
1430
1600
  return calc_segsnr_f_bin(target_f(), noise_f()).db_avg
1431
1601
 
1432
- if m == 'mxssnrdbf_std':
1602
+ if m == "mxssnrdbf_std":
1433
1603
  return calc_segsnr_f_bin(target_f(), noise_f()).db_std
1434
1604
 
1435
- if m == 'mxpesq':
1605
+ if m == "mxpesq":
1436
1606
  if self.mixture(m_id).snr < -96:
1437
1607
  return 0
1438
1608
  return speech().pesq
1439
1609
 
1440
- if m == 'mxcsig':
1610
+ if m == "mxcsig":
1441
1611
  if self.mixture(m_id).snr < -96:
1442
1612
  return 0
1443
1613
  return speech().csig
1444
1614
 
1445
- if m == 'mxcbak':
1615
+ if m == "mxcbak":
1446
1616
  if self.mixture(m_id).snr < -96:
1447
1617
  return 0
1448
1618
  return speech().cbak
1449
1619
 
1450
- if m == 'mxcovl':
1620
+ if m == "mxcovl":
1451
1621
  if self.mixture(m_id).snr < -96:
1452
1622
  return 0
1453
1623
  return speech().covl
1454
1624
 
1455
- if m == 'mxwsdr':
1625
+ if m == "mxwsdr":
1456
1626
  mixture = mixture_audio()[:, np.newaxis]
1457
1627
  target = target_audio()[:, np.newaxis]
1458
1628
  noise = noise_audio()[:, np.newaxis]
1459
- return calc_wsdr(hypothesis=np.concatenate((mixture, noise), axis=1),
1460
- reference=np.concatenate((target, noise), axis=1),
1461
- with_log=True)[0]
1629
+ return calc_wsdr(
1630
+ hypothesis=np.concatenate((mixture, noise), axis=1),
1631
+ reference=np.concatenate((target, noise), axis=1),
1632
+ with_log=True,
1633
+ )[0]
1462
1634
 
1463
- if m == 'mxpd':
1635
+ if m == "mxpd":
1464
1636
  mixture_f = self.mixture_mixture_f(m_id)
1465
1637
  return calc_phase_distance(hypothesis=mixture_f, reference=target_f())[0]
1466
1638
 
1467
- if m == 'mxstoi':
1468
- return stoi(x=target_audio(), y=mixture_audio(), fs_sig=SAMPLE_RATE, extended=False)
1639
+ if m == "mxstoi":
1640
+ return stoi(
1641
+ x=target_audio(),
1642
+ y=mixture_audio(),
1643
+ fs_sig=SAMPLE_RATE,
1644
+ extended=False,
1645
+ )
1469
1646
 
1470
- if m == 'tdco':
1647
+ if m == "tdco":
1471
1648
  return target_stats().dco
1472
1649
 
1473
- if m == 'tmin':
1650
+ if m == "tmin":
1474
1651
  return target_stats().min
1475
1652
 
1476
- if m == 'tmax':
1653
+ if m == "tmax":
1477
1654
  return target_stats().max
1478
1655
 
1479
- if m == 'tpkdb':
1656
+ if m == "tpkdb":
1480
1657
  return target_stats().pkdb
1481
1658
 
1482
- if m == 'tlrms':
1659
+ if m == "tlrms":
1483
1660
  return target_stats().lrms
1484
1661
 
1485
- if m == 'tpkr':
1662
+ if m == "tpkr":
1486
1663
  return target_stats().pkr
1487
1664
 
1488
- if m == 'ttr':
1665
+ if m == "ttr":
1489
1666
  return target_stats().tr
1490
1667
 
1491
- if m == 'tcr':
1668
+ if m == "tcr":
1492
1669
  return target_stats().cr
1493
1670
 
1494
- if m == 'tfl':
1671
+ if m == "tfl":
1495
1672
  return target_stats().fl
1496
1673
 
1497
- if m == 'tpkc':
1674
+ if m == "tpkc":
1498
1675
  return target_stats().pkc
1499
1676
 
1500
- if m.startswith('tasr'):
1677
+ if m.startswith("tasr"):
1501
1678
  return target_asr(get_asr_name(m))
1502
1679
 
1503
- if m == 'ndco':
1680
+ if m == "ndco":
1504
1681
  return noise_stats().dco
1505
1682
 
1506
- if m == 'nmin':
1683
+ if m == "nmin":
1507
1684
  return noise_stats().min
1508
1685
 
1509
- if m == 'nmax':
1686
+ if m == "nmax":
1510
1687
  return noise_stats().max
1511
1688
 
1512
- if m == 'npkdb':
1689
+ if m == "npkdb":
1513
1690
  return noise_stats().pkdb
1514
1691
 
1515
- if m == 'nlrms':
1692
+ if m == "nlrms":
1516
1693
  return noise_stats().lrms
1517
1694
 
1518
- if m == 'npkr':
1695
+ if m == "npkr":
1519
1696
  return noise_stats().pkr
1520
1697
 
1521
- if m == 'ntr':
1698
+ if m == "ntr":
1522
1699
  return noise_stats().tr
1523
1700
 
1524
- if m == 'ncr':
1701
+ if m == "ncr":
1525
1702
  return noise_stats().cr
1526
1703
 
1527
- if m == 'nfl':
1704
+ if m == "nfl":
1528
1705
  return noise_stats().fl
1529
1706
 
1530
- if m == 'npkc':
1707
+ if m == "npkc":
1531
1708
  return noise_stats().pkc
1532
1709
 
1533
- if m == 'sedavg':
1710
+ if m == "sedavg":
1534
1711
  return 0
1535
1712
 
1536
- if m == 'sedcnt':
1713
+ if m == "sedcnt":
1537
1714
  return 0
1538
1715
 
1539
- if m == 'sedtop3':
1716
+ if m == "sedtop3":
1540
1717
  return np.zeros(3, dtype=np.float32)
1541
1718
 
1542
- if m == 'sedtopn':
1719
+ if m == "sedtopn":
1543
1720
  return 0
1544
1721
 
1545
- if m == 'ssnr':
1722
+ if m == "ssnr":
1546
1723
  return segsnr_f()
1547
1724
 
1548
- raise SonusAIError(f"Unrecognized metric: '{m}'")
1725
+ raise AttributeError(f"Unrecognized metric: '{m}'")
1549
1726
 
1550
- result: list[float | int | str | Segsnr] = []
1727
+ result: list[float | int | str | Segsnr | None] = []
1551
1728
  for metric in metrics:
1552
1729
  result.append(calc(metric))
1553
1730
 
@@ -1565,13 +1742,16 @@ def _spectral_mask(db: partial, sm_id: int) -> SpectralMask:
1565
1742
  from .db_datatypes import SpectralMaskRecord
1566
1743
 
1567
1744
  with db() as c:
1568
- spectral_mask = SpectralMaskRecord(*c.execute("SELECT * FROM spectral_mask WHERE ? = spectral_mask.id",
1569
- (sm_id,)).fetchone())
1570
- return SpectralMask(f_max_width=spectral_mask.f_max_width,
1571
- f_num=spectral_mask.f_num,
1572
- t_max_width=spectral_mask.t_max_width,
1573
- t_num=spectral_mask.t_num,
1574
- t_max_percent=spectral_mask.t_max_percent)
1745
+ spectral_mask = SpectralMaskRecord(
1746
+ *c.execute("SELECT * FROM spectral_mask WHERE ? = spectral_mask.id", (sm_id,)).fetchone()
1747
+ )
1748
+ return SpectralMask(
1749
+ f_max_width=spectral_mask.f_max_width,
1750
+ f_num=spectral_mask.f_num,
1751
+ t_max_width=spectral_mask.t_max_width,
1752
+ t_num=spectral_mask.t_num,
1753
+ t_max_percent=spectral_mask.t_max_percent,
1754
+ )
1575
1755
 
1576
1756
 
1577
1757
  @lru_cache
@@ -1584,30 +1764,21 @@ def _target_file(db: partial, t_id: int) -> TargetFile:
1584
1764
  """
1585
1765
  import json
1586
1766
 
1587
- from .datatypes import TruthSetting
1588
- from .datatypes import TruthSettings
1589
1767
  from .db_datatypes import TargetFileRecord
1590
1768
 
1591
1769
  with db() as c:
1592
1770
  target_file = TargetFileRecord(
1593
- *c.execute("SELECT * FROM target_file WHERE ? = target_file.id", (t_id,)).fetchone())
1594
-
1595
- truth_settings: TruthSettings = []
1596
- for ts in c.execute(
1597
- "SELECT truth_setting.setting " +
1598
- "FROM truth_setting, target_file_truth_setting " +
1599
- "WHERE ? = target_file_truth_setting.target_file_id " +
1600
- "AND truth_setting.id = target_file_truth_setting.truth_setting_id",
1601
- (t_id,)).fetchall():
1602
- entry = json.loads(ts[0])
1603
- truth_settings.append(TruthSetting(config=entry.get('config', None),
1604
- function=entry.get('function', None),
1605
- index=entry.get('index', None)))
1606
- return TargetFile(name=target_file.name,
1607
- samples=target_file.samples,
1608
- level_type=target_file.level_type,
1609
- truth_settings=truth_settings,
1610
- speaker_id=target_file.speaker_id)
1771
+ *c.execute("SELECT * FROM target_file WHERE ? = target_file.id", (t_id,)).fetchone()
1772
+ )
1773
+
1774
+ return TargetFile(
1775
+ name=target_file.name,
1776
+ samples=target_file.samples,
1777
+ class_indices=json.loads(target_file.class_indices),
1778
+ level_type=target_file.level_type,
1779
+ truth_configs=_target_truth_configs(db, t_id),
1780
+ speaker_id=target_file.speaker_id,
1781
+ )
1611
1782
 
1612
1783
 
1613
1784
  @lru_cache
@@ -1619,8 +1790,10 @@ def _noise_file(db: partial, n_id: int) -> NoiseFile:
1619
1790
  :return: Noise file
1620
1791
  """
1621
1792
  with db() as c:
1622
- noise = c.execute("SELECT noise_file.name, samples FROM noise_file WHERE ? = noise_file.id",
1623
- (n_id,)).fetchone()
1793
+ noise = c.execute(
1794
+ "SELECT noise_file.name, samples FROM noise_file WHERE ? = noise_file.id",
1795
+ (n_id,),
1796
+ ).fetchone()
1624
1797
  return NoiseFile(name=noise[0], samples=noise[1])
1625
1798
 
1626
1799
 
@@ -1633,9 +1806,12 @@ def _impulse_response_file(db: partial, ir_id: int) -> str:
1633
1806
  :return: Noise
1634
1807
  """
1635
1808
  with db() as c:
1636
- return str(c.execute(
1637
- "SELECT impulse_response_file.file FROM impulse_response_file WHERE ? = impulse_response_file.id",
1638
- (ir_id + 1,)).fetchone()[0])
1809
+ return str(
1810
+ c.execute(
1811
+ "SELECT impulse_response_file.file FROM impulse_response_file WHERE ? = impulse_response_file.id",
1812
+ (ir_id + 1,),
1813
+ ).fetchone()[0]
1814
+ )
1639
1815
 
1640
1816
 
1641
1817
  @lru_cache
@@ -1646,31 +1822,59 @@ def _mixture(db: partial, m_id: int) -> Mixture:
1646
1822
  :param m_id: Zero-based mixture ID
1647
1823
  :return: Mixture record
1648
1824
  """
1649
- from .helpers import to_mixture
1650
- from .helpers import to_target
1651
1825
  from .db_datatypes import MixtureRecord
1652
1826
  from .db_datatypes import TargetRecord
1827
+ from .helpers import to_mixture
1828
+ from .helpers import to_target
1653
1829
 
1654
1830
  with db() as c:
1655
1831
  mixture = MixtureRecord(*c.execute("SELECT * FROM mixture WHERE ? = mixture.id", (m_id + 1,)).fetchone())
1656
- targets = [to_target(TargetRecord(*target)) for target in c.execute(
1657
- "SELECT target.* " +
1658
- "FROM target, mixture_target " +
1659
- "WHERE ? = mixture_target.mixture_id AND target.id = mixture_target.target_id",
1660
- (mixture.id,)).fetchall()]
1832
+ targets = [
1833
+ to_target(TargetRecord(*target))
1834
+ for target in c.execute(
1835
+ "SELECT target.* "
1836
+ + "FROM target, mixture_target "
1837
+ + "WHERE ? = mixture_target.mixture_id AND target.id = mixture_target.target_id",
1838
+ (mixture.id,),
1839
+ ).fetchall()
1840
+ ]
1661
1841
 
1662
1842
  return to_mixture(mixture, targets)
1663
1843
 
1664
1844
 
1665
1845
  @lru_cache
1666
- def _speaker(db: partial, s_id: int | None, tier: str) -> Optional[str]:
1846
+ def _speaker(db: partial, s_id: int | None, tier: str) -> str | None:
1667
1847
  if s_id is None:
1668
1848
  return None
1669
1849
 
1670
1850
  with db() as c:
1671
- data = c.execute(f'SELECT {tier} FROM speaker WHERE ? = id', (s_id,)).fetchone()
1851
+ data = c.execute(f"SELECT {tier} FROM speaker WHERE ? = id", (s_id,)).fetchone()
1672
1852
  if data is None:
1673
1853
  return None
1674
1854
  if data[0] is None:
1675
1855
  return None
1676
1856
  return data[0]
1857
+
1858
+
1859
+ @lru_cache
1860
+ def _target_truth_configs(db: partial, t_id: int) -> TruthConfigs:
1861
+ import json
1862
+
1863
+ from .datatypes import TruthConfig
1864
+
1865
+ truth_configs: TruthConfigs = {}
1866
+ with db() as c:
1867
+ for truth_config_record in c.execute(
1868
+ "SELECT truth_config.config "
1869
+ + "FROM truth_config, target_file_truth_config "
1870
+ + "WHERE ? = target_file_truth_config.target_file_id "
1871
+ + "AND truth_config.id = target_file_truth_config.truth_config_id",
1872
+ (t_id,),
1873
+ ).fetchall():
1874
+ truth_config = json.loads(truth_config_record[0])
1875
+ truth_configs[truth_config["name"]] = TruthConfig(
1876
+ function=truth_config["function"],
1877
+ stride_reduction=truth_config["stride_reduction"],
1878
+ config=truth_config["config"],
1879
+ )
1880
+ return truth_configs