sonusai 0.20.3__py3-none-any.whl → 1.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. sonusai/__init__.py +16 -3
  2. sonusai/audiofe.py +241 -77
  3. sonusai/calc_metric_spenh.py +71 -73
  4. sonusai/config/__init__.py +3 -0
  5. sonusai/config/config.py +61 -0
  6. sonusai/config/config.yml +20 -0
  7. sonusai/config/constants.py +8 -0
  8. sonusai/constants.py +11 -0
  9. sonusai/data/genmixdb.yml +21 -36
  10. sonusai/{mixture/datatypes.py → datatypes.py} +91 -130
  11. sonusai/deprecated/plot.py +4 -5
  12. sonusai/doc/doc.py +4 -4
  13. sonusai/doc.py +11 -4
  14. sonusai/genft.py +43 -45
  15. sonusai/genmetrics.py +25 -19
  16. sonusai/genmix.py +54 -82
  17. sonusai/genmixdb.py +88 -264
  18. sonusai/ir_metric.py +30 -34
  19. sonusai/lsdb.py +41 -48
  20. sonusai/main.py +15 -22
  21. sonusai/metrics/calc_audio_stats.py +4 -293
  22. sonusai/metrics/calc_class_weights.py +4 -4
  23. sonusai/metrics/calc_optimal_thresholds.py +8 -5
  24. sonusai/metrics/calc_pesq.py +2 -2
  25. sonusai/metrics/calc_segsnr_f.py +4 -4
  26. sonusai/metrics/calc_speech.py +25 -13
  27. sonusai/metrics/class_summary.py +7 -7
  28. sonusai/metrics/confusion_matrix_summary.py +5 -5
  29. sonusai/metrics/one_hot.py +4 -4
  30. sonusai/metrics/snr_summary.py +7 -7
  31. sonusai/metrics_summary.py +38 -45
  32. sonusai/mixture/__init__.py +4 -104
  33. sonusai/mixture/audio.py +10 -39
  34. sonusai/mixture/class_balancing.py +103 -0
  35. sonusai/mixture/config.py +251 -271
  36. sonusai/mixture/constants.py +35 -39
  37. sonusai/mixture/data_io.py +25 -36
  38. sonusai/mixture/db_datatypes.py +58 -22
  39. sonusai/mixture/effects.py +386 -0
  40. sonusai/mixture/feature.py +7 -11
  41. sonusai/mixture/generation.py +478 -628
  42. sonusai/mixture/helpers.py +82 -184
  43. sonusai/mixture/ir_delay.py +3 -4
  44. sonusai/mixture/ir_effects.py +77 -0
  45. sonusai/mixture/log_duration_and_sizes.py +6 -12
  46. sonusai/mixture/mixdb.py +910 -729
  47. sonusai/mixture/pad_audio.py +35 -0
  48. sonusai/mixture/resample.py +7 -0
  49. sonusai/mixture/sox_effects.py +195 -0
  50. sonusai/mixture/sox_help.py +650 -0
  51. sonusai/mixture/spectral_mask.py +2 -2
  52. sonusai/mixture/truth.py +17 -15
  53. sonusai/mixture/truth_functions/crm.py +12 -12
  54. sonusai/mixture/truth_functions/energy.py +22 -22
  55. sonusai/mixture/truth_functions/file.py +5 -5
  56. sonusai/mixture/truth_functions/metadata.py +4 -4
  57. sonusai/mixture/truth_functions/metrics.py +4 -4
  58. sonusai/mixture/truth_functions/phoneme.py +3 -3
  59. sonusai/mixture/truth_functions/sed.py +11 -13
  60. sonusai/mixture/truth_functions/target.py +10 -10
  61. sonusai/mkwav.py +26 -29
  62. sonusai/onnx_predict.py +240 -88
  63. sonusai/queries/__init__.py +2 -2
  64. sonusai/queries/queries.py +38 -34
  65. sonusai/speech/librispeech.py +1 -1
  66. sonusai/speech/mcgill.py +1 -1
  67. sonusai/speech/timit.py +2 -2
  68. sonusai/summarize_metric_spenh.py +10 -17
  69. sonusai/utils/__init__.py +7 -1
  70. sonusai/utils/asl_p56.py +2 -2
  71. sonusai/utils/asr.py +2 -2
  72. sonusai/utils/asr_functions/aaware_whisper.py +4 -5
  73. sonusai/utils/choice.py +31 -0
  74. sonusai/utils/compress.py +1 -1
  75. sonusai/utils/dataclass_from_dict.py +19 -1
  76. sonusai/utils/energy_f.py +3 -3
  77. sonusai/utils/evaluate_random_rule.py +15 -0
  78. sonusai/utils/keyboard_interrupt.py +12 -0
  79. sonusai/utils/onnx_utils.py +3 -17
  80. sonusai/utils/print_mixture_details.py +21 -19
  81. sonusai/utils/{temp_seed.py → rand.py} +3 -3
  82. sonusai/utils/read_predict_data.py +2 -2
  83. sonusai/utils/reshape.py +3 -3
  84. sonusai/utils/stratified_shuffle_split.py +3 -3
  85. sonusai/{mixture → utils}/tokenized_shell_vars.py +1 -1
  86. sonusai/utils/write_audio.py +2 -2
  87. sonusai/vars.py +11 -4
  88. {sonusai-0.20.3.dist-info → sonusai-1.0.2.dist-info}/METADATA +4 -2
  89. sonusai-1.0.2.dist-info/RECORD +138 -0
  90. sonusai/mixture/augmentation.py +0 -444
  91. sonusai/mixture/class_count.py +0 -15
  92. sonusai/mixture/eq_rule_is_valid.py +0 -45
  93. sonusai/mixture/target_class_balancing.py +0 -107
  94. sonusai/mixture/targets.py +0 -175
  95. sonusai-0.20.3.dist-info/RECORD +0 -128
  96. {sonusai-0.20.3.dist-info → sonusai-1.0.2.dist-info}/WHEEL +0 -0
  97. {sonusai-0.20.3.dist-info → sonusai-1.0.2.dist-info}/entry_points.txt +0 -0
