sonusai 1.0.16__cp311-abi3-macosx_10_12_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sonusai/__init__.py +170 -0
  2. sonusai/aawscd_probwrite.py +148 -0
  3. sonusai/audiofe.py +481 -0
  4. sonusai/calc_metric_spenh.py +1136 -0
  5. sonusai/config/__init__.py +0 -0
  6. sonusai/config/asr.py +21 -0
  7. sonusai/config/config.py +65 -0
  8. sonusai/config/config.yml +49 -0
  9. sonusai/config/constants.py +53 -0
  10. sonusai/config/ir.py +124 -0
  11. sonusai/config/ir_delay.py +62 -0
  12. sonusai/config/source.py +275 -0
  13. sonusai/config/spectral_masks.py +15 -0
  14. sonusai/config/truth.py +64 -0
  15. sonusai/constants.py +14 -0
  16. sonusai/data/__init__.py +0 -0
  17. sonusai/data/silero_vad_v5.1.jit +0 -0
  18. sonusai/data/silero_vad_v5.1.onnx +0 -0
  19. sonusai/data/speech_ma01_01.wav +0 -0
  20. sonusai/data/whitenoise.wav +0 -0
  21. sonusai/datatypes.py +383 -0
  22. sonusai/deprecated/gentcst.py +632 -0
  23. sonusai/deprecated/plot.py +519 -0
  24. sonusai/deprecated/tplot.py +365 -0
  25. sonusai/doc.py +52 -0
  26. sonusai/doc_strings/__init__.py +1 -0
  27. sonusai/doc_strings/doc_strings.py +531 -0
  28. sonusai/genft.py +196 -0
  29. sonusai/genmetrics.py +183 -0
  30. sonusai/genmix.py +199 -0
  31. sonusai/genmixdb.py +235 -0
  32. sonusai/ir_metric.py +551 -0
  33. sonusai/lsdb.py +141 -0
  34. sonusai/main.py +134 -0
  35. sonusai/metrics/__init__.py +43 -0
  36. sonusai/metrics/calc_audio_stats.py +42 -0
  37. sonusai/metrics/calc_class_weights.py +90 -0
  38. sonusai/metrics/calc_optimal_thresholds.py +73 -0
  39. sonusai/metrics/calc_pcm.py +45 -0
  40. sonusai/metrics/calc_pesq.py +36 -0
  41. sonusai/metrics/calc_phase_distance.py +43 -0
  42. sonusai/metrics/calc_sa_sdr.py +64 -0
  43. sonusai/metrics/calc_sample_weights.py +25 -0
  44. sonusai/metrics/calc_segsnr_f.py +82 -0
  45. sonusai/metrics/calc_speech.py +382 -0
  46. sonusai/metrics/calc_wer.py +71 -0
  47. sonusai/metrics/calc_wsdr.py +57 -0
  48. sonusai/metrics/calculate_metrics.py +395 -0
  49. sonusai/metrics/class_summary.py +74 -0
  50. sonusai/metrics/confusion_matrix_summary.py +75 -0
  51. sonusai/metrics/one_hot.py +283 -0
  52. sonusai/metrics/snr_summary.py +128 -0
  53. sonusai/metrics_summary.py +314 -0
  54. sonusai/mixture/__init__.py +15 -0
  55. sonusai/mixture/audio.py +187 -0
  56. sonusai/mixture/class_balancing.py +103 -0
  57. sonusai/mixture/constants.py +3 -0
  58. sonusai/mixture/data_io.py +173 -0
  59. sonusai/mixture/db.py +169 -0
  60. sonusai/mixture/db_datatypes.py +92 -0
  61. sonusai/mixture/effects.py +344 -0
  62. sonusai/mixture/feature.py +78 -0
  63. sonusai/mixture/generation.py +1116 -0
  64. sonusai/mixture/helpers.py +351 -0
  65. sonusai/mixture/ir_effects.py +77 -0
  66. sonusai/mixture/log_duration_and_sizes.py +23 -0
  67. sonusai/mixture/mixdb.py +1857 -0
  68. sonusai/mixture/pad_audio.py +35 -0
  69. sonusai/mixture/resample.py +7 -0
  70. sonusai/mixture/sox_effects.py +195 -0
  71. sonusai/mixture/sox_help.py +650 -0
  72. sonusai/mixture/spectral_mask.py +51 -0
  73. sonusai/mixture/truth.py +61 -0
  74. sonusai/mixture/truth_functions/__init__.py +45 -0
  75. sonusai/mixture/truth_functions/crm.py +105 -0
  76. sonusai/mixture/truth_functions/energy.py +222 -0
  77. sonusai/mixture/truth_functions/file.py +48 -0
  78. sonusai/mixture/truth_functions/metadata.py +24 -0
  79. sonusai/mixture/truth_functions/metrics.py +28 -0
  80. sonusai/mixture/truth_functions/phoneme.py +18 -0
  81. sonusai/mixture/truth_functions/sed.py +98 -0
  82. sonusai/mixture/truth_functions/target.py +142 -0
  83. sonusai/mkwav.py +135 -0
  84. sonusai/onnx_predict.py +363 -0
  85. sonusai/parse/__init__.py +0 -0
  86. sonusai/parse/expand.py +156 -0
  87. sonusai/parse/parse_source_directive.py +129 -0
  88. sonusai/parse/rand.py +214 -0
  89. sonusai/py.typed +0 -0
  90. sonusai/queries/__init__.py +0 -0
  91. sonusai/queries/queries.py +239 -0
  92. sonusai/rs.abi3.so +0 -0
  93. sonusai/rs.pyi +1 -0
  94. sonusai/rust/__init__.py +0 -0
  95. sonusai/speech/__init__.py +0 -0
  96. sonusai/speech/l2arctic.py +121 -0
  97. sonusai/speech/librispeech.py +102 -0
  98. sonusai/speech/mcgill.py +71 -0
  99. sonusai/speech/textgrid.py +89 -0
  100. sonusai/speech/timit.py +138 -0
  101. sonusai/speech/types.py +12 -0
  102. sonusai/speech/vctk.py +53 -0
  103. sonusai/speech/voxceleb.py +108 -0
  104. sonusai/utils/__init__.py +3 -0
  105. sonusai/utils/asl_p56.py +130 -0
  106. sonusai/utils/asr.py +91 -0
  107. sonusai/utils/asr_functions/__init__.py +3 -0
  108. sonusai/utils/asr_functions/aaware_whisper.py +69 -0
  109. sonusai/utils/audio_devices.py +50 -0
  110. sonusai/utils/braced_glob.py +50 -0
  111. sonusai/utils/calculate_input_shape.py +26 -0
  112. sonusai/utils/choice.py +51 -0
  113. sonusai/utils/compress.py +25 -0
  114. sonusai/utils/convert_string_to_number.py +6 -0
  115. sonusai/utils/create_timestamp.py +5 -0
  116. sonusai/utils/create_ts_name.py +14 -0
  117. sonusai/utils/dataclass_from_dict.py +27 -0
  118. sonusai/utils/db.py +16 -0
  119. sonusai/utils/docstring.py +53 -0
  120. sonusai/utils/energy_f.py +44 -0
  121. sonusai/utils/engineering_number.py +166 -0
  122. sonusai/utils/evaluate_random_rule.py +15 -0
  123. sonusai/utils/get_frames_per_batch.py +2 -0
  124. sonusai/utils/get_label_names.py +20 -0
  125. sonusai/utils/grouper.py +6 -0
  126. sonusai/utils/human_readable_size.py +7 -0
  127. sonusai/utils/keyboard_interrupt.py +12 -0
  128. sonusai/utils/load_object.py +21 -0
  129. sonusai/utils/max_text_width.py +9 -0
  130. sonusai/utils/model_utils.py +28 -0
  131. sonusai/utils/numeric_conversion.py +11 -0
  132. sonusai/utils/onnx_utils.py +155 -0
  133. sonusai/utils/parallel.py +162 -0
  134. sonusai/utils/path_info.py +7 -0
  135. sonusai/utils/print_mixture_details.py +60 -0
  136. sonusai/utils/rand.py +13 -0
  137. sonusai/utils/ranges.py +43 -0
  138. sonusai/utils/read_predict_data.py +32 -0
  139. sonusai/utils/reshape.py +154 -0
  140. sonusai/utils/seconds_to_hms.py +7 -0
  141. sonusai/utils/stacked_complex.py +82 -0
  142. sonusai/utils/stratified_shuffle_split.py +170 -0
  143. sonusai/utils/tokenized_shell_vars.py +143 -0
  144. sonusai/utils/write_audio.py +26 -0
  145. sonusai/utils/yes_or_no.py +8 -0
  146. sonusai/vars.py +47 -0
  147. sonusai-1.0.16.dist-info/METADATA +56 -0
  148. sonusai-1.0.16.dist-info/RECORD +150 -0
  149. sonusai-1.0.16.dist-info/WHEEL +4 -0
  150. sonusai-1.0.16.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,395 @@
