ChessAnalysisPipeline 0.0.3__py3-none-any.whl → 0.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ChessAnalysisPipeline might be problematic. Click here for more details.

@@ -0,0 +1,1225 @@
1
+ #!/usr/bin/env python3
2
+
3
+ #FIX write a function that returns a list of peak indices for a given plot
4
+ #FIX use raise_error concept on more functions to optionally raise an error
5
+
6
+ # -*- coding: utf-8 -*-
7
+ """
8
+ Created on Mon Dec 6 15:36:22 2021
9
+
10
+ @author: rv43
11
+ """
12
+
13
+ from logging import getLogger
14
+ logger = getLogger(__name__)
15
+
16
+ from ast import literal_eval
17
+ from re import compile as re_compile
18
+ from re import split as re_split
19
+ from re import sub as re_sub
20
+ from sys import float_info
21
+
22
+ import numpy as np
23
+ try:
24
+ import matplotlib.pyplot as plt
25
+ from matplotlib.widgets import Button
26
+ except:
27
+ pass
28
+
29
+ def depth_list(L): return isinstance(L, list) and max(map(depth_list, L))+1
30
+ def depth_tuple(T): return isinstance(T, tuple) and max(map(depth_tuple, T))+1
31
+ def unwrap_tuple(T):
32
+ if depth_tuple(T) > 1 and len(T) == 1:
33
+ T = unwrap_tuple(*T)
34
+ return T
35
+
36
+ def illegal_value(value, name, location=None, raise_error=False, log=True):
37
+ if not isinstance(location, str):
38
+ location = ''
39
+ else:
40
+ location = f'in {location} '
41
+ if isinstance(name, str):
42
+ error_msg = f'Illegal value for {name} {location}({value}, {type(value)})'
43
+ else:
44
+ error_msg = f'Illegal value {location}({value}, {type(value)})'
45
+ if log:
46
+ logger.error(error_msg)
47
+ if raise_error:
48
+ raise ValueError(error_msg)
49
+
50
+ def illegal_combination(value1, name1, value2, name2, location=None, raise_error=False,
51
+ log=True):
52
+ if not isinstance(location, str):
53
+ location = ''
54
+ else:
55
+ location = f'in {location} '
56
+ if isinstance(name1, str):
57
+ error_msg = f'Illegal combination for {name1} and {name2} {location}'+ \
58
+ f'({value1}, {type(value1)} and {value2}, {type(value2)})'
59
+ else:
60
+ error_msg = f'Illegal combination {location}'+ \
61
+ f'({value1}, {type(value1)} and {value2}, {type(value2)})'
62
+ if log:
63
+ logger.error(error_msg)
64
+ if raise_error:
65
+ raise ValueError(error_msg)
66
+
67
+ def test_ge_gt_le_lt(ge, gt, le, lt, func, location=None, raise_error=False, log=True):
68
+ """Check individual and mutual validity of ge, gt, le, lt qualifiers
69
+ func: is_int or is_num to test for int or numbers
70
+ Return: True upon success or False when mutually exlusive
71
+ """
72
+ if ge is None and gt is None and le is None and lt is None:
73
+ return True
74
+ if ge is not None:
75
+ if not func(ge):
76
+ illegal_value(ge, 'ge', location, raise_error, log)
77
+ return False
78
+ if gt is not None:
79
+ illegal_combination(ge, 'ge', gt, 'gt', location, raise_error, log)
80
+ return False
81
+ elif gt is not None and not func(gt):
82
+ illegal_value(gt, 'gt', location, raise_error, log)
83
+ return False
84
+ if le is not None:
85
+ if not func(le):
86
+ illegal_value(le, 'le', location, raise_error, log)
87
+ return False
88
+ if lt is not None:
89
+ illegal_combination(le, 'le', lt, 'lt', location, raise_error, log)
90
+ return False
91
+ elif lt is not None and not func(lt):
92
+ illegal_value(lt, 'lt', location, raise_error, log)
93
+ return False
94
+ if ge is not None:
95
+ if le is not None and ge > le:
96
+ illegal_combination(ge, 'ge', le, 'le', location, raise_error, log)
97
+ return False
98
+ elif lt is not None and ge >= lt:
99
+ illegal_combination(ge, 'ge', lt, 'lt', location, raise_error, log)
100
+ return False
101
+ elif gt is not None:
102
+ if le is not None and gt >= le:
103
+ illegal_combination(gt, 'gt', le, 'le', location, raise_error, log)
104
+ return False
105
+ elif lt is not None and gt >= lt:
106
+ illegal_combination(gt, 'gt', lt, 'lt', location, raise_error, log)
107
+ return False
108
+ return True
109
+
110
+ def range_string_ge_gt_le_lt(ge=None, gt=None, le=None, lt=None):
111
+ """Return a range string representation matching the ge, gt, le, lt qualifiers
112
+ Does not validate the inputs, do that as needed before calling
113
+ """
114
+ range_string = ''
115
+ if ge is not None:
116
+ if le is None and lt is None:
117
+ range_string += f'>= {ge}'
118
+ else:
119
+ range_string += f'[{ge}, '
120
+ elif gt is not None:
121
+ if le is None and lt is None:
122
+ range_string += f'> {gt}'
123
+ else:
124
+ range_string += f'({gt}, '
125
+ if le is not None:
126
+ if ge is None and gt is None:
127
+ range_string += f'<= {le}'
128
+ else:
129
+ range_string += f'{le}]'
130
+ elif lt is not None:
131
+ if ge is None and gt is None:
132
+ range_string += f'< {lt}'
133
+ else:
134
+ range_string += f'{lt})'
135
+ return range_string
136
+
137
+ def is_int(v, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True):
138
+ """Value is an integer in range ge <= v <= le or gt < v < lt or some combination.
139
+ Return: True if yes or False is no
140
+ """
141
+ return _is_int_or_num(v, 'int', ge, gt, le, lt, raise_error, log)
142
+
143
+ def is_num(v, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True):
144
+ """Value is a number in range ge <= v <= le or gt < v < lt or some combination.
145
+ Return: True if yes or False is no
146
+ """
147
+ return _is_int_or_num(v, 'num', ge, gt, le, lt, raise_error, log)
148
+
149
+ def _is_int_or_num(v, type_str, ge=None, gt=None, le=None, lt=None, raise_error=False,
150
+ log=True):
151
+ if type_str == 'int':
152
+ if not isinstance(v, int):
153
+ illegal_value(v, 'v', '_is_int_or_num', raise_error, log)
154
+ return False
155
+ if not test_ge_gt_le_lt(ge, gt, le, lt, is_int, '_is_int_or_num', raise_error, log):
156
+ return False
157
+ elif type_str == 'num':
158
+ if not isinstance(v, (int, float)):
159
+ illegal_value(v, 'v', '_is_int_or_num', raise_error, log)
160
+ return False
161
+ if not test_ge_gt_le_lt(ge, gt, le, lt, is_num, '_is_int_or_num', raise_error, log):
162
+ return False
163
+ else:
164
+ illegal_value(type_str, 'type_str', '_is_int_or_num', raise_error, log)
165
+ return False
166
+ if ge is None and gt is None and le is None and lt is None:
167
+ return True
168
+ error = False
169
+ if ge is not None and v < ge:
170
+ error = True
171
+ error_msg = f'Value {v} out of range: {v} !>= {ge}'
172
+ if not error and gt is not None and v <= gt:
173
+ error = True
174
+ error_msg = f'Value {v} out of range: {v} !> {gt}'
175
+ if not error and le is not None and v > le:
176
+ error = True
177
+ error_msg = f'Value {v} out of range: {v} !<= {le}'
178
+ if not error and lt is not None and v >= lt:
179
+ error = True
180
+ error_msg = f'Value {v} out of range: {v} !< {lt}'
181
+ if error:
182
+ if log:
183
+ logger.error(error_msg)
184
+ if raise_error:
185
+ raise ValueError(error_msg)
186
+ return False
187
+ return True
188
+
189
+ def is_int_pair(v, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True):
190
+ """Value is an integer pair, each in range ge <= v[i] <= le or gt < v[i] < lt or
191
+ ge[i] <= v[i] <= le[i] or gt[i] < v[i] < lt[i] or some combination.
192
+ Return: True if yes or False is no
193
+ """
194
+ return _is_int_or_num_pair(v, 'int', ge, gt, le, lt, raise_error, log)
195
+
196
+ def is_num_pair(v, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True):
197
+ """Value is a number pair, each in range ge <= v[i] <= le or gt < v[i] < lt or
198
+ ge[i] <= v[i] <= le[i] or gt[i] < v[i] < lt[i] or some combination.
199
+ Return: True if yes or False is no
200
+ """
201
+ return _is_int_or_num_pair(v, 'num', ge, gt, le, lt, raise_error, log)
202
+
203
+ def _is_int_or_num_pair(v, type_str, ge=None, gt=None, le=None, lt=None, raise_error=False,
204
+ log=True):
205
+ if type_str == 'int':
206
+ if not (isinstance(v, (tuple, list)) and len(v) == 2 and isinstance(v[0], int) and
207
+ isinstance(v[1], int)):
208
+ illegal_value(v, 'v', '_is_int_or_num_pair', raise_error, log)
209
+ return False
210
+ func = is_int
211
+ elif type_str == 'num':
212
+ if not (isinstance(v, (tuple, list)) and len(v) == 2 and isinstance(v[0], (int, float)) and
213
+ isinstance(v[1], (int, float))):
214
+ illegal_value(v, 'v', '_is_int_or_num_pair', raise_error, log)
215
+ return False
216
+ func = is_num
217
+ else:
218
+ illegal_value(type_str, 'type_str', '_is_int_or_num_pair', raise_error, log)
219
+ return False
220
+ if ge is None and gt is None and le is None and lt is None:
221
+ return True
222
+ if ge is None or func(ge, log=True):
223
+ ge = 2*[ge]
224
+ elif not _is_int_or_num_pair(ge, type_str, raise_error=raise_error, log=log):
225
+ return False
226
+ if gt is None or func(gt, log=True):
227
+ gt = 2*[gt]
228
+ elif not _is_int_or_num_pair(gt, type_str, raise_error=raise_error, log=log):
229
+ return False
230
+ if le is None or func(le, log=True):
231
+ le = 2*[le]
232
+ elif not _is_int_or_num_pair(le, type_str, raise_error=raise_error, log=log):
233
+ return False
234
+ if lt is None or func(lt, log=True):
235
+ lt = 2*[lt]
236
+ elif not _is_int_or_num_pair(lt, type_str, raise_error=raise_error, log=log):
237
+ return False
238
+ if (not func(v[0], ge[0], gt[0], le[0], lt[0], raise_error, log) or
239
+ not func(v[1], ge[1], gt[1], le[1], lt[1], raise_error, log)):
240
+ return False
241
+ return True
242
+
243
+ def is_int_series(l, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True):
244
+ """Value is a tuple or list of integers, each in range ge <= l[i] <= le or
245
+ gt < l[i] < lt or some combination.
246
+ """
247
+ if not test_ge_gt_le_lt(ge, gt, le, lt, is_int, 'is_int_series', raise_error, log):
248
+ return False
249
+ if not isinstance(l, (tuple, list)):
250
+ illegal_value(l, 'l', 'is_int_series', raise_error, log)
251
+ return False
252
+ if any(True if not is_int(v, ge, gt, le, lt, raise_error, log) else False for v in l):
253
+ return False
254
+ return True
255
+
256
+ def is_num_series(l, ge=None, gt=None, le=None, lt=None, raise_error=False, log=True):
257
+ """Value is a tuple or list of numbers, each in range ge <= l[i] <= le or
258
+ gt < l[i] < lt or some combination.
259
+ """
260
+ if not test_ge_gt_le_lt(ge, gt, le, lt, is_int, 'is_int_series', raise_error, log):
261
+ return False
262
+ if not isinstance(l, (tuple, list)):
263
+ illegal_value(l, 'l', 'is_num_series', raise_error, log)
264
+ return False
265
+ if any(True if not is_num(v, ge, gt, le, lt, raise_error, log) else False for v in l):
266
+ return False
267
+ return True
268
+
269
+ def is_str_series(l, raise_error=False, log=True):
270
+ """Value is a tuple or list of strings.
271
+ """
272
+ if (not isinstance(l, (tuple, list)) or
273
+ any(True if not isinstance(s, str) else False for s in l)):
274
+ illegal_value(l, 'l', 'is_str_series', raise_error, log)
275
+ return False
276
+ return True
277
+
278
+ def is_dict_series(l, raise_error=False, log=True):
279
+ """Value is a tuple or list of dictionaries.
280
+ """
281
+ if (not isinstance(l, (tuple, list)) or
282
+ any(True if not isinstance(d, dict) else False for d in l)):
283
+ illegal_value(l, 'l', 'is_dict_series', raise_error, log)
284
+ return False
285
+ return True
286
+
287
+ def is_dict_nums(l, raise_error=False, log=True):
288
+ """Value is a dictionary with single number values
289
+ """
290
+ if (not isinstance(l, dict) or
291
+ any(True if not is_num(v, log=False) else False for v in l.values())):
292
+ illegal_value(l, 'l', 'is_dict_nums', raise_error, log)
293
+ return False
294
+ return True
295
+
296
+ def is_dict_strings(l, raise_error=False, log=True):
297
+ """Value is a dictionary with single string values
298
+ """
299
+ if (not isinstance(l, dict) or
300
+ any(True if not isinstance(v, str) else False for v in l.values())):
301
+ illegal_value(l, 'l', 'is_dict_strings', raise_error, log)
302
+ return False
303
+ return True
304
+
305
+ def is_index(v, ge=0, lt=None, raise_error=False, log=True):
306
+ """Value is an array index in range ge <= v < lt.
307
+ NOTE lt IS NOT included!
308
+ """
309
+ if isinstance(lt, int):
310
+ if lt <= ge:
311
+ illegal_combination(ge, 'ge', lt, 'lt', 'is_index', raise_error, log)
312
+ return False
313
+ return is_int(v, ge=ge, lt=lt, raise_error=raise_error, log=log)
314
+
315
+ def is_index_range(v, ge=0, le=None, lt=None, raise_error=False, log=True):
316
+ """Value is an array index range in range ge <= v[0] <= v[1] <= le or ge <= v[0] <= v[1] < lt.
317
+ NOTE le IS included!
318
+ """
319
+ if not is_int_pair(v, raise_error=raise_error, log=log):
320
+ return False
321
+ if not test_ge_gt_le_lt(ge, None, le, lt, is_int, 'is_index_range', raise_error, log):
322
+ return False
323
+ if not ge <= v[0] <= v[1] or (le is not None and v[1] > le) or (lt is not None and v[1] >= lt):
324
+ if le is not None:
325
+ error_msg = f'Value {v} out of range: !({ge} <= {v[0]} <= {v[1]} <= {le})'
326
+ else:
327
+ error_msg = f'Value {v} out of range: !({ge} <= {v[0]} <= {v[1]} < {lt})'
328
+ if log:
329
+ logger.error(error_msg)
330
+ if raise_error:
331
+ raise ValueError(error_msg)
332
+ return False
333
+ return True
334
+
335
+ def index_nearest(a, value):
336
+ a = np.asarray(a)
337
+ if a.ndim > 1:
338
+ raise ValueError(f'Invalid array dimension for parameter a ({a.ndim}, {a})')
339
+ # Round up for .5
340
+ value *= 1.0+float_info.epsilon
341
+ return (int)(np.argmin(np.abs(a-value)))
342
+
343
+ def index_nearest_low(a, value):
344
+ a = np.asarray(a)
345
+ if a.ndim > 1:
346
+ raise ValueError(f'Invalid array dimension for parameter a ({a.ndim}, {a})')
347
+ index = int(np.argmin(np.abs(a-value)))
348
+ if value < a[index] and index > 0:
349
+ index -= 1
350
+ return index
351
+
352
+ def index_nearest_upp(a, value):
353
+ a = np.asarray(a)
354
+ if a.ndim > 1:
355
+ raise ValueError(f'Invalid array dimension for parameter a ({a.ndim}, {a})')
356
+ index = int(np.argmin(np.abs(a-value)))
357
+ if value > a[index] and index < a.size-1:
358
+ index += 1
359
+ return index
360
+
361
+ def round_to_n(x, n=1):
362
+ if x == 0.0:
363
+ return 0
364
+ else:
365
+ return type(x)(round(x, n-1-int(np.floor(np.log10(abs(x))))))
366
+
367
+ def round_up_to_n(x, n=1):
368
+ xr = round_to_n(x, n)
369
+ if abs(x/xr) > 1.0:
370
+ xr += np.sign(x)*10**(np.floor(np.log10(abs(x)))+1-n)
371
+ return type(x)(xr)
372
+
373
+ def trunc_to_n(x, n=1):
374
+ xr = round_to_n(x, n)
375
+ if abs(xr/x) > 1.0:
376
+ xr -= np.sign(x)*10**(np.floor(np.log10(abs(x)))+1-n)
377
+ return type(x)(xr)
378
+
379
+ def almost_equal(a, b, sig_figs):
380
+ if is_num(a) and is_num(b):
381
+ return abs(round_to_n(a-b, sig_figs)) < pow(10, -sig_figs+1)
382
+ else:
383
+ raise ValueError(f'Invalid value for a or b in almost_equal (a: {a}, {type(a)}, '+
384
+ f'b: {b}, {type(b)})')
385
+ return False
386
+
387
+ def string_to_list(s, split_on_dash=True, remove_duplicates=True, sort=True):
388
+ """Return a list of numbers by splitting/expanding a string on any combination of
389
+ commas, whitespaces, or dashes (when split_on_dash=True)
390
+ e.g: '1, 3, 5-8, 12 ' -> [1, 3, 5, 6, 7, 8, 12]
391
+ """
392
+ if not isinstance(s, str):
393
+ illegal_value(s, location='string_to_list')
394
+ return None
395
+ if not len(s):
396
+ return []
397
+ try:
398
+ ll = [x for x in re_split('\s+,\s+|\s+,|,\s+|\s+|,', s.strip())]
399
+ except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError):
400
+ return None
401
+ if split_on_dash:
402
+ try:
403
+ l = []
404
+ for l1 in ll:
405
+ l2 = [literal_eval(x) for x in re_split('\s+-\s+|\s+-|-\s+|\s+|-', l1)]
406
+ if len(l2) == 1:
407
+ l += l2
408
+ elif len(l2) == 2 and l2[1] > l2[0]:
409
+ l += [i for i in range(l2[0], l2[1]+1)]
410
+ else:
411
+ raise ValueError
412
+ except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError):
413
+ return None
414
+ else:
415
+ l = [literal_eval(x) for x in ll]
416
+ if remove_duplicates:
417
+ l = list(dict.fromkeys(l))
418
+ if sort:
419
+ l = sorted(l)
420
+ return l
421
+
422
+ def get_trailing_int(string):
423
+ indexRegex = re_compile(r'\d+$')
424
+ mo = indexRegex.search(string)
425
+ if mo is None:
426
+ return None
427
+ else:
428
+ return int(mo.group())
429
+
430
+ def input_int(s=None, ge=None, gt=None, le=None, lt=None, default=None, inset=None,
431
+ raise_error=False, log=True):
432
+ return _input_int_or_num('int', s, ge, gt, le, lt, default, inset, raise_error, log)
433
+
434
+ def input_num(s=None, ge=None, gt=None, le=None, lt=None, default=None, raise_error=False,
435
+ log=True):
436
+ return _input_int_or_num('num', s, ge, gt, le, lt, default, None, raise_error,log)
437
+
438
+ def _input_int_or_num(type_str, s=None, ge=None, gt=None, le=None, lt=None, default=None,
439
+ inset=None, raise_error=False, log=True):
440
+ if type_str == 'int':
441
+ if not test_ge_gt_le_lt(ge, gt, le, lt, is_int, '_input_int_or_num', raise_error, log):
442
+ return None
443
+ elif type_str == 'num':
444
+ if not test_ge_gt_le_lt(ge, gt, le, lt, is_num, '_input_int_or_num', raise_error, log):
445
+ return None
446
+ else:
447
+ illegal_value(type_str, 'type_str', '_input_int_or_num', raise_error, log)
448
+ return None
449
+ if default is not None:
450
+ if not _is_int_or_num(default, type_str, raise_error=raise_error, log=log):
451
+ return None
452
+ if ge is not None and default < ge:
453
+ illegal_combination(ge, 'ge', default, 'default', '_input_int_or_num', raise_error,
454
+ log)
455
+ return None
456
+ if gt is not None and default <= gt:
457
+ illegal_combination(gt, 'gt', default, 'default', '_input_int_or_num', raise_error,
458
+ log)
459
+ return None
460
+ if le is not None and default > le:
461
+ illegal_combination(le, 'le', default, 'default', '_input_int_or_num', raise_error,
462
+ log)
463
+ return None
464
+ if lt is not None and default >= lt:
465
+ illegal_combination(lt, 'lt', default, 'default', '_input_int_or_num', raise_error,
466
+ log)
467
+ return None
468
+ default_string = f' [{default}]'
469
+ else:
470
+ default_string = ''
471
+ if inset is not None:
472
+ if (not isinstance(inset, (tuple, list)) or any(True if not isinstance(i, int) else
473
+ False for i in inset)):
474
+ illegal_value(inset, 'inset', '_input_int_or_num', raise_error, log)
475
+ return None
476
+ v_range = f'{range_string_ge_gt_le_lt(ge, gt, le, lt)}'
477
+ if len(v_range):
478
+ v_range = f' {v_range}'
479
+ if s is None:
480
+ if type_str == 'int':
481
+ print(f'Enter an integer{v_range}{default_string}: ')
482
+ else:
483
+ print(f'Enter a number{v_range}{default_string}: ')
484
+ else:
485
+ print(f'{s}{v_range}{default_string}: ')
486
+ try:
487
+ i = input()
488
+ if isinstance(i, str) and not len(i):
489
+ v = default
490
+ print(f'{v}')
491
+ else:
492
+ v = literal_eval(i)
493
+ if inset and v not in inset:
494
+ raise ValueError(f'{v} not part of the set {inset}')
495
+ except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError):
496
+ v = None
497
+ except:
498
+ if log:
499
+ logger.error('Unexpected error')
500
+ if raise_error:
501
+ raise ValueError('Unexpected error')
502
+ if not _is_int_or_num(v, type_str, ge, gt, le, lt):
503
+ v = _input_int_or_num(type_str, s, ge, gt, le, lt, default, inset, raise_error, log)
504
+ return v
505
+
506
+ def input_int_list(s=None, ge=None, le=None, split_on_dash=True, remove_duplicates=True,
507
+ sort=True, raise_error=False, log=True):
508
+ """Prompt the user to input a list of interger and split the entered string on any combination
509
+ of commas, whitespaces, or dashes (when split_on_dash is True)
510
+ e.g: '1 3,5-8 , 12 ' -> [1, 3, 5, 6, 7, 8, 12]
511
+ remove_duplicates: removes duplicates if True (may also change the order)
512
+ sort: sort in ascending order if True
513
+ return None upon an illegal input
514
+ """
515
+ return _input_int_or_num_list('int', s, ge, le, split_on_dash, remove_duplicates, sort,
516
+ raise_error, log)
517
+
518
+ def input_num_list(s=None, ge=None, le=None, remove_duplicates=True, sort=True, raise_error=False,
519
+ log=True):
520
+ """Prompt the user to input a list of numbers and split the entered string on any combination
521
+ of commas or whitespaces
522
+ e.g: '1.0, 3, 5.8, 12 ' -> [1.0, 3.0, 5.8, 12.0]
523
+ remove_duplicates: removes duplicates if True (may also change the order)
524
+ sort: sort in ascending order if True
525
+ return None upon an illegal input
526
+ """
527
+ return _input_int_or_num_list('num', s, ge, le, False, remove_duplicates, sort, raise_error,
528
+ log)
529
+
530
+ def _input_int_or_num_list(type_str, s=None, ge=None, le=None, split_on_dash=True,
531
+ remove_duplicates=True, sort=True, raise_error=False, log=True):
532
+ #FIX do we want a limit on max dimension?
533
+ if type_str == 'int':
534
+ if not test_ge_gt_le_lt(ge, None, le, None, is_int, 'input_int_or_num_list', raise_error,
535
+ log):
536
+ return None
537
+ elif type_str == 'num':
538
+ if not test_ge_gt_le_lt(ge, None, le, None, is_num, 'input_int_or_num_list', raise_error,
539
+ log):
540
+ return None
541
+ else:
542
+ illegal_value(type_str, 'type_str', '_input_int_or_num_list')
543
+ return None
544
+ v_range = f'{range_string_ge_gt_le_lt(ge=ge, le=le)}'
545
+ if len(v_range):
546
+ v_range = f' (each value in {v_range})'
547
+ if s is None:
548
+ print(f'Enter a series of integers{v_range}: ')
549
+ else:
550
+ print(f'{s}{v_range}: ')
551
+ try:
552
+ l = string_to_list(input(), split_on_dash, remove_duplicates, sort)
553
+ except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError):
554
+ l = None
555
+ except:
556
+ print('Unexpected error')
557
+ raise
558
+ if (not isinstance(l, list) or
559
+ any(True if not _is_int_or_num(v, type_str, ge=ge, le=le) else False for v in l)):
560
+ if split_on_dash:
561
+ print('Invalid input: enter a valid set of dash/comma/whitespace separated integers '+
562
+ 'e.g. 1 3,5-8 , 12')
563
+ else:
564
+ print('Invalid input: enter a valid set of comma/whitespace separated integers '+
565
+ 'e.g. 1 3,5 8 , 12')
566
+ l = _input_int_or_num_list(type_str, s, ge, le, split_on_dash, remove_duplicates, sort,
567
+ raise_error, log)
568
+ return l
569
+
570
+ def input_yesno(s=None, default=None):
571
+ if default is not None:
572
+ if not isinstance(default, str):
573
+ illegal_value(default, 'default', 'input_yesno')
574
+ return None
575
+ if default.lower() in 'yes':
576
+ default = 'y'
577
+ elif default.lower() in 'no':
578
+ default = 'n'
579
+ else:
580
+ illegal_value(default, 'default', 'input_yesno')
581
+ return None
582
+ default_string = f' [{default}]'
583
+ else:
584
+ default_string = ''
585
+ if s is None:
586
+ print(f'Enter yes or no{default_string}: ')
587
+ else:
588
+ print(f'{s}{default_string}: ')
589
+ i = input()
590
+ if isinstance(i, str) and not len(i):
591
+ i = default
592
+ print(f'{i}')
593
+ if i is not None and i.lower() in 'yes':
594
+ v = True
595
+ elif i is not None and i.lower() in 'no':
596
+ v = False
597
+ else:
598
+ print('Invalid input, enter yes or no')
599
+ v = input_yesno(s, default)
600
+ return v
601
+
602
+ def input_menu(items, default=None, header=None):
603
+ if not isinstance(items, (tuple, list)) or any(True if not isinstance(i, str) else False
604
+ for i in items):
605
+ illegal_value(items, 'items', 'input_menu')
606
+ return None
607
+ if default is not None:
608
+ if not (isinstance(default, str) and default in items):
609
+ logger.error(f'Invalid value for default ({default}), must be in {items}')
610
+ return None
611
+ default_string = f' [{items.index(default)+1}]'
612
+ else:
613
+ default_string = ''
614
+ if header is None:
615
+ print(f'Choose one of the following items (1, {len(items)}){default_string}:')
616
+ else:
617
+ print(f'{header} (1, {len(items)}){default_string}:')
618
+ for i, choice in enumerate(items):
619
+ print(f' {i+1}: {choice}')
620
+ try:
621
+ choice = input()
622
+ if isinstance(choice, str) and not len(choice):
623
+ choice = items.index(default)
624
+ print(f'{choice+1}')
625
+ else:
626
+ choice = literal_eval(choice)
627
+ if isinstance(choice, int) and 1 <= choice <= len(items):
628
+ choice -= 1
629
+ else:
630
+ raise ValueError
631
+ except (ValueError, TypeError, SyntaxError, MemoryError, RecursionError):
632
+ choice = None
633
+ except:
634
+ print('Unexpected error')
635
+ raise
636
+ if choice is None:
637
+ print(f'Invalid choice, enter a number between 1 and {len(items)}')
638
+ choice = input_menu(items, default)
639
+ return choice
640
+
641
+ def assert_no_duplicates_in_list_of_dicts(l: list, raise_error=False) -> list:
642
+ if not isinstance(l, list):
643
+ illegal_value(l, 'l', 'assert_no_duplicates_in_list_of_dicts', raise_error)
644
+ return None
645
+ if any(True if not isinstance(d, dict) else False for d in l):
646
+ illegal_value(l, 'l', 'assert_no_duplicates_in_list_of_dicts', raise_error)
647
+ return None
648
+ if len(l) != len([dict(t) for t in {tuple(sorted(d.items())) for d in l}]):
649
+ if raise_error:
650
+ raise ValueError(f'Duplicate items found in {l}')
651
+ else:
652
+ logger.error(f'Duplicate items found in {l}')
653
+ return None
654
+ else:
655
+ return l
656
+
657
+ def assert_no_duplicate_key_in_list_of_dicts(l: list, key: str, raise_error=False) -> list:
658
+ if not isinstance(key, str):
659
+ illegal_value(key, 'key', 'assert_no_duplicate_key_in_list_of_dicts', raise_error)
660
+ return None
661
+ if not isinstance(l, list):
662
+ illegal_value(l, 'l', 'assert_no_duplicate_key_in_list_of_dicts', raise_error)
663
+ return None
664
+ if any(True if not isinstance(d, dict) else False for d in l):
665
+ illegal_value(l, 'l', 'assert_no_duplicates_in_list_of_dicts', raise_error)
666
+ return None
667
+ keys = [d.get(key, None) for d in l]
668
+ if None in keys or len(set(keys)) != len(l):
669
+ if raise_error:
670
+ raise ValueError(f'Duplicate or missing key ({key}) found in {l}')
671
+ else:
672
+ logger.error(f'Duplicate or missing key ({key}) found in {l}')
673
+ return None
674
+ else:
675
+ return l
676
+
677
+ def assert_no_duplicate_attr_in_list_of_objs(l: list, attr: str, raise_error=False) -> list:
678
+ if not isinstance(attr, str):
679
+ illegal_value(attr, 'attr', 'assert_no_duplicate_attr_in_list_of_objs', raise_error)
680
+ return None
681
+ if not isinstance(l, list):
682
+ illegal_value(l, 'l', 'assert_no_duplicate_key_in_list_of_objs', raise_error)
683
+ return None
684
+ attrs = [getattr(obj, attr, None) for obj in l]
685
+ if None in attrs or len(set(attrs)) != len(l):
686
+ if raise_error:
687
+ raise ValueError(f'Duplicate or missing attr ({attr}) found in {l}')
688
+ else:
689
+ logger.error(f'Duplicate or missing attr ({attr}) found in {l}')
690
+ return None
691
+ else:
692
+ return l
693
+
694
+ def file_exists_and_readable(path):
695
+ import os
696
+ if not os.path.isfile(path):
697
+ raise ValueError(f'{path} is not a valid file')
698
+ elif not os.access(path, os.R_OK):
699
+ raise ValueError(f'{path} is not accessible for reading')
700
+ else:
701
+ return path
702
+
703
+ def draw_mask_1d(ydata, xdata=None, current_index_ranges=None, current_mask=None,
704
+ select_mask=True, num_index_ranges_max=None, title=None, legend=None, test_mode=False):
705
+ #FIX make color blind friendly
706
+ def draw_selections(ax, current_include, current_exclude, selected_index_ranges):
707
+ ax.clear()
708
+ ax.set_title(title)
709
+ ax.legend([legend])
710
+ ax.plot(xdata, ydata, 'k')
711
+ for (low, upp) in current_include:
712
+ xlow = 0.5*(xdata[max(0, low-1)]+xdata[low])
713
+ xupp = 0.5*(xdata[upp]+xdata[min(num_data-1, upp+1)])
714
+ ax.axvspan(xlow, xupp, facecolor='green', alpha=0.5)
715
+ for (low, upp) in current_exclude:
716
+ xlow = 0.5*(xdata[max(0, low-1)]+xdata[low])
717
+ xupp = 0.5*(xdata[upp]+xdata[min(num_data-1, upp+1)])
718
+ ax.axvspan(xlow, xupp, facecolor='red', alpha=0.5)
719
+ for (low, upp) in selected_index_ranges:
720
+ xlow = 0.5*(xdata[max(0, low-1)]+xdata[low])
721
+ xupp = 0.5*(xdata[upp]+xdata[min(num_data-1, upp+1)])
722
+ ax.axvspan(xlow, xupp, facecolor=selection_color, alpha=0.5)
723
+ ax.get_figure().canvas.draw()
724
+
725
+ def onclick(event):
726
+ if event.inaxes in [fig.axes[0]]:
727
+ selected_index_ranges.append(index_nearest_upp(xdata, event.xdata))
728
+
729
+ def onrelease(event):
730
+ if len(selected_index_ranges) > 0:
731
+ if isinstance(selected_index_ranges[-1], int):
732
+ if event.inaxes in [fig.axes[0]]:
733
+ event.xdata = index_nearest_low(xdata, event.xdata)
734
+ if selected_index_ranges[-1] <= event.xdata:
735
+ selected_index_ranges[-1] = (selected_index_ranges[-1], event.xdata)
736
+ else:
737
+ selected_index_ranges[-1] = (event.xdata, selected_index_ranges[-1])
738
+ draw_selections(event.inaxes, current_include, current_exclude, selected_index_ranges)
739
+ else:
740
+ selected_index_ranges.pop(-1)
741
+
742
+ def confirm_selection(event):
743
+ plt.close()
744
+
745
+ def clear_last_selection(event):
746
+ if len(selected_index_ranges):
747
+ selected_index_ranges.pop(-1)
748
+ else:
749
+ while len(current_include):
750
+ current_include.pop()
751
+ while len(current_exclude):
752
+ current_exclude.pop()
753
+ selected_mask.fill(False)
754
+ draw_selections(ax, current_include, current_exclude, selected_index_ranges)
755
+
756
+ def update_mask(mask, selected_index_ranges, unselected_index_ranges):
757
+ for (low, upp) in selected_index_ranges:
758
+ selected_mask = np.logical_and(xdata >= xdata[low], xdata <= xdata[upp])
759
+ mask = np.logical_or(mask, selected_mask)
760
+ for (low, upp) in unselected_index_ranges:
761
+ unselected_mask = np.logical_and(xdata >= xdata[low], xdata <= xdata[upp])
762
+ mask[unselected_mask] = False
763
+ return mask
764
+
765
+ def update_index_ranges(mask):
766
+ # Update the currently included index ranges (where mask is True)
767
+ current_include = []
768
+ for i, m in enumerate(mask):
769
+ if m == True:
770
+ if len(current_include) == 0 or type(current_include[-1]) == tuple:
771
+ current_include.append(i)
772
+ else:
773
+ if len(current_include) > 0 and isinstance(current_include[-1], int):
774
+ current_include[-1] = (current_include[-1], i-1)
775
+ if len(current_include) > 0 and isinstance(current_include[-1], int):
776
+ current_include[-1] = (current_include[-1], num_data-1)
777
+ return current_include
778
+
779
+ # Check inputs
780
+ ydata = np.asarray(ydata)
781
+ if ydata.ndim > 1:
782
+ logger.warning(f'Invalid ydata dimension ({ydata.ndim})')
783
+ return None, None
784
+ num_data = ydata.size
785
+ if xdata is None:
786
+ xdata = np.arange(num_data)
787
+ else:
788
+ xdata = np.asarray(xdata, dtype=np.float64)
789
+ if xdata.ndim > 1 or xdata.size != num_data:
790
+ logger.warning(f'Invalid xdata shape ({xdata.shape})')
791
+ return None, None
792
+ if not np.all(xdata[:-1] < xdata[1:]):
793
+ logger.warning('Invalid xdata: must be monotonically increasing')
794
+ return None, None
795
+ if current_index_ranges is not None:
796
+ if not isinstance(current_index_ranges, (tuple, list)):
797
+ logger.warning('Invalid current_index_ranges parameter ({current_index_ranges}, '+
798
+ f'{type(current_index_ranges)})')
799
+ return None, None
800
+ if not isinstance(select_mask, bool):
801
+ logger.warning('Invalid select_mask parameter ({select_mask}, {type(select_mask)})')
802
+ return None, None
803
+ if num_index_ranges_max is not None:
804
+ logger.warning('num_index_ranges_max input not yet implemented in draw_mask_1d')
805
+ if title is None:
806
+ title = 'select ranges of data'
807
+ elif not isinstance(title, str):
808
+ illegal(title, 'title')
809
+ title = ''
810
+ if legend is None and not isinstance(title, str):
811
+ illegal(legend, 'legend')
812
+ legend = None
813
+
814
+ if select_mask:
815
+ title = f'Click and drag to {title} you wish to include'
816
+ selection_color = 'green'
817
+ else:
818
+ title = f'Click and drag to {title} you wish to exclude'
819
+ selection_color = 'red'
820
+
821
+ # Set initial selected mask and the selected/unselected index ranges as needed
822
+ selected_index_ranges = []
823
+ unselected_index_ranges = []
824
+ selected_mask = np.full(xdata.shape, False, dtype=bool)
825
+ if current_index_ranges is None:
826
+ if current_mask is None:
827
+ if not select_mask:
828
+ selected_index_ranges = [(0, num_data-1)]
829
+ selected_mask = np.full(xdata.shape, True, dtype=bool)
830
+ else:
831
+ selected_mask = np.copy(np.asarray(current_mask, dtype=bool))
832
+ if current_index_ranges is not None and len(current_index_ranges):
833
+ current_index_ranges = sorted([(low, upp) for (low, upp) in current_index_ranges])
834
+ for (low, upp) in current_index_ranges:
835
+ if low > upp or low >= num_data or upp < 0:
836
+ continue
837
+ if low < 0:
838
+ low = 0
839
+ if upp >= num_data:
840
+ upp = num_data-1
841
+ selected_index_ranges.append((low, upp))
842
+ selected_mask = update_mask(selected_mask, selected_index_ranges, unselected_index_ranges)
843
+ if current_index_ranges is not None and current_mask is not None:
844
+ selected_mask = np.logical_and(current_mask, selected_mask)
845
+ if current_mask is not None:
846
+ selected_index_ranges = update_index_ranges(selected_mask)
847
+
848
+ # Set up range selections for display
849
+ current_include = selected_index_ranges
850
+ current_exclude = []
851
+ selected_index_ranges = []
852
+ if not len(current_include):
853
+ if select_mask:
854
+ current_exclude = [(0, num_data-1)]
855
+ else:
856
+ current_include = [(0, num_data-1)]
857
+ else:
858
+ if current_include[0][0] > 0:
859
+ current_exclude.append((0, current_include[0][0]-1))
860
+ for i in range(1, len(current_include)):
861
+ current_exclude.append((current_include[i-1][1]+1, current_include[i][0]-1))
862
+ if current_include[-1][1] < num_data-1:
863
+ current_exclude.append((current_include[-1][1]+1, num_data-1))
864
+
865
+ if not test_mode:
866
+
867
+ # Set up matplotlib figure
868
+ plt.close('all')
869
+ fig, ax = plt.subplots()
870
+ plt.subplots_adjust(bottom=0.2)
871
+ draw_selections(ax, current_include, current_exclude, selected_index_ranges)
872
+
873
+ # Set up event handling for click-and-drag range selection
874
+ cid_click = fig.canvas.mpl_connect('button_press_event', onclick)
875
+ cid_release = fig.canvas.mpl_connect('button_release_event', onrelease)
876
+
877
+ # Set up confirm / clear range selection buttons
878
+ confirm_b = Button(plt.axes([0.75, 0.05, 0.15, 0.075]), 'Confirm')
879
+ clear_b = Button(plt.axes([0.59, 0.05, 0.15, 0.075]), 'Clear')
880
+ cid_confirm = confirm_b.on_clicked(confirm_selection)
881
+ cid_clear = clear_b.on_clicked(clear_last_selection)
882
+
883
+ # Show figure
884
+ plt.show(block=True)
885
+
886
+ # Disconnect callbacks when figure is closed
887
+ fig.canvas.mpl_disconnect(cid_click)
888
+ fig.canvas.mpl_disconnect(cid_release)
889
+ confirm_b.disconnect(cid_confirm)
890
+ clear_b.disconnect(cid_clear)
891
+
892
+ # Swap selection depending on select_mask
893
+ if not select_mask:
894
+ selected_index_ranges, unselected_index_ranges = unselected_index_ranges, \
895
+ selected_index_ranges
896
+
897
+ # Update the mask with the currently selected/unselected x-ranges
898
+ selected_mask = update_mask(selected_mask, selected_index_ranges, unselected_index_ranges)
899
+
900
+ # Update the currently included index ranges (where mask is True)
901
+ current_include = update_index_ranges(selected_mask)
902
+
903
+ def select_image_bounds(a, axis, low=None, upp=None, num_min=None, title='select array bounds',
904
+ raise_error=False):
905
+ """Interactively select the lower and upper data bounds for a 2D numpy array.
906
+ """
907
+ a = np.asarray(a)
908
+ if a.ndim != 2:
909
+ illegal_value(a.ndim, 'array dimension', location='select_image_bounds',
910
+ raise_error=raise_error)
911
+ return None
912
+ if axis < 0 or axis >= a.ndim:
913
+ illegal_value(axis, 'axis', location='select_image_bounds', raise_error=raise_error)
914
+ return None
915
+ low_save = low
916
+ upp_save = upp
917
+ num_min_save = num_min
918
+ if num_min is None:
919
+ num_min = 1
920
+ else:
921
+ if num_min < 2 or num_min > a.shape[axis]:
922
+ logger.warning('Invalid input for num_min in select_image_bounds, input ignored')
923
+ num_min = 1
924
+ if low is None:
925
+ min_ = 0
926
+ max_ = a.shape[axis]
927
+ low_max = a.shape[axis]-num_min
928
+ while True:
929
+ if axis:
930
+ quick_imshow(a[:,min_:max_], title=title, aspect='auto',
931
+ extent=[min_,max_,a.shape[0],0])
932
+ else:
933
+ quick_imshow(a[min_:max_,:], title=title, aspect='auto',
934
+ extent=[0,a.shape[1], max_,min_])
935
+ zoom_flag = input_yesno('Set lower data bound (y) or zoom in (n)?', 'y')
936
+ if zoom_flag:
937
+ low = input_int(' Set lower data bound', ge=0, le=low_max)
938
+ break
939
+ else:
940
+ min_ = input_int(' Set lower zoom index', ge=0, le=low_max)
941
+ max_ = input_int(' Set upper zoom index', ge=min_+1, le=low_max+1)
942
+ else:
943
+ if not is_int(low, ge=0, le=a.shape[axis]-num_min):
944
+ illegal_value(low, 'low', location='select_image_bounds', raise_error=raise_error)
945
+ return None
946
+ if upp is None:
947
+ min_ = low+num_min
948
+ max_ = a.shape[axis]
949
+ upp_min = min_
950
+ while True:
951
+ if axis:
952
+ quick_imshow(a[:,min_:max_], title=title, aspect='auto',
953
+ extent=[min_,max_,a.shape[0],0])
954
+ else:
955
+ quick_imshow(a[min_:max_,:], title=title, aspect='auto',
956
+ extent=[0,a.shape[1], max_,min_])
957
+ zoom_flag = input_yesno('Set upper data bound (y) or zoom in (n)?', 'y')
958
+ if zoom_flag:
959
+ upp = input_int(' Set upper data bound', ge=upp_min, le=a.shape[axis])
960
+ break
961
+ else:
962
+ min_ = input_int(' Set upper zoom index', ge=upp_min, le=a.shape[axis]-1)
963
+ max_ = input_int(' Set upper zoom index', ge=min_+1, le=a.shape[axis])
964
+ else:
965
+ if not is_int(upp, ge=low+num_min, le=a.shape[axis]):
966
+ illegal_value(upp, 'upp', location='select_image_bounds', raise_error=raise_error)
967
+ return None
968
+ bounds = (low, upp)
969
+ a_tmp = np.copy(a)
970
+ a_tmp_max = a.max()
971
+ if axis:
972
+ a_tmp[:,bounds[0]] = a_tmp_max
973
+ a_tmp[:,bounds[1]-1] = a_tmp_max
974
+ else:
975
+ a_tmp[bounds[0],:] = a_tmp_max
976
+ a_tmp[bounds[1]-1,:] = a_tmp_max
977
+ print(f'lower bound = {low} (inclusive)\nupper bound = {upp} (exclusive)')
978
+ quick_imshow(a_tmp, title=title, aspect='auto')
979
+ del a_tmp
980
+ if not input_yesno('Accept these bounds (y/n)?', 'y'):
981
+ bounds = select_image_bounds(a, axis, low=low_save, upp=upp_save, num_min=num_min_save,
982
+ title=title)
983
+ clear_imshow(title)
984
+ return bounds
985
+
986
+ def select_one_image_bound(a, axis, bound=None, bound_name=None, title='select array bounds',
987
+ default='y', raise_error=False):
988
+ """Interactively select a data boundary for a 2D numpy array.
989
+ """
990
+ a = np.asarray(a)
991
+ if a.ndim != 2:
992
+ illegal_value(a.ndim, 'array dimension', location='select_one_image_bound',
993
+ raise_error=raise_error)
994
+ return None
995
+ if axis < 0 or axis >= a.ndim:
996
+ illegal_value(axis, 'axis', location='select_one_image_bound', raise_error=raise_error)
997
+ return None
998
+ if bound_name is None:
999
+ bound_name = 'data bound'
1000
+ if bound is None:
1001
+ min_ = 0
1002
+ max_ = a.shape[axis]
1003
+ bound_max = a.shape[axis]-1
1004
+ while True:
1005
+ if axis:
1006
+ quick_imshow(a[:,min_:max_], title=title, aspect='auto',
1007
+ extent=[min_,max_,a.shape[0],0])
1008
+ else:
1009
+ quick_imshow(a[min_:max_,:], title=title, aspect='auto',
1010
+ extent=[0,a.shape[1], max_,min_])
1011
+ zoom_flag = input_yesno(f'Set {bound_name} (y) or zoom in (n)?', 'y')
1012
+ if zoom_flag:
1013
+ bound = input_int(f' Set {bound_name}', ge=0, le=bound_max)
1014
+ clear_imshow(title)
1015
+ break
1016
+ else:
1017
+ min_ = input_int(' Set lower zoom index', ge=0, le=bound_max)
1018
+ max_ = input_int(' Set upper zoom index', ge=min_+1, le=bound_max+1)
1019
+
1020
+ elif not is_int(bound, ge=0, le=a.shape[axis]-1):
1021
+ illegal_value(bound, 'bound', location='select_one_image_bound', raise_error=raise_error)
1022
+ return None
1023
+ else:
1024
+ print(f'Current {bound_name} = {bound}')
1025
+ a_tmp = np.copy(a)
1026
+ a_tmp_max = a.max()
1027
+ if axis:
1028
+ a_tmp[:,bound] = a_tmp_max
1029
+ else:
1030
+ a_tmp[bound,:] = a_tmp_max
1031
+ quick_imshow(a_tmp, title=title, aspect='auto')
1032
+ del a_tmp
1033
+ if not input_yesno(f'Accept this {bound_name} (y/n)?', default):
1034
+ bound = select_one_image_bound(a, axis, bound_name=bound_name, title=title)
1035
+ clear_imshow(title)
1036
+ return bound
1037
+
1038
+ def clear_imshow(title=None):
1039
+ plt.ioff()
1040
+ if title is None:
1041
+ title = 'quick imshow'
1042
+ elif not isinstance(title, str):
1043
+ raise ValueError(f'Invalid parameter title ({title})')
1044
+ plt.close(fig=title)
1045
+
1046
+ def clear_plot(title=None):
1047
+ plt.ioff()
1048
+ if title is None:
1049
+ title = 'quick plot'
1050
+ elif not isinstance(title, str):
1051
+ raise ValueError(f'Invalid parameter title ({title})')
1052
+ plt.close(fig=title)
1053
+
1054
+ def quick_imshow(a, title=None, path=None, name=None, save_fig=False, save_only=False,
1055
+ clear=True, extent=None, show_grid=False, grid_color='w', grid_linewidth=1,
1056
+ block=False, **kwargs):
1057
+ if title is not None and not isinstance(title, str):
1058
+ raise ValueError(f'Invalid parameter title ({title})')
1059
+ if path is not None and not isinstance(path, str):
1060
+ raise ValueError(f'Invalid parameter path ({path})')
1061
+ if not isinstance(save_fig, bool):
1062
+ raise ValueError(f'Invalid parameter save_fig ({save_fig})')
1063
+ if not isinstance(save_only, bool):
1064
+ raise ValueError(f'Invalid parameter save_only ({save_only})')
1065
+ if not isinstance(clear, bool):
1066
+ raise ValueError(f'Invalid parameter clear ({clear})')
1067
+ if not isinstance(block, bool):
1068
+ raise ValueError(f'Invalid parameter block ({block})')
1069
+ if not title:
1070
+ title='quick imshow'
1071
+ if name is None:
1072
+ ttitle = re_sub(r"\s+", '_', title)
1073
+ if path is None:
1074
+ path = f'{ttitle}.png'
1075
+ else:
1076
+ path = f'{path}/{ttitle}.png'
1077
+ else:
1078
+ if path is None:
1079
+ path = name
1080
+ else:
1081
+ path = f'{path}/{name}'
1082
+ if 'cmap' in kwargs and a.ndim == 3 and (a.shape[2] == 3 or a.shape[2] == 4):
1083
+ use_cmap = True
1084
+ if a.shape[2] == 4 and a[:,:,-1].min() != a[:,:,-1].max():
1085
+ use_cmap = False
1086
+ if any(True if a[i,j,0] != a[i,j,1] and a[i,j,0] != a[i,j,2] else False
1087
+ for i in range(a.shape[0]) for j in range(a.shape[1])):
1088
+ use_cmap = False
1089
+ if use_cmap:
1090
+ a = a[:,:,0]
1091
+ else:
1092
+ logger.warning('Image incompatible with cmap option, ignore cmap')
1093
+ kwargs.pop('cmap')
1094
+ if extent is None:
1095
+ extent = (0, a.shape[1], a.shape[0], 0)
1096
+ if clear:
1097
+ try:
1098
+ plt.close(fig=title)
1099
+ except:
1100
+ pass
1101
+ if not save_only:
1102
+ if block:
1103
+ plt.ioff()
1104
+ else:
1105
+ plt.ion()
1106
+ plt.figure(title)
1107
+ plt.imshow(a, extent=extent, **kwargs)
1108
+ if show_grid:
1109
+ ax = plt.gca()
1110
+ ax.grid(color=grid_color, linewidth=grid_linewidth)
1111
+ # if title != 'quick imshow':
1112
+ # plt.title = title
1113
+ if save_only:
1114
+ plt.savefig(path)
1115
+ plt.close(fig=title)
1116
+ else:
1117
+ if save_fig:
1118
+ plt.savefig(path)
1119
+ if block:
1120
+ plt.show(block=block)
1121
+
1122
+ def quick_plot(*args, xerr=None, yerr=None, vlines=None, title=None, xlim=None, ylim=None,
1123
+ xlabel=None, ylabel=None, legend=None, path=None, name=None, show_grid=False,
1124
+ save_fig=False, save_only=False, clear=True, block=False, **kwargs):
1125
+ if title is not None and not isinstance(title, str):
1126
+ illegal_value(title, 'title', 'quick_plot')
1127
+ title = None
1128
+ if xlim is not None and not isinstance(xlim, (tuple, list)) and len(xlim) != 2:
1129
+ illegal_value(xlim, 'xlim', 'quick_plot')
1130
+ xlim = None
1131
+ if ylim is not None and not isinstance(ylim, (tuple, list)) and len(ylim) != 2:
1132
+ illegal_value(ylim, 'ylim', 'quick_plot')
1133
+ ylim = None
1134
+ if xlabel is not None and not isinstance(xlabel, str):
1135
+ illegal_value(xlabel, 'xlabel', 'quick_plot')
1136
+ xlabel = None
1137
+ if ylabel is not None and not isinstance(ylabel, str):
1138
+ illegal_value(ylabel, 'ylabel', 'quick_plot')
1139
+ ylabel = None
1140
+ if legend is not None and not isinstance(legend, (tuple, list)):
1141
+ illegal_value(legend, 'legend', 'quick_plot')
1142
+ legend = None
1143
+ if path is not None and not isinstance(path, str):
1144
+ illegal_value(path, 'path', 'quick_plot')
1145
+ return
1146
+ if not isinstance(show_grid, bool):
1147
+ illegal_value(show_grid, 'show_grid', 'quick_plot')
1148
+ return
1149
+ if not isinstance(save_fig, bool):
1150
+ illegal_value(save_fig, 'save_fig', 'quick_plot')
1151
+ return
1152
+ if not isinstance(save_only, bool):
1153
+ illegal_value(save_only, 'save_only', 'quick_plot')
1154
+ return
1155
+ if not isinstance(clear, bool):
1156
+ illegal_value(clear, 'clear', 'quick_plot')
1157
+ return
1158
+ if not isinstance(block, bool):
1159
+ illegal_value(block, 'block', 'quick_plot')
1160
+ return
1161
+ if title is None:
1162
+ title = 'quick plot'
1163
+ if name is None:
1164
+ ttitle = re_sub(r"\s+", '_', title)
1165
+ if path is None:
1166
+ path = f'{ttitle}.png'
1167
+ else:
1168
+ path = f'{path}/{ttitle}.png'
1169
+ else:
1170
+ if path is None:
1171
+ path = name
1172
+ else:
1173
+ path = f'{path}/{name}'
1174
+ if clear:
1175
+ try:
1176
+ plt.close(fig=title)
1177
+ except:
1178
+ pass
1179
+ args = unwrap_tuple(args)
1180
+ if depth_tuple(args) > 1 and (xerr is not None or yerr is not None):
1181
+ logger.warning('Error bars ignored form multiple curves')
1182
+ if not save_only:
1183
+ if block:
1184
+ plt.ioff()
1185
+ else:
1186
+ plt.ion()
1187
+ plt.figure(title)
1188
+ if depth_tuple(args) > 1:
1189
+ for y in args:
1190
+ plt.plot(*y, **kwargs)
1191
+ else:
1192
+ if xerr is None and yerr is None:
1193
+ plt.plot(*args, **kwargs)
1194
+ else:
1195
+ plt.errorbar(*args, xerr=xerr, yerr=yerr, **kwargs)
1196
+ if vlines is not None:
1197
+ if isinstance(vlines, (int, float)):
1198
+ vlines = [vlines]
1199
+ for v in vlines:
1200
+ plt.axvline(v, color='r', linestyle='--', **kwargs)
1201
+ # if vlines is not None:
1202
+ # for s in tuple(([x, x], list(plt.gca().get_ylim())) for x in vlines):
1203
+ # plt.plot(*s, color='red', **kwargs)
1204
+ if xlim is not None:
1205
+ plt.xlim(xlim)
1206
+ if ylim is not None:
1207
+ plt.ylim(ylim)
1208
+ if xlabel is not None:
1209
+ plt.xlabel(xlabel)
1210
+ if ylabel is not None:
1211
+ plt.ylabel(ylabel)
1212
+ if show_grid:
1213
+ ax = plt.gca()
1214
+ ax.grid(color='k')#, linewidth=1)
1215
+ if legend is not None:
1216
+ plt.legend(legend)
1217
+ if save_only:
1218
+ plt.savefig(path)
1219
+ plt.close(fig=title)
1220
+ else:
1221
+ if save_fig:
1222
+ plt.savefig(path)
1223
+ if block:
1224
+ plt.show(block=block)
1225
+