sonusai 1.0.7__py3-none-any.whl → 1.0.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sonusai/__init__.py CHANGED
@@ -36,7 +36,7 @@ logger_db = logging.getLogger("sonusai_db")
36
36
  logger_db.setLevel(logging.DEBUG)
37
37
 
38
38
  # create file handler
39
- def create_file_handler(filename: str) -> None:
39
+ def create_file_handler(filename: str, verbose: bool = False) -> None:
40
40
  from pathlib import Path
41
41
 
42
42
  fh = logging.FileHandler(filename=filename, mode="w")
@@ -44,12 +44,13 @@ def create_file_handler(filename: str) -> None:
44
44
  fh.setFormatter(formatter)
45
45
  logger.addHandler(fh)
46
46
 
47
- filename_db = Path(filename)
48
- filename_db = filename_db.parent / (filename_db.stem + "_dbtrace" + filename_db.suffix)
49
- fh = logging.FileHandler(filename=filename_db, mode="w")
50
- fh.setLevel(logging.DEBUG)
51
- fh.setFormatter(formatter_db)
52
- logger_db.addHandler(fh)
47
+ if verbose:
48
+ filename_db = Path(filename)
49
+ filename_db = filename_db.parent / (filename_db.stem + "_dbtrace" + filename_db.suffix)
50
+ fh = logging.FileHandler(filename=filename_db, mode="w")
51
+ fh.setLevel(logging.DEBUG)
52
+ fh.setFormatter(formatter_db)
53
+ logger_db.addHandler(fh)
53
54
 
54
55
 
55
56
  # update console handler
sonusai/audiofe.py CHANGED
@@ -77,7 +77,7 @@ def main() -> None:
77
77
  from sonusai.utils import load_ort_session
78
78
 
79
79
  # Setup logging file
80
- create_file_handler("audiofe.log")
80
+ create_file_handler("audiofe.log", verbose)
81
81
  update_console_handler(verbose)
82
82
  initial_log_messages("audiofe")
83
83
 
