sonusai 1.0.7__py3-none-any.whl → 1.0.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -8,6 +8,7 @@ from ..datatypes import Source
8
8
  from ..datatypes import SourceFile
9
9
  from ..datatypes import SourcesAudioT
10
10
  from ..datatypes import UniversalSNRGenerator
11
+ from .db import SQLiteDatabase
11
12
  from .mixdb import MixtureDatabase
12
13
 
13
14
 
@@ -18,186 +19,173 @@ def config_file(location: str) -> str:
18
19
 
19
20
 
20
21
  def initialize_db(location: str, test: bool = False, verbose: bool = False) -> None:
21
- from .mixdb import db_connection
22
-
23
- con = db_connection(location=location, create=True, test=test, verbose=verbose)
24
-
25
- con.execute("""
26
- CREATE TABLE truth_config(
27
- id INTEGER PRIMARY KEY NOT NULL,
28
- config TEXT NOT NULL)
29
- """)
30
-
31
- con.execute("""
32
- CREATE TABLE truth_parameters(
33
- id INTEGER PRIMARY KEY NOT NULL,
34
- category TEXT NOT NULL,
35
- name TEXT NOT NULL,
36
- parameters INTEGER)
37
- """)
38
-
39
- con.execute("""
40
- CREATE TABLE source_file (
41
- id INTEGER PRIMARY KEY NOT NULL,
42
- category TEXT NOT NULL,
43
- class_indices TEXT,
44
- level_type TEXT NOT NULL,
45
- name TEXT NOT NULL,
46
- samples INTEGER NOT NULL,
47
- speaker_id INTEGER,
48
- FOREIGN KEY(speaker_id) REFERENCES speaker (id))
49
- """)
50
-
51
- con.execute("""
52
- CREATE TABLE ir_file (
53
- id INTEGER PRIMARY KEY NOT NULL,
54
- delay INTEGER NOT NULL,
55
- name TEXT NOT NULL)
56
- """)
57
-
58
- con.execute("""
59
- CREATE TABLE ir_tag (
60
- id INTEGER PRIMARY KEY NOT NULL,
61
- tag TEXT NOT NULL UNIQUE)
62
- """)
63
-
64
- con.execute("""
65
- CREATE TABLE ir_file_ir_tag (
66
- file_id INTEGER NOT NULL,
67
- tag_id INTEGER NOT NULL,
68
- FOREIGN KEY(file_id) REFERENCES ir_file (id),
69
- FOREIGN KEY(tag_id) REFERENCES ir_tag (id))
70
- """)
71
-
72
- con.execute("""
73
- CREATE TABLE speaker (
74
- id INTEGER PRIMARY KEY NOT NULL,
75
- parent TEXT NOT NULL)
76
- """)
77
-
78
- con.execute("""
79
- CREATE TABLE top (
80
- id INTEGER PRIMARY KEY NOT NULL,
81
- asr_configs TEXT NOT NULL,
82
- class_balancing BOOLEAN NOT NULL,
83
- feature TEXT NOT NULL,
84
- mixid_width INTEGER NOT NULL,
85
- num_classes INTEGER NOT NULL,
86
- seed INTEGER NOT NULL,
87
- speaker_metadata_tiers TEXT NOT NULL,
88
- textgrid_metadata_tiers TEXT NOT NULL,
89
- version INTEGER NOT NULL)
90
- """)
91
-
92
- con.execute("""
93
- CREATE TABLE class_label (
94
- id INTEGER PRIMARY KEY NOT NULL,
95
- label TEXT NOT NULL)
96
- """)
97
-
98
- con.execute("""
99
- CREATE TABLE class_weights_threshold (
100
- id INTEGER PRIMARY KEY NOT NULL,
101
- threshold FLOAT NOT NULL)
102
- """)
103
-
104
- con.execute("""
105
- CREATE TABLE spectral_mask (
106
- id INTEGER PRIMARY KEY NOT NULL,
107
- f_max_width INTEGER NOT NULL,
108
- f_num INTEGER NOT NULL,
109
- t_max_percent INTEGER NOT NULL,
110
- t_max_width INTEGER NOT NULL,
111
- t_num INTEGER NOT NULL)
112
- """)
113
-
114
- con.execute("""
115
- CREATE TABLE source_file_truth_config (
116
- source_file_id INTEGER NOT NULL,
117
- truth_config_id INTEGER NOT NULL,
118
- FOREIGN KEY(source_file_id) REFERENCES source_file (id),
119
- FOREIGN KEY(truth_config_id) REFERENCES truth_config (id))
120
- """)
121
-
122
- con.execute("""
123
- CREATE TABLE source (
124
- id INTEGER PRIMARY KEY NOT NULL,
125
- effects TEXT NOT NULL,
126
- file_id INTEGER NOT NULL,
127
- pre_tempo FLOAT NOT NULL,
128
- repeat BOOLEAN NOT NULL,
129
- snr FLOAT NOT NULL,
130
- snr_gain FLOAT NOT NULL,
131
- snr_random BOOLEAN NOT NULL,
132
- start INTEGER NOT NULL,
133
- UNIQUE(effects, file_id, pre_tempo, repeat, snr, snr_gain, snr_random, start),
134
- FOREIGN KEY(file_id) REFERENCES source_file (id))
135
- """)
136
-
137
- con.execute("""
138
- CREATE TABLE mixture (
139
- id INTEGER PRIMARY KEY NOT NULL,
140
- name TEXT NOT NULL,
141
- samples INTEGER NOT NULL,
142
- spectral_mask_id INTEGER NOT NULL,
143
- spectral_mask_seed INTEGER NOT NULL,
144
- FOREIGN KEY(spectral_mask_id) REFERENCES spectral_mask (id))
145
- """)
146
-
147
- con.execute("""
148
- CREATE TABLE mixture_source (
149
- mixture_id INTEGER NOT NULL,
150
- source_id INTEGER NOT NULL,
151
- FOREIGN KEY(mixture_id) REFERENCES mixture (id),
152
- FOREIGN KEY(source_id) REFERENCES source (id))
153
- """)
154
-
155
- con.commit()
156
- con.close()
22
+ with SQLiteDatabase(location=location, create=True, test=test, verbose=verbose) as c:
23
+ c.execute("""
24
+ CREATE TABLE truth_config(
25
+ id INTEGER PRIMARY KEY NOT NULL,
26
+ config TEXT NOT NULL)
27
+ """)
28
+
29
+ c.execute("""
30
+ CREATE TABLE truth_parameters(
31
+ id INTEGER PRIMARY KEY NOT NULL,
32
+ category TEXT NOT NULL,
33
+ name TEXT NOT NULL,
34
+ parameters INTEGER)
35
+ """)
36
+
37
+ c.execute("""
38
+ CREATE TABLE source_file (
39
+ id INTEGER PRIMARY KEY NOT NULL,
40
+ category TEXT NOT NULL,
41
+ class_indices TEXT,
42
+ level_type TEXT NOT NULL,
43
+ name TEXT NOT NULL,
44
+ samples INTEGER NOT NULL,
45
+ speaker_id INTEGER,
46
+ FOREIGN KEY(speaker_id) REFERENCES speaker (id))
47
+ """)
48
+
49
+ c.execute("""
50
+ CREATE TABLE ir_file (
51
+ id INTEGER PRIMARY KEY NOT NULL,
52
+ delay INTEGER NOT NULL,
53
+ name TEXT NOT NULL)
54
+ """)
55
+
56
+ c.execute("""
57
+ CREATE TABLE ir_tag (
58
+ id INTEGER PRIMARY KEY NOT NULL,
59
+ tag TEXT NOT NULL UNIQUE)
60
+ """)
61
+
62
+ c.execute("""
63
+ CREATE TABLE ir_file_ir_tag (
64
+ file_id INTEGER NOT NULL,
65
+ tag_id INTEGER NOT NULL,
66
+ FOREIGN KEY(file_id) REFERENCES ir_file (id),
67
+ FOREIGN KEY(tag_id) REFERENCES ir_tag (id))
68
+ """)
69
+
70
+ c.execute("""
71
+ CREATE TABLE speaker (
72
+ id INTEGER PRIMARY KEY NOT NULL,
73
+ parent TEXT NOT NULL)
74
+ """)
75
+
76
+ c.execute("""
77
+ CREATE TABLE top (
78
+ id INTEGER PRIMARY KEY NOT NULL,
79
+ asr_configs TEXT NOT NULL,
80
+ class_balancing BOOLEAN NOT NULL,
81
+ feature TEXT NOT NULL,
82
+ mixid_width INTEGER NOT NULL,
83
+ num_classes INTEGER NOT NULL,
84
+ seed INTEGER NOT NULL,
85
+ speaker_metadata_tiers TEXT NOT NULL,
86
+ textgrid_metadata_tiers TEXT NOT NULL,
87
+ version INTEGER NOT NULL)
88
+ """)
89
+
90
+ c.execute("""
91
+ CREATE TABLE class_label (
92
+ id INTEGER PRIMARY KEY NOT NULL,
93
+ label TEXT NOT NULL)
94
+ """)
95
+
96
+ c.execute("""
97
+ CREATE TABLE class_weights_threshold (
98
+ id INTEGER PRIMARY KEY NOT NULL,
99
+ threshold FLOAT NOT NULL)
100
+ """)
101
+
102
+ c.execute("""
103
+ CREATE TABLE spectral_mask (
104
+ id INTEGER PRIMARY KEY NOT NULL,
105
+ f_max_width INTEGER NOT NULL,
106
+ f_num INTEGER NOT NULL,
107
+ t_max_percent INTEGER NOT NULL,
108
+ t_max_width INTEGER NOT NULL,
109
+ t_num INTEGER NOT NULL)
110
+ """)
111
+
112
+ c.execute("""
113
+ CREATE TABLE source_file_truth_config (
114
+ source_file_id INTEGER NOT NULL,
115
+ truth_config_id INTEGER NOT NULL,
116
+ FOREIGN KEY(source_file_id) REFERENCES source_file (id),
117
+ FOREIGN KEY(truth_config_id) REFERENCES truth_config (id))
118
+ """)
119
+
120
+ c.execute("""
121
+ CREATE TABLE source (
122
+ id INTEGER PRIMARY KEY NOT NULL,
123
+ effects TEXT NOT NULL,
124
+ file_id INTEGER NOT NULL,
125
+ pre_tempo FLOAT NOT NULL,
126
+ repeat BOOLEAN NOT NULL,
127
+ snr FLOAT NOT NULL,
128
+ snr_gain FLOAT NOT NULL,
129
+ snr_random BOOLEAN NOT NULL,
130
+ start INTEGER NOT NULL,
131
+ UNIQUE(effects, file_id, pre_tempo, repeat, snr, snr_gain, snr_random, start),
132
+ FOREIGN KEY(file_id) REFERENCES source_file (id))
133
+ """)
134
+
135
+ c.execute("""
136
+ CREATE TABLE mixture (
137
+ id INTEGER PRIMARY KEY NOT NULL,
138
+ name TEXT NOT NULL,
139
+ samples INTEGER NOT NULL,
140
+ spectral_mask_id INTEGER NOT NULL,
141
+ spectral_mask_seed INTEGER NOT NULL,
142
+ FOREIGN KEY(spectral_mask_id) REFERENCES spectral_mask (id))
143
+ """)
144
+
145
+ c.execute("""
146
+ CREATE TABLE mixture_source (
147
+ mixture_id INTEGER NOT NULL,
148
+ source_id INTEGER NOT NULL,
149
+ FOREIGN KEY(mixture_id) REFERENCES mixture (id),
150
+ FOREIGN KEY(source_id) REFERENCES source (id))
151
+ """)
157
152
 
