tilupy 2.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tilupy/plot.py ADDED
@@ -0,0 +1,234 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import numpy as np
5
+ import matplotlib
6
+ import matplotlib.pyplot as plt
7
+ import seaborn as sns
8
+ import pytopomap.plot as pyplt
9
+
10
+ def plot_shotgather(x: np.ndarray,
11
+ t: np.ndarray,
12
+ data: np.ndarray,
13
+ xlabel: str="X (m)",
14
+ ylabel: str="Time (s)",
15
+ **kwargs
16
+ ) -> matplotlib.axes._axes.Axes:
17
+ """Plot shotgather image.
18
+
19
+ Plot shotgather like image, with vertical axis as time and horizontal axis
20
+ and spatial dimension. This is a simple call to plot_shotgather, but
21
+ input data is transposed because in tilupy the last axis is time by
22
+ convention.
23
+
24
+ Parameters
25
+ ----------
26
+ x : numpy.ndarray
27
+ Spatial coordinates, size NX.
28
+ t : numpy.ndarray
29
+ Time array (assumed in seconds), size NT.
30
+ data : numpy.ndarray
31
+ NX*NT array of data to be plotted.
32
+ xlabel : string, optional
33
+ Label for x-axis, by default "X (m)".
34
+ ylabel : string, optional
35
+ Label for x-axis, by default "Time (s)".
36
+ **kwargs : dict, optional
37
+ parameters passed on to :func:`pytopomap.plot.plot_imshow`.
38
+
39
+ Returns
40
+ -------
41
+ matplotlib.axes._axes.Axes
42
+ Axes instance where data is plotted
43
+ """
44
+ if "aspect" not in kwargs:
45
+ kwargs["aspect"] = "auto"
46
+ axe = pyplt.plot_imshow(x, t[::-1], data.T, **kwargs)
47
+ axe.set_adjustable("box")
48
+ axe.set_ylabel(ylabel)
49
+ axe.set_xlabel(xlabel)
50
+
51
+ return axe
52
+
53
+
54
+ def plot_heatmaps(df,
55
+ values,
56
+ index,
57
+ columns,
58
+ aggfunc="mean",
59
+ figsize=None,
60
+ ncols=3,
61
+ heatmap_kws=None,
62
+ notations=None,
63
+ best_values=None,
64
+ plot_best_value="point",
65
+ text_kwargs=None,
66
+ ) -> matplotlib.figure.Figure:
67
+ """Plot one or several heatmaps from a pandas DataFrame.
68
+
69
+ Each heatmap is created by pivoting the DataFrame with the given
70
+ `index`, `columns`, and a variable from `values`.
71
+
72
+ Parameters
73
+ ----------
74
+ df : pandas.DataFrame
75
+ Input DataFrame containing the data.
76
+ values : list[str]
77
+ Column names in :data:`df` to plot as separate heatmaps.
78
+ index : str
79
+ Column name to use as rows of the pivot table.
80
+ columns : str
81
+ Column name to use as columns of the pivot table.
82
+ aggfunc : str or callable, optional
83
+ Aggregation function applied when multiple values exist for
84
+ a given (index, column) pair. By default "mean".
85
+ figsize : tuple of float, optional
86
+ Size of the matplotlib figure, by default None.
87
+ ncols : int, optional
88
+ Maximum number of heatmaps per row, by default 3.
89
+ heatmap_kws : dict or dict[dict], optional
90
+ Keyword arguments passed to :data:`seaborn.heatmap`.
91
+ If dict of dict, keys must match the values in :data:`values`.
92
+ notations : dict, optional
93
+ Mapping from variable names to readable labels
94
+ (used for axis and colorbar labels).
95
+ best_values : dict, optional
96
+ Mapping from variable names to selection criterion:
97
+ "min", "min_abs", or "max".
98
+ plot_best_value : {"point", "text"}, optional
99
+ How to highlight best values:
100
+ - "point" : mark with circles
101
+ - "text" : display numeric values
102
+ By default "point".
103
+ text_kwargs : dict, optional
104
+ Keyword arguments passed to :data:`matplotlib.axes.Axes.text` when
105
+ annotating best values. Only used if :data:`plot_best_value="text"`.
106
+
107
+ Returns
108
+ -------
109
+ matplotlib.figure.Figure
110
+ The matplotlib Figure containing the heatmaps.
111
+ """
112
+ nplots = len(values)
113
+ ncols = min(nplots, ncols)
114
+ nrows = int(np.ceil(nplots / ncols))
115
+ fig = plt.figure(figsize=figsize)
116
+ axes = []
117
+
118
+ for i in range(nplots):
119
+ axe = fig.add_subplot(nrows, ncols, i + 1)
120
+ axes.append(axe)
121
+ data = df.pivot_table(
122
+ index=index, columns=columns, values=values[i], aggfunc=aggfunc
123
+ ).astype(float)
124
+ if heatmap_kws is None:
125
+ kws = dict()
126
+ elif isinstance(heatmap_kws, dict):
127
+ if values[i] in heatmap_kws:
128
+ kws = heatmap_kws[values[i]]
129
+ else:
130
+ kws = heatmap_kws
131
+
132
+ if "cmap" not in kws:
133
+ minval = data.min().min()
134
+ maxval = data.max().max()
135
+ if minval * maxval < 0:
136
+ val = max(np.abs(minval), maxval)
137
+ kws["cmap"] = "seismic"
138
+ kws["vmin"] = -val
139
+ kws["vmax"] = val
140
+
141
+ if "cbar_kws" not in kws:
142
+ kws["cbar_kws"] = dict(pad=0.03)
143
+
144
+ if notations is None:
145
+ kws["cbar_kws"]["label"] = values[i]
146
+ else:
147
+ if values[i] in notations:
148
+ kws["cbar_kws"]["label"] = notations[values[i]]
149
+ else:
150
+ kws["cbar_kws"]["label"] = values[i]
151
+
152
+ sns.heatmap(data, ax=axe, **kws)
153
+
154
+ if best_values is not None:
155
+ best_value = best_values[values[i]]
156
+ array = np.array(data)
157
+ irow = np.arange(data.shape[0])
158
+
159
+ if best_value == "min":
160
+ ind = np.nanargmin(array, axis=1)
161
+ i2 = np.nanargmin(array[irow, ind])
162
+ if best_value == "min_abs":
163
+ ind = np.nanargmin(np.abs(array), axis=1)
164
+ i2 = np.nanargmin(np.abs(array[irow, ind]))
165
+ elif best_value == "max":
166
+ ind = np.nanargmax(array, axis=1)
167
+ i2 = np.nanargmax(array[irow, ind])
168
+
169
+ if plot_best_value == "point":
170
+ axe.plot(
171
+ ind + 0.5,
172
+ irow + 0.5,
173
+ ls="",
174
+ marker="o",
175
+ mfc="w",
176
+ mec="k",
177
+ mew=0.5,
178
+ ms=5,
179
+ )
180
+ axe.plot(
181
+ ind[i2] + 0.5,
182
+ i2 + 0.5,
183
+ ls="",
184
+ marker="o",
185
+ mfc="w",
186
+ mec="k",
187
+ mew=0.8,
188
+ ms=9,
189
+ )
190
+ elif plot_best_value == "text":
191
+ indx = list(ind)
192
+ indx.pop(i2)
193
+ indy = list(irow)
194
+ indy.pop(i2)
195
+ default_kwargs = dict(ha="center", va="center", fontsize=8)
196
+ if text_kwargs is None:
197
+ text_kwargs = default_kwargs
198
+ else:
199
+ text_kwargs = dict(default_kwargs, **text_kwargs)
200
+ for i, j in zip(indx, indy):
201
+ axe.text(
202
+ i + 0.5,
203
+ j + 0.5,
204
+ "{:.2g}".format(array[j, i]),
205
+ **text_kwargs
206
+ )
207
+ text_kwargs2 = dict(text_kwargs, fontweight="bold")
208
+ axe.text(
209
+ ind[i2] + 0.5,
210
+ i2 + 0.5,
211
+ "{:.2g}".format(array[i2, ind[i2]]),
212
+ **text_kwargs2
213
+ )
214
+
215
+ axes = np.array(axes).reshape((nrows, ncols))
216
+ for i in range(nrows):
217
+ for j in range(1, ncols):
218
+ axes[i, j].set_ylabel("")
219
+ # axes[i, j].set_yticklabels([])
220
+
221
+ for i in range(nrows - 1):
222
+ for j in range(ncols):
223
+ axes[i, j].set_xlabel("")
224
+ # axes[i, j].set_xticklabels([])
225
+
226
+ if notations is not None:
227
+ for i in range(nrows):
228
+ axes[i, 0].set_ylabel(notations[index])
229
+ for j in range(ncols):
230
+ axes[-1, j].set_xlabel(notations[columns])
231
+
232
+ # fig.tight_layout()
233
+
234
+ return fig
tilupy/raster.py ADDED
@@ -0,0 +1,199 @@
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import numpy as np
5
+ import importlib
6
+
7
+
8
+ def read_raster(file: str) -> list[np.ndarray]:
9
+ """Convert a raster file (tif or asc) into numpy array.
10
+
11
+ Parameters
12
+ ----------
13
+ file : str
14
+ Path to the raster file.
15
+
16
+ Returns
17
+ -------
18
+ list[np.ndarray]
19
+ X and Y coordinates and data values in numpy array.
20
+ """
21
+ if file.endswith(".asc") or file.endswith(".txt"):
22
+ return read_ascii(file)
23
+ elif file.endswith(".tif") or file.endswith(".tif"):
24
+ return read_tiff(file)
25
+
26
+
27
+ def read_tiff(file: str) -> list[np.ndarray]:
28
+ """Read and convert a tiff file into numpy array.
29
+
30
+ Parameters
31
+ ----------
32
+ file : str
33
+ Path to the tiff file.
34
+
35
+ Returns
36
+ -------
37
+ list[np.ndarray]
38
+ X and Y coordinates and data values in numpy array.
39
+ """
40
+ import rasterio
41
+
42
+ with rasterio.open(file, "r") as src:
43
+ dem = src.read(1)
44
+ ny, nx = dem.shape
45
+ x = np.linspace(src.bounds.left, src.bounds.right, nx)
46
+ y = np.linspace(src.bounds.bottom, src.bounds.top, ny)
47
+ return x, y, dem
48
+
49
+
50
+ def read_ascii(file: str) -> list[np.ndarray]:
51
+ """Read and convert a ascii file into numpy array.
52
+
53
+ Parameters
54
+ ----------
55
+ file : str
56
+ Path to the ascii file.
57
+
58
+ Returns
59
+ -------
60
+ list[np.ndarray]
61
+ X and Y coordinates and data values in numpy array.
62
+ """
63
+ dem = np.loadtxt(file, skiprows=6)
64
+ grid = {}
65
+ with open(file, "r") as fid:
66
+ for i in range(6):
67
+ tmp = fid.readline().split()
68
+ grid[tmp[0]] = float(tmp[1])
69
+ try:
70
+ x0 = grid["xllcenter"]
71
+ y0 = grid["yllcenter"]
72
+ except KeyError:
73
+ x0 = grid["xllcorner"]
74
+ y0 = grid["yllcorner"]
75
+ nx = int(grid["ncols"])
76
+ ny = int(grid["nrows"])
77
+ dx = dy = grid["cellsize"]
78
+ x = np.linspace(x0, x0 + (nx - 1) * dx, nx)
79
+ y = np.linspace(y0, y0 + (ny - 1) * dy, ny)
80
+
81
+ return x, y, dem
82
+
83
+
84
+ def write_tiff(x: np.ndarray,
85
+ y: np.ndarray,
86
+ z: np.ndarray,
87
+ file_out: str,
88
+ **kwargs
89
+ ) -> None:
90
+ """Write tif file from numpy array.
91
+
92
+ Parameters
93
+ ----------
94
+ x : np.ndarray
95
+ X coordinates.
96
+ y : np.ndarray
97
+ Y coordinates.
98
+ z : np.ndarray
99
+ Elevation values.
100
+ file_out : str
101
+ Name of the output folder.
102
+ """
103
+ import rasterio
104
+ from rasterio.transform import Affine
105
+
106
+ if "driver" not in kwargs:
107
+ kwargs["driver"] = "GTiff"
108
+ res = (x[-1] - x[0]) / (len(x) - 1)
109
+ transform = Affine.translation(x[0] - res / 2, y[-1] - res / 2) * Affine.scale(res, -res)
110
+
111
+ with rasterio.open(file_out,
112
+ "w",
113
+ height=z.shape[0],
114
+ width=z.shape[1],
115
+ count=1,
116
+ dtype=z.dtype,
117
+ transform=transform,
118
+ **kwargs) as dst:
119
+ dst.write(z, 1)
120
+
121
+
122
+ def write_ascii(x: np.ndarray,
123
+ y: np.ndarray,
124
+ z: np.ndarray,
125
+ file_out: str,
126
+ ) -> None:
127
+ """Write ascii file from numpy array.
128
+
129
+ Parameters
130
+ ----------
131
+ x : np.ndarray
132
+ X coordinates.
133
+ y : np.ndarray
134
+ Y coordinates.
135
+ z : np.ndarray
136
+ Elevation values.
137
+ file_out : str
138
+ Name of the output folder.
139
+ """
140
+ nx = z.shape[1]
141
+ ny = z.shape[0]
142
+ cellsize = x[1] - x[0]
143
+ header_txt = ("ncols {:.0f}\nnrows {:.0f}\nxllcorner {:.5f}\nyllcorner {:.5f}\n")
144
+ header_txt += "cellsize {:.4f}\nnodata_value -99999"
145
+ header_txt = header_txt.format(nx, ny, x[0], y[0], cellsize)
146
+ np.savetxt(file_out, z, header=header_txt, comments="")
147
+
148
+
149
+ def write_raster(x: np.ndarray,
150
+ y: np.ndarray,
151
+ z: np.ndarray,
152
+ file_out: str,
153
+ fmt: str = None,
154
+ **kwargs
155
+ ) -> None:
156
+ """Write raster file from numpy array.
157
+
158
+ Parameters
159
+ ----------
160
+ x : np.ndarray
161
+ X coordinates.
162
+ y : np.ndarray
163
+ Y coordinates.
164
+ z : np.ndarray
165
+ Elevation values.
166
+ file_out : str
167
+ Name of the output folder.
168
+ fmt : str
169
+ Wanted format : "asc", "ascii", "txt", "tif", "tiff".
170
+
171
+ Raises
172
+ ------
173
+ ValueError
174
+ If invalid format.
175
+ """
176
+ # File format read from file_out overrides fmt
177
+ fmt_tmp = file_out.split(".")
178
+ if len(fmt_tmp) > 1:
179
+ fmt = fmt_tmp[-1]
180
+ else:
181
+ if fmt is None:
182
+ fmt = "asc"
183
+ file_out = file_out + "." + fmt
184
+
185
+ if fmt not in ["asc", "ascii", "txt", "tif", "tiff"]:
186
+ raise ValueError("File format not implemented in write_raster")
187
+
188
+ if fmt.startswith("tif"):
189
+ if importlib.util.find_spec("rasterio") is None:
190
+ print(("rasterio is required to write tif files.",
191
+ " Switching to asc format",))
192
+ fmt = "asc"
193
+
194
+ if fmt in ["asc", "ascii", "txt"]:
195
+ write_ascii(x, y, z, file_out)
196
+ elif fmt in ["tif", "tiff"]:
197
+ write_tiff(x, y, z, file_out, **kwargs)
198
+ else:
199
+ raise NotImplementedError()