sonusai 1.0.16__cp311-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sonusai/__init__.py +170 -0
  2. sonusai/aawscd_probwrite.py +148 -0
  3. sonusai/audiofe.py +481 -0
  4. sonusai/calc_metric_spenh.py +1136 -0
  5. sonusai/config/__init__.py +0 -0
  6. sonusai/config/asr.py +21 -0
  7. sonusai/config/config.py +65 -0
  8. sonusai/config/config.yml +49 -0
  9. sonusai/config/constants.py +53 -0
  10. sonusai/config/ir.py +124 -0
  11. sonusai/config/ir_delay.py +62 -0
  12. sonusai/config/source.py +275 -0
  13. sonusai/config/spectral_masks.py +15 -0
  14. sonusai/config/truth.py +64 -0
  15. sonusai/constants.py +14 -0
  16. sonusai/data/__init__.py +0 -0
  17. sonusai/data/silero_vad_v5.1.jit +0 -0
  18. sonusai/data/silero_vad_v5.1.onnx +0 -0
  19. sonusai/data/speech_ma01_01.wav +0 -0
  20. sonusai/data/whitenoise.wav +0 -0
  21. sonusai/datatypes.py +383 -0
  22. sonusai/deprecated/gentcst.py +632 -0
  23. sonusai/deprecated/plot.py +519 -0
  24. sonusai/deprecated/tplot.py +365 -0
  25. sonusai/doc.py +52 -0
  26. sonusai/doc_strings/__init__.py +1 -0
  27. sonusai/doc_strings/doc_strings.py +531 -0
  28. sonusai/genft.py +196 -0
  29. sonusai/genmetrics.py +183 -0
  30. sonusai/genmix.py +199 -0
  31. sonusai/genmixdb.py +235 -0
  32. sonusai/ir_metric.py +551 -0
  33. sonusai/lsdb.py +141 -0
  34. sonusai/main.py +134 -0
  35. sonusai/metrics/__init__.py +43 -0
  36. sonusai/metrics/calc_audio_stats.py +42 -0
  37. sonusai/metrics/calc_class_weights.py +90 -0
  38. sonusai/metrics/calc_optimal_thresholds.py +73 -0
  39. sonusai/metrics/calc_pcm.py +45 -0
  40. sonusai/metrics/calc_pesq.py +36 -0
  41. sonusai/metrics/calc_phase_distance.py +43 -0
  42. sonusai/metrics/calc_sa_sdr.py +64 -0
  43. sonusai/metrics/calc_sample_weights.py +25 -0
  44. sonusai/metrics/calc_segsnr_f.py +82 -0
  45. sonusai/metrics/calc_speech.py +382 -0
  46. sonusai/metrics/calc_wer.py +71 -0
  47. sonusai/metrics/calc_wsdr.py +57 -0
  48. sonusai/metrics/calculate_metrics.py +395 -0
  49. sonusai/metrics/class_summary.py +74 -0
  50. sonusai/metrics/confusion_matrix_summary.py +75 -0
  51. sonusai/metrics/one_hot.py +283 -0
  52. sonusai/metrics/snr_summary.py +128 -0
  53. sonusai/metrics_summary.py +314 -0
  54. sonusai/mixture/__init__.py +15 -0
  55. sonusai/mixture/audio.py +187 -0
  56. sonusai/mixture/class_balancing.py +103 -0
  57. sonusai/mixture/constants.py +3 -0
  58. sonusai/mixture/data_io.py +173 -0
  59. sonusai/mixture/db.py +169 -0
  60. sonusai/mixture/db_datatypes.py +92 -0
  61. sonusai/mixture/effects.py +344 -0
  62. sonusai/mixture/feature.py +78 -0
  63. sonusai/mixture/generation.py +1116 -0
  64. sonusai/mixture/helpers.py +351 -0
  65. sonusai/mixture/ir_effects.py +77 -0
  66. sonusai/mixture/log_duration_and_sizes.py +23 -0
  67. sonusai/mixture/mixdb.py +1857 -0
  68. sonusai/mixture/pad_audio.py +35 -0
  69. sonusai/mixture/resample.py +7 -0
  70. sonusai/mixture/sox_effects.py +195 -0
  71. sonusai/mixture/sox_help.py +650 -0
  72. sonusai/mixture/spectral_mask.py +51 -0
  73. sonusai/mixture/truth.py +61 -0
  74. sonusai/mixture/truth_functions/__init__.py +45 -0
  75. sonusai/mixture/truth_functions/crm.py +105 -0
  76. sonusai/mixture/truth_functions/energy.py +222 -0
  77. sonusai/mixture/truth_functions/file.py +48 -0
  78. sonusai/mixture/truth_functions/metadata.py +24 -0
  79. sonusai/mixture/truth_functions/metrics.py +28 -0
  80. sonusai/mixture/truth_functions/phoneme.py +18 -0
  81. sonusai/mixture/truth_functions/sed.py +98 -0
  82. sonusai/mixture/truth_functions/target.py +142 -0
  83. sonusai/mkwav.py +135 -0
  84. sonusai/onnx_predict.py +363 -0
  85. sonusai/parse/__init__.py +0 -0
  86. sonusai/parse/expand.py +156 -0
  87. sonusai/parse/parse_source_directive.py +129 -0
  88. sonusai/parse/rand.py +214 -0
  89. sonusai/py.typed +0 -0
  90. sonusai/queries/__init__.py +0 -0
  91. sonusai/queries/queries.py +239 -0
  92. sonusai/rs.abi3.so +0 -0
  93. sonusai/rs.pyi +1 -0
  94. sonusai/rust/__init__.py +0 -0
  95. sonusai/speech/__init__.py +0 -0
  96. sonusai/speech/l2arctic.py +121 -0
  97. sonusai/speech/librispeech.py +102 -0
  98. sonusai/speech/mcgill.py +71 -0
  99. sonusai/speech/textgrid.py +89 -0
  100. sonusai/speech/timit.py +138 -0
  101. sonusai/speech/types.py +12 -0
  102. sonusai/speech/vctk.py +53 -0
  103. sonusai/speech/voxceleb.py +108 -0
  104. sonusai/utils/__init__.py +3 -0
  105. sonusai/utils/asl_p56.py +130 -0
  106. sonusai/utils/asr.py +91 -0
  107. sonusai/utils/asr_functions/__init__.py +3 -0
  108. sonusai/utils/asr_functions/aaware_whisper.py +69 -0
  109. sonusai/utils/audio_devices.py +50 -0
  110. sonusai/utils/braced_glob.py +50 -0
  111. sonusai/utils/calculate_input_shape.py +26 -0
  112. sonusai/utils/choice.py +51 -0
  113. sonusai/utils/compress.py +25 -0
  114. sonusai/utils/convert_string_to_number.py +6 -0
  115. sonusai/utils/create_timestamp.py +5 -0
  116. sonusai/utils/create_ts_name.py +14 -0
  117. sonusai/utils/dataclass_from_dict.py +27 -0
  118. sonusai/utils/db.py +16 -0
  119. sonusai/utils/docstring.py +53 -0
  120. sonusai/utils/energy_f.py +44 -0
  121. sonusai/utils/engineering_number.py +166 -0
  122. sonusai/utils/evaluate_random_rule.py +15 -0
  123. sonusai/utils/get_frames_per_batch.py +2 -0
  124. sonusai/utils/get_label_names.py +20 -0
  125. sonusai/utils/grouper.py +6 -0
  126. sonusai/utils/human_readable_size.py +7 -0
  127. sonusai/utils/keyboard_interrupt.py +12 -0
  128. sonusai/utils/load_object.py +21 -0
  129. sonusai/utils/max_text_width.py +9 -0
  130. sonusai/utils/model_utils.py +28 -0
  131. sonusai/utils/numeric_conversion.py +11 -0
  132. sonusai/utils/onnx_utils.py +155 -0
  133. sonusai/utils/parallel.py +162 -0
  134. sonusai/utils/path_info.py +7 -0
  135. sonusai/utils/print_mixture_details.py +60 -0
  136. sonusai/utils/rand.py +13 -0
  137. sonusai/utils/ranges.py +43 -0
  138. sonusai/utils/read_predict_data.py +32 -0
  139. sonusai/utils/reshape.py +154 -0
  140. sonusai/utils/seconds_to_hms.py +7 -0
  141. sonusai/utils/stacked_complex.py +82 -0
  142. sonusai/utils/stratified_shuffle_split.py +170 -0
  143. sonusai/utils/tokenized_shell_vars.py +143 -0
  144. sonusai/utils/write_audio.py +26 -0
  145. sonusai/utils/yes_or_no.py +8 -0
  146. sonusai/vars.py +47 -0
  147. sonusai-1.0.16.dist-info/METADATA +56 -0
  148. sonusai-1.0.16.dist-info/RECORD +150 -0
  149. sonusai-1.0.16.dist-info/WHEEL +4 -0
  150. sonusai-1.0.16.dist-info/entry_points.txt +3 -0
