sonusai 0.20.3__py3-none-any.whl → 1.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. sonusai/__init__.py +16 -3
  2. sonusai/audiofe.py +241 -77
  3. sonusai/calc_metric_spenh.py +71 -73
  4. sonusai/config/__init__.py +3 -0
  5. sonusai/config/config.py +61 -0
  6. sonusai/config/config.yml +20 -0
  7. sonusai/config/constants.py +8 -0
  8. sonusai/constants.py +11 -0
  9. sonusai/data/genmixdb.yml +21 -36
  10. sonusai/{mixture/datatypes.py → datatypes.py} +91 -130
  11. sonusai/deprecated/plot.py +4 -5
  12. sonusai/doc/doc.py +4 -4
  13. sonusai/doc.py +11 -4
  14. sonusai/genft.py +43 -45
  15. sonusai/genmetrics.py +25 -19
  16. sonusai/genmix.py +54 -82
  17. sonusai/genmixdb.py +88 -264
  18. sonusai/ir_metric.py +30 -34
  19. sonusai/lsdb.py +41 -48
  20. sonusai/main.py +15 -22
  21. sonusai/metrics/calc_audio_stats.py +4 -293
  22. sonusai/metrics/calc_class_weights.py +4 -4
  23. sonusai/metrics/calc_optimal_thresholds.py +8 -5
  24. sonusai/metrics/calc_pesq.py +2 -2
  25. sonusai/metrics/calc_segsnr_f.py +4 -4
  26. sonusai/metrics/calc_speech.py +25 -13
  27. sonusai/metrics/class_summary.py +7 -7
  28. sonusai/metrics/confusion_matrix_summary.py +5 -5
  29. sonusai/metrics/one_hot.py +4 -4
  30. sonusai/metrics/snr_summary.py +7 -7
  31. sonusai/metrics_summary.py +38 -45
  32. sonusai/mixture/__init__.py +4 -104
  33. sonusai/mixture/audio.py +10 -39
  34. sonusai/mixture/class_balancing.py +103 -0
  35. sonusai/mixture/config.py +251 -271
  36. sonusai/mixture/constants.py +35 -39
  37. sonusai/mixture/data_io.py +25 -36
  38. sonusai/mixture/db_datatypes.py +58 -22
  39. sonusai/mixture/effects.py +386 -0
  40. sonusai/mixture/feature.py +7 -11
  41. sonusai/mixture/generation.py +478 -628
  42. sonusai/mixture/helpers.py +82 -184
  43. sonusai/mixture/ir_delay.py +3 -4
  44. sonusai/mixture/ir_effects.py +77 -0
  45. sonusai/mixture/log_duration_and_sizes.py +6 -12
  46. sonusai/mixture/mixdb.py +910 -729
  47. sonusai/mixture/pad_audio.py +35 -0
  48. sonusai/mixture/resample.py +7 -0
  49. sonusai/mixture/sox_effects.py +195 -0
  50. sonusai/mixture/sox_help.py +650 -0
  51. sonusai/mixture/spectral_mask.py +2 -2
  52. sonusai/mixture/truth.py +17 -15
  53. sonusai/mixture/truth_functions/crm.py +12 -12
  54. sonusai/mixture/truth_functions/energy.py +22 -22
  55. sonusai/mixture/truth_functions/file.py +5 -5
  56. sonusai/mixture/truth_functions/metadata.py +4 -4
  57. sonusai/mixture/truth_functions/metrics.py +4 -4
  58. sonusai/mixture/truth_functions/phoneme.py +3 -3
  59. sonusai/mixture/truth_functions/sed.py +11 -13
  60. sonusai/mixture/truth_functions/target.py +10 -10
  61. sonusai/mkwav.py +26 -29
  62. sonusai/onnx_predict.py +240 -88
  63. sonusai/queries/__init__.py +2 -2
  64. sonusai/queries/queries.py +38 -34
  65. sonusai/speech/librispeech.py +1 -1
  66. sonusai/speech/mcgill.py +1 -1
  67. sonusai/speech/timit.py +2 -2
  68. sonusai/summarize_metric_spenh.py +10 -17
  69. sonusai/utils/__init__.py +7 -1
  70. sonusai/utils/asl_p56.py +2 -2
  71. sonusai/utils/asr.py +2 -2
  72. sonusai/utils/asr_functions/aaware_whisper.py +4 -5
  73. sonusai/utils/choice.py +31 -0
  74. sonusai/utils/compress.py +1 -1
  75. sonusai/utils/dataclass_from_dict.py +19 -1
  76. sonusai/utils/energy_f.py +3 -3
  77. sonusai/utils/evaluate_random_rule.py +15 -0
  78. sonusai/utils/keyboard_interrupt.py +12 -0
  79. sonusai/utils/onnx_utils.py +3 -17
  80. sonusai/utils/print_mixture_details.py +21 -19
  81. sonusai/utils/{temp_seed.py → rand.py} +3 -3
  82. sonusai/utils/read_predict_data.py +2 -2
  83. sonusai/utils/reshape.py +3 -3
  84. sonusai/utils/stratified_shuffle_split.py +3 -3
  85. sonusai/{mixture → utils}/tokenized_shell_vars.py +1 -1
  86. sonusai/utils/write_audio.py +2 -2
  87. sonusai/vars.py +11 -4
  88. {sonusai-0.20.3.dist-info → sonusai-1.0.2.dist-info}/METADATA +4 -2
  89. sonusai-1.0.2.dist-info/RECORD +138 -0
  90. sonusai/mixture/augmentation.py +0 -444
  91. sonusai/mixture/class_count.py +0 -15
  92. sonusai/mixture/eq_rule_is_valid.py +0 -45
  93. sonusai/mixture/target_class_balancing.py +0 -107
  94. sonusai/mixture/targets.py +0 -175
  95. sonusai-0.20.3.dist-info/RECORD +0 -128
  96. {sonusai-0.20.3.dist-info → sonusai-1.0.2.dist-info}/WHEEL +0 -0
  97. {sonusai-0.20.3.dist-info → sonusai-1.0.2.dist-info}/entry_points.txt +0 -0
sonusai/mixture/mixdb.py CHANGED
@@ -6,36 +6,43 @@ from sqlite3 import Connection
6
6
  from sqlite3 import Cursor
7
7
  from typing import Any
8
8
 
9
- from .datatypes import ASRConfigs
10
- from .datatypes import AudioF
11
- from .datatypes import AudioT
12
- from .datatypes import ClassCount
13
- from .datatypes import Feature
14
- from .datatypes import FeatureGeneratorConfig
15
- from .datatypes import FeatureGeneratorInfo
16
- from .datatypes import GeneralizedIDs
17
- from .datatypes import ImpulseResponseFile
18
- from .datatypes import MetricDoc
19
- from .datatypes import MetricDocs
20
- from .datatypes import Mixture
21
- from .datatypes import NoiseFile
22
- from .datatypes import Segsnr
23
- from .datatypes import SpectralMask
24
- from .datatypes import SpeechMetadata
25
- from .datatypes import TargetFile
26
- from .datatypes import TransformConfig
27
- from .datatypes import TruthConfigs
28
- from .datatypes import TruthDict
29
- from .datatypes import UniversalSNR
9
+ from ..datatypes import ASRConfigs
10
+ from ..datatypes import AudioF
11
+ from ..datatypes import AudioT
12
+ from ..datatypes import ClassCount
13
+ from ..datatypes import Feature
14
+ from ..datatypes import FeatureGeneratorConfig
15
+ from ..datatypes import FeatureGeneratorInfo
16
+ from ..datatypes import GeneralizedIDs
17
+ from ..datatypes import ImpulseResponseFile
18
+ from ..datatypes import MetricDoc
19
+ from ..datatypes import MetricDocs
20
+ from ..datatypes import Mixture
21
+ from ..datatypes import Segsnr
22
+ from ..datatypes import SourceFile
23
+ from ..datatypes import Sources
24
+ from ..datatypes import SourcesAudioF
25
+ from ..datatypes import SourcesAudioT
26
+ from ..datatypes import SpectralMask
27
+ from ..datatypes import SpeechMetadata
28
+ from ..datatypes import TransformConfig
29
+ from ..datatypes import TruthConfigs
30
+ from ..datatypes import TruthDict
31
+ from ..datatypes import TruthsConfigs
32
+ from ..datatypes import TruthsDict
33
+ from ..datatypes import UniversalSNR
30
34
 
31
35
 
32
36
  def db_file(location: str, test: bool = False) -> str:
33
37
  from os.path import join
34
38
 
39
+ from .constants import MIXDB_NAME
40
+ from .constants import TEST_MIXDB_NAME
41
+
35
42
  if test:
36
- name = "mixdb_test.db"
43
+ name = TEST_MIXDB_NAME
37
44
  else:
38
- name = "mixdb.db"
45
+ name = MIXDB_NAME
39
46
 
40
47
  return join(location, name)
41
48
 
@@ -103,7 +110,7 @@ class MixtureDatabase:
103
110
  config = load_config(self.location)
104
111
  new_asr_configs = json.dumps(config["asr_configs"])
105
112
  with self.db() as c:
106
- old_asr_configs = c.execute("SELECT top.asr_configs FROM top").fetchone()
113
+ old_asr_configs = c.execute("SELECT asr_configs FROM top").fetchone()
107
114
 
108
115
  if old_asr_configs is not None and new_asr_configs != old_asr_configs[0]:
109
116
  con = db_connection(location=self.location, readonly=False, test=self.test)
@@ -113,7 +120,7 @@ class MixtureDatabase:
113
120
 
114
121
  @cached_property
115
122
  def json(self) -> str:
116
- from .datatypes import MixtureDatabaseConfig
123
+ from ..datatypes import MixtureDatabaseConfig
117
124
 
118
125
  config = MixtureDatabaseConfig(
119
126
  asr_configs=self.asr_configs,
@@ -121,13 +128,11 @@ class MixtureDatabase:
121
128
  class_labels=self.class_labels,
122
129
  class_weights_threshold=self.class_weights_thresholds,
123
130
  feature=self.feature,
124
- impulse_response_files=self.impulse_response_files,
125
- mixtures=self.mixtures(),
126
- noise_mix_mode=self.noise_mix_mode,
127
- noise_files=self.noise_files,
131
+ ir_files=self.ir_files,
132
+ mixtures=self.mixtures,
128
133
  num_classes=self.num_classes,
129
134
  spectral_masks=self.spectral_masks,
130
- target_files=self.target_files,
135
+ source_files=self.source_files,
131
136
  )
132
137
  return config.to_json(indent=2)
133
138
 
@@ -153,30 +158,28 @@ class MixtureDatabase:
153
158
  return get_feature_generator_info(self.fg_config)
154
159
 
155
160
  @cached_property
156
- def truth_parameters(self) -> dict[str, int | None]:
161
+ def truth_parameters(self) -> dict[str, dict[str, int | None]]:
157
162
  with self.db() as c:
158
- rows = c.execute("SELECT * FROM truth_parameters").fetchall()
159
- truth_parameters: dict[str, int | None] = {}
163
+ rows = c.execute("SELECT category, name, parameters FROM truth_parameters").fetchall()
164
+ truth_parameters: dict[str, dict[str, int | None]] = {}
160
165
  for row in rows:
161
- truth_parameters[row[1]] = row[2]
166
+ category, name, parameters = row
167
+ if category not in truth_parameters:
168
+ truth_parameters[category] = {}
169
+ truth_parameters[category][name] = parameters
162
170
  return truth_parameters
163
171
 
164
172
  @cached_property
165
173
  def num_classes(self) -> int:
166
174
  with self.db() as c:
167
- return int(c.execute("SELECT top.num_classes FROM top").fetchone()[0])
168
-
169
- @cached_property
170
- def noise_mix_mode(self) -> str:
171
- with self.db() as c:
172
- return str(c.execute("SELECT top.noise_mix_mode FROM top").fetchone()[0])
175
+ return int(c.execute("SELECT num_classes FROM top").fetchone()[0])
173
176
 
174
177
  @cached_property
175
178
  def asr_configs(self) -> ASRConfigs:
176
179
  import json
177
180
 
178
181
  with self.db() as c:
179
- return json.loads(c.execute("SELECT top.asr_configs FROM top").fetchone()[0])
182
+ return json.loads(c.execute("SELECT asr_configs FROM top").fetchone()[0])
180
183
 
181
184
  @cached_property
182
185
  def supported_metrics(self) -> MetricDocs:
@@ -223,36 +226,36 @@ class MixtureDatabase:
223
226
  "mxssnrdbf_std",
224
227
  "Per-bin segmental standard deviation of the dB frame values over all frames (using feature transform)",
225
228
  ),
226
- MetricDoc("Mixture Metrics", "mxpesq", "PESQ of mixture versus true targets"),
229
+ MetricDoc("Mixture Metrics", "mxpesq", "PESQ of mixture versus true sources"),
227
230
  MetricDoc(
228
231
  "Mixture Metrics",
229
232
  "mxwsdr",
230
- "Weighted signal distortion ratio of mixture versus true targets",
233
+ "Weighted signal distortion ratio of mixture versus true sources",
231
234
  ),
232
235
  MetricDoc(
233
236
  "Mixture Metrics",
234
237
  "mxpd",
235
- "Phase distance between mixture and true targets",
238
+ "Phase distance between mixture and true sources",
236
239
  ),
237
240
  MetricDoc(
238
241
  "Mixture Metrics",
239
242
  "mxstoi",
240
- "Short term objective intelligibility of mixture versus true targets",
243
+ "Short term objective intelligibility of mixture versus true sources",
241
244
  ),
242
245
  MetricDoc(
243
246
  "Mixture Metrics",
244
247
  "mxcsig",
245
- "Predicted rating of speech distortion of mixture versus true targets",
248
+ "Predicted rating of speech distortion of mixture versus true sources",
246
249
  ),
247
250
  MetricDoc(
248
251
  "Mixture Metrics",
249
252
  "mxcbak",
250
- "Predicted rating of background distortion of mixture versus true targets",
253
+ "Predicted rating of background distortion of mixture versus true sources",
251
254
  ),
252
255
  MetricDoc(
253
256
  "Mixture Metrics",
254
257
  "mxcovl",
255
- "Predicted rating of overall quality of mixture versus true targets",
258
+ "Predicted rating of overall quality of mixture versus true sources",
256
259
  ),
257
260
  MetricDoc("Mixture Metrics", "ssnr", "Segmental SNR"),
258
261
  MetricDoc("Mixture Metrics", "mxdco", "Mixture DC offset"),
@@ -265,26 +268,26 @@ class MixtureDatabase:
265
268
  MetricDoc("Mixture Metrics", "mxcr", "Mixture Crest factor"),
266
269
  MetricDoc("Mixture Metrics", "mxfl", "Mixture Flat factor"),
267
270
  MetricDoc("Mixture Metrics", "mxpkc", "Mixture Pk count"),
