sonusai 0.19.6__py3-none-any.whl → 0.19.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. sonusai/__init__.py +1 -1
  2. sonusai/aawscd_probwrite.py +1 -1
  3. sonusai/calc_metric_spenh.py +1 -1
  4. sonusai/genft.py +29 -14
  5. sonusai/genmetrics.py +60 -42
  6. sonusai/genmix.py +41 -29
  7. sonusai/genmixdb.py +54 -62
  8. sonusai/metrics/calc_class_weights.py +1 -3
  9. sonusai/metrics/calc_optimal_thresholds.py +2 -2
  10. sonusai/metrics/calc_phase_distance.py +1 -1
  11. sonusai/metrics/calc_speech.py +6 -6
  12. sonusai/metrics/class_summary.py +6 -15
  13. sonusai/metrics/confusion_matrix_summary.py +11 -27
  14. sonusai/metrics/one_hot.py +3 -3
  15. sonusai/metrics/snr_summary.py +7 -7
  16. sonusai/mixture/__init__.py +2 -17
  17. sonusai/mixture/augmentation.py +5 -6
  18. sonusai/mixture/class_count.py +1 -1
  19. sonusai/mixture/config.py +36 -46
  20. sonusai/mixture/data_io.py +30 -1
  21. sonusai/mixture/datatypes.py +29 -40
  22. sonusai/mixture/db_datatypes.py +1 -1
  23. sonusai/mixture/feature.py +3 -23
  24. sonusai/mixture/generation.py +202 -235
  25. sonusai/mixture/helpers.py +29 -187
  26. sonusai/mixture/mixdb.py +386 -159
  27. sonusai/mixture/soundfile_audio.py +1 -1
  28. sonusai/mixture/sox_audio.py +4 -4
  29. sonusai/mixture/sox_augmentation.py +1 -1
  30. sonusai/mixture/target_class_balancing.py +9 -11
  31. sonusai/mixture/targets.py +23 -20
  32. sonusai/mixture/truth.py +21 -34
  33. sonusai/mixture/truth_functions/__init__.py +6 -0
  34. sonusai/mixture/truth_functions/crm.py +51 -37
  35. sonusai/mixture/truth_functions/energy.py +95 -50
  36. sonusai/mixture/truth_functions/file.py +12 -8
  37. sonusai/mixture/truth_functions/metadata.py +24 -0
  38. sonusai/mixture/truth_functions/metrics.py +28 -0
  39. sonusai/mixture/truth_functions/phoneme.py +4 -5
  40. sonusai/mixture/truth_functions/sed.py +32 -23
  41. sonusai/mixture/truth_functions/target.py +62 -29
  42. sonusai/mkwav.py +20 -19
  43. sonusai/queries/queries.py +9 -15
  44. sonusai/speech/l2arctic.py +6 -2
  45. sonusai/summarize_metric_spenh.py +1 -1
  46. sonusai/utils/__init__.py +1 -0
  47. sonusai/utils/asr_functions/aaware_whisper.py +1 -1
  48. sonusai/utils/audio_devices.py +27 -18
  49. sonusai/utils/docstring.py +6 -3
  50. sonusai/utils/energy_f.py +5 -3
  51. sonusai/utils/human_readable_size.py +6 -6
  52. sonusai/utils/load_object.py +15 -0
  53. sonusai/utils/onnx_utils.py +2 -2
  54. sonusai/utils/print_mixture_details.py +3 -3
  55. {sonusai-0.19.6.dist-info → sonusai-0.19.8.dist-info}/METADATA +2 -2
  56. {sonusai-0.19.6.dist-info → sonusai-0.19.8.dist-info}/RECORD +58 -56
  57. sonusai/mixture/truth_functions/datatypes.py +0 -37
  58. {sonusai-0.19.6.dist-info → sonusai-0.19.8.dist-info}/WHEEL +0 -0
  59. {sonusai-0.19.6.dist-info → sonusai-0.19.8.dist-info}/entry_points.txt +0 -0
sonusai/__init__.py CHANGED
@@ -7,7 +7,7 @@ from rich.traceback import install
7
7
 
8
8
  install(show_locals=True)
9
9
 
10
- __version__ = metadata.version(__package__)
10
+ __version__ = metadata.version(__package__) # pyright: ignore [reportArgumentType]
11
11
  BASEDIR = dirname(__file__)
