sonusai 1.0.11__py3-none-any.whl → 1.0.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sonusai/genmetrics.py CHANGED
@@ -147,12 +147,10 @@ def main() -> None:
147
147
  logger.info("")
148
148
  logger.info(f"Found {len(mixids):,} mixtures to process")
149
149
 
150
- if num_proc is None or len(mixids) == 1:
151
- no_par = True
152
- num_proc = None
153
- else:
154
- no_par = False
155
- num_proc = int(num_proc) # TBD add support for 'auto'
150
+ no_par = num_proc == 1 or len(mixids) == 1
151
+
152
+ if num_proc is not None:
153
+ num_proc = int(num_proc)
156
154
 
157
155
  progress = track(total=len(mixids), desc="genmetrics")
158
156
  results = par_track(
@@ -15,6 +15,7 @@ from .calc_segsnr_f import calc_segsnr_f_bin
15
15
  from .calc_speech import calc_speech
16
16
  from .calc_wer import calc_wer
17
17
  from .calc_wsdr import calc_wsdr
18
+ from .calculate_metrics import calculate_metrics
18
19
  from .class_summary import class_summary
19
20
  from .confusion_matrix_summary import confusion_matrix_summary
20
21
  from .one_hot import one_hot
@@ -0,0 +1,395 @@
1
+ import functools
2
+ from typing import Any
3
+
4
+ import numpy as np
5
+ from pystoi import stoi
6
+
7
+ from ..constants import SAMPLE_RATE
8
+ from ..datatypes import AudioF
9
+ from ..datatypes import AudioStatsMetrics
10
+ from ..datatypes import AudioT
11
+ from ..datatypes import Segsnr
12
+ from ..datatypes import SpeechMetrics
13
+ from ..mixture.mixdb import MixtureDatabase
14
+ from ..utils.asr import calc_asr
15
+ from ..utils.db import linear_to_db
16
+ from .calc_audio_stats import calc_audio_stats
17
+ from .calc_pesq import calc_pesq
18
+ from .calc_phase_distance import calc_phase_distance
19
+ from .calc_segsnr_f import calc_segsnr_f
20
+ from .calc_segsnr_f import calc_segsnr_f_bin
21
+ from .calc_speech import calc_speech
22
+ from .calc_wer import calc_wer
23
+ from .calc_wsdr import calc_wsdr
24
+
25
+
26
+ def calculate_metrics(mixdb: MixtureDatabase, m_id: int, metrics: list[str], force: bool = False) -> dict[str, Any]:
27
+ """Get metrics data for the given mixture ID
28
+
29
+ :param mixdb: Mixture database object
30
+ :param m_id: Zero-based mixture ID
31
+ :param metrics: List of metrics to get
32
+ :param force: Force computing data from original sources regardless of whether cached data exists
33
+ :return: Dictionary of metric data
34
+ """
35
+
36
+ # Define cached functions for expensive operations
37
+ @functools.lru_cache(maxsize=1)
38
+ def mixture_sources() -> dict[str, AudioT]:
39
+ return mixdb.mixture_sources(m_id)
40
+
41
+ @functools.lru_cache(maxsize=1)
42
+ def mixture_source() -> AudioT:
43
+ return mixdb.mixture_source(m_id)
44
+
45
+ @functools.lru_cache(maxsize=1)
46
+ def mixture_source_f() -> AudioF:
47
+ return mixdb.mixture_source_f(m_id)
48
+
49
+ @functools.lru_cache(maxsize=1)
50
+ def mixture_noise() -> AudioT:
51
+ return mixdb.mixture_noise(m_id)
52
+
53
+ @functools.lru_cache(maxsize=1)
54
+ def mixture_noise_f() -> AudioF:
55
+ return mixdb.mixture_noise_f(m_id)
56
+
57
+ @functools.lru_cache(maxsize=1)
58
+ def mixture_mixture() -> AudioT:
59
+ return mixdb.mixture_mixture(m_id)
60
+
61
+ @functools.lru_cache(maxsize=1)
62
+ def mixture_mixture_f() -> AudioF:
63
+ return mixdb.mixture_mixture_f(m_id)
64
+
65
+ @functools.lru_cache(maxsize=1)
66
+ def mixture_segsnr() -> Segsnr:
67
+ return mixdb.mixture_segsnr(m_id)
68
+
69
+ @functools.lru_cache(maxsize=1)
70
+ def calculate_pesq() -> dict[str, float]:
71
+ return {category: calc_pesq(mixture_mixture(), audio) for category, audio in mixture_sources().items()}
72
+
73
+ @functools.lru_cache(maxsize=1)
74
+ def calculate_speech() -> dict[str, SpeechMetrics]:
75
+ return {
76
+ category: calc_speech(mixture_mixture(), audio, calculate_pesq()[category])
77
+ for category, audio in mixture_sources().items()
78
+ }
79
+
80
+ @functools.lru_cache(maxsize=1)
81
+ def mixture_stats() -> AudioStatsMetrics:
82
+ return calc_audio_stats(mixture_mixture(), mixdb.fg_info.ft_config.length / SAMPLE_RATE)
83
+
84
+ @functools.lru_cache(maxsize=1)
85
+ def sources_stats() -> dict[str, AudioStatsMetrics]:
86
+ return {
87
+ category: calc_audio_stats(audio, mixdb.fg_info.ft_config.length / SAMPLE_RATE)
88
+ for category, audio in mixture_sources().items()
89
+ }
90
+
91
+ @functools.lru_cache(maxsize=1)
92
+ def source_stats() -> AudioStatsMetrics:
93
+ return calc_audio_stats(mixture_source(), mixdb.fg_info.ft_config.length / SAMPLE_RATE)
94
+
95
+ @functools.lru_cache(maxsize=1)
96
+ def noise_stats() -> AudioStatsMetrics:
97
+ return calc_audio_stats(mixture_noise(), mixdb.fg_info.ft_config.length / SAMPLE_RATE)
98
+
99
+ # Cache ASR configurations
100
+ @functools.lru_cache(maxsize=32)
101
+ def get_asr_config(asr_name: str) -> dict:
102
+ value = mixdb.asr_configs.get(asr_name, None)
103
+ if value is None:
104
+ raise ValueError(f"Unrecognized ASR name: '{asr_name}'")
105
+ return value
106
+
107
+ # Cache ASR results for sources, source and mixture
108
+ @functools.lru_cache(maxsize=16)
109
+ def sources_asr(asr_name: str) -> dict[str, str]:
110
+ return {
111
+ category: calc_asr(audio, **get_asr_config(asr_name)).text for category, audio in mixture_sources().items()
112
+ }
113
+
114
+ @functools.lru_cache(maxsize=16)
115
+ def source_asr(asr_name: str) -> str:
116
+ return calc_asr(mixture_source(), **get_asr_config(asr_name)).text
117
+
118
+ @functools.lru_cache(maxsize=16)
119
+ def mixture_asr(asr_name: str) -> str:
120
+ return calc_asr(mixture_mixture(), **get_asr_config(asr_name)).text
121
+
122
+ def get_asr_name(m: str) -> str:
123
+ parts = m.split(".")
124
+ if len(parts) != 2:
125
+ raise ValueError(f"Unrecognized format: '{m}'; must be of the form: '<metric>.<name>'")
126
+ asr_name = parts[1]
127
+ return asr_name
128
+
129
+ def calc(m: str) -> Any:
130
+ if m == "mxsnr":
131
+ return {category: source.snr for category, source in mixdb.mixture(m_id).all_sources.items()}
132
+
133
+ # Get cached data first, if exists
134
+ if not force:
135
+ value = mixdb.read_mixture_data(m_id, m)[m]
136
+ if value is not None:
137
+ return value
138
+
139
+ # Otherwise, generate data as needed
140
+ if m.startswith("mxwer"):
141
+ asr_name = get_asr_name(m)
142
+
143
+ if mixdb.mixture(m_id).is_noise_only:
144
+ # noise only, ignore/reset target asr
145
+ return float("nan")
146
+
147
+ if source_asr(asr_name):
148
+ return calc_wer(mixture_asr(asr_name), source_asr(asr_name)).wer * 100
149
+
150
+ # TODO: should this be NaN like above?
151
+ return float(0)
152
+
153
+ if m.startswith("basewer"):
154
+ asr_name = get_asr_name(m)
155
+
156
+ text = mixdb.mixture_speech_metadata(m_id, "text")
157
+ return {
158
+ category: calc_wer(source, str(text[category])).wer * 100 if isinstance(text[category], str) else 0
159
+ for category, source in sources_asr(asr_name).items()
160
+ }
161
+
162
+ if m.startswith("mxasr"):
163
+ return mixture_asr(get_asr_name(m))
164
+
165
+ if m == "mxssnr_avg":
166
+ return calc_segsnr_f(mixture_segsnr()).avg
167
+
168
+ if m == "mxssnr_std":
169
+ return calc_segsnr_f(mixture_segsnr()).std
170
+
171
+ if m == "mxssnr_avg_db":
172
+ val = calc_segsnr_f(mixture_segsnr()).avg
173
+ if val is not None:
174
+ return linear_to_db(val)
175
+ return None
176
+
177
+ if m == "mxssnr_std_db":
178
+ val = calc_segsnr_f(mixture_segsnr()).std
179
+ if val is not None:
180
+ return linear_to_db(val)
181
+ return None
182
+
183
+ if m == "mxssnrdb_avg":
184
+ return calc_segsnr_f(mixture_segsnr()).db_avg
185
+
186
+ if m == "mxssnrdb_std":
187
+ return calc_segsnr_f(mixture_segsnr()).db_std
188
+
189
+ if m == "mxssnrf_avg":
190
+ return calc_segsnr_f_bin(mixture_source_f(), mixture_noise_f()).avg
191
+
192
+ if m == "mxssnrf_std":
193
+ return calc_segsnr_f_bin(mixture_source_f(), mixture_noise_f()).std
194
+
195
+ if m == "mxssnrdbf_avg":
196
+ return calc_segsnr_f_bin(mixture_source_f(), mixture_noise_f()).db_avg
197
+
198
+ if m == "mxssnrdbf_std":
199
+ return calc_segsnr_f_bin(mixture_source_f(), mixture_noise_f()).db_std
200
+
201
+ if m == "mxpesq":
202
+ if mixdb.mixture(m_id).is_noise_only:
203
+ return dict.fromkeys(calculate_pesq(), 0)
204
+ return calculate_pesq()
205
+
206
+ if m == "mxcsig":
207
+ if mixdb.mixture(m_id).is_noise_only:
208
+ return dict.fromkeys(calculate_speech(), 0)
209
+ return {category: s.csig for category, s in calculate_speech().items()}
210
+
211
+ if m == "mxcbak":
212
+ if mixdb.mixture(m_id).is_noise_only:
213
+ return dict.fromkeys(calculate_speech(), 0)
214
+ return {category: s.cbak for category, s in calculate_speech().items()}
215
+
216
+ if m == "mxcovl":
217
+ if mixdb.mixture(m_id).is_noise_only:
218
+ return dict.fromkeys(calculate_speech(), 0)
219
+ return {category: s.covl for category, s in calculate_speech().items()}
220
+
221
+ if m == "mxwsdr":
222
+ mixture = mixture_mixture()[:, np.newaxis]
223
+ target = mixture_source()[:, np.newaxis]
224
+ noise = mixture_noise()[:, np.newaxis]
225
+ return calc_wsdr(
226
+ hypothesis=np.concatenate((mixture, noise), axis=1),
227
+ reference=np.concatenate((target, noise), axis=1),
228
+ with_log=True,
229
+ )[0]
230
+
231
+ if m == "mxpd":
232
+ return calc_phase_distance(hypothesis=mixture_mixture_f(), reference=mixture_source_f())[0]
233
+
234
+ if m == "mxstoi":
235
+ return stoi(
236
+ x=mixture_source(),
237
+ y=mixture_mixture(),
238
+ fs_sig=SAMPLE_RATE,
239
+ extended=False,
240
+ )
241
+
242
+ if m == "mxdco":
243
+ return mixture_stats().dco
244
+
245
+ if m == "mxmin":
246
+ return mixture_stats().min
247
+
248
+ if m == "mxmax":
249
+ return mixture_stats().max
250
+
251
+ if m == "mxpkdb":
252
+ return mixture_stats().pkdb
253
+
254
+ if m == "mxlrms":
255
+ return mixture_stats().lrms
256
+
257
+ if m == "mxpkr":
258
+ return mixture_stats().pkr
259
+
260
+ if m == "mxtr":
261
+ return mixture_stats().tr
262
+
263
+ if m == "mxcr":
264
+ return mixture_stats().cr
265
+
266
+ if m == "mxfl":
267
+ return mixture_stats().fl
268
+
269
+ if m == "mxpkc":
270
+ return mixture_stats().pkc
271
+
272
+ if m == "sdco":
273
+ return {category: s.dco for category, s in sources_stats().items()}
274
+
275
+ if m == "smin":
276
+ return {category: s.min for category, s in sources_stats().items()}
277
+
278
+ if m == "smax":
279
+ return {category: s.max for category, s in sources_stats().items()}
280
+
281
+ if m == "spkdb":
282
+ return {category: s.pkdb for category, s in sources_stats().items()}
283
+
284
+ if m == "slrms":
285
+ return {category: s.lrms for category, s in sources_stats().items()}
286
+
287
+ if m == "spkr":
288
+ return {category: s.pkr for category, s in sources_stats().items()}
289
+
290
+ if m == "str":
291
+ return {category: s.tr for category, s in sources_stats().items()}
292
+
293
+ if m == "scr":
294
+ return {category: s.cr for category, s in sources_stats().items()}
295
+
296
+ if m == "sfl":
297
+ return {category: s.fl for category, s in sources_stats().items()}
298
+
299
+ if m == "spkc":
300
+ return {category: s.pkc for category, s in sources_stats().items()}
301
+
302
+ if m == "mxsdco":
303
+ return source_stats().dco
304
+
305
+ if m == "mxsmin":
306
+ return source_stats().min
307
+
308
+ if m == "mxsmax":
309
+ return source_stats().max
310
+
311
+ if m == "mxspkdb":
312
+ return source_stats().pkdb
313
+
314
+ if m == "mxslrms":
315
+ return source_stats().lrms
316
+
317
+ if m == "mxspkr":
318
+ return source_stats().pkr
319
+
320
+ if m == "mxstr":
321
+ return source_stats().tr
322
+
323
+ if m == "mxscr":
324
+ return source_stats().cr
325
+
326
+ if m == "mxsfl":
327
+ return source_stats().fl
328
+
329
+ if m == "mxspkc":
330
+ return source_stats().pkc
331
+
332
+ if m.startswith("sasr"):
333
+ return sources_asr(get_asr_name(m))
334
+
335
+ if m.startswith("mxsasr"):
336
+ return source_asr(get_asr_name(m))
337
+
338
+ if m == "ndco":
339
+ return noise_stats().dco
340
+
341
+ if m == "nmin":
342
+ return noise_stats().min
343
+
344
+ if m == "nmax":
345
+ return noise_stats().max
346
+
347
+ if m == "npkdb":
348
+ return noise_stats().pkdb
349
+
350
+ if m == "nlrms":
351
+ return noise_stats().lrms
352
+
353
+ if m == "npkr":
354
+ return noise_stats().pkr
355
+
356
+ if m == "ntr":
357
+ return noise_stats().tr
358
+
359
+ if m == "ncr":
360
+ return noise_stats().cr
361
+
362
+ if m == "nfl":
363
+ return noise_stats().fl
364
+
365
+ if m == "npkc":
366
+ return noise_stats().pkc
367
+
368
+ if m == "sedavg":
369
+ return 0
370
+
371
+ if m == "sedcnt":
372
+ return 0
373
+
374
+ if m == "sedtop3":
375
+ return np.zeros(3, dtype=np.float32)
376
+
377
+ if m == "sedtopn":
378
+ return 0
379
+
380
+ if m == "ssnr":
381
+ return mixture_segsnr()
382
+
383
+ raise AttributeError(f"Unrecognized metric: '{m}'")
384
+
385
+ result: dict[str, Any] = {}
386
+ for metric in metrics:
387
+ result[metric] = calc(metric)
388
+
389
+ # Check for metrics dependencies and add them even if not explicitly requested.
390
+ if metric.startswith("mxwer"):
391
+ dependencies = ("mxasr." + metric[6:], "sasr." + metric[6:])
392
+ for dependency in dependencies:
393
+ result[dependency] = calc(dependency)
394
+
395
+ return result
sonusai/mixture/mixdb.py CHANGED
@@ -215,16 +215,6 @@ class MixtureDatabase:
215
215
  MetricDoc("Mixture Metrics", "mxcr", "Mixture Crest factor"),
216
216
  MetricDoc("Mixture Metrics", "mxfl", "Mixture Flat factor"),
217
217
  MetricDoc("Mixture Metrics", "mxpkc", "Mixture Pk count"),
218
- MetricDoc("Mixture Metrics", "mxtdco", "Mixture source DC offset"),
219
- MetricDoc("Mixture Metrics", "mxtmin", "Mixture source min level"),
220
- MetricDoc("Mixture Metrics", "mxtmax", "Mixture source max levl"),
221
- MetricDoc("Mixture Metrics", "mxtpkdb", "Mixture source Pk lev dB"),
222
- MetricDoc("Mixture Metrics", "mxtlrms", "Mixture source RMS lev dB"),
223
- MetricDoc("Mixture Metrics", "mxtpkr", "Mixture source RMS Pk dB"),
224
- MetricDoc("Mixture Metrics", "mxttr", "Mixture source RMS Tr dB"),
225
- MetricDoc("Mixture Metrics", "mxtcr", "Mixture source Crest factor"),
226
- MetricDoc("Mixture Metrics", "mxtfl", "Mixture source Flat factor"),
227
- MetricDoc("Mixture Metrics", "mxtpkc", "Mixture source Pk count"),
228
218
  MetricDoc("Sources Metrics", "sdco", "Sources DC offset"),
229
219
  MetricDoc("Sources Metrics", "smin", "Sources min level"),
230
220
  MetricDoc("Sources Metrics", "smax", "Sources max levl"),
@@ -235,6 +225,16 @@ class MixtureDatabase:
235
225
  MetricDoc("Sources Metrics", "scr", "Sources Crest factor"),
236
226
  MetricDoc("Sources Metrics", "sfl", "Sources Flat factor"),
237
227
  MetricDoc("Sources Metrics", "spkc", "Sources Pk count"),
228
+ MetricDoc("Source Metrics", "mxsdco", "Source DC offset"),
229
+ MetricDoc("Source Metrics", "mxsmin", "Source min level"),
230
+ MetricDoc("Source Metrics", "mxsmax", "Source max levl"),
231
+ MetricDoc("Source Metrics", "mxspkdb", "Source Pk lev dB"),
232
+ MetricDoc("Source Metrics", "mxslrms", "Source RMS lev dB"),
233
+ MetricDoc("Source Metrics", "mxspkr", "Source RMS Pk dB"),
234
+ MetricDoc("Source Metrics", "mxstr", "Source RMS Tr dB"),
235
+ MetricDoc("Source Metrics", "mxscr", "Source Crest factor"),
236
+ MetricDoc("Source Metrics", "mxsfl", "Source Flat factor"),
237
+ MetricDoc("Source Metrics", "mxspkc", "Source Pk count"),
238
238
  MetricDoc("Noise Metrics", "ndco", "Noise DC offset"),
239
239
  MetricDoc("Noise Metrics", "nmin", "Noise min level"),
240
240
  MetricDoc("Noise Metrics", "nmax", "Noise max levl"),
@@ -272,12 +272,12 @@ class MixtureDatabase:
272
272
  MetricDoc(
273
273
  "Source Metrics",
274
274
  f"mxsasr.{name}",
275
- f"Mixture Source ASR text using {name} ASR as defined in mixdb asr_configs parameter",
275
+ f"Source ASR text using {name} ASR as defined in mixdb asr_configs parameter",
276
276
  )
277
277
  )
