sonusai 1.0.16__cp311-abi3-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sonusai/__init__.py +170 -0
  2. sonusai/aawscd_probwrite.py +148 -0
  3. sonusai/audiofe.py +481 -0
  4. sonusai/calc_metric_spenh.py +1136 -0
  5. sonusai/config/__init__.py +0 -0
  6. sonusai/config/asr.py +21 -0
  7. sonusai/config/config.py +65 -0
  8. sonusai/config/config.yml +49 -0
  9. sonusai/config/constants.py +53 -0
  10. sonusai/config/ir.py +124 -0
  11. sonusai/config/ir_delay.py +62 -0
  12. sonusai/config/source.py +275 -0
  13. sonusai/config/spectral_masks.py +15 -0
  14. sonusai/config/truth.py +64 -0
  15. sonusai/constants.py +14 -0
  16. sonusai/data/__init__.py +0 -0
  17. sonusai/data/silero_vad_v5.1.jit +0 -0
  18. sonusai/data/silero_vad_v5.1.onnx +0 -0
  19. sonusai/data/speech_ma01_01.wav +0 -0
  20. sonusai/data/whitenoise.wav +0 -0
  21. sonusai/datatypes.py +383 -0
  22. sonusai/deprecated/gentcst.py +632 -0
  23. sonusai/deprecated/plot.py +519 -0
  24. sonusai/deprecated/tplot.py +365 -0
  25. sonusai/doc.py +52 -0
  26. sonusai/doc_strings/__init__.py +1 -0
  27. sonusai/doc_strings/doc_strings.py +531 -0
  28. sonusai/genft.py +196 -0
  29. sonusai/genmetrics.py +183 -0
  30. sonusai/genmix.py +199 -0
  31. sonusai/genmixdb.py +235 -0
  32. sonusai/ir_metric.py +551 -0
  33. sonusai/lsdb.py +141 -0
  34. sonusai/main.py +134 -0
  35. sonusai/metrics/__init__.py +43 -0
  36. sonusai/metrics/calc_audio_stats.py +42 -0
  37. sonusai/metrics/calc_class_weights.py +90 -0
  38. sonusai/metrics/calc_optimal_thresholds.py +73 -0
  39. sonusai/metrics/calc_pcm.py +45 -0
  40. sonusai/metrics/calc_pesq.py +36 -0
  41. sonusai/metrics/calc_phase_distance.py +43 -0
  42. sonusai/metrics/calc_sa_sdr.py +64 -0
  43. sonusai/metrics/calc_sample_weights.py +25 -0
  44. sonusai/metrics/calc_segsnr_f.py +82 -0
  45. sonusai/metrics/calc_speech.py +382 -0
  46. sonusai/metrics/calc_wer.py +71 -0
  47. sonusai/metrics/calc_wsdr.py +57 -0
  48. sonusai/metrics/calculate_metrics.py +395 -0
  49. sonusai/metrics/class_summary.py +74 -0
  50. sonusai/metrics/confusion_matrix_summary.py +75 -0
  51. sonusai/metrics/one_hot.py +283 -0
  52. sonusai/metrics/snr_summary.py +128 -0
  53. sonusai/metrics_summary.py +314 -0
  54. sonusai/mixture/__init__.py +15 -0
  55. sonusai/mixture/audio.py +187 -0
  56. sonusai/mixture/class_balancing.py +103 -0
  57. sonusai/mixture/constants.py +3 -0
  58. sonusai/mixture/data_io.py +173 -0
  59. sonusai/mixture/db.py +169 -0
  60. sonusai/mixture/db_datatypes.py +92 -0
  61. sonusai/mixture/effects.py +344 -0
  62. sonusai/mixture/feature.py +78 -0
  63. sonusai/mixture/generation.py +1116 -0
  64. sonusai/mixture/helpers.py +351 -0
  65. sonusai/mixture/ir_effects.py +77 -0
  66. sonusai/mixture/log_duration_and_sizes.py +23 -0
  67. sonusai/mixture/mixdb.py +1857 -0
  68. sonusai/mixture/pad_audio.py +35 -0
  69. sonusai/mixture/resample.py +7 -0
  70. sonusai/mixture/sox_effects.py +195 -0
  71. sonusai/mixture/sox_help.py +650 -0
  72. sonusai/mixture/spectral_mask.py +51 -0
  73. sonusai/mixture/truth.py +61 -0
  74. sonusai/mixture/truth_functions/__init__.py +45 -0
  75. sonusai/mixture/truth_functions/crm.py +105 -0
  76. sonusai/mixture/truth_functions/energy.py +222 -0
  77. sonusai/mixture/truth_functions/file.py +48 -0
  78. sonusai/mixture/truth_functions/metadata.py +24 -0
  79. sonusai/mixture/truth_functions/metrics.py +28 -0
  80. sonusai/mixture/truth_functions/phoneme.py +18 -0
  81. sonusai/mixture/truth_functions/sed.py +98 -0
  82. sonusai/mixture/truth_functions/target.py +142 -0
  83. sonusai/mkwav.py +135 -0
  84. sonusai/onnx_predict.py +363 -0
  85. sonusai/parse/__init__.py +0 -0
  86. sonusai/parse/expand.py +156 -0
  87. sonusai/parse/parse_source_directive.py +129 -0
  88. sonusai/parse/rand.py +214 -0
  89. sonusai/py.typed +0 -0
  90. sonusai/queries/__init__.py +0 -0
  91. sonusai/queries/queries.py +239 -0
  92. sonusai/rs.abi3.so +0 -0
  93. sonusai/rs.pyi +1 -0
  94. sonusai/rust/__init__.py +0 -0
  95. sonusai/speech/__init__.py +0 -0
  96. sonusai/speech/l2arctic.py +121 -0
  97. sonusai/speech/librispeech.py +102 -0
  98. sonusai/speech/mcgill.py +71 -0
  99. sonusai/speech/textgrid.py +89 -0
  100. sonusai/speech/timit.py +138 -0
  101. sonusai/speech/types.py +12 -0
  102. sonusai/speech/vctk.py +53 -0
  103. sonusai/speech/voxceleb.py +108 -0
  104. sonusai/utils/__init__.py +3 -0
  105. sonusai/utils/asl_p56.py +130 -0
  106. sonusai/utils/asr.py +91 -0
  107. sonusai/utils/asr_functions/__init__.py +3 -0
  108. sonusai/utils/asr_functions/aaware_whisper.py +69 -0
  109. sonusai/utils/audio_devices.py +50 -0
  110. sonusai/utils/braced_glob.py +50 -0
  111. sonusai/utils/calculate_input_shape.py +26 -0
  112. sonusai/utils/choice.py +51 -0
  113. sonusai/utils/compress.py +25 -0
  114. sonusai/utils/convert_string_to_number.py +6 -0
  115. sonusai/utils/create_timestamp.py +5 -0
  116. sonusai/utils/create_ts_name.py +14 -0
  117. sonusai/utils/dataclass_from_dict.py +27 -0
  118. sonusai/utils/db.py +16 -0
  119. sonusai/utils/docstring.py +53 -0
  120. sonusai/utils/energy_f.py +44 -0
  121. sonusai/utils/engineering_number.py +166 -0
  122. sonusai/utils/evaluate_random_rule.py +15 -0
  123. sonusai/utils/get_frames_per_batch.py +2 -0
  124. sonusai/utils/get_label_names.py +20 -0
  125. sonusai/utils/grouper.py +6 -0
  126. sonusai/utils/human_readable_size.py +7 -0
  127. sonusai/utils/keyboard_interrupt.py +12 -0
  128. sonusai/utils/load_object.py +21 -0
  129. sonusai/utils/max_text_width.py +9 -0
  130. sonusai/utils/model_utils.py +28 -0
  131. sonusai/utils/numeric_conversion.py +11 -0
  132. sonusai/utils/onnx_utils.py +155 -0
  133. sonusai/utils/parallel.py +162 -0
  134. sonusai/utils/path_info.py +7 -0
  135. sonusai/utils/print_mixture_details.py +60 -0
  136. sonusai/utils/rand.py +13 -0
  137. sonusai/utils/ranges.py +43 -0
  138. sonusai/utils/read_predict_data.py +32 -0
  139. sonusai/utils/reshape.py +154 -0
  140. sonusai/utils/seconds_to_hms.py +7 -0
  141. sonusai/utils/stacked_complex.py +82 -0
  142. sonusai/utils/stratified_shuffle_split.py +170 -0
  143. sonusai/utils/tokenized_shell_vars.py +143 -0
  144. sonusai/utils/write_audio.py +26 -0
  145. sonusai/utils/yes_or_no.py +8 -0
  146. sonusai/vars.py +47 -0
  147. sonusai-1.0.16.dist-info/METADATA +56 -0
  148. sonusai-1.0.16.dist-info/RECORD +150 -0
  149. sonusai-1.0.16.dist-info/WHEEL +4 -0
  150. sonusai-1.0.16.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,69 @@
1
+ from ...datatypes import AudioT
2
+ from ..asr import ASRResult
3
+
4
+
5
+ def aaware_whisper_validate(**_config) -> None:
6
+ pass
7
+
8
+
9
+ def aaware_whisper(audio: AudioT, **_config) -> ASRResult:
10
+ import tempfile
11
+ from math import exp
12
+ from os import getenv
13
+ from os.path import join
14
+
15
+ import requests
16
+
17
+ from ..numeric_conversion import float_to_int16
18
+ from ..write_audio import write_audio
19
+
20
+ url = getenv("AAWARE_WHISPER_URL")
21
+ if url is None:
22
+ raise OSError("AAWARE_WHISPER_URL environment variable does not exist")
23
+ url += "/asr?task=transcribe&language=en&encode=true&output=json"
24
+
25
+ with tempfile.TemporaryDirectory() as tmp:
26
+ file = join(tmp, "asr.wav")
27
+ write_audio(name=file, audio=float_to_int16(audio))
28
+
29
+ files = {"audio_file": (file, open(file, "rb"), "audio/wav")} # noqa: SIM115
30
+
31
+ try:
32
+ response = requests.post(url, files=files) # noqa: S113
33
+ if response.status_code != 200:
34
+ if response.status_code == 422:
35
+ raise RuntimeError(f"Validation error: {response.json()}") # noqa: TRY301
36
+ raise RuntimeError(f"Invalid response: {response.status_code}") # noqa: TRY301
37
+ result = response.json()
38
+ return ASRResult(
39
+ text=result["text"],
40
+ confidence=exp(float(result["segments"][0]["avg_logprob"])),
41
+ )
42
+ except Exception as e:
43
+ raise RuntimeError(f"Aaware Whisper exception: {e.args}") from e
44
+
45
+
46
+ """
47
+ Aaware Whisper Asr Webservice results:
48
+ {
49
+ "text": " The birch canoes slid on the smooth planks.",
50
+ "segments": [
51
+ {
52
+ "id": 0,
53
+ "seek": 0,
54
+ "start": 0.0,
55
+ "end": 2.32,
56
+ "text": " The birch canoes slid on the smooth planks.",
57
+ "tokens": [
58
+ 50364, 440, 1904, 339, 393, 78, 279, 1061, 327, 322, 264, 5508, 499,
59
+ 14592, 13, 50480
60
+ ],
61
+ "temperature": 0.0,
62
+ "avg_logprob": -0.385713913861443,
63
+ "compression_ratio": 0.86,
64
+ "no_speech_prob": 0.006166956853121519
65
+ }
66
+ ],
67
+ "language": "en"
68
+ }
69
+ """
@@ -0,0 +1,50 @@
1
+ import pyaudio
2
+
3
+
4
+ def get_input_device_index_by_name(p: pyaudio.PyAudio, name: str | None = None) -> int:
5
+ info = p.get_host_api_info_by_index(0)
6
+ device_count = info.get("deviceCount")
7
+ if isinstance(device_count, int):
8
+ for i in range(0, device_count):
9
+ device_info = p.get_device_info_by_host_api_device_index(0, i)
10
+ if name is None:
11
+ device_name = None
12
+ else:
13
+ device_name = device_info.get("name")
14
+ if isinstance(device_name, str) or device_name is None:
15
+ input_channels = device_info.get("maxInputChannels")
16
+ if name == device_name and isinstance(input_channels, int) and input_channels > 0:
17
+ return i
18
+
19
+ raise ValueError(f"Could not find {name}")
20
+
21
+
22
+ def get_input_devices(p: pyaudio.PyAudio) -> list[str]:
23
+ names = []
24
+ info = p.get_host_api_info_by_index(0)
25
+ device_count = info.get("deviceCount")
26
+ if isinstance(device_count, int):
27
+ for i in range(0, device_count):
28
+ device_info = p.get_device_info_by_host_api_device_index(0, i)
29
+ device_name = device_info.get("name")
30
+ if isinstance(device_name, str):
31
+ input_channels = device_info.get("maxInputChannels")
32
+ if isinstance(input_channels, int) and input_channels > 0:
33
+ names.append(device_name)
34
+
35
+ return names
36
+
37
+
38
+ def get_default_input_device(p: pyaudio.PyAudio) -> str:
39
+ info = p.get_host_api_info_by_index(0)
40
+ device_count = info.get("deviceCount")
41
+ if isinstance(device_count, int):
42
+ for i in range(0, device_count):
43
+ device_info = p.get_device_info_by_host_api_device_index(0, i)
44
+ device_name = device_info.get("name")
45
+ if isinstance(device_name, str):
46
+ input_channels = device_info.get("maxInputChannels")
47
+ if isinstance(input_channels, int) and input_channels > 0:
48
+ return device_name
49
+
50
+ raise ValueError("No input audio devices found")
@@ -0,0 +1,50 @@
1
+ from collections.abc import Generator
2
+ from typing import LiteralString
3
+
4
+
5
+ def expand_braces(text: LiteralString | str | bytes, seen: set[str] | None = None) -> Generator[str, None, None]:
6
+ """Brace-expansion pre-processing for glob.
7
+
8
+ Expand all the braces, then run glob on each of the results.
9
+ (Brace-expansion turns one string into a list of strings.)
10
+ https://stackoverflow.com/questions/22996645/brace-expansion-in-python-glob
11
+ """
12
+ import itertools
13
+ import re
14
+
15
+ if seen is None:
16
+ seen = set()
17
+
18
+ if not isinstance(text, str):
19
+ text = str(text)
20
+
21
+ spans = [m.span() for m in re.finditer(r"\{[^{}]*}", text)][::-1]
22
+ alts = [text[start + 1 : stop - 1].split(",") for start, stop in spans]
23
+
24
+ if len(spans) == 0:
25
+ if text not in seen:
26
+ yield text
27
+ seen.add(text)
28
+ else:
29
+ for combo in itertools.product(*alts):
30
+ replaced = list(text)
31
+ for (start, stop), replacement in zip(spans, combo, strict=False):
32
+ replaced[start:stop] = replacement
33
+ yield from expand_braces("".join(replaced), seen)
34
+
35
+
36
+ def braced_glob(pathname: LiteralString | str | bytes, recursive: bool = False) -> list[str]:
37
+ from glob import glob
38
+
39
+ result = []
40
+ for expanded_path in expand_braces(pathname):
41
+ result.extend(glob(expanded_path, recursive=recursive))
42
+
43
+ return result
44
+
45
+
46
+ def braced_iglob(pathname: LiteralString | str | bytes, recursive: bool = False) -> Generator[str, None, None]:
47
+ from glob import iglob
48
+
49
+ for expanded_path in expand_braces(pathname):
50
+ yield from iglob(expanded_path, recursive=recursive)
@@ -0,0 +1,26 @@
1
+ def calculate_input_shape(feature: str, flatten: bool = False, timesteps: int = 0, add1ch: bool = False) -> list[int]:
2
+ """
3
+ Calculate input shape given feature and user-specified reshape parameters.
4
+
5
+ Inputs:
6
+ feature: String defining the Aaware feature used in SonusAI, typically mixdb.feature.
7
+ flatten: If true, flatten the 2D spectrogram from SxB to S*B.
8
+ timesteps: Pre-pend timesteps dimension if non-zero, size = timesteps.
9
+ add1ch: Append channel dimension of size 1, (channel last).
10
+ """
11
+ from pyaaware import FeatureGenerator
12
+
13
+ fg = FeatureGenerator(feature_mode=feature)
14
+
15
+ if flatten:
16
+ in_shape = [fg.stride * fg.feature_parameters]
17
+ else:
18
+ in_shape = [fg.stride, fg.feature_parameters]
19
+
20
+ if timesteps > 0:
21
+ in_shape.insert(0, timesteps)
22
+
23
+ if add1ch:
24
+ in_shape.append(1)
25
+
26
+ return in_shape
@@ -0,0 +1,51 @@
1
+ from abc import ABC
2
+ from abc import abstractmethod
3
+ from typing import Any
4
+
5
+
6
+ class BaseChoice(ABC):
7
+ def __init__(self, data: list):
8
+ self.data = data
9
+ self.index = 0
10
+ self.choices: list = []
11
+
12
+ @abstractmethod
13
+ def next(self) -> Any:
14
+ pass
15
+
16
+
17
+ class RandomChoice(BaseChoice):
18
+ def __init__(self, data: list, repetition: bool = False):
19
+ from random import sample
20
+
21
+ super().__init__(data)
22
+ self.repeat = repetition
23
+ self.choices = sample(self.data, len(self.data))
24
+
25
+ def next(self) -> Any:
26
+ from random import choice
27
+ from random import sample
28
+
29
+ if self.repeat:
30
+ return choice(self.data) # noqa: S311
31
+
32
+ if self.index >= len(self.data):
33
+ self.choices = sample(self.data, len(self.data))
34
+ self.index = 0
35
+
36
+ item = self.choices[self.index]
37
+ self.index += 1
38
+
39
+ return item
40
+
41
+
42
+ class SequentialChoice(BaseChoice):
43
+ def __init__(self, data: list):
44
+ super().__init__(data)
45
+
46
+ def next(self) -> Any:
47
+ if self.index >= len(self.data):
48
+ self.index = 0
49
+ item = self.data[self.index]
50
+ self.index += 1
51
+ return item
@@ -0,0 +1,25 @@
1
+ from ..datatypes import AudioF
2
+
3
+
4
+ def power_compress(feature: AudioF) -> AudioF:
5
+ import numpy as np
6
+
7
+ mag = np.abs(feature)
8
+ phase = np.angle(feature)
9
+ mag = mag**0.3
10
+ real_compress = mag * np.cos(phase)
11
+ imag_compress = mag * np.sin(phase)
12
+
13
+ return real_compress + 1j * imag_compress
14
+
15
+
16
+ def power_uncompress(feature: AudioF) -> AudioF:
17
+ import numpy as np
18
+
19
+ mag = np.abs(feature)
20
+ phase = np.angle(feature)
21
+ mag = mag ** (1.0 / 0.3)
22
+ real_uncompress = mag * np.cos(phase)
23
+ imag_uncompress = mag * np.sin(phase)
24
+
25
+ return real_uncompress + 1j * imag_uncompress
@@ -0,0 +1,6 @@
1
+ def convert_string_to_number(string: str) -> float | int | str:
2
+ try:
3
+ result = float(string)
4
+ return int(result) if result == int(result) else result
5
+ except ValueError:
6
+ return string
@@ -0,0 +1,5 @@
1
+ def create_timestamp() -> str:
2
+ """Create a timestamp."""
3
+ from datetime import datetime
4
+
5
+ return datetime.now().strftime("%Y%m%d-%H%M%S")
@@ -0,0 +1,14 @@
1
+ def create_ts_name(base: str) -> str:
2
+ """Create a timestamped name based on 'base'."""
3
+ from datetime import datetime
4
+ from os.path import exists
5
+
6
+ ts = datetime.now()
7
+
8
+ # First try just date
9
+ dir_name = base + "-" + ts.strftime("%Y%m%d")
10
+ if exists(dir_name):
11
+ # add hour-min-sec if necessary
12
+ dir_name = base + "-" + ts.strftime("%Y%m%d-%H%M%S")
13
+
14
+ return dir_name
@@ -0,0 +1,27 @@
1
+ from collections.abc import Sequence
2
+ from typing import Any
3
+
4
+
5
+ def dataclass_from_dict(klass, dikt: dict) -> Any:
6
+ """Convert dictionary to dataclass."""
7
+ try:
8
+ field_types = klass.__annotations__
9
+ return klass(**{f: dataclass_from_dict(field_types[f], dikt[f]) for f in dikt})
10
+ except AttributeError:
11
+ return dikt
12
+
13
+
14
+ def list_dataclass_from_dict(klass, dikt: Sequence[dict]) -> list[Any]:
15
+ """Convert list of dictionary to list of dataclass."""
16
+ return [dataclass_from_dict(klass.__args__[0], f) for f in dikt]
17
+
18
+
19
+ def original_dataclass_from_dict(klass, dikt):
20
+ """Convert dictionary to dataclass."""
21
+ try:
22
+ field_types = klass.__annotations__
23
+ return klass(**{f: dataclass_from_dict(field_types[f], dikt[f]) for f in dikt})
24
+ except AttributeError:
25
+ if isinstance(dikt, tuple | list):
26
+ return [dataclass_from_dict(klass.__args__[0], f) for f in dikt]
27
+ return dikt
sonusai/utils/db.py ADDED
@@ -0,0 +1,16 @@
1
+ def linear_to_db(linear: float) -> float:
2
+ """Convert linear value to dB value
3
+ :param linear: Linear value
4
+ :return: dB value
5
+ """
6
+ import numpy as np
7
+
8
+ return 20 * np.log10(abs(linear))
9
+
10
+
11
+ def db_to_linear(db: float) -> float:
12
+ """Convert dB value to linear value
13
+ :param db: dB value
14
+ :return: Linear value
15
+ """
16
+ return 10 ** (db / 20)
@@ -0,0 +1,53 @@
1
+ def trim_docstring(docstring: str | None) -> str:
2
+ """Trim whitespace from docstring"""
3
+ from sys import maxsize
4
+
5
+ if not docstring:
6
+ return ""
7
+
8
+ # Convert tabs to spaces (following the normal Python rules)
9
+ # and split into a list of lines:
10
+ lines = docstring.expandtabs().splitlines()
11
+
12
+ # Determine minimum indentation (first line doesn't count)
13
+ indent = maxsize
14
+ for line in lines[1:]:
15
+ stripped = line.lstrip()
16
+ if stripped:
17
+ indent = min(indent, len(line) - len(stripped))
18
+
19
+ # Remove indentation (first line is special):
20
+ trimmed = [lines[0].strip()]
21
+ if indent < maxsize:
22
+ for line in lines[1:]:
23
+ trimmed.append(line[indent:].rstrip())
24
+
25
+ # Strip off leading blank lines:
26
+ while trimmed and not trimmed[0]:
27
+ trimmed.pop(0)
28
+
29
+ # Return a single string
30
+ return "\n".join(trimmed)
31
+
32
+
33
+ def add_commands_to_docstring(docstring: str | None, plugin_docstrings: list[str]) -> str:
34
+ """Add commands to docstring"""
35
+ import sonusai
36
+
37
+ if not docstring:
38
+ lines = []
39
+ else:
40
+ lines = docstring.splitlines()
41
+
42
+ start = lines.index("The sonusai commands are:")
43
+ end = lines.index("", start)
44
+
45
+ commands = sonusai.commands_doc.splitlines()
46
+ for plugin_docstring in plugin_docstrings:
47
+ commands.extend(plugin_docstring.splitlines())
48
+ commands.sort()
49
+ commands = list(filter(None, commands))
50
+
51
+ lines = lines[: start + 1] + commands + lines[end:]
52
+
53
+ return "\n".join(lines)
@@ -0,0 +1,44 @@
1
+ from pyaaware import ForwardTransform
2
+
3
+ from ..datatypes import AudioF
4
+ from ..datatypes import AudioT
5
+ from ..datatypes import EnergyF
6
+
7
+
8
+ def compute_energy_f(
9
+ frequency_domain: AudioF | None = None,
10
+ time_domain: AudioT | None = None,
11
+ transform: ForwardTransform | None = None,
12
+ ) -> EnergyF:
13
+ """Compute the energy in each bin
14
+
15
+ Must provide either frequency domain or time domain input. If time domain input is provided, must also provide
16
+ ForwardTransform object to use to convert to frequency domain.
17
+
18
+ :param frequency_domain: Frequency domain data [frames, bins]
19
+ :param time_domain: Time domain data [samples]
20
+ :param transform: ForwardTransform object
21
+ :return: Frequency domain per-bin energy data [frames, bins]
22
+ """
23
+ import numpy as np
24
+ import torch
25
+
26
+ if frequency_domain is None:
27
+ if time_domain is None:
28
+ raise ValueError("Must provide time or frequency domain input")
29
+ if transform is None:
30
+ raise ValueError("Must provide ForwardTransform object")
31
+
32
+ _frequency_domain = transform.execute_all(torch.from_numpy(time_domain))[0].numpy()
33
+ else:
34
+ _frequency_domain = frequency_domain
35
+
36
+ frames, bins = _frequency_domain.shape
37
+ result = np.empty((frames, bins), dtype=np.float32)
38
+
39
+ for f in range(frames):
40
+ for b in range(bins):
41
+ value = _frequency_domain[f, b]
42
+ result[f, b] = np.real(value) * np.real(value) + np.imag(value) * np.imag(value)
43
+
44
+ return result
@@ -0,0 +1,166 @@
1
+ from __future__ import annotations
2
+
3
+ from decimal import Decimal
4
+ from typing import Any
5
+ from typing import ClassVar
6
+
7
+
8
+ class EngineeringNumber:
9
+ """Easy manipulation of numbers which use engineering notation"""
10
+
11
+ _suffix_lookup: ClassVar = {
12
+ "Y": "e24",
13
+ "Z": "e21",
14
+ "E": "e18",
15
+ "P": "e15",
16
+ "T": "e12",
17
+ "G": "e9",
18
+ "M": "e6",
19
+ "k": "e3",
20
+ "": "e0",
21
+ "m": "e-3",
22
+ "u": "e-6",
23
+ "n": "e-9",
24
+ "p": "e-12",
25
+ "f": "e-15",
26
+ "a": "e-18",
27
+ "z": "e-21",
28
+ "y": "e-24",
29
+ }
30
+
31
+ _exponent_lookup_scaled: ClassVar = {
32
+ "-12": "Y",
33
+ "-15": "Z",
34
+ "-18": "E",
35
+ "-21": "P",
36
+ "-24": "T",
37
+ "-27": "G",
38
+ "-30": "M",
39
+ "-33": "k",
40
+ "-36": "",
41
+ "-39": "m",
42
+ "-42": "u",
43
+ "-45": "n",
44
+ "-48": "p",
45
+ "-51": "f",
46
+ "-54": "a",
47
+ "-57": "z",
48
+ "-60": "y",
49
+ }
50
+
51
+ def __init__(
52
+ self,
53
+ value: str | float | int | EngineeringNumber,
54
+ precision: int = 2,
55
+ significant: int = 0,
56
+ ):
57
+ """
58
+ :param value: string, integer, or float representing the numeric value of the number
59
+ :param precision: the precision past the decimal
60
+ :param significant: the number of significant digits
61
+ if given, significant takes precedence over precision
62
+ """
63
+ self.precision = precision
64
+ self.significant = significant
65
+
66
+ if isinstance(value, str):
67
+ suffix_keys = [key for key in self._suffix_lookup if key != ""]
68
+
69
+ str_value = str(value)
70
+ for suffix in suffix_keys:
71
+ if suffix in str_value:
72
+ str_value = str_value[:-1] + self._suffix_lookup[suffix]
73
+ break
74
+
75
+ self.number = Decimal(str_value)
76
+
77
+ elif isinstance(value, int | float | EngineeringNumber):
78
+ self.number = Decimal(str(value))
79
+
80
+ else:
81
+ raise TypeError("value has unsupported type")
82
+
83
+ def __repr__(self):
84
+ """Returns the string representation"""
85
+ # The Decimal class only really converts numbers that are very small into engineering notation.
86
+ # So we will simply make all numbers small numbers and take advantage of the Decimal class.
87
+ number_str = self.number * Decimal("10e-37")
88
+ number_str = number_str.to_eng_string().lower()
89
+
90
+ base, exponent = number_str.split("e")
91
+
92
+ if self.significant > 0:
93
+ abs_base = abs(Decimal(base))
94
+ num_digits = 1
95
+ num_digits += 1 if abs_base >= 10 else 0
96
+ num_digits += 1 if abs_base >= 100 else 0
97
+ num_digits = self.significant - num_digits
98
+ else:
99
+ num_digits = self.precision
100
+
101
+ base = str(round(Decimal(base), num_digits))
102
+
103
+ if "e" in base.lower():
104
+ base = str(int(Decimal(base)))
105
+
106
+ # Remove trailing decimal
107
+ if "." in base:
108
+ base = base.rstrip(".")
109
+
110
+ return base + self._exponent_lookup_scaled[exponent]
111
+
112
+ def __str__(self) -> str:
113
+ return self.__repr__()
114
+
115
+ def __int__(self) -> int:
116
+ return int(self.number)
117
+
118
+ def __float__(self):
119
+ return float(self.number)
120
+
121
+ @staticmethod
122
+ def _to_decimal(other: str | float | int | EngineeringNumber) -> Decimal:
123
+ if not isinstance(other, EngineeringNumber):
124
+ other = EngineeringNumber(other)
125
+ return other.number
126
+
127
+ def __add__(self, other: str | float | int | EngineeringNumber) -> EngineeringNumber:
128
+ return EngineeringNumber(str(self.number + self._to_decimal(other)))
129
+
130
+ def __radd__(self, other: str | float | int | EngineeringNumber) -> EngineeringNumber:
131
+ return self.__add__(other)
132
+
133
+ def __sub__(self, other: str | float | int | EngineeringNumber) -> EngineeringNumber:
134
+ return EngineeringNumber(str(self.number - self._to_decimal(other)))
135
+
136
+ def __rsub__(self, other: str | float | int | EngineeringNumber) -> EngineeringNumber:
137
+ return EngineeringNumber(str(self._to_decimal(other) - self.number))
138
+
139
+ def __mul__(self, other: str | float | int | EngineeringNumber) -> EngineeringNumber:
140
+ return EngineeringNumber(str(self.number * self._to_decimal(other)))
141
+
142
+ def __rmul__(self, other: str | float | int | EngineeringNumber) -> EngineeringNumber:
143
+ return self.__mul__(other)
144
+
145
+ def __truediv__(self, other: str | float | int | EngineeringNumber) -> EngineeringNumber:
146
+ return EngineeringNumber(str(self.number / self._to_decimal(other)))
147
+
148
+ def __rtruediv__(self, other: str | float | int | EngineeringNumber) -> EngineeringNumber:
149
+ return EngineeringNumber(str(self._to_decimal(other) / self.number))
150
+
151
+ def __lt__(self, other: str | float | int | EngineeringNumber) -> bool:
152
+ return self.number < self._to_decimal(other)
153
+
154
+ def __gt__(self, other: str | float | int | EngineeringNumber) -> bool:
155
+ return self.number > self._to_decimal(other)
156
+
157
+ def __le__(self, other: str | float | int | EngineeringNumber) -> bool:
158
+ return self.number <= self._to_decimal(other)
159
+
160
+ def __ge__(self, other: str | float | int | EngineeringNumber) -> bool:
161
+ return self.number >= self._to_decimal(other)
162
+
163
+ def __eq__(self, other: Any) -> bool:
164
+ if not isinstance(other, str | float | int | EngineeringNumber):
165
+ return NotImplemented
166
+ return self.number == self._to_decimal(other)
@@ -0,0 +1,15 @@
1
+ def evaluate_random_rule(rule: str) -> str | float:
2
+ """Evaluate 'rand' directive
3
+
4
+ :param rule: Rule
5
+ :return: Resolved value
6
+ """
7
+ import re
8
+ from random import uniform
9
+
10
+ rand_pattern = re.compile(r"rand\(([-+]?(\d+(\.\d*)?|\.\d+)),\s*([-+]?(\d+(\.\d*)?|\.\d+))\)")
11
+
12
+ def rand_repl(m):
13
+ return f"{uniform(float(m.group(1)), float(m.group(4))):.2f}" # noqa: S311
14
+
15
+ return eval(re.sub(rand_pattern, rand_repl, rule)) # noqa: S307
@@ -0,0 +1,2 @@
1
+ def get_frames_per_batch(batch_size: int, timesteps: int) -> int:
2
+ return batch_size if timesteps == 0 else batch_size * timesteps
@@ -0,0 +1,20 @@
1
+ def get_label_names(num_labels: int, file: str | None = None) -> list:
2
+ """Return label names in a list. Read from CSV file, if provided."""
3
+ import csv
4
+
5
+ if file is None:
6
+ return [f"Class {val + 1}" for val in range(num_labels)]
7
+
8
+ label_names = [""] * num_labels
9
+ with open(file) as f:
10
+ reader = csv.DictReader(f)
11
+ if reader.fieldnames is None or "index" not in reader.fieldnames or "display_name" not in reader.fieldnames:
12
+ raise ValueError("Missing required fields in labels CSV.")
13
+
14
+ for row in reader:
15
+ index = int(row["index"]) - 1
16
+ if index >= num_labels:
17
+ raise ValueError("The number of given label names does not match the number of labels.")
18
+ label_names[index] = row["display_name"]
19
+
20
+ return label_names
@@ -0,0 +1,6 @@
1
+ def grouper(iterable, n):
2
+ from itertools import zip_longest
3
+
4
+ args = [iter(iterable)] * n
5
+ result = zip_longest(*args, fillvalue=None)
6
+ return [list(filter(None.__ne__, x)) for x in result]
@@ -0,0 +1,7 @@
1
+ def human_readable_size(num: float, decimal_places: int = 3, suffix: str = "B") -> str:
2
+ """Convert number into string with units"""
3
+ for unit in ("", "k", "M", "G", "T", "P", "E", "Z"):
4
+ if abs(num) < 1024.0:
5
+ return f"{num:.{decimal_places}f} {unit}{suffix}"
6
+ num /= 1024.0
7
+ return f"{num:.{decimal_places}f} Y{suffix}"