sonusai 1.0.7__py3-none-any.whl → 1.0.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +8 -7
- sonusai/audiofe.py +1 -1
- sonusai/calc_metric_spenh.py +9 -9
- sonusai/doc/doc.py +64 -61
- sonusai/genft.py +1 -1
- sonusai/genmetrics.py +3 -3
- sonusai/genmix.py +1 -1
- sonusai/genmixdb.py +1 -1
- sonusai/lsdb.py +1 -1
- sonusai/metrics_summary.py +7 -8
- sonusai/mixture/__init__.py +1 -1
- sonusai/mixture/db.py +163 -0
- sonusai/mixture/db_file.py +10 -0
- sonusai/mixture/effects.py +19 -52
- sonusai/mixture/generation.py +331 -391
- sonusai/mixture/mixdb.py +11 -68
- sonusai/mkwav.py +1 -1
- sonusai/onnx_predict.py +1 -1
- sonusai/utils/numeric_conversion.py +2 -2
- sonusai/utils/print_mixture_details.py +24 -28
- {sonusai-1.0.7.dist-info → sonusai-1.0.9.dist-info}/METADATA +2 -1
- {sonusai-1.0.7.dist-info → sonusai-1.0.9.dist-info}/RECORD +24 -22
- {sonusai-1.0.7.dist-info → sonusai-1.0.9.dist-info}/WHEEL +0 -0
- {sonusai-1.0.7.dist-info → sonusai-1.0.9.dist-info}/entry_points.txt +0 -0
sonusai/__init__.py
CHANGED
@@ -36,7 +36,7 @@ logger_db = logging.getLogger("sonusai_db")
|
|
36
36
|
logger_db.setLevel(logging.DEBUG)
|
37
37
|
|
38
38
|
# create file handler
|
39
|
-
def create_file_handler(filename: str) -> None:
|
39
|
+
def create_file_handler(filename: str, verbose: bool = False) -> None:
|
40
40
|
from pathlib import Path
|
41
41
|
|
42
42
|
fh = logging.FileHandler(filename=filename, mode="w")
|
@@ -44,12 +44,13 @@ def create_file_handler(filename: str) -> None:
|
|
44
44
|
fh.setFormatter(formatter)
|
45
45
|
logger.addHandler(fh)
|
46
46
|
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
47
|
+
if verbose:
|
48
|
+
filename_db = Path(filename)
|
49
|
+
filename_db = filename_db.parent / (filename_db.stem + "_dbtrace" + filename_db.suffix)
|
50
|
+
fh = logging.FileHandler(filename=filename_db, mode="w")
|
51
|
+
fh.setLevel(logging.DEBUG)
|
52
|
+
fh.setFormatter(formatter_db)
|
53
|
+
logger_db.addHandler(fh)
|
53
54
|
|
54
55
|
|
55
56
|
# update console handler
|
sonusai/audiofe.py
CHANGED
sonusai/calc_metric_spenh.py
CHANGED
@@ -531,10 +531,10 @@ def _process_mixture(
|
|
531
531
|
pesq_speech = calc_pesq(target_est_wav, target_fi)
|
532
532
|
csig_tg, cbak_tg, covl_tg = calc_speech(target_est_wav, target_fi, pesq=pesq_speech)
|
533
533
|
metrics = mixdb.mixture_metrics(m_id, ["mxpesq", "mxcsig", "mxcbak", "mxcovl"])
|
534
|
-
pesq_mx = metrics["mxpesq"][
|
535
|
-
csig_mx = metrics["mxcsig"][
|
536
|
-
cbak_mx = metrics["mxcbak"][
|
537
|
-
covl_mx = metrics["mxcovl"][
|
534
|
+
pesq_mx = metrics["mxpesq"]["primary"] if isinstance(metrics["mxpesq"], dict) else metrics["mxpesq"]
|
535
|
+
csig_mx = metrics["mxcsig"]["primary"] if isinstance(metrics["mxcsig"], dict) else metrics["mxcsig"]
|
536
|
+
cbak_mx = metrics["mxcbak"]["primary"] if isinstance(metrics["mxcbak"], dict) else metrics["mxcbak"]
|
537
|
+
covl_mx = metrics["mxcovl"]["primary"] if isinstance(metrics["mxcovl"], dict) else metrics["mxcovl"]
|
538
538
|
# pesq_speech_tst = calc_pesq(hypothesis=target_est_wav, reference=target)
|
539
539
|
# pesq_mixture_tst = calc_pesq(hypothesis=mixture, reference=target)
|
540
540
|
# pesq improvement
|
@@ -560,11 +560,11 @@ def _process_mixture(
|
|
560
560
|
if asr_method is not None and mixdb.mixture(m_id).noise.snr >= -96: # noise only, ignore/reset target ASR
|
561
561
|
asr_mx_name = f"mxasr.{asr_method}"
|
562
562
|
wer_mx_name = f"mxwer.{asr_method}"
|
563
|
-
asr_tt_name = f"
|
563
|
+
asr_tt_name = f"sasr.{asr_method}"
|
564
564
|
metrics = mixdb.mixture_metrics(m_id, [asr_mx_name, wer_mx_name, asr_tt_name])
|
565
|
-
asr_mx = metrics[asr_mx_name][
|
566
|
-
wer_mx = metrics[wer_mx_name][
|
567
|
-
asr_tt = metrics[asr_tt_name][
|
565
|
+
asr_mx = metrics[asr_mx_name]["primary"] if isinstance(metrics[asr_mx_name], dict) else metrics[asr_mx_name]
|
566
|
+
wer_mx = metrics[wer_mx_name]["primary"] if isinstance(metrics[wer_mx_name], dict) else metrics[wer_mx_name]
|
567
|
+
asr_tt = metrics[asr_tt_name]["primary"] if isinstance(metrics[asr_tt_name], dict) else metrics[asr_tt_name]
|
568
568
|
|
569
569
|
if asr_tt:
|
570
570
|
noiseadd = None # TBD add as switch, default -30
|
@@ -849,7 +849,7 @@ def main():
|
|
849
849
|
logger.info(f"Found predict log {basename(predict_logfile[0])} in predict location.")
|
850
850
|
|
851
851
|
# Setup logging file
|
852
|
-
create_file_handler(join(predict_location, "calc_metric_spenh.log"))
|
852
|
+
create_file_handler(join(predict_location, "calc_metric_spenh.log"), verbose)
|
853
853
|
update_console_handler(verbose)
|
854
854
|
initial_log_messages("calc_metric_spenh")
|
855
855
|
|
sonusai/doc/doc.py
CHANGED
@@ -21,12 +21,12 @@ to use.
|
|
21
21
|
# fmt: on
|
22
22
|
|
23
23
|
|
24
|
-
def
|
25
|
-
default = f"\nDefault value: {get_default_config()['
|
24
|
+
def doc_level_type() -> str:
|
25
|
+
default = f"\nDefault value: {get_default_config()['level_type']}"
|
26
26
|
# fmt: off
|
27
27
|
return """
|
28
|
-
'
|
29
|
-
algorithm to use to determine
|
28
|
+
'level_type' is a mixture database configuration parameter that sets the
|
29
|
+
algorithm to use to determine energy level for SNR calculations.
|
30
30
|
Supported values are:
|
31
31
|
|
32
32
|
default mean of squares
|
@@ -35,31 +35,31 @@ Supported values are:
|
|
35
35
|
# fmt: on
|
36
36
|
|
37
37
|
|
38
|
-
def
|
39
|
-
default = f"\nDefault value: {get_default_config()['
|
38
|
+
def doc_sources() -> str:
|
39
|
+
default = f"\nDefault value: {get_default_config()['sources']}"
|
40
40
|
# fmt: off
|
41
41
|
return """
|
42
|
-
'
|
43
|
-
|
42
|
+
'sources' is a mixture database configuration parameter that sets the list of
|
43
|
+
sources to use.
|
44
44
|
|
45
|
-
|
46
|
-
|
47
|
-
'name'
|
48
|
-
File name. May be one of the following:
|
45
|
+
Two sources are required: 'primary' and 'noise'. Additional sources may be
|
46
|
+
specified with arbitrary names.
|
49
47
|
|
50
|
-
|
51
|
-
glob Matches file glob patterns
|
52
|
-
.yml The given YAML file is parsed into the list
|
53
|
-
.txt Each line in the given text file indicates an item which
|
54
|
-
may be anything in this list (audio, glob, .yml, or .txt)
|
48
|
+
Each source has the following fields:
|
55
49
|
|
56
|
-
|
50
|
+
'files' Required list of files to use. Sub-fields:
|
51
|
+
'name' File name. May be one of the following:
|
52
|
+
audio Supported formats are .wav, .mp3, .m4a, .aif, .flac, and .ogg
|
53
|
+
glob Matches file glob patterns
|
54
|
+
.yml The given YAML file is parsed into the list
|
55
|
+
.txt Each line in the given text file indicates an item which
|
56
|
+
may be anything in this list (audio, glob, .yml, or .txt)
|
57
|
+
'class_indices' Optional list of class indices
|
57
58
|
|
58
|
-
'truth_configs'
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
'<param2>' Target-specific override for truth configuration parameter
|
59
|
+
'truth_configs' Required list of truth config(s) to use for this source. Sub-fields:
|
60
|
+
'<name>' Name of truth config. Sub-fields:
|
61
|
+
'function' Truth function
|
62
|
+
'stride_reduction' Stride reduction method to use. May be one of: none, max
|
63
63
|
|
64
64
|
'target_level_type'
|
65
65
|
Target-specific override for target_level_type.
|
@@ -72,7 +72,7 @@ targets:
|
|
72
72
|
sed:
|
73
73
|
thresholds: [-38, -41, -48]
|
74
74
|
index: 2
|
75
|
-
|
75
|
+
class_balancing_effect: { }
|
76
76
|
- name: target.mp3
|
77
77
|
truth_configs:
|
78
78
|
sed:
|
@@ -176,27 +176,27 @@ The 'truth_configs' parameter specifies the following:
|
|
176
176
|
Name of stride reduction method to use
|
177
177
|
'<param1>' Function-specific configuration parameter
|
178
178
|
'<paramN>' Function-specific configuration parameter
|
179
|
-
'
|
180
|
-
Class balancing
|
179
|
+
'class_balancing_effect'
|
180
|
+
Class balancing effect.
|
181
181
|
This truth configuration will use this rule for class balancing operations.
|
182
182
|
If this rule is empty or unspecified, then this truth function will not
|
183
183
|
perform class balancing.
|
184
184
|
|
185
185
|
Class balancing ensures that each class in a sound classification dataset
|
186
186
|
is represented equally (i.e., each class has the same number of augmented
|
187
|
-
targets). This is achieved by creating new class balancing
|
187
|
+
targets). This is achieved by creating new class balancing effect
|
188
188
|
rules and applying them to targets in underrepresented classes to create
|
189
|
-
more
|
189
|
+
more effected targets for those classes.
|
190
190
|
|
191
191
|
This rule must contain at least one random entry in order to guarantee
|
192
192
|
unique additional data.
|
193
193
|
|
194
|
-
See '
|
194
|
+
See 'effects' for details on effect rules.
|
195
195
|
""" + get_truth_functions() + default
|
196
196
|
# fmt: on
|
197
197
|
|
198
198
|
|
199
|
-
def
|
199
|
+
def doc_effects() -> str:
|
200
200
|
# fmt: off
|
201
201
|
return """
|
202
202
|
Augmentation Rules
|
@@ -211,7 +211,7 @@ the list.
|
|
211
211
|
If a value is specified using rand, then a randomized rule is generated
|
212
212
|
dynamically per use.
|
213
213
|
|
214
|
-
Rules may specify any or all of the following
|
214
|
+
Rules may specify any or all of the following effects:
|
215
215
|
|
216
216
|
'normalize' Normalize audio file to the specified level (in dBFS).
|
217
217
|
'gain' Apply an amplification or an attenuation to the audio signal.
|
@@ -237,16 +237,19 @@ Rules may specify any or all of the following augmentations:
|
|
237
237
|
'impulse_responses' parameter).
|
238
238
|
For targets, the impulse response is applied AFTER truth generation
|
239
239
|
and the resulting audio is still aligned with the truth. Random
|
240
|
-
syntax for 'ir' is
|
240
|
+
syntax for 'ir' is one of the following:
|
241
|
+
'sai_choose()' chooses a random IR from the entire list
|
242
|
+
'sai_choose(<min>, <max>)' chooses a random IR in the range <min> to <max>
|
243
|
+
'sai_choose(<tag>) chooses a random IR that matches <tag>
|
241
244
|
|
242
|
-
Only the specified
|
245
|
+
Only the specified effects for a given rule are applied; all others are
|
243
246
|
skipped in the given rule. For example, if a rule only specifies 'tempo',
|
244
|
-
then only a tempo
|
247
|
+
then only a tempo effect is applied and all other possible effects
|
245
248
|
are ignored (e.g., 'gain', 'pitch', etc.).
|
246
249
|
|
247
250
|
Example:
|
248
251
|
|
249
|
-
|
252
|
+
target_effects:
|
250
253
|
- normalize: -3.5
|
251
254
|
- normalize: -3.5
|
252
255
|
pitch: [-300, 300]
|
@@ -264,7 +267,7 @@ There are four rules given in this example.
|
|
264
267
|
The first rule is simple:
|
265
268
|
- normalize: -3.5
|
266
269
|
|
267
|
-
This results in just one
|
270
|
+
This results in just one effect being applied to each target:
|
268
271
|
|
269
272
|
normalize: -3.5
|
270
273
|
|
@@ -275,7 +278,7 @@ The second rule illustrates the use of lists to specify values:
|
|
275
278
|
eq1: [[1000, 0.8, 3], [600, 1.0, -4], [800, 0.6, 0]]
|
276
279
|
|
277
280
|
There are two values given for pitch, two for tempo, and three for EQ. This
|
278
|
-
rule expands to 2 * 2 * 3 = 12 unique
|
281
|
+
rule expands to 2 * 2 * 3 = 12 unique effects being applied to each
|
279
282
|
target:
|
280
283
|
|
281
284
|
normalize: -3.5, pitch: -3, tempo: 0.8, eq1: [1000, 0.8, 3]
|
@@ -297,13 +300,13 @@ The third rule shows the use of rand:
|
|
297
300
|
eq1: ["rand(100, 6000)", "rand(0.6, 1.0)", "rand(-6, 6)"]
|
298
301
|
lpf: "rand(1000, 8000)"
|
299
302
|
|
300
|
-
This rule is used to create randomized
|
303
|
+
This rule is used to create randomized effects per use.
|
301
304
|
|
302
305
|
The fourth rule demonstrates the use of scalars, lists, and rand:
|
303
306
|
- tempo: [0.9, 1, 1.1]
|
304
307
|
eq1: [["rand(100, 7500)", 0.8, -10], ["rand(100, 7500)", 0.8, 10]]
|
305
308
|
|
306
|
-
This rule expands to 6 unique
|
309
|
+
This rule expands to 6 unique effects being applied to each target
|
307
310
|
(list of 3 * list of 2). Here is the expansion:
|
308
311
|
|
309
312
|
tempo: 0.9, eq1: ["rand(100, 7500)", 0.8, -10]
|
@@ -315,16 +318,16 @@ This rule expands to 6 unique augmentations being applied to each target
|
|
315
318
|
# fmt: on
|
316
319
|
|
317
320
|
|
318
|
-
def
|
321
|
+
def doc_target_effects() -> str:
|
319
322
|
import yaml
|
320
323
|
|
321
|
-
default = f"\nDefault value:\n\n{yaml.dump(get_default_config()['
|
324
|
+
default = f"\nDefault value:\n\n{yaml.dump(get_default_config()['target_effects'])}"
|
322
325
|
# fmt: off
|
323
326
|
return """
|
324
|
-
'
|
325
|
-
specifies a list of
|
327
|
+
'target_effects' is a mixture database configuration parameter that
|
328
|
+
specifies a list of effect rules to use for each target.
|
326
329
|
|
327
|
-
See '
|
330
|
+
See 'effects' for details on effect rules.
|
328
331
|
""" + default
|
329
332
|
# fmt: on
|
330
333
|
|
@@ -338,7 +341,7 @@ def doc_target_distortions() -> str:
|
|
338
341
|
'target_distortions' is a mixture database configuration parameter that
|
339
342
|
specifies a list of distortion rules to use for each target.
|
340
343
|
|
341
|
-
See '
|
344
|
+
See 'effects' for details on distortion rules.
|
342
345
|
""" + default
|
343
346
|
# fmt: on
|
344
347
|
|
@@ -364,17 +367,17 @@ Required field:
|
|
364
367
|
# fmt: on
|
365
368
|
|
366
369
|
|
367
|
-
def
|
370
|
+
def doc_noise_effects() -> str:
|
368
371
|
import yaml
|
369
372
|
|
370
|
-
default = f"\nDefault value:\n\n{yaml.dump(get_default_config()['
|
373
|
+
default = f"\nDefault value:\n\n{yaml.dump(get_default_config()['noise_effects'])}"
|
371
374
|
|
372
375
|
# fmt: off
|
373
376
|
return """
|
374
|
-
'
|
375
|
-
specifies a list of
|
377
|
+
'noise_effects' is a mixture database configuration parameter that
|
378
|
+
specifies a list of effect rules to use for each noise.
|
376
379
|
|
377
|
-
See '
|
380
|
+
See 'effects' for details on effect rules.
|
378
381
|
""" + default
|
379
382
|
# fmt: on
|
380
383
|
|
@@ -386,7 +389,7 @@ def doc_snrs() -> str:
|
|
386
389
|
'snrs' is a mixture database configuration parameter that specifies a list
|
387
390
|
of required signal-to-noise ratios (in dB).
|
388
391
|
|
389
|
-
All other
|
392
|
+
All other effects are applied to both target and noise and then the
|
390
393
|
energy levels are measured and the appropriate noise gain calculated to
|
391
394
|
achieve the desired SNR.
|
392
395
|
|
@@ -407,7 +410,7 @@ list of random signal-to-noise ratios. The value(s) must be specified as
|
|
407
410
|
random using the syntax: 'rand(<min>, <max>)'.
|
408
411
|
|
409
412
|
Random SNRs behave slightly differently from regular or ordered SNRs. As with
|
410
|
-
ordered SNRs, all other
|
413
|
+
ordered SNRs, all other effects are applied to both target and noise and
|
411
414
|
then the energy levels are measured and the appropriate noise gain calculated
|
412
415
|
to achieve the desired SNR. However, unlike ordered SNRs, the desired SNR is
|
413
416
|
randomized (per the given rule(s)) for each mixture, i.e., previous random
|
@@ -425,14 +428,14 @@ how to mix noises with targets.
|
|
425
428
|
|
426
429
|
Supported modes:
|
427
430
|
|
428
|
-
exhaustive Use every noise/
|
429
|
-
non-exhaustive Cycle through every
|
430
|
-
using all noise/
|
431
|
-
non-combinatorial Combine a
|
432
|
-
noise/
|
433
|
-
does not use each noise/
|
431
|
+
exhaustive Use every noise/effect with every primary/effect.
|
432
|
+
non-exhaustive Cycle through every primary/effect without necessarily
|
433
|
+
using all noise/effect combinations (reduced data set).
|
434
|
+
non-combinatorial Combine a primary/effect with a single cut of a
|
435
|
+
noise/effect non-exhaustively (each primary/effect
|
436
|
+
does not use each noise/effect). Cut has a random start
|
434
437
|
and loops back to the beginning if the end of a
|
435
|
-
noise/
|
438
|
+
noise/effect is reached.
|
436
439
|
""" + default
|
437
440
|
# fmt: on
|
438
441
|
|
@@ -444,7 +447,7 @@ def doc_impulse_responses() -> str:
|
|
444
447
|
'impulse_responses' is a mixture database configuration parameter that specifies a
|
445
448
|
list of impulse response files to use.
|
446
449
|
|
447
|
-
See '
|
450
|
+
See 'effects' for details.
|
448
451
|
""" + default
|
449
452
|
# fmt: on
|
450
453
|
|
@@ -456,7 +459,7 @@ def doc_spectral_masks() -> str:
|
|
456
459
|
'spectral_masks' is a mixture database configuration parameter that specifies
|
457
460
|
a list of spectral mask rules.
|
458
461
|
|
459
|
-
All other
|
462
|
+
All other effects are applied including SNR and a mixture is generated
|
460
463
|
and then the spectral mask rules are applied to the resulting mixture feature.
|
461
464
|
|
462
465
|
Rules must specify all the following parameters:
|
sonusai/genft.py
CHANGED
sonusai/genmetrics.py
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
"""sonusai genmetrics
|
2
2
|
|
3
|
-
usage: genmetrics [-hvusd] [-i MIXID] [-n INCLUDE] [-
|
3
|
+
usage: genmetrics [-hvusd] [-i MIXID] [-n INCLUDE] [-x EXCLUDE] [-p NUMPROC] LOC
|
4
4
|
|
5
5
|
options:
|
6
6
|
-h, --help
|
7
7
|
-v, --verbose Be verbose.
|
8
8
|
-i MIXID, --mixid MIXID Mixture ID(s) to generate. [default: *].
|
9
9
|
-n INCLUDE, --include INCLUDE Metrics to include. [default: all]
|
10
|
-
-p NUMPROC, --nproc NUMPROC Number of parallel processes to use. Default single thread.
|
11
10
|
-x EXCLUDE, --exclude EXCLUDE Metrics to exclude. [default: none]
|
11
|
+
-p NUMPROC, --nproc NUMPROC Number of parallel processes to use. Default single thread.
|
12
12
|
-u, --update Update metrics (do not regenerate existing metrics).
|
13
13
|
-s, --supported Show list of supported metrics.
|
14
14
|
-d, --dryrun Show list of metrics that will be generated and exit.
|
@@ -97,7 +97,7 @@ def main() -> None:
|
|
97
97
|
start_time = time.monotonic()
|
98
98
|
|
99
99
|
# Setup logging file
|
100
|
-
create_file_handler(join(location, "genmetrics.log"))
|
100
|
+
create_file_handler(join(location, "genmetrics.log"), verbose)
|
101
101
|
update_console_handler(verbose)
|
102
102
|
initial_log_messages("genmetrics")
|
103
103
|
|
sonusai/genmix.py
CHANGED
@@ -144,7 +144,7 @@ def main() -> None:
|
|
144
144
|
|
145
145
|
start_time = time.monotonic()
|
146
146
|
|
147
|
-
create_file_handler(join(location, "genmix.log"))
|
147
|
+
create_file_handler(join(location, "genmix.log"), verbose)
|
148
148
|
update_console_handler(verbose)
|
149
149
|
initial_log_messages("genmix")
|
150
150
|
|
sonusai/genmixdb.py
CHANGED
@@ -314,7 +314,7 @@ def main() -> None:
|
|
314
314
|
|
315
315
|
makedirs(location, exist_ok=True)
|
316
316
|
|
317
|
-
create_file_handler(join(location, "genmixdb.log"))
|
317
|
+
create_file_handler(join(location, "genmixdb.log"), verbose)
|
318
318
|
update_console_handler(verbose)
|
319
319
|
initial_log_messages("genmixdb")
|
320
320
|
|
sonusai/lsdb.py
CHANGED
@@ -81,7 +81,7 @@ def lsdb(
|
|
81
81
|
mixids = mixdb.mixids_to_list(mixids)
|
82
82
|
|
83
83
|
if len(mixids) == 1:
|
84
|
-
print_mixture_details(mixdb=mixdb, mixid=mixids[0],
|
84
|
+
print_mixture_details(mixdb=mixdb, mixid=mixids[0], print_fn=logger.info)
|
85
85
|
if all_class_counts:
|
86
86
|
# TODO: fix class count
|
87
87
|
logger.info("All class count not supported")
|
sonusai/metrics_summary.py
CHANGED
@@ -55,13 +55,13 @@ def _process_mixture(
|
|
55
55
|
|
56
56
|
all_metrics = mixdb.mixture_metrics(m_id, all_metric_names)
|
57
57
|
|
58
|
-
# replace
|
58
|
+
# replace dict with 'primary' value (ignore mixup)
|
59
59
|
scalar_metrics = {
|
60
|
-
key: all_metrics[key][
|
60
|
+
key: all_metrics[key]["primary"] if isinstance(all_metrics[key], dict) else all_metrics[key]
|
61
61
|
for key in scalar_metric_names
|
62
62
|
}
|
63
63
|
string_metrics = {
|
64
|
-
key: all_metrics[key][
|
64
|
+
key: all_metrics[key]["primary"] if isinstance(all_metrics[key], dict) else all_metrics[key]
|
65
65
|
for key in string_metric_names
|
66
66
|
}
|
67
67
|
|
@@ -133,7 +133,7 @@ def main() -> None:
|
|
133
133
|
timestamp = create_timestamp() # string good for embedding into filenames
|
134
134
|
mixdb_fname = basename(location)
|
135
135
|
if verbose:
|
136
|
-
create_file_handler(join(location, "metrics_summary.log"))
|
136
|
+
create_file_handler(join(location, "metrics_summary.log"), verbose)
|
137
137
|
update_console_handler(verbose)
|
138
138
|
initial_log_messages("metrics_summary")
|
139
139
|
logger.info(f"Logging summary of SonusAI mixture database at {location}")
|
@@ -168,10 +168,9 @@ def main() -> None:
|
|
168
168
|
for metric in metrics_present:
|
169
169
|
metval = all_metrics[metric] # get metric value
|
170
170
|
logger.debug(f"First mixid {mixids[0]} metric {metric} = {metval}")
|
171
|
-
if isinstance(metval,
|
172
|
-
|
173
|
-
|
174
|
-
metval = metval[0] # remove any list
|
171
|
+
if isinstance(metval, dict):
|
172
|
+
logger.warning(f"Mixid {mixids[0]} metric {metric} is a dict, using 'primary'.")
|
173
|
+
metval = metval["primary"] # remove any dict
|
175
174
|
if isinstance(metval, float | int):
|
176
175
|
logger.debug(f"Metric is scalar {type(metval)}, entering in summary table.")
|
177
176
|
scalar_metric_names.append(metric)
|
sonusai/mixture/__init__.py
CHANGED
sonusai/mixture/db.py
ADDED
@@ -0,0 +1,163 @@
|
|
1
|
+
import contextlib
|
2
|
+
import sqlite3
|
3
|
+
from os import remove
|
4
|
+
from os.path import exists
|
5
|
+
from sqlite3 import Connection
|
6
|
+
from sqlite3 import Cursor
|
7
|
+
from typing import Any
|
8
|
+
|
9
|
+
from .. import logger_db
|
10
|
+
from .db_file import db_file
|
11
|
+
|
12
|
+
|
13
|
+
class SQLiteDatabase:
|
14
|
+
"""A context manager for SQLite database connections with configurable behavior."""
|
15
|
+
|
16
|
+
# Constants for database configuration
|
17
|
+
READONLY_MODE = "?mode=ro"
|
18
|
+
WRITE_OPTIMIZED_PRAGMAS = (
|
19
|
+
"?_journal_mode=OFF&_synchronous=OFF&_cache_size=10000&_temp_store=MEMORY&_locking_mode=EXCLUSIVE"
|
20
|
+
)
|
21
|
+
CONNECTION_TIMEOUT = 20
|
22
|
+
|
23
|
+
def __init__(
|
24
|
+
self,
|
25
|
+
location: str,
|
26
|
+
create: bool = False,
|
27
|
+
readonly: bool = True,
|
28
|
+
test: bool = False,
|
29
|
+
verbose: bool = False,
|
30
|
+
) -> None:
|
31
|
+
"""Initialize SQLite database connection manager.
|
32
|
+
|
33
|
+
Args:
|
34
|
+
location: Path to the database file.
|
35
|
+
create: If True, create a new database file, overwriting any existing one.
|
36
|
+
readonly: If True, open the database in read-only mode.
|
37
|
+
test: If True, use the test database path.
|
38
|
+
verbose: If True, enable SQL statement logging.
|
39
|
+
"""
|
40
|
+
self.location = location
|
41
|
+
self.create = create
|
42
|
+
self.readonly = readonly
|
43
|
+
self.test = test
|
44
|
+
self.verbose = verbose
|
45
|
+
self.con: Connection | None = None
|
46
|
+
self.cur: Cursor | None = None
|
47
|
+
|
48
|
+
def __enter__(self) -> Cursor:
|
49
|
+
"""Enter the context manager, establishing the database connection.
|
50
|
+
|
51
|
+
Returns:
|
52
|
+
A database cursor for executing queries.
|
53
|
+
|
54
|
+
Raises:
|
55
|
+
sqlite3.Error: If connection fails.
|
56
|
+
"""
|
57
|
+
try:
|
58
|
+
self._establish_connection()
|
59
|
+
except Exception:
|
60
|
+
self._close_resources()
|
61
|
+
raise
|
62
|
+
|
63
|
+
if self.cur:
|
64
|
+
return self.cur
|
65
|
+
raise sqlite3.Error("Failed to connect to database")
|
66
|
+
|
67
|
+
def __exit__(
|
68
|
+
self,
|
69
|
+
exc_type: type[BaseException] | None,
|
70
|
+
exc_val: BaseException | None,
|
71
|
+
exc_tb: Any | None,
|
72
|
+
) -> None:
|
73
|
+
"""Exit the context manager, committing changes if appropriate and closing resources.
|
74
|
+
|
75
|
+
Args:
|
76
|
+
exc_type: The exception type, if any.
|
77
|
+
exc_val: The exception value, if any.
|
78
|
+
exc_tb: The exception traceback, if any.
|
79
|
+
"""
|
80
|
+
if self.con:
|
81
|
+
if exc_type is None and not self.readonly:
|
82
|
+
# Commit only on successful exit if not readonly
|
83
|
+
self.con.commit()
|
84
|
+
self._close_resources()
|
85
|
+
|
86
|
+
def _close_resources(self) -> None:
|
87
|
+
"""Safely close database cursor and connection resources."""
|
88
|
+
if self.cur:
|
89
|
+
with contextlib.suppress(sqlite3.Error):
|
90
|
+
self.cur.close()
|
91
|
+
self.cur = None
|
92
|
+
|
93
|
+
if self.con:
|
94
|
+
with contextlib.suppress(sqlite3.Error):
|
95
|
+
self.con.close()
|
96
|
+
self.con = None
|
97
|
+
|
98
|
+
def _establish_connection(self) -> None:
|
99
|
+
"""Establish a connection to the SQLite database.
|
100
|
+
|
101
|
+
Raises:
|
102
|
+
OSError: If the database file doesn't exist and create=False.
|
103
|
+
sqlite3.Error: If connection to the database fails.
|
104
|
+
"""
|
105
|
+
db_path = self._get_db_path()
|
106
|
+
self._prepare_db_file(db_path)
|
107
|
+
uri = self._build_connection_uri(db_path)
|
108
|
+
|
109
|
+
try:
|
110
|
+
self.con = sqlite3.connect(f"file:{uri}", uri=True, timeout=self.CONNECTION_TIMEOUT)
|
111
|
+
if self.verbose and self.con:
|
112
|
+
self.con.set_trace_callback(logger_db.debug)
|
113
|
+
except sqlite3.Error as e:
|
114
|
+
raise sqlite3.Error(f"Failed to connect to database: {e}") from e
|
115
|
+
|
116
|
+
self.cur = self.con.cursor()
|
117
|
+
|
118
|
+
def _get_db_path(self) -> str:
|
119
|
+
"""Get the database file path based on location and test settings.
|
120
|
+
|
121
|
+
Returns:
|
122
|
+
The path to the database file.
|
123
|
+
"""
|
124
|
+
return db_file(self.location, self.test)
|
125
|
+
|
126
|
+
def _prepare_db_file(self, db_path: str) -> None:
|
127
|
+
"""Prepare the database file based on creation settings.
|
128
|
+
|
129
|
+
Args:
|
130
|
+
db_path: Path to the database file.
|
131
|
+
|
132
|
+
Raises:
|
133
|
+
OSError: If the database file doesn't exist and create=False.
|
134
|
+
"""
|
135
|
+
if self.create and exists(db_path):
|
136
|
+
remove(db_path)
|
137
|
+
|
138
|
+
if not self.create and not exists(db_path):
|
139
|
+
raise OSError(f"Could not find mixture database in {self.location}")
|
140
|
+
|
141
|
+
def _build_connection_uri(self, db_path: str) -> str:
|
142
|
+
"""Build the SQLite connection URI with appropriate options.
|
143
|
+
|
144
|
+
Args:
|
145
|
+
db_path: Path to the database file.
|
146
|
+
|
147
|
+
Returns:
|
148
|
+
A properly formatted SQLite URI with appropriate options.
|
149
|
+
"""
|
150
|
+
uri = db_path
|
151
|
+
|
152
|
+
# Add readonly mode if needed
|
153
|
+
if not self.create and self.readonly:
|
154
|
+
uri += self.READONLY_MODE
|
155
|
+
|
156
|
+
# Add optimized pragmas for write mode
|
157
|
+
if not self.readonly:
|
158
|
+
if "?" in uri:
|
159
|
+
uri = uri.replace("?", f"{self.WRITE_OPTIMIZED_PRAGMAS}&")
|
160
|
+
else:
|
161
|
+
uri += self.WRITE_OPTIMIZED_PRAGMAS
|
162
|
+
|
163
|
+
return uri
|
@@ -0,0 +1,10 @@
|
|
1
|
+
from os.path import join
|
2
|
+
from os.path import normpath
|
3
|
+
|
4
|
+
from .constants import MIXDB_NAME
|
5
|
+
from .constants import TEST_MIXDB_NAME
|
6
|
+
|
7
|
+
|
8
|
+
def db_file(location: str, test: bool = False) -> str:
|
9
|
+
name = TEST_MIXDB_NAME if test else MIXDB_NAME
|
10
|
+
return normpath(join(location, name))
|