sonusai 1.0.16__cp311-abi3-macosx_10_12_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sonusai/__init__.py +170 -0
  2. sonusai/aawscd_probwrite.py +148 -0
  3. sonusai/audiofe.py +481 -0
  4. sonusai/calc_metric_spenh.py +1136 -0
  5. sonusai/config/__init__.py +0 -0
  6. sonusai/config/asr.py +21 -0
  7. sonusai/config/config.py +65 -0
  8. sonusai/config/config.yml +49 -0
  9. sonusai/config/constants.py +53 -0
  10. sonusai/config/ir.py +124 -0
  11. sonusai/config/ir_delay.py +62 -0
  12. sonusai/config/source.py +275 -0
  13. sonusai/config/spectral_masks.py +15 -0
  14. sonusai/config/truth.py +64 -0
  15. sonusai/constants.py +14 -0
  16. sonusai/data/__init__.py +0 -0
  17. sonusai/data/silero_vad_v5.1.jit +0 -0
  18. sonusai/data/silero_vad_v5.1.onnx +0 -0
  19. sonusai/data/speech_ma01_01.wav +0 -0
  20. sonusai/data/whitenoise.wav +0 -0
  21. sonusai/datatypes.py +383 -0
  22. sonusai/deprecated/gentcst.py +632 -0
  23. sonusai/deprecated/plot.py +519 -0
  24. sonusai/deprecated/tplot.py +365 -0
  25. sonusai/doc.py +52 -0
  26. sonusai/doc_strings/__init__.py +1 -0
  27. sonusai/doc_strings/doc_strings.py +531 -0
  28. sonusai/genft.py +196 -0
  29. sonusai/genmetrics.py +183 -0
  30. sonusai/genmix.py +199 -0
  31. sonusai/genmixdb.py +235 -0
  32. sonusai/ir_metric.py +551 -0
  33. sonusai/lsdb.py +141 -0
  34. sonusai/main.py +134 -0
  35. sonusai/metrics/__init__.py +43 -0
  36. sonusai/metrics/calc_audio_stats.py +42 -0
  37. sonusai/metrics/calc_class_weights.py +90 -0
  38. sonusai/metrics/calc_optimal_thresholds.py +73 -0
  39. sonusai/metrics/calc_pcm.py +45 -0
  40. sonusai/metrics/calc_pesq.py +36 -0
  41. sonusai/metrics/calc_phase_distance.py +43 -0
  42. sonusai/metrics/calc_sa_sdr.py +64 -0
  43. sonusai/metrics/calc_sample_weights.py +25 -0
  44. sonusai/metrics/calc_segsnr_f.py +82 -0
  45. sonusai/metrics/calc_speech.py +382 -0
  46. sonusai/metrics/calc_wer.py +71 -0
  47. sonusai/metrics/calc_wsdr.py +57 -0
  48. sonusai/metrics/calculate_metrics.py +395 -0
  49. sonusai/metrics/class_summary.py +74 -0
  50. sonusai/metrics/confusion_matrix_summary.py +75 -0
  51. sonusai/metrics/one_hot.py +283 -0
  52. sonusai/metrics/snr_summary.py +128 -0
  53. sonusai/metrics_summary.py +314 -0
  54. sonusai/mixture/__init__.py +15 -0
  55. sonusai/mixture/audio.py +187 -0
  56. sonusai/mixture/class_balancing.py +103 -0
  57. sonusai/mixture/constants.py +3 -0
  58. sonusai/mixture/data_io.py +173 -0
  59. sonusai/mixture/db.py +169 -0
  60. sonusai/mixture/db_datatypes.py +92 -0
  61. sonusai/mixture/effects.py +344 -0
  62. sonusai/mixture/feature.py +78 -0
  63. sonusai/mixture/generation.py +1116 -0
  64. sonusai/mixture/helpers.py +351 -0
  65. sonusai/mixture/ir_effects.py +77 -0
  66. sonusai/mixture/log_duration_and_sizes.py +23 -0
  67. sonusai/mixture/mixdb.py +1857 -0
  68. sonusai/mixture/pad_audio.py +35 -0
  69. sonusai/mixture/resample.py +7 -0
  70. sonusai/mixture/sox_effects.py +195 -0
  71. sonusai/mixture/sox_help.py +650 -0
  72. sonusai/mixture/spectral_mask.py +51 -0
  73. sonusai/mixture/truth.py +61 -0
  74. sonusai/mixture/truth_functions/__init__.py +45 -0
  75. sonusai/mixture/truth_functions/crm.py +105 -0
  76. sonusai/mixture/truth_functions/energy.py +222 -0
  77. sonusai/mixture/truth_functions/file.py +48 -0
  78. sonusai/mixture/truth_functions/metadata.py +24 -0
  79. sonusai/mixture/truth_functions/metrics.py +28 -0
  80. sonusai/mixture/truth_functions/phoneme.py +18 -0
  81. sonusai/mixture/truth_functions/sed.py +98 -0
  82. sonusai/mixture/truth_functions/target.py +142 -0
  83. sonusai/mkwav.py +135 -0
  84. sonusai/onnx_predict.py +363 -0
  85. sonusai/parse/__init__.py +0 -0
  86. sonusai/parse/expand.py +156 -0
  87. sonusai/parse/parse_source_directive.py +129 -0
  88. sonusai/parse/rand.py +214 -0
  89. sonusai/py.typed +0 -0
  90. sonusai/queries/__init__.py +0 -0
  91. sonusai/queries/queries.py +239 -0
  92. sonusai/rs.abi3.so +0 -0
  93. sonusai/rs.pyi +1 -0
  94. sonusai/rust/__init__.py +0 -0
  95. sonusai/speech/__init__.py +0 -0
  96. sonusai/speech/l2arctic.py +121 -0
  97. sonusai/speech/librispeech.py +102 -0
  98. sonusai/speech/mcgill.py +71 -0
  99. sonusai/speech/textgrid.py +89 -0
  100. sonusai/speech/timit.py +138 -0
  101. sonusai/speech/types.py +12 -0
  102. sonusai/speech/vctk.py +53 -0
  103. sonusai/speech/voxceleb.py +108 -0
  104. sonusai/utils/__init__.py +3 -0
  105. sonusai/utils/asl_p56.py +130 -0
  106. sonusai/utils/asr.py +91 -0
  107. sonusai/utils/asr_functions/__init__.py +3 -0
  108. sonusai/utils/asr_functions/aaware_whisper.py +69 -0
  109. sonusai/utils/audio_devices.py +50 -0
  110. sonusai/utils/braced_glob.py +50 -0
  111. sonusai/utils/calculate_input_shape.py +26 -0
  112. sonusai/utils/choice.py +51 -0
  113. sonusai/utils/compress.py +25 -0
  114. sonusai/utils/convert_string_to_number.py +6 -0
  115. sonusai/utils/create_timestamp.py +5 -0
  116. sonusai/utils/create_ts_name.py +14 -0
  117. sonusai/utils/dataclass_from_dict.py +27 -0
  118. sonusai/utils/db.py +16 -0
  119. sonusai/utils/docstring.py +53 -0
  120. sonusai/utils/energy_f.py +44 -0
  121. sonusai/utils/engineering_number.py +166 -0
  122. sonusai/utils/evaluate_random_rule.py +15 -0
  123. sonusai/utils/get_frames_per_batch.py +2 -0
  124. sonusai/utils/get_label_names.py +20 -0
  125. sonusai/utils/grouper.py +6 -0
  126. sonusai/utils/human_readable_size.py +7 -0
  127. sonusai/utils/keyboard_interrupt.py +12 -0
  128. sonusai/utils/load_object.py +21 -0
  129. sonusai/utils/max_text_width.py +9 -0
  130. sonusai/utils/model_utils.py +28 -0
  131. sonusai/utils/numeric_conversion.py +11 -0
  132. sonusai/utils/onnx_utils.py +155 -0
  133. sonusai/utils/parallel.py +162 -0
  134. sonusai/utils/path_info.py +7 -0
  135. sonusai/utils/print_mixture_details.py +60 -0
  136. sonusai/utils/rand.py +13 -0
  137. sonusai/utils/ranges.py +43 -0
  138. sonusai/utils/read_predict_data.py +32 -0
  139. sonusai/utils/reshape.py +154 -0
  140. sonusai/utils/seconds_to_hms.py +7 -0
  141. sonusai/utils/stacked_complex.py +82 -0
  142. sonusai/utils/stratified_shuffle_split.py +170 -0
  143. sonusai/utils/tokenized_shell_vars.py +143 -0
  144. sonusai/utils/write_audio.py +26 -0
  145. sonusai/utils/yes_or_no.py +8 -0
  146. sonusai/vars.py +47 -0
  147. sonusai-1.0.16.dist-info/METADATA +56 -0
  148. sonusai-1.0.16.dist-info/RECORD +150 -0
  149. sonusai-1.0.16.dist-info/WHEEL +4 -0
  150. sonusai-1.0.16.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,1857 @@
