sonusai 0.19.9__py3-none-any.whl → 0.20.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/calc_metric_spenh.py +265 -233
- sonusai/data/genmixdb.yml +4 -2
- sonusai/data/silero_vad_v5.1.jit +0 -0
- sonusai/data/silero_vad_v5.1.onnx +0 -0
- sonusai/doc/doc.py +14 -0
- sonusai/genft.py +1 -1
- sonusai/genmetrics.py +15 -18
- sonusai/genmix.py +1 -1
- sonusai/genmixdb.py +30 -52
- sonusai/ir_metric.py +555 -0
- sonusai/metrics_summary.py +322 -0
- sonusai/mixture/__init__.py +6 -2
- sonusai/mixture/audio.py +139 -15
- sonusai/mixture/augmentation.py +199 -84
- sonusai/mixture/config.py +9 -4
- sonusai/mixture/constants.py +0 -1
- sonusai/mixture/datatypes.py +19 -10
- sonusai/mixture/generation.py +52 -64
- sonusai/mixture/helpers.py +38 -26
- sonusai/mixture/ir_delay.py +63 -0
- sonusai/mixture/mixdb.py +190 -46
- sonusai/mixture/targets.py +3 -6
- sonusai/mixture/truth_functions/energy.py +9 -5
- sonusai/mixture/truth_functions/metrics.py +1 -1
- sonusai/mkwav.py +1 -1
- sonusai/onnx_predict.py +1 -1
- sonusai/queries/queries.py +1 -1
- sonusai/utils/__init__.py +2 -0
- sonusai/utils/asr.py +1 -1
- sonusai/utils/load_object.py +8 -2
- sonusai/utils/stratified_shuffle_split.py +1 -1
- sonusai/utils/temp_seed.py +13 -0
- {sonusai-0.19.9.dist-info → sonusai-0.20.2.dist-info}/METADATA +2 -2
- {sonusai-0.19.9.dist-info → sonusai-0.20.2.dist-info}/RECORD +36 -35
- {sonusai-0.19.9.dist-info → sonusai-0.20.2.dist-info}/WHEEL +1 -1
- sonusai/mixture/soundfile_audio.py +0 -130
- sonusai/mixture/sox_audio.py +0 -476
- sonusai/mixture/sox_augmentation.py +0 -136
- sonusai/mixture/torchaudio_audio.py +0 -106
- sonusai/mixture/torchaudio_augmentation.py +0 -109
- {sonusai-0.19.9.dist-info → sonusai-0.20.2.dist-info}/entry_points.txt +0 -0
sonusai/mixture/generation.py
CHANGED
@@ -93,7 +93,8 @@ def initialize_db(location: str, test: bool = False) -> None:
|
|
93
93
|
CREATE TABLE impulse_response_file (
|
94
94
|
id INTEGER PRIMARY KEY NOT NULL,
|
95
95
|
file TEXT NOT NULL,
|
96
|
-
tags TEXT NOT NULL
|
96
|
+
tags TEXT NOT NULL,
|
97
|
+
delay INTEGER NOT NULL)
|
97
98
|
""")
|
98
99
|
|
99
100
|
con.execute("""
|
@@ -119,8 +120,7 @@ def initialize_db(location: str, test: bool = False) -> None:
|
|
119
120
|
id INTEGER PRIMARY KEY NOT NULL,
|
120
121
|
file_id INTEGER NOT NULL,
|
121
122
|
augmentation TEXT NOT NULL,
|
122
|
-
FOREIGN KEY(file_id) REFERENCES target_file (id)
|
123
|
-
UNIQUE(file_id, augmentation))
|
123
|
+
FOREIGN KEY(file_id) REFERENCES target_file (id))
|
124
124
|
""")
|
125
125
|
|
126
126
|
con.execute("""
|
@@ -361,11 +361,12 @@ def populate_impulse_response_file_table(
|
|
361
361
|
|
362
362
|
con = db_connection(location=location, readonly=False, test=test)
|
363
363
|
con.executemany(
|
364
|
-
"INSERT INTO impulse_response_file (file, tags) VALUES (?, ?)",
|
364
|
+
"INSERT INTO impulse_response_file (file, tags, delay) VALUES (?, ?, ?)",
|
365
365
|
[
|
366
366
|
(
|
367
367
|
impulse_response_file.file,
|
368
368
|
json.dumps(impulse_response_file.tags),
|
369
|
+
impulse_response_file.delay,
|
369
370
|
)
|
370
371
|
for impulse_response_file in impulse_response_files
|
371
372
|
],
|
@@ -389,8 +390,7 @@ def update_mixid_width(location: str, num_mixtures: int, test: bool = False) ->
|
|
389
390
|
con.close()
|
390
391
|
|
391
392
|
|
392
|
-
def
|
393
|
-
location: str,
|
393
|
+
def generate_mixtures(
|
394
394
|
noise_mix_mode: str,
|
395
395
|
augmented_targets: list[AugmentedTarget],
|
396
396
|
target_files: list[TargetFile],
|
@@ -403,13 +403,8 @@ def populate_mixture_table(
|
|
403
403
|
num_classes: int,
|
404
404
|
feature_step_samples: int,
|
405
405
|
num_ir: int,
|
406
|
-
|
407
|
-
|
408
|
-
"""Generate mixtures and populate mixture table"""
|
409
|
-
from .helpers import from_mixture
|
410
|
-
from .helpers import from_target
|
411
|
-
from .mixdb import db_connection
|
412
|
-
|
406
|
+
) -> tuple[int, int, list[Mixture]]:
|
407
|
+
"""Generate mixtures"""
|
413
408
|
if noise_mix_mode == "exhaustive":
|
414
409
|
func = _exhaustive_noise_mix
|
415
410
|
elif noise_mix_mode == "non-exhaustive":
|
@@ -419,7 +414,7 @@ def populate_mixture_table(
|
|
419
414
|
else:
|
420
415
|
raise ValueError(f"invalid noise_mix_mode: {noise_mix_mode}")
|
421
416
|
|
422
|
-
|
417
|
+
return func(
|
423
418
|
augmented_targets=augmented_targets,
|
424
419
|
target_files=target_files,
|
425
420
|
target_augmentations=target_augmentations,
|
@@ -433,20 +428,41 @@ def populate_mixture_table(
|
|
433
428
|
num_ir=num_ir,
|
434
429
|
)
|
435
430
|
|
431
|
+
|
432
|
+
def populate_mixture_table(
|
433
|
+
location: str,
|
434
|
+
mixtures: list[Mixture],
|
435
|
+
test: bool = False,
|
436
|
+
logging: bool = False,
|
437
|
+
show_progress: bool = False,
|
438
|
+
) -> None:
|
439
|
+
"""Populate mixture table"""
|
440
|
+
from sonusai import logger
|
441
|
+
from sonusai.utils import track
|
442
|
+
|
443
|
+
from .helpers import from_mixture
|
444
|
+
from .helpers import from_target
|
445
|
+
from .mixdb import db_connection
|
446
|
+
|
436
447
|
con = db_connection(location=location, readonly=False, test=test)
|
448
|
+
|
437
449
|
# Populate target table
|
450
|
+
if logging:
|
451
|
+
logger.info("Populating target table")
|
452
|
+
targets: list[tuple[int, str]] = []
|
438
453
|
for mixture in mixtures:
|
439
454
|
for target in mixture.targets:
|
440
|
-
|
441
|
-
|
442
|
-
|
443
|
-
|
444
|
-
|
445
|
-
from_target(target),
|
446
|
-
)
|
455
|
+
entry = from_target(target)
|
456
|
+
if entry not in targets:
|
457
|
+
targets.append(entry)
|
458
|
+
for target in track(targets, disable=not show_progress):
|
459
|
+
con.execute("INSERT INTO target (file_id, augmentation) VALUES (?, ?)", target)
|
447
460
|
|
448
461
|
# Populate mixture table
|
449
|
-
|
462
|
+
if logging:
|
463
|
+
logger.info("Populating mixture table")
|
464
|
+
for mixture in track(mixtures, disable=not show_progress):
|
465
|
+
m_id = int(mixture.name)
|
450
466
|
con.execute(
|
451
467
|
"""
|
452
468
|
INSERT INTO mixture (id, name, noise_file_id, noise_augmentation, noise_offset, noise_snr_gain, random_snr,
|
@@ -473,20 +489,13 @@ def populate_mixture_table(
|
|
473
489
|
con.commit()
|
474
490
|
con.close()
|
475
491
|
|
476
|
-
return used_noise_files, used_noise_samples
|
477
492
|
|
478
|
-
|
479
|
-
def update_mixture_table(location: str, m_id: int, with_data: bool = False, test: bool = False) -> GenMixData:
|
493
|
+
def update_mixture(mixdb: MixtureDatabase, mixture: Mixture, with_data: bool = False) -> tuple[Mixture, GenMixData]:
|
480
494
|
"""Update mixture record with name and gains"""
|
481
495
|
from .audio import get_next_noise
|
482
496
|
from .augmentation import apply_gain
|
483
|
-
from .datatypes import GenMixData
|
484
|
-
from .helpers import from_mixture
|
485
497
|
from .helpers import get_target
|
486
|
-
from .mixdb import db_connection
|
487
498
|
|
488
|
-
mixdb = MixtureDatabase(location, test)
|
489
|
-
mixture = mixdb.mixture(m_id)
|
490
499
|
mixture, targets_audio = _initialize_targets_audio(mixdb, mixture)
|
491
500
|
|
492
501
|
noise_audio = _augmented_noise_audio(mixdb, mixture)
|
@@ -501,29 +510,8 @@ def update_mixture_table(location: str, m_id: int, with_data: bool = False, test
|
|
501
510
|
|
502
511
|
mixture.name = f"{int(mixture.name):0{mixdb.mixid_width}}"
|
503
512
|
|
504
|
-
con = db_connection(location=location, readonly=False, test=test)
|
505
|
-
con.execute(
|
506
|
-
"""
|
507
|
-
UPDATE mixture SET name=?,
|
508
|
-
noise_file_id=?,
|
509
|
-
noise_augmentation=?,
|
510
|
-
noise_offset=?,
|
511
|
-
noise_snr_gain=?,
|
512
|
-
random_snr=?,
|
513
|
-
snr=?,
|
514
|
-
samples=?,
|
515
|
-
spectral_mask_id=?,
|
516
|
-
spectral_mask_seed=?,
|
517
|
-
target_snr_gain=?
|
518
|
-
WHERE ? = mixture.id
|
519
|
-
""",
|
520
|
-
(*from_mixture(mixture), m_id + 1),
|
521
|
-
)
|
522
|
-
con.commit()
|
523
|
-
con.close()
|
524
|
-
|
525
513
|
if not with_data:
|
526
|
-
return GenMixData()
|
514
|
+
return mixture, GenMixData()
|
527
515
|
|
528
516
|
# Apply SNR gains
|
529
517
|
targets_audio = [apply_gain(audio=target_audio, gain=mixture.target_snr_gain) for target_audio in targets_audio]
|
@@ -533,7 +521,7 @@ def update_mixture_table(location: str, m_id: int, with_data: bool = False, test
|
|
533
521
|
target_audio = get_target(mixdb, mixture, targets_audio)
|
534
522
|
mixture_audio = target_audio + noise_audio
|
535
523
|
|
536
|
-
return GenMixData(
|
524
|
+
return mixture, GenMixData(
|
537
525
|
mixture=mixture_audio,
|
538
526
|
targets=targets_audio,
|
539
527
|
target=target_audio,
|
@@ -543,17 +531,13 @@ def update_mixture_table(location: str, m_id: int, with_data: bool = False, test
|
|
543
531
|
|
544
532
|
def _augmented_noise_audio(mixdb: MixtureDatabase, mixture: Mixture) -> AudioT:
|
545
533
|
from .audio import read_audio
|
546
|
-
from .audio import read_ir
|
547
534
|
from .augmentation import apply_augmentation
|
548
|
-
from .augmentation import apply_impulse_response
|
549
535
|
|
550
536
|
noise = mixdb.noise_file(mixture.noise.file_id)
|
551
537
|
noise_augmentation = mixture.noise.augmentation
|
552
538
|
|
553
539
|
audio = read_audio(noise.name)
|
554
|
-
audio = apply_augmentation(audio, noise_augmentation)
|
555
|
-
if noise_augmentation.ir is not None:
|
556
|
-
audio = apply_impulse_response(audio, read_ir(mixdb.impulse_response_file(noise_augmentation.ir)))
|
540
|
+
audio = apply_augmentation(mixdb, audio, noise_augmentation.pre)
|
557
541
|
|
558
542
|
return audio
|
559
543
|
|
@@ -567,8 +551,9 @@ def _initialize_targets_audio(mixdb: MixtureDatabase, mixture: Mixture) -> tuple
|
|
567
551
|
target_audio = mixdb.read_target_audio(target.file_id)
|
568
552
|
targets_audio.append(
|
569
553
|
apply_augmentation(
|
554
|
+
mixdb=mixdb,
|
570
555
|
audio=target_audio,
|
571
|
-
augmentation=target.augmentation,
|
556
|
+
augmentation=target.augmentation.pre,
|
572
557
|
frame_length=mixdb.feature_step_samples,
|
573
558
|
)
|
574
559
|
)
|
@@ -582,7 +567,10 @@ def _initialize_targets_audio(mixdb: MixtureDatabase, mixture: Mixture) -> tuple
|
|
582
567
|
|
583
568
|
|
584
569
|
def _initialize_mixture_gains(
|
585
|
-
mixdb: MixtureDatabase,
|
570
|
+
mixdb: MixtureDatabase,
|
571
|
+
mixture: Mixture,
|
572
|
+
target_audio: AudioT,
|
573
|
+
noise_audio: AudioT,
|
586
574
|
) -> Mixture:
|
587
575
|
import numpy as np
|
588
576
|
|
@@ -691,7 +679,7 @@ def _exhaustive_noise_mix(
|
|
691
679
|
noise_offset = 0
|
692
680
|
noise_length = estimate_augmented_length_from_length(
|
693
681
|
length=noise_files[noise_file_id].samples,
|
694
|
-
tempo=noise_augmentation.tempo,
|
682
|
+
tempo=noise_augmentation.pre.tempo,
|
695
683
|
)
|
696
684
|
|
697
685
|
for augmented_target_ids_for_mixup in augmented_target_ids_for_mixups:
|
@@ -933,7 +921,7 @@ def _get_next_noise_indices(
|
|
933
921
|
|
934
922
|
noise_augmentation = augmentation_from_rule(noise_augmentations[noise_augmentation_id], num_ir)
|
935
923
|
noise_length = estimate_augmented_length_from_length(
|
936
|
-
length=noise_files[noise_file_id].samples, tempo=noise_augmentation.tempo
|
924
|
+
length=noise_files[noise_file_id].samples, tempo=noise_augmentation.pre.tempo
|
937
925
|
)
|
938
926
|
return noise_file_id, noise_augmentation_id, noise_augmentation, noise_length
|
939
927
|
|
@@ -957,7 +945,7 @@ def _get_next_noise_offset(
|
|
957
945
|
|
958
946
|
noise_augmentation = augmentation_from_rule(noise_augmentations[noise_file_id], num_ir)
|
959
947
|
noise_length = estimate_augmented_length_from_length(
|
960
|
-
length=noise_files[noise_file_id].samples, tempo=noise_augmentation.tempo
|
948
|
+
length=noise_files[noise_file_id].samples, tempo=noise_augmentation.pre.tempo
|
961
949
|
)
|
962
950
|
if noise_offset + target_length >= noise_length:
|
963
951
|
if noise_offset == 0:
|
@@ -998,7 +986,7 @@ def _get_target_info(
|
|
998
986
|
target_length = max(
|
999
987
|
estimate_augmented_length_from_length(
|
1000
988
|
length=target_files[tfi].samples,
|
1001
|
-
tempo=target_augmentation.tempo,
|
989
|
+
tempo=target_augmentation.pre.tempo,
|
1002
990
|
frame_length=feature_step_samples,
|
1003
991
|
),
|
1004
992
|
target_length,
|
sonusai/mixture/helpers.py
CHANGED
@@ -117,11 +117,11 @@ def mixture_all_speech_metadata(mixdb: MixtureDatabase, mixture: Mixture) -> lis
|
|
117
117
|
# Check for tempo augmentation and adjust Interval start and end data as needed
|
118
118
|
entries = []
|
119
119
|
for entry in item:
|
120
|
-
if target.augmentation.tempo is not None:
|
120
|
+
if target.augmentation.pre.tempo is not None:
|
121
121
|
entries.append(
|
122
122
|
Interval(
|
123
|
-
entry.start / target.augmentation.tempo,
|
124
|
-
entry.end / target.augmentation.tempo,
|
123
|
+
entry.start / target.augmentation.pre.tempo,
|
124
|
+
entry.end / target.augmentation.pre.tempo,
|
125
125
|
entry.label,
|
126
126
|
)
|
127
127
|
)
|
@@ -135,22 +135,26 @@ def mixture_all_speech_metadata(mixdb: MixtureDatabase, mixture: Mixture) -> lis
|
|
135
135
|
return results
|
136
136
|
|
137
137
|
|
138
|
-
def mixture_metadata(mixdb: MixtureDatabase, m_id: int) -> str:
|
138
|
+
def mixture_metadata(mixdb: MixtureDatabase, m_id: int | None = None, mixture: Mixture | None = None) -> str:
|
139
139
|
"""Create a string of metadata for a Mixture
|
140
140
|
|
141
141
|
:param mixdb: Mixture database
|
142
142
|
:param m_id: Mixture ID
|
143
|
+
:param mixture: Mixture record
|
143
144
|
:return: String of metadata
|
144
145
|
"""
|
145
|
-
|
146
|
+
if m_id is not None:
|
147
|
+
mixture = mixdb.mixture(m_id)
|
148
|
+
|
149
|
+
if mixture is None:
|
150
|
+
raise ValueError("No mixture specified.")
|
151
|
+
|
146
152
|
metadata = ""
|
147
153
|
speech_metadata = mixture_all_speech_metadata(mixdb, mixture)
|
148
154
|
for mi, target in enumerate(mixture.targets):
|
149
155
|
target_file = mixdb.target_file(target.file_id)
|
150
|
-
target_augmentation = target.augmentation
|
151
156
|
metadata += f"target {mi} name: {target_file.name}\n"
|
152
157
|
metadata += f"target {mi} augmentation: {target.augmentation.to_dict()}\n"
|
153
|
-
metadata += f"target {mi} ir: {mixdb.impulse_response_file(target_augmentation.ir)}\n"
|
154
158
|
metadata += f"target {mi} target_gain: {target.gain if not mixture.is_noise_only else 0}\n"
|
155
159
|
metadata += f"target {mi} class indices: {target_file.class_indices}\n"
|
156
160
|
for key in target_file.truth_configs:
|
@@ -162,7 +166,6 @@ def mixture_metadata(mixdb: MixtureDatabase, m_id: int) -> str:
|
|
162
166
|
noise_augmentation = mixture.noise.augmentation
|
163
167
|
metadata += f"noise name: {noise.name}\n"
|
164
168
|
metadata += f"noise augmentation: {noise_augmentation.to_dict()}\n"
|
165
|
-
metadata += f"noise ir: {mixdb.impulse_response_file(noise_augmentation.ir)}\n"
|
166
169
|
metadata += f"noise offset: {mixture.noise_offset}\n"
|
167
170
|
metadata += f"snr: {mixture.snr}\n"
|
168
171
|
metadata += f"random_snr: {mixture.snr.is_random}\n"
|
@@ -173,17 +176,25 @@ def mixture_metadata(mixdb: MixtureDatabase, m_id: int) -> str:
|
|
173
176
|
return metadata
|
174
177
|
|
175
178
|
|
176
|
-
def write_mixture_metadata(mixdb: MixtureDatabase, m_id: int) -> None:
|
179
|
+
def write_mixture_metadata(mixdb: MixtureDatabase, m_id: int | None = None, mixture: Mixture | None = None) -> None:
|
177
180
|
"""Write mixture metadata to a text file
|
178
181
|
|
179
182
|
:param mixdb: Mixture database
|
180
183
|
:param m_id: Mixture ID
|
184
|
+
:param mixture: Mixture record
|
181
185
|
"""
|
182
186
|
from os.path import join
|
183
187
|
|
184
|
-
|
188
|
+
if m_id is not None:
|
189
|
+
name = mixdb.mixture(m_id).name
|
190
|
+
elif mixture is not None:
|
191
|
+
name = mixture.name
|
192
|
+
else:
|
193
|
+
raise ValueError("No mixture specified.")
|
194
|
+
|
195
|
+
name = join(mixdb.location, "mixture", name, "metadata.txt")
|
185
196
|
with open(file=name, mode="w") as f:
|
186
|
-
f.write(mixture_metadata(mixdb, m_id))
|
197
|
+
f.write(mixture_metadata(mixdb, m_id, mixture))
|
187
198
|
|
188
199
|
|
189
200
|
def from_mixture(
|
@@ -254,24 +265,25 @@ def get_target(mixdb: MixtureDatabase, mixture: Mixture, targets_audio: list[Aud
|
|
254
265
|
:param targets_audio: List of augmented target audio data (one per target in the mixup)
|
255
266
|
:return: Sum of augmented target audio data
|
256
267
|
"""
|
257
|
-
# Apply
|
268
|
+
# Apply post-truth augmentation effects to targets and sum
|
258
269
|
import numpy as np
|
259
270
|
|
260
|
-
from .
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
271
|
+
from .augmentation import apply_augmentation
|
272
|
+
|
273
|
+
targets_post = []
|
274
|
+
for idx, target_audio in enumerate(targets_audio):
|
275
|
+
target = mixture.targets[idx]
|
276
|
+
targets_post.append(
|
277
|
+
apply_augmentation(
|
278
|
+
mixdb=mixdb,
|
279
|
+
audio=target_audio,
|
280
|
+
augmentation=target.augmentation.post,
|
281
|
+
frame_length=mixdb.feature_step_samples,
|
269
282
|
)
|
270
|
-
|
271
|
-
targets_ir.append(target)
|
283
|
+
)
|
272
284
|
|
273
285
|
# Return sum of targets
|
274
|
-
return np.sum(
|
286
|
+
return np.sum(targets_post, axis=0)
|
275
287
|
|
276
288
|
|
277
289
|
def get_transform_from_audio(audio: AudioT, transform: ForwardTransform) -> tuple[AudioF, EnergyT]:
|
@@ -385,7 +397,7 @@ def augmented_target_samples(
|
|
385
397
|
[
|
386
398
|
estimate_augmented_length_from_length(
|
387
399
|
length=target_files[fi].samples,
|
388
|
-
tempo=target_augmentations[ai].tempo,
|
400
|
+
tempo=target_augmentations[ai].pre.tempo,
|
389
401
|
frame_length=feature_step_samples,
|
390
402
|
)
|
391
403
|
for fi, ai in it
|
@@ -405,7 +417,7 @@ def augmented_noise_samples(noise_files: list[NoiseFile], noise_augmentations: l
|
|
405
417
|
def augmented_noise_length(noise_file: NoiseFile, noise_augmentation: Augmentation) -> int:
|
406
418
|
from .augmentation import estimate_augmented_length_from_length
|
407
419
|
|
408
|
-
return estimate_augmented_length_from_length(length=noise_file.samples, tempo=noise_augmentation.tempo)
|
420
|
+
return estimate_augmented_length_from_length(length=noise_file.samples, tempo=noise_augmentation.pre.tempo)
|
409
421
|
|
410
422
|
|
411
423
|
def get_textgrid_tier_from_target_file(target_file: str, tier: str) -> SpeechMetadata | None:
|
@@ -0,0 +1,63 @@
|
|
1
|
+
import numpy as np
|
2
|
+
|
3
|
+
|
4
|
+
def get_impulse_response_delay(file: str) -> int:
|
5
|
+
from sonusai.utils import temp_seed
|
6
|
+
|
7
|
+
from .audio import raw_read_audio
|
8
|
+
|
9
|
+
ir, sample_rate = raw_read_audio(file)
|
10
|
+
|
11
|
+
with temp_seed(42):
|
12
|
+
wgn_ref = np.random.normal(loc=0, scale=0.2, size=int(np.ceil(0.05 * sample_rate))).astype(np.float32)
|
13
|
+
|
14
|
+
wgn_conv = np.convolve(ir, wgn_ref)
|
15
|
+
|
16
|
+
return int(np.round(tdoa(wgn_conv, wgn_ref, interp=16, phat=True)))
|
17
|
+
|
18
|
+
|
19
|
+
def tdoa(signal: np.ndarray, reference: np.ndarray, interp: int = 1, phat: bool = False, fs: int | float = 1) -> float:
|
20
|
+
"""Estimates the shift of array signal with respect to reference using generalized cross-correlation.
|
21
|
+
|
22
|
+
:param signal: The array whose tdoa is measured
|
23
|
+
:param reference: The reference array
|
24
|
+
:param interp: Interpolation factor for the output array
|
25
|
+
:param phat: Apply the PHAT weighting
|
26
|
+
:param fs: The sampling frequency of the input arrays
|
27
|
+
:return: The estimated delay between the two arrays
|
28
|
+
"""
|
29
|
+
n_reference = reference.shape[0]
|
30
|
+
|
31
|
+
r_12 = correlate(signal, reference, interp=interp, phat=phat)
|
32
|
+
|
33
|
+
delay = (np.argmax(np.abs(r_12)) / interp - (n_reference - 1)) / fs
|
34
|
+
|
35
|
+
return float(delay)
|
36
|
+
|
37
|
+
|
38
|
+
def correlate(x1: np.ndarray, x2: np.ndarray, interp: int = 1, phat: bool = False) -> np.ndarray:
|
39
|
+
"""Compute the cross-correlation between x1 and x2
|
40
|
+
|
41
|
+
:param x1: Input array 1
|
42
|
+
:param x2: Input array 2
|
43
|
+
:param interp: Interpolation factor for the output array
|
44
|
+
:param phat: Apply the PHAT weighting
|
45
|
+
:return: The cross-correlation between the two arrays
|
46
|
+
"""
|
47
|
+
n_x1 = x1.shape[0]
|
48
|
+
n_x2 = x2.shape[0]
|
49
|
+
|
50
|
+
n = n_x1 + n_x2 - 1
|
51
|
+
|
52
|
+
fft1 = np.fft.rfft(x1, n=n)
|
53
|
+
fft2 = np.fft.rfft(x2, n=n)
|
54
|
+
|
55
|
+
if phat:
|
56
|
+
eps1 = np.mean(np.abs(fft1)) * 1e-10
|
57
|
+
fft1 /= np.abs(fft1) + eps1
|
58
|
+
eps2 = np.mean(np.abs(fft2)) * 1e-10
|
59
|
+
fft2 /= np.abs(fft2) + eps2
|
60
|
+
|
61
|
+
out = np.fft.irfft(fft1 * np.conj(fft2), n=int(n * interp))
|
62
|
+
|
63
|
+
return np.concatenate([out[-interp * (n_x2 - 1) :], out[: (interp * n_x1)]])
|