roms-tools 0.0.6__py3-none-any.whl → 0.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,809 @@
1
+ from datetime import datetime
2
+ import xarray as xr
3
+ import numpy as np
4
+ import yaml
5
+ import importlib.metadata
6
+
7
+ from dataclasses import dataclass, field, asdict
8
+ from roms_tools.setup.grid import Grid
9
+ from roms_tools.setup.plot import _plot
10
+ from roms_tools.setup.fill import fill_and_interpolate
11
+ from roms_tools.setup.datasets import Dataset
12
+ from roms_tools.setup.utils import (
13
+ nan_check,
14
+ interpolate_from_rho_to_u,
15
+ interpolate_from_rho_to_v,
16
+ )
17
+ from typing import Dict, List
18
+ import matplotlib.pyplot as plt
19
+
20
+
21
+ @dataclass(frozen=True, kw_only=True)
22
+ class TPXO(Dataset):
23
+ """
24
+ Represents tidal data on original grid.
25
+
26
+ Parameters
27
+ ----------
28
+ filename : str
29
+ The path to the TPXO dataset.
30
+ var_names : List[str], optional
31
+ List of variable names that are required in the dataset. Defaults to
32
+ ["h_Re", "h_Im", "sal_Re", "sal_Im", "u_Re", "u_Im", "v_Re", "v_Im"].
33
+ dim_names: Dict[str, str], optional
34
+ Dictionary specifying the names of dimensions in the dataset. Defaults to
35
+ {"longitude": "ny", "latitude": "nx"}.
36
+
37
+ Attributes
38
+ ----------
39
+ ds : xr.Dataset
40
+ The xarray Dataset containing TPXO tidal model data.
41
+ """
42
+
43
+ filename: str
44
+ var_names: List[str] = field(
45
+ default_factory=lambda: [
46
+ "h_Re",
47
+ "h_Im",
48
+ "sal_Re",
49
+ "sal_Im",
50
+ "u_Re",
51
+ "u_Im",
52
+ "v_Re",
53
+ "v_Im",
54
+ "depth",
55
+ ]
56
+ )
57
+ dim_names: Dict[str, str] = field(
58
+ default_factory=lambda: {"longitude": "ny", "latitude": "nx", "ntides": "nc"}
59
+ )
60
+ ds: xr.Dataset = field(init=False, repr=False)
61
+
62
+ def __post_init__(self):
63
+ # Perform any necessary dataset initialization or modifications here
64
+ ds = super().load_data()
65
+
66
+ # Clean up dataset
67
+ ds = ds.assign_coords(
68
+ {
69
+ "omega": ds["omega"],
70
+ "nx": ds["lon_r"].isel(
71
+ ny=0
72
+ ), # lon_r is constant along ny, i.e., is only a function of nx
73
+ "ny": ds["lat_r"].isel(
74
+ nx=0
75
+ ), # lat_r is constant along nx, i.e., is only a function of ny
76
+ }
77
+ )
78
+ ds = ds.rename({"nx": "longitude", "ny": "latitude"})
79
+
80
+ object.__setattr__(
81
+ self,
82
+ "dim_names",
83
+ {
84
+ "latitude": "latitude",
85
+ "longitude": "longitude",
86
+ "ntides": self.dim_names["ntides"],
87
+ },
88
+ )
89
+ # Select relevant fields
90
+ ds = super().select_relevant_fields(ds)
91
+
92
+ # Check whether the data covers the entire globe
93
+ is_global = self.check_if_global(ds)
94
+
95
+ if is_global:
96
+ ds = self.concatenate_longitudes(ds)
97
+
98
+ object.__setattr__(self, "ds", ds)
99
+
100
+ def check_number_constituents(self, ntides: int):
101
+ """
102
+ Checks if the number of constituents in the dataset is at least `ntides`.
103
+
104
+ Parameters
105
+ ----------
106
+ ntides : int
107
+ The required number of tidal constituents.
108
+
109
+ Raises
110
+ ------
111
+ ValueError
112
+ If the number of constituents in the dataset is less than `ntides`.
113
+ """
114
+ if len(self.ds[self.dim_names["ntides"]]) < ntides:
115
+ raise ValueError(
116
+ f"The dataset contains fewer than {ntides} tidal constituents."
117
+ )
118
+
119
+ def get_corrected_tides(self, model_reference_date, allan_factor):
120
+ # Get equilibrium tides
121
+ tpc = compute_equilibrium_tide(self.ds["longitude"], self.ds["latitude"]).isel(
122
+ nc=self.ds.nc
123
+ )
124
+ # Correct for SAL
125
+ tsc = allan_factor * (self.ds["sal_Re"] + 1j * self.ds["sal_Im"])
126
+ tpc = tpc - tsc
127
+
128
+ # Elevations and transports
129
+ thc = self.ds["h_Re"] + 1j * self.ds["h_Im"]
130
+ tuc = self.ds["u_Re"] + 1j * self.ds["u_Im"]
131
+ tvc = self.ds["v_Re"] + 1j * self.ds["v_Im"]
132
+
133
+ # Apply correction for phases and amplitudes
134
+ pf, pu, aa = egbert_correction(model_reference_date)
135
+ pf = pf.isel(nc=self.ds.nc)
136
+ pu = pu.isel(nc=self.ds.nc)
137
+ aa = aa.isel(nc=self.ds.nc)
138
+
139
+ tpxo_reference_date = datetime(1992, 1, 1)
140
+ dt = (model_reference_date - tpxo_reference_date).days * 3600 * 24
141
+
142
+ thc = pf * thc * np.exp(1j * (self.ds["omega"] * dt + pu + aa))
143
+ tuc = pf * tuc * np.exp(1j * (self.ds["omega"] * dt + pu + aa))
144
+ tvc = pf * tvc * np.exp(1j * (self.ds["omega"] * dt + pu + aa))
145
+ tpc = pf * tpc * np.exp(1j * (self.ds["omega"] * dt + pu + aa))
146
+
147
+ tides = {
148
+ "ssh_Re": thc.real,
149
+ "ssh_Im": thc.imag,
150
+ "u_Re": tuc.real,
151
+ "u_Im": tuc.imag,
152
+ "v_Re": tvc.real,
153
+ "v_Im": tvc.imag,
154
+ "pot_Re": tpc.real,
155
+ "pot_Im": tpc.imag,
156
+ "omega": self.ds["omega"],
157
+ }
158
+
159
+ for k in tides.keys():
160
+ tides[k] = tides[k].rename({"nc": "ntides"})
161
+
162
+ return tides
163
+
164
+
165
+ @dataclass(frozen=True, kw_only=True)
166
+ class TidalForcing:
167
+ """
168
+ Represents tidal forcing data used in ocean modeling.
169
+
170
+ Parameters
171
+ ----------
172
+ grid : Grid
173
+ The grid object representing the ROMS grid associated with the tidal forcing data.
174
+ filename: str
175
+ The path to the native tidal dataset.
176
+ ntides : int, optional
177
+ Number of constituents to consider. Maximum number is 14. Default is 10.
178
+ model_reference_date : datetime, optional
179
+ The reference date for the ROMS simulation. Default is datetime(2000, 1, 1).
180
+ source : str, optional
181
+ The source of the tidal data. Default is "TPXO".
182
+ allan_factor : float, optional
183
+ The Allan factor used in tidal model computation. Default is 2.0.
184
+
185
+ Attributes
186
+ ----------
187
+ ds : xr.Dataset
188
+ The xarray Dataset containing the tidal forcing data.
189
+
190
+ Examples
191
+ --------
192
+ >>> grid = Grid(...)
193
+ >>> tidal_forcing = TidalForcing(grid)
194
+ >>> print(tidal_forcing.ds)
195
+ """
196
+
197
+ grid: Grid
198
+ filename: str
199
+ ntides: int = 10
200
+ model_reference_date: datetime = datetime(2000, 1, 1)
201
+ source: str = "TPXO"
202
+ allan_factor: float = 2.0
203
+ ds: xr.Dataset = field(init=False, repr=False)
204
+
205
+ def __post_init__(self):
206
+ if self.source == "TPXO":
207
+ data = TPXO(filename=self.filename)
208
+ else:
209
+ raise ValueError('Only "TPXO" is a valid option for source.')
210
+
211
+ data.check_number_constituents(self.ntides)
212
+ # operate on longitudes between -180 and 180 unless ROMS domain lies at least 5 degrees in lontitude away from Greenwich meridian
213
+ lon = self.grid.ds.lon_rho
214
+ lat = self.grid.ds.lat_rho
215
+ angle = self.grid.ds.angle
216
+
217
+ lon = xr.where(lon > 180, lon - 360, lon)
218
+ straddle = True
219
+ if not self.grid.straddle and abs(lon).min() > 5:
220
+ lon = xr.where(lon < 0, lon + 360, lon)
221
+ straddle = False
222
+
223
+ # The following consists of two steps:
224
+ # Step 1: Choose subdomain of forcing data including safety margin for interpolation, and Step 2: Convert to the proper longitude range.
225
+ # We perform these two steps for two reasons:
226
+ # A) Since the horizontal dimensions consist of a single chunk, selecting a subdomain before interpolation is a lot more performant.
227
+ # B) Step 1 is necessary to avoid discontinuous longitudes that could be introduced by Step 2. Specifically, discontinuous longitudes
228
+ # can lead to artifacts in the interpolation process. Specifically, if there is a data gap if data is not global,
229
+ # discontinuous longitudes could result in values that appear to come from a distant location instead of producing NaNs.
230
+ # These NaNs are important as they can be identified and handled appropriately by the nan_check function.
231
+ data.choose_subdomain(
232
+ latitude_range=[lat.min().values, lat.max().values],
233
+ longitude_range=[lon.min().values, lon.max().values],
234
+ margin=2,
235
+ straddle=straddle,
236
+ )
237
+
238
+ tides = data.get_corrected_tides(self.model_reference_date, self.allan_factor)
239
+
240
+ # select desired number of constituents
241
+ for k in tides.keys():
242
+ tides[k] = tides[k].isel(ntides=slice(None, self.ntides))
243
+
244
+ # interpolate onto desired grid
245
+ coords = {"latitude": lat, "longitude": lon}
246
+ mask = xr.where(data.ds.depth > 0, 1, 0)
247
+
248
+ varnames = [
249
+ "ssh_Re",
250
+ "ssh_Im",
251
+ "pot_Re",
252
+ "pot_Im",
253
+ "u_Re",
254
+ "u_Im",
255
+ "v_Re",
256
+ "v_Im",
257
+ ]
258
+ data_vars = {}
259
+
260
+ for var in varnames:
261
+ data_vars[var] = fill_and_interpolate(
262
+ tides[var],
263
+ mask,
264
+ list(coords.keys()),
265
+ coords,
266
+ method="linear",
267
+ )
268
+
269
+ # Rotate to grid orientation
270
+ u_Re = data_vars["u_Re"] * np.cos(angle) + data_vars["v_Re"] * np.sin(angle)
271
+ v_Re = data_vars["v_Re"] * np.cos(angle) - data_vars["u_Re"] * np.sin(angle)
272
+ u_Im = data_vars["u_Im"] * np.cos(angle) + data_vars["v_Im"] * np.sin(angle)
273
+ v_Im = data_vars["v_Im"] * np.cos(angle) - data_vars["u_Im"] * np.sin(angle)
274
+
275
+ # Convert to barotropic velocity
276
+ u_Re = u_Re / self.grid.ds.h
277
+ v_Re = v_Re / self.grid.ds.h
278
+ u_Im = u_Im / self.grid.ds.h
279
+ v_Im = v_Im / self.grid.ds.h
280
+
281
+ # Interpolate from rho- to velocity points
282
+ u_Re = interpolate_from_rho_to_u(u_Re)
283
+ v_Re = interpolate_from_rho_to_v(v_Re)
284
+ u_Im = interpolate_from_rho_to_u(u_Im)
285
+ v_Im = interpolate_from_rho_to_v(v_Im)
286
+
287
+ # save in new dataset
288
+ ds = xr.Dataset()
289
+
290
+ # ds["omega"] = tides["omega"]
291
+
292
+ ds["ssh_Re"] = data_vars["ssh_Re"].astype(np.float32)
293
+ ds["ssh_Im"] = data_vars["ssh_Im"].astype(np.float32)
294
+ ds["ssh_Re"].attrs["long_name"] = "Tidal elevation, real part"
295
+ ds["ssh_Im"].attrs["long_name"] = "Tidal elevation, complex part"
296
+ ds["ssh_Re"].attrs["units"] = "m"
297
+ ds["ssh_Im"].attrs["units"] = "m"
298
+
299
+ ds["pot_Re"] = data_vars["pot_Re"].astype(np.float32)
300
+ ds["pot_Im"] = data_vars["pot_Im"].astype(np.float32)
301
+ ds["pot_Re"].attrs["long_name"] = "Tidal potential, real part"
302
+ ds["pot_Im"].attrs["long_name"] = "Tidal potential, complex part"
303
+ ds["pot_Re"].attrs["units"] = "m"
304
+ ds["pot_Im"].attrs["units"] = "m"
305
+
306
+ ds["u_Re"] = u_Re.astype(np.float32)
307
+ ds["u_Im"] = u_Im.astype(np.float32)
308
+ ds["u_Re"].attrs["long_name"] = "Tidal velocity in x-direction, real part"
309
+ ds["u_Im"].attrs["long_name"] = "Tidal velocity in x-direction, complex part"
310
+ ds["u_Re"].attrs["units"] = "m/s"
311
+ ds["u_Im"].attrs["units"] = "m/s"
312
+
313
+ ds["v_Re"] = v_Re.astype(np.float32)
314
+ ds["v_Im"] = v_Im.astype(np.float32)
315
+ ds["v_Re"].attrs["long_name"] = "Tidal velocity in y-direction, real part"
316
+ ds["v_Im"].attrs["long_name"] = "Tidal velocity in y-direction, complex part"
317
+ ds["v_Re"].attrs["units"] = "m/s"
318
+ ds["v_Im"].attrs["units"] = "m/s"
319
+
320
+ ds.attrs["title"] = "ROMS tidal forcing created by ROMS-Tools"
321
+ # Include the version of roms-tools
322
+ try:
323
+ roms_tools_version = importlib.metadata.version("roms-tools")
324
+ except importlib.metadata.PackageNotFoundError:
325
+ roms_tools_version = "unknown"
326
+
327
+ ds.attrs["roms_tools_version"] = roms_tools_version
328
+
329
+ ds.attrs["source"] = self.source
330
+ ds.attrs["model_reference_date"] = str(self.model_reference_date)
331
+ ds.attrs["allan_factor"] = self.allan_factor
332
+
333
+ object.__setattr__(self, "ds", ds)
334
+
335
+ for var in ["ssh_Re", "u_Re", "v_Im"]:
336
+ nan_check(self.ds[var].isel(ntides=0), self.grid.ds.mask_rho)
337
+
338
+ def plot(self, varname, ntides=0) -> None:
339
+ """
340
+ Plot the specified tidal forcing variable for a given tidal constituent.
341
+
342
+ Parameters
343
+ ----------
344
+ varname : str
345
+ The tidal forcing variable to plot. Options include:
346
+ - "ssh_Re": Real part of tidal elevation.
347
+ - "ssh_Im": Imaginary part of tidal elevation.
348
+ - "pot_Re": Real part of tidal potential.
349
+ - "pot_Im": Imaginary part of tidal potential.
350
+ - "u_Re": Real part of tidal velocity in the x-direction.
351
+ - "u_Im": Imaginary part of tidal velocity in the x-direction.
352
+ - "v_Re": Real part of tidal velocity in the y-direction.
353
+ - "v_Im": Imaginary part of tidal velocity in the y-direction.
354
+ ntides : int, optional
355
+ The index of the tidal constituent to plot. Default is 0, which corresponds
356
+ to the first constituent.
357
+
358
+ Returns
359
+ -------
360
+ None
361
+ This method does not return any value. It generates and displays a plot.
362
+
363
+ Raises
364
+ ------
365
+ ValueError
366
+ If the specified field is not one of the valid options.
367
+
368
+
369
+ Examples
370
+ --------
371
+ >>> tidal_forcing = TidalForcing(grid)
372
+ >>> tidal_forcing.plot("ssh_Re", nc=0)
373
+ """
374
+
375
+ field = self.ds[varname].isel(ntides=ntides).compute()
376
+
377
+ title = "%s, ntides = %i" % (field.long_name, self.ds[varname].ntides[ntides])
378
+
379
+ vmax = max(field.max(), -field.min())
380
+ vmin = -vmax
381
+ cmap = plt.colormaps.get_cmap("RdBu_r")
382
+ cmap.set_bad(color="gray")
383
+
384
+ kwargs = {"vmax": vmax, "vmin": vmin, "cmap": cmap}
385
+
386
+ _plot(
387
+ self.grid.ds,
388
+ field=field,
389
+ straddle=self.grid.straddle,
390
+ c="g",
391
+ kwargs=kwargs,
392
+ title=title,
393
+ )
394
+
395
+ def save(self, filepath: str) -> None:
396
+ """
397
+ Save the tidal forcing information to a netCDF4 file.
398
+
399
+ Parameters
400
+ ----------
401
+ filepath
402
+ """
403
+ self.ds.to_netcdf(filepath)
404
+
405
+ def to_yaml(self, filepath: str) -> None:
406
+ """
407
+ Export the parameters of the class to a YAML file, including the version of roms-tools.
408
+
409
+ Parameters
410
+ ----------
411
+ filepath : str
412
+ The path to the YAML file where the parameters will be saved.
413
+ """
414
+ grid_data = asdict(self.grid)
415
+ grid_data.pop("ds", None) # Exclude non-serializable fields
416
+ grid_data.pop("straddle", None)
417
+
418
+ # Include the version of roms-tools
419
+ try:
420
+ roms_tools_version = importlib.metadata.version("roms-tools")
421
+ except importlib.metadata.PackageNotFoundError:
422
+ roms_tools_version = "unknown"
423
+
424
+ # Create header
425
+ header = f"---\nroms_tools_version: {roms_tools_version}\n---\n"
426
+
427
+ # Extract grid data
428
+ grid_yaml_data = {"Grid": grid_data}
429
+
430
+ # Extract tidal forcing data
431
+ tidal_forcing_data = {
432
+ "TidalForcing": {
433
+ "filename": self.filename,
434
+ "ntides": self.ntides,
435
+ "model_reference_date": self.model_reference_date.isoformat(),
436
+ "source": self.source,
437
+ "allan_factor": self.allan_factor,
438
+ }
439
+ }
440
+
441
+ # Combine both sections
442
+ yaml_data = {**grid_yaml_data, **tidal_forcing_data}
443
+
444
+ with open(filepath, "w") as file:
445
+ # Write header
446
+ file.write(header)
447
+ # Write YAML data
448
+ yaml.dump(yaml_data, file, default_flow_style=False)
449
+
450
+ @classmethod
451
+ def from_yaml(cls, filepath: str) -> "TidalForcing":
452
+ """
453
+ Create an instance of the TidalForcing class from a YAML file.
454
+
455
+ Parameters
456
+ ----------
457
+ filepath : str
458
+ The path to the YAML file from which the parameters will be read.
459
+
460
+ Returns
461
+ -------
462
+ TidalForcing
463
+ An instance of the TidalForcing class.
464
+ """
465
+ # Read the entire file content
466
+ with open(filepath, "r") as file:
467
+ file_content = file.read()
468
+
469
+ # Split the content into YAML documents
470
+ documents = list(yaml.safe_load_all(file_content))
471
+
472
+ tidal_forcing_data = None
473
+
474
+ # Process the YAML documents
475
+ for doc in documents:
476
+ if doc is None:
477
+ continue
478
+ if "TidalForcing" in doc:
479
+ tidal_forcing_data = doc["TidalForcing"]
480
+ break
481
+
482
+ if tidal_forcing_data is None:
483
+ raise ValueError("No TidalForcing configuration found in the YAML file.")
484
+
485
+ # Convert the model_reference_date from string to datetime
486
+ tidal_forcing_params = tidal_forcing_data
487
+ tidal_forcing_params["model_reference_date"] = datetime.fromisoformat(
488
+ tidal_forcing_params["model_reference_date"]
489
+ )
490
+
491
+ # Create Grid instance from the YAML file
492
+ grid = Grid.from_yaml(filepath)
493
+
494
+ # Create and return an instance of TidalForcing
495
+ return cls(grid=grid, **tidal_forcing_params)
496
+
497
+
498
+ def modified_julian_days(year, month, day, hour=0):
499
+ """
500
+ Calculate the Modified Julian Day (MJD) for a given date and time.
501
+
502
+ The Modified Julian Day (MJD) is a modified Julian day count starting from
503
+ November 17, 1858 AD. It is commonly used in astronomy and geodesy.
504
+
505
+ Parameters
506
+ ----------
507
+ year : int
508
+ The year.
509
+ month : int
510
+ The month (1-12).
511
+ day : int
512
+ The day of the month.
513
+ hour : float, optional
514
+ The hour of the day as a fractional number (0 to 23.999...). Default is 0.
515
+
516
+ Returns
517
+ -------
518
+ mjd : float
519
+ The Modified Julian Day (MJD) corresponding to the input date and time.
520
+
521
+ Notes
522
+ -----
523
+ The algorithm assumes that the input date (year, month, day) is within the
524
+ Gregorian calendar, i.e., after October 15, 1582. Negative MJD values are
525
+ allowed for dates before November 17, 1858.
526
+
527
+ References
528
+ ----------
529
+ - Wikipedia article on Julian Day: https://en.wikipedia.org/wiki/Julian_day
530
+ - Wikipedia article on Modified Julian Day: https://en.wikipedia.org/wiki/Modified_Julian_day
531
+
532
+ Examples
533
+ --------
534
+ >>> modified_julian_days(2024, 5, 20, 12)
535
+ 58814.0
536
+ >>> modified_julian_days(1858, 11, 17)
537
+ 0.0
538
+ >>> modified_julian_days(1582, 10, 4)
539
+ -141428.5
540
+ """
541
+
542
+ if month < 3:
543
+ year -= 1
544
+ month += 12
545
+
546
+ A = year // 100
547
+ B = A // 4
548
+ C = 2 - A + B
549
+ E = int(365.25 * (year + 4716))
550
+ F = int(30.6001 * (month + 1))
551
+ jd = C + day + hour / 24 + E + F - 1524.5
552
+ mjd = jd - 2400000.5
553
+
554
+ return mjd
555
+
556
+
557
+ def egbert_correction(date):
558
+ """
559
+ Correct phases and amplitudes for real-time runs using parts of the
560
+ post-processing code from Egbert's & Erofeeva's (OSU) TPXO model.
561
+
562
+ Parameters
563
+ ----------
564
+ date : datetime.datetime
565
+ The date and time for which corrections are to be applied.
566
+
567
+ Returns
568
+ -------
569
+ pf : xr.DataArray
570
+ Amplitude scaling factor for each of the 15 tidal constituents.
571
+ pu : xr.DataArray
572
+ Phase correction [radians] for each of the 15 tidal constituents.
573
+ aa : xr.DataArray
574
+ Astronomical arguments [radians] associated with the corrections.
575
+
576
+ References
577
+ ----------
578
+ - Egbert, G.D., and S.Y. Erofeeva. "Efficient inverse modeling of barotropic ocean
579
+ tides." Journal of Atmospheric and Oceanic Technology 19, no. 2 (2002): 183-204.
580
+
581
+ """
582
+
583
+ year = date.year
584
+ month = date.month
585
+ day = date.day
586
+ hour = date.hour
587
+ minute = date.minute
588
+ second = date.second
589
+
590
+ rad = np.pi / 180.0
591
+ deg = 180.0 / np.pi
592
+ mjd = modified_julian_days(year, month, day)
593
+ tstart = mjd + hour / 24 + minute / (60 * 24) + second / (60 * 60 * 24)
594
+
595
+ # Determine nodal corrections pu & pf : these expressions are valid for period 1990-2010 (Cartwright 1990).
596
+ # Reset time origin for astronomical arguments to 4th of May 1860:
597
+ timetemp = tstart - 51544.4993
598
+
599
+ # mean longitude of lunar perigee
600
+ P = 83.3535 + 0.11140353 * timetemp
601
+ P = np.mod(P, 360.0)
602
+ if P < 0:
603
+ P = +360
604
+ P *= rad
605
+
606
+ # mean longitude of ascending lunar node
607
+ N = 125.0445 - 0.05295377 * timetemp
608
+ N = np.mod(N, 360.0)
609
+ if N < 0:
610
+ N = +360
611
+ N *= rad
612
+
613
+ sinn = np.sin(N)
614
+ cosn = np.cos(N)
615
+ sin2n = np.sin(2 * N)
616
+ cos2n = np.cos(2 * N)
617
+ sin3n = np.sin(3 * N)
618
+
619
+ pftmp = np.sqrt(
620
+ (1 - 0.03731 * cosn + 0.00052 * cos2n) ** 2
621
+ + (0.03731 * sinn - 0.00052 * sin2n) ** 2
622
+ ) # 2N2
623
+
624
+ pf = np.zeros(15)
625
+ pf[0] = pftmp # M2
626
+ pf[1] = 1.0 # S2
627
+ pf[2] = pftmp # N2
628
+ pf[3] = np.sqrt(
629
+ (1 + 0.2852 * cosn + 0.0324 * cos2n) ** 2
630
+ + (0.3108 * sinn + 0.0324 * sin2n) ** 2
631
+ ) # K2
632
+ pf[4] = np.sqrt(
633
+ (1 + 0.1158 * cosn - 0.0029 * cos2n) ** 2
634
+ + (0.1554 * sinn - 0.0029 * sin2n) ** 2
635
+ ) # K1
636
+ pf[5] = np.sqrt(
637
+ (1 + 0.189 * cosn - 0.0058 * cos2n) ** 2 + (0.189 * sinn - 0.0058 * sin2n) ** 2
638
+ ) # O1
639
+ pf[6] = 1.0 # P1
640
+ pf[7] = np.sqrt((1 + 0.188 * cosn) ** 2 + (0.188 * sinn) ** 2) # Q1
641
+ pf[8] = 1.043 + 0.414 * cosn # Mf
642
+ pf[9] = 1.0 - 0.130 * cosn # Mm
643
+ pf[10] = pftmp**2 # M4
644
+ pf[11] = pftmp**2 # Mn4
645
+ pf[12] = pftmp**2 # Ms4
646
+ pf[13] = pftmp # 2n2
647
+ pf[14] = 1.0 # S1
648
+ pf = xr.DataArray(pf, dims="nc")
649
+
650
+ putmp = (
651
+ np.arctan(
652
+ (-0.03731 * sinn + 0.00052 * sin2n)
653
+ / (1.0 - 0.03731 * cosn + 0.00052 * cos2n)
654
+ )
655
+ * deg
656
+ ) # 2N2
657
+
658
+ pu = np.zeros(15)
659
+ pu[0] = putmp # M2
660
+ pu[1] = 0.0 # S2
661
+ pu[2] = putmp # N2
662
+ pu[3] = (
663
+ np.arctan(
664
+ -(0.3108 * sinn + 0.0324 * sin2n) / (1.0 + 0.2852 * cosn + 0.0324 * cos2n)
665
+ )
666
+ * deg
667
+ ) # K2
668
+ pu[4] = (
669
+ np.arctan(
670
+ (-0.1554 * sinn + 0.0029 * sin2n) / (1.0 + 0.1158 * cosn - 0.0029 * cos2n)
671
+ )
672
+ * deg
673
+ ) # K1
674
+ pu[5] = 10.8 * sinn - 1.3 * sin2n + 0.2 * sin3n # O1
675
+ pu[6] = 0.0 # P1
676
+ pu[7] = np.arctan(0.189 * sinn / (1.0 + 0.189 * cosn)) * deg # Q1
677
+ pu[8] = -23.7 * sinn + 2.7 * sin2n - 0.4 * sin3n # Mf
678
+ pu[9] = 0.0 # Mm
679
+ pu[10] = putmp * 2.0 # M4
680
+ pu[11] = putmp * 2.0 # Mn4
681
+ pu[12] = putmp # Ms4
682
+ pu[13] = putmp # 2n2
683
+ pu[14] = 0.0 # S1
684
+ pu = xr.DataArray(pu, dims="nc")
685
+ # convert from degrees to radians
686
+ pu = pu * rad
687
+
688
+ aa = xr.DataArray(
689
+ data=np.array(
690
+ [
691
+ 1.731557546, # M2
692
+ 0.0, # S2
693
+ 6.050721243, # N2
694
+ 3.487600001, # K2
695
+ 0.173003674, # K1
696
+ 1.558553872, # O1
697
+ 6.110181633, # P1
698
+ 5.877717569, # Q1
699
+ 1.964021610, # Mm
700
+ 1.756042456, # Mf
701
+ 3.463115091, # M4
702
+ 1.499093481, # Mn4
703
+ 1.731557546, # Ms4
704
+ 4.086699633, # 2n2
705
+ 0.0, # S1
706
+ ]
707
+ ),
708
+ dims="nc",
709
+ )
710
+
711
+ return pf, pu, aa
712
+
713
+
714
+ def compute_equilibrium_tide(lon, lat):
715
+ """
716
+ Compute equilibrium tide for given longitudes and latitudes.
717
+
718
+ Parameters
719
+ ----------
720
+ lon : xr.DataArray
721
+ Longitudes in degrees.
722
+ lat : xr.DataArray
723
+ Latitudes in degrees.
724
+
725
+ Returns
726
+ -------
727
+ tpc : xr.DataArray
728
+ Equilibrium tide complex amplitude.
729
+
730
+ Notes
731
+ -----
732
+ This method computes the equilibrium tide complex amplitude for given longitudes
733
+ and latitudes. It considers 15 tidal constituents and their corresponding
734
+ amplitudes and elasticity factors. The types of tides are classified as follows:
735
+ - 2: semidiurnal
736
+ - 1: diurnal
737
+ - 0: long-term
738
+
739
+ """
740
+
741
+ # Amplitudes and elasticity factors for 15 tidal constituents
742
+ A = xr.DataArray(
743
+ data=np.array(
744
+ [
745
+ 0.242334, # M2
746
+ 0.112743, # S2
747
+ 0.046397, # N2
748
+ 0.030684, # K2
749
+ 0.141565, # K1
750
+ 0.100661, # O1
751
+ 0.046848, # P1
752
+ 0.019273, # Q1
753
+ 0.042041, # Mf
754
+ 0.022191, # Mm
755
+ 0.0, # M4
756
+ 0.0, # Mn4
757
+ 0.0, # Ms4
758
+ 0.006141, # 2n2
759
+ 0.000764, # S1
760
+ ]
761
+ ),
762
+ dims="nc",
763
+ )
764
+ B = xr.DataArray(
765
+ data=np.array(
766
+ [
767
+ 0.693, # M2
768
+ 0.693, # S2
769
+ 0.693, # N2
770
+ 0.693, # K2
771
+ 0.736, # K1
772
+ 0.695, # O1
773
+ 0.706, # P1
774
+ 0.695, # Q1
775
+ 0.693, # Mf
776
+ 0.693, # Mm
777
+ 0.693, # M4
778
+ 0.693, # Mn4
779
+ 0.693, # Ms4
780
+ 0.693, # 2n2
781
+ 0.693, # S1
782
+ ]
783
+ ),
784
+ dims="nc",
785
+ )
786
+
787
+ # types: 2 = semidiurnal, 1 = diurnal, 0 = long-term
788
+ ityp = xr.DataArray(
789
+ data=np.array([2, 2, 2, 2, 1, 1, 1, 1, 0, 0, 0, 0, 0, 2, 1]), dims="nc"
790
+ )
791
+
792
+ d2r = np.pi / 180
793
+ coslat2 = np.cos(d2r * lat) ** 2
794
+ sin2lat = np.sin(2 * d2r * lat)
795
+
796
+ p_amp = (
797
+ xr.where(ityp == 2, 1, 0) * A * B * coslat2 # semidiurnal
798
+ + xr.where(ityp == 1, 1, 0) * A * B * sin2lat # diurnal
799
+ + xr.where(ityp == 0, 1, 0) * A * B * (0.5 - 1.5 * coslat2) # long-term
800
+ )
801
+ p_pha = (
802
+ xr.where(ityp == 2, 1, 0) * (-2 * lon * d2r) # semidiurnal
803
+ + xr.where(ityp == 1, 1, 0) * (-lon * d2r) # diurnal
804
+ + xr.where(ityp == 0, 1, 0) * xr.zeros_like(lon) # long-term
805
+ )
806
+
807
+ tpc = p_amp * np.exp(-1j * p_pha)
808
+
809
+ return tpc