1
+ import functools
2
+ from typing import Any
3
+
4
+ import numpy as np
5
+ from pystoi import stoi
6
+
7
+ from ..constants import SAMPLE_RATE
8
+ from ..datatypes import AudioF
9
+ from ..datatypes import AudioStatsMetrics
10
+ from ..datatypes import AudioT
11
+ from ..datatypes import Segsnr
12
+ from ..datatypes import SpeechMetrics
13
+ from ..mixture.mixdb import MixtureDatabase
14
+ from ..utils.asr import calc_asr
15
+ from ..utils.db import linear_to_db
16
+ from .calc_audio_stats import calc_audio_stats
17
+ from .calc_pesq import calc_pesq
18
+ from .calc_phase_distance import calc_phase_distance
19
+ from .calc_segsnr_f import calc_segsnr_f
20
+ from .calc_segsnr_f import calc_segsnr_f_bin
21
+ from .calc_speech import calc_speech
22
+ from .calc_wer import calc_wer
23
+ from .calc_wsdr import calc_wsdr
24
+
25
+
26
+ def calculate_metrics(mixdb: MixtureDatabase, m_id: int, metrics: list[str], force: bool = False) -> dict[str, Any]:
27
+ """Get metrics data for the given mixture ID
28
+
29
+ :param mixdb: Mixture database object
30
+ :param m_id: Zero-based mixture ID
31
+ :param metrics: List of metrics to get
32
+ :param force: Force computing data from original sources regardless of whether cached data exists
33
+ :return: Dictionary of metric data
34
+ """
35
+
36
+ # Define cached functions for expensive operations
37
+ @functools.lru_cache(maxsize=1)
38
+ def mixture_sources() -> dict[str, AudioT]:
39
+ return mixdb.mixture_sources(m_id)
40
+
41
+ @functools.lru_cache(maxsize=1)
42
+ def mixture_source() -> AudioT:
43
+ return mixdb.mixture_source(m_id)
44
+
45
+ @functools.lru_cache(maxsize=1)
46
+ def mixture_source_f() -> AudioF:
47
+ return mixdb.mixture_source_f(m_id)
48
+
49
+ @functools.lru_cache(maxsize=1)
50
+ def mixture_noise() -> AudioT:
51
+ return mixdb.mixture_noise(m_id)
52
+
53
+ @functools.lru_cache(maxsize=1)
54
+ def mixture_noise_f() -> AudioF:
55
+ return mixdb.mixture_noise_f(m_id)
56
+
57
+ @functools.lru_cache(maxsize=1)
58
+ def mixture_mixture() -> AudioT:
59
+ return mixdb.mixture_mixture(m_id)
60
+
61
+ @functools.lru_cache(maxsize=1)
62
+ def mixture_mixture_f() -> AudioF:
63
+ return mixdb.mixture_mixture_f(m_id)
64
+
65
+ @functools.lru_cache(maxsize=1)
66
+ def mixture_segsnr() -> Segsnr:
67
+ return mixdb.mixture_segsnr(m_id)
68
+
69
+ @functools.lru_cache(maxsize=1)
70
+ def calculate_pesq() -> dict[str, float]:
71
+ return {category: calc_pesq(mixture_mixture(), audio) for category, audio in mixture_sources().items()}
72
+
73
+ @functools.lru_cache(maxsize=1)
74
+ def calculate_speech() -> dict[str, SpeechMetrics]:
75
+ return {
76
+ category: calc_speech(mixture_mixture(), audio, calculate_pesq()[category])
77
+ for category, audio in mixture_sources().items()
78
+ }
79
+
80
+ @functools.lru_cache(maxsize=1)
81
+ def mixture_stats() -> AudioStatsMetrics:
82
+ return calc_audio_stats(mixture_mixture(), mixdb.fg_info.ft_config.length / SAMPLE_RATE)
83
+
84
+ @functools.lru_cache(maxsize=1)
85
+ def sources_stats() -> dict[str, AudioStatsMetrics]:
86
+ return {
87
+ category: calc_audio_stats(audio, mixdb.fg_info.ft_config.length / SAMPLE_RATE)
88
+ for category, audio in mixture_sources().items()
89
+ }
90
+
91
+ @functools.lru_cache(maxsize=1)
92
+ def source_stats() -> AudioStatsMetrics:
93
+ return calc_audio_stats(mixture_source(), mixdb.fg_info.ft_config.length / SAMPLE_RATE)
94
+
95
+ @functools.lru_cache(maxsize=1)
96
+ def noise_stats() -> AudioStatsMetrics:
97
+ return calc_audio_stats(mixture_noise(), mixdb.fg_info.ft_config.length / SAMPLE_RATE)
98
+
99
+ # Cache ASR configurations
100
+ @functools.lru_cache(maxsize=32)
101
+ def get_asr_config(asr_name: str) -> dict:
102
+ value = mixdb.asr_configs.get(asr_name, None)
103
+ if value is None:
104
+ raise ValueError(f"Unrecognized ASR name: '{asr_name}'")
105
+ return value
106
+
107
+ # Cache ASR results for sources, source and mixture
108
+ @functools.lru_cache(maxsize=16)
109
+ def sources_asr(asr_name: str) -> dict[str, str]:
110
+ return {
111
+ category: calc_asr(audio, **get_asr_config(asr_name)).text for category, audio in mixture_sources().items()
112
+ }
113
+
114
+ @functools.lru_cache(maxsize=16)
115
+ def source_asr(asr_name: str) -> str:
116
+ return calc_asr(mixture_source(), **get_asr_config(asr_name)).text
117
+
118
+ @functools.lru_cache(maxsize=16)
119
+ def mixture_asr(asr_name: str) -> str:
120
+ return calc_asr(mixture_mixture(), **get_asr_config(asr_name)).text
121
+
122
+ def get_asr_name(m: str) -> str:
123
+ parts = m.split(".")
124
+ if len(parts) != 2:
125
+ raise ValueError(f"Unrecognized format: '{m}'; must be of the form: '<metric>.<name>'")
126
+ asr_name = parts[1]
127
+ return asr_name
128
+
129
+ def calc(m: str) -> Any:
130
+ if m == "mxsnr":
131
+ return {category: source.snr for category, source in mixdb.mixture(m_id).all_sources.items()}
132
+
133
+ # Get cached data first, if exists
134
+ if not force:
135
+ value = mixdb.read_mixture_data(m_id, m)[m]
136
+ if value is not None:
137
+ return value
138
+
139
+ # Otherwise, generate data as needed
140
+ if m.startswith("mxwer"):
141
+ asr_name = get_asr_name(m)
142
+
143
+ if mixdb.mixture(m_id).is_noise_only:
144
+ # noise only, ignore/reset target asr
145
+ return float("nan")
146
+
147
+ if source_asr(asr_name):
148
+ return calc_wer(mixture_asr(asr_name), source_asr(asr_name)).wer * 100
149
+
150
+ # TODO: should this be NaN like above?
151
+ return float(0)
152
+
153
+ if m.startswith("basewer"):
154
+ asr_name = get_asr_name(m)
155
+
156
+ text = mixdb.mixture_speech_metadata(m_id, "text")
157
+ return {
158
+ category: calc_wer(source, str(text[category])).wer * 100 if isinstance(text[category], str) else 0
159
+ for category, source in sources_asr(asr_name).items()
160
+ }
161
+
162
+ if m.startswith("mxasr"):
163
+ return mixture_asr(get_asr_name(m))
164
+
165
+ if m == "mxssnr_avg":
166
+ return calc_segsnr_f(mixture_segsnr()).avg
167
+
168
+ if m == "mxssnr_std":
169
+ return calc_segsnr_f(mixture_segsnr()).std
170
+
171
+ if m == "mxssnr_avg_db":
172
+ val = calc_segsnr_f(mixture_segsnr()).avg
173
+ if val is not None:
174
+ return linear_to_db(val)
175
+ return None
176
+
177
+ if m == "mxssnr_std_db":
178
+ val = calc_segsnr_f(mixture_segsnr()).std
179
+ if val is not None:
180
+ return linear_to_db(val)
181
+ return None
182
+
183
+ if m == "mxssnrdb_avg":
184
+ return calc_segsnr_f(mixture_segsnr()).db_avg
185
+
186
+ if m == "mxssnrdb_std":
187
+ return calc_segsnr_f(mixture_segsnr()).db_std
188
+
189
+ if m == "mxssnrf_avg":
190
+ return calc_segsnr_f_bin(mixture_source_f(), mixture_noise_f()).avg
191
+
192
+ if m == "mxssnrf_std":
193
+ return calc_segsnr_f_bin(mixture_source_f(), mixture_noise_f()).std
194
+
195
+ if m == "mxssnrdbf_avg":
196
+ return calc_segsnr_f_bin(mixture_source_f(), mixture_noise_f()).db_avg
197
+
198
+ if m == "mxssnrdbf_std":
199
+ return calc_segsnr_f_bin(mixture_source_f(), mixture_noise_f()).db_std
200
+
201
+ if m == "mxpesq":
202
+ if mixdb.mixture(m_id).is_noise_only:
203
+ return dict.fromkeys(calculate_pesq(), 0)
204
+ return calculate_pesq()
205
+
206
+ if m == "mxcsig":
207
+ if mixdb.mixture(m_id).is_noise_only:
208
+ return dict.fromkeys(calculate_speech(), 0)
209
+ return {category: s.csig for category, s in calculate_speech().items()}
210
+
211
+ if m == "mxcbak":
212
+ if mixdb.mixture(m_id).is_noise_only:
213
+ return dict.fromkeys(calculate_speech(), 0)
214
+ return {category: s.cbak for category, s in calculate_speech().items()}
215
+
216
+ if m == "mxcovl":
217
+ if mixdb.mixture(m_id).is_noise_only:
218
+ return dict.fromkeys(calculate_speech(), 0)
219
+ return {category: s.covl for category, s in calculate_speech().items()}
220
+
221
+ if m == "mxwsdr":
222
+ mixture = mixture_mixture()[:, np.newaxis]
223
+ target = mixture_source()[:, np.newaxis]
224
+ noise = mixture_noise()[:, np.newaxis]
225
+ return calc_wsdr(
226
+ hypothesis=np.concatenate((mixture, noise), axis=1),
227
+ reference=np.concatenate((target, noise), axis=1),
228
+ with_log=True,
229
+ )[0]
230
+
231
+ if m == "mxpd":
232
+ return calc_phase_distance(hypothesis=mixture_mixture_f(), reference=mixture_source_f())[0]
233
+
234
+ if m == "mxstoi":
235
+ return stoi(
236
+ x=mixture_source(),
237
+ y=mixture_mixture(),
238
+ fs_sig=SAMPLE_RATE,
239
+ extended=False,
240
+ )
241
+
242
+ if m == "mxdco":
243
+ return mixture_stats().dco
244
+
245
+ if m == "mxmin":
246
+ return mixture_stats().min
247
+
248
+ if m == "mxmax":
249
+ return mixture_stats().max
250
+
251
+ if m == "mxpkdb":
252
+ return mixture_stats().pkdb
253
+
254
+ if m == "mxlrms":
255
+ return mixture_stats().lrms
256
+
257
+ if m == "mxpkr":
258
+ return mixture_stats().pkr
259
+
260
+ if m == "mxtr":
261
+ return mixture_stats().tr
262
+
263
+ if m == "mxcr":
264
+ return mixture_stats().cr
265
+
266
+ if m == "mxfl":
267
+ return mixture_stats().fl
268
+
269
+ if m == "mxpkc":
270
+ return mixture_stats().pkc
271
+
272
+ if m == "sdco":
273
+ return {category: s.dco for category, s in sources_stats().items()}
274
+
275
+ if m == "smin":
276
+ return {category: s.min for category, s in sources_stats().items()}
277
+
278
+ if m == "smax":
279
+ return {category: s.max for category, s in sources_stats().items()}
280
+
281
+ if m == "spkdb":
282
+ return {category: s.pkdb for category, s in sources_stats().items()}
283
+
284
+ if m == "slrms":
285
+ return {category: s.lrms for category, s in sources_stats().items()}
286
+
287
+ if m == "spkr":
288
+ return {category: s.pkr for category, s in sources_stats().items()}
289
+
290
+ if m == "str":
291
+ return {category: s.tr for category, s in sources_stats().items()}
292
+
293
+ if m == "scr":
294
+ return {category: s.cr for category, s in sources_stats().items()}
295
+
296
+ if m == "sfl":
297
+ return {category: s.fl for category, s in sources_stats().items()}
298
+
299
+ if m == "spkc":
300
+ return {category: s.pkc for category, s in sources_stats().items()}
301
+
302
+ if m == "mxsdco":
303
+ return source_stats().dco
304
+
305
+ if m == "mxsmin":
306
+ return source_stats().min
307
+
308
+ if m == "mxsmax":
309
+ return source_stats().max
310
+
311
+ if m == "mxspkdb":
312
+ return source_stats().pkdb
313
+
314
+ if m == "mxslrms":
315
+ return source_stats().lrms
316
+
317
+ if m == "mxspkr":
318
+ return source_stats().pkr
319
+
320
+ if m == "mxstr":
321
+ return source_stats().tr
322
+
323
+ if m == "mxscr":
324
+ return source_stats().cr
325
+
326
+ if m == "mxsfl":
327
+ return source_stats().fl
328
+
329
+ if m == "mxspkc":
330
+ return source_stats().pkc
331
+
332
+ if m.startswith("sasr"):
333
+ return sources_asr(get_asr_name(m))
334
+
335
+ if m.startswith("mxsasr"):
336
+ return source_asr(get_asr_name(m))
337
+
338
+ if m == "ndco":
339
+ return noise_stats().dco
340
+
341
+ if m == "nmin":
342
+ return noise_stats().min
343
+
344
+ if m == "nmax":
345
+ return noise_stats().max
346
+
347
+ if m == "npkdb":
348
+ return noise_stats().pkdb
349
+
350
+ if m == "nlrms":
351
+ return noise_stats().lrms
352
+
353
+ if m == "npkr":
354
+ return noise_stats().pkr
355
+
356
+ if m == "ntr":
357
+ return noise_stats().tr
358
+
359
+ if m == "ncr":
360
+ return noise_stats().cr
361
+
362
+ if m == "nfl":
363
+ return noise_stats().fl
364
+
365
+ if m == "npkc":
366
+ return noise_stats().pkc
367
+
368
+ if m == "sedavg":
369
+ return 0
370
+
371
+ if m == "sedcnt":
372
+ return 0
373
+
374
+ if m == "sedtop3":
375
+ return np.zeros(3, dtype=np.float32)
376
+
377
+ if m == "sedtopn":
378
+ return 0
379
+
380
+ if m == "ssnr":
381
+ return mixture_segsnr()
382
+
383
+ raise AttributeError(f"Unrecognized metric: '{m}'")
384
+
385
+ result: dict[str, Any] = {}
386
+ for metric in metrics:
387
+ result[metric] = calc(metric)
388
+
389
+ # Check for metrics dependencies and add them even if not explicitly requested.
390
+ if metric.startswith("mxwer"):
391
+ dependencies = ("mxasr." + metric[6:], "sasr." + metric[6:])
392
+ for dependency in dependencies:
393
+ result[dependency] = calc(dependency)
394
+
395
+ return result
@@ -0,0 +1,74 @@
1
+ # ruff: noqa: F821
2
+ import numpy as np
3
+ import pandas as pd
4
+
5
+ from ..datatypes import GeneralizedIDs
6
+ from ..datatypes import Predict
7
+ from ..datatypes import Truth
8
+ from ..mixture.mixdb import MixtureDatabase
9
+
10
+
11
+ def class_summary(
12
+ mixdb: MixtureDatabase,
13
+ mixids: GeneralizedIDs,
14
+ truth_f: Truth,
15
+ predict: Predict,
16
+ predict_thr: float | np.ndarray = 0,
17
+ truth_thr: float = 0.5,
18
+ timesteps: int = 0,
19
+ ) -> pd.DataFrame:
20
+ """Calculate table of metrics per class, and averages for a list
21
+ of mixtures using truth and prediction data [features, num_classes]
22
+ Example:
23
+ Generate multi-class metric summary into table, for example:
24
+ PPV TPR F1 FPR ACC AP AUC Support
25
+ Class 1 0.71 0.80 0.75 0.00 0.99 44
26
+ Class 2 0.90 0.76 0.82 0.00 0.99 128
27
+ Class 3 0.86 0.82 0.84 0.04 0.93 789
28
+ Other 0.94 0.96 0.95 0.18 0.92 2807
29
+
30
+ micro-avg 0.92 0.027 3768
31
+ macro avg 0.85 0.83 0.84 0.05 0.96 3768
32
+ micro-avgwo
33
+ """
34
+ from ..metrics.one_hot import one_hot
35
+
36
+ num_classes = truth_f.shape[1]
37
+
38
+ # TODO: re-work for modern mixdb API
39
+ y_truth_f, y_predict = get_mixids_data(mixdb, mixids, truth_f, predict) # type: ignore[name-defined]
40
+
41
+ if num_classes > 1:
42
+ if not isinstance(predict_thr, np.ndarray):
43
+ if predict_thr == 0:
44
+ predict_thr = np.atleast_1d(0.5)
45
+ else:
46
+ predict_thr = np.atleast_1d(predict_thr)
47
+ else:
48
+ if predict_thr.ndim == 1 and predict_thr[0] == 0:
49
+ predict_thr = np.atleast_1d(0.5)
50
+
51
+ _, metrics, _, _, _, metavg = one_hot(y_truth_f, y_predict, predict_thr, truth_thr, timesteps)
52
+
53
+ # [ACC, TPR, PPV, TNR, FPR, HITFA, F1, MCC, NT, PT, TP, FP, AP, AUC]
54
+ table_idx = np.array([2, 1, 6, 4, 0, 12, 13, 9])
55
+ col_n = ["PPV", "TPR", "F1", "FPR", "ACC", "AP", "AUC", "Support"]
56
+ if len(mixdb.class_labels) == num_classes:
57
+ row_n = mixdb.class_labels
58
+ else:
59
+ row_n = [f"Class {i}" for i in range(1, num_classes + 1)]
60
+
61
+ df = pd.DataFrame(metrics[:, table_idx], columns=col_n, index=row_n) # pyright: ignore [reportArgumentType]
62
+
63
+ # [miPPV, miTPR, miF1, miFPR, miACC, miAP, miAUC, TPSUM]
64
+ avg_row_n = ["Macro-avg", "Micro-avg", "Weighted-avg"]
65
+ dfavg = pd.DataFrame(metavg, columns=col_n, index=avg_row_n) # pyright: ignore [reportArgumentType]
66
+
67
+ # dfblank = pd.DataFrame([''])
68
+ # pd.concat([df, dfblank, dfblank, dfavg])
69
+
70
+ classdf = pd.concat([df, dfavg])
71
+ # classdf = classdf.round(2)
72
+ classdf["Support"] = classdf["Support"].astype(int)
73
+
74
+ return classdf
@@ -0,0 +1,75 @@
1
+ # ruff: noqa: F821
2
+ import numpy as np
3
+ import pandas as pd
4
+
5
+ from ..datatypes import GeneralizedIDs
6
+ from ..datatypes import Predict
7
+ from ..datatypes import Truth
8
+ from ..mixture.mixdb import MixtureDatabase
9
+
10
+
11
+ def confusion_matrix_summary(
12
+ mixdb: MixtureDatabase,
13
+ mixids: GeneralizedIDs,
14
+ truth_f: Truth,
15
+ predict: Predict,
16
+ class_idx: int,
17
+ predict_thr: float | np.ndarray = 0,
18
+ truth_thr: float = 0.5,
19
+ timesteps: int = 0,
20
+ ) -> tuple[pd.DataFrame, pd.DataFrame]:
21
+ """Calculate confusion matrix for specified class, using truth and prediction
22
+ data [features, num_classes].
23
+
24
+ predict_thr sets the decision threshold(s) applied to predict data, thus allowing
25
+ predict to be continuous probabilities.
26
+
27
+ Default predict_thr=0 will infer 0.5 for multi-label mode (truth_mutex = False), or
28
+ if single-label mode (truth_mutex == True) then ignore and use argmax mode, and
29
+ the confusion matrix is calculated for all classes.
30
+
31
+ Returns pandas dataframes of confusion matrix cmdf and normalized confusion matrix cmndf.
32
+ """
33
+ from ..metrics.one_hot import one_hot
34
+
35
+ num_classes = truth_f.shape[1]
36
+ # TODO: re-work for modern mixdb API
37
+ ytrue, ypred = get_mixids_data(mixdb=mixdb, mixids=mixids, truth_f=truth_f, predict=predict) # type: ignore[name-defined]
38
+
39
+ # Check predict_thr array or scalar and return final scalar predict_thr value
40
+ if num_classes > 1:
41
+ if not isinstance(predict_thr, np.ndarray):
42
+ if predict_thr == 0:
43
+ # multi-label predict_thr scalar 0 force to 0.5 default
44
+ predict_thr = np.atleast_1d(0.5)
45
+ else:
46
+ predict_thr = np.atleast_1d(predict_thr)
47
+ else:
48
+ if predict_thr.ndim == 1:
49
+ if predict_thr[0] == 0:
50
+ # multi-label predict_thr array scalar 0 force to 0.5 default
51
+ predict_thr = np.atleast_1d(0.5)
52
+ else:
53
+ # multi-label predict_thr array set to scalar = array[0]
54
+ predict_thr = predict_thr[0]
55
+ else:
56
+ # multi-label predict_thr array scalar set = array[class_idx]
57
+ predict_thr = predict_thr[class_idx]
58
+
59
+ if len(mixdb.class_labels) == num_classes:
60
+ class_names = mixdb.class_labels
61
+ else:
62
+ class_names = [f"Class {i}" for i in range(1, num_classes + 1)]
63
+
64
+ _, _, cm, cmn, _, _ = one_hot(ytrue[:, class_idx], ypred[:, class_idx], predict_thr, truth_thr, timesteps)
65
+ cname = class_names[class_idx]
66
+ row_n = ["TrueN", "TrueP"]
67
+ col_n = ["N-" + cname, "P-" + cname]
68
+ cmdf = pd.DataFrame(cm, index=row_n, columns=col_n, dtype=np.int32) # pyright: ignore [reportArgumentType]
69
+ cmndf = pd.DataFrame(cmn, index=row_n, columns=col_n, dtype=np.float32) # pyright: ignore [reportArgumentType]
70
+ # add thresholds in 3rd row
71
+ pdnote = pd.DataFrame(np.atleast_2d([predict_thr, truth_thr]), index=["p/t thr:"], columns=col_n) # pyright: ignore [reportArgumentType, reportCallIssue]
72
+ cmdf = pd.concat([cmdf, pdnote])
73
+ cmndf = pd.concat([cmndf, pdnote])
74
+
75
+ return cmdf, cmndf