sonusai/mixture/truth.py CHANGED
@@ -1,33 +1,35 @@
1
- from sonusai.mixture import MixtureDatabase
2
- from sonusai.mixture import Truth
1
+ from ..datatypes import Truth
2
+ from ..datatypes import TruthsDict
3
+ from .mixdb import MixtureDatabase
3
4
 
4
5
 
5
- def truth_function(mixdb: MixtureDatabase, m_id: int) -> list[Truth]:
6
- from sonusai.mixture import TruthDict
7
- from sonusai.mixture import truth_functions
6
+ def truth_function(mixdb: MixtureDatabase, m_id: int) -> TruthsDict:
7
+ from ..datatypes import TruthDict
8
+ from . import truth_functions
8
9
 
9
- result: list[Truth] = []
10
- for target_index in range(len(mixdb.mixture(m_id).targets)):
10
+ result: TruthsDict = {}
11
+ for category, source in mixdb.mixture(m_id).all_sources.items():
11
12
  truth: TruthDict = {}
12
- target_file = mixdb.target_file(mixdb.mixture(m_id).targets[target_index].file_id)
13
- for name, config in target_file.truth_configs.items():
13
+ source_file = mixdb.source_file(source.file_id)
14
+ for name, config in source_file.truth_configs.items():
14
15
  try:
15
- truth[name] = getattr(truth_functions, config.function)(mixdb, m_id, target_index, config.config)
16
+ truth[name] = getattr(truth_functions, config.function)(mixdb, m_id, category, config.config)
16
17
  except AttributeError as e:
17
18
  raise AttributeError(f"Unsupported truth function: {config.function}") from e
18
19
  except Exception as e:
19
20
  raise RuntimeError(f"Error in truth function '{config.function}': {e}") from e
20
21
 
21
- result.append(truth)
22
+ if truth:
23
+ result[category] = truth
22
24
 
23
25
  return result
24
26
 
25
27
 
26
- def get_truth_indices_for_mixid(mixdb: MixtureDatabase, mixid: int) -> list[int]:
27
- """Get a list of truth indices for a given mixid."""
28
+ def get_class_indices_for_mixid(mixdb: MixtureDatabase, mixid: int) -> list[int]:
29
+ """Get a list of class indices for a given mixid."""
28
30
  indices: list[int] = []
29
- for target_id in [target.file_id for target in mixdb.mixture(mixid).targets]:
30
- indices.append(*mixdb.target_file(target_id).class_indices)
31
+ for source_id in [source.file_id for source in mixdb.mixture(mixid).all_sources.values()]:
32
+ indices.append(*mixdb.source_file(source_id).class_indices)
31
33
 
32
34
  return sorted(set(indices))
33
35
 
@@ -1,31 +1,31 @@
1
- from sonusai.mixture import MixtureDatabase
2
- from sonusai.mixture import Truth
1
+ from ...datatypes import Truth
2
+ from ..mixdb import MixtureDatabase
3
3
 
4
4
 
5
- def _core(mixdb: MixtureDatabase, m_id: int, target_index: int, parameters: int, polar: bool) -> Truth:
5
+ def _core(mixdb: MixtureDatabase, m_id: int, category: str, parameters: int, polar: bool) -> Truth:
6
6
  import numpy as np
7
7
  import torch
8
8
  from pyaaware import ForwardTransform
9
9
  from pyaaware import feature_forward_transform_config
10
10
  from pyaaware import feature_inverse_transform_config
11
11
 