268
- MetricDoc("Mixture Metrics", "mxtdco", "Mixture target DC offset"),
269
- MetricDoc("Mixture Metrics", "mxtmin", "Mixture target min level"),
270
- MetricDoc("Mixture Metrics", "mxtmax", "Mixture target max levl"),
271
- MetricDoc("Mixture Metrics", "mxtpkdb", "Mixture target Pk lev dB"),
272
- MetricDoc("Mixture Metrics", "mxtlrms", "Mixture target RMS lev dB"),
273
- MetricDoc("Mixture Metrics", "mxtpkr", "Mixture target RMS Pk dB"),
274
- MetricDoc("Mixture Metrics", "mxttr", "Mixture target RMS Tr dB"),
275
- MetricDoc("Mixture Metrics", "mxtcr", "Mixture target Crest factor"),
276
- MetricDoc("Mixture Metrics", "mxtfl", "Mixture target Flat factor"),
277
- MetricDoc("Mixture Metrics", "mxtpkc", "Mixture target Pk count"),
278
- MetricDoc("Targets Metrics", "tdco", "Targets DC offset"),
279
- MetricDoc("Targets Metrics", "tmin", "Targets min level"),
280
- MetricDoc("Targets Metrics", "tmax", "Targets max levl"),
281
- MetricDoc("Targets Metrics", "tpkdb", "Targets Pk lev dB"),
282
- MetricDoc("Targets Metrics", "tlrms", "Targets RMS lev dB"),
283
- MetricDoc("Targets Metrics", "tpkr", "Targets RMS Pk dB"),
284
- MetricDoc("Targets Metrics", "ttr", "Targets RMS Tr dB"),
285
- MetricDoc("Targets Metrics", "tcr", "Targets Crest factor"),
286
- MetricDoc("Targets Metrics", "tfl", "Targets Flat factor"),
287
- MetricDoc("Targets Metrics", "tpkc", "Targets Pk count"),
271
+ MetricDoc("Mixture Metrics", "mxtdco", "Mixture source DC offset"),
272
+ MetricDoc("Mixture Metrics", "mxtmin", "Mixture source min level"),
273
+ MetricDoc("Mixture Metrics", "mxtmax", "Mixture source max levl"),
274
+ MetricDoc("Mixture Metrics", "mxtpkdb", "Mixture source Pk lev dB"),
275
+ MetricDoc("Mixture Metrics", "mxtlrms", "Mixture source RMS lev dB"),
276
+ MetricDoc("Mixture Metrics", "mxtpkr", "Mixture source RMS Pk dB"),
277
+ MetricDoc("Mixture Metrics", "mxttr", "Mixture source RMS Tr dB"),
278
+ MetricDoc("Mixture Metrics", "mxtcr", "Mixture source Crest factor"),
279
+ MetricDoc("Mixture Metrics", "mxtfl", "Mixture source Flat factor"),
280
+ MetricDoc("Mixture Metrics", "mxtpkc", "Mixture source Pk count"),
281
+ MetricDoc("Sources Metrics", "sdco", "Sources DC offset"),
282
+ MetricDoc("Sources Metrics", "smin", "Sources min level"),
283
+ MetricDoc("Sources Metrics", "smax", "Sources max levl"),
284
+ MetricDoc("Sources Metrics", "spkdb", "Sources Pk lev dB"),
285
+ MetricDoc("Sources Metrics", "slrms", "Sources RMS lev dB"),
286
+ MetricDoc("Sources Metrics", "spkr", "Sources RMS Pk dB"),
287
+ MetricDoc("Sources Metrics", "str", "Sources RMS Tr dB"),
288
+ MetricDoc("Sources Metrics", "scr", "Sources Crest factor"),
289
+ MetricDoc("Sources Metrics", "sfl", "Sources Flat factor"),
290
+ MetricDoc("Sources Metrics", "spkc", "Sources Pk count"),
288
291
  MetricDoc("Noise Metrics", "ndco", "Noise DC offset"),
289
292
  MetricDoc("Noise Metrics", "nmin", "Noise min level"),
290
293
  MetricDoc("Noise Metrics", "nmax", "Noise max levl"),
@@ -320,16 +323,16 @@ class MixtureDatabase:
320
323
  for name in self.asr_configs:
321
324
  metrics.append(
322
325
  MetricDoc(
323
- "Target Metrics",
324
- f"mxtasr.{name}",
325
- f"Mixture Target ASR text using {name} ASR as defined in mixdb asr_configs parameter",
326
+ "Source Metrics",
327
+ f"mxsasr.{name}",
328
+ f"Mixture Source ASR text using {name} ASR as defined in mixdb asr_configs parameter",
326
329
  )
327
330
  )
328
331
  metrics.append(
329
332
  MetricDoc(
330
- "Target Metrics",
331
- f"tasr.{name}",
332
- f"Targets ASR text using {name} ASR as defined in mixdb asr_configs parameter",
333
+ "Source Metrics",
334
+ f"sasr.{name}",
335
+ f"Sources ASR text using {name} ASR as defined in mixdb asr_configs parameter",
333
336
  )
334
337
  )
335
338
  metrics.append(
@@ -341,16 +344,16 @@ class MixtureDatabase:
341
344
  )
342
345
  metrics.append(
343
346
  MetricDoc(
344
- "Target Metrics",
347
+ "Source Metrics",
345
348
  f"basewer.{name}",
346
- f"Word error rate of tasr.{name} vs. speech text metadata for the target",
349
+ f"Word error rate of sasr.{name} vs. speech text metadata for the source",
347
350
  )
348
351
  )
349
352
  metrics.append(
350
353
  MetricDoc(
351
354
  "Mixture Metrics",
352
355
  f"mxwer.{name}",
353
- f"Word error rate of mxasr.{name} vs. tasr.{name}",
356
+ f"Word error rate of mxasr.{name} vs. sasr.{name}",
354
357
  )
355
358
  )
356
359
 
@@ -359,12 +362,12 @@ class MixtureDatabase:
359
362
  @cached_property
360
363
  def class_balancing(self) -> bool:
361
364
  with self.db() as c:
362
- return bool(c.execute("SELECT top.class_balancing FROM top").fetchone()[0])
365
+ return bool(c.execute("SELECT class_balancing FROM top").fetchone()[0])
363
366
 
364
367
  @cached_property
365
368
  def feature(self) -> str:
366
369
  with self.db() as c:
367
- return str(c.execute("SELECT top.feature FROM top").fetchone()[0])
370
+ return str(c.execute("SELECT feature FROM top").fetchone()[0])
368
371
 
369
372
  @cached_property
370
373
  def fg_decimation(self) -> int:
@@ -396,7 +399,7 @@ class MixtureDatabase:
396
399
 
397
400
  @cached_property
398
401
  def transform_frame_ms(self) -> float:
399
- from .constants import SAMPLE_RATE
402
+ from ..constants import SAMPLE_RATE
400
403
 
401
404
  return float(self.ft_config.overlap) / float(SAMPLE_RATE / 1000)
402
405
 
@@ -417,12 +420,7 @@ class MixtureDatabase:
417
420
  return self.ft_config.overlap * self.fg_decimation * self.fg_step
418
421
 
419
422
  def total_samples(self, m_ids: GeneralizedIDs = "*") -> int:
420
- samples = 0
421
- for m_id in self.mixids_to_list(m_ids):
422
- s = self.mixture(m_id).samples
423
- if s is not None:
424
- samples += s
425
- return samples
423
+ return sum([self.mixture(m_id).samples for m_id in self.mixids_to_list(m_ids)])
426
424
 
427
425
  def total_transform_frames(self, m_ids: GeneralizedIDs = "*") -> int:
428
426
  return self.total_samples(m_ids) // self.ft_config.overlap
@@ -457,10 +455,7 @@ class MixtureDatabase:
457
455
  :return: Class labels
458
456
  """
459
457
  with self.db() as c:
460
- return [
461
- str(item[0])
462
- for item in c.execute("SELECT class_label.label FROM class_label ORDER BY class_label.id").fetchall()
463
- ]
458
+ return [str(item[0]) for item in c.execute("SELECT label FROM class_label ORDER BY id").fetchall()]
464
459
 
465
460
  @cached_property
466
461
  def class_weights_thresholds(self) -> list[float]:
@@ -469,37 +464,20 @@ class MixtureDatabase:
469
464
  :return: Class weights thresholds
470
465
  """
471
466
  with self.db() as c:
472
- return [
473
- float(item[0])
474
- for item in c.execute(
475
- "SELECT class_weights_threshold.threshold FROM class_weights_threshold"
476
- ).fetchall()
477
- ]
478
-
479
- @cached_property
480
- def truth_configs(self) -> TruthConfigs:
481
- """Get truth configs from db
482
-
483
- :return: Truth configs
484
- """
485
- import json
467
+ return [float(item[0]) for item in c.execute("SELECT threshold FROM class_weights_threshold").fetchall()]
486
468
 
487
- from .datatypes import TruthConfig
469
+ def category_truth_configs(self, category: str) -> dict[str, str]:
470
+ return _category_truth_configs(self.db, category, self.use_cache)
488
471
 
489
- with self.db() as c:
490
- truth_configs: TruthConfigs = {}
491
- for truth_config_record in c.execute("SELECT truth_config.config FROM truth_config").fetchall():
492
- truth_config = json.loads(truth_config_record[0])
493
- if truth_config["name"] not in truth_configs:
494
- truth_configs[truth_config["name"]] = TruthConfig(
495
- function=truth_config["function"],
496
- stride_reduction=truth_config["stride_reduction"],
497
- config=truth_config["config"],
498
- )
499
- return truth_configs
472
+ def source_truth_configs(self, s_id: int) -> TruthConfigs:
473
+ return _source_truth_configs(self.db, s_id, self.use_cache)
500
474
 
501
- def target_truth_configs(self, t_id: int) -> TruthConfigs:
502
- return _target_truth_configs(self.db, t_id, self.use_cache)
475
+ def mixture_truth_configs(self, m_id: int) -> TruthsConfigs:
476
+ mixture = self.mixture(m_id)
477
+ return {
478
+ category: self.source_truth_configs(mixture.all_sources[category].file_id)
479
+ for category in mixture.all_sources
480
+ }
503
481
 
504
482
  @cached_property
505
483
  def random_snrs(self) -> list[float]:
@@ -509,10 +487,7 @@ class MixtureDatabase:
509
487
  """
510
488
  with self.db() as c:
511
489
  return list(
512
- {
513
- float(item[0])
514
- for item in c.execute("SELECT mixture.snr FROM mixture WHERE mixture.random_snr == 1").fetchall()
515
- }
490
+ {float(item[0]) for item in c.execute("SELECT snr FROM source WHERE snr_random == 1").fetchall()}
516
491
  )
517
492
 
518
493
  @cached_property
@@ -523,10 +498,7 @@ class MixtureDatabase:
523
498
  """
524
499
  with self.db() as c:
525
500
  return list(
526
- {
527
- float(item[0])
528
- for item in c.execute("SELECT mixture.snr FROM mixture WHERE mixture.random_snr == 0").fetchall()
529
- }
501
+ {float(item[0]) for item in c.execute("SELECT snr FROM source WHERE snr_random == 0").fetchall()}
530
502
  )
531
503
 
532
504
  @cached_property
@@ -570,199 +542,216 @@ class MixtureDatabase:
570
542
  return _spectral_mask(self.db, sm_id, self.use_cache)
571
543
 
572
544
  @cached_property
573
- def target_files(self) -> list[TargetFile]:
574
- """Get target files from db
545
+ def source_files(self) -> dict[str, list[SourceFile]]:
546
+ """Get source files from db
575
547
 
576
- :return: Target files
548
+ :return: Source files
577
549
  """
578
550
  import json
579
551
 
580
- from .datatypes import TruthConfig
581
- from .datatypes import TruthConfigs
582
- from .db_datatypes import TargetFileRecord
552
+ from ..datatypes import TruthConfig
553
+ from ..datatypes import TruthConfigs
554
+ from .db_datatypes import SourceFileRecord
583
555
 
584
556
  with self.db() as c:
585
- target_files: list[TargetFile] = []
586
- target_file_records = [
587
- TargetFileRecord(*result) for result in c.execute("SELECT * FROM target_file").fetchall()
588
- ]
589
- for target_file_record in target_file_records:
590
- truth_configs: TruthConfigs = {}
591
- for truth_config_records in c.execute(
592
- """
593
- SELECT truth_config.config
594
- FROM truth_config, target_file_truth_config
595
- WHERE ? = target_file_truth_config.target_file_id
596
- AND truth_config.id = target_file_truth_config.truth_config_id
597
- """,
598
- (target_file_record.id,),
599
- ).fetchall():
600
- truth_config = json.loads(truth_config_records[0])
601
- truth_configs[truth_config["name"]] = TruthConfig(
602
- function=truth_config["function"],
603
- stride_reduction=truth_config["stride_reduction"],
604
- config=truth_config["config"],
605
- )
606
- target_files.append(
607
- TargetFile(
608
- name=target_file_record.name,
609
- samples=target_file_record.samples,
610
- class_indices=json.loads(target_file_record.class_indices),
611
- level_type=target_file_record.level_type,
612
- truth_configs=truth_configs,
613
- speaker_id=target_file_record.speaker_id,
557
+ source_files: dict[str, list[SourceFile]] = {}
558
+ categories = c.execute("SELECT DISTINCT category FROM source_file").fetchall()
559
+ for category in categories:
560
+ source_files[category[0]] = []
561
+ source_file_records = [
562
+ SourceFileRecord(*result)
563
+ for result in c.execute("SELECT * FROM source_file WHERE ? = category", (category[0],)).fetchall()
564
+ ]
565
+ for source_file_record in source_file_records:
566
+ truth_configs: TruthConfigs = {}
567
+ for truth_config_records in c.execute(
568
+ """
569
+ SELECT truth_config.config
570
+ FROM truth_config, source_file_truth_config
571
+ WHERE ? = source_file_truth_config.source_file_id
572
+ AND truth_config.id = source_file_truth_config.truth_config_id
573
+ """,
574
+ (source_file_record.id,),
575
+ ).fetchall():
576
+ truth_config = json.loads(truth_config_records[0])
577
+ truth_configs[truth_config["name"]] = TruthConfig(
578
+ function=truth_config["function"],
579
+ stride_reduction=truth_config["stride_reduction"],
580
+ config=truth_config["config"],
581
+ )
582
+ source_files[source_file_record.category].append(
583
+ SourceFile(
584
+ id=source_file_record.id,
585
+ category=source_file_record.category,
586
+ name=source_file_record.name,
587
+ samples=source_file_record.samples,
588
+ class_indices=json.loads(source_file_record.class_indices),
589
+ level_type=source_file_record.level_type,
590
+ truth_configs=truth_configs,
591
+ speaker_id=source_file_record.speaker_id,
592
+ )
614
593
  )
615
- )
616
- return target_files
594
+ return source_files
617
595
 
618
596
  @cached_property
619
- def target_file_ids(self) -> list[int]:
620
- """Get target file IDs from db
597
+ def source_file_ids(self) -> dict[str, list[int]]:
598
+ """Get source file IDs from db
621
599
 
622
- :return: List of target file IDs
600
+ :return: Dictionary of list of source file IDs
623
601
  """
