sonusai 0.20.3__py3-none-any.whl → 1.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (97) hide show
  1. sonusai/__init__.py +16 -3
  2. sonusai/audiofe.py +241 -77
  3. sonusai/calc_metric_spenh.py +71 -73
  4. sonusai/config/__init__.py +3 -0
  5. sonusai/config/config.py +61 -0
  6. sonusai/config/config.yml +20 -0
  7. sonusai/config/constants.py +8 -0
  8. sonusai/constants.py +11 -0
  9. sonusai/data/genmixdb.yml +21 -36
  10. sonusai/{mixture/datatypes.py → datatypes.py} +91 -130
  11. sonusai/deprecated/plot.py +4 -5
  12. sonusai/doc/doc.py +4 -4
  13. sonusai/doc.py +11 -4
  14. sonusai/genft.py +43 -45
  15. sonusai/genmetrics.py +25 -19
  16. sonusai/genmix.py +54 -82
  17. sonusai/genmixdb.py +88 -264
  18. sonusai/ir_metric.py +30 -34
  19. sonusai/lsdb.py +41 -48
  20. sonusai/main.py +15 -22
  21. sonusai/metrics/calc_audio_stats.py +4 -293
  22. sonusai/metrics/calc_class_weights.py +4 -4
  23. sonusai/metrics/calc_optimal_thresholds.py +8 -5
  24. sonusai/metrics/calc_pesq.py +2 -2
  25. sonusai/metrics/calc_segsnr_f.py +4 -4
  26. sonusai/metrics/calc_speech.py +25 -13
  27. sonusai/metrics/class_summary.py +7 -7
  28. sonusai/metrics/confusion_matrix_summary.py +5 -5
  29. sonusai/metrics/one_hot.py +4 -4
  30. sonusai/metrics/snr_summary.py +7 -7
  31. sonusai/metrics_summary.py +38 -45
  32. sonusai/mixture/__init__.py +4 -104
  33. sonusai/mixture/audio.py +10 -39
  34. sonusai/mixture/class_balancing.py +103 -0
  35. sonusai/mixture/config.py +251 -271
  36. sonusai/mixture/constants.py +35 -39
  37. sonusai/mixture/data_io.py +25 -36
  38. sonusai/mixture/db_datatypes.py +58 -22
  39. sonusai/mixture/effects.py +386 -0
  40. sonusai/mixture/feature.py +7 -11
  41. sonusai/mixture/generation.py +478 -628
  42. sonusai/mixture/helpers.py +82 -184
  43. sonusai/mixture/ir_delay.py +3 -4
  44. sonusai/mixture/ir_effects.py +77 -0
  45. sonusai/mixture/log_duration_and_sizes.py +6 -12
  46. sonusai/mixture/mixdb.py +910 -729
  47. sonusai/mixture/pad_audio.py +35 -0
  48. sonusai/mixture/resample.py +7 -0
  49. sonusai/mixture/sox_effects.py +195 -0
  50. sonusai/mixture/sox_help.py +650 -0
  51. sonusai/mixture/spectral_mask.py +2 -2
  52. sonusai/mixture/truth.py +17 -15
  53. sonusai/mixture/truth_functions/crm.py +12 -12
  54. sonusai/mixture/truth_functions/energy.py +22 -22
  55. sonusai/mixture/truth_functions/file.py +5 -5
  56. sonusai/mixture/truth_functions/metadata.py +4 -4
  57. sonusai/mixture/truth_functions/metrics.py +4 -4
  58. sonusai/mixture/truth_functions/phoneme.py +3 -3
  59. sonusai/mixture/truth_functions/sed.py +11 -13
  60. sonusai/mixture/truth_functions/target.py +10 -10
  61. sonusai/mkwav.py +26 -29
  62. sonusai/onnx_predict.py +240 -88
  63. sonusai/queries/__init__.py +2 -2
  64. sonusai/queries/queries.py +38 -34
  65. sonusai/speech/librispeech.py +1 -1
  66. sonusai/speech/mcgill.py +1 -1
  67. sonusai/speech/timit.py +2 -2
  68. sonusai/summarize_metric_spenh.py +10 -17
  69. sonusai/utils/__init__.py +7 -1
  70. sonusai/utils/asl_p56.py +2 -2
  71. sonusai/utils/asr.py +2 -2
  72. sonusai/utils/asr_functions/aaware_whisper.py +4 -5
  73. sonusai/utils/choice.py +31 -0
  74. sonusai/utils/compress.py +1 -1
  75. sonusai/utils/dataclass_from_dict.py +19 -1
  76. sonusai/utils/energy_f.py +3 -3
  77. sonusai/utils/evaluate_random_rule.py +15 -0
  78. sonusai/utils/keyboard_interrupt.py +12 -0
  79. sonusai/utils/onnx_utils.py +3 -17
  80. sonusai/utils/print_mixture_details.py +21 -19
  81. sonusai/utils/{temp_seed.py → rand.py} +3 -3
  82. sonusai/utils/read_predict_data.py +2 -2
  83. sonusai/utils/reshape.py +3 -3
  84. sonusai/utils/stratified_shuffle_split.py +3 -3
  85. sonusai/{mixture → utils}/tokenized_shell_vars.py +1 -1
  86. sonusai/utils/write_audio.py +2 -2
  87. sonusai/vars.py +11 -4
  88. {sonusai-0.20.3.dist-info → sonusai-1.0.2.dist-info}/METADATA +4 -2
  89. sonusai-1.0.2.dist-info/RECORD +138 -0
  90. sonusai/mixture/augmentation.py +0 -444
  91. sonusai/mixture/class_count.py +0 -15
  92. sonusai/mixture/eq_rule_is_valid.py +0 -45
  93. sonusai/mixture/target_class_balancing.py +0 -107
  94. sonusai/mixture/targets.py +0 -175
  95. sonusai-0.20.3.dist-info/RECORD +0 -128
  96. {sonusai-0.20.3.dist-info → sonusai-1.0.2.dist-info}/WHEEL +0 -0
  97. {sonusai-0.20.3.dist-info → sonusai-1.0.2.dist-info}/entry_points.txt +0 -0
sonusai/lsdb.py CHANGED
@@ -1,12 +1,12 @@
1
1
  """sonusai lsdb
2
2
 
3
- usage: lsdb [-hta] [-i MIXID] [-c CID] LOC
3
+ usage: lsdb [-hsa] [-i MIXID] [-c CID] LOC
4
4
 
5
5
  Options:
6
6
  -h, --help
7
7
  -i MIXID, --mixid MIXID Mixture ID(s) to analyze. [default: *].
8
8
  -c CID, --class_index CID Analyze mixtures that contain this class index.
9
- -t, --targets List all target files.
9
+ -s, --sources List all source files.
10
10
  -a, --all_class_counts List all class counts.
11
11
 
12
12
  List mixture data information from a SonusAI mixture database.
@@ -16,25 +16,10 @@ Inputs:
16
16
 
17
17
  """
18
18
 
19
- import signal
20
-
21
- from sonusai import logger
22
- from sonusai.mixture import GeneralizedIDs
19
+ from sonusai.datatypes import GeneralizedIDs
23
20
  from sonusai.mixture import MixtureDatabase
24
21
 
25
22
 
26
- def signal_handler(_sig, _frame):
27
- import sys
28
-
29
- from sonusai import logger
30
-
31
- logger.info("Canceled due to keyboard interrupt")
32
- sys.exit(1)
33
-
34
-
35
- signal.signal(signal.SIGINT, signal_handler)
36
-
37
-
38
23
  def lsdb(
39
24
  mixdb: MixtureDatabase,
40
25
  mixids: GeneralizedIDs = "*",
@@ -42,7 +27,8 @@ def lsdb(
42
27
  list_targets: bool = False,
43
28
  all_class_counts: bool = False,
44
29
  ) -> None:
45
- from sonusai.mixture import SAMPLE_RATE
30
+ from sonusai import logger
31
+ from sonusai.constants import SAMPLE_RATE
46
32
  from sonusai.queries import get_mixids_from_class_indices
47
33
  from sonusai.utils import consolidate_range
48
34
  from sonusai.utils import max_text_width
@@ -54,38 +40,36 @@ def lsdb(
54
40
  total_samples = mixdb.total_samples()
55
41
  total_duration = total_samples / SAMPLE_RATE
56
42
 
57
- logger.info(f'{"Mixtures":{desc_len}} {mixdb.num_mixtures}')
58
- logger.info(f'{"Duration":{desc_len}} {seconds_to_hms(seconds=total_duration)}')
59
- logger.info(f'{"Targets":{desc_len}} {mixdb.num_target_files}')
60
- logger.info(f'{"Noises":{desc_len}} {mixdb.num_noise_files}')
61
- logger.info(f'{"Feature":{desc_len}} {mixdb.feature}')
43
+ logger.info(f"{'Mixtures':{desc_len}} {mixdb.num_mixtures}")
44
+ logger.info(f"{'Duration':{desc_len}} {seconds_to_hms(seconds=total_duration)}")
45
+ logger.info(f"{'Sources':{desc_len}} {mixdb.num_source_files}")
46
+ logger.info(f"{'Feature':{desc_len}} {mixdb.feature}")
62
47
  logger.info(
63
- f'{"Feature shape":{desc_len}} {mixdb.fg_stride} x {mixdb.feature_parameters} '
64
- f'({mixdb.fg_stride * mixdb.feature_parameters} total params)'
48
+ f"{'Feature shape':{desc_len}} {mixdb.fg_stride} x {mixdb.feature_parameters} "
49
+ f"({mixdb.fg_stride * mixdb.feature_parameters} total params)"
65
50
  )
66
- logger.info(f'{"Feature samples":{desc_len}} {mixdb.feature_samples} samples ({mixdb.feature_ms} ms)')
51
+ logger.info(f"{'Feature samples':{desc_len}} {mixdb.feature_samples} samples ({mixdb.feature_ms} ms)")
67
52
  logger.info(
68
- f'{"Feature step samples":{desc_len}} {mixdb.feature_step_samples} samples ' f'({mixdb.feature_step_ms} ms)'
53
+ f"{'Feature step samples':{desc_len}} {mixdb.feature_step_samples} samples ({mixdb.feature_step_ms} ms)"
69
54
  )
70
- logger.info(f'{"Feature overlap":{desc_len}} {mixdb.fg_step / mixdb.fg_stride} ({mixdb.feature_step_ms} ms)')
71
- logger.info(f'{"SNRs":{desc_len}} {mixdb.snrs}')
72
- logger.info(f'{"Random SNRs":{desc_len}} {mixdb.random_snrs}')
73
- logger.info(f'{"Classes":{desc_len}} {mixdb.num_classes}')
55
+ logger.info(f"{'Feature overlap':{desc_len}} {mixdb.fg_step / mixdb.fg_stride} ({mixdb.feature_step_ms} ms)")
56
+ logger.info(f"{'SNRs':{desc_len}} {mixdb.snrs}")
57
+ logger.info(f"{'Random SNRs':{desc_len}} {mixdb.random_snrs}")
58
+ logger.info(f"{'Classes':{desc_len}} {mixdb.num_classes}")
74
59
  # TODO: fix class count
75
- logger.info(f'{"Class count":{desc_len}} not supported')
60
+ logger.info(f"{'Class count':{desc_len}} not supported")
76
61
  # print_class_count(class_count=class_count, length=desc_len, print_fn=logger.info)
77
62
  # TODO: add class weight calculations here
78
63
  logger.info("")
79
64
 
80
65
  if list_targets:
81
- logger.info("Target details:")
82
- idx_len = max_text_width(mixdb.num_target_files)
83
- for idx, target in enumerate(mixdb.target_files):
84
- desc = f" {idx:{idx_len}} Name"
85
- logger.info(f"{desc:{desc_len}} {target.name}")
86
- desc = f" {idx:{idx_len}} Truth index"
87
- logger.info(f"{desc:{desc_len}} {target.class_indices}")
88
- logger.info("")
66
+ logger.info("Source details:")
67
+ for category, sources in mixdb.source_files.items():
68
+ print(f" {category}:")
69
+ for source in sources:
70
+ logger.info(f"{' Name':{desc_len}} {source.name}")
71
+ logger.info(f"{' Truth index':{desc_len}} {source.class_indices}")
72
+ logger.info("")
89
73
 
90
74
  if class_index is not None:
91
75
  if 0 <= class_index > mixdb.num_classes:
@@ -104,7 +88,7 @@ def lsdb(
104
88
  # print_class_count(class_count=class_count, length=desc_len, print_fn=logger.info, all_class_counts=True)
105
89
  else:
106
90
  logger.info(
107
- f"Calculating statistics from truth_f files for {len(mixids):,} mixtures" f" ({consolidate_range(mixids)})"
91
+ f"Calculating statistics from truth_f files for {len(mixids):,} mixtures ({consolidate_range(mixids)})"
108
92
  )
109
93
  logger.info("Not supported")
110
94
 
@@ -112,13 +96,10 @@ def lsdb(
112
96
  def main() -> None:
113
97
  from docopt import docopt
114
98
 
115
- import sonusai
116
- from sonusai import create_file_handler
117
- from sonusai import initial_log_messages
118
- from sonusai import update_console_handler
99
+ from sonusai import __version__ as sai_version
119
100
  from sonusai.utils import trim_docstring
120
101
 
121
- args = docopt(trim_docstring(__doc__), version=sonusai.__version__, options_first=True)
102
+ args = docopt(trim_docstring(__doc__), version=sai_version, options_first=True)
122
103
 
123
104
  mixid = args["--mixid"]
124
105
  class_index = args["--class_index"]
@@ -126,6 +107,11 @@ def main() -> None:
126
107
  all_class_counts = args["--all_class_counts"]
127
108
  location = args["LOC"]
128
109
 
110
+ from sonusai import create_file_handler
111
+ from sonusai import initial_log_messages
112
+ from sonusai import logger
113
+ from sonusai import update_console_handler
114
+
129
115
  if class_index is not None:
130
116
  class_index = int(class_index)
131
117
 
@@ -146,4 +132,11 @@ def main() -> None:
146
132
 
147
133
 
148
134
  if __name__ == "__main__":
149
- main()
135
+ from sonusai import exception_handler
136
+ from sonusai.utils import register_keyboard_interrupt
137
+
138
+ register_keyboard_interrupt()
139
+ try:
140
+ main()
141
+ except Exception as e:
142
+ exception_handler(e)
sonusai/main.py CHANGED
@@ -10,21 +10,6 @@ for more information on a specific command.
10
10
 
11
11
  """
12
12
 
13
- import signal
14
-
15
-
16
- def signal_handler(_sig, _frame):
17
- import sys
18
-
19
- from sonusai import logger
20
-
21
- logger.info("Canceled due to keyboard interrupt")
22
- sys.exit(1)
23
-
24
-
25
- signal.signal(signal.SIGINT, signal_handler)
26
-
27
-
28
13
  def main() -> None:
29
14
  from importlib import import_module
30
15
  from pkgutil import iter_modules
@@ -44,13 +29,13 @@ def main() -> None:
44
29
 
45
30
  from docopt import docopt
46
31
 
47
- from sonusai import __version__
32
+ from sonusai import __version__ as sai_version
48
33
  from sonusai.utils import add_commands_to_docstring
49
34
  from sonusai.utils import trim_docstring
50
35
 
51
36
  args = docopt(
52
37
  trim_docstring(add_commands_to_docstring(__doc__, plugin_docstrings)),
53
- version=__version__,
38
+ version=sai_version,
54
39
  options_first=True,
55
40
  )
56
41
 
@@ -61,15 +46,16 @@ def main() -> None:
61
46
  from os.path import join
62
47
  from subprocess import call
63
48
 
64
- import sonusai
49
+ from sonusai import BASEDIR
50
+ from sonusai import commands_list
65
51
  from sonusai import logger
66
52
 
67
- base_commands = sonusai.commands_list()
53
+ base_commands = commands_list()
68
54
  if command == "help":
69
55
  if not argv:
70
56
  exit(call(["sonusai", "-h"])) # noqa: S603, S607
71
57
  elif argv[0] in base_commands:
72
- exit(call(["python", f"{join(sonusai.BASEDIR, argv[0])}.py", "-h"])) # noqa: S603, S607
58
+ exit(call(["python", f"{join(BASEDIR, argv[0])}.py", "-h"])) # noqa: S603, S607
73
59
 
74
60
  for data in plugins.values():
75
61
  if argv[0] in data["commands"]:
@@ -79,7 +65,7 @@ def main() -> None:
79
65
  sys.exit(1)
80
66
 
81
67
  if command in base_commands:
82
- exit(call(["python", f"{join(sonusai.BASEDIR, command)}.py", *argv])) # noqa: S603, S607
68
+ exit(call(["python", f"{join(BASEDIR, command)}.py", *argv])) # noqa: S603, S607
83
69
 
84
70
  for data in plugins.values():
85
71
  if command in data["commands"]:
@@ -90,4 +76,11 @@ def main() -> None:
90
76
 
91
77
 
92
78
  if __name__ == "__main__":
93
- main()
79
+ from sonusai import exception_handler
80
+ from sonusai.utils import register_keyboard_interrupt
81
+
82
+ register_keyboard_interrupt()
83
+ try:
84
+ main()
85
+ except Exception as e:
86
+ exception_handler(e)
@@ -1,10 +1,5 @@
1
- from pathlib import Path
2
-
3
- import numpy as np
4
- from sox import Transformer as SoxTransformer
5
-
6
- from sonusai.mixture.datatypes import AudioStatsMetrics
7
- from sonusai.mixture.datatypes import AudioT
1
+ from ..datatypes import AudioStatsMetrics
2
+ from ..datatypes import AudioT
8
3
 
9
4
 
10
5
  def _convert_str_with_factors_to_int(x: str) -> int:
@@ -16,21 +11,9 @@ def _convert_str_with_factors_to_int(x: str) -> int:
16
11
 
17
12
 
18
13
  def calc_audio_stats(audio: AudioT, win_len: float | None = None) -> AudioStatsMetrics:
19
- from sonusai.mixture import SAMPLE_RATE
20
-
21
- args = ["stats"]
22
- if win_len is not None:
23
- args.extend(["-w", str(win_len)])
24
-
25
- tfm = Transformer()
14
+ from ..mixture.sox_effects import sox_stats
26
15
 
27
- _, _, out = tfm.build(
28
- input_array=audio,
29
- sample_rate_in=SAMPLE_RATE,
30
- output_filepath="-n",
31
- extra_args=args,
32
- return_output=True,
33
- )
16
+ out = sox_stats(audio, win_len)
34
17
 
35
18
  if out is None:
36
19
  raise SystemError("Call to sox failed")
@@ -57,275 +40,3 @@ def calc_audio_stats(audio: AudioT, win_len: float | None = None) -> AudioStatsM
57
40
  fl=float(stats["Flat factor"]),
58
41
  pkc=_convert_str_with_factors_to_int(stats["Pk count"]),
59
42
  )
60
-
61
-
62
- class Transformer(SoxTransformer):
63
- """Override certain sox.Transformer methods"""
64
-
65
- def build( # pyright: ignore [reportIncompatibleMethodOverride]
66
- self,
67
- input_filepath: str | Path | None = None,
68
- output_filepath: str | Path | None = None,
69
- input_array: np.ndarray | None = None,
70
- sample_rate_in: float | None = None,
71
- extra_args: list[str] | None = None,
72
- return_output: bool = False,
73
- ) -> tuple[bool, str | None, str | None]:
74
- """Given an input file or array, creates an output_file on disk by
75
- executing the current set of commands. This function returns True on
76
- success. If return_output is True, this function returns a triple of
77
- (status, out, err), giving the success state, along with stdout and
78
- stderr returned by sox.
79
-
80
- Parameters
81
- ----------
82
- input_filepath : str or None
83
- Either path to input audio file or None for array input.
84
- output_filepath : str
85
- Path to desired output file. If a file already exists at
86
- the given path, the file will be overwritten.
87
- If '-n', no file is created.
88
- input_array : np.ndarray or None
89
- An np.ndarray of an waveform with shape (n_samples, n_channels).
90
- sample_rate_in must also be provided.
91
- If None, input_filepath must be specified.
92
- sample_rate_in : int
93
- Sample rate of input_array.
94
- This argument is ignored if input_array is None.
95
- extra_args : list or None, default=None
96
- If a list is given, these additional arguments are passed to SoX
97
- at the end of the list of effects.
98
- Don't use this argument unless you know exactly what you're doing!
99
- return_output : bool, default=False
100
- If True, returns the status and information sent to stderr and
101
- stdout as a tuple (status, stdout, stderr).
102
- If output_filepath is None, return_output=True by default.
103
- If False, returns True on success.
104
-
105
- Returns
106
- -------
107
- status : bool
108
- True on success.
109
- out : str (optional)
110
- This is not returned unless return_output is True.
111
- When returned, captures the stdout produced by sox.
112
- err : str (optional)
113
- This is not returned unless return_output is True.
114
- When returned, captures the stderr produced by sox.
115
-
116
- Examples
117
- --------
118
- > import numpy as np
119
- > import sox
120
- > tfm = sox.Transformer()
121
- > sample_rate = 44100
122
- > y = np.sin(2 * np.pi * 440.0 * np.arange(sample_rate * 1.0) / sample_rate)
123
-
124
- file in, file out - basic usage
125
-
126
- > status = tfm.build('path/to/input.wav', 'path/to/output.mp3')
127
-
128
- file in, file out - equivalent usage
129
-
130
- > status = tfm.build(
131
- input_filepath='path/to/input.wav',
132
- output_filepath='path/to/output.mp3'
133
- )
134
-
135
- array in, file out
136
-
137
- > status = tfm.build(
138
- input_array=y, sample_rate_in=sample_rate,
139
- output_filepath='path/to/output.mp3'
140
- )
141
-
142
- """
143
- from sox import file_info
144
- from sox.core import SoxError
145
- from sox.core import sox
146
- from sox.log import logger
147
-
148
- input_format, input_filepath = self._parse_inputs(input_filepath, input_array, sample_rate_in)
149
-
150
- if output_filepath is None:
151
- raise ValueError("output_filepath is not specified!")
152
-
153
- # set output parameters
154
- if input_filepath == output_filepath:
155
- raise ValueError("input_filepath must be different from output_filepath.")
156
- file_info.validate_output_file(output_filepath)
157
-
158
- args = []
159
- args.extend(self.globals)
160
- args.extend(self._input_format_args(input_format))
161
- args.append(input_filepath)
162
- args.extend(self._output_format_args(self.output_format))
163
- args.append(output_filepath)
164
- args.extend(self.effects)
165
-
166
- if extra_args is not None:
167
- if not isinstance(extra_args, list):
168
- raise ValueError("extra_args must be a list.")
169
- args.extend(extra_args)
170
-
171
- status, out, err = sox(args, input_array, True)
172
- if status != 0:
173
- raise SoxError(f"Stdout: {out}\nStderr: {err}")
174
-
175
- logger.info("Created %s with effects: %s", output_filepath, " ".join(self.effects_log))
176
-
177
- if return_output:
178
- return status, out, err # pyright: ignore [reportReturnType]
179
-
180
- return True, None, None
181
-
182
- def build_array( # pyright: ignore [reportIncompatibleMethodOverride]
183
- self,
184
- input_filepath: str | Path | None = None,
185
- input_array: np.ndarray | None = None,
186
- sample_rate_in: int | None = None,
187
- extra_args: list[str] | None = None,
188
- ) -> np.ndarray:
189
- """Given an input file or array, returns the output as a numpy array
190
- by executing the current set of commands. By default, the array will
191
- have the same sample rate as the input file unless otherwise specified
192
- using set_output_format. Functions such as channels and convert
193
- will be ignored!
194
-
195
- The SonusAI override does not generate a warning for rate transforms.
196
-
197
- Parameters
198
- ----------
199
- input_filepath : str, Path or None
200
- Either path to input audio file or None.
201
- input_array : np.ndarray or None
202
- A np.ndarray of a waveform with shape (n_samples, n_channels).
203
- If this argument is passed, sample_rate_in must also be provided.
204
- If None, input_filepath must be specified.
205
- sample_rate_in : int
206
- Sample rate of input_array.
207
- This argument is ignored if input_array is None.
208
- extra_args : list or None, default=None
209
- If a list is given, these additional arguments are passed to SoX
210
- at the end of the list of effects.
211
- Don't use this argument unless you know exactly what you're doing!
212
-
213
- Returns
214
- -------
215
- output_array : np.ndarray
216
- Output audio as a numpy array
217
-
218
- Examples
219
- --------
220
-
221
- > import numpy as np
222
- > import sox
223
- > tfm = sox.Transformer()
224
- > sample_rate = 44100
225
- > y = np.sin(2 * np.pi * 440.0 * np.arange(sample_rate * 1.0) / sample_rate)
226
-
227
- file in, array out
228
-
229
- > output_array = tfm.build(input_filepath='path/to/input.wav')
230
-
231
- array in, array out
232
-
233
- > output_array = tfm.build(input_array=y, sample_rate_in=sample_rate)
234
-
235
- specifying the output sample rate
236
-
237
- > tfm.set_output_format(rate=8000)
238
- > output_array = tfm.build(input_array=y, sample_rate_in=sample_rate)
239
-
240
- if an effect changes the number of channels, you must explicitly
241
- specify the number of output channels
242
-
243
- > tfm.remix(remix_dictionary={1: [1], 2: [1], 3: [1]})
244
- > tfm.set_output_format(channels=3)
245
- > output_array = tfm.build(input_array=y, sample_rate_in=sample_rate)
246
-
247
-
248
- """
249
- from sox.core import SoxError
250
- from sox.core import sox
251
- from sox.log import logger
252
- from sox.transform import ENCODINGS_MAPPING
253
-
254
- input_format, input_filepath = self._parse_inputs(input_filepath, input_array, sample_rate_in)
255
-
256
- # check if any of the below commands are part of the effects chain
257
- ignored_commands = ["channels", "convert"]
258
- if set(ignored_commands) & set(self.effects_log):
259
- logger.warning(
260
- "When outputting to an array, channels and convert "
261
- + "effects may be ignored. Use set_output_format() to "
262
- + "specify output formats."
263
- )
264
-
265
- output_filepath = "-"
266
-
267
- if input_format.get("file_type") is None:
268
- encoding_out = np.int16
269
- else:
270
- encoding_out = next(k for k, v in ENCODINGS_MAPPING.items() if input_format["file_type"] == v)
271
-
272
- n_bits = np.dtype(encoding_out).itemsize * 8
273
-
274
- output_format = {
275
- "file_type": "raw",
276
- "rate": sample_rate_in,
277
- "bits": n_bits,
278
- "channels": input_format["channels"],
279
- "encoding": None,
280
- "comments": None,
281
- "append_comments": True,
282
- }
283
-
284
- if self.output_format.get("rate") is not None:
285
- output_format["rate"] = self.output_format["rate"]
286
-
287
- if self.output_format.get("channels") is not None:
288
- output_format["channels"] = self.output_format["channels"]
289
-
290
- if self.output_format.get("bits") is not None:
291
- n_bits = self.output_format["bits"]
292
- output_format["bits"] = n_bits
293
-
294
- match n_bits:
295
- case 8:
296
- encoding_out = np.int8 # type: ignore[assignment]
297
- case 16:
298
- encoding_out = np.int16
299
- case 32:
300
- encoding_out = np.float32 # type: ignore[assignment]
301
- case 64:
302
- encoding_out = np.float64 # type: ignore[assignment]
303
- case _:
304
- raise ValueError(f"invalid n_bits {n_bits}")
305
-
306
- args = []
307
- args.extend(self.globals)
308
- args.extend(self._input_format_args(input_format))
309
- args.append(input_filepath)
310
- args.extend(self._output_format_args(output_format))
311
- args.append(output_filepath)
312
- args.extend(self.effects)
313
-
314
- if extra_args is not None:
315
- if not isinstance(extra_args, list):
316
- raise ValueError("extra_args must be a list.")
317
- args.extend(extra_args)
318
-
319
- status, out, err = sox(args, input_array, False)
320
- if status != 0:
321
- raise SoxError(f"Stdout: {out}\nStderr: {err}")
322
-
323
- out = np.frombuffer(out, dtype=encoding_out) # pyright: ignore [reportArgumentType, reportCallIssue]
324
- if output_format["channels"] > 1:
325
- out = out.reshape(
326
- (output_format["channels"], int(len(out) / output_format["channels"])),
327
- order="F",
328
- ).T
329
- logger.info("Created array with effects: %s", " ".join(self.effects_log))
330
-
331
- return out
@@ -1,8 +1,8 @@
1
1
  import numpy as np
2
2
 
3
- from sonusai.mixture.datatypes import GeneralizedIDs
4
- from sonusai.mixture.datatypes import Truth
5
- from sonusai.mixture.mixdb import MixtureDatabase
3
+ from ..datatypes import GeneralizedIDs
4
+ from ..datatypes import Truth
5
+ from ..mixture.mixdb import MixtureDatabase
6
6
 
7
7
 
8
8
  def calc_class_weights_from_truth(truth: Truth, other_weight: float | None = None, other_index: int = -1) -> np.ndarray:
@@ -74,7 +74,7 @@ def calc_class_weights_from_mixdb(
74
74
  weights: Class weights. [num_classes, 1]
75
75
  Note: for Keras use dict(enumerate(weights))
76
76
  """
77
- from sonusai.mixture import get_class_count_from_mixids
77
+ from ..mixture.class_count import get_class_count_from_mixids
78
78
 
79
79
  count = np.ceil(np.array(get_class_count_from_mixids(mixdb=mixdb, mixids=mixids)) / mixdb.feature_step_samples)
80
80
  total_features = sum(count)
@@ -1,11 +1,14 @@
1
1
  import numpy as np
2
2
 
3
- from sonusai.mixture.datatypes import Predict
4
- from sonusai.mixture.datatypes import Truth
3
+ from ..datatypes import Predict
4
+ from ..datatypes import Truth
5
5
 
6
6
 
7
7
  def calc_optimal_thresholds(
8
- truth: Truth, predict: Predict, timesteps: int = 0, truth_thr: float = 0.5
8
+ truth: Truth,
9
+ predict: Predict,
10
+ timesteps: int = 0,
11
+ truth_thr: float = 0.5,
9
12
  ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
10
13
  """Calculates optimal thresholds for each class from one-hot prediction and truth data where both are
11
14
  one-hot probabilities (or quantized decisions) with size [frames, num_classes] or [frames, timesteps, num_classes].
@@ -24,8 +27,8 @@ def calc_optimal_thresholds(
24
27
  from sklearn.metrics import roc_auc_score
25
28
  from sklearn.metrics import roc_curve
26
29
 
27
- from sonusai.utils import get_num_classes_from_predict
28
- from sonusai.utils import reshape_outputs
30
+ from ..utils.reshape import get_num_classes_from_predict
31
+ from ..utils.reshape import reshape_outputs
29
32
 
30
33
  if truth.shape != predict.shape:
31
34
  raise ValueError("truth and predict are not the same shape")
@@ -1,6 +1,6 @@
1
1
  import numpy as np
2
2
 
3
- from sonusai.mixture.constants import SAMPLE_RATE
3
+ from ..constants import SAMPLE_RATE
4
4
 
5
5
 
6
6
  def calc_pesq(
@@ -23,7 +23,7 @@ def calc_pesq(
23
23
 
24
24
  from pesq import pesq
25
25
 
26
- from sonusai import logger
26
+ from .. import logger
27
27
 
28
28
  try:
29
29
  with warnings.catch_warnings():
@@ -1,9 +1,9 @@
1
1
  import numpy as np
2
2
 
3
- from sonusai.mixture.datatypes import AudioF
4
- from sonusai.mixture.datatypes import Segsnr
5
- from sonusai.mixture.datatypes import SnrFBinMetrics
6
- from sonusai.mixture.datatypes import SnrFMetrics
3
+ from ..datatypes import AudioF
4
+ from ..datatypes import Segsnr
5
+ from ..datatypes import SnrFBinMetrics
6
+ from ..datatypes import SnrFMetrics
7
7
 
8
8
 
9
9
  def calc_segsnr_f(segsnr_f: Segsnr) -> SnrFMetrics: