sonusai 1.0.16__cp311-abi3-macosx_10_12_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sonusai/__init__.py +170 -0
  2. sonusai/aawscd_probwrite.py +148 -0
  3. sonusai/audiofe.py +481 -0
  4. sonusai/calc_metric_spenh.py +1136 -0
  5. sonusai/config/__init__.py +0 -0
  6. sonusai/config/asr.py +21 -0
  7. sonusai/config/config.py +65 -0
  8. sonusai/config/config.yml +49 -0
  9. sonusai/config/constants.py +53 -0
  10. sonusai/config/ir.py +124 -0
  11. sonusai/config/ir_delay.py +62 -0
  12. sonusai/config/source.py +275 -0
  13. sonusai/config/spectral_masks.py +15 -0
  14. sonusai/config/truth.py +64 -0
  15. sonusai/constants.py +14 -0
  16. sonusai/data/__init__.py +0 -0
  17. sonusai/data/silero_vad_v5.1.jit +0 -0
  18. sonusai/data/silero_vad_v5.1.onnx +0 -0
  19. sonusai/data/speech_ma01_01.wav +0 -0
  20. sonusai/data/whitenoise.wav +0 -0
  21. sonusai/datatypes.py +383 -0
  22. sonusai/deprecated/gentcst.py +632 -0
  23. sonusai/deprecated/plot.py +519 -0
  24. sonusai/deprecated/tplot.py +365 -0
  25. sonusai/doc.py +52 -0
  26. sonusai/doc_strings/__init__.py +1 -0
  27. sonusai/doc_strings/doc_strings.py +531 -0
  28. sonusai/genft.py +196 -0
  29. sonusai/genmetrics.py +183 -0
  30. sonusai/genmix.py +199 -0
  31. sonusai/genmixdb.py +235 -0
  32. sonusai/ir_metric.py +551 -0
  33. sonusai/lsdb.py +141 -0
  34. sonusai/main.py +134 -0
  35. sonusai/metrics/__init__.py +43 -0
  36. sonusai/metrics/calc_audio_stats.py +42 -0
  37. sonusai/metrics/calc_class_weights.py +90 -0
  38. sonusai/metrics/calc_optimal_thresholds.py +73 -0
  39. sonusai/metrics/calc_pcm.py +45 -0
  40. sonusai/metrics/calc_pesq.py +36 -0
  41. sonusai/metrics/calc_phase_distance.py +43 -0
  42. sonusai/metrics/calc_sa_sdr.py +64 -0
  43. sonusai/metrics/calc_sample_weights.py +25 -0
  44. sonusai/metrics/calc_segsnr_f.py +82 -0
  45. sonusai/metrics/calc_speech.py +382 -0
  46. sonusai/metrics/calc_wer.py +71 -0
  47. sonusai/metrics/calc_wsdr.py +57 -0
  48. sonusai/metrics/calculate_metrics.py +395 -0
  49. sonusai/metrics/class_summary.py +74 -0
  50. sonusai/metrics/confusion_matrix_summary.py +75 -0
  51. sonusai/metrics/one_hot.py +283 -0
  52. sonusai/metrics/snr_summary.py +128 -0
  53. sonusai/metrics_summary.py +314 -0
  54. sonusai/mixture/__init__.py +15 -0
  55. sonusai/mixture/audio.py +187 -0
  56. sonusai/mixture/class_balancing.py +103 -0
  57. sonusai/mixture/constants.py +3 -0
  58. sonusai/mixture/data_io.py +173 -0
  59. sonusai/mixture/db.py +169 -0
  60. sonusai/mixture/db_datatypes.py +92 -0
  61. sonusai/mixture/effects.py +344 -0
  62. sonusai/mixture/feature.py +78 -0
  63. sonusai/mixture/generation.py +1116 -0
  64. sonusai/mixture/helpers.py +351 -0
  65. sonusai/mixture/ir_effects.py +77 -0
  66. sonusai/mixture/log_duration_and_sizes.py +23 -0
  67. sonusai/mixture/mixdb.py +1857 -0
  68. sonusai/mixture/pad_audio.py +35 -0
  69. sonusai/mixture/resample.py +7 -0
  70. sonusai/mixture/sox_effects.py +195 -0
  71. sonusai/mixture/sox_help.py +650 -0
  72. sonusai/mixture/spectral_mask.py +51 -0
  73. sonusai/mixture/truth.py +61 -0
  74. sonusai/mixture/truth_functions/__init__.py +45 -0
  75. sonusai/mixture/truth_functions/crm.py +105 -0
  76. sonusai/mixture/truth_functions/energy.py +222 -0
  77. sonusai/mixture/truth_functions/file.py +48 -0
  78. sonusai/mixture/truth_functions/metadata.py +24 -0
  79. sonusai/mixture/truth_functions/metrics.py +28 -0
  80. sonusai/mixture/truth_functions/phoneme.py +18 -0
  81. sonusai/mixture/truth_functions/sed.py +98 -0
  82. sonusai/mixture/truth_functions/target.py +142 -0
  83. sonusai/mkwav.py +135 -0
  84. sonusai/onnx_predict.py +363 -0
  85. sonusai/parse/__init__.py +0 -0
  86. sonusai/parse/expand.py +156 -0
  87. sonusai/parse/parse_source_directive.py +129 -0
  88. sonusai/parse/rand.py +214 -0
  89. sonusai/py.typed +0 -0
  90. sonusai/queries/__init__.py +0 -0
  91. sonusai/queries/queries.py +239 -0
  92. sonusai/rs.abi3.so +0 -0
  93. sonusai/rs.pyi +1 -0
  94. sonusai/rust/__init__.py +0 -0
  95. sonusai/speech/__init__.py +0 -0
  96. sonusai/speech/l2arctic.py +121 -0
  97. sonusai/speech/librispeech.py +102 -0
  98. sonusai/speech/mcgill.py +71 -0
  99. sonusai/speech/textgrid.py +89 -0
  100. sonusai/speech/timit.py +138 -0
  101. sonusai/speech/types.py +12 -0
  102. sonusai/speech/vctk.py +53 -0
  103. sonusai/speech/voxceleb.py +108 -0
  104. sonusai/utils/__init__.py +3 -0
  105. sonusai/utils/asl_p56.py +130 -0
  106. sonusai/utils/asr.py +91 -0
  107. sonusai/utils/asr_functions/__init__.py +3 -0
  108. sonusai/utils/asr_functions/aaware_whisper.py +69 -0
  109. sonusai/utils/audio_devices.py +50 -0
  110. sonusai/utils/braced_glob.py +50 -0
  111. sonusai/utils/calculate_input_shape.py +26 -0
  112. sonusai/utils/choice.py +51 -0
  113. sonusai/utils/compress.py +25 -0
  114. sonusai/utils/convert_string_to_number.py +6 -0
  115. sonusai/utils/create_timestamp.py +5 -0
  116. sonusai/utils/create_ts_name.py +14 -0
  117. sonusai/utils/dataclass_from_dict.py +27 -0
  118. sonusai/utils/db.py +16 -0
  119. sonusai/utils/docstring.py +53 -0
  120. sonusai/utils/energy_f.py +44 -0
  121. sonusai/utils/engineering_number.py +166 -0
  122. sonusai/utils/evaluate_random_rule.py +15 -0
  123. sonusai/utils/get_frames_per_batch.py +2 -0
  124. sonusai/utils/get_label_names.py +20 -0
  125. sonusai/utils/grouper.py +6 -0
  126. sonusai/utils/human_readable_size.py +7 -0
  127. sonusai/utils/keyboard_interrupt.py +12 -0
  128. sonusai/utils/load_object.py +21 -0
  129. sonusai/utils/max_text_width.py +9 -0
  130. sonusai/utils/model_utils.py +28 -0
  131. sonusai/utils/numeric_conversion.py +11 -0
  132. sonusai/utils/onnx_utils.py +155 -0
  133. sonusai/utils/parallel.py +162 -0
  134. sonusai/utils/path_info.py +7 -0
  135. sonusai/utils/print_mixture_details.py +60 -0
  136. sonusai/utils/rand.py +13 -0
  137. sonusai/utils/ranges.py +43 -0
  138. sonusai/utils/read_predict_data.py +32 -0
  139. sonusai/utils/reshape.py +154 -0
  140. sonusai/utils/seconds_to_hms.py +7 -0
  141. sonusai/utils/stacked_complex.py +82 -0
  142. sonusai/utils/stratified_shuffle_split.py +170 -0
  143. sonusai/utils/tokenized_shell_vars.py +143 -0
  144. sonusai/utils/write_audio.py +26 -0
  145. sonusai/utils/yes_or_no.py +8 -0
  146. sonusai/vars.py +47 -0
  147. sonusai-1.0.16.dist-info/METADATA +56 -0
  148. sonusai-1.0.16.dist-info/RECORD +150 -0
  149. sonusai-1.0.16.dist-info/WHEEL +4 -0
  150. sonusai-1.0.16.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,12 @@