@@ -531,10 +531,10 @@ def _process_mixture(
531
531
  pesq_speech = calc_pesq(target_est_wav, target_fi)
532
532
  csig_tg, cbak_tg, covl_tg = calc_speech(target_est_wav, target_fi, pesq=pesq_speech)
533
533
  metrics = mixdb.mixture_metrics(m_id, ["mxpesq", "mxcsig", "mxcbak", "mxcovl"])
534
- pesq_mx = metrics["mxpesq"][0] if isinstance(metrics["mxpesq"], list) else metrics["mxpesq"]
535
- csig_mx = metrics["mxcsig"][0] if isinstance(metrics["mxcsig"], list) else metrics["mxcsig"]
536
- cbak_mx = metrics["mxcbak"][0] if isinstance(metrics["mxcbak"], list) else metrics["mxcbak"]
537
- covl_mx = metrics["mxcovl"][0] if isinstance(metrics["mxcovl"], list) else metrics["mxcovl"]
534
+ pesq_mx = metrics["mxpesq"]["primary"] if isinstance(metrics["mxpesq"], dict) else metrics["mxpesq"]
535
+ csig_mx = metrics["mxcsig"]["primary"] if isinstance(metrics["mxcsig"], dict) else metrics["mxcsig"]
536
+ cbak_mx = metrics["mxcbak"]["primary"] if isinstance(metrics["mxcbak"], dict) else metrics["mxcbak"]
537
+ covl_mx = metrics["mxcovl"]["primary"] if isinstance(metrics["mxcovl"], dict) else metrics["mxcovl"]
538
538
  # pesq_speech_tst = calc_pesq(hypothesis=target_est_wav, reference=target)
539
539
  # pesq_mixture_tst = calc_pesq(hypothesis=mixture, reference=target)
540
540
  # pesq improvement
@@ -560,11 +560,11 @@ def _process_mixture(
560
560
  if asr_method is not None and mixdb.mixture(m_id).noise.snr >= -96: # noise only, ignore/reset target ASR
561
561
  asr_mx_name = f"mxasr.{asr_method}"
562
562
  wer_mx_name = f"mxwer.{asr_method}"
563
- asr_tt_name = f"tasr.{asr_method}"
563
+ asr_tt_name = f"sasr.{asr_method}"
564
564
  metrics = mixdb.mixture_metrics(m_id, [asr_mx_name, wer_mx_name, asr_tt_name])
565
- asr_mx = metrics[asr_mx_name][0] if isinstance(metrics[asr_mx_name], list) else metrics[asr_mx_name]
566
- wer_mx = metrics[wer_mx_name][0] if isinstance(metrics[wer_mx_name], list) else metrics[wer_mx_name]
567
- asr_tt = metrics[asr_tt_name][0] if isinstance(metrics[asr_tt_name], list) else metrics[asr_tt_name]
565
+ asr_mx = metrics[asr_mx_name]["primary"] if isinstance(metrics[asr_mx_name], dict) else metrics[asr_mx_name]
566
+ wer_mx = metrics[wer_mx_name]["primary"] if isinstance(metrics[wer_mx_name], dict) else metrics[wer_mx_name]
567
+ asr_tt = metrics[asr_tt_name]["primary"] if isinstance(metrics[asr_tt_name], dict) else metrics[asr_tt_name]
568
568
 
569
569
  if asr_tt:
570
570
  noiseadd = None # TBD add as switch, default -30
@@ -849,7 +849,7 @@ def main():
849
849
  logger.info(f"Found predict log {basename(predict_logfile[0])} in predict location.")
850
850
 
851
851
  # Setup logging file
852
- create_file_handler(join(predict_location, "calc_metric_spenh.log"))
852
+ create_file_handler(join(predict_location, "calc_metric_spenh.log"), verbose)
853
853
  update_console_handler(verbose)
854
854
  initial_log_messages("calc_metric_spenh")
855
855
 
sonusai/doc/doc.py CHANGED
@@ -21,12 +21,12 @@ to use.
21
21
  # fmt: on
22
22
 
23
23
 
24
- def doc_target_level_type() -> str:
25
- default = f"\nDefault value: {get_default_config()['target_level_type']}"
24
+ def doc_level_type() -> str:
25
+ default = f"\nDefault value: {get_default_config()['level_type']}"
26
26
  # fmt: off
27
27
  return """
28
- 'target_level_type' is a mixture database configuration parameter that sets the
29
- algorithm to use to determine target energy level for SNR calculations.
28
+ 'level_type' is a mixture database configuration parameter that sets the
29
+ algorithm to use to determine energy level for SNR calculations.
30
30
  Supported values are:
31
31
 
32
32
  default mean of squares
@@ -35,31 +35,31 @@ Supported values are:
35
35
  # fmt: on
36
36
 
37
37
 
38
- def doc_targets() -> str:
39
- default = f"\nDefault value: {get_default_config()['targets']}"
38
+ def doc_sources() -> str:
39
+ default = f"\nDefault value: {get_default_config()['sources']}"
40
40
  # fmt: off
41
41
  return """
42
- 'targets' is a mixture database configuration parameter that sets the list of
43
- targets to use.
42
+ 'sources' is a mixture database configuration parameter that sets the list of
43
+ sources to use.
44
44
 
45
- Required field:
46
-
47
- 'name'
48
- File name. May be one of the following:
45
+ Two sources are required: 'primary' and 'noise'. Additional sources may be
46
+ specified with arbitrary names.
49
47
 
50
- audio Supported formats are .wav, .mp3, .m4a, .aif, .flac, and .ogg
51
- glob Matches file glob patterns
52
- .yml The given YAML file is parsed into the list
53
- .txt Each line in the given text file indicates an item which
54
- may be anything in this list (audio, glob, .yml, or .txt)
48
+ Each source has the following fields:
55
49
 
56
- Optional fields:
50
+ 'files' Required list of files to use. Sub-fields:
51
+ 'name' File name. May be one of the following:
52
+ audio Supported formats are .wav, .mp3, .m4a, .aif, .flac, and .ogg
53
+ glob Matches file glob patterns
54
+ .yml The given YAML file is parsed into the list
55
+ .txt Each line in the given text file indicates an item which
56
+ may be anything in this list (audio, glob, .yml, or .txt)
57
+ 'class_indices' Optional list of class indices
57
58
 
58
- 'truth_configs'
59
- Local overrides for truth configs. Contains the following:
60
- 'name' Name of truth config
61
- '<param1>' Target-specific override for truth configuration parameter
62
- '<param2>' Target-specific override for truth configuration parameter
59
+ 'truth_configs' Required list of truth config(s) to use for this source. Sub-fields:
60
+ '<name>' Name of truth config. Sub-fields:
61
+ 'function' Truth function
62
+ 'stride_reduction' Stride reduction method to use. May be one of: none, max
63
63
 
64
64
  'target_level_type'
65
65
  Target-specific override for target_level_type.
@@ -72,7 +72,7 @@ targets:
72
72
  sed:
73
73
  thresholds: [-38, -41, -48]
74
74
  index: 2
75
- class_balancing_augmentation: { }
75
+ class_balancing_effect: { }
76
76
  - name: target.mp3
77
77
  truth_configs:
78
78
  sed:
@@ -176,27 +176,27 @@ The 'truth_configs' parameter specifies the following:
176
176
  Name of stride reduction method to use
177
177
  '<param1>' Function-specific configuration parameter
178
178
  '<paramN>' Function-specific configuration parameter
179
- 'class_balancing_augmentation'
180
- Class balancing augmentation.
179
+ 'class_balancing_effect'
180
+ Class balancing effect.
181
181
  This truth configuration will use this rule for class balancing operations.
182
182
  If this rule is empty or unspecified, then this truth function will not
183
183
  perform class balancing.
184
184
 
185
185
  Class balancing ensures that each class in a sound classification dataset
186
186
  is represented equally (i.e., each class has the same number of augmented
187
- targets). This is achieved by creating new class balancing augmentation
187
+ targets). This is achieved by creating new class balancing effect
188
188
  rules and applying them to targets in underrepresented classes to create
189
- more augmented targets for those classes.
189
+ more effected targets for those classes.
190
190
 
191
191
  This rule must contain at least one random entry in order to guarantee
192
192
  unique additional data.
193
193
 
194
- See 'augmentations' for details on augmentation rules.
194
+ See 'effects' for details on effect rules.
195
195
  """ + get_truth_functions() + default
196
196
  # fmt: on
197
197
 
198
198
 
199
- def doc_augmentations() -> str:
199
+ def doc_effects() -> str:
200
200
  # fmt: off
201
201
  return """
202
202
  Augmentation Rules
@@ -211,7 +211,7 @@ the list.
211
211
  If a value is specified using rand, then a randomized rule is generated
212
212
  dynamically per use.
213
213
 
214
- Rules may specify any or all of the following augmentations:
214
+ Rules may specify any or all of the following effects:
215
215
 
216
216
  'normalize' Normalize audio file to the specified level (in dBFS).
217
217
  'gain' Apply an amplification or an attenuation to the audio signal.
@@ -237,16 +237,19 @@ Rules may specify any or all of the following augmentations:
237
237
  'impulse_responses' parameter).