12
12
 
13
13
  commands_doc = """
@@ -89,7 +89,7 @@ def on_message(_client, _userdata, message):
89
89
  FRAME_COUNT += 1
90
90
 
91
91
  global PROGRESS
92
- PROGRESS.update()
92
+ PROGRESS.update() # pyright: ignore [reportOptionalMemberAccess]
93
93
 
94
94
  if FRAME_COUNT == FRAMES:
95
95
  global DONE
@@ -582,7 +582,7 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
582
582
  asr_mx = None
583
583
  asr_tge = None
584
584
  asr_engines = list(mixdb.asr_configs.keys())
585
- if len(asr_engines) > 0 and mixdb.mixture(mixid).snr >= -96: # noise only, ignore/reset target asr
585
+ if len(asr_engines) > 0 and not mixdb.mixture(mixid).is_noise_only: # noise only, ignore/reset target asr
586
586
  wer_mx = float(mixdb.mixture_metrics(mixid, [f"mxwer.{asr_engines[0]}"])[0]) * 100
587
587
  asr_tt = MP_GLOBAL.mixdb.mixture_speech_metadata(mixid, "text")[0] # ignore mixup
588
588
  if asr_tt is None:
sonusai/genft.py CHANGED
@@ -1,12 +1,13 @@
1
1
  """sonusai genft
2
2
 
3
- usage: genft [-hvs] [-i MIXID] LOC
3
+ usage: genft [-hvsn] [-i MIXID] LOC
4
4
 
5
5
  options:
6
6
  -h, --help
7
- -v, --verbose Be verbose.
8
- -i MIXID, --mixid MIXID Mixture ID(s) to generate. [default: *].
9
- -s, --segsnr Save segsnr. [default: False].
7
+ -v, --verbose Be verbose.
8
+ -i MIXID, --mixid MIXID Mixture ID(s) to generate. [default: *].
9
+ -s, --segsnr Save segsnr. [default: False].
10
+ -n, --nopar Do not run in parallel. [default: False].
10
11
 
11
12
  Generate SonusAI feature/truth data from a SonusAI mixture database.
12
13
 
@@ -50,6 +51,7 @@ def genft(
50
51
  write: bool = False,
51
52
  show_progress: bool = False,
52
53
  force: bool = True,
54
+ no_par: bool = False,
53
55
  ) -> list[GenFTData]:
54
56
  from functools import partial
55
57
 
