sonusai 0.19.6__py3-none-any.whl → 0.19.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. sonusai/__init__.py +1 -1
  2. sonusai/aawscd_probwrite.py +1 -1
  3. sonusai/calc_metric_spenh.py +1 -1
  4. sonusai/genft.py +29 -14
  5. sonusai/genmetrics.py +60 -42
  6. sonusai/genmix.py +41 -29
  7. sonusai/genmixdb.py +56 -64
  8. sonusai/metrics/calc_class_weights.py +1 -3
  9. sonusai/metrics/calc_optimal_thresholds.py +2 -2
  10. sonusai/metrics/calc_phase_distance.py +1 -1
  11. sonusai/metrics/calc_speech.py +6 -6
  12. sonusai/metrics/class_summary.py +6 -15
  13. sonusai/metrics/confusion_matrix_summary.py +11 -27
  14. sonusai/metrics/one_hot.py +3 -3
  15. sonusai/metrics/snr_summary.py +7 -7
  16. sonusai/mixture/__init__.py +2 -17
  17. sonusai/mixture/augmentation.py +5 -6
  18. sonusai/mixture/class_count.py +1 -1
  19. sonusai/mixture/config.py +36 -46
  20. sonusai/mixture/data_io.py +30 -1
  21. sonusai/mixture/datatypes.py +29 -40
  22. sonusai/mixture/db_datatypes.py +1 -1
  23. sonusai/mixture/feature.py +3 -23
  24. sonusai/mixture/generation.py +161 -204
  25. sonusai/mixture/helpers.py +29 -187
  26. sonusai/mixture/mixdb.py +386 -159
  27. sonusai/mixture/soundfile_audio.py +1 -1
  28. sonusai/mixture/sox_audio.py +4 -4
  29. sonusai/mixture/sox_augmentation.py +1 -1
  30. sonusai/mixture/target_class_balancing.py +9 -11
  31. sonusai/mixture/targets.py +23 -20
  32. sonusai/mixture/torchaudio_audio.py +18 -7
  33. sonusai/mixture/torchaudio_augmentation.py +3 -4
  34. sonusai/mixture/truth.py +21 -34
  35. sonusai/mixture/truth_functions/__init__.py +6 -0
  36. sonusai/mixture/truth_functions/crm.py +51 -37
  37. sonusai/mixture/truth_functions/energy.py +95 -50
  38. sonusai/mixture/truth_functions/file.py +12 -8
  39. sonusai/mixture/truth_functions/metadata.py +24 -0
  40. sonusai/mixture/truth_functions/metrics.py +28 -0
  41. sonusai/mixture/truth_functions/phoneme.py +4 -5
  42. sonusai/mixture/truth_functions/sed.py +32 -23
  43. sonusai/mixture/truth_functions/target.py +62 -29
  44. sonusai/mkwav.py +20 -19
  45. sonusai/queries/queries.py +9 -15
  46. sonusai/speech/l2arctic.py +6 -2
  47. sonusai/summarize_metric_spenh.py +1 -1
  48. sonusai/utils/__init__.py +1 -0
  49. sonusai/utils/asr_functions/aaware_whisper.py +1 -1
  50. sonusai/utils/audio_devices.py +27 -18
  51. sonusai/utils/docstring.py +6 -3
  52. sonusai/utils/energy_f.py +5 -3
  53. sonusai/utils/human_readable_size.py +6 -6
  54. sonusai/utils/load_object.py +15 -0
  55. sonusai/utils/onnx_utils.py +2 -2
  56. sonusai/utils/print_mixture_details.py +3 -3
  57. {sonusai-0.19.6.dist-info → sonusai-0.19.9.dist-info}/METADATA +2 -2
  58. {sonusai-0.19.6.dist-info → sonusai-0.19.9.dist-info}/RECORD +60 -58
  59. sonusai/mixture/truth_functions/datatypes.py +0 -37
  60. {sonusai-0.19.6.dist-info → sonusai-0.19.9.dist-info}/WHEEL +0 -0
  61. {sonusai-0.19.6.dist-info → sonusai-0.19.9.dist-info}/entry_points.txt +0 -0
