ins-pricing 0.4.5__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. ins_pricing/README.md +48 -22
  2. ins_pricing/__init__.py +142 -90
  3. ins_pricing/cli/BayesOpt_entry.py +58 -46
  4. ins_pricing/cli/BayesOpt_incremental.py +77 -110
  5. ins_pricing/cli/Explain_Run.py +42 -23
  6. ins_pricing/cli/Explain_entry.py +551 -577
  7. ins_pricing/cli/Pricing_Run.py +42 -23
  8. ins_pricing/cli/bayesopt_entry_runner.py +51 -16
  9. ins_pricing/cli/utils/bootstrap.py +23 -0
  10. ins_pricing/cli/utils/cli_common.py +256 -256
  11. ins_pricing/cli/utils/cli_config.py +379 -360
  12. ins_pricing/cli/utils/import_resolver.py +375 -358
  13. ins_pricing/cli/utils/notebook_utils.py +256 -242
  14. ins_pricing/cli/watchdog_run.py +216 -198
  15. ins_pricing/frontend/__init__.py +10 -10
  16. ins_pricing/frontend/app.py +132 -61
  17. ins_pricing/frontend/config_builder.py +33 -0
  18. ins_pricing/frontend/example_config.json +11 -0
  19. ins_pricing/frontend/example_workflows.py +1 -1
  20. ins_pricing/frontend/runner.py +340 -388
  21. ins_pricing/governance/__init__.py +20 -20
  22. ins_pricing/governance/release.py +159 -159
  23. ins_pricing/modelling/README.md +1 -1
  24. ins_pricing/modelling/__init__.py +147 -92
  25. ins_pricing/modelling/{core/bayesopt → bayesopt}/README.md +31 -13
  26. ins_pricing/modelling/{core/bayesopt → bayesopt}/__init__.py +64 -102
  27. ins_pricing/modelling/{core/bayesopt → bayesopt}/config_components.py +12 -0
  28. ins_pricing/modelling/{core/bayesopt → bayesopt}/config_preprocess.py +589 -552
  29. ins_pricing/modelling/{core/bayesopt → bayesopt}/core.py +987 -958
  30. ins_pricing/modelling/{core/bayesopt → bayesopt}/model_explain_mixin.py +296 -296
  31. ins_pricing/modelling/{core/bayesopt → bayesopt}/model_plotting_mixin.py +488 -548
  32. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/__init__.py +27 -27
  33. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_ft_components.py +349 -342
  34. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_ft_trainer.py +921 -913
  35. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_gnn.py +794 -785
  36. ins_pricing/modelling/{core/bayesopt → bayesopt}/models/model_resn.py +454 -446
  37. ins_pricing/modelling/bayesopt/trainers/__init__.py +19 -0
  38. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_base.py +1294 -1282
  39. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_ft.py +64 -56
  40. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_glm.py +203 -198
  41. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_gnn.py +333 -325
  42. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_resn.py +279 -267
  43. ins_pricing/modelling/{core/bayesopt → bayesopt}/trainers/trainer_xgb.py +515 -313
  44. ins_pricing/modelling/bayesopt/utils/__init__.py +67 -0
  45. ins_pricing/modelling/bayesopt/utils/constants.py +21 -0
  46. ins_pricing/modelling/{core/bayesopt → bayesopt}/utils/distributed_utils.py +193 -186
  47. ins_pricing/modelling/bayesopt/utils/io_utils.py +7 -0
  48. ins_pricing/modelling/bayesopt/utils/losses.py +27 -0
  49. ins_pricing/modelling/bayesopt/utils/metrics_and_devices.py +17 -0
  50. ins_pricing/modelling/{core/bayesopt → bayesopt}/utils/torch_trainer_mixin.py +636 -623
  51. ins_pricing/modelling/{core/evaluation.py → evaluation.py} +113 -104
  52. ins_pricing/modelling/explain/__init__.py +55 -55
  53. ins_pricing/modelling/explain/metrics.py +27 -174
  54. ins_pricing/modelling/explain/permutation.py +237 -237
  55. ins_pricing/modelling/plotting/__init__.py +40 -36
  56. ins_pricing/modelling/plotting/compat.py +228 -0
  57. ins_pricing/modelling/plotting/curves.py +572 -572
  58. ins_pricing/modelling/plotting/diagnostics.py +163 -163
  59. ins_pricing/modelling/plotting/geo.py +362 -362
  60. ins_pricing/modelling/plotting/importance.py +121 -121
  61. ins_pricing/pricing/__init__.py +27 -27
  62. ins_pricing/pricing/factors.py +67 -56
  63. ins_pricing/production/__init__.py +35 -25
  64. ins_pricing/production/{predict.py → inference.py} +140 -57
  65. ins_pricing/production/monitoring.py +8 -21
  66. ins_pricing/reporting/__init__.py +11 -11
  67. ins_pricing/setup.py +1 -1
  68. ins_pricing/tests/production/test_inference.py +90 -0
  69. ins_pricing/utils/__init__.py +112 -78
  70. ins_pricing/utils/device.py +258 -237
  71. ins_pricing/utils/features.py +53 -0
  72. ins_pricing/utils/io.py +72 -0
  73. ins_pricing/utils/logging.py +34 -1
  74. ins_pricing/{modelling/core/bayesopt/utils → utils}/losses.py +125 -129
  75. ins_pricing/utils/metrics.py +158 -24
  76. ins_pricing/utils/numerics.py +76 -0
  77. ins_pricing/utils/paths.py +9 -1
  78. ins_pricing/utils/profiling.py +8 -4
  79. {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.1.dist-info}/METADATA +1 -1
  80. ins_pricing-0.5.1.dist-info/RECORD +132 -0
  81. ins_pricing/modelling/core/BayesOpt.py +0 -146
  82. ins_pricing/modelling/core/__init__.py +0 -1
  83. ins_pricing/modelling/core/bayesopt/trainers/__init__.py +0 -19
  84. ins_pricing/modelling/core/bayesopt/utils/__init__.py +0 -86
  85. ins_pricing/modelling/core/bayesopt/utils/constants.py +0 -183
  86. ins_pricing/modelling/core/bayesopt/utils/io_utils.py +0 -126
  87. ins_pricing/modelling/core/bayesopt/utils/metrics_and_devices.py +0 -555
  88. ins_pricing/modelling/core/bayesopt/utils.py +0 -105
  89. ins_pricing/modelling/core/bayesopt/utils_backup.py +0 -1503
  90. ins_pricing/tests/production/test_predict.py +0 -233
  91. ins_pricing-0.4.5.dist-info/RECORD +0 -130
  92. {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.1.dist-info}/WHEEL +0 -0
  93. {ins_pricing-0.4.5.dist-info → ins_pricing-0.5.1.dist-info}/top_level.txt +0 -0
@@ -1,362 +1,362 @@
1
- from __future__ import annotations
2
-
3
- from typing import Optional, Sequence, Tuple
4
-
5
- import numpy as np
6
- import pandas as pd
7
- import matplotlib.tri as mtri
8
-
9
- from .common import EPS, PlotStyle, finalize_figure, plt
10
-
11
- try: # optional map basemap support
12
- import contextily as cx
13
- except Exception: # pragma: no cover - optional dependency
14
- cx = None
15
-
16
-
17
- _MERCATOR_MAX_LAT = 85.05112878
18
- _MERCATOR_FACTOR = 20037508.34
19
-
20
-
21
- def _require_contextily(func_name: str) -> None:
22
- if cx is None:
23
- raise RuntimeError(
24
- f"{func_name} requires contextily. Install it via 'pip install contextily'."
25
- )
26
-
27
-
28
- def _lonlat_to_mercator(lon: np.ndarray, lat: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
29
- lon = np.asarray(lon, dtype=float)
30
- lat = np.asarray(lat, dtype=float)
31
- lat = np.clip(lat, -_MERCATOR_MAX_LAT, _MERCATOR_MAX_LAT)
32
- x = lon * _MERCATOR_FACTOR / 180.0
33
- y = np.log(np.tan((90.0 + lat) * np.pi / 360.0)) * _MERCATOR_FACTOR / np.pi
34
- return x, y
35
-
36
-
37
- def _apply_bounds(ax: plt.Axes, x: np.ndarray, y: np.ndarray, padding: float) -> None:
38
- x_min, x_max = float(np.min(x)), float(np.max(x))
39
- y_min, y_max = float(np.min(y)), float(np.max(y))
40
- pad_x = (x_max - x_min) * padding
41
- pad_y = (y_max - y_min) * padding
42
- if pad_x == 0:
43
- pad_x = 1.0
44
- if pad_y == 0:
45
- pad_y = 1.0
46
- ax.set_xlim(x_min - pad_x, x_max + pad_x)
47
- ax.set_ylim(y_min - pad_y, y_max + pad_y)
48
-
49
-
50
- def _resolve_basemap(source):
51
- if cx is None or source is None:
52
- return source
53
- if isinstance(source, str):
54
- provider = cx.providers
55
- for part in source.split("."):
56
- if isinstance(provider, dict):
57
- provider = provider[part]
58
- else:
59
- provider = getattr(provider, part)
60
- return provider
61
- return source
62
-
63
-
64
- def _sanitize_geo(
65
- df: pd.DataFrame,
66
- x_col: str,
67
- y_col: str,
68
- value_col: str,
69
- weight_col: Optional[str] = None,
70
- ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, Optional[np.ndarray]]:
71
- x = pd.to_numeric(df[x_col], errors="coerce").to_numpy(dtype=float)
72
- y = pd.to_numeric(df[y_col], errors="coerce").to_numpy(dtype=float)
73
- z = pd.to_numeric(df[value_col], errors="coerce").to_numpy(dtype=float)
74
- w = None
75
- if weight_col:
76
- w = pd.to_numeric(df[weight_col], errors="coerce").to_numpy(dtype=float)
77
-
78
- if w is None:
79
- mask = np.isfinite(x) & np.isfinite(y) & np.isfinite(z)
80
- else:
81
- mask = np.isfinite(x) & np.isfinite(y) & np.isfinite(z) & np.isfinite(w)
82
- w = w[mask]
83
- return x[mask], y[mask], z[mask], w
84
-
85
-
86
- def _downsample_points(
87
- x: np.ndarray,
88
- y: np.ndarray,
89
- z: np.ndarray,
90
- w: Optional[np.ndarray],
91
- max_points: Optional[int],
92
- ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, Optional[np.ndarray]]:
93
- if max_points is None:
94
- return x, y, z, w
95
- max_points = int(max_points)
96
- if max_points <= 0 or len(x) <= max_points:
97
- return x, y, z, w
98
- rng = np.random.default_rng(13)
99
- idx = rng.choice(len(x), size=max_points, replace=False)
100
- if w is None:
101
- return x[idx], y[idx], z[idx], None
102
- return x[idx], y[idx], z[idx], w[idx]
103
-
104
-
105
- def plot_geo_heatmap(
106
- df: pd.DataFrame,
107
- *,
108
- x_col: str,
109
- y_col: str,
110
- value_col: str,
111
- weight_col: Optional[str] = None,
112
- bins: int | Tuple[int, int] = 50,
113
- agg: str = "mean",
114
- cmap: str = "YlOrRd",
115
- title: str = "Geo Heatmap",
116
- ax: Optional[plt.Axes] = None,
117
- show: bool = False,
118
- save_path: Optional[str] = None,
119
- style: Optional[PlotStyle] = None,
120
- ) -> plt.Figure:
121
- style = style or PlotStyle()
122
- if agg not in {"mean", "sum"}:
123
- raise ValueError("agg must be 'mean' or 'sum'.")
124
- x, y, z, w = _sanitize_geo(df, x_col, y_col, value_col, weight_col)
125
-
126
- if isinstance(bins, int):
127
- bins = (bins, bins)
128
-
129
- if w is None:
130
- sum_z, x_edges, y_edges = np.histogram2d(x, y, bins=bins, weights=z)
131
- if agg == "sum":
132
- grid = sum_z
133
- else:
134
- count, _, _ = np.histogram2d(x, y, bins=bins)
135
- grid = sum_z / np.maximum(count, 1.0)
136
- else:
137
- sum_w, x_edges, y_edges = np.histogram2d(x, y, bins=bins, weights=w)
138
- sum_zw, _, _ = np.histogram2d(x, y, bins=bins, weights=z * w)
139
- grid = sum_zw / np.maximum(sum_w, EPS)
140
-
141
- created_fig = ax is None
142
- if created_fig:
143
- fig, ax = plt.subplots(figsize=style.figsize)
144
- else:
145
- fig = ax.figure
146
-
147
- im = ax.imshow(
148
- grid.T,
149
- origin="lower",
150
- extent=[x_edges[0], x_edges[-1], y_edges[0], y_edges[-1]],
151
- aspect="auto",
152
- cmap=cmap,
153
- )
154
- cbar = fig.colorbar(im, ax=ax)
155
- cbar.set_label(value_col, fontsize=style.label_size)
156
- cbar.ax.tick_params(labelsize=style.tick_size)
157
-
158
- ax.set_xlabel(x_col, fontsize=style.label_size)
159
- ax.set_ylabel(y_col, fontsize=style.label_size)
160
- ax.set_title(title, fontsize=style.title_size)
161
- ax.tick_params(axis="both", labelsize=style.tick_size)
162
-
163
- if created_fig:
164
- finalize_figure(fig, save_path=save_path, show=show, style=style)
165
-
166
- return fig
167
-
168
-
169
- def plot_geo_contour(
170
- df: pd.DataFrame,
171
- *,
172
- x_col: str,
173
- y_col: str,
174
- value_col: str,
175
- weight_col: Optional[str] = None,
176
- max_points: Optional[int] = None,
177
- levels: int | Sequence[float] = 10,
178
- cmap: str = "viridis",
179
- title: str = "Geo Contour",
180
- ax: Optional[plt.Axes] = None,
181
- show_points: bool = False,
182
- show: bool = False,
183
- save_path: Optional[str] = None,
184
- style: Optional[PlotStyle] = None,
185
- ) -> plt.Figure:
186
- style = style or PlotStyle()
187
- x, y, z, w = _sanitize_geo(df, x_col, y_col, value_col, weight_col)
188
- x, y, z, w = _downsample_points(x, y, z, w, max_points)
189
-
190
- if w is not None:
191
- z = z * w
192
-
193
- triang = mtri.Triangulation(x, y)
194
-
195
- created_fig = ax is None
196
- if created_fig:
197
- fig, ax = plt.subplots(figsize=style.figsize)
198
- else:
199
- fig = ax.figure
200
-
201
- contour = ax.tricontourf(triang, z, levels=levels, cmap=cmap)
202
- if show_points:
203
- ax.scatter(x, y, s=6, c="k", alpha=0.2)
204
- cbar = fig.colorbar(contour, ax=ax)
205
- cbar.set_label(value_col, fontsize=style.label_size)
206
- cbar.ax.tick_params(labelsize=style.tick_size)
207
-
208
- ax.set_xlabel(x_col, fontsize=style.label_size)
209
- ax.set_ylabel(y_col, fontsize=style.label_size)
210
- ax.set_title(title, fontsize=style.title_size)
211
- ax.tick_params(axis="both", labelsize=style.tick_size)
212
-
213
- if created_fig:
214
- finalize_figure(fig, save_path=save_path, show=show, style=style)
215
-
216
- return fig
217
-
218
-
219
- def plot_geo_heatmap_on_map(
220
- df: pd.DataFrame,
221
- *,
222
- lon_col: str,
223
- lat_col: str,
224
- value_col: str,
225
- weight_col: Optional[str] = None,
226
- bins: int | Tuple[int, int] = 100,
227
- agg: str = "mean",
228
- cmap: str = "YlOrRd",
229
- alpha: float = 0.6,
230
- basemap: Optional[object] = "CartoDB.Positron",
231
- zoom: Optional[int] = None,
232
- padding: float = 0.05,
233
- title: str = "Geo Heatmap (Map)",
234
- ax: Optional[plt.Axes] = None,
235
- show_points: bool = False,
236
- show: bool = False,
237
- save_path: Optional[str] = None,
238
- style: Optional[PlotStyle] = None,
239
- ) -> plt.Figure:
240
- _require_contextily("plot_geo_heatmap_on_map")
241
- style = style or PlotStyle()
242
- if agg not in {"mean", "sum"}:
243
- raise ValueError("agg must be 'mean' or 'sum'.")
244
- lon, lat, z, w = _sanitize_geo(df, lon_col, lat_col, value_col, weight_col)
245
- x, y = _lonlat_to_mercator(lon, lat)
246
-
247
- if isinstance(bins, int):
248
- bins = (bins, bins)
249
-
250
- if w is None:
251
- sum_z, x_edges, y_edges = np.histogram2d(x, y, bins=bins, weights=z)
252
- if agg == "sum":
253
- grid = sum_z
254
- else:
255
- count, _, _ = np.histogram2d(x, y, bins=bins)
256
- grid = sum_z / np.maximum(count, 1.0)
257
- else:
258
- sum_w, x_edges, y_edges = np.histogram2d(x, y, bins=bins, weights=w)
259
- sum_zw, _, _ = np.histogram2d(x, y, bins=bins, weights=z * w)
260
- grid = sum_zw / np.maximum(sum_w, EPS)
261
-
262
- created_fig = ax is None
263
- if created_fig:
264
- fig, ax = plt.subplots(figsize=style.figsize)
265
- else:
266
- fig = ax.figure
267
-
268
- _apply_bounds(ax, x, y, padding)
269
- ax.set_aspect("equal", adjustable="box")
270
-
271
- source = _resolve_basemap(basemap)
272
- if source is not None:
273
- if zoom is None:
274
- cx.add_basemap(ax, source=source, crs="EPSG:3857")
275
- else:
276
- cx.add_basemap(ax, source=source, crs="EPSG:3857", zoom=zoom)
277
-
278
- im = ax.imshow(
279
- grid.T,
280
- origin="lower",
281
- extent=[x_edges[0], x_edges[-1], y_edges[0], y_edges[-1]],
282
- aspect="auto",
283
- cmap=cmap,
284
- alpha=alpha,
285
- )
286
- if show_points:
287
- ax.scatter(x, y, s=6, c="k", alpha=0.25)
288
-
289
- cbar = fig.colorbar(im, ax=ax)
290
- cbar.set_label(value_col, fontsize=style.label_size)
291
- cbar.ax.tick_params(labelsize=style.tick_size)
292
-
293
- ax.set_title(title, fontsize=style.title_size)
294
- ax.tick_params(axis="both", labelsize=style.tick_size)
295
-
296
- if created_fig:
297
- finalize_figure(fig, save_path=save_path, show=show, style=style)
298
-
299
- return fig
300
-
301
-
302
- def plot_geo_contour_on_map(
303
- df: pd.DataFrame,
304
- *,
305
- lon_col: str,
306
- lat_col: str,
307
- value_col: str,
308
- weight_col: Optional[str] = None,
309
- max_points: Optional[int] = None,
310
- levels: int | Sequence[float] = 10,
311
- cmap: str = "viridis",
312
- alpha: float = 0.6,
313
- basemap: Optional[object] = "CartoDB.Positron",
314
- zoom: Optional[int] = None,
315
- padding: float = 0.05,
316
- title: str = "Geo Contour (Map)",
317
- ax: Optional[plt.Axes] = None,
318
- show_points: bool = False,
319
- show: bool = False,
320
- save_path: Optional[str] = None,
321
- style: Optional[PlotStyle] = None,
322
- ) -> plt.Figure:
323
- _require_contextily("plot_geo_contour_on_map")
324
- style = style or PlotStyle()
325
- lon, lat, z, w = _sanitize_geo(df, lon_col, lat_col, value_col, weight_col)
326
- lon, lat, z, w = _downsample_points(lon, lat, z, w, max_points)
327
- x, y = _lonlat_to_mercator(lon, lat)
328
- if w is not None:
329
- z = z * w
330
-
331
- created_fig = ax is None
332
- if created_fig:
333
- fig, ax = plt.subplots(figsize=style.figsize)
334
- else:
335
- fig = ax.figure
336
-
337
- _apply_bounds(ax, x, y, padding)
338
- ax.set_aspect("equal", adjustable="box")
339
-
340
- source = _resolve_basemap(basemap)
341
- if source is not None:
342
- if zoom is None:
343
- cx.add_basemap(ax, source=source, crs="EPSG:3857")
344
- else:
345
- cx.add_basemap(ax, source=source, crs="EPSG:3857", zoom=zoom)
346
-
347
- triang = mtri.Triangulation(x, y)
348
- contour = ax.tricontourf(triang, z, levels=levels, cmap=cmap, alpha=alpha)
349
- if show_points:
350
- ax.scatter(x, y, s=6, c="k", alpha=0.25)
351
-
352
- cbar = fig.colorbar(contour, ax=ax)
353
- cbar.set_label(value_col, fontsize=style.label_size)
354
- cbar.ax.tick_params(labelsize=style.tick_size)
355
-
356
- ax.set_title(title, fontsize=style.title_size)
357
- ax.tick_params(axis="both", labelsize=style.tick_size)
358
-
359
- if created_fig:
360
- finalize_figure(fig, save_path=save_path, show=show, style=style)
361
-
362
- return fig
1
+ from __future__ import annotations
2
+
3
+ from typing import Optional, Sequence, Tuple
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ import matplotlib.tri as mtri
8
+
9
+ from ins_pricing.modelling.plotting.common import EPS, PlotStyle, finalize_figure, plt
10
+
11
+ try: # optional map basemap support
12
+ import contextily as cx
13
+ except Exception: # pragma: no cover - optional dependency
14
+ cx = None
15
+
16
+
17
+ _MERCATOR_MAX_LAT = 85.05112878
18
+ _MERCATOR_FACTOR = 20037508.34
19
+
20
+
21
+ def _require_contextily(func_name: str) -> None:
22
+ if cx is None:
23
+ raise RuntimeError(
24
+ f"{func_name} requires contextily. Install it via 'pip install contextily'."
25
+ )
26
+
27
+
28
+ def _lonlat_to_mercator(lon: np.ndarray, lat: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
29
+ lon = np.asarray(lon, dtype=float)
30
+ lat = np.asarray(lat, dtype=float)
31
+ lat = np.clip(lat, -_MERCATOR_MAX_LAT, _MERCATOR_MAX_LAT)
32
+ x = lon * _MERCATOR_FACTOR / 180.0
33
+ y = np.log(np.tan((90.0 + lat) * np.pi / 360.0)) * _MERCATOR_FACTOR / np.pi
34
+ return x, y
35
+
36
+
37
+ def _apply_bounds(ax: plt.Axes, x: np.ndarray, y: np.ndarray, padding: float) -> None:
38
+ x_min, x_max = float(np.min(x)), float(np.max(x))
39
+ y_min, y_max = float(np.min(y)), float(np.max(y))
40
+ pad_x = (x_max - x_min) * padding
41
+ pad_y = (y_max - y_min) * padding
42
+ if pad_x == 0:
43
+ pad_x = 1.0
44
+ if pad_y == 0:
45
+ pad_y = 1.0
46
+ ax.set_xlim(x_min - pad_x, x_max + pad_x)
47
+ ax.set_ylim(y_min - pad_y, y_max + pad_y)
48
+
49
+
50
+ def _resolve_basemap(source):
51
+ if cx is None or source is None:
52
+ return source
53
+ if isinstance(source, str):
54
+ provider = cx.providers
55
+ for part in source.split("."):
56
+ if isinstance(provider, dict):
57
+ provider = provider[part]
58
+ else:
59
+ provider = getattr(provider, part)
60
+ return provider
61
+ return source
62
+
63
+
64
+ def _sanitize_geo(
65
+ df: pd.DataFrame,
66
+ x_col: str,
67
+ y_col: str,
68
+ value_col: str,
69
+ weight_col: Optional[str] = None,
70
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, Optional[np.ndarray]]:
71
+ x = pd.to_numeric(df[x_col], errors="coerce").to_numpy(dtype=float)
72
+ y = pd.to_numeric(df[y_col], errors="coerce").to_numpy(dtype=float)
73
+ z = pd.to_numeric(df[value_col], errors="coerce").to_numpy(dtype=float)
74
+ w = None
75
+ if weight_col:
76
+ w = pd.to_numeric(df[weight_col], errors="coerce").to_numpy(dtype=float)
77
+
78
+ if w is None:
79
+ mask = np.isfinite(x) & np.isfinite(y) & np.isfinite(z)
80
+ else:
81
+ mask = np.isfinite(x) & np.isfinite(y) & np.isfinite(z) & np.isfinite(w)
82
+ w = w[mask]
83
+ return x[mask], y[mask], z[mask], w
84
+
85
+
86
+ def _downsample_points(
87
+ x: np.ndarray,
88
+ y: np.ndarray,
89
+ z: np.ndarray,
90
+ w: Optional[np.ndarray],
91
+ max_points: Optional[int],
92
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, Optional[np.ndarray]]:
93
+ if max_points is None:
94
+ return x, y, z, w
95
+ max_points = int(max_points)
96
+ if max_points <= 0 or len(x) <= max_points:
97
+ return x, y, z, w
98
+ rng = np.random.default_rng(13)
99
+ idx = rng.choice(len(x), size=max_points, replace=False)
100
+ if w is None:
101
+ return x[idx], y[idx], z[idx], None
102
+ return x[idx], y[idx], z[idx], w[idx]
103
+
104
+
105
+ def plot_geo_heatmap(
106
+ df: pd.DataFrame,
107
+ *,
108
+ x_col: str,
109
+ y_col: str,
110
+ value_col: str,
111
+ weight_col: Optional[str] = None,
112
+ bins: int | Tuple[int, int] = 50,
113
+ agg: str = "mean",
114
+ cmap: str = "YlOrRd",
115
+ title: str = "Geo Heatmap",
116
+ ax: Optional[plt.Axes] = None,
117
+ show: bool = False,
118
+ save_path: Optional[str] = None,
119
+ style: Optional[PlotStyle] = None,
120
+ ) -> plt.Figure:
121
+ style = style or PlotStyle()
122
+ if agg not in {"mean", "sum"}:
123
+ raise ValueError("agg must be 'mean' or 'sum'.")
124
+ x, y, z, w = _sanitize_geo(df, x_col, y_col, value_col, weight_col)
125
+
126
+ if isinstance(bins, int):
127
+ bins = (bins, bins)
128
+
129
+ if w is None:
130
+ sum_z, x_edges, y_edges = np.histogram2d(x, y, bins=bins, weights=z)
131
+ if agg == "sum":
132
+ grid = sum_z
133
+ else:
134
+ count, _, _ = np.histogram2d(x, y, bins=bins)
135
+ grid = sum_z / np.maximum(count, 1.0)
136
+ else:
137
+ sum_w, x_edges, y_edges = np.histogram2d(x, y, bins=bins, weights=w)
138
+ sum_zw, _, _ = np.histogram2d(x, y, bins=bins, weights=z * w)
139
+ grid = sum_zw / np.maximum(sum_w, EPS)
140
+
141
+ created_fig = ax is None
142
+ if created_fig:
143
+ fig, ax = plt.subplots(figsize=style.figsize)
144
+ else:
145
+ fig = ax.figure
146
+
147
+ im = ax.imshow(
148
+ grid.T,
149
+ origin="lower",
150
+ extent=[x_edges[0], x_edges[-1], y_edges[0], y_edges[-1]],
151
+ aspect="auto",
152
+ cmap=cmap,
153
+ )
154
+ cbar = fig.colorbar(im, ax=ax)
155
+ cbar.set_label(value_col, fontsize=style.label_size)
156
+ cbar.ax.tick_params(labelsize=style.tick_size)
157
+
158
+ ax.set_xlabel(x_col, fontsize=style.label_size)
159
+ ax.set_ylabel(y_col, fontsize=style.label_size)
160
+ ax.set_title(title, fontsize=style.title_size)
161
+ ax.tick_params(axis="both", labelsize=style.tick_size)
162
+
163
+ if created_fig:
164
+ finalize_figure(fig, save_path=save_path, show=show, style=style)
165
+
166
+ return fig
167
+
168
+
169
+ def plot_geo_contour(
170
+ df: pd.DataFrame,
171
+ *,
172
+ x_col: str,
173
+ y_col: str,
174
+ value_col: str,
175
+ weight_col: Optional[str] = None,
176
+ max_points: Optional[int] = None,
177
+ levels: int | Sequence[float] = 10,
178
+ cmap: str = "viridis",
179
+ title: str = "Geo Contour",
180
+ ax: Optional[plt.Axes] = None,
181
+ show_points: bool = False,
182
+ show: bool = False,
183
+ save_path: Optional[str] = None,
184
+ style: Optional[PlotStyle] = None,
185
+ ) -> plt.Figure:
186
+ style = style or PlotStyle()
187
+ x, y, z, w = _sanitize_geo(df, x_col, y_col, value_col, weight_col)
188
+ x, y, z, w = _downsample_points(x, y, z, w, max_points)
189
+
190
+ if w is not None:
191
+ z = z * w
192
+
193
+ triang = mtri.Triangulation(x, y)
194
+
195
+ created_fig = ax is None
196
+ if created_fig:
197
+ fig, ax = plt.subplots(figsize=style.figsize)
198
+ else:
199
+ fig = ax.figure
200
+
201
+ contour = ax.tricontourf(triang, z, levels=levels, cmap=cmap)
202
+ if show_points:
203
+ ax.scatter(x, y, s=6, c="k", alpha=0.2)
204
+ cbar = fig.colorbar(contour, ax=ax)
205
+ cbar.set_label(value_col, fontsize=style.label_size)
206
+ cbar.ax.tick_params(labelsize=style.tick_size)
207
+
208
+ ax.set_xlabel(x_col, fontsize=style.label_size)
209
+ ax.set_ylabel(y_col, fontsize=style.label_size)
210
+ ax.set_title(title, fontsize=style.title_size)
211
+ ax.tick_params(axis="both", labelsize=style.tick_size)
212
+
213
+ if created_fig:
214
+ finalize_figure(fig, save_path=save_path, show=show, style=style)
215
+
216
+ return fig
217
+
218
+
219
+ def plot_geo_heatmap_on_map(
220
+ df: pd.DataFrame,
221
+ *,
222
+ lon_col: str,
223
+ lat_col: str,
224
+ value_col: str,
225
+ weight_col: Optional[str] = None,
226
+ bins: int | Tuple[int, int] = 100,
227
+ agg: str = "mean",
228
+ cmap: str = "YlOrRd",
229
+ alpha: float = 0.6,
230
+ basemap: Optional[object] = "CartoDB.Positron",
231
+ zoom: Optional[int] = None,
232
+ padding: float = 0.05,
233
+ title: str = "Geo Heatmap (Map)",
234
+ ax: Optional[plt.Axes] = None,
235
+ show_points: bool = False,
236
+ show: bool = False,
237
+ save_path: Optional[str] = None,
238
+ style: Optional[PlotStyle] = None,
239
+ ) -> plt.Figure:
240
+ _require_contextily("plot_geo_heatmap_on_map")
241
+ style = style or PlotStyle()
242
+ if agg not in {"mean", "sum"}:
243
+ raise ValueError("agg must be 'mean' or 'sum'.")
244
+ lon, lat, z, w = _sanitize_geo(df, lon_col, lat_col, value_col, weight_col)
245
+ x, y = _lonlat_to_mercator(lon, lat)
246
+
247
+ if isinstance(bins, int):
248
+ bins = (bins, bins)
249
+
250
+ if w is None:
251
+ sum_z, x_edges, y_edges = np.histogram2d(x, y, bins=bins, weights=z)
252
+ if agg == "sum":
253
+ grid = sum_z
254
+ else:
255
+ count, _, _ = np.histogram2d(x, y, bins=bins)
256
+ grid = sum_z / np.maximum(count, 1.0)
257
+ else:
258
+ sum_w, x_edges, y_edges = np.histogram2d(x, y, bins=bins, weights=w)
259
+ sum_zw, _, _ = np.histogram2d(x, y, bins=bins, weights=z * w)
260
+ grid = sum_zw / np.maximum(sum_w, EPS)
261
+
262
+ created_fig = ax is None
263
+ if created_fig:
264
+ fig, ax = plt.subplots(figsize=style.figsize)
265
+ else:
266
+ fig = ax.figure
267
+
268
+ _apply_bounds(ax, x, y, padding)
269
+ ax.set_aspect("equal", adjustable="box")
270
+
271
+ source = _resolve_basemap(basemap)
272
+ if source is not None:
273
+ if zoom is None:
274
+ cx.add_basemap(ax, source=source, crs="EPSG:3857")
275
+ else:
276
+ cx.add_basemap(ax, source=source, crs="EPSG:3857", zoom=zoom)
277
+
278
+ im = ax.imshow(
279
+ grid.T,
280
+ origin="lower",
281
+ extent=[x_edges[0], x_edges[-1], y_edges[0], y_edges[-1]],
282
+ aspect="auto",
283
+ cmap=cmap,
284
+ alpha=alpha,
285
+ )
286
+ if show_points:
287
+ ax.scatter(x, y, s=6, c="k", alpha=0.25)
288
+
289
+ cbar = fig.colorbar(im, ax=ax)
290
+ cbar.set_label(value_col, fontsize=style.label_size)
291
+ cbar.ax.tick_params(labelsize=style.tick_size)
292
+
293
+ ax.set_title(title, fontsize=style.title_size)
294
+ ax.tick_params(axis="both", labelsize=style.tick_size)
295
+
296
+ if created_fig:
297
+ finalize_figure(fig, save_path=save_path, show=show, style=style)
298
+
299
+ return fig
300
+
301
+
302
+ def plot_geo_contour_on_map(
303
+ df: pd.DataFrame,
304
+ *,
305
+ lon_col: str,
306
+ lat_col: str,
307
+ value_col: str,
308
+ weight_col: Optional[str] = None,
309
+ max_points: Optional[int] = None,
310
+ levels: int | Sequence[float] = 10,
311
+ cmap: str = "viridis",
312
+ alpha: float = 0.6,
313
+ basemap: Optional[object] = "CartoDB.Positron",
314
+ zoom: Optional[int] = None,
315
+ padding: float = 0.05,
316
+ title: str = "Geo Contour (Map)",
317
+ ax: Optional[plt.Axes] = None,
318
+ show_points: bool = False,
319
+ show: bool = False,
320
+ save_path: Optional[str] = None,
321
+ style: Optional[PlotStyle] = None,
322
+ ) -> plt.Figure:
323
+ _require_contextily("plot_geo_contour_on_map")
324
+ style = style or PlotStyle()
325
+ lon, lat, z, w = _sanitize_geo(df, lon_col, lat_col, value_col, weight_col)
326
+ lon, lat, z, w = _downsample_points(lon, lat, z, w, max_points)
327
+ x, y = _lonlat_to_mercator(lon, lat)
328
+ if w is not None:
329
+ z = z * w
330
+
331
+ created_fig = ax is None
332
+ if created_fig:
333
+ fig, ax = plt.subplots(figsize=style.figsize)
334
+ else:
335
+ fig = ax.figure
336
+
337
+ _apply_bounds(ax, x, y, padding)
338
+ ax.set_aspect("equal", adjustable="box")
339
+
340
+ source = _resolve_basemap(basemap)
341
+ if source is not None:
342
+ if zoom is None:
343
+ cx.add_basemap(ax, source=source, crs="EPSG:3857")
344
+ else:
345
+ cx.add_basemap(ax, source=source, crs="EPSG:3857", zoom=zoom)
346
+
347
+ triang = mtri.Triangulation(x, y)
348
+ contour = ax.tricontourf(triang, z, levels=levels, cmap=cmap, alpha=alpha)
349
+ if show_points:
350
+ ax.scatter(x, y, s=6, c="k", alpha=0.25)
351
+
352
+ cbar = fig.colorbar(contour, ax=ax)
353
+ cbar.set_label(value_col, fontsize=style.label_size)
354
+ cbar.ax.tick_params(labelsize=style.tick_size)
355
+
356
+ ax.set_title(title, fontsize=style.title_size)
357
+ ax.tick_params(axis="both", labelsize=style.tick_size)
358
+
359
+ if created_fig:
360
+ finalize_figure(fig, save_path=save_path, show=show, style=style)
361
+
362
+ return fig