cdo-toolkit 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,2230 @@
1
+ """CDO regridding pipeline core."""
2
+
3
+ import hashlib
4
+ import logging
5
+ import multiprocessing as mp
6
+ import os
7
+ import shutil
8
+ import signal
9
+ import tempfile
10
+ import threading
11
+ import time
12
+ import traceback
13
+ import uuid
14
+ from concurrent.futures import ProcessPoolExecutor, as_completed
15
+ from pathlib import Path
16
+ from typing import Optional
17
+
18
+ import numpy as np
19
+ import psutil
20
+ import xarray as xa
21
+ from cdo import Cdo
22
+ from rich.console import Console
23
+ from rich.panel import Panel
24
+ from rich.table import Table
25
+
26
+ try:
27
+ import fcntl
28
+ except ImportError:
29
+ fcntl = None
30
+
31
+ from cdo_toolkit.constants import NC4_ENCODING_KEYS, REGRID_ERROR_LOGGER_NAME
32
+ from cdo_toolkit.errors import init_regrid_error_log, log_regrid_error, weight_file_lock
33
+ from cdo_toolkit.memory import MemoryMonitor
34
+ from cdo_toolkit.paths import is_weights_or_cache_file, weight_cache_dir_for_input
35
+ from cdo_toolkit.resolution import calc_resolution
36
+ from cdo_toolkit.timing import format_processing_time, get_processing_time, print_timestamp
37
+ from cdo_toolkit.ui import BatchRegridUI, RegridProgressUI
38
+ from cdo_toolkit.cmip import representative_files_by_directory
39
+ from cdo_toolkit.workers import poll_batch_progress, process_single_file_standalone
40
+
41
+ class CDORegridPipeline:
42
+ """
43
+ CDO regridding pipeline with all advanced features.
44
+ """
45
+
46
+ def __init__(
47
+ self,
48
+ target_resolution: tuple[float, float] = (1.0, 1.0), # TODO: automate this based on native resolution of input file
49
+ target_grid: str = "lonlat",
50
+ weight_cache_dir: Optional[Path] = None,
51
+ extract_surface: bool = False,
52
+ extract_seafloor: bool = False,
53
+ use_regrid_cache: bool = True,
54
+ use_seafloor_cache: bool = True,
55
+ verbose: bool = True,
56
+ verbose_diagnostics: bool = False,
57
+ max_memory_gb: float = 8.0,
58
+ max_workers: Optional[int] = 16,
59
+ chunk_size_gb: float = 2.0,
60
+ enable_parallel: bool = True,
61
+ enable_chunking: bool = True,
62
+ memory_monitoring: bool = True,
63
+ cleanup_weights: bool = False,
64
+ ):
65
+ """
66
+ Initialize the CDO regridding pipeline.
67
+
68
+ Pipeline behaviour:
69
+ - If neither extract_surface nor extract_seafloor: regrid the whole file (use weight cache if use_regrid_cache).
70
+ - If extract_seafloor: identify seafloor indices (from cache if use_seafloor_cache), extract seafloor, regrid only that.
71
+ - If extract_surface: extract top level and regrid only that (use weight cache if use_regrid_cache).
72
+ - If both extract_surface and extract_seafloor: perform seafloor then surface sequentially (use regrid_single_file twice or regrid_single_file_extreme_levels).
73
+
74
+ Parameters
75
+ ----------
76
+ target_resolution (tuple): Target resolution as (lon_res, lat_res) in degrees.
77
+ target_grid (str): Target grid type ('lonlat', 'gaussian', etc.).
78
+ weight_cache_dir (Path, optional): Directory to cache regrid weights.
79
+ extract_surface (bool): If True, extract top level only and regrid that.
80
+ extract_seafloor (bool): If True, extract seafloor values (deepest non-NaN) and regrid only that.
81
+ use_regrid_cache (bool): If True, reuse existing regrid weight files when present.
82
+ use_seafloor_cache (bool): If True, reuse in-memory seafloor depth indices for files in the same directory.
83
+ verbose (bool): Enable verbose output (progress UI, etc.).
84
+ verbose_diagnostics (bool): If True, print Grid type, File size, Levels, Large file messages (max verbosity).
85
+ max_memory_gb (float): Maximum memory usage in GB.
86
+ max_workers (int, optional): Maximum number of parallel workers.
87
+ chunk_size_gb (float): Maximum chunk size in GB for large files.
88
+ enable_parallel (bool): Enable parallel processing.
89
+ enable_chunking (bool): Enable chunked processing for large files.
90
+ memory_monitoring (bool): Enable memory usage monitoring.
91
+ """
92
+ self.target_resolution = target_resolution
93
+ self.target_grid = target_grid
94
+ self.extract_surface = extract_surface
95
+ self.extract_seafloor = extract_seafloor
96
+ self.use_regrid_cache = use_regrid_cache
97
+ self.use_seafloor_cache = use_seafloor_cache
98
+ self.verbose = verbose
99
+ self.verbose_diagnostics = verbose_diagnostics
100
+ self.max_memory_gb = max_memory_gb
101
+ self.chunk_size_gb = chunk_size_gb
102
+ self.enable_parallel = enable_parallel
103
+ self.enable_chunking = enable_chunking
104
+ self.memory_monitoring = memory_monitoring
105
+ self.cleanup_weights = cleanup_weights
106
+ self.prune_regridded = True
107
+
108
+ # Cache for seafloor depth indices per directory (for optimization)
109
+ self._seafloor_depth_cache: dict[str, dict[str, int]] = {}
110
+
111
+ # ensure not requesting more workers than available
112
+ if max_workers is None:
113
+ self.max_workers = min(self.max_workers, mp.cpu_count())
114
+ else:
115
+ self.max_workers = max_workers
116
+
117
+ # set up CDO with optimized settings
118
+ self.cdo = self._setup_cdo()
119
+
120
+ # set up console and logger for output
121
+ self.console = Console()
122
+ self.logger = self._setup_logger()
123
+ self._error_log_path: Optional[Path] = None
124
+
125
+ # weight cache management
126
+ self.weight_cache_dir = Path(weight_cache_dir) if weight_cache_dir else None
127
+ if self.weight_cache_dir:
128
+ self.weight_cache_dir.mkdir(parents=True, exist_ok=True)
129
+ self.weight_cache: dict[str, Path] = {}
130
+ if self.weight_cache_dir and self.cleanup_weights:
131
+ self.cleanup_weight_files()
132
+
133
+ # monitor memory to prevent antisocial behaviour on shared machines andout of memory errors
134
+ self.memory_monitor = MemoryMonitor() if memory_monitoring else None
135
+
136
+ # prepare to store meta-data and processing statistics
137
+ self.stats = {
138
+ 'files_processed': 0,
139
+ 'weights_reused': 0,
140
+ 'weights_generated': 0,
141
+ 'errors': 0,
142
+ 'total_size_gb': 0.0,
143
+ 'chunks_processed': 0,
144
+ 'memory_peak_gb': 0.0,
145
+ 'grid_types': {
146
+ 'structured': 0,
147
+ 'curvilinear': 0,
148
+ 'tripolar_ocean': 0,
149
+ 'unstructured_ncells': 0,
150
+ 'unknown': 0,
151
+ }
152
+ }
153
+
154
+ # cache for file info to avoid repeated expensive operations
155
+ self._file_info_cache: dict[Path, dict] = {}
156
+
157
+ # track created files for cleanup on interrupt
158
+ self._created_files: list[Path] = []
159
+ self._progress_state = None
160
+ self._progress_key: Optional[str] = None
161
+ self._setup_signal_handlers()
162
+
163
+ # timing tracking
164
+ self._start_time: Optional[time.struct_time] = None
165
+ self.end_time: Optional[time.struct_time] = None
166
+
167
+ def _get_target_variable(self, file_path: Path) -> str:
168
+ """Get the target variable for the file from the file_path. N.B. this is specific to CMIP data"""
169
+ return file_path.stem.split("_")[0]
170
+
171
+ def _prune_regridded(self, input_files: list[Path], overwrite: bool = False) -> list[Path]:
172
+ """Prune files with 'regridded' in the name to avoid processing them twice."""
173
+ # ignore any files which have 'cdo_weights' as any parent directory
174
+ input_files = [file for file in input_files if 'cdo_weights' not in file.parents]
175
+ # TODO: if ever getting levels from a regridded file will need to rethink this
176
+ if self.prune_regridded:
177
+ if overwrite: # delete existing regridded files
178
+ for file in input_files:
179
+ if 'regridded' in file.name:
180
+ if self.verbose:
181
+ self.console.print(f"[yellow]Removing existing regridded file: {file.name}[/yellow]")
182
+ file.unlink()
183
+ else: # skip existing regridded files (don't regenerate)
184
+ if self.verbose:
185
+ regridded_files = [file for file in input_files if 'regridded' in file.name]
186
+ if regridded_files:
187
+ self.console.print(f"[blue]Skipping {len(regridded_files)} existing regridded files (overwrite=False)[/blue]")
188
+ return [file for file in input_files if 'regridded' in file.name and not "_chunk_" in file.name]
189
+
190
+ def _cleanup_files_by_pattern(
191
+ self,
192
+ input_files: list[Path],
193
+ pattern: str,
194
+ file_type: str,
195
+ exclude_regridded: bool = False
196
+ ) -> list[Path]:
197
+ """Generic helper to clean up files matching a pattern.
198
+
199
+ Args:
200
+ input_files: List of file paths to check
201
+ pattern: Pattern to match in filename (e.g., '_top_level', '_chunk_')
202
+ file_type: Description for logging (e.g., '_top_level', '_chunk_')
203
+ exclude_regridded: If True, exclude files with 'regridded' in name
204
+
205
+ Returns:
206
+ List of cleaned files (with problematic files removed)
207
+ """
208
+ cleaned_files = []
209
+ removed_count = 0
210
+
211
+ for file_path in input_files:
212
+ should_remove = pattern in file_path.name
213
+ if exclude_regridded and should_remove:
214
+ should_remove = "regridded" not in file_path.name
215
+
216
+ if should_remove:
217
+ if self.verbose:
218
+ self.console.print(
219
+ f"[yellow]Removing problematic {file_type} file: {file_path.name}[/yellow]"
220
+ )
221
+ try:
222
+ file_path.unlink()
223
+ removed_count += 1
224
+ except Exception as e:
225
+ self.logger.warning(f"Could not remove {file_path}: {e}")
226
+ else:
227
+ cleaned_files.append(file_path)
228
+
229
+ if removed_count > 0 and self.verbose:
230
+ self.console.print(
231
+ f"[blue]Cleaned up {removed_count} problematic {file_type} files[/blue]"
232
+ )
233
+
234
+ return cleaned_files
235
+
236
+ def _cleanup_top_level_files(self, input_files: list[Path]) -> list[Path]:
237
+ """Clean up existing _top_level files that may cause HDF errors."""
238
+ # TODO: why do I need to get rid of top_level files?
239
+ return self._cleanup_files_by_pattern(
240
+ input_files,
241
+ pattern='_top_level',
242
+ file_type='_top_level',
243
+ exclude_regridded=True
244
+ )
245
+
246
+ def _cleanup_chunk_files(self, input_files: list[Path]) -> list[Path]:
247
+ """Clean up existing _chunk_ files that may cause HDF errors."""
248
+ return self._cleanup_files_by_pattern(
249
+ input_files,
250
+ pattern='_chunk_',
251
+ file_type='_chunk_',
252
+ exclude_regridded=False
253
+ )
254
+
255
+ def _cleanup_problematic_files(self, input_files: list[Path]) -> list[Path]:
256
+ """Clean up all problematic files (_top_level and _chunk_) that may cause HDF errors.
257
+ Returns the list of input files that remain for processing (non-problematic).
258
+ """
259
+ remaining = self._cleanup_files_by_pattern(
260
+ input_files,
261
+ pattern='_top_level',
262
+ file_type='_top_level',
263
+ exclude_regridded=True # keep regridded top_level files (final form)
264
+ )
265
+ return self._cleanup_files_by_pattern(
266
+ remaining,
267
+ pattern='_chunk_',
268
+ file_type='_chunk_',
269
+ exclude_regridded=False # get rid of everything that's still a chunk
270
+ )
271
+
272
+ def _setup_signal_handlers(self):
273
+ """Set up signal handlers for graceful cleanup on keyboard interrupt."""
274
+ import signal
275
+
276
+ def signal_handler(signum, frame):
277
+ """Handle keyboard interrupt and cleanup created files."""
278
+ if self.verbose:
279
+ self.console.print(f"\n[yellow]Received interrupt signal ({signum}). Cleaning up...[/yellow]")
280
+
281
+ self._cleanup_created_files()
282
+
283
+ if self.verbose:
284
+ self.console.print(f"[red]Interrupted. Cleaned up {len(self._created_files)} created files.[/red]")
285
+
286
+ # exit gracefully
287
+ import sys
288
+ sys.exit(1)
289
+
290
+ # register signal handlers
291
+ signal.signal(signal.SIGINT, signal_handler) # Ctrl+C
292
+ signal.signal(signal.SIGTERM, signal_handler) # termination signal
293
+
294
+ def _cleanup_created_files(self):
295
+ """Clean up all files created during processing."""
296
+ for file_path in self._created_files:
297
+ try:
298
+ if file_path.exists():
299
+ file_path.unlink()
300
+ if self.verbose_diagnostics:
301
+ self.logger.info(f"Cleaned up created file: {file_path.name}")
302
+ except Exception as e:
303
+ self.logger.warning(f"Could not clean up {file_path}: {e}")
304
+
305
+ self._created_files.clear()
306
+
307
+ def _track_created_file(self, file_path: Path):
308
+ """Track a file that was created during processing for cleanup."""
309
+ self._created_files.append(file_path)
310
+
311
+ def _setup_cdo(self) -> Cdo:
312
+ """Set up CDO with optimized performance settings."""
313
+ import os
314
+
315
+ # set CDO environment variables for optimal performance
316
+ os.environ["CDO_NETCDF_COMPRESSION"] = "0" # disable compression for speed TODO: compress if space becomes an issue
317
+ os.environ["CDO_NETCDF_64BIT_OFFSET"] = "1" # use 64-bit offsets for large files
318
+ os.environ["CDO_NETCDF_USE_PARALLEL"] = "1" if self.enable_parallel else "0" # enable parallel I/O
319
+ os.environ["CDO_NUM_THREADS"] = str(self.max_workers)
320
+
321
+ return Cdo()
322
+
323
+ def _setup_logger(self) -> logging.Logger:
324
+ """Pipeline logger; errors go to the shared session log via log_regrid_error."""
325
+ return logging.getLogger(REGRID_ERROR_LOGGER_NAME)
326
+
327
+ def set_error_log_path(self, path: Path) -> None:
328
+ """Ensure the shared regrid error log is initialized for this run."""
329
+ self._error_log_path = init_regrid_error_log(path)
330
+
331
+ def _log_error(self, message: str, exc: Optional[BaseException] = None) -> None:
332
+ log_regrid_error(message, exc=exc)
333
+
334
+ def _detect_grid_type(self, ds: xa.Dataset, dims: dict) -> str:
335
+ """
336
+ Detect the type of grid structure with comprehensive support.
337
+
338
+ Returns:
339
+ - 'unstructured_ncells': Grid with ncells dimension (true unstructured)
340
+ - 'tripolar_ocean': Tri-polar ocean grid with (y,x) dimensions, 2D lat/lon coords, and cell boundaries
341
+ - 'curvilinear': Grid with (j,i) dimensions and 2D lat/lon coordinates
342
+ - 'structured': Regular lon/lat grid with 1D coordinates
343
+ - 'unknown': Cannot determine grid type
344
+ """
345
+ # check for true unstructured grid (ncells dimension)
346
+ if 'ncells' in dims:
347
+ return 'unstructured_ncells'
348
+
349
+ # check for tri-polar ocean grid (y,x dimensions with 2D lat/lon coords and cell boundaries)
350
+ if 'y' in dims and 'x' in dims:
351
+ lat_coords = ['lat', 'latitude']
352
+ lon_coords = ['lon', 'longitude']
353
+
354
+ has_2d_lat = any(coord in ds.coords and ds[coord].ndim == 2 for coord in lat_coords)
355
+ has_2d_lon = any(coord in ds.coords and ds[coord].ndim == 2 for coord in lon_coords)
356
+
357
+ if has_2d_lat and has_2d_lon:
358
+ # Check for cell boundaries (indicative of tri-polar ocean grids)
359
+ has_bounds_lat = ('bounds_lat' in list(ds.coords) or 'bounds_lat' in list(ds.data_vars))
360
+ has_bounds_lon = ('bounds_lon' in list(ds.coords) or 'bounds_lon' in list(ds.data_vars))
361
+ has_nvertex = 'nvertex' in dims
362
+
363
+ # Check for ocean-specific variables
364
+ ocean_vars = ['so', 'thetao', 'uo', 'vo', 'wo', 'zos', 'mlotst', 'sithick']
365
+ has_ocean_vars = any(var in ds.data_vars for var in ocean_vars)
366
+
367
+ if (has_bounds_lat and has_bounds_lon and has_nvertex) or has_ocean_vars:
368
+ return 'tripolar_ocean'
369
+ else:
370
+ return 'curvilinear'
371
+
372
+ # check for curvilinear grid (j,i dimensions with 2D lat/lon coords)
373
+ if 'j' in dims and 'i' in dims:
374
+ # check if lat/lon are 2D coordinates on (j,i) grid
375
+ lat_coords = ['lat', 'latitude']
376
+ lon_coords = ['lon', 'longitude']
377
+
378
+ has_2d_lat = any(coord in ds.coords and ds[coord].ndim == 2 for coord in lat_coords)
379
+ has_2d_lon = any(coord in ds.coords and ds[coord].ndim == 2 for coord in lon_coords)
380
+
381
+ if has_2d_lat and has_2d_lon:
382
+ return 'curvilinear'
383
+
384
+ # check for structured grid (1D lat/lon coordinates)
385
+ lat_coords = ['lat', 'latitude']
386
+ lon_coords = ['lon', 'longitude']
387
+
388
+ has_1d_lat = any(coord in ds.coords and ds[coord].ndim == 1 for coord in lat_coords)
389
+ has_1d_lon = any(coord in ds.coords and ds[coord].ndim == 1 for coord in lon_coords)
390
+
391
+ if has_1d_lat and has_1d_lon:
392
+ return 'structured'
393
+
394
+ raise ValueError(f"Could not determine grid type for file TODO") # TODO: get access to file_path to raise error
395
+
396
+ def _has_grid_corners(self, ds: xa.Dataset) -> bool:
397
+ """Return True when CDO can use conservative remapping (cell corners present)."""
398
+ names = set(ds.coords) | set(ds.data_vars)
399
+ lat_bounds = {"bounds_lat", "lat_bnds", "vertices_latitude"} & names
400
+ lon_bounds = {"bounds_lon", "lon_bnds", "vertices_longitude"} & names
401
+ if lat_bounds and lon_bounds:
402
+ return True
403
+
404
+ for lat_name, lon_name in (("lat", "lon"), ("latitude", "longitude")):
405
+ if lat_name not in ds.coords or lon_name not in ds.coords:
406
+ continue
407
+ lat_bnds = ds[lat_name].attrs.get("bounds")
408
+ lon_bnds = ds[lon_name].attrs.get("bounds")
409
+ if lat_bnds in names and lon_bnds in names:
410
+ return True
411
+ return False
412
+
413
+ def _regrid_operators(self, grid_type: str, has_grid_corners: bool) -> tuple[str, str]:
414
+ """Return CDO (remap, gen_weights) operators for this source grid."""
415
+ if grid_type == "structured" or has_grid_corners:
416
+ return "remapcon", "gencon"
417
+ return "remapbil", "genbil"
418
+
419
+ def _whether_multi_level(self, dims: dict) -> tuple[bool, list[str], int]:
420
+ """Whether the file has a multi-level dimension.
421
+
422
+ Returns:
423
+ - has_level (bool): Whether the file has level dimensions
424
+ - level_dims (list[str]): List of level dimension names found
425
+ - level_count (int): Total number of level dimensions
426
+ """
427
+ level_dims = ['lev', 'level', 'depth', 'z']
428
+ has_level = any(dim in dims for dim in level_dims)
429
+ level_count = sum(dims.get(dim, 0) for dim in level_dims)
430
+ return has_level, level_dims, level_count
431
+
432
+ def _has_level_lightweight(self, file_path: Path) -> bool:
433
+ """Lightweight check for level dimensions without opening the full dataset.
434
+
435
+ This method only reads the dimensions from the NetCDF file header,
436
+ avoiding the expensive operation of loading the full dataset.
437
+
438
+ Args:
439
+ file_path (Path): Path to the NetCDF file
440
+
441
+ Returns:
442
+ bool: True if the file has level dimensions, False otherwise
443
+ """
444
+ try:
445
+ # Use xarray's minimal loading to get just dimensions
446
+ with xa.open_dataset(file_path, decode_times=False) as ds:
447
+ dims = dict(ds.sizes)
448
+ has_level, _, _ = self._whether_multi_level(dims)
449
+ return has_level
450
+ except Exception as e:
451
+ self.logger.warning(f"Could not check levels in {file_path}: {e}")
452
+ return False
453
+
454
+ def _get_file_info(self, file_path: Path) -> dict:
455
+ """Get comprehensive information about a file with caching.
456
+
457
+ Returns:
458
+ - 'file_size_gb': File size in GB
459
+ - 'dims': Dimensions of the file
460
+ - 'has_level': Whether the file has a level dimension
461
+ - 'level_count': Number of level dimensions
462
+ - 'coords': Coordinate information
463
+ - 'grid_type': Grid type
464
+ - 'estimated_memory_gb': Estimated memory usage in GB
465
+ - 'time_steps': Number of time steps
466
+ """
467
+ # check cache first
468
+ if file_path in self._file_info_cache:
469
+ return self._file_info_cache[file_path]
470
+
471
+ try:
472
+ ds = xa.open_dataset(file_path, decode_times=False)
473
+
474
+ # get metadata
475
+ file_size_gb = file_path.stat().st_size / (1024**3) # faster than ds.nbytes
476
+ dims = dict(ds.sizes)
477
+ has_level, _, level_count = self._whether_multi_level(dims)
478
+
479
+ # get coordinate information
480
+ coords_info = {}
481
+ try:
482
+ for coord in ['lon', 'longitude', 'lat', 'latitude']:
483
+ if coord in ds.coords:
484
+ coord_data = ds[coord]
485
+ coords_info[coord] = {
486
+ 'shape': coord_data.shape,
487
+ 'ndim': coord_data.ndim,
488
+ 'size': coord_data.size,
489
+ }
490
+ except Exception as coord_error:
491
+ self.logger.warning(f"Could not get coordinate info for {file_path}: {coord_error}")
492
+ coords_info = {}
493
+
494
+ # determine grid type
495
+ try:
496
+ grid_type = self._detect_grid_type(ds, dims)
497
+ except Exception as grid_error:
498
+ self.logger.warning(f"Could not detect grid type for {file_path}: {grid_error}")
499
+ grid_type = 'unknown'
500
+
501
+ has_grid_corners = (
502
+ grid_type == "structured" or self._has_grid_corners(ds)
503
+ )
504
+
505
+ # estimate memory usage required for regridding
506
+ estimated_memory_gb = file_size_gb * 3 # rough but conservative estimate: when regridding will have original, multiple, and maybe an intermediate array in memory
507
+
508
+ file_info = {
509
+ 'file_size_gb': file_size_gb,
510
+ 'dims': dims,
511
+ 'has_level': has_level,
512
+ 'level_count': level_count,
513
+ 'coords': coords_info,
514
+ 'grid_type': grid_type,
515
+ 'has_grid_corners': has_grid_corners,
516
+ 'estimated_memory_gb': estimated_memory_gb,
517
+ 'time_steps': dims.get('time', 1),
518
+ }
519
+
520
+ # cache the result
521
+ self._file_info_cache[file_path] = file_info
522
+ return file_info
523
+
524
+ except Exception as e:
525
+ self.logger.warning(f"Could not analyze file {file_path}: {e}")
526
+ error_info = {
527
+ 'file_size_gb': 0.0,
528
+ 'dims': {},
529
+ 'has_level': False,
530
+ 'level_count': 0,
531
+ 'coords': {},
532
+ 'grid_type': 'unknown',
533
+ 'has_grid_corners': False,
534
+ 'estimated_memory_gb': 0.0,
535
+ 'time_steps': None,
536
+ }
537
+ # cache the error result too to avoid repeated failures
538
+ self._file_info_cache[file_path] = error_info
539
+ return error_info
540
+
541
+ def _get_grid_signature(self, file_info: dict) -> str:
542
+ """Generate a unique signature for the grid based on file info to avoid duplicating weights."""
543
+ remap_op, gen_op = self._regrid_operators(
544
+ file_info['grid_type'],
545
+ file_info.get('has_grid_corners', False),
546
+ )
547
+ signature_data = {
548
+ 'coords': file_info['coords'],
549
+ 'dims': file_info['dims'],
550
+ 'grid_type': file_info['grid_type'],
551
+ 'remap_op': remap_op,
552
+ 'gen_op': gen_op,
553
+ }
554
+
555
+ signature_str = str(sorted(signature_data.items()))
556
+ return hashlib.md5(signature_str.encode()).hexdigest()[:12]
557
+
558
+ def _get_weight_path(self, grid_signature: str) -> Path:
559
+ """Get the path for regrid weights based on grid signature."""
560
+ if self.weight_cache_dir is None:
561
+ raise RuntimeError("weight_cache_dir is not set")
562
+ return self.weight_cache_dir / f"weights_{grid_signature}.nc"
563
+
564
+ def _cleanup_stale_chunks_for_file(self, file_path: Path) -> None:
565
+ """Remove leftover chunk files from prior failed runs of the same source file."""
566
+ stem = file_path.stem.replace("_top_level", "").replace("_seafloor", "")
567
+ for chunk in file_path.parent.glob(f"{stem}_chunk_*.nc"):
568
+ try:
569
+ chunk.unlink()
570
+ except OSError as e:
571
+ self._log_error(f"Could not remove stale chunk {chunk}: {e}")
572
+
573
+ def _is_valid_weight_file(self, weight_path: Path) -> bool:
574
+ """Return False for missing, empty, or corrupt cached weight files."""
575
+ if not weight_path.exists() or weight_path.stat().st_size == 0:
576
+ return False
577
+ try:
578
+ with xa.open_dataset(weight_path, decode_times=False) as ds:
579
+ return len(ds.dims) > 0
580
+ except Exception:
581
+ return False
582
+
583
+ def _invalidate_weight_file(self, weight_path: Path) -> None:
584
+ if weight_path.exists():
585
+ try:
586
+ weight_path.unlink()
587
+ except OSError:
588
+ pass
589
+
590
+ def _generate_output_filename(
591
+ self,
592
+ input_path: Path,
593
+ has_level: bool,
594
+ extract_surface: bool = False,
595
+ extract_seafloor: bool = False,
596
+ ) -> str:
597
+ """Generate output filename by modifying the input filename.
598
+
599
+ Args:
600
+ input_path (Path): Path to the input file
601
+ has_level (bool): Whether the source has a level/depth dimension
602
+ extract_surface (bool): Whether we extracted top level only
603
+ extract_seafloor (bool): Whether we extracted seafloor only
604
+
605
+ Returns:
606
+ str: Generated output filename (e.g. name_regridded.nc, name_top_level_regridded.nc, name_seafloor_regridded.nc)
607
+ """
608
+ name = input_path.name
609
+ if extract_seafloor and "_seafloor" not in name:
610
+ name = name.replace(input_path.suffix, "_seafloor" + input_path.suffix)
611
+ if extract_surface and "_top_level" not in name and has_level:
612
+ name = name.replace(input_path.suffix, "_top_level" + input_path.suffix)
613
+ return name.replace(input_path.suffix, "_regridded" + input_path.suffix)
614
+
615
+ def _get_representative_file(self, input_files: list[Path]) -> Optional[Path]:
616
+ """Get a representative file from a list of files for resolution calculation."""
617
+ return pick_representative_file(input_files)
618
+
619
+ def _calculate_target_resolution(self, input_file: Path) -> tuple[float, float]:
620
+ """Calculate target resolution based on dataset's nominal_resolution attribute.
621
+
622
+ Args:
623
+ input_file (Path): Path to the input file to analyze
624
+
625
+ Returns:
626
+ tuple[float, float]: (lon_res, lat_res) in degrees, or (9999.0, 9999.0) if calculation fails
627
+ """
628
+ try:
629
+ with xa.open_dataset(input_file, decode_times=False) as ds:
630
+ # check if nominal_resolution attribute exists
631
+ if 'nominal_resolution' not in ds.attrs:
632
+ self.logger.warning(f"No 'nominal_resolution' attribute found in {input_file.name}")
633
+ return (9999.0, 9999.0)
634
+
635
+ # get native resolution
636
+ native_res = ds.attrs['nominal_resolution']
637
+ if self.verbose_diagnostics:
638
+ self.logger.info(f"Found nominal_resolution: {native_res}")
639
+
640
+ # calculate target resolution
641
+ target_res = calc_resolution(native_res)
642
+
643
+ if target_res == 9999.0:
644
+ self.logger.warning(f"Could not parse nominal_resolution '{native_res}' from {input_file.name}")
645
+ return (9999.0, 9999.0)
646
+
647
+ # return same resolution for both lon and lat (regular grid)
648
+ return (target_res, target_res)
649
+
650
+ except Exception as e:
651
+ self.logger.warning(f"Error calculating target resolution from {input_file.name}: {e}")
652
+ return (9999.0, 9999.0)
653
+
654
+ def _generate_target_grid_description(self, representative_file: Optional[Path] = None) -> str:
655
+ """Generate CDO target grid description.
656
+
657
+ Args:
658
+ representative_file (Path, optional): File to use for resolution calculation.
659
+ If None, uses the pipeline's target_resolution.
660
+ """
661
+ # try to calculate target resolution from a representative file
662
+ if representative_file and representative_file.exists():
663
+ lon_res, lat_res = self._calculate_target_resolution(representative_file)
664
+ if lon_res == 9999.0 or lat_res == 9999.0:
665
+ if (
666
+ hasattr(self, "_representative_file")
667
+ and self._representative_file
668
+ and self._representative_file.exists()
669
+ and self._representative_file != representative_file
670
+ ):
671
+ lon_res, lat_res = self._calculate_target_resolution(self._representative_file)
672
+ if lon_res == 9999.0 or lat_res == 9999.0:
673
+ lon_res, lat_res = self.target_resolution
674
+ if self.verbose_diagnostics:
675
+ self.console.print(f"[yellow]Using pipeline target resolution: {lon_res}° x {lat_res}°[/yellow]")
676
+ else:
677
+ if self.verbose_diagnostics:
678
+ self.console.print(f"[green]Calculated target resolution from {representative_file.name}: {lon_res:.3f}° x {lat_res:.3f}° (to 3 decimal places)[/green]")
679
+ elif hasattr(self, '_representative_file') and self._representative_file and self._representative_file.exists():
680
+ # use the pipeline's representative file
681
+ lon_res, lat_res = self._calculate_target_resolution(self._representative_file)
682
+ if lon_res == 9999.0 or lat_res == 9999.0:
683
+ # fall back to pipeline's target resolution
684
+ lon_res, lat_res = self.target_resolution
685
+ if self.verbose_diagnostics:
686
+ self.console.print(f"[yellow]Using pipeline target resolution: {lon_res}° x {lat_res}°[/yellow]")
687
+ else:
688
+ if self.verbose_diagnostics:
689
+ self.console.print(f"[green]Calculated target resolution from {self._representative_file.name}: {lon_res}° x {lat_res}°[/green]")
690
+ else:
691
+ # use pipeline's target resolution
692
+ lon_res, lat_res = self.target_resolution
693
+ if self.verbose_diagnostics:
694
+ self.console.print(f"[blue]Using pipeline target resolution: {lon_res}° x {lat_res}°[/blue]")
695
+
696
+ if self.target_grid == "lonlat":
697
+ # regular lon/lat grid
698
+ xsize = int(360 / lon_res)
699
+ ysize = int(180 / lat_res)
700
+ xfirst = -180 + lon_res / 2 # TODO: check this logic
701
+ yfirst = -90 + lat_res / 2 # TODO: the extent of the grid is not always -180 to 180 and -90 to 90: does this matter? Hard to say, since trying to get everything onto the same grid. Does CDO account for this?
702
+
703
+ return f"""gridtype = lonlat
704
+ xsize = {xsize}
705
+ ysize = {ysize}
706
+ xfirst = {xfirst}
707
+ xinc = {lon_res}
708
+ yfirst = {yfirst}
709
+ yinc = {lat_res}"""
710
+
711
+ elif self.target_grid == "gaussian": # equally spaced along latitude, unequally along longitude
712
+ # Gaussian grid
713
+ ysize = int(180 / lat_res)
714
+ return f"""gridtype = gaussian
715
+ ysize = {ysize}"""
716
+
717
+ else:
718
+ raise ValueError(f"Unsupported target grid type: {self.target_grid}")
719
+ def _should_chunk_file(self, file_info: dict) -> bool:
720
+ """Determine if a file should be processed in chunks based on file size, memory usage, and time steps.
721
+
722
+ Returns (bool): True if the file should be processed in chunks, False otherwise
723
+ """
724
+ if not self.enable_chunking:
725
+ return False
726
+
727
+ return (
728
+ file_info['file_size_gb'] > self.chunk_size_gb or
729
+ file_info['estimated_memory_gb'] > self.max_memory_gb or
730
+ file_info['time_steps'] > 100
731
+ )
732
+
733
+ def _notify_progress(
734
+ self,
735
+ phase: str,
736
+ *,
737
+ chunks_done: int = 0,
738
+ chunks_total: int = 0,
739
+ pct: Optional[int] = None,
740
+ ui: Optional["RegridProgressUI"] = None,
741
+ ui_input_path: Optional[Path] = None,
742
+ regrid_mode: str = "complete",
743
+ ) -> None:
744
+ """Report per-file progress to the local UI and/or shared parallel-worker state."""
745
+ ui_path = ui_input_path
746
+ if chunks_total > 0:
747
+ if ui is not None and ui_path is not None:
748
+ ui.update_chunk_progress(
749
+ ui_path, chunks_done, chunks_total, phase=phase, regrid_mode=regrid_mode
750
+ )
751
+ if self._progress_state is not None and self._progress_key:
752
+ self._progress_state[self._progress_key] = {
753
+ "phase": phase,
754
+ "chunks_done": chunks_done,
755
+ "chunks_total": chunks_total,
756
+ }
757
+ elif pct is not None:
758
+ if ui is not None and ui_path is not None:
759
+ ui.update_file_progress(ui_path, pct, phase, regrid_mode=regrid_mode)
760
+ if self._progress_state is not None and self._progress_key:
761
+ self._progress_state[self._progress_key] = {
762
+ "phase": phase,
763
+ "chunks_done": pct,
764
+ "chunks_total": 0,
765
+ }
766
+
767
+ def _chunk_file_by_time(
768
+ self,
769
+ file_path: Path,
770
+ chunk_size: int = 10,
771
+ name_suffix: Optional[str] = None,
772
+ ui: Optional["RegridProgressUI"] = None,
773
+ ui_input_path: Optional[Path] = None,
774
+ regrid_mode: str = "complete",
775
+ ) -> list[Path]:
776
+ """Split a file into time chunks; chunk files are written to a writable directory.
777
+
778
+ When name_suffix is set (e.g. '_top_level', '_seafloor'), chunk stems use it so
779
+ names stay consistent with extract_surface/extract_seafloor even when preparation
780
+ returned the original file (no level dim or preparation failed).
781
+
782
+ Returns (list[Path]): List of paths to the chunked files
783
+ """
784
+ ds = None
785
+ chunk_files = []
786
+ try:
787
+ ds = xa.open_dataset(file_path, decode_times=False)
788
+
789
+ # Check if file has time dimension and enough time steps to chunk
790
+ if 'time' not in ds.dims:
791
+ self.logger.warning(f"File {file_path.name} has no 'time' dimension, skipping chunking")
792
+ return [file_path]
793
+
794
+ time_length = len(ds.time)
795
+ if time_length <= chunk_size:
796
+ if self.verbose:
797
+ self.logger.debug(f"File {file_path.name} has {time_length} time steps (<= {chunk_size}), skipping chunking")
798
+ return [file_path]
799
+
800
+ time_chunks = list(range(0, time_length, chunk_size))
801
+ total_chunks = len(time_chunks)
802
+ _ui_path = ui_input_path if ui_input_path is not None else file_path
803
+ self._notify_progress(
804
+ "creating",
805
+ chunks_done=0,
806
+ chunks_total=total_chunks,
807
+ ui=ui,
808
+ ui_input_path=_ui_path,
809
+ regrid_mode=regrid_mode,
810
+ )
811
+
812
+ # Chunk stem: use name_suffix when in surface/seafloor mode so names are e.g.
813
+ # *_chunk_000_top_level even if prepared_path was the original file.
814
+ if name_suffix:
815
+ base_stem = (
816
+ file_path.stem.replace("_top_level", "").replace("_seafloor", "").rstrip("_")
817
+ or file_path.stem
818
+ )
819
+ chunk_stem_template = f"{base_stem}_chunk_{{i:03d}}{name_suffix}"
820
+ elif "_top_level" in file_path.stem:
821
+ base_stem = file_path.stem.replace("_top_level", "")
822
+ chunk_stem_template = f"{base_stem}_chunk_{{i:03d}}_top_level"
823
+ elif "_seafloor" in file_path.stem:
824
+ base_stem = file_path.stem.replace("_seafloor", "")
825
+ chunk_stem_template = f"{base_stem}_chunk_{{i:03d}}_seafloor"
826
+ else:
827
+ chunk_stem_template = f"{file_path.stem}_chunk_{{i:03d}}"
828
+
829
+ # Encoding for chunk writes: preserve source encoding (netCDF4-compatible keys only) and avoid _FillValue on coords/bounds.
830
+ # Omit chunksizes so chunk time dimension (smaller than source) does not trigger "chunksize cannot exceed dimension size".
831
+ chunk_encoding = {}
832
+ for v in ds.variables:
833
+ raw = getattr(ds[v], "encoding", None) or {}
834
+ enc = {k: raw[k] for k in (raw if isinstance(raw, dict) else {}) if k in NC4_ENCODING_KEYS and k != "chunksizes"}
835
+ chunk_encoding[v] = enc
836
+ if v in ds.coords or v in ("lat_bnds", "lon_bnds", "lat", "lon", "time", "time_bnds"):
837
+ chunk_encoding[v]["_FillValue"] = None
838
+
839
+ # Write chunks to a writable directory (input dir may be read-only, e.g. shared data)
840
+ out_dir = file_path.parent if os.access(file_path.parent, os.W_OK) else Path(tempfile.gettempdir())
841
+ if out_dir != file_path.parent:
842
+ # Unique suffix so concurrent runs don't collide when using system temp
843
+ unique_suffix = f"_{os.getpid()}_{abs(hash(file_path)) % 1000000:06d}"
844
+ else:
845
+ unique_suffix = ""
846
+
847
+ # Create chunks
848
+ for i, start_idx in enumerate(time_chunks):
849
+ end_idx = min(start_idx + chunk_size, time_length)
850
+ ds_chunk = ds.isel(time=slice(start_idx, end_idx))
851
+
852
+ # Verify chunk has data before writing
853
+ if ds_chunk.sizes.get('time', 0) == 0:
854
+ self.logger.warning(f"Skipping empty chunk {i} for {file_path.name}")
855
+ continue
856
+
857
+ chunk_path = out_dir / f"{chunk_stem_template.format(i=i)}{unique_suffix}{file_path.suffix}"
858
+
859
+ # Write chunk file (same encoding as source so CDO accepts seafloor/top_level chunks)
860
+ try:
861
+ ds_chunk.to_netcdf(chunk_path, encoding=chunk_encoding)
862
+
863
+ # Verify chunk file was created and is not empty
864
+ if not chunk_path.exists() or chunk_path.stat().st_size == 0:
865
+ self._log_error(f"Chunk file {chunk_path.name} is empty or was not created")
866
+ if chunk_path.exists():
867
+ chunk_path.unlink()
868
+ continue
869
+
870
+ # Track the created chunk file for cleanup
871
+ self._track_created_file(chunk_path)
872
+ chunk_files.append(chunk_path)
873
+ self._notify_progress(
874
+ "creating",
875
+ chunks_done=len(chunk_files),
876
+ chunks_total=total_chunks,
877
+ ui=ui,
878
+ ui_input_path=_ui_path,
879
+ regrid_mode=regrid_mode,
880
+ )
881
+
882
+ except Exception as chunk_error:
883
+ self._log_error(
884
+ f"Failed to write chunk {i} for {file_path.name}: {chunk_error}",
885
+ exc=chunk_error if isinstance(chunk_error, BaseException) else None,
886
+ )
887
+ # Clean up failed chunk file if it was partially created
888
+ if chunk_path.exists():
889
+ try:
890
+ chunk_path.unlink()
891
+ except Exception:
892
+ pass
893
+ raise # Re-raise to trigger cleanup
894
+
895
+ # If no chunks were created successfully, return original file
896
+ if not chunk_files:
897
+ self.logger.warning(f"No chunks created for {file_path.name}, returning original file")
898
+ return [file_path]
899
+
900
+ return chunk_files
901
+
902
+ except Exception as e:
903
+ self._log_error(f"Failed to chunk file {file_path}: {e}", exc=e)
904
+
905
+ # Clean up any chunk files that were created before the error
906
+ for chunk_file in chunk_files:
907
+ try:
908
+ if chunk_file.exists() and chunk_file != file_path:
909
+ self.logger.warning(f"Cleaning up incomplete chunk file: {chunk_file.name}")
910
+ chunk_file.unlink()
911
+ except Exception as cleanup_error:
912
+ self._log_error(
913
+ f"Failed to clean up chunk file {chunk_file.name}: {cleanup_error}",
914
+ exc=cleanup_error,
915
+ )
916
+
917
+ return [file_path]
918
+ finally:
919
+ # Always close the dataset if it was opened
920
+ if ds is not None:
921
+ try:
922
+ ds.close()
923
+ except Exception:
924
+ pass
925
+
926
+ def _extract_seafloor_values(
927
+ self,
928
+ file_path: Path,
929
+ ui: Optional["RegridProgressUI"] = None,
930
+ ui_input_path: Optional[Path] = None,
931
+ ) -> Path:
932
+ """
933
+ Extract seafloor values (deepest non-NaN values along depth dimension) from a file.
934
+ Uses cached depth indices for files in the same directory for optimization.
935
+ For large files, uses chunked reading and reports progress when ui is provided.
936
+
937
+ Args:
938
+ file_path (Path): Path to the input file
939
+ ui: Optional progress UI (updates with phases: depth indices, extracting vars, writing)
940
+ ui_input_path: Path to show in UI (usually same as file_path)
941
+
942
+ Returns (Path): Path to the seafloor-extracted file
943
+ """
944
+ _ui_path = ui_input_path if ui_input_path is not None else file_path
945
+
946
+ def _progress(pct: int, msg: str) -> None:
947
+ if ui is not None:
948
+ ui.update_file_progress(_ui_path, pct, msg, regrid_mode="seafloor")
949
+ self.logger.info(f"Seafloor extraction: {msg}")
950
+
951
+ try:
952
+ # Write to a writable dir (temp if input dir is read-only)
953
+ out_dir = file_path.parent if os.access(file_path.parent, os.W_OK) else Path(tempfile.gettempdir())
954
+ seafloor_path = out_dir / f"{file_path.stem}_seafloor{file_path.suffix}"
955
+ if seafloor_path.exists():
956
+ # Validate existing file: re-extract if empty or no data vars (e.g. corrupt or from old code)
957
+ try:
958
+ with xa.open_dataset(seafloor_path, decode_times=False) as existing:
959
+ if existing.data_vars and all(existing.sizes.get(d, 0) > 0 for d in next(iter(existing.data_vars.values())).dims):
960
+ if self.verbose:
961
+ self.console.print(f"[cyan]Using existing seafloor file: {seafloor_path.name}[/cyan]")
962
+ return seafloor_path
963
+ except Exception:
964
+ pass
965
+ seafloor_path.unlink(missing_ok=True)
966
+ if self.verbose:
967
+ self.console.print(f"[yellow]Re-extracting seafloor (existing file invalid or empty)[/yellow]")
968
+
969
+ target_variable = self._get_target_variable(file_path)
970
+ # Chunked read for large files: avoids loading the full 3D/4D array at once
971
+ ds = xa.open_dataset(file_path, decode_times=False, chunks="auto")
972
+ has_level, level_dims, level_count = self._whether_multi_level(dict(ds.sizes))
973
+
974
+ if not has_level or level_count == 0:
975
+ # No depth dimension, return original file
976
+ return file_path
977
+
978
+ # Find the level dimension that exists in the dataset
979
+ level_dim = None
980
+ for dim in level_dims:
981
+ if dim in ds.dims:
982
+ level_dim = dim
983
+ break
984
+
985
+ if level_dim is None:
986
+ return file_path
987
+
988
+
989
+ # Get or compute depth indices for this directory (use cache only if use_seafloor_cache)
990
+ dir_key = str(file_path.parent)
991
+ cache_key = f"{level_dim}_{ds.sizes[level_dim]}"
992
+ use_cache = self.use_seafloor_cache
993
+ if dir_key not in self._seafloor_depth_cache:
994
+ self._seafloor_depth_cache[dir_key] = {}
995
+ cache_hit = use_cache and (cache_key in self._seafloor_depth_cache[dir_key])
996
+
997
+ if not cache_hit:
998
+ # Compute seafloor depth indices: find deepest non-NaN value for each spatial location
999
+ _progress(11, "Seafloor: computing depth indices (may take several minutes for large files)")
1000
+ if self.verbose:
1001
+ self.console.print(f"[blue]Computing seafloor depth indices for {file_path.name}...[/blue]")
1002
+
1003
+ # Get a representative data variable that has the level dimension
1004
+ # (exclude mesh/coord vars that may have only ncells, vertices)
1005
+ data_vars_with_level = [
1006
+ v for v in ds.data_vars
1007
+ if level_dim in ds[v].dims
1008
+ ]
1009
+ if not data_vars_with_level:
1010
+ self.logger.warning(
1011
+ f"No data variables with level dimension '{level_dim}' in {file_path.name}"
1012
+ )
1013
+ ds.close()
1014
+ return file_path
1015
+ # var_name = data_vars_with_level[0] # TODO: surely this should be the target variable (ie uo)?
1016
+ var_data = ds[target_variable]
1017
+ # Take first time step only if the variable has a time dimension
1018
+ if "time" in var_data.dims:
1019
+ var_data = var_data.isel(time=0)
1020
+
1021
+ # Find deepest non-NaN index along level dimension
1022
+ level_size = ds.sizes[level_dim]
1023
+
1024
+ # Get spatial dimensions (all dims except level, time, bnds)
1025
+ spatial_dims = [d for d in var_data.dims if d not in [level_dim, 'time', 'bnds']] # with bnds was returning with a bnds coordinate. May be other rogue variables out there
1026
+
1027
+ if spatial_dims:
1028
+ # For each spatial location, find deepest non-NaN value
1029
+ # Create a mask of valid (non-NaN) values using DataArray's isnull method
1030
+ valid_mask = ~var_data.isnull()
1031
+
1032
+ # Check which locations have any valid values
1033
+ has_valid = valid_mask.any(dim=level_dim)
1034
+
1035
+ # Find the deepest valid index for each spatial location
1036
+ # Reverse along level dimension and find first valid (which is deepest)
1037
+ # argmax returns index of first True (deepest valid)
1038
+ reversed_valid = valid_mask.isel({level_dim: slice(None, None, -1)})
1039
+ reversed_argmax = reversed_valid.argmax(dim=level_dim)
1040
+
1041
+ # Convert back to original indexing (deepest = level_size - 1 - reversed_index)
1042
+ deepest_idx = level_size - 1 - reversed_argmax
1043
+
1044
+ # Handle cases where all values are NaN (set to last index)
1045
+ # Also check if argmax returned 0 (which could mean all False or first is True)
1046
+ # If no valid values exist, argmax will return 0, but we need to check if that's valid
1047
+ all_nan_mask = ~has_valid
1048
+ # For locations with no valid values, use last index
1049
+ deepest_idx = deepest_idx.where(~all_nan_mask, level_size - 1)
1050
+
1051
+ # Convert to numpy array for indexing (triggers chunked compute if dask)
1052
+ depth_indices_array = deepest_idx.values
1053
+ else:
1054
+ # No spatial dimensions (unlikely but handle it)
1055
+ # Just find deepest non-NaN across all data
1056
+ valid_mask = ~var_data.isnull()
1057
+ if valid_mask.any():
1058
+ # Find deepest valid index
1059
+ reversed_valid = valid_mask.isel({level_dim: slice(None, None, -1)})
1060
+ reversed_argmax = reversed_valid.argmax(dim=level_dim)
1061
+ deepest_idx = level_size - 1 - reversed_argmax
1062
+ depth_indices_array = int(deepest_idx.values) if hasattr(deepest_idx, 'values') else int(deepest_idx)
1063
+ else:
1064
+ # All NaN, use last index
1065
+ depth_indices_array = level_size - 1
1066
+
1067
+ # Cache the depth indices
1068
+ if use_cache:
1069
+ self._seafloor_depth_cache[dir_key][cache_key] = depth_indices_array
1070
+ _progress(12, "Seafloor: depth indices done" + (" (cached)" if use_cache else ""))
1071
+ if self.verbose:
1072
+ self.console.print(f"[green]Computed seafloor depth indices[/green]" + (" (cached)" if use_cache else ""))
1073
+ else:
1074
+ # Use cached depth indices
1075
+ depth_indices_array = self._seafloor_depth_cache[dir_key][cache_key]
1076
+ _progress(12, "Seafloor: using cached depth indices")
1077
+ if self.verbose:
1078
+ self.console.print(f"[cyan]Using cached seafloor depth indices[/cyan]")
1079
+ # spatial_dims needed when depth_indices_array is an array; derive from first data var with level_dim
1080
+ if not (isinstance(depth_indices_array, (int, np.integer)) or np.isscalar(depth_indices_array)):
1081
+ data_vars_with_level = [v for v in ds.data_vars if level_dim in ds[v].dims]
1082
+ if data_vars_with_level:
1083
+ spatial_dims = [d for d in ds[data_vars_with_level[0]].dims if d not in [level_dim, "time"]]
1084
+ else:
1085
+ spatial_dims = []
1086
+
1087
+ # Extract seafloor values: same structural approach as top_level (isel one level, drop level dim).
1088
+ # Build seafloor dataset from the source dataset so the written file has the same layout CDO
1089
+ # expects (identical to top_level: data at one level, written with plain to_netcdf).
1090
+ if self.verbose:
1091
+ self.console.print(f"[blue]Extracting seafloor values from {file_path.name}...[/blue]")
1092
+
1093
+ _progress(15, "Seafloor: writing file")
1094
+
1095
+ # Copy dataset and apply seafloor selection to each variable that has level_dim
1096
+ seafloor_ds = ds.copy(deep=False)
1097
+ if isinstance(depth_indices_array, (int, np.integer)) or np.isscalar(depth_indices_array):
1098
+ seafloor_ds = seafloor_ds.isel({level_dim: int(depth_indices_array)})
1099
+ else:
1100
+ depth_idx_da = xa.DataArray(
1101
+ depth_indices_array,
1102
+ dims=spatial_dims,
1103
+ coords={d: seafloor_ds[target_variable].coords[d] for d in spatial_dims if d in seafloor_ds[target_variable].coords},
1104
+ )
1105
+ for var_name in list(seafloor_ds.data_vars):
1106
+ if level_dim in seafloor_ds[var_name].dims:
1107
+ seafloor_ds[var_name] = seafloor_ds[var_name].isel({level_dim: depth_idx_da})
1108
+ seafloor_ds = seafloor_ds.drop_vars(level_dim, errors="ignore")
1109
+ seafloor_ds = seafloor_ds.load()
1110
+
1111
+ # Write next to original file, same way as top_level: plain to_netcdf (no encoding/engine/format)
1112
+ write_dir = seafloor_path.parent if os.access(seafloor_path.parent, os.W_OK) else Path(tempfile.gettempdir())
1113
+ write_path = write_dir / (seafloor_path.name + f".tmp.{os.getpid()}.{uuid.uuid4().hex[:8]}")
1114
+ try:
1115
+ seafloor_ds.to_netcdf(write_path)
1116
+ write_path.rename(seafloor_path)
1117
+ finally:
1118
+ if write_path.exists():
1119
+ write_path.unlink(missing_ok=True)
1120
+ ds.close()
1121
+ seafloor_ds.close()
1122
+
1123
+ # Track the created file for cleanup
1124
+ self._track_created_file(seafloor_path)
1125
+
1126
+ if self.verbose:
1127
+ self.console.print(f"[green]Extracted seafloor values: {seafloor_path.name}[/green]")
1128
+
1129
+ return seafloor_path
1130
+
1131
+ except Exception as e:
1132
+ self._log_error(f"Failed to extract seafloor values from {file_path}: {e}", exc=e)
1133
+ # Return None or raise exception to indicate failure
1134
+ # Returning original file_path would be misleading
1135
+ raise RuntimeError(f"Seafloor extraction failed for {file_path.name}: {e}") from e
1136
+
1137
+ def _prepare_file_for_regridding(
1138
+ self,
1139
+ file_path: Path,
1140
+ ui: Optional["RegridProgressUI"] = None,
1141
+ ui_input_path: Optional[Path] = None,
1142
+ ) -> Path:
1143
+ """
1144
+ Prepare a file for regridding according to pipeline mode.
1145
+
1146
+ - If extract_seafloor: extract seafloor values (use seafloor cache if use_seafloor_cache), return path to seafloor file.
1147
+ - If extract_surface: extract top level only, return path to top-level file.
1148
+ - If neither: return file_path (regrid whole file).
1149
+
1150
+ ui / ui_input_path: optional progress UI; if set, seafloor extraction will update progress (e.g. "Seafloor: computing depth indices").
1151
+ Returns (Path): Path to prepared file (possibly a temporary/prepared NetCDF).
1152
+ """
1153
+ if self.extract_seafloor:
1154
+ if "_seafloor" in file_path.stem:
1155
+ return file_path
1156
+ try:
1157
+ prepared_path = self._extract_seafloor_values(
1158
+ file_path, ui=ui, ui_input_path=ui_input_path
1159
+ )
1160
+ return prepared_path
1161
+ except Exception as e:
1162
+ self._log_error(f"Seafloor extraction failed, cannot proceed with regridding: {e}", exc=e)
1163
+ raise
1164
+
1165
+ if not self.extract_surface:
1166
+ return file_path
1167
+
1168
+ try:
1169
+ if "_top_level" in file_path.stem: # TODO: is top level always correct for the sea surface value?
1170
+ return file_path
1171
+ ds = xa.open_dataset(file_path, decode_times=False)
1172
+ has_level, level_dims, level_count = self._whether_multi_level(dict(ds.sizes))
1173
+ if not isinstance(level_dims, list):
1174
+ self._log_error(f"level_dims is not a list: {type(level_dims)} = {level_dims}")
1175
+ return file_path
1176
+ if not has_level or level_count == 0:
1177
+ return file_path
1178
+ for dim in level_dims:
1179
+ if dim in ds.dims:
1180
+ ds = ds.isel({dim: 0})
1181
+ break
1182
+ # Write to a writable dir (temp if input dir is read-only). Use a unique
1183
+ # path per process so concurrent workers never write the same file (avoids
1184
+ # PermissionError and races when the same source is processed in parallel).
1185
+ prep_dir = file_path.parent if os.access(file_path.parent, os.W_OK) else Path(tempfile.gettempdir())
1186
+ unique_suffix = f"_{os.getpid()}_{uuid.uuid4().hex[:8]}"
1187
+ prepared_path = prep_dir / f"{file_path.stem}_top_level{unique_suffix}{file_path.suffix}"
1188
+ ds.to_netcdf(prepared_path)
1189
+ self._track_created_file(prepared_path)
1190
+ if self.verbose_diagnostics:
1191
+ self.console.print(f"[cyan]Prepared file (top level): {prepared_path.name}[/cyan]")
1192
+ self.logger.info(f"Prepared file (top level): {prepared_path.name}")
1193
+ return prepared_path
1194
+ except Exception as e:
1195
+ self.logger.warning(f"Could not prepare file {file_path}: {e}")
1196
+ self._log_error(f"Could not prepare file {file_path}: {e}", exc=e)
1197
+ return file_path
1198
+
1199
+ def _ensure_top_level_file_for_chunking(self, file_path: Path) -> Path:
1200
+ """When chunking in extract_surface mode, ensure we chunk a top-level-only file.
1201
+
1202
+ If preparation was skipped (prepared_path == input_path) but the file has multiple
1203
+ levels, chunk files would be named *_top_level but contain all levels. This helper
1204
+ extracts top level to a temp file when needed so chunk content matches the name.
1205
+ """
1206
+ try:
1207
+ with xa.open_dataset(file_path, decode_times=False) as ds:
1208
+ has_level, level_dims, level_count = self._whether_multi_level(dict(ds.sizes))
1209
+ if not has_level or level_count <= 1:
1210
+ return file_path
1211
+ for dim in level_dims:
1212
+ if dim in ds.dims:
1213
+ ds = ds.isel({dim: 0})
1214
+ break
1215
+ prep_dir = file_path.parent if os.access(file_path.parent, os.W_OK) else Path(tempfile.gettempdir())
1216
+ unique_suffix = f"_{os.getpid()}_{uuid.uuid4().hex[:8]}"
1217
+ top_path = prep_dir / f"{file_path.stem}_top_level{unique_suffix}{file_path.suffix}"
1218
+ ds.to_netcdf(top_path)
1219
+ self._track_created_file(top_path)
1220
+ self.logger.info(
1221
+ f"Created top-level temp file for chunking (was multi-level): {top_path.name}"
1222
+ )
1223
+ return top_path
1224
+ except Exception as e:
1225
+ self.logger.warning(f"Could not create top-level file for chunking: {e}")
1226
+ return file_path
1227
+
1228
+ def _is_valid_prepared_file(self, prepared_path: Path) -> bool:
1229
+ """Return False if prepared file is missing, empty, or has no data variables (CDO would fail)."""
1230
+ if not prepared_path.exists():
1231
+ self.logger.warning(f"Prepared file does not exist: {prepared_path}")
1232
+ return False
1233
+ if prepared_path.stat().st_size == 0:
1234
+ self.logger.warning(
1235
+ f"Prepared file is empty (0 bytes), skipping regridding: {prepared_path.name}"
1236
+ )
1237
+ return False
1238
+ try:
1239
+ with xa.open_dataset(prepared_path, decode_times=False) as ds:
1240
+ if not ds.data_vars:
1241
+ self.logger.warning(
1242
+ f"Prepared file has no data variables, skipping regridding: {prepared_path.name}"
1243
+ )
1244
+ return False
1245
+ except Exception as e:
1246
+ self.logger.warning(
1247
+ f"Prepared file could not be read or has unsupported structure: {prepared_path.name} ({e})"
1248
+ )
1249
+ return False
1250
+ return True
1251
+
1252
+ def _regrid_chunked_file(
1253
+ self,
1254
+ input_path: Path,
1255
+ output_path: Path,
1256
+ grid_signature: str,
1257
+ grid_type: str,
1258
+ chunk_files: list[Path],
1259
+ weight_cache_dir: Optional[Path] = None,
1260
+ ui: Optional["RegridProgressUI"] = None,
1261
+ ui_input_path: Optional[Path] = None,
1262
+ regrid_mode: str = "complete",
1263
+ ) -> bool:
1264
+ """Regrid a file that has been split into chunks, with optional parallel processing.
1265
+
1266
+ Args:
1267
+ - input_path (Path): Path to the input file
1268
+ - output_path (Path): Path to the output file
1269
+ - grid_signature (str): Unique signature for the grid
1270
+ - grid_type (str): Type of grid
1271
+ - chunk_files (list[Path]): List of paths to the chunked files
1272
+
1273
+ Returns (bool): True if successful, False otherwise
1274
+ """
1275
+ try:
1276
+ if weight_cache_dir is not None:
1277
+ self.weight_cache_dir = Path(weight_cache_dir)
1278
+ self.weight_cache_dir.mkdir(parents=True, exist_ok=True)
1279
+
1280
+ weight_path = self._get_weight_path(grid_signature)
1281
+ grid_desc = self._generate_target_grid_description(input_path)
1282
+
1283
+ with tempfile.TemporaryDirectory() as tmpdir:
1284
+ tmpdir = Path(tmpdir)
1285
+ grid_file = tmpdir / "target_grid.txt"
1286
+
1287
+ with open(grid_file, 'w') as f:
1288
+ f.write(grid_desc)
1289
+
1290
+ weights_exist = (
1291
+ self.use_regrid_cache and self._is_valid_weight_file(weight_path)
1292
+ )
1293
+
1294
+ if not weights_exist and len(chunk_files) > 0:
1295
+ if self.verbose_diagnostics:
1296
+ self.console.print("[blue]Generating weights from first chunk...[/blue]")
1297
+ first_chunk_output = tmpdir / "chunk_000.nc"
1298
+ if not self._regrid_without_weights(
1299
+ chunk_files[0], first_chunk_output, grid_file, grid_type
1300
+ ):
1301
+ self._log_error(
1302
+ f"Failed to regrid first chunk for weights: {chunk_files[0]}"
1303
+ )
1304
+ return False
1305
+ if not self._save_weights(chunk_files[0], weight_path, grid_file):
1306
+ self._log_error(f"Failed to save weights to {weight_path}")
1307
+ return False
1308
+ weights_exist = self._is_valid_weight_file(weight_path)
1309
+ if not weights_exist:
1310
+ self._log_error(f"Weight file invalid after save: {weight_path}")
1311
+ return False
1312
+ if self.verbose_diagnostics:
1313
+ self.console.print("[green]Weights generated and saved[/green]")
1314
+
1315
+ chunk_outputs = []
1316
+ total_chunks = len(chunk_files)
1317
+ _ui_path = ui_input_path if ui_input_path is not None else input_path
1318
+ self._notify_progress(
1319
+ "regridding",
1320
+ chunks_done=0,
1321
+ chunks_total=total_chunks,
1322
+ ui=ui,
1323
+ ui_input_path=_ui_path,
1324
+ regrid_mode=regrid_mode,
1325
+ )
1326
+
1327
+ for i, chunk_file in enumerate(chunk_files):
1328
+ chunk_output = tmpdir / f"chunk_{i:03d}.nc"
1329
+
1330
+ if weights_exist:
1331
+ if not self._regrid_with_weights(
1332
+ chunk_file, chunk_output, grid_file, weight_path, grid_type
1333
+ ):
1334
+ self._log_error(
1335
+ f"Failed to regrid chunk {i} with weights: {chunk_file}"
1336
+ )
1337
+ return False
1338
+ elif not self._regrid_without_weights(
1339
+ chunk_file, chunk_output, grid_file, grid_type
1340
+ ):
1341
+ self._log_error(f"Failed to regrid chunk {i} without weights: {chunk_file}")
1342
+ return False
1343
+
1344
+ chunk_outputs.append(chunk_output)
1345
+ self.stats['chunks_processed'] += 1
1346
+ self._notify_progress(
1347
+ "regridding",
1348
+ chunks_done=i + 1,
1349
+ chunks_total=total_chunks,
1350
+ ui=ui,
1351
+ ui_input_path=_ui_path,
1352
+ regrid_mode=regrid_mode,
1353
+ )
1354
+
1355
+ # combine chunks into single file
1356
+ if len(chunk_outputs) > 1:
1357
+ self._notify_progress(
1358
+ "combining",
1359
+ chunks_done=total_chunks,
1360
+ chunks_total=total_chunks,
1361
+ ui=ui,
1362
+ ui_input_path=_ui_path,
1363
+ regrid_mode=regrid_mode,
1364
+ )
1365
+ if self.verbose_diagnostics:
1366
+ self.console.print(f"[blue]Combining {len(chunk_outputs)} chunks...[/blue]")
1367
+ self._combine_chunks(chunk_outputs, output_path)
1368
+ else:
1369
+ shutil.copy2(chunk_outputs[0], output_path)
1370
+
1371
+ # clean up chunk files
1372
+ for chunk_file in chunk_files:
1373
+ if chunk_file != input_path and chunk_file.exists():
1374
+ try:
1375
+ chunk_file.unlink()
1376
+ self.console.print(f"[green]Cleaned up chunk file: {chunk_file.name}[/green]") if self.verbose_diagnostics else None
1377
+ except Exception as cleanup_error:
1378
+ self.console.print(f"[red]Failed to clean up chunk file {chunk_file.name}: {cleanup_error}[/red]") if self.verbose_diagnostics else None
1379
+
1380
+ return True
1381
+
1382
+ except Exception as e:
1383
+ self._log_error(f"Failed to regrid chunked file {input_path}: {e}", exc=e)
1384
+
1385
+ # Clean up chunk files even if regridding failed
1386
+ for chunk_file in chunk_files:
1387
+ if chunk_file != input_path and chunk_file.exists():
1388
+ try:
1389
+ self.logger.warning(f"Cleaning up chunk file after error: {chunk_file.name}")
1390
+ chunk_file.unlink()
1391
+ except Exception as cleanup_error:
1392
+ self._log_error(
1393
+ f"Failed to clean up chunk file {chunk_file.name}: {cleanup_error}",
1394
+ exc=cleanup_error,
1395
+ )
1396
+
1397
+ return False
1398
+
1399
+ def _combine_chunks(self, chunk_outputs: list[Path], output_path: Path):
1400
+ """Combine regridded chunks into a single file.
1401
+
1402
+ Args:
1403
+ - chunk_outputs (list[Path]): List of paths to the regridded chunks
1404
+ - output_path (Path): Path to the output file
1405
+
1406
+ Returns (bool): True if successful, False otherwise
1407
+ """
1408
+ try:
1409
+ # load all chunks # TODO: xa.open_mfdataset complains about dask not being installed
1410
+ datasets = [xa.open_dataset(chunk) for chunk in chunk_outputs]
1411
+ combined = xa.concat(datasets, dim='time') # combine along time dimension
1412
+ combined.to_netcdf(output_path)
1413
+ # close datasets to free memory
1414
+ for ds in datasets:
1415
+ ds.close()
1416
+
1417
+ except Exception as e:
1418
+ self._log_error(f"Failed to combine chunks: {e}", exc=e)
1419
+ return False
1420
+
1421
+ def _regrid_with_weights(
1422
+ self,
1423
+ input_path: Path,
1424
+ output_path: Path,
1425
+ grid_file: Path,
1426
+ weight_path: Path,
1427
+ grid_type: str = "structured",
1428
+ _allow_regenerate: bool = True,
1429
+ ) -> bool:
1430
+ """Regrid using existing weights; regenerate once if cache is missing or corrupt."""
1431
+ try:
1432
+ with weight_file_lock(weight_path):
1433
+ if not self._is_valid_weight_file(weight_path):
1434
+ if not _allow_regenerate:
1435
+ return False
1436
+ self._invalidate_weight_file(weight_path)
1437
+ if not self._regrid_without_weights(
1438
+ input_path, output_path, grid_file, grid_type
1439
+ ):
1440
+ return False
1441
+ self._save_weights(input_path, weight_path, grid_file)
1442
+ return output_path.exists() and output_path.stat().st_size > 0
1443
+
1444
+ self.cdo.remap(
1445
+ str(grid_file),
1446
+ str(weight_path),
1447
+ input=str(input_path),
1448
+ output=str(output_path),
1449
+ )
1450
+ return True
1451
+ except Exception as e:
1452
+ err = str(e).lower()
1453
+ corrupt = (
1454
+ "hdf error" in err
1455
+ or "nc_open failed" in err
1456
+ or "no such file" in err
1457
+ )
1458
+ if corrupt:
1459
+ self._invalidate_weight_file(weight_path)
1460
+ self._log_error(f"Failed to regrid {input_path} with weights: {e}", exc=e)
1461
+ if corrupt and _allow_regenerate:
1462
+ return self._regrid_with_weights(
1463
+ input_path,
1464
+ output_path,
1465
+ grid_file,
1466
+ weight_path,
1467
+ grid_type=grid_type,
1468
+ _allow_regenerate=False,
1469
+ )
1470
+ return False
1471
+
1472
+ def _regrid_without_weights(self, input_path: Path, output_path: Path, grid_file: Path, grid_type: str):
1473
+ """Regrid without existing weights; conservative if corners exist, else bilinear."""
1474
+ file_info = self._get_file_info(input_path)
1475
+ has_corners = file_info.get(
1476
+ "has_grid_corners", grid_type == "structured"
1477
+ )
1478
+ remap_op, _ = self._regrid_operators(grid_type, has_corners)
1479
+ try:
1480
+ getattr(self.cdo, remap_op)(
1481
+ str(grid_file),
1482
+ input=str(input_path),
1483
+ output=str(output_path),
1484
+ )
1485
+ return True
1486
+ except Exception as e:
1487
+ err = str(e).lower()
1488
+ if remap_op == "remapcon" and "corner coordinates missing" in err:
1489
+ try:
1490
+ self.cdo.remapbil(
1491
+ str(grid_file),
1492
+ input=str(input_path),
1493
+ output=str(output_path),
1494
+ )
1495
+ return True
1496
+ except Exception as fallback_error:
1497
+ self._log_error(
1498
+ f"Failed to regrid {input_path} of type {grid_type} "
1499
+ f"without weights (remapbil fallback): {fallback_error}",
1500
+ exc=fallback_error,
1501
+ )
1502
+ return False
1503
+ self._log_error(
1504
+ f"Failed to regrid {input_path} of type {grid_type} without weights: {e}",
1505
+ exc=e,
1506
+ )
1507
+ return False
1508
+
1509
+ def _regrid_single_file(
1510
+ self,
1511
+ input_path: Path,
1512
+ output_path: Path,
1513
+ grid_signature: str,
1514
+ grid_type: str,
1515
+ force_regenerate_weights: bool = False,
1516
+ ) -> bool:
1517
+ """
1518
+ Regrid a single file using CDO with grid-type-specific handling.
1519
+
1520
+ Args:
1521
+ - input_path (Path): Path to the input file
1522
+ - output_path (Path): Path to the output file
1523
+ - grid_signature (str): Unique signature for the grid
1524
+ - grid_type (str): Type of grid
1525
+ - force_regenerate_weights (bool): Whether to force regeneration of weights
1526
+
1527
+ Returns (bool): True if successful, False otherwise
1528
+ """
1529
+ try:
1530
+ weight_path = self._get_weight_path(grid_signature)
1531
+ grid_desc = self._generate_target_grid_description(input_path)
1532
+
1533
+ with tempfile.TemporaryDirectory() as tmpdir:
1534
+ tmpdir = Path(tmpdir)
1535
+ grid_file = tmpdir / "target_grid.txt"
1536
+ with open(grid_file, 'w') as f:
1537
+ f.write(grid_desc)
1538
+
1539
+ used_cached_weights = False
1540
+ if self.use_regrid_cache and not force_regenerate_weights:
1541
+ if self._is_valid_weight_file(weight_path):
1542
+ used_cached_weights = True
1543
+ if self.verbose_diagnostics:
1544
+ self.console.print(f"[green]Reusing weights: {weight_path.name}[/green]")
1545
+ self.stats['weights_reused'] += 1
1546
+ if not self._regrid_with_weights(
1547
+ input_path, output_path, grid_file, weight_path, grid_type
1548
+ ):
1549
+ used_cached_weights = False
1550
+ else:
1551
+ self._invalidate_weight_file(weight_path)
1552
+
1553
+ if not used_cached_weights:
1554
+ if self.verbose_diagnostics:
1555
+ self.console.print(f"[yellow]Generating new weights: {weight_path.name}[/yellow]")
1556
+ if not self._regrid_without_weights(input_path, output_path, grid_file, grid_type):
1557
+ return False
1558
+ self._save_weights(input_path, weight_path, grid_file)
1559
+ self.stats['weights_generated'] += 1
1560
+
1561
+ if output_path.exists() and output_path.stat().st_size > 0:
1562
+ return True
1563
+ self._log_error(f"Output file is empty or missing: {output_path}")
1564
+ return False
1565
+
1566
+ except Exception as e:
1567
+ self._log_error(f"Failed to regrid {input_path}: {e}", exc=e)
1568
+ return False
1569
+
1570
+ def _save_weights(self, input_path: Path, weight_path: Path, grid_file: Path) -> None:
1571
+ """Save regrid weights for future reuse with grid-type-specific method.
1572
+
1573
+ Args:
1574
+ - input_path (Path): Path to the input file
1575
+ - weight_path (Path): Path to the weight file
1576
+ - grid_file (Path): Path to the grid file
1577
+
1578
+ Returns (None): None
1579
+ """
1580
+ try:
1581
+ file_info = self._get_file_info(input_path)
1582
+ _, gen_op = self._regrid_operators(
1583
+ file_info.get("grid_type", "unknown"),
1584
+ file_info.get("has_grid_corners", False),
1585
+ )
1586
+ weight_path.parent.mkdir(parents=True, exist_ok=True)
1587
+ with weight_file_lock(weight_path):
1588
+ with tempfile.TemporaryDirectory() as tmpdir:
1589
+ tmpdir = Path(tmpdir)
1590
+ temp_weights = tmpdir / "temp_weights.nc"
1591
+ try:
1592
+ getattr(self.cdo, gen_op)(
1593
+ str(grid_file),
1594
+ input=str(input_path),
1595
+ output=str(temp_weights),
1596
+ )
1597
+ except Exception as e:
1598
+ if gen_op == "gencon" and "corner coordinates missing" in str(e).lower():
1599
+ self.cdo.genbil(
1600
+ str(grid_file),
1601
+ input=str(input_path),
1602
+ output=str(temp_weights),
1603
+ )
1604
+ else:
1605
+ raise
1606
+ staging = weight_path.with_suffix(weight_path.suffix + ".part")
1607
+ shutil.copy2(temp_weights, staging)
1608
+ os.replace(staging, weight_path)
1609
+ return True
1610
+ except Exception as e:
1611
+ self.logger.warning(f"Could not save weights: {e}")
1612
+ return False
1613
+
1614
+ def regrid_file(
1615
+ self,
1616
+ input_path: Path,
1617
+ output_path: Optional[Path] = None,
1618
+ force_regenerate_weights: bool = False,
1619
+ overwrite: bool = False,
1620
+ ui: Optional[RegridProgressUI] = None,
1621
+ ) -> bool:
1622
+ """
1623
+ Regrid a single file with comprehensive grid type support and memory optimization.
1624
+
1625
+ Args:
1626
+ - input_path (Path): Path to the input file
1627
+ - output_path (Path): Path to the output file
1628
+ - force_regenerate_weights (bool): Whether to force regeneration of weights
1629
+ - overwrite (bool): If True, overwrite existing output files
1630
+
1631
+ Returns (bool): True if successful, False otherwise
1632
+ """
1633
+ if not input_path.exists():
1634
+ self._log_error(f"Input file does not exist: {input_path}")
1635
+ return False
1636
+
1637
+ self.weight_cache_dir = weight_cache_dir_for_input(input_path)
1638
+ self.weight_cache_dir.mkdir(parents=True, exist_ok=True)
1639
+ init_regrid_error_log(self._error_log_path)
1640
+
1641
+ file_info = self._get_file_info(input_path)
1642
+ regrid_mode = "seafloor" if self.extract_seafloor else ("surface" if self.extract_surface else "complete")
1643
+ if ui:
1644
+ ui.start_file_processing(input_path, file_info, regrid_mode=regrid_mode)
1645
+
1646
+ if output_path is None:
1647
+ output_filename = self._generate_output_filename(
1648
+ input_path, file_info["has_level"], self.extract_surface, self.extract_seafloor
1649
+ )
1650
+ output_path = input_path.parent / output_filename
1651
+
1652
+ if output_path.exists():
1653
+ if overwrite: # handle overwrite logic
1654
+ if self.verbose:
1655
+ self.console.print(f"[yellow]Overwriting existing file: {output_path.name}[/yellow]")
1656
+ try:
1657
+ output_path.unlink()
1658
+ except Exception as e:
1659
+ self._log_error(f"Failed to remove existing file {output_path}: {e}", exc=e)
1660
+ if ui:
1661
+ ui.complete_file(input_path, success=False, message=f"Failed to remove existing file: {e}")
1662
+ return False
1663
+ else:
1664
+ if self.verbose:
1665
+ self.console.print(f"[yellow]Skipping existing file: {output_path.name}[/yellow]")
1666
+ if ui:
1667
+ ui.skip_file(input_path, "File already exists")
1668
+ return True
1669
+
1670
+ grid_type = file_info["grid_type"]
1671
+
1672
+ # Step description for progress UI
1673
+ if self.extract_seafloor:
1674
+ step_msg = "Extracting seafloor"
1675
+ elif self.extract_surface:
1676
+ step_msg = "Extracting surface (top level)"
1677
+ else:
1678
+ step_msg = "Regridding whole file"
1679
+ if self.verbose and self.verbose_diagnostics:
1680
+ self.console.print(Panel(f"[bold cyan]{step_msg}[/bold cyan]", style="dim"))
1681
+ self._notify_progress(
1682
+ step_msg, pct=10, ui=ui, ui_input_path=input_path, regrid_mode=regrid_mode
1683
+ )
1684
+ prepared_path = self._prepare_file_for_regridding(input_path, ui=ui, ui_input_path=input_path)
1685
+ if self.extract_surface and prepared_path == input_path:
1686
+ self.logger.warning(
1687
+ f"Surface extraction skipped for {input_path.name} (no level dim or preparation failed); "
1688
+ "regridding whole file. Chunk names will still use _top_level for consistency."
1689
+ )
1690
+ if self.extract_seafloor and prepared_path == input_path and "_seafloor" not in input_path.stem:
1691
+ self.logger.warning(
1692
+ f"Seafloor extraction skipped for {input_path.name}; regridding whole file."
1693
+ )
1694
+
1695
+ if not self._is_valid_prepared_file(prepared_path):
1696
+ self._log_error(
1697
+ f"Skipping regridding: prepared file is empty or invalid (CDO would report 'No arrays found'): {prepared_path.name}"
1698
+ )
1699
+ if ui:
1700
+ ui.complete_file(input_path, success=False, message="Prepared file empty or invalid")
1701
+ return False
1702
+
1703
+ # Use prepared file's metadata for chunking/display (e.g. seafloor extract is single-level and smaller)
1704
+ if prepared_path != input_path:
1705
+ file_info = self._get_file_info(prepared_path)
1706
+ grid_type = file_info['grid_type']
1707
+
1708
+ self.stats["grid_types"][grid_type] += 1
1709
+
1710
+ if self.verbose_diagnostics:
1711
+ self.console.print(f"[blue]Grid type: {grid_type}[/blue]")
1712
+ self.console.print(f"[blue]File size: {file_info['file_size_gb']:.2f} GB[/blue]")
1713
+ if file_info.get("has_level"):
1714
+ self.console.print(f"[blue]Levels: {file_info['level_count']}[/blue]")
1715
+
1716
+ should_chunk = self._should_chunk_file(file_info)
1717
+ if should_chunk and self.verbose_diagnostics:
1718
+ self.console.print(
1719
+ f"[yellow]Large file detected ({file_info['file_size_gb']:.2f} GB), using chunked processing[/yellow]"
1720
+ )
1721
+ if should_chunk:
1722
+ self._notify_progress(
1723
+ "Chunking large file",
1724
+ pct=20,
1725
+ ui=ui,
1726
+ ui_input_path=input_path,
1727
+ regrid_mode=regrid_mode,
1728
+ )
1729
+ try:
1730
+ grid_signature = self._get_grid_signature(file_info)
1731
+ if should_chunk:
1732
+ # When extract_surface but preparation was skipped, prepared_path may be the
1733
+ # full multi-level file; chunking it would produce *_chunk_*_top_level.nc with
1734
+ # all levels. Ensure we chunk a top-level-only file.
1735
+ if self.extract_surface and prepared_path == input_path:
1736
+ prepared_path = self._ensure_top_level_file_for_chunking(input_path)
1737
+ self._cleanup_stale_chunks_for_file(prepared_path)
1738
+ name_suffix = "_top_level" if self.extract_surface else ("_seafloor" if self.extract_seafloor else None)
1739
+ chunk_files = self._chunk_file_by_time(
1740
+ prepared_path,
1741
+ name_suffix=name_suffix,
1742
+ ui=ui,
1743
+ ui_input_path=input_path,
1744
+ regrid_mode=regrid_mode,
1745
+ )
1746
+ success = self._regrid_chunked_file(
1747
+ prepared_path,
1748
+ output_path,
1749
+ grid_signature,
1750
+ grid_type,
1751
+ chunk_files,
1752
+ weight_cache_dir=self.weight_cache_dir,
1753
+ ui=ui,
1754
+ ui_input_path=input_path,
1755
+ regrid_mode=regrid_mode,
1756
+ )
1757
+ else:
1758
+ self._notify_progress(
1759
+ "Regridding",
1760
+ pct=30,
1761
+ ui=ui,
1762
+ ui_input_path=input_path,
1763
+ regrid_mode=regrid_mode,
1764
+ )
1765
+ force_weights = force_regenerate_weights or not self.use_regrid_cache
1766
+ success = self._regrid_single_file(
1767
+ prepared_path,
1768
+ output_path,
1769
+ grid_signature,
1770
+ grid_type,
1771
+ force_regenerate_weights=force_weights,
1772
+ )
1773
+
1774
+ if success:
1775
+ self.stats['files_processed'] += 1
1776
+ self.stats['total_size_gb'] += file_info['file_size_gb']
1777
+
1778
+ # update memory monitoring
1779
+ if self.memory_monitor:
1780
+ self.memory_monitor.update_peak()
1781
+ self.stats['memory_peak_gb'] = self.memory_monitor.get_peak_memory_gb()
1782
+
1783
+ if ui: # complete UI progress
1784
+ ui.complete_file(input_path, success=True)
1785
+ else:
1786
+ if ui:
1787
+ ui.complete_file(input_path, success=False, message="Regridding failed")
1788
+
1789
+ if prepared_path != input_path and prepared_path.exists():
1790
+ prepared_path.unlink() # clean up prepared file if it was created
1791
+
1792
+ return success
1793
+
1794
+ except Exception as e:
1795
+ self._log_error(f"Error processing {input_path}: {e}", exc=e)
1796
+ self.stats['errors'] += 1
1797
+ if ui:
1798
+ ui.complete_file(input_path, success=False, message=f"Error: {str(e)}")
1799
+ return False
1800
+
1801
+ def regrid_batch(
1802
+ self,
1803
+ input_files: list[Path],
1804
+ output_dir: Optional[Path] = None,
1805
+ group_by_directory: bool = True,
1806
+ overwrite: bool = False,
1807
+ use_ui: bool = True,
1808
+ ) -> dict[str, list[Path]]:
1809
+ """Regrid a batch of files efficiently with optional parallel processing.
1810
+
1811
+ Args:
1812
+ - input_files (list[Path]): List of input files to regrid
1813
+ - output_dir (Path): Output directory. If None, outputs go to same directory as input
1814
+ - group_by_directory (bool): Group files by directory to maximize weight reuse
1815
+ - overwrite (bool): If True, overwrite existing output files
1816
+
1817
+ Returns (dict[str, list[Path]]): Dictionary mapping status to list of file paths
1818
+
1819
+ Note: has two child functions, _regrid_batch_sequential and _regrid_batch_parallel depending on number of files, and whether parallel processing is enabled and is successful.
1820
+ """
1821
+ results = {
1822
+ 'successful': [],
1823
+ 'failed': [],
1824
+ 'skipped': [],
1825
+ }
1826
+
1827
+ if isinstance(input_files, Path):
1828
+ input_files = [input_files]
1829
+ self.console.print(f"[yellow]Input is a single file, will be processed sequentially[/yellow]") if self.verbose_diagnostics else None
1830
+ # Exclude weight/cache files from the batch
1831
+ input_files = [f for f in input_files if not is_weights_or_cache_file(f)]
1832
+ representative_by_dir = representative_files_by_directory(input_files)
1833
+ if representative_by_dir and self.verbose:
1834
+ self.console.print(
1835
+ f"[blue]Resolution representatives: {len(representative_by_dir)} "
1836
+ f"leaf director{'y' if len(representative_by_dir) == 1 else 'ies'}[/blue]"
1837
+ )
1838
+
1839
+ # clean up problematic files (_top_level and _chunk_) first
1840
+ input_files = self._cleanup_problematic_files(input_files)
1841
+
1842
+ # Deduplicate by resolved path so the same original file is never processed by more than one worker
1843
+ seen: set[Path] = set()
1844
+ deduped: list[Path] = []
1845
+ for f in input_files:
1846
+ try:
1847
+ r = f.resolve()
1848
+ except OSError:
1849
+ r = f
1850
+ if r not in seen:
1851
+ seen.add(r)
1852
+ deduped.append(f)
1853
+ if len(deduped) < len(input_files) and self.verbose:
1854
+ self.console.print(
1855
+ f"[blue]Deduplicated input: {len(input_files)} -> {len(deduped)} files (same path only processed once)[/blue]"
1856
+ )
1857
+ input_files = deduped
1858
+
1859
+ if overwrite:
1860
+ # unlink any files with 'regridded' in the name in order to reprocess them
1861
+ for file in input_files:
1862
+ if 'regridded' in file.name:
1863
+ if self.verbose:
1864
+ self.console.print(f"[yellow]Removing existing regridded file: {file.name}[/yellow]")
1865
+ file.unlink()
1866
+ # prune any files with 'regridded' in the name to avoid processing them twice
1867
+ input_files = [
1868
+ file for file in input_files
1869
+ if 'regridded' not in file.name and not file.suffix.endswith('.part')
1870
+ ] # TODO: use _prune_regridded function instead (or remove it)
1871
+
1872
+ # also check for and remove existing output files that would be created from input files
1873
+ if self.verbose:
1874
+ self.console.print(f"[blue]Checking for existing output files to remove (overwrite=True)[/blue]")
1875
+ for file in input_files[:]: # use slice copy to avoid modifying list while iterating
1876
+ # generate the expected output filename using lightweight check
1877
+ has_level = self._has_level_lightweight(file)
1878
+ output_filename = self._generate_output_filename(
1879
+ file, has_level, self.extract_surface, self.extract_seafloor
1880
+ )
1881
+ if output_dir:
1882
+ expected_output = output_dir / output_filename
1883
+ else:
1884
+ expected_output = file.parent / output_filename
1885
+
1886
+ if expected_output.exists():
1887
+ if self.verbose:
1888
+ self.console.print(f"[yellow]Removing existing output file: {expected_output.name}[/yellow]")
1889
+ expected_output.unlink()
1890
+ else:
1891
+ # when overwrite=False, just prune file names with 'regridded' in the name from processing list but don't delete them
1892
+ if self.verbose:
1893
+ regridded_files = [file for file in input_files if 'regridded' in file.name]
1894
+ if regridded_files:
1895
+ self.console.print(f"[blue]Skipping {len(regridded_files)} existing regridded files (overwrite=False)[/blue]")
1896
+ input_files = [
1897
+ file for file in input_files
1898
+ if 'regridded' not in file.name and not file.suffix.endswith('.part')
1899
+ ]
1900
+
1901
+ # Always redirect pipeline errors (seafloor, regrid, etc.) to a log file so they are
1902
+ # captured even when running quiet or without UI
1903
+ if input_files:
1904
+ self._error_log_path = init_regrid_error_log()
1905
+ self.set_error_log_path(self._error_log_path)
1906
+ if use_ui and self.verbose:
1907
+ self.console.print(f"[dim]Errors logged to: {self._error_log_path}[/dim]")
1908
+
1909
+ self.start_time = print_timestamp(self.console, "START")
1910
+
1911
+ if not self.enable_parallel or len(input_files) < 2 or self.max_workers in [None, 1]:
1912
+ # process sequentially
1913
+ results = self._regrid_batch_sequential(
1914
+ input_files, output_dir, group_by_directory, overwrite, use_ui, representative_by_dir
1915
+ )
1916
+ self.end_time = print_timestamp(self.console, "END")
1917
+ return results
1918
+
1919
+ ui = None
1920
+ poll_thread = None
1921
+ stop_poll = threading.Event()
1922
+ try:
1923
+ # initialize compact UI for parallel processing
1924
+ if use_ui and self.verbose:
1925
+ batch_mode = "seafloor" if self.extract_seafloor else ("surface" if self.extract_surface else "complete")
1926
+ ui = BatchRegridUI(
1927
+ input_files,
1928
+ max_workers=self.max_workers,
1929
+ verbose=self.verbose,
1930
+ regrid_mode=batch_mode,
1931
+ )
1932
+ ui.__enter__()
1933
+
1934
+ # process files in parallel (individual file processing)
1935
+ if self.verbose:
1936
+ self.console.print(f"[green]Processing {len(input_files)} files in parallel with {self.max_workers} workers[/green]")
1937
+
1938
+ progress_manager = mp.Manager()
1939
+ progress_state = progress_manager.dict()
1940
+ if ui:
1941
+ poll_thread = threading.Thread(
1942
+ target=poll_batch_progress,
1943
+ args=(ui, progress_state, stop_poll),
1944
+ daemon=True,
1945
+ )
1946
+ poll_thread.start()
1947
+
1948
+ with ProcessPoolExecutor(max_workers=self.max_workers) as executor:
1949
+ futures = {
1950
+ executor.submit(
1951
+ process_single_file_standalone,
1952
+ file_path,
1953
+ output_dir,
1954
+ self.target_resolution,
1955
+ self.target_grid,
1956
+ weight_cache_dir_for_input(file_path),
1957
+ self.extract_surface,
1958
+ self.extract_seafloor,
1959
+ self.use_regrid_cache,
1960
+ self.use_seafloor_cache,
1961
+ self.max_memory_gb,
1962
+ self.chunk_size_gb,
1963
+ self.enable_chunking,
1964
+ overwrite,
1965
+ representative_by_dir.get(file_path.parent),
1966
+ self.verbose,
1967
+ self._error_log_path,
1968
+ progress_state,
1969
+ ): file_path
1970
+ for file_path in input_files
1971
+ }
1972
+
1973
+ combined_stats = {
1974
+ 'files_processed': 0,
1975
+ 'weights_reused': 0,
1976
+ 'weights_generated': 0,
1977
+ 'chunks_processed': 0,
1978
+ 'errors': 0,
1979
+ 'total_size_gb': 0.0,
1980
+ 'memory_peak_gb': 0.0,
1981
+ 'grid_types': {}
1982
+ }
1983
+
1984
+ for future in as_completed(futures):
1985
+ file_result = future.result()
1986
+
1987
+ if ui:
1988
+ ui.update_file_result(file_result["file_path"], file_result)
1989
+
1990
+ # combine statistics from worker
1991
+ if 'stats' in file_result:
1992
+ worker_stats = file_result['stats']
1993
+ combined_stats['files_processed'] += worker_stats.get('files_processed', 0)
1994
+ combined_stats['weights_reused'] += worker_stats.get('weights_reused', 0)
1995
+ combined_stats['weights_generated'] += worker_stats.get('weights_generated', 0)
1996
+ combined_stats['chunks_processed'] += worker_stats.get('chunks_processed', 0)
1997
+ combined_stats['errors'] += worker_stats.get('errors', 0)
1998
+ combined_stats['total_size_gb'] += worker_stats.get('total_size_gb', 0.0)
1999
+ combined_stats['memory_peak_gb'] = max(combined_stats['memory_peak_gb'], worker_stats.get('memory_peak_gb', 0.0))
2000
+
2001
+ # combine grid types
2002
+ for grid_type, count in worker_stats.get('grid_types', {}).items():
2003
+ if grid_type not in combined_stats['grid_types']:
2004
+ combined_stats['grid_types'][grid_type] = 0
2005
+ combined_stats['grid_types'][grid_type] += count
2006
+
2007
+ if file_result['success']:
2008
+ if file_result.get('skipped', False):
2009
+ results['skipped'].append(file_result['file_path'])
2010
+ else:
2011
+ results['successful'].append(file_result['file_path'])
2012
+ else:
2013
+ results['failed'].append(file_result['file_path'])
2014
+
2015
+ self.stats.update(combined_stats)
2016
+ except Exception as e:
2017
+ self._log_error(f"Error processing batch in parallel: {e}", exc=e)
2018
+ self.logger.warning(f"Falling back to sequential processing")
2019
+ results = self._regrid_batch_sequential(
2020
+ input_files, output_dir, group_by_directory, overwrite, use_ui, representative_by_dir
2021
+ )
2022
+ self.end_time = print_timestamp(self.console, "END")
2023
+ return results
2024
+ finally:
2025
+ if poll_thread is not None:
2026
+ stop_poll.set()
2027
+ poll_thread.join(timeout=2)
2028
+ if ui:
2029
+ self.end_time = print_timestamp(self.console, "END")
2030
+ self.stats["processing_time"] = format_processing_time(
2031
+ get_processing_time(self.start_time, self.end_time)
2032
+ )
2033
+ ui._update_stats(self.stats)
2034
+ ui.print_summary()
2035
+ ui.__exit__(None, None, None)
2036
+ else:
2037
+ self.end_time = print_timestamp(self.console, "END")
2038
+
2039
+ return results
2040
+
2041
+ def _regrid_batch_sequential(
2042
+ self,
2043
+ input_files: list[Path],
2044
+ output_dir: Optional[Path] = None,
2045
+ group_by_directory: bool = True,
2046
+ overwrite: bool = False,
2047
+ use_ui: bool = True,
2048
+ representative_by_dir: Optional[dict[Path, Path]] = None,
2049
+ ) -> dict[str, list[Path]]:
2050
+ """Sequential batch processing fallback. Used if parallel processing fails or is not enabled, if there is only one file, or if the number of workers is not specified or is set to 1.
2051
+
2052
+ Args:
2053
+ - input_files (list[Path]): List of input files to regrid
2054
+ - output_dir (Path): Output directory. If None, outputs go to same directory as input
2055
+ - group_by_directory (bool): Group files by directory to maximize weight reuse
2056
+
2057
+ Returns (dict[str, list[Path]]): Dictionary mapping status to list of file paths
2058
+
2059
+ Note: has one regridding child function, regrid_file, for processing a single file.
2060
+ """
2061
+ results = {
2062
+ 'successful': [],
2063
+ 'failed': [],
2064
+ 'skipped': [],
2065
+ }
2066
+
2067
+ # clean up problematic files (_top_level and _chunk_) first
2068
+ input_files = self._cleanup_problematic_files(input_files)
2069
+ # remove completed files from input_files (any containing 'regridded' in the name)
2070
+ chunk_files = [file for file in input_files if 'chunk' in file.name]
2071
+ regridded_files = [file for file in input_files if 'regridded' in file.name]
2072
+ input_files = [file for file in input_files if file not in regridded_files]
2073
+ # remove residual files (any containing 'chunk' in the name): TODO: these shouldn't exist at this point
2074
+ input_files = [file for file in input_files if file not in chunk_files]
2075
+ print(f"Removed {len(chunk_files)} and {len(regridded_files)} files from input_files before regridding")
2076
+
2077
+ # initialize UI if requested
2078
+ ui = None
2079
+ if use_ui and self.verbose:
2080
+ ui = RegridProgressUI(
2081
+ input_files,
2082
+ verbose=self.verbose,
2083
+ verbose_diagnostics=self.verbose_diagnostics,
2084
+ log_file=self._error_log_path,
2085
+ )
2086
+ ui.__enter__()
2087
+ # group files by directory if requested
2088
+ if group_by_directory:
2089
+ file_groups = self._group_files_by_directory(input_files)
2090
+ else:
2091
+ file_groups = {'all': input_files}
2092
+
2093
+ # process each group
2094
+ for group_name, files in file_groups.items():
2095
+ if self.verbose:
2096
+ self.console.print(f"\n[blue]Processing group: {group_name}[/blue]")
2097
+ self.console.print(f"[blue]Files in group: {len(files)}[/blue]")
2098
+
2099
+ # process files in group
2100
+ for file_path in files:
2101
+ try:
2102
+ dir_representative = (representative_by_dir or {}).get(file_path.parent)
2103
+ if dir_representative:
2104
+ self._representative_file = dir_representative
2105
+
2106
+ # Use lightweight check for has_level to avoid expensive full file analysis
2107
+ has_level = self._has_level_lightweight(file_path)
2108
+
2109
+ # determine output path
2110
+ if output_dir:
2111
+ output_filename = self._generate_output_filename(
2112
+ file_path, has_level, self.extract_surface, self.extract_seafloor
2113
+ )
2114
+ output_path = output_dir / output_filename
2115
+ else:
2116
+ output_filename = self._generate_output_filename(
2117
+ file_path, has_level, self.extract_surface, self.extract_seafloor
2118
+ )
2119
+ output_path = file_path.parent / output_filename
2120
+
2121
+ # check if output already exists
2122
+ if output_path.exists():
2123
+ results["skipped"].append(file_path)
2124
+ if self.verbose:
2125
+ self.console.print(f"[yellow]Skipping (exists): {file_path.name}[/yellow]")
2126
+ continue
2127
+
2128
+ # regrid file
2129
+ success = self.regrid_file(file_path, output_path, overwrite=overwrite, ui=ui)
2130
+
2131
+ if success:
2132
+ results['successful'].append(file_path)
2133
+ if self.verbose:
2134
+ self.console.print(f"[green]Success: {file_path.name}[/green]")
2135
+ else:
2136
+ results['failed'].append(file_path)
2137
+ if self.verbose:
2138
+ self.console.print(f"[red]Failed: {file_path.name}[/red]")
2139
+
2140
+ except Exception as e:
2141
+ self._log_error(f"Error processing {file_path}: {e}", exc=e)
2142
+ results['failed'].append(file_path)
2143
+ if ui:
2144
+ ui.complete_file(file_path, success=False, message=f"Error: {str(e)}")
2145
+
2146
+ if ui:
2147
+ self.stats['processing_time'] = format_processing_time(get_processing_time(self.start_time, self.end_time))
2148
+ ui._update_stats(self.stats)
2149
+ ui.print_summary()
2150
+ ui.__exit__(None, None, None)
2151
+
2152
+ return results
2153
+
2154
+
2155
+ def _group_files_by_directory(self, files: list[Path]) -> dict[str, list[Path]]:
2156
+ """Group files by their parent directory.
2157
+
2158
+ Args:
2159
+ - files (list[Path]): List of files to group
2160
+
2161
+ Returns (dict[str, list[Path]]): Dictionary mapping directory to list of file paths
2162
+ """
2163
+ groups = {}
2164
+ for file_path in files:
2165
+ parent_dir = str(file_path.parent)
2166
+ if parent_dir not in groups:
2167
+ groups[parent_dir] = []
2168
+ groups[parent_dir].append(file_path)
2169
+ return groups
2170
+
2171
+ def print_statistics(self):
2172
+ """Print comprehensive processing statistics."""
2173
+ table = Table(show_header=True, header_style="bold cyan")
2174
+ table.add_column("Metric", style="bold")
2175
+ table.add_column("Value", style="bold")
2176
+
2177
+ table.add_row("Files Processed", str(self.stats['files_processed']))
2178
+ table.add_row("Weights Reused", str(self.stats['weights_reused']))
2179
+ table.add_row("Weights Generated", str(self.stats['weights_generated']))
2180
+ table.add_row("Chunks Processed", str(self.stats['chunks_processed']))
2181
+ table.add_row("Errors", str(self.stats['errors']))
2182
+ table.add_row("Total Size (GB)", f"{self.stats['total_size_gb']:.2f}")
2183
+ table.add_row("Memory Peak (GB)", f"{self.stats['memory_peak_gb']:.2f}")
2184
+
2185
+ # add grid type statistics
2186
+ table.add_row("", "") # Empty row
2187
+ table.add_row("Grid Types", "")
2188
+ for grid_type, count in self.stats['grid_types'].items():
2189
+ if count > 0:
2190
+ table.add_row(f" {grid_type}", str(count))
2191
+
2192
+ self.console.print(Panel(table, title="[cyan]CDO Regridding Statistics[/cyan]", border_style="cyan"))
2193
+
2194
+ def cleanup_weight_files(self, confirm: bool = True):
2195
+ """Delete all cached weight files on demand.
2196
+
2197
+ Args:
2198
+ - confirm (bool): Require user confirmation before deleting files (default: True)
2199
+ # TODO: add confirmation option to main function?
2200
+ Returns (None): None
2201
+ """
2202
+ try:
2203
+ weight_files = list(self.weight_cache_dir.glob("weights_*.nc"))
2204
+ if not weight_files:
2205
+ if self.verbose:
2206
+ self.console.print("[yellow]No weight files to clean up.[/yellow]")
2207
+ return
2208
+
2209
+ if confirm:
2210
+ from rich.prompt import Confirm
2211
+ proceed = Confirm.ask(
2212
+ f"[red]Are you sure you want to delete all {len(weight_files)} cached weight files in {self.weight_cache_dir}?[/red]"
2213
+ )
2214
+ if not proceed:
2215
+ if self.verbose:
2216
+ self.console.print("[yellow]Cleanup cancelled by user.[/yellow]")
2217
+ return
2218
+
2219
+ for weight_file in weight_files:
2220
+ try:
2221
+ weight_file.unlink()
2222
+ if self.verbose:
2223
+ self.console.print(f"[cyan]Cleaned up {weight_file.name}[/cyan]")
2224
+ except Exception as e:
2225
+ self._log_error(f"Error deleting weight file {weight_file}: {e}", exc=e)
2226
+ if self.verbose:
2227
+ self.console.print(f"[green]Cleaned up {len(weight_files)} weight file(s).[/green]")
2228
+ except Exception as e:
2229
+ self._log_error(f"Error cleaning up weights: {e}", exc=e)
2230
+