624
602
  with self.db() as c:
625
- return [int(item[0]) for item in c.execute("SELECT target_file.id FROM target_file").fetchall()]
603
+ source_file_ids: dict[str, list[int]] = {}
604
+ categories = c.execute("SELECT DISTINCT category FROM source_file").fetchall()
605
+ for category in categories:
606
+ source_file_ids[category[0]] = [
607
+ int(item[0])
608
+ for item in c.execute("SELECT id FROM source_file WHERE ? = category", (category[0],)).fetchall()
609
+ ]
610
+ return source_file_ids
626
611
 
627
- def target_file(self, t_id: int) -> TargetFile:
628
- """Get target file with ID from db
612
+ def source_file(self, s_id: int) -> SourceFile:
613
+ """Get source file with ID from db
629
614
 
630
- :param t_id: Target file ID
631
- :return: Target file
615
+ :param s_id: Source file ID
616
+ :return: Source file
632
617
  """
633
- return _target_file(self.db, t_id, self.use_cache)
618
+ return _source_file(self.db, s_id, self.use_cache)
634
619
 
635
- @cached_property
636
- def num_target_files(self) -> int:
637
- """Get number of target files from db
620
+ def num_source_files(self, category: str) -> int:
621
+ """Get number of source files from category from db
638
622
 
639
- :return: Number of target files
623
+ :param category: Source category
624
+ :return: Number of source files
640
625
  """
641
- with self.db() as c:
642
- return int(c.execute("SELECT count(target_file.id) FROM target_file").fetchone()[0])
626
+ return _num_source_files(self.db, category, self.use_cache)
643
627
 
644
628
  @cached_property
645
- def noise_files(self) -> list[NoiseFile]:
646
- """Get noise files from db
629
+ def ir_files(self) -> list[ImpulseResponseFile]:
630
+ """Get impulse response files from db
647
631
 
648
- :return: Noise files
632
+ :return: Impulse response files
649
633
  """
650
- with self.db() as c:
651
- return [
652
- NoiseFile(name=noise[0], samples=noise[1])
653
- for noise in c.execute("SELECT noise_file.name, samples FROM noise_file").fetchall()
654
- ]
655
-
656
- @cached_property
657
- def noise_file_ids(self) -> list[int]:
658
- """Get noise file IDs from db
634
+ from .db_datatypes import ImpulseResponseFileRecord
659
635
 
660
- :return: List of noise file IDs
661
- """
662
636
  with self.db() as c:
663
- return [int(item[0]) for item in c.execute("SELECT noise_file.id FROM noise_file").fetchall()]
637
+ files: list[ImpulseResponseFile] = []
638
+ entries = c.execute("SELECT * FROM ir_file").fetchall()
639
+ for entry in entries:
640
+ file = ImpulseResponseFileRecord(*entry)
641
+
642
+ tags = [
643
+ tag[0]
644
+ for tag in c.execute(
645
+ """
646
+ SELECT ir_tag.tag
647
+ FROM ir_tag, ir_file_ir_tag
648
+ WHERE ? = ir_file_ir_tag.file_id
649
+ AND ir_tag.id = ir_file_ir_tag.tag_id
650
+ """,
651
+ (file.id,),
652
+ ).fetchall()
653
+ ]
664
654
 
665
- def noise_file(self, n_id: int) -> NoiseFile:
666
- """Get noise file with ID from db
655
+ files.append(
656
+ ImpulseResponseFile(
657
+ delay=file.delay,
658
+ name=file.name,
659
+ tags=tags,
660
+ )
661
+ )
667
662
 
668
- :param n_id: Noise file ID
669
- :return: Noise file
670
- """
671
- return _noise_file(self.db, n_id, self.use_cache)
663
+ return files
672
664
 
673
665
  @cached_property
674
- def num_noise_files(self) -> int:
675
- """Get number of noise files from db
666
+ def ir_file_ids(self) -> list[int]:
667
+ """Get impulse response file IDs from db
676
668
 
677
- :return: Number of noise files
669
+ :return: List of impulse response file IDs
678
670
  """
679
671
  with self.db() as c:
680
- return int(c.execute("SELECT count(noise_file.id) FROM noise_file").fetchone()[0])
672
+ return [int(item[0]) for item in c.execute("SELECT id FROM ir_file").fetchall()]
681
673
 
682
- @cached_property
683
- def impulse_response_files(self) -> list[ImpulseResponseFile]:
684
- """Get impulse response files from db
674
+ def ir_file_ids_for_tag(self, tag: str) -> list[int]:
675
+ """Get impulse response file IDs for given tag from db
685
676
 
686
- :return: Impulse response files
677
+ :return: List of impulse response file IDs for given tag
687
678
  """
688
- import json
689
-
690
- from .datatypes import ImpulseResponseFile
691
-
692
679
  with self.db() as c:
693
- return [
694
- ImpulseResponseFile(impulse_response[1], json.loads(impulse_response[2]), impulse_response[3])
695
- for impulse_response in c.execute(
696
- "SELECT impulse_response_file.* FROM impulse_response_file"
697
- ).fetchall()
698
- ]
699
-
700
- @cached_property
701
- def impulse_response_file_ids(self) -> list[int]:
702
- """Get impulse response file IDs from db
680
+ tag_id = c.execute("SELECT id FROM ir_tag WHERE ? = tag", (tag,)).fetchone()
681
+ if not tag_id:
682
+ return []
703
683
 
704
- :return: List of impulse response file IDs
705
- """
706
- with self.db() as c:
707
684
  return [
708
- int(item[0])
709
- for item in c.execute("SELECT impulse_response_file.id FROM impulse_response_file").fetchall()
685
+ int(item[0] - 1)
686
+ for item in c.execute("SELECT file_id FROM ir_file_ir_tag WHERE ? = tag_id", (tag_id[0],)).fetchall()
710
687
  ]
711
688
 
712
- def impulse_response_file(self, ir_id: int | None) -> str | None:
689
+ def ir_file(self, ir_id: int) -> str:
713
690
  """Get impulse response file name with ID from db
714
691
 
715
692
  :param ir_id: Impulse response file ID
716
693
  :return: Impulse response file name
717
694
  """
718
- if ir_id is None:
719
- return None
720
- return _impulse_response_file(self.db, ir_id, self.use_cache)
695
+ return _ir_file(self.db, ir_id, self.use_cache)
721
696
 
722
- def impulse_response_delay(self, ir_id: int | None) -> int | None:
697
+ def ir_delay(self, ir_id: int) -> int:
723
698
  """Get impulse response delay with ID from db
724
699
 
725
700
  :param ir_id: Impulse response file ID
726
701
  :return: Impulse response delay
727
702
  """
728
- if ir_id is None:
729
- return None
730
- return _impulse_response_delay(self.db, ir_id, self.use_cache)
703
+ return _ir_delay(self.db, ir_id, self.use_cache)
731
704
 
732
705
  @cached_property
733
- def num_impulse_response_files(self) -> int:
706
+ def num_ir_files(self) -> int:
734
707
  """Get number of impulse response files from db
735
708
 
736
709
  :return: Number of impulse response files
737
710
  """
738
711
  with self.db() as c:
739
- return int(c.execute("SELECT count(impulse_response_file.id) FROM impulse_response_file").fetchone()[0])
712
+ return int(c.execute("SELECT count(id) FROM ir_file").fetchone()[0])
740
713
 
714
+ @cached_property
715
+ def ir_tags(self) -> list[str]:
716
+ """Get tags of impulse response files from db
717
+
718
+ :return: Tags of impulse response files
719
+ """
720
+ with self.db() as c:
721
+ return [tag[0] for tag in c.execute("SELECT tag FROM ir_tag").fetchall()]
722
+
723
+ @property
741
724
  def mixtures(self) -> list[Mixture]:
742
725
  """Get mixtures from db
743
726
 
744
727
  :return: Mixtures
745
728
  """
746
729
  from .db_datatypes import MixtureRecord
747
- from .db_datatypes import TargetRecord
730
+ from .db_datatypes import SourceRecord
748
731
  from .helpers import to_mixture
749
- from .helpers import to_target
732
+ from .helpers import to_source
750
733
 
751
734
  with self.db() as c:
752
735
  mixtures: list[Mixture] = []
753
736
  for mixture in [MixtureRecord(*record) for record in c.execute("SELECT * FROM mixture").fetchall()]:
754
- targets = [
755
- to_target(TargetRecord(*target))
756
- for target in c.execute(
737
+ sources_list = [
738
+ to_source(SourceRecord(*source))
739
+ for source in c.execute(
757
740
  """
758
- SELECT target.*
759
- FROM target, mixture_target
760
- WHERE ? = mixture_target.mixture_id AND target.id = mixture_target.target_id
741
+ SELECT source.*
742
+ FROM source, mixture_source
743
+ WHERE ? = mixture_source.mixture_id AND source.id = mixture_source.source_id
761
744
  """,
762
745
  (mixture.id,),
763
746
  ).fetchall()
764
747
  ]
765
- mixtures.append(to_mixture(mixture, targets))
748
+
749
+ sources: Sources = {}
750
+ for source in sources_list:
751
+ sources[self.source_file(source.file_id).category] = source
752
+
753
+ mixtures.append(to_mixture(mixture, sources))
754
+
766
755
  return mixtures
767
756
 
768
757
  @cached_property
@@ -772,7 +761,7 @@ class MixtureDatabase:
772
761
  :return: List of zero-based mixture IDs
773
762
  """
774
763
  with self.db() as c:
775
- return [int(item[0]) - 1 for item in c.execute("SELECT mixture.id FROM mixture").fetchall()]
764
+ return [int(item[0]) - 1 for item in c.execute("SELECT id FROM mixture").fetchall()]
776
765
 
777
766
  def mixture(self, m_id: int) -> Mixture:
778
767
  """Get mixture record with ID from db
@@ -785,7 +774,7 @@ class MixtureDatabase:
785
774
  @cached_property
786
775
  def mixid_width(self) -> int:
787
776
  with self.db() as c:
788
- return int(c.execute("SELECT top.mixid_width FROM top").fetchone()[0])
777
+ return int(c.execute("SELECT mixid_width FROM top").fetchone()[0])
789
778
 
790
779
  def mixture_location(self, m_id: int) -> str:
791
780
  """Get the file location for the give mixture ID
@@ -804,231 +793,342 @@ class MixtureDatabase:
804
793
  :return: Number of mixtures
805
794
  """
806
795
  with self.db() as c:
807
- return int(c.execute("SELECT count(mixture.id) FROM mixture").fetchone()[0])
796
+ return int(c.execute("SELECT count(id) FROM mixture").fetchone()[0])
808
797
 
809
- def read_mixture_data(self, m_id: int, items: list[str] | str) -> Any:
798
+ def read_mixture_data(self, m_id: int, items: list[str] | str) -> dict[str, Any]:
810
799
  """Read mixture data
811
800
 
812
801
  :param m_id: Zero-based mixture ID
813
802
  :param items: String(s) of dataset(s) to retrieve
814
- :return: Data (or tuple of data)
803
+ :return: Dictionary of name: data
815
804
  """
816
- from sonusai.mixture import read_cached_data
805
+ from .data_io import read_cached_data
817
806
 
818
807
  return read_cached_data(self.location, "mixture", self.mixture(m_id).name, items)
819
808
 
820
- def read_target_audio(self, t_id: int) -> AudioT:
821
- """Read target audio
809
+ def read_source_audio(self, s_id: int) -> AudioT:
810
+ """Read source audio
822
811
 
823
- :param t_id: Target ID
824
- :return: Target audio
812
+ :param s_id: Source ID
813
+ :return: Source audio
825
814
  """
826
815
  from .audio import read_audio
827
816
 
828
- return read_audio(self.target_file(t_id).name, self.use_cache)
829
-
830
- def augmented_noise_audio(self, mixture: Mixture) -> AudioT:
831
- """Get augmented noise audio
832
-
833
- :param mixture: Mixture
834
- :return: Augmented noise audio
835
- """
836
- from .audio import read_audio
837
- from .augmentation import apply_augmentation
838
-
839
- noise = self.noise_file(mixture.noise.file_id)
840
- audio = read_audio(noise.name, self.use_cache)
841
- audio = apply_augmentation(self, audio, mixture.noise.augmentation.pre)
842
-
843
- return audio
817
+ return read_audio(self.source_file(s_id).name, self.use_cache)
844
818
 
845
819
  def mixture_class_indices(self, m_id: int) -> list[int]:
846
820
  class_indices: list[int] = []
847
- for t_id in self.mixture(m_id).target_ids:
848
- class_indices.extend(self.target_file(t_id).class_indices)
821
+ for s_id in self.mixture(m_id).source_ids.values():
822
+ class_indices.extend(self.source_file(s_id).class_indices)
849
823
  return sorted(set(class_indices))
850
824
 
851
- def mixture_targets(self, m_id: int, force: bool = False) -> list[AudioT]:
852
- """Get the list of augmented target audio data (one per target in the mixup) for the given mixture ID
825
+ def mixture_sources(self, m_id: int, force: bool = False, cache: bool = False) -> SourcesAudioT:
826
+ """Get the pre-truth source audio data (one per source in the mixture) for the given mixture ID
853
827
 
854
828
  :param m_id: Zero-based mixture ID
855
829
  :param force: Force computing data from original sources regardless of whether cached data exists