278
278
  metrics.append(
279
279
  MetricDoc(
280
- "Source Metrics",
280
+ "Sources Metrics",
281
281
  f"sasr.{name}",
282
282
  f"Sources ASR text using {name} ASR as defined in mixdb asr_configs parameter",
283
283
  )
@@ -291,7 +291,7 @@ class MixtureDatabase:
291
291
  )
292
292
  metrics.append(
293
293
  MetricDoc(
294
- "Source Metrics",
294
+ "Sources Metrics",
295
295
  f"basewer.{name}",
296
296
  f"Word error rate of sasr.{name} vs. speech text metadata for the source",
297
297
  )
@@ -1296,17 +1296,15 @@ class MixtureDatabase:
1296
1296
  fg = FeatureGenerator(self.fg_config.feature_mode, self.fg_config.truth_parameters)
1297
1297
 
1298
1298
  feature, truth_f = fg.execute_all(mixture_f, truth_t)
1299
- if truth_f is not None:
1300
- truth_configs = self.mixture_truth_configs(m_id)
1301
- for category, configs in truth_configs.items():
1302
- for name, config in configs.items():
1303
- if self.truth_parameters[category][name] is not None:
1304
- truth_f[category][name] = truth_stride_reduction(
1305
- truth_f[category][name], config.stride_reduction
1306
- )
1307
- else:
1299
+ if truth_f is None:
1308
1300
  raise TypeError("Unexpected truth of None from feature generator")
1309
1301
 
1302
+ truth_configs = self.mixture_truth_configs(m_id)
1303
+ for category, configs in truth_configs.items():
1304
+ for name, config in configs.items():
1305
+ if self.truth_parameters[category][name] is not None:
1306
+ truth_f[category][name] = truth_stride_reduction(truth_f[category][name], config.stride_reduction)
1307
+
1310
1308
  if cache:
1311
1309
  write_cached_data(
1312
1310
  location=self.location,
@@ -1598,536 +1596,9 @@ class MixtureDatabase:
1598
1596
  :param force: Force computing data from original sources regardless of whether cached data exists
1599
1597
  :return: Dictionary of metric data
1600
1598
  """
1601
- from collections.abc import Callable
1602
-
1603
- import numpy as np
1604
- from pystoi import stoi
1605
-
1606
- from ..constants import SAMPLE_RATE
1607
- from ..datatypes import AudioStatsMetrics
1608
- from ..datatypes import SpeechMetrics
1609
- from ..metrics.calc_audio_stats import calc_audio_stats
1610
- from ..metrics.calc_pesq import calc_pesq
1611
- from ..metrics.calc_phase_distance import calc_phase_distance
1612
- from ..metrics.calc_segsnr_f import calc_segsnr_f
1613
- from ..metrics.calc_segsnr_f import calc_segsnr_f_bin
1614
- from ..metrics.calc_speech import calc_speech
1615
- from ..metrics.calc_wer import calc_wer
1616
- from ..metrics.calc_wsdr import calc_wsdr
1617
- from ..utils.asr import calc_asr
1618
- from ..utils.db import linear_to_db
1619
-
1620
- def create_sources_audio() -> Callable[[], dict[str, AudioT]]:
1621
- state: dict[str, AudioT] | None = None
1622
-
1623
- def get() -> dict[str, AudioT]:
1624
- nonlocal state
1625
- if state is None:
1626
- state = self.mixture_sources(m_id)
1627
- return state
1628
-
1629
- return get
1630
-
1631
- sources_audio = create_sources_audio()
1632
-
1633
- def create_source_audio() -> Callable[[], AudioT]:
1634
- state: AudioT | None = None
1635
-
1636
- def get() -> AudioT:
1637
- nonlocal state
1638
- if state is None:
1639
- state = self.mixture_source(m_id)
1640
- return state
1641
-
1642
- return get
1643
-
1644
- source_audio = create_source_audio()
1645
-
1646
- def create_source_f() -> Callable[[], AudioF]:
1647
- state: AudioF | None = None
1648
-
1649
- def get() -> AudioF:
1650
- nonlocal state
1651
- if state is None:
1652
- state = self.mixture_source_f(m_id)
1653
- return state
1654
-
1655
- return get
1656
-
1657
- source_f = create_source_f()
1658
-
1659
- def create_noise_audio() -> Callable[[], AudioT]:
1660
- state: AudioT | None = None
1661
-
1662
- def get() -> AudioT:
1663
- nonlocal state
1664
- if state is None:
1665
- state = self.mixture_noise(m_id)
1666
- return state
1667
-
1668
- return get
1669
-
1670
- noise_audio = create_noise_audio()
1671
-
1672
- def create_noise_f() -> Callable[[], AudioF]:
1673
- state: AudioF | None = None
1674
-
1675
- def get() -> AudioF:
1676
- nonlocal state
1677
- if state is None:
1678
- state = self.mixture_noise_f(m_id)
1679
- return state
1680
-
1681
- return get
1682
-
1683
- noise_f = create_noise_f()
1684
-
1685
- def create_mixture_audio() -> Callable[[], AudioT]:
1686
- state: AudioT | None = None
1687
-
1688
- def get() -> AudioT:
1689
- nonlocal state
1690
- if state is None:
1691
- state = self.mixture_mixture(m_id)
1692
- return state
1693
-
1694
- return get
1695
-
1696
- mixture_audio = create_mixture_audio()
1697
-
1698
- def create_segsnr_f() -> Callable[[], Segsnr]:
1699
- state: Segsnr | None = None
1700
-
1701
- def get() -> Segsnr:
1702
- nonlocal state
1703
- if state is None:
1704
- state = self.mixture_segsnr(m_id)
1705
- return state
1706
-
1707
- return get
1708
-
1709
- segsnr_f = create_segsnr_f()
1710
-
1711
- def create_pesq() -> Callable[[], dict[str, float]]:
1712
- state: dict[str, float] | None = None
1713
-
1714
- def get() -> dict[str, float]:
1715
- nonlocal state
1716
- if state is None:
1717
- state = {category: calc_pesq(mixture_audio(), audio) for category, audio in sources_audio().items()}
1718
- return state
1719
-
1720
- return get
1721
-
1722
- pesq = create_pesq()
1723
-
1724
- def create_speech() -> Callable[[], dict[str, SpeechMetrics]]:
1725
- state: dict[str, SpeechMetrics] | None = None
1726
-
1727
- def get() -> dict[str, SpeechMetrics]:
1728
- nonlocal state
1729
- if state is None:
1730
- state = {
1731
- category: calc_speech(mixture_audio(), audio, pesq()[category])
1732
- for category, audio in sources_audio().items()
1733
- }
1734
- return state
1735
-
1736
- return get
1737
-
1738
- speech = create_speech()
1739
-
1740
- def create_mixture_stats() -> Callable[[], AudioStatsMetrics]:
1741
- state: AudioStatsMetrics | None = None
1742
-
1743
- def get() -> AudioStatsMetrics:
1744
- nonlocal state
1745
- if state is None:
1746
- state = calc_audio_stats(mixture_audio(), self.fg_info.ft_config.length / SAMPLE_RATE)
1747
- return state
1748
-
1749
- return get
1750
-
1751
- mixture_stats = create_mixture_stats()
1752
-
1753
- def create_sources_stats() -> Callable[[], dict[str, AudioStatsMetrics]]:
1754
- state: dict[str, AudioStatsMetrics] | None = None
1755
-
1756
- def get() -> dict[str, AudioStatsMetrics]:
1757
- nonlocal state
1758
- if state is None:
1759
- state = {
1760
- category: calc_audio_stats(audio, self.fg_info.ft_config.length / SAMPLE_RATE)
1761
- for category, audio in sources_audio().items()
1762
- }
1763
- return state
1764
-
1765
- return get
1766
-
1767
- sources_stats = create_sources_stats()
1768
-
1769
- def create_source_stats() -> Callable[[], AudioStatsMetrics]:
1770
- state: AudioStatsMetrics | None = None
1771
-
1772
- def get() -> AudioStatsMetrics:
1773
- nonlocal state
1774
- if state is None:
1775
- state = calc_audio_stats(source_audio(), self.fg_info.ft_config.length / SAMPLE_RATE)
1776
- return state
1777
-
1778
- return get
1779
-
1780
- source_stats = create_source_stats()
1781
-
1782
- def create_noise_stats() -> Callable[[], AudioStatsMetrics]:
1783
- state: AudioStatsMetrics | None = None
1784
-
1785
- def get() -> AudioStatsMetrics:
1786
- nonlocal state
1787
- if state is None:
1788
- state = calc_audio_stats(noise_audio(), self.fg_info.ft_config.length / SAMPLE_RATE)
1789
- return state
1790
-
1791
- return get
1792
-
1793
- noise_stats = create_noise_stats()
1794
-
1795
- def create_asr_config() -> Callable[[str], dict]:
1796
- state: dict[str, dict] = {}
1797
-
1798
- def get(asr_name) -> dict:
1799
- nonlocal state
1800
- if asr_name not in state:
1801
- value = self.asr_configs.get(asr_name, None)
1802
- if value is None:
1803
- raise ValueError(f"Unrecognized ASR name: '{asr_name}'")
1804
- state[asr_name] = value
1805
- return state[asr_name]
1806
-
1807
- return get
1808
-
1809
- asr_config = create_asr_config()
1810
-
1811
- def create_sources_asr() -> Callable[[str], dict[str, str]]:
1812
- state: dict[str, dict[str, str]] = {}
1813
-
1814
- def get(asr_name) -> dict[str, str]:
1815
- nonlocal state
1816
- if asr_name not in state:
1817
- state[asr_name] = {
1818
- category: calc_asr(audio, **asr_config(asr_name)).text
1819
- for category, audio in sources_audio().items()
1820
- }
1821
- return state[asr_name]
1822
-
1823
- return get
1824
-
1825
- sources_asr = create_sources_asr()
1826
-
1827
- def create_source_asr() -> Callable[[str], str]:
1828
- state: dict[str, str] = {}
1829
-
1830
- def get(asr_name) -> str:
1831
- nonlocal state
1832
- if asr_name not in state:
1833
- state[asr_name] = calc_asr(source_audio(), **asr_config(asr_name)).text
1834
- return state[asr_name]
1835
-
1836
- return get
1837
-
1838
- source_asr = create_source_asr()
1839
-
1840
- def create_mixture_asr() -> Callable[[str], str]:
1841
- state: dict[str, str] = {}
1842
-
1843
- def get(asr_name) -> str:
1844
- nonlocal state
1845
- if asr_name not in state:
1846
- state[asr_name] = calc_asr(mixture_audio(), **asr_config(asr_name)).text
1847
- return state[asr_name]
1848
-
1849
- return get
1850
-
1851
- mixture_asr = create_mixture_asr()
1852
-
1853
- def get_asr_name(m: str) -> str:
1854
- parts = m.split(".")
1855
- if len(parts) != 2:
1856
- raise ValueError(f"Unrecognized format: '{m}'; must be of the form: '<metric>.<name>'")
1857
- asr_name = parts[1]
1858
- return asr_name
1859
-
1860
- def calc(m: str) -> Any:
1861
- if m == "mxsnr":
1862
- return {category: source.snr for category, source in self.mixture(m_id).all_sources.items()}
1863
-
1864
- # Get cached data first, if exists
1865
- if not force:
1866
- value = self.read_mixture_data(m_id, m)[m]
1867
- if value is not None:
1868
- return value
1869
-
1870
- # Otherwise, generate data as needed
1871
- if m.startswith("mxwer"):
1872
- asr_name = get_asr_name(m)
1873
-
1874
- if self.mixture(m_id).is_noise_only:
1875
- # noise only, ignore/reset target asr
1876
- return float("nan")
1877
-
1878
- if source_asr(asr_name):
1879
- return calc_wer(mixture_asr(asr_name), source_asr(asr_name)).wer * 100
1880
-
1881
- # TODO: should this be NaN like above?
1882
- return float(0)
1883
-
1884
- if m.startswith("basewer"):
1885
- asr_name = get_asr_name(m)
1886
-
1887
- text = self.mixture_speech_metadata(m_id, "text")
1888
- base_wer: dict[str, float] = {}
1889
- for category, source in sources_asr(asr_name).items():
1890
- if isinstance(text[category], str):
1891
- base_wer[category] = calc_wer(source, str(text[category])).wer * 100
1892
- else:
1893
- base_wer[category] = 0
1894
- return base_wer
1895
-
1896
- if m.startswith("mxasr"):
1897
- return mixture_asr(get_asr_name(m))
1898
-
1899
- if m == "mxssnr_avg":
1900
- return calc_segsnr_f(segsnr_f()).avg
1901
-
1902
- if m == "mxssnr_std":
1903
- return calc_segsnr_f(segsnr_f()).std
1904
-
1905
- if m == "mxssnr_avg_db":
1906
- val = calc_segsnr_f(segsnr_f()).avg
1907
- if val is not None:
1908
- return linear_to_db(val)
1909
- return None
1910
-
1911
- if m == "mxssnr_std_db":
1912
- val = calc_segsnr_f(segsnr_f()).std
1913
- if val is not None:
1914
- return linear_to_db(val)
1915
- return None
1916
-
1917
- if m == "mxssnrdb_avg":
1918
- return calc_segsnr_f(segsnr_f()).db_avg
1919
-
1920
- if m == "mxssnrdb_std":
1921
- return calc_segsnr_f(segsnr_f()).db_std
1922
-
1923
- if m == "mxssnrf_avg":
1924
- return calc_segsnr_f_bin(source_f(), noise_f()).avg
1925
-
1926
- if m == "mxssnrf_std":
1927
- return calc_segsnr_f_bin(source_f(), noise_f()).std
1928
-
1929
- if m == "mxssnrdbf_avg":
1930
- return calc_segsnr_f_bin(source_f(), noise_f()).db_avg
1931
-
1932
- if m == "mxssnrdbf_std":
1933
- return calc_segsnr_f_bin(source_f(), noise_f()).db_std
1934
-
1935
- if m == "mxpesq":
1936
- if self.mixture(m_id).is_noise_only:
1937
- return dict.fromkeys(pesq(), 0)
1938
- return pesq()
1939
-
1940
- if m == "mxcsig":
1941
- if self.mixture(m_id).is_noise_only:
1942
- return dict.fromkeys(speech(), 0)
1943
- return {category: s.csig for category, s in speech().items()}
1944
-
1945
- if m == "mxcbak":
1946
- if self.mixture(m_id).is_noise_only:
1947
- return dict.fromkeys(speech(), 0)
1948
- return {category: s.cbak for category, s in speech().items()}
1949
-
1950
- if m == "mxcovl":
1951
- if self.mixture(m_id).is_noise_only:
1952
- return dict.fromkeys(speech(), 0)
1953
- return {category: s.covl for category, s in speech().items()}
1954
-
1955
- if m == "mxwsdr":
1956
- mixture = mixture_audio()[:, np.newaxis]
1957
- target = source_audio()[:, np.newaxis]
1958
- noise = noise_audio()[:, np.newaxis]
1959
- return calc_wsdr(
1960
- hypothesis=np.concatenate((mixture, noise), axis=1),
1961
- reference=np.concatenate((target, noise), axis=1),
1962
- with_log=True,
1963
- )[0]
1964
-
1965
- if m == "mxpd":
1966
- mixture_f = self.mixture_mixture_f(m_id)
1967
- return calc_phase_distance(hypothesis=mixture_f, reference=source_f())[0]
1968
-
1969
- if m == "mxstoi":
1970
- return stoi(
1971
- x=source_audio(),
1972
- y=mixture_audio(),
1973
- fs_sig=SAMPLE_RATE,
1974
- extended=False,
1975
- )
1976
-
1977
- if m == "mxdco":
1978
- return mixture_stats().dco
1979
-
1980
- if m == "mxmin":
1981
- return mixture_stats().min
1982
-
1983
- if m == "mxmax":
1984
- return mixture_stats().max
1985
-
1986
- if m == "mxpkdb":
1987
- return mixture_stats().pkdb
1988
-
1989
- if m == "mxlrms":
1990
- return mixture_stats().lrms
1991
-
1992
- if m == "mxpkr":
1993
- return mixture_stats().pkr
1994
-
1995
- if m == "mxtr":
1996
- return mixture_stats().tr
1997
-
1998
- if m == "mxcr":
1999
- return mixture_stats().cr
2000
-
2001
- if m == "mxfl":
2002
- return mixture_stats().fl
2003
-
2004
- if m == "mxpkc":
2005
- return mixture_stats().pkc
2006
-
2007
- if m == "mxtdco":
2008
- return source_stats().dco
2009
-
2010
- if m == "mxtmin":
2011
- return source_stats().min
2012
-
2013
- if m == "mxtmax":
2014
- return source_stats().max
2015
-
2016
- if m == "mxtpkdb":
2017
- return source_stats().pkdb
2018
-
2019
- if m == "mxtlrms":
2020
- return source_stats().lrms
2021
-
2022
- if m == "mxtpkr":
2023
- return source_stats().pkr
2024
-
2025
- if m == "mxttr":
2026
- return source_stats().tr
2027
-
2028
- if m == "mxtcr":
2029
- return source_stats().cr
2030
-
2031
- if m == "mxtfl":
2032
- return source_stats().fl
2033
-
2034
- if m == "mxtpkc":
2035
- return source_stats().pkc
2036
-
2037
- if m == "sdco":
2038
- return {category: s.dco for category, s in sources_stats().items()}
2039
-
2040
- if m == "smin":
2041
- return {category: s.min for category, s in sources_stats().items()}
2042
-
2043
- if m == "smax":
2044
- return {category: s.max for category, s in sources_stats().items()}
2045
-
2046
- if m == "spkdb":
2047
- return {category: s.pkdb for category, s in sources_stats().items()}
2048
-
2049
- if m == "slrms":
2050
- return {category: s.lrms for category, s in sources_stats().items()}
2051
-
2052
- if m == "spkr":
2053
- return {category: s.pkr for category, s in sources_stats().items()}
2054
-
2055
- if m == "str":
2056
- return {category: s.tr for category, s in sources_stats().items()}
2057
-
2058
- if m == "scr":
2059
- return {category: s.cr for category, s in sources_stats().items()}
2060
-
2061
- if m == "sfl":
2062
- return {category: s.fl for category, s in sources_stats().items()}
2063
-
2064
- if m == "spkc":
2065
- return {category: s.pkc for category, s in sources_stats().items()}
2066
-
2067
- if m.startswith("sasr"):
2068
- return sources_asr(get_asr_name(m))
2069
-
2070
- if m.startswith("mxsasr"):
2071
- return source_asr(get_asr_name(m))
2072
-
2073
- if m == "ndco":
2074
- return noise_stats().dco
2075
-
2076
- if m == "nmin":
2077
- return noise_stats().min
2078
-
2079
- if m == "nmax":
2080
- return noise_stats().max
2081
-
2082
- if m == "npkdb":
2083
- return noise_stats().pkdb
2084
-
2085
- if m == "nlrms":
2086
- return noise_stats().lrms
2087
-
2088
- if m == "npkr":
2089
- return noise_stats().pkr
2090
-
2091
- if m == "ntr":
2092
- return noise_stats().tr
2093
-
2094
- if m == "ncr":
2095
- return noise_stats().cr
2096
-
2097
- if m == "nfl":
2098
- return noise_stats().fl
2099
-
2100
- if m == "npkc":
2101
- return noise_stats().pkc
2102
-
2103
- if m == "sedavg":
2104
- return 0
2105
-
2106
- if m == "sedcnt":
2107
- return 0
2108
-
2109
- if m == "sedtop3":
2110
- return np.zeros(3, dtype=np.float32)
2111
-
2112
- if m == "sedtopn":
2113
- return 0
2114
-
2115
- if m == "ssnr":
2116
- return segsnr_f()
2117
-
2118
- raise AttributeError(f"Unrecognized metric: '{m}'")
2119
-
2120
- result: dict[str, Any] = {}
2121
- for metric in metrics:
2122
- result[metric] = calc(metric)
2123
-
2124
- # Check for metrics dependencies and add them even if not explicitly requested.
2125
- if metric.startswith("mxwer"):
2126
- dependencies = ("mxasr." + metric[6:], "sasr." + metric[6:])
2127
- for dependency in dependencies:
2128
- result[dependency] = calc(dependency)
1599
+ from ..metrics import calculate_metrics
2129
1600
 
2130
- return result
1601
+ return calculate_metrics(self, m_id, metrics, force)
2131
1602
 
2132
1603
 
2133
1604
  def _spectral_mask(db: partial, sm_id: int, use_cache: bool = True) -> SpectralMask:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: sonusai
3
- Version: 1.0.11
3
+ Version: 1.0.13
4
4
  Summary: Framework for building deep neural network models for sound, speech, and voice AI
5
5
  Home-page: https://aaware.com
6
6
  License: GPL-3.0-only
@@ -21,13 +21,13 @@ sonusai/doc/__init__.py,sha256=KyQ26Um0RM8A3GYsb_tbFH64RwpoAw6lja2f_moUWas,33
21
21
  sonusai/doc/doc.py,sha256=FURO3pvGKrUCHs5iHf0L2zeNofdePW_jiEwtKQX4pJw,19520
22
22
  sonusai/doc.py,sha256=ZgFSSI56oNDb-yC3xi-RHMClMjryR2VrgGyi3ggX8gM,1098
23
23
  sonusai/genft.py,sha256=yiADvi0J-Fy4kNpNOEB3wVvU9RZowGvOsCTJndQYXFw,5580
24
- sonusai/genmetrics.py,sha256=9l7g_DAKa126RGq23-Wilzdh1M3QHCCeNfUVYaQS1mU,6193
24
+ sonusai/genmetrics.py,sha256=C-rp_axsxeKvdXtrtExMGqGnNFJgXHq_7EoKeamUkWA,6116
25
25
  sonusai/genmix.py,sha256=gcmqcPqZ1Vz_TtZMp29L8cGnqTK5jcw0cAOc16NOR9A,5753
26
26
  sonusai/genmixdb.py,sha256=VDQMF6JHcHc-yJAZ1Se3CM3ac8fFKIgnaxv4e5jdE1I,11281
27
27
  sonusai/ir_metric.py,sha256=nxS_mARPSZG5Y0G3L8HysOnkPj4v-RGxAxAVBYe-gJI,19600
28
28
  sonusai/lsdb.py,sha256=-Fhwd7YuL-OIymFqaNcBHtOq8l_8LxzoEE6ztduQCpY,5059
29
29
  sonusai/main.py,sha256=72feJv5XEVJE_CQatmNIL1VD9ca-Mo0QNDbXxLrHrbQ,2619
30
- sonusai/metrics/__init__.py,sha256=ssV6JEK_oklRSocsp6HMcG-GtJvV8IkRQtdKhHHmwU8,878
30
+ sonusai/metrics/__init__.py,sha256=0Y0xFHiO3TrH4DRt-htCXEXsc8TLGNRWfD16q16yWEs,927
31
31
  sonusai/metrics/calc_audio_stats.py,sha256=tIfTa40UdYCkj999kUghWafwnFBqFtJxB5yZhVp1YpA,1244
32
32
  sonusai/metrics/calc_class_weights.py,sha256=uF1jeFz73l5nSk6SQ-xkBGbrgvAvX_MKUA_Det2KAEM,3609
33
33
  sonusai/metrics/calc_optimal_thresholds.py,sha256=1bKPoqUYyHpq7lrx7hPnVXrJ5xWIewQjNG632GzKNNU,3502
@@ -40,6 +40,7 @@ sonusai/metrics/calc_segsnr_f.py,sha256=yLqUt--8osVgCNAkopbDZsldlVJ6a5AZEggarN8d
40
40
  sonusai/metrics/calc_speech.py,sha256=bFiWtKz_Fuu4F1kdWGmZ3qZ_LdoSI3pj0ziXZKxXE3U,14828
41
41
  sonusai/metrics/calc_wer.py,sha256=1MQYMx8ldHeodtJEtGibvDKhvSaGe6DBmZV4L8qOMgg,2362
42
42
  sonusai/metrics/calc_wsdr.py,sha256=vcALY-zuhyThRa1QMz2qW8L9kSBc2v32gV9u8bV7VaM,2556
43
+ sonusai/metrics/calculate_metrics.py,sha256=jcAyEV6loenu4fU_EvwEkpKxOrP8-K9O3rwQGlE48IU,12475
43
44
  sonusai/metrics/class_summary.py,sha256=mQbMxQ8EtFIN7S2h7A4Dk0X4XF_CIxKk3W8zZMmpfcw,2801
44
45
  sonusai/metrics/confusion_matrix_summary.py,sha256=lhd8TyHVMC03khX85h_D75XElmawx56KkqpX3X2O2gQ,3133
45
46
  sonusai/metrics/one_hot.py,sha256=aKc-xYd4zWIjbmoQikIcQ6BJB1k-68XKTg8eJCacHTU,13906
@@ -60,7 +61,7 @@ sonusai/mixture/helpers.py,sha256=dmyHwf1C5dZjYOd11kVV16KI33CaM-dU_fyaxOrrKt8,11
60
61
  sonusai/mixture/ir_delay.py,sha256=aiC23HMWQ08-v5wORgMx1_DOJSdh4kunULqiQ-SGuMo,2026
61
62
  sonusai/mixture/ir_effects.py,sha256=PqiqD4PS42-7kD6ESnsZi2a3tnKCFa4E0xqUujRBvGg,2152
62
63
  sonusai/mixture/log_duration_and_sizes.py,sha256=3ekS27IMKlnxIkQAmprzmBnzHOpRjZh3d7maL2VqWQU,927
63
- sonusai/mixture/mixdb.py,sha256=h8ZBARqKyFvMhAlXhXGSIwQ_zMAVtHok4wOfM4USVdQ,84900
64
+ sonusai/mixture/mixdb.py,sha256=BzFzVON6ZupJcZ9Bx-OXOirck5szLrRY92bSr3042S8,67874
64
65
  sonusai/mixture/pad_audio.py,sha256=KNxVQAejA0hblLOnMJgLS6lFaeE0n3tWQ5rclaHBnIY,1015
65
66
  sonusai/mixture/parse.py,sha256=nqhjuR-J7_3wlGhVitYFvQwLJ1sclU8WZrVF0SyW2Cw,3700
66
67
  sonusai/mixture/resample.py,sha256=jXqH6FrZ0mlhQ07XqPx88TT9elu3HHVLw7Q0a7Lh5M4,221
@@ -133,7 +134,7 @@ sonusai/utils/tokenized_shell_vars.py,sha256=EDrrAgz5lJ0RBAjLcTJt1MeyjhbNZiqXkym
133
134
  sonusai/utils/write_audio.py,sha256=IHzrJoFtFcea_J6wo6QSiojRkgnNOzAEcg-z0rFV7nU,810
134
135
  sonusai/utils/yes_or_no.py,sha256=0h1okjXmDNbJp7rZJFR2V-HFU1GJDm3YFTUVmYExkOU,263
135
136
  sonusai/vars.py,sha256=m8pdgfR4A6A9TCGf_rok6jPAT5BgrEsYXTSISIh1nrI,1163
136
- sonusai-1.0.11.dist-info/METADATA,sha256=mW6s4W6kXhofeFWjfMaQZAAisbdCYoc0bWPlD53JvFA,2695
137
- sonusai-1.0.11.dist-info/WHEEL,sha256=RaoafKOydTQ7I_I3JTrPCg6kUmTgtm4BornzOqyEfJ8,88
138
- sonusai-1.0.11.dist-info/entry_points.txt,sha256=zMNjEphEPO6B3cD1GNpit7z-yA9tUU5-j3W2v-UWstU,92
139
- sonusai-1.0.11.dist-info/RECORD,,
137
+ sonusai-1.0.13.dist-info/METADATA,sha256=l-tODfpKcDr2Xfqriw3VDvw52-k7YR9S5fEXtutpS1k,2695
138
+ sonusai-1.0.13.dist-info/WHEEL,sha256=RaoafKOydTQ7I_I3JTrPCg6kUmTgtm4BornzOqyEfJ8,88
139
+ sonusai-1.0.13.dist-info/entry_points.txt,sha256=zMNjEphEPO6B3cD1GNpit7z-yA9tUU5-j3W2v-UWstU,92
140
+ sonusai-1.0.13.dist-info/RECORD,,