sonusai 0.18.9__py3-none-any.whl → 0.19.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (118) hide show
  1. sonusai/__init__.py +20 -29
  2. sonusai/aawscd_probwrite.py +18 -18
  3. sonusai/audiofe.py +93 -80
  4. sonusai/calc_metric_spenh.py +395 -321
  5. sonusai/data/genmixdb.yml +5 -11
  6. sonusai/{gentcst.py → deprecated/gentcst.py} +146 -149
  7. sonusai/{plot.py → deprecated/plot.py} +177 -131
  8. sonusai/{tplot.py → deprecated/tplot.py} +124 -102
  9. sonusai/doc/__init__.py +1 -1
  10. sonusai/doc/doc.py +112 -177
  11. sonusai/doc.py +10 -10
  12. sonusai/genft.py +81 -91
  13. sonusai/genmetrics.py +51 -61
  14. sonusai/genmix.py +105 -115
  15. sonusai/genmixdb.py +201 -174
  16. sonusai/lsdb.py +56 -66
  17. sonusai/main.py +23 -20
  18. sonusai/metrics/__init__.py +2 -0
  19. sonusai/metrics/calc_audio_stats.py +29 -24
  20. sonusai/metrics/calc_class_weights.py +7 -7
  21. sonusai/metrics/calc_optimal_thresholds.py +5 -7
  22. sonusai/metrics/calc_pcm.py +3 -3
  23. sonusai/metrics/calc_pesq.py +10 -7
  24. sonusai/metrics/calc_phase_distance.py +3 -3
  25. sonusai/metrics/calc_sa_sdr.py +10 -8
  26. sonusai/metrics/calc_segsnr_f.py +16 -18
  27. sonusai/metrics/calc_speech.py +105 -47
  28. sonusai/metrics/calc_wer.py +35 -32
  29. sonusai/metrics/calc_wsdr.py +10 -7
  30. sonusai/metrics/class_summary.py +30 -27
  31. sonusai/metrics/confusion_matrix_summary.py +25 -22
  32. sonusai/metrics/one_hot.py +91 -57
  33. sonusai/metrics/snr_summary.py +53 -46
  34. sonusai/mixture/__init__.py +20 -14
  35. sonusai/mixture/audio.py +4 -6
  36. sonusai/mixture/augmentation.py +37 -43
  37. sonusai/mixture/class_count.py +5 -14
  38. sonusai/mixture/config.py +292 -225
  39. sonusai/mixture/constants.py +41 -30
  40. sonusai/mixture/data_io.py +155 -0
  41. sonusai/mixture/datatypes.py +111 -108
  42. sonusai/mixture/db_datatypes.py +54 -70
  43. sonusai/mixture/eq_rule_is_valid.py +6 -9
  44. sonusai/mixture/feature.py +40 -38
  45. sonusai/mixture/generation.py +522 -389
  46. sonusai/mixture/helpers.py +217 -272
  47. sonusai/mixture/log_duration_and_sizes.py +16 -13
  48. sonusai/mixture/mixdb.py +669 -477
  49. sonusai/mixture/soundfile_audio.py +12 -17
  50. sonusai/mixture/sox_audio.py +91 -112
  51. sonusai/mixture/sox_augmentation.py +8 -9
  52. sonusai/mixture/spectral_mask.py +4 -6
  53. sonusai/mixture/target_class_balancing.py +41 -36
  54. sonusai/mixture/targets.py +69 -67
  55. sonusai/mixture/tokenized_shell_vars.py +23 -23
  56. sonusai/mixture/torchaudio_audio.py +14 -15
  57. sonusai/mixture/torchaudio_augmentation.py +23 -27
  58. sonusai/mixture/truth.py +48 -26
  59. sonusai/mixture/truth_functions/__init__.py +26 -0
  60. sonusai/mixture/truth_functions/crm.py +56 -38
  61. sonusai/mixture/truth_functions/datatypes.py +37 -0
  62. sonusai/mixture/truth_functions/energy.py +85 -59
  63. sonusai/mixture/truth_functions/file.py +30 -30
  64. sonusai/mixture/truth_functions/phoneme.py +14 -7
  65. sonusai/mixture/truth_functions/sed.py +71 -45
  66. sonusai/mixture/truth_functions/target.py +69 -106
  67. sonusai/mkwav.py +58 -101
  68. sonusai/onnx_predict.py +46 -43
  69. sonusai/queries/__init__.py +3 -1
  70. sonusai/queries/queries.py +100 -59
  71. sonusai/speech/__init__.py +2 -0
  72. sonusai/speech/l2arctic.py +24 -23
  73. sonusai/speech/librispeech.py +16 -17
  74. sonusai/speech/mcgill.py +22 -21
  75. sonusai/speech/textgrid.py +32 -25
  76. sonusai/speech/timit.py +45 -42
  77. sonusai/speech/vctk.py +14 -13
  78. sonusai/speech/voxceleb.py +26 -20
  79. sonusai/summarize_metric_spenh.py +11 -10
  80. sonusai/utils/__init__.py +4 -3
  81. sonusai/utils/asl_p56.py +1 -1
  82. sonusai/utils/asr.py +37 -17
  83. sonusai/utils/asr_functions/__init__.py +2 -0
  84. sonusai/utils/asr_functions/aaware_whisper.py +18 -12
  85. sonusai/utils/audio_devices.py +12 -12
  86. sonusai/utils/braced_glob.py +6 -8
  87. sonusai/utils/calculate_input_shape.py +1 -4
  88. sonusai/utils/compress.py +2 -2
  89. sonusai/utils/convert_string_to_number.py +1 -3
  90. sonusai/utils/create_timestamp.py +1 -1
  91. sonusai/utils/create_ts_name.py +2 -2
  92. sonusai/utils/dataclass_from_dict.py +1 -1
  93. sonusai/utils/docstring.py +6 -6
  94. sonusai/utils/energy_f.py +9 -7
  95. sonusai/utils/engineering_number.py +56 -54
  96. sonusai/utils/get_label_names.py +8 -10
  97. sonusai/utils/human_readable_size.py +2 -2
  98. sonusai/utils/model_utils.py +3 -5
  99. sonusai/utils/numeric_conversion.py +2 -4
  100. sonusai/utils/onnx_utils.py +43 -32
  101. sonusai/utils/parallel.py +41 -30
  102. sonusai/utils/print_mixture_details.py +25 -22
  103. sonusai/utils/ranges.py +12 -12
  104. sonusai/utils/read_predict_data.py +11 -9
  105. sonusai/utils/reshape.py +19 -26
  106. sonusai/utils/seconds_to_hms.py +1 -1
  107. sonusai/utils/stacked_complex.py +8 -16
  108. sonusai/utils/stratified_shuffle_split.py +29 -27
  109. sonusai/utils/write_audio.py +2 -2
  110. sonusai/utils/yes_or_no.py +3 -3
  111. sonusai/vars.py +14 -14
  112. {sonusai-0.18.9.dist-info → sonusai-0.19.6.dist-info}/METADATA +20 -21
  113. sonusai-0.19.6.dist-info/RECORD +125 -0
  114. {sonusai-0.18.9.dist-info → sonusai-0.19.6.dist-info}/WHEEL +1 -1
  115. sonusai/mixture/truth_functions/data.py +0 -58
  116. sonusai/utils/read_mixture_data.py +0 -14
  117. sonusai-0.18.9.dist-info/RECORD +0 -125
  118. {sonusai-0.18.9.dist-info → sonusai-0.19.6.dist-info}/entry_points.txt +0 -0
@@ -13,6 +13,7 @@ Inputs:
13
13
  LOC A SonusAI calc_metric_spenh results directory.
14
14
 
15
15
  """
16
+
16
17
  import signal
17
18
 
18
19
 
@@ -21,24 +22,24 @@ def signal_handler(_sig, _frame):
21
22
 
22
23
  from sonusai import logger
23
24
 
24
- logger.info('Canceled due to keyboard interrupt')
25
+ logger.info("Canceled due to keyboard interrupt")
25
26
  sys.exit(1)
26
27
 
27
28
 
28
29
  signal.signal(signal.SIGINT, signal_handler)
29
30
 
30
31
 
31
- def summarize_metric_spenh(location: str, by: str = 'MIXID', reverse: bool = False) -> str:
32
+ def summarize_metric_spenh(location: str, by: str = "MIXID", reverse: bool = False) -> str:
32
33
  import glob
33
34
 
34
35
  import pandas as pd
35
36
 
36
- files = sorted(glob.glob(location + '/*_metric_spenh.txt'))
37
+ files = sorted(glob.glob(location + "/*_metric_spenh.txt"))
37
38
  need_header = True
38
- header = ['MIXID']
39
+ header = ["MIXID"]
39
40
  data = []
40
41
  for file in files:
41
- with open(file, 'r') as f:
42
+ with open(file) as f:
42
43
  for i, line in enumerate(f):
43
44
  if i == 1 and need_header:
44
45
  need_header = False
@@ -48,7 +49,7 @@ def summarize_metric_spenh(location: str, by: str = 'MIXID', reverse: bool = Fal
48
49
  break
49
50
 
50
51
  df = pd.DataFrame(data, columns=header)
51
- df[header[0:-2]] = df[header[0:-2]].apply(pd.to_numeric, errors='coerce')
52
+ df[header[0:-2]] = df[header[0:-2]].apply(pd.to_numeric, errors="coerce")
52
53
  return df.sort_values(by=by, ascending=not reverse).to_string(index=False)
53
54
 
54
55
 
@@ -60,12 +61,12 @@ def main():
60
61
 
61
62
  args = docopt(trim_docstring(__doc__), version=sonusai.__version__, options_first=True)
62
63
 
63
- by = args['--sort']
64
- reverse = args['--reverse']
65
- location = args['LOC']
64
+ by = args["--sort"]
65
+ reverse = args["--reverse"]
66
+ location = args["LOC"]
66
67
 
67
68
  print(summarize_metric_spenh(location, by, reverse))
68
69
 
69
70
 
70
- if __name__ == '__main__':
71
+ if __name__ == "__main__":
71
72
  main()
sonusai/utils/__init__.py CHANGED
@@ -1,8 +1,10 @@
1
1
  # SonusAI general utilities
2
+ # ruff: noqa: F401
2
3
  from .asl_p56 import asl_p56
3
4
  from .asr import ASRResult
4
5
  from .asr import calc_asr
5
6
  from .asr import get_available_engines
7
+ from .asr import validate_asr
6
8
  from .audio_devices import get_default_input_device
7
9
  from .audio_devices import get_input_device_index_by_name
8
10
  from .audio_devices import get_input_devices
@@ -32,14 +34,13 @@ from .numeric_conversion import int16_to_float
32
34
  from .onnx_utils import add_sonusai_metadata
33
35
  from .onnx_utils import get_sonusai_metadata
34
36
  from .onnx_utils import load_ort_session
35
- from .parallel import pp_imap
36
- from .parallel import pp_tqdm_imap
37
+ from .parallel import par_track
38
+ from .parallel import track
37
39
  from .path_info import PathInfo
38
40
  from .print_mixture_details import print_class_count
39
41
  from .print_mixture_details import print_mixture_details
40
42
  from .ranges import consolidate_range
41
43
  from .ranges import expand_range
42
- from .read_mixture_data import read_mixture_data
43
44
  from .read_predict_data import read_predict_data
44
45
  from .reshape import get_input_shape
45
46
  from .reshape import get_num_classes_from_predict
sonusai/utils/asl_p56.py CHANGED
@@ -88,7 +88,7 @@ def asl_p56(audio: AudioT) -> float:
88
88
  # Interpolate to find the asl_ms_log
89
89
  asl_ms_log = _bin_interp(A_db[j], A_db[j - 1], C_db[j], C_db[j - 1], M, 0.5)
90
90
  # This is the mean square value NOT the RMS
91
- asl_msq = 10. ** (asl_ms_log / 10)
91
+ asl_msq = 10.0 ** (asl_ms_log / 10)
92
92
  break
93
93
 
94
94
  return asl_msq
sonusai/utils/asr.py CHANGED
@@ -1,6 +1,5 @@
1
+ from collections.abc import Callable
1
2
  from dataclasses import dataclass
2
- from typing import Callable
3
- from typing import Optional
4
3
 
5
4
  from sonusai.mixture import AudioT
6
5
 
@@ -8,25 +7,25 @@ from sonusai.mixture import AudioT
8
7
  @dataclass(frozen=True)
9
8
  class ASRResult:
10
9
  text: str
11
- confidence: Optional[float] = None
12
- lang: Optional[str] = None
13
- lang_prob: Optional[float] = None
14
- duration: Optional[float] = None
15
- num_segments: Optional[int] = None
16
- asr_cpu_time: Optional[float] = None
10
+ confidence: float | None = None
11
+ lang: str | None = None
12
+ lang_prob: float | None = None
13
+ duration: float | None = None
14
+ num_segments: int | None = None
15
+ asr_cpu_time: float | None = None
17
16
 
18
17
 
19
18
  def get_available_engines() -> list[str]:
20
19
  from importlib import import_module
21
20
  from pkgutil import iter_modules
22
21
 
23
- module = import_module('sonusai.utils.asr_functions')
24
- engines = [method for method in dir(module) if not method.startswith('_')]
22
+ module = import_module("sonusai.utils.asr_functions")
23
+ engines = [method for method in dir(module) if not method.startswith("_")]
25
24
  for _, name, _ in iter_modules():
26
- if name.startswith('sonusai_asr_'):
27
- module = import_module(f'{name}.asr_functions')
25
+ if name.startswith("sonusai_asr_"):
26
+ module = import_module(f"{name}.asr_functions")
28
27
  for method in dir(module):
29
- if not method.startswith('_'):
28
+ if not method.startswith("_"):
30
29
  engines.append(method)
31
30
 
32
31
  return engines
@@ -36,19 +35,19 @@ def _asr_fn(engine: str) -> Callable[..., ASRResult]:
36
35
  from importlib import import_module
37
36
  from pkgutil import iter_modules
38
37
 
39
- module = import_module('sonusai.utils.asr_functions')
38
+ module = import_module("sonusai.utils.asr_functions")
40
39
  for method in dir(module):
41
40
  if method == engine:
42
41
  return getattr(module, method)
43
42
 
44
43
  for _, name, _ in iter_modules():
45
- if name.startswith('sonusai_asr_'):
46
- module = import_module(f'{name}.asr_functions')
44
+ if name.startswith("sonusai_asr_"):
45
+ module = import_module(f"{name}.asr_functions")
47
46
  for method in dir(module):
48
47
  if method == engine:
49
48
  return getattr(module, method)
50
49
 
51
- raise ValueError(f'engine {engine} not supported')
50
+ raise ValueError(f"engine {engine} not supported")
52
51
 
53
52
 
54
53
  def calc_asr(audio: AudioT | str, engine: str, **config) -> ASRResult:
@@ -69,3 +68,24 @@ def calc_asr(audio: AudioT | str, engine: str, **config) -> ASRResult:
69
68
  audio = copy(read_audio(audio))
70
69
 
71
70
  return _asr_fn(engine)(audio, **config)
71
+
72
+
73
+ def validate_asr(engine: str, **config) -> None:
74
+ from importlib import import_module
75
+ from pkgutil import iter_modules
76
+
77
+ module = import_module("sonusai.utils.asr_functions")
78
+ for method in dir(module):
79
+ if method == engine:
80
+ getattr(module, method + "_validate")(**config)
81
+ return
82
+
83
+ for _, name, _ in iter_modules():
84
+ if name.startswith("sonusai_asr_"):
85
+ module = import_module(f"{name}.asr_functions")
86
+ for method in dir(module):
87
+ if method == engine:
88
+ getattr(module, method + "_validate")(**config)
89
+ return
90
+
91
+ raise ValueError(f"engine {engine} not supported")
@@ -1 +1,3 @@
1
+ # ruff: noqa: F401
2
+
1
3
  from .aaware_whisper import aaware_whisper
@@ -2,6 +2,10 @@ from sonusai.mixture import AudioT
2
2
  from sonusai.utils import ASRResult
3
3
 
4
4
 
5
+ def aaware_whisper_validate(**_config) -> None:
6
+ pass
7
+
8
+
5
9
  def aaware_whisper(audio: AudioT, **_config) -> ASRResult:
6
10
  import tempfile
7
11
  from math import exp
@@ -10,32 +14,34 @@ def aaware_whisper(audio: AudioT, **_config) -> ASRResult:
10
14
 
11
15
  import requests
12
16
 
13
- from sonusai import SonusAIError
14
17
  from sonusai.utils import ASRResult
15
18
  from sonusai.utils import float_to_int16
16
19
  from sonusai.utils import write_audio
17
20
 
18
- url = getenv('AAWARE_WHISPER_URL')
21
+ url = getenv("AAWARE_WHISPER_URL")
19
22
  if url is None:
20
- raise SonusAIError(f'AAWARE_WHISPER_URL environment variable does not exist')
21
- url += '/asr?task=transcribe&language=en&encode=true&output=json'
23
+ raise EnvironmentError("AAWARE_WHISPER_URL environment variable does not exist")
24
+ url += "/asr?task=transcribe&language=en&encode=true&output=json"
22
25
 
23
26
  with tempfile.TemporaryDirectory() as tmp:
24
- file = join(tmp, 'asr.wav')
27
+ file = join(tmp, "asr.wav")
25
28
  write_audio(name=file, audio=float_to_int16(audio))
26
29
 
27
- files = {'audio_file': (file, open(file, 'rb'), 'audio/wav')}
30
+ files = {"audio_file": (file, open(file, "rb"), "audio/wav")} # noqa: SIM115
28
31
 
29
32
  try:
30
- response = requests.post(url, files=files)
31
- if not response.status_code == 200:
33
+ response = requests.post(url, files=files) # noqa: S113
34
+ if response.status_code != 200:
32
35
  if response.status_code == 422:
33
- raise SonusAIError(f'Validation error: {response.json()}')
34
- raise SonusAIError(f'Invalid response: {response.status_code}')
36
+ raise RuntimeError(f"Validation error: {response.json()}") # noqa: TRY301
37
+ raise RuntimeError(f"Invalid response: {response.status_code}") # noqa: TRY301
35
38
  result = response.json()
36
- return ASRResult(text=result['text'], confidence=exp(float(result['segments'][0]['avg_logprob'])))
39
+ return ASRResult(
40
+ text=result["text"],
41
+ confidence=exp(float(result["segments"][0]["avg_logprob"])),
42
+ )
37
43
  except Exception as e:
38
- raise SonusAIError(f'Aaware Whisper exception: {e.args}')
44
+ raise RuntimeError(f"Aaware Whisper exception: {e.args}") from e
39
45
 
40
46
 
41
47
  """
