sonusai 1.0.10__py3-none-any.whl → 1.0.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -15,6 +15,7 @@ from .calc_segsnr_f import calc_segsnr_f_bin
15
15
  from .calc_speech import calc_speech
16
16
  from .calc_wer import calc_wer
17
17
  from .calc_wsdr import calc_wsdr
18
+ from .calculate_metrics import calculate_metrics
18
19
  from .class_summary import class_summary
19
20
  from .confusion_matrix_summary import confusion_matrix_summary
20
21
  from .one_hot import one_hot
@@ -0,0 +1,395 @@
1
+ import functools
2
+ from typing import Any
3
+
4
+ import numpy as np
5
+ from pystoi import stoi
6
+
7
+ from ..constants import SAMPLE_RATE
8
+ from ..datatypes import AudioF
9
+ from ..datatypes import AudioStatsMetrics
10
+ from ..datatypes import AudioT
11
+ from ..datatypes import Segsnr
12
+ from ..datatypes import SpeechMetrics
13
+ from ..mixture.mixdb import MixtureDatabase
14
+ from ..utils.asr import calc_asr
15
+ from ..utils.db import linear_to_db
16
+ from .calc_audio_stats import calc_audio_stats
17
+ from .calc_pesq import calc_pesq
18
+ from .calc_phase_distance import calc_phase_distance
19
+ from .calc_segsnr_f import calc_segsnr_f
20
+ from .calc_segsnr_f import calc_segsnr_f_bin
21
+ from .calc_speech import calc_speech
22
+ from .calc_wer import calc_wer
23
+ from .calc_wsdr import calc_wsdr
24
+
25
+
26
+ def calculate_metrics(mixdb: MixtureDatabase, m_id: int, metrics: list[str], force: bool = False) -> dict[str, Any]:
27
+ """Get metrics data for the given mixture ID
28
+
29
+ :param mixdb: Mixture database object
30
+ :param m_id: Zero-based mixture ID
31
+ :param metrics: List of metrics to get
32
+ :param force: Force computing data from original sources regardless of whether cached data exists
33
+ :return: Dictionary of metric data
34
+ """
35
+
36
+ # Define cached functions for expensive operations
37
+ @functools.lru_cache(maxsize=1)
38
+ def mixture_sources() -> dict[str, AudioT]:
39
+ return mixdb.mixture_sources(m_id)
40
+
41
+ @functools.lru_cache(maxsize=1)
42
+ def mixture_source() -> AudioT:
43
+ return mixdb.mixture_source(m_id)
44
+
45
+ @functools.lru_cache(maxsize=1)
46
+ def mixture_source_f() -> AudioF:
47
+ return mixdb.mixture_source_f(m_id)
48
+
49
+ @functools.lru_cache(maxsize=1)
50
+ def mixture_noise() -> AudioT:
51
+ return mixdb.mixture_noise(m_id)
52
+
53
+ @functools.lru_cache(maxsize=1)
54
+ def mixture_noise_f() -> AudioF:
55
+ return mixdb.mixture_noise_f(m_id)
56
+
57
+ @functools.lru_cache(maxsize=1)
58
+ def mixture_mixture() -> AudioT:
59
+ return mixdb.mixture_mixture(m_id)
60
+
61
+ @functools.lru_cache(maxsize=1)
62
+ def mixture_mixture_f() -> AudioF:
63
+ return mixdb.mixture_mixture_f(m_id)
64
+
65
+ @functools.lru_cache(maxsize=1)
66
+ def mixture_segsnr() -> Segsnr:
67
+ return mixdb.mixture_segsnr(m_id)
68
+
69
+ @functools.lru_cache(maxsize=1)
70
+ def calculate_pesq() -> dict[str, float]:
71
+ return {category: calc_pesq(mixture_mixture(), audio) for category, audio in mixture_sources().items()}
72
+
73
+ @functools.lru_cache(maxsize=1)
74
+ def calculate_speech() -> dict[str, SpeechMetrics]:
75
+ return {
76
+ category: calc_speech(mixture_mixture(), audio, calculate_pesq()[category])
77
+ for category, audio in mixture_sources().items()
78
+ }
79
+
80
+ @functools.lru_cache(maxsize=1)
81
+ def mixture_stats() -> AudioStatsMetrics:
82
+ return calc_audio_stats(mixture_mixture(), mixdb.fg_info.ft_config.length / SAMPLE_RATE)
83
+
84
+ @functools.lru_cache(maxsize=1)
85
+ def sources_stats() -> dict[str, AudioStatsMetrics]:
86
+ return {
87
+ category: calc_audio_stats(audio, mixdb.fg_info.ft_config.length / SAMPLE_RATE)
88
+ for category, audio in mixture_sources().items()
89
+ }
90
+
91
+ @functools.lru_cache(maxsize=1)
92
+ def source_stats() -> AudioStatsMetrics:
93
+ return calc_audio_stats(mixture_source(), mixdb.fg_info.ft_config.length / SAMPLE_RATE)
94
+
95
+ @functools.lru_cache(maxsize=1)
96
+ def noise_stats() -> AudioStatsMetrics:
97
+ return calc_audio_stats(mixture_noise(), mixdb.fg_info.ft_config.length / SAMPLE_RATE)
98
+
99
+ # Cache ASR configurations
100
+ @functools.lru_cache(maxsize=32)
101
+ def get_asr_config(asr_name: str) -> dict:
102
+ value = mixdb.asr_configs.get(asr_name, None)
103
+ if value is None:
104
+ raise ValueError(f"Unrecognized ASR name: '{asr_name}'")
105
+ return value
106
+
107
+ # Cache ASR results for sources, source and mixture
108
+ @functools.lru_cache(maxsize=16)
109
+ def sources_asr(asr_name: str) -> dict[str, str]:
110
+ return {
111
+ category: calc_asr(audio, **get_asr_config(asr_name)).text for category, audio in mixture_sources().items()
112
+ }
113
+
114
+ @functools.lru_cache(maxsize=16)
115
+ def source_asr(asr_name: str) -> str:
116
+ return calc_asr(mixture_source(), **get_asr_config(asr_name)).text
117
+
118
+ @functools.lru_cache(maxsize=16)
119
+ def mixture_asr(asr_name: str) -> str:
120
+ return calc_asr(mixture_mixture(), **get_asr_config(asr_name)).text
121
+
122
+ def get_asr_name(m: str) -> str:
123
+ parts = m.split(".")
124
+ if len(parts) != 2:
125
+ raise ValueError(f"Unrecognized format: '{m}'; must be of the form: '<metric>.<name>'")
126
+ asr_name = parts[1]
127
+ return asr_name
128
+
129
+ def calc(m: str) -> Any:
130
+ if m == "mxsnr":
131
+ return {category: source.snr for category, source in mixdb.mixture(m_id).all_sources.items()}
132
+
133
+ # Get cached data first, if exists
134
+ if not force:
135
+ value = mixdb.read_mixture_data(m_id, m)[m]
136
+ if value is not None:
137
+ return value
138
+
139
+ # Otherwise, generate data as needed
140
+ if m.startswith("mxwer"):
141
+ asr_name = get_asr_name(m)
142
+
143
+ if mixdb.mixture(m_id).is_noise_only:
144
+ # noise only, ignore/reset target asr
145
+ return float("nan")
146
+
147
+ if source_asr(asr_name):
148
+ return calc_wer(mixture_asr(asr_name), source_asr(asr_name)).wer * 100
149
+
150
+ # TODO: should this be NaN like above?
151
+ return float(0)
152
+
153
+ if m.startswith("basewer"):
154
+ asr_name = get_asr_name(m)
155
+
156
+ text = mixdb.mixture_speech_metadata(m_id, "text")
157
+ return {
158
+ category: calc_wer(source, str(text[category])).wer * 100 if isinstance(text[category], str) else 0
159
+ for category, source in sources_asr(asr_name).items()
160
+ }
161
+
162
+ if m.startswith("mxasr"):
163
+ return mixture_asr(get_asr_name(m))
164
+
165
+ if m == "mxssnr_avg":
166
+ return calc_segsnr_f(mixture_segsnr()).avg
167
+
168
+ if m == "mxssnr_std":
169
+ return calc_segsnr_f(mixture_segsnr()).std
170
+
171
+ if m == "mxssnr_avg_db":
172
+ val = calc_segsnr_f(mixture_segsnr()).avg
173
+ if val is not None:
174
+ return linear_to_db(val)
175
+ return None
176
+
177
+ if m == "mxssnr_std_db":
178
+ val = calc_segsnr_f(mixture_segsnr()).std
179
+ if val is not None:
180
+ return linear_to_db(val)
181
+ return None
182
+
183
+ if m == "mxssnrdb_avg":
184
+ return calc_segsnr_f(mixture_segsnr()).db_avg
185
+
186
+ if m == "mxssnrdb_std":
187
+ return calc_segsnr_f(mixture_segsnr()).db_std
188
+
189
+ if m == "mxssnrf_avg":
190
+ return calc_segsnr_f_bin(mixture_source_f(), mixture_noise_f()).avg
191
+
192
+ if m == "mxssnrf_std":
193
+ return calc_segsnr_f_bin(mixture_source_f(), mixture_noise_f()).std
194
+
195
+ if m == "mxssnrdbf_avg":
196
+ return calc_segsnr_f_bin(mixture_source_f(), mixture_noise_f()).db_avg
197
+
198
+ if m == "mxssnrdbf_std":
199
+ return calc_segsnr_f_bin(mixture_source_f(), mixture_noise_f()).db_std
200
+
201
+ if m == "mxpesq":
202
+ if mixdb.mixture(m_id).is_noise_only:
203
+ return dict.fromkeys(calculate_pesq(), 0)
204
+ return calculate_pesq()
205
+
206
+ if m == "mxcsig":
207
+ if mixdb.mixture(m_id).is_noise_only:
208
+ return dict.fromkeys(calculate_speech(), 0)
209
+ return {category: s.csig for category, s in calculate_speech().items()}
210
+
211
+ if m == "mxcbak":
212
+ if mixdb.mixture(m_id).is_noise_only:
213
+ return dict.fromkeys(calculate_speech(), 0)
214
+ return {category: s.cbak for category, s in calculate_speech().items()}
215
+
216
+ if m == "mxcovl":
217
+ if mixdb.mixture(m_id).is_noise_only:
218
+ return dict.fromkeys(calculate_speech(), 0)
219
+ return {category: s.covl for category, s in calculate_speech().items()}
220
+
221
+ if m == "mxwsdr":
222
+ mixture = mixture_mixture()[:, np.newaxis]
223
+ target = mixture_source()[:, np.newaxis]
224
+ noise = mixture_noise()[:, np.newaxis]
225
+ return calc_wsdr(
226
+ hypothesis=np.concatenate((mixture, noise), axis=1),
227
+ reference=np.concatenate((target, noise), axis=1),
228
+ with_log=True,
229
+ )[0]
230
+
231
+ if m == "mxpd":
232
+ return calc_phase_distance(hypothesis=mixture_mixture_f(), reference=mixture_source_f())[0]
233
+
234
+ if m == "mxstoi":
235
+ return stoi(
236
+ x=mixture_source(),
237
+ y=mixture_mixture(),
238
+ fs_sig=SAMPLE_RATE,
239
+ extended=False,
240
+ )
241
+
242
+ if m == "mxdco":
243
+ return mixture_stats().dco
244
+
245
+ if m == "mxmin":
246
+ return mixture_stats().min
247
+
248
+ if m == "mxmax":
249
+ return mixture_stats().max
250
+
251
+ if m == "mxpkdb":
252
+ return mixture_stats().pkdb
253
+
254
+ if m == "mxlrms":
255
+ return mixture_stats().lrms
256
+
257
+ if m == "mxpkr":
258
+ return mixture_stats().pkr
259
+
260
+ if m == "mxtr":
261
+ return mixture_stats().tr
262
+
263
+ if m == "mxcr":
264
+ return mixture_stats().cr
265
+
266
+ if m == "mxfl":
267
+ return mixture_stats().fl
268
+
269
+ if m == "mxpkc":
270
+ return mixture_stats().pkc
271
+
272
+ if m == "sdco":
273
+ return {category: s.dco for category, s in sources_stats().items()}
274
+
275
+ if m == "smin":
276
+ return {category: s.min for category, s in sources_stats().items()}
277
+
278
+ if m == "smax":
279
+ return {category: s.max for category, s in sources_stats().items()}
280
+
281
+ if m == "spkdb":
282
+ return {category: s.pkdb for category, s in sources_stats().items()}
283
+
284
+ if m == "slrms":
285
+ return {category: s.lrms for category, s in sources_stats().items()}
286
+
287
+ if m == "spkr":
288
+ return {category: s.pkr for category, s in sources_stats().items()}
289
+
290
+ if m == "str":
291
+ return {category: s.tr for category, s in sources_stats().items()}
292
+
293
+ if m == "scr":
294
+ return {category: s.cr for category, s in sources_stats().items()}
295
+
296
+ if m == "sfl":
297
+ return {category: s.fl for category, s in sources_stats().items()}
298
+
299
+ if m == "spkc":
300
+ return {category: s.pkc for category, s in sources_stats().items()}
301
+
302
+ if m == "mxsdco":
303
+ return source_stats().dco
304
+
305
+ if m == "mxsmin":
306
+ return source_stats().min
307
+
308
+ if m == "mxsmax":
309
+ return source_stats().max
310
+
311
+ if m == "mxspkdb":
312
+ return source_stats().pkdb
313
+
314
+ if m == "mxslrms":
315
+ return source_stats().lrms
316
+
317
+ if m == "mxspkr":
318
+ return source_stats().pkr
319
+
320
+ if m == "mxstr":
321
+ return source_stats().tr
322
+
323
+ if m == "mxscr":
324
+ return source_stats().cr
325
+
326
+ if m == "mxsfl":
327
+ return source_stats().fl
328
+
329
+ if m == "mxspkc":
330
+ return source_stats().pkc
331
+
332
+ if m.startswith("sasr"):
333
+ return sources_asr(get_asr_name(m))
334
+
335
+ if m.startswith("mxsasr"):
336
+ return source_asr(get_asr_name(m))
337
+
338
+ if m == "ndco":
339
+ return noise_stats().dco
340
+
341
+ if m == "nmin":
342
+ return noise_stats().min
343
+
344
+ if m == "nmax":
345
+ return noise_stats().max
346
+
347
+ if m == "npkdb":
348
+ return noise_stats().pkdb
349
+
350
+ if m == "nlrms":
351
+ return noise_stats().lrms
352
+
353
+ if m == "npkr":
354
+ return noise_stats().pkr
355
+
356
+ if m == "ntr":
357
+ return noise_stats().tr
358
+
359
+ if m == "ncr":
360
+ return noise_stats().cr
361
+
362
+ if m == "nfl":
363
+ return noise_stats().fl
364
+
365
+ if m == "npkc":
366
+ return noise_stats().pkc
367
+
368
+ if m == "sedavg":
369
+ return 0
370
+
371
+ if m == "sedcnt":
372
+ return 0
373
+
374
+ if m == "sedtop3":
375
+ return np.zeros(3, dtype=np.float32)
376
+
377
+ if m == "sedtopn":
378
+ return 0
379
+
380
+ if m == "ssnr":
381
+ return mixture_segsnr()
382
+
383
+ raise AttributeError(f"Unrecognized metric: '{m}'")
384
+
385
+ result: dict[str, Any] = {}
386
+ for metric in metrics:
387
+ result[metric] = calc(metric)
388
+
389
+ # Check for metrics dependencies and add them even if not explicitly requested.
390
+ if metric.startswith("mxwer"):
391
+ dependencies = ("mxasr." + metric[6:], "sasr." + metric[6:])
392
+ for dependency in dependencies:
393
+ result[dependency] = calc(dependency)
394
+
395
+ return result
@@ -28,4 +28,4 @@ from .helpers import inverse_transform
28
28
  from .helpers import write_mixture_metadata
