sonusai 0.18.9__py3-none-any.whl → 0.19.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (118) hide show
  1. sonusai/__init__.py +20 -29
  2. sonusai/aawscd_probwrite.py +18 -18
  3. sonusai/audiofe.py +93 -80
  4. sonusai/calc_metric_spenh.py +395 -321
  5. sonusai/data/genmixdb.yml +5 -11
  6. sonusai/{gentcst.py → deprecated/gentcst.py} +146 -149
  7. sonusai/{plot.py → deprecated/plot.py} +177 -131
  8. sonusai/{tplot.py → deprecated/tplot.py} +124 -102
  9. sonusai/doc/__init__.py +1 -1
  10. sonusai/doc/doc.py +112 -177
  11. sonusai/doc.py +10 -10
  12. sonusai/genft.py +81 -91
  13. sonusai/genmetrics.py +51 -61
  14. sonusai/genmix.py +105 -115
  15. sonusai/genmixdb.py +201 -174
  16. sonusai/lsdb.py +56 -66
  17. sonusai/main.py +23 -20
  18. sonusai/metrics/__init__.py +2 -0
  19. sonusai/metrics/calc_audio_stats.py +29 -24
  20. sonusai/metrics/calc_class_weights.py +7 -7
  21. sonusai/metrics/calc_optimal_thresholds.py +5 -7
  22. sonusai/metrics/calc_pcm.py +3 -3
  23. sonusai/metrics/calc_pesq.py +10 -7
  24. sonusai/metrics/calc_phase_distance.py +3 -3
  25. sonusai/metrics/calc_sa_sdr.py +10 -8
  26. sonusai/metrics/calc_segsnr_f.py +16 -18
  27. sonusai/metrics/calc_speech.py +105 -47
  28. sonusai/metrics/calc_wer.py +35 -32
  29. sonusai/metrics/calc_wsdr.py +10 -7
  30. sonusai/metrics/class_summary.py +30 -27
  31. sonusai/metrics/confusion_matrix_summary.py +25 -22
  32. sonusai/metrics/one_hot.py +91 -57
  33. sonusai/metrics/snr_summary.py +53 -46
  34. sonusai/mixture/__init__.py +20 -14
  35. sonusai/mixture/audio.py +4 -6
  36. sonusai/mixture/augmentation.py +37 -43
  37. sonusai/mixture/class_count.py +5 -14
  38. sonusai/mixture/config.py +292 -225
  39. sonusai/mixture/constants.py +41 -30
  40. sonusai/mixture/data_io.py +155 -0
  41. sonusai/mixture/datatypes.py +111 -108
  42. sonusai/mixture/db_datatypes.py +54 -70
  43. sonusai/mixture/eq_rule_is_valid.py +6 -9
  44. sonusai/mixture/feature.py +40 -38
  45. sonusai/mixture/generation.py +522 -389
  46. sonusai/mixture/helpers.py +217 -272
  47. sonusai/mixture/log_duration_and_sizes.py +16 -13
  48. sonusai/mixture/mixdb.py +669 -477
  49. sonusai/mixture/soundfile_audio.py +12 -17
  50. sonusai/mixture/sox_audio.py +91 -112
  51. sonusai/mixture/sox_augmentation.py +8 -9
  52. sonusai/mixture/spectral_mask.py +4 -6
  53. sonusai/mixture/target_class_balancing.py +41 -36
  54. sonusai/mixture/targets.py +69 -67
  55. sonusai/mixture/tokenized_shell_vars.py +23 -23
  56. sonusai/mixture/torchaudio_audio.py +14 -15
  57. sonusai/mixture/torchaudio_augmentation.py +23 -27
  58. sonusai/mixture/truth.py +48 -26
  59. sonusai/mixture/truth_functions/__init__.py +26 -0
  60. sonusai/mixture/truth_functions/crm.py +56 -38
  61. sonusai/mixture/truth_functions/datatypes.py +37 -0
  62. sonusai/mixture/truth_functions/energy.py +85 -59
  63. sonusai/mixture/truth_functions/file.py +30 -30
  64. sonusai/mixture/truth_functions/phoneme.py +14 -7
  65. sonusai/mixture/truth_functions/sed.py +71 -45
  66. sonusai/mixture/truth_functions/target.py +69 -106
  67. sonusai/mkwav.py +58 -101
  68. sonusai/onnx_predict.py +46 -43
  69. sonusai/queries/__init__.py +3 -1
  70. sonusai/queries/queries.py +100 -59
  71. sonusai/speech/__init__.py +2 -0
  72. sonusai/speech/l2arctic.py +24 -23
  73. sonusai/speech/librispeech.py +16 -17
  74. sonusai/speech/mcgill.py +22 -21
  75. sonusai/speech/textgrid.py +32 -25
  76. sonusai/speech/timit.py +45 -42
  77. sonusai/speech/vctk.py +14 -13
  78. sonusai/speech/voxceleb.py +26 -20
  79. sonusai/summarize_metric_spenh.py +11 -10
  80. sonusai/utils/__init__.py +4 -3
  81. sonusai/utils/asl_p56.py +1 -1
  82. sonusai/utils/asr.py +37 -17
  83. sonusai/utils/asr_functions/__init__.py +2 -0
  84. sonusai/utils/asr_functions/aaware_whisper.py +18 -12
  85. sonusai/utils/audio_devices.py +12 -12
  86. sonusai/utils/braced_glob.py +6 -8
  87. sonusai/utils/calculate_input_shape.py +1 -4
  88. sonusai/utils/compress.py +2 -2
  89. sonusai/utils/convert_string_to_number.py +1 -3
  90. sonusai/utils/create_timestamp.py +1 -1
  91. sonusai/utils/create_ts_name.py +2 -2
  92. sonusai/utils/dataclass_from_dict.py +1 -1
  93. sonusai/utils/docstring.py +6 -6
  94. sonusai/utils/energy_f.py +9 -7
  95. sonusai/utils/engineering_number.py +56 -54
  96. sonusai/utils/get_label_names.py +8 -10
  97. sonusai/utils/human_readable_size.py +2 -2
  98. sonusai/utils/model_utils.py +3 -5
  99. sonusai/utils/numeric_conversion.py +2 -4
  100. sonusai/utils/onnx_utils.py +43 -32
  101. sonusai/utils/parallel.py +41 -30
  102. sonusai/utils/print_mixture_details.py +25 -22
  103. sonusai/utils/ranges.py +12 -12
  104. sonusai/utils/read_predict_data.py +11 -9
  105. sonusai/utils/reshape.py +19 -26
  106. sonusai/utils/seconds_to_hms.py +1 -1
  107. sonusai/utils/stacked_complex.py +8 -16
  108. sonusai/utils/stratified_shuffle_split.py +29 -27
  109. sonusai/utils/write_audio.py +2 -2
  110. sonusai/utils/yes_or_no.py +3 -3
  111. sonusai/vars.py +14 -14
  112. {sonusai-0.18.9.dist-info → sonusai-0.19.6.dist-info}/METADATA +20 -21
  113. sonusai-0.19.6.dist-info/RECORD +125 -0
  114. {sonusai-0.18.9.dist-info → sonusai-0.19.6.dist-info}/WHEEL +1 -1
  115. sonusai/mixture/truth_functions/data.py +0 -58
  116. sonusai/utils/read_mixture_data.py +0 -14
  117. sonusai-0.18.9.dist-info/RECORD +0 -125
  118. {sonusai-0.18.9.dist-info → sonusai-0.19.6.dist-info}/entry_points.txt +0 -0
