sonusai 0.20.2__py3-none-any.whl → 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. sonusai/__init__.py +16 -3
  2. sonusai/audiofe.py +240 -76
  3. sonusai/calc_metric_spenh.py +71 -73
  4. sonusai/config/__init__.py +3 -0
  5. sonusai/config/config.py +61 -0
  6. sonusai/config/config.yml +20 -0
  7. sonusai/config/constants.py +8 -0
  8. sonusai/constants.py +11 -0
  9. sonusai/data/genmixdb.yml +21 -36
  10. sonusai/{mixture/datatypes.py → datatypes.py} +91 -130
  11. sonusai/deprecated/plot.py +4 -5
  12. sonusai/doc/doc.py +4 -4
  13. sonusai/doc.py +11 -4
  14. sonusai/genft.py +43 -45
  15. sonusai/genmetrics.py +23 -19
  16. sonusai/genmix.py +54 -82
  17. sonusai/genmixdb.py +88 -264
  18. sonusai/ir_metric.py +30 -34
  19. sonusai/lsdb.py +41 -48
  20. sonusai/main.py +15 -22
  21. sonusai/metrics/calc_audio_stats.py +4 -17
  22. sonusai/metrics/calc_class_weights.py +4 -4
  23. sonusai/metrics/calc_optimal_thresholds.py +8 -5
  24. sonusai/metrics/calc_pesq.py +2 -2
  25. sonusai/metrics/calc_segsnr_f.py +4 -4
  26. sonusai/metrics/calc_speech.py +25 -13
  27. sonusai/metrics/class_summary.py +7 -7
  28. sonusai/metrics/confusion_matrix_summary.py +5 -5
  29. sonusai/metrics/one_hot.py +4 -4
  30. sonusai/metrics/snr_summary.py +7 -7
  31. sonusai/metrics_summary.py +38 -45
  32. sonusai/mixture/__init__.py +5 -104
  33. sonusai/mixture/audio.py +10 -39
  34. sonusai/mixture/class_balancing.py +103 -0
  35. sonusai/mixture/config.py +251 -271
  36. sonusai/mixture/constants.py +35 -39
  37. sonusai/mixture/data_io.py +25 -36
  38. sonusai/mixture/db_datatypes.py +58 -22
  39. sonusai/mixture/effects.py +386 -0
  40. sonusai/mixture/feature.py +7 -11
  41. sonusai/mixture/generation.py +484 -611
  42. sonusai/mixture/helpers.py +82 -184
  43. sonusai/mixture/ir_delay.py +3 -4
  44. sonusai/mixture/ir_effects.py +77 -0
  45. sonusai/mixture/log_duration_and_sizes.py +6 -12
  46. sonusai/mixture/mixdb.py +931 -669
  47. sonusai/mixture/pad_audio.py +35 -0
  48. sonusai/mixture/resample.py +7 -0
  49. sonusai/mixture/sox_effects.py +195 -0
  50. sonusai/mixture/sox_help.py +650 -0
  51. sonusai/mixture/spectral_mask.py +2 -2
  52. sonusai/mixture/truth.py +17 -15
  53. sonusai/mixture/truth_functions/crm.py +12 -12
  54. sonusai/mixture/truth_functions/energy.py +22 -22
  55. sonusai/mixture/truth_functions/file.py +5 -5
  56. sonusai/mixture/truth_functions/metadata.py +4 -4
  57. sonusai/mixture/truth_functions/metrics.py +4 -4
  58. sonusai/mixture/truth_functions/phoneme.py +3 -3
  59. sonusai/mixture/truth_functions/sed.py +11 -13
  60. sonusai/mixture/truth_functions/target.py +10 -10
  61. sonusai/mkwav.py +26 -29
  62. sonusai/onnx_predict.py +240 -88
  63. sonusai/queries/__init__.py +2 -2
  64. sonusai/queries/queries.py +38 -34
  65. sonusai/speech/librispeech.py +1 -1
  66. sonusai/speech/mcgill.py +1 -1
  67. sonusai/speech/timit.py +2 -2
  68. sonusai/summarize_metric_spenh.py +10 -17
  69. sonusai/utils/__init__.py +7 -1
  70. sonusai/utils/asl_p56.py +2 -2
  71. sonusai/utils/asr.py +2 -2
  72. sonusai/utils/asr_functions/aaware_whisper.py +4 -5
  73. sonusai/utils/choice.py +31 -0
  74. sonusai/utils/compress.py +1 -1
  75. sonusai/utils/dataclass_from_dict.py +19 -1
  76. sonusai/utils/energy_f.py +3 -3
  77. sonusai/utils/evaluate_random_rule.py +15 -0
  78. sonusai/utils/keyboard_interrupt.py +12 -0
  79. sonusai/utils/onnx_utils.py +3 -17
  80. sonusai/utils/print_mixture_details.py +21 -19
  81. sonusai/utils/{temp_seed.py → rand.py} +3 -3
  82. sonusai/utils/read_predict_data.py +2 -2
  83. sonusai/utils/reshape.py +3 -3
  84. sonusai/utils/stratified_shuffle_split.py +3 -3
  85. sonusai/{mixture → utils}/tokenized_shell_vars.py +1 -1
  86. sonusai/utils/write_audio.py +2 -2
  87. sonusai/vars.py +11 -4
  88. {sonusai-0.20.2.dist-info → sonusai-1.0.1.dist-info}/METADATA +4 -2
  89. sonusai-1.0.1.dist-info/RECORD +138 -0
  90. sonusai/mixture/augmentation.py +0 -444
  91. sonusai/mixture/class_count.py +0 -15
  92. sonusai/mixture/eq_rule_is_valid.py +0 -45
  93. sonusai/mixture/target_class_balancing.py +0 -107
  94. sonusai/mixture/targets.py +0 -175
  95. sonusai-0.20.2.dist-info/RECORD +0 -128
  96. {sonusai-0.20.2.dist-info → sonusai-1.0.1.dist-info}/WHEEL +0 -0
  97. {sonusai-0.20.2.dist-info → sonusai-1.0.1.dist-info}/entry_points.txt +0 -0
sonusai/mixture/mixdb.py CHANGED
@@ -6,36 +6,43 @@ from sqlite3 import Connection
6
6
  from sqlite3 import Cursor
7
7
  from typing import Any
8
8
 
9
- from .datatypes import ASRConfigs
10
- from .datatypes import AudioF
11
- from .datatypes import AudioT
12
- from .datatypes import ClassCount
13
- from .datatypes import Feature
14
- from .datatypes import FeatureGeneratorConfig
15
- from .datatypes import FeatureGeneratorInfo
16
- from .datatypes import GeneralizedIDs
17
- from .datatypes import ImpulseResponseFile
18
- from .datatypes import MetricDoc
19
- from .datatypes import MetricDocs
20
- from .datatypes import Mixture
21
- from .datatypes import NoiseFile
22
- from .datatypes import Segsnr
23
- from .datatypes import SpectralMask
24
- from .datatypes import SpeechMetadata
25
- from .datatypes import TargetFile
26
- from .datatypes import TransformConfig
27
- from .datatypes import TruthConfigs
28
- from .datatypes import TruthDict
29
- from .datatypes import UniversalSNR
9
+ from ..datatypes import ASRConfigs
10
+ from ..datatypes import AudioF
11
+ from ..datatypes import AudioT
12
+ from ..datatypes import ClassCount
13
+ from ..datatypes import Feature
14
+ from ..datatypes import FeatureGeneratorConfig
15
+ from ..datatypes import FeatureGeneratorInfo
16
+ from ..datatypes import GeneralizedIDs
17
+ from ..datatypes import ImpulseResponseFile
18
+ from ..datatypes import MetricDoc
19
+ from ..datatypes import MetricDocs
20
+ from ..datatypes import Mixture
21
+ from ..datatypes import Segsnr
22
+ from ..datatypes import SourceFile
23
+ from ..datatypes import Sources
24
+ from ..datatypes import SourcesAudioF
25
+ from ..datatypes import SourcesAudioT
26
+ from ..datatypes import SpectralMask
27
+ from ..datatypes import SpeechMetadata
28
+ from ..datatypes import TransformConfig
29
+ from ..datatypes import TruthConfigs
30
+ from ..datatypes import TruthDict
31
+ from ..datatypes import TruthsConfigs
32
+ from ..datatypes import TruthsDict
33
+ from ..datatypes import UniversalSNR
30
34
 
31
35
 
32
36
  def db_file(location: str, test: bool = False) -> str:
33
37
  from os.path import join
34
38
 
39
+ from .constants import MIXDB_NAME
40
+ from .constants import TEST_MIXDB_NAME
41
+
35
42
  if test:
36
- name = "mixdb_test.db"
43
+ name = TEST_MIXDB_NAME
37
44
  else:
38
- name = "mixdb.db"
45
+ name = MIXDB_NAME
39
46
 
40
47
  return join(location, name)
41
48
 
@@ -113,7 +120,7 @@ class MixtureDatabase:
113
120
 
114
121
  @cached_property
115
122
  def json(self) -> str:
116
- from .datatypes import MixtureDatabaseConfig
123
+ from ..datatypes import MixtureDatabaseConfig
117
124
 
118
125
  config = MixtureDatabaseConfig(
119
126
  asr_configs=self.asr_configs,
@@ -121,13 +128,11 @@ class MixtureDatabase:
121
128
  class_labels=self.class_labels,
122
129
  class_weights_threshold=self.class_weights_thresholds,
123
130
  feature=self.feature,
124
- impulse_response_files=self.impulse_response_files,
125
- mixtures=self.mixtures(),
126
- noise_mix_mode=self.noise_mix_mode,
127
- noise_files=self.noise_files,
131
+ ir_files=self.ir_files,
132
+ mixtures=self.mixtures,
128
133
  num_classes=self.num_classes,
129
134
  spectral_masks=self.spectral_masks,
130
- target_files=self.target_files,
135
+ source_files=self.source_files,
131
136
  )
132
137
  return config.to_json(indent=2)
133
138
 
@@ -153,12 +158,15 @@ class MixtureDatabase:
153
158
  return get_feature_generator_info(self.fg_config)
154
159
 
155
160
  @cached_property
156
- def truth_parameters(self) -> dict[str, int | None]:
161
+ def truth_parameters(self) -> dict[str, dict[str, int | None]]:
157
162
  with self.db() as c:
158
- rows = c.execute("SELECT * FROM truth_parameters").fetchall()
159
- truth_parameters: dict[str, int | None] = {}
163
+ rows = c.execute("SELECT category, name, parameters FROM truth_parameters").fetchall()
164
+ truth_parameters: dict[str, dict[str, int | None]] = {}
160
165
  for row in rows:
161
- truth_parameters[row[1]] = row[2]
166
+ category, name, parameters = row
167
+ if category not in truth_parameters:
168
+ truth_parameters[category] = {}
169
+ truth_parameters[category][name] = parameters
162
170
  return truth_parameters
163
171
 
164
172
  @cached_property
@@ -166,11 +174,6 @@ class MixtureDatabase:
166
174
  with self.db() as c:
167
175
  return int(c.execute("SELECT top.num_classes FROM top").fetchone()[0])
168
176
 
169
- @cached_property
170
- def noise_mix_mode(self) -> str:
171
- with self.db() as c:
172
- return str(c.execute("SELECT top.noise_mix_mode FROM top").fetchone()[0])
173
-
174
177
  @cached_property
175
178
  def asr_configs(self) -> ASRConfigs:
176
179
  import json
@@ -223,36 +226,36 @@ class MixtureDatabase:
223
226
  "mxssnrdbf_std",
224
227
  "Per-bin segmental standard deviation of the dB frame values over all frames (using feature transform)",
225
228
  ),
226
- MetricDoc("Mixture Metrics", "mxpesq", "PESQ of mixture versus true targets"),
229
+ MetricDoc("Mixture Metrics", "mxpesq", "PESQ of mixture versus true sources"),
227
230
  MetricDoc(
228
231
  "Mixture Metrics",
229
232
  "mxwsdr",
230
- "Weighted signal distortion ratio of mixture versus true targets",
233
+ "Weighted signal distortion ratio of mixture versus true sources",
231
234
  ),
232
235
  MetricDoc(
233
236
  "Mixture Metrics",
234
237
  "mxpd",
235
- "Phase distance between mixture and true targets",
238
+ "Phase distance between mixture and true sources",
236
239
  ),
237
240
  MetricDoc(
238
241
  "Mixture Metrics",
239
242
  "mxstoi",
240
- "Short term objective intelligibility of mixture versus true targets",
243
+ "Short term objective intelligibility of mixture versus true sources",
241
244
  ),
242
245
  MetricDoc(
243
246
  "Mixture Metrics",
244
247
  "mxcsig",
245
- "Predicted rating of speech distortion of mixture versus true targets",
248
+ "Predicted rating of speech distortion of mixture versus true sources",
246
249
  ),
247
250
  MetricDoc(
248
251
  "Mixture Metrics",
249
252
  "mxcbak",
250
- "Predicted rating of background distortion of mixture versus true targets",
253
+ "Predicted rating of background distortion of mixture versus true sources",
251
254
  ),
252
255
  MetricDoc(
253
256
  "Mixture Metrics",
254
257
  "mxcovl",
255
- "Predicted rating of overall quality of mixture versus true targets",
258
+ "Predicted rating of overall quality of mixture versus true sources",
256
259
  ),
257
260
  MetricDoc("Mixture Metrics", "ssnr", "Segmental SNR"),
258
261
  MetricDoc("Mixture Metrics", "mxdco", "Mixture DC offset"),
@@ -265,26 +268,26 @@ class MixtureDatabase:
265
268
  MetricDoc("Mixture Metrics", "mxcr", "Mixture Crest factor"),
266
269
  MetricDoc("Mixture Metrics", "mxfl", "Mixture Flat factor"),
267
270
  MetricDoc("Mixture Metrics", "mxpkc", "Mixture Pk count"),
