sonusai 0.20.3__py3-none-any.whl → 1.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. sonusai/__init__.py +16 -3
  2. sonusai/audiofe.py +241 -77
  3. sonusai/calc_metric_spenh.py +71 -73
  4. sonusai/config/__init__.py +3 -0
  5. sonusai/config/config.py +61 -0
  6. sonusai/config/config.yml +20 -0
  7. sonusai/config/constants.py +8 -0
  8. sonusai/constants.py +11 -0
  9. sonusai/data/genmixdb.yml +21 -36
  10. sonusai/{mixture/datatypes.py → datatypes.py} +91 -130
  11. sonusai/deprecated/plot.py +4 -5
  12. sonusai/doc/doc.py +4 -4
  13. sonusai/doc.py +11 -4
  14. sonusai/genft.py +43 -45
  15. sonusai/genmetrics.py +25 -19
  16. sonusai/genmix.py +54 -82
  17. sonusai/genmixdb.py +88 -264
  18. sonusai/ir_metric.py +30 -34
  19. sonusai/lsdb.py +41 -48
  20. sonusai/main.py +15 -22
  21. sonusai/metrics/calc_audio_stats.py +4 -293
  22. sonusai/metrics/calc_class_weights.py +4 -4
  23. sonusai/metrics/calc_optimal_thresholds.py +8 -5
  24. sonusai/metrics/calc_pesq.py +2 -2
  25. sonusai/metrics/calc_segsnr_f.py +4 -4
  26. sonusai/metrics/calc_speech.py +25 -13
  27. sonusai/metrics/class_summary.py +7 -7
  28. sonusai/metrics/confusion_matrix_summary.py +5 -5
  29. sonusai/metrics/one_hot.py +4 -4
  30. sonusai/metrics/snr_summary.py +7 -7
  31. sonusai/metrics_summary.py +38 -45
  32. sonusai/mixture/__init__.py +4 -104
  33. sonusai/mixture/audio.py +10 -39
  34. sonusai/mixture/class_balancing.py +103 -0
  35. sonusai/mixture/config.py +251 -271
  36. sonusai/mixture/constants.py +35 -39
  37. sonusai/mixture/data_io.py +25 -36
  38. sonusai/mixture/db_datatypes.py +58 -22
  39. sonusai/mixture/effects.py +386 -0
  40. sonusai/mixture/feature.py +7 -11
  41. sonusai/mixture/generation.py +478 -628
  42. sonusai/mixture/helpers.py +82 -184
  43. sonusai/mixture/ir_delay.py +3 -4
  44. sonusai/mixture/ir_effects.py +77 -0
  45. sonusai/mixture/log_duration_and_sizes.py +6 -12
  46. sonusai/mixture/mixdb.py +910 -729
  47. sonusai/mixture/pad_audio.py +35 -0
  48. sonusai/mixture/resample.py +7 -0
  49. sonusai/mixture/sox_effects.py +195 -0
  50. sonusai/mixture/sox_help.py +650 -0
  51. sonusai/mixture/spectral_mask.py +2 -2
  52. sonusai/mixture/truth.py +17 -15
  53. sonusai/mixture/truth_functions/crm.py +12 -12
  54. sonusai/mixture/truth_functions/energy.py +22 -22
  55. sonusai/mixture/truth_functions/file.py +5 -5
  56. sonusai/mixture/truth_functions/metadata.py +4 -4
  57. sonusai/mixture/truth_functions/metrics.py +4 -4
  58. sonusai/mixture/truth_functions/phoneme.py +3 -3
  59. sonusai/mixture/truth_functions/sed.py +11 -13
  60. sonusai/mixture/truth_functions/target.py +10 -10
  61. sonusai/mkwav.py +26 -29
  62. sonusai/onnx_predict.py +240 -88
  63. sonusai/queries/__init__.py +2 -2
  64. sonusai/queries/queries.py +38 -34
  65. sonusai/speech/librispeech.py +1 -1
  66. sonusai/speech/mcgill.py +1 -1
  67. sonusai/speech/timit.py +2 -2
  68. sonusai/summarize_metric_spenh.py +10 -17
  69. sonusai/utils/__init__.py +7 -1
  70. sonusai/utils/asl_p56.py +2 -2
  71. sonusai/utils/asr.py +2 -2
  72. sonusai/utils/asr_functions/aaware_whisper.py +4 -5
  73. sonusai/utils/choice.py +31 -0
  74. sonusai/utils/compress.py +1 -1
  75. sonusai/utils/dataclass_from_dict.py +19 -1
  76. sonusai/utils/energy_f.py +3 -3
  77. sonusai/utils/evaluate_random_rule.py +15 -0
  78. sonusai/utils/keyboard_interrupt.py +12 -0
  79. sonusai/utils/onnx_utils.py +3 -17
  80. sonusai/utils/print_mixture_details.py +21 -19
  81. sonusai/utils/{temp_seed.py → rand.py} +3 -3
  82. sonusai/utils/read_predict_data.py +2 -2
  83. sonusai/utils/reshape.py +3 -3
  84. sonusai/utils/stratified_shuffle_split.py +3 -3
  85. sonusai/{mixture → utils}/tokenized_shell_vars.py +1 -1
  86. sonusai/utils/write_audio.py +2 -2
  87. sonusai/vars.py +11 -4
  88. {sonusai-0.20.3.dist-info → sonusai-1.0.2.dist-info}/METADATA +4 -2
  89. sonusai-1.0.2.dist-info/RECORD +138 -0
  90. sonusai/mixture/augmentation.py +0 -444
  91. sonusai/mixture/class_count.py +0 -15
  92. sonusai/mixture/eq_rule_is_valid.py +0 -45
  93. sonusai/mixture/target_class_balancing.py +0 -107
  94. sonusai/mixture/targets.py +0 -175
  95. sonusai-0.20.3.dist-info/RECORD +0 -128
  96. {sonusai-0.20.3.dist-info → sonusai-1.0.2.dist-info}/WHEEL +0 -0
  97. {sonusai-0.20.3.dist-info → sonusai-1.0.2.dist-info}/entry_points.txt +0 -0
