sonusai 0.18.9__py3-none-any.whl → 0.19.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (118) hide show
  1. sonusai/__init__.py +20 -29
  2. sonusai/aawscd_probwrite.py +18 -18
  3. sonusai/audiofe.py +93 -80
  4. sonusai/calc_metric_spenh.py +395 -321
  5. sonusai/data/genmixdb.yml +5 -11
  6. sonusai/{gentcst.py → deprecated/gentcst.py} +146 -149
  7. sonusai/{plot.py → deprecated/plot.py} +177 -131
  8. sonusai/{tplot.py → deprecated/tplot.py} +124 -102
  9. sonusai/doc/__init__.py +1 -1
  10. sonusai/doc/doc.py +112 -177
  11. sonusai/doc.py +10 -10
  12. sonusai/genft.py +93 -77
  13. sonusai/genmetrics.py +59 -46
  14. sonusai/genmix.py +116 -104
  15. sonusai/genmixdb.py +194 -153
  16. sonusai/lsdb.py +56 -66
  17. sonusai/main.py +23 -20
  18. sonusai/metrics/__init__.py +2 -0
  19. sonusai/metrics/calc_audio_stats.py +29 -24
  20. sonusai/metrics/calc_class_weights.py +7 -7
  21. sonusai/metrics/calc_optimal_thresholds.py +5 -7
  22. sonusai/metrics/calc_pcm.py +3 -3
  23. sonusai/metrics/calc_pesq.py +10 -7
  24. sonusai/metrics/calc_phase_distance.py +3 -3
  25. sonusai/metrics/calc_sa_sdr.py +10 -8
  26. sonusai/metrics/calc_segsnr_f.py +15 -17
  27. sonusai/metrics/calc_speech.py +105 -47
  28. sonusai/metrics/calc_wer.py +35 -32
  29. sonusai/metrics/calc_wsdr.py +10 -7
  30. sonusai/metrics/class_summary.py +30 -27
  31. sonusai/metrics/confusion_matrix_summary.py +25 -22
  32. sonusai/metrics/one_hot.py +91 -57
  33. sonusai/metrics/snr_summary.py +53 -46
  34. sonusai/mixture/__init__.py +19 -14
  35. sonusai/mixture/audio.py +4 -6
  36. sonusai/mixture/augmentation.py +37 -43
  37. sonusai/mixture/class_count.py +5 -14
  38. sonusai/mixture/config.py +292 -225
  39. sonusai/mixture/constants.py +41 -30
  40. sonusai/mixture/data_io.py +155 -0
  41. sonusai/mixture/datatypes.py +111 -108
  42. sonusai/mixture/db_datatypes.py +54 -70
  43. sonusai/mixture/eq_rule_is_valid.py +6 -9
  44. sonusai/mixture/feature.py +40 -38
  45. sonusai/mixture/generation.py +522 -389
  46. sonusai/mixture/helpers.py +217 -272
  47. sonusai/mixture/log_duration_and_sizes.py +16 -13
  48. sonusai/mixture/mixdb.py +669 -477
  49. sonusai/mixture/soundfile_audio.py +12 -17
  50. sonusai/mixture/sox_audio.py +91 -112
  51. sonusai/mixture/sox_augmentation.py +8 -9
  52. sonusai/mixture/spectral_mask.py +4 -6
  53. sonusai/mixture/target_class_balancing.py +41 -36
  54. sonusai/mixture/targets.py +69 -67
  55. sonusai/mixture/tokenized_shell_vars.py +23 -23
  56. sonusai/mixture/torchaudio_audio.py +14 -15
  57. sonusai/mixture/torchaudio_augmentation.py +23 -27
  58. sonusai/mixture/truth.py +48 -26
  59. sonusai/mixture/truth_functions/__init__.py +26 -0
  60. sonusai/mixture/truth_functions/crm.py +56 -38
  61. sonusai/mixture/truth_functions/datatypes.py +37 -0
  62. sonusai/mixture/truth_functions/energy.py +85 -59
  63. sonusai/mixture/truth_functions/file.py +30 -30
  64. sonusai/mixture/truth_functions/phoneme.py +14 -7
  65. sonusai/mixture/truth_functions/sed.py +71 -45
  66. sonusai/mixture/truth_functions/target.py +69 -106
  67. sonusai/mkwav.py +52 -85
  68. sonusai/onnx_predict.py +46 -43
  69. sonusai/queries/__init__.py +3 -1
  70. sonusai/queries/queries.py +100 -59
  71. sonusai/speech/__init__.py +2 -0
  72. sonusai/speech/l2arctic.py +24 -23
  73. sonusai/speech/librispeech.py +16 -17
  74. sonusai/speech/mcgill.py +22 -21
  75. sonusai/speech/textgrid.py +32 -25
  76. sonusai/speech/timit.py +45 -42
  77. sonusai/speech/vctk.py +14 -13
  78. sonusai/speech/voxceleb.py +26 -20
  79. sonusai/summarize_metric_spenh.py +11 -10
  80. sonusai/utils/__init__.py +4 -3
  81. sonusai/utils/asl_p56.py +1 -1
  82. sonusai/utils/asr.py +37 -17
  83. sonusai/utils/asr_functions/__init__.py +2 -0
  84. sonusai/utils/asr_functions/aaware_whisper.py +18 -12
  85. sonusai/utils/audio_devices.py +12 -12
  86. sonusai/utils/braced_glob.py +6 -8
  87. sonusai/utils/calculate_input_shape.py +1 -4
  88. sonusai/utils/compress.py +2 -2
  89. sonusai/utils/convert_string_to_number.py +1 -3
  90. sonusai/utils/create_timestamp.py +1 -1
  91. sonusai/utils/create_ts_name.py +2 -2
  92. sonusai/utils/dataclass_from_dict.py +1 -1
  93. sonusai/utils/docstring.py +6 -6
  94. sonusai/utils/energy_f.py +9 -7
  95. sonusai/utils/engineering_number.py +56 -54
  96. sonusai/utils/get_label_names.py +8 -10
  97. sonusai/utils/human_readable_size.py +2 -2
  98. sonusai/utils/model_utils.py +3 -5
  99. sonusai/utils/numeric_conversion.py +2 -4
  100. sonusai/utils/onnx_utils.py +43 -32
  101. sonusai/utils/parallel.py +40 -27
  102. sonusai/utils/print_mixture_details.py +25 -22
  103. sonusai/utils/ranges.py +12 -12
  104. sonusai/utils/read_predict_data.py +11 -9
  105. sonusai/utils/reshape.py +19 -26
  106. sonusai/utils/seconds_to_hms.py +1 -1
  107. sonusai/utils/stacked_complex.py +8 -16
  108. sonusai/utils/stratified_shuffle_split.py +29 -27
  109. sonusai/utils/write_audio.py +2 -2
  110. sonusai/utils/yes_or_no.py +3 -3
  111. sonusai/vars.py +14 -14
  112. {sonusai-0.18.9.dist-info → sonusai-0.19.5.dist-info}/METADATA +20 -21
  113. sonusai-0.19.5.dist-info/RECORD +125 -0
  114. {sonusai-0.18.9.dist-info → sonusai-0.19.5.dist-info}/WHEEL +1 -1
  115. sonusai/mixture/truth_functions/data.py +0 -58
  116. sonusai/utils/read_mixture_data.py +0 -14
  117. sonusai-0.18.9.dist-info/RECORD +0 -125
  118. {sonusai-0.18.9.dist-info → sonusai-0.19.5.dist-info}/entry_points.txt +0 -0
sonusai/mixture/mixdb.py CHANGED
@@ -1,46 +1,47 @@
1
+ # ruff: noqa: S608
1
2
  from functools import cached_property
2
3
  from functools import lru_cache
3
4
  from functools import partial
4
5
  from sqlite3 import Connection
5
6
  from sqlite3 import Cursor
6
7
  from typing import Any
7
- from typing import Optional
8
-
9
- from sonusai.mixture.datatypes import ASRConfigs
10
- from sonusai.mixture.datatypes import AudioF
11
- from sonusai.mixture.datatypes import AudioT
12
- from sonusai.mixture.datatypes import AudiosF
13
- from sonusai.mixture.datatypes import AudiosT
14
- from sonusai.mixture.datatypes import ClassCount
15
- from sonusai.mixture.datatypes import Feature
16
- from sonusai.mixture.datatypes import FeatureGeneratorConfig
17
- from sonusai.mixture.datatypes import FeatureGeneratorInfo
18
- from sonusai.mixture.datatypes import GeneralizedIDs
19
- from sonusai.mixture.datatypes import ImpulseResponseFiles
20
- from sonusai.mixture.datatypes import MetricDoc
21
- from sonusai.mixture.datatypes import MetricDocs
22
- from sonusai.mixture.datatypes import Mixture
23
- from sonusai.mixture.datatypes import Mixtures
24
- from sonusai.mixture.datatypes import NoiseFile
25
- from sonusai.mixture.datatypes import NoiseFiles
26
- from sonusai.mixture.datatypes import Segsnr
27
- from sonusai.mixture.datatypes import SpectralMask
28
- from sonusai.mixture.datatypes import SpectralMasks
29
- from sonusai.mixture.datatypes import SpeechMetadata
30
- from sonusai.mixture.datatypes import TargetFile
31
- from sonusai.mixture.datatypes import TargetFiles
32
- from sonusai.mixture.datatypes import TransformConfig
33
- from sonusai.mixture.datatypes import Truth
34
- from sonusai.mixture.datatypes import UniversalSNR
8
+
9
+ from .datatypes import ASRConfigs
10
+ from .datatypes import AudioF
11
+ from .datatypes import AudiosF
12
+ from .datatypes import AudiosT
13
+ from .datatypes import AudioT
14
+ from .datatypes import ClassCount
15
+ from .datatypes import Feature
16
+ from .datatypes import FeatureGeneratorConfig
17
+ from .datatypes import FeatureGeneratorInfo
18
+ from .datatypes import GeneralizedIDs
19
+ from .datatypes import ImpulseResponseFiles
20
+ from .datatypes import MetricDoc
21
+ from .datatypes import MetricDocs
22
+ from .datatypes import Mixture
23
+ from .datatypes import Mixtures
24
+ from .datatypes import NoiseFile
25
+ from .datatypes import NoiseFiles
26
+ from .datatypes import Segsnr
27
+ from .datatypes import SpectralMask
28
+ from .datatypes import SpectralMasks
29
+ from .datatypes import SpeechMetadata
30
+ from .datatypes import TargetFile
31
+ from .datatypes import TargetFiles
32
+ from .datatypes import TransformConfig
33
+ from .datatypes import TruthConfigs
34
+ from .datatypes import TruthDict
35
+ from .datatypes import UniversalSNR
35
36
 
36
37
 
37
38
  def db_file(location: str, test: bool = False) -> str:
38
39
  from os.path import join
39
40
 
40
41
  if test:
41
- name = 'mixdb_test.db'
42
+ name = "mixdb_test.db"
42
43
  else:
43
- name = 'mixdb.db'
44
+ name = "mixdb.db"
44
45
 
45
46
  return join(location, name)
46
47
 