268
- MetricDoc("Mixture Metrics", "mxtdco", "Mixture target DC offset"),
269
- MetricDoc("Mixture Metrics", "mxtmin", "Mixture target min level"),
270
- MetricDoc("Mixture Metrics", "mxtmax", "Mixture target max levl"),
271
- MetricDoc("Mixture Metrics", "mxtpkdb", "Mixture target Pk lev dB"),
272
- MetricDoc("Mixture Metrics", "mxtlrms", "Mixture target RMS lev dB"),
273
- MetricDoc("Mixture Metrics", "mxtpkr", "Mixture target RMS Pk dB"),
274
- MetricDoc("Mixture Metrics", "mxttr", "Mixture target RMS Tr dB"),
275
- MetricDoc("Mixture Metrics", "mxtcr", "Mixture target Crest factor"),
276
- MetricDoc("Mixture Metrics", "mxtfl", "Mixture target Flat factor"),
277
- MetricDoc("Mixture Metrics", "mxtpkc", "Mixture target Pk count"),
278
- MetricDoc("Targets Metrics", "tdco", "Targets DC offset"),
279
- MetricDoc("Targets Metrics", "tmin", "Targets min level"),
280
- MetricDoc("Targets Metrics", "tmax", "Targets max levl"),
281
- MetricDoc("Targets Metrics", "tpkdb", "Targets Pk lev dB"),
282
- MetricDoc("Targets Metrics", "tlrms", "Targets RMS lev dB"),
283
- MetricDoc("Targets Metrics", "tpkr", "Targets RMS Pk dB"),
284
- MetricDoc("Targets Metrics", "ttr", "Targets RMS Tr dB"),
285
- MetricDoc("Targets Metrics", "tcr", "Targets Crest factor"),
286
- MetricDoc("Targets Metrics", "tfl", "Targets Flat factor"),
287
- MetricDoc("Targets Metrics", "tpkc", "Targets Pk count"),
271
+ MetricDoc("Mixture Metrics", "mxtdco", "Mixture source DC offset"),
272
+ MetricDoc("Mixture Metrics", "mxtmin", "Mixture source min level"),
273
+ MetricDoc("Mixture Metrics", "mxtmax", "Mixture source max levl"),
274
+ MetricDoc("Mixture Metrics", "mxtpkdb", "Mixture source Pk lev dB"),
275
+ MetricDoc("Mixture Metrics", "mxtlrms", "Mixture source RMS lev dB"),
276
+ MetricDoc("Mixture Metrics", "mxtpkr", "Mixture source RMS Pk dB"),
277
+ MetricDoc("Mixture Metrics", "mxttr", "Mixture source RMS Tr dB"),
278
+ MetricDoc("Mixture Metrics", "mxtcr", "Mixture source Crest factor"),
279
+ MetricDoc("Mixture Metrics", "mxtfl", "Mixture source Flat factor"),
280
+ MetricDoc("Mixture Metrics", "mxtpkc", "Mixture source Pk count"),
281
+ MetricDoc("Sources Metrics", "sdco", "Sources DC offset"),
282
+ MetricDoc("Sources Metrics", "smin", "Sources min level"),
283
+ MetricDoc("Sources Metrics", "smax", "Sources max levl"),
284
+ MetricDoc("Sources Metrics", "spkdb", "Sources Pk lev dB"),
285
+ MetricDoc("Sources Metrics", "slrms", "Sources RMS lev dB"),
286
+ MetricDoc("Sources Metrics", "spkr", "Sources RMS Pk dB"),
287
+ MetricDoc("Sources Metrics", "str", "Sources RMS Tr dB"),
288
+ MetricDoc("Sources Metrics", "scr", "Sources Crest factor"),
289
+ MetricDoc("Sources Metrics", "sfl", "Sources Flat factor"),
290
+ MetricDoc("Sources Metrics", "spkc", "Sources Pk count"),
288
291
  MetricDoc("Noise Metrics", "ndco", "Noise DC offset"),
289
292
  MetricDoc("Noise Metrics", "nmin", "Noise min level"),
290
293
  MetricDoc("Noise Metrics", "nmax", "Noise max levl"),
@@ -320,16 +323,16 @@ class MixtureDatabase:
320
323
  for name in self.asr_configs:
321
324
  metrics.append(
322
325
  MetricDoc(
323
- "Target Metrics",
324
- f"mxtasr.{name}",
325
- f"Mixture Target ASR text using {name} ASR as defined in mixdb asr_configs parameter",
326
+ "Source Metrics",
327
+ f"mxsasr.{name}",
328
+ f"Mixture Source ASR text using {name} ASR as defined in mixdb asr_configs parameter",
326
329
  )
327
330
  )
328
331
  metrics.append(
329
332
  MetricDoc(
330
- "Target Metrics",
331
- f"tasr.{name}",
332
- f"Targets ASR text using {name} ASR as defined in mixdb asr_configs parameter",
333
+ "Source Metrics",
334
+ f"sasr.{name}",
335
+ f"Sources ASR text using {name} ASR as defined in mixdb asr_configs parameter",
333
336
  )
334
337
  )
335
338
  metrics.append(
@@ -341,16 +344,16 @@ class MixtureDatabase:
341
344
  )
342
345
  metrics.append(
343
346
  MetricDoc(
344
- "Target Metrics",
347
+ "Source Metrics",
345
348
  f"basewer.{name}",
346
- f"Word error rate of tasr.{name} vs. speech text metadata for the target",
349
+ f"Word error rate of sasr.{name} vs. speech text metadata for the source",
347
350
  )
348
351
  )
349
352
  metrics.append(
350
353
  MetricDoc(
351
354
  "Mixture Metrics",
352
355
  f"mxwer.{name}",
353
- f"Word error rate of mxasr.{name} vs. tasr.{name}",
356
+ f"Word error rate of mxasr.{name} vs. sasr.{name}",
354
357
  )
355
358
  )
356
359
 
@@ -396,7 +399,7 @@ class MixtureDatabase:
396
399
 
397
400
  @cached_property
398
401
  def transform_frame_ms(self) -> float:
399
- from .constants import SAMPLE_RATE
402
+ from ..constants import SAMPLE_RATE
400
403
 
401
404
  return float(self.ft_config.overlap) / float(SAMPLE_RATE / 1000)
402
405
 
@@ -417,12 +420,7 @@ class MixtureDatabase:
417
420
  return self.ft_config.overlap * self.fg_decimation * self.fg_step
418
421
 
419
422
  def total_samples(self, m_ids: GeneralizedIDs = "*") -> int:
420
- samples = 0
421
- for m_id in self.mixids_to_list(m_ids):
422
- s = self.mixture(m_id).samples
423
- if s is not None:
424
- samples += s
425
- return samples
423
+ return sum([self.mixture(m_id).samples for m_id in self.mixids_to_list(m_ids)])
426
424
 
427
425
  def total_transform_frames(self, m_ids: GeneralizedIDs = "*") -> int:
428
426
  return self.total_samples(m_ids) // self.ft_config.overlap
@@ -476,30 +474,18 @@ class MixtureDatabase:
476
474
  ).fetchall()
477
475
  ]
478
476
 
479
- @cached_property
480
- def truth_configs(self) -> TruthConfigs:
481
- """Get truth configs from db
482
-
483
- :return: Truth configs
484
- """
485
- import json
477
+ def category_truth_configs(self, category: str) -> dict[str, str]:
478
+ return _category_truth_configs(self.db, category, self.use_cache)
486
479
 
487
- from .datatypes import TruthConfig
480
+ def source_truth_configs(self, s_id: int) -> TruthConfigs:
481
+ return _source_truth_configs(self.db, s_id, self.use_cache)
488
482
 
489
- with self.db() as c:
490
- truth_configs: TruthConfigs = {}
491
- for truth_config_record in c.execute("SELECT truth_config.config FROM truth_config").fetchall():
492
- truth_config = json.loads(truth_config_record[0])
493
- if truth_config["name"] not in truth_configs:
494
- truth_configs[truth_config["name"]] = TruthConfig(
495
- function=truth_config["function"],
496
- stride_reduction=truth_config["stride_reduction"],
497
- config=truth_config["config"],
498
- )
499
- return truth_configs
500
-
501
- def target_truth_configs(self, t_id: int) -> TruthConfigs:
502
- return _target_truth_configs(self.db, t_id, self.use_cache)
483
+ def mixture_truth_configs(self, m_id: int) -> TruthsConfigs:
484
+ mixture = self.mixture(m_id)
485
+ return {
486
+ category: self.source_truth_configs(mixture.all_sources[category].file_id)
487
+ for category in mixture.all_sources
488
+ }
503
489
 
504
490
  @cached_property
505
491
  def random_snrs(self) -> list[float]:
@@ -509,10 +495,7 @@ class MixtureDatabase:
509
495
  """
510
496
  with self.db() as c:
511
497
  return list(
512
- {
513
- float(item[0])
514
- for item in c.execute("SELECT mixture.snr FROM mixture WHERE mixture.random_snr == 1").fetchall()
515
- }
498
+ {float(item[0]) for item in c.execute("SELECT snr FROM source WHERE snr_random == 1").fetchall()}
516
499
  )
517
500
 
518
501
  @cached_property
@@ -523,10 +506,7 @@ class MixtureDatabase:
523
506
  """
524
507
  with self.db() as c:
525
508
  return list(
526
- {
527
- float(item[0])
528
- for item in c.execute("SELECT mixture.snr FROM mixture WHERE mixture.random_snr == 0").fetchall()
529
- }
509
+ {float(item[0]) for item in c.execute("SELECT snr FROM source WHERE snr_random == 0").fetchall()}
530
510
  )
531
511
 
532
512
  @cached_property
@@ -570,199 +550,246 @@ class MixtureDatabase:
570
550
  return _spectral_mask(self.db, sm_id, self.use_cache)
571
551
 
572
552
  @cached_property
573
- def target_files(self) -> list[TargetFile]:
574
- """Get target files from db
553
+ def source_files(self) -> dict[str, list[SourceFile]]:
554
+ """Get source files from db
575
555
 
576
- :return: Target files
556
+ :return: Source files
577
557
  """
578
558
  import json
579
559
 
580
- from .datatypes import TruthConfig
581
- from .datatypes import TruthConfigs
582
- from .db_datatypes import TargetFileRecord
560
+ from ..datatypes import TruthConfig
561
+ from ..datatypes import TruthConfigs
562
+ from .db_datatypes import SourceFileRecord
583
563
 
584
564
  with self.db() as c:
585
- target_files: list[TargetFile] = []
586
- target_file_records = [
587
- TargetFileRecord(*result) for result in c.execute("SELECT * FROM target_file").fetchall()
588
- ]
589
- for target_file_record in target_file_records:
590
- truth_configs: TruthConfigs = {}
591
- for truth_config_records in c.execute(
592
- """
593
- SELECT truth_config.config
594
- FROM truth_config, target_file_truth_config
595
- WHERE ? = target_file_truth_config.target_file_id
596
- AND truth_config.id = target_file_truth_config.truth_config_id
597
- """,
598
- (target_file_record.id,),
599
- ).fetchall():
600
- truth_config = json.loads(truth_config_records[0])
601
- truth_configs[truth_config["name"]] = TruthConfig(
602
- function=truth_config["function"],
603
- stride_reduction=truth_config["stride_reduction"],
604
- config=truth_config["config"],
605
- )
606
- target_files.append(
607
- TargetFile(
608
- name=target_file_record.name,
609
- samples=target_file_record.samples,
610
- class_indices=json.loads(target_file_record.class_indices),
611
- level_type=target_file_record.level_type,
612
- truth_configs=truth_configs,
613
- speaker_id=target_file_record.speaker_id,
565
+ source_files: dict[str, list[SourceFile]] = {}
566
+ categories = c.execute("SELECT DISTINCT category FROM source_file").fetchall()
567
+ for category in categories:
568
+ source_files[category[0]] = []
569
+ source_file_records = [
570
+ SourceFileRecord(*result)
571
+ for result in c.execute(
572
+ """
573
+ SELECT *
574
+ FROM source_file
575
+ WHERE ? = source_file.category
576
+ """,
577
+ (category[0],),
578
+ ).fetchall()
579
+ ]
580
+ for source_file_record in source_file_records:
581
+ truth_configs: TruthConfigs = {}
582
+ for truth_config_records in c.execute(
583
+ """
584
+ SELECT truth_config.config
585
+ FROM truth_config, source_file_truth_config
586
+ WHERE ? = source_file_truth_config.source_file_id
587
+ AND truth_config.id = source_file_truth_config.truth_config_id
588
+ """,
589
+ (source_file_record.id,),
590
+ ).fetchall():
591
+ truth_config = json.loads(truth_config_records[0])
592
+ truth_configs[truth_config["name"]] = TruthConfig(
593
+ function=truth_config["function"],
594
+ stride_reduction=truth_config["stride_reduction"],
595
+ config=truth_config["config"],
596
+ )
597
+ source_files[source_file_record.category].append(
598
+ SourceFile(
599
+ id=source_file_record.id,
600
+ category=source_file_record.category,
601
+ name=source_file_record.name,
602
+ samples=source_file_record.samples,
603
+ class_indices=json.loads(source_file_record.class_indices),
604
+ level_type=source_file_record.level_type,
605
+ truth_configs=truth_configs,
606
+ speaker_id=source_file_record.speaker_id,
607
+ )
614
608
  )
615
- )
616
- return target_files
609
+ return source_files
617
610
 
618
611
  @cached_property
619
- def target_file_ids(self) -> list[int]:
620
- """Get target file IDs from db
612
+ def source_file_ids(self) -> dict[str, list[int]]:
613
+ """Get source file IDs from db
621
614
 
622
- :return: List of target file IDs
615
+ :return: Dictionary of list of source file IDs
623
616
  """
624
617
  with self.db() as c:
625
- return [int(item[0]) for item in c.execute("SELECT target_file.id FROM target_file").fetchall()]
618
+ source_file_ids: dict[str, list[int]] = {}
619
+ categories = c.execute("SELECT DISTINCT category FROM source_file").fetchall()
620
+ for category in categories:
621
+ # items = c.execute(
622
+ # """
623
+ # SELECT source_file.id
624
+ # FROM source_file
625
+ # WHERE ? = source_file.category
626
+ # """,
627
+ # (category[0],),
628
+ # ).fetchall()
629
+ # source_file_ids[category[0]] = [int(item[0]) for item in items]
630
+ source_file_ids[category[0]] = [
631
+ int(item[0])
632
+ for item in c.execute(
633
+ """
634
+ SELECT source_file.id
635
+ FROM source_file
636
+ WHERE ? = source_file.category
637
+ """,
638
+ (category[0],),
639
+ ).fetchall()
640
+ ]
641
+ return source_file_ids
626
642
 
