sonusai 0.18.9__py3-none-any.whl → 0.19.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (118) hide show
  1. sonusai/__init__.py +20 -29
  2. sonusai/aawscd_probwrite.py +18 -18
  3. sonusai/audiofe.py +93 -80
  4. sonusai/calc_metric_spenh.py +395 -321
  5. sonusai/data/genmixdb.yml +5 -11
  6. sonusai/{gentcst.py → deprecated/gentcst.py} +146 -149
  7. sonusai/{plot.py → deprecated/plot.py} +177 -131
  8. sonusai/{tplot.py → deprecated/tplot.py} +124 -102
  9. sonusai/doc/__init__.py +1 -1
  10. sonusai/doc/doc.py +112 -177
  11. sonusai/doc.py +10 -10
  12. sonusai/genft.py +93 -77
  13. sonusai/genmetrics.py +59 -46
  14. sonusai/genmix.py +116 -104
  15. sonusai/genmixdb.py +194 -153
  16. sonusai/lsdb.py +56 -66
  17. sonusai/main.py +23 -20
  18. sonusai/metrics/__init__.py +2 -0
  19. sonusai/metrics/calc_audio_stats.py +29 -24
  20. sonusai/metrics/calc_class_weights.py +7 -7
  21. sonusai/metrics/calc_optimal_thresholds.py +5 -7
  22. sonusai/metrics/calc_pcm.py +3 -3
  23. sonusai/metrics/calc_pesq.py +10 -7
  24. sonusai/metrics/calc_phase_distance.py +3 -3
  25. sonusai/metrics/calc_sa_sdr.py +10 -8
  26. sonusai/metrics/calc_segsnr_f.py +15 -17
  27. sonusai/metrics/calc_speech.py +105 -47
  28. sonusai/metrics/calc_wer.py +35 -32
  29. sonusai/metrics/calc_wsdr.py +10 -7
  30. sonusai/metrics/class_summary.py +30 -27
  31. sonusai/metrics/confusion_matrix_summary.py +25 -22
  32. sonusai/metrics/one_hot.py +91 -57
  33. sonusai/metrics/snr_summary.py +53 -46
  34. sonusai/mixture/__init__.py +19 -14
  35. sonusai/mixture/audio.py +4 -6
  36. sonusai/mixture/augmentation.py +37 -43
  37. sonusai/mixture/class_count.py +5 -14
  38. sonusai/mixture/config.py +292 -225
  39. sonusai/mixture/constants.py +41 -30
  40. sonusai/mixture/data_io.py +155 -0
  41. sonusai/mixture/datatypes.py +111 -108
  42. sonusai/mixture/db_datatypes.py +54 -70
  43. sonusai/mixture/eq_rule_is_valid.py +6 -9
  44. sonusai/mixture/feature.py +40 -38
  45. sonusai/mixture/generation.py +522 -389
  46. sonusai/mixture/helpers.py +217 -272
  47. sonusai/mixture/log_duration_and_sizes.py +16 -13
  48. sonusai/mixture/mixdb.py +669 -477
  49. sonusai/mixture/soundfile_audio.py +12 -17
  50. sonusai/mixture/sox_audio.py +91 -112
  51. sonusai/mixture/sox_augmentation.py +8 -9
  52. sonusai/mixture/spectral_mask.py +4 -6
  53. sonusai/mixture/target_class_balancing.py +41 -36
  54. sonusai/mixture/targets.py +69 -67
  55. sonusai/mixture/tokenized_shell_vars.py +23 -23
  56. sonusai/mixture/torchaudio_audio.py +14 -15
  57. sonusai/mixture/torchaudio_augmentation.py +23 -27
  58. sonusai/mixture/truth.py +48 -26
  59. sonusai/mixture/truth_functions/__init__.py +26 -0
  60. sonusai/mixture/truth_functions/crm.py +56 -38
  61. sonusai/mixture/truth_functions/datatypes.py +37 -0
  62. sonusai/mixture/truth_functions/energy.py +85 -59
  63. sonusai/mixture/truth_functions/file.py +30 -30
  64. sonusai/mixture/truth_functions/phoneme.py +14 -7
  65. sonusai/mixture/truth_functions/sed.py +71 -45
  66. sonusai/mixture/truth_functions/target.py +69 -106
  67. sonusai/mkwav.py +52 -85
  68. sonusai/onnx_predict.py +46 -43
  69. sonusai/queries/__init__.py +3 -1
  70. sonusai/queries/queries.py +100 -59
  71. sonusai/speech/__init__.py +2 -0
  72. sonusai/speech/l2arctic.py +24 -23
  73. sonusai/speech/librispeech.py +16 -17
  74. sonusai/speech/mcgill.py +22 -21
  75. sonusai/speech/textgrid.py +32 -25
  76. sonusai/speech/timit.py +45 -42
  77. sonusai/speech/vctk.py +14 -13
  78. sonusai/speech/voxceleb.py +26 -20
  79. sonusai/summarize_metric_spenh.py +11 -10
  80. sonusai/utils/__init__.py +4 -3
  81. sonusai/utils/asl_p56.py +1 -1
  82. sonusai/utils/asr.py +37 -17
  83. sonusai/utils/asr_functions/__init__.py +2 -0
  84. sonusai/utils/asr_functions/aaware_whisper.py +18 -12
  85. sonusai/utils/audio_devices.py +12 -12
  86. sonusai/utils/braced_glob.py +6 -8
  87. sonusai/utils/calculate_input_shape.py +1 -4
  88. sonusai/utils/compress.py +2 -2
  89. sonusai/utils/convert_string_to_number.py +1 -3
  90. sonusai/utils/create_timestamp.py +1 -1
  91. sonusai/utils/create_ts_name.py +2 -2
  92. sonusai/utils/dataclass_from_dict.py +1 -1
  93. sonusai/utils/docstring.py +6 -6
  94. sonusai/utils/energy_f.py +9 -7
  95. sonusai/utils/engineering_number.py +56 -54
  96. sonusai/utils/get_label_names.py +8 -10
  97. sonusai/utils/human_readable_size.py +2 -2
  98. sonusai/utils/model_utils.py +3 -5
  99. sonusai/utils/numeric_conversion.py +2 -4
  100. sonusai/utils/onnx_utils.py +43 -32
  101. sonusai/utils/parallel.py +40 -27
  102. sonusai/utils/print_mixture_details.py +25 -22
  103. sonusai/utils/ranges.py +12 -12
  104. sonusai/utils/read_predict_data.py +11 -9
  105. sonusai/utils/reshape.py +19 -26
  106. sonusai/utils/seconds_to_hms.py +1 -1
  107. sonusai/utils/stacked_complex.py +8 -16
  108. sonusai/utils/stratified_shuffle_split.py +29 -27
  109. sonusai/utils/write_audio.py +2 -2
  110. sonusai/utils/yes_or_no.py +3 -3
  111. sonusai/vars.py +14 -14
  112. {sonusai-0.18.9.dist-info → sonusai-0.19.5.dist-info}/METADATA +20 -21
  113. sonusai-0.19.5.dist-info/RECORD +125 -0
  114. {sonusai-0.18.9.dist-info → sonusai-0.19.5.dist-info}/WHEEL +1 -1
  115. sonusai/mixture/truth_functions/data.py +0 -58
  116. sonusai/utils/read_mixture_data.py +0 -14
  117. sonusai-0.18.9.dist-info/RECORD +0 -125
  118. {sonusai-0.18.9.dist-info → sonusai-0.19.5.dist-info}/entry_points.txt +0 -0