238
238
  For targets, the impulse response is applied AFTER truth generation
239
239
  and the resulting audio is still aligned with the truth. Random
240
- syntax for 'ir' is simply 'rand' (i.e., do not specify <min> and <max>).
240
+ syntax for 'ir' is one of the following:
241
+ 'sai_choose()' chooses a random IR from the entire list
242
+ 'sai_choose(<min>, <max>)' chooses a random IR in the range <min> to <max>
243
+ 'sai_choose(<tag>) chooses a random IR that matches <tag>
241
244
 
242
- Only the specified augmentations for a given rule are applied; all others are
245
+ Only the specified effects for a given rule are applied; all others are
243
246
  skipped in the given rule. For example, if a rule only specifies 'tempo',
244
- then only a tempo augmentation is applied and all other possible augmentations
247
+ then only a tempo effect is applied and all other possible effects
245
248
  are ignored (e.g., 'gain', 'pitch', etc.).
246
249
 
247
250
  Example:
248
251
 
249
- target_augmentations:
252
+ target_effects:
250
253
  - normalize: -3.5
251
254
  - normalize: -3.5
252
255
  pitch: [-300, 300]
@@ -264,7 +267,7 @@ There are four rules given in this example.
264
267
  The first rule is simple:
265
268
  - normalize: -3.5
