sonusai 1.0.16__cp311-abi3-macosx_10_12_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sonusai/__init__.py +170 -0
  2. sonusai/aawscd_probwrite.py +148 -0
  3. sonusai/audiofe.py +481 -0
  4. sonusai/calc_metric_spenh.py +1136 -0
  5. sonusai/config/__init__.py +0 -0
  6. sonusai/config/asr.py +21 -0
  7. sonusai/config/config.py +65 -0
  8. sonusai/config/config.yml +49 -0
  9. sonusai/config/constants.py +53 -0
  10. sonusai/config/ir.py +124 -0
  11. sonusai/config/ir_delay.py +62 -0
  12. sonusai/config/source.py +275 -0
  13. sonusai/config/spectral_masks.py +15 -0
  14. sonusai/config/truth.py +64 -0
  15. sonusai/constants.py +14 -0
  16. sonusai/data/__init__.py +0 -0
  17. sonusai/data/silero_vad_v5.1.jit +0 -0
  18. sonusai/data/silero_vad_v5.1.onnx +0 -0
  19. sonusai/data/speech_ma01_01.wav +0 -0
  20. sonusai/data/whitenoise.wav +0 -0
  21. sonusai/datatypes.py +383 -0
  22. sonusai/deprecated/gentcst.py +632 -0
  23. sonusai/deprecated/plot.py +519 -0
  24. sonusai/deprecated/tplot.py +365 -0
  25. sonusai/doc.py +52 -0
  26. sonusai/doc_strings/__init__.py +1 -0
  27. sonusai/doc_strings/doc_strings.py +531 -0
  28. sonusai/genft.py +196 -0
  29. sonusai/genmetrics.py +183 -0
  30. sonusai/genmix.py +199 -0
  31. sonusai/genmixdb.py +235 -0
  32. sonusai/ir_metric.py +551 -0
  33. sonusai/lsdb.py +141 -0
  34. sonusai/main.py +134 -0
  35. sonusai/metrics/__init__.py +43 -0
  36. sonusai/metrics/calc_audio_stats.py +42 -0
  37. sonusai/metrics/calc_class_weights.py +90 -0
  38. sonusai/metrics/calc_optimal_thresholds.py +73 -0
  39. sonusai/metrics/calc_pcm.py +45 -0
  40. sonusai/metrics/calc_pesq.py +36 -0
  41. sonusai/metrics/calc_phase_distance.py +43 -0
  42. sonusai/metrics/calc_sa_sdr.py +64 -0
  43. sonusai/metrics/calc_sample_weights.py +25 -0
  44. sonusai/metrics/calc_segsnr_f.py +82 -0
  45. sonusai/metrics/calc_speech.py +382 -0
  46. sonusai/metrics/calc_wer.py +71 -0
  47. sonusai/metrics/calc_wsdr.py +57 -0
  48. sonusai/metrics/calculate_metrics.py +395 -0
  49. sonusai/metrics/class_summary.py +74 -0
  50. sonusai/metrics/confusion_matrix_summary.py +75 -0
  51. sonusai/metrics/one_hot.py +283 -0
  52. sonusai/metrics/snr_summary.py +128 -0
  53. sonusai/metrics_summary.py +314 -0
  54. sonusai/mixture/__init__.py +15 -0
  55. sonusai/mixture/audio.py +187 -0
  56. sonusai/mixture/class_balancing.py +103 -0
  57. sonusai/mixture/constants.py +3 -0
  58. sonusai/mixture/data_io.py +173 -0
  59. sonusai/mixture/db.py +169 -0
  60. sonusai/mixture/db_datatypes.py +92 -0
  61. sonusai/mixture/effects.py +344 -0
  62. sonusai/mixture/feature.py +78 -0
  63. sonusai/mixture/generation.py +1116 -0
  64. sonusai/mixture/helpers.py +351 -0
  65. sonusai/mixture/ir_effects.py +77 -0
  66. sonusai/mixture/log_duration_and_sizes.py +23 -0
  67. sonusai/mixture/mixdb.py +1857 -0
  68. sonusai/mixture/pad_audio.py +35 -0
  69. sonusai/mixture/resample.py +7 -0
  70. sonusai/mixture/sox_effects.py +195 -0
  71. sonusai/mixture/sox_help.py +650 -0
  72. sonusai/mixture/spectral_mask.py +51 -0
  73. sonusai/mixture/truth.py +61 -0
  74. sonusai/mixture/truth_functions/__init__.py +45 -0
  75. sonusai/mixture/truth_functions/crm.py +105 -0
  76. sonusai/mixture/truth_functions/energy.py +222 -0
  77. sonusai/mixture/truth_functions/file.py +48 -0
  78. sonusai/mixture/truth_functions/metadata.py +24 -0
  79. sonusai/mixture/truth_functions/metrics.py +28 -0
  80. sonusai/mixture/truth_functions/phoneme.py +18 -0
  81. sonusai/mixture/truth_functions/sed.py +98 -0
  82. sonusai/mixture/truth_functions/target.py +142 -0
  83. sonusai/mkwav.py +135 -0
  84. sonusai/onnx_predict.py +363 -0
  85. sonusai/parse/__init__.py +0 -0
  86. sonusai/parse/expand.py +156 -0
  87. sonusai/parse/parse_source_directive.py +129 -0
  88. sonusai/parse/rand.py +214 -0
  89. sonusai/py.typed +0 -0
  90. sonusai/queries/__init__.py +0 -0
  91. sonusai/queries/queries.py +239 -0
  92. sonusai/rs.abi3.so +0 -0
  93. sonusai/rs.pyi +1 -0
  94. sonusai/rust/__init__.py +0 -0
  95. sonusai/speech/__init__.py +0 -0
  96. sonusai/speech/l2arctic.py +121 -0
  97. sonusai/speech/librispeech.py +102 -0
  98. sonusai/speech/mcgill.py +71 -0
  99. sonusai/speech/textgrid.py +89 -0
  100. sonusai/speech/timit.py +138 -0
  101. sonusai/speech/types.py +12 -0
  102. sonusai/speech/vctk.py +53 -0
  103. sonusai/speech/voxceleb.py +108 -0
  104. sonusai/utils/__init__.py +3 -0
  105. sonusai/utils/asl_p56.py +130 -0
  106. sonusai/utils/asr.py +91 -0
  107. sonusai/utils/asr_functions/__init__.py +3 -0
  108. sonusai/utils/asr_functions/aaware_whisper.py +69 -0
  109. sonusai/utils/audio_devices.py +50 -0
  110. sonusai/utils/braced_glob.py +50 -0
  111. sonusai/utils/calculate_input_shape.py +26 -0
  112. sonusai/utils/choice.py +51 -0
  113. sonusai/utils/compress.py +25 -0
  114. sonusai/utils/convert_string_to_number.py +6 -0
  115. sonusai/utils/create_timestamp.py +5 -0
  116. sonusai/utils/create_ts_name.py +14 -0
  117. sonusai/utils/dataclass_from_dict.py +27 -0
  118. sonusai/utils/db.py +16 -0
  119. sonusai/utils/docstring.py +53 -0
  120. sonusai/utils/energy_f.py +44 -0
  121. sonusai/utils/engineering_number.py +166 -0
  122. sonusai/utils/evaluate_random_rule.py +15 -0
  123. sonusai/utils/get_frames_per_batch.py +2 -0
  124. sonusai/utils/get_label_names.py +20 -0
  125. sonusai/utils/grouper.py +6 -0
  126. sonusai/utils/human_readable_size.py +7 -0
  127. sonusai/utils/keyboard_interrupt.py +12 -0
  128. sonusai/utils/load_object.py +21 -0
  129. sonusai/utils/max_text_width.py +9 -0
  130. sonusai/utils/model_utils.py +28 -0
  131. sonusai/utils/numeric_conversion.py +11 -0
  132. sonusai/utils/onnx_utils.py +155 -0
  133. sonusai/utils/parallel.py +162 -0
  134. sonusai/utils/path_info.py +7 -0
  135. sonusai/utils/print_mixture_details.py +60 -0
  136. sonusai/utils/rand.py +13 -0
  137. sonusai/utils/ranges.py +43 -0
  138. sonusai/utils/read_predict_data.py +32 -0
  139. sonusai/utils/reshape.py +154 -0
  140. sonusai/utils/seconds_to_hms.py +7 -0
  141. sonusai/utils/stacked_complex.py +82 -0
  142. sonusai/utils/stratified_shuffle_split.py +170 -0
  143. sonusai/utils/tokenized_shell_vars.py +143 -0
  144. sonusai/utils/write_audio.py +26 -0
  145. sonusai/utils/yes_or_no.py +8 -0
  146. sonusai/vars.py +47 -0
  147. sonusai-1.0.16.dist-info/METADATA +56 -0
  148. sonusai-1.0.16.dist-info/RECORD +150 -0
  149. sonusai-1.0.16.dist-info/WHEEL +4 -0
  150. sonusai-1.0.16.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,632 @@