@@ -1,29 +1,29 @@
1
1
  import pyaudio
2
2
 
3
3
 
4
- def get_input_device_index_by_name(p: pyaudio.PyAudio, name: str = None) -> int:
4
+ def get_input_device_index_by_name(p: pyaudio.PyAudio, name: str | None = None) -> int:
5
5
  info = p.get_host_api_info_by_index(0)
6
- device_count = info.get('deviceCount')
6
+ device_count = info.get("deviceCount")
7
7
  for i in range(0, device_count):
8
8
  device_info = p.get_device_info_by_host_api_device_index(0, i)
9
9
  if name is None:
10
10
  device_name = None
11
11
  else:
12
- device_name = device_info.get('name')
13
- if name == device_name and device_info.get('maxInputChannels') > 0:
12
+ device_name = device_info.get("name")
13
+ if name == device_name and device_info.get("maxInputChannels") > 0:
14
14
  return i
15
15
 
16
- raise ValueError(f'Could not find {name}')
16
+ raise ValueError(f"Could not find {name}")
17
17
 
18
18
 
19
19
  def get_input_devices(p: pyaudio.PyAudio) -> list[str]:
20
20
  names = []
21
21
  info = p.get_host_api_info_by_index(0)
22
- device_count = info.get('deviceCount')
22
+ device_count = info.get("deviceCount")
23
23
  for i in range(0, device_count):