@@ -15,6 +15,7 @@ AudioT: TypeAlias = npt.NDArray[np.float32]
15
15
 
16
16
  Truth: TypeAlias = Any
17
17
  TruthDict: TypeAlias = dict[str, Truth]
18
+ TruthsDict: TypeAlias = dict[str, TruthDict]
18
19
  Segsnr: TypeAlias = npt.NDArray[np.float32]
19
20
 
20
21
  AudioF: TypeAlias = npt.NDArray[np.complex64]
@@ -68,63 +69,34 @@ class TruthConfig(DataClassSonusAIMixin):
68
69
 
69
70
 
70
71
  TruthConfigs: TypeAlias = dict[str, TruthConfig]
72
+ TruthsConfigs: TypeAlias = dict[str, TruthConfigs]
73
+
74
+
71
75
  NumberStr: TypeAlias = float | int | str
72
76
  OptionalNumberStr: TypeAlias = NumberStr | None
73
77
  OptionalListNumberStr: TypeAlias = list[NumberStr] | None
74
- EQ: TypeAlias = tuple[float | int, float | int, float | int]
75
-
76
-
77
- @dataclass
78
- class AugmentationRuleEffects(DataClassSonusAIMixin):
79
- normalize: OptionalNumberStr = None
80
- pitch: OptionalNumberStr = None
81
- tempo: OptionalNumberStr = None
82
- gain: OptionalNumberStr = None
83
- eq1: OptionalListNumberStr = None
84
- eq2: OptionalListNumberStr = None
85
- eq3: OptionalListNumberStr = None
86
- lpf: OptionalNumberStr = None
87
- ir: OptionalNumberStr = None
88
-
89
-
90
- @dataclass
91
- class AugmentationRule(DataClassSonusAIMixin):
92
- pre: AugmentationRuleEffects
93
- post: AugmentationRuleEffects | None = None
94
- mixup: int = 1
95
78
 
96
79
 
97
- @dataclass
98
- class AugmentationEffects(DataClassSonusAIMixin):
99
- normalize: float | None = None
100
- pitch: float | None = None
101
- tempo: float | None = None
102
- gain: float | None = None
103
- eq1: EQ | None = None
104
- eq2: EQ | None = None
105
- eq3: EQ | None = None
106
- lpf: float | None = None
107
- ir: int | None = None
80
+ EffectList: TypeAlias = list[str]
108
81
 
109
82
 
110
83
  @dataclass
111
- class Augmentation(DataClassSonusAIMixin):
112
- pre: AugmentationEffects
113
- post: AugmentationEffects
84
+ class Effects(DataClassSonusAIMixin):
85
+ pre: EffectList
86
+ post: EffectList = field(default_factory=EffectList)
114
87
 
115
88
 
116
- @dataclass(frozen=True)
117
89
  class UniversalSNRGenerator:
118
- is_random: bool
119
- _raw_value: float | str
90
+ def __init__(self, raw_value: float | str) -> None:
91
+ self._raw_value = str(raw_value)
92
+ self.is_random = isinstance(raw_value, str) and raw_value.startswith("sai_rand")
120
93
 
121
94
  @property
122
95
  def value(self) -> float:
123
- if self.is_random:
124
- from .augmentation import evaluate_random_rule
125
-
126
- return float(evaluate_random_rule(str(self._raw_value)))
96
+ from sonusai.mixture.effects import evaluate_sai_random_float
127
97
 
98
+ if self.is_random:
99
+ return float(evaluate_sai_random_float(self._raw_value))
128
100
  return float(self._raw_value)
129
101
 
130
102
 
@@ -145,12 +117,14 @@ Speaker: TypeAlias = dict[str, str]
145
117
 
146
118
 
147
119
  @dataclass
148
- class TargetFile(DataClassSonusAIMixin):
120
+ class SourceFile(DataClassSonusAIMixin):
121
+ category: str
122
+ class_indices: list[int]
149
123
  name: str
150
124
  samples: int
151
- class_indices: list[int]
152
125
  truth_configs: TruthConfigs
153
- class_balancing_augmentation: AugmentationRule | None = None
126
+ class_balancing_effect: EffectList | None = None
127
+ id: int = -1
154
128
  level_type: str | None = None
155
129
  speaker_id: int | None = None
156
130
 
@@ -162,21 +136,9 @@ class TargetFile(DataClassSonusAIMixin):
162
136
 
163
137
 
164
138
  @dataclass
165
- class AugmentedTarget(DataClassSonusAIMixin):
166
- target_id: int
167
- target_augmentation_id: int
168
-
169
-
170
- @dataclass
171
- class NoiseFile(DataClassSonusAIMixin):
172
- name: str
173
- samples: int
174
-
175
- @property
176
- def duration(self) -> float:
177
- from .constants import SAMPLE_RATE
178
-
179
- return self.samples / SAMPLE_RATE
139
+ class EffectedFile(DataClassSonusAIMixin):
140
+ file_id: int
141
+ effect_id: int
180
142
 
181
143
 
182
144
  ClassCount: TypeAlias = list[int]
@@ -184,37 +146,6 @@ ClassCount: TypeAlias = list[int]
184
146
  GeneralizedIDs: TypeAlias = str | int | list[int] | range
185
147
 
186
148
 
187
- @dataclass
188
- class GenMixData:
189
- targets: list[AudioT] | None = None
190
- target: AudioT | None = None
191
- noise: AudioT | None = None
192
- mixture: AudioT | None = None
193
- truth_t: list[TruthDict] | None = None
194
- segsnr_t: Segsnr | None = None
195
-
196
-
197
- @dataclass
198
- class GenFTData:
199
- feature: Feature | None = None
200
- truth_f: TruthDict | None = None
201
- segsnr: Segsnr | None = None
202
-
203
-
204
- @dataclass
205
- class ImpulseResponseData:
206
- data: AudioT
207
- sample_rate: int
208
- delay: int
209
-
210
-
211
- @dataclass
212
- class ImpulseResponseFile:
213
- file: str
214
- tags: list[str]
215
- delay: int
216
-
217
-
218
149
  @dataclass(frozen=True)
219
150
  class SpectralMask(DataClassSonusAIMixin):
220
151
  f_max_width: int