@@ -50,19 +51,17 @@ def db_connection(location: str, create: bool = False, readonly: bool = True, te
50
51
  from os import remove
51
52
  from os.path import exists
52
53
 
53
- from sonusai import SonusAIError
54
-
55
54
  name = db_file(location, test)
56
55
  if create and exists(name):
57
56
  remove(name)
58
57
 
59
58
  if not create and not exists(name):
60
- raise SonusAIError(f'Could not find mixture database in {location}')
59
+ raise OSError(f"Could not find mixture database in {location}")
61
60
 
62
61
  if not create and readonly:
63
- name += '?mode=ro'
62
+ name += "?mode=ro"
64
63
 
65
- connection = sqlite3.connect('file:' + name, uri=True)
64
+ connection = sqlite3.connect("file:" + name, uri=True)
66
65
  # connection.set_trace_callback(print)
67
66
  return connection
68
67
 
@@ -103,25 +102,23 @@ class MixtureDatabase:
103
102
  num_classes=self.num_classes,
104
103
  spectral_masks=self.spectral_masks,
105
104
  target_files=self.target_files,
106
- truth_mutex=self.truth_mutex,
107
- truth_reduction_function=self.truth_reduction_function
108
105
  )
109
106
  return config.to_json(indent=2)
110
107
 
111
108
  def save(self) -> None:
112
- """Save the MixtureDatabase as a JSON file
113
- """
109
+ """Save the MixtureDatabase as a JSON file"""
114
110
  from os.path import join
115
111
 
116
- json_name = join(self.location, 'mixdb.json')
117
- with open(file=json_name, mode='w') as file:
112
+ json_name = join(self.location, "mixdb.json")
113
+ with open(file=json_name, mode="w") as file:
118
114
  file.write(self.json)
119
115
 
120
116
  @cached_property
121
117
  def fg_config(self) -> FeatureGeneratorConfig:
122
- return FeatureGeneratorConfig(feature_mode=self.feature,
123
- num_classes=self.num_classes,
124
- truth_mutex=self.truth_mutex)
118
+ return FeatureGeneratorConfig(
119
+ feature_mode=self.feature,
120
+ truth_parameters=self.truth_parameters,
121
+ )
125
122
 
126
123
  @cached_property
127
124
  def fg_info(self) -> FeatureGeneratorInfo:
@@ -130,19 +127,18 @@ class MixtureDatabase:
130
127
  return get_feature_generator_info(self.fg_config)
131
128
 
132
129
  @cached_property
133
- def num_classes(self) -> int:
134
- with self.db() as c:
135
- return int(c.execute("SELECT top.num_classes FROM top").fetchone()[0])
136
-
137
- @cached_property
138
- def truth_mutex(self) -> bool:
130
+ def truth_parameters(self) -> dict[str, int]:
139
131
  with self.db() as c:
140
- return bool(c.execute("SELECT top.truth_mutex FROM top").fetchone()[0])
132
+ rows = c.execute("SELECT * FROM truth_parameters").fetchall()
133
+ truth_parameters: dict[str, int] = {}
134
+ for row in rows:
135
+ truth_parameters[row[1]] = row[2]
136
+ return truth_parameters
141
137
 
142
138
  @cached_property
143
- def truth_reduction_function(self) -> str:
139
+ def num_classes(self) -> int:
144
140
  with self.db() as c:
145
- return str(c.execute("SELECT top.truth_reduction_function FROM top").fetchone()[0])
141
+ return int(c.execute("SELECT top.num_classes FROM top").fetchone()[0])
146
142
 
147
143
  @cached_property
148
144
  def noise_mix_mode(self) -> str:
@@ -158,70 +154,152 @@ class MixtureDatabase:
158
154
 
159
155
  @cached_property
160
156
  def supported_metrics(self) -> MetricDocs:
161
- metrics = MetricDocs([
162
- MetricDoc('Mixture Metrics', 'mxsnr', 'SNR specification in dB'),
163
- MetricDoc('Mixture Metrics', 'mxssnr_avg', 'Segmental SNR average over all frames'),
164
- MetricDoc('Mixture Metrics', 'mxssnr_std', 'Segmental SNR standard deviation over all frames'),
165
- MetricDoc('Mixture Metrics', 'mxssnrdb_avg',
166
- 'Segmental SNR average of the dB frame values over all frames'),
167
- MetricDoc('Mixture Metrics', 'mxssnrdb_std',
168
- 'Segmental SNR standard deviation of the dB frame values over all frames'),
169
- MetricDoc('Mixture Metrics', 'mxssnrf_avg',
170
- 'Per-bin segmental SNR average over all frames (using feature transform)'),
171
- MetricDoc('Mixture Metrics', 'mxssnrf_std',
172
- 'Per-bin segmental SNR standard deviation over all frames (using feature transform)'),
173
- MetricDoc('Mixture Metrics', 'mxssnrdbf_avg',
174
- 'Per-bin segmental average of the dB frame values over all frames (using feature transform)'),
175
- MetricDoc('Mixture Metrics', 'mxssnrdbf_std',
176
- 'Per-bin segmental standard deviation of the dB frame values over all frames (using feature transform)'),
177
- MetricDoc('Mixture Metrics', 'mxpesq', 'PESQ of mixture versus true target[0]'),
178
- MetricDoc('Mixture Metrics', 'mxwsdr', 'Weighted signal distorion ratio of mixture versus true target[0]'),
179
- MetricDoc('Mixture Metrics', 'mxpd', 'Phase distance between mixture and true target[0]'),
180
- MetricDoc('Mixture Metrics', 'mxstoi',
181
- 'Short term objective intelligibility of mixture versus true target[0]'),
182
- MetricDoc('Mixture Metrics', 'mxcsig',
183
- 'Predicted rating of speech distortion of mixture versus true target[0]'),
184
- MetricDoc('Mixture Metrics', 'mxcbak',
185
- 'Predicted rating of background distortion of mixture versus true target[0]'),
186
- MetricDoc('Mixture Metrics', 'mxcovl',
187
- 'Predicted rating of overall quality of mixture versus true target[0]'),
188
- MetricDoc('Mixture Metrics', 'ssnr', 'Segmental SNR'),
189
- MetricDoc('Target Metrics', 'tdco', 'Target[0] DC offset'),
190
- MetricDoc('Target Metrics', 'tmin', 'Target[0] min level'),
191
- MetricDoc('Target Metrics', 'tmax', 'Target[0] max levl'),
192
- MetricDoc('Target Metrics', 'tpkdb', 'Target[0] Pk lev dB'),
193
- MetricDoc('Target Metrics', 'tlrms', 'Target[0] RMS lev dB'),
194
- MetricDoc('Target Metrics', 'tpkr', 'Target[0] RMS Pk dB'),
195
- MetricDoc('Target Metrics', 'ttr', 'Target[0] RMS Tr dB'),
196
- MetricDoc('Target Metrics', 'tcr', 'Target[0] Crest factor'),
197
- MetricDoc('Target Metrics', 'tfl', 'Target[0] Flat factor'),
198
- MetricDoc('Target Metrics', 'tpkc', 'Target[0] Pk count'),
199
- MetricDoc('Noise Metrics', 'ndco', 'Noise DC offset'),
200
- MetricDoc('Noise Metrics', 'nmin', 'Noise min level'),
201
- MetricDoc('Noise Metrics', 'nmax', 'Noise max levl'),
202
- MetricDoc('Noise Metrics', 'npkdb', 'Noise Pk lev dB'),
203
- MetricDoc('Noise Metrics', 'nlrms', 'Noise RMS lev dB'),
204
- MetricDoc('Noise Metrics', 'npkr', 'Noise RMS Pk dB'),
205
- MetricDoc('Noise Metrics', 'ntr', 'Noise RMS Tr dB'),
206
- MetricDoc('Noise Metrics', 'ncr', 'Noise Crest factor'),
207
- MetricDoc('Noise Metrics', 'nfl', 'Noise Flat factor'),
208
- MetricDoc('Noise Metrics', 'npkc', 'Noise Pk count'),
209
- MetricDoc('Truth Metrics', 'sedavg',
210
- '(not implemented) Average SED activity over all frames [num_classes, 1]'),
211
- MetricDoc('Truth Metrics', 'sedcnt',
212
- '(not implemented) Count in number of frames that SED is active [num_classes, 1]'),
213
- MetricDoc('Truth Metrics', 'sedtop3', '(not implemented) 3 most active by largest sedavg [3, 1]'),
214
- MetricDoc('Truth Metrics', 'sedtopn', '(not implemented) N most active by largest sedavg [N, 1]'),
215
- ])
157
+ metrics = MetricDocs(
158
+ [
159
+ MetricDoc("Mixture Metrics", "mxsnr", "SNR specification in dB"),
160
+ MetricDoc(
161
+ "Mixture Metrics",
162
+ "mxssnr_avg",
163
+ "Segmental SNR average over all frames",
164
+ ),
165
+ MetricDoc(
166
+ "Mixture Metrics",
167
+ "mxssnr_std",
168
+ "Segmental SNR standard deviation over all frames",
169
+ ),
170
+ MetricDoc(
171
+ "Mixture Metrics",
172
+ "mxssnrdb_avg",
173
+ "Segmental SNR average of the dB frame values over all frames",
174
+ ),
175
+ MetricDoc(
176
+ "Mixture Metrics",
177
+ "mxssnrdb_std",
178
+ "Segmental SNR standard deviation of the dB frame values over all frames",
179
+ ),
180
+ MetricDoc(
181
+ "Mixture Metrics",
182
+ "mxssnrf_avg",
183
+ "Per-bin segmental SNR average over all frames (using feature transform)",
184
+ ),
185
+ MetricDoc(
186
+ "Mixture Metrics",
187
+ "mxssnrf_std",
188
+ "Per-bin segmental SNR standard deviation over all frames (using feature transform)",
189
+ ),
190
+ MetricDoc(
191
+ "Mixture Metrics",
192
+ "mxssnrdbf_avg",
193
+ "Per-bin segmental average of the dB frame values over all frames (using feature transform)",
194
+ ),
195
+ MetricDoc(
196
+ "Mixture Metrics",
197
+ "mxssnrdbf_std",
198
+ "Per-bin segmental standard deviation of the dB frame values over all frames (using feature transform)",
199
+ ),
200
+ MetricDoc("Mixture Metrics", "mxpesq", "PESQ of mixture versus true target[0]"),
201
+ MetricDoc(
202
+ "Mixture Metrics",
203
+ "mxwsdr",
204
+ "Weighted signal distorion ratio of mixture versus true target[0]",
205
+ ),
206
+ MetricDoc(
207
+ "Mixture Metrics",
208
+ "mxpd",
209
+ "Phase distance between mixture and true target[0]",
210
+ ),
211
+ MetricDoc(
212
+ "Mixture Metrics",
213
+ "mxstoi",
214
+ "Short term objective intelligibility of mixture versus true target[0]",
215
+ ),
216
+ MetricDoc(
217
+ "Mixture Metrics",
218
+ "mxcsig",
219
+ "Predicted rating of speech distortion of mixture versus true target[0]",
220
+ ),
221
+ MetricDoc(
222
+ "Mixture Metrics",
223
+ "mxcbak",
224
+ "Predicted rating of background distortion of mixture versus true target[0]",
225
+ ),
226
+ MetricDoc(
227
+ "Mixture Metrics",
228
+ "mxcovl",
229
+ "Predicted rating of overall quality of mixture versus true target[0]",
230
+ ),
231
+ MetricDoc("Mixture Metrics", "ssnr", "Segmental SNR"),
232
+ MetricDoc("Target Metrics", "tdco", "Target[0] DC offset"),
233
+ MetricDoc("Target Metrics", "tmin", "Target[0] min level"),
234
+ MetricDoc("Target Metrics", "tmax", "Target[0] max levl"),
235
+ MetricDoc("Target Metrics", "tpkdb", "Target[0] Pk lev dB"),
236
+ MetricDoc("Target Metrics", "tlrms", "Target[0] RMS lev dB"),
237
+ MetricDoc("Target Metrics", "tpkr", "Target[0] RMS Pk dB"),
238
+ MetricDoc("Target Metrics", "ttr", "Target[0] RMS Tr dB"),
239
+ MetricDoc("Target Metrics", "tcr", "Target[0] Crest factor"),
240
+ MetricDoc("Target Metrics", "tfl", "Target[0] Flat factor"),
241
+ MetricDoc("Target Metrics", "tpkc", "Target[0] Pk count"),
242
+ MetricDoc("Noise Metrics", "ndco", "Noise DC offset"),
243
+ MetricDoc("Noise Metrics", "nmin", "Noise min level"),
244
+ MetricDoc("Noise Metrics", "nmax", "Noise max levl"),
245
+ MetricDoc("Noise Metrics", "npkdb", "Noise Pk lev dB"),
246
+ MetricDoc("Noise Metrics", "nlrms", "Noise RMS lev dB"),
247
+ MetricDoc("Noise Metrics", "npkr", "Noise RMS Pk dB"),
248
+ MetricDoc("Noise Metrics", "ntr", "Noise RMS Tr dB"),
249
+ MetricDoc("Noise Metrics", "ncr", "Noise Crest factor"),
250
+ MetricDoc("Noise Metrics", "nfl", "Noise Flat factor"),
251
+ MetricDoc("Noise Metrics", "npkc", "Noise Pk count"),
252
+ MetricDoc(
253
+ "Truth Metrics",
254
+ "sedavg",
255
+ "(not implemented) Average SED activity over all frames [truth_parameters, 1]",
256
+ ),
257
+ MetricDoc(
258
+ "Truth Metrics",
259
+ "sedcnt",
260
+ "(not implemented) Count in number of frames that SED is active [truth_parameters, 1]",
261
+ ),
262
+ MetricDoc(
263
+ "Truth Metrics",
264
+ "sedtop3",
265
+ "(not implemented) 3 most active by largest sedavg [3, 1]",
266
+ ),
267
+ MetricDoc(
268
+ "Truth Metrics",
269
+ "sedtopn",
270
+ "(not implemented) N most active by largest sedavg [N, 1]",
271
+ ),
272
+ ]
273
+ )
216
274
  for name in self.asr_configs:
217
- metrics.append(MetricDoc('Target Metrics', f'tasr.{name}',
218
- f'Target[0] ASR text using {name} ASR as defined in mixdb asr_configs parameter'))
219
- metrics.append(MetricDoc('Mixture Metrics', f'mxasr.{name}',
220
- f'ASR text using {name} ASR as defined in mixdb asr_configs parameter'))
221
- metrics.append(MetricDoc('Target Metrics', f'basewer.{name}',
222
- f'Word error rate of tasr.{name} vs. speech text metadata for the target'))
223
- metrics.append(MetricDoc('Mixture Metrics', f'mxwer.{name}',
224
- f'Word error rate of mxasr.{name} vs. tasr.{name}'))
275
+ metrics.append(
276
+ MetricDoc(
277
+ "Target Metrics",
278
+ f"tasr.{name}",
279
+ f"Target[0] ASR text using {name} ASR as defined in mixdb asr_configs parameter",
280
+ )
281
+ )
282
+ metrics.append(
283
+ MetricDoc(
284
+ "Mixture Metrics",
285
+ f"mxasr.{name}",
286
+ f"ASR text using {name} ASR as defined in mixdb asr_configs parameter",
287
+ )
288
+ )
289
+ metrics.append(
290
+ MetricDoc(
291
+ "Target Metrics",
292
+ f"basewer.{name}",
293
+ f"Word error rate of tasr.{name} vs. speech text metadata for the target",
294
+ )
295
+ )
296
+ metrics.append(
297
+ MetricDoc(
298
+ "Mixture Metrics",
299
+ f"mxwer.{name}",
300
+ f"Word error rate of mxasr.{name} vs. tasr.{name}",
301
+ )
302
+ )
225
303
 
226
304
  return metrics
227
305
 
@@ -267,7 +345,7 @@ class MixtureDatabase:
267
345
  def transform_frame_ms(self) -> float:
268
346
  from .constants import SAMPLE_RATE
269
347
 
270
- return float(self.ft_config.R) / float(SAMPLE_RATE / 1000)
348
+ return float(self.ft_config.overlap) / float(SAMPLE_RATE / 1000)
271
349
 
272
350
  @cached_property
273
351
  def feature_ms(self) -> float:
@@ -275,7 +353,7 @@ class MixtureDatabase:
275
353
 
276
354
  @cached_property
277
355
  def feature_samples(self) -> int:
278
- return self.ft_config.R * self.fg_decimation * self.fg_stride
356
+ return self.ft_config.overlap * self.fg_decimation * self.fg_stride
279
357
 
280
358
  @cached_property
281
359
  def feature_step_ms(self) -> float:
@@ -283,28 +361,33 @@ class MixtureDatabase:
283
361
 
284
362
  @cached_property
285
363
  def feature_step_samples(self) -> int:
286
- return self.ft_config.R * self.fg_decimation * self.fg_step
364
+ return self.ft_config.overlap * self.fg_decimation * self.fg_step
287
365
 
288
- def total_samples(self, m_ids: GeneralizedIDs = '*') -> int:
289
- return sum([self.mixture(m_id).samples for m_id in self.mixids_to_list(m_ids)])
366
+ def total_samples(self, m_ids: GeneralizedIDs = "*") -> int:
367
+ samples = 0
368
+ for m_id in self.mixids_to_list(m_ids):
369
+ s = self.mixture(m_id).samples
370
+ if s is not None:
371
+ samples += s
372
+ return samples
290
373
 
291
- def total_transform_frames(self, m_ids: GeneralizedIDs = '*') -> int:
292
- return self.total_samples(m_ids) // self.ft_config.R
374
+ def total_transform_frames(self, m_ids: GeneralizedIDs = "*") -> int:
375
+ return self.total_samples(m_ids) // self.ft_config.overlap
293
376
 
294
- def total_feature_frames(self, m_ids: GeneralizedIDs = '*') -> int:
377
+ def total_feature_frames(self, m_ids: GeneralizedIDs = "*") -> int:
295
378
  return self.total_samples(m_ids) // self.feature_step_samples
296
379
 
297
380
  def mixture_transform_frames(self, m_id: int) -> int:
298
381
  from .helpers import frames_from_samples
299
382
 
300
- return frames_from_samples(self.mixture(m_id).samples, self.ft_config.R)
383
+ return frames_from_samples(self.mixture(m_id).samples, self.ft_config.overlap)
301
384
 
302
385
  def mixture_feature_frames(self, m_id: int) -> int:
303
386
  from .helpers import frames_from_samples
304
387
 
305
388
  return frames_from_samples(self.mixture(m_id).samples, self.feature_step_samples)
306
389
 
307
- def mixids_to_list(self, m_ids: Optional[GeneralizedIDs] = None) -> list[int]:
390
+ def mixids_to_list(self, m_ids: GeneralizedIDs = "*") -> list[int]:
308
391
  """Resolve generalized mixture IDs to a list of integers
309
392
 
310
393
  :param m_ids: Generalized mixture IDs
@@ -321,8 +404,10 @@ class MixtureDatabase:
321
404
  :return: Class labels
322
405
  """
323
406
  with self.db() as c:
324
- return [str(item[0]) for item in
325
- c.execute("SELECT class_label.label FROM class_label ORDER BY class_label.id").fetchall()]
407
+ return [
408
+ str(item[0])
409
+ for item in c.execute("SELECT class_label.label FROM class_label ORDER BY class_label.id").fetchall()
410
+ ]
326
411
 
327
412
  @cached_property
328
413
  def class_weights_thresholds(self) -> list[float]:
@@ -331,8 +416,37 @@ class MixtureDatabase:
331
416
  :return: Class weights thresholds
332
417
  """
333
418
  with self.db() as c:
334
- return [float(item[0]) for item in
335
- c.execute("SELECT class_weights_threshold.threshold FROM class_weights_threshold").fetchall()]
419
+ return [
420
+ float(item[0])
421
+ for item in c.execute(
422
+ "SELECT class_weights_threshold.threshold FROM class_weights_threshold"
423
+ ).fetchall()
424
+ ]
425
+
426
+ @cached_property
427
+ def truth_configs(self) -> TruthConfigs:
428
+ """Get truth configs from db
429
+
430
+ :return: Truth configs
431
+ """
432
+ import json
433
+
434
+ from .datatypes import TruthConfig
435
+
436
+ with self.db() as c:
437
+ truth_configs: TruthConfigs = {}
438
+ for truth_config_record in c.execute("SELECT truth_config.config FROM truth_config").fetchall():
439
+ truth_config = json.loads(truth_config_record[0])
440
+ if truth_config["name"] not in truth_configs:
441
+ truth_configs[truth_config["name"]] = TruthConfig(
442
+ function=truth_config["function"],
443
+ stride_reduction=truth_config["stride_reduction"],
444
+ config=truth_config["config"],
445
+ )
446
+ return truth_configs
447
+
448
+ def target_truth_configs(self, t_id: int) -> TruthConfigs:
449
+ return _target_truth_configs(self.db, t_id)
336
450
 
337
451
  @cached_property
338
452
  def random_snrs(self) -> list[float]:
@@ -341,8 +455,12 @@ class MixtureDatabase:
341
455
  :return: Random SNRs
342
456
  """
343
457
  with self.db() as c:
344
- return list(set([float(item[0]) for item in
345
- c.execute("SELECT mixture.snr FROM mixture WHERE mixture.random_snr == 1").fetchall()]))
458
+ return list(
459
+ {
460
+ float(item[0])
461
+ for item in c.execute("SELECT mixture.snr FROM mixture WHERE mixture.random_snr == 1").fetchall()
462
+ }
463
+ )
346
464
 
347
465
  @cached_property
348
466
  def snrs(self) -> list[float]:
@@ -351,13 +469,21 @@ class MixtureDatabase:
351
469
  :return: SNRs
352
470
  """
353
471
  with self.db() as c:
354
- return list(set([float(item[0]) for item in
355
- c.execute("SELECT mixture.snr FROM mixture WHERE mixture.random_snr == 0").fetchall()]))
472
+ return list(
473
+ {
474
+ float(item[0])
475
+ for item in c.execute("SELECT mixture.snr FROM mixture WHERE mixture.random_snr == 0").fetchall()
476
+ }
477
+ )
356
478
 
357
479
  @cached_property
358
480
  def all_snrs(self) -> list[UniversalSNR]:
359
- return sorted(list(set([UniversalSNR(is_random=False, value=snr) for snr in self.snrs] +
360
- [UniversalSNR(is_random=True, value=snr) for snr in self.random_snrs])))
481
+ return sorted(
482
+ set(
483
+ [UniversalSNR(is_random=False, value=snr) for snr in self.snrs]
484
+ + [UniversalSNR(is_random=True, value=snr) for snr in self.random_snrs]
485
+ )
486
+ )
361
487
 
362
488
  @cached_property
363
489
  def spectral_masks(self) -> SpectralMasks:
@@ -368,13 +494,19 @@ class MixtureDatabase:
368
494
  from .db_datatypes import SpectralMaskRecord
369
495
 
370
496
  with self.db() as c:
371
- spectral_masks = [SpectralMaskRecord(*result) for result in
372
- c.execute("SELECT * FROM spectral_mask").fetchall()]
373
- return [SpectralMask(f_max_width=spectral_mask.f_max_width,
374
- f_num=spectral_mask.f_num,
375
- t_max_width=spectral_mask.t_max_width,
376
- t_num=spectral_mask.t_num,
377
- t_max_percent=spectral_mask.t_max_percent) for spectral_mask in spectral_masks]
497
+ spectral_masks = [
498
+ SpectralMaskRecord(*result) for result in c.execute("SELECT * FROM spectral_mask").fetchall()
499
+ ]
500
+ return [
501
+ SpectralMask(
502
+ f_max_width=spectral_mask.f_max_width,
503
+ f_num=spectral_mask.f_num,
504
+ t_max_width=spectral_mask.t_max_width,
505
+ t_num=spectral_mask.t_num,
506
+ t_max_percent=spectral_mask.t_max_percent,
507
+ )
508
+ for spectral_mask in spectral_masks
509
+ ]
378
510
 
379
511
  def spectral_mask(self, sm_id: int) -> SpectralMask:
380
512
  """Get spectral mask with ID from db
