roms-tools 2.2.0__py3-none-any.whl → 2.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. roms_tools/__init__.py +1 -0
  2. roms_tools/analysis/roms_output.py +586 -0
  3. roms_tools/{setup/download.py → download.py} +3 -0
  4. roms_tools/{setup/plot.py → plot.py} +34 -28
  5. roms_tools/setup/boundary_forcing.py +23 -12
  6. roms_tools/setup/datasets.py +2 -135
  7. roms_tools/setup/grid.py +54 -15
  8. roms_tools/setup/initial_conditions.py +105 -149
  9. roms_tools/setup/nesting.py +4 -4
  10. roms_tools/setup/river_forcing.py +7 -9
  11. roms_tools/setup/surface_forcing.py +14 -14
  12. roms_tools/setup/tides.py +24 -21
  13. roms_tools/setup/topography.py +1 -1
  14. roms_tools/setup/utils.py +19 -143
  15. roms_tools/tests/test_analysis/test_roms_output.py +269 -0
  16. roms_tools/tests/{test_setup/test_regrid.py → test_regrid.py} +1 -1
  17. roms_tools/tests/test_setup/test_boundary_forcing.py +1 -1
  18. roms_tools/tests/test_setup/test_datasets.py +1 -1
  19. roms_tools/tests/test_setup/test_grid.py +1 -1
  20. roms_tools/tests/test_setup/test_initial_conditions.py +8 -4
  21. roms_tools/tests/test_setup/test_river_forcing.py +1 -1
  22. roms_tools/tests/test_setup/test_surface_forcing.py +1 -1
  23. roms_tools/tests/test_setup/test_tides.py +1 -1
  24. roms_tools/tests/test_setup/test_topography.py +1 -1
  25. roms_tools/tests/test_setup/test_utils.py +56 -1
  26. roms_tools/utils.py +301 -0
  27. roms_tools/vertical_coordinate.py +306 -0
  28. {roms_tools-2.2.0.dist-info → roms_tools-2.3.0.dist-info}/METADATA +1 -1
  29. {roms_tools-2.2.0.dist-info → roms_tools-2.3.0.dist-info}/RECORD +33 -31
  30. roms_tools/setup/vertical_coordinate.py +0 -109
  31. /roms_tools/{setup/regrid.py → regrid.py} +0 -0
  32. {roms_tools-2.2.0.dist-info → roms_tools-2.3.0.dist-info}/LICENSE +0 -0
  33. {roms_tools-2.2.0.dist-info → roms_tools-2.3.0.dist-info}/WHEEL +0 -0
  34. {roms_tools-2.2.0.dist-info → roms_tools-2.3.0.dist-info}/top_level.txt +0 -0
roms_tools/utils.py CHANGED
@@ -4,6 +4,8 @@ import numpy as np
4
4
  import xarray as xr
5
5
  from typing import Union
6
6
  from pathlib import Path
7
+ import re
8
+ import glob
7
9
 
8
10
 