1
+ """sonusai gentcst
2
+
3
+ usage: gentcst [-hvru] [-o OUTPUT] [-f FOLD] [-y ONTOLOGY]
4
+
5
+ options:
6
+ -h, --help
7
+ -v, --verbose Be verbose.
8
+ -o OUTPUT, --output OUTPUT Output file name.
9
+ -f FOLD, --fold FOLD Fold(s) to include. [default: *].
10
+ -y ONTOLOGY, --ontology ONTOLOGY Reference ontology JSON file for cross-check and adding metadata.
11
+ [default: ontology.json].
12
+ -r, --hierarchical Generate hierarchical multiclass truth for non-leaf nodes.
13
+ -u, --update Write JSON metadata into the tree.
14
+
15
+ Generate a target SonusAI configuration file from a hierarchical subdirectory tree
16
+ under the local directory. Leaves in the subdirectory tree define classes by providing SonusAI list
17
+ .txt files for class data and additional truth config definitions. This offers a way to simplify and
18
+ more clearly manage hierarchical class structure similar to the Audioset ontology.
19
+
20
+ Features:
21
+ - automatically compiles the target configuration with the number of classes and each truth
22
+ index and the associated files (note: files should use absolute path with env variable)
23
+ - supports hierarchical multilabel class truth generation, where higher nodes in the tree
24
+ are included with leaf truth index, and different leaves with the same name will have the
25
+ same index (i.e. door/squeak, chair/squeak, and brief-tone/squeak)
26
+ - support config overrides for class config.yml or filename.yml for specific file config
27
+ - manages folds via file prefix naming convention, i.e. 1-*.txt, 2-*.txt, etc.
28
+ specifies data for fold 1, 2, etc.
29
+ - cross-check node definitions with the reference ontology and auto copy class metadata into
30
+ the tree to help define/collect data
31
+ - class label is also generated (in the .yml or referencing a .csv)
32
+
33
+ Inputs:
34
+ The local subdirectory tree with expected top-level configuration file config.yml. This
35
+ file must at least define the feature and can have additional parameters treated as default
36
+ (for truth, augmentation, etc.) which will be included in the generated output yaml file.
37
+ An entire class dataset config parameter can be overridden by including it in a config.yml
38
+ file in the class subdirectory. Individual file config parameters can be further overridden
39
+ or specified in a file of the same name but with .yml (i.e. 1-data.yml with the 1-data.txt)
40
+
41
+ Outputs:
42
+ output.yml
43
+ output.csv
44
+ gentcst.log
45
+
46
+ """
47
+
48
+ import signal
49
+ from dataclasses import dataclass
50
+
51
+
52
+ def signal_handler(_sig, _frame):
53
+ import sys
54
+
55
+ from sonusai import logger
56
+
57
+ logger.info("Canceled due to keyboard interrupt")
58
+ sys.exit(1)
59
+
60
+
61
+ signal.signal(signal.SIGINT, signal_handler)
62
+
63
+ CONFIG_FILE = "config.yml"
64
+
65
+
66
+ @dataclass
67
+ class FileInfo:
68
+ file: str
69
+ classes: list[str]
70
+ leaf: str
71
+ labels: list[str]
72
+ fold: int
73
+ truth_index: list[int] | None = None
74
+
75
+
76
+ @dataclass(frozen=True)
77
+ class DirInfo:
78
+ file: str
79
+ classes: list[str]
80
+
81
+
82
+ @dataclass(frozen=True)
83
+ class LabelInfo:
84
+ index: int
85
+ display_name: str
86
+
87
+
88
+ @dataclass(frozen=True)
89
+ class OntologyInfo:
90
+ id: str
91
+ name: list[str]
92
+ description: str
93
+ citation_uri: str
94
+ positive_examples: list[str]
95
+ child_ids: list[str]
96
+ restrictions: list[str]
97
+
98
+
99
+ def gentcst(
100
+ fold: str = "*",
101
+ ontology: str = "ontology.json",
102
+ hierarchical: bool = False,
103
+ update: bool = False,
104
+ verbose: bool = False,
105
+ ) -> tuple[dict, list[LabelInfo]]:
106
+ from copy import deepcopy
107
+ from os import getcwd
108
+ from os.path import dirname
109
+ from os.path import exists
110
+ from os.path import splitext
111
+
112
+ from sonusai import SonusAIError
113
+ from sonusai import initial_log_messages
114
+ from sonusai import logger
115
+ from sonusai import update_console_handler
116
+ from sonusai.mixture.config import get_default_config
117
+ from sonusai.mixture.config import raw_load_config
118
+ from sonusai.mixture.config import update_config_from_hierarchy
119
+ from sonusai.config.truth import validate_truth_configs
120
+
121
+ update_console_handler(verbose)
122
+ initial_log_messages("gentcst")
123
+
124
+ if update:
125
+ logger.info("Updating tree with JSON metadata")
126
+ logger.info("")
127
+
128
+ logger.debug(f"fold: {fold}")
129
+ logger.debug(f"ontology: {ontology}")
130
+ logger.debug(f"hierarchical: {hierarchical}")
131
+ logger.debug(f"update: {update}")
132
+ logger.debug("")
133
+
134
+ config = get_config()
135
+
136
+ if config["truth_mode"] == "mutex" and hierarchical:
137
+ raise SonusAIError("Multi-class truth is incompatible with truth_mode mutex")
138
+
139
+ all_files = get_all_files(hierarchical)
140
+ all_folds = get_folds_from_files(all_files)
141
+ use_folds = get_use_folds(all_folds, fold)
142
+ use_files = get_files_from_folds(all_files, use_folds)
143
+ report_leaf_fold_data_usage(all_files, use_files)
144
+
145
+ ontology_data = validate_ontology(ontology, use_files)
146
+
147
+ labels = get_labels_from_files(all_files)
148
+ logger.debug("Truth indices:")
149
+ for item in labels:
150
+ logger.debug(f" {item.index:3} {item.display_name}")
151
+ logger.debug("")
152
+
153
+ gen_truth_indices(use_files, labels)
154
+
155
+ config["num_classes"] = len(labels)
156
+ if config["truth_mode"] == "mutex":
157
+ config["num_classes"] = config["num_classes"] + 1
158
+
159
+ config["targets"] = []
160
+
161
+ logger.info(f"gentcst {len(use_files)} entries in tree")
162
+ root = getcwd()
163
+ for file in use_files:
164
+ leaf = dirname(file.file)
165
+ local_config = get_default_config()
166
+ local_config = update_config_from_hierarchy(root=root, leaf=leaf, config=local_config)
167
+ if not isinstance(local_config["truth_settings"], list):
168
+ local_config["truth_settings"] = [local_config["truth_settings"]]
169
+
170
+ specific_name = splitext(file.file)[0] + ".yml"
171
+ if exists(specific_name):
172
+ specific_config = raw_load_config(specific_name)
173
+
174
+ for key in specific_config:
175
+ if key != "truth_settings":
176
+ local_config[key] = specific_config[key]
177
+ else:
178
+ local_config["truth_settings"] = validate_truth_configs(
179
+ given=specific_config["truth_settings"],
180
+ default=local_config["truth_settings"],
181
+ )
182
+
183
+ truth_settings = deepcopy(local_config["truth_settings"])
184
+ for idx, val in enumerate(truth_settings):
185
+ val["index"] = file.truth_index
186
+ for key in ["function", "config"]:
187
+ if key in val and val[key] == config["truth_settings"][idx][key]:
188
+ del val[key]
189
+
190
+ target = {"name": file.file, "truth_settings": truth_settings}
191
+
192
+ config["targets"].append(target)
193
+
194
+ if update:
195
+ write_metadata_to_tree(ontology_data, file)
196
+
197
+ return config, labels
198
+
199
+
200
+ def get_ontology(ontology: str) -> list[OntologyInfo]:
201
+ import json
202
+ from os.path import exists
203
+
204
+ raw_ontology_data = []
205
+ if exists(ontology):
206
+ with open(file=ontology, encoding="utf-8") as f:
207
+ raw_ontology_data = json.load(f)
208
+
209
+ return [
210
+ OntologyInfo(
211
+ id=item["id"],
212
+ name=convert_ontology_name(item["name"]),
213
+ description=item["description"],
214
+ citation_uri=item["citation_uri"],
215
+ positive_examples=item["positive_examples"],
216
+ child_ids=item["child_ids"],
217
+ restrictions=item["restrictions"],
218
+ )
219
+ for item in raw_ontology_data
220
+ ]
221
+
222
+
223
+ def convert_ontology_name(name: str) -> list[str]:
224
+ """Convert names to lowercase, convert spaces to '-', split into lists on ','."""
225
+ text = name.lower()
226
+ # Need to find ', ' before converting spaces so that we don't convert ', ' to ',-'
227
+ text = text.replace(", ", ",")
228
+ text = text.replace(" ", "-")
229
+ return text.split(",")
230
+
231
+
232
+ def get_dirs() -> list[DirInfo]:
233
+ from os import getcwd
234
+ from os import walk
235
+ from os.path import join
236
+
237
+ match: list[str] = []
238
+ cwd = getcwd()
239
+ for root, ds, _ in walk(cwd):
240
+ for d in ds:
241
+ match.append(join(root, d).replace(cwd, "."))
242
+ match = sorted(match)
243
+
244
+ dirs: list[DirInfo] = []
245
+ for m in match:
246
+ if not m.startswith("./gentcst") and m != ".":
247
+ classes = m.split("/")
248
+ # Remove first element because it is '.'
249
+ classes.pop(0)
250
+ dirs.append(DirInfo(file=m, classes=classes))
251
+
252
+ return dirs
253
+
254
+
255
+ def get_leaf_from_classes(classes: list[str]) -> str:
256
+ return "/".join(classes)
257
+
258
+
259
+ def get_all_files(hierarchical: bool) -> list[FileInfo]:
260
+ import re
261
+ from os import getcwd
262
+ from os import walk
263
+ from os.path import join
264
+
265
+ file_list: list[str] = []
266
+ cwd = getcwd()
267
+ regex = re.compile(r"^\d+-.*\.txt$")
268
+ for root, _, fs in walk(cwd):
269
+ for f in fs:
270
+ if regex.match(f):
271
+ file_list.append(join(root, f).replace(cwd, "."))
272
+
273
+ files: list[FileInfo] = []
274
+ fold_pattern = re.compile(r".*/(\d+)-.*\.txt")
275
+ for file in file_list:
276
+ fold_match = fold_pattern.match(file)
277
+ if fold_match:
278
+ fold = int(fold_match.group(1))
279
+
280
+ classes = file.split("/")
281
+ # Remove first element because it is '.'
282
+ classes.pop(0)
283
+ # Remove last element because it is the file name
284
+ classes.pop()
285
+
286
+ leaf = get_leaf_from_classes(classes)
287
+
288
+ if hierarchical:
289
+ labels = classes
290
+ else:
291
+ labels = [leaf]
292
+
293
+ files.append(FileInfo(file=file, classes=classes, leaf=leaf, labels=labels, fold=fold))
294
+
295
+ return files
296
+
297
+
298
+ def get_use_folds(all_folds: list[int], fold: str) -> list[int]:
299
+ import numpy as np
300
+
301
+ from sonusai import logger
302
+
303
+ req_folds: list[int] = get_req_folds(all_folds, fold)
304
+ use_folds: list[int] = list(set(req_folds).intersection(all_folds))
305
+
306
+ logger.debug("Fold information")
307
+ logger.debug(f" Available: {all_folds}")
308
+ logger.debug(f" Requested: {req_folds}")
309
+ logger.debug(f" Used: {use_folds}")
310
+ logger.debug(f" Unused: {list(set(use_folds).symmetric_difference(all_folds))}")
311
+ logger.debug(f" Missing: {list(np.setdiff1d(req_folds, all_folds))}")
312
+ logger.debug("")
313
+
314
+ return use_folds
315
+
316
+
317
+ def get_files_from_folds(files: list[FileInfo], folds: list[int]) -> list[FileInfo]:
318
+ return [file for file in files if file.fold in folds]
319
+
320
+
321
+ def get_item_from_name(ontology_data: list[OntologyInfo], name: str) -> OntologyInfo | None:
322
+ return next((item for item in ontology_data if name in item.name), None)
323
+
324
+
325
+ def get_item_from_id(ontology_data: list[OntologyInfo], identity: str) -> OntologyInfo | None:
326
+ return next((item for item in ontology_data if item.id == identity), None)
327
+
328
+
329
+ def get_name_from_id(ontology_data: list[OntologyInfo], identity: str) -> list[str] | None:
330
+ from sonusai import SonusAIError
331
+
332
+ name = None
333
+ found = False
334
+
335
+ for item in ontology_data:
336
+ if item.id == identity:
337
+ if found:
338
+ raise SonusAIError(f"id {identity} appears multiple times in ontology")
339
+
340
+ name = item.name
341
+ found = True
342
+
343
+ return name
344
+
345
+
346
+ def get_id_from_name(ontology_data: list[OntologyInfo], name: str) -> str | None:
347
+ from sonusai import SonusAIError
348
+
349
+ identity = None
350
+ found = False
351
+
352
+ for item in ontology_data:
353
+ if name in item.name:
354
+ if found:
355
+ raise SonusAIError(f"name {name} appears multiple times in ontology")
356
+
357
+ identity = item.id
358
+ found = True
359
+
360
+ return identity
361
+
362
+
363
+ def is_valid_name(ontology_data: list[OntologyInfo], name: str) -> bool:
364
+ return any(name in item.name for item in ontology_data)
365
+
366
+
367
+ def is_valid_child(ontology_data: list[OntologyInfo], parent: str, child: str) -> bool:
368
+ valid = False
369
+ parent_item = get_item_from_name(ontology_data, parent)
370
+ child_id = get_id_from_name(ontology_data, child)
371
+
372
+ if child_id is not None and parent_item is not None and child_id in parent_item.child_ids:
373
+ valid = True
374
+
375
+ return valid
376
+
377
+
378
+ def is_valid_hierarchy(ontology_data: list[OntologyInfo], classes: list[str]) -> bool:
379
+ from itertools import pairwise
380
+ valid = True
381
+
382
+ for parent, child in pairwise(classes):
383
+ if not is_valid_child(ontology_data, parent, child):
384
+ valid = False
385
+
386
+ return valid
387
+
388
+
389
+ def validate_class(ontology_data: list[OntologyInfo], item: FileInfo | DirInfo) -> bool:
390
+ from sonusai import logger
391
+
392
+ valid = True
393
+
394
+ for c in item.classes:
395
+ if not is_valid_name(ontology_data, c):
396
+ logger.warning(f" Could not find {c} in ontology for {item.file}")
397
+ valid = False
398
+
399
+ if valid and not is_valid_hierarchy(ontology_data, item.classes):
400
+ logger.warning(f" Invalid parent/child relationship for {item.file}")
401
+ valid = False
402
+
403
+ return valid
404
+
405
+
406
+ def get_req_folds(folds: list[int], fold: str) -> list[int]:
407
+ from sonusai.utils.ranges import expand_range
408
+
409
+ if fold == "*":
410
+ return folds
411
+
412
+ return expand_range(fold)
413
+
414
+
415
+ def get_folds_from_files(files: list[FileInfo]) -> list[int]:
416
+ result: list[int] = [file.fold for file in files]
417
+ # Converting to set and back to list ensures uniqueness
418
+ return sorted(set(result))
419
+
420
+
421
+ def get_leaves_from_files(files: list[FileInfo]) -> list[str]:
422
+ result: list[str] = [file.leaf for file in files]
423
+ # Converting to set and back to list ensures uniqueness
424
+ return sorted(set(result))
425
+
426
+
427
+ def get_labels_from_files(files: list[FileInfo]) -> list[LabelInfo]:
428
+ all_labels = [file.labels for file in files]
429
+
430
+ # Get labels by depth
431
+ labels_by_depth: list[list[str]] = [[] for _ in range(max([len(label) for label in all_labels]))]
432
+ for label in all_labels:
433
+ for index, name in enumerate(label):
434
+ labels_by_depth[index].append(name)
435
+
436
+ for n in range(len(labels_by_depth)):
437
+ # Converting to set and back to list ensures uniqueness
438
+ labels_by_depth[n] = sorted(set(labels_by_depth[n]))
439
+
440
+ # We want the deepest leaves first
441
+ labels_by_depth.reverse()
442
+
443
+ # Now flatten the list
444
+ flattened_labels_by_depth = [item for sublist in labels_by_depth for item in sublist]
445
+
446
+ # Generate index, name pairs
447
+ # labels = []
448
+ # for index, file in enumerate(flattened_labels_by_depth):
449
+ # labels.append({'index': index + 1, 'display_name': file})
450
+ return [LabelInfo(index=index + 1, display_name=file) for index, file in enumerate(flattened_labels_by_depth)]
451
+
452
+
453
+ def gen_truth_indices(files: list[FileInfo], labels: list[LabelInfo]):
454
+ for file in files:
455
+ file.truth_index = sorted([get_index_from_label(labels, label) for label in file.labels])
456
+
457
+
458
+ def get_index_from_label(labels: list[LabelInfo], label: str) -> int:
459
+ return next((item for item in labels if item.display_name == label), None).index
460
+
461
+
462
+ def get_config() -> dict:
463
+ from os.path import exists
464
+
465
+ from sonusai import SonusAIError
466
+ from sonusai.mixture import raw_load_config
467
+
468
+ if not exists(CONFIG_FILE):
469
+ raise SonusAIError(f"No {CONFIG_FILE} at top level")
470
+
471
+ config = raw_load_config(CONFIG_FILE)
472
+
473
+ if "feature" not in config:
474
+ raise SonusAIError("feature not in top level config")
475
+
476
+ if "truth_mode" not in config:
477
+ raise SonusAIError("truth_mode not in top level config")
478
+
479
+ if config["truth_mode"] not in ["normal", "mutex"]:
480
+ raise SonusAIError("Invalid truth_mode in top level config")
481
+
482
+ if "truth_settings" in config and not isinstance(config["truth_settings"], list):
483
+ config["truth_settings"] = [config["truth_settings"]]
484
+
485
+ return config
486
+
487
+
488
+ def validate_ontology(ontology: str, items: list[FileInfo] | list[DirInfo]) -> list[OntologyInfo]:
489
+ from os.path import exists
490
+
491
+ from sonusai import logger
492
+
493
+ if exists(ontology):
494
+ ontology_data = get_ontology(ontology)
495
+ logger.debug(f"Reference ontology in {ontology} has {len(ontology_data)} classes")
496
+ logger.debug("")
497
+
498
+ logger.info("Checking tree against reference ontology")
499
+ all_dirs = get_dirs()
500
+ valid = True
501
+ for file in all_dirs:
502
+ if not validate_class(ontology_data, file):
503
+ valid = False
504
+ if valid:
505
+ logger.info("PASS")
506
+ logger.info("")
507
+
508
+ logger.info("Checking files against reference ontology")
509
+ valid = True
510
+ for item in items:
511
+ if not validate_class(ontology_data, item):
512
+ valid = False
513
+ if valid:
514
+ logger.info("PASS")
515
+ logger.info("")
516
+
517
+ return ontology_data
518
+
519
+ return []
520
+
521
+
522
+ def get_node_from_name(ontology_data: list[OntologyInfo], name: str) -> OntologyInfo | None:
523
+ from sonusai import logger
524
+
525
+ nodes = [item for item in ontology_data if name in item.name]
526
+ if len(nodes) == 1:
527
+ return nodes[0]
528
+
529
+ if nodes:
530
+ logger.warning(f"Found multiple entries in reference ontology that match {name}")
531
+ else:
532
+ logger.warning(f"Could not find entry for {name} in reference ontology")
533
+
534
+ return None
535
+
536
+
537
+ def write_metadata_to_tree(ontology_data: list[OntologyInfo], file: FileInfo):
538
+ import json
539
+ from dataclasses import asdict
540
+ from os.path import basename
541
+ from os.path import dirname
542
+
543
+ if ontology_data:
544
+ node = get_node_from_name(ontology_data, file.classes[-1])
545
+ if node is not None:
546
+ dir_name = dirname(file.file)
547
+ json_name = dir_name + "/" + basename(dir_name) + ".json"
548
+ with open(file=json_name, mode="w") as f:
549
+ json.dump(asdict(node), f)
550
+
551
+
552
+ def get_folds_from_leaf(all_files: list[FileInfo], leaf: str) -> list[int]:
553
+ files = [item for item in all_files if item.leaf == leaf]
554
+ return get_folds_from_files(files)
555
+
556
+
557
+ def report_leaf_fold_data_usage(all_files: list[FileInfo], use_files: list[FileInfo]):
558
+ from sonusai import logger
559
+
560
+ use_leaves = get_leaves_from_files(use_files)
561
+ all_leaves = get_leaves_from_files(all_files)
562
+
563
+ logger.debug("Data folds present in each leaf")
564
+ leaf_len = len(max(all_leaves, key=len))
565
+ for leaf in all_leaves:
566
+ folds = get_folds_from_leaf(all_files, leaf)
567
+ logger.debug(f" {leaf:{leaf_len}} {folds}")
568
+ logger.debug("")
569
+
570
+ dif_leaves = set(all_leaves).symmetric_difference(use_leaves)
571
+ if dif_leaves:
572
+ logger.warning("This fold selection is missing data from the following leaves")
573
+ for c in dif_leaves:
574
+ logger.warning(f" {c}")
575
+ logger.warning("")
576
+
577
+
578
+ def main() -> None:
579
+ from docopt import docopt
580
+
581
+ import sonusai
582
+ from sonusai.utils.docstring import trim_docstring
583
+
584
+ args = docopt(trim_docstring(__doc__), version=sonusai.__version__, options_first=True)
585
+
586
+ verbose = args["--verbose"]
587
+ output_name = args["--output"]
588
+ fold = args["--fold"]
589
+ ontology = args["--ontology"]
590
+ hierarchical = args["--hierarchical"]
591
+ update = args["--update"]
592
+
593
+ import csv
594
+ from dataclasses import asdict
595
+ from os import getcwd
596
+ from os.path import basename
597
+ from os.path import splitext
598
+
599
+ import yaml
600
+
601
+ from sonusai import create_file_handler
602
+ from sonusai import logger
603
+
604
+ if not output_name:
605
+ output_name = basename(getcwd()) + ".yml"
606
+
607
+ create_file_handler("gentcst.log")
608
+
609
+ config, labels = gentcst(
610
+ fold=fold,
611
+ ontology=ontology,
612
+ hierarchical=hierarchical,
613
+ update=update,
614
+ verbose=verbose,
615
+ )
616
+
617
+ with open(file=output_name, mode="w") as f:
618
+ yaml.dump(config, f)
619
+ logger.info(f"Wrote config to {output_name}")
620
+
621
+ csv_fields = ["index", "display_name"]
622
+ csv_name = splitext(output_name)[0] + ".csv"
623
+
624
+ with open(file=csv_name, mode="w") as f:
625
+ writer = csv.DictWriter(f, fieldnames=csv_fields)
626
+ writer.writeheader()
627
+ writer.writerows([asdict(label) for label in labels])
628
+ logger.info(f"Wrote labels to {csv_name}")
629
+
630
+
631
+ if __name__ == "__main__":
632
+ main()