roms-tools 0.1.0__py3-none-any.whl → 0.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,48 +1,457 @@
1
1
  import pooch
2
2
  import xarray as xr
3
+ from dataclasses import dataclass, field
4
+ import glob
5
+ from datetime import datetime, timedelta
6
+ import numpy as np
7
+ from typing import Dict, Optional, List
8
+ import dask
3
9
 
4
-
5
- FRANK = pooch.create(
10
+ # Create a Pooch object to manage the global topography data
11
+ pup_data = pooch.create(
6
12
  # Use the default cache folder for the operating system
7
13
  path=pooch.os_cache("roms-tools"),
8
14
  base_url="https://github.com/CWorthy-ocean/roms-tools-data/raw/main/",
9
- # If this is a development version, get the data from the "main" branch
10
15
  # The registry specifies the files that can be fetched
11
16
  registry={
12
17
  "etopo5.nc": "sha256:23600e422d59bbf7c3666090166a0d468c8ee16092f4f14e32c4e928fbcd627b",
13
18
  },
14
19
  )
15
20
 
21
+ # Create a Pooch object to manage the test data
22
+ pup_test_data = pooch.create(
23
+ # Use the default cache folder for the operating system
24
+ path=pooch.os_cache("roms-tools"),
25
+ base_url="https://github.com/CWorthy-ocean/roms-tools-test-data/raw/main/",
26
+ # The registry specifies the files that can be fetched
27
+ registry={
28
+ "GLORYS_test_data.nc": "648f88ec29c433bcf65f257c1fb9497bd3d5d3880640186336b10ed54f7129d2",
29
+ "ERA5_regional_test_data.nc": "bd12ce3b562fbea2a80a3b79ba74c724294043c28dc98ae092ad816d74eac794",
30
+ "ERA5_global_test_data.nc": "8ed177ab64c02caf509b9fb121cf6713f286cc603b1f302f15f3f4eb0c21dc4f",
31
+ "TPXO_global_test_data.nc": "457bfe87a7b247ec6e04e3c7d3e741ccf223020c41593f8ae33a14f2b5255e60",
32
+ "TPXO_regional_test_data.nc": "11739245e2286d9c9d342dce5221e6435d2072b50028bef2e86a30287b3b4032",
33
+ },
34
+ )
35
+
16
36
 
17
- def fetch_topo(topography_source) -> xr.Dataset:
37
+ def fetch_topo(topography_source: str) -> xr.Dataset:
18
38
  """
19
39
  Load the global topography data as an xarray Dataset.
40
+
41
+ Parameters
42
+ ----------
43
+ topography_source : str
44
+ The source of the topography data to be loaded. Available options:
45
+ - "ETOPO5"
46
+
47
+ Returns
48
+ -------
49
+ xr.Dataset
50
+ The global topography data as an xarray Dataset.
20
51
  """
21
52
  # Mapping from user-specified topography options to corresponding filenames in the registry
22
- topo_dict = {"etopo5": "etopo5.nc"}
23
-
24
- # The file will be downloaded automatically the first time this is run
25
- # returns the file path to the downloaded file. Afterwards, Pooch finds
26
- # it in the local cache and doesn't repeat the download.
27
- fname = FRANK.fetch(topo_dict[topography_source])
28
- # The "fetch" method returns the full path to the downloaded data file.
29
- # All we need to do now is load it with our standard Python tools.
53
+ topo_dict = {"ETOPO5": "etopo5.nc"}
54
+
55
+ # Fetch the file using Pooch, downloading if necessary
56
+ fname = pup_data.fetch(topo_dict[topography_source])
57
+
58
+ # Load the dataset using xarray and return it
30
59
  ds = xr.open_dataset(fname)
31
60
  return ds
32
61
 
33
62
 
34
- def fetch_ssr_correction(correction_source) -> xr.Dataset:
63
+ def download_test_data(filename: str) -> str:
35
64
  """
36
- Load the SSR correction data as an xarray Dataset.
65
+ Download the test data file.
66
+
67
+ Parameters
68
+ ----------
69
+ filename : str
70
+ The name of the test data file to be downloaded. Available options:
71
+ - "GLORYS_test_data.nc"
72
+ - "ERA5_regional_test_data.nc"
73
+ - "ERA5_global_test_data.nc"
74
+
75
+ Returns
76
+ -------
77
+ str
78
+ The path to the downloaded test data file.
37
79
  """
38
- # Mapping from user-specified topography options to corresponding filenames in the registry
39
- topo_dict = {"corev2": "SSR_correction.nc"}
40
-
41
- # The file will be downloaded automatically the first time this is run
42
- # returns the file path to the downloaded file. Afterwards, Pooch finds
43
- # it in the local cache and doesn't repeat the download.
44
- fname = FRANK.fetch(topo_dict[correction_source])
45
- # The "fetch" method returns the full path to the downloaded data file.
46
- # All we need to do now is load it with our standard Python tools.
47
- ds = xr.open_dataset(fname)
48
- return ds
80
+ # Fetch the file using Pooch, downloading if necessary
81
+ fname = pup_test_data.fetch(filename)
82
+
83
+ return fname
84
+
85
+
86
+ @dataclass(frozen=True, kw_only=True)
87
+ class Dataset:
88
+ """
89
+ Represents forcing data on original grid.
90
+
91
+ Parameters
92
+ ----------
93
+ filename : str
94
+ The path to the data files. Can contain wildcards.
95
+ start_time : Optional[datetime], optional
96
+ The start time for selecting relevant data. If not provided, the data is not filtered by start time.
97
+ end_time : Optional[datetime], optional
98
+ The end time for selecting relevant data. If not provided, only data at the start_time is selected if start_time is provided,
99
+ or no filtering is applied if start_time is not provided.
100
+ var_names : List[str]
101
+ List of variable names that are required in the dataset.
102
+ dim_names: Dict[str, str], optional
103
+ Dictionary specifying the names of dimensions in the dataset.
104
+
105
+ Attributes
106
+ ----------
107
+ ds : xr.Dataset
108
+ The xarray Dataset containing the forcing data on its original grid.
109
+
110
+ Examples
111
+ --------
112
+ >>> dataset = Dataset(
113
+ ... filename="data.nc",
114
+ ... start_time=datetime(2022, 1, 1),
115
+ ... end_time=datetime(2022, 12, 31),
116
+ ... )
117
+ >>> dataset.load_data()
118
+ >>> print(dataset.ds)
119
+ <xarray.Dataset>
120
+ Dimensions: ...
121
+ """
122
+
123
+ filename: str
124
+ start_time: Optional[datetime] = None
125
+ end_time: Optional[datetime] = None
126
+ var_names: List[str]
127
+ dim_names: Dict[str, str] = field(
128
+ default_factory=lambda: {
129
+ "longitude": "longitude",
130
+ "latitude": "latitude",
131
+ "time": "time",
132
+ }
133
+ )
134
+
135
+ ds: xr.Dataset = field(init=False, repr=False)
136
+
137
+ def __post_init__(self):
138
+ ds = self.load_data()
139
+
140
+ # Select relevant times
141
+ if "time" in self.dim_names and self.start_time is not None:
142
+ ds = self.select_relevant_times(ds)
143
+
144
+ # Select relevant fields
145
+ ds = self.select_relevant_fields(ds)
146
+
147
+ # Make sure that latitude is ascending
148
+ diff = np.diff(ds[self.dim_names["latitude"]])
149
+ if np.all(diff < 0):
150
+ ds = ds.isel(**{self.dim_names["latitude"]: slice(None, None, -1)})
151
+
152
+ # Check whether the data covers the entire globe
153
+ is_global = self.check_if_global(ds)
154
+
155
+ if is_global:
156
+ ds = self.concatenate_longitudes(ds)
157
+
158
+ object.__setattr__(self, "ds", ds)
159
+
160
+ def load_data(self) -> xr.Dataset:
161
+ """
162
+ Load dataset from the specified file.
163
+
164
+ Returns
165
+ -------
166
+ ds : xr.Dataset
167
+ The loaded xarray Dataset containing the forcing data.
168
+
169
+ Raises
170
+ ------
171
+ FileNotFoundError
172
+ If the specified file does not exist.
173
+ """
174
+
175
+ # Check if the file exists
176
+ matching_files = glob.glob(self.filename)
177
+ if not matching_files:
178
+ raise FileNotFoundError(
179
+ f"No files found matching the pattern '{self.filename}'."
180
+ )
181
+
182
+ # Load the dataset
183
+ with dask.config.set(**{"array.slicing.split_large_chunks": False}):
184
+ # Define the chunk sizes
185
+ chunks = {
186
+ self.dim_names["latitude"]: -1,
187
+ self.dim_names["longitude"]: -1,
188
+ }
189
+ if "depth" in self.dim_names.keys():
190
+ chunks[self.dim_names["depth"]] = -1
191
+ if "time" in self.dim_names.keys():
192
+ chunks[self.dim_names["time"]] = 1
193
+
194
+ ds = xr.open_mfdataset(
195
+ self.filename,
196
+ combine="nested",
197
+ concat_dim=self.dim_names["time"],
198
+ coords="minimal",
199
+ compat="override",
200
+ chunks=chunks,
201
+ engine="netcdf4",
202
+ )
203
+ else:
204
+ ds = xr.open_dataset(
205
+ self.filename,
206
+ chunks=chunks,
207
+ )
208
+
209
+ return ds
210
+
211
+ def select_relevant_fields(self, ds) -> xr.Dataset:
212
+ """
213
+ Selects and returns a subset of the dataset containing only the variables specified in `self.var_names`.
214
+
215
+ Parameters
216
+ ----------
217
+ ds : xr.Dataset
218
+ The input dataset from which variables will be selected.
219
+
220
+ Returns
221
+ -------
222
+ xr.Dataset
223
+ A dataset containing only the variables specified in `self.var_names`.
224
+
225
+ Raises
226
+ ------
227
+ ValueError
228
+ If `ds` does not contain all variables listed in `self.var_names`.
229
+
230
+ """
231
+ missing_vars = [var for var in self.var_names if var not in ds.data_vars]
232
+ if missing_vars:
233
+ raise ValueError(
234
+ f"Dataset does not contain all required variables. The following variables are missing: {missing_vars}"
235
+ )
236
+
237
+ for var in ds.data_vars:
238
+ if var not in self.var_names:
239
+ ds = ds.drop_vars(var)
240
+
241
+ return ds
242
+
243
+ def select_relevant_times(self, ds) -> xr.Dataset:
244
+
245
+ """
246
+ Selects and returns the subset of the dataset corresponding to the specified time range.
247
+
248
+ This function filters the dataset to include only the data points within the specified
249
+ time range, defined by `self.start_time` and `self.end_time`. If `self.end_time` is not
250
+ provided, it defaults to one day after `self.start_time`.
251
+
252
+ Parameters
253
+ ----------
254
+ ds : xr.Dataset
255
+ The input dataset to be filtered.
256
+
257
+ Returns
258
+ -------
259
+ xr.Dataset
260
+ A dataset containing only the data points within the specified time range.
261
+
262
+ """
263
+
264
+ time_dim = self.dim_names["time"]
265
+
266
+ if not self.end_time:
267
+ end_time = self.start_time + timedelta(days=1)
268
+ else:
269
+ end_time = self.end_time
270
+
271
+ times = (np.datetime64(self.start_time) <= ds[time_dim]) & (
272
+ ds[time_dim] < np.datetime64(end_time)
273
+ )
274
+ ds = ds.where(times, drop=True)
275
+
276
+ if not ds.sizes[time_dim]:
277
+ raise ValueError("No matching times found.")
278
+
279
+ if not self.end_time:
280
+ if ds.sizes[time_dim] != 1:
281
+ found_times = ds.sizes[time_dim]
282
+ raise ValueError(
283
+ f"There must be exactly one time matching the start_time. Found {found_times} matching times."
284
+ )
285
+
286
+ return ds
287
+
288
+ def check_if_global(self, ds) -> bool:
289
+ """
290
+ Checks if the dataset covers the entire globe in the longitude dimension.
291
+
292
+ This function calculates the mean difference between consecutive longitude values.
293
+ It then checks if the difference between the first and last longitude values (plus 360 degrees)
294
+ is close to this mean difference, within a specified tolerance. If it is, the dataset is considered
295
+ to cover the entire globe in the longitude dimension.
296
+
297
+ Returns
298
+ -------
299
+ bool
300
+ True if the dataset covers the entire globe in the longitude dimension, False otherwise.
301
+
302
+ """
303
+ dlon_mean = (
304
+ ds[self.dim_names["longitude"]].diff(dim=self.dim_names["longitude"]).mean()
305
+ )
306
+ dlon = (
307
+ ds[self.dim_names["longitude"]][0] - ds[self.dim_names["longitude"]][-1]
308
+ ) % 360.0
309
+ is_global = np.isclose(dlon, dlon_mean, rtol=0.0, atol=1e-3)
310
+
311
+ return is_global
312
+
313
+ def concatenate_longitudes(self, ds):
314
+ """
315
+ Concatenates the field three times: with longitudes shifted by -360, original longitudes, and shifted by +360.
316
+
317
+ Parameters
318
+ ----------
319
+ field : xr.DataArray
320
+ The field to be concatenated.
321
+
322
+ Returns
323
+ -------
324
+ xr.DataArray
325
+ The concatenated field, with the longitude dimension extended.
326
+
327
+ Notes
328
+ -----
329
+ Concatenating three times may be overkill in most situations, but it is safe. Alternatively, we could refactor
330
+ to figure out whether concatenating on the lower end, upper end, or at all is needed.
331
+
332
+ """
333
+ ds_concatenated = xr.Dataset()
334
+
335
+ lon = ds[self.dim_names["longitude"]]
336
+ lon_minus360 = lon - 360
337
+ lon_plus360 = lon + 360
338
+ lon_concatenated = xr.concat(
339
+ [lon_minus360, lon, lon_plus360], dim=self.dim_names["longitude"]
340
+ )
341
+
342
+ ds_concatenated[self.dim_names["longitude"]] = lon_concatenated
343
+
344
+ for var in self.var_names:
345
+ if self.dim_names["longitude"] in ds[var].dims:
346
+ field = ds[var]
347
+ field_concatenated = xr.concat(
348
+ [field, field, field], dim=self.dim_names["longitude"]
349
+ ).chunk({self.dim_names["longitude"]: -1})
350
+ field_concatenated[self.dim_names["longitude"]] = lon_concatenated
351
+ ds_concatenated[var] = field_concatenated
352
+ else:
353
+ ds_concatenated[var] = ds[var]
354
+
355
+ return ds_concatenated
356
+
357
+ def choose_subdomain(
358
+ self, latitude_range, longitude_range, margin, straddle, return_subdomain=False
359
+ ):
360
+ """
361
+ Selects a subdomain from the given xarray Dataset based on latitude and longitude ranges,
362
+ extending the selection by the specified margin. Handles the conversion of longitude values
363
+ in the dataset from one range to another.
364
+
365
+ Parameters
366
+ ----------
367
+ latitude_range : tuple
368
+ A tuple (lat_min, lat_max) specifying the minimum and maximum latitude values of the subdomain.
369
+ longitude_range : tuple
370
+ A tuple (lon_min, lon_max) specifying the minimum and maximum longitude values of the subdomain.
371
+ margin : float
372
+ Margin in degrees to extend beyond the specified latitude and longitude ranges when selecting the subdomain.
373
+ straddle : bool
374
+ If True, target longitudes are expected in the range [-180, 180].
375
+ If False, target longitudes are expected in the range [0, 360].
376
+ return_subdomain : bool, optional
377
+ If True, returns the subset of the original dataset. If False, assigns it to self.ds.
378
+ Default is False.
379
+
380
+ Returns
381
+ -------
382
+ xr.Dataset
383
+ The subset of the original dataset representing the chosen subdomain, including an extended area
384
+ to cover one extra grid point beyond the specified ranges if return_subdomain is True.
385
+ Otherwise, returns None.
386
+
387
+ Raises
388
+ ------
389
+ ValueError
390
+ If the selected latitude or longitude range does not intersect with the dataset.
391
+ """
392
+ lat_min, lat_max = latitude_range
393
+ lon_min, lon_max = longitude_range
394
+
395
+ lon = self.ds[self.dim_names["longitude"]]
396
+ # Adjust longitude range if needed to match the expected range
397
+ if not straddle:
398
+ if lon.min() < -180:
399
+ if lon_max + margin > 0:
400
+ lon_min -= 360
401
+ lon_max -= 360
402
+ elif lon.min() < 0:
403
+ if lon_max + margin > 180:
404
+ lon_min -= 360
405
+ lon_max -= 360
406
+
407
+ if straddle:
408
+ if lon.max() > 360:
409
+ if lon_min - margin < 180:
410
+ lon_min += 360
411
+ lon_max += 360
412
+ elif lon.max() > 180:
413
+ if lon_min - margin < 0:
414
+ lon_min += 360
415
+ lon_max += 360
416
+
417
+ # Select the subdomain
418
+ subdomain = self.ds.sel(
419
+ **{
420
+ self.dim_names["latitude"]: slice(lat_min - margin, lat_max + margin),
421
+ self.dim_names["longitude"]: slice(lon_min - margin, lon_max + margin),
422
+ }
423
+ )
424
+
425
+ # Check if the selected subdomain has zero dimensions in latitude or longitude
426
+ if subdomain[self.dim_names["latitude"]].size == 0:
427
+ raise ValueError("Selected latitude range does not intersect with dataset.")
428
+
429
+ if subdomain[self.dim_names["longitude"]].size == 0:
430
+ raise ValueError(
431
+ "Selected longitude range does not intersect with dataset."
432
+ )
433
+
434
+ # Adjust longitudes to expected range if needed
435
+ lon = subdomain[self.dim_names["longitude"]]
436
+ if straddle:
437
+ subdomain[self.dim_names["longitude"]] = xr.where(lon > 180, lon - 360, lon)
438
+ else:
439
+ subdomain[self.dim_names["longitude"]] = xr.where(lon < 0, lon + 360, lon)
440
+
441
+ if return_subdomain:
442
+ return subdomain
443
+ else:
444
+ object.__setattr__(self, "ds", subdomain)
445
+
446
+ def convert_to_negative_depth(self):
447
+ """
448
+ Converts the depth values in the dataset to negative if they are non-negative.
449
+
450
+ This method checks the values in the depth dimension of the dataset (`self.ds[self.dim_names["depth"]]`).
451
+ If all values are greater than or equal to zero, it negates them and updates the dataset accordingly.
452
+
453
+ """
454
+ depth = self.ds[self.dim_names["depth"]]
455
+
456
+ if (depth >= 0).all():
457
+ self.ds[self.dim_names["depth"]] = -depth
roms_tools/setup/fill.py CHANGED
@@ -3,7 +3,103 @@ import xarray as xr
3
3
  from numba import jit
4
4
 
5
5
 
6
- def lateral_fill(var, land_mask, dims=["latitude", "longitude"]):
6
+ def fill_and_interpolate(
7
+ field,
8
+ mask,
9
+ fill_dims,
10
+ coords,
11
+ method="linear",
12
+ fillvalue_fill=0.0,
13
+ fillvalue_interp=np.nan,
14
+ ):
15
+ """
16
+ Propagates ocean values into land areas and interpolates the data to specified coordinates using a given method.
17
+
18
+ Parameters
19
+ ----------
20
+ field : xr.DataArray
21
+ The data array to be interpolated, typically containing oceanographic or atmospheric data
22
+ with dimensions such as latitude and longitude.
23
+
24
+ mask : xr.DataArray
25
+ A data array with the same spatial dimensions as `field`, where `1` indicates ocean points
26
+ and `0` indicates land points. This mask is used to identify land and ocean areas in the dataset.
27
+
28
+ fill_dims : list of str
29
+ List specifying the dimensions along which to perform the lateral fill, typically the horizontal
30
+ dimensions such as latitude and longitude, e.g., ["latitude", "longitude"].
31
+
32
+ coords : dict
33
+ Dictionary specifying the target coordinates for interpolation. The keys should match the dimensions
34
+ of `field` (e.g., {"longitude": lon_values, "latitude": lat_values, "depth": depth_values}).
35
+ This dictionary provides the new coordinates onto which the data array will be interpolated.
36
+
37
+ method : str, optional, default='linear'
38
+ The interpolation method to use. Valid options are those supported by `xarray.DataArray.interp`,
39
+ such as 'linear' or 'nearest'.
40
+
41
+ fillvalue_fill : float, optional, default=0.0
42
+ Value to use in the fill step if an entire data slice along the fill dimensions contains only NaNs.
43
+
44
+ fillvalue_interp : float, optional, default=np.nan
45
+ Value to use in the interpolation step. `np.nan` means that no extrapolation is applied.
46
+ `None` means that extrapolation is applied, which often makes sense when interpolating in the
47
+ vertical direction to avoid NaNs at the surface if the lowest depth is greater than zero.
48
+
49
+ Returns
50
+ -------
51
+ xr.DataArray
52
+ The interpolated data array. This array has the same dimensions as the input `field` but with values
53
+ interpolated to the new coordinates specified in `coords`.
54
+
55
+ Notes
56
+ -----
57
+ This method performs the following steps:
58
+ 1. Sets land values to NaN based on the provided mask to ensure that interpolation does not cross
59
+ the land-ocean boundary.
60
+ 2. Uses the `lateral_fill` function to propagate ocean values into the land interior, helping to fill
61
+ gaps in the dataset.
62
+ 3. Interpolates the filled data array over the specified coordinates using the selected interpolation method.
63
+
64
+ Example
65
+ -------
66
+ >>> import xarray as xr
67
+ >>> field = xr.DataArray(...)
68
+ >>> mask = xr.DataArray(...)
69
+ >>> fill_dims = ["latitude", "longitude"]
70
+ >>> coords = {"latitude": new_lat_values, "longitude": new_lon_values}
71
+ >>> interpolated_field = fill_and_interpolate(
72
+ ... field, mask, fill_dims, coords, method="linear"
73
+ ... )
74
+ >>> print(interpolated_field)
75
+ """
76
+ if not isinstance(field, xr.DataArray):
77
+ raise TypeError("field must be an xarray.DataArray")
78
+ if not isinstance(mask, xr.DataArray):
79
+ raise TypeError("mask must be an xarray.DataArray")
80
+ if not isinstance(coords, dict):
81
+ raise TypeError("coords must be a dictionary")
82
+ if not all(dim in field.dims for dim in coords.keys()):
83
+ raise ValueError("All keys in coords must match dimensions of field")
84
+ if method not in ["linear", "nearest"]:
85
+ raise ValueError(
86
+ "Unsupported interpolation method. Choose from 'linear', 'nearest'"
87
+ )
88
+
89
+ # Set land values to NaN
90
+ field = field.where(mask)
91
+
92
+ # Propagate ocean values into land interior before interpolation
93
+ field = lateral_fill(field, 1 - mask, fill_dims, fillvalue_fill)
94
+
95
+ field_interpolated = field.interp(
96
+ coords, method=method, kwargs={"fill_value": fillvalue_interp}
97
+ ).drop_vars(list(coords.keys()))
98
+
99
+ return field_interpolated
100
+
101
+
102
+ def lateral_fill(var, land_mask, dims=["latitude", "longitude"], fillvalue=0.0):
7
103
  """
8
104
  Perform lateral fill on an xarray DataArray using a land mask.
9
105
 
@@ -20,6 +116,9 @@ def lateral_fill(var, land_mask, dims=["latitude", "longitude"]):
20
116
  dims : list of str, optional, default=['latitude', 'longitude']
21
117
  Dimensions along which to perform the fill. The default is ['latitude', 'longitude'].
22
118
 
119
+ fillvalue : float, optional, default=0.0
120
+ Value to use if an entire data slice along the dims contains only NaNs.
121
+
23
122
  Returns
24
123
  -------
25
124
  var_filled : xarray.DataArray
@@ -27,21 +126,25 @@ def lateral_fill(var, land_mask, dims=["latitude", "longitude"]):
27
126
  specified by `land_mask` where NaNs are preserved.
28
127
 
29
128
  """
129
+
30
130
  var_filled = xr.apply_ufunc(
31
131
  _lateral_fill_np_array,
32
132
  var,
33
133
  land_mask,
34
134
  input_core_dims=[dims, dims],
35
135
  output_core_dims=[dims],
36
- dask="parallelized",
37
136
  output_dtypes=[var.dtype],
137
+ dask="parallelized",
38
138
  vectorize=True,
139
+ kwargs={"fillvalue": fillvalue},
39
140
  )
40
141
 
41
142
  return var_filled
42
143
 
43
144
 
44
- def _lateral_fill_np_array(var, isvalid_mask, tol=1.0e-4, rc=1.8, max_iter=10000):
145
+ def _lateral_fill_np_array(
146
+ var, isvalid_mask, fillvalue=0.0, tol=1.0e-4, rc=1.8, max_iter=10000
147
+ ):
45
148
  """
46
149
  Perform lateral fill on a numpy array.
47
150
 
@@ -55,6 +158,9 @@ def _lateral_fill_np_array(var, isvalid_mask, tol=1.0e-4, rc=1.8, max_iter=10000
55
158
  Valid values mask: `True` where data should be filled. Must have same shape
56
159
  as `var`.
57
160
 
161
+ fillvalue: float
162
+ Value to use if the full field `var` contains only NaNs. Default is 0.0.
163
+
58
164
  tol : float, optional, default=1.0e-4
59
165
  Convergence criteria: stop filling when the value change is less than
60
166
  or equal to `tol * var`, i.e., `delta <= tol * np.abs(var[j, i])`.
@@ -90,14 +196,14 @@ def _lateral_fill_np_array(var, isvalid_mask, tol=1.0e-4, rc=1.8, max_iter=10000
90
196
 
91
197
  fillmask = np.isnan(var) # Fill all NaNs
92
198
  keepNaNs = ~isvalid_mask & np.isnan(var)
93
- var = _iterative_fill_sor(nlat, nlon, var, fillmask, tol, rc, max_iter)
199
+ var = _iterative_fill_sor(nlat, nlon, var, fillmask, tol, rc, max_iter, fillvalue)
94
200
  var[keepNaNs] = np.nan # Replace NaNs in areas not designated for filling
95
201
 
96
202
  return var
97
203
 
98
204
 
99
205
  @jit(nopython=True, parallel=True)
100
- def _iterative_fill_sor(nlat, nlon, var, fillmask, tol, rc, max_iter):
206
+ def _iterative_fill_sor(nlat, nlon, var, fillmask, tol, rc, max_iter, fillvalue=0.0):
101
207
  """
102
208
  Perform an iterative land fill algorithm using the Successive Over-Relaxation (SOR)
103
209
  solution of the Laplace Equation.
@@ -126,6 +232,9 @@ def _iterative_fill_sor(nlat, nlon, var, fillmask, tol, rc, max_iter):
126
232
  max_iter : int
127
233
  Maximum number of iterations allowed before the process is terminated.
128
234
 
235
+ fillvalue: float
236
+ Value to use if the full field is NaNs. Default is 0.0.
237
+
129
238
  Returns
130
239
  -------
131
240
  None
@@ -155,6 +264,10 @@ def _iterative_fill_sor(nlat, nlon, var, fillmask, tol, rc, max_iter):
155
264
  if np.max(np.fabs(var)) == 0.0:
156
265
  var = np.zeros_like(var)
157
266
  return var
267
+ # If field consists only of NaNs, fill NaNs with fill value
268
+ if np.isnan(var).all():
269
+ var = fillvalue * np.ones_like(var)
270
+ return var
158
271
 
159
272
  # Compute a zonal mean to use as a first guess
160
273
  zoncnt = np.zeros(nlat)