sonusai 0.18.9__py3-none-any.whl → 0.19.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (118) hide show
  1. sonusai/__init__.py +20 -29
  2. sonusai/aawscd_probwrite.py +18 -18
  3. sonusai/audiofe.py +93 -80
  4. sonusai/calc_metric_spenh.py +395 -321
  5. sonusai/data/genmixdb.yml +5 -11
  6. sonusai/{gentcst.py → deprecated/gentcst.py} +146 -149
  7. sonusai/{plot.py → deprecated/plot.py} +177 -131
  8. sonusai/{tplot.py → deprecated/tplot.py} +124 -102
  9. sonusai/doc/__init__.py +1 -1
  10. sonusai/doc/doc.py +112 -177
  11. sonusai/doc.py +10 -10
  12. sonusai/genft.py +81 -91
  13. sonusai/genmetrics.py +51 -61
  14. sonusai/genmix.py +105 -115
  15. sonusai/genmixdb.py +201 -174
  16. sonusai/lsdb.py +56 -66
  17. sonusai/main.py +23 -20
  18. sonusai/metrics/__init__.py +2 -0
  19. sonusai/metrics/calc_audio_stats.py +29 -24
  20. sonusai/metrics/calc_class_weights.py +7 -7
  21. sonusai/metrics/calc_optimal_thresholds.py +5 -7
  22. sonusai/metrics/calc_pcm.py +3 -3
  23. sonusai/metrics/calc_pesq.py +10 -7
  24. sonusai/metrics/calc_phase_distance.py +3 -3
  25. sonusai/metrics/calc_sa_sdr.py +10 -8
  26. sonusai/metrics/calc_segsnr_f.py +16 -18
  27. sonusai/metrics/calc_speech.py +105 -47
  28. sonusai/metrics/calc_wer.py +35 -32
  29. sonusai/metrics/calc_wsdr.py +10 -7
  30. sonusai/metrics/class_summary.py +30 -27
  31. sonusai/metrics/confusion_matrix_summary.py +25 -22
  32. sonusai/metrics/one_hot.py +91 -57
  33. sonusai/metrics/snr_summary.py +53 -46
  34. sonusai/mixture/__init__.py +20 -14
  35. sonusai/mixture/audio.py +4 -6
  36. sonusai/mixture/augmentation.py +37 -43
  37. sonusai/mixture/class_count.py +5 -14
  38. sonusai/mixture/config.py +292 -225
  39. sonusai/mixture/constants.py +41 -30
  40. sonusai/mixture/data_io.py +155 -0
  41. sonusai/mixture/datatypes.py +111 -108
  42. sonusai/mixture/db_datatypes.py +54 -70
  43. sonusai/mixture/eq_rule_is_valid.py +6 -9
  44. sonusai/mixture/feature.py +40 -38
  45. sonusai/mixture/generation.py +522 -389
  46. sonusai/mixture/helpers.py +217 -272
  47. sonusai/mixture/log_duration_and_sizes.py +16 -13
  48. sonusai/mixture/mixdb.py +669 -477
  49. sonusai/mixture/soundfile_audio.py +12 -17
  50. sonusai/mixture/sox_audio.py +91 -112
  51. sonusai/mixture/sox_augmentation.py +8 -9
  52. sonusai/mixture/spectral_mask.py +4 -6
  53. sonusai/mixture/target_class_balancing.py +41 -36
  54. sonusai/mixture/targets.py +69 -67
  55. sonusai/mixture/tokenized_shell_vars.py +23 -23
  56. sonusai/mixture/torchaudio_audio.py +14 -15
  57. sonusai/mixture/torchaudio_augmentation.py +23 -27
  58. sonusai/mixture/truth.py +48 -26
  59. sonusai/mixture/truth_functions/__init__.py +26 -0
  60. sonusai/mixture/truth_functions/crm.py +56 -38
  61. sonusai/mixture/truth_functions/datatypes.py +37 -0
  62. sonusai/mixture/truth_functions/energy.py +85 -59
  63. sonusai/mixture/truth_functions/file.py +30 -30
  64. sonusai/mixture/truth_functions/phoneme.py +14 -7
  65. sonusai/mixture/truth_functions/sed.py +71 -45
  66. sonusai/mixture/truth_functions/target.py +69 -106
  67. sonusai/mkwav.py +58 -101
  68. sonusai/onnx_predict.py +46 -43
  69. sonusai/queries/__init__.py +3 -1
  70. sonusai/queries/queries.py +100 -59
  71. sonusai/speech/__init__.py +2 -0
  72. sonusai/speech/l2arctic.py +24 -23
  73. sonusai/speech/librispeech.py +16 -17
  74. sonusai/speech/mcgill.py +22 -21
  75. sonusai/speech/textgrid.py +32 -25
  76. sonusai/speech/timit.py +45 -42
  77. sonusai/speech/vctk.py +14 -13
  78. sonusai/speech/voxceleb.py +26 -20
  79. sonusai/summarize_metric_spenh.py +11 -10
  80. sonusai/utils/__init__.py +4 -3
  81. sonusai/utils/asl_p56.py +1 -1
  82. sonusai/utils/asr.py +37 -17
  83. sonusai/utils/asr_functions/__init__.py +2 -0
  84. sonusai/utils/asr_functions/aaware_whisper.py +18 -12
  85. sonusai/utils/audio_devices.py +12 -12
  86. sonusai/utils/braced_glob.py +6 -8
  87. sonusai/utils/calculate_input_shape.py +1 -4
  88. sonusai/utils/compress.py +2 -2
  89. sonusai/utils/convert_string_to_number.py +1 -3
  90. sonusai/utils/create_timestamp.py +1 -1
  91. sonusai/utils/create_ts_name.py +2 -2
  92. sonusai/utils/dataclass_from_dict.py +1 -1
  93. sonusai/utils/docstring.py +6 -6
  94. sonusai/utils/energy_f.py +9 -7
  95. sonusai/utils/engineering_number.py +56 -54
  96. sonusai/utils/get_label_names.py +8 -10
  97. sonusai/utils/human_readable_size.py +2 -2
  98. sonusai/utils/model_utils.py +3 -5
  99. sonusai/utils/numeric_conversion.py +2 -4
  100. sonusai/utils/onnx_utils.py +43 -32
  101. sonusai/utils/parallel.py +41 -30
  102. sonusai/utils/print_mixture_details.py +25 -22
  103. sonusai/utils/ranges.py +12 -12
  104. sonusai/utils/read_predict_data.py +11 -9
  105. sonusai/utils/reshape.py +19 -26
  106. sonusai/utils/seconds_to_hms.py +1 -1
  107. sonusai/utils/stacked_complex.py +8 -16
  108. sonusai/utils/stratified_shuffle_split.py +29 -27
  109. sonusai/utils/write_audio.py +2 -2
  110. sonusai/utils/yes_or_no.py +3 -3
  111. sonusai/vars.py +14 -14
  112. {sonusai-0.18.9.dist-info → sonusai-0.19.6.dist-info}/METADATA +20 -21
  113. sonusai-0.19.6.dist-info/RECORD +125 -0
  114. {sonusai-0.18.9.dist-info → sonusai-0.19.6.dist-info}/WHEEL +1 -1
  115. sonusai/mixture/truth_functions/data.py +0 -58
  116. sonusai/utils/read_mixture_data.py +0 -14
  117. sonusai-0.18.9.dist-info/RECORD +0 -125
  118. {sonusai-0.18.9.dist-info → sonusai-0.19.6.dist-info}/entry_points.txt +0 -0
sonusai/__init__.py CHANGED
@@ -2,11 +2,10 @@ import logging
2
2
  from importlib import metadata
3
3
  from os.path import dirname
4
4
 
5
- from pyaaware import TorchForwardTransform
6
- from pyaaware import TorchInverseTransform
5
+ from rich.logging import RichHandler
6
+ from rich.traceback import install
7
7
 
8
- ForwardTransform = TorchForwardTransform
9
- InverseTransform = TorchInverseTransform
8
+ install(show_locals=True)
10
9
 
11
10
  __version__ = metadata.version(__package__)
12
11
  BASEDIR = dirname(__file__)
@@ -19,34 +18,26 @@ commands_doc = """
19
18
  genmetrics Generate mixture metrics data
20
19
  genmix Generate mixture and truth data
21
20
  genmixdb Generate a mixture database
22
- gentcst Generate target configuration from a subdirectory tree
23
21
  lsdb List information about a mixture database
24
22
  mkwav Make WAV files from a mixture database
25
23
  onnx_predict Run ONNX predict on a trained model
26
- plot Plot mixture data
27
24
  summarize_metric_spenh Summarize speech enhancement and analysis results
28
- tplot Plot truth data
29
25
  vars List custom SonusAI variables
30
26
  """
31
27
 
32
28
  # create logger