@@ -62,18 +64,26 @@ def genft(
62
64
 
63
65
  progress = track(total=len(mixids), disable=not show_progress)
64
66
  results = par_track(
65
- partial(_genft_kernel, location=location, compute_truth=compute_truth, compute_segsnr=compute_segsnr, force=force, write=write),
67
+ partial(
68
+ _genft_kernel,
69
+ location=location,
70
+ compute_truth=compute_truth,
71
+ compute_segsnr=compute_segsnr,
72
+ force=force,
73
+ write=write,
74
+ ),
66
75
  mixids,
67
76
  progress=progress,
77
+ no_par=no_par,
68
78
  )
69
79
  progress.close()
70
80
 
71
81
  return results
72
82
 
73
83
 
74
- def _genft_kernel(m_id: int, location: str, compute_truth: bool, compute_segsnr: bool, force: bool, write: bool) -> GenFTData:
75
- from typing import Any
76
-
84
+ def _genft_kernel(
85
+ m_id: int, location: str, compute_truth: bool, compute_segsnr: bool, force: bool, write: bool
86
+ ) -> GenFTData:
77
87
  from sonusai.mixture import MixtureDatabase
78
88
  from sonusai.mixture import write_cached_data
79
89
  from sonusai.mixture import write_mixture_metadata
@@ -83,21 +93,23 @@ def _genft_kernel(m_id: int, location: str, compute_truth: bool, compute_segsnr:
83
93
  result = GenFTData()
84
94
 
85
95
  feature, truth_f = mixdb.mixture_ft(m_id=m_id, force=force)
86
- write_data: list[tuple[str, Any]] = [("feature", feature)]
87
96
  result.feature = feature
97
+ if write:
98
+ write_cached_data(mixdb.location, "mixture", mixdb.mixture(m_id).name, [("feature", feature)])
88
99
 
89
100
  if compute_truth:
90
- write_data.append(("truth_f", truth_f))
91
101
  result.truth_f = truth_f
102
+ if write:
103
+ write_cached_data(mixdb.location, "mixture", mixdb.mixture(m_id).name, [("truth_f", truth_f)])
92
104
 
93
105
  if compute_segsnr:
94
106
  segsnr = mixdb.mixture_segsnr(m_id=m_id, force=force)
95
- write_data.append(("segsnr", segsnr))
96
107
  result.segsnr = segsnr
108
+ if write:
109
+ write_cached_data(mixdb.location, "mixture", mixdb.mixture(m_id).name, [("segsnr", segsnr)])
97
110
 
98
111
  if write:
99
- write_cached_data(mixdb.location, "mixture", mixdb.mixture(m_id).name, write_data)
100
- write_mixture_metadata(mixdb, mixdb.mixture(m_id))
112
+ write_mixture_metadata(mixdb, m_id)
101
113
 
102
114
  return result
103
115
 
@@ -117,6 +129,7 @@ def main() -> None:
117
129
  from sonusai import initial_log_messages
118
130
  from sonusai import logger
119
131
  from sonusai import update_console_handler
132
+ from sonusai.mixture import SAMPLE_RATE
120
133
  from sonusai.mixture import MixtureDatabase
121
134
  from sonusai.mixture import check_audio_files_exist
122
135
  from sonusai.utils import human_readable_size
@@ -125,6 +138,7 @@ def main() -> None:
125
138
  verbose = args["--verbose"]
126
139
  mixids = args["--mixid"]
127
140
  compute_segsnr = args["--segsnr"]
141
+ no_par = args["--nopar"]
128
142
  location = args["LOC"]
129
143
 
130
144
  start_time = time.monotonic()
@@ -138,7 +152,7 @@ def main() -> None:
138
152
  mixids = mixdb.mixids_to_list(mixids)
139
153
 
140
154
  total_samples = mixdb.total_samples(mixids)
141
- duration = total_samples / sonusai.mixture.SAMPLE_RATE
155
+ duration = total_samples / SAMPLE_RATE
142
156
  total_transform_frames = total_samples // mixdb.ft_config.overlap
143
157
  total_feature_frames = total_samples // mixdb.feature_step_samples
144
158
 
@@ -159,6 +173,7 @@ def main() -> None:
159
173
  compute_segsnr=compute_segsnr,
160
174
  write=True,
161
175
  show_progress=True,
176
+ no_par=no_par,
162
177
  )
163
178
  except Exception as e:
164
179
  logger.debug(e)
sonusai/genmetrics.py CHANGED
@@ -1,6 +1,6 @@
1
1
  """sonusai genmetrics
2
2
 
3
- usage: genmetrics [-hvs] [-i MIXID] [-n INCLUDE] [-x EXCLUDE] LOC
3
+ usage: genmetrics [-hvsd] [-i MIXID] [-n INCLUDE] [-x EXCLUDE] LOC
4
4
 
5
5
  options:
6
6
  -h, --help
@@ -9,27 +9,39 @@ options:
9
9
  -n INCLUDE, --include INCLUDE Metrics to include. [default: all]
10
10
  -x EXCLUDE, --exclude EXCLUDE Metrics to exclude. [default: none]
11
11
  -s, --supported Show list of supported metrics.
12
+ -d, --dryrun Show list of metrics that will be generated and exit.
12
13
 
13
14
  Calculate speech enhancement metrics of SonusAI mixture data in LOC.
14
15
 
15
16
  Inputs:
16
17
  LOC A SonusAI mixture database directory.
17
18
  MIXID A glob of mixture ID(s) to generate.
18
- INCLUDE Comma separated list of metrics to include. Can be 'all' or
19
- any of the supported metrics.
20
- EXCLUDE Comma separated list of metrics to exclude. Can be 'none' or
21
- any of the supported metrics.
19
+ INCLUDE Comma separated list of metrics to include. Can be "all" or
20
+ any of the supported metrics or glob(s).
21
+ EXCLUDE Comma separated list of metrics to exclude. Can be "none" or
22
+ any of the supported metrics or glob(s)
23
+
24
+ Note: The default include of "all" excludes the generation of ASR metrics,
25
+ i.e., "*asr*,*wer*". However, if include is manually specified to something other than "all",
26
+ then this behavior is overridden.
27
+
28
+ Similarly, the default exclude of "none" excludes the generation of ASR metrics,
29
+ i.e., "*asr*,*wer*". However, if exclude is manually specified to something other than "none",
30
+ then this behavior is also overridden.
22
31
 
23
32
  Examples:
24
33
 
25
34
  Generate all available mxwer metrics (as determined by mixdb asr_configs parameter):
26
- > sonusai genmetrics -n"mxwer" mixdb_loc
35
+ > sonusai genmetrics -n"mxwer*" mixdb_loc
27
36
 
28
37
  Generate only mxwer.faster metrics:
29
38
  > sonusai genmetrics -n"mxwer.faster" mixdb_loc
30
39
 
31
- Generate all available metrics except for mxwer.faster:
32
- > sonusai genmetrics -x"mxwer.faster" mixdb_loc
40
+ Generate only faster metrics:
41
+ > sonusai genmetrics -n"*faster" mixdb_loc
42
+
43
+ Generate all available metrics except for mxcovl
44
+ > sonusai genmetrics -x"mxcovl" mixdb_loc
33
45
 
34
46
  """
35
47
 
@@ -71,11 +83,13 @@ def main() -> None:
71
83
 
72
84
  verbose = args["--verbose"]
73
85
  mixids = args["--mixid"]
74
- includes = [x.strip() for x in args["--include"].lower().split(",")]
75
- excludes = [x.strip() for x in args["--exclude"].lower().split(",")]
86
+ includes = {x.strip() for x in args["--include"].replace(" ", ",").lower().split(",") if x != ""}
87
+ excludes = {x.strip() for x in args["--exclude"].replace(" ", ",").lower().split(",") if x != ""}
76
88
  show_supported = args["--supported"]
89
+ dryrun = args["--dryrun"]
77
90
  location = args["LOC"]
78
91
 
92
+ import fnmatch
79
93
  import sys
80
94
  import time
81
95
  from functools import partial
@@ -89,9 +103,6 @@ def main() -> None:
89
103
  from sonusai.utils import seconds_to_hms
90
104
  from sonusai.utils import track
91
105
 
92
- # TODO: Check config.yml for changes to asr_configs and update mixdb
93
- # TODO: Support globs for metrics (includes and excludes)
94
-
95
106
  start_time = time.monotonic()
96
107
 
97
108
  # Setup logging file
@@ -107,38 +118,45 @@ def main() -> None:
107
118
  logger.info(f"\nSupported metrics:\n\n{supported.pretty}")
108
119
  sys.exit(0)
109
120
 
110
- if includes is None or "all" in includes:
111
- metrics = supported.names
112
- else:
113
- metrics = set(includes)
114
- if "mxwer" in metrics:
115
- metrics.remove("mxwer")
116
- for name in mixdb.asr_configs:
117
- metrics.add(f"mxwer.{name}")
118
-
119
- diff = metrics.difference(supported.names)
120
- if diff:
121
- logger.error(f"Unrecognized metric: {', '.join(diff)}")
122
- sys.exit(1)
121
+ # Handle default excludes
122
+ if "none" in excludes:
123
+ if "all" in includes:
124
+ excludes = {"*asr*", "*wer*"}
125
+ else:
126
+ excludes = set()
123
127
 
124
- if excludes is None or "none" in excludes:
125
- _excludes = set()
126
- else:
127
- _excludes = set(excludes)
128
- if "mxwer" in _excludes:
129
- _excludes.remove("mxwer")
130
- for name in mixdb.asr_configs:
131
- _excludes.add(f"mxwer.{name}")
132
-
133
- diff = _excludes.difference(supported.names)
134
- if diff:
135
- logger.error(f"Unrecognized metric: {', '.join(diff)}")
136
- sys.exit(1)
128
+ # Handle default includes
129
+ if "all" in includes:
130
+ includes = {"*"}
131
+
132
+ included_metrics: set[str] = set()
133
+ for include in includes:
134
+ for m in fnmatch.filter(supported.names, include):
135
+ included_metrics.add(m)
137
136
 
138
- for exclude in _excludes:
139
- metrics.discard(exclude)
137
+ excluded_metrics: set[str] = set()
138
+ for exclude in excludes:
139
+ for m in fnmatch.filter(supported.names, exclude):
140
+ excluded_metrics.add(m)
141
+
142
+ requested = included_metrics - excluded_metrics
143
+
144
+ # Check for metrics dependencies and cache dependencies even if not explicitly requested.
145
+ dependencies: set[str] = set()
146
+ for metric in requested:
147
+ if metric.startswith("mxwer"):
148
+ dependencies.add("mxasr." + metric[6:])
149
+ dependencies.add("tasr." + metric[6:])
150
+
151
+ metrics = sorted(requested | dependencies)
152
+
153
+ if len(metrics) == 0:
154
+ logger.warning("No metrics were requested")
155
+ sys.exit(1)
140
156
 
141
157
  logger.info(f"Generating metrics: {', '.join(metrics)}")
158
+ if dryrun:
159
+ sys.exit(0)
142
160
 
143
161
  mixids = mixdb.mixids_to_list(mixids)
144
162
  logger.info("")
@@ -146,7 +164,7 @@ def main() -> None:
146
164
 
147
165
  progress = track(total=len(mixids), desc="genmetrics")
148
166
  par_track(
149
- partial(_process_mixture, location=location, metrics=list(metrics)),
167
+ partial(_process_mixture, location=location, metrics=metrics),
150
168
  mixids,
151
169
  progress=progress,
152
170
  )
sonusai/genmix.py CHANGED
@@ -1,14 +1,15 @@
1
1
  """sonusai genmix
2
2
 
3
- usage: genmix [-hvgts] [-i MIXID] LOC
3
+ usage: genmix [-hvgtsn] [-i MIXID] LOC
4
4
 
5
5
  options:
6
6
  -h, --help
7
- -v, --verbose Be verbose.
8
- -i MIXID, --mixid MIXID Mixture ID(s) to generate. [default: *].
9
- -g, --target Save target. [default: False].
10
- -t, --truth Save truth_t. [default: False].
11
- -s, --segsnr Save segsnr_t. [default: False].
7
+ -v, --verbose Be verbose.
8
+ -i MIXID, --mixid MIXID Mixture ID(s) to generate. [default: *].
9
+ -g, --target Save target. [default: False].
10
+ -t, --truth Save truth_t. [default: False].
11
+ -s, --segsnr Save segsnr_t. [default: False].
12
+ -n, --nopar Do not run in parallel. [default: False].
12
13
 
13
14
  Generate SonusAI mixture data from a SonusAI mixture database.
14
15
 
@@ -55,6 +56,7 @@ def genmix(
55
56
  write: bool = False,
56
57
  show_progress: bool = False,
57
58
  force: bool = True,
59
+ no_par: bool = False,
58
60
  ) -> list[GenMixData]:
59
61
  from functools import partial
60
62
 
@@ -77,6 +79,7 @@ def genmix(
77
79
  ),
78
80
  mixids,
79
81
  progress=progress,
82
+ no_par=no_par,
80
83
  )
81
84
  progress.close()
82
85
 
@@ -98,41 +101,47 @@ def _genmix_kernel(
98
101
 
99
102
  mixdb = MixtureDatabase(location)
100
103
 
104
+ result = GenMixData()
105
+
101
106
  targets = mixdb.mixture_targets(m_id=m_id, force=force)
107
+ result.targets = targets
102
108
  noise = mixdb.mixture_noise(m_id=m_id, force=force)
103
- write_data = [("targets", targets), ("noise", noise)]
109
+ result.noise = noise
110
+ if write:
111
+ write_cached_data(
112
+ mixdb.location,
113
+ "mixture",
114
+ mixdb.mixture(m_id).name,
115
+ [
116
+ ("targets", targets),
117
+ ("noise", noise),
118
+ ],
119
+ )
104
120
 
105
121
  if compute_truth:
106
122
  truth_t = mixdb.mixture_truth_t(m_id=m_id, targets=targets, noise=noise, force=force)
107
- write_data.append(("truth_t", truth_t))
108
- else:
109
- truth_t = None
123
+ result.truth_t = truth_t
124
+ if write:
125
+ write_cached_data(mixdb.location, "mixture", mixdb.mixture(m_id).name, [("truth_t", truth_t)])
110
126
 
111
127
  target = mixdb.mixture_target(m_id=m_id, targets=targets)
112
- if save_target:
113
- write_data.append(("target", target))
128
+ result.target = target
129
+ if save_target and write:
130
+ write_cached_data(mixdb.location, "mixture", mixdb.mixture(m_id).name, [("target", target)])
114
131
 
115
132
  if compute_segsnr:
116
133
  segsnr_t = mixdb.mixture_segsnr_t(m_id=m_id, targets=targets, target=target, noise=noise, force=force)
117
- write_data.append(("segsnr_t", segsnr_t))
118
- else:
119
- segsnr_t = None
134
+ result.segsnr_t = segsnr_t
135
+ if write:
136
+ write_cached_data(mixdb.location, "mixture", mixdb.mixture(m_id).name, [("segsnr_t", segsnr_t)])
120
137
 
121
138
  mixture = mixdb.mixture_mixture(m_id=m_id, targets=targets, target=target, noise=noise, force=force)
122
- write_data.append(("mixture", mixture))
123
-
139
+ result.mixture = mixture
124
140
  if write:
125
- write_cached_data(mixdb.location, "mixture", mixdb.mixture(m_id).name, write_data)
126
- write_mixture_metadata(mixdb, mixdb.mixture(m_id))
127
-
128
- return GenMixData(
129
- targets=targets,
130
- target=target,
131
- noise=noise,
132
- mixture=mixture,
133
- truth_t=truth_t,
134
- segsnr_t=segsnr_t,
135
- )
141
+ write_cached_data(mixdb.location, "mixture", mixdb.mixture(m_id).name, [("mixture", mixture)])
142
+ write_mixture_metadata(mixdb, m_id)
143
+
144
+ return result
136
145
 
137
146
 
138
147
  def main() -> None:
@@ -150,6 +159,7 @@ def main() -> None:
150
159
  from sonusai import initial_log_messages
151
160
  from sonusai import logger
152
161
  from sonusai import update_console_handler
162
+ from sonusai.mixture import SAMPLE_RATE
153
163
  from sonusai.mixture import MixtureDatabase
154
164
  from sonusai.mixture import check_audio_files_exist
155
165
  from sonusai.utils import human_readable_size
@@ -161,6 +171,7 @@ def main() -> None:
161
171
  save_target = args["--target"]
162
172
  compute_truth = args["--truth"]
163
173
  compute_segsnr = args["--segsnr"]
174
+ no_par = args["--nopar"]
164
175
 
165
176
  start_time = time.monotonic()
166
177
 
@@ -173,7 +184,7 @@ def main() -> None:
173
184
  mixids = mixdb.mixids_to_list(mixids)
174
185
 
175
186
  total_samples = mixdb.total_samples(mixids)
176
- duration = total_samples / sonusai.mixture.SAMPLE_RATE
187
+ duration = total_samples / SAMPLE_RATE
177
188
 
178
189
  logger.info("")
179
190
  logger.info(f"Found {len(mixids):,} mixtures to process")
@@ -190,6 +201,7 @@ def main() -> None:
190
201
  compute_segsnr=compute_segsnr,
191
202
  write=True,
192
203
  show_progress=True,
204
+ no_par=no_par,
193
205
  )
194
206
  except Exception as e:
195
207
  logger.debug(e)
sonusai/genmixdb.py CHANGED
@@ -1,15 +1,16 @@
1
1
  """sonusai genmixdb
2
2
 
3
- usage: genmixdb [-hvmfsdj] LOC
3
+ usage: genmixdb [-hvmfsdjn] LOC
4
4
 
5
5
  options:
6
- -h, --help
7
- -v, --verbose Be verbose.
8
- -m, --mix Save mixture data. [default: False].
9
- -f, --ft Save feature/truth_f data. [default: False].
10
- -s, --segsnr Save segsnr data. [default: False].
11
- -d, --dryrun Perform a dry run showing the processed config. [default: False].
12
- -j, --json Save JSON version of database. [default: False].
6
+ -h, --help
7
+ -v, --verbose Be verbose.
8
+ -m, --mix ave mixture data. [default: False].
9
+ -f, --ft Save feature/truth_f data. [default: False].
10
+ -s, --segsnr Save segsnr data. [default: False].
11
+ -d, --dryrun Perform a dry run showing the processed config. [default: False].
12
+ -j, --json Save JSON version of database. [default: False].
13
+ -n, --nopar Do not run in parallel. [default: False].
13
14
 
14
15
  Create mixture database data for training and evaluation. Optionally, also create mixture audio and feature/truth data.
15
16
 
@@ -115,8 +116,6 @@ will find all .wav files in the specified directories and process them as target
115
116
 
116
117
  import signal
117
118
 
118
- from sonusai.mixture import Mixture
119
-
120
119
 
121
120
  def signal_handler(_sig, _frame):
122
121
  import sys
@@ -139,6 +138,7 @@ def genmixdb(
139
138
  show_progress: bool = False,
140
139
  test: bool = False,
141
140
  save_json: bool = False,
141
+ no_par: bool = False,
142
142
  ) -> None:
143
143
  from functools import partial
144
144
  from random import seed
@@ -151,7 +151,6 @@ def genmixdb(
151
151
  from sonusai.mixture import AugmentationRule
152
152
  from sonusai.mixture import MixtureDatabase
153
153
  from sonusai.mixture import balance_targets
154
- from sonusai.mixture import generate_mixtures
155
154
  from sonusai.mixture import get_all_snrs_from_config
156
155
  from sonusai.mixture import get_augmentation_rules
157
156
  from sonusai.mixture import get_augmented_targets
@@ -317,7 +316,8 @@ def genmixdb(
317
316
  f"{seconds_to_hms(seconds=noise_audio_duration)}"
318
317
  )
319
318
 
320
- used_noise_files, used_noise_samples, mixtures = generate_mixtures(
319
+ used_noise_files, used_noise_samples = populate_mixture_table(
320
+ location=location,
321
321
  noise_mix_mode=mixdb.noise_mix_mode,
322
322
  augmented_targets=augmented_targets,
323
323
  target_files=target_files,
@@ -330,16 +330,17 @@ def genmixdb(
330
330
  num_classes=mixdb.num_classes,
331
331
  feature_step_samples=mixdb.feature_step_samples,
332
332
  num_ir=mixdb.num_impulse_response_files,
333
+ test=test,
333
334
  )
334
335
 
335
- num_mixtures = len(mixtures)
336
+ num_mixtures = len(mixdb.mixtures)
336
337
  update_mixid_width(location, num_mixtures, test)
337
338
 
338
339
  if logging:
339
340
  logger.info("")
340
341
  logger.info(f"Found {num_mixtures:,} mixtures to process")
341
342
 
342
- total_duration = float(sum([mixture.samples for mixture in mixtures])) / SAMPLE_RATE
343
+ total_duration = float(sum([mixture.samples for mixture in mixdb.mixtures])) / SAMPLE_RATE
343
344
 
344
345
  if logging:
345
346
  log_duration_and_sizes(
@@ -363,7 +364,7 @@ def genmixdb(
363
364
  if logging:
364
365
  logger.info("Generating mixtures")
365
366
  progress = track(total=num_mixtures, disable=not show_progress)
366
- mixtures = par_track(
367
+ par_track(
367
368
  partial(
368
369
  _process_mixture,
369
370
  location=location,
@@ -372,13 +373,12 @@ def genmixdb(
372
373
  save_segsnr=save_segsnr,
373
374
  test=test,
374
375
  ),
375
- mixtures,
376
+ range(num_mixtures),
376
377
  progress=progress,
378
+ no_par=no_par,
377
379
  )
378
380
  progress.close()
379
381
 
380
- populate_mixture_table(location, mixtures, test)
381
-
382
382
  total_noise_files = len(noise_files)
383
383
 
384
384
  total_samples = mixdb.total_samples()
@@ -409,70 +409,60 @@ def genmixdb(
409
409
 
410
410
 
411
411
  def _process_mixture(
412
- mixture: Mixture,
412
+ m_id: int,
413
413
  location: str,
414
414
  save_mix: bool,
415
415
  save_ft: bool,
416
416
  save_segsnr: bool,
417
417
  test: bool,
418
- ) -> Mixture:
419
- from typing import Any
418
+ ) -> None:
419
+ from functools import partial
420
420
 
421
421
  from sonusai.mixture import MixtureDatabase
422
- from sonusai.mixture import get_ft
423
- from sonusai.mixture import get_segsnr
424
- from sonusai.mixture import get_truth
425
- from sonusai.mixture import update_mixture
422
+ from sonusai.mixture import clear_cached_data
423
+ from sonusai.mixture import update_mixture_table
426
424
  from sonusai.mixture import write_cached_data
427
425
  from sonusai.mixture import write_mixture_metadata
428
426
 
429
- with_data = save_mix or save_ft
427
+ with_data = save_mix or save_ft or save_segsnr
428
+
429
+ genmix_data = update_mixture_table(location, m_id, with_data, test)
430
+
430
431
  mixdb = MixtureDatabase(location, test)
432
+ mixture = mixdb.mixture(m_id)
431
433
 
432
- mixture, genmix_data = update_mixture(mixdb, mixture, with_data)
434
+ write = partial(write_cached_data, location=location, name="mixture", index=mixture.name)
435
+ clear = partial(clear_cached_data, location=location, name="mixture", index=mixture.name)
433
436
 
434
437
  if with_data:
435
- write_data: list[tuple[str, Any]] = []
436
-
437
- if save_mix:
438
- write_data.append(("targets", genmix_data.targets))
439
- write_data.append(("noise", genmix_data.noise))
440
- write_data.append(("mixture", genmix_data.mixture))
438
+ write(
439
+ items=[
440
+ ("targets", genmix_data.targets),
441
+ ("target", genmix_data.target),
442
+ ("noise", genmix_data.noise),
443
+ ("mixture", genmix_data.mixture),
444
+ ]
445
+ )
441
446
 
442
447
  if save_ft:
443
- if genmix_data.targets is None or genmix_data.noise is None or genmix_data.mixture is None:
444
- raise RuntimeError("Mixture data was not generated properly")
445
- truth_t = get_truth(
446
- mixdb=mixdb,
447
- mixture=mixture,
448
- targets_audio=genmix_data.targets,
449
- noise_audio=genmix_data.noise,
450
- mixture_audio=genmix_data.mixture,
451
- )
452
- feature, truth_f = get_ft(
453
- mixdb=mixdb,
454
- mixture=mixture,
455
- mixture_audio=genmix_data.mixture,
456
- truth_t=truth_t,
448
+ clear(items=["feature", "truth_f"])
449
+ feature, truth_f = mixdb.mixture_ft(m_id)
450
+ write(
451
+ items=[
452
+ ("feature", feature),
453
+ ("truth_f", truth_f),
454
+ ]
457
455
  )
458
- write_data.append(("feature", feature))
459
- write_data.append(("truth_f", truth_f))
460
456
 
461
- if save_segsnr:
462
- if genmix_data.target is None:
463
- raise RuntimeError("Target data was not generated properly")
464
- segsnr = get_segsnr(
465
- mixdb=mixdb,
466
- mixture=mixture,
467
- target_audio=genmix_data.target,
468
- noise=genmix_data.noise,
469
- )
470
- write_data.append(("segsnr", segsnr))
457
+ if save_segsnr:
458
+ clear(items=["segsnr"])
459
+ segsnr = mixdb.mixture_segsnr(m_id)
460
+ write(items=[("segsnr", segsnr)])
471
461
 
472
- write_cached_data(mixdb.location, "mixture", mixture.name, write_data)
473
- write_mixture_metadata(mixdb, mixture)
462
+ if not save_mix:
463
+ clear(items=["targets", "target", "noise", "mixture"])
474
464
 
475
- return mixture
465
+ write_mixture_metadata(mixdb, m_id)
476
466
 
477
467
 
478
468
  def main() -> None:
@@ -505,6 +495,7 @@ def main() -> None:
505
495
  save_segsnr = args["--segsnr"]
506
496
  dryrun = args["--dryrun"]
507
497
  save_json = args["--json"]
498
+ no_par = args["--nopar"]
508
499
  location = args["LOC"]
509
500
 
510
501
  start_time = time.monotonic()
@@ -535,6 +526,7 @@ def main() -> None:
535
526
  save_segsnr=save_segsnr,
536
527
  show_progress=True,
537
528
  save_json=save_json,
529
+ no_par=no_par,
538
530
  )
539
531
  except Exception as e:
540
532
  logger.debug(e)