sonusai 0.15.8__py3-none-any.whl → 0.15.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/audiofe.py +293 -0
- sonusai/calc_metric_spenh.py +3 -3
- sonusai/data_generator/dataset_from_mixdb.py +1 -1
- sonusai/data_generator/keras_from_mixdb.py +1 -1
- sonusai/genft.py +2 -1
- sonusai/genmixdb.py +4 -4
- sonusai/keras_predict.py +1 -1
- sonusai/lsdb.py +2 -2
- sonusai/main.py +2 -2
- sonusai/mixture/__init__.py +3 -2
- sonusai/mixture/audio.py +0 -34
- sonusai/mixture/datatypes.py +1 -1
- sonusai/mixture/feature.py +75 -21
- sonusai/mixture/helpers.py +60 -30
- sonusai/mixture/log_duration_and_sizes.py +2 -2
- sonusai/mixture/mixdb.py +13 -10
- sonusai/mixture/spectral_mask.py +14 -14
- sonusai/mixture/truth_functions/data.py +1 -1
- sonusai/mixture/truth_functions/target.py +2 -2
- sonusai/onnx_predict.py +1 -1
- sonusai/plot.py +4 -4
- sonusai/post_spenh_targetf.py +8 -8
- sonusai/torchl_predict.py +71 -76
- sonusai/utils/__init__.py +4 -0
- sonusai/utils/audio_devices.py +41 -0
- sonusai/utils/calculate_input_shape.py +3 -4
- sonusai/utils/create_timestamp.py +5 -0
- sonusai/utils/reshape.py +11 -11
- sonusai/utils/wave.py +12 -5
- {sonusai-0.15.8.dist-info → sonusai-0.15.9.dist-info}/METADATA +8 -1
- {sonusai-0.15.8.dist-info → sonusai-0.15.9.dist-info}/RECORD +33 -31
- {sonusai-0.15.8.dist-info → sonusai-0.15.9.dist-info}/WHEEL +1 -1
- sonusai/evaluate.py +0 -245
- {sonusai-0.15.8.dist-info → sonusai-0.15.9.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,41 @@
|
|
1
|
+
import pyaudio
|
2
|
+
|
3
|
+
|
4
|
+
def get_input_device_index_by_name(p: pyaudio.PyAudio, name: str = None) -> int:
|
5
|
+
info = p.get_host_api_info_by_index(0)
|
6
|
+
device_count = info.get('deviceCount')
|
7
|
+
for i in range(0, device_count):
|
8
|
+
device_info = p.get_device_info_by_host_api_device_index(0, i)
|
9
|
+
if name is None:
|
10
|
+
device_name = None
|
11
|
+
else:
|
12
|
+
device_name = device_info.get('name')
|
13
|
+
if name == device_name and device_info.get('maxInputChannels') > 0:
|
14
|
+
return i
|
15
|
+
|
16
|
+
raise ValueError(f'Could not find {name}')
|
17
|
+
|
18
|
+
|
19
|
+
def get_input_devices(p: pyaudio.PyAudio) -> list[str]:
|
20
|
+
names = []
|
21
|
+
info = p.get_host_api_info_by_index(0)
|
22
|
+
device_count = info.get('deviceCount')
|
23
|
+
for i in range(0, device_count):
|
24
|
+
device_info = p.get_device_info_by_host_api_device_index(0, i)
|
25
|
+
device_name = device_info.get('name')
|
26
|
+
if device_info.get('maxInputChannels') > 0:
|
27
|
+
names.append(device_name)
|
28
|
+
|
29
|
+
return names
|
30
|
+
|
31
|
+
|
32
|
+
def get_default_input_device(p: pyaudio.PyAudio) -> str:
|
33
|
+
info = p.get_host_api_info_by_index(0)
|
34
|
+
device_count = info.get('deviceCount')
|
35
|
+
for i in range(0, device_count):
|
36
|
+
device_info = p.get_device_info_by_host_api_device_index(0, i)
|
37
|
+
device_name = device_info.get('name')
|
38
|
+
if device_info.get('maxInputChannels') > 0:
|
39
|
+
return device_name
|
40
|
+
|
41
|
+
raise ValueError('No input audio devices found')
|
@@ -13,13 +13,12 @@ def calculate_input_shape(feature: str,
|
|
13
13
|
"""
|
14
14
|
from pyaaware import FeatureGenerator
|
15
15
|
|
16
|
-
|
17
|
-
fg = FeatureGenerator(feature_mode=feature, num_classes=2)
|
16
|
+
fg = FeatureGenerator(feature_mode=feature)
|
18
17
|
|
19
18
|
if flatten:
|
20
|
-
in_shape = [fg.stride * fg.
|
19
|
+
in_shape = [fg.stride * fg.feature_parameters]
|
21
20
|
else:
|
22
|
-
in_shape = [fg.stride, fg.
|
21
|
+
in_shape = [fg.stride, fg.feature_parameters]
|
23
22
|
|
24
23
|
if timesteps > 0:
|
25
24
|
in_shape.insert(0, timesteps)
|
sonusai/utils/reshape.py
CHANGED
@@ -17,14 +17,14 @@ def reshape_inputs(feature: Feature,
|
|
17
17
|
timesteps: int = 0,
|
18
18
|
flatten: bool = False,
|
19
19
|
add1ch: bool = False) -> tuple[Feature, Optional[Truth]]:
|
20
|
-
"""Check SonusAI feature and truth data and reshape feature of size [frames, strides,
|
20
|
+
"""Check SonusAI feature and truth data and reshape feature of size [frames, strides, feature_parameters] into
|
21
21
|
one of several options:
|
22
22
|
|
23
23
|
If timesteps > 0: (i.e., for recurrent NNs):
|
24
|
-
no-flatten, no-channel: [sequences, timesteps, strides,
|
25
|
-
flatten, no-channel: [sequences, timesteps, strides*
|
26
|
-
no-flatten, add-1channel: [sequences, timesteps, strides,
|
27
|
-
flatten, add-1channel: [sequences, timesteps, strides*
|
24
|
+
no-flatten, no-channel: [sequences, timesteps, strides, feature_parameters] (4-dim)
|
25
|
+
flatten, no-channel: [sequences, timesteps, strides*feature_parameters] (3-dim)
|
26
|
+
no-flatten, add-1channel: [sequences, timesteps, strides, feature_parameters, 1] (5-dim)
|
27
|
+
flatten, add-1channel: [sequences, timesteps, strides*feature_parameters, 1] (4-dim)
|
28
28
|
|
29
29
|
If batch_size is None, then do not reshape; just calculate new input shape and return.
|
30
30
|
|
@@ -40,7 +40,7 @@ def reshape_inputs(feature: Feature,
|
|
40
40
|
"""
|
41
41
|
from sonusai import SonusAIError
|
42
42
|
|
43
|
-
frames, strides,
|
43
|
+
frames, strides, feature_parameters = feature.shape
|
44
44
|
if truth is not None:
|
45
45
|
truth_frames, num_classes = truth.shape
|
46
46
|
# Double-check correctness of inputs
|
@@ -50,7 +50,7 @@ def reshape_inputs(feature: Feature,
|
|
50
50
|
num_classes = None
|
51
51
|
|
52
52
|
if flatten:
|
53
|
-
feature = np.reshape(feature, (frames, strides *
|
53
|
+
feature = np.reshape(feature, (frames, strides * feature_parameters))
|
54
54
|
|
55
55
|
# Reshape for Keras/TF recurrent models that require timesteps/sequence length dimension
|
56
56
|
if timesteps > 0:
|
@@ -73,14 +73,14 @@ def reshape_inputs(feature: Feature,
|
|
73
73
|
|
74
74
|
# Reshape
|
75
75
|
if feature.ndim == 2: # flattened input
|
76
|
-
# was [frames,
|
77
|
-
feature = np.reshape(feature, (sequences, timesteps, strides *
|
76
|
+
# was [frames, feature_parameters*timesteps]
|
77
|
+
feature = np.reshape(feature, (sequences, timesteps, strides * feature_parameters))
|
78
78
|
if truth is not None:
|
79
79
|
# was [frames, num_classes]
|
80
80
|
truth = np.reshape(truth, (sequences, timesteps, num_classes))
|
81
81
|
elif feature.ndim == 3: # un-flattened input
|
82
|
-
# was [frames,
|
83
|
-
feature = np.reshape(feature, (sequences, timesteps, strides,
|
82
|
+
# was [frames, feature_parameters, timesteps]
|
83
|
+
feature = np.reshape(feature, (sequences, timesteps, strides, feature_parameters))
|
84
84
|
if truth is not None:
|
85
85
|
# was [frames, num_classes]
|
86
86
|
truth = np.reshape(truth, (sequences, timesteps, num_classes))
|
sonusai/utils/wave.py
CHANGED
@@ -5,15 +5,22 @@ from sonusai.mixture.datatypes import AudioT
|
|
5
5
|
def write_wav(name: str, audio: AudioT, sample_rate: int = SAMPLE_RATE) -> None:
|
6
6
|
""" Write a simple, uncompressed WAV file.
|
7
7
|
|
8
|
-
To write multiple channels, use a 2D array of shape [
|
8
|
+
To write multiple channels, use a 2D array of shape [channels, samples].
|
9
9
|
The bits per sample and PCM/float are determined by the data type.
|
10
10
|
|
11
11
|
"""
|
12
|
-
import numpy as np
|
13
12
|
import torch
|
14
13
|
import torchaudio
|
15
14
|
|
16
|
-
|
17
|
-
audio = np.reshape(audio, (1, audio.shape[0]))
|
15
|
+
data = torch.tensor(audio)
|
18
16
|
|
19
|
-
|
17
|
+
if data.dim() == 1:
|
18
|
+
data = torch.reshape(data, (1, data.shape[0]))
|
19
|
+
if data.dim() != 2:
|
20
|
+
raise ValueError(f'audio must be a 1D or 2D array')
|
21
|
+
|
22
|
+
# Assuming data has more samples than channels, check if array needs to be transposed
|
23
|
+
if data.shape[1] < data.shape[0]:
|
24
|
+
data = torch.transpose(data, 0, 1)
|
25
|
+
|
26
|
+
torchaudio.save(uri=name, src=data, sample_rate=sample_rate)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: sonusai
|
3
|
-
Version: 0.15.
|
3
|
+
Version: 0.15.9
|
4
4
|
Summary: Framework for building deep neural network models for sound, speech, and voice AI
|
5
5
|
Home-page: https://aaware.com
|
6
6
|
License: GPL-3.0-only
|
@@ -16,28 +16,35 @@ Classifier: Programming Language :: Python :: 3.10
|
|
16
16
|
Classifier: Programming Language :: Python :: 3.11
|
17
17
|
Requires-Dist: PyYAML (>=6.0.1,<7.0.0)
|
18
18
|
Requires-Dist: aixplain (>=0.2.6,<0.3.0)
|
19
|
+
Requires-Dist: bitarray (>=2.9.2,<3.0.0)
|
19
20
|
Requires-Dist: ctranslate2 (==4.1.0)
|
20
21
|
Requires-Dist: dataclasses-json (>=0.6.1,<0.7.0)
|
21
22
|
Requires-Dist: deepgram-sdk (>=3.0.0,<4.0.0)
|
22
23
|
Requires-Dist: docopt (>=0.6.2,<0.7.0)
|
24
|
+
Requires-Dist: einops (>=0.7.0,<0.8.0)
|
23
25
|
Requires-Dist: faster-whisper (>=1.0.1,<2.0.0)
|
26
|
+
Requires-Dist: geomloss (>=0.2.6,<0.3.0)
|
24
27
|
Requires-Dist: h5py (>=3.11.0,<4.0.0)
|
28
|
+
Requires-Dist: hydra-core (>=1.3.2,<2.0.0)
|
25
29
|
Requires-Dist: jiwer (>=3.0.3,<4.0.0)
|
26
30
|
Requires-Dist: keras (>=3.1.1,<4.0.0)
|
27
31
|
Requires-Dist: keras-tuner (>=1.4.7,<2.0.0)
|
28
32
|
Requires-Dist: librosa (>=0.10.1,<0.11.0)
|
29
33
|
Requires-Dist: lightning (>=2.2,<2.3)
|
30
34
|
Requires-Dist: matplotlib (>=3.8.0,<4.0.0)
|
35
|
+
Requires-Dist: omegaconf (>=2.3.0,<3.0.0)
|
31
36
|
Requires-Dist: onnx (>=1.14.1,<2.0.0)
|
32
37
|
Requires-Dist: onnxruntime (>=1.16.1,<2.0.0)
|
33
38
|
Requires-Dist: paho-mqtt (>=2.0.0,<3.0.0)
|
34
39
|
Requires-Dist: pandas (>=2.1.1,<3.0.0)
|
35
40
|
Requires-Dist: pesq (>=0.0.4,<0.0.5)
|
36
41
|
Requires-Dist: pyaaware (>=1.5.3,<2.0.0)
|
42
|
+
Requires-Dist: pyaudio (>=0.2.14,<0.3.0)
|
37
43
|
Requires-Dist: pydub (>=0.25.1,<0.26.0)
|
38
44
|
Requires-Dist: pystoi (>=0.4.0,<0.5.0)
|
39
45
|
Requires-Dist: python-magic (>=0.4.27,<0.5.0)
|
40
46
|
Requires-Dist: requests (>=2.31.0,<3.0.0)
|
47
|
+
Requires-Dist: sacrebleu (>=2.4.2,<3.0.0)
|
41
48
|
Requires-Dist: samplerate (>=0.2.1,<0.3.0)
|
42
49
|
Requires-Dist: soundfile (>=0.12.1,<0.13.0)
|
43
50
|
Requires-Dist: sox (>=1.4.1,<2.0.0)
|
@@ -1,27 +1,27 @@
|
|
1
1
|
sonusai/__init__.py,sha256=KmIJ9wni9d9v5pyu0pUxbacZIHGkAywB9CJwl7JME28,1526
|
2
2
|
sonusai/aawscd_probwrite.py,sha256=GukR5owp_0A3DrqSl9fHWULYgclNft4D5OkHIwfxxkc,3698
|
3
|
-
sonusai/
|
3
|
+
sonusai/audiofe.py,sha256=XE_cgOhhTryjPUePxW_8NY1TwrnRZ6BHCsH-gp8PmYw,11471
|
4
|
+
sonusai/calc_metric_spenh.py,sha256=D8iQVSIhFhrsUwKuIP-S38NBnyfAOZlsOIIgOZwGOOI,60852
|
4
5
|
sonusai/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
5
6
|
sonusai/data/genmixdb.yml,sha256=-XSs_hUR6wHJVoTPmSewzXL7u61X-xmHY46lNPatxSE,1025
|
6
7
|
sonusai/data/speech_ma01_01.wav,sha256=PK0vMKg-NR6rPE3KouxHGF6PKXnJCr7AwjMqfu98LUA,76644
|
7
8
|
sonusai/data/whitenoise.wav,sha256=I2umov0m34y56F9IsIBi1XtE76ZeZaSKDf70cJRe3pI,1920044
|
8
9
|
sonusai/data_generator/__init__.py,sha256=ouCpY5EDV35fKFeKGQfIcU8uE-c3QcuNerTxUA1X5L8,232
|
9
|
-
sonusai/data_generator/dataset_from_mixdb.py,sha256=
|
10
|
-
sonusai/data_generator/keras_from_mixdb.py,sha256=
|
10
|
+
sonusai/data_generator/dataset_from_mixdb.py,sha256=D14L8BL7a0WgkF8a8eogQ9Hk9ow4_RK3QBGsZ-HDAog,5493
|
11
|
+
sonusai/data_generator/keras_from_mixdb.py,sha256=14r89aX6Dr9ZKsmMRC7HDXbJrPrCZC1liwwLmZUKj0w,6182
|
11
12
|
sonusai/data_generator/torch_from_mixdb.py,sha256=lvEe9DDu_rIaoyhv9PW4UAnAWp5N74L8kRfxUhsh7oo,4279
|
12
13
|
sonusai/doc/__init__.py,sha256=rP5Hgn0Iys_xkuv4caxngdqehuU4zLZsiKuv8Nde67M,19
|
13
14
|
sonusai/doc/doc.py,sha256=3z210v6ZckuOlsGZ3ySQBdlCNmBp2M1ahqhqG_eUN58,22664
|
14
15
|
sonusai/doc.py,sha256=l8CaFgLI8mqx4tn0aXfxKqa2dy9GgC0zjYxZAkpmi1E,878
|
15
|
-
sonusai/
|
16
|
-
sonusai/genft.py,sha256=CeQN8Sxb_NKeXWJxN9HtzUu687eXl97nHBxzzCzQdLg,5557
|
16
|
+
sonusai/genft.py,sha256=6hOds8d-pYRawesLYh7XLrQh4PweWUj8U5Cbzx45bNQ,5572
|
17
17
|
sonusai/genmix.py,sha256=0AiUfF7n0CGOp5v-woNfeP3-QaVQUb0CJZc0oXkvqpk,7016
|
18
|
-
sonusai/genmixdb.py,sha256=
|
18
|
+
sonusai/genmixdb.py,sha256=rAxCKNPkOXaAugEfp9pTcHCQONapdTnxMlBsIPIoizE,19639
|
19
19
|
sonusai/gentcst.py,sha256=8jeXirgJe0OGgknC8A-rIudjHeH8UTYPpuh71Ha-I3w,20165
|
20
20
|
sonusai/keras_onnx.py,sha256=WHcPIcff4VPdiXqGX-TU-_x-UuKUD3nNpQtCX-2NEVQ,2658
|
21
|
-
sonusai/keras_predict.py,sha256=
|
21
|
+
sonusai/keras_predict.py,sha256=_83EtPtnfrqwUzC2H2tk4LI90RiQdyEEBxFGTgFPl3M,9090
|
22
22
|
sonusai/keras_train.py,sha256=8_M5vY-CkonPzbOtOF3Vk-wox-42o8fkaOKLjk7Oc2k,13226
|
23
|
-
sonusai/lsdb.py,sha256=
|
24
|
-
sonusai/main.py,sha256=
|
23
|
+
sonusai/lsdb.py,sha256=TTMQ-0H8fFzUSczt6yjy-9xUjZSdIGQzTVH5Xr6XPSA,5941
|
24
|
+
sonusai/main.py,sha256=KjN0dCI6rWare4wo_ACzTlURW7pvTw03n51pH7EyLAU,3108
|
25
25
|
sonusai/metrics/__init__.py,sha256=56itZW3S1I7ZYvbxPmFIVPAh1AIJZdljByz1uCrHqFE,635
|
26
26
|
sonusai/metrics/calc_class_weights.py,sha256=dyY7daEIf5Ms5tfTf6wF0fkx_GnMADHOZR_rtsfGoVM,3933
|
27
27
|
sonusai/metrics/calc_optimal_thresholds.py,sha256=9fRfwl-aKAbzHJyqGHv4o8BpZXG9HHB7zUJObHXfYM4,3522
|
@@ -35,24 +35,24 @@ sonusai/metrics/class_summary.py,sha256=4Mb25nuk6eqotnQSFMuOQL3zofGcpNXDfDlPa513
|
|
35
35
|
sonusai/metrics/confusion_matrix_summary.py,sha256=3qg6TMKjJeHtNjj2YnNjPFSlMrQXt0Zcu1dLkGB_aPU,4001
|
36
36
|
sonusai/metrics/one_hot.py,sha256=QSeH_GdqBpOAKLrNnQ8gjcPC-vSdUqC0yPEQueTA6VI,13548
|
37
37
|
sonusai/metrics/snr_summary.py,sha256=P4U5_Xr7v9F8kF-rZBnpsVNt3p42rIVS6zmch8yfVfg,5575
|
38
|
-
sonusai/mixture/__init__.py,sha256=
|
39
|
-
sonusai/mixture/audio.py,sha256=
|
38
|
+
sonusai/mixture/__init__.py,sha256=fCVSlizYxUUQQD9nSZ8bEbfc_TB2yiOC14HPOB4KFz4,5287
|
39
|
+
sonusai/mixture/audio.py,sha256=S-ZROf5rVvwv1TCEuwJHz1FfX4oVubb4QhbybUMMqtM,2150
|
40
40
|
sonusai/mixture/augmentation.py,sha256=Blb90tdTwBOj5w9tRcYyS5H67YJuFiXsGqwZWd7ON4g,10468
|
41
41
|
sonusai/mixture/class_count.py,sha256=_wFnVl2yEOnbor7pLg7cYOUeX6nioov-03Cv3SEbh2k,996
|
42
42
|
sonusai/mixture/config.py,sha256=CXIkVRJmaW2QW_sGl0aIqPf7I_TesyGhUYzxouw5UX4,22266
|
43
43
|
sonusai/mixture/constants.py,sha256=xjCskcQi6khqYZDf7j6z1OkeN1C6wE06kBBapcJiNI4,1428
|
44
|
-
sonusai/mixture/datatypes.py,sha256=
|
44
|
+
sonusai/mixture/datatypes.py,sha256=zaxfOHw8ddt-i8JPYOPnlqWz_EHBEDoO4q2VAqJViHM,8173
|
45
45
|
sonusai/mixture/eq_rule_is_valid.py,sha256=MpQwRA5M76wSiQWEI1lW2cLFdPaMttBLcQp3tWD8efM,1243
|
46
|
-
sonusai/mixture/feature.py,sha256=
|
46
|
+
sonusai/mixture/feature.py,sha256=Rwuf82IoXzhHPGbKYVGcatImF_ssBf_FfvbqghVPXtg,4116
|
47
47
|
sonusai/mixture/generation.py,sha256=miUrc3QOSUNIG6mDkiMCZ6M2ulivUZxlYUAJUOVomWc,39039
|
48
|
-
sonusai/mixture/helpers.py,sha256=
|
49
|
-
sonusai/mixture/log_duration_and_sizes.py,sha256=
|
48
|
+
sonusai/mixture/helpers.py,sha256=GSGSD2KnvOeEIB6IwNTxyaQNjghTSBMB729kUEd_RiM,22403
|
49
|
+
sonusai/mixture/log_duration_and_sizes.py,sha256=baTUpqyM15wA125jo9E3posmVJUe3WlpksyO6v9Jul0,1347
|
50
50
|
sonusai/mixture/mapped_snr_f.py,sha256=mlbYM1t14OXe_Zg4CjpWTuA_Zun4W0O3bSUXeodRBQs,1845
|
51
|
-
sonusai/mixture/mixdb.py,sha256=
|
51
|
+
sonusai/mixture/mixdb.py,sha256=9Pe0mEG8pnEf9NZynTIldc05GfdOrgmcVoIt63RG5DA,45279
|
52
52
|
sonusai/mixture/soundfile_audio.py,sha256=Ow_IWIMz4pMsLxMP_JsQ8AuHLCWlYQinLa58CFW97f8,2804
|
53
53
|
sonusai/mixture/sox_audio.py,sha256=HT3kYA9TP5QPCuoOJdUMnGVN-qY6q96DGL8zxuog76o,12277
|
54
54
|
sonusai/mixture/sox_augmentation.py,sha256=F9tBdNvX2guCn7gRppAFrxRnBtjw9q6qAq2_v_A4hh0,4490
|
55
|
-
sonusai/mixture/spectral_mask.py,sha256=
|
55
|
+
sonusai/mixture/spectral_mask.py,sha256=8AkCwhy-PSdP1Uri9miKZP-bXFYnFcH_c9xZCGrHavU,2071
|
56
56
|
sonusai/mixture/target_class_balancing.py,sha256=NTNiKZH0_PWLooeow0l41CjJKK8ZTMVbUqz9ZkaNtWk,4900
|
57
57
|
sonusai/mixture/targets.py,sha256=wyy5vhLhuN-hqBMBGoziVvEJg3FKFvJFgmEE7_LaV2M,7908
|
58
58
|
sonusai/mixture/tokenized_shell_vars.py,sha256=gCxw8SQUcal6mqWKF7hOBTgSQmbJUk1nT0Gn3H8GA0U,4705
|
@@ -61,24 +61,24 @@ sonusai/mixture/torchaudio_augmentation.py,sha256=1vEDHI0caL1vrgoY2lAWe4CiHE2jKR
|
|
61
61
|
sonusai/mixture/truth.py,sha256=Y41pZ52Xkols9LUler0NlgnilUOscBIucmw4GcxXNzU,1612
|
62
62
|
sonusai/mixture/truth_functions/__init__.py,sha256=82lKYHhLy8KW3gHngrocoqwupGVLVsWdIXdYs3vhjOc,359
|
63
63
|
sonusai/mixture/truth_functions/crm.py,sha256=_Vy8UMrOUQXsrM3nutvUMWCpvI8GePr01QFlyqLFd4k,2626
|
64
|
-
sonusai/mixture/truth_functions/data.py,sha256=
|
64
|
+
sonusai/mixture/truth_functions/data.py,sha256=okFJeOf43NxfdLqWFCBA2pOGqujRlNDYdAcwwR_m8z8,2875
|
65
65
|
sonusai/mixture/truth_functions/energy.py,sha256=ydMtMLjMloG76DB30ZHQ5tkBVh4dkMJ82XEhKBokmIk,4281
|
66
66
|
sonusai/mixture/truth_functions/file.py,sha256=jOJuC_3y9BH6GGOp9eKcbVrHLVRzUA80BJq59LhcBUM,1539
|
67
67
|
sonusai/mixture/truth_functions/phoneme.py,sha256=stYdlPuNytQK_LLT61OJLfYSqKd-sDjQZdtJKGzt5wA,479
|
68
68
|
sonusai/mixture/truth_functions/sed.py,sha256=8cHjEFjZaH_0hIOHhPmj4AJz2GpEADM6Ys2x4NoiWSY,2469
|
69
|
-
sonusai/mixture/truth_functions/target.py,sha256=
|
69
|
+
sonusai/mixture/truth_functions/target.py,sha256=KAsjugDRooOA5BRcHVAbZRgV7l8S5CFg7CZ0XtKZaQ0,5764
|
70
70
|
sonusai/mkmanifest.py,sha256=dIPVFKKhnhHdq63OGr6p__pK7fyx3OdKVtbmGUJxsR8,7078
|
71
71
|
sonusai/mkwav.py,sha256=LZNyhq4gJEs_NtGvRsYHA2qfgkkODpt6HoH1b-Tjjuw,5266
|
72
|
-
sonusai/onnx_predict.py,sha256=
|
73
|
-
sonusai/plot.py,sha256=
|
74
|
-
sonusai/post_spenh_targetf.py,sha256=
|
72
|
+
sonusai/onnx_predict.py,sha256=Bz_pR28oAZBarNajlKwyzBxmW7ktum77SmxDN2onKPM,9060
|
73
|
+
sonusai/plot.py,sha256=u-PvF8guNcm0b-GN99xfEkrcAAtidAEY3RLDzNvcyYk,17014
|
74
|
+
sonusai/post_spenh_targetf.py,sha256=NIMhDXeDuUqeWukNaAUMvDw9JpEVCauwjrL2F4M9nrI,4927
|
75
75
|
sonusai/queries/__init__.py,sha256=oKY5JeqZ4Cz7DwCwPc1_ydB8bUs6KaMcWFp_w02TjOs,255
|
76
76
|
sonusai/queries/queries.py,sha256=FNMUKnoY_Ya9S5sNhsB8ppwy0B7V55ilbbjhQRv_UN8,7552
|
77
77
|
sonusai/torchl_onnx.py,sha256=5JYow3XpBaUdtuyAW0mOZyCKL_4FrHvEekYBRdDT6KA,8967
|
78
|
-
sonusai/torchl_predict.py,sha256
|
78
|
+
sonusai/torchl_predict.py,sha256=P1ySDH_ITOPefZ2xZqyxyIrsNDqblKTBLZqFApgo5EU,26238
|
79
79
|
sonusai/torchl_train.py,sha256=NPCRB0gwTvabivmOz78gjUreDeO1z16PYuw7L1-pIRQ,9680
|
80
80
|
sonusai/tplot.py,sha256=yFyyyg9ymp2Eh-64Muu0EFFEY61MoJSV0a_fy9OWaCk,14485
|
81
|
-
sonusai/utils/__init__.py,sha256=
|
81
|
+
sonusai/utils/__init__.py,sha256=tVSmxinSo0Enexpol6wCzz6tU7WrueC-YslFgQr-o7M,2382
|
82
82
|
sonusai/utils/asl_p56.py,sha256=GCKlz-NLInQ0z41XBi0mOvGdSfRZf3WI53necVNDo80,3837
|
83
83
|
sonusai/utils/asr.py,sha256=QN1wdO9-EqD72-ixr4lnzsPfT8i0syhTGj1evKNJWe4,2021
|
84
84
|
sonusai/utils/asr_functions/__init__.py,sha256=4boXXOXlQHTt8K2DWOwFXSlc8D2NLFd8QTc68yL2ejU,214
|
@@ -93,9 +93,11 @@ sonusai/utils/asr_manifest_functions/__init__.py,sha256=Lz12aCGvfngZkLoUxHSqFjHc
|
|
93
93
|
sonusai/utils/asr_manifest_functions/data.py,sha256=mJsaHccBReguOJu9qsshRhL-3GbeyqM0-PXMseFnZbE,151
|
94
94
|
sonusai/utils/asr_manifest_functions/librispeech.py,sha256=HIaytcYmjRUkuR6fCQlv3Jh3IDWSox_A6WFcFFAHN9M,1635
|
95
95
|
sonusai/utils/asr_manifest_functions/vctk_noisy_speech.py,sha256=-69lM0dz18KbU5_-dmSeqDoNNwgJj4UlxgGkNBEi3wM,2169
|
96
|
+
sonusai/utils/audio_devices.py,sha256=LgaXTln1oRArBzaet3rZiIO2plgtaThuGBc3sJ_sLlo,1414
|
96
97
|
sonusai/utils/braced_glob.py,sha256=h4hab7YDbM4CjLg9iSzyHZrkd22IPUOY5zZqHdifkh8,1510
|
97
|
-
sonusai/utils/calculate_input_shape.py,sha256=
|
98
|
+
sonusai/utils/calculate_input_shape.py,sha256=63ILxibYKuTQozY83QN8Y2OOhBEbW_1X47Q0askcHDM,984
|
98
99
|
sonusai/utils/convert_string_to_number.py,sha256=i17yIxurp8Iz6NPE-imTRlARrXWqadwm8qbOTuzHZvE,236
|
100
|
+
sonusai/utils/create_timestamp.py,sha256=TxoQXWZ3SFdBEHLOv-ujeIsTEJuiFnKOGRy-FQq45YU,148
|
99
101
|
sonusai/utils/create_ts_name.py,sha256=8RLKmgXwuGcbDMGgtTuc0MvGFfA7IOVqfjkE2T18GOo,405
|
100
102
|
sonusai/utils/dataclass_from_dict.py,sha256=vAGnuMjhy0W9bxZ5usrH7mbQsFog3n0__IC4xyJyVUc,390
|
101
103
|
sonusai/utils/db.py,sha256=lI77MJJLs4CTYxhjFUvBom2Kk2imAP34okOeO4irbDc,371
|
@@ -114,15 +116,15 @@ sonusai/utils/print_mixture_details.py,sha256=BzYM4-wHHNa6zxPzBMUJxwKt0gKHmvbwdd
|
|
114
116
|
sonusai/utils/ranges.py,sha256=NPBZOVzMb95GTOIxltVO-wSzgcXqZ14wbdV46JDLKrw,1222
|
115
117
|
sonusai/utils/read_mixture_data.py,sha256=Sb30RgSpw6DnH_iD81O7G_KOsdfjQWWLk3euEkxfMa8,453
|
116
118
|
sonusai/utils/read_predict_data.py,sha256=5rR_ijrrcS2cKO1Sea2M2QEicokTtW5XtAo6jT5YSX8,1064
|
117
|
-
sonusai/utils/reshape.py,sha256=
|
119
|
+
sonusai/utils/reshape.py,sha256=E8Eu6grynaeWwVO6peIR0BF22SrVaJSa1Rkl109lq6Y,5997
|
118
120
|
sonusai/utils/seconds_to_hms.py,sha256=oxLuZhTJJr9swj-fOSOrZJ5vBNM7_BrOMQhX1pYpiv0,260
|
119
121
|
sonusai/utils/stacked_complex.py,sha256=feLhz3GC1ILxBGMHOj3sJK--sidsXKbfwkalwAVwizc,2950
|
120
122
|
sonusai/utils/stratified_shuffle_split.py,sha256=rJNXvBp-GxoKzH3OpL7k0ANSu5xMP2zJ7K1fm_33UzE,7022
|
121
123
|
sonusai/utils/trim_docstring.py,sha256=dSrtiRsEN4wkkvKBp6WDr13RUypfqZzgH_jOBLs1ouY,881
|
122
|
-
sonusai/utils/wave.py,sha256=
|
124
|
+
sonusai/utils/wave.py,sha256=O4ZXkZ6wjrKGa99wBCdFd8G6bp91MXXDnmGihpaEMh0,856
|
123
125
|
sonusai/utils/yes_or_no.py,sha256=eMLXBVH0cEahiXY4W2KNORmwNQ-ba10eRtldh0y4NYg,263
|
124
126
|
sonusai/vars.py,sha256=m2AefF0m5bXWGXpJj8Pi42zWL2ydeEj7bkak3GrtMyM,940
|
125
|
-
sonusai-0.15.
|
126
|
-
sonusai-0.15.
|
127
|
-
sonusai-0.15.
|
128
|
-
sonusai-0.15.
|
127
|
+
sonusai-0.15.9.dist-info/METADATA,sha256=DudNQlTEQpWpzqyzyowz_V-J9epd7mrKgAYM6rFxaPo,3209
|
128
|
+
sonusai-0.15.9.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
129
|
+
sonusai-0.15.9.dist-info/entry_points.txt,sha256=zMNjEphEPO6B3cD1GNpit7z-yA9tUU5-j3W2v-UWstU,92
|
130
|
+
sonusai-0.15.9.dist-info/RECORD,,
|
sonusai/evaluate.py
DELETED
@@ -1,245 +0,0 @@
|
|
1
|
-
"""sonusai evaluate
|
2
|
-
|
3
|
-
usage: evaluate [-hv] [-i MIXID] (-f FEATURE) (-p PREDICT) [-t PTHR] LOC
|
4
|
-
|
5
|
-
options:
|
6
|
-
-h, --help
|
7
|
-
-v, --verbose Be verbose.
|
8
|
-
-i MIXID, --mixid MIXID Mixture ID(s) to generate. [default: *].
|
9
|
-
-p PREDICT, --predict PREDICT A directory containing prediction data.
|
10
|
-
-t PTHR, --thr PTHR Optional prediction decision threshold(s). [default: 0].
|
11
|
-
|
12
|
-
Evaluate calculates performance metrics of neural-network models from model prediction data and genft data.
|
13
|
-
|
14
|
-
Inputs:
|
15
|
-
LOC A SonusAI mixture database directory.
|
16
|
-
MIXID A glob of mixture ID(s) to generate.
|
17
|
-
PREDICT A directory containing SonusAI predict HDF5 files. Contains:
|
18
|
-
dataset: predict (either [frames, num_classes] or [frames, timesteps, num_classes])
|
19
|
-
PTHR Scalar or array of thresholds. Default 0 will select values:
|
20
|
-
argmax() if mixdb indicates single-label mode (truth_mutex = true)
|
21
|
-
0.5 if mixdb indicates multi-label mode (truth_mutex = false)
|
22
|
-
If PTHR = -1, optimal thresholds are calculated using precision_recall_curve() which
|
23
|
-
optimizes F1 score.
|
24
|
-
"""
|
25
|
-
import numpy as np
|
26
|
-
|
27
|
-
from sonusai import logger
|
28
|
-
from sonusai.mixture import Feature
|
29
|
-
from sonusai.mixture import MixtureDatabase
|
30
|
-
from sonusai.mixture import Predict
|
31
|
-
from sonusai.mixture import Segsnr
|
32
|
-
from sonusai.mixture import Truth
|
33
|
-
|
34
|
-
|
35
|
-
def evaluate(mixdb: MixtureDatabase,
|
36
|
-
truth: Truth,
|
37
|
-
predict: Predict = None,
|
38
|
-
segsnr: Segsnr = None,
|
39
|
-
output_dir: str = None,
|
40
|
-
predict_thr: float | np.ndarray = 0,
|
41
|
-
feature: Feature = None,
|
42
|
-
verbose: bool = False) -> None:
|
43
|
-
from os.path import join
|
44
|
-
|
45
|
-
from sonusai import initial_log_messages
|
46
|
-
from sonusai import update_console_handler
|
47
|
-
from sonusai.metrics import calc_optimal_thresholds
|
48
|
-
from sonusai.metrics import class_summary
|
49
|
-
from sonusai.metrics import snr_summary
|
50
|
-
from sonusai.mixture import SAMPLE_RATE
|
51
|
-
from sonusai.queries import get_mixids_from_snr
|
52
|
-
from sonusai.utils import get_num_classes_from_predict
|
53
|
-
from sonusai.utils import human_readable_size
|
54
|
-
from sonusai.utils import reshape_outputs
|
55
|
-
from sonusai.utils import seconds_to_hms
|
56
|
-
|
57
|
-
update_console_handler(verbose)
|
58
|
-
initial_log_messages('evaluate')
|
59
|
-
|
60
|
-
if truth.shape[-1] != predict.shape[-1]:
|
61
|
-
logger.exception(f'Number of classes in truth and predict are not equal. Exiting.')
|
62
|
-
raise SystemExit(1)
|
63
|
-
|
64
|
-
# truth, predict can be either [frames, num_classes] or [frames, timesteps, num_classes]
|
65
|
-
# in binary case dim may not exist, detect this and set num_classes == 1
|
66
|
-
timesteps = -1
|
67
|
-
predict, truth = reshape_outputs(predict=predict, truth=truth, timesteps=timesteps)
|
68
|
-
num_classes = get_num_classes_from_predict(predict=predict, timesteps=timesteps)
|
69
|
-
|
70
|
-
fdiff = truth.shape[0] - predict.shape[0]
|
71
|
-
if fdiff > 0:
|
72
|
-
# truth = truth[0:-fdiff,:]
|
73
|
-
predict = np.concatenate((predict, np.zeros((fdiff, num_classes), dtype=np.float32)))
|
74
|
-
logger.info(f'Truth has more feature-frames than predict, padding predict with zeros to match.')
|
75
|
-
|
76
|
-
if fdiff < 0:
|
77
|
-
predict = predict[0:fdiff, :]
|
78
|
-
logger.info(f'Predict has more feature-frames than truth, trimming predict to match.')
|
79
|
-
|
80
|
-
# Check segsnr, input is always in transform frames
|
81
|
-
compute_segsnr = False
|
82
|
-
if len(segsnr) > 0:
|
83
|
-
segsnr_feature_frames = segsnr.shape[0] / (mixdb.feature_step_samples / mixdb.ft_config.R)
|
84
|
-
if segsnr_feature_frames == truth.shape[0]:
|
85
|
-
compute_segsnr = True
|
86
|
-
else:
|
87
|
-
logger.warning('segsnr length does not match truth, ignoring.')
|
88
|
-
|
89
|
-
# Check predict_thr array or scalar and return final scalar predict_thr value
|
90
|
-
if not mixdb.truth_mutex:
|
91
|
-
if num_classes > 1:
|
92
|
-
if not isinstance(predict_thr, np.ndarray):
|
93
|
-
if predict_thr == 0:
|
94
|
-
# multi-label predict_thr scalar 0 force to 0.5 default
|
95
|
-
predict_thr = np.atleast_1d(0.5)
|
96
|
-
else:
|
97
|
-
predict_thr = np.atleast_1d(predict_thr)
|
98
|
-
else:
|
99
|
-
if predict_thr.ndim == 1:
|
100
|
-
if predict_thr[0] == 0:
|
101
|
-
# multi-label predict_thr array scalar 0 force to 0.5 default
|
102
|
-
predict_thr = np.atleast_1d(0.5)
|
103
|
-
else:
|
104
|
-
# multi-label predict_thr array set to scalar = array[0]
|
105
|
-
predict_thr = predict_thr[0]
|
106
|
-
else:
|
107
|
-
# single-label mode, force argmax mode
|
108
|
-
predict_thr = np.atleast_1d(0)
|
109
|
-
|
110
|
-
if predict_thr == -1:
|
111
|
-
thrpr, thrroc, _, _ = calc_optimal_thresholds(truth, predict, timesteps)
|
112
|
-
predict_thr = np.atleast_1d(thrpr)
|
113
|
-
predict_thr = np.maximum(predict_thr, 0.001) # enforce lower limit
|
114
|
-
predict_thr = np.minimum(predict_thr, 0.999) # enforce upper limit
|
115
|
-
predict_thr = predict_thr.round(2)
|
116
|
-
|
117
|
-
# Summarize the mixture data
|
118
|
-
num_mixtures = mixdb.num_mixtures
|
119
|
-
total_samples = sum([mixture.samples for mixture in mixdb.mixtures])
|
120
|
-
duration = total_samples / SAMPLE_RATE
|
121
|
-
|
122
|
-
logger.info('')
|
123
|
-
logger.info(f'Mixtures: {num_mixtures}')
|
124
|
-
logger.info(f'Duration: {seconds_to_hms(seconds=duration)}')
|
125
|
-
logger.info(f'truth: {human_readable_size(truth.nbytes, 1)}')
|
126
|
-
logger.info(f'predict: {human_readable_size(predict.nbytes, 1)}')
|
127
|
-
if compute_segsnr:
|
128
|
-
logger.info(f'segsnr: {human_readable_size(segsnr.nbytes, 1)}')
|
129
|
-
if feature:
|
130
|
-
logger.info(f'feature: {human_readable_size(feature.nbytes, 1)}')
|
131
|
-
|
132
|
-
logger.info(f'Classes: {num_classes}')
|
133
|
-
if mixdb.truth_mutex:
|
134
|
-
logger.info(f'Mode: Single-label / truth_mutex / softmax')
|
135
|
-
else:
|
136
|
-
logger.info(f'Mode: Multi-label / Binary')
|
137
|
-
|
138
|
-
mxid_snro = get_mixids_from_snr(mixdb=mixdb)
|
139
|
-
snrlist = list(mxid_snro.keys())
|
140
|
-
snrlist.sort(reverse=True)
|
141
|
-
logger.info(f'Ordered SNRs: {snrlist}\n')
|
142
|
-
predict_thr_info = predict_thr.transpose() if isinstance(predict_thr, np.ndarray) else predict_thr
|
143
|
-
logger.info(f'Prediction Threshold(s): {predict_thr_info}\n')
|
144
|
-
|
145
|
-
# Top-level report over all mixtures
|
146
|
-
macrodf, microdf, wghtdf, mxid_snro = snr_summary(mixdb=mixdb,
|
147
|
-
mixid=':',
|
148
|
-
truth_f=truth,
|
149
|
-
predict=predict,
|
150
|
-
segsnr=segsnr if compute_segsnr else None,
|
151
|
-
predict_thr=predict_thr)
|
152
|
-
|
153
|
-
if num_classes > 1:
|
154
|
-
logger.info(f'Metrics micro-avg per SNR over all {num_mixtures} mixtures:')
|
155
|
-
else:
|
156
|
-
logger.info(f'Metrics per SNR over all {num_mixtures} mixtures:')
|
157
|
-
logger.info(microdf.round(3).to_string())
|
158
|
-
logger.info('')
|
159
|
-
if output_dir:
|
160
|
-
microdf.round(3).to_csv(join(output_dir, 'snr.csv'))
|
161
|
-
|
162
|
-
if mixdb.truth_mutex:
|
163
|
-
macrodf, microdf, wghtdf, mxid_snro = snr_summary(mixdb=mixdb,
|
164
|
-
mixid=':',
|
165
|
-
truth_f=truth[:, 0:-1],
|
166
|
-
predict=predict[:, 0:-1],
|
167
|
-
segsnr=segsnr if compute_segsnr else None,
|
168
|
-
predict_thr=predict_thr)
|
169
|
-
|
170
|
-
logger.info(f'Metrics micro-avg without "Other" class per SNR over all {num_mixtures} mixtures:')
|
171
|
-
logger.info(microdf.round(3).to_string())
|
172
|
-
logger.info('')
|
173
|
-
if output_dir:
|
174
|
-
microdf.round(3).to_csv(join(output_dir, 'snrwo.csv'))
|
175
|
-
|
176
|
-
for snri in snrlist:
|
177
|
-
mxids = mxid_snro[snri]
|
178
|
-
classdf = class_summary(mixdb, mxids, truth, predict, predict_thr)
|
179
|
-
logger.info(f'Metrics per class for SNR {snri} over {len(mxids)} mixtures:')
|
180
|
-
logger.info(classdf.round(3).to_string())
|
181
|
-
logger.info('')
|
182
|
-
if output_dir:
|
183
|
-
classdf.round(3).to_csv(join(output_dir, f'class_snr{snri}.csv'))
|
184
|
-
|
185
|
-
|
186
|
-
def main() -> None:
|
187
|
-
from datetime import datetime
|
188
|
-
from os import mkdir
|
189
|
-
from os.path import join
|
190
|
-
|
191
|
-
import h5py
|
192
|
-
from docopt import docopt
|
193
|
-
|
194
|
-
import sonusai
|
195
|
-
from sonusai import SonusAIError
|
196
|
-
from sonusai import create_file_handler
|
197
|
-
from sonusai.utils import read_predict_data
|
198
|
-
from sonusai.utils import trim_docstring
|
199
|
-
|
200
|
-
args = docopt(trim_docstring(__doc__), version=sonusai.__version__, options_first=True)
|
201
|
-
|
202
|
-
verbose = args['--verbose']
|
203
|
-
feature_name = args['--feature']
|
204
|
-
predict_name = args['--predict']
|
205
|
-
predict_threshold = np.array(float(args['--thr']), dtype=np.float32)
|
206
|
-
location = args['LOC']
|
207
|
-
|
208
|
-
mixdb = MixtureDatabase(location)
|
209
|
-
|
210
|
-
# create output directory
|
211
|
-
output_dir = f'evaluate-{datetime.now():%Y%m%d}'
|
212
|
-
try:
|
213
|
-
mkdir(output_dir)
|
214
|
-
except OSError as _:
|
215
|
-
output_dir = f'evaluate-{datetime.now():%Y%m%d-%H%M%S}'
|
216
|
-
try:
|
217
|
-
mkdir(output_dir)
|
218
|
-
except OSError as error:
|
219
|
-
raise SonusAIError(f'Could not create directory, {output_dir}: {error}')
|
220
|
-
|
221
|
-
create_file_handler(join(output_dir, 'evaluate.log'))
|
222
|
-
|
223
|
-
with h5py.File(feature_name, 'r') as f:
|
224
|
-
truth_f = np.array(f['truth_f'])
|
225
|
-
segsnr = np.array(f['segsnr'])
|
226
|
-
|
227
|
-
predict = read_predict_data(predict_name)
|
228
|
-
|
229
|
-
evaluate(mixdb=mixdb,
|
230
|
-
truth=truth_f,
|
231
|
-
segsnr=segsnr,
|
232
|
-
output_dir=output_dir,
|
233
|
-
predict=predict,
|
234
|
-
predict_thr=predict_threshold,
|
235
|
-
verbose=verbose)
|
236
|
-
|
237
|
-
logger.info(f'Wrote results to {output_dir}')
|
238
|
-
|
239
|
-
|
240
|
-
if __name__ == '__main__':
|
241
|
-
try:
|
242
|
-
main()
|
243
|
-
except KeyboardInterrupt:
|
244
|
-
logger.info('Canceled due to keyboard interrupt')
|
245
|
-
raise SystemExit(0)
|
File without changes
|