pastastore 1.3.0__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pastastore/plotting.py CHANGED
@@ -1,4 +1,4 @@
1
- """This module contains all the plotting methods for PastaStore.
1
+ """Module containing all the plotting methods for PastaStore.
2
2
 
3
3
  Pastastore comes with a number helpful plotting methods to quickly
4
4
  visualize time series or the locations of the time series contained in the
@@ -14,6 +14,7 @@ follows::
14
14
  ax = pstore.maps.oseries()
15
15
  pstore.maps.add_background_map(ax) # for adding a background map
16
16
  """
17
+
17
18
  import matplotlib.pyplot as plt
18
19
  import numpy as np
19
20
  import pandas as pd
@@ -49,9 +50,12 @@ class Plots:
49
50
  split=False,
50
51
  figsize=(10, 5),
51
52
  progressbar=True,
53
+ show_legend=True,
54
+ labelfunc=None,
55
+ legend_kwargs=None,
52
56
  **kwargs,
53
57
  ):
54
- """Internal method to plot time series from pastastore.
58
+ """Plot time series from pastastore (internal method).
55
59
 
56
60
  Parameters
57
61
  ----------
@@ -71,6 +75,13 @@ class Plots:
71
75
  progressbar : bool, optional
72
76
  show progressbar when loading time series from store,
73
77
  by default True
78
+ show_legend : bool, optional
79
+ show legend, default is True.
80
+ labelfunc : callable, optional
81
+ function to create custom labels, function should take name of time series
82
+ as input
83
+ legend_kwargs : dict, optional
84
+ additional arguments to pass to legend
74
85
 
75
86
  Returns
76
87
  -------
@@ -109,16 +120,33 @@ class Plots:
109
120
  iax = axes
110
121
  else:
111
122
  iax = ax
123
+ if labelfunc is not None:
124
+ n = labelfunc(n)
112
125
  iax.plot(ts.index, ts.squeeze(), label=n, **kwargs)
113
- if split:
126
+
127
+ if split and show_legend:
114
128
  iax.legend(loc="best", fontsize="x-small")
115
129
 
116
- if not split:
117
- axes.legend(loc=(0, 1), frameon=False, ncol=7, fontsize="x-small")
130
+ if not split and show_legend:
131
+ if legend_kwargs is None:
132
+ legend_kwargs = {}
133
+ ncol = legend_kwargs.pop("ncol", 7)
134
+ fontsize = legend_kwargs.pop("fontsize", "x-small")
135
+ axes.legend(loc=(0, 1), frameon=False, ncol=ncol, fontsize=fontsize)
118
136
 
119
137
  return axes
120
138
 
121
- def oseries(self, names=None, ax=None, split=False, figsize=(10, 5), **kwargs):
139
+ def oseries(
140
+ self,
141
+ names=None,
142
+ ax=None,
143
+ split=False,
144
+ figsize=(10, 5),
145
+ show_legend=True,
146
+ labelfunc=None,
147
+ legend_kwargs=None,
148
+ **kwargs,
149
+ ):
122
150
  """Plot oseries.
123
151
 
124
152
  Parameters
@@ -134,6 +162,13 @@ class Plots:
134
162
  A maximum of 20 time series is supported when split=True.
135
163
  figsize : tuple, optional
136
164
  figure size, by default (10, 5)
165
+ show_legend : bool, optional
166
+ show legend, default is True.
167
+ labelfunc : callable, optional
168
+ function to create custom labels, function should take name of time series
169
+ as input
170
+ legend_kwargs : dict, optional
171
+ additional arguments to pass to legend
137
172
 
138
173
  Returns
139
174
  -------
@@ -146,6 +181,9 @@ class Plots:
146
181
  ax=ax,
147
182
  split=split,
148
183
  figsize=figsize,
184
+ show_legend=show_legend,
185
+ labelfunc=labelfunc,
186
+ legend_kwargs=legend_kwargs,
149
187
  **kwargs,
150
188
  )
151
189
 
@@ -156,6 +194,9 @@ class Plots:
156
194
  ax=None,
