sonusai 0.19.5__py3-none-any.whl → 0.19.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. sonusai/__init__.py +1 -1
  2. sonusai/aawscd_probwrite.py +1 -1
  3. sonusai/calc_metric_spenh.py +1 -1
  4. sonusai/genft.py +38 -49
  5. sonusai/genmetrics.py +65 -70
  6. sonusai/genmix.py +62 -72
  7. sonusai/genmixdb.py +73 -95
  8. sonusai/metrics/calc_class_weights.py +1 -3
  9. sonusai/metrics/calc_optimal_thresholds.py +2 -2
  10. sonusai/metrics/calc_phase_distance.py +1 -1
  11. sonusai/metrics/calc_segsnr_f.py +1 -1
  12. sonusai/metrics/calc_speech.py +6 -6
  13. sonusai/metrics/class_summary.py +6 -15
  14. sonusai/metrics/confusion_matrix_summary.py +11 -27
  15. sonusai/metrics/one_hot.py +3 -3
  16. sonusai/metrics/snr_summary.py +7 -7
  17. sonusai/mixture/__init__.py +3 -17
  18. sonusai/mixture/augmentation.py +5 -6
  19. sonusai/mixture/class_count.py +1 -1
  20. sonusai/mixture/config.py +36 -46
  21. sonusai/mixture/data_io.py +30 -1
  22. sonusai/mixture/datatypes.py +29 -40
  23. sonusai/mixture/db_datatypes.py +1 -1
  24. sonusai/mixture/feature.py +3 -23
  25. sonusai/mixture/generation.py +202 -235
  26. sonusai/mixture/helpers.py +29 -187
  27. sonusai/mixture/mixdb.py +386 -159
  28. sonusai/mixture/soundfile_audio.py +1 -1
  29. sonusai/mixture/sox_audio.py +4 -4
  30. sonusai/mixture/sox_augmentation.py +1 -1
  31. sonusai/mixture/target_class_balancing.py +9 -11
  32. sonusai/mixture/targets.py +23 -20
  33. sonusai/mixture/truth.py +21 -34
  34. sonusai/mixture/truth_functions/__init__.py +6 -0
  35. sonusai/mixture/truth_functions/crm.py +51 -37
  36. sonusai/mixture/truth_functions/energy.py +95 -50
  37. sonusai/mixture/truth_functions/file.py +12 -8
  38. sonusai/mixture/truth_functions/metadata.py +24 -0
  39. sonusai/mixture/truth_functions/metrics.py +28 -0
  40. sonusai/mixture/truth_functions/phoneme.py +4 -5
  41. sonusai/mixture/truth_functions/sed.py +32 -23
  42. sonusai/mixture/truth_functions/target.py +62 -29
  43. sonusai/mkwav.py +34 -43
  44. sonusai/queries/queries.py +9 -15
  45. sonusai/speech/l2arctic.py +6 -2
  46. sonusai/summarize_metric_spenh.py +1 -1
  47. sonusai/utils/__init__.py +1 -0
  48. sonusai/utils/asr_functions/aaware_whisper.py +1 -1
  49. sonusai/utils/audio_devices.py +27 -18
  50. sonusai/utils/docstring.py +6 -3
  51. sonusai/utils/energy_f.py +5 -3
  52. sonusai/utils/human_readable_size.py +6 -6
  53. sonusai/utils/load_object.py +15 -0
  54. sonusai/utils/onnx_utils.py +2 -2
  55. sonusai/utils/parallel.py +3 -5
  56. sonusai/utils/print_mixture_details.py +3 -3
  57. {sonusai-0.19.5.dist-info → sonusai-0.19.8.dist-info}/METADATA +2 -2
  58. {sonusai-0.19.5.dist-info → sonusai-0.19.8.dist-info}/RECORD +60 -58
  59. sonusai/mixture/truth_functions/datatypes.py +0 -37
  60. {sonusai-0.19.5.dist-info → sonusai-0.19.8.dist-info}/WHEEL +0 -0
  61. {sonusai-0.19.5.dist-info → sonusai-0.19.8.dist-info}/entry_points.txt +0 -0