266
269
 
267
- This results in just one augmentation being applied to each target:
270
+ This results in just one effect being applied to each target:
268
271
 
269
272
  normalize: -3.5
270
273
 
@@ -275,7 +278,7 @@ The second rule illustrates the use of lists to specify values:
275
278
  eq1: [[1000, 0.8, 3], [600, 1.0, -4], [800, 0.6, 0]]
276
279
 
277
280
  There are two values given for pitch, two for tempo, and three for EQ. This
278
- rule expands to 2 * 2 * 3 = 12 unique augmentations being applied to each
281
+ rule expands to 2 * 2 * 3 = 12 unique effects being applied to each
279
282
  target:
280
283
 
281
284
  normalize: -3.5, pitch: -3, tempo: 0.8, eq1: [1000, 0.8, 3]
@@ -297,13 +300,13 @@ The third rule shows the use of rand:
297
300
  eq1: ["rand(100, 6000)", "rand(0.6, 1.0)", "rand(-6, 6)"]
298
301
  lpf: "rand(1000, 8000)"
299
302
 
300
- This rule is used to create randomized augmentations per use.
303
+ This rule is used to create randomized effects per use.
301
304
 
302
305
  The fourth rule demonstrates the use of scalars, lists, and rand:
303
306
  - tempo: [0.9, 1, 1.1]
304
307
  eq1: [["rand(100, 7500)", 0.8, -10], ["rand(100, 7500)", 0.8, 10]]
305
308
 
306
- This rule expands to 6 unique augmentations being applied to each target
309
+ This rule expands to 6 unique effects being applied to each target
307
310
  (list of 3 * list of 2). Here is the expansion:
308
311
 
309
312
  tempo: 0.9, eq1: ["rand(100, 7500)", 0.8, -10]
@@ -315,16 +318,16 @@ This rule expands to 6 unique augmentations being applied to each target
315
318
  # fmt: on
316
319
 
317
320
 
318
- def doc_target_augmentations() -> str:
321
+ def doc_target_effects() -> str:
319
322
  import yaml
320
323
 
321
- default = f"\nDefault value:\n\n{yaml.dump(get_default_config()['target_augmentations'])}"
324
+ default = f"\nDefault value:\n\n{yaml.dump(get_default_config()['target_effects'])}"
322
325
  # fmt: off
323
326
  return """
324
- 'target_augmentations' is a mixture database configuration parameter that
325
- specifies a list of augmentation rules to use for each target.
327
+ 'target_effects' is a mixture database configuration parameter that
328
+ specifies a list of effect rules to use for each target.
326
329
 
327
- See 'augmentations' for details on augmentation rules.
330
+ See 'effects' for details on effect rules.
328
331
  """ + default
329
332
  # fmt: on
330
333
 
@@ -338,7 +341,7 @@ def doc_target_distortions() -> str:
338
341
  'target_distortions' is a mixture database configuration parameter that
339
342
  specifies a list of distortion rules to use for each target.
340
343
 
341
- See 'augmentations' for details on distortion rules.
344
+ See 'effects' for details on distortion rules.
342
345
  """ + default
343
346
  # fmt: on
344
347
 
@@ -364,17 +367,17 @@ Required field:
364
367
  # fmt: on
365
368
 
366
369
 
367
- def doc_noise_augmentations() -> str:
370
+ def doc_noise_effects() -> str:
368
371
  import yaml
369
372
 
370
- default = f"\nDefault value:\n\n{yaml.dump(get_default_config()['noise_augmentations'])}"
373
+ default = f"\nDefault value:\n\n{yaml.dump(get_default_config()['noise_effects'])}"
371
374
 
372
375
  # fmt: off
373
376
  return """
