sonusai 0.19.6__py3-none-any.whl → 0.19.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. sonusai/__init__.py +1 -1
  2. sonusai/aawscd_probwrite.py +1 -1
  3. sonusai/calc_metric_spenh.py +1 -1
  4. sonusai/genft.py +29 -14
  5. sonusai/genmetrics.py +60 -42
  6. sonusai/genmix.py +41 -29
  7. sonusai/genmixdb.py +56 -64
  8. sonusai/metrics/calc_class_weights.py +1 -3
  9. sonusai/metrics/calc_optimal_thresholds.py +2 -2
  10. sonusai/metrics/calc_phase_distance.py +1 -1
  11. sonusai/metrics/calc_speech.py +6 -6
  12. sonusai/metrics/class_summary.py +6 -15
  13. sonusai/metrics/confusion_matrix_summary.py +11 -27
  14. sonusai/metrics/one_hot.py +3 -3
  15. sonusai/metrics/snr_summary.py +7 -7
  16. sonusai/mixture/__init__.py +2 -17
  17. sonusai/mixture/augmentation.py +5 -6
  18. sonusai/mixture/class_count.py +1 -1
  19. sonusai/mixture/config.py +36 -46
  20. sonusai/mixture/data_io.py +30 -1
  21. sonusai/mixture/datatypes.py +29 -40
  22. sonusai/mixture/db_datatypes.py +1 -1
  23. sonusai/mixture/feature.py +3 -23
  24. sonusai/mixture/generation.py +161 -204
  25. sonusai/mixture/helpers.py +29 -187
  26. sonusai/mixture/mixdb.py +386 -159
  27. sonusai/mixture/soundfile_audio.py +1 -1
  28. sonusai/mixture/sox_audio.py +4 -4
  29. sonusai/mixture/sox_augmentation.py +1 -1
  30. sonusai/mixture/target_class_balancing.py +9 -11
  31. sonusai/mixture/targets.py +23 -20
  32. sonusai/mixture/torchaudio_audio.py +18 -7
  33. sonusai/mixture/torchaudio_augmentation.py +3 -4
  34. sonusai/mixture/truth.py +21 -34
  35. sonusai/mixture/truth_functions/__init__.py +6 -0
  36. sonusai/mixture/truth_functions/crm.py +51 -37
  37. sonusai/mixture/truth_functions/energy.py +95 -50
  38. sonusai/mixture/truth_functions/file.py +12 -8
  39. sonusai/mixture/truth_functions/metadata.py +24 -0
  40. sonusai/mixture/truth_functions/metrics.py +28 -0
  41. sonusai/mixture/truth_functions/phoneme.py +4 -5
  42. sonusai/mixture/truth_functions/sed.py +32 -23
  43. sonusai/mixture/truth_functions/target.py +62 -29
  44. sonusai/mkwav.py +20 -19
  45. sonusai/queries/queries.py +9 -15
  46. sonusai/speech/l2arctic.py +6 -2
  47. sonusai/summarize_metric_spenh.py +1 -1
  48. sonusai/utils/__init__.py +1 -0
  49. sonusai/utils/asr_functions/aaware_whisper.py +1 -1
  50. sonusai/utils/audio_devices.py +27 -18
  51. sonusai/utils/docstring.py +6 -3
  52. sonusai/utils/energy_f.py +5 -3
  53. sonusai/utils/human_readable_size.py +6 -6
  54. sonusai/utils/load_object.py +15 -0
  55. sonusai/utils/onnx_utils.py +2 -2
  56. sonusai/utils/print_mixture_details.py +3 -3
  57. {sonusai-0.19.6.dist-info → sonusai-0.19.9.dist-info}/METADATA +2 -2
  58. {sonusai-0.19.6.dist-info → sonusai-0.19.9.dist-info}/RECORD +60 -58
  59. sonusai/mixture/truth_functions/datatypes.py +0 -37
  60. {sonusai-0.19.6.dist-info → sonusai-0.19.9.dist-info}/WHEEL +0 -0
  61. {sonusai-0.19.6.dist-info → sonusai-0.19.9.dist-info}/entry_points.txt +0 -0
sonusai/mixture/mixdb.py CHANGED
@@ -8,27 +8,21 @@ from typing import Any
8
8
 
9
9
  from .datatypes import ASRConfigs
10
10
  from .datatypes import AudioF
11
- from .datatypes import AudiosF
12
- from .datatypes import AudiosT
13
11
  from .datatypes import AudioT
14
12
  from .datatypes import ClassCount
15
13
  from .datatypes import Feature
16
14
  from .datatypes import FeatureGeneratorConfig
17
15
  from .datatypes import FeatureGeneratorInfo
18
16
  from .datatypes import GeneralizedIDs
19
- from .datatypes import ImpulseResponseFiles
17
+ from .datatypes import ImpulseResponseFile
20
18
  from .datatypes import MetricDoc
21
19
  from .datatypes import MetricDocs
22
20
  from .datatypes import Mixture
23
- from .datatypes import Mixtures
24
21
  from .datatypes import NoiseFile
25
- from .datatypes import NoiseFiles
26
22
  from .datatypes import Segsnr
27
23
  from .datatypes import SpectralMask
28
- from .datatypes import SpectralMasks
29
24
  from .datatypes import SpeechMetadata
30
25
  from .datatypes import TargetFile
31
- from .datatypes import TargetFiles
32
26
  from .datatypes import TransformConfig
33
27
  from .datatypes import TruthConfigs
34
28
  from .datatypes import TruthDict
@@ -46,7 +40,13 @@ def db_file(location: str, test: bool = False) -> str:
46
40
  return join(location, name)
47
41
 
48
42
 
49
- def db_connection(location: str, create: bool = False, readonly: bool = True, test: bool = False) -> Connection:
43
+ def db_connection(
44
+ location: str,
45
+ create: bool = False,
46
+ readonly: bool = True,
47
+ test: bool = False,
48
+ verbose: bool = False,
49
+ ) -> Connection:
50
50
  import sqlite3
51
51
  from os import remove
52
52
  from os.path import exists