33
- logger = logging.getLogger('sonusai')
29
+ logger = logging.getLogger("sonusai")
34
30
  logger.setLevel(logging.DEBUG)
35
- formatter = logging.Formatter('%(message)s')
36
- console_handler = logging.StreamHandler()
31
+ formatter = logging.Formatter("%(message)s")
32
+ console_handler = RichHandler(show_level=False, show_path=False, show_time=False)
37
33
  console_handler.setLevel(logging.DEBUG)
38
34
  console_handler.setFormatter(formatter)
39
35
  logger.addHandler(console_handler)
40
36
 
41
37
 
42
- class SonusAIError(Exception):
43
- def __init__(self, value):
44
- logger.error(value)
45
-
46
-
47
38
  # create file handler
48
39
  def create_file_handler(filename: str) -> None:
49
- fh = logging.FileHandler(filename=filename, mode='w')
40
+ fh = logging.FileHandler(filename=filename, mode="w")
50
41
  fh.setLevel(logging.DEBUG)
51
42
  fh.setFormatter(formatter)
52
43
  logger.addHandler(fh)
@@ -61,7 +52,7 @@ def update_console_handler(verbose: bool) -> None:
61
52
 
62
53
 
63
54
  # write initial log message
64
- def initial_log_messages(name: str, subprocess: str = None) -> None:
55
+ def initial_log_messages(name: str, subprocess: str | None = None) -> None:
65
56
  from datetime import datetime
66
57
  from getpass import getuser
67
58
  from os import getcwd
@@ -69,24 +60,24 @@ def initial_log_messages(name: str, subprocess: str = None) -> None:
69
60
  from sys import argv
70
61
 
71
62
  if subprocess is None:
72
- logger.info(f'SonusAI {__version__}')
63
+ logger.info(f"SonusAI {__version__}")
73
64
  else:
74
- logger.info(f'SonusAI {subprocess}')
75
- logger.info(f'{name}')
76
- logger.info('')
77
- logger.debug(f'Host: {gethostname()}')
78
- logger.debug(f'User: {getuser()}')
79
- logger.debug(f'Directory: {getcwd()}')
80
- logger.debug(f'Date: {datetime.now()}')
81
- logger.debug(f'Command: {" ".join(argv)}')
82
- logger.debug('')
65
+ logger.info(f"SonusAI {subprocess}")
66
+ logger.info(f"{name}")
67
+ logger.info("")
68
+ logger.debug(f"Host: {gethostname()}")
69
+ logger.debug(f"User: {getuser()}")
70
+ logger.debug(f"Directory: {getcwd()}")
71
+ logger.debug(f"Date: {datetime.now()}")
72
+ logger.debug(f"Command: {' '.join(argv)}")
73
+ logger.debug("")
83
74
 
84
75
 
85
76
  def commands_list(doc: str = commands_doc) -> list[str]:
86
- lines = doc.split('\n')
77
+ lines = doc.split("\n")
87
78
  commands = []
88
79
  for line in lines:
89
- command = line.strip().split(' ').pop(0)
80
+ command = line.strip().split(" ").pop(0)
90
81
  if command:
91
82
  commands.append(command)
92
83
  return commands
@@ -12,19 +12,20 @@ aawscd_probwrite connects to an Aaware platform running aawscd and writes the so
12
12
  probability output to an HDF5 file.
13
13
 
