sonusai 0.19.9__py3-none-any.whl → 0.19.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Binary file
Binary file
sonusai/genft.py CHANGED
@@ -109,7 +109,7 @@ def _genft_kernel(
109
109
  write_cached_data(mixdb.location, "mixture", mixdb.mixture(m_id).name, [("segsnr", segsnr)])
110
110
 
111
111
  if write:
112
- write_mixture_metadata(mixdb, m_id)
112
+ write_mixture_metadata(mixdb, m_id=m_id)
113
113
 
114
114
  return result
115
115
 
sonusai/genmetrics.py CHANGED
@@ -1,6 +1,6 @@
1
1
  """sonusai genmetrics
2
2
 
3
- usage: genmetrics [-hvsd] [-i MIXID] [-n INCLUDE] [-x EXCLUDE] LOC
3
+ usage: genmetrics [-hvusd] [-i MIXID] [-n INCLUDE] [-x EXCLUDE] LOC
4
4
 
5
5
  options:
6
6
  -h, --help
@@ -8,6 +8,7 @@ options:
8
8
  -i MIXID, --mixid MIXID Mixture ID(s) to generate. [default: *].
9
9
  -n INCLUDE, --include INCLUDE Metrics to include. [default: all]
10
10
  -x EXCLUDE, --exclude EXCLUDE Metrics to exclude. [default: none]
11
+ -u, --update Update metrics (do not regenerate existing metrics).
11
12
  -s, --supported Show list of supported metrics.
12
13
  -d, --dryrun Show list of metrics that will be generated and exit.
13
14
 
@@ -60,16 +61,15 @@ def signal_handler(_sig, _frame):
60
61
  signal.signal(signal.SIGINT, signal_handler)
61
62
 
62
63
 
63
- def _process_mixture(mixid: int, location: str, metrics: list[str]) -> None:
64
+ def _process_mixture(mixid: int, location: str, metrics: list[str], update: bool = False) -> set[str]:
64
65
  from sonusai.mixture import MixtureDatabase
65
66
  from sonusai.mixture import write_cached_data
66
67
 
67
68
  mixdb = MixtureDatabase(location)
69
+ results = mixdb.mixture_metrics(m_id=mixid, metrics=metrics, force=not update)
70
+ write_cached_data(mixdb.location, "mixture", mixdb.mixture(mixid).name, list(results.items()))
68
71
 
69
- values = mixdb.mixture_metrics(m_id=mixid, metrics=metrics, force=True)
70
- write_data = list(zip(metrics, values, strict=False))
71
-
72
- write_cached_data(mixdb.location, "mixture", mixdb.mixture(mixid).name, write_data)
72
+ return set(results.keys())
73
73
 
74
74
 
75
75
  def main() -> None:
@@ -85,6 +85,7 @@ def main() -> None:
85
85
  mixids = args["--mixid"]
86
86
  includes = {x.strip() for x in args["--include"].replace(" ", ",").lower().split(",") if x != ""}
87
87
  excludes = {x.strip() for x in args["--exclude"].replace(" ", ",").lower().split(",") if x != ""}
88
+ update = args["--update"]
88
89
  show_supported = args["--supported"]
89
90
  dryrun = args["--dryrun"]
90
91
  location = args["LOC"]
@@ -141,20 +142,14 @@ def main() -> None:
141
142
 
142
143
  requested = included_metrics - excluded_metrics
143
144
 
144
- # Check for metrics dependencies and cache dependencies even if not explicitly requested.
145
- dependencies: set[str] = set()
146
- for metric in requested:
147
- if metric.startswith("mxwer"):
148
- dependencies.add("mxasr." + metric[6:])
149
- dependencies.add("tasr." + metric[6:])
150
-
151
- metrics = sorted(requested | dependencies)
145
+ metrics = sorted(requested)
152
146
 
153
147
  if len(metrics) == 0:
154
148
  logger.warning("No metrics were requested")
155
149
  sys.exit(1)
156
150
 
157
- logger.info(f"Generating metrics: {', '.join(metrics)}")
151
+ logger.info("Generating metrics:")
152
+ logger.info(f"{', '.join(metrics)}")
158
153
  if dryrun:
159
154
  sys.exit(0)
160
155
 
@@ -163,14 +158,16 @@ def main() -> None:
163
158
  logger.info(f"Found {len(mixids):,} mixtures to process")
164
159
 
165
160
  progress = track(total=len(mixids), desc="genmetrics")
166
- par_track(
167
- partial(_process_mixture, location=location, metrics=metrics),
161
+ results = par_track(
162
+ partial(_process_mixture, location=location, metrics=metrics, update=update),
168
163
  mixids,
169
164
  progress=progress,
170
165
  )
171
166
  progress.close()
172
167
 
173
- logger.info(f"Wrote metrics for {len(mixids)} mixtures to {location}")
168
+ written_metrics = sorted(set().union(*results))
169
+ logger.info(f"Wrote metrics for {len(mixids)} mixtures to {location}:")
170
+ logger.info(f"{', '.join(written_metrics)}")
174
171
  logger.info("")
175
172
 
176
173
  end_time = time.monotonic()
sonusai/genmix.py CHANGED
@@ -139,7 +139,7 @@ def _genmix_kernel(
139
139
  result.mixture = mixture
140
140
  if write:
141
141
  write_cached_data(mixdb.location, "mixture", mixdb.mixture(m_id).name, [("mixture", mixture)])
142
- write_mixture_metadata(mixdb, m_id)
142
+ write_mixture_metadata(mixdb, m_id=m_id)
143
143
 
144
144
  return result
145
145
 
sonusai/genmixdb.py CHANGED
@@ -1,13 +1,11 @@
1
1
  """sonusai genmixdb
2
2
 
3
- usage: genmixdb [-hvmfsdjn] LOC
3
+ usage: genmixdb [-hvmdjn] LOC
4
4
 
5
5
  options:
6
6
  -h, --help
7
7
  -v, --verbose Be verbose.
8
8
  -m, --mix ave mixture data. [default: False].
9
- -f, --ft Save feature/truth_f data. [default: False].
10
- -s, --segsnr Save segsnr data. [default: False].
11
9
  -d, --dryrun Perform a dry run showing the processed config. [default: False].
12
10
  -j, --json Save JSON version of database. [default: False].
13
11
  -n, --nopar Do not run in parallel. [default: False].
@@ -116,6 +114,9 @@ will find all .wav files in the specified directories and process them as target
116
114
 
117
115
  import signal
118
116
 
117
+ from sonusai.mixture import Mixture
118
+ from sonusai.mixture import MixtureDatabase
119
+
119
120
 
120
121
  def signal_handler(_sig, _frame):
121
122
  import sys
@@ -132,8 +133,6 @@ signal.signal(signal.SIGINT, signal_handler)
132
133
  def genmixdb(
133
134
  location: str,
134
135
  save_mix: bool = False,
135
- save_ft: bool = False,
136
- save_segsnr: bool = False,
137
136
  logging: bool = True,
138
137
  show_progress: bool = False,
139
138
  test: bool = False,
@@ -151,6 +150,7 @@ def genmixdb(
151
150
  from sonusai.mixture import AugmentationRule
152
151
  from sonusai.mixture import MixtureDatabase
153
152
  from sonusai.mixture import balance_targets
153
+ from sonusai.mixture import generate_mixtures
154
154
  from sonusai.mixture import get_all_snrs_from_config
155
155
  from sonusai.mixture import get_augmentation_rules
156
156
  from sonusai.mixture import get_augmented_targets
@@ -316,8 +316,10 @@ def genmixdb(
316
316
  f"{seconds_to_hms(seconds=noise_audio_duration)}"
317
317
  )
318
318
 
319
- used_noise_files, used_noise_samples = populate_mixture_table(
320
- location=location,
319
+ if logging:
320
+ logger.info("Generating mixtures")
321
+
322
+ used_noise_files, used_noise_samples, mixtures = generate_mixtures(
321
323
  noise_mix_mode=mixdb.noise_mix_mode,
322
324
  augmented_targets=augmented_targets,
323
325
  target_files=target_files,
@@ -330,17 +332,16 @@ def genmixdb(
330
332
  num_classes=mixdb.num_classes,
331
333
  feature_step_samples=mixdb.feature_step_samples,
332
334
  num_ir=mixdb.num_impulse_response_files,
333
- test=test,
334
335
  )
335
336
 
336
- num_mixtures = len(mixdb.mixtures)
337
+ num_mixtures = len(mixtures)
337
338
  update_mixid_width(location, num_mixtures, test)
338
339
 
339
340
  if logging:
340
341
  logger.info("")
341
342
  logger.info(f"Found {num_mixtures:,} mixtures to process")
342
343
 
343
- total_duration = float(sum([mixture.samples for mixture in mixdb.mixtures])) / SAMPLE_RATE
344
+ total_duration = float(sum([mixture.samples for mixture in mixtures])) / SAMPLE_RATE
344
345
 
345
346
  if logging:
346
347
  log_duration_and_sizes(
@@ -362,23 +363,29 @@ def genmixdb(
362
363
 
363
364
  # Fill in the details
364
365
  if logging:
365
- logger.info("Generating mixtures")
366
+ logger.info("Processing mixtures")
366
367
  progress = track(total=num_mixtures, disable=not show_progress)
367
- par_track(
368
+ mixtures = par_track(
368
369
  partial(
369
370
  _process_mixture,
370
371
  location=location,
371
372
  save_mix=save_mix,
372
- save_ft=save_ft,
373
- save_segsnr=save_segsnr,
374
373
  test=test,
375
374
  ),
376
- range(num_mixtures),
375
+ mixtures,
377
376
  progress=progress,
378
377
  no_par=no_par,
379
378
  )
380
379
  progress.close()
381
380
 
381
+ populate_mixture_table(
382
+ location=location,
383
+ mixtures=mixtures,
384
+ test=test,
385
+ logging=logging,
386
+ show_progress=show_progress,
387
+ )
388
+
382
389
  total_noise_files = len(noise_files)
383
390
 
384
391
  total_samples = mixdb.total_samples()
@@ -409,32 +416,23 @@ def genmixdb(
409
416
 
410
417
 
411
418
  def _process_mixture(
412
- m_id: int,
419
+ mixture: Mixture,
413
420
  location: str,
414
421
  save_mix: bool,
415
- save_ft: bool,
416
- save_segsnr: bool,
417
422
  test: bool,
418
- ) -> None:
423
+ ) -> Mixture:
419
424
  from functools import partial
420
425
 
421
- from sonusai.mixture import MixtureDatabase
422
- from sonusai.mixture import clear_cached_data
423
- from sonusai.mixture import update_mixture_table
426
+ from sonusai.mixture import update_mixture
424
427
  from sonusai.mixture import write_cached_data
425
428
  from sonusai.mixture import write_mixture_metadata
426
429
 
427
- with_data = save_mix or save_ft or save_segsnr
428
-
429
- genmix_data = update_mixture_table(location, m_id, with_data, test)
430
-
431
- mixdb = MixtureDatabase(location, test)
432
- mixture = mixdb.mixture(m_id)
430
+ mixdb = MixtureDatabase(location, test=test)
431
+ mixture, genmix_data = update_mixture(mixdb, mixture, save_mix)
433
432
 
434
433
  write = partial(write_cached_data, location=location, name="mixture", index=mixture.name)
435
- clear = partial(clear_cached_data, location=location, name="mixture", index=mixture.name)
436
434
 
437
- if with_data:
435
+ if save_mix:
438
436
  write(
439
437
  items=[
440
438
  ("targets", genmix_data.targets),
@@ -444,25 +442,9 @@ def _process_mixture(
444
442
  ]
445
443
  )
446
444
 
447
- if save_ft:
448
- clear(items=["feature", "truth_f"])
449
- feature, truth_f = mixdb.mixture_ft(m_id)
450
- write(
451
- items=[
452
- ("feature", feature),
453
- ("truth_f", truth_f),
454
- ]
455
- )
456
-
457
- if save_segsnr:
458
- clear(items=["segsnr"])
459
- segsnr = mixdb.mixture_segsnr(m_id)
460
- write(items=[("segsnr", segsnr)])
461
-
462
- if not save_mix:
463
- clear(items=["targets", "target", "noise", "mixture"])
445
+ write_mixture_metadata(mixdb, mixture=mixture)
464
446
 
465
- write_mixture_metadata(mixdb, m_id)
447
+ return mixture
466
448
 
467
449
 
468
450
  def main() -> None:
@@ -491,8 +473,6 @@ def main() -> None:
491
473
 
492
474
  verbose = args["--verbose"]
493
475
  save_mix = args["--mix"]
494
- save_ft = args["--ft"]
495
- save_segsnr = args["--segsnr"]
496
476
  dryrun = args["--dryrun"]
497
477
  save_json = args["--json"]
498
478
  no_par = args["--nopar"]
@@ -522,8 +502,6 @@ def main() -> None:
522
502
  genmixdb(
523
503
  location=location,
524
504
  save_mix=save_mix,
525
- save_ft=save_ft,
526
- save_segsnr=save_segsnr,
527
505
  show_progress=True,
528
506
  save_json=save_json,
529
507
  no_par=no_par,
@@ -0,0 +1,320 @@
1
+ """sonusai metrics_summary
2
+
3
+ usage: lsdb [-vlh] [-i MIXID] [-n NCPU] LOCATION
4
+
5
+ Options:
6
+ -h, --help
7
+ -v, --verbose
8
+ -l, --write-list Write .csv file list of all mixture metrics
9
+ -i MIXID, --mixid MIXID Mixture ID(s) to analyze. [default: *].
10
+ -n, --num_process NCPU Number of parallel processes to use [default: auto]
11
+
12
+ Summarize mixture metrics across a SonusAI mixture database where metrics have been generated by SonusAI genmetrics.
13
+
14
+ Inputs:
15
+ LOCATION A SonusAI mixture database directory with mixdb.db and pre-generated metrics from SonusAI genmetrics.
16
+
17
+ """
18
+
19
+ import signal
20
+
21
+ import numpy as np
22
+ import pandas as pd
23
+
24
+
25
+ def signal_handler(_sig, _frame):
26
+ import sys
27
+
28
+ from sonusai import logger
29
+
30
+ logger.info("Canceled due to keyboard interrupt")
31
+ sys.exit(1)
32
+
33
+
34
+ signal.signal(signal.SIGINT, signal_handler)
35
+
36
+ DB_99 = np.power(10, 99 / 10)
37
+ DB_N99 = np.power(10, -99 / 10)
38
+
39
+
40
+ def _process_mixture(
41
+ m_id: int,
42
+ location: str,
43
+ all_metric_names: list[str],
44
+ scalar_metric_names: list[str],
45
+ string_metric_names: list[str],
46
+ frame_metric_names: list[str],
47
+ bin_metric_names: list[str],
48
+ ptab_labels: list[str],
49
+ ) -> tuple[pd.DataFrame, pd.DataFrame]:
50
+ from os.path import basename
51
+
52
+ from sonusai.metrics import calc_wer
53
+ from sonusai.mixture import SAMPLE_RATE
54
+ from sonusai.mixture import MixtureDatabase
55
+
56
+ mixdb = MixtureDatabase(location)
57
+
58
+ # Process mixture
59
+ # for mixid in mixids:
60
+ samples = mixdb.mixture(m_id).samples
61
+ duration = samples / SAMPLE_RATE
62
+ tf_frames = mixdb.mixture_transform_frames(m_id)
63
+ feat_frames = mixdb.mixture_feature_frames(m_id)
64
+ mxsnr = mixdb.mixture(m_id).snr
65
+ ti = mixdb.mixture(m_id).targets[0].file_id
66
+ ni = mixdb.mixture(m_id).noise.file_id
67
+ t0file = basename(mixdb.target_file(ti).name)
68
+ nfile = basename(mixdb.noise_file(ni).name)
69
+
70
+ all_metrics = mixdb.mixture_metrics(m_id, all_metric_names)
71
+
72
+ # replace lists with first value (ignore mixup)
73
+ scalar_metrics = {
74
+ key: all_metrics[key][0] if isinstance(all_metrics[key], list) else all_metrics[key]
75
+ for key in scalar_metric_names
76
+ }
77
+ string_metrics = {
78
+ key: all_metrics[key][0] if isinstance(all_metrics[key], list) else all_metrics[key]
79
+ for key in string_metric_names
80
+ }
81
+
82
+ # Convert strings into word count
83
+ for key in string_metrics:
84
+ string_metrics[key] = calc_wer(string_metrics[key], string_metrics[key]).words
85
+
86
+ # Collect pandas table values note: must match given ptab_labels
87
+ ptab_data: list = [
88
+ mxsnr,
89
+ *scalar_metrics.values(),
90
+ *string_metrics.values(),
91
+ tf_frames,
92
+ duration,
93
+ t0file,
94
+ nfile,
95
+ ]
96
+
97
+ ptab1 = pd.DataFrame([ptab_data], columns=ptab_labels, index=[m_id])
98
+
99
+ # TODO: collect frame metrics and bin metrics
100
+
101
+ return ptab1, ptab1
102
+
103
+
104
+ def main() -> None:
105
+ from docopt import docopt
106
+
107
+ from sonusai import __version__ as sonusai_ver
108
+ from sonusai.utils import trim_docstring
109
+
110
+ args = docopt(trim_docstring(__doc__), version=sonusai_ver, options_first=True)
111
+
112
+ verbose = args["--verbose"]
113
+ wrlist = args["--write-list"]
114
+ mixids = args["--mixid"]
115
+ location = args["LOCATION"]
116
+ num_proc = args["--num_process"]
117
+
118
+ from functools import partial
119
+ from os.path import basename
120
+ from os.path import join
121
+
122
+ import psutil
123
+
124
+ from sonusai import create_file_handler
125
+ from sonusai import initial_log_messages
126
+ from sonusai import logger
127
+ from sonusai import update_console_handler
128
+ from sonusai.mixture import MixtureDatabase
129
+ from sonusai.utils import create_timestamp
130
+ from sonusai.utils import par_track
131
+ from sonusai.utils import track
132
+
133
+ try:
134
+ mixdb = MixtureDatabase(location)
135
+ print(f"Found SonusAI mixture database with {mixdb.num_mixtures} mixtures.")
136
+ except:
137
+ print(f"Could not open SonusAI mixture database in {location}, exiting ...")
138
+ return
139
+
140
+ metrics_present = mixdb.cached_metrics()
141
+ num_metrics_present = len(metrics_present)
142
+ if num_metrics_present < 1:
143
+ print(f"mixdb reports no pre-generated metrics are present. Nothing to summarize in {location}, exiting ...")
144
+ return
145
+
146
+ # Setup logging file
147
+ timestamp = create_timestamp() # string good for embedding into filenames
148
+ mixdb_fname = basename(location)
149
+ if verbose:
150
+ create_file_handler(join(location, "metrics_summary.log"))
151
+ update_console_handler(verbose)
152
+ initial_log_messages("metrics_summary")
153
+ logger.info(f"Logging summary of SonusAI mixture db at {location}")
154
+ else:
155
+ update_console_handler(verbose)
156
+
157
+ logger.info("")
158
+ mixids = mixdb.mixids_to_list(mixids)
159
+ if len(mixids) < mixdb.num_mixtures:
160
+ logger.info(
161
+ f"Processing a subset of {len(mixids)} out of total mixdb mixtures of {mixdb.num_mixtures}, "
162
+ f"summary results will not include entire dataset."
163
+ )
164
+ fsuffix = f"_s{len(mixids)}t{mixdb.num_mixtures}"
165
+ else:
166
+ logger.info(
167
+ f"Summarizing SonusAI mixture db with {mixdb.num_mixtures} mixtures "
168
+ f"and {num_metrics_present} pre-generated metrics ..."
169
+ )
170
+ fsuffix = ""
171
+
172
+ metric_sup = mixdb.supported_metrics
173
+ ft_bins = mixdb.ft_config.bin_end - mixdb.ft_config.bin_start + 1 # bins of forward transform
174
+ # Pre-process first mixid to gather metrics into 4 types: scalar, str (scalar word cnt), frame-array, bin-array
175
+ # Collect list of indices for each
176
+ scalar_metric_names: list[str] = []
177
+ string_metric_names: list[str] = []
178
+ frame_metric_names: list[str] = []
179
+ bin_metric_names: list[str] = []
180
+ all_metrics = mixdb.mixture_metrics(mixids[0], metrics_present)
181
+ tf_frames = mixdb.mixture_transform_frames(mixids[0])
182
+ for metric in metrics_present:
183
+ metval = all_metrics[metric] # get metric value
184
+ logger.debug(f"First mixid {mixids[0]} metric {metric} = {metval}")
185
+ if isinstance(metval, list):
186
+ if len(metval) > 1:
187
+ logger.warning(f"Mixid {mixids[0]} metric {metric} has a list with more than 1 element, using first.")
188
+ metval = metval[0] # remove any list
189
+ if isinstance(metval, float):
190
+ logger.debug("Metric is scalar float, entering in summary table.")
191
+ scalar_metric_names.append(metric)
192
+ elif isinstance(metval, str):
193
+ logger.debug("Metric is string, will summarize with word count.")
194
+ string_metric_names.append(metric)
195
+ elif isinstance(metval, np.ndarray):
196
+ if metval.ndim == 1:
197
+ if metval.size == tf_frames:
198
+ logger.debug("Metric is frames vector.")
199
+ frame_metric_names.append(metric)
200
+ elif metval.size == ft_bins:
201
+ logger.debug("Metric is bins vector.")
202
+ bin_metric_names.append(metric)
203
+ else:
204
+ logger.warning(f"Mixid {mixids[0]} metric {metric} is a vector of improper size, ignoring.")
205
+
206
+ # Setup pandas table for summarizing scalar metrics
207
+ ptab_labels = [
208
+ "mxsnr",
209
+ *scalar_metric_names,
210
+ *string_metric_names,
211
+ "fcnt",
212
+ "duration",
213
+ "t0file",
214
+ "nfile",
215
+ ]
216
+
217
+ num_cpu = psutil.cpu_count()
218
+ cpu_percent = psutil.cpu_percent(interval=1)
219
+ logger.info("")
220
+ logger.info(f"#CPUs: {num_cpu}, current CPU utilization: {cpu_percent}%")
221
+ logger.info(f"Memory utilization: {psutil.virtual_memory().percent}%")
222
+ if num_proc == "auto":
223
+ use_cpu = int(num_cpu * (0.9 - cpu_percent / 100)) # default use 80% of available cpus
224
+ elif num_proc == "None":
225
+ use_cpu = None
226
+ else:
227
+ use_cpu = min(max(int(num_proc), 1), num_cpu)
228
+
229
+ logger.info(f"Summarizing metrics for {len(mixids)} mixtures using {use_cpu} parallel processes")
230
+
231
+ # progress = tqdm(total=len(mixids), desc='calc_metric_spenh', mininterval=1)
232
+ progress = track(total=len(mixids))
233
+ if use_cpu is None:
234
+ no_par = True
235
+ num_cpus = None
236
+ else:
237
+ no_par = False
238
+ num_cpus = use_cpu
239
+
240
+ all_metrics_tables = par_track(
241
+ partial(
242
+ _process_mixture,
243
+ location=location,
244
+ all_metric_names=metrics_present,
245
+ scalar_metric_names=scalar_metric_names,
246
+ string_metric_names=string_metric_names,
247
+ frame_metric_names=frame_metric_names,
248
+ bin_metric_names=bin_metric_names,
249
+ ptab_labels=ptab_labels,
250
+ ),
251
+ mixids,
252
+ progress=progress,
253
+ num_cpus=num_cpus,
254
+ no_par=no_par,
255
+ )
256
+ progress.close()
257
+
258
+ # Done with mixtures, write out summary metrics
259
+ header_args = {
260
+ "mode": "a",
261
+ "encoding": "utf-8",
262
+ "index": False,
263
+ "header": False,
264
+ }
265
+ table_args = {
266
+ "mode": "a",
267
+ "encoding": "utf-8",
268
+ }
269
+ ptab1 = pd.concat([item[0] for item in all_metrics_tables])
270
+ if wrlist:
271
+ wlcsv_name = str(join(location, "metric_summary_list" + fsuffix + ".csv"))
272
+ pd.DataFrame([["Timestamp", timestamp]]).to_csv(wlcsv_name, header=False, index=False)
273
+ pd.DataFrame([f"Metric list for {mixdb_fname}:"]).to_csv(wlcsv_name, mode="a", header=False, index=False)
274
+ ptab1.round(2).to_csv(wlcsv_name, **table_args)
275
+ ptab1_sorted = ptab1.sort_values(by=["mxsnr", "t0file"])
276
+
277
+ # Create metrics table except except -99 SNR
278
+ ptab1_nom99 = ptab1_sorted[ptab1_sorted.mxsnr != -99]
279
+
280
+ # Create summary by SNR for all scalar metrics, taking mean
281
+ mtab_snr_summary = None
282
+ for snri in range(0, len(mixdb.snrs)):
283
+ tmp = ptab1_sorted.query("mxsnr==" + str(mixdb.snrs[snri])).mean(numeric_only=True).to_frame().T
284
+ # avoid nan when subset of mixids specified (i.e. no mixtures exist for an SNR)
285
+ if ~np.isnan(tmp.iloc[0].to_numpy()[0]).any():
286
+ mtab_snr_summary = pd.concat([mtab_snr_summary, tmp])
287
+ mtab_snr_summary = mtab_snr_summary.sort_values(by=["mxsnr"], ascending=False)
288
+
289
+ # Write summary to .csv
290
+ snrcsv_name = str(join(location, "metric_summary_snr" + fsuffix + ".csv"))
291
+ nmix = len(mixids)
292
+ nmixtot = mixdb.num_mixtures
293
+ pd.DataFrame([["Timestamp", timestamp]]).to_csv(snrcsv_name, header=False, index=False)
294
+ pd.DataFrame(['"Metrics avg over each SNR:"']).to_csv(snrcsv_name, **header_args)
295
+ mtab_snr_summary.round(2).to_csv(snrcsv_name, index=False, **table_args)
296
+ pd.DataFrame(["--"]).to_csv(snrcsv_name, header=False, index=False, mode="a")
297
+ pd.DataFrame([f'"Metrics stats over {nmix} mixtures out of {nmixtot} total:"']).to_csv(snrcsv_name, **header_args)
298
+ ptab1.describe().round(2).T.to_csv(snrcsv_name, index=True, **table_args)
299
+ pd.DataFrame(["--"]).to_csv(snrcsv_name, header=False, index=False, mode="a")
300
+ pd.DataFrame([f'"Metrics stats over {len(ptab1_nom99)} non -99db mixtures out of {nmixtot} total:"']).to_csv(
301
+ snrcsv_name, **header_args
302
+ )
303
+ ptab1_nom99.describe().round(2).T.to_csv(snrcsv_name, index=True, **table_args)
304
+
305
+ # Write summary to .csv
306
+ snrtxt_name = str(join(location, "metric_summary_snr" + fsuffix + ".txt"))
307
+ with open(snrtxt_name, "w") as f:
308
+ print(f"Timestamp: {timestamp}", file=f)
309
+ print("Metrics avg over each SNR:", file=f)
310
+ print(mtab_snr_summary.round(2).to_string(float_format=lambda x: f"{x:.2f}", index=False), file=f)
311
+ print("", file=f)
312
+ print(f"Metrics stats over {len(mixids)} mixtures out of {mixdb.num_mixtures} total:", file=f)
313
+ print(ptab1.describe().round(2).T.to_string(float_format=lambda x: f"{x:.2f}", index=True), file=f)
314
+ print("", file=f)
315
+ print(f"Metrics stats over {len(ptab1_nom99)} non -99db mixtures out of {mixdb.num_mixtures} total:", file=f)
316
+ print(ptab1_nom99.describe().round(2).T.to_string(float_format=lambda x: f"{x:.2f}", index=True), file=f)
317
+
318
+
319
+ if __name__ == "__main__":
320
+ main()
@@ -87,6 +87,7 @@ from .datatypes import TruthParameter
87
87
  from .datatypes import UniversalSNR
88
88
  from .feature import get_audio_from_feature
89
89
  from .feature import get_feature_from_audio
90
+ from .generation import generate_mixtures
90
91
  from .generation import get_all_snrs_from_config
91
92
  from .generation import initialize_db
92
93
  from .generation import populate_class_label_table
@@ -99,7 +100,7 @@ from .generation import populate_target_file_table
99
100
  from .generation import populate_top_table
100
101
  from .generation import populate_truth_parameters_table
101
102
  from .generation import update_mixid_width
102
- from .generation import update_mixture_table
103
+ from .generation import update_mixture
103
104
  from .helpers import augmented_noise_samples
104
105
  from .helpers import augmented_target_samples
105
106
  from .helpers import check_audio_files_exist