29
29
  from .log_duration_and_sizes import log_duration_and_sizes
30
30
  from .mixdb import MixtureDatabase
31
- from .db_file import db_file
31
+ from .db import db_file
sonusai/mixture/db.py CHANGED
@@ -2,12 +2,21 @@ import contextlib
2
2
  import sqlite3
3
3
  from os import remove
4
4
  from os.path import exists
5
+ from os.path import join
6
+ from os.path import normpath
5
7
  from sqlite3 import Connection
6
8
  from sqlite3 import Cursor
7
9
  from typing import Any
8
10
 
9
- from .. import logger_db
10
- from .db_file import db_file
11
+ from sonusai import logger_db
12
+
13
+
14
+ def db_file(location: str, test: bool = False) -> str:
15
+ from .constants import MIXDB_NAME
16
+ from .constants import TEST_MIXDB_NAME
17
+
18
+ name = TEST_MIXDB_NAME if test else MIXDB_NAME
19
+ return normpath(join(location, name))
11
20
 
12
21
 
13
22
  class SQLiteDatabase:
@@ -15,9 +24,6 @@ class SQLiteDatabase:
15
24
 
16
25
  # Constants for database configuration
17
26
  READONLY_MODE = "?mode=ro"
18
- WRITE_OPTIMIZED_PRAGMAS = (
19
- "?_journal_mode=OFF&_synchronous=OFF&_cache_size=10000&_temp_store=MEMORY&_locking_mode=EXCLUSIVE"
20
- )
21
27
  CONNECTION_TIMEOUT = 20