374
- 'noise_augmentations' is a mixture database configuration parameter that
375
- specifies a list of augmentation rules to use for each noise.
377
+ 'noise_effects' is a mixture database configuration parameter that
378
+ specifies a list of effect rules to use for each noise.
376
379
 
377
- See 'augmentations' for details on augmentation rules.
380
+ See 'effects' for details on effect rules.
378
381
  """ + default
379
382
  # fmt: on
380
383
 
@@ -386,7 +389,7 @@ def doc_snrs() -> str:
386
389
  'snrs' is a mixture database configuration parameter that specifies a list
387
390
  of required signal-to-noise ratios (in dB).
388
391
 
389
- All other augmentations are applied to both target and noise and then the
392
+ All other effects are applied to both target and noise and then the
390
393
  energy levels are measured and the appropriate noise gain calculated to
391
394
  achieve the desired SNR.
392
395
 
@@ -407,7 +410,7 @@ list of random signal-to-noise ratios. The value(s) must be specified as
407
410
  random using the syntax: 'rand(<min>, <max>)'.
408
411
 
409
412
  Random SNRs behave slightly differently from regular or ordered SNRs. As with
410
- ordered SNRs, all other augmentations are applied to both target and noise and
413
+ ordered SNRs, all other effects are applied to both target and noise and
411
414
  then the energy levels are measured and the appropriate noise gain calculated
412
415
  to achieve the desired SNR. However, unlike ordered SNRs, the desired SNR is
413
416
  randomized (per the given rule(s)) for each mixture, i.e., previous random
@@ -425,14 +428,14 @@ how to mix noises with targets.
425
428
 
426
429
  Supported modes:
427
430
 
428
- exhaustive Use every noise/augmentation with every target/augmentation.
429
- non-exhaustive Cycle through every target/augmentation without necessarily
430
- using all noise/augmentation combinations (reduced data set).
431
- non-combinatorial Combine a target/augmentation with a single cut of a
432
- noise/augmentation non-exhaustively (each target/augmentation
433
- does not use each noise/augmentation). Cut has a random start
431
+ exhaustive Use every noise/effect with every primary/effect.
432
+ non-exhaustive Cycle through every primary/effect without necessarily
433
+ using all noise/effect combinations (reduced data set).
434
+ non-combinatorial Combine a primary/effect with a single cut of a
435
+ noise/effect non-exhaustively (each primary/effect
436
+ does not use each noise/effect). Cut has a random start
434
437
  and loops back to the beginning if the end of a
435
- noise/augmentation is reached.
438
+ noise/effect is reached.
436
439
  """ + default
437
440
  # fmt: on
438
441
 
@@ -444,7 +447,7 @@ def doc_impulse_responses() -> str:
444
447
  'impulse_responses' is a mixture database configuration parameter that specifies a
445
448
  list of impulse response files to use.
446
449
 
447
- See 'augmentations' for details.
450
+ See 'effects' for details.
448
451
  """ + default
449
452
  # fmt: on
450
453
 
@@ -456,7 +459,7 @@ def doc_spectral_masks() -> str:
456
459
  'spectral_masks' is a mixture database configuration parameter that specifies
457
460
  a list of spectral mask rules.
458
461
 
459
- All other augmentations are applied including SNR and a mixture is generated
462
+ All other effects are applied including SNR and a mixture is generated
460
463
  and then the spectral mask rules are applied to the resulting mixture feature.
461
464
 
462
465
  Rules must specify all the following parameters:
sonusai/genft.py CHANGED
@@ -138,7 +138,7 @@ def main() -> None:
138
138
 
139
139
  start_time = time.monotonic()
140
140
 
141
- create_file_handler(join(location, "genft.log"))
141
+ create_file_handler(join(location, "genft.log"), verbose)
142
142
  update_console_handler(verbose)
143
143
  initial_log_messages("genft")
144
144
 
sonusai/genmetrics.py CHANGED
@@ -1,14 +1,14 @@
1
1
  """sonusai genmetrics