14
14
  """
15
+
15
16
  from threading import Condition
16
- from typing import Optional
17
17
 
18
18
  import numpy as np
19
- from tqdm import tqdm
20
19
 
21
- CLIENT: str = 'aawscd_probwrite'
22
- TOPIC: str = 'aawscd/sc/prob'
20
+ from sonusai.utils import track
21
+
22
+ CLIENT: str = "aawscd_probwrite"
23
+ TOPIC: str = "aawscd/sc/prob"
23
24
  DONE: Condition = Condition()
24
25
  FRAMES: int = 10
25
26
  FRAME_COUNT: int = 0
26
- DATA: Optional[np.ndarray] = None
27
- PROGRESS: Optional[tqdm] = None
27
+ DATA: np.ndarray | None = None
28
+ PROGRESS: track | None = None
28
29
 
29
30
 
30
31
  def shutdown(_signum, _frame) -> None:
@@ -75,8 +76,8 @@ def on_message(_client, _userdata, message):
75
76
 
76
77
  global TOPIC
77
78
  if mqtt.topic_matches_sub(TOPIC, message.topic):
78
- payload = yaml.safe_load(str(message.payload.decode('utf-8')))
79
- prob = parse_prob(payload['prob'])
79
+ payload = yaml.safe_load(str(message.payload.decode("utf-8")))
80
+ prob = parse_prob(payload["prob"])
80
81
 
81
82
  global DATA
82
83
  global FRAMES
@@ -99,20 +100,19 @@ def on_message(_client, _userdata, message):
99
100
  def main() -> None:
100
101
  from docopt import docopt
101
102
 
102
- args = docopt(__doc__, version='1.0.0', options_first=True)
103
+ args = docopt(__doc__, version="1.0.0", options_first=True)
103
104
 
104
105
  import signal
105
106
 
106
107
  import h5py
107
108
  import paho.mqtt.client as mqtt
108
- from tqdm import tqdm
109
109
 
110
- machine = args['--machine']
110
+ machine = args["--machine"]
111
111
 
112
112
  global FRAMES
113
- FRAMES = int(args['--frames'])
113
+ FRAMES = int(args["--frames"])
114
114
 
115
- file = args['FILE']
115
+ file = args["FILE"]
116
116
 
117
117
  signal.signal(signal.SIGINT, shutdown)
118
118
  signal.signal(signal.SIGTERM, shutdown)
@@ -126,7 +126,7 @@ def main() -> None:
126
126
  client.subscribe(topic=TOPIC)
127
127
 
128
128
  global PROGRESS
129
- PROGRESS = tqdm(total=FRAMES, desc=file)
129
+ PROGRESS = track(total=FRAMES, desc=file)
130
130
 
131
131
  with DONE:
132
132
  DONE.wait()
@@ -138,11 +138,11 @@ def main() -> None:
138
138
  client.disconnect()
139
139
 
140
140
  global DATA
141
- with h5py.File(file, 'w') as f:
142
- f.create_dataset(name='prob', data=DATA)
141
+ with h5py.File(file, "w") as f:
142
+ f.create_dataset(name="prob", data=DATA)
143
143
 
144
- print(f'Wrote {file}')
144
+ print(f"Wrote {file}")
145
145
 
146
146
 
147
- if __name__ == '__main__':
147
+ if __name__ == "__main__":
148
148
  main()
sonusai/audiofe.py CHANGED
@@ -34,6 +34,7 @@ If the debug option is enabled, write capture audio, feature, reconstruct audio,
34
34
  audiofe_<TIMESTAMP>.h5.
35
35
 
36
36
  """
37
+
37
38
  import signal
38
39
 
39
40
  import numpy as np
@@ -46,7 +47,7 @@ def signal_handler(_sig, _frame):
46
47
 
47
48
  from sonusai import logger
48
49
 
49
- logger.info('Canceled due to keyboard interrupt')
50
+ logger.info("Canceled due to keyboard interrupt")
50
51
  sys.exit(1)
51
52
 
52
53
 
@@ -61,14 +62,14 @@ def main() -> None:
61
62
 
62
63
  args = docopt(trim_docstring(__doc__), version=sonusai.__version__, options_first=True)
63
64
 
64
- verbose = args['--verbose']
65
- length = float(args['--length'])
66
- input_name = args['--input']
67
- model_name = args['--model']
68
- asr_name = args['--asr']
69
- whisper_name = args['--whisper']
70
- debug = args['--debug']
71
- show = args['--show']
65
+ verbose = args["--verbose"]
66
+ length = float(args["--length"])
67
+ input_name = args["--input"]
68
+ model_name = args["--model"]
69
+ asr_name = args["--asr"]
70
+ whisper_name = args["--whisper"]
71
+ debug = args["--debug"]
72
+ show = args["--show"]
72
73
 
73
74
  from os.path import exists