24
24
  device_info = p.get_device_info_by_host_api_device_index(0, i)
25
- device_name = device_info.get('name')
26
- if device_info.get('maxInputChannels') > 0:
25
+ device_name = device_info.get("name")
26
+ if device_info.get("maxInputChannels") > 0:
27
27
  names.append(device_name)
28
28
 
29
29
  return names
@@ -31,11 +31,11 @@ def get_input_devices(p: pyaudio.PyAudio) -> list[str]:
31
31
 
32
32
  def get_default_input_device(p: pyaudio.PyAudio) -> str:
33
33
  info = p.get_host_api_info_by_index(0)
34
- device_count = info.get('deviceCount')
34
+ device_count = info.get("deviceCount")
35
35
  for i in range(0, device_count):
36
36
  device_info = p.get_device_info_by_host_api_device_index(0, i)
37
- device_name = device_info.get('name')
38
- if device_info.get('maxInputChannels') > 0:
37
+ device_name = device_info.get("name")
38
+ if device_info.get("maxInputChannels") > 0:
39
39
  return device_name
40
40
 
41
- raise ValueError('No input audio devices found')
41
+ raise ValueError("No input audio devices found")
@@ -1,10 +1,8 @@
1
- from typing import Generator
1
+ from collections.abc import Generator
2
2
  from typing import LiteralString
