pivtools 0.1.3__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (127) hide show
  1. pivtools-0.1.3.dist-info/METADATA +222 -0
  2. pivtools-0.1.3.dist-info/RECORD +127 -0
  3. pivtools-0.1.3.dist-info/WHEEL +5 -0
  4. pivtools-0.1.3.dist-info/entry_points.txt +3 -0
  5. pivtools-0.1.3.dist-info/top_level.txt +3 -0
  6. pivtools_cli/__init__.py +5 -0
  7. pivtools_cli/_build_marker.c +25 -0
  8. pivtools_cli/_build_marker.cp311-win_amd64.pyd +0 -0
  9. pivtools_cli/cli.py +225 -0
  10. pivtools_cli/example.py +139 -0
  11. pivtools_cli/lib/PIV_2d_cross_correlate.c +334 -0
  12. pivtools_cli/lib/PIV_2d_cross_correlate.h +22 -0
  13. pivtools_cli/lib/common.h +36 -0
  14. pivtools_cli/lib/interp2custom.c +146 -0
  15. pivtools_cli/lib/interp2custom.h +48 -0
  16. pivtools_cli/lib/peak_locate_gsl.c +711 -0
  17. pivtools_cli/lib/peak_locate_gsl.h +40 -0
  18. pivtools_cli/lib/peak_locate_gsl_print.c +736 -0
  19. pivtools_cli/lib/peak_locate_lm.c +751 -0
  20. pivtools_cli/lib/peak_locate_lm.h +27 -0
  21. pivtools_cli/lib/xcorr.c +342 -0
  22. pivtools_cli/lib/xcorr.h +31 -0
  23. pivtools_cli/lib/xcorr_cache.c +78 -0
  24. pivtools_cli/lib/xcorr_cache.h +26 -0
  25. pivtools_cli/piv/interp2custom/interp2custom.py +69 -0
  26. pivtools_cli/piv/piv.py +240 -0
  27. pivtools_cli/piv/piv_backend/base.py +825 -0
  28. pivtools_cli/piv/piv_backend/cpu_instantaneous.py +1005 -0
  29. pivtools_cli/piv/piv_backend/factory.py +28 -0
  30. pivtools_cli/piv/piv_backend/gpu_instantaneous.py +15 -0
  31. pivtools_cli/piv/piv_backend/infilling.py +445 -0
  32. pivtools_cli/piv/piv_backend/outlier_detection.py +306 -0
  33. pivtools_cli/piv/piv_backend/profile_cpu_instantaneous.py +230 -0
  34. pivtools_cli/piv/piv_result.py +40 -0
  35. pivtools_cli/piv/save_results.py +342 -0
  36. pivtools_cli/piv_cluster/cluster.py +108 -0
  37. pivtools_cli/preprocessing/filters.py +399 -0
  38. pivtools_cli/preprocessing/preprocess.py +79 -0
  39. pivtools_cli/tests/helpers.py +107 -0
  40. pivtools_cli/tests/instantaneous_piv/test_piv_integration.py +167 -0
  41. pivtools_cli/tests/instantaneous_piv/test_piv_integration_multi.py +553 -0
  42. pivtools_cli/tests/preprocessing/test_filters.py +41 -0
  43. pivtools_core/__init__.py +5 -0
  44. pivtools_core/config.py +703 -0
  45. pivtools_core/config.yaml +135 -0
  46. pivtools_core/image_handling/__init__.py +0 -0
  47. pivtools_core/image_handling/load_images.py +464 -0
  48. pivtools_core/image_handling/readers/__init__.py +53 -0
  49. pivtools_core/image_handling/readers/generic_readers.py +50 -0
  50. pivtools_core/image_handling/readers/lavision_reader.py +190 -0
  51. pivtools_core/image_handling/readers/registry.py +24 -0
  52. pivtools_core/paths.py +49 -0
  53. pivtools_core/vector_loading.py +248 -0
  54. pivtools_gui/__init__.py +3 -0
  55. pivtools_gui/app.py +687 -0
  56. pivtools_gui/calibration/__init__.py +0 -0
  57. pivtools_gui/calibration/app/__init__.py +0 -0
  58. pivtools_gui/calibration/app/views.py +1186 -0
  59. pivtools_gui/calibration/calibration_planar/planar_calibration_production.py +570 -0
  60. pivtools_gui/calibration/vector_calibration_production.py +544 -0
  61. pivtools_gui/config.py +703 -0
  62. pivtools_gui/image_handling/__init__.py +0 -0
  63. pivtools_gui/image_handling/load_images.py +464 -0
  64. pivtools_gui/image_handling/readers/__init__.py +53 -0
  65. pivtools_gui/image_handling/readers/generic_readers.py +50 -0
  66. pivtools_gui/image_handling/readers/lavision_reader.py +190 -0
  67. pivtools_gui/image_handling/readers/registry.py +24 -0
  68. pivtools_gui/masking/__init__.py +0 -0
  69. pivtools_gui/masking/app/__init__.py +0 -0
  70. pivtools_gui/masking/app/views.py +123 -0
  71. pivtools_gui/paths.py +49 -0
  72. pivtools_gui/piv_runner.py +261 -0
  73. pivtools_gui/pivtools.py +58 -0
  74. pivtools_gui/plotting/__init__.py +0 -0
  75. pivtools_gui/plotting/app/__init__.py +0 -0
  76. pivtools_gui/plotting/app/views.py +1671 -0
  77. pivtools_gui/plotting/plot_maker.py +220 -0
  78. pivtools_gui/post_processing/POD/__init__.py +0 -0
  79. pivtools_gui/post_processing/POD/app/__init__.py +0 -0
  80. pivtools_gui/post_processing/POD/app/views.py +647 -0
  81. pivtools_gui/post_processing/POD/pod_decompose.py +979 -0
  82. pivtools_gui/post_processing/POD/views.py +1096 -0
  83. pivtools_gui/post_processing/__init__.py +0 -0
  84. pivtools_gui/static/404.html +1 -0
  85. pivtools_gui/static/_next/static/chunks/117-d5793c8e79de5511.js +2 -0
  86. pivtools_gui/static/_next/static/chunks/484-cfa8b9348ce4f00e.js +1 -0
  87. pivtools_gui/static/_next/static/chunks/869-320a6b9bdafbb6d3.js +1 -0
  88. pivtools_gui/static/_next/static/chunks/app/_not-found/page-12f067ceb7415e55.js +1 -0
  89. pivtools_gui/static/_next/static/chunks/app/layout-b907d5f31ac82e9d.js +1 -0
  90. pivtools_gui/static/_next/static/chunks/app/page-334cc4e8444cde2f.js +1 -0
  91. pivtools_gui/static/_next/static/chunks/fd9d1056-ad15f396ddf9b7e5.js +1 -0
  92. pivtools_gui/static/_next/static/chunks/framework-f66176bb897dc684.js +1 -0
  93. pivtools_gui/static/_next/static/chunks/main-a1b3ced4d5f6d998.js +1 -0
  94. pivtools_gui/static/_next/static/chunks/main-app-8a63c6f5e7baee11.js +1 -0
  95. pivtools_gui/static/_next/static/chunks/pages/_app-72b849fbd24ac258.js +1 -0
  96. pivtools_gui/static/_next/static/chunks/pages/_error-7ba65e1336b92748.js +1 -0
  97. pivtools_gui/static/_next/static/chunks/polyfills-42372ed130431b0a.js +1 -0
  98. pivtools_gui/static/_next/static/chunks/webpack-4a8ca7c99e9bb3d8.js +1 -0
  99. pivtools_gui/static/_next/static/css/7d3f2337d7ea12a5.css +3 -0
  100. pivtools_gui/static/_next/static/vQeR20OUdSSKlK4vukC4q/_buildManifest.js +1 -0
  101. pivtools_gui/static/_next/static/vQeR20OUdSSKlK4vukC4q/_ssgManifest.js +1 -0
  102. pivtools_gui/static/file.svg +1 -0
  103. pivtools_gui/static/globe.svg +1 -0
  104. pivtools_gui/static/grid.svg +8 -0
  105. pivtools_gui/static/index.html +1 -0
  106. pivtools_gui/static/index.txt +8 -0
  107. pivtools_gui/static/next.svg +1 -0
  108. pivtools_gui/static/vercel.svg +1 -0
  109. pivtools_gui/static/window.svg +1 -0
  110. pivtools_gui/stereo_reconstruction/__init__.py +0 -0
  111. pivtools_gui/stereo_reconstruction/app/__init__.py +0 -0
  112. pivtools_gui/stereo_reconstruction/app/views.py +1985 -0
  113. pivtools_gui/stereo_reconstruction/stereo_calibration_production.py +606 -0
  114. pivtools_gui/stereo_reconstruction/stereo_reconstruction_production.py +544 -0
  115. pivtools_gui/utils.py +63 -0
  116. pivtools_gui/vector_loading.py +248 -0
  117. pivtools_gui/vector_merging/__init__.py +1 -0
  118. pivtools_gui/vector_merging/app/__init__.py +1 -0
  119. pivtools_gui/vector_merging/app/views.py +759 -0
  120. pivtools_gui/vector_statistics/app/__init__.py +1 -0
  121. pivtools_gui/vector_statistics/app/views.py +710 -0
  122. pivtools_gui/vector_statistics/ensemble_statistics.py +49 -0
  123. pivtools_gui/vector_statistics/instantaneous_statistics.py +311 -0
  124. pivtools_gui/video_maker/__init__.py +0 -0
  125. pivtools_gui/video_maker/app/__init__.py +0 -0
  126. pivtools_gui/video_maker/app/views.py +436 -0
  127. pivtools_gui/video_maker/video_maker.py +662 -0