627
- def target_file(self, t_id: int) -> TargetFile:
628
- """Get target file with ID from db
643
+ def source_file(self, s_id: int) -> SourceFile:
644
+ """Get source file with ID from db
629
645
 
630
- :param t_id: Target file ID
631
- :return: Target file
646
+ :param s_id: Source file ID
647
+ :return: Source file
632
648
  """
633
- return _target_file(self.db, t_id, self.use_cache)
649
+ return _source_file(self.db, s_id, self.use_cache)
634
650
 
635
- @cached_property
636
- def num_target_files(self) -> int:
637
- """Get number of target files from db
651
+ def num_source_files(self, category: str) -> int:
652
+ """Get number of source files from category from db
638
653
 
639
- :return: Number of target files
654
+ :param category: Source category
655
+ :return: Number of source files
640
656
  """
641
- with self.db() as c:
642
- return int(c.execute("SELECT count(target_file.id) FROM target_file").fetchone()[0])
657
+ return _num_source_files(self.db, category, self.use_cache)
643
658
 
644
659
  @cached_property
645
- def noise_files(self) -> list[NoiseFile]:
646
- """Get noise files from db
660
+ def ir_files(self) -> list[ImpulseResponseFile]:
661
+ """Get impulse response files from db
647
662
 
648
- :return: Noise files
663
+ :return: Impulse response files
649
664
  """
650
- with self.db() as c:
651
- return [
652
- NoiseFile(name=noise[0], samples=noise[1])
653
- for noise in c.execute("SELECT noise_file.name, samples FROM noise_file").fetchall()
654
- ]
655
-
656
- @cached_property
657
- def noise_file_ids(self) -> list[int]:
658
- """Get noise file IDs from db
665
+ from .db_datatypes import ImpulseResponseFileRecord
659
666
 
660
- :return: List of noise file IDs
661
- """
662
667
  with self.db() as c:
663
- return [int(item[0]) for item in c.execute("SELECT noise_file.id FROM noise_file").fetchall()]
668
+ files: list[ImpulseResponseFile] = []
669
+ entries = c.execute("SELECT * FROM ir_file").fetchall()
670
+ for entry in entries:
671
+ file = ImpulseResponseFileRecord(*entry)
672
+
673
+ tags = [
674
+ tag[0]
675
+ for tag in c.execute(
676
+ """
677
+ SELECT ir_tag.tag
678
+ FROM ir_tag, ir_file_ir_tag
679
+ WHERE ? = ir_file_ir_tag.file_id
680
+ AND ir_tag.id = ir_file_ir_tag.tag_id
681
+ """,
682
+ (file.id,),
683
+ ).fetchall()
684
+ ]
664
685
 
665
- def noise_file(self, n_id: int) -> NoiseFile:
666
- """Get noise file with ID from db
686
+ files.append(
687
+ ImpulseResponseFile(
688
+ delay=file.delay,
689
+ name=file.name,
690
+ tags=tags,
691
+ )
692
+ )
667
693
 
668
- :param n_id: Noise file ID
669
- :return: Noise file
670
- """
671
- return _noise_file(self.db, n_id, self.use_cache)
694
+ return files
672
695
 
673
696
  @cached_property
674
- def num_noise_files(self) -> int:
675
- """Get number of noise files from db
697
+ def ir_file_ids(self) -> list[int]:
698
+ """Get impulse response file IDs from db
676
699
 
677
- :return: Number of noise files
700
+ :return: List of impulse response file IDs
678
701
  """
679
702
  with self.db() as c:
680
- return int(c.execute("SELECT count(noise_file.id) FROM noise_file").fetchone()[0])
703
+ return [int(item[0]) for item in c.execute("SELECT ir_file.id FROM ir_file").fetchall()]
681
704
 
682
- @cached_property
683
- def impulse_response_files(self) -> list[ImpulseResponseFile]:
684
- """Get impulse response files from db
705
+ def ir_file_ids_for_tag(self, tag: str) -> list[int]:
706
+ """Get impulse response file IDs for given tag from db
685
707
 
686
- :return: Impulse response files
708
+ :return: List of impulse response file IDs for given tag
687
709
  """
688
- import json
689
-
690
- from .datatypes import ImpulseResponseFile
691
-
692
710
  with self.db() as c:
693
- return [
694
- ImpulseResponseFile(impulse_response[1], json.loads(impulse_response[2]), impulse_response[3])
695
- for impulse_response in c.execute(
696
- "SELECT impulse_response_file.* FROM impulse_response_file"
697
- ).fetchall()
698
- ]
699
-
700
- @cached_property
701
- def impulse_response_file_ids(self) -> list[int]:
702
- """Get impulse response file IDs from db
711
+ tag_id = c.execute("SELECT ir_tag.id FROM ir_tag WHERE ? = ir_tag.tag", (tag,)).fetchone()
712
+ if not tag_id:
713
+ return []
703
714
 
704
- :return: List of impulse response file IDs
705
- """
706
- with self.db() as c:
707
715
  return [
708
- int(item[0])
709
- for item in c.execute("SELECT impulse_response_file.id FROM impulse_response_file").fetchall()
716
+ int(item[0] - 1)
717
+ for item in c.execute(
718
+ """
719
+ SELECT ir_file_ir_tag.file_id
720
+ FROM ir_file_ir_tag
721
+ WHERE ? = ir_file_ir_tag.tag_id
722
+ """,
723
+ (tag_id[0],),
724
+ ).fetchall()
710
725
  ]
711
726
 
712
- def impulse_response_file(self, ir_id: int | None) -> str | None:
727
+ def ir_file(self, ir_id: int) -> str:
713
728
  """Get impulse response file name with ID from db
714
729
 
715
730
  :param ir_id: Impulse response file ID
716
731
  :return: Impulse response file name
717
732
  """
718
- if ir_id is None:
719
- return None
720
- return _impulse_response_file(self.db, ir_id, self.use_cache)
733
+ return _ir_file(self.db, ir_id, self.use_cache)
721
734
 
722
- def impulse_response_delay(self, ir_id: int | None) -> int | None:
735
+ def ir_delay(self, ir_id: int) -> int:
723
736
  """Get impulse response delay with ID from db
724
737
 
725
738
  :param ir_id: Impulse response file ID
726
739
  :return: Impulse response delay
727
740
  """
728
- if ir_id is None:
729
- return None
730
- return _impulse_response_delay(self.db, ir_id, self.use_cache)
741
+ return _ir_delay(self.db, ir_id, self.use_cache)
731
742
 
732
743
  @cached_property
733
- def num_impulse_response_files(self) -> int:
744
+ def num_ir_files(self) -> int:
734
745
  """Get number of impulse response files from db
735
746
 
736
747
  :return: Number of impulse response files
737
748
  """
738
749
  with self.db() as c:
739
- return int(c.execute("SELECT count(impulse_response_file.id) FROM impulse_response_file").fetchone()[0])
750
+ return int(c.execute("SELECT count(ir_file.id) FROM ir_file").fetchone()[0])
751
+
752
+ @cached_property
753
+ def ir_tags(self) -> list[str]:
754
+ """Get tags of impulse response files from db
755
+
756
+ :return: Tags of impulse response files
757
+ """
758
+ with self.db() as c:
759
+ return [tag[0] for tag in c.execute("SELECT ir_tag.tag FROM ir_tag").fetchall()]
740
760
 
761
+ @property
741
762
  def mixtures(self) -> list[Mixture]:
742
763
  """Get mixtures from db
743
764
 
744
765
  :return: Mixtures
745
766
  """
746
767
  from .db_datatypes import MixtureRecord
747
- from .db_datatypes import TargetRecord
768
+ from .db_datatypes import SourceRecord
748
769
  from .helpers import to_mixture
749
- from .helpers import to_target
770
+ from .helpers import to_source
750
771
 
751
772
  with self.db() as c:
752
773
  mixtures: list[Mixture] = []
753
774
  for mixture in [MixtureRecord(*record) for record in c.execute("SELECT * FROM mixture").fetchall()]:
754
- targets = [
755
- to_target(TargetRecord(*target))
756
- for target in c.execute(
775
+ sources_list = [
776
+ to_source(SourceRecord(*source))
777
+ for source in c.execute(
757
778
  """
758
- SELECT target.*
759
- FROM target, mixture_target
760
- WHERE ? = mixture_target.mixture_id AND target.id = mixture_target.target_id
779
+ SELECT source.*
780
+ FROM source, mixture_source
781
+ WHERE ? = mixture_source.mixture_id AND source.id = mixture_source.source_id
761
782
  """,
762
783
  (mixture.id,),
763
784
  ).fetchall()
764
785
  ]
765
- mixtures.append(to_mixture(mixture, targets))
786
+
787
+ sources: Sources = {}
788
+ for source in sources_list:
789
+ sources[self.source_file(source.file_id).category] = source
790
+
791
+ mixtures.append(to_mixture(mixture, sources))
792
+
766
793
  return mixtures
767
794
 
768
795
  @cached_property
@@ -806,229 +833,340 @@ class MixtureDatabase:
806
833
  with self.db() as c:
807
834
  return int(c.execute("SELECT count(mixture.id) FROM mixture").fetchone()[0])
808
835
 
809
- def read_mixture_data(self, m_id: int, items: list[str] | str) -> Any:
836
+ def read_mixture_data(self, m_id: int, items: list[str] | str) -> dict[str, Any]:
810
837
  """Read mixture data
811
838
 
812
839
  :param m_id: Zero-based mixture ID
813
840
  :param items: String(s) of dataset(s) to retrieve
814
- :return: Data (or tuple of data)
841
+ :return: Dictionary of name: data
815
842
  """
816
- from sonusai.mixture import read_cached_data
843
+ from .data_io import read_cached_data
817
844
 
818
845
  return read_cached_data(self.location, "mixture", self.mixture(m_id).name, items)
819
846
 
820
- def read_target_audio(self, t_id: int) -> AudioT:
821
- """Read target audio
847
+ def read_source_audio(self, s_id: int) -> AudioT:
848
+ """Read source audio
822
849
 
823
- :param t_id: Target ID
824
- :return: Target audio
850
+ :param s_id: Source ID
851
+ :return: Source audio
825
852
  """
826
853
  from .audio import read_audio
827
854
 
828
- return read_audio(self.target_file(t_id).name, self.use_cache)
829
-
830
- def augmented_noise_audio(self, mixture: Mixture) -> AudioT:
831
- """Get augmented noise audio
832
-
833
- :param mixture: Mixture
834
- :return: Augmented noise audio
835
- """
836
- from .audio import read_audio
837
- from .augmentation import apply_augmentation
838
-
839
- noise = self.noise_file(mixture.noise.file_id)
840
- audio = read_audio(noise.name, self.use_cache)
841
- audio = apply_augmentation(self, audio, mixture.noise.augmentation.pre)
842
-
843
- return audio
855
+ return read_audio(self.source_file(s_id).name, self.use_cache)
844
856
 
845
857
  def mixture_class_indices(self, m_id: int) -> list[int]:
846
858
  class_indices: list[int] = []
847
- for t_id in self.mixture(m_id).target_ids:
848
- class_indices.extend(self.target_file(t_id).class_indices)
859
+ for s_id in self.mixture(m_id).source_ids.values():
860
+ class_indices.extend(self.source_file(s_id).class_indices)
849
861
  return sorted(set(class_indices))
850
862
 
851
- def mixture_targets(self, m_id: int, force: bool = False) -> list[AudioT]:
852
- """Get the list of augmented target audio data (one per target in the mixup) for the given mixture ID
863
+ def mixture_sources(self, m_id: int, force: bool = False, cache: bool = False) -> SourcesAudioT:
864
+ """Get the pre-truth source audio data (one per source in the mixture) for the given mixture ID
853
865
 
854
866
  :param m_id: Zero-based mixture ID
855
867
  :param force: Force computing data from original sources regardless of whether cached data exists
856
- :return: List of augmented target audio data (one per target in the mixup)
868
+ :param cache: Cache result
869
+ :return: Dictionary of pre-truth source audio data (one per source in the mixture)
857
870
  """
858
- from .augmentation import apply_augmentation
859
- from .augmentation import apply_gain
860
- from .augmentation import pad_audio_to_length
871
+ from .data_io import write_cached_data
872
+ from .effects import apply_effects
873
+ from .effects import conform_audio_to_length
861
874
 
862
875
  if not force:
863
- targets_audio = self.read_mixture_data(m_id, "targets")
864
- if targets_audio is not None:
865
- return list(targets_audio)
876
+ sources = self.read_mixture_data(m_id, "sources")["sources"]
877
+ if sources is not None:
878
+ return sources
866
879
 
867
880
  mixture = self.mixture(m_id)
868
881
  if mixture is None:
869
882
  raise ValueError(f"Could not find mixture for m_id: {m_id}")
870
883
 
871
- targets_audio = []
872
- for target in mixture.targets:
873
- target_audio = self.read_target_audio(target.file_id)
874
- target_audio = apply_augmentation(
875
- mixdb=self,
876
- audio=target_audio,
877
- augmentation=target.augmentation.pre,
878
- frame_length=self.feature_step_samples,
884
+ sources = {}
885
+ for category, source in mixture.all_sources.items():
886
+ source = mixture.all_sources[category]
887
+ audio = self.read_source_audio(source.file_id)
888
+ audio = apply_effects(self, audio, source.effects, pre=True, post=False)
889
+ audio = conform_audio_to_length(audio, mixture.samples, source.repeat, source.start)
890
+ sources[category] = audio
891
+
892
+ if cache:
893
+ write_cached_data(
894
+ location=self.location,
895
+ name="mixture",
896
+ index=mixture.name,
897
+ items={"sources": sources},
879
898
  )
880
- target_audio = apply_gain(audio=target_audio, gain=mixture.target_snr_gain)
881
- target_audio = pad_audio_to_length(audio=target_audio, length=mixture.samples)
882
- targets_audio.append(target_audio)
883
899
 
884
- return targets_audio
900
+ return sources
885
901
 
886
- def mixture_targets_f(self, m_id: int, targets: list[AudioT] | None = None, force: bool = False) -> list[AudioF]:
887
- """Get the list of augmented target transform data (one per target in the mixup) for the given mixture ID
902
+ def mixture_sources_f(
903
+ self,
904
+ m_id: int,
905
+ sources: SourcesAudioT | None = None,
906
+ force: bool = False,
907
+ cache: bool = False,
908
+ ) -> SourcesAudioF:
909
+ """Get the pre-truth source transform data (one per source in the mixture) for the given mixture ID
888
910
 
889
911
  :param m_id: Zero-based mixture ID
890
- :param targets: List of augmented target audio data (one per target in the mixup)
912
+ :param sources: Dictionary of pre-truth source audio data (one per source in the mixture)
891
913
  :param force: Force computing data from original sources regardless of whether cached data exists
892
- :return: List of augmented target transform data (one per target in the mixup)
914
+ :param cache: Cache result
915
+ :return: Dictionary of pre-truth source transform data (one per source in the mixture)
893
916
  """