@@ -392,31 +524,40 @@ class MixtureDatabase:
392
524
  """
393
525
  import json
394
526
 
395
- from .datatypes import TruthSetting
396
- from .datatypes import TruthSettings
527
+ from .datatypes import TruthConfig
528
+ from .datatypes import TruthConfigs
397
529
  from .db_datatypes import TargetFileRecord
398
530
 
399
531
  with self.db() as c:
400
532
  target_files: TargetFiles = []
401
- target_file_records = [TargetFileRecord(*result) for result in
402
- c.execute("SELECT * FROM target_file").fetchall()]
533
+ target_file_records = [
534
+ TargetFileRecord(*result) for result in c.execute("SELECT * FROM target_file").fetchall()
535
+ ]
403
536
  for target_file_record in target_file_records:
404
- truth_settings: TruthSettings = []
405
- for truth_setting_records in c.execute(
406
- "SELECT truth_setting.setting " +
407
- "FROM truth_setting, target_file_truth_setting " +
408
- "WHERE ? = target_file_truth_setting.target_file_id " +
409
- "AND truth_setting.id = target_file_truth_setting.truth_setting_id",
410
- (target_file_record.id,)).fetchall():
411
- truth_setting = json.loads(truth_setting_records[0])
412
- truth_settings.append(TruthSetting(config=truth_setting.get('config', None),
413
- function=truth_setting.get('function', None),
414
- index=truth_setting.get('index', None)))
415
- target_files.append(TargetFile(name=target_file_record.name,
416
- samples=target_file_record.samples,
417
- level_type=target_file_record.level_type,
418
- truth_settings=truth_settings,
419
- speaker_id=target_file_record.speaker_id))
537
+ truth_configs: TruthConfigs = {}
538
+ for truth_config_records in c.execute(
539
+ "SELECT truth_config.config "
540
+ + "FROM truth_config, target_file_truth_config "
541
+ + "WHERE ? = target_file_truth_config.target_file_id "
542
+ + "AND truth_config.id = target_file_truth_config.truth_config_id",
543
+ (target_file_record.id,),
544
+ ).fetchall():
545
+ truth_config = json.loads(truth_config_records[0])
546
+ truth_configs[truth_config["name"]] = TruthConfig(
547
+ function=truth_config["function"],
548
+ stride_reduction=truth_config["stride_reduction"],
549
+ config=truth_config["config"],
550
+ )
551
+ target_files.append(
552
+ TargetFile(
553
+ name=target_file_record.name,
554
+ samples=target_file_record.samples,
555
+ class_indices=json.loads(target_file_record.class_indices),
556
+ level_type=target_file_record.level_type,
557
+ truth_configs=truth_configs,
558
+ speaker_id=target_file_record.speaker_id,
559
+ )
560
+ )
420
561
  return target_files
421
562
 
422
563
  @cached_property
@@ -452,8 +593,10 @@ class MixtureDatabase:
452
593
  :return: Noise files
453
594
  """
454
595
  with self.db() as c:
455
- return [NoiseFile(name=noise[0], samples=noise[1]) for noise in
456
- c.execute("SELECT noise_file.name, samples FROM noise_file").fetchall()]
596
+ return [
597
+ NoiseFile(name=noise[0], samples=noise[1])
598
+ for noise in c.execute("SELECT noise_file.name, samples FROM noise_file").fetchall()
599
+ ]
457
600
 
458
601
  @cached_property
459
602
  def noise_file_ids(self) -> list[int]:
@@ -487,9 +630,21 @@ class MixtureDatabase:
487
630
 
488
631
  :return: Impulse response files
489
632
  """
633
+ import json
634
+
635
+ from .datatypes import ImpulseResponseFile
636
+
490
637
  with self.db() as c:
491
- return [str(impulse_response[0]) for impulse_response in
492
- c.execute("SELECT impulse_response_file.file FROM impulse_response_file").fetchall()]
638
+ # for impulse_response in c.execute(
639
+ # "SELECT impulse_response_file.* FROM impulse_response_file"
640
+ # ).fetchall():
641
+ # print(impulse_response)
642
+ return [
643
+ ImpulseResponseFile(impulse_response[1], json.loads(impulse_response[2]))
644
+ for impulse_response in c.execute(
645
+ "SELECT impulse_response_file.* FROM impulse_response_file"
646
+ ).fetchall()
647
+ ]
493
648
 
494
649
  @cached_property
495
650
  def impulse_response_file_ids(self) -> list[int]:
@@ -498,15 +653,19 @@ class MixtureDatabase:
498
653
  :return: List of impulse response file IDs
499
654
  """
500
655
  with self.db() as c:
501
- return [int(item[0]) for item in
502
- c.execute("SELECT impulse_response_file.id FROM impulse_response_file").fetchall()]
656
+ return [
657
+ int(item[0])
658
+ for item in c.execute("SELECT impulse_response_file.id FROM impulse_response_file").fetchall()
659
+ ]
503
660
 
504
- def impulse_response_file(self, ir_id: int) -> str:
661
+ def impulse_response_file(self, ir_id: int | None) -> str | None:
505
662
  """Get impulse response file with ID from db
506
663
 
507
664
  :param ir_id: Impulse response file ID
508
665
  :return: Noise
509
666
  """
667
+ if ir_id is None:
668
+ return None
510
669
  return _impulse_response_file(self.db, ir_id)
511
670
 
512
671
  @cached_property
@@ -524,18 +683,22 @@ class MixtureDatabase:
524
683
 
525
684
  :return: Mixtures
526
685
  """
527
- from .helpers import to_mixture
528
- from .helpers import to_target
529
686
  from .db_datatypes import MixtureRecord
530
687
  from .db_datatypes import TargetRecord
688
+ from .helpers import to_mixture
689
+ from .helpers import to_target
531
690
 
532
691
  with self.db() as c:
533
692
  mixtures: Mixtures = []
534
693
  for mixture in [MixtureRecord(*record) for record in c.execute("SELECT * FROM mixture").fetchall()]:
535
- targets = [to_target(TargetRecord(*target)) for target in c.execute(
536
- "SELECT target.* FROM target, mixture_target " +
537
- "WHERE ? = mixture_target.mixture_id AND target.id = mixture_target.target_id",
538
- (mixture.id,)).fetchall()]
694
+ targets = [
695
+ to_target(TargetRecord(*target))
696
+ for target in c.execute(
697
+ "SELECT target.* FROM target, mixture_target "
698
+ + "WHERE ? = mixture_target.mixture_id AND target.id = mixture_target.target_id",
699
+ (mixture.id,),
700
+ ).fetchall()
701
+ ]
539
702
  mixtures.append(to_mixture(mixture, targets))
540
703
  return mixtures
541
704
 
@@ -561,23 +724,15 @@ class MixtureDatabase:
561
724
  with self.db() as c:
562
725
  return int(c.execute("SELECT top.mixid_width FROM top").fetchone()[0])
563
726
 
564
- def location_filename(self, name: str) -> str:
565
- """Add the location to the given file name
727
+ def mixture_location(self, m_id: int) -> str:
728
+ """Get the file location for the give mixture ID
566
729
 
567
- :param name: File name
568
- :return: Location added
730
+ :param m_id: Zero-based mixture ID
731
+ :return: File location
569
732
  """
570
733
  from os.path import join
571
734
 
572
- return join(self.location, name)
573
-
574
- def mixture_filename(self, m_id: int) -> str:
575
- """Get the HDF5 file name for the give mixture ID
576
-
577
- :param m_id: Zero-based mixture ID
578
- :return: File name
579
- """
580
- return self.location_filename(self.mixture(m_id).name)
735
+ return join(self.location, self.mixture(m_id).name)
581
736
 
582
737
  @cached_property
583
738
  def num_mixtures(self) -> int:
@@ -595,9 +750,9 @@ class MixtureDatabase:
595
750
  :param items: String(s) of dataset(s) to retrieve
596
751
  :return: Data (or tuple of data)
597
752
  """
598
- from .helpers import read_mixture_data
753
+ from sonusai.mixture import read_cached_data
599
754
 
600
- return read_mixture_data(self.location_filename(self.mixture(m_id).name), items)
755
+ return read_cached_data(self.location, "mixture", self.mixture(m_id).name, items)
601
756
 
602
757
  def read_target_audio(self, t_id: int) -> AudioT:
603
758
  """Read target audio
@@ -624,10 +779,19 @@ class MixtureDatabase:
624
779
  audio = read_audio(noise.name)
625
780
  audio = apply_augmentation(audio, mixture.noise.augmentation)
626
781
  if mixture.noise.augmentation.ir is not None:
627
- audio = apply_impulse_response(audio, read_ir(self.impulse_response_file(mixture.noise.augmentation.ir)))
782
+ audio = apply_impulse_response(
783
+ audio,
784
+ read_ir(self.impulse_response_file(mixture.noise.augmentation.ir)),
785
+ )
628
786
 
629
787
  return audio
630
788
 
789
+ def mixture_class_indices(self, m_id: int) -> list[int]:
790
+ class_indices: list[int] = []
791
+ for t_id in self.mixture(m_id).target_ids:
792
+ class_indices.extend(self.target_file(t_id).class_indices)
793
+ return sorted(set(class_indices))
794
+
631
795
  def mixture_targets(self, m_id: int, force: bool = False) -> AudiosT:
632
796
  """Get the list of augmented target audio data (one per target in the mixup) for the given mixture ID
633
797
 
@@ -635,36 +799,34 @@ class MixtureDatabase:
635
799
  :param force: Force computing data from original sources regardless of whether cached data exists
636
800
  :return: List of augmented target audio data (one per target in the mixup)
637
801
  """
638
- from sonusai import SonusAIError
639
802
  from .augmentation import apply_augmentation
640
803
  from .augmentation import apply_gain
641
804
  from .augmentation import pad_audio_to_length
642
805
 
643
806
  if not force:
644
- targets_audio = self.read_mixture_data(m_id, 'targets')
807
+ targets_audio = self.read_mixture_data(m_id, "targets")
645
808
  if targets_audio is not None:
646
809
  return list(targets_audio)
647
810
 
648
811
  mixture = self.mixture(m_id)
649
812
  if mixture is None:
650
- raise SonusAIError(f'Could not find mixture for m_id: {m_id}')
813
+ raise ValueError(f"Could not find mixture for m_id: {m_id}")
651
814
 
652
815
  targets_audio = []
653
816
  for target in mixture.targets:
654
817
  target_audio = self.read_target_audio(target.file_id)