856
- :return: List of augmented target audio data (one per target in the mixup)
830
+ :param cache: Cache result
831
+ :return: Dictionary of pre-truth source audio data (one per source in the mixture)
857
832
  """
858
- from .augmentation import apply_augmentation
859
- from .augmentation import apply_gain
860
- from .augmentation import pad_audio_to_length
833
+ from .data_io import write_cached_data
834
+ from .effects import apply_effects
835
+ from .effects import conform_audio_to_length
861
836
 
862
837
  if not force:
863
- targets_audio = self.read_mixture_data(m_id, "targets")
864
- if targets_audio is not None:
865
- return list(targets_audio)
838
+ sources = self.read_mixture_data(m_id, "sources")["sources"]
839
+ if sources is not None:
840
+ return sources
866
841
 
867
842
  mixture = self.mixture(m_id)
868
843
  if mixture is None:
869
844
  raise ValueError(f"Could not find mixture for m_id: {m_id}")
870
845
 
871
- targets_audio = []
872
- for target in mixture.targets:
873
- target_audio = self.read_target_audio(target.file_id)
874
- target_audio = apply_augmentation(
875
- mixdb=self,
876
- audio=target_audio,
877
- augmentation=target.augmentation.pre,
878
- frame_length=self.feature_step_samples,
846
+ sources = {}
847
+ for category, source in mixture.all_sources.items():
848
+ source = mixture.all_sources[category]
849
+ audio = self.read_source_audio(source.file_id)
850
+ audio = apply_effects(self, audio, source.effects, pre=True, post=False)
851
+ audio = conform_audio_to_length(audio, mixture.samples, source.repeat, source.start)
852
+ sources[category] = audio
853
+
854
+ if cache:
855
+ write_cached_data(
856
+ location=self.location,
857
+ name="mixture",
858
+ index=mixture.name,
859
+ items={"sources": sources},
879
860
  )
880
- target_audio = apply_gain(audio=target_audio, gain=mixture.target_snr_gain)
881
- target_audio = pad_audio_to_length(audio=target_audio, length=mixture.samples)
882
- targets_audio.append(target_audio)
883
861
 
884
- return targets_audio
862
+ return sources
885
863
 
886
- def mixture_targets_f(self, m_id: int, targets: list[AudioT] | None = None, force: bool = False) -> list[AudioF]:
887
- """Get the list of augmented target transform data (one per target in the mixup) for the given mixture ID
864
+ def mixture_sources_f(
865
+ self,
866
+ m_id: int,
867
+ sources: SourcesAudioT | None = None,
868
+ force: bool = False,
869
+ cache: bool = False,
870
+ ) -> SourcesAudioF:
871
+ """Get the pre-truth source transform data (one per source in the mixture) for the given mixture ID
888
872
 
889
873
  :param m_id: Zero-based mixture ID
890
- :param targets: List of augmented target audio data (one per target in the mixup)
874
+ :param sources: Dictionary of pre-truth source audio data (one per source in the mixture)
891
875
  :param force: Force computing data from original sources regardless of whether cached data exists
892
- :return: List of augmented target transform data (one per target in the mixup)
876
+ :param cache: Cache result
877
+ :return: Dictionary of pre-truth source transform data (one per source in the mixture)
893
878
  """
879
+ from .data_io import write_cached_data
894
880
  from .helpers import forward_transform
895
881
 
896
- if force or targets is None:
897
- targets = self.mixture_targets(m_id, force)
882
+ if sources is None:
883
+ sources = self.mixture_sources(m_id, force)
898
884
 
899
- return [forward_transform(target, self.ft_config) for target in targets]
885
+ sources_f = {category: forward_transform(sources[category], self.ft_config) for category in sources}
900
886
 
901
- def mixture_target(self, m_id: int, targets: list[AudioT] | None = None, force: bool = False) -> AudioT:
902
- """Get the augmented target audio data for the given mixture ID
887
+ if cache:
888
+ write_cached_data(
889
+ location=self.location,
890
+ name="mixture",
891
+ index=self.mixture(m_id).name,
892
+ items={"sources_f": sources_f},
893
+ )
894
+
895
+ return sources_f
896
+
897
+ def mixture_source(
898
+ self,
899
+ m_id: int,
900
+ sources: SourcesAudioT | None = None,
901
+ force: bool = False,
902
+ cache: bool = False,
903
+ ) -> AudioT:
904
+ """Get the post-truth, summed, and gained source audio data for the given mixture ID
903
905
 
904
906
  :param m_id: Zero-based mixture ID
905
- :param targets: List of augmented target audio data (one per target in the mixup)
907
+ :param sources: Dictionary of pre-truth source audio data (one per source in the mixture)
906
908
  :param force: Force computing data from original sources regardless of whether cached data exists
907
- :return: Augmented target audio data
909
+ :param cache: Cache result
910
+ :return: Post-truth, gained, and summed source audio data
908
911
  """
909
- from .helpers import get_target
912
+ import numpy as np
913
+
914
+ from .data_io import write_cached_data
915
+ from .effects import apply_effects
910
916
 
911
917
  if not force:
912
- target = self.read_mixture_data(m_id, "target")
913
- if target is not None:
914
- return target
918
+ source = self.read_mixture_data(m_id, "source")["source"]
919
+ if source is not None:
920
+ return source
921
+
922
+ if sources is None:
923
+ sources = self.mixture_sources(m_id, force)
924
+
925
+ mixture = self.mixture(m_id)
915
926
 
916
- if force or targets is None:
917
- targets = self.mixture_targets(m_id, force)
927
+ source = np.sum(
928
+ [
929
+ apply_effects(
930
+ self,
931
+ audio=sources[category],
932
+ effects=mixture.all_sources[category].effects,
933
+ pre=False,
934
+ post=True,
935
+ )
936
+ * mixture.all_sources[category].snr_gain
937
+ for category in sources
938
+ if category != "noise"
939
+ ],
940
+ axis=0,
941
+ )
918
942
 
919
- return get_target(self, self.mixture(m_id), targets)
943
+ if cache:
944
+ write_cached_data(
945
+ location=self.location,
946
+ name="mixture",
947
+ index=mixture.name,
948
+ items={"source": source},
949
+ )
950
+
951
+ return source
920
952
 
921
- def mixture_target_f(
953
+ def mixture_source_f(
922
954
  self,
923
955
  m_id: int,
924
- targets: list[AudioT] | None = None,
925
- target: AudioT | None = None,
956
+ sources: SourcesAudioT | None = None,
957
+ source: AudioT | None = None,
926
958
  force: bool = False,
959
+ cache: bool = False,
927
960
  ) -> AudioF:
928
- """Get the augmented target transform data for the given mixture ID
961
+ """Get the post-truth, summed, and gained source transform data for the given mixture ID
929
962
 
930
963
  :param m_id: Zero-based mixture ID
931
- :param targets: List of augmented target audio data (one per target in the mixup)
932
- :param target: Augmented target audio for the given m_id
964
+ :param sources: Dictionary of pre-truth source audio data (one per source in the mixture)
965
+ :param source: Post-truth, gained, and summed source audio for the given m_id
933
966
  :param force: Force computing data from original sources regardless of whether cached data exists
934
- :return: Augmented target transform data
967
+ :param cache: Cache result
968
+ :return: Post-truth, gained, and summed source transform data
935
969
  """
970
+ from .data_io import write_cached_data
936
971
  from .helpers import forward_transform
937
972
 
938
- if force or target is None:
939
- target = self.mixture_target(m_id, targets, force)
973
+ if source is None:
974
+ source = self.mixture_source(m_id, sources, force)
940
975
 
941
- return forward_transform(target, self.ft_config)
976
+ source_f = forward_transform(source, self.ft_config)
942
977
 
943
- def mixture_noise(self, m_id: int, force: bool = False) -> AudioT:
944
- """Get the augmented noise audio data for the given mixture ID
978
+ if cache:
979
+ write_cached_data(
980
+ location=self.location,
981
+ name="mixture",
982
+ index=self.mixture(m_id).name,
983
+ items={"source_f": source_f},
984
+ )
985
+
986
+ return source_f
987
+
988
+ def mixture_noise(
989
+ self,
990
+ m_id: int,
991
+ sources: SourcesAudioT | None = None,
992
+ force: bool = False,
993
+ cache: bool = False,
994
+ ) -> AudioT:
995
+ """Get the post-truth and gained noise audio data for the given mixture ID
945
996
 
946
997
  :param m_id: Zero-based mixture ID
998
+ :param sources: Dictionary of pre-truth source audio data (one per source in the mixture)
947
999
  :param force: Force computing data from original sources regardless of whether cached data exists
948
- :return: Augmented noise audio data
1000
+ :param cache: Cache result
1001
+ :return: Post-truth and gained noise audio data
949
1002
  """
950
- from .audio import get_next_noise
951
- from .augmentation import apply_gain
1003
+ from .data_io import write_cached_data
1004
+ from .effects import apply_effects
952
1005
 
953
1006
  if not force:
954
- noise = self.read_mixture_data(m_id, "noise")
1007
+ noise = self.read_mixture_data(m_id, "noise")["noise"]
955
1008
  if noise is not None:
956
1009
  return noise
957
1010
 
958
- mixture = self.mixture(m_id)
959
- noise = self.augmented_noise_audio(mixture)
960
- noise = get_next_noise(audio=noise, offset=mixture.noise_offset, length=mixture.samples)
961
- return apply_gain(audio=noise, gain=mixture.noise_snr_gain)
1011
+ if sources is None:
1012
+ sources = self.mixture_sources(m_id, force)
962
1013
 
963
- def mixture_noise_f(self, m_id: int, noise: AudioT | None = None, force: bool = False) -> AudioF:
964
- """Get the augmented noise transform for the given mixture ID
1014
+ noise = self.mixture(m_id).noise
1015
+ noise = apply_effects(self, sources["noise"], noise.effects, pre=False, post=True) * noise.snr_gain
1016
+
1017
+ if cache:
1018
+ write_cached_data(
1019
+ location=self.location,
1020
+ name="mixture",
1021
+ index=self.mixture(m_id).name,
1022
+ items={"noise": noise},
1023
+ )
1024
+
1025
+ return noise
1026
+
1027
+ def mixture_noise_f(
1028
+ self,
1029
+ m_id: int,
1030
+ sources: SourcesAudioT | None = None,
1031
+ noise: AudioT | None = None,
1032
+ force: bool = False,
1033
+ cache: bool = False,
1034
+ ) -> AudioF:
1035
+ """Get the post-truth and gained noise transform for the given mixture ID
965
1036
 
966
1037
  :param m_id: Zero-based mixture ID
967
- :param noise: Augmented noise audio data
1038
+ :param sources: Dictionary of pre-truth source audio data (one per source in the mixture)
1039
+ :param noise: Post-truth and gained noise audio data
968
1040
  :param force: Force computing data from original sources regardless of whether cached data exists
969
- :return: Augmented noise transform data
1041
+ :param cache: Cache result
1042
+ :return: Post-truth and gained noise transform data
970
1043
  """
1044
+ from .data_io import write_cached_data
971
1045
  from .helpers import forward_transform
972
1046
 
973
1047
  if force or noise is None:
974
- noise = self.mixture_noise(m_id, force)
1048
+ noise = self.mixture_noise(m_id, sources, force)
1049
+
1050
+ noise_f = forward_transform(noise, self.ft_config)
1051
+ if cache:
1052
+ write_cached_data(
1053
+ location=self.location,
1054
+ name="mixture",
1055
+ index=self.mixture(m_id).name,
1056
+ items={"noise_f": noise_f},
1057
+ )
975
1058
 
976
- return forward_transform(noise, self.ft_config)
1059
+ return noise_f
977
1060
 
978
1061
  def mixture_mixture(
979
1062
  self,
980
1063
  m_id: int,
981
- targets: list[AudioT] | None = None,
982
- target: AudioT | None = None,
1064
+ sources: SourcesAudioT | None = None,
1065
+ source: AudioT | None = None,
983
1066
  noise: AudioT | None = None,
984
1067
  force: bool = False,
1068
+ cache: bool = False,
985
1069
  ) -> AudioT:
986
1070
  """Get the mixture audio data for the given mixture ID
987
1071
 
988
1072
  :param m_id: Zero-based mixture ID
989
- :param targets: List of augmented target audio data (one per target in the mixup)
990
- :param target: Augmented target audio data
991
- :param noise: Augmented noise audio data
1073
+ :param sources: Dictionary of pre-truth source audio data (one per source in the mixture)
1074
+ :param source: Post-truth, gained, and summed source audio data
1075
+ :param noise: Post-truth and gained noise audio data
992
1076
  :param force: Force computing data from original sources regardless of whether cached data exists
1077
+ :param cache: Cache result
993
1078
  :return: Mixture audio data
994
1079
  """
1080
+ from .data_io import write_cached_data
1081
+
995
1082
  if not force:
996
- mixture = self.read_mixture_data(m_id, "mixture")
1083
+ mixture = self.read_mixture_data(m_id, "mixture")["mixture"]
997
1084
  if mixture is not None:
998
1085
  return mixture
999
1086
 
1000
- if force or target is None:
1001
- target = self.mixture_target(m_id, targets, force)
1087
+ if source is None:
1088
+ source = self.mixture_source(m_id, sources, force)
1002
1089
 
1003
- if force or noise is None:
1004
- noise = self.mixture_noise(m_id, force)
1090
+ if noise is None:
1091
+ noise = self.mixture_noise(m_id, sources, force)
1092
+
1093
+ mixture = source + noise
1005
1094
 
1006
- return target + noise
1095
+ if cache:
1096
+ write_cached_data(
1097
+ location=self.location,
1098
+ name="mixture",
1099
+ index=self.mixture(m_id).name,
1100
+ items={"mixture": mixture},
1101
+ )
1102
+
1103
+ return mixture
1007
1104
 
1008
1105
  def mixture_mixture_f(
1009
1106
  self,
1010
1107
  m_id: int,
1011
- targets: list[AudioT] | None = None,
1012
- target: AudioT | None = None,
1108
+ sources: SourcesAudioT | None = None,
1109
+ source: AudioT | None = None,
1013
1110
  noise: AudioT | None = None,
1014
1111
  mixture: AudioT | None = None,
1015
1112
  force: bool = False,
1113
+ cache: bool = False,
1016
1114
  ) -> AudioF:
1017
1115
  """Get the mixture transform for the given mixture ID
1018
1116
 
1019
1117
  :param m_id: Zero-based mixture ID
1020
- :param targets: List of augmented target audio data (one per target in the mixup)
1021
- :param target: Augmented target audio data
1022
- :param noise: Augmented noise audio data
1118
+ :param sources: Dictionary of pre-truth source audio data (one per source in the mixture)
1119
+ :param source: Post-truth, gained, and summed source audio data
1120
+ :param noise: Post-truth and gained noise audio data
1023
1121
  :param mixture: Mixture audio data
1024
1122
  :param force: Force computing data from original sources regardless of whether cached data exists
1123
+ :param cache: Cache result
1025
1124
  :return: Mixture transform data
1026
1125
  """
1126
+ from .data_io import write_cached_data
1027
1127
  from .helpers import forward_transform
1028
1128
  from .spectral_mask import apply_spectral_mask
1029
1129
 
1030
- if force or mixture is None:
1031
- mixture = self.mixture_mixture(m_id, targets, target, noise, force)
1130
+ if mixture is None:
1131
+ mixture = self.mixture_mixture(m_id, sources, source, noise, force)
1032
1132
 
1033
1133
  mixture_f = forward_transform(mixture, self.ft_config)
1034
1134
 
@@ -1040,80 +1140,79 @@ class MixtureDatabase:
1040
1140
  seed=m.spectral_mask_seed,
1041
1141
  )
1042
1142
 
1143
+ if cache:
1144
+ write_cached_data(
1145
+ location=self.location,
1146
+ name="mixture",
1147
+ index=self.mixture(m_id).name,
1148
+ items={"mixture_f": mixture_f},
1149
+ )
1150
+
1043
1151
  return mixture_f
1044
1152
 