3
- from typing import Optional
4
- from typing import Set
5
3
 
6
4
 
7
- def expand_braces(text: LiteralString | str | bytes, seen: Optional[Set[str]] = None) -> Generator[str, None, None]:
5
+ def expand_braces(text: LiteralString | str | bytes, seen: set[str] | None = None) -> Generator[str, None, None]:
8
6
  """Brace-expansion pre-processing for glob.
9
7
 
10
8
  Expand all the braces, then run glob on each of the results.
@@ -20,8 +18,8 @@ def expand_braces(text: LiteralString | str | bytes, seen: Optional[Set[str]] =
20
18
  if not isinstance(text, str):
21
19
  text = str(text)
22
20
 
23
- spans = [m.span() for m in re.finditer(r'\{[^{}]*}', text)][::-1]
24
- alts = [text[start + 1: stop - 1].split(',') for start, stop in spans]
21
+ spans = [m.span() for m in re.finditer(r"\{[^{}]*}", text)][::-1]
22
+ alts = [text[start + 1 : stop - 1].split(",") for start, stop in spans]
25
23
 
26
24
  if len(spans) == 0:
27
25
  if text not in seen:
@@ -30,9 +28,9 @@ def expand_braces(text: LiteralString | str | bytes, seen: Optional[Set[str]] =
30
28
  else:
31
29
  for combo in itertools.product(*alts):
32
30
  replaced = list(text)
33
- for (start, stop), replacement in zip(spans, combo):
31
+ for (start, stop), replacement in zip(spans, combo, strict=False):
34
32
  replaced[start:stop] = replacement
35
- yield from expand_braces(''.join(replaced), seen)
33
+ yield from expand_braces("".join(replaced), seen)
36
34
 
37
35
 
38
36
  def braced_glob(pathname: LiteralString | str | bytes, recursive: bool = False) -> list[str]:
@@ -1,7 +1,4 @@
1
- def calculate_input_shape(feature: str,
2
- flatten: bool = False,
3
- timesteps: int = 0,
4
- add1ch: bool = False) -> list[int]:
1
+ def calculate_input_shape(feature: str, flatten: bool = False, timesteps: int = 0, add1ch: bool = False) -> list[int]:
5
2
  """