sonusai/__init__.py CHANGED
@@ -7,7 +7,7 @@ from rich.traceback import install
7
7
 
8
8
  install(show_locals=True)
9
9
 
10
- __version__ = metadata.version(__package__)
10
+ __version__ = metadata.version(__package__) # pyright: ignore [reportArgumentType]
11
11
  BASEDIR = dirname(__file__)
12
12
 
13
13
  commands_doc = """
@@ -89,7 +89,7 @@ def on_message(_client, _userdata, message):
89
89
  FRAME_COUNT += 1
90
90
 
91
91
  global PROGRESS
92
- PROGRESS.update()
92
+ PROGRESS.update() # pyright: ignore [reportOptionalMemberAccess]
93
93
 
94
94
  if FRAME_COUNT == FRAMES:
95
95
  global DONE
@@ -582,7 +582,7 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
582
582
  asr_mx = None
583
583
  asr_tge = None
584
584
  asr_engines = list(mixdb.asr_configs.keys())
585
- if len(asr_engines) > 0 and mixdb.mixture(mixid).snr >= -96: # noise only, ignore/reset target asr
585
+ if len(asr_engines) > 0 and not mixdb.mixture(mixid).is_noise_only: # noise only, ignore/reset target asr
586
586
  wer_mx = float(mixdb.mixture_metrics(mixid, [f"mxwer.{asr_engines[0]}"])[0]) * 100
587
587
  asr_tt = MP_GLOBAL.mixdb.mixture_speech_metadata(mixid, "text")[0] # ignore mixup
588
588
  if asr_tt is None:
sonusai/genft.py CHANGED
@@ -1,12 +1,13 @@
1
1
  """sonusai genft
2
2
 
3
- usage: genft [-hvs] [-i MIXID] LOC
3
+ usage: genft [-hvsn] [-i MIXID] LOC
4
4
 
5
5
  options:
6
6
  -h, --help
7
- -v, --verbose Be verbose.
8
- -i MIXID, --mixid MIXID Mixture ID(s) to generate. [default: *].
9
- -s, --segsnr Save segsnr. [default: False].
7
+ -v, --verbose Be verbose.
8
+ -i MIXID, --mixid MIXID Mixture ID(s) to generate. [default: *].
9
+ -s, --segsnr Save segsnr. [default: False].
10
+ -n, --nopar Do not run in parallel. [default: False].
10
11
 
11
12
  Generate SonusAI feature/truth data from a SonusAI mixture database.
12
13
 
@@ -50,6 +51,7 @@ def genft(
50
51
  write: bool = False,
51
52
  show_progress: bool = False,
52
53
  force: bool = True,
54
+ no_par: bool = False,
53
55
  ) -> list[GenFTData]:
54
56
  from functools import partial
55
57
 
