tilupy 0.1.4__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tilupy might be problematic. Click here for more details.

tilupy/plot.py CHANGED
@@ -15,7 +15,23 @@ import matplotlib.pyplot as plt
15
15
  import matplotlib.colors as mcolors
16
16
  import seaborn as sns
17
17
 
18
- from mpl_toolkits.axes_grid1 import make_axes_locatable
18
+ BOLD_CONTOURS_INTV = [
19
+ 0.1,
20
+ 0.2,
21
+ 0.5,
22
+ 1,
23
+ 2.0,
24
+ 5,
25
+ 10,
26
+ 20,
27
+ 50,
28
+ 100,
29
+ 200,
30
+ 500,
31
+ 1000,
32
+ ]
33
+ NB_THIN_CONTOURS = 10
34
+ NB_BOLD_CONTOURS = 3
19
35
 
20
36
 
21
37
  def centered_map(cmap, vmin, vmax, ncolors=256):
@@ -39,32 +55,76 @@ def centered_map(cmap, vmin, vmax, ncolors=256):
39
55
  DESCRIPTION.
40
56
 
41
57
  """
42
- p = vmax/(vmax-vmin)
43
- npos = int(ncolors*p)
58
+ p = vmax / (vmax - vmin)
59
+ npos = int(ncolors * p)
44
60
  method = getattr(plt.cm, cmap)
45
61
 
46
- colors1 = method(np.linspace(0., 1, npos*2))
47
- colors2 = method(np.linspace(0., 1, (ncolors-npos)*2))
62
+ colors1 = method(np.linspace(0.0, 1, npos * 2))
63
+ colors2 = method(np.linspace(0.0, 1, (ncolors - npos) * 2))
48
64
  colors = np.concatenate(
49
- (colors2[:ncolors-npos, :], colors1[npos:, :]), axis=0)
65
+ (colors2[: ncolors - npos, :], colors1[npos:, :]), axis=0
66
+ )
50
67
  # colors[ncolors-npos-1,:]=np.ones((1,4))
51
68
  # colors[ncolors-npos,:]=np.ones((1,4))
52
- new_map = mcolors.LinearSegmentedColormap.from_list(
53
- 'my_colormap', colors)
69
+ new_map = mcolors.LinearSegmentedColormap.from_list("my_colormap", colors)
54
70
 
55
71
  return new_map
56
72
 
57
73
 
58
- def plot_topo(z, x, y, contour_step=None, nlevels=25, level_min=None,
59
- step_contour_bold=0, contour_labels_properties=None,
60
- label_contour=True, contour_label_effect=None,
61
- axe=None,
62
- vert_exag=1, fraction=1, ndv=0, uniform_grey=None,
63
- contours_prop=None, contours_bold_prop=None,
64
- figsize=None,
65
- interpolation=None,
66
- sea_level=0, sea_color=None, alpha=1, azdeg=315, altdeg=45,
67
- zmin=None, zmax=None):
74
+ def get_contour_intervals(
75
+ zmin, zmax, nb_bold_contours=None, nb_thin_contours=None
76
+ ):
77
+ if nb_thin_contours is None:
78
+ nb_thin_contours = NB_THIN_CONTOURS
79
+ if nb_bold_contours is None:
80
+ nb_bold_contours = NB_BOLD_CONTOURS
81
+
82
+ intv = (zmax - zmin) / nb_bold_contours
83
+ i = np.argmin(np.abs(np.array(BOLD_CONTOURS_INTV) - intv))
84
+
85
+ bold_intv = BOLD_CONTOURS_INTV[i]
86
+ if BOLD_CONTOURS_INTV[i] != BOLD_CONTOURS_INTV[0]:
87
+ if bold_intv - intv > 0:
88
+ bold_intv = BOLD_CONTOURS_INTV[i - 1]
89
+
90
+ if nb_thin_contours is None:
91
+ thin_intv = bold_intv / NB_THIN_CONTOURS
92
+ if (zmax - zmin) / bold_intv > 5:
93
+ thin_intv = thin_intv * 2
94
+ else:
95
+ thin_intv = bold_intv / nb_thin_contours
96
+
97
+ return bold_intv, thin_intv
98
+
99
+
100
+ def plot_topo(
101
+ z,
102
+ x,
103
+ y,
104
+ contour_step=None,
105
+ nlevels=None,
106
+ level_min=None,
107
+ step_contour_bold="auto",
108
+ contour_labels_properties=None,
109
+ label_contour=True,
110
+ contour_label_effect=None,
111
+ axe=None,
112
+ vert_exag=1,
113
+ fraction=1,
114
+ ndv=-9999,
115
+ uniform_grey=None,
116
+ contours_prop=None,
117
+ contours_bold_prop=None,
118
+ figsize=None,
119
+ interpolation=None,
120
+ sea_level=0,
121
+ sea_color=None,
122
+ alpha=1,
123
+ azdeg=315,
124
+ altdeg=45,
125
+ zmin=None,
126
+ zmax=None,
127
+ ):
68
128
  """