sonusai/data/genmixdb.yml CHANGED
@@ -7,6 +7,8 @@ feature: ""
7
7
 
8
8
  target_level_type: default
9
9
 
10
+ class_indices: 1
11
+
10
12
  targets: [ ]
11
13
 
12
14
  num_classes: 1
@@ -17,15 +19,7 @@ seed: 0
17
19
 
18
20
  class_weights_threshold: 0.5
19
21
 
20
- truth_mode: normal
21
-
22
- truth_reduction_function: max
23
-
24
- truth_settings:
25
- function: sed
26
- config:
27
- thresholds: [ -38, -41, -48 ]
28
- index: 1
22
+ truth_configs: { }
29
23
 
30
24
  asr_manifest: [ ]
31
25
 
@@ -52,7 +46,7 @@ snrs:
52
46
 
53
47
  random_snrs: [ ]
54
48
 
55
- noise_mix_mode: 'exhaustive'
49
+ noise_mix_mode: exhaustive
56
50
 
57
51
  impulse_responses: [ ]
58
52
 
@@ -63,4 +57,4 @@ spectral_masks:
63
57
  t_num: 0
64
58
  t_max_percent: 100
65
59
 
66
- asr_configs: [ ]
60
+ asr_configs: { }
@@ -44,9 +44,9 @@ Outputs:
44
44
  gentcst.log
45
45
 