655
- target_audio = apply_augmentation(audio=target_audio,
656
- augmentation=target.augmentation,
657
- frame_length=self.feature_step_samples)
818
+ target_audio = apply_augmentation(
819
+ audio=target_audio,
820
+ augmentation=target.augmentation,
821
+ frame_length=self.feature_step_samples,
822
+ )
658
823
  target_audio = apply_gain(audio=target_audio, gain=mixture.target_snr_gain)
659
824
  target_audio = pad_audio_to_length(audio=target_audio, length=mixture.samples)
660
825
  targets_audio.append(target_audio)
661
826
 
662
827
  return targets_audio
663
828
 
664
- def mixture_targets_f(self,
665
- m_id: int,
666
- targets: Optional[AudiosT] = None,
667
- force: bool = False) -> AudiosF:
829
+ def mixture_targets_f(self, m_id: int, targets: AudiosT | None = None, force: bool = False) -> AudiosF:
668
830
  """Get the list of augmented target transform data (one per target in the mixup) for the given mixture ID
669
831
 
670
832
  :param m_id: Zero-based mixture ID
@@ -679,10 +841,7 @@ class MixtureDatabase:
679
841
 
680
842
  return [forward_transform(target, self.ft_config) for target in targets]
681
843
 
682
- def mixture_target(self,
683
- m_id: int,
684
- targets: Optional[AudiosT] = None,
685
- force: bool = False) -> AudioT:
844
+ def mixture_target(self, m_id: int, targets: AudiosT | None = None, force: bool = False) -> AudioT:
686
845
  """Get the augmented target audio data for the given mixture ID
687
846
 
688
847
  :param m_id: Zero-based mixture ID
@@ -693,7 +852,7 @@ class MixtureDatabase:
693
852
  from .helpers import get_target
694
853
 
695
854
  if not force:
696
- target = self.read_mixture_data(m_id, 'target')
855
+ target = self.read_mixture_data(m_id, "target")
697
856
  if target is not None:
698
857
  return target
699
858
 
@@ -702,11 +861,13 @@ class MixtureDatabase:
702
861
 
703
862
  return get_target(self, self.mixture(m_id), targets)
704
863
 
705
- def mixture_target_f(self,
706
- m_id: int,
707
- targets: Optional[AudiosT] = None,
708
- target: Optional[AudioT] = None,
709
- force: bool = False) -> AudioF:
864
+ def mixture_target_f(
865
+ self,
866
+ m_id: int,
867
+ targets: AudiosT | None = None,
868
+ target: AudioT | None = None,
869
+ force: bool = False,
870
+ ) -> AudioF:
710
871
  """Get the augmented target transform data for the given mixture ID
711
872
 
712
873
  :param m_id: Zero-based mixture ID
@@ -722,9 +883,7 @@ class MixtureDatabase:
722
883
 
723
884
  return forward_transform(target, self.ft_config)
724
885
 
725
- def mixture_noise(self,
726
- m_id: int,
727
- force: bool = False) -> AudioT:
886
+ def mixture_noise(self, m_id: int, force: bool = False) -> AudioT:
728
887
  """Get the augmented noise audio data for the given mixture ID
729
888
 
730
889
  :param m_id: Zero-based mixture ID
@@ -735,7 +894,7 @@ class MixtureDatabase:
735
894
  from .augmentation import apply_gain
736
895
 
737
896
  if not force:
738
- noise = self.read_mixture_data(m_id, 'noise')
897
+ noise = self.read_mixture_data(m_id, "noise")
739
898
  if noise is not None:
740
899
  return noise
741
900
 
@@ -744,10 +903,7 @@ class MixtureDatabase:
744
903
  noise = get_next_noise(audio=noise, offset=mixture.noise.offset, length=mixture.samples)
745
904
  return apply_gain(audio=noise, gain=mixture.noise_snr_gain)
746
905
 
747
- def mixture_noise_f(self,
748
- m_id: int,
749
- noise: Optional[AudioT] = None,
750
- force: bool = False) -> AudioF:
906
+ def mixture_noise_f(self, m_id: int, noise: AudioT | None = None, force: bool = False) -> AudioF:
751
907
  """Get the augmented noise transform for the given mixture ID
752
908
 
753
909
  :param m_id: Zero-based mixture ID
@@ -762,12 +918,14 @@ class MixtureDatabase:
762
918
 
763
919
  return forward_transform(noise, self.ft_config)
764
920
 
765
- def mixture_mixture(self,
766
- m_id: int,
767
- targets: Optional[AudiosT] = None,
768
- target: Optional[AudioT] = None,
769
- noise: Optional[AudioT] = None,
770
- force: bool = False) -> AudioT:
921
+ def mixture_mixture(
922
+ self,
923
+ m_id: int,
924
+ targets: AudiosT | None = None,
925
+ target: AudioT | None = None,
926
+ noise: AudioT | None = None,
927
+ force: bool = False,
928
+ ) -> AudioT:
771
929
  """Get the mixture audio data for the given mixture ID
772
930
 
773
931
  :param m_id: Zero-based mixture ID
@@ -778,7 +936,7 @@ class MixtureDatabase:
778
936
  :return: Mixture audio data
779
937
  """
780
938
  if not force:
781
- mixture = self.read_mixture_data(m_id, 'mixture')
939
+ mixture = self.read_mixture_data(m_id, "mixture")
782
940
  if mixture is not None:
783
941
  return mixture
784
942
 
@@ -790,13 +948,15 @@ class MixtureDatabase:
790
948
 
791
949
  return target + noise
792
950
 
793
- def mixture_mixture_f(self,
794
- m_id: int,
795
- targets: Optional[AudiosT] = None,
796
- target: Optional[AudioT] = None,
797
- noise: Optional[AudioT] = None,
798
- mixture: Optional[AudioT] = None,
799
- force: bool = False) -> AudioF:
951
+ def mixture_mixture_f(
952
+ self,
953
+ m_id: int,
954
+ targets: AudiosT | None = None,
955
+ target: AudioT | None = None,
956
+ noise: AudioT | None = None,
957
+ mixture: AudioT | None = None,
958
+ force: bool = False,
959
+ ) -> AudioF:
800
960
  """Get the mixture transform for the given mixture ID
801
961
 
802
962
  :param m_id: Zero-based mixture ID
@@ -817,18 +977,22 @@ class MixtureDatabase:
817
977
 
818
978
  m = self.mixture(m_id)
819
979
  if m.spectral_mask_id is not None:
820
- mixture_f = apply_spectral_mask(audio_f=mixture_f,
821
- spectral_mask=self.spectral_mask(int(m.spectral_mask_id)),
822
- seed=m.spectral_mask_seed)
980
+ mixture_f = apply_spectral_mask(
981
+ audio_f=mixture_f,
982
+ spectral_mask=self.spectral_mask(int(m.spectral_mask_id)),
983
+ seed=m.spectral_mask_seed,
984
+ )
823
985
 
824
986
  return mixture_f
825
987
 
826
- def mixture_truth_t(self,
827
- m_id: int,
828
- targets: Optional[AudiosT] = None,
829
- noise: Optional[AudioT] = None,
830
- mixture: Optional[AudioT] = None,
831
- force: bool = False) -> Truth:
988
+ def mixture_truth_t(
989
+ self,
990
+ m_id: int,
991
+ targets: AudiosT | None = None,
992
+ noise: AudioT | None = None,
993
+ mixture: AudioT | None = None,
994
+ force: bool = False,
995
+ ) -> TruthDict:
832
996
  """Get the truth_t data for the given mixture ID
833
997
 
834
998
  :param m_id: Zero-based mixture ID
@@ -838,10 +1002,10 @@ class MixtureDatabase:
838
1002
  :param force: Force computing data from original sources regardless of whether cached data exists
839
1003
  :return: truth_t data
840
1004
  """
841
- from .helpers import get_truth_t
1005
+ from .helpers import get_truth
842
1006
 
843
1007
  if not force:
844
- truth_t = self.read_mixture_data(m_id, 'truth_t')
1008
+ truth_t = self.read_mixture_data(m_id, "truth_t")
845
1009
  if truth_t is not None:
846
1010
  return truth_t
847
1011
 
@@ -852,19 +1016,18 @@ class MixtureDatabase:
852
1016
  noise = self.mixture_noise(m_id, force)
853
1017
 
854
1018
  if force or mixture is None:
855
- noise = self.mixture_mixture(m_id,
856
- targets=targets,
857
- noise=noise,
858
- force=force)
859
-
860
- return get_truth_t(self, self.mixture(m_id), targets, noise, mixture)
861
-
862
- def mixture_segsnr_t(self,
863
- m_id: int,
864
- targets: Optional[AudiosT] = None,
865
- target: Optional[AudioT] = None,
866
- noise: Optional[AudioT] = None,
867
- force: bool = False) -> Segsnr:
1019
+ mixture = self.mixture_mixture(m_id, targets=targets, noise=noise, force=force)
1020
+
1021
+ return get_truth(self, self.mixture(m_id), targets, noise, mixture)
1022
+
1023
+ def mixture_segsnr_t(
1024
+ self,
1025
+ m_id: int,
1026
+ targets: AudiosT | None = None,
1027
+ target: AudioT | None = None,
1028
+ noise: AudioT | None = None,
1029
+ force: bool = False,
1030
+ ) -> Segsnr:
868
1031
  """Get the segsnr_t data for the given mixture ID
869
1032
 
870
1033
  :param m_id: Zero-based mixture ID
@@ -877,7 +1040,7 @@ class MixtureDatabase:
877
1040
  from .helpers import get_segsnr_t
878
1041
 
879
1042
  if not force:
880
- segsnr_t = self.read_mixture_data(m_id, 'segsnr_t')
1043
+ segsnr_t = self.read_mixture_data(m_id, "segsnr_t")
881
1044
  if segsnr_t is not None:
882
1045
  return segsnr_t
883
1046
 
@@ -889,13 +1052,15 @@ class MixtureDatabase:
889
1052
 
890
1053
  return get_segsnr_t(self, self.mixture(m_id), target, noise)
891
1054
 
892
- def mixture_segsnr(self,
893
- m_id: int,
894
- segsnr_t: Optional[Segsnr] = None,
895
- targets: Optional[AudiosT] = None,
896
- target: Optional[AudioT] = None,
897
- noise: Optional[AudioT] = None,
898
- force: bool = False) -> Segsnr:
1055
+ def mixture_segsnr(
1056
+ self,
1057
+ m_id: int,
1058
+ segsnr_t: Segsnr | None = None,
1059
+ targets: AudiosT | None = None,
1060
+ target: AudioT | None = None,
1061
+ noise: AudioT | None = None,
1062
+ force: bool = False,
1063
+ ) -> Segsnr:
899
1064
  """Get the segsnr data for the given mixture ID
900
1065
 
901
1066
  :param m_id: Zero-based mixture ID
@@ -907,28 +1072,30 @@ class MixtureDatabase:
907
1072
  :return: segsnr data