69
129
  Plot topography with hillshading.
70
130
 
@@ -111,7 +171,7 @@ def plot_topo(z, x, y, contour_step=None, nlevels=25, level_min=None,
111
171
  fraction : TYPE, optional
112
172
  DESCRIPTION. The default is 1.
113
173
  ndv : TYPE, optional
114
- DESCRIPTION. The default is 0.
174
+ DESCRIPTION. The default is -9999.
115
175
 
116
176
  Returns
117
177
  -------
@@ -120,15 +180,19 @@ def plot_topo(z, x, y, contour_step=None, nlevels=25, level_min=None,
120
180
  """
121
181
  dx = x[1] - x[0]
122
182
  dy = y[1] - y[0]
123
- im_extent = [x[0]-dx/2,
124
- x[-1]+dx/2,
125
- y[0]-dy/2,
126
- y[-1]+dy/2]
183
+ im_extent = [x[0] - dx / 2, x[-1] + dx / 2, y[0] - dy / 2, y[-1] + dy / 2]
127
184
  ls = mcolors.LightSource(azdeg=azdeg, altdeg=altdeg)
128
185
 
186
+ auto_bold_intv = None
187
+
188
+ if nlevels is None and contour_step is None:
189
+ auto_bold_intv, contour_step = get_contour_intervals(
190
+ np.nanmin(z), np.nanmax(z)
191
+ )
192
+
129
193
  if level_min is None:
130
194
  if contour_step is not None:
131
- level_min = np.ceil(np.nanmin(z)/contour_step)*contour_step
195
+ level_min = np.ceil(np.nanmin(z) / contour_step) * contour_step
132
196
  else:
133
197
  level_min = np.nanmin(z)
134
198
  if contour_step is not None:
@@ -137,42 +201,63 @@ def plot_topo(z, x, y, contour_step=None, nlevels=25, level_min=None,
137
201
  levels = np.linspace(level_min, np.nanmax(z), nlevels)
138
202
 
139
203
  if axe is None:
140
- fig = plt.figure(figsize=figsize)
141
- axe = fig.gca()
142
- else:
143
- fig = axe.figure
144
- axe.set_ylabel('Y (m)')
145
- axe.set_xlabel('X (m)')
146
- axe.set_aspect('equal')
204
+ fig, axe = plt.subplots(1, 1, figsize=figsize, layout="constrained")
205
+
206
+ axe.set_ylabel("Y (m)")
207
+ axe.set_xlabel("X (m)")
208
+ axe.set_aspect("equal")
147
209
 
148
210
  if uniform_grey is None:
149
- shaded_topo = ls.hillshade(z,
150
- vert_exag=vert_exag, dx=dx, dy=dy,
151
- fraction=1)
211
+ shaded_topo = ls.hillshade(
212
+ z, vert_exag=vert_exag, dx=dx, dy=dy, fraction=1
213
+ )
152
214
  else:
153
- shaded_topo = np.ones(z.shape)*uniform_grey
215
+ shaded_topo = np.ones(z.shape) * uniform_grey
154
216
  shaded_topo[z == ndv] = np.nan
155
- axe.imshow(shaded_topo, cmap='gray', extent=im_extent,
156
- interpolation=interpolation, alpha=alpha, vmin=0, vmax=1)
217
+ axe.imshow(
218
+ shaded_topo,
219
+ cmap="gray",
220
+ extent=im_extent,
221
+ interpolation=interpolation,
222
+ alpha=alpha,
223
+ vmin=0,
224
+ vmax=1,
225
+ )
157
226
 
158
227
  if contours_prop is None:
159
- contours_prop = dict(alpha=0.5, colors='k',
160
- linewidths=0.5)
161
- axe.contour(x, y, np.flip(z, axis=0), extent=im_extent,
162
- levels=levels,
163
- **contours_prop)
228
+ contours_prop = dict(alpha=0.5, colors="k", linewidths=0.5)
229
+ axe.contour(
230
+ x,
231
+ y,
232
+ np.flip(z, axis=0),
233
+ extent=im_extent,
234
+ levels=levels,
235
+ **contours_prop
236
+ )
164
237
 
165
238
  if contours_bold_prop is None:
166
- contours_bold_prop = dict(alpha=0.8, colors='k',
167
- linewidths=0.8)
239
+ contours_bold_prop = dict(alpha=0.8, colors="k", linewidths=0.8)
240
+
241
+ if step_contour_bold == "auto":
242
+ if auto_bold_intv is None:
243
+ auto_bold_intv, _ = get_contour_intervals(
244
+ np.nanmin(z), np.nanmax(z)
245
+ )
246
+ step_contour_bold = auto_bold_intv
168
247
 
169
248
  if step_contour_bold > 0:
170
- lmin = np.ceil(np.nanmin(z)/step_contour_bold)*step_contour_bold
249
+ lmin = np.ceil(np.nanmin(z) / step_contour_bold) * step_contour_bold
250
+ if lmin < level_min:
251
+ lmin = lmin + step_contour_bold
171
252
  levels = np.arange(lmin, np.nanmax(z), step_contour_bold)
172
- cs = axe.contour(x, y, np.flip(z, axis=0),
173
- extent=im_extent,
174
- levels=levels,
175
- **contours_bold_prop)
253
+ cs = axe.contour(
254
+ x,
255
+ y,
256
+ np.flip(z, axis=0),
257
+ extent=im_extent,
258
+ levels=levels,
259
+ **contours_bold_prop
260
+ )
176
261
  if label_contour:
177
262
  if contour_labels_properties is None:
178
263
  contour_labels_properties = {}
@@ -182,30 +267,79 @@ def plot_topo(z, x, y, contour_step=None, nlevels=25, level_min=None,
182
267
 
183
268
  if sea_color is not None:
184
269
  cmap_sea = mcolors.ListedColormap([sea_color])
185
- cmap_sea.set_under(color='w', alpha=0)
186
- mask_sea = (z <= sea_level)*1
270
+ cmap_sea.set_under(color="w", alpha=0)
271
+ mask_sea = (z <= sea_level) * 1
187
272
  if mask_sea.any():
188
- axe.imshow(mask_sea, extent=im_extent, cmap=cmap_sea,
189
- vmin=0.5, origin='lower', interpolation='none')
190
-
191
-
192
- def plot_data_on_topo(x, y, z, data, axe=None, figsize=(15/2.54, 15/2.54),
193
- cmap=None,
194
- minval=None, maxval=None, vmin=None, vmax=None,
195
- minval_abs=None,
196
- cmap_intervals=None, extend_cc='max',
197
- topo_kwargs=None, sup_plot=None, alpha=1,
198
- plot_colorbar=True, axecc=None, colorbar_kwargs=None,
199
- mask=None, alpha_mask=None, color_mask='k',
200
- xlims=None, ylims=None):
273
+ axe.imshow(
274
+ mask_sea,
275
+ extent=im_extent,
276
+ cmap=cmap_sea,
277
+ vmin=0.5,
278
+ origin="lower",
279
+ interpolation="none",
280
+ )
281
+
282
+
283
+ def plot_imshow(
284
+ x,
285
+ y,
286
+ data,
287
+ axe=None,
288
+ figsize=None,
289
+ cmap=None,
290
+ minval=None,
291
+ maxval=None,
292
+ vmin=None,
293
+ vmax=None,
294
+ alpha=1,
295
+ minval_abs=None,
296
+ cmap_intervals=None,
297
+ extend_cc="max",
298
+ plot_colorbar=True,
299
+ axecc=None,
300
+ colorbar_kwargs=None,
301
+ aspect=None,
302
+ ):
201
303
  """
202
- Plot array data on topo.
304
+ plt.imshow data with some pre-processing
305
+
306
+ Parameters
307
+ ----------
308
+ x : TYPE
309
+ DESCRIPTION.
310
+ y : TYPE
311
+ DESCRIPTION.
312
+ data : TYPE
313
+ DESCRIPTION.
314
+ axe : TYPE, optional
315
+ DESCRIPTION. The default is None.
316
+ figsize : TYPE, optional
317
+ DESCRIPTION. The default is None.
318
+ cmap : TYPE, optional
319
+ DESCRIPTION. The default is None.
320
+ minval : TYPE, optional
321
+ DESCRIPTION. The default is None.
322
+ maxval : TYPE, optional
323
+ DESCRIPTION. The default is None.
324
+ vmin : TYPE, optional
325
+ DESCRIPTION. The default is None.
326
+ vmax : TYPE, optional
327
+ DESCRIPTION. The default is None.
328
+ minval_abs : TYPE, optional
329
+ DESCRIPTION. The default is None.
330
+ cmap_intervals : TYPE, optional
331
+ DESCRIPTION. The default is None.
332
+ extend_cc : TYPE, optional
333
+ DESCRIPTION. The default is "max".
203
334
 
204
335
  Returns
205
336
  -------
206
337
  None.
207
338
 
208
339
  """
340
+ if axe is None:
341
+ _, axe = plt.subplots(1, 1, figsize=figsize, layout="constrained")
342
+
209
343
  f = copy.copy(data)
210
344
 
211
345
  # vmin and vmax are similar to minval and maxval
@@ -219,12 +353,13 @@ def plot_data_on_topo(x, y, z, data, axe=None, figsize=(15/2.54, 15/2.54),
219
353
  # Remove values below and above minval and maxval, depending on whether
220
354
  # cmap_intervals are given with or without extend_cc
221
355
  if cmap_intervals is not None:
222
- norm = matplotlib.colors.BoundaryNorm(cmap_intervals, 256,
223
- extend=extend_cc)
224
- if extend_cc in ['neither', 'max']:
356
+ norm = matplotlib.colors.BoundaryNorm(
357
+ cmap_intervals, 256, extend=extend_cc
358
+ )
359
+ if extend_cc in ["neither", "max"]:
225
360
  minval = cmap_intervals[0]
226
361
  f[f < minval] = np.nan
227
- elif extend_cc in ['neither', 'min']:
362
+ elif extend_cc in ["neither", "min"]:
228
363
  maxval = cmap_intervals[-1]
229
364
  f[f > maxval] = np.nan
230
365
  else:
@@ -247,61 +382,140 @@ def plot_data_on_topo(x, y, z, data, axe=None, figsize=(15/2.54, 15/2.54),
247
382
 
248
383
  # Define colormap type
249
384
  if cmap is None:
250
- if maxval*minval >= 0:
251
- cmap = 'hot_r'
385
+ if maxval * minval >= 0:
386
+ cmap = "hot_r"
252
387
  else:
253
- cmap = 'seismic'
254
- if maxval*minval >= 0:
255
- color_map = matplotlib.cm.get_cmap(cmap).copy()
388
+ cmap = "seismic"
389
+ if maxval * minval >= 0:
390
+ color_map = matplotlib.colormaps[cmap]
256
391
  else:
257
392
  color_map = centered_map(cmap, minval, maxval)
258
393
 
259
394
  if cmap_intervals is not None:
260
- norm = matplotlib.colors.BoundaryNorm(cmap_intervals, 256,
261
- extend=extend_cc)
395
+ norm = matplotlib.colors.BoundaryNorm(
396
+ cmap_intervals, 256, extend=extend_cc
397
+ )
262
398
  maxval = None
263
399
  minval = None
264
400
  else:
265
401
  norm = None
266
- # color_map.set_under([1, 1, 1], alpha=0)
402
+
403
+ # get map_extent
404
+ dx = x[1] - x[0]
405
+ dy = y[1] - y[0]
406
+ im_extent = [x[0] - dx / 2, x[-1] + dx / 2, y[0] - dy / 2, y[-1] + dy / 2]
407
+
408
+ # Plot data
409
+
410
+ fim = axe.imshow(
411
+ f,
412
+ extent=im_extent,
413
+ cmap=color_map,
414
+ vmin=minval,
415
+ vmax=maxval,
416
+ alpha=alpha,
417
+ interpolation="none",
418
+ norm=norm,
419
+ zorder=4,
420
+ aspect=aspect,
421
+ )
422
+
423
+ # Plot colorbar
424
+ if plot_colorbar:
425
+ colorbar_kwargs = {} if colorbar_kwargs is None else colorbar_kwargs
426
+ if cmap_intervals is not None and extend_cc is not None:
427
+ colorbar_kwargs["extend"] = extend_cc
428
+ axe.figure.colorbar(fim, cax=axecc, **colorbar_kwargs)
429
+
430
+ return axe
431
+
432
+
433
+ def plot_data_on_topo(
434
+ x,
435
+ y,
436
+ z,
437
+ data,
438
+ axe=None,
439
+ figsize=(15 / 2.54, 15 / 2.54),
440
+ cmap=None,
441
+ minval=None,
442
+ maxval=None,
443
+ vmin=None,
444
+ vmax=None,
445
+ minval_abs=None,
446
+ cmap_intervals=None,
447
+ extend_cc="max",
448
+ topo_kwargs=None,
449
+ sup_plot=None,
450
+ alpha=1,
451
+ plot_colorbar=True,
452
+ axecc=None,
453
+ colorbar_kwargs=None,
454
+ mask=None,
455
+ alpha_mask=None,
456
+ color_mask="k",
457
+ xlims=None,
458
+ ylims=None,
459
+ ):
460
+ """
461
+ Plot array data on topo.
462
+
463
+ Returns
464
+ -------
465
+ None.
466
+
467
+ """
267
468
 
268
469
  # Initialize figure properties
269
- dx = x[1]-x[0]
270
- dy = y[1]-y[0]
271
- im_extent = [x[0]-dx/2, x[-1]+dx/2, y[0]-dy/2, y[-1]+dy/2]
470
+ dx = x[1] - x[0]
471
+ dy = y[1] - y[0]
472
+ im_extent = [x[0] - dx / 2, x[-1] + dx / 2, y[0] - dy / 2, y[-1] + dy / 2]
272
473
  if axe is None:
273
- fig = plt.figure(figsize=figsize)
274
- axe = fig.gca()
275
- else:
276
- fig = axe.figure
277
- axe.set_ylabel('Y (m)')
278
- axe.set_xlabel('X (m)')
279
- axe.set_aspect('equal', adjustable='box')
474
+ fig, axe = plt.subplots(1, 1, figsize=figsize, layout="constrained")
475
+
476
+ axe.set_ylabel("Y (m)")
477
+ axe.set_xlabel("X (m)")
478
+ axe.set_aspect("equal", adjustable="box")
280
479
 
281
480
  # Plot topo
282
481
  topo_kwargs = {} if topo_kwargs is None else topo_kwargs
283
482
 
284
- plot_topo(z, x, y, axe=axe, **topo_kwargs)
483
+ if z is not None:
484
+ plot_topo(z, x, y, axe=axe, **topo_kwargs)
285
485
 
286
486
  # Plot mask
287
487
  if mask is not None:
288
488
  cmap_mask = mcolors.ListedColormap([color_mask])
289
- cmap_mask.set_under(color='w', alpha=0)
290
- axe.imshow(mask.transpose(), extent=im_extent, cmap=cmap_mask,
291
- vmin=0.5, origin='lower', interpolation='none',
292
- zorder=3, alpha=alpha_mask)
489
+ cmap_mask.set_under(color="w", alpha=0)
490
+ axe.imshow(
491
+ mask.transpose(),
492
+ extent=im_extent,
493
+ cmap=cmap_mask,
494
+ vmin=0.5,
495
+ origin="lower",
496
+ interpolation="none",
497
+ # zorder=3,
498
+ alpha=alpha_mask,
499
+ )
293
500
 
294
501
  # Plot data
295
- fim = axe.imshow(f, extent=im_extent, cmap=color_map,
296
- vmin=minval, vmax=maxval, alpha=alpha,
297
- interpolation='none', norm=norm, zorder=4)
298
-
299
- # Plot colorbar
300
- if plot_colorbar:
301
- colorbar_kwargs = {} if colorbar_kwargs is None else colorbar_kwargs
302
- if cmap_intervals is not None and extend_cc is not None:
303
- colorbar_kwargs['extend'] = extend_cc
304
- colorbar(fim, cax=axecc, **colorbar_kwargs)
502
+ plot_imshow(
503
+ x,
504
+ y,
505
+ data,
506
+ axe=axe,
507
+ cmap=cmap,
508
+ minval=minval,
509
+ maxval=maxval,
510
+ vmin=vmin,
511
+ vmax=vmax,
512
+ minval_abs=minval_abs,
513
+ cmap_intervals=cmap_intervals,
514
+ extend_cc=extend_cc,
515
+ plot_colorbar=plot_colorbar,
516
+ axecc=axecc,
517
+ colorbar_kwargs=colorbar_kwargs,
518
+ )
305
519
 
306
520
  # Adjust axes limits
307
521
  if xlims is not None:
@@ -313,11 +527,59 @@ def plot_data_on_topo(x, y, z, data, axe=None, figsize=(15/2.54, 15/2.54),
313
527
  return axe
314
528
 
315
529
 
316
- def plot_maps(x, y, z, data, t, file_name, folder_out=None,
317
- figsize=None, dpi=None, fmt='png',
318
- sup_plt_fn=None,
319
- sup_plt_fn_args=None,
320
- **kwargs):
530
+ def plot_shotgather(x, t, data, xlabel="X (m)", ylabel=None, **kwargs):
531
+ """
532
+ Plot shotgather like image, with vertical axis as time and horizontal axis
533
+ and spatial dimension. This is a simple call to plot_shotgather, but
534
+ input data is transposed because in tilupy the last axis is time by
535
+ convention.
536
+
537
+ Parameters
538
+ ----------
539
+ x : NX-array
540
+ spatial coordinates
541
+ t : NT-array
542
+ time array (assumed in seconds)
543
+ data : TYPE
544
+ NX*NT array of data to be plotted
545
+ spatial_label : string, optional
546
+ label for y-axis. The default is "X (m)"
547
+ **kwargs : dict, optional
548
+ parameters passed on to plot_imshow
549
+
550
+ Returns
551
+ -------
552
+ axe : Axes
553
+ Axes instance where data is plotted
554
+
555
+ """
556
+ if "aspect" not in kwargs:
557
+ kwargs["aspect"] = "auto"
558
+ axe = plot_imshow(x, t[::-1], data.T, **kwargs)
559
+ axe.set_adjustable("box")
560
+ if ylabel is None:
561
+ ylabel = "Time (s)"
562
+ axe.set_ylabel(ylabel)
563
+ axe.set_xlabel(xlabel)
564
+
565
+ return axe
566
+
567
+
568
+ def plot_maps(
569
+ x,
570
+ y,
571
+ z,
572
+ data,
573
+ t,
574
+ file_name=None,
575
+ folder_out=None,
576
+ figsize=None,
577
+ dpi=None,
578
+ fmt="png",
579
+ sup_plt_fn=None,
580
+ sup_plt_fn_args=None,
581
+ **kwargs
582
+ ):
321
583
  """
322
584
  Plot and save maps of simulations outputs at successive time steps
323
585
 
@@ -349,29 +611,34 @@ def plot_maps(x, y, z, data, t, file_name, folder_out=None,
349
611
  nfigs = len(t)
350
612
  if nfigs != data.shape[2]:
351
613
  raise ValueError(
352
- 'length of t must be similar to the last dimension of data')
614
+ "length of t must be similar to the last dimension of data"
615
+ )
353
616
  if folder_out is not None:
354
- file_path = os.path.join(folder_out, file_name + '_{:04d}.' + fmt)
355
- title_fmt = 't = {:.2f} s'
617
+ file_path = os.path.join(folder_out, file_name + "_{:04d}." + fmt)
618
+ title_fmt = "t = {:.2f} s"
356
619
 
357
620
  for i in range(nfigs):
358
- axe = plot_data_on_topo(x, y, z, data[:, :, i], axe=None,
359
- figsize=figsize,
360
- **kwargs)
621
+ axe = plot_data_on_topo(
622
+ x, y, z, data[:, :, i], axe=None, figsize=figsize, **kwargs
623
+ )
361
624
  axe.set_title(title_fmt.format(t[i]))
362
625
  if sup_plt_fn is not None:
363
626
  if sup_plt_fn_args is None:
364
627
  sup_plt_fn_args = dict()
365
628
  sup_plt_fn(axe, **sup_plt_fn_args)
366
- axe.figure.tight_layout(pad=0.1)
629
+ # axe.figure.tight_layout(pad=0.1)
367
630
  if folder_out is not None:
368
- axe.figure.savefig(file_path.format(i), dpi=dpi,
369
- bbox_inches='tight', pad_inches=0.05)
631
+ axe.figure.savefig(
632
+ file_path.format(i),
633
+ dpi=dpi,
634
+ bbox_inches="tight",
635
+ pad_inches=0.05,
636
+ )
370
637
 
371
638
 
372
- def colorbar(mappable, ax=None,
373
- cax=None, size="5%", pad=0.1, position='right',
374
- **kwargs):
639
+ def colorbar(
640
+ mappable, ax=None, cax=None, size="5%", pad=0.1, position="right", **kwargs
641
+ ):
375
642
  """
376
643
  Create nice colorbar matching height/width of axe.
377
644
 
@@ -399,45 +666,54 @@ def colorbar(mappable, ax=None,
399
666
  if ax is None:
400
667
  ax = mappable.axes
401
668
  fig = ax.figure
402
- if position in ['left', 'right']:
403
- orientation = 'vertical'
669
+ if position in ["left", "right"]:
670
+ orientation = "vertical"
404
671
  else:
405
- orientation = 'horizontal'
406
- if cax is None:
407
- # divider = ax.get_axes_locator()
408
- # if divider is None:
409
- divider = make_axes_locatable(ax)
410
- cax = divider.append_axes(position, size=size, pad=pad)
672
+ orientation = "horizontal"
411
673
 
412
- cc = fig.colorbar(mappable, cax=cax,
413
- orientation=orientation, **kwargs)
674
+ # if cax is None:
675
+ # # divider = ax.get_axes_locator()
676
+ # # if divider is None:
677
+ # divider = make_axes_locatable(ax)
678
+ # cax = divider.append_axes(position, size=size, pad=pad)
414
679
 
415
- if position == 'top':
680
+ cc = fig.colorbar(mappable, cax=cax, orientation=orientation, **kwargs)
681
+
682
+ if position == "top":
416
683
  cax.xaxis.tick_top()
417
- cax.xaxis.set_label_position('top')
418
- if position == 'left':
684
+ cax.xaxis.set_label_position("top")
685
+ if position == "left":
419
686
  cax.yaxis.tick_left()
420
- cax.xaxis.set_label_position('left')
687
+ cax.xaxis.set_label_position("left")
421
688
  return cc
422
689
 
423
690
 
424
- def plot_heatmaps(df, values, index, columns, aggfunc='mean',
425
- figsize=None, ncols=3,
426
- heatmap_kws=None, notations=None, best_values=None,
427
- plot_best_value='point', text_kwargs=None):
428
-
691
+ def plot_heatmaps(
692
+ df,
693
+ values,
694
+ index,
695
+ columns,
696
+ aggfunc="mean",
697
+ figsize=None,
698
+ ncols=3,
699
+ heatmap_kws=None,
700
+ notations=None,
701
+ best_values=None,
702
+ plot_best_value="point",
703
+ text_kwargs=None,
704
+ ):
429
705
  nplots = len(values)
430
706
  ncols = min(nplots, ncols)
431
- nrows = int(np.ceil(nplots/ncols))
707
+ nrows = int(np.ceil(nplots / ncols))
432
708
  fig = plt.figure(figsize=figsize)
433
709
  axes = []
434
710
 
435
711
  for i in range(nplots):
436
- axe = fig.add_subplot(nrows, ncols, i+1)
712
+ axe = fig.add_subplot(nrows, ncols, i + 1)
437
713
  axes.append(axe)
438
- data = df.pivot_table(index=index, columns=columns,
439
- values=values[i],
440
- aggfunc=aggfunc).astype(float)
714
+ data = df.pivot_table(
715
+ index=index, columns=columns, values=values[i], aggfunc=aggfunc
716
+ ).astype(float)
441
717
  if heatmap_kws is None:
442
718
  kws = dict()
443
719
  elif isinstance(heatmap_kws, dict):
@@ -446,25 +722,25 @@ def plot_heatmaps(df, values, index, columns, aggfunc='mean',
446
722
  else:
447
723
  kws = heatmap_kws
448
724
 
449
- if 'cmap' not in kws:
725
+ if "cmap" not in kws:
450
726
  minval = data.min().min()
451
727
  maxval = data.max().max()
452
- if minval*maxval < 0:
728
+ if minval * maxval < 0:
453
729
  val = max(np.abs(minval), maxval)
454
- kws['cmap'] = 'seismic'
455
- kws['vmin'] = -val
456
- kws['vmax'] = val
730
+ kws["cmap"] = "seismic"
731
+ kws["vmin"] = -val
732
+ kws["vmax"] = val
457
733
 
458
- if 'cbar_kws' not in kws:
459
- kws['cbar_kws'] = dict(pad=0.03)
734
+ if "cbar_kws" not in kws:
735
+ kws["cbar_kws"] = dict(pad=0.03)
460
736
 
461
737
  if notations is None:
462
- kws['cbar_kws']['label'] = values[i]
738
+ kws["cbar_kws"]["label"] = values[i]
463
739
  else:
464
740
  if values[i] in notations:
465
- kws['cbar_kws']['label'] = notations[values[i]]
741
+ kws["cbar_kws"]["label"] = notations[values[i]]
466
742
  else:
467
- kws['cbar_kws']['label'] = values[i]
743
+ kws["cbar_kws"]["label"] = values[i]
468
744
 
469
745
  sns.heatmap(data, ax=axe, **kws)
470
746
 
@@ -473,50 +749,71 @@ def plot_heatmaps(df, values, index, columns, aggfunc='mean',
473
749
  array = np.array(data)
474
750
  irow = np.arange(data.shape[0])
475
751
 
476
- if best_value == 'min':
752
+ if best_value == "min":
477
753
  ind = np.nanargmin(array, axis=1)
478
754
  i2 = np.nanargmin(array[irow, ind])
479
- if best_value == 'min_abs':
755
+ if best_value == "min_abs":
480
756
  ind = np.nanargmin(np.abs(array), axis=1)
481
757
  i2 = np.nanargmin(np.abs(array[irow, ind]))
482
- elif best_value == 'max':
758
+ elif best_value == "max":
483
759
  ind = np.nanargmax(array, axis=1)
484
760
  i2 = np.nanargmax(array[irow, ind])
485
761
 
486
- if plot_best_value == 'point':
487
- axe.plot(ind + 0.5, irow+0.5, ls='',
488
- marker='o', mfc='w', mec='k', mew=0.5, ms=5)
489
- axe.plot(ind[i2] + 0.5, i2+0.5, ls='',
490
- marker='o', mfc='w', mec='k', mew=0.8, ms=9)
491
- elif plot_best_value == 'text':
762
+ if plot_best_value == "point":
763
+ axe.plot(
764
+ ind + 0.5,
765
+ irow + 0.5,
766
+ ls="",
767
+ marker="o",
768
+ mfc="w",
769
+ mec="k",
770
+ mew=0.5,
771
+ ms=5,
772
+ )
773
+ axe.plot(
774
+ ind[i2] + 0.5,
775
+ i2 + 0.5,
776
+ ls="",
777
+ marker="o",
778
+ mfc="w",
779
+ mec="k",
780
+ mew=0.8,
781
+ ms=9,
782
+ )
783
+ elif plot_best_value == "text":
492
784
  indx = list(ind)
493
785
  indx.pop(i2)
494
786
  indy = list(irow)
495
787
  indy.pop(i2)
496
- default_kwargs = dict(ha='center', va='center',
497
- fontsize=8)
788
+ default_kwargs = dict(ha="center", va="center", fontsize=8)
498
789
  if text_kwargs is None:
499
790
  text_kwargs = default_kwargs
500
791
  else:
501
792
  text_kwargs = dict(default_kwargs, **text_kwargs)
502
793
  for i, j in zip(indx, indy):
503
- axe.text(i + 0.5, j+0.5,
504
- '{:.2g}'.format(array[j, i]),
505
- **text_kwargs)
506
- text_kwargs2 = dict(text_kwargs, fontweight='bold')
507
- axe.text(ind[i2] + 0.5, i2+0.5,
508
- '{:.2g}'.format(array[i2, ind[i2]]),
509
- **text_kwargs2)
794
+ axe.text(
795
+ i + 0.5,
796
+ j + 0.5,
797
+ "{:.2g}".format(array[j, i]),
798
+ **text_kwargs
799
+ )
800
+ text_kwargs2 = dict(text_kwargs, fontweight="bold")
801
+ axe.text(
802
+ ind[i2] + 0.5,
803
+ i2 + 0.5,
804
+ "{:.2g}".format(array[i2, ind[i2]]),
805
+ **text_kwargs2
806
+ )
510
807
 
511
808
  axes = np.array(axes).reshape((nrows, ncols))
512
809
  for i in range(nrows):
513
810
  for j in range(1, ncols):
514
- axes[i, j].set_ylabel('')
811
+ axes[i, j].set_ylabel("")
515
812
  # axes[i, j].set_yticklabels([])
516
813
 
517
- for i in range(nrows-1):
814
+ for i in range(nrows - 1):
518
815
  for j in range(ncols):
519
- axes[i, j].set_xlabel('')
816
+ axes[i, j].set_xlabel("")
520
817
  # axes[i, j].set_xticklabels([])
521
818
 
522
819
  if notations is not None:
@@ -525,6 +822,6 @@ def plot_heatmaps(df, values, index, columns, aggfunc='mean',
525
822
  for j in range(ncols):
526
823
  axes[-1, j].set_xlabel(notations[columns])
527
824
 
528
- fig.tight_layout()
825
+ # fig.tight_layout()
529
826
 
530
827
  return fig