22
28
 
23
29
  def __init__(
@@ -39,7 +45,7 @@ class SQLiteDatabase:
39
45
  """
40
46
  self.location = location
41
47
  self.create = create
42
- self.readonly = readonly
48
+ self.readonly = readonly and not create
43
49
  self.test = test
44
50
  self.verbose = verbose
45
51
  self.con: Connection | None = None
@@ -61,6 +67,7 @@ class SQLiteDatabase:
61
67
  raise
62
68
 
63
69
  if self.cur:
70
+ self.cur.execute("BEGIN TRANSACTION")
64
71
  return self.cur
65
72
  raise sqlite3.Error("Failed to connect to database")
66
73
 
@@ -78,9 +85,13 @@ class SQLiteDatabase:
78
85
  exc_tb: The exception traceback, if any.
79
86
  """
80
87
  if self.con:
81
- if exc_type is None and not self.readonly:
82
- # Commit only on successful exit if not readonly
83
- self.con.commit()
88
+ if not self.readonly:
89
+ if exc_type is None:
90
+ # Commit only on successful exit
91
+ self.con.commit()
92
+ else:
93
+ # Rollback on exception
94
+ self.con.rollback()
84
95
  self._close_resources()
85
96
 
86
97
  def _close_resources(self) -> None:
@@ -107,9 +118,21 @@ class SQLiteDatabase:
107
118
  uri = self._build_connection_uri(db_path)
108
119
 
109
120
  try:
110
- self.con = sqlite3.connect(f"file:{uri}", uri=True, timeout=self.CONNECTION_TIMEOUT)
121
+ self.con = sqlite3.connect(
122
+ f"file:{uri}",
123
+ uri=True,
124
+ timeout=self.CONNECTION_TIMEOUT,
125
+ isolation_level=None,
126
+ )
111
127
  if self.verbose and self.con:
112
128
  self.con.set_trace_callback(logger_db.debug)
129
+ if self.create or not self.readonly:
130
+ self.con.execute("PRAGMA journal_mode=wal")
131
+ self.con.execute("PRAGMA synchronous=0") # off
132
+ self.con.execute("PRAGMA cache_size=10000")
133
+ self.con.execute("PRAGMA temp_store=2") # memory
134
+ self.con.execute("PRAGMA locking_mode=exclusive")
135
+ self.con.commit()
113
136
  except sqlite3.Error as e:
114
137
  raise sqlite3.Error(f"Failed to connect to database: {e}") from e
115
138
 
@@ -153,11 +176,4 @@ class SQLiteDatabase:
153
176
  if not self.create and self.readonly:
154
177
  uri += self.READONLY_MODE
155
178
 
156
- # Add optimized pragmas for write mode
157
- if not self.readonly:
158
- if "?" in uri:
159
- uri = uri.replace("?", f"{self.WRITE_OPTIMIZED_PRAGMAS}&")
160
- else:
161
- uri += self.WRITE_OPTIMIZED_PRAGMAS
162
-
163
179
  return uri
sonusai/mixture/mixdb.py CHANGED
@@ -30,7 +30,7 @@ from ..datatypes import TruthsConfigs
30
30
  from ..datatypes import TruthsDict
31
31
  from ..datatypes import UniversalSNR
32
32
  from .db import SQLiteDatabase
33
- from .db_file import db_file
33
+ from .db import db_file
34
34
 
35
35
 
36
36
  class MixtureDatabase:
@@ -215,16 +215,6 @@ class MixtureDatabase:
215
215
  MetricDoc("Mixture Metrics", "mxcr", "Mixture Crest factor"),
216
216
  MetricDoc("Mixture Metrics", "mxfl", "Mixture Flat factor"),
217
217
  MetricDoc("Mixture Metrics", "mxpkc", "Mixture Pk count"),
218
- MetricDoc("Mixture Metrics", "mxtdco", "Mixture source DC offset"),
219
- MetricDoc("Mixture Metrics", "mxtmin", "Mixture source min level"),
220
- MetricDoc("Mixture Metrics", "mxtmax", "Mixture source max levl"),
221
- MetricDoc("Mixture Metrics", "mxtpkdb", "Mixture source Pk lev dB"),
222
- MetricDoc("Mixture Metrics", "mxtlrms", "Mixture source RMS lev dB"),
223
- MetricDoc("Mixture Metrics", "mxtpkr", "Mixture source RMS Pk dB"),
224
- MetricDoc("Mixture Metrics", "mxttr", "Mixture source RMS Tr dB"),
225
- MetricDoc("Mixture Metrics", "mxtcr", "Mixture source Crest factor"),
226
- MetricDoc("Mixture Metrics", "mxtfl", "Mixture source Flat factor"),
227
- MetricDoc("Mixture Metrics", "mxtpkc", "Mixture source Pk count"),
228
218
  MetricDoc("Sources Metrics", "sdco", "Sources DC offset"),
229
219
  MetricDoc("Sources Metrics", "smin", "Sources min level"),
230
220
  MetricDoc("Sources Metrics", "smax", "Sources max levl"),
@@ -235,6 +225,16 @@ class MixtureDatabase:
235
225
  MetricDoc("Sources Metrics", "scr", "Sources Crest factor"),
236
226
  MetricDoc("Sources Metrics", "sfl", "Sources Flat factor"),
237
227
  MetricDoc("Sources Metrics", "spkc", "Sources Pk count"),
228
+ MetricDoc("Source Metrics", "mxsdco", "Source DC offset"),
229
+ MetricDoc("Source Metrics", "mxsmin", "Source min level"),
230
+ MetricDoc("Source Metrics", "mxsmax", "Source max levl"),
231
+ MetricDoc("Source Metrics", "mxspkdb", "Source Pk lev dB"),
232
+ MetricDoc("Source Metrics", "mxslrms", "Source RMS lev dB"),
233
+ MetricDoc("Source Metrics", "mxspkr", "Source RMS Pk dB"),
234
+ MetricDoc("Source Metrics", "mxstr", "Source RMS Tr dB"),
235
+ MetricDoc("Source Metrics", "mxscr", "Source Crest factor"),
236
+ MetricDoc("Source Metrics", "mxsfl", "Source Flat factor"),
237
+ MetricDoc("Source Metrics", "mxspkc", "Source Pk count"),
238
238
  MetricDoc("Noise Metrics", "ndco", "Noise DC offset"),
239
239
  MetricDoc("Noise Metrics", "nmin", "Noise min level"),
240
240
  MetricDoc("Noise Metrics", "nmax", "Noise max levl"),
@@ -277,7 +277,7 @@ class MixtureDatabase:
277
277
  )