sonusai/__init__.py CHANGED
@@ -7,7 +7,7 @@ from rich.traceback import install
7
7
 
8
8
  install(show_locals=True)
9
9
 
10
- __version__ = metadata.version(__package__)
10
+ __version__ = metadata.version(__package__) # pyright: ignore [reportArgumentType]
11
11
  BASEDIR = dirname(__file__)
12
12
 
13
13
  commands_doc = """
@@ -89,7 +89,7 @@ def on_message(_client, _userdata, message):
89
89
  FRAME_COUNT += 1
90
90
 
91
91
  global PROGRESS
92
- PROGRESS.update()
92
+ PROGRESS.update() # pyright: ignore [reportOptionalMemberAccess]
93
93
 
94
94
  if FRAME_COUNT == FRAMES:
95
95
  global DONE
@@ -582,7 +582,7 @@ def _process_mixture(mixid: int) -> tuple[pd.DataFrame, pd.DataFrame]:
582
582
  asr_mx = None
583
583
  asr_tge = None
584
584
  asr_engines = list(mixdb.asr_configs.keys())
585
- if len(asr_engines) > 0 and mixdb.mixture(mixid).snr >= -96: # noise only, ignore/reset target asr
585
+ if len(asr_engines) > 0 and not mixdb.mixture(mixid).is_noise_only: # noise only, ignore/reset target asr
586
586
  wer_mx = float(mixdb.mixture_metrics(mixid, [f"mxwer.{asr_engines[0]}"])[0]) * 100
587
587
  asr_tt = MP_GLOBAL.mixdb.mixture_speech_metadata(mixid, "text")[0] # ignore mixup
588
588
  if asr_tt is None:
sonusai/genft.py CHANGED
@@ -1,12 +1,13 @@
1
1
  """sonusai genft
2
2
 
3
- usage: genft [-hvs] [-i MIXID] LOC
3
+ usage: genft [-hvsn] [-i MIXID] LOC
4
4
 
5
5
  options:
6
6
  -h, --help
7
- -v, --verbose Be verbose.
8
- -i MIXID, --mixid MIXID Mixture ID(s) to generate. [default: *].
9
- -s, --segsnr Save segsnr. [default: False].
7
+ -v, --verbose Be verbose.
8
+ -i MIXID, --mixid MIXID Mixture ID(s) to generate. [default: *].
9
+ -s, --segsnr Save segsnr. [default: False].
10
+ -n, --nopar Do not run in parallel. [default: False].
10
11
 
11
12
  Generate SonusAI feature/truth data from a SonusAI mixture database.
12
13
 