@@ -5,14 +5,15 @@ from sonusai.mixture.datatypes import TargetFile
5
5
  from sonusai.mixture.datatypes import TargetFiles
6
6
 
7
7
 
8
- def balance_targets(augmented_targets: AugmentedTargets,
9
- targets: TargetFiles,
10
- target_augmentations: AugmentationRules,
11
- class_balancing_augmentation: AugmentationRule,
12
- num_classes: int,
13
- truth_mutex: bool,
14
- num_ir: int,
15
- mixups: list[int] = None) -> tuple[AugmentedTargets, AugmentationRules]:
8
+ def balance_targets(
9
+ augmented_targets: AugmentedTargets,
10
+ targets: TargetFiles,
11
+ target_augmentations: AugmentationRules,
12
+ class_balancing_augmentation: AugmentationRule,
13
+ num_classes: int,
14
+ num_ir: int,
15
+ mixups: list[int] | None = None,
16
+ ) -> tuple[AugmentedTargets, AugmentationRules]:
16
17
  import math
17
18
 
18
19
  from .augmentation import get_mixups
@@ -34,14 +35,13 @@ def balance_targets(augmented_targets: AugmentedTargets,
34
35
  target_augmentations=target_augmentations,
35
36
  mixup=mixup,
36
37
  num_classes=num_classes,
37
- truth_mutex=truth_mutex)
38
+ )
38
39
 
39
40
  largest = max([len(item) for item in augmented_target_indices_by_class])
40
41
  largest = math.ceil(largest / mixup) * mixup
41
42
  for at_indices in augmented_target_indices_by_class:
42
43
  additional_augmentations_needed = largest - len(at_indices)
43
- target_ids = sorted(
44
- list(set([augmented_targets[at_index].target_id for at_index in at_indices])))
44
+ target_ids = sorted({augmented_targets[at_index].target_id for at_index in at_indices})
45
45
 
46
46
  tfi_idx = 0
47
47
  for _ in range(additional_augmentations_needed):
