smashbox 1.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. smashbox/.spyproject/config/backups/codestyle.ini.bak +8 -0
  2. smashbox/.spyproject/config/backups/encoding.ini.bak +6 -0
  3. smashbox/.spyproject/config/backups/vcs.ini.bak +7 -0
  4. smashbox/.spyproject/config/backups/workspace.ini.bak +12 -0
  5. smashbox/.spyproject/config/codestyle.ini +8 -0
  6. smashbox/.spyproject/config/defaults/defaults-codestyle-0.2.0.ini +5 -0
  7. smashbox/.spyproject/config/defaults/defaults-encoding-0.2.0.ini +3 -0
  8. smashbox/.spyproject/config/defaults/defaults-vcs-0.2.0.ini +4 -0
  9. smashbox/.spyproject/config/defaults/defaults-workspace-0.2.0.ini +6 -0
  10. smashbox/.spyproject/config/encoding.ini +6 -0
  11. smashbox/.spyproject/config/vcs.ini +7 -0
  12. smashbox/.spyproject/config/workspace.ini +12 -0
  13. smashbox/__init__.py +8 -0
  14. smashbox/asset/flwdir/flowdir_fr_1000m.tif +0 -0
  15. smashbox/asset/outlets/.Rhistory +0 -0
  16. smashbox/asset/outlets/db_bnbv_fr.csv +142704 -0
  17. smashbox/asset/outlets/db_bnbv_light.csv +42084 -0
  18. smashbox/asset/outlets/db_sites.csv +8700 -0
  19. smashbox/asset/outlets/db_stations.csv +2916 -0
  20. smashbox/asset/outlets/db_stations_example.csv +19 -0
  21. smashbox/asset/outlets/edit_database.py +185 -0
  22. smashbox/asset/outlets/readme.txt +5 -0
  23. smashbox/asset/params/ci.tif +0 -0
  24. smashbox/asset/params/cp.tif +0 -0
  25. smashbox/asset/params/ct.tif +0 -0
  26. smashbox/asset/params/kexc.tif +0 -0
  27. smashbox/asset/params/kmlt.tif +0 -0
  28. smashbox/asset/params/llr.tif +0 -0
  29. smashbox/asset/setup/setup_rhax_gr4_dt3600.yaml +15 -0
  30. smashbox/asset/setup/setup_rhax_gr4_dt900.yaml +15 -0
  31. smashbox/asset/setup/setup_rhax_gr5_dt3600.yaml +15 -0
  32. smashbox/asset/setup/setup_rhax_gr5_dt900.yaml +15 -0
  33. smashbox/init/README.md +3 -0
  34. smashbox/init/__init__.py +3 -0
  35. smashbox/init/multimodel_statistics.py +405 -0
  36. smashbox/init/param.py +799 -0
  37. smashbox/init/smashbox.py +186 -0
  38. smashbox/model/__init__.py +1 -0
  39. smashbox/model/atmos_data_connector.py +518 -0
  40. smashbox/model/mesh.py +185 -0
  41. smashbox/model/model.py +829 -0
  42. smashbox/model/setup.py +109 -0
  43. smashbox/plot/__init__.py +1 -0
  44. smashbox/plot/myplot.py +1133 -0
  45. smashbox/plot/plot.py +1662 -0
  46. smashbox/read_inputdata/__init__.py +1 -0
  47. smashbox/read_inputdata/read_data.py +1229 -0
  48. smashbox/read_inputdata/smashmodel.py +395 -0
  49. smashbox/stats/__init__.py +1 -0
  50. smashbox/stats/mystats.py +1632 -0
  51. smashbox/stats/stats.py +2022 -0
  52. smashbox/test.py +532 -0
  53. smashbox/test_average_stats.py +122 -0
  54. smashbox/test_mesh.r +8 -0
  55. smashbox/test_mesh_from_graffas.py +69 -0
  56. smashbox/tools/__init__.py +1 -0
  57. smashbox/tools/geo_toolbox.py +1028 -0
  58. smashbox/tools/tools.py +461 -0
  59. smashbox/tutorial_R.r +182 -0
  60. smashbox/tutorial_R_graffas.r +88 -0
  61. smashbox/tutorial_R_graffas_local.r +33 -0
  62. smashbox/tutorial_python.py +102 -0
  63. smashbox/tutorial_readme.py +261 -0
  64. smashbox/tutorial_report.py +58 -0
  65. smashbox/tutorials/Python_tutorial.md +124 -0
  66. smashbox/tutorials/R_Graffas_tutorial.md +153 -0
  67. smashbox/tutorials/R_tutorial.md +121 -0
  68. smashbox/tutorials/__init__.py +6 -0
  69. smashbox/tutorials/generate_doc.md +7 -0
  70. smashbox-1.0.dist-info/METADATA +998 -0
  71. smashbox-1.0.dist-info/RECORD +73 -0
  72. smashbox-1.0.dist-info/WHEEL +5 -0
  73. smashbox-1.0.dist-info/licenses/LICENSE +100 -0