@@ -25,11 +26,9 @@ Outputs the following to the mixture database directory:
25
26
  """
26
27
 
27
28
  import signal
28
- from dataclasses import dataclass
29
29
 
30
30
  from sonusai.mixture import GeneralizedIDs
31
31
  from sonusai.mixture import GenFTData
32
- from sonusai.mixture import MixtureDatabase
33
32
 
34
33
 
35
34
  def signal_handler(_sig, _frame):
@@ -44,87 +43,73 @@ def signal_handler(_sig, _frame):
44
43
  signal.signal(signal.SIGINT, signal_handler)
45
44
 
46
45
 
47
- @dataclass
48
- class MPGlobal:
49
- mixdb: MixtureDatabase
50
- compute_truth: bool
51
- compute_segsnr: bool
52
- force: bool
53
- write: bool
54
-
55
-
56
- MP_GLOBAL: MPGlobal
57
-
58
-
59
46
  def genft(
60
- mixdb: MixtureDatabase,
47
+ location: str,
61
48
  mixids: GeneralizedIDs = "*",
62
49
  compute_truth: bool = True,
63
50
  compute_segsnr: bool = False,
64
51
  write: bool = False,
65
52
  show_progress: bool = False,
66
53
  force: bool = True,
54
+ no_par: bool = False,
67
55
  ) -> list[GenFTData]:
56
+ from functools import partial
57
+
58
+ from sonusai.mixture import MixtureDatabase
68
59
  from sonusai.utils import par_track
69
60
  from sonusai.utils import track
70
61
 
62
+ mixdb = MixtureDatabase(location)
71
63
  mixids = mixdb.mixids_to_list(mixids)
72
64
 
73
65
  progress = track(total=len(mixids), disable=not show_progress)
74
66
  results = par_track(
75
- _genft_kernel,
67
+ partial(
68
+ _genft_kernel,
69
+ location=location,
70
+ compute_truth=compute_truth,
71
+ compute_segsnr=compute_segsnr,
72
+ force=force,
73
+ write=write,
74
+ ),
76
75
  mixids,
77
- initializer=_genft_initializer,
78
- initargs=(mixdb.location, compute_truth, compute_segsnr, force, write),
79
76
  progress=progress,
77
+ no_par=no_par,
80
78
  )
81
79
  progress.close()
82
80
 
83
81
  return results
84
82
 
85
83
 
86
- def _genft_initializer(location: str, compute_truth: bool, compute_segsnr: bool, force: bool, write: bool) -> None:
87
- global MP_GLOBAL
88
-
89
- MP_GLOBAL = MPGlobal(
90
- mixdb=MixtureDatabase(location),
91
- compute_truth=compute_truth,
92
- compute_segsnr=compute_segsnr,
93
- force=force,
94
- write=write,
95
- )
96
-
97
-
98
- def _genft_kernel(m_id: int) -> GenFTData:
84
+ def _genft_kernel(
85
+ m_id: int, location: str, compute_truth: bool, compute_segsnr: bool, force: bool, write: bool
86
+ ) -> GenFTData:
87
+ from sonusai.mixture import MixtureDatabase
99
88
  from sonusai.mixture import write_cached_data
100
89
  from sonusai.mixture import write_mixture_metadata
101
90
 
102
- global MP_GLOBAL
103
-
104
- mixdb = MP_GLOBAL.mixdb
105
- compute_truth = MP_GLOBAL.compute_truth
106
- compute_segsnr = MP_GLOBAL.compute_segsnr
107
- force = MP_GLOBAL.force
108
- write = MP_GLOBAL.write
91
+ mixdb = MixtureDatabase(location)
109
92
 
110
93
  result = GenFTData()
111
94
 
112
95
  feature, truth_f = mixdb.mixture_ft(m_id=m_id, force=force)
113
- write_data = [("feature", feature)]
114
96
  result.feature = feature
97
+ if write:
98
+ write_cached_data(mixdb.location, "mixture", mixdb.mixture(m_id).name, [("feature", feature)])
115
99
 
116
100
  if compute_truth:
117
- write_data.append(("truth_f", truth_f))
118
101
  result.truth_f = truth_f
102
+ if write:
103
+ write_cached_data(mixdb.location, "mixture", mixdb.mixture(m_id).name, [("truth_f", truth_f)])
119
104
 
120
105
  if compute_segsnr:
121
106
  segsnr = mixdb.mixture_segsnr(m_id=m_id, force=force)
122
- write_data.append(("segsnr", segsnr))
123
107
  result.segsnr = segsnr
108
+ if write:
109
+ write_cached_data(mixdb.location, "mixture", mixdb.mixture(m_id).name, [("segsnr", segsnr)])
124
110
 
125
111
  if write:
126
- write_cached_data(mixdb.location, "mixture", mixdb.mixture(m_id).name, write_data)
127
- write_mixture_metadata(mixdb, mixdb.mixture(m_id))
112
+ write_mixture_metadata(mixdb, m_id)
128
113
 
129
114
  return result
130
115
 
@@ -144,6 +129,8 @@ def main() -> None:
144
129
  from sonusai import initial_log_messages
145
130
  from sonusai import logger
146
131
  from sonusai import update_console_handler
132
+ from sonusai.mixture import SAMPLE_RATE
133
+ from sonusai.mixture import MixtureDatabase
147
134
  from sonusai.mixture import check_audio_files_exist
148
135
  from sonusai.utils import human_readable_size
149
136
  from sonusai.utils import seconds_to_hms
@@ -151,6 +138,7 @@ def main() -> None:
151
138
  verbose = args["--verbose"]
152
139
  mixids = args["--mixid"]
153
140
  compute_segsnr = args["--segsnr"]
141
+ no_par = args["--nopar"]
154
142
  location = args["LOC"]
155
143
 
156
144
  start_time = time.monotonic()
@@ -164,7 +152,7 @@ def main() -> None:
164
152
  mixids = mixdb.mixids_to_list(mixids)
165
153
 
166
154
  total_samples = mixdb.total_samples(mixids)
167
- duration = total_samples / sonusai.mixture.SAMPLE_RATE
155
+ duration = total_samples / SAMPLE_RATE
168
156
  total_transform_frames = total_samples // mixdb.ft_config.overlap
169
157
  total_feature_frames = total_samples // mixdb.feature_step_samples
170
158
 
@@ -180,11 +168,12 @@ def main() -> None:
180
168
 
181
169
  try:
182
170
  genft(
183
- mixdb=mixdb,
171
+ location=location,
184
172
  mixids=mixids,
185
173
  compute_segsnr=compute_segsnr,
186
174
  write=True,
187
175
  show_progress=True,
176
+ no_par=no_par,
188
177
  )
189
178
  except Exception as e:
190
179
  logger.debug(e)
sonusai/genmetrics.py CHANGED
@@ -1,6 +1,6 @@
1
1
  """sonusai genmetrics