157
195
  split=False,
158
196
  figsize=(10, 5),
197
+ show_legend=True,
198
+ labelfunc=None,
199
+ legend_kwargs=None,
159
200
  **kwargs,
160
201
  ):
161
202
  """Plot stresses.
@@ -176,6 +217,13 @@ class Plots:
176
217
  A maximum of 20 time series is supported when split=True.
177
218
  figsize : tuple, optional
178
219
  figure size, by default (10, 5)
220
+ show_legend : bool, optional
221
+ show legend, default is True.
222
+ labelfunc : callable, optional
223
+ function to create custom labels, function should take name of time series
224
+ as input
225
+ legend_kwargs : dict, optional
226
+ additional arguments to pass to legend
179
227
 
180
228
  Returns
181
229
  -------
@@ -196,6 +244,9 @@ class Plots:
196
244
  ax=ax,
197
245
  split=split,
198
246
  figsize=figsize,
247
+ show_legend=show_legend,
248
+ labelfunc=labelfunc,
249
+ legend_kwargs=legend_kwargs,
199
250
  **kwargs,
200
251
  )
201
252
 
@@ -456,7 +507,6 @@ class Plots:
456
507
  ax : matplotlib Axes
457
508
  The axes in which the cumulative histogram is plotted
458
509
  """
459
-
460
510
  statsdf = self.pstore.get_statistics(
461
511
  [statistic], modelnames=modelnames, progressbar=False
462
512
  )
@@ -511,6 +561,39 @@ class Plots:
511
561
 
512
562
  return ax
513
563
 
564
+ def compare_models(self, modelnames, ax=None, **kwargs):
565
+ """Compare multiple models and plot the results.
566
+
567
+ Parameters
568
+ ----------
569
+ modelnames : list
570
+ A list of model names to compare.
571
+ ax : matplotlib.axes.Axes, optional
572
+ The axes on which to plot the comparison. If not provided, a new figure
573
+ and axes will be created.
574
+ **kwargs : dict
575
+ Additional keyword arguments to pass to the plot function.
576
+
577
+ Returns
578
+ -------
579
+ cm : pastastore.CompareModels
580
+ The CompareModels object containing the comparison results.
581
+ """
582
+ models = self.pstore.get_models(modelnames)
583
+ names = []
584
+ onames = [iml.oseries.name for iml in models]
585
+ if len(np.unique(onames)) == 1:
586
+ for modelname in modelnames:
587
+ if onames[0] in modelname:
588
+ names.append(modelname.replace(onames[0], ""))
589
+ else:
590
+ names.append(modelname)
591
+ else:
592
+ names = modelnames
593
+ cm = ps.CompareModels(models, names=names)
594
+ cm.plot(**kwargs)
595
+ return cm
596
+
514
597
 
515
598
  class Maps:
516
599
  """Map Class for PastaStore.
@@ -539,10 +622,12 @@ class Maps:
539
622
  self,
540
623
  names=None,
541
624
  kind=None,
625
+ extent=None,
542
626
  labels=True,
543
627
  adjust=False,
544
628
  figsize=(10, 8),
545
629
  backgroundmap=False,
630
+ label_kwargs=None,
546
631
  **kwargs,
547
632
  ):
548
633
  """Plot stresses locations on map.
@@ -554,6 +639,8 @@ class Maps:
554
639
  kind: str, optional
555
640
  if passed, only plot stresses of a specific kind, default is None
556
641
  which plots all stresses.
642
+ extent : list of float, optional
643
+ plot only stresses within extent [xmin, xmax, ymin, ymax]
557
644
  labels: bool, optional
558
645
  label models, by default True
559
646
  adjust: bool, optional
@@ -565,20 +652,22 @@ class Maps:
565
652
  backgroundmap: bool, optional
566
653
  if True, add background map (default CRS is EPSG:28992) with default tiles
567
654
  by OpenStreetMap.Mapnik. Default option is False.