74
75
 
@@ -89,26 +90,26 @@ def main() -> None:
89
90
  from sonusai.utils import write_audio
90
91
 
91
92
  ts = create_timestamp()
92
- capture_name = f'audiofe_capture_{ts}'
93
- capture_wav = capture_name + '.wav'
94
- capture_png = capture_name + '.png'
95
- predict_name = f'audiofe_predict_{ts}'
96
- predict_wav = predict_name + '.wav'
97
- predict_png = predict_name + '.png'
98
- h5_name = f'audiofe_{ts}.h5'
93
+ capture_name = f"audiofe_capture_{ts}"
94
+ capture_wav = capture_name + ".wav"
95
+ capture_png = capture_name + ".png"
96
+ predict_name = f"audiofe_predict_{ts}"
97
+ predict_wav = predict_name + ".wav"
98
+ predict_png = predict_name + ".png"
99
+ h5_name = f"audiofe_{ts}.h5"
99
100
 
100
101
  # Setup logging file
101
- create_file_handler('audiofe.log')
102
+ create_file_handler("audiofe.log")
102
103
  update_console_handler(verbose)
103
- initial_log_messages('audiofe')
104
+ initial_log_messages("audiofe")
104
105
 
105
106
  if show:
106
- logger.info('List of available audio inputs:')
107
- logger.info('')
107
+ logger.info("List of available audio inputs:")
108
+ logger.info("")
108
109
  p = pyaudio.PyAudio()
109
110
  for name in get_input_devices(p):
110
- logger.info(f'{name}')
111
- logger.info('')
111
+ logger.info(f"{name}")
112
+ logger.info("")
112
113
  p.terminate()
113
114
  return
114
115
 
@@ -122,27 +123,27 @@ def main() -> None:
122
123
  return
123
124
  # Only write if capture from device, not for file input
124
125
  write_audio(capture_wav, capture_audio, SAMPLE_RATE)
125
- logger.info('')
126
- logger.info(f'Wrote capture audio with shape {capture_audio.shape} to {capture_wav}')
126
+ logger.info("")
127
+ logger.info(f"Wrote capture audio with shape {capture_audio.shape} to {capture_wav}")
127
128
 
128
129
  if debug:
129
- with h5py.File(h5_name, 'a') as f:
130
- if 'capture_audio' in f:
131
- del f['capture_audio']
132
- f.create_dataset('capture_audio', data=capture_audio)
133
- logger.info(f'Wrote capture feature data with shape {capture_audio.shape} to {h5_name}')
130
+ with h5py.File(h5_name, "a") as f:
131
+ if "capture_audio" in f:
132
+ del f["capture_audio"]
133
+ f.create_dataset("capture_audio", data=capture_audio)
134
+ logger.info(f"Wrote capture feature data with shape {capture_audio.shape} to {h5_name}")
134
135
 
135
136
  if asr_name is not None:
136
- logger.info(f'Running ASR on captured audio with {asr_name} ...')
137
+ logger.info(f"Running ASR on captured audio with {asr_name} ...")
137
138
  capture_asr = calc_asr(capture_audio, engine=asr_name, whisper_model_name=whisper_name).text
138
- logger.info(f'Capture audio ASR: {capture_asr}')
139
+ logger.info(f"Capture audio ASR: {capture_asr}")
139
140
 
140
141
  if model_name is not None:
141
142
  session, options, model_root, hparams, sess_inputs, sess_outputs = load_ort_session(model_name)
142
143
  if hparams is None:
143
- logger.error(f'Error: ONNX model does not have required SonusAI hyperparameters, cannot proceed.')
144
+ logger.error("Error: ONNX model does not have required SonusAI hyperparameters, cannot proceed.")
144
145
  raise SystemExit(1)
145
- feature_mode = hparams['feature']
146
+ feature_mode = hparams["feature"]
146
147
  in0name = sess_inputs[0].name
147
148
  in0type = sess_inputs[0].type
148
149
  out_names = [n.name for n in session.get_outputs()]
@@ -150,47 +151,50 @@ def main() -> None:
150
151
  # frames x stride x feat_params
151
152
  feature = get_feature_from_audio(audio=capture_audio, feature_mode=feature_mode)
152
153
  save_figure(capture_png, capture_audio, feature)
