ChessAnalysisPipeline 0.0.17.dev3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (70) hide show
  1. CHAP/TaskManager.py +216 -0
  2. CHAP/__init__.py +27 -0
  3. CHAP/common/__init__.py +57 -0
  4. CHAP/common/models/__init__.py +8 -0
  5. CHAP/common/models/common.py +124 -0
  6. CHAP/common/models/integration.py +659 -0
  7. CHAP/common/models/map.py +1291 -0
  8. CHAP/common/processor.py +2869 -0
  9. CHAP/common/reader.py +658 -0
  10. CHAP/common/utils.py +110 -0
  11. CHAP/common/writer.py +730 -0
  12. CHAP/edd/__init__.py +23 -0
  13. CHAP/edd/models.py +876 -0
  14. CHAP/edd/processor.py +3069 -0
  15. CHAP/edd/reader.py +1023 -0
  16. CHAP/edd/select_material_params_gui.py +348 -0
  17. CHAP/edd/utils.py +1572 -0
  18. CHAP/edd/writer.py +26 -0
  19. CHAP/foxden/__init__.py +19 -0
  20. CHAP/foxden/models.py +71 -0
  21. CHAP/foxden/processor.py +124 -0
  22. CHAP/foxden/reader.py +224 -0
  23. CHAP/foxden/utils.py +80 -0
  24. CHAP/foxden/writer.py +168 -0
  25. CHAP/giwaxs/__init__.py +11 -0
  26. CHAP/giwaxs/models.py +491 -0
  27. CHAP/giwaxs/processor.py +776 -0
  28. CHAP/giwaxs/reader.py +8 -0
  29. CHAP/giwaxs/writer.py +8 -0
  30. CHAP/inference/__init__.py +7 -0
  31. CHAP/inference/processor.py +69 -0
  32. CHAP/inference/reader.py +8 -0
  33. CHAP/inference/writer.py +8 -0
  34. CHAP/models.py +227 -0
  35. CHAP/pipeline.py +479 -0
  36. CHAP/processor.py +125 -0
  37. CHAP/reader.py +124 -0
  38. CHAP/runner.py +277 -0
  39. CHAP/saxswaxs/__init__.py +7 -0
  40. CHAP/saxswaxs/processor.py +8 -0
  41. CHAP/saxswaxs/reader.py +8 -0
  42. CHAP/saxswaxs/writer.py +8 -0
  43. CHAP/server.py +125 -0
  44. CHAP/sin2psi/__init__.py +7 -0
  45. CHAP/sin2psi/processor.py +8 -0
  46. CHAP/sin2psi/reader.py +8 -0
  47. CHAP/sin2psi/writer.py +8 -0
  48. CHAP/tomo/__init__.py +15 -0
  49. CHAP/tomo/models.py +210 -0
  50. CHAP/tomo/processor.py +3862 -0
  51. CHAP/tomo/reader.py +9 -0
  52. CHAP/tomo/writer.py +59 -0
  53. CHAP/utils/__init__.py +6 -0
  54. CHAP/utils/converters.py +188 -0
  55. CHAP/utils/fit.py +2947 -0
  56. CHAP/utils/general.py +2655 -0
  57. CHAP/utils/material.py +274 -0
  58. CHAP/utils/models.py +595 -0
  59. CHAP/utils/parfile.py +224 -0
  60. CHAP/writer.py +122 -0
  61. MLaaS/__init__.py +0 -0
  62. MLaaS/ktrain.py +205 -0
  63. MLaaS/mnist_img.py +83 -0
  64. MLaaS/tfaas_client.py +371 -0
  65. chessanalysispipeline-0.0.17.dev3.dist-info/LICENSE +60 -0
  66. chessanalysispipeline-0.0.17.dev3.dist-info/METADATA +29 -0
  67. chessanalysispipeline-0.0.17.dev3.dist-info/RECORD +70 -0
  68. chessanalysispipeline-0.0.17.dev3.dist-info/WHEEL +5 -0
  69. chessanalysispipeline-0.0.17.dev3.dist-info/entry_points.txt +2 -0
  70. chessanalysispipeline-0.0.17.dev3.dist-info/top_level.txt +2 -0
CHAP/utils/general.py ADDED
@@ -0,0 +1,2655 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ File : general.py
5
+ Author : Rolf Verberg <rolfverberg AT gmail dot com>
6
+ Description: A collection of general modules
7
+ """
8
+ # RV write function that returns a list of peak indices for a given plot
9
+ # RV use raise_error concept on more functions
10
+
11
+ # System modules
12
+ from ast import literal_eval
13
+ import collections.abc
14
+ from logging import getLogger
15
+ import os
16
+ import re
17
+ import sys
18
+
19
+ # Third party modules
20
+ import numpy as np
21
+ try:
22
+ import matplotlib.pyplot as plt
23
+ except ImportError:
24
+ pass
25
+
26
+ logger = getLogger(__name__)
27
+
28
+ # pylint: disable=no-member
29
+ tiny = np.finfo(np.float64).resolution
30
+ # pylint: enable=no-member
31
+
32
+ def gformat(val, length=11):
33
+ """
34
+ Format a number with '%g'-like format, while:
35
+ - the length of the output string will be of the requested length
36
+ - positive numbers will have a leading blank
37
+ - the precision will be as high as possible
38
+ - trailing zeros will not be trimmed
39
+ """
40
+ # Code taken from lmfit library
41
+ if val is None or isinstance(val, bool):
42
+ return f'{repr(val):>{length}s}'
43
+ try:
44
+ expon = int(np.log10(abs(val)))
45
+ except (OverflowError, ValueError):
46
+ expon = 0
47
+ except TypeError:
48
+ return f'{repr(val):>{length}s}'
49
+
50
+ length = max(length, 7)
51
+ form = 'e'
52
+ prec = length - 7
53
+ if abs(expon) > 99:
54
+ prec -= 1
55
+ elif 0 < expon < prec+4 or -expon < prec-1 <= 0:
56
+ form = 'f'
57
+ prec += 4
58
+ if expon > 0:
59
+ prec -= expon
60
+ return f'{val:{length}.{prec}{form}}'
61
+
62
+
63
+ def getfloat_attr(obj, attr, length=11):
64
+ """Format an attribute of an object for printing."""
65
+ # Code taken from lmfit library
66
+ val = getattr(obj, attr, None)
67
+ if val is None:
68
+ return 'unknown'
69
+ if isinstance(val, int):
70
+ return f'{val}'
71
+ if isinstance(val, float):
72
+ return gformat(val, length=length).strip()
73
+ return repr(val)
74
+
75
+
76
+ def depth_list(_list):
77
+ """Return the depth of a list."""
78
+ return isinstance(_list, list) and 1+max(map(depth_list, _list))
79
+
80
+
81
+ def depth_tuple(_tuple):
82
+ """Return the depth of a tuple."""
83
+ return isinstance(_tuple, tuple) and 1+max(map(depth_tuple, _tuple))
84
+
85
+
86
+ def unwrap_tuple(_tuple):
87
+ """Unwrap a tuple."""
88
+ if depth_tuple(_tuple) > 1 and len(_tuple) == 1:
89
+ _tuple = unwrap_tuple(*_tuple)
90
+ return _tuple
91
+
92
+
93
+ def all_any(l, key):
94
+ """Check for a common key in a list of dictionaries, looping
95
+ at maximum only once over the entire list.
96
+
97
+ :param l: Input list.
98
+ :type l: list[dict]
99
+ :param key: The common dictionary key.
100
+ :type key: Any
101
+ :return: `1` if `all(l, key)`, `0` if `not any(l, key)`, or `-1`
102
+ otherwise. Return `None` for a zero length input list.
103
+ :rtype: Union[None, int]
104
+ """
105
+ ret = None
106
+ for d in l:
107
+ if key in d:
108
+ if ret == 0:
109
+ ret = -1
110
+ break
111
+ elif ret is None:
112
+ ret = 1
113
+ else:
114
+ if ret == 1:
115
+ ret = -1
116
+ break
117
+ elif ret is None:
118
+ ret = 0
119
+ return ret
120
+
121
+ def illegal_value(value, name, location=None, raise_error=False, log=True):
122
+ """Print illegal value message and/or raise error."""
123
+ if not isinstance(location, str):
124
+ location = ''
125
+ else:
126
+ location = f'in {location} '
127
+ if isinstance(name, str):
128
+ error_msg = \
129
+ f'Illegal value for {name} {location}({value}, {type(value)})'
130
+ else:
131
+ error_msg = f'Illegal value {location}({value}, {type(value)})'
132
+ if log:
133
+ logger.error(error_msg)
134
+ if raise_error:
135
+ raise ValueError(error_msg)
136
+
137
+
138
+ def illegal_combination(
139
+ value1, name1, value2, name2, location=None, raise_error=False,
140
+ log=True):
141
+ """Print illegal combination message and/or raise error."""
142
+ if not isinstance(location, str):
143
+ location = ''
144
+ else:
145
+ location = f'in {location} '
146
+ if isinstance(name1, str):
147
+ error_msg = f'Illegal combination for {name1} and {name2} {location}' \
148
+ f'({value1}, {type(value1)} and {value2}, {type(value2)})'
149
+ else:
150
+ error_msg = f'Illegal combination {location}' \
151
+ f'({value1}, {type(value1)} and {value2}, {type(value2)})'
152
+ if log:
153
+ logger.error(error_msg)
154
+ if raise_error:
155
+ raise ValueError(error_msg)
156
+
157
+
158
+ def not_zero(value):
159
+ """Return value with a minimal absolute size of tiny,
160
+ preserving the sign.
161
+ """
162
+ return float(np.copysign(max(tiny, abs(value)), value))
163
+
164
+
165
+ def test_ge_gt_le_lt(
166
+ ge, gt, le, lt, func, location=None, raise_error=False, log=True):
167
+ """Check individual and mutual validity of ge, gt, le, lt
168
+ qualifiers.
169
+
170
+ :param func: Test for integers or numbers.
171
+ :type func: callable: is_int, is_num
172
+ :return: True upon success or False when mutually exlusive.
173
+ :rtype: bool
174
+ """
175
+ if ge is None and gt is None and le is None and lt is None:
176
+ return True
177
+ if ge is not None:
178
+ if not func(ge):
179
+ illegal_value(ge, 'ge', location, raise_error, log)
180
+ return False
181
+ if gt is not None:
182
+ illegal_combination(ge, 'ge', gt, 'gt', location, raise_error, log)
183
+ return False
184
+ elif gt is not None and not func(gt):
185
+ illegal_value(gt, 'gt', location, raise_error, log)
186
+ return False
187
+ if le is not None:
188
+ if not func(le):
189
+ illegal_value(le, 'le', location, raise_error, log)
190
+ return False
191
+ if lt is not None:
192
+ illegal_combination(le, 'le', lt, 'lt', location, raise_error, log)
193
+ return False
194
+ elif lt is not None and not func(lt):
195
+ illegal_value(lt, 'lt', location, raise_error, log)
196
+ return False
197
+ if ge is not None:
198
+ if le is not None and ge > le:
199
+ illegal_combination(ge, 'ge', le, 'le', location, raise_error, log)
200
+ return False
201
+ if lt is not None and ge >= lt:
202
+ illegal_combination(ge, 'ge', lt, 'lt', location, raise_error, log)
203
+ return False
204
+ elif gt is not None:
205
+ if le is not None and gt >= le:
206
+ illegal_combination(gt, 'gt', le, 'le', location, raise_error, log)
207
+ return False
208
+ if lt is not None and gt >= lt:
209
+ illegal_combination(gt, 'gt', lt, 'lt', location, raise_error, log)
210
+ return False
211
+ return True
212
+
213
+
214
+ def range_string_ge_gt_le_lt(ge=None, gt=None, le=None, lt=None):
215
+ """Return a range string representation matching the ge, gt, le, lt
216
+ qualifiers. Does not validate the inputs, do that as needed before
217
+ calling.
218
+ """
219
+ range_string = ''
220
+ if ge is not None:
221
+ if le is None and lt is None:
222
+ range_string += f'>= {ge}'
223
+ else:
224
+ range_string += f'[{ge}, '
225
+ elif gt is not None:
226
+ if le is None and lt is None:
227
+ range_string += f'> {gt}'
228
+ else:
229
+ range_string += f'({gt}, '
230
+ if le is not None:
231
+ if ge is None and gt is None:
232
+ range_string += f'<= {le}'
233
+ else:
234
+ range_string += f'{le}]'
235
+ elif lt is not None:
236
+ if ge is None and gt is None:
237
+ range_string += f'< {lt}'
238
+ else:
239
+ range_string += f'{lt})'
240
+ return range_string
241
+
242
+
243
+ def is_int(v, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True):
244
+ """Value is an integer in range ge <= v <= le or gt < v < lt or
245
+ some combination.
246
+
247
+ :return: True if yes or False is no.
248
+ :rtype: bool
249
+ """
250
+ return _is_int_or_num(v, 'int', ge, gt, le, lt, raise_error, log)
251
+
252
+
253
+ def is_num(v, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True):
254
+ """Value is a number in range ge <= v <= le or gt < v < lt or some
255
+ combination.
256
+
257
+ :return: True if yes or False is no.
258
+ :rtype: bool
259
+ """
260
+ return _is_int_or_num(v, 'num', ge, gt, le, lt, raise_error, log)
261
+
262
+
263
+ def _is_int_or_num(
264
+ v, type_str, ge=None, gt=None, le=None, lt=None, raise_error=False,
265
+ log=True):
266
+ if type_str == 'int':
267
+ if not isinstance(v, int):
268
+ illegal_value(v, 'v', '_is_int_or_num', raise_error, log)
269
+ return False
270
+ if not test_ge_gt_le_lt(
271
+ ge, gt, le, lt, is_int, '_is_int_or_num', raise_error, log):
272
+ return False
273
+ elif type_str == 'num':
274
+ if not isinstance(v, (int, float)):
275
+ illegal_value(v, 'v', '_is_int_or_num', raise_error, log)
276
+ return False
277
+ if not test_ge_gt_le_lt(
278
+ ge, gt, le, lt, is_num, '_is_int_or_num', raise_error, log):
279
+ return False
280
+ else:
281
+ illegal_value(type_str, 'type_str', '_is_int_or_num', raise_error, log)
282
+ return False
283
+ if ge is None and gt is None and le is None and lt is None:
284
+ return True
285
+ error = False
286
+ error_msg = ''
287
+ if ge is not None and v < ge:
288
+ error = True
289
+ error_msg = f'Value {v} out of range: {v} !>= {ge}'
290
+ if not error and gt is not None and v <= gt:
291
+ error = True
292
+ error_msg = f'Value {v} out of range: {v} !> {gt}'
293
+ if not error and le is not None and v > le:
294
+ error = True
295
+ error_msg = f'Value {v} out of range: {v} !<= {le}'
296
+ if not error and lt is not None and v >= lt:
297
+ error = True
298
+ error_msg = f'Value {v} out of range: {v} !< {lt}'
299
+ if error:
300
+ if log:
301
+ logger.error(error_msg)
302
+ if raise_error:
303
+ raise ValueError(error_msg)
304
+ return False
305
+ return True
306
+
307
+
308
+ def is_int_pair(
309
+ v, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True):
310
+ """Value is an integer pair, each in range ge <= v[i] <= le or
311
+ gt < v[i] < lt or ge[i] <= v[i] <= le[i] or gt[i] < v[i] < lt[i]
312
+ or some combination.
313
+
314
+ :return: True if yes or False is no.
315
+ :rtype: bool
316
+ """
317
+ return _is_int_or_num_pair(v, 'int', ge, gt, le, lt, raise_error, log)
318
+
319
+
320
+ def is_num_pair(
321
+ v, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True):
322
+ """Value is a number pair, each in range ge <= v[i] <= le or
323
+ gt < v[i] < lt or ge[i] <= v[i] <= le[i] or gt[i] < v[i] < lt[i]
324
+ or some combination.
325
+
326
+ :return: True if yes or False is no.
327
+ :rtype: bool
328
+ """
329
+ return _is_int_or_num_pair(v, 'num', ge, gt, le, lt, raise_error, log)
330
+
331
+
332
+ def _is_int_or_num_pair(
333
+ v, type_str, ge=None, gt=None, le=None, lt=None, raise_error=False,
334
+ log=True):
335
+ if type_str == 'int':
336
+ if not (isinstance(v, (tuple, list)) and len(v) == 2
337
+ and isinstance(v[0], int) and isinstance(v[1], int)):
338
+ illegal_value(v, 'v', '_is_int_or_num_pair', raise_error, log)
339
+ return False
340
+ func = is_int
341
+ elif type_str == 'num':
342
+ if not (isinstance(v, (tuple, list)) and len(v) == 2
343
+ and isinstance(v[0], (int, float))
344
+ and isinstance(v[1], (int, float))):
345
+ illegal_value(v, 'v', '_is_int_or_num_pair', raise_error, log)
346
+ return False
347
+ func = is_num
348
+ else:
349
+ illegal_value(
350
+ type_str, 'type_str', '_is_int_or_num_pair', raise_error, log)
351
+ return False
352
+ if ge is None and gt is None and le is None and lt is None:
353
+ return True
354
+ if ge is None or func(ge, log=True):
355
+ ge = 2*[ge]
356
+ elif not _is_int_or_num_pair(
357
+ ge, type_str, raise_error=raise_error, log=log):
358
+ return False
359
+ if gt is None or func(gt, log=True):
360
+ gt = 2*[gt]
361
+ elif not _is_int_or_num_pair(
362
+ gt, type_str, raise_error=raise_error, log=log):
363
+ return False
364
+ if le is None or func(le, log=True):
365
+ le = 2*[le]
366
+ elif not _is_int_or_num_pair(
367
+ le, type_str, raise_error=raise_error, log=log):
368
+ return False
369
+ if lt is None or func(lt, log=True):
370
+ lt = 2*[lt]
371
+ elif not _is_int_or_num_pair(
372
+ lt, type_str, raise_error=raise_error, log=log):
373
+ return False
374
+ if (not func(v[0], ge[0], gt[0], le[0], lt[0], raise_error, log)
375
+ or not func(v[1], ge[1], gt[1], le[1], lt[1], raise_error, log)):
376
+ return False
377
+ return True
378
+
379
+
380
+ def is_int_series(
381
+ t_or_l, ge=None, gt=None, le=None, lt=None, raise_error=False,
382
+ log=True):
383
+ """Value is a tuple or list of integers, each in range
384
+ ge <= l[i] <= le or gt < l[i] < lt or some combination.
385
+ """
386
+ if not test_ge_gt_le_lt(
387
+ ge, gt, le, lt, is_int, 'is_int_series', raise_error, log):
388
+ return False
389
+ if not isinstance(t_or_l, (tuple, list)):
390
+ illegal_value(t_or_l, 't_or_l', 'is_int_series', raise_error, log)
391
+ return False
392
+ if any(not is_int(v, ge, gt, le, lt, raise_error, log) for v in t_or_l):
393
+ return False
394
+ return True
395
+
396
+
397
+ def is_num_series(
398
+ t_or_l, ge=None, gt=None, le=None, lt=None, raise_error=False,
399
+ log=True):
400
+ """Value is a tuple or list of numbers, each in range
401
+ ge <= l[i] <= le or gt < l[i] < lt or some combination.
402
+ """
403
+ if not test_ge_gt_le_lt(
404
+ ge, gt, le, lt, is_int, 'is_int_series', raise_error, log):
405
+ return False
406
+ if not isinstance(t_or_l, (tuple, list)):
407
+ illegal_value(t_or_l, 't_or_l', 'is_num_series', raise_error, log)
408
+ return False
409
+ if any(not is_num(v, ge, gt, le, lt, raise_error, log) for v in t_or_l):
410
+ return False
411
+ return True
412
+
413
+
414
+ def is_str_series(t_or_l, raise_error=False, log=True):
415
+ """Value is a tuple or list of strings."""
416
+ if (not isinstance(t_or_l, (tuple, list))
417
+ or any(not isinstance(s, str) for s in t_or_l)):
418
+ illegal_value(t_or_l, 't_or_l', 'is_str_series', raise_error, log)
419
+ return False
420
+ return True
421
+
422
+ def is_str_or_str_series(t_or_l, raise_error=False, log=True):
423
+ """Value is a string ot a tuple or list of strings."""
424
+ if isinstance(t_or_l, str):
425
+ return True
426
+ if (not isinstance(t_or_l, (tuple, list))
427
+ or any(not isinstance(s, str) for s in t_or_l)):
428
+ illegal_value(
429
+ t_or_l, 't_or_l', 'is_str_or_str_series', raise_error, log)
430
+ return False
431
+ return True
432
+
433
+
434
+ def is_dict_series(t_or_l, raise_error=False, log=True):
435
+ """Value is a tuple or list of dictionaries."""
436
+ if (not isinstance(t_or_l, (tuple, list))
437
+ or any(not isinstance(d, dict) for d in t_or_l)):
438
+ illegal_value(t_or_l, 't_or_l', 'is_dict_series', raise_error, log)
439
+ return False
440
+ return True
441
+
442
+
443
+ def is_dict_nums(d, raise_error=False, log=True):
444
+ """Value is a dictionary with single number values."""
445
+ if (not isinstance(d, dict)
446
+ or any(not is_num(v, log=False) for v in d.values())):
447
+ illegal_value(d, 'd', 'is_dict_nums', raise_error, log)
448
+ return False
449
+ return True
450
+
451
+
452
+ def is_dict_strings(d, raise_error=False, log=True):
453
+ """Value is a dictionary with single string values."""
454
+ if (not isinstance(d, dict)
455
+ or any(not isinstance(v, str) for v in d.values())):
456
+ illegal_value(d, 'd', 'is_dict_strings', raise_error, log)
457
+ return False
458
+ return True
459
+
460
+
461
+ def is_index(v, ge=0, lt=None, raise_error=False, log=True):
462
+ """Value is an array index in range ge <= v < lt. NOTE lt IS NOT
463
+ included!
464
+ """
465
+ if isinstance(lt, int):
466
+ if lt <= ge:
467
+ illegal_combination(
468
+ ge, 'ge', lt, 'lt', 'is_index', raise_error, log)
469
+ return False
470
+ return is_int(v, ge=ge, lt=lt, raise_error=raise_error, log=log)
471
+
472
+
473
+ def is_index_range(v, ge=0, le=None, lt=None, raise_error=False, log=True):
474
+ """Value is an array index range in range ge <= v[0] <= v[1] <= le
475
+ or ge <= v[0] <= v[1] < lt. NOTE le IS included!
476
+ """
477
+ if not is_int_pair(v, raise_error=raise_error, log=log):
478
+ return False
479
+ if not test_ge_gt_le_lt(
480
+ ge, None, le, lt, is_int, 'is_index_range', raise_error, log):
481
+ return False
482
+ if (not ge <= v[0] <= v[1] or (le is not None and v[1] > le)
483
+ or (lt is not None and v[1] >= lt)):
484
+ if le is not None:
485
+ error_msg = \
486
+ f'Value {v} out of range: !({ge} <= {v[0]} <= {v[1]} <= {le})'
487
+ else:
488
+ error_msg = \
489
+ f'Value {v} out of range: !({ge} <= {v[0]} <= {v[1]} < {lt})'
490
+ if log:
491
+ logger.error(error_msg)
492
+ if raise_error:
493
+ raise ValueError(error_msg)
494
+ return False
495
+ return True
496
+
497
+
498
+ def index_nearest(a, value):
499
+ """Return index of nearest array value."""
500
+ a = np.asarray(a)
501
+ if a.ndim > 1:
502
+ raise ValueError(
503
+ f'Invalid array dimension for parameter a ({a.ndim}, {a})')
504
+ # Round up for .5
505
+ value *= 1.0 + sys.float_info.epsilon
506
+ return (int)(np.argmin(np.abs(a-value)))
507
+
508
+
509
+ def index_nearest_down(a, value):
510
+ """Return index of nearest array value, rounded down."""
511
+ a = np.asarray(a)
512
+ if a.ndim > 1:
513
+ raise ValueError(
514
+ f'Invalid array dimension for parameter a ({a.ndim}, {a})')
515
+ index = int(np.argmin(np.abs(a-value)))
516
+ if value < a[index] and index > 0:
517
+ index -= 1
518
+ return index
519
+
520
+
521
+ def index_nearest_up(a, value):
522
+ """Return index of nearest array value, rounded up."""
523
+ a = np.asarray(a)
524
+ if a.ndim > 1:
525
+ raise ValueError(
526
+ f'Invalid array dimension for parameter a ({a.ndim}, {a})')
527
+ index = int(np.argmin(np.abs(a-value)))
528
+ if value > a[index] and index < a.size-1:
529
+ index += 1
530
+ return index
531
+
532
+
533
+ def get_consecutive_int_range(a):
534
+ """Return a list of pairs of integers marking consecutive ranges
535
+ of integers.
536
+ """
537
+ a.sort()
538
+ i = 0
539
+ int_ranges = []
540
+ while i < len(a):
541
+ j = i
542
+ while j < len(a)-1:
543
+ if a[j+1] > 1 + a[j]:
544
+ break
545
+ j += 1
546
+ int_ranges.append([a[i], a[j]])
547
+ i = j+1
548
+ return int_ranges
549
+
550
+
551
+ def round_to_n(x, n=1):
552
+ """Round to a specific number of sig figs."""
553
+ if x == 0.0:
554
+ return 0
555
+ return type(x)(round(x, n-1-int(np.floor(np.log10(abs(x))))))
556
+
557
+
558
+ def round_up_to_n(x, n=1):
559
+ """Round up to a specific number of sig figs."""
560
+ x_round = round_to_n(x, n)
561
+ if abs(x/x_round) > 1.0:
562
+ x_round += np.sign(x) * 10**(np.floor(np.log10(abs(x)))+1-n)
563
+ return type(x)(x_round)
564
+
565
+
566
+ def trunc_to_n(x, n=1):
567
+ """Truncate to a specific number of sig figs."""
568
+ x_round = round_to_n(x, n)
569
+ if abs(x_round/x) > 1.0:
570
+ x_round -= np.sign(x) * 10**(np.floor(np.log10(abs(x)))+1-n)
571
+ return type(x)(x_round)
572
+
573
+
574
+ def almost_equal(a, b, sig_figs):
575
+ """Check if equal to within a certain number of significant digits.
576
+ """
577
+ if is_num(a) and is_num(b):
578
+ return abs(round_to_n(a-b, sig_figs)) < pow(10, 1-sig_figs)
579
+ raise ValueError(
580
+ f'Invalid value for a or b in almost_equal (a: {a}, {type(a)}, '
581
+ f'b: {b}, {type(b)})')
582
+
583
+
584
+ def string_to_list(
585
+ s, split_on_dash=True, remove_duplicates=True, sort=True,
586
+ raise_error=False):
587
+ """Return a list of numbers by splitting/expanding a string on any
588
+ combination of commas, whitespaces, or dashes (when
589
+ split_on_dash=True).
590
+ e.g: '1, 3, 5-8, 12 ' -> [1, 3, 5, 6, 7, 8, 12]
591
+
592
+ :param s: Input string.
593
+ :type s: str
594
+ :param split_on_dash: Allow dashes in input string,
595
+ defaults to `True`.
596
+ :type split_on_dash: bool, optional
597
+ :param remove_duplicates: Removes duplicates (may also change the
598
+ order), defaults to `True`.
599
+ :type remove_duplicates: bool, optional
600
+ :param sort: Sort in ascending order, defaults to `True`.
601
+ :type sort: bool, optional
602
+ :param raise_error: Raise an exception upon any error,
603
+ defaults to `False`.
604
+ :type raise_error: bool, optional
605
+ :return: Input list or none upon an illegal input.
606
+ :rtype: list
607
+ """
608
+ if not isinstance(s, str):
609
+ illegal_value(s, 's', location='string_to_list')
610
+ return None
611
+ if not s:
612
+ return []
613
+ try:
614
+ list1 = re.split(r'\s+,\s+|\s+,|,\s+|\s+|,', s.strip())
615
+ except (ValueError, TypeError, SyntaxError, MemoryError,
616
+ RecursionError) as e:
617
+ if not raise_error:
618
+ return None
619
+ raise e
620
+ if split_on_dash:
621
+ try:
622
+ l_of_i = []
623
+ for v in list1:
624
+ list2 = [
625
+ literal_eval(x)
626
+ for x in re.split(r'\s+-\s+|\s+-|-\s+|\s+|-', v)]
627
+ if len(list2) == 1:
628
+ l_of_i += list2
629
+ elif len(list2) == 2 and list2[1] > list2[0]:
630
+ l_of_i += list(range(list2[0], 1+list2[1]))
631
+ else:
632
+ raise ValueError
633
+ except (ValueError, TypeError, SyntaxError, MemoryError,
634
+ RecursionError) as e:
635
+ if not raise_error:
636
+ return None
637
+ raise e
638
+ else:
639
+ l_of_i = [literal_eval(x) for x in list1]
640
+ if remove_duplicates:
641
+ l_of_i = list(dict.fromkeys(l_of_i))
642
+ if sort:
643
+ l_of_i = sorted(l_of_i)
644
+ return l_of_i
645
+
646
+
647
+ def list_to_string(a):
648
+ """Return a list of pairs of integers marking consecutive ranges
649
+ of integers in string notation."""
650
+ int_ranges = get_consecutive_int_range(a)
651
+ if not int_ranges:
652
+ return ''
653
+ if int_ranges[0][0] == int_ranges[0][1]:
654
+ s = f'{int_ranges[0][0]}'
655
+ else:
656
+ s = f'{int_ranges[0][0]}-{int_ranges[0][1]}'
657
+ for int_range in int_ranges[1:]:
658
+ if int_range[0] == int_range[1]:
659
+ s += f', {int_range[0]}'
660
+ else:
661
+ s += f', {int_range[0]}-{int_range[1]}'
662
+ return s
663
+
664
+
665
+ def get_trailing_int(string):
666
+ """Get the trailing integer in a string."""
667
+ index_regex = re.compile(r'\d+$')
668
+ match = index_regex.search(string)
669
+ if match is None:
670
+ return None
671
+ return int(match.group())
672
+
673
+
674
+ def input_int(
675
+ s=None, ge=None, gt=None, le=None, lt=None, default=None, inset=None,
676
+ raise_error=False, log=True):
677
+ """Interactively prompt the user to enter an integer."""
678
+ return _input_int_or_num(
679
+ 'int', s, ge, gt, le, lt, default, inset, raise_error, log)
680
+
681
+
682
+ def input_num(
683
+ s=None, ge=None, gt=None, le=None, lt=None, default=None,
684
+ raise_error=False, log=True):
685
+ """Interactively prompt the user to enter a number."""
686
+ return _input_int_or_num(
687
+ 'num', s, ge, gt, le, lt, default, None, raise_error,log)
688
+
689
+
690
+ def _input_int_or_num(
691
+ type_str, s=None, ge=None, gt=None, le=None, lt=None, default=None,
692
+ inset=None, raise_error=False, log=True):
693
+ """Interactively prompt the user to enter an integer or number."""
694
+ if type_str == 'int':
695
+ if not test_ge_gt_le_lt(
696
+ ge, gt, le, lt, is_int, '_input_int_or_num', raise_error, log):
697
+ return None
698
+ elif type_str == 'num':
699
+ if not test_ge_gt_le_lt(
700
+ ge, gt, le, lt, is_num, '_input_int_or_num', raise_error, log):
701
+ return None
702
+ else:
703
+ illegal_value(
704
+ type_str, 'type_str', '_input_int_or_num', raise_error, log)
705
+ return None
706
+ if default is not None:
707
+ if not _is_int_or_num(
708
+ default, type_str, raise_error=raise_error, log=log):
709
+ return None
710
+ if ge is not None and default < ge:
711
+ illegal_combination(
712
+ ge, 'ge', default, 'default', '_input_int_or_num', raise_error,
713
+ log)
714
+ return None
715
+ if gt is not None and default <= gt:
716
+ illegal_combination(
717
+ gt, 'gt', default, 'default', '_input_int_or_num', raise_error,
718
+ log)
719
+ return None
720
+ if le is not None and default > le:
721
+ illegal_combination(
722
+ le, 'le', default, 'default', '_input_int_or_num', raise_error,
723
+ log)
724
+ return None
725
+ if lt is not None and default >= lt:
726
+ illegal_combination(
727
+ lt, 'lt', default, 'default', '_input_int_or_num', raise_error,
728
+ log)
729
+ return None
730
+ default_string = f' [{default}]'
731
+ else:
732
+ default_string = ''
733
+ if inset is not None:
734
+ if (not isinstance(inset, (tuple, list))
735
+ or any(not isinstance(i, int) for i in inset)):
736
+ illegal_value(
737
+ inset, 'inset', '_input_int_or_num', raise_error, log)
738
+ return None
739
+ v_range = f'{range_string_ge_gt_le_lt(ge, gt, le, lt)}'
740
+ if v_range:
741
+ v_range = f' {v_range}'
742
+ if s is None:
743
+ if type_str == 'int':
744
+ print(f'Enter an integer{v_range}{default_string}: ')
745
+ else:
746
+ print(f'Enter a number{v_range}{default_string}: ')
747
+ else:
748
+ print(f'{s}{v_range}{default_string}: ')
749
+ try:
750
+ i = input()
751
+ if isinstance(i, str) and not i:
752
+ v = default
753
+ print(f'{v}')
754
+ else:
755
+ v = literal_eval(i)
756
+ if inset and v not in inset:
757
+ raise ValueError(f'{v} not part of the set {inset}')
758
+ except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError):
759
+ v = None
760
+ if not _is_int_or_num(v, type_str, ge, gt, le, lt):
761
+ v = _input_int_or_num(
762
+ type_str, s, ge, gt, le, lt, default, inset, raise_error, log)
763
+ return v
764
+
765
+
766
+ def input_int_list(
767
+ s=None, num_max=None, ge=None, le=None, split_on_dash=True,
768
+ remove_duplicates=True, sort=True, raise_error=False, log=True):
769
+ """Prompt the user to input a list of integers and split the
770
+ entered string on any combination of commas, whitespaces, or
771
+ dashes (when split_on_dash is True).
772
+ e.g: '1 3,5-8 , 12 ' -> [1, 3, 5, 6, 7, 8, 12]
773
+
774
+ :param s: Interactive user prompt, defaults to `None`.
775
+ :type s: str, optional
776
+ :param num_max: Maximum number of inputs in list.
777
+ :type num_max: int, optional
778
+ :param ge: Minimum value of inputs in list.
779
+ :type ge: int, optional
780
+ :param le: Minimum value of inputs in list.
781
+ :type le: int, optional
782
+ :param split_on_dash: Allow dashes in input string,
783
+ defaults to `True`.
784
+ :type split_on_dash: bool, optional
785
+ :param remove_duplicates: Removes duplicates (may also change the
786
+ order), defaults to `True`.
787
+ :type remove_duplicates: bool, optional
788
+ :param sort: Sort in ascending order, defaults to `True`.
789
+ :type sort: bool, optional
790
+ :param raise_error: Raise an exception upon any error,
791
+ defaults to `False`.
792
+ :type raise_error: bool, optional
793
+ :param log: Print an error message upon any error,
794
+ defaults to `True`.
795
+ :type log: bool, optional
796
+ :return: Input list or none upon an illegal input.
797
+ :rtype: list
798
+ """
799
+ return _input_int_or_num_list(
800
+ 'int', s, num_max, ge, le, split_on_dash, remove_duplicates, sort,
801
+ raise_error, log)
802
+
803
+
804
+ def input_num_list(
805
+ s=None, num_max=None, ge=None, le=None, remove_duplicates=True,
806
+ sort=True, raise_error=False, log=True):
807
+ """Prompt the user to input a list of numbers and split the entered
808
+ string on any combination of commas or whitespaces.
809
+ e.g: '1.0, 3, 5.8, 12 ' -> [1.0, 3.0, 5.8, 12.0]
810
+
811
+ :param s: Interactive user prompt.
812
+ :type s: str, optional
813
+ :param num_max: Maximum number of inputs in list.
814
+ :type num_max: int, optional
815
+ :param ge: Minimum value of inputs in list.
816
+ :type ge: float, optional
817
+ :param le: Minimum value of inputs in list.
818
+ :type le: float, optional
819
+ :param remove_duplicates: Removes duplicates (may also change the
820
+ order), defaults to `True`.
821
+ :type remove_duplicates: bool, optional
822
+ :param sort: Sort in ascending order, defaults to `True`.
823
+ :type sort: bool, optional
824
+ :param raise_error: Raise an exception upon any error,
825
+ defaults to `False`.
826
+ :type raise_error: bool, optional
827
+ :param log: Print an error message upon any error,
828
+ defaults to `True`.
829
+ :type log: bool, optional
830
+ :return: Input list or none upon an illegal input.
831
+ :rtype: list
832
+ """
833
+ return _input_int_or_num_list(
834
+ 'num', s, num_max, ge, le, False, remove_duplicates, sort, raise_error,
835
+ log)
836
+
837
+
838
+ def _input_int_or_num_list(
839
+ type_str, s=None, num_max=None, ge=None, le=None, split_on_dash=True,
840
+ remove_duplicates=True, sort=True, raise_error=False, log=True):
841
+ # RV do we want a limit on max dimension?
842
+ if type_str == 'int':
843
+ if not test_ge_gt_le_lt(
844
+ ge, None, le, None, is_int, 'input_int_or_num_list',
845
+ raise_error, log):
846
+ return None
847
+ elif type_str == 'num':
848
+ if not test_ge_gt_le_lt(
849
+ ge, None, le, None, is_num, 'input_int_or_num_list',
850
+ raise_error, log):
851
+ return None
852
+ else:
853
+ illegal_value(type_str, 'type_str', '_input_int_or_num_list')
854
+ return None
855
+ if (num_max is not None
856
+ and not is_int(num_max, gt=0, raise_error=raise_error, log=log)):
857
+ return None
858
+ v_range = f'{range_string_ge_gt_le_lt(ge=ge, le=le)}'
859
+ if v_range:
860
+ v_range = f' (each value in {v_range})'
861
+ if s is None:
862
+ print(f'Enter a series of integers{v_range}: ')
863
+ else:
864
+ print(f'{s}{v_range}: ')
865
+ try:
866
+ _list = string_to_list(input(), split_on_dash, remove_duplicates, sort)
867
+ except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError):
868
+ _list = None
869
+ except Exception as exc:
870
+ raise Exception('Unexpected error') from exc
871
+ if (not isinstance(_list, list)
872
+ or (num_max is not None and len(_list) > num_max)
873
+ or any(
874
+ not _is_int_or_num(v, type_str, ge=ge, le=le) for v in _list)):
875
+ num = '' if num_max is None else f'up to {num_max} '
876
+ if split_on_dash:
877
+ print(f'Invalid input: enter a valid set of {num}dash/comma/'
878
+ 'whitespace separated numbers e.g. 1 3,5-8 , 12')
879
+ else:
880
+ print(f'Invalid input: enter a valid set of {num}comma/whitespace '
881
+ 'separated numbers e.g. 1 3,5 8 , 12')
882
+ _list = _input_int_or_num_list(
883
+ type_str, s, ge, le, split_on_dash, remove_duplicates, sort,
884
+ raise_error, log)
885
+ return _list
886
+
887
+
888
+ def input_yesno(s=None, default=None):
889
+ """Interactively prompt the user to enter a y/n question."""
890
+ if default is not None:
891
+ if not isinstance(default, str):
892
+ illegal_value(default, 'default', 'input_yesno')
893
+ return None
894
+ if default.lower() in 'yes':
895
+ default = 'y'
896
+ elif default.lower() in 'no':
897
+ default = 'n'
898
+ else:
899
+ illegal_value(default, 'default', 'input_yesno')
900
+ return None
901
+ default_string = f' [{default}]'
902
+ else:
903
+ default_string = ''
904
+ if s is None:
905
+ print(f'Enter yes or no{default_string}: ')
906
+ else:
907
+ print(f'{s}{default_string}: ')
908
+ i = input()
909
+ if isinstance(i, str) and not i:
910
+ i = default
911
+ print(f'{i}')
912
+ if i is not None and i.lower() in 'yes':
913
+ v = True
914
+ elif i is not None and i.lower() in 'no':
915
+ v = False
916
+ else:
917
+ print('Invalid input, enter yes or no')
918
+ v = input_yesno(s, default)
919
+ return v
920
+
921
+
922
+ def input_menu(items, default=None, header=None):
923
+ """Interactively prompt the user to select from a menu."""
924
+ if (not isinstance(items, (tuple, list))
925
+ or any(not isinstance(i, str) for i in items)):
926
+ illegal_value(items, 'items', 'input_menu')
927
+ return None
928
+ if default is not None:
929
+ if not (isinstance(default, str) and default in items):
930
+ logger.error(
931
+ f'Invalid value for default ({default}), must be in {items}')
932
+ return None
933
+ default_string = f' [{1+items.index(default)}]'
934
+ else:
935
+ default_string = ''
936
+ if header is None:
937
+ print('Choose one of the following items '
938
+ f'(1, {len(items)}){default_string}:')
939
+ else:
940
+ print(f'{header} (1, {len(items)}){default_string}:')
941
+ for i, choice in enumerate(items):
942
+ print(f' {i+1}: {choice}')
943
+ try:
944
+ choice = input()
945
+ if isinstance(choice, str) and not choice:
946
+ choice = items.index(default)
947
+ print(f'{1+choice}')
948
+ else:
949
+ choice = literal_eval(choice)
950
+ if isinstance(choice, int) and 1 <= choice <= len(items):
951
+ choice -= 1
952
+ else:
953
+ raise ValueError
954
+ except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError):
955
+ choice = None
956
+ except Exception as exc:
957
+ raise Exception('Unexpected error') from exc
958
+ if choice is None:
959
+ print(f'Invalid choice, enter a number between 1 and {len(items)}')
960
+ choice = input_menu(items, default)
961
+ return choice
962
+
963
+
964
+ def assert_no_duplicates_in_list_of_dicts(_list, raise_error=False):
965
+ """Assert that there are no duplicates in a list of dictionaries.
966
+ """
967
+ if not isinstance(_list, list):
968
+ illegal_value(
969
+ _list, '_list', 'assert_no_duplicates_in_list_of_dicts',
970
+ raise_error)
971
+ return None
972
+ if any(not isinstance(d, dict) for d in _list):
973
+ illegal_value(
974
+ _list, '_list', 'assert_no_duplicates_in_list_of_dicts',
975
+ raise_error)
976
+ return None
977
+ if (len(_list) != len([dict(_tuple) for _tuple in
978
+ {tuple(sorted(d.items())) for d in _list}])):
979
+ if raise_error:
980
+ raise ValueError(f'Duplicate items found in {_list}')
981
+ logger.error(f'Duplicate items found in {_list}')
982
+ return None
983
+ return _list
984
+
985
+
986
+ def assert_no_duplicate_key_in_list_of_dicts(_list, key, raise_error=False):
987
+ """Assert that there are no duplicate keys in a list of
988
+ dictionaries.
989
+ """
990
+ if not isinstance(key, str):
991
+ illegal_value(
992
+ key, 'key', 'assert_no_duplicate_key_in_list_of_dicts',
993
+ raise_error)
994
+ return None
995
+ if not isinstance(_list, list):
996
+ illegal_value(
997
+ _list, '_list', 'assert_no_duplicate_key_in_list_of_dicts',
998
+ raise_error)
999
+ return None
1000
+ if any(isinstance(d, dict) for d in _list):
1001
+ illegal_value(
1002
+ _list, '_list', 'assert_no_duplicates_in_list_of_dicts',
1003
+ raise_error)
1004
+ return None
1005
+ keys = [d.get(key, None) for d in _list]
1006
+ if None in keys or len(set(keys)) != len(_list):
1007
+ if raise_error:
1008
+ raise ValueError(
1009
+ f'Duplicate or missing key ({key}) found in {_list}')
1010
+ logger.error(f'Duplicate or missing key ({key}) found in {_list}')
1011
+ return None
1012
+ return _list
1013
+
1014
+
1015
+ def assert_no_duplicate_attr_in_list_of_objs(_list, attr, raise_error=False):
1016
+ """Assert that there are no duplicate attributes in a list of
1017
+ objects.
1018
+ """
1019
+ if not isinstance(attr, str):
1020
+ illegal_value(
1021
+ attr, 'attr', 'assert_no_duplicate_attr_in_list_of_objs',
1022
+ raise_error)
1023
+ return None
1024
+ if not isinstance(_list, list):
1025
+ illegal_value(
1026
+ _list, '_list', 'assert_no_duplicate_key_in_list_of_objs',
1027
+ raise_error)
1028
+ return None
1029
+ attrs = [getattr(obj, attr, None) for obj in _list]
1030
+ if None in attrs or len(set(attrs)) != len(_list):
1031
+ if raise_error:
1032
+ raise ValueError(
1033
+ f'Duplicate or missing attr ({attr}) found in {_list}')
1034
+ logger.error(f'Duplicate or missing attr ({attr}) found in {_list}')
1035
+ return None
1036
+ return _list
1037
+
1038
+
1039
+ def file_exists_and_readable(f):
1040
+ """Check if a file exists and is readable."""
1041
+ if not os.path.isfile(f):
1042
+ raise ValueError(f'{f} is not a valid file')
1043
+ if not os.access(f, os.R_OK):
1044
+ raise ValueError(f'{f} is not accessible for reading')
1045
+ return f
1046
+
1047
+
1048
+ def rolling_average(
1049
+ y, x=None, dtype=None, start=0, end=None, width=None,
1050
+ stride=None, num=None, average=True, mode='valid',
1051
+ use_convolve=None):
1052
+ """Returns the rolling sum or average of an array over the last
1053
+ dimension.
1054
+ """
1055
+ y = np.asarray(y)
1056
+ y_shape = y.shape
1057
+ if y.ndim == 1:
1058
+ y = np.expand_dims(y, 0)
1059
+ else:
1060
+ y = y.reshape((np.prod(y.shape[0:-1]), y.shape[-1]))
1061
+ if x is not None:
1062
+ x = np.asarray(x)
1063
+ if x.ndim != 1:
1064
+ raise ValueError('Parameter "x" must be a 1D array-like')
1065
+ if x.size != y.shape[1]:
1066
+ raise ValueError(f'Dimensions of "x" and "y[1]" do not '
1067
+ f'match ({x.size} vs {y.shape[1]})')
1068
+ if dtype is None:
1069
+ if average:
1070
+ dtype = y.dtype
1071
+ else:
1072
+ dtype = np.float32
1073
+ if width is None and stride is None and num is None:
1074
+ raise ValueError('Invalid input parameters, specify at least one of '
1075
+ '"width", "stride" or "num"')
1076
+ if width is not None and not is_int(width, ge=1):
1077
+ raise ValueError(f'Invalid "width" parameter ({width})')
1078
+ if stride is not None and not is_int(stride, ge=1):
1079
+ raise ValueError(f'Invalid "stride" parameter ({stride})')
1080
+ if num is not None and not is_int(num, ge=1):
1081
+ raise ValueError(f'Invalid "num" parameter ({num})')
1082
+ if not isinstance(average, bool):
1083
+ raise ValueError(f'Invalid "average" parameter ({average})')
1084
+ if mode not in ('valid', 'full'):
1085
+ raise ValueError(f'Invalid "mode" parameter ({mode})')
1086
+ size = y.shape[1]
1087
+ if size < 2:
1088
+ raise ValueError(f'Invalid y[1] dimension ({size})')
1089
+ if not is_int(start, ge=0, lt=size):
1090
+ raise ValueError(f'Invalid "start" parameter ({start})')
1091
+ if end is None:
1092
+ end = size
1093
+ elif not is_int(end, gt=start, le=size):
1094
+ raise ValueError(f'Invalid "end" parameter ({end})')
1095
+ if use_convolve is None:
1096
+ use_convolve = bool(len(y_shape) == 1)
1097
+ if use_convolve and (start or end < size):
1098
+ y = np.take(y, range(start, end), axis=1)
1099
+ if x is not None:
1100
+ x = x[start:end]
1101
+ size = y.shape[1]
1102
+ else:
1103
+ size = end-start
1104
+
1105
+ if stride is None:
1106
+ if width is None:
1107
+ width = max(1, int(size/num))
1108
+ stride = width
1109
+ else:
1110
+ width = min(width, size)
1111
+ if num is None:
1112
+ stride = width
1113
+ else:
1114
+ stride = max(1, int((size-width) / (num-1)))
1115
+ else:
1116
+ stride = min(stride, size-stride)
1117
+ if width is None:
1118
+ width = stride
1119
+
1120
+ if mode == 'valid':
1121
+ num = 1 + max(0, int((size-width) / stride))
1122
+ else:
1123
+ num = int(size/stride)
1124
+ if num*stride < size:
1125
+ num += 1
1126
+
1127
+ if use_convolve:
1128
+ n_start = 0
1129
+ n_end = width
1130
+ weight = np.empty((num))
1131
+ for n in range(num):
1132
+ n_num = n_end-n_start
1133
+ weight[n] = n_num
1134
+ n_start += stride
1135
+ n_end = min(size, n_end+stride)
1136
+
1137
+ window = np.ones((width))
1138
+ if x is not None:
1139
+ if mode == 'valid':
1140
+ rx = np.convolve(x, window)[width-1:1-width:stride]
1141
+ else:
1142
+ rx = np.convolve(x, window)[width-1::stride]
1143
+ rx /= weight
1144
+
1145
+ ry = []
1146
+ if mode == 'valid':
1147
+ for i in range(y.shape[0]):
1148
+ ry.append(np.convolve(y[i], window)[width-1:1-width:stride])
1149
+ else:
1150
+ for i in range(y.shape[0]):
1151
+ ry.append(np.convolve(y[i], window)[width-1::stride])
1152
+ ry = np.reshape(ry, (*y_shape[0:-1], num))
1153
+ if len(y_shape) == 1:
1154
+ ry = np.squeeze(ry)
1155
+ if average:
1156
+ ry = (np.asarray(ry).astype(np.float32)/weight).astype(dtype)
1157
+ elif mode != 'valid':
1158
+ weight = np.where(weight < width, width/weight, 1.0)
1159
+ ry = (np.asarray(ry).astype(np.float32)*weight).astype(dtype)
1160
+ else:
1161
+ ry = np.zeros((num, y.shape[0]), dtype=y.dtype)
1162
+ if x is not None:
1163
+ rx = np.zeros(num, dtype=x.dtype)
1164
+ n_start = start
1165
+ n_end = n_start+width
1166
+ for n in range(num):
1167
+ y_sum = np.sum(y[:,n_start:n_end], 1)
1168
+ n_num = n_end-n_start
1169
+ if n_num < width:
1170
+ y_sum *= width/n_num
1171
+ ry[n] = y_sum
1172
+ if x is not None:
1173
+ rx[n] = np.sum(x[n_start:n_end])/n_num
1174
+ n_start += stride
1175
+ n_end = min(start+size, n_end+stride)
1176
+ ry = np.reshape(ry.T, (*y_shape[0:-1], num))
1177
+ if len(y_shape) == 1:
1178
+ ry = np.squeeze(ry)
1179
+ if average:
1180
+ ry = (ry.astype(np.float32)/width).astype(dtype)
1181
+
1182
+ if x is None:
1183
+ return ry
1184
+ return ry, rx
1185
+
1186
+
1187
+ def baseline_arPLS(
1188
+ y, mask=None, w=None, tol=1.e-8, lam=1.e6, max_iter=20,
1189
+ full_output=False):
1190
+ """Returns the smoothed baseline estimate of a spectrum.
1191
+
1192
+ Based on S.-J. Baek, A. Park, Y.-J Ahn, and J. Choo,
1193
+ "Baseline correction using asymmetrically reweighted penalized
1194
+ least squares smoothing", Analyst, 2015,140, 250-257
1195
+
1196
+ :param y: The spectrum.
1197
+ :type y: array-like
1198
+ :param mask: A mask to apply to the spectrum before baseline
1199
+ construction.
1200
+ :type mask: array-like, optional
1201
+ :param w: The weights (allows restart for additional ieterations).
1202
+ :type w: numpy.array, optional
1203
+ :param tol: The convergence tolerence, defaults to `1.e-8`.
1204
+ :type tol: float, optional
1205
+ :param lam: The &lambda (smoothness) parameter (the balance
1206
+ between the residual of the data and the baseline and the
1207
+ smoothness of the baseline). The suggested range is between
1208
+ 100 and 10^8, defaults to `10^6`.
1209
+ :type lam: float, optional
1210
+ :param max_iter: The maximum number of iterations,
1211
+ defaults to `20`.
1212
+ :type max_iter: int, optional
1213
+ :param full_output: Whether or not to also output the baseline
1214
+ corrected spectrum, the number of iterations and error in the
1215
+ returned result, defaults to `False`.
1216
+ :type full_output: bool, optional
1217
+ :return: The smoothed baseline, with optionally the baseline
1218
+ corrected spectrum, the weights, the number of iterations and
1219
+ the error in the returned result.
1220
+ :rtype: numpy.array [, numpy.array, int, float]
1221
+ """
1222
+ # With credit to: Daniel Casas-Orozco
1223
+ # https://stackoverflow.com/questions/29156532/python-baseline-correction-library
1224
+ # Third party modules
1225
+ from scipy.sparse import (
1226
+ spdiags,
1227
+ linalg,
1228
+ )
1229
+
1230
+ if not is_num(tol, gt=0):
1231
+ raise ValueError(f'Invalid tol parameter ({tol})')
1232
+ if not is_num(lam, gt=0):
1233
+ raise ValueError(f'Invalid lam parameter ({lam})')
1234
+ if not is_int(max_iter, gt=0):
1235
+ raise ValueError(f'Invalid max_iter parameter ({max_iter})')
1236
+ if not isinstance(full_output, bool):
1237
+ raise ValueError(f'Invalid full_output parameter ({max_iter})')
1238
+ y = np.asarray(y)
1239
+ if mask is not None:
1240
+ mask = mask.astype(bool)
1241
+ y_org = y
1242
+ y = y[mask]
1243
+ num = y.size
1244
+
1245
+ diag = np.ones((num-2))
1246
+ D = spdiags([diag, -2*diag, diag], [0, -1, -2], num, num-2)
1247
+
1248
+ H = lam * D.dot(D.T)
1249
+
1250
+ if w is None:
1251
+ w = np.ones(num)
1252
+ W = spdiags(w, 0, num, num)
1253
+
1254
+ error = 1
1255
+ num_iter = 0
1256
+
1257
+ exp_max = int(np.log(sys.float_info.max))
1258
+ while error > tol and num_iter < max_iter:
1259
+ z = linalg.spsolve(W + H, W * y)
1260
+ d = y - z
1261
+ dn = d[d < 0]
1262
+
1263
+ m = np.mean(dn)
1264
+ s = np.std(dn)
1265
+
1266
+ w_new = 1.0 / (1.0 + np.exp(
1267
+ np.clip(2.0 * (d - (2.0*s - m))/s, None, exp_max)))
1268
+ error = np.linalg.norm(w_new - w) / np.linalg.norm(w)
1269
+ num_iter += 1
1270
+ w = w_new
1271
+ W.setdiag(w)
1272
+
1273
+ if mask is not None:
1274
+ zz = np.zeros(y_org.size)
1275
+ zz[mask] = z
1276
+ z = zz
1277
+ if full_output:
1278
+ d = y_org - z
1279
+ if full_output:
1280
+ return z, d, w, num_iter, float(error)
1281
+ return z
1282
+
1283
+
1284
+ def fig_to_iobuf(fig, fileformat=None):
1285
+ """Return an in-memory object as a byte stream represention of
1286
+ a Matplotlib figure.
1287
+
1288
+ :param fig: Matplotlib figure object.
1289
+ :type fig: matplotlib.figure.Figure
1290
+ :param fileformat: Valid Matplotlib saved figure file format,
1291
+ defaults to `'png'`.
1292
+ :type fileformat: str, optional
1293
+ :return: Byte stream representation of the Matplotlib figure and
1294
+ the associated file format.
1295
+ :rtype: tuple[_io.BytesIO, str]
1296
+ """
1297
+ # System modules
1298
+ from io import BytesIO
1299
+
1300
+ if fileformat is None:
1301
+ fileformat = 'png'
1302
+ else:
1303
+ if fileformat not in plt.gcf().canvas.get_supported_filetypes():
1304
+ fileformat = 'png'
1305
+
1306
+ buf = BytesIO()
1307
+ fig.savefig(buf, format=fileformat)
1308
+ return buf, fileformat
1309
+
1310
+
1311
+ def save_iobuf_fig(buf, filename, force_overwrite=False):
1312
+ """Save a byte stream represention of a Matplotlib figure to file.
1313
+
1314
+ :param buf: Byte stream representation of the Matplotlib figure.
1315
+ :type buf: _io.BytesIO
1316
+ :param filename: Filename (with a valid extension).
1317
+ :type filename: str
1318
+ :param force_overwrite: Flag to allow `filename` to be overwritten
1319
+ if it already exists, defaults to `False`.
1320
+ :type force_overwrite: bool, optional
1321
+ :raises RuntimeError: If a file already exists and
1322
+ `force_overwrite` is `False`.
1323
+ """
1324
+ # Third party modules
1325
+ from PIL import Image
1326
+
1327
+ exts = Image.registered_extensions()
1328
+ exts = {ex for ex, f in exts.items() if f in Image.SAVE}
1329
+
1330
+ # Validate filename and extension
1331
+ _, ext = os.path.splitext(filename)
1332
+ if not ext or ext not in exts:
1333
+ raise ValueError(f'Invalid file format {ext[1:]}')
1334
+ filedir = os.path.dirname(filename)
1335
+ if not os.path.isdir(filedir):
1336
+ os.makedirs(filedir)
1337
+ if os.path.isfile(filename) and not force_overwrite:
1338
+ raise FileExistsError(f'{filename} already exists')
1339
+
1340
+ # Write image to file
1341
+ buf.seek(0)
1342
+ img = Image.open(buf)
1343
+ img.save(filename)
1344
+
1345
+
1346
+ def select_mask_1d(
1347
+ y, x=None, preselected_index_ranges=None, preselected_mask=None,
1348
+ title=None, xlabel=None, ylabel=None, min_num_index_ranges=None,
1349
+ max_num_index_ranges=None, interactive=True, filename=None,
1350
+ return_buf=False):
1351
+ """Display a lineplot and have the user select a mask.
1352
+
1353
+ :param y: One-dimensional data array for which a mask will be
1354
+ constructed.
1355
+ :type y: numpy.ndarray
1356
+ :param x: x-coordinates of the reference data.
1357
+ :type x: numpy.ndarray, optional
1358
+ :param preselected_index_ranges: List of preselected index ranges
1359
+ to mask (bounds are inclusive).
1360
+ :type preselected_index_ranges: Union(list[tuple(int, int)],
1361
+ list[list[int]], list[tuple(float, float)], list[list[float]]),
1362
+ optional
1363
+ :param preselected_mask: Preselected boolean mask array.
1364
+ :type preselected_mask: numpy.ndarray, optional
1365
+ :param title: Title for the displayed figure.
1366
+ :type title: str, optional
1367
+ :param xlabel: Label for the x-axis of the displayed figure.
1368
+ :type xlabel: str, optional
1369
+ :param ylabel: Label for the y-axis of the displayed figure.
1370
+ :type ylabel: str, optional
1371
+ :param min_num_index_ranges: The minimum number of selected index
1372
+ ranges.
1373
+ :type min_num_index_ranges: int, optional
1374
+ :param max_num_index_ranges: The maximum number of selected index
1375
+ ranges.
1376
+ :type max_num_index_ranges: int, optional
1377
+ :param interactive: Show the plot and allow user interactions with
1378
+ the matplotlib figure, defaults to `True`.
1379
+ :type interactive: bool, optional
1380
+ :param filename: Save a .png of the plot to filename, defaults to
1381
+ `None`, in which case the plot is not saved.
1382
+ :type filename: str, optional
1383
+ :param return_buf: Return an in-memory object as a byte stream
1384
+ represention of the Matplotlib figure, defaults to `False`.
1385
+ :type return_buf: bool, optional
1386
+ :return: A byte stream represention of the Matplotlib figure if
1387
+ return_buf is `True` (`None` otherwise), a boolean mask array,
1388
+ and the list of selected index ranges.
1389
+ :rtype: Union[io.BytesIO, None], numpy.ndarray, list[list[int, int]]
1390
+ """
1391
+ # Third party modules
1392
+ # pylint: disable=possibly-used-before-assignment
1393
+ if interactive or filename is not None or return_buf:
1394
+ from matplotlib.patches import Patch
1395
+ from matplotlib.widgets import Button, SpanSelector
1396
+
1397
+ def change_fig_title(title):
1398
+ if fig_title:
1399
+ fig_title[0].remove()
1400
+ fig_title.pop()
1401
+ fig_title.append(plt.figtext(*title_pos, title, **title_props))
1402
+
1403
+ def change_error_text(error):
1404
+ if error_texts:
1405
+ error_texts[0].remove()
1406
+ error_texts.pop()
1407
+ error_texts.append(plt.figtext(*error_pos, error, **error_props))
1408
+
1409
+ def get_selected_index_ranges(change_fnc=None, title=''):
1410
+ selected_index_ranges = sorted(
1411
+ [[index_nearest(x, span.extents[0]),
1412
+ index_nearest(x, span.extents[1])+1]
1413
+ for span in spans])
1414
+ if change_fnc is not None:
1415
+ if len(selected_index_ranges) > 1:
1416
+ change_fnc(
1417
+ f'{title}Selected ROIs: {selected_index_ranges}')
1418
+ elif selected_index_ranges:
1419
+ change_fnc(
1420
+ f'{title}Selected ROI: {tuple(selected_index_ranges[0])}')
1421
+ else:
1422
+ change_fnc(f'{title}Selected ROI: None')
1423
+ return selected_index_ranges
1424
+
1425
+ def add_span(event, xrange_init=None):
1426
+ """Callback function for the "Add span" button."""
1427
+ if (max_num_index_ranges is not None
1428
+ and len(spans) >= max_num_index_ranges):
1429
+ change_error_text(
1430
+ 'Exceeding max number of ranges, adjust an existing '
1431
+ 'range or click "Reset"/"Confirm"')
1432
+ else:
1433
+ spans.append(
1434
+ SpanSelector(
1435
+ ax, select_span, 'horizontal', props=included_props,
1436
+ useblit=True, interactive=interactive,
1437
+ drag_from_anywhere=True, ignore_event_outside=True,
1438
+ grab_range=5))
1439
+ if xrange_init is None:
1440
+ xmin_init, xmax_init = min(x), 0.05*(max(x)-min(x))
1441
+ else:
1442
+ xmin_init, xmax_init = xrange_init
1443
+ spans[-1]._selection_completed = True
1444
+ spans[-1].extents = (xmin_init, xmax_init)
1445
+ spans[-1].onselect(xmin_init, xmax_init)
1446
+ plt.draw()
1447
+
1448
+ def select_span(xmin, xmax):
1449
+ """Callback function for the SpanSelector widget."""
1450
+ combined_spans = True
1451
+ while combined_spans:
1452
+ combined_spans = False
1453
+ for i, span1 in enumerate(spans):
1454
+ for span2 in spans[i+1:]:
1455
+ if (span1.extents[1] >= span2.extents[0]
1456
+ and span1.extents[0] <= span2.extents[1]):
1457
+ change_error_text(
1458
+ 'Combined overlapping spans in currently '
1459
+ 'selected mask')
1460
+ span2.extents = (
1461
+ min(span1.extents[0], span2.extents[0]),
1462
+ max(span1.extents[1], span2.extents[1]))
1463
+ span1.set_visible(False)
1464
+ spans.remove(span1)
1465
+ combined_spans = True
1466
+ break
1467
+ if combined_spans:
1468
+ break
1469
+ get_selected_index_ranges(change_error_text)
1470
+ plt.draw()
1471
+
1472
+ def reset(event):
1473
+ """Callback function for the "Reset" button."""
1474
+ if error_texts:
1475
+ error_texts[0].remove()
1476
+ error_texts.pop()
1477
+ for span in reversed(spans):
1478
+ span.set_visible(False)
1479
+ spans.remove(span)
1480
+ get_selected_index_ranges(change_error_text)
1481
+ plt.draw()
1482
+
1483
+ def confirm(event):
1484
+ """Callback function for the "Confirm" button."""
1485
+ if (min_num_index_ranges is not None
1486
+ and len(spans) < min_num_index_ranges):
1487
+ change_error_text(
1488
+ f'Select at least {min_num_index_ranges} unique index ranges')
1489
+ plt.draw()
1490
+ else:
1491
+ if error_texts:
1492
+ error_texts[0].remove()
1493
+ error_texts.pop()
1494
+ get_selected_index_ranges(change_fig_title, title)
1495
+ plt.close()
1496
+
1497
+ def update_mask(mask, selected_index_ranges):
1498
+ """Update the mask with the selected index ranges."""
1499
+ for min_, max_ in selected_index_ranges:
1500
+ mask = np.logical_or(
1501
+ mask,
1502
+ np.logical_and(x >= x[min_], x <= x[min(max_, num_data-1)]))
1503
+ return mask
1504
+
1505
+ def update_index_ranges(mask):
1506
+ """Update the selected index ranges (where mask = True)."""
1507
+ selected_index_ranges = []
1508
+ for i, m in enumerate(mask):
1509
+ if m:
1510
+ if (not selected_index_ranges
1511
+ or isinstance(selected_index_ranges[-1], tuple)):
1512
+ selected_index_ranges.append(i)
1513
+ else:
1514
+ if (selected_index_ranges
1515
+ and isinstance(selected_index_ranges[-1], int)):
1516
+ selected_index_ranges[-1] = \
1517
+ (selected_index_ranges[-1], i-1)
1518
+ if (selected_index_ranges
1519
+ and isinstance(selected_index_ranges[-1], int)):
1520
+ selected_index_ranges[-1] = (selected_index_ranges[-1], num_data-1)
1521
+ return selected_index_ranges
1522
+
1523
+ # Check inputs
1524
+ y = np.asarray(y)
1525
+ if y.ndim > 1:
1526
+ raise ValueError(f'Invalid y dimension ({y.ndim})')
1527
+ num_data = y.size
1528
+ if x is None:
1529
+ x = np.arange(num_data)+0.5
1530
+ else:
1531
+ x = np.asarray(x, dtype=np.float64)
1532
+ if x.ndim > 1 or x.size != num_data:
1533
+ raise ValueError(f'Invalid x shape ({x.shape})')
1534
+ if not np.all(x[:-1] < x[1:]):
1535
+ raise ValueError('Invalid x: must be monotonically increasing')
1536
+ if title is None:
1537
+ title = ''
1538
+ else:
1539
+ title = f'{title}: '
1540
+ if preselected_index_ranges is None:
1541
+ preselected_index_ranges = []
1542
+ else:
1543
+ if not isinstance(preselected_index_ranges, list):
1544
+ raise ValueError('Invalid parameter preselected_index_ranges '
1545
+ f'({preselected_index_ranges})')
1546
+ if interactive or filename is not None or return_buf:
1547
+ index_ranges = []
1548
+ for v in preselected_index_ranges:
1549
+ if not is_num_pair(v):
1550
+ raise ValueError(
1551
+ 'Invalid parameter preselected_index_ranges '
1552
+ f'({preselected_index_ranges})')
1553
+ index_ranges.append(
1554
+ (max(0, int(v[0])), min(num_data, int(v[1])-1)))
1555
+ preselected_index_ranges = index_ranges
1556
+
1557
+ # Setup the preselected mask and index ranges if provided
1558
+ if preselected_mask is not None:
1559
+ preselected_index_ranges = update_index_ranges(
1560
+ update_mask(
1561
+ np.copy(np.asarray(preselected_mask, dtype=bool)),
1562
+ preselected_index_ranges))
1563
+
1564
+ if not interactive and filename is None and not return_buf:
1565
+
1566
+ # Update the mask with the preselected index ranges
1567
+ selected_mask = update_mask(len(x)*[False], preselected_index_ranges)
1568
+
1569
+ return None, selected_mask, preselected_index_ranges
1570
+
1571
+ spans = []
1572
+ fig_title = []
1573
+ error_texts = []
1574
+
1575
+ # Setup the Matplotlib figure
1576
+ title_pos = (0.5, 0.95)
1577
+ title_props = {'fontsize': 'xx-large', 'horizontalalignment': 'center',
1578
+ 'verticalalignment': 'bottom'}
1579
+ error_pos = (0.5, 0.90)
1580
+ error_props = {'fontsize': 'x-large', 'horizontalalignment': 'center',
1581
+ 'verticalalignment': 'bottom'}
1582
+ excluded_props = {
1583
+ 'facecolor': 'white', 'edgecolor': 'gray', 'linestyle': ':'}
1584
+ included_props = {
1585
+ 'alpha': 0.5, 'facecolor': 'tab:blue', 'edgecolor': 'blue'}
1586
+
1587
+ fig, ax = plt.subplots(figsize=(11, 8.5))
1588
+ handles = ax.plot(x, y, color='k', label='Reference Data')
1589
+ handles.append(Patch(
1590
+ label='Excluded / unselected ranges', **excluded_props))
1591
+ handles.append(Patch(
1592
+ label='Included / selected ranges', **included_props))
1593
+ ax.legend(handles=handles)
1594
+ ax.set_xlabel(xlabel, fontsize='x-large')
1595
+ ax.set_ylabel(ylabel, fontsize='x-large')
1596
+ ax.set_xlim(x[0], x[-1])
1597
+ fig.subplots_adjust(bottom=0.0, top=0.85)
1598
+
1599
+ # Add the preselected index ranges
1600
+ for min_, max_ in preselected_index_ranges:
1601
+ add_span(None, xrange_init=(x[min_], x[min(max_, num_data-1)]))
1602
+
1603
+ if not interactive:
1604
+
1605
+ get_selected_index_ranges(change_fig_title, title)
1606
+ if error_texts:
1607
+ error_texts[0].remove()
1608
+ error_texts.pop()
1609
+
1610
+ else:
1611
+
1612
+ change_fig_title(f'{title}Click and drag to select ranges')
1613
+ get_selected_index_ranges(change_error_text)
1614
+ fig.subplots_adjust(bottom=0.2)
1615
+
1616
+ # Setup "Add span" button
1617
+ add_span_btn = Button(
1618
+ plt.axes([0.15, 0.05, 0.15, 0.075]), 'Add span')
1619
+ add_span_cid = add_span_btn.on_clicked(add_span)
1620
+
1621
+ # Setup "Reset" button
1622
+ reset_btn = Button(plt.axes([0.45, 0.05, 0.15, 0.075]), 'Reset')
1623
+ reset_cid = reset_btn.on_clicked(reset)
1624
+
1625
+ # Setup "Confirm" button
1626
+ confirm_btn = Button(plt.axes([0.75, 0.05, 0.15, 0.075]), 'Confirm')
1627
+ confirm_cid = confirm_btn.on_clicked(confirm)
1628
+
1629
+ # Show figure for user interaction
1630
+ plt.show()
1631
+
1632
+ # Disconnect all widget callbacks when figure is closed
1633
+ add_span_btn.disconnect(add_span_cid)
1634
+ reset_btn.disconnect(reset_cid)
1635
+ confirm_btn.disconnect(confirm_cid)
1636
+
1637
+ # ...and remove the buttons before returning the figure
1638
+ add_span_btn.ax.remove()
1639
+ reset_btn.ax.remove()
1640
+ confirm_btn.ax.remove()
1641
+ plt.subplots_adjust(bottom=0.0)
1642
+
1643
+ selected_index_ranges = get_selected_index_ranges()
1644
+
1645
+ # Update the mask with the currently selected index ranges
1646
+ selected_mask = update_mask(len(x)*[False], selected_index_ranges)
1647
+
1648
+ buf = None
1649
+ if filename is not None or return_buf:
1650
+ if interactive:
1651
+ if len(selected_index_ranges) > 1:
1652
+ title += f'Selected ROIs: {selected_index_ranges}'
1653
+ else:
1654
+ title += f'Selected ROI: {tuple(selected_index_ranges[0])}'
1655
+ fig_title[0]._text = title
1656
+ fig_title[0].set_in_layout(True)
1657
+ fig.tight_layout(rect=(0, 0, 1, 0.95))
1658
+ if filename is not None:
1659
+ fig.savefig(filename)
1660
+ if return_buf:
1661
+ buf = fig_to_iobuf(fig)
1662
+ plt.close()
1663
+ return buf, selected_mask, selected_index_ranges
1664
+
1665
+
1666
+ def select_roi_1d(
1667
+ y, x=None, preselected_roi=None, title=None, xlabel=None, ylabel=None,
1668
+ interactive=True, filename=None, return_buf=False):
1669
+ """Display a 2D plot and have the user select a single region
1670
+ of interest.
1671
+
1672
+ :param y: One-dimensional data array for which a for which a region
1673
+ of interest will be selected.
1674
+ :type y: numpy.ndarray
1675
+ :param x: x-coordinates of the data
1676
+ :type x: numpy.ndarray, optional
1677
+ :param preselected_roi: Preselected region of interest.
1678
+ :type preselected_roi: tuple(int, int), optional
1679
+ :param title: Title for the displayed figure.
1680
+ :type title: str, optional
1681
+ :param xlabel: Label for the x-axis of the displayed figure.
1682
+ :type xlabel: str, optional
1683
+ :param ylabel: Label for the y-axis of the displayed figure.
1684
+ :type ylabel: str, optional
1685
+ :param interactive: Show the plot and allow user interactions with
1686
+ the matplotlib figure, defaults to `True`.
1687
+ :type interactive: bool, optional
1688
+ :param filename: Save a .png of the plot to filename, defaults to
1689
+ `None`, in which case the plot is not saved.
1690
+ :type filename: str, optional
1691
+ :param return_buf: Return an in-memory object as a byte stream
1692
+ represention of the Matplotlib figure, defaults to `False`.
1693
+ :type return_buf: bool, optional
1694
+ :return: A byte stream represention of the Matplotlib figure if
1695
+ return_buf is `True` (`None` otherwise), and the selected
1696
+ region of interest.
1697
+ :rtype: Union[io.BytesIO, None], tuple(int, int)
1698
+ """
1699
+ # Check inputs
1700
+ y = np.asarray(y)
1701
+ if y.ndim != 1:
1702
+ raise ValueError(f'Invalid image dimension ({y.ndim})')
1703
+ if preselected_roi is not None:
1704
+ if not is_int_pair(preselected_roi, ge=0, le=y.size, log=False):
1705
+ raise ValueError('Invalid parameter preselected_roi '
1706
+ f'({preselected_roi})')
1707
+ preselected_roi = [preselected_roi]
1708
+
1709
+ buf, _, roi = select_mask_1d(
1710
+ y, x=x, preselected_index_ranges=preselected_roi, title=title,
1711
+ xlabel=xlabel, ylabel=ylabel, min_num_index_ranges=1,
1712
+ max_num_index_ranges=1, interactive=interactive, filename=filename,
1713
+ return_buf=return_buf)
1714
+
1715
+ return buf, tuple(roi[0])
1716
+
1717
+ def select_roi_2d(
1718
+ a, preselected_roi=None, title=None, title_a=None,
1719
+ row_label='row index', column_label='column index', interactive=True,
1720
+ filename=None, return_buf=False):
1721
+ """Display a 2D image and have the user select a single rectangular
1722
+ region of interest.
1723
+
1724
+ :param a: Two-dimensional image data array for which a region of
1725
+ interest will be selected.
1726
+ :type a: numpy.ndarray
1727
+ :param preselected_roi: Preselected region of interest.
1728
+ :type preselected_roi: tuple(int, int, int, int), optional
1729
+ :param title: Title for the displayed figure.
1730
+ :type title: str, optional
1731
+ :param title_a: Title for the image of a.
1732
+ :type title_a: str, optional
1733
+ :param row_label: Label for the y-axis of the displayed figure,
1734
+ defaults to `row index`.
1735
+ :type row_label: str, optional
1736
+ :param column_label: Label for the x-axis of the displayed figure,
1737
+ defaults to `column index`.
1738
+ :type column_label: str, optional
1739
+ :param interactive: Show the plot and allow user interactions with
1740
+ the matplotlib figure, defaults to `True`.
1741
+ :type interactive: bool, optional
1742
+ :param filename: Save a .png of the plot to filename, defaults to
1743
+ `None`, in which case the plot is not saved.
1744
+ :type filename: str, optional
1745
+ :param return_buf: Return an in-memory object as a byte stream
1746
+ represention of the Matplotlib figure, defaults to `False`.
1747
+ :type return_buf: bool, optional
1748
+ :return: A byte stream represention of the Matplotlib figure if
1749
+ return_buf is `True` (`None` otherwise), and the selected
1750
+ region of interest.
1751
+ :rtype: Union[io.BytesIO, None], tuple(int, int, int, int)
1752
+ """
1753
+ # Third party modules
1754
+ # pylint: disable=possibly-used-before-assignment
1755
+ if interactive or filename is not None or return_buf:
1756
+ from matplotlib.widgets import Button, RectangleSelector
1757
+
1758
+ def change_fig_title(title):
1759
+ if fig_title:
1760
+ fig_title[0].remove()
1761
+ fig_title.pop()
1762
+ fig_title.append(plt.figtext(*title_pos, title, **title_props))
1763
+
1764
+ def change_subfig_title(error):
1765
+ if subfig_title:
1766
+ subfig_title[0].remove()
1767
+ subfig_title.pop()
1768
+ subfig_title.append(plt.figtext(*error_pos, error, **error_props))
1769
+
1770
+ def clear_selection():
1771
+ rects[0].set_visible(False)
1772
+ rects.pop()
1773
+ rects.append(
1774
+ RectangleSelector(
1775
+ ax, on_rect_select, props=rect_props,
1776
+ useblit=True, interactive=interactive, drag_from_anywhere=True,
1777
+ ignore_event_outside=False))
1778
+
1779
+ def on_rect_select(eclick, erelease):
1780
+ """Callback function for the RectangleSelector widget."""
1781
+ if (not int(rects[0].extents[1]) - int(rects[0].extents[0])
1782
+ or not int(rects[0].extents[3]) - int(rects[0].extents[2])):
1783
+ clear_selection()
1784
+ change_subfig_title(
1785
+ 'Selected ROI too small, try again')
1786
+ else:
1787
+ change_subfig_title(
1788
+ f'Selected ROI: {tuple(int(v) for v in rects[0].extents)}')
1789
+ plt.draw()
1790
+
1791
+ def reset(event):
1792
+ """Callback function for the "Reset" button."""
1793
+ if subfig_title:
1794
+ subfig_title[0].remove()
1795
+ subfig_title.pop()
1796
+ clear_selection()
1797
+ plt.draw()
1798
+
1799
+ def confirm(event):
1800
+ """Callback function for the "Confirm" button."""
1801
+ if subfig_title:
1802
+ subfig_title[0].remove()
1803
+ subfig_title.pop()
1804
+ roi = tuple(int(v) for v in rects[0].extents)
1805
+ if roi[1]-roi[0] < 1 or roi[3]-roi[2] < 1:
1806
+ roi = None
1807
+ change_fig_title(f'Selected ROI: {roi}')
1808
+ plt.close()
1809
+
1810
+ # Check inputs
1811
+ a = np.asarray(a)
1812
+ if a.ndim != 2:
1813
+ raise ValueError(f'Invalid image dimension ({a.ndim})')
1814
+ if preselected_roi is not None:
1815
+ if (not is_int_series(preselected_roi, ge=0, log=False)
1816
+ or len(preselected_roi) != 4):
1817
+ raise ValueError('Invalid parameter preselected_roi '
1818
+ f'({preselected_roi})')
1819
+ if title is None:
1820
+ title = 'Click and drag to select or adjust a region of interest (ROI)'
1821
+
1822
+ if not interactive and filename is None and not return_buf:
1823
+ return None, preselected_roi
1824
+
1825
+ fig_title = []
1826
+ subfig_title = []
1827
+
1828
+ title_pos = (0.5, 0.95)
1829
+ title_props = {'fontsize': 'xx-large', 'horizontalalignment': 'center',
1830
+ 'verticalalignment': 'bottom'}
1831
+ error_pos = (0.5, 0.90)
1832
+ error_props = {'fontsize': 'xx-large', 'horizontalalignment': 'center',
1833
+ 'verticalalignment': 'bottom'}
1834
+ rect_props = {
1835
+ 'alpha': 0.5, 'facecolor': 'tab:blue', 'edgecolor': 'blue'}
1836
+
1837
+ fig, ax = plt.subplots(figsize=(11, 8.5))
1838
+ ax.imshow(a)
1839
+ ax.set_title(title_a, fontsize='xx-large')
1840
+ ax.set_xlabel(column_label, fontsize='x-large')
1841
+ ax.set_ylabel(row_label, fontsize='x-large')
1842
+ ax.set_xlim(0, a.shape[1])
1843
+ ax.set_ylim(a.shape[0], 0)
1844
+ fig.subplots_adjust(bottom=0.0, top=0.85)
1845
+
1846
+ # Setup the preselected range of interest if provided
1847
+ rects = [RectangleSelector(
1848
+ ax, on_rect_select, props=rect_props, useblit=True,
1849
+ interactive=interactive, drag_from_anywhere=True,
1850
+ ignore_event_outside=True)]
1851
+ if preselected_roi is not None:
1852
+ rects[0].extents = preselected_roi
1853
+
1854
+ if not interactive:
1855
+
1856
+ if preselected_roi is not None:
1857
+ change_fig_title(
1858
+ f'Selected ROI: {tuple(int(v) for v in preselected_roi)}')
1859
+
1860
+ else:
1861
+
1862
+ change_fig_title(title)
1863
+ if preselected_roi is not None:
1864
+ change_subfig_title(
1865
+ f'Preselected ROI: {tuple(int(v) for v in preselected_roi)}')
1866
+ fig.subplots_adjust(bottom=0.2)
1867
+
1868
+ # Setup "Reset" button
1869
+ reset_btn = Button(plt.axes([0.125, 0.05, 0.15, 0.075]), 'Reset')
1870
+ reset_cid = reset_btn.on_clicked(reset)
1871
+
1872
+ # Setup "Confirm" button
1873
+ confirm_btn = Button(plt.axes([0.75, 0.05, 0.15, 0.075]), 'Confirm')
1874
+ confirm_cid = confirm_btn.on_clicked(confirm)
1875
+
1876
+ # Show figure for user interaction
1877
+ plt.show()
1878
+
1879
+ # Disconnect all widget callbacks when figure is closed
1880
+ reset_btn.disconnect(reset_cid)
1881
+ confirm_btn.disconnect(confirm_cid)
1882
+
1883
+ # ... and remove the buttons before returning the figure
1884
+ reset_btn.ax.remove()
1885
+ confirm_btn.ax.remove()
1886
+
1887
+ buf = None
1888
+ if filename is not None or return_buf:
1889
+ if fig_title:
1890
+ fig_title[0].set_in_layout(True)
1891
+ fig.tight_layout(rect=(0, 0, 1, 0.95))
1892
+ else:
1893
+ fig.tight_layout(rect=(0, 0, 1, 1))
1894
+
1895
+ # Remove the handles
1896
+ if interactive:
1897
+ rects[0]._center_handle.set_visible(False)
1898
+ rects[0]._corner_handles.set_visible(False)
1899
+ rects[0]._edge_handles.set_visible(False)
1900
+ if filename is not None:
1901
+ fig.savefig(filename)
1902
+ if return_buf:
1903
+ buf = fig_to_iobuf(fig)
1904
+ plt.close()
1905
+
1906
+ roi = tuple(int(v) for v in rects[0].extents)
1907
+ if roi[1]-roi[0] < 1 or roi[3]-roi[2] < 1:
1908
+ roi = None
1909
+
1910
+ return buf, roi
1911
+
1912
+
1913
+ def select_image_indices(
1914
+ a, axis, b=None, preselected_indices=None, axis_index_offset=0,
1915
+ min_range=None, min_num_indices=2, max_num_indices=2, title=None,
1916
+ title_a=None, title_b=None, row_label='row index',
1917
+ column_label='column index', interactive=True, return_buf=False):
1918
+ """Display a 2D image and have the user select a set of image
1919
+ indices in either row or column direction.
1920
+
1921
+ :param a: Two-dimensional image data array for which a region of
1922
+ interest will be selected.
1923
+ :type a: numpy.ndarray
1924
+ :param axis: The selection direction (0: row, 1: column)
1925
+ :type axis: int
1926
+ :param b: A secondary two-dimensional image data array for which
1927
+ a shared region of interest will be selected.
1928
+ :type b: numpy.ndarray, optional
1929
+ :param preselected_indices: Preselected image indices.
1930
+ :type preselected_indices: tuple(int), list(int), optional
1931
+ :param axis_index_offset: Offset in axis index range and
1932
+ preselected indices, defaults to `0`.
1933
+ :type axis_index_offset: int, optional
1934
+ :param min_range: The minimal range spanned by the selected
1935
+ indices.
1936
+ :type min_range: int, optional
1937
+ :param min_num_indices: The minimum number of selected indices.
1938
+ :type min_num_indices: int, optional
1939
+ :param max_num_indices: The maximum number of selected indices.
1940
+ :type max_num_indices: int, optional
1941
+ :param title: Title for the displayed figure.
1942
+ :type title: str, optional
1943
+ :param title_a: Title for the image of a.
1944
+ :type title_a: str, optional
1945
+ :param title_b: Title for the image of b.
1946
+ :type title_b: str, optional
1947
+ :param row_label: Label for the y-axis of the displayed figure,
1948
+ defaults to `row index`.
1949
+ :type row_label: str, optional
1950
+ :param column_label: Label for the x-axis of the displayed figure,
1951
+ defaults to `column index`.
1952
+ :type column_label: str, optional
1953
+ :param interactive: Show the plot and allow user interactions with
1954
+ the matplotlib figure, defaults to `True`.
1955
+ :type interactive: bool, optional
1956
+ :param return_buf: Return an in-memory object as a byte stream
1957
+ represention of the Matplotlib figure instead of the
1958
+ matplotlib figure, defaults to `False`.
1959
+ :type return_buf: bool, optional
1960
+ :return: The selected region of interest as array indices and a
1961
+ matplotlib figure.
1962
+ :rtype: Union[matplotlib.figure.Figure, io.BytesIO],
1963
+ tuple(int, int, int, int)
1964
+ """
1965
+ # Third party modules
1966
+ from matplotlib.widgets import TextBox, Button
1967
+
1968
+ index_input = None
1969
+
1970
+ def change_fig_title(title):
1971
+ if fig_title:
1972
+ fig_title[0].remove()
1973
+ fig_title.pop()
1974
+ fig_title.append(plt.figtext(*title_pos, title, **title_props))
1975
+
1976
+ def change_error_text(error):
1977
+ if error_texts:
1978
+ error_texts[0].remove()
1979
+ error_texts.pop()
1980
+ error_texts.append(plt.figtext(*error_pos, error, **error_props))
1981
+
1982
+ def get_selected_indices(change_fnc=None):
1983
+ selected_indices = tuple(sorted(indices))
1984
+ if change_fnc is not None:
1985
+ num_indices = len(indices)
1986
+ if len(selected_indices) > 1:
1987
+ text = f'Selected {row_column} indices: {selected_indices}'
1988
+ elif selected_indices:
1989
+ text = f'Selected {row_column} index: {selected_indices[0]}'
1990
+ else:
1991
+ text = f'Selected {row_column} indices: None'
1992
+ if min_num_indices is not None and num_indices < min_num_indices:
1993
+ if min_num_indices == max_num_indices:
1994
+ text += \
1995
+ f', select another {max_num_indices-num_indices}'
1996
+ else:
1997
+ text += \
1998
+ f', select at least {max_num_indices-num_indices} more'
1999
+ change_fnc(text)
2000
+ return selected_indices
2001
+
2002
+ def add_index(index):
2003
+ if index in indices:
2004
+ raise ValueError(f'Ignoring duplicate of selected {row_column}s')
2005
+ if max_num_indices is not None and len(indices) >= max_num_indices:
2006
+ raise ValueError(
2007
+ f'Exceeding maximum number of selected {row_column}s, click '
2008
+ 'either "Reset" or "Confirm"')
2009
+ if (indices and min_range is not None
2010
+ and abs(max(index, *indices) - min(index, *indices))
2011
+ < min_range):
2012
+ raise ValueError(
2013
+ f'Selected {row_column} range is smaller than required '
2014
+ 'minimal range of {min_range}: ignoring last selection')
2015
+ indices.append(index)
2016
+ if not axis:
2017
+ for ax in axs:
2018
+ lines.append(ax.axhline(indices[-1], c='r', lw=2))
2019
+ else:
2020
+ for ax in axs:
2021
+ lines.append(ax.axvline(indices[-1], c='r', lw=2))
2022
+
2023
+ def select_index(expression):
2024
+ """Callback function for the "Select row/column index" TextBox.
2025
+ """
2026
+ if not expression:
2027
+ return
2028
+ if error_texts:
2029
+ error_texts[0].remove()
2030
+ error_texts.pop()
2031
+ try:
2032
+ index = int(expression)
2033
+ if (index < axis_index_offset
2034
+ or index > axis_index_offset+a.shape[axis]):
2035
+ raise ValueError
2036
+ except ValueError:
2037
+ change_error_text(
2038
+ f'Invalid {row_column} index ({expression}), enter an integer '
2039
+ f'between {axis_index_offset} and '
2040
+ f'{axis_index_offset+a.shape[axis]-1}')
2041
+ else:
2042
+ try:
2043
+ add_index(index)
2044
+ get_selected_indices(change_error_text)
2045
+ except ValueError as exc:
2046
+ change_error_text(exc)
2047
+ index_input.set_val('')
2048
+ for ax in axs:
2049
+ ax.get_figure().canvas.draw()
2050
+
2051
+ def reset(event):
2052
+ """Callback function for the "Reset" button."""
2053
+ if error_texts:
2054
+ error_texts[0].remove()
2055
+ error_texts.pop()
2056
+ for line in reversed(lines):
2057
+ line.remove()
2058
+ indices.clear()
2059
+ lines.clear()
2060
+ get_selected_indices(change_error_text)
2061
+ for ax in axs:
2062
+ ax.get_figure().canvas.draw()
2063
+
2064
+ def confirm(event):
2065
+ """Callback function for the "Confirm" button."""
2066
+ if len(indices) < min_num_indices:
2067
+ change_error_text(
2068
+ f'Select at least {min_num_indices} unique {row_column}s')
2069
+ for ax in axs:
2070
+ ax.get_figure().canvas.draw()
2071
+ else:
2072
+ # Remove error texts and add selected indices if set
2073
+ if error_texts:
2074
+ error_texts[0].remove()
2075
+ error_texts.pop()
2076
+ get_selected_indices(change_fig_title)
2077
+ plt.close()
2078
+
2079
+ # Check inputs
2080
+ a = np.asarray(a)
2081
+ if a.ndim != 2:
2082
+ raise ValueError(f'Invalid image dimension ({a.ndim})')
2083
+ if axis < 0 or axis >= a.ndim:
2084
+ raise ValueError(f'Invalid parameter axis ({axis})')
2085
+ if not axis:
2086
+ row_column = 'row'
2087
+ else:
2088
+ row_column = 'column'
2089
+ if not is_int(axis_index_offset, ge=0, log=False):
2090
+ raise ValueError(
2091
+ 'Invalid parameter axis_index_offset ({axis_index_offset})')
2092
+ if preselected_indices is not None:
2093
+ if not is_int_series(
2094
+ preselected_indices, ge=axis_index_offset,
2095
+ le=axis_index_offset+a.shape[axis], log=False):
2096
+ if interactive:
2097
+ logger.warning(
2098
+ 'Invalid parameter preselected_indices '
2099
+ f'({preselected_indices}), ignoring preselected_indices')
2100
+ preselected_indices = None
2101
+ else:
2102
+ raise ValueError('Invalid parameter preselected_indices '
2103
+ f'({preselected_indices})')
2104
+ if min_range is not None and not 2 <= min_range <= a.shape[axis]:
2105
+ raise ValueError('Invalid parameter min_range ({min_range})')
2106
+ if title is None:
2107
+ title = f'Select or adjust image {row_column} indices'
2108
+ if b is not None:
2109
+ b = np.asarray(b)
2110
+ if b.ndim != 2:
2111
+ raise ValueError(f'Invalid image dimension ({b.ndim})')
2112
+ if a.shape[0] != b.shape[0]:
2113
+ raise ValueError(f'Inconsistent image shapes({a.shape} vs '
2114
+ f'{b.shape})')
2115
+
2116
+ indices = []
2117
+ lines = []
2118
+ fig_title = []
2119
+ error_texts = []
2120
+
2121
+ title_pos = (0.5, 0.95)
2122
+ title_props = {'fontsize': 'xx-large', 'horizontalalignment': 'center',
2123
+ 'verticalalignment': 'bottom'}
2124
+ error_pos = (0.5, 0.90)
2125
+ error_props = {'fontsize': 'x-large', 'horizontalalignment': 'center',
2126
+ 'verticalalignment': 'bottom'}
2127
+ if b is None:
2128
+ fig, axs = plt.subplots(figsize=(11, 8.5))
2129
+ axs = [axs]
2130
+ else:
2131
+ if a.shape[0]+b.shape[0] > max(a.shape[1], b.shape[1]):
2132
+ fig, axs = plt.subplots(1, 2, figsize=(11, 8.5))
2133
+ else:
2134
+ fig, axs = plt.subplots(2, 1, figsize=(11, 8.5))
2135
+ extent = (0, a.shape[1], axis_index_offset+a.shape[0], axis_index_offset)
2136
+ axs[0].imshow(a, extent=extent)
2137
+ axs[0].set_title(title_a, fontsize='xx-large')
2138
+ if b is not None:
2139
+ axs[1].imshow(b, extent=extent)
2140
+ axs[1].set_title(title_b, fontsize='xx-large')
2141
+ if a.shape[0]+b.shape[0] > max(a.shape[1], b.shape[1]):
2142
+ axs[0].set_xlabel(column_label, fontsize='x-large')
2143
+ axs[0].set_ylabel(row_label, fontsize='x-large')
2144
+ axs[1].set_xlabel(column_label, fontsize='x-large')
2145
+ else:
2146
+ axs[0].set_ylabel(row_label, fontsize='x-large')
2147
+ axs[1].set_xlabel(column_label, fontsize='x-large')
2148
+ axs[1].set_ylabel(row_label, fontsize='x-large')
2149
+ for ax in axs:
2150
+ ax.set_xlim(extent[0], extent[1])
2151
+ ax.set_ylim(extent[2], extent[3])
2152
+ fig.subplots_adjust(bottom=0.0, top=0.85)
2153
+
2154
+ # Setup the preselected indices if provided
2155
+ if preselected_indices is not None:
2156
+ preselected_indices = sorted(list(preselected_indices))
2157
+ for index in preselected_indices:
2158
+ add_index(index)
2159
+
2160
+ if not interactive:
2161
+
2162
+ get_selected_indices(change_fig_title)
2163
+
2164
+ else:
2165
+
2166
+ change_fig_title(title)
2167
+ get_selected_indices(change_error_text)
2168
+ fig.subplots_adjust(bottom=0.2)
2169
+
2170
+ # Setup TextBox
2171
+ index_input = TextBox(
2172
+ plt.axes([0.25, 0.05, 0.15, 0.075]), f'Select {row_column} index ')
2173
+ indices_cid = index_input.on_submit(select_index)
2174
+
2175
+ # Setup "Reset" button
2176
+ reset_btn = Button(plt.axes([0.5, 0.05, 0.15, 0.075]), 'Reset')
2177
+ reset_cid = reset_btn.on_clicked(reset)
2178
+
2179
+ # Setup "Confirm" button
2180
+ confirm_btn = Button(plt.axes([0.75, 0.05, 0.15, 0.075]), 'Confirm')
2181
+ confirm_cid = confirm_btn.on_clicked(confirm)
2182
+
2183
+ plt.show()
2184
+
2185
+ # Disconnect all widget callbacks when figure is closed
2186
+ index_input.disconnect(indices_cid)
2187
+ reset_btn.disconnect(reset_cid)
2188
+ confirm_btn.disconnect(confirm_cid)
2189
+
2190
+ # ... and remove the buttons before returning the figure
2191
+ index_input.ax.remove()
2192
+ reset_btn.ax.remove()
2193
+ confirm_btn.ax.remove()
2194
+
2195
+ fig_title[0].set_in_layout(True)
2196
+ fig.tight_layout(rect=(0, 0, 1, 0.95))
2197
+
2198
+ if return_buf:
2199
+ buf = fig_to_iobuf(fig)
2200
+ else:
2201
+ buf = None
2202
+ plt.close()
2203
+ if indices:
2204
+ return buf, tuple(sorted(indices))
2205
+ return buf, None
2206
+
2207
+
2208
+ def quick_imshow(
2209
+ a, title=None, row_label='row index', column_label='column index',
2210
+ path=None, name=None, show_fig=True, save_fig=False,
2211
+ return_fig=False, block=None, extent=None, show_grid=False,
2212
+ grid_color='w', grid_linewidth=1, **kwargs):
2213
+ """Display and or save a 2D image and or return an in-memory object
2214
+ as a byte stream represention.
2215
+ """
2216
+ if title is not None and not isinstance(title, str):
2217
+ raise ValueError(f'Invalid parameter title ({title})')
2218
+ if path is not None and not isinstance(path, str):
2219
+ raise ValueError(f'Invalid parameter path ({path})')
2220
+ if not isinstance(show_fig, bool):
2221
+ raise ValueError(f'Invalid parameter show_fig ({show_fig})')
2222
+ if not isinstance(save_fig, bool):
2223
+ raise ValueError(f'Invalid parameter save_fig ({save_fig})')
2224
+ if not isinstance(return_fig, bool):
2225
+ raise ValueError(f'Invalid parameter return_fig ({return_fig})')
2226
+ if block is not None and not isinstance(block, bool):
2227
+ raise ValueError(f'Invalid parameter block ({block})')
2228
+ if not title:
2229
+ title = 'quick imshow'
2230
+ if ('cmap' in kwargs and a.ndim == 3
2231
+ and (a.shape[2] == 3 or a.shape[2] == 4)):
2232
+ use_cmap = True
2233
+ if a.shape[2] == 4 and a[:,:,-1].min() != a[:,:,-1].max():
2234
+ use_cmap = False
2235
+ if any(
2236
+ a[i,j,0] != a[i,j,1] and a[i,j,0] != a[i,j,2]
2237
+ for i in range(a.shape[0])
2238
+ for j in range(a.shape[1])):
2239
+ use_cmap = False
2240
+ if use_cmap:
2241
+ a = a[:,:,0]
2242
+ else:
2243
+ logger.warning('Image incompatible with cmap option, ignore cmap')
2244
+ kwargs.pop('cmap')
2245
+ if extent is None:
2246
+ extent = (0, a.shape[1], a.shape[0], 0)
2247
+ plt.ioff()
2248
+ fig, ax = plt.subplots(figsize=(11, 8.5))
2249
+ ax.imshow(a, extent=extent, **kwargs)
2250
+ ax.set_title(title, fontsize='xx-large')
2251
+ ax.set_xlabel(column_label, fontsize='x-large')
2252
+ ax.set_ylabel(row_label, fontsize='x-large')
2253
+ if show_grid:
2254
+ ax.grid(color=grid_color, linewidth=grid_linewidth)
2255
+ if show_fig:
2256
+ plt.show(block=block)
2257
+ if save_fig:
2258
+ if name is None:
2259
+ title = re.sub(r'\s+', '_', title)
2260
+ if path is None:
2261
+ path = title
2262
+ else:
2263
+ path = f'{path}/{title}'
2264
+ else:
2265
+ if path is None:
2266
+ path = name
2267
+ else:
2268
+ path = f'{path}/{name}'
2269
+ if (os.path.splitext(path)[1]
2270
+ not in plt.gcf().canvas.get_supported_filetypes()):
2271
+ path += '.png'
2272
+ plt.savefig(path)
2273
+ if return_fig:
2274
+ buf = fig_to_iobuf(fig)
2275
+ else:
2276
+ buf = None
2277
+ plt.close()
2278
+ return buf
2279
+
2280
+
2281
+ def quick_plot(
2282
+ *args, xerr=None, yerr=None, vlines=None, title=None, xlim=None,
2283
+ ylim=None, xlabel=None, ylabel=None, legend=None, path=None, name=None,
2284
+ show_grid=False, save_fig=False, save_only=False, block=False,
2285
+ **kwargs):
2286
+ """Display a 2D line plot."""
2287
+ #RV FIX: Update with return_buf
2288
+ if title is not None and not isinstance(title, str):
2289
+ illegal_value(title, 'title', 'quick_plot')
2290
+ title = None
2291
+ if (xlim is not None and not isinstance(xlim, (tuple, list))
2292
+ and len(xlim) != 2):
2293
+ illegal_value(xlim, 'xlim', 'quick_plot')
2294
+ xlim = None
2295
+ if (ylim is not None and not isinstance(ylim, (tuple, list))
2296
+ and len(ylim) != 2):
2297
+ illegal_value(ylim, 'ylim', 'quick_plot')
2298
+ ylim = None
2299
+ if xlabel is not None and not isinstance(xlabel, str):
2300
+ illegal_value(xlabel, 'xlabel', 'quick_plot')
2301
+ xlabel = None
2302
+ if ylabel is not None and not isinstance(ylabel, str):
2303
+ illegal_value(ylabel, 'ylabel', 'quick_plot')
2304
+ ylabel = None
2305
+ if legend is not None and not isinstance(legend, (tuple, list)):
2306
+ illegal_value(legend, 'legend', 'quick_plot')
2307
+ legend = None
2308
+ if path is not None and not isinstance(path, str):
2309
+ illegal_value(path, 'path', 'quick_plot')
2310
+ return
2311
+ if not isinstance(show_grid, bool):
2312
+ illegal_value(show_grid, 'show_grid', 'quick_plot')
2313
+ return
2314
+ if not isinstance(save_fig, bool):
2315
+ illegal_value(save_fig, 'save_fig', 'quick_plot')
2316
+ return
2317
+ if not isinstance(save_only, bool):
2318
+ illegal_value(save_only, 'save_only', 'quick_plot')
2319
+ return
2320
+ if not isinstance(block, bool):
2321
+ illegal_value(block, 'block', 'quick_plot')
2322
+ return
2323
+ if title is None:
2324
+ title = 'quick plot'
2325
+ if name is None:
2326
+ ttitle = re.sub(r'\s+', '_', title)
2327
+ if path is None:
2328
+ path = f'{ttitle}.png'
2329
+ else:
2330
+ path = f'{path}/{ttitle}.png'
2331
+ else:
2332
+ if path is None:
2333
+ path = name
2334
+ else:
2335
+ path = f'{path}/{name}'
2336
+ args = unwrap_tuple(args)
2337
+ if depth_tuple(args) > 1 and (xerr is not None or yerr is not None):
2338
+ logger.warning('Error bars ignored for multiple curves')
2339
+ if not save_only:
2340
+ if block:
2341
+ plt.ioff()
2342
+ else:
2343
+ plt.ion()
2344
+ plt.figure(title)
2345
+ if depth_tuple(args) > 1:
2346
+ for y in args:
2347
+ plt.plot(*y, **kwargs)
2348
+ else:
2349
+ if xerr is None and yerr is None:
2350
+ plt.plot(*args, **kwargs)
2351
+ else:
2352
+ plt.errorbar(*args, xerr=xerr, yerr=yerr, **kwargs)
2353
+ if vlines is not None:
2354
+ if isinstance(vlines, (int, float)):
2355
+ vlines = [vlines]
2356
+ for v in vlines:
2357
+ plt.axvline(v, color='r', linestyle='--', **kwargs)
2358
+ if xlim is not None:
2359
+ plt.xlim(xlim)
2360
+ if ylim is not None:
2361
+ plt.ylim(ylim)
2362
+ if xlabel is not None:
2363
+ plt.xlabel(xlabel)
2364
+ if ylabel is not None:
2365
+ plt.ylabel(ylabel)
2366
+ if show_grid:
2367
+ ax = plt.gca()
2368
+ ax.grid(color='k') # , linewidth=1)
2369
+ if legend is not None:
2370
+ plt.legend(legend)
2371
+ if save_only:
2372
+ plt.savefig(path)
2373
+ plt.close(fig=title)
2374
+ else:
2375
+ if save_fig:
2376
+ plt.savefig(path)
2377
+ plt.show(block=block)
2378
+ plt.close()
2379
+
2380
+
2381
+ def nxcopy(
2382
+ nxobject, exclude_nxpaths=None, nxpath_prefix=None,
2383
+ nxpathabs_prefix=None, nxpath_copy_abspath=None):
2384
+ """Function that returns a copy of a nexus object, optionally
2385
+ exluding certain child items.
2386
+
2387
+ :param nxobject: The input nexus object to "copy".
2388
+ :type nxobject: nexusformat.nexus.NXobject
2389
+ :param exlude_nxpaths: A list of relative paths to child nexus
2390
+ objects that should be excluded from the returned "copy".
2391
+ :type exclude_nxpaths: str, list[str], optional
2392
+ :param nxpath_prefix: For use in recursive calls from inside this
2393
+ function only.
2394
+ :type nxpath_prefix: str
2395
+ :param nxpathabs_prefix: For use in recursive calls from inside
2396
+ this function only.
2397
+ :type nxpathabs_prefix: str
2398
+ :param nxpath_copy_abspath: For use in recursive calls from inside
2399
+ this function only.
2400
+ :type nxpath_copy_abspath: str
2401
+ :return: Copy of the input `nxobject` with some children optionally
2402
+ exluded.
2403
+ :rtype: nexusformat.nexus.NXobject
2404
+ """
2405
+ # Third party modules
2406
+ from nexusformat.nexus import (
2407
+ NXentry,
2408
+ NXfield,
2409
+ NXgroup,
2410
+ NXlink,
2411
+ NXlinkgroup,
2412
+ NXroot,
2413
+ )
2414
+
2415
+
2416
+ if isinstance(nxobject, NXlinkgroup):
2417
+ # The top level nxobject is a linked group
2418
+ # Create a group with the same name as the top level's target
2419
+ nxobject_copy = nxobject[nxobject.nxtarget].__class__(
2420
+ name=nxobject.nxname)
2421
+ elif isinstance(nxobject, (NXlink, NXfield)):
2422
+ # The top level nxobject is a (linked) field: return a copy
2423
+ attrs = nxobject.attrs
2424
+ attrs.pop('target', None)
2425
+ nxobject_copy = NXfield(
2426
+ value=nxobject.nxdata, name=nxobject.nxname,
2427
+ attrs=attrs)
2428
+ return nxobject_copy
2429
+ else:
2430
+ # Create a group with the same type/name as the nxobject
2431
+ nxobject_copy = nxobject.__class__(name=nxobject.nxname)
2432
+
2433
+ # Copy attributes
2434
+ if isinstance(nxobject, NXroot):
2435
+ if 'default' in nxobject.attrs:
2436
+ nxobject_copy.attrs['default'] = nxobject.default
2437
+ else:
2438
+ for k, v in nxobject.attrs.items():
2439
+ nxobject_copy.attrs[k] = v
2440
+
2441
+ # Setup paths
2442
+ if exclude_nxpaths is None:
2443
+ exclude_nxpaths = []
2444
+ elif isinstance(exclude_nxpaths, str):
2445
+ exclude_nxpaths = [exclude_nxpaths]
2446
+ for exclude_nxpath in exclude_nxpaths:
2447
+ if exclude_nxpath[0] == '/':
2448
+ raise ValueError(
2449
+ f'Invalid parameter in exclude_nxpaths ({exclude_nxpaths}), '
2450
+ 'excluded paths should be relative')
2451
+ if nxpath_prefix is None:
2452
+ nxpath_prefix = ''
2453
+ if nxpathabs_prefix is None:
2454
+ if isinstance(nxobject, NXentry):
2455
+ nxpathabs_prefix = nxobject.nxpath
2456
+ else:
2457
+ nxpathabs_prefix = nxobject.nxpath.removesuffix(nxobject.nxname)
2458
+ if nxpath_copy_abspath is None:
2459
+ nxpath_copy_abspath = ''
2460
+
2461
+ # Loop over all nxobject's children
2462
+ for k, v in nxobject.items():
2463
+ nxpath = os.path.join(nxpath_prefix, k)
2464
+ nxpathabs = os.path.join(nxpathabs_prefix, nxpath)
2465
+ if nxpath in exclude_nxpaths:
2466
+ if 'default' in nxobject_copy.attrs and nxobject_copy.default == k:
2467
+ nxobject_copy.attrs.pop('default')
2468
+ continue
2469
+ if isinstance(v, NXlinkgroup):
2470
+ if nxpathabs == v.nxpath and not any(
2471
+ v.nxtarget.startswith(os.path.join(nxpathabs_prefix, p))
2472
+ for p in exclude_nxpaths):
2473
+ nxobject_copy[k] = NXlink(v.nxtarget)
2474
+ else:
2475
+ nxobject_copy[k] = nxcopy(
2476
+ v, exclude_nxpaths=exclude_nxpaths,
2477
+ nxpath_prefix=nxpath, nxpathabs_prefix=nxpathabs_prefix,
2478
+ nxpath_copy_abspath=os.path.join(nxpath_copy_abspath, k))
2479
+ elif isinstance(v, NXlink):
2480
+ if nxpathabs == v.nxpath and not any(
2481
+ v.nxtarget.startswith(os.path.join(nxpathabs_prefix, p))
2482
+ for p in exclude_nxpaths):
2483
+ nxobject_copy[k] = v
2484
+ else:
2485
+ nxobject_copy[k] = v.nxdata
2486
+ for kk, vv in v.attrs.items():
2487
+ nxobject_copy[k].attrs[kk] = vv
2488
+ nxobject_copy[k].attrs.pop('target', None)
2489
+ elif isinstance(v, NXgroup):
2490
+ nxobject_copy[k] = nxcopy(
2491
+ v, exclude_nxpaths=exclude_nxpaths,
2492
+ nxpath_prefix=nxpath, nxpathabs_prefix=nxpathabs_prefix,
2493
+ nxpath_copy_abspath=os.path.join(nxpath_copy_abspath, k))
2494
+ else:
2495
+ nxobject_copy[k] = v.nxdata
2496
+ for kk, vv in v.attrs.items():
2497
+ nxobject_copy[k].attrs[kk] = vv
2498
+ if nxpathabs != os.path.join(nxpath_copy_abspath, k):
2499
+ nxobject_copy[k].attrs.pop('target', None)
2500
+
2501
+ return nxobject_copy
2502
+
2503
+
2504
+ def dictionary_update(target, source, merge_key_paths=None, sort=False):
2505
+ """Recursively updates a target dictionary with values from a source
2506
+ dictionary. Source values superseed target values for identical keys
2507
+ unless both values are lists of dictionaries in which case they are
2508
+ merged according to the merge_key_paths parameter.
2509
+
2510
+ :param target: Target dictionary.
2511
+ :type target: collections.abc.Mapping
2512
+ :param source: Source dictionary.
2513
+ :type target: collections.abc.Mapping
2514
+ :param merge_key_paths: List key paths to merge dictionary lists,
2515
+ only used if items in the target and source dictionary trees
2516
+ are lists of dictionaries.
2517
+ :type merge_key_paths: Union[str, list[str]]
2518
+ :param sort: Sort dictionary lists on the key.
2519
+ :type sort: bool, optional
2520
+ :return: The updated target directory.
2521
+ :rtype: collections.abc.Mapping
2522
+ """
2523
+ if not isinstance(target, dict):
2524
+ raise ValueError(
2525
+ 'Invalid parameter type "target" ({type(target)})')
2526
+ if not isinstance(source, dict):
2527
+ raise ValueError(
2528
+ 'Invalid parameter type "source" ({type(source)})')
2529
+ for k, v in source.items():
2530
+ if (isinstance(v, collections.abc.Mapping)
2531
+ and isinstance(target.get(k), collections.abc.Mapping)):
2532
+ if merge_key_paths is not None:
2533
+ raise NotImplementedError(
2534
+ f'"merge_key_paths" ({type(merge_key_paths)}) '
2535
+ 'for source and target dictionaries not yet implemented')
2536
+ # merge_key_path = None
2537
+ # if '/' in merge_key_paths:
2538
+ # print(f'"/" in merge_key_path')
2539
+ # merge_key_path = merge_key_paths.split('/', 1)[1:]
2540
+ # elif is_str_series(merge_key_paths):
2541
+ # print(f'merge_key_path is string series')
2542
+ # merge_key_path = [
2543
+ # vv[1] for vv in [
2544
+ # v for v in [merge_key_paths.split('/', 1)
2545
+ # for s in sss]]
2546
+ # if (vv[0]==k and len(vv)>1)]
2547
+ # print(f'---> merge_key_path: {merge_key_path}')
2548
+ target[k] = dictionary_update(target.get(k, {}), v)
2549
+ elif (is_dict_series(v, log=False)
2550
+ and is_dict_series(target.get(k), log=False)):
2551
+ if isinstance(merge_key_paths, str):
2552
+ merge_key_path = merge_key_paths
2553
+ merge_key_type = None
2554
+ elif isinstance(merge_key_paths, dict):
2555
+ merge_key_path = merge_key_paths.get('key_path')
2556
+ merge_key_type = merge_key_paths.get('type')
2557
+ elif merge_key_path is not None:
2558
+ raise NotImplementedError(
2559
+ 'Invalid/unimplemeted parameter type "merge_key_path" '
2560
+ f'({type(merge_key_pathsource)}) for source and target '
2561
+ 'lists of dictionaries')
2562
+ merge_key = l[1] if len(
2563
+ l:=merge_key_path.split('/')) == 2 else None
2564
+ # if '/' in merge_key_paths:
2565
+ # merge_key_paths = [merge_key_paths]
2566
+ # if is_str_series(merge_key_paths):
2567
+ # paths = paths if len(
2568
+ # paths:=[l[1] for path in merge_key_paths
2569
+ # if (l:=path.split('/', 1))[0] == k and len(l)>1]
2570
+ # ) else [None]
2571
+ # if len(paths) > 1:
2572
+ # raise ValueError(
2573
+ # 'Ambiguous parameter merge_key_paths '
2574
+ # f'({merge_key_paths}) while trying to merge '
2575
+ # f'{source} with {target}')
2576
+ # merge_path = paths[0]
2577
+ # else:
2578
+ # merge_path = None
2579
+ target[k] = list_dictionary_update(
2580
+ target.get(k), v, key=merge_key, key_type=merge_key_type,
2581
+ sort=sort)
2582
+ else:
2583
+ target[k] = v
2584
+ return target
2585
+
2586
+
2587
+ def list_dictionary_update(
2588
+ target, source, key=None, key_type=None, sort=False):
2589
+ """Recursively updates a target list of dictionaries with values
2590
+ from a source list of dictionaries. Each list item is updated item
2591
+ by item based on the key if given and equal to a key that is shared
2592
+ among all sets of source and target list item keys. Otherwise the
2593
+ target list appended to the source list is returned.
2594
+
2595
+ :param target: Target list.
2596
+ :type target: list
2597
+ :param source: Source list.
2598
+ :type source: list
2599
+ :param key: Selected key to merge the lists of dictionaries.
2600
+ :type key: str, optional
2601
+ :param key_type: Key type to enforce.
2602
+ :type key_type: type, optional
2603
+ :param sort: Sort the returned list on the key.
2604
+ :type sort: bool, optional
2605
+ :return: The updated list.
2606
+ :rtype: list
2607
+ """
2608
+ if not isinstance(target, list):
2609
+ raise ValueError(
2610
+ 'Invalid parameter type "target" ({type(target)})')
2611
+ if not isinstance(source, list):
2612
+ raise ValueError(
2613
+ 'Invalid parameter type "source" ({type(source)})')
2614
+ if key is None:
2615
+ return source + target
2616
+ if not isinstance(key, str) or '/' in key:
2617
+ raise ValueError('Invalid parameter "key" ({key}, {type(key)})')
2618
+ if not (key_type is None or isinstance(key_type, type)):
2619
+ raise ValueError(
2620
+ 'Invalid parameter "key_type" ({key_type}, {type(key_type)})')
2621
+ all_any_source = all_any(source, key)
2622
+ if all_any_source < 0:
2623
+ raise ValueError(
2624
+ f'Partially shared key ({key}) while trying to merge {source} '
2625
+ f'with {target}')
2626
+ all_any_target = all_any(target, key)
2627
+ if all_any_target < 0 or all_any_source != all_any_target:
2628
+ raise ValueError(
2629
+ f'Partially shared key ({key}) while trying to merge {source} '
2630
+ f'with {target}')
2631
+ if not all_any_source and not all_any_target:
2632
+ return source + target
2633
+ merged = []
2634
+ for target_dict in target:
2635
+ value = target_dict[key]
2636
+ if key_type is not None:
2637
+ value = key_type(value)
2638
+ for i, source_dict in enumerate(source):
2639
+ vvalue = source_dict[key]
2640
+ if key_type is not None:
2641
+ vvalue = key_type(vvalue)
2642
+ if value == vvalue:
2643
+ merged.append(dictionary_update(
2644
+ target_dict, source_dict, sort=sort))
2645
+ source.pop(i)
2646
+ break
2647
+ else:
2648
+ merged.append(target_dict)
2649
+ merged.extend(source)
2650
+ if sorted:
2651
+ if key_type is None:
2652
+ merged.sort(key=lambda x: x[key])
2653
+ else:
2654
+ merged.sort(key=lambda x: key_type(x[key]))
2655
+ return merged