roms-tools 2.2.1__py3-none-any.whl → 2.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- roms_tools/__init__.py +1 -0
- roms_tools/analysis/roms_output.py +586 -0
- roms_tools/{setup/download.py → download.py} +3 -0
- roms_tools/{setup/plot.py → plot.py} +34 -28
- roms_tools/setup/boundary_forcing.py +23 -12
- roms_tools/setup/datasets.py +2 -135
- roms_tools/setup/grid.py +54 -15
- roms_tools/setup/initial_conditions.py +105 -149
- roms_tools/setup/nesting.py +4 -4
- roms_tools/setup/river_forcing.py +7 -9
- roms_tools/setup/surface_forcing.py +14 -14
- roms_tools/setup/tides.py +24 -21
- roms_tools/setup/topography.py +1 -1
- roms_tools/setup/utils.py +20 -154
- roms_tools/tests/test_analysis/test_roms_output.py +269 -0
- roms_tools/tests/{test_setup/test_regrid.py → test_regrid.py} +1 -1
- roms_tools/tests/test_setup/test_boundary_forcing.py +1 -1
- roms_tools/tests/test_setup/test_datasets.py +1 -1
- roms_tools/tests/test_setup/test_grid.py +1 -1
- roms_tools/tests/test_setup/test_initial_conditions.py +1 -1
- roms_tools/tests/test_setup/test_river_forcing.py +1 -1
- roms_tools/tests/test_setup/test_surface_forcing.py +1 -1
- roms_tools/tests/test_setup/test_tides.py +1 -1
- roms_tools/tests/test_setup/test_topography.py +1 -1
- roms_tools/tests/test_setup/test_utils.py +56 -1
- roms_tools/utils.py +301 -0
- roms_tools/vertical_coordinate.py +306 -0
- {roms_tools-2.2.1.dist-info → roms_tools-2.3.0.dist-info}/METADATA +1 -1
- {roms_tools-2.2.1.dist-info → roms_tools-2.3.0.dist-info}/RECORD +33 -31
- roms_tools/setup/vertical_coordinate.py +0 -109
- /roms_tools/{setup/regrid.py → regrid.py} +0 -0
- {roms_tools-2.2.1.dist-info → roms_tools-2.3.0.dist-info}/LICENSE +0 -0
- {roms_tools-2.2.1.dist-info → roms_tools-2.3.0.dist-info}/WHEEL +0 -0
- {roms_tools-2.2.1.dist-info → roms_tools-2.3.0.dist-info}/top_level.txt +0 -0
roms_tools/utils.py
CHANGED
|
@@ -4,6 +4,8 @@ import numpy as np
|
|
|
4
4
|
import xarray as xr
|
|
5
5
|
from typing import Union
|
|
6
6
|
from pathlib import Path
|
|
7
|
+
import re
|
|
8
|
+
import glob
|
|
7
9
|
|
|
8
10
|
|
|
9
11
|
def partition(
|
|
@@ -333,3 +335,302 @@ def partition_netcdf(
|
|
|
333
335
|
xr.save_mfdataset(partitioned_datasets, paths_to_partitioned_files)
|
|
334
336
|
|
|
335
337
|
return paths_to_partitioned_files
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
def _load_data(
|
|
341
|
+
filename,
|
|
342
|
+
dim_names,
|
|
343
|
+
use_dask,
|
|
344
|
+
time_chunking=True,
|
|
345
|
+
decode_times=True,
|
|
346
|
+
force_combine_nested=False,
|
|
347
|
+
):
|
|
348
|
+
"""Load dataset from the specified file.
|
|
349
|
+
|
|
350
|
+
Parameters
|
|
351
|
+
----------
|
|
352
|
+
filename : Union[str, Path, List[Union[str, Path]]]
|
|
353
|
+
The path to the data file(s). Can be a single string (with or without wildcards), a single Path object,
|
|
354
|
+
or a list of strings or Path objects containing multiple files.
|
|
355
|
+
dim_names : Dict[str, str], optional
|
|
356
|
+
Dictionary specifying the names of dimensions in the dataset.
|
|
357
|
+
Required only for lat-lon datasets to map dimension names like "latitude" and "longitude".
|
|
358
|
+
For ROMS datasets, this parameter can be omitted, as default ROMS dimensions ("eta_rho", "xi_rho", "s_rho") are assumed.
|
|
359
|
+
use_dask: bool
|
|
360
|
+
Indicates whether to use dask for chunking. If True, data is loaded with dask; if False, data is loaded eagerly. Defaults to False.
|
|
361
|
+
time_chunking : bool, optional
|
|
362
|
+
If True and `use_dask=True`, the data will be chunked along the time dimension with a chunk size of 1.
|
|
363
|
+
If False, the data will not be chunked explicitly along the time dimension, but will follow the default auto chunking scheme. This option is useful for ROMS restart files.
|
|
364
|
+
Defaults to True.
|
|
365
|
+
decode_times: bool, optional
|
|
366
|
+
If True, decode times encoded in the standard NetCDF datetime format into datetime objects. Otherwise, leave them encoded as numbers.
|
|
367
|
+
Defaults to True.
|
|
368
|
+
force_combine_nested: bool, optional
|
|
369
|
+
If True, forces the use of nested combination (`combine_nested`) regardless of whether wildcards are used.
|
|
370
|
+
Defaults to False.
|
|
371
|
+
|
|
372
|
+
Returns
|
|
373
|
+
-------
|
|
374
|
+
ds : xr.Dataset
|
|
375
|
+
The loaded xarray Dataset containing the forcing data.
|
|
376
|
+
|
|
377
|
+
Raises
|
|
378
|
+
------
|
|
379
|
+
FileNotFoundError
|
|
380
|
+
If the specified file does not exist.
|
|
381
|
+
ValueError
|
|
382
|
+
If a list of files is provided but dim_names["time"] is not available or use_dask=False.
|
|
383
|
+
"""
|
|
384
|
+
if dim_names is None:
|
|
385
|
+
dim_names = {}
|
|
386
|
+
|
|
387
|
+
# Precompile the regex for matching wildcard characters
|
|
388
|
+
wildcard_regex = re.compile(r"[\*\?\[\]]")
|
|
389
|
+
|
|
390
|
+
# Convert Path objects to strings
|
|
391
|
+
if isinstance(filename, (str, Path)):
|
|
392
|
+
filename_str = str(filename)
|
|
393
|
+
elif isinstance(filename, list):
|
|
394
|
+
filename_str = [str(f) for f in filename]
|
|
395
|
+
else:
|
|
396
|
+
raise ValueError("filename must be a string, Path, or a list of strings/Paths.")
|
|
397
|
+
|
|
398
|
+
# Handle the case when filename is a string
|
|
399
|
+
contains_wildcard = False
|
|
400
|
+
if isinstance(filename_str, str):
|
|
401
|
+
contains_wildcard = bool(wildcard_regex.search(filename_str))
|
|
402
|
+
if contains_wildcard:
|
|
403
|
+
matching_files = glob.glob(filename_str)
|
|
404
|
+
if not matching_files:
|
|
405
|
+
raise FileNotFoundError(
|
|
406
|
+
f"No files found matching the pattern '{filename_str}'."
|
|
407
|
+
)
|
|
408
|
+
else:
|
|
409
|
+
matching_files = [filename_str]
|
|
410
|
+
|
|
411
|
+
# Handle the case when filename is a list
|
|
412
|
+
elif isinstance(filename_str, list):
|
|
413
|
+
contains_wildcard = any(wildcard_regex.search(f) for f in filename_str)
|
|
414
|
+
if contains_wildcard:
|
|
415
|
+
matching_files = []
|
|
416
|
+
for f in filename_str:
|
|
417
|
+
files = glob.glob(f)
|
|
418
|
+
if not files:
|
|
419
|
+
raise FileNotFoundError(
|
|
420
|
+
f"No files found matching the pattern '{f}'."
|
|
421
|
+
)
|
|
422
|
+
matching_files.extend(files)
|
|
423
|
+
else:
|
|
424
|
+
matching_files = filename_str
|
|
425
|
+
|
|
426
|
+
# Sort the matching files
|
|
427
|
+
matching_files = sorted(matching_files)
|
|
428
|
+
|
|
429
|
+
# Check if time dimension is available when multiple files are provided
|
|
430
|
+
if isinstance(filename_str, list) and "time" not in dim_names:
|
|
431
|
+
raise ValueError(
|
|
432
|
+
"A list of files is provided, but time dimension is not available. "
|
|
433
|
+
"A time dimension must be available to concatenate the files."
|
|
434
|
+
)
|
|
435
|
+
|
|
436
|
+
# Determine the kwargs for combining datasets
|
|
437
|
+
if force_combine_nested:
|
|
438
|
+
kwargs = {"combine": "nested", "concat_dim": dim_names["time"]}
|
|
439
|
+
elif contains_wildcard or len(matching_files) == 1:
|
|
440
|
+
kwargs = {"combine": "by_coords"}
|
|
441
|
+
else:
|
|
442
|
+
kwargs = {"combine": "nested", "concat_dim": dim_names["time"]}
|
|
443
|
+
|
|
444
|
+
# Base kwargs used for dataset combination
|
|
445
|
+
combine_kwargs = {
|
|
446
|
+
"coords": "minimal",
|
|
447
|
+
"compat": "override",
|
|
448
|
+
"combine_attrs": "override",
|
|
449
|
+
}
|
|
450
|
+
|
|
451
|
+
if use_dask:
|
|
452
|
+
|
|
453
|
+
if "latitude" in dim_names and "longitude" in dim_names:
|
|
454
|
+
# for lat-lon datasets
|
|
455
|
+
chunks = {
|
|
456
|
+
dim_names["latitude"]: -1,
|
|
457
|
+
dim_names["longitude"]: -1,
|
|
458
|
+
}
|
|
459
|
+
else:
|
|
460
|
+
# For ROMS datasets
|
|
461
|
+
chunks = {
|
|
462
|
+
"eta_rho": -1,
|
|
463
|
+
"eta_v": -1,
|
|
464
|
+
"xi_rho": -1,
|
|
465
|
+
"xi_u": -1,
|
|
466
|
+
"s_rho": -1,
|
|
467
|
+
}
|
|
468
|
+
|
|
469
|
+
if "depth" in dim_names:
|
|
470
|
+
chunks[dim_names["depth"]] = -1
|
|
471
|
+
if "time" in dim_names and time_chunking:
|
|
472
|
+
chunks[dim_names["time"]] = 1
|
|
473
|
+
|
|
474
|
+
ds = xr.open_mfdataset(
|
|
475
|
+
matching_files,
|
|
476
|
+
decode_times=decode_times,
|
|
477
|
+
chunks=chunks,
|
|
478
|
+
**combine_kwargs,
|
|
479
|
+
**kwargs,
|
|
480
|
+
)
|
|
481
|
+
|
|
482
|
+
# Rechunk the dataset along the tidal constituent dimension ("ntides") after loading
|
|
483
|
+
# because the original dataset does not have a chunk size of 1 along this dimension.
|
|
484
|
+
if "ntides" in dim_names:
|
|
485
|
+
ds = ds.chunk({dim_names["ntides"]: 1})
|
|
486
|
+
|
|
487
|
+
else:
|
|
488
|
+
ds_list = []
|
|
489
|
+
for file in matching_files:
|
|
490
|
+
ds = xr.open_dataset(file, decode_times=decode_times, chunks=None)
|
|
491
|
+
ds_list.append(ds)
|
|
492
|
+
|
|
493
|
+
if kwargs["combine"] == "by_coords":
|
|
494
|
+
ds = xr.combine_by_coords(ds_list, **combine_kwargs)
|
|
495
|
+
elif kwargs["combine"] == "nested":
|
|
496
|
+
ds = xr.combine_nested(
|
|
497
|
+
ds_list, concat_dim=kwargs["concat_dim"], **combine_kwargs
|
|
498
|
+
)
|
|
499
|
+
|
|
500
|
+
if "time" in dim_names and dim_names["time"] not in ds.dims:
|
|
501
|
+
ds = ds.expand_dims(dim_names["time"])
|
|
502
|
+
|
|
503
|
+
return ds
|
|
504
|
+
|
|
505
|
+
|
|
506
|
+
def interpolate_from_rho_to_u(field, method="additive"):
|
|
507
|
+
"""Interpolates the given field from rho points to u points.
|
|
508
|
+
|
|
509
|
+
This function performs an interpolation from the rho grid (cell centers) to the u grid
|
|
510
|
+
(cell edges in the xi direction). Depending on the chosen method, it either averages
|
|
511
|
+
(additive) or multiplies (multiplicative) the field values between adjacent rho points
|
|
512
|
+
along the xi dimension. It also handles the removal of unnecessary coordinate variables
|
|
513
|
+
and updates the dimensions accordingly.
|
|
514
|
+
|
|
515
|
+
Parameters
|
|
516
|
+
----------
|
|
517
|
+
field : xr.DataArray
|
|
518
|
+
The input data array on the rho grid to be interpolated. It is assumed to have a dimension
|
|
519
|
+
named "xi_rho".
|
|
520
|
+
|
|
521
|
+
method : str, optional, default='additive'
|
|
522
|
+
The method to use for interpolation. Options are:
|
|
523
|
+
- 'additive': Average the field values between adjacent rho points.
|
|
524
|
+
- 'multiplicative': Multiply the field values between adjacent rho points. Appropriate for
|
|
525
|
+
binary masks.
|
|
526
|
+
|
|
527
|
+
Returns
|
|
528
|
+
-------
|
|
529
|
+
field_interpolated : xr.DataArray
|
|
530
|
+
The interpolated data array on the u grid with the dimension "xi_u".
|
|
531
|
+
"""
|
|
532
|
+
|
|
533
|
+
if method == "additive":
|
|
534
|
+
field_interpolated = 0.5 * (field + field.shift(xi_rho=1)).isel(
|
|
535
|
+
xi_rho=slice(1, None)
|
|
536
|
+
)
|
|
537
|
+
elif method == "multiplicative":
|
|
538
|
+
field_interpolated = (field * field.shift(xi_rho=1)).isel(xi_rho=slice(1, None))
|
|
539
|
+
else:
|
|
540
|
+
raise NotImplementedError(f"Unsupported method '{method}' specified.")
|
|
541
|
+
|
|
542
|
+
vars_to_drop = ["lat_rho", "lon_rho", "eta_rho", "xi_rho"]
|
|
543
|
+
for var in vars_to_drop:
|
|
544
|
+
if var in field_interpolated.coords:
|
|
545
|
+
field_interpolated = field_interpolated.drop_vars(var)
|
|
546
|
+
|
|
547
|
+
field_interpolated = field_interpolated.swap_dims({"xi_rho": "xi_u"})
|
|
548
|
+
|
|
549
|
+
return field_interpolated
|
|
550
|
+
|
|
551
|
+
|
|
552
|
+
def interpolate_from_rho_to_v(field, method="additive"):
|
|
553
|
+
"""Interpolates the given field from rho points to v points.
|
|
554
|
+
|
|
555
|
+
This function performs an interpolation from the rho grid (cell centers) to the v grid
|
|
556
|
+
(cell edges in the eta direction). Depending on the chosen method, it either averages
|
|
557
|
+
(additive) or multiplies (multiplicative) the field values between adjacent rho points
|
|
558
|
+
along the eta dimension. It also handles the removal of unnecessary coordinate variables
|
|
559
|
+
and updates the dimensions accordingly.
|
|
560
|
+
|
|
561
|
+
Parameters
|
|
562
|
+
----------
|
|
563
|
+
field : xr.DataArray
|
|
564
|
+
The input data array on the rho grid to be interpolated. It is assumed to have a dimension
|
|
565
|
+
named "eta_rho".
|
|
566
|
+
|
|
567
|
+
method : str, optional, default='additive'
|
|
568
|
+
The method to use for interpolation. Options are:
|
|
569
|
+
- 'additive': Average the field values between adjacent rho points.
|
|
570
|
+
- 'multiplicative': Multiply the field values between adjacent rho points. Appropriate for
|
|
571
|
+
binary masks.
|
|
572
|
+
|
|
573
|
+
Returns
|
|
574
|
+
-------
|
|
575
|
+
field_interpolated : xr.DataArray
|
|
576
|
+
The interpolated data array on the v grid with the dimension "eta_v".
|
|
577
|
+
"""
|
|
578
|
+
|
|
579
|
+
if method == "additive":
|
|
580
|
+
field_interpolated = 0.5 * (field + field.shift(eta_rho=1)).isel(
|
|
581
|
+
eta_rho=slice(1, None)
|
|
582
|
+
)
|
|
583
|
+
elif method == "multiplicative":
|
|
584
|
+
field_interpolated = (field * field.shift(eta_rho=1)).isel(
|
|
585
|
+
eta_rho=slice(1, None)
|
|
586
|
+
)
|
|
587
|
+
else:
|
|
588
|
+
raise NotImplementedError(f"Unsupported method '{method}' specified.")
|
|
589
|
+
|
|
590
|
+
vars_to_drop = ["lat_rho", "lon_rho", "eta_rho", "xi_rho"]
|
|
591
|
+
for var in vars_to_drop:
|
|
592
|
+
if var in field_interpolated.coords:
|
|
593
|
+
field_interpolated = field_interpolated.drop_vars(var)
|
|
594
|
+
|
|
595
|
+
field_interpolated = field_interpolated.swap_dims({"eta_rho": "eta_v"})
|
|
596
|
+
|
|
597
|
+
return field_interpolated
|
|
598
|
+
|
|
599
|
+
|
|
600
|
+
def transpose_dimensions(da: xr.DataArray) -> xr.DataArray:
|
|
601
|
+
"""Transpose the dimensions of an xarray.DataArray to ensure that 'time', any
|
|
602
|
+
dimension starting with 's_', 'eta_', and 'xi_' are ordered first, followed by the
|
|
603
|
+
remaining dimensions in their original order.
|
|
604
|
+
|
|
605
|
+
Parameters
|
|
606
|
+
----------
|
|
607
|
+
da : xarray.DataArray
|
|
608
|
+
The input DataArray whose dimensions are to be reordered.
|
|
609
|
+
|
|
610
|
+
Returns
|
|
611
|
+
-------
|
|
612
|
+
xarray.DataArray
|
|
613
|
+
The DataArray with dimensions reordered so that 'time', 's_*', 'eta_*',
|
|
614
|
+
and 'xi_*' are first, in that order, if they exist.
|
|
615
|
+
"""
|
|
616
|
+
|
|
617
|
+
# List of preferred dimension patterns
|
|
618
|
+
preferred_order = ["time", "s_", "eta_", "xi_"]
|
|
619
|
+
|
|
620
|
+
# Get the existing dimensions in the DataArray
|
|
621
|
+
dims = list(da.dims)
|
|
622
|
+
|
|
623
|
+
# Collect dimensions that match any of the preferred patterns
|
|
624
|
+
matched_dims = []
|
|
625
|
+
for pattern in preferred_order:
|
|
626
|
+
# Find dimensions that start with the pattern
|
|
627
|
+
matched_dims += [dim for dim in dims if dim.startswith(pattern)]
|
|
628
|
+
|
|
629
|
+
# Create a new order: first the matched dimensions, then the rest
|
|
630
|
+
remaining_dims = [dim for dim in dims if dim not in matched_dims]
|
|
631
|
+
new_order = matched_dims + remaining_dims
|
|
632
|
+
|
|
633
|
+
# Transpose the DataArray to the new order
|
|
634
|
+
transposed_da = da.transpose(*new_order)
|
|
635
|
+
|
|
636
|
+
return transposed_da
|
|
@@ -0,0 +1,306 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import xarray as xr
|
|
3
|
+
from roms_tools.utils import (
|
|
4
|
+
transpose_dimensions,
|
|
5
|
+
interpolate_from_rho_to_u,
|
|
6
|
+
interpolate_from_rho_to_v,
|
|
7
|
+
)
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def compute_cs(sigma, theta_s, theta_b):
|
|
11
|
+
"""Compute the S-coordinate stretching curves according to Shchepetkin and
|
|
12
|
+
McWilliams (2009).
|
|
13
|
+
|
|
14
|
+
Parameters
|
|
15
|
+
----------
|
|
16
|
+
sigma : np.ndarray or float
|
|
17
|
+
The sigma-coordinate values.
|
|
18
|
+
theta_s : float
|
|
19
|
+
The surface control parameter.
|
|
20
|
+
theta_b : float
|
|
21
|
+
The bottom control parameter.
|
|
22
|
+
|
|
23
|
+
Returns
|
|
24
|
+
-------
|
|
25
|
+
C : np.ndarray or float
|
|
26
|
+
The stretching curve values.
|
|
27
|
+
|
|
28
|
+
Raises
|
|
29
|
+
------
|
|
30
|
+
ValueError
|
|
31
|
+
If theta_s or theta_b are not within the valid range.
|
|
32
|
+
"""
|
|
33
|
+
if not (0 < theta_s <= 10):
|
|
34
|
+
raise ValueError("theta_s must be between 0 and 10.")
|
|
35
|
+
if not (0 < theta_b <= 4):
|
|
36
|
+
raise ValueError("theta_b must be between 0 and 4.")
|
|
37
|
+
|
|
38
|
+
C = (1 - np.cosh(theta_s * sigma)) / (np.cosh(theta_s) - 1)
|
|
39
|
+
C = (np.exp(theta_b * C) - 1) / (1 - np.exp(-theta_b))
|
|
40
|
+
|
|
41
|
+
return C
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def sigma_stretch(theta_s, theta_b, N, type):
|
|
45
|
+
"""Compute sigma and stretching curves based on the type and parameters.
|
|
46
|
+
|
|
47
|
+
Parameters
|
|
48
|
+
----------
|
|
49
|
+
theta_s : float
|
|
50
|
+
The surface control parameter.
|
|
51
|
+
theta_b : float
|
|
52
|
+
The bottom control parameter.
|
|
53
|
+
N : int
|
|
54
|
+
The number of vertical levels.
|
|
55
|
+
type : str
|
|
56
|
+
The type of sigma ('w' for vertical velocity points, 'r' for rho-points).
|
|
57
|
+
|
|
58
|
+
Returns
|
|
59
|
+
-------
|
|
60
|
+
cs : xr.DataArray
|
|
61
|
+
The stretching curve values.
|
|
62
|
+
sigma : xr.DataArray
|
|
63
|
+
The sigma-coordinate values.
|
|
64
|
+
|
|
65
|
+
Raises
|
|
66
|
+
------
|
|
67
|
+
ValueError
|
|
68
|
+
If the type is not 'w' or 'r'.
|
|
69
|
+
"""
|
|
70
|
+
if type == "w":
|
|
71
|
+
k = xr.DataArray(np.arange(N + 1), dims="s_w")
|
|
72
|
+
sigma = (k - N) / N
|
|
73
|
+
elif type == "r":
|
|
74
|
+
k = xr.DataArray(np.arange(1, N + 1), dims="s_rho")
|
|
75
|
+
sigma = (k - N - 0.5) / N
|
|
76
|
+
else:
|
|
77
|
+
raise ValueError(
|
|
78
|
+
"Type must be either 'w' for vertical velocity points or 'r' for rho-points."
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
cs = compute_cs(sigma, theta_s, theta_b)
|
|
82
|
+
|
|
83
|
+
return cs, sigma
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def compute_depth(zeta, h, hc, cs, sigma):
|
|
87
|
+
"""Compute the depth at different sigma levels.
|
|
88
|
+
|
|
89
|
+
Parameters
|
|
90
|
+
----------
|
|
91
|
+
zeta : xr.DataArray or scalar
|
|
92
|
+
The sea surface height.
|
|
93
|
+
h : xr.DataArray
|
|
94
|
+
The depth of the sea bottom.
|
|
95
|
+
hc : float
|
|
96
|
+
The critical depth.
|
|
97
|
+
cs : xr.DataArray
|
|
98
|
+
The stretching curve values.
|
|
99
|
+
sigma : xr.DataArray
|
|
100
|
+
The sigma-coordinate values.
|
|
101
|
+
|
|
102
|
+
Returns
|
|
103
|
+
-------
|
|
104
|
+
z : xr.DataArray
|
|
105
|
+
The depth at different sigma levels.
|
|
106
|
+
"""
|
|
107
|
+
|
|
108
|
+
z = (hc * sigma + h * cs) / (hc + h)
|
|
109
|
+
z = zeta + (zeta + h) * z
|
|
110
|
+
|
|
111
|
+
z = -transpose_dimensions(z)
|
|
112
|
+
|
|
113
|
+
return z
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def add_depth_coordinates_to_dataset(
|
|
117
|
+
ds: "xr.Dataset",
|
|
118
|
+
grid_ds: "xr.Dataset",
|
|
119
|
+
depth_type: str,
|
|
120
|
+
locations: list[str] = ["rho", "u", "v"],
|
|
121
|
+
) -> None:
|
|
122
|
+
"""Add computed vertical depth coordinates to a dataset for specified grid
|
|
123
|
+
locations.
|
|
124
|
+
|
|
125
|
+
This function computes vertical depth coordinates (layer or interface) and updates
|
|
126
|
+
the provided dataset with these coordinates for the specified grid locations. If
|
|
127
|
+
the dataset already contains depth coordinates for all specified locations, the function
|
|
128
|
+
does nothing.
|
|
129
|
+
|
|
130
|
+
Parameters
|
|
131
|
+
----------
|
|
132
|
+
ds : xr.Dataset
|
|
133
|
+
Target dataset to which computed depth coordinates will be added.
|
|
134
|
+
If the `zeta` variable is not present, static vertical coordinates are used.
|
|
135
|
+
|
|
136
|
+
grid_ds : xr.Dataset
|
|
137
|
+
Grid dataset containing bathymetry, stretching curves, and parameters.
|
|
138
|
+
|
|
139
|
+
depth_type : str
|
|
140
|
+
Type of depth coordinates to compute. Options are:
|
|
141
|
+
- "layer": Layer depth coordinates.
|
|
142
|
+
- "interface": Interface depth coordinates.
|
|
143
|
+
|
|
144
|
+
locations : list of str, optional
|
|
145
|
+
List of locations for which to compute depth coordinates. Default is ["rho", "u", "v"].
|
|
146
|
+
"""
|
|
147
|
+
required_vars = [f"{depth_type}_depth_{loc}" for loc in locations]
|
|
148
|
+
|
|
149
|
+
if all(var in ds for var in required_vars):
|
|
150
|
+
return # Depth coordinates already exist
|
|
151
|
+
|
|
152
|
+
# Compute or interpolate depth coordinates
|
|
153
|
+
if f"{depth_type}_depth_rho" in ds:
|
|
154
|
+
depth_rho = ds[f"{depth_type}_depth_rho"]
|
|
155
|
+
else:
|
|
156
|
+
h = grid_ds["h"]
|
|
157
|
+
zeta = ds.get("zeta", 0)
|
|
158
|
+
if depth_type == "layer":
|
|
159
|
+
Cs = grid_ds["Cs_r"]
|
|
160
|
+
sigma = grid_ds["sigma_r"]
|
|
161
|
+
elif depth_type == "interface":
|
|
162
|
+
Cs = grid_ds["Cs_w"]
|
|
163
|
+
sigma = grid_ds["sigma_w"]
|
|
164
|
+
depth_rho = compute_depth(zeta, h, grid_ds.attrs["hc"], Cs, sigma)
|
|
165
|
+
depth_rho.attrs.update(
|
|
166
|
+
{"long_name": f"{depth_type} depth at rho-points", "units": "m"}
|
|
167
|
+
)
|
|
168
|
+
ds[f"{depth_type}_depth_rho"] = depth_rho
|
|
169
|
+
|
|
170
|
+
# Interpolate depth to other locations
|
|
171
|
+
for loc in locations:
|
|
172
|
+
if loc == "rho":
|
|
173
|
+
continue
|
|
174
|
+
|
|
175
|
+
interp_func = (
|
|
176
|
+
interpolate_from_rho_to_u if loc == "u" else interpolate_from_rho_to_v
|
|
177
|
+
)
|
|
178
|
+
depth_loc = interp_func(depth_rho)
|
|
179
|
+
depth_loc.attrs.update(
|
|
180
|
+
{"long_name": f"{depth_type} depth at {loc}-points", "units": "m"}
|
|
181
|
+
)
|
|
182
|
+
ds[f"{depth_type}_depth_{loc}"] = depth_loc
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def compute_depth_coordinates(
|
|
186
|
+
ds: "xr.Dataset",
|
|
187
|
+
grid_ds: "xr.Dataset",
|
|
188
|
+
depth_type: str,
|
|
189
|
+
location: str,
|
|
190
|
+
s: int = None,
|
|
191
|
+
eta: int = None,
|
|
192
|
+
xi: int = None,
|
|
193
|
+
) -> "xr.DataArray":
|
|
194
|
+
"""Compute vertical depth coordinates efficiently for a specified grid location and
|
|
195
|
+
optional indices.
|
|
196
|
+
|
|
197
|
+
This function calculates vertical depth coordinates (layer or interface) for a given grid
|
|
198
|
+
location (`rho`, `u`, or `v`). It performs spatial slicing (meridional or zonal) on the
|
|
199
|
+
bathymetry and free-surface elevation (`zeta`) before computing depth coordinates. This
|
|
200
|
+
approach minimizes computational overhead by reducing the dataset size before performing
|
|
201
|
+
vertical coordinate calculations.
|
|
202
|
+
|
|
203
|
+
Parameters
|
|
204
|
+
----------
|
|
205
|
+
ds : xr.Dataset
|
|
206
|
+
Dataset containing optional `zeta` (free-surface elevation). If `zeta` is not present,
|
|
207
|
+
static vertical coordinates are computed.
|
|
208
|
+
|
|
209
|
+
grid_ds : xr.Dataset
|
|
210
|
+
Grid dataset containing bathymetry (`h`), stretching curves (`Cs`), and sigma-layer
|
|
211
|
+
parameters (`sigma`). The attributes of this dataset should include the critical depth (`hc`).
|
|
212
|
+
|
|
213
|
+
depth_type : str
|
|
214
|
+
Type of depth coordinates to compute:
|
|
215
|
+
- `"layer"`: Depth at the center of layers.
|
|
216
|
+
- `"interface"`: Depth at layer interfaces.
|
|
217
|
+
|
|
218
|
+
location : str
|
|
219
|
+
Grid location for the computation. Options are:
|
|
220
|
+
- `"rho"`: Depth at rho points (cell centers).
|
|
221
|
+
- `"u"`: Depth at u points (eastward velocity points).
|
|
222
|
+
- `"v"`: Depth at v points (northward velocity points).
|
|
223
|
+
|
|
224
|
+
s : int, optional
|
|
225
|
+
Vertical index to extract a single layer or interface slice. If not provided, all vertical
|
|
226
|
+
layers are included.
|
|
227
|
+
|
|
228
|
+
eta : int, optional
|
|
229
|
+
Meridional (north-south) index to extract a slice. If not provided, all meridional indices
|
|
230
|
+
are included.
|
|
231
|
+
|
|
232
|
+
xi : int, optional
|
|
233
|
+
Zonal (east-west) index to extract a slice. If not provided, all zonal indices are included.
|
|
234
|
+
|
|
235
|
+
Returns
|
|
236
|
+
-------
|
|
237
|
+
xr.DataArray
|
|
238
|
+
A DataArray containing the computed depth coordinates. If no indices are specified, the
|
|
239
|
+
array will have the full dimensionality of the depth coordinates. The dimensions of the
|
|
240
|
+
output depend on the provided indices:
|
|
241
|
+
- Full 3D (or 4D if `zeta` includes time) depth coordinates if no indices are provided.
|
|
242
|
+
- Reduced dimensionality for specified slices (e.g., 2D for a single vertical slice).
|
|
243
|
+
|
|
244
|
+
Notes
|
|
245
|
+
-----
|
|
246
|
+
- To ensure computational efficiency, spatial slicing (based on `eta` and `xi`) is performed
|
|
247
|
+
before computing depth coordinates. This reduces memory usage and processing time.
|
|
248
|
+
- Depth coordinates are interpolated to the specified grid location (`rho`, `u`, or `v`) if
|
|
249
|
+
necessary.
|
|
250
|
+
- If depth coordinates for the specified location and configuration already exist in `ds`,
|
|
251
|
+
they are not recomputed.
|
|
252
|
+
"""
|
|
253
|
+
|
|
254
|
+
h = grid_ds["h"]
|
|
255
|
+
zeta = ds.get("zeta", None)
|
|
256
|
+
|
|
257
|
+
# Interpolate h and zeta to the specified location
|
|
258
|
+
if location == "u":
|
|
259
|
+
h = interpolate_from_rho_to_u(h)
|
|
260
|
+
if zeta is not None:
|
|
261
|
+
zeta = interpolate_from_rho_to_u(zeta)
|
|
262
|
+
elif location == "v":
|
|
263
|
+
h = interpolate_from_rho_to_v(h)
|
|
264
|
+
if zeta is not None:
|
|
265
|
+
zeta = interpolate_from_rho_to_v(zeta)
|
|
266
|
+
|
|
267
|
+
# Slice spatially based on the location's specific dimensions
|
|
268
|
+
if eta is not None:
|
|
269
|
+
if location == "v":
|
|
270
|
+
h = h.isel(eta_v=eta)
|
|
271
|
+
if zeta is not None:
|
|
272
|
+
zeta = zeta.isel(eta_v=eta)
|
|
273
|
+
else: # Default to "rho" or "u"
|
|
274
|
+
h = h.isel(eta_rho=eta)
|
|
275
|
+
if zeta is not None:
|
|
276
|
+
zeta = zeta.isel(eta_rho=eta)
|
|
277
|
+
if xi is not None:
|
|
278
|
+
if location == "u":
|
|
279
|
+
h = h.isel(xi_u=xi)
|
|
280
|
+
if zeta is not None:
|
|
281
|
+
zeta = zeta.isel(xi_u=xi)
|
|
282
|
+
else: # Default to "rho" or "v"
|
|
283
|
+
h = h.isel(xi_rho=xi)
|
|
284
|
+
if zeta is not None:
|
|
285
|
+
zeta = zeta.isel(xi_rho=xi)
|
|
286
|
+
|
|
287
|
+
# Compute depth
|
|
288
|
+
if depth_type == "layer":
|
|
289
|
+
Cs = grid_ds["Cs_r"]
|
|
290
|
+
sigma = grid_ds["sigma_r"]
|
|
291
|
+
elif depth_type == "interface":
|
|
292
|
+
Cs = grid_ds["Cs_w"]
|
|
293
|
+
sigma = grid_ds["sigma_w"]
|
|
294
|
+
depth = compute_depth(zeta, h, grid_ds.attrs["hc"], Cs, sigma)
|
|
295
|
+
|
|
296
|
+
# Slice vertically
|
|
297
|
+
if s is not None:
|
|
298
|
+
vertical_dim = "s_rho" if "s_rho" in depth.dims else "s_w"
|
|
299
|
+
depth = depth.isel({vertical_dim: s})
|
|
300
|
+
|
|
301
|
+
# Add metadata
|
|
302
|
+
depth.attrs.update(
|
|
303
|
+
{"long_name": f"{depth_type} depth at {location}-points", "units": "m"}
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
return depth
|