tilupy 0.1.5__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of tilupy might be problematic. Click here for more details.

tilupy/plot.py CHANGED
@@ -15,9 +15,21 @@ import matplotlib.pyplot as plt
15
15
  import matplotlib.colors as mcolors
16
16
  import seaborn as sns
17
17
 
18
- from mpl_toolkits.axes_grid1 import make_axes_locatable
19
-
20
- BOLD_CONTOURS_INTV = [0.1, 0.2, 0.5, 1, 2., 5, 10, 20, 50, 100, 200, 500, 1000]
18
+ BOLD_CONTOURS_INTV = [
19
+ 0.1,
20
+ 0.2,
21
+ 0.5,
22
+ 1,
23
+ 2.0,
24
+ 5,
25
+ 10,
26
+ 20,
27
+ 50,
28
+ 100,
29
+ 200,
30
+ 500,
31
+ 1000,
32
+ ]
21
33
  NB_THIN_CONTOURS = 10
22
34
  NB_BOLD_CONTOURS = 3
23
35
 
@@ -43,25 +55,25 @@ def centered_map(cmap, vmin, vmax, ncolors=256):
43
55
  DESCRIPTION.
44
56
 
45
57
  """
46
- p = vmax/(vmax-vmin)
47
- npos = int(ncolors*p)
58
+ p = vmax / (vmax - vmin)
59
+ npos = int(ncolors * p)
48
60
  method = getattr(plt.cm, cmap)
49
61
 
50
- colors1 = method(np.linspace(0., 1, npos*2))
51
- colors2 = method(np.linspace(0., 1, (ncolors-npos)*2))
62
+ colors1 = method(np.linspace(0.0, 1, npos * 2))
63
+ colors2 = method(np.linspace(0.0, 1, (ncolors - npos) * 2))
52
64
  colors = np.concatenate(
53
- (colors2[:ncolors-npos, :], colors1[npos:, :]), axis=0)
65
+ (colors2[: ncolors - npos, :], colors1[npos:, :]), axis=0
66
+ )
54
67
  # colors[ncolors-npos-1,:]=np.ones((1,4))
55
68
  # colors[ncolors-npos,:]=np.ones((1,4))
56
- new_map = mcolors.LinearSegmentedColormap.from_list(
57
- 'my_colormap', colors)
69
+ new_map = mcolors.LinearSegmentedColormap.from_list("my_colormap", colors)
58
70
 
59
71
  return new_map
60
72
 
61
73
 
62
- def get_contour_intervals(zmin, zmax, nb_bold_contours=None,
63
- nb_thin_contours=None):
64
-
74
+ def get_contour_intervals(
75
+ zmin, zmax, nb_bold_contours=None, nb_thin_contours=None
76
+ ):
65
77
  if nb_thin_contours is None:
66
78
  nb_thin_contours = NB_THIN_CONTOURS
67
79
  if nb_bold_contours is None:
@@ -73,28 +85,46 @@ def get_contour_intervals(zmin, zmax, nb_bold_contours=None,
73
85
  bold_intv = BOLD_CONTOURS_INTV[i]
74
86
  if BOLD_CONTOURS_INTV[i] != BOLD_CONTOURS_INTV[0]:
75
87
  if bold_intv - intv > 0:
76
- bold_intv = BOLD_CONTOURS_INTV[i-1]
88
+ bold_intv = BOLD_CONTOURS_INTV[i - 1]
77
89
 
78
90
  if nb_thin_contours is None:
79
91
  thin_intv = bold_intv / NB_THIN_CONTOURS
80
- if (zmax - zmin)/bold_intv > 5:
81
- thin_intv = thin_intv*2
92
+ if (zmax - zmin) / bold_intv > 5:
93
+ thin_intv = thin_intv * 2
82
94
  else:
83
95
  thin_intv = bold_intv / nb_thin_contours
84
96
 
85
97
  return bold_intv, thin_intv
86
98
 
87
99
 
88
- def plot_topo(z, x, y, contour_step=None, nlevels=None, level_min=None,
89
- step_contour_bold='auto', contour_labels_properties=None,
90
- label_contour=True, contour_label_effect=None,
91
- axe=None,
92
- vert_exag=1, fraction=1, ndv=-9999, uniform_grey=None,
93
- contours_prop=None, contours_bold_prop=None,
94
- figsize=None,
95
- interpolation=None,
96
- sea_level=0, sea_color=None, alpha=1, azdeg=315, altdeg=45,
97
- zmin=None, zmax=None):
100
+ def plot_topo(
101
+ z,
102
+ x,
103
+ y,
104
+ contour_step=None,
105
+ nlevels=None,
106
+ level_min=None,
107
+ step_contour_bold="auto",
108
+ contour_labels_properties=None,
109
+ label_contour=True,
110
+ contour_label_effect=None,
111
+ axe=None,
112
+ vert_exag=1,
113
+ fraction=1,
114
+ ndv=-9999,
115
+ uniform_grey=None,
116
+ contours_prop=None,
117
+ contours_bold_prop=None,
118
+ figsize=None,
119
+ interpolation=None,
120
+ sea_level=0,
121
+ sea_color=None,
122
+ alpha=1,
123
+ azdeg=315,
124
+ altdeg=45,
125
+ zmin=None,
126
+ zmax=None,
127
+ ):
98
128
  """