@@ -55,50 +55,55 @@ def balance_targets(augmented_targets: AugmentedTargets,
55
55
  target_id=target_id,
56
56
  mixup=mixup,
57
57
  num_ir=num_ir,
58
- first_cba_id=first_cba_id)
59
- augmented_target = AugmentedTarget(target_id=target_id,
60
- target_augmentation_id=augmentation_index)
58
+ first_cba_id=first_cba_id,
59
+ )
60
+ augmented_target = AugmentedTarget(target_id=target_id, target_augmentation_id=augmentation_index)
61
61
  augmented_targets.append(augmented_target)
62
62
 
63
63
  return augmented_targets, target_augmentations
64
64
 
65
65
 
66
- def _get_unused_balancing_augmentation(augmented_targets: AugmentedTargets,
67
- targets: TargetFiles,
68
- target_augmentations: AugmentationRules,
69
- class_balancing_augmentation: AugmentationRule,
70
- target_id: int,
71
- mixup: int,
72
- num_ir: int,
73
- first_cba_id: int) -> tuple[int, AugmentationRules]:
74
- """Get an unused balancing augmentation for a given target file index
75
- """
66
+ def _get_unused_balancing_augmentation(
67
+ augmented_targets: AugmentedTargets,
68
+ targets: TargetFiles,
69
+ target_augmentations: AugmentationRules,
70
+ class_balancing_augmentation: AugmentationRule,
71
+ target_id: int,
72
+ mixup: int,
73
+ num_ir: int,
74
+ first_cba_id: int,
75
+ ) -> tuple[int, AugmentationRules]:
76
+ """Get an unused balancing augmentation for a given target file index"""
76
77
  from dataclasses import asdict
77
78
 
78
79
  from .augmentation import get_augmentation_rules
79
80
 
80
81
  balancing_augmentations = [item for item in range(len(target_augmentations)) if item >= first_cba_id]
81
- used_balancing_augmentations = [at.target_augmentation_id for at in augmented_targets if
82
- at.target_id == target_id and
83
- at.target_augmentation_id in balancing_augmentations]
84
-
85
- augmentation_indices = [item for item in balancing_augmentations if
86
- item not in used_balancing_augmentations and
87
- target_augmentations[item].mixup == mixup]
82
+ used_balancing_augmentations = [
83
+ at.target_augmentation_id
84
+ for at in augmented_targets
85
+ if at.target_id == target_id and at.target_augmentation_id in balancing_augmentations
86
+ ]
87
+
88
+ augmentation_indices = [
89
+ item
90
+ for item in balancing_augmentations
91
+ if item not in used_balancing_augmentations and target_augmentations[item].mixup == mixup
92
+ ]
88
93
  if len(augmentation_indices) > 0:
89
94
  return augmentation_indices[0], target_augmentations
90
95
 
91
- class_balancing_augmentation = get_class_balancing_augmentation(target=targets[target_id],
92
- default_cba=class_balancing_augmentation)
96
+ class_balancing_augmentation = get_class_balancing_augmentation(
97
+ target=targets[target_id], default_cba=class_balancing_augmentation
98
+ )
93
99
  new_augmentation = get_augmentation_rules(rules=asdict(class_balancing_augmentation), num_ir=num_ir)[0]
94
100
  new_augmentation.mixup = mixup
95
101
  target_augmentations.append(new_augmentation)
96
102
  return len(target_augmentations) - 1, target_augmentations
97
103
 
98
104
 
99
- def get_class_balancing_augmentation(target: TargetFile, default_cba: AugmentationRule) -> AugmentationRule | None:
100
- """ Get the class balancing augmentation rule for the given target
101
- """
105
+ def get_class_balancing_augmentation(target: TargetFile, default_cba: AugmentationRule) -> AugmentationRule:
106
+ """Get the class balancing augmentation rule for the given target"""
102
107
  if target.class_balancing_augmentation is not None:
103
108
  return target.class_balancing_augmentation
104
109
  return default_cba
@@ -1,13 +1,14 @@
1
1
  from sonusai.mixture.datatypes import AugmentationRules
2
2
  from sonusai.mixture.datatypes import AugmentedTarget
3
3
  from sonusai.mixture.datatypes import AugmentedTargets
4
- from sonusai.mixture.datatypes import TargetFile
5
4
  from sonusai.mixture.datatypes import TargetFiles
6
5
 
7
6
 
8
- def get_augmented_targets(target_files: TargetFiles,
9
- target_augmentations: AugmentationRules,
10
- mixups: list[int] = None) -> AugmentedTargets:
7
+ def get_augmented_targets(
8
+ target_files: TargetFiles,
9
+ target_augmentations: AugmentationRules,
10
+ mixups: list[int] | None = None,
11
+ ) -> AugmentedTargets:
11
12
  from .augmentation import get_augmentation_indices_for_mixup