1045
- def mixture_truth_t(
1046
- self,
1047
- m_id: int,
1048
- targets: list[AudioT] | None = None,
1049
- noise: AudioT | None = None,
1050
- mixture: AudioT | None = None,
1051
- force: bool = False,
1052
- ) -> list[TruthDict]:
1153
+ def mixture_truth_t(self, m_id: int, force: bool = False, cache: bool = False) -> TruthsDict:
1053
1154
  """Get the truth_t data for the given mixture ID
1054
1155
 
1055
1156
  :param m_id: Zero-based mixture ID
1056
- :param targets: List of augmented target audio data (one per target in the mixup) for the given mixture ID
1057
- :param noise: Augmented noise audio data for the given mixture ID
1058
- :param mixture: Mixture audio data for the given mixture ID
1059
1157
  :param force: Force computing data from original sources regardless of whether cached data exists
1158
+ :param cache: Cache result
1060
1159
  :return: list of truth_t data
1061
1160
  """
1161
+ from .data_io import write_cached_data
1062
1162
  from .truth import truth_function
1063
1163
 
1064
1164
  if not force:
1065
- truth_t = self.read_mixture_data(m_id, "truth_t")
1165
+ truth_t = self.read_mixture_data(m_id, "truth_t")["truth_t"]
1066
1166
  if truth_t is not None:
1067
1167
  return truth_t
1068
1168
 
1069
- if force or targets is None:
1070
- targets = self.mixture_targets(m_id, force)
1071
-
1072
- if force or noise is None:
1073
- noise = self.mixture_noise(m_id, force)
1074
-
1075
- if force or mixture is None:
1076
- mixture = self.mixture_mixture(m_id, targets=targets, noise=noise, force=force)
1077
-
1078
- if not all(len(target) == self.mixture(m_id).samples for target in targets):
1079
- raise ValueError("Lengths of targets do not match length of mixture")
1169
+ truth_t = truth_function(self, m_id)
1080
1170
 
1081
- if len(noise) != self.mixture(m_id).samples:
1082
- raise ValueError("Length of noise does not match length of mixture")
1171
+ if cache:
1172
+ write_cached_data(
1173
+ location=self.location,
1174
+ name="mixture",
1175
+ index=self.mixture(m_id).name,
1176
+ items={"truth_t": truth_t},
1177
+ )
1083
1178
 
1084
- return truth_function(self, m_id)
1179
+ return truth_t
1085
1180
 
1086
1181
  def mixture_segsnr_t(
1087
1182
  self,
1088
1183
  m_id: int,
1089
- targets: list[AudioT] | None = None,
1090
- target: AudioT | None = None,
1184
+ sources: SourcesAudioT | None = None,
1185
+ source: AudioT | None = None,
1091
1186
  noise: AudioT | None = None,
1092
1187
  force: bool = False,
1188
+ cache: bool = False,
1093
1189
  ) -> Segsnr:
1094
1190
  """Get the segsnr_t data for the given mixture ID
1095
1191
 
1096
1192
  :param m_id: Zero-based mixture ID
1097
- :param targets: List of augmented target audio data (one per target in the mixup)
1098
- :param target: Augmented target audio data
1099
- :param noise: Augmented noise audio data
1193
+ :param sources: Dictionary of pre-truth source audio data (one per source in the mixture)
1194
+ :param source: Post-truth, gained, and summed source audio data
1195
+ :param noise: Post-truth and gained noise audio data
1100
1196
  :param force: Force computing data from original sources regardless of whether cached data exists
1197
+ :param cache: Cache result
1101
1198
  :return: segsnr_t data
1102
1199
  """
1103
1200
  import numpy as np
1104
1201
  import torch
1105
1202
  from pyaaware import ForwardTransform
1106
1203
 
1204
+ from .data_io import write_cached_data
1205
+
1107
1206
  if not force:
1108
- segsnr_t = self.read_mixture_data(m_id, "segsnr_t")
1207
+ segsnr_t = self.read_mixture_data(m_id, "segsnr_t")["segsnr_t"]
1109
1208
  if segsnr_t is not None:
1110
1209
  return segsnr_t
1111
1210
 
1112
- if force or target is None:
1113
- target = self.mixture_target(m_id, targets, force)
1211
+ if source is None:
1212
+ source = self.mixture_source(m_id, sources, force)
1114
1213
 
1115
- if force or noise is None:
1116
- noise = self.mixture_noise(m_id, force)
1214
+ if noise is None:
1215
+ noise = self.mixture_noise(m_id, sources, force)
1117
1216
 
1118
1217
  ft = ForwardTransform(
1119
1218
  length=self.ft_config.length,
@@ -1127,13 +1226,13 @@ class MixtureDatabase:
1127
1226
 
1128
1227
  segsnr_t = np.empty(mixture.samples, dtype=np.float32)
1129
1228
 
1130
- target_energy = ft.execute_all(torch.from_numpy(target))[1].numpy()
1229
+ source_energy = ft.execute_all(torch.from_numpy(source))[1].numpy()
1131
1230
  noise_energy = ft.execute_all(torch.from_numpy(noise))[1].numpy()
1132
1231
 
1133
1232
  offsets = range(0, mixture.samples, self.ft_config.overlap)
1134
- if len(target_energy) != len(offsets):
1233
+ if len(source_energy) != len(offsets):
1135
1234
  raise ValueError(
1136
- f"Number of frames in energy, {len(target_energy)}, is not number of frames in mixture, {len(offsets)}"
1235
+ f"Number of frames in energy, {len(source_energy)}, is not number of frames in mixture, {len(offsets)}"
1137
1236
  )
1138
1237
 
1139
1238
  for idx, offset in enumerate(offsets):
@@ -1142,187 +1241,242 @@ class MixtureDatabase:
1142
1241
  if noise_energy[idx] == 0:
1143
1242
  snr = np.float32(np.inf)
1144
1243
  else:
1145
- snr = np.float32(target_energy[idx] / noise_energy[idx])
1244
+ snr = np.float32(source_energy[idx] / noise_energy[idx])
1146
1245
 
1147
1246
  segsnr_t[indices] = snr
1148
1247
 
1248
+ if cache:
1249
+ write_cached_data(
1250
+ location=self.location,
1251
+ name="mixture",
1252
+ index=mixture.name,
1253
+ items={"segsnr_t": segsnr_t},
1254
+ )
1255
+
1149
1256
  return segsnr_t
1150
1257
 
1151
1258
  def mixture_segsnr(
1152
1259
  self,
1153
1260
  m_id: int,
1154
1261
  segsnr_t: Segsnr | None = None,
1155
- targets: list[AudioT] | None = None,
1156
- target: AudioT | None = None,
1262
+ sources: SourcesAudioT | None = None,
1263
+ source: AudioT | None = None,
1157
1264
  noise: AudioT | None = None,
1158
1265
  force: bool = False,
1266
+ cache: bool = False,
1159
1267
  ) -> Segsnr:
1160
1268
  """Get the segsnr data for the given mixture ID
1161
1269
 
1162
1270
  :param m_id: Zero-based mixture ID
1163
1271
  :param segsnr_t: segsnr_t data
1164
- :param targets: List of augmented target audio data (one per target in the mixup)
1165
- :param target: Augmented target audio data
1166
- :param noise: Augmented noise audio data
1272
+ :param sources: Dictionary of pre-truth source audio data (one per source in the mixture)
1273
+ :param source: Post-truth, gained, and summed source audio data
1274
+ :param noise: Post-truth and gained noise audio data
1167
1275
  :param force: Force computing data from original sources regardless of whether cached data exists
1276
+ :param cache: Cache result
1168
1277
  :return: segsnr data
1169
1278
  """
1279
+ from .data_io import write_cached_data
1280
+
1170
1281
  if not force:
1171
- segsnr = self.read_mixture_data(m_id, "segsnr")
1282
+ segsnr = self.read_mixture_data(m_id, "segsnr")["segsnr"]
1172
1283
  if segsnr is not None:
1173
1284
  return segsnr
1174
1285
 
1175
- segsnr_t = self.read_mixture_data(m_id, "segsnr_t")
1176
- if segsnr_t is not None:
1177
- return segsnr_t[0 :: self.ft_config.overlap]
1286
+ if segsnr_t is None:
1287
+ segsnr_t = self.mixture_segsnr_t(m_id, sources, source, noise, force)
1178
1288
 
1179
- if force or segsnr_t is None:
1180
- segsnr_t = self.mixture_segsnr_t(m_id, targets, target, noise, force)
1289
+ segsnr = segsnr_t[0 :: self.ft_config.overlap]
1181
1290
 
1182
- return segsnr_t[0 :: self.ft_config.overlap]
1291
+ if cache:
1292
+ write_cached_data(
1293
+ location=self.location,
1294
+ name="mixture",
1295
+ index=self.mixture(m_id).name,
1296
+ items={"segsnr": segsnr},
1297
+ )
1298
+
1299
+ return segsnr
1183
1300
 
1184
1301
  def mixture_ft(
1185
1302
  self,
1186
1303
  m_id: int,
1187
- targets: list[AudioT] | None = None,
1188
- target: AudioT | None = None,
1304
+ sources: SourcesAudioT | None = None,
1305
+ source: AudioT | None = None,
1189
1306
  noise: AudioT | None = None,
1190
1307
  mixture_f: AudioF | None = None,
1191
1308
  mixture: AudioT | None = None,
1192
- truth_t: list[TruthDict] | None = None,
1309
+ truth_t: TruthsDict | None = None,
1193
1310
  force: bool = False,
1194
- ) -> tuple[Feature, TruthDict]:
1311
+ cache: bool = False,
1312
+ ) -> tuple[Feature, TruthsDict]:
1195
1313
  """Get the feature and truth_f data for the given mixture ID
1196
1314
 
1197
1315
  :param m_id: Zero-based mixture ID
1198
- :param targets: List of augmented target audio data (one per target in the mixup)
1199
- :param target: Augmented target audio data
1200
- :param noise: Augmented noise audio data
1316
+ :param sources: Dictionary of pre-truth source audio data (one per source in the mixture)
1317
+ :param source: Post-truth, gained, and summed source audio data
1318
+ :param noise: Post-truth and gained noise audio data
1201
1319
  :param mixture_f: Mixture transform data
1202
1320
  :param mixture: Mixture audio data
1203
1321
  :param truth_t: truth_t
1204
1322
  :param force: Force computing data from original sources regardless of whether cached data exists
1323
+ :param cache: Cache result
1205
1324
  :return: Tuple of (feature, truth_f) data
1206
1325
  """
1207
1326
  from pyaaware import FeatureGenerator
1208
1327
 
1328
+ from .data_io import write_cached_data
1209
1329
  from .truth import truth_stride_reduction
1210
1330
 
1211
1331
  if not force:
1212
- feature, truth_f = self.read_mixture_data(m_id, ["feature", "truth_f"])
1213
- if feature is not None and truth_f is not None:
1214
- return feature, truth_f
1332
+ ft = self.read_mixture_data(m_id, ["feature", "truth_f"])
1333
+ if ft["feature"] is not None and ft["truth_f"] is not None:
1334
+ return ft["feature"], ft["truth_f"]
1215
1335
 
1216
- if force or mixture_f is None:
1336
+ if mixture_f is None:
1217
1337
  mixture_f = self.mixture_mixture_f(
1218
1338
  m_id=m_id,
1219
- targets=targets,
1220
- target=target,
1339
+ sources=sources,
1340
+ source=source,
1221
1341
  noise=noise,
1222
1342
  mixture=mixture,
1223
1343
  force=force,
1224
1344
  )
1225
1345
 
1226
- if force or truth_t is None:
1227
- truth_t = self.mixture_truth_t(m_id=m_id, targets=targets, noise=noise, force=force)
1346
+ if truth_t is None:
1347
+ truth_t = self.mixture_truth_t(m_id, force)
1228
1348
 
1229
1349
  fg = FeatureGenerator(self.fg_config.feature_mode, self.fg_config.truth_parameters)
1230
1350
 
1231
- # TODO: handle mixup in truth_t
1232
- feature, truth_f = fg.execute_all(mixture_f, truth_t[0])
1351
+ feature, truth_f = fg.execute_all(mixture_f, truth_t)
1233
1352
  if truth_f is not None:
1234
- for key in self.truth_configs:
1235
- if self.truth_parameters[key] is not None:
1236
- truth_f[key] = truth_stride_reduction(truth_f[key], self.truth_configs[key].stride_reduction)
1353
+ truth_configs = self.mixture_truth_configs(m_id)
1354
+ for category, configs in truth_configs.items():
1355
+ for name, config in configs.items():
1356
+ if self.truth_parameters[category][name] is not None:
1357
+ truth_f[category][name] = truth_stride_reduction(
1358
+ truth_f[category][name], config.stride_reduction
1359
+ )
1237
1360
  else:
1238
1361
  raise TypeError("Unexpected truth of None from feature generator")
1239
1362
 
1363
+ if cache:
1364
+ write_cached_data(
1365
+ location=self.location,
1366
+ name="mixture",
1367
+ index=self.mixture(m_id).name,
1368
+ items={"feature": truth_f, "truth_f": truth_f},
1369
+ )
1370
+
1240
1371
  return feature, truth_f
1241
1372
 
1242
1373
  def mixture_feature(
1243
1374
  self,
1244
1375
  m_id: int,
1245
- targets: list[AudioT] | None = None,
1376
+ sources: SourcesAudioT | None = None,
1246
1377
  noise: AudioT | None = None,
1247
1378
  mixture: AudioT | None = None,
1248
- truth_t: list[TruthDict] | None = None,
1379
+ truth_t: TruthsDict | None = None,
1249
1380
  force: bool = False,
1381
+ cache: bool = False,
1250
1382
  ) -> Feature:
1251
1383
  """Get the feature data for the given mixture ID
1252
1384
 
1253
1385
  :param m_id: Zero-based mixture ID
1254
- :param targets: List of augmented target audio data (one per target in the mixup)
1255
- :param noise: Augmented noise audio data
1386
+ :param sources: Dictionary of pre-truth source audio data (one per source in the mixture)
1387
+ :param noise: Post-truth and gained noise audio data
1256
1388
  :param mixture: Mixture audio data
1257
1389
  :param truth_t: truth_t
1258
1390
  :param force: Force computing data from original sources regardless of whether cached data exists
1391
+ :param cache: Cache result
1259
1392
  :return: Feature data
1260
1393
  """
1261
- feature, _ = self.mixture_ft(
1394
+ from .data_io import write_cached_data
1395
+
1396
+ feature = self.mixture_ft(
1262
1397
  m_id=m_id,
1263
- targets=targets,
1398
+ sources=sources,
1264
1399
  noise=noise,
1265
1400
  mixture=mixture,
1266
1401
  truth_t=truth_t,
1267
1402
  force=force,
1268
- )
1403
+ )[0]
1404
+
1405
+ if cache:
1406
+ write_cached_data(
1407
+ location=self.location,
1408
+ name="mixture",
1409
+ index=self.mixture(m_id).name,
1410
+ items={"feature": feature},
1411
+ )
1412
+
1269
1413
  return feature
1270
1414
 
1271
1415
  def mixture_truth_f(
1272
1416
  self,
1273
1417
  m_id: int,
1274
- targets: list[AudioT] | None = None,
1418
+ sources: SourcesAudioT | None = None,
1275
1419
  noise: AudioT | None = None,
1276
1420
  mixture: AudioT | None = None,
1277
- truth_t: list[TruthDict] | None = None,
1421
+ truth_t: TruthsDict | None = None,
1278
1422
  force: bool = False,
1423
+ cache: bool = False,
1279
1424
  ) -> TruthDict:
1280
1425
  """Get the truth_f data for the given mixture ID
1281
1426
 
1282
1427
  :param m_id: Zero-based mixture ID
1283
- :param targets: List of augmented target audio data (one per target in the mixup)
1284
- :param noise: Augmented noise audio data
1428
+ :param sources: Dictionary of pre-truth source audio data (one per source in the mixture)
1429
+ :param noise: Post-truth and gained noise audio data
1285
1430
  :param mixture: Mixture audio data
1286
1431
  :param truth_t: truth_t
1287
1432
  :param force: Force computing data from original sources regardless of whether cached data exists
1433
+ :param cache: Cache result
1288
1434
  :return: truth_f data
1289
1435
  """
1290
- _, truth_f = self.mixture_ft(
1436
+ from .data_io import write_cached_data
1437
+
1438
+ truth_f = self.mixture_ft(
1291
1439
  m_id=m_id,
1292
- targets=targets,
1440
+ sources=sources,
1293
1441
  noise=noise,
1294
1442
  mixture=mixture,
1295
1443
  truth_t=truth_t,
1296
1444
  force=force,
1297
- )
1445
+ )[1]
1446
+
1447
+ if cache:
1448
+ write_cached_data(
1449
+ location=self.location,
1450
+ name="mixture",
1451
+ index=self.mixture(m_id).name,
1452
+ items={"truth_f": truth_f},
1453
+ )
1454
+
1298
1455
  return truth_f
1299
1456
 
1300
- def mixture_class_count(
1301
- self,
1302
- m_id: int,
1303
- targets: list[AudioT] | None = None,
1304
- noise: AudioT | None = None,
1305
- truth_t: list[TruthDict] | None = None,
1306
- ) -> ClassCount:
1457
+ def mixture_class_count(self, m_id: int, truth_t: TruthsDict | None = None) -> dict[str, ClassCount]:
1307
1458
  """Compute the number of frames for which each class index is active for the given mixture ID
1308
1459
 
1309
1460
  :param m_id: Zero-based mixture ID
1310
- :param targets: List of augmented target audio (one per target in the mixup)
1311
- :param noise: Augmented noise audio
1312
1461
  :param truth_t: truth_t
1313
- :return: List of class counts
1462
+ :return: Dictionary of class counts
1314
1463
  """
1315
1464
  import numpy as np
1316
1465
 
1317
1466
  if truth_t is None:
1318
- truth_t = self.mixture_truth_t(m_id, targets, noise)
1467
+ truth_t = self.mixture_truth_t(m_id)
1319
1468
 
1320
- class_count = [0] * self.num_classes
1321
- num_classes = self.num_classes
1322
- if "sed" in self.truth_configs:
1323
- for cl in range(num_classes):
1324
- # TODO: handle mixup in truth_t
1325
- class_count[cl] = int(np.sum(truth_t[0]["sed"][:, cl] >= self.class_weights_thresholds[cl]))
1469
+ class_count: dict[str, ClassCount] = {}
1470
+
1471
+ truth_configs = self.mixture_truth_configs(m_id)
1472
+ for category in truth_configs:
1473
+ class_count[category] = [0] * self.num_classes
1474
+ for configs in truth_configs[category]:
1475
+ if "sed" in configs:
1476
+ for cl in range(self.num_classes):
1477
+ class_count[category][cl] = int(
1478
+ np.sum(truth_t[category]["sed"][:, cl] >= self.class_weights_thresholds[cl])
1479
+ )
1326
1480
 
1327
1481
  return class_count
1328
1482
 
@@ -1348,57 +1502,56 @@ class MixtureDatabase:
1348
1502
  return _speaker(self.db, s_id, tier, self.use_cache)
1349
1503
 
1350
1504
  def speech_metadata(self, tier: str) -> list[str]:
1351
- from .helpers import get_textgrid_tier_from_target_file
1505
+ from .helpers import get_textgrid_tier_from_source_file
1352
1506
 
1353
1507
  results: set[str] = set()
1354
1508
  if tier in self.textgrid_metadata_tiers:
1355
- for target_file in self.target_files:
1356
- data = get_textgrid_tier_from_target_file(target_file.name, tier)
1357
- if data is None:
1358
- continue
1359
- if isinstance(data, list):
1360
- for item in data:
1361
- results.add(item.label)
1362
- else:
1363
- results.add(data)
1509
+ for source_files in self.source_files.values():
1510
+ for source_file in source_files:
1511
+ data = get_textgrid_tier_from_source_file(source_file.name, tier)
1512
+ if data is None:
1513
+ continue
1514
+ if isinstance(data, list):
1515
+ for item in data:
1516
+ results.add(item.label)
1517
+ else:
1518
+ results.add(data)
1364
1519
  elif tier in self.speaker_metadata_tiers:
1365
- for target_file in self.target_files:
1366
- data = self.speaker(target_file.speaker_id, tier)
1367
- if data is not None:
1368
- results.add(data)
1520
+ for source_files in self.source_files.values():
1521
+ for source_file in source_files:
1522
+ data = self.speaker(source_file.speaker_id, tier)
1523
+ if data is not None:
1524
+ results.add(data)
1369
1525
 
1370
1526
  return sorted(results)
1371
1527
 
1372
- def mixture_speech_metadata(self, mixid: int, tier: str) -> list[SpeechMetadata]:
1528
+ def mixture_speech_metadata(self, mixid: int, tier: str) -> dict[str, SpeechMetadata]:
1373
1529
  from praatio.utilities.constants import Interval
1374
1530
 
1375
- from .helpers import get_textgrid_tier_from_target_file
1531
+ from .helpers import get_textgrid_tier_from_source_file
1376
1532
 
1377
- results: list[SpeechMetadata] = []
1533
+ results: dict[str, SpeechMetadata] = {}
1378
1534
  is_textgrid = tier in self.textgrid_metadata_tiers
1379
1535
  if is_textgrid:
1380
- for target in self.mixture(mixid).targets:
1381
- data = get_textgrid_tier_from_target_file(self.target_file(target.file_id).name, tier)
1536
+ for category, source in self.mixture(mixid).all_sources.items():
1537
+ data = get_textgrid_tier_from_source_file(self.source_file(source.file_id).name, tier)
1382
1538
  if isinstance(data, list):
1383
- # Check for tempo augmentation and adjust Interval start and end data as needed
1539
+ # Check for tempo effect and adjust Interval start and end data as needed
1384
1540
  entries = []
1385
1541
  for entry in data:
1386
- if target.augmentation.pre.tempo is not None:
1387
- entries.append(
1388
- Interval(
1389
- entry.start / target.augmentation.pre.tempo,
1390
- entry.end / target.augmentation.pre.tempo,
1391
- entry.label,
1392
- )
1542
+ entries.append(
1543
+ Interval(
1544
+ entry.start / source.pre_tempo,
1545
+ entry.end / source.pre_tempo,
1546
+ entry.label,
1393
1547
  )
1394
- else:
1395
- entries.append(entry)
1396
- results.append(entries)
1548
+ )
1549
+ results[category] = entries
1397
1550
  else:
1398
- results.append(data)
1551
+ results[category] = data
1399
1552
  else:
1400
- for target in self.mixture(mixid).targets:
1401
- results.append(self.speaker(self.target_file(target.file_id).speaker_id, tier))
1553
+ for category, source in self.mixture(mixid).all_sources.items():
1554
+ results[category] = self.speaker(self.source_file(source.file_id).speaker_id, tier)
1402
1555
 
1403
1556
  return results
1404
1557
 
@@ -1407,7 +1560,7 @@ class MixtureDatabase:
1407
1560
  tier: str | None = None,
1408
1561
  value: str | None = None,
1409
1562
  where: str | None = None,
1410
- ) -> list[int]:
1563
+ ) -> dict[str, list[int]]:
1411
1564
  """Get a list of mixture IDs for the given speech metadata tier.
1412
1565
 
1413
1566
  If 'where' is None, then include mixture IDs whose tier values are equal to the given 'value'.
@@ -1441,16 +1594,29 @@ class MixtureDatabase:
1441
1594
  results = c.execute(f"SELECT id FROM speaker WHERE {where}").fetchall()
1442
1595
  speaker_ids = ",".join(map(str, [i[0] for i in results]))
1443
1596
 
1444
- results = c.execute(f"SELECT id FROM target_file WHERE speaker_id IN ({speaker_ids})").fetchall()
1445
- target_file_ids = ",".join(map(str, [i[0] for i in results]))
1597
+ results = c.execute(f"SELECT id, category FROM source_file WHERE speaker_id IN ({speaker_ids})").fetchall()
1598
+ source_file_ids: dict[str, list[int]] = {}
1599
+ for result in results:
1600
+ source_file_id, category = result
1601
+ if category not in source_file_ids:
1602
+ source_file_ids[category] = [source_file_id]
1603
+ else:
1604
+ source_file_ids[category].append(source_file_id)
1446
1605
 
1447
- results = c.execute(
1448
- f"SELECT mixture_id FROM mixture_target WHERE mixture_target.target_id IN ({target_file_ids})"
1449
- ).fetchall()
1606
+ mixids: dict[str, list[int]] = {}
1607
+ for category in source_file_ids:
1608
+ id_str = ",".join(map(str, source_file_ids[category]))
1609
+ results = c.execute(f"SELECT id FROM source WHERE file_id IN ({id_str})").fetchall()
1610
+ source_ids = ",".join(map(str, [i[0] for i in results]))
1450
1611
 
1451
- return [mixture_id[0] - 1 for mixture_id in results]
1612
+ results = c.execute(
1613
+ f"SELECT mixture_id FROM mixture_source WHERE source_id IN ({source_ids})"
1614
+ ).fetchall()
1615
+ mixids[category] = [mixture_id[0] - 1 for mixture_id in results]
1452
1616
 
1453
- def mixture_all_speech_metadata(self, m_id: int) -> list[dict[str, SpeechMetadata]]:
1617
+ return mixids
1618
+
1619
+ def mixture_all_speech_metadata(self, m_id: int) -> dict[str, dict[str, SpeechMetadata]]:
1454
1620
  from .helpers import mixture_all_speech_metadata
1455
1621
 
1456
1622
  return mixture_all_speech_metadata(self, self.mixture(m_id))
@@ -1483,63 +1649,65 @@ class MixtureDatabase:
1483
1649
  :param m_id: Zero-based mixture ID
1484
1650
  :param metrics: List of metrics to get
1485
1651
  :param force: Force computing data from original sources regardless of whether cached data exists
1486
- :return: List of metric data
1652
+ :return: Dictionary of metric data
1487
1653
  """
1488
1654
  from collections.abc import Callable
1489
1655
 
1490
1656
  import numpy as np
1491
1657
  from pystoi import stoi
1492
1658
 
1493
- from sonusai.metrics import calc_audio_stats
1494
- from sonusai.metrics import calc_phase_distance
1495
- from sonusai.metrics import calc_segsnr_f
1496
- from sonusai.metrics import calc_segsnr_f_bin
1497
- from sonusai.metrics import calc_speech
1498
- from sonusai.metrics import calc_wer
1499
- from sonusai.metrics import calc_wsdr
1500
- from sonusai.mixture import SAMPLE_RATE
1501
- from sonusai.mixture import AudioStatsMetrics
1502
- from sonusai.mixture import SpeechMetrics
1503
- from sonusai.utils import calc_asr
1504
-
1505
- def create_targets_audio() -> Callable[[], list[AudioT]]:
1506
- state: list[AudioT] | None = None
1507
-
1508
- def get() -> list[AudioT]:
1659
+ from ..constants import SAMPLE_RATE
1660
+ from ..datatypes import AudioStatsMetrics
1661
+ from ..datatypes import SpeechMetrics
1662
+ from ..metrics.calc_audio_stats import calc_audio_stats
1663
+ from ..metrics.calc_pesq import calc_pesq
1664
+ from ..metrics.calc_phase_distance import calc_phase_distance
1665
+ from ..metrics.calc_segsnr_f import calc_segsnr_f
1666
+ from ..metrics.calc_segsnr_f import calc_segsnr_f_bin
1667
+ from ..metrics.calc_speech import calc_speech
1668
+ from ..metrics.calc_wer import calc_wer
1669
+ from ..metrics.calc_wsdr import calc_wsdr
1670
+ from ..utils.asr import calc_asr
1671
+ from ..utils.db import linear_to_db
1672
+
1673
+ def create_sources_audio() -> Callable[[], dict[str, AudioT]]:
1674
+ state: dict[str, AudioT] | None = None
1675
+
1676
+ def get() -> dict[str, AudioT]:
1509
1677
  nonlocal state
1510
1678
  if state is None:
1511
- state = self.mixture_targets(m_id)
1679
+ state = self.mixture_sources(m_id)
1512
1680
  return state
1513
1681
 
1514
1682
  return get
1515
1683
 
1516
- targets_audio = create_targets_audio()
1684
+ sources_audio = create_sources_audio()
1517
1685
 
1518
- def create_target_audio() -> Callable[[], AudioT]:
1686
+ def create_source_audio() -> Callable[[], AudioT]:
1519
1687
  state: AudioT | None = None
1520
1688
 
1521
1689
  def get() -> AudioT:
1522
1690
  nonlocal state
1523
1691
  if state is None:
1524
- state = self.mixture_target(m_id)
1692
+ state = self.mixture_source(m_id)
1525
1693
  return state
1526
1694
 
1527
1695
  return get
1528
1696
 
1529
- target_audio = create_target_audio()
1697
+ source_audio = create_source_audio()
1530
1698
 
1531
- def create_target_f() -> Callable[[], AudioF]:
1699
+ def create_source_f() -> Callable[[], AudioF]:
1532
1700
  state: AudioF | None = None
1533
1701
 
1534
1702
  def get() -> AudioF:
1535
1703
  nonlocal state
1536
1704
  if state is None:
1537
- state = self.mixture_targets_f(m_id)[0]
1705
+ state = self.mixture_source_f(m_id)
1538
1706
  return state
1539
1707
 
1540
1708
  return get
1541
1709
 
1542
- target_f = create_target_f()
1710
+ source_f = create_source_f()
1543
1711
 
1544
1712
  def create_noise_audio() -> Callable[[], AudioT]:
1545
1713
  state: AudioT | None = None
@@ -1593,15 +1761,29 @@ class MixtureDatabase:
1593
1761
 
1594
1762
  segsnr_f = create_segsnr_f()
1595
1763
 
1596
- def create_speech() -> Callable[[], list[SpeechMetrics]]:
1597
- state: list[SpeechMetrics] | None = None
1764
+ def create_pesq() -> Callable[[], dict[str, float]]:
1765
+ state: dict[str, float] | None = None
1598
1766
 
