sonusai 0.19.5__py3-none-any.whl → 0.19.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. sonusai/__init__.py +1 -1
  2. sonusai/aawscd_probwrite.py +1 -1
  3. sonusai/calc_metric_spenh.py +1 -1
  4. sonusai/genft.py +38 -49
  5. sonusai/genmetrics.py +65 -70
  6. sonusai/genmix.py +62 -72
  7. sonusai/genmixdb.py +73 -95
  8. sonusai/metrics/calc_class_weights.py +1 -3
  9. sonusai/metrics/calc_optimal_thresholds.py +2 -2
  10. sonusai/metrics/calc_phase_distance.py +1 -1
  11. sonusai/metrics/calc_segsnr_f.py +1 -1
  12. sonusai/metrics/calc_speech.py +6 -6
  13. sonusai/metrics/class_summary.py +6 -15
  14. sonusai/metrics/confusion_matrix_summary.py +11 -27
  15. sonusai/metrics/one_hot.py +3 -3
  16. sonusai/metrics/snr_summary.py +7 -7
  17. sonusai/mixture/__init__.py +3 -17
  18. sonusai/mixture/augmentation.py +5 -6
  19. sonusai/mixture/class_count.py +1 -1
  20. sonusai/mixture/config.py +36 -46
  21. sonusai/mixture/data_io.py +30 -1
  22. sonusai/mixture/datatypes.py +29 -40
  23. sonusai/mixture/db_datatypes.py +1 -1
  24. sonusai/mixture/feature.py +3 -23
  25. sonusai/mixture/generation.py +202 -235
  26. sonusai/mixture/helpers.py +29 -187
  27. sonusai/mixture/mixdb.py +386 -159
  28. sonusai/mixture/soundfile_audio.py +1 -1
  29. sonusai/mixture/sox_audio.py +4 -4
  30. sonusai/mixture/sox_augmentation.py +1 -1
  31. sonusai/mixture/target_class_balancing.py +9 -11
  32. sonusai/mixture/targets.py +23 -20
  33. sonusai/mixture/truth.py +21 -34
  34. sonusai/mixture/truth_functions/__init__.py +6 -0
  35. sonusai/mixture/truth_functions/crm.py +51 -37
  36. sonusai/mixture/truth_functions/energy.py +95 -50
  37. sonusai/mixture/truth_functions/file.py +12 -8
  38. sonusai/mixture/truth_functions/metadata.py +24 -0
  39. sonusai/mixture/truth_functions/metrics.py +28 -0
  40. sonusai/mixture/truth_functions/phoneme.py +4 -5
  41. sonusai/mixture/truth_functions/sed.py +32 -23
  42. sonusai/mixture/truth_functions/target.py +62 -29
  43. sonusai/mkwav.py +34 -43
  44. sonusai/queries/queries.py +9 -15
  45. sonusai/speech/l2arctic.py +6 -2
  46. sonusai/summarize_metric_spenh.py +1 -1
  47. sonusai/utils/__init__.py +1 -0
  48. sonusai/utils/asr_functions/aaware_whisper.py +1 -1
  49. sonusai/utils/audio_devices.py +27 -18
  50. sonusai/utils/docstring.py +6 -3
  51. sonusai/utils/energy_f.py +5 -3
  52. sonusai/utils/human_readable_size.py +6 -6
  53. sonusai/utils/load_object.py +15 -0
  54. sonusai/utils/onnx_utils.py +2 -2
  55. sonusai/utils/parallel.py +3 -5
  56. sonusai/utils/print_mixture_details.py +3 -3
  57. {sonusai-0.19.5.dist-info → sonusai-0.19.8.dist-info}/METADATA +2 -2
  58. {sonusai-0.19.5.dist-info → sonusai-0.19.8.dist-info}/RECORD +60 -58
  59. sonusai/mixture/truth_functions/datatypes.py +0 -37
  60. {sonusai-0.19.5.dist-info → sonusai-0.19.8.dist-info}/WHEEL +0 -0
  61. {sonusai-0.19.5.dist-info → sonusai-0.19.8.dist-info}/entry_points.txt +0 -0
@@ -1,6 +1,5 @@
1
- from sonusai.mixture.datatypes import Truth
2
- from sonusai.mixture.truth_functions.datatypes import TruthFunctionConfig
3
- from sonusai.mixture.truth_functions.datatypes import TruthFunctionData
1
+ from sonusai.mixture import MixtureDatabase
2
+ from sonusai.mixture import Truth
4
3
 