153
- logger.info(f'Wrote capture plots to {capture_png}')
154
+ logger.info(f"Wrote capture plots to {capture_png}")
154
155
 
155
156
  if debug:
156
- with h5py.File(h5_name, 'a') as f:
157
- if 'feature' in f:
158
- del f['feature']
159
- f.create_dataset('feature', data=feature)
160
- logger.info(f'Wrote feature with shape {feature.shape} to {h5_name}')
157
+ with h5py.File(h5_name, "a") as f:
158
+ if "feature" in f:
159
+ del f["feature"]
160
+ f.create_dataset("feature", data=feature)
161
+ logger.info(f"Wrote feature with shape {feature.shape} to {h5_name}")
161
162
 
162
- if in0type.find('float16') != -1:
163
- logger.info(f'Detected input of float16, converting all feature inputs to that type.')
164
- feature = np.float16(feature) # type: ignore
163
+ if in0type.find("float16") != -1:
164
+ logger.info("Detected input of float16, converting all feature inputs to that type.")
165
+ feature = np.float16(feature) # type: ignore[assignment]
165
166
 
166
167
  # Run inference, ort session wants batch x timesteps x feat_params, outputs numpy BxTxFP or BxFP
167
168
  # Note full reshape not needed here since we assume speech enhancement type model, so a transpose suffices
168
- predict = np.transpose(session.run(out_names, {in0name: np.transpose(feature, (1, 0, 2))})[0], (1, 0, 2))
169
+ predict = np.transpose(
170
+ session.run(out_names, {in0name: np.transpose(feature, (1, 0, 2))})[0],
171
+ (1, 0, 2),
172
+ )
169
173
 
170
174
  if debug:
171
- with h5py.File(h5_name, 'a') as f:
172
- if 'predict' in f:
173
- del f['predict']
174
- f.create_dataset('predict', data=predict)
175
- logger.info(f'Wrote predict with shape {predict.shape} to {h5_name}')
175
+ with h5py.File(h5_name, "a") as f:
176
+ if "predict" in f:
177
+ del f["predict"]
178
+ f.create_dataset("predict", data=predict)
179
+ logger.info(f"Wrote predict with shape {predict.shape} to {h5_name}")
176
180
 
177
181
  predict_audio = get_audio_from_feature(feature=predict, feature_mode=feature_mode)
178
182
  write_audio(predict_wav, predict_audio, SAMPLE_RATE)
179
- logger.info(f'Wrote predict audio with shape {predict_audio.shape} to {predict_wav}')
183
+ logger.info(f"Wrote predict audio with shape {predict_audio.shape} to {predict_wav}")
180
184
  if debug:
181
- with h5py.File(h5_name, 'a') as f:
182
- if 'predict_audio' in f:
183
- del f['predict_audio']
184
- f.create_dataset('predict_audio', data=predict_audio)
185
- logger.info(f'Wrote predict audio with shape {predict_audio.shape} to {h5_name}')
185
+ with h5py.File(h5_name, "a") as f:
186
+ if "predict_audio" in f:
187
+ del f["predict_audio"]
188
+ f.create_dataset("predict_audio", data=predict_audio)
189
+ logger.info(f"Wrote predict audio with shape {predict_audio.shape} to {h5_name}")
186
190
 
187
191
  save_figure(predict_png, predict_audio, predict)
188
- logger.info(f'Wrote predict plots to {predict_png}')
192
+ logger.info(f"Wrote predict plots to {predict_png}")
189
193
 
190
194
  if asr_name is not None:
191
- logger.info(f'Running ASR on model-enhanced audio with {asr_name} ...')
195
+ logger.info(f"Running ASR on model-enhanced audio with {asr_name} ...")
192
196
  predict_asr = calc_asr(predict_audio, engine=asr_name, whisper_model_name=whisper_name).text
193
- logger.info(f'Predict audio ASR: {predict_asr}')
197
+ logger.info(f"Predict audio ASR: {predict_asr}")
194
198
 
195
199
 
196
200
  def get_frames_from_device(input_name: str | None, length: float, chunk: int = 1024) -> AudioT:
@@ -209,32 +213,34 @@ def get_frames_from_device(input_name: str | None, length: float, chunk: int = 1
209
213
 
210
214
  input_devices = get_input_devices(p)
211
215
  if not input_devices:
212
- raise ValueError('No input audio devices found')
216
+ raise ValueError("No input audio devices found")
213
217
 
214
218
  if input_name is None:
215
219
  input_name = input_devices[0]
216
220
 
217
221
  try:
218
222
  device_index = get_input_device_index_by_name(p, input_name)
219
- except ValueError:
220
- msg = f'Could not find {input_name}\n'
221
- msg += f'Available devices:\n'
223
+ except ValueError as e:
224
+ msg = f"Could not find {input_name}\n"
225
+ msg += "Available devices:\n"
222
226
  for input_device in input_devices:
223
- msg += f' {input_device}\n'
224
- raise ValueError(msg)
225
-
226
- logger.info(f'Capturing from {p.get_device_info_by_index(device_index).get("name")}')
227
- stream = p.open(format=pyaudio.paFloat32,
228
- channels=CHANNEL_COUNT,
229
- rate=SAMPLE_RATE,
230
- input=True,
231
- input_device_index=device_index)
227
+ msg += f" {input_device}\n"
228
+ raise ValueError(msg) from e
229
+
230
+ logger.info(f"Capturing from {p.get_device_info_by_index(device_index).get('name')}")
231
+ stream = p.open(
232
+ format=pyaudio.paFloat32,
233
+ channels=CHANNEL_COUNT,
234
+ rate=SAMPLE_RATE,
235
+ input=True,
236
+ input_device_index=device_index,
237
+ )
232
238
  stream.start_stream()
233
239
 
234
240
  print()
235
- print('+---------------------------------+')
236
- print('| Press Enter to stop |')
237
- print('+---------------------------------+')
241
+ print("+---------------------------------+")
242
+ print("| Press Enter to stop |")
243
+ print("+---------------------------------+")
238
244
  print()
239
245
 
240
246
  elapsed = 0.0
@@ -243,14 +249,21 @@ def get_frames_from_device(input_name: str | None, length: float, chunk: int = 1
243
249
  while elapsed < length or length == -1:
244
250
  raw_frames.append(stream.read(num_frames=chunk, exception_on_overflow=False))
245
251
  elapsed += seconds_per_chunk
246
- if select([stdin, ], [], [], 0)[0]:
252
+ if select(
253
+ [
254
+ stdin,
255
+ ],
256
+ [],
257
+ [],
258
+ 0,
259
+ )[0]:
247
260
  stdin.read(1)
248
261
  length = elapsed
249
262
 
250
263
  stream.stop_stream()
251
264
  stream.close()
252
265
  p.terminate()
253
- frames = np.frombuffer(b''.join(raw_frames), dtype=np.float32)
266
+ frames = np.frombuffer(b"".join(raw_frames), dtype=np.float32)
254
267
  return frames
255
268
 
256
269
 
@@ -259,7 +272,7 @@ def get_frames_from_file(input_name: str, length: float) -> AudioT:
259
272
  from sonusai.mixture import SAMPLE_RATE
260
273
  from sonusai.mixture import read_audio
261
274
 
262
- logger.info(f'Capturing from {input_name}')
275
+ logger.info(f"Capturing from {input_name}")
263
276
  frames = read_audio(input_name)
264
277
  if length != -1:
265
278
  num_frames = int(length * SAMPLE_RATE)
@@ -289,16 +302,16 @@ def save_figure(name: str, audio: np.ndarray, feature: np.ndarray) -> None:
289
302
  fig, (ax1, ax2) = plt.subplots(nrows=2)
290
303
  ax1.set_title(name)
291
304
  ax1.plot(t, audio[:samples])
292
- ax1.set_ylabel('Signal')
305
+ ax1.set_ylabel("Signal")
293
306
  ax1.set_xlim(0, length_in_s)
294
307
  ax1.set_ylim(-1, 1)
295
308
 
296
- ax2.imshow(spectrum, origin='lower', aspect='auto')
309
+ ax2.imshow(spectrum, origin="lower", aspect="auto")
297
310
  ax2.set_xticks([])
298
- ax2.set_ylabel('Feature')
311
+ ax2.set_ylabel("Feature")
299
312
 
300
313
  plt.savefig(name, dpi=300)
301
314
 
302
315
 
303
- if __name__ == '__main__':
316
+ if __name__ == "__main__":
304
317
  main()