@@ -226,68 +157,70 @@ class SpectralMask(DataClassSonusAIMixin):
226
157
 
227
158
  @dataclass(frozen=True)
228
159
  class TruthParameter(DataClassSonusAIMixin):
160
+ category: str
229
161
  name: str
230
162
  parameters: int | None
231
163
 
232
164
 
233
165
  @dataclass
234
- class Target(DataClassSonusAIMixin):
166
+ class Source(DataClassSonusAIMixin):
167
+ effects: Effects
235
168
  file_id: int
236
- augmentation: Augmentation
237
-
238
- @property
239
- def gain(self) -> float:
240
- # gain is used to back out the gain augmentation in order to return the target audio
241
- # to its normalized level when calculating truth (if needed).
242
- if self.augmentation.pre.gain is None:
243
- return 1.0
244
- return round(10 ** (self.augmentation.pre.gain / 20), ndigits=5)
169
+ pre_tempo: float = 1
170
+ repeat: bool = False
171
+ snr: UniversalSNR = field(default_factory=lambda: UniversalSNR(0))
172
+ snr_gain: float = 0
173
+ start: int = 0
245
174
 
246
175
 
247
- Targets: TypeAlias = list[Target]
248
-
249
-
250
- @dataclass
251
- class Noise(DataClassSonusAIMixin):
252
- file_id: int
253
- augmentation: Augmentation
176
+ Sources: TypeAlias = dict[str, Source]
177
+ SourcesAudioT: TypeAlias = dict[str, AudioT]
178
+ SourcesAudioF: TypeAlias = dict[str, AudioF]
254
179
 
255
180
 
256
181
  @dataclass
257
182
  class Mixture(DataClassSonusAIMixin):
258
183
  name: str
259
- targets: list[Target]
260
- noise: Noise
261
- noise_offset: int
262
184
  samples: int
263
- snr: UniversalSNR
185
+ all_sources: Sources
264
186
  spectral_mask_id: int
265
187
  spectral_mask_seed: int
266
- target_snr_gain: float = 1.0
267
- noise_snr_gain: float = 1.0
188
+
189
+ @property
190
+ def all_source_ids(self) -> dict[str, int]:
191
+ return {category: source.file_id for category, source in self.all_sources.items()}
192
+
193
+ @property
194
+ def sources(self) -> Sources:
195
+ return {category: source for category, source in self.all_sources.items() if category != "noise"}
196
+
197
+ @property
198
+ def source_ids(self) -> dict[str, int]:
199
+ return {category: source.file_id for category, source in self.sources.items()}
200
+
201
+ @property
202
+ def noise(self) -> Source:
203
+ return self.all_sources["noise"]
268
204
 
269
205
  @property
270
206
  def noise_id(self) -> int:
271
207
  return self.noise.file_id
272
208
 
273
209
  @property
274
- def target_ids(self) -> list[int]:
275
- return [target.file_id for target in self.targets]
210
+ def source_effects(self) -> dict[str, Effects]:
211
+ return {category: source.effects for category, source in self.sources.items()}
276
212
 
277
213
  @property
278
- def target_augmentations(self) -> list[Augmentation]:
279
- return [target.augmentation for target in self.targets]
214
+ def noise_effects(self) -> Effects:
215
+ return self.noise.effects
280
216
 
281
217
  @property
282
218
  def is_noise_only(self) -> bool:
283
- return self.snr < -96
219
+ return self.noise.snr < -96
284
220
 
285
221
  @property
286
- def is_target_only(self) -> bool:
287
- return self.snr > 96
288
-
289
- def target_gain(self, target_index: int) -> float:
290
- return (self.targets[target_index].gain if not self.is_noise_only else 0) * self.target_snr_gain
222
+ def is_source_only(self) -> bool:
223
+ return self.noise.snr > 96
291
224
 
292
225
 
293
226
  @dataclass(frozen=True)
@@ -302,7 +235,7 @@ class TransformConfig:
302
235
  @dataclass(frozen=True)
303
236
  class FeatureGeneratorConfig:
304
237
  feature_mode: str
305
- truth_parameters: dict[str, int | None]
238
+ truth_parameters: dict[str, dict[str, int | None]]
306
239
 