2
2
 
3
- usage: genmetrics [-hvs] [-i MIXID] [-n INCLUDE] [-x EXCLUDE] LOC
3
+ usage: genmetrics [-hvsd] [-i MIXID] [-n INCLUDE] [-x EXCLUDE] LOC
4
4
 
5
5
  options:
6
6
  -h, --help
@@ -9,34 +9,43 @@ options:
9
9
  -n INCLUDE, --include INCLUDE Metrics to include. [default: all]
10
10
  -x EXCLUDE, --exclude EXCLUDE Metrics to exclude. [default: none]
11
11
  -s, --supported Show list of supported metrics.
12
+ -d, --dryrun Show list of metrics that will be generated and exit.
12
13
 
13
14
  Calculate speech enhancement metrics of SonusAI mixture data in LOC.
14
15
 
15
16
  Inputs:
16
17
  LOC A SonusAI mixture database directory.
17
18
  MIXID A glob of mixture ID(s) to generate.
18
- INCLUDE Comma separated list of metrics to include. Can be 'all' or
19
- any of the supported metrics.
20
- EXCLUDE Comma separated list of metrics to exclude. Can be 'none' or
21
- any of the supported metrics.
19
+ INCLUDE Comma separated list of metrics to include. Can be "all" or
20
+ any of the supported metrics or glob(s).
21
+ EXCLUDE Comma separated list of metrics to exclude. Can be "none" or
22
+ any of the supported metrics or glob(s)
23
+
24
+ Note: The default include of "all" excludes the generation of ASR metrics,
25
+ i.e., "*asr*,*wer*". However, if include is manually specified to something other than "all",
26
+ then this behavior is overridden.
27
+
28
+ Similarly, the default exclude of "none" excludes the generation of ASR metrics,
29
+ i.e., "*asr*,*wer*". However, if exclude is manually specified to something other than "none",
30
+ then this behavior is also overridden.
22
31
 
23
32
  Examples:
24
33
 
25
34
  Generate all available mxwer metrics (as determined by mixdb asr_configs parameter):
26
- > sonusai genmetrics -n"mxwer" mixdb_loc
35
+ > sonusai genmetrics -n"mxwer*" mixdb_loc
27
36
 
28
37
  Generate only mxwer.faster metrics:
29
38
  > sonusai genmetrics -n"mxwer.faster" mixdb_loc
30
39
 
31
- Generate all available metrics except for mxwer.faster:
32
- > sonusai genmetrics -x"mxwer.faster" mixdb_loc
40
+ Generate only faster metrics:
41
+ > sonusai genmetrics -n"*faster" mixdb_loc
42
+
43
+ Generate all available metrics except for mxcovl
44
+ > sonusai genmetrics -x"mxcovl" mixdb_loc
33
45
 
34
46
  """
35
47
 
36
48
  import signal
37
- from dataclasses import dataclass
38
-
39
- from sonusai.mixture import MixtureDatabase
40
49
 
41
50
 
42
51
  def signal_handler(_sig, _frame):
@@ -51,31 +60,11 @@ def signal_handler(_sig, _frame):
51
60
  signal.signal(signal.SIGINT, signal_handler)
52
61
 
53
62
 
54
- @dataclass
55
- class MPGlobal:
56
- mixdb: MixtureDatabase
57
- metrics: set[str]
58
-
59
-
60
- MP_GLOBAL: MPGlobal
61
-
62
-
63
- def _initializer(location: str, metrics: set[str]) -> None:
64
- global MP_GLOBAL
65
-
66
- MP_GLOBAL = MPGlobal(
67
- mixdb=MixtureDatabase(location),
68
- metrics=metrics,
69
- )
70
-
71
-
72
- def _process_mixture(mixid: int) -> None:
63
+ def _process_mixture(mixid: int, location: str, metrics: list[str]) -> None:
64
+ from sonusai.mixture import MixtureDatabase
73
65
  from sonusai.mixture import write_cached_data
74
66
 
75
- global MP_GLOBAL
76
-
77
- mixdb = MP_GLOBAL.mixdb
78
- metrics = list(MP_GLOBAL.metrics)
67
+ mixdb = MixtureDatabase(location)
79
68
 
80
69
  values = mixdb.mixture_metrics(m_id=mixid, metrics=metrics, force=True)
81
70
  write_data = list(zip(metrics, values, strict=False))
@@ -87,19 +76,23 @@ def main() -> None:
87
76
  from docopt import docopt
88
77
 
89
78
  import sonusai
79
+ from sonusai.mixture import MixtureDatabase
90
80
  from sonusai.utils import trim_docstring
91
81
 
92
82
  args = docopt(trim_docstring(__doc__), version=sonusai.__version__, options_first=True)
93
83
 
94
84
  verbose = args["--verbose"]
95
85
  mixids = args["--mixid"]
96
- includes = [x.strip() for x in args["--include"].lower().split(",")]
97
- excludes = [x.strip() for x in args["--exclude"].lower().split(",")]
86
+ includes = {x.strip() for x in args["--include"].replace(" ", ",").lower().split(",") if x != ""}
87
+ excludes = {x.strip() for x in args["--exclude"].replace(" ", ",").lower().split(",") if x != ""}
98
88
  show_supported = args["--supported"]
89
+ dryrun = args["--dryrun"]
99
90
  location = args["LOC"]
100
91
 
92
+ import fnmatch
101
93
  import sys
102
94
  import time
95
+ from functools import partial
103
96
  from os.path import join
104
97
 
105
98
  from sonusai import create_file_handler
@@ -110,9 +103,6 @@ def main() -> None:
110
103
  from sonusai.utils import seconds_to_hms
111
104
  from sonusai.utils import track
112
105
 
113
- # TODO: Check config.yml for changes to asr_configs and update mixdb
114
- # TODO: Support globs for metrics (includes and excludes)
115
-
116
106
  start_time = time.monotonic()
117
107
 
118
108
  # Setup logging file
@@ -128,38 +118,45 @@ def main() -> None:
128
118
  logger.info(f"\nSupported metrics:\n\n{supported.pretty}")
129
119
  sys.exit(0)
130
120
 
131
- if includes is None or "all" in includes:
132
- metrics = supported.names
133
- else:
134
- metrics = set(includes)
135
- if "mxwer" in metrics:
136
- metrics.remove("mxwer")
137
- for name in mixdb.asr_configs:
138
- metrics.add(f"mxwer.{name}")
139
-
140
- diff = metrics.difference(supported.names)
141
- if diff:
142
- logger.error(f"Unrecognized metric: {', '.join(diff)}")
143
- sys.exit(1)
121
+ # Handle default excludes
122
+ if "none" in excludes:
123
+ if "all" in includes:
124
+ excludes = {"*asr*", "*wer*"}
125
+ else:
126
+ excludes = set()
144
127
 
145
- if excludes is None or "none" in excludes:
146
- _excludes = set()
147
- else:
148
- _excludes = set(excludes)
149
- if "mxwer" in _excludes:
150
- _excludes.remove("mxwer")
151
- for name in mixdb.asr_configs:
152
- _excludes.add(f"mxwer.{name}")
153
-
154
- diff = _excludes.difference(supported.names)
155
- if diff:
156
- logger.error(f"Unrecognized metric: {', '.join(diff)}")
157
- sys.exit(1)
128
+ # Handle default includes
129
+ if "all" in includes:
130
+ includes = {"*"}
131
+
132
+ included_metrics: set[str] = set()
133
+ for include in includes:
134
+ for m in fnmatch.filter(supported.names, include):
135
+ included_metrics.add(m)
158
136
 
159
- for exclude in _excludes:
160
- metrics.discard(exclude)
137
+ excluded_metrics: set[str] = set()
138
+ for exclude in excludes:
139
+ for m in fnmatch.filter(supported.names, exclude):
140
+ excluded_metrics.add(m)
141
+
142
+ requested = included_metrics - excluded_metrics
143
+
144
+ # Check for metrics dependencies and cache dependencies even if not explicitly requested.
145
+ dependencies: set[str] = set()
146
+ for metric in requested:
147
+ if metric.startswith("mxwer"):
148
+ dependencies.add("mxasr." + metric[6:])
149
+ dependencies.add("tasr." + metric[6:])
150
+
151
+ metrics = sorted(requested | dependencies)
152
+
153
+ if len(metrics) == 0:
154
+ logger.warning("No metrics were requested")
155
+ sys.exit(1)
161
156
 
162
157
  logger.info(f"Generating metrics: {', '.join(metrics)}")
158
+ if dryrun:
159
+ sys.exit(0)
163
160
 
164
161
  mixids = mixdb.mixids_to_list(mixids)
165
162
  logger.info("")
@@ -167,11 +164,9 @@ def main() -> None:
167
164
 
168
165
  progress = track(total=len(mixids), desc="genmetrics")
169
166
  par_track(
170
- _process_mixture,
167
+ partial(_process_mixture, location=location, metrics=metrics),
171
168
  mixids,
172
169
  progress=progress,
173
- initializer=_initializer,
174
- initargs=(location, metrics),
175
170
  )
176
171
  progress.close()
177
172
 
sonusai/genmix.py CHANGED
@@ -1,14 +1,15 @@
1
1
  """sonusai genmix