908
1073
  """
909
1074
  if not force:
910
- segsnr = self.read_mixture_data(m_id, 'segsnr')
1075
+ segsnr = self.read_mixture_data(m_id, "segsnr")
911
1076
  if segsnr is not None:
912
1077
  return segsnr
913
1078
 
914
- segsnr_t = self.read_mixture_data(m_id, 'segsnr_t')
1079
+ segsnr_t = self.read_mixture_data(m_id, "segsnr_t")
915
1080
  if segsnr_t is not None:
916
- return segsnr_t[0::self.ft_config.R]
1081
+ return segsnr_t[0 :: self.ft_config.overlap]
917
1082
 
918
1083
  if force or segsnr_t is None:
919
1084
  segsnr_t = self.mixture_segsnr_t(m_id, targets, target, noise, force)
920
1085
 
921
- return segsnr_t[0::self.ft_config.R]
922
-
923
- def mixture_ft(self,
924
- m_id: int,
925
- targets: Optional[AudiosT] = None,
926
- target: Optional[AudioT] = None,
927
- noise: Optional[AudioT] = None,
928
- mixture_f: Optional[AudioF] = None,
929
- mixture: Optional[AudioT] = None,
930
- truth_t: Optional[Truth] = None,
931
- force: bool = False) -> tuple[Feature, Truth]:
1086
+ return segsnr_t[0 :: self.ft_config.overlap]
1087
+
1088
+ def mixture_ft(
1089
+ self,
1090
+ m_id: int,
1091
+ targets: AudiosT | None = None,
1092
+ target: AudioT | None = None,
1093
+ noise: AudioT | None = None,
1094
+ mixture_f: AudioF | None = None,
1095
+ mixture: AudioT | None = None,
1096
+ truth_t: TruthDict | None = None,
1097
+ force: bool = False,
1098
+ ) -> tuple[Feature, TruthDict]:
932
1099
  """Get the feature and truth_f data for the given mixture ID
933
1100
 
934
1101
  :param m_id: Zero-based mixture ID
@@ -941,63 +1108,45 @@ class MixtureDatabase:
941
1108
  :param force: Force computing data from original sources regardless of whether cached data exists
942
1109
  :return: Tuple of (feature, truth_f) data
943
1110
  """
944
- from dataclasses import asdict
945
-
946
- import numpy as np
947
1111
  from pyaaware import FeatureGenerator
948
1112
 
949
- from .truth import truth_reduction
1113
+ from .truth import truth_stride_reduction
950
1114
 
951
1115
  if not force:
952
- feature, truth_f = self.read_mixture_data(m_id, ['feature', 'truth_f'])
1116
+ feature, truth_f = self.read_mixture_data(m_id, ["feature", "truth_f"])
953
1117
  if feature is not None and truth_f is not None:
954
1118
  return feature, truth_f
955
1119
 
956
1120
  if force or mixture_f is None:
957
- mixture_f = self.mixture_mixture_f(m_id=m_id,
958
- targets=targets,
959
- target=target,
960
- noise=noise,
961
- mixture=mixture,
962
- force=force)
1121
+ mixture_f = self.mixture_mixture_f(
1122
+ m_id=m_id,
1123
+ targets=targets,
1124
+ target=target,
1125
+ noise=noise,
1126
+ mixture=mixture,
1127
+ force=force,
1128
+ )
963
1129
 
964
1130
  if force or truth_t is None:
965
1131
  truth_t = self.mixture_truth_t(m_id=m_id, targets=targets, noise=noise, force=force)
966
1132
 
967
- m = self.mixture(m_id)
968
- transform_frames = self.mixture_transform_frames(m_id)
969
- feature_frames = self.mixture_feature_frames(m_id)
970
-
971
- if truth_t is None:
972
- truth_t = np.zeros((m.samples, self.num_classes), dtype=np.float32)
973
-
974
- feature = np.empty((feature_frames, self.fg_stride, self.feature_parameters), dtype=np.float32)
975
- truth_f = np.empty((feature_frames, self.num_classes), dtype=np.complex64)
976
-
977
- fg = FeatureGenerator(**asdict(self.fg_config))
978
- feature_frame = 0
979
- for transform_frame in range(transform_frames):
980
- indices = slice(transform_frame * self.ft_config.R, (transform_frame + 1) * self.ft_config.R)
981
- fg.execute(mixture_f[transform_frame],
982
- truth_reduction(truth_t[indices], self.truth_reduction_function))
1133
+ fg = FeatureGenerator(self.fg_config.feature_mode, self.fg_config.truth_parameters)
983
1134
 
984
- if fg.eof():
985
- feature[feature_frame] = fg.feature()
986
- truth_f[feature_frame] = fg.truth()
987
- feature_frame += 1
988
-
989
- if np.isreal(truth_f).all():
990
- return feature, truth_f.real
1135
+ feature, truth_f = fg.execute_all(mixture_f, truth_t)
1136
+ for key in self.truth_configs:
1137
+ truth_f[key] = truth_stride_reduction(truth_f[key], self.truth_configs[key].stride_reduction)
991
1138
 
992
1139
  return feature, truth_f
993
1140
 
994
- def mixture_feature(self,
995
- m_id: int,
996
- targets: Optional[AudiosT] = None,
997
- noise: Optional[AudioT] = None,
998
- mixture: Optional[AudioT] = None,
999
- truth_t: Optional[Truth] = None,
1000
- force: bool = False) -> Feature:
1141
+ def mixture_feature(
1142
+ self,
1143
+ m_id: int,
1144
+ targets: AudiosT | None = None,
1145
+ noise: AudioT | None = None,
1146
+ mixture: AudioT | None = None,
1147
+ truth_t: TruthDict | None = None,
1148
+ force: bool = False,
1149
+ ) -> Feature:
1001
1150
  """Get the feature data for the given mixture ID
1002
1151
 
1003
1152
  :param m_id: Zero-based mixture ID
@@ -1008,21 +1157,25 @@ class MixtureDatabase:
1008
1157
  :param force: Force computing data from original sources regardless of whether cached data exists
1009
1158
  :return: Feature data
1010
1159
  """
1011
- feature, _ = self.mixture_ft(m_id=m_id,
1012
- targets=targets,
1013
- noise=noise,
1014
- mixture=mixture,
1015
- truth_t=truth_t,
1016
- force=force)
1160
+ feature, _ = self.mixture_ft(
1161
+ m_id=m_id,
1162
+ targets=targets,
1163
+ noise=noise,
1164
+ mixture=mixture,
1165
+ truth_t=truth_t,
1166
+ force=force,
1167
+ )
1017
1168
  return feature
1018
1169
 
1019
- def mixture_truth_f(self,
1020
- m_id: int,
1021
- targets: Optional[AudiosT] = None,
1022
- noise: Optional[AudioT] = None,
1023
- mixture: Optional[AudioT] = None,
1024
- truth_t: Optional[Truth] = None,
1025
- force: bool = False) -> Truth:
1170
+ def mixture_truth_f(
1171
+ self,
1172
+ m_id: int,
1173
+ targets: AudiosT | None = None,
1174
+ noise: AudioT | None = None,
1175
+ mixture: AudioT | None = None,
1176
+ truth_t: TruthDict | None = None,
1177
+ force: bool = False,
1178
+ ) -> TruthDict:
1026
1179
  """Get the truth_f data for the given mixture ID
1027
1180
 
1028
1181
  :param m_id: Zero-based mixture ID
@@ -1033,20 +1186,24 @@ class MixtureDatabase:
1033
1186
  :param force: Force computing data from original sources regardless of whether cached data exists
1034
1187
  :return: truth_f data
1035
1188
  """
1036
- _, truth_f = self.mixture_ft(m_id=m_id,
1037
- targets=targets,
1038
- noise=noise,
1039
- mixture=mixture,
1040
- truth_t=truth_t,
1041
- force=force)
1189
+ _, truth_f = self.mixture_ft(
1190
+ m_id=m_id,
1191
+ targets=targets,
1192
+ noise=noise,
1193
+ mixture=mixture,
1194
+ truth_t=truth_t,
1195
+ force=force,
1196
+ )
1042
1197
  return truth_f
1043
1198
 
1044
- def mixture_class_count(self,
1045
- m_id: int,
1046
- targets: Optional[AudiosT] = None,
1047
- noise: Optional[AudioT] = None,
1048
- truth_t: Optional[Truth] = None) -> ClassCount:
1049
- """Compute the number of samples for which each truth index is active for the given mixture ID
1199
+ def mixture_class_count(
1200
+ self,
1201
+ m_id: int,
1202
+ targets: AudiosT | None = None,
1203
+ noise: AudioT | None = None,
1204
+ truth_t: TruthDict | None = None,
1205
+ ) -> ClassCount:
1206
+ """Compute the number of frames for which each class index is active for the given mixture ID
1050
1207
 
1051
1208
  :param m_id: Zero-based mixture ID
1052
1209
  :param targets: List of augmented target audio (one per target in the mixup)
@@ -1061,10 +1218,9 @@ class MixtureDatabase:
1061
1218
 
1062
1219
  class_count = [0] * self.num_classes
1063
1220
  num_classes = self.num_classes
1064
- if self.truth_mutex:
1065
- num_classes -= 1
1066
- for cl in range(num_classes):
1067
- class_count[cl] = int(np.sum(truth_t[:, cl] >= self.class_weights_thresholds[cl]))
1221
+ if "sed" in self.truth_configs:
1222
+ for cl in range(num_classes):
1223
+ class_count[cl] = int(np.sum(truth_t["sed"][:, cl] >= self.class_weights_thresholds[cl]))
1068
1224
 
1069
1225
  return class_count
1070
1226
 
@@ -1086,7 +1242,7 @@ class MixtureDatabase:
1086
1242
  def speech_metadata_tiers(self) -> list[str]:
1087
1243
  return sorted(set(self.speaker_metadata_tiers + self.textgrid_metadata_tiers))
1088
1244
 
1089
- def speaker(self, s_id: int | None, tier: str) -> Optional[str]:
1245
+ def speaker(self, s_id: int | None, tier: str) -> str | None:
1090
1246
  return _speaker(self.db, s_id, tier)
1091
1247
 
1092
1248
  def speech_metadata(self, tier: str) -> list[str]:
@@ -1126,9 +1282,13 @@ class MixtureDatabase:
1126
1282
  entries = []
1127
1283
  for entry in data:
1128
1284
  if target.augmentation.tempo is not None:
1129
- entries.append(Interval(entry.start / target.augmentation.tempo,
1130
- entry.end / target.augmentation.tempo,
1131
- entry.label))
1285
+ entries.append(
1286
+ Interval(
1287
+ entry.start / target.augmentation.tempo,
1288
+ entry.end / target.augmentation.tempo,
1289
+ entry.label,
1290
+ )
1291
+ )
1132
1292
  else:
1133
1293
  entries.append(entry)
1134
1294
  results.append(entries)
@@ -1138,12 +1298,9 @@ class MixtureDatabase:
1138
1298
  for target in self.mixture(mixid).targets:
1139
1299
  results.append(self.speaker(self.target_file(target.file_id).speaker_id, tier))
1140
1300
 
1141
- return sorted(results)
1301
+ return results
1142
1302
 
1143
- def mixids_for_speech_metadata(self,
1144
- tier: str,
1145
- value: str = None,
1146
- where: str = None) -> list[int]:
1303
+ def mixids_for_speech_metadata(self, tier: str, value: str | None = None, where: str | None = None) -> list[int]:
1147
1304
  """Get a list of mixture IDs for the given speech metadata tier.
1148
1305
 
1149
1306
  If 'where' is None, then include mixture IDs whose tier values are equal to the given 'value'.
@@ -1162,25 +1319,27 @@ class MixtureDatabase:
1162
1319
  >>> mixids = mixdb.mixids_for_speech_metadata('dialect', where="dialect in ('New York City', 'Northern')")
1163
1320
  Get mixture IDs for mixtures with speakers whose dialects are either 'New York City' or 'Northern'.