6
3
  Calculate input shape given feature and user-specified reshape parameters.
7
4
 
sonusai/utils/compress.py CHANGED
@@ -6,7 +6,7 @@ def power_compress(feature: AudioF) -> AudioF:
6
6
 
7
7
  mag = np.abs(feature)
8
8
  phase = np.angle(feature)
9
- mag = mag ** 0.3
9
+ mag = mag**0.3
10
10
  real_compress = mag * np.cos(phase)
11
11
  imag_compress = mag * np.sin(phase)
12
12
 
@@ -18,7 +18,7 @@ def power_uncompress(feature: AudioF) -> AudioF:
18
18
 
19
19
  mag = np.abs(feature)
20
20
  phase = np.angle(feature)
21
- mag = mag ** (1. / 0.3)
21
+ mag = mag ** (1.0 / 0.3)
22
22
  real_uncompress = mag * np.cos(phase)
23
23
  imag_uncompress = mag * np.sin(phase)
24
24
 
@@ -1,8 +1,6 @@
1
1
  def convert_string_to_number(string: str) -> float | int | str:
2
2
  try:
3
3
  result = float(string)
4
- if result == int(result):
5
- return int(result)
6
- return result
4
+ return int(result) if result == int(result) else result
7
5
  except ValueError:
8
6
  return string
@@ -2,4 +2,4 @@ def create_timestamp() -> str:
2
2
  """Create a timestamp."""
3
3
  from datetime import datetime
4
4
 
5
- return datetime.now().strftime('%Y%m%d-%H%M%S')
5
+ return datetime.now().strftime("%Y%m%d-%H%M%S")
@@ -6,9 +6,9 @@ def create_ts_name(base: str) -> str:
6
6
  ts = datetime.now()
7
7
 