12
13
  from .augmentation import get_mixups
13
14
 
@@ -19,85 +20,82 @@ def get_augmented_targets(target_files: TargetFiles,
19
20
  augmentation_indices = get_augmentation_indices_for_mixup(target_augmentations, mixup)
20
21
  for target_index in range(len(target_files)):
21
22
  for augmentation_index in augmentation_indices:
22
- augmented_targets.append(AugmentedTarget(target_id=target_index,
23
- target_augmentation_id=augmentation_index))
23
+ augmented_targets.append(
24
+ AugmentedTarget(
25
+ target_id=target_index,
26
+ target_augmentation_id=augmentation_index,
27
+ )
28
+ )
24
29
 
25
30
  return augmented_targets
26
31
 
27
32
 
28
- def get_truth_indices_for_target(target: TargetFile) -> list[int]:
29
- """Get a list of truth indices for a given target."""
30
- index = [truth_setting.index for truth_setting in target.truth_settings]
31
-
32
- # flatten, uniquify, and sort
33
- return sorted(list(set([item for sublist in index for item in sublist])))
34
-
35
-
36
- def get_truth_indices_for_augmented_target(augmented_target: AugmentedTarget, targets: TargetFiles) -> list[int]:
37
- return get_truth_indices_for_target(targets[augmented_target.target_id])
33
+ def get_class_index_for_augmented_target(augmented_target: AugmentedTarget, targets: TargetFiles) -> list[int]:
34
+ return targets[augmented_target.target_id].class_indices
38
35
 
39
36
 
40
37
  def get_mixup_for_augmented_target(augmented_target: AugmentedTarget, augmentations: AugmentationRules) -> int:
41
38
  return augmentations[augmented_target.target_augmentation_id].mixup
42
39
 
43
40
 
44
- def get_target_ids_for_truth_index(targets: TargetFiles,
45
- truth_index: int,
46
- allow_multiple: bool = False) -> list[int]:
47
- """Get a list of target indices containing the given truth index.
41
+ def get_target_ids_for_class_index(targets: TargetFiles, class_index: int, allow_multiple: bool = False) -> list[int]:
42
+ """Get a list of target indices containing the given class index.
48
43
 
49
- If allow_multiple is True, then include targets that contain multiple truth indices.
44
+ If allow_multiple is True, then include targets that contain multiple class indices.
50
45
  """
51
46
  target_indices = set()
52
47
  for target_index, target in enumerate(targets):
53
- indices = get_truth_indices_for_target(target)
48
+ indices = target.class_indices
54
49
  if len(indices) == 1 or allow_multiple:
55
50
  for index in indices:
56
- if index == truth_index + 1:
51
+ if index == class_index + 1:
57
52
  target_indices.add(target_index)
58
53
 
59
- return sorted(list(target_indices))
54
+ return sorted(target_indices)
60
55
 
61
56
 
62
- def get_augmented_target_ids_for_truth_index(augmented_targets: AugmentedTargets,
63
- targets: TargetFiles,
64
- augmentations: AugmentationRules,
65
- truth_index: int,
66
- mixup: int,
67
- allow_multiple: bool = False) -> list[int]:
68
- """Get a list of augmented target indices containing the given truth index.
57
+ def get_augmented_target_ids_for_class_index(
58
+ augmented_targets: AugmentedTargets,
59
+ targets: TargetFiles,
60
+ augmentations: AugmentationRules,
61
+ class_index: int,
62
+ mixup: int,
63
+ allow_multiple: bool = False,
64
+ ) -> list[int]:
65
+ """Get a list of augmented target indices containing the given class index.
69
66
 
70
- If allow_multiple is True, then include targets that contain multiple truth indices.
67
+ If allow_multiple is True, then include targets that contain multiple class indices.
71
68
  """
72
69
  augmented_target_ids = set()
73
70
  for augmented_target_id, augmented_target in enumerate(augmented_targets):
74
71
  if get_mixup_for_augmented_target(augmented_target=augmented_target, augmentations=augmentations) == mixup:
75
- indices = get_truth_indices_for_augmented_target(augmented_target=augmented_target, targets=targets)
72
+ indices = get_class_index_for_augmented_target(augmented_target=augmented_target, targets=targets)
76
73
  if len(indices) == 1 or allow_multiple:
77
74
  for index in indices:
78
- if index == truth_index + 1:
75
+ if index == class_index + 1:
79
76
  augmented_target_ids.add(augmented_target_id)
80
77
 
81
- return sorted(list(augmented_target_ids))
78
+ return sorted(augmented_target_ids)
82
79
 
83
80
 
84
- def get_augmented_target_ids_by_class(augmented_targets: AugmentedTargets,
85
- targets: TargetFiles,
86
- target_augmentations: AugmentationRules,
87
- mixup: int,
88
- num_classes: int,
89
- truth_mutex: bool) -> list[list[int]]:
90
- if truth_mutex:
91
- num_classes -= 1
92
-
81
+ def get_augmented_target_ids_by_class(
82
+ augmented_targets: AugmentedTargets,
83
+ targets: TargetFiles,
84
+ target_augmentations: AugmentationRules,
85
+ mixup: int,
86
+ num_classes: int,
87
+ ) -> list[list[int]]:
93
88
  indices = []
94
89
  for idx in range(num_classes):
95
90
  indices.append(
96
- get_augmented_target_ids_for_truth_index(augmented_targets=augmented_targets,
97
- targets=targets,
98
- augmentations=target_augmentations,
99
- truth_index=idx,
100
- mixup=mixup))
91
+ get_augmented_target_ids_for_class_index(
92
+ augmented_targets=augmented_targets,
93
+ targets=targets,
94
+ augmentations=target_augmentations,
95
+ class_index=idx,
96
+ mixup=mixup,
97
+ )
98
+ )
101
99
  return indices
102
100
 
103
101
 
@@ -111,36 +109,40 @@ def get_target_augmentations_for_mixup(target_augmentations: AugmentationRules,
111
109
  return [target_augmentation for target_augmentation in target_augmentations if target_augmentation.mixup == mixup]
112
110
 
113
111
 
114
- def get_augmented_target_ids_for_mixup(augmented_targets: AugmentedTargets,
115
- targets: TargetFiles,
116
- target_augmentations: AugmentationRules,
117
- mixup: int,
118
- num_classes: int,
119
- truth_mutex: bool) -> list[list[int]]:
112
+ def get_augmented_target_ids_for_mixup(
113
+ augmented_targets: AugmentedTargets,
114
+ targets: TargetFiles,
115
+ target_augmentations: AugmentationRules,
116
+ mixup: int,
117
+ num_classes: int,
118
+ ) -> list[list[int]]:
120
119
  from collections import deque
121
120
  from random import shuffle
122
121
 
123
- from sonusai import SonusAIError
124
-
125
122
  mixup_indices = []
126
123
 
127
124
  if mixup == 1:
128
125
  for index, augmented_target in enumerate(augmented_targets):
129
- if get_mixup_for_augmented_target(augmented_target=augmented_target,
130
- augmentations=target_augmentations) == 1:
126
+ if (
127
+ get_mixup_for_augmented_target(
128
+ augmented_target=augmented_target,
129
+ augmentations=target_augmentations,
130
+ )
131
+ == 1
132
+ ):
131
133
  mixup_indices.append([index])
132
134
  return mixup_indices
133
135
 
134
- augmented_target_ids_by_class = get_augmented_target_ids_by_class(augmented_targets=augmented_targets,
135
- targets=targets,
136
- target_augmentations=target_augmentations,
137
- mixup=mixup,
138
- num_classes=num_classes,
139
- truth_mutex=truth_mutex)
136
+ augmented_target_ids_by_class = get_augmented_target_ids_by_class(
137
+ augmented_targets=augmented_targets,
138
+ targets=targets,
139
+ target_augmentations=target_augmentations,
140
+ mixup=mixup,
141
+ num_classes=num_classes,
142
+ )
140
143
 
141
144
  if mixup > num_classes:
142
- raise SonusAIError(
143
- f'Specified mixup, {mixup}, is greater than the number of classes, {num_classes}')
145
+ raise ValueError(f"Specified mixup, {mixup}, is greater than the number of classes, {num_classes}")
144
146
 
145
147
  de: deque[int] = deque()
146
148
 
@@ -23,10 +23,10 @@ def tokenized_expand(name: str | bytes | Path) -> tuple[str, dict[str, str]]:
23
23
 
24
24
  from .constants import DEFAULT_NOISE
25
25
 
26
- os.environ['default_noise'] = str(DEFAULT_NOISE)
26
+ os.environ["default_noise"] = str(DEFAULT_NOISE) # noqa: SIM112
27
27
 
28
28
  if isinstance(name, bytes):
29
- name = name.decode('utf-8')
29
+ name = name.decode("utf-8")
30
30
 
31
31
  if isinstance(name, Path):
32
32
  name = name.as_posix()
@@ -34,37 +34,37 @@ def tokenized_expand(name: str | bytes | Path) -> tuple[str, dict[str, str]]:
34
34
  name = os.fspath(name)
35
35
  token_map: dict = {}
36
36
 
37
- if '$' not in name and '%' not in name:
37
+ if "$" not in name and "%" not in name:
38
38
  return name, token_map
39
39
 
40
- var_chars = string.ascii_letters + string.digits + '_-'
41
- quote = '\''
42
- percent = '%'
43
- brace = '{'
44
- rbrace = '}'
45
- dollar = '$'
40
+ var_chars = string.ascii_letters + string.digits + "_-"
41
+ quote = "'"
42
+ percent = "%"
43
+ brace = "{"
44
+ rbrace = "}"
45
+ dollar = "$"
46
46
  environ = os.environ
47
47
 
48
48
  result = name[:0]
49
49
  index = 0
50
50
  path_len = len(name)
51
51
  while index < path_len:
52
- c = name[index:index + 1]
52
+ c = name[index : index + 1]
53
53
  if c == quote: # no expansion within single quotes
54
- name = name[index + 1:]
54
+ name = name[index + 1 :]
55
55
  path_len = len(name)
56
56
  try:
57
57
  index = name.index(c)
58
- result += c + name[:index + 1]
58
+ result += c + name[: index + 1]
59
59
  except ValueError:
60
60
  result += c + name
61
61
  index = path_len - 1
62
62
  elif c == percent: # variable or '%'
63
- if name[index + 1:index + 2] == percent:
63
+ if name[index + 1 : index + 2] == percent:
64
64
  result += c
65
65
  index += 1
66
66
  else:
67
- name = name[index + 1:]
67
+ name = name[index + 1 :]
68
68
  path_len = len(name)
69
69
  try:
70
70
  index = name.index(percent)
@@ -75,7 +75,7 @@ def tokenized_expand(name: str | bytes | Path) -> tuple[str, dict[str, str]]:
75
75
  var = name[:index]
76
76
  try:
77
77
  if environ is None:
78
- value = os.fsencode(os.environ[os.fsdecode(var)]).decode('utf-8')
78
+ value = os.fsencode(os.environ[os.fsdecode(var)]).decode("utf-8") # type: ignore[unreachable]
79
79
  else:
80
80
  value = environ[var]
81
81
  token_map[var] = value
@@ -83,11 +83,11 @@ def tokenized_expand(name: str | bytes | Path) -> tuple[str, dict[str, str]]:
83
83
  value = percent + var + percent
84
84
  result += value
85
85
  elif c == dollar: # variable or '$$'
86
- if name[index + 1:index + 2] == dollar:
86
+ if name[index + 1 : index + 2] == dollar:
87
87
  result += c
88
88
  index += 1
89
- elif name[index + 1:index + 2] == brace:
90
- name = name[index + 2:]
89
+ elif name[index + 1 : index + 2] == brace:
90
+ name = name[index + 2 :]
91
91
  path_len = len(name)
92
92
  try:
93
93
  index = name.index(rbrace)
@@ -98,7 +98,7 @@ def tokenized_expand(name: str | bytes | Path) -> tuple[str, dict[str, str]]:
98
98
  var = name[:index]
99
99
  try:
100
100
  if environ is None:
101
- value = os.fsencode(os.environ[os.fsdecode(var)]).decode('utf-8')
101
+ value = os.fsencode(os.environ[os.fsdecode(var)]).decode("utf-8") # type: ignore[unreachable]
102
102
  else:
103
103
  value = environ[var]
104
104
  token_map[var] = value
@@ -108,14 +108,14 @@ def tokenized_expand(name: str | bytes | Path) -> tuple[str, dict[str, str]]:
108
108
  else:
109
109
  var = name[:0]
110
110
  index += 1
111
- c = name[index:index + 1]
111
+ c = name[index : index + 1]
112
112
  while c and c in var_chars:
113
113
  var += c
114
114
  index += 1
115
- c = name[index:index + 1]
115
+ c = name[index : index + 1]
116
116
  try:
117
117
  if environ is None:
118
- value = os.fsencode(os.environ[os.fsdecode(var)]).decode('utf-8')
118
+ value = os.fsencode(os.environ[os.fsdecode(var)]).decode("utf-8") # type: ignore[unreachable]
119
119
  else:
120
120
  value = environ[var]
121
121
  token_map[var] = value
@@ -139,5 +139,5 @@ def tokenized_replace(name: str, tokens: dict[str, str]) -> str:
139
139
  :return: replaced string
140
140
  """
141
141
  for key, value in tokens.items():
142
- name = name.replace(value, f'${key}')
142
+ name = name.replace(value, f"${key}")
143
143
  return name
@@ -14,19 +14,18 @@ def read_impulse_response(name: str | Path) -> ImpulseResponseData:
14
14
  import torch
15
15
  import torchaudio
16
16
 
17
- from sonusai import SonusAIError
18
17
  from .tokenized_shell_vars import tokenized_expand
19
18
 
20
19
  expanded_name, _ = tokenized_expand(name)
21
20
 
22
21
  # Read impulse response data from audio file
23
22
  try:
24
- raw, sample_rate = torchaudio.load(expanded_name, backend='soundfile')
23
+ raw, sample_rate = torchaudio.load(expanded_name, backend="soundfile")
25
24
  except Exception as e:
26
25
  if name != expanded_name:
27
- raise SonusAIError(f'Error reading {name} (expanded: {expanded_name}): {e}')
26
+ raise OSError(f"Error reading {name} (expanded: {expanded_name}): {e}") from e
28
27
  else:
29
- raise SonusAIError(f'Error reading {name}: {e}')
28
+ raise OSError(f"Error reading {name}: {e}") from e
30
29
 
31
30
  raw = torch.squeeze(raw[0, :])
32
31
  offset = torch.argmax(raw)
@@ -49,7 +48,6 @@ def get_sample_rate(name: str | Path) -> int:
49
48
  """
50
49
  import torchaudio
51
50
 
52
- from sonusai import SonusAIError
53
51
  from .tokenized_shell_vars import tokenized_expand
54
52
 
55
53
  expanded_name, _ = tokenized_expand(name)
@@ -58,9 +56,9 @@ def get_sample_rate(name: str | Path) -> int:
58
56
  return torchaudio.info(expanded_name).sample_rate
59
57
  except Exception as e:
60
58
  if name != expanded_name:
61
- raise SonusAIError(f'Error reading {name} (expanded: {expanded_name}):\n{e}')
59
+ raise OSError(f"Error reading {name} (expanded: {expanded_name}):\n{e}") from e
62
60
  else:
63
- raise SonusAIError(f'Error reading {name}:\n{e}')
61
+ raise OSError(f"Error reading {name}:\n{e}") from e
64
62
 
65
63
 
66
64
  def read_audio(name: str | Path) -> AudioT:
@@ -73,24 +71,25 @@ def read_audio(name: str | Path) -> AudioT:
73
71
  import torch
74
72
  import torchaudio
75
73
 
76
- from sonusai import SonusAIError
77
74
  from .constants import SAMPLE_RATE
78
75
  from .tokenized_shell_vars import tokenized_expand
79
76
 
80
77
  expanded_name, _ = tokenized_expand(name)
81
78
 
82
79
  try:
83
- out, samplerate = torchaudio.load(expanded_name, backend='soundfile')
80
+ out, samplerate = torchaudio.load(expanded_name, backend="soundfile")
84
81
  out = torch.reshape(out[0, :], (1, out.size()[1]))
85
- out = torchaudio.functional.resample(out,
86
- orig_freq=samplerate,
87
- new_freq=SAMPLE_RATE,
88
- resampling_method='sinc_interp_hann')
82
+ out = torchaudio.functional.resample(
83
+ out,
84
+ orig_freq=samplerate,
85
+ new_freq=SAMPLE_RATE,
86
+ resampling_method="sinc_interp_hann",
87
+ )
89
88
  except Exception as e:
90
89
  if name != expanded_name:
91
- raise SonusAIError(f'Error reading {name} (expanded: {expanded_name}):\n{e}')
90
+ raise OSError(f"Error reading {name} (expanded: {expanded_name}):\n{e}") from e
92
91
  else:
93
- raise SonusAIError(f'Error reading {name}:\n{e}')
92
+ raise OSError(f"Error reading {name}:\n{e}") from e
94
93
 
95
94
  result = np.squeeze(np.array(out))
96
95
  return result
@@ -3,9 +3,7 @@ from sonusai.mixture.datatypes import Augmentation
3
3
  from sonusai.mixture.datatypes import ImpulseResponseData
4
4
 
5
5
 
6
- def apply_augmentation(audio: AudioT,
7
- augmentation: Augmentation,
8
- frame_length: int = 1) -> AudioT:
6
+ def apply_augmentation(audio: AudioT, augmentation: Augmentation, frame_length: int = 1) -> AudioT:
9
7
  """Apply augmentations to audio data using torchaudio.sox_effects
10
8
 
11
9
  :param audio: Audio
@@ -17,7 +15,6 @@ def apply_augmentation(audio: AudioT,
17
15
  import torch
18
16
  import torchaudio
19
17
 
20
- from sonusai import SonusAIError
21
18
  from .augmentation import pad_audio_to_frame
22
19
  from .constants import SAMPLE_RATE
23
20
 
@@ -28,29 +25,29 @@ def apply_augmentation(audio: AudioT,
28
25
  # Normalize to globally set level (should this be a global config parameter,
29
26
  # or hard-coded into the script?)
30
27
  if augmentation.normalize is not None:
31
- effects.append(['norm', str(augmentation.normalize)])
28
+ effects.append(["norm", str(augmentation.normalize)])
32
29
 
33
30
  if augmentation.gain is not None:
34
- effects.append(['gain', str(augmentation.gain)])
31
+ effects.append(["gain", str(augmentation.gain)])
35
32
 
36
33
  if augmentation.pitch is not None:
37
- effects.append(['pitch', str(augmentation.pitch)])
38
- effects.append(['rate', str(SAMPLE_RATE)])
34
+ effects.append(["pitch", str(augmentation.pitch)])
35
+ effects.append(["rate", str(SAMPLE_RATE)])
39
36
 
40
37
  if augmentation.tempo is not None:
41
- effects.append(['tempo', '-s', str(augmentation.tempo)])
38
+ effects.append(["tempo", "-s", str(augmentation.tempo)])
42
39
 
43
40
  if augmentation.eq1 is not None:
44
- effects.append(['equalizer', *[str(item) for item in augmentation.eq1]])
41
+ effects.append(["equalizer", *[str(item) for item in augmentation.eq1]])
45
42
 
46
43
  if augmentation.eq2 is not None:
47
- effects.append(['equalizer', *[str(item) for item in augmentation.eq2]])
44
+ effects.append(["equalizer", *[str(item) for item in augmentation.eq2]])
48
45
 
49
46
  if augmentation.eq3 is not None:
50
- effects.append(['equalizer', *[str(item) for item in augmentation.eq3]])
47
+ effects.append(["equalizer", *[str(item) for item in augmentation.eq3]])
51
48
 
52
49
  if augmentation.lpf is not None:
53
- effects.append(['lowpass', '-2', str(augmentation.lpf), '0.707'])
50
+ effects.append(["lowpass", "-2", str(augmentation.lpf), "0.707"])
54
51
 
55
52
  if effects:
56
53
  if audio.ndim == 1:
@@ -58,11 +55,9 @@ def apply_augmentation(audio: AudioT,
58
55
  out = torch.tensor(audio)
59
56
 
60
57
  try:
61
- out, _ = torchaudio.sox_effects.apply_effects_tensor(out,
62
- sample_rate=SAMPLE_RATE,
63
- effects=effects)
58
+ out, _ = torchaudio.sox_effects.apply_effects_tensor(out, sample_rate=SAMPLE_RATE, effects=effects)
64
59
  except Exception as e:
65
- raise SonusAIError(f'Error applying {augmentation}: {e}')
60
+ raise RuntimeError(f"Error applying {augmentation}: {e}") from e
66
61
 
67
62
  audio_out = np.squeeze(np.array(out))
68
63
  else:
@@ -84,6 +79,7 @@ def apply_impulse_response(audio: AudioT, ir: ImpulseResponseData) -> AudioT:
84
79
  import torchaudio
85
80
 
86
81
  from sonusai.utils import linear_to_db
82
+
87
83
  from .constants import SAMPLE_RATE
88
84
 
89
85
  # Early exit if no ir or if all audio is zero
@@ -95,20 +91,20 @@ def apply_impulse_response(audio: AudioT, ir: ImpulseResponseData) -> AudioT:
95
91
 
96
92
  # Convert audio to IR sample rate
97
93
  audio_in = torch.reshape(torch.tensor(audio), (1, len(audio)))
98
- audio_out, sr = torchaudio.sox_effects.apply_effects_tensor(audio_in,
99
- sample_rate=SAMPLE_RATE,
100
- effects=[['rate', str(ir.sample_rate)]])
94
+ audio_out, sr = torchaudio.sox_effects.apply_effects_tensor(
95
+ audio_in, sample_rate=SAMPLE_RATE, effects=[["rate", str(ir.sample_rate)]]
96
+ )
101
97
 
102
98
  # Apply IR and convert back to global sample rate
103
99
  rir = torch.reshape(torch.tensor(ir.data), (1, len(ir.data)))
104
100
  audio_out = torchaudio.functional.fftconvolve(audio_out, rir)
105
- audio_out, sr = torchaudio.sox_effects.apply_effects_tensor(audio_out,
106
- sample_rate=ir.sample_rate,
107
- effects=[['rate', str(SAMPLE_RATE)]])
101
+ audio_out, sr = torchaudio.sox_effects.apply_effects_tensor(
102
+ audio_out, sample_rate=ir.sample_rate, effects=[["rate", str(SAMPLE_RATE)]]
103
+ )
108
104
 
109
105
  # Reset level to previous max value
110
- audio_out, sr = torchaudio.sox_effects.apply_effects_tensor(audio_out,
111
- sample_rate=SAMPLE_RATE,
112
- effects=[['norm', str(max_db)]])
106
+ audio_out, sr = torchaudio.sox_effects.apply_effects_tensor(
107
+ audio_out, sample_rate=SAMPLE_RATE, effects=[["norm", str(max_db)]]
108
+ )
113
109
 
114
- return np.squeeze(np.array(audio_out[:, :len(audio)]))
110
+ return np.squeeze(np.array(audio_out[:, : len(audio)]))