278
278
  metrics.append(
279
279
  MetricDoc(
280
- "Source Metrics",
280
+ "Sources Metrics",
281
281
  f"sasr.{name}",
282
282
  f"Sources ASR text using {name} ASR as defined in mixdb asr_configs parameter",
283
283
  )
@@ -291,7 +291,7 @@ class MixtureDatabase:
291
291
  )
292
292
  metrics.append(
293
293
  MetricDoc(
294
- "Source Metrics",
294
+ "Sources Metrics",
295
295
  f"basewer.{name}",
296
296
  f"Word error rate of sasr.{name} vs. speech text metadata for the source",
297
297
  )
@@ -1296,17 +1296,15 @@ class MixtureDatabase:
1296
1296
  fg = FeatureGenerator(self.fg_config.feature_mode, self.fg_config.truth_parameters)
1297
1297
 
1298
1298
  feature, truth_f = fg.execute_all(mixture_f, truth_t)
1299
- if truth_f is not None:
1300
- truth_configs = self.mixture_truth_configs(m_id)
1301
- for category, configs in truth_configs.items():
1302
- for name, config in configs.items():
1303
- if self.truth_parameters[category][name] is not None:
1304
- truth_f[category][name] = truth_stride_reduction(
1305
- truth_f[category][name], config.stride_reduction
1306
- )
1307
- else:
1299
+ if truth_f is None:
1308
1300
  raise TypeError("Unexpected truth of None from feature generator")
1309
1301
 
1302
+ truth_configs = self.mixture_truth_configs(m_id)
1303
+ for category, configs in truth_configs.items():
1304
+ for name, config in configs.items():
1305
+ if self.truth_parameters[category][name] is not None:
1306
+ truth_f[category][name] = truth_stride_reduction(truth_f[category][name], config.stride_reduction)
1307
+
1310
1308
  if cache:
1311
1309
  write_cached_data(
1312
1310
  location=self.location,
@@ -1598,536 +1596,9 @@ class MixtureDatabase:
1598
1596
  :param force: Force computing data from original sources regardless of whether cached data exists
1599
1597
  :return: Dictionary of metric data
1600
1598
  """
1601
- from collections.abc import Callable
1602
-
1603
- import numpy as np
1604
- from pystoi import stoi
1605
-
1606
- from ..constants import SAMPLE_RATE
1607
- from ..datatypes import AudioStatsMetrics
1608
- from ..datatypes import SpeechMetrics
1609
- from ..metrics.calc_audio_stats import calc_audio_stats
1610
- from ..metrics.calc_pesq import calc_pesq
1611
- from ..metrics.calc_phase_distance import calc_phase_distance
1612
- from ..metrics.calc_segsnr_f import calc_segsnr_f
1613
- from ..metrics.calc_segsnr_f import calc_segsnr_f_bin
1614
- from ..metrics.calc_speech import calc_speech
1615
- from ..metrics.calc_wer import calc_wer
1616
- from ..metrics.calc_wsdr import calc_wsdr
1617
- from ..utils.asr import calc_asr
1618
- from ..utils.db import linear_to_db
1619
-
1620
- def create_sources_audio() -> Callable[[], dict[str, AudioT]]:
1621
- state: dict[str, AudioT] | None = None
1622
-
1623
- def get() -> dict[str, AudioT]:
1624
- nonlocal state
1625
- if state is None:
1626
- state = self.mixture_sources(m_id)
1627
- return state
1628
-
1629
- return get
1630
-
1631
- sources_audio = create_sources_audio()
1632
-
1633
- def create_source_audio() -> Callable[[], AudioT]:
1634
- state: AudioT | None = None
1635
-
1636
- def get() -> AudioT:
1637
- nonlocal state
1638
- if state is None:
1639
- state = self.mixture_source(m_id)
1640
- return state
1641
-
1642
- return get
1643
-
1644
- source_audio = create_source_audio()
1645
-
1646
- def create_source_f() -> Callable[[], AudioF]:
1647
- state: AudioF | None = None
1648
-
1649
- def get() -> AudioF:
1650
- nonlocal state
1651
- if state is None:
1652
- state = self.mixture_source_f(m_id)
1653
- return state
1654
-
1655
- return get
1656
-
1657
- source_f = create_source_f()
1658
-
1659
- def create_noise_audio() -> Callable[[], AudioT]:
1660
- state: AudioT | None = None
1661
-
1662
- def get() -> AudioT:
1663
- nonlocal state
1664
- if state is None:
1665
- state = self.mixture_noise(m_id)
1666
- return state
1667
-
1668
- return get
1669
-
1670
- noise_audio = create_noise_audio()
1671
-
1672
- def create_noise_f() -> Callable[[], AudioF]:
1673
- state: AudioF | None = None
1674
-
1675
- def get() -> AudioF:
1676
- nonlocal state
1677
- if state is None:
1678
- state = self.mixture_noise_f(m_id)
1679
- return state
1680
-
1681
- return get
1682
-
1683
- noise_f = create_noise_f()
1684
-
1685
- def create_mixture_audio() -> Callable[[], AudioT]:
1686
- state: AudioT | None = None
1687
-
1688
- def get() -> AudioT:
1689
- nonlocal state
1690
- if state is None:
1691
- state = self.mixture_mixture(m_id)
1692
- return state
1693
-
1694
- return get
1695
-
1696
- mixture_audio = create_mixture_audio()
1697
-
1698
- def create_segsnr_f() -> Callable[[], Segsnr]:
1699
- state: Segsnr | None = None
1700
-
1701
- def get() -> Segsnr:
1702
- nonlocal state
1703
- if state is None:
1704
- state = self.mixture_segsnr(m_id)
1705
- return state
1706
-
1707
- return get
1708
-
1709
- segsnr_f = create_segsnr_f()
1710
-
1711
- def create_pesq() -> Callable[[], dict[str, float]]:
1712
- state: dict[str, float] | None = None
1713
-
1714
- def get() -> dict[str, float]:
1715
- nonlocal state
1716
- if state is None:
1717
- state = {category: calc_pesq(mixture_audio(), audio) for category, audio in sources_audio().items()}
1718
- return state
1719
-
1720
- return get
1721
-
1722
- pesq = create_pesq()
1723
-
1724
- def create_speech() -> Callable[[], dict[str, SpeechMetrics]]:
1725
- state: dict[str, SpeechMetrics] | None = None
1726
-
1727
- def get() -> dict[str, SpeechMetrics]:
1728
- nonlocal state
1729
- if state is None:
1730
- state = {
1731
- category: calc_speech(mixture_audio(), audio, pesq()[category])
1732
- for category, audio in sources_audio().items()
1733
- }
1734
- return state
1735
-
1736
- return get
1737
-
1738
- speech = create_speech()
1739
-
1740
- def create_mixture_stats() -> Callable[[], AudioStatsMetrics]:
1741
- state: AudioStatsMetrics | None = None
1742
-
1743
- def get() -> AudioStatsMetrics:
1744
- nonlocal state
1745
- if state is None:
1746
- state = calc_audio_stats(mixture_audio(), self.fg_info.ft_config.length / SAMPLE_RATE)
1747
- return state
1748
-
1749
- return get
1750
-
1751
- mixture_stats = create_mixture_stats()
1752
-
1753
- def create_sources_stats() -> Callable[[], dict[str, AudioStatsMetrics]]:
1754
- state: dict[str, AudioStatsMetrics] | None = None
1755
-
1756
- def get() -> dict[str, AudioStatsMetrics]:
1757
- nonlocal state
1758
- if state is None:
1759
- state = {
1760
- category: calc_audio_stats(audio, self.fg_info.ft_config.length / SAMPLE_RATE)
1761
- for category, audio in sources_audio().items()
1762
- }
1763
- return state
1764
-
1765
- return get
1766
-
1767
- sources_stats = create_sources_stats()
1768
-
1769
- def create_source_stats() -> Callable[[], AudioStatsMetrics]:
1770
- state: AudioStatsMetrics | None = None
1771
-
1772
- def get() -> AudioStatsMetrics:
1773
- nonlocal state
1774
- if state is None:
1775
- state = calc_audio_stats(source_audio(), self.fg_info.ft_config.length / SAMPLE_RATE)
1776
- return state
1777
-
1778
- return get
1779
-
1780
- source_stats = create_source_stats()
1781
-
1782
- def create_noise_stats() -> Callable[[], AudioStatsMetrics]:
1783
- state: AudioStatsMetrics | None = None
1784
-
1785
- def get() -> AudioStatsMetrics:
1786
- nonlocal state
1787
- if state is None:
1788
- state = calc_audio_stats(noise_audio(), self.fg_info.ft_config.length / SAMPLE_RATE)
1789
- return state
1790
-
1791
- return get
1792
-
1793
- noise_stats = create_noise_stats()
1794
-
1795
- def create_asr_config() -> Callable[[str], dict]:
1796
- state: dict[str, dict] = {}
1797
-
1798
- def get(asr_name) -> dict:
1799
- nonlocal state
1800
- if asr_name not in state:
1801
- value = self.asr_configs.get(asr_name, None)
1802
- if value is None:
1803
- raise ValueError(f"Unrecognized ASR name: '{asr_name}'")
1804
- state[asr_name] = value
1805
- return state[asr_name]
1806
-
1807
- return get
1808
-
1809
- asr_config = create_asr_config()
1810
-
1811
- def create_sources_asr() -> Callable[[str], dict[str, str]]:
1812
- state: dict[str, dict[str, str]] = {}
1813
-
1814
- def get(asr_name) -> dict[str, str]:
1815
- nonlocal state
1816
- if asr_name not in state:
1817
- state[asr_name] = {
1818
- category: calc_asr(audio, **asr_config(asr_name)).text
1819
- for category, audio in sources_audio().items()
1820
- }
1821
- return state[asr_name]
1822
-
1823
- return get
1824
-
1825
- sources_asr = create_sources_asr()
1826
-
1827
- def create_source_asr() -> Callable[[str], str]:
1828
- state: dict[str, str] = {}
1829
-
1830
- def get(asr_name) -> str:
1831
- nonlocal state
1832
- if asr_name not in state:
1833
- state[asr_name] = calc_asr(source_audio(), **asr_config(asr_name)).text
1834
- return state[asr_name]
1835
-
1836
- return get
1837
-
1838
- source_asr = create_source_asr()
1839
-
1840
- def create_mixture_asr() -> Callable[[str], str]:
1841
- state: dict[str, str] = {}
1842
-
1843
- def get(asr_name) -> str:
1844
- nonlocal state
1845
- if asr_name not in state:
1846
- state[asr_name] = calc_asr(mixture_audio(), **asr_config(asr_name)).text
1847
- return state[asr_name]
1848
-
1849
- return get
1850
-
1851
- mixture_asr = create_mixture_asr()
1852
-
1853
- def get_asr_name(m: str) -> str:
1854
- parts = m.split(".")
1855
- if len(parts) != 2:
1856
- raise ValueError(f"Unrecognized format: '{m}'; must be of the form: '<metric>.<name>'")
1857
- asr_name = parts[1]
1858
- return asr_name
1859
-
1860
- def calc(m: str) -> Any:
1861
- if m == "mxsnr":
1862
- return {category: source.snr for category, source in self.mixture(m_id).all_sources.items()}
1863
-
1864
- # Get cached data first, if exists
1865
- if not force:
1866
- value = self.read_mixture_data(m_id, m)[m]
1867
- if value is not None:
1868
- return value
1869
-
1870
- # Otherwise, generate data as needed
1871
- if m.startswith("mxwer"):
1872
- asr_name = get_asr_name(m)
1873
-
1874
- if self.mixture(m_id).is_noise_only:
1875
- # noise only, ignore/reset target asr
1876
- return float("nan")
1877
-
1878
- if source_asr(asr_name):
1879
- return calc_wer(mixture_asr(asr_name), source_asr(asr_name)).wer * 100
1880
-
1881
- # TODO: should this be NaN like above?
1882
- return float(0)
1883
-
1884
- if m.startswith("basewer"):
1885
- asr_name = get_asr_name(m)
1886
-
1887
- text = self.mixture_speech_metadata(m_id, "text")
1888
- base_wer: dict[str, float] = {}
1889
- for category, source in sources_asr(asr_name).items():
1890
- if isinstance(text[category], str):
1891
- base_wer[category] = calc_wer(source, str(text[category])).wer * 100
1892
- else:
1893
- base_wer[category] = 0
1894
- return base_wer
1895
-
1896
- if m.startswith("mxasr"):
1897
- return mixture_asr(get_asr_name(m))
1898
-
1899
- if m == "mxssnr_avg":
1900
- return calc_segsnr_f(segsnr_f()).avg
1901
-
1902
- if m == "mxssnr_std":
1903
- return calc_segsnr_f(segsnr_f()).std
1904
-
1905
- if m == "mxssnr_avg_db":
1906
- val = calc_segsnr_f(segsnr_f()).avg
1907
- if val is not None:
1908
- return linear_to_db(val)
1909
- return None
1910
-
1911
- if m == "mxssnr_std_db":
1912
- val = calc_segsnr_f(segsnr_f()).std
1913
- if val is not None:
1914
- return linear_to_db(val)
1915
- return None
1916
-
1917
- if m == "mxssnrdb_avg":
1918
- return calc_segsnr_f(segsnr_f()).db_avg
1919
-
1920
- if m == "mxssnrdb_std":
1921
- return calc_segsnr_f(segsnr_f()).db_std
1922
-
1923
- if m == "mxssnrf_avg":
1924
- return calc_segsnr_f_bin(source_f(), noise_f()).avg
1925
-
1926
- if m == "mxssnrf_std":
1927
- return calc_segsnr_f_bin(source_f(), noise_f()).std
1928
-
1929
- if m == "mxssnrdbf_avg":
1930
- return calc_segsnr_f_bin(source_f(), noise_f()).db_avg
1931
-
1932
- if m == "mxssnrdbf_std":
1933
- return calc_segsnr_f_bin(source_f(), noise_f()).db_std
1934
-
1935
- if m == "mxpesq":
1936
- if self.mixture(m_id).is_noise_only:
1937
- return dict.fromkeys(pesq(), 0)
1938
- return pesq()
1939
-
1940
- if m == "mxcsig":
1941
- if self.mixture(m_id).is_noise_only:
1942
- return dict.fromkeys(speech(), 0)
1943
- return {category: s.csig for category, s in speech().items()}
1944
-
1945
- if m == "mxcbak":
1946
- if self.mixture(m_id).is_noise_only:
1947
- return dict.fromkeys(speech(), 0)
1948
- return {category: s.cbak for category, s in speech().items()}
1949
-
1950
- if m == "mxcovl":
1951
- if self.mixture(m_id).is_noise_only:
1952
- return dict.fromkeys(speech(), 0)
1953
- return {category: s.covl for category, s in speech().items()}
1954
-
1955
- if m == "mxwsdr":
1956
- mixture = mixture_audio()[:, np.newaxis]
1957
- target = source_audio()[:, np.newaxis]
1958
- noise = noise_audio()[:, np.newaxis]
1959
- return calc_wsdr(
1960
- hypothesis=np.concatenate((mixture, noise), axis=1),
1961
- reference=np.concatenate((target, noise), axis=1),
1962
- with_log=True,
1963
- )[0]
1964
-
1965
- if m == "mxpd":
1966
- mixture_f = self.mixture_mixture_f(m_id)
1967
- return calc_phase_distance(hypothesis=mixture_f, reference=source_f())[0]
1968
-
1969
- if m == "mxstoi":
1970
- return stoi(
1971
- x=source_audio(),
1972
- y=mixture_audio(),
1973
- fs_sig=SAMPLE_RATE,
1974
- extended=False,
1975
- )
1976
-
1977
- if m == "mxdco":
1978
- return mixture_stats().dco
1979
-
1980
- if m == "mxmin":
1981
- return mixture_stats().min
1982
-
1983
- if m == "mxmax":
1984
- return mixture_stats().max
1985
-
1986
- if m == "mxpkdb":
1987
- return mixture_stats().pkdb
1988
-
1989
- if m == "mxlrms":
1990
- return mixture_stats().lrms
1991
-
1992
- if m == "mxpkr":
1993
- return mixture_stats().pkr
1994
-
1995
- if m == "mxtr":
1996
- return mixture_stats().tr
1997
-
1998
- if m == "mxcr":
1999
- return mixture_stats().cr
2000
-
2001
- if m == "mxfl":
2002
- return mixture_stats().fl
2003
-
2004
- if m == "mxpkc":
2005
- return mixture_stats().pkc
2006
-
2007
- if m == "mxtdco":
2008
- return source_stats().dco
2009
-
2010
- if m == "mxtmin":
2011
- return source_stats().min
2012
-
2013
- if m == "mxtmax":
2014
- return source_stats().max
2015
-
2016
- if m == "mxtpkdb":
2017
- return source_stats().pkdb
2018
-
2019
- if m == "mxtlrms":
2020
- return source_stats().lrms
2021
-
2022
- if m == "mxtpkr":
2023
- return source_stats().pkr
2024
-
2025
- if m == "mxttr":
2026
- return source_stats().tr
2027
-
2028
- if m == "mxtcr":
2029
- return source_stats().cr
2030
-
2031
- if m == "mxtfl":
2032
- return source_stats().fl
2033
-
2034
- if m == "mxtpkc":
2035
- return source_stats().pkc
2036
-
2037
- if m == "sdco":
2038
- return {category: s.dco for category, s in sources_stats().items()}
2039
-
2040
- if m == "smin":
2041
- return {category: s.min for category, s in sources_stats().items()}
2042
-
2043
- if m == "smax":
2044
- return {category: s.max for category, s in sources_stats().items()}
2045
-
2046
- if m == "spkdb":
2047
- return {category: s.pkdb for category, s in sources_stats().items()}
2048
-
2049
- if m == "slrms":
2050
- return {category: s.lrms for category, s in sources_stats().items()}
2051
-
2052
- if m == "spkr":
2053
- return {category: s.pkr for category, s in sources_stats().items()}
2054
-
2055
- if m == "str":
2056
- return {category: s.tr for category, s in sources_stats().items()}
2057
-
2058
- if m == "scr":
2059
- return {category: s.cr for category, s in sources_stats().items()}
2060
-
2061
- if m == "sfl":
2062
- return {category: s.fl for category, s in sources_stats().items()}
2063
-
2064
- if m == "spkc":
2065
- return {category: s.pkc for category, s in sources_stats().items()}
2066
-
2067
- if m.startswith("sasr"):
2068
- return sources_asr(get_asr_name(m))
2069
-
2070
- if m.startswith("mxsasr"):
2071
- return source_asr(get_asr_name(m))
2072
-
2073
- if m == "ndco":
2074
- return noise_stats().dco
2075
-
2076
- if m == "nmin":
2077
- return noise_stats().min
2078
-
2079
- if m == "nmax":
2080
- return noise_stats().max
2081
-
2082
- if m == "npkdb":
2083
- return noise_stats().pkdb
2084
-
2085
- if m == "nlrms":
2086
- return noise_stats().lrms
2087
-
2088
- if m == "npkr":
2089
- return noise_stats().pkr
2090
-
2091
- if m == "ntr":
2092
- return noise_stats().tr
2093
-
2094
- if m == "ncr":
2095
- return noise_stats().cr
2096
-
2097
- if m == "nfl":
2098
- return noise_stats().fl
2099
-
2100
- if m == "npkc":
2101
- return noise_stats().pkc
2102
-
2103
- if m == "sedavg":
2104
- return 0
2105
-
2106
- if m == "sedcnt":
2107
- return 0
2108
-
2109
- if m == "sedtop3":
2110
- return np.zeros(3, dtype=np.float32)
2111
-
2112
- if m == "sedtopn":
2113
- return 0
2114
-
2115
- if m == "ssnr":
2116
- return segsnr_f()
2117
-
2118
- raise AttributeError(f"Unrecognized metric: '{m}'")
2119
-
2120
- result: dict[str, Any] = {}
2121
- for metric in metrics:
2122
- result[metric] = calc(metric)
2123
-
2124
- # Check for metrics dependencies and add them even if not explicitly requested.
2125
- if metric.startswith("mxwer"):
2126
- dependencies = ("mxasr." + metric[6:], "sasr." + metric[6:])
2127
- for dependency in dependencies:
2128
- result[dependency] = calc(dependency)
1599
+ from ..metrics import calculate_metrics
2129
1600
 
2130
- return result
1601
+ return calculate_metrics(self, m_id, metrics, force)
2131
1602
 
2132
1603
 
2133
1604
  def _spectral_mask(db: partial, sm_id: int, use_cache: bool = True) -> SpectralMask:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: sonusai
3
- Version: 1.0.10
3
+ Version: 1.0.12
4
4
  Summary: Framework for building deep neural network models for sound, speech, and voice AI
5
5
  Home-page: https://aaware.com
6
6
  License: GPL-3.0-only
@@ -27,7 +27,7 @@ sonusai/genmixdb.py,sha256=VDQMF6JHcHc-yJAZ1Se3CM3ac8fFKIgnaxv4e5jdE1I,11281
27
27
  sonusai/ir_metric.py,sha256=nxS_mARPSZG5Y0G3L8HysOnkPj4v-RGxAxAVBYe-gJI,19600
28
28
  sonusai/lsdb.py,sha256=-Fhwd7YuL-OIymFqaNcBHtOq8l_8LxzoEE6ztduQCpY,5059
29
29
  sonusai/main.py,sha256=72feJv5XEVJE_CQatmNIL1VD9ca-Mo0QNDbXxLrHrbQ,2619
30
- sonusai/metrics/__init__.py,sha256=ssV6JEK_oklRSocsp6HMcG-GtJvV8IkRQtdKhHHmwU8,878
30
+ sonusai/metrics/__init__.py,sha256=0Y0xFHiO3TrH4DRt-htCXEXsc8TLGNRWfD16q16yWEs,927
31
31
  sonusai/metrics/calc_audio_stats.py,sha256=tIfTa40UdYCkj999kUghWafwnFBqFtJxB5yZhVp1YpA,1244
32
32
  sonusai/metrics/calc_class_weights.py,sha256=uF1jeFz73l5nSk6SQ-xkBGbrgvAvX_MKUA_Det2KAEM,3609
33
33
  sonusai/metrics/calc_optimal_thresholds.py,sha256=1bKPoqUYyHpq7lrx7hPnVXrJ5xWIewQjNG632GzKNNU,3502
@@ -40,20 +40,20 @@ sonusai/metrics/calc_segsnr_f.py,sha256=yLqUt--8osVgCNAkopbDZsldlVJ6a5AZEggarN8d
40
40
  sonusai/metrics/calc_speech.py,sha256=bFiWtKz_Fuu4F1kdWGmZ3qZ_LdoSI3pj0ziXZKxXE3U,14828
41
41
  sonusai/metrics/calc_wer.py,sha256=1MQYMx8ldHeodtJEtGibvDKhvSaGe6DBmZV4L8qOMgg,2362
42
42
  sonusai/metrics/calc_wsdr.py,sha256=vcALY-zuhyThRa1QMz2qW8L9kSBc2v32gV9u8bV7VaM,2556
43
+ sonusai/metrics/calculate_metrics.py,sha256=jcAyEV6loenu4fU_EvwEkpKxOrP8-K9O3rwQGlE48IU,12475
43
44
  sonusai/metrics/class_summary.py,sha256=mQbMxQ8EtFIN7S2h7A4Dk0X4XF_CIxKk3W8zZMmpfcw,2801
44
45
  sonusai/metrics/confusion_matrix_summary.py,sha256=lhd8TyHVMC03khX85h_D75XElmawx56KkqpX3X2O2gQ,3133
45
46
  sonusai/metrics/one_hot.py,sha256=aKc-xYd4zWIjbmoQikIcQ6BJB1k-68XKTg8eJCacHTU,13906
46
47
  sonusai/metrics/snr_summary.py,sha256=qKHctpmvGeu2cmjTG7iQPX1lvVUEtEnCIKwUGu6VrEQ,5773
47
48
  sonusai/metrics_summary.py,sha256=jtSwHomw23qwTYfzjFo_JmqzrkZcts1CMFFzTmJCmWk,12189
48
- sonusai/mixture/__init__.py,sha256=l4CgJN0gH4Z19jcQvXJbR8KSZ5f_ysnoAGi93LQaTjM,1260
49
+ sonusai/mixture/__init__.py,sha256=GGx8WG0pZwKmlXiWVBrtQXVY0dKW4yqDxSBgv7BI2Xc,1255
49
50
  sonusai/mixture/audio.py,sha256=JyrVtVPLH3aTXFgyl446f5uVHxlFRa4aBaSPYaMdg80,5814
50
51
  sonusai/mixture/class_balancing.py,sha256=lubicVCzxs4TMh2dZSsuIffkLkk1gmwjmwtrtQ27BVQ,3638
51
52
  sonusai/mixture/config.py,sha256=2_hEndyRXxyBpGzyBFaDT9REYGoK9Q7HQy8vDqPozus,23320
52
53
  sonusai/mixture/constants.py,sha256=Kklzhf5DL30yb3TpqRbvRUhcFrEXJ4s2S3D_nw4ARxM,1498
53
54
  sonusai/mixture/data_io.py,sha256=DV48sFcP2Qp3NBzvcnlptQOXU3aUEcAeLuh3XOtC5jI,5341
54
- sonusai/mixture/db.py,sha256=yd0bCiihuUAw3IgRlLqcshXB2QHep837O3TwjPyo-LM,5132
55
+ sonusai/mixture/db.py,sha256=zZnMFdW30leMCT1nX1Ml57ByLkqYEcm4VlekELvCFyc,5678
55
56
  sonusai/mixture/db_datatypes.py,sha256=VvNtbOgt5WSeSnBoVcNGC5gs_7hX_38pDUPjy5KRbG4,1471
56
- sonusai/mixture/db_file.py,sha256=P48TWYNyqchycENIqBu1QqhfsRDP6WK2VanPgxN1Imk,278
57
57
  sonusai/mixture/effects.py,sha256=zIb6ir0WSdKQJo7uJ3QQnV52RA6lJaqgQqvQh-s0dhc,11038
58
58
  sonusai/mixture/feature.py,sha256=7GJvFhfqeqerfjy9Vq9aKt-cecgYblK0IypNNo5hgwY,2285
59
59
  sonusai/mixture/generation.py,sha256=_vGTyqo0ocyOK84rTj_1QXciq1Tmxxl5XhwaXPWIEL0,33105
@@ -61,7 +61,7 @@ sonusai/mixture/helpers.py,sha256=dmyHwf1C5dZjYOd11kVV16KI33CaM-dU_fyaxOrrKt8,11
61
61
  sonusai/mixture/ir_delay.py,sha256=aiC23HMWQ08-v5wORgMx1_DOJSdh4kunULqiQ-SGuMo,2026
62
62
  sonusai/mixture/ir_effects.py,sha256=PqiqD4PS42-7kD6ESnsZi2a3tnKCFa4E0xqUujRBvGg,2152
63
63
  sonusai/mixture/log_duration_and_sizes.py,sha256=3ekS27IMKlnxIkQAmprzmBnzHOpRjZh3d7maL2VqWQU,927
64
- sonusai/mixture/mixdb.py,sha256=5YI0zKisFw_B-jKpB-Y1EYlJ8pHQDvOQLs9LEe0gT1w,84905
64
+ sonusai/mixture/mixdb.py,sha256=0smihhsBjENymN5iuNoaj5FIXSfSzMkpXN29QuAhIiE,67882
65
65
  sonusai/mixture/pad_audio.py,sha256=KNxVQAejA0hblLOnMJgLS6lFaeE0n3tWQ5rclaHBnIY,1015
66
66
  sonusai/mixture/parse.py,sha256=nqhjuR-J7_3wlGhVitYFvQwLJ1sclU8WZrVF0SyW2Cw,3700
67
67
  sonusai/mixture/resample.py,sha256=jXqH6FrZ0mlhQ07XqPx88TT9elu3HHVLw7Q0a7Lh5M4,221
@@ -134,7 +134,7 @@ sonusai/utils/tokenized_shell_vars.py,sha256=EDrrAgz5lJ0RBAjLcTJt1MeyjhbNZiqXkym
134
134
  sonusai/utils/write_audio.py,sha256=IHzrJoFtFcea_J6wo6QSiojRkgnNOzAEcg-z0rFV7nU,810
135
135
  sonusai/utils/yes_or_no.py,sha256=0h1okjXmDNbJp7rZJFR2V-HFU1GJDm3YFTUVmYExkOU,263
136
136
  sonusai/vars.py,sha256=m8pdgfR4A6A9TCGf_rok6jPAT5BgrEsYXTSISIh1nrI,1163
137
- sonusai-1.0.10.dist-info/METADATA,sha256=kliBuHLQIEAUTsv9Hav0VWo1IGQxpTao5bl233yOnaQ,2695
138
- sonusai-1.0.10.dist-info/WHEEL,sha256=RaoafKOydTQ7I_I3JTrPCg6kUmTgtm4BornzOqyEfJ8,88
139
- sonusai-1.0.10.dist-info/entry_points.txt,sha256=zMNjEphEPO6B3cD1GNpit7z-yA9tUU5-j3W2v-UWstU,92
140
- sonusai-1.0.10.dist-info/RECORD,,
137
+ sonusai-1.0.12.dist-info/METADATA,sha256=8GtmHLwVNnw6fFJLJrzqXFjpLW3eMABQK4aN_R0j0Is,2695
138
+ sonusai-1.0.12.dist-info/WHEEL,sha256=RaoafKOydTQ7I_I3JTrPCg6kUmTgtm4BornzOqyEfJ8,88
139
+ sonusai-1.0.12.dist-info/entry_points.txt,sha256=zMNjEphEPO6B3cD1GNpit7z-yA9tUU5-j3W2v-UWstU,92
140
+ sonusai-1.0.12.dist-info/RECORD,,
@@ -1,10 +0,0 @@
1
- from os.path import join
2
- from os.path import normpath
3
-
4
- from .constants import MIXDB_NAME
5
- from .constants import TEST_MIXDB_NAME
6
-
7
-
8
- def db_file(location: str, test: bool = False) -> str:
9
- name = TEST_MIXDB_NAME if test else MIXDB_NAME
10
- return normpath(join(location, name))