2
2
 
3
- usage: genmix [-hvgts] [-i MIXID] LOC
3
+ usage: genmix [-hvgtsn] [-i MIXID] LOC
4
4
 
5
5
  options:
6
6
  -h, --help
7
- -v, --verbose Be verbose.
8
- -i MIXID, --mixid MIXID Mixture ID(s) to generate. [default: *].
9
- -g, --target Save target. [default: False].
10
- -t, --truth Save truth_t. [default: False].
11
- -s, --segsnr Save segsnr_t. [default: False].
7
+ -v, --verbose Be verbose.
8
+ -i MIXID, --mixid MIXID Mixture ID(s) to generate. [default: *].
9
+ -g, --target Save target. [default: False].
10
+ -t, --truth Save truth_t. [default: False].
11
+ -s, --segsnr Save segsnr_t. [default: False].
12
+ -n, --nopar Do not run in parallel. [default: False].
12
13
 
13
14
  Generate SonusAI mixture data from a SonusAI mixture database.
14
15
 
@@ -29,11 +30,9 @@ Outputs the following to the mixture database directory:
29
30
  """
30
31
 
31
32
  import signal
32
- from dataclasses import dataclass
33
33
 
34
34
  from sonusai.mixture import GeneralizedIDs
35
35
  from sonusai.mixture import GenMixData
36
- from sonusai.mixture import MixtureDatabase
37
36
 
38
37
 
39
38
  def signal_handler(_sig, _frame):
@@ -48,21 +47,8 @@ def signal_handler(_sig, _frame):
48
47
  signal.signal(signal.SIGINT, signal_handler)
49
48
 
50
49
 
51
- @dataclass
52
- class MPGlobal:
53
- mixdb: MixtureDatabase
54
- save_target: bool
55
- compute_truth: bool
56
- compute_segsnr: bool
57
- force: bool
58
- write: bool
59
-
60
-
61
- MP_GLOBAL: MPGlobal
62
-
63
-
64
50
  def genmix(
65
- mixdb: MixtureDatabase,
51
+ location: str,
66
52
  mixids: GeneralizedIDs = "*",
67
53
  save_target: bool = False,
68
54
  compute_truth: bool = False,
@@ -70,92 +56,92 @@ def genmix(
70
56
  write: bool = False,
71
57
  show_progress: bool = False,
72
58
  force: bool = True,
59
+ no_par: bool = False,
73
60
  ) -> list[GenMixData]:
61
+ from functools import partial
62
+
63
+ from sonusai.mixture import MixtureDatabase
74
64
  from sonusai.utils import par_track
75
65
  from sonusai.utils import track
76
66
 
67
+ mixdb = MixtureDatabase(location)
77
68
  mixids = mixdb.mixids_to_list(mixids)
78
69
  progress = track(total=len(mixids), disable=not show_progress)
79
70
  results = par_track(
80
- _genmix_kernel,
71
+ partial(
72
+ _genmix_kernel,
73
+ location=location,
74
+ save_target=save_target,
75
+ compute_truth=compute_truth,
76
+ compute_segsnr=compute_segsnr,
77
+ force=force,
78
+ write=write,
79
+ ),
81
80
  mixids,
82
- initializer=_genmix_initializer,
83
- initargs=(mixdb, save_target, compute_truth, compute_segsnr, force, write),
84
81
  progress=progress,
82
+ no_par=no_par,
85
83
  )
86
84
  progress.close()
87
85
 
88
86
  return results
89
87
 
90
88
 
91
- def _genmix_initializer(
92
- mixdb: MixtureDatabase,
89
+ def _genmix_kernel(
90
+ m_id: int,
91
+ location: str,
93
92
  save_target: bool,
94
93
  compute_truth: bool,
95
94
  compute_segsnr: bool,
96
95
  force: bool,
97
96
  write: bool,
98
- ) -> None:
99
- global MP_GLOBAL
100
-
101
- MP_GLOBAL = MPGlobal(
102
- mixdb=mixdb,
103
- save_target=save_target,
104
- compute_truth=compute_truth,
105
- compute_segsnr=compute_segsnr,
106
- force=force,
107
- write=write,
108
- )
109
-
110
-
111
- def _genmix_kernel(m_id: int) -> GenMixData:
97
+ ) -> GenMixData:
98
+ from sonusai.mixture import MixtureDatabase
112
99
  from sonusai.mixture import write_cached_data
113
100
  from sonusai.mixture import write_mixture_metadata
114
101
 
115
- global MP_GLOBAL
102
+ mixdb = MixtureDatabase(location)
116
103
 
117
- mixdb = MP_GLOBAL.mixdb
118
- save_target = MP_GLOBAL.save_target
119
- compute_truth = MP_GLOBAL.compute_truth
120
- compute_segsnr = MP_GLOBAL.compute_segsnr
121
- force = MP_GLOBAL.force
122
- write = MP_GLOBAL.write
104
+ result = GenMixData()
123
105
 
124
106
  targets = mixdb.mixture_targets(m_id=m_id, force=force)
107
+ result.targets = targets
125
108
  noise = mixdb.mixture_noise(m_id=m_id, force=force)
126
- write_data = [("targets", targets), ("noise", noise)]
109
+ result.noise = noise
110
+ if write:
111
+ write_cached_data(
112
+ mixdb.location,
113
+ "mixture",
114
+ mixdb.mixture(m_id).name,
115
+ [
116
+ ("targets", targets),
117
+ ("noise", noise),
118
+ ],
119
+ )
127
120
 
128
121
  if compute_truth:
129
122
  truth_t = mixdb.mixture_truth_t(m_id=m_id, targets=targets, noise=noise, force=force)
130
- write_data.append(("truth_t", truth_t))
131
- else:
132
- truth_t = None
123
+ result.truth_t = truth_t
124
+ if write:
125
+ write_cached_data(mixdb.location, "mixture", mixdb.mixture(m_id).name, [("truth_t", truth_t)])
133
126
 
134
127
  target = mixdb.mixture_target(m_id=m_id, targets=targets)
135
- if save_target:
136
- write_data.append(("target", target))
128
+ result.target = target
129
+ if save_target and write:
130
+ write_cached_data(mixdb.location, "mixture", mixdb.mixture(m_id).name, [("target", target)])
137
131
 
138
132
  if compute_segsnr:
139
133
  segsnr_t = mixdb.mixture_segsnr_t(m_id=m_id, targets=targets, target=target, noise=noise, force=force)
140
- write_data.append(("segsnr_t", segsnr_t))
141
- else:
142
- segsnr_t = None
134
+ result.segsnr_t = segsnr_t
135
+ if write:
136
+ write_cached_data(mixdb.location, "mixture", mixdb.mixture(m_id).name, [("segsnr_t", segsnr_t)])
143
137
 
144
138
  mixture = mixdb.mixture_mixture(m_id=m_id, targets=targets, target=target, noise=noise, force=force)
145
- write_data.append(("mixture", mixture))
146
-
139
+ result.mixture = mixture
147
140
  if write:
148
- write_cached_data(mixdb.location, "mixture", mixdb.mixture(m_id).name, write_data)
149
- write_mixture_metadata(mixdb, mixdb.mixture(m_id))
150
-
151
- return GenMixData(
152
- targets=targets,
153
- target=target,
154
- noise=noise,
155
- mixture=mixture,
156
- truth_t=truth_t,
157
- segsnr_t=segsnr_t,
158
- )
141
+ write_cached_data(mixdb.location, "mixture", mixdb.mixture(m_id).name, [("mixture", mixture)])
142
+ write_mixture_metadata(mixdb, m_id)
143
+
144
+ return result
159
145
 
160
146
 
161
147
  def main() -> None:
@@ -173,6 +159,8 @@ def main() -> None:
173
159
  from sonusai import initial_log_messages
174
160
  from sonusai import logger
175
161
  from sonusai import update_console_handler
162
+ from sonusai.mixture import SAMPLE_RATE
163
+ from sonusai.mixture import MixtureDatabase
176
164
  from sonusai.mixture import check_audio_files_exist
177
165
  from sonusai.utils import human_readable_size
178
166
  from sonusai.utils import seconds_to_hms
@@ -183,6 +171,7 @@ def main() -> None:
183
171
  save_target = args["--target"]
184
172
  compute_truth = args["--truth"]
185
173
  compute_segsnr = args["--segsnr"]
174
+ no_par = args["--nopar"]
186
175
 
187
176
  start_time = time.monotonic()
188
177
 
@@ -195,7 +184,7 @@ def main() -> None:
195
184
  mixids = mixdb.mixids_to_list(mixids)
196
185
 
197
186
  total_samples = mixdb.total_samples(mixids)
198
- duration = total_samples / sonusai.mixture.SAMPLE_RATE
187
+ duration = total_samples / SAMPLE_RATE
199
188
 
200
189
  logger.info("")
201
190
  logger.info(f"Found {len(mixids):,} mixtures to process")
@@ -205,13 +194,14 @@ def main() -> None:
205
194
 
206
195
  try:
207
196
  genmix(
208
- mixdb=mixdb,
197
+ location=location,
209
198
  mixids=mixids,
210
199
  save_target=save_target,
211
200
  compute_truth=compute_truth,
212
201
  compute_segsnr=compute_segsnr,
213
202
  write=True,
214
203
  show_progress=True,
204
+ no_par=no_par,
215
205
  )
216
206
  except Exception as e:
217
207
  logger.debug(e)