sonusai 1.0.16__cp311-abi3-macosx_10_12_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +170 -0
- sonusai/aawscd_probwrite.py +148 -0
- sonusai/audiofe.py +481 -0
- sonusai/calc_metric_spenh.py +1136 -0
- sonusai/config/__init__.py +0 -0
- sonusai/config/asr.py +21 -0
- sonusai/config/config.py +65 -0
- sonusai/config/config.yml +49 -0
- sonusai/config/constants.py +53 -0
- sonusai/config/ir.py +124 -0
- sonusai/config/ir_delay.py +62 -0
- sonusai/config/source.py +275 -0
- sonusai/config/spectral_masks.py +15 -0
- sonusai/config/truth.py +64 -0
- sonusai/constants.py +14 -0
- sonusai/data/__init__.py +0 -0
- sonusai/data/silero_vad_v5.1.jit +0 -0
- sonusai/data/silero_vad_v5.1.onnx +0 -0
- sonusai/data/speech_ma01_01.wav +0 -0
- sonusai/data/whitenoise.wav +0 -0
- sonusai/datatypes.py +383 -0
- sonusai/deprecated/gentcst.py +632 -0
- sonusai/deprecated/plot.py +519 -0
- sonusai/deprecated/tplot.py +365 -0
- sonusai/doc.py +52 -0
- sonusai/doc_strings/__init__.py +1 -0
- sonusai/doc_strings/doc_strings.py +531 -0
- sonusai/genft.py +196 -0
- sonusai/genmetrics.py +183 -0
- sonusai/genmix.py +199 -0
- sonusai/genmixdb.py +235 -0
- sonusai/ir_metric.py +551 -0
- sonusai/lsdb.py +141 -0
- sonusai/main.py +134 -0
- sonusai/metrics/__init__.py +43 -0
- sonusai/metrics/calc_audio_stats.py +42 -0
- sonusai/metrics/calc_class_weights.py +90 -0
- sonusai/metrics/calc_optimal_thresholds.py +73 -0
- sonusai/metrics/calc_pcm.py +45 -0
- sonusai/metrics/calc_pesq.py +36 -0
- sonusai/metrics/calc_phase_distance.py +43 -0
- sonusai/metrics/calc_sa_sdr.py +64 -0
- sonusai/metrics/calc_sample_weights.py +25 -0
- sonusai/metrics/calc_segsnr_f.py +82 -0
- sonusai/metrics/calc_speech.py +382 -0
- sonusai/metrics/calc_wer.py +71 -0
- sonusai/metrics/calc_wsdr.py +57 -0
- sonusai/metrics/calculate_metrics.py +395 -0
- sonusai/metrics/class_summary.py +74 -0
- sonusai/metrics/confusion_matrix_summary.py +75 -0
- sonusai/metrics/one_hot.py +283 -0
- sonusai/metrics/snr_summary.py +128 -0
- sonusai/metrics_summary.py +314 -0
- sonusai/mixture/__init__.py +15 -0
- sonusai/mixture/audio.py +187 -0
- sonusai/mixture/class_balancing.py +103 -0
- sonusai/mixture/constants.py +3 -0
- sonusai/mixture/data_io.py +173 -0
- sonusai/mixture/db.py +169 -0
- sonusai/mixture/db_datatypes.py +92 -0
- sonusai/mixture/effects.py +344 -0
- sonusai/mixture/feature.py +78 -0
- sonusai/mixture/generation.py +1116 -0
- sonusai/mixture/helpers.py +351 -0
- sonusai/mixture/ir_effects.py +77 -0
- sonusai/mixture/log_duration_and_sizes.py +23 -0
- sonusai/mixture/mixdb.py +1857 -0
- sonusai/mixture/pad_audio.py +35 -0
- sonusai/mixture/resample.py +7 -0
- sonusai/mixture/sox_effects.py +195 -0
- sonusai/mixture/sox_help.py +650 -0
- sonusai/mixture/spectral_mask.py +51 -0
- sonusai/mixture/truth.py +61 -0
- sonusai/mixture/truth_functions/__init__.py +45 -0
- sonusai/mixture/truth_functions/crm.py +105 -0
- sonusai/mixture/truth_functions/energy.py +222 -0
- sonusai/mixture/truth_functions/file.py +48 -0
- sonusai/mixture/truth_functions/metadata.py +24 -0
- sonusai/mixture/truth_functions/metrics.py +28 -0
- sonusai/mixture/truth_functions/phoneme.py +18 -0
- sonusai/mixture/truth_functions/sed.py +98 -0
- sonusai/mixture/truth_functions/target.py +142 -0
- sonusai/mkwav.py +135 -0
- sonusai/onnx_predict.py +363 -0
- sonusai/parse/__init__.py +0 -0
- sonusai/parse/expand.py +156 -0
- sonusai/parse/parse_source_directive.py +129 -0
- sonusai/parse/rand.py +214 -0
- sonusai/py.typed +0 -0
- sonusai/queries/__init__.py +0 -0
- sonusai/queries/queries.py +239 -0
- sonusai/rs.abi3.so +0 -0
- sonusai/rs.pyi +1 -0
- sonusai/rust/__init__.py +0 -0
- sonusai/speech/__init__.py +0 -0
- sonusai/speech/l2arctic.py +121 -0
- sonusai/speech/librispeech.py +102 -0
- sonusai/speech/mcgill.py +71 -0
- sonusai/speech/textgrid.py +89 -0
- sonusai/speech/timit.py +138 -0
- sonusai/speech/types.py +12 -0
- sonusai/speech/vctk.py +53 -0
- sonusai/speech/voxceleb.py +108 -0
- sonusai/utils/__init__.py +3 -0
- sonusai/utils/asl_p56.py +130 -0
- sonusai/utils/asr.py +91 -0
- sonusai/utils/asr_functions/__init__.py +3 -0
- sonusai/utils/asr_functions/aaware_whisper.py +69 -0
- sonusai/utils/audio_devices.py +50 -0
- sonusai/utils/braced_glob.py +50 -0
- sonusai/utils/calculate_input_shape.py +26 -0
- sonusai/utils/choice.py +51 -0
- sonusai/utils/compress.py +25 -0
- sonusai/utils/convert_string_to_number.py +6 -0
- sonusai/utils/create_timestamp.py +5 -0
- sonusai/utils/create_ts_name.py +14 -0
- sonusai/utils/dataclass_from_dict.py +27 -0
- sonusai/utils/db.py +16 -0
- sonusai/utils/docstring.py +53 -0
- sonusai/utils/energy_f.py +44 -0
- sonusai/utils/engineering_number.py +166 -0
- sonusai/utils/evaluate_random_rule.py +15 -0
- sonusai/utils/get_frames_per_batch.py +2 -0
- sonusai/utils/get_label_names.py +20 -0
- sonusai/utils/grouper.py +6 -0
- sonusai/utils/human_readable_size.py +7 -0
- sonusai/utils/keyboard_interrupt.py +12 -0
- sonusai/utils/load_object.py +21 -0
- sonusai/utils/max_text_width.py +9 -0
- sonusai/utils/model_utils.py +28 -0
- sonusai/utils/numeric_conversion.py +11 -0
- sonusai/utils/onnx_utils.py +155 -0
- sonusai/utils/parallel.py +162 -0
- sonusai/utils/path_info.py +7 -0
- sonusai/utils/print_mixture_details.py +60 -0
- sonusai/utils/rand.py +13 -0
- sonusai/utils/ranges.py +43 -0
- sonusai/utils/read_predict_data.py +32 -0
- sonusai/utils/reshape.py +154 -0
- sonusai/utils/seconds_to_hms.py +7 -0
- sonusai/utils/stacked_complex.py +82 -0
- sonusai/utils/stratified_shuffle_split.py +170 -0
- sonusai/utils/tokenized_shell_vars.py +143 -0
- sonusai/utils/write_audio.py +26 -0
- sonusai/utils/yes_or_no.py +8 -0
- sonusai/vars.py +47 -0
- sonusai-1.0.16.dist-info/METADATA +56 -0
- sonusai-1.0.16.dist-info/RECORD +150 -0
- sonusai-1.0.16.dist-info/WHEEL +4 -0
- sonusai-1.0.16.dist-info/entry_points.txt +3 -0
@@ -0,0 +1,632 @@
|
|
1
|
+
"""sonusai gentcst
|
2
|
+
|
3
|
+
usage: gentcst [-hvru] [-o OUTPUT] [-f FOLD] [-y ONTOLOGY]
|
4
|
+
|
5
|
+
options:
|
6
|
+
-h, --help
|
7
|
+
-v, --verbose Be verbose.
|
8
|
+
-o OUTPUT, --output OUTPUT Output file name.
|
9
|
+
-f FOLD, --fold FOLD Fold(s) to include. [default: *].
|
10
|
+
-y ONTOLOGY, --ontology ONTOLOGY Reference ontology JSON file for cross-check and adding metadata.
|
11
|
+
[default: ontology.json].
|
12
|
+
-r, --hierarchical Generate hierarchical multiclass truth for non-leaf nodes.
|
13
|
+
-u, --update Write JSON metadata into the tree.
|
14
|
+
|
15
|
+
Generate a target SonusAI configuration file from a hierarchical subdirectory tree
|
16
|
+
under the local directory. Leaves in the subdirectory tree define classes by providing SonusAI list
|
17
|
+
.txt files for class data and additional truth config definitions. This offers a way to simplify and
|
18
|
+
more clearly manage hierarchical class structure similar to the Audioset ontology.
|
19
|
+
|
20
|
+
Features:
|
21
|
+
- automatically compiles the target configuration with the number of classes and each truth
|
22
|
+
index and the associated files (note: files should use absolute path with env variable)
|
23
|
+
- supports hierarchical multilabel class truth generation, where higher nodes in the tree
|
24
|
+
are included with leaf truth index, and different leaves with the same name will have the
|
25
|
+
same index (i.e. door/squeak, chair/squeak, and brief-tone/squeak)
|
26
|
+
- support config overrides for class config.yml or filename.yml for specific file config
|
27
|
+
- manages folds via file prefix naming convention, i.e. 1-*.txt, 2-*.txt, etc.
|
28
|
+
specifies data for fold 1, 2, etc.
|
29
|
+
- cross-check node definitions with the reference ontology and auto copy class metadata into
|
30
|
+
the tree to help define/collect data
|
31
|
+
- class label is also generated (in the .yml or referencing a .csv)
|
32
|
+
|
33
|
+
Inputs:
|
34
|
+
The local subdirectory tree with expected top-level configuration file config.yml. This
|
35
|
+
file must at least define the feature and can have additional parameters treated as default
|
36
|
+
(for truth, augmentation, etc.) which will be included in the generated output yaml file.
|
37
|
+
An entire class dataset config parameter can be overridden by including it in a config.yml
|
38
|
+
file in the class subdirectory. Individual file config parameters can be further overridden
|
39
|
+
or specified in a file of the same name but with .yml (i.e. 1-data.yml with the 1-data.txt)
|
40
|
+
|
41
|
+
Outputs:
|
42
|
+
output.yml
|
43
|
+
output.csv
|
44
|
+
gentcst.log
|
45
|
+
|
46
|
+
"""
|
47
|
+
|
48
|
+
import signal
|
49
|
+
from dataclasses import dataclass
|
50
|
+
|
51
|
+
|
52
|
+
def signal_handler(_sig, _frame):
|
53
|
+
import sys
|
54
|
+
|
55
|
+
from sonusai import logger
|
56
|
+
|
57
|
+
logger.info("Canceled due to keyboard interrupt")
|
58
|
+
sys.exit(1)
|
59
|
+
|
60
|
+
|
61
|
+
signal.signal(signal.SIGINT, signal_handler)
|
62
|
+
|
63
|
+
CONFIG_FILE = "config.yml"
|
64
|
+
|
65
|
+
|
66
|
+
@dataclass
|
67
|
+
class FileInfo:
|
68
|
+
file: str
|
69
|
+
classes: list[str]
|
70
|
+
leaf: str
|
71
|
+
labels: list[str]
|
72
|
+
fold: int
|
73
|
+
truth_index: list[int] | None = None
|
74
|
+
|
75
|
+
|
76
|
+
@dataclass(frozen=True)
|
77
|
+
class DirInfo:
|
78
|
+
file: str
|
79
|
+
classes: list[str]
|
80
|
+
|
81
|
+
|
82
|
+
@dataclass(frozen=True)
|
83
|
+
class LabelInfo:
|
84
|
+
index: int
|
85
|
+
display_name: str
|
86
|
+
|
87
|
+
|
88
|
+
@dataclass(frozen=True)
|
89
|
+
class OntologyInfo:
|
90
|
+
id: str
|
91
|
+
name: list[str]
|
92
|
+
description: str
|
93
|
+
citation_uri: str
|
94
|
+
positive_examples: list[str]
|
95
|
+
child_ids: list[str]
|
96
|
+
restrictions: list[str]
|
97
|
+
|
98
|
+
|
99
|
+
def gentcst(
|
100
|
+
fold: str = "*",
|
101
|
+
ontology: str = "ontology.json",
|
102
|
+
hierarchical: bool = False,
|
103
|
+
update: bool = False,
|
104
|
+
verbose: bool = False,
|
105
|
+
) -> tuple[dict, list[LabelInfo]]:
|
106
|
+
from copy import deepcopy
|
107
|
+
from os import getcwd
|
108
|
+
from os.path import dirname
|
109
|
+
from os.path import exists
|
110
|
+
from os.path import splitext
|
111
|
+
|
112
|
+
from sonusai import SonusAIError
|
113
|
+
from sonusai import initial_log_messages
|
114
|
+
from sonusai import logger
|
115
|
+
from sonusai import update_console_handler
|
116
|
+
from sonusai.mixture.config import get_default_config
|
117
|
+
from sonusai.mixture.config import raw_load_config
|
118
|
+
from sonusai.mixture.config import update_config_from_hierarchy
|
119
|
+
from sonusai.config.truth import validate_truth_configs
|
120
|
+
|
121
|
+
update_console_handler(verbose)
|
122
|
+
initial_log_messages("gentcst")
|
123
|
+
|
124
|
+
if update:
|
125
|
+
logger.info("Updating tree with JSON metadata")
|
126
|
+
logger.info("")
|
127
|
+
|
128
|
+
logger.debug(f"fold: {fold}")
|
129
|
+
logger.debug(f"ontology: {ontology}")
|
130
|
+
logger.debug(f"hierarchical: {hierarchical}")
|
131
|
+
logger.debug(f"update: {update}")
|
132
|
+
logger.debug("")
|
133
|
+
|
134
|
+
config = get_config()
|
135
|
+
|
136
|
+
if config["truth_mode"] == "mutex" and hierarchical:
|
137
|
+
raise SonusAIError("Multi-class truth is incompatible with truth_mode mutex")
|
138
|
+
|
139
|
+
all_files = get_all_files(hierarchical)
|
140
|
+
all_folds = get_folds_from_files(all_files)
|
141
|
+
use_folds = get_use_folds(all_folds, fold)
|
142
|
+
use_files = get_files_from_folds(all_files, use_folds)
|
143
|
+
report_leaf_fold_data_usage(all_files, use_files)
|
144
|
+
|
145
|
+
ontology_data = validate_ontology(ontology, use_files)
|
146
|
+
|
147
|
+
labels = get_labels_from_files(all_files)
|
148
|
+
logger.debug("Truth indices:")
|
149
|
+
for item in labels:
|
150
|
+
logger.debug(f" {item.index:3} {item.display_name}")
|
151
|
+
logger.debug("")
|
152
|
+
|
153
|
+
gen_truth_indices(use_files, labels)
|
154
|
+
|
155
|
+
config["num_classes"] = len(labels)
|
156
|
+
if config["truth_mode"] == "mutex":
|
157
|
+
config["num_classes"] = config["num_classes"] + 1
|
158
|
+
|
159
|
+
config["targets"] = []
|
160
|
+
|
161
|
+
logger.info(f"gentcst {len(use_files)} entries in tree")
|
162
|
+
root = getcwd()
|
163
|
+
for file in use_files:
|
164
|
+
leaf = dirname(file.file)
|
165
|
+
local_config = get_default_config()
|
166
|
+
local_config = update_config_from_hierarchy(root=root, leaf=leaf, config=local_config)
|
167
|
+
if not isinstance(local_config["truth_settings"], list):
|
168
|
+
local_config["truth_settings"] = [local_config["truth_settings"]]
|
169
|
+
|
170
|
+
specific_name = splitext(file.file)[0] + ".yml"
|
171
|
+
if exists(specific_name):
|
172
|
+
specific_config = raw_load_config(specific_name)
|
173
|
+
|
174
|
+
for key in specific_config:
|
175
|
+
if key != "truth_settings":
|
176
|
+
local_config[key] = specific_config[key]
|
177
|
+
else:
|
178
|
+
local_config["truth_settings"] = validate_truth_configs(
|
179
|
+
given=specific_config["truth_settings"],
|
180
|
+
default=local_config["truth_settings"],
|
181
|
+
)
|
182
|
+
|
183
|
+
truth_settings = deepcopy(local_config["truth_settings"])
|
184
|
+
for idx, val in enumerate(truth_settings):
|
185
|
+
val["index"] = file.truth_index
|
186
|
+
for key in ["function", "config"]:
|
187
|
+
if key in val and val[key] == config["truth_settings"][idx][key]:
|
188
|
+
del val[key]
|
189
|
+
|
190
|
+
target = {"name": file.file, "truth_settings": truth_settings}
|
191
|
+
|
192
|
+
config["targets"].append(target)
|
193
|
+
|
194
|
+
if update:
|
195
|
+
write_metadata_to_tree(ontology_data, file)
|
196
|
+
|
197
|
+
return config, labels
|
198
|
+
|
199
|
+
|
200
|
+
def get_ontology(ontology: str) -> list[OntologyInfo]:
|
201
|
+
import json
|
202
|
+
from os.path import exists
|
203
|
+
|
204
|
+
raw_ontology_data = []
|
205
|
+
if exists(ontology):
|
206
|
+
with open(file=ontology, encoding="utf-8") as f:
|
207
|
+
raw_ontology_data = json.load(f)
|
208
|
+
|
209
|
+
return [
|
210
|
+
OntologyInfo(
|
211
|
+
id=item["id"],
|
212
|
+
name=convert_ontology_name(item["name"]),
|
213
|
+
description=item["description"],
|
214
|
+
citation_uri=item["citation_uri"],
|
215
|
+
positive_examples=item["positive_examples"],
|
216
|
+
child_ids=item["child_ids"],
|
217
|
+
restrictions=item["restrictions"],
|
218
|
+
)
|
219
|
+
for item in raw_ontology_data
|
220
|
+
]
|
221
|
+
|
222
|
+
|
223
|
+
def convert_ontology_name(name: str) -> list[str]:
|
224
|
+
"""Convert names to lowercase, convert spaces to '-', split into lists on ','."""
|
225
|
+
text = name.lower()
|
226
|
+
# Need to find ', ' before converting spaces so that we don't convert ', ' to ',-'
|
227
|
+
text = text.replace(", ", ",")
|
228
|
+
text = text.replace(" ", "-")
|
229
|
+
return text.split(",")
|
230
|
+
|
231
|
+
|
232
|
+
def get_dirs() -> list[DirInfo]:
|
233
|
+
from os import getcwd
|
234
|
+
from os import walk
|
235
|
+
from os.path import join
|
236
|
+
|
237
|
+
match: list[str] = []
|
238
|
+
cwd = getcwd()
|
239
|
+
for root, ds, _ in walk(cwd):
|
240
|
+
for d in ds:
|
241
|
+
match.append(join(root, d).replace(cwd, "."))
|
242
|
+
match = sorted(match)
|
243
|
+
|
244
|
+
dirs: list[DirInfo] = []
|
245
|
+
for m in match:
|
246
|
+
if not m.startswith("./gentcst") and m != ".":
|
247
|
+
classes = m.split("/")
|
248
|
+
# Remove first element because it is '.'
|
249
|
+
classes.pop(0)
|
250
|
+
dirs.append(DirInfo(file=m, classes=classes))
|
251
|
+
|
252
|
+
return dirs
|
253
|
+
|
254
|
+
|
255
|
+
def get_leaf_from_classes(classes: list[str]) -> str:
|
256
|
+
return "/".join(classes)
|
257
|
+
|
258
|
+
|
259
|
+
def get_all_files(hierarchical: bool) -> list[FileInfo]:
|
260
|
+
import re
|
261
|
+
from os import getcwd
|
262
|
+
from os import walk
|
263
|
+
from os.path import join
|
264
|
+
|
265
|
+
file_list: list[str] = []
|
266
|
+
cwd = getcwd()
|
267
|
+
regex = re.compile(r"^\d+-.*\.txt$")
|
268
|
+
for root, _, fs in walk(cwd):
|
269
|
+
for f in fs:
|
270
|
+
if regex.match(f):
|
271
|
+
file_list.append(join(root, f).replace(cwd, "."))
|
272
|
+
|
273
|
+
files: list[FileInfo] = []
|
274
|
+
fold_pattern = re.compile(r".*/(\d+)-.*\.txt")
|
275
|
+
for file in file_list:
|
276
|
+
fold_match = fold_pattern.match(file)
|
277
|
+
if fold_match:
|
278
|
+
fold = int(fold_match.group(1))
|
279
|
+
|
280
|
+
classes = file.split("/")
|
281
|
+
# Remove first element because it is '.'
|
282
|
+
classes.pop(0)
|
283
|
+
# Remove last element because it is the file name
|
284
|
+
classes.pop()
|
285
|
+
|
286
|
+
leaf = get_leaf_from_classes(classes)
|
287
|
+
|
288
|
+
if hierarchical:
|
289
|
+
labels = classes
|
290
|
+
else:
|
291
|
+
labels = [leaf]
|
292
|
+
|
293
|
+
files.append(FileInfo(file=file, classes=classes, leaf=leaf, labels=labels, fold=fold))
|
294
|
+
|
295
|
+
return files
|
296
|
+
|
297
|
+
|
298
|
+
def get_use_folds(all_folds: list[int], fold: str) -> list[int]:
|
299
|
+
import numpy as np
|
300
|
+
|
301
|
+
from sonusai import logger
|
302
|
+
|
303
|
+
req_folds: list[int] = get_req_folds(all_folds, fold)
|
304
|
+
use_folds: list[int] = list(set(req_folds).intersection(all_folds))
|
305
|
+
|
306
|
+
logger.debug("Fold information")
|
307
|
+
logger.debug(f" Available: {all_folds}")
|
308
|
+
logger.debug(f" Requested: {req_folds}")
|
309
|
+
logger.debug(f" Used: {use_folds}")
|
310
|
+
logger.debug(f" Unused: {list(set(use_folds).symmetric_difference(all_folds))}")
|
311
|
+
logger.debug(f" Missing: {list(np.setdiff1d(req_folds, all_folds))}")
|
312
|
+
logger.debug("")
|
313
|
+
|
314
|
+
return use_folds
|
315
|
+
|
316
|
+
|
317
|
+
def get_files_from_folds(files: list[FileInfo], folds: list[int]) -> list[FileInfo]:
|
318
|
+
return [file for file in files if file.fold in folds]
|
319
|
+
|
320
|
+
|
321
|
+
def get_item_from_name(ontology_data: list[OntologyInfo], name: str) -> OntologyInfo | None:
|
322
|
+
return next((item for item in ontology_data if name in item.name), None)
|
323
|
+
|
324
|
+
|
325
|
+
def get_item_from_id(ontology_data: list[OntologyInfo], identity: str) -> OntologyInfo | None:
|
326
|
+
return next((item for item in ontology_data if item.id == identity), None)
|
327
|
+
|
328
|
+
|
329
|
+
def get_name_from_id(ontology_data: list[OntologyInfo], identity: str) -> list[str] | None:
|
330
|
+
from sonusai import SonusAIError
|
331
|
+
|
332
|
+
name = None
|
333
|
+
found = False
|
334
|
+
|
335
|
+
for item in ontology_data:
|
336
|
+
if item.id == identity:
|
337
|
+
if found:
|
338
|
+
raise SonusAIError(f"id {identity} appears multiple times in ontology")
|
339
|
+
|
340
|
+
name = item.name
|
341
|
+
found = True
|
342
|
+
|
343
|
+
return name
|
344
|
+
|
345
|
+
|
346
|
+
def get_id_from_name(ontology_data: list[OntologyInfo], name: str) -> str | None:
|
347
|
+
from sonusai import SonusAIError
|
348
|
+
|
349
|
+
identity = None
|
350
|
+
found = False
|
351
|
+
|
352
|
+
for item in ontology_data:
|
353
|
+
if name in item.name:
|
354
|
+
if found:
|
355
|
+
raise SonusAIError(f"name {name} appears multiple times in ontology")
|
356
|
+
|
357
|
+
identity = item.id
|
358
|
+
found = True
|
359
|
+
|
360
|
+
return identity
|
361
|
+
|
362
|
+
|
363
|
+
def is_valid_name(ontology_data: list[OntologyInfo], name: str) -> bool:
|
364
|
+
return any(name in item.name for item in ontology_data)
|
365
|
+
|
366
|
+
|
367
|
+
def is_valid_child(ontology_data: list[OntologyInfo], parent: str, child: str) -> bool:
|
368
|
+
valid = False
|
369
|
+
parent_item = get_item_from_name(ontology_data, parent)
|
370
|
+
child_id = get_id_from_name(ontology_data, child)
|
371
|
+
|
372
|
+
if child_id is not None and parent_item is not None and child_id in parent_item.child_ids:
|
373
|
+
valid = True
|
374
|
+
|
375
|
+
return valid
|
376
|
+
|
377
|
+
|
378
|
+
def is_valid_hierarchy(ontology_data: list[OntologyInfo], classes: list[str]) -> bool:
|
379
|
+
from itertools import pairwise
|
380
|
+
valid = True
|
381
|
+
|
382
|
+
for parent, child in pairwise(classes):
|
383
|
+
if not is_valid_child(ontology_data, parent, child):
|
384
|
+
valid = False
|
385
|
+
|
386
|
+
return valid
|
387
|
+
|
388
|
+
|
389
|
+
def validate_class(ontology_data: list[OntologyInfo], item: FileInfo | DirInfo) -> bool:
|
390
|
+
from sonusai import logger
|
391
|
+
|
392
|
+
valid = True
|
393
|
+
|
394
|
+
for c in item.classes:
|
395
|
+
if not is_valid_name(ontology_data, c):
|
396
|
+
logger.warning(f" Could not find {c} in ontology for {item.file}")
|
397
|
+
valid = False
|
398
|
+
|
399
|
+
if valid and not is_valid_hierarchy(ontology_data, item.classes):
|
400
|
+
logger.warning(f" Invalid parent/child relationship for {item.file}")
|
401
|
+
valid = False
|
402
|
+
|
403
|
+
return valid
|
404
|
+
|
405
|
+
|
406
|
+
def get_req_folds(folds: list[int], fold: str) -> list[int]:
|
407
|
+
from sonusai.utils.ranges import expand_range
|
408
|
+
|
409
|
+
if fold == "*":
|
410
|
+
return folds
|
411
|
+
|
412
|
+
return expand_range(fold)
|
413
|
+
|
414
|
+
|
415
|
+
def get_folds_from_files(files: list[FileInfo]) -> list[int]:
|
416
|
+
result: list[int] = [file.fold for file in files]
|
417
|
+
# Converting to set and back to list ensures uniqueness
|
418
|
+
return sorted(set(result))
|
419
|
+
|
420
|
+
|
421
|
+
def get_leaves_from_files(files: list[FileInfo]) -> list[str]:
|
422
|
+
result: list[str] = [file.leaf for file in files]
|
423
|
+
# Converting to set and back to list ensures uniqueness
|
424
|
+
return sorted(set(result))
|
425
|
+
|
426
|
+
|
427
|
+
def get_labels_from_files(files: list[FileInfo]) -> list[LabelInfo]:
|
428
|
+
all_labels = [file.labels for file in files]
|
429
|
+
|
430
|
+
# Get labels by depth
|
431
|
+
labels_by_depth: list[list[str]] = [[] for _ in range(max([len(label) for label in all_labels]))]
|
432
|
+
for label in all_labels:
|
433
|
+
for index, name in enumerate(label):
|
434
|
+
labels_by_depth[index].append(name)
|
435
|
+
|
436
|
+
for n in range(len(labels_by_depth)):
|
437
|
+
# Converting to set and back to list ensures uniqueness
|
438
|
+
labels_by_depth[n] = sorted(set(labels_by_depth[n]))
|
439
|
+
|
440
|
+
# We want the deepest leaves first
|
441
|
+
labels_by_depth.reverse()
|
442
|
+
|
443
|
+
# Now flatten the list
|
444
|
+
flattened_labels_by_depth = [item for sublist in labels_by_depth for item in sublist]
|
445
|
+
|
446
|
+
# Generate index, name pairs
|
447
|
+
# labels = []
|
448
|
+
# for index, file in enumerate(flattened_labels_by_depth):
|
449
|
+
# labels.append({'index': index + 1, 'display_name': file})
|
450
|
+
return [LabelInfo(index=index + 1, display_name=file) for index, file in enumerate(flattened_labels_by_depth)]
|
451
|
+
|
452
|
+
|
453
|
+
def gen_truth_indices(files: list[FileInfo], labels: list[LabelInfo]):
|
454
|
+
for file in files:
|
455
|
+
file.truth_index = sorted([get_index_from_label(labels, label) for label in file.labels])
|
456
|
+
|
457
|
+
|
458
|
+
def get_index_from_label(labels: list[LabelInfo], label: str) -> int:
|
459
|
+
return next((item for item in labels if item.display_name == label), None).index
|
460
|
+
|
461
|
+
|
462
|
+
def get_config() -> dict:
|
463
|
+
from os.path import exists
|
464
|
+
|
465
|
+
from sonusai import SonusAIError
|
466
|
+
from sonusai.mixture import raw_load_config
|
467
|
+
|
468
|
+
if not exists(CONFIG_FILE):
|
469
|
+
raise SonusAIError(f"No {CONFIG_FILE} at top level")
|
470
|
+
|
471
|
+
config = raw_load_config(CONFIG_FILE)
|
472
|
+
|
473
|
+
if "feature" not in config:
|
474
|
+
raise SonusAIError("feature not in top level config")
|
475
|
+
|
476
|
+
if "truth_mode" not in config:
|
477
|
+
raise SonusAIError("truth_mode not in top level config")
|
478
|
+
|
479
|
+
if config["truth_mode"] not in ["normal", "mutex"]:
|
480
|
+
raise SonusAIError("Invalid truth_mode in top level config")
|
481
|
+
|
482
|
+
if "truth_settings" in config and not isinstance(config["truth_settings"], list):
|
483
|
+
config["truth_settings"] = [config["truth_settings"]]
|
484
|
+
|
485
|
+
return config
|
486
|
+
|
487
|
+
|
488
|
+
def validate_ontology(ontology: str, items: list[FileInfo] | list[DirInfo]) -> list[OntologyInfo]:
|
489
|
+
from os.path import exists
|
490
|
+
|
491
|
+
from sonusai import logger
|
492
|
+
|
493
|
+
if exists(ontology):
|
494
|
+
ontology_data = get_ontology(ontology)
|
495
|
+
logger.debug(f"Reference ontology in {ontology} has {len(ontology_data)} classes")
|
496
|
+
logger.debug("")
|
497
|
+
|
498
|
+
logger.info("Checking tree against reference ontology")
|
499
|
+
all_dirs = get_dirs()
|
500
|
+
valid = True
|
501
|
+
for file in all_dirs:
|
502
|
+
if not validate_class(ontology_data, file):
|
503
|
+
valid = False
|
504
|
+
if valid:
|
505
|
+
logger.info("PASS")
|
506
|
+
logger.info("")
|
507
|
+
|
508
|
+
logger.info("Checking files against reference ontology")
|
509
|
+
valid = True
|
510
|
+
for item in items:
|
511
|
+
if not validate_class(ontology_data, item):
|
512
|
+
valid = False
|
513
|
+
if valid:
|
514
|
+
logger.info("PASS")
|
515
|
+
logger.info("")
|
516
|
+
|
517
|
+
return ontology_data
|
518
|
+
|
519
|
+
return []
|
520
|
+
|
521
|
+
|
522
|
+
def get_node_from_name(ontology_data: list[OntologyInfo], name: str) -> OntologyInfo | None:
|
523
|
+
from sonusai import logger
|
524
|
+
|
525
|
+
nodes = [item for item in ontology_data if name in item.name]
|
526
|
+
if len(nodes) == 1:
|
527
|
+
return nodes[0]
|
528
|
+
|
529
|
+
if nodes:
|
530
|
+
logger.warning(f"Found multiple entries in reference ontology that match {name}")
|
531
|
+
else:
|
532
|
+
logger.warning(f"Could not find entry for {name} in reference ontology")
|
533
|
+
|
534
|
+
return None
|
535
|
+
|
536
|
+
|
537
|
+
def write_metadata_to_tree(ontology_data: list[OntologyInfo], file: FileInfo):
|
538
|
+
import json
|
539
|
+
from dataclasses import asdict
|
540
|
+
from os.path import basename
|
541
|
+
from os.path import dirname
|
542
|
+
|
543
|
+
if ontology_data:
|
544
|
+
node = get_node_from_name(ontology_data, file.classes[-1])
|
545
|
+
if node is not None:
|
546
|
+
dir_name = dirname(file.file)
|
547
|
+
json_name = dir_name + "/" + basename(dir_name) + ".json"
|
548
|
+
with open(file=json_name, mode="w") as f:
|
549
|
+
json.dump(asdict(node), f)
|
550
|
+
|
551
|
+
|
552
|
+
def get_folds_from_leaf(all_files: list[FileInfo], leaf: str) -> list[int]:
|
553
|
+
files = [item for item in all_files if item.leaf == leaf]
|
554
|
+
return get_folds_from_files(files)
|
555
|
+
|
556
|
+
|
557
|
+
def report_leaf_fold_data_usage(all_files: list[FileInfo], use_files: list[FileInfo]):
|
558
|
+
from sonusai import logger
|
559
|
+
|
560
|
+
use_leaves = get_leaves_from_files(use_files)
|
561
|
+
all_leaves = get_leaves_from_files(all_files)
|
562
|
+
|
563
|
+
logger.debug("Data folds present in each leaf")
|
564
|
+
leaf_len = len(max(all_leaves, key=len))
|
565
|
+
for leaf in all_leaves:
|
566
|
+
folds = get_folds_from_leaf(all_files, leaf)
|
567
|
+
logger.debug(f" {leaf:{leaf_len}} {folds}")
|
568
|
+
logger.debug("")
|
569
|
+
|
570
|
+
dif_leaves = set(all_leaves).symmetric_difference(use_leaves)
|
571
|
+
if dif_leaves:
|
572
|
+
logger.warning("This fold selection is missing data from the following leaves")
|
573
|
+
for c in dif_leaves:
|
574
|
+
logger.warning(f" {c}")
|
575
|
+
logger.warning("")
|
576
|
+
|
577
|
+
|
578
|
+
def main() -> None:
|
579
|
+
from docopt import docopt
|
580
|
+
|
581
|
+
import sonusai
|
582
|
+
from sonusai.utils.docstring import trim_docstring
|
583
|
+
|
584
|
+
args = docopt(trim_docstring(__doc__), version=sonusai.__version__, options_first=True)
|
585
|
+
|
586
|
+
verbose = args["--verbose"]
|
587
|
+
output_name = args["--output"]
|
588
|
+
fold = args["--fold"]
|
589
|
+
ontology = args["--ontology"]
|
590
|
+
hierarchical = args["--hierarchical"]
|
591
|
+
update = args["--update"]
|
592
|
+
|
593
|
+
import csv
|
594
|
+
from dataclasses import asdict
|
595
|
+
from os import getcwd
|
596
|
+
from os.path import basename
|
597
|
+
from os.path import splitext
|
598
|
+
|
599
|
+
import yaml
|
600
|
+
|
601
|
+
from sonusai import create_file_handler
|
602
|
+
from sonusai import logger
|
603
|
+
|
604
|
+
if not output_name:
|
605
|
+
output_name = basename(getcwd()) + ".yml"
|
606
|
+
|
607
|
+
create_file_handler("gentcst.log")
|
608
|
+
|
609
|
+
config, labels = gentcst(
|
610
|
+
fold=fold,
|
611
|
+
ontology=ontology,
|
612
|
+
hierarchical=hierarchical,
|
613
|
+
update=update,
|
614
|
+
verbose=verbose,
|
615
|
+
)
|
616
|
+
|
617
|
+
with open(file=output_name, mode="w") as f:
|
618
|
+
yaml.dump(config, f)
|
619
|
+
logger.info(f"Wrote config to {output_name}")
|
620
|
+
|
621
|
+
csv_fields = ["index", "display_name"]
|
622
|
+
csv_name = splitext(output_name)[0] + ".csv"
|
623
|
+
|
624
|
+
with open(file=csv_name, mode="w") as f:
|
625
|
+
writer = csv.DictWriter(f, fieldnames=csv_fields)
|
626
|
+
writer.writeheader()
|
627
|
+
writer.writerows([asdict(label) for label in labels])
|
628
|
+
logger.info(f"Wrote labels to {csv_name}")
|
629
|
+
|
630
|
+
|
631
|
+
if __name__ == "__main__":
|
632
|
+
main()
|