1
+ def register_keyboard_interrupt() -> None:
2
+ import signal
3
+
4
+ def signal_handler(_sig, _frame):
5
+ import sys
6
+
7
+ from sonusai import logger
8
+
9
+ logger.info("Canceled due to keyboard interrupt")
10
+ sys.exit(1)
11
+
12
+ signal.signal(signal.SIGINT, signal_handler)
@@ -0,0 +1,21 @@
1
+ from functools import lru_cache
2
+ from typing import Any
3
+
4
+
5
+ def load_object(name: str, use_cache: bool = True) -> Any:
6
+ """Load an object from a pickle file"""
7
+ if use_cache:
8
+ return _load_object(name)
9
+ return _load_object.__wrapped__(name)
10
+
11
+
12
+ @lru_cache
13
+ def _load_object(name: str) -> Any:
14
+ import pickle
15
+ from os.path import exists
16
+
17
+ if exists(name):
18
+ with open(name, "rb") as f:
19
+ return pickle.load(f) # noqa: S301
20
+
21
+ raise FileNotFoundError(name)
@@ -0,0 +1,9 @@
1
+ def max_text_width(number_of_items: int) -> int:
2
+ """Compute maximum text width for the indices of a sequence of items.
3
+
4
+ :param number_of_items: Total number of items in sequence
5
+ :return: Text width of largest item index
6
+ """
7
+ import numpy as np
8
+
9
+ return int(np.ceil(np.log10(number_of_items)))
@@ -0,0 +1,28 @@
1
+ from typing import Any
2
+
3
+
4
+ def import_module(name: str) -> Any:
5
+ """Import a Python module adding the module file's directory to the Python system path so that relative package
6
+ imports are found correctly.
7
+ """
8
+ import os
9
+ import sys
10
+ from importlib import import_module
11
+
12
+ try:
13
+ path = os.path.dirname(name)
14
+ if len(path) < 1:
15
+ path = "./"
16
+
17
+ # Add model file location to system path
18
+ sys.path.append(os.path.abspath(path))
19
+
20
+ try:
21
+ root = os.path.splitext(os.path.basename(name))[0]
22
+ model = import_module(root)
23
+ except Exception as e:
24
+ raise OSError(f"Error: could not import model from {name}: {e}.") from e
25
+ except Exception as e:
26
+ raise OSError(f"Error: could not find {name}: {e}.") from e
27
+
28
+ return model
@@ -0,0 +1,11 @@
1
+ import numpy as np
2
+
3
+
4
+ def int16_to_float(x: np.ndarray) -> np.ndarray:
5
+ """Convert an int16 array to a floating point array with range +/- 1"""
6
+ return x.astype(np.float32) / 32768
7
+
8
+
9
+ def float_to_int16(x: np.ndarray) -> np.ndarray:
10
+ """Convert a floating point array with range +/- 1 to an int16 array"""
11
+ return (x * 32768).astype(np.int16)
@@ -0,0 +1,155 @@
1
+ from collections.abc import Sequence
2
+
3
+ from onnx import ModelProto
4
+ from onnx import ValueInfoProto
5
+ from onnxruntime import InferenceSession
6
+ from onnxruntime import NodeArg # pyright: ignore [reportAttributeAccessIssue]
7
+ from onnxruntime import SessionOptions # pyright: ignore [reportAttributeAccessIssue]
8
+
9
+ REQUIRED_HPARAMS = ("feature", "batch_size", "timesteps")
10
+
11
+
12
+ def _extract_shapes(io: list[ValueInfoProto]) -> list[list[int] | str]:
13
+ shapes: list[list[int] | str] = []
14
+
15
+ # iterate through inputs of the graph to find shapes
16
+ for item in io:
17
+ # get tensor type: 0, 1, 2, etc.
18
+ tensor_type = item.type.tensor_type
19
+ # check if it has a shape
20
+ if tensor_type.HasField("shape"):
21
+ tmp_shape = []
22
+ # iterate through dimensions of the shape
23
+ for d in tensor_type.shape.dim:
24
+ if d.HasField("dim_value"):
25
+ # known dimension, int value
26
+ tmp_shape.append(d.dim_value)
27
+ elif d.HasField("dim_param"):
28
+ # dynamic dim with symbolic name of d.dim_param; set size to 0
29
+ tmp_shape.append(0)
30
+ else:
31
+ # unknown dimension with no name; also set to 0
32
+ tmp_shape.append(0)
33
+ # add as a list
34
+ shapes.append(tmp_shape)
35
+ else:
36
+ shapes.append("unknown rank")
37
+
38
+ return shapes
39
+
40
+
41
+ def get_and_check_inputs(model: ModelProto) -> tuple[list[ValueInfoProto], list[list[int] | str]]:
42
+ from sonusai import logger
43
+
44
+ # ignore initializer inputs (only seen in older ONNX < v1.5)
45
+ initializer_names = [x.name for x in model.graph.initializer]
46
+ inputs = [i for i in model.graph.input if i.name not in initializer_names]
47
+ if len(inputs) != 1:
48
+ logger.warning(f"Warning: ONNX model has {len(inputs)} inputs; expected only 1")
49
+
50
+ # This one-liner works only if input has type and shape, returns a list
51
+ # shape0 = [d.dim_value for d in inputs[0].type.tensor_type.shape.dim]
52
+ shapes = _extract_shapes(inputs)
53
+
54
+ return inputs, shapes
55
+
56
+
57
+ def get_and_check_outputs(model: ModelProto) -> tuple[list[ValueInfoProto], list[list[int] | str]]:
58
+ from sonusai import logger
59
+
60
+ outputs = list(model.graph.output)
61
+ if len(outputs) != 1:
62
+ logger.warning(f"Warning: ONNX model has {len(outputs)} outputs; expected only 1")
63
+
64
+ shapes = _extract_shapes(outputs)
65
+
66
+ return outputs, shapes
67
+
68
+
69
+ def add_sonusai_metadata(model: ModelProto, hparams: dict) -> ModelProto:
70
+ """Add SonusAI hyperparameters as metadata to an ONNX model using 'hparams' key
71
+
72
+ :param model: ONNX model
73
+ :param hparams: dictionary of hyperparameters to add
74
+ :return: ONNX model
75
+
76
+ Note SonusAI conventions require models to have:
77
+ feature: Model feature type
78
+ batch_size: Model batch size
79
+ timesteps: Size of timestep dimension (0 for no dimension)
80
+ """
81
+ from sonusai import logger
82
+
83
+ # Note hparams should be a dict (i.e., extracted from checkpoint)
84
+ if eval(str(hparams)) != hparams: # noqa: S307
85
+ raise TypeError("hparams is not a dict")
86
+ for key in REQUIRED_HPARAMS:
87
+ if key not in hparams:
88
+ logger.warning(f"Warning: SonusAI hyperparameters are missing: {key}")
89
+
90
+ meta = model.metadata_props.add()
91
+ meta.key = "hparams"
92
+ meta.value = str(hparams)
93
+
94
+ return model
95
+
96
+
97
+ def get_sonusai_metadata(session: InferenceSession) -> dict | None:
98
+ """Get SonusAI hyperparameter metadata from an ONNX Runtime session."""
99
+ from sonusai import logger
100
+
101
+ meta = session.get_modelmeta()
102
+ if "hparams" not in meta.custom_metadata_map:
103
+ logger.warning("Warning: ONNX model metadata does not contain 'hparams'")
104
+ return None
105
+
106
+ hparams = eval(meta.custom_metadata_map["hparams"]) # noqa: S307
107
+ for key in REQUIRED_HPARAMS:
108
+ if key not in hparams:
109
+ logger.warning(f"Warning: ONNX model does not have required SonusAI hyperparameters: {key}")
110
+
111
+ return hparams
112
+
113
+
114
+ def load_ort_session(
115
+ model_path: str, providers: Sequence[str | tuple[str, dict]] | None = None
116
+ ) -> tuple[InferenceSession, SessionOptions, str, dict | None, list[NodeArg], list[NodeArg]]:
117
+ from os.path import basename
118
+ from os.path import exists
119
+ from os.path import isfile
120
+ from os.path import splitext
121
+
122
+ import onnxruntime as ort
123
+
124
+ from sonusai import logger
125
+
126
+ if providers is None:
127
+ providers = ["CPUExecutionProvider"]
128
+
129
+ if exists(model_path) and isfile(model_path):
130
+ model_basename = basename(model_path)
131
+ model_root = splitext(model_basename)[0]
132
+ logger.info(f"Importing model from {model_basename}")
133
+ try:
134
+ session = ort.InferenceSession(model_path, providers=providers)
135
+ options = ort.SessionOptions()
136
+ except Exception as e:
137
+ logger.exception(f"Error: could not load ONNX model from {model_path}: {e}")
138
+ raise SystemExit(1) from e
139
+ else:
140
+ logger.exception(f"Error: model file does not exist: {model_path}")
141
+ raise SystemExit(1)
142
+
143
+ logger.info(f"Opened session with provider options: {session._provider_options}.")
144
+ hparams = get_sonusai_metadata(session)
145
+ if hparams is not None:
146
+ for key in REQUIRED_HPARAMS:
147
+ logger.info(f" {key:12} {hparams[key]}")
148
+
149
+ inputs = session.get_inputs()
150
+ outputs = session.get_outputs()
151
+
152
+ # in_names = [n.name for n in session.get_inputs()]
153
+ # out_names = [n.name for n in session.get_outputs()]
154
+
155
+ return session, options, model_root, hparams, inputs, outputs
@@ -0,0 +1,162 @@
1
+ import warnings
2
+ from collections.abc import Callable
3
+ from collections.abc import Iterable
4
+ from multiprocessing import current_process
5
+ from multiprocessing import get_context
6
+ from typing import Any
7
+
8
+ from tqdm import TqdmExperimentalWarning
9
+ from tqdm.rich import tqdm
10
+
11
+ warnings.filterwarnings(action="ignore", category=TqdmExperimentalWarning)
12
+
13
+ track = tqdm
14
+
15
+ CONTEXT = "fork"
16
+
17
+
18
+ def par_track(
19
+ func: Callable,
20
+ *iterables: Iterable,
21
+ initializer: Callable[..., None] | None = None,
22
+ initargs: Iterable[Any] | None = None,
23
+ progress: tqdm | None = None,
24
+ num_cpus: int | float | None = None,
25
+ total: int | None = None,
26
+ no_par: bool = False,
27
+ pass_index: bool = False,
28
+ ) -> list[Any]:
29
+ """Performs a parallel ordered imap with tqdm progress."""
30
+ total_items = _calculate_total_items(iterables, total)
31
+ results: list[Any] = [None] * total_items
32
+
33
+ if no_par or current_process().daemon:
34
+ _execute_sequential(func, iterables, initializer, initargs, results, progress, pass_index)
35
+ else:
36
+ cpu_count = _determine_cpu_count(num_cpus, total_items)
37
+ _execute_parallel(func, iterables, initializer, initargs, results, progress, pass_index, cpu_count)
38
+
39
+ if progress is not None:
40
+ progress.close()
41
+ return results
42
+
43
+
44
+ def _calculate_total_items(iterables: tuple[Iterable, ...], total: int | None) -> int:
45
+ """Calculate the total number of items to process."""
46
+ from collections.abc import Sized
47
+
48
+ if total is None:
49
+ return min(len(iterable) for iterable in iterables if isinstance(iterable, Sized))
50
+ return int(total)
51
+
52
+
53
+ def _cpu_count() -> int:
54
+ """Get the number of CPUs available."""
55
+ from psutil import cpu_count
56
+
57
+ count = cpu_count()
58
+ if count is None:
59
+ return 1
60
+ return count
61
+
62
+
63
+ def _determine_cpu_count(num_cpus: int | float | None, total_items: int) -> int:
64
+ """Determine the optimal number of CPUs to use."""
65
+ if num_cpus is None:
66
+ # Reserve 2 CPUs for system, minimum 1
67
+ optimal_cpus = max(_cpu_count() - 2, 1)
68
+ elif isinstance(num_cpus, float):
69
+ optimal_cpus = int(round(num_cpus * _cpu_count()))
70
+ else:
71
+ optimal_cpus = int(num_cpus)
72
+
73
+ return min(optimal_cpus, total_items)
74
+
75
+
76
+ def _create_indexed_iterables(iterables: tuple[Iterable, ...]) -> tuple[Iterable, ...]:
77
+ """Create iterables that include the index as the first argument."""
78
+ # Get the first iterable to enumerate over
79
+ first_iterable = iterables[0]
80
+ remaining_iterables = iterables[1:]
81
+
82
+ # Create an enumerated version: (index, first_item), second_item, third_item, ...
83
+ indexed_first = enumerate(first_iterable)
84
+
85
+ if remaining_iterables:
86
+ return (indexed_first,) + remaining_iterables
87
+ else:
88
+ return (indexed_first,)
89
+
90
+
91
+ class _IndexedFunctionWrapper:
92
+ """Pickle-able wrapper class for functions that need an index as the first argument."""
93
+
94
+ def __init__(self, func: Callable):
95
+ self.func = func
96
+
97
+ def __call__(self, indexed_first_arg, *remaining_args):
98
+ index, first_arg = indexed_first_arg
99
+ return self.func(index, first_arg, *remaining_args)
100
+
101
+
102
+ def _wrap_function_with_index(func: Callable) -> Callable:
103
+ """Wrap a function to handle indexed arguments."""
104
+ return _IndexedFunctionWrapper(func)
105
+
106
+
107
+ def _execute_sequential(
108
+ func: Callable,
109
+ iterables: tuple[Iterable, ...],
110
+ initializer: Callable[..., None] | None,
111
+ initargs: Iterable[Any] | None,
112
+ results: list[Any],
113
+ progress: tqdm | None,
114
+ pass_index: bool,
115
+ ) -> None:
116
+ """Execute a function sequentially without using multiprocessing."""
117
+ if initializer is not None:
118
+ if initargs is not None:
119
+ initializer(*initargs)
120
+ else:
121
+ initializer()
122
+
123
+ if pass_index:
124
+ mapped_iterables = _create_indexed_iterables(iterables)
125
+ wrapped_func = _wrap_function_with_index(func)
126
+ iterator = map(wrapped_func, *mapped_iterables)
127
+ else:
128
+ iterator = map(func, *iterables)
129
+
130
+ for index, result in enumerate(iterator):
131
+ results[index] = result
132
+ if progress is not None:
133
+ progress.update()
134
+
135
+
136
+ def _execute_parallel(
137
+ func: Callable,
138
+ iterables: tuple[Iterable, ...],
139
+ initializer: Callable[..., None] | None,
140
+ initargs: Iterable[Any] | None,
141
+ results: list[Any],
142
+ progress: tqdm | None,
143
+ pass_index: bool,
144
+ cpu_count: int,
145
+ ) -> None:
146
+ """Execute a function in parallel using multiprocessing."""
147
+ init_args = initargs if initargs is not None else []
148
+
149
+ if pass_index:
150
+ mapped_iterables = _create_indexed_iterables(iterables)
151
+ wrapped_func = _wrap_function_with_index(func)
152
+ else:
153
+ mapped_iterables = iterables
154
+ wrapped_func = func
155
+
156
+ with get_context(CONTEXT).Pool(processes=cpu_count, initializer=initializer, initargs=init_args) as pool:
157
+ for index, result in enumerate(pool.imap(wrapped_func, *mapped_iterables, chunksize=1)):
158
+ results[index] = result
159
+ if progress is not None:
160
+ progress.update()
161
+ pool.close()
162
+ pool.join()
@@ -0,0 +1,7 @@
1
+ from dataclasses import dataclass
2
+
3
+
4
+ @dataclass(frozen=True)
5
+ class PathInfo:
6
+ abs_path: str
7
+ audio_filepath: str
@@ -0,0 +1,60 @@
1
+ from collections.abc import Callable
2
+
3
+ from ..datatypes import ClassCount
4
+ from ..mixture.helpers import mixture_all_speech_metadata
5
+ from ..mixture.mixdb import MixtureDatabase
6
+
7
+
8
+ def print_mixture_details(
9
+ mixdb: MixtureDatabase,
10
+ mixid: int | None = None,
11
+ print_fn: Callable = print,
12
+ ) -> None:
13
+ from ..utils.seconds_to_hms import seconds_to_hms
14
+
15
+ if mixid is not None:
16
+ if 0 < mixid >= mixdb.num_mixtures:
17
+ raise ValueError(f"Given mixid is outside valid range of 0:{mixdb.num_mixtures - 1}.")
18
+
19
+ print_fn(f"Mixture {mixid} details")
20
+ mixture = mixdb.mixture(mixid)
21
+ speech_metadata = mixture_all_speech_metadata(mixdb, mixture)
22
+ for category, source in mixture.all_sources.items():
23
+ source_file = mixdb.source_file(source.file_id)
24
+ print_fn(f" {category}")
25
+ print_fn(f" name: {source_file.name}")
26
+ print_fn(f" effects: {source.effects.to_dict()}")
27
+ print_fn(f" pre_tempo: {source.pre_tempo}")
28
+ print_fn(f" duration: {seconds_to_hms(source_file.duration)}")
29
+ print_fn(f" start: {source.start}")
30
+ print_fn(f" repeat: {source.loop}")
31
+ print_fn(f" snr: {source.snr}")
32
+ print_fn(f" random_snr: {source.snr.is_random}")
33
+ print_fn(f" snr_gain: {source.snr_gain}")
34
+ for key in source_file.truth_configs:
35
+ print_fn(f" truth '{key}' function: {source_file.truth_configs[key].function}")
36
+ print_fn(f" truth '{key}' config: {source_file.truth_configs[key].config}")
37
+ print_fn(
38
+ f" truth '{key}' stride_reduction: {source_file.truth_configs[key].stride_reduction}"
39
+ )
40
+ for key in speech_metadata[category]:
41
+ print_fn(f"{category} speech {key}: {speech_metadata[category][key]}")
42
+ print_fn(f" samples: {mixture.samples}")
43
+ print_fn(f" feature frames: {mixdb.mixture_feature_frames(mixid)}")
44
+ print_fn("")
45
+
46
+
47
+ def print_class_count(
48
+ class_count: ClassCount,
49
+ length: int,
50
+ print_fn: Callable = print,
51
+ all_class_counts: bool = False,
52
+ ) -> None:
53
+ from ..utils.max_text_width import max_text_width
54
+
55
+ print_fn("Class count:")
56
+ idx_len = max_text_width(len(class_count))
57
+ for idx, count in enumerate(class_count):
58
+ if all_class_counts or count > 0:
59
+ desc = f" class {idx + 1:{idx_len}}"
60
+ print_fn(f"{desc:{length}} {count}")
sonusai/utils/rand.py ADDED
@@ -0,0 +1,13 @@
1
+ import contextlib
2
+
3
+
4
+ @contextlib.contextmanager
5
+ def seed_context(seed):
6
+ import numpy as np
7
+
8
+ state = np.random.get_state()
9
+ np.random.seed(seed)
10
+ try:
11
+ yield
12
+ finally:
13
+ np.random.set_state(state)
@@ -0,0 +1,43 @@
1
+ def expand_range(s: str, sort: bool = True) -> list[int]:
2
+ """Returns a list of integers from a string input representing a range."""
3
+ import re
4
+
5
+ clean_s = s.replace(":", "-")
6
+ clean_s = clean_s.replace(";", ",")
7
+ clean_s = re.sub(r" +", ",", clean_s)
8
+ clean_s = re.sub(r",+", ",", clean_s)
9
+
10
+ r: list[int] = []
11
+ for i in clean_s.split(","):
12
+ if "-" not in i:
13
+ r.append(int(i))
14
+ else:
15
+ lo, hi = map(int, i.split("-"))
16
+ r += range(lo, hi + 1)
17
+
18
+ if sort:
19
+ r = sorted(r)
20
+
21
+ return r
22
+
23
+
24
+ def consolidate_range(r: list[int]) -> str:
25
+ """Returns a string representing a range from an input list of integers."""
26
+ from collections.abc import Generator
27
+
28
+ def ranges(i: list[int]) -> Generator[tuple[int, int], None, None]:
29
+ import itertools
30
+
31
+ for _, b in itertools.groupby(enumerate(i), lambda pair: pair[1] - pair[0]):
32
+ b_list = list(b)
33
+ yield b_list[0][1], b_list[-1][1]
34
+
35
+ ls: list[tuple[int, int]] = list(ranges(r))
36
+ result: list[str] = []
37
+ for val in ls:
38
+ entry = str(val[0])
39
+ if val[0] != val[1]:
40
+ entry += f"-{val[1]}"
41
+ result.append(entry)
42
+
43
+ return ", ".join(result)
@@ -0,0 +1,32 @@
1
+ import numpy as np
2
+
3
+ from ..datatypes import Predict
4
+
5
+
6
+ def read_predict_data(filename: str) -> Predict:
7
+ """Read predict data from given HDF5 file and return it."""
8
+ import h5py
9
+
10
+ from .. import logger
11
+
12
+ logger.debug(f"Reading prediction data from {filename}")
13
+ with h5py.File(filename, "r") as f:
14
+ # prediction data is either [frames, num_classes], or [frames, timesteps, num_classes]
15
+ predict = np.array(f["predict"])
16
+
17
+ if predict.ndim == 2:
18
+ return predict
19
+
20
+ if predict.ndim == 3:
21
+ frames, timesteps, num_classes = predict.shape
22
+
23
+ logger.debug(
24
+ f"Reshaping prediction data in {filename} "
25
+ f""
26
+ f"from [{frames}, {timesteps}, {num_classes}] "
27
+ f"to [{frames * timesteps}, {num_classes}]"
28
+ )
29
+ predict = np.reshape(predict, [frames * timesteps, num_classes], order="F")
30
+ return predict
31
+
32
+ raise RuntimeError(f"Invalid prediction data dimensions in {filename}")