sonusai 0.17.3__py3-none-any.whl → 0.18.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +0 -1
- sonusai/calc_metric_spenh.py +74 -45
- sonusai/doc/doc.py +0 -24
- sonusai/genmetrics.py +146 -0
- sonusai/genmixdb.py +0 -2
- sonusai/mixture/__init__.py +0 -1
- sonusai/mixture/constants.py +0 -1
- sonusai/mixture/datatypes.py +2 -9
- sonusai/mixture/db_datatypes.py +72 -0
- sonusai/mixture/generation.py +139 -38
- sonusai/mixture/helpers.py +75 -16
- sonusai/mixture/mapped_snr_f.py +56 -9
- sonusai/mixture/mixdb.py +347 -226
- sonusai/mixture/tokenized_shell_vars.py +8 -1
- sonusai/speech/textgrid.py +6 -24
- {sonusai-0.17.3.dist-info → sonusai-0.18.1.dist-info}/METADATA +3 -1
- {sonusai-0.17.3.dist-info → sonusai-0.18.1.dist-info}/RECORD +19 -24
- sonusai/mixture/speaker_metadata.py +0 -35
- sonusai/mkmanifest.py +0 -209
- sonusai/utils/asr_manifest_functions/__init__.py +0 -6
- sonusai/utils/asr_manifest_functions/data.py +0 -1
- sonusai/utils/asr_manifest_functions/librispeech.py +0 -46
- sonusai/utils/asr_manifest_functions/mcgill_speech.py +0 -29
- sonusai/utils/asr_manifest_functions/vctk_noisy_speech.py +0 -66
- {sonusai-0.17.3.dist-info → sonusai-0.18.1.dist-info}/WHEEL +0 -0
- {sonusai-0.17.3.dist-info → sonusai-0.18.1.dist-info}/entry_points.txt +0 -0
sonusai/mixture/generation.py
CHANGED
@@ -37,7 +37,15 @@ def initialize_db(location: str, test: bool = False) -> None:
|
|
37
37
|
id INTEGER PRIMARY KEY NOT NULL,
|
38
38
|
name TEXT NOT NULL,
|
39
39
|
samples INTEGER NOT NULL,
|
40
|
-
level_type TEXT NOT NULL
|
40
|
+
level_type TEXT NOT NULL,
|
41
|
+
speaker_id INTEGER,
|
42
|
+
FOREIGN KEY(speaker_id) REFERENCES speaker (id))
|
43
|
+
""")
|
44
|
+
|
45
|
+
con.execute("""
|
46
|
+
CREATE TABLE speaker (
|
47
|
+
id INTEGER PRIMARY KEY NOT NULL,
|
48
|
+
parent TEXT NOT NULL)
|
41
49
|
""")
|
42
50
|
|
43
51
|
con.execute("""
|
@@ -58,13 +66,9 @@ def initialize_db(location: str, test: bool = False) -> None:
|
|
58
66
|
seed INTEGER NOT NULL,
|
59
67
|
truth_mutex BOOLEAN NOT NULL,
|
60
68
|
truth_reduction_function TEXT NOT NULL,
|
61
|
-
mixid_width INTEGER NOT NULL
|
62
|
-
|
63
|
-
|
64
|
-
con.execute("""
|
65
|
-
CREATE TABLE asr_manifest (
|
66
|
-
id INTEGER PRIMARY KEY NOT NULL,
|
67
|
-
manifest TEXT NOT NULL)
|
69
|
+
mixid_width INTEGER NOT NULL,
|
70
|
+
speaker_metadata_tiers TEXT NOT NULL,
|
71
|
+
textgrid_metadata_tiers TEXT NOT NULL)
|
68
72
|
""")
|
69
73
|
|
70
74
|
con.execute("""
|
@@ -155,8 +159,8 @@ def populate_top_table(location: str, config: dict, test: bool = False) -> None:
|
|
155
159
|
con = db_connection(location=location, readonly=False, test=test)
|
156
160
|
con.execute("""
|
157
161
|
INSERT INTO top (version, class_balancing, feature, noise_mix_mode, num_classes,
|
158
|
-
seed, truth_mutex, truth_reduction_function, mixid_width)
|
159
|
-
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
162
|
+
seed, truth_mutex, truth_reduction_function, mixid_width, speaker_metadata_tiers, textgrid_metadata_tiers)
|
163
|
+
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
160
164
|
""", (
|
161
165
|
1,
|
162
166
|
config['class_balancing'],
|
@@ -166,19 +170,9 @@ def populate_top_table(location: str, config: dict, test: bool = False) -> None:
|
|
166
170
|
config['seed'],
|
167
171
|
truth_mutex,
|
168
172
|
config['truth_reduction_function'],
|
169
|
-
0
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
def populate_asr_manifest_table(location: str, config: dict, test: bool = False) -> None:
|
175
|
-
"""Populate asr_manifest table
|
176
|
-
"""
|
177
|
-
from .mixdb import db_connection
|
178
|
-
|
179
|
-
con = db_connection(location=location, readonly=False, test=test)
|
180
|
-
con.executemany("INSERT INTO asr_manifest (manifest) VALUES (?)",
|
181
|
-
[(item,) for item in config['asr_manifest']])
|
173
|
+
0,
|
174
|
+
'',
|
175
|
+
''))
|
182
176
|
con.commit()
|
183
177
|
con.close()
|
184
178
|
|
@@ -242,36 +236,51 @@ def populate_spectral_mask_table(location: str, config: dict, test: bool = False
|
|
242
236
|
def populate_target_file_table(location: str, target_files: TargetFiles, test: bool = False) -> None:
|
243
237
|
"""Populate target file table
|
244
238
|
"""
|
239
|
+
import json
|
240
|
+
from pathlib import Path
|
241
|
+
|
245
242
|
from .mixdb import db_connection
|
246
243
|
|
247
|
-
|
244
|
+
_populate_truth_setting_table(location, target_files, test)
|
245
|
+
_populate_speaker_table(location, target_files, test)
|
248
246
|
|
249
|
-
|
250
|
-
truth_settings: list[str] = []
|
251
|
-
for truth_setting in [truth_setting for target_file in target_files
|
252
|
-
for truth_setting in target_file.truth_settings]:
|
253
|
-
ts = truth_setting.to_json()
|
254
|
-
if ts not in truth_settings:
|
255
|
-
truth_settings.append(ts)
|
256
|
-
con.executemany("INSERT INTO truth_setting (setting) VALUES (?)",
|
257
|
-
[(item,) for item in truth_settings])
|
247
|
+
con = db_connection(location=location, readonly=False, test=test)
|
258
248
|
|
259
|
-
# Populate target_file table
|
260
249
|
cur = con.cursor()
|
250
|
+
textgrid_metadata_tiers: set[str] = set()
|
261
251
|
for target_file in target_files:
|
252
|
+
# Get TextGrid tiers for target file and add to collection
|
253
|
+
tiers = _get_textgrid_tiers_from_target_file(target_file.name)
|
254
|
+
for tier in tiers:
|
255
|
+
textgrid_metadata_tiers.add(tier)
|
256
|
+
|
257
|
+
# Get truth settings for target file
|
262
258
|
truth_setting_ids: list[int] = []
|
263
259
|
for truth_setting in target_file.truth_settings:
|
264
260
|
cur.execute("SELECT truth_setting.id FROM truth_setting WHERE ? = truth_setting.setting",
|
265
261
|
(truth_setting.to_json(),))
|
266
262
|
truth_setting_ids.append(cur.fetchone()[0])
|
267
263
|
|
268
|
-
|
269
|
-
|
264
|
+
# Get speaker_id for target file
|
265
|
+
cur.execute("SELECT speaker.id FROM speaker WHERE ? = speaker.parent",
|
266
|
+
(Path(target_file.name).parent.as_posix(),))
|
267
|
+
result = cur.fetchone()
|
268
|
+
speaker_id = None
|
269
|
+
if result is not None:
|
270
|
+
speaker_id = result[0]
|
271
|
+
|
272
|
+
# Add entry
|
273
|
+
cur.execute("INSERT INTO target_file (name, samples, level_type, speaker_id) VALUES (?, ?, ?, ?)",
|
274
|
+
(target_file.name, target_file.samples, target_file.level_type, speaker_id))
|
270
275
|
target_file_id = cur.lastrowid
|
271
276
|
for truth_setting_id in truth_setting_ids:
|
272
277
|
cur.execute("INSERT INTO target_file_truth_setting (target_file_id, truth_setting_id) VALUES (?, ?)",
|
273
278
|
(target_file_id, truth_setting_id))
|
274
279
|
|
280
|
+
# Update textgrid_metadata_tiers in the top table
|
281
|
+
con.execute("UPDATE top SET textgrid_metadata_tiers=? WHERE top.id = ?",
|
282
|
+
(json.dumps(sorted(textgrid_metadata_tiers)), 1))
|
283
|
+
|
275
284
|
con.commit()
|
276
285
|
con.close()
|
277
286
|
|
@@ -304,8 +313,8 @@ def populate_impulse_response_file_table(location: str, impulse_response_files:
|
|
304
313
|
def update_mixid_width(location: str, num_mixtures: int, test: bool = False) -> None:
|
305
314
|
"""Update the mixid width
|
306
315
|
"""
|
307
|
-
from sonusai.utils import max_text_width
|
308
316
|
from .mixdb import db_connection
|
317
|
+
from sonusai.utils import max_text_width
|
309
318
|
|
310
319
|
con = db_connection(location=location, readonly=False, test=test)
|
311
320
|
con.execute("UPDATE top SET mixid_width=? WHERE top.id = ?", (max_text_width(num_mixtures), 1))
|
@@ -367,8 +376,8 @@ def update_mixture(mixdb: MixtureDatabase,
|
|
367
376
|
"""
|
368
377
|
from .audio import get_next_noise
|
369
378
|
from .augmentation import apply_gain
|
370
|
-
from .helpers import get_target
|
371
379
|
from .datatypes import GenMixData
|
380
|
+
from .helpers import get_target
|
372
381
|
|
373
382
|
mixture, targets_audio = _initialize_targets_audio(mixdb, mixture)
|
374
383
|
|
@@ -917,3 +926,95 @@ def get_all_snrs_from_config(config: dict) -> list[UniversalSNRGenerator]:
|
|
917
926
|
|
918
927
|
return ([UniversalSNRGenerator(is_random=False, _raw_value=snr) for snr in config['snrs']] +
|
919
928
|
[UniversalSNRGenerator(is_random=True, _raw_value=snr) for snr in config['random_snrs']])
|
929
|
+
|
930
|
+
|
931
|
+
def _get_textgrid_tiers_from_target_file(target_file: str) -> list[str]:
|
932
|
+
from pathlib import Path
|
933
|
+
|
934
|
+
from praatio import textgrid
|
935
|
+
|
936
|
+
from sonusai.mixture import tokenized_expand
|
937
|
+
|
938
|
+
textgrid_file = Path(tokenized_expand(target_file)[0]).with_suffix('.TextGrid')
|
939
|
+
if not textgrid_file.exists():
|
940
|
+
return []
|
941
|
+
|
942
|
+
tg = textgrid.openTextgrid(str(textgrid_file), includeEmptyIntervals=False)
|
943
|
+
|
944
|
+
return sorted(tg.tierNames)
|
945
|
+
|
946
|
+
|
947
|
+
def _populate_speaker_table(location: str, target_files: TargetFiles, test: bool = False) -> None:
|
948
|
+
"""Populate speaker table
|
949
|
+
"""
|
950
|
+
import json
|
951
|
+
from pathlib import Path
|
952
|
+
|
953
|
+
import yaml
|
954
|
+
|
955
|
+
from .mixdb import db_connection
|
956
|
+
from .tokenized_shell_vars import tokenized_expand
|
957
|
+
|
958
|
+
# Determine columns for speaker table
|
959
|
+
all_parents = set([Path(target_file.name).parent for target_file in target_files])
|
960
|
+
speaker_parents = (parent for parent in all_parents if Path(tokenized_expand(parent / 'speaker.yml')[0]).exists())
|
961
|
+
|
962
|
+
speakers: dict[Path, dict[str, str]] = {}
|
963
|
+
for parent in sorted(speaker_parents):
|
964
|
+
with open(tokenized_expand(parent / 'speaker.yml')[0], 'r') as f:
|
965
|
+
speakers[parent] = yaml.safe_load(f)
|
966
|
+
|
967
|
+
new_columns: list[str] = []
|
968
|
+
for keys in speakers.keys():
|
969
|
+
for column in speakers[keys].keys():
|
970
|
+
new_columns.append(column)
|
971
|
+
new_columns = sorted(set(new_columns))
|
972
|
+
|
973
|
+
con = db_connection(location=location, readonly=False, test=test)
|
974
|
+
|
975
|
+
for new_column in new_columns:
|
976
|
+
con.execute(f'ALTER TABLE speaker ADD COLUMN {new_column} TEXT')
|
977
|
+
|
978
|
+
# Populate speaker table
|
979
|
+
speaker_rows: list[tuple[str, ...]] = []
|
980
|
+
for key in speakers.keys():
|
981
|
+
entry = (speakers[key].get(column, None) for column in new_columns)
|
982
|
+
speaker_rows.append((key.as_posix(), *entry))
|
983
|
+
|
984
|
+
column_ids = ', '.join(['parent', *new_columns])
|
985
|
+
column_values = ', '.join(['?'] * (len(new_columns) + 1))
|
986
|
+
con.executemany(f'INSERT INTO speaker ({column_ids}) VALUES ({column_values})', speaker_rows)
|
987
|
+
|
988
|
+
con.execute("CREATE INDEX speaker_parent_idx ON speaker (parent)")
|
989
|
+
|
990
|
+
# Update speaker_metadata_tiers in the top table
|
991
|
+
tiers = [description[0] for description in con.execute("SELECT * FROM speaker").description if
|
992
|
+
description[0] not in ('id', 'parent')]
|
993
|
+
con.execute("UPDATE top SET speaker_metadata_tiers=? WHERE top.id = ?", (json.dumps(tiers), 1))
|
994
|
+
|
995
|
+
if 'speaker_id' in tiers:
|
996
|
+
con.execute("CREATE INDEX speaker_speaker_id_idx ON speaker (speaker_id)")
|
997
|
+
|
998
|
+
con.commit()
|
999
|
+
con.close()
|
1000
|
+
|
1001
|
+
|
1002
|
+
def _populate_truth_setting_table(location: str, target_files: TargetFiles, test: bool = False) -> None:
|
1003
|
+
"""Populate truth_setting table
|
1004
|
+
"""
|
1005
|
+
from .mixdb import db_connection
|
1006
|
+
|
1007
|
+
con = db_connection(location=location, readonly=False, test=test)
|
1008
|
+
|
1009
|
+
# Populate truth_setting table
|
1010
|
+
truth_settings: list[str] = []
|
1011
|
+
for truth_setting in [truth_setting for target_file in target_files
|
1012
|
+
for truth_setting in target_file.truth_settings]:
|
1013
|
+
ts = truth_setting.to_json()
|
1014
|
+
if ts not in truth_settings:
|
1015
|
+
truth_settings.append(ts)
|
1016
|
+
con.executemany("INSERT INTO truth_setting (setting) VALUES (?)",
|
1017
|
+
[(item,) for item in truth_settings])
|
1018
|
+
|
1019
|
+
con.commit()
|
1020
|
+
con.close()
|
sonusai/mixture/helpers.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1
1
|
from typing import Any
|
2
|
+
from typing import Optional
|
2
3
|
|
4
|
+
from praatio.utilities.constants import Interval
|
3
5
|
from pyaaware import ForwardTransform
|
4
6
|
from pyaaware import InverseTransform
|
5
7
|
|
@@ -18,11 +20,14 @@ from sonusai.mixture.datatypes import Mixture
|
|
18
20
|
from sonusai.mixture.datatypes import NoiseFile
|
19
21
|
from sonusai.mixture.datatypes import NoiseFiles
|
20
22
|
from sonusai.mixture.datatypes import Segsnr
|
23
|
+
from sonusai.mixture.datatypes import SpeechMetadata
|
21
24
|
from sonusai.mixture.datatypes import Target
|
22
25
|
from sonusai.mixture.datatypes import TargetFiles
|
23
26
|
from sonusai.mixture.datatypes import Targets
|
24
27
|
from sonusai.mixture.datatypes import TransformConfig
|
25
28
|
from sonusai.mixture.datatypes import Truth
|
29
|
+
from sonusai.mixture.db_datatypes import MixtureRecord
|
30
|
+
from sonusai.mixture.db_datatypes import TargetRecord
|
26
31
|
from sonusai.mixture.mixdb import MixtureDatabase
|
27
32
|
|
28
33
|
|
@@ -123,6 +128,35 @@ def write_mixture_data(mixdb: MixtureDatabase,
|
|
123
128
|
f.create_dataset(name=item[0], data=item[1])
|
124
129
|
|
125
130
|
|
131
|
+
def mixture_all_speech_metadata(mixdb: MixtureDatabase, mixture: Mixture) -> list[dict[str, SpeechMetadata]]:
|
132
|
+
"""Get a list of all speech metadata for the given mixture
|
133
|
+
"""
|
134
|
+
results: list[dict[str, SpeechMetadata]] = []
|
135
|
+
for target in mixture.targets:
|
136
|
+
data: dict[str, SpeechMetadata] = {}
|
137
|
+
for tier in mixdb.speaker_metadata_tiers:
|
138
|
+
data[tier] = mixdb.speaker(mixdb.target_file(target.file_id).speaker_id, tier)
|
139
|
+
|
140
|
+
for tier in mixdb.textgrid_metadata_tiers:
|
141
|
+
item = get_textgrid_tier_from_target_file(mixdb.target_file(target.file_id).name, tier)
|
142
|
+
if isinstance(item, list):
|
143
|
+
# Check for tempo augmentation and adjust Interval start and end data as needed
|
144
|
+
entries = []
|
145
|
+
for entry in item:
|
146
|
+
if target.augmentation.tempo is not None:
|
147
|
+
entries.append(Interval(entry.start / target.augmentation.tempo,
|
148
|
+
entry.end / target.augmentation.tempo,
|
149
|
+
entry.label))
|
150
|
+
else:
|
151
|
+
entries.append(entry)
|
152
|
+
data[tier] = entries
|
153
|
+
else:
|
154
|
+
data[tier] = item
|
155
|
+
results.append(data)
|
156
|
+
|
157
|
+
return results
|
158
|
+
|
159
|
+
|
126
160
|
def mixture_metadata(mixdb: MixtureDatabase, mixture: Mixture) -> str:
|
127
161
|
"""Create a string of metadata for a Mixture
|
128
162
|
|
@@ -131,6 +165,7 @@ def mixture_metadata(mixdb: MixtureDatabase, mixture: Mixture) -> str:
|
|
131
165
|
:return: String of metadata
|
132
166
|
"""
|
133
167
|
metadata = ''
|
168
|
+
speech_metadata = mixture_all_speech_metadata(mixdb, mixture)
|
134
169
|
for mi, target in enumerate(mixture.targets):
|
135
170
|
target_file = mixdb.target_file(target.file_id)
|
136
171
|
target_augmentation = target.augmentation
|
@@ -147,7 +182,8 @@ def mixture_metadata(mixdb: MixtureDatabase, mixture: Mixture) -> str:
|
|
147
182
|
metadata += f'target {mi} truth index {tsi}: {truth_settings[tsi].index}\n'
|
148
183
|
metadata += f'target {mi} truth function {tsi}: {truth_settings[tsi].function}\n'
|
149
184
|
metadata += f'target {mi} truth config {tsi}: {truth_settings[tsi].config}\n'
|
150
|
-
|
185
|
+
for key in speech_metadata[mi].keys():
|
186
|
+
metadata += f'target {mi} speech {key}: {speech_metadata[mi][key]}\n'
|
151
187
|
noise = mixdb.noise_file(mixture.noise.file_id)
|
152
188
|
noise_augmentation = mixture.noise.augmentation
|
153
189
|
metadata += f'noise name: {noise.name}\n'
|
@@ -194,7 +230,7 @@ def from_mixture(mixture: Mixture) -> tuple[str, int, str, int, float, bool, flo
|
|
194
230
|
mixture.target_snr_gain)
|
195
231
|
|
196
232
|
|
197
|
-
def to_mixture(entry:
|
233
|
+
def to_mixture(entry: MixtureRecord, targets: Targets) -> Mixture:
|
198
234
|
import json
|
199
235
|
|
200
236
|
from sonusai.utils import dataclass_from_dict
|
@@ -204,32 +240,32 @@ def to_mixture(entry: tuple[str, int, str, int, float, bool, float, int, int, in
|
|
204
240
|
from .datatypes import UniversalSNR
|
205
241
|
|
206
242
|
return Mixture(targets=targets,
|
207
|
-
name=entry
|
208
|
-
noise=Noise(file_id=entry
|
209
|
-
augmentation=dataclass_from_dict(Augmentation, json.loads(entry
|
210
|
-
offset=entry
|
211
|
-
noise_snr_gain=entry
|
212
|
-
snr=UniversalSNR(is_random=entry
|
213
|
-
samples=entry
|
214
|
-
spectral_mask_id=entry
|
215
|
-
spectral_mask_seed=entry
|
216
|
-
target_snr_gain=entry
|
243
|
+
name=entry.name,
|
244
|
+
noise=Noise(file_id=entry.noise_file_id,
|
245
|
+
augmentation=dataclass_from_dict(Augmentation, json.loads(entry.noise_augmentation)),
|
246
|
+
offset=entry.noise_offset),
|
247
|
+
noise_snr_gain=entry.noise_snr_gain,
|
248
|
+
snr=UniversalSNR(is_random=entry.random_snr, value=entry.snr),
|
249
|
+
samples=entry.samples,
|
250
|
+
spectral_mask_id=entry.spectral_mask_id,
|
251
|
+
spectral_mask_seed=entry.spectral_mask_seed,
|
252
|
+
target_snr_gain=entry.target_snr_gain)
|
217
253
|
|
218
254
|
|
219
255
|
def from_target(target: Target) -> tuple[int, str, float]:
|
220
256
|
return target.file_id, target.augmentation.to_json(), target.gain
|
221
257
|
|
222
258
|
|
223
|
-
def to_target(entry:
|
259
|
+
def to_target(entry: TargetRecord) -> Target:
|
224
260
|
import json
|
225
261
|
|
226
262
|
from sonusai.utils import dataclass_from_dict
|
227
263
|
from .datatypes import Augmentation
|
228
264
|
from .datatypes import Target
|
229
265
|
|
230
|
-
return Target(file_id=entry
|
231
|
-
augmentation=dataclass_from_dict(Augmentation, json.loads(entry
|
232
|
-
gain=entry
|
266
|
+
return Target(file_id=entry.file_id,
|
267
|
+
augmentation=dataclass_from_dict(Augmentation, json.loads(entry.augmentation)),
|
268
|
+
gain=entry.gain)
|
233
269
|
|
234
270
|
|
235
271
|
def read_mixture_data(name: str, items: list[str] | str) -> Any:
|
@@ -582,3 +618,26 @@ def augmented_noise_length(noise_file: NoiseFile, noise_augmentation: Augmentati
|
|
582
618
|
|
583
619
|
return estimate_augmented_length_from_length(length=noise_file.samples,
|
584
620
|
tempo=noise_augmentation.tempo)
|
621
|
+
|
622
|
+
|
623
|
+
def get_textgrid_tier_from_target_file(target_file: str, tier: str) -> Optional[SpeechMetadata]:
|
624
|
+
from pathlib import Path
|
625
|
+
|
626
|
+
from praatio import textgrid
|
627
|
+
|
628
|
+
from .tokenized_shell_vars import tokenized_expand
|
629
|
+
|
630
|
+
textgrid_file = Path(tokenized_expand(target_file)[0]).with_suffix('.TextGrid')
|
631
|
+
if not textgrid_file.exists():
|
632
|
+
return None
|
633
|
+
|
634
|
+
tg = textgrid.openTextgrid(str(textgrid_file), includeEmptyIntervals=False)
|
635
|
+
|
636
|
+
if tier not in tg.tierNames:
|
637
|
+
return None
|
638
|
+
|
639
|
+
entries = tg.getTier(tier).entries
|
640
|
+
if len(entries) > 1:
|
641
|
+
return list(entries)
|
642
|
+
else:
|
643
|
+
return entries[0].label
|
sonusai/mixture/mapped_snr_f.py
CHANGED
@@ -7,35 +7,82 @@ def calculate_snr_f_statistics(truth_f: np.ndarray) -> tuple[np.ndarray, np.ndar
|
|
7
7
|
For now, includes mean and standard deviation of the raw values (usually energy)
|
8
8
|
and mean and standard deviation of the dB values (10 * log10).
|
9
9
|
"""
|
10
|
-
|
10
|
+
return (
|
11
|
+
calculate_snr_mean(truth_f),
|
12
|
+
calculate_snr_std(truth_f),
|
13
|
+
calculate_snr_db_mean(truth_f),
|
14
|
+
calculate_snr_db_std(truth_f),
|
15
|
+
)
|
11
16
|
|
12
|
-
snr_mean = np.zeros(classes, dtype=np.float32)
|
13
|
-
snr_std = np.zeros(classes, dtype=np.float32)
|
14
|
-
snr_db_mean = np.zeros(classes, dtype=np.float32)
|
15
|
-
snr_db_std = np.zeros(classes, dtype=np.float32)
|
16
17
|
|
17
|
-
|
18
|
+
def calculate_snr_mean(truth_f: np.ndarray) -> np.ndarray:
|
19
|
+
"""Calculate mean of snr_f truth data."""
|
20
|
+
snr_mean = np.zeros(truth_f.shape[1], dtype=np.float32)
|
21
|
+
|
22
|
+
for c in range(truth_f.shape[1]):
|
18
23
|
tmp_truth = truth_f[:, c]
|
19
24
|
tmp = tmp_truth[np.isfinite(tmp_truth)].astype(np.double)
|
20
25
|
|
21
26
|
if len(tmp) == 0:
|
22
27
|
snr_mean[c] = -np.inf
|
23
|
-
snr_std[c] = -np.inf
|
24
28
|
else:
|
25
29
|
snr_mean[c] = np.mean(tmp)
|
30
|
+
|
31
|
+
return snr_mean
|
32
|
+
|
33
|
+
|
34
|
+
def calculate_snr_std(truth_f: np.ndarray) -> np.ndarray:
|
35
|
+
"""Calculate standard deviation of snr_f truth data."""
|
36
|
+
snr_std = np.zeros(truth_f.shape[1], dtype=np.float32)
|
37
|
+
|
38
|
+
for c in range(truth_f.shape[1]):
|
39
|
+
tmp_truth = truth_f[:, c]
|
40
|
+
tmp = tmp_truth[np.isfinite(tmp_truth)].astype(np.double)
|
41
|
+
|
42
|
+
if len(tmp) == 0:
|
43
|
+
snr_std[c] = -np.inf
|
44
|
+
else:
|
26
45
|
snr_std[c] = np.std(tmp, ddof=1)
|
27
46
|
|
47
|
+
return snr_std
|
48
|
+
|
49
|
+
|
50
|
+
def calculate_snr_db_mean(truth_f: np.ndarray) -> np.ndarray:
|
51
|
+
"""Calculate dB mean of snr_f truth data."""
|
52
|
+
snr_db_mean = np.zeros(truth_f.shape[1], dtype=np.float32)
|
53
|
+
|
54
|
+
for c in range(truth_f.shape[1]):
|
55
|
+
tmp_truth = truth_f[:, c]
|
56
|
+
tmp = tmp_truth[np.isfinite(tmp_truth)].astype(np.double)
|
57
|
+
|
28
58
|
tmp2 = 10 * np.ma.log10(tmp).filled(-np.inf)
|
29
59
|
tmp2 = tmp2[np.isfinite(tmp2)]
|
30
60
|
|
31
61
|
if len(tmp2) == 0:
|
32
62
|
snr_db_mean[c] = -np.inf
|
33
|
-
snr_db_std[c] = -np.inf
|
34
63
|
else:
|
35
64
|
snr_db_mean[c] = np.mean(tmp2)
|
65
|
+
|
66
|
+
return snr_db_mean
|
67
|
+
|
68
|
+
|
69
|
+
def calculate_snr_db_std(truth_f: np.ndarray) -> np.ndarray:
|
70
|
+
"""Calculate dB standard deviation of snr_f truth data."""
|
71
|
+
snr_db_std = np.zeros(truth_f.shape[1], dtype=np.float32)
|
72
|
+
|
73
|
+
for c in range(truth_f.shape[1]):
|
74
|
+
tmp_truth = truth_f[:, c]
|
75
|
+
tmp = tmp_truth[np.isfinite(tmp_truth)].astype(np.double)
|
76
|
+
|
77
|
+
tmp2 = 10 * np.ma.log10(tmp).filled(-np.inf)
|
78
|
+
tmp2 = tmp2[np.isfinite(tmp2)]
|
79
|
+
|
80
|
+
if len(tmp2) == 0:
|
81
|
+
snr_db_std[c] = -np.inf
|
82
|
+
else:
|
36
83
|
snr_db_std[c] = np.std(tmp2, ddof=1)
|
37
84
|
|
38
|
-
return
|
85
|
+
return snr_db_std
|
39
86
|
|
40
87
|
|
41
88
|
def calculate_mapped_snr_f(truth_f: np.ndarray, snr_db_mean: np.ndarray, snr_db_std: np.ndarray) -> np.ndarray:
|