307
240
 
308
241
  @dataclass(frozen=True)
@@ -319,6 +252,37 @@ class FeatureGeneratorInfo:
319
252
  ASRConfigs: TypeAlias = dict[str, dict[str, Any]]
320
253
 
321
254
 
255
+ @dataclass
256
+ class GenMixData:
257
+ mixture: AudioT | None = None
258
+ truth_t: TruthsDict | None = None
259
+ segsnr_t: Segsnr | None = None
260
+ sources: SourcesAudioT | None = None
261
+ source: AudioT | None = None
262
+ noise: AudioT | None = None
263
+
264
+
265
+ @dataclass
266
+ class GenFTData:
267
+ feature: Feature | None = None
268
+ truth_f: TruthsDict | None = None
269
+ segsnr: Segsnr | None = None
270
+
271
+
272
+ @dataclass
273
+ class ImpulseResponseData:
274
+ data: AudioT
275
+ sample_rate: int
276
+ delay: int
277
+
278
+
279
+ @dataclass
280
+ class ImpulseResponseFile(DataClassSonusAIMixin):
281
+ name: str
282
+ tags: list[str]
283
+ delay: str | int = "auto"
284
+
285
+
322
286
  @dataclass
323
287
  class MixtureDatabaseConfig(DataClassSonusAIMixin):
324
288
  asr_configs: ASRConfigs
@@ -326,13 +290,11 @@ class MixtureDatabaseConfig(DataClassSonusAIMixin):
326
290
  class_labels: list[str]
327
291
  class_weights_threshold: list[float]
328
292
  feature: str
329
- impulse_response_files: list[ImpulseResponseFile]
293
+ ir_files: list[ImpulseResponseFile]
330
294
  mixtures: list[Mixture]
331
- noise_mix_mode: str
332
- noise_files: list[NoiseFile]
333
295
  num_classes: int
296
+ source_files: dict[str, list[SourceFile]]
334
297
  spectral_masks: list[SpectralMask]
335
- target_files: list[TargetFile]
336
298
 
337
299
 
338
300
  SpeechMetadata: TypeAlias = str | list[Interval] | None
@@ -353,7 +315,6 @@ class SnrFBinMetrics(NamedTuple):
353
315
 
354
316
 
355
317
  class SpeechMetrics(NamedTuple):
356
- pesq: float | None = None
357
318
  csig: float | None = None
358
319
  cbak: float | None = None
359
320
  covl: float | None = None
@@ -45,11 +45,10 @@ import signal
45
45
 
46
46
  import numpy as np
47
47
  from matplotlib import pyplot as plt
48
-
49
- from sonusai.mixture import AudioT
50
- from sonusai.mixture import Feature
51
- from sonusai.mixture import Predict
52
- from sonusai.mixture import Truth
48
+ from sonusai.datatypes import AudioT
49
+ from sonusai.datatypes import Feature
50
+ from sonusai.datatypes import Predict
51
+ from sonusai.datatypes import Truth
53
52
 
54
53
 
55
54
  def signal_handler(_sig, _frame):
sonusai/doc/doc.py CHANGED
@@ -1,4 +1,4 @@
1
- from sonusai.mixture import get_default_config
1
+ from ..mixture.config import get_default_config
2
2
 
3
3
 
4
4
  def doc_seed() -> str:
@@ -124,7 +124,7 @@ in this dataset.
124
124
 
125
125
 
126
126
  def get_truth_functions() -> str:
127
- from sonusai.mixture import truth_functions
127
+ from ..mixture import truth_functions
128
128
 
129
129
  functions = [function for function in dir(truth_functions) if not function.startswith("__")]
130
130
  text = "\nSupported truth functions:\n\n"
@@ -471,7 +471,7 @@ Rules must specify all the following parameters:
471
471
 
472
472
 
473
473
  def doc_config() -> str:
474
- from sonusai.mixture import VALID_CONFIGS
474
+ from ..mixture.constants import VALID_CONFIGS
475
475
 
476
476
  text = "\n"
477
477
  text += "The SonusAI database is defined using a config.yml file.\n\n"
@@ -482,7 +482,7 @@ def doc_config() -> str:
482
482
 
483
483
 
484
484
  def doc_asr_configs() -> str:
485
- from sonusai.utils import get_available_engines
485
+ from ..utils.asr import get_available_engines
486
486
 
487
487
  default = f"\nDefault value: {get_default_config()['asr_configs']}"
488
488
  engines = get_available_engines()
sonusai/doc.py CHANGED
@@ -13,16 +13,16 @@ Show SonusAI documentation.
13
13
  def main() -> None:
14
14
  from docopt import docopt
15
15
 
16
- import sonusai
16
+ from sonusai import __version__ as sai_version
17
17
  from sonusai.utils import trim_docstring
18
18
 
19
- args = docopt(trim_docstring(__doc__), version=sonusai.__version__, options_first=True)
19
+ args = docopt(trim_docstring(__doc__), version=sai_version, options_first=True)
20
20
 
21
21
  from sonusai import doc
22
22
 
23
23
  topic = args["TOPIC"]
24
24
 
25
- print(f"SonusAI {sonusai.__version__} Documentation")
25
+ print(f"SonusAI {sai_version} Documentation")
26
26
  print("")
27
27
 
28
28
  topics = sorted([item[4:] for item in dir(doc) if item.startswith("doc_")])
@@ -42,4 +42,11 @@ def main() -> None:
42
42
 
43
43
 
44
44
  if __name__ == "__main__":
45
- main()
45
+ from sonusai import exception_handler
46
+ from sonusai.utils import register_keyboard_interrupt
47
+
48
+ register_keyboard_interrupt()
49
+ try:
50
+ main()
51
+ except Exception as e:
52
+ exception_handler(e)
sonusai/genft.py CHANGED
@@ -25,22 +25,8 @@ Outputs the following to the mixture database directory:
25
25
 