5
4
 
6
5
  def file_validate(config: dict) -> None:
@@ -17,28 +16,33 @@ def file_validate(config: dict) -> None:
17
16
  raise ValueError("Truth file does not contain truth_f dataset")
18
17
 
19
18
 
20
- def file_parameters(config: TruthFunctionConfig) -> int:
19
+ def file_parameters(_feature: str, _num_classes: int, config: dict) -> int:
21
20
  import h5py
22
21
  import numpy as np
23
22
 
24
- with h5py.File(config.config["file"], "r") as f:
23
+ with h5py.File(config["file"], "r") as f:
25
24
  truth = np.array(f["truth_f"])
26
25
 
27
26
  return truth.shape[-1]
28
27
 
29
28
 
30
- def file(data: TruthFunctionData, config: TruthFunctionConfig) -> Truth:
29
+ def file(mixdb: MixtureDatabase, m_id: int, target_index: int, config: dict) -> Truth:
31
30
  """file truth function documentation"""
32
31
  import h5py
33
32
  import numpy as np
33
+ from pyaaware import feature_inverse_transform_config
34
+
35
+ target_audio = mixdb.mixture_targets(m_id)[target_index]
34
36
 
35
- with h5py.File(config.config["file"], "r") as f:
37
+ frame_size = feature_inverse_transform_config(mixdb.feature)["overlap"]
38
+
39
+ with h5py.File(config["file"], "r") as f:
36
40
  truth = np.array(f["truth_f"])
37
41
 
38
42
  if truth.ndim != 2:
39
43
  raise ValueError("Truth file data is not 2 dimensions")
40
44
 
41
- if truth.shape[0] != len(data.target_audio) // config.frame_size:
45
+ if truth.shape[0] != len(target_audio) // frame_size:
42
46
  raise ValueError("Truth file does not contain the right amount of frames")
43
47
 
44
48
  return truth
@@ -0,0 +1,24 @@
1
+ from sonusai.mixture import MixtureDatabase
2
+ from sonusai.mixture import Truth
3
+
4
+
5
+ def metadata_validate(config: dict) -> None:
6
+ if len(config) == 0:
7
+ raise AttributeError("metadata truth function is missing config")
8
+
9
+ parameters = ["tier"]
10
+ for parameter in parameters:
11
+ if parameter not in config:
12
+ raise AttributeError(f"metadata truth function is missing required '{parameter}'")
13
+
14
+
15
+ def metadata_parameters(_feature: str, _num_classes: int, _config: dict) -> int | None:
16
+ return None
17
+
18
+
19
+ def metadata(mixdb: MixtureDatabase, m_id: int, target_index: int, config: dict) -> Truth:
20
+ """Metadata truth generation function
21
+
22
+ Retrieves metadata from target.
23
+ """
24
+ return mixdb.mixture_speech_metadata(m_id, config["tier"])[target_index]
@@ -0,0 +1,28 @@
1
+ from sonusai.mixture import MixtureDatabase
2
+ from sonusai.mixture import Truth
3
+
4
+
5
+ def metrics_validate(config: dict) -> None:
6
+ if len(config) == 0:
7
+ raise AttributeError("metrics truth function is missing config")
8
+
9
+ parameters = ["metric"]
10
+ for parameter in parameters:
11
+ if parameter not in config:
12
+ raise AttributeError(f"metrics truth function is missing required '{parameter}'")
13
+
14
+
15
+ def metrics_parameters(_feature: str, _num_classes: int, _config: dict) -> int | None:
16
+ return None
17
+
18
+
19
+ def metrics(mixdb: MixtureDatabase, m_id: int, target_index: int, config: dict) -> Truth:
20
+ """Metadata truth generation function
21
+
22
+ Retrieves metrics from target.
23
+ """
24
+ if not isinstance(config["metric"], list):
25
+ m = [config["metric"]]
26
+ else:
27
+ m = config["metric"]
28
+ return mixdb.mixture_metrics(m_id, m)[0][target_index]
@@ -1,17 +1,16 @@
1
- from sonusai.mixture.datatypes import Truth
2
- from sonusai.mixture.truth_functions.datatypes import TruthFunctionConfig
3
- from sonusai.mixture.truth_functions.datatypes import TruthFunctionData
1
+ from sonusai.mixture import MixtureDatabase
2
+ from sonusai.mixture import Truth
4
3
 
5
4
 
6
5
  def phoneme_validate(_config: dict) -> None:
7
6
  raise NotImplementedError("Truth function phoneme is not supported yet")
8
7
 
9
8
 
10
- def phoneme_parameters(_config: TruthFunctionConfig) -> int:
9
+ def phoneme_parameters(_feature: str, _num_classes: int, _config: dict) -> int:
11
10
  raise NotImplementedError("Truth function phoneme is not supported yet")
12
11
 
13
12
 
14
- def phoneme(_data: TruthFunctionData, _config: TruthFunctionConfig) -> Truth:
13
+ def phoneme(_mixdb: MixtureDatabase, _m_id: int, _target_index: int, _config: dict) -> Truth:
15
14
  """Read in .txt transcript and run a Python function to generate text grid data
16
15
  (indicating which phonemes are active). Then generate truth based on this data and put
17
16
  in the correct classes based on the index in the config.
@@ -1,12 +1,5 @@
1
- from sonusai.mixture.datatypes import Truth
2
- from sonusai.mixture.truth_functions.datatypes import TruthFunctionConfig
3
- from sonusai.mixture.truth_functions.datatypes import TruthFunctionData
4
-
5
-
6
- def _strictly_decreasing(list_to_check: list) -> bool:
7
- from itertools import pairwise
8
-
9
- return all(x > y for x, y in pairwise(list_to_check))
1
+ from sonusai.mixture import MixtureDatabase
2
+ from sonusai.mixture import Truth
10
3
 
11
4
 
12
5
  def sed_validate(config: dict) -> None:
@@ -23,11 +16,11 @@ def sed_validate(config: dict) -> None:
23
16
  raise ValueError(f"sed truth function 'thresholds' are not strictly decreasing: {thresholds}")
24
17
 
25
18
 
26
- def sed_parameters(config: TruthFunctionConfig) -> int:
27
- return config.num_classes
19
+ def sed_parameters(_feature: str, num_classes: int, _config: dict) -> int:
20
+ return num_classes
28
21
 
29
22
 
30
- def sed(data: TruthFunctionData, config: TruthFunctionConfig) -> Truth:
23
+ def sed(mixdb: MixtureDatabase, m_id: int, target_index: int, config: dict) -> Truth:
31
24
  """Sound energy detection truth generation function
32
25
 
33
26
  Calculates sound energy detection truth using simple 3 threshold
@@ -62,30 +55,46 @@ def sed(data: TruthFunctionData, config: TruthFunctionConfig) -> Truth:
62
55
  import numpy as np
63
56
  import torch
64
57
  from pyaaware import SED
58
+ from pyaaware import ForwardTransform
59
+ from pyaaware import feature_forward_transform_config
60
+ from pyaaware import feature_inverse_transform_config
61
+
62
+ target_audio = torch.from_numpy(mixdb.mixture_targets(m_id)[target_index])
63
+
64
+ frame_size = feature_inverse_transform_config(mixdb.feature)["overlap"]
65
65
 
66
- if len(data.target_audio) % config.frame_size != 0:
67
- raise ValueError(f"Number of samples in audio is not a multiple of {config.frame_size}")
66
+ ft = ForwardTransform(**feature_forward_transform_config(mixdb.feature))
68
67
 
69
- frames = config.target_fft.frames(data.target_audio)
70
- parameters = sed_parameters(config)
71
- if config.target_gain == 0:
68
+ if len(target_audio) % frame_size != 0:
69
+ raise ValueError(f"Number of samples in audio is not a multiple of {frame_size}")
70
+
71
+ frames = ft.frames(target_audio)
72
+ parameters = sed_parameters(mixdb.feature, mixdb.num_classes, config)
73
+ target_gain = mixdb.mixture(m_id).target_gain(target_index)
74
+ if target_gain == 0:
72
75
  return np.zeros((frames, parameters), dtype=np.float32)
73
76
 
74
77
  # SED wants 1-based indices
75
78
  s = SED(
76
- thresholds=config.config["thresholds"],
77
- index=config.class_indices,
78
- frame_size=config.frame_size,
79
- num_classes=config.num_classes,
79
+ thresholds=config["thresholds"],
80
+ index=mixdb.target_file(mixdb.mixture(m_id).targets[target_index].file_id).class_indices,
81
+ frame_size=frame_size,
82
+ num_classes=mixdb.num_classes,
80
83
  )
81
84
 
82
85
  # Back out target gain
83
- target_audio = data.target_audio / config.target_gain
86
+ target_audio = target_audio / target_gain
84
87
 
85
88
  # Compute energy