12
- target_audio = torch.from_numpy(mixdb.mixture_targets(m_id)[target_index])
12
+ source_audio = torch.from_numpy(mixdb.mixture_sources(m_id)[category])
13
13
  t_ft = ForwardTransform(**feature_forward_transform_config(mixdb.feature))
14
14
  n_ft = ForwardTransform(**feature_forward_transform_config(mixdb.feature))
15
15
 
16
- frames = t_ft.frames(target_audio)
17
- if mixdb.mixture(m_id).target_gain(target_index) == 0:
16
+ frames = t_ft.frames(source_audio)
17
+ if mixdb.mixture(m_id).all_sources[category].snr_gain == 0:
18
18
  return np.zeros((frames, parameters), dtype=np.float32)
19
19
 
20
20
  noise_audio = torch.from_numpy(mixdb.mixture_noise(m_id))
21
21
 
22
22
  frame_size = feature_inverse_transform_config(mixdb.feature)["overlap"]
23
23
 
24
- frames = len(target_audio) // frame_size
24
+ frames = len(source_audio) // frame_size
25
25
  truth = np.empty((frames, t_ft.bins * 2), dtype=np.float32)
26
26
  for frame in range(frames):
27
27
  offset = frame * frame_size
28
- target_f = t_ft.execute(target_audio[offset : offset + frame_size])[0].numpy().astype(np.complex64)
28
+ target_f = t_ft.execute(source_audio[offset : offset + frame_size])[0].numpy().astype(np.complex64)
29
29
  noise_f = n_ft.execute(noise_audio[offset : offset + frame_size])[0].numpy().astype(np.complex64)
30
30
  mixture_f = target_f + noise_f
31
31
 
@@ -58,7 +58,7 @@ def crm_parameters(feature: str, _num_classes: int, _config: dict) -> int:
58
58
  return ForwardTransform(**feature_forward_transform_config(feature)).bins * 2
59
59
 
60
60
 
61
- def crm(mixdb: MixtureDatabase, m_id: int, target_index: int, _config: dict) -> Truth:
61
+ def crm(mixdb: MixtureDatabase, m_id: int, category: str, _config: dict) -> Truth:
62
62
  """Complex ratio mask truth generation function
63
63
 
64
64
  Calculates the true complex ratio mask (CRM) truth which is a complex number
@@ -71,7 +71,7 @@ def crm(mixdb: MixtureDatabase, m_id: int, target_index: int, _config: dict) ->
71
71
  return _core(
72
72
  mixdb=mixdb,
73
73
  m_id=m_id,
74
- target_index=target_index,
74
+ category=category,
75
75
  parameters=crm_parameters(mixdb.feature, mixdb.num_classes, _config),
76
76
  polar=False,
77
77
  )
@@ -88,7 +88,7 @@ def crmp_parameters(feature: str, _num_classes: int, _config: dict) -> int:
88
88
  return ForwardTransform(**feature_forward_transform_config(feature)).bins * 2
89
89
 
90
90
 
91
- def crmp(mixdb: MixtureDatabase, m_id: int, target_index: int, _config: dict) -> Truth:
91
+ def crmp(mixdb: MixtureDatabase, m_id: int, category: str, _config: dict) -> Truth:
92
92
  """Complex ratio mask polar truth generation function
93
93
 
94
94
  Same as the crm function except the results are magnitude and phase
@@ -99,7 +99,7 @@ def crmp(mixdb: MixtureDatabase, m_id: int, target_index: int, _config: dict) ->
99
99
  return _core(
100
100
  mixdb=mixdb,
101
101
  m_id=m_id,
102
- target_index=target_index,
102
+ category=category,
103
103
  parameters=crmp_parameters(mixdb.feature, mixdb.num_classes, _config),
104
104
  polar=True,
105
105
  )
@@ -1,14 +1,14 @@
1
1
  import numpy as np
2
2
 
3
- from sonusai.mixture import MixtureDatabase
4
- from sonusai.mixture import Truth
5
- from sonusai.utils import load_object
3
+ from ...datatypes import Truth
4
+ from ...utils.load_object import load_object
5
+ from ..mixdb import MixtureDatabase
6
6
 
7
7
 