26
26
  """
27
27
 
28
- import signal
29
-
30
- from sonusai.mixture import GeneralizedIDs
31
- from sonusai.mixture import GenFTData
32
-
33
-
34
- def signal_handler(_sig, _frame):
35
- import sys
36
-
37
- from sonusai import logger
38
-
39
- logger.info("Canceled due to keyboard interrupt")
40
- sys.exit(1)
41
-
42
-
43
- signal.signal(signal.SIGINT, signal_handler)
28
+ from sonusai.datatypes import GeneralizedIDs
29
+ from sonusai.datatypes import GenFTData
44
30
 
45
31
 
46
32
  def genft(
@@ -48,7 +34,7 @@ def genft(
48
34
  mixids: GeneralizedIDs = "*",
49
35
  compute_truth: bool = True,
50
36
  compute_segsnr: bool = False,
51
- write: bool = False,
37
+ cache: bool = False,
52
38
  show_progress: bool = False,
53
39
  force: bool = True,
54
40
  no_par: bool = False,
@@ -69,8 +55,8 @@ def genft(
69
55
  location=location,
70
56
  compute_truth=compute_truth,
71
57
  compute_segsnr=compute_segsnr,
58
+ cache=cache,
72
59
  force=force,
73
- write=write,
74
60
  ),
75
61
  mixids,
76
62
  progress=progress,
@@ -82,33 +68,42 @@ def genft(
82
68
 
83
69
 
84
70
  def _genft_kernel(
85
- m_id: int, location: str, compute_truth: bool, compute_segsnr: bool, force: bool, write: bool
71
+ m_id: int,
72
+ location: str,
73
+ compute_truth: bool,
74
+ compute_segsnr: bool,
75
+ cache: bool,
76
+ force: bool,
86
77
  ) -> GenFTData:
78
+ from functools import partial
79
+ from typing import Any
80
+
87
81
  from sonusai.mixture import MixtureDatabase
88
82
  from sonusai.mixture import write_cached_data
89
83
  from sonusai.mixture import write_mixture_metadata
90
84
 
91
85
  mixdb = MixtureDatabase(location)
92
86
 
87
+ write_func = partial(write_cached_data, location=mixdb.location, name="mixture", index=mixdb.mixture(m_id).name)
88
+
93
89
  result = GenFTData()
94
90
 
95
- feature, truth_f = mixdb.mixture_ft(m_id=m_id, force=force)
91
+ mixture = mixdb.mixture_mixture(m_id)
92
+
93
+ feature, truth_f = mixdb.mixture_ft(m_id=m_id, mixture=mixture, force=force)
96
94
  result.feature = feature
97
- if write:
98
- write_cached_data(mixdb.location, "mixture", mixdb.mixture(m_id).name, [("feature", feature)])
95
+ items: dict[str, Any] = {"feature": feature}
99
96
 
100
97
  if compute_truth:
101
98
  result.truth_f = truth_f
102
- if write:
103
- write_cached_data(mixdb.location, "mixture", mixdb.mixture(m_id).name, [("truth_f", truth_f)])
99
+ items["truth_f"] = truth_f
104
100
 
105
101
  if compute_segsnr:
106
- segsnr = mixdb.mixture_segsnr(m_id=m_id, force=force)
107
- result.segsnr = segsnr
108
- if write:
109
- write_cached_data(mixdb.location, "mixture", mixdb.mixture(m_id).name, [("segsnr", segsnr)])
102
+ segsnr_t = mixdb.mixture_segsnr_t(m_id)
103
+ result.segsnr = mixdb.mixture_segsnr(m_id=m_id, segsnr_t=segsnr_t, force=force, cache=cache)
110
104
 
111
- if write:
105
+ if cache:
106
+ write_func(items=items)
112
107
  write_mixture_metadata(mixdb, m_id=m_id)
113
108
 
114
109
  return result
@@ -117,10 +112,10 @@ def _genft_kernel(
117
112
  def main() -> None:
118
113
  from docopt import docopt
119
114
 
120
- import sonusai
115
+ from sonusai import __version__ as sai_version
121
116
  from sonusai.utils import trim_docstring
122
117
 
123
- args = docopt(trim_docstring(__doc__), version=sonusai.__version__, options_first=True)
118
+ args = docopt(trim_docstring(__doc__), version=sai_version, options_first=True)
124
119
 
125
120
  import time
126
121
  from os.path import join
@@ -129,7 +124,7 @@ def main() -> None:
129
124
  from sonusai import initial_log_messages
130
125
  from sonusai import logger
131
126
  from sonusai import update_console_handler
132
- from sonusai.mixture import SAMPLE_RATE
127
+ from sonusai.constants import SAMPLE_RATE
133
128
  from sonusai.mixture import MixtureDatabase
134
129
  from sonusai.mixture import check_audio_files_exist
135
130
  from sonusai.utils import human_readable_size
@@ -166,18 +161,14 @@ def main() -> None:
166
161
 
167
162
  check_audio_files_exist(mixdb)
168
163
 
169
- try:
170
- genft(
171
- location=location,
172
- mixids=mixids,
173
- compute_segsnr=compute_segsnr,
174
- write=True,
175
- show_progress=True,
176
- no_par=no_par,
177
- )
178
- except Exception as e:
179
- logger.debug(e)
180
- raise
164
+ genft(
165
+ location=location,
166
+ mixids=mixids,
167
+ compute_segsnr=compute_segsnr,
168
+ cache=True,
169
+ show_progress=True,
170
+ no_par=no_par,
171
+ )
181
172
 
182
173
  logger.info(f"Wrote {len(mixids)} mixtures to {location}")
183
174
  logger.info("")
@@ -195,4 +186,11 @@ def main() -> None:
195
186
 
196
187
 
197
188
  if __name__ == "__main__":
198
- main()
189
+ from sonusai import exception_handler
190
+ from sonusai.utils import register_keyboard_interrupt
191
+
192
+ register_keyboard_interrupt()
193
+ try:
194
+ main()
195
+ except Exception as e:
196
+ exception_handler(e)
sonusai/genmetrics.py CHANGED
@@ -1,12 +1,13 @@
1
1
  """sonusai genmetrics
2
2
 
3
- usage: genmetrics [-hvusd] [-i MIXID] [-n INCLUDE] [-x EXCLUDE] LOC
3
+ usage: genmetrics [-hvusd] [-i MIXID] [-n INCLUDE] [-p NUMPROC] [-x EXCLUDE] LOC
4
4
 