86
- target_energy = config.target_fft.execute_all(torch.from_numpy(target_audio))[1].numpy()
89
+ target_energy = ft.execute_all(target_audio)[1].numpy()
87
90
 
88
91
  if frames != target_energy.shape[0]:
89
92
  raise ValueError("Incorrect frames calculation in sed truth function")
90
93
 
91
94
  return s.execute_all(target_energy)
95
+
96
+
97
+ def _strictly_decreasing(list_to_check: list) -> bool:
98
+ from itertools import pairwise
99
+
100
+ return all(x > y for x, y in pairwise(list_to_check))
@@ -1,21 +1,24 @@
1
- from sonusai.mixture.datatypes import AudioF
2
- from sonusai.mixture.datatypes import Truth
3
- from sonusai.mixture.truth_functions.datatypes import TruthFunctionConfig
4
- from sonusai.mixture.truth_functions.datatypes import TruthFunctionData
1
+ from sonusai.mixture import MixtureDatabase
2
+ from sonusai.mixture import Truth
5
3
 
6
4
 
7
5
  def target_f_validate(_config: dict) -> None:
8
6
  pass
9
7
 
10
8
 
11
- def target_f_parameters(config: TruthFunctionConfig) -> int:
12
- if config.ttype == "tdac-co":
13
- return config.target_fft.bins
9
+ def target_f_parameters(feature: str, _num_classes: int, _config: dict) -> int:
10
+ from pyaaware import ForwardTransform
11
+ from pyaaware import feature_forward_transform_config
14
12
 
15
- return config.target_fft.bins * 2
13
+ ft = ForwardTransform(**feature_forward_transform_config(feature))
16
14
 
15
+ if ft.ttype == "tdac-co":
16
+ return ft.bins
17
17
 
18
- def target_f(data: TruthFunctionData, config: TruthFunctionConfig) -> Truth:
18
+ return ft.bins * 2
19
+
20
+
21
+ def target_f(mixdb: MixtureDatabase, m_id: int, target_index: int, _config: dict) -> Truth:
19
22
  """Frequency domain target truth function
20
23
 
21
24
  Calculates the true transform of the target using the STFT
@@ -26,23 +29,34 @@ def target_f(data: TruthFunctionData, config: TruthFunctionConfig) -> Truth:
26
29
  [:, bins] (target real only for tdac-co)
27
30
  """
28
31
  import torch
32
+ from pyaaware import ForwardTransform
33
+ from pyaaware import feature_forward_transform_config
34
+
35
+ ft = ForwardTransform(**feature_forward_transform_config(mixdb.feature))
29
36
 
30
- target_freq = config.target_fft.execute_all(torch.from_numpy(data.target_audio))[0].numpy()
31
- return _stack_real_imag(target_freq, config.ttype)
37
+ target_audio = torch.from_numpy(mixdb.mixture_targets(m_id)[target_index])
38
+
39
+ target_freq = ft.execute_all(target_audio)[0].numpy()
40
+ return _stack_real_imag(target_freq, ft.ttype)
32
41
 
33
42
 
34
43
  def target_mixture_f_validate(_config: dict) -> None:
35
44
  pass
36
45
 
37
46
 
38
- def target_mixture_f_parameters(config: TruthFunctionConfig) -> int:
39
- if config.ttype == "tdac-co":
40
- return config.target_fft.bins * 2
47
+ def target_mixture_f_parameters(feature: str, _num_classes: int, _config: dict) -> int:
48
+ from pyaaware import ForwardTransform
49
+ from pyaaware import feature_forward_transform_config
50
+
51
+ ft = ForwardTransform(**feature_forward_transform_config(feature))
52
+
53
+ if ft.ttype == "tdac-co":
54
+ return ft.bins * 2
41
55
 
42
- return config.target_fft.bins * 4
56
+ return ft.bins * 4
43
57
 
44
58
 
45
- def target_mixture_f(data: TruthFunctionData, config: TruthFunctionConfig) -> Truth:
59
+ def target_mixture_f(mixdb: MixtureDatabase, m_id: int, target_index: int, _config: dict) -> Truth:
46
60
  """Frequency domain target and mixture truth function
47
61
 
48
62
  Calculates the true transform of the target and the mixture
@@ -55,14 +69,21 @@ def target_mixture_f(data: TruthFunctionData, config: TruthFunctionConfig) -> Tr
55
69
  """
56
70
  import numpy as np
57
71
  import torch
