sonusai 1.0.16__cp311-abi3-macosx_10_12_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +170 -0
- sonusai/aawscd_probwrite.py +148 -0
- sonusai/audiofe.py +481 -0
- sonusai/calc_metric_spenh.py +1136 -0
- sonusai/config/__init__.py +0 -0
- sonusai/config/asr.py +21 -0
- sonusai/config/config.py +65 -0
- sonusai/config/config.yml +49 -0
- sonusai/config/constants.py +53 -0
- sonusai/config/ir.py +124 -0
- sonusai/config/ir_delay.py +62 -0
- sonusai/config/source.py +275 -0
- sonusai/config/spectral_masks.py +15 -0
- sonusai/config/truth.py +64 -0
- sonusai/constants.py +14 -0
- sonusai/data/__init__.py +0 -0
- sonusai/data/silero_vad_v5.1.jit +0 -0
- sonusai/data/silero_vad_v5.1.onnx +0 -0
- sonusai/data/speech_ma01_01.wav +0 -0
- sonusai/data/whitenoise.wav +0 -0
- sonusai/datatypes.py +383 -0
- sonusai/deprecated/gentcst.py +632 -0
- sonusai/deprecated/plot.py +519 -0
- sonusai/deprecated/tplot.py +365 -0
- sonusai/doc.py +52 -0
- sonusai/doc_strings/__init__.py +1 -0
- sonusai/doc_strings/doc_strings.py +531 -0
- sonusai/genft.py +196 -0
- sonusai/genmetrics.py +183 -0
- sonusai/genmix.py +199 -0
- sonusai/genmixdb.py +235 -0
- sonusai/ir_metric.py +551 -0
- sonusai/lsdb.py +141 -0
- sonusai/main.py +134 -0
- sonusai/metrics/__init__.py +43 -0
- sonusai/metrics/calc_audio_stats.py +42 -0
- sonusai/metrics/calc_class_weights.py +90 -0
- sonusai/metrics/calc_optimal_thresholds.py +73 -0
- sonusai/metrics/calc_pcm.py +45 -0
- sonusai/metrics/calc_pesq.py +36 -0
- sonusai/metrics/calc_phase_distance.py +43 -0
- sonusai/metrics/calc_sa_sdr.py +64 -0
- sonusai/metrics/calc_sample_weights.py +25 -0
- sonusai/metrics/calc_segsnr_f.py +82 -0
- sonusai/metrics/calc_speech.py +382 -0
- sonusai/metrics/calc_wer.py +71 -0
- sonusai/metrics/calc_wsdr.py +57 -0
- sonusai/metrics/calculate_metrics.py +395 -0
- sonusai/metrics/class_summary.py +74 -0
- sonusai/metrics/confusion_matrix_summary.py +75 -0
- sonusai/metrics/one_hot.py +283 -0
- sonusai/metrics/snr_summary.py +128 -0
- sonusai/metrics_summary.py +314 -0
- sonusai/mixture/__init__.py +15 -0
- sonusai/mixture/audio.py +187 -0
- sonusai/mixture/class_balancing.py +103 -0
- sonusai/mixture/constants.py +3 -0
- sonusai/mixture/data_io.py +173 -0
- sonusai/mixture/db.py +169 -0
- sonusai/mixture/db_datatypes.py +92 -0
- sonusai/mixture/effects.py +344 -0
- sonusai/mixture/feature.py +78 -0
- sonusai/mixture/generation.py +1116 -0
- sonusai/mixture/helpers.py +351 -0
- sonusai/mixture/ir_effects.py +77 -0
- sonusai/mixture/log_duration_and_sizes.py +23 -0
- sonusai/mixture/mixdb.py +1857 -0
- sonusai/mixture/pad_audio.py +35 -0
- sonusai/mixture/resample.py +7 -0
- sonusai/mixture/sox_effects.py +195 -0
- sonusai/mixture/sox_help.py +650 -0
- sonusai/mixture/spectral_mask.py +51 -0
- sonusai/mixture/truth.py +61 -0
- sonusai/mixture/truth_functions/__init__.py +45 -0
- sonusai/mixture/truth_functions/crm.py +105 -0
- sonusai/mixture/truth_functions/energy.py +222 -0
- sonusai/mixture/truth_functions/file.py +48 -0
- sonusai/mixture/truth_functions/metadata.py +24 -0
- sonusai/mixture/truth_functions/metrics.py +28 -0
- sonusai/mixture/truth_functions/phoneme.py +18 -0
- sonusai/mixture/truth_functions/sed.py +98 -0
- sonusai/mixture/truth_functions/target.py +142 -0
- sonusai/mkwav.py +135 -0
- sonusai/onnx_predict.py +363 -0
- sonusai/parse/__init__.py +0 -0
- sonusai/parse/expand.py +156 -0
- sonusai/parse/parse_source_directive.py +129 -0
- sonusai/parse/rand.py +214 -0
- sonusai/py.typed +0 -0
- sonusai/queries/__init__.py +0 -0
- sonusai/queries/queries.py +239 -0
- sonusai/rs.abi3.so +0 -0
- sonusai/rs.pyi +1 -0
- sonusai/rust/__init__.py +0 -0
- sonusai/speech/__init__.py +0 -0
- sonusai/speech/l2arctic.py +121 -0
- sonusai/speech/librispeech.py +102 -0
- sonusai/speech/mcgill.py +71 -0
- sonusai/speech/textgrid.py +89 -0
- sonusai/speech/timit.py +138 -0
- sonusai/speech/types.py +12 -0
- sonusai/speech/vctk.py +53 -0
- sonusai/speech/voxceleb.py +108 -0
- sonusai/utils/__init__.py +3 -0
- sonusai/utils/asl_p56.py +130 -0
- sonusai/utils/asr.py +91 -0
- sonusai/utils/asr_functions/__init__.py +3 -0
- sonusai/utils/asr_functions/aaware_whisper.py +69 -0
- sonusai/utils/audio_devices.py +50 -0
- sonusai/utils/braced_glob.py +50 -0
- sonusai/utils/calculate_input_shape.py +26 -0
- sonusai/utils/choice.py +51 -0
- sonusai/utils/compress.py +25 -0
- sonusai/utils/convert_string_to_number.py +6 -0
- sonusai/utils/create_timestamp.py +5 -0
- sonusai/utils/create_ts_name.py +14 -0
- sonusai/utils/dataclass_from_dict.py +27 -0
- sonusai/utils/db.py +16 -0
- sonusai/utils/docstring.py +53 -0
- sonusai/utils/energy_f.py +44 -0
- sonusai/utils/engineering_number.py +166 -0
- sonusai/utils/evaluate_random_rule.py +15 -0
- sonusai/utils/get_frames_per_batch.py +2 -0
- sonusai/utils/get_label_names.py +20 -0
- sonusai/utils/grouper.py +6 -0
- sonusai/utils/human_readable_size.py +7 -0
- sonusai/utils/keyboard_interrupt.py +12 -0
- sonusai/utils/load_object.py +21 -0
- sonusai/utils/max_text_width.py +9 -0
- sonusai/utils/model_utils.py +28 -0
- sonusai/utils/numeric_conversion.py +11 -0
- sonusai/utils/onnx_utils.py +155 -0
- sonusai/utils/parallel.py +162 -0
- sonusai/utils/path_info.py +7 -0
- sonusai/utils/print_mixture_details.py +60 -0
- sonusai/utils/rand.py +13 -0
- sonusai/utils/ranges.py +43 -0
- sonusai/utils/read_predict_data.py +32 -0
- sonusai/utils/reshape.py +154 -0
- sonusai/utils/seconds_to_hms.py +7 -0
- sonusai/utils/stacked_complex.py +82 -0
- sonusai/utils/stratified_shuffle_split.py +170 -0
- sonusai/utils/tokenized_shell_vars.py +143 -0
- sonusai/utils/write_audio.py +26 -0
- sonusai/utils/yes_or_no.py +8 -0
- sonusai/vars.py +47 -0
- sonusai-1.0.16.dist-info/METADATA +56 -0
- sonusai-1.0.16.dist-info/RECORD +150 -0
- sonusai-1.0.16.dist-info/WHEEL +4 -0
- sonusai-1.0.16.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,12 @@
|
|
1
|
+
def register_keyboard_interrupt() -> None:
|
2
|
+
import signal
|
3
|
+
|
4
|
+
def signal_handler(_sig, _frame):
|
5
|
+
import sys
|
6
|
+
|
7
|
+
from sonusai import logger
|
8
|
+
|
9
|
+
logger.info("Canceled due to keyboard interrupt")
|
10
|
+
sys.exit(1)
|
11
|
+
|
12
|
+
signal.signal(signal.SIGINT, signal_handler)
|
@@ -0,0 +1,21 @@
|
|
1
|
+
from functools import lru_cache
|
2
|
+
from typing import Any
|
3
|
+
|
4
|
+
|
5
|
+
def load_object(name: str, use_cache: bool = True) -> Any:
|
6
|
+
"""Load an object from a pickle file"""
|
7
|
+
if use_cache:
|
8
|
+
return _load_object(name)
|
9
|
+
return _load_object.__wrapped__(name)
|
10
|
+
|
11
|
+
|
12
|
+
@lru_cache
|
13
|
+
def _load_object(name: str) -> Any:
|
14
|
+
import pickle
|
15
|
+
from os.path import exists
|
16
|
+
|
17
|
+
if exists(name):
|
18
|
+
with open(name, "rb") as f:
|
19
|
+
return pickle.load(f) # noqa: S301
|
20
|
+
|
21
|
+
raise FileNotFoundError(name)
|
@@ -0,0 +1,9 @@
|
|
1
|
+
def max_text_width(number_of_items: int) -> int:
|
2
|
+
"""Compute maximum text width for the indices of a sequence of items.
|
3
|
+
|
4
|
+
:param number_of_items: Total number of items in sequence
|
5
|
+
:return: Text width of largest item index
|
6
|
+
"""
|
7
|
+
import numpy as np
|
8
|
+
|
9
|
+
return int(np.ceil(np.log10(number_of_items)))
|
@@ -0,0 +1,28 @@
|
|
1
|
+
from typing import Any
|
2
|
+
|
3
|
+
|
4
|
+
def import_module(name: str) -> Any:
|
5
|
+
"""Import a Python module adding the module file's directory to the Python system path so that relative package
|
6
|
+
imports are found correctly.
|
7
|
+
"""
|
8
|
+
import os
|
9
|
+
import sys
|
10
|
+
from importlib import import_module
|
11
|
+
|
12
|
+
try:
|
13
|
+
path = os.path.dirname(name)
|
14
|
+
if len(path) < 1:
|
15
|
+
path = "./"
|
16
|
+
|
17
|
+
# Add model file location to system path
|
18
|
+
sys.path.append(os.path.abspath(path))
|
19
|
+
|
20
|
+
try:
|
21
|
+
root = os.path.splitext(os.path.basename(name))[0]
|
22
|
+
model = import_module(root)
|
23
|
+
except Exception as e:
|
24
|
+
raise OSError(f"Error: could not import model from {name}: {e}.") from e
|
25
|
+
except Exception as e:
|
26
|
+
raise OSError(f"Error: could not find {name}: {e}.") from e
|
27
|
+
|
28
|
+
return model
|
@@ -0,0 +1,11 @@
|
|
1
|
+
import numpy as np
|
2
|
+
|
3
|
+
|
4
|
+
def int16_to_float(x: np.ndarray) -> np.ndarray:
|
5
|
+
"""Convert an int16 array to a floating point array with range +/- 1"""
|
6
|
+
return x.astype(np.float32) / 32768
|
7
|
+
|
8
|
+
|
9
|
+
def float_to_int16(x: np.ndarray) -> np.ndarray:
|
10
|
+
"""Convert a floating point array with range +/- 1 to an int16 array"""
|
11
|
+
return (x * 32768).astype(np.int16)
|
@@ -0,0 +1,155 @@
|
|
1
|
+
from collections.abc import Sequence
|
2
|
+
|
3
|
+
from onnx import ModelProto
|
4
|
+
from onnx import ValueInfoProto
|
5
|
+
from onnxruntime import InferenceSession
|
6
|
+
from onnxruntime import NodeArg # pyright: ignore [reportAttributeAccessIssue]
|
7
|
+
from onnxruntime import SessionOptions # pyright: ignore [reportAttributeAccessIssue]
|
8
|
+
|
9
|
+
REQUIRED_HPARAMS = ("feature", "batch_size", "timesteps")
|
10
|
+
|
11
|
+
|
12
|
+
def _extract_shapes(io: list[ValueInfoProto]) -> list[list[int] | str]:
|
13
|
+
shapes: list[list[int] | str] = []
|
14
|
+
|
15
|
+
# iterate through inputs of the graph to find shapes
|
16
|
+
for item in io:
|
17
|
+
# get tensor type: 0, 1, 2, etc.
|
18
|
+
tensor_type = item.type.tensor_type
|
19
|
+
# check if it has a shape
|
20
|
+
if tensor_type.HasField("shape"):
|
21
|
+
tmp_shape = []
|
22
|
+
# iterate through dimensions of the shape
|
23
|
+
for d in tensor_type.shape.dim:
|
24
|
+
if d.HasField("dim_value"):
|
25
|
+
# known dimension, int value
|
26
|
+
tmp_shape.append(d.dim_value)
|
27
|
+
elif d.HasField("dim_param"):
|
28
|
+
# dynamic dim with symbolic name of d.dim_param; set size to 0
|
29
|
+
tmp_shape.append(0)
|
30
|
+
else:
|
31
|
+
# unknown dimension with no name; also set to 0
|
32
|
+
tmp_shape.append(0)
|
33
|
+
# add as a list
|
34
|
+
shapes.append(tmp_shape)
|
35
|
+
else:
|
36
|
+
shapes.append("unknown rank")
|
37
|
+
|
38
|
+
return shapes
|
39
|
+
|
40
|
+
|
41
|
+
def get_and_check_inputs(model: ModelProto) -> tuple[list[ValueInfoProto], list[list[int] | str]]:
|
42
|
+
from sonusai import logger
|
43
|
+
|
44
|
+
# ignore initializer inputs (only seen in older ONNX < v1.5)
|
45
|
+
initializer_names = [x.name for x in model.graph.initializer]
|
46
|
+
inputs = [i for i in model.graph.input if i.name not in initializer_names]
|
47
|
+
if len(inputs) != 1:
|
48
|
+
logger.warning(f"Warning: ONNX model has {len(inputs)} inputs; expected only 1")
|
49
|
+
|
50
|
+
# This one-liner works only if input has type and shape, returns a list
|
51
|
+
# shape0 = [d.dim_value for d in inputs[0].type.tensor_type.shape.dim]
|
52
|
+
shapes = _extract_shapes(inputs)
|
53
|
+
|
54
|
+
return inputs, shapes
|
55
|
+
|
56
|
+
|
57
|
+
def get_and_check_outputs(model: ModelProto) -> tuple[list[ValueInfoProto], list[list[int] | str]]:
|
58
|
+
from sonusai import logger
|
59
|
+
|
60
|
+
outputs = list(model.graph.output)
|
61
|
+
if len(outputs) != 1:
|
62
|
+
logger.warning(f"Warning: ONNX model has {len(outputs)} outputs; expected only 1")
|
63
|
+
|
64
|
+
shapes = _extract_shapes(outputs)
|
65
|
+
|
66
|
+
return outputs, shapes
|
67
|
+
|
68
|
+
|
69
|
+
def add_sonusai_metadata(model: ModelProto, hparams: dict) -> ModelProto:
|
70
|
+
"""Add SonusAI hyperparameters as metadata to an ONNX model using 'hparams' key
|
71
|
+
|
72
|
+
:param model: ONNX model
|
73
|
+
:param hparams: dictionary of hyperparameters to add
|
74
|
+
:return: ONNX model
|
75
|
+
|
76
|
+
Note SonusAI conventions require models to have:
|
77
|
+
feature: Model feature type
|
78
|
+
batch_size: Model batch size
|
79
|
+
timesteps: Size of timestep dimension (0 for no dimension)
|
80
|
+
"""
|
81
|
+
from sonusai import logger
|
82
|
+
|
83
|
+
# Note hparams should be a dict (i.e., extracted from checkpoint)
|
84
|
+
if eval(str(hparams)) != hparams: # noqa: S307
|
85
|
+
raise TypeError("hparams is not a dict")
|
86
|
+
for key in REQUIRED_HPARAMS:
|
87
|
+
if key not in hparams:
|
88
|
+
logger.warning(f"Warning: SonusAI hyperparameters are missing: {key}")
|
89
|
+
|
90
|
+
meta = model.metadata_props.add()
|
91
|
+
meta.key = "hparams"
|
92
|
+
meta.value = str(hparams)
|
93
|
+
|
94
|
+
return model
|
95
|
+
|
96
|
+
|
97
|
+
def get_sonusai_metadata(session: InferenceSession) -> dict | None:
|
98
|
+
"""Get SonusAI hyperparameter metadata from an ONNX Runtime session."""
|
99
|
+
from sonusai import logger
|
100
|
+
|
101
|
+
meta = session.get_modelmeta()
|
102
|
+
if "hparams" not in meta.custom_metadata_map:
|
103
|
+
logger.warning("Warning: ONNX model metadata does not contain 'hparams'")
|
104
|
+
return None
|
105
|
+
|
106
|
+
hparams = eval(meta.custom_metadata_map["hparams"]) # noqa: S307
|
107
|
+
for key in REQUIRED_HPARAMS:
|
108
|
+
if key not in hparams:
|
109
|
+
logger.warning(f"Warning: ONNX model does not have required SonusAI hyperparameters: {key}")
|
110
|
+
|
111
|
+
return hparams
|
112
|
+
|
113
|
+
|
114
|
+
def load_ort_session(
|
115
|
+
model_path: str, providers: Sequence[str | tuple[str, dict]] | None = None
|
116
|
+
) -> tuple[InferenceSession, SessionOptions, str, dict | None, list[NodeArg], list[NodeArg]]:
|
117
|
+
from os.path import basename
|
118
|
+
from os.path import exists
|
119
|
+
from os.path import isfile
|
120
|
+
from os.path import splitext
|
121
|
+
|
122
|
+
import onnxruntime as ort
|
123
|
+
|
124
|
+
from sonusai import logger
|
125
|
+
|
126
|
+
if providers is None:
|
127
|
+
providers = ["CPUExecutionProvider"]
|
128
|
+
|
129
|
+
if exists(model_path) and isfile(model_path):
|
130
|
+
model_basename = basename(model_path)
|
131
|
+
model_root = splitext(model_basename)[0]
|
132
|
+
logger.info(f"Importing model from {model_basename}")
|
133
|
+
try:
|
134
|
+
session = ort.InferenceSession(model_path, providers=providers)
|
135
|
+
options = ort.SessionOptions()
|
136
|
+
except Exception as e:
|
137
|
+
logger.exception(f"Error: could not load ONNX model from {model_path}: {e}")
|
138
|
+
raise SystemExit(1) from e
|
139
|
+
else:
|
140
|
+
logger.exception(f"Error: model file does not exist: {model_path}")
|
141
|
+
raise SystemExit(1)
|
142
|
+
|
143
|
+
logger.info(f"Opened session with provider options: {session._provider_options}.")
|
144
|
+
hparams = get_sonusai_metadata(session)
|
145
|
+
if hparams is not None:
|
146
|
+
for key in REQUIRED_HPARAMS:
|
147
|
+
logger.info(f" {key:12} {hparams[key]}")
|
148
|
+
|
149
|
+
inputs = session.get_inputs()
|
150
|
+
outputs = session.get_outputs()
|
151
|
+
|
152
|
+
# in_names = [n.name for n in session.get_inputs()]
|
153
|
+
# out_names = [n.name for n in session.get_outputs()]
|
154
|
+
|
155
|
+
return session, options, model_root, hparams, inputs, outputs
|
@@ -0,0 +1,162 @@
|
|
1
|
+
import warnings
|
2
|
+
from collections.abc import Callable
|
3
|
+
from collections.abc import Iterable
|
4
|
+
from multiprocessing import current_process
|
5
|
+
from multiprocessing import get_context
|
6
|
+
from typing import Any
|
7
|
+
|
8
|
+
from tqdm import TqdmExperimentalWarning
|
9
|
+
from tqdm.rich import tqdm
|
10
|
+
|
11
|
+
warnings.filterwarnings(action="ignore", category=TqdmExperimentalWarning)
|
12
|
+
|
13
|
+
track = tqdm
|
14
|
+
|
15
|
+
CONTEXT = "fork"
|
16
|
+
|
17
|
+
|
18
|
+
def par_track(
|
19
|
+
func: Callable,
|
20
|
+
*iterables: Iterable,
|
21
|
+
initializer: Callable[..., None] | None = None,
|
22
|
+
initargs: Iterable[Any] | None = None,
|
23
|
+
progress: tqdm | None = None,
|
24
|
+
num_cpus: int | float | None = None,
|
25
|
+
total: int | None = None,
|
26
|
+
no_par: bool = False,
|
27
|
+
pass_index: bool = False,
|
28
|
+
) -> list[Any]:
|
29
|
+
"""Performs a parallel ordered imap with tqdm progress."""
|
30
|
+
total_items = _calculate_total_items(iterables, total)
|
31
|
+
results: list[Any] = [None] * total_items
|
32
|
+
|
33
|
+
if no_par or current_process().daemon:
|
34
|
+
_execute_sequential(func, iterables, initializer, initargs, results, progress, pass_index)
|
35
|
+
else:
|
36
|
+
cpu_count = _determine_cpu_count(num_cpus, total_items)
|
37
|
+
_execute_parallel(func, iterables, initializer, initargs, results, progress, pass_index, cpu_count)
|
38
|
+
|
39
|
+
if progress is not None:
|
40
|
+
progress.close()
|
41
|
+
return results
|
42
|
+
|
43
|
+
|
44
|
+
def _calculate_total_items(iterables: tuple[Iterable, ...], total: int | None) -> int:
|
45
|
+
"""Calculate the total number of items to process."""
|
46
|
+
from collections.abc import Sized
|
47
|
+
|
48
|
+
if total is None:
|
49
|
+
return min(len(iterable) for iterable in iterables if isinstance(iterable, Sized))
|
50
|
+
return int(total)
|
51
|
+
|
52
|
+
|
53
|
+
def _cpu_count() -> int:
|
54
|
+
"""Get the number of CPUs available."""
|
55
|
+
from psutil import cpu_count
|
56
|
+
|
57
|
+
count = cpu_count()
|
58
|
+
if count is None:
|
59
|
+
return 1
|
60
|
+
return count
|
61
|
+
|
62
|
+
|
63
|
+
def _determine_cpu_count(num_cpus: int | float | None, total_items: int) -> int:
|
64
|
+
"""Determine the optimal number of CPUs to use."""
|
65
|
+
if num_cpus is None:
|
66
|
+
# Reserve 2 CPUs for system, minimum 1
|
67
|
+
optimal_cpus = max(_cpu_count() - 2, 1)
|
68
|
+
elif isinstance(num_cpus, float):
|
69
|
+
optimal_cpus = int(round(num_cpus * _cpu_count()))
|
70
|
+
else:
|
71
|
+
optimal_cpus = int(num_cpus)
|
72
|
+
|
73
|
+
return min(optimal_cpus, total_items)
|
74
|
+
|
75
|
+
|
76
|
+
def _create_indexed_iterables(iterables: tuple[Iterable, ...]) -> tuple[Iterable, ...]:
|
77
|
+
"""Create iterables that include the index as the first argument."""
|
78
|
+
# Get the first iterable to enumerate over
|
79
|
+
first_iterable = iterables[0]
|
80
|
+
remaining_iterables = iterables[1:]
|
81
|
+
|
82
|
+
# Create an enumerated version: (index, first_item), second_item, third_item, ...
|
83
|
+
indexed_first = enumerate(first_iterable)
|
84
|
+
|
85
|
+
if remaining_iterables:
|
86
|
+
return (indexed_first,) + remaining_iterables
|
87
|
+
else:
|
88
|
+
return (indexed_first,)
|
89
|
+
|
90
|
+
|
91
|
+
class _IndexedFunctionWrapper:
|
92
|
+
"""Pickle-able wrapper class for functions that need an index as the first argument."""
|
93
|
+
|
94
|
+
def __init__(self, func: Callable):
|
95
|
+
self.func = func
|
96
|
+
|
97
|
+
def __call__(self, indexed_first_arg, *remaining_args):
|
98
|
+
index, first_arg = indexed_first_arg
|
99
|
+
return self.func(index, first_arg, *remaining_args)
|
100
|
+
|
101
|
+
|
102
|
+
def _wrap_function_with_index(func: Callable) -> Callable:
|
103
|
+
"""Wrap a function to handle indexed arguments."""
|
104
|
+
return _IndexedFunctionWrapper(func)
|
105
|
+
|
106
|
+
|
107
|
+
def _execute_sequential(
|
108
|
+
func: Callable,
|
109
|
+
iterables: tuple[Iterable, ...],
|
110
|
+
initializer: Callable[..., None] | None,
|
111
|
+
initargs: Iterable[Any] | None,
|
112
|
+
results: list[Any],
|
113
|
+
progress: tqdm | None,
|
114
|
+
pass_index: bool,
|
115
|
+
) -> None:
|
116
|
+
"""Execute a function sequentially without using multiprocessing."""
|
117
|
+
if initializer is not None:
|
118
|
+
if initargs is not None:
|
119
|
+
initializer(*initargs)
|
120
|
+
else:
|
121
|
+
initializer()
|
122
|
+
|
123
|
+
if pass_index:
|
124
|
+
mapped_iterables = _create_indexed_iterables(iterables)
|
125
|
+
wrapped_func = _wrap_function_with_index(func)
|
126
|
+
iterator = map(wrapped_func, *mapped_iterables)
|
127
|
+
else:
|
128
|
+
iterator = map(func, *iterables)
|
129
|
+
|
130
|
+
for index, result in enumerate(iterator):
|
131
|
+
results[index] = result
|
132
|
+
if progress is not None:
|
133
|
+
progress.update()
|
134
|
+
|
135
|
+
|
136
|
+
def _execute_parallel(
|
137
|
+
func: Callable,
|
138
|
+
iterables: tuple[Iterable, ...],
|
139
|
+
initializer: Callable[..., None] | None,
|
140
|
+
initargs: Iterable[Any] | None,
|
141
|
+
results: list[Any],
|
142
|
+
progress: tqdm | None,
|
143
|
+
pass_index: bool,
|
144
|
+
cpu_count: int,
|
145
|
+
) -> None:
|
146
|
+
"""Execute a function in parallel using multiprocessing."""
|
147
|
+
init_args = initargs if initargs is not None else []
|
148
|
+
|
149
|
+
if pass_index:
|
150
|
+
mapped_iterables = _create_indexed_iterables(iterables)
|
151
|
+
wrapped_func = _wrap_function_with_index(func)
|
152
|
+
else:
|
153
|
+
mapped_iterables = iterables
|
154
|
+
wrapped_func = func
|
155
|
+
|
156
|
+
with get_context(CONTEXT).Pool(processes=cpu_count, initializer=initializer, initargs=init_args) as pool:
|
157
|
+
for index, result in enumerate(pool.imap(wrapped_func, *mapped_iterables, chunksize=1)):
|
158
|
+
results[index] = result
|
159
|
+
if progress is not None:
|
160
|
+
progress.update()
|
161
|
+
pool.close()
|
162
|
+
pool.join()
|
@@ -0,0 +1,60 @@
|
|
1
|
+
from collections.abc import Callable
|
2
|
+
|
3
|
+
from ..datatypes import ClassCount
|
4
|
+
from ..mixture.helpers import mixture_all_speech_metadata
|
5
|
+
from ..mixture.mixdb import MixtureDatabase
|
6
|
+
|
7
|
+
|
8
|
+
def print_mixture_details(
|
9
|
+
mixdb: MixtureDatabase,
|
10
|
+
mixid: int | None = None,
|
11
|
+
print_fn: Callable = print,
|
12
|
+
) -> None:
|
13
|
+
from ..utils.seconds_to_hms import seconds_to_hms
|
14
|
+
|
15
|
+
if mixid is not None:
|
16
|
+
if 0 < mixid >= mixdb.num_mixtures:
|
17
|
+
raise ValueError(f"Given mixid is outside valid range of 0:{mixdb.num_mixtures - 1}.")
|
18
|
+
|
19
|
+
print_fn(f"Mixture {mixid} details")
|
20
|
+
mixture = mixdb.mixture(mixid)
|
21
|
+
speech_metadata = mixture_all_speech_metadata(mixdb, mixture)
|
22
|
+
for category, source in mixture.all_sources.items():
|
23
|
+
source_file = mixdb.source_file(source.file_id)
|
24
|
+
print_fn(f" {category}")
|
25
|
+
print_fn(f" name: {source_file.name}")
|
26
|
+
print_fn(f" effects: {source.effects.to_dict()}")
|
27
|
+
print_fn(f" pre_tempo: {source.pre_tempo}")
|
28
|
+
print_fn(f" duration: {seconds_to_hms(source_file.duration)}")
|
29
|
+
print_fn(f" start: {source.start}")
|
30
|
+
print_fn(f" repeat: {source.loop}")
|
31
|
+
print_fn(f" snr: {source.snr}")
|
32
|
+
print_fn(f" random_snr: {source.snr.is_random}")
|
33
|
+
print_fn(f" snr_gain: {source.snr_gain}")
|
34
|
+
for key in source_file.truth_configs:
|
35
|
+
print_fn(f" truth '{key}' function: {source_file.truth_configs[key].function}")
|
36
|
+
print_fn(f" truth '{key}' config: {source_file.truth_configs[key].config}")
|
37
|
+
print_fn(
|
38
|
+
f" truth '{key}' stride_reduction: {source_file.truth_configs[key].stride_reduction}"
|
39
|
+
)
|
40
|
+
for key in speech_metadata[category]:
|
41
|
+
print_fn(f"{category} speech {key}: {speech_metadata[category][key]}")
|
42
|
+
print_fn(f" samples: {mixture.samples}")
|
43
|
+
print_fn(f" feature frames: {mixdb.mixture_feature_frames(mixid)}")
|
44
|
+
print_fn("")
|
45
|
+
|
46
|
+
|
47
|
+
def print_class_count(
|
48
|
+
class_count: ClassCount,
|
49
|
+
length: int,
|
50
|
+
print_fn: Callable = print,
|
51
|
+
all_class_counts: bool = False,
|
52
|
+
) -> None:
|
53
|
+
from ..utils.max_text_width import max_text_width
|
54
|
+
|
55
|
+
print_fn("Class count:")
|
56
|
+
idx_len = max_text_width(len(class_count))
|
57
|
+
for idx, count in enumerate(class_count):
|
58
|
+
if all_class_counts or count > 0:
|
59
|
+
desc = f" class {idx + 1:{idx_len}}"
|
60
|
+
print_fn(f"{desc:{length}} {count}")
|
sonusai/utils/rand.py
ADDED
sonusai/utils/ranges.py
ADDED
@@ -0,0 +1,43 @@
|
|
1
|
+
def expand_range(s: str, sort: bool = True) -> list[int]:
|
2
|
+
"""Returns a list of integers from a string input representing a range."""
|
3
|
+
import re
|
4
|
+
|
5
|
+
clean_s = s.replace(":", "-")
|
6
|
+
clean_s = clean_s.replace(";", ",")
|
7
|
+
clean_s = re.sub(r" +", ",", clean_s)
|
8
|
+
clean_s = re.sub(r",+", ",", clean_s)
|
9
|
+
|
10
|
+
r: list[int] = []
|
11
|
+
for i in clean_s.split(","):
|
12
|
+
if "-" not in i:
|
13
|
+
r.append(int(i))
|
14
|
+
else:
|
15
|
+
lo, hi = map(int, i.split("-"))
|
16
|
+
r += range(lo, hi + 1)
|
17
|
+
|
18
|
+
if sort:
|
19
|
+
r = sorted(r)
|
20
|
+
|
21
|
+
return r
|
22
|
+
|
23
|
+
|
24
|
+
def consolidate_range(r: list[int]) -> str:
|
25
|
+
"""Returns a string representing a range from an input list of integers."""
|
26
|
+
from collections.abc import Generator
|
27
|
+
|
28
|
+
def ranges(i: list[int]) -> Generator[tuple[int, int], None, None]:
|
29
|
+
import itertools
|
30
|
+
|
31
|
+
for _, b in itertools.groupby(enumerate(i), lambda pair: pair[1] - pair[0]):
|
32
|
+
b_list = list(b)
|
33
|
+
yield b_list[0][1], b_list[-1][1]
|
34
|
+
|
35
|
+
ls: list[tuple[int, int]] = list(ranges(r))
|
36
|
+
result: list[str] = []
|
37
|
+
for val in ls:
|
38
|
+
entry = str(val[0])
|
39
|
+
if val[0] != val[1]:
|
40
|
+
entry += f"-{val[1]}"
|
41
|
+
result.append(entry)
|
42
|
+
|
43
|
+
return ", ".join(result)
|
@@ -0,0 +1,32 @@
|
|
1
|
+
import numpy as np
|
2
|
+
|
3
|
+
from ..datatypes import Predict
|
4
|
+
|
5
|
+
|
6
|
+
def read_predict_data(filename: str) -> Predict:
|
7
|
+
"""Read predict data from given HDF5 file and return it."""
|
8
|
+
import h5py
|
9
|
+
|
10
|
+
from .. import logger
|
11
|
+
|
12
|
+
logger.debug(f"Reading prediction data from {filename}")
|
13
|
+
with h5py.File(filename, "r") as f:
|
14
|
+
# prediction data is either [frames, num_classes], or [frames, timesteps, num_classes]
|
15
|
+
predict = np.array(f["predict"])
|
16
|
+
|
17
|
+
if predict.ndim == 2:
|
18
|
+
return predict
|
19
|
+
|
20
|
+
if predict.ndim == 3:
|
21
|
+
frames, timesteps, num_classes = predict.shape
|
22
|
+
|
23
|
+
logger.debug(
|
24
|
+
f"Reshaping prediction data in {filename} "
|
25
|
+
f""
|
26
|
+
f"from [{frames}, {timesteps}, {num_classes}] "
|
27
|
+
f"to [{frames * timesteps}, {num_classes}]"
|
28
|
+
)
|
29
|
+
predict = np.reshape(predict, [frames * timesteps, num_classes], order="F")
|
30
|
+
return predict
|
31
|
+
|
32
|
+
raise RuntimeError(f"Invalid prediction data dimensions in {filename}")
|