46
46
  """
47
+
47
48
  import signal
48
49
  from dataclasses import dataclass
49
- from typing import Optional
50
50
 
51
51
 
52
52
  def signal_handler(_sig, _frame):
@@ -54,13 +54,13 @@ def signal_handler(_sig, _frame):
54
54
 
55
55
  from sonusai import logger
56
56
 
57
- logger.info('Canceled due to keyboard interrupt')
57
+ logger.info("Canceled due to keyboard interrupt")
58
58
  sys.exit(1)
59
59
 
60
60
 
61
61
  signal.signal(signal.SIGINT, signal_handler)
62
62
 
63
- CONFIG_FILE = 'config.yml'
63
+ CONFIG_FILE = "config.yml"
64
64
 
65
65
 
66
66
  @dataclass
@@ -70,7 +70,7 @@ class FileInfo:
70
70
  leaf: str
71
71
  labels: list[str]
72
72
  fold: int
73
- truth_index: Optional[list[int]] = None
73
+ truth_index: list[int] | None = None
74
74
 
75
75
 
76
76
  @dataclass(frozen=True)
@@ -96,11 +96,13 @@ class OntologyInfo:
96
96
  restrictions: list[str]
97
97
 
98
98
 
99
- def gentcst(fold: str = '*',
100
- ontology: str = 'ontology.json',
101
- hierarchical: bool = False,
102
- update: bool = False,
103
- verbose: bool = False) -> tuple[dict, list[LabelInfo]]:
99
+ def gentcst(
100
+ fold: str = "*",
101
+ ontology: str = "ontology.json",
102
+ hierarchical: bool = False,
103
+ update: bool = False,
104
+ verbose: bool = False,
105
+ ) -> tuple[dict, list[LabelInfo]]:
104
106
  from copy import deepcopy
105
107
  from os import getcwd
106
108
  from os.path import dirname
@@ -114,25 +116,25 @@ def gentcst(fold: str = '*',
114
116
  from sonusai.mixture import get_default_config
115
117
  from sonusai.mixture import raw_load_config
116
118
  from sonusai.mixture import update_config_from_hierarchy
117
- from sonusai.mixture import update_truth_settings
119
+ from sonusai.mixture import validate_truth_configs
118
120
 
119
121
  update_console_handler(verbose)
120
- initial_log_messages('gentcst')
122
+ initial_log_messages("gentcst")
121
123
 
122
124
  if update:
123
- logger.info('Updating tree with JSON metadata')
124
- logger.info('')
125
+ logger.info("Updating tree with JSON metadata")
126
+ logger.info("")
125
127
 
126
- logger.debug(f'fold: {fold}')
127
- logger.debug(f'ontology: {ontology}')
128
- logger.debug(f'hierarchical: {hierarchical}')
129
- logger.debug(f'update: {update}')
130
- logger.debug('')
128
+ logger.debug(f"fold: {fold}")
129
+ logger.debug(f"ontology: {ontology}")
130
+ logger.debug(f"hierarchical: {hierarchical}")
131
+ logger.debug(f"update: {update}")
132
+ logger.debug("")
131
133
 
132
134
  config = get_config()
133
135
 
134
- if config['truth_mode'] == 'mutex' and hierarchical:
135
- raise SonusAIError('Multi-class truth is incompatible with truth_mode mutex')
136
+ if config["truth_mode"] == "mutex" and hierarchical:
137
+ raise SonusAIError("Multi-class truth is incompatible with truth_mode mutex")
136
138
 
137
139
  all_files = get_all_files(hierarchical)
138
140
  all_folds = get_folds_from_files(all_files)
@@ -143,52 +145,51 @@ def gentcst(fold: str = '*',
143
145
  ontology_data = validate_ontology(ontology, use_files)
144
146
 
145
147
  labels = get_labels_from_files(all_files)
146
- logger.debug('Truth indices:')
148
+ logger.debug("Truth indices:")
147
149
  for item in labels:
148
- logger.debug(f' {item.index:3} {item.display_name}')
149
- logger.debug('')
150
+ logger.debug(f" {item.index:3} {item.display_name}")
151
+ logger.debug("")
150
152
 
151
153
  gen_truth_indices(use_files, labels)
152
154
 
153
- config['num_classes'] = len(labels)
154
- if config['truth_mode'] == 'mutex':
155
- config['num_classes'] = config['num_classes'] + 1
155
+ config["num_classes"] = len(labels)
156
+ if config["truth_mode"] == "mutex":
157
+ config["num_classes"] = config["num_classes"] + 1
156
158
 
157
- config['targets'] = []
159
+ config["targets"] = []
158
160
 
159
- logger.info(f'gentcst {len(use_files)} entries in tree')
161
+ logger.info(f"gentcst {len(use_files)} entries in tree")
160
162
  root = getcwd()
161
163
  for file in use_files:
162
164
  leaf = dirname(file.file)
163
165
  local_config = get_default_config()
164
166
  local_config = update_config_from_hierarchy(root=root, leaf=leaf, config=local_config)
165
- if not isinstance(local_config['truth_settings'], list):
166
- local_config['truth_settings'] = [local_config['truth_settings']]
167
+ if not isinstance(local_config["truth_settings"], list):
168
+ local_config["truth_settings"] = [local_config["truth_settings"]]
167
169
 
168
- specific_name = splitext(file.file)[0] + '.yml'
170
+ specific_name = splitext(file.file)[0] + ".yml"
169
171
  if exists(specific_name):
170
172
  specific_config = raw_load_config(specific_name)
171
173
 
172
174
  for key in specific_config:
173
- if key != 'truth_settings':
175
+ if key != "truth_settings":
174
176
  local_config[key] = specific_config[key]
175
177
  else:
176
- local_config['truth_settings'] = update_truth_settings(given=specific_config['truth_settings'],
177
- default=local_config['truth_settings'])
178
+ local_config["truth_settings"] = validate_truth_configs(
179
+ given=specific_config["truth_settings"],
180
+ default=local_config["truth_settings"],
181
+ )
178
182
 
179
- truth_settings = deepcopy(local_config['truth_settings'])
183
+ truth_settings = deepcopy(local_config["truth_settings"])
180
184
  for idx, val in enumerate(truth_settings):
181
- val['index'] = file.truth_index
182
- for key in ['function', 'config']:
183
- if key in val and val[key] == config['truth_settings'][idx][key]:
185
+ val["index"] = file.truth_index
186
+ for key in ["function", "config"]:
187
+ if key in val and val[key] == config["truth_settings"][idx][key]:
184
188
  del val[key]
185
189
 
186
- target = {
187
- 'name': file.file,
188
- 'truth_settings': truth_settings
189
- }
190
+ target = {"name": file.file, "truth_settings": truth_settings}
190
191
 
191
- config['targets'].append(target)
192
+ config["targets"].append(target)
192
193
 
193
194
  if update:
194
195
  write_metadata_to_tree(ontology_data, file)
@@ -202,28 +203,30 @@ def get_ontology(ontology: str) -> list[OntologyInfo]:
202
203
 
203
204
  raw_ontology_data = []
204
205
  if exists(ontology):
205
- with open(file=ontology, mode='r', encoding='utf-8') as f:
206
+ with open(file=ontology, encoding="utf-8") as f:
206
207
  raw_ontology_data = json.load(f)
207
208
 
208
- return [OntologyInfo(
209
- id=item['id'],
210
- name=convert_ontology_name(item['name']),
211
- description=item['description'],
212
- citation_uri=item['citation_uri'],
213
- positive_examples=item['positive_examples'],
214
- child_ids=item['child_ids'],
215
- restrictions=item['restrictions']
216
- ) for item in raw_ontology_data]
209
+ return [
210
+ OntologyInfo(
211
+ id=item["id"],
212
+ name=convert_ontology_name(item["name"]),
213
+ description=item["description"],
214
+ citation_uri=item["citation_uri"],
215
+ positive_examples=item["positive_examples"],
216
+ child_ids=item["child_ids"],
217
+ restrictions=item["restrictions"],
218
+ )
219
+ for item in raw_ontology_data
220
+ ]
217
221
 
218
222
 
219
223
  def convert_ontology_name(name: str) -> list[str]:
220
- """Convert names to lowercase, convert spaces to '-', split into lists on ','.
221
- """
224
+ """Convert names to lowercase, convert spaces to '-', split into lists on ','."""
222
225
  text = name.lower()
223
226
  # Need to find ', ' before converting spaces so that we don't convert ', ' to ',-'
224
- text = text.replace(', ', ',')
225
- text = text.replace(' ', '-')
226
- return text.split(',')
227
+ text = text.replace(", ", ",")
228
+ text = text.replace(" ", "-")
229
+ return text.split(",")
227
230
 
228
231
 
229
232
  def get_dirs() -> list[DirInfo]:
@@ -235,13 +238,13 @@ def get_dirs() -> list[DirInfo]:
235
238
  cwd = getcwd()
236
239
  for root, ds, _ in walk(cwd):
237
240
  for d in ds:
238
- match.append(join(root, d).replace(cwd, '.'))
241
+ match.append(join(root, d).replace(cwd, "."))
239
242
  match = sorted(match)
240
243
 
241
244
  dirs: list[DirInfo] = []
242
245
  for m in match:
243
- if not m.startswith('./gentcst') and m != '.':
244
- classes = m.split('/')
246
+ if not m.startswith("./gentcst") and m != ".":
247
+ classes = m.split("/")
245
248
  # Remove first element because it is '.'
246
249
  classes.pop(0)
247
250
  dirs.append(DirInfo(file=m, classes=classes))
@@ -250,7 +253,7 @@ def get_dirs() -> list[DirInfo]:
250
253
 
251
254
 
252
255
  def get_leaf_from_classes(classes: list[str]) -> str:
253
- return '/'.join(classes)
256
+ return "/".join(classes)
254
257
 
255
258
 
256
259
  def get_all_files(hierarchical: bool) -> list[FileInfo]:
@@ -261,20 +264,20 @@ def get_all_files(hierarchical: bool) -> list[FileInfo]:
261
264
 
262
265
  file_list: list[str] = []
263
266
  cwd = getcwd()
264
- regex = re.compile(r'^\d+-.*\.txt$')
267
+ regex = re.compile(r"^\d+-.*\.txt$")
265
268
  for root, _, fs in walk(cwd):
266
269
  for f in fs:
267
270
  if regex.match(f):
268
- file_list.append(join(root, f).replace(cwd, '.'))
271
+ file_list.append(join(root, f).replace(cwd, "."))
269
272
 
270
273
  files: list[FileInfo] = []
271
- fold_pattern = re.compile(r'.*/(\d+)-.*\.txt')
274
+ fold_pattern = re.compile(r".*/(\d+)-.*\.txt")
272
275
  for file in file_list:
273
276
  fold_match = fold_pattern.match(file)
274
277
  if fold_match:
275
278
  fold = int(fold_match.group(1))
276
279
 
277
- classes = file.split('/')
280
+ classes = file.split("/")
278
281
  # Remove first element because it is '.'
279
282
  classes.pop(0)
280
283
  # Remove last element because it is the file name
@@ -287,11 +290,7 @@ def get_all_files(hierarchical: bool) -> list[FileInfo]:
287
290
  else:
288
291
  labels = [leaf]
289
292
 
290
- files.append(FileInfo(file=file,
291
- classes=classes,
292
- leaf=leaf,
293
- labels=labels,
294
- fold=fold))
293
+ files.append(FileInfo(file=file, classes=classes, leaf=leaf, labels=labels, fold=fold))
295
294
 
296
295
  return files
297
296
 
@@ -304,13 +303,13 @@ def get_use_folds(all_folds: list[int], fold: str) -> list[int]:
304
303
  req_folds: list[int] = get_req_folds(all_folds, fold)
305
304
  use_folds: list[int] = list(set(req_folds).intersection(all_folds))
306
305
 
307
- logger.debug('Fold information')
308
- logger.debug(f' Available: {all_folds}')
309
- logger.debug(f' Requested: {req_folds}')
310
- logger.debug(f' Used: {use_folds}')
311
- logger.debug(f' Unused: {list(set(use_folds).symmetric_difference(all_folds))}')
312
- logger.debug(f' Missing: {list(np.setdiff1d(req_folds, all_folds))}')
313
- logger.debug('')
306
+ logger.debug("Fold information")
307
+ logger.debug(f" Available: {all_folds}")
308
+ logger.debug(f" Requested: {req_folds}")
309
+ logger.debug(f" Used: {use_folds}")
310
+ logger.debug(f" Unused: {list(set(use_folds).symmetric_difference(all_folds))}")
311
+ logger.debug(f" Missing: {list(np.setdiff1d(req_folds, all_folds))}")
312
+ logger.debug("")
314
313
 
315
314
  return use_folds
316
315
 
@@ -319,15 +318,15 @@ def get_files_from_folds(files: list[FileInfo], folds: list[int]) -> list[FileIn
319
318
  return [file for file in files if file.fold in folds]
320
319
 
321
320
 
322
- def get_item_from_name(ontology_data: list[OntologyInfo], name: str) -> OntologyInfo:
321
+ def get_item_from_name(ontology_data: list[OntologyInfo], name: str) -> OntologyInfo | None:
323
322
  return next((item for item in ontology_data if name in item.name), None)
324
323
 
325
324
 
326
- def get_item_from_id(ontology_data: list[OntologyInfo], identity: str) -> OntologyInfo:
325
+ def get_item_from_id(ontology_data: list[OntologyInfo], identity: str) -> OntologyInfo | None:
327
326
  return next((item for item in ontology_data if item.id == identity), None)
328
327
 
329
328
 
330
- def get_name_from_id(ontology_data: list[OntologyInfo], identity: str) -> list[str]:
329
+ def get_name_from_id(ontology_data: list[OntologyInfo], identity: str) -> list[str] | None:
331
330
  from sonusai import SonusAIError
332
331
 
333
332
  name = None
@@ -336,7 +335,7 @@ def get_name_from_id(ontology_data: list[OntologyInfo], identity: str) -> list[s
336
335
  for item in ontology_data:
337
336
  if item.id == identity:
338
337
  if found:
339
- raise SonusAIError(f'id {identity} appears multiple times in ontology')
338
+ raise SonusAIError(f"id {identity} appears multiple times in ontology")
340
339
 
341
340
  name = item.name
342
341
  found = True
@@ -344,7 +343,7 @@ def get_name_from_id(ontology_data: list[OntologyInfo], identity: str) -> list[s
344
343
  return name
345
344
 
346
345
 
347
- def get_id_from_name(ontology_data: list[OntologyInfo], name: str) -> str:
346
+ def get_id_from_name(ontology_data: list[OntologyInfo], name: str) -> str | None:
348
347
  from sonusai import SonusAIError
349
348
 
350
349
  identity = None
@@ -353,7 +352,7 @@ def get_id_from_name(ontology_data: list[OntologyInfo], name: str) -> str:
353
352
  for item in ontology_data:
354
353
  if name in item.name:
355
354
  if found:
356
- raise SonusAIError(f'name {name} appears multiple times in ontology')
355
+ raise SonusAIError(f"name {name} appears multiple times in ontology")
357
356
 
358
357
  identity = item.id
359
358
  found = True
@@ -362,10 +361,7 @@ def get_id_from_name(ontology_data: list[OntologyInfo], name: str) -> str:
362
361
 
363
362
 
364
363
  def is_valid_name(ontology_data: list[OntologyInfo], name: str) -> bool:
365
- for item in ontology_data:
366
- if name in item.name:
367
- return True
368
- return False
364
+ return any(name in item.name for item in ontology_data)
369
365
 
370
366
 
371
367
  def is_valid_child(ontology_data: list[OntologyInfo], parent: str, child: str) -> bool:
@@ -373,17 +369,17 @@ def is_valid_child(ontology_data: list[OntologyInfo], parent: str, child: str) -
373
369
  parent_item = get_item_from_name(ontology_data, parent)
374
370
  child_id = get_id_from_name(ontology_data, child)
375
371
 
376
- if child_id is not None and parent_item is not None:
377
- if child_id in parent_item.child_ids:
378
- valid = True
372
+ if child_id is not None and parent_item is not None and child_id in parent_item.child_ids:
373
+ valid = True
379
374
 
380
375
  return valid
381
376
 
382
377
 
383
378
  def is_valid_hierarchy(ontology_data: list[OntologyInfo], classes: list[str]) -> bool:
379
+ from itertools import pairwise
384
380
  valid = True
385
381
 
386
- for parent, child in zip(classes, classes[1:]):
382
+ for parent, child in pairwise(classes):
387
383
  if not is_valid_child(ontology_data, parent, child):
388
384
  valid = False
389
385
 
@@ -397,13 +393,12 @@ def validate_class(ontology_data: list[OntologyInfo], item: FileInfo | DirInfo)
397
393
 
398
394
  for c in item.classes:
399
395
  if not is_valid_name(ontology_data, c):
400
- logger.warning(f' Could not find {c} in ontology for {item.file}')
396
+ logger.warning(f" Could not find {c} in ontology for {item.file}")
401
397
  valid = False
402
398
 
403
- if valid:
404
- if not is_valid_hierarchy(ontology_data, item.classes):
405
- logger.warning(f' Invalid parent/child relationship for {item.file}')
406
- valid = False
399
+ if valid and not is_valid_hierarchy(ontology_data, item.classes):
400
+ logger.warning(f" Invalid parent/child relationship for {item.file}")
401
+ valid = False
407
402
 
408
403
  return valid
409
404
 
@@ -411,7 +406,7 @@ def validate_class(ontology_data: list[OntologyInfo], item: FileInfo | DirInfo)
411
406
  def get_req_folds(folds: list[int], fold: str) -> list[int]:
412
407
  from sonusai.utils import expand_range
413
408
 
414
- if fold == '*':
409
+ if fold == "*":
415
410
  return folds
416
411
 
417
412
  return expand_range(fold)
@@ -420,13 +415,13 @@ def get_req_folds(folds: list[int], fold: str) -> list[int]:
420
415
  def get_folds_from_files(files: list[FileInfo]) -> list[int]:
421
416
  result: list[int] = [file.fold for file in files]
422
417
  # Converting to set and back to list ensures uniqueness
423
- return sorted(list(set(result)))
418
+ return sorted(set(result))
424
419
 
425
420
 
426
421
  def get_leaves_from_files(files: list[FileInfo]) -> list[str]:
427
422
  result: list[str] = [file.leaf for file in files]
428
423
  # Converting to set and back to list ensures uniqueness
429
- return sorted(list(set(result)))
424
+ return sorted(set(result))
430
425
 
431
426
 
432
427
  def get_labels_from_files(files: list[FileInfo]) -> list[LabelInfo]:
@@ -440,7 +435,7 @@ def get_labels_from_files(files: list[FileInfo]) -> list[LabelInfo]:
440
435
 
441
436
  for n in range(len(labels_by_depth)):
442
437
  # Converting to set and back to list ensures uniqueness
443
- labels_by_depth[n] = sorted(list(set(labels_by_depth[n])))
438
+ labels_by_depth[n] = sorted(set(labels_by_depth[n]))
444
439
 
445
440
  # We want the deepest leaves first
446
441
  labels_by_depth.reverse()
@@ -471,21 +466,21 @@ def get_config() -> dict:
471
466
  from sonusai.mixture import raw_load_config
472
467
 
473
468
  if not exists(CONFIG_FILE):
474
- raise SonusAIError(f'No {CONFIG_FILE} at top level')
469
+ raise SonusAIError(f"No {CONFIG_FILE} at top level")
475
470
 
476
471
  config = raw_load_config(CONFIG_FILE)
477
472
 
478
- if 'feature' not in config:
479
- raise SonusAIError('feature not in top level config')
473
+ if "feature" not in config:
474
+ raise SonusAIError("feature not in top level config")
480
475
 
481
- if 'truth_mode' not in config:
482
- raise SonusAIError('truth_mode not in top level config')
476
+ if "truth_mode" not in config:
477
+ raise SonusAIError("truth_mode not in top level config")
483
478
 
484
- if config['truth_mode'] not in ['normal', 'mutex']:
485
- raise SonusAIError('Invalid truth_mode in top level config')
479
+ if config["truth_mode"] not in ["normal", "mutex"]:
480
+ raise SonusAIError("Invalid truth_mode in top level config")
486
481
 
487
- if 'truth_settings' in config and not isinstance(config['truth_settings'], list):
488
- config['truth_settings'] = [config['truth_settings']]
482
+ if "truth_settings" in config and not isinstance(config["truth_settings"], list):
483
+ config["truth_settings"] = [config["truth_settings"]]
489
484
 
490
485
  return config
491
486
 
@@ -497,34 +492,34 @@ def validate_ontology(ontology: str, items: list[FileInfo] | list[DirInfo]) -> l
497
492
 
498
493
  if exists(ontology):
499
494
  ontology_data = get_ontology(ontology)
500
- logger.debug(f'Reference ontology in {ontology} has {len(ontology_data)} classes')
501
- logger.debug('')
495
+ logger.debug(f"Reference ontology in {ontology} has {len(ontology_data)} classes")
496
+ logger.debug("")
502
497
 
503
- logger.info('Checking tree against reference ontology')
498
+ logger.info("Checking tree against reference ontology")
504
499
  all_dirs = get_dirs()
505
500
  valid = True
506
501
  for file in all_dirs:
507
502
  if not validate_class(ontology_data, file):
508
503
  valid = False
509
504
  if valid:
510
- logger.info('PASS')
511
- logger.info('')
505
+ logger.info("PASS")
506
+ logger.info("")
512
507
 
513
- logger.info('Checking files against reference ontology')
508
+ logger.info("Checking files against reference ontology")
514
509
  valid = True
515
510
  for item in items:
516
511
  if not validate_class(ontology_data, item):
517
512
  valid = False
518
513
  if valid:
519
- logger.info('PASS')
520
- logger.info('')
514
+ logger.info("PASS")
515
+ logger.info("")
521
516
 
522
517
  return ontology_data
523
518
 
524
519
  return []
525
520
 
526
521
 
527
- def get_node_from_name(ontology_data: list[OntologyInfo], name: str) -> Optional[OntologyInfo]:
522
+ def get_node_from_name(ontology_data: list[OntologyInfo], name: str) -> OntologyInfo | None:
528
523
  from sonusai import logger
529
524
 
530
525
  nodes = [item for item in ontology_data if name in item.name]
@@ -532,9 +527,9 @@ def get_node_from_name(ontology_data: list[OntologyInfo], name: str) -> Optional
532
527
  return nodes[0]
533
528
 
534
529
  if nodes:
535
- logger.warning(f'Found multiple entries in reference ontology that match {name}')
530
+ logger.warning(f"Found multiple entries in reference ontology that match {name}")
536
531
  else:
537
- logger.warning(f'Could not find entry for {name} in reference ontology')
532
+ logger.warning(f"Could not find entry for {name} in reference ontology")
538
533
 
539
534
  return None
540
535
 
@@ -549,8 +544,8 @@ def write_metadata_to_tree(ontology_data: list[OntologyInfo], file: FileInfo):
549
544
  node = get_node_from_name(ontology_data, file.classes[-1])
550
545
  if node is not None:
551
546
  dir_name = dirname(file.file)
552
- json_name = dir_name + '/' + basename(dir_name) + '.json'
553
- with open(file=json_name, mode='w') as f:
547
+ json_name = dir_name + "/" + basename(dir_name) + ".json"
548
+ with open(file=json_name, mode="w") as f:
554
549
  json.dump(asdict(node), f)
555
550
 
556
551
 
@@ -565,19 +560,19 @@ def report_leaf_fold_data_usage(all_files: list[FileInfo], use_files: list[FileI
565
560
  use_leaves = get_leaves_from_files(use_files)
566
561
  all_leaves = get_leaves_from_files(all_files)
567
562
 
568
- logger.debug('Data folds present in each leaf')
563
+ logger.debug("Data folds present in each leaf")
569
564
  leaf_len = len(max(all_leaves, key=len))
570
565
  for leaf in all_leaves:
571
566
  folds = get_folds_from_leaf(all_files, leaf)
572
- logger.debug(f' {leaf:{leaf_len}} {folds}')
573
- logger.debug('')
567
+ logger.debug(f" {leaf:{leaf_len}} {folds}")
568
+ logger.debug("")
574
569
 
575
570
  dif_leaves = set(all_leaves).symmetric_difference(use_leaves)
576
571
  if dif_leaves:
577
- logger.warning('This fold selection is missing data from the following leaves')
572
+ logger.warning("This fold selection is missing data from the following leaves")
578
573
  for c in dif_leaves:
579
- logger.warning(f' {c}')
580
- logger.warning('')
574
+ logger.warning(f" {c}")
575
+ logger.warning("")
581
576
 
582
577
 
583
578
  def main() -> None:
@@ -588,12 +583,12 @@ def main() -> None:
588
583
 
589
584
  args = docopt(trim_docstring(__doc__), version=sonusai.__version__, options_first=True)
590
585
 
591
- verbose = args['--verbose']
592
- output_name = args['--output']
593
- fold = args['--fold']
594
- ontology = args['--ontology']
595
- hierarchical = args['--hierarchical']
596
- update = args['--update']
586
+ verbose = args["--verbose"]
587
+ output_name = args["--output"]
588
+ fold = args["--fold"]
589
+ ontology = args["--ontology"]
590
+ hierarchical = args["--hierarchical"]
591
+ update = args["--update"]
597
592
 
598
593
  import csv
599
594
  from dataclasses import asdict
@@ -607,29 +602,31 @@ def main() -> None:
607
602
  from sonusai import logger
608
603
 
609
604
  if not output_name:
610
- output_name = basename(getcwd()) + '.yml'
605
+ output_name = basename(getcwd()) + ".yml"
611
606
 
612
- create_file_handler('gentcst.log')
607
+ create_file_handler("gentcst.log")
613
608
 
614
- config, labels = gentcst(fold=fold,
615
- ontology=ontology,
616
- hierarchical=hierarchical,
617
- update=update,
618
- verbose=verbose)
609
+ config, labels = gentcst(
610
+ fold=fold,
611
+ ontology=ontology,
612
+ hierarchical=hierarchical,
613
+ update=update,
614
+ verbose=verbose,
615
+ )
619
616
 
620
- with open(file=output_name, mode='w') as f:
617
+ with open(file=output_name, mode="w") as f:
621
618
  yaml.dump(config, f)
622
- logger.info(f'Wrote config to {output_name}')
619
+ logger.info(f"Wrote config to {output_name}")
623
620
 
624
- csv_fields = ['index', 'display_name']
625
- csv_name = splitext(output_name)[0] + '.csv'
621
+ csv_fields = ["index", "display_name"]
622
+ csv_name = splitext(output_name)[0] + ".csv"
626
623
 
627
- with open(file=csv_name, mode='w') as f:
624
+ with open(file=csv_name, mode="w") as f:
628
625
  writer = csv.DictWriter(f, fieldnames=csv_fields)
629
626
  writer.writeheader()
630
627
  writer.writerows([asdict(label) for label in labels])
631
- logger.info(f'Wrote labels to {csv_name}')
628
+ logger.info(f"Wrote labels to {csv_name}")
632
629
 
633
630
 
634
- if __name__ == '__main__':
631
+ if __name__ == "__main__":
635
632
  main()