sonusai/parse/rand.py ADDED
@@ -0,0 +1,214 @@
1
+ """
2
+ Parse 'rand' expressions.
3
+
4
+ """
5
+
6
+ import decimal
7
+ import re
8
+ from random import uniform
9
+
10
+ import pyparsing as pp
11
+
12
+ SIGNIFICANT_DIGITS = 6
13
+
14
+
15
+ def rand(directive: str) -> str:
16
+ """Evaluate the 'rand(min, max)' directive and validate its syntax.
17
+
18
+ :param directive: Directive to evaluate
19
+ :return: Text with all 'rand' directives replaced with a random value,
20
+ with a certain number of significant digits, or an empty string if 'text' is empty or None.
21
+ :raises ValueError: If the expression cannot be parsed or is malformed.
22
+ """
23
+ if not directive:
24
+ return directive
25
+
26
+ # Create a recursive grammar for correct expressions
27
+ expr = pp.Forward()
28
+ number = pp.pyparsing_common.number
29
+
30
+ func_name = pp.Literal("rand")
31
+ left_paren = pp.Literal("(").suppress()
32
+ right_paren = pp.Literal(")").suppress()
33
+ comma = pp.Literal(",").suppress()
34
+
35
+ # Allow whitespace around function parameters
36
+ rand_function = (
37
+ func_name
38
+ + left_paren
39
+ + pp.Optional(pp.White()).suppress()
40
+ + (number | expr)("min_val")
41
+ + pp.Optional(pp.White()).suppress()
42
+ + comma
43
+ + pp.Optional(pp.White()).suppress()
44
+ + (number | expr)("max_val")
45
+ + pp.Optional(pp.White()).suppress()
46
+ + right_paren
47
+ )
48
+
49
+ # Complete the recursive definition
50
+ expr << rand_function # pyright: ignore [reportUnusedExpression]
51
+
52
+ # Define parse action for generating random values
53
+ def replace_with_random(tokens):
54
+ min_val_token = tokens["min_val"]
55
+ max_val_token = tokens["max_val"]
56
+
57
+ # Convert tokens to float, handling both direct values and strings
58
+ min_val = float(min_val_token)
59
+ max_val = float(max_val_token)
60
+
61
+ # Validate min/max relationship
62
+ if min_val > max_val:
63
+ raise ValueError(f"Min value ({min_val}) cannot be greater than max value ({max_val})")
64
+
65
+ # Generate random value
66
+ value = uniform(min_val, max_val) # noqa: S311
67
+ decimal.getcontext().prec = SIGNIFICANT_DIGITS
68
+ return str(decimal.Decimal(value).normalize())
69
+
70
+ rand_function.setParseAction(replace_with_random)
71
+
72
+ # Create a validator parser for syntax checking only.
73
+ # This parser doesn't transform but just validates the syntax.
74
+ validator = pp.Forward()
75
+ validator_rand = (
76
+ func_name
77
+ + left_paren
78
+ + pp.Optional(pp.White()).suppress()
79
+ + (number | validator)
80
+ + pp.Optional(pp.White()).suppress()
81
+ + comma
82
+ + pp.Optional(pp.White()).suppress()
83
+ + (number | validator)
84
+ + pp.Optional(pp.White()).suppress()
85
+ + right_paren
86
+ )
87
+ validator << validator_rand # pyright: ignore [reportUnusedExpression]
88
+
89
+ try:
90
+ # First, try to validate all 'rand' expressions without evaluating them.
91
+ # This helps identify structural problems before evaluation.
92
+ malformations = []
93
+
94
+ # Find all potential 'rand' expressions with or without opening/closing parentheses
95
+ potential_expressions = list(re.finditer(r"rand\s*(\()?(?:[^()]|\([^()]*\))*\)?", directive))
96
+
97
+ for match in potential_expressions:
98
+ expr_text = match.group(0)
99
+
100
+ # Check for missing opening parenthesis
101
+ if "rand" in expr_text and "(" not in expr_text:
102
+ malformations.append(f"Missing opening parenthesis in '{expr_text}'")
103
+ continue
104
+
105
+ # Check for missing closing parenthesis
106
+ if not expr_text.endswith(")"):
107
+ malformations.append(f"Missing closing parenthesis in '{expr_text}'")
108
+ continue
109
+
110
+ # Try to validate the expression structure
111
+ try:
112
+ validator.parseString(expr_text, parseAll=True)
113
+ except pp.ParseException:
114
+ # Count commas to check for parameter issues
115
+ param_text = expr_text[expr_text.find("(") + 1 : expr_text.rfind(")")]
116
+
117
+ # Track parenthesis nesting level to count commas correctly
118
+ nesting_level = 0
119
+ comma_count = 0
120
+
121
+ for char in param_text:
122
+ if char == "(":
123
+ nesting_level += 1
124
+ elif char == ")":
125
+ nesting_level -= 1
126
+ elif char == "," and nesting_level == 0:
127
+ comma_count += 1
128
+
129
+ if comma_count == 0:
130
+ if not param_text.strip():
131
+ malformations.append(f"Missing parameters in '{expr_text}' (expected 2)")
132
+ else:
133
+ # Check if there might be a space instead of comma
134
+ if re.search(r"\d+\s+[-+]?\d+", param_text):
135
+ malformations.append(f"Missing comma between parameters in '{expr_text}'")
136
+ else:
137
+ malformations.append(f"Too few parameters in '{expr_text}' (expected 2, got 1)")
138
+ elif comma_count > 1:
139
+ malformations.append(f"Too many parameters in '{expr_text}' (expected 2, got {comma_count + 1})")
140
+ else:
141
+ # There's 1 comma, so we have 2 parameters, but still a parsing error
142
+ # This is likely a non-numeric parameter
143
+ params = [p.strip() for p in split_params_respecting_nesting(param_text)]
144
+
145
+ for i, param in enumerate(params):
146
+ # Check nested 'rand' expressions for validity
147
+ if "rand" in param:
148
+ # Check if the nested expression is valid by recursively calling 'rand'
149
+ try:
150
+ # We only want to validate, not transform
151
+ nested_validator = pp.Forward()
152
+ nested_validator_rand = (
153
+ func_name
154
+ + left_paren
155
+ + pp.Optional(pp.White()).suppress()
156
+ + (number | nested_validator)
157
+ + pp.Optional(pp.White()).suppress()
158
+ + comma
159
+ + pp.Optional(pp.White()).suppress()
160
+ + (number | nested_validator)
161
+ + pp.Optional(pp.White()).suppress()
162
+ + right_paren
163
+ )
164
+ nested_validator << nested_validator_rand # pyright: ignore [reportUnusedExpression]
165
+ nested_validator.parseString(param, parseAll=True)
166
+ except pp.ParseException:
167
+ malformations.append(f"Invalid nested expression '{param}' in '{expr_text}'")
168
+ continue
169
+
170
+ # Check if the parameter is numeric
171
+ if not is_numeric(param):
172
+ param_name = "first" if i == 0 else "second"
173
+ malformations.append(f"Non-numeric {param_name} parameter '{param}' in '{expr_text}'")
174
+
175
+ if malformations:
176
+ raise ValueError(f"Malformed rand directive: {'; '.join(malformations)}")
177
+
178
+ # If validation passes, try to transform
179
+ result = rand_function.transformString(directive)
180
+ except pp.ParseException as e:
181
+ raise ValueError(f"Invalid rand expression in '{directive}': {e!s}") from e
182
+
183
+ return result
184
+
185
+
186
+ def is_numeric(text: str) -> bool:
187
+ """Check if the text is a valid number (including scientific notation)."""
188
+ numeric_pattern = r"^[-+]?(\d+(\.\d*)?|\.\d+)([eE][-+]?\d+)?$"
189
+ return bool(re.match(numeric_pattern, text))
190
+
191
+
192
+ def split_params_respecting_nesting(param_text: str) -> list:
193
+ """Split parameters by comma while respecting nested parentheses."""
194
+ result = []
195
+ current_param = []
196
+ nesting_level = 0
197
+
198
+ for char in param_text:
199
+ if char == "(":
200
+ nesting_level += 1
201
+ current_param.append(char)
202
+ elif char == ")" and nesting_level > 0:
203
+ nesting_level -= 1
204
+ current_param.append(char)
205
+ elif char == "," and nesting_level == 0:
206
+ result.append("".join(current_param))
207
+ current_param = []
208
+ else:
209
+ current_param.append(char)
210
+
211
+ if current_param:
212
+ result.append("".join(current_param))
213
+
214
+ return result
sonusai/py.typed ADDED
File without changes
File without changes
@@ -0,0 +1,239 @@
1
+ from collections.abc import Callable
2
+ from typing import Any
3
+
4
+ from ..datatypes import GeneralizedIDs
5
+ from ..mixture.mixdb import MixtureDatabase
6
+
7
+
8
+ def _true_predicate(_: Any) -> bool:
9
+ return True
10
+
11
+
12
+ def get_mixids_from_mixture_field_predicate(
13
+ mixdb: MixtureDatabase,
14
+ field: str,
15
+ mixids: GeneralizedIDs = "*",
16
+ predicate: Callable[[Any], bool] | None = None,
17
+ ) -> dict[int, list[int]]:
18
+ """
19
+ Generate mixture IDs based on the mixture field and predicate
20
+ Return a dictionary where:
21
+ - keys are the matching field values
22
+ - values are lists of the mixids that match the criteria
23
+ """
24
+ mixid_out = mixdb.mixids_to_list(mixids)
25
+
26
+ if predicate is None:
27
+ predicate = _true_predicate
28
+
29
+ criteria_set = set()
30
+ for m_id in mixid_out:
31
+ value = getattr(mixdb.mixture(m_id), field)
32
+ if isinstance(value, dict):
33
+ for v in value.values():
34
+ if predicate(v):
35
+ criteria_set.add(v)
36
+ elif predicate(value):
37
+ criteria_set.add(value)
38
+ criteria = sorted(criteria_set)
39
+
40
+ result: dict[int, list[int]] = {}
41
+ for criterion in criteria:
42
+ result[criterion] = []
43
+ for m_id in mixid_out:
44
+ value = getattr(mixdb.mixture(m_id), field)
45
+ if isinstance(value, dict):
46
+ for v in value.values():
47
+ if v == criterion:
48
+ result[criterion].append(m_id)
49
+ elif value == criterion:
50
+ result[criterion].append(m_id)
51
+
52
+ return result
53
+
54
+
55
+ def get_mixids_from_truth_configs_field_predicate(
56
+ mixdb: MixtureDatabase,
57
+ field: str,
58
+ mixids: GeneralizedIDs = "*",
59
+ predicate: Callable[[Any], bool] | None = None,
60
+ ) -> dict[int, list[int]]:
61
+ """
62
+ Generate mixture IDs based on the target truth_configs field and predicate
63
+ Return a dictionary where:
64
+ - keys are the matching field values
65
+ - values are lists of the mixids that match the criteria
66
+ """
67
+ from ..config.constants import REQUIRED_TRUTH_CONFIG_FIELDS
68
+
69
+ mixid_out = mixdb.mixids_to_list(mixids)
70
+
71
+ # Get all field values
72
+ values = get_all_truth_configs_values_from_field(mixdb, field)
73
+
74
+ if predicate is None:
75
+ predicate = _true_predicate
76
+
77
+ # Get only values of interest
78
+ values = [value for value in values if predicate(value)]
79
+
80
+ result = {}
81
+ for value in values:
82
+ # Get a list of sources for each field value
83
+ indices = []
84
+ for s_ids in mixdb.source_file_ids.values():
85
+ for s_id in s_ids:
86
+ source = mixdb.source_file(s_id)
87
+ for truth_config in source.truth_configs.values():
88
+ if field in REQUIRED_TRUTH_CONFIG_FIELDS:
89
+ if value in getattr(truth_config, field):
90
+ indices.append(s_id)
91
+ else:
92
+ if value in getattr(truth_config.config, field):
93
+ indices.append(s_id)
94
+ indices = sorted(set(indices))
95
+
96
+ mixids = []
97
+ for index in indices:
98
+ for m_id in mixid_out:
99
+ if index in [source.file_id for source in mixdb.mixture(m_id).all_sources.values()]:
100
+ mixids.append(m_id)
101
+
102
+ mixids = sorted(set(mixids))
103
+ if mixids:
104
+ result[value] = mixids
105
+
106
+ return result
107
+
108
+
109
+ def get_all_truth_configs_values_from_field(mixdb: MixtureDatabase, field: str) -> list:
110
+ """
111
+ Generate a list of all values corresponding to the given field in truth_configs
112
+ """
113
+ from ..config.constants import REQUIRED_TRUTH_CONFIG_FIELDS
114
+
115
+ result = []
116
+ for sources in mixdb.source_files.values():
117
+ for source in sources:
118
+ for truth_config in source.truth_configs.values():
119
+ if field in REQUIRED_TRUTH_CONFIG_FIELDS:
120
+ value = getattr(truth_config, field)
121
+ else:
122
+ value = getattr(truth_config.config, field, None)
123
+ if not isinstance(value, list):
124
+ value = [value]
125
+ result.extend(value)
126
+
127
+ return sorted(set(result))
128
+
129
+
130
+ def get_mixids_from_noise(
131
+ mixdb: MixtureDatabase,
132
+ mixids: GeneralizedIDs = "*",
133
+ predicate: Callable[[Any], bool] | None = None,
134
+ ) -> dict[int, list[int]]:
135
+ """
136
+ Generate mixids based on noise index predicate
137
+ Return a dictionary where:
138
+ - keys are the noise indices
139
+ - values are lists of the mixids that match the noise index
140
+ """
141
+ return get_mixids_from_mixture_field_predicate(mixdb=mixdb, mixids=mixids, field="noise_id", predicate=predicate)
142
+
143
+
144
+ def get_mixids_from_source(
145
+ mixdb: MixtureDatabase,
146
+ mixids: GeneralizedIDs = "*",
147
+ predicate: Callable[[Any], bool] | None = None,
148
+ ) -> dict[int, list[int]]:
149
+ """
150
+ Generate mixids based on a source index predicate
151
+ Return a dictionary where:
152
+ - keys are the source indices
153
+ - values are lists of the mixids that match the source index
154
+ """
155
+ return get_mixids_from_mixture_field_predicate(mixdb=mixdb, mixids=mixids, field="source_ids", predicate=predicate)
156
+
157
+
158
+ def get_mixids_from_snr(
159
+ mixdb: MixtureDatabase,
160
+ mixids: GeneralizedIDs = "*",
161
+ predicate: Callable[[Any], bool] | None = None,
162
+ ) -> dict[float, list[int]]:
163
+ """
164
+ Generate mixids based on an SNR predicate
165
+ Return a dictionary where:
166
+ - keys are the SNRs
167
+ - values are lists of the mixids that match the SNR
168
+ """
169
+ mixid_out = mixdb.mixids_to_list(mixids)
170
+
171
+ # Get all the SNRs
172
+ snrs = [float(snr) for snr in mixdb.all_snrs if not snr.is_random]
173
+
174
+ if predicate is None:
175
+ predicate = _true_predicate
176
+
177
+ # Get only the SNRs of interest (filter on predicate)
178
+ snrs = [snr for snr in snrs if predicate(snr)]
179
+
180
+ result: dict[float, list[int]] = {}
181
+ for snr in snrs:
182
+ # Get a list of mixids for each SNR
183
+ result[snr] = sorted(
184
+ [i for i, mixture in enumerate(mixdb.mixtures) if mixture.noise.snr == snr and i in mixid_out]
185
+ )
186
+
187
+ return result
188
+
189
+
190
+ def get_mixids_from_class_indices(
191
+ mixdb: MixtureDatabase,
192
+ mixids: GeneralizedIDs = "*",
193
+ predicate: Callable[[Any], bool] | None = None,
194
+ ) -> dict[int, list[int]]:
195
+ """
196
+ Generate mixids based on a class index predicate
197
+ Return a dictionary where:
198
+ - keys are the class indices
199
+ - values are lists of the mixids that match the class index
200
+ """
201
+ mixid_out = mixdb.mixids_to_list(mixids)
202
+
203
+ if predicate is None:
204
+ predicate = _true_predicate
205
+
206
+ criteria_set = set()
207
+ for m_id in mixid_out:
208
+ class_indices = mixdb.mixture_class_indices(m_id)
209
+ for class_index in class_indices:
210
+ if predicate(class_index):
211
+ criteria_set.add(class_index)
212
+ criteria = sorted(criteria_set)
213
+
214
+ result: dict[int, list[int]] = {}
215
+ for criterion in criteria:
216
+ result[criterion] = []
217
+ for m_id in mixid_out:
218
+ class_indices = mixdb.mixture_class_indices(m_id)
219
+ for class_index in class_indices:
220
+ if class_index == criterion:
221
+ result[criterion].append(m_id)
222
+
223
+ return result
224
+
225
+
226
+ def get_mixids_from_truth_function(
227
+ mixdb: MixtureDatabase,
228
+ mixids: GeneralizedIDs = "*",
229
+ predicate: Callable[[Any], bool] | None = None,
230
+ ) -> dict[int, list[int]]:
231
+ """
232
+ Generate mixids based on a truth function predicate
233
+ Return a dictionary where:
234
+ - keys are the truth functions
235
+ - values are lists of the mixids that match the truth function
236
+ """
237
+ return get_mixids_from_truth_configs_field_predicate(
238
+ mixdb=mixdb, mixids=mixids, field="function", predicate=predicate
239
+ )
sonusai/rs.abi3.so ADDED
Binary file
sonusai/rs.pyi ADDED
@@ -0,0 +1 @@
1
+ __version__: str
File without changes
File without changes
@@ -0,0 +1,121 @@
1
+ import os
2
+ import string
3
+ from pathlib import Path
4
+
5
+ from .types import TimeAlignedType
6
+
7
+
8
+ def _get_duration(name: str) -> float:
9
+ import soundfile
10
+
11
+ try:
12
+ return soundfile.info(name).duration
13
+ except Exception as e:
14
+ raise OSError(f"Error reading {name}: {e}") from e
15
+
16
+
17
+ def load_text(audio: str | os.PathLike[str]) -> TimeAlignedType | None:
18
+ """Load time-aligned text data given a L2-ARCTIC audio file.
19
+
20
+ :param audio: Path to the L2-ARCTIC audio file.
21
+ :return: A TimeAlignedType object.
22
+ """
23
+ file = Path(audio).parent.parent / "transcript" / (Path(audio).stem + ".txt")
24
+ if not os.path.exists(file):
25
+ return None
26
+
27
+ with open(file, encoding="utf-8") as f:
28
+ line = f.read()
29
+
30
+ return TimeAlignedType(
31
+ 0,
32
+ _get_duration(str(audio)),
33
+ line.strip().lower().translate(str.maketrans("", "", string.punctuation)),
34
+ )
35
+
36
+
37
+ def load_words(audio: str | os.PathLike[str]) -> list[TimeAlignedType] | None:
38
+ """Load time-aligned word data given a L2-ARCTIC audio file.
39
+
40
+ :param audio: Path to the L2-ARCTIC audio file.
41
+ :return: A list of TimeAlignedType objects.
42
+ """
43
+ return _load_ta(audio, "words")
44
+
45
+
46
+ def load_phonemes(audio: str | os.PathLike[str]) -> list[TimeAlignedType] | None:
47
+ """Load time-aligned phonemes data given a L2-ARCTIC audio file.
48
+
49
+ :param audio: Path to the L2-ARCTIC audio file.
50
+ :return: A list of TimeAlignedType objects.
51
+ """
52
+ return _load_ta(audio, "phones")
53
+
54
+
55
+ def _load_ta(audio: str | os.PathLike[str], tier: str) -> list[TimeAlignedType] | None:
56
+ from praatio import textgrid
57
+ from praatio.utilities.constants import Interval
58
+
59
+ file = Path(audio).parent.parent / "textgrid" / (Path(audio).stem + ".TextGrid")
60
+ if not os.path.exists(file):
61
+ return None
62
+
63
+ tg = textgrid.openTextgrid(str(file), includeEmptyIntervals=False)
64
+ if tier not in tg.tierNames:
65
+ return None
66
+
67
+ entries: list[TimeAlignedType] = []
68
+ for entry in tg.getTier(tier).entries:
69
+ if isinstance(entry, Interval):
70
+ entries.append(TimeAlignedType(text=entry.label, start=entry.start, end=entry.end))
71
+
72
+ return entries
73
+
74
+
75
+ def load_annotations(
76
+ audio: str | os.PathLike[str],
77
+ ) -> dict[str, list[TimeAlignedType]] | None:
78
+ """Load time-aligned annotation data given a L2-ARCTIC audio file.
79
+
80
+ :param audio: Path to the L2-ARCTIC audio file.
81
+ :return: A dictionary of a list of TimeAlignedType objects.
82
+ """
83
+ from praatio import textgrid
84
+ from praatio.utilities.constants import Interval
85
+
86
+ file = Path(audio).parent.parent / "annotation" / (Path(audio).stem + ".TextGrid")
87
+ if not os.path.exists(file):
88
+ return None
89
+
90
+ tg = textgrid.openTextgrid(str(file), includeEmptyIntervals=False)
91
+ result: dict[str, list[TimeAlignedType]] = {}
92
+ for tier in tg.tierNames:
93
+ entries: list[TimeAlignedType] = []
94
+ for entry in tg.getTier(tier).entries:
95
+ if isinstance(entry, Interval):
96
+ entries.append(TimeAlignedType(text=entry.label, start=entry.start, end=entry.end))
97
+ result[tier] = entries
98
+
99
+ return result
100
+
101
+
102
+ def load_speakers(input_dir: Path) -> dict:
103
+ speakers = {}
104
+ with open(input_dir / "readme-download.txt") as file:
105
+ processing = False
106
+ for line in file:
107
+ if not processing and line.startswith("|---|"):
108
+ processing = True
109
+ continue
110
+
111
+ if processing:
112
+ if line.startswith("|**Total**|"):
113
+ break
114
+ else:
115
+ fields = line.strip().split("|")
116
+ speaker_id = fields[1]
117
+ gender = fields[2]
118
+ dialect = fields[3]
119
+ speakers[speaker_id] = {"gender": gender, "dialect": dialect}
120
+
121
+ return speakers
@@ -0,0 +1,102 @@
1
+ import os
2
+ from pathlib import Path
3
+
4
+ from .types import TimeAlignedType
5
+
6
+
7
+ def _get_num_samples(audio: str | os.PathLike[str]) -> int:
8
+ """Get number of samples from audio file using soundfile
9
+
10
+ :param audio: Audio file name
11
+ :return: Number of samples
12
+ """
13
+ import soundfile
14
+ from pydub import AudioSegment
15
+
16
+ if Path(audio).suffix == ".mp3":
17
+ return AudioSegment.from_mp3(audio).frame_count()
18
+
19
+ if Path(audio).suffix == ".m4a":
20
+ return AudioSegment.from_file(audio).frame_count()
21
+
22
+ return soundfile.info(audio).frames
23
+
24
+
25
+ def load_text(audio: str | os.PathLike[str]) -> TimeAlignedType | None:
26
+ """Load text data from a LibriSpeech transcription file given a LibriSpeech audio filename.
27
+
28
+ :param audio: Path to the LibriSpeech audio file.
29
+ :return: A TimeAlignedType object.
30
+ """
31
+ import string
32
+
33
+ from ..mixture.audio import get_sample_rate
34
+
35
+ path = Path(audio)
36
+ name = path.stem
37
+ transcript_filename = path.parent / f"{path.parent.parent.name}-{path.parent.name}.trans.txt"
38
+
39
+ if not os.path.exists(transcript_filename):
40
+ return None
41
+
42
+ with open(transcript_filename, encoding="utf-8") as f:
43
+ for line in f.readlines():
44
+ fields = line.strip().split()
45
+ key = fields[0]
46
+ if key == name:
47
+ text = " ".join(fields[1:]).lower().translate(str.maketrans("", "", string.punctuation))
48
+ return TimeAlignedType(0, _get_num_samples(audio) / get_sample_rate(str(audio)), text)
49
+
50
+ return None
51
+
52
+
53
+ def load_words(audio: str | os.PathLike[str]) -> list[TimeAlignedType] | None:
54
+ """Load time-aligned word data given a LibriSpeech audio file.
55
+
56
+ :param audio: Path to the Librispeech audio file.
57
+ :return: A list of TimeAlignedType objects.
58
+ """
59
+ return _load_ta(audio, "words")
60
+
61
+
62
+ def load_phonemes(audio: str | os.PathLike[str]) -> list[TimeAlignedType] | None:
63
+ """Load time-aligned phonemes data given a LibriSpeech audio file.
64
+
65
+ :param audio: Path to the LibriSpeech audio file.
66
+ :return: A list of TimeAlignedType objects.
67
+ """
68
+ return _load_ta(audio, "phones")
69
+
70
+
71
+ def _load_ta(audio: str | os.PathLike[str], tier: str) -> list[TimeAlignedType] | None:
72
+ from praatio import textgrid
73
+ from praatio.utilities.constants import Interval
74
+
75
+ file = Path(audio).with_suffix(".TextGrid")
76
+ if not os.path.exists(file):
77
+ return None
78
+
79
+ tg = textgrid.openTextgrid(str(file), includeEmptyIntervals=False)
80
+ if tier not in tg.tierNames:
81
+ return None
82
+
83
+ entries: list[TimeAlignedType] = []
84
+ for entry in tg.getTier(tier).entries:
85
+ if isinstance(entry, Interval):
86
+ entries.append(TimeAlignedType(text=entry.label, start=entry.start, end=entry.end))
87
+ else:
88
+ entries.append(TimeAlignedType(text=entry.label, start=entry.time, end=entry.time))
89
+
90
+ return entries
91
+
92
+
93
+ def load_speakers(input_dir: Path) -> dict:
94
+ speakers = {}
95
+ with open(input_dir / "SPEAKERS.TXT") as file:
96
+ for line in file:
97
+ if not line.startswith(";"):
98
+ fields = line.strip().split("|")
99
+ speaker_id = fields[0].strip()
100
+ gender = fields[1].strip()
101
+ speakers[speaker_id] = {"gender": gender}
102
+ return speakers