sonusai 0.18.9__py3-none-any.whl → 0.19.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +20 -29
- sonusai/aawscd_probwrite.py +18 -18
- sonusai/audiofe.py +93 -80
- sonusai/calc_metric_spenh.py +395 -321
- sonusai/data/genmixdb.yml +5 -11
- sonusai/{gentcst.py → deprecated/gentcst.py} +146 -149
- sonusai/{plot.py → deprecated/plot.py} +177 -131
- sonusai/{tplot.py → deprecated/tplot.py} +124 -102
- sonusai/doc/__init__.py +1 -1
- sonusai/doc/doc.py +112 -177
- sonusai/doc.py +10 -10
- sonusai/genft.py +81 -91
- sonusai/genmetrics.py +51 -61
- sonusai/genmix.py +105 -115
- sonusai/genmixdb.py +201 -174
- sonusai/lsdb.py +56 -66
- sonusai/main.py +23 -20
- sonusai/metrics/__init__.py +2 -0
- sonusai/metrics/calc_audio_stats.py +29 -24
- sonusai/metrics/calc_class_weights.py +7 -7
- sonusai/metrics/calc_optimal_thresholds.py +5 -7
- sonusai/metrics/calc_pcm.py +3 -3
- sonusai/metrics/calc_pesq.py +10 -7
- sonusai/metrics/calc_phase_distance.py +3 -3
- sonusai/metrics/calc_sa_sdr.py +10 -8
- sonusai/metrics/calc_segsnr_f.py +16 -18
- sonusai/metrics/calc_speech.py +105 -47
- sonusai/metrics/calc_wer.py +35 -32
- sonusai/metrics/calc_wsdr.py +10 -7
- sonusai/metrics/class_summary.py +30 -27
- sonusai/metrics/confusion_matrix_summary.py +25 -22
- sonusai/metrics/one_hot.py +91 -57
- sonusai/metrics/snr_summary.py +53 -46
- sonusai/mixture/__init__.py +20 -14
- sonusai/mixture/audio.py +4 -6
- sonusai/mixture/augmentation.py +37 -43
- sonusai/mixture/class_count.py +5 -14
- sonusai/mixture/config.py +292 -225
- sonusai/mixture/constants.py +41 -30
- sonusai/mixture/data_io.py +155 -0
- sonusai/mixture/datatypes.py +111 -108
- sonusai/mixture/db_datatypes.py +54 -70
- sonusai/mixture/eq_rule_is_valid.py +6 -9
- sonusai/mixture/feature.py +40 -38
- sonusai/mixture/generation.py +522 -389
- sonusai/mixture/helpers.py +217 -272
- sonusai/mixture/log_duration_and_sizes.py +16 -13
- sonusai/mixture/mixdb.py +669 -477
- sonusai/mixture/soundfile_audio.py +12 -17
- sonusai/mixture/sox_audio.py +91 -112
- sonusai/mixture/sox_augmentation.py +8 -9
- sonusai/mixture/spectral_mask.py +4 -6
- sonusai/mixture/target_class_balancing.py +41 -36
- sonusai/mixture/targets.py +69 -67
- sonusai/mixture/tokenized_shell_vars.py +23 -23
- sonusai/mixture/torchaudio_audio.py +14 -15
- sonusai/mixture/torchaudio_augmentation.py +23 -27
- sonusai/mixture/truth.py +48 -26
- sonusai/mixture/truth_functions/__init__.py +26 -0
- sonusai/mixture/truth_functions/crm.py +56 -38
- sonusai/mixture/truth_functions/datatypes.py +37 -0
- sonusai/mixture/truth_functions/energy.py +85 -59
- sonusai/mixture/truth_functions/file.py +30 -30
- sonusai/mixture/truth_functions/phoneme.py +14 -7
- sonusai/mixture/truth_functions/sed.py +71 -45
- sonusai/mixture/truth_functions/target.py +69 -106
- sonusai/mkwav.py +58 -101
- sonusai/onnx_predict.py +46 -43
- sonusai/queries/__init__.py +3 -1
- sonusai/queries/queries.py +100 -59
- sonusai/speech/__init__.py +2 -0
- sonusai/speech/l2arctic.py +24 -23
- sonusai/speech/librispeech.py +16 -17
- sonusai/speech/mcgill.py +22 -21
- sonusai/speech/textgrid.py +32 -25
- sonusai/speech/timit.py +45 -42
- sonusai/speech/vctk.py +14 -13
- sonusai/speech/voxceleb.py +26 -20
- sonusai/summarize_metric_spenh.py +11 -10
- sonusai/utils/__init__.py +4 -3
- sonusai/utils/asl_p56.py +1 -1
- sonusai/utils/asr.py +37 -17
- sonusai/utils/asr_functions/__init__.py +2 -0
- sonusai/utils/asr_functions/aaware_whisper.py +18 -12
- sonusai/utils/audio_devices.py +12 -12
- sonusai/utils/braced_glob.py +6 -8
- sonusai/utils/calculate_input_shape.py +1 -4
- sonusai/utils/compress.py +2 -2
- sonusai/utils/convert_string_to_number.py +1 -3
- sonusai/utils/create_timestamp.py +1 -1
- sonusai/utils/create_ts_name.py +2 -2
- sonusai/utils/dataclass_from_dict.py +1 -1
- sonusai/utils/docstring.py +6 -6
- sonusai/utils/energy_f.py +9 -7
- sonusai/utils/engineering_number.py +56 -54
- sonusai/utils/get_label_names.py +8 -10
- sonusai/utils/human_readable_size.py +2 -2
- sonusai/utils/model_utils.py +3 -5
- sonusai/utils/numeric_conversion.py +2 -4
- sonusai/utils/onnx_utils.py +43 -32
- sonusai/utils/parallel.py +41 -30
- sonusai/utils/print_mixture_details.py +25 -22
- sonusai/utils/ranges.py +12 -12
- sonusai/utils/read_predict_data.py +11 -9
- sonusai/utils/reshape.py +19 -26
- sonusai/utils/seconds_to_hms.py +1 -1
- sonusai/utils/stacked_complex.py +8 -16
- sonusai/utils/stratified_shuffle_split.py +29 -27
- sonusai/utils/write_audio.py +2 -2
- sonusai/utils/yes_or_no.py +3 -3
- sonusai/vars.py +14 -14
- {sonusai-0.18.9.dist-info → sonusai-0.19.6.dist-info}/METADATA +20 -21
- sonusai-0.19.6.dist-info/RECORD +125 -0
- {sonusai-0.18.9.dist-info → sonusai-0.19.6.dist-info}/WHEEL +1 -1
- sonusai/mixture/truth_functions/data.py +0 -58
- sonusai/utils/read_mixture_data.py +0 -14
- sonusai-0.18.9.dist-info/RECORD +0 -125
- {sonusai-0.18.9.dist-info → sonusai-0.19.6.dist-info}/entry_points.txt +0 -0
@@ -13,6 +13,7 @@ Inputs:
|
|
13
13
|
LOC A SonusAI calc_metric_spenh results directory.
|
14
14
|
|
15
15
|
"""
|
16
|
+
|
16
17
|
import signal
|
17
18
|
|
18
19
|
|
@@ -21,24 +22,24 @@ def signal_handler(_sig, _frame):
|
|
21
22
|
|
22
23
|
from sonusai import logger
|
23
24
|
|
24
|
-
logger.info(
|
25
|
+
logger.info("Canceled due to keyboard interrupt")
|
25
26
|
sys.exit(1)
|
26
27
|
|
27
28
|
|
28
29
|
signal.signal(signal.SIGINT, signal_handler)
|
29
30
|
|
30
31
|
|
31
|
-
def summarize_metric_spenh(location: str, by: str =
|
32
|
+
def summarize_metric_spenh(location: str, by: str = "MIXID", reverse: bool = False) -> str:
|
32
33
|
import glob
|
33
34
|
|
34
35
|
import pandas as pd
|
35
36
|
|
36
|
-
files = sorted(glob.glob(location +
|
37
|
+
files = sorted(glob.glob(location + "/*_metric_spenh.txt"))
|
37
38
|
need_header = True
|
38
|
-
header = [
|
39
|
+
header = ["MIXID"]
|
39
40
|
data = []
|
40
41
|
for file in files:
|
41
|
-
with open(file
|
42
|
+
with open(file) as f:
|
42
43
|
for i, line in enumerate(f):
|
43
44
|
if i == 1 and need_header:
|
44
45
|
need_header = False
|
@@ -48,7 +49,7 @@ def summarize_metric_spenh(location: str, by: str = 'MIXID', reverse: bool = Fal
|
|
48
49
|
break
|
49
50
|
|
50
51
|
df = pd.DataFrame(data, columns=header)
|
51
|
-
df[header[0:-2]] = df[header[0:-2]].apply(pd.to_numeric, errors=
|
52
|
+
df[header[0:-2]] = df[header[0:-2]].apply(pd.to_numeric, errors="coerce")
|
52
53
|
return df.sort_values(by=by, ascending=not reverse).to_string(index=False)
|
53
54
|
|
54
55
|
|
@@ -60,12 +61,12 @@ def main():
|
|
60
61
|
|
61
62
|
args = docopt(trim_docstring(__doc__), version=sonusai.__version__, options_first=True)
|
62
63
|
|
63
|
-
by = args[
|
64
|
-
reverse = args[
|
65
|
-
location = args[
|
64
|
+
by = args["--sort"]
|
65
|
+
reverse = args["--reverse"]
|
66
|
+
location = args["LOC"]
|
66
67
|
|
67
68
|
print(summarize_metric_spenh(location, by, reverse))
|
68
69
|
|
69
70
|
|
70
|
-
if __name__ ==
|
71
|
+
if __name__ == "__main__":
|
71
72
|
main()
|
sonusai/utils/__init__.py
CHANGED
@@ -1,8 +1,10 @@
|
|
1
1
|
# SonusAI general utilities
|
2
|
+
# ruff: noqa: F401
|
2
3
|
from .asl_p56 import asl_p56
|
3
4
|
from .asr import ASRResult
|
4
5
|
from .asr import calc_asr
|
5
6
|
from .asr import get_available_engines
|
7
|
+
from .asr import validate_asr
|
6
8
|
from .audio_devices import get_default_input_device
|
7
9
|
from .audio_devices import get_input_device_index_by_name
|
8
10
|
from .audio_devices import get_input_devices
|
@@ -32,14 +34,13 @@ from .numeric_conversion import int16_to_float
|
|
32
34
|
from .onnx_utils import add_sonusai_metadata
|
33
35
|
from .onnx_utils import get_sonusai_metadata
|
34
36
|
from .onnx_utils import load_ort_session
|
35
|
-
from .parallel import
|
36
|
-
from .parallel import
|
37
|
+
from .parallel import par_track
|
38
|
+
from .parallel import track
|
37
39
|
from .path_info import PathInfo
|
38
40
|
from .print_mixture_details import print_class_count
|
39
41
|
from .print_mixture_details import print_mixture_details
|
40
42
|
from .ranges import consolidate_range
|
41
43
|
from .ranges import expand_range
|
42
|
-
from .read_mixture_data import read_mixture_data
|
43
44
|
from .read_predict_data import read_predict_data
|
44
45
|
from .reshape import get_input_shape
|
45
46
|
from .reshape import get_num_classes_from_predict
|
sonusai/utils/asl_p56.py
CHANGED
@@ -88,7 +88,7 @@ def asl_p56(audio: AudioT) -> float:
|
|
88
88
|
# Interpolate to find the asl_ms_log
|
89
89
|
asl_ms_log = _bin_interp(A_db[j], A_db[j - 1], C_db[j], C_db[j - 1], M, 0.5)
|
90
90
|
# This is the mean square value NOT the RMS
|
91
|
-
asl_msq = 10. ** (asl_ms_log / 10)
|
91
|
+
asl_msq = 10.0 ** (asl_ms_log / 10)
|
92
92
|
break
|
93
93
|
|
94
94
|
return asl_msq
|
sonusai/utils/asr.py
CHANGED
@@ -1,6 +1,5 @@
|
|
1
|
+
from collections.abc import Callable
|
1
2
|
from dataclasses import dataclass
|
2
|
-
from typing import Callable
|
3
|
-
from typing import Optional
|
4
3
|
|
5
4
|
from sonusai.mixture import AudioT
|
6
5
|
|
@@ -8,25 +7,25 @@ from sonusai.mixture import AudioT
|
|
8
7
|
@dataclass(frozen=True)
|
9
8
|
class ASRResult:
|
10
9
|
text: str
|
11
|
-
confidence:
|
12
|
-
lang:
|
13
|
-
lang_prob:
|
14
|
-
duration:
|
15
|
-
num_segments:
|
16
|
-
asr_cpu_time:
|
10
|
+
confidence: float | None = None
|
11
|
+
lang: str | None = None
|
12
|
+
lang_prob: float | None = None
|
13
|
+
duration: float | None = None
|
14
|
+
num_segments: int | None = None
|
15
|
+
asr_cpu_time: float | None = None
|
17
16
|
|
18
17
|
|
19
18
|
def get_available_engines() -> list[str]:
|
20
19
|
from importlib import import_module
|
21
20
|
from pkgutil import iter_modules
|
22
21
|
|
23
|
-
module = import_module(
|
24
|
-
engines = [method for method in dir(module) if not method.startswith(
|
22
|
+
module = import_module("sonusai.utils.asr_functions")
|
23
|
+
engines = [method for method in dir(module) if not method.startswith("_")]
|
25
24
|
for _, name, _ in iter_modules():
|
26
|
-
if name.startswith(
|
27
|
-
module = import_module(f
|
25
|
+
if name.startswith("sonusai_asr_"):
|
26
|
+
module = import_module(f"{name}.asr_functions")
|
28
27
|
for method in dir(module):
|
29
|
-
if not method.startswith(
|
28
|
+
if not method.startswith("_"):
|
30
29
|
engines.append(method)
|
31
30
|
|
32
31
|
return engines
|
@@ -36,19 +35,19 @@ def _asr_fn(engine: str) -> Callable[..., ASRResult]:
|
|
36
35
|
from importlib import import_module
|
37
36
|
from pkgutil import iter_modules
|
38
37
|
|
39
|
-
module = import_module(
|
38
|
+
module = import_module("sonusai.utils.asr_functions")
|
40
39
|
for method in dir(module):
|
41
40
|
if method == engine:
|
42
41
|
return getattr(module, method)
|
43
42
|
|
44
43
|
for _, name, _ in iter_modules():
|
45
|
-
if name.startswith(
|
46
|
-
module = import_module(f
|
44
|
+
if name.startswith("sonusai_asr_"):
|
45
|
+
module = import_module(f"{name}.asr_functions")
|
47
46
|
for method in dir(module):
|
48
47
|
if method == engine:
|
49
48
|
return getattr(module, method)
|
50
49
|
|
51
|
-
raise ValueError(f
|
50
|
+
raise ValueError(f"engine {engine} not supported")
|
52
51
|
|
53
52
|
|
54
53
|
def calc_asr(audio: AudioT | str, engine: str, **config) -> ASRResult:
|
@@ -69,3 +68,24 @@ def calc_asr(audio: AudioT | str, engine: str, **config) -> ASRResult:
|
|
69
68
|
audio = copy(read_audio(audio))
|
70
69
|
|
71
70
|
return _asr_fn(engine)(audio, **config)
|
71
|
+
|
72
|
+
|
73
|
+
def validate_asr(engine: str, **config) -> None:
|
74
|
+
from importlib import import_module
|
75
|
+
from pkgutil import iter_modules
|
76
|
+
|
77
|
+
module = import_module("sonusai.utils.asr_functions")
|
78
|
+
for method in dir(module):
|
79
|
+
if method == engine:
|
80
|
+
getattr(module, method + "_validate")(**config)
|
81
|
+
return
|
82
|
+
|
83
|
+
for _, name, _ in iter_modules():
|
84
|
+
if name.startswith("sonusai_asr_"):
|
85
|
+
module = import_module(f"{name}.asr_functions")
|
86
|
+
for method in dir(module):
|
87
|
+
if method == engine:
|
88
|
+
getattr(module, method + "_validate")(**config)
|
89
|
+
return
|
90
|
+
|
91
|
+
raise ValueError(f"engine {engine} not supported")
|
@@ -2,6 +2,10 @@ from sonusai.mixture import AudioT
|
|
2
2
|
from sonusai.utils import ASRResult
|
3
3
|
|
4
4
|
|
5
|
+
def aaware_whisper_validate(**_config) -> None:
|
6
|
+
pass
|
7
|
+
|
8
|
+
|
5
9
|
def aaware_whisper(audio: AudioT, **_config) -> ASRResult:
|
6
10
|
import tempfile
|
7
11
|
from math import exp
|
@@ -10,32 +14,34 @@ def aaware_whisper(audio: AudioT, **_config) -> ASRResult:
|
|
10
14
|
|
11
15
|
import requests
|
12
16
|
|
13
|
-
from sonusai import SonusAIError
|
14
17
|
from sonusai.utils import ASRResult
|
15
18
|
from sonusai.utils import float_to_int16
|
16
19
|
from sonusai.utils import write_audio
|
17
20
|
|
18
|
-
url = getenv(
|
21
|
+
url = getenv("AAWARE_WHISPER_URL")
|
19
22
|
if url is None:
|
20
|
-
raise
|
21
|
-
url +=
|
23
|
+
raise EnvironmentError("AAWARE_WHISPER_URL environment variable does not exist")
|
24
|
+
url += "/asr?task=transcribe&language=en&encode=true&output=json"
|
22
25
|
|
23
26
|
with tempfile.TemporaryDirectory() as tmp:
|
24
|
-
file = join(tmp,
|
27
|
+
file = join(tmp, "asr.wav")
|
25
28
|
write_audio(name=file, audio=float_to_int16(audio))
|
26
29
|
|
27
|
-
files = {
|
30
|
+
files = {"audio_file": (file, open(file, "rb"), "audio/wav")} # noqa: SIM115
|
28
31
|
|
29
32
|
try:
|
30
|
-
response = requests.post(url, files=files)
|
31
|
-
if
|
33
|
+
response = requests.post(url, files=files) # noqa: S113
|
34
|
+
if response.status_code != 200:
|
32
35
|
if response.status_code == 422:
|
33
|
-
raise
|
34
|
-
raise
|
36
|
+
raise RuntimeError(f"Validation error: {response.json()}") # noqa: TRY301
|
37
|
+
raise RuntimeError(f"Invalid response: {response.status_code}") # noqa: TRY301
|
35
38
|
result = response.json()
|
36
|
-
return ASRResult(
|
39
|
+
return ASRResult(
|
40
|
+
text=result["text"],
|
41
|
+
confidence=exp(float(result["segments"][0]["avg_logprob"])),
|
42
|
+
)
|
37
43
|
except Exception as e:
|
38
|
-
raise
|
44
|
+
raise RuntimeError(f"Aaware Whisper exception: {e.args}") from e
|
39
45
|
|
40
46
|
|
41
47
|
"""
|
sonusai/utils/audio_devices.py
CHANGED
@@ -1,29 +1,29 @@
|
|
1
1
|
import pyaudio
|
2
2
|
|
3
3
|
|
4
|
-
def get_input_device_index_by_name(p: pyaudio.PyAudio, name: str = None) -> int:
|
4
|
+
def get_input_device_index_by_name(p: pyaudio.PyAudio, name: str | None = None) -> int:
|
5
5
|
info = p.get_host_api_info_by_index(0)
|
6
|
-
device_count = info.get(
|
6
|
+
device_count = info.get("deviceCount")
|
7
7
|
for i in range(0, device_count):
|
8
8
|
device_info = p.get_device_info_by_host_api_device_index(0, i)
|
9
9
|
if name is None:
|
10
10
|
device_name = None
|
11
11
|
else:
|
12
|
-
device_name = device_info.get(
|
13
|
-
if name == device_name and device_info.get(
|
12
|
+
device_name = device_info.get("name")
|
13
|
+
if name == device_name and device_info.get("maxInputChannels") > 0:
|
14
14
|
return i
|
15
15
|
|
16
|
-
raise ValueError(f
|
16
|
+
raise ValueError(f"Could not find {name}")
|
17
17
|
|
18
18
|
|
19
19
|
def get_input_devices(p: pyaudio.PyAudio) -> list[str]:
|
20
20
|
names = []
|
21
21
|
info = p.get_host_api_info_by_index(0)
|
22
|
-
device_count = info.get(
|
22
|
+
device_count = info.get("deviceCount")
|
23
23
|
for i in range(0, device_count):
|
24
24
|
device_info = p.get_device_info_by_host_api_device_index(0, i)
|
25
|
-
device_name = device_info.get(
|
26
|
-
if device_info.get(
|
25
|
+
device_name = device_info.get("name")
|
26
|
+
if device_info.get("maxInputChannels") > 0:
|
27
27
|
names.append(device_name)
|
28
28
|
|
29
29
|
return names
|
@@ -31,11 +31,11 @@ def get_input_devices(p: pyaudio.PyAudio) -> list[str]:
|
|
31
31
|
|
32
32
|
def get_default_input_device(p: pyaudio.PyAudio) -> str:
|
33
33
|
info = p.get_host_api_info_by_index(0)
|
34
|
-
device_count = info.get(
|
34
|
+
device_count = info.get("deviceCount")
|
35
35
|
for i in range(0, device_count):
|
36
36
|
device_info = p.get_device_info_by_host_api_device_index(0, i)
|
37
|
-
device_name = device_info.get(
|
38
|
-
if device_info.get(
|
37
|
+
device_name = device_info.get("name")
|
38
|
+
if device_info.get("maxInputChannels") > 0:
|
39
39
|
return device_name
|
40
40
|
|
41
|
-
raise ValueError(
|
41
|
+
raise ValueError("No input audio devices found")
|
sonusai/utils/braced_glob.py
CHANGED
@@ -1,10 +1,8 @@
|
|
1
|
-
from
|
1
|
+
from collections.abc import Generator
|
2
2
|
from typing import LiteralString
|
3
|
-
from typing import Optional
|
4
|
-
from typing import Set
|
5
3
|
|
6
4
|
|
7
|
-
def expand_braces(text: LiteralString | str | bytes, seen:
|
5
|
+
def expand_braces(text: LiteralString | str | bytes, seen: set[str] | None = None) -> Generator[str, None, None]:
|
8
6
|
"""Brace-expansion pre-processing for glob.
|
9
7
|
|
10
8
|
Expand all the braces, then run glob on each of the results.
|
@@ -20,8 +18,8 @@ def expand_braces(text: LiteralString | str | bytes, seen: Optional[Set[str]] =
|
|
20
18
|
if not isinstance(text, str):
|
21
19
|
text = str(text)
|
22
20
|
|
23
|
-
spans = [m.span() for m in re.finditer(r
|
24
|
-
alts = [text[start + 1: stop - 1].split(
|
21
|
+
spans = [m.span() for m in re.finditer(r"\{[^{}]*}", text)][::-1]
|
22
|
+
alts = [text[start + 1 : stop - 1].split(",") for start, stop in spans]
|
25
23
|
|
26
24
|
if len(spans) == 0:
|
27
25
|
if text not in seen:
|
@@ -30,9 +28,9 @@ def expand_braces(text: LiteralString | str | bytes, seen: Optional[Set[str]] =
|
|
30
28
|
else:
|
31
29
|
for combo in itertools.product(*alts):
|
32
30
|
replaced = list(text)
|
33
|
-
for (start, stop), replacement in zip(spans, combo):
|
31
|
+
for (start, stop), replacement in zip(spans, combo, strict=False):
|
34
32
|
replaced[start:stop] = replacement
|
35
|
-
yield from expand_braces(
|
33
|
+
yield from expand_braces("".join(replaced), seen)
|
36
34
|
|
37
35
|
|
38
36
|
def braced_glob(pathname: LiteralString | str | bytes, recursive: bool = False) -> list[str]:
|
@@ -1,7 +1,4 @@
|
|
1
|
-
def calculate_input_shape(feature: str,
|
2
|
-
flatten: bool = False,
|
3
|
-
timesteps: int = 0,
|
4
|
-
add1ch: bool = False) -> list[int]:
|
1
|
+
def calculate_input_shape(feature: str, flatten: bool = False, timesteps: int = 0, add1ch: bool = False) -> list[int]:
|
5
2
|
"""
|
6
3
|
Calculate input shape given feature and user-specified reshape parameters.
|
7
4
|
|
sonusai/utils/compress.py
CHANGED
@@ -6,7 +6,7 @@ def power_compress(feature: AudioF) -> AudioF:
|
|
6
6
|
|
7
7
|
mag = np.abs(feature)
|
8
8
|
phase = np.angle(feature)
|
9
|
-
mag = mag
|
9
|
+
mag = mag**0.3
|
10
10
|
real_compress = mag * np.cos(phase)
|
11
11
|
imag_compress = mag * np.sin(phase)
|
12
12
|
|
@@ -18,7 +18,7 @@ def power_uncompress(feature: AudioF) -> AudioF:
|
|
18
18
|
|
19
19
|
mag = np.abs(feature)
|
20
20
|
phase = np.angle(feature)
|
21
|
-
mag = mag ** (1. / 0.3)
|
21
|
+
mag = mag ** (1.0 / 0.3)
|
22
22
|
real_uncompress = mag * np.cos(phase)
|
23
23
|
imag_uncompress = mag * np.sin(phase)
|
24
24
|
|
sonusai/utils/create_ts_name.py
CHANGED
@@ -6,9 +6,9 @@ def create_ts_name(base: str) -> str:
|
|
6
6
|
ts = datetime.now()
|
7
7
|
|
8
8
|
# First try just date
|
9
|
-
dir_name = base +
|
9
|
+
dir_name = base + "-" + ts.strftime("%Y%m%d")
|
10
10
|
if exists(dir_name):
|
11
11
|
# add hour-min-sec if necessary
|
12
|
-
dir_name = base +
|
12
|
+
dir_name = base + "-" + ts.strftime("%Y%m%d-%H%M%S")
|
13
13
|
|
14
14
|
return dir_name
|
@@ -4,6 +4,6 @@ def dataclass_from_dict(klass, dikt):
|
|
4
4
|
field_types = klass.__annotations__
|
5
5
|
return klass(**{f: dataclass_from_dict(field_types[f], dikt[f]) for f in dikt})
|
6
6
|
except AttributeError:
|
7
|
-
if isinstance(dikt,
|
7
|
+
if isinstance(dikt, tuple | list):
|
8
8
|
return [dataclass_from_dict(klass.__args__[0], f) for f in dikt]
|
9
9
|
return dikt
|
sonusai/utils/docstring.py
CHANGED
@@ -3,7 +3,7 @@ def trim_docstring(docstring: str) -> str:
|
|
3
3
|
from sys import maxsize
|
4
4
|
|
5
5
|
if not docstring:
|
6
|
-
return
|
6
|
+
return ""
|
7
7
|
|
8
8
|
# Convert tabs to spaces (following the normal Python rules)
|
9
9
|
# and split into a list of lines:
|
@@ -27,7 +27,7 @@ def trim_docstring(docstring: str) -> str:
|
|
27
27
|
trimmed.pop(0)
|
28
28
|
|
29
29
|
# Return a single string
|
30
|
-
return
|
30
|
+
return "\n".join(trimmed)
|
31
31
|
|
32
32
|
|
33
33
|
def add_commands_to_docstring(docstring: str, plugin_docstrings: list[str]) -> str:
|
@@ -36,8 +36,8 @@ def add_commands_to_docstring(docstring: str, plugin_docstrings: list[str]) -> s
|
|
36
36
|
|
37
37
|
lines = docstring.splitlines()
|
38
38
|
|
39
|
-
start = lines.index(
|
40
|
-
end = lines.index(
|
39
|
+
start = lines.index("The sonusai commands are:")
|
40
|
+
end = lines.index("", start)
|
41
41
|
|
42
42
|
commands = sonusai.commands_doc.splitlines()
|
43
43
|
for plugin_docstring in plugin_docstrings:
|
@@ -45,6 +45,6 @@ def add_commands_to_docstring(docstring: str, plugin_docstrings: list[str]) -> s
|
|
45
45
|
commands.sort()
|
46
46
|
commands = list(filter(None, commands))
|
47
47
|
|
48
|
-
lines = lines[:start + 1] + commands + lines[end:]
|
48
|
+
lines = lines[: start + 1] + commands + lines[end:]
|
49
49
|
|
50
|
-
return
|
50
|
+
return "\n".join(lines)
|
sonusai/utils/energy_f.py
CHANGED
@@ -1,12 +1,15 @@
|
|
1
|
-
from
|
1
|
+
from pyaaware import ForwardTransform
|
2
|
+
|
2
3
|
from sonusai.mixture import AudioF
|
3
4
|
from sonusai.mixture import AudioT
|
4
5
|
from sonusai.mixture import EnergyF
|
5
6
|
|
6
7
|
|
7
|
-
def compute_energy_f(
|
8
|
-
|
9
|
-
|
8
|
+
def compute_energy_f(
|
9
|
+
frequency_domain: AudioF | None = None,
|
10
|
+
time_domain: AudioT | None = None,
|
11
|
+
transform: ForwardTransform | None = None,
|
12
|
+
) -> EnergyF:
|
10
13
|
"""Compute the energy in each bin
|
11
14
|
|
12
15
|
Must provide either frequency domain or time domain input. If time domain input is provided, must also provide
|
@@ -19,13 +22,12 @@ def compute_energy_f(frequency_domain: AudioF = None,
|
|
19
22
|
"""
|
20
23
|
import numpy as np
|
21
24
|
import torch
|
22
|
-
from sonusai import SonusAIError
|
23
25
|
|
24
26
|
if frequency_domain is None:
|
25
27
|
if time_domain is None:
|
26
|
-
raise
|
28
|
+
raise ValueError("Must provide time or frequency domain input")
|
27
29
|
if transform is None:
|
28
|
-
raise
|
30
|
+
raise ValueError("Must provide ForwardTransform object")
|
29
31
|
|
30
32
|
frequency_domain = transform.execute_all(torch.from_numpy(time_domain))[0].numpy()
|
31
33
|
|