9
11
  def partition(
@@ -333,3 +335,302 @@ def partition_netcdf(
333
335
  xr.save_mfdataset(partitioned_datasets, paths_to_partitioned_files)
334
336
 
335
337
  return paths_to_partitioned_files
338
+
339
+
340
+ def _load_data(
341
+ filename,
342
+ dim_names,
343
+ use_dask,
344
+ time_chunking=True,
345
+ decode_times=True,
346
+ force_combine_nested=False,
347
+ ):
348
+ """Load dataset from the specified file.
349
+
350
+ Parameters
351
+ ----------
352
+ filename : Union[str, Path, List[Union[str, Path]]]
353
+ The path to the data file(s). Can be a single string (with or without wildcards), a single Path object,
354
+ or a list of strings or Path objects containing multiple files.
355
+ dim_names : Dict[str, str], optional
356
+ Dictionary specifying the names of dimensions in the dataset.
357
+ Required only for lat-lon datasets to map dimension names like "latitude" and "longitude".
358
+ For ROMS datasets, this parameter can be omitted, as default ROMS dimensions ("eta_rho", "xi_rho", "s_rho") are assumed.
359
+ use_dask: bool
360
+ Indicates whether to use dask for chunking. If True, data is loaded with dask; if False, data is loaded eagerly. Defaults to False.
361
+ time_chunking : bool, optional
362
+ If True and `use_dask=True`, the data will be chunked along the time dimension with a chunk size of 1.
363
+ If False, the data will not be chunked explicitly along the time dimension, but will follow the default auto chunking scheme. This option is useful for ROMS restart files.
364
+ Defaults to True.
365
+ decode_times: bool, optional
366
+ If True, decode times encoded in the standard NetCDF datetime format into datetime objects. Otherwise, leave them encoded as numbers.
367
+ Defaults to True.
368
+ force_combine_nested: bool, optional
369
+ If True, forces the use of nested combination (`combine_nested`) regardless of whether wildcards are used.
370
+ Defaults to False.
371
+
372
+ Returns
373
+ -------
374
+ ds : xr.Dataset
375
+ The loaded xarray Dataset containing the forcing data.
376
+
377
+ Raises
378
+ ------
379
+ FileNotFoundError
380
+ If the specified file does not exist.
381
+ ValueError
382
+ If a list of files is provided but dim_names["time"] is not available or use_dask=False.
383
+ """
384
+ if dim_names is None:
385
+ dim_names = {}
386
+
387
+ # Precompile the regex for matching wildcard characters
388
+ wildcard_regex = re.compile(r"[\*\?\[\]]")
389
+
390
+ # Convert Path objects to strings
391
+ if isinstance(filename, (str, Path)):
392
+ filename_str = str(filename)
393
+ elif isinstance(filename, list):
394
+ filename_str = [str(f) for f in filename]
395
+ else:
396
+ raise ValueError("filename must be a string, Path, or a list of strings/Paths.")
397
+
398
+ # Handle the case when filename is a string
399
+ contains_wildcard = False
400
+ if isinstance(filename_str, str):
401
+ contains_wildcard = bool(wildcard_regex.search(filename_str))
402
+ if contains_wildcard:
403
+ matching_files = glob.glob(filename_str)
404
+ if not matching_files:
405
+ raise FileNotFoundError(
406
+ f"No files found matching the pattern '{filename_str}'."
407
+ )
408
+ else:
409
+ matching_files = [filename_str]
410
+
411
+ # Handle the case when filename is a list
412
+ elif isinstance(filename_str, list):
413
+ contains_wildcard = any(wildcard_regex.search(f) for f in filename_str)
414
+ if contains_wildcard:
415
+ matching_files = []
416
+ for f in filename_str:
417
+ files = glob.glob(f)
418
+ if not files:
419
+ raise FileNotFoundError(
420
+ f"No files found matching the pattern '{f}'."
421
+ )
422
+ matching_files.extend(files)
423
+ else:
424
+ matching_files = filename_str
425
+
426
+ # Sort the matching files
427
+ matching_files = sorted(matching_files)
428
+
429
+ # Check if time dimension is available when multiple files are provided
430
+ if isinstance(filename_str, list) and "time" not in dim_names:
431
+ raise ValueError(
432
+ "A list of files is provided, but time dimension is not available. "
433
+ "A time dimension must be available to concatenate the files."
434
+ )
435
+
436
+ # Determine the kwargs for combining datasets
437
+ if force_combine_nested:
438
+ kwargs = {"combine": "nested", "concat_dim": dim_names["time"]}
439
+ elif contains_wildcard or len(matching_files) == 1:
440
+ kwargs = {"combine": "by_coords"}
441
+ else:
442
+ kwargs = {"combine": "nested", "concat_dim": dim_names["time"]}
443
+
444
+ # Base kwargs used for dataset combination
445
+ combine_kwargs = {
446
+ "coords": "minimal",
447
+ "compat": "override",
448
+ "combine_attrs": "override",
449
+ }
450
+
451
+ if use_dask:
452
+
453
+ if "latitude" in dim_names and "longitude" in dim_names:
454
+ # for lat-lon datasets
455
+ chunks = {
456
+ dim_names["latitude"]: -1,
457
+ dim_names["longitude"]: -1,
458
+ }
459
+ else:
460
+ # For ROMS datasets
461
+ chunks = {
462
+ "eta_rho": -1,
463
+ "eta_v": -1,
464
+ "xi_rho": -1,
465
+ "xi_u": -1,
466
+ "s_rho": -1,
467
+ }
468
+
469
+ if "depth" in dim_names:
470
+ chunks[dim_names["depth"]] = -1
471
+ if "time" in dim_names and time_chunking:
472
+ chunks[dim_names["time"]] = 1
473
+
474
+ ds = xr.open_mfdataset(
475
+ matching_files,
476
+ decode_times=decode_times,
477
+ chunks=chunks,
478
+ **combine_kwargs,
479
+ **kwargs,
480
+ )
481
+
482
+ # Rechunk the dataset along the tidal constituent dimension ("ntides") after loading
483
+ # because the original dataset does not have a chunk size of 1 along this dimension.
484
+ if "ntides" in dim_names:
485
+ ds = ds.chunk({dim_names["ntides"]: 1})
486
+
487
+ else:
488
+ ds_list = []
489
+ for file in matching_files:
490
+ ds = xr.open_dataset(file, decode_times=decode_times, chunks=None)
491
+ ds_list.append(ds)
492
+
493
+ if kwargs["combine"] == "by_coords":
494
+ ds = xr.combine_by_coords(ds_list, **combine_kwargs)
495
+ elif kwargs["combine"] == "nested":
496
+ ds = xr.combine_nested(
497
+ ds_list, concat_dim=kwargs["concat_dim"], **combine_kwargs
498
+ )
499
+
500
+ if "time" in dim_names and dim_names["time"] not in ds.dims:
501
+ ds = ds.expand_dims(dim_names["time"])
502
+
503
+ return ds
504
+
505
+
506
+ def interpolate_from_rho_to_u(field, method="additive"):
507
+ """Interpolates the given field from rho points to u points.
508
+
509
+ This function performs an interpolation from the rho grid (cell centers) to the u grid
510
+ (cell edges in the xi direction). Depending on the chosen method, it either averages
511
+ (additive) or multiplies (multiplicative) the field values between adjacent rho points
512
+ along the xi dimension. It also handles the removal of unnecessary coordinate variables
513
+ and updates the dimensions accordingly.
514
+
515
+ Parameters
516
+ ----------
517
+ field : xr.DataArray
518
+ The input data array on the rho grid to be interpolated. It is assumed to have a dimension
519
+ named "xi_rho".
520
+
521
+ method : str, optional, default='additive'
522
+ The method to use for interpolation. Options are:
523
+ - 'additive': Average the field values between adjacent rho points.
524
+ - 'multiplicative': Multiply the field values between adjacent rho points. Appropriate for
525
+ binary masks.
526
+
527
+ Returns
528
+ -------
529
+ field_interpolated : xr.DataArray
530
+ The interpolated data array on the u grid with the dimension "xi_u".
531
+ """
532
+
533
+ if method == "additive":
534
+ field_interpolated = 0.5 * (field + field.shift(xi_rho=1)).isel(
535
+ xi_rho=slice(1, None)
536
+ )
537
+ elif method == "multiplicative":
538
+ field_interpolated = (field * field.shift(xi_rho=1)).isel(xi_rho=slice(1, None))
539
+ else:
540
+ raise NotImplementedError(f"Unsupported method '{method}' specified.")
541
+
542
+ vars_to_drop = ["lat_rho", "lon_rho", "eta_rho", "xi_rho"]
543
+ for var in vars_to_drop:
544
+ if var in field_interpolated.coords:
545
+ field_interpolated = field_interpolated.drop_vars(var)
546
+
547
+ field_interpolated = field_interpolated.swap_dims({"xi_rho": "xi_u"})
548
+
549
+ return field_interpolated
550
+
551
+
552
+ def interpolate_from_rho_to_v(field, method="additive"):
553
+ """Interpolates the given field from rho points to v points.
554
+
555
+ This function performs an interpolation from the rho grid (cell centers) to the v grid
556
+ (cell edges in the eta direction). Depending on the chosen method, it either averages
557
+ (additive) or multiplies (multiplicative) the field values between adjacent rho points
558
+ along the eta dimension. It also handles the removal of unnecessary coordinate variables
559
+ and updates the dimensions accordingly.
560
+
561
+ Parameters
562
+ ----------
563
+ field : xr.DataArray
564
+ The input data array on the rho grid to be interpolated. It is assumed to have a dimension
565
+ named "eta_rho".
566
+
567
+ method : str, optional, default='additive'
568
+ The method to use for interpolation. Options are:
569
+ - 'additive': Average the field values between adjacent rho points.
570
+ - 'multiplicative': Multiply the field values between adjacent rho points. Appropriate for
571
+ binary masks.
572
+
573
+ Returns
574
+ -------
575
+ field_interpolated : xr.DataArray
576
+ The interpolated data array on the v grid with the dimension "eta_v".
577
+ """
578
+
579
+ if method == "additive":
580
+ field_interpolated = 0.5 * (field + field.shift(eta_rho=1)).isel(
581
+ eta_rho=slice(1, None)
582
+ )
583
+ elif method == "multiplicative":
584
+ field_interpolated = (field * field.shift(eta_rho=1)).isel(
585
+ eta_rho=slice(1, None)
586
+ )
587
+ else:
588
+ raise NotImplementedError(f"Unsupported method '{method}' specified.")
589
+
590
+ vars_to_drop = ["lat_rho", "lon_rho", "eta_rho", "xi_rho"]
591
+ for var in vars_to_drop:
592
+ if var in field_interpolated.coords:
593
+ field_interpolated = field_interpolated.drop_vars(var)
594
+
595
+ field_interpolated = field_interpolated.swap_dims({"eta_rho": "eta_v"})
596
+
597
+ return field_interpolated
598
+
599
+
600
+ def transpose_dimensions(da: xr.DataArray) -> xr.DataArray:
601
+ """Transpose the dimensions of an xarray.DataArray to ensure that 'time', any
602
+ dimension starting with 's_', 'eta_', and 'xi_' are ordered first, followed by the
603
+ remaining dimensions in their original order.
604
+
605
+ Parameters
606
+ ----------
607
+ da : xarray.DataArray
608
+ The input DataArray whose dimensions are to be reordered.
609
+
610
+ Returns
611
+ -------
612
+ xarray.DataArray
613
+ The DataArray with dimensions reordered so that 'time', 's_*', 'eta_*',
614
+ and 'xi_*' are first, in that order, if they exist.
615
+ """
616
+
617
+ # List of preferred dimension patterns
618
+ preferred_order = ["time", "s_", "eta_", "xi_"]
619
+
620
+ # Get the existing dimensions in the DataArray
621
+ dims = list(da.dims)
622
+
623
+ # Collect dimensions that match any of the preferred patterns
624
+ matched_dims = []
625
+ for pattern in preferred_order:
626
+ # Find dimensions that start with the pattern
627
+ matched_dims += [dim for dim in dims if dim.startswith(pattern)]
628
+
629
+ # Create a new order: first the matched dimensions, then the rest
630
+ remaining_dims = [dim for dim in dims if dim not in matched_dims]
631
+ new_order = matched_dims + remaining_dims
632
+
633
+ # Transpose the DataArray to the new order
634
+ transposed_da = da.transpose(*new_order)
635
+
636
+ return transposed_da
@@ -0,0 +1,306 @@
1
+ import numpy as np
2
+ import xarray as xr
3
+ from roms_tools.utils import (
4
+ transpose_dimensions,
5
+ interpolate_from_rho_to_u,
6
+ interpolate_from_rho_to_v,
7
+ )
8
+
9
+
10
+ def compute_cs(sigma, theta_s, theta_b):
11
+ """Compute the S-coordinate stretching curves according to Shchepetkin and
12
+ McWilliams (2009).
13
+
14
+ Parameters
15
+ ----------
16
+ sigma : np.ndarray or float
17
+ The sigma-coordinate values.
18
+ theta_s : float
19
+ The surface control parameter.
20
+ theta_b : float
21
+ The bottom control parameter.
22
+
23
+ Returns
24
+ -------
25
+ C : np.ndarray or float
26
+ The stretching curve values.
27
+
28
+ Raises
29
+ ------
30
+ ValueError
31
+ If theta_s or theta_b are not within the valid range.
32
+ """
33
+ if not (0 < theta_s <= 10):
34
+ raise ValueError("theta_s must be between 0 and 10.")
35
+ if not (0 < theta_b <= 4):
36
+ raise ValueError("theta_b must be between 0 and 4.")
37
+
38
+ C = (1 - np.cosh(theta_s * sigma)) / (np.cosh(theta_s) - 1)
39
+ C = (np.exp(theta_b * C) - 1) / (1 - np.exp(-theta_b))
40
+
41
+ return C
42
+
43
+
44
+ def sigma_stretch(theta_s, theta_b, N, type):
45
+ """Compute sigma and stretching curves based on the type and parameters.
46
+
47
+ Parameters
48
+ ----------
49
+ theta_s : float
50
+ The surface control parameter.
51
+ theta_b : float
52
+ The bottom control parameter.
53
+ N : int
54
+ The number of vertical levels.
55
+ type : str
56
+ The type of sigma ('w' for vertical velocity points, 'r' for rho-points).
57
+
58
+ Returns
59
+ -------
60
+ cs : xr.DataArray
61
+ The stretching curve values.
62
+ sigma : xr.DataArray
63
+ The sigma-coordinate values.
64
+
65
+ Raises
66
+ ------
67
+ ValueError
68
+ If the type is not 'w' or 'r'.
69
+ """
70
+ if type == "w":
71
+ k = xr.DataArray(np.arange(N + 1), dims="s_w")
72
+ sigma = (k - N) / N
73
+ elif type == "r":
74
+ k = xr.DataArray(np.arange(1, N + 1), dims="s_rho")
75
+ sigma = (k - N - 0.5) / N
76
+ else:
77
+ raise ValueError(
78
+ "Type must be either 'w' for vertical velocity points or 'r' for rho-points."
79
+ )
80
+
81
+ cs = compute_cs(sigma, theta_s, theta_b)
82
+
83
+ return cs, sigma
84
+
85
+
86
+ def compute_depth(zeta, h, hc, cs, sigma):
87
+ """Compute the depth at different sigma levels.
88
+
89
+ Parameters
90
+ ----------
91
+ zeta : xr.DataArray or scalar
92
+ The sea surface height.
93
+ h : xr.DataArray
94
+ The depth of the sea bottom.
95
+ hc : float
96
+ The critical depth.
97
+ cs : xr.DataArray
98
+ The stretching curve values.
99
+ sigma : xr.DataArray
100
+ The sigma-coordinate values.
101
+
102
+ Returns
103
+ -------
104
+ z : xr.DataArray
105
+ The depth at different sigma levels.
106
+ """
107
+
108
+ z = (hc * sigma + h * cs) / (hc + h)
109
+ z = zeta + (zeta + h) * z
110
+
111
+ z = -transpose_dimensions(z)
112
+
113
+ return z
114
+
115
+
116
+ def add_depth_coordinates_to_dataset(
117
+ ds: "xr.Dataset",
118
+ grid_ds: "xr.Dataset",
119
+ depth_type: str,
120
+ locations: list[str] = ["rho", "u", "v"],
121
+ ) -> None:
122
+ """Add computed vertical depth coordinates to a dataset for specified grid
123
+ locations.
124
+
125
+ This function computes vertical depth coordinates (layer or interface) and updates
126
+ the provided dataset with these coordinates for the specified grid locations. If
127
+ the dataset already contains depth coordinates for all specified locations, the function
128
+ does nothing.
129
+
130
+ Parameters
131
+ ----------
132
+ ds : xr.Dataset
133
+ Target dataset to which computed depth coordinates will be added.
134
+ If the `zeta` variable is not present, static vertical coordinates are used.
135
+
136
+ grid_ds : xr.Dataset
137
+ Grid dataset containing bathymetry, stretching curves, and parameters.
138
+
139
+ depth_type : str
140
+ Type of depth coordinates to compute. Options are:
141
+ - "layer": Layer depth coordinates.
142
+ - "interface": Interface depth coordinates.
143
+
144
+ locations : list of str, optional
145
+ List of locations for which to compute depth coordinates. Default is ["rho", "u", "v"].
146
+ """
147
+ required_vars = [f"{depth_type}_depth_{loc}" for loc in locations]
148
+
149
+ if all(var in ds for var in required_vars):
150
+ return # Depth coordinates already exist
151
+
152
+ # Compute or interpolate depth coordinates
153
+ if f"{depth_type}_depth_rho" in ds:
154
+ depth_rho = ds[f"{depth_type}_depth_rho"]
155
+ else:
156
+ h = grid_ds["h"]
157
+ zeta = ds.get("zeta", 0)
158
+ if depth_type == "layer":
159
+ Cs = grid_ds["Cs_r"]
160
+ sigma = grid_ds["sigma_r"]
161
+ elif depth_type == "interface":
162
+ Cs = grid_ds["Cs_w"]
163
+ sigma = grid_ds["sigma_w"]
164
+ depth_rho = compute_depth(zeta, h, grid_ds.attrs["hc"], Cs, sigma)
165
+ depth_rho.attrs.update(
166
+ {"long_name": f"{depth_type} depth at rho-points", "units": "m"}
167
+ )
168
+ ds[f"{depth_type}_depth_rho"] = depth_rho
169
+
170
+ # Interpolate depth to other locations
171
+ for loc in locations:
172
+ if loc == "rho":
173
+ continue
174
+
175
+ interp_func = (
176
+ interpolate_from_rho_to_u if loc == "u" else interpolate_from_rho_to_v
177
+ )
178
+ depth_loc = interp_func(depth_rho)
179
+ depth_loc.attrs.update(
180
+ {"long_name": f"{depth_type} depth at {loc}-points", "units": "m"}
181
+ )
182
+ ds[f"{depth_type}_depth_{loc}"] = depth_loc
183
+
184
+
185
+ def compute_depth_coordinates(
186
+ ds: "xr.Dataset",
187
+ grid_ds: "xr.Dataset",
188
+ depth_type: str,
189
+ location: str,
190
+ s: int = None,
191
+ eta: int = None,
192
+ xi: int = None,
193
+ ) -> "xr.DataArray":
194
+ """Compute vertical depth coordinates efficiently for a specified grid location and
195
+ optional indices.
196
+
197
+ This function calculates vertical depth coordinates (layer or interface) for a given grid
198
+ location (`rho`, `u`, or `v`). It performs spatial slicing (meridional or zonal) on the
199
+ bathymetry and free-surface elevation (`zeta`) before computing depth coordinates. This
200
+ approach minimizes computational overhead by reducing the dataset size before performing
201
+ vertical coordinate calculations.
202
+
203
+ Parameters
204
+ ----------
205
+ ds : xr.Dataset
206
+ Dataset containing optional `zeta` (free-surface elevation). If `zeta` is not present,
207
+ static vertical coordinates are computed.
208
+
209
+ grid_ds : xr.Dataset
210
+ Grid dataset containing bathymetry (`h`), stretching curves (`Cs`), and sigma-layer
211
+ parameters (`sigma`). The attributes of this dataset should include the critical depth (`hc`).
212
+
213
+ depth_type : str
214
+ Type of depth coordinates to compute:
215
+ - `"layer"`: Depth at the center of layers.
216
+ - `"interface"`: Depth at layer interfaces.
217
+
218
+ location : str
219
+ Grid location for the computation. Options are:
220
+ - `"rho"`: Depth at rho points (cell centers).
221
+ - `"u"`: Depth at u points (eastward velocity points).
222
+ - `"v"`: Depth at v points (northward velocity points).
223
+
224
+ s : int, optional
225
+ Vertical index to extract a single layer or interface slice. If not provided, all vertical
226
+ layers are included.
227
+
228
+ eta : int, optional
229
+ Meridional (north-south) index to extract a slice. If not provided, all meridional indices
230
+ are included.
231
+
232
+ xi : int, optional
233
+ Zonal (east-west) index to extract a slice. If not provided, all zonal indices are included.
234
+
235
+ Returns
236
+ -------
237
+ xr.DataArray
238
+ A DataArray containing the computed depth coordinates. If no indices are specified, the
239
+ array will have the full dimensionality of the depth coordinates. The dimensions of the
240
+ output depend on the provided indices:
241
+ - Full 3D (or 4D if `zeta` includes time) depth coordinates if no indices are provided.
242
+ - Reduced dimensionality for specified slices (e.g., 2D for a single vertical slice).
243
+
244
+ Notes
245
+ -----
246
+ - To ensure computational efficiency, spatial slicing (based on `eta` and `xi`) is performed
247
+ before computing depth coordinates. This reduces memory usage and processing time.
248
+ - Depth coordinates are interpolated to the specified grid location (`rho`, `u`, or `v`) if
249
+ necessary.
250
+ - If depth coordinates for the specified location and configuration already exist in `ds`,
251
+ they are not recomputed.
252
+ """
253
+
254
+ h = grid_ds["h"]
255
+ zeta = ds.get("zeta", None)
256
+
257
+ # Interpolate h and zeta to the specified location
258
+ if location == "u":
259
+ h = interpolate_from_rho_to_u(h)
260
+ if zeta is not None:
261
+ zeta = interpolate_from_rho_to_u(zeta)
262
+ elif location == "v":
263
+ h = interpolate_from_rho_to_v(h)
264
+ if zeta is not None:
265
+ zeta = interpolate_from_rho_to_v(zeta)
266
+
267
+ # Slice spatially based on the location's specific dimensions
268
+ if eta is not None:
269
+ if location == "v":
270
+ h = h.isel(eta_v=eta)
271
+ if zeta is not None:
272
+ zeta = zeta.isel(eta_v=eta)
273
+ else: # Default to "rho" or "u"
274
+ h = h.isel(eta_rho=eta)
275
+ if zeta is not None:
276
+ zeta = zeta.isel(eta_rho=eta)
277
+ if xi is not None:
278
+ if location == "u":
279
+ h = h.isel(xi_u=xi)
280
+ if zeta is not None:
281
+ zeta = zeta.isel(xi_u=xi)
282
+ else: # Default to "rho" or "v"
283
+ h = h.isel(xi_rho=xi)
284
+ if zeta is not None:
285
+ zeta = zeta.isel(xi_rho=xi)
286
+
287
+ # Compute depth
288
+ if depth_type == "layer":
289
+ Cs = grid_ds["Cs_r"]
290
+ sigma = grid_ds["sigma_r"]
291
+ elif depth_type == "interface":
292
+ Cs = grid_ds["Cs_w"]
293
+ sigma = grid_ds["sigma_w"]
294
+ depth = compute_depth(zeta, h, grid_ds.attrs["hc"], Cs, sigma)
295
+
296
+ # Slice vertically
297
+ if s is not None:
298
+ vertical_dim = "s_rho" if "s_rho" in depth.dims else "s_w"
299
+ depth = depth.isel({vertical_dim: s})
300
+
301
+ # Add metadata
302
+ depth.attrs.update(
303
+ {"long_name": f"{depth_type} depth at {location}-points", "units": "m"}
304
+ )
305
+
306
+ return depth
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: roms-tools
3
- Version: 2.2.0
3
+ Version: 2.3.0
4
4
  Summary: Tools for running and analysing UCLA-ROMS simulations
5
5
  Author-email: Nora Loose <nora.loose@gmail.com>, Thomas Nicholas <tom@cworthy.org>
6
6
  License: Apache-2