8
8
  def _core(
9
9
  mixdb: MixtureDatabase,
10
10
  m_id: int,
11
- target_index: int,
11
+ category: str,
12
12
  config: dict,
13
13
  parameters: int,
14
14
  mapped: bool,
@@ -21,27 +21,27 @@ def _core(
21
21
  from pyaaware import ForwardTransform
22
22
  from pyaaware import feature_forward_transform_config
23
23
 
24
- from sonusai.utils import compute_energy_f
24
+ from ...utils.energy_f import compute_energy_f
25
25
 
26
- target_audio = mixdb.mixture_targets(m_id)[target_index]
26
+ source_audio = mixdb.mixture_sources(m_id)[category]
27
27
  ft = ForwardTransform(**feature_forward_transform_config(mixdb.feature))
28
28
 
29
- frames = ft.frames(torch.from_numpy(target_audio))
29
+ frames = ft.frames(torch.from_numpy(source_audio))
30
30
 
31
- if mixdb.mixture(m_id).target_gain(target_index) == 0:
31
+ if mixdb.mixture(m_id).all_sources[category].snr_gain == 0:
32
32
  return np.zeros((frames, parameters), dtype=np.float32)
33
33
 
34
34
  noise_audio = mixdb.mixture_noise(m_id)
35
35
 
36
- target_energy = compute_energy_f(time_domain=target_audio, transform=ft)
36
+ source_energy = compute_energy_f(time_domain=source_audio, transform=ft)
37
37
  noise_energy = None
38
38
  if snr:
39
39
  noise_energy = compute_energy_f(time_domain=noise_audio, transform=ft)
40
40
 
41
- frames = len(target_energy)
41
+ frames = len(source_energy)
42
42
  truth = np.empty((frames, ft.bins), dtype=np.float32)
43
43
  for frame in range(frames):
44
- tmp = target_energy[frame]
44
+ tmp = source_energy[frame]
45
45
 
46
46
  if noise_energy is not None:
47
47
  old_err = np.seterr(divide="ignore", invalid="ignore")
@@ -86,7 +86,7 @@ def energy_f_parameters(feature: str, _num_classes: int, _config: dict) -> int:
86
86
  return ForwardTransform(**feature_forward_transform_config(feature)).bins
87
87
 
88
88
 
89
- def energy_f(mixdb: MixtureDatabase, m_id: int, target_index: int, config: dict, use_cache: bool = True) -> Truth:
89
+ def energy_f(mixdb: MixtureDatabase, m_id: int, category: str, config: dict, use_cache: bool = True) -> Truth:
90
90
  """Frequency domain energy truth generation function
91
91
 
92
92
  Calculates the true energy per bin:
@@ -100,7 +100,7 @@ def energy_f(mixdb: MixtureDatabase, m_id: int, target_index: int, config: dict,
100
100
  return _core(
101
101
  mixdb=mixdb,
102
102
  m_id=m_id,
103
- target_index=target_index,
103
+ category=category,
104
104
  config=config,
105
105
  parameters=energy_f_parameters(mixdb.feature, mixdb.num_classes, config),
106
106
  mapped=False,
@@ -120,7 +120,7 @@ def snr_f_parameters(feature: str, _num_classes: int, _config: dict) -> int:
120
120
  return ForwardTransform(**feature_forward_transform_config(feature)).bins
121
121
 
122
122
 
123
- def snr_f(mixdb: MixtureDatabase, m_id: int, target_index: int, config: dict, use_cache: bool = True) -> Truth:
123
+ def snr_f(mixdb: MixtureDatabase, m_id: int, category: str, config: dict, use_cache: bool = True) -> Truth:
124
124
  """Frequency domain SNR truth function documentation
125
125
 
126
126
  Calculates the true SNR per bin:
@@ -134,7 +134,7 @@ def snr_f(mixdb: MixtureDatabase, m_id: int, target_index: int, config: dict, us
134
134
  return _core(
135
135
  mixdb=mixdb,
136
136
  m_id=m_id,
137
- target_index=target_index,
137
+ category=category,
138
138
  config=config,
139
139
  parameters=snr_f_parameters(mixdb.feature, mixdb.num_classes, config),
140
140
  mapped=False,
@@ -159,7 +159,7 @@ def mapped_snr_f_parameters(feature: str, _num_classes: int, _config: dict) -> i
159
159
  return ForwardTransform(**feature_forward_transform_config(feature)).bins
160
160
 
161
161
 
162
- def mapped_snr_f(mixdb: MixtureDatabase, m_id: int, target_index: int, config: dict, use_cache: bool = True) -> Truth:
162
+ def mapped_snr_f(mixdb: MixtureDatabase, m_id: int, category: str, config: dict, use_cache: bool = True) -> Truth:
163
163
  """Frequency domain mapped SNR truth function documentation
164
164
 
165
165
  Output shape: [:, bins]
@@ -167,7 +167,7 @@ def mapped_snr_f(mixdb: MixtureDatabase, m_id: int, target_index: int, config: d
167
167
  return _core(
168
168
  mixdb=mixdb,
169
169
  m_id=m_id,
170
- target_index=target_index,
170
+ category=category,
171
171
  config=config,
172
172
  parameters=mapped_snr_f_parameters(mixdb.feature, mixdb.num_classes, config),
173
173
  mapped=True,
@@ -184,7 +184,7 @@ def energy_t_parameters(_feature: str, _num_classes: int, _config: dict) -> int:
184
184
  return 1
185
185
 
186
186
 
187
- def energy_t(mixdb: MixtureDatabase, m_id: int, target_index: int, _config: dict) -> Truth:
187
+ def energy_t(mixdb: MixtureDatabase, m_id: int, category: str, _config: dict) -> Truth:
188
188
  """Time domain energy truth function documentation
189
189
 
190
190
  Calculates the true time domain energy of each frame:
@@ -210,13 +210,13 @@ def energy_t(mixdb: MixtureDatabase, m_id: int, target_index: int, _config: dict
210
210
  from pyaaware import ForwardTransform
211
211
  from pyaaware import feature_forward_transform_config
212
212
 
213
- target_audio = torch.from_numpy(mixdb.mixture_targets(m_id)[target_index])
213
+ source_audio = torch.from_numpy(mixdb.mixture_sources(m_id)[category])
214
214
 
215
215
  ft = ForwardTransform(**feature_forward_transform_config(mixdb.feature))
216
216
 
217
- frames = ft.frames(target_audio)
217
+ frames = ft.frames(source_audio)
218
218
  parameters = energy_f_parameters(mixdb.feature, mixdb.num_classes, _config)
219
- if mixdb.mixture(m_id).target_gain(target_index) == 0:
219
+ if mixdb.mixture(m_id).all_sources[category].snr_gain == 0:
220
220
  return np.zeros((frames, parameters), dtype=np.float32)
221
221
 
222
- return ft.execute_all(target_audio)[1].numpy()
222
+ return ft.execute_all(source_audio)[1].numpy()
@@ -1,5 +1,5 @@
1
- from sonusai.mixture import MixtureDatabase
2
- from sonusai.mixture import Truth
1
+ from ...datatypes import Truth
2
+ from ..mixdb import MixtureDatabase
3
3
 
4
4
 
5
5
  def file_validate(config: dict) -> None:
@@ -26,13 +26,13 @@ def file_parameters(_feature: str, _num_classes: int, config: dict) -> int:
26
26
  return truth.shape[-1]
27
27
 
28
28
 
29
- def file(mixdb: MixtureDatabase, m_id: int, target_index: int, config: dict) -> Truth:
29
+ def file(mixdb: MixtureDatabase, m_id: int, category: str, config: dict) -> Truth:
30
30
  """file truth function documentation"""
31
31
  import h5py
32
32
  import numpy as np
33
33
  from pyaaware import feature_inverse_transform_config
34
34
 
35
- target_audio = mixdb.mixture_targets(m_id)[target_index]
35
+ source_audio = mixdb.mixture_sources(m_id)[category]
36
36
 
37
37
  frame_size = feature_inverse_transform_config(mixdb.feature)["overlap"]
38
38
 
@@ -42,7 +42,7 @@ def file(mixdb: MixtureDatabase, m_id: int, target_index: int, config: dict) ->
42
42
  if truth.ndim != 2:
43
43
  raise ValueError("Truth file data is not 2 dimensions")
44
44
 
45
- if truth.shape[0] != len(target_audio) // frame_size:
45
+ if truth.shape[0] != len(source_audio) // frame_size:
46
46
  raise ValueError("Truth file does not contain the right amount of frames")
47
47
 
48
48
  return truth
@@ -1,5 +1,5 @@
1
- from sonusai.mixture import MixtureDatabase
2
- from sonusai.mixture import Truth
1
+ from ...datatypes import Truth
2
+ from ..mixdb import MixtureDatabase
3
3
 
4
4
 
5
5
  def metadata_validate(config: dict) -> None:
@@ -16,9 +16,9 @@ def metadata_parameters(_feature: str, _num_classes: int, _config: dict) -> int
16
16
  return None
17
17
 
18
18
 
19
- def metadata(mixdb: MixtureDatabase, m_id: int, target_index: int, config: dict) -> Truth:
19
+ def metadata(mixdb: MixtureDatabase, m_id: int, category: str, config: dict) -> Truth:
20
20
  """Metadata truth generation function
21
21
 
22
22
  Retrieves metadata from target.
23
23
  """
24
- return mixdb.mixture_speech_metadata(m_id, config["tier"])[target_index]
24
+ return mixdb.mixture_speech_metadata(m_id, config["tier"])[category]
@@ -1,5 +1,5 @@
1
- from sonusai.mixture import MixtureDatabase
2
- from sonusai.mixture import Truth
1
+ from ...datatypes import Truth
2
+ from ..mixdb import MixtureDatabase
3
3
 
4
4
 
5
5
  def metrics_validate(config: dict) -> None:
@@ -16,7 +16,7 @@ def metrics_parameters(_feature: str, _num_classes: int, _config: dict) -> int |
16
16
  return None
17
17
 
18
18
 
19
- def metrics(mixdb: MixtureDatabase, m_id: int, target_index: int, config: dict) -> Truth:
19
+ def metrics(mixdb: MixtureDatabase, m_id: int, category: str, config: dict) -> Truth:
20
20
  """Metadata truth generation function
21
21
 
22
22
  Retrieves metrics from target.
@@ -25,4 +25,4 @@ def metrics(mixdb: MixtureDatabase, m_id: int, target_index: int, config: dict)
25
25
  m = [config["metric"]]
26
26
  else:
27
27
  m = config["metric"]
28
- return mixdb.mixture_metrics(m_id, m)[m[0]][target_index]
28
+ return mixdb.mixture_metrics(m_id, m)[m[0]][category]
@@ -1,5 +1,5 @@
1
- from sonusai.mixture import MixtureDatabase
2
- from sonusai.mixture import Truth
1
+ from ...datatypes import Truth
2
+ from ..mixdb import MixtureDatabase
3
3
 
4
4
 
5
5
  def phoneme_validate(_config: dict) -> None:
@@ -10,7 +10,7 @@ def phoneme_parameters(_feature: str, _num_classes: int, _config: dict) -> int:
10
10
  raise NotImplementedError("Truth function phoneme is not supported yet")
11
11
 
12
12
 
13
- def phoneme(_mixdb: MixtureDatabase, _m_id: int, _target_index: int, _config: dict) -> Truth:
13
+ def phoneme(_mixdb: MixtureDatabase, _m_id: int, _category: str, _config: dict) -> Truth:
14
14
  """Read in .txt transcript and run a Python function to generate text grid data
15
15
  (indicating which phonemes are active). Then generate truth based on this data and put
16
16
  in the correct classes based on the index in the config.
@@ -1,5 +1,7 @@
1
- from sonusai.mixture import MixtureDatabase
2
- from sonusai.mixture import Truth
1
+ from numpy.lib.utils import source
2
+
3
+ from ...datatypes import Truth
4
+ from ..mixdb import MixtureDatabase
3
5
 
4
6
 
5
7
  def sed_validate(config: dict) -> None:
@@ -20,7 +22,7 @@ def sed_parameters(_feature: str, num_classes: int, _config: dict) -> int:
20
22
  return num_classes
21
23
 
22
24
 
23
- def sed(mixdb: MixtureDatabase, m_id: int, target_index: int, config: dict) -> Truth:
25
+ def sed(mixdb: MixtureDatabase, m_id: int, category: str, config: dict) -> Truth:
24
26
  """Sound energy detection truth generation function
25
27
 
26
28
  Calculates sound energy detection truth using simple 3 threshold
@@ -59,34 +61,30 @@ def sed(mixdb: MixtureDatabase, m_id: int, target_index: int, config: dict) -> T
59
61
  from pyaaware import feature_forward_transform_config
60
62
  from pyaaware import feature_inverse_transform_config
61
63
 
62
- target_audio = torch.from_numpy(mixdb.mixture_targets(m_id)[target_index])
64
+ source_audio = torch.from_numpy(mixdb.mixture_sources(m_id)[category])
63
65
 
64
66
  frame_size = feature_inverse_transform_config(mixdb.feature)["overlap"]
65
67
 
66
68
  ft = ForwardTransform(**feature_forward_transform_config(mixdb.feature))
67
69
 
68
- if len(target_audio) % frame_size != 0:
70
+ if len(source_audio) % frame_size != 0:
69
71
  raise ValueError(f"Number of samples in audio is not a multiple of {frame_size}")
70
72
 
71
- frames = ft.frames(target_audio)
73
+ frames = ft.frames(source_audio)
72
74
  parameters = sed_parameters(mixdb.feature, mixdb.num_classes, config)
73
- target_gain = mixdb.mixture(m_id).target_gain(target_index)
74
- if target_gain == 0:
75
+ if mixdb.mixture(m_id).all_sources[category].snr_gain == 0:
75
76
  return np.zeros((frames, parameters), dtype=np.float32)
76
77
 
77
78
  # SED wants 1-based indices
78
79
  s = SED(
79
80
  thresholds=config["thresholds"],
80
- index=mixdb.target_file(mixdb.mixture(m_id).targets[target_index].file_id).class_indices,
81
+ index=mixdb.source_file(mixdb.mixture(m_id).all_sources[category].file_id).class_indices,
81
82
  frame_size=frame_size,
82
83
  num_classes=mixdb.num_classes,
83
84
  )
84
85
 
85
- # Back out target gain
86
- target_audio = target_audio / target_gain
87
-
88
86
  # Compute energy
89
- target_energy = ft.execute_all(target_audio)[1].numpy()
87
+ target_energy = ft.execute_all(source_audio)[1].numpy()
90
88
 
91
89
  if frames != target_energy.shape[0]:
92
90
  raise ValueError("Incorrect frames calculation in sed truth function")
@@ -1,5 +1,5 @@
1
- from sonusai.mixture import MixtureDatabase
2
- from sonusai.mixture import Truth
1
+ from ...datatypes import Truth
2
+ from ..mixdb import MixtureDatabase
3
3
 
4
4
 
5
5
  def target_f_validate(_config: dict) -> None:
@@ -18,7 +18,7 @@ def target_f_parameters(feature: str, _num_classes: int, _config: dict) -> int:
18
18
  return ft.bins * 2
19
19
 
20
20
 
21
- def target_f(mixdb: MixtureDatabase, m_id: int, target_index: int, _config: dict) -> Truth:
21
+ def target_f(mixdb: MixtureDatabase, m_id: int, category: str, _config: dict) -> Truth:
22
22
  """Frequency domain target truth function
23
23
 
24
24
  Calculates the true transform of the target using the STFT
@@ -34,7 +34,7 @@ def target_f(mixdb: MixtureDatabase, m_id: int, target_index: int, _config: dict
34
34
 
35
35
  ft = ForwardTransform(**feature_forward_transform_config(mixdb.feature))
36
36
 
37
- target_audio = torch.from_numpy(mixdb.mixture_targets(m_id)[target_index])
37
+ target_audio = torch.from_numpy(mixdb.mixture_sources(m_id)[category])
38
38
 
39
39
  target_freq = ft.execute_all(target_audio)[0].numpy()
40
40
  return _stack_real_imag(target_freq, ft.ttype)
@@ -56,7 +56,7 @@ def target_mixture_f_parameters(feature: str, _num_classes: int, _config: dict)
56
56
  return ft.bins * 4
57
57
 
58
58
 
59
- def target_mixture_f(mixdb: MixtureDatabase, m_id: int, target_index: int, _config: dict) -> Truth:
59
+ def target_mixture_f(mixdb: MixtureDatabase, m_id: int, category: str, _config: dict) -> Truth:
60
60
  """Frequency domain target and mixture truth function
61
61
 
62
62
  Calculates the true transform of the target and the mixture
@@ -74,7 +74,7 @@ def target_mixture_f(mixdb: MixtureDatabase, m_id: int, target_index: int, _conf
74
74
 
75
75
  ft = ForwardTransform(**feature_forward_transform_config(mixdb.feature))
76
76
 
77
- target_audio = torch.from_numpy(mixdb.mixture_targets(m_id)[target_index])
77
+ target_audio = torch.from_numpy(mixdb.mixture_sources(m_id)[category])
78
78
  mixture_audio = torch.from_numpy(mixdb.mixture_mixture(m_id))
79
79
 
80
80
  target_freq = ft.execute_all(torch.from_numpy(target_audio))[0].numpy()
@@ -98,7 +98,7 @@ def target_swin_f_parameters(feature: str, _num_classes: int, _config: dict) ->
98
98
  return ForwardTransform(**feature_forward_transform_config(feature)).bins * 2
99
99
 
100
100
 
101
- def target_swin_f(mixdb: MixtureDatabase, m_id: int, target_index: int, _config: dict) -> Truth:
101
+ def target_swin_f(mixdb: MixtureDatabase, m_id: int, category: str, _config: dict) -> Truth:
102
102
  """Frequency domain target with synthesis window truth function
103
103
 
104
104
  Calculates the true transform of the target using the STFT
@@ -115,12 +115,12 @@ def target_swin_f(mixdb: MixtureDatabase, m_id: int, target_index: int, _config:
115
115
  from pyaaware import feature_forward_transform_config
116
116
  from pyaaware import feature_inverse_transform_config
117
117
 
118
- from sonusai.utils import stack_complex
118
+ from ...utils.stacked_complex import stack_complex
119
119
 
120
120
  ft = ForwardTransform(**feature_forward_transform_config(mixdb.feature))
121
121
  it = InverseTransform(**feature_inverse_transform_config(mixdb.feature))
122
122
 
123
- target_audio = mixdb.mixture_targets(m_id)[target_index]
123
+ target_audio = mixdb.mixture_sources(m_id)[category]
124
124
 
125
125
  truth = np.empty((len(target_audio) // ft.overlap, ft.bins * 2), dtype=np.float32)
126
126
  for idx, offset in enumerate(range(0, len(target_audio), ft.overlap)):
@@ -134,7 +134,7 @@ def target_swin_f(mixdb: MixtureDatabase, m_id: int, target_index: int, _config:
134
134
  def _stack_real_imag(data: Truth, ttype: str) -> Truth:
135
135
  import numpy as np
136
136
 
137
- from sonusai.utils import stack_complex
137
+ from ...utils.stacked_complex import stack_complex
138
138
 
139
139
  if ttype == "tdac-co":
140
140
  return np.real(data)
sonusai/mkwav.py CHANGED
@@ -6,8 +6,8 @@ options:
6
6
  -h, --help
7
7
  -v, --verbose Be verbose.
8
8
  -i MIXID, --mixid MIXID Mixture ID(s) to generate. [default: *].
9
- -t, --target Write target file.
10
- -s, --targets Write targets files.
9
+ -t, --source Write source file.
10
+ -s, --sources Write sources files.
11
11
  -n, --noise Write noise file.
12
12
 
13
13
  The mkwav command creates WAV files from a SonusAI database.
@@ -19,30 +19,17 @@ Inputs:
19
19
  Outputs the following to the mixture database directory:
20
20
  <id>
21
21
  mixture.wav: mixture
22
- target.wav: target (optional)
23
- targets<n>.wav: targets <n> (optional)
22
+ source.wav: source (optional)
23
+ source_<c>.wav: source <category> (optional)
24
24
  noise.wav: noise (optional)
25
25
  metadata.txt
26
26
  mkwav.log
27
27
 
28
28
  """
29
29
 
30
- import signal
31
-
32
-
33
- def signal_handler(_sig, _frame):
34
- import sys
35
-
36
- from sonusai import logger
37
-
38
- logger.info("Canceled due to keyboard interrupt")
39
- sys.exit(1)
40
-
41
-
42
- signal.signal(signal.SIGINT, signal_handler)
43
-
44
30
 
45
31
  def _process_mixture(m_id: int, location: str, write_target: bool, write_targets: bool, write_noise: bool) -> None:
32
+ from os import makedirs
46
33
  from os.path import join
47
34
 
48
35
  from sonusai.mixture import MixtureDatabase
@@ -52,14 +39,16 @@ def _process_mixture(m_id: int, location: str, write_target: bool, write_targets
52
39
 
53
40
  mixdb = MixtureDatabase(location)
54
41
 
55
- location = join(mixdb.location, "mixture", mixdb.mixture(m_id).name)
42
+ index = mixdb.mixture(m_id).name
43
+ location = join(mixdb.location, "mixture", index)
44
+ makedirs(location, exist_ok=True)
56
45
 
57
46
  write_audio(name=join(location, "mixture.wav"), audio=float_to_int16(mixdb.mixture_mixture(m_id)))
58
47
  if write_target:
59
- write_audio(name=join(location, "target.wav"), audio=float_to_int16(mixdb.mixture_target(m_id)))
48
+ write_audio(name=join(location, "source.wav"), audio=float_to_int16(mixdb.mixture_source(m_id)))
60
49
  if write_targets:
61
- for idx, target in enumerate(mixdb.mixture_targets(m_id)):
62
- write_audio(name=join(location, f"targets{idx}.wav"), audio=float_to_int16(target))
50
+ for category, source in mixdb.mixture_sources(m_id).items():
51
+ write_audio(name=join(location, f"sources_{category}.wav"), audio=float_to_int16(source))
63
52
  if write_noise:
64
53
  write_audio(name=join(location, "noise.wav"), audio=float_to_int16(mixdb.mixture_noise(m_id)))
65
54
 
@@ -69,15 +58,15 @@ def _process_mixture(m_id: int, location: str, write_target: bool, write_targets
69
58
  def main() -> None:
70
59
  from docopt import docopt
71
60
 
72
- import sonusai
61
+ from sonusai import __version__ as sai_version
73
62
  from sonusai.utils import trim_docstring
74
63
 
75
- args = docopt(trim_docstring(__doc__), version=sonusai.__version__, options_first=True)
64
+ args = docopt(trim_docstring(__doc__), version=sai_version, options_first=True)
76
65
 
77
66
  verbose = args["--verbose"]
78
67
  mixid = args["--mixid"]
79
- write_target = args["--target"]
80
- write_targets = args["--targets"]
68
+ write_source = args["--source"]
69
+ write_sources = args["--sources"]
81
70
  write_noise = args["--noise"]
82
71
  location = args["LOC"]
83
72
 
@@ -118,12 +107,13 @@ def main() -> None:
118
107
  partial(
119
108
  _process_mixture,
120
109
  location=location,
121
- write_target=write_target,
122
- write_targets=write_targets,
110
+ write_target=write_source,
111
+ write_targets=write_sources,
123
112
  write_noise=write_noise,
124
113
  ),
125
114
  mixid,
126
115
  progress=progress,
116
+ # no_par=True,
127
117
  )
128
118
  progress.close()
129
119
 
@@ -135,4 +125,11 @@ def main() -> None:
135
125
 
136
126
 
137
127
  if __name__ == "__main__":
138
- main()
128
+ from sonusai import exception_handler
129
+ from sonusai.utils import register_keyboard_interrupt
130
+
131
+ register_keyboard_interrupt()
132
+ try:
133
+ main()
134
+ except Exception as e:
135
+ exception_handler(e)