5
5
  options:
6
6
  -h, --help
7
7
  -v, --verbose Be verbose.
8
8
  -i MIXID, --mixid MIXID Mixture ID(s) to generate. [default: *].
9
9
  -n INCLUDE, --include INCLUDE Metrics to include. [default: all]
10
+ -p NUMPROC, --nproc NUMPROC Number of parallel processes to use. Default single thread.
10
11
  -x EXCLUDE, --exclude EXCLUDE Metrics to exclude. [default: none]
11
12
  -u, --update Update metrics (do not regenerate existing metrics).
12
13
  -s, --supported Show list of supported metrics.
@@ -46,28 +47,16 @@ Generate all available metrics except for mxcovl
46
47
 
47
48
  """
48
49
 
49
- import signal
50
-
51
-
52
- def signal_handler(_sig, _frame):
53
- import sys
54
-
55
- from sonusai import logger
56
-
57
- logger.info("Canceled due to keyboard interrupt")
58
- sys.exit(1)
59
-
60
-
61
- signal.signal(signal.SIGINT, signal_handler)
62
-
63
50
 
64
51
  def _process_mixture(mixid: int, location: str, metrics: list[str], update: bool = False) -> set[str]:
65
52
  from sonusai.mixture import MixtureDatabase
66
53
  from sonusai.mixture import write_cached_data
54
+ from sonusai.mixture import write_mixture_metadata
67
55
 
68
56
  mixdb = MixtureDatabase(location)
69
57
  results = mixdb.mixture_metrics(m_id=mixid, metrics=metrics, force=not update)
70
- write_cached_data(mixdb.location, "mixture", mixdb.mixture(mixid).name, list(results.items()))
58
+ write_cached_data(mixdb.location, "mixture", mixdb.mixture(mixid).name, results)
59
+ write_mixture_metadata(mixdb, mixture=mixdb.mixture(mixid))
71
60
 
72
61
  return set(results.keys())
73
62
 
@@ -75,11 +64,11 @@ def _process_mixture(mixid: int, location: str, metrics: list[str], update: bool
75
64
  def main() -> None:
76
65
  from docopt import docopt
77
66
 
78
- import sonusai
67
+ from sonusai import __version__ as sai_version
79
68
  from sonusai.mixture import MixtureDatabase
80
69
  from sonusai.utils import trim_docstring
81
70
 
82
- args = docopt(trim_docstring(__doc__), version=sonusai.__version__, options_first=True)
71
+ args = docopt(trim_docstring(__doc__), version=sai_version, options_first=True)
83
72
 
84
73
  verbose = args["--verbose"]
85
74
  mixids = args["--mixid"]
@@ -88,6 +77,7 @@ def main() -> None:
88
77
  update = args["--update"]
89
78
  show_supported = args["--supported"]
90
79
  dryrun = args["--dryrun"]
80
+ num_proc = args["--nproc"]
91
81
  location = args["LOC"]
92
82
 
93
83
  import fnmatch
@@ -157,11 +147,20 @@ def main() -> None:
157
147
  logger.info("")
158
148
  logger.info(f"Found {len(mixids):,} mixtures to process")
159
149
 
150
+ if num_proc is None or len(mixids) == 1:
151
+ no_par = True
152
+ num_proc = None
153
+ else:
154
+ no_par = False
155
+ num_proc = int(num_proc) # TBD add support for 'auto'
156
+
160
157
  progress = track(total=len(mixids), desc="genmetrics")
161
158
  results = par_track(
162
159
  partial(_process_mixture, location=location, metrics=metrics, update=update),
163
160
  mixids,
164
161
  progress=progress,
162
+ num_cpus=num_proc,
163
+ no_par=no_par,
165
164
  )
166
165
  progress.close()
167
166
 
@@ -176,4 +175,11 @@ def main() -> None:
176
175
 
177
176
 
178
177
  if __name__ == "__main__":
179
- main()
178
+ from sonusai import exception_handler
179
+ from sonusai.utils import register_keyboard_interrupt
180
+
181
+ register_keyboard_interrupt()
182
+ try:
183
+ main()
184
+ except Exception as e:
185
+ exception_handler(e)