smashbox/plot/plot.py ADDED
@@ -0,0 +1,1662 @@
1
+ # -*- coding: utf-8 -*-
2
+ import matplotlib.pyplot as plt
3
+ import numpy as np
4
+ import matplotlib
5
+ from matplotlib.colors import ListedColormap
6
+ import pandas as pd
7
+ from pandas import DataFrame
8
+ from smashbox.stats import stats
9
+ from smashbox.tools import geo_toolbox
10
+ import datetime
11
+ from smash import Model
12
+
13
+ import matplotlib.colors as mcolors
14
+ from matplotlib import cm
15
+ import colorsys
16
+ from mpl_toolkits.axes_grid1 import make_axes_locatable
17
+
18
+ import os
19
+ import pandas as pd
20
+ from smashbox.init.param import param
21
+
22
+ from smashbox.tools import tools
23
+
24
+
25
+ class plot_properties:
26
+ """Class which handle differents properties of the matplotlib plot function.
27
+ All attributes can be defined by the user and the object plot_properties can be
28
+ passed to any smashbox plot function.
29
+ """
30
+
31
+ def __init__(
32
+ self,
33
+ ls="-",
34
+ lw=1.5,
35
+ marker="",
36
+ markersize=4,
37
+ color="black",
38
+ label="",
39
+ ):
40
+ self.ls = ls
41
+ """The style of the line (see matplotlib documentation)"""
42
+ self.lw = lw
43
+ """The linewidth of the line (see matplotlib documentation)"""
44
+ self.marker = marker
45
+ """The style of the marker (see matplotlib documentation)"""
46
+ self.markersize = markersize
47
+ """The size of the marker (see matplotlib documentation)"""
48
+ self.color = color
49
+ """The color of the line (see matplotlib documentation)"""
50
+ self.label = label
51
+ """The label of the line (see matplotlib documentation)"""
52
+
53
+ def update(self, **kwargs):
54
+ """Update the class attributes using kwarg (dictionnary)"""
55
+ for key, values in kwargs.items():
56
+ setattr(self, key, values)
57
+
58
+
59
+ class ax_properties:
60
+ """Class which handle differents properties of the matplotlib ax object
61
+ (see matplotlib documentation). All attributes can be defined by the user and the
62
+ object ax_properties can be passed to any smashbox plot function.
63
+ """
64
+
65
+ def __init__(
66
+ self,
67
+ title: str = None,
68
+ xlabel: str = None,
69
+ ylabel: str = None,
70
+ clabel: str = None,
71
+ font_ratio: int = 1,
72
+ title_fontsize: int = 12,
73
+ label_fontsize: int = 10,
74
+ annotate_fontsize: int = 6,
75
+ grid: bool = True,
76
+ xscale: str | None = None,
77
+ yscale: str | None = None,
78
+ legend: bool = True,
79
+ legend_loc: str = None,
80
+ legend_fontsize: int = 8,
81
+ xtics_fontsize: int = 8,
82
+ ytics_fontsize: int = 8,
83
+ cmap: str | None = None,
84
+ xlim: tuple | list | None = (None, None),
85
+ ylim: tuple | list | None = (None, None),
86
+ xticklabels_rotation: int = 0,
87
+ barlabel_fontsize=6,
88
+ ):
89
+
90
+ self.title = title
91
+ """The title of the graphic"""
92
+ self.xlabel = xlabel
93
+ """The label of the x axis"""
94
+ self.ylabel = ylabel
95
+ """The label of the y axis"""
96
+ self.clabel = clabel
97
+ """The label of the colorbar"""
98
+ self.font_ratio = font_ratio
99
+ """Ratio of the global fontsize"""
100
+ self.title_fontsize = title_fontsize
101
+ """The fontsize of the title"""
102
+ self.label_fontsize = label_fontsize
103
+ """The label fontsize"""
104
+ self.annotate_fontsize = annotate_fontsize
105
+ """The annotation fontsize in plot"""
106
+ self.grid = grid
107
+ """Set the grid (boolean), default True"""
108
+ self.xscale = xscale
109
+ """Scale of the x axis (see matplotlib documentation)"""
110
+ self.yscale = yscale
111
+ """Scale of the x axis (see matplotlib documentation)"""
112
+ self.legend = legend
113
+ """Set the legend, boolean, default is True"""
114
+ self.legend_loc = legend_loc
115
+ """Localisation of the legend"""
116
+ self.legend_fontsize = legend_fontsize
117
+ """The fontsize of the legend"""
118
+ self.xtics_fontsize = xtics_fontsize
119
+ """The fontsize of the xtics"""
120
+ self.ytics_fontsize = ytics_fontsize
121
+ """The fontsize of the ytics"""
122
+ self.cmap = cmap
123
+ """The name of the used colormap"""
124
+ self.xlim = xlim
125
+ """The limit of the x axis, tuple or list, default is (None,None)"""
126
+ self.ylim = ylim
127
+ """The limit of the y axis, tuple or list, default is (None,None)"""
128
+ self.xticklabels_rotation = xticklabels_rotation
129
+ """Angle of the xtics labels, float, default is 0."""
130
+ self.barlabel_fontsize = barlabel_fontsize * self.font_ratio
131
+ "Fontsize of the bar top label"
132
+
133
+ def update(self, **kwargs):
134
+ """Update the class attributes using kwarg (dictionnary)"""
135
+ for key, values in kwargs.items():
136
+ setattr(self, key, values)
137
+
138
+ def change(self, figure):
139
+ """Apply change to the current figure `figure`
140
+ Parameter:
141
+ ----------
142
+ figure : list or tuple
143
+ list of (fig, ax), figure and ax of a matplotlib subplot.
144
+ Return:
145
+ -------
146
+ a tuple of the modified (fig,ax)
147
+ """
148
+ fig, ax = figure
149
+
150
+ plt.rcParams.update(plt.rcParamsDefault)
151
+
152
+ plt.rcParams.update(
153
+ {
154
+ "axes.labelsize": self.label_fontsize * self.font_ratio,
155
+ "axes.titlesize": self.title_fontsize * self.font_ratio,
156
+ "legend.fontsize": self.legend_fontsize * self.font_ratio,
157
+ "figure.titlesize": self.title_fontsize * self.font_ratio,
158
+ "xtick.labelsize": self.xtics_fontsize * self.font_ratio,
159
+ "ytick.labelsize": self.ytics_fontsize * self.font_ratio,
160
+ }
161
+ )
162
+
163
+ if self.title is not None:
164
+
165
+ ax.set_title(self.title, fontsize=self.title_fontsize * self.font_ratio)
166
+
167
+ if self.xlabel is not None:
168
+ ax.set_xlabel(self.xlabel, fontsize=self.label_fontsize * self.font_ratio)
169
+
170
+ if self.ylabel is not None:
171
+ ax.set_ylabel(self.ylabel, fontsize=self.label_fontsize * self.font_ratio)
172
+
173
+ if self.grid:
174
+ ax.grid(True, which="both", linestyle="--", alpha=0.5)
175
+
176
+ if self.xscale is not None:
177
+ ax.set_xscale(self.xscale)
178
+
179
+ if self.yscale is not None:
180
+ ax.set_xyscale(self.yscale)
181
+
182
+ if self.legend:
183
+ ax.legend(
184
+ loc=self.legend_loc, fontsize=self.legend_fontsize * self.font_ratio
185
+ )
186
+
187
+ if self.cmap is not None:
188
+ plt.rc("image", cmap=self.cmap)
189
+
190
+ if self.ylim[0] is not None:
191
+ ax.set_ylim(bottom=self.ylim[0])
192
+
193
+ if self.ylim[1] is not None:
194
+ ax.set_ylim(top=self.ylim[1])
195
+
196
+ if self.xlim[0] is not None:
197
+ ax.set_xlim(left=self.xlim[0])
198
+
199
+ if self.xlim[1] is not None:
200
+ ax.set_xlim(right=self.xlim[1])
201
+
202
+ if self.xticklabels_rotation > 0:
203
+ ax.set_xticklabels(
204
+ ax.get_xticklabels(), rotation=self.xticklabels_rotation, ha="right"
205
+ )
206
+
207
+ return fig, ax
208
+
209
+
210
+ class fig_properties:
211
+ """Class which handle differents properties of the matplotlib fig object
212
+ (see matplotlib documentation). All attributes can be defined by the user and the
213
+ object ax_properties can be passed to any smashbox plot function.
214
+ """
215
+
216
+ def __init__(
217
+ self,
218
+ figname=None,
219
+ xsize=8,
220
+ ysize=6,
221
+ transparent=False,
222
+ dpi=160,
223
+ font_ratio=1,
224
+ bbox_inches="tight",
225
+ ):
226
+
227
+ self.figname = figname
228
+ """Path to the figure name to be saved"""
229
+ self.xsize = xsize
230
+ """Width of the figure in inch"""
231
+ self.ysize = ysize
232
+ """Height of the figure in inch"""
233
+ self.transparent = transparent
234
+ """Use transparency when exporting the figure, default is False"""
235
+ self.dpi = dpi
236
+ """Résolution (dpi), int, default is 80"""
237
+ self.font_ratio = font_ratio
238
+ """Global font ratio"""
239
+ self.bbox_inches = bbox_inches
240
+ """Constraint of the boundingbox of each ax (see matplotlib docuentation)"""
241
+
242
+ def update(self, **kwargs):
243
+ """Update the class attributes using kwarg (dictionnary)"""
244
+
245
+ for key, values in kwargs.items():
246
+ setattr(self, key, values)
247
+
248
+ def change(self, figure):
249
+ """Apply change to the current figure `figure`
250
+ Parameter:
251
+ ----------
252
+ figure : list or tuple
253
+ list of (fig, ax), figure and ax of a matplotlib subplot.
254
+ Return:
255
+ -------
256
+ a tuple of the modified (fig,ax)
257
+ """
258
+
259
+ fig, ax = figure
260
+
261
+ fig.set_figheight(self.ysize)
262
+ fig.set_figwidth(self.xsize)
263
+
264
+ plt.rc(
265
+ "font", size=plt.rcParams["font.size"] * self.font_ratio
266
+ ) # controls default text sizes
267
+ plt.rc(
268
+ "axes", titlesize=plt.rcParams["axes.titlesize"] * self.font_ratio
269
+ ) # fontsize of the axes title
270
+ plt.rc(
271
+ "axes", labelsize=plt.rcParams["axes.labelsize"] * self.font_ratio
272
+ ) # fontsize of the x and y labels
273
+ plt.rc(
274
+ "xtick", labelsize=plt.rcParams["xtick.labelsize"] * self.font_ratio
275
+ ) # fontsize of the tick labels
276
+ plt.rc(
277
+ "ytick", labelsize=plt.rcParams["ytick.labelsize"] * self.font_ratio
278
+ ) # fontsize of the tick labels
279
+ plt.rc(
280
+ "legend", fontsize=plt.rcParams["legend.fontsize"] * self.font_ratio
281
+ ) # legend fontsize
282
+ plt.rc(
283
+ "figure", titlesize=plt.rcParams["figure.titlesize"] * self.font_ratio
284
+ ) # fontsize of the figure title
285
+
286
+ if self.figname is not None:
287
+
288
+ head_path, basename = os.path.split(self.figname)
289
+
290
+ if len(head_path) > 0 and not os.path.exists(head_path):
291
+ os.makedirs(head_path)
292
+
293
+ fig.savefig(
294
+ self.figname,
295
+ transparent=self.transparent,
296
+ dpi=self.dpi,
297
+ bbox_inches=self.bbox_inches,
298
+ )
299
+
300
+ return fig, ax
301
+
302
+
303
+ @tools.autocast_args
304
+ def save_figure(
305
+ fig=None, figname="myfigure", xsize=8, ysize=6, transparent=False, dpi=80
306
+ ):
307
+ """
308
+ Save a figure.
309
+ Parameters:
310
+ -----------
311
+ fig: fig object returned by matplotlib.subplot
312
+ the figure to save
313
+ figname: str
314
+ Path to the figure
315
+ xsize: int
316
+ width of the figure in inch
317
+ ysize: int
318
+ height of the figure in inch
319
+ transparent: bool, default is False
320
+ use transparency
321
+ dpi : int
322
+ resolution of the figure, default is 80
323
+ """
324
+ fig.set_size_inches(xsize, ysize, forward=True)
325
+ fig.savefig(figname, transparent=transparent, dpi=dpi, bbox_inches="tight")
326
+
327
+
328
+ def generate_palette(base_color, n, variation="hue"):
329
+ """
330
+ Generate a palette of colors from a base color.
331
+ Parameter:
332
+ ---------
333
+ base_color: str
334
+ matplotlib color string
335
+ n: int
336
+ number of color to generate
337
+ variation: 'hue' | 'brightness'
338
+ how to generate the color palette, by changing the hue or the brighness of the base color.
339
+ Return: a list of colors
340
+ """
341
+ # Convertir la couleur de base en format RGB normalisé (0-1)
342
+ rgb = mcolors.to_rgb(base_color)
343
+
344
+ # Convertir en HSV
345
+ h, s, v = colorsys.rgb_to_hsv(*rgb)
346
+
347
+ # Générer n couleurs en modifiant la teinte ou valeur
348
+ palette = []
349
+ for i in range(n):
350
+ new_h = h
351
+ if variation == "hue":
352
+ new_h = (h + i / n) % 1.0 # cycle dans le cercle chromatique
353
+ new_v = v
354
+ if variation == "brightness":
355
+ new_v = max(0.1, min(1.0, v * (0.5 + i / (2 * n)))) # éviter le noir complet
356
+
357
+ new_rgb = colorsys.hsv_to_rgb(new_h, s, new_v)
358
+ palette.append(new_rgb)
359
+
360
+ return palette
361
+
362
+
363
+ @tools.autocast_args
364
+ def plot_chro(
365
+ data: np.ndarray = np.zeros(shape=(1, 10)),
366
+ t_axis: int = 1,
367
+ outlets_name: list | tuple = [],
368
+ columns: list | tuple = [],
369
+ dt: float = 0.0,
370
+ xtics: list | tuple = [],
371
+ date_range: list = None,
372
+ figure=None,
373
+ ax_settings: dict | ax_properties = ax_properties(),
374
+ fig_settings: dict | fig_properties = fig_properties(),
375
+ plot_settings: dict | plot_properties = plot_properties(),
376
+ ):
377
+ """
378
+ Plot a temporal chonic of values
379
+ Parameters:
380
+ -----------
381
+ data: np.ndarray of dimenion 2.
382
+ data to plot as a matrix of 2 dimension.
383
+ t_axis : int
384
+ the axis of the time in data, default is 1
385
+ outlets_name: list
386
+ the list of the outlets name
387
+ columns : list
388
+ the column to be plotted in t_axis direction
389
+ dt: float
390
+ the timestep
391
+ xtics : list
392
+ list of date for the xtics. The format must be automatically read by numpy.Datetime
393
+ date_range: list
394
+ list of [date_start, date_end, timedelta] to generate the xtics
395
+ figure: tuple
396
+ input figure as (fig,ax) to add a new curve
397
+ ax_settings: dict or class ax_properties
398
+ object or dict with any attribute of class ax_properties
399
+ fig_settings: dict or class ax_properties
400
+ object or dict with any attribute of class fig_settings
401
+ plot_settings: dict or class plot_properties
402
+ object or dict with any attribute of class plot_settings
403
+
404
+ """
405
+
406
+ if isinstance(ax_settings, dict):
407
+ ax_settings = ax_properties(**ax_settings)
408
+ else:
409
+ ax_settings = ax_properties(**ax_settings.__dict__)
410
+
411
+ if isinstance(fig_settings, dict):
412
+ fig_settings = fig_properties(**fig_settings)
413
+ else:
414
+ fig_settings = fig_properties(**fig_settings.__dict__)
415
+
416
+ if isinstance(plot_settings, dict):
417
+ plot_settings = plot_properties(**plot_settings)
418
+ else:
419
+ plot_settings = plot_properties(**plot_settings.__dict__)
420
+
421
+ data = np.moveaxis(data, t_axis, 0)
422
+
423
+ if figure is not None:
424
+ fig, ax = figure
425
+ else:
426
+ fig, ax = plt.subplots()
427
+
428
+ fig, ax = ax_settings.change(figure=(fig, ax))
429
+
430
+ if len(xtics) == 0:
431
+ xtics = np.arange(0, data.shape[0])
432
+ if dt > 0:
433
+ xtics = xtics * dt
434
+ else:
435
+ for i in range(len(xtics)):
436
+ xtics[i] = np.datetime64(xtics[i])
437
+
438
+ if date_range is not None:
439
+ if len(date_range) != 3:
440
+ raise ValueError(
441
+ "date_range must have a length of 3: [date_start, date_end, step (s)]"
442
+ )
443
+ xtics = np.arange(
444
+ np.datetime64(date_range[0]),
445
+ np.datetime64(date_range[1] + pd.Timedelta(seconds=int(date_range[2]))),
446
+ np.timedelta64(int(date_range[2]), "s"),
447
+ )
448
+
449
+ if len(columns) > 0:
450
+
451
+ args = plot_settings.__dict__.copy()
452
+ print(args)
453
+ del args["label"]
454
+ del args["color"]
455
+
456
+ palette = generate_palette(plot_settings.color, len(columns))
457
+
458
+ for i in columns:
459
+ ax.plot(xtics[:], data[:, i], **args, label=outlets_name[i], color=palette[i])
460
+ else:
461
+ ax.plot(xtics[:], data[:, 0], **plot_settings.__dict__)
462
+
463
+ fig, ax = ax_settings.change(figure=(fig, ax))
464
+ fig, ax = fig_settings.change(figure=(fig, ax))
465
+
466
+ return fig, ax
467
+
468
+
469
+ def plot_hydrograph(
470
+ model: Model | None = None,
471
+ columns: list | tuple = [],
472
+ outlets_name: list | tuple = [],
473
+ plot_rainfall: bool = True,
474
+ figure: list | tuple | None = None,
475
+ ax_settings: dict | ax_properties = {},
476
+ fig_settings: dict | fig_properties = {},
477
+ plot_settings_sim: dict | plot_properties = {},
478
+ plot_settings_obs: dict | plot_properties = {},
479
+ ):
480
+ """
481
+ Plot an hydrograph from a smash model
482
+ Parameters:
483
+ -----------
484
+ model: a smash model object
485
+ a smash model object
486
+ outlets_name: list
487
+ the list of the outlets name
488
+ columns : list
489
+ the column to be plotted in t_axis direction
490
+ figure: tuple
491
+ input figure as (fig,ax) to add a new curve
492
+ ax_settings: dict or class ax_properties
493
+ object or dict with any attribute of class ax_properties
494
+ fig_settings: dict or class ax_properties
495
+ object or dict with any attribute of class fig_settings
496
+ plot_settings_sim: dict or class ax_properties
497
+ object or dict with any attribute of class plot_settings.Control the simulated curve
498
+ plot_settings_obs: dict or class ax_properties
499
+ object or dict with any attribute of class plot_settings. Control the observed curve
500
+
501
+ """
502
+ if model is None:
503
+ raise ValueError("Input smash model object is None.")
504
+
505
+ if isinstance(ax_settings, dict):
506
+ default_ax_settings = ax_properties(
507
+ xlabel="Time", ylabel="discharges m^3/s", xtics_fontsize=10, ytics_fontsize=10
508
+ )
509
+ default_ax_settings.update(**ax_settings)
510
+ else:
511
+ default_ax_settings = ax_properties(**ax_settings.__dict__)
512
+
513
+ if isinstance(fig_settings, dict):
514
+ fig_settings = fig_properties(**fig_settings)
515
+ else:
516
+ fig_settings = fig_properties(**fig_settings.__dict__)
517
+
518
+ if isinstance(plot_settings_sim, dict):
519
+ default_plot_settings_sim = plot_properties(
520
+ ls="-",
521
+ lw="2",
522
+ marker="",
523
+ markersize=4,
524
+ color="blue",
525
+ label="Sim",
526
+ )
527
+ default_plot_settings_sim.update(**plot_settings_sim)
528
+ else:
529
+ default_plot_settings_sim = plot_properties(**plot_settings_sim.__dict__)
530
+
531
+ # default color for multi curves: same color but different line type
532
+ if len(columns) >= 2:
533
+ color = "blue"
534
+ else:
535
+ color = "black"
536
+
537
+ if isinstance(plot_settings_obs, dict):
538
+ default_plot_settings_obs = plot_properties(
539
+ ls="--",
540
+ lw="1.5",
541
+ marker="",
542
+ markersize=4,
543
+ color=color,
544
+ label="Obs",
545
+ )
546
+ default_plot_settings_obs.update(**plot_settings_obs)
547
+ else:
548
+ default_plot_settings_obs = plot_properties()
549
+
550
+ # manage date here
551
+ date_deb = datetime.datetime.fromisoformat(
552
+ model.setup.start_time
553
+ ) + datetime.timedelta(seconds=int(model.setup.dt))
554
+ date_end = datetime.datetime.fromisoformat(model.setup.end_time)
555
+ date_range = [date_deb, date_end, model.setup.dt]
556
+
557
+ if figure is None:
558
+ if plot_rainfall:
559
+ fig, (ax1, ax2) = plt.subplots(2, 1, height_ratios=[1, 4])
560
+ fig.subplots_adjust(hspace=0)
561
+ figure = [fig, ax2, ax1]
562
+ else:
563
+ fig, ax2 = plt.subplots()
564
+ figure = [fig, ax2]
565
+ else:
566
+ if plot_rainfall:
567
+ fig = figure[0]
568
+ ax1 = figure[2]
569
+ ax2 = figure[1]
570
+ else:
571
+ fig = figure[0]
572
+ ax2 = figure[1]
573
+
574
+ fig, ax = default_ax_settings.change(figure=(fig, ax2))
575
+
576
+ fig, ax2 = plot_chro(
577
+ model.response_data.q,
578
+ date_range=date_range,
579
+ columns=columns,
580
+ outlets_name=["obs_" + name for name in outlets_name],
581
+ figure=(fig, ax2),
582
+ ax_settings=default_ax_settings,
583
+ fig_settings=fig_settings,
584
+ plot_settings=default_plot_settings_obs,
585
+ )
586
+
587
+ fig, ax2 = plot_chro(
588
+ model.response.q,
589
+ date_range=date_range,
590
+ columns=columns,
591
+ outlets_name=["sim_" + name for name in outlets_name],
592
+ figure=(fig, ax2),
593
+ ax_settings=default_ax_settings,
594
+ fig_settings=fig_settings,
595
+ plot_settings=default_plot_settings_sim,
596
+ )
597
+
598
+ xtics = np.arange(
599
+ np.datetime64(date_range[0]),
600
+ np.datetime64(date_range[1] + datetime.timedelta(seconds=int(date_range[2]))),
601
+ np.timedelta64(int(date_range[2]), "s"),
602
+ )
603
+
604
+ if plot_rainfall:
605
+
606
+ if len(columns) > 0:
607
+ col = columns[0]
608
+ else:
609
+ col = 0
610
+
611
+ ax1.bar(
612
+ xtics[:],
613
+ model.atmos_data.mean_prcp[col, :],
614
+ label="Average rainfall (mm)",
615
+ width=np.timedelta64(int(date_range[2]), "s"),
616
+ )
617
+
618
+ ax1.invert_yaxis()
619
+ ax1.grid(alpha=0.7, ls="--")
620
+ ax1.get_xaxis().set_visible(False)
621
+ ax1.set_ylim(bottom=1.2 * max(model.atmos_data.mean_prcp[0, :]), top=0.0)
622
+ ax1.set_ylabel("Average rainfall (mm)")
623
+
624
+ fig, ax = fig_settings.change(figure=(fig, (ax1, ax2)))
625
+
626
+ return fig, ax
627
+
628
+
629
+ def plot_catchment_surface_error(
630
+ mesh: dict = None,
631
+ ax_settings: dict | ax_properties = {},
632
+ fig_settings: dict | fig_properties = {},
633
+ ):
634
+ """
635
+ Plot the misfit criteria between the simulated and observed discharges.
636
+ Parameters:
637
+ -----------
638
+ values: np.ndarray
639
+ The result of the discharge misfit for all outlets.
640
+ names: np.ndarray
641
+ Outlets name or code stored in an np.ndarray.
642
+ columns: list | None
643
+ Columns of the np.ndarray to plot
644
+ misfit: str
645
+ Criteria to plot. choice are ['nse', 'nnse', 'rmse', 'nrmse', 'se', 'kge']"
646
+ figure: tuple
647
+ input figure as (fig,ax) to add a new curve
648
+ ax_settings: dict or class ax_properties
649
+ object or dict with any attribute of class ax_properties
650
+ fig_settings: dict or class ax_properties
651
+ object or dict with any attribute of class fig_settings
652
+ """
653
+
654
+ if isinstance(ax_settings, dict):
655
+ default_ax_settings = ax_properties(
656
+ title="Catchment surface error (Ssim-Sobs)/Sobs *100",
657
+ ylabel="Surface error %",
658
+ xlabel="Outlets",
659
+ xticklabels_rotation=45,
660
+ xtics_fontsize=6,
661
+ )
662
+ default_ax_settings.update(**ax_settings)
663
+ else:
664
+ default_ax_settings = ax_properties(**ax_settings.__dict__)
665
+
666
+ if isinstance(fig_settings, dict):
667
+ fig_settings = fig_properties(**fig_settings)
668
+ else:
669
+ fig_settings = fig_properties(**fig_settings.__dict__)
670
+
671
+ plt.rcParams.update(plt.rcParamsDefault)
672
+ fig, ax = plt.subplots()
673
+ fig, ax = default_ax_settings.change(figure=(fig, ax))
674
+
675
+ surface_error = (mesh["area_dln"] - mesh["area"]) / mesh["area"] * 100
676
+
677
+ fig, ax = default_ax_settings.change(figure=(fig, ax))
678
+ bar_container = ax.bar(
679
+ mesh["code"], surface_error, color="grey", tick_label=mesh["code"]
680
+ )
681
+
682
+ ax.bar_label(
683
+ bar_container,
684
+ fmt=lambda x: f"{x:.2f}",
685
+ fontsize=default_ax_settings.barlabel_fontsize,
686
+ )
687
+
688
+ fig, ax = default_ax_settings.change(figure=(fig, ax))
689
+ fig, ax = fig_settings.change(figure=(fig, ax))
690
+
691
+ return fig, ax
692
+
693
+
694
+ def plot_catchment_surface_consistency(
695
+ mesh: dict = None,
696
+ label: bool = True,
697
+ ax_settings: dict | ax_properties = {},
698
+ fig_settings: dict | fig_properties = {},
699
+ plot_settings: dict | plot_properties = {},
700
+ ):
701
+ """
702
+ Plot the modeled surface vs the observed surface
703
+ Parameters:
704
+ -----------
705
+ mesh: dict, optional
706
+ The mesh of the Smash model, defaults to None
707
+ ax_settings: dict or class ax_properties
708
+ object or dict with any attribute of class ax_properties
709
+ fig_settings: dict or class ax_properties
710
+ object or dict with any attribute of class fig_settings
711
+ plot_settings: dict or class plot_properties
712
+ object or dict with any attribute of class plot_settings.Control the simulated curve
713
+ """
714
+
715
+ if isinstance(ax_settings, dict):
716
+ default_ax_settings = ax_properties(
717
+ title="Modeled and observed surface consistency",
718
+ ylabel="Modeled surface",
719
+ xlabel="Observed surface",
720
+ )
721
+ default_ax_settings.update(**ax_settings)
722
+ else:
723
+ default_ax_settings = ax_properties(**ax_settings.__dict__)
724
+
725
+ if isinstance(fig_settings, dict):
726
+ fig_settings = fig_properties(**fig_settings)
727
+ else:
728
+ fig_settings = fig_properties(**fig_settings.__dict__)
729
+
730
+ if isinstance(plot_settings, dict):
731
+ default_plot_settings = plot_properties(
732
+ marker="+",
733
+ markersize=12,
734
+ color="blue",
735
+ )
736
+ default_plot_settings.update(**plot_settings)
737
+ else:
738
+ default_plot_settings = plot_properties(**plot_settings.__dict__)
739
+
740
+ surface_model = mesh["area_dln"] / 1000.0**2.0
741
+ surface_obs = mesh["area"] / 1000.0**2.0
742
+
743
+ plt.rcParams.update(plt.rcParamsDefault)
744
+ fig, ax = plt.subplots()
745
+ fig, ax = default_ax_settings.change(figure=(fig, ax))
746
+
747
+ ax.plot(
748
+ surface_obs,
749
+ surface_model,
750
+ markersize=default_plot_settings.markersize,
751
+ marker=default_plot_settings.marker,
752
+ color=default_plot_settings.color,
753
+ linestyle="None",
754
+ )
755
+ ax.plot(
756
+ np.linspace(min(surface_obs), max(surface_obs), 10),
757
+ np.linspace(min(surface_obs), max(surface_obs), 10),
758
+ linewidth=2,
759
+ color="grey",
760
+ )
761
+
762
+ if label:
763
+ ha = ("left", "right")
764
+ for i, label in enumerate(mesh["code"]):
765
+ ax.annotate(
766
+ label, # this is the text
767
+ (
768
+ surface_obs[i],
769
+ surface_model[i],
770
+ ), # these are the coordinates to position the label
771
+ textcoords="data", # how to position the text
772
+ xytext=(
773
+ surface_obs[i],
774
+ surface_model[i],
775
+ ), # distance from text to points (x,y)
776
+ ha=ha[i % 2], # horizontal alignment can be left, right or center
777
+ color="red",
778
+ fontsize=default_ax_settings.annotate_fontsize,
779
+ )
780
+
781
+ ax.set(xlabel=default_ax_settings.xlabel, ylabel=default_ax_settings.ylabel)
782
+
783
+ fig, ax = fig_settings.change(figure=(fig, ax))
784
+
785
+ return fig, ax
786
+
787
+
788
+ def plot_mesh(
789
+ mesh: dict = None,
790
+ coef_hydro: float = 99.0,
791
+ catchment_polygon: None | DataFrame = None,
792
+ ax_settings: dict | ax_properties = {},
793
+ fig_settings: dict | fig_properties = {},
794
+ ):
795
+ """
796
+ Plot the mesh of a smash model
797
+ Parameters:
798
+ -----------
799
+ mesh: a smash mesh as dictionary
800
+ a smash model object
801
+ coef_hydro: float
802
+ the coefficient to colorize the hydrographic network accodring the cumulative
803
+ surface. default is 99% so that 99% of the cell will be hidden.
804
+ ax_settings: dict or class ax_properties
805
+ object or dict with any attribute of class ax_properties
806
+ fig_settings: dict or class ax_properties
807
+ object or dict with any attribute of class fig_settings
808
+ """
809
+
810
+ if mesh is not None:
811
+ if isinstance(mesh, dict):
812
+ pass
813
+ else:
814
+ raise ValueError("mesh must be a dict")
815
+ else:
816
+ raise ValueError(
817
+ "model or mesh are mandatory and must be a dict or a smash Model object"
818
+ )
819
+
820
+ if isinstance(ax_settings, dict):
821
+ default_ax_settings = ax_properties(
822
+ title="Mesh of the Smash model",
823
+ xlabel="x_coords",
824
+ ylabel="y_coords",
825
+ )
826
+ default_ax_settings.update(**ax_settings)
827
+ else:
828
+ default_ax_settings = ax_properties(**ax_settings.__dict__)
829
+
830
+ if isinstance(fig_settings, dict):
831
+ fig_settings = fig_properties(**fig_settings)
832
+ else:
833
+ fig_settings = fig_properties(**fig_settings.__dict__)
834
+
835
+ # mesh["active_cell"]
836
+ gauge = mesh["gauge_pos"]
837
+ stations = mesh["code"]
838
+ flow_acc = mesh["flwacc"]
839
+ na = mesh["active_cell"] == 0
840
+
841
+ flow_accum_bv = np.where(na, 0.0, flow_acc.data / 1000000.0)
842
+ surfmin = (1.0 - coef_hydro / 100.0) * np.max(flow_accum_bv)
843
+ mask_flow = flow_accum_bv < surfmin
844
+ flow_plot = np.where(mask_flow, np.nan, flow_accum_bv)
845
+ flow_plot = np.where(na, np.nan, flow_plot)
846
+
847
+ plt.rcParams.update(plt.rcParamsDefault)
848
+ fig, ax = plt.subplots()
849
+ fig, ax = default_ax_settings.change(figure=(fig, ax))
850
+
851
+ bbox = geo_toolbox.get_bbox_from_smash_mesh(mesh)
852
+ extent = (bbox["left"], bbox["right"], bbox["bottom"], bbox["top"])
853
+
854
+ active_cell = np.where(na, np.nan, mesh["active_cell"])
855
+ cmap = ListedColormap(["lightgray"])
856
+ ax.imshow(active_cell, cmap=cmap, extent=extent)
857
+
858
+ myblues = matplotlib.colormaps["Blues"]
859
+ cmp = ListedColormap(myblues(np.linspace(0.30, 1.0, 265)))
860
+ im = ax.imshow(flow_plot, cmap=cmp, extent=extent)
861
+
862
+ if catchment_polygon is not None:
863
+ # catchment_polygon = gpd.read_file(outlets_shapefile)
864
+ catchment_polygon.plot(ax=ax, facecolor="none", edgecolor="black")
865
+
866
+ # create an axes on the right side of ax. The width of cax will be 5%
867
+ # of ax and the padding between cax and ax will be fixed at 0.05 inch.
868
+ divider = make_axes_locatable(ax)
869
+ cax = divider.append_axes("right", size="5%", pad=0.05)
870
+
871
+ fig.colorbar(
872
+ im, cmap="Blues", ax=ax, label="Cumulated surface (km²)", shrink=0.75, cax=cax
873
+ )
874
+
875
+ pos_y = -5
876
+ ha = "right"
877
+ for i in range(len(stations)):
878
+ if pos_y > 0:
879
+ pos_y = -10
880
+ else:
881
+ pos_y = 5
882
+ # pos_y=-1*pos_y
883
+
884
+ if ha == "right":
885
+ ha = "left"
886
+ pos_x = 5
887
+ else:
888
+ ha = "right"
889
+ # pos_x = -5
890
+
891
+ coord = geo_toolbox.rowcol_to_xy(
892
+ gauge[i][0],
893
+ gauge[i][1],
894
+ mesh["xmin"],
895
+ mesh["ymax"],
896
+ mesh["xres"],
897
+ mesh["yres"],
898
+ )
899
+
900
+ code = stations[i]
901
+ ax.plot(coord[0], coord[1], color="green", marker="o", markersize=6)
902
+ ax.annotate(
903
+ code, # this is the text
904
+ # these are the coordinates to position the label
905
+ (coord[0], coord[1]),
906
+ # textcoords="offset points", # how to position the text
907
+ # xytext=(pos_x, pos_y), # distance from text to points (x,y)
908
+ textcoords="data", # how to position the text
909
+ xytext=(coord[0], coord[1]), # distance from text to points (x,y)
910
+ ha=ha, # horizontal alignment can be left, right or center
911
+ color="red",
912
+ fontsize=10,
913
+ )
914
+
915
+ fig, ax = default_ax_settings.change(figure=(fig, ax))
916
+
917
+ fig, ax = fig_settings.change(figure=(fig, ax))
918
+
919
+ return fig, ax
920
+
921
+
922
+ def plot_xy_quantile(
923
+ res_quantile,
924
+ X,
925
+ Y,
926
+ res_quantile_obs=None,
927
+ gauge_pos=None,
928
+ figure=None,
929
+ ax_settings: dict | ax_properties = {},
930
+ fig_settings: dict | fig_properties = {},
931
+ plot_settings: dict | plot_properties = {},
932
+ ):
933
+ """
934
+ Plot the discharges quantiles fitting at X,Y coordinates.
935
+ Parameters:
936
+ -----------
937
+ res_quantile: dict
938
+ The result of the discharge quantile computation.
939
+ res_quantile_obs: dict
940
+ The results of the observed discharges quantile. res_quantile_obs is a dict and must be computed by the function smashbox.stats.stats.quantile_obs()
941
+ gauge_pos: int
942
+ gauge_pos is the index of gauge in the Smash mesh for which the quantile_discharge are provided to the function.
943
+ X: int
944
+ Coordinates of the pixel in the row directions (X means row)
945
+ Y: int
946
+ Coordinates of the pixel in the column directions (Y means column)
947
+ figure: tuple
948
+ input figure as (fig,ax) to add a new curve
949
+ ax_settings: dict or class ax_properties
950
+ object or dict with any attribute of class ax_properties
951
+ fig_settings: dict or class ax_properties
952
+ object or dict with any attribute of class fig_settings
953
+ """
954
+ if isinstance(ax_settings, dict):
955
+ default_ax_settings = ax_properties(
956
+ xscale="log",
957
+ xlabel=f"Return period (*{res_quantile['chunk_size']} days)",
958
+ ylabel="Discharges (m³/s)",
959
+ grid=True,
960
+ legend=True,
961
+ )
962
+ default_ax_settings.update(**ax_settings)
963
+ else:
964
+ default_ax_settings = ax_properties(**ax_settings.__dict__)
965
+
966
+ if isinstance(fig_settings, dict):
967
+ fig_settings = fig_properties(**fig_settings)
968
+ else:
969
+ fig_settings = fig_properties(**fig_settings.__dict__)
970
+
971
+ if isinstance(plot_settings, dict):
972
+ default_plot_settings = plot_properties(markersize=10)
973
+ default_plot_settings.update(**plot_settings)
974
+ else:
975
+ default_plot_settings = plot_properties(**plot_settings.__dict__)
976
+
977
+ quantile = res_quantile["Q_th"][X, Y]
978
+ maxima = res_quantile["maxima"][X, Y]
979
+ T_emp = res_quantile["T_emp"]
980
+ loc = res_quantile["fit_loc"][X, Y]
981
+ scale = res_quantile["fit_scale"][X, Y]
982
+ shape = res_quantile["fit_shape"][X, Y]
983
+ fit = res_quantile["fit"]
984
+
985
+ sorted_data = np.sort(maxima)
986
+
987
+ plt.rcParams.update(plt.rcParamsDefault)
988
+ if figure is None:
989
+ fig, ax = plt.subplots()
990
+ else:
991
+ fig, ax = figure
992
+
993
+ fig, ax = default_ax_settings.change(figure=(fig, ax))
994
+
995
+ if res_quantile_obs is not None and len(res_quantile_obs.keys()) > 0:
996
+ if gauge_pos is None:
997
+ raise ValueError(
998
+ "gauge_pos is None. gauge_pos argument must be an integer corresponding to the gauge index."
999
+ )
1000
+ maxima_obs = res_quantile_obs["maxima"][gauge_pos, :]
1001
+ T_emp_obs = res_quantile_obs["Temp"][gauge_pos, :]
1002
+
1003
+ ax.plot(
1004
+ T_emp_obs,
1005
+ maxima_obs,
1006
+ "o",
1007
+ label="Observed",
1008
+ color="black",
1009
+ markersize=default_plot_settings.markersize,
1010
+ )
1011
+
1012
+ ax.plot(
1013
+ T_emp,
1014
+ sorted_data,
1015
+ "o",
1016
+ label="Empirical",
1017
+ markersize=default_plot_settings.markersize,
1018
+ )
1019
+
1020
+ ax.plot(
1021
+ res_quantile["T"],
1022
+ quantile,
1023
+ "x",
1024
+ label="Theorical",
1025
+ markersize=default_plot_settings.markersize,
1026
+ )
1027
+
1028
+ Trange = np.linspace(1.1, np.max(res_quantile["T"]), 50)
1029
+
1030
+ if fit == "gumbel":
1031
+ ax.plot(
1032
+ Trange,
1033
+ [stats.quantile_gumbel(T, loc, scale) for T in Trange],
1034
+ "r--",
1035
+ label=f"{fit} fitted",
1036
+ lw=default_plot_settings.lw,
1037
+ )
1038
+
1039
+ if fit == "gev":
1040
+ ax.plot(
1041
+ Trange,
1042
+ [stats.quantile_gev(T, shape, loc, scale) for T in Trange],
1043
+ "r--",
1044
+ label=f"{fit} fitted",
1045
+ lw=default_plot_settings.lw,
1046
+ )
1047
+
1048
+ if "Umax" in res_quantile.keys() and "Umin" in res_quantile.keys():
1049
+ if res_quantile["Umax"] is not None and res_quantile["Umin"] is not None:
1050
+ ax.plot(
1051
+ res_quantile["T"],
1052
+ res_quantile["Umax"][X, Y],
1053
+ "r--",
1054
+ label="Uncertainties (max)",
1055
+ color="grey",
1056
+ lw=default_plot_settings.lw,
1057
+ )
1058
+ ax.plot(
1059
+ res_quantile["T"],
1060
+ res_quantile["Umin"][X, Y],
1061
+ "r--",
1062
+ label="Uncertainties (min)",
1063
+ color="grey",
1064
+ lw=default_plot_settings.lw,
1065
+ )
1066
+
1067
+ fig, ax = default_ax_settings.change(figure=(fig, ax))
1068
+ fig, ax = fig_settings.change(figure=(fig, ax))
1069
+
1070
+ return fig, ax
1071
+
1072
+
1073
+ def plot_image(
1074
+ matrice=np.zeros(shape=(2, 2)),
1075
+ bbox=None,
1076
+ vmin=None,
1077
+ vmax=None,
1078
+ mask=None,
1079
+ extend=None,
1080
+ catchment_polygon=None,
1081
+ figure=None,
1082
+ ax_settings: dict | ax_properties = {},
1083
+ fig_settings: dict | fig_properties = {},
1084
+ ):
1085
+ """
1086
+ Function for plotting a matrix as an image
1087
+
1088
+ Parameters
1089
+ ----------
1090
+ matrice : numpy array
1091
+ Matrix to be plotted
1092
+ bbox : list
1093
+ ["left","right","bottom","top"] bouding box to put x and y coordinates instead
1094
+ of the shape of the matrix
1095
+ vmin: real,
1096
+ minimum z value
1097
+ vmax: real,
1098
+ maximum z value
1099
+ mask: integer, matrix, shape of matice, contain 0 for pixels that should not be plotted
1100
+ catchment_polygon: dataframe containing some polygon to be plotted.
1101
+ Ideally it must contain the boundaries of the catchment as a polygon from a shp file
1102
+ read by geopanda.
1103
+ figure: tuple
1104
+ input figure as (fig,ax) to add a new curve
1105
+ ax_settings: dict or class ax_properties
1106
+ object or dict with any attribute of class ax_properties
1107
+ fig_settings: dict or class ax_properties
1108
+ object or dict with any attribute of class fig_settings
1109
+
1110
+ Examples
1111
+ ----------
1112
+ smash.utils.plot_image(mesh_france['drained_area'],bbox=bbox,title="Surfaces
1113
+ drainées",xlabel="Longitude",ylabel="Latitude",zlabel="Surfaces drainées
1114
+ km^2",vmin=0.0,vmax=1000,mask=mesh_france['global_active_cell'])
1115
+
1116
+ """
1117
+
1118
+ if isinstance(ax_settings, dict):
1119
+ ax_settings = ax_properties(**ax_settings)
1120
+ else:
1121
+ ax_settings = ax_properties(**ax_settings.__dict__)
1122
+
1123
+ if isinstance(fig_settings, dict):
1124
+ fig_settings = fig_properties(**fig_settings)
1125
+ else:
1126
+ fig_settings = fig_properties(**fig_settings.__dict__)
1127
+
1128
+ matrice = np.float32(matrice)
1129
+
1130
+ if bbox is not None:
1131
+ extent = [
1132
+ bbox["left"],
1133
+ bbox["right"],
1134
+ bbox["bottom"],
1135
+ bbox["top"],
1136
+ ]
1137
+ else:
1138
+ extent = None
1139
+
1140
+ if mask is not None:
1141
+ matrice[np.where(mask == 0)] = np.nan
1142
+
1143
+ plt.rcParams.update(plt.rcParamsDefault)
1144
+ if figure is None:
1145
+ fig, ax = plt.subplots()
1146
+ else:
1147
+ fig, ax = figure
1148
+
1149
+ fig, ax = ax_settings.change(figure=(fig, ax))
1150
+
1151
+ im = ax.imshow(matrice, extent=extent, vmin=vmin, vmax=vmax, cmap=ax_settings.cmap)
1152
+
1153
+ if catchment_polygon is not None:
1154
+ catchment_polygon.plot(ax=ax, facecolor="none", edgecolor="black")
1155
+
1156
+ # create an axes on the right side of ax. The width of cax will be 5%
1157
+ # of ax and the padding between cax and ax will be fixed at 0.05 inch.
1158
+ divider = make_axes_locatable(ax)
1159
+ cax = divider.append_axes("right", size="5%", pad=0.05)
1160
+
1161
+ plt.colorbar(im, label=ax_settings.clabel, cax=cax)
1162
+
1163
+ fig, ax = ax_settings.change(figure=(fig, ax))
1164
+ fig, ax = fig_settings.change(figure=(fig, ax))
1165
+
1166
+ return (fig, ax)
1167
+
1168
+
1169
+ def plot_misfit(
1170
+ values: np.ndarray = [],
1171
+ names: np.ndarray = [],
1172
+ columns: list | None = None,
1173
+ misfit: str = "nse",
1174
+ figure: list | tuple | None = None,
1175
+ ax_settings: dict | ax_properties = {},
1176
+ fig_settings: dict | fig_properties = {},
1177
+ ):
1178
+ """
1179
+ Plot the misfit criteria between the simulated and observed discharges.
1180
+ Parameters:
1181
+ -----------
1182
+ values: np.ndarray
1183
+ The result of the discharge misfit for all outlets.
1184
+ names: np.ndarray
1185
+ Outlets name or code stored in an np.ndarray.
1186
+ columns: list | None
1187
+ Columns of the np.ndarray to plot
1188
+ misfit: str
1189
+ Criteria to plot. choice are ['nse', 'nnse', 'rmse', 'nrmse', 'se', 'kge']"
1190
+ figure: tuple
1191
+ input figure as (fig,ax) to add a new curve
1192
+ ax_settings: dict or class ax_properties
1193
+ object or dict with any attribute of class ax_properties
1194
+ fig_settings: dict or class ax_properties
1195
+ object or dict with any attribute of class fig_settings
1196
+ """
1197
+
1198
+ if isinstance(ax_settings, dict):
1199
+ default_ax_settings = ax_properties(
1200
+ ylabel=f"{misfit} criteria",
1201
+ xlabel="Gauges stations",
1202
+ grid=True,
1203
+ legend=True,
1204
+ xticklabels_rotation=45,
1205
+ xtics_fontsize=8,
1206
+ )
1207
+ default_ax_settings.update(**ax_settings)
1208
+ else:
1209
+ default_ax_settings = ax_properties(**ax_settings.__dict__)
1210
+
1211
+ if isinstance(fig_settings, dict):
1212
+ fig_settings = fig_properties(**fig_settings)
1213
+ else:
1214
+ fig_settings = fig_properties(**fig_settings.__dict__)
1215
+
1216
+ if len(names) == 0:
1217
+ names = np.arange(len(values))
1218
+
1219
+ if columns is not None:
1220
+ values = values[columns]
1221
+ names = names[columns]
1222
+
1223
+ # remove nan from plot
1224
+ columns = list(np.isnan(values) == False)
1225
+ # print(columns)
1226
+ if len(columns) > 0:
1227
+ values = values[columns]
1228
+ names = names[columns]
1229
+
1230
+ if figure is None:
1231
+ fig, ax = plt.subplots()
1232
+ else:
1233
+ fig, ax = figure
1234
+
1235
+ fig, ax = default_ax_settings.change(figure=(fig, ax))
1236
+ bar_container = ax.bar(names, values, color="grey", tick_label=names)
1237
+
1238
+ ax.bar_label(
1239
+ bar_container,
1240
+ fmt=lambda x: f"{x:.2f}",
1241
+ fontsize=default_ax_settings.barlabel_fontsize,
1242
+ )
1243
+
1244
+ fig, ax = default_ax_settings.change(figure=(fig, ax))
1245
+ fig, ax = fig_settings.change(figure=(fig, ax))
1246
+
1247
+ return fig, ax
1248
+
1249
+
1250
+ def plot_outlet_stats(
1251
+ values_sim: np.ndarray | None = None,
1252
+ values_obs: np.ndarray | None = None,
1253
+ names: np.ndarray = [],
1254
+ columns: list | None = [],
1255
+ stat: str = "max",
1256
+ figure: list | tuple | None = None,
1257
+ ax_settings: dict | ax_properties = {},
1258
+ fig_settings: dict | fig_properties = {},
1259
+ ):
1260
+ """
1261
+ Plot a statistical criteria at a given list of outlet.
1262
+ Parameters:
1263
+ -----------
1264
+ values_sim: np.ndarray or None
1265
+ The result of the simulated stat for all outlets.
1266
+ values_obs: np.ndarray or None
1267
+ The result of the observed stat for all outlets.
1268
+ names: np.ndarray
1269
+ Outlets name or code stored in an np.ndarray.
1270
+ columns: list | None
1271
+ Columns of the np.ndarray to plot
1272
+ stat: str
1273
+ Criteria to plot. choice are ['max', 'min', 'mean', 'median', 'q20', 'q80']"
1274
+ figure: tuple
1275
+ input figure as (fig,ax) to add a new curve
1276
+ ax_settings: dict or class ax_properties
1277
+ object or dict with any attribute of class ax_properties
1278
+ fig_settings: dict or class ax_properties
1279
+ object or dict with any attribute of class fig_settings
1280
+ """
1281
+
1282
+ if isinstance(ax_settings, dict):
1283
+ default_ax_settings = ax_properties(
1284
+ ylabel=f"{stat} criteria",
1285
+ xlabel="Gauges stations",
1286
+ grid=True,
1287
+ legend=True,
1288
+ xticklabels_rotation=45,
1289
+ xtics_fontsize=6,
1290
+ )
1291
+ default_ax_settings.update(**ax_settings)
1292
+ else:
1293
+ default_ax_settings = ax_properties(**ax_settings.__dict__)
1294
+
1295
+ if isinstance(fig_settings, dict):
1296
+ fig_settings = fig_properties(**fig_settings)
1297
+ else:
1298
+ fig_settings = fig_properties(**fig_settings.__dict__)
1299
+
1300
+ if columns is not None:
1301
+ if values_sim is not None:
1302
+ values_sim = values_sim[columns]
1303
+
1304
+ if values_obs is not None:
1305
+ values_obs = values_obs[columns]
1306
+
1307
+ names = names[columns]
1308
+
1309
+ if np.all(values_obs == -99.0):
1310
+ values_obs = None
1311
+
1312
+ if values_sim is not None and values_obs is not None:
1313
+ if values_obs.size != values_sim.size:
1314
+ raise ValueError("values_sim and values_obs must have the same size !")
1315
+
1316
+ if figure is None:
1317
+ fig, ax = plt.subplots()
1318
+ else:
1319
+ fig, ax = figure
1320
+
1321
+ fig, ax = default_ax_settings.change(figure=(fig, ax))
1322
+
1323
+ x = np.arange(len(names))
1324
+ width = 0.25 # the width of the bars
1325
+
1326
+ multiplier = 0
1327
+
1328
+ if values_sim is not None:
1329
+ offset = width * multiplier
1330
+ ax.bar(x + offset, values_sim, width, label="obs")
1331
+ multiplier += 1
1332
+ # ax.bar_label(rects, padding=3)
1333
+
1334
+ if values_obs is not None:
1335
+ offset = width * multiplier
1336
+ ax.bar(x + offset, values_obs, width, label="sim")
1337
+ # ax.bar_label(rects, padding=3)
1338
+ # multiplier += 1
1339
+
1340
+ ax.set_xticks(x + width, names)
1341
+
1342
+ # bar_container = ax.bar(names, values, color="grey", tick_label=names)
1343
+
1344
+ # ax.bar_label(
1345
+ # bar_container,
1346
+ # fmt=lambda x: f"{x:.2f}",
1347
+ # fontsize=default_ax_settings.barlabel_fontsize,
1348
+ # )
1349
+
1350
+ fig, ax = default_ax_settings.change(figure=(fig, ax))
1351
+ fig, ax = fig_settings.change(figure=(fig, ax))
1352
+
1353
+ return fig, ax
1354
+
1355
+
1356
+ def plot_misfit_map(
1357
+ values: np.ndarray = [],
1358
+ names: np.ndarray = [],
1359
+ mesh=None,
1360
+ misfit: str = "nse",
1361
+ coef_hydro=99.0,
1362
+ catchment_polygon: None | DataFrame = None,
1363
+ ax_settings: dict | ax_properties = {},
1364
+ fig_settings: dict | fig_properties = {},
1365
+ plot_settings: dict | plot_properties = {},
1366
+ ):
1367
+ """
1368
+ Map plot of the misfit criteria between the simulated and observed discharges.
1369
+ Parameters:
1370
+ -----------
1371
+ values: np.ndarray
1372
+ The result of the discharge misfit for all outlets.
1373
+ names: np.ndarray
1374
+ Outlets name or code stored in an np.ndarray.
1375
+ mesh: None | dict
1376
+ The mesh of the Smash model as dict
1377
+ misfit: str
1378
+ Criteria to plot. choice are ['nse', 'nnse', 'rmse', 'nrmse', 'se', 'kge']"
1379
+ figure: tuple
1380
+ input figure as (fig,ax) to add a new curve
1381
+ ax_settings: dict or class ax_properties
1382
+ object or dict with any attribute of class ax_properties
1383
+ fig_settings: dict or class ax_properties
1384
+ object or dict with any attribute of class fig_settings
1385
+ plot_settings_sim: dict or class ax_properties
1386
+ object or dict with any attribute of class plot_settings.
1387
+ """
1388
+ if mesh is not None:
1389
+ if isinstance(mesh, dict):
1390
+ pass
1391
+ else:
1392
+ raise ValueError("mesh must be a dict")
1393
+ else:
1394
+ raise ValueError(
1395
+ "model or mesh are mandatory and must be a dict or a smash Model object"
1396
+ )
1397
+
1398
+ if isinstance(ax_settings, dict):
1399
+ default_ax_settings = ax_properties(
1400
+ title=f"Map of {misfit} criteria over the domain.",
1401
+ xlabel="x_coords",
1402
+ ylabel="y_coords",
1403
+ cmap="turbo_r",
1404
+ )
1405
+ default_ax_settings.update(**ax_settings)
1406
+ else:
1407
+ default_ax_settings = ax_properties(**ax_settings.__dict__)
1408
+
1409
+ if isinstance(fig_settings, dict):
1410
+ fig_settings = fig_properties(**fig_settings)
1411
+ else:
1412
+ fig_settings = fig_properties(**fig_settings.__dict__)
1413
+
1414
+ if isinstance(plot_settings, dict):
1415
+ default_plot_settings = plot_properties(
1416
+ marker="o",
1417
+ markersize=8,
1418
+ )
1419
+ default_plot_settings.update(**plot_settings)
1420
+ else:
1421
+ default_plot_settings = plot_properties(**plot_settings.__dict__)
1422
+
1423
+ # unset attribute color, managed separatly
1424
+ delattr(default_plot_settings, "color")
1425
+
1426
+ gauge = mesh["gauge_pos"]
1427
+ stations = mesh["code"]
1428
+ flow_acc = mesh["flwacc"]
1429
+ na = mesh["active_cell"] == 0
1430
+
1431
+ bbox = geo_toolbox.get_bbox_from_smash_mesh(mesh)
1432
+ extent = (bbox["left"], bbox["right"], bbox["bottom"], bbox["top"])
1433
+
1434
+ flow_accum_bv = np.where(na, 0.0, flow_acc.data)
1435
+ surfmin = (1.0 - coef_hydro / 100.0) * np.max(flow_accum_bv)
1436
+ mask_flow = flow_accum_bv < surfmin
1437
+ flow_plot = np.where(mask_flow, np.nan, flow_accum_bv.data)
1438
+ flow_plot = np.where(na, np.nan, flow_plot)
1439
+
1440
+ plt.rcParams.update(plt.rcParamsDefault)
1441
+ fig, ax = plt.subplots()
1442
+ fig, ax = default_ax_settings.change(figure=(fig, ax))
1443
+
1444
+ active_cell = np.where(na, np.nan, mesh["active_cell"])
1445
+ cmap = ListedColormap(["lightgray"])
1446
+ ax.imshow(active_cell, cmap=cmap, extent=extent)
1447
+
1448
+ myblues = matplotlib.colormaps["binary"]
1449
+ cmp = ListedColormap(myblues(np.linspace(0.20, 1.0, 265)))
1450
+ im = ax.imshow(flow_plot, cmap=cmp, extent=extent)
1451
+
1452
+ if catchment_polygon is not None:
1453
+ # catchment_polygon = gpd.read_file(outlets_shapefile)
1454
+ catchment_polygon.plot(ax=ax, facecolor="none", edgecolor="black")
1455
+
1456
+ # create an axes on the right side of ax. The width of cax will be 5%
1457
+ # of ax and the padding between cax and ax will be fixed at 0.05 inch.
1458
+ divider = make_axes_locatable(ax)
1459
+ cax = divider.append_axes("right", size="5%", pad=0.05)
1460
+
1461
+ fig.colorbar(
1462
+ im, cmap="Blues", ax=ax, label="Cumulated surface (km²)", shrink=0.75, cax=cax
1463
+ )
1464
+
1465
+ # define bounds for the colormap
1466
+ if misfit == "nse" or misfit == "nnse":
1467
+ vmin = 0
1468
+ vmax = 1
1469
+ elif misfit == "rmse" or misfit == "nrmse" or misfit == "se":
1470
+ vmin = 0
1471
+ vmax = np.max(values)
1472
+ else:
1473
+ vmin = np.min(values)
1474
+ vmax = np.max(values)
1475
+
1476
+ colormap = cm.get_cmap(default_ax_settings.cmap)
1477
+ cmp = ListedColormap(colormap(np.linspace(vmin, vmax, 256)))
1478
+
1479
+ ha = "right"
1480
+ for i in range(len(stations)):
1481
+
1482
+ if ha == "right":
1483
+ ha = "left"
1484
+ str_val = str(np.round(values[i], 2)).rjust(int(len(stations[i])))
1485
+ code = f"{stations[i]}\n {str_val}"
1486
+
1487
+ else:
1488
+ ha = "right"
1489
+ str_val = str(np.round(values[i], 2)).ljust(int(len(stations[i])))
1490
+ code = f"{stations[i]}\n {str_val}"
1491
+
1492
+ coord = geo_toolbox.rowcol_to_xy(
1493
+ gauge[i][0],
1494
+ gauge[i][1],
1495
+ mesh["xmin"],
1496
+ mesh["ymax"],
1497
+ mesh["xres"],
1498
+ mesh["yres"],
1499
+ )
1500
+
1501
+ ax.plot(
1502
+ coord[0],
1503
+ coord[1],
1504
+ color=cmp(values[i]),
1505
+ **default_plot_settings.__dict__,
1506
+ )
1507
+
1508
+ ax.annotate(
1509
+ code, # this is the text
1510
+ # these are the coordinates to position the label
1511
+ (coord[0], coord[1]),
1512
+ textcoords="data", # how to position the text
1513
+ xytext=(coord[0], coord[1]), # distance from text to points (x,y)
1514
+ ha=ha, # horizontal alignment can be left, right or center
1515
+ color=cmp(values[i]),
1516
+ fontsize=default_ax_settings.annotate_fontsize
1517
+ * default_ax_settings.font_ratio,
1518
+ )
1519
+
1520
+ import matplotlib as mpl
1521
+
1522
+ norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
1523
+ # create an axes on the right side of ax. The width of cax will be 5%
1524
+ # of ax and the padding between cax and ax will be fixed at 0.05 inch.
1525
+ # divider = make_axes_locatable(ax)
1526
+ cax = divider.append_axes("right", size="5%", pad=0.5)
1527
+
1528
+ fig.colorbar(
1529
+ cm.ScalarMappable(norm=norm, cmap=cmp),
1530
+ cmap=cmp,
1531
+ ax=ax,
1532
+ cax=cax,
1533
+ label=misfit,
1534
+ shrink=0.75,
1535
+ location="right",
1536
+ )
1537
+
1538
+ fig, ax = default_ax_settings.change(figure=(fig, ax))
1539
+
1540
+ fig, ax = fig_settings.change(figure=(fig, ax))
1541
+
1542
+ return fig, ax
1543
+
1544
+
1545
+ # def _ax_settings(
1546
+ # figure,
1547
+ # title: str = None,
1548
+ # xlabel: str = None,
1549
+ # ylabel: str = None,
1550
+ # clabel: str = None,
1551
+ # font_ratio: int = 1,
1552
+ # title_fontsize: int = 12,
1553
+ # label_fontsize: int = 10,
1554
+ # grid: bool = True,
1555
+ # xscale: str | None = None,
1556
+ # yscale: str | None = None,
1557
+ # legend: bool = True,
1558
+ # legend_loc: str = None,
1559
+ # legend_fontsize: int = 8,
1560
+ # xtics_fontsize: int = 8,
1561
+ # ytics_fontsize: int = 8,
1562
+ # cmap: str | None = None,
1563
+ # xlim: tuple | list | None = (None, None),
1564
+ # ylim: tuple | list | None = (None, None),
1565
+ # ):
1566
+
1567
+ # fig, ax = figure
1568
+
1569
+ # plt.rcParams.update(plt.rcParamsDefault)
1570
+
1571
+ # if title is not None:
1572
+ # ax.set_title(title, fontsize=title_fontsize * font_ratio)
1573
+
1574
+ # if xlabel is not None:
1575
+ # ax.set_xlabel(xlabel)
1576
+
1577
+ # if ylabel is not None:
1578
+ # ax.set_ylabel(ylabel)
1579
+
1580
+ # if grid:
1581
+ # ax.grid(True, which="both", linestyle="--", alpha=0.5)
1582
+
1583
+ # if xscale is not None:
1584
+ # ax.set_xscale(xscale)
1585
+
1586
+ # if yscale is not None:
1587
+ # ax.set_xyscale(yscale)
1588
+
1589
+ # if legend:
1590
+ # ax.legend(loc=legend_loc)
1591
+
1592
+ # if cmap is not None:
1593
+ # plt.rc("image", cmap=cmap)
1594
+
1595
+ # if ylim[0] != None:
1596
+ # ax.set_ylim(bottom=ylim[0])
1597
+
1598
+ # if ylim[1] != None:
1599
+ # ax.set_ylim(top=ylim[1])
1600
+
1601
+ # if xlim[0] != None:
1602
+ # ax.set_xlim(left=xlim[0])
1603
+
1604
+ # if xlim[1] != None:
1605
+ # ax.set_xlim(right=xlim[1])
1606
+
1607
+ # plt.rcParams.update(
1608
+ # {
1609
+ # "axes.labelsize": label_fontsize * font_ratio,
1610
+ # "axes.titlesize": title_fontsize * font_ratio,
1611
+ # "legend.fontsize": legend_fontsize * font_ratio,
1612
+ # "figure.titlesize": title_fontsize * font_ratio,
1613
+ # "xtick.labelsize": xtics_fontsize * font_ratio,
1614
+ # "ytick.labelsize": ytics_fontsize * font_ratio,
1615
+ # }
1616
+ # )
1617
+
1618
+ # return fig, ax
1619
+
1620
+
1621
+ # def _fig_settings(
1622
+ # figure,
1623
+ # figname=None,
1624
+ # xsize=8,
1625
+ # ysize=6,
1626
+ # transparent=False,
1627
+ # dpi=80,
1628
+ # font_ratio=1,
1629
+ # bbox_inches="tight",
1630
+ # ):
1631
+
1632
+ # fig, ax = figure
1633
+
1634
+ # fig.set_figheight(ysize)
1635
+ # fig.set_figwidth(xsize)
1636
+
1637
+ # plt.rc(
1638
+ # "font", size=plt.rcParams["font.size"] * font_ratio
1639
+ # ) # controls default text sizes
1640
+ # plt.rc(
1641
+ # "axes", titlesize=plt.rcParams["axes.titlesize"] * font_ratio
1642
+ # ) # fontsize of the axes title
1643
+ # plt.rc(
1644
+ # "axes", labelsize=plt.rcParams["axes.labelsize"] * font_ratio
1645
+ # ) # fontsize of the x and y labels
1646
+ # plt.rc(
1647
+ # "xtick", labelsize=plt.rcParams["xtick.labelsize"] * font_ratio
1648
+ # ) # fontsize of the tick labels
1649
+ # plt.rc(
1650
+ # "ytick", labelsize=plt.rcParams["ytick.labelsize"] * font_ratio
1651
+ # ) # fontsize of the tick labels
1652
+ # plt.rc(
1653
+ # "legend", fontsize=plt.rcParams["legend.fontsize"] * font_ratio
1654
+ # ) # legend fontsize
1655
+ # plt.rc(
1656
+ # "figure", titlesize=plt.rcParams["figure.titlesize"] * font_ratio
1657
+ # ) # fontsize of the figure title
1658
+
1659
+ # if figname is not None:
1660
+ # fig.savefig(figname, transparent=transparent, dpi=dpi, bbox_inches=bbox_inches)
1661
+
1662
+ # return fig, ax