917
+ from .data_io import write_cached_data
894
918
  from .helpers import forward_transform
895
919
 
896
- if force or targets is None:
897
- targets = self.mixture_targets(m_id, force)
920
+ if sources is None:
921
+ sources = self.mixture_sources(m_id, force)
898
922
 
899
- return [forward_transform(target, self.ft_config) for target in targets]
923
+ sources_f = {category: forward_transform(sources[category], self.ft_config) for category in sources}
924
+
925
+ if cache:
926
+ write_cached_data(
927
+ location=self.location,
928
+ name="mixture",
929
+ index=self.mixture(m_id).name,
930
+ items={"sources_f": sources_f},
931
+ )
900
932
 
901
- def mixture_target(self, m_id: int, targets: list[AudioT] | None = None, force: bool = False) -> AudioT:
902
- """Get the augmented target audio data for the given mixture ID
933
+ return sources_f
934
+
935
+ def mixture_source(
936
+ self,
937
+ m_id: int,
938
+ sources: SourcesAudioT | None = None,
939
+ force: bool = False,
940
+ cache: bool = False,
941
+ ) -> AudioT:
942
+ """Get the post-truth, summed, and gained source audio data for the given mixture ID
903
943
 
904
944
  :param m_id: Zero-based mixture ID
905
- :param targets: List of augmented target audio data (one per target in the mixup)
945
+ :param sources: Dictionary of pre-truth source audio data (one per source in the mixture)
906
946
  :param force: Force computing data from original sources regardless of whether cached data exists
907
- :return: Augmented target audio data
947
+ :param cache: Cache result
948
+ :return: Post-truth, gained, and summed source audio data
908
949
  """
909
- from .helpers import get_target
950
+ import numpy as np
951
+
952
+ from .data_io import write_cached_data
953
+ from .effects import apply_effects
910
954
 
911
955
  if not force:
912
- target = self.read_mixture_data(m_id, "target")
913
- if target is not None:
914
- return target
956
+ source = self.read_mixture_data(m_id, "source")["source"]
957
+ if source is not None:
958
+ return source
959
+
960
+ if sources is None:
961
+ sources = self.mixture_sources(m_id, force)
962
+
963
+ mixture = self.mixture(m_id)
915
964
 
916
- if force or targets is None:
917
- targets = self.mixture_targets(m_id, force)
965
+ source = np.sum(
966
+ [
967
+ apply_effects(
968
+ self,
969
+ audio=sources[category],
970
+ effects=mixture.all_sources[category].effects,
971
+ pre=False,
972
+ post=True,
973
+ )
974
+ * mixture.all_sources[category].snr_gain
975
+ for category in sources
976
+ if category != "noise"
977
+ ],
978
+ axis=0,
979
+ )
918
980
 
919
- return get_target(self, self.mixture(m_id), targets)
981
+ if cache:
982
+ write_cached_data(
983
+ location=self.location,
984
+ name="mixture",
985
+ index=mixture.name,
986
+ items={"source": source},
987
+ )
920
988
 
921
- def mixture_target_f(
989
+ return source
990
+
991
+ def mixture_source_f(
922
992
  self,
923
993
  m_id: int,
924
- targets: list[AudioT] | None = None,
925
- target: AudioT | None = None,
994
+ sources: SourcesAudioT | None = None,
995
+ source: AudioT | None = None,
926
996
  force: bool = False,
997
+ cache: bool = False,
927
998
  ) -> AudioF:
928
- """Get the augmented target transform data for the given mixture ID
999
+ """Get the post-truth, summed, and gained source transform data for the given mixture ID
929
1000
 
930
1001
  :param m_id: Zero-based mixture ID
931
- :param targets: List of augmented target audio data (one per target in the mixup)
932
- :param target: Augmented target audio for the given m_id
1002
+ :param sources: Dictionary of pre-truth source audio data (one per source in the mixture)
1003
+ :param source: Post-truth, gained, and summed source audio for the given m_id
933
1004
  :param force: Force computing data from original sources regardless of whether cached data exists
934
- :return: Augmented target transform data
1005
+ :param cache: Cache result
1006
+ :return: Post-truth, gained, and summed source transform data
935
1007
  """
1008
+ from .data_io import write_cached_data
936
1009
  from .helpers import forward_transform
937
1010
 
938
- if force or target is None:
939
- target = self.mixture_target(m_id, targets, force)
1011
+ if source is None:
1012
+ source = self.mixture_source(m_id, sources, force)
1013
+
1014
+ source_f = forward_transform(source, self.ft_config)
1015
+
1016
+ if cache:
1017
+ write_cached_data(
1018
+ location=self.location,
1019
+ name="mixture",
1020
+ index=self.mixture(m_id).name,
1021
+ items={"source_f": source_f},
1022
+ )
940
1023
 
941
- return forward_transform(target, self.ft_config)
1024
+ return source_f
942
1025
 
943
- def mixture_noise(self, m_id: int, force: bool = False) -> AudioT:
944
- """Get the augmented noise audio data for the given mixture ID
1026
+ def mixture_noise(
1027
+ self,
1028
+ m_id: int,
1029
+ sources: SourcesAudioT | None = None,
1030
+ force: bool = False,
1031
+ cache: bool = False,
1032
+ ) -> AudioT:
1033
+ """Get the post-truth and gained noise audio data for the given mixture ID
945
1034
 
946
1035
  :param m_id: Zero-based mixture ID
1036
+ :param sources: Dictionary of pre-truth source audio data (one per source in the mixture)
947
1037
  :param force: Force computing data from original sources regardless of whether cached data exists
948
- :return: Augmented noise audio data
1038
+ :param cache: Cache result
1039
+ :return: Post-truth and gained noise audio data
949
1040
  """
950
- from .audio import get_next_noise
951
- from .augmentation import apply_gain
1041
+ from .data_io import write_cached_data
1042
+ from .effects import apply_effects
952
1043
 
953
1044
  if not force:
954
- noise = self.read_mixture_data(m_id, "noise")
1045
+ noise = self.read_mixture_data(m_id, "noise")["noise"]
955
1046
  if noise is not None:
956
1047
  return noise
957
1048
 
958
- mixture = self.mixture(m_id)
959
- noise = self.augmented_noise_audio(mixture)
960
- noise = get_next_noise(audio=noise, offset=mixture.noise_offset, length=mixture.samples)
961
- return apply_gain(audio=noise, gain=mixture.noise_snr_gain)
1049
+ if sources is None:
1050
+ sources = self.mixture_sources(m_id, force)
1051
+
1052
+ noise = self.mixture(m_id).noise
1053
+ noise = apply_effects(self, sources["noise"], noise.effects, pre=False, post=True) * noise.snr_gain
962
1054
 
963
- def mixture_noise_f(self, m_id: int, noise: AudioT | None = None, force: bool = False) -> AudioF:
964
- """Get the augmented noise transform for the given mixture ID
1055
+ if cache:
1056
+ write_cached_data(
1057
+ location=self.location,
1058
+ name="mixture",
1059
+ index=self.mixture(m_id).name,
1060
+ items={"noise": noise},
1061
+ )
1062
+
1063
+ return noise
1064
+
1065
+ def mixture_noise_f(
1066
+ self,
1067
+ m_id: int,
1068
+ sources: SourcesAudioT | None = None,
1069
+ noise: AudioT | None = None,
1070
+ force: bool = False,
1071
+ cache: bool = False,
1072
+ ) -> AudioF:
1073
+ """Get the post-truth and gained noise transform for the given mixture ID
965
1074
 
966
1075
  :param m_id: Zero-based mixture ID
967
- :param noise: Augmented noise audio data
1076
+ :param sources: Dictionary of pre-truth source audio data (one per source in the mixture)
1077
+ :param noise: Post-truth and gained noise audio data
968
1078
  :param force: Force computing data from original sources regardless of whether cached data exists
969
- :return: Augmented noise transform data
1079
+ :param cache: Cache result
1080
+ :return: Post-truth and gained noise transform data
970
1081
  """
1082
+ from .data_io import write_cached_data
971
1083
  from .helpers import forward_transform
972
1084
 
973
1085
  if force or noise is None:
974
- noise = self.mixture_noise(m_id, force)
1086
+ noise = self.mixture_noise(m_id, sources, force)
1087
+
1088
+ noise_f = forward_transform(noise, self.ft_config)
1089
+ if cache:
1090
+ write_cached_data(
1091
+ location=self.location,
1092
+ name="mixture",
1093
+ index=self.mixture(m_id).name,
1094
+ items={"noise_f": noise_f},
1095
+ )
975
1096
 
976
- return forward_transform(noise, self.ft_config)
1097
+ return noise_f
977
1098
 
978
1099
  def mixture_mixture(
979
1100
  self,
980
1101
  m_id: int,
981
- targets: list[AudioT] | None = None,
982
- target: AudioT | None = None,
1102
+ sources: SourcesAudioT | None = None,
1103
+ source: AudioT | None = None,
983
1104
  noise: AudioT | None = None,
984
1105
  force: bool = False,
1106
+ cache: bool = False,
985
1107
  ) -> AudioT:
986
1108
  """Get the mixture audio data for the given mixture ID
987
1109
 
988
1110
  :param m_id: Zero-based mixture ID
989
- :param targets: List of augmented target audio data (one per target in the mixup)
990
- :param target: Augmented target audio data
991
- :param noise: Augmented noise audio data
1111
+ :param sources: Dictionary of pre-truth source audio data (one per source in the mixture)
1112
+ :param source: Post-truth, gained, and summed source audio data
1113
+ :param noise: Post-truth and gained noise audio data
992
1114
  :param force: Force computing data from original sources regardless of whether cached data exists
1115
+ :param cache: Cache result
993
1116
  :return: Mixture audio data
994
1117
  """
1118
+ from .data_io import write_cached_data
1119
+
995
1120
  if not force:
996
- mixture = self.read_mixture_data(m_id, "mixture")
1121
+ mixture = self.read_mixture_data(m_id, "mixture")["mixture"]
997
1122
  if mixture is not None:
998
1123
  return mixture
999
1124
 
1000
- if force or target is None:
1001
- target = self.mixture_target(m_id, targets, force)
1125
+ if source is None:
1126
+ source = self.mixture_source(m_id, sources, force)
1002
1127
 
1003
- if force or noise is None:
1004
- noise = self.mixture_noise(m_id, force)
1128
+ if noise is None:
1129
+ noise = self.mixture_noise(m_id, sources, force)
1130
+
1131
+ mixture = source + noise
1132
+
1133
+ if cache:
1134
+ write_cached_data(
1135
+ location=self.location,
1136
+ name="mixture",
1137
+ index=self.mixture(m_id).name,
1138
+ items={"mixture": mixture},
1139
+ )
1005
1140
 
1006
- return target + noise
1141
+ return mixture
1007
1142
 
1008
1143
  def mixture_mixture_f(
1009
1144
  self,
1010
1145
  m_id: int,
1011
- targets: list[AudioT] | None = None,
1012
- target: AudioT | None = None,
1146
+ sources: SourcesAudioT | None = None,
1147
+ source: AudioT | None = None,
1013
1148
  noise: AudioT | None = None,
1014
1149
  mixture: AudioT | None = None,
1015
1150
  force: bool = False,
1151
+ cache: bool = False,
1016
1152
  ) -> AudioF:
1017
1153
  """Get the mixture transform for the given mixture ID
1018
1154
 
1019
1155
  :param m_id: Zero-based mixture ID
1020
- :param targets: List of augmented target audio data (one per target in the mixup)
1021
- :param target: Augmented target audio data
1022
- :param noise: Augmented noise audio data
1156
+ :param sources: Dictionary of pre-truth source audio data (one per source in the mixture)
1157
+ :param source: Post-truth, gained, and summed source audio data
1158
+ :param noise: Post-truth and gained noise audio data
1023
1159
  :param mixture: Mixture audio data
1024
1160
  :param force: Force computing data from original sources regardless of whether cached data exists
1161
+ :param cache: Cache result
1025
1162
  :return: Mixture transform data
1026
1163
  """
1164
+ from .data_io import write_cached_data
1027
1165
  from .helpers import forward_transform
1028
1166
  from .spectral_mask import apply_spectral_mask
1029
1167
 
1030
- if force or mixture is None:
1031
- mixture = self.mixture_mixture(m_id, targets, target, noise, force)
1168
+ if mixture is None:
1169
+ mixture = self.mixture_mixture(m_id, sources, source, noise, force)
1032
1170
 
1033
1171
  mixture_f = forward_transform(mixture, self.ft_config)
1034
1172
 
@@ -1040,80 +1178,79 @@ class MixtureDatabase:
1040
1178
  seed=m.spectral_mask_seed,
1041
1179
  )
1042
1180
 
1181
+ if cache:
1182
+ write_cached_data(
1183
+ location=self.location,
1184
+ name="mixture",
1185
+ index=self.mixture(m_id).name,
1186
+ items={"mixture_f": mixture_f},
1187
+ )
1188
+
1043
1189
  return mixture_f
1044
1190
 