8
8
  # First try just date
9
- dir_name = base + '-' + ts.strftime('%Y%m%d')
9
+ dir_name = base + "-" + ts.strftime("%Y%m%d")
10
10
  if exists(dir_name):
11
11
  # add hour-min-sec if necessary
12
- dir_name = base + '-' + ts.strftime('%Y%m%d-%H%M%S')
12
+ dir_name = base + "-" + ts.strftime("%Y%m%d-%H%M%S")
13
13
 
14
14
  return dir_name
@@ -4,6 +4,6 @@ def dataclass_from_dict(klass, dikt):
4
4
  field_types = klass.__annotations__
5
5
  return klass(**{f: dataclass_from_dict(field_types[f], dikt[f]) for f in dikt})
6
6
  except AttributeError:
7
- if isinstance(dikt, (tuple, list)):
7
+ if isinstance(dikt, tuple | list):
8
8
  return [dataclass_from_dict(klass.__args__[0], f) for f in dikt]
9
9
  return dikt
@@ -3,7 +3,7 @@ def trim_docstring(docstring: str) -> str:
3
3
  from sys import maxsize
4
4
 
5
5
  if not docstring:
6
- return ''
6
+ return ""
7
7
 
8
8
  # Convert tabs to spaces (following the normal Python rules)
9
9
  # and split into a list of lines:
@@ -27,7 +27,7 @@ def trim_docstring(docstring: str) -> str:
27
27
  trimmed.pop(0)
28
28
 
29
29
  # Return a single string
30
- return '\n'.join(trimmed)
30
+ return "\n".join(trimmed)
31
31
 
32
32
 
33
33
  def add_commands_to_docstring(docstring: str, plugin_docstrings: list[str]) -> str:
@@ -36,8 +36,8 @@ def add_commands_to_docstring(docstring: str, plugin_docstrings: list[str]) -> s
36
36
 
37
37
  lines = docstring.splitlines()
38
38
 
39
- start = lines.index('The sonusai commands are:')
40
- end = lines.index('', start)
39
+ start = lines.index("The sonusai commands are:")
40
+ end = lines.index("", start)
41
41
 
42
42
  commands = sonusai.commands_doc.splitlines()
43
43
  for plugin_docstring in plugin_docstrings:
@@ -45,6 +45,6 @@ def add_commands_to_docstring(docstring: str, plugin_docstrings: list[str]) -> s
45
45
  commands.sort()
46
46
  commands = list(filter(None, commands))
47
47
 
48
- lines = lines[:start + 1] + commands + lines[end:]
48
+ lines = lines[: start + 1] + commands + lines[end:]
49
49
 
50
- return '\n'.join(lines)
50
+ return "\n".join(lines)
sonusai/utils/energy_f.py CHANGED
@@ -1,12 +1,15 @@
1
- from sonusai import ForwardTransform
1
+ from pyaaware import ForwardTransform
2
+
2
3
  from sonusai.mixture import AudioF
3
4
  from sonusai.mixture import AudioT
4
5
  from sonusai.mixture import EnergyF
5
6
 
6
7
 
7
- def compute_energy_f(frequency_domain: AudioF = None,
8
- time_domain: AudioT = None,
9
- transform: ForwardTransform = None) -> EnergyF:
8
+ def compute_energy_f(
9
+ frequency_domain: AudioF | None = None,
10
+ time_domain: AudioT | None = None,
11
+ transform: ForwardTransform | None = None,
12
+ ) -> EnergyF:
10
13
  """Compute the energy in each bin
11
14
 
12
15
  Must provide either frequency domain or time domain input. If time domain input is provided, must also provide
@@ -19,13 +22,12 @@ def compute_energy_f(frequency_domain: AudioF = None,
19
22
  """
20
23
  import numpy as np
21
24
  import torch
22
- from sonusai import SonusAIError
23
25
 
24
26
  if frequency_domain is None:
25
27
  if time_domain is None:
26
- raise SonusAIError('Must provide time or frequency domain input')
28
+ raise ValueError("Must provide time or frequency domain input")
27
29
  if transform is None:
28
- raise SonusAIError('Must provide ForwardTransform object')
30
+ raise ValueError("Must provide ForwardTransform object")
29
31
 
30
32
  frequency_domain = transform.execute_all(torch.from_numpy(time_domain))[0].numpy()
31
33