72
+ from pyaaware import ForwardTransform
73
+ from pyaaware import feature_forward_transform_config
58
74
 
59
- target_freq = config.target_fft.execute_all(torch.from_numpy(data.target_audio))[0].numpy()
60
- mixture_freq = config.mixture_fft.execute_all(torch.from_numpy(data.mixture_audio))[0].numpy()
75
+ ft = ForwardTransform(**feature_forward_transform_config(mixdb.feature))
76
+
77
+ target_audio = torch.from_numpy(mixdb.mixture_targets(m_id)[target_index])
78
+ mixture_audio = torch.from_numpy(mixdb.mixture_mixture(m_id))
79
+
80
+ target_freq = ft.execute_all(torch.from_numpy(target_audio))[0].numpy()
81
+ mixture_freq = ft.execute_all(torch.from_numpy(mixture_audio))[0].numpy()
61
82
 
62
83
  frames, bins = target_freq.shape
63
84
  truth = np.empty((frames, bins * 4), dtype=np.float32)
64
- truth[:, : bins * 2] = _stack_real_imag(target_freq, config.ttype)
65
- truth[:, bins * 2 :] = _stack_real_imag(mixture_freq, config.ttype)
85
+ truth[:, : bins * 2] = _stack_real_imag(target_freq, ft.ttype)
86
+ truth[:, bins * 2 :] = _stack_real_imag(mixture_freq, ft.ttype)
66
87
  return truth
67
88
 
68
89
 
@@ -70,11 +91,14 @@ def target_swin_f_validate(_config: dict) -> None:
70
91
  pass
71
92
 
72
93
 
73
- def target_swin_f_parameters(config: TruthFunctionConfig) -> int:
74
- return config.target_fft.bins * 2
94
+ def target_swin_f_parameters(feature: str, _num_classes: int, _config: dict) -> int:
95
+ from pyaaware import ForwardTransform
96
+ from pyaaware import feature_forward_transform_config
97
+
98
+ return ForwardTransform(**feature_forward_transform_config(feature)).bins * 2
75
99
 
76
100
 
77
- def target_swin_f(data: TruthFunctionData, config: TruthFunctionConfig) -> Truth:
101
+ def target_swin_f(mixdb: MixtureDatabase, m_id: int, target_index: int, _config: dict) -> Truth:
78
102
  """Frequency domain target with synthesis window truth function
79
103
 
80
104
  Calculates the true transform of the target using the STFT
@@ -85,20 +109,29 @@ def target_swin_f(data: TruthFunctionData, config: TruthFunctionConfig) -> Truth
85
109
  Output shape: [:, 2 * bins] (stacked real, imag)
86
110
  """
87
111
  import numpy as np
112
+ import torch
113
+ from pyaaware import ForwardTransform
114
+ from pyaaware import InverseTransform
115
+ from pyaaware import feature_forward_transform_config
116
+ from pyaaware import feature_inverse_transform_config
88
117
 
89
118
  from sonusai.utils import stack_complex
90
119
 