1164
1321
  """
1165
- from sonusai import SonusAIError
1166
-
1167
1322
  if value is None and where is None:
1168
- raise SonusAIError('Must provide either value or where')
1323
+ raise ValueError("Must provide either value or where")
1169
1324
 
1170
1325
  if where is None:
1171
1326
  where = f"{tier} = '{value}'"
1172
1327
 
1173
1328
  if tier in self.textgrid_metadata_tiers:
1174
- raise SonusAIError(f'TextGrid tier data, "{tier}", is not supported in mixids_for_speech_metadata().')
1329
+ raise ValueError(f"TextGrid tier data, '{tier}', is not supported in mixids_for_speech_metadata().")
1175
1330
 
1176
1331
  with self.db() as c:
1177
- speaker_ids = [speaker_id[0] for speaker_id in
1178
- c.execute(f"SELECT id FROM speaker WHERE {where}").fetchall()]
1179
- results = c.execute(f"SELECT id FROM target_file " +
1180
- f"WHERE speaker_id IN ({','.join(map(str, speaker_ids))})").fetchall()
1332
+ speaker_ids = [
1333
+ speaker_id[0] for speaker_id in c.execute(f"SELECT id FROM speaker WHERE {where}").fetchall()
1334
+ ]
1335
+ results = c.execute(
1336
+ "SELECT id FROM target_file " + f"WHERE speaker_id IN ({','.join(map(str, speaker_ids))})"
1337
+ ).fetchall()
1181
1338
  target_file_ids = [target_file_id[0] for target_file_id in results]
1182
- results = c.execute("SELECT mixture_id FROM mixture_target " +
1183
- f"WHERE mixture_target.target_id IN ({','.join(map(str, target_file_ids))})").fetchall()
1339
+ results = c.execute(
1340
+ "SELECT mixture_id FROM mixture_target "
1341
+ + f"WHERE mixture_target.target_id IN ({','.join(map(str, target_file_ids))})"
1342
+ ).fetchall()
1184
1343
 
1185
1344
  return [mixture_id[0] - 1 for mixture_id in results]
1186
1345
 
@@ -1189,9 +1348,9 @@ class MixtureDatabase:
1189
1348
 
1190
1349
  return mixture_all_speech_metadata(self, self.mixture(m_id))
1191
1350
 
1192
- def mixture_metrics(self, m_id: int,
1193
- metrics: list[str],
1194
- force: bool = False) -> list[float | int | str | Segsnr]:
1351
+ def mixture_metrics(
1352
+ self, m_id: int, metrics: list[str], force: bool = False
1353
+ ) -> list[float | int | str | Segsnr | None]:
1195
1354
  """Get metrics data for the given mixture ID
1196
1355
 
1197
1356
  :param m_id: Zero-based mixture ID
@@ -1199,12 +1358,11 @@ class MixtureDatabase:
1199
1358
  :param force: Force computing data from original sources regardless of whether cached data exists
1200
1359
  :return: List of metric data
1201
1360
  """
1202
- from typing import Callable
1361
+ from collections.abc import Callable
1203
1362
 
1204
1363
  import numpy as np
1205
1364
  from pystoi import stoi
1206
1365
 
1207
- from sonusai import SonusAIError
1208
1366
  from sonusai.metrics import calc_audio_stats
1209
1367
  from sonusai.metrics import calc_phase_distance
1210
1368
  from sonusai.metrics import calc_segsnr_f
@@ -1314,7 +1472,7 @@ class MixtureDatabase:
1314
1472
  def get() -> AudioStatsMetrics:
1315
1473
  nonlocal state
1316
1474
  if state is None:
1317
- state = calc_audio_stats(target_audio(), self.fg_info.ft_config.N / SAMPLE_RATE)
1475
+ state = calc_audio_stats(target_audio(), self.fg_info.ft_config.length / SAMPLE_RATE)
1318
1476
  return state
1319
1477
 
1320
1478
  return get
@@ -1327,7 +1485,7 @@ class MixtureDatabase:
1327
1485
  def get() -> AudioStatsMetrics:
1328
1486
  nonlocal state
1329
1487
  if state is None:
1330
- state = calc_audio_stats(noise_audio(), self.fg_info.ft_config.N / SAMPLE_RATE)
1488
+ state = calc_audio_stats(noise_audio(), self.fg_info.ft_config.length / SAMPLE_RATE)
1331
1489
  return state
1332
1490
 
1333
1491
  return get
@@ -1340,9 +1498,10 @@ class MixtureDatabase:
1340
1498
  def get(asr_name) -> dict:
1341
1499
  nonlocal state
1342
1500
  if asr_name not in state:
1343
- state[asr_name] = self.asr_configs.get(asr_name, None)
1344
- if state[asr_name] is None:
1345
- raise SonusAIError(f"Unrecognized ASR name: '{asr_name}'")
1501
+ value = self.asr_configs.get(asr_name, None)
1502
+ if value is None:
1503
+ raise ValueError(f"Unrecognized ASR name: '{asr_name}'")
1504
+ state[asr_name] = value
1346
1505
  return state[asr_name]
1347
1506
 
1348
1507
  return get
@@ -1376,15 +1535,14 @@ class MixtureDatabase:
1376
1535
  mixture_asr = create_mixture_asr()
1377
1536
 
1378
1537
  def get_asr_name(m: str) -> str:
1379
- parts = m.split('.')
1538
+ parts = m.split(".")
1380
1539
  if len(parts) != 2:
1381
- raise SonusAIError(
1382
- f"Unrecognized format: '{m}'; must be of the form: '<metric>.<name>'")
1540
+ raise ValueError(f"Unrecognized format: '{m}'; must be of the form: '<metric>.<name>'")
1383
1541
  asr_name = parts[1]
1384
1542
  return asr_name
1385
1543
 
1386
- def calc(m: str) -> float | int | str | Segsnr:
1387
- if m == 'mxsnr':
1544
+ def calc(m: str) -> float | int | str | Segsnr | None:
1545
+ if m == "mxsnr":
1388
1546
  return self.mixture(m_id).snr
1389
1547
 
1390
1548
  # Get cached data first, if exists
@@ -1394,12 +1552,12 @@ class MixtureDatabase:
1394
1552
  return value
1395
1553
 
1396
1554
  # Otherwise, generate data as needed
1397
- if m.startswith('mxwer'):
1555
+ if m.startswith("mxwer"):
1398
1556
  asr_name = get_asr_name(m)
1399
1557
 
1400
1558
  if self.mixture(m_id).snr < -96:
1401
1559
  # noise only, ignore/reset target asr
1402
- return float('nan')
1560
+ return float("nan")
1403
1561
 
1404
1562
  if target_asr(asr_name):
1405
1563
  return calc_wer(mixture_asr(asr_name), target_asr(asr_name)).wer * 100
@@ -1407,159 +1565,166 @@ class MixtureDatabase:
1407
1565
  # TODO: should this be NaN like above?
1408
1566
  return float(0)
1409
1567
 
1410
- if m.startswith('basewer'):
1568
+ if m.startswith("basewer"):
1411
1569
  asr_name = get_asr_name(m)
1412
1570
 
1413
- text = self.mixture_speech_metadata(m_id, 'text')[0]
1571
+ text = self.mixture_speech_metadata(m_id, "text")[0]
1414
1572
  if text is not None:
1415
1573
  return calc_wer(target_asr(asr_name), text).wer * 100
1416
1574
 
1417
1575
  # TODO: should this be NaN like above?
1418
1576
  return float(0)
1419
1577
 
1420
- if m.startswith('mxasr'):
1578
+ if m.startswith("mxasr"):
1421
1579
  return mixture_asr(get_asr_name(m))
1422
1580
 
1423
- if m == 'mxssnr_avg':
1581
+ if m == "mxssnr_avg":
1424
1582
  return calc_segsnr_f(segsnr_f()).avg
1425
1583
 
1426
- if m == 'mxssnr_std':
1584
+ if m == "mxssnr_std":
1427
1585
  return calc_segsnr_f(segsnr_f()).std
1428
1586
 
1429
- if m == 'mxssnrdb_avg':
1587
+ if m == "mxssnrdb_avg":
1430
1588
  return calc_segsnr_f(segsnr_f()).db_avg
1431
1589
 
1432
- if m == 'mxssnrdb_std':
1590
+ if m == "mxssnrdb_std":
1433
1591
  return calc_segsnr_f(segsnr_f()).db_std
1434
1592
 
1435
- if m == 'mxssnrf_avg':
1593
+ if m == "mxssnrf_avg":
1436
1594
  return calc_segsnr_f_bin(target_f(), noise_f()).avg
1437
1595
 
1438
- if m == 'mxssnrf_std':
1596
+ if m == "mxssnrf_std":
1439
1597
  return calc_segsnr_f_bin(target_f(), noise_f()).std
1440
1598
 
1441
- if m == 'mxssnrdbf_avg':
1599
+ if m == "mxssnrdbf_avg":
1442
1600
  return calc_segsnr_f_bin(target_f(), noise_f()).db_avg
1443
1601
 
1444
- if m == 'mxssnrdbf_std':
1602
+ if m == "mxssnrdbf_std":
1445
1603
  return calc_segsnr_f_bin(target_f(), noise_f()).db_std
1446
1604
 
1447
- if m == 'mxpesq':
1605
+ if m == "mxpesq":
1448
1606
  if self.mixture(m_id).snr < -96:
1449
1607
  return 0
1450
1608
  return speech().pesq
1451
1609
 
1452
- if m == 'mxcsig':
1610
+ if m == "mxcsig":
1453
1611
  if self.mixture(m_id).snr < -96:
1454
1612
  return 0
1455
1613
  return speech().csig
1456
1614
 
1457
- if m == 'mxcbak':
1615
+ if m == "mxcbak":
1458
1616
  if self.mixture(m_id).snr < -96:
1459
1617
  return 0
1460
1618
  return speech().cbak
1461
1619
 
1462
- if m == 'mxcovl':
1620
+ if m == "mxcovl":
1463
1621
  if self.mixture(m_id).snr < -96:
1464
1622
  return 0
1465
1623
  return speech().covl
1466
1624
 
1467
- if m == 'mxwsdr':
1625
+ if m == "mxwsdr":
1468
1626
  mixture = mixture_audio()[:, np.newaxis]
1469
1627
  target = target_audio()[:, np.newaxis]
1470
1628
  noise = noise_audio()[:, np.newaxis]
1471
- return calc_wsdr(hypothesis=np.concatenate((mixture, noise), axis=1),
1472
- reference=np.concatenate((target, noise), axis=1),
1473
- with_log=True)[0]
1629
+ return calc_wsdr(
1630
+ hypothesis=np.concatenate((mixture, noise), axis=1),
1631
+ reference=np.concatenate((target, noise), axis=1),
1632
+ with_log=True,
1633
+ )[0]
1474
1634
 
1475
- if m == 'mxpd':
1635
+ if m == "mxpd":
1476
1636
  mixture_f = self.mixture_mixture_f(m_id)
1477
1637
  return calc_phase_distance(hypothesis=mixture_f, reference=target_f())[0]
1478
1638
 
1479
- if m == 'mxstoi':
1480
- return stoi(x=target_audio(), y=mixture_audio(), fs_sig=SAMPLE_RATE, extended=False)
1639
+ if m == "mxstoi":
1640
+ return stoi(
1641
+ x=target_audio(),
1642
+ y=mixture_audio(),
1643
+ fs_sig=SAMPLE_RATE,
1644
+ extended=False,
1645
+ )
1481
1646
 
1482
- if m == 'tdco':
1647
+ if m == "tdco":
1483
1648
  return target_stats().dco
1484
1649
 
1485
- if m == 'tmin':
1650
+ if m == "tmin":
1486
1651
  return target_stats().min
1487
1652
 
1488
- if m == 'tmax':
1653
+ if m == "tmax":
1489
1654
  return target_stats().max
1490
1655
 
1491
- if m == 'tpkdb':
1656
+ if m == "tpkdb":
1492
1657
  return target_stats().pkdb
1493
1658
 
1494
- if m == 'tlrms':
1659
+ if m == "tlrms":
1495
1660
  return target_stats().lrms
1496
1661
 
1497
- if m == 'tpkr':
1662
+ if m == "tpkr":
1498
1663
  return target_stats().pkr
1499
1664
 
1500
- if m == 'ttr':
1665
+ if m == "ttr":
1501
1666
  return target_stats().tr
1502
1667
 
1503
- if m == 'tcr':
1668
+ if m == "tcr":
1504
1669
  return target_stats().cr
1505
1670
 
1506
- if m == 'tfl':
1671
+ if m == "tfl":
1507
1672
  return target_stats().fl
1508
1673
 
1509
- if m == 'tpkc':
1674
+ if m == "tpkc":
1510
1675
  return target_stats().pkc
1511
1676
 
1512
- if m.startswith('tasr'):
1677
+ if m.startswith("tasr"):
1513
1678
  return target_asr(get_asr_name(m))
1514
1679
 
1515
- if m == 'ndco':
1680
+ if m == "ndco":
1516
1681
  return noise_stats().dco
1517
1682
 
1518
- if m == 'nmin':
1683
+ if m == "nmin":
1519
1684
  return noise_stats().min
1520
1685
 
1521
- if m == 'nmax':
1686
+ if m == "nmax":
1522
1687
  return noise_stats().max
1523
1688
 
1524
- if m == 'npkdb':
1689
+ if m == "npkdb":
1525
1690
  return noise_stats().pkdb
1526
1691
 
1527
- if m == 'nlrms':
1692
+ if m == "nlrms":
1528
1693
  return noise_stats().lrms
1529
1694
 
1530
- if m == 'npkr':
1695
+ if m == "npkr":
1531
1696
  return noise_stats().pkr
1532
1697
 
1533
- if m == 'ntr':
1698
+ if m == "ntr":
1534
1699
  return noise_stats().tr
1535
1700
 
1536
- if m == 'ncr':
1701
+ if m == "ncr":
1537
1702
  return noise_stats().cr
1538
1703
 
1539
- if m == 'nfl':
1704
+ if m == "nfl":
1540
1705
  return noise_stats().fl
1541
1706
 
1542
- if m == 'npkc':
1707
+ if m == "npkc":
1543
1708
  return noise_stats().pkc
1544
1709
 
1545
- if m == 'sedavg':
1710
+ if m == "sedavg":
1546
1711
  return 0
1547
1712
 
1548
- if m == 'sedcnt':
1713
+ if m == "sedcnt":
1549
1714
  return 0
1550
1715
 
1551
- if m == 'sedtop3':
1716
+ if m == "sedtop3":
1552
1717
  return np.zeros(3, dtype=np.float32)
1553
1718
 
1554
- if m == 'sedtopn':
1719
+ if m == "sedtopn":
1555
1720
  return 0
1556
1721
 
1557
- if m == 'ssnr':
1722
+ if m == "ssnr":
1558
1723
  return segsnr_f()
1559
1724
 
1560
- raise SonusAIError(f"Unrecognized metric: '{m}'")
1725
+ raise AttributeError(f"Unrecognized metric: '{m}'")
1561
1726
 
1562
- result: list[float | int | str | Segsnr] = []
1727
+ result: list[float | int | str | Segsnr | None] = []
1563
1728
  for metric in metrics:
1564
1729
  result.append(calc(metric))
1565
1730
 
@@ -1577,13 +1742,16 @@ def _spectral_mask(db: partial, sm_id: int) -> SpectralMask:
1577
1742
  from .db_datatypes import SpectralMaskRecord
1578
1743
 
1579
1744
  with db() as c:
1580
- spectral_mask = SpectralMaskRecord(*c.execute("SELECT * FROM spectral_mask WHERE ? = spectral_mask.id",
1581
- (sm_id,)).fetchone())
1582
- return SpectralMask(f_max_width=spectral_mask.f_max_width,
1583
- f_num=spectral_mask.f_num,
1584
- t_max_width=spectral_mask.t_max_width,
1585
- t_num=spectral_mask.t_num,
1586
- t_max_percent=spectral_mask.t_max_percent)
1745
+ spectral_mask = SpectralMaskRecord(
1746
+ *c.execute("SELECT * FROM spectral_mask WHERE ? = spectral_mask.id", (sm_id,)).fetchone()
1747
+ )
1748
+ return SpectralMask(
1749
+ f_max_width=spectral_mask.f_max_width,
1750
+ f_num=spectral_mask.f_num,
1751
+ t_max_width=spectral_mask.t_max_width,
1752
+ t_num=spectral_mask.t_num,
1753
+ t_max_percent=spectral_mask.t_max_percent,
1754
+ )
1587
1755
 
1588
1756
 
1589
1757
  @lru_cache
@@ -1596,30 +1764,21 @@ def _target_file(db: partial, t_id: int) -> TargetFile:
1596
1764
  """
1597
1765
  import json
1598
1766
 
1599
- from .datatypes import TruthSetting
1600
- from .datatypes import TruthSettings
1601
1767
  from .db_datatypes import TargetFileRecord
1602
1768
 
1603
1769
  with db() as c:
1604
1770
  target_file = TargetFileRecord(
1605
- *c.execute("SELECT * FROM target_file WHERE ? = target_file.id", (t_id,)).fetchone())
1606
-
1607
- truth_settings: TruthSettings = []
1608
- for ts in c.execute(
1609
- "SELECT truth_setting.setting " +
1610
- "FROM truth_setting, target_file_truth_setting " +
1611
- "WHERE ? = target_file_truth_setting.target_file_id " +
1612
- "AND truth_setting.id = target_file_truth_setting.truth_setting_id",
1613
- (t_id,)).fetchall():
1614
- entry = json.loads(ts[0])
1615
- truth_settings.append(TruthSetting(config=entry.get('config', None),
1616
- function=entry.get('function', None),
1617
- index=entry.get('index', None)))
1618
- return TargetFile(name=target_file.name,
1619
- samples=target_file.samples,
1620
- level_type=target_file.level_type,
1621
- truth_settings=truth_settings,
1622
- speaker_id=target_file.speaker_id)
1771
+ *c.execute("SELECT * FROM target_file WHERE ? = target_file.id", (t_id,)).fetchone()
1772
+ )
1773
+
1774
+ return TargetFile(
1775
+ name=target_file.name,
1776
+ samples=target_file.samples,
1777
+ class_indices=json.loads(target_file.class_indices),
1778
+ level_type=target_file.level_type,
1779
+ truth_configs=_target_truth_configs(db, t_id),
1780
+ speaker_id=target_file.speaker_id,
1781
+ )
1623
1782
 
1624
1783
 
1625
1784
  @lru_cache
@@ -1631,8 +1790,10 @@ def _noise_file(db: partial, n_id: int) -> NoiseFile:
1631
1790
  :return: Noise file
1632
1791
  """
1633
1792
  with db() as c:
1634
- noise = c.execute("SELECT noise_file.name, samples FROM noise_file WHERE ? = noise_file.id",
1635
- (n_id,)).fetchone()
1793
+ noise = c.execute(
1794
+ "SELECT noise_file.name, samples FROM noise_file WHERE ? = noise_file.id",
1795
+ (n_id,),
1796
+ ).fetchone()
1636
1797
  return NoiseFile(name=noise[0], samples=noise[1])
1637
1798
 
1638
1799
 
@@ -1645,9 +1806,12 @@ def _impulse_response_file(db: partial, ir_id: int) -> str:
1645
1806
  :return: Noise
1646
1807
  """
1647
1808
  with db() as c:
1648
- return str(c.execute(
1649
- "SELECT impulse_response_file.file FROM impulse_response_file WHERE ? = impulse_response_file.id",
1650
- (ir_id + 1,)).fetchone()[0])
1809
+ return str(
1810
+ c.execute(
1811
+ "SELECT impulse_response_file.file FROM impulse_response_file WHERE ? = impulse_response_file.id",
1812
+ (ir_id + 1,),
1813
+ ).fetchone()[0]
1814
+ )
1651
1815
 
1652
1816
 
1653
1817
  @lru_cache
@@ -1658,31 +1822,59 @@ def _mixture(db: partial, m_id: int) -> Mixture:
1658
1822
  :param m_id: Zero-based mixture ID
1659
1823
  :return: Mixture record
1660
1824
  """
1661
- from .helpers import to_mixture
1662
- from .helpers import to_target
1663
1825
  from .db_datatypes import MixtureRecord
1664
1826
  from .db_datatypes import TargetRecord
1827
+ from .helpers import to_mixture
1828
+ from .helpers import to_target
1665
1829
 
1666
1830
  with db() as c:
1667
1831
  mixture = MixtureRecord(*c.execute("SELECT * FROM mixture WHERE ? = mixture.id", (m_id + 1,)).fetchone())
1668
- targets = [to_target(TargetRecord(*target)) for target in c.execute(
1669
- "SELECT target.* " +
1670
- "FROM target, mixture_target " +
1671
- "WHERE ? = mixture_target.mixture_id AND target.id = mixture_target.target_id",
1672
- (mixture.id,)).fetchall()]
1832
+ targets = [
1833
+ to_target(TargetRecord(*target))
1834
+ for target in c.execute(
1835
+ "SELECT target.* "
1836
+ + "FROM target, mixture_target "
1837
+ + "WHERE ? = mixture_target.mixture_id AND target.id = mixture_target.target_id",
1838
+ (mixture.id,),
1839
+ ).fetchall()
1840
+ ]
1673
1841
 
1674
1842
  return to_mixture(mixture, targets)
1675
1843
 
1676
1844
 
1677
1845
  @lru_cache
1678
- def _speaker(db: partial, s_id: int | None, tier: str) -> Optional[str]:
1846
+ def _speaker(db: partial, s_id: int | None, tier: str) -> str | None:
1679
1847
  if s_id is None:
1680
1848
  return None
1681
1849
 
1682
1850
  with db() as c:
1683
- data = c.execute(f'SELECT {tier} FROM speaker WHERE ? = id', (s_id,)).fetchone()
1851
+ data = c.execute(f"SELECT {tier} FROM speaker WHERE ? = id", (s_id,)).fetchone()
1684
1852
  if data is None:
1685
1853
  return None
1686
1854
  if data[0] is None:
1687
1855
  return None
1688
1856
  return data[0]
1857
+
1858
+
1859
+ @lru_cache
1860
+ def _target_truth_configs(db: partial, t_id: int) -> TruthConfigs:
1861
+ import json
1862
+
1863
+ from .datatypes import TruthConfig
1864
+
1865
+ truth_configs: TruthConfigs = {}
1866
+ with db() as c:
1867
+ for truth_config_record in c.execute(
1868
+ "SELECT truth_config.config "
1869
+ + "FROM truth_config, target_file_truth_config "
1870
+ + "WHERE ? = target_file_truth_config.target_file_id "
1871
+ + "AND truth_config.id = target_file_truth_config.truth_config_id",
1872
+ (t_id,),
1873
+ ).fetchall():
1874
+ truth_config = json.loads(truth_config_record[0])
1875
+ truth_configs[truth_config["name"]] = TruthConfig(
1876
+ function=truth_config["function"],
1877
+ stride_reduction=truth_config["stride_reduction"],
1878
+ config=truth_config["config"],
1879
+ )
1880
+ return truth_configs