158
153
 
159
154
  def populate_top_table(location: str, config: dict, test: bool = False, verbose: bool = False) -> None:
160
- """Populate top table"""
155
+ """Populate the top table"""
161
156
  import json
162
157
 
163
158
  from .constants import MIXDB_VERSION
164
- from .mixdb import db_connection
165
-
166
- con = db_connection(location=location, readonly=False, test=test, verbose=verbose)
167
- con.execute(
168
- """
169
- INSERT INTO top (id, asr_configs, class_balancing, feature, mixid_width, num_classes,
170
- seed, speaker_metadata_tiers, textgrid_metadata_tiers, version)
171
- VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
172
- """,
173
- (
174
- 1,
175
- json.dumps(config["asr_configs"]),
176
- config["class_balancing"],
177
- config["feature"],
178
- 0,
179
- config["num_classes"],
180
- config["seed"],
181
- "",
182
- "",
183
- MIXDB_VERSION,
184
- ),
185
- )
186
- con.commit()
187
- con.close()
159
+
160
+ with SQLiteDatabase(location=location, readonly=False, test=test, verbose=verbose) as c:
161
+ c.execute(
162
+ """
163
+ INSERT INTO top (id, asr_configs, class_balancing, feature, mixid_width, num_classes,
164
+ seed, speaker_metadata_tiers, textgrid_metadata_tiers, version)
165
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
166
+ """,
167
+ (
168
+ 1,
169
+ json.dumps(config["asr_configs"]),
170
+ config["class_balancing"],
171
+ config["feature"],
172
+ 0,
173
+ config["num_classes"],
174
+ config["seed"],
175
+ "",
176
+ "",
177
+ MIXDB_VERSION,
178
+ ),
179
+ )
188
180
 
189
181
 
190
182
  def populate_class_label_table(location: str, config: dict, test: bool = False, verbose: bool = False) -> None:
191
183
  """Populate class_label table"""
192
- from .mixdb import db_connection
193
-
194
- con = db_connection(location=location, readonly=False, test=test, verbose=verbose)
195
- con.executemany(
196
- "INSERT INTO class_label (label) VALUES (?)",
197
- [(item,) for item in config["class_labels"]],
198
- )
199
- con.commit()
200
- con.close()
184
+ with SQLiteDatabase(location=location, readonly=False, test=test, verbose=verbose) as c:
185
+ c.executemany(
186
+ "INSERT INTO class_label (label) VALUES (?)",
187
+ [(item,) for item in config["class_labels"]],
188
+ )
201
189
 
202
190
 
203
191
  def populate_class_weights_threshold_table(
@@ -207,8 +195,6 @@ def populate_class_weights_threshold_table(
207
195
  verbose: bool = False,
208
196
  ) -> None:
209
197
  """Populate class_weights_threshold table"""
210
- from .mixdb import db_connection
211
-
212
198
  class_weights_threshold = config["class_weights_threshold"]
213
199
  num_classes = config["num_classes"]
214
200
 
@@ -221,61 +207,53 @@ def populate_class_weights_threshold_table(
221
207
  if len(class_weights_threshold) != num_classes:
222
208
  raise ValueError(f"invalid class_weights_threshold length: {len(class_weights_threshold)}")
223
209
 
224
- con = db_connection(location=location, readonly=False, test=test, verbose=verbose)
225
- con.executemany(
226
- "INSERT INTO class_weights_threshold (threshold) VALUES (?)",
227
- [(item,) for item in class_weights_threshold],
228
- )
229
- con.commit()
230
- con.close()
210
+ with SQLiteDatabase(location=location, readonly=False, test=test, verbose=verbose) as c:
211
+ c.executemany(
212
+ "INSERT INTO class_weights_threshold (threshold) VALUES (?)",
213
+ [(item,) for item in class_weights_threshold],
214
+ )
231
215
 
232
216
 
233
217
  def populate_spectral_mask_table(location: str, config: dict, test: bool = False, verbose: bool = False) -> None:
234
218
  """Populate spectral_mask table"""
235
219
  from .config import get_spectral_masks
236
- from .mixdb import db_connection
237
-
238
- con = db_connection(location=location, readonly=False, test=test, verbose=verbose)
239
- con.executemany(
240
- """
241
- INSERT INTO spectral_mask (f_max_width, f_num, t_max_percent, t_max_width, t_num) VALUES (?, ?, ?, ?, ?)
242
- """,
243
- [
244
- (
245
- item.f_max_width,
246
- item.f_num,
247
- item.t_max_percent,
248
- item.t_max_width,
249
- item.t_num,
250
- )
251
- for item in get_spectral_masks(config)
252
- ],
253
- )
254
- con.commit()
255
- con.close()
220
+
221
+ with SQLiteDatabase(location=location, readonly=False, test=test, verbose=verbose) as c:
222
+ c.executemany(
223
+ """
224
+ INSERT INTO spectral_mask (f_max_width, f_num, t_max_percent, t_max_width, t_num) VALUES (?, ?, ?, ?, ?)
225
+ """,
226
+ [
227
+ (
228
+ item.f_max_width,
229
+ item.f_num,
230
+ item.t_max_percent,
231
+ item.t_max_width,
232
+ item.t_num,
233
+ )
234
+ for item in get_spectral_masks(config)
235
+ ],
236
+ )
256
237
 
257
238
 
258
239
  def populate_truth_parameters_table(location: str, config: dict, test: bool = False, verbose: bool = False) -> None:
259
240
  """Populate truth_parameters table"""
260
241
  from .config import get_truth_parameters
261
- from .mixdb import db_connection
262
-
263
- con = db_connection(location=location, readonly=False, test=test, verbose=verbose)
264
- con.executemany(
265
- """
266
- INSERT INTO truth_parameters (category, name, parameters) VALUES (?, ?, ?)
267
- """,
268
- [
269
- (
270
- item.category,
271
- item.name,
272
- item.parameters,
273
- )
274
- for item in get_truth_parameters(config)
275
- ],
276
- )
277
- con.commit()
278
- con.close()
242
+
243
+ with SQLiteDatabase(location=location, readonly=False, test=test, verbose=verbose) as c:
244
+ c.executemany(
245
+ """
246
+ INSERT INTO truth_parameters (category, name, parameters) VALUES (?, ?, ?)
247
+ """,
248
+ [
249
+ (
250
+ item.category,
251
+ item.name,
252
+ item.parameters,
253
+ )
254
+ for item in get_truth_parameters(config)
255
+ ],
256
+ )
279
257
 
280
258
 
281
259
  def populate_source_file_table(
@@ -284,72 +262,65 @@ def populate_source_file_table(
284
262
  test: bool = False,
285
263
  verbose: bool = False,
286
264
  ) -> None:
287
- """Populate source file table"""
265
+ """Populate the source file table"""
288
266
  import json
289
267
  from pathlib import Path
290
268
 
291
- from .mixdb import db_connection
292
-
293
269
  _populate_truth_config_table(location, files, test, verbose)
294
270
  _populate_speaker_table(location, files, test, verbose)
295
271
 
296
- con = db_connection(location=location, readonly=False, test=test, verbose=verbose)
272
+ with SQLiteDatabase(location=location, readonly=False, test=test, verbose=verbose) as c:
273
+ textgrid_metadata_tiers: set[str] = set()
274
+ for file in files:
275
+ # Get TextGrid tiers for source file and add to collection
276
+ tiers = _get_textgrid_tiers_from_source_file(file.name)
277
+ for tier in tiers:
278
+ textgrid_metadata_tiers.add(tier)
279
+
280
+ # Get truth settings for file
281
+ truth_config_ids: list[int] = []
282
+ if file.truth_configs:
283
+ for name, config in file.truth_configs.items():
284
+ ts = json.dumps({"name": name} | config.to_dict())
285
+ c.execute(
286
+ "SELECT truth_config.id FROM truth_config WHERE ? = truth_config.config",
287
+ (ts,),
288
+ )
289
+ truth_config_ids.append(c.fetchone()[0])
297
290
 
298
- cur = con.cursor()
299
- textgrid_metadata_tiers: set[str] = set()
300
- for file in files:
301
- # Get TextGrid tiers for source file and add to collection
302
- tiers = _get_textgrid_tiers_from_source_file(file.name)
303
- for tier in tiers:
304
- textgrid_metadata_tiers.add(tier)
291
+ # Get speaker_id for source file
292
+ c.execute("SELECT speaker.id FROM speaker WHERE ? = speaker.parent", (Path(file.name).parent.as_posix(),))
293
+ result = c.fetchone()
294
+ speaker_id = None
295
+ if result is not None:
296
+ speaker_id = result[0]
305
297
 
306
- # Get truth settings for file
307
- truth_config_ids: list[int] = []
308
- if file.truth_configs:
309
- for name, config in file.truth_configs.items():
310
- ts = json.dumps({"name": name} | config.to_dict())
311
- cur.execute(
312
- "SELECT truth_config.id FROM truth_config WHERE ? = truth_config.config",
313
- (ts,),
298
+ # Add entry
299
+ c.execute(
300
+ """
301
+ INSERT INTO source_file (category, class_indices, level_type, name, samples, speaker_id)
302
+ VALUES (?, ?, ?, ?, ?, ?)
303
+ """,
304
+ (
305
+ file.category,
306
+ json.dumps(file.class_indices),
307
+ file.level_type,
308
+ file.name,
309
+ file.samples,
310
+ speaker_id,
311
+ ),
312
+ )
313
+ source_file_id = c.lastrowid
314
+ for truth_config_id in truth_config_ids:
315
+ c.execute(
316
+ "INSERT INTO source_file_truth_config (source_file_id, truth_config_id) VALUES (?, ?)",
317
+ (source_file_id, truth_config_id),
314
318
  )
315
- truth_config_ids.append(cur.fetchone()[0])
316
319
 
317
- # Get speaker_id for source file
318
- cur.execute("SELECT speaker.id FROM speaker WHERE ? = speaker.parent", (Path(file.name).parent.as_posix(),))
319
- result = cur.fetchone()
320
- speaker_id = None
321
- if result is not None:
322
- speaker_id = result[0]
323
-
324
- # Add entry
325
- cur.execute(
326
- """
327
- INSERT INTO source_file (category, class_indices, level_type, name, samples, speaker_id)
328
- VALUES (?, ?, ?, ?, ?, ?)
329
- """,
330
- (
331
- file.category,
332
- json.dumps(file.class_indices),
333
- file.level_type,
334
- file.name,
335
- file.samples,
336
- speaker_id,
337
- ),
320
+ # Update textgrid_metadata_tiers in the top table
321
+ c.execute(
322
+ "UPDATE top SET textgrid_metadata_tiers=? WHERE ? = id", (json.dumps(sorted(textgrid_metadata_tiers)), 1)
338
323
  )
339
- source_file_id = cur.lastrowid
340
- for truth_config_id in truth_config_ids:
341
- cur.execute(
342
- "INSERT INTO source_file_truth_config (source_file_id, truth_config_id) VALUES (?, ?)",
343
- (source_file_id, truth_config_id),
344
- )
345
-
346
- # Update textgrid_metadata_tiers in the top table
347
- con.execute(
348
- "UPDATE top SET textgrid_metadata_tiers=? WHERE ? = id", (json.dumps(sorted(textgrid_metadata_tiers)), 1)
349
- )
350
-
351
- con.commit()
352
- con.close()
353
324
 
354
325
 
355
326
  def populate_impulse_response_file_table(
@@ -358,40 +329,30 @@ def populate_impulse_response_file_table(
358
329
  test: bool = False,
359
330
  verbose: bool = False,
360
331
  ) -> None:
361
- """Populate impulse response file table"""
362
- from .mixdb import db_connection
363
-
332
+ """Populate the impulse response file table"""
364
333
  _populate_impulse_response_tag_table(location, files, test, verbose)
365
334
 
366
- con = db_connection(location=location, readonly=False, test=test, verbose=verbose)
367
-
368
- cur = con.cursor()
369
- for file in files:
370
- # Get tags for file
371
- tag_ids: list[int] = []
372
- for tag in file.tags:
373
- cur.execute("SELECT id FROM ir_tag WHERE ? = tag", (tag,))
374
- tag_ids.append(cur.fetchone()[0])
335
+ with SQLiteDatabase(location=location, readonly=False, test=test, verbose=verbose) as c:
336
+ for file in files:
337
+ # Get the tags for the file
338
+ tag_ids: list[int] = []
339
+ for tag in file.tags:
340
+ c.execute("SELECT id FROM ir_tag WHERE ? = tag", (tag,))
341
+ tag_ids.append(c.fetchone()[0])
375
342
 
376
- cur.execute("INSERT INTO ir_file (delay, name) VALUES (?, ?)", (file.delay, file.name))
343
+ c.execute("INSERT INTO ir_file (delay, name) VALUES (?, ?)", (file.delay, file.name))
377
344
 
378
- file_id = cur.lastrowid
379
- for tag_id in tag_ids:
380
- cur.execute("INSERT INTO ir_file_ir_tag (file_id, tag_id) VALUES (?, ?)", (file_id, tag_id))
381
-
382
- con.commit()
383
- con.close()
345
+ file_id = c.lastrowid
346
+ for tag_id in tag_ids:
347
+ c.execute("INSERT INTO ir_file_ir_tag (file_id, tag_id) VALUES (?, ?)", (file_id, tag_id))
384
348
 
385
349
 
386
350
  def update_mixid_width(location: str, num_mixtures: int, test: bool = False, verbose: bool = False) -> None:
387
351
  """Update the mixid width"""
388
352
  from ..utils.max_text_width import max_text_width
389
- from .mixdb import db_connection
390
353
 
391
- con = db_connection(location=location, readonly=False, test=test, verbose=verbose)
392
- con.execute("UPDATE top SET mixid_width=? WHERE ? = id", (max_text_width(num_mixtures), 1))
393
- con.commit()
394
- con.close()
354
+ with SQLiteDatabase(location=location, readonly=False, test=test, verbose=verbose) as c:
355
+ c.execute("UPDATE top SET mixid_width=? WHERE ? = id", (max_text_width(num_mixtures), 1))
395
356
 
396
357
 
397
358
  def generate_mixtures(
@@ -447,53 +408,49 @@ def populate_mixture_table(
447
408
  from ..utils.parallel import track
448
409
  from .helpers import from_mixture
449
410
  from .helpers import from_source
450
- from .mixdb import db_connection
451
-
452
- con = db_connection(location=location, readonly=False, test=test, verbose=verbose)
453
411
 
454
- # Populate source table
455
412
  if logging:
456
413
  logger.info("Populating mixture and source tables")
457
- for mixture in track(mixtures, disable=not show_progress):
458
- m_id = int(mixture.name) + 1
459
- con.execute(
460
- """
461
- INSERT INTO mixture (id, name, samples, spectral_mask_id, spectral_mask_seed)
462
- VALUES (?, ?, ?, ?, ?)
463
- """,
464
- (m_id, *from_mixture(mixture)),
465
- )
466
414
 
467
- for source in mixture.all_sources.values():
468
- con.execute(
415
+ with SQLiteDatabase(location=location, readonly=False, test=test, verbose=verbose) as c:
416
+ # Populate source table
417
+ for mixture in track(mixtures, disable=not show_progress):
418
+ m_id = int(mixture.name) + 1
419
+ c.execute(
469
420
  """
470
- INSERT OR IGNORE INTO source (effects, file_id, pre_tempo, repeat, snr, snr_gain, snr_random, start)
471
- VALUES (?, ?, ?, ?, ?, ?, ?, ?)
472
- """,
473
- from_source(source),
421
+ INSERT INTO mixture (id, name, samples, spectral_mask_id, spectral_mask_seed)
422
+ VALUES (?, ?, ?, ?, ?)
423
+ """,
424
+ (m_id, *from_mixture(mixture)),
474
425
  )
475
426
 
476
- source_id = con.execute(
477
- """
478
- SELECT id
479
- FROM source
480
- WHERE ? = effects
481
- AND ? = file_id
482
- AND ? = pre_tempo
483
- AND ? = repeat
484
- AND ? = snr
485
- AND ? = snr_gain
486
- AND ? = snr_random
487
- AND ? = start
488
- """,
489
- from_source(source),
490
- ).fetchone()[0]
491
- con.execute("INSERT INTO mixture_source (mixture_id, source_id) VALUES (?, ?)", (m_id, source_id))
427
+ for source in mixture.all_sources.values():
428
+ c.execute(
429
+ """
430
+ INSERT OR IGNORE INTO source (effects, file_id, pre_tempo, repeat, snr, snr_gain, snr_random, start)
431
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?)
432
+ """,
433
+ from_source(source),
434
+ )
492
435
 
493
- if logging:
494
- logger.info("Closing mixture and source tables")
495
- con.commit()
496
- con.close()
436
+ source_id = c.execute(
437
+ """
438
+ SELECT id
439
+ FROM source
440
+ WHERE ? = effects
441
+ AND ? = file_id
442
+ AND ? = pre_tempo
443
+ AND ? = repeat
444
+ AND ? = snr
445
+ AND ? = snr_gain
446
+ AND ? = snr_random
447
+ AND ? = start
448
+ """,
449
+ from_source(source),
450
+ ).fetchone()[0]
451
+ c.execute("INSERT INTO mixture_source (mixture_id, source_id) VALUES (?, ?)", (m_id, source_id))
452
+ if logging:
453
+ logger.info("Closing mixture and source tables")
497
454
 
498
455
 
499
456
  def update_mixture(mixdb: MixtureDatabase, mixture: Mixture, with_data: bool = False) -> tuple[Mixture, GenMixData]:
@@ -867,14 +824,13 @@ def _populate_speaker_table(
867
824
  test: bool = False,
868
825
  verbose: bool = False,
869
826
  ) -> None:
870
- """Populate speaker table"""
827
+ """Populate the speaker table"""
871
828
  import json
872
829
  from pathlib import Path
873
830
 
874
831
  import yaml
875
832
 
876
833
  from ..utils.tokenized_shell_vars import tokenized_expand
877
- from .mixdb import db_connection
878
834
 
879
835
  # Determine columns for speaker table
880
836
  all_parents = {Path(file.name).parent for file in source_files}
@@ -891,36 +847,32 @@ def _populate_speaker_table(
891
847
  new_columns.append(column)
892
848
  new_columns = sorted(set(new_columns))
893
849
 
894
- con = db_connection(location=location, readonly=False, test=test, verbose=verbose)
895
-
896
- for new_column in new_columns:
897
- con.execute(f"ALTER TABLE speaker ADD COLUMN {new_column} TEXT")
850
+ with SQLiteDatabase(location=location, readonly=False, test=test, verbose=verbose) as c:
851
+ for new_column in new_columns:
852
+ c.execute(f"ALTER TABLE speaker ADD COLUMN {new_column} TEXT")
898
853
 
899
- # Populate speaker table
900
- speaker_rows: list[tuple[str, ...]] = []
901
- for key in speakers:
902
- entry = (speakers[key].get(column, None) for column in new_columns)
903
- speaker_rows.append((key.as_posix(), *entry)) # type: ignore[arg-type]
854
+ # Populate speaker table
855
+ speaker_rows: list[tuple[str, ...]] = []
856
+ for key in speakers:
857
+ entry = (speakers[key].get(column, None) for column in new_columns)
858
+ speaker_rows.append((key.as_posix(), *entry)) # type: ignore[arg-type]
904
859
 
905
- column_ids = ", ".join(["parent", *new_columns])
906
- column_values = ", ".join(["?"] * (len(new_columns) + 1))
907
- con.executemany(f"INSERT INTO speaker ({column_ids}) VALUES ({column_values})", speaker_rows)
860
+ column_ids = ", ".join(["parent", *new_columns])
861
+ column_values = ", ".join(["?"] * (len(new_columns) + 1))
862
+ c.executemany(f"INSERT INTO speaker ({column_ids}) VALUES ({column_values})", speaker_rows)
908
863
 
909
- con.execute("CREATE INDEX speaker_parent_idx ON speaker (parent)")
864
+ c.execute("CREATE INDEX speaker_parent_idx ON speaker (parent)")
910
865
 
911
- # Update speaker_metadata_tiers in the top table
912
- tiers = [
913
- description[0]
914
- for description in con.execute("SELECT * FROM speaker").description
915
- if description[0] not in ("id", "parent")
916
- ]
917
- con.execute("UPDATE top SET speaker_metadata_tiers=? WHERE ? = id", (json.dumps(tiers), 1))
866
+ # Update speaker_metadata_tiers in the top table
867
+ tiers = [
868
+ description[0]
869
+ for description in c.execute("SELECT * FROM speaker").description
870
+ if description[0] not in ("id", "parent")
871
+ ]
872
+ c.execute("UPDATE top SET speaker_metadata_tiers=? WHERE ? = id", (json.dumps(tiers), 1))
918
873
 
919
- if "speaker_id" in tiers:
920
- con.execute("CREATE INDEX speaker_speaker_id_idx ON source_file (speaker_id)")
921
-
922
- con.commit()
923
- con.close()
874
+ if "speaker_id" in tiers:
875
+ c.execute("CREATE INDEX speaker_speaker_id_idx ON source_file (speaker_id)")
924
876
 
925
877
 
926
878
  def _populate_truth_config_table(
@@ -932,24 +884,18 @@ def _populate_truth_config_table(
932
884
  """Populate truth_config table"""
933
885
  import json
934
886
 
935
- from .mixdb import db_connection
936
-
937
- con = db_connection(location=location, readonly=False, test=test, verbose=verbose)
938
-
939
- # Populate truth_config table
940
- truth_configs: list[str] = []
941
- for file in source_files:
942
- for name, config in file.truth_configs.items():
943
- ts = json.dumps({"name": name} | config.to_dict())
944
- if ts not in truth_configs:
945
- truth_configs.append(ts)
946
- con.executemany(
947
- "INSERT INTO truth_config (config) VALUES (?)",
948
- [(item,) for item in truth_configs],
949
- )
950
-
951
- con.commit()
952
- con.close()
887
+ with SQLiteDatabase(location=location, readonly=False, test=test, verbose=verbose) as c:
888
+ # Populate truth_config table
889
+ truth_configs: list[str] = []
890
+ for file in source_files:
891
+ for name, config in file.truth_configs.items():
892
+ ts = json.dumps({"name": name} | config.to_dict())
893
+ if ts not in truth_configs:
894
+ truth_configs.append(ts)
895
+ c.executemany(
896
+ "INSERT INTO truth_config (config) VALUES (?)",
897
+ [(item,) for item in truth_configs],
898
+ )
953
899
 
954
900
 
955
901
  def _populate_impulse_response_tag_table(
@@ -959,14 +905,8 @@ def _populate_impulse_response_tag_table(
959
905
  verbose: bool = False,
960
906
  ) -> None:
961
907
  """Populate ir_tag table"""
962
- from .mixdb import db_connection
963
-
964
- con = db_connection(location=location, readonly=False, test=test, verbose=verbose)
965
-
966
- con.executemany(
967
- "INSERT INTO ir_tag (tag) VALUES (?)",
968
- [(tag,) for tag in {tag for file in files for tag in file.tags}],
969
- )
970
-
971
- con.commit()
972
- con.close()
908
+ with SQLiteDatabase(location=location, readonly=False, test=test, verbose=verbose) as c:
909
+ c.executemany(
910
+ "INSERT INTO ir_tag (tag) VALUES (?)",
911
+ [(tag,) for tag in {tag for file in files for tag in file.tags}],
912
+ )