1045
- def mixture_truth_t(
1046
- self,
1047
- m_id: int,
1048
- targets: list[AudioT] | None = None,
1049
- noise: AudioT | None = None,
1050
- mixture: AudioT | None = None,
1051
- force: bool = False,
1052
- ) -> list[TruthDict]:
1191
+ def mixture_truth_t(self, m_id: int, force: bool = False, cache: bool = False) -> TruthsDict:
1053
1192
  """Get the truth_t data for the given mixture ID
1054
1193
 
1055
1194
  :param m_id: Zero-based mixture ID
1056
- :param targets: List of augmented target audio data (one per target in the mixup) for the given mixture ID
1057
- :param noise: Augmented noise audio data for the given mixture ID
1058
- :param mixture: Mixture audio data for the given mixture ID
1059
1195
  :param force: Force computing data from original sources regardless of whether cached data exists
1196
+ :param cache: Cache result
1060
1197
  :return: list of truth_t data
1061
1198
  """
1199
+ from .data_io import write_cached_data
1062
1200
  from .truth import truth_function
1063
1201
 
1064
1202
  if not force:
1065
- truth_t = self.read_mixture_data(m_id, "truth_t")
1203
+ truth_t = self.read_mixture_data(m_id, "truth_t")["truth_t"]
1066
1204
  if truth_t is not None:
1067
1205
  return truth_t
1068
1206
 
1069
- if force or targets is None:
1070
- targets = self.mixture_targets(m_id, force)
1071
-
1072
- if force or noise is None:
1073
- noise = self.mixture_noise(m_id, force)
1074
-
1075
- if force or mixture is None:
1076
- mixture = self.mixture_mixture(m_id, targets=targets, noise=noise, force=force)
1207
+ truth_t = truth_function(self, m_id)
1077
1208
 
1078
- if not all(len(target) == self.mixture(m_id).samples for target in targets):
1079
- raise ValueError("Lengths of targets do not match length of mixture")
1080
-
1081
- if len(noise) != self.mixture(m_id).samples:
1082
- raise ValueError("Length of noise does not match length of mixture")
1209
+ if cache:
1210
+ write_cached_data(
1211
+ location=self.location,
1212
+ name="mixture",
1213
+ index=self.mixture(m_id).name,
1214
+ items={"truth_t": truth_t},
1215
+ )
1083
1216
 
1084
- return truth_function(self, m_id)
1217
+ return truth_t
1085
1218
 
1086
1219
  def mixture_segsnr_t(
1087
1220
  self,
1088
1221
  m_id: int,
1089
- targets: list[AudioT] | None = None,
1090
- target: AudioT | None = None,
1222
+ sources: SourcesAudioT | None = None,
1223
+ source: AudioT | None = None,
1091
1224
  noise: AudioT | None = None,
1092
1225
  force: bool = False,
1226
+ cache: bool = False,
1093
1227
  ) -> Segsnr:
1094
1228
  """Get the segsnr_t data for the given mixture ID
1095
1229
 
1096
1230
  :param m_id: Zero-based mixture ID
1097
- :param targets: List of augmented target audio data (one per target in the mixup)
1098
- :param target: Augmented target audio data
1099
- :param noise: Augmented noise audio data
1231
+ :param sources: Dictionary of pre-truth source audio data (one per source in the mixture)
1232
+ :param source: Post-truth, gained, and summed source audio data
1233
+ :param noise: Post-truth and gained noise audio data
1100
1234
  :param force: Force computing data from original sources regardless of whether cached data exists
1235
+ :param cache: Cache result
1101
1236
  :return: segsnr_t data
1102
1237
  """
1103
1238
  import numpy as np
1104
1239
  import torch
1105
1240
  from pyaaware import ForwardTransform
1106
1241
 
1242
+ from .data_io import write_cached_data
1243
+
1107
1244
  if not force:
1108
- segsnr_t = self.read_mixture_data(m_id, "segsnr_t")
1245
+ segsnr_t = self.read_mixture_data(m_id, "segsnr_t")["segsnr_t"]
1109
1246
  if segsnr_t is not None:
1110
1247
  return segsnr_t
1111
1248
 
1112
- if force or target is None:
1113
- target = self.mixture_target(m_id, targets, force)
1249
+ if source is None:
1250
+ source = self.mixture_source(m_id, sources, force)
1114
1251
 
1115
- if force or noise is None:
1116
- noise = self.mixture_noise(m_id, force)
1252
+ if noise is None:
1253
+ noise = self.mixture_noise(m_id, sources, force)
1117
1254
 
1118
1255
  ft = ForwardTransform(
1119
1256
  length=self.ft_config.length,
@@ -1127,13 +1264,13 @@ class MixtureDatabase:
1127
1264
 
1128
1265
  segsnr_t = np.empty(mixture.samples, dtype=np.float32)
1129
1266
 
1130
- target_energy = ft.execute_all(torch.from_numpy(target))[1].numpy()
1267
+ source_energy = ft.execute_all(torch.from_numpy(source))[1].numpy()
1131
1268
  noise_energy = ft.execute_all(torch.from_numpy(noise))[1].numpy()
1132
1269
 
1133
1270
  offsets = range(0, mixture.samples, self.ft_config.overlap)
1134
- if len(target_energy) != len(offsets):
1271
+ if len(source_energy) != len(offsets):
1135
1272
  raise ValueError(
1136
- f"Number of frames in energy, {len(target_energy)}, is not number of frames in mixture, {len(offsets)}"
1273
+ f"Number of frames in energy, {len(source_energy)}, is not number of frames in mixture, {len(offsets)}"
1137
1274
  )
1138
1275
 
1139
1276
  for idx, offset in enumerate(offsets):
@@ -1142,187 +1279,242 @@ class MixtureDatabase:
1142
1279
  if noise_energy[idx] == 0:
1143
1280
  snr = np.float32(np.inf)
1144
1281
  else:
1145
- snr = np.float32(target_energy[idx] / noise_energy[idx])
1282
+ snr = np.float32(source_energy[idx] / noise_energy[idx])
1146
1283
 
1147
1284
  segsnr_t[indices] = snr
1148
1285
 
1286
+ if cache:
1287
+ write_cached_data(
1288
+ location=self.location,
1289
+ name="mixture",
1290
+ index=mixture.name,
1291
+ items={"segsnr_t": segsnr_t},
1292
+ )
1293
+
1149
1294
  return segsnr_t
1150
1295
 
1151
1296
  def mixture_segsnr(
1152
1297
  self,
1153
1298
  m_id: int,
1154
1299
  segsnr_t: Segsnr | None = None,
1155
- targets: list[AudioT] | None = None,
1156
- target: AudioT | None = None,
1300
+ sources: SourcesAudioT | None = None,
1301
+ source: AudioT | None = None,
1157
1302
  noise: AudioT | None = None,
1158
1303
  force: bool = False,
1304
+ cache: bool = False,
1159
1305
  ) -> Segsnr:
1160
1306
  """Get the segsnr data for the given mixture ID
1161
1307
 
1162
1308
  :param m_id: Zero-based mixture ID
1163
1309
  :param segsnr_t: segsnr_t data
1164
- :param targets: List of augmented target audio data (one per target in the mixup)
1165
- :param target: Augmented target audio data
1166
- :param noise: Augmented noise audio data
1310
+ :param sources: Dictionary of pre-truth source audio data (one per source in the mixture)
1311
+ :param source: Post-truth, gained, and summed source audio data
1312
+ :param noise: Post-truth and gained noise audio data
1167
1313
  :param force: Force computing data from original sources regardless of whether cached data exists
1314
+ :param cache: Cache result
1168
1315
  :return: segsnr data
1169
1316
  """
1317
+ from .data_io import write_cached_data
1318
+
1170
1319
  if not force:
1171
- segsnr = self.read_mixture_data(m_id, "segsnr")
1320
+ segsnr = self.read_mixture_data(m_id, "segsnr")["segsnr"]
1172
1321
  if segsnr is not None:
1173
1322
  return segsnr
1174
1323
 
1175
- segsnr_t = self.read_mixture_data(m_id, "segsnr_t")
1176
- if segsnr_t is not None:
1177
- return segsnr_t[0 :: self.ft_config.overlap]
1324
+ if segsnr_t is None:
1325
+ segsnr_t = self.mixture_segsnr_t(m_id, sources, source, noise, force)
1326
+
1327
+ segsnr = segsnr_t[0 :: self.ft_config.overlap]
1178
1328
 
1179
- if force or segsnr_t is None:
1180
- segsnr_t = self.mixture_segsnr_t(m_id, targets, target, noise, force)
1329
+ if cache:
1330
+ write_cached_data(
1331
+ location=self.location,
1332
+ name="mixture",
1333
+ index=self.mixture(m_id).name,
1334
+ items={"segsnr": segsnr},
1335
+ )
1181
1336
 
1182
- return segsnr_t[0 :: self.ft_config.overlap]
1337
+ return segsnr
1183
1338
 
1184
1339
  def mixture_ft(
1185
1340
  self,
1186
1341
  m_id: int,
1187
- targets: list[AudioT] | None = None,
1188
- target: AudioT | None = None,
1342
+ sources: SourcesAudioT | None = None,
1343
+ source: AudioT | None = None,
1189
1344
  noise: AudioT | None = None,
1190
1345
  mixture_f: AudioF | None = None,
1191
1346
  mixture: AudioT | None = None,
1192
- truth_t: list[TruthDict] | None = None,
1347
+ truth_t: TruthsDict | None = None,
1193
1348
  force: bool = False,
1194
- ) -> tuple[Feature, TruthDict]:
1349
+ cache: bool = False,
1350
+ ) -> tuple[Feature, TruthsDict]:
1195
1351
  """Get the feature and truth_f data for the given mixture ID
1196
1352
 
1197
1353
  :param m_id: Zero-based mixture ID
1198
- :param targets: List of augmented target audio data (one per target in the mixup)
1199
- :param target: Augmented target audio data
1200
- :param noise: Augmented noise audio data
1354
+ :param sources: Dictionary of pre-truth source audio data (one per source in the mixture)
1355
+ :param source: Post-truth, gained, and summed source audio data
1356
+ :param noise: Post-truth and gained noise audio data
1201
1357
  :param mixture_f: Mixture transform data
1202
1358
  :param mixture: Mixture audio data
1203
1359
  :param truth_t: truth_t
1204
1360
  :param force: Force computing data from original sources regardless of whether cached data exists
1361
+ :param cache: Cache result
1205
1362
  :return: Tuple of (feature, truth_f) data
1206
1363
  """
1207
1364
  from pyaaware import FeatureGenerator
1208
1365
 
1366
+ from .data_io import write_cached_data
1209
1367
  from .truth import truth_stride_reduction
1210
1368
 
1211
1369
  if not force:
1212
- feature, truth_f = self.read_mixture_data(m_id, ["feature", "truth_f"])
1213
- if feature is not None and truth_f is not None:
1214
- return feature, truth_f
1370
+ ft = self.read_mixture_data(m_id, ["feature", "truth_f"])
1371
+ if ft["feature"] is not None and ft["truth_f"] is not None:
1372
+ return ft["feature"], ft["truth_f"]
1215
1373
 
1216
- if force or mixture_f is None:
1374
+ if mixture_f is None:
1217
1375
  mixture_f = self.mixture_mixture_f(
1218
1376
  m_id=m_id,
1219
- targets=targets,
1220
- target=target,
1377
+ sources=sources,
1378
+ source=source,
1221
1379
  noise=noise,
1222
1380
  mixture=mixture,
1223
1381
  force=force,
1224
1382
  )
1225
1383
 
1226
- if force or truth_t is None:
1227
- truth_t = self.mixture_truth_t(m_id=m_id, targets=targets, noise=noise, force=force)
1384
+ if truth_t is None:
1385
+ truth_t = self.mixture_truth_t(m_id, force)
1228
1386
 
1229
1387
  fg = FeatureGenerator(self.fg_config.feature_mode, self.fg_config.truth_parameters)
1230
1388
 
1231
- # TODO: handle mixup in truth_t
1232
- feature, truth_f = fg.execute_all(mixture_f, truth_t[0])
1389
+ feature, truth_f = fg.execute_all(mixture_f, truth_t)
1233
1390
  if truth_f is not None:
1234
- for key in self.truth_configs:
1235
- if self.truth_parameters[key] is not None:
1236
- truth_f[key] = truth_stride_reduction(truth_f[key], self.truth_configs[key].stride_reduction)
1391
+ truth_configs = self.mixture_truth_configs(m_id)
1392
+ for category, configs in truth_configs.items():
1393
+ for name, config in configs.items():
1394
+ if self.truth_parameters[category][name] is not None:
1395
+ truth_f[category][name] = truth_stride_reduction(
1396
+ truth_f[category][name], config.stride_reduction
1397
+ )
1237
1398
  else:
1238
1399
  raise TypeError("Unexpected truth of None from feature generator")
1239
1400
 
1401
+ if cache:
1402
+ write_cached_data(
1403
+ location=self.location,
1404
+ name="mixture",
1405
+ index=self.mixture(m_id).name,
1406
+ items={"feature": truth_f, "truth_f": truth_f},
1407
+ )
1408
+
1240
1409
  return feature, truth_f
1241
1410
 
1242
1411
  def mixture_feature(
1243
1412
  self,
1244
1413
  m_id: int,
1245
- targets: list[AudioT] | None = None,
1414
+ sources: SourcesAudioT | None = None,
1246
1415
  noise: AudioT | None = None,
1247
1416
  mixture: AudioT | None = None,
1248
- truth_t: list[TruthDict] | None = None,
1417
+ truth_t: TruthsDict | None = None,
1249
1418
  force: bool = False,
1419
+ cache: bool = False,
1250
1420
  ) -> Feature:
1251
1421
  """Get the feature data for the given mixture ID
1252
1422
 
1253
1423
  :param m_id: Zero-based mixture ID
1254
- :param targets: List of augmented target audio data (one per target in the mixup)
1255
- :param noise: Augmented noise audio data
1424
+ :param sources: Dictionary of pre-truth source audio data (one per source in the mixture)
1425
+ :param noise: Post-truth and gained noise audio data
1256
1426
  :param mixture: Mixture audio data
1257
1427
  :param truth_t: truth_t
1258
1428
  :param force: Force computing data from original sources regardless of whether cached data exists
1429
+ :param cache: Cache result
1259
1430
  :return: Feature data
1260
1431
  """