2
2
 
3
- usage: genmetrics [-hvusd] [-i MIXID] [-n INCLUDE] [-p NUMPROC] [-x EXCLUDE] LOC
3
+ usage: genmetrics [-hvusd] [-i MIXID] [-n INCLUDE] [-x EXCLUDE] [-p NUMPROC] LOC
4
4
 
5
5
  options:
6
6
  -h, --help
7
7
  -v, --verbose Be verbose.
8
8
  -i MIXID, --mixid MIXID Mixture ID(s) to generate. [default: *].
9
9
  -n INCLUDE, --include INCLUDE Metrics to include. [default: all]
10
- -p NUMPROC, --nproc NUMPROC Number of parallel processes to use. Default single thread.
11
10
  -x EXCLUDE, --exclude EXCLUDE Metrics to exclude. [default: none]
11
+ -p NUMPROC, --nproc NUMPROC Number of parallel processes to use. Default single thread.
12
12
  -u, --update Update metrics (do not regenerate existing metrics).
13
13
  -s, --supported Show list of supported metrics.
14
14
  -d, --dryrun Show list of metrics that will be generated and exit.
@@ -97,7 +97,7 @@ def main() -> None:
97
97
  start_time = time.monotonic()
98
98
 
99
99
  # Setup logging file
100
- create_file_handler(join(location, "genmetrics.log"))
100
+ create_file_handler(join(location, "genmetrics.log"), verbose)
101
101
  update_console_handler(verbose)
102
102
  initial_log_messages("genmetrics")
103
103
 
sonusai/genmix.py CHANGED
@@ -144,7 +144,7 @@ def main() -> None:
144
144
 
145
145
  start_time = time.monotonic()
146
146
 
147
- create_file_handler(join(location, "genmix.log"))
147
+ create_file_handler(join(location, "genmix.log"), verbose)
148
148
  update_console_handler(verbose)
149
149
  initial_log_messages("genmix")
150
150
 
sonusai/genmixdb.py CHANGED
@@ -314,7 +314,7 @@ def main() -> None:
314
314
 
315
315
  makedirs(location, exist_ok=True)
316
316
 
317
- create_file_handler(join(location, "genmixdb.log"))
317
+ create_file_handler(join(location, "genmixdb.log"), verbose)
318
318
  update_console_handler(verbose)
319
319
  initial_log_messages("genmixdb")
320
320
 
