ChessAnalysisPipeline 0.0.17.dev3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- CHAP/TaskManager.py +216 -0
- CHAP/__init__.py +27 -0
- CHAP/common/__init__.py +57 -0
- CHAP/common/models/__init__.py +8 -0
- CHAP/common/models/common.py +124 -0
- CHAP/common/models/integration.py +659 -0
- CHAP/common/models/map.py +1291 -0
- CHAP/common/processor.py +2869 -0
- CHAP/common/reader.py +658 -0
- CHAP/common/utils.py +110 -0
- CHAP/common/writer.py +730 -0
- CHAP/edd/__init__.py +23 -0
- CHAP/edd/models.py +876 -0
- CHAP/edd/processor.py +3069 -0
- CHAP/edd/reader.py +1023 -0
- CHAP/edd/select_material_params_gui.py +348 -0
- CHAP/edd/utils.py +1572 -0
- CHAP/edd/writer.py +26 -0
- CHAP/foxden/__init__.py +19 -0
- CHAP/foxden/models.py +71 -0
- CHAP/foxden/processor.py +124 -0
- CHAP/foxden/reader.py +224 -0
- CHAP/foxden/utils.py +80 -0
- CHAP/foxden/writer.py +168 -0
- CHAP/giwaxs/__init__.py +11 -0
- CHAP/giwaxs/models.py +491 -0
- CHAP/giwaxs/processor.py +776 -0
- CHAP/giwaxs/reader.py +8 -0
- CHAP/giwaxs/writer.py +8 -0
- CHAP/inference/__init__.py +7 -0
- CHAP/inference/processor.py +69 -0
- CHAP/inference/reader.py +8 -0
- CHAP/inference/writer.py +8 -0
- CHAP/models.py +227 -0
- CHAP/pipeline.py +479 -0
- CHAP/processor.py +125 -0
- CHAP/reader.py +124 -0
- CHAP/runner.py +277 -0
- CHAP/saxswaxs/__init__.py +7 -0
- CHAP/saxswaxs/processor.py +8 -0
- CHAP/saxswaxs/reader.py +8 -0
- CHAP/saxswaxs/writer.py +8 -0
- CHAP/server.py +125 -0
- CHAP/sin2psi/__init__.py +7 -0
- CHAP/sin2psi/processor.py +8 -0
- CHAP/sin2psi/reader.py +8 -0
- CHAP/sin2psi/writer.py +8 -0
- CHAP/tomo/__init__.py +15 -0
- CHAP/tomo/models.py +210 -0
- CHAP/tomo/processor.py +3862 -0
- CHAP/tomo/reader.py +9 -0
- CHAP/tomo/writer.py +59 -0
- CHAP/utils/__init__.py +6 -0
- CHAP/utils/converters.py +188 -0
- CHAP/utils/fit.py +2947 -0
- CHAP/utils/general.py +2655 -0
- CHAP/utils/material.py +274 -0
- CHAP/utils/models.py +595 -0
- CHAP/utils/parfile.py +224 -0
- CHAP/writer.py +122 -0
- MLaaS/__init__.py +0 -0
- MLaaS/ktrain.py +205 -0
- MLaaS/mnist_img.py +83 -0
- MLaaS/tfaas_client.py +371 -0
- chessanalysispipeline-0.0.17.dev3.dist-info/LICENSE +60 -0
- chessanalysispipeline-0.0.17.dev3.dist-info/METADATA +29 -0
- chessanalysispipeline-0.0.17.dev3.dist-info/RECORD +70 -0
- chessanalysispipeline-0.0.17.dev3.dist-info/WHEEL +5 -0
- chessanalysispipeline-0.0.17.dev3.dist-info/entry_points.txt +2 -0
- chessanalysispipeline-0.0.17.dev3.dist-info/top_level.txt +2 -0
CHAP/utils/general.py
ADDED
|
@@ -0,0 +1,2655 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# -*- coding: utf-8 -*-
|
|
3
|
+
"""
|
|
4
|
+
File : general.py
|
|
5
|
+
Author : Rolf Verberg <rolfverberg AT gmail dot com>
|
|
6
|
+
Description: A collection of general modules
|
|
7
|
+
"""
|
|
8
|
+
# RV write function that returns a list of peak indices for a given plot
|
|
9
|
+
# RV use raise_error concept on more functions
|
|
10
|
+
|
|
11
|
+
# System modules
|
|
12
|
+
from ast import literal_eval
|
|
13
|
+
import collections.abc
|
|
14
|
+
from logging import getLogger
|
|
15
|
+
import os
|
|
16
|
+
import re
|
|
17
|
+
import sys
|
|
18
|
+
|
|
19
|
+
# Third party modules
|
|
20
|
+
import numpy as np
|
|
21
|
+
try:
|
|
22
|
+
import matplotlib.pyplot as plt
|
|
23
|
+
except ImportError:
|
|
24
|
+
pass
|
|
25
|
+
|
|
26
|
+
logger = getLogger(__name__)
|
|
27
|
+
|
|
28
|
+
# pylint: disable=no-member
|
|
29
|
+
tiny = np.finfo(np.float64).resolution
|
|
30
|
+
# pylint: enable=no-member
|
|
31
|
+
|
|
32
|
+
def gformat(val, length=11):
|
|
33
|
+
"""
|
|
34
|
+
Format a number with '%g'-like format, while:
|
|
35
|
+
- the length of the output string will be of the requested length
|
|
36
|
+
- positive numbers will have a leading blank
|
|
37
|
+
- the precision will be as high as possible
|
|
38
|
+
- trailing zeros will not be trimmed
|
|
39
|
+
"""
|
|
40
|
+
# Code taken from lmfit library
|
|
41
|
+
if val is None or isinstance(val, bool):
|
|
42
|
+
return f'{repr(val):>{length}s}'
|
|
43
|
+
try:
|
|
44
|
+
expon = int(np.log10(abs(val)))
|
|
45
|
+
except (OverflowError, ValueError):
|
|
46
|
+
expon = 0
|
|
47
|
+
except TypeError:
|
|
48
|
+
return f'{repr(val):>{length}s}'
|
|
49
|
+
|
|
50
|
+
length = max(length, 7)
|
|
51
|
+
form = 'e'
|
|
52
|
+
prec = length - 7
|
|
53
|
+
if abs(expon) > 99:
|
|
54
|
+
prec -= 1
|
|
55
|
+
elif 0 < expon < prec+4 or -expon < prec-1 <= 0:
|
|
56
|
+
form = 'f'
|
|
57
|
+
prec += 4
|
|
58
|
+
if expon > 0:
|
|
59
|
+
prec -= expon
|
|
60
|
+
return f'{val:{length}.{prec}{form}}'
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def getfloat_attr(obj, attr, length=11):
|
|
64
|
+
"""Format an attribute of an object for printing."""
|
|
65
|
+
# Code taken from lmfit library
|
|
66
|
+
val = getattr(obj, attr, None)
|
|
67
|
+
if val is None:
|
|
68
|
+
return 'unknown'
|
|
69
|
+
if isinstance(val, int):
|
|
70
|
+
return f'{val}'
|
|
71
|
+
if isinstance(val, float):
|
|
72
|
+
return gformat(val, length=length).strip()
|
|
73
|
+
return repr(val)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def depth_list(_list):
|
|
77
|
+
"""Return the depth of a list."""
|
|
78
|
+
return isinstance(_list, list) and 1+max(map(depth_list, _list))
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def depth_tuple(_tuple):
|
|
82
|
+
"""Return the depth of a tuple."""
|
|
83
|
+
return isinstance(_tuple, tuple) and 1+max(map(depth_tuple, _tuple))
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def unwrap_tuple(_tuple):
|
|
87
|
+
"""Unwrap a tuple."""
|
|
88
|
+
if depth_tuple(_tuple) > 1 and len(_tuple) == 1:
|
|
89
|
+
_tuple = unwrap_tuple(*_tuple)
|
|
90
|
+
return _tuple
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def all_any(l, key):
|
|
94
|
+
"""Check for a common key in a list of dictionaries, looping
|
|
95
|
+
at maximum only once over the entire list.
|
|
96
|
+
|
|
97
|
+
:param l: Input list.
|
|
98
|
+
:type l: list[dict]
|
|
99
|
+
:param key: The common dictionary key.
|
|
100
|
+
:type key: Any
|
|
101
|
+
:return: `1` if `all(l, key)`, `0` if `not any(l, key)`, or `-1`
|
|
102
|
+
otherwise. Return `None` for a zero length input list.
|
|
103
|
+
:rtype: Union[None, int]
|
|
104
|
+
"""
|
|
105
|
+
ret = None
|
|
106
|
+
for d in l:
|
|
107
|
+
if key in d:
|
|
108
|
+
if ret == 0:
|
|
109
|
+
ret = -1
|
|
110
|
+
break
|
|
111
|
+
elif ret is None:
|
|
112
|
+
ret = 1
|
|
113
|
+
else:
|
|
114
|
+
if ret == 1:
|
|
115
|
+
ret = -1
|
|
116
|
+
break
|
|
117
|
+
elif ret is None:
|
|
118
|
+
ret = 0
|
|
119
|
+
return ret
|
|
120
|
+
|
|
121
|
+
def illegal_value(value, name, location=None, raise_error=False, log=True):
|
|
122
|
+
"""Print illegal value message and/or raise error."""
|
|
123
|
+
if not isinstance(location, str):
|
|
124
|
+
location = ''
|
|
125
|
+
else:
|
|
126
|
+
location = f'in {location} '
|
|
127
|
+
if isinstance(name, str):
|
|
128
|
+
error_msg = \
|
|
129
|
+
f'Illegal value for {name} {location}({value}, {type(value)})'
|
|
130
|
+
else:
|
|
131
|
+
error_msg = f'Illegal value {location}({value}, {type(value)})'
|
|
132
|
+
if log:
|
|
133
|
+
logger.error(error_msg)
|
|
134
|
+
if raise_error:
|
|
135
|
+
raise ValueError(error_msg)
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def illegal_combination(
|
|
139
|
+
value1, name1, value2, name2, location=None, raise_error=False,
|
|
140
|
+
log=True):
|
|
141
|
+
"""Print illegal combination message and/or raise error."""
|
|
142
|
+
if not isinstance(location, str):
|
|
143
|
+
location = ''
|
|
144
|
+
else:
|
|
145
|
+
location = f'in {location} '
|
|
146
|
+
if isinstance(name1, str):
|
|
147
|
+
error_msg = f'Illegal combination for {name1} and {name2} {location}' \
|
|
148
|
+
f'({value1}, {type(value1)} and {value2}, {type(value2)})'
|
|
149
|
+
else:
|
|
150
|
+
error_msg = f'Illegal combination {location}' \
|
|
151
|
+
f'({value1}, {type(value1)} and {value2}, {type(value2)})'
|
|
152
|
+
if log:
|
|
153
|
+
logger.error(error_msg)
|
|
154
|
+
if raise_error:
|
|
155
|
+
raise ValueError(error_msg)
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def not_zero(value):
|
|
159
|
+
"""Return value with a minimal absolute size of tiny,
|
|
160
|
+
preserving the sign.
|
|
161
|
+
"""
|
|
162
|
+
return float(np.copysign(max(tiny, abs(value)), value))
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def test_ge_gt_le_lt(
|
|
166
|
+
ge, gt, le, lt, func, location=None, raise_error=False, log=True):
|
|
167
|
+
"""Check individual and mutual validity of ge, gt, le, lt
|
|
168
|
+
qualifiers.
|
|
169
|
+
|
|
170
|
+
:param func: Test for integers or numbers.
|
|
171
|
+
:type func: callable: is_int, is_num
|
|
172
|
+
:return: True upon success or False when mutually exlusive.
|
|
173
|
+
:rtype: bool
|
|
174
|
+
"""
|
|
175
|
+
if ge is None and gt is None and le is None and lt is None:
|
|
176
|
+
return True
|
|
177
|
+
if ge is not None:
|
|
178
|
+
if not func(ge):
|
|
179
|
+
illegal_value(ge, 'ge', location, raise_error, log)
|
|
180
|
+
return False
|
|
181
|
+
if gt is not None:
|
|
182
|
+
illegal_combination(ge, 'ge', gt, 'gt', location, raise_error, log)
|
|
183
|
+
return False
|
|
184
|
+
elif gt is not None and not func(gt):
|
|
185
|
+
illegal_value(gt, 'gt', location, raise_error, log)
|
|
186
|
+
return False
|
|
187
|
+
if le is not None:
|
|
188
|
+
if not func(le):
|
|
189
|
+
illegal_value(le, 'le', location, raise_error, log)
|
|
190
|
+
return False
|
|
191
|
+
if lt is not None:
|
|
192
|
+
illegal_combination(le, 'le', lt, 'lt', location, raise_error, log)
|
|
193
|
+
return False
|
|
194
|
+
elif lt is not None and not func(lt):
|
|
195
|
+
illegal_value(lt, 'lt', location, raise_error, log)
|
|
196
|
+
return False
|
|
197
|
+
if ge is not None:
|
|
198
|
+
if le is not None and ge > le:
|
|
199
|
+
illegal_combination(ge, 'ge', le, 'le', location, raise_error, log)
|
|
200
|
+
return False
|
|
201
|
+
if lt is not None and ge >= lt:
|
|
202
|
+
illegal_combination(ge, 'ge', lt, 'lt', location, raise_error, log)
|
|
203
|
+
return False
|
|
204
|
+
elif gt is not None:
|
|
205
|
+
if le is not None and gt >= le:
|
|
206
|
+
illegal_combination(gt, 'gt', le, 'le', location, raise_error, log)
|
|
207
|
+
return False
|
|
208
|
+
if lt is not None and gt >= lt:
|
|
209
|
+
illegal_combination(gt, 'gt', lt, 'lt', location, raise_error, log)
|
|
210
|
+
return False
|
|
211
|
+
return True
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
def range_string_ge_gt_le_lt(ge=None, gt=None, le=None, lt=None):
|
|
215
|
+
"""Return a range string representation matching the ge, gt, le, lt
|
|
216
|
+
qualifiers. Does not validate the inputs, do that as needed before
|
|
217
|
+
calling.
|
|
218
|
+
"""
|
|
219
|
+
range_string = ''
|
|
220
|
+
if ge is not None:
|
|
221
|
+
if le is None and lt is None:
|
|
222
|
+
range_string += f'>= {ge}'
|
|
223
|
+
else:
|
|
224
|
+
range_string += f'[{ge}, '
|
|
225
|
+
elif gt is not None:
|
|
226
|
+
if le is None and lt is None:
|
|
227
|
+
range_string += f'> {gt}'
|
|
228
|
+
else:
|
|
229
|
+
range_string += f'({gt}, '
|
|
230
|
+
if le is not None:
|
|
231
|
+
if ge is None and gt is None:
|
|
232
|
+
range_string += f'<= {le}'
|
|
233
|
+
else:
|
|
234
|
+
range_string += f'{le}]'
|
|
235
|
+
elif lt is not None:
|
|
236
|
+
if ge is None and gt is None:
|
|
237
|
+
range_string += f'< {lt}'
|
|
238
|
+
else:
|
|
239
|
+
range_string += f'{lt})'
|
|
240
|
+
return range_string
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
def is_int(v, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True):
|
|
244
|
+
"""Value is an integer in range ge <= v <= le or gt < v < lt or
|
|
245
|
+
some combination.
|
|
246
|
+
|
|
247
|
+
:return: True if yes or False is no.
|
|
248
|
+
:rtype: bool
|
|
249
|
+
"""
|
|
250
|
+
return _is_int_or_num(v, 'int', ge, gt, le, lt, raise_error, log)
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
def is_num(v, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True):
|
|
254
|
+
"""Value is a number in range ge <= v <= le or gt < v < lt or some
|
|
255
|
+
combination.
|
|
256
|
+
|
|
257
|
+
:return: True if yes or False is no.
|
|
258
|
+
:rtype: bool
|
|
259
|
+
"""
|
|
260
|
+
return _is_int_or_num(v, 'num', ge, gt, le, lt, raise_error, log)
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
def _is_int_or_num(
|
|
264
|
+
v, type_str, ge=None, gt=None, le=None, lt=None, raise_error=False,
|
|
265
|
+
log=True):
|
|
266
|
+
if type_str == 'int':
|
|
267
|
+
if not isinstance(v, int):
|
|
268
|
+
illegal_value(v, 'v', '_is_int_or_num', raise_error, log)
|
|
269
|
+
return False
|
|
270
|
+
if not test_ge_gt_le_lt(
|
|
271
|
+
ge, gt, le, lt, is_int, '_is_int_or_num', raise_error, log):
|
|
272
|
+
return False
|
|
273
|
+
elif type_str == 'num':
|
|
274
|
+
if not isinstance(v, (int, float)):
|
|
275
|
+
illegal_value(v, 'v', '_is_int_or_num', raise_error, log)
|
|
276
|
+
return False
|
|
277
|
+
if not test_ge_gt_le_lt(
|
|
278
|
+
ge, gt, le, lt, is_num, '_is_int_or_num', raise_error, log):
|
|
279
|
+
return False
|
|
280
|
+
else:
|
|
281
|
+
illegal_value(type_str, 'type_str', '_is_int_or_num', raise_error, log)
|
|
282
|
+
return False
|
|
283
|
+
if ge is None and gt is None and le is None and lt is None:
|
|
284
|
+
return True
|
|
285
|
+
error = False
|
|
286
|
+
error_msg = ''
|
|
287
|
+
if ge is not None and v < ge:
|
|
288
|
+
error = True
|
|
289
|
+
error_msg = f'Value {v} out of range: {v} !>= {ge}'
|
|
290
|
+
if not error and gt is not None and v <= gt:
|
|
291
|
+
error = True
|
|
292
|
+
error_msg = f'Value {v} out of range: {v} !> {gt}'
|
|
293
|
+
if not error and le is not None and v > le:
|
|
294
|
+
error = True
|
|
295
|
+
error_msg = f'Value {v} out of range: {v} !<= {le}'
|
|
296
|
+
if not error and lt is not None and v >= lt:
|
|
297
|
+
error = True
|
|
298
|
+
error_msg = f'Value {v} out of range: {v} !< {lt}'
|
|
299
|
+
if error:
|
|
300
|
+
if log:
|
|
301
|
+
logger.error(error_msg)
|
|
302
|
+
if raise_error:
|
|
303
|
+
raise ValueError(error_msg)
|
|
304
|
+
return False
|
|
305
|
+
return True
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
def is_int_pair(
|
|
309
|
+
v, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True):
|
|
310
|
+
"""Value is an integer pair, each in range ge <= v[i] <= le or
|
|
311
|
+
gt < v[i] < lt or ge[i] <= v[i] <= le[i] or gt[i] < v[i] < lt[i]
|
|
312
|
+
or some combination.
|
|
313
|
+
|
|
314
|
+
:return: True if yes or False is no.
|
|
315
|
+
:rtype: bool
|
|
316
|
+
"""
|
|
317
|
+
return _is_int_or_num_pair(v, 'int', ge, gt, le, lt, raise_error, log)
|
|
318
|
+
|
|
319
|
+
|
|
320
|
+
def is_num_pair(
|
|
321
|
+
v, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True):
|
|
322
|
+
"""Value is a number pair, each in range ge <= v[i] <= le or
|
|
323
|
+
gt < v[i] < lt or ge[i] <= v[i] <= le[i] or gt[i] < v[i] < lt[i]
|
|
324
|
+
or some combination.
|
|
325
|
+
|
|
326
|
+
:return: True if yes or False is no.
|
|
327
|
+
:rtype: bool
|
|
328
|
+
"""
|
|
329
|
+
return _is_int_or_num_pair(v, 'num', ge, gt, le, lt, raise_error, log)
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
def _is_int_or_num_pair(
|
|
333
|
+
v, type_str, ge=None, gt=None, le=None, lt=None, raise_error=False,
|
|
334
|
+
log=True):
|
|
335
|
+
if type_str == 'int':
|
|
336
|
+
if not (isinstance(v, (tuple, list)) and len(v) == 2
|
|
337
|
+
and isinstance(v[0], int) and isinstance(v[1], int)):
|
|
338
|
+
illegal_value(v, 'v', '_is_int_or_num_pair', raise_error, log)
|
|
339
|
+
return False
|
|
340
|
+
func = is_int
|
|
341
|
+
elif type_str == 'num':
|
|
342
|
+
if not (isinstance(v, (tuple, list)) and len(v) == 2
|
|
343
|
+
and isinstance(v[0], (int, float))
|
|
344
|
+
and isinstance(v[1], (int, float))):
|
|
345
|
+
illegal_value(v, 'v', '_is_int_or_num_pair', raise_error, log)
|
|
346
|
+
return False
|
|
347
|
+
func = is_num
|
|
348
|
+
else:
|
|
349
|
+
illegal_value(
|
|
350
|
+
type_str, 'type_str', '_is_int_or_num_pair', raise_error, log)
|
|
351
|
+
return False
|
|
352
|
+
if ge is None and gt is None and le is None and lt is None:
|
|
353
|
+
return True
|
|
354
|
+
if ge is None or func(ge, log=True):
|
|
355
|
+
ge = 2*[ge]
|
|
356
|
+
elif not _is_int_or_num_pair(
|
|
357
|
+
ge, type_str, raise_error=raise_error, log=log):
|
|
358
|
+
return False
|
|
359
|
+
if gt is None or func(gt, log=True):
|
|
360
|
+
gt = 2*[gt]
|
|
361
|
+
elif not _is_int_or_num_pair(
|
|
362
|
+
gt, type_str, raise_error=raise_error, log=log):
|
|
363
|
+
return False
|
|
364
|
+
if le is None or func(le, log=True):
|
|
365
|
+
le = 2*[le]
|
|
366
|
+
elif not _is_int_or_num_pair(
|
|
367
|
+
le, type_str, raise_error=raise_error, log=log):
|
|
368
|
+
return False
|
|
369
|
+
if lt is None or func(lt, log=True):
|
|
370
|
+
lt = 2*[lt]
|
|
371
|
+
elif not _is_int_or_num_pair(
|
|
372
|
+
lt, type_str, raise_error=raise_error, log=log):
|
|
373
|
+
return False
|
|
374
|
+
if (not func(v[0], ge[0], gt[0], le[0], lt[0], raise_error, log)
|
|
375
|
+
or not func(v[1], ge[1], gt[1], le[1], lt[1], raise_error, log)):
|
|
376
|
+
return False
|
|
377
|
+
return True
|
|
378
|
+
|
|
379
|
+
|
|
380
|
+
def is_int_series(
|
|
381
|
+
t_or_l, ge=None, gt=None, le=None, lt=None, raise_error=False,
|
|
382
|
+
log=True):
|
|
383
|
+
"""Value is a tuple or list of integers, each in range
|
|
384
|
+
ge <= l[i] <= le or gt < l[i] < lt or some combination.
|
|
385
|
+
"""
|
|
386
|
+
if not test_ge_gt_le_lt(
|
|
387
|
+
ge, gt, le, lt, is_int, 'is_int_series', raise_error, log):
|
|
388
|
+
return False
|
|
389
|
+
if not isinstance(t_or_l, (tuple, list)):
|
|
390
|
+
illegal_value(t_or_l, 't_or_l', 'is_int_series', raise_error, log)
|
|
391
|
+
return False
|
|
392
|
+
if any(not is_int(v, ge, gt, le, lt, raise_error, log) for v in t_or_l):
|
|
393
|
+
return False
|
|
394
|
+
return True
|
|
395
|
+
|
|
396
|
+
|
|
397
|
+
def is_num_series(
|
|
398
|
+
t_or_l, ge=None, gt=None, le=None, lt=None, raise_error=False,
|
|
399
|
+
log=True):
|
|
400
|
+
"""Value is a tuple or list of numbers, each in range
|
|
401
|
+
ge <= l[i] <= le or gt < l[i] < lt or some combination.
|
|
402
|
+
"""
|
|
403
|
+
if not test_ge_gt_le_lt(
|
|
404
|
+
ge, gt, le, lt, is_int, 'is_int_series', raise_error, log):
|
|
405
|
+
return False
|
|
406
|
+
if not isinstance(t_or_l, (tuple, list)):
|
|
407
|
+
illegal_value(t_or_l, 't_or_l', 'is_num_series', raise_error, log)
|
|
408
|
+
return False
|
|
409
|
+
if any(not is_num(v, ge, gt, le, lt, raise_error, log) for v in t_or_l):
|
|
410
|
+
return False
|
|
411
|
+
return True
|
|
412
|
+
|
|
413
|
+
|
|
414
|
+
def is_str_series(t_or_l, raise_error=False, log=True):
|
|
415
|
+
"""Value is a tuple or list of strings."""
|
|
416
|
+
if (not isinstance(t_or_l, (tuple, list))
|
|
417
|
+
or any(not isinstance(s, str) for s in t_or_l)):
|
|
418
|
+
illegal_value(t_or_l, 't_or_l', 'is_str_series', raise_error, log)
|
|
419
|
+
return False
|
|
420
|
+
return True
|
|
421
|
+
|
|
422
|
+
def is_str_or_str_series(t_or_l, raise_error=False, log=True):
|
|
423
|
+
"""Value is a string ot a tuple or list of strings."""
|
|
424
|
+
if isinstance(t_or_l, str):
|
|
425
|
+
return True
|
|
426
|
+
if (not isinstance(t_or_l, (tuple, list))
|
|
427
|
+
or any(not isinstance(s, str) for s in t_or_l)):
|
|
428
|
+
illegal_value(
|
|
429
|
+
t_or_l, 't_or_l', 'is_str_or_str_series', raise_error, log)
|
|
430
|
+
return False
|
|
431
|
+
return True
|
|
432
|
+
|
|
433
|
+
|
|
434
|
+
def is_dict_series(t_or_l, raise_error=False, log=True):
|
|
435
|
+
"""Value is a tuple or list of dictionaries."""
|
|
436
|
+
if (not isinstance(t_or_l, (tuple, list))
|
|
437
|
+
or any(not isinstance(d, dict) for d in t_or_l)):
|
|
438
|
+
illegal_value(t_or_l, 't_or_l', 'is_dict_series', raise_error, log)
|
|
439
|
+
return False
|
|
440
|
+
return True
|
|
441
|
+
|
|
442
|
+
|
|
443
|
+
def is_dict_nums(d, raise_error=False, log=True):
|
|
444
|
+
"""Value is a dictionary with single number values."""
|
|
445
|
+
if (not isinstance(d, dict)
|
|
446
|
+
or any(not is_num(v, log=False) for v in d.values())):
|
|
447
|
+
illegal_value(d, 'd', 'is_dict_nums', raise_error, log)
|
|
448
|
+
return False
|
|
449
|
+
return True
|
|
450
|
+
|
|
451
|
+
|
|
452
|
+
def is_dict_strings(d, raise_error=False, log=True):
|
|
453
|
+
"""Value is a dictionary with single string values."""
|
|
454
|
+
if (not isinstance(d, dict)
|
|
455
|
+
or any(not isinstance(v, str) for v in d.values())):
|
|
456
|
+
illegal_value(d, 'd', 'is_dict_strings', raise_error, log)
|
|
457
|
+
return False
|
|
458
|
+
return True
|
|
459
|
+
|
|
460
|
+
|
|
461
|
+
def is_index(v, ge=0, lt=None, raise_error=False, log=True):
|
|
462
|
+
"""Value is an array index in range ge <= v < lt. NOTE lt IS NOT
|
|
463
|
+
included!
|
|
464
|
+
"""
|
|
465
|
+
if isinstance(lt, int):
|
|
466
|
+
if lt <= ge:
|
|
467
|
+
illegal_combination(
|
|
468
|
+
ge, 'ge', lt, 'lt', 'is_index', raise_error, log)
|
|
469
|
+
return False
|
|
470
|
+
return is_int(v, ge=ge, lt=lt, raise_error=raise_error, log=log)
|
|
471
|
+
|
|
472
|
+
|
|
473
|
+
def is_index_range(v, ge=0, le=None, lt=None, raise_error=False, log=True):
|
|
474
|
+
"""Value is an array index range in range ge <= v[0] <= v[1] <= le
|
|
475
|
+
or ge <= v[0] <= v[1] < lt. NOTE le IS included!
|
|
476
|
+
"""
|
|
477
|
+
if not is_int_pair(v, raise_error=raise_error, log=log):
|
|
478
|
+
return False
|
|
479
|
+
if not test_ge_gt_le_lt(
|
|
480
|
+
ge, None, le, lt, is_int, 'is_index_range', raise_error, log):
|
|
481
|
+
return False
|
|
482
|
+
if (not ge <= v[0] <= v[1] or (le is not None and v[1] > le)
|
|
483
|
+
or (lt is not None and v[1] >= lt)):
|
|
484
|
+
if le is not None:
|
|
485
|
+
error_msg = \
|
|
486
|
+
f'Value {v} out of range: !({ge} <= {v[0]} <= {v[1]} <= {le})'
|
|
487
|
+
else:
|
|
488
|
+
error_msg = \
|
|
489
|
+
f'Value {v} out of range: !({ge} <= {v[0]} <= {v[1]} < {lt})'
|
|
490
|
+
if log:
|
|
491
|
+
logger.error(error_msg)
|
|
492
|
+
if raise_error:
|
|
493
|
+
raise ValueError(error_msg)
|
|
494
|
+
return False
|
|
495
|
+
return True
|
|
496
|
+
|
|
497
|
+
|
|
498
|
+
def index_nearest(a, value):
|
|
499
|
+
"""Return index of nearest array value."""
|
|
500
|
+
a = np.asarray(a)
|
|
501
|
+
if a.ndim > 1:
|
|
502
|
+
raise ValueError(
|
|
503
|
+
f'Invalid array dimension for parameter a ({a.ndim}, {a})')
|
|
504
|
+
# Round up for .5
|
|
505
|
+
value *= 1.0 + sys.float_info.epsilon
|
|
506
|
+
return (int)(np.argmin(np.abs(a-value)))
|
|
507
|
+
|
|
508
|
+
|
|
509
|
+
def index_nearest_down(a, value):
|
|
510
|
+
"""Return index of nearest array value, rounded down."""
|
|
511
|
+
a = np.asarray(a)
|
|
512
|
+
if a.ndim > 1:
|
|
513
|
+
raise ValueError(
|
|
514
|
+
f'Invalid array dimension for parameter a ({a.ndim}, {a})')
|
|
515
|
+
index = int(np.argmin(np.abs(a-value)))
|
|
516
|
+
if value < a[index] and index > 0:
|
|
517
|
+
index -= 1
|
|
518
|
+
return index
|
|
519
|
+
|
|
520
|
+
|
|
521
|
+
def index_nearest_up(a, value):
|
|
522
|
+
"""Return index of nearest array value, rounded up."""
|
|
523
|
+
a = np.asarray(a)
|
|
524
|
+
if a.ndim > 1:
|
|
525
|
+
raise ValueError(
|
|
526
|
+
f'Invalid array dimension for parameter a ({a.ndim}, {a})')
|
|
527
|
+
index = int(np.argmin(np.abs(a-value)))
|
|
528
|
+
if value > a[index] and index < a.size-1:
|
|
529
|
+
index += 1
|
|
530
|
+
return index
|
|
531
|
+
|
|
532
|
+
|
|
533
|
+
def get_consecutive_int_range(a):
|
|
534
|
+
"""Return a list of pairs of integers marking consecutive ranges
|
|
535
|
+
of integers.
|
|
536
|
+
"""
|
|
537
|
+
a.sort()
|
|
538
|
+
i = 0
|
|
539
|
+
int_ranges = []
|
|
540
|
+
while i < len(a):
|
|
541
|
+
j = i
|
|
542
|
+
while j < len(a)-1:
|
|
543
|
+
if a[j+1] > 1 + a[j]:
|
|
544
|
+
break
|
|
545
|
+
j += 1
|
|
546
|
+
int_ranges.append([a[i], a[j]])
|
|
547
|
+
i = j+1
|
|
548
|
+
return int_ranges
|
|
549
|
+
|
|
550
|
+
|
|
551
|
+
def round_to_n(x, n=1):
|
|
552
|
+
"""Round to a specific number of sig figs."""
|
|
553
|
+
if x == 0.0:
|
|
554
|
+
return 0
|
|
555
|
+
return type(x)(round(x, n-1-int(np.floor(np.log10(abs(x))))))
|
|
556
|
+
|
|
557
|
+
|
|
558
|
+
def round_up_to_n(x, n=1):
|
|
559
|
+
"""Round up to a specific number of sig figs."""
|
|
560
|
+
x_round = round_to_n(x, n)
|
|
561
|
+
if abs(x/x_round) > 1.0:
|
|
562
|
+
x_round += np.sign(x) * 10**(np.floor(np.log10(abs(x)))+1-n)
|
|
563
|
+
return type(x)(x_round)
|
|
564
|
+
|
|
565
|
+
|
|
566
|
+
def trunc_to_n(x, n=1):
|
|
567
|
+
"""Truncate to a specific number of sig figs."""
|
|
568
|
+
x_round = round_to_n(x, n)
|
|
569
|
+
if abs(x_round/x) > 1.0:
|
|
570
|
+
x_round -= np.sign(x) * 10**(np.floor(np.log10(abs(x)))+1-n)
|
|
571
|
+
return type(x)(x_round)
|
|
572
|
+
|
|
573
|
+
|
|
574
|
+
def almost_equal(a, b, sig_figs):
|
|
575
|
+
"""Check if equal to within a certain number of significant digits.
|
|
576
|
+
"""
|
|
577
|
+
if is_num(a) and is_num(b):
|
|
578
|
+
return abs(round_to_n(a-b, sig_figs)) < pow(10, 1-sig_figs)
|
|
579
|
+
raise ValueError(
|
|
580
|
+
f'Invalid value for a or b in almost_equal (a: {a}, {type(a)}, '
|
|
581
|
+
f'b: {b}, {type(b)})')
|
|
582
|
+
|
|
583
|
+
|
|
584
|
+
def string_to_list(
|
|
585
|
+
s, split_on_dash=True, remove_duplicates=True, sort=True,
|
|
586
|
+
raise_error=False):
|
|
587
|
+
"""Return a list of numbers by splitting/expanding a string on any
|
|
588
|
+
combination of commas, whitespaces, or dashes (when
|
|
589
|
+
split_on_dash=True).
|
|
590
|
+
e.g: '1, 3, 5-8, 12 ' -> [1, 3, 5, 6, 7, 8, 12]
|
|
591
|
+
|
|
592
|
+
:param s: Input string.
|
|
593
|
+
:type s: str
|
|
594
|
+
:param split_on_dash: Allow dashes in input string,
|
|
595
|
+
defaults to `True`.
|
|
596
|
+
:type split_on_dash: bool, optional
|
|
597
|
+
:param remove_duplicates: Removes duplicates (may also change the
|
|
598
|
+
order), defaults to `True`.
|
|
599
|
+
:type remove_duplicates: bool, optional
|
|
600
|
+
:param sort: Sort in ascending order, defaults to `True`.
|
|
601
|
+
:type sort: bool, optional
|
|
602
|
+
:param raise_error: Raise an exception upon any error,
|
|
603
|
+
defaults to `False`.
|
|
604
|
+
:type raise_error: bool, optional
|
|
605
|
+
:return: Input list or none upon an illegal input.
|
|
606
|
+
:rtype: list
|
|
607
|
+
"""
|
|
608
|
+
if not isinstance(s, str):
|
|
609
|
+
illegal_value(s, 's', location='string_to_list')
|
|
610
|
+
return None
|
|
611
|
+
if not s:
|
|
612
|
+
return []
|
|
613
|
+
try:
|
|
614
|
+
list1 = re.split(r'\s+,\s+|\s+,|,\s+|\s+|,', s.strip())
|
|
615
|
+
except (ValueError, TypeError, SyntaxError, MemoryError,
|
|
616
|
+
RecursionError) as e:
|
|
617
|
+
if not raise_error:
|
|
618
|
+
return None
|
|
619
|
+
raise e
|
|
620
|
+
if split_on_dash:
|
|
621
|
+
try:
|
|
622
|
+
l_of_i = []
|
|
623
|
+
for v in list1:
|
|
624
|
+
list2 = [
|
|
625
|
+
literal_eval(x)
|
|
626
|
+
for x in re.split(r'\s+-\s+|\s+-|-\s+|\s+|-', v)]
|
|
627
|
+
if len(list2) == 1:
|
|
628
|
+
l_of_i += list2
|
|
629
|
+
elif len(list2) == 2 and list2[1] > list2[0]:
|
|
630
|
+
l_of_i += list(range(list2[0], 1+list2[1]))
|
|
631
|
+
else:
|
|
632
|
+
raise ValueError
|
|
633
|
+
except (ValueError, TypeError, SyntaxError, MemoryError,
|
|
634
|
+
RecursionError) as e:
|
|
635
|
+
if not raise_error:
|
|
636
|
+
return None
|
|
637
|
+
raise e
|
|
638
|
+
else:
|
|
639
|
+
l_of_i = [literal_eval(x) for x in list1]
|
|
640
|
+
if remove_duplicates:
|
|
641
|
+
l_of_i = list(dict.fromkeys(l_of_i))
|
|
642
|
+
if sort:
|
|
643
|
+
l_of_i = sorted(l_of_i)
|
|
644
|
+
return l_of_i
|
|
645
|
+
|
|
646
|
+
|
|
647
|
+
def list_to_string(a):
|
|
648
|
+
"""Return a list of pairs of integers marking consecutive ranges
|
|
649
|
+
of integers in string notation."""
|
|
650
|
+
int_ranges = get_consecutive_int_range(a)
|
|
651
|
+
if not int_ranges:
|
|
652
|
+
return ''
|
|
653
|
+
if int_ranges[0][0] == int_ranges[0][1]:
|
|
654
|
+
s = f'{int_ranges[0][0]}'
|
|
655
|
+
else:
|
|
656
|
+
s = f'{int_ranges[0][0]}-{int_ranges[0][1]}'
|
|
657
|
+
for int_range in int_ranges[1:]:
|
|
658
|
+
if int_range[0] == int_range[1]:
|
|
659
|
+
s += f', {int_range[0]}'
|
|
660
|
+
else:
|
|
661
|
+
s += f', {int_range[0]}-{int_range[1]}'
|
|
662
|
+
return s
|
|
663
|
+
|
|
664
|
+
|
|
665
|
+
def get_trailing_int(string):
|
|
666
|
+
"""Get the trailing integer in a string."""
|
|
667
|
+
index_regex = re.compile(r'\d+$')
|
|
668
|
+
match = index_regex.search(string)
|
|
669
|
+
if match is None:
|
|
670
|
+
return None
|
|
671
|
+
return int(match.group())
|
|
672
|
+
|
|
673
|
+
|
|
674
|
+
def input_int(
|
|
675
|
+
s=None, ge=None, gt=None, le=None, lt=None, default=None, inset=None,
|
|
676
|
+
raise_error=False, log=True):
|
|
677
|
+
"""Interactively prompt the user to enter an integer."""
|
|
678
|
+
return _input_int_or_num(
|
|
679
|
+
'int', s, ge, gt, le, lt, default, inset, raise_error, log)
|
|
680
|
+
|
|
681
|
+
|
|
682
|
+
def input_num(
|
|
683
|
+
s=None, ge=None, gt=None, le=None, lt=None, default=None,
|
|
684
|
+
raise_error=False, log=True):
|
|
685
|
+
"""Interactively prompt the user to enter a number."""
|
|
686
|
+
return _input_int_or_num(
|
|
687
|
+
'num', s, ge, gt, le, lt, default, None, raise_error,log)
|
|
688
|
+
|
|
689
|
+
|
|
690
|
+
def _input_int_or_num(
|
|
691
|
+
type_str, s=None, ge=None, gt=None, le=None, lt=None, default=None,
|
|
692
|
+
inset=None, raise_error=False, log=True):
|
|
693
|
+
"""Interactively prompt the user to enter an integer or number."""
|
|
694
|
+
if type_str == 'int':
|
|
695
|
+
if not test_ge_gt_le_lt(
|
|
696
|
+
ge, gt, le, lt, is_int, '_input_int_or_num', raise_error, log):
|
|
697
|
+
return None
|
|
698
|
+
elif type_str == 'num':
|
|
699
|
+
if not test_ge_gt_le_lt(
|
|
700
|
+
ge, gt, le, lt, is_num, '_input_int_or_num', raise_error, log):
|
|
701
|
+
return None
|
|
702
|
+
else:
|
|
703
|
+
illegal_value(
|
|
704
|
+
type_str, 'type_str', '_input_int_or_num', raise_error, log)
|
|
705
|
+
return None
|
|
706
|
+
if default is not None:
|
|
707
|
+
if not _is_int_or_num(
|
|
708
|
+
default, type_str, raise_error=raise_error, log=log):
|
|
709
|
+
return None
|
|
710
|
+
if ge is not None and default < ge:
|
|
711
|
+
illegal_combination(
|
|
712
|
+
ge, 'ge', default, 'default', '_input_int_or_num', raise_error,
|
|
713
|
+
log)
|
|
714
|
+
return None
|
|
715
|
+
if gt is not None and default <= gt:
|
|
716
|
+
illegal_combination(
|
|
717
|
+
gt, 'gt', default, 'default', '_input_int_or_num', raise_error,
|
|
718
|
+
log)
|
|
719
|
+
return None
|
|
720
|
+
if le is not None and default > le:
|
|
721
|
+
illegal_combination(
|
|
722
|
+
le, 'le', default, 'default', '_input_int_or_num', raise_error,
|
|
723
|
+
log)
|
|
724
|
+
return None
|
|
725
|
+
if lt is not None and default >= lt:
|
|
726
|
+
illegal_combination(
|
|
727
|
+
lt, 'lt', default, 'default', '_input_int_or_num', raise_error,
|
|
728
|
+
log)
|
|
729
|
+
return None
|
|
730
|
+
default_string = f' [{default}]'
|
|
731
|
+
else:
|
|
732
|
+
default_string = ''
|
|
733
|
+
if inset is not None:
|
|
734
|
+
if (not isinstance(inset, (tuple, list))
|
|
735
|
+
or any(not isinstance(i, int) for i in inset)):
|
|
736
|
+
illegal_value(
|
|
737
|
+
inset, 'inset', '_input_int_or_num', raise_error, log)
|
|
738
|
+
return None
|
|
739
|
+
v_range = f'{range_string_ge_gt_le_lt(ge, gt, le, lt)}'
|
|
740
|
+
if v_range:
|
|
741
|
+
v_range = f' {v_range}'
|
|
742
|
+
if s is None:
|
|
743
|
+
if type_str == 'int':
|
|
744
|
+
print(f'Enter an integer{v_range}{default_string}: ')
|
|
745
|
+
else:
|
|
746
|
+
print(f'Enter a number{v_range}{default_string}: ')
|
|
747
|
+
else:
|
|
748
|
+
print(f'{s}{v_range}{default_string}: ')
|
|
749
|
+
try:
|
|
750
|
+
i = input()
|
|
751
|
+
if isinstance(i, str) and not i:
|
|
752
|
+
v = default
|
|
753
|
+
print(f'{v}')
|
|
754
|
+
else:
|
|
755
|
+
v = literal_eval(i)
|
|
756
|
+
if inset and v not in inset:
|
|
757
|
+
raise ValueError(f'{v} not part of the set {inset}')
|
|
758
|
+
except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError):
|
|
759
|
+
v = None
|
|
760
|
+
if not _is_int_or_num(v, type_str, ge, gt, le, lt):
|
|
761
|
+
v = _input_int_or_num(
|
|
762
|
+
type_str, s, ge, gt, le, lt, default, inset, raise_error, log)
|
|
763
|
+
return v
|
|
764
|
+
|
|
765
|
+
|
|
766
|
+
def input_int_list(
|
|
767
|
+
s=None, num_max=None, ge=None, le=None, split_on_dash=True,
|
|
768
|
+
remove_duplicates=True, sort=True, raise_error=False, log=True):
|
|
769
|
+
"""Prompt the user to input a list of integers and split the
|
|
770
|
+
entered string on any combination of commas, whitespaces, or
|
|
771
|
+
dashes (when split_on_dash is True).
|
|
772
|
+
e.g: '1 3,5-8 , 12 ' -> [1, 3, 5, 6, 7, 8, 12]
|
|
773
|
+
|
|
774
|
+
:param s: Interactive user prompt, defaults to `None`.
|
|
775
|
+
:type s: str, optional
|
|
776
|
+
:param num_max: Maximum number of inputs in list.
|
|
777
|
+
:type num_max: int, optional
|
|
778
|
+
:param ge: Minimum value of inputs in list.
|
|
779
|
+
:type ge: int, optional
|
|
780
|
+
:param le: Minimum value of inputs in list.
|
|
781
|
+
:type le: int, optional
|
|
782
|
+
:param split_on_dash: Allow dashes in input string,
|
|
783
|
+
defaults to `True`.
|
|
784
|
+
:type split_on_dash: bool, optional
|
|
785
|
+
:param remove_duplicates: Removes duplicates (may also change the
|
|
786
|
+
order), defaults to `True`.
|
|
787
|
+
:type remove_duplicates: bool, optional
|
|
788
|
+
:param sort: Sort in ascending order, defaults to `True`.
|
|
789
|
+
:type sort: bool, optional
|
|
790
|
+
:param raise_error: Raise an exception upon any error,
|
|
791
|
+
defaults to `False`.
|
|
792
|
+
:type raise_error: bool, optional
|
|
793
|
+
:param log: Print an error message upon any error,
|
|
794
|
+
defaults to `True`.
|
|
795
|
+
:type log: bool, optional
|
|
796
|
+
:return: Input list or none upon an illegal input.
|
|
797
|
+
:rtype: list
|
|
798
|
+
"""
|
|
799
|
+
return _input_int_or_num_list(
|
|
800
|
+
'int', s, num_max, ge, le, split_on_dash, remove_duplicates, sort,
|
|
801
|
+
raise_error, log)
|
|
802
|
+
|
|
803
|
+
|
|
804
|
+
def input_num_list(
|
|
805
|
+
s=None, num_max=None, ge=None, le=None, remove_duplicates=True,
|
|
806
|
+
sort=True, raise_error=False, log=True):
|
|
807
|
+
"""Prompt the user to input a list of numbers and split the entered
|
|
808
|
+
string on any combination of commas or whitespaces.
|
|
809
|
+
e.g: '1.0, 3, 5.8, 12 ' -> [1.0, 3.0, 5.8, 12.0]
|
|
810
|
+
|
|
811
|
+
:param s: Interactive user prompt.
|
|
812
|
+
:type s: str, optional
|
|
813
|
+
:param num_max: Maximum number of inputs in list.
|
|
814
|
+
:type num_max: int, optional
|
|
815
|
+
:param ge: Minimum value of inputs in list.
|
|
816
|
+
:type ge: float, optional
|
|
817
|
+
:param le: Minimum value of inputs in list.
|
|
818
|
+
:type le: float, optional
|
|
819
|
+
:param remove_duplicates: Removes duplicates (may also change the
|
|
820
|
+
order), defaults to `True`.
|
|
821
|
+
:type remove_duplicates: bool, optional
|
|
822
|
+
:param sort: Sort in ascending order, defaults to `True`.
|
|
823
|
+
:type sort: bool, optional
|
|
824
|
+
:param raise_error: Raise an exception upon any error,
|
|
825
|
+
defaults to `False`.
|
|
826
|
+
:type raise_error: bool, optional
|
|
827
|
+
:param log: Print an error message upon any error,
|
|
828
|
+
defaults to `True`.
|
|
829
|
+
:type log: bool, optional
|
|
830
|
+
:return: Input list or none upon an illegal input.
|
|
831
|
+
:rtype: list
|
|
832
|
+
"""
|
|
833
|
+
return _input_int_or_num_list(
|
|
834
|
+
'num', s, num_max, ge, le, False, remove_duplicates, sort, raise_error,
|
|
835
|
+
log)
|
|
836
|
+
|
|
837
|
+
|
|
838
|
+
def _input_int_or_num_list(
|
|
839
|
+
type_str, s=None, num_max=None, ge=None, le=None, split_on_dash=True,
|
|
840
|
+
remove_duplicates=True, sort=True, raise_error=False, log=True):
|
|
841
|
+
# RV do we want a limit on max dimension?
|
|
842
|
+
if type_str == 'int':
|
|
843
|
+
if not test_ge_gt_le_lt(
|
|
844
|
+
ge, None, le, None, is_int, 'input_int_or_num_list',
|
|
845
|
+
raise_error, log):
|
|
846
|
+
return None
|
|
847
|
+
elif type_str == 'num':
|
|
848
|
+
if not test_ge_gt_le_lt(
|
|
849
|
+
ge, None, le, None, is_num, 'input_int_or_num_list',
|
|
850
|
+
raise_error, log):
|
|
851
|
+
return None
|
|
852
|
+
else:
|
|
853
|
+
illegal_value(type_str, 'type_str', '_input_int_or_num_list')
|
|
854
|
+
return None
|
|
855
|
+
if (num_max is not None
|
|
856
|
+
and not is_int(num_max, gt=0, raise_error=raise_error, log=log)):
|
|
857
|
+
return None
|
|
858
|
+
v_range = f'{range_string_ge_gt_le_lt(ge=ge, le=le)}'
|
|
859
|
+
if v_range:
|
|
860
|
+
v_range = f' (each value in {v_range})'
|
|
861
|
+
if s is None:
|
|
862
|
+
print(f'Enter a series of integers{v_range}: ')
|
|
863
|
+
else:
|
|
864
|
+
print(f'{s}{v_range}: ')
|
|
865
|
+
try:
|
|
866
|
+
_list = string_to_list(input(), split_on_dash, remove_duplicates, sort)
|
|
867
|
+
except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError):
|
|
868
|
+
_list = None
|
|
869
|
+
except Exception as exc:
|
|
870
|
+
raise Exception('Unexpected error') from exc
|
|
871
|
+
if (not isinstance(_list, list)
|
|
872
|
+
or (num_max is not None and len(_list) > num_max)
|
|
873
|
+
or any(
|
|
874
|
+
not _is_int_or_num(v, type_str, ge=ge, le=le) for v in _list)):
|
|
875
|
+
num = '' if num_max is None else f'up to {num_max} '
|
|
876
|
+
if split_on_dash:
|
|
877
|
+
print(f'Invalid input: enter a valid set of {num}dash/comma/'
|
|
878
|
+
'whitespace separated numbers e.g. 1 3,5-8 , 12')
|
|
879
|
+
else:
|
|
880
|
+
print(f'Invalid input: enter a valid set of {num}comma/whitespace '
|
|
881
|
+
'separated numbers e.g. 1 3,5 8 , 12')
|
|
882
|
+
_list = _input_int_or_num_list(
|
|
883
|
+
type_str, s, ge, le, split_on_dash, remove_duplicates, sort,
|
|
884
|
+
raise_error, log)
|
|
885
|
+
return _list
|
|
886
|
+
|
|
887
|
+
|
|
888
|
+
def input_yesno(s=None, default=None):
|
|
889
|
+
"""Interactively prompt the user to enter a y/n question."""
|
|
890
|
+
if default is not None:
|
|
891
|
+
if not isinstance(default, str):
|
|
892
|
+
illegal_value(default, 'default', 'input_yesno')
|
|
893
|
+
return None
|
|
894
|
+
if default.lower() in 'yes':
|
|
895
|
+
default = 'y'
|
|
896
|
+
elif default.lower() in 'no':
|
|
897
|
+
default = 'n'
|
|
898
|
+
else:
|
|
899
|
+
illegal_value(default, 'default', 'input_yesno')
|
|
900
|
+
return None
|
|
901
|
+
default_string = f' [{default}]'
|
|
902
|
+
else:
|
|
903
|
+
default_string = ''
|
|
904
|
+
if s is None:
|
|
905
|
+
print(f'Enter yes or no{default_string}: ')
|
|
906
|
+
else:
|
|
907
|
+
print(f'{s}{default_string}: ')
|
|
908
|
+
i = input()
|
|
909
|
+
if isinstance(i, str) and not i:
|
|
910
|
+
i = default
|
|
911
|
+
print(f'{i}')
|
|
912
|
+
if i is not None and i.lower() in 'yes':
|
|
913
|
+
v = True
|
|
914
|
+
elif i is not None and i.lower() in 'no':
|
|
915
|
+
v = False
|
|
916
|
+
else:
|
|
917
|
+
print('Invalid input, enter yes or no')
|
|
918
|
+
v = input_yesno(s, default)
|
|
919
|
+
return v
|
|
920
|
+
|
|
921
|
+
|
|
922
|
+
def input_menu(items, default=None, header=None):
|
|
923
|
+
"""Interactively prompt the user to select from a menu."""
|
|
924
|
+
if (not isinstance(items, (tuple, list))
|
|
925
|
+
or any(not isinstance(i, str) for i in items)):
|
|
926
|
+
illegal_value(items, 'items', 'input_menu')
|
|
927
|
+
return None
|
|
928
|
+
if default is not None:
|
|
929
|
+
if not (isinstance(default, str) and default in items):
|
|
930
|
+
logger.error(
|
|
931
|
+
f'Invalid value for default ({default}), must be in {items}')
|
|
932
|
+
return None
|
|
933
|
+
default_string = f' [{1+items.index(default)}]'
|
|
934
|
+
else:
|
|
935
|
+
default_string = ''
|
|
936
|
+
if header is None:
|
|
937
|
+
print('Choose one of the following items '
|
|
938
|
+
f'(1, {len(items)}){default_string}:')
|
|
939
|
+
else:
|
|
940
|
+
print(f'{header} (1, {len(items)}){default_string}:')
|
|
941
|
+
for i, choice in enumerate(items):
|
|
942
|
+
print(f' {i+1}: {choice}')
|
|
943
|
+
try:
|
|
944
|
+
choice = input()
|
|
945
|
+
if isinstance(choice, str) and not choice:
|
|
946
|
+
choice = items.index(default)
|
|
947
|
+
print(f'{1+choice}')
|
|
948
|
+
else:
|
|
949
|
+
choice = literal_eval(choice)
|
|
950
|
+
if isinstance(choice, int) and 1 <= choice <= len(items):
|
|
951
|
+
choice -= 1
|
|
952
|
+
else:
|
|
953
|
+
raise ValueError
|
|
954
|
+
except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError):
|
|
955
|
+
choice = None
|
|
956
|
+
except Exception as exc:
|
|
957
|
+
raise Exception('Unexpected error') from exc
|
|
958
|
+
if choice is None:
|
|
959
|
+
print(f'Invalid choice, enter a number between 1 and {len(items)}')
|
|
960
|
+
choice = input_menu(items, default)
|
|
961
|
+
return choice
|
|
962
|
+
|
|
963
|
+
|
|
964
|
+
def assert_no_duplicates_in_list_of_dicts(_list, raise_error=False):
|
|
965
|
+
"""Assert that there are no duplicates in a list of dictionaries.
|
|
966
|
+
"""
|
|
967
|
+
if not isinstance(_list, list):
|
|
968
|
+
illegal_value(
|
|
969
|
+
_list, '_list', 'assert_no_duplicates_in_list_of_dicts',
|
|
970
|
+
raise_error)
|
|
971
|
+
return None
|
|
972
|
+
if any(not isinstance(d, dict) for d in _list):
|
|
973
|
+
illegal_value(
|
|
974
|
+
_list, '_list', 'assert_no_duplicates_in_list_of_dicts',
|
|
975
|
+
raise_error)
|
|
976
|
+
return None
|
|
977
|
+
if (len(_list) != len([dict(_tuple) for _tuple in
|
|
978
|
+
{tuple(sorted(d.items())) for d in _list}])):
|
|
979
|
+
if raise_error:
|
|
980
|
+
raise ValueError(f'Duplicate items found in {_list}')
|
|
981
|
+
logger.error(f'Duplicate items found in {_list}')
|
|
982
|
+
return None
|
|
983
|
+
return _list
|
|
984
|
+
|
|
985
|
+
|
|
986
|
+
def assert_no_duplicate_key_in_list_of_dicts(_list, key, raise_error=False):
|
|
987
|
+
"""Assert that there are no duplicate keys in a list of
|
|
988
|
+
dictionaries.
|
|
989
|
+
"""
|
|
990
|
+
if not isinstance(key, str):
|
|
991
|
+
illegal_value(
|
|
992
|
+
key, 'key', 'assert_no_duplicate_key_in_list_of_dicts',
|
|
993
|
+
raise_error)
|
|
994
|
+
return None
|
|
995
|
+
if not isinstance(_list, list):
|
|
996
|
+
illegal_value(
|
|
997
|
+
_list, '_list', 'assert_no_duplicate_key_in_list_of_dicts',
|
|
998
|
+
raise_error)
|
|
999
|
+
return None
|
|
1000
|
+
if any(isinstance(d, dict) for d in _list):
|
|
1001
|
+
illegal_value(
|
|
1002
|
+
_list, '_list', 'assert_no_duplicates_in_list_of_dicts',
|
|
1003
|
+
raise_error)
|
|
1004
|
+
return None
|
|
1005
|
+
keys = [d.get(key, None) for d in _list]
|
|
1006
|
+
if None in keys or len(set(keys)) != len(_list):
|
|
1007
|
+
if raise_error:
|
|
1008
|
+
raise ValueError(
|
|
1009
|
+
f'Duplicate or missing key ({key}) found in {_list}')
|
|
1010
|
+
logger.error(f'Duplicate or missing key ({key}) found in {_list}')
|
|
1011
|
+
return None
|
|
1012
|
+
return _list
|
|
1013
|
+
|
|
1014
|
+
|
|
1015
|
+
def assert_no_duplicate_attr_in_list_of_objs(_list, attr, raise_error=False):
|
|
1016
|
+
"""Assert that there are no duplicate attributes in a list of
|
|
1017
|
+
objects.
|
|
1018
|
+
"""
|
|
1019
|
+
if not isinstance(attr, str):
|
|
1020
|
+
illegal_value(
|
|
1021
|
+
attr, 'attr', 'assert_no_duplicate_attr_in_list_of_objs',
|
|
1022
|
+
raise_error)
|
|
1023
|
+
return None
|
|
1024
|
+
if not isinstance(_list, list):
|
|
1025
|
+
illegal_value(
|
|
1026
|
+
_list, '_list', 'assert_no_duplicate_key_in_list_of_objs',
|
|
1027
|
+
raise_error)
|
|
1028
|
+
return None
|
|
1029
|
+
attrs = [getattr(obj, attr, None) for obj in _list]
|
|
1030
|
+
if None in attrs or len(set(attrs)) != len(_list):
|
|
1031
|
+
if raise_error:
|
|
1032
|
+
raise ValueError(
|
|
1033
|
+
f'Duplicate or missing attr ({attr}) found in {_list}')
|
|
1034
|
+
logger.error(f'Duplicate or missing attr ({attr}) found in {_list}')
|
|
1035
|
+
return None
|
|
1036
|
+
return _list
|
|
1037
|
+
|
|
1038
|
+
|
|
1039
|
+
def file_exists_and_readable(f):
|
|
1040
|
+
"""Check if a file exists and is readable."""
|
|
1041
|
+
if not os.path.isfile(f):
|
|
1042
|
+
raise ValueError(f'{f} is not a valid file')
|
|
1043
|
+
if not os.access(f, os.R_OK):
|
|
1044
|
+
raise ValueError(f'{f} is not accessible for reading')
|
|
1045
|
+
return f
|
|
1046
|
+
|
|
1047
|
+
|
|
1048
|
+
def rolling_average(
|
|
1049
|
+
y, x=None, dtype=None, start=0, end=None, width=None,
|
|
1050
|
+
stride=None, num=None, average=True, mode='valid',
|
|
1051
|
+
use_convolve=None):
|
|
1052
|
+
"""Returns the rolling sum or average of an array over the last
|
|
1053
|
+
dimension.
|
|
1054
|
+
"""
|
|
1055
|
+
y = np.asarray(y)
|
|
1056
|
+
y_shape = y.shape
|
|
1057
|
+
if y.ndim == 1:
|
|
1058
|
+
y = np.expand_dims(y, 0)
|
|
1059
|
+
else:
|
|
1060
|
+
y = y.reshape((np.prod(y.shape[0:-1]), y.shape[-1]))
|
|
1061
|
+
if x is not None:
|
|
1062
|
+
x = np.asarray(x)
|
|
1063
|
+
if x.ndim != 1:
|
|
1064
|
+
raise ValueError('Parameter "x" must be a 1D array-like')
|
|
1065
|
+
if x.size != y.shape[1]:
|
|
1066
|
+
raise ValueError(f'Dimensions of "x" and "y[1]" do not '
|
|
1067
|
+
f'match ({x.size} vs {y.shape[1]})')
|
|
1068
|
+
if dtype is None:
|
|
1069
|
+
if average:
|
|
1070
|
+
dtype = y.dtype
|
|
1071
|
+
else:
|
|
1072
|
+
dtype = np.float32
|
|
1073
|
+
if width is None and stride is None and num is None:
|
|
1074
|
+
raise ValueError('Invalid input parameters, specify at least one of '
|
|
1075
|
+
'"width", "stride" or "num"')
|
|
1076
|
+
if width is not None and not is_int(width, ge=1):
|
|
1077
|
+
raise ValueError(f'Invalid "width" parameter ({width})')
|
|
1078
|
+
if stride is not None and not is_int(stride, ge=1):
|
|
1079
|
+
raise ValueError(f'Invalid "stride" parameter ({stride})')
|
|
1080
|
+
if num is not None and not is_int(num, ge=1):
|
|
1081
|
+
raise ValueError(f'Invalid "num" parameter ({num})')
|
|
1082
|
+
if not isinstance(average, bool):
|
|
1083
|
+
raise ValueError(f'Invalid "average" parameter ({average})')
|
|
1084
|
+
if mode not in ('valid', 'full'):
|
|
1085
|
+
raise ValueError(f'Invalid "mode" parameter ({mode})')
|
|
1086
|
+
size = y.shape[1]
|
|
1087
|
+
if size < 2:
|
|
1088
|
+
raise ValueError(f'Invalid y[1] dimension ({size})')
|
|
1089
|
+
if not is_int(start, ge=0, lt=size):
|
|
1090
|
+
raise ValueError(f'Invalid "start" parameter ({start})')
|
|
1091
|
+
if end is None:
|
|
1092
|
+
end = size
|
|
1093
|
+
elif not is_int(end, gt=start, le=size):
|
|
1094
|
+
raise ValueError(f'Invalid "end" parameter ({end})')
|
|
1095
|
+
if use_convolve is None:
|
|
1096
|
+
use_convolve = bool(len(y_shape) == 1)
|
|
1097
|
+
if use_convolve and (start or end < size):
|
|
1098
|
+
y = np.take(y, range(start, end), axis=1)
|
|
1099
|
+
if x is not None:
|
|
1100
|
+
x = x[start:end]
|
|
1101
|
+
size = y.shape[1]
|
|
1102
|
+
else:
|
|
1103
|
+
size = end-start
|
|
1104
|
+
|
|
1105
|
+
if stride is None:
|
|
1106
|
+
if width is None:
|
|
1107
|
+
width = max(1, int(size/num))
|
|
1108
|
+
stride = width
|
|
1109
|
+
else:
|
|
1110
|
+
width = min(width, size)
|
|
1111
|
+
if num is None:
|
|
1112
|
+
stride = width
|
|
1113
|
+
else:
|
|
1114
|
+
stride = max(1, int((size-width) / (num-1)))
|
|
1115
|
+
else:
|
|
1116
|
+
stride = min(stride, size-stride)
|
|
1117
|
+
if width is None:
|
|
1118
|
+
width = stride
|
|
1119
|
+
|
|
1120
|
+
if mode == 'valid':
|
|
1121
|
+
num = 1 + max(0, int((size-width) / stride))
|
|
1122
|
+
else:
|
|
1123
|
+
num = int(size/stride)
|
|
1124
|
+
if num*stride < size:
|
|
1125
|
+
num += 1
|
|
1126
|
+
|
|
1127
|
+
if use_convolve:
|
|
1128
|
+
n_start = 0
|
|
1129
|
+
n_end = width
|
|
1130
|
+
weight = np.empty((num))
|
|
1131
|
+
for n in range(num):
|
|
1132
|
+
n_num = n_end-n_start
|
|
1133
|
+
weight[n] = n_num
|
|
1134
|
+
n_start += stride
|
|
1135
|
+
n_end = min(size, n_end+stride)
|
|
1136
|
+
|
|
1137
|
+
window = np.ones((width))
|
|
1138
|
+
if x is not None:
|
|
1139
|
+
if mode == 'valid':
|
|
1140
|
+
rx = np.convolve(x, window)[width-1:1-width:stride]
|
|
1141
|
+
else:
|
|
1142
|
+
rx = np.convolve(x, window)[width-1::stride]
|
|
1143
|
+
rx /= weight
|
|
1144
|
+
|
|
1145
|
+
ry = []
|
|
1146
|
+
if mode == 'valid':
|
|
1147
|
+
for i in range(y.shape[0]):
|
|
1148
|
+
ry.append(np.convolve(y[i], window)[width-1:1-width:stride])
|
|
1149
|
+
else:
|
|
1150
|
+
for i in range(y.shape[0]):
|
|
1151
|
+
ry.append(np.convolve(y[i], window)[width-1::stride])
|
|
1152
|
+
ry = np.reshape(ry, (*y_shape[0:-1], num))
|
|
1153
|
+
if len(y_shape) == 1:
|
|
1154
|
+
ry = np.squeeze(ry)
|
|
1155
|
+
if average:
|
|
1156
|
+
ry = (np.asarray(ry).astype(np.float32)/weight).astype(dtype)
|
|
1157
|
+
elif mode != 'valid':
|
|
1158
|
+
weight = np.where(weight < width, width/weight, 1.0)
|
|
1159
|
+
ry = (np.asarray(ry).astype(np.float32)*weight).astype(dtype)
|
|
1160
|
+
else:
|
|
1161
|
+
ry = np.zeros((num, y.shape[0]), dtype=y.dtype)
|
|
1162
|
+
if x is not None:
|
|
1163
|
+
rx = np.zeros(num, dtype=x.dtype)
|
|
1164
|
+
n_start = start
|
|
1165
|
+
n_end = n_start+width
|
|
1166
|
+
for n in range(num):
|
|
1167
|
+
y_sum = np.sum(y[:,n_start:n_end], 1)
|
|
1168
|
+
n_num = n_end-n_start
|
|
1169
|
+
if n_num < width:
|
|
1170
|
+
y_sum *= width/n_num
|
|
1171
|
+
ry[n] = y_sum
|
|
1172
|
+
if x is not None:
|
|
1173
|
+
rx[n] = np.sum(x[n_start:n_end])/n_num
|
|
1174
|
+
n_start += stride
|
|
1175
|
+
n_end = min(start+size, n_end+stride)
|
|
1176
|
+
ry = np.reshape(ry.T, (*y_shape[0:-1], num))
|
|
1177
|
+
if len(y_shape) == 1:
|
|
1178
|
+
ry = np.squeeze(ry)
|
|
1179
|
+
if average:
|
|
1180
|
+
ry = (ry.astype(np.float32)/width).astype(dtype)
|
|
1181
|
+
|
|
1182
|
+
if x is None:
|
|
1183
|
+
return ry
|
|
1184
|
+
return ry, rx
|
|
1185
|
+
|
|
1186
|
+
|
|
1187
|
+
def baseline_arPLS(
|
|
1188
|
+
y, mask=None, w=None, tol=1.e-8, lam=1.e6, max_iter=20,
|
|
1189
|
+
full_output=False):
|
|
1190
|
+
"""Returns the smoothed baseline estimate of a spectrum.
|
|
1191
|
+
|
|
1192
|
+
Based on S.-J. Baek, A. Park, Y.-J Ahn, and J. Choo,
|
|
1193
|
+
"Baseline correction using asymmetrically reweighted penalized
|
|
1194
|
+
least squares smoothing", Analyst, 2015,140, 250-257
|
|
1195
|
+
|
|
1196
|
+
:param y: The spectrum.
|
|
1197
|
+
:type y: array-like
|
|
1198
|
+
:param mask: A mask to apply to the spectrum before baseline
|
|
1199
|
+
construction.
|
|
1200
|
+
:type mask: array-like, optional
|
|
1201
|
+
:param w: The weights (allows restart for additional ieterations).
|
|
1202
|
+
:type w: numpy.array, optional
|
|
1203
|
+
:param tol: The convergence tolerence, defaults to `1.e-8`.
|
|
1204
|
+
:type tol: float, optional
|
|
1205
|
+
:param lam: The &lambda (smoothness) parameter (the balance
|
|
1206
|
+
between the residual of the data and the baseline and the
|
|
1207
|
+
smoothness of the baseline). The suggested range is between
|
|
1208
|
+
100 and 10^8, defaults to `10^6`.
|
|
1209
|
+
:type lam: float, optional
|
|
1210
|
+
:param max_iter: The maximum number of iterations,
|
|
1211
|
+
defaults to `20`.
|
|
1212
|
+
:type max_iter: int, optional
|
|
1213
|
+
:param full_output: Whether or not to also output the baseline
|
|
1214
|
+
corrected spectrum, the number of iterations and error in the
|
|
1215
|
+
returned result, defaults to `False`.
|
|
1216
|
+
:type full_output: bool, optional
|
|
1217
|
+
:return: The smoothed baseline, with optionally the baseline
|
|
1218
|
+
corrected spectrum, the weights, the number of iterations and
|
|
1219
|
+
the error in the returned result.
|
|
1220
|
+
:rtype: numpy.array [, numpy.array, int, float]
|
|
1221
|
+
"""
|
|
1222
|
+
# With credit to: Daniel Casas-Orozco
|
|
1223
|
+
# https://stackoverflow.com/questions/29156532/python-baseline-correction-library
|
|
1224
|
+
# Third party modules
|
|
1225
|
+
from scipy.sparse import (
|
|
1226
|
+
spdiags,
|
|
1227
|
+
linalg,
|
|
1228
|
+
)
|
|
1229
|
+
|
|
1230
|
+
if not is_num(tol, gt=0):
|
|
1231
|
+
raise ValueError(f'Invalid tol parameter ({tol})')
|
|
1232
|
+
if not is_num(lam, gt=0):
|
|
1233
|
+
raise ValueError(f'Invalid lam parameter ({lam})')
|
|
1234
|
+
if not is_int(max_iter, gt=0):
|
|
1235
|
+
raise ValueError(f'Invalid max_iter parameter ({max_iter})')
|
|
1236
|
+
if not isinstance(full_output, bool):
|
|
1237
|
+
raise ValueError(f'Invalid full_output parameter ({max_iter})')
|
|
1238
|
+
y = np.asarray(y)
|
|
1239
|
+
if mask is not None:
|
|
1240
|
+
mask = mask.astype(bool)
|
|
1241
|
+
y_org = y
|
|
1242
|
+
y = y[mask]
|
|
1243
|
+
num = y.size
|
|
1244
|
+
|
|
1245
|
+
diag = np.ones((num-2))
|
|
1246
|
+
D = spdiags([diag, -2*diag, diag], [0, -1, -2], num, num-2)
|
|
1247
|
+
|
|
1248
|
+
H = lam * D.dot(D.T)
|
|
1249
|
+
|
|
1250
|
+
if w is None:
|
|
1251
|
+
w = np.ones(num)
|
|
1252
|
+
W = spdiags(w, 0, num, num)
|
|
1253
|
+
|
|
1254
|
+
error = 1
|
|
1255
|
+
num_iter = 0
|
|
1256
|
+
|
|
1257
|
+
exp_max = int(np.log(sys.float_info.max))
|
|
1258
|
+
while error > tol and num_iter < max_iter:
|
|
1259
|
+
z = linalg.spsolve(W + H, W * y)
|
|
1260
|
+
d = y - z
|
|
1261
|
+
dn = d[d < 0]
|
|
1262
|
+
|
|
1263
|
+
m = np.mean(dn)
|
|
1264
|
+
s = np.std(dn)
|
|
1265
|
+
|
|
1266
|
+
w_new = 1.0 / (1.0 + np.exp(
|
|
1267
|
+
np.clip(2.0 * (d - (2.0*s - m))/s, None, exp_max)))
|
|
1268
|
+
error = np.linalg.norm(w_new - w) / np.linalg.norm(w)
|
|
1269
|
+
num_iter += 1
|
|
1270
|
+
w = w_new
|
|
1271
|
+
W.setdiag(w)
|
|
1272
|
+
|
|
1273
|
+
if mask is not None:
|
|
1274
|
+
zz = np.zeros(y_org.size)
|
|
1275
|
+
zz[mask] = z
|
|
1276
|
+
z = zz
|
|
1277
|
+
if full_output:
|
|
1278
|
+
d = y_org - z
|
|
1279
|
+
if full_output:
|
|
1280
|
+
return z, d, w, num_iter, float(error)
|
|
1281
|
+
return z
|
|
1282
|
+
|
|
1283
|
+
|
|
1284
|
+
def fig_to_iobuf(fig, fileformat=None):
|
|
1285
|
+
"""Return an in-memory object as a byte stream represention of
|
|
1286
|
+
a Matplotlib figure.
|
|
1287
|
+
|
|
1288
|
+
:param fig: Matplotlib figure object.
|
|
1289
|
+
:type fig: matplotlib.figure.Figure
|
|
1290
|
+
:param fileformat: Valid Matplotlib saved figure file format,
|
|
1291
|
+
defaults to `'png'`.
|
|
1292
|
+
:type fileformat: str, optional
|
|
1293
|
+
:return: Byte stream representation of the Matplotlib figure and
|
|
1294
|
+
the associated file format.
|
|
1295
|
+
:rtype: tuple[_io.BytesIO, str]
|
|
1296
|
+
"""
|
|
1297
|
+
# System modules
|
|
1298
|
+
from io import BytesIO
|
|
1299
|
+
|
|
1300
|
+
if fileformat is None:
|
|
1301
|
+
fileformat = 'png'
|
|
1302
|
+
else:
|
|
1303
|
+
if fileformat not in plt.gcf().canvas.get_supported_filetypes():
|
|
1304
|
+
fileformat = 'png'
|
|
1305
|
+
|
|
1306
|
+
buf = BytesIO()
|
|
1307
|
+
fig.savefig(buf, format=fileformat)
|
|
1308
|
+
return buf, fileformat
|
|
1309
|
+
|
|
1310
|
+
|
|
1311
|
+
def save_iobuf_fig(buf, filename, force_overwrite=False):
|
|
1312
|
+
"""Save a byte stream represention of a Matplotlib figure to file.
|
|
1313
|
+
|
|
1314
|
+
:param buf: Byte stream representation of the Matplotlib figure.
|
|
1315
|
+
:type buf: _io.BytesIO
|
|
1316
|
+
:param filename: Filename (with a valid extension).
|
|
1317
|
+
:type filename: str
|
|
1318
|
+
:param force_overwrite: Flag to allow `filename` to be overwritten
|
|
1319
|
+
if it already exists, defaults to `False`.
|
|
1320
|
+
:type force_overwrite: bool, optional
|
|
1321
|
+
:raises RuntimeError: If a file already exists and
|
|
1322
|
+
`force_overwrite` is `False`.
|
|
1323
|
+
"""
|
|
1324
|
+
# Third party modules
|
|
1325
|
+
from PIL import Image
|
|
1326
|
+
|
|
1327
|
+
exts = Image.registered_extensions()
|
|
1328
|
+
exts = {ex for ex, f in exts.items() if f in Image.SAVE}
|
|
1329
|
+
|
|
1330
|
+
# Validate filename and extension
|
|
1331
|
+
_, ext = os.path.splitext(filename)
|
|
1332
|
+
if not ext or ext not in exts:
|
|
1333
|
+
raise ValueError(f'Invalid file format {ext[1:]}')
|
|
1334
|
+
filedir = os.path.dirname(filename)
|
|
1335
|
+
if not os.path.isdir(filedir):
|
|
1336
|
+
os.makedirs(filedir)
|
|
1337
|
+
if os.path.isfile(filename) and not force_overwrite:
|
|
1338
|
+
raise FileExistsError(f'{filename} already exists')
|
|
1339
|
+
|
|
1340
|
+
# Write image to file
|
|
1341
|
+
buf.seek(0)
|
|
1342
|
+
img = Image.open(buf)
|
|
1343
|
+
img.save(filename)
|
|
1344
|
+
|
|
1345
|
+
|
|
1346
|
+
def select_mask_1d(
|
|
1347
|
+
y, x=None, preselected_index_ranges=None, preselected_mask=None,
|
|
1348
|
+
title=None, xlabel=None, ylabel=None, min_num_index_ranges=None,
|
|
1349
|
+
max_num_index_ranges=None, interactive=True, filename=None,
|
|
1350
|
+
return_buf=False):
|
|
1351
|
+
"""Display a lineplot and have the user select a mask.
|
|
1352
|
+
|
|
1353
|
+
:param y: One-dimensional data array for which a mask will be
|
|
1354
|
+
constructed.
|
|
1355
|
+
:type y: numpy.ndarray
|
|
1356
|
+
:param x: x-coordinates of the reference data.
|
|
1357
|
+
:type x: numpy.ndarray, optional
|
|
1358
|
+
:param preselected_index_ranges: List of preselected index ranges
|
|
1359
|
+
to mask (bounds are inclusive).
|
|
1360
|
+
:type preselected_index_ranges: Union(list[tuple(int, int)],
|
|
1361
|
+
list[list[int]], list[tuple(float, float)], list[list[float]]),
|
|
1362
|
+
optional
|
|
1363
|
+
:param preselected_mask: Preselected boolean mask array.
|
|
1364
|
+
:type preselected_mask: numpy.ndarray, optional
|
|
1365
|
+
:param title: Title for the displayed figure.
|
|
1366
|
+
:type title: str, optional
|
|
1367
|
+
:param xlabel: Label for the x-axis of the displayed figure.
|
|
1368
|
+
:type xlabel: str, optional
|
|
1369
|
+
:param ylabel: Label for the y-axis of the displayed figure.
|
|
1370
|
+
:type ylabel: str, optional
|
|
1371
|
+
:param min_num_index_ranges: The minimum number of selected index
|
|
1372
|
+
ranges.
|
|
1373
|
+
:type min_num_index_ranges: int, optional
|
|
1374
|
+
:param max_num_index_ranges: The maximum number of selected index
|
|
1375
|
+
ranges.
|
|
1376
|
+
:type max_num_index_ranges: int, optional
|
|
1377
|
+
:param interactive: Show the plot and allow user interactions with
|
|
1378
|
+
the matplotlib figure, defaults to `True`.
|
|
1379
|
+
:type interactive: bool, optional
|
|
1380
|
+
:param filename: Save a .png of the plot to filename, defaults to
|
|
1381
|
+
`None`, in which case the plot is not saved.
|
|
1382
|
+
:type filename: str, optional
|
|
1383
|
+
:param return_buf: Return an in-memory object as a byte stream
|
|
1384
|
+
represention of the Matplotlib figure, defaults to `False`.
|
|
1385
|
+
:type return_buf: bool, optional
|
|
1386
|
+
:return: A byte stream represention of the Matplotlib figure if
|
|
1387
|
+
return_buf is `True` (`None` otherwise), a boolean mask array,
|
|
1388
|
+
and the list of selected index ranges.
|
|
1389
|
+
:rtype: Union[io.BytesIO, None], numpy.ndarray, list[list[int, int]]
|
|
1390
|
+
"""
|
|
1391
|
+
# Third party modules
|
|
1392
|
+
# pylint: disable=possibly-used-before-assignment
|
|
1393
|
+
if interactive or filename is not None or return_buf:
|
|
1394
|
+
from matplotlib.patches import Patch
|
|
1395
|
+
from matplotlib.widgets import Button, SpanSelector
|
|
1396
|
+
|
|
1397
|
+
def change_fig_title(title):
|
|
1398
|
+
if fig_title:
|
|
1399
|
+
fig_title[0].remove()
|
|
1400
|
+
fig_title.pop()
|
|
1401
|
+
fig_title.append(plt.figtext(*title_pos, title, **title_props))
|
|
1402
|
+
|
|
1403
|
+
def change_error_text(error):
|
|
1404
|
+
if error_texts:
|
|
1405
|
+
error_texts[0].remove()
|
|
1406
|
+
error_texts.pop()
|
|
1407
|
+
error_texts.append(plt.figtext(*error_pos, error, **error_props))
|
|
1408
|
+
|
|
1409
|
+
def get_selected_index_ranges(change_fnc=None, title=''):
|
|
1410
|
+
selected_index_ranges = sorted(
|
|
1411
|
+
[[index_nearest(x, span.extents[0]),
|
|
1412
|
+
index_nearest(x, span.extents[1])+1]
|
|
1413
|
+
for span in spans])
|
|
1414
|
+
if change_fnc is not None:
|
|
1415
|
+
if len(selected_index_ranges) > 1:
|
|
1416
|
+
change_fnc(
|
|
1417
|
+
f'{title}Selected ROIs: {selected_index_ranges}')
|
|
1418
|
+
elif selected_index_ranges:
|
|
1419
|
+
change_fnc(
|
|
1420
|
+
f'{title}Selected ROI: {tuple(selected_index_ranges[0])}')
|
|
1421
|
+
else:
|
|
1422
|
+
change_fnc(f'{title}Selected ROI: None')
|
|
1423
|
+
return selected_index_ranges
|
|
1424
|
+
|
|
1425
|
+
def add_span(event, xrange_init=None):
|
|
1426
|
+
"""Callback function for the "Add span" button."""
|
|
1427
|
+
if (max_num_index_ranges is not None
|
|
1428
|
+
and len(spans) >= max_num_index_ranges):
|
|
1429
|
+
change_error_text(
|
|
1430
|
+
'Exceeding max number of ranges, adjust an existing '
|
|
1431
|
+
'range or click "Reset"/"Confirm"')
|
|
1432
|
+
else:
|
|
1433
|
+
spans.append(
|
|
1434
|
+
SpanSelector(
|
|
1435
|
+
ax, select_span, 'horizontal', props=included_props,
|
|
1436
|
+
useblit=True, interactive=interactive,
|
|
1437
|
+
drag_from_anywhere=True, ignore_event_outside=True,
|
|
1438
|
+
grab_range=5))
|
|
1439
|
+
if xrange_init is None:
|
|
1440
|
+
xmin_init, xmax_init = min(x), 0.05*(max(x)-min(x))
|
|
1441
|
+
else:
|
|
1442
|
+
xmin_init, xmax_init = xrange_init
|
|
1443
|
+
spans[-1]._selection_completed = True
|
|
1444
|
+
spans[-1].extents = (xmin_init, xmax_init)
|
|
1445
|
+
spans[-1].onselect(xmin_init, xmax_init)
|
|
1446
|
+
plt.draw()
|
|
1447
|
+
|
|
1448
|
+
def select_span(xmin, xmax):
|
|
1449
|
+
"""Callback function for the SpanSelector widget."""
|
|
1450
|
+
combined_spans = True
|
|
1451
|
+
while combined_spans:
|
|
1452
|
+
combined_spans = False
|
|
1453
|
+
for i, span1 in enumerate(spans):
|
|
1454
|
+
for span2 in spans[i+1:]:
|
|
1455
|
+
if (span1.extents[1] >= span2.extents[0]
|
|
1456
|
+
and span1.extents[0] <= span2.extents[1]):
|
|
1457
|
+
change_error_text(
|
|
1458
|
+
'Combined overlapping spans in currently '
|
|
1459
|
+
'selected mask')
|
|
1460
|
+
span2.extents = (
|
|
1461
|
+
min(span1.extents[0], span2.extents[0]),
|
|
1462
|
+
max(span1.extents[1], span2.extents[1]))
|
|
1463
|
+
span1.set_visible(False)
|
|
1464
|
+
spans.remove(span1)
|
|
1465
|
+
combined_spans = True
|
|
1466
|
+
break
|
|
1467
|
+
if combined_spans:
|
|
1468
|
+
break
|
|
1469
|
+
get_selected_index_ranges(change_error_text)
|
|
1470
|
+
plt.draw()
|
|
1471
|
+
|
|
1472
|
+
def reset(event):
|
|
1473
|
+
"""Callback function for the "Reset" button."""
|
|
1474
|
+
if error_texts:
|
|
1475
|
+
error_texts[0].remove()
|
|
1476
|
+
error_texts.pop()
|
|
1477
|
+
for span in reversed(spans):
|
|
1478
|
+
span.set_visible(False)
|
|
1479
|
+
spans.remove(span)
|
|
1480
|
+
get_selected_index_ranges(change_error_text)
|
|
1481
|
+
plt.draw()
|
|
1482
|
+
|
|
1483
|
+
def confirm(event):
|
|
1484
|
+
"""Callback function for the "Confirm" button."""
|
|
1485
|
+
if (min_num_index_ranges is not None
|
|
1486
|
+
and len(spans) < min_num_index_ranges):
|
|
1487
|
+
change_error_text(
|
|
1488
|
+
f'Select at least {min_num_index_ranges} unique index ranges')
|
|
1489
|
+
plt.draw()
|
|
1490
|
+
else:
|
|
1491
|
+
if error_texts:
|
|
1492
|
+
error_texts[0].remove()
|
|
1493
|
+
error_texts.pop()
|
|
1494
|
+
get_selected_index_ranges(change_fig_title, title)
|
|
1495
|
+
plt.close()
|
|
1496
|
+
|
|
1497
|
+
def update_mask(mask, selected_index_ranges):
|
|
1498
|
+
"""Update the mask with the selected index ranges."""
|
|
1499
|
+
for min_, max_ in selected_index_ranges:
|
|
1500
|
+
mask = np.logical_or(
|
|
1501
|
+
mask,
|
|
1502
|
+
np.logical_and(x >= x[min_], x <= x[min(max_, num_data-1)]))
|
|
1503
|
+
return mask
|
|
1504
|
+
|
|
1505
|
+
def update_index_ranges(mask):
|
|
1506
|
+
"""Update the selected index ranges (where mask = True)."""
|
|
1507
|
+
selected_index_ranges = []
|
|
1508
|
+
for i, m in enumerate(mask):
|
|
1509
|
+
if m:
|
|
1510
|
+
if (not selected_index_ranges
|
|
1511
|
+
or isinstance(selected_index_ranges[-1], tuple)):
|
|
1512
|
+
selected_index_ranges.append(i)
|
|
1513
|
+
else:
|
|
1514
|
+
if (selected_index_ranges
|
|
1515
|
+
and isinstance(selected_index_ranges[-1], int)):
|
|
1516
|
+
selected_index_ranges[-1] = \
|
|
1517
|
+
(selected_index_ranges[-1], i-1)
|
|
1518
|
+
if (selected_index_ranges
|
|
1519
|
+
and isinstance(selected_index_ranges[-1], int)):
|
|
1520
|
+
selected_index_ranges[-1] = (selected_index_ranges[-1], num_data-1)
|
|
1521
|
+
return selected_index_ranges
|
|
1522
|
+
|
|
1523
|
+
# Check inputs
|
|
1524
|
+
y = np.asarray(y)
|
|
1525
|
+
if y.ndim > 1:
|
|
1526
|
+
raise ValueError(f'Invalid y dimension ({y.ndim})')
|
|
1527
|
+
num_data = y.size
|
|
1528
|
+
if x is None:
|
|
1529
|
+
x = np.arange(num_data)+0.5
|
|
1530
|
+
else:
|
|
1531
|
+
x = np.asarray(x, dtype=np.float64)
|
|
1532
|
+
if x.ndim > 1 or x.size != num_data:
|
|
1533
|
+
raise ValueError(f'Invalid x shape ({x.shape})')
|
|
1534
|
+
if not np.all(x[:-1] < x[1:]):
|
|
1535
|
+
raise ValueError('Invalid x: must be monotonically increasing')
|
|
1536
|
+
if title is None:
|
|
1537
|
+
title = ''
|
|
1538
|
+
else:
|
|
1539
|
+
title = f'{title}: '
|
|
1540
|
+
if preselected_index_ranges is None:
|
|
1541
|
+
preselected_index_ranges = []
|
|
1542
|
+
else:
|
|
1543
|
+
if not isinstance(preselected_index_ranges, list):
|
|
1544
|
+
raise ValueError('Invalid parameter preselected_index_ranges '
|
|
1545
|
+
f'({preselected_index_ranges})')
|
|
1546
|
+
if interactive or filename is not None or return_buf:
|
|
1547
|
+
index_ranges = []
|
|
1548
|
+
for v in preselected_index_ranges:
|
|
1549
|
+
if not is_num_pair(v):
|
|
1550
|
+
raise ValueError(
|
|
1551
|
+
'Invalid parameter preselected_index_ranges '
|
|
1552
|
+
f'({preselected_index_ranges})')
|
|
1553
|
+
index_ranges.append(
|
|
1554
|
+
(max(0, int(v[0])), min(num_data, int(v[1])-1)))
|
|
1555
|
+
preselected_index_ranges = index_ranges
|
|
1556
|
+
|
|
1557
|
+
# Setup the preselected mask and index ranges if provided
|
|
1558
|
+
if preselected_mask is not None:
|
|
1559
|
+
preselected_index_ranges = update_index_ranges(
|
|
1560
|
+
update_mask(
|
|
1561
|
+
np.copy(np.asarray(preselected_mask, dtype=bool)),
|
|
1562
|
+
preselected_index_ranges))
|
|
1563
|
+
|
|
1564
|
+
if not interactive and filename is None and not return_buf:
|
|
1565
|
+
|
|
1566
|
+
# Update the mask with the preselected index ranges
|
|
1567
|
+
selected_mask = update_mask(len(x)*[False], preselected_index_ranges)
|
|
1568
|
+
|
|
1569
|
+
return None, selected_mask, preselected_index_ranges
|
|
1570
|
+
|
|
1571
|
+
spans = []
|
|
1572
|
+
fig_title = []
|
|
1573
|
+
error_texts = []
|
|
1574
|
+
|
|
1575
|
+
# Setup the Matplotlib figure
|
|
1576
|
+
title_pos = (0.5, 0.95)
|
|
1577
|
+
title_props = {'fontsize': 'xx-large', 'horizontalalignment': 'center',
|
|
1578
|
+
'verticalalignment': 'bottom'}
|
|
1579
|
+
error_pos = (0.5, 0.90)
|
|
1580
|
+
error_props = {'fontsize': 'x-large', 'horizontalalignment': 'center',
|
|
1581
|
+
'verticalalignment': 'bottom'}
|
|
1582
|
+
excluded_props = {
|
|
1583
|
+
'facecolor': 'white', 'edgecolor': 'gray', 'linestyle': ':'}
|
|
1584
|
+
included_props = {
|
|
1585
|
+
'alpha': 0.5, 'facecolor': 'tab:blue', 'edgecolor': 'blue'}
|
|
1586
|
+
|
|
1587
|
+
fig, ax = plt.subplots(figsize=(11, 8.5))
|
|
1588
|
+
handles = ax.plot(x, y, color='k', label='Reference Data')
|
|
1589
|
+
handles.append(Patch(
|
|
1590
|
+
label='Excluded / unselected ranges', **excluded_props))
|
|
1591
|
+
handles.append(Patch(
|
|
1592
|
+
label='Included / selected ranges', **included_props))
|
|
1593
|
+
ax.legend(handles=handles)
|
|
1594
|
+
ax.set_xlabel(xlabel, fontsize='x-large')
|
|
1595
|
+
ax.set_ylabel(ylabel, fontsize='x-large')
|
|
1596
|
+
ax.set_xlim(x[0], x[-1])
|
|
1597
|
+
fig.subplots_adjust(bottom=0.0, top=0.85)
|
|
1598
|
+
|
|
1599
|
+
# Add the preselected index ranges
|
|
1600
|
+
for min_, max_ in preselected_index_ranges:
|
|
1601
|
+
add_span(None, xrange_init=(x[min_], x[min(max_, num_data-1)]))
|
|
1602
|
+
|
|
1603
|
+
if not interactive:
|
|
1604
|
+
|
|
1605
|
+
get_selected_index_ranges(change_fig_title, title)
|
|
1606
|
+
if error_texts:
|
|
1607
|
+
error_texts[0].remove()
|
|
1608
|
+
error_texts.pop()
|
|
1609
|
+
|
|
1610
|
+
else:
|
|
1611
|
+
|
|
1612
|
+
change_fig_title(f'{title}Click and drag to select ranges')
|
|
1613
|
+
get_selected_index_ranges(change_error_text)
|
|
1614
|
+
fig.subplots_adjust(bottom=0.2)
|
|
1615
|
+
|
|
1616
|
+
# Setup "Add span" button
|
|
1617
|
+
add_span_btn = Button(
|
|
1618
|
+
plt.axes([0.15, 0.05, 0.15, 0.075]), 'Add span')
|
|
1619
|
+
add_span_cid = add_span_btn.on_clicked(add_span)
|
|
1620
|
+
|
|
1621
|
+
# Setup "Reset" button
|
|
1622
|
+
reset_btn = Button(plt.axes([0.45, 0.05, 0.15, 0.075]), 'Reset')
|
|
1623
|
+
reset_cid = reset_btn.on_clicked(reset)
|
|
1624
|
+
|
|
1625
|
+
# Setup "Confirm" button
|
|
1626
|
+
confirm_btn = Button(plt.axes([0.75, 0.05, 0.15, 0.075]), 'Confirm')
|
|
1627
|
+
confirm_cid = confirm_btn.on_clicked(confirm)
|
|
1628
|
+
|
|
1629
|
+
# Show figure for user interaction
|
|
1630
|
+
plt.show()
|
|
1631
|
+
|
|
1632
|
+
# Disconnect all widget callbacks when figure is closed
|
|
1633
|
+
add_span_btn.disconnect(add_span_cid)
|
|
1634
|
+
reset_btn.disconnect(reset_cid)
|
|
1635
|
+
confirm_btn.disconnect(confirm_cid)
|
|
1636
|
+
|
|
1637
|
+
# ...and remove the buttons before returning the figure
|
|
1638
|
+
add_span_btn.ax.remove()
|
|
1639
|
+
reset_btn.ax.remove()
|
|
1640
|
+
confirm_btn.ax.remove()
|
|
1641
|
+
plt.subplots_adjust(bottom=0.0)
|
|
1642
|
+
|
|
1643
|
+
selected_index_ranges = get_selected_index_ranges()
|
|
1644
|
+
|
|
1645
|
+
# Update the mask with the currently selected index ranges
|
|
1646
|
+
selected_mask = update_mask(len(x)*[False], selected_index_ranges)
|
|
1647
|
+
|
|
1648
|
+
buf = None
|
|
1649
|
+
if filename is not None or return_buf:
|
|
1650
|
+
if interactive:
|
|
1651
|
+
if len(selected_index_ranges) > 1:
|
|
1652
|
+
title += f'Selected ROIs: {selected_index_ranges}'
|
|
1653
|
+
else:
|
|
1654
|
+
title += f'Selected ROI: {tuple(selected_index_ranges[0])}'
|
|
1655
|
+
fig_title[0]._text = title
|
|
1656
|
+
fig_title[0].set_in_layout(True)
|
|
1657
|
+
fig.tight_layout(rect=(0, 0, 1, 0.95))
|
|
1658
|
+
if filename is not None:
|
|
1659
|
+
fig.savefig(filename)
|
|
1660
|
+
if return_buf:
|
|
1661
|
+
buf = fig_to_iobuf(fig)
|
|
1662
|
+
plt.close()
|
|
1663
|
+
return buf, selected_mask, selected_index_ranges
|
|
1664
|
+
|
|
1665
|
+
|
|
1666
|
+
def select_roi_1d(
|
|
1667
|
+
y, x=None, preselected_roi=None, title=None, xlabel=None, ylabel=None,
|
|
1668
|
+
interactive=True, filename=None, return_buf=False):
|
|
1669
|
+
"""Display a 2D plot and have the user select a single region
|
|
1670
|
+
of interest.
|
|
1671
|
+
|
|
1672
|
+
:param y: One-dimensional data array for which a for which a region
|
|
1673
|
+
of interest will be selected.
|
|
1674
|
+
:type y: numpy.ndarray
|
|
1675
|
+
:param x: x-coordinates of the data
|
|
1676
|
+
:type x: numpy.ndarray, optional
|
|
1677
|
+
:param preselected_roi: Preselected region of interest.
|
|
1678
|
+
:type preselected_roi: tuple(int, int), optional
|
|
1679
|
+
:param title: Title for the displayed figure.
|
|
1680
|
+
:type title: str, optional
|
|
1681
|
+
:param xlabel: Label for the x-axis of the displayed figure.
|
|
1682
|
+
:type xlabel: str, optional
|
|
1683
|
+
:param ylabel: Label for the y-axis of the displayed figure.
|
|
1684
|
+
:type ylabel: str, optional
|
|
1685
|
+
:param interactive: Show the plot and allow user interactions with
|
|
1686
|
+
the matplotlib figure, defaults to `True`.
|
|
1687
|
+
:type interactive: bool, optional
|
|
1688
|
+
:param filename: Save a .png of the plot to filename, defaults to
|
|
1689
|
+
`None`, in which case the plot is not saved.
|
|
1690
|
+
:type filename: str, optional
|
|
1691
|
+
:param return_buf: Return an in-memory object as a byte stream
|
|
1692
|
+
represention of the Matplotlib figure, defaults to `False`.
|
|
1693
|
+
:type return_buf: bool, optional
|
|
1694
|
+
:return: A byte stream represention of the Matplotlib figure if
|
|
1695
|
+
return_buf is `True` (`None` otherwise), and the selected
|
|
1696
|
+
region of interest.
|
|
1697
|
+
:rtype: Union[io.BytesIO, None], tuple(int, int)
|
|
1698
|
+
"""
|
|
1699
|
+
# Check inputs
|
|
1700
|
+
y = np.asarray(y)
|
|
1701
|
+
if y.ndim != 1:
|
|
1702
|
+
raise ValueError(f'Invalid image dimension ({y.ndim})')
|
|
1703
|
+
if preselected_roi is not None:
|
|
1704
|
+
if not is_int_pair(preselected_roi, ge=0, le=y.size, log=False):
|
|
1705
|
+
raise ValueError('Invalid parameter preselected_roi '
|
|
1706
|
+
f'({preselected_roi})')
|
|
1707
|
+
preselected_roi = [preselected_roi]
|
|
1708
|
+
|
|
1709
|
+
buf, _, roi = select_mask_1d(
|
|
1710
|
+
y, x=x, preselected_index_ranges=preselected_roi, title=title,
|
|
1711
|
+
xlabel=xlabel, ylabel=ylabel, min_num_index_ranges=1,
|
|
1712
|
+
max_num_index_ranges=1, interactive=interactive, filename=filename,
|
|
1713
|
+
return_buf=return_buf)
|
|
1714
|
+
|
|
1715
|
+
return buf, tuple(roi[0])
|
|
1716
|
+
|
|
1717
|
+
def select_roi_2d(
|
|
1718
|
+
a, preselected_roi=None, title=None, title_a=None,
|
|
1719
|
+
row_label='row index', column_label='column index', interactive=True,
|
|
1720
|
+
filename=None, return_buf=False):
|
|
1721
|
+
"""Display a 2D image and have the user select a single rectangular
|
|
1722
|
+
region of interest.
|
|
1723
|
+
|
|
1724
|
+
:param a: Two-dimensional image data array for which a region of
|
|
1725
|
+
interest will be selected.
|
|
1726
|
+
:type a: numpy.ndarray
|
|
1727
|
+
:param preselected_roi: Preselected region of interest.
|
|
1728
|
+
:type preselected_roi: tuple(int, int, int, int), optional
|
|
1729
|
+
:param title: Title for the displayed figure.
|
|
1730
|
+
:type title: str, optional
|
|
1731
|
+
:param title_a: Title for the image of a.
|
|
1732
|
+
:type title_a: str, optional
|
|
1733
|
+
:param row_label: Label for the y-axis of the displayed figure,
|
|
1734
|
+
defaults to `row index`.
|
|
1735
|
+
:type row_label: str, optional
|
|
1736
|
+
:param column_label: Label for the x-axis of the displayed figure,
|
|
1737
|
+
defaults to `column index`.
|
|
1738
|
+
:type column_label: str, optional
|
|
1739
|
+
:param interactive: Show the plot and allow user interactions with
|
|
1740
|
+
the matplotlib figure, defaults to `True`.
|
|
1741
|
+
:type interactive: bool, optional
|
|
1742
|
+
:param filename: Save a .png of the plot to filename, defaults to
|
|
1743
|
+
`None`, in which case the plot is not saved.
|
|
1744
|
+
:type filename: str, optional
|
|
1745
|
+
:param return_buf: Return an in-memory object as a byte stream
|
|
1746
|
+
represention of the Matplotlib figure, defaults to `False`.
|
|
1747
|
+
:type return_buf: bool, optional
|
|
1748
|
+
:return: A byte stream represention of the Matplotlib figure if
|
|
1749
|
+
return_buf is `True` (`None` otherwise), and the selected
|
|
1750
|
+
region of interest.
|
|
1751
|
+
:rtype: Union[io.BytesIO, None], tuple(int, int, int, int)
|
|
1752
|
+
"""
|
|
1753
|
+
# Third party modules
|
|
1754
|
+
# pylint: disable=possibly-used-before-assignment
|
|
1755
|
+
if interactive or filename is not None or return_buf:
|
|
1756
|
+
from matplotlib.widgets import Button, RectangleSelector
|
|
1757
|
+
|
|
1758
|
+
def change_fig_title(title):
|
|
1759
|
+
if fig_title:
|
|
1760
|
+
fig_title[0].remove()
|
|
1761
|
+
fig_title.pop()
|
|
1762
|
+
fig_title.append(plt.figtext(*title_pos, title, **title_props))
|
|
1763
|
+
|
|
1764
|
+
def change_subfig_title(error):
|
|
1765
|
+
if subfig_title:
|
|
1766
|
+
subfig_title[0].remove()
|
|
1767
|
+
subfig_title.pop()
|
|
1768
|
+
subfig_title.append(plt.figtext(*error_pos, error, **error_props))
|
|
1769
|
+
|
|
1770
|
+
def clear_selection():
|
|
1771
|
+
rects[0].set_visible(False)
|
|
1772
|
+
rects.pop()
|
|
1773
|
+
rects.append(
|
|
1774
|
+
RectangleSelector(
|
|
1775
|
+
ax, on_rect_select, props=rect_props,
|
|
1776
|
+
useblit=True, interactive=interactive, drag_from_anywhere=True,
|
|
1777
|
+
ignore_event_outside=False))
|
|
1778
|
+
|
|
1779
|
+
def on_rect_select(eclick, erelease):
|
|
1780
|
+
"""Callback function for the RectangleSelector widget."""
|
|
1781
|
+
if (not int(rects[0].extents[1]) - int(rects[0].extents[0])
|
|
1782
|
+
or not int(rects[0].extents[3]) - int(rects[0].extents[2])):
|
|
1783
|
+
clear_selection()
|
|
1784
|
+
change_subfig_title(
|
|
1785
|
+
'Selected ROI too small, try again')
|
|
1786
|
+
else:
|
|
1787
|
+
change_subfig_title(
|
|
1788
|
+
f'Selected ROI: {tuple(int(v) for v in rects[0].extents)}')
|
|
1789
|
+
plt.draw()
|
|
1790
|
+
|
|
1791
|
+
def reset(event):
|
|
1792
|
+
"""Callback function for the "Reset" button."""
|
|
1793
|
+
if subfig_title:
|
|
1794
|
+
subfig_title[0].remove()
|
|
1795
|
+
subfig_title.pop()
|
|
1796
|
+
clear_selection()
|
|
1797
|
+
plt.draw()
|
|
1798
|
+
|
|
1799
|
+
def confirm(event):
|
|
1800
|
+
"""Callback function for the "Confirm" button."""
|
|
1801
|
+
if subfig_title:
|
|
1802
|
+
subfig_title[0].remove()
|
|
1803
|
+
subfig_title.pop()
|
|
1804
|
+
roi = tuple(int(v) for v in rects[0].extents)
|
|
1805
|
+
if roi[1]-roi[0] < 1 or roi[3]-roi[2] < 1:
|
|
1806
|
+
roi = None
|
|
1807
|
+
change_fig_title(f'Selected ROI: {roi}')
|
|
1808
|
+
plt.close()
|
|
1809
|
+
|
|
1810
|
+
# Check inputs
|
|
1811
|
+
a = np.asarray(a)
|
|
1812
|
+
if a.ndim != 2:
|
|
1813
|
+
raise ValueError(f'Invalid image dimension ({a.ndim})')
|
|
1814
|
+
if preselected_roi is not None:
|
|
1815
|
+
if (not is_int_series(preselected_roi, ge=0, log=False)
|
|
1816
|
+
or len(preselected_roi) != 4):
|
|
1817
|
+
raise ValueError('Invalid parameter preselected_roi '
|
|
1818
|
+
f'({preselected_roi})')
|
|
1819
|
+
if title is None:
|
|
1820
|
+
title = 'Click and drag to select or adjust a region of interest (ROI)'
|
|
1821
|
+
|
|
1822
|
+
if not interactive and filename is None and not return_buf:
|
|
1823
|
+
return None, preselected_roi
|
|
1824
|
+
|
|
1825
|
+
fig_title = []
|
|
1826
|
+
subfig_title = []
|
|
1827
|
+
|
|
1828
|
+
title_pos = (0.5, 0.95)
|
|
1829
|
+
title_props = {'fontsize': 'xx-large', 'horizontalalignment': 'center',
|
|
1830
|
+
'verticalalignment': 'bottom'}
|
|
1831
|
+
error_pos = (0.5, 0.90)
|
|
1832
|
+
error_props = {'fontsize': 'xx-large', 'horizontalalignment': 'center',
|
|
1833
|
+
'verticalalignment': 'bottom'}
|
|
1834
|
+
rect_props = {
|
|
1835
|
+
'alpha': 0.5, 'facecolor': 'tab:blue', 'edgecolor': 'blue'}
|
|
1836
|
+
|
|
1837
|
+
fig, ax = plt.subplots(figsize=(11, 8.5))
|
|
1838
|
+
ax.imshow(a)
|
|
1839
|
+
ax.set_title(title_a, fontsize='xx-large')
|
|
1840
|
+
ax.set_xlabel(column_label, fontsize='x-large')
|
|
1841
|
+
ax.set_ylabel(row_label, fontsize='x-large')
|
|
1842
|
+
ax.set_xlim(0, a.shape[1])
|
|
1843
|
+
ax.set_ylim(a.shape[0], 0)
|
|
1844
|
+
fig.subplots_adjust(bottom=0.0, top=0.85)
|
|
1845
|
+
|
|
1846
|
+
# Setup the preselected range of interest if provided
|
|
1847
|
+
rects = [RectangleSelector(
|
|
1848
|
+
ax, on_rect_select, props=rect_props, useblit=True,
|
|
1849
|
+
interactive=interactive, drag_from_anywhere=True,
|
|
1850
|
+
ignore_event_outside=True)]
|
|
1851
|
+
if preselected_roi is not None:
|
|
1852
|
+
rects[0].extents = preselected_roi
|
|
1853
|
+
|
|
1854
|
+
if not interactive:
|
|
1855
|
+
|
|
1856
|
+
if preselected_roi is not None:
|
|
1857
|
+
change_fig_title(
|
|
1858
|
+
f'Selected ROI: {tuple(int(v) for v in preselected_roi)}')
|
|
1859
|
+
|
|
1860
|
+
else:
|
|
1861
|
+
|
|
1862
|
+
change_fig_title(title)
|
|
1863
|
+
if preselected_roi is not None:
|
|
1864
|
+
change_subfig_title(
|
|
1865
|
+
f'Preselected ROI: {tuple(int(v) for v in preselected_roi)}')
|
|
1866
|
+
fig.subplots_adjust(bottom=0.2)
|
|
1867
|
+
|
|
1868
|
+
# Setup "Reset" button
|
|
1869
|
+
reset_btn = Button(plt.axes([0.125, 0.05, 0.15, 0.075]), 'Reset')
|
|
1870
|
+
reset_cid = reset_btn.on_clicked(reset)
|
|
1871
|
+
|
|
1872
|
+
# Setup "Confirm" button
|
|
1873
|
+
confirm_btn = Button(plt.axes([0.75, 0.05, 0.15, 0.075]), 'Confirm')
|
|
1874
|
+
confirm_cid = confirm_btn.on_clicked(confirm)
|
|
1875
|
+
|
|
1876
|
+
# Show figure for user interaction
|
|
1877
|
+
plt.show()
|
|
1878
|
+
|
|
1879
|
+
# Disconnect all widget callbacks when figure is closed
|
|
1880
|
+
reset_btn.disconnect(reset_cid)
|
|
1881
|
+
confirm_btn.disconnect(confirm_cid)
|
|
1882
|
+
|
|
1883
|
+
# ... and remove the buttons before returning the figure
|
|
1884
|
+
reset_btn.ax.remove()
|
|
1885
|
+
confirm_btn.ax.remove()
|
|
1886
|
+
|
|
1887
|
+
buf = None
|
|
1888
|
+
if filename is not None or return_buf:
|
|
1889
|
+
if fig_title:
|
|
1890
|
+
fig_title[0].set_in_layout(True)
|
|
1891
|
+
fig.tight_layout(rect=(0, 0, 1, 0.95))
|
|
1892
|
+
else:
|
|
1893
|
+
fig.tight_layout(rect=(0, 0, 1, 1))
|
|
1894
|
+
|
|
1895
|
+
# Remove the handles
|
|
1896
|
+
if interactive:
|
|
1897
|
+
rects[0]._center_handle.set_visible(False)
|
|
1898
|
+
rects[0]._corner_handles.set_visible(False)
|
|
1899
|
+
rects[0]._edge_handles.set_visible(False)
|
|
1900
|
+
if filename is not None:
|
|
1901
|
+
fig.savefig(filename)
|
|
1902
|
+
if return_buf:
|
|
1903
|
+
buf = fig_to_iobuf(fig)
|
|
1904
|
+
plt.close()
|
|
1905
|
+
|
|
1906
|
+
roi = tuple(int(v) for v in rects[0].extents)
|
|
1907
|
+
if roi[1]-roi[0] < 1 or roi[3]-roi[2] < 1:
|
|
1908
|
+
roi = None
|
|
1909
|
+
|
|
1910
|
+
return buf, roi
|
|
1911
|
+
|
|
1912
|
+
|
|
1913
|
+
def select_image_indices(
|
|
1914
|
+
a, axis, b=None, preselected_indices=None, axis_index_offset=0,
|
|
1915
|
+
min_range=None, min_num_indices=2, max_num_indices=2, title=None,
|
|
1916
|
+
title_a=None, title_b=None, row_label='row index',
|
|
1917
|
+
column_label='column index', interactive=True, return_buf=False):
|
|
1918
|
+
"""Display a 2D image and have the user select a set of image
|
|
1919
|
+
indices in either row or column direction.
|
|
1920
|
+
|
|
1921
|
+
:param a: Two-dimensional image data array for which a region of
|
|
1922
|
+
interest will be selected.
|
|
1923
|
+
:type a: numpy.ndarray
|
|
1924
|
+
:param axis: The selection direction (0: row, 1: column)
|
|
1925
|
+
:type axis: int
|
|
1926
|
+
:param b: A secondary two-dimensional image data array for which
|
|
1927
|
+
a shared region of interest will be selected.
|
|
1928
|
+
:type b: numpy.ndarray, optional
|
|
1929
|
+
:param preselected_indices: Preselected image indices.
|
|
1930
|
+
:type preselected_indices: tuple(int), list(int), optional
|
|
1931
|
+
:param axis_index_offset: Offset in axis index range and
|
|
1932
|
+
preselected indices, defaults to `0`.
|
|
1933
|
+
:type axis_index_offset: int, optional
|
|
1934
|
+
:param min_range: The minimal range spanned by the selected
|
|
1935
|
+
indices.
|
|
1936
|
+
:type min_range: int, optional
|
|
1937
|
+
:param min_num_indices: The minimum number of selected indices.
|
|
1938
|
+
:type min_num_indices: int, optional
|
|
1939
|
+
:param max_num_indices: The maximum number of selected indices.
|
|
1940
|
+
:type max_num_indices: int, optional
|
|
1941
|
+
:param title: Title for the displayed figure.
|
|
1942
|
+
:type title: str, optional
|
|
1943
|
+
:param title_a: Title for the image of a.
|
|
1944
|
+
:type title_a: str, optional
|
|
1945
|
+
:param title_b: Title for the image of b.
|
|
1946
|
+
:type title_b: str, optional
|
|
1947
|
+
:param row_label: Label for the y-axis of the displayed figure,
|
|
1948
|
+
defaults to `row index`.
|
|
1949
|
+
:type row_label: str, optional
|
|
1950
|
+
:param column_label: Label for the x-axis of the displayed figure,
|
|
1951
|
+
defaults to `column index`.
|
|
1952
|
+
:type column_label: str, optional
|
|
1953
|
+
:param interactive: Show the plot and allow user interactions with
|
|
1954
|
+
the matplotlib figure, defaults to `True`.
|
|
1955
|
+
:type interactive: bool, optional
|
|
1956
|
+
:param return_buf: Return an in-memory object as a byte stream
|
|
1957
|
+
represention of the Matplotlib figure instead of the
|
|
1958
|
+
matplotlib figure, defaults to `False`.
|
|
1959
|
+
:type return_buf: bool, optional
|
|
1960
|
+
:return: The selected region of interest as array indices and a
|
|
1961
|
+
matplotlib figure.
|
|
1962
|
+
:rtype: Union[matplotlib.figure.Figure, io.BytesIO],
|
|
1963
|
+
tuple(int, int, int, int)
|
|
1964
|
+
"""
|
|
1965
|
+
# Third party modules
|
|
1966
|
+
from matplotlib.widgets import TextBox, Button
|
|
1967
|
+
|
|
1968
|
+
index_input = None
|
|
1969
|
+
|
|
1970
|
+
def change_fig_title(title):
|
|
1971
|
+
if fig_title:
|
|
1972
|
+
fig_title[0].remove()
|
|
1973
|
+
fig_title.pop()
|
|
1974
|
+
fig_title.append(plt.figtext(*title_pos, title, **title_props))
|
|
1975
|
+
|
|
1976
|
+
def change_error_text(error):
|
|
1977
|
+
if error_texts:
|
|
1978
|
+
error_texts[0].remove()
|
|
1979
|
+
error_texts.pop()
|
|
1980
|
+
error_texts.append(plt.figtext(*error_pos, error, **error_props))
|
|
1981
|
+
|
|
1982
|
+
def get_selected_indices(change_fnc=None):
|
|
1983
|
+
selected_indices = tuple(sorted(indices))
|
|
1984
|
+
if change_fnc is not None:
|
|
1985
|
+
num_indices = len(indices)
|
|
1986
|
+
if len(selected_indices) > 1:
|
|
1987
|
+
text = f'Selected {row_column} indices: {selected_indices}'
|
|
1988
|
+
elif selected_indices:
|
|
1989
|
+
text = f'Selected {row_column} index: {selected_indices[0]}'
|
|
1990
|
+
else:
|
|
1991
|
+
text = f'Selected {row_column} indices: None'
|
|
1992
|
+
if min_num_indices is not None and num_indices < min_num_indices:
|
|
1993
|
+
if min_num_indices == max_num_indices:
|
|
1994
|
+
text += \
|
|
1995
|
+
f', select another {max_num_indices-num_indices}'
|
|
1996
|
+
else:
|
|
1997
|
+
text += \
|
|
1998
|
+
f', select at least {max_num_indices-num_indices} more'
|
|
1999
|
+
change_fnc(text)
|
|
2000
|
+
return selected_indices
|
|
2001
|
+
|
|
2002
|
+
def add_index(index):
|
|
2003
|
+
if index in indices:
|
|
2004
|
+
raise ValueError(f'Ignoring duplicate of selected {row_column}s')
|
|
2005
|
+
if max_num_indices is not None and len(indices) >= max_num_indices:
|
|
2006
|
+
raise ValueError(
|
|
2007
|
+
f'Exceeding maximum number of selected {row_column}s, click '
|
|
2008
|
+
'either "Reset" or "Confirm"')
|
|
2009
|
+
if (indices and min_range is not None
|
|
2010
|
+
and abs(max(index, *indices) - min(index, *indices))
|
|
2011
|
+
< min_range):
|
|
2012
|
+
raise ValueError(
|
|
2013
|
+
f'Selected {row_column} range is smaller than required '
|
|
2014
|
+
'minimal range of {min_range}: ignoring last selection')
|
|
2015
|
+
indices.append(index)
|
|
2016
|
+
if not axis:
|
|
2017
|
+
for ax in axs:
|
|
2018
|
+
lines.append(ax.axhline(indices[-1], c='r', lw=2))
|
|
2019
|
+
else:
|
|
2020
|
+
for ax in axs:
|
|
2021
|
+
lines.append(ax.axvline(indices[-1], c='r', lw=2))
|
|
2022
|
+
|
|
2023
|
+
def select_index(expression):
|
|
2024
|
+
"""Callback function for the "Select row/column index" TextBox.
|
|
2025
|
+
"""
|
|
2026
|
+
if not expression:
|
|
2027
|
+
return
|
|
2028
|
+
if error_texts:
|
|
2029
|
+
error_texts[0].remove()
|
|
2030
|
+
error_texts.pop()
|
|
2031
|
+
try:
|
|
2032
|
+
index = int(expression)
|
|
2033
|
+
if (index < axis_index_offset
|
|
2034
|
+
or index > axis_index_offset+a.shape[axis]):
|
|
2035
|
+
raise ValueError
|
|
2036
|
+
except ValueError:
|
|
2037
|
+
change_error_text(
|
|
2038
|
+
f'Invalid {row_column} index ({expression}), enter an integer '
|
|
2039
|
+
f'between {axis_index_offset} and '
|
|
2040
|
+
f'{axis_index_offset+a.shape[axis]-1}')
|
|
2041
|
+
else:
|
|
2042
|
+
try:
|
|
2043
|
+
add_index(index)
|
|
2044
|
+
get_selected_indices(change_error_text)
|
|
2045
|
+
except ValueError as exc:
|
|
2046
|
+
change_error_text(exc)
|
|
2047
|
+
index_input.set_val('')
|
|
2048
|
+
for ax in axs:
|
|
2049
|
+
ax.get_figure().canvas.draw()
|
|
2050
|
+
|
|
2051
|
+
def reset(event):
|
|
2052
|
+
"""Callback function for the "Reset" button."""
|
|
2053
|
+
if error_texts:
|
|
2054
|
+
error_texts[0].remove()
|
|
2055
|
+
error_texts.pop()
|
|
2056
|
+
for line in reversed(lines):
|
|
2057
|
+
line.remove()
|
|
2058
|
+
indices.clear()
|
|
2059
|
+
lines.clear()
|
|
2060
|
+
get_selected_indices(change_error_text)
|
|
2061
|
+
for ax in axs:
|
|
2062
|
+
ax.get_figure().canvas.draw()
|
|
2063
|
+
|
|
2064
|
+
def confirm(event):
|
|
2065
|
+
"""Callback function for the "Confirm" button."""
|
|
2066
|
+
if len(indices) < min_num_indices:
|
|
2067
|
+
change_error_text(
|
|
2068
|
+
f'Select at least {min_num_indices} unique {row_column}s')
|
|
2069
|
+
for ax in axs:
|
|
2070
|
+
ax.get_figure().canvas.draw()
|
|
2071
|
+
else:
|
|
2072
|
+
# Remove error texts and add selected indices if set
|
|
2073
|
+
if error_texts:
|
|
2074
|
+
error_texts[0].remove()
|
|
2075
|
+
error_texts.pop()
|
|
2076
|
+
get_selected_indices(change_fig_title)
|
|
2077
|
+
plt.close()
|
|
2078
|
+
|
|
2079
|
+
# Check inputs
|
|
2080
|
+
a = np.asarray(a)
|
|
2081
|
+
if a.ndim != 2:
|
|
2082
|
+
raise ValueError(f'Invalid image dimension ({a.ndim})')
|
|
2083
|
+
if axis < 0 or axis >= a.ndim:
|
|
2084
|
+
raise ValueError(f'Invalid parameter axis ({axis})')
|
|
2085
|
+
if not axis:
|
|
2086
|
+
row_column = 'row'
|
|
2087
|
+
else:
|
|
2088
|
+
row_column = 'column'
|
|
2089
|
+
if not is_int(axis_index_offset, ge=0, log=False):
|
|
2090
|
+
raise ValueError(
|
|
2091
|
+
'Invalid parameter axis_index_offset ({axis_index_offset})')
|
|
2092
|
+
if preselected_indices is not None:
|
|
2093
|
+
if not is_int_series(
|
|
2094
|
+
preselected_indices, ge=axis_index_offset,
|
|
2095
|
+
le=axis_index_offset+a.shape[axis], log=False):
|
|
2096
|
+
if interactive:
|
|
2097
|
+
logger.warning(
|
|
2098
|
+
'Invalid parameter preselected_indices '
|
|
2099
|
+
f'({preselected_indices}), ignoring preselected_indices')
|
|
2100
|
+
preselected_indices = None
|
|
2101
|
+
else:
|
|
2102
|
+
raise ValueError('Invalid parameter preselected_indices '
|
|
2103
|
+
f'({preselected_indices})')
|
|
2104
|
+
if min_range is not None and not 2 <= min_range <= a.shape[axis]:
|
|
2105
|
+
raise ValueError('Invalid parameter min_range ({min_range})')
|
|
2106
|
+
if title is None:
|
|
2107
|
+
title = f'Select or adjust image {row_column} indices'
|
|
2108
|
+
if b is not None:
|
|
2109
|
+
b = np.asarray(b)
|
|
2110
|
+
if b.ndim != 2:
|
|
2111
|
+
raise ValueError(f'Invalid image dimension ({b.ndim})')
|
|
2112
|
+
if a.shape[0] != b.shape[0]:
|
|
2113
|
+
raise ValueError(f'Inconsistent image shapes({a.shape} vs '
|
|
2114
|
+
f'{b.shape})')
|
|
2115
|
+
|
|
2116
|
+
indices = []
|
|
2117
|
+
lines = []
|
|
2118
|
+
fig_title = []
|
|
2119
|
+
error_texts = []
|
|
2120
|
+
|
|
2121
|
+
title_pos = (0.5, 0.95)
|
|
2122
|
+
title_props = {'fontsize': 'xx-large', 'horizontalalignment': 'center',
|
|
2123
|
+
'verticalalignment': 'bottom'}
|
|
2124
|
+
error_pos = (0.5, 0.90)
|
|
2125
|
+
error_props = {'fontsize': 'x-large', 'horizontalalignment': 'center',
|
|
2126
|
+
'verticalalignment': 'bottom'}
|
|
2127
|
+
if b is None:
|
|
2128
|
+
fig, axs = plt.subplots(figsize=(11, 8.5))
|
|
2129
|
+
axs = [axs]
|
|
2130
|
+
else:
|
|
2131
|
+
if a.shape[0]+b.shape[0] > max(a.shape[1], b.shape[1]):
|
|
2132
|
+
fig, axs = plt.subplots(1, 2, figsize=(11, 8.5))
|
|
2133
|
+
else:
|
|
2134
|
+
fig, axs = plt.subplots(2, 1, figsize=(11, 8.5))
|
|
2135
|
+
extent = (0, a.shape[1], axis_index_offset+a.shape[0], axis_index_offset)
|
|
2136
|
+
axs[0].imshow(a, extent=extent)
|
|
2137
|
+
axs[0].set_title(title_a, fontsize='xx-large')
|
|
2138
|
+
if b is not None:
|
|
2139
|
+
axs[1].imshow(b, extent=extent)
|
|
2140
|
+
axs[1].set_title(title_b, fontsize='xx-large')
|
|
2141
|
+
if a.shape[0]+b.shape[0] > max(a.shape[1], b.shape[1]):
|
|
2142
|
+
axs[0].set_xlabel(column_label, fontsize='x-large')
|
|
2143
|
+
axs[0].set_ylabel(row_label, fontsize='x-large')
|
|
2144
|
+
axs[1].set_xlabel(column_label, fontsize='x-large')
|
|
2145
|
+
else:
|
|
2146
|
+
axs[0].set_ylabel(row_label, fontsize='x-large')
|
|
2147
|
+
axs[1].set_xlabel(column_label, fontsize='x-large')
|
|
2148
|
+
axs[1].set_ylabel(row_label, fontsize='x-large')
|
|
2149
|
+
for ax in axs:
|
|
2150
|
+
ax.set_xlim(extent[0], extent[1])
|
|
2151
|
+
ax.set_ylim(extent[2], extent[3])
|
|
2152
|
+
fig.subplots_adjust(bottom=0.0, top=0.85)
|
|
2153
|
+
|
|
2154
|
+
# Setup the preselected indices if provided
|
|
2155
|
+
if preselected_indices is not None:
|
|
2156
|
+
preselected_indices = sorted(list(preselected_indices))
|
|
2157
|
+
for index in preselected_indices:
|
|
2158
|
+
add_index(index)
|
|
2159
|
+
|
|
2160
|
+
if not interactive:
|
|
2161
|
+
|
|
2162
|
+
get_selected_indices(change_fig_title)
|
|
2163
|
+
|
|
2164
|
+
else:
|
|
2165
|
+
|
|
2166
|
+
change_fig_title(title)
|
|
2167
|
+
get_selected_indices(change_error_text)
|
|
2168
|
+
fig.subplots_adjust(bottom=0.2)
|
|
2169
|
+
|
|
2170
|
+
# Setup TextBox
|
|
2171
|
+
index_input = TextBox(
|
|
2172
|
+
plt.axes([0.25, 0.05, 0.15, 0.075]), f'Select {row_column} index ')
|
|
2173
|
+
indices_cid = index_input.on_submit(select_index)
|
|
2174
|
+
|
|
2175
|
+
# Setup "Reset" button
|
|
2176
|
+
reset_btn = Button(plt.axes([0.5, 0.05, 0.15, 0.075]), 'Reset')
|
|
2177
|
+
reset_cid = reset_btn.on_clicked(reset)
|
|
2178
|
+
|
|
2179
|
+
# Setup "Confirm" button
|
|
2180
|
+
confirm_btn = Button(plt.axes([0.75, 0.05, 0.15, 0.075]), 'Confirm')
|
|
2181
|
+
confirm_cid = confirm_btn.on_clicked(confirm)
|
|
2182
|
+
|
|
2183
|
+
plt.show()
|
|
2184
|
+
|
|
2185
|
+
# Disconnect all widget callbacks when figure is closed
|
|
2186
|
+
index_input.disconnect(indices_cid)
|
|
2187
|
+
reset_btn.disconnect(reset_cid)
|
|
2188
|
+
confirm_btn.disconnect(confirm_cid)
|
|
2189
|
+
|
|
2190
|
+
# ... and remove the buttons before returning the figure
|
|
2191
|
+
index_input.ax.remove()
|
|
2192
|
+
reset_btn.ax.remove()
|
|
2193
|
+
confirm_btn.ax.remove()
|
|
2194
|
+
|
|
2195
|
+
fig_title[0].set_in_layout(True)
|
|
2196
|
+
fig.tight_layout(rect=(0, 0, 1, 0.95))
|
|
2197
|
+
|
|
2198
|
+
if return_buf:
|
|
2199
|
+
buf = fig_to_iobuf(fig)
|
|
2200
|
+
else:
|
|
2201
|
+
buf = None
|
|
2202
|
+
plt.close()
|
|
2203
|
+
if indices:
|
|
2204
|
+
return buf, tuple(sorted(indices))
|
|
2205
|
+
return buf, None
|
|
2206
|
+
|
|
2207
|
+
|
|
2208
|
+
def quick_imshow(
|
|
2209
|
+
a, title=None, row_label='row index', column_label='column index',
|
|
2210
|
+
path=None, name=None, show_fig=True, save_fig=False,
|
|
2211
|
+
return_fig=False, block=None, extent=None, show_grid=False,
|
|
2212
|
+
grid_color='w', grid_linewidth=1, **kwargs):
|
|
2213
|
+
"""Display and or save a 2D image and or return an in-memory object
|
|
2214
|
+
as a byte stream represention.
|
|
2215
|
+
"""
|
|
2216
|
+
if title is not None and not isinstance(title, str):
|
|
2217
|
+
raise ValueError(f'Invalid parameter title ({title})')
|
|
2218
|
+
if path is not None and not isinstance(path, str):
|
|
2219
|
+
raise ValueError(f'Invalid parameter path ({path})')
|
|
2220
|
+
if not isinstance(show_fig, bool):
|
|
2221
|
+
raise ValueError(f'Invalid parameter show_fig ({show_fig})')
|
|
2222
|
+
if not isinstance(save_fig, bool):
|
|
2223
|
+
raise ValueError(f'Invalid parameter save_fig ({save_fig})')
|
|
2224
|
+
if not isinstance(return_fig, bool):
|
|
2225
|
+
raise ValueError(f'Invalid parameter return_fig ({return_fig})')
|
|
2226
|
+
if block is not None and not isinstance(block, bool):
|
|
2227
|
+
raise ValueError(f'Invalid parameter block ({block})')
|
|
2228
|
+
if not title:
|
|
2229
|
+
title = 'quick imshow'
|
|
2230
|
+
if ('cmap' in kwargs and a.ndim == 3
|
|
2231
|
+
and (a.shape[2] == 3 or a.shape[2] == 4)):
|
|
2232
|
+
use_cmap = True
|
|
2233
|
+
if a.shape[2] == 4 and a[:,:,-1].min() != a[:,:,-1].max():
|
|
2234
|
+
use_cmap = False
|
|
2235
|
+
if any(
|
|
2236
|
+
a[i,j,0] != a[i,j,1] and a[i,j,0] != a[i,j,2]
|
|
2237
|
+
for i in range(a.shape[0])
|
|
2238
|
+
for j in range(a.shape[1])):
|
|
2239
|
+
use_cmap = False
|
|
2240
|
+
if use_cmap:
|
|
2241
|
+
a = a[:,:,0]
|
|
2242
|
+
else:
|
|
2243
|
+
logger.warning('Image incompatible with cmap option, ignore cmap')
|
|
2244
|
+
kwargs.pop('cmap')
|
|
2245
|
+
if extent is None:
|
|
2246
|
+
extent = (0, a.shape[1], a.shape[0], 0)
|
|
2247
|
+
plt.ioff()
|
|
2248
|
+
fig, ax = plt.subplots(figsize=(11, 8.5))
|
|
2249
|
+
ax.imshow(a, extent=extent, **kwargs)
|
|
2250
|
+
ax.set_title(title, fontsize='xx-large')
|
|
2251
|
+
ax.set_xlabel(column_label, fontsize='x-large')
|
|
2252
|
+
ax.set_ylabel(row_label, fontsize='x-large')
|
|
2253
|
+
if show_grid:
|
|
2254
|
+
ax.grid(color=grid_color, linewidth=grid_linewidth)
|
|
2255
|
+
if show_fig:
|
|
2256
|
+
plt.show(block=block)
|
|
2257
|
+
if save_fig:
|
|
2258
|
+
if name is None:
|
|
2259
|
+
title = re.sub(r'\s+', '_', title)
|
|
2260
|
+
if path is None:
|
|
2261
|
+
path = title
|
|
2262
|
+
else:
|
|
2263
|
+
path = f'{path}/{title}'
|
|
2264
|
+
else:
|
|
2265
|
+
if path is None:
|
|
2266
|
+
path = name
|
|
2267
|
+
else:
|
|
2268
|
+
path = f'{path}/{name}'
|
|
2269
|
+
if (os.path.splitext(path)[1]
|
|
2270
|
+
not in plt.gcf().canvas.get_supported_filetypes()):
|
|
2271
|
+
path += '.png'
|
|
2272
|
+
plt.savefig(path)
|
|
2273
|
+
if return_fig:
|
|
2274
|
+
buf = fig_to_iobuf(fig)
|
|
2275
|
+
else:
|
|
2276
|
+
buf = None
|
|
2277
|
+
plt.close()
|
|
2278
|
+
return buf
|
|
2279
|
+
|
|
2280
|
+
|
|
2281
|
+
def quick_plot(
|
|
2282
|
+
*args, xerr=None, yerr=None, vlines=None, title=None, xlim=None,
|
|
2283
|
+
ylim=None, xlabel=None, ylabel=None, legend=None, path=None, name=None,
|
|
2284
|
+
show_grid=False, save_fig=False, save_only=False, block=False,
|
|
2285
|
+
**kwargs):
|
|
2286
|
+
"""Display a 2D line plot."""
|
|
2287
|
+
#RV FIX: Update with return_buf
|
|
2288
|
+
if title is not None and not isinstance(title, str):
|
|
2289
|
+
illegal_value(title, 'title', 'quick_plot')
|
|
2290
|
+
title = None
|
|
2291
|
+
if (xlim is not None and not isinstance(xlim, (tuple, list))
|
|
2292
|
+
and len(xlim) != 2):
|
|
2293
|
+
illegal_value(xlim, 'xlim', 'quick_plot')
|
|
2294
|
+
xlim = None
|
|
2295
|
+
if (ylim is not None and not isinstance(ylim, (tuple, list))
|
|
2296
|
+
and len(ylim) != 2):
|
|
2297
|
+
illegal_value(ylim, 'ylim', 'quick_plot')
|
|
2298
|
+
ylim = None
|
|
2299
|
+
if xlabel is not None and not isinstance(xlabel, str):
|
|
2300
|
+
illegal_value(xlabel, 'xlabel', 'quick_plot')
|
|
2301
|
+
xlabel = None
|
|
2302
|
+
if ylabel is not None and not isinstance(ylabel, str):
|
|
2303
|
+
illegal_value(ylabel, 'ylabel', 'quick_plot')
|
|
2304
|
+
ylabel = None
|
|
2305
|
+
if legend is not None and not isinstance(legend, (tuple, list)):
|
|
2306
|
+
illegal_value(legend, 'legend', 'quick_plot')
|
|
2307
|
+
legend = None
|
|
2308
|
+
if path is not None and not isinstance(path, str):
|
|
2309
|
+
illegal_value(path, 'path', 'quick_plot')
|
|
2310
|
+
return
|
|
2311
|
+
if not isinstance(show_grid, bool):
|
|
2312
|
+
illegal_value(show_grid, 'show_grid', 'quick_plot')
|
|
2313
|
+
return
|
|
2314
|
+
if not isinstance(save_fig, bool):
|
|
2315
|
+
illegal_value(save_fig, 'save_fig', 'quick_plot')
|
|
2316
|
+
return
|
|
2317
|
+
if not isinstance(save_only, bool):
|
|
2318
|
+
illegal_value(save_only, 'save_only', 'quick_plot')
|
|
2319
|
+
return
|
|
2320
|
+
if not isinstance(block, bool):
|
|
2321
|
+
illegal_value(block, 'block', 'quick_plot')
|
|
2322
|
+
return
|
|
2323
|
+
if title is None:
|
|
2324
|
+
title = 'quick plot'
|
|
2325
|
+
if name is None:
|
|
2326
|
+
ttitle = re.sub(r'\s+', '_', title)
|
|
2327
|
+
if path is None:
|
|
2328
|
+
path = f'{ttitle}.png'
|
|
2329
|
+
else:
|
|
2330
|
+
path = f'{path}/{ttitle}.png'
|
|
2331
|
+
else:
|
|
2332
|
+
if path is None:
|
|
2333
|
+
path = name
|
|
2334
|
+
else:
|
|
2335
|
+
path = f'{path}/{name}'
|
|
2336
|
+
args = unwrap_tuple(args)
|
|
2337
|
+
if depth_tuple(args) > 1 and (xerr is not None or yerr is not None):
|
|
2338
|
+
logger.warning('Error bars ignored for multiple curves')
|
|
2339
|
+
if not save_only:
|
|
2340
|
+
if block:
|
|
2341
|
+
plt.ioff()
|
|
2342
|
+
else:
|
|
2343
|
+
plt.ion()
|
|
2344
|
+
plt.figure(title)
|
|
2345
|
+
if depth_tuple(args) > 1:
|
|
2346
|
+
for y in args:
|
|
2347
|
+
plt.plot(*y, **kwargs)
|
|
2348
|
+
else:
|
|
2349
|
+
if xerr is None and yerr is None:
|
|
2350
|
+
plt.plot(*args, **kwargs)
|
|
2351
|
+
else:
|
|
2352
|
+
plt.errorbar(*args, xerr=xerr, yerr=yerr, **kwargs)
|
|
2353
|
+
if vlines is not None:
|
|
2354
|
+
if isinstance(vlines, (int, float)):
|
|
2355
|
+
vlines = [vlines]
|
|
2356
|
+
for v in vlines:
|
|
2357
|
+
plt.axvline(v, color='r', linestyle='--', **kwargs)
|
|
2358
|
+
if xlim is not None:
|
|
2359
|
+
plt.xlim(xlim)
|
|
2360
|
+
if ylim is not None:
|
|
2361
|
+
plt.ylim(ylim)
|
|
2362
|
+
if xlabel is not None:
|
|
2363
|
+
plt.xlabel(xlabel)
|
|
2364
|
+
if ylabel is not None:
|
|
2365
|
+
plt.ylabel(ylabel)
|
|
2366
|
+
if show_grid:
|
|
2367
|
+
ax = plt.gca()
|
|
2368
|
+
ax.grid(color='k') # , linewidth=1)
|
|
2369
|
+
if legend is not None:
|
|
2370
|
+
plt.legend(legend)
|
|
2371
|
+
if save_only:
|
|
2372
|
+
plt.savefig(path)
|
|
2373
|
+
plt.close(fig=title)
|
|
2374
|
+
else:
|
|
2375
|
+
if save_fig:
|
|
2376
|
+
plt.savefig(path)
|
|
2377
|
+
plt.show(block=block)
|
|
2378
|
+
plt.close()
|
|
2379
|
+
|
|
2380
|
+
|
|
2381
|
+
def nxcopy(
|
|
2382
|
+
nxobject, exclude_nxpaths=None, nxpath_prefix=None,
|
|
2383
|
+
nxpathabs_prefix=None, nxpath_copy_abspath=None):
|
|
2384
|
+
"""Function that returns a copy of a nexus object, optionally
|
|
2385
|
+
exluding certain child items.
|
|
2386
|
+
|
|
2387
|
+
:param nxobject: The input nexus object to "copy".
|
|
2388
|
+
:type nxobject: nexusformat.nexus.NXobject
|
|
2389
|
+
:param exlude_nxpaths: A list of relative paths to child nexus
|
|
2390
|
+
objects that should be excluded from the returned "copy".
|
|
2391
|
+
:type exclude_nxpaths: str, list[str], optional
|
|
2392
|
+
:param nxpath_prefix: For use in recursive calls from inside this
|
|
2393
|
+
function only.
|
|
2394
|
+
:type nxpath_prefix: str
|
|
2395
|
+
:param nxpathabs_prefix: For use in recursive calls from inside
|
|
2396
|
+
this function only.
|
|
2397
|
+
:type nxpathabs_prefix: str
|
|
2398
|
+
:param nxpath_copy_abspath: For use in recursive calls from inside
|
|
2399
|
+
this function only.
|
|
2400
|
+
:type nxpath_copy_abspath: str
|
|
2401
|
+
:return: Copy of the input `nxobject` with some children optionally
|
|
2402
|
+
exluded.
|
|
2403
|
+
:rtype: nexusformat.nexus.NXobject
|
|
2404
|
+
"""
|
|
2405
|
+
# Third party modules
|
|
2406
|
+
from nexusformat.nexus import (
|
|
2407
|
+
NXentry,
|
|
2408
|
+
NXfield,
|
|
2409
|
+
NXgroup,
|
|
2410
|
+
NXlink,
|
|
2411
|
+
NXlinkgroup,
|
|
2412
|
+
NXroot,
|
|
2413
|
+
)
|
|
2414
|
+
|
|
2415
|
+
|
|
2416
|
+
if isinstance(nxobject, NXlinkgroup):
|
|
2417
|
+
# The top level nxobject is a linked group
|
|
2418
|
+
# Create a group with the same name as the top level's target
|
|
2419
|
+
nxobject_copy = nxobject[nxobject.nxtarget].__class__(
|
|
2420
|
+
name=nxobject.nxname)
|
|
2421
|
+
elif isinstance(nxobject, (NXlink, NXfield)):
|
|
2422
|
+
# The top level nxobject is a (linked) field: return a copy
|
|
2423
|
+
attrs = nxobject.attrs
|
|
2424
|
+
attrs.pop('target', None)
|
|
2425
|
+
nxobject_copy = NXfield(
|
|
2426
|
+
value=nxobject.nxdata, name=nxobject.nxname,
|
|
2427
|
+
attrs=attrs)
|
|
2428
|
+
return nxobject_copy
|
|
2429
|
+
else:
|
|
2430
|
+
# Create a group with the same type/name as the nxobject
|
|
2431
|
+
nxobject_copy = nxobject.__class__(name=nxobject.nxname)
|
|
2432
|
+
|
|
2433
|
+
# Copy attributes
|
|
2434
|
+
if isinstance(nxobject, NXroot):
|
|
2435
|
+
if 'default' in nxobject.attrs:
|
|
2436
|
+
nxobject_copy.attrs['default'] = nxobject.default
|
|
2437
|
+
else:
|
|
2438
|
+
for k, v in nxobject.attrs.items():
|
|
2439
|
+
nxobject_copy.attrs[k] = v
|
|
2440
|
+
|
|
2441
|
+
# Setup paths
|
|
2442
|
+
if exclude_nxpaths is None:
|
|
2443
|
+
exclude_nxpaths = []
|
|
2444
|
+
elif isinstance(exclude_nxpaths, str):
|
|
2445
|
+
exclude_nxpaths = [exclude_nxpaths]
|
|
2446
|
+
for exclude_nxpath in exclude_nxpaths:
|
|
2447
|
+
if exclude_nxpath[0] == '/':
|
|
2448
|
+
raise ValueError(
|
|
2449
|
+
f'Invalid parameter in exclude_nxpaths ({exclude_nxpaths}), '
|
|
2450
|
+
'excluded paths should be relative')
|
|
2451
|
+
if nxpath_prefix is None:
|
|
2452
|
+
nxpath_prefix = ''
|
|
2453
|
+
if nxpathabs_prefix is None:
|
|
2454
|
+
if isinstance(nxobject, NXentry):
|
|
2455
|
+
nxpathabs_prefix = nxobject.nxpath
|
|
2456
|
+
else:
|
|
2457
|
+
nxpathabs_prefix = nxobject.nxpath.removesuffix(nxobject.nxname)
|
|
2458
|
+
if nxpath_copy_abspath is None:
|
|
2459
|
+
nxpath_copy_abspath = ''
|
|
2460
|
+
|
|
2461
|
+
# Loop over all nxobject's children
|
|
2462
|
+
for k, v in nxobject.items():
|
|
2463
|
+
nxpath = os.path.join(nxpath_prefix, k)
|
|
2464
|
+
nxpathabs = os.path.join(nxpathabs_prefix, nxpath)
|
|
2465
|
+
if nxpath in exclude_nxpaths:
|
|
2466
|
+
if 'default' in nxobject_copy.attrs and nxobject_copy.default == k:
|
|
2467
|
+
nxobject_copy.attrs.pop('default')
|
|
2468
|
+
continue
|
|
2469
|
+
if isinstance(v, NXlinkgroup):
|
|
2470
|
+
if nxpathabs == v.nxpath and not any(
|
|
2471
|
+
v.nxtarget.startswith(os.path.join(nxpathabs_prefix, p))
|
|
2472
|
+
for p in exclude_nxpaths):
|
|
2473
|
+
nxobject_copy[k] = NXlink(v.nxtarget)
|
|
2474
|
+
else:
|
|
2475
|
+
nxobject_copy[k] = nxcopy(
|
|
2476
|
+
v, exclude_nxpaths=exclude_nxpaths,
|
|
2477
|
+
nxpath_prefix=nxpath, nxpathabs_prefix=nxpathabs_prefix,
|
|
2478
|
+
nxpath_copy_abspath=os.path.join(nxpath_copy_abspath, k))
|
|
2479
|
+
elif isinstance(v, NXlink):
|
|
2480
|
+
if nxpathabs == v.nxpath and not any(
|
|
2481
|
+
v.nxtarget.startswith(os.path.join(nxpathabs_prefix, p))
|
|
2482
|
+
for p in exclude_nxpaths):
|
|
2483
|
+
nxobject_copy[k] = v
|
|
2484
|
+
else:
|
|
2485
|
+
nxobject_copy[k] = v.nxdata
|
|
2486
|
+
for kk, vv in v.attrs.items():
|
|
2487
|
+
nxobject_copy[k].attrs[kk] = vv
|
|
2488
|
+
nxobject_copy[k].attrs.pop('target', None)
|
|
2489
|
+
elif isinstance(v, NXgroup):
|
|
2490
|
+
nxobject_copy[k] = nxcopy(
|
|
2491
|
+
v, exclude_nxpaths=exclude_nxpaths,
|
|
2492
|
+
nxpath_prefix=nxpath, nxpathabs_prefix=nxpathabs_prefix,
|
|
2493
|
+
nxpath_copy_abspath=os.path.join(nxpath_copy_abspath, k))
|
|
2494
|
+
else:
|
|
2495
|
+
nxobject_copy[k] = v.nxdata
|
|
2496
|
+
for kk, vv in v.attrs.items():
|
|
2497
|
+
nxobject_copy[k].attrs[kk] = vv
|
|
2498
|
+
if nxpathabs != os.path.join(nxpath_copy_abspath, k):
|
|
2499
|
+
nxobject_copy[k].attrs.pop('target', None)
|
|
2500
|
+
|
|
2501
|
+
return nxobject_copy
|
|
2502
|
+
|
|
2503
|
+
|
|
2504
|
+
def dictionary_update(target, source, merge_key_paths=None, sort=False):
|
|
2505
|
+
"""Recursively updates a target dictionary with values from a source
|
|
2506
|
+
dictionary. Source values superseed target values for identical keys
|
|
2507
|
+
unless both values are lists of dictionaries in which case they are
|
|
2508
|
+
merged according to the merge_key_paths parameter.
|
|
2509
|
+
|
|
2510
|
+
:param target: Target dictionary.
|
|
2511
|
+
:type target: collections.abc.Mapping
|
|
2512
|
+
:param source: Source dictionary.
|
|
2513
|
+
:type target: collections.abc.Mapping
|
|
2514
|
+
:param merge_key_paths: List key paths to merge dictionary lists,
|
|
2515
|
+
only used if items in the target and source dictionary trees
|
|
2516
|
+
are lists of dictionaries.
|
|
2517
|
+
:type merge_key_paths: Union[str, list[str]]
|
|
2518
|
+
:param sort: Sort dictionary lists on the key.
|
|
2519
|
+
:type sort: bool, optional
|
|
2520
|
+
:return: The updated target directory.
|
|
2521
|
+
:rtype: collections.abc.Mapping
|
|
2522
|
+
"""
|
|
2523
|
+
if not isinstance(target, dict):
|
|
2524
|
+
raise ValueError(
|
|
2525
|
+
'Invalid parameter type "target" ({type(target)})')
|
|
2526
|
+
if not isinstance(source, dict):
|
|
2527
|
+
raise ValueError(
|
|
2528
|
+
'Invalid parameter type "source" ({type(source)})')
|
|
2529
|
+
for k, v in source.items():
|
|
2530
|
+
if (isinstance(v, collections.abc.Mapping)
|
|
2531
|
+
and isinstance(target.get(k), collections.abc.Mapping)):
|
|
2532
|
+
if merge_key_paths is not None:
|
|
2533
|
+
raise NotImplementedError(
|
|
2534
|
+
f'"merge_key_paths" ({type(merge_key_paths)}) '
|
|
2535
|
+
'for source and target dictionaries not yet implemented')
|
|
2536
|
+
# merge_key_path = None
|
|
2537
|
+
# if '/' in merge_key_paths:
|
|
2538
|
+
# print(f'"/" in merge_key_path')
|
|
2539
|
+
# merge_key_path = merge_key_paths.split('/', 1)[1:]
|
|
2540
|
+
# elif is_str_series(merge_key_paths):
|
|
2541
|
+
# print(f'merge_key_path is string series')
|
|
2542
|
+
# merge_key_path = [
|
|
2543
|
+
# vv[1] for vv in [
|
|
2544
|
+
# v for v in [merge_key_paths.split('/', 1)
|
|
2545
|
+
# for s in sss]]
|
|
2546
|
+
# if (vv[0]==k and len(vv)>1)]
|
|
2547
|
+
# print(f'---> merge_key_path: {merge_key_path}')
|
|
2548
|
+
target[k] = dictionary_update(target.get(k, {}), v)
|
|
2549
|
+
elif (is_dict_series(v, log=False)
|
|
2550
|
+
and is_dict_series(target.get(k), log=False)):
|
|
2551
|
+
if isinstance(merge_key_paths, str):
|
|
2552
|
+
merge_key_path = merge_key_paths
|
|
2553
|
+
merge_key_type = None
|
|
2554
|
+
elif isinstance(merge_key_paths, dict):
|
|
2555
|
+
merge_key_path = merge_key_paths.get('key_path')
|
|
2556
|
+
merge_key_type = merge_key_paths.get('type')
|
|
2557
|
+
elif merge_key_path is not None:
|
|
2558
|
+
raise NotImplementedError(
|
|
2559
|
+
'Invalid/unimplemeted parameter type "merge_key_path" '
|
|
2560
|
+
f'({type(merge_key_pathsource)}) for source and target '
|
|
2561
|
+
'lists of dictionaries')
|
|
2562
|
+
merge_key = l[1] if len(
|
|
2563
|
+
l:=merge_key_path.split('/')) == 2 else None
|
|
2564
|
+
# if '/' in merge_key_paths:
|
|
2565
|
+
# merge_key_paths = [merge_key_paths]
|
|
2566
|
+
# if is_str_series(merge_key_paths):
|
|
2567
|
+
# paths = paths if len(
|
|
2568
|
+
# paths:=[l[1] for path in merge_key_paths
|
|
2569
|
+
# if (l:=path.split('/', 1))[0] == k and len(l)>1]
|
|
2570
|
+
# ) else [None]
|
|
2571
|
+
# if len(paths) > 1:
|
|
2572
|
+
# raise ValueError(
|
|
2573
|
+
# 'Ambiguous parameter merge_key_paths '
|
|
2574
|
+
# f'({merge_key_paths}) while trying to merge '
|
|
2575
|
+
# f'{source} with {target}')
|
|
2576
|
+
# merge_path = paths[0]
|
|
2577
|
+
# else:
|
|
2578
|
+
# merge_path = None
|
|
2579
|
+
target[k] = list_dictionary_update(
|
|
2580
|
+
target.get(k), v, key=merge_key, key_type=merge_key_type,
|
|
2581
|
+
sort=sort)
|
|
2582
|
+
else:
|
|
2583
|
+
target[k] = v
|
|
2584
|
+
return target
|
|
2585
|
+
|
|
2586
|
+
|
|
2587
|
+
def list_dictionary_update(
|
|
2588
|
+
target, source, key=None, key_type=None, sort=False):
|
|
2589
|
+
"""Recursively updates a target list of dictionaries with values
|
|
2590
|
+
from a source list of dictionaries. Each list item is updated item
|
|
2591
|
+
by item based on the key if given and equal to a key that is shared
|
|
2592
|
+
among all sets of source and target list item keys. Otherwise the
|
|
2593
|
+
target list appended to the source list is returned.
|
|
2594
|
+
|
|
2595
|
+
:param target: Target list.
|
|
2596
|
+
:type target: list
|
|
2597
|
+
:param source: Source list.
|
|
2598
|
+
:type source: list
|
|
2599
|
+
:param key: Selected key to merge the lists of dictionaries.
|
|
2600
|
+
:type key: str, optional
|
|
2601
|
+
:param key_type: Key type to enforce.
|
|
2602
|
+
:type key_type: type, optional
|
|
2603
|
+
:param sort: Sort the returned list on the key.
|
|
2604
|
+
:type sort: bool, optional
|
|
2605
|
+
:return: The updated list.
|
|
2606
|
+
:rtype: list
|
|
2607
|
+
"""
|
|
2608
|
+
if not isinstance(target, list):
|
|
2609
|
+
raise ValueError(
|
|
2610
|
+
'Invalid parameter type "target" ({type(target)})')
|
|
2611
|
+
if not isinstance(source, list):
|
|
2612
|
+
raise ValueError(
|
|
2613
|
+
'Invalid parameter type "source" ({type(source)})')
|
|
2614
|
+
if key is None:
|
|
2615
|
+
return source + target
|
|
2616
|
+
if not isinstance(key, str) or '/' in key:
|
|
2617
|
+
raise ValueError('Invalid parameter "key" ({key}, {type(key)})')
|
|
2618
|
+
if not (key_type is None or isinstance(key_type, type)):
|
|
2619
|
+
raise ValueError(
|
|
2620
|
+
'Invalid parameter "key_type" ({key_type}, {type(key_type)})')
|
|
2621
|
+
all_any_source = all_any(source, key)
|
|
2622
|
+
if all_any_source < 0:
|
|
2623
|
+
raise ValueError(
|
|
2624
|
+
f'Partially shared key ({key}) while trying to merge {source} '
|
|
2625
|
+
f'with {target}')
|
|
2626
|
+
all_any_target = all_any(target, key)
|
|
2627
|
+
if all_any_target < 0 or all_any_source != all_any_target:
|
|
2628
|
+
raise ValueError(
|
|
2629
|
+
f'Partially shared key ({key}) while trying to merge {source} '
|
|
2630
|
+
f'with {target}')
|
|
2631
|
+
if not all_any_source and not all_any_target:
|
|
2632
|
+
return source + target
|
|
2633
|
+
merged = []
|
|
2634
|
+
for target_dict in target:
|
|
2635
|
+
value = target_dict[key]
|
|
2636
|
+
if key_type is not None:
|
|
2637
|
+
value = key_type(value)
|
|
2638
|
+
for i, source_dict in enumerate(source):
|
|
2639
|
+
vvalue = source_dict[key]
|
|
2640
|
+
if key_type is not None:
|
|
2641
|
+
vvalue = key_type(vvalue)
|
|
2642
|
+
if value == vvalue:
|
|
2643
|
+
merged.append(dictionary_update(
|
|
2644
|
+
target_dict, source_dict, sort=sort))
|
|
2645
|
+
source.pop(i)
|
|
2646
|
+
break
|
|
2647
|
+
else:
|
|
2648
|
+
merged.append(target_dict)
|
|
2649
|
+
merged.extend(source)
|
|
2650
|
+
if sorted:
|
|
2651
|
+
if key_type is None:
|
|
2652
|
+
merged.sort(key=lambda x: x[key])
|
|
2653
|
+
else:
|
|
2654
|
+
merged.sort(key=lambda x: key_type(x[key]))
|
|
2655
|
+
return merged
|