655
+ label_kwargs: dict, optional
656
+ dictionary with keyword arguments to pass to add_labels method
568
657
 
569
658
  Returns
570
659
  -------
571
660
  ax: matplotlib.Axes
572
661
  axes object
573
662
 
574
- See also
663
+ See Also
575
664
  --------
576
665
  self.add_background_map
577
666
  """
578
- if names is not None:
579
- df = self.pstore.stresses.loc[names]
580
- else:
581
- df = self.pstore.stresses
667
+ names = self.pstore.conn._parse_names(names, "stresses")
668
+ if extent is not None:
669
+ names = self.pstore.within(extent, names=names, libname="stresses")
670
+ df = self.pstore.stresses.loc[names]
582
671
 
583
672
  if kind is not None:
584
673
  if isinstance(kind, str):
@@ -603,7 +692,9 @@ class Maps:
603
692
  else:
604
693
  ax = r
605
694
  if labels:
606
- self.add_labels(stresses, ax, adjust=adjust)
695
+ if label_kwargs is None:
696
+ label_kwargs = {}
697
+ self.add_labels(stresses, ax, adjust=adjust, **label_kwargs)
607
698
 
608
699
  if backgroundmap:
609
700
  self.add_background_map(ax)
@@ -613,10 +704,12 @@ class Maps:
613
704
  def oseries(
614
705
  self,
615
706
  names=None,
707
+ extent=None,
616
708
  labels=True,
617
709
  adjust=False,
618
710
  figsize=(10, 8),
619
711
  backgroundmap=False,
712
+ label_kwargs=None,
620
713
  **kwargs,
621
714
  ):
622
715
  """Plot oseries locations on map.
@@ -625,8 +718,11 @@ class Maps:
625
718
  ----------
626
719
  names: list, optional
627
720
  oseries names, by default None which plots all oseries locations
628
- labels: bool, optional
629
- label models, by default True
721
+ extent : list of float, optional
722
+ plot only oseries within extent [xmin, xmax, ymin, ymax]
723
+ labels: bool or str, optional
724
+ label models, by default True, if passed as "grouped", only the first
725
+ label for each x,y-location is shown.
630
726
  adjust: bool, optional
631
727
  automated smart label placement using adjustText, by default False
632
728
  figsize: tuple, optional
@@ -634,18 +730,21 @@ class Maps:
634
730
  backgroundmap: bool, optional
635
731
  if True, add background map (default CRS is EPSG:28992) with default tiles
636
732
  by OpenStreetMap.Mapnik. Default option is False.
733
+ label_kwargs: dict, optional
734
+ dictionary with keyword arguments to pass to add_labels method
637
735
 
638
736
  Returns
639
737
  -------
640
738
  ax: matplotlib.Axes
641
739
  axes object
642
740
 
643
- See also
741
+ See Also
644
742
  --------
645
743
  self.add_background_map
646
744
  """
647
-
648
745
  names = self.pstore.conn._parse_names(names, "oseries")
746
+ if extent is not None:
747
+ names = self.pstore.within(extent, names=names)
649
748
  oseries = self.pstore.oseries.loc[names]
650
749
  mask0 = (oseries["x"] != 0.0) | (oseries["y"] != 0.0)
651
750
  r = self._plotmap_dataframe(oseries.loc[mask0], figsize=figsize, **kwargs)
@@ -654,7 +753,12 @@ class Maps:
654
753
  else:
655
754
  ax = r
656
755
  if labels:
657
- self.add_labels(oseries, ax, adjust=adjust)
756
+ if label_kwargs is None:
757
+ label_kwargs = {}
758
+ if labels == "grouped":
759
+ gr = oseries.sort_index().reset_index().groupby(["x", "y"])
760
+ oseries = oseries.loc[gr["index"].first().tolist()]
761
+ self.add_labels(oseries, ax, adjust=adjust, **label_kwargs)
658
762
 
659
763
  if backgroundmap:
660
764
  self.add_background_map(ax)
@@ -685,11 +789,10 @@ class Maps:
685
789
  ax: matplotlib.Axes
686
790
  axes object
687
791
 
688
- See also
792
+ See Also
689
793
  --------
690
794
  self.add_background_map
691
795
  """
692
-
693
796
  model_oseries = [
694
797
  self.pstore.get_models(m, return_dict=True)["oseries"]["name"]
695
798
  for m in self.pstore.model_names
@@ -760,7 +863,7 @@ class Maps:
760
863
  ax: matplotlib.Axes
761
864
  axes object
762
865
 
763
- See also
866
+ See Also
764
867
  --------
765
868
  self.add_background_map
766
869
  """
@@ -809,7 +912,7 @@ class Maps:
809
912
  figsize=(10, 8),
810
913
  **kwargs,
811
914
  ):
812
- """Internal method for plotting dataframe with point locations.
915
+ """Plot dataframe with point locations (internal method).
813
916
 
814
917
  Can be called directly for more control over plot characteristics.
815
918
 
@@ -843,7 +946,6 @@ class Maps:
843
946
  sc : scatter handle
844
947
  scatter plot handle, returned if ax is not None
845
948
  """
846
-
847
949
  if ax is None:
848
950
  return_scatter = False
849
951
  fig, ax = plt.subplots(figsize=figsize)
@@ -927,7 +1029,7 @@ class Maps:
927
1029
  ax: axes object
928
1030
  axis handle of the resulting figure
929
1031
 
930
- See also
1032
+ See Also
931
1033
  --------
932
1034
  self.add_background_map
933
1035
  """
@@ -1013,7 +1115,7 @@ class Maps:
1013
1115
  uniques = stresses.loc[:, ["stressmodel", "color"]].drop_duplicates(
1014
1116
  keep="first"
1015
1117
  )
1016
- for name, row in uniques.iterrows():
1118
+ for _, row in uniques.iterrows():
1017
1119
  (h,) = ax.plot(
1018
1120
  [],
1019
1121
  [],
@@ -1112,7 +1214,7 @@ class Maps:
1112
1214
  ax: axes object
1113
1215
  axis handle of the resulting figure
1114
1216
 
1115
- See also
1217
+ See Also
1116
1218
  --------
1117
1219
  self.add_background_map
1118
1220
  """
@@ -1258,7 +1360,7 @@ class Maps:
1258
1360
  ctx.add_basemap(ax, source=providers[map_provider], crs=proj.srs, **kwargs)
1259
1361
 
1260
1362
  @staticmethod
1261
- def add_labels(df, ax, adjust=False, **kwargs):
1363
+ def add_labels(df, ax, adjust=False, objects=None, **kwargs):
1262
1364
  """Add labels to points on plot.
1263
1365
 
1264
1366
  Uses dataframe index to label points.
@@ -1271,11 +1373,12 @@ class Maps:
1271
1373
  axes object to label points on
1272
1374
  adjust: bool
1273
1375
  automated smart label placement using adjustText
1376
+ objects : list of matplotlib objects
1377
+ use to avoid labels overlapping markers
1274
1378
  **kwargs:
1275
- keyword arguments to ax.annotate
1379
+ keyword arguments to ax.annotate or adjusttext
1276
1380
  """
1277
1381
  stroke = [patheffects.withStroke(linewidth=3, foreground="w")]
1278
-
1279
1382
  fontsize = kwargs.pop("fontsize", 10)
1280
1383
 
1281
1384
  if adjust:
@@ -1295,7 +1398,9 @@ class Maps:
1295
1398
 
1296
1399
  adjust_text(
1297
1400
  texts,
1298
- force_text=0.05,
1401
+ objects=objects,
1402
+ force_text=(0.05, 0.10),
1403
+ **kwargs,
1299
1404
  **{
1300
1405
  "arrowprops": {
1301
1406
  "arrowstyle": "-",
@@ -1318,4 +1423,5 @@ class Maps:
1318
1423
  textcoords=textcoords,
1319
1424
  xytext=xytext,
1320
1425
  **{"path_effects": stroke},
1426
+ **kwargs,
1321
1427
  )