1599
- def get() -> list[SpeechMetrics]:
1767
+ def get() -> dict[str, float]:
1600
1768
  nonlocal state
1601
1769
  if state is None:
1602
- state = []
1603
- for audio in targets_audio():
1604
- state.append(calc_speech(hypothesis=mixture_audio(), reference=audio))
1770
+ state = {category: calc_pesq(mixture_audio(), audio) for category, audio in sources_audio().items()}
1771
+ return state
1772
+
1773
+ return get
1774
+
1775
+ pesq = create_pesq()
1776
+
1777
+ def create_speech() -> Callable[[], dict[str, SpeechMetrics]]:
1778
+ state: dict[str, SpeechMetrics] | None = None
1779
+
1780
+ def get() -> dict[str, SpeechMetrics]:
1781
+ nonlocal state
1782
+ if state is None:
1783
+ state = {
1784
+ category: calc_speech(mixture_audio(), audio, pesq()[category])
1785
+ for category, audio in sources_audio().items()
1786
+ }
1605
1787
  return state
1606
1788
 
1607
1789
  return get
@@ -1621,33 +1803,34 @@ class MixtureDatabase:
1621
1803
 
1622
1804
  mixture_stats = create_mixture_stats()
1623
1805
 
1624
- def create_targets_stats() -> Callable[[], list[AudioStatsMetrics]]:
1625
- state: list[AudioStatsMetrics] | None = None
1806
+ def create_sources_stats() -> Callable[[], dict[str, AudioStatsMetrics]]:
1807
+ state: dict[str, AudioStatsMetrics] | None = None
1626
1808
 
1627
- def get() -> list[AudioStatsMetrics]:
1809
+ def get() -> dict[str, AudioStatsMetrics]:
1628
1810
  nonlocal state
1629
1811
  if state is None:
1630
- state = []
1631
- for audio in targets_audio():
1632
- state.append(calc_audio_stats(audio, self.fg_info.ft_config.length / SAMPLE_RATE))
1812
+ state = {
1813
+ category: calc_audio_stats(audio, self.fg_info.ft_config.length / SAMPLE_RATE)
1814
+ for category, audio in sources_audio().items()
1815
+ }
1633
1816
  return state
1634
1817
 
1635
1818
  return get
1636
1819
 
1637
- targets_stats = create_targets_stats()
1820
+ sources_stats = create_sources_stats()
1638
1821
 
1639
- def create_target_stats() -> Callable[[], AudioStatsMetrics]:
1822
+ def create_source_stats() -> Callable[[], AudioStatsMetrics]:
1640
1823
  state: AudioStatsMetrics | None = None
1641
1824
 
1642
1825
  def get() -> AudioStatsMetrics:
1643
1826
  nonlocal state
1644
1827
  if state is None:
1645
- state = calc_audio_stats(target_audio(), self.fg_info.ft_config.length / SAMPLE_RATE)
1828
+ state = calc_audio_stats(source_audio(), self.fg_info.ft_config.length / SAMPLE_RATE)
1646
1829
  return state
1647
1830
 
1648
1831
  return get
1649
1832
 
1650
- target_stats = create_target_stats()
1833
+ source_stats = create_source_stats()
1651
1834
 
1652
1835
  def create_noise_stats() -> Callable[[], AudioStatsMetrics]:
1653
1836
  state: AudioStatsMetrics | None = None
@@ -1678,33 +1861,34 @@ class MixtureDatabase:
1678
1861
 
1679
1862
  asr_config = create_asr_config()
1680
1863
 
1681
- def create_targets_asr() -> Callable[[str], list[str]]:
1682
- state: dict[str, list[str]] = {}
1864
+ def create_sources_asr() -> Callable[[str], dict[str, str]]:
1865
+ state: dict[str, dict[str, str]] = {}
1683
1866
 
1684
- def get(asr_name) -> list[str]:
1867
+ def get(asr_name) -> dict[str, str]:
1685
1868
  nonlocal state
1686
1869
  if asr_name not in state:
1687
- state[asr_name] = []
1688
- for audio in targets_audio():
1689
- state[asr_name].append(calc_asr(audio, **asr_config(asr_name)).text)
1870
+ state[asr_name] = {
1871
+ category: calc_asr(audio, **asr_config(asr_name)).text
1872
+ for category, audio in sources_audio().items()
1873
+ }
1690
1874
  return state[asr_name]
1691
1875
 
1692
1876
  return get
1693
1877
 
1694
- targets_asr = create_targets_asr()
1878
+ sources_asr = create_sources_asr()
1695
1879
 
1696
- def create_target_asr() -> Callable[[str], str]:
1880
+ def create_source_asr() -> Callable[[str], str]:
1697
1881
  state: dict[str, str] = {}
1698
1882
 
1699
1883
  def get(asr_name) -> str:
1700
1884
  nonlocal state
1701
1885
  if asr_name not in state:
1702
- state[asr_name] = calc_asr(target_audio(), **asr_config(asr_name)).text
1886
+ state[asr_name] = calc_asr(source_audio(), **asr_config(asr_name)).text
1703
1887
  return state[asr_name]
1704
1888
 
1705
1889
  return get
1706
1890
 
1707
- target_asr = create_target_asr()
1891
+ source_asr = create_source_asr()
1708
1892
 
1709
1893
  def create_mixture_asr() -> Callable[[str], str]:
1710
1894
  state: dict[str, str] = {}
@@ -1728,11 +1912,11 @@ class MixtureDatabase:
1728
1912
 
1729
1913
  def calc(m: str) -> Any:
1730
1914
  if m == "mxsnr":
1731
- return self.mixture(m_id).snr
1915
+ return {category: source.snr for category, source in self.mixture(m_id).all_sources.items()}
1732
1916
 
1733
1917
  # Get cached data first, if exists
1734
1918
  if not force:
1735
- value = self.read_mixture_data(m_id, m)
1919
+ value = self.read_mixture_data(m_id, m)[m]
1736
1920
  if value is not None:
1737
1921
  return value
1738
1922
 
@@ -1744,8 +1928,8 @@ class MixtureDatabase:
1744
1928
  # noise only, ignore/reset target asr
1745
1929
  return float("nan")
1746
1930
 
1747
- if target_asr(asr_name):
1748
- return calc_wer(mixture_asr(asr_name), target_asr(asr_name)).wer * 100
1931
+ if source_asr(asr_name):
1932
+ return calc_wer(mixture_asr(asr_name), source_asr(asr_name)).wer * 100
1749
1933
 
1750
1934
  # TODO: should this be NaN like above?
1751
1935
  return float(0)
@@ -1753,12 +1937,14 @@ class MixtureDatabase:
1753
1937
  if m.startswith("basewer"):
1754
1938
  asr_name = get_asr_name(m)
1755
1939
 
1756
- text = self.mixture_speech_metadata(m_id, "text")[0]
1757
- if not isinstance(text, str):
1758
- # TODO: should this be NaN like above?
1759
- return [float(0)] * len(targets_audio())
1760
-
1761
- return [calc_wer(t, text).wer * 100 for t in targets_asr(asr_name)]
1940
+ text = self.mixture_speech_metadata(m_id, "text")
1941
+ base_wer: dict[str, float] = {}
1942
+ for category, source in sources_asr(asr_name).items():
1943
+ if isinstance(text[category], str):
1944
+ base_wer[category] = calc_wer(source, str(text[category])).wer * 100
1945
+ else:
1946
+ base_wer[category] = 0
1947
+ return base_wer
1762
1948
 
1763
1949
  if m.startswith("mxasr"):
1764
1950
  return mixture_asr(get_asr_name(m))
@@ -1769,6 +1955,18 @@ class MixtureDatabase:
1769
1955
  if m == "mxssnr_std":
1770
1956
  return calc_segsnr_f(segsnr_f()).std
1771
1957
 
1958
+ if m == "mxssnr_avg_db":
1959
+ val = calc_segsnr_f(segsnr_f()).avg
1960
+ if val is not None:
1961
+ return linear_to_db(val)
1962
+ return None
1963
+
1964
+ if m == "mxssnr_std_db":
1965
+ val = calc_segsnr_f(segsnr_f()).std
1966
+ if val is not None:
1967
+ return linear_to_db(val)
1968
+ return None
1969
+
1772
1970
  if m == "mxssnrdb_avg":
1773
1971
  return calc_segsnr_f(segsnr_f()).db_avg
1774
1972
 
@@ -1776,40 +1974,40 @@ class MixtureDatabase:
1776
1974
  return calc_segsnr_f(segsnr_f()).db_std
1777
1975
 
1778
1976
  if m == "mxssnrf_avg":
1779
- return calc_segsnr_f_bin(target_f(), noise_f()).avg
1977
+ return calc_segsnr_f_bin(source_f(), noise_f()).avg
1780
1978
 
1781
1979
  if m == "mxssnrf_std":
1782
- return calc_segsnr_f_bin(target_f(), noise_f()).std
1980
+ return calc_segsnr_f_bin(source_f(), noise_f()).std
1783
1981
 
1784
1982
  if m == "mxssnrdbf_avg":
1785
- return calc_segsnr_f_bin(target_f(), noise_f()).db_avg
1983
+ return calc_segsnr_f_bin(source_f(), noise_f()).db_avg
1786
1984
 
1787
1985
  if m == "mxssnrdbf_std":
1788
- return calc_segsnr_f_bin(target_f(), noise_f()).db_std
1986
+ return calc_segsnr_f_bin(source_f(), noise_f()).db_std
1789
1987
 
1790
1988
  if m == "mxpesq":
1791
1989
  if self.mixture(m_id).is_noise_only:
1792
- return [0] * len(speech())
1793
- return [s.pesq for s in speech()]
1990
+ return dict.fromkeys(pesq(), 0)
1991
+ return pesq()
1794
1992
 
1795
1993
  if m == "mxcsig":
1796
1994
  if self.mixture(m_id).is_noise_only:
1797
- return [0] * len(speech())
1798
- return [s.csig for s in speech()]
1995
+ return dict.fromkeys(speech(), 0)
1996
+ return {category: s.csig for category, s in speech().items()}
1799
1997
 
1800
1998
  if m == "mxcbak":
1801
1999
  if self.mixture(m_id).is_noise_only:
1802
- return [0] * len(speech())
1803
- return [s.cbak for s in speech()]
2000
+ return dict.fromkeys(speech(), 0)
2001
+ return {category: s.cbak for category, s in speech().items()}
1804
2002
 
1805
2003
  if m == "mxcovl":
1806
2004
  if self.mixture(m_id).is_noise_only:
1807
- return [0] * len(speech())
1808
- return [s.covl for s in speech()]
2005
+ return dict.fromkeys(speech(), 0)
2006
+ return {category: s.covl for category, s in speech().items()}
1809
2007
 
1810
2008
  if m == "mxwsdr":
1811
2009
  mixture = mixture_audio()[:, np.newaxis]
1812
- target = target_audio()[:, np.newaxis]
2010
+ target = source_audio()[:, np.newaxis]
1813
2011
  noise = noise_audio()[:, np.newaxis]