@@ -0,0 +1,28 @@
1
+ import sys
2
+ from pathlib import Path
3
+ from typing import Optional
4
+
5
+ from pivtools_core.config import Config
6
+ from pivtools_cli.piv.piv_backend.cpu_instantaneous import InstantaneousCorrelatorCPU
7
+ from pivtools_cli.piv.piv_backend.gpu_instantaneous import InstantaneousCorrelatorGPU
8
+
9
+ # Global cache for correlator instances to avoid redundant caching
10
+ _correlator_cache = {}
11
+ _correlator_cache_data = {}
12
+
13
+
14
+ def make_correlator_backend(config: Config, precomputed_cache: Optional[dict] = None):
15
+ """Create correlator backend, optionally with precomputed cache.
16
+
17
+ :param config: Configuration object
18
+ :param precomputed_cache: Optional precomputed cache data to avoid redundant computation
19
+ :return: Correlator backend instance
20
+ """
21
+ backend = getattr(config, "backend", "cpu").lower()
22
+
23
+ if backend == "cpu":
24
+ return InstantaneousCorrelatorCPU(config=config, precomputed_cache=precomputed_cache)
25
+ elif backend == "gpu":
26
+ return InstantaneousCorrelatorGPU()
27
+ else:
28
+ raise ValueError(f"Unknown backend: {backend}")
@@ -0,0 +1,15 @@
1
+ import sys
2
+ from pathlib import Path
3
+
4
+ import dask.array as da
5
+ import numpy as np
6
+
7
+ from pivtools_core.config import Config
8
+
9
+ from pivtools_cli.piv.piv_backend.base import CrossCorrelator
10
+
11
+
12
+ class InstantaneousCorrelatorGPU(CrossCorrelator):
13
+ def correlate_batch(self, images: np.ndarray, config: Config) -> da.Array:
14
+
15
+ pass
@@ -0,0 +1,445 @@
1
+ """
2
+ Infilling methods for PIV velocity fields.
3
+
4
+ This module provides various methods for filling NaN/masked values in PIV data:
5
+ - Local median infilling
6
+ - K-nearest neighbors (KNN) regression
7
+ - Biharmonic inpainting
8
+ - Griddata linear interpolation
9
+ - Radial basis function (RBF) interpolation
10
+ """
11
+
12
+ import numpy as np
13
+ import bottleneck as bn
14
+ from numpy.lib.stride_tricks import sliding_window_view
15
+ from scipy import ndimage as ndi
16
+ from scipy.interpolate import griddata, RBFInterpolator
17
+ from scipy.spatial import cKDTree
18
+ from skimage.restoration import inpaint_biharmonic
19
+
20
+ try:
21
+ from sklearn.neighbors import KNeighborsRegressor
22
+ SKLEARN_AVAILABLE = True
23
+ except ImportError:
24
+ SKLEARN_AVAILABLE = False
25
+
26
+
27
+ def infill_local_median(
28
+ ux: np.ndarray,
29
+ uy: np.ndarray,
30
+ mask: np.ndarray,
31
+ ksize: int = 3,
32
+ ) -> tuple[np.ndarray, np.ndarray]:
33
+ """
34
+ Fill masked pixels with the median of their ksize×ksize neighbors (center excluded).
35
+
36
+ NaNs in the neighborhood are ignored (nanmedian).
37
+
38
+ Parameters
39
+ ----------
40
+ ux : np.ndarray
41
+ Horizontal velocity component (2D).
42
+ uy : np.ndarray
43
+ Vertical velocity component (2D).
44
+ mask : np.ndarray
45
+ Boolean mask where True indicates pixels to fill.
46
+ ksize : int, optional
47
+ Kernel size for local median window, defaults to 3.
48
+
49
+ Returns
50
+ -------
51
+ tuple[np.ndarray, np.ndarray]
52
+ Infilled (ux, uy) arrays.
53
+ """
54
+ ux = np.asarray(ux, dtype=np.float32)
55
+ uy = np.asarray(uy, dtype=np.float32)
56
+ H, W = ux.shape
57
+ pad = ksize // 2
58
+
59
+ # Pad with NaN so out-of-bounds are ignored by nanmedian
60
+ ux_pad = np.pad(ux, pad, mode="constant", constant_values=np.nan)
61
+ uy_pad = np.pad(uy, pad, mode="constant", constant_values=np.nan)
62
+
63
+ # Sliding 2D windows (zero-copy views)
64
+ win_x = sliding_window_view(ux_pad, (ksize, ksize)) # (H, W, k, k)
65
+ win_y = sliding_window_view(uy_pad, (ksize, ksize))
66
+
67
+ # Flatten the window and drop the center element to exclude "self"
68
+ KK = ksize * ksize
69
+ center_idx = KK // 2 # works for odd ksize
70
+ keep = np.ones(KK, dtype=bool)
71
+ keep[center_idx] = False
72
+
73
+ nb_x = win_x.reshape(H, W, KK)[..., keep] # (H, W, KK-1)
74
+ nb_y = win_y.reshape(H, W, KK)[..., keep]
75
+
76
+ # If mask is sparse, compute medians only where needed
77
+ if mask is not None and mask.any():
78
+ idx = np.where(mask)
79
+ med_x_vals = bn.nanmedian(nb_x[idx], axis=-1)
80
+ med_y_vals = bn.nanmedian(nb_y[idx], axis=-1)
81
+
82
+ ux_out = ux.copy()
83
+ uy_out = uy.copy()
84
+ ux_out[idx] = med_x_vals
85
+ uy_out[idx] = med_y_vals
86
+ return ux_out, uy_out
87
+ else:
88
+ # Dense case: compute whole-field medians, then apply mask
89
+ med_x = bn.nanmedian(nb_x, axis=-1)
90
+ med_y = bn.nanmedian(nb_y, axis=-1)
91
+ ux_out = np.where(mask, med_x, ux)
92
+ uy_out = np.where(mask, med_y, uy)
93
+ return ux_out, uy_out
94
+
95
+
96
+ def infill_biharmonic(
97
+ ux: np.ndarray,
98
+ uy: np.ndarray,
99
+ mask: np.ndarray,
100
+ ) -> tuple[np.ndarray, np.ndarray]:
101
+ """
102
+ Infill using biharmonic inpainting.
103
+
104
+ Uses scikit-image's biharmonic inpainting algorithm which solves
105
+ a partial differential equation to smoothly fill masked regions.
106
+
107
+ Parameters
108
+ ----------
109
+ ux : np.ndarray
110
+ Horizontal velocity component (2D).
111
+ uy : np.ndarray
112
+ Vertical velocity component (2D).
113
+ mask : np.ndarray
114
+ Boolean mask where True indicates pixels to fill (outliers only).
115
+
116
+ Returns
117
+ -------
118
+ tuple[np.ndarray, np.ndarray]
119
+ Infilled (ux, uy) arrays.
120
+ """
121
+ # inpaint_biharmonic fills NaN regions, so we need to temporarily set outliers to NaN
122
+ ux_temp = ux.copy()
123
+ uy_temp = uy.copy()
124
+ ux_temp[mask] = np.nan
125
+ uy_temp[mask] = np.nan
126
+
127
+ # Create a combined mask of what needs filling (outliers + any pre-existing NaNs)
128
+ combined_mask = np.isnan(ux_temp) | np.isnan(uy_temp)
129
+
130
+ # Biharmonic inpainting
131
+ ux_filled = inpaint_biharmonic(ux_temp, combined_mask)
132
+ uy_filled = inpaint_biharmonic(uy_temp, combined_mask)
133
+
134
+ # Keep original valid values, only replace the outliers
135
+ ux_out = ux.copy()
136
+ uy_out = uy.copy()
137
+ ux_out[mask] = ux_filled[mask]
138
+ uy_out[mask] = uy_filled[mask]
139
+
140
+ return ux_out, uy_out
141
+
142
+
143
+ def infill_griddata_linear(
144
+ ux: np.ndarray,
145
+ uy: np.ndarray,
146
+ mask: np.ndarray,
147
+ method: str = "linear",
148
+ ) -> tuple[np.ndarray, np.ndarray]:
149
+ """
150
+ Infill using scipy griddata interpolation.
151
+
152
+ Uses Delaunay triangulation for linear/cubic interpolation with
153
+ nearest-neighbor fallback for edge regions.
154
+
155
+ Parameters
156
+ ----------
157
+ ux : np.ndarray
158
+ Horizontal velocity component (2D).
159
+ uy : np.ndarray
160
+ Vertical velocity component (2D).
161
+ mask : np.ndarray
162
+ Boolean mask where True indicates pixels to fill.
163
+ method : str, optional
164
+ Interpolation method: 'linear', 'cubic', or 'nearest', defaults to 'linear'.
165
+
166
+ Returns
167
+ -------
168
+ tuple[np.ndarray, np.ndarray]
169
+ Infilled (ux, uy) arrays.
170
+ """
171
+ H, W = ux.shape
172
+ yy, xx = np.mgrid[0:H, 0:W]
173
+
174
+ # Valid points are those that are NOT masked for filling AND are finite
175
+ valid = ~mask & np.isfinite(ux) & np.isfinite(uy)
176
+
177
+ if not valid.any():
178
+ # No valid points to interpolate from
179
+ return ux.copy(), uy.copy()
180
+
181
+ pts = np.c_[yy[valid], xx[valid]]
182
+
183
+ ux_f = griddata(pts, ux[valid], (yy, xx), method=method)
184
+ uy_f = griddata(pts, uy[valid], (yy, xx), method=method)
185
+
186
+ # Fallback to nearest for anything left unfilled (edges)
187
+ if np.isnan(ux_f).any() or np.isnan(uy_f).any():
188
+ ux_nn = griddata(pts, ux[valid], (yy, xx), method="nearest")
189
+ uy_nn = griddata(pts, uy[valid], (yy, xx), method="nearest")
190
+ ux_f = np.where(np.isnan(ux_f), ux_nn, ux_f)
191
+ uy_f = np.where(np.isnan(uy_f), uy_nn, uy_f)
192
+
193
+ # Only replace the masked outliers, keep everything else
194
+ ux_out = ux.copy()
195
+ uy_out = uy.copy()
196
+ ux_out[mask] = ux_f[mask]
197
+ uy_out[mask] = uy_f[mask]
198
+
199
+ return ux_out, uy_out
200
+
201
+
202
+ def infill_rbf_local(
203
+ ux: np.ndarray,
204
+ uy: np.ndarray,
205
+ mask: np.ndarray,
206
+ neighbors: int = 64,
207
+ kernel: str = "thin_plate_spline",
208
+ epsilon: float = None,
209
+ smoothing: float = 0.0,
210
+ ) -> tuple[np.ndarray, np.ndarray]:
211
+ """
212
+ Infill masked PIV vectors with local RBFs (memory-safe).
213
+
214
+ Uses radial basis function interpolation with local neighborhood
215
+ support to keep memory usage reasonable for large fields.
216
+
217
+ Parameters
218
+ ----------
219
+ ux : np.ndarray
220
+ Horizontal velocity component (2D).
221
+ uy : np.ndarray
222
+ Vertical velocity component (2D).
223
+ mask : np.ndarray
224
+ Boolean mask where True indicates pixels to fill.
225
+ neighbors : int, optional
226
+ Number of nearest neighbors used per query, defaults to 64.
227
+ Typical range: 32-128.
228
+ kernel : str, optional
229
+ RBF kernel type, defaults to "thin_plate_spline".
230
+ Options: "thin_plate_spline", "multiquadric", "inverse_multiquadric",
231
+ "cubic", "quintic", "linear", "gaussian".
232
+ epsilon : float, optional
233
+ Shape parameter (length scale). Auto-determined if None.
234
+ Try 1-3 pixels for gaussian/multiquadric; TPS ignores epsilon.
235
+ smoothing : float, optional
236
+ Regularization parameter, defaults to 0.0.
237
+ Try 1e-3 to 1e-1 if data is noisy.
238
+
239
+ Returns
240
+ -------
241
+ tuple[np.ndarray, np.ndarray]
242
+ Infilled (ux, uy) arrays.
243
+ """
244
+ H, W = ux.shape
245
+ yy, xx = np.mgrid[0:H, 0:W]
246
+
247
+ # Valid training points: not masked for filling AND finite
248
+ valid = (~mask) & np.isfinite(ux) & np.isfinite(uy)
249
+
250
+ if valid.sum() < 4:
251
+ # Not enough points—just return originals
252
+ return ux.copy(), uy.copy()
253
+
254
+ X_train = np.c_[xx[valid].ravel(), yy[valid].ravel()].astype(np.float64)
255
+ # Only query points that need filling (the mask)
256
+ X_query = np.c_[xx[mask].ravel(), yy[mask].ravel()].astype(np.float64)
257
+
258
+ u_train = ux[valid].astype(np.float64)
259
+ v_train = uy[valid].astype(np.float64)
260
+
261
+ # Build local RBF models (KD-tree inside; memory ~ O(N))
262
+ rbf_u = RBFInterpolator(
263
+ X_train, u_train,
264
+ kernel=kernel,
265
+ epsilon=epsilon,
266
+ neighbors=neighbors,
267
+ smoothing=smoothing
268
+ )
269
+ rbf_v = RBFInterpolator(
270
+ X_train, v_train,
271
+ kernel=kernel,
272
+ epsilon=epsilon,
273
+ neighbors=neighbors,
274
+ smoothing=smoothing
275
+ )
276
+
277
+ u_pred = rbf_u(X_query)
278
+ v_pred = rbf_v(X_query)
279
+
280
+ # Keep originals; fill only masked cells
281
+ ux_filled = ux.copy()
282
+ uy_filled = uy.copy()
283
+ ux_filled[mask] = u_pred
284
+ uy_filled[mask] = v_pred
285
+
286
+ return ux_filled, uy_filled
287
+
288
+
289
+ def infill_knn(
290
+ ux: np.ndarray,
291
+ uy: np.ndarray,
292
+ mask: np.ndarray,
293
+ n_neighbors: int = 32,
294
+ weights: str = "distance",
295
+ algorithm: str = "kd_tree",
296
+ ) -> tuple[np.ndarray, np.ndarray]:
297
+ """
298
+ Infill masked PIV vectors using K-nearest-neighbor regression.
299
+
300
+ Uses scikit-learn's KNN regressor for fast, local interpolation.
301
+ Requires scikit-learn to be installed.
302
+
303
+ Parameters
304
+ ----------
305
+ ux : np.ndarray
306
+ Horizontal velocity component (2D).
307
+ uy : np.ndarray
308
+ Vertical velocity component (2D).
309
+ mask : np.ndarray
310
+ Boolean mask where True indicates pixels to fill.
311
+ n_neighbors : int, optional
312
+ Number of neighbors for interpolation, defaults to 32.
313
+ Typical range: 16-64.
314
+ weights : str, optional
315
+ Weighting scheme: "uniform" or "distance", defaults to "distance".
316
+ "distance" usually gives smoother results.
317
+ algorithm : str, optional
318
+ Nearest neighbor search algorithm, defaults to "kd_tree".
319
+ Options: "auto", "ball_tree", "kd_tree", "brute".
320
+
321
+ Returns
322
+ -------
323
+ tuple[np.ndarray, np.ndarray]
324
+ Infilled (ux, uy) arrays.
325
+
326
+ Raises
327
+ ------
328
+ ImportError
329
+ If scikit-learn is not installed.
330
+ """
331
+ if not SKLEARN_AVAILABLE:
332
+ raise ImportError(
333
+ "scikit-learn is required for KNN infilling. "
334
+ "Install with: pip install scikit-learn"
335
+ )
336
+
337
+ H, W = ux.shape
338
+ ux = np.asarray(ux, dtype=np.float64)
339
+ uy = np.asarray(uy, dtype=np.float64)
340
+ mask = mask.astype(bool)
341
+
342
+ # Early exit if nothing to fill
343
+ if not np.any(mask):
344
+ return ux, uy
345
+
346
+ yy, xx = np.mgrid[0:H, 0:W]
347
+ valid = (~mask) & np.isfinite(ux) & np.isfinite(uy)
348
+
349
+ X_train = np.c_[xx[valid], yy[valid]]
350
+ X_query = np.c_[xx[mask], yy[mask]]
351
+
352
+ # Clamp n_neighbors to avoid errors
353
+ n_valid = X_train.shape[0]
354
+ if n_valid < n_neighbors:
355
+ if n_valid < 1:
356
+ return ux, uy
357
+ n_neighbors = n_valid
358
+
359
+ # Fit single KNN model for both components (vectorized)
360
+ knn = KNeighborsRegressor(
361
+ n_neighbors=n_neighbors,
362
+ weights=weights,
363
+ algorithm=algorithm,
364
+ n_jobs=-1 # Use all CPU cores
365
+ )
366
+
367
+ # Stack u and v as multi-output targets
368
+ y_train = np.column_stack([ux[valid], uy[valid]])
369
+ knn.fit(X_train, y_train)
370
+
371
+ # Single prediction call for both components
372
+ predictions = knn.predict(X_query)
373
+
374
+ ux_filled = ux.copy()
375
+ uy_filled = uy.copy()
376
+ ux_filled[mask] = predictions[:, 0]
377
+ uy_filled[mask] = predictions[:, 1]
378
+
379
+ return ux_filled, uy_filled
380
+
381
+
382
+ def apply_infilling(
383
+ ux: np.ndarray,
384
+ uy: np.ndarray,
385
+ mask: np.ndarray,
386
+ method_cfg: dict,
387
+ ) -> tuple[np.ndarray, np.ndarray]:
388
+ """
389
+ Apply the configured infilling method.
390
+
391
+ This function dispatches to the appropriate infilling method based on
392
+ the configuration dictionary.
393
+
394
+ Parameters
395
+ ----------
396
+ ux : np.ndarray
397
+ Horizontal velocity component (2D).
398
+ uy : np.ndarray
399
+ Vertical velocity component (2D).
400
+ mask : np.ndarray
401
+ Boolean mask where True indicates pixels to fill.
402
+ method_cfg : dict
403
+ Configuration dictionary with 'method' and 'parameters' keys.
404
+
405
+ Returns
406
+ -------
407
+ tuple[np.ndarray, np.ndarray]
408
+ Infilled (ux, uy) arrays.
409
+
410
+ Raises
411
+ ------
412
+ ValueError
413
+ If unknown infilling method is specified.
414
+ """
415
+ method = method_cfg.get('method', 'local_median').lower()
416
+ params = method_cfg.get('parameters', {})
417
+
418
+ if method == 'local_median':
419
+ ksize = params.get('ksize', 3)
420
+ return infill_local_median(ux, uy, mask, ksize=ksize)
421
+
422
+ elif method == 'knn':
423
+ n_neighbors = params.get('n_neighbors', 32)
424
+ weights = params.get('weights', 'distance')
425
+ algorithm = params.get('algorithm', 'kd_tree')
426
+ return infill_knn(ux, uy, mask, n_neighbors=n_neighbors,
427
+ weights=weights, algorithm=algorithm)
428
+
429
+ elif method == 'biharmonic':
430
+ return infill_biharmonic(ux, uy, mask)
431
+
432
+ elif method == 'griddata_linear':
433
+ interp_method = params.get('method', 'linear')
434
+ return infill_griddata_linear(ux, uy, mask, method=interp_method)
435
+
436
+ elif method == 'rbf_local':
437
+ neighbors = params.get('neighbors', 64)
438
+ kernel = params.get('kernel', 'thin_plate_spline')
439
+ epsilon = params.get('epsilon', None)
440
+ smoothing = params.get('smoothing', 0.0)
441
+ return infill_rbf_local(ux, uy, mask, neighbors=neighbors,
442
+ kernel=kernel, epsilon=epsilon, smoothing=smoothing)
443
+
444
+ else:
445
+ raise ValueError(f"Unknown infilling method: {method}")