91
- truth = np.empty((len(data.target_audio) // config.frame_size, config.target_fft.bins * 2), dtype=np.float32)
92
- for idx, offset in enumerate(range(0, len(data.target_audio), config.frame_size)):
93
- target_freq = config.target_fft.execute(
94
- np.multiply(data.target_audio[offset : offset + config.frame_size], config.swin)
95
- )[0]
120
+ ft = ForwardTransform(**feature_forward_transform_config(mixdb.feature))
121
+ it = InverseTransform(**feature_inverse_transform_config(mixdb.feature))
122
+
123
+ target_audio = mixdb.mixture_targets(m_id)[target_index]
124
+
125
+ truth = np.empty((len(target_audio) // ft.overlap, ft.bins * 2), dtype=np.float32)
126
+ for idx, offset in enumerate(range(0, len(target_audio), ft.overlap)):
127
+ audio_frame = torch.from_numpy(np.multiply(target_audio[offset : offset + ft.overlap], it.window))
128
+ target_freq = ft.execute(audio_frame)[0].numpy()
96
129
  truth[idx] = stack_complex(target_freq)
97
130
 
98
131
  return truth
99
132
 
100
133
 
101
- def _stack_real_imag(data: AudioF, ttype: str) -> Truth:
134
+ def _stack_real_imag(data: Truth, ttype: str) -> Truth:
102
135
  import numpy as np
103
136
 
104
137
  from sonusai.utils import stack_complex
sonusai/mkwav.py CHANGED
@@ -1,12 +1,13 @@
1
1
  """sonusai mkwav
2
2
 
3
- usage: mkwav [-hvtn] [-i MIXID] LOC
3
+ usage: mkwav [-hvtsn] [-i MIXID] LOC
4
4
 
5
5
  options:
6
6
  -h, --help
7
7
  -v, --verbose Be verbose.
8
8
  -i MIXID, --mixid MIXID Mixture ID(s) to generate. [default: *].
9
9
  -t, --target Write target file.
10
+ -s, --targets Write targets files.
10
11
  -n, --noise Write noise file.
11
12
 
12
13
  The mkwav command creates WAV files from a SonusAI database.
@@ -17,18 +18,16 @@ Inputs:
17
18
 
18
19
  Outputs the following to the mixture database directory:
19
20
  <id>
20
- mixture.wav: mixture
21
- target.wav: target (optional)
22
- noise.wav: noise (optional)
21
+ mixture.wav: mixture
22
+ target.wav: target (optional)
23
+ targets<n>.wav: targets <n> (optional)
24
+ noise.wav: noise (optional)
23
25
  metadata.txt
24
26
  mkwav.log
25
27
 
26
28
  """
27
29
 
28
30
  import signal
29
- from dataclasses import dataclass
30
-
31
- from sonusai.mixture import MixtureDatabase
32
31
 
33
32
 
34
33
  def signal_handler(_sig, _frame):
@@ -43,39 +42,28 @@ def signal_handler(_sig, _frame):
43
42
  signal.signal(signal.SIGINT, signal_handler)
44
43
 
45
44
 
46
- @dataclass
47
- class MPGlobal:
48
- mixdb: MixtureDatabase
49
- write_target: bool
50
- write_noise: bool
51
-
52
-
53
- MP_GLOBAL: MPGlobal
54
-
55
-
56
- def _process_mixture(m_id: int) -> None:
45
+ def _process_mixture(m_id: int, location: str, write_target: bool, write_targets: bool, write_noise: bool) -> None:
57
46
  from os.path import join
58
47
 
48
+ from sonusai.mixture import MixtureDatabase
59
49
  from sonusai.mixture import write_mixture_metadata
60
50
  from sonusai.utils import float_to_int16
61
51
  from sonusai.utils import write_audio
62
52
 
63
- global MP_GLOBAL
64
-
65
- mixdb = MP_GLOBAL.mixdb
66
- write_target = MP_GLOBAL.write_target
67
- write_noise = MP_GLOBAL.write_noise
53
+ mixdb = MixtureDatabase(location)
68
54
 
69
- mixture = mixdb.mixture(m_id)
70
- location = join(mixdb.location, mixture.name)
55
+ location = join(mixdb.location, "mixture", mixdb.mixture(m_id).name)
71
56
 
72
57
  write_audio(name=join(location, "mixture.wav"), audio=float_to_int16(mixdb.mixture_mixture(m_id)))
73
58
  if write_target:
74
59
  write_audio(name=join(location, "target.wav"), audio=float_to_int16(mixdb.mixture_target(m_id)))
60
+ if write_targets:
61
+ for idx, target in enumerate(mixdb.mixture_targets(m_id)):
62
+ write_audio(name=join(location, f"targets{idx}.wav"), audio=float_to_int16(target))
75
63
  if write_noise:
76
64
  write_audio(name=join(location, "noise.wav"), audio=float_to_int16(mixdb.mixture_noise(m_id)))
77
65
 
78
- write_mixture_metadata(mixdb, mixture)
66
+ write_mixture_metadata(mixdb, m_id)
79
67
 
80
68
 
81
69
  def main() -> None:
@@ -88,20 +76,21 @@ def main() -> None:
88
76
 
89
77
  verbose = args["--verbose"]
90
78
  mixid = args["--mixid"]
91
- MP_GLOBAL.write_target = args["--target"]
92
- MP_GLOBAL.write_noise = args["--noise"]
79
+ write_target = args["--target"]
80
+ write_targets = args["--targets"]
81
+ write_noise = args["--noise"]
93
82
  location = args["LOC"]
94
83
 
95
84
  import time
85
+ from functools import partial
96
86
  from os.path import join
97
87
 
98
- import sonusai
99
88
  from sonusai import create_file_handler
100
89
  from sonusai import initial_log_messages
101
90
  from sonusai import logger
102
91
  from sonusai import update_console_handler
92
+ from sonusai.mixture import MixtureDatabase
103
93
  from sonusai.mixture import check_audio_files_exist
104
- from sonusai.utils import human_readable_size
105
94
  from sonusai.utils import par_track
106
95
  from sonusai.utils import seconds_to_hms
107
96
  from sonusai.utils import track
@@ -113,31 +102,33 @@ def main() -> None:
113
102
  initial_log_messages("mkwav")
114
103
 
115
104
  logger.info(f"Load mixture database from {location}")
116
- MP_GLOBAL.mixdb = MixtureDatabase(location)
117
- mixid = MP_GLOBAL.mixdb.mixids_to_list(mixid)
105
+ mixdb = MixtureDatabase(location)
106
+ mixid = mixdb.mixids_to_list(mixid)
118
107
 
119
- total_samples = MP_GLOBAL.mixdb.total_samples(mixid)
120
- duration = total_samples / sonusai.mixture.SAMPLE_RATE
108
+ total_samples = mixdb.total_samples(mixid)
121
109
 
122
110
  logger.info("")
123
111
  logger.info(f"Found {len(mixid):,} mixtures to process")
124
112
  logger.info(f"{total_samples:,} samples")
125
113
 
126
- check_audio_files_exist(MP_GLOBAL.mixdb)
114
+ check_audio_files_exist(mixdb)
127
115
 
128
116
  progress = track(total=len(mixid))
129
- par_track(_process_mixture, mixid, progress=progress)
117
+ par_track(
118
+ partial(
119
+ _process_mixture,
120
+ location=location,
121
+ write_target=write_target,
122
+ write_targets=write_targets,
123
+ write_noise=write_noise,
124
+ ),
125
+ mixid,
126
+ progress=progress,
127
+ )
130
128
  progress.close()
131
129
 
132
130
  logger.info(f"Wrote {len(mixid)} mixtures to {location}")
133
131
  logger.info("")
134
- logger.info(f"Duration: {seconds_to_hms(seconds=duration)}")
135
- logger.info(f"mixture: {human_readable_size(total_samples * 2, 1)}")
136
- if MP_GLOBAL.write_target:
137
- logger.info(f"target: {human_readable_size(total_samples * 2, 1)}")
138
- if MP_GLOBAL.write_noise:
139
- logger.info(f"noise: {human_readable_size(total_samples * 2, 1)}")
140
-
141
132
  end_time = time.monotonic()
142
133
  logger.info(f"Completed in {seconds_to_hms(seconds=end_time - start_time)}")
143
134
  logger.info("")
@@ -5,6 +5,10 @@ from sonusai.mixture.datatypes import GeneralizedIDs
5
5
  from sonusai.mixture.mixdb import MixtureDatabase
6
6
 
7
7
 
8
+ def _true_predicate(_: Any) -> bool:
9
+ return True
10
+
11
+
8
12
  def get_mixids_from_mixture_field_predicate(
9
13
  mixdb: MixtureDatabase,
10
14
  field: str,
@@ -20,9 +24,7 @@ def get_mixids_from_mixture_field_predicate(
20
24
  mixid_out = mixdb.mixids_to_list(mixids)
21
25
 
22
26
  if predicate is None:
23
-
24
- def predicate(_: Any) -> bool:
25
- return True
27
+ predicate = _true_predicate
26
28
 
27
29
  criteria_set = set()
28
30
  for m_id in mixid_out:
@@ -70,9 +72,7 @@ def get_mixids_from_truth_configs_field_predicate(
70
72
  values = get_all_truth_configs_values_from_field(mixdb, field)
71
73
 
72
74
  if predicate is None:
73
-
74
- def predicate(_: Any) -> bool:
75
- return True
75
+ predicate = _true_predicate
76
76
 
77
77
  # Get only values of interest
78
78
  values = [value for value in values if predicate(value)]
@@ -118,7 +118,7 @@ def get_all_truth_configs_values_from_field(mixdb: MixtureDatabase, field: str)
118
118
  value = getattr(truth_config, field)
119
119
  else:
120
120
  value = getattr(truth_config.config, field, None)
121
- if isinstance(value, str):
121
+ if not isinstance(value, list):
122
122
  value = [value]
123
123
  result.extend(value)
124
124
 
@@ -164,17 +164,13 @@ def get_mixids_from_snr(
164
164
  - keys are the SNRs
165
165
  - values are lists of the mixids that match the SNR
166
166
  """
167
- from typing import Any
168
-
169
167
  mixid_out = mixdb.mixids_to_list(mixids)
170
168
 
171
169
  # Get all the SNRs
172
170
  snrs = [float(snr) for snr in mixdb.all_snrs if not snr.is_random]
173
171
 
174
172
  if predicate is None:
175
-
176
- def predicate(_: Any) -> bool:
177
- return True
173
+ predicate = _true_predicate
178
174
 
179
175
  # Get only the SNRs of interest (filter on predicate)
180
176
  snrs = [snr for snr in snrs if predicate(snr)]
@@ -201,9 +197,7 @@ def get_mixids_from_class_indices(
201
197
  mixid_out = mixdb.mixids_to_list(mixids)
202
198
 
203
199
  if predicate is None:
204
-
205
- def predicate(_: Any) -> bool:
206
- return True
200
+ predicate = _true_predicate
207
201
 
208
202
  criteria_set = set()
209
203
  for m_id in mixid_out:
@@ -54,6 +54,7 @@ def load_phonemes(audio: str | os.PathLike[str]) -> list[TimeAlignedType] | None
54
54
 
55
55
  def _load_ta(audio: str | os.PathLike[str], tier: str) -> list[TimeAlignedType] | None:
56
56
  from praatio import textgrid
57
+ from praatio.utilities.constants import Interval
57
58
 
58
59
  file = Path(audio).parent.parent / "textgrid" / (Path(audio).stem + ".TextGrid")
59
60
  if not os.path.exists(file):
@@ -65,7 +66,8 @@ def _load_ta(audio: str | os.PathLike[str], tier: str) -> list[TimeAlignedType]
65
66
 
66
67
  entries: list[TimeAlignedType] = []
67
68
  for entry in tg.getTier(tier).entries:
68
- entries.append(TimeAlignedType(text=entry.label, start=entry.start, end=entry.end))
69
+ if isinstance(entry, Interval):
70
+ entries.append(TimeAlignedType(text=entry.label, start=entry.start, end=entry.end))
69
71
 
70
72
  return entries
71
73
 
@@ -79,6 +81,7 @@ def load_annotations(
79
81
  :return: A dictionary of a list of TimeAlignedType objects.
80
82
  """
81
83
  from praatio import textgrid
84
+ from praatio.utilities.constants import Interval
82
85
 
83
86
  file = Path(audio).parent.parent / "annotation" / (Path(audio).stem + ".TextGrid")
84
87
  if not os.path.exists(file):
@@ -89,7 +92,8 @@ def load_annotations(
89
92
  for tier in tg.tierNames:
90
93
  entries: list[TimeAlignedType] = []
91
94
  for entry in tg.getTier(tier).entries:
92
- entries.append(TimeAlignedType(text=entry.label, start=entry.start, end=entry.end))
95
+ if isinstance(entry, Interval):
96
+ entries.append(TimeAlignedType(text=entry.label, start=entry.start, end=entry.end))
93
97
  result[tier] = entries
94
98
 
95
99
  return result
@@ -48,7 +48,7 @@ def summarize_metric_spenh(location: str, by: str = "MIXID", reverse: bool = Fal
48
48
  data.append(line.strip().split())
49
49
  break
50
50
 
51
- df = pd.DataFrame(data, columns=header)
51
+ df = pd.DataFrame(data, columns=header) # pyright: ignore [reportArgumentType]
52
52
  df[header[0:-2]] = df[header[0:-2]].apply(pd.to_numeric, errors="coerce")
53
53
  return df.sort_values(by=by, ascending=not reverse).to_string(index=False)
54
54
 
sonusai/utils/__init__.py CHANGED
@@ -27,6 +27,7 @@ from .get_frames_per_batch import get_frames_per_batch
27
27
  from .get_label_names import get_label_names
28
28
  from .grouper import grouper
29
29
  from .human_readable_size import human_readable_size
30
+ from .load_object import load_object
30
31
  from .max_text_width import max_text_width
31
32
  from .model_utils import import_module
32
33
  from .numeric_conversion import float_to_int16
@@ -20,7 +20,7 @@ def aaware_whisper(audio: AudioT, **_config) -> ASRResult:
20
20
 
21
21
  url = getenv("AAWARE_WHISPER_URL")
22
22
  if url is None:
23
- raise EnvironmentError("AAWARE_WHISPER_URL environment variable does not exist")
23
+ raise OSError("AAWARE_WHISPER_URL environment variable does not exist")
24
24
  url += "/asr?task=transcribe&language=en&encode=true&output=json"
25
25
 
26
26
  with tempfile.TemporaryDirectory() as tmp: