sonusai 0.19.9__py3-none-any.whl → 0.19.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/calc_metric_spenh.py +265 -233
- sonusai/data/silero_vad_v5.1.jit +0 -0
- sonusai/data/silero_vad_v5.1.onnx +0 -0
- sonusai/genft.py +1 -1
- sonusai/genmetrics.py +15 -18
- sonusai/genmix.py +1 -1
- sonusai/genmixdb.py +30 -52
- sonusai/metrics_summary.py +320 -0
- sonusai/mixture/__init__.py +2 -1
- sonusai/mixture/audio.py +40 -7
- sonusai/mixture/generation.py +42 -53
- sonusai/mixture/helpers.py +22 -7
- sonusai/mixture/mixdb.py +90 -30
- sonusai/mixture/truth_functions/energy.py +9 -5
- sonusai/mixture/truth_functions/metrics.py +1 -1
- sonusai/mkwav.py +1 -1
- sonusai/onnx_predict.py +1 -1
- sonusai/queries/queries.py +1 -1
- sonusai/utils/asr.py +1 -1
- sonusai/utils/load_object.py +8 -2
- sonusai/utils/stratified_shuffle_split.py +1 -1
- {sonusai-0.19.9.dist-info → sonusai-0.19.10.dist-info}/METADATA +1 -1
- {sonusai-0.19.9.dist-info → sonusai-0.19.10.dist-info}/RECORD +25 -22
- {sonusai-0.19.9.dist-info → sonusai-0.19.10.dist-info}/WHEEL +0 -0
- {sonusai-0.19.9.dist-info → sonusai-0.19.10.dist-info}/entry_points.txt +0 -0
Binary file
|
Binary file
|
sonusai/genft.py
CHANGED
sonusai/genmetrics.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
"""sonusai genmetrics
|
2
2
|
|
3
|
-
usage: genmetrics [-
|
3
|
+
usage: genmetrics [-hvusd] [-i MIXID] [-n INCLUDE] [-x EXCLUDE] LOC
|
4
4
|
|
5
5
|
options:
|
6
6
|
-h, --help
|
@@ -8,6 +8,7 @@ options:
|
|
8
8
|
-i MIXID, --mixid MIXID Mixture ID(s) to generate. [default: *].
|
9
9
|
-n INCLUDE, --include INCLUDE Metrics to include. [default: all]
|
10
10
|
-x EXCLUDE, --exclude EXCLUDE Metrics to exclude. [default: none]
|
11
|
+
-u, --update Update metrics (do not regenerate existing metrics).
|
11
12
|
-s, --supported Show list of supported metrics.
|
12
13
|
-d, --dryrun Show list of metrics that will be generated and exit.
|
13
14
|
|
@@ -60,16 +61,15 @@ def signal_handler(_sig, _frame):
|
|
60
61
|
signal.signal(signal.SIGINT, signal_handler)
|
61
62
|
|
62
63
|
|
63
|
-
def _process_mixture(mixid: int, location: str, metrics: list[str]) ->
|
64
|
+
def _process_mixture(mixid: int, location: str, metrics: list[str], update: bool = False) -> set[str]:
|
64
65
|
from sonusai.mixture import MixtureDatabase
|
65
66
|
from sonusai.mixture import write_cached_data
|
66
67
|
|
67
68
|
mixdb = MixtureDatabase(location)
|
69
|
+
results = mixdb.mixture_metrics(m_id=mixid, metrics=metrics, force=not update)
|
70
|
+
write_cached_data(mixdb.location, "mixture", mixdb.mixture(mixid).name, list(results.items()))
|
68
71
|
|
69
|
-
|
70
|
-
write_data = list(zip(metrics, values, strict=False))
|
71
|
-
|
72
|
-
write_cached_data(mixdb.location, "mixture", mixdb.mixture(mixid).name, write_data)
|
72
|
+
return set(results.keys())
|
73
73
|
|
74
74
|
|
75
75
|
def main() -> None:
|
@@ -85,6 +85,7 @@ def main() -> None:
|
|
85
85
|
mixids = args["--mixid"]
|
86
86
|
includes = {x.strip() for x in args["--include"].replace(" ", ",").lower().split(",") if x != ""}
|
87
87
|
excludes = {x.strip() for x in args["--exclude"].replace(" ", ",").lower().split(",") if x != ""}
|
88
|
+
update = args["--update"]
|
88
89
|
show_supported = args["--supported"]
|
89
90
|
dryrun = args["--dryrun"]
|
90
91
|
location = args["LOC"]
|
@@ -141,20 +142,14 @@ def main() -> None:
|
|
141
142
|
|
142
143
|
requested = included_metrics - excluded_metrics
|
143
144
|
|
144
|
-
|
145
|
-
dependencies: set[str] = set()
|
146
|
-
for metric in requested:
|
147
|
-
if metric.startswith("mxwer"):
|
148
|
-
dependencies.add("mxasr." + metric[6:])
|
149
|
-
dependencies.add("tasr." + metric[6:])
|
150
|
-
|
151
|
-
metrics = sorted(requested | dependencies)
|
145
|
+
metrics = sorted(requested)
|
152
146
|
|
153
147
|
if len(metrics) == 0:
|
154
148
|
logger.warning("No metrics were requested")
|
155
149
|
sys.exit(1)
|
156
150
|
|
157
|
-
logger.info(
|
151
|
+
logger.info("Generating metrics:")
|
152
|
+
logger.info(f"{', '.join(metrics)}")
|
158
153
|
if dryrun:
|
159
154
|
sys.exit(0)
|
160
155
|
|
@@ -163,14 +158,16 @@ def main() -> None:
|
|
163
158
|
logger.info(f"Found {len(mixids):,} mixtures to process")
|
164
159
|
|
165
160
|
progress = track(total=len(mixids), desc="genmetrics")
|
166
|
-
par_track(
|
167
|
-
partial(_process_mixture, location=location, metrics=metrics),
|
161
|
+
results = par_track(
|
162
|
+
partial(_process_mixture, location=location, metrics=metrics, update=update),
|
168
163
|
mixids,
|
169
164
|
progress=progress,
|
170
165
|
)
|
171
166
|
progress.close()
|
172
167
|
|
173
|
-
|
168
|
+
written_metrics = sorted(set().union(*results))
|
169
|
+
logger.info(f"Wrote metrics for {len(mixids)} mixtures to {location}:")
|
170
|
+
logger.info(f"{', '.join(written_metrics)}")
|
174
171
|
logger.info("")
|
175
172
|
|
176
173
|
end_time = time.monotonic()
|
sonusai/genmix.py
CHANGED
@@ -139,7 +139,7 @@ def _genmix_kernel(
|
|
139
139
|
result.mixture = mixture
|
140
140
|
if write:
|
141
141
|
write_cached_data(mixdb.location, "mixture", mixdb.mixture(m_id).name, [("mixture", mixture)])
|
142
|
-
write_mixture_metadata(mixdb, m_id)
|
142
|
+
write_mixture_metadata(mixdb, m_id=m_id)
|
143
143
|
|
144
144
|
return result
|
145
145
|
|
sonusai/genmixdb.py
CHANGED
@@ -1,13 +1,11 @@
|
|
1
1
|
"""sonusai genmixdb
|
2
2
|
|
3
|
-
usage: genmixdb [-
|
3
|
+
usage: genmixdb [-hvmdjn] LOC
|
4
4
|
|
5
5
|
options:
|
6
6
|
-h, --help
|
7
7
|
-v, --verbose Be verbose.
|
8
8
|
-m, --mix ave mixture data. [default: False].
|
9
|
-
-f, --ft Save feature/truth_f data. [default: False].
|
10
|
-
-s, --segsnr Save segsnr data. [default: False].
|
11
9
|
-d, --dryrun Perform a dry run showing the processed config. [default: False].
|
12
10
|
-j, --json Save JSON version of database. [default: False].
|
13
11
|
-n, --nopar Do not run in parallel. [default: False].
|
@@ -116,6 +114,9 @@ will find all .wav files in the specified directories and process them as target
|
|
116
114
|
|
117
115
|
import signal
|
118
116
|
|
117
|
+
from sonusai.mixture import Mixture
|
118
|
+
from sonusai.mixture import MixtureDatabase
|
119
|
+
|
119
120
|
|
120
121
|
def signal_handler(_sig, _frame):
|
121
122
|
import sys
|
@@ -132,8 +133,6 @@ signal.signal(signal.SIGINT, signal_handler)
|
|
132
133
|
def genmixdb(
|
133
134
|
location: str,
|
134
135
|
save_mix: bool = False,
|
135
|
-
save_ft: bool = False,
|
136
|
-
save_segsnr: bool = False,
|
137
136
|
logging: bool = True,
|
138
137
|
show_progress: bool = False,
|
139
138
|
test: bool = False,
|
@@ -151,6 +150,7 @@ def genmixdb(
|
|
151
150
|
from sonusai.mixture import AugmentationRule
|
152
151
|
from sonusai.mixture import MixtureDatabase
|
153
152
|
from sonusai.mixture import balance_targets
|
153
|
+
from sonusai.mixture import generate_mixtures
|
154
154
|
from sonusai.mixture import get_all_snrs_from_config
|
155
155
|
from sonusai.mixture import get_augmentation_rules
|
156
156
|
from sonusai.mixture import get_augmented_targets
|
@@ -316,8 +316,10 @@ def genmixdb(
|
|
316
316
|
f"{seconds_to_hms(seconds=noise_audio_duration)}"
|
317
317
|
)
|
318
318
|
|
319
|
-
|
320
|
-
|
319
|
+
if logging:
|
320
|
+
logger.info("Generating mixtures")
|
321
|
+
|
322
|
+
used_noise_files, used_noise_samples, mixtures = generate_mixtures(
|
321
323
|
noise_mix_mode=mixdb.noise_mix_mode,
|
322
324
|
augmented_targets=augmented_targets,
|
323
325
|
target_files=target_files,
|
@@ -330,17 +332,16 @@ def genmixdb(
|
|
330
332
|
num_classes=mixdb.num_classes,
|
331
333
|
feature_step_samples=mixdb.feature_step_samples,
|
332
334
|
num_ir=mixdb.num_impulse_response_files,
|
333
|
-
test=test,
|
334
335
|
)
|
335
336
|
|
336
|
-
num_mixtures = len(
|
337
|
+
num_mixtures = len(mixtures)
|
337
338
|
update_mixid_width(location, num_mixtures, test)
|
338
339
|
|
339
340
|
if logging:
|
340
341
|
logger.info("")
|
341
342
|
logger.info(f"Found {num_mixtures:,} mixtures to process")
|
342
343
|
|
343
|
-
total_duration = float(sum([mixture.samples for mixture in
|
344
|
+
total_duration = float(sum([mixture.samples for mixture in mixtures])) / SAMPLE_RATE
|
344
345
|
|
345
346
|
if logging:
|
346
347
|
log_duration_and_sizes(
|
@@ -362,23 +363,29 @@ def genmixdb(
|
|
362
363
|
|
363
364
|
# Fill in the details
|
364
365
|
if logging:
|
365
|
-
logger.info("
|
366
|
+
logger.info("Processing mixtures")
|
366
367
|
progress = track(total=num_mixtures, disable=not show_progress)
|
367
|
-
par_track(
|
368
|
+
mixtures = par_track(
|
368
369
|
partial(
|
369
370
|
_process_mixture,
|
370
371
|
location=location,
|
371
372
|
save_mix=save_mix,
|
372
|
-
save_ft=save_ft,
|
373
|
-
save_segsnr=save_segsnr,
|
374
373
|
test=test,
|
375
374
|
),
|
376
|
-
|
375
|
+
mixtures,
|
377
376
|
progress=progress,
|
378
377
|
no_par=no_par,
|
379
378
|
)
|
380
379
|
progress.close()
|
381
380
|
|
381
|
+
populate_mixture_table(
|
382
|
+
location=location,
|
383
|
+
mixtures=mixtures,
|
384
|
+
test=test,
|
385
|
+
logging=logging,
|
386
|
+
show_progress=show_progress,
|
387
|
+
)
|
388
|
+
|
382
389
|
total_noise_files = len(noise_files)
|
383
390
|
|
384
391
|
total_samples = mixdb.total_samples()
|
@@ -409,32 +416,23 @@ def genmixdb(
|
|
409
416
|
|
410
417
|
|
411
418
|
def _process_mixture(
|
412
|
-
|
419
|
+
mixture: Mixture,
|
413
420
|
location: str,
|
414
421
|
save_mix: bool,
|
415
|
-
save_ft: bool,
|
416
|
-
save_segsnr: bool,
|
417
422
|
test: bool,
|
418
|
-
) ->
|
423
|
+
) -> Mixture:
|
419
424
|
from functools import partial
|
420
425
|
|
421
|
-
from sonusai.mixture import
|
422
|
-
from sonusai.mixture import clear_cached_data
|
423
|
-
from sonusai.mixture import update_mixture_table
|
426
|
+
from sonusai.mixture import update_mixture
|
424
427
|
from sonusai.mixture import write_cached_data
|
425
428
|
from sonusai.mixture import write_mixture_metadata
|
426
429
|
|
427
|
-
|
428
|
-
|
429
|
-
genmix_data = update_mixture_table(location, m_id, with_data, test)
|
430
|
-
|
431
|
-
mixdb = MixtureDatabase(location, test)
|
432
|
-
mixture = mixdb.mixture(m_id)
|
430
|
+
mixdb = MixtureDatabase(location, test=test)
|
431
|
+
mixture, genmix_data = update_mixture(mixdb, mixture, save_mix)
|
433
432
|
|
434
433
|
write = partial(write_cached_data, location=location, name="mixture", index=mixture.name)
|
435
|
-
clear = partial(clear_cached_data, location=location, name="mixture", index=mixture.name)
|
436
434
|
|
437
|
-
if
|
435
|
+
if save_mix:
|
438
436
|
write(
|
439
437
|
items=[
|
440
438
|
("targets", genmix_data.targets),
|
@@ -444,25 +442,9 @@ def _process_mixture(
|
|
444
442
|
]
|
445
443
|
)
|
446
444
|
|
447
|
-
|
448
|
-
clear(items=["feature", "truth_f"])
|
449
|
-
feature, truth_f = mixdb.mixture_ft(m_id)
|
450
|
-
write(
|
451
|
-
items=[
|
452
|
-
("feature", feature),
|
453
|
-
("truth_f", truth_f),
|
454
|
-
]
|
455
|
-
)
|
456
|
-
|
457
|
-
if save_segsnr:
|
458
|
-
clear(items=["segsnr"])
|
459
|
-
segsnr = mixdb.mixture_segsnr(m_id)
|
460
|
-
write(items=[("segsnr", segsnr)])
|
461
|
-
|
462
|
-
if not save_mix:
|
463
|
-
clear(items=["targets", "target", "noise", "mixture"])
|
445
|
+
write_mixture_metadata(mixdb, mixture=mixture)
|
464
446
|
|
465
|
-
|
447
|
+
return mixture
|
466
448
|
|
467
449
|
|
468
450
|
def main() -> None:
|
@@ -491,8 +473,6 @@ def main() -> None:
|
|
491
473
|
|
492
474
|
verbose = args["--verbose"]
|
493
475
|
save_mix = args["--mix"]
|
494
|
-
save_ft = args["--ft"]
|
495
|
-
save_segsnr = args["--segsnr"]
|
496
476
|
dryrun = args["--dryrun"]
|
497
477
|
save_json = args["--json"]
|
498
478
|
no_par = args["--nopar"]
|
@@ -522,8 +502,6 @@ def main() -> None:
|
|
522
502
|
genmixdb(
|
523
503
|
location=location,
|
524
504
|
save_mix=save_mix,
|
525
|
-
save_ft=save_ft,
|
526
|
-
save_segsnr=save_segsnr,
|
527
505
|
show_progress=True,
|
528
506
|
save_json=save_json,
|
529
507
|
no_par=no_par,
|
@@ -0,0 +1,320 @@
|
|
1
|
+
"""sonusai metrics_summary
|
2
|
+
|
3
|
+
usage: lsdb [-vlh] [-i MIXID] [-n NCPU] LOCATION
|
4
|
+
|
5
|
+
Options:
|
6
|
+
-h, --help
|
7
|
+
-v, --verbose
|
8
|
+
-l, --write-list Write .csv file list of all mixture metrics
|
9
|
+
-i MIXID, --mixid MIXID Mixture ID(s) to analyze. [default: *].
|
10
|
+
-n, --num_process NCPU Number of parallel processes to use [default: auto]
|
11
|
+
|
12
|
+
Summarize mixture metrics across a SonusAI mixture database where metrics have been generated by SonusAI genmetrics.
|
13
|
+
|
14
|
+
Inputs:
|
15
|
+
LOCATION A SonusAI mixture database directory with mixdb.db and pre-generated metrics from SonusAI genmetrics.
|
16
|
+
|
17
|
+
"""
|
18
|
+
|
19
|
+
import signal
|
20
|
+
|
21
|
+
import numpy as np
|
22
|
+
import pandas as pd
|
23
|
+
|
24
|
+
|
25
|
+
def signal_handler(_sig, _frame):
|
26
|
+
import sys
|
27
|
+
|
28
|
+
from sonusai import logger
|
29
|
+
|
30
|
+
logger.info("Canceled due to keyboard interrupt")
|
31
|
+
sys.exit(1)
|
32
|
+
|
33
|
+
|
34
|
+
signal.signal(signal.SIGINT, signal_handler)
|
35
|
+
|
36
|
+
DB_99 = np.power(10, 99 / 10)
|
37
|
+
DB_N99 = np.power(10, -99 / 10)
|
38
|
+
|
39
|
+
|
40
|
+
def _process_mixture(
|
41
|
+
m_id: int,
|
42
|
+
location: str,
|
43
|
+
all_metric_names: list[str],
|
44
|
+
scalar_metric_names: list[str],
|
45
|
+
string_metric_names: list[str],
|
46
|
+
frame_metric_names: list[str],
|
47
|
+
bin_metric_names: list[str],
|
48
|
+
ptab_labels: list[str],
|
49
|
+
) -> tuple[pd.DataFrame, pd.DataFrame]:
|
50
|
+
from os.path import basename
|
51
|
+
|
52
|
+
from sonusai.metrics import calc_wer
|
53
|
+
from sonusai.mixture import SAMPLE_RATE
|
54
|
+
from sonusai.mixture import MixtureDatabase
|
55
|
+
|
56
|
+
mixdb = MixtureDatabase(location)
|
57
|
+
|
58
|
+
# Process mixture
|
59
|
+
# for mixid in mixids:
|
60
|
+
samples = mixdb.mixture(m_id).samples
|
61
|
+
duration = samples / SAMPLE_RATE
|
62
|
+
tf_frames = mixdb.mixture_transform_frames(m_id)
|
63
|
+
feat_frames = mixdb.mixture_feature_frames(m_id)
|
64
|
+
mxsnr = mixdb.mixture(m_id).snr
|
65
|
+
ti = mixdb.mixture(m_id).targets[0].file_id
|
66
|
+
ni = mixdb.mixture(m_id).noise.file_id
|
67
|
+
t0file = basename(mixdb.target_file(ti).name)
|
68
|
+
nfile = basename(mixdb.noise_file(ni).name)
|
69
|
+
|
70
|
+
all_metrics = mixdb.mixture_metrics(m_id, all_metric_names)
|
71
|
+
|
72
|
+
# replace lists with first value (ignore mixup)
|
73
|
+
scalar_metrics = {
|
74
|
+
key: all_metrics[key][0] if isinstance(all_metrics[key], list) else all_metrics[key]
|
75
|
+
for key in scalar_metric_names
|
76
|
+
}
|
77
|
+
string_metrics = {
|
78
|
+
key: all_metrics[key][0] if isinstance(all_metrics[key], list) else all_metrics[key]
|
79
|
+
for key in string_metric_names
|
80
|
+
}
|
81
|
+
|
82
|
+
# Convert strings into word count
|
83
|
+
for key in string_metrics:
|
84
|
+
string_metrics[key] = calc_wer(string_metrics[key], string_metrics[key]).words
|
85
|
+
|
86
|
+
# Collect pandas table values note: must match given ptab_labels
|
87
|
+
ptab_data: list = [
|
88
|
+
mxsnr,
|
89
|
+
*scalar_metrics.values(),
|
90
|
+
*string_metrics.values(),
|
91
|
+
tf_frames,
|
92
|
+
duration,
|
93
|
+
t0file,
|
94
|
+
nfile,
|
95
|
+
]
|
96
|
+
|
97
|
+
ptab1 = pd.DataFrame([ptab_data], columns=ptab_labels, index=[m_id])
|
98
|
+
|
99
|
+
# TODO: collect frame metrics and bin metrics
|
100
|
+
|
101
|
+
return ptab1, ptab1
|
102
|
+
|
103
|
+
|
104
|
+
def main() -> None:
|
105
|
+
from docopt import docopt
|
106
|
+
|
107
|
+
from sonusai import __version__ as sonusai_ver
|
108
|
+
from sonusai.utils import trim_docstring
|
109
|
+
|
110
|
+
args = docopt(trim_docstring(__doc__), version=sonusai_ver, options_first=True)
|
111
|
+
|
112
|
+
verbose = args["--verbose"]
|
113
|
+
wrlist = args["--write-list"]
|
114
|
+
mixids = args["--mixid"]
|
115
|
+
location = args["LOCATION"]
|
116
|
+
num_proc = args["--num_process"]
|
117
|
+
|
118
|
+
from functools import partial
|
119
|
+
from os.path import basename
|
120
|
+
from os.path import join
|
121
|
+
|
122
|
+
import psutil
|
123
|
+
|
124
|
+
from sonusai import create_file_handler
|
125
|
+
from sonusai import initial_log_messages
|
126
|
+
from sonusai import logger
|
127
|
+
from sonusai import update_console_handler
|
128
|
+
from sonusai.mixture import MixtureDatabase
|
129
|
+
from sonusai.utils import create_timestamp
|
130
|
+
from sonusai.utils import par_track
|
131
|
+
from sonusai.utils import track
|
132
|
+
|
133
|
+
try:
|
134
|
+
mixdb = MixtureDatabase(location)
|
135
|
+
print(f"Found SonusAI mixture database with {mixdb.num_mixtures} mixtures.")
|
136
|
+
except:
|
137
|
+
print(f"Could not open SonusAI mixture database in {location}, exiting ...")
|
138
|
+
return
|
139
|
+
|
140
|
+
metrics_present = mixdb.cached_metrics()
|
141
|
+
num_metrics_present = len(metrics_present)
|
142
|
+
if num_metrics_present < 1:
|
143
|
+
print(f"mixdb reports no pre-generated metrics are present. Nothing to summarize in {location}, exiting ...")
|
144
|
+
return
|
145
|
+
|
146
|
+
# Setup logging file
|
147
|
+
timestamp = create_timestamp() # string good for embedding into filenames
|
148
|
+
mixdb_fname = basename(location)
|
149
|
+
if verbose:
|
150
|
+
create_file_handler(join(location, "metrics_summary.log"))
|
151
|
+
update_console_handler(verbose)
|
152
|
+
initial_log_messages("metrics_summary")
|
153
|
+
logger.info(f"Logging summary of SonusAI mixture db at {location}")
|
154
|
+
else:
|
155
|
+
update_console_handler(verbose)
|
156
|
+
|
157
|
+
logger.info("")
|
158
|
+
mixids = mixdb.mixids_to_list(mixids)
|
159
|
+
if len(mixids) < mixdb.num_mixtures:
|
160
|
+
logger.info(
|
161
|
+
f"Processing a subset of {len(mixids)} out of total mixdb mixtures of {mixdb.num_mixtures}, "
|
162
|
+
f"summary results will not include entire dataset."
|
163
|
+
)
|
164
|
+
fsuffix = f"_s{len(mixids)}t{mixdb.num_mixtures}"
|
165
|
+
else:
|
166
|
+
logger.info(
|
167
|
+
f"Summarizing SonusAI mixture db with {mixdb.num_mixtures} mixtures "
|
168
|
+
f"and {num_metrics_present} pre-generated metrics ..."
|
169
|
+
)
|
170
|
+
fsuffix = ""
|
171
|
+
|
172
|
+
metric_sup = mixdb.supported_metrics
|
173
|
+
ft_bins = mixdb.ft_config.bin_end - mixdb.ft_config.bin_start + 1 # bins of forward transform
|
174
|
+
# Pre-process first mixid to gather metrics into 4 types: scalar, str (scalar word cnt), frame-array, bin-array
|
175
|
+
# Collect list of indices for each
|
176
|
+
scalar_metric_names: list[str] = []
|
177
|
+
string_metric_names: list[str] = []
|
178
|
+
frame_metric_names: list[str] = []
|
179
|
+
bin_metric_names: list[str] = []
|
180
|
+
all_metrics = mixdb.mixture_metrics(mixids[0], metrics_present)
|
181
|
+
tf_frames = mixdb.mixture_transform_frames(mixids[0])
|
182
|
+
for metric in metrics_present:
|
183
|
+
metval = all_metrics[metric] # get metric value
|
184
|
+
logger.debug(f"First mixid {mixids[0]} metric {metric} = {metval}")
|
185
|
+
if isinstance(metval, list):
|
186
|
+
if len(metval) > 1:
|
187
|
+
logger.warning(f"Mixid {mixids[0]} metric {metric} has a list with more than 1 element, using first.")
|
188
|
+
metval = metval[0] # remove any list
|
189
|
+
if isinstance(metval, float):
|
190
|
+
logger.debug("Metric is scalar float, entering in summary table.")
|
191
|
+
scalar_metric_names.append(metric)
|
192
|
+
elif isinstance(metval, str):
|
193
|
+
logger.debug("Metric is string, will summarize with word count.")
|
194
|
+
string_metric_names.append(metric)
|
195
|
+
elif isinstance(metval, np.ndarray):
|
196
|
+
if metval.ndim == 1:
|
197
|
+
if metval.size == tf_frames:
|
198
|
+
logger.debug("Metric is frames vector.")
|
199
|
+
frame_metric_names.append(metric)
|
200
|
+
elif metval.size == ft_bins:
|
201
|
+
logger.debug("Metric is bins vector.")
|
202
|
+
bin_metric_names.append(metric)
|
203
|
+
else:
|
204
|
+
logger.warning(f"Mixid {mixids[0]} metric {metric} is a vector of improper size, ignoring.")
|
205
|
+
|
206
|
+
# Setup pandas table for summarizing scalar metrics
|
207
|
+
ptab_labels = [
|
208
|
+
"mxsnr",
|
209
|
+
*scalar_metric_names,
|
210
|
+
*string_metric_names,
|
211
|
+
"fcnt",
|
212
|
+
"duration",
|
213
|
+
"t0file",
|
214
|
+
"nfile",
|
215
|
+
]
|
216
|
+
|
217
|
+
num_cpu = psutil.cpu_count()
|
218
|
+
cpu_percent = psutil.cpu_percent(interval=1)
|
219
|
+
logger.info("")
|
220
|
+
logger.info(f"#CPUs: {num_cpu}, current CPU utilization: {cpu_percent}%")
|
221
|
+
logger.info(f"Memory utilization: {psutil.virtual_memory().percent}%")
|
222
|
+
if num_proc == "auto":
|
223
|
+
use_cpu = int(num_cpu * (0.9 - cpu_percent / 100)) # default use 80% of available cpus
|
224
|
+
elif num_proc == "None":
|
225
|
+
use_cpu = None
|
226
|
+
else:
|
227
|
+
use_cpu = min(max(int(num_proc), 1), num_cpu)
|
228
|
+
|
229
|
+
logger.info(f"Summarizing metrics for {len(mixids)} mixtures using {use_cpu} parallel processes")
|
230
|
+
|
231
|
+
# progress = tqdm(total=len(mixids), desc='calc_metric_spenh', mininterval=1)
|
232
|
+
progress = track(total=len(mixids))
|
233
|
+
if use_cpu is None:
|
234
|
+
no_par = True
|
235
|
+
num_cpus = None
|
236
|
+
else:
|
237
|
+
no_par = False
|
238
|
+
num_cpus = use_cpu
|
239
|
+
|
240
|
+
all_metrics_tables = par_track(
|
241
|
+
partial(
|
242
|
+
_process_mixture,
|
243
|
+
location=location,
|
244
|
+
all_metric_names=metrics_present,
|
245
|
+
scalar_metric_names=scalar_metric_names,
|
246
|
+
string_metric_names=string_metric_names,
|
247
|
+
frame_metric_names=frame_metric_names,
|
248
|
+
bin_metric_names=bin_metric_names,
|
249
|
+
ptab_labels=ptab_labels,
|
250
|
+
),
|
251
|
+
mixids,
|
252
|
+
progress=progress,
|
253
|
+
num_cpus=num_cpus,
|
254
|
+
no_par=no_par,
|
255
|
+
)
|
256
|
+
progress.close()
|
257
|
+
|
258
|
+
# Done with mixtures, write out summary metrics
|
259
|
+
header_args = {
|
260
|
+
"mode": "a",
|
261
|
+
"encoding": "utf-8",
|
262
|
+
"index": False,
|
263
|
+
"header": False,
|
264
|
+
}
|
265
|
+
table_args = {
|
266
|
+
"mode": "a",
|
267
|
+
"encoding": "utf-8",
|
268
|
+
}
|
269
|
+
ptab1 = pd.concat([item[0] for item in all_metrics_tables])
|
270
|
+
if wrlist:
|
271
|
+
wlcsv_name = str(join(location, "metric_summary_list" + fsuffix + ".csv"))
|
272
|
+
pd.DataFrame([["Timestamp", timestamp]]).to_csv(wlcsv_name, header=False, index=False)
|
273
|
+
pd.DataFrame([f"Metric list for {mixdb_fname}:"]).to_csv(wlcsv_name, mode="a", header=False, index=False)
|
274
|
+
ptab1.round(2).to_csv(wlcsv_name, **table_args)
|
275
|
+
ptab1_sorted = ptab1.sort_values(by=["mxsnr", "t0file"])
|
276
|
+
|
277
|
+
# Create metrics table except except -99 SNR
|
278
|
+
ptab1_nom99 = ptab1_sorted[ptab1_sorted.mxsnr != -99]
|
279
|
+
|
280
|
+
# Create summary by SNR for all scalar metrics, taking mean
|
281
|
+
mtab_snr_summary = None
|
282
|
+
for snri in range(0, len(mixdb.snrs)):
|
283
|
+
tmp = ptab1_sorted.query("mxsnr==" + str(mixdb.snrs[snri])).mean(numeric_only=True).to_frame().T
|
284
|
+
# avoid nan when subset of mixids specified (i.e. no mixtures exist for an SNR)
|
285
|
+
if ~np.isnan(tmp.iloc[0].to_numpy()[0]).any():
|
286
|
+
mtab_snr_summary = pd.concat([mtab_snr_summary, tmp])
|
287
|
+
mtab_snr_summary = mtab_snr_summary.sort_values(by=["mxsnr"], ascending=False)
|
288
|
+
|
289
|
+
# Write summary to .csv
|
290
|
+
snrcsv_name = str(join(location, "metric_summary_snr" + fsuffix + ".csv"))
|
291
|
+
nmix = len(mixids)
|
292
|
+
nmixtot = mixdb.num_mixtures
|
293
|
+
pd.DataFrame([["Timestamp", timestamp]]).to_csv(snrcsv_name, header=False, index=False)
|
294
|
+
pd.DataFrame(['"Metrics avg over each SNR:"']).to_csv(snrcsv_name, **header_args)
|
295
|
+
mtab_snr_summary.round(2).to_csv(snrcsv_name, index=False, **table_args)
|
296
|
+
pd.DataFrame(["--"]).to_csv(snrcsv_name, header=False, index=False, mode="a")
|
297
|
+
pd.DataFrame([f'"Metrics stats over {nmix} mixtures out of {nmixtot} total:"']).to_csv(snrcsv_name, **header_args)
|
298
|
+
ptab1.describe().round(2).T.to_csv(snrcsv_name, index=True, **table_args)
|
299
|
+
pd.DataFrame(["--"]).to_csv(snrcsv_name, header=False, index=False, mode="a")
|
300
|
+
pd.DataFrame([f'"Metrics stats over {len(ptab1_nom99)} non -99db mixtures out of {nmixtot} total:"']).to_csv(
|
301
|
+
snrcsv_name, **header_args
|
302
|
+
)
|
303
|
+
ptab1_nom99.describe().round(2).T.to_csv(snrcsv_name, index=True, **table_args)
|
304
|
+
|
305
|
+
# Write summary to .csv
|
306
|
+
snrtxt_name = str(join(location, "metric_summary_snr" + fsuffix + ".txt"))
|
307
|
+
with open(snrtxt_name, "w") as f:
|
308
|
+
print(f"Timestamp: {timestamp}", file=f)
|
309
|
+
print("Metrics avg over each SNR:", file=f)
|
310
|
+
print(mtab_snr_summary.round(2).to_string(float_format=lambda x: f"{x:.2f}", index=False), file=f)
|
311
|
+
print("", file=f)
|
312
|
+
print(f"Metrics stats over {len(mixids)} mixtures out of {mixdb.num_mixtures} total:", file=f)
|
313
|
+
print(ptab1.describe().round(2).T.to_string(float_format=lambda x: f"{x:.2f}", index=True), file=f)
|
314
|
+
print("", file=f)
|
315
|
+
print(f"Metrics stats over {len(ptab1_nom99)} non -99db mixtures out of {mixdb.num_mixtures} total:", file=f)
|
316
|
+
print(ptab1_nom99.describe().round(2).T.to_string(float_format=lambda x: f"{x:.2f}", index=True), file=f)
|
317
|
+
|
318
|
+
|
319
|
+
if __name__ == "__main__":
|
320
|
+
main()
|
sonusai/mixture/__init__.py
CHANGED
@@ -87,6 +87,7 @@ from .datatypes import TruthParameter
|
|
87
87
|
from .datatypes import UniversalSNR
|
88
88
|
from .feature import get_audio_from_feature
|
89
89
|
from .feature import get_feature_from_audio
|
90
|
+
from .generation import generate_mixtures
|
90
91
|
from .generation import get_all_snrs_from_config
|
91
92
|
from .generation import initialize_db
|
92
93
|
from .generation import populate_class_label_table
|
@@ -99,7 +100,7 @@ from .generation import populate_target_file_table
|
|
99
100
|
from .generation import populate_top_table
|
100
101
|
from .generation import populate_truth_parameters_table
|
101
102
|
from .generation import update_mixid_width
|
102
|
-
from .generation import
|
103
|
+
from .generation import update_mixture
|
103
104
|
from .helpers import augmented_noise_samples
|
104
105
|
from .helpers import augmented_target_samples
|
105
106
|
from .helpers import check_audio_files_exist
|