sonusai/lsdb.py CHANGED
@@ -81,7 +81,7 @@ def lsdb(
81
81
  mixids = mixdb.mixids_to_list(mixids)
82
82
 
83
83
  if len(mixids) == 1:
84
- print_mixture_details(mixdb=mixdb, mixid=mixids[0], desc_len=desc_len, print_fn=logger.info)
84
+ print_mixture_details(mixdb=mixdb, mixid=mixids[0], print_fn=logger.info)
85
85
  if all_class_counts:
86
86
  # TODO: fix class count
87
87
  logger.info("All class count not supported")
@@ -55,13 +55,13 @@ def _process_mixture(
55
55
 
56
56
  all_metrics = mixdb.mixture_metrics(m_id, all_metric_names)
57
57
 
58
- # replace lists with first value (ignore mixup)
58
+ # replace dict with 'primary' value (ignore mixup)
59
59
  scalar_metrics = {
60
- key: all_metrics[key][0] if isinstance(all_metrics[key], list) else all_metrics[key]
60
+ key: all_metrics[key]["primary"] if isinstance(all_metrics[key], dict) else all_metrics[key]
61
61
  for key in scalar_metric_names
62
62
  }
63
63
  string_metrics = {
64
- key: all_metrics[key][0] if isinstance(all_metrics[key], list) else all_metrics[key]
64
+ key: all_metrics[key]["primary"] if isinstance(all_metrics[key], dict) else all_metrics[key]
65
65
  for key in string_metric_names
66
66
  }
67
67
 
@@ -133,7 +133,7 @@ def main() -> None:
133
133
  timestamp = create_timestamp() # string good for embedding into filenames
134
134
  mixdb_fname = basename(location)
135
135
  if verbose:
136
- create_file_handler(join(location, "metrics_summary.log"))
136
+ create_file_handler(join(location, "metrics_summary.log"), verbose)
137
137
  update_console_handler(verbose)
138
138
  initial_log_messages("metrics_summary")
139
139
  logger.info(f"Logging summary of SonusAI mixture database at {location}")
@@ -168,10 +168,9 @@ def main() -> None:
168
168
  for metric in metrics_present:
169
169
  metval = all_metrics[metric] # get metric value
170
170
  logger.debug(f"First mixid {mixids[0]} metric {metric} = {metval}")
171
- if isinstance(metval, list):
172
- if len(metval) > 1:
173
- logger.warning(f"Mixid {mixids[0]} metric {metric} has a list with more than 1 element, using first.")
174
- metval = metval[0] # remove any list
171
+ if isinstance(metval, dict):
172
+ logger.warning(f"Mixid {mixids[0]} metric {metric} is a dict, using 'primary'.")
173
+ metval = metval["primary"] # remove any dict
175
174
  if isinstance(metval, float | int):
176
175
  logger.debug(f"Metric is scalar {type(metval)}, entering in summary table.")
177
176
  scalar_metric_names.append(metric)
@@ -28,4 +28,4 @@ from .helpers import inverse_transform
28
28
  from .helpers import write_mixture_metadata
29
29
  from .log_duration_and_sizes import log_duration_and_sizes
30
30
  from .mixdb import MixtureDatabase
31
- from .mixdb import db_file
31
+ from .db_file import db_file
sonusai/mixture/db.py ADDED
@@ -0,0 +1,163 @@
1
+ import contextlib
2
+ import sqlite3
3
+ from os import remove
4
+ from os.path import exists
5
+ from sqlite3 import Connection
6
+ from sqlite3 import Cursor
7
+ from typing import Any
8
+
9
+ from .. import logger_db
10
+ from .db_file import db_file
11
+
12
+
13
+ class SQLiteDatabase:
14
+ """A context manager for SQLite database connections with configurable behavior."""
15
+
16
+ # Constants for database configuration
17
+ READONLY_MODE = "?mode=ro"
18
+ WRITE_OPTIMIZED_PRAGMAS = (
19
+ "?_journal_mode=OFF&_synchronous=OFF&_cache_size=10000&_temp_store=MEMORY&_locking_mode=EXCLUSIVE"
20
+ )
21
+ CONNECTION_TIMEOUT = 20
22
+
23
+ def __init__(
24
+ self,
25
+ location: str,
26
+ create: bool = False,
27
+ readonly: bool = True,
28
+ test: bool = False,
29
+ verbose: bool = False,
30
+ ) -> None:
31
+ """Initialize SQLite database connection manager.
32
+
33
+ Args:
34
+ location: Path to the database file.
35
+ create: If True, create a new database file, overwriting any existing one.
36
+ readonly: If True, open the database in read-only mode.
37
+ test: If True, use the test database path.
38
+ verbose: If True, enable SQL statement logging.
39
+ """
40
+ self.location = location
41
+ self.create = create
42
+ self.readonly = readonly
43
+ self.test = test
44
+ self.verbose = verbose
45
+ self.con: Connection | None = None
46
+ self.cur: Cursor | None = None
47
+
48
+ def __enter__(self) -> Cursor:
49
+ """Enter the context manager, establishing the database connection.
50
+
51
+ Returns:
52
+ A database cursor for executing queries.
53
+
54
+ Raises:
55
+ sqlite3.Error: If connection fails.
56
+ """
57
+ try:
58
+ self._establish_connection()
59
+ except Exception:
60
+ self._close_resources()
61
+ raise
62
+
63
+ if self.cur:
64
+ return self.cur
65
+ raise sqlite3.Error("Failed to connect to database")
66
+
67
+ def __exit__(
68
+ self,
69
+ exc_type: type[BaseException] | None,
70
+ exc_val: BaseException | None,
71
+ exc_tb: Any | None,
72
+ ) -> None:
73
+ """Exit the context manager, committing changes if appropriate and closing resources.
74
+
75
+ Args:
76
+ exc_type: The exception type, if any.
77
+ exc_val: The exception value, if any.
78
+ exc_tb: The exception traceback, if any.
79
+ """
80
+ if self.con:
81
+ if exc_type is None and not self.readonly:
82
+ # Commit only on successful exit if not readonly
83
+ self.con.commit()
84
+ self._close_resources()
85
+
86
+ def _close_resources(self) -> None:
87
+ """Safely close database cursor and connection resources."""
88
+ if self.cur:
89
+ with contextlib.suppress(sqlite3.Error):
90
+ self.cur.close()
91
+ self.cur = None
92
+
93
+ if self.con:
94
+ with contextlib.suppress(sqlite3.Error):
95
+ self.con.close()
96
+ self.con = None
97
+
98
+ def _establish_connection(self) -> None:
99
+ """Establish a connection to the SQLite database.
100
+
101
+ Raises:
102
+ OSError: If the database file doesn't exist and create=False.
103
+ sqlite3.Error: If connection to the database fails.
104
+ """
105
+ db_path = self._get_db_path()
106
+ self._prepare_db_file(db_path)
107
+ uri = self._build_connection_uri(db_path)
108
+
109
+ try:
110
+ self.con = sqlite3.connect(f"file:{uri}", uri=True, timeout=self.CONNECTION_TIMEOUT)
111
+ if self.verbose and self.con:
112
+ self.con.set_trace_callback(logger_db.debug)
113
+ except sqlite3.Error as e:
114
+ raise sqlite3.Error(f"Failed to connect to database: {e}") from e
115
+
116
+ self.cur = self.con.cursor()
117
+
118
+ def _get_db_path(self) -> str:
119
+ """Get the database file path based on location and test settings.
120
+
121
+ Returns:
122
+ The path to the database file.
123
+ """
124
+ return db_file(self.location, self.test)
125
+
126
+ def _prepare_db_file(self, db_path: str) -> None:
127
+ """Prepare the database file based on creation settings.
128
+
129
+ Args:
130
+ db_path: Path to the database file.
131
+
132
+ Raises:
133
+ OSError: If the database file doesn't exist and create=False.
134
+ """
135
+ if self.create and exists(db_path):
136
+ remove(db_path)
137
+
138
+ if not self.create and not exists(db_path):
139
+ raise OSError(f"Could not find mixture database in {self.location}")
140
+
141
+ def _build_connection_uri(self, db_path: str) -> str:
142
+ """Build the SQLite connection URI with appropriate options.
143
+
144
+ Args:
145
+ db_path: Path to the database file.
146
+
147
+ Returns:
148
+ A properly formatted SQLite URI with appropriate options.
149
+ """
150
+ uri = db_path
151
+
152
+ # Add readonly mode if needed
153
+ if not self.create and self.readonly:
154
+ uri += self.READONLY_MODE
155
+
156
+ # Add optimized pragmas for write mode
157
+ if not self.readonly:
158
+ if "?" in uri:
159
+ uri = uri.replace("?", f"{self.WRITE_OPTIMIZED_PRAGMAS}&")
160
+ else:
161
+ uri += self.WRITE_OPTIMIZED_PRAGMAS
162
+
163
+ return uri
@@ -0,0 +1,10 @@
1
+ from os.path import join
2
+ from os.path import normpath
3
+
4
+ from .constants import MIXDB_NAME
5
+ from .constants import TEST_MIXDB_NAME
6
+
7
+
8
+ def db_file(location: str, test: bool = False) -> str:
9
+ name = TEST_MIXDB_NAME if test else MIXDB_NAME
10
+ return normpath(join(location, name))