sonusai 0.15.6__py3-none-any.whl → 0.15.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sonusai/evaluate.py DELETED
@@ -1,245 +0,0 @@
1
- """sonusai evaluate
2
-
3
- usage: evaluate [-hv] [-i MIXID] (-f FEATURE) (-p PREDICT) [-t PTHR] LOC
4
-
5
- options:
6
- -h, --help
7
- -v, --verbose Be verbose.
8
- -i MIXID, --mixid MIXID Mixture ID(s) to generate. [default: *].
9
- -p PREDICT, --predict PREDICT A directory containing prediction data.
10
- -t PTHR, --thr PTHR Optional prediction decision threshold(s). [default: 0].
11
-
12
- Evaluate calculates performance metrics of neural-network models from model prediction data and genft data.
13
-
14
- Inputs:
15
- LOC A SonusAI mixture database directory.
16
- MIXID A glob of mixture ID(s) to generate.
17
- PREDICT A directory containing SonusAI predict HDF5 files. Contains:
18
- dataset: predict (either [frames, num_classes] or [frames, timesteps, num_classes])
19
- PTHR Scalar or array of thresholds. Default 0 will select values:
20
- argmax() if mixdb indicates single-label mode (truth_mutex = true)
21
- 0.5 if mixdb indicates multi-label mode (truth_mutex = false)
22
- If PTHR = -1, optimal thresholds are calculated using precision_recall_curve() which
23
- optimizes F1 score.
24
- """
25
- import numpy as np
26
-
27
- from sonusai import logger
28
- from sonusai.mixture import Feature
29
- from sonusai.mixture import MixtureDatabase
30
- from sonusai.mixture import Predict
31
- from sonusai.mixture import Segsnr
32
- from sonusai.mixture import Truth
33
-
34
-
35
- def evaluate(mixdb: MixtureDatabase,
36
- truth: Truth,
37
- predict: Predict = None,
38
- segsnr: Segsnr = None,
39
- output_dir: str = None,
40
- predict_thr: float | np.ndarray = 0,
41
- feature: Feature = None,
42
- verbose: bool = False) -> None:
43
- from os.path import join
44
-
45
- from sonusai import initial_log_messages
46
- from sonusai import update_console_handler
47
- from sonusai.metrics import calc_optimal_thresholds
48
- from sonusai.metrics import class_summary
49
- from sonusai.metrics import snr_summary
50
- from sonusai.mixture import SAMPLE_RATE
51
- from sonusai.queries import get_mixids_from_snr
52
- from sonusai.utils import get_num_classes_from_predict
53
- from sonusai.utils import human_readable_size
54
- from sonusai.utils import reshape_outputs
55
- from sonusai.utils import seconds_to_hms
56
-
57
- update_console_handler(verbose)
58
- initial_log_messages('evaluate')
59
-
60
- if truth.shape[-1] != predict.shape[-1]:
61
- logger.exception(f'Number of classes in truth and predict are not equal. Exiting.')
62
- raise SystemExit(1)
63
-
64
- # truth, predict can be either [frames, num_classes] or [frames, timesteps, num_classes]
65
- # in binary case dim may not exist, detect this and set num_classes == 1
66
- timesteps = -1
67
- predict, truth = reshape_outputs(predict=predict, truth=truth, timesteps=timesteps)
68
- num_classes = get_num_classes_from_predict(predict=predict, timesteps=timesteps)
69
-
70
- fdiff = truth.shape[0] - predict.shape[0]
71
- if fdiff > 0:
72
- # truth = truth[0:-fdiff,:]
73
- predict = np.concatenate((predict, np.zeros((fdiff, num_classes), dtype=np.float32)))
74
- logger.info(f'Truth has more feature-frames than predict, padding predict with zeros to match.')
75
-
76
- if fdiff < 0:
77
- predict = predict[0:fdiff, :]
78
- logger.info(f'Predict has more feature-frames than truth, trimming predict to match.')
79
-
80
- # Check segsnr, input is always in transform frames
81
- compute_segsnr = False
82
- if len(segsnr) > 0:
83
- segsnr_feature_frames = segsnr.shape[0] / (mixdb.feature_step_samples / mixdb.ft_config.R)
84
- if segsnr_feature_frames == truth.shape[0]:
85
- compute_segsnr = True
86
- else:
87
- logger.warning('segsnr length does not match truth, ignoring.')
88
-
89
- # Check predict_thr array or scalar and return final scalar predict_thr value
90
- if not mixdb.truth_mutex:
91
- if num_classes > 1:
92
- if not isinstance(predict_thr, np.ndarray):
93
- if predict_thr == 0:
94
- # multi-label predict_thr scalar 0 force to 0.5 default
95
- predict_thr = np.atleast_1d(0.5)
96
- else:
97
- predict_thr = np.atleast_1d(predict_thr)
98
- else:
99
- if predict_thr.ndim == 1:
100
- if predict_thr[0] == 0:
101
- # multi-label predict_thr array scalar 0 force to 0.5 default
102
- predict_thr = np.atleast_1d(0.5)
103
- else:
104
- # multi-label predict_thr array set to scalar = array[0]
105
- predict_thr = predict_thr[0]
106
- else:
107
- # single-label mode, force argmax mode
108
- predict_thr = np.atleast_1d(0)
109
-
110
- if predict_thr == -1:
111
- thrpr, thrroc, _, _ = calc_optimal_thresholds(truth, predict, timesteps)
112
- predict_thr = np.atleast_1d(thrpr)
113
- predict_thr = np.maximum(predict_thr, 0.001) # enforce lower limit
114
- predict_thr = np.minimum(predict_thr, 0.999) # enforce upper limit
115
- predict_thr = predict_thr.round(2)
116
-
117
- # Summarize the mixture data
118
- num_mixtures = mixdb.num_mixtures
119
- total_samples = sum([mixture.samples for mixture in mixdb.mixtures])
120
- duration = total_samples / SAMPLE_RATE
121
-
122
- logger.info('')
123
- logger.info(f'Mixtures: {num_mixtures}')
124
- logger.info(f'Duration: {seconds_to_hms(seconds=duration)}')
125
- logger.info(f'truth: {human_readable_size(truth.nbytes, 1)}')
126
- logger.info(f'predict: {human_readable_size(predict.nbytes, 1)}')
127
- if compute_segsnr:
128
- logger.info(f'segsnr: {human_readable_size(segsnr.nbytes, 1)}')
129
- if feature:
130
- logger.info(f'feature: {human_readable_size(feature.nbytes, 1)}')
131
-
132
- logger.info(f'Classes: {num_classes}')
133
- if mixdb.truth_mutex:
134
- logger.info(f'Mode: Single-label / truth_mutex / softmax')
135
- else:
136
- logger.info(f'Mode: Multi-label / Binary')
137
-
138
- mxid_snro = get_mixids_from_snr(mixdb=mixdb)
139
- snrlist = list(mxid_snro.keys())
140
- snrlist.sort(reverse=True)
141
- logger.info(f'Ordered SNRs: {snrlist}\n')
142
- predict_thr_info = predict_thr.transpose() if isinstance(predict_thr, np.ndarray) else predict_thr
143
- logger.info(f'Prediction Threshold(s): {predict_thr_info}\n')
144
-
145
- # Top-level report over all mixtures
146
- macrodf, microdf, wghtdf, mxid_snro = snr_summary(mixdb=mixdb,
147
- mixid=':',
148
- truth_f=truth,
149
- predict=predict,
150
- segsnr=segsnr if compute_segsnr else None,
151
- predict_thr=predict_thr)
152
-
153
- if num_classes > 1:
154
- logger.info(f'Metrics micro-avg per SNR over all {num_mixtures} mixtures:')
155
- else:
156
- logger.info(f'Metrics per SNR over all {num_mixtures} mixtures:')
157
- logger.info(microdf.round(3).to_string())
158
- logger.info('')
159
- if output_dir:
160
- microdf.round(3).to_csv(join(output_dir, 'snr.csv'))
161
-
162
- if mixdb.truth_mutex:
163
- macrodf, microdf, wghtdf, mxid_snro = snr_summary(mixdb=mixdb,
164
- mixid=':',
165
- truth_f=truth[:, 0:-1],
166
- predict=predict[:, 0:-1],
167
- segsnr=segsnr if compute_segsnr else None,
168
- predict_thr=predict_thr)
169
-
170
- logger.info(f'Metrics micro-avg without "Other" class per SNR over all {num_mixtures} mixtures:')
171
- logger.info(microdf.round(3).to_string())
172
- logger.info('')
173
- if output_dir:
174
- microdf.round(3).to_csv(join(output_dir, 'snrwo.csv'))
175
-
176
- for snri in snrlist:
177
- mxids = mxid_snro[snri]
178
- classdf = class_summary(mixdb, mxids, truth, predict, predict_thr)
179
- logger.info(f'Metrics per class for SNR {snri} over {len(mxids)} mixtures:')
180
- logger.info(classdf.round(3).to_string())
181
- logger.info('')
182
- if output_dir:
183
- classdf.round(3).to_csv(join(output_dir, f'class_snr{snri}.csv'))
184
-
185
-
186
- def main() -> None:
187
- from datetime import datetime
188
- from os import mkdir
189
- from os.path import join
190
-
191
- import h5py
192
- from docopt import docopt
193
-
194
- import sonusai
195
- from sonusai import SonusAIError
196
- from sonusai import create_file_handler
197
- from sonusai.utils import read_predict_data
198
- from sonusai.utils import trim_docstring
199
-
200
- args = docopt(trim_docstring(__doc__), version=sonusai.__version__, options_first=True)
201
-
202
- verbose = args['--verbose']
203
- feature_name = args['--feature']
204
- predict_name = args['--predict']
205
- predict_threshold = np.array(float(args['--thr']), dtype=np.float32)
206
- location = args['LOC']
207
-
208
- mixdb = MixtureDatabase(location)
209
-
210
- # create output directory
211
- output_dir = f'evaluate-{datetime.now():%Y%m%d}'
212
- try:
213
- mkdir(output_dir)
214
- except OSError as _:
215
- output_dir = f'evaluate-{datetime.now():%Y%m%d-%H%M%S}'
216
- try:
217
- mkdir(output_dir)
218
- except OSError as error:
219
- raise SonusAIError(f'Could not create directory, {output_dir}: {error}')
220
-
221
- create_file_handler(join(output_dir, 'evaluate.log'))
222
-
223
- with h5py.File(feature_name, 'r') as f:
224
- truth_f = np.array(f['truth_f'])
225
- segsnr = np.array(f['segsnr'])
226
-
227
- predict = read_predict_data(predict_name)
228
-
229
- evaluate(mixdb=mixdb,
230
- truth=truth_f,
231
- segsnr=segsnr,
232
- output_dir=output_dir,
233
- predict=predict,
234
- predict_thr=predict_threshold,
235
- verbose=verbose)
236
-
237
- logger.info(f'Wrote results to {output_dir}')
238
-
239
-
240
- if __name__ == '__main__':
241
- try:
242
- main()
243
- except KeyboardInterrupt:
244
- logger.info('Canceled due to keyboard interrupt')
245
- raise SystemExit(0)