sonusai 1.0.16__cp311-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +170 -0
- sonusai/aawscd_probwrite.py +148 -0
- sonusai/audiofe.py +481 -0
- sonusai/calc_metric_spenh.py +1136 -0
- sonusai/config/__init__.py +0 -0
- sonusai/config/asr.py +21 -0
- sonusai/config/config.py +65 -0
- sonusai/config/config.yml +49 -0
- sonusai/config/constants.py +53 -0
- sonusai/config/ir.py +124 -0
- sonusai/config/ir_delay.py +62 -0
- sonusai/config/source.py +275 -0
- sonusai/config/spectral_masks.py +15 -0
- sonusai/config/truth.py +64 -0
- sonusai/constants.py +14 -0
- sonusai/data/__init__.py +0 -0
- sonusai/data/silero_vad_v5.1.jit +0 -0
- sonusai/data/silero_vad_v5.1.onnx +0 -0
- sonusai/data/speech_ma01_01.wav +0 -0
- sonusai/data/whitenoise.wav +0 -0
- sonusai/datatypes.py +383 -0
- sonusai/deprecated/gentcst.py +632 -0
- sonusai/deprecated/plot.py +519 -0
- sonusai/deprecated/tplot.py +365 -0
- sonusai/doc.py +52 -0
- sonusai/doc_strings/__init__.py +1 -0
- sonusai/doc_strings/doc_strings.py +531 -0
- sonusai/genft.py +196 -0
- sonusai/genmetrics.py +183 -0
- sonusai/genmix.py +199 -0
- sonusai/genmixdb.py +235 -0
- sonusai/ir_metric.py +551 -0
- sonusai/lsdb.py +141 -0
- sonusai/main.py +134 -0
- sonusai/metrics/__init__.py +43 -0
- sonusai/metrics/calc_audio_stats.py +42 -0
- sonusai/metrics/calc_class_weights.py +90 -0
- sonusai/metrics/calc_optimal_thresholds.py +73 -0
- sonusai/metrics/calc_pcm.py +45 -0
- sonusai/metrics/calc_pesq.py +36 -0
- sonusai/metrics/calc_phase_distance.py +43 -0
- sonusai/metrics/calc_sa_sdr.py +64 -0
- sonusai/metrics/calc_sample_weights.py +25 -0
- sonusai/metrics/calc_segsnr_f.py +82 -0
- sonusai/metrics/calc_speech.py +382 -0
- sonusai/metrics/calc_wer.py +71 -0
- sonusai/metrics/calc_wsdr.py +57 -0
- sonusai/metrics/calculate_metrics.py +395 -0
- sonusai/metrics/class_summary.py +74 -0
- sonusai/metrics/confusion_matrix_summary.py +75 -0
- sonusai/metrics/one_hot.py +283 -0
- sonusai/metrics/snr_summary.py +128 -0
- sonusai/metrics_summary.py +314 -0
- sonusai/mixture/__init__.py +15 -0
- sonusai/mixture/audio.py +187 -0
- sonusai/mixture/class_balancing.py +103 -0
- sonusai/mixture/constants.py +3 -0
- sonusai/mixture/data_io.py +173 -0
- sonusai/mixture/db.py +169 -0
- sonusai/mixture/db_datatypes.py +92 -0
- sonusai/mixture/effects.py +344 -0
- sonusai/mixture/feature.py +78 -0
- sonusai/mixture/generation.py +1116 -0
- sonusai/mixture/helpers.py +351 -0
- sonusai/mixture/ir_effects.py +77 -0
- sonusai/mixture/log_duration_and_sizes.py +23 -0
- sonusai/mixture/mixdb.py +1857 -0
- sonusai/mixture/pad_audio.py +35 -0
- sonusai/mixture/resample.py +7 -0
- sonusai/mixture/sox_effects.py +195 -0
- sonusai/mixture/sox_help.py +650 -0
- sonusai/mixture/spectral_mask.py +51 -0
- sonusai/mixture/truth.py +61 -0
- sonusai/mixture/truth_functions/__init__.py +45 -0
- sonusai/mixture/truth_functions/crm.py +105 -0
- sonusai/mixture/truth_functions/energy.py +222 -0
- sonusai/mixture/truth_functions/file.py +48 -0
- sonusai/mixture/truth_functions/metadata.py +24 -0
- sonusai/mixture/truth_functions/metrics.py +28 -0
- sonusai/mixture/truth_functions/phoneme.py +18 -0
- sonusai/mixture/truth_functions/sed.py +98 -0
- sonusai/mixture/truth_functions/target.py +142 -0
- sonusai/mkwav.py +135 -0
- sonusai/onnx_predict.py +363 -0
- sonusai/parse/__init__.py +0 -0
- sonusai/parse/expand.py +156 -0
- sonusai/parse/parse_source_directive.py +129 -0
- sonusai/parse/rand.py +214 -0
- sonusai/py.typed +0 -0
- sonusai/queries/__init__.py +0 -0
- sonusai/queries/queries.py +239 -0
- sonusai/rs.abi3.so +0 -0
- sonusai/rs.pyi +1 -0
- sonusai/rust/__init__.py +0 -0
- sonusai/speech/__init__.py +0 -0
- sonusai/speech/l2arctic.py +121 -0
- sonusai/speech/librispeech.py +102 -0
- sonusai/speech/mcgill.py +71 -0
- sonusai/speech/textgrid.py +89 -0
- sonusai/speech/timit.py +138 -0
- sonusai/speech/types.py +12 -0
- sonusai/speech/vctk.py +53 -0
- sonusai/speech/voxceleb.py +108 -0
- sonusai/utils/__init__.py +3 -0
- sonusai/utils/asl_p56.py +130 -0
- sonusai/utils/asr.py +91 -0
- sonusai/utils/asr_functions/__init__.py +3 -0
- sonusai/utils/asr_functions/aaware_whisper.py +69 -0
- sonusai/utils/audio_devices.py +50 -0
- sonusai/utils/braced_glob.py +50 -0
- sonusai/utils/calculate_input_shape.py +26 -0
- sonusai/utils/choice.py +51 -0
- sonusai/utils/compress.py +25 -0
- sonusai/utils/convert_string_to_number.py +6 -0
- sonusai/utils/create_timestamp.py +5 -0
- sonusai/utils/create_ts_name.py +14 -0
- sonusai/utils/dataclass_from_dict.py +27 -0
- sonusai/utils/db.py +16 -0
- sonusai/utils/docstring.py +53 -0
- sonusai/utils/energy_f.py +44 -0
- sonusai/utils/engineering_number.py +166 -0
- sonusai/utils/evaluate_random_rule.py +15 -0
- sonusai/utils/get_frames_per_batch.py +2 -0
- sonusai/utils/get_label_names.py +20 -0
- sonusai/utils/grouper.py +6 -0
- sonusai/utils/human_readable_size.py +7 -0
- sonusai/utils/keyboard_interrupt.py +12 -0
- sonusai/utils/load_object.py +21 -0
- sonusai/utils/max_text_width.py +9 -0
- sonusai/utils/model_utils.py +28 -0
- sonusai/utils/numeric_conversion.py +11 -0
- sonusai/utils/onnx_utils.py +155 -0
- sonusai/utils/parallel.py +162 -0
- sonusai/utils/path_info.py +7 -0
- sonusai/utils/print_mixture_details.py +60 -0
- sonusai/utils/rand.py +13 -0
- sonusai/utils/ranges.py +43 -0
- sonusai/utils/read_predict_data.py +32 -0
- sonusai/utils/reshape.py +154 -0
- sonusai/utils/seconds_to_hms.py +7 -0
- sonusai/utils/stacked_complex.py +82 -0
- sonusai/utils/stratified_shuffle_split.py +170 -0
- sonusai/utils/tokenized_shell_vars.py +143 -0
- sonusai/utils/write_audio.py +26 -0
- sonusai/utils/yes_or_no.py +8 -0
- sonusai/vars.py +47 -0
- sonusai-1.0.16.dist-info/METADATA +56 -0
- sonusai-1.0.16.dist-info/RECORD +150 -0
- sonusai-1.0.16.dist-info/WHEEL +4 -0
- sonusai-1.0.16.dist-info/entry_points.txt +3 -0
sonusai/parse/rand.py
ADDED
@@ -0,0 +1,214 @@
|
|
1
|
+
"""
|
2
|
+
Parse 'rand' expressions.
|
3
|
+
|
4
|
+
"""
|
5
|
+
|
6
|
+
import decimal
|
7
|
+
import re
|
8
|
+
from random import uniform
|
9
|
+
|
10
|
+
import pyparsing as pp
|
11
|
+
|
12
|
+
SIGNIFICANT_DIGITS = 6
|
13
|
+
|
14
|
+
|
15
|
+
def rand(directive: str) -> str:
|
16
|
+
"""Evaluate the 'rand(min, max)' directive and validate its syntax.
|
17
|
+
|
18
|
+
:param directive: Directive to evaluate
|
19
|
+
:return: Text with all 'rand' directives replaced with a random value,
|
20
|
+
with a certain number of significant digits, or an empty string if 'text' is empty or None.
|
21
|
+
:raises ValueError: If the expression cannot be parsed or is malformed.
|
22
|
+
"""
|
23
|
+
if not directive:
|
24
|
+
return directive
|
25
|
+
|
26
|
+
# Create a recursive grammar for correct expressions
|
27
|
+
expr = pp.Forward()
|
28
|
+
number = pp.pyparsing_common.number
|
29
|
+
|
30
|
+
func_name = pp.Literal("rand")
|
31
|
+
left_paren = pp.Literal("(").suppress()
|
32
|
+
right_paren = pp.Literal(")").suppress()
|
33
|
+
comma = pp.Literal(",").suppress()
|
34
|
+
|
35
|
+
# Allow whitespace around function parameters
|
36
|
+
rand_function = (
|
37
|
+
func_name
|
38
|
+
+ left_paren
|
39
|
+
+ pp.Optional(pp.White()).suppress()
|
40
|
+
+ (number | expr)("min_val")
|
41
|
+
+ pp.Optional(pp.White()).suppress()
|
42
|
+
+ comma
|
43
|
+
+ pp.Optional(pp.White()).suppress()
|
44
|
+
+ (number | expr)("max_val")
|
45
|
+
+ pp.Optional(pp.White()).suppress()
|
46
|
+
+ right_paren
|
47
|
+
)
|
48
|
+
|
49
|
+
# Complete the recursive definition
|
50
|
+
expr << rand_function # pyright: ignore [reportUnusedExpression]
|
51
|
+
|
52
|
+
# Define parse action for generating random values
|
53
|
+
def replace_with_random(tokens):
|
54
|
+
min_val_token = tokens["min_val"]
|
55
|
+
max_val_token = tokens["max_val"]
|
56
|
+
|
57
|
+
# Convert tokens to float, handling both direct values and strings
|
58
|
+
min_val = float(min_val_token)
|
59
|
+
max_val = float(max_val_token)
|
60
|
+
|
61
|
+
# Validate min/max relationship
|
62
|
+
if min_val > max_val:
|
63
|
+
raise ValueError(f"Min value ({min_val}) cannot be greater than max value ({max_val})")
|
64
|
+
|
65
|
+
# Generate random value
|
66
|
+
value = uniform(min_val, max_val) # noqa: S311
|
67
|
+
decimal.getcontext().prec = SIGNIFICANT_DIGITS
|
68
|
+
return str(decimal.Decimal(value).normalize())
|
69
|
+
|
70
|
+
rand_function.setParseAction(replace_with_random)
|
71
|
+
|
72
|
+
# Create a validator parser for syntax checking only.
|
73
|
+
# This parser doesn't transform but just validates the syntax.
|
74
|
+
validator = pp.Forward()
|
75
|
+
validator_rand = (
|
76
|
+
func_name
|
77
|
+
+ left_paren
|
78
|
+
+ pp.Optional(pp.White()).suppress()
|
79
|
+
+ (number | validator)
|
80
|
+
+ pp.Optional(pp.White()).suppress()
|
81
|
+
+ comma
|
82
|
+
+ pp.Optional(pp.White()).suppress()
|
83
|
+
+ (number | validator)
|
84
|
+
+ pp.Optional(pp.White()).suppress()
|
85
|
+
+ right_paren
|
86
|
+
)
|
87
|
+
validator << validator_rand # pyright: ignore [reportUnusedExpression]
|
88
|
+
|
89
|
+
try:
|
90
|
+
# First, try to validate all 'rand' expressions without evaluating them.
|
91
|
+
# This helps identify structural problems before evaluation.
|
92
|
+
malformations = []
|
93
|
+
|
94
|
+
# Find all potential 'rand' expressions with or without opening/closing parentheses
|
95
|
+
potential_expressions = list(re.finditer(r"rand\s*(\()?(?:[^()]|\([^()]*\))*\)?", directive))
|
96
|
+
|
97
|
+
for match in potential_expressions:
|
98
|
+
expr_text = match.group(0)
|
99
|
+
|
100
|
+
# Check for missing opening parenthesis
|
101
|
+
if "rand" in expr_text and "(" not in expr_text:
|
102
|
+
malformations.append(f"Missing opening parenthesis in '{expr_text}'")
|
103
|
+
continue
|
104
|
+
|
105
|
+
# Check for missing closing parenthesis
|
106
|
+
if not expr_text.endswith(")"):
|
107
|
+
malformations.append(f"Missing closing parenthesis in '{expr_text}'")
|
108
|
+
continue
|
109
|
+
|
110
|
+
# Try to validate the expression structure
|
111
|
+
try:
|
112
|
+
validator.parseString(expr_text, parseAll=True)
|
113
|
+
except pp.ParseException:
|
114
|
+
# Count commas to check for parameter issues
|
115
|
+
param_text = expr_text[expr_text.find("(") + 1 : expr_text.rfind(")")]
|
116
|
+
|
117
|
+
# Track parenthesis nesting level to count commas correctly
|
118
|
+
nesting_level = 0
|
119
|
+
comma_count = 0
|
120
|
+
|
121
|
+
for char in param_text:
|
122
|
+
if char == "(":
|
123
|
+
nesting_level += 1
|
124
|
+
elif char == ")":
|
125
|
+
nesting_level -= 1
|
126
|
+
elif char == "," and nesting_level == 0:
|
127
|
+
comma_count += 1
|
128
|
+
|
129
|
+
if comma_count == 0:
|
130
|
+
if not param_text.strip():
|
131
|
+
malformations.append(f"Missing parameters in '{expr_text}' (expected 2)")
|
132
|
+
else:
|
133
|
+
# Check if there might be a space instead of comma
|
134
|
+
if re.search(r"\d+\s+[-+]?\d+", param_text):
|
135
|
+
malformations.append(f"Missing comma between parameters in '{expr_text}'")
|
136
|
+
else:
|
137
|
+
malformations.append(f"Too few parameters in '{expr_text}' (expected 2, got 1)")
|
138
|
+
elif comma_count > 1:
|
139
|
+
malformations.append(f"Too many parameters in '{expr_text}' (expected 2, got {comma_count + 1})")
|
140
|
+
else:
|
141
|
+
# There's 1 comma, so we have 2 parameters, but still a parsing error
|
142
|
+
# This is likely a non-numeric parameter
|
143
|
+
params = [p.strip() for p in split_params_respecting_nesting(param_text)]
|
144
|
+
|
145
|
+
for i, param in enumerate(params):
|
146
|
+
# Check nested 'rand' expressions for validity
|
147
|
+
if "rand" in param:
|
148
|
+
# Check if the nested expression is valid by recursively calling 'rand'
|
149
|
+
try:
|
150
|
+
# We only want to validate, not transform
|
151
|
+
nested_validator = pp.Forward()
|
152
|
+
nested_validator_rand = (
|
153
|
+
func_name
|
154
|
+
+ left_paren
|
155
|
+
+ pp.Optional(pp.White()).suppress()
|
156
|
+
+ (number | nested_validator)
|
157
|
+
+ pp.Optional(pp.White()).suppress()
|
158
|
+
+ comma
|
159
|
+
+ pp.Optional(pp.White()).suppress()
|
160
|
+
+ (number | nested_validator)
|
161
|
+
+ pp.Optional(pp.White()).suppress()
|
162
|
+
+ right_paren
|
163
|
+
)
|
164
|
+
nested_validator << nested_validator_rand # pyright: ignore [reportUnusedExpression]
|
165
|
+
nested_validator.parseString(param, parseAll=True)
|
166
|
+
except pp.ParseException:
|
167
|
+
malformations.append(f"Invalid nested expression '{param}' in '{expr_text}'")
|
168
|
+
continue
|
169
|
+
|
170
|
+
# Check if the parameter is numeric
|
171
|
+
if not is_numeric(param):
|
172
|
+
param_name = "first" if i == 0 else "second"
|
173
|
+
malformations.append(f"Non-numeric {param_name} parameter '{param}' in '{expr_text}'")
|
174
|
+
|
175
|
+
if malformations:
|
176
|
+
raise ValueError(f"Malformed rand directive: {'; '.join(malformations)}")
|
177
|
+
|
178
|
+
# If validation passes, try to transform
|
179
|
+
result = rand_function.transformString(directive)
|
180
|
+
except pp.ParseException as e:
|
181
|
+
raise ValueError(f"Invalid rand expression in '{directive}': {e!s}") from e
|
182
|
+
|
183
|
+
return result
|
184
|
+
|
185
|
+
|
186
|
+
def is_numeric(text: str) -> bool:
|
187
|
+
"""Check if the text is a valid number (including scientific notation)."""
|
188
|
+
numeric_pattern = r"^[-+]?(\d+(\.\d*)?|\.\d+)([eE][-+]?\d+)?$"
|
189
|
+
return bool(re.match(numeric_pattern, text))
|
190
|
+
|
191
|
+
|
192
|
+
def split_params_respecting_nesting(param_text: str) -> list:
|
193
|
+
"""Split parameters by comma while respecting nested parentheses."""
|
194
|
+
result = []
|
195
|
+
current_param = []
|
196
|
+
nesting_level = 0
|
197
|
+
|
198
|
+
for char in param_text:
|
199
|
+
if char == "(":
|
200
|
+
nesting_level += 1
|
201
|
+
current_param.append(char)
|
202
|
+
elif char == ")" and nesting_level > 0:
|
203
|
+
nesting_level -= 1
|
204
|
+
current_param.append(char)
|
205
|
+
elif char == "," and nesting_level == 0:
|
206
|
+
result.append("".join(current_param))
|
207
|
+
current_param = []
|
208
|
+
else:
|
209
|
+
current_param.append(char)
|
210
|
+
|
211
|
+
if current_param:
|
212
|
+
result.append("".join(current_param))
|
213
|
+
|
214
|
+
return result
|
sonusai/py.typed
ADDED
File without changes
|
File without changes
|
@@ -0,0 +1,239 @@
|
|
1
|
+
from collections.abc import Callable
|
2
|
+
from typing import Any
|
3
|
+
|
4
|
+
from ..datatypes import GeneralizedIDs
|
5
|
+
from ..mixture.mixdb import MixtureDatabase
|
6
|
+
|
7
|
+
|
8
|
+
def _true_predicate(_: Any) -> bool:
|
9
|
+
return True
|
10
|
+
|
11
|
+
|
12
|
+
def get_mixids_from_mixture_field_predicate(
|
13
|
+
mixdb: MixtureDatabase,
|
14
|
+
field: str,
|
15
|
+
mixids: GeneralizedIDs = "*",
|
16
|
+
predicate: Callable[[Any], bool] | None = None,
|
17
|
+
) -> dict[int, list[int]]:
|
18
|
+
"""
|
19
|
+
Generate mixture IDs based on the mixture field and predicate
|
20
|
+
Return a dictionary where:
|
21
|
+
- keys are the matching field values
|
22
|
+
- values are lists of the mixids that match the criteria
|
23
|
+
"""
|
24
|
+
mixid_out = mixdb.mixids_to_list(mixids)
|
25
|
+
|
26
|
+
if predicate is None:
|
27
|
+
predicate = _true_predicate
|
28
|
+
|
29
|
+
criteria_set = set()
|
30
|
+
for m_id in mixid_out:
|
31
|
+
value = getattr(mixdb.mixture(m_id), field)
|
32
|
+
if isinstance(value, dict):
|
33
|
+
for v in value.values():
|
34
|
+
if predicate(v):
|
35
|
+
criteria_set.add(v)
|
36
|
+
elif predicate(value):
|
37
|
+
criteria_set.add(value)
|
38
|
+
criteria = sorted(criteria_set)
|
39
|
+
|
40
|
+
result: dict[int, list[int]] = {}
|
41
|
+
for criterion in criteria:
|
42
|
+
result[criterion] = []
|
43
|
+
for m_id in mixid_out:
|
44
|
+
value = getattr(mixdb.mixture(m_id), field)
|
45
|
+
if isinstance(value, dict):
|
46
|
+
for v in value.values():
|
47
|
+
if v == criterion:
|
48
|
+
result[criterion].append(m_id)
|
49
|
+
elif value == criterion:
|
50
|
+
result[criterion].append(m_id)
|
51
|
+
|
52
|
+
return result
|
53
|
+
|
54
|
+
|
55
|
+
def get_mixids_from_truth_configs_field_predicate(
|
56
|
+
mixdb: MixtureDatabase,
|
57
|
+
field: str,
|
58
|
+
mixids: GeneralizedIDs = "*",
|
59
|
+
predicate: Callable[[Any], bool] | None = None,
|
60
|
+
) -> dict[int, list[int]]:
|
61
|
+
"""
|
62
|
+
Generate mixture IDs based on the target truth_configs field and predicate
|
63
|
+
Return a dictionary where:
|
64
|
+
- keys are the matching field values
|
65
|
+
- values are lists of the mixids that match the criteria
|
66
|
+
"""
|
67
|
+
from ..config.constants import REQUIRED_TRUTH_CONFIG_FIELDS
|
68
|
+
|
69
|
+
mixid_out = mixdb.mixids_to_list(mixids)
|
70
|
+
|
71
|
+
# Get all field values
|
72
|
+
values = get_all_truth_configs_values_from_field(mixdb, field)
|
73
|
+
|
74
|
+
if predicate is None:
|
75
|
+
predicate = _true_predicate
|
76
|
+
|
77
|
+
# Get only values of interest
|
78
|
+
values = [value for value in values if predicate(value)]
|
79
|
+
|
80
|
+
result = {}
|
81
|
+
for value in values:
|
82
|
+
# Get a list of sources for each field value
|
83
|
+
indices = []
|
84
|
+
for s_ids in mixdb.source_file_ids.values():
|
85
|
+
for s_id in s_ids:
|
86
|
+
source = mixdb.source_file(s_id)
|
87
|
+
for truth_config in source.truth_configs.values():
|
88
|
+
if field in REQUIRED_TRUTH_CONFIG_FIELDS:
|
89
|
+
if value in getattr(truth_config, field):
|
90
|
+
indices.append(s_id)
|
91
|
+
else:
|
92
|
+
if value in getattr(truth_config.config, field):
|
93
|
+
indices.append(s_id)
|
94
|
+
indices = sorted(set(indices))
|
95
|
+
|
96
|
+
mixids = []
|
97
|
+
for index in indices:
|
98
|
+
for m_id in mixid_out:
|
99
|
+
if index in [source.file_id for source in mixdb.mixture(m_id).all_sources.values()]:
|
100
|
+
mixids.append(m_id)
|
101
|
+
|
102
|
+
mixids = sorted(set(mixids))
|
103
|
+
if mixids:
|
104
|
+
result[value] = mixids
|
105
|
+
|
106
|
+
return result
|
107
|
+
|
108
|
+
|
109
|
+
def get_all_truth_configs_values_from_field(mixdb: MixtureDatabase, field: str) -> list:
|
110
|
+
"""
|
111
|
+
Generate a list of all values corresponding to the given field in truth_configs
|
112
|
+
"""
|
113
|
+
from ..config.constants import REQUIRED_TRUTH_CONFIG_FIELDS
|
114
|
+
|
115
|
+
result = []
|
116
|
+
for sources in mixdb.source_files.values():
|
117
|
+
for source in sources:
|
118
|
+
for truth_config in source.truth_configs.values():
|
119
|
+
if field in REQUIRED_TRUTH_CONFIG_FIELDS:
|
120
|
+
value = getattr(truth_config, field)
|
121
|
+
else:
|
122
|
+
value = getattr(truth_config.config, field, None)
|
123
|
+
if not isinstance(value, list):
|
124
|
+
value = [value]
|
125
|
+
result.extend(value)
|
126
|
+
|
127
|
+
return sorted(set(result))
|
128
|
+
|
129
|
+
|
130
|
+
def get_mixids_from_noise(
|
131
|
+
mixdb: MixtureDatabase,
|
132
|
+
mixids: GeneralizedIDs = "*",
|
133
|
+
predicate: Callable[[Any], bool] | None = None,
|
134
|
+
) -> dict[int, list[int]]:
|
135
|
+
"""
|
136
|
+
Generate mixids based on noise index predicate
|
137
|
+
Return a dictionary where:
|
138
|
+
- keys are the noise indices
|
139
|
+
- values are lists of the mixids that match the noise index
|
140
|
+
"""
|
141
|
+
return get_mixids_from_mixture_field_predicate(mixdb=mixdb, mixids=mixids, field="noise_id", predicate=predicate)
|
142
|
+
|
143
|
+
|
144
|
+
def get_mixids_from_source(
|
145
|
+
mixdb: MixtureDatabase,
|
146
|
+
mixids: GeneralizedIDs = "*",
|
147
|
+
predicate: Callable[[Any], bool] | None = None,
|
148
|
+
) -> dict[int, list[int]]:
|
149
|
+
"""
|
150
|
+
Generate mixids based on a source index predicate
|
151
|
+
Return a dictionary where:
|
152
|
+
- keys are the source indices
|
153
|
+
- values are lists of the mixids that match the source index
|
154
|
+
"""
|
155
|
+
return get_mixids_from_mixture_field_predicate(mixdb=mixdb, mixids=mixids, field="source_ids", predicate=predicate)
|
156
|
+
|
157
|
+
|
158
|
+
def get_mixids_from_snr(
|
159
|
+
mixdb: MixtureDatabase,
|
160
|
+
mixids: GeneralizedIDs = "*",
|
161
|
+
predicate: Callable[[Any], bool] | None = None,
|
162
|
+
) -> dict[float, list[int]]:
|
163
|
+
"""
|
164
|
+
Generate mixids based on an SNR predicate
|
165
|
+
Return a dictionary where:
|
166
|
+
- keys are the SNRs
|
167
|
+
- values are lists of the mixids that match the SNR
|
168
|
+
"""
|
169
|
+
mixid_out = mixdb.mixids_to_list(mixids)
|
170
|
+
|
171
|
+
# Get all the SNRs
|
172
|
+
snrs = [float(snr) for snr in mixdb.all_snrs if not snr.is_random]
|
173
|
+
|
174
|
+
if predicate is None:
|
175
|
+
predicate = _true_predicate
|
176
|
+
|
177
|
+
# Get only the SNRs of interest (filter on predicate)
|
178
|
+
snrs = [snr for snr in snrs if predicate(snr)]
|
179
|
+
|
180
|
+
result: dict[float, list[int]] = {}
|
181
|
+
for snr in snrs:
|
182
|
+
# Get a list of mixids for each SNR
|
183
|
+
result[snr] = sorted(
|
184
|
+
[i for i, mixture in enumerate(mixdb.mixtures) if mixture.noise.snr == snr and i in mixid_out]
|
185
|
+
)
|
186
|
+
|
187
|
+
return result
|
188
|
+
|
189
|
+
|
190
|
+
def get_mixids_from_class_indices(
|
191
|
+
mixdb: MixtureDatabase,
|
192
|
+
mixids: GeneralizedIDs = "*",
|
193
|
+
predicate: Callable[[Any], bool] | None = None,
|
194
|
+
) -> dict[int, list[int]]:
|
195
|
+
"""
|
196
|
+
Generate mixids based on a class index predicate
|
197
|
+
Return a dictionary where:
|
198
|
+
- keys are the class indices
|
199
|
+
- values are lists of the mixids that match the class index
|
200
|
+
"""
|
201
|
+
mixid_out = mixdb.mixids_to_list(mixids)
|
202
|
+
|
203
|
+
if predicate is None:
|
204
|
+
predicate = _true_predicate
|
205
|
+
|
206
|
+
criteria_set = set()
|
207
|
+
for m_id in mixid_out:
|
208
|
+
class_indices = mixdb.mixture_class_indices(m_id)
|
209
|
+
for class_index in class_indices:
|
210
|
+
if predicate(class_index):
|
211
|
+
criteria_set.add(class_index)
|
212
|
+
criteria = sorted(criteria_set)
|
213
|
+
|
214
|
+
result: dict[int, list[int]] = {}
|
215
|
+
for criterion in criteria:
|
216
|
+
result[criterion] = []
|
217
|
+
for m_id in mixid_out:
|
218
|
+
class_indices = mixdb.mixture_class_indices(m_id)
|
219
|
+
for class_index in class_indices:
|
220
|
+
if class_index == criterion:
|
221
|
+
result[criterion].append(m_id)
|
222
|
+
|
223
|
+
return result
|
224
|
+
|
225
|
+
|
226
|
+
def get_mixids_from_truth_function(
|
227
|
+
mixdb: MixtureDatabase,
|
228
|
+
mixids: GeneralizedIDs = "*",
|
229
|
+
predicate: Callable[[Any], bool] | None = None,
|
230
|
+
) -> dict[int, list[int]]:
|
231
|
+
"""
|
232
|
+
Generate mixids based on a truth function predicate
|
233
|
+
Return a dictionary where:
|
234
|
+
- keys are the truth functions
|
235
|
+
- values are lists of the mixids that match the truth function
|
236
|
+
"""
|
237
|
+
return get_mixids_from_truth_configs_field_predicate(
|
238
|
+
mixdb=mixdb, mixids=mixids, field="function", predicate=predicate
|
239
|
+
)
|
sonusai/rs.abi3.so
ADDED
Binary file
|
sonusai/rs.pyi
ADDED
@@ -0,0 +1 @@
|
|
1
|
+
__version__: str
|
sonusai/rust/__init__.py
ADDED
File without changes
|
File without changes
|
@@ -0,0 +1,121 @@
|
|
1
|
+
import os
|
2
|
+
import string
|
3
|
+
from pathlib import Path
|
4
|
+
|
5
|
+
from .types import TimeAlignedType
|
6
|
+
|
7
|
+
|
8
|
+
def _get_duration(name: str) -> float:
|
9
|
+
import soundfile
|
10
|
+
|
11
|
+
try:
|
12
|
+
return soundfile.info(name).duration
|
13
|
+
except Exception as e:
|
14
|
+
raise OSError(f"Error reading {name}: {e}") from e
|
15
|
+
|
16
|
+
|
17
|
+
def load_text(audio: str | os.PathLike[str]) -> TimeAlignedType | None:
|
18
|
+
"""Load time-aligned text data given a L2-ARCTIC audio file.
|
19
|
+
|
20
|
+
:param audio: Path to the L2-ARCTIC audio file.
|
21
|
+
:return: A TimeAlignedType object.
|
22
|
+
"""
|
23
|
+
file = Path(audio).parent.parent / "transcript" / (Path(audio).stem + ".txt")
|
24
|
+
if not os.path.exists(file):
|
25
|
+
return None
|
26
|
+
|
27
|
+
with open(file, encoding="utf-8") as f:
|
28
|
+
line = f.read()
|
29
|
+
|
30
|
+
return TimeAlignedType(
|
31
|
+
0,
|
32
|
+
_get_duration(str(audio)),
|
33
|
+
line.strip().lower().translate(str.maketrans("", "", string.punctuation)),
|
34
|
+
)
|
35
|
+
|
36
|
+
|
37
|
+
def load_words(audio: str | os.PathLike[str]) -> list[TimeAlignedType] | None:
|
38
|
+
"""Load time-aligned word data given a L2-ARCTIC audio file.
|
39
|
+
|
40
|
+
:param audio: Path to the L2-ARCTIC audio file.
|
41
|
+
:return: A list of TimeAlignedType objects.
|
42
|
+
"""
|
43
|
+
return _load_ta(audio, "words")
|
44
|
+
|
45
|
+
|
46
|
+
def load_phonemes(audio: str | os.PathLike[str]) -> list[TimeAlignedType] | None:
|
47
|
+
"""Load time-aligned phonemes data given a L2-ARCTIC audio file.
|
48
|
+
|
49
|
+
:param audio: Path to the L2-ARCTIC audio file.
|
50
|
+
:return: A list of TimeAlignedType objects.
|
51
|
+
"""
|
52
|
+
return _load_ta(audio, "phones")
|
53
|
+
|
54
|
+
|
55
|
+
def _load_ta(audio: str | os.PathLike[str], tier: str) -> list[TimeAlignedType] | None:
|
56
|
+
from praatio import textgrid
|
57
|
+
from praatio.utilities.constants import Interval
|
58
|
+
|
59
|
+
file = Path(audio).parent.parent / "textgrid" / (Path(audio).stem + ".TextGrid")
|
60
|
+
if not os.path.exists(file):
|
61
|
+
return None
|
62
|
+
|
63
|
+
tg = textgrid.openTextgrid(str(file), includeEmptyIntervals=False)
|
64
|
+
if tier not in tg.tierNames:
|
65
|
+
return None
|
66
|
+
|
67
|
+
entries: list[TimeAlignedType] = []
|
68
|
+
for entry in tg.getTier(tier).entries:
|
69
|
+
if isinstance(entry, Interval):
|
70
|
+
entries.append(TimeAlignedType(text=entry.label, start=entry.start, end=entry.end))
|
71
|
+
|
72
|
+
return entries
|
73
|
+
|
74
|
+
|
75
|
+
def load_annotations(
|
76
|
+
audio: str | os.PathLike[str],
|
77
|
+
) -> dict[str, list[TimeAlignedType]] | None:
|
78
|
+
"""Load time-aligned annotation data given a L2-ARCTIC audio file.
|
79
|
+
|
80
|
+
:param audio: Path to the L2-ARCTIC audio file.
|
81
|
+
:return: A dictionary of a list of TimeAlignedType objects.
|
82
|
+
"""
|
83
|
+
from praatio import textgrid
|
84
|
+
from praatio.utilities.constants import Interval
|
85
|
+
|
86
|
+
file = Path(audio).parent.parent / "annotation" / (Path(audio).stem + ".TextGrid")
|
87
|
+
if not os.path.exists(file):
|
88
|
+
return None
|
89
|
+
|
90
|
+
tg = textgrid.openTextgrid(str(file), includeEmptyIntervals=False)
|
91
|
+
result: dict[str, list[TimeAlignedType]] = {}
|
92
|
+
for tier in tg.tierNames:
|
93
|
+
entries: list[TimeAlignedType] = []
|
94
|
+
for entry in tg.getTier(tier).entries:
|
95
|
+
if isinstance(entry, Interval):
|
96
|
+
entries.append(TimeAlignedType(text=entry.label, start=entry.start, end=entry.end))
|
97
|
+
result[tier] = entries
|
98
|
+
|
99
|
+
return result
|
100
|
+
|
101
|
+
|
102
|
+
def load_speakers(input_dir: Path) -> dict:
|
103
|
+
speakers = {}
|
104
|
+
with open(input_dir / "readme-download.txt") as file:
|
105
|
+
processing = False
|
106
|
+
for line in file:
|
107
|
+
if not processing and line.startswith("|---|"):
|
108
|
+
processing = True
|
109
|
+
continue
|
110
|
+
|
111
|
+
if processing:
|
112
|
+
if line.startswith("|**Total**|"):
|
113
|
+
break
|
114
|
+
else:
|
115
|
+
fields = line.strip().split("|")
|
116
|
+
speaker_id = fields[1]
|
117
|
+
gender = fields[2]
|
118
|
+
dialect = fields[3]
|
119
|
+
speakers[speaker_id] = {"gender": gender, "dialect": dialect}
|
120
|
+
|
121
|
+
return speakers
|
@@ -0,0 +1,102 @@
|
|
1
|
+
import os
|
2
|
+
from pathlib import Path
|
3
|
+
|
4
|
+
from .types import TimeAlignedType
|
5
|
+
|
6
|
+
|
7
|
+
def _get_num_samples(audio: str | os.PathLike[str]) -> int:
|
8
|
+
"""Get number of samples from audio file using soundfile
|
9
|
+
|
10
|
+
:param audio: Audio file name
|
11
|
+
:return: Number of samples
|
12
|
+
"""
|
13
|
+
import soundfile
|
14
|
+
from pydub import AudioSegment
|
15
|
+
|
16
|
+
if Path(audio).suffix == ".mp3":
|
17
|
+
return AudioSegment.from_mp3(audio).frame_count()
|
18
|
+
|
19
|
+
if Path(audio).suffix == ".m4a":
|
20
|
+
return AudioSegment.from_file(audio).frame_count()
|
21
|
+
|
22
|
+
return soundfile.info(audio).frames
|
23
|
+
|
24
|
+
|
25
|
+
def load_text(audio: str | os.PathLike[str]) -> TimeAlignedType | None:
|
26
|
+
"""Load text data from a LibriSpeech transcription file given a LibriSpeech audio filename.
|
27
|
+
|
28
|
+
:param audio: Path to the LibriSpeech audio file.
|
29
|
+
:return: A TimeAlignedType object.
|
30
|
+
"""
|
31
|
+
import string
|
32
|
+
|
33
|
+
from ..mixture.audio import get_sample_rate
|
34
|
+
|
35
|
+
path = Path(audio)
|
36
|
+
name = path.stem
|
37
|
+
transcript_filename = path.parent / f"{path.parent.parent.name}-{path.parent.name}.trans.txt"
|
38
|
+
|
39
|
+
if not os.path.exists(transcript_filename):
|
40
|
+
return None
|
41
|
+
|
42
|
+
with open(transcript_filename, encoding="utf-8") as f:
|
43
|
+
for line in f.readlines():
|
44
|
+
fields = line.strip().split()
|
45
|
+
key = fields[0]
|
46
|
+
if key == name:
|
47
|
+
text = " ".join(fields[1:]).lower().translate(str.maketrans("", "", string.punctuation))
|
48
|
+
return TimeAlignedType(0, _get_num_samples(audio) / get_sample_rate(str(audio)), text)
|
49
|
+
|
50
|
+
return None
|
51
|
+
|
52
|
+
|
53
|
+
def load_words(audio: str | os.PathLike[str]) -> list[TimeAlignedType] | None:
|
54
|
+
"""Load time-aligned word data given a LibriSpeech audio file.
|
55
|
+
|
56
|
+
:param audio: Path to the Librispeech audio file.
|
57
|
+
:return: A list of TimeAlignedType objects.
|
58
|
+
"""
|
59
|
+
return _load_ta(audio, "words")
|
60
|
+
|
61
|
+
|
62
|
+
def load_phonemes(audio: str | os.PathLike[str]) -> list[TimeAlignedType] | None:
|
63
|
+
"""Load time-aligned phonemes data given a LibriSpeech audio file.
|
64
|
+
|
65
|
+
:param audio: Path to the LibriSpeech audio file.
|
66
|
+
:return: A list of TimeAlignedType objects.
|
67
|
+
"""
|
68
|
+
return _load_ta(audio, "phones")
|
69
|
+
|
70
|
+
|
71
|
+
def _load_ta(audio: str | os.PathLike[str], tier: str) -> list[TimeAlignedType] | None:
|
72
|
+
from praatio import textgrid
|
73
|
+
from praatio.utilities.constants import Interval
|
74
|
+
|
75
|
+
file = Path(audio).with_suffix(".TextGrid")
|
76
|
+
if not os.path.exists(file):
|
77
|
+
return None
|
78
|
+
|
79
|
+
tg = textgrid.openTextgrid(str(file), includeEmptyIntervals=False)
|
80
|
+
if tier not in tg.tierNames:
|
81
|
+
return None
|
82
|
+
|
83
|
+
entries: list[TimeAlignedType] = []
|
84
|
+
for entry in tg.getTier(tier).entries:
|
85
|
+
if isinstance(entry, Interval):
|
86
|
+
entries.append(TimeAlignedType(text=entry.label, start=entry.start, end=entry.end))
|
87
|
+
else:
|
88
|
+
entries.append(TimeAlignedType(text=entry.label, start=entry.time, end=entry.time))
|
89
|
+
|
90
|
+
return entries
|
91
|
+
|
92
|
+
|
93
|
+
def load_speakers(input_dir: Path) -> dict:
|
94
|
+
speakers = {}
|
95
|
+
with open(input_dir / "SPEAKERS.TXT") as file:
|
96
|
+
for line in file:
|
97
|
+
if not line.startswith(";"):
|
98
|
+
fields = line.strip().split("|")
|
99
|
+
speaker_id = fields[0].strip()
|
100
|
+
gender = fields[1].strip()
|
101
|
+
speakers[speaker_id] = {"gender": gender}
|
102
|
+
return speakers
|