sonusai 0.18.9__py3-none-any.whl → 0.19.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (118) hide show
  1. sonusai/__init__.py +20 -29
  2. sonusai/aawscd_probwrite.py +18 -18
  3. sonusai/audiofe.py +93 -80
  4. sonusai/calc_metric_spenh.py +395 -321
  5. sonusai/data/genmixdb.yml +5 -11
  6. sonusai/{gentcst.py → deprecated/gentcst.py} +146 -149
  7. sonusai/{plot.py → deprecated/plot.py} +177 -131
  8. sonusai/{tplot.py → deprecated/tplot.py} +124 -102
  9. sonusai/doc/__init__.py +1 -1
  10. sonusai/doc/doc.py +112 -177
  11. sonusai/doc.py +10 -10
  12. sonusai/genft.py +81 -91
  13. sonusai/genmetrics.py +51 -61
  14. sonusai/genmix.py +105 -115
  15. sonusai/genmixdb.py +201 -174
  16. sonusai/lsdb.py +56 -66
  17. sonusai/main.py +23 -20
  18. sonusai/metrics/__init__.py +2 -0
  19. sonusai/metrics/calc_audio_stats.py +29 -24
  20. sonusai/metrics/calc_class_weights.py +7 -7
  21. sonusai/metrics/calc_optimal_thresholds.py +5 -7
  22. sonusai/metrics/calc_pcm.py +3 -3
  23. sonusai/metrics/calc_pesq.py +10 -7
  24. sonusai/metrics/calc_phase_distance.py +3 -3
  25. sonusai/metrics/calc_sa_sdr.py +10 -8
  26. sonusai/metrics/calc_segsnr_f.py +16 -18
  27. sonusai/metrics/calc_speech.py +105 -47
  28. sonusai/metrics/calc_wer.py +35 -32
  29. sonusai/metrics/calc_wsdr.py +10 -7
  30. sonusai/metrics/class_summary.py +30 -27
  31. sonusai/metrics/confusion_matrix_summary.py +25 -22
  32. sonusai/metrics/one_hot.py +91 -57
  33. sonusai/metrics/snr_summary.py +53 -46
  34. sonusai/mixture/__init__.py +20 -14
  35. sonusai/mixture/audio.py +4 -6
  36. sonusai/mixture/augmentation.py +37 -43
  37. sonusai/mixture/class_count.py +5 -14
  38. sonusai/mixture/config.py +292 -225
  39. sonusai/mixture/constants.py +41 -30
  40. sonusai/mixture/data_io.py +155 -0
  41. sonusai/mixture/datatypes.py +111 -108
  42. sonusai/mixture/db_datatypes.py +54 -70
  43. sonusai/mixture/eq_rule_is_valid.py +6 -9
  44. sonusai/mixture/feature.py +40 -38
  45. sonusai/mixture/generation.py +522 -389
  46. sonusai/mixture/helpers.py +217 -272
  47. sonusai/mixture/log_duration_and_sizes.py +16 -13
  48. sonusai/mixture/mixdb.py +669 -477
  49. sonusai/mixture/soundfile_audio.py +12 -17
  50. sonusai/mixture/sox_audio.py +91 -112
  51. sonusai/mixture/sox_augmentation.py +8 -9
  52. sonusai/mixture/spectral_mask.py +4 -6
  53. sonusai/mixture/target_class_balancing.py +41 -36
  54. sonusai/mixture/targets.py +69 -67
  55. sonusai/mixture/tokenized_shell_vars.py +23 -23
  56. sonusai/mixture/torchaudio_audio.py +14 -15
  57. sonusai/mixture/torchaudio_augmentation.py +23 -27
  58. sonusai/mixture/truth.py +48 -26
  59. sonusai/mixture/truth_functions/__init__.py +26 -0
  60. sonusai/mixture/truth_functions/crm.py +56 -38
  61. sonusai/mixture/truth_functions/datatypes.py +37 -0
  62. sonusai/mixture/truth_functions/energy.py +85 -59
  63. sonusai/mixture/truth_functions/file.py +30 -30
  64. sonusai/mixture/truth_functions/phoneme.py +14 -7
  65. sonusai/mixture/truth_functions/sed.py +71 -45
  66. sonusai/mixture/truth_functions/target.py +69 -106
  67. sonusai/mkwav.py +58 -101
  68. sonusai/onnx_predict.py +46 -43
  69. sonusai/queries/__init__.py +3 -1
  70. sonusai/queries/queries.py +100 -59
  71. sonusai/speech/__init__.py +2 -0
  72. sonusai/speech/l2arctic.py +24 -23
  73. sonusai/speech/librispeech.py +16 -17
  74. sonusai/speech/mcgill.py +22 -21
  75. sonusai/speech/textgrid.py +32 -25
  76. sonusai/speech/timit.py +45 -42
  77. sonusai/speech/vctk.py +14 -13
  78. sonusai/speech/voxceleb.py +26 -20
  79. sonusai/summarize_metric_spenh.py +11 -10
  80. sonusai/utils/__init__.py +4 -3
  81. sonusai/utils/asl_p56.py +1 -1
  82. sonusai/utils/asr.py +37 -17
  83. sonusai/utils/asr_functions/__init__.py +2 -0
  84. sonusai/utils/asr_functions/aaware_whisper.py +18 -12
  85. sonusai/utils/audio_devices.py +12 -12
  86. sonusai/utils/braced_glob.py +6 -8
  87. sonusai/utils/calculate_input_shape.py +1 -4
  88. sonusai/utils/compress.py +2 -2
  89. sonusai/utils/convert_string_to_number.py +1 -3
  90. sonusai/utils/create_timestamp.py +1 -1
  91. sonusai/utils/create_ts_name.py +2 -2
  92. sonusai/utils/dataclass_from_dict.py +1 -1
  93. sonusai/utils/docstring.py +6 -6
  94. sonusai/utils/energy_f.py +9 -7
  95. sonusai/utils/engineering_number.py +56 -54
  96. sonusai/utils/get_label_names.py +8 -10
  97. sonusai/utils/human_readable_size.py +2 -2
  98. sonusai/utils/model_utils.py +3 -5
  99. sonusai/utils/numeric_conversion.py +2 -4
  100. sonusai/utils/onnx_utils.py +43 -32
  101. sonusai/utils/parallel.py +41 -30
  102. sonusai/utils/print_mixture_details.py +25 -22
  103. sonusai/utils/ranges.py +12 -12
  104. sonusai/utils/read_predict_data.py +11 -9
  105. sonusai/utils/reshape.py +19 -26
  106. sonusai/utils/seconds_to_hms.py +1 -1
  107. sonusai/utils/stacked_complex.py +8 -16
  108. sonusai/utils/stratified_shuffle_split.py +29 -27
  109. sonusai/utils/write_audio.py +2 -2
  110. sonusai/utils/yes_or_no.py +3 -3
  111. sonusai/vars.py +14 -14
  112. {sonusai-0.18.9.dist-info → sonusai-0.19.6.dist-info}/METADATA +20 -21
  113. sonusai-0.19.6.dist-info/RECORD +125 -0
  114. {sonusai-0.18.9.dist-info → sonusai-0.19.6.dist-info}/WHEEL +1 -1
  115. sonusai/mixture/truth_functions/data.py +0 -58
  116. sonusai/utils/read_mixture_data.py +0 -14
  117. sonusai-0.18.9.dist-info/RECORD +0 -125
  118. {sonusai-0.18.9.dist-info → sonusai-0.19.6.dist-info}/entry_points.txt +0 -0
@@ -2,90 +2,92 @@ from __future__ import annotations
2
2
 
3
3
  from decimal import Decimal
4
4
  from typing import Any
5
+ from typing import ClassVar
5
6
 
6
7
 
7
8
  class EngineeringNumber:
8
9
  """Easy manipulation of numbers which use engineering notation"""
9
10
 
10
- _suffix_lookup = {
11
- 'Y': 'e24',
12
- 'Z': 'e21',
13
- 'E': 'e18',
14
- 'P': 'e15',
15
- 'T': 'e12',
16
- 'G': 'e9',
17
- 'M': 'e6',
18
- 'k': 'e3',
19
- '': 'e0',
20
- 'm': 'e-3',
21
- 'u': 'e-6',
22
- 'n': 'e-9',
23
- 'p': 'e-12',
24
- 'f': 'e-15',
25
- 'a': 'e-18',
26
- 'z': 'e-21',
27
- 'y': 'e-24',
11
+ _suffix_lookup: ClassVar = {
12
+ "Y": "e24",
13
+ "Z": "e21",
14
+ "E": "e18",
15
+ "P": "e15",
16
+ "T": "e12",
17
+ "G": "e9",
18
+ "M": "e6",
19
+ "k": "e3",
20
+ "": "e0",
21
+ "m": "e-3",
22
+ "u": "e-6",
23
+ "n": "e-9",
24
+ "p": "e-12",
25
+ "f": "e-15",
26
+ "a": "e-18",
27
+ "z": "e-21",
28
+ "y": "e-24",
28
29
  }
29
30
 
30
- _exponent_lookup_scaled = {
31
- '-12': 'Y',
32
- '-15': 'Z',
33
- '-18': 'E',
34
- '-21': 'P',
35
- '-24': 'T',
36
- '-27': 'G',
37
- '-30': 'M',
38
- '-33': 'k',
39
- '-36': '',
40
- '-39': 'm',
41
- '-42': 'u',
42
- '-45': 'n',
43
- '-48': 'p',
44
- '-51': 'f',
45
- '-54': 'a',
46
- '-57': 'z',
47
- '-60': 'y',
31
+ _exponent_lookup_scaled: ClassVar = {
32
+ "-12": "Y",
33
+ "-15": "Z",
34
+ "-18": "E",
35
+ "-21": "P",
36
+ "-24": "T",
37
+ "-27": "G",
38
+ "-30": "M",
39
+ "-33": "k",
40
+ "-36": "",
41
+ "-39": "m",
42
+ "-42": "u",
43
+ "-45": "n",
44
+ "-48": "p",
45
+ "-51": "f",
46
+ "-54": "a",
47
+ "-57": "z",
48
+ "-60": "y",
48
49
  }
49
50
 
50
- def __init__(self,
51
- value: str | float | int | EngineeringNumber,
52
- precision: int = 2,
53
- significant: int = 0):
51
+ def __init__(
52
+ self,
53
+ value: str | float | int | EngineeringNumber,
54
+ precision: int = 2,
55
+ significant: int = 0,
56
+ ):
54
57
  """
55
58
  :param value: string, integer, or float representing the numeric value of the number
56
59
  :param precision: the precision past the decimal
57
60
  :param significant: the number of significant digits
58
61
  if given, significant takes precedence over precision
59
62
  """
60
- from sonusai import SonusAIError
61
-
62
63
  self.precision = precision
63
64
  self.significant = significant
64
65
 
65
66
  if isinstance(value, str):
66
- suffix_keys = [key for key in self._suffix_lookup.keys() if key != '']
67
+ suffix_keys = [key for key in self._suffix_lookup if key != ""]
67
68
 
69
+ str_value = str(value)
68
70
  for suffix in suffix_keys:
69
- if suffix in value:
70
- value = value[:-1] + self._suffix_lookup[suffix]
71
+ if suffix in str_value:
72
+ str_value = str_value[:-1] + self._suffix_lookup[suffix]
71
73
  break
72
74
 
73
- self.number = Decimal(value)
75
+ self.number = Decimal(str_value)
74
76
 
75
- elif isinstance(value, int) or isinstance(value, float) or isinstance(value, EngineeringNumber):
77
+ elif isinstance(value, int | float | EngineeringNumber):
76
78
  self.number = Decimal(str(value))
77
79
 
78
80
  else:
79
- raise SonusAIError('value has unsupported type')
81
+ raise TypeError("value has unsupported type")
80
82
 
81
83
  def __repr__(self):
82
84
  """Returns the string representation"""
83
85
  # The Decimal class only really converts numbers that are very small into engineering notation.
84
86
  # So we will simply make all numbers small numbers and take advantage of the Decimal class.
85
- number_str = self.number * Decimal('10e-37')
87
+ number_str = self.number * Decimal("10e-37")
86
88
  number_str = number_str.to_eng_string().lower()
87
89
 
88
- base, exponent = number_str.split('e')
90
+ base, exponent = number_str.split("e")
89
91
 
90
92
  if self.significant > 0:
91
93
  abs_base = abs(Decimal(base))
@@ -98,12 +100,12 @@ class EngineeringNumber:
98
100
 
99
101
  base = str(round(Decimal(base), num_digits))
100
102
 
101
- if 'e' in base.lower():
103
+ if "e" in base.lower():
102
104
  base = str(int(Decimal(base)))
103
105
 
104
106
  # Remove trailing decimal
105
- if '.' in base:
106
- base = base.rstrip('.')
107
+ if "." in base:
108
+ base = base.rstrip(".")
107
109
 
108
110
  return base + self._exponent_lookup_scaled[exponent]
109
111
 
@@ -159,6 +161,6 @@ class EngineeringNumber:
159
161
  return self.number >= self._to_decimal(other)
160
162
 
161
163
  def __eq__(self, other: Any) -> bool:
162
- if not isinstance(other, (str, float, int, EngineeringNumber)):
164
+ if not isinstance(other, str | float | int | EngineeringNumber):
163
165
  return NotImplemented
164
166
  return self.number == self._to_decimal(other)
@@ -1,22 +1,20 @@
1
- def get_label_names(num_labels: int, file: str = None) -> list:
1
+ def get_label_names(num_labels: int, file: str | None = None) -> list:
2
2
  """Return label names in a list. Read from CSV file, if provided."""
3
3
  import csv
4
4
 
5
- from sonusai import SonusAIError
6
-
7
5
  if file is None:
8
- return [f'Class {val + 1}' for val in range(num_labels)]
6
+ return [f"Class {val + 1}" for val in range(num_labels)]
9
7
 
10
- label_names = [''] * num_labels
8
+ label_names = [""] * num_labels
11
9
  with open(file) as f:
12
10
  reader = csv.DictReader(f)
13
- if 'index' not in reader.fieldnames or 'display_name' not in reader.fieldnames:
14
- raise SonusAIError(f'Missing required fields in labels CSV.')
11
+ if reader.fieldnames is None or "index" not in reader.fieldnames or "display_name" not in reader.fieldnames:
12
+ raise ValueError("Missing required fields in labels CSV.")
15
13
 
16
14
  for row in reader:
17
- index = int(row['index']) - 1
15
+ index = int(row["index"]) - 1
18
16
  if index >= num_labels:
19
- raise SonusAIError(f'The number of given label names does not match the number of labels.')
20
- label_names[index] = row['display_name']
17
+ raise ValueError("The number of given label names does not match the number of labels.")
18
+ label_names[index] = row["display_name"]
21
19
 
22
20
  return label_names
@@ -1,7 +1,7 @@
1
1
  def human_readable_size(size: float, decimal_places: int = 3) -> str:
2
2
  """Convert number into string with units"""
3
- for unit in ['B', 'kB', 'MB', 'GB', 'TB']:
3
+ for unit in ["B", "kB", "MB", "GB", "TB"]: # noqa: B007
4
4
  if size < 1024.0:
5
5
  break
6
6
  size /= 1024.0
7
- return f'{size:.{decimal_places}f} {unit}'
7
+ return f"{size:.{decimal_places}f} {unit}"
@@ -9,12 +9,10 @@ def import_module(name: str) -> Any:
9
9
  import sys
10
10
  from importlib import import_module
11
11
 
12
- from sonusai import SonusAIError
13
-
14
12
  try:
15
13
  path = os.path.dirname(name)
16
14
  if len(path) < 1:
17
- path = './'
15
+ path = "./"
18
16
 
19
17
  # Add model file location to system path
20
18
  sys.path.append(os.path.abspath(path))
@@ -23,8 +21,8 @@ def import_module(name: str) -> Any:
23
21
  root = os.path.splitext(os.path.basename(name))[0]
24
22
  model = import_module(root)
25
23
  except Exception as e:
26
- raise SonusAIError(f'Error: could not import model from {name}: {e}.')
24
+ raise OSError(f"Error: could not import model from {name}: {e}.") from e
27
25
  except Exception as e:
28
- raise SonusAIError(f'Error: could not find {name}: {e}.')
26
+ raise OSError(f"Error: could not find {name}: {e}.") from e
29
27
 
30
28
  return model
@@ -2,12 +2,10 @@ import numpy as np
2
2
 
3
3
 
4
4
  def int16_to_float(x: np.ndarray) -> np.ndarray:
5
- """ Convert int16 array to floating point with range +/- 1
6
- """
5
+ """Convert int16 array to floating point with range +/- 1"""
7
6
  return x.astype(np.float32) / 32768
8
7
 
9
8
 
10
9
  def float_to_int16(x: np.ndarray) -> np.ndarray:
11
- """ Convert float point array with range +/- 1 to int16
12
- """
10
+ """Convert float point array with range +/- 1 to int16"""
13
11
  return (x * 32768).astype(np.int16)
@@ -1,5 +1,4 @@
1
- from typing import Optional
2
- from typing import Sequence
1
+ from collections.abc import Sequence
3
2
 
4
3
  from onnx import ModelProto
5
4
  from onnx import ValueInfoProto
@@ -7,7 +6,14 @@ from onnxruntime import InferenceSession
7
6
  from onnxruntime import NodeArg
8
7
  from onnxruntime import SessionOptions
9
8
 
10
- REQUIRED_HPARAMS = ('feature', 'batch_size', 'timesteps', 'flatten', 'add1ch', 'truth_mutex')
9
+ REQUIRED_HPARAMS = (
10
+ "feature",
11
+ "batch_size",
12
+ "timesteps",
13
+ "flatten",
14
+ "add1ch",
15
+ "truth_mutex",
16
+ )
11
17
 
12
18
 
13
19
  def _extract_shapes(io: list[ValueInfoProto]) -> list[list[int] | str]:
@@ -18,14 +24,14 @@ def _extract_shapes(io: list[ValueInfoProto]) -> list[list[int] | str]:
18
24
  # get tensor type: 0, 1, 2, etc.
19
25
  tensor_type = item.type.tensor_type
20
26
  # check if it has a shape
21
- if tensor_type.HasField('shape'):
27
+ if tensor_type.HasField("shape"):
22
28
  tmp_shape = []
23
29
  # iterate through dimensions of the shape
24
30
  for d in tensor_type.shape.dim:
25
- if d.HasField('dim_value'):
31
+ if d.HasField("dim_value"):
26
32
  # known dimension, int value
27
33
  tmp_shape.append(d.dim_value)
28
- elif d.HasField('dim_param'):
34
+ elif d.HasField("dim_param"):
29
35
  # dynamic dim with symbolic name of d.dim_param; set size to 0
30
36
  tmp_shape.append(0)
31
37
  else:
@@ -34,19 +40,21 @@ def _extract_shapes(io: list[ValueInfoProto]) -> list[list[int] | str]:
34
40
  # add as a list
35
41
  shapes.append(tmp_shape)
36
42
  else:
37
- shapes.append('unknown rank')
43
+ shapes.append("unknown rank")
38
44
 
39
45
  return shapes
40
46
 
41
47
 
42
- def get_and_check_inputs(model: ModelProto) -> tuple[list[ValueInfoProto], list[list[int] | str]]:
48
+ def get_and_check_inputs(
49
+ model: ModelProto,
50
+ ) -> tuple[list[ValueInfoProto], list[list[int] | str]]:
43
51
  from sonusai import logger
44
52
 
45
53
  # ignore initializer inputs (only seen in older ONNX < v1.5)
46
54
  initializer_names = [x.name for x in model.graph.initializer]
47
55
  inputs = [i for i in model.graph.input if i.name not in initializer_names]
48
56
  if len(inputs) != 1:
49
- logger.warning(f'Warning: ONNX model has {len(inputs)} inputs; expected only 1')
57
+ logger.warning(f"Warning: ONNX model has {len(inputs)} inputs; expected only 1")
50
58
 
51
59
  # This one-liner works only if input has type and shape, returns a list
52
60
  # shape0 = [d.dim_value for d in inputs[0].type.tensor_type.shape.dim]
@@ -55,12 +63,14 @@ def get_and_check_inputs(model: ModelProto) -> tuple[list[ValueInfoProto], list[
55
63
  return inputs, shapes
56
64
 
57
65
 
58
- def get_and_check_outputs(model: ModelProto) -> tuple[list[ValueInfoProto], list[list[int] | str]]:
66
+ def get_and_check_outputs(
67
+ model: ModelProto,
68
+ ) -> tuple[list[ValueInfoProto], list[list[int] | str]]:
59
69
  from sonusai import logger
60
70
 
61
- outputs = [o for o in model.graph.output]
71
+ outputs = list(model.graph.output)
62
72
  if len(outputs) != 1:
63
- logger.warning(f'Warning: ONNX model has {len(outputs)} outputs; expected only 1')
73
+ logger.warning(f"Warning: ONNX model has {len(outputs)} outputs; expected only 1")
64
74
 
65
75
  shapes = _extract_shapes(outputs)
66
76
 
@@ -85,38 +95,39 @@ def add_sonusai_metadata(model: ModelProto, hparams: dict) -> ModelProto:
85
95
  from sonusai import logger
86
96
 
87
97
  # Note hparams should be a dict (i.e., extracted from checkpoint)
88
- assert eval(str(hparams)) == hparams
98
+ if eval(str(hparams)) != hparams: # noqa: S307
99
+ raise TypeError("hparams is not a dict")
89
100
  for key in REQUIRED_HPARAMS:
90
- if key not in hparams.keys():
91
- logger.warning(f'Warning: SonusAI hyperparameters are missing: {key}')
101
+ if key not in hparams:
102
+ logger.warning(f"Warning: SonusAI hyperparameters are missing: {key}")
92
103
 
93
104
  meta = model.metadata_props.add()
94
- meta.key = 'hparams'
105
+ meta.key = "hparams"
95
106
  meta.value = str(hparams)
96
107
 
97
108
  return model
98
109
 
99
110
 
100
- def get_sonusai_metadata(session: InferenceSession) -> Optional[dict]:
101
- """Get SonusAI hyperparameter metadata from an ONNX Runtime session.
102
- """
111
+ def get_sonusai_metadata(session: InferenceSession) -> dict | None:
112
+ """Get SonusAI hyperparameter metadata from an ONNX Runtime session."""
103
113
  from sonusai import logger
104
114
 
105
115
  meta = session.get_modelmeta()
106
- if 'hparams' not in meta.custom_metadata_map.keys():
116
+ if "hparams" not in meta.custom_metadata_map:
107
117
  logger.warning("Warning: ONNX model metadata does not contain 'hparams'")
108
118
  return None
109
119
 
110
- hparams = eval(meta.custom_metadata_map['hparams'])
120
+ hparams = eval(meta.custom_metadata_map["hparams"]) # noqa: S307
111
121
  for key in REQUIRED_HPARAMS:
112
- if key not in hparams.keys():
113
- logger.warning(f'Warning: ONNX model does not have required SonusAI hyperparameters: {key}')
122
+ if key not in hparams:
123
+ logger.warning(f"Warning: ONNX model does not have required SonusAI hyperparameters: {key}")
114
124
 
115
125
  return hparams
116
126
 
117
127
 
118
- def load_ort_session(model_path: str, providers: Sequence[str | tuple[str, dict]] = None) -> tuple[
119
- InferenceSession, SessionOptions, str, dict, list[NodeArg], list[NodeArg]]:
128
+ def load_ort_session(
129
+ model_path: str, providers: Sequence[str | tuple[str, dict]] | None = None
130
+ ) -> tuple[InferenceSession, SessionOptions, str, dict | None, list[NodeArg], list[NodeArg]]:
120
131
  from os.path import basename
121
132
  from os.path import exists
122
133
  from os.path import isfile
@@ -127,27 +138,27 @@ def load_ort_session(model_path: str, providers: Sequence[str | tuple[str, dict]
127
138
  from sonusai import logger
128
139
 
129
140
  if providers is None:
130
- providers = ['CPUExecutionProvider']
141
+ providers = ["CPUExecutionProvider"]
131
142
 
132
143
  if exists(model_path) and isfile(model_path):
133
144
  model_basename = basename(model_path)
134
145
  model_root = splitext(model_basename)[0]
135
- logger.info(f'Importing model from {model_basename}')
146
+ logger.info(f"Importing model from {model_basename}")
136
147
  try:
137
148
  session = ort.InferenceSession(model_path, providers=providers)
138
149
  options = ort.SessionOptions()
139
150
  except Exception as e:
140
- logger.exception(f'Error: could not load ONNX model from {model_path}: {e}')
141
- raise SystemExit(1)
151
+ logger.exception(f"Error: could not load ONNX model from {model_path}: {e}")
152
+ raise SystemExit(1) from e
142
153
  else:
143
- logger.exception(f'Error: model file does not exist: {model_path}')
154
+ logger.exception(f"Error: model file does not exist: {model_path}")
144
155
  raise SystemExit(1)
145
156
 
146
- logger.info(f'Opened session with provider options: {session._provider_options}.')
157
+ logger.info(f"Opened session with provider options: {session._provider_options}.")
147
158
  hparams = get_sonusai_metadata(session)
148
159
  if hparams is not None:
149
160
  for key in REQUIRED_HPARAMS:
150
- logger.info(f' {key:12} {hparams[key]}')
161
+ logger.info(f" {key:12} {hparams[key]}")
151
162
 
152
163
  inputs = session.get_inputs()
153
164
  outputs = session.get_outputs()
sonusai/utils/parallel.py CHANGED
@@ -1,32 +1,41 @@
1
+ import warnings
2
+ from collections.abc import Callable
3
+ from collections.abc import Iterable
1
4
  from multiprocessing import current_process
2
5
  from multiprocessing import get_context
3
6
  from typing import Any
4
- from typing import Callable
5
- from typing import Iterable
6
- from typing import Optional
7
7
 
8
- from tqdm import tqdm
8
+ from tqdm import TqdmExperimentalWarning
9
+ from tqdm.rich import tqdm
9
10
 
10
- CONTEXT = 'fork'
11
+ warnings.filterwarnings(action="ignore", category=TqdmExperimentalWarning)
11
12
 
13
+ track = tqdm
12
14
 
13
- def pp_tqdm_imap(func: Callable,
14
- *iterables: Iterable,
15
- initializer: Optional[Callable[..., None]] = None,
16
- initargs: Optional[Iterable[Any]] = None,
17
- progress: Optional[tqdm] = None,
18
- num_cpus: Optional[int] = None,
19
- total: Optional[int] = None,
20
- no_par: bool = False) -> list[Any]:
15
+ CONTEXT = "fork"
16
+
17
+
18
+ def par_track(
19
+ func: Callable,
20
+ *iterables: Iterable,
21
+ initializer: Callable[..., None] | None = None,
22
+ initargs: Iterable[Any] | None = None,
23
+ progress: tqdm | None = None,
24
+ num_cpus: int | float | None = None,
25
+ total: int | None = None,
26
+ no_par: bool = False,
27
+ ) -> list[Any]:
21
28
  """Performs a parallel ordered imap with tqdm progress."""
22
- from os import cpu_count
23
- from typing import Sized
29
+ from collections.abc import Sized
30
+
31
+ from psutil import cpu_count
24
32
 
25
33
  if total is None:
26
- total = min(len(iterable) for iterable in iterables if isinstance(iterable, Sized))
34
+ _total = min(len(iterable) for iterable in iterables if isinstance(iterable, Sized))
35
+ else:
36
+ _total = int(total)
27
37
 
28
- results: list[Any] = [None] * total
29
- n = 0
38
+ results: list[Any] = [None] * _total
30
39
  if no_par or current_process().daemon:
31
40
  if initializer is not None:
32
41
  if initargs is not None:
@@ -34,21 +43,26 @@ def pp_tqdm_imap(func: Callable,
34
43
  else:
35
44
  initializer()
36
45
 
37
- for result in map(func, *iterables):
46
+ for n, result in enumerate(map(func, *iterables)):
38
47
  results[n] = result
39
- n += 1
40
48
  if progress is not None:
41
49
  progress.update()
42
50
  else:
43
51
  if num_cpus is None:
44
- num_cpus = cpu_count()
45
- elif num_cpus is float:
46
- num_cpus = int(round(num_cpus * cpu_count()))
52
+ _num_cpus = max(cpu_count() - 2, 1)
53
+ elif isinstance(num_cpus, float):
54
+ _num_cpus = int(round(num_cpus * cpu_count()))
55
+ else:
56
+ _num_cpus = int(num_cpus)
47
57
 
48
- if total < num_cpus:
49
- num_cpus = total
50
- with get_context(CONTEXT).Pool(processes=num_cpus, initializer=initializer, initargs=initargs) as pool:
51
- for result in pool.imap(func, *iterables): # type: ignore
58
+ _num_cpus = min(_num_cpus, _total)
59
+
60
+ if initargs is None:
61
+ initargs = []
62
+
63
+ with get_context(CONTEXT).Pool(processes=_num_cpus, initializer=initializer, initargs=initargs) as pool:
64
+ n = 0
65
+ for result in pool.imap(func, *iterables): # type: ignore[arg-type]
52
66
  results[n] = result
53
67
  n += 1
54
68
  if progress is not None:
@@ -59,6 +73,3 @@ def pp_tqdm_imap(func: Callable,
59
73
  if progress is not None:
60
74
  progress.close()
61
75
  return results
62
-
63
-
64
- pp_imap = pp_tqdm_imap
@@ -1,37 +1,38 @@
1
- from typing import Callable
1
+ from collections.abc import Callable
2
2
 
3
3
  from sonusai.mixture import ClassCount
4
4
  from sonusai.mixture import MixtureDatabase
5
5
 
6
6
 
7
- def print_mixture_details(mixdb: MixtureDatabase,
8
- mixid: int = None,
9
- desc_len: int = 1,
10
- print_fn: Callable = print) -> None:
7
+ def print_mixture_details(
8
+ mixdb: MixtureDatabase,
9
+ mixid: int | None = None,
10
+ desc_len: int = 1,
11
+ print_fn: Callable = print,
12
+ ) -> None:
11
13
  import numpy as np
12
14
 
13
- from sonusai import SonusAIError
14
15
  from sonusai.mixture import SAMPLE_RATE
15
16
  from sonusai.utils import seconds_to_hms
16
17
 
17
18
  if mixid is not None:
18
19
  if 0 < mixid >= mixdb.num_mixtures:
19
- raise SonusAIError(f'Given mixid is outside valid range of 0:{mixdb.num_mixtures - 1}.')
20
+ raise ValueError(f"Given mixid is outside valid range of 0:{mixdb.num_mixtures - 1}.")
20
21
 
21
- print_fn(f'Mixture {mixid} details')
22
+ print_fn(f"Mixture {mixid} details")
22
23
  mixture = mixdb.mixture(mixid)
23
24
  target_files = [mixdb.target_files[target.file_id] for target in mixture.targets]
24
25
  target_augmentations = [target.augmentation for target in mixture.targets]
25
26
  noise_file = mixdb.noise_file(mixture.noise.file_id)
26
27
  for t_idx, target_file in enumerate(target_files):
27
- print_fn(f' Target {t_idx}')
28
+ print_fn(f" Target {t_idx}")
28
29
  print_fn(f'{" Name":{desc_len}} {target_file.name}')
29
30
  print_fn(f'{" Duration":{desc_len}} {seconds_to_hms(target_file.duration)}')
30
- for ts_idx, truth_setting in enumerate(target_file.truth_settings):
31
- print_fn(f' Truth setting {ts_idx}')
32
- print_fn(f'{" Index":{desc_len}} {truth_setting.index}')
33
- print_fn(f'{" Function":{desc_len}} {truth_setting.function}')
34
- print_fn(f'{" Config":{desc_len}} {truth_setting.config}')
31
+ for truth_name, truth_config in target_file.truth_configs.items():
32
+ print_fn(f" Truth config: '{truth_name}'")
33
+ print_fn(f'{" Function":{desc_len}} {truth_config.function}')
34
+ print_fn(f'{" Stride reduction":{desc_len}} {truth_config.stride_reduction}')
35
+ print_fn(f'{" Config":{desc_len}} {truth_config.config}')
35
36
  print_fn(f'{" Augmentation":{desc_len}} {target_augmentations[t_idx]}')
36
37
  print_fn(f'{" Samples":{desc_len}} {mixture.samples}')
37
38
  print_fn(f'{" Feature frames":{desc_len}} {mixdb.mixture_feature_frames(mixid)}')
@@ -42,18 +43,20 @@ def print_mixture_details(mixdb: MixtureDatabase,
42
43
  print_fn(f'{" Target gain":{desc_len}} {[target.gain for target in mixture.targets]}')
43
44
  print_fn(f'{" Target SNR gain":{desc_len}} {mixture.target_snr_gain}')
44
45
  print_fn(f'{" Noise SNR gain":{desc_len}} {mixture.noise_snr_gain}')
45
- print_fn('')
46
+ print_fn("")
46
47
 
47
48
 
48
- def print_class_count(class_count: ClassCount,
49
- length: int,
50
- print_fn: Callable = print,
51
- all_class_counts: bool = False) -> None:
49
+ def print_class_count(
50
+ class_count: ClassCount,
51
+ length: int,
52
+ print_fn: Callable = print,
53
+ all_class_counts: bool = False,
54
+ ) -> None:
52
55
  from sonusai.utils import max_text_width
53
56
 
54
- print_fn(f'Class count:')
57
+ print_fn("Class count:")
55
58
  idx_len = max_text_width(len(class_count))
56
59
  for idx, count in enumerate(class_count):
57
60
  if all_class_counts or count > 0:
58
- desc = f' class {idx + 1:{idx_len}}'
59
- print_fn(f'{desc:{length}} {count}')
61
+ desc = f" class {idx + 1:{idx_len}}"
62
+ print_fn(f"{desc:{length}} {count}")