1
+ # ruff: noqa: S608
2
+ from functools import cached_property
3
+ from functools import lru_cache
4
+ from functools import partial
5
+ from typing import Any
6
+
7
+ from ..datatypes import ASRConfigs
8
+ from ..datatypes import AudioF
9
+ from ..datatypes import AudioT
10
+ from ..datatypes import ClassCount
11
+ from ..datatypes import Feature
12
+ from ..datatypes import FeatureGeneratorConfig
13
+ from ..datatypes import FeatureGeneratorInfo
14
+ from ..datatypes import GeneralizedIDs
15
+ from ..datatypes import ImpulseResponseFile
16
+ from ..datatypes import MetricDoc
17
+ from ..datatypes import MetricDocs
18
+ from ..datatypes import Mixture
19
+ from ..datatypes import Segsnr
20
+ from ..datatypes import SourceFile
21
+ from ..datatypes import Sources
22
+ from ..datatypes import SourcesAudioF
23
+ from ..datatypes import SourcesAudioT
24
+ from ..datatypes import SpectralMask
25
+ from ..datatypes import SpeechMetadata
26
+ from ..datatypes import TransformConfig
27
+ from ..datatypes import TruthConfigs
28
+ from ..datatypes import TruthDict
29
+ from ..datatypes import TruthsConfigs
30
+ from ..datatypes import TruthsDict
31
+ from ..datatypes import UniversalSNR
32
+ from .db import SQLiteDatabase
33
+ from .db import db_file
34
+
35
+
36
+ class MixtureDatabase:
37
+ def __init__(
38
+ self,
39
+ location: str,
40
+ test: bool = False,
41
+ verbose: bool = False,
42
+ use_cache: bool = True,
43
+ ) -> None:
44
+ self.location = location
45
+ self.test = test
46
+ self.db_path = db_file(location=self.location, test=self.test)
47
+ self.verbose = verbose
48
+ self.use_cache = use_cache
49
+
50
+ self.db = partial(SQLiteDatabase, location=self.location, test=self.test, verbose=self.verbose)
51
+
52
+ # Update ASR configs
53
+ self.update_asr_configs()
54
+
55
+ def update_asr_configs(self) -> None:
56
+ """Update the asr_configs column in the top table with the current asr_configs in the config.yml file."""
57
+ import json
58
+
59
+ from ..config.config import load_config
60
+
61
+ # Check config.yml to see if asr_configs has changed and update the database if needed
62
+ config = load_config(self.location)
63
+ new_asr_configs = json.dumps(config["asr_configs"])
64
+ with SQLiteDatabase(
65
+ location=self.location,
66
+ readonly=False,
67
+ test=self.test,
68
+ verbose=self.verbose,
69
+ ) as c:
70
+ old_asr_configs = c.execute("SELECT asr_configs FROM top").fetchone()
71
+
72
+ if old_asr_configs is not None and new_asr_configs != old_asr_configs[0]:
73
+ c.execute("UPDATE top SET asr_configs = ? WHERE ? = id", (new_asr_configs,))
74
+
75
+ @cached_property
76
+ def json(self) -> str:
77
+ from ..datatypes import MixtureDatabaseConfig
78
+
79
+ config = MixtureDatabaseConfig(
80
+ asr_configs=self.asr_configs,
81
+ class_balancing=self.class_balancing,
82
+ class_labels=self.class_labels,
83
+ class_weights_threshold=self.class_weights_thresholds,
84
+ feature=self.feature,
85
+ ir_files=self.ir_files,
86
+ mixtures=self.mixtures,
87
+ num_classes=self.num_classes,
88
+ spectral_masks=self.spectral_masks,
89
+ source_files=self.source_files,
90
+ )
91
+ return config.to_json(indent=2)
92
+
93
+ def save(self) -> None:
94
+ """Save the MixtureDatabase as a JSON file"""
95
+ from os.path import join
96
+
97
+ json_name = join(self.location, "mixdb.json")
98
+ with open(file=json_name, mode="w") as file:
99
+ file.write(self.json)
100
+
101
+ @cached_property
102
+ def fg_config(self) -> FeatureGeneratorConfig:
103
+ return FeatureGeneratorConfig(
104
+ feature_mode=self.feature,
105
+ truth_parameters=self.truth_parameters,
106
+ )
107
+
108
+ @cached_property
109
+ def fg_info(self) -> FeatureGeneratorInfo:
110
+ from .helpers import get_feature_generator_info
111
+
112
+ return get_feature_generator_info(self.fg_config)
113
+
114
+ @cached_property
115
+ def truth_parameters(self) -> dict[str, dict[str, int | None]]:
116
+ with self.db() as c:
117
+ rows = c.execute("SELECT category, name, parameters FROM truth_parameters").fetchall()
118
+ truth_parameters: dict[str, dict[str, int | None]] = {}
119
+ for row in rows:
120
+ category, name, parameters = row
121
+ if category not in truth_parameters:
122
+ truth_parameters[category] = {}
123
+ truth_parameters[category][name] = parameters
124
+ return truth_parameters
125
+
126
+ @cached_property
127
+ def num_classes(self) -> int:
128
+ with self.db() as c:
129
+ return int(c.execute("SELECT num_classes FROM top").fetchone()[0])
130
+
131
+ @cached_property
132
+ def asr_configs(self) -> ASRConfigs:
133
+ import json
134
+
135
+ with self.db() as c:
136
+ return json.loads(c.execute("SELECT asr_configs FROM top").fetchone()[0])
137
+
138
+ @cached_property
139
+ def supported_metrics(self) -> MetricDocs:
140
+ metrics = MetricDocs(
141
+ [
142
+ MetricDoc("Mixture Metrics", "mxsnr", "SNR specification in dB"),
143
+ MetricDoc(
144
+ "Mixture Metrics",
145
+ "mxssnr_avg",
146
+ "Segmental SNR average over all frames",
147
+ ),
148
+ MetricDoc(
149
+ "Mixture Metrics",
150
+ "mxssnr_std",
151
+ "Segmental SNR standard deviation over all frames",
152
+ ),
153
+ MetricDoc(
154
+ "Mixture Metrics",
155
+ "mxssnrdb_avg",
156
+ "Segmental SNR average of the dB frame values over all frames",
157
+ ),
158
+ MetricDoc(
159
+ "Mixture Metrics",
160
+ "mxssnrdb_std",
161
+ "Segmental SNR standard deviation of the dB frame values over all frames",
162
+ ),
163
+ MetricDoc(
164
+ "Mixture Metrics",
165
+ "mxssnrf_avg",
166
+ "Per-bin segmental SNR average over all frames (using feature transform)",
167
+ ),
168
+ MetricDoc(
169
+ "Mixture Metrics",
170
+ "mxssnrf_std",
171
+ "Per-bin segmental SNR standard deviation over all frames (using feature transform)",
172
+ ),
173
+ MetricDoc(
174
+ "Mixture Metrics",
175
+ "mxssnrdbf_avg",
176
+ "Per-bin segmental average of the dB frame values over all frames (using feature transform)",
177
+ ),
178
+ MetricDoc(
179
+ "Mixture Metrics",
180
+ "mxssnrdbf_std",
181
+ "Per-bin segmental standard deviation of the dB frame values over all frames (using feature transform)",
182
+ ),
183
+ MetricDoc("Mixture Metrics", "mxpesq", "PESQ of mixture versus true sources"),
184
+ MetricDoc(
185
+ "Mixture Metrics",
186
+ "mxwsdr",
187
+ "Weighted signal distortion ratio of mixture versus true sources",
188
+ ),
189
+ MetricDoc(
190
+ "Mixture Metrics",
191
+ "mxpd",
192
+ "Phase distance between mixture and true sources",
193
+ ),
194
+ MetricDoc(
195
+ "Mixture Metrics",
196
+ "mxstoi",
197
+ "Short term objective intelligibility of mixture versus true sources",
198
+ ),
199
+ MetricDoc(
200
+ "Mixture Metrics",
201
+ "mxcsig",
202
+ "Predicted rating of speech distortion of mixture versus true sources",
203
+ ),
204
+ MetricDoc(
205
+ "Mixture Metrics",
206
+ "mxcbak",
207
+ "Predicted rating of background distortion of mixture versus true sources",
208
+ ),
209
+ MetricDoc(
210
+ "Mixture Metrics",
211
+ "mxcovl",
212
+ "Predicted rating of overall quality of mixture versus true sources",
213
+ ),
214
+ MetricDoc("Mixture Metrics", "ssnr", "Segmental SNR"),
215
+ MetricDoc("Mixture Metrics", "mxdco", "Mixture DC offset"),
216
+ MetricDoc("Mixture Metrics", "mxmin", "Mixture min level"),
217
+ MetricDoc("Mixture Metrics", "mxmax", "Mixture max levl"),
218
+ MetricDoc("Mixture Metrics", "mxpkdb", "Mixture Pk lev dB"),
219
+ MetricDoc("Mixture Metrics", "mxlrms", "Mixture RMS lev dB"),
220
+ MetricDoc("Mixture Metrics", "mxpkr", "Mixture RMS Pk dB"),
221
+ MetricDoc("Mixture Metrics", "mxtr", "Mixture RMS Tr dB"),
222
+ MetricDoc("Mixture Metrics", "mxcr", "Mixture Crest factor"),
223
+ MetricDoc("Mixture Metrics", "mxfl", "Mixture Flat factor"),
224
+ MetricDoc("Mixture Metrics", "mxpkc", "Mixture Pk count"),
225
+ MetricDoc("Sources Metrics", "sdco", "Sources DC offset"),
226
+ MetricDoc("Sources Metrics", "smin", "Sources min level"),
227
+ MetricDoc("Sources Metrics", "smax", "Sources max levl"),
228
+ MetricDoc("Sources Metrics", "spkdb", "Sources Pk lev dB"),
229
+ MetricDoc("Sources Metrics", "slrms", "Sources RMS lev dB"),
230
+ MetricDoc("Sources Metrics", "spkr", "Sources RMS Pk dB"),
231
+ MetricDoc("Sources Metrics", "str", "Sources RMS Tr dB"),
232
+ MetricDoc("Sources Metrics", "scr", "Sources Crest factor"),
233
+ MetricDoc("Sources Metrics", "sfl", "Sources Flat factor"),
234
+ MetricDoc("Sources Metrics", "spkc", "Sources Pk count"),
235
+ MetricDoc("Source Metrics", "mxsdco", "Source DC offset"),
236
+ MetricDoc("Source Metrics", "mxsmin", "Source min level"),
237
+ MetricDoc("Source Metrics", "mxsmax", "Source max levl"),
238
+ MetricDoc("Source Metrics", "mxspkdb", "Source Pk lev dB"),
239
+ MetricDoc("Source Metrics", "mxslrms", "Source RMS lev dB"),
240
+ MetricDoc("Source Metrics", "mxspkr", "Source RMS Pk dB"),
241
+ MetricDoc("Source Metrics", "mxstr", "Source RMS Tr dB"),
242
+ MetricDoc("Source Metrics", "mxscr", "Source Crest factor"),
243
+ MetricDoc("Source Metrics", "mxsfl", "Source Flat factor"),
244
+ MetricDoc("Source Metrics", "mxspkc", "Source Pk count"),
245
+ MetricDoc("Noise Metrics", "ndco", "Noise DC offset"),
246
+ MetricDoc("Noise Metrics", "nmin", "Noise min level"),
247
+ MetricDoc("Noise Metrics", "nmax", "Noise max levl"),
248
+ MetricDoc("Noise Metrics", "npkdb", "Noise Pk lev dB"),
249
+ MetricDoc("Noise Metrics", "nlrms", "Noise RMS lev dB"),
250
+ MetricDoc("Noise Metrics", "npkr", "Noise RMS Pk dB"),
251
+ MetricDoc("Noise Metrics", "ntr", "Noise RMS Tr dB"),
252
+ MetricDoc("Noise Metrics", "ncr", "Noise Crest factor"),
253
+ MetricDoc("Noise Metrics", "nfl", "Noise Flat factor"),
254
+ MetricDoc("Noise Metrics", "npkc", "Noise Pk count"),
255
+ MetricDoc(
256
+ "Truth Metrics",
257
+ "sedavg",
258
+ "(not implemented) Average SED activity over all frames [truth_parameters, 1]",
259
+ ),
260
+ MetricDoc(
261
+ "Truth Metrics",
262
+ "sedcnt",
263
+ "(not implemented) Count in number of frames that SED is active [truth_parameters, 1]",
264
+ ),
265
+ MetricDoc(
266
+ "Truth Metrics",
267
+ "sedtop3",
268
+ "(not implemented) 3 most active by largest sedavg [3, 1]",
269
+ ),
270
+ MetricDoc(
271
+ "Truth Metrics",
272
+ "sedtopn",
273
+ "(not implemented) N most active by largest sedavg [N, 1]",
274
+ ),
275
+ ]
276
+ )
277
+ for name in self.asr_configs:
278
+ metrics.append(
279
+ MetricDoc(
280
+ "Source Metrics",
281
+ f"mxsasr.{name}",
282
+ f"Source ASR text using {name} ASR as defined in mixdb asr_configs parameter",
283
+ )
284
+ )
285
+ metrics.append(
286
+ MetricDoc(
287
+ "Sources Metrics",
288
+ f"sasr.{name}",
289
+ f"Sources ASR text using {name} ASR as defined in mixdb asr_configs parameter",
290
+ )
291
+ )
292
+ metrics.append(
293
+ MetricDoc(
294
+ "Mixture Metrics",
295
+ f"mxasr.{name}",
296
+ f"ASR text using {name} ASR as defined in mixdb asr_configs parameter",
297
+ )
298
+ )
299
+ metrics.append(
300
+ MetricDoc(
301
+ "Sources Metrics",
302
+ f"basewer.{name}",
303
+ f"Word error rate of sasr.{name} vs. speech text metadata for the source",
304
+ )
305
+ )
306
+ metrics.append(
307
+ MetricDoc(
308
+ "Mixture Metrics",
309
+ f"mxwer.{name}",
310
+ f"Word error rate of mxasr.{name} vs. sasr.{name}",
311
+ )
312
+ )
313
+
314
+ return metrics
315
+
316
+ @cached_property
317
+ def class_balancing(self) -> bool:
318
+ with self.db() as c:
319
+ return bool(c.execute("SELECT class_balancing FROM top").fetchone()[0])
320
+
321
+ @cached_property
322
+ def feature(self) -> str:
323
+ with self.db() as c:
324
+ return str(c.execute("SELECT feature FROM top").fetchone()[0])
325
+
326
+ @cached_property
327
+ def fg_decimation(self) -> int:
328
+ return self.fg_info.decimation
329
+
330
+ @cached_property
331
+ def fg_stride(self) -> int:
332
+ return self.fg_info.stride
333
+
334
+ @cached_property
335
+ def fg_step(self) -> int:
336
+ return self.fg_info.step
337
+
338
+ @cached_property
339
+ def feature_parameters(self) -> int:
340
+ return self.fg_info.feature_parameters
341
+
342
+ @cached_property
343
+ def ft_config(self) -> TransformConfig:
344
+ return self.fg_info.ft_config
345
+
346
+ @cached_property
347
+ def eft_config(self) -> TransformConfig:
348
+ return self.fg_info.eft_config
349
+
350
+ @cached_property
351
+ def it_config(self) -> TransformConfig:
352
+ return self.fg_info.it_config
353
+
354
+ @cached_property
355
+ def transform_frame_ms(self) -> float:
356
+ from ..constants import SAMPLE_RATE
357
+
358
+ return float(self.ft_config.overlap) / float(SAMPLE_RATE / 1000)
359
+
360
+ @cached_property
361
+ def feature_ms(self) -> float:
362
+ return self.transform_frame_ms * self.fg_decimation * self.fg_stride
363
+
364
+ @cached_property
365
+ def feature_samples(self) -> int:
366
+ return self.ft_config.overlap * self.fg_decimation * self.fg_stride
367
+
368
+ @cached_property
369
+ def feature_step_ms(self) -> float:
370
+ return self.transform_frame_ms * self.fg_decimation * self.fg_step
371
+
372
+ @cached_property
373
+ def feature_step_samples(self) -> int:
374
+ return self.ft_config.overlap * self.fg_decimation * self.fg_step
375
+
376
+ def total_samples(self, m_ids: GeneralizedIDs = "*") -> int:
377
+ return sum([self.mixture(m_id).samples for m_id in self.mixids_to_list(m_ids)])
378
+
379
+ def total_transform_frames(self, m_ids: GeneralizedIDs = "*") -> int:
380
+ return self.total_samples(m_ids) // self.ft_config.overlap
381
+
382
+ def total_feature_frames(self, m_ids: GeneralizedIDs = "*") -> int:
383
+ return self.total_samples(m_ids) // self.feature_step_samples
384
+
385
+ def mixture_transform_frames(self, m_id: int) -> int:
386
+ from .helpers import frames_from_samples
387
+
388
+ return frames_from_samples(self.mixture(m_id).samples, self.ft_config.overlap)
389
+
390
+ def mixture_feature_frames(self, m_id: int) -> int:
391
+ from .helpers import frames_from_samples
392
+
393
+ return frames_from_samples(self.mixture(m_id).samples, self.feature_step_samples)
394
+
395
+ def mixids_to_list(self, m_ids: GeneralizedIDs = "*") -> list[int]:
396
+ """Resolve generalized mixture IDs to a list of integers
397
+
398
+ :param m_ids: Generalized mixture IDs
399
+ :return: List of mixture ID integers
400
+ """
401
+ from .helpers import generic_ids_to_list
402
+
403
+ return generic_ids_to_list(self.num_mixtures, m_ids)
404
+
405
+ @cached_property
406
+ def class_labels(self) -> list[str]:
407
+ """Get class labels from db
408
+
409
+ :return: Class labels
410
+ """
411
+ with self.db() as c:
412
+ return [str(item[0]) for item in c.execute("SELECT label FROM class_label ORDER BY id").fetchall()]
413
+
414
+ @cached_property
415
+ def class_weights_thresholds(self) -> list[float]:
416
+ """Get class weights thresholds from db
417
+
418
+ :return: Class weights thresholds
419
+ """
420
+ with self.db() as c:
421
+ return [float(item[0]) for item in c.execute("SELECT threshold FROM class_weights_threshold").fetchall()]
422
+
423
+ def category_truth_configs(self, category: str) -> dict[str, str]:
424
+ return _category_truth_configs(self.db, category, self.use_cache)
425
+
426
+ def source_truth_configs(self, s_id: int) -> TruthConfigs:
427
+ return _source_truth_configs(self.db, s_id, self.use_cache)
428
+
429
+ def mixture_truth_configs(self, m_id: int) -> TruthsConfigs:
430
+ mixture = self.mixture(m_id)
431
+ return {
432
+ category: self.source_truth_configs(mixture.all_sources[category].file_id)
433
+ for category in mixture.all_sources
434
+ }
435
+
436
+ @cached_property
437
+ def random_snrs(self) -> list[float]:
438
+ """Get random snrs from db
439
+
440
+ :return: Random SNRs
441
+ """
442
+ with self.db() as c:
443
+ return list(
444
+ {float(item[0]) for item in c.execute("SELECT snr FROM source WHERE snr_random == 1").fetchall()}
445
+ )
446
+
447
+ @cached_property
448
+ def snrs(self) -> list[float]:
449
+ """Get snrs from db
450
+
451
+ :return: SNRs
452
+ """
453
+ with self.db() as c:
454
+ return list(
455
+ {float(item[0]) for item in c.execute("SELECT snr FROM source WHERE snr_random == 0").fetchall()}
456
+ )
457
+
458
+ @cached_property
459
+ def all_snrs(self) -> list[UniversalSNR]:
460
+ return sorted(
461
+ set(
462
+ [UniversalSNR(is_random=False, value=snr) for snr in self.snrs]
463
+ + [UniversalSNR(is_random=True, value=snr) for snr in self.random_snrs]
464
+ )
465
+ )
466
+
467
+ @cached_property
468
+ def spectral_masks(self) -> list[SpectralMask]:
469
+ """Get spectral masks from db
470
+
471
+ :return: Spectral masks
472
+ """
473
+ from .db_datatypes import SpectralMaskRecord
474
+
475
+ with self.db() as c:
476
+ spectral_masks = [
477
+ SpectralMaskRecord(*result) for result in c.execute("SELECT * FROM spectral_mask").fetchall()
478
+ ]
479
+ return [
480
+ SpectralMask(
481
+ f_max_width=spectral_mask.f_max_width,
482
+ f_num=spectral_mask.f_num,
483
+ t_max_width=spectral_mask.t_max_width,
484
+ t_num=spectral_mask.t_num,
485
+ t_max_percent=spectral_mask.t_max_percent,
486
+ )
487
+ for spectral_mask in spectral_masks
488
+ ]
489
+
490
+ def spectral_mask(self, sm_id: int) -> SpectralMask:
491
+ """Get spectral mask with ID from db
492
+
493
+ :param sm_id: Spectral mask ID
494
+ :return: Spectral mask
495
+ """
496
+ return _spectral_mask(self.db, sm_id, self.use_cache)
497
+
498
+ @cached_property
499
+ def source_files(self) -> dict[str, list[SourceFile]]:
500
+ """Get source files from db
501
+
502
+ :return: Source files
503
+ """
504
+ import json
505
+
506
+ from ..datatypes import TruthConfig
507
+ from ..datatypes import TruthConfigs
508
+ from .db_datatypes import SourceFileRecord
509
+
510
+ with self.db() as c:
511
+ source_files: dict[str, list[SourceFile]] = {}
512
+ categories = c.execute("SELECT DISTINCT category FROM source_file").fetchall()
513
+ for category in categories:
514
+ source_files[category[0]] = []
515
+ source_file_records = [
516
+ SourceFileRecord(*result)
517
+ for result in c.execute("SELECT * FROM source_file WHERE ? = category", (category[0],)).fetchall()
518
+ ]
519
+ for source_file_record in source_file_records:
520
+ truth_configs: TruthConfigs = {}
521
+ for truth_config_records in c.execute(
522
+ """
523
+ SELECT truth_config.config
524
+ FROM truth_config, source_file_truth_config
525
+ WHERE ? = source_file_truth_config.source_file_id
526
+ AND truth_config.id = source_file_truth_config.truth_config_id
527
+ """,
528
+ (source_file_record.id,),
529
+ ).fetchall():
530
+ truth_config = json.loads(truth_config_records[0])
531
+ truth_configs[truth_config["name"]] = TruthConfig(
532
+ function=truth_config["function"],
533
+ stride_reduction=truth_config["stride_reduction"],
534
+ config=truth_config["config"],
535
+ )
536
+ source_files[source_file_record.category].append(
537
+ SourceFile(
538
+ id=source_file_record.id,
539
+ category=source_file_record.category,
540
+ name=source_file_record.name,
541
+ samples=source_file_record.samples,
542
+ class_indices=json.loads(source_file_record.class_indices),
543
+ level_type=source_file_record.level_type,
544
+ truth_configs=truth_configs,
545
+ speaker_id=source_file_record.speaker_id,
546
+ )
547
+ )
548
+ return source_files
549
+
550
+ @cached_property
551
+ def source_file_ids(self) -> dict[str, list[int]]:
552
+ """Get source file IDs from db
553
+
554
+ :return: Dictionary of a list of source file IDs
555
+ """
556
+ with self.db() as c:
557
+ source_file_ids: dict[str, list[int]] = {}
558
+ categories = c.execute("SELECT DISTINCT category FROM source_file").fetchall()
559
+ for category in categories:
560
+ source_file_ids[category[0]] = [
561
+ int(item[0])
562
+ for item in c.execute("SELECT id FROM source_file WHERE ? = category", (category[0],)).fetchall()
563
+ ]
564
+ return source_file_ids
565
+
566
+ def source_file(self, s_id: int) -> SourceFile:
567
+ """Get the source file with ID from db
568
+
569
+ :param s_id: Source file ID
570
+ :return: Source file
571
+ """
572
+ return _source_file(self.db, s_id, self.use_cache)
573
+
574
+ def num_source_files(self, category: str) -> int:
575
+ """Get the number of source files from the category from db
576
+
577
+ :param category: Source category
578
+ :return: Number of source files
579
+ """
580
+ return _num_source_files(self.db, category, self.use_cache)
581
+
582
+ @cached_property
583
+ def ir_files(self) -> list[ImpulseResponseFile]:
584
+ """Get impulse response files from db
585
+
586
+ :return: Impulse response files
587
+ """
588
+ from .db_datatypes import ImpulseResponseFileRecord
589
+
590
+ with self.db() as c:
591
+ files: list[ImpulseResponseFile] = []
592
+ entries = c.execute("SELECT * FROM ir_file").fetchall()
593
+ for entry in entries:
594
+ file = ImpulseResponseFileRecord(*entry)
595
+
596
+ tags = [
597
+ tag[0]
598
+ for tag in c.execute(
599
+ """
600
+ SELECT ir_tag.tag
601
+ FROM ir_tag, ir_file_ir_tag
602
+ WHERE ? = ir_file_ir_tag.file_id
603
+ AND ir_tag.id = ir_file_ir_tag.tag_id
604
+ """,
605
+ (file.id,),
606
+ ).fetchall()
607
+ ]
608
+
609
+ files.append(
610
+ ImpulseResponseFile(
611
+ delay=file.delay,
612
+ name=file.name,
613
+ tags=tags,
614
+ )
615
+ )
616
+
617
+ return files
618
+
619
+ @cached_property
620
+ def ir_file_ids(self) -> list[int]:
621
+ """Get impulse response file IDs from db
622
+
623
+ :return: List of impulse response file IDs
624
+ """
625
+ with self.db() as c:
626
+ return [int(item[0]) for item in c.execute("SELECT id FROM ir_file").fetchall()]
627
+
628
+ def ir_file_ids_for_tag(self, tag: str) -> list[int]:
629
+ """Get impulse response file IDs for the given tag from db
630
+
631
+ :return: List of impulse response file IDs for the given tag
632
+ """
633
+ with self.db() as c:
634
+ tag_id = c.execute("SELECT id FROM ir_tag WHERE ? = tag", (tag,)).fetchone()
635
+ if not tag_id:
636
+ return []
637
+
638
+ return [
639
+ int(item[0] - 1)
640
+ for item in c.execute("SELECT file_id FROM ir_file_ir_tag WHERE ? = tag_id", (tag_id[0],)).fetchall()
641
+ ]
642
+
643
+ def ir_file(self, ir_id: int) -> str:
644
+ """Get impulse response file name with ID from db
645
+
646
+ :param ir_id: Impulse response file ID
647
+ :return: Impulse response file name
648
+ """
649
+ return _ir_file(self.db, ir_id, self.use_cache)
650
+
651
+ def ir_delay(self, ir_id: int) -> int:
652
+ """Get impulse response delay with ID from db
653
+
654
+ :param ir_id: Impulse response file ID
655
+ :return: Impulse response delay
656
+ """
657
+ return _ir_delay(self.db, ir_id, self.use_cache)
658
+
659
+ @cached_property
660
+ def num_ir_files(self) -> int:
661
+ """Get number of impulse response files from db
662
+
663
+ :return: Number of impulse response files
664
+ """
665
+ with self.db() as c:
666
+ return int(c.execute("SELECT count(id) FROM ir_file").fetchone()[0])
667
+
668
+ @cached_property
669
+ def ir_tags(self) -> list[str]:
670
+ """Get tags of impulse response files from db
671
+
672
+ :return: Tags of impulse response files
673
+ """
674
+ with self.db() as c:
675
+ return [tag[0] for tag in c.execute("SELECT tag FROM ir_tag").fetchall()]
676
+
677
+ @property
678
+ def mixtures(self) -> list[Mixture]:
679
+ """Get mixtures from db
680
+
681
+ :return: Mixtures
682
+ """
683
+ from .db_datatypes import MixtureRecord
684
+ from .db_datatypes import SourceRecord
685
+ from .helpers import to_mixture
686
+ from .helpers import to_source
687
+
688
+ with self.db() as c:
689
+ mixtures: list[Mixture] = []
690
+ for mixture in [MixtureRecord(*record) for record in c.execute("SELECT * FROM mixture").fetchall()]:
691
+ sources_list = [
692
+ to_source(SourceRecord(*source))
693
+ for source in c.execute(
694
+ """
695
+ SELECT source.*
696
+ FROM source, mixture_source
697
+ WHERE ? = mixture_source.mixture_id AND source.id = mixture_source.source_id
698
+ """,
699
+ (mixture.id,),
700
+ ).fetchall()
701
+ ]
702
+
703
+ sources: Sources = {}
704
+ for source in sources_list:
705
+ sources[self.source_file(source.file_id).category] = source
706
+
707
+ mixtures.append(to_mixture(mixture, sources))
708
+
709
+ return mixtures
710
+
711
+ @cached_property
712
+ def mixture_ids(self) -> list[int]:
713
+ """Get mixture IDs from db
714
+
715
+ :return: List of zero-based mixture IDs
716
+ """
717
+ with self.db() as c:
718
+ return [int(item[0]) - 1 for item in c.execute("SELECT id FROM mixture").fetchall()]
719
+
720
+ def mixture(self, m_id: int) -> Mixture:
721
+ """Get mixture record with ID from db
722
+
723
+ :param m_id: Zero-based mixture ID
724
+ :return: Mixture record
725
+ """
726
+ return _mixture(self.db, m_id, self.use_cache)
727
+
728
+ @cached_property
729
+ def mixid_width(self) -> int:
730
+ with self.db() as c:
731
+ return int(c.execute("SELECT mixid_width FROM top").fetchone()[0])
732
+
733
+ def mixture_location(self, m_id: int) -> str:
734
+ """Get the file location for the give mixture ID
735
+
736
+ :param m_id: Zero-based mixture ID
737
+ :return: File location
738
+ """
739
+ from os.path import join
740
+
741
+ return join(self.location, self.mixture(m_id).name)
742
+
743
+ @cached_property
744
+ def num_mixtures(self) -> int:
745
+ """Get the number of mixtures from db
746
+
747
+ :return: Number of mixtures
748
+ """
749
+ with self.db() as c:
750
+ return int(c.execute("SELECT count(id) FROM mixture").fetchone()[0])
751
+
752
+ def read_mixture_data(self, m_id: int, items: list[str] | str) -> dict[str, Any]:
753
+ """Read mixture data
754
+
755
+ :param m_id: Zero-based mixture ID
756
+ :param items: String(s) of dataset(s) to retrieve
757
+ :return: Dictionary of name: data
758
+ """
759
+ from .data_io import read_cached_data
760
+
761
+ return read_cached_data(self.location, "mixture", self.mixture(m_id).name, items)
762
+
763
+ def read_source_audio(self, s_id: int) -> AudioT:
764
+ """Read source audio
765
+
766
+ :param s_id: Source ID
767
+ :return: Source audio
768
+ """
769
+ from .audio import read_audio
770
+
771
+ return read_audio(self.source_file(s_id).name, self.use_cache)
772
+
773
+ def mixture_class_indices(self, m_id: int) -> list[int]:
774
+ class_indices: list[int] = []
775
+ for s_id in self.mixture(m_id).source_ids.values():
776
+ class_indices.extend(self.source_file(s_id).class_indices)
777
+ return sorted(set(class_indices))
778
+
779
+ def mixture_sources(self, m_id: int, force: bool = False, cache: bool = False) -> SourcesAudioT:
780
+ """Get the pre-truth source audio data (one per source in the mixture) for the given mixture ID
781
+
782
+ :param m_id: Zero-based mixture ID
783
+ :param force: Force computing data from original sources regardless of whether cached data exists
784
+ :param cache: Cache result
785
+ :return: Dictionary of pre-truth source audio data (one per source in the mixture)
786
+ """
787
+ from .data_io import write_cached_data
788
+ from .effects import apply_effects
789
+ from .effects import conform_audio_to_length
790
+
791
+ if not force:
792
+ sources = self.read_mixture_data(m_id, "sources")["sources"]
793
+ if sources is not None:
794
+ return sources
795
+
796
+ mixture = self.mixture(m_id)
797
+ if mixture is None:
798
+ raise ValueError(f"Could not find mixture for m_id: {m_id}")
799
+
800
+ sources = {}
801
+ for category, source in mixture.all_sources.items():
802
+ source = mixture.all_sources[category]
803
+ audio = self.read_source_audio(source.file_id)
804
+ audio = apply_effects(self, audio, source.effects, pre=True, post=False)
805
+ audio = conform_audio_to_length(audio, mixture.samples, source.loop, source.start)
806
+ sources[category] = audio
807
+
808
+ if cache:
809
+ write_cached_data(
810
+ location=self.location,
811
+ name="mixture",
812
+ index=mixture.name,
813
+ items={"sources": sources},
814
+ )
815
+
816
+ return sources
817
+
818
+ def mixture_sources_f(
819
+ self,
820
+ m_id: int,
821
+ sources: SourcesAudioT | None = None,
822
+ force: bool = False,
823
+ cache: bool = False,
824
+ ) -> SourcesAudioF:
825
+ """Get the pre-truth source transform data (one per source in the mixture) for the given mixture ID
826
+
827
+ :param m_id: Zero-based mixture ID
828
+ :param sources: Dictionary of pre-truth source audio data (one per source in the mixture)
829
+ :param force: Force computing data from original sources regardless of whether cached data exists
830
+ :param cache: Cache result
831
+ :return: Dictionary of pre-truth source transform data (one per source in the mixture)
832
+ """
833
+ from .data_io import write_cached_data
834
+ from .helpers import forward_transform
835
+
836
+ if sources is None:
837
+ sources = self.mixture_sources(m_id, force)
838
+
839
+ sources_f = {category: forward_transform(sources[category], self.ft_config) for category in sources}
840
+
841
+ if cache:
842
+ write_cached_data(
843
+ location=self.location,
844
+ name="mixture",
845
+ index=self.mixture(m_id).name,
846
+ items={"sources_f": sources_f},
847
+ )
848
+
849
+ return sources_f
850
+
851
+ def mixture_source(
852
+ self,
853
+ m_id: int,
854
+ sources: SourcesAudioT | None = None,
855
+ force: bool = False,
856
+ cache: bool = False,
857
+ ) -> AudioT:
858
+ """Get the post-truth, summed, and gained source audio data for the given mixture ID
859
+
860
+ :param m_id: Zero-based mixture ID
861
+ :param sources: Dictionary of pre-truth source audio data (one per source in the mixture)
862
+ :param force: Force computing data from original sources regardless of whether cached data exists
863
+ :param cache: Cache result
864
+ :return: Post-truth, gained, and summed source audio data
865
+ """
866
+ import numpy as np
867
+
868
+ from .data_io import write_cached_data
869
+ from .effects import apply_effects
870
+
871
+ if not force:
872
+ source = self.read_mixture_data(m_id, "source")["source"]
873
+ if source is not None:
874
+ return source
875
+
876
+ if sources is None:
877
+ sources = self.mixture_sources(m_id, force)
878
+
879
+ mixture = self.mixture(m_id)
880
+
881
+ source = np.sum(
882
+ [
883
+ apply_effects(
884
+ self,
885
+ audio=sources[category],
886
+ effects=mixture.all_sources[category].effects,
887
+ pre=False,
888
+ post=True,
889
+ )
890
+ * mixture.all_sources[category].snr_gain
891
+ for category in sources
892
+ if category != "noise"
893
+ ],
894
+ axis=0,
895
+ )
896
+
897
+ if cache:
898
+ write_cached_data(
899
+ location=self.location,
900
+ name="mixture",
901
+ index=mixture.name,
902
+ items={"source": source},
903
+ )
904
+
905
+ return source
906
+
907
+ def mixture_source_f(
908
+ self,
909
+ m_id: int,
910
+ sources: SourcesAudioT | None = None,
911
+ source: AudioT | None = None,
912
+ force: bool = False,
913
+ cache: bool = False,
914
+ ) -> AudioF:
915
+ """Get the post-truth, summed, and gained source transform data for the given mixture ID
916
+
917
+ :param m_id: Zero-based mixture ID
918
+ :param sources: Dictionary of pre-truth source audio data (one per source in the mixture)
919
+ :param source: Post-truth, gained, and summed source audio for the given m_id
920
+ :param force: Force computing data from original sources regardless of whether cached data exists
921
+ :param cache: Cache result
922
+ :return: Post-truth, gained, and summed source transform data
923
+ """
924
+ from .data_io import write_cached_data
925
+ from .helpers import forward_transform
926
+
927
+ if source is None:
928
+ source = self.mixture_source(m_id, sources, force)
929
+
930
+ source_f = forward_transform(source, self.ft_config)
931
+
932
+ if cache:
933
+ write_cached_data(
934
+ location=self.location,
935
+ name="mixture",
936
+ index=self.mixture(m_id).name,
937
+ items={"source_f": source_f},
938
+ )
939
+
940
+ return source_f
941
+
942
+ def mixture_noise(
943
+ self,
944
+ m_id: int,
945
+ sources: SourcesAudioT | None = None,
946
+ force: bool = False,
947
+ cache: bool = False,
948
+ ) -> AudioT:
949
+ """Get the post-truth and gained noise audio data for the given mixture ID
950
+
951
+ :param m_id: Zero-based mixture ID
952
+ :param sources: Dictionary of pre-truth source audio data (one per source in the mixture)
953
+ :param force: Force computing data from original sources regardless of whether cached data exists
954
+ :param cache: Cache result
955
+ :return: Post-truth and gained noise audio data
956
+ """
957
+ from .data_io import write_cached_data
958
+ from .effects import apply_effects
959
+
960
+ if not force:
961
+ noise = self.read_mixture_data(m_id, "noise")["noise"]
962
+ if noise is not None:
963
+ return noise
964
+
965
+ if sources is None:
966
+ sources = self.mixture_sources(m_id, force)
967
+
968
+ noise = self.mixture(m_id).noise
969
+ noise = apply_effects(self, sources["noise"], noise.effects, pre=False, post=True) * noise.snr_gain
970
+
971
+ if cache:
972
+ write_cached_data(
973
+ location=self.location,
974
+ name="mixture",
975
+ index=self.mixture(m_id).name,
976
+ items={"noise": noise},
977
+ )
978
+
979
+ return noise
980
+
981
+ def mixture_noise_f(
982
+ self,
983
+ m_id: int,
984
+ sources: SourcesAudioT | None = None,
985
+ noise: AudioT | None = None,
986
+ force: bool = False,
987
+ cache: bool = False,
988
+ ) -> AudioF:
989
+ """Get the post-truth and gained noise transform for the given mixture ID
990
+
991
+ :param m_id: Zero-based mixture ID
992
+ :param sources: Dictionary of pre-truth source audio data (one per source in the mixture)
993
+ :param noise: Post-truth and gained noise audio data
994
+ :param force: Force computing data from original sources regardless of whether cached data exists
995
+ :param cache: Cache result
996
+ :return: Post-truth and gained noise transform data
997
+ """
998
+ from .data_io import write_cached_data
999
+ from .helpers import forward_transform
1000
+
1001
+ if force or noise is None:
1002
+ noise = self.mixture_noise(m_id, sources, force)
1003
+
1004
+ noise_f = forward_transform(noise, self.ft_config)
1005
+ if cache:
1006
+ write_cached_data(
1007
+ location=self.location,
1008
+ name="mixture",
1009
+ index=self.mixture(m_id).name,
1010
+ items={"noise_f": noise_f},
1011
+ )
1012
+
1013
+ return noise_f
1014
+
1015
+ def mixture_mixture(
1016
+ self,
1017
+ m_id: int,
1018
+ sources: SourcesAudioT | None = None,
1019
+ source: AudioT | None = None,
1020
+ noise: AudioT | None = None,
1021
+ force: bool = False,
1022
+ cache: bool = False,
1023
+ ) -> AudioT:
1024
+ """Get the mixture audio data for the given mixture ID
1025
+
1026
+ :param m_id: Zero-based mixture ID
1027
+ :param sources: Dictionary of pre-truth source audio data (one per source in the mixture)
1028
+ :param source: Post-truth, gained, and summed source audio data
1029
+ :param noise: Post-truth and gained noise audio data
1030
+ :param force: Force computing data from original sources regardless of whether cached data exists
1031
+ :param cache: Cache result
1032
+ :return: Mixture audio data
1033
+ """
1034
+ from .data_io import write_cached_data
1035
+
1036
+ if not force:
1037
+ mixture = self.read_mixture_data(m_id, "mixture")["mixture"]
1038
+ if mixture is not None:
1039
+ return mixture
1040
+
1041
+ if source is None:
1042
+ source = self.mixture_source(m_id, sources, force)
1043
+
1044
+ if noise is None:
1045
+ noise = self.mixture_noise(m_id, sources, force)
1046
+
1047
+ mixture = source + noise
1048
+
1049
+ if cache:
1050
+ write_cached_data(
1051
+ location=self.location,
1052
+ name="mixture",
1053
+ index=self.mixture(m_id).name,
1054
+ items={"mixture": mixture},
1055
+ )
1056
+
1057
+ return mixture
1058
+
1059
+ def mixture_mixture_f(
1060
+ self,
1061
+ m_id: int,
1062
+ sources: SourcesAudioT | None = None,
1063
+ source: AudioT | None = None,
1064
+ noise: AudioT | None = None,
1065
+ mixture: AudioT | None = None,
1066
+ force: bool = False,
1067
+ cache: bool = False,
1068
+ ) -> AudioF:
1069
+ """Get the mixture transform for the given mixture ID
1070
+
1071
+ :param m_id: Zero-based mixture ID
1072
+ :param sources: Dictionary of pre-truth source audio data (one per source in the mixture)
1073
+ :param source: Post-truth, gained, and summed source audio data
1074
+ :param noise: Post-truth and gained noise audio data
1075
+ :param mixture: Mixture audio data
1076
+ :param force: Force computing data from original sources regardless of whether cached data exists
1077
+ :param cache: Cache result
1078
+ :return: Mixture transform data
1079
+ """
1080
+ from .data_io import write_cached_data
1081
+ from .helpers import forward_transform
1082
+ from .spectral_mask import apply_spectral_mask
1083
+
1084
+ if mixture is None:
1085
+ mixture = self.mixture_mixture(m_id, sources, source, noise, force)
1086
+
1087
+ mixture_f = forward_transform(mixture, self.ft_config)
1088
+
1089
+ m = self.mixture(m_id)
1090
+ if m.spectral_mask_id is not None:
1091
+ mixture_f = apply_spectral_mask(
1092
+ audio_f=mixture_f,
1093
+ spectral_mask=self.spectral_mask(int(m.spectral_mask_id)),
1094
+ seed=m.spectral_mask_seed,
1095
+ )
1096
+
1097
+ if cache:
1098
+ write_cached_data(
1099
+ location=self.location,
1100
+ name="mixture",
1101
+ index=self.mixture(m_id).name,
1102
+ items={"mixture_f": mixture_f},
1103
+ )
1104
+
1105
+ return mixture_f
1106
+
1107
+ def mixture_truth_t(self, m_id: int, force: bool = False, cache: bool = False) -> TruthsDict:
1108
+ """Get the truth_t data for the given mixture ID
1109
+
1110
+ :param m_id: Zero-based mixture ID
1111
+ :param force: Force computing data from original sources regardless of whether cached data exists
1112
+ :param cache: Cache result
1113
+ :return: list of truth_t data
1114
+ """
1115
+ from .data_io import write_cached_data
1116
+ from .truth import truth_function
1117
+
1118
+ if not force:
1119
+ truth_t = self.read_mixture_data(m_id, "truth_t")["truth_t"]
1120
+ if truth_t is not None:
1121
+ return truth_t
1122
+
1123
+ truth_t = truth_function(self, m_id)
1124
+
1125
+ if cache:
1126
+ write_cached_data(
1127
+ location=self.location,
1128
+ name="mixture",
1129
+ index=self.mixture(m_id).name,
1130
+ items={"truth_t": truth_t},
1131
+ )
1132
+
1133
+ return truth_t
1134
+
1135
+ def mixture_segsnr_t(
1136
+ self,
1137
+ m_id: int,
1138
+ sources: SourcesAudioT | None = None,
1139
+ source: AudioT | None = None,
1140
+ noise: AudioT | None = None,
1141
+ force: bool = False,
1142
+ cache: bool = False,
1143
+ ) -> Segsnr:
1144
+ """Get the segsnr_t data for the given mixture ID
1145
+
1146
+ :param m_id: Zero-based mixture ID
1147
+ :param sources: Dictionary of pre-truth source audio data (one per source in the mixture)
1148
+ :param source: Post-truth, gained, and summed source audio data
1149
+ :param noise: Post-truth and gained noise audio data
1150
+ :param force: Force computing data from original sources regardless of whether cached data exists
1151
+ :param cache: Cache result
1152
+ :return: segsnr_t data
1153
+ """
1154
+ import numpy as np
1155
+ import torch
1156
+ from pyaaware import ForwardTransform
1157
+
1158
+ from .data_io import write_cached_data
1159
+
1160
+ if not force:
1161
+ segsnr_t = self.read_mixture_data(m_id, "segsnr_t")["segsnr_t"]
1162
+ if segsnr_t is not None:
1163
+ return segsnr_t
1164
+
1165
+ if source is None:
1166
+ source = self.mixture_source(m_id, sources, force)
1167
+
1168
+ if noise is None:
1169
+ noise = self.mixture_noise(m_id, sources, force)
1170
+
1171
+ ft = ForwardTransform(
1172
+ length=self.ft_config.length,
1173
+ overlap=self.ft_config.overlap,
1174
+ bin_start=self.ft_config.bin_start,
1175
+ bin_end=self.ft_config.bin_end,
1176
+ ttype=self.ft_config.ttype,
1177
+ )
1178
+
1179
+ mixture = self.mixture(m_id)
1180
+
1181
+ segsnr_t = np.empty(mixture.samples, dtype=np.float32)
1182
+
1183
+ source_energy = ft.execute_all(torch.from_numpy(source))[1].numpy()
1184
+ noise_energy = ft.execute_all(torch.from_numpy(noise))[1].numpy()
1185
+
1186
+ offsets = range(0, mixture.samples, self.ft_config.overlap)
1187
+ if len(source_energy) != len(offsets):
1188
+ raise ValueError(
1189
+ f"Number of frames in energy, {len(source_energy)}, is not number of frames in mixture, {len(offsets)}"
1190
+ )
1191
+
1192
+ for idx, offset in enumerate(offsets):
1193
+ indices = slice(offset, offset + self.ft_config.overlap)
1194
+
1195
+ if noise_energy[idx] == 0:
1196
+ snr = np.float32(np.inf)
1197
+ else:
1198
+ snr = np.float32(source_energy[idx] / noise_energy[idx])
1199
+
1200
+ segsnr_t[indices] = snr
1201
+
1202
+ if cache:
1203
+ write_cached_data(
1204
+ location=self.location,
1205
+ name="mixture",
1206
+ index=mixture.name,
1207
+ items={"segsnr_t": segsnr_t},
1208
+ )
1209
+
1210
+ return segsnr_t
1211
+
1212
+ def mixture_segsnr(
1213
+ self,
1214
+ m_id: int,
1215
+ segsnr_t: Segsnr | None = None,
1216
+ sources: SourcesAudioT | None = None,
1217
+ source: AudioT | None = None,
1218
+ noise: AudioT | None = None,
1219
+ force: bool = False,
1220
+ cache: bool = False,
1221
+ ) -> Segsnr:
1222
+ """Get the segsnr data for the given mixture ID
1223
+
1224
+ :param m_id: Zero-based mixture ID
1225
+ :param segsnr_t: segsnr_t data
1226
+ :param sources: Dictionary of pre-truth source audio data (one per source in the mixture)
1227
+ :param source: Post-truth, gained, and summed source audio data
1228
+ :param noise: Post-truth and gained noise audio data
1229
+ :param force: Force computing data from original sources regardless of whether cached data exists
1230
+ :param cache: Cache result
1231
+ :return: segsnr data
1232
+ """
1233
+ from .data_io import write_cached_data
1234
+
1235
+ if not force:
1236
+ segsnr = self.read_mixture_data(m_id, "segsnr")["segsnr"]
1237
+ if segsnr is not None:
1238
+ return segsnr
1239
+
1240
+ if segsnr_t is None:
1241
+ segsnr_t = self.mixture_segsnr_t(m_id, sources, source, noise, force)
1242
+
1243
+ segsnr = segsnr_t[0 :: self.ft_config.overlap]
1244
+
1245
+ if cache:
1246
+ write_cached_data(
1247
+ location=self.location,
1248
+ name="mixture",
1249
+ index=self.mixture(m_id).name,
1250
+ items={"segsnr": segsnr},
1251
+ )
1252
+
1253
+ return segsnr
1254
+
1255
+ def mixture_ft(
1256
+ self,
1257
+ m_id: int,
1258
+ sources: SourcesAudioT | None = None,
1259
+ source: AudioT | None = None,
1260
+ noise: AudioT | None = None,
1261
+ mixture_f: AudioF | None = None,
1262
+ mixture: AudioT | None = None,
1263
+ truth_t: TruthsDict | None = None,
1264
+ force: bool = False,
1265
+ cache: bool = False,
1266
+ ) -> tuple[Feature, TruthsDict]:
1267
+ """Get the feature and truth_f data for the given mixture ID
1268
+
1269
+ :param m_id: Zero-based mixture ID
1270
+ :param sources: Dictionary of pre-truth source audio data (one per source in the mixture)
1271
+ :param source: Post-truth, gained, and summed source audio data
1272
+ :param noise: Post-truth and gained noise audio data
1273
+ :param mixture_f: Mixture transform data
1274
+ :param mixture: Mixture audio data
1275
+ :param truth_t: truth_t
1276
+ :param force: Force computing data from original sources regardless of whether cached data exists
1277
+ :param cache: Cache result
1278
+ :return: Tuple of (feature, truth_f) data
1279
+ """
1280
+ from pyaaware import FeatureGenerator
1281
+
1282
+ from .data_io import write_cached_data
1283
+ from .truth import truth_stride_reduction
1284
+
1285
+ if not force:
1286
+ ft = self.read_mixture_data(m_id, ["feature", "truth_f"])
1287
+ if ft["feature"] is not None and ft["truth_f"] is not None:
1288
+ return ft["feature"], ft["truth_f"]
1289
+
1290
+ if mixture_f is None:
1291
+ mixture_f = self.mixture_mixture_f(
1292
+ m_id=m_id,
1293
+ sources=sources,
1294
+ source=source,
1295
+ noise=noise,
1296
+ mixture=mixture,
1297
+ force=force,
1298
+ )
1299
+
1300
+ if truth_t is None:
1301
+ truth_t = self.mixture_truth_t(m_id, force)
1302
+
1303
+ fg = FeatureGenerator(self.fg_config.feature_mode, self.fg_config.truth_parameters)
1304
+
1305
+ feature, truth_f = fg.execute_all(mixture_f, truth_t)
1306
+ if truth_f is None:
1307
+ raise TypeError("Unexpected truth of None from feature generator")
1308
+
1309
+ truth_configs = self.mixture_truth_configs(m_id)
1310
+ for category, configs in truth_configs.items():
1311
+ for name, config in configs.items():
1312
+ if self.truth_parameters[category][name] is not None:
1313
+ truth_f[category][name] = truth_stride_reduction(truth_f[category][name], config.stride_reduction)
1314
+
1315
+ if cache:
1316
+ write_cached_data(
1317
+ location=self.location,
1318
+ name="mixture",
1319
+ index=self.mixture(m_id).name,
1320
+ items={"feature": truth_f, "truth_f": truth_f},
1321
+ )
1322
+
1323
+ return feature, truth_f
1324
+
1325
+ def mixture_feature(
1326
+ self,
1327
+ m_id: int,
1328
+ sources: SourcesAudioT | None = None,
1329
+ noise: AudioT | None = None,
1330
+ mixture: AudioT | None = None,
1331
+ truth_t: TruthsDict | None = None,
1332
+ force: bool = False,
1333
+ cache: bool = False,
1334
+ ) -> Feature:
1335
+ """Get the feature data for the given mixture ID
1336
+
1337
+ :param m_id: Zero-based mixture ID
1338
+ :param sources: Dictionary of pre-truth source audio data (one per source in the mixture)
1339
+ :param noise: Post-truth and gained noise audio data
1340
+ :param mixture: Mixture audio data
1341
+ :param truth_t: truth_t
1342
+ :param force: Force computing data from original sources regardless of whether cached data exists
1343
+ :param cache: Cache result
1344
+ :return: Feature data
1345
+ """
1346
+ from .data_io import write_cached_data
1347
+
1348
+ feature = self.mixture_ft(
1349
+ m_id=m_id,
1350
+ sources=sources,
1351
+ noise=noise,
1352
+ mixture=mixture,
1353
+ truth_t=truth_t,
1354
+ force=force,
1355
+ )[0]
1356
+
1357
+ if cache:
1358
+ write_cached_data(
1359
+ location=self.location,
1360
+ name="mixture",
1361
+ index=self.mixture(m_id).name,
1362
+ items={"feature": feature},
1363
+ )
1364
+
1365
+ return feature
1366
+
1367
+ def mixture_truth_f(
1368
+ self,
1369
+ m_id: int,
1370
+ sources: SourcesAudioT | None = None,
1371
+ noise: AudioT | None = None,
1372
+ mixture: AudioT | None = None,
1373
+ truth_t: TruthsDict | None = None,
1374
+ force: bool = False,
1375
+ cache: bool = False,
1376
+ ) -> TruthDict:
1377
+ """Get the truth_f data for the given mixture ID
1378
+
1379
+ :param m_id: Zero-based mixture ID
1380
+ :param sources: Dictionary of pre-truth source audio data (one per source in the mixture)
1381
+ :param noise: Post-truth and gained noise audio data
1382
+ :param mixture: Mixture audio data
1383
+ :param truth_t: truth_t
1384
+ :param force: Force computing data from original sources regardless of whether cached data exists
1385
+ :param cache: Cache result
1386
+ :return: truth_f data
1387
+ """
1388
+ from .data_io import write_cached_data
1389
+
1390
+ truth_f = self.mixture_ft(
1391
+ m_id=m_id,
1392
+ sources=sources,
1393
+ noise=noise,
1394
+ mixture=mixture,
1395
+ truth_t=truth_t,
1396
+ force=force,
1397
+ )[1]
1398
+
1399
+ if cache:
1400
+ write_cached_data(
1401
+ location=self.location,
1402
+ name="mixture",
1403
+ index=self.mixture(m_id).name,
1404
+ items={"truth_f": truth_f},
1405
+ )
1406
+
1407
+ return truth_f
1408
+
1409
+ def mixture_class_count(self, m_id: int, truth_t: TruthsDict | None = None) -> dict[str, ClassCount]:
1410
+ """Compute the number of frames for which each class index is active for the given mixture ID
1411
+
1412
+ :param m_id: Zero-based mixture ID
1413
+ :param truth_t: truth_t
1414
+ :return: Dictionary of class counts
1415
+ """
1416
+ import numpy as np
1417
+
1418
+ if truth_t is None:
1419
+ truth_t = self.mixture_truth_t(m_id)
1420
+
1421
+ class_count: dict[str, ClassCount] = {}
1422
+
1423
+ truth_configs = self.mixture_truth_configs(m_id)
1424
+ for category in truth_configs:
1425
+ class_count[category] = [0] * self.num_classes
1426
+ for configs in truth_configs[category]:
1427
+ if "sed" in configs:
1428
+ for cl in range(self.num_classes):
1429
+ class_count[category][cl] = int(
1430
+ np.sum(truth_t[category]["sed"][:, cl] >= self.class_weights_thresholds[cl])
1431
+ )
1432
+
1433
+ return class_count
1434
+
1435
+ @cached_property
1436
+ def speaker_metadata_tiers(self) -> list[str]:
1437
+ import json
1438
+
1439
+ with self.db() as c:
1440
+ return json.loads(c.execute("SELECT speaker_metadata_tiers FROM top WHERE 1 = id").fetchone()[0])
1441
+
1442
+ @cached_property
1443
+ def textgrid_metadata_tiers(self) -> list[str]:
1444
+ import json
1445
+
1446
+ with self.db() as c:
1447
+ return json.loads(c.execute("SELECT textgrid_metadata_tiers FROM top WHERE 1 = id").fetchone()[0])
1448
+
1449
+ @cached_property
1450
+ def speech_metadata_tiers(self) -> list[str]:
1451
+ return sorted(set(self.speaker_metadata_tiers + self.textgrid_metadata_tiers))
1452
+
1453
+ def speaker(self, s_id: int | None, tier: str) -> str | None:
1454
+ return _speaker(self.db, s_id, tier, self.use_cache)
1455
+
1456
+ def speech_metadata(self, tier: str) -> list[str]:
1457
+ from .helpers import get_textgrid_tier_from_source_file
1458
+
1459
+ results: set[str] = set()
1460
+ if tier in self.textgrid_metadata_tiers:
1461
+ for source_files in self.source_files.values():
1462
+ for source_file in source_files:
1463
+ data = get_textgrid_tier_from_source_file(source_file.name, tier)
1464
+ if data is None:
1465
+ continue
1466
+ if isinstance(data, list):
1467
+ for item in data:
1468
+ results.add(item.label)
1469
+ else:
1470
+ results.add(data)
1471
+ elif tier in self.speaker_metadata_tiers:
1472
+ for source_files in self.source_files.values():
1473
+ for source_file in source_files:
1474
+ data = self.speaker(source_file.speaker_id, tier)
1475
+ if data is not None:
1476
+ results.add(data)
1477
+
1478
+ return sorted(results)
1479
+
1480
+ def mixture_speech_metadata(self, mixid: int, tier: str) -> dict[str, SpeechMetadata]:
1481
+ from praatio.utilities.constants import Interval
1482
+
1483
+ from .helpers import get_textgrid_tier_from_source_file
1484
+
1485
+ results: dict[str, SpeechMetadata] = {}
1486
+ is_textgrid = tier in self.textgrid_metadata_tiers
1487
+ if is_textgrid:
1488
+ for category, source in self.mixture(mixid).all_sources.items():
1489
+ data = get_textgrid_tier_from_source_file(self.source_file(source.file_id).name, tier)
1490
+ if isinstance(data, list):
1491
+ # Check for tempo effect and adjust Interval start and end data as needed
1492
+ entries = []
1493
+ for entry in data:
1494
+ entries.append(
1495
+ Interval(
1496
+ entry.start / source.pre_tempo,
1497
+ entry.end / source.pre_tempo,
1498
+ entry.label,
1499
+ )
1500
+ )
1501
+ results[category] = entries
1502
+ else:
1503
+ results[category] = data
1504
+ else:
1505
+ for category, source in self.mixture(mixid).all_sources.items():
1506
+ results[category] = self.speaker(self.source_file(source.file_id).speaker_id, tier)
1507
+
1508
+ return results
1509
+
1510
+ def mixids_for_speech_metadata(
1511
+ self,
1512
+ tier: str | None = None,
1513
+ value: str | None = None,
1514
+ where: str | None = None,
1515
+ ) -> dict[str, list[int]]:
1516
+ """Get a list of mixture IDs for the given speech metadata tier.
1517
+
1518
+ If 'where' is None, then include mixture IDs whose tier values are equal to the given 'value'.
1519
+ If 'where' is not None, then ignore 'value' and use the given SQL where clause to determine
1520
+ which entries to include.
1521
+
1522
+ Examples:
1523
+ >>> mixdb = MixtureDatabase('/mixdb_location')
1524
+
1525
+ >>> mixids = mixdb.mixids_for_speech_metadata('speaker_id', 'TIMIT_ABW0')
1526
+ Get mixture IDs for mixtures with speakers whose speaker_ids are 'TIMIT_ABW0'.
1527
+
1528
+ >>> mixids = mixdb.mixids_for_speech_metadata(where='age >= 27')
1529
+ Get mixture IDs for mixtures with speakers whose ages are greater than or equal to 27.
1530
+
1531
+ >>> mixids = mixdb.mixids_for_speech_metadata(where="dialect in ('New York City', 'Northern')")
1532
+ Get mixture IDs for mixtures with speakers whose dialects are either 'New York City' or 'Northern'.
1533
+ """
1534
+ if value is None and where is None:
1535
+ raise ValueError("Must provide either value or where")
1536
+
1537
+ if where is None:
1538
+ if tier is None:
1539
+ raise ValueError("Must provide tier")
1540
+ where = f"{tier} = '{value}'"
1541
+
1542
+ if tier is not None and tier in self.textgrid_metadata_tiers:
1543
+ raise ValueError(f"TextGrid tier data, '{tier}', is not supported in mixids_for_speech_metadata().")
1544
+
1545
+ with self.db() as c:
1546
+ results = c.execute(f"SELECT id FROM speaker WHERE {where}").fetchall()
1547
+ speaker_ids = ",".join(map(str, [i[0] for i in results]))
1548
+
1549
+ results = c.execute(f"SELECT id, category FROM source_file WHERE speaker_id IN ({speaker_ids})").fetchall()
1550
+ source_file_ids: dict[str, list[int]] = {}
1551
+ for result in results:
1552
+ source_file_id, category = result
1553
+ if category not in source_file_ids:
1554
+ source_file_ids[category] = [source_file_id]
1555
+ else:
1556
+ source_file_ids[category].append(source_file_id)
1557
+
1558
+ mixids: dict[str, list[int]] = {}
1559
+ for category in source_file_ids:
1560
+ id_str = ",".join(map(str, source_file_ids[category]))
1561
+ results = c.execute(f"SELECT id FROM source WHERE file_id IN ({id_str})").fetchall()
1562
+ source_ids = ",".join(map(str, [i[0] for i in results]))
1563
+
1564
+ results = c.execute(
1565
+ f"SELECT mixture_id FROM mixture_source WHERE source_id IN ({source_ids})"
1566
+ ).fetchall()
1567
+ mixids[category] = [mixture_id[0] - 1 for mixture_id in results]
1568
+
1569
+ return mixids
1570
+
1571
+ def mixture_all_speech_metadata(self, m_id: int) -> dict[str, dict[str, SpeechMetadata]]:
1572
+ from .helpers import mixture_all_speech_metadata
1573
+
1574
+ return mixture_all_speech_metadata(self, self.mixture(m_id))
1575
+
1576
+ def cached_metrics(self, m_ids: GeneralizedIDs = "*") -> list[str]:
1577
+ """Get a list of cached metrics for all mixtures."""
1578
+ from glob import glob
1579
+ from os.path import join
1580
+ from pathlib import Path
1581
+
1582
+ supported_metrics = self.supported_metrics.names
1583
+ first = True
1584
+ result: set[str] = set()
1585
+ for m_id in self.mixids_to_list(m_ids):
1586
+ mixture_dir = join(self.location, "mixture", self.mixture(m_id).name)
1587
+ found = {Path(f).stem for f in glob(join(mixture_dir, "*.pkl"))}
1588
+ if first:
1589
+ first = False
1590
+ for f in found:
1591
+ if f in supported_metrics:
1592
+ result.add(f)
1593
+ else:
1594
+ result = result & found
1595
+
1596
+ return sorted(result)
1597
+
1598
+ def mixture_metrics(self, m_id: int, metrics: list[str], force: bool = False) -> dict[str, Any]:
1599
+ """Get metrics data for the given mixture ID
1600
+
1601
+ :param m_id: Zero-based mixture ID
1602
+ :param metrics: List of metrics to get
1603
+ :param force: Force computing data from original sources regardless of whether cached data exists
1604
+ :return: Dictionary of metric data
1605
+ """
1606
+ from ..metrics import calculate_metrics
1607
+
1608
+ return calculate_metrics(self, m_id, metrics, force)
1609
+
1610
+
1611
+ def _spectral_mask(db: partial, sm_id: int, use_cache: bool = True) -> SpectralMask:
1612
+ """Get spectral mask with ID from db
1613
+
1614
+ :param db: Database context
1615
+ :param sm_id: Spectral mask ID
1616
+ :param use_cache: If true, use LRU caching
1617
+ :return: Spectral mask
1618
+ """
1619
+ if use_cache:
1620
+ return __spectral_mask(db, sm_id)
1621
+ return __spectral_mask.__wrapped__(db, sm_id)
1622
+
1623
+
1624
+ @lru_cache
1625
+ def __spectral_mask(db: partial, sm_id: int) -> SpectralMask:
1626
+ from .db_datatypes import SpectralMaskRecord
1627
+
1628
+ with db() as c:
1629
+ spectral_mask = SpectralMaskRecord(*c.execute("SELECT * FROM spectral_mask WHERE ? = id", (sm_id,)).fetchone())
1630
+ return SpectralMask(
1631
+ f_max_width=spectral_mask.f_max_width,
1632
+ f_num=spectral_mask.f_num,
1633
+ t_max_width=spectral_mask.t_max_width,
1634
+ t_num=spectral_mask.t_num,
1635
+ t_max_percent=spectral_mask.t_max_percent,
1636
+ )
1637
+
1638
+
1639
+ def _num_source_files(db: partial, category: str, use_cache: bool = True) -> int:
1640
+ """Get the number of source files from a category from db
1641
+
1642
+ :param db: Database context
1643
+ :param category: Source category
1644
+ :param use_cache: If true, use LRU caching
1645
+ :return: Number of source files
1646
+ """
1647
+ if use_cache:
1648
+ return __num_source_files(db, category)
1649
+ return __num_source_files.__wrapped__(db, category)
1650
+
1651
+
1652
+ @lru_cache
1653
+ def __num_source_files(db: partial, category: str) -> int:
1654
+ """Get the number of source files from a category from db
1655
+
1656
+ :param db: Database context
1657
+ :param category: Source category
1658
+ :return: Number of source files
1659
+ """
1660
+ with db() as c:
1661
+ return int(c.execute("SELECT count(id) FROM source_file WHERE ? = category", (category,)).fetchone()[0])
1662
+
1663
+
1664
+ def _source_file(db: partial, s_id: int, use_cache: bool = True) -> SourceFile:
1665
+ """Get the source file with ID from db
1666
+
1667
+ :param db: Database context
1668
+ :param s_id: Source file ID
1669
+ :param use_cache: If true, use LRU caching
1670
+ :return: Source file
1671
+ """
1672
+ if use_cache:
1673
+ return __source_file(db, s_id, use_cache)
1674
+ return __source_file.__wrapped__(db, s_id, use_cache)
1675
+
1676
+
1677
+ @lru_cache
1678
+ def __source_file(db: partial, s_id: int, use_cache: bool = True) -> SourceFile:
1679
+ """Get the source file with ID from db
1680
+
1681
+ :param db: Database context
1682
+ :param s_id: Source file ID
1683
+ :param use_cache: If true, use LRU caching
1684
+ :return: Source file
1685
+ """
1686
+ import json
1687
+
1688
+ from .db_datatypes import SourceFileRecord
1689
+
1690
+ with db() as c:
1691
+ source_file = SourceFileRecord(*c.execute("SELECT * FROM source_file WHERE ? = id", (s_id,)).fetchone())
1692
+
1693
+ return SourceFile(
1694
+ category=source_file.category,
1695
+ name=source_file.name,
1696
+ samples=source_file.samples,
1697
+ class_indices=json.loads(source_file.class_indices),
1698
+ level_type=source_file.level_type,
1699
+ truth_configs=_source_truth_configs(db, s_id, use_cache),
1700
+ speaker_id=source_file.speaker_id,
1701
+ )
1702
+
1703
+
1704
+ def _ir_file(db: partial, ir_id: int, use_cache: bool = True) -> str:
1705
+ """Get impulse response file name with ID from db
1706
+
1707
+ :param db: Database context
1708
+ :param ir_id: Impulse response file ID
1709
+ :param use_cache: If true, use LRU caching
1710
+ :return: Impulse response file name
1711
+ """
1712
+ if use_cache:
1713
+ return __ir_file(db, ir_id)
1714
+ return __ir_file.__wrapped__(db, ir_id)
1715
+
1716
+
1717
+ @lru_cache
1718
+ def __ir_file(db: partial, ir_id: int) -> str:
1719
+ with db() as c:
1720
+ return str(c.execute("SELECT name FROM ir_file WHERE ? = id ", (ir_id + 1,)).fetchone()[0])
1721
+
1722
+
1723
+ def _ir_delay(db: partial, ir_id: int, use_cache: bool = True) -> int:
1724
+ """Get impulse response delay with ID from db
1725
+
1726
+ :param db: Database context
1727
+ :param ir_id: Impulse response file ID
1728
+ :param use_cache: If true, use LRU caching
1729
+ :return: Impulse response delay
1730
+ """
1731
+ if use_cache:
1732
+ return __ir_delay(db, ir_id)
1733
+ return __ir_delay.__wrapped__(db, ir_id)
1734
+
1735
+
1736
+ @lru_cache
1737
+ def __ir_delay(db: partial, ir_id: int) -> int:
1738
+ with db() as c:
1739
+ return int(c.execute("SELECT delay FROM ir_file WHERE ? = id", (ir_id + 1,)).fetchone()[0])
1740
+
1741
+
1742
+ def _mixture(db: partial, m_id: int, use_cache: bool = True) -> Mixture:
1743
+ """Get mixture record with ID from db
1744
+
1745
+ :param db: Database context
1746
+ :param m_id: Zero-based mixture ID
1747
+ :param use_cache: If true, use LRU caching
1748
+ :return: Mixture record
1749
+ """
1750
+ if use_cache:
1751
+ return __mixture(db, m_id)
1752
+ return __mixture.__wrapped__(db, m_id)
1753
+
1754
+
1755
+ @lru_cache
1756
+ def __mixture(db: partial, m_id: int) -> Mixture:
1757
+ from .db_datatypes import MixtureRecord
1758
+ from .db_datatypes import SourceRecord
1759
+ from .helpers import to_mixture
1760
+ from .helpers import to_source
1761
+
1762
+ with db() as c:
1763
+ mixture = MixtureRecord(*c.execute("SELECT * FROM mixture WHERE ? = id", (m_id + 1,)).fetchone())
1764
+
1765
+ sources: Sources = {}
1766
+ for source in c.execute(
1767
+ """
1768
+ SELECT source.*
1769
+ FROM source, mixture_source
1770
+ WHERE ? = mixture_source.mixture_id AND source.id = mixture_source.source_id
1771
+ """,
1772
+ (mixture.id,),
1773
+ ).fetchall():
1774
+ s = SourceRecord(*source)
1775
+ category = c.execute("SELECT category FROM source_file WHERE ? = id", (s.file_id,)).fetchone()[0]
1776
+ sources[category] = to_source(s)
1777
+
1778
+ return to_mixture(mixture, sources)
1779
+
1780
+
1781
+ def _speaker(db: partial, s_id: int | None, tier: str, use_cache: bool = True) -> str | None:
1782
+ if use_cache:
1783
+ return __speaker(db, s_id, tier)
1784
+ return __speaker.__wrapped__(db, s_id, tier)
1785
+
1786
+
1787
+ @lru_cache
1788
+ def __speaker(db: partial, s_id: int | None, tier: str) -> str | None:
1789
+ if s_id is None:
1790
+ return None
1791
+
1792
+ with db() as c:
1793
+ data = c.execute(f"SELECT {tier} FROM speaker WHERE ? = id", (s_id,)).fetchone()
1794
+ if data is None:
1795
+ return None
1796
+ if data[0] is None:
1797
+ return None
1798
+ return data[0]
1799
+
1800
+
1801
+ def _category_truth_configs(db: partial, category: str, use_cache: bool = True) -> dict[str, str]:
1802
+ if use_cache:
1803
+ return __category_truth_configs(db, category)
1804
+ return __category_truth_configs.__wrapped__(db, category)
1805
+
1806
+
1807
+ @lru_cache
1808
+ def __category_truth_configs(db: partial, category: str) -> dict[str, str]:
1809
+ import json
1810
+
1811
+ truth_configs: dict[str, str] = {}
1812
+ with db() as c:
1813
+ s_ids = c.execute("SELECT id FROM source_file WHERE ? = category", (category,)).fetchall()
1814
+
1815
+ for s_id in s_ids:
1816
+ for truth_config_record in c.execute(
1817
+ """
1818
+ SELECT truth_config.config
1819
+ FROM truth_config, source_file_truth_config
1820
+ WHERE ? = source_file_truth_config.source_file_id AND truth_config.id = source_file_truth_config.truth_config_id
1821
+ """,
1822
+ (s_id[0],),
1823
+ ).fetchall():
1824
+ truth_config = json.loads(truth_config_record[0])
1825
+ truth_configs[truth_config["name"]] = truth_config["function"]
1826
+ return truth_configs
1827
+
1828
+
1829
+ def _source_truth_configs(db: partial, s_id: int, use_cache: bool = True) -> TruthConfigs:
1830
+ if use_cache:
1831
+ return __source_truth_configs(db, s_id)
1832
+ return __source_truth_configs.__wrapped__(db, s_id)
1833
+
1834
+
1835
+ @lru_cache
1836
+ def __source_truth_configs(db: partial, s_id: int) -> TruthConfigs:
1837
+ import json
1838
+
1839
+ from ..datatypes import TruthConfig
1840
+
1841
+ truth_configs: TruthConfigs = {}
1842
+ with db() as c:
1843
+ for truth_config_record in c.execute(
1844
+ """
1845
+ SELECT truth_config.config
1846
+ FROM truth_config, source_file_truth_config
1847
+ WHERE ? = source_file_truth_config.source_file_id AND truth_config.id = source_file_truth_config.truth_config_id
1848
+ """,
1849
+ (s_id,),
1850
+ ).fetchall():
1851
+ truth_config = json.loads(truth_config_record[0])
1852
+ truth_configs[truth_config["name"]] = TruthConfig(
1853
+ function=truth_config["function"],
1854
+ stride_reduction=truth_config["stride_reduction"],
1855
+ config=truth_config["config"],
1856
+ )
1857
+ return truth_configs