99
129
  Plot topography with hillshading.
100
130
 
@@ -150,21 +180,19 @@ def plot_topo(z, x, y, contour_step=None, nlevels=None, level_min=None,
150
180
  """
151
181
  dx = x[1] - x[0]
152
182
  dy = y[1] - y[0]
153
- im_extent = [x[0]-dx/2,
154
- x[-1]+dx/2,
155
- y[0]-dy/2,
156
- y[-1]+dy/2]
183
+ im_extent = [x[0] - dx / 2, x[-1] + dx / 2, y[0] - dy / 2, y[-1] + dy / 2]
157
184
  ls = mcolors.LightSource(azdeg=azdeg, altdeg=altdeg)
158
185
 
159
186
  auto_bold_intv = None
160
187
 
161
188
  if nlevels is None and contour_step is None:
162
- auto_bold_intv, contour_step = get_contour_intervals(np.nanmin(z),
163
- np.nanmax(z))
189
+ auto_bold_intv, contour_step = get_contour_intervals(
190
+ np.nanmin(z), np.nanmax(z)
191
+ )
164
192
 
165
193
  if level_min is None:
166
194
  if contour_step is not None:
167
- level_min = np.ceil(np.nanmin(z)/contour_step)*contour_step
195
+ level_min = np.ceil(np.nanmin(z) / contour_step) * contour_step
168
196
  else:
169
197
  level_min = np.nanmin(z)
170
198
  if contour_step is not None:
@@ -173,50 +201,63 @@ def plot_topo(z, x, y, contour_step=None, nlevels=None, level_min=None,
173
201
  levels = np.linspace(level_min, np.nanmax(z), nlevels)
174
202
 
175
203
  if axe is None:
176
- fig = plt.figure(figsize=figsize)
177
- axe = fig.gca()
178
- else:
179
- fig = axe.figure
180
- axe.set_ylabel('Y (m)')
181
- axe.set_xlabel('X (m)')
182
- axe.set_aspect('equal')
204
+ fig, axe = plt.subplots(1, 1, figsize=figsize, layout="constrained")
205
+
206
+ axe.set_ylabel("Y (m)")
207
+ axe.set_xlabel("X (m)")
208
+ axe.set_aspect("equal")
183
209
 
184
210
  if uniform_grey is None:
185
- shaded_topo = ls.hillshade(z,
186
- vert_exag=vert_exag, dx=dx, dy=dy,
187
- fraction=1)
211
+ shaded_topo = ls.hillshade(
212
+ z, vert_exag=vert_exag, dx=dx, dy=dy, fraction=1
213
+ )
188
214
  else:
189
- shaded_topo = np.ones(z.shape)*uniform_grey
215
+ shaded_topo = np.ones(z.shape) * uniform_grey
190
216
  shaded_topo[z == ndv] = np.nan
191
- axe.imshow(shaded_topo, cmap='gray', extent=im_extent,
192
- interpolation=interpolation, alpha=alpha, vmin=0, vmax=1)
217
+ axe.imshow(
218
+ shaded_topo,
219
+ cmap="gray",
220
+ extent=im_extent,
221
+ interpolation=interpolation,
222
+ alpha=alpha,
223
+ vmin=0,
224
+ vmax=1,
225
+ )
193
226
 
194
227
  if contours_prop is None:
195
- contours_prop = dict(alpha=0.5, colors='k',
196
- linewidths=0.5)
197
- axe.contour(x, y, np.flip(z, axis=0), extent=im_extent,
198
- levels=levels,
199
- **contours_prop)
228
+ contours_prop = dict(alpha=0.5, colors="k", linewidths=0.5)
229
+ axe.contour(
230
+ x,
231
+ y,
232
+ np.flip(z, axis=0),
233
+ extent=im_extent,
234
+ levels=levels,
235
+ **contours_prop
236
+ )
200
237
 
201
238
  if contours_bold_prop is None:
202
- contours_bold_prop = dict(alpha=0.8, colors='k',
203
- linewidths=0.8)
239
+ contours_bold_prop = dict(alpha=0.8, colors="k", linewidths=0.8)
204
240
 
205
- if step_contour_bold == 'auto':
241
+ if step_contour_bold == "auto":
206
242
  if auto_bold_intv is None:
207
- auto_bold_intv, _ = get_contour_intervals(np.nanmin(z),
208
- np.nanmax(z))
243
+ auto_bold_intv, _ = get_contour_intervals(
244
+ np.nanmin(z), np.nanmax(z)
245
+ )
209
246
  step_contour_bold = auto_bold_intv
210
247
 
211
248
  if step_contour_bold > 0:
212
- lmin = np.ceil(np.nanmin(z)/step_contour_bold)*step_contour_bold
249
+ lmin = np.ceil(np.nanmin(z) / step_contour_bold) * step_contour_bold
213
250
  if lmin < level_min:
214
251
  lmin = lmin + step_contour_bold
215
252
  levels = np.arange(lmin, np.nanmax(z), step_contour_bold)
216
- cs = axe.contour(x, y, np.flip(z, axis=0),
217
- extent=im_extent,
218
- levels=levels,
219
- **contours_bold_prop)
253
+ cs = axe.contour(
254
+ x,
255
+ y,
256
+ np.flip(z, axis=0),
257
+ extent=im_extent,
258
+ levels=levels,
259
+ **contours_bold_prop
260
+ )
220
261
  if label_contour:
221
262
  if contour_labels_properties is None:
222
263
  contour_labels_properties = {}
@@ -226,30 +267,79 @@ def plot_topo(z, x, y, contour_step=None, nlevels=None, level_min=None,
226
267
 
227
268
  if sea_color is not None:
228
269
  cmap_sea = mcolors.ListedColormap([sea_color])
229
- cmap_sea.set_under(color='w', alpha=0)
230
- mask_sea = (z <= sea_level)*1
270
+ cmap_sea.set_under(color="w", alpha=0)
271
+ mask_sea = (z <= sea_level) * 1
231
272
  if mask_sea.any():
232
- axe.imshow(mask_sea, extent=im_extent, cmap=cmap_sea,
233
- vmin=0.5, origin='lower', interpolation='none')
234
-
235
-
236
- def plot_data_on_topo(x, y, z, data, axe=None, figsize=(15/2.54, 15/2.54),
237
- cmap=None,
238
- minval=None, maxval=None, vmin=None, vmax=None,
239
- minval_abs=None,
240
- cmap_intervals=None, extend_cc='max',
241
- topo_kwargs=None, sup_plot=None, alpha=1,
242
- plot_colorbar=True, axecc=None, colorbar_kwargs=None,
243
- mask=None, alpha_mask=None, color_mask='k',
244
- xlims=None, ylims=None):
273
+ axe.imshow(
274
+ mask_sea,
275
+ extent=im_extent,
276
+ cmap=cmap_sea,
277
+ vmin=0.5,
278
+ origin="lower",
279
+ interpolation="none",
280
+ )
281
+
282
+
283
+ def plot_imshow(
284
+ x,
285
+ y,
286
+ data,
287
+ axe=None,
288
+ figsize=None,
289
+ cmap=None,
290
+ minval=None,
291
+ maxval=None,
292
+ vmin=None,
293
+ vmax=None,
294
+ alpha=1,
295
+ minval_abs=None,
296
+ cmap_intervals=None,
297
+ extend_cc="max",
298
+ plot_colorbar=True,
299
+ axecc=None,
300
+ colorbar_kwargs=None,
301
+ aspect=None,
302
+ ):
245
303
  """
246
- Plot array data on topo.
304
+ plt.imshow data with some pre-processing
305
+
306
+ Parameters
307
+ ----------
308
+ x : TYPE
309
+ DESCRIPTION.
310
+ y : TYPE
311
+ DESCRIPTION.
312
+ data : TYPE
313
+ DESCRIPTION.
314
+ axe : TYPE, optional
315
+ DESCRIPTION. The default is None.
316
+ figsize : TYPE, optional
317
+ DESCRIPTION. The default is None.
318
+ cmap : TYPE, optional
319
+ DESCRIPTION. The default is None.
320
+ minval : TYPE, optional
321
+ DESCRIPTION. The default is None.
322
+ maxval : TYPE, optional
323
+ DESCRIPTION. The default is None.
324
+ vmin : TYPE, optional
325
+ DESCRIPTION. The default is None.
326
+ vmax : TYPE, optional
327
+ DESCRIPTION. The default is None.
328
+ minval_abs : TYPE, optional
329
+ DESCRIPTION. The default is None.
330
+ cmap_intervals : TYPE, optional
331
+ DESCRIPTION. The default is None.
332
+ extend_cc : TYPE, optional
333
+ DESCRIPTION. The default is "max".
247
334
 
248
335
  Returns
249
336
  -------
250
337
  None.
251
338
 
252
339
  """
340
+ if axe is None:
341
+ _, axe = plt.subplots(1, 1, figsize=figsize, layout="constrained")
342
+
253
343
  f = copy.copy(data)
254
344
 
255
345
  # vmin and vmax are similar to minval and maxval
@@ -263,12 +353,13 @@ def plot_data_on_topo(x, y, z, data, axe=None, figsize=(15/2.54, 15/2.54),
263
353
  # Remove values below and above minval and maxval, depending on whether
264
354
  # cmap_intervals are given with or without extend_cc
265
355
  if cmap_intervals is not None:
266
- norm = matplotlib.colors.BoundaryNorm(cmap_intervals, 256,
267
- extend=extend_cc)
268
- if extend_cc in ['neither', 'max']:
356
+ norm = matplotlib.colors.BoundaryNorm(
357
+ cmap_intervals, 256, extend=extend_cc
358
+ )
359
+ if extend_cc in ["neither", "max"]:
269
360
  minval = cmap_intervals[0]
270
361
  f[f < minval] = np.nan
271
- elif extend_cc in ['neither', 'min']:
362
+ elif extend_cc in ["neither", "min"]:
272
363
  maxval = cmap_intervals[-1]
273
364
  f[f > maxval] = np.nan
274
365
  else:
@@ -291,61 +382,140 @@ def plot_data_on_topo(x, y, z, data, axe=None, figsize=(15/2.54, 15/2.54),
291
382
 
292
383
  # Define colormap type
293
384
  if cmap is None:
294
- if maxval*minval >= 0:
295
- cmap = 'hot_r'
385
+ if maxval * minval >= 0:
386
+ cmap = "hot_r"
296
387
  else:
297
- cmap = 'seismic'
298
- if maxval*minval >= 0:
299
- color_map = matplotlib.cm.get_cmap(cmap).copy()
388
+ cmap = "seismic"
389
+ if maxval * minval >= 0:
390
+ color_map = matplotlib.colormaps[cmap]
300
391
  else:
301
392
  color_map = centered_map(cmap, minval, maxval)
302
393
 
303
394
  if cmap_intervals is not None:
304
- norm = matplotlib.colors.BoundaryNorm(cmap_intervals, 256,
305
- extend=extend_cc)
395
+ norm = matplotlib.colors.BoundaryNorm(
396
+ cmap_intervals, 256, extend=extend_cc
397
+ )
306
398
  maxval = None
307
399
  minval = None
308
400
  else:
309
401
  norm = None
310
- # color_map.set_under([1, 1, 1], alpha=0)
402
+
403
+ # get map_extent
404
+ dx = x[1] - x[0]
405
+ dy = y[1] - y[0]
406
+ im_extent = [x[0] - dx / 2, x[-1] + dx / 2, y[0] - dy / 2, y[-1] + dy / 2]
407
+
408
+ # Plot data
409
+
410
+ fim = axe.imshow(
411
+ f,
412
+ extent=im_extent,
413
+ cmap=color_map,
414
+ vmin=minval,
415
+ vmax=maxval,
416
+ alpha=alpha,
417
+ interpolation="none",
418
+ norm=norm,
419
+ zorder=4,
420
+ aspect=aspect,
421
+ )
422
+
423
+ # Plot colorbar
424
+ if plot_colorbar:
425
+ colorbar_kwargs = {} if colorbar_kwargs is None else colorbar_kwargs
426
+ if cmap_intervals is not None and extend_cc is not None:
427
+ colorbar_kwargs["extend"] = extend_cc
428
+ axe.figure.colorbar(fim, cax=axecc, **colorbar_kwargs)
429
+
430
+ return axe
431
+
432
+
433
+ def plot_data_on_topo(
434
+ x,
435
+ y,
436
+ z,
437
+ data,
438
+ axe=None,
439
+ figsize=(15 / 2.54, 15 / 2.54),
440
+ cmap=None,
441
+ minval=None,
442
+ maxval=None,
443
+ vmin=None,
444
+ vmax=None,
445
+ minval_abs=None,
446
+ cmap_intervals=None,
447
+ extend_cc="max",
448
+ topo_kwargs=None,
449
+ sup_plot=None,
450
+ alpha=1,
451
+ plot_colorbar=True,
452
+ axecc=None,
453
+ colorbar_kwargs=None,
454
+ mask=None,
455
+ alpha_mask=None,
456
+ color_mask="k",
457
+ xlims=None,
458
+ ylims=None,
459
+ ):
460
+ """
461
+ Plot array data on topo.
462
+
463
+ Returns
464
+ -------
465
+ None.
466
+
467
+ """
311
468
 
312
469
  # Initialize figure properties
313
- dx = x[1]-x[0]
314
- dy = y[1]-y[0]
315
- im_extent = [x[0]-dx/2, x[-1]+dx/2, y[0]-dy/2, y[-1]+dy/2]
470
+ dx = x[1] - x[0]
471
+ dy = y[1] - y[0]
472
+ im_extent = [x[0] - dx / 2, x[-1] + dx / 2, y[0] - dy / 2, y[-1] + dy / 2]
316
473
  if axe is None:
317
- fig = plt.figure(figsize=figsize)
318
- axe = fig.gca()
319
- else:
320
- fig = axe.figure
321
- axe.set_ylabel('Y (m)')
322
- axe.set_xlabel('X (m)')
323
- axe.set_aspect('equal', adjustable='box')
474
+ fig, axe = plt.subplots(1, 1, figsize=figsize, layout="constrained")
475
+
476
+ axe.set_ylabel("Y (m)")
477
+ axe.set_xlabel("X (m)")
478
+ axe.set_aspect("equal", adjustable="box")
324
479
 
325
480
  # Plot topo
326
481
  topo_kwargs = {} if topo_kwargs is None else topo_kwargs
327
482
 
328
- plot_topo(z, x, y, axe=axe, **topo_kwargs)
483
+ if z is not None:
484
+ plot_topo(z, x, y, axe=axe, **topo_kwargs)
329
485
 
330
486
  # Plot mask
331
487
  if mask is not None:
332
488
  cmap_mask = mcolors.ListedColormap([color_mask])
333
- cmap_mask.set_under(color='w', alpha=0)
334
- axe.imshow(mask.transpose(), extent=im_extent, cmap=cmap_mask,
335
- vmin=0.5, origin='lower', interpolation='none',
336
- zorder=3, alpha=alpha_mask)
489
+ cmap_mask.set_under(color="w", alpha=0)
490
+ axe.imshow(
491
+ mask.transpose(),
492
+ extent=im_extent,
493
+ cmap=cmap_mask,
494
+ vmin=0.5,
495
+ origin="lower",
496
+ interpolation="none",
497
+ # zorder=3,
498
+ alpha=alpha_mask,
499
+ )
337
500
 
338
501
  # Plot data
339
- fim = axe.imshow(f, extent=im_extent, cmap=color_map,
340
- vmin=minval, vmax=maxval, alpha=alpha,
341
- interpolation='none', norm=norm, zorder=4)
342
-
343
- # Plot colorbar
344
- if plot_colorbar:
345
- colorbar_kwargs = {} if colorbar_kwargs is None else colorbar_kwargs
346
- if cmap_intervals is not None and extend_cc is not None:
347
- colorbar_kwargs['extend'] = extend_cc
348
- colorbar(fim, cax=axecc, **colorbar_kwargs)
502
+ plot_imshow(
503
+ x,
504
+ y,
505
+ data,
506
+ axe=axe,
507
+ cmap=cmap,
508
+ minval=minval,
509
+ maxval=maxval,
510
+ vmin=vmin,
511
+ vmax=vmax,
512
+ minval_abs=minval_abs,
513
+ cmap_intervals=cmap_intervals,
514
+ extend_cc=extend_cc,
515
+ plot_colorbar=plot_colorbar,
516
+ axecc=axecc,
517
+ colorbar_kwargs=colorbar_kwargs,
518
+ )
349
519
 
350
520
  # Adjust axes limits
351
521
  if xlims is not None:
@@ -357,11 +527,59 @@ def plot_data_on_topo(x, y, z, data, axe=None, figsize=(15/2.54, 15/2.54),
357
527
  return axe
358
528
 
359
529
 
360
- def plot_maps(x, y, z, data, t, file_name, folder_out=None,
361
- figsize=None, dpi=None, fmt='png',
362
- sup_plt_fn=None,
363
- sup_plt_fn_args=None,
364
- **kwargs):
530
+ def plot_shotgather(x, t, data, xlabel="X (m)", ylabel=None, **kwargs):
531
+ """
532
+ Plot shotgather like image, with vertical axis as time and horizontal axis
533
+ and spatial dimension. This is a simple call to plot_shotgather, but
534
+ input data is transposed because in tilupy the last axis is time by
535
+ convention.
536
+
537
+ Parameters
538
+ ----------
539
+ x : NX-array
540
+ spatial coordinates
541
+ t : NT-array
542
+ time array (assumed in seconds)
543
+ data : TYPE
544
+ NX*NT array of data to be plotted
545
+ spatial_label : string, optional
546
+ label for y-axis. The default is "X (m)"
547
+ **kwargs : dict, optional
548
+ parameters passed on to plot_imshow
549
+
550
+ Returns
551
+ -------
552
+ axe : Axes
553
+ Axes instance where data is plotted
554
+
555
+ """
556
+ if "aspect" not in kwargs:
557
+ kwargs["aspect"] = "auto"
558
+ axe = plot_imshow(x, t[::-1], data.T, **kwargs)
559
+ axe.set_adjustable("box")
560
+ if ylabel is None:
561
+ ylabel = "Time (s)"
562
+ axe.set_ylabel(ylabel)
563
+ axe.set_xlabel(xlabel)
564
+
565
+ return axe
566
+
567
+
568
+ def plot_maps(
569
+ x,
570
+ y,
571
+ z,
572
+ data,
573
+ t,
574
+ file_name=None,
575
+ folder_out=None,
576
+ figsize=None,
577
+ dpi=None,
578
+ fmt="png",
579
+ sup_plt_fn=None,
580
+ sup_plt_fn_args=None,
581
+ **kwargs
582
+ ):
365
583
  """
366
584
  Plot and save maps of simulations outputs at successive time steps
367
585
 
@@ -393,29 +611,34 @@ def plot_maps(x, y, z, data, t, file_name, folder_out=None,
393
611
  nfigs = len(t)
394
612
  if nfigs != data.shape[2]:
395
613
  raise ValueError(
396
- 'length of t must be similar to the last dimension of data')
614
+ "length of t must be similar to the last dimension of data"
615
+ )
397
616
  if folder_out is not None:
398
- file_path = os.path.join(folder_out, file_name + '_{:04d}.' + fmt)
399
- title_fmt = 't = {:.2f} s'
617
+ file_path = os.path.join(folder_out, file_name + "_{:04d}." + fmt)
618
+ title_fmt = "t = {:.2f} s"
400
619
 
401
620
  for i in range(nfigs):
402
- axe = plot_data_on_topo(x, y, z, data[:, :, i], axe=None,
403
- figsize=figsize,
404
- **kwargs)
621
+ axe = plot_data_on_topo(
622
+ x, y, z, data[:, :, i], axe=None, figsize=figsize, **kwargs
623
+ )
405
624
  axe.set_title(title_fmt.format(t[i]))
406
625
  if sup_plt_fn is not None:
407
626
  if sup_plt_fn_args is None:
408
627
  sup_plt_fn_args = dict()
409
628
  sup_plt_fn(axe, **sup_plt_fn_args)
410
- axe.figure.tight_layout(pad=0.1)
629
+ # axe.figure.tight_layout(pad=0.1)
411
630
  if folder_out is not None:
412
- axe.figure.savefig(file_path.format(i), dpi=dpi,
413
- bbox_inches='tight', pad_inches=0.05)
631
+ axe.figure.savefig(
632
+ file_path.format(i),
633
+ dpi=dpi,
634
+ bbox_inches="tight",
635
+ pad_inches=0.05,
636
+ )
414
637
 
415
638
 
416
- def colorbar(mappable, ax=None,
417
- cax=None, size="5%", pad=0.1, position='right',
418
- **kwargs):
639
+ def colorbar(
640
+ mappable, ax=None, cax=None, size="5%", pad=0.1, position="right", **kwargs
641
+ ):
419
642
  """
420
643
  Create nice colorbar matching height/width of axe.
421
644
 
@@ -443,45 +666,54 @@ def colorbar(mappable, ax=None,
443
666
  if ax is None:
444
667
  ax = mappable.axes
445
668
  fig = ax.figure
446
- if position in ['left', 'right']:
447
- orientation = 'vertical'
669
+ if position in ["left", "right"]:
670
+ orientation = "vertical"
448
671
  else:
449
- orientation = 'horizontal'
450
- if cax is None:
451
- # divider = ax.get_axes_locator()
452
- # if divider is None:
453
- divider = make_axes_locatable(ax)
454
- cax = divider.append_axes(position, size=size, pad=pad)
672
+ orientation = "horizontal"
455
673
 
456
- cc = fig.colorbar(mappable, cax=cax,
457
- orientation=orientation, **kwargs)
674
+ # if cax is None:
675
+ # # divider = ax.get_axes_locator()
676
+ # # if divider is None:
677
+ # divider = make_axes_locatable(ax)
678
+ # cax = divider.append_axes(position, size=size, pad=pad)
458
679
 
459
- if position == 'top':
680
+ cc = fig.colorbar(mappable, cax=cax, orientation=orientation, **kwargs)
681
+
682
+ if position == "top":
460
683
  cax.xaxis.tick_top()
461
- cax.xaxis.set_label_position('top')
462
- if position == 'left':
684
+ cax.xaxis.set_label_position("top")
685
+ if position == "left":
463
686
  cax.yaxis.tick_left()
464
- cax.xaxis.set_label_position('left')
687
+ cax.xaxis.set_label_position("left")
465
688
  return cc
466
689
 
467
690
 
468
- def plot_heatmaps(df, values, index, columns, aggfunc='mean',
469
- figsize=None, ncols=3,
470
- heatmap_kws=None, notations=None, best_values=None,
471
- plot_best_value='point', text_kwargs=None):
472
-
691
+ def plot_heatmaps(
692
+ df,
693
+ values,
694
+ index,
695
+ columns,
696
+ aggfunc="mean",
697
+ figsize=None,
698
+ ncols=3,
699
+ heatmap_kws=None,
700
+ notations=None,
701
+ best_values=None,
702
+ plot_best_value="point",
703
+ text_kwargs=None,
704
+ ):
473
705
  nplots = len(values)
474
706
  ncols = min(nplots, ncols)
475
- nrows = int(np.ceil(nplots/ncols))
707
+ nrows = int(np.ceil(nplots / ncols))
476
708
  fig = plt.figure(figsize=figsize)
477
709
  axes = []
478
710
 
479
711
  for i in range(nplots):
480
- axe = fig.add_subplot(nrows, ncols, i+1)
712
+ axe = fig.add_subplot(nrows, ncols, i + 1)
481
713
  axes.append(axe)
482
- data = df.pivot_table(index=index, columns=columns,
483
- values=values[i],
484
- aggfunc=aggfunc).astype(float)
714
+ data = df.pivot_table(
715
+ index=index, columns=columns, values=values[i], aggfunc=aggfunc
716
+ ).astype(float)
485
717
  if heatmap_kws is None:
486
718
  kws = dict()
487
719
  elif isinstance(heatmap_kws, dict):
@@ -490,25 +722,25 @@ def plot_heatmaps(df, values, index, columns, aggfunc='mean',
490
722
  else:
491
723
  kws = heatmap_kws
492
724
 
493
- if 'cmap' not in kws:
725
+ if "cmap" not in kws:
494
726
  minval = data.min().min()
495
727
  maxval = data.max().max()
496
- if minval*maxval < 0:
728
+ if minval * maxval < 0:
497
729
  val = max(np.abs(minval), maxval)
498
- kws['cmap'] = 'seismic'
499
- kws['vmin'] = -val
500
- kws['vmax'] = val
730
+ kws["cmap"] = "seismic"
731
+ kws["vmin"] = -val
732
+ kws["vmax"] = val
501
733
 
502
- if 'cbar_kws' not in kws:
503
- kws['cbar_kws'] = dict(pad=0.03)
734
+ if "cbar_kws" not in kws:
735
+ kws["cbar_kws"] = dict(pad=0.03)
504
736
 
505
737
  if notations is None:
506
- kws['cbar_kws']['label'] = values[i]
738
+ kws["cbar_kws"]["label"] = values[i]
507
739
  else:
508
740
  if values[i] in notations:
509
- kws['cbar_kws']['label'] = notations[values[i]]
741
+ kws["cbar_kws"]["label"] = notations[values[i]]
510
742
  else:
511
- kws['cbar_kws']['label'] = values[i]
743
+ kws["cbar_kws"]["label"] = values[i]
512
744
 
513
745
  sns.heatmap(data, ax=axe, **kws)
514
746
 
@@ -517,50 +749,71 @@ def plot_heatmaps(df, values, index, columns, aggfunc='mean',
517
749
  array = np.array(data)
518
750
  irow = np.arange(data.shape[0])
519
751
 
520
- if best_value == 'min':
752
+ if best_value == "min":
521
753
  ind = np.nanargmin(array, axis=1)
522
754
  i2 = np.nanargmin(array[irow, ind])
523
- if best_value == 'min_abs':
755
+ if best_value == "min_abs":
524
756
  ind = np.nanargmin(np.abs(array), axis=1)
525
757
  i2 = np.nanargmin(np.abs(array[irow, ind]))
526
- elif best_value == 'max':
758
+ elif best_value == "max":
527
759
  ind = np.nanargmax(array, axis=1)
528
760
  i2 = np.nanargmax(array[irow, ind])
529
761
 
530
- if plot_best_value == 'point':
531
- axe.plot(ind + 0.5, irow+0.5, ls='',
532
- marker='o', mfc='w', mec='k', mew=0.5, ms=5)
533
- axe.plot(ind[i2] + 0.5, i2+0.5, ls='',
534
- marker='o', mfc='w', mec='k', mew=0.8, ms=9)
535
- elif plot_best_value == 'text':
762
+ if plot_best_value == "point":
763
+ axe.plot(
764
+ ind + 0.5,
765
+ irow + 0.5,
766
+ ls="",
767
+ marker="o",
768
+ mfc="w",
769
+ mec="k",
770
+ mew=0.5,
771
+ ms=5,
772
+ )
773
+ axe.plot(
774
+ ind[i2] + 0.5,
775
+ i2 + 0.5,
776
+ ls="",
777
+ marker="o",
778
+ mfc="w",
779
+ mec="k",
780
+ mew=0.8,
781
+ ms=9,
782
+ )
783
+ elif plot_best_value == "text":
536
784
  indx = list(ind)
537
785
  indx.pop(i2)
538
786
  indy = list(irow)
539
787
  indy.pop(i2)
540
- default_kwargs = dict(ha='center', va='center',
541
- fontsize=8)
788
+ default_kwargs = dict(ha="center", va="center", fontsize=8)
542
789
  if text_kwargs is None:
543
790
  text_kwargs = default_kwargs
544
791
  else:
545
792
  text_kwargs = dict(default_kwargs, **text_kwargs)
546
793
  for i, j in zip(indx, indy):
547
- axe.text(i + 0.5, j+0.5,
548
- '{:.2g}'.format(array[j, i]),
549
- **text_kwargs)
550
- text_kwargs2 = dict(text_kwargs, fontweight='bold')
551
- axe.text(ind[i2] + 0.5, i2+0.5,
552
- '{:.2g}'.format(array[i2, ind[i2]]),
553
- **text_kwargs2)
794
+ axe.text(
795
+ i + 0.5,
796
+ j + 0.5,
797
+ "{:.2g}".format(array[j, i]),
798
+ **text_kwargs
799
+ )
800
+ text_kwargs2 = dict(text_kwargs, fontweight="bold")
801
+ axe.text(
802
+ ind[i2] + 0.5,
803
+ i2 + 0.5,
804
+ "{:.2g}".format(array[i2, ind[i2]]),
805
+ **text_kwargs2
806
+ )
554
807
 
555
808
  axes = np.array(axes).reshape((nrows, ncols))
556
809
  for i in range(nrows):
557
810
  for j in range(1, ncols):
558
- axes[i, j].set_ylabel('')
811
+ axes[i, j].set_ylabel("")
559
812
  # axes[i, j].set_yticklabels([])
560
813
 
561
- for i in range(nrows-1):
814
+ for i in range(nrows - 1):
562
815
  for j in range(ncols):
563
- axes[i, j].set_xlabel('')
816
+ axes[i, j].set_xlabel("")
564
817
  # axes[i, j].set_xticklabels([])
565
818
 
566
819
  if notations is not None:
@@ -569,6 +822,6 @@ def plot_heatmaps(df, values, index, columns, aggfunc='mean',
569
822
  for j in range(ncols):
570
823
  axes[-1, j].set_xlabel(notations[columns])
571
824
 
572
- fig.tight_layout()
825
+ # fig.tight_layout()
573
826
 
574
827
  return fig