sonusai 0.18.9__py3-none-any.whl → 0.19.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sonusai/__init__.py +20 -29
- sonusai/aawscd_probwrite.py +18 -18
- sonusai/audiofe.py +93 -80
- sonusai/calc_metric_spenh.py +395 -321
- sonusai/data/genmixdb.yml +5 -11
- sonusai/{gentcst.py → deprecated/gentcst.py} +146 -149
- sonusai/{plot.py → deprecated/plot.py} +177 -131
- sonusai/{tplot.py → deprecated/tplot.py} +124 -102
- sonusai/doc/__init__.py +1 -1
- sonusai/doc/doc.py +112 -177
- sonusai/doc.py +10 -10
- sonusai/genft.py +93 -77
- sonusai/genmetrics.py +59 -46
- sonusai/genmix.py +116 -104
- sonusai/genmixdb.py +194 -153
- sonusai/lsdb.py +56 -66
- sonusai/main.py +23 -20
- sonusai/metrics/__init__.py +2 -0
- sonusai/metrics/calc_audio_stats.py +29 -24
- sonusai/metrics/calc_class_weights.py +7 -7
- sonusai/metrics/calc_optimal_thresholds.py +5 -7
- sonusai/metrics/calc_pcm.py +3 -3
- sonusai/metrics/calc_pesq.py +10 -7
- sonusai/metrics/calc_phase_distance.py +3 -3
- sonusai/metrics/calc_sa_sdr.py +10 -8
- sonusai/metrics/calc_segsnr_f.py +15 -17
- sonusai/metrics/calc_speech.py +105 -47
- sonusai/metrics/calc_wer.py +35 -32
- sonusai/metrics/calc_wsdr.py +10 -7
- sonusai/metrics/class_summary.py +30 -27
- sonusai/metrics/confusion_matrix_summary.py +25 -22
- sonusai/metrics/one_hot.py +91 -57
- sonusai/metrics/snr_summary.py +53 -46
- sonusai/mixture/__init__.py +19 -14
- sonusai/mixture/audio.py +4 -6
- sonusai/mixture/augmentation.py +37 -43
- sonusai/mixture/class_count.py +5 -14
- sonusai/mixture/config.py +292 -225
- sonusai/mixture/constants.py +41 -30
- sonusai/mixture/data_io.py +155 -0
- sonusai/mixture/datatypes.py +111 -108
- sonusai/mixture/db_datatypes.py +54 -70
- sonusai/mixture/eq_rule_is_valid.py +6 -9
- sonusai/mixture/feature.py +40 -38
- sonusai/mixture/generation.py +522 -389
- sonusai/mixture/helpers.py +217 -272
- sonusai/mixture/log_duration_and_sizes.py +16 -13
- sonusai/mixture/mixdb.py +669 -477
- sonusai/mixture/soundfile_audio.py +12 -17
- sonusai/mixture/sox_audio.py +91 -112
- sonusai/mixture/sox_augmentation.py +8 -9
- sonusai/mixture/spectral_mask.py +4 -6
- sonusai/mixture/target_class_balancing.py +41 -36
- sonusai/mixture/targets.py +69 -67
- sonusai/mixture/tokenized_shell_vars.py +23 -23
- sonusai/mixture/torchaudio_audio.py +14 -15
- sonusai/mixture/torchaudio_augmentation.py +23 -27
- sonusai/mixture/truth.py +48 -26
- sonusai/mixture/truth_functions/__init__.py +26 -0
- sonusai/mixture/truth_functions/crm.py +56 -38
- sonusai/mixture/truth_functions/datatypes.py +37 -0
- sonusai/mixture/truth_functions/energy.py +85 -59
- sonusai/mixture/truth_functions/file.py +30 -30
- sonusai/mixture/truth_functions/phoneme.py +14 -7
- sonusai/mixture/truth_functions/sed.py +71 -45
- sonusai/mixture/truth_functions/target.py +69 -106
- sonusai/mkwav.py +52 -85
- sonusai/onnx_predict.py +46 -43
- sonusai/queries/__init__.py +3 -1
- sonusai/queries/queries.py +100 -59
- sonusai/speech/__init__.py +2 -0
- sonusai/speech/l2arctic.py +24 -23
- sonusai/speech/librispeech.py +16 -17
- sonusai/speech/mcgill.py +22 -21
- sonusai/speech/textgrid.py +32 -25
- sonusai/speech/timit.py +45 -42
- sonusai/speech/vctk.py +14 -13
- sonusai/speech/voxceleb.py +26 -20
- sonusai/summarize_metric_spenh.py +11 -10
- sonusai/utils/__init__.py +4 -3
- sonusai/utils/asl_p56.py +1 -1
- sonusai/utils/asr.py +37 -17
- sonusai/utils/asr_functions/__init__.py +2 -0
- sonusai/utils/asr_functions/aaware_whisper.py +18 -12
- sonusai/utils/audio_devices.py +12 -12
- sonusai/utils/braced_glob.py +6 -8
- sonusai/utils/calculate_input_shape.py +1 -4
- sonusai/utils/compress.py +2 -2
- sonusai/utils/convert_string_to_number.py +1 -3
- sonusai/utils/create_timestamp.py +1 -1
- sonusai/utils/create_ts_name.py +2 -2
- sonusai/utils/dataclass_from_dict.py +1 -1
- sonusai/utils/docstring.py +6 -6
- sonusai/utils/energy_f.py +9 -7
- sonusai/utils/engineering_number.py +56 -54
- sonusai/utils/get_label_names.py +8 -10
- sonusai/utils/human_readable_size.py +2 -2
- sonusai/utils/model_utils.py +3 -5
- sonusai/utils/numeric_conversion.py +2 -4
- sonusai/utils/onnx_utils.py +43 -32
- sonusai/utils/parallel.py +40 -27
- sonusai/utils/print_mixture_details.py +25 -22
- sonusai/utils/ranges.py +12 -12
- sonusai/utils/read_predict_data.py +11 -9
- sonusai/utils/reshape.py +19 -26
- sonusai/utils/seconds_to_hms.py +1 -1
- sonusai/utils/stacked_complex.py +8 -16
- sonusai/utils/stratified_shuffle_split.py +29 -27
- sonusai/utils/write_audio.py +2 -2
- sonusai/utils/yes_or_no.py +3 -3
- sonusai/vars.py +14 -14
- {sonusai-0.18.9.dist-info → sonusai-0.19.5.dist-info}/METADATA +20 -21
- sonusai-0.19.5.dist-info/RECORD +125 -0
- {sonusai-0.18.9.dist-info → sonusai-0.19.5.dist-info}/WHEEL +1 -1
- sonusai/mixture/truth_functions/data.py +0 -58
- sonusai/utils/read_mixture_data.py +0 -14
- sonusai-0.18.9.dist-info/RECORD +0 -125
- {sonusai-0.18.9.dist-info → sonusai-0.19.5.dist-info}/entry_points.txt +0 -0
sonusai/data/genmixdb.yml
CHANGED
@@ -7,6 +7,8 @@ feature: ""
|
|
7
7
|
|
8
8
|
target_level_type: default
|
9
9
|
|
10
|
+
class_indices: 1
|
11
|
+
|
10
12
|
targets: [ ]
|
11
13
|
|
12
14
|
num_classes: 1
|
@@ -17,15 +19,7 @@ seed: 0
|
|
17
19
|
|
18
20
|
class_weights_threshold: 0.5
|
19
21
|
|
20
|
-
|
21
|
-
|
22
|
-
truth_reduction_function: max
|
23
|
-
|
24
|
-
truth_settings:
|
25
|
-
function: sed
|
26
|
-
config:
|
27
|
-
thresholds: [ -38, -41, -48 ]
|
28
|
-
index: 1
|
22
|
+
truth_configs: { }
|
29
23
|
|
30
24
|
asr_manifest: [ ]
|
31
25
|
|
@@ -52,7 +46,7 @@ snrs:
|
|
52
46
|
|
53
47
|
random_snrs: [ ]
|
54
48
|
|
55
|
-
noise_mix_mode:
|
49
|
+
noise_mix_mode: exhaustive
|
56
50
|
|
57
51
|
impulse_responses: [ ]
|
58
52
|
|
@@ -63,4 +57,4 @@ spectral_masks:
|
|
63
57
|
t_num: 0
|
64
58
|
t_max_percent: 100
|
65
59
|
|
66
|
-
asr_configs:
|
60
|
+
asr_configs: { }
|
@@ -44,9 +44,9 @@ Outputs:
|
|
44
44
|
gentcst.log
|
45
45
|
|
46
46
|
"""
|
47
|
+
|
47
48
|
import signal
|
48
49
|
from dataclasses import dataclass
|
49
|
-
from typing import Optional
|
50
50
|
|
51
51
|
|
52
52
|
def signal_handler(_sig, _frame):
|
@@ -54,13 +54,13 @@ def signal_handler(_sig, _frame):
|
|
54
54
|
|
55
55
|
from sonusai import logger
|
56
56
|
|
57
|
-
logger.info(
|
57
|
+
logger.info("Canceled due to keyboard interrupt")
|
58
58
|
sys.exit(1)
|
59
59
|
|
60
60
|
|
61
61
|
signal.signal(signal.SIGINT, signal_handler)
|
62
62
|
|
63
|
-
CONFIG_FILE =
|
63
|
+
CONFIG_FILE = "config.yml"
|
64
64
|
|
65
65
|
|
66
66
|
@dataclass
|
@@ -70,7 +70,7 @@ class FileInfo:
|
|
70
70
|
leaf: str
|
71
71
|
labels: list[str]
|
72
72
|
fold: int
|
73
|
-
truth_index:
|
73
|
+
truth_index: list[int] | None = None
|
74
74
|
|
75
75
|
|
76
76
|
@dataclass(frozen=True)
|
@@ -96,11 +96,13 @@ class OntologyInfo:
|
|
96
96
|
restrictions: list[str]
|
97
97
|
|
98
98
|
|
99
|
-
def gentcst(
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
99
|
+
def gentcst(
|
100
|
+
fold: str = "*",
|
101
|
+
ontology: str = "ontology.json",
|
102
|
+
hierarchical: bool = False,
|
103
|
+
update: bool = False,
|
104
|
+
verbose: bool = False,
|
105
|
+
) -> tuple[dict, list[LabelInfo]]:
|
104
106
|
from copy import deepcopy
|
105
107
|
from os import getcwd
|
106
108
|
from os.path import dirname
|
@@ -114,25 +116,25 @@ def gentcst(fold: str = '*',
|
|
114
116
|
from sonusai.mixture import get_default_config
|
115
117
|
from sonusai.mixture import raw_load_config
|
116
118
|
from sonusai.mixture import update_config_from_hierarchy
|
117
|
-
from sonusai.mixture import
|
119
|
+
from sonusai.mixture import validate_truth_configs
|
118
120
|
|
119
121
|
update_console_handler(verbose)
|
120
|
-
initial_log_messages(
|
122
|
+
initial_log_messages("gentcst")
|
121
123
|
|
122
124
|
if update:
|
123
|
-
logger.info(
|
124
|
-
logger.info(
|
125
|
+
logger.info("Updating tree with JSON metadata")
|
126
|
+
logger.info("")
|
125
127
|
|
126
|
-
logger.debug(f
|
127
|
-
logger.debug(f
|
128
|
-
logger.debug(f
|
129
|
-
logger.debug(f
|
130
|
-
logger.debug(
|
128
|
+
logger.debug(f"fold: {fold}")
|
129
|
+
logger.debug(f"ontology: {ontology}")
|
130
|
+
logger.debug(f"hierarchical: {hierarchical}")
|
131
|
+
logger.debug(f"update: {update}")
|
132
|
+
logger.debug("")
|
131
133
|
|
132
134
|
config = get_config()
|
133
135
|
|
134
|
-
if config[
|
135
|
-
raise SonusAIError(
|
136
|
+
if config["truth_mode"] == "mutex" and hierarchical:
|
137
|
+
raise SonusAIError("Multi-class truth is incompatible with truth_mode mutex")
|
136
138
|
|
137
139
|
all_files = get_all_files(hierarchical)
|
138
140
|
all_folds = get_folds_from_files(all_files)
|
@@ -143,52 +145,51 @@ def gentcst(fold: str = '*',
|
|
143
145
|
ontology_data = validate_ontology(ontology, use_files)
|
144
146
|
|
145
147
|
labels = get_labels_from_files(all_files)
|
146
|
-
logger.debug(
|
148
|
+
logger.debug("Truth indices:")
|
147
149
|
for item in labels:
|
148
|
-
logger.debug(f
|
149
|
-
logger.debug(
|
150
|
+
logger.debug(f" {item.index:3} {item.display_name}")
|
151
|
+
logger.debug("")
|
150
152
|
|
151
153
|
gen_truth_indices(use_files, labels)
|
152
154
|
|
153
|
-
config[
|
154
|
-
if config[
|
155
|
-
config[
|
155
|
+
config["num_classes"] = len(labels)
|
156
|
+
if config["truth_mode"] == "mutex":
|
157
|
+
config["num_classes"] = config["num_classes"] + 1
|
156
158
|
|
157
|
-
config[
|
159
|
+
config["targets"] = []
|
158
160
|
|
159
|
-
logger.info(f
|
161
|
+
logger.info(f"gentcst {len(use_files)} entries in tree")
|
160
162
|
root = getcwd()
|
161
163
|
for file in use_files:
|
162
164
|
leaf = dirname(file.file)
|
163
165
|
local_config = get_default_config()
|
164
166
|
local_config = update_config_from_hierarchy(root=root, leaf=leaf, config=local_config)
|
165
|
-
if not isinstance(local_config[
|
166
|
-
local_config[
|
167
|
+
if not isinstance(local_config["truth_settings"], list):
|
168
|
+
local_config["truth_settings"] = [local_config["truth_settings"]]
|
167
169
|
|
168
|
-
specific_name = splitext(file.file)[0] +
|
170
|
+
specific_name = splitext(file.file)[0] + ".yml"
|
169
171
|
if exists(specific_name):
|
170
172
|
specific_config = raw_load_config(specific_name)
|
171
173
|
|
172
174
|
for key in specific_config:
|
173
|
-
if key !=
|
175
|
+
if key != "truth_settings":
|
174
176
|
local_config[key] = specific_config[key]
|
175
177
|
else:
|
176
|
-
local_config[
|
177
|
-
|
178
|
+
local_config["truth_settings"] = validate_truth_configs(
|
179
|
+
given=specific_config["truth_settings"],
|
180
|
+
default=local_config["truth_settings"],
|
181
|
+
)
|
178
182
|
|
179
|
-
truth_settings = deepcopy(local_config[
|
183
|
+
truth_settings = deepcopy(local_config["truth_settings"])
|
180
184
|
for idx, val in enumerate(truth_settings):
|
181
|
-
val[
|
182
|
-
for key in [
|
183
|
-
if key in val and val[key] == config[
|
185
|
+
val["index"] = file.truth_index
|
186
|
+
for key in ["function", "config"]:
|
187
|
+
if key in val and val[key] == config["truth_settings"][idx][key]:
|
184
188
|
del val[key]
|
185
189
|
|
186
|
-
target = {
|
187
|
-
'name': file.file,
|
188
|
-
'truth_settings': truth_settings
|
189
|
-
}
|
190
|
+
target = {"name": file.file, "truth_settings": truth_settings}
|
190
191
|
|
191
|
-
config[
|
192
|
+
config["targets"].append(target)
|
192
193
|
|
193
194
|
if update:
|
194
195
|
write_metadata_to_tree(ontology_data, file)
|
@@ -202,28 +203,30 @@ def get_ontology(ontology: str) -> list[OntologyInfo]:
|
|
202
203
|
|
203
204
|
raw_ontology_data = []
|
204
205
|
if exists(ontology):
|
205
|
-
with open(file=ontology,
|
206
|
+
with open(file=ontology, encoding="utf-8") as f:
|
206
207
|
raw_ontology_data = json.load(f)
|
207
208
|
|
208
|
-
return [
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
209
|
+
return [
|
210
|
+
OntologyInfo(
|
211
|
+
id=item["id"],
|
212
|
+
name=convert_ontology_name(item["name"]),
|
213
|
+
description=item["description"],
|
214
|
+
citation_uri=item["citation_uri"],
|
215
|
+
positive_examples=item["positive_examples"],
|
216
|
+
child_ids=item["child_ids"],
|
217
|
+
restrictions=item["restrictions"],
|
218
|
+
)
|
219
|
+
for item in raw_ontology_data
|
220
|
+
]
|
217
221
|
|
218
222
|
|
219
223
|
def convert_ontology_name(name: str) -> list[str]:
|
220
|
-
"""Convert names to lowercase, convert spaces to '-', split into lists on ','.
|
221
|
-
"""
|
224
|
+
"""Convert names to lowercase, convert spaces to '-', split into lists on ','."""
|
222
225
|
text = name.lower()
|
223
226
|
# Need to find ', ' before converting spaces so that we don't convert ', ' to ',-'
|
224
|
-
text = text.replace(
|
225
|
-
text = text.replace(
|
226
|
-
return text.split(
|
227
|
+
text = text.replace(", ", ",")
|
228
|
+
text = text.replace(" ", "-")
|
229
|
+
return text.split(",")
|
227
230
|
|
228
231
|
|
229
232
|
def get_dirs() -> list[DirInfo]:
|
@@ -235,13 +238,13 @@ def get_dirs() -> list[DirInfo]:
|
|
235
238
|
cwd = getcwd()
|
236
239
|
for root, ds, _ in walk(cwd):
|
237
240
|
for d in ds:
|
238
|
-
match.append(join(root, d).replace(cwd,
|
241
|
+
match.append(join(root, d).replace(cwd, "."))
|
239
242
|
match = sorted(match)
|
240
243
|
|
241
244
|
dirs: list[DirInfo] = []
|
242
245
|
for m in match:
|
243
|
-
if not m.startswith(
|
244
|
-
classes = m.split(
|
246
|
+
if not m.startswith("./gentcst") and m != ".":
|
247
|
+
classes = m.split("/")
|
245
248
|
# Remove first element because it is '.'
|
246
249
|
classes.pop(0)
|
247
250
|
dirs.append(DirInfo(file=m, classes=classes))
|
@@ -250,7 +253,7 @@ def get_dirs() -> list[DirInfo]:
|
|
250
253
|
|
251
254
|
|
252
255
|
def get_leaf_from_classes(classes: list[str]) -> str:
|
253
|
-
return
|
256
|
+
return "/".join(classes)
|
254
257
|
|
255
258
|
|
256
259
|
def get_all_files(hierarchical: bool) -> list[FileInfo]:
|
@@ -261,20 +264,20 @@ def get_all_files(hierarchical: bool) -> list[FileInfo]:
|
|
261
264
|
|
262
265
|
file_list: list[str] = []
|
263
266
|
cwd = getcwd()
|
264
|
-
regex = re.compile(r
|
267
|
+
regex = re.compile(r"^\d+-.*\.txt$")
|
265
268
|
for root, _, fs in walk(cwd):
|
266
269
|
for f in fs:
|
267
270
|
if regex.match(f):
|
268
|
-
file_list.append(join(root, f).replace(cwd,
|
271
|
+
file_list.append(join(root, f).replace(cwd, "."))
|
269
272
|
|
270
273
|
files: list[FileInfo] = []
|
271
|
-
fold_pattern = re.compile(r
|
274
|
+
fold_pattern = re.compile(r".*/(\d+)-.*\.txt")
|
272
275
|
for file in file_list:
|
273
276
|
fold_match = fold_pattern.match(file)
|
274
277
|
if fold_match:
|
275
278
|
fold = int(fold_match.group(1))
|
276
279
|
|
277
|
-
classes = file.split(
|
280
|
+
classes = file.split("/")
|
278
281
|
# Remove first element because it is '.'
|
279
282
|
classes.pop(0)
|
280
283
|
# Remove last element because it is the file name
|
@@ -287,11 +290,7 @@ def get_all_files(hierarchical: bool) -> list[FileInfo]:
|
|
287
290
|
else:
|
288
291
|
labels = [leaf]
|
289
292
|
|
290
|
-
files.append(FileInfo(file=file,
|
291
|
-
classes=classes,
|
292
|
-
leaf=leaf,
|
293
|
-
labels=labels,
|
294
|
-
fold=fold))
|
293
|
+
files.append(FileInfo(file=file, classes=classes, leaf=leaf, labels=labels, fold=fold))
|
295
294
|
|
296
295
|
return files
|
297
296
|
|
@@ -304,13 +303,13 @@ def get_use_folds(all_folds: list[int], fold: str) -> list[int]:
|
|
304
303
|
req_folds: list[int] = get_req_folds(all_folds, fold)
|
305
304
|
use_folds: list[int] = list(set(req_folds).intersection(all_folds))
|
306
305
|
|
307
|
-
logger.debug(
|
308
|
-
logger.debug(f
|
309
|
-
logger.debug(f
|
310
|
-
logger.debug(f
|
311
|
-
logger.debug(f
|
312
|
-
logger.debug(f
|
313
|
-
logger.debug(
|
306
|
+
logger.debug("Fold information")
|
307
|
+
logger.debug(f" Available: {all_folds}")
|
308
|
+
logger.debug(f" Requested: {req_folds}")
|
309
|
+
logger.debug(f" Used: {use_folds}")
|
310
|
+
logger.debug(f" Unused: {list(set(use_folds).symmetric_difference(all_folds))}")
|
311
|
+
logger.debug(f" Missing: {list(np.setdiff1d(req_folds, all_folds))}")
|
312
|
+
logger.debug("")
|
314
313
|
|
315
314
|
return use_folds
|
316
315
|
|
@@ -319,15 +318,15 @@ def get_files_from_folds(files: list[FileInfo], folds: list[int]) -> list[FileIn
|
|
319
318
|
return [file for file in files if file.fold in folds]
|
320
319
|
|
321
320
|
|
322
|
-
def get_item_from_name(ontology_data: list[OntologyInfo], name: str) -> OntologyInfo:
|
321
|
+
def get_item_from_name(ontology_data: list[OntologyInfo], name: str) -> OntologyInfo | None:
|
323
322
|
return next((item for item in ontology_data if name in item.name), None)
|
324
323
|
|
325
324
|
|
326
|
-
def get_item_from_id(ontology_data: list[OntologyInfo], identity: str) -> OntologyInfo:
|
325
|
+
def get_item_from_id(ontology_data: list[OntologyInfo], identity: str) -> OntologyInfo | None:
|
327
326
|
return next((item for item in ontology_data if item.id == identity), None)
|
328
327
|
|
329
328
|
|
330
|
-
def get_name_from_id(ontology_data: list[OntologyInfo], identity: str) -> list[str]:
|
329
|
+
def get_name_from_id(ontology_data: list[OntologyInfo], identity: str) -> list[str] | None:
|
331
330
|
from sonusai import SonusAIError
|
332
331
|
|
333
332
|
name = None
|
@@ -336,7 +335,7 @@ def get_name_from_id(ontology_data: list[OntologyInfo], identity: str) -> list[s
|
|
336
335
|
for item in ontology_data:
|
337
336
|
if item.id == identity:
|
338
337
|
if found:
|
339
|
-
raise SonusAIError(f
|
338
|
+
raise SonusAIError(f"id {identity} appears multiple times in ontology")
|
340
339
|
|
341
340
|
name = item.name
|
342
341
|
found = True
|
@@ -344,7 +343,7 @@ def get_name_from_id(ontology_data: list[OntologyInfo], identity: str) -> list[s
|
|
344
343
|
return name
|
345
344
|
|
346
345
|
|
347
|
-
def get_id_from_name(ontology_data: list[OntologyInfo], name: str) -> str:
|
346
|
+
def get_id_from_name(ontology_data: list[OntologyInfo], name: str) -> str | None:
|
348
347
|
from sonusai import SonusAIError
|
349
348
|
|
350
349
|
identity = None
|
@@ -353,7 +352,7 @@ def get_id_from_name(ontology_data: list[OntologyInfo], name: str) -> str:
|
|
353
352
|
for item in ontology_data:
|
354
353
|
if name in item.name:
|
355
354
|
if found:
|
356
|
-
raise SonusAIError(f
|
355
|
+
raise SonusAIError(f"name {name} appears multiple times in ontology")
|
357
356
|
|
358
357
|
identity = item.id
|
359
358
|
found = True
|
@@ -362,10 +361,7 @@ def get_id_from_name(ontology_data: list[OntologyInfo], name: str) -> str:
|
|
362
361
|
|
363
362
|
|
364
363
|
def is_valid_name(ontology_data: list[OntologyInfo], name: str) -> bool:
|
365
|
-
for item in ontology_data
|
366
|
-
if name in item.name:
|
367
|
-
return True
|
368
|
-
return False
|
364
|
+
return any(name in item.name for item in ontology_data)
|
369
365
|
|
370
366
|
|
371
367
|
def is_valid_child(ontology_data: list[OntologyInfo], parent: str, child: str) -> bool:
|
@@ -373,17 +369,17 @@ def is_valid_child(ontology_data: list[OntologyInfo], parent: str, child: str) -
|
|
373
369
|
parent_item = get_item_from_name(ontology_data, parent)
|
374
370
|
child_id = get_id_from_name(ontology_data, child)
|
375
371
|
|
376
|
-
if child_id is not None and parent_item is not None:
|
377
|
-
|
378
|
-
valid = True
|
372
|
+
if child_id is not None and parent_item is not None and child_id in parent_item.child_ids:
|
373
|
+
valid = True
|
379
374
|
|
380
375
|
return valid
|
381
376
|
|
382
377
|
|
383
378
|
def is_valid_hierarchy(ontology_data: list[OntologyInfo], classes: list[str]) -> bool:
|
379
|
+
from itertools import pairwise
|
384
380
|
valid = True
|
385
381
|
|
386
|
-
for parent, child in
|
382
|
+
for parent, child in pairwise(classes):
|
387
383
|
if not is_valid_child(ontology_data, parent, child):
|
388
384
|
valid = False
|
389
385
|
|
@@ -397,13 +393,12 @@ def validate_class(ontology_data: list[OntologyInfo], item: FileInfo | DirInfo)
|
|
397
393
|
|
398
394
|
for c in item.classes:
|
399
395
|
if not is_valid_name(ontology_data, c):
|
400
|
-
logger.warning(f
|
396
|
+
logger.warning(f" Could not find {c} in ontology for {item.file}")
|
401
397
|
valid = False
|
402
398
|
|
403
|
-
if valid:
|
404
|
-
|
405
|
-
|
406
|
-
valid = False
|
399
|
+
if valid and not is_valid_hierarchy(ontology_data, item.classes):
|
400
|
+
logger.warning(f" Invalid parent/child relationship for {item.file}")
|
401
|
+
valid = False
|
407
402
|
|
408
403
|
return valid
|
409
404
|
|
@@ -411,7 +406,7 @@ def validate_class(ontology_data: list[OntologyInfo], item: FileInfo | DirInfo)
|
|
411
406
|
def get_req_folds(folds: list[int], fold: str) -> list[int]:
|
412
407
|
from sonusai.utils import expand_range
|
413
408
|
|
414
|
-
if fold ==
|
409
|
+
if fold == "*":
|
415
410
|
return folds
|
416
411
|
|
417
412
|
return expand_range(fold)
|
@@ -420,13 +415,13 @@ def get_req_folds(folds: list[int], fold: str) -> list[int]:
|
|
420
415
|
def get_folds_from_files(files: list[FileInfo]) -> list[int]:
|
421
416
|
result: list[int] = [file.fold for file in files]
|
422
417
|
# Converting to set and back to list ensures uniqueness
|
423
|
-
return sorted(
|
418
|
+
return sorted(set(result))
|
424
419
|
|
425
420
|
|
426
421
|
def get_leaves_from_files(files: list[FileInfo]) -> list[str]:
|
427
422
|
result: list[str] = [file.leaf for file in files]
|
428
423
|
# Converting to set and back to list ensures uniqueness
|
429
|
-
return sorted(
|
424
|
+
return sorted(set(result))
|
430
425
|
|
431
426
|
|
432
427
|
def get_labels_from_files(files: list[FileInfo]) -> list[LabelInfo]:
|
@@ -440,7 +435,7 @@ def get_labels_from_files(files: list[FileInfo]) -> list[LabelInfo]:
|
|
440
435
|
|
441
436
|
for n in range(len(labels_by_depth)):
|
442
437
|
# Converting to set and back to list ensures uniqueness
|
443
|
-
labels_by_depth[n] = sorted(
|
438
|
+
labels_by_depth[n] = sorted(set(labels_by_depth[n]))
|
444
439
|
|
445
440
|
# We want the deepest leaves first
|
446
441
|
labels_by_depth.reverse()
|
@@ -471,21 +466,21 @@ def get_config() -> dict:
|
|
471
466
|
from sonusai.mixture import raw_load_config
|
472
467
|
|
473
468
|
if not exists(CONFIG_FILE):
|
474
|
-
raise SonusAIError(f
|
469
|
+
raise SonusAIError(f"No {CONFIG_FILE} at top level")
|
475
470
|
|
476
471
|
config = raw_load_config(CONFIG_FILE)
|
477
472
|
|
478
|
-
if
|
479
|
-
raise SonusAIError(
|
473
|
+
if "feature" not in config:
|
474
|
+
raise SonusAIError("feature not in top level config")
|
480
475
|
|
481
|
-
if
|
482
|
-
raise SonusAIError(
|
476
|
+
if "truth_mode" not in config:
|
477
|
+
raise SonusAIError("truth_mode not in top level config")
|
483
478
|
|
484
|
-
if config[
|
485
|
-
raise SonusAIError(
|
479
|
+
if config["truth_mode"] not in ["normal", "mutex"]:
|
480
|
+
raise SonusAIError("Invalid truth_mode in top level config")
|
486
481
|
|
487
|
-
if
|
488
|
-
config[
|
482
|
+
if "truth_settings" in config and not isinstance(config["truth_settings"], list):
|
483
|
+
config["truth_settings"] = [config["truth_settings"]]
|
489
484
|
|
490
485
|
return config
|
491
486
|
|
@@ -497,34 +492,34 @@ def validate_ontology(ontology: str, items: list[FileInfo] | list[DirInfo]) -> l
|
|
497
492
|
|
498
493
|
if exists(ontology):
|
499
494
|
ontology_data = get_ontology(ontology)
|
500
|
-
logger.debug(f
|
501
|
-
logger.debug(
|
495
|
+
logger.debug(f"Reference ontology in {ontology} has {len(ontology_data)} classes")
|
496
|
+
logger.debug("")
|
502
497
|
|
503
|
-
logger.info(
|
498
|
+
logger.info("Checking tree against reference ontology")
|
504
499
|
all_dirs = get_dirs()
|
505
500
|
valid = True
|
506
501
|
for file in all_dirs:
|
507
502
|
if not validate_class(ontology_data, file):
|
508
503
|
valid = False
|
509
504
|
if valid:
|
510
|
-
logger.info(
|
511
|
-
logger.info(
|
505
|
+
logger.info("PASS")
|
506
|
+
logger.info("")
|
512
507
|
|
513
|
-
logger.info(
|
508
|
+
logger.info("Checking files against reference ontology")
|
514
509
|
valid = True
|
515
510
|
for item in items:
|
516
511
|
if not validate_class(ontology_data, item):
|
517
512
|
valid = False
|
518
513
|
if valid:
|
519
|
-
logger.info(
|
520
|
-
logger.info(
|
514
|
+
logger.info("PASS")
|
515
|
+
logger.info("")
|
521
516
|
|
522
517
|
return ontology_data
|
523
518
|
|
524
519
|
return []
|
525
520
|
|
526
521
|
|
527
|
-
def get_node_from_name(ontology_data: list[OntologyInfo], name: str) ->
|
522
|
+
def get_node_from_name(ontology_data: list[OntologyInfo], name: str) -> OntologyInfo | None:
|
528
523
|
from sonusai import logger
|
529
524
|
|
530
525
|
nodes = [item for item in ontology_data if name in item.name]
|
@@ -532,9 +527,9 @@ def get_node_from_name(ontology_data: list[OntologyInfo], name: str) -> Optional
|
|
532
527
|
return nodes[0]
|
533
528
|
|
534
529
|
if nodes:
|
535
|
-
logger.warning(f
|
530
|
+
logger.warning(f"Found multiple entries in reference ontology that match {name}")
|
536
531
|
else:
|
537
|
-
logger.warning(f
|
532
|
+
logger.warning(f"Could not find entry for {name} in reference ontology")
|
538
533
|
|
539
534
|
return None
|
540
535
|
|
@@ -549,8 +544,8 @@ def write_metadata_to_tree(ontology_data: list[OntologyInfo], file: FileInfo):
|
|
549
544
|
node = get_node_from_name(ontology_data, file.classes[-1])
|
550
545
|
if node is not None:
|
551
546
|
dir_name = dirname(file.file)
|
552
|
-
json_name = dir_name +
|
553
|
-
with open(file=json_name, mode=
|
547
|
+
json_name = dir_name + "/" + basename(dir_name) + ".json"
|
548
|
+
with open(file=json_name, mode="w") as f:
|
554
549
|
json.dump(asdict(node), f)
|
555
550
|
|
556
551
|
|
@@ -565,19 +560,19 @@ def report_leaf_fold_data_usage(all_files: list[FileInfo], use_files: list[FileI
|
|
565
560
|
use_leaves = get_leaves_from_files(use_files)
|
566
561
|
all_leaves = get_leaves_from_files(all_files)
|
567
562
|
|
568
|
-
logger.debug(
|
563
|
+
logger.debug("Data folds present in each leaf")
|
569
564
|
leaf_len = len(max(all_leaves, key=len))
|
570
565
|
for leaf in all_leaves:
|
571
566
|
folds = get_folds_from_leaf(all_files, leaf)
|
572
|
-
logger.debug(f
|
573
|
-
logger.debug(
|
567
|
+
logger.debug(f" {leaf:{leaf_len}} {folds}")
|
568
|
+
logger.debug("")
|
574
569
|
|
575
570
|
dif_leaves = set(all_leaves).symmetric_difference(use_leaves)
|
576
571
|
if dif_leaves:
|
577
|
-
logger.warning(
|
572
|
+
logger.warning("This fold selection is missing data from the following leaves")
|
578
573
|
for c in dif_leaves:
|
579
|
-
logger.warning(f
|
580
|
-
logger.warning(
|
574
|
+
logger.warning(f" {c}")
|
575
|
+
logger.warning("")
|
581
576
|
|
582
577
|
|
583
578
|
def main() -> None:
|
@@ -588,12 +583,12 @@ def main() -> None:
|
|
588
583
|
|
589
584
|
args = docopt(trim_docstring(__doc__), version=sonusai.__version__, options_first=True)
|
590
585
|
|
591
|
-
verbose = args[
|
592
|
-
output_name = args[
|
593
|
-
fold = args[
|
594
|
-
ontology = args[
|
595
|
-
hierarchical = args[
|
596
|
-
update = args[
|
586
|
+
verbose = args["--verbose"]
|
587
|
+
output_name = args["--output"]
|
588
|
+
fold = args["--fold"]
|
589
|
+
ontology = args["--ontology"]
|
590
|
+
hierarchical = args["--hierarchical"]
|
591
|
+
update = args["--update"]
|
597
592
|
|
598
593
|
import csv
|
599
594
|
from dataclasses import asdict
|
@@ -607,29 +602,31 @@ def main() -> None:
|
|
607
602
|
from sonusai import logger
|
608
603
|
|
609
604
|
if not output_name:
|
610
|
-
output_name = basename(getcwd()) +
|
605
|
+
output_name = basename(getcwd()) + ".yml"
|
611
606
|
|
612
|
-
create_file_handler(
|
607
|
+
create_file_handler("gentcst.log")
|
613
608
|
|
614
|
-
config, labels = gentcst(
|
615
|
-
|
616
|
-
|
617
|
-
|
618
|
-
|
609
|
+
config, labels = gentcst(
|
610
|
+
fold=fold,
|
611
|
+
ontology=ontology,
|
612
|
+
hierarchical=hierarchical,
|
613
|
+
update=update,
|
614
|
+
verbose=verbose,
|
615
|
+
)
|
619
616
|
|
620
|
-
with open(file=output_name, mode=
|
617
|
+
with open(file=output_name, mode="w") as f:
|
621
618
|
yaml.dump(config, f)
|
622
|
-
logger.info(f
|
619
|
+
logger.info(f"Wrote config to {output_name}")
|
623
620
|
|
624
|
-
csv_fields = [
|
625
|
-
csv_name = splitext(output_name)[0] +
|
621
|
+
csv_fields = ["index", "display_name"]
|
622
|
+
csv_name = splitext(output_name)[0] + ".csv"
|
626
623
|
|
627
|
-
with open(file=csv_name, mode=
|
624
|
+
with open(file=csv_name, mode="w") as f:
|
628
625
|
writer = csv.DictWriter(f, fieldnames=csv_fields)
|
629
626
|
writer.writeheader()
|
630
627
|
writer.writerows([asdict(label) for label in labels])
|
631
|
-
logger.info(f
|
628
|
+
logger.info(f"Wrote labels to {csv_name}")
|
632
629
|
|
633
630
|
|
634
|
-
if __name__ ==
|
631
|
+
if __name__ == "__main__":
|
635
632
|
main()
|