1261
- feature, _ = self.mixture_ft(
1432
+ from .data_io import write_cached_data
1433
+
1434
+ feature = self.mixture_ft(
1262
1435
  m_id=m_id,
1263
- targets=targets,
1436
+ sources=sources,
1264
1437
  noise=noise,
1265
1438
  mixture=mixture,
1266
1439
  truth_t=truth_t,
1267
1440
  force=force,
1268
- )
1441
+ )[0]
1442
+
1443
+ if cache:
1444
+ write_cached_data(
1445
+ location=self.location,
1446
+ name="mixture",
1447
+ index=self.mixture(m_id).name,
1448
+ items={"feature": feature},
1449
+ )
1450
+
1269
1451
  return feature
1270
1452
 
1271
1453
  def mixture_truth_f(
1272
1454
  self,
1273
1455
  m_id: int,
1274
- targets: list[AudioT] | None = None,
1456
+ sources: SourcesAudioT | None = None,
1275
1457
  noise: AudioT | None = None,
1276
1458
  mixture: AudioT | None = None,
1277
- truth_t: list[TruthDict] | None = None,
1459
+ truth_t: TruthsDict | None = None,
1278
1460
  force: bool = False,
1461
+ cache: bool = False,
1279
1462
  ) -> TruthDict:
1280
1463
  """Get the truth_f data for the given mixture ID
1281
1464
 
1282
1465
  :param m_id: Zero-based mixture ID
1283
- :param targets: List of augmented target audio data (one per target in the mixup)
1284
- :param noise: Augmented noise audio data
1466
+ :param sources: Dictionary of pre-truth source audio data (one per source in the mixture)
1467
+ :param noise: Post-truth and gained noise audio data
1285
1468
  :param mixture: Mixture audio data
1286
1469
  :param truth_t: truth_t
1287
1470
  :param force: Force computing data from original sources regardless of whether cached data exists
1471
+ :param cache: Cache result
1288
1472
  :return: truth_f data
1289
1473
  """
1290
- _, truth_f = self.mixture_ft(
1474
+ from .data_io import write_cached_data
1475
+
1476
+ truth_f = self.mixture_ft(
1291
1477
  m_id=m_id,
1292
- targets=targets,
1478
+ sources=sources,
1293
1479
  noise=noise,
1294
1480
  mixture=mixture,
1295
1481
  truth_t=truth_t,
1296
1482
  force=force,
1297
- )
1483
+ )[1]
1484
+
1485
+ if cache:
1486
+ write_cached_data(
1487
+ location=self.location,
1488
+ name="mixture",
1489
+ index=self.mixture(m_id).name,
1490
+ items={"truth_f": truth_f},
1491
+ )
1492
+
1298
1493
  return truth_f
1299
1494
 
1300
- def mixture_class_count(
1301
- self,
1302
- m_id: int,
1303
- targets: list[AudioT] | None = None,
1304
- noise: AudioT | None = None,
1305
- truth_t: list[TruthDict] | None = None,
1306
- ) -> ClassCount:
1495
+ def mixture_class_count(self, m_id: int, truth_t: TruthsDict | None = None) -> dict[str, ClassCount]:
1307
1496
  """Compute the number of frames for which each class index is active for the given mixture ID
1308
1497
 
1309
1498
  :param m_id: Zero-based mixture ID
1310
- :param targets: List of augmented target audio (one per target in the mixup)
1311
- :param noise: Augmented noise audio
1312
1499
  :param truth_t: truth_t
1313
- :return: List of class counts
1500
+ :return: Dictionary of class counts
1314
1501
  """
1315
1502
  import numpy as np
1316
1503
 
1317
1504
  if truth_t is None:
1318
- truth_t = self.mixture_truth_t(m_id, targets, noise)
1505
+ truth_t = self.mixture_truth_t(m_id)
1319
1506
 
1320
- class_count = [0] * self.num_classes
1321
- num_classes = self.num_classes
1322
- if "sed" in self.truth_configs:
1323
- for cl in range(num_classes):
1324
- # TODO: handle mixup in truth_t
1325
- class_count[cl] = int(np.sum(truth_t[0]["sed"][:, cl] >= self.class_weights_thresholds[cl]))
1507
+ class_count: dict[str, ClassCount] = {}
1508
+
1509
+ truth_configs = self.mixture_truth_configs(m_id)
1510
+ for category in truth_configs:
1511
+ class_count[category] = [0] * self.num_classes
1512
+ for configs in truth_configs[category]:
1513
+ if "sed" in configs:
1514
+ for cl in range(self.num_classes):
1515
+ class_count[category][cl] = int(
1516
+ np.sum(truth_t[category]["sed"][:, cl] >= self.class_weights_thresholds[cl])
1517
+ )
1326
1518
 
1327
1519
  return class_count
1328
1520
 
@@ -1348,57 +1540,56 @@ class MixtureDatabase:
1348
1540
  return _speaker(self.db, s_id, tier, self.use_cache)
1349
1541
 
1350
1542
  def speech_metadata(self, tier: str) -> list[str]:
1351
- from .helpers import get_textgrid_tier_from_target_file
1543
+ from .helpers import get_textgrid_tier_from_source_file
1352
1544
 
1353
1545
  results: set[str] = set()
1354
1546
  if tier in self.textgrid_metadata_tiers:
1355
- for target_file in self.target_files:
1356
- data = get_textgrid_tier_from_target_file(target_file.name, tier)
1357
- if data is None:
1358
- continue
1359
- if isinstance(data, list):
1360
- for item in data:
1361
- results.add(item.label)
1362
- else:
1363
- results.add(data)
1547
+ for source_files in self.source_files.values():
1548
+ for source_file in source_files:
1549
+ data = get_textgrid_tier_from_source_file(source_file.name, tier)
1550
+ if data is None:
1551
+ continue
1552
+ if isinstance(data, list):
1553
+ for item in data:
1554
+ results.add(item.label)
1555
+ else:
1556
+ results.add(data)
1364
1557
  elif tier in self.speaker_metadata_tiers:
1365
- for target_file in self.target_files:
1366
- data = self.speaker(target_file.speaker_id, tier)
1367
- if data is not None:
1368
- results.add(data)
1558
+ for source_files in self.source_files.values():
1559
+ for source_file in source_files:
1560
+ data = self.speaker(source_file.speaker_id, tier)
1561
+ if data is not None:
1562
+ results.add(data)
1369
1563
 
1370
1564
  return sorted(results)
1371
1565
 
1372
- def mixture_speech_metadata(self, mixid: int, tier: str) -> list[SpeechMetadata]:
1566
+ def mixture_speech_metadata(self, mixid: int, tier: str) -> dict[str, SpeechMetadata]:
1373
1567
  from praatio.utilities.constants import Interval
1374
1568
 
1375
- from .helpers import get_textgrid_tier_from_target_file
1569
+ from .helpers import get_textgrid_tier_from_source_file
1376
1570
 
1377
- results: list[SpeechMetadata] = []
1571
+ results: dict[str, SpeechMetadata] = {}
1378
1572
  is_textgrid = tier in self.textgrid_metadata_tiers
1379
1573
  if is_textgrid:
1380
- for target in self.mixture(mixid).targets:
1381
- data = get_textgrid_tier_from_target_file(self.target_file(target.file_id).name, tier)
1574
+ for category, source in self.mixture(mixid).all_sources.items():
1575
+ data = get_textgrid_tier_from_source_file(self.source_file(source.file_id).name, tier)
1382
1576
  if isinstance(data, list):
1383
- # Check for tempo augmentation and adjust Interval start and end data as needed
1577
+ # Check for tempo effect and adjust Interval start and end data as needed
1384
1578
  entries = []
1385
1579
  for entry in data:
1386
- if target.augmentation.pre.tempo is not None:
1387
- entries.append(
1388
- Interval(
1389
- entry.start / target.augmentation.pre.tempo,
1390
- entry.end / target.augmentation.pre.tempo,
1391
- entry.label,
1392
- )
1580
+ entries.append(
1581
+ Interval(
1582
+ entry.start / source.pre_tempo,
1583
+ entry.end / source.pre_tempo,
1584
+ entry.label,
1393
1585
  )
1394
- else:
1395
- entries.append(entry)
1396
- results.append(entries)
1586
+ )
1587
+ results[category] = entries
1397
1588
  else:
1398
- results.append(data)
1589
+ results[category] = data
1399
1590
  else:
1400
- for target in self.mixture(mixid).targets:
1401
- results.append(self.speaker(self.target_file(target.file_id).speaker_id, tier))
1591
+ for category, source in self.mixture(mixid).all_sources.items():
1592
+ results[category] = self.speaker(self.source_file(source.file_id).speaker_id, tier)
1402
1593
 
1403
1594
  return results
1404
1595
 
@@ -1450,7 +1641,7 @@ class MixtureDatabase:
1450
1641
 
1451
1642
  return [mixture_id[0] - 1 for mixture_id in results]
1452
1643
 
1453
- def mixture_all_speech_metadata(self, m_id: int) -> list[dict[str, SpeechMetadata]]:
1644
+ def mixture_all_speech_metadata(self, m_id: int) -> dict[str, dict[str, SpeechMetadata]]:
1454
1645
  from .helpers import mixture_all_speech_metadata
1455
1646
 
1456
1647
  return mixture_all_speech_metadata(self, self.mixture(m_id))
@@ -1483,63 +1674,65 @@ class MixtureDatabase:
1483
1674
  :param m_id: Zero-based mixture ID
1484
1675
  :param metrics: List of metrics to get
1485
1676
  :param force: Force computing data from original sources regardless of whether cached data exists
1486
- :return: List of metric data
1677
+ :return: Dictionary of metric data
1487
1678
  """
1488
1679
  from collections.abc import Callable
1489
1680
 
1490
1681
  import numpy as np
1491
1682
  from pystoi import stoi
1492
1683
 
1493
- from sonusai.metrics import calc_audio_stats
1494
- from sonusai.metrics import calc_phase_distance
1495
- from sonusai.metrics import calc_segsnr_f
1496
- from sonusai.metrics import calc_segsnr_f_bin
1497
- from sonusai.metrics import calc_speech
1498
- from sonusai.metrics import calc_wer
1499
- from sonusai.metrics import calc_wsdr
1500
- from sonusai.mixture import SAMPLE_RATE
1501
- from sonusai.mixture import AudioStatsMetrics
1502
- from sonusai.mixture import SpeechMetrics
1503
- from sonusai.utils import calc_asr
1504
-
1505
- def create_targets_audio() -> Callable[[], list[AudioT]]:
1506
- state: list[AudioT] | None = None
1507
-
1508
- def get() -> list[AudioT]:
1684
+ from ..constants import SAMPLE_RATE
1685
+ from ..datatypes import AudioStatsMetrics
1686
+ from ..datatypes import SpeechMetrics
1687
+ from ..metrics.calc_audio_stats import calc_audio_stats
1688
+ from ..metrics.calc_pesq import calc_pesq
1689
+ from ..metrics.calc_phase_distance import calc_phase_distance
1690
+ from ..metrics.calc_segsnr_f import calc_segsnr_f
1691
+ from ..metrics.calc_segsnr_f import calc_segsnr_f_bin
1692
+ from ..metrics.calc_speech import calc_speech
1693
+ from ..metrics.calc_wer import calc_wer
1694
+ from ..metrics.calc_wsdr import calc_wsdr
1695
+ from ..utils.asr import calc_asr
1696
+ from ..utils.db import linear_to_db
1697
+
1698
+ def create_sources_audio() -> Callable[[], dict[str, AudioT]]:
1699
+ state: dict[str, AudioT] | None = None
1700
+
1701
+ def get() -> dict[str, AudioT]:
1509
1702
  nonlocal state
1510
1703
  if state is None:
1511
- state = self.mixture_targets(m_id)
1704
+ state = self.mixture_sources(m_id)
1512
1705
  return state
1513
1706
 
1514
1707
  return get
1515
1708
 
1516
- targets_audio = create_targets_audio()
1709
+ sources_audio = create_sources_audio()
1517
1710
 
1518
- def create_target_audio() -> Callable[[], AudioT]:
1711
+ def create_source_audio() -> Callable[[], AudioT]:
1519
1712
  state: AudioT | None = None
1520
1713
 
1521
1714
  def get() -> AudioT:
1522
1715
  nonlocal state
1523
1716
  if state is None:
1524
- state = self.mixture_target(m_id)
1717
+ state = self.mixture_source(m_id)
1525
1718
  return state
1526
1719
 
1527
1720
  return get
1528
1721
 
1529
- target_audio = create_target_audio()
1722
+ source_audio = create_source_audio()
1530
1723
 
1531
- def create_target_f() -> Callable[[], AudioF]:
1724
+ def create_source_f() -> Callable[[], AudioF]:
1532
1725
  state: AudioF | None = None
1533
1726
 
1534
1727
  def get() -> AudioF:
1535
1728
  nonlocal state
1536
1729
  if state is None:
1537
- state = self.mixture_targets_f(m_id)[0]
1730
+ state = self.mixture_source_f(m_id)
1538
1731
  return state
1539
1732
 
1540
1733
  return get
1541
1734
 
1542
- target_f = create_target_f()
1735
+ source_f = create_source_f()
1543
1736
 
1544
1737
  def create_noise_audio() -> Callable[[], AudioT]:
1545
1738
  state: AudioT | None = None
@@ -1593,15 +1786,29 @@ class MixtureDatabase:
1593
1786
 
1594
1787
  segsnr_f = create_segsnr_f()
1595
1788
 
1596
- def create_speech() -> Callable[[], list[SpeechMetrics]]:
1597
- state: list[SpeechMetrics] | None = None
1789
+ def create_pesq() -> Callable[[], dict[str, float]]:
1790
+ state: dict[str, float] | None = None
1791
+
1792
+ def get() -> dict[str, float]:
1793
+ nonlocal state
1794
+ if state is None:
1795
+ state = {category: calc_pesq(mixture_audio(), audio) for category, audio in sources_audio().items()}
1796
+ return state
1797
+
1798
+ return get
1799
+
1800
+ pesq = create_pesq()
1801
+
1802
+ def create_speech() -> Callable[[], dict[str, SpeechMetrics]]:
1803
+ state: dict[str, SpeechMetrics] | None = None
1598
1804
 
1599
- def get() -> list[SpeechMetrics]:
1805
+ def get() -> dict[str, SpeechMetrics]:
1600
1806
  nonlocal state
1601
1807
  if state is None:
1602
- state = []
1603
- for audio in targets_audio():
1604
- state.append(calc_speech(hypothesis=mixture_audio(), reference=audio))
1808
+ state = {
1809
+ category: calc_speech(mixture_audio(), audio, pesq()[category])
1810
+ for category, audio in sources_audio().items()
1811
+ }
1605
1812
  return state
1606
1813
 
1607
1814
  return get
@@ -1621,33 +1828,34 @@ class MixtureDatabase:
1621
1828
 
1622
1829
  mixture_stats = create_mixture_stats()
1623
1830
 
1624
- def create_targets_stats() -> Callable[[], list[AudioStatsMetrics]]:
1625
- state: list[AudioStatsMetrics] | None = None
1831
+ def create_sources_stats() -> Callable[[], dict[str, AudioStatsMetrics]]:
1832
+ state: dict[str, AudioStatsMetrics] | None = None
1626
1833
 
1627
- def get() -> list[AudioStatsMetrics]:
1834
+ def get() -> dict[str, AudioStatsMetrics]:
1628
1835
  nonlocal state
1629
1836
  if state is None:
1630
- state = []
1631
- for audio in targets_audio():
1632
- state.append(calc_audio_stats(audio, self.fg_info.ft_config.length / SAMPLE_RATE))
1837
+ state = {
1838
+ category: calc_audio_stats(audio, self.fg_info.ft_config.length / SAMPLE_RATE)
1839
+ for category, audio in sources_audio().items()
1840
+ }
1633
1841
  return state
1634
1842
 
1635
1843
  return get
1636
1844
 
1637
- targets_stats = create_targets_stats()
1845
+ sources_stats = create_sources_stats()
1638
1846
 
1639
- def create_target_stats() -> Callable[[], AudioStatsMetrics]:
1847
+ def create_source_stats() -> Callable[[], AudioStatsMetrics]:
1640
1848
  state: AudioStatsMetrics | None = None
1641
1849
 
1642
1850
  def get() -> AudioStatsMetrics:
1643
1851
  nonlocal state
1644
1852
  if state is None:
1645
- state = calc_audio_stats(target_audio(), self.fg_info.ft_config.length / SAMPLE_RATE)
1853
+ state = calc_audio_stats(source_audio(), self.fg_info.ft_config.length / SAMPLE_RATE)
1646
1854
  return state
1647
1855
 
1648
1856
  return get
1649
1857
 
1650
- target_stats = create_target_stats()
1858
+ source_stats = create_source_stats()
1651
1859
 
1652
1860
  def create_noise_stats() -> Callable[[], AudioStatsMetrics]:
1653
1861
  state: AudioStatsMetrics | None = None
@@ -1678,33 +1886,34 @@ class MixtureDatabase:
1678
1886
 
1679
1887
  asr_config = create_asr_config()
1680
1888
 
1681
- def create_targets_asr() -> Callable[[str], list[str]]:
1682
- state: dict[str, list[str]] = {}
1889
+ def create_sources_asr() -> Callable[[str], dict[str, str]]:
1890
+ state: dict[str, dict[str, str]] = {}
1683
1891
 
1684
- def get(asr_name) -> list[str]:
1892
+ def get(asr_name) -> dict[str, str]:
1685
1893
  nonlocal state
1686
1894
  if asr_name not in state:
1687
- state[asr_name] = []
1688
- for audio in targets_audio():
1689
- state[asr_name].append(calc_asr(audio, **asr_config(asr_name)).text)
1895
+ state[asr_name] = {
1896
+ category: calc_asr(audio, **asr_config(asr_name)).text
1897
+ for category, audio in sources_audio().items()
1898
+ }
1690
1899
  return state[asr_name]
1691
1900
 
1692
1901
  return get
1693
1902
 
1694
- targets_asr = create_targets_asr()
1903
+ sources_asr = create_sources_asr()
1695
1904
 
1696
- def create_target_asr() -> Callable[[str], str]:
1905
+ def create_source_asr() -> Callable[[str], str]:
1697
1906
  state: dict[str, str] = {}
1698
1907
 
1699
1908
  def get(asr_name) -> str:
1700
1909
  nonlocal state
1701
1910
  if asr_name not in state:
1702
- state[asr_name] = calc_asr(target_audio(), **asr_config(asr_name)).text
1911
+ state[asr_name] = calc_asr(source_audio(), **asr_config(asr_name)).text
1703
1912
  return state[asr_name]
1704
1913
 
1705
1914
  return get
1706
1915
 
1707
- target_asr = create_target_asr()
1916
+ source_asr = create_source_asr()
1708
1917
 
1709
1918
  def create_mixture_asr() -> Callable[[str], str]:
1710
1919
  state: dict[str, str] = {}
@@ -1728,11 +1937,11 @@ class MixtureDatabase:
1728
1937
 
1729
1938
  def calc(m: str) -> Any:
1730
1939
  if m == "mxsnr":
1731
- return self.mixture(m_id).snr
1940
+ return {category: source.snr for category, source in self.mixture(m_id).all_sources.items()}
1732
1941
 
1733
1942
  # Get cached data first, if exists
1734
1943
  if not force:
1735
- value = self.read_mixture_data(m_id, m)
1944
+ value = self.read_mixture_data(m_id, m)[m]
1736
1945
  if value is not None:
1737
1946
  return value
1738
1947
 
@@ -1744,8 +1953,8 @@ class MixtureDatabase:
1744
1953
  # noise only, ignore/reset target asr
1745
1954
  return float("nan")
1746
1955
 
1747
- if target_asr(asr_name):
1748
- return calc_wer(mixture_asr(asr_name), target_asr(asr_name)).wer * 100
1956
+ if source_asr(asr_name):
1957
+ return calc_wer(mixture_asr(asr_name), source_asr(asr_name)).wer * 100
1749
1958
 
1750
1959
  # TODO: should this be NaN like above?
1751
1960
  return float(0)
@@ -1753,12 +1962,14 @@ class MixtureDatabase:
1753
1962
  if m.startswith("basewer"):
1754
1963
  asr_name = get_asr_name(m)
1755
1964
 
1756
- text = self.mixture_speech_metadata(m_id, "text")[0]
1757
- if not isinstance(text, str):
1758
- # TODO: should this be NaN like above?
1759
- return [float(0)] * len(targets_audio())
1760
-
1761
- return [calc_wer(t, text).wer * 100 for t in targets_asr(asr_name)]
1965
+ text = self.mixture_speech_metadata(m_id, "text")
1966
+ base_wer: dict[str, float] = {}
1967
+ for category, source in sources_asr(asr_name).items():
1968
+ if isinstance(text[category], str):
1969
+ base_wer[category] = calc_wer(source, str(text[category])).wer * 100
1970
+ else:
1971
+ base_wer[category] = 0
1972
+ return base_wer
1762
1973
 
1763
1974
  if m.startswith("mxasr"):
1764
1975
  return mixture_asr(get_asr_name(m))
@@ -1769,6 +1980,18 @@ class MixtureDatabase:
1769
1980
  if m == "mxssnr_std":
1770
1981
  return calc_segsnr_f(segsnr_f()).std
1771
1982
 
1983
+ if m == "mxssnr_avg_db":
1984
+ val = calc_segsnr_f(segsnr_f()).avg
1985
+ if val is not None:
1986
+ return linear_to_db(val)
1987
+ return None
1988
+
1989
+ if m == "mxssnr_std_db":
1990
+ val = calc_segsnr_f(segsnr_f()).std
1991
+ if val is not None:
1992
+ return linear_to_db(val)
1993
+ return None
1994
+
1772
1995
  if m == "mxssnrdb_avg":
1773
1996
  return calc_segsnr_f(segsnr_f()).db_avg
1774
1997
 
@@ -1776,40 +1999,40 @@ class MixtureDatabase:
1776
1999
  return calc_segsnr_f(segsnr_f()).db_std
1777
2000
 
1778
2001
  if m == "mxssnrf_avg":
1779
- return calc_segsnr_f_bin(target_f(), noise_f()).avg
2002
+ return calc_segsnr_f_bin(source_f(), noise_f()).avg
1780
2003
 
1781
2004
  if m == "mxssnrf_std":
1782
- return calc_segsnr_f_bin(target_f(), noise_f()).std
2005
+ return calc_segsnr_f_bin(source_f(), noise_f()).std
1783
2006
 
1784
2007
  if m == "mxssnrdbf_avg":
1785
- return calc_segsnr_f_bin(target_f(), noise_f()).db_avg
2008
+ return calc_segsnr_f_bin(source_f(), noise_f()).db_avg
1786
2009
 
1787
2010
  if m == "mxssnrdbf_std":
1788
- return calc_segsnr_f_bin(target_f(), noise_f()).db_std
2011
+ return calc_segsnr_f_bin(source_f(), noise_f()).db_std
1789
2012
 
1790
2013
  if m == "mxpesq":
1791
2014
  if self.mixture(m_id).is_noise_only:
1792
- return [0] * len(speech())
1793
- return [s.pesq for s in speech()]
2015
+ return dict.fromkeys(pesq(), 0)
2016
+ return pesq()
1794
2017
 
1795
2018
  if m == "mxcsig":
1796
2019
  if self.mixture(m_id).is_noise_only:
1797
- return [0] * len(speech())
1798
- return [s.csig for s in speech()]
2020
+ return dict.fromkeys(speech(), 0)
2021
+ return {category: s.csig for category, s in speech().items()}
1799
2022
 
1800
2023
  if m == "mxcbak":
1801
2024
  if self.mixture(m_id).is_noise_only:
1802
- return [0] * len(speech())
1803
- return [s.cbak for s in speech()]
2025
+ return dict.fromkeys(speech(), 0)
2026
+ return {category: s.cbak for category, s in speech().items()}
1804
2027
 
1805
2028
  if m == "mxcovl":
1806
2029
  if self.mixture(m_id).is_noise_only:
1807
- return [0] * len(speech())
1808
- return [s.covl for s in speech()]
2030
+ return dict.fromkeys(speech(), 0)
2031
+ return {category: s.covl for category, s in speech().items()}
1809
2032
 
1810
2033
  if m == "mxwsdr":
1811
2034
  mixture = mixture_audio()[:, np.newaxis]
1812
- target = target_audio()[:, np.newaxis]
2035
+ target = source_audio()[:, np.newaxis]
1813
2036
  noise = noise_audio()[:, np.newaxis]
1814
2037
  return calc_wsdr(
1815
2038
  hypothesis=np.concatenate((mixture, noise), axis=1),
@@ -1819,11 +2042,11 @@ class MixtureDatabase:
1819
2042
 
1820
2043
  if m == "mxpd":
1821
2044
  mixture_f = self.mixture_mixture_f(m_id)
1822
- return calc_phase_distance(hypothesis=mixture_f, reference=target_f())[0]
2045
+ return calc_phase_distance(hypothesis=mixture_f, reference=source_f())[0]
1823
2046
 
1824
2047
  if m == "mxstoi":
1825
2048
  return stoi(
1826
- x=target_audio(),
2049
+ x=source_audio(),
1827
2050
  y=mixture_audio(),
1828
2051
  fs_sig=SAMPLE_RATE,
1829
2052
  extended=False,
@@ -1860,70 +2083,70 @@ class MixtureDatabase:
1860
2083
  return mixture_stats().pkc
1861
2084
 
1862
2085
  if m == "mxtdco":
1863
- return target_stats().dco
2086
+ return source_stats().dco
1864
2087
 
1865
2088
  if m == "mxtmin":
1866
- return target_stats().min
2089
+ return source_stats().min
1867
2090
 
1868
2091
  if m == "mxtmax":
1869
- return target_stats().max
2092
+ return source_stats().max
1870
2093
 
1871
2094
  if m == "mxtpkdb":
1872
- return target_stats().pkdb
2095
+ return source_stats().pkdb
1873
2096
 
1874
2097
  if m == "mxtlrms":
1875
- return target_stats().lrms
2098
+ return source_stats().lrms
1876
2099
 
1877
2100
  if m == "mxtpkr":
1878
- return target_stats().pkr
2101
+ return source_stats().pkr
1879
2102
 
1880
2103
  if m == "mxttr":
1881
- return target_stats().tr
2104
+ return source_stats().tr
1882
2105
 
1883
2106
  if m == "mxtcr":
1884
- return target_stats().cr
2107
+ return source_stats().cr
1885
2108
 
1886
2109
  if m == "mxtfl":
1887
- return target_stats().fl
2110
+ return source_stats().fl
1888
2111
 
1889
2112
  if m == "mxtpkc":
1890
- return target_stats().pkc
2113
+ return source_stats().pkc
1891
2114
 
1892
- if m == "tdco":
1893
- return [t.dco for t in targets_stats()]
2115
+ if m == "sdco":
2116
+ return {category: s.dco for category, s in sources_stats().items()}
1894
2117
 
1895
- if m == "tmin":
1896
- return [t.min for t in targets_stats()]
2118
+ if m == "smin":
2119
+ return {category: s.min for category, s in sources_stats().items()}
1897
2120
 
1898
- if m == "tmax":
1899
- return [t.max for t in targets_stats()]
2121
+ if m == "smax":
2122
+ return {category: s.max for category, s in sources_stats().items()}
1900
2123
 
1901
- if m == "tpkdb":
1902
- return [t.pkdb for t in targets_stats()]
2124
+ if m == "spkdb":
2125
+ return {category: s.pkdb for category, s in sources_stats().items()}
1903
2126
 
1904
- if m == "tlrms":
1905
- return [t.lrms for t in targets_stats()]
2127
+ if m == "slrms":
2128
+ return {category: s.lrms for category, s in sources_stats().items()}
1906
2129
 
1907
- if m == "tpkr":
1908
- return [t.pkr for t in targets_stats()]
2130
+ if m == "spkr":
2131
+ return {category: s.pkr for category, s in sources_stats().items()}
1909
2132
 
1910
- if m == "ttr":
1911
- return [t.tr for t in targets_stats()]
2133
+ if m == "str":
2134
+ return {category: s.tr for category, s in sources_stats().items()}
1912
2135
 
1913
- if m == "tcr":
1914
- return [t.cr for t in targets_stats()]
2136
+ if m == "scr":
2137
+ return {category: s.cr for category, s in sources_stats().items()}
1915
2138
 
1916
- if m == "tfl":
1917
- return [t.fl for t in targets_stats()]
2139
+ if m == "sfl":
2140
+ return {category: s.fl for category, s in sources_stats().items()}
1918
2141
 
1919
- if m == "tpkc":
1920
- return [t.pkc for t in targets_stats()]
2142
+ if m == "spkc":
2143
+ return {category: s.pkc for category, s in sources_stats().items()}
1921
2144
 
1922
- if m.startswith("tasr"):
1923
- return targets_asr(get_asr_name(m))
2145
+ if m.startswith("sasr"):
2146
+ return sources_asr(get_asr_name(m))
1924
2147
 
1925
- if m.startswith("mxtasr"):
1926
- return target_asr(get_asr_name(m))
2148
+ if m.startswith("mxsasr"):
2149
+ return source_asr(get_asr_name(m))
1927
2150
 
1928
2151
  if m == "ndco":
1929
2152
  return noise_stats().dco
@@ -2022,82 +2245,85 @@ def __spectral_mask(db: partial, sm_id: int) -> SpectralMask:
2022
2245
  )
2023
2246
 
2024
2247
 
2025
- def _target_file(db: partial, t_id: int, use_cache: bool = True) -> TargetFile:
2026
- """Get target file with ID from db
2248
+ def _num_source_files(db: partial, category: str, use_cache: bool = True) -> int:
2249
+ """Get number of source files from category from db
2027
2250
 
2028
2251
  :param db: Database context
2029
- :param t_id: Target file ID
2252
+ :param category: Source category
2030
2253
  :param use_cache: If true, use LRU caching
2031
- :return: Target file
2254
+ :return: Number of source files
2032
2255
  """
2033
2256
  if use_cache:
2034
- return __target_file(db, t_id, use_cache)
2035
- return __target_file.__wrapped__(db, t_id, use_cache)
2257
+ return __num_source_files(db, category)
2258
+ return __num_source_files.__wrapped__(db, category)
2036
2259
 
2037
2260
 
2038
2261
  @lru_cache
2039
- def __target_file(db: partial, t_id: int, use_cache: bool = True) -> TargetFile:
2040
- """Get target file with ID from db
2262
+ def __num_source_files(db: partial, category: str) -> int:
2263
+ """Get number of source files from category from db
2041
2264
 
2042
2265
  :param db: Database context
2043
- :param t_id: Target file ID
2044
- :param use_cache: If true, use LRU caching
2045
- :return: Target file
2266
+ :param category: Source category
2267
+ :return: Number of source files
2046
2268
  """
2047
- import json
2048
-
2049
- from .db_datatypes import TargetFileRecord
2050
-
2051
2269
  with db() as c:
2052
- target_file = TargetFileRecord(
2053
- *c.execute(
2054
- """
2055
- SELECT *
2056
- FROM target_file
2057
- WHERE ? = target_file.id
2058
- """,
2059
- (t_id,),
2060
- ).fetchone()
2061
- )
2062
-
2063
- return TargetFile(
2064
- name=target_file.name,
2065
- samples=target_file.samples,
2066
- class_indices=json.loads(target_file.class_indices),
2067
- level_type=target_file.level_type,
2068
- truth_configs=_target_truth_configs(db, t_id, use_cache),
2069
- speaker_id=target_file.speaker_id,
2270
+ return int(
2271
+ c.execute(
2272
+ "SELECT count(source_file.id) FROM source_file WHERE ? = source_file.category", (category,)
2273
+ ).fetchone()[0]
2070
2274
  )
2071
2275
 
2072
2276
 
2073
- def _noise_file(db: partial, n_id: int, use_cache: bool = True) -> NoiseFile:
2074
- """Get noise file with ID from db
2277
+ def _source_file(db: partial, s_id: int, use_cache: bool = True) -> SourceFile:
2278
+ """Get source file with ID from db
2075
2279
 
2076
2280
  :param db: Database context
2077
- :param n_id: Noise file ID
2281
+ :param s_id: Source file ID
2078
2282
  :param use_cache: If true, use LRU caching
2079
- :return: Noise file
2283
+ :return: Source file
2080
2284
  """
2081
2285
  if use_cache:
2082
- return __noise_file(db, n_id)
2083
- return __noise_file.__wrapped__(db, n_id)
2286
+ return __source_file(db, s_id, use_cache)
2287
+ return __source_file.__wrapped__(db, s_id, use_cache)
2084
2288
 
2085
2289
 
2086
2290
  @lru_cache
2087
- def __noise_file(db: partial, n_id: int) -> NoiseFile:
2291
+ def __source_file(db: partial, s_id: int, use_cache: bool = True) -> SourceFile:
2292
+ """Get source file with ID from db
2293
+
2294
+ :param db: Database context
2295
+ :param s_id: Source file ID
2296
+ :param use_cache: If true, use LRU caching
2297
+ :return: Source file
2298
+ """
2299
+ import json
2300
+
2301
+ from .db_datatypes import SourceFileRecord
2302
+
2088
2303
  with db() as c:
2089
- noise = c.execute(
2090
- """
2091
- SELECT noise_file.name, samples
2092
- FROM noise_file
2093
- WHERE ? = noise_file.id
2094
- """,
2095
- (n_id,),
2096
- ).fetchone()
2097
- return NoiseFile(name=noise[0], samples=noise[1])
2304
+ source_file = SourceFileRecord(
2305
+ *c.execute(
2306
+ """
2307
+ SELECT *
2308
+ FROM source_file
2309
+ WHERE ? = source_file.id
2310
+ """,
2311
+ (s_id,),
2312
+ ).fetchone()
2313
+ )
2314
+
2315
+ return SourceFile(
2316
+ category=source_file.category,
2317
+ name=source_file.name,
2318
+ samples=source_file.samples,
2319
+ class_indices=json.loads(source_file.class_indices),
2320
+ level_type=source_file.level_type,
2321
+ truth_configs=_source_truth_configs(db, s_id, use_cache),
2322
+ speaker_id=source_file.speaker_id,
2323
+ )
2098
2324
 
2099
2325
 
2100
- def _impulse_response_file(db: partial, ir_id: int, use_cache: bool = True) -> str:
2326
+ def _ir_file(db: partial, ir_id: int, use_cache: bool = True) -> str:
2101
2327
  """Get impulse response file name with ID from db
2102
2328
 
2103
2329
  :param db: Database context
@@ -2106,26 +2332,26 @@ def _impulse_response_file(db: partial, ir_id: int, use_cache: bool = True) -> s
2106
2332
  :return: Impulse response file name
2107
2333
  """
2108
2334
  if use_cache:
2109
- return __impulse_response_file(db, ir_id)
2110
- return __impulse_response_file.__wrapped__(db, ir_id)
2335
+ return __ir_file(db, ir_id)
2336
+ return __ir_file.__wrapped__(db, ir_id)
2111
2337
 
2112
2338
 
2113
2339
  @lru_cache
2114
- def __impulse_response_file(db: partial, ir_id: int) -> str:
2340
+ def __ir_file(db: partial, ir_id: int) -> str:
2115
2341
  with db() as c:
2116
2342
  return str(
2117
2343
  c.execute(
2118
2344
  """
2119
- SELECT impulse_response_file.file
2120
- FROM impulse_response_file
2121
- WHERE ? = impulse_response_file.id
2345
+ SELECT ir_file.name
2346
+ FROM ir_file
2347
+ WHERE ? = ir_file.id
2122
2348
  """,
2123
2349
  (ir_id + 1,),
2124
2350
  ).fetchone()[0]
2125
2351
  )
2126
2352
 
2127
2353
 
2128
- def _impulse_response_delay(db: partial, ir_id: int, use_cache: bool = True) -> int:
2354
+ def _ir_delay(db: partial, ir_id: int, use_cache: bool = True) -> int:
2129
2355
  """Get impulse response delay with ID from db
2130
2356
 
2131
2357
  :param db: Database context
@@ -2134,19 +2360,19 @@ def _impulse_response_delay(db: partial, ir_id: int, use_cache: bool = True) ->
2134
2360
  :return: Impulse response delay
2135
2361
  """
2136
2362
  if use_cache:
2137
- return __impulse_response_delay(db, ir_id)
2138
- return __impulse_response_delay.__wrapped__(db, ir_id)
2363
+ return __ir_delay(db, ir_id)
2364
+ return __ir_delay.__wrapped__(db, ir_id)
2139
2365
 
2140
2366
 
2141
2367
  @lru_cache
2142
- def __impulse_response_delay(db: partial, ir_id: int) -> int:
2368
+ def __ir_delay(db: partial, ir_id: int) -> int:
2143
2369
  with db() as c:
2144
2370
  return int(
2145
2371
  c.execute(
2146
2372
  """
2147
- SELECT impulse_response_file.delay
2148
- FROM impulse_response_file
2149
- WHERE ? = impulse_response_file.id
2373
+ SELECT ir_file.delay
2374
+ FROM ir_file
2375
+ WHERE ? = ir_file.id
2150
2376
  """,
2151
2377
  (ir_id + 1,),
2152
2378
  ).fetchone()[0]
@@ -2169,9 +2395,9 @@ def _mixture(db: partial, m_id: int, use_cache: bool = True) -> Mixture:
2169
2395
  @lru_cache
2170
2396
  def __mixture(db: partial, m_id: int) -> Mixture:
2171
2397
  from .db_datatypes import MixtureRecord
2172
- from .db_datatypes import TargetRecord
2398
+ from .db_datatypes import SourceRecord
2173
2399
  from .helpers import to_mixture
2174
- from .helpers import to_target
2400
+ from .helpers import to_source
2175
2401
 
2176
2402
  with db() as c:
2177
2403
  mixture = MixtureRecord(
@@ -2185,19 +2411,20 @@ def __mixture(db: partial, m_id: int) -> Mixture:
2185
2411
  ).fetchone()
2186
2412
  )
2187
2413
 
2188
- targets = [
2189
- to_target(TargetRecord(*target))
2190
- for target in c.execute(
2191
- """
2192
- SELECT target.*
2193
- FROM target, mixture_target
2194
- WHERE ? = mixture_target.mixture_id AND target.id = mixture_target.target_id
2414
+ sources: Sources = {}
2415
+ for source in c.execute(
2416
+ """
2417
+ SELECT source.*
2418
+ FROM source, mixture_source
2419
+ WHERE ? = mixture_source.mixture_id AND source.id = mixture_source.source_id
2195
2420
  """,
2196
- (mixture.id,),
2197
- ).fetchall()
2198
- ]
2421
+ (mixture.id,),
2422
+ ).fetchall():
2423
+ s = SourceRecord(*source)
2424
+ category = c.execute("SELECT category FROM source_file WHERE ? = id", (s.file_id,)).fetchone()[0]
2425
+ sources[category] = to_source(s)
2199
2426
 
2200
- return to_mixture(mixture, targets)
2427
+ return to_mixture(mixture, sources)
2201
2428
 
2202
2429
 
2203
2430
  def _speaker(db: partial, s_id: int | None, tier: str, use_cache: bool = True) -> str | None:
@@ -2220,27 +2447,62 @@ def __speaker(db: partial, s_id: int | None, tier: str) -> str | None:
2220
2447
  return data[0]
2221
2448
 
2222
2449
 
2223
- def _target_truth_configs(db: partial, t_id: int, use_cache: bool = True) -> TruthConfigs:
2450
+ def _category_truth_configs(db: partial, category: str, use_cache: bool = True) -> dict[str, str]:
2451
+ if use_cache:
2452
+ return __category_truth_configs(db, category)
2453
+ return __category_truth_configs.__wrapped__(db, category)
2454
+
2455
+
2456
+ @lru_cache
2457
+ def __category_truth_configs(db: partial, category: str) -> dict[str, str]:
2458
+ import json
2459
+
2460
+ truth_configs: dict[str, str] = {}
2461
+ with db() as c:
2462
+ s_ids = c.execute(
2463
+ """
2464
+ SELECT id
2465
+ FROM source_file
2466
+ WHERE ? = category
2467
+ """,
2468
+ (category,),
2469
+ ).fetchall()
2470
+
2471
+ for s_id in s_ids:
2472
+ for truth_config_record in c.execute(
2473
+ """
2474
+ SELECT truth_config.config
2475
+ FROM truth_config, source_file_truth_config
2476
+ WHERE ? = source_file_truth_config.source_file_id AND truth_config.id = source_file_truth_config.truth_config_id
2477
+ """,
2478
+ (s_id[0],),
2479
+ ).fetchall():
2480
+ truth_config = json.loads(truth_config_record[0])
2481
+ truth_configs[truth_config["name"]] = truth_config["function"]
2482
+ return truth_configs
2483
+
2484
+
2485
+ def _source_truth_configs(db: partial, s_id: int, use_cache: bool = True) -> TruthConfigs:
2224
2486
  if use_cache:
2225
- return __target_truth_configs(db, t_id)
2226
- return __target_truth_configs.__wrapped__(db, t_id)
2487
+ return __source_truth_configs(db, s_id)
2488
+ return __source_truth_configs.__wrapped__(db, s_id)
2227
2489
 
2228
2490
 
2229
2491
  @lru_cache
2230
- def __target_truth_configs(db: partial, t_id: int) -> TruthConfigs:
2492
+ def __source_truth_configs(db: partial, s_id: int) -> TruthConfigs:
2231
2493
  import json
2232
2494
 
2233
- from .datatypes import TruthConfig
2495
+ from ..datatypes import TruthConfig
2234
2496
 
2235
2497
  truth_configs: TruthConfigs = {}
2236
2498
  with db() as c:
2237
2499
  for truth_config_record in c.execute(
2238
2500
  """
2239
2501
  SELECT truth_config.config
2240
- FROM truth_config, target_file_truth_config
2241
- WHERE ? = target_file_truth_config.target_file_id AND truth_config.id = target_file_truth_config.truth_config_id
2502
+ FROM truth_config, source_file_truth_config
2503
+ WHERE ? = source_file_truth_config.source_file_id AND truth_config.id = source_file_truth_config.truth_config_id
2242
2504
  """,
2243
- (t_id,),
2505
+ (s_id,),
2244
2506
  ).fetchall():
2245
2507
  truth_config = json.loads(truth_config_record[0])
2246
2508
  truth_configs[truth_config["name"]] = TruthConfig(