1814
2012
  return calc_wsdr(
1815
2013
  hypothesis=np.concatenate((mixture, noise), axis=1),
@@ -1819,11 +2017,11 @@ class MixtureDatabase:
1819
2017
 
1820
2018
  if m == "mxpd":
1821
2019
  mixture_f = self.mixture_mixture_f(m_id)
1822
- return calc_phase_distance(hypothesis=mixture_f, reference=target_f())[0]
2020
+ return calc_phase_distance(hypothesis=mixture_f, reference=source_f())[0]
1823
2021
 
1824
2022
  if m == "mxstoi":
1825
2023
  return stoi(
1826
- x=target_audio(),
2024
+ x=source_audio(),
1827
2025
  y=mixture_audio(),
1828
2026
  fs_sig=SAMPLE_RATE,
1829
2027
  extended=False,
@@ -1860,70 +2058,70 @@ class MixtureDatabase:
1860
2058
  return mixture_stats().pkc
1861
2059
 
1862
2060
  if m == "mxtdco":
1863
- return target_stats().dco
2061
+ return source_stats().dco
1864
2062
 
1865
2063
  if m == "mxtmin":
1866
- return target_stats().min
2064
+ return source_stats().min
1867
2065
 
1868
2066
  if m == "mxtmax":
1869
- return target_stats().max
2067
+ return source_stats().max
1870
2068
 
1871
2069
  if m == "mxtpkdb":
1872
- return target_stats().pkdb
2070
+ return source_stats().pkdb
1873
2071
 
1874
2072
  if m == "mxtlrms":
1875
- return target_stats().lrms
2073
+ return source_stats().lrms
1876
2074
 
1877
2075
  if m == "mxtpkr":
1878
- return target_stats().pkr
2076
+ return source_stats().pkr
1879
2077
 
1880
2078
  if m == "mxttr":
1881
- return target_stats().tr
2079
+ return source_stats().tr
1882
2080
 
1883
2081
  if m == "mxtcr":
1884
- return target_stats().cr
2082
+ return source_stats().cr
1885
2083
 
1886
2084
  if m == "mxtfl":
1887
- return target_stats().fl
2085
+ return source_stats().fl
1888
2086
 
1889
2087
  if m == "mxtpkc":
1890
- return target_stats().pkc
2088
+ return source_stats().pkc
1891
2089
 
1892
- if m == "tdco":
1893
- return [t.dco for t in targets_stats()]
2090
+ if m == "sdco":
2091
+ return {category: s.dco for category, s in sources_stats().items()}
1894
2092
 
1895
- if m == "tmin":
1896
- return [t.min for t in targets_stats()]
2093
+ if m == "smin":
2094
+ return {category: s.min for category, s in sources_stats().items()}
1897
2095
 
1898
- if m == "tmax":
1899
- return [t.max for t in targets_stats()]
2096
+ if m == "smax":
2097
+ return {category: s.max for category, s in sources_stats().items()}
1900
2098
 
1901
- if m == "tpkdb":
1902
- return [t.pkdb for t in targets_stats()]
2099
+ if m == "spkdb":
2100
+ return {category: s.pkdb for category, s in sources_stats().items()}
1903
2101
 
1904
- if m == "tlrms":
1905
- return [t.lrms for t in targets_stats()]
2102
+ if m == "slrms":
2103
+ return {category: s.lrms for category, s in sources_stats().items()}
1906
2104
 
1907
- if m == "tpkr":
1908
- return [t.pkr for t in targets_stats()]
2105
+ if m == "spkr":
2106
+ return {category: s.pkr for category, s in sources_stats().items()}
1909
2107
 
1910
- if m == "ttr":
1911
- return [t.tr for t in targets_stats()]
2108
+ if m == "str":
2109
+ return {category: s.tr for category, s in sources_stats().items()}
1912
2110
 
1913
- if m == "tcr":
1914
- return [t.cr for t in targets_stats()]
2111
+ if m == "scr":
2112
+ return {category: s.cr for category, s in sources_stats().items()}
1915
2113
 
1916
- if m == "tfl":
1917
- return [t.fl for t in targets_stats()]
2114
+ if m == "sfl":
2115
+ return {category: s.fl for category, s in sources_stats().items()}
1918
2116
 
1919
- if m == "tpkc":
1920
- return [t.pkc for t in targets_stats()]
2117
+ if m == "spkc":
2118
+ return {category: s.pkc for category, s in sources_stats().items()}
1921
2119
 
1922
- if m.startswith("tasr"):
1923
- return targets_asr(get_asr_name(m))
2120
+ if m.startswith("sasr"):
2121
+ return sources_asr(get_asr_name(m))
1924
2122
 
1925
- if m.startswith("mxtasr"):
1926
- return target_asr(get_asr_name(m))
2123
+ if m.startswith("mxsasr"):
2124
+ return source_asr(get_asr_name(m))
1927
2125
 
1928
2126
  if m == "ndco":
1929
2127
  return noise_stats().dco
@@ -2003,16 +2201,7 @@ def __spectral_mask(db: partial, sm_id: int) -> SpectralMask:
2003
2201
  from .db_datatypes import SpectralMaskRecord
2004
2202
 
2005
2203
  with db() as c:
2006
- spectral_mask = SpectralMaskRecord(
2007
- *c.execute(
2008
- """
2009
- SELECT *
2010
- FROM spectral_mask
2011
- WHERE ? = spectral_mask.id
2012
- """,
2013
- (sm_id,),
2014
- ).fetchone()
2015
- )
2204
+ spectral_mask = SpectralMaskRecord(*c.execute("SELECT * FROM spectral_mask WHERE ? = id", (sm_id,)).fetchone())
2016
2205
  return SpectralMask(
2017
2206
  f_max_width=spectral_mask.f_max_width,
2018
2207
  f_num=spectral_mask.f_num,
@@ -2022,82 +2211,72 @@ def __spectral_mask(db: partial, sm_id: int) -> SpectralMask:
2022
2211
  )
2023
2212
 
2024
2213
 
2025
- def _target_file(db: partial, t_id: int, use_cache: bool = True) -> TargetFile:
2026
- """Get target file with ID from db
2214
+ def _num_source_files(db: partial, category: str, use_cache: bool = True) -> int:
2215
+ """Get number of source files from category from db
2027
2216
 
2028
2217
  :param db: Database context
2029
- :param t_id: Target file ID
2218
+ :param category: Source category
2030
2219
  :param use_cache: If true, use LRU caching
2031
- :return: Target file
2220
+ :return: Number of source files
2032
2221
  """
2033
2222
  if use_cache:
2034
- return __target_file(db, t_id, use_cache)
2035
- return __target_file.__wrapped__(db, t_id, use_cache)
2223
+ return __num_source_files(db, category)
2224
+ return __num_source_files.__wrapped__(db, category)
2036
2225
 
2037
2226
 
2038
2227
  @lru_cache
2039
- def __target_file(db: partial, t_id: int, use_cache: bool = True) -> TargetFile:
2040
- """Get target file with ID from db
2228
+ def __num_source_files(db: partial, category: str) -> int:
2229
+ """Get number of source files from category from db
2041
2230
 
2042
2231
  :param db: Database context
2043
- :param t_id: Target file ID
2044
- :param use_cache: If true, use LRU caching
2045
- :return: Target file
2232
+ :param category: Source category
2233
+ :return: Number of source files
2046
2234
  """
2047
- import json
2048
-
2049
- from .db_datatypes import TargetFileRecord
2050
-
2051
2235
  with db() as c:
2052
- target_file = TargetFileRecord(
2053
- *c.execute(
2054
- """
2055
- SELECT *
2056
- FROM target_file
2057
- WHERE ? = target_file.id
2058
- """,
2059
- (t_id,),
2060
- ).fetchone()
2061
- )
2062
-
2063
- return TargetFile(
2064
- name=target_file.name,
2065
- samples=target_file.samples,
2066
- class_indices=json.loads(target_file.class_indices),
2067
- level_type=target_file.level_type,
2068
- truth_configs=_target_truth_configs(db, t_id, use_cache),
2069
- speaker_id=target_file.speaker_id,
2070
- )
2236
+ return int(c.execute("SELECT count(id) FROM source_file WHERE ? = category", (category,)).fetchone()[0])
2071
2237
 
2072
2238
 
2073
- def _noise_file(db: partial, n_id: int, use_cache: bool = True) -> NoiseFile:
2074
- """Get noise file with ID from db
2239
+ def _source_file(db: partial, s_id: int, use_cache: bool = True) -> SourceFile:
2240
+ """Get source file with ID from db
2075
2241
 
2076
2242
  :param db: Database context
2077
- :param n_id: Noise file ID
2243
+ :param s_id: Source file ID
2078
2244
  :param use_cache: If true, use LRU caching
2079
- :return: Noise file
2245
+ :return: Source file
2080
2246
  """
2081
2247
  if use_cache:
2082
- return __noise_file(db, n_id)
2083
- return __noise_file.__wrapped__(db, n_id)
2248
+ return __source_file(db, s_id, use_cache)
2249
+ return __source_file.__wrapped__(db, s_id, use_cache)
2084
2250
 
2085
2251
 
2086
2252
  @lru_cache
2087
- def __noise_file(db: partial, n_id: int) -> NoiseFile:
2253
+ def __source_file(db: partial, s_id: int, use_cache: bool = True) -> SourceFile:
2254
+ """Get source file with ID from db
2255
+
2256
+ :param db: Database context
2257
+ :param s_id: Source file ID
2258
+ :param use_cache: If true, use LRU caching
2259
+ :return: Source file
2260
+ """
2261
+ import json
2262
+
2263
+ from .db_datatypes import SourceFileRecord
2264
+
2088
2265
  with db() as c:
2089
- noise = c.execute(
2090
- """
2091
- SELECT noise_file.name, samples
2092
- FROM noise_file
2093
- WHERE ? = noise_file.id
2094
- """,
2095
- (n_id,),
2096
- ).fetchone()
2097
- return NoiseFile(name=noise[0], samples=noise[1])
2266
+ source_file = SourceFileRecord(*c.execute("SELECT * FROM source_file WHERE ? = id", (s_id,)).fetchone())
2267
+
2268
+ return SourceFile(
2269
+ category=source_file.category,
2270
+ name=source_file.name,
2271
+ samples=source_file.samples,
2272
+ class_indices=json.loads(source_file.class_indices),
2273
+ level_type=source_file.level_type,
2274
+ truth_configs=_source_truth_configs(db, s_id, use_cache),
2275
+ speaker_id=source_file.speaker_id,
2276
+ )
2098
2277
 
2099
2278
 
2100
- def _impulse_response_file(db: partial, ir_id: int, use_cache: bool = True) -> str:
2279
+ def _ir_file(db: partial, ir_id: int, use_cache: bool = True) -> str:
2101
2280
  """Get impulse response file name with ID from db
2102
2281
 
2103
2282
  :param db: Database context
@@ -2106,26 +2285,17 @@ def _impulse_response_file(db: partial, ir_id: int, use_cache: bool = True) -> s
2106
2285
  :return: Impulse response file name
2107
2286
  """
2108
2287
  if use_cache:
2109
- return __impulse_response_file(db, ir_id)
2110
- return __impulse_response_file.__wrapped__(db, ir_id)
2288
+ return __ir_file(db, ir_id)
2289
+ return __ir_file.__wrapped__(db, ir_id)
2111
2290
 
2112
2291
 
2113
2292
  @lru_cache
2114
- def __impulse_response_file(db: partial, ir_id: int) -> str:
2293
+ def __ir_file(db: partial, ir_id: int) -> str:
2115
2294
  with db() as c:
2116
- return str(
2117
- c.execute(
2118
- """
2119
- SELECT impulse_response_file.file
2120
- FROM impulse_response_file
2121
- WHERE ? = impulse_response_file.id
2122
- """,
2123
- (ir_id + 1,),
2124
- ).fetchone()[0]
2125
- )
2295
+ return str(c.execute("SELECT name FROM ir_file WHERE ? = id ", (ir_id + 1,)).fetchone()[0])
2126
2296
 
2127
2297
 
2128
- def _impulse_response_delay(db: partial, ir_id: int, use_cache: bool = True) -> int:
2298
+ def _ir_delay(db: partial, ir_id: int, use_cache: bool = True) -> int:
2129
2299
  """Get impulse response delay with ID from db
2130
2300
 
2131
2301
  :param db: Database context
@@ -2134,23 +2304,14 @@ def _impulse_response_delay(db: partial, ir_id: int, use_cache: bool = True) ->
2134
2304
  :return: Impulse response delay
2135
2305
  """
2136
2306
  if use_cache:
2137
- return __impulse_response_delay(db, ir_id)
2138
- return __impulse_response_delay.__wrapped__(db, ir_id)
2307
+ return __ir_delay(db, ir_id)
2308
+ return __ir_delay.__wrapped__(db, ir_id)
2139
2309
 
2140
2310
 
2141
2311
  @lru_cache
2142
- def __impulse_response_delay(db: partial, ir_id: int) -> int:
2312
+ def __ir_delay(db: partial, ir_id: int) -> int:
2143
2313
  with db() as c:
2144
- return int(
2145
- c.execute(
2146
- """
2147
- SELECT impulse_response_file.delay
2148
- FROM impulse_response_file
2149
- WHERE ? = impulse_response_file.id
2150
- """,
2151
- (ir_id + 1,),
2152
- ).fetchone()[0]
2153
- )
2314
+ return int(c.execute("SELECT delay FROM ir_file WHERE ? = id", (ir_id + 1,)).fetchone()[0])
2154
2315
 
2155
2316
 
2156
2317
  def _mixture(db: partial, m_id: int, use_cache: bool = True) -> Mixture:
@@ -2169,35 +2330,27 @@ def _mixture(db: partial, m_id: int, use_cache: bool = True) -> Mixture:
2169
2330
  @lru_cache
2170
2331
  def __mixture(db: partial, m_id: int) -> Mixture:
2171
2332
  from .db_datatypes import MixtureRecord
2172
- from .db_datatypes import TargetRecord
2333
+ from .db_datatypes import SourceRecord
2173
2334
  from .helpers import to_mixture
2174
- from .helpers import to_target
2335
+ from .helpers import to_source
2175
2336
 
2176
2337
  with db() as c:
2177
- mixture = MixtureRecord(
2178
- *c.execute(
2179
- """
2180
- SELECT *
2181
- FROM mixture
2182
- WHERE ? = mixture.id
2183
- """,
2184
- (m_id + 1,),
2185
- ).fetchone()
2186
- )
2338
+ mixture = MixtureRecord(*c.execute("SELECT * FROM mixture WHERE ? = id", (m_id + 1,)).fetchone())
2187
2339
 
2188
- targets = [
2189
- to_target(TargetRecord(*target))
2190
- for target in c.execute(
2191
- """
2192
- SELECT target.*
2193
- FROM target, mixture_target
2194
- WHERE ? = mixture_target.mixture_id AND target.id = mixture_target.target_id
2340
+ sources: Sources = {}
2341
+ for source in c.execute(
2342
+ """
2343
+ SELECT source.*
2344
+ FROM source, mixture_source
2345
+ WHERE ? = mixture_source.mixture_id AND source.id = mixture_source.source_id
2195
2346
  """,
2196
- (mixture.id,),
2197
- ).fetchall()
2198
- ]
2347
+ (mixture.id,),
2348
+ ).fetchall():
2349
+ s = SourceRecord(*source)
2350
+ category = c.execute("SELECT category FROM source_file WHERE ? = id", (s.file_id,)).fetchone()[0]
2351
+ sources[category] = to_source(s)
2199
2352
 
2200
- return to_mixture(mixture, targets)
2353
+ return to_mixture(mixture, sources)
2201
2354
 
2202
2355
 
2203
2356
  def _speaker(db: partial, s_id: int | None, tier: str, use_cache: bool = True) -> str | None:
@@ -2220,27 +2373,55 @@ def __speaker(db: partial, s_id: int | None, tier: str) -> str | None:
2220
2373
  return data[0]
2221
2374
 
2222
2375
 
2223
- def _target_truth_configs(db: partial, t_id: int, use_cache: bool = True) -> TruthConfigs:
2376
+ def _category_truth_configs(db: partial, category: str, use_cache: bool = True) -> dict[str, str]:
2377
+ if use_cache:
2378
+ return __category_truth_configs(db, category)
2379
+ return __category_truth_configs.__wrapped__(db, category)
2380
+
2381
+
2382
+ @lru_cache
2383
+ def __category_truth_configs(db: partial, category: str) -> dict[str, str]:
2384
+ import json
2385
+
2386
+ truth_configs: dict[str, str] = {}
2387
+ with db() as c:
2388
+ s_ids = c.execute("SELECT id FROM source_file WHERE ? = category", (category,)).fetchall()
2389
+
2390
+ for s_id in s_ids:
2391
+ for truth_config_record in c.execute(
2392
+ """
2393
+ SELECT truth_config.config
2394
+ FROM truth_config, source_file_truth_config
2395
+ WHERE ? = source_file_truth_config.source_file_id AND truth_config.id = source_file_truth_config.truth_config_id
2396
+ """,
2397
+ (s_id[0],),
2398
+ ).fetchall():
2399
+ truth_config = json.loads(truth_config_record[0])
2400
+ truth_configs[truth_config["name"]] = truth_config["function"]
2401
+ return truth_configs
2402
+
2403
+
2404
+ def _source_truth_configs(db: partial, s_id: int, use_cache: bool = True) -> TruthConfigs:
2224
2405
  if use_cache:
2225
- return __target_truth_configs(db, t_id)
2226
- return __target_truth_configs.__wrapped__(db, t_id)
2406
+ return __source_truth_configs(db, s_id)
2407
+ return __source_truth_configs.__wrapped__(db, s_id)
2227
2408
 
2228
2409
 
2229
2410
  @lru_cache
2230
- def __target_truth_configs(db: partial, t_id: int) -> TruthConfigs:
2411
+ def __source_truth_configs(db: partial, s_id: int) -> TruthConfigs:
2231
2412
  import json
2232
2413
 
2233
- from .datatypes import TruthConfig
2414
+ from ..datatypes import TruthConfig
2234
2415
 
2235
2416
  truth_configs: TruthConfigs = {}
2236
2417
  with db() as c:
2237
2418
  for truth_config_record in c.execute(
2238
2419
  """
2239
2420
  SELECT truth_config.config
2240
- FROM truth_config, target_file_truth_config
2241
- WHERE ? = target_file_truth_config.target_file_id AND truth_config.id = target_file_truth_config.truth_config_id
2421
+ FROM truth_config, source_file_truth_config
2422
+ WHERE ? = source_file_truth_config.source_file_id AND truth_config.id = source_file_truth_config.truth_config_id
2242
2423
  """,
2243
- (t_id,),
2424
+ (s_id,),
2244
2425
  ).fetchall():
2245
2426
  truth_config = json.loads(truth_config_record[0])
2246
2427
  truth_configs[truth_config["name"]] = TruthConfig(