sonusai 0.18.9__py3-none-any.whl → 0.19.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +20 -29
- sonusai/aawscd_probwrite.py +18 -18
- sonusai/audiofe.py +93 -80
- sonusai/calc_metric_spenh.py +395 -321
- sonusai/data/genmixdb.yml +5 -11
- sonusai/{gentcst.py → deprecated/gentcst.py} +146 -149
- sonusai/{plot.py → deprecated/plot.py} +177 -131
- sonusai/{tplot.py → deprecated/tplot.py} +124 -102
- sonusai/doc/__init__.py +1 -1
- sonusai/doc/doc.py +112 -177
- sonusai/doc.py +10 -10
- sonusai/genft.py +81 -91
- sonusai/genmetrics.py +51 -61
- sonusai/genmix.py +105 -115
- sonusai/genmixdb.py +201 -174
- sonusai/lsdb.py +56 -66
- sonusai/main.py +23 -20
- sonusai/metrics/__init__.py +2 -0
- sonusai/metrics/calc_audio_stats.py +29 -24
- sonusai/metrics/calc_class_weights.py +7 -7
- sonusai/metrics/calc_optimal_thresholds.py +5 -7
- sonusai/metrics/calc_pcm.py +3 -3
- sonusai/metrics/calc_pesq.py +10 -7
- sonusai/metrics/calc_phase_distance.py +3 -3
- sonusai/metrics/calc_sa_sdr.py +10 -8
- sonusai/metrics/calc_segsnr_f.py +16 -18
- sonusai/metrics/calc_speech.py +105 -47
- sonusai/metrics/calc_wer.py +35 -32
- sonusai/metrics/calc_wsdr.py +10 -7
- sonusai/metrics/class_summary.py +30 -27
- sonusai/metrics/confusion_matrix_summary.py +25 -22
- sonusai/metrics/one_hot.py +91 -57
- sonusai/metrics/snr_summary.py +53 -46
- sonusai/mixture/__init__.py +20 -14
- sonusai/mixture/audio.py +4 -6
- sonusai/mixture/augmentation.py +37 -43
- sonusai/mixture/class_count.py +5 -14
- sonusai/mixture/config.py +292 -225
- sonusai/mixture/constants.py +41 -30
- sonusai/mixture/data_io.py +155 -0
- sonusai/mixture/datatypes.py +111 -108
- sonusai/mixture/db_datatypes.py +54 -70
- sonusai/mixture/eq_rule_is_valid.py +6 -9
- sonusai/mixture/feature.py +40 -38
- sonusai/mixture/generation.py +522 -389
- sonusai/mixture/helpers.py +217 -272
- sonusai/mixture/log_duration_and_sizes.py +16 -13
- sonusai/mixture/mixdb.py +669 -477
- sonusai/mixture/soundfile_audio.py +12 -17
- sonusai/mixture/sox_audio.py +91 -112
- sonusai/mixture/sox_augmentation.py +8 -9
- sonusai/mixture/spectral_mask.py +4 -6
- sonusai/mixture/target_class_balancing.py +41 -36
- sonusai/mixture/targets.py +69 -67
- sonusai/mixture/tokenized_shell_vars.py +23 -23
- sonusai/mixture/torchaudio_audio.py +14 -15
- sonusai/mixture/torchaudio_augmentation.py +23 -27
- sonusai/mixture/truth.py +48 -26
- sonusai/mixture/truth_functions/__init__.py +26 -0
- sonusai/mixture/truth_functions/crm.py +56 -38
- sonusai/mixture/truth_functions/datatypes.py +37 -0
- sonusai/mixture/truth_functions/energy.py +85 -59
- sonusai/mixture/truth_functions/file.py +30 -30
- sonusai/mixture/truth_functions/phoneme.py +14 -7
- sonusai/mixture/truth_functions/sed.py +71 -45
- sonusai/mixture/truth_functions/target.py +69 -106
- sonusai/mkwav.py +58 -101
- sonusai/onnx_predict.py +46 -43
- sonusai/queries/__init__.py +3 -1
- sonusai/queries/queries.py +100 -59
- sonusai/speech/__init__.py +2 -0
- sonusai/speech/l2arctic.py +24 -23
- sonusai/speech/librispeech.py +16 -17
- sonusai/speech/mcgill.py +22 -21
- sonusai/speech/textgrid.py +32 -25
- sonusai/speech/timit.py +45 -42
- sonusai/speech/vctk.py +14 -13
- sonusai/speech/voxceleb.py +26 -20
- sonusai/summarize_metric_spenh.py +11 -10
- sonusai/utils/__init__.py +4 -3
- sonusai/utils/asl_p56.py +1 -1
- sonusai/utils/asr.py +37 -17
- sonusai/utils/asr_functions/__init__.py +2 -0
- sonusai/utils/asr_functions/aaware_whisper.py +18 -12
- sonusai/utils/audio_devices.py +12 -12
- sonusai/utils/braced_glob.py +6 -8
- sonusai/utils/calculate_input_shape.py +1 -4
- sonusai/utils/compress.py +2 -2
- sonusai/utils/convert_string_to_number.py +1 -3
- sonusai/utils/create_timestamp.py +1 -1
- sonusai/utils/create_ts_name.py +2 -2
- sonusai/utils/dataclass_from_dict.py +1 -1
- sonusai/utils/docstring.py +6 -6
- sonusai/utils/energy_f.py +9 -7
- sonusai/utils/engineering_number.py +56 -54
- sonusai/utils/get_label_names.py +8 -10
- sonusai/utils/human_readable_size.py +2 -2
- sonusai/utils/model_utils.py +3 -5
- sonusai/utils/numeric_conversion.py +2 -4
- sonusai/utils/onnx_utils.py +43 -32
- sonusai/utils/parallel.py +41 -30
- sonusai/utils/print_mixture_details.py +25 -22
- sonusai/utils/ranges.py +12 -12
- sonusai/utils/read_predict_data.py +11 -9
- sonusai/utils/reshape.py +19 -26
- sonusai/utils/seconds_to_hms.py +1 -1
- sonusai/utils/stacked_complex.py +8 -16
- sonusai/utils/stratified_shuffle_split.py +29 -27
- sonusai/utils/write_audio.py +2 -2
- sonusai/utils/yes_or_no.py +3 -3
- sonusai/vars.py +14 -14
- {sonusai-0.18.9.dist-info → sonusai-0.19.6.dist-info}/METADATA +20 -21
- sonusai-0.19.6.dist-info/RECORD +125 -0
- {sonusai-0.18.9.dist-info → sonusai-0.19.6.dist-info}/WHEEL +1 -1
- sonusai/mixture/truth_functions/data.py +0 -58
- sonusai/utils/read_mixture_data.py +0 -14
- sonusai-0.18.9.dist-info/RECORD +0 -125
- {sonusai-0.18.9.dist-info → sonusai-0.19.6.dist-info}/entry_points.txt +0 -0
@@ -2,90 +2,92 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
from decimal import Decimal
|
4
4
|
from typing import Any
|
5
|
+
from typing import ClassVar
|
5
6
|
|
6
7
|
|
7
8
|
class EngineeringNumber:
|
8
9
|
"""Easy manipulation of numbers which use engineering notation"""
|
9
10
|
|
10
|
-
_suffix_lookup = {
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
11
|
+
_suffix_lookup: ClassVar = {
|
12
|
+
"Y": "e24",
|
13
|
+
"Z": "e21",
|
14
|
+
"E": "e18",
|
15
|
+
"P": "e15",
|
16
|
+
"T": "e12",
|
17
|
+
"G": "e9",
|
18
|
+
"M": "e6",
|
19
|
+
"k": "e3",
|
20
|
+
"": "e0",
|
21
|
+
"m": "e-3",
|
22
|
+
"u": "e-6",
|
23
|
+
"n": "e-9",
|
24
|
+
"p": "e-12",
|
25
|
+
"f": "e-15",
|
26
|
+
"a": "e-18",
|
27
|
+
"z": "e-21",
|
28
|
+
"y": "e-24",
|
28
29
|
}
|
29
30
|
|
30
|
-
_exponent_lookup_scaled = {
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
31
|
+
_exponent_lookup_scaled: ClassVar = {
|
32
|
+
"-12": "Y",
|
33
|
+
"-15": "Z",
|
34
|
+
"-18": "E",
|
35
|
+
"-21": "P",
|
36
|
+
"-24": "T",
|
37
|
+
"-27": "G",
|
38
|
+
"-30": "M",
|
39
|
+
"-33": "k",
|
40
|
+
"-36": "",
|
41
|
+
"-39": "m",
|
42
|
+
"-42": "u",
|
43
|
+
"-45": "n",
|
44
|
+
"-48": "p",
|
45
|
+
"-51": "f",
|
46
|
+
"-54": "a",
|
47
|
+
"-57": "z",
|
48
|
+
"-60": "y",
|
48
49
|
}
|
49
50
|
|
50
|
-
def __init__(
|
51
|
-
|
52
|
-
|
53
|
-
|
51
|
+
def __init__(
|
52
|
+
self,
|
53
|
+
value: str | float | int | EngineeringNumber,
|
54
|
+
precision: int = 2,
|
55
|
+
significant: int = 0,
|
56
|
+
):
|
54
57
|
"""
|
55
58
|
:param value: string, integer, or float representing the numeric value of the number
|
56
59
|
:param precision: the precision past the decimal
|
57
60
|
:param significant: the number of significant digits
|
58
61
|
if given, significant takes precedence over precision
|
59
62
|
"""
|
60
|
-
from sonusai import SonusAIError
|
61
|
-
|
62
63
|
self.precision = precision
|
63
64
|
self.significant = significant
|
64
65
|
|
65
66
|
if isinstance(value, str):
|
66
|
-
suffix_keys = [key for key in self._suffix_lookup
|
67
|
+
suffix_keys = [key for key in self._suffix_lookup if key != ""]
|
67
68
|
|
69
|
+
str_value = str(value)
|
68
70
|
for suffix in suffix_keys:
|
69
|
-
if suffix in
|
70
|
-
|
71
|
+
if suffix in str_value:
|
72
|
+
str_value = str_value[:-1] + self._suffix_lookup[suffix]
|
71
73
|
break
|
72
74
|
|
73
|
-
self.number = Decimal(
|
75
|
+
self.number = Decimal(str_value)
|
74
76
|
|
75
|
-
elif isinstance(value, int
|
77
|
+
elif isinstance(value, int | float | EngineeringNumber):
|
76
78
|
self.number = Decimal(str(value))
|
77
79
|
|
78
80
|
else:
|
79
|
-
raise
|
81
|
+
raise TypeError("value has unsupported type")
|
80
82
|
|
81
83
|
def __repr__(self):
|
82
84
|
"""Returns the string representation"""
|
83
85
|
# The Decimal class only really converts numbers that are very small into engineering notation.
|
84
86
|
# So we will simply make all numbers small numbers and take advantage of the Decimal class.
|
85
|
-
number_str = self.number * Decimal(
|
87
|
+
number_str = self.number * Decimal("10e-37")
|
86
88
|
number_str = number_str.to_eng_string().lower()
|
87
89
|
|
88
|
-
base, exponent = number_str.split(
|
90
|
+
base, exponent = number_str.split("e")
|
89
91
|
|
90
92
|
if self.significant > 0:
|
91
93
|
abs_base = abs(Decimal(base))
|
@@ -98,12 +100,12 @@ class EngineeringNumber:
|
|
98
100
|
|
99
101
|
base = str(round(Decimal(base), num_digits))
|
100
102
|
|
101
|
-
if
|
103
|
+
if "e" in base.lower():
|
102
104
|
base = str(int(Decimal(base)))
|
103
105
|
|
104
106
|
# Remove trailing decimal
|
105
|
-
if
|
106
|
-
base = base.rstrip(
|
107
|
+
if "." in base:
|
108
|
+
base = base.rstrip(".")
|
107
109
|
|
108
110
|
return base + self._exponent_lookup_scaled[exponent]
|
109
111
|
|
@@ -159,6 +161,6 @@ class EngineeringNumber:
|
|
159
161
|
return self.number >= self._to_decimal(other)
|
160
162
|
|
161
163
|
def __eq__(self, other: Any) -> bool:
|
162
|
-
if not isinstance(other,
|
164
|
+
if not isinstance(other, str | float | int | EngineeringNumber):
|
163
165
|
return NotImplemented
|
164
166
|
return self.number == self._to_decimal(other)
|
sonusai/utils/get_label_names.py
CHANGED
@@ -1,22 +1,20 @@
|
|
1
|
-
def get_label_names(num_labels: int, file: str = None) -> list:
|
1
|
+
def get_label_names(num_labels: int, file: str | None = None) -> list:
|
2
2
|
"""Return label names in a list. Read from CSV file, if provided."""
|
3
3
|
import csv
|
4
4
|
|
5
|
-
from sonusai import SonusAIError
|
6
|
-
|
7
5
|
if file is None:
|
8
|
-
return [f
|
6
|
+
return [f"Class {val + 1}" for val in range(num_labels)]
|
9
7
|
|
10
|
-
label_names = [
|
8
|
+
label_names = [""] * num_labels
|
11
9
|
with open(file) as f:
|
12
10
|
reader = csv.DictReader(f)
|
13
|
-
if
|
14
|
-
raise
|
11
|
+
if reader.fieldnames is None or "index" not in reader.fieldnames or "display_name" not in reader.fieldnames:
|
12
|
+
raise ValueError("Missing required fields in labels CSV.")
|
15
13
|
|
16
14
|
for row in reader:
|
17
|
-
index = int(row[
|
15
|
+
index = int(row["index"]) - 1
|
18
16
|
if index >= num_labels:
|
19
|
-
raise
|
20
|
-
label_names[index] = row[
|
17
|
+
raise ValueError("The number of given label names does not match the number of labels.")
|
18
|
+
label_names[index] = row["display_name"]
|
21
19
|
|
22
20
|
return label_names
|
@@ -1,7 +1,7 @@
|
|
1
1
|
def human_readable_size(size: float, decimal_places: int = 3) -> str:
|
2
2
|
"""Convert number into string with units"""
|
3
|
-
for unit in [
|
3
|
+
for unit in ["B", "kB", "MB", "GB", "TB"]: # noqa: B007
|
4
4
|
if size < 1024.0:
|
5
5
|
break
|
6
6
|
size /= 1024.0
|
7
|
-
return f
|
7
|
+
return f"{size:.{decimal_places}f} {unit}"
|
sonusai/utils/model_utils.py
CHANGED
@@ -9,12 +9,10 @@ def import_module(name: str) -> Any:
|
|
9
9
|
import sys
|
10
10
|
from importlib import import_module
|
11
11
|
|
12
|
-
from sonusai import SonusAIError
|
13
|
-
|
14
12
|
try:
|
15
13
|
path = os.path.dirname(name)
|
16
14
|
if len(path) < 1:
|
17
|
-
path =
|
15
|
+
path = "./"
|
18
16
|
|
19
17
|
# Add model file location to system path
|
20
18
|
sys.path.append(os.path.abspath(path))
|
@@ -23,8 +21,8 @@ def import_module(name: str) -> Any:
|
|
23
21
|
root = os.path.splitext(os.path.basename(name))[0]
|
24
22
|
model = import_module(root)
|
25
23
|
except Exception as e:
|
26
|
-
raise
|
24
|
+
raise OSError(f"Error: could not import model from {name}: {e}.") from e
|
27
25
|
except Exception as e:
|
28
|
-
raise
|
26
|
+
raise OSError(f"Error: could not find {name}: {e}.") from e
|
29
27
|
|
30
28
|
return model
|
@@ -2,12 +2,10 @@ import numpy as np
|
|
2
2
|
|
3
3
|
|
4
4
|
def int16_to_float(x: np.ndarray) -> np.ndarray:
|
5
|
-
"""
|
6
|
-
"""
|
5
|
+
"""Convert int16 array to floating point with range +/- 1"""
|
7
6
|
return x.astype(np.float32) / 32768
|
8
7
|
|
9
8
|
|
10
9
|
def float_to_int16(x: np.ndarray) -> np.ndarray:
|
11
|
-
"""
|
12
|
-
"""
|
10
|
+
"""Convert float point array with range +/- 1 to int16"""
|
13
11
|
return (x * 32768).astype(np.int16)
|
sonusai/utils/onnx_utils.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1
|
-
from
|
2
|
-
from typing import Sequence
|
1
|
+
from collections.abc import Sequence
|
3
2
|
|
4
3
|
from onnx import ModelProto
|
5
4
|
from onnx import ValueInfoProto
|
@@ -7,7 +6,14 @@ from onnxruntime import InferenceSession
|
|
7
6
|
from onnxruntime import NodeArg
|
8
7
|
from onnxruntime import SessionOptions
|
9
8
|
|
10
|
-
REQUIRED_HPARAMS = (
|
9
|
+
REQUIRED_HPARAMS = (
|
10
|
+
"feature",
|
11
|
+
"batch_size",
|
12
|
+
"timesteps",
|
13
|
+
"flatten",
|
14
|
+
"add1ch",
|
15
|
+
"truth_mutex",
|
16
|
+
)
|
11
17
|
|
12
18
|
|
13
19
|
def _extract_shapes(io: list[ValueInfoProto]) -> list[list[int] | str]:
|
@@ -18,14 +24,14 @@ def _extract_shapes(io: list[ValueInfoProto]) -> list[list[int] | str]:
|
|
18
24
|
# get tensor type: 0, 1, 2, etc.
|
19
25
|
tensor_type = item.type.tensor_type
|
20
26
|
# check if it has a shape
|
21
|
-
if tensor_type.HasField(
|
27
|
+
if tensor_type.HasField("shape"):
|
22
28
|
tmp_shape = []
|
23
29
|
# iterate through dimensions of the shape
|
24
30
|
for d in tensor_type.shape.dim:
|
25
|
-
if d.HasField(
|
31
|
+
if d.HasField("dim_value"):
|
26
32
|
# known dimension, int value
|
27
33
|
tmp_shape.append(d.dim_value)
|
28
|
-
elif d.HasField(
|
34
|
+
elif d.HasField("dim_param"):
|
29
35
|
# dynamic dim with symbolic name of d.dim_param; set size to 0
|
30
36
|
tmp_shape.append(0)
|
31
37
|
else:
|
@@ -34,19 +40,21 @@ def _extract_shapes(io: list[ValueInfoProto]) -> list[list[int] | str]:
|
|
34
40
|
# add as a list
|
35
41
|
shapes.append(tmp_shape)
|
36
42
|
else:
|
37
|
-
shapes.append(
|
43
|
+
shapes.append("unknown rank")
|
38
44
|
|
39
45
|
return shapes
|
40
46
|
|
41
47
|
|
42
|
-
def get_and_check_inputs(
|
48
|
+
def get_and_check_inputs(
|
49
|
+
model: ModelProto,
|
50
|
+
) -> tuple[list[ValueInfoProto], list[list[int] | str]]:
|
43
51
|
from sonusai import logger
|
44
52
|
|
45
53
|
# ignore initializer inputs (only seen in older ONNX < v1.5)
|
46
54
|
initializer_names = [x.name for x in model.graph.initializer]
|
47
55
|
inputs = [i for i in model.graph.input if i.name not in initializer_names]
|
48
56
|
if len(inputs) != 1:
|
49
|
-
logger.warning(f
|
57
|
+
logger.warning(f"Warning: ONNX model has {len(inputs)} inputs; expected only 1")
|
50
58
|
|
51
59
|
# This one-liner works only if input has type and shape, returns a list
|
52
60
|
# shape0 = [d.dim_value for d in inputs[0].type.tensor_type.shape.dim]
|
@@ -55,12 +63,14 @@ def get_and_check_inputs(model: ModelProto) -> tuple[list[ValueInfoProto], list[
|
|
55
63
|
return inputs, shapes
|
56
64
|
|
57
65
|
|
58
|
-
def get_and_check_outputs(
|
66
|
+
def get_and_check_outputs(
|
67
|
+
model: ModelProto,
|
68
|
+
) -> tuple[list[ValueInfoProto], list[list[int] | str]]:
|
59
69
|
from sonusai import logger
|
60
70
|
|
61
|
-
outputs =
|
71
|
+
outputs = list(model.graph.output)
|
62
72
|
if len(outputs) != 1:
|
63
|
-
logger.warning(f
|
73
|
+
logger.warning(f"Warning: ONNX model has {len(outputs)} outputs; expected only 1")
|
64
74
|
|
65
75
|
shapes = _extract_shapes(outputs)
|
66
76
|
|
@@ -85,38 +95,39 @@ def add_sonusai_metadata(model: ModelProto, hparams: dict) -> ModelProto:
|
|
85
95
|
from sonusai import logger
|
86
96
|
|
87
97
|
# Note hparams should be a dict (i.e., extracted from checkpoint)
|
88
|
-
|
98
|
+
if eval(str(hparams)) != hparams: # noqa: S307
|
99
|
+
raise TypeError("hparams is not a dict")
|
89
100
|
for key in REQUIRED_HPARAMS:
|
90
|
-
if key not in hparams
|
91
|
-
logger.warning(f
|
101
|
+
if key not in hparams:
|
102
|
+
logger.warning(f"Warning: SonusAI hyperparameters are missing: {key}")
|
92
103
|
|
93
104
|
meta = model.metadata_props.add()
|
94
|
-
meta.key =
|
105
|
+
meta.key = "hparams"
|
95
106
|
meta.value = str(hparams)
|
96
107
|
|
97
108
|
return model
|
98
109
|
|
99
110
|
|
100
|
-
def get_sonusai_metadata(session: InferenceSession) ->
|
101
|
-
"""Get SonusAI hyperparameter metadata from an ONNX Runtime session.
|
102
|
-
"""
|
111
|
+
def get_sonusai_metadata(session: InferenceSession) -> dict | None:
|
112
|
+
"""Get SonusAI hyperparameter metadata from an ONNX Runtime session."""
|
103
113
|
from sonusai import logger
|
104
114
|
|
105
115
|
meta = session.get_modelmeta()
|
106
|
-
if
|
116
|
+
if "hparams" not in meta.custom_metadata_map:
|
107
117
|
logger.warning("Warning: ONNX model metadata does not contain 'hparams'")
|
108
118
|
return None
|
109
119
|
|
110
|
-
hparams = eval(meta.custom_metadata_map[
|
120
|
+
hparams = eval(meta.custom_metadata_map["hparams"]) # noqa: S307
|
111
121
|
for key in REQUIRED_HPARAMS:
|
112
|
-
if key not in hparams
|
113
|
-
logger.warning(f
|
122
|
+
if key not in hparams:
|
123
|
+
logger.warning(f"Warning: ONNX model does not have required SonusAI hyperparameters: {key}")
|
114
124
|
|
115
125
|
return hparams
|
116
126
|
|
117
127
|
|
118
|
-
def load_ort_session(
|
119
|
-
|
128
|
+
def load_ort_session(
|
129
|
+
model_path: str, providers: Sequence[str | tuple[str, dict]] | None = None
|
130
|
+
) -> tuple[InferenceSession, SessionOptions, str, dict | None, list[NodeArg], list[NodeArg]]:
|
120
131
|
from os.path import basename
|
121
132
|
from os.path import exists
|
122
133
|
from os.path import isfile
|
@@ -127,27 +138,27 @@ def load_ort_session(model_path: str, providers: Sequence[str | tuple[str, dict]
|
|
127
138
|
from sonusai import logger
|
128
139
|
|
129
140
|
if providers is None:
|
130
|
-
providers = [
|
141
|
+
providers = ["CPUExecutionProvider"]
|
131
142
|
|
132
143
|
if exists(model_path) and isfile(model_path):
|
133
144
|
model_basename = basename(model_path)
|
134
145
|
model_root = splitext(model_basename)[0]
|
135
|
-
logger.info(f
|
146
|
+
logger.info(f"Importing model from {model_basename}")
|
136
147
|
try:
|
137
148
|
session = ort.InferenceSession(model_path, providers=providers)
|
138
149
|
options = ort.SessionOptions()
|
139
150
|
except Exception as e:
|
140
|
-
logger.exception(f
|
141
|
-
raise SystemExit(1)
|
151
|
+
logger.exception(f"Error: could not load ONNX model from {model_path}: {e}")
|
152
|
+
raise SystemExit(1) from e
|
142
153
|
else:
|
143
|
-
logger.exception(f
|
154
|
+
logger.exception(f"Error: model file does not exist: {model_path}")
|
144
155
|
raise SystemExit(1)
|
145
156
|
|
146
|
-
logger.info(f
|
157
|
+
logger.info(f"Opened session with provider options: {session._provider_options}.")
|
147
158
|
hparams = get_sonusai_metadata(session)
|
148
159
|
if hparams is not None:
|
149
160
|
for key in REQUIRED_HPARAMS:
|
150
|
-
logger.info(f
|
161
|
+
logger.info(f" {key:12} {hparams[key]}")
|
151
162
|
|
152
163
|
inputs = session.get_inputs()
|
153
164
|
outputs = session.get_outputs()
|
sonusai/utils/parallel.py
CHANGED
@@ -1,32 +1,41 @@
|
|
1
|
+
import warnings
|
2
|
+
from collections.abc import Callable
|
3
|
+
from collections.abc import Iterable
|
1
4
|
from multiprocessing import current_process
|
2
5
|
from multiprocessing import get_context
|
3
6
|
from typing import Any
|
4
|
-
from typing import Callable
|
5
|
-
from typing import Iterable
|
6
|
-
from typing import Optional
|
7
7
|
|
8
|
-
from tqdm import
|
8
|
+
from tqdm import TqdmExperimentalWarning
|
9
|
+
from tqdm.rich import tqdm
|
9
10
|
|
10
|
-
|
11
|
+
warnings.filterwarnings(action="ignore", category=TqdmExperimentalWarning)
|
11
12
|
|
13
|
+
track = tqdm
|
12
14
|
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
15
|
+
CONTEXT = "fork"
|
16
|
+
|
17
|
+
|
18
|
+
def par_track(
|
19
|
+
func: Callable,
|
20
|
+
*iterables: Iterable,
|
21
|
+
initializer: Callable[..., None] | None = None,
|
22
|
+
initargs: Iterable[Any] | None = None,
|
23
|
+
progress: tqdm | None = None,
|
24
|
+
num_cpus: int | float | None = None,
|
25
|
+
total: int | None = None,
|
26
|
+
no_par: bool = False,
|
27
|
+
) -> list[Any]:
|
21
28
|
"""Performs a parallel ordered imap with tqdm progress."""
|
22
|
-
from
|
23
|
-
|
29
|
+
from collections.abc import Sized
|
30
|
+
|
31
|
+
from psutil import cpu_count
|
24
32
|
|
25
33
|
if total is None:
|
26
|
-
|
34
|
+
_total = min(len(iterable) for iterable in iterables if isinstance(iterable, Sized))
|
35
|
+
else:
|
36
|
+
_total = int(total)
|
27
37
|
|
28
|
-
results: list[Any] = [None] *
|
29
|
-
n = 0
|
38
|
+
results: list[Any] = [None] * _total
|
30
39
|
if no_par or current_process().daemon:
|
31
40
|
if initializer is not None:
|
32
41
|
if initargs is not None:
|
@@ -34,21 +43,26 @@ def pp_tqdm_imap(func: Callable,
|
|
34
43
|
else:
|
35
44
|
initializer()
|
36
45
|
|
37
|
-
for result in map(func, *iterables):
|
46
|
+
for n, result in enumerate(map(func, *iterables)):
|
38
47
|
results[n] = result
|
39
|
-
n += 1
|
40
48
|
if progress is not None:
|
41
49
|
progress.update()
|
42
50
|
else:
|
43
51
|
if num_cpus is None:
|
44
|
-
|
45
|
-
elif num_cpus
|
46
|
-
|
52
|
+
_num_cpus = max(cpu_count() - 2, 1)
|
53
|
+
elif isinstance(num_cpus, float):
|
54
|
+
_num_cpus = int(round(num_cpus * cpu_count()))
|
55
|
+
else:
|
56
|
+
_num_cpus = int(num_cpus)
|
47
57
|
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
58
|
+
_num_cpus = min(_num_cpus, _total)
|
59
|
+
|
60
|
+
if initargs is None:
|
61
|
+
initargs = []
|
62
|
+
|
63
|
+
with get_context(CONTEXT).Pool(processes=_num_cpus, initializer=initializer, initargs=initargs) as pool:
|
64
|
+
n = 0
|
65
|
+
for result in pool.imap(func, *iterables): # type: ignore[arg-type]
|
52
66
|
results[n] = result
|
53
67
|
n += 1
|
54
68
|
if progress is not None:
|
@@ -59,6 +73,3 @@ def pp_tqdm_imap(func: Callable,
|
|
59
73
|
if progress is not None:
|
60
74
|
progress.close()
|
61
75
|
return results
|
62
|
-
|
63
|
-
|
64
|
-
pp_imap = pp_tqdm_imap
|
@@ -1,37 +1,38 @@
|
|
1
|
-
from
|
1
|
+
from collections.abc import Callable
|
2
2
|
|
3
3
|
from sonusai.mixture import ClassCount
|
4
4
|
from sonusai.mixture import MixtureDatabase
|
5
5
|
|
6
6
|
|
7
|
-
def print_mixture_details(
|
8
|
-
|
9
|
-
|
10
|
-
|
7
|
+
def print_mixture_details(
|
8
|
+
mixdb: MixtureDatabase,
|
9
|
+
mixid: int | None = None,
|
10
|
+
desc_len: int = 1,
|
11
|
+
print_fn: Callable = print,
|
12
|
+
) -> None:
|
11
13
|
import numpy as np
|
12
14
|
|
13
|
-
from sonusai import SonusAIError
|
14
15
|
from sonusai.mixture import SAMPLE_RATE
|
15
16
|
from sonusai.utils import seconds_to_hms
|
16
17
|
|
17
18
|
if mixid is not None:
|
18
19
|
if 0 < mixid >= mixdb.num_mixtures:
|
19
|
-
raise
|
20
|
+
raise ValueError(f"Given mixid is outside valid range of 0:{mixdb.num_mixtures - 1}.")
|
20
21
|
|
21
|
-
print_fn(f
|
22
|
+
print_fn(f"Mixture {mixid} details")
|
22
23
|
mixture = mixdb.mixture(mixid)
|
23
24
|
target_files = [mixdb.target_files[target.file_id] for target in mixture.targets]
|
24
25
|
target_augmentations = [target.augmentation for target in mixture.targets]
|
25
26
|
noise_file = mixdb.noise_file(mixture.noise.file_id)
|
26
27
|
for t_idx, target_file in enumerate(target_files):
|
27
|
-
print_fn(f
|
28
|
+
print_fn(f" Target {t_idx}")
|
28
29
|
print_fn(f'{" Name":{desc_len}} {target_file.name}')
|
29
30
|
print_fn(f'{" Duration":{desc_len}} {seconds_to_hms(target_file.duration)}')
|
30
|
-
for
|
31
|
-
print_fn(f
|
32
|
-
print_fn(f'{"
|
33
|
-
print_fn(f'{"
|
34
|
-
print_fn(f'{" Config":{desc_len}} {
|
31
|
+
for truth_name, truth_config in target_file.truth_configs.items():
|
32
|
+
print_fn(f" Truth config: '{truth_name}'")
|
33
|
+
print_fn(f'{" Function":{desc_len}} {truth_config.function}')
|
34
|
+
print_fn(f'{" Stride reduction":{desc_len}} {truth_config.stride_reduction}')
|
35
|
+
print_fn(f'{" Config":{desc_len}} {truth_config.config}')
|
35
36
|
print_fn(f'{" Augmentation":{desc_len}} {target_augmentations[t_idx]}')
|
36
37
|
print_fn(f'{" Samples":{desc_len}} {mixture.samples}')
|
37
38
|
print_fn(f'{" Feature frames":{desc_len}} {mixdb.mixture_feature_frames(mixid)}')
|
@@ -42,18 +43,20 @@ def print_mixture_details(mixdb: MixtureDatabase,
|
|
42
43
|
print_fn(f'{" Target gain":{desc_len}} {[target.gain for target in mixture.targets]}')
|
43
44
|
print_fn(f'{" Target SNR gain":{desc_len}} {mixture.target_snr_gain}')
|
44
45
|
print_fn(f'{" Noise SNR gain":{desc_len}} {mixture.noise_snr_gain}')
|
45
|
-
print_fn(
|
46
|
+
print_fn("")
|
46
47
|
|
47
48
|
|
48
|
-
def print_class_count(
|
49
|
-
|
50
|
-
|
51
|
-
|
49
|
+
def print_class_count(
|
50
|
+
class_count: ClassCount,
|
51
|
+
length: int,
|
52
|
+
print_fn: Callable = print,
|
53
|
+
all_class_counts: bool = False,
|
54
|
+
) -> None:
|
52
55
|
from sonusai.utils import max_text_width
|
53
56
|
|
54
|
-
print_fn(
|
57
|
+
print_fn("Class count:")
|
55
58
|
idx_len = max_text_width(len(class_count))
|
56
59
|
for idx, count in enumerate(class_count):
|
57
60
|
if all_class_counts or count > 0:
|
58
|
-
desc = f
|
59
|
-
print_fn(f
|
61
|
+
desc = f" class {idx + 1:{idx_len}}"
|
62
|
+
print_fn(f"{desc:{length}} {count}")
|