@@ -62,18 +64,26 @@ def genft(
62
64
 
63
65
  progress = track(total=len(mixids), disable=not show_progress)
64
66
  results = par_track(
65
- partial(_genft_kernel, location=location, compute_truth=compute_truth, compute_segsnr=compute_segsnr, force=force, write=write),
67
+ partial(
68
+ _genft_kernel,
69
+ location=location,
70
+ compute_truth=compute_truth,
71
+ compute_segsnr=compute_segsnr,
72
+ force=force,
73
+ write=write,
74
+ ),
66
75
  mixids,
67
76
  progress=progress,
77
+ no_par=no_par,
68
78
  )
69
79
  progress.close()
70
80
 
71
81
  return results
72
82
 
73
83
 
74
- def _genft_kernel(m_id: int, location: str, compute_truth: bool, compute_segsnr: bool, force: bool, write: bool) -> GenFTData:
75
- from typing import Any
76
-
84
+ def _genft_kernel(
85
+ m_id: int, location: str, compute_truth: bool, compute_segsnr: bool, force: bool, write: bool
86
+ ) -> GenFTData:
77
87
  from sonusai.mixture import MixtureDatabase
78
88
  from sonusai.mixture import write_cached_data
79
89
  from sonusai.mixture import write_mixture_metadata
@@ -83,21 +93,23 @@ def _genft_kernel(m_id: int, location: str, compute_truth: bool, compute_segsnr:
83
93
  result = GenFTData()
84
94
 
85
95
  feature, truth_f = mixdb.mixture_ft(m_id=m_id, force=force)
86
- write_data: list[tuple[str, Any]] = [("feature", feature)]
87
96
  result.feature = feature
97
+ if write:
98
+ write_cached_data(mixdb.location, "mixture", mixdb.mixture(m_id).name, [("feature", feature)])
88
99
 
89
100
  if compute_truth:
90
- write_data.append(("truth_f", truth_f))
91
101
  result.truth_f = truth_f
102
+ if write:
103
+ write_cached_data(mixdb.location, "mixture", mixdb.mixture(m_id).name, [("truth_f", truth_f)])
92
104
 
93
105
  if compute_segsnr:
94
106
  segsnr = mixdb.mixture_segsnr(m_id=m_id, force=force)
95
- write_data.append(("segsnr", segsnr))
96
107
  result.segsnr = segsnr
108
+ if write:
109
+ write_cached_data(mixdb.location, "mixture", mixdb.mixture(m_id).name, [("segsnr", segsnr)])
97
110
 
98
111
  if write:
99
- write_cached_data(mixdb.location, "mixture", mixdb.mixture(m_id).name, write_data)
100
- write_mixture_metadata(mixdb, mixdb.mixture(m_id))
112
+ write_mixture_metadata(mixdb, m_id)
101
113
 
102
114
  return result
103
115
 
@@ -117,6 +129,7 @@ def main() -> None:
117
129
  from sonusai import initial_log_messages
118
130
  from sonusai import logger
119
131
  from sonusai import update_console_handler
132
+ from sonusai.mixture import SAMPLE_RATE
120
133
  from sonusai.mixture import MixtureDatabase
121
134
  from sonusai.mixture import check_audio_files_exist
122
135
  from sonusai.utils import human_readable_size
@@ -125,6 +138,7 @@ def main() -> None:
125
138
  verbose = args["--verbose"]
126
139
  mixids = args["--mixid"]
127
140
  compute_segsnr = args["--segsnr"]
141
+ no_par = args["--nopar"]
128
142
  location = args["LOC"]
129
143
 
130
144
  start_time = time.monotonic()
@@ -138,7 +152,7 @@ def main() -> None:
138
152
  mixids = mixdb.mixids_to_list(mixids)
139
153
 
140
154
  total_samples = mixdb.total_samples(mixids)
141
- duration = total_samples / sonusai.mixture.SAMPLE_RATE
155
+ duration = total_samples / SAMPLE_RATE
142
156
  total_transform_frames = total_samples // mixdb.ft_config.overlap
143
157
  total_feature_frames = total_samples // mixdb.feature_step_samples
144
158
 
@@ -159,6 +173,7 @@ def main() -> None:
159
173
  compute_segsnr=compute_segsnr,
160
174
  write=True,
161
175
  show_progress=True,
176
+ no_par=no_par,
162
177
  )
163
178
  except Exception as e:
164
179
  logger.debug(e)
sonusai/genmetrics.py CHANGED
@@ -1,6 +1,6 @@
1
1
  """sonusai genmetrics
2
2
 
3
- usage: genmetrics [-hvs] [-i MIXID] [-n INCLUDE] [-x EXCLUDE] LOC
3
+ usage: genmetrics [-hvsd] [-i MIXID] [-n INCLUDE] [-x EXCLUDE] LOC
4
4
 
5
5
  options:
6
6
  -h, --help
@@ -9,27 +9,39 @@ options:
9
9
  -n INCLUDE, --include INCLUDE Metrics to include. [default: all]
10
10
  -x EXCLUDE, --exclude EXCLUDE Metrics to exclude. [default: none]
11
11
  -s, --supported Show list of supported metrics.
12
+ -d, --dryrun Show list of metrics that will be generated and exit.
12
13
 
13
14
  Calculate speech enhancement metrics of SonusAI mixture data in LOC.
14
15
 
15
16
  Inputs:
16
17
  LOC A SonusAI mixture database directory.
17
18
  MIXID A glob of mixture ID(s) to generate.
18
- INCLUDE Comma separated list of metrics to include. Can be 'all' or
19
- any of the supported metrics.
20
- EXCLUDE Comma separated list of metrics to exclude. Can be 'none' or
21
- any of the supported metrics.
19
+ INCLUDE Comma separated list of metrics to include. Can be "all" or
20
+ any of the supported metrics or glob(s).
21
+ EXCLUDE Comma separated list of metrics to exclude. Can be "none" or
22
+ any of the supported metrics or glob(s)
23
+
24
+ Note: The default include of "all" excludes the generation of ASR metrics,
25
+ i.e., "*asr*,*wer*". However, if include is manually specified to something other than "all",
26
+ then this behavior is overridden.
27
+
28
+ Similarly, the default exclude of "none" excludes the generation of ASR metrics,
29
+ i.e., "*asr*,*wer*". However, if exclude is manually specified to something other than "none",
30
+ then this behavior is also overridden.
22
31
 
23
32
  Examples:
24
33
 
25
34
  Generate all available mxwer metrics (as determined by mixdb asr_configs parameter):
26
- > sonusai genmetrics -n"mxwer" mixdb_loc
35
+ > sonusai genmetrics -n"mxwer*" mixdb_loc
27
36
 
28
37
  Generate only mxwer.faster metrics:
29
38
  > sonusai genmetrics -n"mxwer.faster" mixdb_loc
30
39
 
31
- Generate all available metrics except for mxwer.faster:
32
- > sonusai genmetrics -x"mxwer.faster" mixdb_loc
40
+ Generate only faster metrics:
41
+ > sonusai genmetrics -n"*faster" mixdb_loc
42
+
43
+ Generate all available metrics except for mxcovl
44
+ > sonusai genmetrics -x"mxcovl" mixdb_loc
33
45
 
34
46
  """
35
47
 
@@ -71,11 +83,13 @@ def main() -> None:
71
83
 
72
84
  verbose = args["--verbose"]
73
85
  mixids = args["--mixid"]
74
- includes = [x.strip() for x in args["--include"].lower().split(",")]
75
- excludes = [x.strip() for x in args["--exclude"].lower().split(",")]
86
+ includes = {x.strip() for x in args["--include"].replace(" ", ",").lower().split(",") if x != ""}
87
+ excludes = {x.strip() for x in args["--exclude"].replace(" ", ",").lower().split(",") if x != ""}
76
88
  show_supported = args["--supported"]
89
+ dryrun = args["--dryrun"]
77
90
  location = args["LOC"]
78
91
 
92
+ import fnmatch
79
93
  import sys
80
94
  import time
81
95
  from functools import partial
@@ -89,9 +103,6 @@ def main() -> None:
89
103
  from sonusai.utils import seconds_to_hms
90
104
  from sonusai.utils import track
91
105
 
92
- # TODO: Check config.yml for changes to asr_configs and update mixdb
93
- # TODO: Support globs for metrics (includes and excludes)
94
-
95
106
  start_time = time.monotonic()
96
107
 
97
108
  # Setup logging file
@@ -107,38 +118,45 @@ def main() -> None:
107
118
  logger.info(f"\nSupported metrics:\n\n{supported.pretty}")
108
119
  sys.exit(0)
109
120
 
110
- if includes is None or "all" in includes:
111
- metrics = supported.names
112
- else:
113
- metrics = set(includes)
114
- if "mxwer" in metrics:
115
- metrics.remove("mxwer")
116
- for name in mixdb.asr_configs:
117
- metrics.add(f"mxwer.{name}")
118
-
119
- diff = metrics.difference(supported.names)
120
- if diff:
121
- logger.error(f"Unrecognized metric: {', '.join(diff)}")
122
- sys.exit(1)
121
+ # Handle default excludes
122
+ if "none" in excludes:
123
+ if "all" in includes:
124
+ excludes = {"*asr*", "*wer*"}
125
+ else:
126
+ excludes = set()
123
127
 
124
- if excludes is None or "none" in excludes:
125
- _excludes = set()
126
- else:
127
- _excludes = set(excludes)
128
- if "mxwer" in _excludes:
129
- _excludes.remove("mxwer")
130
- for name in mixdb.asr_configs:
131
- _excludes.add(f"mxwer.{name}")
132
-
133
- diff = _excludes.difference(supported.names)
134
- if diff:
135
- logger.error(f"Unrecognized metric: {', '.join(diff)}")
136
- sys.exit(1)
128
+ # Handle default includes
129
+ if "all" in includes:
130
+ includes = {"*"}
131
+
132
+ included_metrics: set[str] = set()
133
+ for include in includes:
134
+ for m in fnmatch.filter(supported.names, include):
135
+ included_metrics.add(m)
137
136
 
138
- for exclude in _excludes:
139
- metrics.discard(exclude)
137
+ excluded_metrics: set[str] = set()
138
+ for exclude in excludes:
139
+ for m in fnmatch.filter(supported.names, exclude):
140
+ excluded_metrics.add(m)
141
+
142
+ requested = included_metrics - excluded_metrics
143
+
144
+ # Check for metrics dependencies and cache dependencies even if not explicitly requested.
145
+ dependencies: set[str] = set()
146
+ for metric in requested:
147
+ if metric.startswith("mxwer"):
148
+ dependencies.add("mxasr." + metric[6:])
149
+ dependencies.add("tasr." + metric[6:])
150
+
151
+ metrics = sorted(requested | dependencies)
152
+
153
+ if len(metrics) == 0:
154
+ logger.warning("No metrics were requested")
155
+ sys.exit(1)
140
156
 
141
157
  logger.info(f"Generating metrics: {', '.join(metrics)}")
158
+ if dryrun:
159
+ sys.exit(0)
142
160
 
143
161
  mixids = mixdb.mixids_to_list(mixids)
144
162
  logger.info("")
@@ -146,7 +164,7 @@ def main() -> None:
146
164
 
147
165
  progress = track(total=len(mixids), desc="genmetrics")
148
166
  par_track(
149
- partial(_process_mixture, location=location, metrics=list(metrics)),
167
+ partial(_process_mixture, location=location, metrics=metrics),
150
168
  mixids,
151
169
  progress=progress,
152
170
  )
sonusai/genmix.py CHANGED
@@ -1,14 +1,15 @@
1
1
  """sonusai genmix
2
2
 
3
- usage: genmix [-hvgts] [-i MIXID] LOC
3
+ usage: genmix [-hvgtsn] [-i MIXID] LOC
4
4
 
5
5
  options:
6
6
  -h, --help
7
- -v, --verbose Be verbose.
8
- -i MIXID, --mixid MIXID Mixture ID(s) to generate. [default: *].
9
- -g, --target Save target. [default: False].
10
- -t, --truth Save truth_t. [default: False].
11
- -s, --segsnr Save segsnr_t. [default: False].
7
+ -v, --verbose Be verbose.
8
+ -i MIXID, --mixid MIXID Mixture ID(s) to generate. [default: *].
9
+ -g, --target Save target. [default: False].
10
+ -t, --truth Save truth_t. [default: False].
11
+ -s, --segsnr Save segsnr_t. [default: False].
12
+ -n, --nopar Do not run in parallel. [default: False].
12
13
 
13
14
  Generate SonusAI mixture data from a SonusAI mixture database.
14
15
 
@@ -55,6 +56,7 @@ def genmix(
55
56
  write: bool = False,
56
57
  show_progress: bool = False,
57
58
  force: bool = True,
59
+ no_par: bool = False,
58
60
  ) -> list[GenMixData]:
59
61
  from functools import partial
60
62
 
@@ -77,6 +79,7 @@ def genmix(
77
79
  ),
78
80
  mixids,
79
81
  progress=progress,
82
+ no_par=no_par,
80
83
  )
81
84
  progress.close()
82
85
 
@@ -98,41 +101,47 @@ def _genmix_kernel(
98
101
 
99
102
  mixdb = MixtureDatabase(location)
100
103
 
104
+ result = GenMixData()
105
+
101
106
  targets = mixdb.mixture_targets(m_id=m_id, force=force)
107
+ result.targets = targets
102
108
  noise = mixdb.mixture_noise(m_id=m_id, force=force)
103
- write_data = [("targets", targets), ("noise", noise)]
109
+ result.noise = noise
110
+ if write:
111
+ write_cached_data(
112
+ mixdb.location,
113
+ "mixture",
114
+ mixdb.mixture(m_id).name,
115
+ [
116
+ ("targets", targets),
117
+ ("noise", noise),
118
+ ],
119
+ )
104
120
 
105
121
  if compute_truth:
106
122
  truth_t = mixdb.mixture_truth_t(m_id=m_id, targets=targets, noise=noise, force=force)
107
- write_data.append(("truth_t", truth_t))
108
- else:
109
- truth_t = None
123
+ result.truth_t = truth_t
124
+ if write:
125
+ write_cached_data(mixdb.location, "mixture", mixdb.mixture(m_id).name, [("truth_t", truth_t)])
110
126
 
111
127
  target = mixdb.mixture_target(m_id=m_id, targets=targets)
112
- if save_target:
113
- write_data.append(("target", target))
128
+ result.target = target
129
+ if save_target and write:
130
+ write_cached_data(mixdb.location, "mixture", mixdb.mixture(m_id).name, [("target", target)])
114
131
 
115
132
  if compute_segsnr:
116
133
  segsnr_t = mixdb.mixture_segsnr_t(m_id=m_id, targets=targets, target=target, noise=noise, force=force)
117
- write_data.append(("segsnr_t", segsnr_t))
118
- else:
119
- segsnr_t = None
134
+ result.segsnr_t = segsnr_t
135
+ if write:
136
+ write_cached_data(mixdb.location, "mixture", mixdb.mixture(m_id).name, [("segsnr_t", segsnr_t)])
120
137
 
121
138
  mixture = mixdb.mixture_mixture(m_id=m_id, targets=targets, target=target, noise=noise, force=force)
122
- write_data.append(("mixture", mixture))
123
-
139
+ result.mixture = mixture
124
140
  if write:
125
- write_cached_data(mixdb.location, "mixture", mixdb.mixture(m_id).name, write_data)
126
- write_mixture_metadata(mixdb, mixdb.mixture(m_id))
127
-
128
- return GenMixData(
129
- targets=targets,
130
- target=target,
131
- noise=noise,
132
- mixture=mixture,
133
- truth_t=truth_t,
134
- segsnr_t=segsnr_t,
135
- )
141
+ write_cached_data(mixdb.location, "mixture", mixdb.mixture(m_id).name, [("mixture", mixture)])
142
+ write_mixture_metadata(mixdb, m_id)
143
+
144
+ return result
136
145
 
137
146
 
138
147
  def main() -> None:
@@ -150,6 +159,7 @@ def main() -> None:
150
159
  from sonusai import initial_log_messages
151
160
  from sonusai import logger
152
161
  from sonusai import update_console_handler
162
+ from sonusai.mixture import SAMPLE_RATE
153
163
  from sonusai.mixture import MixtureDatabase
154
164
  from sonusai.mixture import check_audio_files_exist
155
165
  from sonusai.utils import human_readable_size
@@ -161,6 +171,7 @@ def main() -> None:
161
171
  save_target = args["--target"]
162
172
  compute_truth = args["--truth"]
163
173
  compute_segsnr = args["--segsnr"]
174
+ no_par = args["--nopar"]
164
175
 
165
176
  start_time = time.monotonic()
166
177
 
@@ -173,7 +184,7 @@ def main() -> None:
173
184
  mixids = mixdb.mixids_to_list(mixids)
174
185
 
175
186
  total_samples = mixdb.total_samples(mixids)
176
- duration = total_samples / sonusai.mixture.SAMPLE_RATE
187
+ duration = total_samples / SAMPLE_RATE
177
188
 
178
189
  logger.info("")
179
190
  logger.info(f"Found {len(mixids):,} mixtures to process")
@@ -190,6 +201,7 @@ def main() -> None:
190
201
  compute_segsnr=compute_segsnr,
191
202
  write=True,
192
203
  show_progress=True,
204
+ no_par=no_par,
193
205
  )
194
206
  except Exception as e:
195
207
  logger.debug(e)