pivtools 0.1.3__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (127) hide show
  1. pivtools-0.1.3.dist-info/METADATA +222 -0
  2. pivtools-0.1.3.dist-info/RECORD +127 -0
  3. pivtools-0.1.3.dist-info/WHEEL +5 -0
  4. pivtools-0.1.3.dist-info/entry_points.txt +3 -0
  5. pivtools-0.1.3.dist-info/top_level.txt +3 -0
  6. pivtools_cli/__init__.py +5 -0
  7. pivtools_cli/_build_marker.c +25 -0
  8. pivtools_cli/_build_marker.cp311-win_amd64.pyd +0 -0
  9. pivtools_cli/cli.py +225 -0
  10. pivtools_cli/example.py +139 -0
  11. pivtools_cli/lib/PIV_2d_cross_correlate.c +334 -0
  12. pivtools_cli/lib/PIV_2d_cross_correlate.h +22 -0
  13. pivtools_cli/lib/common.h +36 -0
  14. pivtools_cli/lib/interp2custom.c +146 -0
  15. pivtools_cli/lib/interp2custom.h +48 -0
  16. pivtools_cli/lib/peak_locate_gsl.c +711 -0
  17. pivtools_cli/lib/peak_locate_gsl.h +40 -0
  18. pivtools_cli/lib/peak_locate_gsl_print.c +736 -0
  19. pivtools_cli/lib/peak_locate_lm.c +751 -0
  20. pivtools_cli/lib/peak_locate_lm.h +27 -0
  21. pivtools_cli/lib/xcorr.c +342 -0
  22. pivtools_cli/lib/xcorr.h +31 -0
  23. pivtools_cli/lib/xcorr_cache.c +78 -0
  24. pivtools_cli/lib/xcorr_cache.h +26 -0
  25. pivtools_cli/piv/interp2custom/interp2custom.py +69 -0
  26. pivtools_cli/piv/piv.py +240 -0
  27. pivtools_cli/piv/piv_backend/base.py +825 -0
  28. pivtools_cli/piv/piv_backend/cpu_instantaneous.py +1005 -0
  29. pivtools_cli/piv/piv_backend/factory.py +28 -0
  30. pivtools_cli/piv/piv_backend/gpu_instantaneous.py +15 -0
  31. pivtools_cli/piv/piv_backend/infilling.py +445 -0
  32. pivtools_cli/piv/piv_backend/outlier_detection.py +306 -0
  33. pivtools_cli/piv/piv_backend/profile_cpu_instantaneous.py +230 -0
  34. pivtools_cli/piv/piv_result.py +40 -0
  35. pivtools_cli/piv/save_results.py +342 -0
  36. pivtools_cli/piv_cluster/cluster.py +108 -0
  37. pivtools_cli/preprocessing/filters.py +399 -0
  38. pivtools_cli/preprocessing/preprocess.py +79 -0
  39. pivtools_cli/tests/helpers.py +107 -0
  40. pivtools_cli/tests/instantaneous_piv/test_piv_integration.py +167 -0
  41. pivtools_cli/tests/instantaneous_piv/test_piv_integration_multi.py +553 -0
  42. pivtools_cli/tests/preprocessing/test_filters.py +41 -0
  43. pivtools_core/__init__.py +5 -0
  44. pivtools_core/config.py +703 -0
  45. pivtools_core/config.yaml +135 -0
  46. pivtools_core/image_handling/__init__.py +0 -0
  47. pivtools_core/image_handling/load_images.py +464 -0
  48. pivtools_core/image_handling/readers/__init__.py +53 -0
  49. pivtools_core/image_handling/readers/generic_readers.py +50 -0
  50. pivtools_core/image_handling/readers/lavision_reader.py +190 -0
  51. pivtools_core/image_handling/readers/registry.py +24 -0
  52. pivtools_core/paths.py +49 -0
  53. pivtools_core/vector_loading.py +248 -0
  54. pivtools_gui/__init__.py +3 -0
  55. pivtools_gui/app.py +687 -0
  56. pivtools_gui/calibration/__init__.py +0 -0
  57. pivtools_gui/calibration/app/__init__.py +0 -0
  58. pivtools_gui/calibration/app/views.py +1186 -0
  59. pivtools_gui/calibration/calibration_planar/planar_calibration_production.py +570 -0
  60. pivtools_gui/calibration/vector_calibration_production.py +544 -0
  61. pivtools_gui/config.py +703 -0
  62. pivtools_gui/image_handling/__init__.py +0 -0
  63. pivtools_gui/image_handling/load_images.py +464 -0
  64. pivtools_gui/image_handling/readers/__init__.py +53 -0
  65. pivtools_gui/image_handling/readers/generic_readers.py +50 -0
  66. pivtools_gui/image_handling/readers/lavision_reader.py +190 -0
  67. pivtools_gui/image_handling/readers/registry.py +24 -0
  68. pivtools_gui/masking/__init__.py +0 -0
  69. pivtools_gui/masking/app/__init__.py +0 -0
  70. pivtools_gui/masking/app/views.py +123 -0
  71. pivtools_gui/paths.py +49 -0
  72. pivtools_gui/piv_runner.py +261 -0
  73. pivtools_gui/pivtools.py +58 -0
  74. pivtools_gui/plotting/__init__.py +0 -0
  75. pivtools_gui/plotting/app/__init__.py +0 -0
  76. pivtools_gui/plotting/app/views.py +1671 -0
  77. pivtools_gui/plotting/plot_maker.py +220 -0
  78. pivtools_gui/post_processing/POD/__init__.py +0 -0
  79. pivtools_gui/post_processing/POD/app/__init__.py +0 -0
  80. pivtools_gui/post_processing/POD/app/views.py +647 -0
  81. pivtools_gui/post_processing/POD/pod_decompose.py +979 -0
  82. pivtools_gui/post_processing/POD/views.py +1096 -0
  83. pivtools_gui/post_processing/__init__.py +0 -0
  84. pivtools_gui/static/404.html +1 -0
  85. pivtools_gui/static/_next/static/chunks/117-d5793c8e79de5511.js +2 -0
  86. pivtools_gui/static/_next/static/chunks/484-cfa8b9348ce4f00e.js +1 -0
  87. pivtools_gui/static/_next/static/chunks/869-320a6b9bdafbb6d3.js +1 -0
  88. pivtools_gui/static/_next/static/chunks/app/_not-found/page-12f067ceb7415e55.js +1 -0
  89. pivtools_gui/static/_next/static/chunks/app/layout-b907d5f31ac82e9d.js +1 -0
  90. pivtools_gui/static/_next/static/chunks/app/page-334cc4e8444cde2f.js +1 -0
  91. pivtools_gui/static/_next/static/chunks/fd9d1056-ad15f396ddf9b7e5.js +1 -0
  92. pivtools_gui/static/_next/static/chunks/framework-f66176bb897dc684.js +1 -0
  93. pivtools_gui/static/_next/static/chunks/main-a1b3ced4d5f6d998.js +1 -0
  94. pivtools_gui/static/_next/static/chunks/main-app-8a63c6f5e7baee11.js +1 -0
  95. pivtools_gui/static/_next/static/chunks/pages/_app-72b849fbd24ac258.js +1 -0
  96. pivtools_gui/static/_next/static/chunks/pages/_error-7ba65e1336b92748.js +1 -0
  97. pivtools_gui/static/_next/static/chunks/polyfills-42372ed130431b0a.js +1 -0
  98. pivtools_gui/static/_next/static/chunks/webpack-4a8ca7c99e9bb3d8.js +1 -0
  99. pivtools_gui/static/_next/static/css/7d3f2337d7ea12a5.css +3 -0
  100. pivtools_gui/static/_next/static/vQeR20OUdSSKlK4vukC4q/_buildManifest.js +1 -0
  101. pivtools_gui/static/_next/static/vQeR20OUdSSKlK4vukC4q/_ssgManifest.js +1 -0
  102. pivtools_gui/static/file.svg +1 -0
  103. pivtools_gui/static/globe.svg +1 -0
  104. pivtools_gui/static/grid.svg +8 -0
  105. pivtools_gui/static/index.html +1 -0
  106. pivtools_gui/static/index.txt +8 -0
  107. pivtools_gui/static/next.svg +1 -0
  108. pivtools_gui/static/vercel.svg +1 -0
  109. pivtools_gui/static/window.svg +1 -0
  110. pivtools_gui/stereo_reconstruction/__init__.py +0 -0
  111. pivtools_gui/stereo_reconstruction/app/__init__.py +0 -0
  112. pivtools_gui/stereo_reconstruction/app/views.py +1985 -0
  113. pivtools_gui/stereo_reconstruction/stereo_calibration_production.py +606 -0
  114. pivtools_gui/stereo_reconstruction/stereo_reconstruction_production.py +544 -0
  115. pivtools_gui/utils.py +63 -0
  116. pivtools_gui/vector_loading.py +248 -0
  117. pivtools_gui/vector_merging/__init__.py +1 -0
  118. pivtools_gui/vector_merging/app/__init__.py +1 -0
  119. pivtools_gui/vector_merging/app/views.py +759 -0
  120. pivtools_gui/vector_statistics/app/__init__.py +1 -0
  121. pivtools_gui/vector_statistics/app/views.py +710 -0
  122. pivtools_gui/vector_statistics/ensemble_statistics.py +49 -0
  123. pivtools_gui/vector_statistics/instantaneous_statistics.py +311 -0
  124. pivtools_gui/video_maker/__init__.py +0 -0
  125. pivtools_gui/video_maker/app/__init__.py +0 -0
  126. pivtools_gui/video_maker/app/views.py +436 -0
  127. pivtools_gui/video_maker/video_maker.py +662 -0
@@ -0,0 +1,342 @@
1
+ """
2
+ Module for saving PIV results to .mat files compatible with post-processing code.
3
+ """
4
+ import logging
5
+ import sys
6
+ from pathlib import Path
7
+ from typing import List, Optional, Union
8
+
9
+ import numpy as np
10
+ import scipy.io
11
+ from pivtools_core.config import Config
12
+ from pivtools_core.paths import get_data_paths
13
+
14
+ from pivtools_cli.piv.piv_result import PIVResult, PIVPassResult
15
+
16
+
17
+ def save_piv_result_distributed(
18
+ piv_result: PIVResult,
19
+ output_path: Path,
20
+ frame_number: int,
21
+ runs_to_save: Optional[List[int]] = None,
22
+ vector_fmt: str = "B%05d.mat",
23
+ ) -> str:
24
+ """
25
+ Save a PIV result to disk. Designed to be submitted to Dask workers.
26
+
27
+ This function can be called on Dask workers to save results in parallel,
28
+ avoiding the memory bottleneck of gathering all results to main.
29
+ Memory-efficient: uses direct serialization without unnecessary copies.
30
+
31
+ Parameters
32
+ ----------
33
+ piv_result : PIVResult
34
+ The PIV result object containing one or more passes with complete data.
35
+ output_path : Path
36
+ Directory where the .mat file will be saved.
37
+ frame_number : int
38
+ Frame number (1-based) for the filename (e.g., 1 -> B00001.mat).
39
+ runs_to_save : Optional[List[int]]
40
+ List of pass indices (0-based) to save. If None, save all passes.
41
+ For passes not in this list, empty arrays will be saved.
42
+ vector_fmt : str
43
+ Format string for the filename, e.g., "B%05d.mat".
44
+
45
+ Returns
46
+ -------
47
+ str
48
+ Path to the saved file (for verification/logging).
49
+ """
50
+ output_path = Path(output_path)
51
+ output_path.mkdir(parents=True, exist_ok=True)
52
+
53
+ filename = output_path / (vector_fmt % frame_number)
54
+
55
+ if len(piv_result.passes) == 0:
56
+ logging.warning(
57
+ f"PIVResult has no passes for frame {frame_number}. "
58
+ "Skipping save."
59
+ )
60
+ return str(filename)
61
+
62
+ # Create single struct with arrays indexed by pass number
63
+ # All data is already in piv_result, no external lists needed
64
+ mat_data = _create_piv_struct_all_passes(piv_result, runs_to_save)
65
+
66
+ # Save to .mat file with compression to reduce I/O
67
+ scipy.io.savemat(filename, {"piv_result": mat_data}, oned_as="row", do_compression=True)
68
+ logging.debug(f"Worker saved PIV result to {filename}")
69
+
70
+ return str(filename)
71
+
72
+
73
+ def save_coordinates_from_config_distributed(
74
+ config: Config,
75
+ output_path: Path,
76
+ correlator_cache: Optional[dict] = None,
77
+ runs_to_save: Optional[List[int]] = None,
78
+ ) -> str:
79
+ """
80
+ Generate and save coordinate grids. Designed for Dask workers.
81
+
82
+ Parameters
83
+ ----------
84
+ config : Config
85
+ Configuration object containing window sizes and overlap.
86
+ output_path : Path
87
+ Directory where coordinates.mat will be saved.
88
+ correlator_cache : Optional[dict]
89
+ Precomputed correlator cache to avoid redundant computation.
90
+ runs_to_save : Optional[List[int]]
91
+ List of pass indices (0-based) to save with data. If None, save all passes.
92
+ For passes not in this list, empty coordinate grids will be saved.
93
+
94
+ Returns
95
+ -------
96
+ str
97
+ Path to the saved coordinates file.
98
+ """
99
+ from pivtools_cli.piv.piv_backend.cpu_instantaneous import (
100
+ InstantaneousCorrelatorCPU
101
+ )
102
+
103
+ # Create a temporary correlator with optional precomputed cache
104
+ correlator = InstantaneousCorrelatorCPU(config, precomputed_cache=correlator_cache)
105
+
106
+ # Extract the cached window centers
107
+ win_ctrs_x_list = correlator.win_ctrs_x
108
+ win_ctrs_y_list = correlator.win_ctrs_y
109
+
110
+ num_passes = len(config.window_sizes)
111
+
112
+ if runs_to_save is None:
113
+ runs_to_save = list(range(num_passes))
114
+
115
+
116
+ # Create MATLAB-style struct array with fields 'x' and 'y', shape (num_passes,)
117
+ dtype = [('x', object), ('y', object)]
118
+ coords_struct = np.empty((num_passes,), dtype=dtype)
119
+
120
+ for i in range(num_passes):
121
+ if i in runs_to_save:
122
+ x_centers = win_ctrs_x_list[i]
123
+ y_centers = win_ctrs_y_list[i]
124
+
125
+ # Create 2D coordinate grids with smallest y at the bottom
126
+ x_grid, y_grid = np.meshgrid(x_centers+1, y_centers[::-1]+1, indexing='xy')
127
+
128
+ # Convert to half precision for space saving
129
+ x_grid = _convert_to_half_precision(x_grid)
130
+ y_grid = _convert_to_half_precision(y_grid)
131
+
132
+ coords_struct['x'][i] = x_grid
133
+ coords_struct['y'][i] = y_grid
134
+ else:
135
+ # Empty arrays for non-selected passes
136
+ coords_struct['x'][i] = np.array([], dtype=np.float16)
137
+ coords_struct['y'][i] = np.array([], dtype=np.float16)
138
+
139
+ output_path = Path(output_path)
140
+ output_path.mkdir(parents=True, exist_ok=True)
141
+
142
+ filename = output_path / "coordinates.mat"
143
+ scipy.io.savemat(filename, {"coordinates": coords_struct}, oned_as="row", do_compression=True)
144
+ logging.info(f"Worker saved coordinates to {filename}")
145
+
146
+ return str(filename)
147
+
148
+
149
+ def _create_piv_struct_all_passes(
150
+ piv_result: PIVResult,
151
+ runs_to_save: Optional[List[int]] = None,
152
+ ) -> np.ndarray:
153
+ """
154
+ Create a MATLAB-compatible struct with arrays indexed by pass number.
155
+
156
+ This creates a single struct where each field (ux, uy, b_mask, etc.) is
157
+ an array with one element per pass, matching the expected format:
158
+ piv_result["ux"][pass_idx] = 2D array for that pass
159
+
160
+ All required data (including window centers and masks) is extracted from
161
+ the PIVResult object, which contains all necessary information in each
162
+ PIVPassResult.
163
+
164
+ Parameters
165
+ ----------
166
+ piv_result : PIVResult
167
+ PIV result object containing one or more passes with complete data.
168
+ runs_to_save : Optional[List[int]]
169
+ List of pass indices (0-based) to save with data. If None, save all passes.
170
+ For passes not in this list, empty arrays will be saved.
171
+
172
+ Returns
173
+ -------
174
+ np.ndarray
175
+ Structured numpy array compatible with scipy.io.savemat.
176
+ """
177
+ n_passes = len(piv_result.passes)
178
+
179
+ # Always save all passes, but empty arrays for non-selected passes
180
+ n_passes_to_save = n_passes
181
+ passes_to_save = list(range(n_passes))
182
+
183
+ # If runs_to_save is specified, only fill data for those passes
184
+ if runs_to_save is None:
185
+ runs_to_save = passes_to_save
186
+
187
+ # Create structured dtype with all fields
188
+ dtype = [
189
+ ('ux', object),
190
+ ('uy', object),
191
+ ('b_mask', object),
192
+ ('nan_mask', object),
193
+ ('win_ctrs_x', object),
194
+ ('win_ctrs_y', object),
195
+ ('peak_mag', object),
196
+ ('peak_choice', object),
197
+ ('n_windows', object),
198
+ ('predictor_field', object),
199
+ ('window_size', object),
200
+ ]
201
+
202
+ # Create the struct with shape (n_passes_to_save,)
203
+ piv_struct = np.empty((n_passes_to_save,), dtype=dtype)
204
+
205
+ # Get dtype from first pass for creating empty arrays
206
+ first_pass = piv_result.passes[0]
207
+ if first_pass.ux_mat is not None and first_pass.ux_mat.size > 0:
208
+ data_dtype = first_pass.ux_mat.dtype
209
+ else:
210
+ data_dtype = np.float64
211
+
212
+ # Initialize all passes with empty arrays
213
+ empty = np.empty((0, 0), dtype=data_dtype)
214
+ for i in range(n_passes_to_save):
215
+ piv_struct['ux'][i] = empty
216
+ piv_struct['uy'][i] = empty
217
+ piv_struct['b_mask'][i] = empty
218
+ piv_struct['nan_mask'][i] = empty
219
+ piv_struct['win_ctrs_x'][i] = empty
220
+ piv_struct['win_ctrs_y'][i] = empty
221
+ piv_struct['peak_mag'][i] = empty
222
+ piv_struct['peak_choice'][i] = empty
223
+ piv_struct['n_windows'][i] = empty
224
+ piv_struct['predictor_field'][i] = empty
225
+ piv_struct['window_size'][i] = empty
226
+
227
+ # Fill with actual data for selected passes
228
+ for local_idx, global_pass_idx in enumerate(passes_to_save):
229
+ if global_pass_idx not in runs_to_save:
230
+ continue # Skip filling for non-selected passes
231
+ pass_result = piv_result.passes[global_pass_idx]
232
+
233
+ # Save ux and uy directly without swapping - coordinate system is now correct
234
+ if pass_result.ux_mat is not None:
235
+ piv_struct['ux'][local_idx] = _convert_to_half_precision(pass_result.ux_mat)
236
+ if pass_result.uy_mat is not None:
237
+ piv_struct['uy'][local_idx] = _convert_to_half_precision(pass_result.uy_mat)
238
+
239
+ # Use b_mask from pass_result (already computed during PIV)
240
+ if pass_result.b_mask is not None:
241
+ piv_struct['b_mask'][local_idx] = pass_result.b_mask
242
+ elif pass_result.nan_mask is not None:
243
+ # Fallback to nan_mask if b_mask not available
244
+ piv_struct['b_mask'][local_idx] = pass_result.nan_mask
245
+
246
+ if pass_result.nan_mask is not None:
247
+ piv_struct['nan_mask'][local_idx] = pass_result.nan_mask
248
+
249
+ # Window centers are always stored in pass_result
250
+ if pass_result.win_ctrs_x is not None:
251
+ piv_struct['win_ctrs_x'][local_idx] = _convert_to_half_precision(pass_result.win_ctrs_x)
252
+ if pass_result.win_ctrs_y is not None:
253
+ piv_struct['win_ctrs_y'][local_idx] = _convert_to_half_precision(pass_result.win_ctrs_y)
254
+
255
+ if pass_result.peak_mag is not None:
256
+ piv_struct['peak_mag'][local_idx] = _convert_to_half_precision(pass_result.peak_mag)
257
+ if pass_result.peak_choice is not None:
258
+ piv_struct['peak_choice'][local_idx] = pass_result.peak_choice
259
+ if pass_result.n_windows is not None:
260
+ piv_struct['n_windows'][local_idx] = pass_result.n_windows
261
+ if pass_result.predictor_field is not None:
262
+ piv_struct['predictor_field'][local_idx] = _convert_to_half_precision(pass_result.predictor_field)
263
+ if pass_result.window_size is not None:
264
+ piv_struct['window_size'][local_idx] = pass_result.window_size
265
+
266
+ return piv_struct
267
+
268
+
269
+ # Note: get_data_paths is imported from src/paths.py at the top of this file
270
+
271
+
272
+ def get_output_path(
273
+ config: Config,
274
+ camera: Union[int, str],
275
+ create: bool = True,
276
+ use_uncalibrated: bool = True,
277
+ ) -> Path:
278
+ """
279
+ Get the output path for a specific camera's PIV results using the GUI path structure.
280
+
281
+ Follows the standardized directory structure:
282
+ - Uncalibrated: base_path/uncalibrated_piv/{num_images}/Cam{camera}/instantaneous
283
+ - Calibrated: base_path/calibrated_piv/{num_images}/Cam{camera}/instantaneous
284
+
285
+ Parameters
286
+ ----------
287
+ config : Config
288
+ Configuration object.
289
+ camera : Union[int, str]
290
+ Camera number (int) or camera folder name (str, e.g., "Cam1").
291
+ create : bool
292
+ If True, create the directory if it doesn't exist.
293
+ use_uncalibrated : bool
294
+ If True, save to uncalibrated_piv directory.
295
+ If False, save to calibrated_piv directory.
296
+
297
+ Returns
298
+ -------
299
+ Path
300
+ Output path for PIV results.
301
+ """
302
+ base_path = config.base_paths[0]
303
+
304
+ # Convert camera to int if it's a string
305
+ if isinstance(camera, str):
306
+ if camera.startswith("Cam"):
307
+ camera_num = int(camera[3:])
308
+ else:
309
+ camera_num = int(camera)
310
+ else:
311
+ camera_num = camera
312
+
313
+ # Get PIV type - default to instantaneous
314
+ piv_type = "instantaneous" if config.data.get("processing", {}).get("instantaneous", True) else "ensemble"
315
+
316
+ # Use get_data_paths from src/paths.py (positional args: base_dir, num_images, cam, type_name)
317
+ paths = get_data_paths(
318
+ base_path,
319
+ config.num_images,
320
+ camera_num,
321
+ piv_type,
322
+ endpoint="",
323
+ use_uncalibrated=use_uncalibrated
324
+ )
325
+
326
+ output_path = paths["data_dir"]
327
+
328
+ if create:
329
+ output_path.mkdir(parents=True, exist_ok=True)
330
+
331
+ return output_path
332
+
333
+
334
+ def _convert_to_half_precision(arr: np.ndarray) -> np.ndarray:
335
+ """
336
+ Convert float arrays to half precision (float16) for space saving.
337
+ """
338
+ if arr is None or arr.size == 0:
339
+ return arr
340
+ if arr.dtype.kind == 'f':
341
+ return arr.astype(np.float16)
342
+ return arr
@@ -0,0 +1,108 @@
1
+ import logging
2
+ import os
3
+ import sys
4
+ from collections import defaultdict
5
+ from pathlib import Path
6
+ from typing import List, Tuple
7
+
8
+ from dask.distributed import Client, LocalCluster
9
+
10
+ from pivtools_core.config import Config
11
+
12
+
13
+ def make_cluster(
14
+ threads_per_worker: int = 1, # None,
15
+ n_workers_per_node: int = 2,
16
+ memory_limit: str = "auto",
17
+ ) -> Tuple[LocalCluster, Client]:
18
+ cluster = LocalCluster(
19
+ n_workers=n_workers_per_node,
20
+ threads_per_worker=threads_per_worker,
21
+ memory_limit=memory_limit,
22
+ nanny=True,
23
+ )
24
+ client = Client(cluster)
25
+ return cluster, client
26
+
27
+
28
+ def group_workers_by_host(client: Client) -> dict[str, List[str]]:
29
+ workers = client.scheduler_info()["workers"]
30
+ grouped = defaultdict(list)
31
+ for addr, info in workers.items():
32
+ grouped[info["host"]].append(addr)
33
+ return dict(grouped)
34
+
35
+
36
+ def select_workers_per_node(client: Client, n_workers_per_node: int = 1) -> List[str]:
37
+ grouped = group_workers_by_host(client)
38
+ selected = []
39
+ for node_workers in grouped.values():
40
+ selected.extend(node_workers[:n_workers_per_node])
41
+ return selected
42
+
43
+
44
+ def start_cluster(
45
+ n_workers_per_node: int = 1,
46
+ threads_per_worker: int = None,
47
+ memory_limit: str = "auto",
48
+ config: Config = Config(),
49
+ ) -> tuple[LocalCluster, Client]:
50
+ """
51
+ Start a local Dask cluster.
52
+
53
+ Returns:
54
+ client: Dask Client
55
+ piv_workers: list of workers to use for PIV
56
+ """
57
+ cluster = None
58
+ client = None
59
+
60
+ try:
61
+ cluster, client = make_cluster(
62
+ n_workers_per_node=n_workers_per_node,
63
+ threads_per_worker=threads_per_worker,
64
+ memory_limit=memory_limit,
65
+ )
66
+ client.run(
67
+ setup_worker_logging,
68
+ log_level=getattr(logging, config.log_level, logging.INFO),
69
+ log_file=config.log_file if hasattr(config, "log_file") else None,
70
+ log_console=True,
71
+ )
72
+
73
+ return cluster, client
74
+
75
+ except Exception as e:
76
+ print(f"Error starting Dask cluster: {e}")
77
+ if client is not None:
78
+ client.close()
79
+ if cluster is not None:
80
+ cluster.close()
81
+ raise
82
+
83
+
84
+ def setup_worker_logging(log_level=logging.INFO, log_file=None, log_console=True):
85
+ """
86
+ Configure logging inside a Dask worker process.
87
+ """
88
+ logger = logging.getLogger()
89
+ logger.setLevel(log_level)
90
+
91
+ for handler in logger.handlers[:]:
92
+ logger.removeHandler(handler)
93
+
94
+ if log_file:
95
+ file_handler = logging.FileHandler(log_file)
96
+ file_handler.setLevel(log_level)
97
+ file_formatter = logging.Formatter("%(asctime)s [%(levelname)s] %(message)s")
98
+ file_handler.setFormatter(file_formatter)
99
+ logger.addHandler(file_handler)
100
+
101
+ if log_console:
102
+ console_handler = logging.StreamHandler()
103
+ console_handler.setLevel(log_level)
104
+ console_formatter = logging.Formatter("%(asctime)s [%(levelname)s] %(message)s")
105
+ console_handler.setFormatter(console_formatter)
106
+ logger.addHandler(console_handler)
107
+
108
+ logger.info("Worker logging configured successfully")