@@ -62,7 +62,10 @@ def db_connection(location: str, create: bool = False, readonly: bool = True, te
62
62
  name += "?mode=ro"
63
63
 
64
64
  connection = sqlite3.connect("file:" + name, uri=True)
65
- # connection.set_trace_callback(print)
65
+
66
+ if verbose:
67
+ connection.set_trace_callback(print)
68
+
66
69
  return connection
67
70
 
68
71
 
@@ -82,8 +85,30 @@ class SQLiteContextManager:
82
85
 
83
86
  class MixtureDatabase:
84
87
  def __init__(self, location: str, test: bool = False) -> None:
88
+ import json
89
+ from os.path import exists
90
+
91
+ from .config import load_config
92
+
85
93
  self.location = location
86
- self.db = partial(SQLiteContextManager, self.location, test)
94
+ self.test = test
95
+
96
+ if not exists(db_file(self.location, self.test)):
97
+ raise OSError(f"Could not find mixture database in {self.location}")
98
+
99
+ self.db = partial(SQLiteContextManager, self.location, self.test)
100
+
101
+ # Check config.yml to see if asr_configs has changed and update database if needed
102
+ config = load_config(self.location)
103
+ new_asr_configs = json.dumps(config["asr_configs"])
104
+ with self.db() as c:
105
+ old_asr_configs = c.execute("SELECT top.asr_configs FROM top").fetchone()
106
+
107
+ if old_asr_configs is not None and new_asr_configs != old_asr_configs[0]:
108
+ con = db_connection(location=self.location, readonly=False, test=self.test)
109
+ con.execute("UPDATE top SET asr_configs = ? WHERE ? = id", (new_asr_configs,))
110
+ con.commit()
111
+ con.close()
87
112
 
88
113
  @cached_property
89
114
  def json(self) -> str:
@@ -127,10 +152,10 @@ class MixtureDatabase:
127
152
  return get_feature_generator_info(self.fg_config)
128
153
 
129
154
  @cached_property
130
- def truth_parameters(self) -> dict[str, int]:
155
+ def truth_parameters(self) -> dict[str, int | None]:
131
156
  with self.db() as c:
132
157
  rows = c.execute("SELECT * FROM truth_parameters").fetchall()
133
- truth_parameters: dict[str, int] = {}
158
+ truth_parameters: dict[str, int | None] = {}
134
159
  for row in rows:
135
160
  truth_parameters[row[1]] = row[2]
136
161
  return truth_parameters
@@ -197,48 +222,58 @@ class MixtureDatabase:
197
222
  "mxssnrdbf_std",
198
223
  "Per-bin segmental standard deviation of the dB frame values over all frames (using feature transform)",
199
224
  ),
200
- MetricDoc("Mixture Metrics", "mxpesq", "PESQ of mixture versus true target[0]"),
225
+ MetricDoc("Mixture Metrics", "mxpesq", "PESQ of mixture versus true targets"),
201
226
  MetricDoc(
202
227
  "Mixture Metrics",
203
228
  "mxwsdr",
204
- "Weighted signal distorion ratio of mixture versus true target[0]",
229
+ "Weighted signal distortion ratio of mixture versus true targets",
205
230
  ),
206
231
  MetricDoc(
207
232
  "Mixture Metrics",
208
233
  "mxpd",
209
- "Phase distance between mixture and true target[0]",
234
+ "Phase distance between mixture and true targets",
210
235
  ),
211
236
  MetricDoc(
212
237
  "Mixture Metrics",
213
238
  "mxstoi",
214
- "Short term objective intelligibility of mixture versus true target[0]",
239
+ "Short term objective intelligibility of mixture versus true targets",
215
240
  ),
216
241
  MetricDoc(
217
242
  "Mixture Metrics",
218
243
  "mxcsig",
219
- "Predicted rating of speech distortion of mixture versus true target[0]",
244
+ "Predicted rating of speech distortion of mixture versus true targets",
220
245
  ),
221
246
  MetricDoc(
222
247
  "Mixture Metrics",
223
248
  "mxcbak",
224
- "Predicted rating of background distortion of mixture versus true target[0]",
249
+ "Predicted rating of background distortion of mixture versus true targets",
225
250
  ),
226
251
  MetricDoc(
227
252
  "Mixture Metrics",
228
253
  "mxcovl",
229
- "Predicted rating of overall quality of mixture versus true target[0]",
254
+ "Predicted rating of overall quality of mixture versus true targets",
230
255
  ),
231
256
  MetricDoc("Mixture Metrics", "ssnr", "Segmental SNR"),
232
- MetricDoc("Target Metrics", "tdco", "Target[0] DC offset"),
233
- MetricDoc("Target Metrics", "tmin", "Target[0] min level"),
234
- MetricDoc("Target Metrics", "tmax", "Target[0] max levl"),
235
- MetricDoc("Target Metrics", "tpkdb", "Target[0] Pk lev dB"),
236
- MetricDoc("Target Metrics", "tlrms", "Target[0] RMS lev dB"),
237
- MetricDoc("Target Metrics", "tpkr", "Target[0] RMS Pk dB"),
238
- MetricDoc("Target Metrics", "ttr", "Target[0] RMS Tr dB"),
239
- MetricDoc("Target Metrics", "tcr", "Target[0] Crest factor"),
240
- MetricDoc("Target Metrics", "tfl", "Target[0] Flat factor"),
241
- MetricDoc("Target Metrics", "tpkc", "Target[0] Pk count"),
257
+ MetricDoc("Mixture Metrics", "mxtdco", "Mixture target DC offset"),
258
+ MetricDoc("Mixture Metrics", "mxtmin", "Mixture target min level"),
259
+ MetricDoc("Mixture Metrics", "mxtmax", "Mixture target max levl"),
260
+ MetricDoc("Mixture Metrics", "mxtpkdb", "Mixture target Pk lev dB"),
261
+ MetricDoc("Mixture Metrics", "mxtlrms", "Mixture target RMS lev dB"),
262
+ MetricDoc("Mixture Metrics", "mxtpkr", "Mixture target RMS Pk dB"),
263
+ MetricDoc("Mixture Metrics", "mxttr", "Mixture target RMS Tr dB"),
264
+ MetricDoc("Mixture Metrics", "mxtcr", "Mixture target Crest factor"),
265
+ MetricDoc("Mixture Metrics", "mxtfl", "Mixture target Flat factor"),
266
+ MetricDoc("Mixture Metrics", "mxtpkc", "Mixture target Pk count"),
267
+ MetricDoc("Targets Metrics", "tdco", "Targets DC offset"),
268
+ MetricDoc("Targets Metrics", "tmin", "Targets min level"),
269
+ MetricDoc("Targets Metrics", "tmax", "Targets max levl"),
270
+ MetricDoc("Targets Metrics", "tpkdb", "Targets Pk lev dB"),
271
+ MetricDoc("Targets Metrics", "tlrms", "Targets RMS lev dB"),
272
+ MetricDoc("Targets Metrics", "tpkr", "Targets RMS Pk dB"),
273
+ MetricDoc("Targets Metrics", "ttr", "Targets RMS Tr dB"),
274
+ MetricDoc("Targets Metrics", "tcr", "Targets Crest factor"),
275
+ MetricDoc("Targets Metrics", "tfl", "Targets Flat factor"),
276
+ MetricDoc("Targets Metrics", "tpkc", "Targets Pk count"),
242
277
  MetricDoc("Noise Metrics", "ndco", "Noise DC offset"),
243
278
  MetricDoc("Noise Metrics", "nmin", "Noise min level"),
244
279
  MetricDoc("Noise Metrics", "nmax", "Noise max levl"),
@@ -272,11 +307,18 @@ class MixtureDatabase:
272
307
  ]
273
308
  )
274
309
  for name in self.asr_configs:
310
+ metrics.append(
311
+ MetricDoc(
312
+ "Target Metrics",
313
+ f"mxtasr.{name}",
314
+ f"Mixture Target ASR text using {name} ASR as defined in mixdb asr_configs parameter",
315
+ )
316
+ )
275
317
  metrics.append(
276
318
  MetricDoc(
277
319
  "Target Metrics",
278
320
  f"tasr.{name}",
279
- f"Target[0] ASR text using {name} ASR as defined in mixdb asr_configs parameter",
321
+ f"Targets ASR text using {name} ASR as defined in mixdb asr_configs parameter",
280
322
  )
281
323
  )
282
324
  metrics.append(
@@ -486,7 +528,7 @@ class MixtureDatabase:
486
528
  )
487
529
 
488
530
  @cached_property
489
- def spectral_masks(self) -> SpectralMasks:
531
+ def spectral_masks(self) -> list[SpectralMask]:
490
532
  """Get spectral masks from db
491
533
 
492
534
  :return: Spectral masks
@@ -517,7 +559,7 @@ class MixtureDatabase:
517
559
  return _spectral_mask(self.db, sm_id)
518
560
 
519
561
  @cached_property
520
- def target_files(self) -> TargetFiles:
562
+ def target_files(self) -> list[TargetFile]:
521
563
  """Get target files from db
522
564
 
523
565
  :return: Target files
@@ -529,17 +571,19 @@ class MixtureDatabase:
529
571
  from .db_datatypes import TargetFileRecord
530
572
 
531
573
  with self.db() as c:
532
- target_files: TargetFiles = []
574
+ target_files: list[TargetFile] = []
533
575
  target_file_records = [
534
576
  TargetFileRecord(*result) for result in c.execute("SELECT * FROM target_file").fetchall()
535
577
  ]
536
578
  for target_file_record in target_file_records:
537
579
  truth_configs: TruthConfigs = {}
538
580
  for truth_config_records in c.execute(
539
- "SELECT truth_config.config "
540
- + "FROM truth_config, target_file_truth_config "
541
- + "WHERE ? = target_file_truth_config.target_file_id "
542
- + "AND truth_config.id = target_file_truth_config.truth_config_id",
581
+ """
582
+ SELECT truth_config.config
583
+ FROM truth_config, target_file_truth_config
584
+ WHERE ? = target_file_truth_config.target_file_id
585
+ AND truth_config.id = target_file_truth_config.truth_config_id
586
+ """,
543
587
  (target_file_record.id,),
544
588
  ).fetchall():
545
589
  truth_config = json.loads(truth_config_records[0])
@@ -587,7 +631,7 @@ class MixtureDatabase:
587
631
  return int(c.execute("SELECT count(target_file.id) FROM target_file").fetchone()[0])
588
632
 
589
633
  @cached_property
590
- def noise_files(self) -> NoiseFiles:
634
+ def noise_files(self) -> list[NoiseFile]:
591
635
  """Get noise files from db
592
636
 
593
637
  :return: Noise files
@@ -625,7 +669,7 @@ class MixtureDatabase:
625
669
  return int(c.execute("SELECT count(noise_file.id) FROM noise_file").fetchone()[0])
626
670
 
627
671
  @cached_property
628
- def impulse_response_files(self) -> ImpulseResponseFiles:
672
+ def impulse_response_files(self) -> list[ImpulseResponseFile]:
629
673
  """Get impulse response files from db
630
674
 
631
675
  :return: Impulse response files
@@ -635,10 +679,6 @@ class MixtureDatabase:
635
679
  from .datatypes import ImpulseResponseFile
636
680
 
637
681
  with self.db() as c:
638
- # for impulse_response in c.execute(
639
- # "SELECT impulse_response_file.* FROM impulse_response_file"
640
- # ).fetchall():
641
- # print(impulse_response)
642
682
  return [
643
683
  ImpulseResponseFile(impulse_response[1], json.loads(impulse_response[2]))
644
684
  for impulse_response in c.execute(
@@ -678,7 +718,7 @@ class MixtureDatabase:
678
718
  return int(c.execute("SELECT count(impulse_response_file.id) FROM impulse_response_file").fetchone()[0])
679
719
 
680
720
  @cached_property
681
- def mixtures(self) -> Mixtures:
721
+ def mixtures(self) -> list[Mixture]:
682
722
  """Get mixtures from db
683
723
 
684
724
  :return: Mixtures
@@ -689,13 +729,16 @@ class MixtureDatabase:
689
729
  from .helpers import to_target
690
730
 
691
731
  with self.db() as c:
692
- mixtures: Mixtures = []
732
+ mixtures: list[Mixture] = []
693
733
  for mixture in [MixtureRecord(*record) for record in c.execute("SELECT * FROM mixture").fetchall()]:
694
734
  targets = [
695
735
  to_target(TargetRecord(*target))
696
736
  for target in c.execute(
697
- "SELECT target.* FROM target, mixture_target "
698
- + "WHERE ? = mixture_target.mixture_id AND target.id = mixture_target.target_id",
737
+ """
738
+ SELECT target.*
739
+ FROM target, mixture_target
740
+ WHERE ? = mixture_target.mixture_id AND target.id = mixture_target.target_id
741
+ """,
699
742
  (mixture.id,),
700
743
  ).fetchall()
701
744
  ]
@@ -744,7 +787,7 @@ class MixtureDatabase:
744
787
  return int(c.execute("SELECT count(mixture.id) FROM mixture").fetchone()[0])
745
788
 
746
789
  def read_mixture_data(self, m_id: int, items: list[str] | str) -> Any:
747
- """Read mixture data from a mixture HDF5 file
790
+ """Read mixture data
748
791
 
749
792
  :param m_id: Zero-based mixture ID
750
793
  :param items: String(s) of dataset(s) to retrieve
@@ -792,7 +835,7 @@ class MixtureDatabase:
792
835
  class_indices.extend(self.target_file(t_id).class_indices)
793
836
  return sorted(set(class_indices))
794
837
 
795
- def mixture_targets(self, m_id: int, force: bool = False) -> AudiosT:
838
+ def mixture_targets(self, m_id: int, force: bool = False) -> list[AudioT]:
796
839
  """Get the list of augmented target audio data (one per target in the mixup) for the given mixture ID
797
840
 
798
841
  :param m_id: Zero-based mixture ID
@@ -826,7 +869,7 @@ class MixtureDatabase:
826
869
 
827
870
  return targets_audio
828
871
 
829
- def mixture_targets_f(self, m_id: int, targets: AudiosT | None = None, force: bool = False) -> AudiosF:
872
+ def mixture_targets_f(self, m_id: int, targets: list[AudioT] | None = None, force: bool = False) -> list[AudioF]:
830
873
  """Get the list of augmented target transform data (one per target in the mixup) for the given mixture ID
831
874
 
832
875
  :param m_id: Zero-based mixture ID
@@ -841,7 +884,7 @@ class MixtureDatabase:
841
884
 
842
885
  return [forward_transform(target, self.ft_config) for target in targets]
843
886
 
844
- def mixture_target(self, m_id: int, targets: AudiosT | None = None, force: bool = False) -> AudioT:
887
+ def mixture_target(self, m_id: int, targets: list[AudioT] | None = None, force: bool = False) -> AudioT:
845
888
  """Get the augmented target audio data for the given mixture ID
846
889
 
847
890
  :param m_id: Zero-based mixture ID
@@ -864,7 +907,7 @@ class MixtureDatabase:
864
907
  def mixture_target_f(
865
908
  self,
866
909
  m_id: int,
867
- targets: AudiosT | None = None,
910
+ targets: list[AudioT] | None = None,
868
911
  target: AudioT | None = None,
869
912
  force: bool = False,
870
913
  ) -> AudioF:
@@ -900,7 +943,7 @@ class MixtureDatabase:
900
943
 
901
944
  mixture = self.mixture(m_id)
902
945
  noise = self.augmented_noise_audio(mixture)
903
- noise = get_next_noise(audio=noise, offset=mixture.noise.offset, length=mixture.samples)
946
+ noise = get_next_noise(audio=noise, offset=mixture.noise_offset, length=mixture.samples)
904
947
  return apply_gain(audio=noise, gain=mixture.noise_snr_gain)
905
948
 
906
949
  def mixture_noise_f(self, m_id: int, noise: AudioT | None = None, force: bool = False) -> AudioF:
@@ -921,7 +964,7 @@ class MixtureDatabase:
921
964
  def mixture_mixture(
922
965
  self,
923
966
  m_id: int,
924
- targets: AudiosT | None = None,
967
+ targets: list[AudioT] | None = None,
925
968
  target: AudioT | None = None,
926
969
  noise: AudioT | None = None,
927
970
  force: bool = False,
@@ -951,7 +994,7 @@ class MixtureDatabase:
951
994
  def mixture_mixture_f(
952
995
  self,
953
996
  m_id: int,
954
- targets: AudiosT | None = None,
997
+ targets: list[AudioT] | None = None,
955
998
  target: AudioT | None = None,
956
999
  noise: AudioT | None = None,
957
1000
  mixture: AudioT | None = None,
@@ -988,11 +1031,11 @@ class MixtureDatabase:
988
1031
  def mixture_truth_t(
989
1032
  self,
990
1033
  m_id: int,
991
- targets: AudiosT | None = None,
1034
+ targets: list[AudioT] | None = None,
992
1035
  noise: AudioT | None = None,
993
1036
  mixture: AudioT | None = None,
994
1037
  force: bool = False,
995
- ) -> TruthDict:
1038
+ ) -> list[TruthDict]:
996
1039
  """Get the truth_t data for the given mixture ID
997
1040
 
998
1041
  :param m_id: Zero-based mixture ID
@@ -1000,9 +1043,9 @@ class MixtureDatabase:
1000
1043
  :param noise: Augmented noise audio data for the given mixture ID
1001
1044
  :param mixture: Mixture audio data for the given mixture ID
1002
1045
  :param force: Force computing data from original sources regardless of whether cached data exists
1003
- :return: truth_t data
1046
+ :return: list of truth_t data
1004
1047
  """
1005
- from .helpers import get_truth
1048
+ from .truth import truth_function
1006
1049
 
1007
1050
  if not force:
1008
1051
  truth_t = self.read_mixture_data(m_id, "truth_t")
@@ -1018,12 +1061,18 @@ class MixtureDatabase:
1018
1061
  if force or mixture is None:
1019
1062
  mixture = self.mixture_mixture(m_id, targets=targets, noise=noise, force=force)
1020
1063
 
1021
- return get_truth(self, self.mixture(m_id), targets, noise, mixture)
1064
+ if not all(len(target) == self.mixture(m_id).samples for target in targets):
1065
+ raise ValueError("Lengths of targets do not match length of mixture")
1066
+
1067
+ if len(noise) != self.mixture(m_id).samples:
1068
+ raise ValueError("Length of noise does not match length of mixture")
1069
+
1070
+ return truth_function(self, m_id)
1022
1071
 
1023
1072
  def mixture_segsnr_t(
1024
1073
  self,
1025
1074
  m_id: int,
1026
- targets: AudiosT | None = None,
1075
+ targets: list[AudioT] | None = None,
1027
1076
  target: AudioT | None = None,
1028
1077
  noise: AudioT | None = None,
1029
1078
  force: bool = False,
@@ -1037,7 +1086,9 @@ class MixtureDatabase:
1037
1086
  :param force: Force computing data from original sources regardless of whether cached data exists
1038
1087
  :return: segsnr_t data
1039
1088
  """
1040
- from .helpers import get_segsnr_t
1089
+ import numpy as np
1090
+ import torch
1091
+ from pyaaware import ForwardTransform
1041
1092
 
1042
1093
  if not force:
1043
1094
  segsnr_t = self.read_mixture_data(m_id, "segsnr_t")
@@ -1050,13 +1101,45 @@ class MixtureDatabase:
1050
1101
  if force or noise is None:
1051
1102
  noise = self.mixture_noise(m_id, force)
1052
1103
 
1053
- return get_segsnr_t(self, self.mixture(m_id), target, noise)
1104
+ ft = ForwardTransform(
1105
+ length=self.ft_config.length,
1106
+ overlap=self.ft_config.overlap,
1107
+ bin_start=self.ft_config.bin_start,
1108
+ bin_end=self.ft_config.bin_end,
1109
+ ttype=self.ft_config.ttype,
1110
+ )
1111
+
1112
+ mixture = self.mixture(m_id)
1113
+
1114
+ segsnr_t = np.empty(mixture.samples, dtype=np.float32)
1115
+
1116
+ target_energy = ft.execute_all(torch.from_numpy(target))[1].numpy()
1117
+ noise_energy = ft.execute_all(torch.from_numpy(noise))[1].numpy()
1118
+
1119
+ offsets = range(0, mixture.samples, self.ft_config.overlap)
1120
+ if len(target_energy) != len(offsets):
1121
+ raise ValueError(
1122
+ f"Number of frames in energy, {len(target_energy)},"
1123
+ f" is not number of frames in mixture, {len(offsets)}"
1124
+ )
1125
+
1126
+ for idx, offset in enumerate(offsets):
1127
+ indices = slice(offset, offset + self.ft_config.overlap)
1128
+
1129
+ if noise_energy[idx] == 0:
1130
+ snr = np.float32(np.inf)
1131
+ else:
1132
+ snr = np.float32(target_energy[idx] / noise_energy[idx])
1133
+
1134
+ segsnr_t[indices] = snr
1135
+
1136
+ return segsnr_t
1054
1137
 
1055
1138
  def mixture_segsnr(
1056
1139
  self,
1057
1140
  m_id: int,
1058
1141
  segsnr_t: Segsnr | None = None,
1059
- targets: AudiosT | None = None,
1142
+ targets: list[AudioT] | None = None,
1060
1143
  target: AudioT | None = None,
1061
1144
  noise: AudioT | None = None,
1062
1145
  force: bool = False,
@@ -1088,12 +1171,12 @@ class MixtureDatabase:
1088
1171
  def mixture_ft(
1089
1172
  self,
1090
1173
  m_id: int,
1091
- targets: AudiosT | None = None,
1174
+ targets: list[AudioT] | None = None,
1092
1175
  target: AudioT | None = None,
1093
1176
  noise: AudioT | None = None,
1094
1177
  mixture_f: AudioF | None = None,
1095
1178
  mixture: AudioT | None = None,
1096
- truth_t: TruthDict | None = None,
1179
+ truth_t: list[TruthDict] | None = None,
1097
1180
  force: bool = False,
1098
1181
  ) -> tuple[Feature, TruthDict]:
1099
1182
  """Get the feature and truth_f data for the given mixture ID
@@ -1132,19 +1215,24 @@ class MixtureDatabase:
1132
1215
 
1133
1216
  fg = FeatureGenerator(self.fg_config.feature_mode, self.fg_config.truth_parameters)
1134
1217
 
1135
- feature, truth_f = fg.execute_all(mixture_f, truth_t)
1136
- for key in self.truth_configs:
1137
- truth_f[key] = truth_stride_reduction(truth_f[key], self.truth_configs[key].stride_reduction)
1218
+ # TODO: handle mixup in truth_t
1219
+ feature, truth_f = fg.execute_all(mixture_f, truth_t[0])
1220
+ if truth_f is not None:
1221
+ for key in self.truth_configs:
1222
+ if self.truth_parameters[key] is not None:
1223
+ truth_f[key] = truth_stride_reduction(truth_f[key], self.truth_configs[key].stride_reduction)
1224
+ else:
1225
+ raise TypeError("Unexpected truth of None from feature generator")
1138
1226
 
1139
1227
  return feature, truth_f
1140
1228
 
1141
1229
  def mixture_feature(
1142
1230
  self,
1143
1231
  m_id: int,
1144
- targets: AudiosT | None = None,
1232
+ targets: list[AudioT] | None = None,
1145
1233
  noise: AudioT | None = None,
1146
1234
  mixture: AudioT | None = None,
1147
- truth_t: TruthDict | None = None,
1235
+ truth_t: list[TruthDict] | None = None,
1148
1236
  force: bool = False,
1149
1237
  ) -> Feature:
1150
1238
  """Get the feature data for the given mixture ID
@@ -1170,10 +1258,10 @@ class MixtureDatabase:
1170
1258
  def mixture_truth_f(
1171
1259
  self,
1172
1260
  m_id: int,
1173
- targets: AudiosT | None = None,
1261
+ targets: list[AudioT] | None = None,
1174
1262
  noise: AudioT | None = None,
1175
1263
  mixture: AudioT | None = None,
1176
- truth_t: TruthDict | None = None,
1264
+ truth_t: list[TruthDict] | None = None,
1177
1265
  force: bool = False,
1178
1266
  ) -> TruthDict:
1179
1267
  """Get the truth_f data for the given mixture ID
@@ -1199,9 +1287,9 @@ class MixtureDatabase:
1199
1287
  def mixture_class_count(
1200
1288
  self,
1201
1289
  m_id: int,
1202
- targets: AudiosT | None = None,
1290
+ targets: list[AudioT] | None = None,
1203
1291
  noise: AudioT | None = None,
1204
- truth_t: TruthDict | None = None,
1292
+ truth_t: list[TruthDict] | None = None,
1205
1293
  ) -> ClassCount:
1206
1294
  """Compute the number of frames for which each class index is active for the given mixture ID
1207
1295
 
@@ -1220,7 +1308,8 @@ class MixtureDatabase:
1220
1308
  num_classes = self.num_classes
1221
1309
  if "sed" in self.truth_configs:
1222
1310
  for cl in range(num_classes):
1223
- class_count[cl] = int(np.sum(truth_t["sed"][:, cl] >= self.class_weights_thresholds[cl]))
1311
+ # TODO: handle mixup in truth_t
1312
+ class_count[cl] = int(np.sum(truth_t[0]["sed"][:, cl] >= self.class_weights_thresholds[cl]))
1224
1313
 
1225
1314
  return class_count
1226
1315
 
@@ -1300,7 +1389,12 @@ class MixtureDatabase:
1300
1389
 
1301
1390
  return results
1302
1391
 
1303
- def mixids_for_speech_metadata(self, tier: str, value: str | None = None, where: str | None = None) -> list[int]:
1392
+ def mixids_for_speech_metadata(
1393
+ self,
1394
+ tier: str | None = None,
1395
+ value: str | None = None,
1396
+ where: str | None = None,
1397
+ ) -> list[int]:
1304
1398
  """Get a list of mixture IDs for the given speech metadata tier.
1305
1399
 
1306
1400
  If 'where' is None, then include mixture IDs whose tier values are equal to the given 'value'.
@@ -1310,35 +1404,35 @@ class MixtureDatabase:
1310
1404
  Examples:
1311
1405
  >>> mixdb = MixtureDatabase('/mixdb_location')
1312
1406
 
1313
- >>> mixids = mixdb.mixids_for_speech_metadata('speaker_id', 'TIMIT_ARC0')
1314
- Get mixutre IDs for mixtures with speakers whose speaker_ids are 'TIMIT_ARC0'.
1407
+ >>> mixids = mixdb.mixids_for_speech_metadata('speaker_id', 'TIMIT_ABW0')
1408
+ Get mixture IDs for mixtures with speakers whose speaker_ids are 'TIMIT_ABW0'.
1315
1409
 
1316
- >>> mixids = mixdb.mixids_for_speech_metadata('age', where='age < 25')
1317
- Get mixture IDs for mixtures with speakers whose ages are less than 25.
1410
+ >>> mixids = mixdb.mixids_for_speech_metadata(where='age >= 27')
1411
+ Get mixture IDs for mixtures with speakers whose ages are greater than or equal to 27.
1318
1412
 
1319
- >>> mixids = mixdb.mixids_for_speech_metadata('dialect', where="dialect in ('New York City', 'Northern')")
1413
+ >>> mixids = mixdb.mixids_for_speech_metadata(where="dialect in ('New York City', 'Northern')")
1320
1414
  Get mixture IDs for mixtures with speakers whose dialects are either 'New York City' or 'Northern'.
1321
1415
  """
1322
1416
  if value is None and where is None:
1323
1417
  raise ValueError("Must provide either value or where")
1324
1418
 
1325
1419
  if where is None:
1420
+ if tier is None:
1421
+ raise ValueError("Must provide tier")
1326
1422
  where = f"{tier} = '{value}'"
1327
1423
 
1328
- if tier in self.textgrid_metadata_tiers:
1424
+ if tier is not None and tier in self.textgrid_metadata_tiers:
1329
1425
  raise ValueError(f"TextGrid tier data, '{tier}', is not supported in mixids_for_speech_metadata().")
1330
1426
 
1331
1427
  with self.db() as c:
1332
- speaker_ids = [
1333
- speaker_id[0] for speaker_id in c.execute(f"SELECT id FROM speaker WHERE {where}").fetchall()
1334
- ]
1335
- results = c.execute(
1336
- "SELECT id FROM target_file " + f"WHERE speaker_id IN ({','.join(map(str, speaker_ids))})"
1337
- ).fetchall()
1338
- target_file_ids = [target_file_id[0] for target_file_id in results]
1428
+ results = c.execute(f"SELECT id FROM speaker WHERE {where}").fetchall()
1429
+ speaker_ids = ",".join(map(str, [i[0] for i in results]))
1430
+
1431
+ results = c.execute(f"SELECT id FROM target_file WHERE speaker_id IN ({speaker_ids})").fetchall()
1432
+ target_file_ids = ",".join(map(str, [i[0] for i in results]))
1433
+
1339
1434
  results = c.execute(
1340
- "SELECT mixture_id FROM mixture_target "
1341
- + f"WHERE mixture_target.target_id IN ({','.join(map(str, target_file_ids))})"
1435
+ f"SELECT mixture_id FROM mixture_target WHERE mixture_target.target_id IN ({target_file_ids})"
1342
1436
  ).fetchall()
1343
1437
 
1344
1438
  return [mixture_id[0] - 1 for mixture_id in results]
@@ -1348,9 +1442,29 @@ class MixtureDatabase:
1348
1442
 
1349
1443
  return mixture_all_speech_metadata(self, self.mixture(m_id))
1350
1444
 
1351
- def mixture_metrics(
1352
- self, m_id: int, metrics: list[str], force: bool = False
1353
- ) -> list[float | int | str | Segsnr | None]:
1445
+ def cached_metrics(self, m_ids: GeneralizedIDs = "*") -> list[str]:
1446
+ """Get list of cached metrics for all mixtures."""
1447
+ from glob import glob
1448
+ from os.path import join
1449
+ from pathlib import Path
1450
+
1451
+ supported_metrics = self.supported_metrics.names
1452
+ first = True
1453
+ result: set[str] = set()
1454
+ for m_id in self.mixids_to_list(m_ids):
1455
+ mixture_dir = join(self.location, "mixture", self.mixture(m_id).name)
1456
+ found = {Path(f).stem for f in glob(join(mixture_dir, "*.pkl"))}
1457
+ if first:
1458
+ first = False
1459
+ for f in found:
1460
+ if f in supported_metrics:
1461
+ result.add(f)
1462
+ else:
1463
+ result = result & found
1464
+
1465
+ return sorted(result)
1466
+
1467
+ def mixture_metrics(self, m_id: int, metrics: list[str], force: bool = False) -> list[Any]:
1354
1468
  """Get metrics data for the given mixture ID
1355
1469
 
1356
1470
  :param m_id: Zero-based mixture ID
@@ -1375,10 +1489,23 @@ class MixtureDatabase:
1375
1489
  from sonusai.mixture import SpeechMetrics
1376
1490
  from sonusai.utils import calc_asr
1377
1491
 
1378
- def create_target_audio() -> Callable[[], np.ndarray]:
1379
- state = None
1492
+ def create_targets_audio() -> Callable[[], list[AudioT]]:
1493
+ state: list[AudioT] | None = None
1494
+
1495
+ def get() -> list[AudioT]:
1496
+ nonlocal state
1497
+ if state is None:
1498
+ state = self.mixture_targets(m_id)
1499
+ return state
1500
+
1501
+ return get
1502
+
1503
+ targets_audio = create_targets_audio()
1504
+
1505
+ def create_target_audio() -> Callable[[], AudioT]:
1506
+ state: AudioT | None = None
1380
1507
 
1381
- def get() -> np.ndarray:
1508
+ def get() -> AudioT:
1382
1509
  nonlocal state
1383
1510
  if state is None:
1384
1511
  state = self.mixture_target(m_id)
@@ -1388,10 +1515,10 @@ class MixtureDatabase:
1388
1515
 
1389
1516
  target_audio = create_target_audio()
1390
1517
 
1391
- def create_target_f() -> Callable[[], np.ndarray]:
1392
- state = None
1518
+ def create_target_f() -> Callable[[], AudioF]:
1519
+ state: AudioF | None = None
1393
1520
 
1394
- def get() -> np.ndarray:
1521
+ def get() -> AudioF:
1395
1522
  nonlocal state
1396
1523
  if state is None:
1397
1524
  state = self.mixture_targets_f(m_id)[0]
@@ -1401,10 +1528,10 @@ class MixtureDatabase:
1401
1528
 
1402
1529
  target_f = create_target_f()
1403
1530
 
1404
- def create_noise_audio() -> Callable[[], np.ndarray]:
1405
- state = None
1531
+ def create_noise_audio() -> Callable[[], AudioT]:
1532
+ state: AudioT | None = None
1406
1533
 
1407
- def get() -> np.ndarray:
1534
+ def get() -> AudioT:
1408
1535
  nonlocal state
1409
1536
  if state is None:
1410
1537
  state = self.mixture_noise(m_id)
@@ -1414,10 +1541,10 @@ class MixtureDatabase:
1414
1541
 
1415
1542
  noise_audio = create_noise_audio()
1416
1543
 
1417
- def create_noise_f() -> Callable[[], np.ndarray]:
1418
- state = None
1544
+ def create_noise_f() -> Callable[[], AudioF]:
1545
+ state: AudioF | None = None
1419
1546
 
1420
- def get() -> np.ndarray:
1547
+ def get() -> AudioF:
1421
1548
  nonlocal state
1422
1549
  if state is None:
1423
1550
  state = self.mixture_noise_f(m_id)
@@ -1427,10 +1554,10 @@ class MixtureDatabase:
1427
1554
 
1428
1555
  noise_f = create_noise_f()
1429
1556
 
1430
- def create_mixture_audio() -> Callable[[], np.ndarray]:
1431
- state = None
1557
+ def create_mixture_audio() -> Callable[[], AudioT]:
1558
+ state: AudioT | None = None
1432
1559
 
1433
- def get() -> np.ndarray:
1560
+ def get() -> AudioT:
1434
1561
  nonlocal state
1435
1562
  if state is None:
1436
1563
  state = self.mixture_mixture(m_id)
@@ -1440,10 +1567,10 @@ class MixtureDatabase:
1440
1567
 
1441
1568
  mixture_audio = create_mixture_audio()
1442
1569
 
1443
- def create_segsnr_f() -> Callable[[], np.ndarray]:
1444
- state = None
1570
+ def create_segsnr_f() -> Callable[[], Segsnr]:
1571
+ state: Segsnr | None = None
1445
1572
 
1446
- def get() -> np.ndarray:
1573
+ def get() -> Segsnr:
1447
1574
  nonlocal state
1448
1575
  if state is None:
1449
1576
  state = self.mixture_segsnr(m_id)
@@ -1453,21 +1580,38 @@ class MixtureDatabase:
1453
1580
 
1454
1581
  segsnr_f = create_segsnr_f()
1455
1582
 
1456
- def create_speech() -> Callable[[], SpeechMetrics]:
1457
- state = None
1583
+ def create_speech() -> Callable[[], list[SpeechMetrics]]:
1584
+ state: list[SpeechMetrics] | None = None
1458
1585
 
1459
- def get() -> SpeechMetrics:
1586
+ def get() -> list[SpeechMetrics]:
1460
1587
  nonlocal state
1461
1588
  if state is None:
1462
- state = calc_speech(hypothesis=mixture_audio(), reference=target_audio())
1589
+ state = []
1590
+ for audio in targets_audio():
1591
+ state.append(calc_speech(hypothesis=mixture_audio(), reference=audio))
1463
1592
  return state
1464
1593
 
1465
1594
  return get
1466
1595
 
1467
1596
  speech = create_speech()
1468
1597
 
1598
+ def create_targets_stats() -> Callable[[], list[AudioStatsMetrics]]:
1599
+ state: list[AudioStatsMetrics] | None = None
1600
+
1601
+ def get() -> list[AudioStatsMetrics]:
1602
+ nonlocal state
1603
+ if state is None:
1604
+ state = []
1605
+ for audio in targets_audio():
1606
+ state.append(calc_audio_stats(audio, self.fg_info.ft_config.length / SAMPLE_RATE))
1607
+ return state
1608
+
1609
+ return get
1610
+
1611
+ targets_stats = create_targets_stats()
1612
+
1469
1613
  def create_target_stats() -> Callable[[], AudioStatsMetrics]:
1470
- state = None
1614
+ state: AudioStatsMetrics | None = None
1471
1615
 
1472
1616
  def get() -> AudioStatsMetrics:
1473
1617
  nonlocal state
@@ -1480,7 +1624,7 @@ class MixtureDatabase:
1480
1624
  target_stats = create_target_stats()
1481
1625
 
1482
1626
  def create_noise_stats() -> Callable[[], AudioStatsMetrics]:
1483
- state = None
1627
+ state: AudioStatsMetrics | None = None
1484
1628
 
1485
1629
  def get() -> AudioStatsMetrics:
1486
1630
  nonlocal state
@@ -1508,6 +1652,21 @@ class MixtureDatabase:
1508
1652
 
1509
1653
  asr_config = create_asr_config()
1510
1654
 
1655
+ def create_targets_asr() -> Callable[[str], list[str]]:
1656
+ state: dict[str, list[str]] = {}
1657
+
1658
+ def get(asr_name) -> list[str]:
1659
+ nonlocal state
1660
+ if asr_name not in state:
1661
+ state[asr_name] = []
1662
+ for audio in targets_audio():
1663
+ state[asr_name].append(calc_asr(audio, **asr_config(asr_name)).text)
1664
+ return state[asr_name]
1665
+
1666
+ return get
1667
+
1668
+ targets_asr = create_targets_asr()
1669
+
1511
1670
  def create_target_asr() -> Callable[[str], str]:
1512
1671
  state: dict[str, str] = {}
1513
1672
 
@@ -1541,7 +1700,7 @@ class MixtureDatabase:
1541
1700
  asr_name = parts[1]
1542
1701
  return asr_name
1543
1702
 
1544
- def calc(m: str) -> float | int | str | Segsnr | None:
1703
+ def calc(m: str) -> Any:
1545
1704
  if m == "mxsnr":
1546
1705
  return self.mixture(m_id).snr
1547
1706
 
@@ -1555,7 +1714,7 @@ class MixtureDatabase:
1555
1714
  if m.startswith("mxwer"):
1556
1715
  asr_name = get_asr_name(m)
1557
1716
 
1558
- if self.mixture(m_id).snr < -96:
1717
+ if self.mixture(m_id).is_noise_only:
1559
1718
  # noise only, ignore/reset target asr
1560
1719
  return float("nan")
1561
1720
 
@@ -1569,11 +1728,11 @@ class MixtureDatabase:
1569
1728
  asr_name = get_asr_name(m)
1570
1729
 
1571
1730
  text = self.mixture_speech_metadata(m_id, "text")[0]
1572
- if text is not None:
1573
- return calc_wer(target_asr(asr_name), text).wer * 100
1731
+ if not isinstance(text, str):
1732
+ # TODO: should this be NaN like above?
1733
+ return [float(0)] * len(targets_audio())
1574
1734
 
1575
- # TODO: should this be NaN like above?
1576
- return float(0)
1735
+ return [calc_wer(t, text).wer * 100 for t in targets_asr(asr_name)]
1577
1736
 
1578
1737
  if m.startswith("mxasr"):
1579
1738
  return mixture_asr(get_asr_name(m))
@@ -1603,24 +1762,24 @@ class MixtureDatabase:
1603
1762
  return calc_segsnr_f_bin(target_f(), noise_f()).db_std
1604
1763
 
1605
1764
  if m == "mxpesq":
1606
- if self.mixture(m_id).snr < -96:
1607
- return 0
1608
- return speech().pesq
1765
+ if self.mixture(m_id).is_noise_only:
1766
+ return [0] * len(speech())
1767
+ return [s.pesq for s in speech()]
1609
1768
 
1610
1769
  if m == "mxcsig":
1611
- if self.mixture(m_id).snr < -96:
1612
- return 0
1613
- return speech().csig
1770
+ if self.mixture(m_id).is_noise_only:
1771
+ return [0] * len(speech())
1772
+ return [s.csig for s in speech()]
1614
1773
 
1615
1774
  if m == "mxcbak":
1616
- if self.mixture(m_id).snr < -96:
1617
- return 0
1618
- return speech().cbak
1775
+ if self.mixture(m_id).is_noise_only:
1776
+ return [0] * len(speech())
1777
+ return [s.cbak for s in speech()]
1619
1778
 
1620
1779
  if m == "mxcovl":
1621
- if self.mixture(m_id).snr < -96:
1622
- return 0
1623
- return speech().covl
1780
+ if self.mixture(m_id).is_noise_only:
1781
+ return [0] * len(speech())
1782
+ return [s.covl for s in speech()]
1624
1783
 
1625
1784
  if m == "mxwsdr":
1626
1785
  mixture = mixture_audio()[:, np.newaxis]
@@ -1644,37 +1803,70 @@ class MixtureDatabase:
1644
1803
  extended=False,
1645
1804
  )
1646
1805
 
1647
- if m == "tdco":
1806
+ if m == "mxtdco":
1648
1807
  return target_stats().dco
1649
1808
 
1650
- if m == "tmin":
1809
+ if m == "mxtmin":
1651
1810
  return target_stats().min
1652
1811
 
1653
- if m == "tmax":
1812
+ if m == "mxtmax":
1654
1813
  return target_stats().max
1655
1814
 
1656
- if m == "tpkdb":
1815
+ if m == "mxtpkdb":
1657
1816
  return target_stats().pkdb
1658
1817
 
1659
- if m == "tlrms":
1818
+ if m == "mxtlrms":
1660
1819
  return target_stats().lrms
1661
1820
 
1662
- if m == "tpkr":
1821
+ if m == "mxtpkr":
1663
1822
  return target_stats().pkr
1664
1823
 
1665
- if m == "ttr":
1824
+ if m == "mxttr":
1666
1825
  return target_stats().tr
1667
1826
 
1668
- if m == "tcr":
1827
+ if m == "mxtcr":
1669
1828
  return target_stats().cr
1670
1829
 
1671
- if m == "tfl":
1830
+ if m == "mxtfl":
1672
1831
  return target_stats().fl
1673
1832
 
1674
- if m == "tpkc":
1833
+ if m == "mxtpkc":
1675
1834
  return target_stats().pkc
1676
1835
 
1836
+ if m == "tdco":
1837
+ return [t.dco for t in targets_stats()]
1838
+
1839
+ if m == "tmin":
1840
+ return [t.min for t in targets_stats()]
1841
+
1842
+ if m == "tmax":
1843
+ return [t.max for t in targets_stats()]
1844
+
1845
+ if m == "tpkdb":
1846
+ return [t.pkdb for t in targets_stats()]
1847
+
1848
+ if m == "tlrms":
1849
+ return [t.lrms for t in targets_stats()]
1850
+
1851
+ if m == "tpkr":
1852
+ return [t.pkr for t in targets_stats()]
1853
+
1854
+ if m == "ttr":
1855
+ return [t.tr for t in targets_stats()]
1856
+
1857
+ if m == "tcr":
1858
+ return [t.cr for t in targets_stats()]
1859
+
1860
+ if m == "tfl":
1861
+ return [t.fl for t in targets_stats()]
1862
+
1863
+ if m == "tpkc":
1864
+ return [t.pkc for t in targets_stats()]
1865
+
1677
1866
  if m.startswith("tasr"):
1867
+ return targets_asr(get_asr_name(m))
1868
+
1869
+ if m.startswith("mxtasr"):
1678
1870
  return target_asr(get_asr_name(m))
1679
1871
 
1680
1872
  if m == "ndco":
@@ -1743,7 +1935,14 @@ def _spectral_mask(db: partial, sm_id: int) -> SpectralMask:
1743
1935
 
1744
1936
  with db() as c:
1745
1937
  spectral_mask = SpectralMaskRecord(
1746
- *c.execute("SELECT * FROM spectral_mask WHERE ? = spectral_mask.id", (sm_id,)).fetchone()
1938
+ *c.execute(
1939
+ """
1940
+ SELECT *
1941
+ FROM spectral_mask
1942
+ WHERE ? = spectral_mask.id
1943
+ """,
1944
+ (sm_id,),
1945
+ ).fetchone()
1747
1946
  )
1748
1947
  return SpectralMask(
1749
1948
  f_max_width=spectral_mask.f_max_width,
@@ -1768,7 +1967,14 @@ def _target_file(db: partial, t_id: int) -> TargetFile:
1768
1967
 
1769
1968
  with db() as c:
1770
1969
  target_file = TargetFileRecord(
1771
- *c.execute("SELECT * FROM target_file WHERE ? = target_file.id", (t_id,)).fetchone()
1970
+ *c.execute(
1971
+ """
1972
+ SELECT *
1973
+ FROM target_file
1974
+ WHERE ? = target_file.id
1975
+ """,
1976
+ (t_id,),
1977
+ ).fetchone()
1772
1978
  )
1773
1979
 
1774
1980
  return TargetFile(
@@ -1791,7 +1997,11 @@ def _noise_file(db: partial, n_id: int) -> NoiseFile:
1791
1997
  """
1792
1998
  with db() as c:
1793
1999
  noise = c.execute(
1794
- "SELECT noise_file.name, samples FROM noise_file WHERE ? = noise_file.id",
2000
+ """
2001
+ SELECT noise_file.name, samples
2002
+ FROM noise_file
2003
+ WHERE ? = noise_file.id
2004
+ """,
1795
2005
  (n_id,),
1796
2006
  ).fetchone()
1797
2007
  return NoiseFile(name=noise[0], samples=noise[1])
@@ -1808,7 +2018,11 @@ def _impulse_response_file(db: partial, ir_id: int) -> str:
1808
2018
  with db() as c:
1809
2019
  return str(
1810
2020
  c.execute(
1811
- "SELECT impulse_response_file.file FROM impulse_response_file WHERE ? = impulse_response_file.id",
2021
+ """
2022
+ SELECT impulse_response_file.file
2023
+ FROM impulse_response_file
2024
+ WHERE ? = impulse_response_file.id
2025
+ """,
1812
2026
  (ir_id + 1,),
1813
2027
  ).fetchone()[0]
1814
2028
  )
@@ -1828,13 +2042,25 @@ def _mixture(db: partial, m_id: int) -> Mixture:
1828
2042
  from .helpers import to_target
1829
2043
 
1830
2044
  with db() as c:
1831
- mixture = MixtureRecord(*c.execute("SELECT * FROM mixture WHERE ? = mixture.id", (m_id + 1,)).fetchone())
2045
+ mixture = MixtureRecord(
2046
+ *c.execute(
2047
+ """
2048
+ SELECT *
2049
+ FROM mixture
2050
+ WHERE ? = mixture.id
2051
+ """,
2052
+ (m_id + 1,),
2053
+ ).fetchone()
2054
+ )
2055
+
1832
2056
  targets = [
1833
2057
  to_target(TargetRecord(*target))
1834
2058
  for target in c.execute(
1835
- "SELECT target.* "
1836
- + "FROM target, mixture_target "
1837
- + "WHERE ? = mixture_target.mixture_id AND target.id = mixture_target.target_id",
2059
+ """
2060
+ SELECT target.*
2061
+ FROM target, mixture_target
2062
+ WHERE ? = mixture_target.mixture_id AND target.id = mixture_target.target_id
2063
+ """,
1838
2064
  (mixture.id,),
1839
2065
  ).fetchall()
1840
2066
  ]
@@ -1865,10 +2091,11 @@ def _target_truth_configs(db: partial, t_id: int) -> TruthConfigs:
1865
2091
  truth_configs: TruthConfigs = {}
1866
2092
  with db() as c:
1867
2093
  for truth_config_record in c.execute(
1868
- "SELECT truth_config.config "
1869
- + "FROM truth_config, target_file_truth_config "
1870
- + "WHERE ? = target_file_truth_config.target_file_id "
1871
- + "AND truth_config.id = target_file_truth_config.truth_config_id",
2094
+ """
2095
+ SELECT truth_config.config
2096
+ FROM truth_config, target_file_truth_config
2097
+ WHERE ? = target_file_truth_config.target_file_id AND truth_config.id = target_file_truth_config.truth_config_id
2098
+ """,
1872
2099
  (t_id,),
1873
2100
  ).fetchall():
1874
2101
  truth_config = json.loads(truth_config_record[0])